hysts HF Staff commited on
Commit
ffce5fe
·
1 Parent(s): 4e68fb4
Files changed (1) hide show
  1. app.py +118 -19
app.py CHANGED
@@ -1,11 +1,14 @@
1
  import colorsys
2
  import gc
3
  import tempfile
4
- from collections.abc import Iterator
 
 
5
 
6
  import cv2
7
  import gradio as gr
8
  import numpy as np
 
9
  import torch
10
  from gradio.themes import Soft
11
  from PIL import Image, ImageDraw, ImageFont
@@ -15,7 +18,7 @@ MODEL_ID = "facebook/sam3"
15
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
  DTYPE = torch.bfloat16
17
 
18
- TRACKER_MODEL = Sam3TrackerVideoModel.from_pretrained(MODEL_ID, torch_dtype=DTYPE, device_map=DEVICE).eval()
19
  TRACKER_PROCESSOR = Sam3TrackerVideoProcessor.from_pretrained(MODEL_ID)
20
 
21
  TEXT_VIDEO_MODEL = Sam3VideoModel.from_pretrained(MODEL_ID).to(DEVICE, dtype=DTYPE).eval()
@@ -25,6 +28,81 @@ print("Models loaded successfully!")
25
  MAX_SECONDS = 8.0
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], dict]:
29
  cap = cv2.VideoCapture(video_path_or_url)
30
  frames = []
@@ -175,8 +253,8 @@ def init_video_session(
175
  processor = TEXT_VIDEO_PROCESSOR
176
  state.inference_session = processor.init_video_session(
177
  video=frames,
178
- inference_device=DEVICE,
179
- inference_state_device=DEVICE,
180
  processing_device="cpu",
181
  video_storage_device="cpu",
182
  dtype=DTYPE,
@@ -185,13 +263,17 @@ def init_video_session(
185
  processor = TRACKER_PROCESSOR
186
  state.inference_session = processor.init_video_session(
187
  video=raw_video,
188
- inference_device=DEVICE,
189
- inference_state_device=DEVICE,
190
  processing_device="cpu",
191
  video_storage_device="cpu",
192
  dtype=DTYPE,
193
  )
194
 
 
 
 
 
195
  first_frame = frames[0]
196
  max_idx = len(frames) - 1
197
  if active_tab == "text":
@@ -362,6 +444,7 @@ def _ensure_color_for_obj(state: AppState, obj_id: int) -> None:
362
  state.color_by_obj[obj_id] = pastel_color_for_object(obj_id)
363
 
364
 
 
365
  def on_image_click(
366
  img: Image.Image | np.ndarray,
367
  state: AppState,
@@ -370,12 +453,13 @@ def on_image_click(
370
  label: str,
371
  clear_old: bool,
372
  evt: gr.SelectData,
373
- ) -> Image.Image:
374
  if state is None or state.inference_session is None:
375
  return img
376
 
377
  model = TRACKER_MODEL
378
  processor = TRACKER_PROCESSOR
 
379
 
380
  x = y = None
381
  if evt is not None:
@@ -471,14 +555,17 @@ def on_image_click(
471
 
472
  state.composited_frames.pop(ann_frame_idx, None)
473
 
474
- return update_frame_display(state, ann_frame_idx)
 
 
475
 
476
 
 
477
  def on_text_prompt(
478
  state: AppState,
479
  frame_idx: int,
480
  text_prompt: str,
481
- ) -> tuple[Image.Image, str, str]:
482
  if state is None or state.inference_session is None:
483
  return None, "Upload a video and enter text prompt.", "**Active prompts:** None"
484
 
@@ -487,7 +574,7 @@ def on_text_prompt(
487
 
488
  if not text_prompt or not text_prompt.strip():
489
  active_prompts = _get_active_prompts_display(state)
490
- return update_frame_display(state, int(frame_idx)), "Please enter a text prompt.", active_prompts
491
 
492
  frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1))
493
 
@@ -495,7 +582,9 @@ def on_text_prompt(
495
  prompt_texts = [p.strip() for p in text_prompt.split(",") if p.strip()]
496
  if not prompt_texts:
497
  active_prompts = _get_active_prompts_display(state)
498
- return update_frame_display(state, int(frame_idx)), "Please enter a valid text prompt.", active_prompts
 
 
499
 
500
  # Add text prompt(s) - supports both single string and list of strings
501
  state.inference_session = processor.add_text_prompt(
@@ -579,7 +668,10 @@ def on_text_prompt(
579
  status = f"Processed text prompt(s) {prompts_str} on frame {frame_idx}. No objects detected."
580
 
581
  active_prompts = _get_active_prompts_display(state)
582
- return update_frame_display(state, int(frame_idx)), status, active_prompts
 
 
 
583
 
584
 
585
  def _get_active_prompts_display(state: AppState) -> str:
@@ -596,6 +688,7 @@ def _get_active_prompts_display(state: AppState) -> str:
596
  return "**Active prompts:** None"
597
 
598
 
 
599
  def propagate_masks(state: AppState) -> Iterator[tuple[AppState, str, dict]]:
600
  if state is None:
601
  return state, "Load a video first.", gr.update()
@@ -619,6 +712,8 @@ def propagate_masks(state: AppState) -> Iterator[tuple[AppState, str, dict]]:
619
  model = TEXT_VIDEO_MODEL
620
  processor = TEXT_VIDEO_PROCESSOR
621
 
 
 
622
  # Collect all unique prompts from existing frame annotations
623
  text_prompt_to_obj_ids = {}
624
  for frame_idx, frame_texts in state.text_prompts_by_frame_obj.items():
@@ -638,6 +733,7 @@ def propagate_masks(state: AppState) -> Iterator[tuple[AppState, str, dict]]:
638
  text_prompt_to_obj_ids[text_prompt].sort()
639
 
640
  if not text_prompt_to_obj_ids:
 
641
  yield state, "No text prompts found. Please add a text prompt first.", gr.update()
642
  return
643
 
@@ -705,7 +801,9 @@ def propagate_masks(state: AppState) -> Iterator[tuple[AppState, str, dict]]:
705
  last_frame_idx = frame_idx
706
  processed += 1
707
  if processed % 30 == 0 or processed == total:
 
708
  yield state, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
 
709
  else:
710
  if state.inference_session is None:
711
  yield state, "Tracker model not loaded.", gr.update()
@@ -714,6 +812,8 @@ def propagate_masks(state: AppState) -> Iterator[tuple[AppState, str, dict]]:
714
  model = TRACKER_MODEL
715
  processor = TRACKER_PROCESSOR
716
 
 
 
717
  for sam2_video_output in model.propagate_in_video_iterator(inference_session=state.inference_session):
718
  video_res_masks = processor.post_process_masks(
719
  [sam2_video_output.pred_masks],
@@ -731,9 +831,12 @@ def propagate_masks(state: AppState) -> Iterator[tuple[AppState, str, dict]]:
731
  last_frame_idx = frame_idx
732
  processed += 1
733
  if processed % 30 == 0 or processed == total:
 
734
  yield state, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
 
735
 
736
  text = f"Propagated masks across {processed} frames."
 
737
  yield state, text, gr.update(value=last_frame_idx)
738
 
739
 
@@ -1079,17 +1182,13 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
1079
  preview_pointbox.select(
1080
  fn=on_image_click,
1081
  inputs=[preview_pointbox, app_state, frame_slider_pointbox, obj_id_inp, label_radio, clear_old_chk],
1082
- outputs=preview_pointbox,
1083
  )
1084
 
1085
- def _on_text_apply(state: AppState, frame_idx: int, text: str) -> tuple[Image.Image, str, str]:
1086
- img, status, active_prompts = on_text_prompt(state, frame_idx, text)
1087
- return img, status, active_prompts
1088
-
1089
  text_apply_btn.click(
1090
- fn=_on_text_apply,
1091
  inputs=[app_state, frame_slider_text, text_prompt_input],
1092
- outputs=[preview_text, text_status, active_prompts_display],
1093
  )
1094
 
1095
  reset_prompts_btn.click(
 
1
  import colorsys
2
  import gc
3
  import tempfile
4
+ from collections import defaultdict
5
+ from collections.abc import Iterator, Mapping, Sequence
6
+ from typing import Any
7
 
8
  import cv2
9
  import gradio as gr
10
  import numpy as np
11
+ import spaces
12
  import torch
13
  from gradio.themes import Soft
14
  from PIL import Image, ImageDraw, ImageFont
 
18
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
  DTYPE = torch.bfloat16
20
 
21
+ TRACKER_MODEL = Sam3TrackerVideoModel.from_pretrained(MODEL_ID, torch_dtype=DTYPE).to(DEVICE).eval()
22
  TRACKER_PROCESSOR = Sam3TrackerVideoProcessor.from_pretrained(MODEL_ID)
23
 
24
  TEXT_VIDEO_MODEL = Sam3VideoModel.from_pretrained(MODEL_ID).to(DEVICE, dtype=DTYPE).eval()
 
28
  MAX_SECONDS = 8.0
29
 
30
 
31
+ def to_device_recursive(obj: Any, device: str | torch.device) -> Any: # noqa: ANN401
32
+ """Return a new object where all torch.Tensors reachable from `obj` are moved to the given device.
33
+
34
+ - Does NOT mutate the original object.
35
+ - Handles:
36
+ * torch.Tensor
37
+ * Mapping (e.g. dict, defaultdict, OrderedDict, etc.)
38
+ * Sequence (e.g. list, tuple) except str/bytes
39
+ * Custom classes with attributes (__dict__)
40
+ - Tries to preserve container types where reasonable.
41
+ """
42
+ device = torch.device(device)
43
+ memo = {}
44
+
45
+ def _convert(x: Any) -> Any: # noqa: ANN401, C901
46
+ obj_id = id(x)
47
+ if obj_id in memo:
48
+ return memo[obj_id]
49
+
50
+ # 1. Tensor
51
+ if isinstance(x, torch.Tensor):
52
+ y = x.to(device)
53
+ memo[obj_id] = y
54
+ return y
55
+
56
+ # 2. Mapping (dict, defaultdict, etc.)
57
+ if isinstance(x, Mapping):
58
+ # Special case: defaultdict
59
+ if isinstance(x, defaultdict):
60
+ y = defaultdict(x.default_factory)
61
+ memo[obj_id] = y
62
+ for k, v in x.items():
63
+ y[k] = _convert(v)
64
+ return y
65
+
66
+ # Try to rebuild the same type using (key, value) pairs
67
+ try:
68
+ y = type(x)((k, _convert(v)) for k, v in x.items())
69
+ memo[obj_id] = y
70
+ return y
71
+ except TypeError:
72
+ # Fallback: plain dict
73
+ y = {k: _convert(v) for k, v in x.items()}
74
+ memo[obj_id] = y
75
+ return y
76
+
77
+ # 3. Sequence (list/tuple/etc.) but not str/bytes
78
+ if isinstance(x, Sequence) and not isinstance(x, (str, bytes, bytearray)):
79
+ if isinstance(x, list):
80
+ y = [_convert(v) for v in x]
81
+ elif isinstance(x, tuple):
82
+ y = type(x)(_convert(v) for v in x)
83
+ else:
84
+ try:
85
+ y = type(x)(_convert(v) for v in x)
86
+ except TypeError:
87
+ y = [_convert(v) for v in x]
88
+ memo[obj_id] = y
89
+ return y
90
+
91
+ # 4. Custom object with attributes (__dict__)
92
+ if hasattr(x, "__dict__") and not isinstance(x, type):
93
+ new_obj = x.__class__.__new__(x.__class__)
94
+ memo[obj_id] = new_obj
95
+ for name, value in vars(x).items():
96
+ setattr(new_obj, name, _convert(value))
97
+ return new_obj
98
+
99
+ # 5. Everything else → keep as-is
100
+ memo[obj_id] = x
101
+ return x
102
+
103
+ return _convert(obj)
104
+
105
+
106
  def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], dict]:
107
  cap = cv2.VideoCapture(video_path_or_url)
108
  frames = []
 
253
  processor = TEXT_VIDEO_PROCESSOR
254
  state.inference_session = processor.init_video_session(
255
  video=frames,
256
+ inference_device="cpu",
257
+ inference_state_device="cpu",
258
  processing_device="cpu",
259
  video_storage_device="cpu",
260
  dtype=DTYPE,
 
263
  processor = TRACKER_PROCESSOR
264
  state.inference_session = processor.init_video_session(
265
  video=raw_video,
266
+ inference_device="cpu",
267
+ inference_state_device="cpu",
268
  processing_device="cpu",
269
  video_storage_device="cpu",
270
  dtype=DTYPE,
271
  )
272
 
273
+ state.inference_session.inference_device = DEVICE
274
+ state.inference_session.processing_device = DEVICE
275
+ state.inference_session.cache.inference_device = DEVICE
276
+
277
  first_frame = frames[0]
278
  max_idx = len(frames) - 1
279
  if active_tab == "text":
 
444
  state.color_by_obj[obj_id] = pastel_color_for_object(obj_id)
445
 
446
 
447
+ @spaces.GPU
448
  def on_image_click(
449
  img: Image.Image | np.ndarray,
450
  state: AppState,
 
453
  label: str,
454
  clear_old: bool,
455
  evt: gr.SelectData,
456
+ ) -> tuple[Image.Image, AppState]:
457
  if state is None or state.inference_session is None:
458
  return img
459
 
460
  model = TRACKER_MODEL
461
  processor = TRACKER_PROCESSOR
462
+ state.inference_session = to_device_recursive(state.inference_session, DEVICE)
463
 
464
  x = y = None
465
  if evt is not None:
 
555
 
556
  state.composited_frames.pop(ann_frame_idx, None)
557
 
558
+ state.inference_session = to_device_recursive(state.inference_session, "cpu")
559
+
560
+ return update_frame_display(state, ann_frame_idx), state
561
 
562
 
563
+ @spaces.GPU
564
  def on_text_prompt(
565
  state: AppState,
566
  frame_idx: int,
567
  text_prompt: str,
568
+ ) -> tuple[Image.Image, str, str, AppState]:
569
  if state is None or state.inference_session is None:
570
  return None, "Upload a video and enter text prompt.", "**Active prompts:** None"
571
 
 
574
 
575
  if not text_prompt or not text_prompt.strip():
576
  active_prompts = _get_active_prompts_display(state)
577
+ return update_frame_display(state, int(frame_idx)), "Please enter a text prompt.", active_prompts, state
578
 
579
  frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1))
580
 
 
582
  prompt_texts = [p.strip() for p in text_prompt.split(",") if p.strip()]
583
  if not prompt_texts:
584
  active_prompts = _get_active_prompts_display(state)
585
+ return update_frame_display(state, int(frame_idx)), "Please enter a valid text prompt.", active_prompts, state
586
+
587
+ state.inference_session = to_device_recursive(state.inference_session, DEVICE)
588
 
589
  # Add text prompt(s) - supports both single string and list of strings
590
  state.inference_session = processor.add_text_prompt(
 
668
  status = f"Processed text prompt(s) {prompts_str} on frame {frame_idx}. No objects detected."
669
 
670
  active_prompts = _get_active_prompts_display(state)
671
+
672
+ state.inference_session = to_device_recursive(state.inference_session, "cpu")
673
+
674
+ return update_frame_display(state, int(frame_idx)), status, active_prompts, state
675
 
676
 
677
  def _get_active_prompts_display(state: AppState) -> str:
 
688
  return "**Active prompts:** None"
689
 
690
 
691
+ @spaces.GPU
692
  def propagate_masks(state: AppState) -> Iterator[tuple[AppState, str, dict]]:
693
  if state is None:
694
  return state, "Load a video first.", gr.update()
 
712
  model = TEXT_VIDEO_MODEL
713
  processor = TEXT_VIDEO_PROCESSOR
714
 
715
+ state.inference_session = to_device_recursive(state.inference_session, DEVICE)
716
+
717
  # Collect all unique prompts from existing frame annotations
718
  text_prompt_to_obj_ids = {}
719
  for frame_idx, frame_texts in state.text_prompts_by_frame_obj.items():
 
733
  text_prompt_to_obj_ids[text_prompt].sort()
734
 
735
  if not text_prompt_to_obj_ids:
736
+ state.inference_session = to_device_recursive(state.inference_session, "cpu")
737
  yield state, "No text prompts found. Please add a text prompt first.", gr.update()
738
  return
739
 
 
801
  last_frame_idx = frame_idx
802
  processed += 1
803
  if processed % 30 == 0 or processed == total:
804
+ state.inference_session = to_device_recursive(state.inference_session, "cpu")
805
  yield state, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
806
+ state.inference_session = to_device_recursive(state.inference_session, DEVICE)
807
  else:
808
  if state.inference_session is None:
809
  yield state, "Tracker model not loaded.", gr.update()
 
812
  model = TRACKER_MODEL
813
  processor = TRACKER_PROCESSOR
814
 
815
+ state.inference_session = to_device_recursive(state.inference_session, DEVICE)
816
+
817
  for sam2_video_output in model.propagate_in_video_iterator(inference_session=state.inference_session):
818
  video_res_masks = processor.post_process_masks(
819
  [sam2_video_output.pred_masks],
 
831
  last_frame_idx = frame_idx
832
  processed += 1
833
  if processed % 30 == 0 or processed == total:
834
+ state.inference_session = to_device_recursive(state.inference_session, "cpu")
835
  yield state, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
836
+ state.inference_session = to_device_recursive(state.inference_session, DEVICE)
837
 
838
  text = f"Propagated masks across {processed} frames."
839
+ state.inference_session = to_device_recursive(state.inference_session, "cpu")
840
  yield state, text, gr.update(value=last_frame_idx)
841
 
842
 
 
1182
  preview_pointbox.select(
1183
  fn=on_image_click,
1184
  inputs=[preview_pointbox, app_state, frame_slider_pointbox, obj_id_inp, label_radio, clear_old_chk],
1185
+ outputs=[preview_pointbox, app_state],
1186
  )
1187
 
 
 
 
 
1188
  text_apply_btn.click(
1189
+ fn=on_text_prompt,
1190
  inputs=[app_state, frame_slider_text, text_prompt_input],
1191
+ outputs=[preview_text, text_status, active_prompts_display, app_state],
1192
  )
1193
 
1194
  reset_prompts_btn.click(