Allow building catalog from uploaded video
Browse files
frame_extraction/src/frame_extraction/app.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import json
|
| 4 |
import os
|
|
|
|
|
|
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import Any
|
| 7 |
|
|
@@ -9,47 +12,60 @@ import gradio as gr
|
|
| 9 |
import numpy as np
|
| 10 |
from PIL import Image
|
| 11 |
|
| 12 |
-
from .
|
|
|
|
| 13 |
from .matcher import match_frames
|
| 14 |
|
| 15 |
-
|
|
|
|
| 16 |
OUTPUT_DIR = Path(os.getenv("FRAME_OUTPUT_DIR", "app_outputs"))
|
| 17 |
|
| 18 |
|
| 19 |
-
def load_catalog() -> dict[str, Any] | None:
|
| 20 |
-
path = Path(CATALOG_PATH)
|
| 21 |
-
if path.exists():
|
| 22 |
-
return json.loads(path.read_text(encoding="utf-8"))
|
| 23 |
-
return None
|
| 24 |
-
|
| 25 |
-
|
| 26 |
def ensure_output_dirs() -> None:
|
| 27 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 28 |
-
(OUTPUT_DIR / "
|
|
|
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
-
def predict_from_arrays(arrays: list[np.ndarray]) -> tuple[list[dict[str, Any]], list[list[str]]]:
|
| 35 |
-
if
|
| 36 |
-
raise gr.Error("Catalog not
|
| 37 |
|
| 38 |
if not arrays:
|
| 39 |
raise gr.Error("Please upload at least one frame.")
|
| 40 |
|
| 41 |
ensure_output_dirs()
|
| 42 |
-
|
|
|
|
|
|
|
| 43 |
|
| 44 |
-
saved_paths: list[Path] = []
|
| 45 |
for idx, array in enumerate(arrays):
|
| 46 |
-
|
| 47 |
-
Image.fromarray(array).save(output_path)
|
| 48 |
-
saved_paths.append(output_path)
|
| 49 |
|
| 50 |
-
output_path = OUTPUT_DIR / "
|
| 51 |
cfg = MatchConfig(
|
| 52 |
-
catalog_path=Path(
|
| 53 |
frames_dir=frames_dir,
|
| 54 |
output_path=output_path,
|
| 55 |
top_k=1,
|
|
@@ -58,7 +74,7 @@ def predict_from_arrays(arrays: list[np.ndarray]) -> tuple[list[dict[str, Any]],
|
|
| 58 |
match_frames(cfg)
|
| 59 |
data = json.loads(output_path.read_text(encoding="utf-8"))
|
| 60 |
gallery_items = [
|
| 61 |
-
[item
|
| 62 |
for item in data
|
| 63 |
]
|
| 64 |
return data, gallery_items
|
|
@@ -67,19 +83,31 @@ def predict_from_arrays(arrays: list[np.ndarray]) -> tuple[list[dict[str, Any]],
|
|
| 67 |
def build_interface() -> gr.Blocks:
|
| 68 |
with gr.Blocks() as demo:
|
| 69 |
gr.Markdown("# Character Reference Matcher")
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
label="Upload frames",
|
| 72 |
file_types=["image"],
|
| 73 |
file_count="multiple",
|
| 74 |
)
|
|
|
|
| 75 |
matches_json = gr.JSON(label="Matches")
|
| 76 |
-
gallery = gr.Gallery(label="Reference Thumbnails", columns=2
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
def
|
| 79 |
arrays = [np.array(Image.open(file.name).convert("RGB")) for file in files]
|
| 80 |
-
return predict_from_arrays(arrays)
|
| 81 |
|
| 82 |
-
|
| 83 |
return demo
|
| 84 |
|
| 85 |
|
|
|
|
| 1 |
+
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
import json
|
| 5 |
import os
|
| 6 |
+
import shutil
|
| 7 |
+
import uuid
|
| 8 |
from pathlib import Path
|
| 9 |
from typing import Any
|
| 10 |
|
|
|
|
| 12 |
import numpy as np
|
| 13 |
from PIL import Image
|
| 14 |
|
| 15 |
+
from .catalog import build_catalog
|
| 16 |
+
from .config import CatalogConfig, MatchConfig
|
| 17 |
from .matcher import match_frames
|
| 18 |
|
| 19 |
+
CATALOG_ENV_DEFAULT = "catalog/catalog.json"
|
| 20 |
+
CATALOG_PATH = Path(os.getenv("FRAME_CATALOG", CATALOG_ENV_DEFAULT))
|
| 21 |
OUTPUT_DIR = Path(os.getenv("FRAME_OUTPUT_DIR", "app_outputs"))
|
| 22 |
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
def ensure_output_dirs() -> None:
|
| 25 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 26 |
+
(OUTPUT_DIR / "catalogs").mkdir(parents=True, exist_ok=True)
|
| 27 |
+
(OUTPUT_DIR / "videos").mkdir(parents=True, exist_ok=True)
|
| 28 |
+
(OUTPUT_DIR / "frames").mkdir(parents=True, exist_ok=True)
|
| 29 |
|
| 30 |
|
| 31 |
+
def build_catalog_from_video(files: list[gr.FileData]) -> tuple[str | None, str]:
|
| 32 |
+
if not files:
|
| 33 |
+
raise gr.Error("Please upload a source video first.")
|
| 34 |
+
|
| 35 |
+
ensure_output_dirs()
|
| 36 |
+
file = files[0]
|
| 37 |
+
run_id = uuid.uuid4().hex[:8]
|
| 38 |
+
video_dir = OUTPUT_DIR / "videos"
|
| 39 |
+
video_path = video_dir / f"{run_id}_{Path(file.name).name}"
|
| 40 |
+
shutil.copy(file.name, video_path)
|
| 41 |
+
|
| 42 |
+
catalog_dir = OUTPUT_DIR / "catalogs" / f"catalog_{run_id}"
|
| 43 |
+
cfg = CatalogConfig(video_path=video_path, output_dir=catalog_dir)
|
| 44 |
+
catalog_path = build_catalog(cfg)
|
| 45 |
+
catalog_data = json.loads(catalog_path.read_text(encoding="utf-8"))
|
| 46 |
+
ref_count = len(catalog_data.get("references", []))
|
| 47 |
+
message = f"Catalog ready ({ref_count} references)."
|
| 48 |
+
return str(catalog_path), message
|
| 49 |
|
| 50 |
|
| 51 |
+
def predict_from_arrays(arrays: list[np.ndarray], catalog_path: str | None) -> tuple[list[dict[str, Any]], list[list[str]]]:
|
| 52 |
+
if not catalog_path:
|
| 53 |
+
raise gr.Error("Catalog not ready yet. Upload a video first.")
|
| 54 |
|
| 55 |
if not arrays:
|
| 56 |
raise gr.Error("Please upload at least one frame.")
|
| 57 |
|
| 58 |
ensure_output_dirs()
|
| 59 |
+
run_id = uuid.uuid4().hex[:8]
|
| 60 |
+
frames_dir = OUTPUT_DIR / "frames" / run_id
|
| 61 |
+
frames_dir.mkdir(parents=True, exist_ok=True)
|
| 62 |
|
|
|
|
| 63 |
for idx, array in enumerate(arrays):
|
| 64 |
+
Image.fromarray(array).save(frames_dir / f"upload_{idx:03d}.png")
|
|
|
|
|
|
|
| 65 |
|
| 66 |
+
output_path = OUTPUT_DIR / f"matches_{run_id}.json"
|
| 67 |
cfg = MatchConfig(
|
| 68 |
+
catalog_path=Path(catalog_path),
|
| 69 |
frames_dir=frames_dir,
|
| 70 |
output_path=output_path,
|
| 71 |
top_k=1,
|
|
|
|
| 74 |
match_frames(cfg)
|
| 75 |
data = json.loads(output_path.read_text(encoding="utf-8"))
|
| 76 |
gallery_items = [
|
| 77 |
+
[item.get("reference_crop", ""), f"{item.get('character_id', 'unknown')} ({item.get('similarity', 0):.2f})"]
|
| 78 |
for item in data
|
| 79 |
]
|
| 80 |
return data, gallery_items
|
|
|
|
| 83 |
def build_interface() -> gr.Blocks:
|
| 84 |
with gr.Blocks() as demo:
|
| 85 |
gr.Markdown("# Character Reference Matcher")
|
| 86 |
+
catalog_state = gr.State[str | None](str(CATALOG_PATH) if CATALOG_PATH.exists() else None)
|
| 87 |
+
status_box = gr.Textbox(label="Status", value="Upload a video to generate a catalog.", interactive=False)
|
| 88 |
+
|
| 89 |
+
video_upload = gr.UploadButton(
|
| 90 |
+
label="Upload source video",
|
| 91 |
+
file_types=["video"],
|
| 92 |
+
file_count="single",
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
frame_upload = gr.UploadButton(
|
| 96 |
label="Upload frames",
|
| 97 |
file_types=["image"],
|
| 98 |
file_count="multiple",
|
| 99 |
)
|
| 100 |
+
|
| 101 |
matches_json = gr.JSON(label="Matches")
|
| 102 |
+
gallery = gr.Gallery(label="Reference Thumbnails", columns=2)
|
| 103 |
+
|
| 104 |
+
video_upload.upload(build_catalog_from_video, inputs=video_upload, outputs=[catalog_state, status_box])
|
| 105 |
|
| 106 |
+
def handle_frames(files: list[gr.FileData], catalog_path: str | None) -> tuple[list[dict[str, Any]], list[list[str]]]:
|
| 107 |
arrays = [np.array(Image.open(file.name).convert("RGB")) for file in files]
|
| 108 |
+
return predict_from_arrays(arrays, catalog_path)
|
| 109 |
|
| 110 |
+
frame_upload.upload(handle_frames, inputs=[frame_upload, catalog_state], outputs=[matches_json, gallery])
|
| 111 |
return demo
|
| 112 |
|
| 113 |
|