import spaces import subprocess import sys, os from pathlib import Path import math import pickle from typing import Any, Dict, List, Tuple, Optional import importlib, site import gradio as gr import torch import numpy as np from PIL import Image, ImageDraw import cv2 import logging # ============================================================ # Bootstrap (same style as your original app.py) # ============================================================ ROOT = Path(__file__).resolve().parent SAM2 = ROOT / "sam2-src" CKPT = SAM2 / "checkpoints" / "sam2.1_hiera_large.pt" if not CKPT.exists(): subprocess.check_call(["bash", "download_ckpts.sh"], cwd=SAM2 / "checkpoints") try: import sam2.build_sam # noqa: F401 except ModuleNotFoundError: subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "./sam2-src"], cwd=ROOT) subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "./sam2-src[notebooks]"], cwd=ROOT) try: import asmk.index # noqa: F401 except Exception: subprocess.check_call(["cythonize", "*.pyx"], cwd="./asmk-src/cython") subprocess.check_call([sys.executable, "-m", "pip", "install", "./asmk-src", "--no-build-isolation"]) if not os.path.exists("./private"): from huggingface_hub import snapshot_download snapshot_download( repo_id="nycu-cplab/3AM", local_dir="./private", repo_type="model", ) for sp in site.getsitepackages(): site.addsitedir(sp) importlib.invalidate_caches() # ============================================================ # Logging # ============================================================ logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(sys.stdout)], ) logger = logging.getLogger("app_user") # ============================================================ # Engine imports # ============================================================ from engine import ( # noqa: E402 get_predictors, get_views, prepare_sam2_inputs, must3r_features_and_output, get_single_frame_mask, get_tracked_masks, ) # ============================================================ # Globals # ============================================================ PREDICTOR_ORIGINAL = None PREDICTOR = None DEVICE = "cuda" if torch.cuda.is_available() else "cpu" torch.no_grad().__enter__() def load_models(): global PREDICTOR_ORIGINAL, PREDICTOR if PREDICTOR is None or PREDICTOR_ORIGINAL is None: logger.info(f"Initializing models on device: {DEVICE}...") PREDICTOR_ORIGINAL, PREDICTOR = get_predictors(device=DEVICE) logger.info("Models loaded successfully.") return PREDICTOR_ORIGINAL, PREDICTOR def to_device_nested(x: Any, device: str) -> Any: if torch.is_tensor(x): return x.to(device) if isinstance(x, dict): return {k: to_device_nested(v, device) for k, v in x.items()} if isinstance(x, list): return [to_device_nested(v, device) for v in x] if isinstance(x, tuple): return tuple(to_device_nested(v, device) for v in x) return x # ============================================================ # Helper Functions # ============================================================ def video_to_frames(video_path, interval=1): logger.info(f"Extracting frames from video: {video_path} with interval {interval}") cap = cv2.VideoCapture(video_path) frames = [] count = 0 while cap.isOpened(): ret, frame = cap.read() if not ret: break if count % interval == 0: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(Image.fromarray(frame_rgb)) count += 1 cap.release() logger.info(f"Extracted {len(frames)} frames (sampled from {count} total frames).") return frames def draw_points(image_pil, points, labels): img_draw = image_pil.copy() draw = ImageDraw.Draw(img_draw) r = 15 for pt, lbl in zip(points, labels): x, y = pt if lbl == 1: color = "green" elif lbl == 0: color = "red" elif lbl == 2: color = "blue" elif lbl == 3: color = "cyan" else: color = "yellow" draw.ellipse((x - r, y - r, x + r, y + r), fill=color, outline="white") return img_draw def overlay_mask(image_pil, mask, color=(255, 0, 0), alpha=0.5): if mask is None: return image_pil mask = mask > 0 img_np = np.array(image_pil) h, w = img_np.shape[:2] if mask.shape[0] != h or mask.shape[1] != w: mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool) overlay = img_np.copy() overlay[mask] = np.array(color, dtype=np.uint8) combined = cv2.addWeighted(overlay, alpha, img_np, 1 - alpha, 0) return Image.fromarray(combined) def create_video_from_masks(frames, masks_dict, output_path="output_tracking.mp4", fps=24): logger.info(f"Creating video output at {output_path} with {len(frames)} frames.") if not frames: logger.warning("No frames to create video.") return None fps = float(fps) if not (fps > 0.0): fps = 24.0 h, w = np.array(frames[0]).shape[:2] fourcc = cv2.VideoWriter_fourcc(*"mp4v") out = cv2.VideoWriter(output_path, fourcc, fps, (w, h)) for idx, frame in enumerate(frames): mask = masks_dict.get(idx) if mask is not None: pil_out = overlay_mask(frame, mask, color=(255, 0, 0), alpha=0.6) frame_np = np.array(pil_out) else: frame_np = np.array(frame) frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR) out.write(frame_bgr) out.release() logger.info("Video creation complete.") return output_path # ============================================================ # Runtime estimation # ============================================================ def estimate_video_fps(video_path: str) -> float: cap = cv2.VideoCapture(video_path) fps = float(cap.get(cv2.CAP_PROP_FPS)) or 0.0 cap.release() return fps if fps > 0.0 else 24.0 def estimate_total_frames(video_path: str) -> int: cap = cv2.VideoCapture(video_path) n = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0 cap.release() return max(1, n) MAX_GPU_SECONDS = 600 def clamp_duration(sec: int) -> int: return int(min(MAX_GPU_SECONDS, max(1, sec))) def get_duration_must3r_features(video_path, interval): total = estimate_total_frames(video_path) interval = max(1, int(interval)) processed = math.ceil(total / interval) sec_per_frame = 2 return clamp_duration(int(processed * sec_per_frame)) def get_duration_tracking(sam2_input_images, must3r_feats, must3r_outputs, start_idx, first_frame_mask): try: n = int(getattr(sam2_input_images, "shape")[0]) except Exception: n = 100 sec_per_frame = 2 return clamp_duration(int(n * sec_per_frame)) # ============================================================ # GPU Wrapped Functions # ============================================================ @spaces.GPU(duration=get_duration_must3r_features) def process_video_and_features(video_path, interval): logger.info(f"Starting GPU process: Video feature extraction (Interval: {interval})") load_models() pil_imgs = video_to_frames(video_path, interval=max(1, int(interval))) if not pil_imgs: raise ValueError("Could not extract frames from video.") views, resize_funcs = get_views(pil_imgs) must3r_feats, must3r_outputs = must3r_features_and_output(views, device=DEVICE) sam2_input_images, images_tensor = prepare_sam2_inputs(views, pil_imgs, resize_funcs) return pil_imgs, views, resize_funcs, must3r_feats, must3r_outputs, sam2_input_images, images_tensor @spaces.GPU def generate_frame_mask(image_tensor, points, labels, original_size): logger.info(f"Generating mask for single frame. Points: {len(points)}") load_models() # Ensure tensors are on GPU image_tensor = image_tensor.to(DEVICE) pts_tensor = torch.tensor(points, dtype=torch.float32).unsqueeze(0).to(DEVICE) lbl_tensor = torch.tensor(labels, dtype=torch.int32).unsqueeze(0).to(DEVICE) w, h = original_size pts_tensor[..., 0] /= (w / 1024.0) pts_tensor[..., 1] /= (h / 1024.0) mask = get_single_frame_mask( image=image_tensor, predictor_original=PREDICTOR_ORIGINAL, points=pts_tensor, labels=lbl_tensor, device=DEVICE, ) return mask.squeeze().cpu().numpy() @spaces.GPU(duration=get_duration_tracking) def run_tracking(sam2_input_images, must3r_feats, must3r_outputs, start_idx, first_frame_mask): logger.info(f"Starting tracking from frame index {start_idx}...") load_models() # Ensure everything is on GPU (cached examples load from CPU) sam2_input_images = sam2_input_images.to(DEVICE) must3r_feats = to_device_nested(must3r_feats, DEVICE) must3r_outputs = to_device_nested(must3r_outputs, DEVICE) mask_tensor = torch.tensor(first_frame_mask).to(DEVICE) > 0 tracked_masks = get_tracked_masks( sam2_input_images=sam2_input_images, must3r_feats=must3r_feats, must3r_outputs=must3r_outputs, start_idx=start_idx, first_frame_mask=mask_tensor, predictor=PREDICTOR, predictor_original=PREDICTOR_ORIGINAL, device=DEVICE, ) logger.info(f"Tracking complete. Generated masks for {len(tracked_masks)} frames.") return tracked_masks # ============================================================ # Cache loader (Examples) # ============================================================ CACHE_ROOT = Path("./private/cache") def _read_meta(meta_path: Path) -> Dict[str, Any]: with open(meta_path, "rb") as f: return pickle.load(f) def _load_frames_from_dir(frames_dir: Path) -> List[Image.Image]: frames = [] for p in sorted(frames_dir.glob("*.jpg")): frames.append(Image.open(p).convert("RGB")) return frames def list_example_dirs() -> List[Path]: if not CACHE_ROOT.exists(): return [] out = [] for d in sorted(CACHE_ROOT.iterdir()): if not d.is_dir(): continue if (d / "meta.pkl").exists() and (d / "state_tensors.pt").exists() and (d / "output_tracking.mp4").exists(): out.append(d) return out # ============================================================ # Cache loader (Examples) - GALLERY VERSION # ============================================================ def build_examples_gallery(): """Build gallery data for examples.""" gallery_items = [] cache_index = {} for idx, d in enumerate(list_example_dirs()): cache_id = d.name meta = _read_meta(d / "meta.pkl") frames_dir = d / "frames" thumb = d / "vis_img.png" if not thumb.exists(): jpgs = sorted(frames_dir.glob("*.jpg")) if not jpgs: continue thumb = jpgs[0] # Gallery item: (image, caption) caption = f"{meta.get('num_frames', 0)} Frames" gallery_items.append((str(thumb), caption)) cache_index[idx] = { "cache_id": cache_id, "dir": d, "meta": meta, "video_mp4": str(d / "output_tracking.mp4"), "frames_dir": frames_dir, "tensors": str(d / "state_tensors.pt"), } print(f"Found {len(gallery_items)} example directories.") return gallery_items, cache_index def load_cache_into_state(row_idx: int, cache_index: Dict[int, Dict[str, Any]]): info = cache_index[row_idx] meta = info["meta"] cache_id = info["cache_id"] pil_imgs = _load_frames_from_dir(info["frames_dir"]) if not pil_imgs: raise gr.Error("Example frames not found or empty.") tensors = torch.load(info["tensors"], map_location="cpu") views, resize_funcs = get_views(pil_imgs) fps_in = float(meta.get("fps_in", 24.0)) fps_out = float(meta.get("fps_out", 24.0)) interval = int(meta.get("interval", 1)) points = meta.get("points", []) labels = meta.get("labels", []) first_frame_mask = meta.get("first_frame_mask", None) state = { "pil_imgs": pil_imgs, "views": views, "resize_funcs": resize_funcs, "must3r_feats": tensors["must3r_feats"], "must3r_outputs": tensors["must3r_outputs"], "sam2_input_images": tensors["sam2_input_images"], "images_tensor": tensors["images_tensor"], "current_points": points, "current_labels": labels, "current_mask": first_frame_mask, "frame_idx": 0, "video_path": meta.get("video_name", "example"), "interval": interval, "fps_in": fps_in, "fps_out": fps_out, "output_video_path": info["video_mp4"], "loaded_from_cache": True, "cache_id": cache_id, } vis_img = overlay_mask(pil_imgs[0], state["current_mask"]) vis_img = draw_points(vis_img, state["current_points"], state["current_labels"]) slider = gr.Slider(value=0, maximum=len(pil_imgs) - 1, step=1, interactive=True) return state, vis_img, slider, info["video_mp4"], 1 def on_example_select(evt: gr.SelectData, cache_index_state): """Handle gallery selection.""" idx = evt.index state, vis_img, slider, mp4_path, interval = load_cache_into_state(idx, cache_index_state) return ( vis_img, state, slider, mp4_path, gr.update(value=interval), "Ready. Example loaded.", gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), ) # ============================================================ # UI callbacks (same semantics as your original app.py) # ============================================================ def on_video_uploaded(video_path): n_frames = estimate_total_frames(video_path) default_interval = max(1, n_frames // 100) return ( gr.update(value=default_interval, maximum=min(30, n_frames)), f"Video uploaded ({n_frames} frames). 2) Adjust interval, then click 'Load Frames'.", ) def on_video_upload_and_load(video_path, interval): logger.info(f"User uploaded video: {video_path}, Interval: {interval}") if video_path is None: return None, None, gr.Slider(value=0, maximum=0), None pil_imgs, views, resize_funcs, must3r_feats, must3r_outputs, sam2_input_images, images_tensor = process_video_and_features( video_path, int(interval) ) fps_in = estimate_video_fps(video_path) interval_i = max(1, int(interval)) fps_out = max(1.0, fps_in / interval_i) state = { "pil_imgs": pil_imgs, "views": views, "resize_funcs": resize_funcs, "must3r_feats": must3r_feats, "must3r_outputs": must3r_outputs, "sam2_input_images": sam2_input_images, "images_tensor": images_tensor, "current_points": [], "current_labels": [], "current_mask": None, "frame_idx": 0, "video_path": video_path, "interval": interval_i, "fps_in": fps_in, "fps_out": fps_out, "output_video_path": None, "loaded_from_cache": False, } first_frame = pil_imgs[0] new_slider = gr.Slider(value=0, maximum=len(pil_imgs) - 1, step=1, interactive=True) return first_frame, state, new_slider, gr.Image(value=first_frame) def on_slider_change(state, frame_idx): if not state: return None frame_idx = int(frame_idx) if frame_idx >= len(state["pil_imgs"]): frame_idx = len(state["pil_imgs"]) - 1 state["frame_idx"] = frame_idx state["current_points"] = [] state["current_labels"] = [] state["current_mask"] = None return state["pil_imgs"][frame_idx] def on_image_click(state, evt: gr.SelectData, mode): if not state: return None x, y = evt.index label_map = { "Positive Point": 1, "Negative Point": 0, "Box Top-Left": 2, "Box Bottom-Right": 3, } label = label_map[mode] state["current_points"].append([x, y]) state["current_labels"].append(label) frame_pil = state["pil_imgs"][state["frame_idx"]] vis_img = draw_points(frame_pil, state["current_points"], state["current_labels"]) if state["current_mask"] is not None: vis_img = overlay_mask(vis_img, state["current_mask"]) return vis_img def on_generate_mask_click(state): if not state: return None if not state["current_points"]: raise gr.Error("No points or boxes annotated.") num_tl = state["current_labels"].count(2) num_br = state["current_labels"].count(3) if num_tl != num_br or num_tl > 1: raise gr.Error(f"Incomplete box detected! TL={num_tl}, BR={num_br}. Must match and be <= 1.") frame_idx = state["frame_idx"] full_tensor = state["sam2_input_images"] frame_tensor = full_tensor[frame_idx].unsqueeze(0) original_size = state["pil_imgs"][frame_idx].size mask = generate_frame_mask( frame_tensor, state["current_points"], state["current_labels"], original_size, ) state["current_mask"] = mask frame_pil = state["pil_imgs"][frame_idx] vis_img = overlay_mask(frame_pil, mask) vis_img = draw_points(vis_img, state["current_points"], state["current_labels"]) return vis_img def reset_annotations(state): if not state: return None state["current_points"] = [] state["current_labels"] = [] state["current_mask"] = None frame_idx = state["frame_idx"] return state["pil_imgs"][frame_idx] def on_track_click(state): if not state or state["current_mask"] is None: raise gr.Error("Please annotate a frame and generate a mask first.") num_tl = state["current_labels"].count(2) num_br = state["current_labels"].count(3) if num_tl != num_br: raise gr.Error("Incomplete box annotations.") start_idx = state["frame_idx"] first_frame_mask = state["current_mask"] tracked_masks_dict = run_tracking( state["sam2_input_images"], state["must3r_feats"], state["must3r_outputs"], start_idx, first_frame_mask, ) output_path = create_video_from_masks( state["pil_imgs"], tracked_masks_dict, fps=state.get("fps_out", 24.0), ) state["output_video_path"] = output_path return output_path # ============================================================ # App Layout (match original, add Examples at bottom) # ============================================================ description = """

3AM: Segment Anything with Geometric Consistency in Videos

Upload a video, extract geometric features, annotate a frame, and track the object.

""" with gr.Blocks(title="3AM: 3egment Anything") as app: gr.HTML(description) gr.Markdown( """ **Workflow** 1) Upload video 2) Adjust frame interval → Load frames 3) Annotate & generate mask 4) Track through the video """ ) app_state = gr.State() with gr.Row(): with gr.Column(scale=1): gr.Markdown("## Step 1 — Upload video") video_input = gr.Video( label="Upload Video", sources=["upload"], height=512, ) gr.Markdown("## Step 2 — Set interval, then load frames") interval_slider = gr.Slider( label="Frame Interval", minimum=1, maximum=30, step=1, value=1, info="Default ≈ total_frames / 100", ) load_btn = gr.Button("Load Frames", variant="primary") process_status = gr.Textbox( label="Status", value="1) Upload a video.", interactive=False, ) with gr.Column(scale=2): gr.Markdown("## Step 3 — Annotate frame & generate mask") img_display = gr.Image( label="Annotate Frame", interactive=True, height=512, ) frame_slider = gr.Slider( label="Select Frame", minimum=0, maximum=100, step=1, value=0, ) with gr.Row(): mode_radio = gr.Radio( choices=[ "Positive Point", "Negative Point", "Box Top-Left", "Box Bottom-Right", ], value="Positive Point", label="Annotation Mode", ) with gr.Column(): gen_mask_btn = gr.Button( "Generate Mask", variant="primary", interactive=False, ) reset_btn = gr.Button( "Reset Annotations", interactive=False, ) gr.Markdown("## Step 4 — Track through the video") with gr.Row(): track_btn = gr.Button( "Start Tracking", variant="primary", scale=1, interactive=False, ) with gr.Row(): video_output = gr.Video( label="Tracking Output", autoplay=True, height=512, ) # ------------------------- # Examples table at bottom # ------------------------- gr.Markdown("## Examples (click to load)") gallery_items, cache_index = build_examples_gallery() cache_index_state = gr.State(cache_index) if gallery_items: examples_gallery = gr.Gallery( value=gallery_items, label="Examples", container=True, columns=6, object_fit="contain", show_label=False, ) examples_gallery.select( fn=on_example_select, inputs=[cache_index_state], outputs=[ img_display, app_state, frame_slider, video_output, interval_slider, process_status, gen_mask_btn, reset_btn, track_btn, ], ) else: gr.Markdown("*No examples available.*") # ============================================================ # Events (original + examples) # ============================================================ video_input.upload( fn=on_video_uploaded, inputs=video_input, outputs=[interval_slider, process_status], ) load_btn.click( fn=lambda: ( "Loading frames...", gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), ), outputs=[process_status, gen_mask_btn, reset_btn, track_btn], ).then( fn=on_video_upload_and_load, inputs=[video_input, interval_slider], outputs=[img_display, app_state, frame_slider, img_display], ).then( fn=lambda: ( "Ready. 3) Annotate and generate mask.", gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), ), outputs=[process_status, gen_mask_btn, reset_btn, track_btn], ) frame_slider.change( fn=on_slider_change, inputs=[app_state, frame_slider], outputs=[img_display], ) img_display.select( fn=on_image_click, inputs=[app_state, mode_radio], outputs=[img_display], ) gen_mask_btn.click( fn=on_generate_mask_click, inputs=[app_state], outputs=[img_display], ) reset_btn.click( fn=reset_annotations, inputs=[app_state], outputs=[img_display], ) track_btn.click( fn=lambda: "Tracking in progress...", outputs=process_status, ).then( fn=on_track_click, inputs=[app_state], outputs=[video_output], ).then( fn=lambda: "Tracking complete!", outputs=process_status, ) if __name__ == "__main__": logger.info("Starting Gradio app...") app.launch()