"""Grounded-SAM-2 segmenter with continuous-ID video tracking. Combines an object detector (open-vocabulary or closed-set) with SAM2's video predictor to produce temporally consistent segmentation masks with persistent object IDs across an entire video. Reference implementation: Grounded-SAM-2/grounded_sam2_tracking_demo_with_continuous_id.py """ import copy import logging import time from contextlib import nullcontext from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np import torch from PIL import Image from .base import Segmenter, SegmentationResult # --------------------------------------------------------------------------- # Data structures (mirrors Grounded-SAM-2 reference utilities) # --------------------------------------------------------------------------- @dataclass class ObjectInfo: """Per-object tracking info for a single frame.""" instance_id: int = 0 mask: Any = None # torch.Tensor bool (H, W) class_name: str = "" x1: int = 0 y1: int = 0 x2: int = 0 y2: int = 0 @dataclass class MaskDictionary: """Tracks object masks across frames with IoU-based ID matching.""" mask_height: int = 0 mask_width: int = 0 labels: Dict[int, ObjectInfo] = field(default_factory=dict) def add_new_frame_annotation( self, mask_list: torch.Tensor, box_list: torch.Tensor, label_list: list, ): mask_img = torch.zeros(mask_list.shape[-2:]) anno = {} for idx, (mask, box, label) in enumerate(zip(mask_list, box_list, label_list)): final_index = idx + 1 mask_img[mask == True] = final_index # noqa: E712 anno[final_index] = ObjectInfo( instance_id=final_index, mask=mask, class_name=str(label), x1=int(box[0]), y1=int(box[1]), x2=int(box[2]), y2=int(box[3]), ) self.mask_height = mask_img.shape[0] self.mask_width = mask_img.shape[1] self.labels = anno def update_masks( self, tracking_dict: "MaskDictionary", iou_threshold: float = 0.5, objects_count: int = 0, ) -> int: """Match current detections against tracked objects via IoU.""" updated = {} used_tracked_ids = set() for _seg_id, seg_info in self.labels.items(): if seg_info.mask is None or seg_info.mask.sum() == 0: continue matched_id = 0 best_iou = iou_threshold for _obj_id, obj_info in tracking_dict.labels.items(): if obj_info.instance_id in used_tracked_ids: continue iou = self._iou(seg_info.mask, obj_info.mask) if iou > best_iou: best_iou = iou matched_id = obj_info.instance_id if not matched_id: objects_count += 1 matched_id = objects_count else: used_tracked_ids.add(matched_id) new_info = ObjectInfo( instance_id=matched_id, mask=seg_info.mask, class_name=seg_info.class_name, ) updated[matched_id] = new_info self.labels = updated return objects_count def update_masks_with_remapping( self, tracking_dict: "MaskDictionary", iou_threshold: float = 0.5, objects_count: int = 0, ) -> Tuple[int, Dict[int, int]]: """Match detections against tracked objects, returning ID remapping. Same logic as ``update_masks`` but additionally returns a dict mapping original (local) IDs to the assigned (global) IDs. """ updated = {} remapping: Dict[int, int] = {} used_tracked_ids = set() for seg_id, seg_info in self.labels.items(): if seg_info.mask is None or seg_info.mask.sum() == 0: continue matched_id = 0 best_iou = iou_threshold for _obj_id, obj_info in tracking_dict.labels.items(): if obj_info.instance_id in used_tracked_ids: continue iou = self._iou(seg_info.mask, obj_info.mask) if iou > best_iou: best_iou = iou matched_id = obj_info.instance_id if not matched_id: objects_count += 1 matched_id = objects_count else: used_tracked_ids.add(matched_id) new_info = ObjectInfo( instance_id=matched_id, mask=seg_info.mask, class_name=seg_info.class_name, ) updated[matched_id] = new_info remapping[seg_id] = matched_id self.labels = updated return objects_count, remapping def get_target_class_name(self, instance_id: int) -> str: info = self.labels.get(instance_id) return info.class_name if info else "" @staticmethod def _iou(m1: torch.Tensor, m2: torch.Tensor) -> float: if not torch.is_tensor(m1): m1 = torch.as_tensor(m1) if not torch.is_tensor(m2): m2 = torch.as_tensor(m2) # Multi-GPU reconciliation can compare masks produced on different # devices; normalize both masks onto CPU before arithmetic. if m1.device != m2.device: m1 = m1.detach().to(device="cpu") m2 = m2.detach().to(device="cpu") m1f = m1.to(torch.float32) m2f = m2.to(torch.float32) inter = (m1f * m2f).sum() union = m1f.sum() + m2f.sum() - inter if union == 0: return 0.0 return float((inter / union).item()) # --------------------------------------------------------------------------- # GPU-resident bounding-box helper (zero CUDA syncs) # --------------------------------------------------------------------------- def _bbox_gpu(bool_masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Compute bboxes from (N, H, W) bool GPU masks. Returns GPU tensors, zero sync. Returns: bboxes: (N, 4) int64 on same device as input [x_min, y_min, x_max, y_max] valid: (N,) bool on same device as input """ N, H, W = bool_masks.shape rows = bool_masks.any(dim=2) # (N, H) cols = bool_masks.any(dim=1) # (N, W) valid = rows.any(dim=1) # (N,) rows_f = rows.float() cols_f = cols.float() bboxes = torch.stack([ cols_f.argmax(dim=1), # x_min rows_f.argmax(dim=1), # y_min W - 1 - cols_f.flip(1).argmax(dim=1), # x_max H - 1 - rows_f.flip(1).argmax(dim=1), # y_max ], dim=1).to(torch.int64) # (N, 4) int64 return bboxes, valid # --------------------------------------------------------------------------- # GPU-resident segment output (deferred CPU materialization) # --------------------------------------------------------------------------- @dataclass class SegmentOutput: """GPU-resident segment propagation result. Zero CUDA syncs to construct.""" masks: torch.Tensor # (count, H, W) bool on GPU bboxes: torch.Tensor # (count, 4) int64 on GPU valid: torch.Tensor # (count,) bool on GPU frame_indices: List[int] # len == count obj_ids: List[int] # len == count class_names: List[str] # len == count device: str = "cpu" def last_frame_idx(self) -> Optional[int]: return self.frame_indices[-1] if self.frame_indices else None def frame_to_object_dict( self, frame_idx: int, remapping: Optional[Dict[int, int]] = None, to_cpu: bool = True, ) -> Dict[int, "ObjectInfo"]: """Materialize a single frame's ObjectInfo dict from GPU buffers. Args: frame_idx: The frame index to materialize. remapping: Optional local->global ID mapping. to_cpu: If True, transfer mask/bbox to CPU. Returns: ``{obj_id: ObjectInfo}`` for the requested frame. """ # Build lazy frame index on first call if not hasattr(self, '_frame_index'): idx: Dict[int, List[int]] = {} for i, fi in enumerate(self.frame_indices): idx.setdefault(fi, []).append(i) self._frame_index = idx positions = self._frame_index.get(frame_idx) if not positions: return {} result: Dict[int, ObjectInfo] = {} for i in positions: oid = self.obj_ids[i] cn = self.class_names[i] global_id = remapping.get(oid, oid) if remapping else oid mask = self.masks[i] valid = self.valid[i] if valid: bbox = self.bboxes[i] if to_cpu: mask = mask.cpu() x1 = int(bbox[0].item()) y1 = int(bbox[1].item()) x2 = int(bbox[2].item()) y2 = int(bbox[3].item()) else: x1, y1 = int(bbox[0]), int(bbox[1]) x2, y2 = int(bbox[2]), int(bbox[3]) else: if to_cpu: mask = mask.cpu() x1 = y1 = x2 = y2 = 0 result[global_id] = ObjectInfo( instance_id=global_id, mask=mask, class_name=cn, x1=x1, y1=y1, x2=x2, y2=y2, ) return result # --------------------------------------------------------------------------- # Lazy frame objects wrapper (deferred GPU->CPU per-frame) # --------------------------------------------------------------------------- @dataclass class LazyFrameObjects: """Lightweight wrapper for deferred GPU->CPU materialization. Holds a reference to a GPU-resident ``SegmentOutput`` plus frame index and optional ID remapping. Call ``materialize()`` to perform the GPU->CPU transfer (intended to run in a render worker thread). """ segment_output: SegmentOutput frame_idx: int remapping: Optional[Dict[int, int]] = None def materialize(self) -> Dict[int, "ObjectInfo"]: """Transfer one frame's data from GPU to CPU and build ObjectInfo dict.""" return self.segment_output.frame_to_object_dict( self.frame_idx, remapping=self.remapping, to_cpu=True, ) # --------------------------------------------------------------------------- # SAM2 HuggingFace model IDs per size # --------------------------------------------------------------------------- _SAM2_HF_MODELS = { "small": "facebook/sam2.1-hiera-small", "base": "facebook/sam2.1-hiera-base-plus", "large": "facebook/sam2.1-hiera-large", } def _det_label_names(det) -> List[str]: """Extract string labels from a DetectionResult, with fallback.""" num_boxes = len(det.boxes) if det.boxes is not None else 0 if det.label_names is not None and len(det.label_names) > 0: return list(det.label_names) if det.labels is not None and len(det.labels) > 0: return [str(l) for l in det.labels] return ["object"] * num_boxes # --------------------------------------------------------------------------- # Grounded-SAM-2 Segmenter # --------------------------------------------------------------------------- class GroundedSAM2Segmenter(Segmenter): """SAM2 video segmenter driven by an injected object detector. Any ``ObjectDetector`` can be used (defaults to Grounding DINO). For single-frame mode (``predict``), uses detector + SAM2 image predictor. For video mode (``process_video``), uses detector on keyframes + SAM2 video predictor for temporal mask propagation with continuous object IDs. """ supports_batch = False max_batch_size = 1 def __init__( self, model_size: str = "large", device: Optional[str] = None, step: int = 20, iou_threshold: float = 0.5, num_maskmem: Optional[int] = None, detector_name: Optional[str] = None, ): self.model_size = model_size self.step = step self.iou_threshold = iou_threshold self.num_maskmem = num_maskmem # None = use default (7) self._detector_name = detector_name # None = "grounding_dino" _size_suffix = {"small": "S", "base": "B", "large": "L"} _det_prefix = {"yolo11": "YSAM2"} _prefix = _det_prefix.get(detector_name, "GSAM2") self.name = f"{_prefix}-{_size_suffix[model_size]}" if device: self.device = device else: self.device = "cuda" if torch.cuda.is_available() else "cpu" # Lazy-loaded model handles self._video_predictor = None self._image_predictor = None self._detector = None self._models_loaded = False # -- Lazy loading ------------------------------------------------------- def _ensure_models_loaded(self): if self._models_loaded: return hf_id = _SAM2_HF_MODELS[self.model_size] logging.info( "Loading Grounded-SAM-2 (%s) on device %s ...", hf_id, self.device ) # Enable TF32 on Ampere+ GPUs if torch.cuda.is_available(): try: props = torch.cuda.get_device_properties( int(self.device.split(":")[-1]) if ":" in self.device else 0 ) if props.major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True except Exception: pass from sam2.build_sam import build_sam2_hf, build_sam2_video_predictor_hf from sam2.sam2_image_predictor import SAM2ImagePredictor # Video predictor (for process_video) self._video_predictor = build_sam2_video_predictor_hf( hf_id, device=self.device ) # Image predictor (for single-frame predict) sam2_image_model = build_sam2_hf(hf_id, device=self.device) self._image_predictor = SAM2ImagePredictor(sam2_image_model) # Override num_maskmem if requested if self.num_maskmem is not None: self._patch_num_maskmem(self._video_predictor, self.num_maskmem) logging.info("Patched video predictor num_maskmem → %d", self.num_maskmem) # Load detector by name (defaults to Grounding DINO) from models.model_loader import load_detector_on_device det_name = self._detector_name or "grounding_dino" self._detector = load_detector_on_device(det_name, self.device) self._models_loaded = True logging.info("Grounded-SAM-2 models loaded successfully.") @staticmethod def _patch_num_maskmem(predictor, num_maskmem: int): """Override num_maskmem on a loaded SAM2 video predictor at runtime. Slices the temporal positional encoding parameter to match the new memory size so the model runs without shape mismatches. """ import torch.nn as nn # The underlying model may be predictor itself or predictor.model model = getattr(predictor, "model", predictor) old = getattr(model, "num_maskmem", None) if old is None: logging.warning("Cannot patch num_maskmem: attribute not found on model") return if num_maskmem == old: return model.num_maskmem = num_maskmem # Slice or pad maskmem_tpos_enc (shape: [num_maskmem, 1, 1, mem_dim]) if hasattr(model, "maskmem_tpos_enc") and model.maskmem_tpos_enc is not None: old_enc = model.maskmem_tpos_enc if num_maskmem <= old_enc.shape[0]: model.maskmem_tpos_enc = nn.Parameter( old_enc[:num_maskmem].clone() ) else: # Pad with zeros for the extra slots pad = torch.zeros( num_maskmem - old_enc.shape[0], *old_enc.shape[1:], device=old_enc.device, dtype=old_enc.dtype, ) model.maskmem_tpos_enc = nn.Parameter( torch.cat([old_enc, pad], dim=0) ) logging.info("num_maskmem changed from %d to %d", old, num_maskmem) # -- GPU-resident SAM2 predict (skip numpy conversion) ------------------ def _predict_masks_gpu( self, input_boxes: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Run SAM2 image predictor keeping masks on GPU (skip numpy conversion). Calls SAM2's internal ``_prep_prompts`` + ``_predict`` directly, bypassing the public ``predict()`` which converts to numpy. Args: input_boxes: (N, 4) float tensor on device. Returns: ``(masks, scores)`` — *masks* is ``(N, H, W)`` bool GPU tensor, *scores* is ``(N,)`` float GPU tensor. """ mask_input, unnorm_coords, labels, unnorm_box = ( self._image_predictor._prep_prompts( point_coords=None, point_labels=None, box=input_boxes, mask_logits=None, normalize_coords=True, ) ) masks, scores, _ = self._image_predictor._predict( unnorm_coords, labels, unnorm_box, mask_input, multimask_output=False, return_logits=False, ) # _predict returns (1, N, ..., H, W); squeeze batch dim masks = masks.squeeze(0) if masks.ndim == 2: masks = masks[None] elif masks.ndim == 4: masks = masks.squeeze(1) scores = scores.squeeze(0).flatten() return masks, scores # -- Single-frame interface (Segmenter.predict) ------------------------- def predict( self, frame: np.ndarray, text_prompts: Optional[list] = None ) -> SegmentationResult: """Run detector + SAM2 image predictor on a single frame.""" self._ensure_models_loaded() prompts = text_prompts or ["object"] # Run detector to get boxes det = self._detector.predict(frame, prompts) if det.boxes is None or len(det.boxes) == 0: return SegmentationResult( masks=np.zeros((0, frame.shape[0], frame.shape[1]), dtype=bool), scores=None, boxes=None, label_names=None, ) # SAM2 image predictor expects RGB import cv2 as _cv2 frame_rgb = _cv2.cvtColor(frame, _cv2.COLOR_BGR2RGB) device_type = self.device.split(":")[0] autocast_ctx = torch.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() with autocast_ctx: self._image_predictor.set_image(frame_rgb) input_boxes = torch.tensor(det.boxes, device=self.device, dtype=torch.float32) masks, scores, _ = self._image_predictor.predict( point_coords=None, point_labels=None, box=input_boxes, multimask_output=False, ) # Normalize mask shape to (N, H, W) if masks.ndim == 2: masks = masks[None] elif masks.ndim == 4: masks = masks.squeeze(1) if isinstance(masks, torch.Tensor): masks_np = masks.cpu().numpy().astype(bool) else: masks_np = np.asarray(masks).astype(bool) scores_np = None if scores is not None: if isinstance(scores, torch.Tensor): scores_np = scores.cpu().numpy().flatten() else: scores_np = np.asarray(scores).flatten() return SegmentationResult( masks=masks_np, scores=scores_np, boxes=det.boxes, label_names=det.label_names, ) # -- Multi-GPU helper methods ------------------------------------------- def detect_keyframe( self, image: "Image", text_prompts: List[str], ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], List[str]]: """Run detector + SAM2 image predictor on a single keyframe. Args: image: PIL Image in RGB mode. text_prompts: Text queries for the detector. Returns: ``(masks, boxes, labels)`` where *masks* is an ``(N, H, W)`` bool GPU tensor, *boxes* is an ``(N, 4)`` tensor on device, and *labels* is a list of strings. Returns ``(None, None, [])`` when no objects are detected. """ self._ensure_models_loaded() _pm = getattr(self, '_perf_metrics', None) if _pm is not None: _t0 = time.perf_counter() # Convert PIL RGB → numpy BGR for detector.predict() frame_bgr = np.array(image)[:, :, ::-1].copy() det = self._detector.predict(frame_bgr, text_prompts) if _pm is not None: _pl = getattr(self, '_perf_lock', None) _d = (time.perf_counter() - _t0) * 1000.0 if _pl: with _pl: _pm["gdino_total_ms"] += _d else: _pm["gdino_total_ms"] += _d if det.boxes is None or len(det.boxes) == 0: return None, None, [] input_boxes = torch.tensor(det.boxes, device=self.device, dtype=torch.float32) det_labels = _det_label_names(det) # SAM2 image predictor if _pm is not None: _t1 = time.perf_counter() self._image_predictor.set_image(np.array(image)) masks, _ = self._predict_masks_gpu(input_boxes) if _pm is not None: _pl = getattr(self, '_perf_lock', None) _d = (time.perf_counter() - _t1) * 1000.0 if _pl: with _pl: _pm["sam_image_total_ms"] += _d else: _pm["sam_image_total_ms"] += _d return masks, input_boxes, det_labels def propagate_segment( self, inference_state: Any, start_idx: int, mask_dict: "MaskDictionary", step: int, ) -> "SegmentOutput": """Propagate masks for a single segment via SAM2 video predictor. Returns a GPU-resident ``SegmentOutput`` with zero CUDA syncs. Call ``output.frame_to_object_dict()`` to materialize per-frame CPU dicts. """ _pm = getattr(self, '_perf_metrics', None) if _pm is not None: _t0 = time.perf_counter() self._video_predictor.reset_state(inference_state) for obj_id, obj_info in mask_dict.labels.items(): self._video_predictor.add_new_mask( inference_state, start_idx, obj_id, obj_info.mask, ) # Pre-compute class name lookup (avoid repeated dict access in loop) obj_id_to_class = {oid: mask_dict.get_target_class_name(oid) for oid in mask_dict.labels} n_obj = len(mask_dict.labels) # Pre-allocated GPU buffers (allocated on first yield when H, W known) masks_buf = bboxes_buf = valid_buf = None frame_indices: List[int] = [] obj_ids_list: List[int] = [] class_names_list: List[str] = [] cursor = 0 for out_frame_idx, out_obj_ids, out_mask_logits in self._video_predictor.propagate_in_video( inference_state, max_frame_num_to_track=step, start_frame_idx=start_idx, ): bool_masks = (out_mask_logits[:, 0] > 0.0) # (N, H, W) GPU async n = bool_masks.shape[0] # Allocate on first yield if masks_buf is None: H, W = bool_masks.shape[1], bool_masks.shape[2] max_entries = step * max(n_obj, n) masks_buf = torch.empty(max_entries, H, W, dtype=torch.bool, device=self.device) bboxes_buf = torch.empty(max_entries, 4, dtype=torch.int64, device=self.device) valid_buf = torch.empty(max_entries, dtype=torch.bool, device=self.device) # Grow buffers if needed (unlikely but safe) if cursor + n > masks_buf.shape[0]: grow = max(step * n_obj, cursor + n - masks_buf.shape[0]) H, W = masks_buf.shape[1], masks_buf.shape[2] masks_buf = torch.cat([masks_buf, torch.empty(grow, H, W, dtype=torch.bool, device=self.device)]) bboxes_buf = torch.cat([bboxes_buf, torch.empty(grow, 4, dtype=torch.int64, device=self.device)]) valid_buf = torch.cat([valid_buf, torch.empty(grow, dtype=torch.bool, device=self.device)]) # Inline bbox — GPU async, zero sync frame_bboxes, frame_valid = _bbox_gpu(bool_masks) # Fill pre-allocated slices — GPU async masks_buf[cursor:cursor + n] = bool_masks bboxes_buf[cursor:cursor + n] = frame_bboxes valid_buf[cursor:cursor + n] = frame_valid # Metadata (trivial Python, ~2μs GIL) oid_list = list(out_obj_ids) if not isinstance(out_obj_ids, list) else out_obj_ids for oid in oid_list: frame_indices.append(out_frame_idx) obj_ids_list.append(oid) class_names_list.append(obj_id_to_class.get(oid, "")) cursor += n # Build output (zero-copy slice if under-filled, empty tensors if no frames) if masks_buf is not None: output = SegmentOutput( masks=masks_buf[:cursor], bboxes=bboxes_buf[:cursor], valid=valid_buf[:cursor], frame_indices=frame_indices, obj_ids=obj_ids_list, class_names=class_names_list, device=self.device, ) else: output = SegmentOutput( masks=torch.empty(0, 0, 0, dtype=torch.bool, device=self.device), bboxes=torch.empty(0, 4, dtype=torch.int64, device=self.device), valid=torch.empty(0, dtype=torch.bool, device=self.device), frame_indices=[], obj_ids=[], class_names=[], device=self.device, ) if _pm is not None: _pl = getattr(self, '_perf_lock', None) _d = (time.perf_counter() - _t0) * 1000.0 if _pl: with _pl: _pm["sam_video_total_ms"] += _d else: _pm["sam_video_total_ms"] += _d return output # -- Video-level tracking interface ------------------------------------- def process_video( self, frame_dir: str, frame_names: List[str], text_prompts: List[str], on_segment: Optional[Callable[[Dict[int, Dict[int, "ObjectInfo"]]], None]] = None, on_segment_output: Optional[Callable[["SegmentOutput"], None]] = None, _ttfs_t0: Optional[float] = None, _ttfs_job_id: Optional[str] = None, frame_store=None, ) -> Dict[int, Dict[int, ObjectInfo]]: """Run full Grounded-SAM-2 tracking pipeline on extracted JPEG frames. Args: frame_dir: Directory containing JPEG frames. frame_names: Sorted list of frame filenames. text_prompts: Text queries for the detector. on_segment: Optional callback invoked after each segment completes. Receives ``{frame_idx: {obj_id: ObjectInfo}}`` for the segment. Returns: Dict mapping frame_idx -> {obj_id: ObjectInfo} with masks, bboxes, and class names for every frame. """ import os self._ensure_models_loaded() device = self.device step = self.step total_frames = len(frame_names) logging.info( "Grounded-SAM-2 tracking: %d frames, step=%d, queries=%s", total_frames, step, text_prompts, ) # Single global autocast context (matches reference implementation) device_type = device.split(":")[0] autocast_ctx = torch.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() _pm = getattr(self, '_perf_metrics', None) sam2_masks = MaskDictionary() objects_count = 0 all_results: Dict[int, Dict[int, ObjectInfo]] = {} with autocast_ctx: # Init SAM2 video predictor state if _pm is not None: _t_init = time.perf_counter() if frame_store is not None: inference_state = self._video_predictor.init_state( video_path=frame_dir, # dummy dir with 1 JPEG offload_video_to_cpu=True, async_loading_frames=False, ) # Clear cached_features (dummy frame 0's backbone features) inference_state["cached_features"] = {} # Patch in real frame data img_size = self._video_predictor.image_size inference_state["images"] = frame_store.sam2_adapter(image_size=img_size) inference_state["num_frames"] = len(frame_store) inference_state["video_height"] = frame_store.height inference_state["video_width"] = frame_store.width else: inference_state = self._video_predictor.init_state( video_path=frame_dir, offload_video_to_cpu=True, async_loading_frames=True, ) if _pm is not None: _pl = getattr(self, '_perf_lock', None) _d = (time.perf_counter() - _t_init) * 1000.0 if _pl: with _pl: _pm["init_state_ms"] += _d else: _pm["init_state_ms"] += _d for start_idx in range(0, total_frames, step): logging.info("Processing keyframe %d / %d", start_idx, total_frames) if frame_store is not None: image = frame_store.get_pil_rgb(start_idx) else: image = Image.open(os.path.join(frame_dir, frame_names[start_idx])).convert("RGB") mask_dict = MaskDictionary() # -- Detector on keyframe -- if _pm is not None: _t_gd = time.perf_counter() frame_bgr = np.array(image)[:, :, ::-1].copy() det = self._detector.predict(frame_bgr, text_prompts) if _pm is not None: _pl = getattr(self, '_perf_lock', None) _d = (time.perf_counter() - _t_gd) * 1000.0 if _pl: with _pl: _pm["gdino_total_ms"] += _d else: _pm["gdino_total_ms"] += _d if det.boxes is None or len(det.boxes) == 0: input_boxes = torch.zeros((0, 4), device=device) det_labels = [] else: input_boxes = torch.tensor(det.boxes, device=device, dtype=torch.float32) det_labels = _det_label_names(det) if len(input_boxes) == 0: logging.info("No detections on keyframe %d, propagating previous masks", start_idx) # Fill empty results for this segment seg_results: Dict[int, Dict[int, ObjectInfo]] = {} for fi in range(start_idx, min(start_idx + step, total_frames)): if fi not in all_results: # Carry forward last known masks all_results[fi] = { k: ObjectInfo( instance_id=v.instance_id, mask=v.mask, class_name=v.class_name, x1=v.x1, y1=v.y1, x2=v.x2, y2=v.y2, ) for k, v in sam2_masks.labels.items() } if sam2_masks.labels else {} seg_results[fi] = all_results[fi] if on_segment and seg_results: on_segment(seg_results) if start_idx == 0 and _ttfs_t0 is not None: logging.info("[TTFS:%s] +%.1fs first_segment_complete (no detections, step=%d)", _ttfs_job_id, time.perf_counter() - _ttfs_t0, step) continue # -- SAM2 image predictor on keyframe -- if _pm is not None: _t_si = time.perf_counter() self._image_predictor.set_image(np.array(image)) masks, _ = self._predict_masks_gpu(input_boxes) if _pm is not None: _pl = getattr(self, '_perf_lock', None) _d = (time.perf_counter() - _t_si) * 1000.0 if _pl: with _pl: _pm["sam_image_total_ms"] += _d else: _pm["sam_image_total_ms"] += _d mask_dict.add_new_frame_annotation( mask_list=masks, box_list=input_boxes.clone() if torch.is_tensor(input_boxes) else torch.tensor(input_boxes), label_list=det_labels, ) # -- IoU matching to maintain persistent IDs -- if _pm is not None: _t_id = time.perf_counter() objects_count = mask_dict.update_masks( tracking_dict=sam2_masks, iou_threshold=self.iou_threshold, objects_count=objects_count, ) if _pm is not None: _pl = getattr(self, '_perf_lock', None) _d = (time.perf_counter() - _t_id) * 1000.0 if _pl: with _pl: _pm["id_reconciliation_ms"] += _d else: _pm["id_reconciliation_ms"] += _d if len(mask_dict.labels) == 0: seg_results_empty: Dict[int, Dict[int, ObjectInfo]] = {} for fi in range(start_idx, min(start_idx + step, total_frames)): all_results[fi] = {} seg_results_empty[fi] = {} if on_segment: on_segment(seg_results_empty) if start_idx == 0 and _ttfs_t0 is not None: logging.info("[TTFS:%s] +%.1fs first_segment_complete (empty masks, step=%d)", _ttfs_job_id, time.perf_counter() - _ttfs_t0, step) continue # -- SAM2 video predictor: propagate masks -- # NOTE: propagate_segment() already accumulates into # _pm["sam_video_total_ms"], so no outer timer here. segment_output = self.propagate_segment( inference_state, start_idx, mask_dict, step, ) # GPU-deferred path: only materialize last frame for IoU last_fi = segment_output.last_frame_idx() if last_fi is not None: last_frame_objects = segment_output.frame_to_object_dict( last_fi, to_cpu=True, ) all_results[last_fi] = last_frame_objects sam2_masks = MaskDictionary() sam2_masks.labels = copy.deepcopy(last_frame_objects) if last_frame_objects: first_info = next(iter(last_frame_objects.values())) if first_info.mask is not None: sam2_masks.mask_height = first_info.mask.shape[-2] if first_info.mask.ndim >= 2 else 0 sam2_masks.mask_width = first_info.mask.shape[-1] if first_info.mask.ndim >= 2 else 0 if on_segment_output is not None: on_segment_output(segment_output) if start_idx == 0 and _ttfs_t0 is not None: logging.info("[TTFS:%s] +%.1fs first_segment_complete (step=%d)", _ttfs_job_id, time.perf_counter() - _ttfs_t0, step) logging.info( "Grounded-SAM-2 tracking complete: %d frames, %d tracked objects", len(all_results), objects_count, ) return all_results