| import os |
|
|
| import torch |
| from ..model.vl_combiner import SAM3VLBackbone |
|
|
| try: |
| from ..model.vl_combiner import SAM3VLBackboneTri |
| except ImportError: |
| SAM3VLBackboneTri = None |
| from typing import Dict, List, Optional |
|
|
| import numpy as np |
| from ..model.data_misc import BatchedDatapoint, FindStage |
| from ..model.geometry_encoders import Prompt |
| from ..model.model_misc import SAM3Output |
| from ..model.sam3_image import Sam3Image |
| from ..model.sam3_multiplex_detector_utils import nms_masks |
|
|
|
|
| class Sam3MultiplexImageBase(Sam3Image): |
| """A wrapper class to run Sam3Image on videos for per-frame detection (no tracking).""" |
|
|
| def __init__( |
| self, |
| *args, |
| tracking_score_thresh: float = 0.0, |
| offload_outputs_to_cpu_for_eval: bool = False, |
| **kwargs, |
| ): |
| super().__init__(*args, **kwargs) |
| self.tracking_score_thresh = tracking_score_thresh |
| self.offload_outputs_to_cpu_for_eval = offload_outputs_to_cpu_for_eval |
| self.trim_outputs_for_eval = True |
|
|
| def forward( |
| self, |
| input: BatchedDatapoint, |
| is_inference=False, |
| ): |
| assert not self.training, ( |
| "Sam3MultiplexImageBase should only be used in eval mode." |
| ) |
|
|
| device = self.device |
| backbone_out = {"img_batch_all_stages": input.img_batch} |
| text_outputs = self.backbone.forward_text(input.find_text_batch, device=device) |
| backbone_out.update(text_outputs) |
| num_frames = len(input.find_inputs) |
|
|
| previous_stages_out = SAM3Output( |
| iter_mode=SAM3Output.IterMode.LAST_STEP_PER_STAGE |
| ) |
| for frame_idx in range(num_frames): |
| find_input = input.find_inputs[frame_idx] |
| find_target = input.find_targets[frame_idx] |
| geometric_prompt = self._get_geo_prompt_from_find_input(find_input) |
| cur_out, _ = self.forward_video_grounding( |
| backbone_out=backbone_out, |
| find_input=find_input, |
| find_target=find_target, |
| geometric_prompt=geometric_prompt, |
| ) |
| |
| if self.offload_outputs_to_cpu_for_eval: |
| cur_out = {k: v.cpu() for k, v in cur_out.items()} |
|
|
| previous_stages_out.append([cur_out]) |
|
|
| get_queries = None |
| return previous_stages_out, get_queries |
|
|
| def forward_video_grounding( |
| self, |
| backbone_out, |
| find_input, |
| find_target, |
| geometric_prompt: Prompt, |
| **kwargs, |
| ): |
| |
| out = self.forward_grounding( |
| backbone_out=backbone_out, |
| find_input=find_input, |
| find_target=find_target, |
| geometric_prompt=geometric_prompt, |
| ) |
| |
| out = { |
| "pred_logits": out["pred_logits"], |
| "pred_boxes": out["pred_boxes"], |
| "pred_boxes_xyxy": out["pred_boxes_xyxy"], |
| "pred_masks": out["pred_masks"], |
| "pred_object_ids": self._get_dummy_object_ids(out["pred_logits"]), |
| } |
| return out, backbone_out |
|
|
| def _get_dummy_object_ids(self, pred_logits): |
| """Generate dummy object IDs for the detected objects, based on their detection query indices.""" |
| |
| B, Q, _ = pred_logits.shape |
| is_above_thresh = pred_logits.squeeze(2) > self.tracking_score_thresh |
| dummy_obj_ids = torch.arange(Q, device=self.device).expand(B, -1) |
| dummy_obj_ids = torch.where(is_above_thresh, dummy_obj_ids, -1) |
| return dummy_obj_ids |
|
|
| def _trim_outputs(self, *args, **kwargs): |
| pass |
|
|
| def _batch_find_inputs( |
| self, |
| find_inputs: List[FindStage], |
| chunk_start: int, |
| chunk_end: int, |
| ) -> FindStage: |
| """ |
| Batch multiple FindStage objects into a single batched FindStage. |
| |
| For each frame in the chunk, creates img_ids that point to the correct |
| frame index. When processing streaming video, the img_ids are the actual |
| frame indices (e.g., [0, 1, 2, ..., 15] for chunk 0-16), and the modulo |
| for circular buffer access is applied later in _get_img_feats. |
| |
| Args: |
| find_inputs: List of FindStage objects for all frames. |
| chunk_start: Start index of the chunk. |
| chunk_end: End index of the chunk (exclusive). |
| |
| Returns: |
| A single FindStage with batched tensors. |
| """ |
| chunk_find_inputs = [ |
| find_inputs[i % len(find_inputs)] for i in range(chunk_start, chunk_end) |
| ] |
|
|
| |
| |
| |
| device = chunk_find_inputs[0].img_ids.device |
| dtype = chunk_find_inputs[0].img_ids.dtype |
| img_ids_list = [ |
| torch.tensor([i], device=device, dtype=dtype) |
| for i in range(chunk_start, chunk_end) |
| ] |
| batched_img_ids = torch.cat(img_ids_list, dim=0) |
|
|
| |
| img_ids_np_list = [np.array([i]) for i in range(chunk_start, chunk_end)] |
| batched_img_ids_np = np.concatenate(img_ids_np_list, axis=0) |
|
|
| |
| text_ids_list = [fi.text_ids for fi in chunk_find_inputs] |
| batched_text_ids = torch.cat(text_ids_list, dim=0) |
|
|
| |
| input_boxes_list = [fi.input_boxes for fi in chunk_find_inputs] |
| batched_input_boxes = ( |
| torch.cat(input_boxes_list, dim=0) |
| if input_boxes_list[0] is not None |
| else None |
| ) |
|
|
| |
| input_boxes_mask_list = [fi.input_boxes_mask for fi in chunk_find_inputs] |
| batched_input_boxes_mask = ( |
| torch.cat(input_boxes_mask_list, dim=0) |
| if input_boxes_mask_list[0] is not None |
| else None |
| ) |
|
|
| |
| input_boxes_label_list = [fi.input_boxes_label for fi in chunk_find_inputs] |
| batched_input_boxes_label = ( |
| torch.cat(input_boxes_label_list, dim=0) |
| if input_boxes_label_list[0] is not None |
| else None |
| ) |
|
|
| |
| input_points_list = [fi.input_points for fi in chunk_find_inputs] |
| batched_input_points = ( |
| torch.cat(input_points_list, dim=0) |
| if input_points_list[0] is not None |
| else None |
| ) |
|
|
| |
| input_points_mask_list = [fi.input_points_mask for fi in chunk_find_inputs] |
| batched_input_points_mask = ( |
| torch.cat(input_points_mask_list, dim=0) |
| if input_points_mask_list[0] is not None |
| else None |
| ) |
|
|
| |
| input_boxes_before_embed_list = [ |
| fi.input_boxes_before_embed for fi in chunk_find_inputs |
| ] |
| batched_input_boxes_before_embed = ( |
| torch.cat(input_boxes_before_embed_list, dim=0) |
| if input_boxes_before_embed_list[0] is not None |
| else None |
| ) |
|
|
| input_points_before_embed_list = [ |
| fi.input_points_before_embed for fi in chunk_find_inputs |
| ] |
| batched_input_points_before_embed = ( |
| torch.cat(input_points_before_embed_list, dim=0) |
| if input_points_before_embed_list[0] is not None |
| else None |
| ) |
|
|
| |
| batched_find_input = FindStage( |
| img_ids=batched_img_ids, |
| img_ids_np=batched_img_ids_np, |
| text_ids=batched_text_ids, |
| input_boxes=batched_input_boxes, |
| input_boxes_mask=batched_input_boxes_mask, |
| input_boxes_label=batched_input_boxes_label, |
| input_points=batched_input_points, |
| input_points_mask=batched_input_points_mask, |
| ptrs=None, |
| ptrs_seg=None, |
| object_ids=None, |
| input_boxes_before_embed=batched_input_boxes_before_embed, |
| input_points_before_embed=batched_input_points_before_embed, |
| ) |
|
|
| return batched_find_input |
|
|
| def _batch_geometric_prompts( |
| self, |
| geometric_prompts: List[Prompt], |
| chunk_start: int, |
| chunk_end: int, |
| ) -> Prompt: |
| """ |
| Batch multiple Prompt objects into a single batched Prompt. |
| |
| Args: |
| geometric_prompts: List of Prompt objects for all frames. |
| chunk_start: Start index of the chunk. |
| chunk_end: End index of the chunk (exclusive). |
| |
| Returns: |
| A single Prompt with batched tensors. |
| """ |
| chunk_prompts = [geometric_prompts[i] for i in range(chunk_start, chunk_end)] |
| return self._batch_geometric_prompts_from_list(chunk_prompts) |
|
|
| def _batch_geometric_prompts_from_list( |
| self, |
| chunk_prompts: List[Prompt], |
| ) -> Prompt: |
| """ |
| Batch a list of Prompt objects into a single batched Prompt. |
| |
| Prompt uses seq-first, batch-second convention: |
| - box_embeddings: N_boxes x B x C_box - batch along dim 1 |
| - box_mask: B x N_boxes - batch along dim 0 |
| - box_labels: N_boxes x B - batch along dim 1 |
| - point_embeddings: N_points x B x C_point - batch along dim 1 |
| - point_mask: B x N_points - batch along dim 0 |
| - point_labels: N_points x B - batch along dim 1 |
| |
| Args: |
| chunk_prompts: List of Prompt objects to batch. |
| |
| Returns: |
| A single Prompt with batched tensors. |
| """ |
|
|
| |
| def batch_tensors(tensors, dim): |
| if tensors[0] is None: |
| return None |
| return torch.cat(tensors, dim=dim) |
|
|
| |
| box_embeddings_list = [p.box_embeddings for p in chunk_prompts] |
| batched_box_embeddings = batch_tensors(box_embeddings_list, dim=1) |
|
|
| |
| box_mask_list = [p.box_mask for p in chunk_prompts] |
| batched_box_mask = batch_tensors(box_mask_list, dim=0) |
|
|
| |
| box_labels_list = [p.box_labels for p in chunk_prompts] |
| batched_box_labels = batch_tensors(box_labels_list, dim=1) |
|
|
| |
| point_embeddings_list = [p.point_embeddings for p in chunk_prompts] |
| batched_point_embeddings = batch_tensors(point_embeddings_list, dim=1) |
|
|
| |
| point_mask_list = [p.point_mask for p in chunk_prompts] |
| batched_point_mask = batch_tensors(point_mask_list, dim=0) |
|
|
| |
| point_labels_list = [p.point_labels for p in chunk_prompts] |
| batched_point_labels = batch_tensors(point_labels_list, dim=1) |
|
|
| |
| batched_prompt = Prompt( |
| box_embeddings=batched_box_embeddings, |
| box_mask=batched_box_mask, |
| box_labels=batched_box_labels, |
| point_embeddings=batched_point_embeddings, |
| point_mask=batched_point_mask, |
| point_labels=batched_point_labels, |
| ) |
|
|
| return batched_prompt |
|
|
|
|
| class Sam3MultiplexDetector(Sam3MultiplexImageBase): |
| def __init__( |
| self, |
| *args, |
| async_all_gather=True, |
| gather_backbone_out=None, |
| is_multiplex=False, |
| **kwargs, |
| ): |
| super().__init__(*args, **kwargs) |
| self.rank = int(os.getenv("RANK", "0")) |
| self.world_size = int(os.getenv("WORLD_SIZE", "1")) |
| self.async_all_gather = async_all_gather |
|
|
| |
| if gather_backbone_out is None: |
| gather_backbone_out = isinstance(self.backbone, SAM3VLBackbone) or ( |
| SAM3VLBackboneTri is not None |
| and isinstance(self.backbone, SAM3VLBackboneTri) |
| ) |
| self.gather_backbone_out = gather_backbone_out |
| self.is_multiplex = is_multiplex |
|
|
| def forward_video_grounding_multigpu( |
| self, |
| backbone_out, |
| find_inputs, |
| geometric_prompt: Prompt, |
| frame_idx, |
| num_frames, |
| |
| multigpu_buffer, |
| track_in_reverse=False, |
| |
| return_sam2_backbone_feats=False, |
| |
| run_nms=False, |
| nms_prob_thresh=None, |
| nms_iou_thresh=None, |
| nms_use_iom=False, |
| |
| max_frame_num_to_track=None, |
| propagate_in_video_start_frame_idx=None, |
| |
| feature_cache=None, |
| **kwargs, |
| ): |
| """ |
| Compute the FA detection outputs in a distributed manner, where all GPUs process |
| a chunk of frames (equal to the number of GPUs) at once and store them in cache. |
| """ |
| |
| |
| if max_frame_num_to_track is not None: |
| if propagate_in_video_start_frame_idx is None: |
| propagate_in_video_start_frame_idx = 0 |
| if track_in_reverse: |
| |
| valid_frame_start = max( |
| 0, |
| propagate_in_video_start_frame_idx - max_frame_num_to_track + 1, |
| ) |
| valid_frame_end = num_frames |
| else: |
| |
| valid_frame_start = 0 |
| valid_frame_end = min( |
| num_frames, |
| propagate_in_video_start_frame_idx + max_frame_num_to_track, |
| ) |
| else: |
| |
| valid_frame_start = 0 |
| valid_frame_end = num_frames |
|
|
| |
| frame_idx_curr_b = frame_idx - frame_idx % self.world_size |
| frame_idx_curr_e = min(frame_idx_curr_b + self.world_size, num_frames) |
|
|
| |
| frame_idx_curr_b = max(frame_idx_curr_b, valid_frame_start) |
| frame_idx_curr_e = min(frame_idx_curr_e, valid_frame_end) |
| |
| |
| if frame_idx not in multigpu_buffer: |
| with torch.profiler.record_function("build_multigpu_buffer_next_chunk1"): |
| self._build_multigpu_buffer_next_chunk( |
| backbone_out=backbone_out, |
| find_inputs=find_inputs, |
| geometric_prompt=geometric_prompt, |
| frame_idx_begin=frame_idx_curr_b, |
| frame_idx_end=frame_idx_curr_e, |
| num_frames=num_frames, |
| multigpu_buffer=multigpu_buffer, |
| run_nms=run_nms, |
| nms_prob_thresh=nms_prob_thresh, |
| nms_iou_thresh=nms_iou_thresh, |
| nms_use_iom=nms_use_iom, |
| feature_cache=feature_cache, |
| ) |
|
|
| |
| out = {} |
| for k, (v, handle) in multigpu_buffer[frame_idx].items(): |
| if self.is_multiplex: |
| if ( |
| k.startswith("interactive_backbone_") |
| or k.startswith("propagation_backbone_") |
| ) and not return_sam2_backbone_feats: |
| continue |
| else: |
| if k.startswith("sam2_backbone_") and not return_sam2_backbone_feats: |
| continue |
| if handle is not None: |
| handle.wait() |
| out[k] = v |
|
|
| |
| if not track_in_reverse and frame_idx_curr_b - self.world_size >= 0: |
| frame_idx_prev_e = frame_idx_curr_b |
| frame_idx_prev_b = frame_idx_curr_b - self.world_size |
| elif track_in_reverse and frame_idx_curr_e < num_frames: |
| frame_idx_prev_b = frame_idx_curr_e |
| frame_idx_prev_e = min(frame_idx_prev_b + self.world_size, num_frames) |
| else: |
| frame_idx_prev_b = frame_idx_prev_e = None |
| if frame_idx_prev_b is not None: |
| for frame_idx_rm in range(frame_idx_prev_b, frame_idx_prev_e): |
| multigpu_buffer.pop(frame_idx_rm, None) |
|
|
| |
| |
| |
|
|
| if not track_in_reverse and frame_idx_curr_e < valid_frame_end: |
| frame_idx_next_b = frame_idx_curr_e |
| frame_idx_next_e = min(frame_idx_next_b + self.world_size, valid_frame_end) |
| elif ( |
| track_in_reverse and frame_idx_curr_b - self.world_size >= valid_frame_start |
| ): |
| frame_idx_next_e = frame_idx_curr_b |
| frame_idx_next_b = max( |
| frame_idx_curr_b - self.world_size, valid_frame_start |
| ) |
| else: |
| frame_idx_next_b = frame_idx_next_e = None |
| if frame_idx_next_b is not None and frame_idx_next_b not in multigpu_buffer: |
| with torch.profiler.record_function("build_multigpu_buffer_next_chunk2"): |
| self._build_multigpu_buffer_next_chunk( |
| backbone_out=backbone_out, |
| find_inputs=find_inputs, |
| geometric_prompt=geometric_prompt, |
| frame_idx_begin=frame_idx_next_b, |
| frame_idx_end=frame_idx_next_e, |
| num_frames=num_frames, |
| multigpu_buffer=multigpu_buffer, |
| run_nms=run_nms, |
| nms_prob_thresh=nms_prob_thresh, |
| nms_iou_thresh=nms_iou_thresh, |
| feature_cache=feature_cache, |
| ) |
|
|
| return out, backbone_out |
|
|
| def _build_multigpu_buffer_next_chunk( |
| self, |
| backbone_out, |
| find_inputs, |
| geometric_prompt: Prompt, |
| frame_idx_begin, |
| frame_idx_end, |
| num_frames, |
| multigpu_buffer, |
| run_nms=False, |
| nms_prob_thresh=None, |
| nms_iou_thresh=None, |
| nms_use_iom=False, |
| feature_cache=None, |
| ): |
| """Compute FA outputs on a chunk of frames and store their results in multigpu_buffer.""" |
| |
| frame_idx_local_gpu = min(frame_idx_begin + self.rank, frame_idx_end - 1) |
| |
| with torch.profiler.record_function("forward_grounding"): |
| out_local = self.forward_grounding( |
| backbone_out=backbone_out, |
| |
| |
| |
| find_input=find_inputs[frame_idx_local_gpu % len(find_inputs)], |
| find_target=None, |
| geometric_prompt=geometric_prompt, |
| feature_cache=feature_cache, |
| ) |
| if run_nms: |
| with torch.profiler.record_function("nms_masks"): |
| |
| assert nms_prob_thresh is not None and nms_iou_thresh is not None |
| pred_probs = out_local["pred_logits"].squeeze(-1).sigmoid() |
| pred_masks = out_local["pred_masks"] |
| |
| for prompt_idx in range(pred_probs.size(0)): |
| keep = nms_masks( |
| pred_probs=pred_probs[prompt_idx], |
| pred_masks=pred_masks[prompt_idx], |
| prob_threshold=nms_prob_thresh, |
| iou_threshold=nms_iou_thresh, |
| nms_use_iom=nms_use_iom, |
| do_compile=getattr(self, "compile_model", False), |
| running_in_prod=getattr(self, "running_in_prod", False), |
| ) |
| |
| out_local["pred_logits"][prompt_idx, :, 0] -= 1e4 * (~keep).float() |
|
|
| if self.gather_backbone_out: |
| |
| if self.is_multiplex: |
| |
| |
|
|
| |
| inte_feats = out_local["prev_encoder_out"]["backbone_out"][ |
| "interactive" |
| ] |
| assert inte_feats["vision_mask"] is None |
| assert ( |
| len(inte_feats["backbone_fpn"]) == 3 |
| ) |
| assert all(x.mask is None for x in inte_feats["backbone_fpn"]) |
| |
| |
| inte_backbone_fpn_bf16 = [ |
| x.to(torch.bfloat16) for x in inte_feats["backbone_fpn"] |
| ] |
| inte_fpn0, inte_fpn_handle0 = self._gather_tensor( |
| inte_backbone_fpn_bf16[0].tensors |
| ) |
| inte_fpn1, inte_fpn_handle1 = self._gather_tensor( |
| inte_backbone_fpn_bf16[1].tensors |
| ) |
| inte_fpn2, inte_fpn_handle2 = self._gather_tensor( |
| inte_backbone_fpn_bf16[2].tensors |
| ) |
| |
| inte_vision_pos_enc = inte_feats["vision_pos_enc"] |
|
|
| feats = out_local["prev_encoder_out"]["backbone_out"]["sam2_backbone_out"] |
| assert feats["vision_mask"] is None |
| assert len(feats["backbone_fpn"]) == 3 |
| assert all(x.mask is None for x in feats["backbone_fpn"]) |
| |
| |
| backbone_fpn_bf16 = [x.to(torch.bfloat16) for x in feats["backbone_fpn"]] |
| fpn0, fpn_handle0 = self._gather_tensor(backbone_fpn_bf16[0].tensors) |
| fpn1, fpn_handle1 = self._gather_tensor(backbone_fpn_bf16[1].tensors) |
| fpn2, fpn_handle2 = self._gather_tensor(backbone_fpn_bf16[2].tensors) |
| |
| vision_pos_enc = feats["vision_pos_enc"] |
|
|
| |
| out_local = { |
| "pred_logits": out_local["pred_logits"], |
| "pred_boxes": out_local["pred_boxes"], |
| "pred_boxes_xyxy": out_local["pred_boxes_xyxy"], |
| "pred_masks": out_local["pred_masks"], |
| "pred_object_ids": self._get_dummy_object_ids(out_local["pred_logits"]), |
| } |
|
|
| |
| |
| out_gathered = {k: self._gather_tensor(v) for k, v in out_local.items()} |
| for rank in range(self.world_size): |
| frame_idx_to_save = frame_idx_begin + rank |
| if frame_idx_to_save >= num_frames: |
| continue |
| frame_buffer = { |
| k: (v[rank], handle) for k, (v, handle) in out_gathered.items() |
| } |
| if self.gather_backbone_out: |
| |
| if self.is_multiplex: |
| frame_buffer["interactive_backbone_fpn_0"] = ( |
| inte_fpn0[rank], |
| inte_fpn_handle0, |
| ) |
| frame_buffer["interactive_backbone_fpn_1"] = ( |
| inte_fpn1[rank], |
| inte_fpn_handle1, |
| ) |
| frame_buffer["interactive_backbone_fpn_2"] = ( |
| inte_fpn2[rank], |
| inte_fpn_handle2, |
| ) |
| frame_buffer["interactive_backbone_pos_enc"] = ( |
| inte_vision_pos_enc, |
| None, |
| ) |
| frame_buffer["sam2_backbone_fpn_0"] = (fpn0[rank], fpn_handle0) |
| frame_buffer["sam2_backbone_fpn_1"] = (fpn1[rank], fpn_handle1) |
| frame_buffer["sam2_backbone_fpn_2"] = (fpn2[rank], fpn_handle2) |
| frame_buffer["sam2_backbone_pos_enc"] = (vision_pos_enc, None) |
|
|
| multigpu_buffer[frame_idx_to_save] = frame_buffer |
|
|
| def _gather_tensor(self, x): |
| if self.world_size == 1: |
| return [x], None |
|
|
| async_op = self.async_all_gather |
| |
| |
| x = x.contiguous() |
| output_list = [torch.empty_like(x) for _ in range(self.world_size)] |
| handle = torch.distributed.all_gather(output_list, x, async_op=async_op) |
| return output_list, handle |
|
|
| def forward_video_grounding_batched_multigpu( |
| self, |
| backbone_out, |
| find_inputs, |
| geometric_prompt: Prompt, |
| frame_idx, |
| num_frames, |
| |
| grounding_cache, |
| track_in_reverse=False, |
| |
| return_sam2_backbone_feats=False, |
| |
| run_nms=False, |
| nms_prob_thresh=None, |
| nms_iou_thresh=None, |
| nms_use_iom=False, |
| |
| max_frame_num_to_track=None, |
| propagate_in_video_start_frame_idx=None, |
| |
| feature_cache=None, |
| |
| batch_size=16, |
| ): |
| """ |
| Fully batched forward_grounding that processes chunks of frames together on each GPU. |
| |
| Unlike forward_video_grounding_multigpu which processes 1 frame per GPU per chunk, |
| this method processes `batch_size` frames at once using the batched forward_grounding |
| approach from Sam3MultiplexImageBase. |
| |
| For single-GPU (world_size=1), this is equivalent to forward_grounding_batched. |
| For multi-GPU, each GPU processes batch_size frames in parallel. |
| |
| Args: |
| backbone_out: Dictionary containing backbone outputs and image batch. |
| find_inputs: List of FindStage objects for all frames. |
| geometric_prompt: Prompt object (used as template, individual prompts are |
| constructed from find_inputs for batching). |
| frame_idx: Current frame index to process. |
| num_frames: Total number of frames in the video. |
| grounding_cache: Dictionary to cache grounding outputs. |
| track_in_reverse: If True, processing in reverse frame order. |
| return_sam2_backbone_feats: Whether to also return SAM2 backbone features. |
| run_nms: Whether to perform NMS on detection outputs. |
| nms_prob_thresh: Probability threshold for NMS. |
| nms_iou_thresh: IoU threshold for NMS. |
| nms_use_iom: Whether to use IoM for NMS. |
| max_frame_num_to_track: Maximum number of frames to track. |
| propagate_in_video_start_frame_idx: Start frame index for propagation. |
| feature_cache: Optional dictionary for backbone feature caching. |
| batch_size: Number of frames to batch together per GPU (default: 16). |
| |
| Returns: |
| Tuple of (out, backbone_out) where out contains detection results for frame_idx. |
| """ |
| |
| if max_frame_num_to_track is not None: |
| if propagate_in_video_start_frame_idx is None: |
| propagate_in_video_start_frame_idx = 0 |
| if track_in_reverse: |
| valid_frame_start = ( |
| propagate_in_video_start_frame_idx - max_frame_num_to_track + 1 |
| ) |
| valid_frame_end = propagate_in_video_start_frame_idx |
| else: |
| valid_frame_start = propagate_in_video_start_frame_idx |
| valid_frame_end = ( |
| propagate_in_video_start_frame_idx + max_frame_num_to_track |
| ) |
| else: |
| valid_frame_start = 0 |
| valid_frame_end = num_frames |
|
|
| |
| if "grounding_buffer" not in grounding_cache: |
| grounding_cache["grounding_buffer"] = {} |
|
|
| |
| chunk_start = (frame_idx // batch_size) * batch_size |
| chunk_end = min(chunk_start + batch_size, valid_frame_end) |
| chunk_key = (chunk_start, chunk_end) |
|
|
| |
| if chunk_key not in grounding_cache["grounding_buffer"]: |
| self._cleanup_previous_chunks_multigpu( |
| grounding_cache=grounding_cache, |
| current_chunk_key=chunk_key, |
| batch_size=batch_size, |
| num_frames=num_frames, |
| track_in_reverse=track_in_reverse, |
| ) |
| with torch.profiler.record_function( |
| "forward_grounding_batched.process_chunk" |
| ): |
| chunk_outputs = self._process_grounding_chunk_batched( |
| backbone_out=backbone_out, |
| find_inputs=find_inputs, |
| chunk_start=chunk_start, |
| chunk_end=chunk_end, |
| run_nms=run_nms, |
| nms_prob_thresh=nms_prob_thresh, |
| nms_iou_thresh=nms_iou_thresh, |
| nms_use_iom=nms_use_iom, |
| feature_cache=feature_cache, |
| return_sam2_backbone_feats=return_sam2_backbone_feats, |
| ) |
| grounding_cache["grounding_buffer"][chunk_key] = chunk_outputs |
|
|
| |
| self._cleanup_previous_chunks_multigpu( |
| grounding_cache=grounding_cache, |
| current_chunk_key=chunk_key, |
| batch_size=batch_size, |
| num_frames=num_frames, |
| track_in_reverse=track_in_reverse, |
| ) |
|
|
| |
| chunk_outputs = grounding_cache["grounding_buffer"][chunk_key] |
| local_idx = frame_idx - chunk_start |
|
|
| |
| out = self._slice_batched_output( |
| chunk_outputs, local_idx, return_sam2_backbone_feats |
| ) |
|
|
| return out, backbone_out |
|
|
| def _process_grounding_chunk_batched( |
| self, |
| backbone_out, |
| find_inputs, |
| chunk_start: int, |
| chunk_end: int, |
| run_nms: bool, |
| nms_prob_thresh, |
| nms_iou_thresh, |
| nms_use_iom: bool, |
| feature_cache, |
| return_sam2_backbone_feats: bool, |
| ): |
| """ |
| Process a chunk of frames through the full forward_grounding pipeline in batch. |
| """ |
| chunk_size = chunk_end - chunk_start |
|
|
| |
| chunk_geo_prompts = [ |
| self._get_geo_prompt_from_find_input(find_inputs[i % len(find_inputs)]) |
| for i in range(chunk_start, chunk_end) |
| ] |
|
|
| |
| batched_find_input = self._batch_find_inputs( |
| find_inputs, chunk_start, chunk_end |
| ) |
|
|
| |
| batched_geometric_prompt = self._batch_geometric_prompts_from_list( |
| chunk_geo_prompts |
| ) |
|
|
| |
| with torch.profiler.record_function("forward_grounding_batched.forward"): |
| out = self.forward_grounding( |
| backbone_out=backbone_out, |
| find_input=batched_find_input, |
| find_target=None, |
| geometric_prompt=batched_geometric_prompt, |
| feature_cache=feature_cache, |
| ) |
|
|
| |
| if run_nms: |
| with torch.profiler.record_function("forward_grounding_batched.nms"): |
| assert nms_prob_thresh is not None and nms_iou_thresh is not None |
| pred_probs = out["pred_logits"].squeeze(-1).sigmoid() |
| pred_masks = out["pred_masks"] |
| |
| |
| |
| keep = nms_masks( |
| pred_probs=pred_probs, |
| pred_masks=pred_masks, |
| prob_threshold=nms_prob_thresh, |
| iou_threshold=nms_iou_thresh, |
| nms_use_iom=nms_use_iom, |
| do_compile=getattr(self, "compile_model", False), |
| running_in_prod=getattr(self, "running_in_prod", False), |
| ) |
| |
| |
| out["pred_logits"][:, :, 0] -= 1e4 * (~keep).float() |
|
|
| |
| if return_sam2_backbone_feats and "prev_encoder_out" in out: |
| backbone_data = out["prev_encoder_out"]["backbone_out"] |
| if self.is_multiplex and "interactive" in backbone_data: |
| out["_interactive_backbone"] = backbone_data["interactive"] |
| if "sam2_backbone_out" in backbone_data: |
| out["_sam2_backbone"] = backbone_data["sam2_backbone_out"] |
|
|
| out["_chunk_size"] = chunk_size |
| return out |
|
|
| def _slice_batched_output( |
| self, |
| chunk_outputs, |
| local_idx: int, |
| return_sam2_backbone_feats: bool, |
| ): |
| """ |
| Slice a single frame's output from the batched chunk outputs. |
| """ |
| out = {} |
|
|
| |
| batch_dim_keys = { |
| "pred_logits", |
| "pred_boxes", |
| "pred_boxes_xyxy", |
| "pred_masks", |
| "pred_logits_o2m", |
| "pred_boxes_o2m", |
| "pred_boxes_xyxy_o2m", |
| "pred_masks_o2m", |
| "queries", |
| "presence_logit_dec", |
| } |
|
|
| |
| skip_keys = { |
| "_chunk_size", |
| "_interactive_backbone", |
| "_sam2_backbone", |
| "prev_encoder_out", |
| "encoder_hidden_states", |
| "aux_outputs", |
| } |
|
|
| for key, value in chunk_outputs.items(): |
| if key in skip_keys: |
| continue |
| if key in batch_dim_keys and isinstance(value, torch.Tensor): |
| out[key] = value[local_idx : local_idx + 1] |
| elif isinstance(value, torch.Tensor): |
| try: |
| out[key] = value[local_idx : local_idx + 1] |
| except (IndexError, RuntimeError): |
| out[key] = value |
|
|
| |
| if "pred_logits" in out: |
| out["pred_object_ids"] = self._get_dummy_object_ids(out["pred_logits"]) |
|
|
| |
| if return_sam2_backbone_feats: |
| if "_sam2_backbone" in chunk_outputs: |
| sam2_bb = chunk_outputs["_sam2_backbone"] |
| out["sam2_backbone_fpn_0"] = sam2_bb["backbone_fpn"][0].tensors[ |
| local_idx : local_idx + 1 |
| ] |
| out["sam2_backbone_fpn_1"] = sam2_bb["backbone_fpn"][1].tensors[ |
| local_idx : local_idx + 1 |
| ] |
| out["sam2_backbone_fpn_2"] = sam2_bb["backbone_fpn"][2].tensors[ |
| local_idx : local_idx + 1 |
| ] |
| out["sam2_backbone_pos_enc"] = [ |
| x[local_idx : local_idx + 1] for x in sam2_bb["vision_pos_enc"] |
| ] |
|
|
| if self.is_multiplex and "_interactive_backbone" in chunk_outputs: |
| inte_bb = chunk_outputs["_interactive_backbone"] |
| out["interactive_backbone_fpn_0"] = inte_bb["backbone_fpn"][0].tensors[ |
| local_idx : local_idx + 1 |
| ] |
| out["interactive_backbone_fpn_1"] = inte_bb["backbone_fpn"][1].tensors[ |
| local_idx : local_idx + 1 |
| ] |
| out["interactive_backbone_fpn_2"] = inte_bb["backbone_fpn"][2].tensors[ |
| local_idx : local_idx + 1 |
| ] |
| out["interactive_backbone_pos_enc"] = [ |
| x[local_idx : local_idx + 1] for x in inte_bb["vision_pos_enc"] |
| ] |
|
|
| return out |
|
|
| def _cleanup_previous_chunks_multigpu( |
| self, |
| grounding_cache, |
| current_chunk_key, |
| batch_size: int, |
| num_frames: int, |
| track_in_reverse: bool, |
| ): |
| """Remove previous chunks from cache to save GPU memory.""" |
| chunk_start, chunk_end = current_chunk_key |
|
|
| if not track_in_reverse: |
| prev_chunk_start = chunk_start - batch_size |
| if prev_chunk_start >= 0: |
| prev_chunk_end = chunk_start |
| prev_chunk_key = (prev_chunk_start, prev_chunk_end) |
|
|
| |
| chunk = grounding_cache["grounding_buffer"].pop(prev_chunk_key, None) |
| if chunk is not None: |
| del chunk |
| else: |
| next_chunk_start = chunk_end |
| if next_chunk_start < num_frames: |
| next_chunk_end = min(next_chunk_start + batch_size, num_frames) |
| next_chunk_key = (next_chunk_start, next_chunk_end) |
| grounding_cache["grounding_buffer"].pop(next_chunk_key, None) |
|
|