cheenchan commited on
Commit
45f025a
·
1 Parent(s): de0fca3

Auto-display catalog characters after video upload

Browse files
frame_extraction/src/frame_extraction/app.py CHANGED
@@ -6,7 +6,7 @@ import os
6
  import shutil
7
  import uuid
8
  from pathlib import Path
9
- from typing import Any
10
 
11
  import gradio as gr
12
  import numpy as np
@@ -28,7 +28,30 @@ def ensure_output_dirs() -> None:
28
  (OUTPUT_DIR / "frames").mkdir(parents=True, exist_ok=True)
29
 
30
 
31
- def build_catalog_from_video(file: gr.FileData) -> tuple[str | None, str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  if file is None:
33
  raise gr.Error("Please upload a source video first.")
34
 
@@ -41,10 +64,8 @@ def build_catalog_from_video(file: gr.FileData) -> tuple[str | None, str]:
41
  catalog_dir = OUTPUT_DIR / "catalogs" / f"catalog_{run_id}"
42
  cfg = CatalogConfig(video_path=video_path, output_dir=catalog_dir)
43
  catalog_path = build_catalog(cfg)
44
- catalog_data = json.loads(catalog_path.read_text(encoding="utf-8"))
45
- ref_count = len(catalog_data.get("references", []))
46
- message = f"Catalog ready ({ref_count} references)."
47
- return str(catalog_path), message
48
 
49
 
50
  def predict_from_arrays(arrays: list[np.ndarray], catalog_path: str | None) -> tuple[list[dict[str, Any]], list[list[str]]]:
@@ -75,15 +96,28 @@ def predict_from_arrays(arrays: list[np.ndarray], catalog_path: str | None) -> t
75
  gallery_items = [
76
  [item.get("reference_crop", ""), f"{item.get('character_id', 'unknown')} ({item.get('similarity', 0):.2f})"]
77
  for item in data
 
78
  ]
79
  return data, gallery_items
80
 
81
 
 
 
 
 
 
 
 
82
  def build_interface() -> gr.Blocks:
 
 
83
  with gr.Blocks() as demo:
84
  gr.Markdown("# Character Reference Matcher")
85
- catalog_state = gr.State(str(CATALOG_PATH) if CATALOG_PATH.exists() else None)
86
- status_box = gr.Textbox(label="Status", value="Upload a video to generate a catalog.", interactive=False)
 
 
 
87
 
88
  video_upload = gr.File(label="Source video", file_types=["video"], height="auto")
89
  frame_upload = gr.UploadButton(
@@ -93,15 +127,18 @@ def build_interface() -> gr.Blocks:
93
  )
94
 
95
  matches_json = gr.JSON(label="Matches")
96
- gallery = gr.Gallery(label="Reference Thumbnails", columns=2)
 
 
 
97
 
98
- video_upload.change(build_catalog_from_video, inputs=video_upload, outputs=[catalog_state, status_box])
99
 
100
- def handle_frames(files: list[gr.FileData], catalog_path: str | None) -> tuple[list[dict[str, Any]], list[list[str]]]:
101
  arrays = [np.array(Image.open(file.name).convert("RGB")) for file in files]
102
  return predict_from_arrays(arrays, catalog_path)
103
 
104
- frame_upload.upload(handle_frames, inputs=[frame_upload, catalog_state], outputs=[matches_json, gallery])
105
  return demo
106
 
107
 
 
6
  import shutil
7
  import uuid
8
  from pathlib import Path
9
+ from typing import Any, Tuple
10
 
11
  import gradio as gr
12
  import numpy as np
 
28
  (OUTPUT_DIR / "frames").mkdir(parents=True, exist_ok=True)
29
 
30
 
31
+ def summarize_catalog(catalog_path: Path) -> Tuple[str, list[dict[str, Any]], list[list[str]]]:
32
+ if not catalog_path.exists():
33
+ return ("Catalog not found.", [], [])
34
+ data = json.loads(catalog_path.read_text(encoding="utf-8"))
35
+ references = data.get("references", [])
36
+ message = f"Catalog ready ({len(references)} references)."
37
+ index = [
38
+ {
39
+ "character_id": ref.get("character_id"),
40
+ "reference_path": ref.get("reference_path"),
41
+ "frame_path": ref.get("frame_path"),
42
+ "sharpness": ref.get("sharpness"),
43
+ }
44
+ for ref in references
45
+ ]
46
+ gallery = [
47
+ [ref.get("reference_path", ""), ref.get("character_id", "unknown")]
48
+ for ref in references
49
+ if ref.get("reference_path")
50
+ ]
51
+ return message, index, gallery
52
+
53
+
54
+ def build_catalog_from_video(file: gr.FileData) -> tuple[str | None, str, list[dict[str, Any]], list[list[str]]]:
55
  if file is None:
56
  raise gr.Error("Please upload a source video first.")
57
 
 
64
  catalog_dir = OUTPUT_DIR / "catalogs" / f"catalog_{run_id}"
65
  cfg = CatalogConfig(video_path=video_path, output_dir=catalog_dir)
66
  catalog_path = build_catalog(cfg)
67
+ message, index, gallery = summarize_catalog(catalog_path)
68
+ return str(catalog_path), message, index, gallery
 
 
69
 
70
 
71
  def predict_from_arrays(arrays: list[np.ndarray], catalog_path: str | None) -> tuple[list[dict[str, Any]], list[list[str]]]:
 
96
  gallery_items = [
97
  [item.get("reference_crop", ""), f"{item.get('character_id', 'unknown')} ({item.get('similarity', 0):.2f})"]
98
  for item in data
99
+ if item.get("reference_crop")
100
  ]
101
  return data, gallery_items
102
 
103
 
104
+ def load_initial_catalog() -> tuple[str | None, str, list[dict[str, Any]], list[list[str]]]:
105
+ if CATALOG_PATH.exists():
106
+ message, index, gallery = summarize_catalog(CATALOG_PATH)
107
+ return str(CATALOG_PATH), message, index, gallery
108
+ return None, "Upload a video to generate a catalog.", [], []
109
+
110
+
111
  def build_interface() -> gr.Blocks:
112
+ initial_catalog, initial_status, initial_index, initial_gallery = load_initial_catalog()
113
+
114
  with gr.Blocks() as demo:
115
  gr.Markdown("# Character Reference Matcher")
116
+ catalog_state = gr.State(initial_catalog)
117
+
118
+ status_box = gr.Textbox(label="Status", value=initial_status, interactive=False)
119
+ catalog_json = gr.JSON(label="Character Index", value=initial_index)
120
+ catalog_gallery = gr.Gallery(label="Catalog Characters", columns=4, value=initial_gallery)
121
 
122
  video_upload = gr.File(label="Source video", file_types=["video"], height="auto")
123
  frame_upload = gr.UploadButton(
 
127
  )
128
 
129
  matches_json = gr.JSON(label="Matches")
130
+ match_gallery = gr.Gallery(label="Matched Characters", columns=3)
131
+
132
+ def on_video_upload(file: gr.FileData) -> tuple[str | None, str, list[dict[str, Any]], list[list[str]]]:
133
+ return build_catalog_from_video(file)
134
 
135
+ video_upload.change(on_video_upload, inputs=video_upload, outputs=[catalog_state, status_box, catalog_json, catalog_gallery])
136
 
137
+ def on_frames_upload(files: list[gr.FileData], catalog_path: str | None) -> tuple[list[dict[str, Any]], list[list[str]]]:
138
  arrays = [np.array(Image.open(file.name).convert("RGB")) for file in files]
139
  return predict_from_arrays(arrays, catalog_path)
140
 
141
+ frame_upload.upload(on_frames_upload, inputs=[frame_upload, catalog_state], outputs=[matches_json, match_gallery])
142
  return demo
143
 
144