Spaces:
Paused
Paused
Zhen Ye Claude Opus 4.6 commited on
Commit ·
5aec47c
1
Parent(s): 64f68de
perf: GPU-resident tensor pipeline for SAM2 video propagation
Browse filesEliminate all CUDA synchronization from propagate_segment() by keeping
masks, bboxes, and validity flags on GPU in pre-allocated buffers.
CPU materialization is deferred to a single bulk transfer via
SegmentOutput.to_object_dicts() at the consumer.
- Add _bbox_gpu() for zero-sync bounding box computation on GPU
- Add SegmentOutput dataclass for GPU-resident segment results
- Rewrite propagate_segment() with inline bbox + pre-allocated tensors
- Refactor process_video() to reuse propagate_segment()
- Update Phase 4 reconciliation: 3 CUDA syncs per segment vs 100+
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- inference.py +12 -8
- models/segmenters/grounded_sam2.py +149 -117
inference.py
CHANGED
|
@@ -1643,7 +1643,7 @@ def run_grounded_sam2_tracking(
|
|
| 1643 |
from PIL import Image as PILImage
|
| 1644 |
|
| 1645 |
from utils.video import extract_frames_to_jpeg_dir
|
| 1646 |
-
from models.segmenters.grounded_sam2 import MaskDictionary, ObjectInfo
|
| 1647 |
|
| 1648 |
active_segmenter = segmenter_name or "gsam2_large"
|
| 1649 |
logging.info(
|
|
@@ -1816,11 +1816,11 @@ def run_grounded_sam2_tracking(
|
|
| 1816 |
label_list=labels,
|
| 1817 |
)
|
| 1818 |
|
| 1819 |
-
|
| 1820 |
state, start_idx, mask_dict, step,
|
| 1821 |
)
|
| 1822 |
seg_queue_out.put(
|
| 1823 |
-
(seg_idx, start_idx, mask_dict,
|
| 1824 |
)
|
| 1825 |
except RuntimeError as e:
|
| 1826 |
if "cancelled" in str(e).lower():
|
|
@@ -1853,8 +1853,8 @@ def run_grounded_sam2_tracking(
|
|
| 1853 |
# Collect all segment outputs
|
| 1854 |
segment_data: Dict[int, Tuple] = {}
|
| 1855 |
while not seg_queue_out.empty():
|
| 1856 |
-
seg_idx, start_idx, mask_dict,
|
| 1857 |
-
segment_data[seg_idx] = (start_idx, mask_dict,
|
| 1858 |
|
| 1859 |
# Phase 4: Sequential ID reconciliation
|
| 1860 |
if _perf_metrics is not None:
|
|
@@ -1865,12 +1865,13 @@ def run_grounded_sam2_tracking(
|
|
| 1865 |
tracking_results: Dict[int, Dict[int, ObjectInfo]] = {}
|
| 1866 |
|
| 1867 |
def _mask_to_cpu(mask):
|
|
|
|
| 1868 |
if torch.is_tensor(mask):
|
| 1869 |
return mask.detach().cpu()
|
| 1870 |
return mask
|
| 1871 |
|
| 1872 |
for seg_idx in sorted(segment_data.keys()):
|
| 1873 |
-
start_idx, mask_dict,
|
| 1874 |
|
| 1875 |
if mask_dict is None or not mask_dict.labels:
|
| 1876 |
# No detections — carry forward previous masks
|
|
@@ -1882,7 +1883,7 @@ def run_grounded_sam2_tracking(
|
|
| 1882 |
{
|
| 1883 |
k: ObjectInfo(
|
| 1884 |
instance_id=v.instance_id,
|
| 1885 |
-
mask=
|
| 1886 |
class_name=v.class_name,
|
| 1887 |
x1=v.x1, y1=v.y1,
|
| 1888 |
x2=v.x2, y2=v.y2,
|
|
@@ -1914,6 +1915,9 @@ def run_grounded_sam2_tracking(
|
|
| 1914 |
tracking_results[fi] = {}
|
| 1915 |
continue
|
| 1916 |
|
|
|
|
|
|
|
|
|
|
| 1917 |
# Apply remapping to every frame in this segment
|
| 1918 |
for frame_idx, frame_objects in segment_results.items():
|
| 1919 |
remapped: Dict[int, ObjectInfo] = {}
|
|
@@ -1923,7 +1927,7 @@ def run_grounded_sam2_tracking(
|
|
| 1923 |
continue
|
| 1924 |
remapped[global_id] = ObjectInfo(
|
| 1925 |
instance_id=global_id,
|
| 1926 |
-
mask=
|
| 1927 |
class_name=obj_info.class_name,
|
| 1928 |
x1=obj_info.x1, y1=obj_info.y1,
|
| 1929 |
x2=obj_info.x2, y2=obj_info.y2,
|
|
|
|
| 1643 |
from PIL import Image as PILImage
|
| 1644 |
|
| 1645 |
from utils.video import extract_frames_to_jpeg_dir
|
| 1646 |
+
from models.segmenters.grounded_sam2 import MaskDictionary, ObjectInfo, SegmentOutput
|
| 1647 |
|
| 1648 |
active_segmenter = segmenter_name or "gsam2_large"
|
| 1649 |
logging.info(
|
|
|
|
| 1816 |
label_list=labels,
|
| 1817 |
)
|
| 1818 |
|
| 1819 |
+
segment_output = seg.propagate_segment(
|
| 1820 |
state, start_idx, mask_dict, step,
|
| 1821 |
)
|
| 1822 |
seg_queue_out.put(
|
| 1823 |
+
(seg_idx, start_idx, mask_dict, segment_output)
|
| 1824 |
)
|
| 1825 |
except RuntimeError as e:
|
| 1826 |
if "cancelled" in str(e).lower():
|
|
|
|
| 1853 |
# Collect all segment outputs
|
| 1854 |
segment_data: Dict[int, Tuple] = {}
|
| 1855 |
while not seg_queue_out.empty():
|
| 1856 |
+
seg_idx, start_idx, mask_dict, segment_output = seg_queue_out.get()
|
| 1857 |
+
segment_data[seg_idx] = (start_idx, mask_dict, segment_output)
|
| 1858 |
|
| 1859 |
# Phase 4: Sequential ID reconciliation
|
| 1860 |
if _perf_metrics is not None:
|
|
|
|
| 1865 |
tracking_results: Dict[int, Dict[int, ObjectInfo]] = {}
|
| 1866 |
|
| 1867 |
def _mask_to_cpu(mask):
|
| 1868 |
+
"""Normalize mask to CPU tensor (still used for keyframe mask_dict)."""
|
| 1869 |
if torch.is_tensor(mask):
|
| 1870 |
return mask.detach().cpu()
|
| 1871 |
return mask
|
| 1872 |
|
| 1873 |
for seg_idx in sorted(segment_data.keys()):
|
| 1874 |
+
start_idx, mask_dict, segment_output = segment_data[seg_idx]
|
| 1875 |
|
| 1876 |
if mask_dict is None or not mask_dict.labels:
|
| 1877 |
# No detections — carry forward previous masks
|
|
|
|
| 1883 |
{
|
| 1884 |
k: ObjectInfo(
|
| 1885 |
instance_id=v.instance_id,
|
| 1886 |
+
mask=v.mask,
|
| 1887 |
class_name=v.class_name,
|
| 1888 |
x1=v.x1, y1=v.y1,
|
| 1889 |
x2=v.x2, y2=v.y2,
|
|
|
|
| 1915 |
tracking_results[fi] = {}
|
| 1916 |
continue
|
| 1917 |
|
| 1918 |
+
# Bulk CPU transfer: 3 CUDA syncs total (was 100+ per-mask syncs)
|
| 1919 |
+
segment_results = segment_output.to_object_dicts()
|
| 1920 |
+
|
| 1921 |
# Apply remapping to every frame in this segment
|
| 1922 |
for frame_idx, frame_objects in segment_results.items():
|
| 1923 |
remapped: Dict[int, ObjectInfo] = {}
|
|
|
|
| 1927 |
continue
|
| 1928 |
remapped[global_id] = ObjectInfo(
|
| 1929 |
instance_id=global_id,
|
| 1930 |
+
mask=obj_info.mask,
|
| 1931 |
class_name=obj_info.class_name,
|
| 1932 |
x1=obj_info.x1, y1=obj_info.y1,
|
| 1933 |
x2=obj_info.x2, y2=obj_info.y2,
|
models/segmenters/grounded_sam2.py
CHANGED
|
@@ -13,7 +13,7 @@ import logging
|
|
| 13 |
import time
|
| 14 |
from contextlib import nullcontext
|
| 15 |
from dataclasses import dataclass, field
|
| 16 |
-
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
| 17 |
|
| 18 |
import numpy as np
|
| 19 |
import torch
|
|
@@ -220,6 +220,72 @@ class MaskDictionary:
|
|
| 220 |
return float((inter / union).item())
|
| 221 |
|
| 222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
# ---------------------------------------------------------------------------
|
| 224 |
# SAM2 HuggingFace model IDs per size
|
| 225 |
# ---------------------------------------------------------------------------
|
|
@@ -466,21 +532,11 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 466 |
start_idx: int,
|
| 467 |
mask_dict: "MaskDictionary",
|
| 468 |
step: int,
|
| 469 |
-
) ->
|
| 470 |
"""Propagate masks for a single segment via SAM2 video predictor.
|
| 471 |
|
| 472 |
-
|
| 473 |
-
(
|
| 474 |
-
|
| 475 |
-
Args:
|
| 476 |
-
inference_state: SAM2 video predictor state (from ``init_state``).
|
| 477 |
-
start_idx: Starting frame index for this segment.
|
| 478 |
-
mask_dict: MaskDictionary with object masks for the keyframe.
|
| 479 |
-
step: Maximum number of frames to propagate.
|
| 480 |
-
|
| 481 |
-
Returns:
|
| 482 |
-
Dict mapping ``frame_idx`` → ``{obj_id: ObjectInfo}`` using the
|
| 483 |
-
IDs from *mask_dict* (local, not yet reconciled).
|
| 484 |
"""
|
| 485 |
_pm = getattr(self, '_perf_metrics', None)
|
| 486 |
if _pm is not None:
|
|
@@ -490,53 +546,72 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 490 |
|
| 491 |
for obj_id, obj_info in mask_dict.labels.items():
|
| 492 |
self._video_predictor.add_new_mask(
|
| 493 |
-
inference_state,
|
| 494 |
-
start_idx,
|
| 495 |
-
obj_id,
|
| 496 |
-
obj_info.mask,
|
| 497 |
)
|
| 498 |
|
| 499 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
|
| 501 |
-
# Phase A: Drain generator — GPU ops only, zero CUDA syncs
|
| 502 |
-
raw_frames: list = []
|
| 503 |
for out_frame_idx, out_obj_ids, out_mask_logits in self._video_predictor.propagate_in_video(
|
| 504 |
-
inference_state,
|
| 505 |
-
max_frame_num_to_track=step,
|
| 506 |
-
start_frame_idx=start_idx,
|
| 507 |
):
|
| 508 |
-
bool_masks = (out_mask_logits[:, 0] > 0.0)
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
if
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 540 |
|
| 541 |
if _pm is not None:
|
| 542 |
_pl = getattr(self, '_perf_lock', None)
|
|
@@ -546,7 +621,7 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 546 |
else:
|
| 547 |
_pm["sam_video_total_ms"] += _d
|
| 548 |
|
| 549 |
-
return
|
| 550 |
|
| 551 |
# -- Video-level tracking interface -------------------------------------
|
| 552 |
|
|
@@ -721,66 +796,23 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 721 |
if _pm is not None:
|
| 722 |
_t_sv = time.perf_counter()
|
| 723 |
|
| 724 |
-
self.
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
bool_masks = (out_mask_logits[:, 0] > 0.0) # (N_obj, H, W) bool, GPU
|
| 742 |
-
raw_frames.append((out_frame_idx, list(out_obj_ids), bool_masks))
|
| 743 |
-
|
| 744 |
-
# Phase B: Batched bbox + ObjectInfo construction — 2 CUDA syncs total
|
| 745 |
-
if raw_frames:
|
| 746 |
-
entries: list = []
|
| 747 |
-
all_masks: list = []
|
| 748 |
-
for frame_idx, obj_ids, bool_masks in raw_frames:
|
| 749 |
-
for i, obj_id in enumerate(obj_ids):
|
| 750 |
-
entries.append((frame_idx, obj_id, mask_dict.get_target_class_name(obj_id)))
|
| 751 |
-
all_masks.append(bool_masks[i])
|
| 752 |
-
|
| 753 |
-
if all_masks:
|
| 754 |
-
stacked = torch.stack(all_masks)
|
| 755 |
-
bboxes_cpu, valid_cpu = ObjectInfo.batch_bbox(stacked)
|
| 756 |
-
del stacked
|
| 757 |
-
|
| 758 |
-
bboxes_list = bboxes_cpu.tolist()
|
| 759 |
-
valid_list = valid_cpu.tolist()
|
| 760 |
-
|
| 761 |
-
for idx, (frame_idx, obj_id, class_name) in enumerate(entries):
|
| 762 |
-
if valid_list[idx]:
|
| 763 |
-
x1, y1, x2, y2 = int(bboxes_list[idx][0]), int(bboxes_list[idx][1]), int(bboxes_list[idx][2]), int(bboxes_list[idx][3])
|
| 764 |
-
else:
|
| 765 |
-
x1 = y1 = x2 = y2 = 0
|
| 766 |
-
info = ObjectInfo(
|
| 767 |
-
instance_id=obj_id,
|
| 768 |
-
mask=all_masks[idx],
|
| 769 |
-
class_name=class_name,
|
| 770 |
-
x1=x1, y1=y1, x2=x2, y2=y2,
|
| 771 |
-
)
|
| 772 |
-
all_results.setdefault(frame_idx, {})[obj_id] = info
|
| 773 |
-
|
| 774 |
-
# deepcopy ONLY the last frame (was running every frame before)
|
| 775 |
-
last_frame_idx = raw_frames[-1][0]
|
| 776 |
-
last_frame_objects = all_results.get(last_frame_idx, {})
|
| 777 |
-
sam2_masks = MaskDictionary()
|
| 778 |
-
sam2_masks.labels = copy.deepcopy(last_frame_objects)
|
| 779 |
-
if last_frame_objects:
|
| 780 |
-
first_info = next(iter(last_frame_objects.values()))
|
| 781 |
-
if first_info.mask is not None:
|
| 782 |
-
sam2_masks.mask_height = first_info.mask.shape[-2] if first_info.mask.ndim >= 2 else 0
|
| 783 |
-
sam2_masks.mask_width = first_info.mask.shape[-1] if first_info.mask.ndim >= 2 else 0
|
| 784 |
|
| 785 |
if _pm is not None:
|
| 786 |
_pl = getattr(self, '_perf_lock', None)
|
|
|
|
| 13 |
import time
|
| 14 |
from contextlib import nullcontext
|
| 15 |
from dataclasses import dataclass, field
|
| 16 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING
|
| 17 |
|
| 18 |
import numpy as np
|
| 19 |
import torch
|
|
|
|
| 220 |
return float((inter / union).item())
|
| 221 |
|
| 222 |
|
| 223 |
+
# ---------------------------------------------------------------------------
|
| 224 |
+
# GPU-resident bounding-box helper (zero CUDA syncs)
|
| 225 |
+
# ---------------------------------------------------------------------------
|
| 226 |
+
|
| 227 |
+
def _bbox_gpu(bool_masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 228 |
+
"""Compute bboxes from (N, H, W) bool GPU masks. Returns GPU tensors, zero sync.
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
bboxes: (N, 4) int64 on same device as input [x_min, y_min, x_max, y_max]
|
| 232 |
+
valid: (N,) bool on same device as input
|
| 233 |
+
"""
|
| 234 |
+
N, H, W = bool_masks.shape
|
| 235 |
+
rows = bool_masks.any(dim=2) # (N, H)
|
| 236 |
+
cols = bool_masks.any(dim=1) # (N, W)
|
| 237 |
+
valid = rows.any(dim=1) # (N,)
|
| 238 |
+
rows_f = rows.float()
|
| 239 |
+
cols_f = cols.float()
|
| 240 |
+
bboxes = torch.stack([
|
| 241 |
+
cols_f.argmax(dim=1), # x_min
|
| 242 |
+
rows_f.argmax(dim=1), # y_min
|
| 243 |
+
W - 1 - cols_f.flip(1).argmax(dim=1), # x_max
|
| 244 |
+
H - 1 - rows_f.flip(1).argmax(dim=1), # y_max
|
| 245 |
+
], dim=1).to(torch.int64) # (N, 4) int64
|
| 246 |
+
return bboxes, valid
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# ---------------------------------------------------------------------------
|
| 250 |
+
# GPU-resident segment output (deferred CPU materialization)
|
| 251 |
+
# ---------------------------------------------------------------------------
|
| 252 |
+
|
| 253 |
+
@dataclass
|
| 254 |
+
class SegmentOutput:
|
| 255 |
+
"""GPU-resident segment propagation result. Zero CUDA syncs to construct."""
|
| 256 |
+
masks: torch.Tensor # (count, H, W) bool on GPU
|
| 257 |
+
bboxes: torch.Tensor # (count, 4) int64 on GPU
|
| 258 |
+
valid: torch.Tensor # (count,) bool on GPU
|
| 259 |
+
frame_indices: List[int] # len == count
|
| 260 |
+
obj_ids: List[int] # len == count
|
| 261 |
+
class_names: List[str] # len == count
|
| 262 |
+
device: str = "cpu"
|
| 263 |
+
|
| 264 |
+
def to_object_dicts(self) -> Dict[int, Dict[int, "ObjectInfo"]]:
|
| 265 |
+
"""Bulk CPU transfer + ObjectInfo construction. 3 CUDA syncs total."""
|
| 266 |
+
if self.masks.numel() == 0:
|
| 267 |
+
return {}
|
| 268 |
+
masks_cpu = self.masks.cpu() # sync 1
|
| 269 |
+
bboxes_cpu = self.bboxes.cpu() # sync 2
|
| 270 |
+
valid_cpu = self.valid.cpu() # sync 3
|
| 271 |
+
result: Dict[int, Dict[int, ObjectInfo]] = {}
|
| 272 |
+
for i in range(masks_cpu.shape[0]):
|
| 273 |
+
fi, oid, cn = self.frame_indices[i], self.obj_ids[i], self.class_names[i]
|
| 274 |
+
if valid_cpu[i]:
|
| 275 |
+
x1, y1, x2, y2 = int(bboxes_cpu[i, 0]), int(bboxes_cpu[i, 1]), int(bboxes_cpu[i, 2]), int(bboxes_cpu[i, 3])
|
| 276 |
+
else:
|
| 277 |
+
x1 = y1 = x2 = y2 = 0
|
| 278 |
+
info = ObjectInfo(
|
| 279 |
+
instance_id=oid, mask=masks_cpu[i],
|
| 280 |
+
class_name=cn, x1=x1, y1=y1, x2=x2, y2=y2,
|
| 281 |
+
)
|
| 282 |
+
result.setdefault(fi, {})[oid] = info
|
| 283 |
+
return result
|
| 284 |
+
|
| 285 |
+
def last_frame_idx(self) -> Optional[int]:
|
| 286 |
+
return self.frame_indices[-1] if self.frame_indices else None
|
| 287 |
+
|
| 288 |
+
|
| 289 |
# ---------------------------------------------------------------------------
|
| 290 |
# SAM2 HuggingFace model IDs per size
|
| 291 |
# ---------------------------------------------------------------------------
|
|
|
|
| 532 |
start_idx: int,
|
| 533 |
mask_dict: "MaskDictionary",
|
| 534 |
step: int,
|
| 535 |
+
) -> "SegmentOutput":
|
| 536 |
"""Propagate masks for a single segment via SAM2 video predictor.
|
| 537 |
|
| 538 |
+
Returns a GPU-resident ``SegmentOutput`` with zero CUDA syncs.
|
| 539 |
+
Call ``output.to_object_dicts()`` to materialize CPU ObjectInfo dicts.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 540 |
"""
|
| 541 |
_pm = getattr(self, '_perf_metrics', None)
|
| 542 |
if _pm is not None:
|
|
|
|
| 546 |
|
| 547 |
for obj_id, obj_info in mask_dict.labels.items():
|
| 548 |
self._video_predictor.add_new_mask(
|
| 549 |
+
inference_state, start_idx, obj_id, obj_info.mask,
|
|
|
|
|
|
|
|
|
|
| 550 |
)
|
| 551 |
|
| 552 |
+
# Pre-compute class name lookup (avoid repeated dict access in loop)
|
| 553 |
+
obj_id_to_class = {oid: mask_dict.get_target_class_name(oid) for oid in mask_dict.labels}
|
| 554 |
+
n_obj = len(mask_dict.labels)
|
| 555 |
+
|
| 556 |
+
# Pre-allocated GPU buffers (allocated on first yield when H, W known)
|
| 557 |
+
masks_buf = bboxes_buf = valid_buf = None
|
| 558 |
+
frame_indices: List[int] = []
|
| 559 |
+
obj_ids_list: List[int] = []
|
| 560 |
+
class_names_list: List[str] = []
|
| 561 |
+
cursor = 0
|
| 562 |
|
|
|
|
|
|
|
| 563 |
for out_frame_idx, out_obj_ids, out_mask_logits in self._video_predictor.propagate_in_video(
|
| 564 |
+
inference_state, max_frame_num_to_track=step, start_frame_idx=start_idx,
|
|
|
|
|
|
|
| 565 |
):
|
| 566 |
+
bool_masks = (out_mask_logits[:, 0] > 0.0) # (N, H, W) GPU async
|
| 567 |
+
n = bool_masks.shape[0]
|
| 568 |
+
|
| 569 |
+
# Allocate on first yield
|
| 570 |
+
if masks_buf is None:
|
| 571 |
+
H, W = bool_masks.shape[1], bool_masks.shape[2]
|
| 572 |
+
max_entries = step * max(n_obj, n)
|
| 573 |
+
masks_buf = torch.empty(max_entries, H, W, dtype=torch.bool, device=self.device)
|
| 574 |
+
bboxes_buf = torch.empty(max_entries, 4, dtype=torch.int64, device=self.device)
|
| 575 |
+
valid_buf = torch.empty(max_entries, dtype=torch.bool, device=self.device)
|
| 576 |
+
|
| 577 |
+
# Grow buffers if needed (unlikely but safe)
|
| 578 |
+
if cursor + n > masks_buf.shape[0]:
|
| 579 |
+
grow = max(step * n_obj, cursor + n - masks_buf.shape[0])
|
| 580 |
+
H, W = masks_buf.shape[1], masks_buf.shape[2]
|
| 581 |
+
masks_buf = torch.cat([masks_buf, torch.empty(grow, H, W, dtype=torch.bool, device=self.device)])
|
| 582 |
+
bboxes_buf = torch.cat([bboxes_buf, torch.empty(grow, 4, dtype=torch.int64, device=self.device)])
|
| 583 |
+
valid_buf = torch.cat([valid_buf, torch.empty(grow, dtype=torch.bool, device=self.device)])
|
| 584 |
+
|
| 585 |
+
# Inline bbox — GPU async, zero sync
|
| 586 |
+
frame_bboxes, frame_valid = _bbox_gpu(bool_masks)
|
| 587 |
+
|
| 588 |
+
# Fill pre-allocated slices — GPU async
|
| 589 |
+
masks_buf[cursor:cursor + n] = bool_masks
|
| 590 |
+
bboxes_buf[cursor:cursor + n] = frame_bboxes
|
| 591 |
+
valid_buf[cursor:cursor + n] = frame_valid
|
| 592 |
+
|
| 593 |
+
# Metadata (trivial Python, ~2μs GIL)
|
| 594 |
+
oid_list = list(out_obj_ids) if not isinstance(out_obj_ids, list) else out_obj_ids
|
| 595 |
+
for oid in oid_list:
|
| 596 |
+
frame_indices.append(out_frame_idx)
|
| 597 |
+
obj_ids_list.append(oid)
|
| 598 |
+
class_names_list.append(obj_id_to_class.get(oid, ""))
|
| 599 |
+
cursor += n
|
| 600 |
+
|
| 601 |
+
# Build output (zero-copy slice if under-filled, empty tensors if no frames)
|
| 602 |
+
if masks_buf is not None:
|
| 603 |
+
output = SegmentOutput(
|
| 604 |
+
masks=masks_buf[:cursor], bboxes=bboxes_buf[:cursor],
|
| 605 |
+
valid=valid_buf[:cursor], frame_indices=frame_indices,
|
| 606 |
+
obj_ids=obj_ids_list, class_names=class_names_list, device=self.device,
|
| 607 |
+
)
|
| 608 |
+
else:
|
| 609 |
+
output = SegmentOutput(
|
| 610 |
+
masks=torch.empty(0, 0, 0, dtype=torch.bool, device=self.device),
|
| 611 |
+
bboxes=torch.empty(0, 4, dtype=torch.int64, device=self.device),
|
| 612 |
+
valid=torch.empty(0, dtype=torch.bool, device=self.device),
|
| 613 |
+
frame_indices=[], obj_ids=[], class_names=[], device=self.device,
|
| 614 |
+
)
|
| 615 |
|
| 616 |
if _pm is not None:
|
| 617 |
_pl = getattr(self, '_perf_lock', None)
|
|
|
|
| 621 |
else:
|
| 622 |
_pm["sam_video_total_ms"] += _d
|
| 623 |
|
| 624 |
+
return output
|
| 625 |
|
| 626 |
# -- Video-level tracking interface -------------------------------------
|
| 627 |
|
|
|
|
| 796 |
if _pm is not None:
|
| 797 |
_t_sv = time.perf_counter()
|
| 798 |
|
| 799 |
+
segment_output = self.propagate_segment(
|
| 800 |
+
inference_state, start_idx, mask_dict, step,
|
| 801 |
+
)
|
| 802 |
+
segment_results = segment_output.to_object_dicts()
|
| 803 |
+
|
| 804 |
+
if segment_results:
|
| 805 |
+
all_results.update(segment_results)
|
| 806 |
+
last_fi = segment_output.last_frame_idx()
|
| 807 |
+
if last_fi is not None:
|
| 808 |
+
last_frame_objects = all_results.get(last_fi, {})
|
| 809 |
+
sam2_masks = MaskDictionary()
|
| 810 |
+
sam2_masks.labels = copy.deepcopy(last_frame_objects)
|
| 811 |
+
if last_frame_objects:
|
| 812 |
+
first_info = next(iter(last_frame_objects.values()))
|
| 813 |
+
if first_info.mask is not None:
|
| 814 |
+
sam2_masks.mask_height = first_info.mask.shape[-2] if first_info.mask.ndim >= 2 else 0
|
| 815 |
+
sam2_masks.mask_width = first_info.mask.shape[-1] if first_info.mask.ndim >= 2 else 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 816 |
|
| 817 |
if _pm is not None:
|
| 818 |
_pl = getattr(self, '_perf_lock', None)
|