Spaces:
Sleeping
Sleeping
| # app_cache.py | |
| # Purpose: | |
| # - Same UI flow (upload -> load frames -> annotate -> generate mask -> track) | |
| # - After tracking, enable "Save Cache" | |
| # - You can create multiple caches by repeating the workflow | |
| # | |
| # Cache contents per example: | |
| # cache/<key>/ | |
| # meta.pkl | |
| # frames/*.jpg | |
| # state_tensors.pt (must3r_feats, must3r_outputs, sam2_input_images, images_tensor) on CPU | |
| # output_tracking.mp4 | |
| # | |
| # Notes: | |
| # - We do NOT pickle views/resize_funcs (recomputed on load). | |
| # - We store frames as JPEG to avoid pickling PIL and to be deterministic/reloadable. | |
| import spaces | |
| import subprocess | |
| import sys, os | |
| from pathlib import Path | |
| import math | |
| import hashlib | |
| import pickle | |
| from datetime import datetime | |
| from typing import Any, Dict, List, Tuple | |
| import importlib, site | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| import cv2 | |
| import logging | |
| # ---------------------------- | |
| # Project bootstrap | |
| # ---------------------------- | |
| ROOT = Path(__file__).resolve().parent | |
| SAM2 = ROOT / "sam2-src" | |
| CKPT = SAM2 / "checkpoints" / "sam2.1_hiera_large.pt" | |
| # download sam2 checkpoints | |
| if not CKPT.exists(): | |
| subprocess.check_call(["bash", "download_ckpts.sh"], cwd=SAM2 / "checkpoints") | |
| # install sam2 | |
| try: | |
| import sam2.build_sam # noqa | |
| except ModuleNotFoundError: | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "./sam2-src"], cwd=ROOT) | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "./sam2-src[notebooks]"], cwd=ROOT) | |
| # install asmk | |
| try: | |
| import asmk.index # noqa: F401 | |
| except Exception: | |
| subprocess.check_call(["cythonize", "*.pyx"], cwd="./asmk-src/cython") | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "./asmk-src", "--no-build-isolation"]) | |
| # download private checkpoints | |
| if not os.path.exists("./private"): | |
| from huggingface_hub import snapshot_download | |
| snapshot_download( | |
| repo_id="nycu-cplab/3AM", | |
| local_dir="./private", | |
| repo_type="model", | |
| ) | |
| for sp in site.getsitepackages(): | |
| site.addsitedir(sp) | |
| importlib.invalidate_caches() | |
| # ---------------------------- | |
| # Logging | |
| # ---------------------------- | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(message)s", | |
| handlers=[logging.StreamHandler(sys.stdout)], | |
| ) | |
| logger = logging.getLogger("app_cache") | |
| # ---------------------------- | |
| # Engine imports | |
| # ---------------------------- | |
| from engine import ( | |
| get_predictors, | |
| get_views, | |
| prepare_sam2_inputs, | |
| must3r_features_and_output, | |
| get_single_frame_mask, | |
| get_tracked_masks, | |
| ) | |
| # ---------------------------- | |
| # Globals | |
| # ---------------------------- | |
| PREDICTOR_ORIGINAL = None | |
| PREDICTOR = None | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_models(): | |
| global PREDICTOR_ORIGINAL, PREDICTOR | |
| if PREDICTOR is None or PREDICTOR_ORIGINAL is None: | |
| logger.info(f"Initializing models on device: {DEVICE}...") | |
| PREDICTOR_ORIGINAL, PREDICTOR = get_predictors(device=DEVICE) | |
| logger.info("Models loaded successfully.") | |
| return PREDICTOR_ORIGINAL, PREDICTOR | |
| # Ensure no_grad globally (as you had) | |
| torch.no_grad().__enter__() | |
| # ---------------------------- | |
| # Video / visualization helpers | |
| # ---------------------------- | |
| def video_to_frames(video_path, interval=1): | |
| logger.info(f"Extracting frames from video: {video_path} with interval={interval}") | |
| cap = cv2.VideoCapture(video_path) | |
| frames = [] | |
| count = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if count % interval == 0: | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(Image.fromarray(frame_rgb)) | |
| count += 1 | |
| cap.release() | |
| logger.info(f"Extracted {len(frames)} frames (sampled from {count} total).") | |
| return frames | |
| def draw_points(image_pil, points, labels): | |
| img_draw = image_pil.copy() | |
| draw = ImageDraw.Draw(img_draw) | |
| r = 7.5 | |
| for pt, lbl in zip(points, labels): | |
| x, y = pt | |
| if lbl == 1: | |
| color = "green" | |
| elif lbl == 0: | |
| color = "red" | |
| elif lbl == 2: | |
| color = "blue" | |
| elif lbl == 3: | |
| color = "cyan" | |
| else: | |
| color = "yellow" | |
| draw.ellipse((x-r, y-r, x+r, y+r), fill=color, outline="white") | |
| return img_draw | |
| def overlay_mask(image_pil, mask, color=(255, 0, 0), alpha=0.5): | |
| if mask is None: | |
| return image_pil | |
| mask = mask > 0 | |
| img_np = np.array(image_pil) | |
| h, w = img_np.shape[:2] | |
| if mask.shape[0] != h or mask.shape[1] != w: | |
| mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool) | |
| overlay = img_np.copy() | |
| overlay[mask] = np.array(color, dtype=np.uint8) | |
| combined = cv2.addWeighted(overlay, alpha, img_np, 1 - alpha, 0) | |
| return Image.fromarray(combined) | |
| def create_video_from_masks(frames, masks_dict, output_path="output_tracking.mp4", fps=24): | |
| logger.info(f"Creating video output at {output_path} with {len(frames)} frames.") | |
| if not frames: | |
| return None | |
| fps = float(fps) | |
| if not (fps > 0.0): | |
| fps = 24.0 | |
| h, w = np.array(frames[0]).shape[:2] | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (w, h)) | |
| for idx, frame in enumerate(frames): | |
| mask = masks_dict.get(idx) | |
| if mask is not None: | |
| pil_out = overlay_mask(frame, mask, color=(255, 0, 0), alpha=0.6) | |
| frame_np = np.array(pil_out) | |
| else: | |
| frame_np = np.array(frame) | |
| frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR) | |
| out.write(frame_bgr) | |
| out.release() | |
| return output_path | |
| # ---------------------------- | |
| # Runtime estimation helpers | |
| # ---------------------------- | |
| def estimate_video_fps(video_path: str) -> float: | |
| cap = cv2.VideoCapture(video_path) | |
| fps = float(cap.get(cv2.CAP_PROP_FPS)) or 0.0 | |
| cap.release() | |
| return fps if fps > 0.0 else 24.0 | |
| def estimate_total_frames(video_path: str) -> int: | |
| cap = cv2.VideoCapture(video_path) | |
| n = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0 | |
| cap.release() | |
| return max(1, n) | |
| MAX_GPU_SECONDS = 600 | |
| def clamp_duration(sec: int) -> int: | |
| return int(min(MAX_GPU_SECONDS, max(1, sec))) | |
| def get_duration_must3r_features(video_path, interval): | |
| total = estimate_total_frames(video_path) | |
| interval = max(1, int(interval)) | |
| processed = math.ceil(total / interval) | |
| sec_per_frame = 2 | |
| return clamp_duration(int(processed * sec_per_frame)) | |
| def get_duration_tracking(sam2_input_images, must3r_feats, must3r_outputs, start_idx, first_frame_mask): | |
| try: | |
| n = int(getattr(sam2_input_images, "shape")[0]) | |
| except Exception: | |
| n = 100 | |
| sec_per_frame = 2 | |
| return clamp_duration(int(n * sec_per_frame)) | |
| # ---------------------------- | |
| # GPU functions | |
| # ---------------------------- | |
| def process_video_and_features(video_path, interval): | |
| logger.info(f"GPU: feature extraction interval={interval}") | |
| load_models() | |
| pil_imgs = video_to_frames(video_path, interval=max(1, int(interval))) | |
| if not pil_imgs: | |
| raise ValueError("Could not extract frames.") | |
| views, resize_funcs = get_views(pil_imgs) | |
| must3r_feats, must3r_outputs = must3r_features_and_output(views, device=DEVICE) | |
| sam2_input_images, images_tensor = prepare_sam2_inputs(views, pil_imgs, resize_funcs) | |
| return pil_imgs, views, resize_funcs, must3r_feats, must3r_outputs, sam2_input_images, images_tensor | |
| def generate_frame_mask(image_tensor, points, labels, original_size): | |
| logger.info(f"GPU: generate mask points={len(points)}") | |
| load_models() | |
| pts_tensor = torch.tensor(points, dtype=torch.float32).unsqueeze(0).to(DEVICE) | |
| lbl_tensor = torch.tensor(labels, dtype=torch.int32).unsqueeze(0).to(DEVICE) | |
| w, h = original_size | |
| pts_tensor[..., 0] /= (w / 1024.0) | |
| pts_tensor[..., 1] /= (h / 1024.0) | |
| mask = get_single_frame_mask( | |
| image=image_tensor, | |
| predictor_original=PREDICTOR_ORIGINAL, | |
| points=pts_tensor, | |
| labels=lbl_tensor, | |
| device=DEVICE, | |
| ) | |
| return mask.squeeze().cpu().numpy() | |
| def run_tracking(sam2_input_images, must3r_feats, must3r_outputs, start_idx, first_frame_mask): | |
| logger.info(f"GPU: tracking start_idx={start_idx}") | |
| load_models() | |
| mask_tensor = torch.tensor(first_frame_mask).to(DEVICE) > 0 | |
| tracked_masks = get_tracked_masks( | |
| sam2_input_images=sam2_input_images, | |
| must3r_feats=must3r_feats, | |
| must3r_outputs=must3r_outputs, | |
| start_idx=start_idx, | |
| first_frame_mask=mask_tensor, | |
| predictor=PREDICTOR, | |
| predictor_original=PREDICTOR_ORIGINAL, | |
| device=DEVICE, | |
| ) | |
| return tracked_masks | |
| # ---------------------------- | |
| # Cache utilities | |
| # ---------------------------- | |
| CACHE_DIR = Path("./tmpcache/cache") | |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| def _make_cache_key(video_path: str, interval: int, start_idx: int) -> str: | |
| name = Path(video_path).name if video_path else "video" | |
| stamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") | |
| s = f"{name}|interval={interval}|start={start_idx}|{stamp}" | |
| return hashlib.sha256(s.encode("utf-8")).hexdigest()[:16] | |
| def _cache_paths(key: str) -> Dict[str, Path]: | |
| base = CACHE_DIR / key | |
| base.mkdir(parents=True, exist_ok=True) | |
| return { | |
| "base": base, | |
| "meta": base / "meta.pkl", | |
| "frames_dir": base / "frames", | |
| "vis_img": base / "vis_img.png", | |
| "tensors": base / "state_tensors.pt", | |
| "video": base / "output_tracking.mp4", | |
| } | |
| def _save_frames_as_jpg(pil_imgs: List[Image.Image], frames_dir: Path, quality: int = 95) -> None: | |
| frames_dir.mkdir(parents=True, exist_ok=True) | |
| for i, im in enumerate(pil_imgs): | |
| im.save(frames_dir / f"{i:06d}.jpg", "JPEG", quality=quality, subsampling=0) | |
| def _to_cpu(obj: Any) -> Any: | |
| if torch.is_tensor(obj): | |
| return obj.detach().to("cpu") | |
| if isinstance(obj, dict): | |
| return {k: _to_cpu(v) for k, v in obj.items()} | |
| if isinstance(obj, (list, tuple)): | |
| out = [_to_cpu(v) for v in obj] | |
| return type(obj)(out) if isinstance(obj, tuple) else out | |
| return obj | |
| def save_full_cache_from_state(state: Dict[str, Any]) -> str: | |
| if not state: | |
| raise ValueError("Empty state.") | |
| required = [ | |
| "pil_imgs", | |
| "must3r_feats", | |
| "must3r_outputs", | |
| "sam2_input_images", | |
| "images_tensor", | |
| "output_video_path", | |
| "video_path", | |
| "interval", | |
| "fps_in", | |
| "fps_out", | |
| "last_tracking_start_idx", | |
| ] | |
| missing = [k for k in required if k not in state or state[k] is None] | |
| if missing: | |
| raise ValueError(f"State missing fields: {missing}") | |
| key = _make_cache_key( | |
| str(state["video_path"]), | |
| int(state["interval"]), | |
| int(state["last_tracking_start_idx"]), | |
| ) | |
| paths = _cache_paths(key) | |
| _save_frames_as_jpg(state["pil_imgs"], paths["frames_dir"]) | |
| state['current_vis_img'].save(paths["vis_img"]) | |
| print(f"Saving tensors to cache...") | |
| torch.save( | |
| { | |
| "must3r_feats": _to_cpu(state["must3r_feats"]), | |
| "must3r_outputs": _to_cpu(state["must3r_outputs"]), | |
| "sam2_input_images": _to_cpu(state["sam2_input_images"]), | |
| "images_tensor": _to_cpu(state["images_tensor"]), | |
| }, | |
| paths["tensors"], | |
| ) | |
| src = Path(state["output_video_path"]) | |
| if not src.exists(): | |
| raise FileNotFoundError(f"Output video not found: {src}") | |
| dst = paths["video"] | |
| if src.resolve() != dst.resolve(): | |
| dst.write_bytes(src.read_bytes()) | |
| meta = { | |
| "video_name": Path(str(state["video_path"])).name, | |
| "interval": int(state["interval"]), | |
| "fps_in": float(state["fps_in"]), | |
| "fps_out": float(state["fps_out"]), | |
| "num_frames": int(len(state["pil_imgs"])), | |
| "start_idx": int(state["last_tracking_start_idx"]), | |
| "points": list(state.get("last_points", [])), | |
| "labels": list(state.get("last_labels", [])), | |
| 'first_frame_mask': state.get("current_mask", None), | |
| "cache_key": key, | |
| } | |
| with open(paths["meta"], "wb") as f: | |
| pickle.dump(meta, f) | |
| print(f"Cache saved at key: {key}") | |
| return key | |
| # ---------------------------- | |
| # UI callbacks | |
| # ---------------------------- | |
| def on_video_upload(video_path, interval): | |
| if video_path is None: | |
| return None, None, gr.Slider(value=0, maximum=0), None | |
| pil_imgs, views, resize_funcs, must3r_feats, must3r_outputs, sam2_input_images, images_tensor = process_video_and_features( | |
| video_path, int(interval) | |
| ) | |
| fps_in = estimate_video_fps(video_path) | |
| interval_i = max(1, int(interval)) | |
| fps_out = max(1.0, fps_in / interval_i) | |
| state = { | |
| "pil_imgs": pil_imgs, | |
| "views": views, | |
| "resize_funcs": resize_funcs, | |
| "must3r_feats": must3r_feats, | |
| "must3r_outputs": must3r_outputs, | |
| "sam2_input_images": sam2_input_images, | |
| "images_tensor": images_tensor, | |
| "current_points": [], | |
| "current_labels": [], | |
| "current_mask": None, | |
| "frame_idx": 0, | |
| "video_path": video_path, | |
| "interval": interval_i, | |
| "fps_in": fps_in, | |
| "fps_out": fps_out, | |
| # tracking outputs (filled later) | |
| "output_video_path": None, | |
| "last_tracking_start_idx": None, | |
| "last_points": None, | |
| "last_labels": None, | |
| } | |
| first_frame = pil_imgs[0] | |
| new_slider = gr.Slider(value=0, maximum=len(pil_imgs) - 1, step=1, interactive=True) | |
| return first_frame, state, new_slider, gr.Image(value=first_frame) | |
| def on_slider_change(state, frame_idx): | |
| if not state: | |
| return None | |
| frame_idx = int(frame_idx) | |
| frame_idx = min(frame_idx, len(state["pil_imgs"]) - 1) | |
| state["frame_idx"] = frame_idx | |
| state["current_points"] = [] | |
| state["current_labels"] = [] | |
| state["current_mask"] = None | |
| frame = state["pil_imgs"][frame_idx] | |
| return frame | |
| def on_image_click(state, evt: gr.SelectData, mode): | |
| if not state: | |
| return None | |
| x, y = evt.index | |
| label_map = { | |
| "Positive Point": 1, | |
| "Negative Point": 0, | |
| "Box Top-Left": 2, | |
| "Box Bottom-Right": 3, | |
| } | |
| label = label_map[mode] | |
| state["current_points"].append([x, y]) | |
| state["current_labels"].append(label) | |
| frame_pil = state["pil_imgs"][state["frame_idx"]] | |
| vis_img = draw_points(frame_pil, state["current_points"], state["current_labels"]) | |
| if state["current_mask"] is not None: | |
| vis_img = overlay_mask(vis_img, state["current_mask"]) | |
| return vis_img | |
| def on_generate_mask_click(state): | |
| if not state: | |
| return None | |
| if not state["current_points"]: | |
| raise gr.Error("No points or boxes annotated.") | |
| num_tl = state["current_labels"].count(2) | |
| num_br = state["current_labels"].count(3) | |
| if num_tl != num_br or num_tl > 1: | |
| raise gr.Error(f"Incomplete box: TL={num_tl}, BR={num_br}. Must match and be <= 1.") | |
| frame_idx = state["frame_idx"] | |
| full_tensor = state["sam2_input_images"] | |
| frame_tensor = full_tensor[frame_idx].unsqueeze(0) | |
| original_size = state["pil_imgs"][frame_idx].size | |
| mask = generate_frame_mask( | |
| frame_tensor, | |
| state["current_points"], | |
| state["current_labels"], | |
| original_size, | |
| ) | |
| state["current_mask"] = mask | |
| frame_pil = state["pil_imgs"][frame_idx] | |
| vis_img = overlay_mask(frame_pil, mask) | |
| vis_img = draw_points(vis_img, state["current_points"], state["current_labels"]) | |
| state["current_vis_img"] = vis_img.copy() | |
| return vis_img | |
| def reset_annotations(state): | |
| if not state: | |
| return None | |
| state["current_points"] = [] | |
| state["current_labels"] = [] | |
| state["current_mask"] = None | |
| frame_idx = state["frame_idx"] | |
| return state["pil_imgs"][frame_idx] | |
| def on_track_click(state): | |
| if not state or state["current_mask"] is None: | |
| raise gr.Error("Generate a mask first.") | |
| num_tl = state["current_labels"].count(2) | |
| num_br = state["current_labels"].count(3) | |
| if num_tl != num_br: | |
| raise gr.Error("Incomplete box annotations.") | |
| start_idx = int(state["frame_idx"]) | |
| first_frame_mask = state["current_mask"] | |
| tracked_masks_dict = run_tracking( | |
| state["sam2_input_images"], | |
| state["must3r_feats"], | |
| state["must3r_outputs"], | |
| start_idx, | |
| first_frame_mask, | |
| ) | |
| output_path = create_video_from_masks( | |
| state["pil_imgs"], | |
| tracked_masks_dict, | |
| fps=state.get("fps_out", 24.0), | |
| ) | |
| state["output_video_path"] = output_path | |
| state["last_tracking_start_idx"] = start_idx | |
| state["last_points"] = list(state.get("current_points", [])) | |
| state["last_labels"] = list(state.get("current_labels", [])) | |
| print(f"Tracking Complete") | |
| return output_path, state | |
| def on_save_cache_click(state): | |
| key = save_full_cache_from_state(state) | |
| return f"Saved cache key: {key}" | |
| # ---------------------------- | |
| # UI layout | |
| # ---------------------------- | |
| description = """ | |
| <div style="text-align: center;"> | |
| <h1>3AM: 3egment Anything with Geometric Consistency in Videos</h1> | |
| <p>Cache-builder UI: run full pipeline, then save caches for user examples.</p> | |
| </div> | |
| """ | |
| with gr.Blocks(title="3AM Cache Builder") as app: | |
| gr.HTML(description) | |
| app_state = gr.State() | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Step 1 — Upload video") | |
| video_input = gr.Video(label="Upload Video", sources=["upload"], height=512) | |
| gr.Markdown("## Step 2 — Set interval, then load frames") | |
| interval_slider = gr.Slider( | |
| label="Frame Interval", | |
| minimum=1, | |
| maximum=30, | |
| step=1, | |
| value=1, | |
| ) | |
| load_btn = gr.Button("Load Frames", variant="primary") | |
| process_status = gr.Textbox(label="Status", value="1) Upload a video.", interactive=False) | |
| with gr.Column(scale=2): | |
| gr.Markdown("## Step 3 — Annotate frame & generate mask") | |
| img_display = gr.Image(label="Annotate Frame", interactive=True, height=512) | |
| frame_slider = gr.Slider(label="Select Frame", minimum=0, maximum=100, step=1, value=0) | |
| with gr.Row(): | |
| mode_radio = gr.Radio( | |
| choices=["Positive Point", "Negative Point", "Box Top-Left", "Box Bottom-Right"], | |
| value="Positive Point", | |
| label="Annotation Mode", | |
| ) | |
| with gr.Column(): | |
| gen_mask_btn = gr.Button("Generate Mask", variant="primary", interactive=False) | |
| reset_btn = gr.Button("Reset Annotations", interactive=False) | |
| gr.Markdown("## Step 4 — Track & Save Cache") | |
| with gr.Row(): | |
| track_btn = gr.Button("Start Tracking", variant="primary", interactive=False) | |
| save_cache_btn = gr.Button("Save Cache", variant="secondary", interactive=False) | |
| with gr.Row(): | |
| video_output = gr.Video(label="Tracking Output", autoplay=True, height=512) | |
| cache_status = gr.Textbox(label="Cache", value="", interactive=False) | |
| # ------------------------ | |
| # Events | |
| # ------------------------ | |
| def on_video_uploaded(video_path): | |
| n_frames = estimate_total_frames(video_path) | |
| default_interval = max(1, n_frames // 100) | |
| return ( | |
| gr.update(value=default_interval, maximum=min(30, n_frames)), | |
| f"Video uploaded ({n_frames} frames). 2) Adjust interval, then click 'Load Frames'.", | |
| ) | |
| video_input.upload(fn=on_video_uploaded, inputs=video_input, outputs=[interval_slider, process_status]) | |
| load_btn.click( | |
| fn=lambda: ( | |
| "Loading frames...", | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), # save_cache_btn | |
| gr.update(value=""), | |
| ), | |
| outputs=[process_status, gen_mask_btn, reset_btn, track_btn, save_cache_btn, cache_status], | |
| ).then( | |
| fn=on_video_upload, | |
| inputs=[video_input, interval_slider], | |
| outputs=[img_display, app_state, frame_slider, img_display], | |
| ).then( | |
| fn=lambda: ( | |
| "Ready. 3) Annotate and generate mask.", | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| ), | |
| outputs=[process_status, gen_mask_btn, reset_btn, track_btn], | |
| ) | |
| frame_slider.change(fn=on_slider_change, inputs=[app_state, frame_slider], outputs=[img_display]) | |
| img_display.select(fn=on_image_click, inputs=[app_state, mode_radio], outputs=[img_display]) | |
| gen_mask_btn.click(fn=on_generate_mask_click, inputs=[app_state], outputs=[img_display]) | |
| reset_btn.click(fn=reset_annotations, inputs=[app_state], outputs=[img_display]) | |
| track_btn.click( | |
| fn=lambda: ( | |
| "Tracking in progress...", | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| ), | |
| outputs=[process_status, track_btn, save_cache_btn], | |
| ).then( | |
| fn=on_track_click, | |
| inputs=[app_state], | |
| outputs=[video_output, app_state], | |
| ).then( | |
| fn=lambda: ( | |
| "Tracking complete. You can save cache.", | |
| gr.update(interactive=True), # track_btn | |
| gr.update(interactive=True), # save_cache_btn | |
| ), | |
| outputs=[process_status, track_btn, save_cache_btn], | |
| ) | |
| save_cache_btn.click(fn=on_save_cache_click, inputs=[app_state], outputs=[cache_status]) | |
| if __name__ == "__main__": | |
| app.launch() | |