cheenchan commited on
Commit
471e946
·
1 Parent(s): d4f2461

Compat with Gradio upload/gallery

Browse files
frame_extraction/src/frame_extraction/app.py CHANGED
@@ -7,11 +7,11 @@ from typing import Any
7
 
8
  import gradio as gr
9
  import numpy as np
 
10
 
11
  from .config import MatchConfig
12
  from .matcher import match_frames
13
 
14
-
15
  CATALOG_PATH = Path(os.getenv("FRAME_CATALOG", "catalog/catalog.json"))
16
  OUTPUT_DIR = Path(os.getenv("FRAME_OUTPUT_DIR", "app_outputs"))
17
 
@@ -23,25 +23,28 @@ def load_catalog() -> dict[str, Any] | None:
23
  return None
24
 
25
 
 
 
 
 
 
26
  catalog_cache = load_catalog()
27
 
28
 
29
- def predict(image_inputs: list[np.ndarray]) -> tuple[list[dict[str, Any]], list[tuple[str, str]]]:
30
  if catalog_cache is None:
31
  raise gr.Error("Catalog not found. Upload catalog.json or set FRAME_CATALOG.")
32
 
33
- if not image_inputs:
34
  raise gr.Error("Please upload at least one frame.")
35
 
 
36
  frames_dir = OUTPUT_DIR / "inputs"
37
- frames_dir.mkdir(parents=True, exist_ok=True)
38
-
39
- from PIL import Image
40
 
41
  saved_paths: list[Path] = []
42
- for idx, image in enumerate(image_inputs):
43
  output_path = frames_dir / f"upload_{idx:03d}.png"
44
- Image.fromarray(image).save(output_path)
45
  saved_paths.append(output_path)
46
 
47
  output_path = OUTPUT_DIR / "matches.json"
@@ -55,7 +58,8 @@ def predict(image_inputs: list[np.ndarray]) -> tuple[list[dict[str, Any]], list[
55
  match_frames(cfg)
56
  data = json.loads(output_path.read_text(encoding="utf-8"))
57
  gallery_items = [
58
- (item["reference_crop"], f"{item['character_id']} ({item['similarity']:.2f})") for item in data
 
59
  ]
60
  return data, gallery_items
61
 
@@ -63,28 +67,25 @@ def predict(image_inputs: list[np.ndarray]) -> tuple[list[dict[str, Any]], list[
63
  def build_interface() -> gr.Blocks:
64
  with gr.Blocks() as demo:
65
  gr.Markdown("# Character Reference Matcher")
66
- image_input = gr.UploadButton(
67
  label="Upload frames",
68
  file_types=["image"],
69
  file_count="multiple",
70
  )
71
  matches_json = gr.JSON(label="Matches")
72
- gallery = gr.Gallery(label="Reference Thumbnails").style(grid=2)
73
-
74
- def _handle_upload(files: list[gr.FileData]):
75
- from PIL import Image
76
- import numpy as np
77
 
78
- images = [np.array(Image.open(file.name)) for file in files]
79
- return predict(images)
 
80
 
81
- image_input.upload(_handle_upload, inputs=image_input, outputs=[matches_json, gallery])
82
  return demo
83
 
84
 
85
  def main() -> None:
86
- interface = build_interface()
87
- interface.launch()
88
 
89
 
90
  if __name__ == "__main__":
 
7
 
8
  import gradio as gr
9
  import numpy as np
10
+ from PIL import Image
11
 
12
  from .config import MatchConfig
13
  from .matcher import match_frames
14
 
 
15
  CATALOG_PATH = Path(os.getenv("FRAME_CATALOG", "catalog/catalog.json"))
16
  OUTPUT_DIR = Path(os.getenv("FRAME_OUTPUT_DIR", "app_outputs"))
17
 
 
23
  return None
24
 
25
 
26
+ def ensure_output_dirs() -> None:
27
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
28
+ (OUTPUT_DIR / "inputs").mkdir(parents=True, exist_ok=True)
29
+
30
+
31
  catalog_cache = load_catalog()
32
 
33
 
34
+ def predict_from_arrays(arrays: list[np.ndarray]) -> tuple[list[dict[str, Any]], list[list[str]]]:
35
  if catalog_cache is None:
36
  raise gr.Error("Catalog not found. Upload catalog.json or set FRAME_CATALOG.")
37
 
38
+ if not arrays:
39
  raise gr.Error("Please upload at least one frame.")
40
 
41
+ ensure_output_dirs()
42
  frames_dir = OUTPUT_DIR / "inputs"
 
 
 
43
 
44
  saved_paths: list[Path] = []
45
+ for idx, array in enumerate(arrays):
46
  output_path = frames_dir / f"upload_{idx:03d}.png"
47
+ Image.fromarray(array).save(output_path)
48
  saved_paths.append(output_path)
49
 
50
  output_path = OUTPUT_DIR / "matches.json"
 
58
  match_frames(cfg)
59
  data = json.loads(output_path.read_text(encoding="utf-8"))
60
  gallery_items = [
61
+ [item["reference_crop"], f"{item['character_id']} ({item['similarity']:.2f})"]
62
+ for item in data
63
  ]
64
  return data, gallery_items
65
 
 
67
  def build_interface() -> gr.Blocks:
68
  with gr.Blocks() as demo:
69
  gr.Markdown("# Character Reference Matcher")
70
+ upload = gr.UploadButton(
71
  label="Upload frames",
72
  file_types=["image"],
73
  file_count="multiple",
74
  )
75
  matches_json = gr.JSON(label="Matches")
76
+ gallery = gr.Gallery(label="Reference Thumbnails", columns=2, height="auto")
 
 
 
 
77
 
78
+ def handle_upload(files: list[gr.FileData]) -> tuple[list[dict[str, Any]], list[list[str]]]:
79
+ arrays = [np.array(Image.open(file.name).convert("RGB")) for file in files]
80
+ return predict_from_arrays(arrays)
81
 
82
+ upload.upload(handle_upload, inputs=upload, outputs=[matches_json, gallery])
83
  return demo
84
 
85
 
86
  def main() -> None:
87
+ demo = build_interface()
88
+ demo.launch()
89
 
90
 
91
  if __name__ == "__main__":