Zhen Ye Claude Opus 4.6 commited on
Commit
5aec47c
·
1 Parent(s): 64f68de

perf: GPU-resident tensor pipeline for SAM2 video propagation

Browse files

Eliminate all CUDA synchronization from propagate_segment() by keeping
masks, bboxes, and validity flags on GPU in pre-allocated buffers.
CPU materialization is deferred to a single bulk transfer via
SegmentOutput.to_object_dicts() at the consumer.

- Add _bbox_gpu() for zero-sync bounding box computation on GPU
- Add SegmentOutput dataclass for GPU-resident segment results
- Rewrite propagate_segment() with inline bbox + pre-allocated tensors
- Refactor process_video() to reuse propagate_segment()
- Update Phase 4 reconciliation: 3 CUDA syncs per segment vs 100+

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. inference.py +12 -8
  2. models/segmenters/grounded_sam2.py +149 -117
inference.py CHANGED
@@ -1643,7 +1643,7 @@ def run_grounded_sam2_tracking(
1643
  from PIL import Image as PILImage
1644
 
1645
  from utils.video import extract_frames_to_jpeg_dir
1646
- from models.segmenters.grounded_sam2 import MaskDictionary, ObjectInfo
1647
 
1648
  active_segmenter = segmenter_name or "gsam2_large"
1649
  logging.info(
@@ -1816,11 +1816,11 @@ def run_grounded_sam2_tracking(
1816
  label_list=labels,
1817
  )
1818
 
1819
- segment_results = seg.propagate_segment(
1820
  state, start_idx, mask_dict, step,
1821
  )
1822
  seg_queue_out.put(
1823
- (seg_idx, start_idx, mask_dict, segment_results)
1824
  )
1825
  except RuntimeError as e:
1826
  if "cancelled" in str(e).lower():
@@ -1853,8 +1853,8 @@ def run_grounded_sam2_tracking(
1853
  # Collect all segment outputs
1854
  segment_data: Dict[int, Tuple] = {}
1855
  while not seg_queue_out.empty():
1856
- seg_idx, start_idx, mask_dict, results = seg_queue_out.get()
1857
- segment_data[seg_idx] = (start_idx, mask_dict, results)
1858
 
1859
  # Phase 4: Sequential ID reconciliation
1860
  if _perf_metrics is not None:
@@ -1865,12 +1865,13 @@ def run_grounded_sam2_tracking(
1865
  tracking_results: Dict[int, Dict[int, ObjectInfo]] = {}
1866
 
1867
  def _mask_to_cpu(mask):
 
1868
  if torch.is_tensor(mask):
1869
  return mask.detach().cpu()
1870
  return mask
1871
 
1872
  for seg_idx in sorted(segment_data.keys()):
1873
- start_idx, mask_dict, segment_results = segment_data[seg_idx]
1874
 
1875
  if mask_dict is None or not mask_dict.labels:
1876
  # No detections — carry forward previous masks
@@ -1882,7 +1883,7 @@ def run_grounded_sam2_tracking(
1882
  {
1883
  k: ObjectInfo(
1884
  instance_id=v.instance_id,
1885
- mask=_mask_to_cpu(v.mask),
1886
  class_name=v.class_name,
1887
  x1=v.x1, y1=v.y1,
1888
  x2=v.x2, y2=v.y2,
@@ -1914,6 +1915,9 @@ def run_grounded_sam2_tracking(
1914
  tracking_results[fi] = {}
1915
  continue
1916
 
 
 
 
1917
  # Apply remapping to every frame in this segment
1918
  for frame_idx, frame_objects in segment_results.items():
1919
  remapped: Dict[int, ObjectInfo] = {}
@@ -1923,7 +1927,7 @@ def run_grounded_sam2_tracking(
1923
  continue
1924
  remapped[global_id] = ObjectInfo(
1925
  instance_id=global_id,
1926
- mask=_mask_to_cpu(obj_info.mask),
1927
  class_name=obj_info.class_name,
1928
  x1=obj_info.x1, y1=obj_info.y1,
1929
  x2=obj_info.x2, y2=obj_info.y2,
 
1643
  from PIL import Image as PILImage
1644
 
1645
  from utils.video import extract_frames_to_jpeg_dir
1646
+ from models.segmenters.grounded_sam2 import MaskDictionary, ObjectInfo, SegmentOutput
1647
 
1648
  active_segmenter = segmenter_name or "gsam2_large"
1649
  logging.info(
 
1816
  label_list=labels,
1817
  )
1818
 
1819
+ segment_output = seg.propagate_segment(
1820
  state, start_idx, mask_dict, step,
1821
  )
1822
  seg_queue_out.put(
1823
+ (seg_idx, start_idx, mask_dict, segment_output)
1824
  )
1825
  except RuntimeError as e:
1826
  if "cancelled" in str(e).lower():
 
1853
  # Collect all segment outputs
1854
  segment_data: Dict[int, Tuple] = {}
1855
  while not seg_queue_out.empty():
1856
+ seg_idx, start_idx, mask_dict, segment_output = seg_queue_out.get()
1857
+ segment_data[seg_idx] = (start_idx, mask_dict, segment_output)
1858
 
1859
  # Phase 4: Sequential ID reconciliation
1860
  if _perf_metrics is not None:
 
1865
  tracking_results: Dict[int, Dict[int, ObjectInfo]] = {}
1866
 
1867
  def _mask_to_cpu(mask):
1868
+ """Normalize mask to CPU tensor (still used for keyframe mask_dict)."""
1869
  if torch.is_tensor(mask):
1870
  return mask.detach().cpu()
1871
  return mask
1872
 
1873
  for seg_idx in sorted(segment_data.keys()):
1874
+ start_idx, mask_dict, segment_output = segment_data[seg_idx]
1875
 
1876
  if mask_dict is None or not mask_dict.labels:
1877
  # No detections — carry forward previous masks
 
1883
  {
1884
  k: ObjectInfo(
1885
  instance_id=v.instance_id,
1886
+ mask=v.mask,
1887
  class_name=v.class_name,
1888
  x1=v.x1, y1=v.y1,
1889
  x2=v.x2, y2=v.y2,
 
1915
  tracking_results[fi] = {}
1916
  continue
1917
 
1918
+ # Bulk CPU transfer: 3 CUDA syncs total (was 100+ per-mask syncs)
1919
+ segment_results = segment_output.to_object_dicts()
1920
+
1921
  # Apply remapping to every frame in this segment
1922
  for frame_idx, frame_objects in segment_results.items():
1923
  remapped: Dict[int, ObjectInfo] = {}
 
1927
  continue
1928
  remapped[global_id] = ObjectInfo(
1929
  instance_id=global_id,
1930
+ mask=obj_info.mask,
1931
  class_name=obj_info.class_name,
1932
  x1=obj_info.x1, y1=obj_info.y1,
1933
  x2=obj_info.x2, y2=obj_info.y2,
models/segmenters/grounded_sam2.py CHANGED
@@ -13,7 +13,7 @@ import logging
13
  import time
14
  from contextlib import nullcontext
15
  from dataclasses import dataclass, field
16
- from typing import Any, Dict, List, Optional, Sequence, Tuple
17
 
18
  import numpy as np
19
  import torch
@@ -220,6 +220,72 @@ class MaskDictionary:
220
  return float((inter / union).item())
221
 
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  # ---------------------------------------------------------------------------
224
  # SAM2 HuggingFace model IDs per size
225
  # ---------------------------------------------------------------------------
@@ -466,21 +532,11 @@ class GroundedSAM2Segmenter(Segmenter):
466
  start_idx: int,
467
  mask_dict: "MaskDictionary",
468
  step: int,
469
- ) -> Dict[int, Dict[int, "ObjectInfo"]]:
470
  """Propagate masks for a single segment via SAM2 video predictor.
471
 
472
- Calls ``reset_state`` first, making this safe to call independently
473
- (and therefore parallelisable across GPUs).
474
-
475
- Args:
476
- inference_state: SAM2 video predictor state (from ``init_state``).
477
- start_idx: Starting frame index for this segment.
478
- mask_dict: MaskDictionary with object masks for the keyframe.
479
- step: Maximum number of frames to propagate.
480
-
481
- Returns:
482
- Dict mapping ``frame_idx`` → ``{obj_id: ObjectInfo}`` using the
483
- IDs from *mask_dict* (local, not yet reconciled).
484
  """
485
  _pm = getattr(self, '_perf_metrics', None)
486
  if _pm is not None:
@@ -490,53 +546,72 @@ class GroundedSAM2Segmenter(Segmenter):
490
 
491
  for obj_id, obj_info in mask_dict.labels.items():
492
  self._video_predictor.add_new_mask(
493
- inference_state,
494
- start_idx,
495
- obj_id,
496
- obj_info.mask,
497
  )
498
 
499
- segment_results: Dict[int, Dict[int, ObjectInfo]] = {}
 
 
 
 
 
 
 
 
 
500
 
501
- # Phase A: Drain generator — GPU ops only, zero CUDA syncs
502
- raw_frames: list = []
503
  for out_frame_idx, out_obj_ids, out_mask_logits in self._video_predictor.propagate_in_video(
504
- inference_state,
505
- max_frame_num_to_track=step,
506
- start_frame_idx=start_idx,
507
  ):
508
- bool_masks = (out_mask_logits[:, 0] > 0.0) # (N_obj, H, W) bool, GPU
509
- raw_frames.append((out_frame_idx, list(out_obj_ids), bool_masks))
510
-
511
- # Phase B: Batched bbox + ObjectInfo construction — 2 CUDA syncs total
512
- if raw_frames:
513
- entries: list = []
514
- all_masks: list = []
515
- for frame_idx, obj_ids, bool_masks in raw_frames:
516
- for i, obj_id in enumerate(obj_ids):
517
- entries.append((frame_idx, obj_id, mask_dict.get_target_class_name(obj_id)))
518
- all_masks.append(bool_masks[i])
519
-
520
- if all_masks:
521
- stacked = torch.stack(all_masks)
522
- bboxes_cpu, valid_cpu = ObjectInfo.batch_bbox(stacked)
523
- del stacked
524
-
525
- bboxes_list = bboxes_cpu.tolist()
526
- valid_list = valid_cpu.tolist()
527
-
528
- for idx, (frame_idx, obj_id, class_name) in enumerate(entries):
529
- if valid_list[idx]:
530
- x1, y1, x2, y2 = int(bboxes_list[idx][0]), int(bboxes_list[idx][1]), int(bboxes_list[idx][2]), int(bboxes_list[idx][3])
531
- else:
532
- x1 = y1 = x2 = y2 = 0
533
- info = ObjectInfo(
534
- instance_id=obj_id,
535
- mask=all_masks[idx],
536
- class_name=class_name,
537
- x1=x1, y1=y1, x2=x2, y2=y2,
538
- )
539
- segment_results.setdefault(frame_idx, {})[obj_id] = info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540
 
541
  if _pm is not None:
542
  _pl = getattr(self, '_perf_lock', None)
@@ -546,7 +621,7 @@ class GroundedSAM2Segmenter(Segmenter):
546
  else:
547
  _pm["sam_video_total_ms"] += _d
548
 
549
- return segment_results
550
 
551
  # -- Video-level tracking interface -------------------------------------
552
 
@@ -721,66 +796,23 @@ class GroundedSAM2Segmenter(Segmenter):
721
  if _pm is not None:
722
  _t_sv = time.perf_counter()
723
 
724
- self._video_predictor.reset_state(inference_state)
725
-
726
- for obj_id, obj_info in mask_dict.labels.items():
727
- self._video_predictor.add_new_mask(
728
- inference_state,
729
- start_idx,
730
- obj_id,
731
- obj_info.mask,
732
- )
733
-
734
- # Phase A: Drain generator — GPU ops only, zero CUDA syncs
735
- raw_frames: list = []
736
- for out_frame_idx, out_obj_ids, out_mask_logits in self._video_predictor.propagate_in_video(
737
- inference_state,
738
- max_frame_num_to_track=step,
739
- start_frame_idx=start_idx,
740
- ):
741
- bool_masks = (out_mask_logits[:, 0] > 0.0) # (N_obj, H, W) bool, GPU
742
- raw_frames.append((out_frame_idx, list(out_obj_ids), bool_masks))
743
-
744
- # Phase B: Batched bbox + ObjectInfo construction — 2 CUDA syncs total
745
- if raw_frames:
746
- entries: list = []
747
- all_masks: list = []
748
- for frame_idx, obj_ids, bool_masks in raw_frames:
749
- for i, obj_id in enumerate(obj_ids):
750
- entries.append((frame_idx, obj_id, mask_dict.get_target_class_name(obj_id)))
751
- all_masks.append(bool_masks[i])
752
-
753
- if all_masks:
754
- stacked = torch.stack(all_masks)
755
- bboxes_cpu, valid_cpu = ObjectInfo.batch_bbox(stacked)
756
- del stacked
757
-
758
- bboxes_list = bboxes_cpu.tolist()
759
- valid_list = valid_cpu.tolist()
760
-
761
- for idx, (frame_idx, obj_id, class_name) in enumerate(entries):
762
- if valid_list[idx]:
763
- x1, y1, x2, y2 = int(bboxes_list[idx][0]), int(bboxes_list[idx][1]), int(bboxes_list[idx][2]), int(bboxes_list[idx][3])
764
- else:
765
- x1 = y1 = x2 = y2 = 0
766
- info = ObjectInfo(
767
- instance_id=obj_id,
768
- mask=all_masks[idx],
769
- class_name=class_name,
770
- x1=x1, y1=y1, x2=x2, y2=y2,
771
- )
772
- all_results.setdefault(frame_idx, {})[obj_id] = info
773
-
774
- # deepcopy ONLY the last frame (was running every frame before)
775
- last_frame_idx = raw_frames[-1][0]
776
- last_frame_objects = all_results.get(last_frame_idx, {})
777
- sam2_masks = MaskDictionary()
778
- sam2_masks.labels = copy.deepcopy(last_frame_objects)
779
- if last_frame_objects:
780
- first_info = next(iter(last_frame_objects.values()))
781
- if first_info.mask is not None:
782
- sam2_masks.mask_height = first_info.mask.shape[-2] if first_info.mask.ndim >= 2 else 0
783
- sam2_masks.mask_width = first_info.mask.shape[-1] if first_info.mask.ndim >= 2 else 0
784
 
785
  if _pm is not None:
786
  _pl = getattr(self, '_perf_lock', None)
 
13
  import time
14
  from contextlib import nullcontext
15
  from dataclasses import dataclass, field
16
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING
17
 
18
  import numpy as np
19
  import torch
 
220
  return float((inter / union).item())
221
 
222
 
223
+ # ---------------------------------------------------------------------------
224
+ # GPU-resident bounding-box helper (zero CUDA syncs)
225
+ # ---------------------------------------------------------------------------
226
+
227
+ def _bbox_gpu(bool_masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
228
+ """Compute bboxes from (N, H, W) bool GPU masks. Returns GPU tensors, zero sync.
229
+
230
+ Returns:
231
+ bboxes: (N, 4) int64 on same device as input [x_min, y_min, x_max, y_max]
232
+ valid: (N,) bool on same device as input
233
+ """
234
+ N, H, W = bool_masks.shape
235
+ rows = bool_masks.any(dim=2) # (N, H)
236
+ cols = bool_masks.any(dim=1) # (N, W)
237
+ valid = rows.any(dim=1) # (N,)
238
+ rows_f = rows.float()
239
+ cols_f = cols.float()
240
+ bboxes = torch.stack([
241
+ cols_f.argmax(dim=1), # x_min
242
+ rows_f.argmax(dim=1), # y_min
243
+ W - 1 - cols_f.flip(1).argmax(dim=1), # x_max
244
+ H - 1 - rows_f.flip(1).argmax(dim=1), # y_max
245
+ ], dim=1).to(torch.int64) # (N, 4) int64
246
+ return bboxes, valid
247
+
248
+
249
+ # ---------------------------------------------------------------------------
250
+ # GPU-resident segment output (deferred CPU materialization)
251
+ # ---------------------------------------------------------------------------
252
+
253
+ @dataclass
254
+ class SegmentOutput:
255
+ """GPU-resident segment propagation result. Zero CUDA syncs to construct."""
256
+ masks: torch.Tensor # (count, H, W) bool on GPU
257
+ bboxes: torch.Tensor # (count, 4) int64 on GPU
258
+ valid: torch.Tensor # (count,) bool on GPU
259
+ frame_indices: List[int] # len == count
260
+ obj_ids: List[int] # len == count
261
+ class_names: List[str] # len == count
262
+ device: str = "cpu"
263
+
264
+ def to_object_dicts(self) -> Dict[int, Dict[int, "ObjectInfo"]]:
265
+ """Bulk CPU transfer + ObjectInfo construction. 3 CUDA syncs total."""
266
+ if self.masks.numel() == 0:
267
+ return {}
268
+ masks_cpu = self.masks.cpu() # sync 1
269
+ bboxes_cpu = self.bboxes.cpu() # sync 2
270
+ valid_cpu = self.valid.cpu() # sync 3
271
+ result: Dict[int, Dict[int, ObjectInfo]] = {}
272
+ for i in range(masks_cpu.shape[0]):
273
+ fi, oid, cn = self.frame_indices[i], self.obj_ids[i], self.class_names[i]
274
+ if valid_cpu[i]:
275
+ x1, y1, x2, y2 = int(bboxes_cpu[i, 0]), int(bboxes_cpu[i, 1]), int(bboxes_cpu[i, 2]), int(bboxes_cpu[i, 3])
276
+ else:
277
+ x1 = y1 = x2 = y2 = 0
278
+ info = ObjectInfo(
279
+ instance_id=oid, mask=masks_cpu[i],
280
+ class_name=cn, x1=x1, y1=y1, x2=x2, y2=y2,
281
+ )
282
+ result.setdefault(fi, {})[oid] = info
283
+ return result
284
+
285
+ def last_frame_idx(self) -> Optional[int]:
286
+ return self.frame_indices[-1] if self.frame_indices else None
287
+
288
+
289
  # ---------------------------------------------------------------------------
290
  # SAM2 HuggingFace model IDs per size
291
  # ---------------------------------------------------------------------------
 
532
  start_idx: int,
533
  mask_dict: "MaskDictionary",
534
  step: int,
535
+ ) -> "SegmentOutput":
536
  """Propagate masks for a single segment via SAM2 video predictor.
537
 
538
+ Returns a GPU-resident ``SegmentOutput`` with zero CUDA syncs.
539
+ Call ``output.to_object_dicts()`` to materialize CPU ObjectInfo dicts.
 
 
 
 
 
 
 
 
 
 
540
  """
541
  _pm = getattr(self, '_perf_metrics', None)
542
  if _pm is not None:
 
546
 
547
  for obj_id, obj_info in mask_dict.labels.items():
548
  self._video_predictor.add_new_mask(
549
+ inference_state, start_idx, obj_id, obj_info.mask,
 
 
 
550
  )
551
 
552
+ # Pre-compute class name lookup (avoid repeated dict access in loop)
553
+ obj_id_to_class = {oid: mask_dict.get_target_class_name(oid) for oid in mask_dict.labels}
554
+ n_obj = len(mask_dict.labels)
555
+
556
+ # Pre-allocated GPU buffers (allocated on first yield when H, W known)
557
+ masks_buf = bboxes_buf = valid_buf = None
558
+ frame_indices: List[int] = []
559
+ obj_ids_list: List[int] = []
560
+ class_names_list: List[str] = []
561
+ cursor = 0
562
 
 
 
563
  for out_frame_idx, out_obj_ids, out_mask_logits in self._video_predictor.propagate_in_video(
564
+ inference_state, max_frame_num_to_track=step, start_frame_idx=start_idx,
 
 
565
  ):
566
+ bool_masks = (out_mask_logits[:, 0] > 0.0) # (N, H, W) GPU async
567
+ n = bool_masks.shape[0]
568
+
569
+ # Allocate on first yield
570
+ if masks_buf is None:
571
+ H, W = bool_masks.shape[1], bool_masks.shape[2]
572
+ max_entries = step * max(n_obj, n)
573
+ masks_buf = torch.empty(max_entries, H, W, dtype=torch.bool, device=self.device)
574
+ bboxes_buf = torch.empty(max_entries, 4, dtype=torch.int64, device=self.device)
575
+ valid_buf = torch.empty(max_entries, dtype=torch.bool, device=self.device)
576
+
577
+ # Grow buffers if needed (unlikely but safe)
578
+ if cursor + n > masks_buf.shape[0]:
579
+ grow = max(step * n_obj, cursor + n - masks_buf.shape[0])
580
+ H, W = masks_buf.shape[1], masks_buf.shape[2]
581
+ masks_buf = torch.cat([masks_buf, torch.empty(grow, H, W, dtype=torch.bool, device=self.device)])
582
+ bboxes_buf = torch.cat([bboxes_buf, torch.empty(grow, 4, dtype=torch.int64, device=self.device)])
583
+ valid_buf = torch.cat([valid_buf, torch.empty(grow, dtype=torch.bool, device=self.device)])
584
+
585
+ # Inline bbox — GPU async, zero sync
586
+ frame_bboxes, frame_valid = _bbox_gpu(bool_masks)
587
+
588
+ # Fill pre-allocated slices GPU async
589
+ masks_buf[cursor:cursor + n] = bool_masks
590
+ bboxes_buf[cursor:cursor + n] = frame_bboxes
591
+ valid_buf[cursor:cursor + n] = frame_valid
592
+
593
+ # Metadata (trivial Python, ~2μs GIL)
594
+ oid_list = list(out_obj_ids) if not isinstance(out_obj_ids, list) else out_obj_ids
595
+ for oid in oid_list:
596
+ frame_indices.append(out_frame_idx)
597
+ obj_ids_list.append(oid)
598
+ class_names_list.append(obj_id_to_class.get(oid, ""))
599
+ cursor += n
600
+
601
+ # Build output (zero-copy slice if under-filled, empty tensors if no frames)
602
+ if masks_buf is not None:
603
+ output = SegmentOutput(
604
+ masks=masks_buf[:cursor], bboxes=bboxes_buf[:cursor],
605
+ valid=valid_buf[:cursor], frame_indices=frame_indices,
606
+ obj_ids=obj_ids_list, class_names=class_names_list, device=self.device,
607
+ )
608
+ else:
609
+ output = SegmentOutput(
610
+ masks=torch.empty(0, 0, 0, dtype=torch.bool, device=self.device),
611
+ bboxes=torch.empty(0, 4, dtype=torch.int64, device=self.device),
612
+ valid=torch.empty(0, dtype=torch.bool, device=self.device),
613
+ frame_indices=[], obj_ids=[], class_names=[], device=self.device,
614
+ )
615
 
616
  if _pm is not None:
617
  _pl = getattr(self, '_perf_lock', None)
 
621
  else:
622
  _pm["sam_video_total_ms"] += _d
623
 
624
+ return output
625
 
626
  # -- Video-level tracking interface -------------------------------------
627
 
 
796
  if _pm is not None:
797
  _t_sv = time.perf_counter()
798
 
799
+ segment_output = self.propagate_segment(
800
+ inference_state, start_idx, mask_dict, step,
801
+ )
802
+ segment_results = segment_output.to_object_dicts()
803
+
804
+ if segment_results:
805
+ all_results.update(segment_results)
806
+ last_fi = segment_output.last_frame_idx()
807
+ if last_fi is not None:
808
+ last_frame_objects = all_results.get(last_fi, {})
809
+ sam2_masks = MaskDictionary()
810
+ sam2_masks.labels = copy.deepcopy(last_frame_objects)
811
+ if last_frame_objects:
812
+ first_info = next(iter(last_frame_objects.values()))
813
+ if first_info.mask is not None:
814
+ sam2_masks.mask_height = first_info.mask.shape[-2] if first_info.mask.ndim >= 2 else 0
815
+ sam2_masks.mask_width = first_info.mask.shape[-1] if first_info.mask.ndim >= 2 else 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
816
 
817
  if _pm is not None:
818
  _pl = getattr(self, '_perf_lock', None)