nycu-cplab commited on
Commit
fb1055c
·
1 Parent(s): 9d941d0
Files changed (2) hide show
  1. app.py +443 -286
  2. app_cache.py +675 -0
app.py CHANGED
@@ -1,105 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import spaces
2
  import subprocess
3
  import sys, os
4
  from pathlib import Path
5
  import math
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- ''' loading modules '''
 
 
8
  ROOT = Path(__file__).resolve().parent
9
  SAM2 = ROOT / "sam2-src"
10
  CKPT = SAM2 / "checkpoints" / "sam2.1_hiera_large.pt"
11
- ASMK = ROOT / "asmk"
12
 
13
- ''' download sam2 checkpoints '''
14
  if not CKPT.exists():
15
  subprocess.check_call(["bash", "download_ckpts.sh"], cwd=SAM2 / "checkpoints")
16
 
17
- ''' install sam2 '''
18
  try:
19
- import sam2.build_sam
20
  except ModuleNotFoundError:
21
  subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "./sam2-src"], cwd=ROOT)
22
  subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "./sam2-src[notebooks]"], cwd=ROOT)
23
 
24
- ''' install asmk '''
25
  try:
26
  import asmk.index # noqa: F401
27
- except Exception as e:
28
- subprocess.check_call(
29
- ["cythonize", "*.pyx"], cwd='./asmk-src/cython'
30
- )
31
- subprocess.check_call(
32
- [sys.executable, "-m", "pip", "install", './asmk-src', "--no-build-isolation"]
33
- )
34
 
35
- ''' download some checkpoints '''
36
- if not os.path.exists('./private'):
37
  from huggingface_hub import snapshot_download
38
- local_dir = snapshot_download(
39
  repo_id="nycu-cplab/3AM",
40
  local_dir="./private",
41
  repo_type="model",
42
  )
43
- import importlib, site
44
  for sp in site.getsitepackages():
45
  site.addsitedir(sp)
46
  importlib.invalidate_caches()
47
 
48
- import gradio as gr
49
- import torch
50
- torch.no_grad().__enter__()
51
- import numpy as np
52
- from PIL import Image, ImageDraw
53
- import cv2
54
- import copy
55
- import json
56
- import logging
57
- import sys
58
- # --- Logging Configuration ---
59
  logging.basicConfig(
60
  level=logging.INFO,
61
  format="%(asctime)s [%(levelname)s] %(message)s",
62
- handlers=[
63
- logging.StreamHandler(sys.stdout)
64
- ]
65
  )
66
- logger = logging.getLogger(__name__)
 
67
 
68
- # Import functions from your engine.py
69
- from engine import (
 
 
70
  get_predictors,
71
  get_views,
72
  prepare_sam2_inputs,
73
  must3r_features_and_output,
74
  get_single_frame_mask,
75
- get_tracked_masks
76
  )
77
 
78
- # --- Global Configuration & Model Loading ---
79
 
 
 
 
80
  PREDICTOR_ORIGINAL = None
81
  PREDICTOR = None
82
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
83
 
84
  def load_models():
85
  global PREDICTOR_ORIGINAL, PREDICTOR
86
  if PREDICTOR is None or PREDICTOR_ORIGINAL is None:
87
  logger.info(f"Initializing models on device: {DEVICE}...")
88
- try:
89
- PREDICTOR_ORIGINAL, PREDICTOR = get_predictors(device=DEVICE)
90
- logger.info("Models loaded successfully.")
91
- except Exception as e:
92
- logger.error(f"Failed to load models: {e}")
93
- raise e
94
  return PREDICTOR_ORIGINAL, PREDICTOR
95
 
96
- # --- Helper Functions ---
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def video_to_frames(video_path, interval=1):
99
- """
100
- Extract frames from video path to a list of PIL Images.
101
- Respects the frame interval (e.g., interval=5 takes every 5th frame).
102
- """
103
  logger.info(f"Extracting frames from video: {video_path} with interval {interval}")
104
  cap = cv2.VideoCapture(video_path)
105
  frames = []
@@ -108,68 +136,50 @@ def video_to_frames(video_path, interval=1):
108
  ret, frame = cap.read()
109
  if not ret:
110
  break
111
-
112
- # Only keep frame if it matches the interval
113
  if count % interval == 0:
114
- # Convert BGR to RGB
115
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
116
  frames.append(Image.fromarray(frame_rgb))
117
-
118
  count += 1
119
-
120
  cap.release()
121
  logger.info(f"Extracted {len(frames)} frames (sampled from {count} total frames).")
122
  return frames
123
 
 
124
  def draw_points(image_pil, points, labels):
125
- """Draws visual markers for clicks on the image."""
126
  img_draw = image_pil.copy()
127
  draw = ImageDraw.Draw(img_draw)
128
-
129
- # Radius of points
130
  r = 5
131
-
132
  for pt, lbl in zip(points, labels):
133
  x, y = pt
134
- if lbl == 1: # Positive
135
  color = "green"
136
- elif lbl == 0: # Negative
137
  color = "red"
138
- elif lbl == 2: # Box Top-Left
139
  color = "blue"
140
- elif lbl == 3: # Box Bottom-Right
141
  color = "cyan"
142
  else:
143
  color = "yellow"
144
-
145
- draw.ellipse((x-r, y-r, x+r, y+r), fill=color, outline="white")
146
-
147
  return img_draw
148
 
 
149
  def overlay_mask(image_pil, mask, color=(255, 0, 0), alpha=0.5):
150
- """Overlay a binary mask on a PIL image."""
151
  if mask is None:
152
  return image_pil
153
-
154
- # Ensure mask is bool or 0/1
155
  mask = mask > 0
156
-
157
  img_np = np.array(image_pil)
158
  h, w = img_np.shape[:2]
159
-
160
- # Resize mask to image size if necessary
161
  if mask.shape[0] != h or mask.shape[1] != w:
162
- logger.debug(f"Resizing mask from {mask.shape} to {(h, w)}")
163
  mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool)
164
-
165
  overlay = img_np.copy()
166
  overlay[mask] = np.array(color, dtype=np.uint8)
167
-
168
  combined = cv2.addWeighted(overlay, alpha, img_np, 1 - alpha, 0)
169
  return Image.fromarray(combined)
170
 
 
171
  def create_video_from_masks(frames, masks_dict, output_path="output_tracking.mp4", fps=24):
172
- """Combine original frames and tracking masks into a video."""
173
  logger.info(f"Creating video output at {output_path} with {len(frames)} frames.")
174
  if not frames:
175
  logger.warning("No frames to create video.")
@@ -178,9 +188,9 @@ def create_video_from_masks(frames, masks_dict, output_path="output_tracking.mp4
178
  if not (fps > 0.0):
179
  fps = 24.0
180
  h, w = np.array(frames[0]).shape[:2]
181
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
182
  out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
183
-
184
  for idx, frame in enumerate(frames):
185
  mask = masks_dict.get(idx)
186
  if mask is not None:
@@ -188,26 +198,23 @@ def create_video_from_masks(frames, masks_dict, output_path="output_tracking.mp4
188
  frame_np = np.array(pil_out)
189
  else:
190
  frame_np = np.array(frame)
191
-
192
  frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR)
193
  out.write(frame_bgr)
194
-
195
  out.release()
196
  logger.info("Video creation complete.")
197
  return output_path
198
 
199
- # --- GPU Wrapped Functions ---
200
 
 
 
 
201
  def estimate_video_fps(video_path: str) -> float:
202
  cap = cv2.VideoCapture(video_path)
203
  fps = float(cap.get(cv2.CAP_PROP_FPS)) or 0.0
204
  cap.release()
205
- # Robust fallback if metadata is missing
206
  return fps if fps > 0.0 else 24.0
207
 
208
- MAX_GPU_SECONDS = 600 # e.g., 10 minutes
209
- def clamp_duration(sec: int) -> int:
210
- return int(min(MAX_GPU_SECONDS, max(1, sec)))
211
 
212
  def estimate_total_frames(video_path: str) -> int:
213
  cap = cv2.VideoCapture(video_path)
@@ -215,125 +222,250 @@ def estimate_total_frames(video_path: str) -> int:
215
  cap.release()
216
  return max(1, n)
217
 
 
 
 
 
 
 
 
 
218
  def get_duration_must3r_features(video_path, interval):
219
- # interval is applied to the entire pipeline, so actual processed frames ~= ceil(total / interval)
220
  total = estimate_total_frames(video_path)
221
  interval = max(1, int(interval))
222
  processed = math.ceil(total / interval)
223
-
224
- # Tune this coefficient based on your observed runtime on ZeroGPU
225
  sec_per_frame = 2
226
  return clamp_duration(int(processed * sec_per_frame))
227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  @spaces.GPU(duration=get_duration_must3r_features)
229
  def process_video_and_features(video_path, interval):
230
- """Load video, subsample frames, get views, MUSt3R features, SAM2 inputs."""
231
  logger.info(f"Starting GPU process: Video feature extraction (Interval: {interval})")
232
  load_models()
233
-
234
- # Pass interval to subsample frames immediately
235
- pil_imgs = video_to_frames(video_path, interval=interval)
236
  if not pil_imgs:
237
  raise ValueError("Could not extract frames from video.")
238
 
239
- logger.info("Step 1/3: Getting views...")
240
  views, resize_funcs = get_views(pil_imgs)
241
- # Ensure consistent resizing
242
- pil_imgs_resized = [resize_funcs[0].transforms[0](p) for p in pil_imgs]
243
-
244
- logger.info("Step 2/3: Extracting MUSt3R features...")
245
  must3r_feats, must3r_outputs = must3r_features_and_output(views, device=DEVICE)
246
- logger.debug(f"MUSt3R features extracted. Output keys: {must3r_outputs.keys()}")
247
-
248
- logger.info("Step 3/3: Preparing SAM2 inputs...")
249
  sam2_input_images, images_tensor = prepare_sam2_inputs(views, pil_imgs, resize_funcs)
250
- logger.debug(f"SAM2 input shape: {sam2_input_images.shape}")
251
-
252
- logger.info("Feature extraction complete.")
253
  return pil_imgs, views, resize_funcs, must3r_feats, must3r_outputs, sam2_input_images, images_tensor
254
 
 
255
  @spaces.GPU
256
  def generate_frame_mask(image_tensor, points, labels, original_size):
257
- """Generate mask for a single frame based on clicks."""
258
  logger.info(f"Generating mask for single frame. Points: {len(points)}")
259
  load_models()
260
-
 
 
 
261
  pts_tensor = torch.tensor(points, dtype=torch.float32).unsqueeze(0).to(DEVICE)
262
  lbl_tensor = torch.tensor(labels, dtype=torch.int32).unsqueeze(0).to(DEVICE)
263
-
264
  w, h = original_size
265
- # Normalize points
266
  pts_tensor[..., 0] /= (w / 1024.0)
267
  pts_tensor[..., 1] /= (h / 1024.0)
268
 
269
- try:
270
- mask = get_single_frame_mask(
271
- image=image_tensor,
272
- predictor_original=PREDICTOR_ORIGINAL,
273
- points=pts_tensor,
274
- labels=lbl_tensor,
275
- device=DEVICE
276
- )
277
- logger.info("Mask generation successful.")
278
- mask_np = mask.squeeze().cpu().numpy()
279
- return mask_np
280
- except Exception as e:
281
- logger.error(f"Error during mask generation: {e}")
282
- raise e
283
-
284
- def get_duration_tracking(sam2_input_images, must3r_feats, must3r_outputs, start_idx, first_frame_mask):
285
- # sam2_input_images is already subsampled, so this is the true number of frames to track
286
- try:
287
- n = int(getattr(sam2_input_images, "shape")[0])
288
- except Exception:
289
- n = 100 # fallback if something unexpected is passed
290
 
291
- sec_per_frame = 2
292
- return clamp_duration(int(n * sec_per_frame))
293
 
294
  @spaces.GPU(duration=get_duration_tracking)
295
  def run_tracking(sam2_input_images, must3r_feats, must3r_outputs, start_idx, first_frame_mask):
296
- """Track the mask across the video."""
297
  logger.info(f"Starting tracking from frame index {start_idx}...")
298
  load_models()
299
-
 
 
 
 
 
300
  mask_tensor = torch.tensor(first_frame_mask).to(DEVICE) > 0
301
-
302
- try:
303
- tracked_masks = get_tracked_masks(
304
- sam2_input_images=sam2_input_images,
305
- must3r_feats=must3r_feats,
306
- must3r_outputs=must3r_outputs,
307
- start_idx=start_idx,
308
- first_frame_mask=mask_tensor,
309
- predictor=PREDICTOR,
310
- predictor_original=PREDICTOR_ORIGINAL,
311
- device=DEVICE
312
- )
313
- logger.info(f"Tracking complete. Generated masks for {len(tracked_masks)} frames.")
314
- return tracked_masks
315
- except Exception as e:
316
- logger.error(f"Error during tracking: {e}")
317
- raise e
318
 
319
- # --- Gradio Callbacks ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
- def on_video_upload(video_path, interval):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  logger.info(f"User uploaded video: {video_path}, Interval: {interval}")
323
  if video_path is None:
324
  return None, None, gr.Slider(value=0, maximum=0), None
325
-
326
- try:
327
- pil_imgs, views, resize_funcs, must3r_feats, must3r_outputs, sam2_input_images, images_tensor = process_video_and_features(video_path, int(interval))
328
- except Exception as e:
329
- logger.error(f"Failed to process video: {e}")
330
- raise gr.Error(f"Processing failed: {str(e)}")
331
-
332
  fps_in = estimate_video_fps(video_path)
333
  interval_i = max(1, int(interval))
334
  fps_out = max(1.0, fps_in / interval_i)
335
 
336
- # Initialize state
337
  state = {
338
  "pil_imgs": pil_imgs,
339
  "views": views,
@@ -349,160 +481,165 @@ def on_video_upload(video_path, interval):
349
  "video_path": video_path,
350
  "interval": interval_i,
351
  "fps_in": fps_in,
352
- "fps_out": fps_out
 
 
353
  }
354
-
355
  first_frame = pil_imgs[0]
356
- new_slider = gr.Slider(value=0, maximum=len(pil_imgs)-1, step=1, interactive=True)
357
  return first_frame, state, new_slider, gr.Image(value=first_frame)
358
 
 
359
  def on_slider_change(state, frame_idx):
360
  if not state:
361
  return None
362
-
363
  if frame_idx >= len(state["pil_imgs"]):
364
  frame_idx = len(state["pil_imgs"]) - 1
365
-
366
  state["frame_idx"] = frame_idx
367
  state["current_points"] = []
368
  state["current_labels"] = []
369
  state["current_mask"] = None
370
-
371
- frame = state["pil_imgs"][frame_idx]
372
- return frame
373
 
374
  def on_image_click(state, evt: gr.SelectData, mode):
375
- """
376
- Registers the click, updates state, and draws the point/box corner.
377
- Does NOT generate the mask.
378
- """
379
  if not state:
380
  return None
381
-
382
  x, y = evt.index
383
- logger.info(f"User clicked at ({x}, {y}) with mode: {mode}")
384
-
385
  label_map = {
386
  "Positive Point": 1,
387
  "Negative Point": 0,
388
  "Box Top-Left": 2,
389
- "Box Bottom-Right": 3
390
  }
391
  label = label_map[mode]
392
-
393
- # Update State
394
  state["current_points"].append([x, y])
395
  state["current_labels"].append(label)
396
-
397
- # Visual Feedback Only (Draw points)
398
  frame_pil = state["pil_imgs"][state["frame_idx"]]
399
  vis_img = draw_points(frame_pil, state["current_points"], state["current_labels"])
400
-
401
- # Keep old mask visible if it exists, but don't update it yet
402
  if state["current_mask"] is not None:
403
  vis_img = overlay_mask(vis_img, state["current_mask"])
404
-
405
  return vis_img
406
 
 
407
  def on_generate_mask_click(state):
408
- """
409
- Called when 'Generate Mask' button is clicked.
410
- Validates inputs (box completion) and triggers GPU mask generation.
411
- """
412
  if not state:
413
  return None
414
-
415
- logger.info("Generate Mask button clicked.")
416
-
417
  if not state["current_points"]:
418
  raise gr.Error("No points or boxes annotated.")
419
 
420
- # --- BOX VALIDATION LOGIC ---
421
  num_tl = state["current_labels"].count(2)
422
  num_br = state["current_labels"].count(3)
423
-
424
  if num_tl != num_br or num_tl > 1:
425
- logger.warning(f"Box mismatch: TL={num_tl}, BR={num_br}")
426
- raise gr.Error(f"Incomplete box detected! You have {num_tl} top-left(s) and {num_br} bottom-right(s). They must match and be <= 1.")
427
 
428
- # Proceed to inference
429
  frame_idx = state["frame_idx"]
430
  full_tensor = state["sam2_input_images"]
431
- frame_tensor = full_tensor[frame_idx].unsqueeze(0)
432
- original_size = state["pil_imgs"][frame_idx].size
433
-
434
- try:
435
- mask = generate_frame_mask(
436
- frame_tensor,
437
- state["current_points"],
438
- state["current_labels"],
439
- original_size
440
- )
441
- except Exception as e:
442
- logger.error(f"Mask generation failed: {e}")
443
- raise gr.Error("Failed to generate mask.")
444
-
445
  state["current_mask"] = mask
446
-
447
- # Visualization: Draw Mask AND Points
448
  frame_pil = state["pil_imgs"][frame_idx]
449
  vis_img = overlay_mask(frame_pil, mask)
450
  vis_img = draw_points(vis_img, state["current_points"], state["current_labels"])
451
-
452
  return vis_img
453
 
 
 
 
 
 
 
 
 
 
 
 
454
  def on_track_click(state):
455
- logger.info("Track button clicked.")
456
  if not state or state["current_mask"] is None:
457
- logger.warning("Track attempted without mask/state.")
458
  raise gr.Error("Please annotate a frame and generate a mask first.")
459
-
460
- # Double check box consistency just in case
461
  num_tl = state["current_labels"].count(2)
462
  num_br = state["current_labels"].count(3)
463
  if num_tl != num_br:
464
  raise gr.Error("Incomplete box annotations.")
465
-
466
  start_idx = state["frame_idx"]
467
  first_frame_mask = state["current_mask"]
468
-
469
- try:
470
- tracked_masks_dict = run_tracking(
471
- state["sam2_input_images"],
472
- state["must3r_feats"],
473
- state["must3r_outputs"],
474
- start_idx,
475
- first_frame_mask
476
- )
477
-
478
- output_path = create_video_from_masks(
479
- state["pil_imgs"],
480
- tracked_masks_dict,
481
- fps=state.get("fps_out", 24.0),
482
- )
483
- return output_path
484
- except Exception as e:
485
- logger.error(f"Tracking failed in UI callback: {e}")
486
- raise gr.Error(f"Tracking failed: {str(e)}")
487
 
488
- def reset_annotations(state):
489
- if not state:
490
- return None
491
- logger.info("Resetting annotations for current frame.")
492
- state["current_points"] = []
493
- state["current_labels"] = []
494
- state["current_mask"] = None
495
- frame_idx = state["frame_idx"]
496
- return state["pil_imgs"][frame_idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
 
498
- # --- App Layout ---
499
 
 
 
 
 
 
 
 
 
 
500
  description = """
501
  <div style="text-align: center;">
502
- <h1>3AM: 3egment Anything with Geometric Consistency in Videos</h1>
503
- <p>Upload a video, geometric features are extracted automatically. Select a frame, click to annotate objects, and track them in 3D-consistent space.</p>
504
  </div>
505
  """
 
506
  with gr.Blocks(title="3AM: 3egment Anything") as app:
507
  gr.HTML(description)
508
 
@@ -513,11 +650,12 @@ with gr.Blocks(title="3AM: 3egment Anything") as app:
513
  1) Upload video
514
  2) Adjust frame interval → Load frames
515
  3) Annotate & generate mask
516
- 4) Track through the video
517
  """
518
  )
519
 
520
  app_state = gr.State()
 
521
 
522
  with gr.Row():
523
  with gr.Column(scale=1):
@@ -525,7 +663,7 @@ with gr.Blocks(title="3AM: 3egment Anything") as app:
525
  video_input = gr.Video(
526
  label="Upload Video",
527
  sources=["upload"],
528
- height=512
529
  )
530
 
531
  gr.Markdown("## Step 2 — Set interval, then load frames")
@@ -535,18 +673,15 @@ with gr.Blocks(title="3AM: 3egment Anything") as app:
535
  maximum=30,
536
  step=1,
537
  value=1,
538
- info="Default ≈ total_frames / 100"
539
  )
540
 
541
- load_btn = gr.Button(
542
- "Load Frames",
543
- variant="primary"
544
- )
545
 
546
  process_status = gr.Textbox(
547
  label="Status",
548
  value="1) Upload a video.",
549
- interactive=False
550
  )
551
 
552
  with gr.Column(scale=2):
@@ -554,7 +689,7 @@ with gr.Blocks(title="3AM: 3egment Anything") as app:
554
  img_display = gr.Image(
555
  label="Annotate Frame",
556
  interactive=True,
557
- height=512
558
  )
559
 
560
  frame_slider = gr.Slider(
@@ -562,7 +697,7 @@ with gr.Blocks(title="3AM: 3egment Anything") as app:
562
  minimum=0,
563
  maximum=100,
564
  step=1,
565
- value=0
566
  )
567
 
568
  with gr.Row():
@@ -574,17 +709,17 @@ with gr.Blocks(title="3AM: 3egment Anything") as app:
574
  "Box Bottom-Right",
575
  ],
576
  value="Positive Point",
577
- label="Annotation Mode"
578
  )
579
  with gr.Column():
580
  gen_mask_btn = gr.Button(
581
  "Generate Mask",
582
  variant="primary",
583
- interactive=False
584
  )
585
  reset_btn = gr.Button(
586
  "Reset Annotations",
587
- interactive=False
588
  )
589
 
590
  gr.Markdown("## Step 4 — Track through the video")
@@ -593,37 +728,43 @@ with gr.Blocks(title="3AM: 3egment Anything") as app:
593
  "Start Tracking",
594
  variant="primary",
595
  scale=1,
596
- interactive=False
597
  )
598
 
599
  with gr.Row():
600
  video_output = gr.Video(
601
  label="Tracking Output",
602
  autoplay=True,
603
- height=512
604
  )
605
 
606
- # ------------------------------------------------
607
- # Events
608
- # ------------------------------------------------
609
-
610
- # Upload: only read metadata & set default interval
611
- def on_video_uploaded(video_path):
612
- n_frames = estimate_total_frames(video_path)
613
- default_interval = max(1, n_frames // 100)
614
- return (
615
- gr.update(value=default_interval, maximum=min(30, n_frames)),
616
- f"Video uploaded ({n_frames} frames). "
617
- "2) Adjust interval, then click 'Load Frames'."
618
- )
 
 
 
619
 
 
 
 
 
620
  video_input.upload(
621
  fn=on_video_uploaded,
622
  inputs=video_input,
623
- outputs=[interval_slider, process_status]
624
  )
625
 
626
- # Load frames: heavy compute happens here
627
  load_btn.click(
628
  fn=lambda: (
629
  "Loading frames...",
@@ -631,11 +772,11 @@ with gr.Blocks(title="3AM: 3egment Anything") as app:
631
  gr.update(interactive=False),
632
  gr.update(interactive=False),
633
  ),
634
- outputs=[process_status, gen_mask_btn, reset_btn, track_btn]
635
  ).then(
636
- fn=on_video_upload,
637
  inputs=[video_input, interval_slider],
638
- outputs=[img_display, app_state, frame_slider, img_display]
639
  ).then(
640
  fn=lambda: (
641
  "Ready. 3) Annotate and generate mask.",
@@ -643,46 +784,62 @@ with gr.Blocks(title="3AM: 3egment Anything") as app:
643
  gr.update(interactive=True),
644
  gr.update(interactive=True),
645
  ),
646
- outputs=[process_status, gen_mask_btn, reset_btn, track_btn]
647
  )
648
 
649
  frame_slider.change(
650
  fn=on_slider_change,
651
  inputs=[app_state, frame_slider],
652
- outputs=[img_display]
653
  )
654
 
655
  img_display.select(
656
  fn=on_image_click,
657
  inputs=[app_state, mode_radio],
658
- outputs=[img_display]
659
  )
660
 
661
  gen_mask_btn.click(
662
  fn=on_generate_mask_click,
663
  inputs=[app_state],
664
- outputs=[img_display]
665
  )
666
 
667
  reset_btn.click(
668
  fn=reset_annotations,
669
  inputs=[app_state],
670
- outputs=[img_display]
671
  )
672
 
673
  track_btn.click(
674
  fn=lambda: "Tracking in progress...",
675
- outputs=process_status
676
  ).then(
677
  fn=on_track_click,
678
  inputs=[app_state],
679
- outputs=[video_output]
680
  ).then(
681
  fn=lambda: "Tracking complete!",
682
- outputs=process_status
683
  )
684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
685
 
686
  if __name__ == "__main__":
687
  logger.info("Starting Gradio app...")
688
- app.launch()
 
1
+ # app_user.py
2
+ # User-facing app:
3
+ # - Same workflow as original app.py (upload -> set interval -> Load Frames -> annotate -> Generate Mask -> Track)
4
+ # - Adds an Examples table at the bottom
5
+ # - Loads examples from ./private/cache/*
6
+ # - Each row shows the first-frame thumbnail
7
+ # - Clicking a row instantly loads the cached example (state + precomputed output mp4)
8
+ #
9
+ # Expected cache structure per example directory:
10
+ # ./private/cache/<cache_id>/
11
+ # meta.pkl
12
+ # frames/000000.jpg (thumbnail) + more frames
13
+ # state_tensors.pt (must3r_feats, must3r_outputs, sam2_input_images, images_tensor) saved on CPU
14
+ # output_tracking.mp4
15
+ #
16
+ # Notes:
17
+ # - tracked_masks_dict is not required.
18
+ # - views/resize_funcs are recomputed on load (cheap vs must3r/tracking).
19
+
20
  import spaces
21
  import subprocess
22
  import sys, os
23
  from pathlib import Path
24
  import math
25
+ import pickle
26
+ from typing import Any, Dict, List, Tuple, Optional
27
+
28
+ import importlib, site
29
+
30
+ import gradio as gr
31
+ import torch
32
+ import numpy as np
33
+ from PIL import Image, ImageDraw
34
+ import cv2
35
+ import logging
36
+
37
 
38
+ # ============================================================
39
+ # Bootstrap (same style as your original app.py)
40
+ # ============================================================
41
  ROOT = Path(__file__).resolve().parent
42
  SAM2 = ROOT / "sam2-src"
43
  CKPT = SAM2 / "checkpoints" / "sam2.1_hiera_large.pt"
 
44
 
 
45
  if not CKPT.exists():
46
  subprocess.check_call(["bash", "download_ckpts.sh"], cwd=SAM2 / "checkpoints")
47
 
 
48
  try:
49
+ import sam2.build_sam # noqa: F401
50
  except ModuleNotFoundError:
51
  subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "./sam2-src"], cwd=ROOT)
52
  subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "./sam2-src[notebooks]"], cwd=ROOT)
53
 
 
54
  try:
55
  import asmk.index # noqa: F401
56
+ except Exception:
57
+ subprocess.check_call(["cythonize", "*.pyx"], cwd="./asmk-src/cython")
58
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "./asmk-src", "--no-build-isolation"])
 
 
 
 
59
 
60
+ if not os.path.exists("./private"):
 
61
  from huggingface_hub import snapshot_download
62
+ snapshot_download(
63
  repo_id="nycu-cplab/3AM",
64
  local_dir="./private",
65
  repo_type="model",
66
  )
67
+
68
  for sp in site.getsitepackages():
69
  site.addsitedir(sp)
70
  importlib.invalidate_caches()
71
 
72
+
73
+ # ============================================================
74
+ # Logging
75
+ # ============================================================
 
 
 
 
 
 
 
76
  logging.basicConfig(
77
  level=logging.INFO,
78
  format="%(asctime)s [%(levelname)s] %(message)s",
79
+ handlers=[logging.StreamHandler(sys.stdout)],
 
 
80
  )
81
+ logger = logging.getLogger("app_user")
82
+
83
 
84
+ # ============================================================
85
+ # Engine imports
86
+ # ============================================================
87
+ from engine import ( # noqa: E402
88
  get_predictors,
89
  get_views,
90
  prepare_sam2_inputs,
91
  must3r_features_and_output,
92
  get_single_frame_mask,
93
+ get_tracked_masks,
94
  )
95
 
 
96
 
97
+ # ============================================================
98
+ # Globals
99
+ # ============================================================
100
  PREDICTOR_ORIGINAL = None
101
  PREDICTOR = None
102
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
103
+ torch.no_grad().__enter__()
104
+
105
 
106
  def load_models():
107
  global PREDICTOR_ORIGINAL, PREDICTOR
108
  if PREDICTOR is None or PREDICTOR_ORIGINAL is None:
109
  logger.info(f"Initializing models on device: {DEVICE}...")
110
+ PREDICTOR_ORIGINAL, PREDICTOR = get_predictors(device=DEVICE)
111
+ logger.info("Models loaded successfully.")
 
 
 
 
112
  return PREDICTOR_ORIGINAL, PREDICTOR
113
 
 
114
 
115
+ def to_device_nested(x: Any, device: str) -> Any:
116
+ if torch.is_tensor(x):
117
+ return x.to(device)
118
+ if isinstance(x, dict):
119
+ return {k: to_device_nested(v, device) for k, v in x.items()}
120
+ if isinstance(x, list):
121
+ return [to_device_nested(v, device) for v in x]
122
+ if isinstance(x, tuple):
123
+ return tuple(to_device_nested(v, device) for v in x)
124
+ return x
125
+
126
+
127
+ # ============================================================
128
+ # Helper Functions
129
+ # ============================================================
130
  def video_to_frames(video_path, interval=1):
 
 
 
 
131
  logger.info(f"Extracting frames from video: {video_path} with interval {interval}")
132
  cap = cv2.VideoCapture(video_path)
133
  frames = []
 
136
  ret, frame = cap.read()
137
  if not ret:
138
  break
 
 
139
  if count % interval == 0:
 
140
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
141
  frames.append(Image.fromarray(frame_rgb))
 
142
  count += 1
 
143
  cap.release()
144
  logger.info(f"Extracted {len(frames)} frames (sampled from {count} total frames).")
145
  return frames
146
 
147
+
148
  def draw_points(image_pil, points, labels):
 
149
  img_draw = image_pil.copy()
150
  draw = ImageDraw.Draw(img_draw)
 
 
151
  r = 5
 
152
  for pt, lbl in zip(points, labels):
153
  x, y = pt
154
+ if lbl == 1:
155
  color = "green"
156
+ elif lbl == 0:
157
  color = "red"
158
+ elif lbl == 2:
159
  color = "blue"
160
+ elif lbl == 3:
161
  color = "cyan"
162
  else:
163
  color = "yellow"
164
+ draw.ellipse((x - r, y - r, x + r, y + r), fill=color, outline="white")
 
 
165
  return img_draw
166
 
167
+
168
  def overlay_mask(image_pil, mask, color=(255, 0, 0), alpha=0.5):
 
169
  if mask is None:
170
  return image_pil
 
 
171
  mask = mask > 0
 
172
  img_np = np.array(image_pil)
173
  h, w = img_np.shape[:2]
 
 
174
  if mask.shape[0] != h or mask.shape[1] != w:
 
175
  mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool)
 
176
  overlay = img_np.copy()
177
  overlay[mask] = np.array(color, dtype=np.uint8)
 
178
  combined = cv2.addWeighted(overlay, alpha, img_np, 1 - alpha, 0)
179
  return Image.fromarray(combined)
180
 
181
+
182
  def create_video_from_masks(frames, masks_dict, output_path="output_tracking.mp4", fps=24):
 
183
  logger.info(f"Creating video output at {output_path} with {len(frames)} frames.")
184
  if not frames:
185
  logger.warning("No frames to create video.")
 
188
  if not (fps > 0.0):
189
  fps = 24.0
190
  h, w = np.array(frames[0]).shape[:2]
191
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
192
  out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
193
+
194
  for idx, frame in enumerate(frames):
195
  mask = masks_dict.get(idx)
196
  if mask is not None:
 
198
  frame_np = np.array(pil_out)
199
  else:
200
  frame_np = np.array(frame)
 
201
  frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR)
202
  out.write(frame_bgr)
203
+
204
  out.release()
205
  logger.info("Video creation complete.")
206
  return output_path
207
 
 
208
 
209
+ # ============================================================
210
+ # Runtime estimation
211
+ # ============================================================
212
  def estimate_video_fps(video_path: str) -> float:
213
  cap = cv2.VideoCapture(video_path)
214
  fps = float(cap.get(cv2.CAP_PROP_FPS)) or 0.0
215
  cap.release()
 
216
  return fps if fps > 0.0 else 24.0
217
 
 
 
 
218
 
219
  def estimate_total_frames(video_path: str) -> int:
220
  cap = cv2.VideoCapture(video_path)
 
222
  cap.release()
223
  return max(1, n)
224
 
225
+
226
+ MAX_GPU_SECONDS = 600
227
+
228
+
229
+ def clamp_duration(sec: int) -> int:
230
+ return int(min(MAX_GPU_SECONDS, max(1, sec)))
231
+
232
+
233
  def get_duration_must3r_features(video_path, interval):
 
234
  total = estimate_total_frames(video_path)
235
  interval = max(1, int(interval))
236
  processed = math.ceil(total / interval)
 
 
237
  sec_per_frame = 2
238
  return clamp_duration(int(processed * sec_per_frame))
239
 
240
+
241
+ def get_duration_tracking(sam2_input_images, must3r_feats, must3r_outputs, start_idx, first_frame_mask):
242
+ try:
243
+ n = int(getattr(sam2_input_images, "shape")[0])
244
+ except Exception:
245
+ n = 100
246
+ sec_per_frame = 2
247
+ return clamp_duration(int(n * sec_per_frame))
248
+
249
+
250
+ # ============================================================
251
+ # GPU Wrapped Functions
252
+ # ============================================================
253
  @spaces.GPU(duration=get_duration_must3r_features)
254
  def process_video_and_features(video_path, interval):
 
255
  logger.info(f"Starting GPU process: Video feature extraction (Interval: {interval})")
256
  load_models()
257
+
258
+ pil_imgs = video_to_frames(video_path, interval=max(1, int(interval)))
 
259
  if not pil_imgs:
260
  raise ValueError("Could not extract frames from video.")
261
 
 
262
  views, resize_funcs = get_views(pil_imgs)
263
+
 
 
 
264
  must3r_feats, must3r_outputs = must3r_features_and_output(views, device=DEVICE)
265
+
 
 
266
  sam2_input_images, images_tensor = prepare_sam2_inputs(views, pil_imgs, resize_funcs)
267
+
 
 
268
  return pil_imgs, views, resize_funcs, must3r_feats, must3r_outputs, sam2_input_images, images_tensor
269
 
270
+
271
  @spaces.GPU
272
  def generate_frame_mask(image_tensor, points, labels, original_size):
 
273
  logger.info(f"Generating mask for single frame. Points: {len(points)}")
274
  load_models()
275
+
276
+ # Ensure tensors are on GPU
277
+ image_tensor = image_tensor.to(DEVICE)
278
+
279
  pts_tensor = torch.tensor(points, dtype=torch.float32).unsqueeze(0).to(DEVICE)
280
  lbl_tensor = torch.tensor(labels, dtype=torch.int32).unsqueeze(0).to(DEVICE)
281
+
282
  w, h = original_size
 
283
  pts_tensor[..., 0] /= (w / 1024.0)
284
  pts_tensor[..., 1] /= (h / 1024.0)
285
 
286
+ mask = get_single_frame_mask(
287
+ image=image_tensor,
288
+ predictor_original=PREDICTOR_ORIGINAL,
289
+ points=pts_tensor,
290
+ labels=lbl_tensor,
291
+ device=DEVICE,
292
+ )
293
+ return mask.squeeze().cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
 
 
295
 
296
  @spaces.GPU(duration=get_duration_tracking)
297
  def run_tracking(sam2_input_images, must3r_feats, must3r_outputs, start_idx, first_frame_mask):
 
298
  logger.info(f"Starting tracking from frame index {start_idx}...")
299
  load_models()
300
+
301
+ # Ensure everything is on GPU (cached examples load from CPU)
302
+ sam2_input_images = sam2_input_images.to(DEVICE)
303
+ must3r_feats = to_device_nested(must3r_feats, DEVICE)
304
+ must3r_outputs = to_device_nested(must3r_outputs, DEVICE)
305
+
306
  mask_tensor = torch.tensor(first_frame_mask).to(DEVICE) > 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
+ tracked_masks = get_tracked_masks(
309
+ sam2_input_images=sam2_input_images,
310
+ must3r_feats=must3r_feats,
311
+ must3r_outputs=must3r_outputs,
312
+ start_idx=start_idx,
313
+ first_frame_mask=mask_tensor,
314
+ predictor=PREDICTOR,
315
+ predictor_original=PREDICTOR_ORIGINAL,
316
+ device=DEVICE,
317
+ )
318
+ logger.info(f"Tracking complete. Generated masks for {len(tracked_masks)} frames.")
319
+ return tracked_masks
320
+
321
+
322
+ # ============================================================
323
+ # Cache loader (Examples)
324
+ # ============================================================
325
+ CACHE_ROOT = Path("./private/cache")
326
+
327
+
328
+ def _read_meta(meta_path: Path) -> Dict[str, Any]:
329
+ with open(meta_path, "rb") as f:
330
+ return pickle.load(f)
331
+
332
+
333
+ def _load_frames_from_dir(frames_dir: Path) -> List[Image.Image]:
334
+ frames = []
335
+ for p in sorted(frames_dir.glob("*.jpg")):
336
+ frames.append(Image.open(p).convert("RGB"))
337
+ return frames
338
+
339
+
340
+ def list_example_dirs() -> List[Path]:
341
+ if not CACHE_ROOT.exists():
342
+ return []
343
+ out = []
344
+ for d in sorted(CACHE_ROOT.iterdir()):
345
+ if not d.is_dir():
346
+ continue
347
+ if (d / "meta.pkl").exists() and (d / "state_tensors.pt").exists() and (d / "output_tracking.mp4").exists():
348
+ out.append(d)
349
+ return out
350
+
351
+
352
+ def build_examples_table():
353
+ """
354
+ Each row:
355
+ [thumbnail_path, video_name, interval, num_frames, cache_id]
356
+ """
357
+ rows = []
358
+ cache_index = {}
359
+
360
+ for d in list_example_dirs():
361
+ cache_id = d.name
362
+ meta = _read_meta(d / "meta.pkl")
363
+
364
+ frames_dir = d / "frames"
365
+ thumb = frames_dir / "000000.jpg"
366
+ if not thumb.exists():
367
+ jpgs = sorted(frames_dir.glob("*.jpg"))
368
+ if not jpgs:
369
+ continue
370
+ thumb = jpgs[0]
371
+
372
+ video_name = meta.get("video_name", cache_id)
373
+ interval = int(meta.get("interval", 1))
374
+ num_frames = int(meta.get("num_frames", 0))
375
+
376
+ rows.append([
377
+ str(thumb), # image cell
378
+ video_name,
379
+ interval,
380
+ num_frames,
381
+ cache_id, # hidden but kept
382
+ ])
383
 
384
+ cache_index[cache_id] = {
385
+ "dir": d,
386
+ "meta": meta,
387
+ "video_mp4": str(d / "output_tracking.mp4"),
388
+ "frames_dir": frames_dir,
389
+ "tensors": str(d / "state_tensors.pt"),
390
+ }
391
+
392
+ return rows, cache_index
393
+
394
+
395
+
396
+ def load_cache_into_state(cache_id: str, cache_index: Dict[str, Dict[str, Any]]) -> Tuple[Dict[str, Any], Image.Image, gr.Slider, str, int]:
397
+ if cache_id not in cache_index:
398
+ raise gr.Error(f"Unknown cache_id: {cache_id}")
399
+
400
+ info = cache_index[cache_id]
401
+ meta = info["meta"]
402
+
403
+ pil_imgs = _load_frames_from_dir(info["frames_dir"])
404
+ if not pil_imgs:
405
+ raise gr.Error("Example frames not found or empty.")
406
+
407
+ tensors = torch.load(info["tensors"], map_location="cpu")
408
+
409
+ # Recompute lightweight parts
410
+ views, resize_funcs = get_views(pil_imgs)
411
+
412
+ fps_in = float(meta.get("fps_in", 24.0))
413
+ fps_out = float(meta.get("fps_out", 24.0))
414
+ interval = int(meta.get("interval", 1))
415
+
416
+ state = {
417
+ "pil_imgs": pil_imgs,
418
+ "views": views,
419
+ "resize_funcs": resize_funcs,
420
+ "must3r_feats": tensors["must3r_feats"],
421
+ "must3r_outputs": tensors["must3r_outputs"],
422
+ "sam2_input_images": tensors["sam2_input_images"],
423
+ "images_tensor": tensors["images_tensor"],
424
+ "current_points": [],
425
+ "current_labels": [],
426
+ "current_mask": None,
427
+ "frame_idx": 0,
428
+ "video_path": meta.get("video_name", "example"),
429
+ "interval": interval,
430
+ "fps_in": fps_in,
431
+ "fps_out": fps_out,
432
+ # precomputed output
433
+ "output_video_path": info["video_mp4"],
434
+ "loaded_from_cache": True,
435
+ "cache_id": cache_id,
436
+ }
437
+
438
+ first_frame = pil_imgs[0]
439
+ slider = gr.Slider(value=0, maximum=len(pil_imgs) - 1, step=1, interactive=True)
440
+
441
+ return state, first_frame, slider, info["video_mp4"], interval
442
+
443
+
444
+ # ============================================================
445
+ # UI callbacks (same semantics as your original app.py)
446
+ # ============================================================
447
+ def on_video_uploaded(video_path):
448
+ n_frames = estimate_total_frames(video_path)
449
+ default_interval = max(1, n_frames // 100)
450
+ return (
451
+ gr.update(value=default_interval, maximum=min(30, n_frames)),
452
+ f"Video uploaded ({n_frames} frames). 2) Adjust interval, then click 'Load Frames'.",
453
+ )
454
+
455
+
456
+ def on_video_upload_and_load(video_path, interval):
457
  logger.info(f"User uploaded video: {video_path}, Interval: {interval}")
458
  if video_path is None:
459
  return None, None, gr.Slider(value=0, maximum=0), None
460
+
461
+ pil_imgs, views, resize_funcs, must3r_feats, must3r_outputs, sam2_input_images, images_tensor = process_video_and_features(
462
+ video_path, int(interval)
463
+ )
464
+
 
 
465
  fps_in = estimate_video_fps(video_path)
466
  interval_i = max(1, int(interval))
467
  fps_out = max(1.0, fps_in / interval_i)
468
 
 
469
  state = {
470
  "pil_imgs": pil_imgs,
471
  "views": views,
 
481
  "video_path": video_path,
482
  "interval": interval_i,
483
  "fps_in": fps_in,
484
+ "fps_out": fps_out,
485
+ "output_video_path": None,
486
+ "loaded_from_cache": False,
487
  }
488
+
489
  first_frame = pil_imgs[0]
490
+ new_slider = gr.Slider(value=0, maximum=len(pil_imgs) - 1, step=1, interactive=True)
491
  return first_frame, state, new_slider, gr.Image(value=first_frame)
492
 
493
+
494
  def on_slider_change(state, frame_idx):
495
  if not state:
496
  return None
497
+ frame_idx = int(frame_idx)
498
  if frame_idx >= len(state["pil_imgs"]):
499
  frame_idx = len(state["pil_imgs"]) - 1
500
+
501
  state["frame_idx"] = frame_idx
502
  state["current_points"] = []
503
  state["current_labels"] = []
504
  state["current_mask"] = None
505
+
506
+ return state["pil_imgs"][frame_idx]
507
+
508
 
509
  def on_image_click(state, evt: gr.SelectData, mode):
 
 
 
 
510
  if not state:
511
  return None
512
+
513
  x, y = evt.index
 
 
514
  label_map = {
515
  "Positive Point": 1,
516
  "Negative Point": 0,
517
  "Box Top-Left": 2,
518
+ "Box Bottom-Right": 3,
519
  }
520
  label = label_map[mode]
521
+
 
522
  state["current_points"].append([x, y])
523
  state["current_labels"].append(label)
524
+
 
525
  frame_pil = state["pil_imgs"][state["frame_idx"]]
526
  vis_img = draw_points(frame_pil, state["current_points"], state["current_labels"])
 
 
527
  if state["current_mask"] is not None:
528
  vis_img = overlay_mask(vis_img, state["current_mask"])
 
529
  return vis_img
530
 
531
+
532
  def on_generate_mask_click(state):
 
 
 
 
533
  if not state:
534
  return None
 
 
 
535
  if not state["current_points"]:
536
  raise gr.Error("No points or boxes annotated.")
537
 
 
538
  num_tl = state["current_labels"].count(2)
539
  num_br = state["current_labels"].count(3)
 
540
  if num_tl != num_br or num_tl > 1:
541
+ raise gr.Error(f"Incomplete box detected! TL={num_tl}, BR={num_br}. Must match and be <= 1.")
 
542
 
 
543
  frame_idx = state["frame_idx"]
544
  full_tensor = state["sam2_input_images"]
545
+ frame_tensor = full_tensor[frame_idx].unsqueeze(0)
546
+ original_size = state["pil_imgs"][frame_idx].size
547
+
548
+ mask = generate_frame_mask(
549
+ frame_tensor,
550
+ state["current_points"],
551
+ state["current_labels"],
552
+ original_size,
553
+ )
554
+
 
 
 
 
555
  state["current_mask"] = mask
556
+
 
557
  frame_pil = state["pil_imgs"][frame_idx]
558
  vis_img = overlay_mask(frame_pil, mask)
559
  vis_img = draw_points(vis_img, state["current_points"], state["current_labels"])
 
560
  return vis_img
561
 
562
+
563
+ def reset_annotations(state):
564
+ if not state:
565
+ return None
566
+ state["current_points"] = []
567
+ state["current_labels"] = []
568
+ state["current_mask"] = None
569
+ frame_idx = state["frame_idx"]
570
+ return state["pil_imgs"][frame_idx]
571
+
572
+
573
  def on_track_click(state):
 
574
  if not state or state["current_mask"] is None:
 
575
  raise gr.Error("Please annotate a frame and generate a mask first.")
576
+
 
577
  num_tl = state["current_labels"].count(2)
578
  num_br = state["current_labels"].count(3)
579
  if num_tl != num_br:
580
  raise gr.Error("Incomplete box annotations.")
581
+
582
  start_idx = state["frame_idx"]
583
  first_frame_mask = state["current_mask"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
584
 
585
+ tracked_masks_dict = run_tracking(
586
+ state["sam2_input_images"],
587
+ state["must3r_feats"],
588
+ state["must3r_outputs"],
589
+ start_idx,
590
+ first_frame_mask,
591
+ )
592
+
593
+ output_path = create_video_from_masks(
594
+ state["pil_imgs"],
595
+ tracked_masks_dict,
596
+ fps=state.get("fps_out", 24.0),
597
+ )
598
+ state["output_video_path"] = output_path
599
+ return output_path
600
+
601
+
602
+ # ============================================================
603
+ # Examples UI: row click handler
604
+ # ============================================================
605
+ def on_example_row_click(evt: gr.SelectData, cache_index_state):
606
+ row = evt.value
607
+ # row = [thumb_path, video_name, interval, frames, cache_id]
608
+
609
+ cache_id = row[4]
610
+ state, first_frame, slider, mp4_path, interval = load_cache_into_state(
611
+ cache_id, cache_index_state
612
+ )
613
+
614
+ return (
615
+ first_frame,
616
+ state,
617
+ slider,
618
+ mp4_path,
619
+ gr.update(value=interval),
620
+ "Ready. Example loaded.",
621
+ gr.update(interactive=True),
622
+ gr.update(interactive=True),
623
+ gr.update(interactive=True),
624
+ )
625
 
 
626
 
627
+ # ============================================================
628
+ # Build examples at startup
629
+ # ============================================================
630
+ examples_rows, cache_index = build_examples_table()
631
+
632
+
633
+ # ============================================================
634
+ # App Layout (match original, add Examples at bottom)
635
+ # ============================================================
636
  description = """
637
  <div style="text-align: center;">
638
+ <h1>3AM: 3egment Anything</h1>
639
+ <p>Upload a video, extract geometric features, annotate a frame, and track the object.</p>
640
  </div>
641
  """
642
+
643
  with gr.Blocks(title="3AM: 3egment Anything") as app:
644
  gr.HTML(description)
645
 
 
650
  1) Upload video
651
  2) Adjust frame interval → Load frames
652
  3) Annotate & generate mask
653
+ 4) Track through the video
654
  """
655
  )
656
 
657
  app_state = gr.State()
658
+ cache_index_state = gr.State(cache_index)
659
 
660
  with gr.Row():
661
  with gr.Column(scale=1):
 
663
  video_input = gr.Video(
664
  label="Upload Video",
665
  sources=["upload"],
666
+ height=512,
667
  )
668
 
669
  gr.Markdown("## Step 2 — Set interval, then load frames")
 
673
  maximum=30,
674
  step=1,
675
  value=1,
676
+ info="Default ≈ total_frames / 100",
677
  )
678
 
679
+ load_btn = gr.Button("Load Frames", variant="primary")
 
 
 
680
 
681
  process_status = gr.Textbox(
682
  label="Status",
683
  value="1) Upload a video.",
684
+ interactive=False,
685
  )
686
 
687
  with gr.Column(scale=2):
 
689
  img_display = gr.Image(
690
  label="Annotate Frame",
691
  interactive=True,
692
+ height=512,
693
  )
694
 
695
  frame_slider = gr.Slider(
 
697
  minimum=0,
698
  maximum=100,
699
  step=1,
700
+ value=0,
701
  )
702
 
703
  with gr.Row():
 
709
  "Box Bottom-Right",
710
  ],
711
  value="Positive Point",
712
+ label="Annotation Mode",
713
  )
714
  with gr.Column():
715
  gen_mask_btn = gr.Button(
716
  "Generate Mask",
717
  variant="primary",
718
+ interactive=False,
719
  )
720
  reset_btn = gr.Button(
721
  "Reset Annotations",
722
+ interactive=False,
723
  )
724
 
725
  gr.Markdown("## Step 4 — Track through the video")
 
728
  "Start Tracking",
729
  variant="primary",
730
  scale=1,
731
+ interactive=False,
732
  )
733
 
734
  with gr.Row():
735
  video_output = gr.Video(
736
  label="Tracking Output",
737
  autoplay=True,
738
+ height=512,
739
  )
740
 
741
+ # -------------------------
742
+ # Examples table at bottom
743
+ # -------------------------
744
+ gr.Markdown("## Examples (click a row to load)")
745
+
746
+ examples_df = gr.Dataframe(
747
+ headers=["Example", "Video", "Interval", "Frames", "cache_id"],
748
+ datatype=["image", "str", "number", "number", "str"],
749
+ value=examples_rows,
750
+ row_count=len(examples_rows),
751
+ col_count=(5, "fixed"),
752
+ interactive=False,
753
+ wrap=True,
754
+ visible=True,
755
+ )
756
+ examples_df.style({"display": "none"}, columns=["cache_id"])
757
 
758
+
759
+ # ============================================================
760
+ # Events (original + examples)
761
+ # ============================================================
762
  video_input.upload(
763
  fn=on_video_uploaded,
764
  inputs=video_input,
765
+ outputs=[interval_slider, process_status],
766
  )
767
 
 
768
  load_btn.click(
769
  fn=lambda: (
770
  "Loading frames...",
 
772
  gr.update(interactive=False),
773
  gr.update(interactive=False),
774
  ),
775
+ outputs=[process_status, gen_mask_btn, reset_btn, track_btn],
776
  ).then(
777
+ fn=on_video_upload_and_load,
778
  inputs=[video_input, interval_slider],
779
+ outputs=[img_display, app_state, frame_slider, img_display],
780
  ).then(
781
  fn=lambda: (
782
  "Ready. 3) Annotate and generate mask.",
 
784
  gr.update(interactive=True),
785
  gr.update(interactive=True),
786
  ),
787
+ outputs=[process_status, gen_mask_btn, reset_btn, track_btn],
788
  )
789
 
790
  frame_slider.change(
791
  fn=on_slider_change,
792
  inputs=[app_state, frame_slider],
793
+ outputs=[img_display],
794
  )
795
 
796
  img_display.select(
797
  fn=on_image_click,
798
  inputs=[app_state, mode_radio],
799
+ outputs=[img_display],
800
  )
801
 
802
  gen_mask_btn.click(
803
  fn=on_generate_mask_click,
804
  inputs=[app_state],
805
+ outputs=[img_display],
806
  )
807
 
808
  reset_btn.click(
809
  fn=reset_annotations,
810
  inputs=[app_state],
811
+ outputs=[img_display],
812
  )
813
 
814
  track_btn.click(
815
  fn=lambda: "Tracking in progress...",
816
+ outputs=process_status,
817
  ).then(
818
  fn=on_track_click,
819
  inputs=[app_state],
820
+ outputs=[video_output],
821
  ).then(
822
  fn=lambda: "Tracking complete!",
823
+ outputs=process_status,
824
  )
825
 
826
+ examples_df.select(
827
+ fn=on_example_row_click,
828
+ inputs=[cache_index_state],
829
+ outputs=[
830
+ img_display,
831
+ app_state,
832
+ frame_slider,
833
+ video_output,
834
+ interval_slider,
835
+ process_status,
836
+ gen_mask_btn,
837
+ reset_btn,
838
+ track_btn,
839
+ ],
840
+ )
841
+
842
 
843
  if __name__ == "__main__":
844
  logger.info("Starting Gradio app...")
845
+ app.launch()
app_cache.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app_cache.py
2
+ # Purpose:
3
+ # - Same UI flow (upload -> load frames -> annotate -> generate mask -> track)
4
+ # - After tracking, enable "Save Cache"
5
+ # - You can create multiple caches by repeating the workflow
6
+ #
7
+ # Cache contents per example:
8
+ # cache/<key>/
9
+ # meta.pkl
10
+ # frames/*.jpg
11
+ # state_tensors.pt (must3r_feats, must3r_outputs, sam2_input_images, images_tensor) on CPU
12
+ # output_tracking.mp4
13
+ #
14
+ # Notes:
15
+ # - We do NOT pickle views/resize_funcs (recomputed on load).
16
+ # - We store frames as JPEG to avoid pickling PIL and to be deterministic/reloadable.
17
+
18
+ import spaces
19
+ import subprocess
20
+ import sys, os
21
+ from pathlib import Path
22
+ import math
23
+ import hashlib
24
+ import pickle
25
+ from datetime import datetime
26
+ from typing import Any, Dict, List, Tuple
27
+
28
+ import importlib, site
29
+
30
+ import gradio as gr
31
+ import torch
32
+ import numpy as np
33
+ from PIL import Image, ImageDraw
34
+ import cv2
35
+ import logging
36
+
37
+ # ----------------------------
38
+ # Project bootstrap
39
+ # ----------------------------
40
+ ROOT = Path(__file__).resolve().parent
41
+ SAM2 = ROOT / "sam2-src"
42
+ CKPT = SAM2 / "checkpoints" / "sam2.1_hiera_large.pt"
43
+
44
+ # download sam2 checkpoints
45
+ if not CKPT.exists():
46
+ subprocess.check_call(["bash", "download_ckpts.sh"], cwd=SAM2 / "checkpoints")
47
+
48
+ # install sam2
49
+ try:
50
+ import sam2.build_sam # noqa
51
+ except ModuleNotFoundError:
52
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "./sam2-src"], cwd=ROOT)
53
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "./sam2-src[notebooks]"], cwd=ROOT)
54
+
55
+ # install asmk
56
+ try:
57
+ import asmk.index # noqa: F401
58
+ except Exception:
59
+ subprocess.check_call(["cythonize", "*.pyx"], cwd="./asmk-src/cython")
60
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "./asmk-src", "--no-build-isolation"])
61
+
62
+ # download private checkpoints
63
+ if not os.path.exists("./private"):
64
+ from huggingface_hub import snapshot_download
65
+ snapshot_download(
66
+ repo_id="nycu-cplab/3AM",
67
+ local_dir="./private",
68
+ repo_type="model",
69
+ )
70
+
71
+ for sp in site.getsitepackages():
72
+ site.addsitedir(sp)
73
+ importlib.invalidate_caches()
74
+
75
+ # ----------------------------
76
+ # Logging
77
+ # ----------------------------
78
+ logging.basicConfig(
79
+ level=logging.INFO,
80
+ format="%(asctime)s [%(levelname)s] %(message)s",
81
+ handlers=[logging.StreamHandler(sys.stdout)],
82
+ )
83
+ logger = logging.getLogger("app_cache")
84
+
85
+ # ----------------------------
86
+ # Engine imports
87
+ # ----------------------------
88
+ from engine import (
89
+ get_predictors,
90
+ get_views,
91
+ prepare_sam2_inputs,
92
+ must3r_features_and_output,
93
+ get_single_frame_mask,
94
+ get_tracked_masks,
95
+ )
96
+
97
+ # ----------------------------
98
+ # Globals
99
+ # ----------------------------
100
+ PREDICTOR_ORIGINAL = None
101
+ PREDICTOR = None
102
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
103
+
104
+ def load_models():
105
+ global PREDICTOR_ORIGINAL, PREDICTOR
106
+ if PREDICTOR is None or PREDICTOR_ORIGINAL is None:
107
+ logger.info(f"Initializing models on device: {DEVICE}...")
108
+ PREDICTOR_ORIGINAL, PREDICTOR = get_predictors(device=DEVICE)
109
+ logger.info("Models loaded successfully.")
110
+ return PREDICTOR_ORIGINAL, PREDICTOR
111
+
112
+ # Ensure no_grad globally (as you had)
113
+ torch.no_grad().__enter__()
114
+
115
+ # ----------------------------
116
+ # Video / visualization helpers
117
+ # ----------------------------
118
+ def video_to_frames(video_path, interval=1):
119
+ logger.info(f"Extracting frames from video: {video_path} with interval={interval}")
120
+ cap = cv2.VideoCapture(video_path)
121
+ frames = []
122
+ count = 0
123
+ while cap.isOpened():
124
+ ret, frame = cap.read()
125
+ if not ret:
126
+ break
127
+ if count % interval == 0:
128
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
129
+ frames.append(Image.fromarray(frame_rgb))
130
+ count += 1
131
+ cap.release()
132
+ logger.info(f"Extracted {len(frames)} frames (sampled from {count} total).")
133
+ return frames
134
+
135
+ def draw_points(image_pil, points, labels):
136
+ img_draw = image_pil.copy()
137
+ draw = ImageDraw.Draw(img_draw)
138
+ r = 5
139
+ for pt, lbl in zip(points, labels):
140
+ x, y = pt
141
+ if lbl == 1:
142
+ color = "green"
143
+ elif lbl == 0:
144
+ color = "red"
145
+ elif lbl == 2:
146
+ color = "blue"
147
+ elif lbl == 3:
148
+ color = "cyan"
149
+ else:
150
+ color = "yellow"
151
+ draw.ellipse((x-r, y-r, x+r, y+r), fill=color, outline="white")
152
+ return img_draw
153
+
154
+ def overlay_mask(image_pil, mask, color=(255, 0, 0), alpha=0.5):
155
+ if mask is None:
156
+ return image_pil
157
+ mask = mask > 0
158
+ img_np = np.array(image_pil)
159
+ h, w = img_np.shape[:2]
160
+ if mask.shape[0] != h or mask.shape[1] != w:
161
+ mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool)
162
+ overlay = img_np.copy()
163
+ overlay[mask] = np.array(color, dtype=np.uint8)
164
+ combined = cv2.addWeighted(overlay, alpha, img_np, 1 - alpha, 0)
165
+ return Image.fromarray(combined)
166
+
167
+ def create_video_from_masks(frames, masks_dict, output_path="output_tracking.mp4", fps=24):
168
+ logger.info(f"Creating video output at {output_path} with {len(frames)} frames.")
169
+ if not frames:
170
+ return None
171
+ fps = float(fps)
172
+ if not (fps > 0.0):
173
+ fps = 24.0
174
+ h, w = np.array(frames[0]).shape[:2]
175
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
176
+ out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
177
+
178
+ for idx, frame in enumerate(frames):
179
+ mask = masks_dict.get(idx)
180
+ if mask is not None:
181
+ pil_out = overlay_mask(frame, mask, color=(255, 0, 0), alpha=0.6)
182
+ frame_np = np.array(pil_out)
183
+ else:
184
+ frame_np = np.array(frame)
185
+ frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR)
186
+ out.write(frame_bgr)
187
+
188
+ out.release()
189
+ return output_path
190
+
191
+ # ----------------------------
192
+ # Runtime estimation helpers
193
+ # ----------------------------
194
+ def estimate_video_fps(video_path: str) -> float:
195
+ cap = cv2.VideoCapture(video_path)
196
+ fps = float(cap.get(cv2.CAP_PROP_FPS)) or 0.0
197
+ cap.release()
198
+ return fps if fps > 0.0 else 24.0
199
+
200
+ def estimate_total_frames(video_path: str) -> int:
201
+ cap = cv2.VideoCapture(video_path)
202
+ n = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
203
+ cap.release()
204
+ return max(1, n)
205
+
206
+ MAX_GPU_SECONDS = 600
207
+
208
+ def clamp_duration(sec: int) -> int:
209
+ return int(min(MAX_GPU_SECONDS, max(1, sec)))
210
+
211
+ def get_duration_must3r_features(video_path, interval):
212
+ total = estimate_total_frames(video_path)
213
+ interval = max(1, int(interval))
214
+ processed = math.ceil(total / interval)
215
+ sec_per_frame = 2
216
+ return clamp_duration(int(processed * sec_per_frame))
217
+
218
+ def get_duration_tracking(sam2_input_images, must3r_feats, must3r_outputs, start_idx, first_frame_mask):
219
+ try:
220
+ n = int(getattr(sam2_input_images, "shape")[0])
221
+ except Exception:
222
+ n = 100
223
+ sec_per_frame = 2
224
+ return clamp_duration(int(n * sec_per_frame))
225
+
226
+ # ----------------------------
227
+ # GPU functions
228
+ # ----------------------------
229
+ @spaces.GPU(duration=get_duration_must3r_features)
230
+ def process_video_and_features(video_path, interval):
231
+ logger.info(f"GPU: feature extraction interval={interval}")
232
+ load_models()
233
+
234
+ pil_imgs = video_to_frames(video_path, interval=max(1, int(interval)))
235
+ if not pil_imgs:
236
+ raise ValueError("Could not extract frames.")
237
+
238
+ views, resize_funcs = get_views(pil_imgs)
239
+
240
+ must3r_feats, must3r_outputs = must3r_features_and_output(views, device=DEVICE)
241
+
242
+ sam2_input_images, images_tensor = prepare_sam2_inputs(views, pil_imgs, resize_funcs)
243
+
244
+ return pil_imgs, views, resize_funcs, must3r_feats, must3r_outputs, sam2_input_images, images_tensor
245
+
246
+ @spaces.GPU
247
+ def generate_frame_mask(image_tensor, points, labels, original_size):
248
+ logger.info(f"GPU: generate mask points={len(points)}")
249
+ load_models()
250
+
251
+ pts_tensor = torch.tensor(points, dtype=torch.float32).unsqueeze(0).to(DEVICE)
252
+ lbl_tensor = torch.tensor(labels, dtype=torch.int32).unsqueeze(0).to(DEVICE)
253
+
254
+ w, h = original_size
255
+ pts_tensor[..., 0] /= (w / 1024.0)
256
+ pts_tensor[..., 1] /= (h / 1024.0)
257
+
258
+ mask = get_single_frame_mask(
259
+ image=image_tensor,
260
+ predictor_original=PREDICTOR_ORIGINAL,
261
+ points=pts_tensor,
262
+ labels=lbl_tensor,
263
+ device=DEVICE,
264
+ )
265
+ return mask.squeeze().cpu().numpy()
266
+
267
+ @spaces.GPU(duration=get_duration_tracking)
268
+ def run_tracking(sam2_input_images, must3r_feats, must3r_outputs, start_idx, first_frame_mask):
269
+ logger.info(f"GPU: tracking start_idx={start_idx}")
270
+ load_models()
271
+
272
+ mask_tensor = torch.tensor(first_frame_mask).to(DEVICE) > 0
273
+
274
+ tracked_masks = get_tracked_masks(
275
+ sam2_input_images=sam2_input_images,
276
+ must3r_feats=must3r_feats,
277
+ must3r_outputs=must3r_outputs,
278
+ start_idx=start_idx,
279
+ first_frame_mask=mask_tensor,
280
+ predictor=PREDICTOR,
281
+ predictor_original=PREDICTOR_ORIGINAL,
282
+ device=DEVICE,
283
+ )
284
+ return tracked_masks
285
+
286
+ # ----------------------------
287
+ # Cache utilities
288
+ # ----------------------------
289
+ CACHE_DIR = Path("./cache")
290
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
291
+
292
+ def _make_cache_key(video_path: str, interval: int, start_idx: int) -> str:
293
+ name = Path(video_path).name if video_path else "video"
294
+ stamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
295
+ s = f"{name}|interval={interval}|start={start_idx}|{stamp}"
296
+ return hashlib.sha256(s.encode("utf-8")).hexdigest()[:16]
297
+
298
+ def _cache_paths(key: str) -> Dict[str, Path]:
299
+ base = CACHE_DIR / key
300
+ base.mkdir(parents=True, exist_ok=True)
301
+ return {
302
+ "base": base,
303
+ "meta": base / "meta.pkl",
304
+ "frames_dir": base / "frames",
305
+ "tensors": base / "state_tensors.pt",
306
+ "video": base / "output_tracking.mp4",
307
+ }
308
+
309
+ def _save_frames_as_jpg(pil_imgs: List[Image.Image], frames_dir: Path, quality: int = 95) -> None:
310
+ frames_dir.mkdir(parents=True, exist_ok=True)
311
+ for i, im in enumerate(pil_imgs):
312
+ im.save(frames_dir / f"{i:06d}.jpg", "JPEG", quality=quality, subsampling=0)
313
+
314
+ def _to_cpu(obj: Any) -> Any:
315
+ if torch.is_tensor(obj):
316
+ return obj.detach().to("cpu")
317
+ if isinstance(obj, dict):
318
+ return {k: _to_cpu(v) for k, v in obj.items()}
319
+ if isinstance(obj, (list, tuple)):
320
+ out = [_to_cpu(v) for v in obj]
321
+ return type(obj)(out) if isinstance(obj, tuple) else out
322
+ return obj
323
+
324
+ def _pack_masks_uint8_cpu(tracked_masks_dict: Dict[int, Any]) -> Dict[int, torch.Tensor]:
325
+ packed: Dict[int, torch.Tensor] = {}
326
+ for k, v in tracked_masks_dict.items():
327
+ if isinstance(v, np.ndarray):
328
+ t = torch.from_numpy(v)
329
+ else:
330
+ t = v
331
+ if not torch.is_tensor(t):
332
+ t = torch.tensor(t)
333
+ packed[int(k)] = (t > 0).to(torch.uint8).cpu()
334
+ return packed
335
+
336
+ def save_full_cache_from_state(state: Dict[str, Any]) -> str:
337
+ if not state:
338
+ raise ValueError("Empty state.")
339
+ required = [
340
+ "pil_imgs",
341
+ "must3r_feats",
342
+ "must3r_outputs",
343
+ "sam2_input_images",
344
+ "images_tensor",
345
+ "output_video_path",
346
+ "video_path",
347
+ "interval",
348
+ "fps_in",
349
+ "fps_out",
350
+ "last_tracking_start_idx",
351
+ ]
352
+ missing = [k for k in required if k not in state or state[k] is None]
353
+ if missing:
354
+ raise ValueError(f"State missing fields: {missing}")
355
+
356
+ key = _make_cache_key(
357
+ str(state["video_path"]),
358
+ int(state["interval"]),
359
+ int(state["last_tracking_start_idx"]),
360
+ )
361
+ paths = _cache_paths(key)
362
+
363
+ _save_frames_as_jpg(state["pil_imgs"], paths["frames_dir"])
364
+
365
+ torch.save(
366
+ {
367
+ "must3r_feats": _to_cpu(state["must3r_feats"]),
368
+ "must3r_outputs": _to_cpu(state["must3r_outputs"]),
369
+ "sam2_input_images": _to_cpu(state["sam2_input_images"]),
370
+ "images_tensor": _to_cpu(state["images_tensor"]),
371
+ },
372
+ paths["tensors"],
373
+ )
374
+
375
+ src = Path(state["output_video_path"])
376
+ if not src.exists():
377
+ raise FileNotFoundError(f"Output video not found: {src}")
378
+ dst = paths["video"]
379
+ if src.resolve() != dst.resolve():
380
+ dst.write_bytes(src.read_bytes())
381
+
382
+ meta = {
383
+ "video_name": Path(str(state["video_path"])).name,
384
+ "interval": int(state["interval"]),
385
+ "fps_in": float(state["fps_in"]),
386
+ "fps_out": float(state["fps_out"]),
387
+ "num_frames": int(len(state["pil_imgs"])),
388
+ "start_idx": int(state["last_tracking_start_idx"]),
389
+ "points": list(state.get("last_points", [])),
390
+ "labels": list(state.get("last_labels", [])),
391
+ "cache_key": key,
392
+ }
393
+ with open(paths["meta"], "wb") as f:
394
+ pickle.dump(meta, f)
395
+
396
+ return key
397
+
398
+ # ----------------------------
399
+ # UI callbacks
400
+ # ----------------------------
401
+ def on_video_upload(video_path, interval):
402
+ if video_path is None:
403
+ return None, None, gr.Slider(value=0, maximum=0), None
404
+
405
+ pil_imgs, views, resize_funcs, must3r_feats, must3r_outputs, sam2_input_images, images_tensor = process_video_and_features(
406
+ video_path, int(interval)
407
+ )
408
+
409
+ fps_in = estimate_video_fps(video_path)
410
+ interval_i = max(1, int(interval))
411
+ fps_out = max(1.0, fps_in / interval_i)
412
+
413
+ state = {
414
+ "pil_imgs": pil_imgs,
415
+ "views": views,
416
+ "resize_funcs": resize_funcs,
417
+ "must3r_feats": must3r_feats,
418
+ "must3r_outputs": must3r_outputs,
419
+ "sam2_input_images": sam2_input_images,
420
+ "images_tensor": images_tensor,
421
+ "current_points": [],
422
+ "current_labels": [],
423
+ "current_mask": None,
424
+ "frame_idx": 0,
425
+ "video_path": video_path,
426
+ "interval": interval_i,
427
+ "fps_in": fps_in,
428
+ "fps_out": fps_out,
429
+ # tracking outputs (filled later)
430
+ "output_video_path": None,
431
+ "last_tracking_start_idx": None,
432
+ "last_points": None,
433
+ "last_labels": None,
434
+ }
435
+
436
+ first_frame = pil_imgs[0]
437
+ new_slider = gr.Slider(value=0, maximum=len(pil_imgs) - 1, step=1, interactive=True)
438
+ return first_frame, state, new_slider, gr.Image(value=first_frame)
439
+
440
+ def on_slider_change(state, frame_idx):
441
+ if not state:
442
+ return None
443
+ frame_idx = int(frame_idx)
444
+ frame_idx = min(frame_idx, len(state["pil_imgs"]) - 1)
445
+ state["frame_idx"] = frame_idx
446
+ state["current_points"] = []
447
+ state["current_labels"] = []
448
+ state["current_mask"] = None
449
+ frame = state["pil_imgs"][frame_idx]
450
+ return frame
451
+
452
+ def on_image_click(state, evt: gr.SelectData, mode):
453
+ if not state:
454
+ return None
455
+ x, y = evt.index
456
+
457
+ label_map = {
458
+ "Positive Point": 1,
459
+ "Negative Point": 0,
460
+ "Box Top-Left": 2,
461
+ "Box Bottom-Right": 3,
462
+ }
463
+ label = label_map[mode]
464
+ state["current_points"].append([x, y])
465
+ state["current_labels"].append(label)
466
+
467
+ frame_pil = state["pil_imgs"][state["frame_idx"]]
468
+ vis_img = draw_points(frame_pil, state["current_points"], state["current_labels"])
469
+ if state["current_mask"] is not None:
470
+ vis_img = overlay_mask(vis_img, state["current_mask"])
471
+ return vis_img
472
+
473
+ def on_generate_mask_click(state):
474
+ if not state:
475
+ return None
476
+ if not state["current_points"]:
477
+ raise gr.Error("No points or boxes annotated.")
478
+
479
+ num_tl = state["current_labels"].count(2)
480
+ num_br = state["current_labels"].count(3)
481
+ if num_tl != num_br or num_tl > 1:
482
+ raise gr.Error(f"Incomplete box: TL={num_tl}, BR={num_br}. Must match and be <= 1.")
483
+
484
+ frame_idx = state["frame_idx"]
485
+ full_tensor = state["sam2_input_images"]
486
+ frame_tensor = full_tensor[frame_idx].unsqueeze(0)
487
+ original_size = state["pil_imgs"][frame_idx].size
488
+
489
+ mask = generate_frame_mask(
490
+ frame_tensor,
491
+ state["current_points"],
492
+ state["current_labels"],
493
+ original_size,
494
+ )
495
+ state["current_mask"] = mask
496
+
497
+ frame_pil = state["pil_imgs"][frame_idx]
498
+ vis_img = overlay_mask(frame_pil, mask)
499
+ vis_img = draw_points(vis_img, state["current_points"], state["current_labels"])
500
+ return vis_img
501
+
502
+ def reset_annotations(state):
503
+ if not state:
504
+ return None
505
+ state["current_points"] = []
506
+ state["current_labels"] = []
507
+ state["current_mask"] = None
508
+ frame_idx = state["frame_idx"]
509
+ return state["pil_imgs"][frame_idx]
510
+
511
+ def on_track_click(state):
512
+ if not state or state["current_mask"] is None:
513
+ raise gr.Error("Generate a mask first.")
514
+
515
+ num_tl = state["current_labels"].count(2)
516
+ num_br = state["current_labels"].count(3)
517
+ if num_tl != num_br:
518
+ raise gr.Error("Incomplete box annotations.")
519
+
520
+ start_idx = int(state["frame_idx"])
521
+ first_frame_mask = state["current_mask"]
522
+
523
+ tracked_masks_dict = run_tracking(
524
+ state["sam2_input_images"],
525
+ state["must3r_feats"],
526
+ state["must3r_outputs"],
527
+ start_idx,
528
+ first_frame_mask,
529
+ )
530
+
531
+ output_path = create_video_from_masks(
532
+ state["pil_imgs"],
533
+ tracked_masks_dict,
534
+ fps=state.get("fps_out", 24.0),
535
+ )
536
+
537
+ state["output_video_path"] = output_path
538
+ state["last_tracking_start_idx"] = start_idx
539
+ state["last_points"] = list(state.get("current_points", []))
540
+ state["last_labels"] = list(state.get("current_labels", []))
541
+
542
+ return output_path, state
543
+
544
+ def on_save_cache_click(state):
545
+ key = save_full_cache_from_state(state)
546
+ return f"Saved cache key: {key}"
547
+
548
+ # ----------------------------
549
+ # UI layout
550
+ # ----------------------------
551
+ description = """
552
+ <div style="text-align: center;">
553
+ <h1>3AM: 3egment Anything with Geometric Consistency in Videos</h1>
554
+ <p>Cache-builder UI: run full pipeline, then save caches for user examples.</p>
555
+ </div>
556
+ """
557
+
558
+ with gr.Blocks(title="3AM Cache Builder") as app:
559
+ gr.HTML(description)
560
+
561
+ app_state = gr.State()
562
+
563
+ with gr.Row():
564
+ with gr.Column(scale=1):
565
+ gr.Markdown("## Step 1 — Upload video")
566
+ video_input = gr.Video(label="Upload Video", sources=["upload"], height=512)
567
+
568
+ gr.Markdown("## Step 2 — Set interval, then load frames")
569
+ interval_slider = gr.Slider(
570
+ label="Frame Interval",
571
+ minimum=1,
572
+ maximum=30,
573
+ step=1,
574
+ value=1,
575
+ )
576
+
577
+ load_btn = gr.Button("Load Frames", variant="primary")
578
+
579
+ process_status = gr.Textbox(label="Status", value="1) Upload a video.", interactive=False)
580
+
581
+ with gr.Column(scale=2):
582
+ gr.Markdown("## Step 3 — Annotate frame & generate mask")
583
+ img_display = gr.Image(label="Annotate Frame", interactive=True, height=512)
584
+
585
+ frame_slider = gr.Slider(label="Select Frame", minimum=0, maximum=100, step=1, value=0)
586
+
587
+ with gr.Row():
588
+ mode_radio = gr.Radio(
589
+ choices=["Positive Point", "Negative Point", "Box Top-Left", "Box Bottom-Right"],
590
+ value="Positive Point",
591
+ label="Annotation Mode",
592
+ )
593
+ with gr.Column():
594
+ gen_mask_btn = gr.Button("Generate Mask", variant="primary", interactive=False)
595
+ reset_btn = gr.Button("Reset Annotations", interactive=False)
596
+
597
+ gr.Markdown("## Step 4 — Track & Save Cache")
598
+ with gr.Row():
599
+ track_btn = gr.Button("Start Tracking", variant="primary", interactive=False)
600
+ save_cache_btn = gr.Button("Save Cache", variant="secondary", interactive=False)
601
+
602
+ with gr.Row():
603
+ video_output = gr.Video(label="Tracking Output", autoplay=True, height=512)
604
+
605
+ cache_status = gr.Textbox(label="Cache", value="", interactive=False)
606
+
607
+ # ------------------------
608
+ # Events
609
+ # ------------------------
610
+ def on_video_uploaded(video_path):
611
+ n_frames = estimate_total_frames(video_path)
612
+ default_interval = max(1, n_frames // 100)
613
+ return (
614
+ gr.update(value=default_interval, maximum=min(30, n_frames)),
615
+ f"Video uploaded ({n_frames} frames). 2) Adjust interval, then click 'Load Frames'.",
616
+ )
617
+
618
+ video_input.upload(fn=on_video_uploaded, inputs=video_input, outputs=[interval_slider, process_status])
619
+
620
+ load_btn.click(
621
+ fn=lambda: (
622
+ "Loading frames...",
623
+ gr.update(interactive=False),
624
+ gr.update(interactive=False),
625
+ gr.update(interactive=False),
626
+ gr.update(interactive=False), # save_cache_btn
627
+ gr.update(value=""),
628
+ ),
629
+ outputs=[process_status, gen_mask_btn, reset_btn, track_btn, save_cache_btn, cache_status],
630
+ ).then(
631
+ fn=on_video_upload,
632
+ inputs=[video_input, interval_slider],
633
+ outputs=[img_display, app_state, frame_slider, img_display],
634
+ ).then(
635
+ fn=lambda: (
636
+ "Ready. 3) Annotate and generate mask.",
637
+ gr.update(interactive=True),
638
+ gr.update(interactive=True),
639
+ gr.update(interactive=True),
640
+ ),
641
+ outputs=[process_status, gen_mask_btn, reset_btn, track_btn],
642
+ )
643
+
644
+ frame_slider.change(fn=on_slider_change, inputs=[app_state, frame_slider], outputs=[img_display])
645
+
646
+ img_display.select(fn=on_image_click, inputs=[app_state, mode_radio], outputs=[img_display])
647
+
648
+ gen_mask_btn.click(fn=on_generate_mask_click, inputs=[app_state], outputs=[img_display])
649
+
650
+ reset_btn.click(fn=reset_annotations, inputs=[app_state], outputs=[img_display])
651
+
652
+ track_btn.click(
653
+ fn=lambda: (
654
+ "Tracking in progress...",
655
+ gr.update(interactive=False),
656
+ gr.update(interactive=False),
657
+ ),
658
+ outputs=[process_status, track_btn, save_cache_btn],
659
+ ).then(
660
+ fn=on_track_click,
661
+ inputs=[app_state],
662
+ outputs=[video_output, app_state],
663
+ ).then(
664
+ fn=lambda: (
665
+ "Tracking complete. You can save cache.",
666
+ gr.update(interactive=True), # track_btn
667
+ gr.update(interactive=True), # save_cache_btn
668
+ ),
669
+ outputs=[process_status, track_btn, save_cache_btn],
670
+ )
671
+
672
+ save_cache_btn.click(fn=on_save_cache_click, inputs=[app_state], outputs=[cache_status])
673
+
674
+ if __name__ == "__main__":
675
+ app.launch()