cheenchan commited on
Commit
d4f2461
·
1 Parent(s): d564157

Use upload button for multiple frames

Browse files
frame_extraction/src/frame_extraction/app.py CHANGED
@@ -7,7 +7,6 @@ from typing import Any
7
 
8
  import gradio as gr
9
  import numpy as np
10
- import torch
11
 
12
  from .config import MatchConfig
13
  from .matcher import match_frames
@@ -27,19 +26,23 @@ def load_catalog() -> dict[str, Any] | None:
27
  catalog_cache = load_catalog()
28
 
29
 
30
- def predict(images: list[np.ndarray]) -> tuple[list[str], list[str]]:
31
  if catalog_cache is None:
32
  raise gr.Error("Catalog not found. Upload catalog.json or set FRAME_CATALOG.")
33
 
34
- if not images:
35
  raise gr.Error("Please upload at least one frame.")
36
 
37
  frames_dir = OUTPUT_DIR / "inputs"
38
  frames_dir.mkdir(parents=True, exist_ok=True)
39
- for idx, image in enumerate(images):
40
- from PIL import Image
41
 
42
- Image.fromarray(image).save(frames_dir / f"upload_{idx:03d}.png")
 
 
 
 
 
 
43
 
44
  output_path = OUTPUT_DIR / "matches.json"
45
  cfg = MatchConfig(
@@ -51,23 +54,37 @@ def predict(images: list[np.ndarray]) -> tuple[list[str], list[str]]:
51
  )
52
  match_frames(cfg)
53
  data = json.loads(output_path.read_text(encoding="utf-8"))
54
- gallery_items = [(item["reference_crop"], f"{item['character_id']} ({item['similarity']:.2f})") for item in data]
 
 
55
  return data, gallery_items
56
 
57
 
58
- with gr.Blocks() as demo:
59
- gr.Markdown("# Character Reference Matcher")
60
- with gr.Row():
61
- image_input = gr.Image(type="numpy", image_mode="RGB", label="Upload frames", multiple=True)
62
- submit = gr.Button("Match Characters")
63
- matches_json = gr.JSON(label="Matches")
64
- gallery = gr.Gallery(label="Reference Thumbnails").style(grid=2)
 
 
 
 
 
 
 
 
 
 
65
 
66
- submit.click(predict, inputs=image_input, outputs=[matches_json, gallery])
 
67
 
68
 
69
  def main() -> None:
70
- demo.launch()
 
71
 
72
 
73
  if __name__ == "__main__":
 
7
 
8
  import gradio as gr
9
  import numpy as np
 
10
 
11
  from .config import MatchConfig
12
  from .matcher import match_frames
 
26
  catalog_cache = load_catalog()
27
 
28
 
29
+ def predict(image_inputs: list[np.ndarray]) -> tuple[list[dict[str, Any]], list[tuple[str, str]]]:
30
  if catalog_cache is None:
31
  raise gr.Error("Catalog not found. Upload catalog.json or set FRAME_CATALOG.")
32
 
33
+ if not image_inputs:
34
  raise gr.Error("Please upload at least one frame.")
35
 
36
  frames_dir = OUTPUT_DIR / "inputs"
37
  frames_dir.mkdir(parents=True, exist_ok=True)
 
 
38
 
39
+ from PIL import Image
40
+
41
+ saved_paths: list[Path] = []
42
+ for idx, image in enumerate(image_inputs):
43
+ output_path = frames_dir / f"upload_{idx:03d}.png"
44
+ Image.fromarray(image).save(output_path)
45
+ saved_paths.append(output_path)
46
 
47
  output_path = OUTPUT_DIR / "matches.json"
48
  cfg = MatchConfig(
 
54
  )
55
  match_frames(cfg)
56
  data = json.loads(output_path.read_text(encoding="utf-8"))
57
+ gallery_items = [
58
+ (item["reference_crop"], f"{item['character_id']} ({item['similarity']:.2f})") for item in data
59
+ ]
60
  return data, gallery_items
61
 
62
 
63
+ def build_interface() -> gr.Blocks:
64
+ with gr.Blocks() as demo:
65
+ gr.Markdown("# Character Reference Matcher")
66
+ image_input = gr.UploadButton(
67
+ label="Upload frames",
68
+ file_types=["image"],
69
+ file_count="multiple",
70
+ )
71
+ matches_json = gr.JSON(label="Matches")
72
+ gallery = gr.Gallery(label="Reference Thumbnails").style(grid=2)
73
+
74
+ def _handle_upload(files: list[gr.FileData]):
75
+ from PIL import Image
76
+ import numpy as np
77
+
78
+ images = [np.array(Image.open(file.name)) for file in files]
79
+ return predict(images)
80
 
81
+ image_input.upload(_handle_upload, inputs=image_input, outputs=[matches_json, gallery])
82
+ return demo
83
 
84
 
85
  def main() -> None:
86
+ interface = build_interface()
87
+ interface.launch()
88
 
89
 
90
  if __name__ == "__main__":