nycu-cplab commited on
Commit
9d941d0
Β·
1 Parent(s): 03d1a10

ui improve

Browse files
Files changed (1) hide show
  1. app.py +177 -52
app.py CHANGED
@@ -2,6 +2,7 @@ import spaces
2
  import subprocess
3
  import sys, os
4
  from pathlib import Path
 
5
 
6
  ''' loading modules '''
7
  ROOT = Path(__file__).resolve().parent
@@ -173,7 +174,9 @@ def create_video_from_masks(frames, masks_dict, output_path="output_tracking.mp4
173
  if not frames:
174
  logger.warning("No frames to create video.")
175
  return None
176
-
 
 
177
  h, w = np.array(frames[0]).shape[:2]
178
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
179
  out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
@@ -195,7 +198,34 @@ def create_video_from_masks(frames, masks_dict, output_path="output_tracking.mp4
195
 
196
  # --- GPU Wrapped Functions ---
197
 
198
- @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  def process_video_and_features(video_path, interval):
200
  """Load video, subsample frames, get views, MUSt3R features, SAM2 inputs."""
201
  logger.info(f"Starting GPU process: Video feature extraction (Interval: {interval})")
@@ -251,7 +281,17 @@ def generate_frame_mask(image_tensor, points, labels, original_size):
251
  logger.error(f"Error during mask generation: {e}")
252
  raise e
253
 
254
- @spaces.GPU
 
 
 
 
 
 
 
 
 
 
255
  def run_tracking(sam2_input_images, must3r_feats, must3r_outputs, start_idx, first_frame_mask):
256
  """Track the mask across the video."""
257
  logger.info(f"Starting tracking from frame index {start_idx}...")
@@ -289,6 +329,10 @@ def on_video_upload(video_path, interval):
289
  logger.error(f"Failed to process video: {e}")
290
  raise gr.Error(f"Processing failed: {str(e)}")
291
 
 
 
 
 
292
  # Initialize state
293
  state = {
294
  "pil_imgs": pil_imgs,
@@ -301,7 +345,11 @@ def on_video_upload(video_path, interval):
301
  "current_points": [],
302
  "current_labels": [],
303
  "current_mask": None,
304
- "frame_idx": 0
 
 
 
 
305
  }
306
 
307
  first_frame = pil_imgs[0]
@@ -427,7 +475,11 @@ def on_track_click(state):
427
  first_frame_mask
428
  )
429
 
430
- output_path = create_video_from_masks(state["pil_imgs"], tracked_masks_dict)
 
 
 
 
431
  return output_path
432
  except Exception as e:
433
  logger.error(f"Tracking failed in UI callback: {e}")
@@ -451,101 +503,173 @@ description = """
451
  <p>Upload a video, geometric features are extracted automatically. Select a frame, click to annotate objects, and track them in 3D-consistent space.</p>
452
  </div>
453
  """
454
-
455
  with gr.Blocks(title="3AM: 3egment Anything") as app:
456
  gr.HTML(description)
457
-
 
 
 
 
 
 
 
 
 
 
 
458
  app_state = gr.State()
459
-
460
  with gr.Row():
461
  with gr.Column(scale=1):
462
- with gr.Group():
463
- # Added height limit to video input
464
- video_input = gr.Video(
465
- label="Upload Video",
466
- sources=["upload"],
467
- height=512
468
- )
469
- interval_slider = gr.Slider(
470
- label="Frame Interval (Applied to entire pipeline)",
471
- minimum=1,
472
- maximum=30,
473
- step=1,
474
- value=1,
475
- info="Process every N-th frame. Higher values = faster processing but lower temporal resolution."
476
- )
477
- process_status = gr.Textbox(label="Status", value="Waiting for upload...", interactive=False)
478
-
 
 
 
 
 
 
 
 
 
 
 
479
  with gr.Column(scale=2):
480
- # Added height limit to image display
481
  img_display = gr.Image(
482
- label="Annotate Frame",
483
- interactive=True,
484
  height=512
485
  )
486
- frame_slider = gr.Slider(label="Select Frame", minimum=0, maximum=100, step=1, value=0)
487
-
 
 
 
 
 
 
 
488
  with gr.Row():
489
  mode_radio = gr.Radio(
490
- choices=["Positive Point", "Negative Point", "Box Top-Left", "Box Bottom-Right"],
 
 
 
 
 
491
  value="Positive Point",
492
  label="Annotation Mode"
493
  )
494
  with gr.Column():
495
- gen_mask_btn = gr.Button("Generate Mask", variant="primary")
496
- reset_btn = gr.Button("Reset Annotations")
497
-
 
 
 
 
 
 
 
 
498
  with gr.Row():
499
- track_btn = gr.Button("Start Tracking", variant="primary", scale=1)
500
-
 
 
 
 
 
501
  with gr.Row():
502
- # Added height limit to video output
503
  video_output = gr.Video(
504
- label="Tracking Output",
505
- autoplay=True,
506
  height=512
507
  )
508
 
509
- # --- Events ---
510
-
 
 
 
 
 
 
 
 
 
 
 
 
511
  video_input.upload(
512
- fn=lambda: "Processing video (MUSt3R + SAM2)...",
513
- outputs=process_status
 
 
 
 
 
 
 
 
 
 
 
 
514
  ).then(
515
  fn=on_video_upload,
516
  inputs=[video_input, interval_slider],
517
  outputs=[img_display, app_state, frame_slider, img_display]
518
  ).then(
519
- fn=lambda: "Ready to annotate.",
520
- outputs=process_status
 
 
 
 
 
521
  )
522
-
523
  frame_slider.change(
524
  fn=on_slider_change,
525
  inputs=[app_state, frame_slider],
526
  outputs=[img_display]
527
  )
528
-
529
- # 1. Click on image -> Draw point (no mask gen)
530
  img_display.select(
531
  fn=on_image_click,
532
  inputs=[app_state, mode_radio],
533
  outputs=[img_display]
534
  )
535
-
536
- # 2. Click Generate -> Check box consistency & Gen Mask
537
  gen_mask_btn.click(
538
  fn=on_generate_mask_click,
539
  inputs=[app_state],
540
  outputs=[img_display]
541
  )
542
-
543
  reset_btn.click(
544
  fn=reset_annotations,
545
  inputs=[app_state],
546
  outputs=[img_display]
547
  )
548
-
549
  track_btn.click(
550
  fn=lambda: "Tracking in progress...",
551
  outputs=process_status
@@ -558,6 +682,7 @@ with gr.Blocks(title="3AM: 3egment Anything") as app:
558
  outputs=process_status
559
  )
560
 
 
561
  if __name__ == "__main__":
562
  logger.info("Starting Gradio app...")
563
  app.launch()
 
2
  import subprocess
3
  import sys, os
4
  from pathlib import Path
5
+ import math
6
 
7
  ''' loading modules '''
8
  ROOT = Path(__file__).resolve().parent
 
174
  if not frames:
175
  logger.warning("No frames to create video.")
176
  return None
177
+ fps = float(fps)
178
+ if not (fps > 0.0):
179
+ fps = 24.0
180
  h, w = np.array(frames[0]).shape[:2]
181
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
182
  out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
 
198
 
199
  # --- GPU Wrapped Functions ---
200
 
201
+ def estimate_video_fps(video_path: str) -> float:
202
+ cap = cv2.VideoCapture(video_path)
203
+ fps = float(cap.get(cv2.CAP_PROP_FPS)) or 0.0
204
+ cap.release()
205
+ # Robust fallback if metadata is missing
206
+ return fps if fps > 0.0 else 24.0
207
+
208
+ MAX_GPU_SECONDS = 600 # e.g., 10 minutes
209
+ def clamp_duration(sec: int) -> int:
210
+ return int(min(MAX_GPU_SECONDS, max(1, sec)))
211
+
212
+ def estimate_total_frames(video_path: str) -> int:
213
+ cap = cv2.VideoCapture(video_path)
214
+ n = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
215
+ cap.release()
216
+ return max(1, n)
217
+
218
+ def get_duration_must3r_features(video_path, interval):
219
+ # interval is applied to the entire pipeline, so actual processed frames ~= ceil(total / interval)
220
+ total = estimate_total_frames(video_path)
221
+ interval = max(1, int(interval))
222
+ processed = math.ceil(total / interval)
223
+
224
+ # Tune this coefficient based on your observed runtime on ZeroGPU
225
+ sec_per_frame = 2
226
+ return clamp_duration(int(processed * sec_per_frame))
227
+
228
+ @spaces.GPU(duration=get_duration_must3r_features)
229
  def process_video_and_features(video_path, interval):
230
  """Load video, subsample frames, get views, MUSt3R features, SAM2 inputs."""
231
  logger.info(f"Starting GPU process: Video feature extraction (Interval: {interval})")
 
281
  logger.error(f"Error during mask generation: {e}")
282
  raise e
283
 
284
+ def get_duration_tracking(sam2_input_images, must3r_feats, must3r_outputs, start_idx, first_frame_mask):
285
+ # sam2_input_images is already subsampled, so this is the true number of frames to track
286
+ try:
287
+ n = int(getattr(sam2_input_images, "shape")[0])
288
+ except Exception:
289
+ n = 100 # fallback if something unexpected is passed
290
+
291
+ sec_per_frame = 2
292
+ return clamp_duration(int(n * sec_per_frame))
293
+
294
+ @spaces.GPU(duration=get_duration_tracking)
295
  def run_tracking(sam2_input_images, must3r_feats, must3r_outputs, start_idx, first_frame_mask):
296
  """Track the mask across the video."""
297
  logger.info(f"Starting tracking from frame index {start_idx}...")
 
329
  logger.error(f"Failed to process video: {e}")
330
  raise gr.Error(f"Processing failed: {str(e)}")
331
 
332
+ fps_in = estimate_video_fps(video_path)
333
+ interval_i = max(1, int(interval))
334
+ fps_out = max(1.0, fps_in / interval_i)
335
+
336
  # Initialize state
337
  state = {
338
  "pil_imgs": pil_imgs,
 
345
  "current_points": [],
346
  "current_labels": [],
347
  "current_mask": None,
348
+ "frame_idx": 0,
349
+ "video_path": video_path,
350
+ "interval": interval_i,
351
+ "fps_in": fps_in,
352
+ "fps_out": fps_out
353
  }
354
 
355
  first_frame = pil_imgs[0]
 
475
  first_frame_mask
476
  )
477
 
478
+ output_path = create_video_from_masks(
479
+ state["pil_imgs"],
480
+ tracked_masks_dict,
481
+ fps=state.get("fps_out", 24.0),
482
+ )
483
  return output_path
484
  except Exception as e:
485
  logger.error(f"Tracking failed in UI callback: {e}")
 
503
  <p>Upload a video, geometric features are extracted automatically. Select a frame, click to annotate objects, and track them in 3D-consistent space.</p>
504
  </div>
505
  """
 
506
  with gr.Blocks(title="3AM: 3egment Anything") as app:
507
  gr.HTML(description)
508
+
509
+ gr.Markdown(
510
+ """
511
+ # 3AM: 3egment Anything
512
+ **Workflow**
513
+ 1) Upload video
514
+ 2) Adjust frame interval β†’ Load frames
515
+ 3) Annotate & generate mask
516
+ 4) Track through the video
517
+ """
518
+ )
519
+
520
  app_state = gr.State()
521
+
522
  with gr.Row():
523
  with gr.Column(scale=1):
524
+ gr.Markdown("## Step 1 β€” Upload video")
525
+ video_input = gr.Video(
526
+ label="Upload Video",
527
+ sources=["upload"],
528
+ height=512
529
+ )
530
+
531
+ gr.Markdown("## Step 2 β€” Set interval, then load frames")
532
+ interval_slider = gr.Slider(
533
+ label="Frame Interval",
534
+ minimum=1,
535
+ maximum=30,
536
+ step=1,
537
+ value=1,
538
+ info="Default β‰ˆ total_frames / 100"
539
+ )
540
+
541
+ load_btn = gr.Button(
542
+ "Load Frames",
543
+ variant="primary"
544
+ )
545
+
546
+ process_status = gr.Textbox(
547
+ label="Status",
548
+ value="1) Upload a video.",
549
+ interactive=False
550
+ )
551
+
552
  with gr.Column(scale=2):
553
+ gr.Markdown("## Step 3 β€” Annotate frame & generate mask")
554
  img_display = gr.Image(
555
+ label="Annotate Frame",
556
+ interactive=True,
557
  height=512
558
  )
559
+
560
+ frame_slider = gr.Slider(
561
+ label="Select Frame",
562
+ minimum=0,
563
+ maximum=100,
564
+ step=1,
565
+ value=0
566
+ )
567
+
568
  with gr.Row():
569
  mode_radio = gr.Radio(
570
+ choices=[
571
+ "Positive Point",
572
+ "Negative Point",
573
+ "Box Top-Left",
574
+ "Box Bottom-Right",
575
+ ],
576
  value="Positive Point",
577
  label="Annotation Mode"
578
  )
579
  with gr.Column():
580
+ gen_mask_btn = gr.Button(
581
+ "Generate Mask",
582
+ variant="primary",
583
+ interactive=False
584
+ )
585
+ reset_btn = gr.Button(
586
+ "Reset Annotations",
587
+ interactive=False
588
+ )
589
+
590
+ gr.Markdown("## Step 4 β€” Track through the video")
591
  with gr.Row():
592
+ track_btn = gr.Button(
593
+ "Start Tracking",
594
+ variant="primary",
595
+ scale=1,
596
+ interactive=False
597
+ )
598
+
599
  with gr.Row():
 
600
  video_output = gr.Video(
601
+ label="Tracking Output",
602
+ autoplay=True,
603
  height=512
604
  )
605
 
606
+ # ------------------------------------------------
607
+ # Events
608
+ # ------------------------------------------------
609
+
610
+ # Upload: only read metadata & set default interval
611
+ def on_video_uploaded(video_path):
612
+ n_frames = estimate_total_frames(video_path)
613
+ default_interval = max(1, n_frames // 100)
614
+ return (
615
+ gr.update(value=default_interval, maximum=min(30, n_frames)),
616
+ f"Video uploaded ({n_frames} frames). "
617
+ "2) Adjust interval, then click 'Load Frames'."
618
+ )
619
+
620
  video_input.upload(
621
+ fn=on_video_uploaded,
622
+ inputs=video_input,
623
+ outputs=[interval_slider, process_status]
624
+ )
625
+
626
+ # Load frames: heavy compute happens here
627
+ load_btn.click(
628
+ fn=lambda: (
629
+ "Loading frames...",
630
+ gr.update(interactive=False),
631
+ gr.update(interactive=False),
632
+ gr.update(interactive=False),
633
+ ),
634
+ outputs=[process_status, gen_mask_btn, reset_btn, track_btn]
635
  ).then(
636
  fn=on_video_upload,
637
  inputs=[video_input, interval_slider],
638
  outputs=[img_display, app_state, frame_slider, img_display]
639
  ).then(
640
+ fn=lambda: (
641
+ "Ready. 3) Annotate and generate mask.",
642
+ gr.update(interactive=True),
643
+ gr.update(interactive=True),
644
+ gr.update(interactive=True),
645
+ ),
646
+ outputs=[process_status, gen_mask_btn, reset_btn, track_btn]
647
  )
648
+
649
  frame_slider.change(
650
  fn=on_slider_change,
651
  inputs=[app_state, frame_slider],
652
  outputs=[img_display]
653
  )
654
+
 
655
  img_display.select(
656
  fn=on_image_click,
657
  inputs=[app_state, mode_radio],
658
  outputs=[img_display]
659
  )
660
+
 
661
  gen_mask_btn.click(
662
  fn=on_generate_mask_click,
663
  inputs=[app_state],
664
  outputs=[img_display]
665
  )
666
+
667
  reset_btn.click(
668
  fn=reset_annotations,
669
  inputs=[app_state],
670
  outputs=[img_display]
671
  )
672
+
673
  track_btn.click(
674
  fn=lambda: "Tracking in progress...",
675
  outputs=process_status
 
682
  outputs=process_status
683
  )
684
 
685
+
686
  if __name__ == "__main__":
687
  logger.info("Starting Gradio app...")
688
  app.launch()