# app_cache.py # Purpose: # - Same UI flow (upload -> load frames -> annotate -> generate mask -> track) # - After tracking, enable "Save Cache" # - You can create multiple caches by repeating the workflow # # Cache contents per example: # cache// # meta.pkl # frames/*.jpg # state_tensors.pt (must3r_feats, must3r_outputs, sam2_input_images, images_tensor) on CPU # output_tracking.mp4 # # Notes: # - We do NOT pickle views/resize_funcs (recomputed on load). # - We store frames as JPEG to avoid pickling PIL and to be deterministic/reloadable. import spaces import subprocess import sys, os from pathlib import Path import math import hashlib import pickle from datetime import datetime from typing import Any, Dict, List, Tuple import importlib, site import gradio as gr import torch import numpy as np from PIL import Image, ImageDraw import cv2 import logging # ---------------------------- # Project bootstrap # ---------------------------- ROOT = Path(__file__).resolve().parent SAM2 = ROOT / "sam2-src" CKPT = SAM2 / "checkpoints" / "sam2.1_hiera_large.pt" # download sam2 checkpoints if not CKPT.exists(): subprocess.check_call(["bash", "download_ckpts.sh"], cwd=SAM2 / "checkpoints") # install sam2 try: import sam2.build_sam # noqa except ModuleNotFoundError: subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "./sam2-src"], cwd=ROOT) subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "./sam2-src[notebooks]"], cwd=ROOT) # install asmk try: import asmk.index # noqa: F401 except Exception: subprocess.check_call(["cythonize", "*.pyx"], cwd="./asmk-src/cython") subprocess.check_call([sys.executable, "-m", "pip", "install", "./asmk-src", "--no-build-isolation"]) # download private checkpoints if not os.path.exists("./private"): from huggingface_hub import snapshot_download snapshot_download( repo_id="nycu-cplab/3AM", local_dir="./private", repo_type="model", ) for sp in site.getsitepackages(): site.addsitedir(sp) importlib.invalidate_caches() # ---------------------------- # Logging # ---------------------------- logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(sys.stdout)], ) logger = logging.getLogger("app_cache") # ---------------------------- # Engine imports # ---------------------------- from engine import ( get_predictors, get_views, prepare_sam2_inputs, must3r_features_and_output, get_single_frame_mask, get_tracked_masks, ) # ---------------------------- # Globals # ---------------------------- PREDICTOR_ORIGINAL = None PREDICTOR = None DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def load_models(): global PREDICTOR_ORIGINAL, PREDICTOR if PREDICTOR is None or PREDICTOR_ORIGINAL is None: logger.info(f"Initializing models on device: {DEVICE}...") PREDICTOR_ORIGINAL, PREDICTOR = get_predictors(device=DEVICE) logger.info("Models loaded successfully.") return PREDICTOR_ORIGINAL, PREDICTOR # Ensure no_grad globally (as you had) torch.no_grad().__enter__() # ---------------------------- # Video / visualization helpers # ---------------------------- def video_to_frames(video_path, interval=1): logger.info(f"Extracting frames from video: {video_path} with interval={interval}") cap = cv2.VideoCapture(video_path) frames = [] count = 0 while cap.isOpened(): ret, frame = cap.read() if not ret: break if count % interval == 0: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(Image.fromarray(frame_rgb)) count += 1 cap.release() logger.info(f"Extracted {len(frames)} frames (sampled from {count} total).") return frames def draw_points(image_pil, points, labels): img_draw = image_pil.copy() draw = ImageDraw.Draw(img_draw) r = 7.5 for pt, lbl in zip(points, labels): x, y = pt if lbl == 1: color = "green" elif lbl == 0: color = "red" elif lbl == 2: color = "blue" elif lbl == 3: color = "cyan" else: color = "yellow" draw.ellipse((x-r, y-r, x+r, y+r), fill=color, outline="white") return img_draw def overlay_mask(image_pil, mask, color=(255, 0, 0), alpha=0.5): if mask is None: return image_pil mask = mask > 0 img_np = np.array(image_pil) h, w = img_np.shape[:2] if mask.shape[0] != h or mask.shape[1] != w: mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool) overlay = img_np.copy() overlay[mask] = np.array(color, dtype=np.uint8) combined = cv2.addWeighted(overlay, alpha, img_np, 1 - alpha, 0) return Image.fromarray(combined) def create_video_from_masks(frames, masks_dict, output_path="output_tracking.mp4", fps=24): logger.info(f"Creating video output at {output_path} with {len(frames)} frames.") if not frames: return None fps = float(fps) if not (fps > 0.0): fps = 24.0 h, w = np.array(frames[0]).shape[:2] fourcc = cv2.VideoWriter_fourcc(*"mp4v") out = cv2.VideoWriter(output_path, fourcc, fps, (w, h)) for idx, frame in enumerate(frames): mask = masks_dict.get(idx) if mask is not None: pil_out = overlay_mask(frame, mask, color=(255, 0, 0), alpha=0.6) frame_np = np.array(pil_out) else: frame_np = np.array(frame) frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR) out.write(frame_bgr) out.release() return output_path # ---------------------------- # Runtime estimation helpers # ---------------------------- def estimate_video_fps(video_path: str) -> float: cap = cv2.VideoCapture(video_path) fps = float(cap.get(cv2.CAP_PROP_FPS)) or 0.0 cap.release() return fps if fps > 0.0 else 24.0 def estimate_total_frames(video_path: str) -> int: cap = cv2.VideoCapture(video_path) n = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0 cap.release() return max(1, n) MAX_GPU_SECONDS = 600 def clamp_duration(sec: int) -> int: return int(min(MAX_GPU_SECONDS, max(1, sec))) def get_duration_must3r_features(video_path, interval): total = estimate_total_frames(video_path) interval = max(1, int(interval)) processed = math.ceil(total / interval) sec_per_frame = 2 return clamp_duration(int(processed * sec_per_frame)) def get_duration_tracking(sam2_input_images, must3r_feats, must3r_outputs, start_idx, first_frame_mask): try: n = int(getattr(sam2_input_images, "shape")[0]) except Exception: n = 100 sec_per_frame = 2 return clamp_duration(int(n * sec_per_frame)) # ---------------------------- # GPU functions # ---------------------------- @spaces.GPU(duration=get_duration_must3r_features) def process_video_and_features(video_path, interval): logger.info(f"GPU: feature extraction interval={interval}") load_models() pil_imgs = video_to_frames(video_path, interval=max(1, int(interval))) if not pil_imgs: raise ValueError("Could not extract frames.") views, resize_funcs = get_views(pil_imgs) must3r_feats, must3r_outputs = must3r_features_and_output(views, device=DEVICE) sam2_input_images, images_tensor = prepare_sam2_inputs(views, pil_imgs, resize_funcs) return pil_imgs, views, resize_funcs, must3r_feats, must3r_outputs, sam2_input_images, images_tensor @spaces.GPU def generate_frame_mask(image_tensor, points, labels, original_size): logger.info(f"GPU: generate mask points={len(points)}") load_models() pts_tensor = torch.tensor(points, dtype=torch.float32).unsqueeze(0).to(DEVICE) lbl_tensor = torch.tensor(labels, dtype=torch.int32).unsqueeze(0).to(DEVICE) w, h = original_size pts_tensor[..., 0] /= (w / 1024.0) pts_tensor[..., 1] /= (h / 1024.0) mask = get_single_frame_mask( image=image_tensor, predictor_original=PREDICTOR_ORIGINAL, points=pts_tensor, labels=lbl_tensor, device=DEVICE, ) return mask.squeeze().cpu().numpy() @spaces.GPU(duration=get_duration_tracking) def run_tracking(sam2_input_images, must3r_feats, must3r_outputs, start_idx, first_frame_mask): logger.info(f"GPU: tracking start_idx={start_idx}") load_models() mask_tensor = torch.tensor(first_frame_mask).to(DEVICE) > 0 tracked_masks = get_tracked_masks( sam2_input_images=sam2_input_images, must3r_feats=must3r_feats, must3r_outputs=must3r_outputs, start_idx=start_idx, first_frame_mask=mask_tensor, predictor=PREDICTOR, predictor_original=PREDICTOR_ORIGINAL, device=DEVICE, ) return tracked_masks # ---------------------------- # Cache utilities # ---------------------------- CACHE_DIR = Path("./tmpcache/cache") CACHE_DIR.mkdir(parents=True, exist_ok=True) def _make_cache_key(video_path: str, interval: int, start_idx: int) -> str: name = Path(video_path).name if video_path else "video" stamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") s = f"{name}|interval={interval}|start={start_idx}|{stamp}" return hashlib.sha256(s.encode("utf-8")).hexdigest()[:16] def _cache_paths(key: str) -> Dict[str, Path]: base = CACHE_DIR / key base.mkdir(parents=True, exist_ok=True) return { "base": base, "meta": base / "meta.pkl", "frames_dir": base / "frames", "vis_img": base / "vis_img.png", "tensors": base / "state_tensors.pt", "video": base / "output_tracking.mp4", } def _save_frames_as_jpg(pil_imgs: List[Image.Image], frames_dir: Path, quality: int = 95) -> None: frames_dir.mkdir(parents=True, exist_ok=True) for i, im in enumerate(pil_imgs): im.save(frames_dir / f"{i:06d}.jpg", "JPEG", quality=quality, subsampling=0) def _to_cpu(obj: Any) -> Any: if torch.is_tensor(obj): return obj.detach().to("cpu") if isinstance(obj, dict): return {k: _to_cpu(v) for k, v in obj.items()} if isinstance(obj, (list, tuple)): out = [_to_cpu(v) for v in obj] return type(obj)(out) if isinstance(obj, tuple) else out return obj def save_full_cache_from_state(state: Dict[str, Any]) -> str: if not state: raise ValueError("Empty state.") required = [ "pil_imgs", "must3r_feats", "must3r_outputs", "sam2_input_images", "images_tensor", "output_video_path", "video_path", "interval", "fps_in", "fps_out", "last_tracking_start_idx", ] missing = [k for k in required if k not in state or state[k] is None] if missing: raise ValueError(f"State missing fields: {missing}") key = _make_cache_key( str(state["video_path"]), int(state["interval"]), int(state["last_tracking_start_idx"]), ) paths = _cache_paths(key) _save_frames_as_jpg(state["pil_imgs"], paths["frames_dir"]) state['current_vis_img'].save(paths["vis_img"]) print(f"Saving tensors to cache...") torch.save( { "must3r_feats": _to_cpu(state["must3r_feats"]), "must3r_outputs": _to_cpu(state["must3r_outputs"]), "sam2_input_images": _to_cpu(state["sam2_input_images"]), "images_tensor": _to_cpu(state["images_tensor"]), }, paths["tensors"], ) src = Path(state["output_video_path"]) if not src.exists(): raise FileNotFoundError(f"Output video not found: {src}") dst = paths["video"] if src.resolve() != dst.resolve(): dst.write_bytes(src.read_bytes()) meta = { "video_name": Path(str(state["video_path"])).name, "interval": int(state["interval"]), "fps_in": float(state["fps_in"]), "fps_out": float(state["fps_out"]), "num_frames": int(len(state["pil_imgs"])), "start_idx": int(state["last_tracking_start_idx"]), "points": list(state.get("last_points", [])), "labels": list(state.get("last_labels", [])), 'first_frame_mask': state.get("current_mask", None), "cache_key": key, } with open(paths["meta"], "wb") as f: pickle.dump(meta, f) print(f"Cache saved at key: {key}") return key # ---------------------------- # UI callbacks # ---------------------------- def on_video_upload(video_path, interval): if video_path is None: return None, None, gr.Slider(value=0, maximum=0), None pil_imgs, views, resize_funcs, must3r_feats, must3r_outputs, sam2_input_images, images_tensor = process_video_and_features( video_path, int(interval) ) fps_in = estimate_video_fps(video_path) interval_i = max(1, int(interval)) fps_out = max(1.0, fps_in / interval_i) state = { "pil_imgs": pil_imgs, "views": views, "resize_funcs": resize_funcs, "must3r_feats": must3r_feats, "must3r_outputs": must3r_outputs, "sam2_input_images": sam2_input_images, "images_tensor": images_tensor, "current_points": [], "current_labels": [], "current_mask": None, "frame_idx": 0, "video_path": video_path, "interval": interval_i, "fps_in": fps_in, "fps_out": fps_out, # tracking outputs (filled later) "output_video_path": None, "last_tracking_start_idx": None, "last_points": None, "last_labels": None, } first_frame = pil_imgs[0] new_slider = gr.Slider(value=0, maximum=len(pil_imgs) - 1, step=1, interactive=True) return first_frame, state, new_slider, gr.Image(value=first_frame) def on_slider_change(state, frame_idx): if not state: return None frame_idx = int(frame_idx) frame_idx = min(frame_idx, len(state["pil_imgs"]) - 1) state["frame_idx"] = frame_idx state["current_points"] = [] state["current_labels"] = [] state["current_mask"] = None frame = state["pil_imgs"][frame_idx] return frame def on_image_click(state, evt: gr.SelectData, mode): if not state: return None x, y = evt.index label_map = { "Positive Point": 1, "Negative Point": 0, "Box Top-Left": 2, "Box Bottom-Right": 3, } label = label_map[mode] state["current_points"].append([x, y]) state["current_labels"].append(label) frame_pil = state["pil_imgs"][state["frame_idx"]] vis_img = draw_points(frame_pil, state["current_points"], state["current_labels"]) if state["current_mask"] is not None: vis_img = overlay_mask(vis_img, state["current_mask"]) return vis_img def on_generate_mask_click(state): if not state: return None if not state["current_points"]: raise gr.Error("No points or boxes annotated.") num_tl = state["current_labels"].count(2) num_br = state["current_labels"].count(3) if num_tl != num_br or num_tl > 1: raise gr.Error(f"Incomplete box: TL={num_tl}, BR={num_br}. Must match and be <= 1.") frame_idx = state["frame_idx"] full_tensor = state["sam2_input_images"] frame_tensor = full_tensor[frame_idx].unsqueeze(0) original_size = state["pil_imgs"][frame_idx].size mask = generate_frame_mask( frame_tensor, state["current_points"], state["current_labels"], original_size, ) state["current_mask"] = mask frame_pil = state["pil_imgs"][frame_idx] vis_img = overlay_mask(frame_pil, mask) vis_img = draw_points(vis_img, state["current_points"], state["current_labels"]) state["current_vis_img"] = vis_img.copy() return vis_img def reset_annotations(state): if not state: return None state["current_points"] = [] state["current_labels"] = [] state["current_mask"] = None frame_idx = state["frame_idx"] return state["pil_imgs"][frame_idx] def on_track_click(state): if not state or state["current_mask"] is None: raise gr.Error("Generate a mask first.") num_tl = state["current_labels"].count(2) num_br = state["current_labels"].count(3) if num_tl != num_br: raise gr.Error("Incomplete box annotations.") start_idx = int(state["frame_idx"]) first_frame_mask = state["current_mask"] tracked_masks_dict = run_tracking( state["sam2_input_images"], state["must3r_feats"], state["must3r_outputs"], start_idx, first_frame_mask, ) output_path = create_video_from_masks( state["pil_imgs"], tracked_masks_dict, fps=state.get("fps_out", 24.0), ) state["output_video_path"] = output_path state["last_tracking_start_idx"] = start_idx state["last_points"] = list(state.get("current_points", [])) state["last_labels"] = list(state.get("current_labels", [])) print(f"Tracking Complete") return output_path, state def on_save_cache_click(state): key = save_full_cache_from_state(state) return f"Saved cache key: {key}" # ---------------------------- # UI layout # ---------------------------- description = """

3AM: 3egment Anything with Geometric Consistency in Videos

Cache-builder UI: run full pipeline, then save caches for user examples.

""" with gr.Blocks(title="3AM Cache Builder") as app: gr.HTML(description) app_state = gr.State() with gr.Row(): with gr.Column(scale=1): gr.Markdown("## Step 1 — Upload video") video_input = gr.Video(label="Upload Video", sources=["upload"], height=512) gr.Markdown("## Step 2 — Set interval, then load frames") interval_slider = gr.Slider( label="Frame Interval", minimum=1, maximum=30, step=1, value=1, ) load_btn = gr.Button("Load Frames", variant="primary") process_status = gr.Textbox(label="Status", value="1) Upload a video.", interactive=False) with gr.Column(scale=2): gr.Markdown("## Step 3 — Annotate frame & generate mask") img_display = gr.Image(label="Annotate Frame", interactive=True, height=512) frame_slider = gr.Slider(label="Select Frame", minimum=0, maximum=100, step=1, value=0) with gr.Row(): mode_radio = gr.Radio( choices=["Positive Point", "Negative Point", "Box Top-Left", "Box Bottom-Right"], value="Positive Point", label="Annotation Mode", ) with gr.Column(): gen_mask_btn = gr.Button("Generate Mask", variant="primary", interactive=False) reset_btn = gr.Button("Reset Annotations", interactive=False) gr.Markdown("## Step 4 — Track & Save Cache") with gr.Row(): track_btn = gr.Button("Start Tracking", variant="primary", interactive=False) save_cache_btn = gr.Button("Save Cache", variant="secondary", interactive=False) with gr.Row(): video_output = gr.Video(label="Tracking Output", autoplay=True, height=512) cache_status = gr.Textbox(label="Cache", value="", interactive=False) # ------------------------ # Events # ------------------------ def on_video_uploaded(video_path): n_frames = estimate_total_frames(video_path) default_interval = max(1, n_frames // 100) return ( gr.update(value=default_interval, maximum=min(30, n_frames)), f"Video uploaded ({n_frames} frames). 2) Adjust interval, then click 'Load Frames'.", ) video_input.upload(fn=on_video_uploaded, inputs=video_input, outputs=[interval_slider, process_status]) load_btn.click( fn=lambda: ( "Loading frames...", gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), # save_cache_btn gr.update(value=""), ), outputs=[process_status, gen_mask_btn, reset_btn, track_btn, save_cache_btn, cache_status], ).then( fn=on_video_upload, inputs=[video_input, interval_slider], outputs=[img_display, app_state, frame_slider, img_display], ).then( fn=lambda: ( "Ready. 3) Annotate and generate mask.", gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), ), outputs=[process_status, gen_mask_btn, reset_btn, track_btn], ) frame_slider.change(fn=on_slider_change, inputs=[app_state, frame_slider], outputs=[img_display]) img_display.select(fn=on_image_click, inputs=[app_state, mode_radio], outputs=[img_display]) gen_mask_btn.click(fn=on_generate_mask_click, inputs=[app_state], outputs=[img_display]) reset_btn.click(fn=reset_annotations, inputs=[app_state], outputs=[img_display]) track_btn.click( fn=lambda: ( "Tracking in progress...", gr.update(interactive=False), gr.update(interactive=False), ), outputs=[process_status, track_btn, save_cache_btn], ).then( fn=on_track_click, inputs=[app_state], outputs=[video_output, app_state], ).then( fn=lambda: ( "Tracking complete. You can save cache.", gr.update(interactive=True), # track_btn gr.update(interactive=True), # save_cache_btn ), outputs=[process_status, track_btn, save_cache_btn], ) save_cache_btn.click(fn=on_save_cache_click, inputs=[app_state], outputs=[cache_status]) if __name__ == "__main__": app.launch()