3AM / app.py
nycu-cplab's picture
description
3ab2b45
import spaces
import subprocess
import sys, os
from pathlib import Path
import math
import pickle
from typing import Any, Dict, List, Tuple, Optional
import importlib, site
import gradio as gr
import torch
import numpy as np
from PIL import Image, ImageDraw
import cv2
import logging
# ============================================================
# Bootstrap (same style as your original app.py)
# ============================================================
ROOT = Path(__file__).resolve().parent
SAM2 = ROOT / "sam2-src"
CKPT = SAM2 / "checkpoints" / "sam2.1_hiera_large.pt"
if not CKPT.exists():
subprocess.check_call(["bash", "download_ckpts.sh"], cwd=SAM2 / "checkpoints")
try:
import sam2.build_sam # noqa: F401
except ModuleNotFoundError:
subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "./sam2-src"], cwd=ROOT)
subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "./sam2-src[notebooks]"], cwd=ROOT)
try:
import asmk.index # noqa: F401
except Exception:
subprocess.check_call(["cythonize", "*.pyx"], cwd="./asmk-src/cython")
subprocess.check_call([sys.executable, "-m", "pip", "install", "./asmk-src", "--no-build-isolation"])
if not os.path.exists("./private"):
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="nycu-cplab/3AM",
local_dir="./private",
repo_type="model",
)
for sp in site.getsitepackages():
site.addsitedir(sp)
importlib.invalidate_caches()
# ============================================================
# Logging
# ============================================================
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger("app_user")
# ============================================================
# Engine imports
# ============================================================
from engine import ( # noqa: E402
get_predictors,
get_views,
prepare_sam2_inputs,
must3r_features_and_output,
get_single_frame_mask,
get_tracked_masks,
)
# ============================================================
# Globals
# ============================================================
PREDICTOR_ORIGINAL = None
PREDICTOR = None
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.no_grad().__enter__()
def load_models():
global PREDICTOR_ORIGINAL, PREDICTOR
if PREDICTOR is None or PREDICTOR_ORIGINAL is None:
logger.info(f"Initializing models on device: {DEVICE}...")
PREDICTOR_ORIGINAL, PREDICTOR = get_predictors(device=DEVICE)
logger.info("Models loaded successfully.")
return PREDICTOR_ORIGINAL, PREDICTOR
def to_device_nested(x: Any, device: str) -> Any:
if torch.is_tensor(x):
return x.to(device)
if isinstance(x, dict):
return {k: to_device_nested(v, device) for k, v in x.items()}
if isinstance(x, list):
return [to_device_nested(v, device) for v in x]
if isinstance(x, tuple):
return tuple(to_device_nested(v, device) for v in x)
return x
# ============================================================
# Helper Functions
# ============================================================
def video_to_frames(video_path, interval=1):
logger.info(f"Extracting frames from video: {video_path} with interval {interval}")
cap = cv2.VideoCapture(video_path)
frames = []
count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if count % interval == 0:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(Image.fromarray(frame_rgb))
count += 1
cap.release()
logger.info(f"Extracted {len(frames)} frames (sampled from {count} total frames).")
return frames
def draw_points(image_pil, points, labels):
img_draw = image_pil.copy()
draw = ImageDraw.Draw(img_draw)
r = 15
for pt, lbl in zip(points, labels):
x, y = pt
if lbl == 1:
color = "green"
elif lbl == 0:
color = "red"
elif lbl == 2:
color = "blue"
elif lbl == 3:
color = "cyan"
else:
color = "yellow"
draw.ellipse((x - r, y - r, x + r, y + r), fill=color, outline="white")
return img_draw
def overlay_mask(image_pil, mask, color=(255, 0, 0), alpha=0.5):
if mask is None:
return image_pil
mask = mask > 0
img_np = np.array(image_pil)
h, w = img_np.shape[:2]
if mask.shape[0] != h or mask.shape[1] != w:
mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool)
overlay = img_np.copy()
overlay[mask] = np.array(color, dtype=np.uint8)
combined = cv2.addWeighted(overlay, alpha, img_np, 1 - alpha, 0)
return Image.fromarray(combined)
def create_video_from_masks(frames, masks_dict, output_path="output_tracking.mp4", fps=24):
logger.info(f"Creating video output at {output_path} with {len(frames)} frames.")
if not frames:
logger.warning("No frames to create video.")
return None
fps = float(fps)
if not (fps > 0.0):
fps = 24.0
h, w = np.array(frames[0]).shape[:2]
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
for idx, frame in enumerate(frames):
mask = masks_dict.get(idx)
if mask is not None:
pil_out = overlay_mask(frame, mask, color=(255, 0, 0), alpha=0.6)
frame_np = np.array(pil_out)
else:
frame_np = np.array(frame)
frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR)
out.write(frame_bgr)
out.release()
logger.info("Video creation complete.")
return output_path
# ============================================================
# Runtime estimation
# ============================================================
def estimate_video_fps(video_path: str) -> float:
cap = cv2.VideoCapture(video_path)
fps = float(cap.get(cv2.CAP_PROP_FPS)) or 0.0
cap.release()
return fps if fps > 0.0 else 24.0
def estimate_total_frames(video_path: str) -> int:
cap = cv2.VideoCapture(video_path)
n = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
cap.release()
return max(1, n)
MAX_GPU_SECONDS = 600
def clamp_duration(sec: int) -> int:
return int(min(MAX_GPU_SECONDS, max(1, sec)))
def get_duration_must3r_features(video_path, interval):
total = estimate_total_frames(video_path)
interval = max(1, int(interval))
processed = math.ceil(total / interval)
sec_per_frame = 2
return clamp_duration(int(processed * sec_per_frame))
def get_duration_tracking(sam2_input_images, must3r_feats, must3r_outputs, start_idx, first_frame_mask):
try:
n = int(getattr(sam2_input_images, "shape")[0])
except Exception:
n = 100
sec_per_frame = 2
return clamp_duration(int(n * sec_per_frame))
# ============================================================
# GPU Wrapped Functions
# ============================================================
@spaces.GPU(duration=get_duration_must3r_features)
def process_video_and_features(video_path, interval):
logger.info(f"Starting GPU process: Video feature extraction (Interval: {interval})")
load_models()
pil_imgs = video_to_frames(video_path, interval=max(1, int(interval)))
if not pil_imgs:
raise ValueError("Could not extract frames from video.")
views, resize_funcs = get_views(pil_imgs)
must3r_feats, must3r_outputs = must3r_features_and_output(views, device=DEVICE)
sam2_input_images, images_tensor = prepare_sam2_inputs(views, pil_imgs, resize_funcs)
return pil_imgs, views, resize_funcs, must3r_feats, must3r_outputs, sam2_input_images, images_tensor
@spaces.GPU
def generate_frame_mask(image_tensor, points, labels, original_size):
logger.info(f"Generating mask for single frame. Points: {len(points)}")
load_models()
# Ensure tensors are on GPU
image_tensor = image_tensor.to(DEVICE)
pts_tensor = torch.tensor(points, dtype=torch.float32).unsqueeze(0).to(DEVICE)
lbl_tensor = torch.tensor(labels, dtype=torch.int32).unsqueeze(0).to(DEVICE)
w, h = original_size
pts_tensor[..., 0] /= (w / 1024.0)
pts_tensor[..., 1] /= (h / 1024.0)
mask = get_single_frame_mask(
image=image_tensor,
predictor_original=PREDICTOR_ORIGINAL,
points=pts_tensor,
labels=lbl_tensor,
device=DEVICE,
)
return mask.squeeze().cpu().numpy()
@spaces.GPU(duration=get_duration_tracking)
def run_tracking(sam2_input_images, must3r_feats, must3r_outputs, start_idx, first_frame_mask):
logger.info(f"Starting tracking from frame index {start_idx}...")
load_models()
# Ensure everything is on GPU (cached examples load from CPU)
sam2_input_images = sam2_input_images.to(DEVICE)
must3r_feats = to_device_nested(must3r_feats, DEVICE)
must3r_outputs = to_device_nested(must3r_outputs, DEVICE)
mask_tensor = torch.tensor(first_frame_mask).to(DEVICE) > 0
tracked_masks = get_tracked_masks(
sam2_input_images=sam2_input_images,
must3r_feats=must3r_feats,
must3r_outputs=must3r_outputs,
start_idx=start_idx,
first_frame_mask=mask_tensor,
predictor=PREDICTOR,
predictor_original=PREDICTOR_ORIGINAL,
device=DEVICE,
)
logger.info(f"Tracking complete. Generated masks for {len(tracked_masks)} frames.")
return tracked_masks
# ============================================================
# Cache loader (Examples)
# ============================================================
CACHE_ROOT = Path("./private/cache")
def _read_meta(meta_path: Path) -> Dict[str, Any]:
with open(meta_path, "rb") as f:
return pickle.load(f)
def _load_frames_from_dir(frames_dir: Path) -> List[Image.Image]:
frames = []
for p in sorted(frames_dir.glob("*.jpg")):
frames.append(Image.open(p).convert("RGB"))
return frames
def list_example_dirs() -> List[Path]:
if not CACHE_ROOT.exists():
return []
out = []
for d in sorted(CACHE_ROOT.iterdir()):
if not d.is_dir():
continue
if (d / "meta.pkl").exists() and (d / "state_tensors.pt").exists() and (d / "output_tracking.mp4").exists():
out.append(d)
return out
# ============================================================
# Cache loader (Examples) - GALLERY VERSION
# ============================================================
def build_examples_gallery():
"""Build gallery data for examples."""
gallery_items = []
cache_index = {}
for idx, d in enumerate(list_example_dirs()):
cache_id = d.name
meta = _read_meta(d / "meta.pkl")
frames_dir = d / "frames"
thumb = d / "vis_img.png"
if not thumb.exists():
jpgs = sorted(frames_dir.glob("*.jpg"))
if not jpgs:
continue
thumb = jpgs[0]
# Gallery item: (image, caption)
caption = f"{meta.get('num_frames', 0)} Frames"
gallery_items.append((str(thumb), caption))
cache_index[idx] = {
"cache_id": cache_id,
"dir": d,
"meta": meta,
"video_mp4": str(d / "output_tracking.mp4"),
"frames_dir": frames_dir,
"tensors": str(d / "state_tensors.pt"),
}
print(f"Found {len(gallery_items)} example directories.")
return gallery_items, cache_index
def load_cache_into_state(row_idx: int, cache_index: Dict[int, Dict[str, Any]]):
info = cache_index[row_idx]
meta = info["meta"]
cache_id = info["cache_id"]
pil_imgs = _load_frames_from_dir(info["frames_dir"])
if not pil_imgs:
raise gr.Error("Example frames not found or empty.")
tensors = torch.load(info["tensors"], map_location="cpu")
views, resize_funcs = get_views(pil_imgs)
fps_in = float(meta.get("fps_in", 24.0))
fps_out = float(meta.get("fps_out", 24.0))
interval = int(meta.get("interval", 1))
points = meta.get("points", [])
labels = meta.get("labels", [])
first_frame_mask = meta.get("first_frame_mask", None)
state = {
"pil_imgs": pil_imgs,
"views": views,
"resize_funcs": resize_funcs,
"must3r_feats": tensors["must3r_feats"],
"must3r_outputs": tensors["must3r_outputs"],
"sam2_input_images": tensors["sam2_input_images"],
"images_tensor": tensors["images_tensor"],
"current_points": points,
"current_labels": labels,
"current_mask": first_frame_mask,
"frame_idx": 0,
"video_path": meta.get("video_name", "example"),
"interval": interval,
"fps_in": fps_in,
"fps_out": fps_out,
"output_video_path": info["video_mp4"],
"loaded_from_cache": True,
"cache_id": cache_id,
}
vis_img = overlay_mask(pil_imgs[0], state["current_mask"])
vis_img = draw_points(vis_img, state["current_points"], state["current_labels"])
slider = gr.Slider(value=0, maximum=len(pil_imgs) - 1, step=1, interactive=True)
return state, vis_img, slider, info["video_mp4"], 1
def on_example_select(evt: gr.SelectData, cache_index_state):
"""Handle gallery selection."""
idx = evt.index
state, vis_img, slider, mp4_path, interval = load_cache_into_state(idx, cache_index_state)
return (
vis_img,
state,
slider,
mp4_path,
gr.update(value=interval),
"Ready. Example loaded.",
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(interactive=True),
)
# ============================================================
# UI callbacks (same semantics as your original app.py)
# ============================================================
def on_video_uploaded(video_path):
n_frames = estimate_total_frames(video_path)
default_interval = max(1, n_frames // 100)
return (
gr.update(value=default_interval, maximum=min(30, n_frames)),
f"Video uploaded ({n_frames} frames). 2) Adjust interval, then click 'Load Frames'.",
)
def on_video_upload_and_load(video_path, interval):
logger.info(f"User uploaded video: {video_path}, Interval: {interval}")
if video_path is None:
return None, None, gr.Slider(value=0, maximum=0), None
pil_imgs, views, resize_funcs, must3r_feats, must3r_outputs, sam2_input_images, images_tensor = process_video_and_features(
video_path, int(interval)
)
fps_in = estimate_video_fps(video_path)
interval_i = max(1, int(interval))
fps_out = max(1.0, fps_in / interval_i)
state = {
"pil_imgs": pil_imgs,
"views": views,
"resize_funcs": resize_funcs,
"must3r_feats": must3r_feats,
"must3r_outputs": must3r_outputs,
"sam2_input_images": sam2_input_images,
"images_tensor": images_tensor,
"current_points": [],
"current_labels": [],
"current_mask": None,
"frame_idx": 0,
"video_path": video_path,
"interval": interval_i,
"fps_in": fps_in,
"fps_out": fps_out,
"output_video_path": None,
"loaded_from_cache": False,
}
first_frame = pil_imgs[0]
new_slider = gr.Slider(value=0, maximum=len(pil_imgs) - 1, step=1, interactive=True)
return first_frame, state, new_slider, gr.Image(value=first_frame)
def on_slider_change(state, frame_idx):
if not state:
return None
frame_idx = int(frame_idx)
if frame_idx >= len(state["pil_imgs"]):
frame_idx = len(state["pil_imgs"]) - 1
state["frame_idx"] = frame_idx
state["current_points"] = []
state["current_labels"] = []
state["current_mask"] = None
return state["pil_imgs"][frame_idx]
def on_image_click(state, evt: gr.SelectData, mode):
if not state:
return None
x, y = evt.index
label_map = {
"Positive Point": 1,
"Negative Point": 0,
"Box Top-Left": 2,
"Box Bottom-Right": 3,
}
label = label_map[mode]
state["current_points"].append([x, y])
state["current_labels"].append(label)
frame_pil = state["pil_imgs"][state["frame_idx"]]
vis_img = draw_points(frame_pil, state["current_points"], state["current_labels"])
if state["current_mask"] is not None:
vis_img = overlay_mask(vis_img, state["current_mask"])
return vis_img
def on_generate_mask_click(state):
if not state:
return None
if not state["current_points"]:
raise gr.Error("No points or boxes annotated.")
num_tl = state["current_labels"].count(2)
num_br = state["current_labels"].count(3)
if num_tl != num_br or num_tl > 1:
raise gr.Error(f"Incomplete box detected! TL={num_tl}, BR={num_br}. Must match and be <= 1.")
frame_idx = state["frame_idx"]
full_tensor = state["sam2_input_images"]
frame_tensor = full_tensor[frame_idx].unsqueeze(0)
original_size = state["pil_imgs"][frame_idx].size
mask = generate_frame_mask(
frame_tensor,
state["current_points"],
state["current_labels"],
original_size,
)
state["current_mask"] = mask
frame_pil = state["pil_imgs"][frame_idx]
vis_img = overlay_mask(frame_pil, mask)
vis_img = draw_points(vis_img, state["current_points"], state["current_labels"])
return vis_img
def reset_annotations(state):
if not state:
return None
state["current_points"] = []
state["current_labels"] = []
state["current_mask"] = None
frame_idx = state["frame_idx"]
return state["pil_imgs"][frame_idx]
def on_track_click(state):
if not state or state["current_mask"] is None:
raise gr.Error("Please annotate a frame and generate a mask first.")
num_tl = state["current_labels"].count(2)
num_br = state["current_labels"].count(3)
if num_tl != num_br:
raise gr.Error("Incomplete box annotations.")
start_idx = state["frame_idx"]
first_frame_mask = state["current_mask"]
tracked_masks_dict = run_tracking(
state["sam2_input_images"],
state["must3r_feats"],
state["must3r_outputs"],
start_idx,
first_frame_mask,
)
output_path = create_video_from_masks(
state["pil_imgs"],
tracked_masks_dict,
fps=state.get("fps_out", 24.0),
)
state["output_video_path"] = output_path
return output_path
# ============================================================
# App Layout (match original, add Examples at bottom)
# ============================================================
description = """
<div style="text-align: center;">
<h1>3AM: Segment Anything with Geometric Consistency in Videos </h1>
<p>Upload a video, extract geometric features, annotate a frame, and track the object.</p>
</div>
"""
with gr.Blocks(title="3AM: 3egment Anything") as app:
gr.HTML(description)
gr.Markdown(
"""
**Workflow**
1) Upload video
2) Adjust frame interval → Load frames
3) Annotate & generate mask
4) Track through the video
"""
)
app_state = gr.State()
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## Step 1 — Upload video")
video_input = gr.Video(
label="Upload Video",
sources=["upload"],
height=512,
)
gr.Markdown("## Step 2 — Set interval, then load frames")
interval_slider = gr.Slider(
label="Frame Interval",
minimum=1,
maximum=30,
step=1,
value=1,
info="Default ≈ total_frames / 100",
)
load_btn = gr.Button("Load Frames", variant="primary")
process_status = gr.Textbox(
label="Status",
value="1) Upload a video.",
interactive=False,
)
with gr.Column(scale=2):
gr.Markdown("## Step 3 — Annotate frame & generate mask")
img_display = gr.Image(
label="Annotate Frame",
interactive=True,
height=512,
)
frame_slider = gr.Slider(
label="Select Frame",
minimum=0,
maximum=100,
step=1,
value=0,
)
with gr.Row():
mode_radio = gr.Radio(
choices=[
"Positive Point",
"Negative Point",
"Box Top-Left",
"Box Bottom-Right",
],
value="Positive Point",
label="Annotation Mode",
)
with gr.Column():
gen_mask_btn = gr.Button(
"Generate Mask",
variant="primary",
interactive=False,
)
reset_btn = gr.Button(
"Reset Annotations",
interactive=False,
)
gr.Markdown("## Step 4 — Track through the video")
with gr.Row():
track_btn = gr.Button(
"Start Tracking",
variant="primary",
scale=1,
interactive=False,
)
with gr.Row():
video_output = gr.Video(
label="Tracking Output",
autoplay=True,
height=512,
)
# -------------------------
# Examples table at bottom
# -------------------------
gr.Markdown("## Examples (click to load)")
gallery_items, cache_index = build_examples_gallery()
cache_index_state = gr.State(cache_index)
if gallery_items:
examples_gallery = gr.Gallery(
value=gallery_items,
label="Examples",
container=True,
columns=6,
object_fit="contain",
show_label=False,
)
examples_gallery.select(
fn=on_example_select,
inputs=[cache_index_state],
outputs=[
img_display,
app_state,
frame_slider,
video_output,
interval_slider,
process_status,
gen_mask_btn,
reset_btn,
track_btn,
],
)
else:
gr.Markdown("*No examples available.*")
# ============================================================
# Events (original + examples)
# ============================================================
video_input.upload(
fn=on_video_uploaded,
inputs=video_input,
outputs=[interval_slider, process_status],
)
load_btn.click(
fn=lambda: (
"Loading frames...",
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
),
outputs=[process_status, gen_mask_btn, reset_btn, track_btn],
).then(
fn=on_video_upload_and_load,
inputs=[video_input, interval_slider],
outputs=[img_display, app_state, frame_slider, img_display],
).then(
fn=lambda: (
"Ready. 3) Annotate and generate mask.",
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(interactive=True),
),
outputs=[process_status, gen_mask_btn, reset_btn, track_btn],
)
frame_slider.change(
fn=on_slider_change,
inputs=[app_state, frame_slider],
outputs=[img_display],
)
img_display.select(
fn=on_image_click,
inputs=[app_state, mode_radio],
outputs=[img_display],
)
gen_mask_btn.click(
fn=on_generate_mask_click,
inputs=[app_state],
outputs=[img_display],
)
reset_btn.click(
fn=reset_annotations,
inputs=[app_state],
outputs=[img_display],
)
track_btn.click(
fn=lambda: "Tracking in progress...",
outputs=process_status,
).then(
fn=on_track_click,
inputs=[app_state],
outputs=[video_output],
).then(
fn=lambda: "Tracking complete!",
outputs=process_status,
)
if __name__ == "__main__":
logger.info("Starting Gradio app...")
app.launch()