Spaces:
Sleeping
Sleeping
| import spaces | |
| import subprocess | |
| import sys, os | |
| from pathlib import Path | |
| import math | |
| import pickle | |
| from typing import Any, Dict, List, Tuple, Optional | |
| import importlib, site | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| import cv2 | |
| import logging | |
| # ============================================================ | |
| # Bootstrap (same style as your original app.py) | |
| # ============================================================ | |
| ROOT = Path(__file__).resolve().parent | |
| SAM2 = ROOT / "sam2-src" | |
| CKPT = SAM2 / "checkpoints" / "sam2.1_hiera_large.pt" | |
| if not CKPT.exists(): | |
| subprocess.check_call(["bash", "download_ckpts.sh"], cwd=SAM2 / "checkpoints") | |
| try: | |
| import sam2.build_sam # noqa: F401 | |
| except ModuleNotFoundError: | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "./sam2-src"], cwd=ROOT) | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "./sam2-src[notebooks]"], cwd=ROOT) | |
| try: | |
| import asmk.index # noqa: F401 | |
| except Exception: | |
| subprocess.check_call(["cythonize", "*.pyx"], cwd="./asmk-src/cython") | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "./asmk-src", "--no-build-isolation"]) | |
| if not os.path.exists("./private"): | |
| from huggingface_hub import snapshot_download | |
| snapshot_download( | |
| repo_id="nycu-cplab/3AM", | |
| local_dir="./private", | |
| repo_type="model", | |
| ) | |
| for sp in site.getsitepackages(): | |
| site.addsitedir(sp) | |
| importlib.invalidate_caches() | |
| # ============================================================ | |
| # Logging | |
| # ============================================================ | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(message)s", | |
| handlers=[logging.StreamHandler(sys.stdout)], | |
| ) | |
| logger = logging.getLogger("app_user") | |
| # ============================================================ | |
| # Engine imports | |
| # ============================================================ | |
| from engine import ( # noqa: E402 | |
| get_predictors, | |
| get_views, | |
| prepare_sam2_inputs, | |
| must3r_features_and_output, | |
| get_single_frame_mask, | |
| get_tracked_masks, | |
| ) | |
| # ============================================================ | |
| # Globals | |
| # ============================================================ | |
| PREDICTOR_ORIGINAL = None | |
| PREDICTOR = None | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch.no_grad().__enter__() | |
| def load_models(): | |
| global PREDICTOR_ORIGINAL, PREDICTOR | |
| if PREDICTOR is None or PREDICTOR_ORIGINAL is None: | |
| logger.info(f"Initializing models on device: {DEVICE}...") | |
| PREDICTOR_ORIGINAL, PREDICTOR = get_predictors(device=DEVICE) | |
| logger.info("Models loaded successfully.") | |
| return PREDICTOR_ORIGINAL, PREDICTOR | |
| def to_device_nested(x: Any, device: str) -> Any: | |
| if torch.is_tensor(x): | |
| return x.to(device) | |
| if isinstance(x, dict): | |
| return {k: to_device_nested(v, device) for k, v in x.items()} | |
| if isinstance(x, list): | |
| return [to_device_nested(v, device) for v in x] | |
| if isinstance(x, tuple): | |
| return tuple(to_device_nested(v, device) for v in x) | |
| return x | |
| # ============================================================ | |
| # Helper Functions | |
| # ============================================================ | |
| def video_to_frames(video_path, interval=1): | |
| logger.info(f"Extracting frames from video: {video_path} with interval {interval}") | |
| cap = cv2.VideoCapture(video_path) | |
| frames = [] | |
| count = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if count % interval == 0: | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(Image.fromarray(frame_rgb)) | |
| count += 1 | |
| cap.release() | |
| logger.info(f"Extracted {len(frames)} frames (sampled from {count} total frames).") | |
| return frames | |
| def draw_points(image_pil, points, labels): | |
| img_draw = image_pil.copy() | |
| draw = ImageDraw.Draw(img_draw) | |
| r = 15 | |
| for pt, lbl in zip(points, labels): | |
| x, y = pt | |
| if lbl == 1: | |
| color = "green" | |
| elif lbl == 0: | |
| color = "red" | |
| elif lbl == 2: | |
| color = "blue" | |
| elif lbl == 3: | |
| color = "cyan" | |
| else: | |
| color = "yellow" | |
| draw.ellipse((x - r, y - r, x + r, y + r), fill=color, outline="white") | |
| return img_draw | |
| def overlay_mask(image_pil, mask, color=(255, 0, 0), alpha=0.5): | |
| if mask is None: | |
| return image_pil | |
| mask = mask > 0 | |
| img_np = np.array(image_pil) | |
| h, w = img_np.shape[:2] | |
| if mask.shape[0] != h or mask.shape[1] != w: | |
| mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool) | |
| overlay = img_np.copy() | |
| overlay[mask] = np.array(color, dtype=np.uint8) | |
| combined = cv2.addWeighted(overlay, alpha, img_np, 1 - alpha, 0) | |
| return Image.fromarray(combined) | |
| def create_video_from_masks(frames, masks_dict, output_path="output_tracking.mp4", fps=24): | |
| logger.info(f"Creating video output at {output_path} with {len(frames)} frames.") | |
| if not frames: | |
| logger.warning("No frames to create video.") | |
| return None | |
| fps = float(fps) | |
| if not (fps > 0.0): | |
| fps = 24.0 | |
| h, w = np.array(frames[0]).shape[:2] | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (w, h)) | |
| for idx, frame in enumerate(frames): | |
| mask = masks_dict.get(idx) | |
| if mask is not None: | |
| pil_out = overlay_mask(frame, mask, color=(255, 0, 0), alpha=0.6) | |
| frame_np = np.array(pil_out) | |
| else: | |
| frame_np = np.array(frame) | |
| frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR) | |
| out.write(frame_bgr) | |
| out.release() | |
| logger.info("Video creation complete.") | |
| return output_path | |
| # ============================================================ | |
| # Runtime estimation | |
| # ============================================================ | |
| def estimate_video_fps(video_path: str) -> float: | |
| cap = cv2.VideoCapture(video_path) | |
| fps = float(cap.get(cv2.CAP_PROP_FPS)) or 0.0 | |
| cap.release() | |
| return fps if fps > 0.0 else 24.0 | |
| def estimate_total_frames(video_path: str) -> int: | |
| cap = cv2.VideoCapture(video_path) | |
| n = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0 | |
| cap.release() | |
| return max(1, n) | |
| MAX_GPU_SECONDS = 600 | |
| def clamp_duration(sec: int) -> int: | |
| return int(min(MAX_GPU_SECONDS, max(1, sec))) | |
| def get_duration_must3r_features(video_path, interval): | |
| total = estimate_total_frames(video_path) | |
| interval = max(1, int(interval)) | |
| processed = math.ceil(total / interval) | |
| sec_per_frame = 2 | |
| return clamp_duration(int(processed * sec_per_frame)) | |
| def get_duration_tracking(sam2_input_images, must3r_feats, must3r_outputs, start_idx, first_frame_mask): | |
| try: | |
| n = int(getattr(sam2_input_images, "shape")[0]) | |
| except Exception: | |
| n = 100 | |
| sec_per_frame = 2 | |
| return clamp_duration(int(n * sec_per_frame)) | |
| # ============================================================ | |
| # GPU Wrapped Functions | |
| # ============================================================ | |
| def process_video_and_features(video_path, interval): | |
| logger.info(f"Starting GPU process: Video feature extraction (Interval: {interval})") | |
| load_models() | |
| pil_imgs = video_to_frames(video_path, interval=max(1, int(interval))) | |
| if not pil_imgs: | |
| raise ValueError("Could not extract frames from video.") | |
| views, resize_funcs = get_views(pil_imgs) | |
| must3r_feats, must3r_outputs = must3r_features_and_output(views, device=DEVICE) | |
| sam2_input_images, images_tensor = prepare_sam2_inputs(views, pil_imgs, resize_funcs) | |
| return pil_imgs, views, resize_funcs, must3r_feats, must3r_outputs, sam2_input_images, images_tensor | |
| def generate_frame_mask(image_tensor, points, labels, original_size): | |
| logger.info(f"Generating mask for single frame. Points: {len(points)}") | |
| load_models() | |
| # Ensure tensors are on GPU | |
| image_tensor = image_tensor.to(DEVICE) | |
| pts_tensor = torch.tensor(points, dtype=torch.float32).unsqueeze(0).to(DEVICE) | |
| lbl_tensor = torch.tensor(labels, dtype=torch.int32).unsqueeze(0).to(DEVICE) | |
| w, h = original_size | |
| pts_tensor[..., 0] /= (w / 1024.0) | |
| pts_tensor[..., 1] /= (h / 1024.0) | |
| mask = get_single_frame_mask( | |
| image=image_tensor, | |
| predictor_original=PREDICTOR_ORIGINAL, | |
| points=pts_tensor, | |
| labels=lbl_tensor, | |
| device=DEVICE, | |
| ) | |
| return mask.squeeze().cpu().numpy() | |
| def run_tracking(sam2_input_images, must3r_feats, must3r_outputs, start_idx, first_frame_mask): | |
| logger.info(f"Starting tracking from frame index {start_idx}...") | |
| load_models() | |
| # Ensure everything is on GPU (cached examples load from CPU) | |
| sam2_input_images = sam2_input_images.to(DEVICE) | |
| must3r_feats = to_device_nested(must3r_feats, DEVICE) | |
| must3r_outputs = to_device_nested(must3r_outputs, DEVICE) | |
| mask_tensor = torch.tensor(first_frame_mask).to(DEVICE) > 0 | |
| tracked_masks = get_tracked_masks( | |
| sam2_input_images=sam2_input_images, | |
| must3r_feats=must3r_feats, | |
| must3r_outputs=must3r_outputs, | |
| start_idx=start_idx, | |
| first_frame_mask=mask_tensor, | |
| predictor=PREDICTOR, | |
| predictor_original=PREDICTOR_ORIGINAL, | |
| device=DEVICE, | |
| ) | |
| logger.info(f"Tracking complete. Generated masks for {len(tracked_masks)} frames.") | |
| return tracked_masks | |
| # ============================================================ | |
| # Cache loader (Examples) | |
| # ============================================================ | |
| CACHE_ROOT = Path("./private/cache") | |
| def _read_meta(meta_path: Path) -> Dict[str, Any]: | |
| with open(meta_path, "rb") as f: | |
| return pickle.load(f) | |
| def _load_frames_from_dir(frames_dir: Path) -> List[Image.Image]: | |
| frames = [] | |
| for p in sorted(frames_dir.glob("*.jpg")): | |
| frames.append(Image.open(p).convert("RGB")) | |
| return frames | |
| def list_example_dirs() -> List[Path]: | |
| if not CACHE_ROOT.exists(): | |
| return [] | |
| out = [] | |
| for d in sorted(CACHE_ROOT.iterdir()): | |
| if not d.is_dir(): | |
| continue | |
| if (d / "meta.pkl").exists() and (d / "state_tensors.pt").exists() and (d / "output_tracking.mp4").exists(): | |
| out.append(d) | |
| return out | |
| # ============================================================ | |
| # Cache loader (Examples) - GALLERY VERSION | |
| # ============================================================ | |
| def build_examples_gallery(): | |
| """Build gallery data for examples.""" | |
| gallery_items = [] | |
| cache_index = {} | |
| for idx, d in enumerate(list_example_dirs()): | |
| cache_id = d.name | |
| meta = _read_meta(d / "meta.pkl") | |
| frames_dir = d / "frames" | |
| thumb = d / "vis_img.png" | |
| if not thumb.exists(): | |
| jpgs = sorted(frames_dir.glob("*.jpg")) | |
| if not jpgs: | |
| continue | |
| thumb = jpgs[0] | |
| # Gallery item: (image, caption) | |
| caption = f"{meta.get('num_frames', 0)} Frames" | |
| gallery_items.append((str(thumb), caption)) | |
| cache_index[idx] = { | |
| "cache_id": cache_id, | |
| "dir": d, | |
| "meta": meta, | |
| "video_mp4": str(d / "output_tracking.mp4"), | |
| "frames_dir": frames_dir, | |
| "tensors": str(d / "state_tensors.pt"), | |
| } | |
| print(f"Found {len(gallery_items)} example directories.") | |
| return gallery_items, cache_index | |
| def load_cache_into_state(row_idx: int, cache_index: Dict[int, Dict[str, Any]]): | |
| info = cache_index[row_idx] | |
| meta = info["meta"] | |
| cache_id = info["cache_id"] | |
| pil_imgs = _load_frames_from_dir(info["frames_dir"]) | |
| if not pil_imgs: | |
| raise gr.Error("Example frames not found or empty.") | |
| tensors = torch.load(info["tensors"], map_location="cpu") | |
| views, resize_funcs = get_views(pil_imgs) | |
| fps_in = float(meta.get("fps_in", 24.0)) | |
| fps_out = float(meta.get("fps_out", 24.0)) | |
| interval = int(meta.get("interval", 1)) | |
| points = meta.get("points", []) | |
| labels = meta.get("labels", []) | |
| first_frame_mask = meta.get("first_frame_mask", None) | |
| state = { | |
| "pil_imgs": pil_imgs, | |
| "views": views, | |
| "resize_funcs": resize_funcs, | |
| "must3r_feats": tensors["must3r_feats"], | |
| "must3r_outputs": tensors["must3r_outputs"], | |
| "sam2_input_images": tensors["sam2_input_images"], | |
| "images_tensor": tensors["images_tensor"], | |
| "current_points": points, | |
| "current_labels": labels, | |
| "current_mask": first_frame_mask, | |
| "frame_idx": 0, | |
| "video_path": meta.get("video_name", "example"), | |
| "interval": interval, | |
| "fps_in": fps_in, | |
| "fps_out": fps_out, | |
| "output_video_path": info["video_mp4"], | |
| "loaded_from_cache": True, | |
| "cache_id": cache_id, | |
| } | |
| vis_img = overlay_mask(pil_imgs[0], state["current_mask"]) | |
| vis_img = draw_points(vis_img, state["current_points"], state["current_labels"]) | |
| slider = gr.Slider(value=0, maximum=len(pil_imgs) - 1, step=1, interactive=True) | |
| return state, vis_img, slider, info["video_mp4"], 1 | |
| def on_example_select(evt: gr.SelectData, cache_index_state): | |
| """Handle gallery selection.""" | |
| idx = evt.index | |
| state, vis_img, slider, mp4_path, interval = load_cache_into_state(idx, cache_index_state) | |
| return ( | |
| vis_img, | |
| state, | |
| slider, | |
| mp4_path, | |
| gr.update(value=interval), | |
| "Ready. Example loaded.", | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| ) | |
| # ============================================================ | |
| # UI callbacks (same semantics as your original app.py) | |
| # ============================================================ | |
| def on_video_uploaded(video_path): | |
| n_frames = estimate_total_frames(video_path) | |
| default_interval = max(1, n_frames // 100) | |
| return ( | |
| gr.update(value=default_interval, maximum=min(30, n_frames)), | |
| f"Video uploaded ({n_frames} frames). 2) Adjust interval, then click 'Load Frames'.", | |
| ) | |
| def on_video_upload_and_load(video_path, interval): | |
| logger.info(f"User uploaded video: {video_path}, Interval: {interval}") | |
| if video_path is None: | |
| return None, None, gr.Slider(value=0, maximum=0), None | |
| pil_imgs, views, resize_funcs, must3r_feats, must3r_outputs, sam2_input_images, images_tensor = process_video_and_features( | |
| video_path, int(interval) | |
| ) | |
| fps_in = estimate_video_fps(video_path) | |
| interval_i = max(1, int(interval)) | |
| fps_out = max(1.0, fps_in / interval_i) | |
| state = { | |
| "pil_imgs": pil_imgs, | |
| "views": views, | |
| "resize_funcs": resize_funcs, | |
| "must3r_feats": must3r_feats, | |
| "must3r_outputs": must3r_outputs, | |
| "sam2_input_images": sam2_input_images, | |
| "images_tensor": images_tensor, | |
| "current_points": [], | |
| "current_labels": [], | |
| "current_mask": None, | |
| "frame_idx": 0, | |
| "video_path": video_path, | |
| "interval": interval_i, | |
| "fps_in": fps_in, | |
| "fps_out": fps_out, | |
| "output_video_path": None, | |
| "loaded_from_cache": False, | |
| } | |
| first_frame = pil_imgs[0] | |
| new_slider = gr.Slider(value=0, maximum=len(pil_imgs) - 1, step=1, interactive=True) | |
| return first_frame, state, new_slider, gr.Image(value=first_frame) | |
| def on_slider_change(state, frame_idx): | |
| if not state: | |
| return None | |
| frame_idx = int(frame_idx) | |
| if frame_idx >= len(state["pil_imgs"]): | |
| frame_idx = len(state["pil_imgs"]) - 1 | |
| state["frame_idx"] = frame_idx | |
| state["current_points"] = [] | |
| state["current_labels"] = [] | |
| state["current_mask"] = None | |
| return state["pil_imgs"][frame_idx] | |
| def on_image_click(state, evt: gr.SelectData, mode): | |
| if not state: | |
| return None | |
| x, y = evt.index | |
| label_map = { | |
| "Positive Point": 1, | |
| "Negative Point": 0, | |
| "Box Top-Left": 2, | |
| "Box Bottom-Right": 3, | |
| } | |
| label = label_map[mode] | |
| state["current_points"].append([x, y]) | |
| state["current_labels"].append(label) | |
| frame_pil = state["pil_imgs"][state["frame_idx"]] | |
| vis_img = draw_points(frame_pil, state["current_points"], state["current_labels"]) | |
| if state["current_mask"] is not None: | |
| vis_img = overlay_mask(vis_img, state["current_mask"]) | |
| return vis_img | |
| def on_generate_mask_click(state): | |
| if not state: | |
| return None | |
| if not state["current_points"]: | |
| raise gr.Error("No points or boxes annotated.") | |
| num_tl = state["current_labels"].count(2) | |
| num_br = state["current_labels"].count(3) | |
| if num_tl != num_br or num_tl > 1: | |
| raise gr.Error(f"Incomplete box detected! TL={num_tl}, BR={num_br}. Must match and be <= 1.") | |
| frame_idx = state["frame_idx"] | |
| full_tensor = state["sam2_input_images"] | |
| frame_tensor = full_tensor[frame_idx].unsqueeze(0) | |
| original_size = state["pil_imgs"][frame_idx].size | |
| mask = generate_frame_mask( | |
| frame_tensor, | |
| state["current_points"], | |
| state["current_labels"], | |
| original_size, | |
| ) | |
| state["current_mask"] = mask | |
| frame_pil = state["pil_imgs"][frame_idx] | |
| vis_img = overlay_mask(frame_pil, mask) | |
| vis_img = draw_points(vis_img, state["current_points"], state["current_labels"]) | |
| return vis_img | |
| def reset_annotations(state): | |
| if not state: | |
| return None | |
| state["current_points"] = [] | |
| state["current_labels"] = [] | |
| state["current_mask"] = None | |
| frame_idx = state["frame_idx"] | |
| return state["pil_imgs"][frame_idx] | |
| def on_track_click(state): | |
| if not state or state["current_mask"] is None: | |
| raise gr.Error("Please annotate a frame and generate a mask first.") | |
| num_tl = state["current_labels"].count(2) | |
| num_br = state["current_labels"].count(3) | |
| if num_tl != num_br: | |
| raise gr.Error("Incomplete box annotations.") | |
| start_idx = state["frame_idx"] | |
| first_frame_mask = state["current_mask"] | |
| tracked_masks_dict = run_tracking( | |
| state["sam2_input_images"], | |
| state["must3r_feats"], | |
| state["must3r_outputs"], | |
| start_idx, | |
| first_frame_mask, | |
| ) | |
| output_path = create_video_from_masks( | |
| state["pil_imgs"], | |
| tracked_masks_dict, | |
| fps=state.get("fps_out", 24.0), | |
| ) | |
| state["output_video_path"] = output_path | |
| return output_path | |
| # ============================================================ | |
| # App Layout (match original, add Examples at bottom) | |
| # ============================================================ | |
| description = """ | |
| <div style="text-align: center;"> | |
| <h1>3AM: Segment Anything with Geometric Consistency in Videos </h1> | |
| <p>Upload a video, extract geometric features, annotate a frame, and track the object.</p> | |
| </div> | |
| """ | |
| with gr.Blocks(title="3AM: 3egment Anything") as app: | |
| gr.HTML(description) | |
| gr.Markdown( | |
| """ | |
| **Workflow** | |
| 1) Upload video | |
| 2) Adjust frame interval → Load frames | |
| 3) Annotate & generate mask | |
| 4) Track through the video | |
| """ | |
| ) | |
| app_state = gr.State() | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Step 1 — Upload video") | |
| video_input = gr.Video( | |
| label="Upload Video", | |
| sources=["upload"], | |
| height=512, | |
| ) | |
| gr.Markdown("## Step 2 — Set interval, then load frames") | |
| interval_slider = gr.Slider( | |
| label="Frame Interval", | |
| minimum=1, | |
| maximum=30, | |
| step=1, | |
| value=1, | |
| info="Default ≈ total_frames / 100", | |
| ) | |
| load_btn = gr.Button("Load Frames", variant="primary") | |
| process_status = gr.Textbox( | |
| label="Status", | |
| value="1) Upload a video.", | |
| interactive=False, | |
| ) | |
| with gr.Column(scale=2): | |
| gr.Markdown("## Step 3 — Annotate frame & generate mask") | |
| img_display = gr.Image( | |
| label="Annotate Frame", | |
| interactive=True, | |
| height=512, | |
| ) | |
| frame_slider = gr.Slider( | |
| label="Select Frame", | |
| minimum=0, | |
| maximum=100, | |
| step=1, | |
| value=0, | |
| ) | |
| with gr.Row(): | |
| mode_radio = gr.Radio( | |
| choices=[ | |
| "Positive Point", | |
| "Negative Point", | |
| "Box Top-Left", | |
| "Box Bottom-Right", | |
| ], | |
| value="Positive Point", | |
| label="Annotation Mode", | |
| ) | |
| with gr.Column(): | |
| gen_mask_btn = gr.Button( | |
| "Generate Mask", | |
| variant="primary", | |
| interactive=False, | |
| ) | |
| reset_btn = gr.Button( | |
| "Reset Annotations", | |
| interactive=False, | |
| ) | |
| gr.Markdown("## Step 4 — Track through the video") | |
| with gr.Row(): | |
| track_btn = gr.Button( | |
| "Start Tracking", | |
| variant="primary", | |
| scale=1, | |
| interactive=False, | |
| ) | |
| with gr.Row(): | |
| video_output = gr.Video( | |
| label="Tracking Output", | |
| autoplay=True, | |
| height=512, | |
| ) | |
| # ------------------------- | |
| # Examples table at bottom | |
| # ------------------------- | |
| gr.Markdown("## Examples (click to load)") | |
| gallery_items, cache_index = build_examples_gallery() | |
| cache_index_state = gr.State(cache_index) | |
| if gallery_items: | |
| examples_gallery = gr.Gallery( | |
| value=gallery_items, | |
| label="Examples", | |
| container=True, | |
| columns=6, | |
| object_fit="contain", | |
| show_label=False, | |
| ) | |
| examples_gallery.select( | |
| fn=on_example_select, | |
| inputs=[cache_index_state], | |
| outputs=[ | |
| img_display, | |
| app_state, | |
| frame_slider, | |
| video_output, | |
| interval_slider, | |
| process_status, | |
| gen_mask_btn, | |
| reset_btn, | |
| track_btn, | |
| ], | |
| ) | |
| else: | |
| gr.Markdown("*No examples available.*") | |
| # ============================================================ | |
| # Events (original + examples) | |
| # ============================================================ | |
| video_input.upload( | |
| fn=on_video_uploaded, | |
| inputs=video_input, | |
| outputs=[interval_slider, process_status], | |
| ) | |
| load_btn.click( | |
| fn=lambda: ( | |
| "Loading frames...", | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| ), | |
| outputs=[process_status, gen_mask_btn, reset_btn, track_btn], | |
| ).then( | |
| fn=on_video_upload_and_load, | |
| inputs=[video_input, interval_slider], | |
| outputs=[img_display, app_state, frame_slider, img_display], | |
| ).then( | |
| fn=lambda: ( | |
| "Ready. 3) Annotate and generate mask.", | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| ), | |
| outputs=[process_status, gen_mask_btn, reset_btn, track_btn], | |
| ) | |
| frame_slider.change( | |
| fn=on_slider_change, | |
| inputs=[app_state, frame_slider], | |
| outputs=[img_display], | |
| ) | |
| img_display.select( | |
| fn=on_image_click, | |
| inputs=[app_state, mode_radio], | |
| outputs=[img_display], | |
| ) | |
| gen_mask_btn.click( | |
| fn=on_generate_mask_click, | |
| inputs=[app_state], | |
| outputs=[img_display], | |
| ) | |
| reset_btn.click( | |
| fn=reset_annotations, | |
| inputs=[app_state], | |
| outputs=[img_display], | |
| ) | |
| track_btn.click( | |
| fn=lambda: "Tracking in progress...", | |
| outputs=process_status, | |
| ).then( | |
| fn=on_track_click, | |
| inputs=[app_state], | |
| outputs=[video_output], | |
| ).then( | |
| fn=lambda: "Tracking complete!", | |
| outputs=process_status, | |
| ) | |
| if __name__ == "__main__": | |
| logger.info("Starting Gradio app...") | |
| app.launch() | |