Spaces:
Paused
Paused
| """Grounded-SAM-2 segmenter with continuous-ID video tracking. | |
| Combines an object detector (open-vocabulary or closed-set) with SAM2's video | |
| predictor to produce temporally consistent segmentation masks with | |
| persistent object IDs across an entire video. | |
| Reference implementation: | |
| Grounded-SAM-2/grounded_sam2_tracking_demo_with_continuous_id.py | |
| """ | |
| import copy | |
| import logging | |
| import time | |
| from contextlib import nullcontext | |
| from dataclasses import dataclass, field | |
| from typing import Any, Callable, Dict, List, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from .base import Segmenter, SegmentationResult | |
| # --------------------------------------------------------------------------- | |
| # Data structures (mirrors Grounded-SAM-2 reference utilities) | |
| # --------------------------------------------------------------------------- | |
| class ObjectInfo: | |
| """Per-object tracking info for a single frame.""" | |
| instance_id: int = 0 | |
| mask: Any = None # torch.Tensor bool (H, W) | |
| class_name: str = "" | |
| x1: int = 0 | |
| y1: int = 0 | |
| x2: int = 0 | |
| y2: int = 0 | |
| class MaskDictionary: | |
| """Tracks object masks across frames with IoU-based ID matching.""" | |
| mask_height: int = 0 | |
| mask_width: int = 0 | |
| labels: Dict[int, ObjectInfo] = field(default_factory=dict) | |
| def add_new_frame_annotation( | |
| self, | |
| mask_list: torch.Tensor, | |
| box_list: torch.Tensor, | |
| label_list: list, | |
| ): | |
| mask_img = torch.zeros(mask_list.shape[-2:]) | |
| anno = {} | |
| for idx, (mask, box, label) in enumerate(zip(mask_list, box_list, label_list)): | |
| final_index = idx + 1 | |
| mask_img[mask == True] = final_index # noqa: E712 | |
| anno[final_index] = ObjectInfo( | |
| instance_id=final_index, | |
| mask=mask, | |
| class_name=str(label), | |
| x1=int(box[0]), | |
| y1=int(box[1]), | |
| x2=int(box[2]), | |
| y2=int(box[3]), | |
| ) | |
| self.mask_height = mask_img.shape[0] | |
| self.mask_width = mask_img.shape[1] | |
| self.labels = anno | |
| def update_masks( | |
| self, | |
| tracking_dict: "MaskDictionary", | |
| iou_threshold: float = 0.5, | |
| objects_count: int = 0, | |
| ) -> int: | |
| """Match current detections against tracked objects via IoU.""" | |
| updated = {} | |
| used_tracked_ids = set() | |
| for _seg_id, seg_info in self.labels.items(): | |
| if seg_info.mask is None or seg_info.mask.sum() == 0: | |
| continue | |
| matched_id = 0 | |
| best_iou = iou_threshold | |
| for _obj_id, obj_info in tracking_dict.labels.items(): | |
| if obj_info.instance_id in used_tracked_ids: | |
| continue | |
| iou = self._iou(seg_info.mask, obj_info.mask) | |
| if iou > best_iou: | |
| best_iou = iou | |
| matched_id = obj_info.instance_id | |
| if not matched_id: | |
| objects_count += 1 | |
| matched_id = objects_count | |
| else: | |
| used_tracked_ids.add(matched_id) | |
| new_info = ObjectInfo( | |
| instance_id=matched_id, | |
| mask=seg_info.mask, | |
| class_name=seg_info.class_name, | |
| ) | |
| updated[matched_id] = new_info | |
| self.labels = updated | |
| return objects_count | |
| def update_masks_with_remapping( | |
| self, | |
| tracking_dict: "MaskDictionary", | |
| iou_threshold: float = 0.5, | |
| objects_count: int = 0, | |
| ) -> Tuple[int, Dict[int, int]]: | |
| """Match detections against tracked objects, returning ID remapping. | |
| Same logic as ``update_masks`` but additionally returns a dict | |
| mapping original (local) IDs to the assigned (global) IDs. | |
| """ | |
| updated = {} | |
| remapping: Dict[int, int] = {} | |
| used_tracked_ids = set() | |
| for seg_id, seg_info in self.labels.items(): | |
| if seg_info.mask is None or seg_info.mask.sum() == 0: | |
| continue | |
| matched_id = 0 | |
| best_iou = iou_threshold | |
| for _obj_id, obj_info in tracking_dict.labels.items(): | |
| if obj_info.instance_id in used_tracked_ids: | |
| continue | |
| iou = self._iou(seg_info.mask, obj_info.mask) | |
| if iou > best_iou: | |
| best_iou = iou | |
| matched_id = obj_info.instance_id | |
| if not matched_id: | |
| objects_count += 1 | |
| matched_id = objects_count | |
| else: | |
| used_tracked_ids.add(matched_id) | |
| new_info = ObjectInfo( | |
| instance_id=matched_id, | |
| mask=seg_info.mask, | |
| class_name=seg_info.class_name, | |
| ) | |
| updated[matched_id] = new_info | |
| remapping[seg_id] = matched_id | |
| self.labels = updated | |
| return objects_count, remapping | |
| def get_target_class_name(self, instance_id: int) -> str: | |
| info = self.labels.get(instance_id) | |
| return info.class_name if info else "" | |
| def _iou(m1: torch.Tensor, m2: torch.Tensor) -> float: | |
| if not torch.is_tensor(m1): | |
| m1 = torch.as_tensor(m1) | |
| if not torch.is_tensor(m2): | |
| m2 = torch.as_tensor(m2) | |
| # Multi-GPU reconciliation can compare masks produced on different | |
| # devices; normalize both masks onto CPU before arithmetic. | |
| if m1.device != m2.device: | |
| m1 = m1.detach().to(device="cpu") | |
| m2 = m2.detach().to(device="cpu") | |
| m1f = m1.to(torch.float32) | |
| m2f = m2.to(torch.float32) | |
| inter = (m1f * m2f).sum() | |
| union = m1f.sum() + m2f.sum() - inter | |
| if union == 0: | |
| return 0.0 | |
| return float((inter / union).item()) | |
| # --------------------------------------------------------------------------- | |
| # GPU-resident bounding-box helper (zero CUDA syncs) | |
| # --------------------------------------------------------------------------- | |
| def _bbox_gpu(bool_masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Compute bboxes from (N, H, W) bool GPU masks. Returns GPU tensors, zero sync. | |
| Returns: | |
| bboxes: (N, 4) int64 on same device as input [x_min, y_min, x_max, y_max] | |
| valid: (N,) bool on same device as input | |
| """ | |
| N, H, W = bool_masks.shape | |
| rows = bool_masks.any(dim=2) # (N, H) | |
| cols = bool_masks.any(dim=1) # (N, W) | |
| valid = rows.any(dim=1) # (N,) | |
| rows_f = rows.float() | |
| cols_f = cols.float() | |
| bboxes = torch.stack([ | |
| cols_f.argmax(dim=1), # x_min | |
| rows_f.argmax(dim=1), # y_min | |
| W - 1 - cols_f.flip(1).argmax(dim=1), # x_max | |
| H - 1 - rows_f.flip(1).argmax(dim=1), # y_max | |
| ], dim=1).to(torch.int64) # (N, 4) int64 | |
| return bboxes, valid | |
| # --------------------------------------------------------------------------- | |
| # GPU-resident segment output (deferred CPU materialization) | |
| # --------------------------------------------------------------------------- | |
| class SegmentOutput: | |
| """GPU-resident segment propagation result. Zero CUDA syncs to construct.""" | |
| masks: torch.Tensor # (count, H, W) bool on GPU | |
| bboxes: torch.Tensor # (count, 4) int64 on GPU | |
| valid: torch.Tensor # (count,) bool on GPU | |
| frame_indices: List[int] # len == count | |
| obj_ids: List[int] # len == count | |
| class_names: List[str] # len == count | |
| device: str = "cpu" | |
| def last_frame_idx(self) -> Optional[int]: | |
| return self.frame_indices[-1] if self.frame_indices else None | |
| def frame_to_object_dict( | |
| self, | |
| frame_idx: int, | |
| remapping: Optional[Dict[int, int]] = None, | |
| to_cpu: bool = True, | |
| ) -> Dict[int, "ObjectInfo"]: | |
| """Materialize a single frame's ObjectInfo dict from GPU buffers. | |
| Args: | |
| frame_idx: The frame index to materialize. | |
| remapping: Optional local->global ID mapping. | |
| to_cpu: If True, transfer mask/bbox to CPU. | |
| Returns: | |
| ``{obj_id: ObjectInfo}`` for the requested frame. | |
| """ | |
| # Build lazy frame index on first call | |
| if not hasattr(self, '_frame_index'): | |
| idx: Dict[int, List[int]] = {} | |
| for i, fi in enumerate(self.frame_indices): | |
| idx.setdefault(fi, []).append(i) | |
| self._frame_index = idx | |
| positions = self._frame_index.get(frame_idx) | |
| if not positions: | |
| return {} | |
| result: Dict[int, ObjectInfo] = {} | |
| for i in positions: | |
| oid = self.obj_ids[i] | |
| cn = self.class_names[i] | |
| global_id = remapping.get(oid, oid) if remapping else oid | |
| mask = self.masks[i] | |
| valid = self.valid[i] | |
| if valid: | |
| bbox = self.bboxes[i] | |
| if to_cpu: | |
| mask = mask.cpu() | |
| x1 = int(bbox[0].item()) | |
| y1 = int(bbox[1].item()) | |
| x2 = int(bbox[2].item()) | |
| y2 = int(bbox[3].item()) | |
| else: | |
| x1, y1 = int(bbox[0]), int(bbox[1]) | |
| x2, y2 = int(bbox[2]), int(bbox[3]) | |
| else: | |
| if to_cpu: | |
| mask = mask.cpu() | |
| x1 = y1 = x2 = y2 = 0 | |
| result[global_id] = ObjectInfo( | |
| instance_id=global_id, | |
| mask=mask, | |
| class_name=cn, | |
| x1=x1, y1=y1, x2=x2, y2=y2, | |
| ) | |
| return result | |
| # --------------------------------------------------------------------------- | |
| # Lazy frame objects wrapper (deferred GPU->CPU per-frame) | |
| # --------------------------------------------------------------------------- | |
| class LazyFrameObjects: | |
| """Lightweight wrapper for deferred GPU->CPU materialization. | |
| Holds a reference to a GPU-resident ``SegmentOutput`` plus frame index | |
| and optional ID remapping. Call ``materialize()`` to perform the | |
| GPU->CPU transfer (intended to run in a render worker thread). | |
| """ | |
| segment_output: SegmentOutput | |
| frame_idx: int | |
| remapping: Optional[Dict[int, int]] = None | |
| def materialize(self) -> Dict[int, "ObjectInfo"]: | |
| """Transfer one frame's data from GPU to CPU and build ObjectInfo dict.""" | |
| return self.segment_output.frame_to_object_dict( | |
| self.frame_idx, remapping=self.remapping, to_cpu=True, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # SAM2 HuggingFace model IDs per size | |
| # --------------------------------------------------------------------------- | |
| _SAM2_HF_MODELS = { | |
| "small": "facebook/sam2.1-hiera-small", | |
| "base": "facebook/sam2.1-hiera-base-plus", | |
| "large": "facebook/sam2.1-hiera-large", | |
| } | |
| def _det_label_names(det) -> List[str]: | |
| """Extract string labels from a DetectionResult, with fallback.""" | |
| num_boxes = len(det.boxes) if det.boxes is not None else 0 | |
| if det.label_names is not None and len(det.label_names) > 0: | |
| return list(det.label_names) | |
| if det.labels is not None and len(det.labels) > 0: | |
| return [str(l) for l in det.labels] | |
| return ["object"] * num_boxes | |
| # --------------------------------------------------------------------------- | |
| # Grounded-SAM-2 Segmenter | |
| # --------------------------------------------------------------------------- | |
| class GroundedSAM2Segmenter(Segmenter): | |
| """SAM2 video segmenter driven by an injected object detector. | |
| Any ``ObjectDetector`` can be used (defaults to Grounding DINO). | |
| For single-frame mode (``predict``), uses detector + SAM2 image predictor. | |
| For video mode (``process_video``), uses detector on keyframes + SAM2 video | |
| predictor for temporal mask propagation with continuous object IDs. | |
| """ | |
| supports_batch = False | |
| max_batch_size = 1 | |
| def __init__( | |
| self, | |
| model_size: str = "large", | |
| device: Optional[str] = None, | |
| step: int = 20, | |
| iou_threshold: float = 0.5, | |
| num_maskmem: Optional[int] = None, | |
| detector_name: Optional[str] = None, | |
| ): | |
| self.model_size = model_size | |
| self.step = step | |
| self.iou_threshold = iou_threshold | |
| self.num_maskmem = num_maskmem # None = use default (7) | |
| self._detector_name = detector_name # None = "grounding_dino" | |
| _size_suffix = {"small": "S", "base": "B", "large": "L"} | |
| _det_prefix = {"yolo11": "YSAM2"} | |
| _prefix = _det_prefix.get(detector_name, "GSAM2") | |
| self.name = f"{_prefix}-{_size_suffix[model_size]}" | |
| if device: | |
| self.device = device | |
| else: | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Lazy-loaded model handles | |
| self._video_predictor = None | |
| self._image_predictor = None | |
| self._detector = None | |
| self._models_loaded = False | |
| # -- Lazy loading ------------------------------------------------------- | |
| def _ensure_models_loaded(self): | |
| if self._models_loaded: | |
| return | |
| hf_id = _SAM2_HF_MODELS[self.model_size] | |
| logging.info( | |
| "Loading Grounded-SAM-2 (%s) on device %s ...", hf_id, self.device | |
| ) | |
| # Enable TF32 on Ampere+ GPUs | |
| if torch.cuda.is_available(): | |
| try: | |
| props = torch.cuda.get_device_properties( | |
| int(self.device.split(":")[-1]) if ":" in self.device else 0 | |
| ) | |
| if props.major >= 8: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| except Exception: | |
| pass | |
| from sam2.build_sam import build_sam2_hf, build_sam2_video_predictor_hf | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| # Video predictor (for process_video) | |
| self._video_predictor = build_sam2_video_predictor_hf( | |
| hf_id, device=self.device | |
| ) | |
| # Image predictor (for single-frame predict) | |
| sam2_image_model = build_sam2_hf(hf_id, device=self.device) | |
| self._image_predictor = SAM2ImagePredictor(sam2_image_model) | |
| # Override num_maskmem if requested | |
| if self.num_maskmem is not None: | |
| self._patch_num_maskmem(self._video_predictor, self.num_maskmem) | |
| logging.info("Patched video predictor num_maskmem → %d", self.num_maskmem) | |
| # Load detector by name (defaults to Grounding DINO) | |
| from models.model_loader import load_detector_on_device | |
| det_name = self._detector_name or "grounding_dino" | |
| self._detector = load_detector_on_device(det_name, self.device) | |
| self._models_loaded = True | |
| logging.info("Grounded-SAM-2 models loaded successfully.") | |
| def _patch_num_maskmem(predictor, num_maskmem: int): | |
| """Override num_maskmem on a loaded SAM2 video predictor at runtime. | |
| Slices the temporal positional encoding parameter to match the new | |
| memory size so the model runs without shape mismatches. | |
| """ | |
| import torch.nn as nn | |
| # The underlying model may be predictor itself or predictor.model | |
| model = getattr(predictor, "model", predictor) | |
| old = getattr(model, "num_maskmem", None) | |
| if old is None: | |
| logging.warning("Cannot patch num_maskmem: attribute not found on model") | |
| return | |
| if num_maskmem == old: | |
| return | |
| model.num_maskmem = num_maskmem | |
| # Slice or pad maskmem_tpos_enc (shape: [num_maskmem, 1, 1, mem_dim]) | |
| if hasattr(model, "maskmem_tpos_enc") and model.maskmem_tpos_enc is not None: | |
| old_enc = model.maskmem_tpos_enc | |
| if num_maskmem <= old_enc.shape[0]: | |
| model.maskmem_tpos_enc = nn.Parameter( | |
| old_enc[:num_maskmem].clone() | |
| ) | |
| else: | |
| # Pad with zeros for the extra slots | |
| pad = torch.zeros( | |
| num_maskmem - old_enc.shape[0], *old_enc.shape[1:], | |
| device=old_enc.device, dtype=old_enc.dtype, | |
| ) | |
| model.maskmem_tpos_enc = nn.Parameter( | |
| torch.cat([old_enc, pad], dim=0) | |
| ) | |
| logging.info("num_maskmem changed from %d to %d", old, num_maskmem) | |
| # -- GPU-resident SAM2 predict (skip numpy conversion) ------------------ | |
| def _predict_masks_gpu( | |
| self, input_boxes: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Run SAM2 image predictor keeping masks on GPU (skip numpy conversion). | |
| Calls SAM2's internal ``_prep_prompts`` + ``_predict`` directly, | |
| bypassing the public ``predict()`` which converts to numpy. | |
| Args: | |
| input_boxes: (N, 4) float tensor on device. | |
| Returns: | |
| ``(masks, scores)`` — *masks* is ``(N, H, W)`` bool GPU tensor, | |
| *scores* is ``(N,)`` float GPU tensor. | |
| """ | |
| mask_input, unnorm_coords, labels, unnorm_box = ( | |
| self._image_predictor._prep_prompts( | |
| point_coords=None, | |
| point_labels=None, | |
| box=input_boxes, | |
| mask_logits=None, | |
| normalize_coords=True, | |
| ) | |
| ) | |
| masks, scores, _ = self._image_predictor._predict( | |
| unnorm_coords, labels, unnorm_box, mask_input, | |
| multimask_output=False, return_logits=False, | |
| ) | |
| # _predict returns (1, N, ..., H, W); squeeze batch dim | |
| masks = masks.squeeze(0) | |
| if masks.ndim == 2: | |
| masks = masks[None] | |
| elif masks.ndim == 4: | |
| masks = masks.squeeze(1) | |
| scores = scores.squeeze(0).flatten() | |
| return masks, scores | |
| # -- Single-frame interface (Segmenter.predict) ------------------------- | |
| def predict( | |
| self, frame: np.ndarray, text_prompts: Optional[list] = None | |
| ) -> SegmentationResult: | |
| """Run detector + SAM2 image predictor on a single frame.""" | |
| self._ensure_models_loaded() | |
| prompts = text_prompts or ["object"] | |
| # Run detector to get boxes | |
| det = self._detector.predict(frame, prompts) | |
| if det.boxes is None or len(det.boxes) == 0: | |
| return SegmentationResult( | |
| masks=np.zeros((0, frame.shape[0], frame.shape[1]), dtype=bool), | |
| scores=None, | |
| boxes=None, | |
| label_names=None, | |
| ) | |
| # SAM2 image predictor expects RGB | |
| import cv2 as _cv2 | |
| frame_rgb = _cv2.cvtColor(frame, _cv2.COLOR_BGR2RGB) | |
| device_type = self.device.split(":")[0] | |
| autocast_ctx = torch.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() | |
| with autocast_ctx: | |
| self._image_predictor.set_image(frame_rgb) | |
| input_boxes = torch.tensor(det.boxes, device=self.device, dtype=torch.float32) | |
| masks, scores, _ = self._image_predictor.predict( | |
| point_coords=None, | |
| point_labels=None, | |
| box=input_boxes, | |
| multimask_output=False, | |
| ) | |
| # Normalize mask shape to (N, H, W) | |
| if masks.ndim == 2: | |
| masks = masks[None] | |
| elif masks.ndim == 4: | |
| masks = masks.squeeze(1) | |
| if isinstance(masks, torch.Tensor): | |
| masks_np = masks.cpu().numpy().astype(bool) | |
| else: | |
| masks_np = np.asarray(masks).astype(bool) | |
| scores_np = None | |
| if scores is not None: | |
| if isinstance(scores, torch.Tensor): | |
| scores_np = scores.cpu().numpy().flatten() | |
| else: | |
| scores_np = np.asarray(scores).flatten() | |
| return SegmentationResult( | |
| masks=masks_np, | |
| scores=scores_np, | |
| boxes=det.boxes, | |
| label_names=det.label_names, | |
| ) | |
| # -- Multi-GPU helper methods ------------------------------------------- | |
| def detect_keyframe( | |
| self, | |
| image: "Image", | |
| text_prompts: List[str], | |
| ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], List[str]]: | |
| """Run detector + SAM2 image predictor on a single keyframe. | |
| Args: | |
| image: PIL Image in RGB mode. | |
| text_prompts: Text queries for the detector. | |
| Returns: | |
| ``(masks, boxes, labels)`` where *masks* is an ``(N, H, W)`` | |
| bool GPU tensor, *boxes* is an ``(N, 4)`` tensor on device, | |
| and *labels* is a list of strings. Returns | |
| ``(None, None, [])`` when no objects are detected. | |
| """ | |
| self._ensure_models_loaded() | |
| _pm = getattr(self, '_perf_metrics', None) | |
| if _pm is not None: | |
| _t0 = time.perf_counter() | |
| # Convert PIL RGB → numpy BGR for detector.predict() | |
| frame_bgr = np.array(image)[:, :, ::-1].copy() | |
| det = self._detector.predict(frame_bgr, text_prompts) | |
| if _pm is not None: | |
| _pl = getattr(self, '_perf_lock', None) | |
| _d = (time.perf_counter() - _t0) * 1000.0 | |
| if _pl: | |
| with _pl: _pm["gdino_total_ms"] += _d | |
| else: | |
| _pm["gdino_total_ms"] += _d | |
| if det.boxes is None or len(det.boxes) == 0: | |
| return None, None, [] | |
| input_boxes = torch.tensor(det.boxes, device=self.device, dtype=torch.float32) | |
| det_labels = _det_label_names(det) | |
| # SAM2 image predictor | |
| if _pm is not None: | |
| _t1 = time.perf_counter() | |
| self._image_predictor.set_image(np.array(image)) | |
| masks, _ = self._predict_masks_gpu(input_boxes) | |
| if _pm is not None: | |
| _pl = getattr(self, '_perf_lock', None) | |
| _d = (time.perf_counter() - _t1) * 1000.0 | |
| if _pl: | |
| with _pl: _pm["sam_image_total_ms"] += _d | |
| else: | |
| _pm["sam_image_total_ms"] += _d | |
| return masks, input_boxes, det_labels | |
| def propagate_segment( | |
| self, | |
| inference_state: Any, | |
| start_idx: int, | |
| mask_dict: "MaskDictionary", | |
| step: int, | |
| ) -> "SegmentOutput": | |
| """Propagate masks for a single segment via SAM2 video predictor. | |
| Returns a GPU-resident ``SegmentOutput`` with zero CUDA syncs. | |
| Call ``output.frame_to_object_dict()`` to materialize per-frame CPU dicts. | |
| """ | |
| _pm = getattr(self, '_perf_metrics', None) | |
| if _pm is not None: | |
| _t0 = time.perf_counter() | |
| self._video_predictor.reset_state(inference_state) | |
| for obj_id, obj_info in mask_dict.labels.items(): | |
| self._video_predictor.add_new_mask( | |
| inference_state, start_idx, obj_id, obj_info.mask, | |
| ) | |
| # Pre-compute class name lookup (avoid repeated dict access in loop) | |
| obj_id_to_class = {oid: mask_dict.get_target_class_name(oid) for oid in mask_dict.labels} | |
| n_obj = len(mask_dict.labels) | |
| # Pre-allocated GPU buffers (allocated on first yield when H, W known) | |
| masks_buf = bboxes_buf = valid_buf = None | |
| frame_indices: List[int] = [] | |
| obj_ids_list: List[int] = [] | |
| class_names_list: List[str] = [] | |
| cursor = 0 | |
| for out_frame_idx, out_obj_ids, out_mask_logits in self._video_predictor.propagate_in_video( | |
| inference_state, max_frame_num_to_track=step, start_frame_idx=start_idx, | |
| ): | |
| bool_masks = (out_mask_logits[:, 0] > 0.0) # (N, H, W) GPU async | |
| n = bool_masks.shape[0] | |
| # Allocate on first yield | |
| if masks_buf is None: | |
| H, W = bool_masks.shape[1], bool_masks.shape[2] | |
| max_entries = step * max(n_obj, n) | |
| masks_buf = torch.empty(max_entries, H, W, dtype=torch.bool, device=self.device) | |
| bboxes_buf = torch.empty(max_entries, 4, dtype=torch.int64, device=self.device) | |
| valid_buf = torch.empty(max_entries, dtype=torch.bool, device=self.device) | |
| # Grow buffers if needed (unlikely but safe) | |
| if cursor + n > masks_buf.shape[0]: | |
| grow = max(step * n_obj, cursor + n - masks_buf.shape[0]) | |
| H, W = masks_buf.shape[1], masks_buf.shape[2] | |
| masks_buf = torch.cat([masks_buf, torch.empty(grow, H, W, dtype=torch.bool, device=self.device)]) | |
| bboxes_buf = torch.cat([bboxes_buf, torch.empty(grow, 4, dtype=torch.int64, device=self.device)]) | |
| valid_buf = torch.cat([valid_buf, torch.empty(grow, dtype=torch.bool, device=self.device)]) | |
| # Inline bbox — GPU async, zero sync | |
| frame_bboxes, frame_valid = _bbox_gpu(bool_masks) | |
| # Fill pre-allocated slices — GPU async | |
| masks_buf[cursor:cursor + n] = bool_masks | |
| bboxes_buf[cursor:cursor + n] = frame_bboxes | |
| valid_buf[cursor:cursor + n] = frame_valid | |
| # Metadata (trivial Python, ~2μs GIL) | |
| oid_list = list(out_obj_ids) if not isinstance(out_obj_ids, list) else out_obj_ids | |
| for oid in oid_list: | |
| frame_indices.append(out_frame_idx) | |
| obj_ids_list.append(oid) | |
| class_names_list.append(obj_id_to_class.get(oid, "")) | |
| cursor += n | |
| # Build output (zero-copy slice if under-filled, empty tensors if no frames) | |
| if masks_buf is not None: | |
| output = SegmentOutput( | |
| masks=masks_buf[:cursor], bboxes=bboxes_buf[:cursor], | |
| valid=valid_buf[:cursor], frame_indices=frame_indices, | |
| obj_ids=obj_ids_list, class_names=class_names_list, device=self.device, | |
| ) | |
| else: | |
| output = SegmentOutput( | |
| masks=torch.empty(0, 0, 0, dtype=torch.bool, device=self.device), | |
| bboxes=torch.empty(0, 4, dtype=torch.int64, device=self.device), | |
| valid=torch.empty(0, dtype=torch.bool, device=self.device), | |
| frame_indices=[], obj_ids=[], class_names=[], device=self.device, | |
| ) | |
| if _pm is not None: | |
| _pl = getattr(self, '_perf_lock', None) | |
| _d = (time.perf_counter() - _t0) * 1000.0 | |
| if _pl: | |
| with _pl: _pm["sam_video_total_ms"] += _d | |
| else: | |
| _pm["sam_video_total_ms"] += _d | |
| return output | |
| # -- Video-level tracking interface ------------------------------------- | |
| def process_video( | |
| self, | |
| frame_dir: str, | |
| frame_names: List[str], | |
| text_prompts: List[str], | |
| on_segment: Optional[Callable[[Dict[int, Dict[int, "ObjectInfo"]]], None]] = None, | |
| on_segment_output: Optional[Callable[["SegmentOutput"], None]] = None, | |
| _ttfs_t0: Optional[float] = None, | |
| _ttfs_job_id: Optional[str] = None, | |
| frame_store=None, | |
| ) -> Dict[int, Dict[int, ObjectInfo]]: | |
| """Run full Grounded-SAM-2 tracking pipeline on extracted JPEG frames. | |
| Args: | |
| frame_dir: Directory containing JPEG frames. | |
| frame_names: Sorted list of frame filenames. | |
| text_prompts: Text queries for the detector. | |
| on_segment: Optional callback invoked after each segment completes. | |
| Receives ``{frame_idx: {obj_id: ObjectInfo}}`` for the segment. | |
| Returns: | |
| Dict mapping frame_idx -> {obj_id: ObjectInfo} with masks, | |
| bboxes, and class names for every frame. | |
| """ | |
| import os | |
| self._ensure_models_loaded() | |
| device = self.device | |
| step = self.step | |
| total_frames = len(frame_names) | |
| logging.info( | |
| "Grounded-SAM-2 tracking: %d frames, step=%d, queries=%s", | |
| total_frames, step, text_prompts, | |
| ) | |
| # Single global autocast context (matches reference implementation) | |
| device_type = device.split(":")[0] | |
| autocast_ctx = torch.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() | |
| _pm = getattr(self, '_perf_metrics', None) | |
| sam2_masks = MaskDictionary() | |
| objects_count = 0 | |
| all_results: Dict[int, Dict[int, ObjectInfo]] = {} | |
| with autocast_ctx: | |
| # Init SAM2 video predictor state | |
| if _pm is not None: | |
| _t_init = time.perf_counter() | |
| if frame_store is not None: | |
| inference_state = self._video_predictor.init_state( | |
| video_path=frame_dir, # dummy dir with 1 JPEG | |
| offload_video_to_cpu=True, | |
| async_loading_frames=False, | |
| ) | |
| # Clear cached_features (dummy frame 0's backbone features) | |
| inference_state["cached_features"] = {} | |
| # Patch in real frame data | |
| img_size = self._video_predictor.image_size | |
| inference_state["images"] = frame_store.sam2_adapter(image_size=img_size) | |
| inference_state["num_frames"] = len(frame_store) | |
| inference_state["video_height"] = frame_store.height | |
| inference_state["video_width"] = frame_store.width | |
| else: | |
| inference_state = self._video_predictor.init_state( | |
| video_path=frame_dir, | |
| offload_video_to_cpu=True, | |
| async_loading_frames=True, | |
| ) | |
| if _pm is not None: | |
| _pl = getattr(self, '_perf_lock', None) | |
| _d = (time.perf_counter() - _t_init) * 1000.0 | |
| if _pl: | |
| with _pl: _pm["init_state_ms"] += _d | |
| else: | |
| _pm["init_state_ms"] += _d | |
| for start_idx in range(0, total_frames, step): | |
| logging.info("Processing keyframe %d / %d", start_idx, total_frames) | |
| if frame_store is not None: | |
| image = frame_store.get_pil_rgb(start_idx) | |
| else: | |
| image = Image.open(os.path.join(frame_dir, frame_names[start_idx])).convert("RGB") | |
| mask_dict = MaskDictionary() | |
| # -- Detector on keyframe -- | |
| if _pm is not None: | |
| _t_gd = time.perf_counter() | |
| frame_bgr = np.array(image)[:, :, ::-1].copy() | |
| det = self._detector.predict(frame_bgr, text_prompts) | |
| if _pm is not None: | |
| _pl = getattr(self, '_perf_lock', None) | |
| _d = (time.perf_counter() - _t_gd) * 1000.0 | |
| if _pl: | |
| with _pl: _pm["gdino_total_ms"] += _d | |
| else: | |
| _pm["gdino_total_ms"] += _d | |
| if det.boxes is None or len(det.boxes) == 0: | |
| input_boxes = torch.zeros((0, 4), device=device) | |
| det_labels = [] | |
| else: | |
| input_boxes = torch.tensor(det.boxes, device=device, dtype=torch.float32) | |
| det_labels = _det_label_names(det) | |
| if len(input_boxes) == 0: | |
| logging.info("No detections on keyframe %d, propagating previous masks", start_idx) | |
| # Fill empty results for this segment | |
| seg_results: Dict[int, Dict[int, ObjectInfo]] = {} | |
| for fi in range(start_idx, min(start_idx + step, total_frames)): | |
| if fi not in all_results: | |
| # Carry forward last known masks | |
| all_results[fi] = { | |
| k: ObjectInfo( | |
| instance_id=v.instance_id, | |
| mask=v.mask, | |
| class_name=v.class_name, | |
| x1=v.x1, y1=v.y1, x2=v.x2, y2=v.y2, | |
| ) | |
| for k, v in sam2_masks.labels.items() | |
| } if sam2_masks.labels else {} | |
| seg_results[fi] = all_results[fi] | |
| if on_segment and seg_results: | |
| on_segment(seg_results) | |
| if start_idx == 0 and _ttfs_t0 is not None: | |
| logging.info("[TTFS:%s] +%.1fs first_segment_complete (no detections, step=%d)", | |
| _ttfs_job_id, time.perf_counter() - _ttfs_t0, step) | |
| continue | |
| # -- SAM2 image predictor on keyframe -- | |
| if _pm is not None: | |
| _t_si = time.perf_counter() | |
| self._image_predictor.set_image(np.array(image)) | |
| masks, _ = self._predict_masks_gpu(input_boxes) | |
| if _pm is not None: | |
| _pl = getattr(self, '_perf_lock', None) | |
| _d = (time.perf_counter() - _t_si) * 1000.0 | |
| if _pl: | |
| with _pl: _pm["sam_image_total_ms"] += _d | |
| else: | |
| _pm["sam_image_total_ms"] += _d | |
| mask_dict.add_new_frame_annotation( | |
| mask_list=masks, | |
| box_list=input_boxes.clone() if torch.is_tensor(input_boxes) else torch.tensor(input_boxes), | |
| label_list=det_labels, | |
| ) | |
| # -- IoU matching to maintain persistent IDs -- | |
| if _pm is not None: | |
| _t_id = time.perf_counter() | |
| objects_count = mask_dict.update_masks( | |
| tracking_dict=sam2_masks, | |
| iou_threshold=self.iou_threshold, | |
| objects_count=objects_count, | |
| ) | |
| if _pm is not None: | |
| _pl = getattr(self, '_perf_lock', None) | |
| _d = (time.perf_counter() - _t_id) * 1000.0 | |
| if _pl: | |
| with _pl: _pm["id_reconciliation_ms"] += _d | |
| else: | |
| _pm["id_reconciliation_ms"] += _d | |
| if len(mask_dict.labels) == 0: | |
| seg_results_empty: Dict[int, Dict[int, ObjectInfo]] = {} | |
| for fi in range(start_idx, min(start_idx + step, total_frames)): | |
| all_results[fi] = {} | |
| seg_results_empty[fi] = {} | |
| if on_segment: | |
| on_segment(seg_results_empty) | |
| if start_idx == 0 and _ttfs_t0 is not None: | |
| logging.info("[TTFS:%s] +%.1fs first_segment_complete (empty masks, step=%d)", | |
| _ttfs_job_id, time.perf_counter() - _ttfs_t0, step) | |
| continue | |
| # -- SAM2 video predictor: propagate masks -- | |
| # NOTE: propagate_segment() already accumulates into | |
| # _pm["sam_video_total_ms"], so no outer timer here. | |
| segment_output = self.propagate_segment( | |
| inference_state, start_idx, mask_dict, step, | |
| ) | |
| # GPU-deferred path: only materialize last frame for IoU | |
| last_fi = segment_output.last_frame_idx() | |
| if last_fi is not None: | |
| last_frame_objects = segment_output.frame_to_object_dict( | |
| last_fi, to_cpu=True, | |
| ) | |
| all_results[last_fi] = last_frame_objects | |
| sam2_masks = MaskDictionary() | |
| sam2_masks.labels = copy.deepcopy(last_frame_objects) | |
| if last_frame_objects: | |
| first_info = next(iter(last_frame_objects.values())) | |
| if first_info.mask is not None: | |
| sam2_masks.mask_height = first_info.mask.shape[-2] if first_info.mask.ndim >= 2 else 0 | |
| sam2_masks.mask_width = first_info.mask.shape[-1] if first_info.mask.ndim >= 2 else 0 | |
| if on_segment_output is not None: | |
| on_segment_output(segment_output) | |
| if start_idx == 0 and _ttfs_t0 is not None: | |
| logging.info("[TTFS:%s] +%.1fs first_segment_complete (step=%d)", | |
| _ttfs_job_id, time.perf_counter() - _ttfs_t0, step) | |
| logging.info( | |
| "Grounded-SAM-2 tracking complete: %d frames, %d tracked objects", | |
| len(all_results), objects_count, | |
| ) | |
| return all_results | |