detection_base / models /segmenters /grounded_sam2.py
Zhen Ye
Eliminate redundant JPEG frame loading via shared frame store
c97a5f9
"""Grounded-SAM-2 segmenter with continuous-ID video tracking.
Combines an object detector (open-vocabulary or closed-set) with SAM2's video
predictor to produce temporally consistent segmentation masks with
persistent object IDs across an entire video.
Reference implementation:
Grounded-SAM-2/grounded_sam2_tracking_demo_with_continuous_id.py
"""
import copy
import logging
import time
from contextlib import nullcontext
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
import torch
from PIL import Image
from .base import Segmenter, SegmentationResult
# ---------------------------------------------------------------------------
# Data structures (mirrors Grounded-SAM-2 reference utilities)
# ---------------------------------------------------------------------------
@dataclass
class ObjectInfo:
"""Per-object tracking info for a single frame."""
instance_id: int = 0
mask: Any = None # torch.Tensor bool (H, W)
class_name: str = ""
x1: int = 0
y1: int = 0
x2: int = 0
y2: int = 0
@dataclass
class MaskDictionary:
"""Tracks object masks across frames with IoU-based ID matching."""
mask_height: int = 0
mask_width: int = 0
labels: Dict[int, ObjectInfo] = field(default_factory=dict)
def add_new_frame_annotation(
self,
mask_list: torch.Tensor,
box_list: torch.Tensor,
label_list: list,
):
mask_img = torch.zeros(mask_list.shape[-2:])
anno = {}
for idx, (mask, box, label) in enumerate(zip(mask_list, box_list, label_list)):
final_index = idx + 1
mask_img[mask == True] = final_index # noqa: E712
anno[final_index] = ObjectInfo(
instance_id=final_index,
mask=mask,
class_name=str(label),
x1=int(box[0]),
y1=int(box[1]),
x2=int(box[2]),
y2=int(box[3]),
)
self.mask_height = mask_img.shape[0]
self.mask_width = mask_img.shape[1]
self.labels = anno
def update_masks(
self,
tracking_dict: "MaskDictionary",
iou_threshold: float = 0.5,
objects_count: int = 0,
) -> int:
"""Match current detections against tracked objects via IoU."""
updated = {}
used_tracked_ids = set()
for _seg_id, seg_info in self.labels.items():
if seg_info.mask is None or seg_info.mask.sum() == 0:
continue
matched_id = 0
best_iou = iou_threshold
for _obj_id, obj_info in tracking_dict.labels.items():
if obj_info.instance_id in used_tracked_ids:
continue
iou = self._iou(seg_info.mask, obj_info.mask)
if iou > best_iou:
best_iou = iou
matched_id = obj_info.instance_id
if not matched_id:
objects_count += 1
matched_id = objects_count
else:
used_tracked_ids.add(matched_id)
new_info = ObjectInfo(
instance_id=matched_id,
mask=seg_info.mask,
class_name=seg_info.class_name,
)
updated[matched_id] = new_info
self.labels = updated
return objects_count
def update_masks_with_remapping(
self,
tracking_dict: "MaskDictionary",
iou_threshold: float = 0.5,
objects_count: int = 0,
) -> Tuple[int, Dict[int, int]]:
"""Match detections against tracked objects, returning ID remapping.
Same logic as ``update_masks`` but additionally returns a dict
mapping original (local) IDs to the assigned (global) IDs.
"""
updated = {}
remapping: Dict[int, int] = {}
used_tracked_ids = set()
for seg_id, seg_info in self.labels.items():
if seg_info.mask is None or seg_info.mask.sum() == 0:
continue
matched_id = 0
best_iou = iou_threshold
for _obj_id, obj_info in tracking_dict.labels.items():
if obj_info.instance_id in used_tracked_ids:
continue
iou = self._iou(seg_info.mask, obj_info.mask)
if iou > best_iou:
best_iou = iou
matched_id = obj_info.instance_id
if not matched_id:
objects_count += 1
matched_id = objects_count
else:
used_tracked_ids.add(matched_id)
new_info = ObjectInfo(
instance_id=matched_id,
mask=seg_info.mask,
class_name=seg_info.class_name,
)
updated[matched_id] = new_info
remapping[seg_id] = matched_id
self.labels = updated
return objects_count, remapping
def get_target_class_name(self, instance_id: int) -> str:
info = self.labels.get(instance_id)
return info.class_name if info else ""
@staticmethod
def _iou(m1: torch.Tensor, m2: torch.Tensor) -> float:
if not torch.is_tensor(m1):
m1 = torch.as_tensor(m1)
if not torch.is_tensor(m2):
m2 = torch.as_tensor(m2)
# Multi-GPU reconciliation can compare masks produced on different
# devices; normalize both masks onto CPU before arithmetic.
if m1.device != m2.device:
m1 = m1.detach().to(device="cpu")
m2 = m2.detach().to(device="cpu")
m1f = m1.to(torch.float32)
m2f = m2.to(torch.float32)
inter = (m1f * m2f).sum()
union = m1f.sum() + m2f.sum() - inter
if union == 0:
return 0.0
return float((inter / union).item())
# ---------------------------------------------------------------------------
# GPU-resident bounding-box helper (zero CUDA syncs)
# ---------------------------------------------------------------------------
def _bbox_gpu(bool_masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute bboxes from (N, H, W) bool GPU masks. Returns GPU tensors, zero sync.
Returns:
bboxes: (N, 4) int64 on same device as input [x_min, y_min, x_max, y_max]
valid: (N,) bool on same device as input
"""
N, H, W = bool_masks.shape
rows = bool_masks.any(dim=2) # (N, H)
cols = bool_masks.any(dim=1) # (N, W)
valid = rows.any(dim=1) # (N,)
rows_f = rows.float()
cols_f = cols.float()
bboxes = torch.stack([
cols_f.argmax(dim=1), # x_min
rows_f.argmax(dim=1), # y_min
W - 1 - cols_f.flip(1).argmax(dim=1), # x_max
H - 1 - rows_f.flip(1).argmax(dim=1), # y_max
], dim=1).to(torch.int64) # (N, 4) int64
return bboxes, valid
# ---------------------------------------------------------------------------
# GPU-resident segment output (deferred CPU materialization)
# ---------------------------------------------------------------------------
@dataclass
class SegmentOutput:
"""GPU-resident segment propagation result. Zero CUDA syncs to construct."""
masks: torch.Tensor # (count, H, W) bool on GPU
bboxes: torch.Tensor # (count, 4) int64 on GPU
valid: torch.Tensor # (count,) bool on GPU
frame_indices: List[int] # len == count
obj_ids: List[int] # len == count
class_names: List[str] # len == count
device: str = "cpu"
def last_frame_idx(self) -> Optional[int]:
return self.frame_indices[-1] if self.frame_indices else None
def frame_to_object_dict(
self,
frame_idx: int,
remapping: Optional[Dict[int, int]] = None,
to_cpu: bool = True,
) -> Dict[int, "ObjectInfo"]:
"""Materialize a single frame's ObjectInfo dict from GPU buffers.
Args:
frame_idx: The frame index to materialize.
remapping: Optional local->global ID mapping.
to_cpu: If True, transfer mask/bbox to CPU.
Returns:
``{obj_id: ObjectInfo}`` for the requested frame.
"""
# Build lazy frame index on first call
if not hasattr(self, '_frame_index'):
idx: Dict[int, List[int]] = {}
for i, fi in enumerate(self.frame_indices):
idx.setdefault(fi, []).append(i)
self._frame_index = idx
positions = self._frame_index.get(frame_idx)
if not positions:
return {}
result: Dict[int, ObjectInfo] = {}
for i in positions:
oid = self.obj_ids[i]
cn = self.class_names[i]
global_id = remapping.get(oid, oid) if remapping else oid
mask = self.masks[i]
valid = self.valid[i]
if valid:
bbox = self.bboxes[i]
if to_cpu:
mask = mask.cpu()
x1 = int(bbox[0].item())
y1 = int(bbox[1].item())
x2 = int(bbox[2].item())
y2 = int(bbox[3].item())
else:
x1, y1 = int(bbox[0]), int(bbox[1])
x2, y2 = int(bbox[2]), int(bbox[3])
else:
if to_cpu:
mask = mask.cpu()
x1 = y1 = x2 = y2 = 0
result[global_id] = ObjectInfo(
instance_id=global_id,
mask=mask,
class_name=cn,
x1=x1, y1=y1, x2=x2, y2=y2,
)
return result
# ---------------------------------------------------------------------------
# Lazy frame objects wrapper (deferred GPU->CPU per-frame)
# ---------------------------------------------------------------------------
@dataclass
class LazyFrameObjects:
"""Lightweight wrapper for deferred GPU->CPU materialization.
Holds a reference to a GPU-resident ``SegmentOutput`` plus frame index
and optional ID remapping. Call ``materialize()`` to perform the
GPU->CPU transfer (intended to run in a render worker thread).
"""
segment_output: SegmentOutput
frame_idx: int
remapping: Optional[Dict[int, int]] = None
def materialize(self) -> Dict[int, "ObjectInfo"]:
"""Transfer one frame's data from GPU to CPU and build ObjectInfo dict."""
return self.segment_output.frame_to_object_dict(
self.frame_idx, remapping=self.remapping, to_cpu=True,
)
# ---------------------------------------------------------------------------
# SAM2 HuggingFace model IDs per size
# ---------------------------------------------------------------------------
_SAM2_HF_MODELS = {
"small": "facebook/sam2.1-hiera-small",
"base": "facebook/sam2.1-hiera-base-plus",
"large": "facebook/sam2.1-hiera-large",
}
def _det_label_names(det) -> List[str]:
"""Extract string labels from a DetectionResult, with fallback."""
num_boxes = len(det.boxes) if det.boxes is not None else 0
if det.label_names is not None and len(det.label_names) > 0:
return list(det.label_names)
if det.labels is not None and len(det.labels) > 0:
return [str(l) for l in det.labels]
return ["object"] * num_boxes
# ---------------------------------------------------------------------------
# Grounded-SAM-2 Segmenter
# ---------------------------------------------------------------------------
class GroundedSAM2Segmenter(Segmenter):
"""SAM2 video segmenter driven by an injected object detector.
Any ``ObjectDetector`` can be used (defaults to Grounding DINO).
For single-frame mode (``predict``), uses detector + SAM2 image predictor.
For video mode (``process_video``), uses detector on keyframes + SAM2 video
predictor for temporal mask propagation with continuous object IDs.
"""
supports_batch = False
max_batch_size = 1
def __init__(
self,
model_size: str = "large",
device: Optional[str] = None,
step: int = 20,
iou_threshold: float = 0.5,
num_maskmem: Optional[int] = None,
detector_name: Optional[str] = None,
):
self.model_size = model_size
self.step = step
self.iou_threshold = iou_threshold
self.num_maskmem = num_maskmem # None = use default (7)
self._detector_name = detector_name # None = "grounding_dino"
_size_suffix = {"small": "S", "base": "B", "large": "L"}
_det_prefix = {"yolo11": "YSAM2"}
_prefix = _det_prefix.get(detector_name, "GSAM2")
self.name = f"{_prefix}-{_size_suffix[model_size]}"
if device:
self.device = device
else:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Lazy-loaded model handles
self._video_predictor = None
self._image_predictor = None
self._detector = None
self._models_loaded = False
# -- Lazy loading -------------------------------------------------------
def _ensure_models_loaded(self):
if self._models_loaded:
return
hf_id = _SAM2_HF_MODELS[self.model_size]
logging.info(
"Loading Grounded-SAM-2 (%s) on device %s ...", hf_id, self.device
)
# Enable TF32 on Ampere+ GPUs
if torch.cuda.is_available():
try:
props = torch.cuda.get_device_properties(
int(self.device.split(":")[-1]) if ":" in self.device else 0
)
if props.major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
except Exception:
pass
from sam2.build_sam import build_sam2_hf, build_sam2_video_predictor_hf
from sam2.sam2_image_predictor import SAM2ImagePredictor
# Video predictor (for process_video)
self._video_predictor = build_sam2_video_predictor_hf(
hf_id, device=self.device
)
# Image predictor (for single-frame predict)
sam2_image_model = build_sam2_hf(hf_id, device=self.device)
self._image_predictor = SAM2ImagePredictor(sam2_image_model)
# Override num_maskmem if requested
if self.num_maskmem is not None:
self._patch_num_maskmem(self._video_predictor, self.num_maskmem)
logging.info("Patched video predictor num_maskmem → %d", self.num_maskmem)
# Load detector by name (defaults to Grounding DINO)
from models.model_loader import load_detector_on_device
det_name = self._detector_name or "grounding_dino"
self._detector = load_detector_on_device(det_name, self.device)
self._models_loaded = True
logging.info("Grounded-SAM-2 models loaded successfully.")
@staticmethod
def _patch_num_maskmem(predictor, num_maskmem: int):
"""Override num_maskmem on a loaded SAM2 video predictor at runtime.
Slices the temporal positional encoding parameter to match the new
memory size so the model runs without shape mismatches.
"""
import torch.nn as nn
# The underlying model may be predictor itself or predictor.model
model = getattr(predictor, "model", predictor)
old = getattr(model, "num_maskmem", None)
if old is None:
logging.warning("Cannot patch num_maskmem: attribute not found on model")
return
if num_maskmem == old:
return
model.num_maskmem = num_maskmem
# Slice or pad maskmem_tpos_enc (shape: [num_maskmem, 1, 1, mem_dim])
if hasattr(model, "maskmem_tpos_enc") and model.maskmem_tpos_enc is not None:
old_enc = model.maskmem_tpos_enc
if num_maskmem <= old_enc.shape[0]:
model.maskmem_tpos_enc = nn.Parameter(
old_enc[:num_maskmem].clone()
)
else:
# Pad with zeros for the extra slots
pad = torch.zeros(
num_maskmem - old_enc.shape[0], *old_enc.shape[1:],
device=old_enc.device, dtype=old_enc.dtype,
)
model.maskmem_tpos_enc = nn.Parameter(
torch.cat([old_enc, pad], dim=0)
)
logging.info("num_maskmem changed from %d to %d", old, num_maskmem)
# -- GPU-resident SAM2 predict (skip numpy conversion) ------------------
def _predict_masks_gpu(
self, input_boxes: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Run SAM2 image predictor keeping masks on GPU (skip numpy conversion).
Calls SAM2's internal ``_prep_prompts`` + ``_predict`` directly,
bypassing the public ``predict()`` which converts to numpy.
Args:
input_boxes: (N, 4) float tensor on device.
Returns:
``(masks, scores)`` — *masks* is ``(N, H, W)`` bool GPU tensor,
*scores* is ``(N,)`` float GPU tensor.
"""
mask_input, unnorm_coords, labels, unnorm_box = (
self._image_predictor._prep_prompts(
point_coords=None,
point_labels=None,
box=input_boxes,
mask_logits=None,
normalize_coords=True,
)
)
masks, scores, _ = self._image_predictor._predict(
unnorm_coords, labels, unnorm_box, mask_input,
multimask_output=False, return_logits=False,
)
# _predict returns (1, N, ..., H, W); squeeze batch dim
masks = masks.squeeze(0)
if masks.ndim == 2:
masks = masks[None]
elif masks.ndim == 4:
masks = masks.squeeze(1)
scores = scores.squeeze(0).flatten()
return masks, scores
# -- Single-frame interface (Segmenter.predict) -------------------------
def predict(
self, frame: np.ndarray, text_prompts: Optional[list] = None
) -> SegmentationResult:
"""Run detector + SAM2 image predictor on a single frame."""
self._ensure_models_loaded()
prompts = text_prompts or ["object"]
# Run detector to get boxes
det = self._detector.predict(frame, prompts)
if det.boxes is None or len(det.boxes) == 0:
return SegmentationResult(
masks=np.zeros((0, frame.shape[0], frame.shape[1]), dtype=bool),
scores=None,
boxes=None,
label_names=None,
)
# SAM2 image predictor expects RGB
import cv2 as _cv2
frame_rgb = _cv2.cvtColor(frame, _cv2.COLOR_BGR2RGB)
device_type = self.device.split(":")[0]
autocast_ctx = torch.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
with autocast_ctx:
self._image_predictor.set_image(frame_rgb)
input_boxes = torch.tensor(det.boxes, device=self.device, dtype=torch.float32)
masks, scores, _ = self._image_predictor.predict(
point_coords=None,
point_labels=None,
box=input_boxes,
multimask_output=False,
)
# Normalize mask shape to (N, H, W)
if masks.ndim == 2:
masks = masks[None]
elif masks.ndim == 4:
masks = masks.squeeze(1)
if isinstance(masks, torch.Tensor):
masks_np = masks.cpu().numpy().astype(bool)
else:
masks_np = np.asarray(masks).astype(bool)
scores_np = None
if scores is not None:
if isinstance(scores, torch.Tensor):
scores_np = scores.cpu().numpy().flatten()
else:
scores_np = np.asarray(scores).flatten()
return SegmentationResult(
masks=masks_np,
scores=scores_np,
boxes=det.boxes,
label_names=det.label_names,
)
# -- Multi-GPU helper methods -------------------------------------------
def detect_keyframe(
self,
image: "Image",
text_prompts: List[str],
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], List[str]]:
"""Run detector + SAM2 image predictor on a single keyframe.
Args:
image: PIL Image in RGB mode.
text_prompts: Text queries for the detector.
Returns:
``(masks, boxes, labels)`` where *masks* is an ``(N, H, W)``
bool GPU tensor, *boxes* is an ``(N, 4)`` tensor on device,
and *labels* is a list of strings. Returns
``(None, None, [])`` when no objects are detected.
"""
self._ensure_models_loaded()
_pm = getattr(self, '_perf_metrics', None)
if _pm is not None:
_t0 = time.perf_counter()
# Convert PIL RGB → numpy BGR for detector.predict()
frame_bgr = np.array(image)[:, :, ::-1].copy()
det = self._detector.predict(frame_bgr, text_prompts)
if _pm is not None:
_pl = getattr(self, '_perf_lock', None)
_d = (time.perf_counter() - _t0) * 1000.0
if _pl:
with _pl: _pm["gdino_total_ms"] += _d
else:
_pm["gdino_total_ms"] += _d
if det.boxes is None or len(det.boxes) == 0:
return None, None, []
input_boxes = torch.tensor(det.boxes, device=self.device, dtype=torch.float32)
det_labels = _det_label_names(det)
# SAM2 image predictor
if _pm is not None:
_t1 = time.perf_counter()
self._image_predictor.set_image(np.array(image))
masks, _ = self._predict_masks_gpu(input_boxes)
if _pm is not None:
_pl = getattr(self, '_perf_lock', None)
_d = (time.perf_counter() - _t1) * 1000.0
if _pl:
with _pl: _pm["sam_image_total_ms"] += _d
else:
_pm["sam_image_total_ms"] += _d
return masks, input_boxes, det_labels
def propagate_segment(
self,
inference_state: Any,
start_idx: int,
mask_dict: "MaskDictionary",
step: int,
) -> "SegmentOutput":
"""Propagate masks for a single segment via SAM2 video predictor.
Returns a GPU-resident ``SegmentOutput`` with zero CUDA syncs.
Call ``output.frame_to_object_dict()`` to materialize per-frame CPU dicts.
"""
_pm = getattr(self, '_perf_metrics', None)
if _pm is not None:
_t0 = time.perf_counter()
self._video_predictor.reset_state(inference_state)
for obj_id, obj_info in mask_dict.labels.items():
self._video_predictor.add_new_mask(
inference_state, start_idx, obj_id, obj_info.mask,
)
# Pre-compute class name lookup (avoid repeated dict access in loop)
obj_id_to_class = {oid: mask_dict.get_target_class_name(oid) for oid in mask_dict.labels}
n_obj = len(mask_dict.labels)
# Pre-allocated GPU buffers (allocated on first yield when H, W known)
masks_buf = bboxes_buf = valid_buf = None
frame_indices: List[int] = []
obj_ids_list: List[int] = []
class_names_list: List[str] = []
cursor = 0
for out_frame_idx, out_obj_ids, out_mask_logits in self._video_predictor.propagate_in_video(
inference_state, max_frame_num_to_track=step, start_frame_idx=start_idx,
):
bool_masks = (out_mask_logits[:, 0] > 0.0) # (N, H, W) GPU async
n = bool_masks.shape[0]
# Allocate on first yield
if masks_buf is None:
H, W = bool_masks.shape[1], bool_masks.shape[2]
max_entries = step * max(n_obj, n)
masks_buf = torch.empty(max_entries, H, W, dtype=torch.bool, device=self.device)
bboxes_buf = torch.empty(max_entries, 4, dtype=torch.int64, device=self.device)
valid_buf = torch.empty(max_entries, dtype=torch.bool, device=self.device)
# Grow buffers if needed (unlikely but safe)
if cursor + n > masks_buf.shape[0]:
grow = max(step * n_obj, cursor + n - masks_buf.shape[0])
H, W = masks_buf.shape[1], masks_buf.shape[2]
masks_buf = torch.cat([masks_buf, torch.empty(grow, H, W, dtype=torch.bool, device=self.device)])
bboxes_buf = torch.cat([bboxes_buf, torch.empty(grow, 4, dtype=torch.int64, device=self.device)])
valid_buf = torch.cat([valid_buf, torch.empty(grow, dtype=torch.bool, device=self.device)])
# Inline bbox — GPU async, zero sync
frame_bboxes, frame_valid = _bbox_gpu(bool_masks)
# Fill pre-allocated slices — GPU async
masks_buf[cursor:cursor + n] = bool_masks
bboxes_buf[cursor:cursor + n] = frame_bboxes
valid_buf[cursor:cursor + n] = frame_valid
# Metadata (trivial Python, ~2μs GIL)
oid_list = list(out_obj_ids) if not isinstance(out_obj_ids, list) else out_obj_ids
for oid in oid_list:
frame_indices.append(out_frame_idx)
obj_ids_list.append(oid)
class_names_list.append(obj_id_to_class.get(oid, ""))
cursor += n
# Build output (zero-copy slice if under-filled, empty tensors if no frames)
if masks_buf is not None:
output = SegmentOutput(
masks=masks_buf[:cursor], bboxes=bboxes_buf[:cursor],
valid=valid_buf[:cursor], frame_indices=frame_indices,
obj_ids=obj_ids_list, class_names=class_names_list, device=self.device,
)
else:
output = SegmentOutput(
masks=torch.empty(0, 0, 0, dtype=torch.bool, device=self.device),
bboxes=torch.empty(0, 4, dtype=torch.int64, device=self.device),
valid=torch.empty(0, dtype=torch.bool, device=self.device),
frame_indices=[], obj_ids=[], class_names=[], device=self.device,
)
if _pm is not None:
_pl = getattr(self, '_perf_lock', None)
_d = (time.perf_counter() - _t0) * 1000.0
if _pl:
with _pl: _pm["sam_video_total_ms"] += _d
else:
_pm["sam_video_total_ms"] += _d
return output
# -- Video-level tracking interface -------------------------------------
def process_video(
self,
frame_dir: str,
frame_names: List[str],
text_prompts: List[str],
on_segment: Optional[Callable[[Dict[int, Dict[int, "ObjectInfo"]]], None]] = None,
on_segment_output: Optional[Callable[["SegmentOutput"], None]] = None,
_ttfs_t0: Optional[float] = None,
_ttfs_job_id: Optional[str] = None,
frame_store=None,
) -> Dict[int, Dict[int, ObjectInfo]]:
"""Run full Grounded-SAM-2 tracking pipeline on extracted JPEG frames.
Args:
frame_dir: Directory containing JPEG frames.
frame_names: Sorted list of frame filenames.
text_prompts: Text queries for the detector.
on_segment: Optional callback invoked after each segment completes.
Receives ``{frame_idx: {obj_id: ObjectInfo}}`` for the segment.
Returns:
Dict mapping frame_idx -> {obj_id: ObjectInfo} with masks,
bboxes, and class names for every frame.
"""
import os
self._ensure_models_loaded()
device = self.device
step = self.step
total_frames = len(frame_names)
logging.info(
"Grounded-SAM-2 tracking: %d frames, step=%d, queries=%s",
total_frames, step, text_prompts,
)
# Single global autocast context (matches reference implementation)
device_type = device.split(":")[0]
autocast_ctx = torch.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
_pm = getattr(self, '_perf_metrics', None)
sam2_masks = MaskDictionary()
objects_count = 0
all_results: Dict[int, Dict[int, ObjectInfo]] = {}
with autocast_ctx:
# Init SAM2 video predictor state
if _pm is not None:
_t_init = time.perf_counter()
if frame_store is not None:
inference_state = self._video_predictor.init_state(
video_path=frame_dir, # dummy dir with 1 JPEG
offload_video_to_cpu=True,
async_loading_frames=False,
)
# Clear cached_features (dummy frame 0's backbone features)
inference_state["cached_features"] = {}
# Patch in real frame data
img_size = self._video_predictor.image_size
inference_state["images"] = frame_store.sam2_adapter(image_size=img_size)
inference_state["num_frames"] = len(frame_store)
inference_state["video_height"] = frame_store.height
inference_state["video_width"] = frame_store.width
else:
inference_state = self._video_predictor.init_state(
video_path=frame_dir,
offload_video_to_cpu=True,
async_loading_frames=True,
)
if _pm is not None:
_pl = getattr(self, '_perf_lock', None)
_d = (time.perf_counter() - _t_init) * 1000.0
if _pl:
with _pl: _pm["init_state_ms"] += _d
else:
_pm["init_state_ms"] += _d
for start_idx in range(0, total_frames, step):
logging.info("Processing keyframe %d / %d", start_idx, total_frames)
if frame_store is not None:
image = frame_store.get_pil_rgb(start_idx)
else:
image = Image.open(os.path.join(frame_dir, frame_names[start_idx])).convert("RGB")
mask_dict = MaskDictionary()
# -- Detector on keyframe --
if _pm is not None:
_t_gd = time.perf_counter()
frame_bgr = np.array(image)[:, :, ::-1].copy()
det = self._detector.predict(frame_bgr, text_prompts)
if _pm is not None:
_pl = getattr(self, '_perf_lock', None)
_d = (time.perf_counter() - _t_gd) * 1000.0
if _pl:
with _pl: _pm["gdino_total_ms"] += _d
else:
_pm["gdino_total_ms"] += _d
if det.boxes is None or len(det.boxes) == 0:
input_boxes = torch.zeros((0, 4), device=device)
det_labels = []
else:
input_boxes = torch.tensor(det.boxes, device=device, dtype=torch.float32)
det_labels = _det_label_names(det)
if len(input_boxes) == 0:
logging.info("No detections on keyframe %d, propagating previous masks", start_idx)
# Fill empty results for this segment
seg_results: Dict[int, Dict[int, ObjectInfo]] = {}
for fi in range(start_idx, min(start_idx + step, total_frames)):
if fi not in all_results:
# Carry forward last known masks
all_results[fi] = {
k: ObjectInfo(
instance_id=v.instance_id,
mask=v.mask,
class_name=v.class_name,
x1=v.x1, y1=v.y1, x2=v.x2, y2=v.y2,
)
for k, v in sam2_masks.labels.items()
} if sam2_masks.labels else {}
seg_results[fi] = all_results[fi]
if on_segment and seg_results:
on_segment(seg_results)
if start_idx == 0 and _ttfs_t0 is not None:
logging.info("[TTFS:%s] +%.1fs first_segment_complete (no detections, step=%d)",
_ttfs_job_id, time.perf_counter() - _ttfs_t0, step)
continue
# -- SAM2 image predictor on keyframe --
if _pm is not None:
_t_si = time.perf_counter()
self._image_predictor.set_image(np.array(image))
masks, _ = self._predict_masks_gpu(input_boxes)
if _pm is not None:
_pl = getattr(self, '_perf_lock', None)
_d = (time.perf_counter() - _t_si) * 1000.0
if _pl:
with _pl: _pm["sam_image_total_ms"] += _d
else:
_pm["sam_image_total_ms"] += _d
mask_dict.add_new_frame_annotation(
mask_list=masks,
box_list=input_boxes.clone() if torch.is_tensor(input_boxes) else torch.tensor(input_boxes),
label_list=det_labels,
)
# -- IoU matching to maintain persistent IDs --
if _pm is not None:
_t_id = time.perf_counter()
objects_count = mask_dict.update_masks(
tracking_dict=sam2_masks,
iou_threshold=self.iou_threshold,
objects_count=objects_count,
)
if _pm is not None:
_pl = getattr(self, '_perf_lock', None)
_d = (time.perf_counter() - _t_id) * 1000.0
if _pl:
with _pl: _pm["id_reconciliation_ms"] += _d
else:
_pm["id_reconciliation_ms"] += _d
if len(mask_dict.labels) == 0:
seg_results_empty: Dict[int, Dict[int, ObjectInfo]] = {}
for fi in range(start_idx, min(start_idx + step, total_frames)):
all_results[fi] = {}
seg_results_empty[fi] = {}
if on_segment:
on_segment(seg_results_empty)
if start_idx == 0 and _ttfs_t0 is not None:
logging.info("[TTFS:%s] +%.1fs first_segment_complete (empty masks, step=%d)",
_ttfs_job_id, time.perf_counter() - _ttfs_t0, step)
continue
# -- SAM2 video predictor: propagate masks --
# NOTE: propagate_segment() already accumulates into
# _pm["sam_video_total_ms"], so no outer timer here.
segment_output = self.propagate_segment(
inference_state, start_idx, mask_dict, step,
)
# GPU-deferred path: only materialize last frame for IoU
last_fi = segment_output.last_frame_idx()
if last_fi is not None:
last_frame_objects = segment_output.frame_to_object_dict(
last_fi, to_cpu=True,
)
all_results[last_fi] = last_frame_objects
sam2_masks = MaskDictionary()
sam2_masks.labels = copy.deepcopy(last_frame_objects)
if last_frame_objects:
first_info = next(iter(last_frame_objects.values()))
if first_info.mask is not None:
sam2_masks.mask_height = first_info.mask.shape[-2] if first_info.mask.ndim >= 2 else 0
sam2_masks.mask_width = first_info.mask.shape[-1] if first_info.mask.ndim >= 2 else 0
if on_segment_output is not None:
on_segment_output(segment_output)
if start_idx == 0 and _ttfs_t0 is not None:
logging.info("[TTFS:%s] +%.1fs first_segment_complete (step=%d)",
_ttfs_job_id, time.perf_counter() - _ttfs_t0, step)
logging.info(
"Grounded-SAM-2 tracking complete: %d frames, %d tracked objects",
len(all_results), objects_count,
)
return all_results