3AM / app_cache.py
nycu-cplab's picture
apps
6fd7427
# app_cache.py
# Purpose:
# - Same UI flow (upload -> load frames -> annotate -> generate mask -> track)
# - After tracking, enable "Save Cache"
# - You can create multiple caches by repeating the workflow
#
# Cache contents per example:
# cache/<key>/
# meta.pkl
# frames/*.jpg
# state_tensors.pt (must3r_feats, must3r_outputs, sam2_input_images, images_tensor) on CPU
# output_tracking.mp4
#
# Notes:
# - We do NOT pickle views/resize_funcs (recomputed on load).
# - We store frames as JPEG to avoid pickling PIL and to be deterministic/reloadable.
import spaces
import subprocess
import sys, os
from pathlib import Path
import math
import hashlib
import pickle
from datetime import datetime
from typing import Any, Dict, List, Tuple
import importlib, site
import gradio as gr
import torch
import numpy as np
from PIL import Image, ImageDraw
import cv2
import logging
# ----------------------------
# Project bootstrap
# ----------------------------
ROOT = Path(__file__).resolve().parent
SAM2 = ROOT / "sam2-src"
CKPT = SAM2 / "checkpoints" / "sam2.1_hiera_large.pt"
# download sam2 checkpoints
if not CKPT.exists():
subprocess.check_call(["bash", "download_ckpts.sh"], cwd=SAM2 / "checkpoints")
# install sam2
try:
import sam2.build_sam # noqa
except ModuleNotFoundError:
subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "./sam2-src"], cwd=ROOT)
subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "./sam2-src[notebooks]"], cwd=ROOT)
# install asmk
try:
import asmk.index # noqa: F401
except Exception:
subprocess.check_call(["cythonize", "*.pyx"], cwd="./asmk-src/cython")
subprocess.check_call([sys.executable, "-m", "pip", "install", "./asmk-src", "--no-build-isolation"])
# download private checkpoints
if not os.path.exists("./private"):
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="nycu-cplab/3AM",
local_dir="./private",
repo_type="model",
)
for sp in site.getsitepackages():
site.addsitedir(sp)
importlib.invalidate_caches()
# ----------------------------
# Logging
# ----------------------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger("app_cache")
# ----------------------------
# Engine imports
# ----------------------------
from engine import (
get_predictors,
get_views,
prepare_sam2_inputs,
must3r_features_and_output,
get_single_frame_mask,
get_tracked_masks,
)
# ----------------------------
# Globals
# ----------------------------
PREDICTOR_ORIGINAL = None
PREDICTOR = None
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def load_models():
global PREDICTOR_ORIGINAL, PREDICTOR
if PREDICTOR is None or PREDICTOR_ORIGINAL is None:
logger.info(f"Initializing models on device: {DEVICE}...")
PREDICTOR_ORIGINAL, PREDICTOR = get_predictors(device=DEVICE)
logger.info("Models loaded successfully.")
return PREDICTOR_ORIGINAL, PREDICTOR
# Ensure no_grad globally (as you had)
torch.no_grad().__enter__()
# ----------------------------
# Video / visualization helpers
# ----------------------------
def video_to_frames(video_path, interval=1):
logger.info(f"Extracting frames from video: {video_path} with interval={interval}")
cap = cv2.VideoCapture(video_path)
frames = []
count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if count % interval == 0:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(Image.fromarray(frame_rgb))
count += 1
cap.release()
logger.info(f"Extracted {len(frames)} frames (sampled from {count} total).")
return frames
def draw_points(image_pil, points, labels):
img_draw = image_pil.copy()
draw = ImageDraw.Draw(img_draw)
r = 7.5
for pt, lbl in zip(points, labels):
x, y = pt
if lbl == 1:
color = "green"
elif lbl == 0:
color = "red"
elif lbl == 2:
color = "blue"
elif lbl == 3:
color = "cyan"
else:
color = "yellow"
draw.ellipse((x-r, y-r, x+r, y+r), fill=color, outline="white")
return img_draw
def overlay_mask(image_pil, mask, color=(255, 0, 0), alpha=0.5):
if mask is None:
return image_pil
mask = mask > 0
img_np = np.array(image_pil)
h, w = img_np.shape[:2]
if mask.shape[0] != h or mask.shape[1] != w:
mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool)
overlay = img_np.copy()
overlay[mask] = np.array(color, dtype=np.uint8)
combined = cv2.addWeighted(overlay, alpha, img_np, 1 - alpha, 0)
return Image.fromarray(combined)
def create_video_from_masks(frames, masks_dict, output_path="output_tracking.mp4", fps=24):
logger.info(f"Creating video output at {output_path} with {len(frames)} frames.")
if not frames:
return None
fps = float(fps)
if not (fps > 0.0):
fps = 24.0
h, w = np.array(frames[0]).shape[:2]
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
for idx, frame in enumerate(frames):
mask = masks_dict.get(idx)
if mask is not None:
pil_out = overlay_mask(frame, mask, color=(255, 0, 0), alpha=0.6)
frame_np = np.array(pil_out)
else:
frame_np = np.array(frame)
frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR)
out.write(frame_bgr)
out.release()
return output_path
# ----------------------------
# Runtime estimation helpers
# ----------------------------
def estimate_video_fps(video_path: str) -> float:
cap = cv2.VideoCapture(video_path)
fps = float(cap.get(cv2.CAP_PROP_FPS)) or 0.0
cap.release()
return fps if fps > 0.0 else 24.0
def estimate_total_frames(video_path: str) -> int:
cap = cv2.VideoCapture(video_path)
n = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
cap.release()
return max(1, n)
MAX_GPU_SECONDS = 600
def clamp_duration(sec: int) -> int:
return int(min(MAX_GPU_SECONDS, max(1, sec)))
def get_duration_must3r_features(video_path, interval):
total = estimate_total_frames(video_path)
interval = max(1, int(interval))
processed = math.ceil(total / interval)
sec_per_frame = 2
return clamp_duration(int(processed * sec_per_frame))
def get_duration_tracking(sam2_input_images, must3r_feats, must3r_outputs, start_idx, first_frame_mask):
try:
n = int(getattr(sam2_input_images, "shape")[0])
except Exception:
n = 100
sec_per_frame = 2
return clamp_duration(int(n * sec_per_frame))
# ----------------------------
# GPU functions
# ----------------------------
@spaces.GPU(duration=get_duration_must3r_features)
def process_video_and_features(video_path, interval):
logger.info(f"GPU: feature extraction interval={interval}")
load_models()
pil_imgs = video_to_frames(video_path, interval=max(1, int(interval)))
if not pil_imgs:
raise ValueError("Could not extract frames.")
views, resize_funcs = get_views(pil_imgs)
must3r_feats, must3r_outputs = must3r_features_and_output(views, device=DEVICE)
sam2_input_images, images_tensor = prepare_sam2_inputs(views, pil_imgs, resize_funcs)
return pil_imgs, views, resize_funcs, must3r_feats, must3r_outputs, sam2_input_images, images_tensor
@spaces.GPU
def generate_frame_mask(image_tensor, points, labels, original_size):
logger.info(f"GPU: generate mask points={len(points)}")
load_models()
pts_tensor = torch.tensor(points, dtype=torch.float32).unsqueeze(0).to(DEVICE)
lbl_tensor = torch.tensor(labels, dtype=torch.int32).unsqueeze(0).to(DEVICE)
w, h = original_size
pts_tensor[..., 0] /= (w / 1024.0)
pts_tensor[..., 1] /= (h / 1024.0)
mask = get_single_frame_mask(
image=image_tensor,
predictor_original=PREDICTOR_ORIGINAL,
points=pts_tensor,
labels=lbl_tensor,
device=DEVICE,
)
return mask.squeeze().cpu().numpy()
@spaces.GPU(duration=get_duration_tracking)
def run_tracking(sam2_input_images, must3r_feats, must3r_outputs, start_idx, first_frame_mask):
logger.info(f"GPU: tracking start_idx={start_idx}")
load_models()
mask_tensor = torch.tensor(first_frame_mask).to(DEVICE) > 0
tracked_masks = get_tracked_masks(
sam2_input_images=sam2_input_images,
must3r_feats=must3r_feats,
must3r_outputs=must3r_outputs,
start_idx=start_idx,
first_frame_mask=mask_tensor,
predictor=PREDICTOR,
predictor_original=PREDICTOR_ORIGINAL,
device=DEVICE,
)
return tracked_masks
# ----------------------------
# Cache utilities
# ----------------------------
CACHE_DIR = Path("./tmpcache/cache")
CACHE_DIR.mkdir(parents=True, exist_ok=True)
def _make_cache_key(video_path: str, interval: int, start_idx: int) -> str:
name = Path(video_path).name if video_path else "video"
stamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
s = f"{name}|interval={interval}|start={start_idx}|{stamp}"
return hashlib.sha256(s.encode("utf-8")).hexdigest()[:16]
def _cache_paths(key: str) -> Dict[str, Path]:
base = CACHE_DIR / key
base.mkdir(parents=True, exist_ok=True)
return {
"base": base,
"meta": base / "meta.pkl",
"frames_dir": base / "frames",
"vis_img": base / "vis_img.png",
"tensors": base / "state_tensors.pt",
"video": base / "output_tracking.mp4",
}
def _save_frames_as_jpg(pil_imgs: List[Image.Image], frames_dir: Path, quality: int = 95) -> None:
frames_dir.mkdir(parents=True, exist_ok=True)
for i, im in enumerate(pil_imgs):
im.save(frames_dir / f"{i:06d}.jpg", "JPEG", quality=quality, subsampling=0)
def _to_cpu(obj: Any) -> Any:
if torch.is_tensor(obj):
return obj.detach().to("cpu")
if isinstance(obj, dict):
return {k: _to_cpu(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
out = [_to_cpu(v) for v in obj]
return type(obj)(out) if isinstance(obj, tuple) else out
return obj
def save_full_cache_from_state(state: Dict[str, Any]) -> str:
if not state:
raise ValueError("Empty state.")
required = [
"pil_imgs",
"must3r_feats",
"must3r_outputs",
"sam2_input_images",
"images_tensor",
"output_video_path",
"video_path",
"interval",
"fps_in",
"fps_out",
"last_tracking_start_idx",
]
missing = [k for k in required if k not in state or state[k] is None]
if missing:
raise ValueError(f"State missing fields: {missing}")
key = _make_cache_key(
str(state["video_path"]),
int(state["interval"]),
int(state["last_tracking_start_idx"]),
)
paths = _cache_paths(key)
_save_frames_as_jpg(state["pil_imgs"], paths["frames_dir"])
state['current_vis_img'].save(paths["vis_img"])
print(f"Saving tensors to cache...")
torch.save(
{
"must3r_feats": _to_cpu(state["must3r_feats"]),
"must3r_outputs": _to_cpu(state["must3r_outputs"]),
"sam2_input_images": _to_cpu(state["sam2_input_images"]),
"images_tensor": _to_cpu(state["images_tensor"]),
},
paths["tensors"],
)
src = Path(state["output_video_path"])
if not src.exists():
raise FileNotFoundError(f"Output video not found: {src}")
dst = paths["video"]
if src.resolve() != dst.resolve():
dst.write_bytes(src.read_bytes())
meta = {
"video_name": Path(str(state["video_path"])).name,
"interval": int(state["interval"]),
"fps_in": float(state["fps_in"]),
"fps_out": float(state["fps_out"]),
"num_frames": int(len(state["pil_imgs"])),
"start_idx": int(state["last_tracking_start_idx"]),
"points": list(state.get("last_points", [])),
"labels": list(state.get("last_labels", [])),
'first_frame_mask': state.get("current_mask", None),
"cache_key": key,
}
with open(paths["meta"], "wb") as f:
pickle.dump(meta, f)
print(f"Cache saved at key: {key}")
return key
# ----------------------------
# UI callbacks
# ----------------------------
def on_video_upload(video_path, interval):
if video_path is None:
return None, None, gr.Slider(value=0, maximum=0), None
pil_imgs, views, resize_funcs, must3r_feats, must3r_outputs, sam2_input_images, images_tensor = process_video_and_features(
video_path, int(interval)
)
fps_in = estimate_video_fps(video_path)
interval_i = max(1, int(interval))
fps_out = max(1.0, fps_in / interval_i)
state = {
"pil_imgs": pil_imgs,
"views": views,
"resize_funcs": resize_funcs,
"must3r_feats": must3r_feats,
"must3r_outputs": must3r_outputs,
"sam2_input_images": sam2_input_images,
"images_tensor": images_tensor,
"current_points": [],
"current_labels": [],
"current_mask": None,
"frame_idx": 0,
"video_path": video_path,
"interval": interval_i,
"fps_in": fps_in,
"fps_out": fps_out,
# tracking outputs (filled later)
"output_video_path": None,
"last_tracking_start_idx": None,
"last_points": None,
"last_labels": None,
}
first_frame = pil_imgs[0]
new_slider = gr.Slider(value=0, maximum=len(pil_imgs) - 1, step=1, interactive=True)
return first_frame, state, new_slider, gr.Image(value=first_frame)
def on_slider_change(state, frame_idx):
if not state:
return None
frame_idx = int(frame_idx)
frame_idx = min(frame_idx, len(state["pil_imgs"]) - 1)
state["frame_idx"] = frame_idx
state["current_points"] = []
state["current_labels"] = []
state["current_mask"] = None
frame = state["pil_imgs"][frame_idx]
return frame
def on_image_click(state, evt: gr.SelectData, mode):
if not state:
return None
x, y = evt.index
label_map = {
"Positive Point": 1,
"Negative Point": 0,
"Box Top-Left": 2,
"Box Bottom-Right": 3,
}
label = label_map[mode]
state["current_points"].append([x, y])
state["current_labels"].append(label)
frame_pil = state["pil_imgs"][state["frame_idx"]]
vis_img = draw_points(frame_pil, state["current_points"], state["current_labels"])
if state["current_mask"] is not None:
vis_img = overlay_mask(vis_img, state["current_mask"])
return vis_img
def on_generate_mask_click(state):
if not state:
return None
if not state["current_points"]:
raise gr.Error("No points or boxes annotated.")
num_tl = state["current_labels"].count(2)
num_br = state["current_labels"].count(3)
if num_tl != num_br or num_tl > 1:
raise gr.Error(f"Incomplete box: TL={num_tl}, BR={num_br}. Must match and be <= 1.")
frame_idx = state["frame_idx"]
full_tensor = state["sam2_input_images"]
frame_tensor = full_tensor[frame_idx].unsqueeze(0)
original_size = state["pil_imgs"][frame_idx].size
mask = generate_frame_mask(
frame_tensor,
state["current_points"],
state["current_labels"],
original_size,
)
state["current_mask"] = mask
frame_pil = state["pil_imgs"][frame_idx]
vis_img = overlay_mask(frame_pil, mask)
vis_img = draw_points(vis_img, state["current_points"], state["current_labels"])
state["current_vis_img"] = vis_img.copy()
return vis_img
def reset_annotations(state):
if not state:
return None
state["current_points"] = []
state["current_labels"] = []
state["current_mask"] = None
frame_idx = state["frame_idx"]
return state["pil_imgs"][frame_idx]
def on_track_click(state):
if not state or state["current_mask"] is None:
raise gr.Error("Generate a mask first.")
num_tl = state["current_labels"].count(2)
num_br = state["current_labels"].count(3)
if num_tl != num_br:
raise gr.Error("Incomplete box annotations.")
start_idx = int(state["frame_idx"])
first_frame_mask = state["current_mask"]
tracked_masks_dict = run_tracking(
state["sam2_input_images"],
state["must3r_feats"],
state["must3r_outputs"],
start_idx,
first_frame_mask,
)
output_path = create_video_from_masks(
state["pil_imgs"],
tracked_masks_dict,
fps=state.get("fps_out", 24.0),
)
state["output_video_path"] = output_path
state["last_tracking_start_idx"] = start_idx
state["last_points"] = list(state.get("current_points", []))
state["last_labels"] = list(state.get("current_labels", []))
print(f"Tracking Complete")
return output_path, state
def on_save_cache_click(state):
key = save_full_cache_from_state(state)
return f"Saved cache key: {key}"
# ----------------------------
# UI layout
# ----------------------------
description = """
<div style="text-align: center;">
<h1>3AM: 3egment Anything with Geometric Consistency in Videos</h1>
<p>Cache-builder UI: run full pipeline, then save caches for user examples.</p>
</div>
"""
with gr.Blocks(title="3AM Cache Builder") as app:
gr.HTML(description)
app_state = gr.State()
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## Step 1 — Upload video")
video_input = gr.Video(label="Upload Video", sources=["upload"], height=512)
gr.Markdown("## Step 2 — Set interval, then load frames")
interval_slider = gr.Slider(
label="Frame Interval",
minimum=1,
maximum=30,
step=1,
value=1,
)
load_btn = gr.Button("Load Frames", variant="primary")
process_status = gr.Textbox(label="Status", value="1) Upload a video.", interactive=False)
with gr.Column(scale=2):
gr.Markdown("## Step 3 — Annotate frame & generate mask")
img_display = gr.Image(label="Annotate Frame", interactive=True, height=512)
frame_slider = gr.Slider(label="Select Frame", minimum=0, maximum=100, step=1, value=0)
with gr.Row():
mode_radio = gr.Radio(
choices=["Positive Point", "Negative Point", "Box Top-Left", "Box Bottom-Right"],
value="Positive Point",
label="Annotation Mode",
)
with gr.Column():
gen_mask_btn = gr.Button("Generate Mask", variant="primary", interactive=False)
reset_btn = gr.Button("Reset Annotations", interactive=False)
gr.Markdown("## Step 4 — Track & Save Cache")
with gr.Row():
track_btn = gr.Button("Start Tracking", variant="primary", interactive=False)
save_cache_btn = gr.Button("Save Cache", variant="secondary", interactive=False)
with gr.Row():
video_output = gr.Video(label="Tracking Output", autoplay=True, height=512)
cache_status = gr.Textbox(label="Cache", value="", interactive=False)
# ------------------------
# Events
# ------------------------
def on_video_uploaded(video_path):
n_frames = estimate_total_frames(video_path)
default_interval = max(1, n_frames // 100)
return (
gr.update(value=default_interval, maximum=min(30, n_frames)),
f"Video uploaded ({n_frames} frames). 2) Adjust interval, then click 'Load Frames'.",
)
video_input.upload(fn=on_video_uploaded, inputs=video_input, outputs=[interval_slider, process_status])
load_btn.click(
fn=lambda: (
"Loading frames...",
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False), # save_cache_btn
gr.update(value=""),
),
outputs=[process_status, gen_mask_btn, reset_btn, track_btn, save_cache_btn, cache_status],
).then(
fn=on_video_upload,
inputs=[video_input, interval_slider],
outputs=[img_display, app_state, frame_slider, img_display],
).then(
fn=lambda: (
"Ready. 3) Annotate and generate mask.",
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(interactive=True),
),
outputs=[process_status, gen_mask_btn, reset_btn, track_btn],
)
frame_slider.change(fn=on_slider_change, inputs=[app_state, frame_slider], outputs=[img_display])
img_display.select(fn=on_image_click, inputs=[app_state, mode_radio], outputs=[img_display])
gen_mask_btn.click(fn=on_generate_mask_click, inputs=[app_state], outputs=[img_display])
reset_btn.click(fn=reset_annotations, inputs=[app_state], outputs=[img_display])
track_btn.click(
fn=lambda: (
"Tracking in progress...",
gr.update(interactive=False),
gr.update(interactive=False),
),
outputs=[process_status, track_btn, save_cache_btn],
).then(
fn=on_track_click,
inputs=[app_state],
outputs=[video_output, app_state],
).then(
fn=lambda: (
"Tracking complete. You can save cache.",
gr.update(interactive=True), # track_btn
gr.update(interactive=True), # save_cache_btn
),
outputs=[process_status, track_btn, save_cache_btn],
)
save_cache_btn.click(fn=on_save_cache_click, inputs=[app_state], outputs=[cache_status])
if __name__ == "__main__":
app.launch()