yonigozlan HF Staff commited on
Commit
adb1841
·
1 Parent(s): 03b77e4

process frames one at a time

Browse files
Files changed (1) hide show
  1. app.py +40 -66
app.py CHANGED
@@ -54,10 +54,12 @@ def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], di
54
  cap.release()
55
  if fps_val and fps_val > 0:
56
  info["fps"] = float(fps_val)
57
- except Exception:
 
58
  pass
59
  return pil_frames, info
60
- except Exception:
 
61
  # Fallback to OpenCV
62
  try:
63
  import cv2 # type: ignore
@@ -180,14 +182,6 @@ def load_model_if_needed(GLOBAL_STATE: gr.State) -> tuple[AutoModel, Sam2VideoPr
180
  if GLOBAL_STATE.model_repo_id == desired_repo:
181
  return GLOBAL_STATE.model, GLOBAL_STATE.processor, GLOBAL_STATE.device, GLOBAL_STATE.dtype
182
  # Different repo requested: dispose current and reload
183
- try:
184
- del GLOBAL_STATE.model
185
- except Exception:
186
- pass
187
- try:
188
- del GLOBAL_STATE.processor
189
- except Exception:
190
- pass
191
  GLOBAL_STATE.model = None
192
  GLOBAL_STATE.processor = None
193
  print(f"Loading model from {desired_repo}")
@@ -219,16 +213,8 @@ def ensure_session_for_current_model(GLOBAL_STATE: gr.State) -> None:
219
  GLOBAL_STATE.clicks_by_frame_obj.clear()
220
  GLOBAL_STATE.boxes_by_frame_obj.clear()
221
  GLOBAL_STATE.composited_frames.clear()
222
- # Dispose previous session cleanly
223
- try:
224
- if GLOBAL_STATE.inference_session is not None:
225
- GLOBAL_STATE.inference_session.reset_inference_session()
226
- except Exception:
227
- pass
228
  GLOBAL_STATE.inference_session = None
229
- gc.collect()
230
  GLOBAL_STATE.inference_session = processor.init_video_session(
231
- video=GLOBAL_STATE.video_frames,
232
  inference_device=device,
233
  video_storage_device="cpu",
234
  dtype=dtype,
@@ -265,40 +251,18 @@ def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppSt
265
  # Enforce max duration of 8 seconds (trim if longer)
266
  MAX_SECONDS = 8.0
267
  trimmed_note = ""
268
- fps_in = None
269
- if isinstance(info, dict) and info.get("fps"):
270
- try:
271
- fps_in = float(info["fps"]) or None
272
- except Exception:
273
- fps_in = None
274
- if fps_in is not None:
275
- max_frames_allowed = int(MAX_SECONDS * fps_in)
276
- if len(frames) > max_frames_allowed:
277
- frames = frames[:max_frames_allowed]
278
- trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)"
279
- if isinstance(info, dict):
280
- info["num_frames"] = len(frames)
281
- else:
282
- # Fallback when FPS unknown: assume ~30 FPS and cap to 240 frames (~8s)
283
- max_frames_allowed = 240
284
- if len(frames) > max_frames_allowed:
285
- frames = frames[:max_frames_allowed]
286
- trimmed_note = " (trimmed to 240 frames ~8s @30fps)"
287
- if isinstance(info, dict):
288
- info["num_frames"] = len(frames)
289
-
290
  GLOBAL_STATE.video_frames = frames
291
  # Try to capture original FPS if provided by loader
292
- GLOBAL_STATE.video_fps = None
293
- if isinstance(info, dict) and info.get("fps"):
294
- try:
295
- GLOBAL_STATE.video_fps = float(info["fps"]) or None
296
- except Exception:
297
- GLOBAL_STATE.video_fps = None
298
-
299
  # Initialize session
300
  inference_session = processor.init_video_session(
301
- video=frames,
302
  inference_device=device,
303
  video_storage_device="cpu",
304
  dtype=dtype,
@@ -412,6 +376,12 @@ def on_image_click(
412
  processor = state.processor
413
  model = state.model
414
  inference_session = state.inference_session
 
 
 
 
 
 
415
 
416
  if state.current_prompt_type == "Boxes":
417
  # Two-click box input
@@ -443,6 +413,7 @@ def on_image_click(
443
  obj_ids=int(obj_id),
444
  input_boxes=[[[x_min, y_min, x_max, y_max]]],
445
  clear_old_inputs=True, # For boxes, always clear old inputs
 
446
  )
447
 
448
  frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {})
@@ -465,6 +436,7 @@ def on_image_click(
465
  obj_ids=int(obj_id),
466
  input_points=[[[[int(x), int(y)]]]],
467
  input_labels=[[[int(label_int)]]],
 
468
  clear_old_inputs=bool(clear_old),
469
  )
470
 
@@ -477,10 +449,7 @@ def on_image_click(
477
 
478
  # Forward on that frame
479
  with torch.inference_mode():
480
- outputs = model(
481
- inference_session=inference_session,
482
- frame_idx=int(frame_idx),
483
- )
484
 
485
  H = inference_session.video_height
486
  W = inference_session.video_width
@@ -509,8 +478,8 @@ def on_image_click(
509
  @spaces.GPU()
510
  def propagate_masks(GLOBAL_STATE: gr.State):
511
  if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
512
- yield "Load a video first.", gr.update()
513
- return
514
 
515
  processor = deepcopy(GLOBAL_STATE.processor)
516
  model = deepcopy(GLOBAL_STATE.model)
@@ -524,17 +493,19 @@ def propagate_masks(GLOBAL_STATE: gr.State):
524
  processed = 0
525
 
526
  # Initial status; no slider change yet
527
- yield f"Propagating masks: {processed}/{total}", gr.update()
528
 
529
  last_frame_idx = 0
530
  with torch.inference_mode():
531
- for sam2_video_output in model.propagate_in_video_iterator(inference_session):
 
 
 
 
532
  H = inference_session.video_height
533
  W = inference_session.video_width
534
  pred_masks = sam2_video_output.pred_masks.detach().cpu()
535
  video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0]
536
-
537
- frame_idx = int(sam2_video_output.frame_idx)
538
  last_frame_idx = frame_idx
539
  masks_for_frame: dict[int, np.ndarray] = {}
540
  obj_ids_order = list(inference_session.obj_ids)
@@ -547,15 +518,15 @@ def propagate_masks(GLOBAL_STATE: gr.State):
547
 
548
  processed += 1
549
  # Every 15th frame (or last), move slider to current frame to update preview via slider binding
550
- if processed % 15 == 0 or processed == total:
551
- yield f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
552
- else:
553
- yield f"Propagating masks: {processed}/{total}", gr.update()
554
 
555
  text = f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects."
556
 
557
  # Final status; ensure slider points to last processed frame
558
- yield text, gr.update(value=last_frame_idx)
559
 
560
 
561
  def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str]:
@@ -785,14 +756,16 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
785
 
786
  iio.imwrite(out_path, [fr[:, :, ::-1] for fr in frames_np], plugin="pyav", fps=fps)
787
  return out_path
788
- except Exception:
 
789
  # Fallbacks
790
  try:
791
  import imageio.v2 as imageio # type: ignore
792
 
793
  imageio.mimsave(out_path, [fr[:, :, ::-1] for fr in frames_np], fps=fps)
794
  return out_path
795
- except Exception:
 
796
  try:
797
  import cv2 # type: ignore
798
 
@@ -803,6 +776,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
803
  writer.release()
804
  return out_path
805
  except Exception as e:
 
806
  raise gr.Error(f"Failed to render video: {e}")
807
 
808
  render_btn.click(_render_video, inputs=[GLOBAL_STATE], outputs=[playback_video])
@@ -811,7 +785,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
811
  propagate_btn.click(
812
  propagate_masks,
813
  inputs=[GLOBAL_STATE],
814
- outputs=[propagate_status, frame_slider],
815
  )
816
 
817
  reset_btn.click(
 
54
  cap.release()
55
  if fps_val and fps_val > 0:
56
  info["fps"] = float(fps_val)
57
+ except Exception as e:
58
+ print(f"Failed to render video with cv2: {e}")
59
  pass
60
  return pil_frames, info
61
+ except Exception as e:
62
+ print(f"Failed to load video with transformers.video_utils: {e}")
63
  # Fallback to OpenCV
64
  try:
65
  import cv2 # type: ignore
 
182
  if GLOBAL_STATE.model_repo_id == desired_repo:
183
  return GLOBAL_STATE.model, GLOBAL_STATE.processor, GLOBAL_STATE.device, GLOBAL_STATE.dtype
184
  # Different repo requested: dispose current and reload
 
 
 
 
 
 
 
 
185
  GLOBAL_STATE.model = None
186
  GLOBAL_STATE.processor = None
187
  print(f"Loading model from {desired_repo}")
 
213
  GLOBAL_STATE.clicks_by_frame_obj.clear()
214
  GLOBAL_STATE.boxes_by_frame_obj.clear()
215
  GLOBAL_STATE.composited_frames.clear()
 
 
 
 
 
 
216
  GLOBAL_STATE.inference_session = None
 
217
  GLOBAL_STATE.inference_session = processor.init_video_session(
 
218
  inference_device=device,
219
  video_storage_device="cpu",
220
  dtype=dtype,
 
251
  # Enforce max duration of 8 seconds (trim if longer)
252
  MAX_SECONDS = 8.0
253
  trimmed_note = ""
254
+ fps_in = info.get("fps")
255
+ max_frames_allowed = int(MAX_SECONDS * fps_in)
256
+ if len(frames) > max_frames_allowed:
257
+ frames = frames[:max_frames_allowed]
258
+ trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)"
259
+ if isinstance(info, dict):
260
+ info["num_frames"] = len(frames)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  GLOBAL_STATE.video_frames = frames
262
  # Try to capture original FPS if provided by loader
263
+ GLOBAL_STATE.video_fps = float(fps_in)
 
 
 
 
 
 
264
  # Initialize session
265
  inference_session = processor.init_video_session(
 
266
  inference_device=device,
267
  video_storage_device="cpu",
268
  dtype=dtype,
 
376
  processor = state.processor
377
  model = state.model
378
  inference_session = state.inference_session
379
+ original_size = None
380
+ pixel_values = None
381
+ if not inference_session.processed_frames or frame_idx not in inference_session.processed_frames:
382
+ inputs = processor(images=state.video_frames[frame_idx], device=state.device, return_tensors="pt")
383
+ original_size = inputs.original_sizes[0]
384
+ pixel_values = inputs.pixel_values[0]
385
 
386
  if state.current_prompt_type == "Boxes":
387
  # Two-click box input
 
413
  obj_ids=int(obj_id),
414
  input_boxes=[[[x_min, y_min, x_max, y_max]]],
415
  clear_old_inputs=True, # For boxes, always clear old inputs
416
+ original_size=original_size,
417
  )
418
 
419
  frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {})
 
436
  obj_ids=int(obj_id),
437
  input_points=[[[[int(x), int(y)]]]],
438
  input_labels=[[[int(label_int)]]],
439
+ original_size=original_size,
440
  clear_old_inputs=bool(clear_old),
441
  )
442
 
 
449
 
450
  # Forward on that frame
451
  with torch.inference_mode():
452
+ outputs = model(inference_session=inference_session, frame=pixel_values, frame_idx=int(frame_idx))
 
 
 
453
 
454
  H = inference_session.video_height
455
  W = inference_session.video_width
 
478
  @spaces.GPU()
479
  def propagate_masks(GLOBAL_STATE: gr.State):
480
  if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
481
+ # yield GLOBAL_STATE, "Load a video first.", gr.update()
482
+ return GLOBAL_STATE, "Load a video first.", gr.update()
483
 
484
  processor = deepcopy(GLOBAL_STATE.processor)
485
  model = deepcopy(GLOBAL_STATE.model)
 
493
  processed = 0
494
 
495
  # Initial status; no slider change yet
496
+ yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update()
497
 
498
  last_frame_idx = 0
499
  with torch.inference_mode():
500
+ for frame_idx, frame in enumerate(GLOBAL_STATE.video_frames):
501
+ pixel_values = None
502
+ if not inference_session.processed_frames or frame_idx not in inference_session.processed_frames:
503
+ pixel_values = processor(images=frame, device="cuda", return_tensors="pt").pixel_values[0]
504
+ sam2_video_output = model(inference_session=inference_session, frame=pixel_values, frame_idx=frame_idx)
505
  H = inference_session.video_height
506
  W = inference_session.video_width
507
  pred_masks = sam2_video_output.pred_masks.detach().cpu()
508
  video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0]
 
 
509
  last_frame_idx = frame_idx
510
  masks_for_frame: dict[int, np.ndarray] = {}
511
  obj_ids_order = list(inference_session.obj_ids)
 
518
 
519
  processed += 1
520
  # Every 15th frame (or last), move slider to current frame to update preview via slider binding
521
+ if processed % 30 == 0 or processed == total:
522
+ yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
523
+ # else:
524
+ # yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update()
525
 
526
  text = f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects."
527
 
528
  # Final status; ensure slider points to last processed frame
529
+ yield GLOBAL_STATE, text, gr.update(value=last_frame_idx)
530
 
531
 
532
  def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str]:
 
756
 
757
  iio.imwrite(out_path, [fr[:, :, ::-1] for fr in frames_np], plugin="pyav", fps=fps)
758
  return out_path
759
+ except Exception as e:
760
+ print(f"Failed to render video with imageio.v3: {e}")
761
  # Fallbacks
762
  try:
763
  import imageio.v2 as imageio # type: ignore
764
 
765
  imageio.mimsave(out_path, [fr[:, :, ::-1] for fr in frames_np], fps=fps)
766
  return out_path
767
+ except Exception as e:
768
+ print(f"Failed to render video with imageio.v2: {e}")
769
  try:
770
  import cv2 # type: ignore
771
 
 
776
  writer.release()
777
  return out_path
778
  except Exception as e:
779
+ print(f"Failed to render video with cv2: {e}")
780
  raise gr.Error(f"Failed to render video: {e}")
781
 
782
  render_btn.click(_render_video, inputs=[GLOBAL_STATE], outputs=[playback_video])
 
785
  propagate_btn.click(
786
  propagate_masks,
787
  inputs=[GLOBAL_STATE],
788
+ outputs=[GLOBAL_STATE, propagate_status, frame_slider],
789
  )
790
 
791
  reset_btn.click(