ColabWan / preprocessing /sam3 /model /sam3_multiplex_detector.py
1ripon1's picture
Upload folder using huggingface_hub
7344bef verified
Raw
History Blame Contribute Delete
40.4 kB
import os
import torch
from ..model.vl_combiner import SAM3VLBackbone
try:
from ..model.vl_combiner import SAM3VLBackboneTri
except ImportError:
SAM3VLBackboneTri = None
from typing import Dict, List, Optional
import numpy as np
from ..model.data_misc import BatchedDatapoint, FindStage
from ..model.geometry_encoders import Prompt
from ..model.model_misc import SAM3Output
from ..model.sam3_image import Sam3Image
from ..model.sam3_multiplex_detector_utils import nms_masks
class Sam3MultiplexImageBase(Sam3Image):
"""A wrapper class to run Sam3Image on videos for per-frame detection (no tracking)."""
def __init__(
self,
*args,
tracking_score_thresh: float = 0.0,
offload_outputs_to_cpu_for_eval: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.tracking_score_thresh = tracking_score_thresh
self.offload_outputs_to_cpu_for_eval = offload_outputs_to_cpu_for_eval
self.trim_outputs_for_eval = True # dummy option -- it doesn't do anything
def forward(
self,
input: BatchedDatapoint,
is_inference=False, # (a dummy parameter not used anymore)
):
assert not self.training, (
"Sam3MultiplexImageBase should only be used in eval mode."
)
device = self.device
backbone_out = {"img_batch_all_stages": input.img_batch}
text_outputs = self.backbone.forward_text(input.find_text_batch, device=device)
backbone_out.update(text_outputs)
num_frames = len(input.find_inputs)
previous_stages_out = SAM3Output(
iter_mode=SAM3Output.IterMode.LAST_STEP_PER_STAGE
)
for frame_idx in range(num_frames):
find_input = input.find_inputs[frame_idx]
find_target = input.find_targets[frame_idx]
geometric_prompt = self._get_geo_prompt_from_find_input(find_input)
cur_out, _ = self.forward_video_grounding(
backbone_out=backbone_out,
find_input=find_input,
find_target=find_target,
geometric_prompt=geometric_prompt,
)
# offload model outputs to CPU (to save GPU memory) for evaluation
if self.offload_outputs_to_cpu_for_eval:
cur_out = {k: v.cpu() for k, v in cur_out.items()}
previous_stages_out.append([cur_out])
get_queries = None
return previous_stages_out, get_queries
def forward_video_grounding(
self,
backbone_out,
find_input,
find_target,
geometric_prompt: Prompt,
**kwargs,
):
# route this to the image grounding forward method
out = self.forward_grounding(
backbone_out=backbone_out,
find_input=find_input,
find_target=find_target,
geometric_prompt=geometric_prompt,
)
# trim the output to only include the necessary keys
out = {
"pred_logits": out["pred_logits"],
"pred_boxes": out["pred_boxes"],
"pred_boxes_xyxy": out["pred_boxes_xyxy"],
"pred_masks": out["pred_masks"],
"pred_object_ids": self._get_dummy_object_ids(out["pred_logits"]),
}
return out, backbone_out
def _get_dummy_object_ids(self, pred_logits):
"""Generate dummy object IDs for the detected objects, based on their detection query indices."""
# Assuming pred_logits has shape [batch_size, num_queries, num_classes]
B, Q, _ = pred_logits.shape
is_above_thresh = pred_logits.squeeze(2) > self.tracking_score_thresh
dummy_obj_ids = torch.arange(Q, device=self.device).expand(B, -1)
dummy_obj_ids = torch.where(is_above_thresh, dummy_obj_ids, -1)
return dummy_obj_ids
def _trim_outputs(self, *args, **kwargs):
pass # not needed for image-on-video
def _batch_find_inputs(
self,
find_inputs: List[FindStage],
chunk_start: int,
chunk_end: int,
) -> FindStage:
"""
Batch multiple FindStage objects into a single batched FindStage.
For each frame in the chunk, creates img_ids that point to the correct
frame index. When processing streaming video, the img_ids are the actual
frame indices (e.g., [0, 1, 2, ..., 15] for chunk 0-16), and the modulo
for circular buffer access is applied later in _get_img_feats.
Args:
find_inputs: List of FindStage objects for all frames.
chunk_start: Start index of the chunk.
chunk_end: End index of the chunk (exclusive).
Returns:
A single FindStage with batched tensors.
"""
chunk_find_inputs = [
find_inputs[i % len(find_inputs)] for i in range(chunk_start, chunk_end)
]
# Generate img_ids based on chunk frame indices
# Each frame in the chunk gets its corresponding frame index
# The modulo for circular buffer access is handled in _get_img_feats
device = chunk_find_inputs[0].img_ids.device
dtype = chunk_find_inputs[0].img_ids.dtype
img_ids_list = [
torch.tensor([i], device=device, dtype=dtype)
for i in range(chunk_start, chunk_end)
]
batched_img_ids = torch.cat(img_ids_list, dim=0)
# Generate img_ids_np to match
img_ids_np_list = [np.array([i]) for i in range(chunk_start, chunk_end)]
batched_img_ids_np = np.concatenate(img_ids_np_list, axis=0)
# Concatenate text_ids
text_ids_list = [fi.text_ids for fi in chunk_find_inputs]
batched_text_ids = torch.cat(text_ids_list, dim=0)
# Concatenate input_boxes
input_boxes_list = [fi.input_boxes for fi in chunk_find_inputs]
batched_input_boxes = (
torch.cat(input_boxes_list, dim=0)
if input_boxes_list[0] is not None
else None
)
# Concatenate input_boxes_mask
input_boxes_mask_list = [fi.input_boxes_mask for fi in chunk_find_inputs]
batched_input_boxes_mask = (
torch.cat(input_boxes_mask_list, dim=0)
if input_boxes_mask_list[0] is not None
else None
)
# Concatenate input_boxes_label
input_boxes_label_list = [fi.input_boxes_label for fi in chunk_find_inputs]
batched_input_boxes_label = (
torch.cat(input_boxes_label_list, dim=0)
if input_boxes_label_list[0] is not None
else None
)
# Concatenate input_points
input_points_list = [fi.input_points for fi in chunk_find_inputs]
batched_input_points = (
torch.cat(input_points_list, dim=0)
if input_points_list[0] is not None
else None
)
# Concatenate input_points_mask
input_points_mask_list = [fi.input_points_mask for fi in chunk_find_inputs]
batched_input_points_mask = (
torch.cat(input_points_mask_list, dim=0)
if input_points_mask_list[0] is not None
else None
)
# Handle optional fields
input_boxes_before_embed_list = [
fi.input_boxes_before_embed for fi in chunk_find_inputs
]
batched_input_boxes_before_embed = (
torch.cat(input_boxes_before_embed_list, dim=0)
if input_boxes_before_embed_list[0] is not None
else None
)
input_points_before_embed_list = [
fi.input_points_before_embed for fi in chunk_find_inputs
]
batched_input_points_before_embed = (
torch.cat(input_points_before_embed_list, dim=0)
if input_points_before_embed_list[0] is not None
else None
)
# Create batched FindStage
batched_find_input = FindStage(
img_ids=batched_img_ids,
img_ids_np=batched_img_ids_np,
text_ids=batched_text_ids,
input_boxes=batched_input_boxes,
input_boxes_mask=batched_input_boxes_mask,
input_boxes_label=batched_input_boxes_label,
input_points=batched_input_points,
input_points_mask=batched_input_points_mask,
ptrs=None, # Not batching pointers for now
ptrs_seg=None,
object_ids=None,
input_boxes_before_embed=batched_input_boxes_before_embed,
input_points_before_embed=batched_input_points_before_embed,
)
return batched_find_input
def _batch_geometric_prompts(
self,
geometric_prompts: List[Prompt],
chunk_start: int,
chunk_end: int,
) -> Prompt:
"""
Batch multiple Prompt objects into a single batched Prompt.
Args:
geometric_prompts: List of Prompt objects for all frames.
chunk_start: Start index of the chunk.
chunk_end: End index of the chunk (exclusive).
Returns:
A single Prompt with batched tensors.
"""
chunk_prompts = [geometric_prompts[i] for i in range(chunk_start, chunk_end)]
return self._batch_geometric_prompts_from_list(chunk_prompts)
def _batch_geometric_prompts_from_list(
self,
chunk_prompts: List[Prompt],
) -> Prompt:
"""
Batch a list of Prompt objects into a single batched Prompt.
Prompt uses seq-first, batch-second convention:
- box_embeddings: N_boxes x B x C_box - batch along dim 1
- box_mask: B x N_boxes - batch along dim 0
- box_labels: N_boxes x B - batch along dim 1
- point_embeddings: N_points x B x C_point - batch along dim 1
- point_mask: B x N_points - batch along dim 0
- point_labels: N_points x B - batch along dim 1
Args:
chunk_prompts: List of Prompt objects to batch.
Returns:
A single Prompt with batched tensors.
"""
# Helper function to batch tensors along specified dimension
def batch_tensors(tensors, dim):
if tensors[0] is None:
return None
return torch.cat(tensors, dim=dim)
# Batch box embeddings (N_boxes x B x C_box - batch along dim 1)
box_embeddings_list = [p.box_embeddings for p in chunk_prompts]
batched_box_embeddings = batch_tensors(box_embeddings_list, dim=1)
# Batch box mask (B x N_boxes - batch along dim 0)
box_mask_list = [p.box_mask for p in chunk_prompts]
batched_box_mask = batch_tensors(box_mask_list, dim=0)
# Batch box labels (N_boxes x B - batch along dim 1)
box_labels_list = [p.box_labels for p in chunk_prompts]
batched_box_labels = batch_tensors(box_labels_list, dim=1)
# Batch point embeddings (N_points x B x C_point - batch along dim 1)
point_embeddings_list = [p.point_embeddings for p in chunk_prompts]
batched_point_embeddings = batch_tensors(point_embeddings_list, dim=1)
# Batch point mask (B x N_points - batch along dim 0)
point_mask_list = [p.point_mask for p in chunk_prompts]
batched_point_mask = batch_tensors(point_mask_list, dim=0)
# Batch point labels (N_points x B - batch along dim 1)
point_labels_list = [p.point_labels for p in chunk_prompts]
batched_point_labels = batch_tensors(point_labels_list, dim=1)
# Create batched Prompt
batched_prompt = Prompt(
box_embeddings=batched_box_embeddings,
box_mask=batched_box_mask,
box_labels=batched_box_labels,
point_embeddings=batched_point_embeddings,
point_mask=batched_point_mask,
point_labels=batched_point_labels,
)
return batched_prompt
class Sam3MultiplexDetector(Sam3MultiplexImageBase):
def __init__(
self,
*args,
async_all_gather=True,
gather_backbone_out=None,
is_multiplex=False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.rank = int(os.getenv("RANK", "0"))
self.world_size = int(os.getenv("WORLD_SIZE", "1"))
self.async_all_gather = async_all_gather
# if gather_backbone is not set, default to gathering only for `SAM3VLBackbone`
if gather_backbone_out is None:
gather_backbone_out = isinstance(self.backbone, SAM3VLBackbone) or (
SAM3VLBackboneTri is not None
and isinstance(self.backbone, SAM3VLBackboneTri)
)
self.gather_backbone_out = gather_backbone_out
self.is_multiplex = is_multiplex
def forward_video_grounding_multigpu(
self,
backbone_out,
find_inputs,
geometric_prompt: Prompt,
frame_idx,
num_frames,
# `multigpu_buffer` is a dict to cache FA outputs in a chunk between different calls
multigpu_buffer,
track_in_reverse=False,
# whether to also return the SAM2 backbone features (in addition to FA results)
return_sam2_backbone_feats=False,
# whether to perform NMS and suppress the scores of those detections removed by NMS
run_nms=False,
nms_prob_thresh=None,
nms_iou_thresh=None,
nms_use_iom=False,
# tracking bounds to respect max_frame_num_to_track
max_frame_num_to_track=None,
propagate_in_video_start_frame_idx=None,
# feature_cache for buffered backbone computation
feature_cache=None,
**kwargs,
):
"""
Compute the FA detection outputs in a distributed manner, where all GPUs process
a chunk of frames (equal to the number of GPUs) at once and store them in cache.
"""
# Calculate valid frame range based on max_frame_num_to_track
# We prevent pre-fetching beyond the tracking window relative to current frame
if max_frame_num_to_track is not None:
if propagate_in_video_start_frame_idx is None:
propagate_in_video_start_frame_idx = 0
if track_in_reverse:
# When going backwards, limit how far back we can go from current frame
valid_frame_start = max(
0,
propagate_in_video_start_frame_idx - max_frame_num_to_track + 1,
)
valid_frame_end = num_frames
else:
# When going forwards, limit how far ahead we can go from current frame
valid_frame_start = 0
valid_frame_end = min(
num_frames,
propagate_in_video_start_frame_idx + max_frame_num_to_track,
)
else:
# No tracking limit specified, use full video range
valid_frame_start = 0
valid_frame_end = num_frames
# Step 1: fetch the FA outputs in the current chunk from buffer
frame_idx_curr_b = frame_idx - frame_idx % self.world_size
frame_idx_curr_e = min(frame_idx_curr_b + self.world_size, num_frames)
# Clamp the current chunk to the valid tracking range
frame_idx_curr_b = max(frame_idx_curr_b, valid_frame_start)
frame_idx_curr_e = min(frame_idx_curr_e, valid_frame_end)
# in case the current frame's FA results are not in the buffer yet, build the current chunk
# (this should only happen on the first chunk, since we are also building the next chunk below)
if frame_idx not in multigpu_buffer:
with torch.profiler.record_function("build_multigpu_buffer_next_chunk1"):
self._build_multigpu_buffer_next_chunk(
backbone_out=backbone_out,
find_inputs=find_inputs,
geometric_prompt=geometric_prompt,
frame_idx_begin=frame_idx_curr_b,
frame_idx_end=frame_idx_curr_e,
num_frames=num_frames,
multigpu_buffer=multigpu_buffer,
run_nms=run_nms,
nms_prob_thresh=nms_prob_thresh,
nms_iou_thresh=nms_iou_thresh,
nms_use_iom=nms_use_iom,
feature_cache=feature_cache,
)
# read out the current frame's results from `multigpu_buffer`
out = {}
for k, (v, handle) in multigpu_buffer[frame_idx].items():
if self.is_multiplex:
if (
k.startswith("interactive_backbone_")
or k.startswith("propagation_backbone_")
) and not return_sam2_backbone_feats:
continue
else:
if k.startswith("sam2_backbone_") and not return_sam2_backbone_feats:
continue
if handle is not None:
handle.wait() # wait for async all-gather to finish
out[k] = v
# Step 2: remove FA outputs of the previous chunk from cache to save GPU memory
if not track_in_reverse and frame_idx_curr_b - self.world_size >= 0:
frame_idx_prev_e = frame_idx_curr_b
frame_idx_prev_b = frame_idx_curr_b - self.world_size
elif track_in_reverse and frame_idx_curr_e < num_frames:
frame_idx_prev_b = frame_idx_curr_e
frame_idx_prev_e = min(frame_idx_prev_b + self.world_size, num_frames)
else:
frame_idx_prev_b = frame_idx_prev_e = None
if frame_idx_prev_b is not None:
for frame_idx_rm in range(frame_idx_prev_b, frame_idx_prev_e):
multigpu_buffer.pop(frame_idx_rm, None)
# Step 3: compute and cache FA outputs of the next chunk ahead of time
# (so that we can overlap computation with all-gather transfer)
# Respect tracking bounds when calculating next chunk
if not track_in_reverse and frame_idx_curr_e < valid_frame_end:
frame_idx_next_b = frame_idx_curr_e
frame_idx_next_e = min(frame_idx_next_b + self.world_size, valid_frame_end)
elif (
track_in_reverse and frame_idx_curr_b - self.world_size >= valid_frame_start
):
frame_idx_next_e = frame_idx_curr_b
frame_idx_next_b = max(
frame_idx_curr_b - self.world_size, valid_frame_start
)
else:
frame_idx_next_b = frame_idx_next_e = None
if frame_idx_next_b is not None and frame_idx_next_b not in multigpu_buffer:
with torch.profiler.record_function("build_multigpu_buffer_next_chunk2"):
self._build_multigpu_buffer_next_chunk(
backbone_out=backbone_out,
find_inputs=find_inputs,
geometric_prompt=geometric_prompt,
frame_idx_begin=frame_idx_next_b,
frame_idx_end=frame_idx_next_e,
num_frames=num_frames,
multigpu_buffer=multigpu_buffer,
run_nms=run_nms,
nms_prob_thresh=nms_prob_thresh,
nms_iou_thresh=nms_iou_thresh,
feature_cache=feature_cache,
)
return out, backbone_out
def _build_multigpu_buffer_next_chunk(
self,
backbone_out,
find_inputs,
geometric_prompt: Prompt,
frame_idx_begin,
frame_idx_end,
num_frames,
multigpu_buffer,
run_nms=False,
nms_prob_thresh=None,
nms_iou_thresh=None,
nms_use_iom=False,
feature_cache=None,
):
"""Compute FA outputs on a chunk of frames and store their results in multigpu_buffer."""
# each GPU computes FA on one frame in the chunk (in a round-robin manner)
frame_idx_local_gpu = min(frame_idx_begin + self.rank, frame_idx_end - 1)
# `forward_grounding` (from base class `Sam3MultiplexImageBase`) runs FA on a single frame
with torch.profiler.record_function("forward_grounding"):
out_local = self.forward_grounding(
backbone_out=backbone_out,
# HACK: Since find_inputs is on GPU having to realloc is expensive so changing the values in place for the prod usecase
# i.e. when using the streaming frame loader resource instead of local file. For non-prod is always
# frame_idx_local_gpu < len(find_inputs) so should be a no-op
find_input=find_inputs[frame_idx_local_gpu % len(find_inputs)],
find_target=None,
geometric_prompt=geometric_prompt,
feature_cache=feature_cache,
)
if run_nms:
with torch.profiler.record_function("nms_masks"):
# run NMS as a post-processing step on top of the detection outputs
assert nms_prob_thresh is not None and nms_iou_thresh is not None
pred_probs = out_local["pred_logits"].squeeze(-1).sigmoid()
pred_masks = out_local["pred_masks"]
# loop over text prompts (not an overhead for demo where there's only 1 prompt)
for prompt_idx in range(pred_probs.size(0)):
keep = nms_masks(
pred_probs=pred_probs[prompt_idx],
pred_masks=pred_masks[prompt_idx],
prob_threshold=nms_prob_thresh,
iou_threshold=nms_iou_thresh,
nms_use_iom=nms_use_iom,
do_compile=getattr(self, "compile_model", False),
running_in_prod=getattr(self, "running_in_prod", False),
)
# set a very low threshold for those detections removed by NMS
out_local["pred_logits"][prompt_idx, :, 0] -= 1e4 * (~keep).float()
if self.gather_backbone_out:
# gather the SAM 2 backbone features across GPUs
if self.is_multiplex:
# Note that we should not need to compute the interaction features every frame
# TODO: rooms for optimization
# Interaction features
inte_feats = out_local["prev_encoder_out"]["backbone_out"][
"interactive"
]
assert inte_feats["vision_mask"] is None
assert (
len(inte_feats["backbone_fpn"]) == 3
) # SAM2 backbone always have 3 levels
assert all(x.mask is None for x in inte_feats["backbone_fpn"])
# cast the SAM2 backbone features to bfloat16 for all-gather (this is usually
# a no-op, SAM2 backbone features are likely already in bfloat16 due to AMP)
inte_backbone_fpn_bf16 = [
x.to(torch.bfloat16) for x in inte_feats["backbone_fpn"]
]
inte_fpn0, inte_fpn_handle0 = self._gather_tensor(
inte_backbone_fpn_bf16[0].tensors
)
inte_fpn1, inte_fpn_handle1 = self._gather_tensor(
inte_backbone_fpn_bf16[1].tensors
)
inte_fpn2, inte_fpn_handle2 = self._gather_tensor(
inte_backbone_fpn_bf16[2].tensors
)
# vision_pos_enc is the same on all frames, so no need to all-gather them
inte_vision_pos_enc = inte_feats["vision_pos_enc"]
feats = out_local["prev_encoder_out"]["backbone_out"]["sam2_backbone_out"]
assert feats["vision_mask"] is None
assert len(feats["backbone_fpn"]) == 3 # SAM2 backbone always have 3 levels
assert all(x.mask is None for x in feats["backbone_fpn"])
# cast the SAM2 backbone features to bfloat16 for all-gather (this is usually
# a no-op, SAM2 backbone features are likely already in bfloat16 due to AMP)
backbone_fpn_bf16 = [x.to(torch.bfloat16) for x in feats["backbone_fpn"]]
fpn0, fpn_handle0 = self._gather_tensor(backbone_fpn_bf16[0].tensors)
fpn1, fpn_handle1 = self._gather_tensor(backbone_fpn_bf16[1].tensors)
fpn2, fpn_handle2 = self._gather_tensor(backbone_fpn_bf16[2].tensors)
# vision_pos_enc is the same on all frames, so no need to all-gather them
vision_pos_enc = feats["vision_pos_enc"]
# trim the FA output to only include the necessary keys
out_local = {
"pred_logits": out_local["pred_logits"],
"pred_boxes": out_local["pred_boxes"],
"pred_boxes_xyxy": out_local["pred_boxes_xyxy"],
"pred_masks": out_local["pred_masks"],
"pred_object_ids": self._get_dummy_object_ids(out_local["pred_logits"]),
}
# gather the results: after this step, each GPU will receive FA outputs on
# all frames in the chunk and store them in `multigpu_buffer`
out_gathered = {k: self._gather_tensor(v) for k, v in out_local.items()}
for rank in range(self.world_size):
frame_idx_to_save = frame_idx_begin + rank
if frame_idx_to_save >= num_frames:
continue
frame_buffer = {
k: (v[rank], handle) for k, (v, handle) in out_gathered.items()
}
if self.gather_backbone_out:
# also add gathered SAM 2 backbone features to frame_buffer
if self.is_multiplex:
frame_buffer["interactive_backbone_fpn_0"] = (
inte_fpn0[rank],
inte_fpn_handle0,
)
frame_buffer["interactive_backbone_fpn_1"] = (
inte_fpn1[rank],
inte_fpn_handle1,
)
frame_buffer["interactive_backbone_fpn_2"] = (
inte_fpn2[rank],
inte_fpn_handle2,
)
frame_buffer["interactive_backbone_pos_enc"] = (
inte_vision_pos_enc,
None,
)
frame_buffer["sam2_backbone_fpn_0"] = (fpn0[rank], fpn_handle0)
frame_buffer["sam2_backbone_fpn_1"] = (fpn1[rank], fpn_handle1)
frame_buffer["sam2_backbone_fpn_2"] = (fpn2[rank], fpn_handle2)
frame_buffer["sam2_backbone_pos_enc"] = (vision_pos_enc, None)
multigpu_buffer[frame_idx_to_save] = frame_buffer
def _gather_tensor(self, x):
if self.world_size == 1:
return [x], None
async_op = self.async_all_gather
# here `.contiguous()` is required -- otherwise NCCL all_gather
# sometimes gives wrong results (based on Ronghang's observations)
x = x.contiguous() # ensure contiguous memory for NCCL
output_list = [torch.empty_like(x) for _ in range(self.world_size)]
handle = torch.distributed.all_gather(output_list, x, async_op=async_op)
return output_list, handle
def forward_video_grounding_batched_multigpu(
self,
backbone_out,
find_inputs,
geometric_prompt: Prompt,
frame_idx,
num_frames,
# `grounding_cache` is a dict to cache FA outputs in a chunk between different calls
grounding_cache,
track_in_reverse=False,
# whether to also return the SAM2 backbone features (in addition to FA results)
return_sam2_backbone_feats=False,
# whether to perform NMS and suppress the scores of those detections removed by NMS
run_nms=False,
nms_prob_thresh=None,
nms_iou_thresh=None,
nms_use_iom=False,
# tracking bounds to respect max_frame_num_to_track
max_frame_num_to_track=None,
propagate_in_video_start_frame_idx=None,
# feature_cache for buffered backbone computation
feature_cache=None,
# batch_size for batched forward_grounding (default: 16)
batch_size=16,
):
"""
Fully batched forward_grounding that processes chunks of frames together on each GPU.
Unlike forward_video_grounding_multigpu which processes 1 frame per GPU per chunk,
this method processes `batch_size` frames at once using the batched forward_grounding
approach from Sam3MultiplexImageBase.
For single-GPU (world_size=1), this is equivalent to forward_grounding_batched.
For multi-GPU, each GPU processes batch_size frames in parallel.
Args:
backbone_out: Dictionary containing backbone outputs and image batch.
find_inputs: List of FindStage objects for all frames.
geometric_prompt: Prompt object (used as template, individual prompts are
constructed from find_inputs for batching).
frame_idx: Current frame index to process.
num_frames: Total number of frames in the video.
grounding_cache: Dictionary to cache grounding outputs.
track_in_reverse: If True, processing in reverse frame order.
return_sam2_backbone_feats: Whether to also return SAM2 backbone features.
run_nms: Whether to perform NMS on detection outputs.
nms_prob_thresh: Probability threshold for NMS.
nms_iou_thresh: IoU threshold for NMS.
nms_use_iom: Whether to use IoM for NMS.
max_frame_num_to_track: Maximum number of frames to track.
propagate_in_video_start_frame_idx: Start frame index for propagation.
feature_cache: Optional dictionary for backbone feature caching.
batch_size: Number of frames to batch together per GPU (default: 16).
Returns:
Tuple of (out, backbone_out) where out contains detection results for frame_idx.
"""
# Calculate valid frame range based on max_frame_num_to_track
if max_frame_num_to_track is not None:
if propagate_in_video_start_frame_idx is None:
propagate_in_video_start_frame_idx = 0
if track_in_reverse:
valid_frame_start = (
propagate_in_video_start_frame_idx - max_frame_num_to_track + 1
)
valid_frame_end = propagate_in_video_start_frame_idx
else:
valid_frame_start = propagate_in_video_start_frame_idx
valid_frame_end = (
propagate_in_video_start_frame_idx + max_frame_num_to_track
)
else:
valid_frame_start = 0
valid_frame_end = num_frames
# Initialize grounding_buffer if not present
if "grounding_buffer" not in grounding_cache:
grounding_cache["grounding_buffer"] = {}
# Calculate chunk boundaries - use batch_size instead of world_size
chunk_start = (frame_idx // batch_size) * batch_size
chunk_end = min(chunk_start + batch_size, valid_frame_end)
chunk_key = (chunk_start, chunk_end)
# Process chunk if not already cached
if chunk_key not in grounding_cache["grounding_buffer"]:
self._cleanup_previous_chunks_multigpu(
grounding_cache=grounding_cache,
current_chunk_key=chunk_key,
batch_size=batch_size,
num_frames=num_frames,
track_in_reverse=track_in_reverse,
)
with torch.profiler.record_function(
"forward_grounding_batched.process_chunk"
):
chunk_outputs = self._process_grounding_chunk_batched(
backbone_out=backbone_out,
find_inputs=find_inputs,
chunk_start=chunk_start,
chunk_end=chunk_end,
run_nms=run_nms,
nms_prob_thresh=nms_prob_thresh,
nms_iou_thresh=nms_iou_thresh,
nms_use_iom=nms_use_iom,
feature_cache=feature_cache,
return_sam2_backbone_feats=return_sam2_backbone_feats,
)
grounding_cache["grounding_buffer"][chunk_key] = chunk_outputs
# Auto-cleanup previous chunks
self._cleanup_previous_chunks_multigpu(
grounding_cache=grounding_cache,
current_chunk_key=chunk_key,
batch_size=batch_size,
num_frames=num_frames,
track_in_reverse=track_in_reverse,
)
# Retrieve the cached output for this frame
chunk_outputs = grounding_cache["grounding_buffer"][chunk_key]
local_idx = frame_idx - chunk_start
# Slice out the output for this specific frame
out = self._slice_batched_output(
chunk_outputs, local_idx, return_sam2_backbone_feats
)
return out, backbone_out
def _process_grounding_chunk_batched(
self,
backbone_out,
find_inputs,
chunk_start: int,
chunk_end: int,
run_nms: bool,
nms_prob_thresh,
nms_iou_thresh,
nms_use_iom: bool,
feature_cache,
return_sam2_backbone_feats: bool,
):
"""
Process a chunk of frames through the full forward_grounding pipeline in batch.
"""
chunk_size = chunk_end - chunk_start
# Build geometric prompts for the chunk
chunk_geo_prompts = [
self._get_geo_prompt_from_find_input(find_inputs[i % len(find_inputs)])
for i in range(chunk_start, chunk_end)
]
# Batch the find_inputs for this chunk
batched_find_input = self._batch_find_inputs(
find_inputs, chunk_start, chunk_end
)
# Batch the geometric prompts
batched_geometric_prompt = self._batch_geometric_prompts_from_list(
chunk_geo_prompts
)
# Run forward_grounding on the batched input
with torch.profiler.record_function("forward_grounding_batched.forward"):
out = self.forward_grounding(
backbone_out=backbone_out,
find_input=batched_find_input,
find_target=None,
geometric_prompt=batched_geometric_prompt,
feature_cache=feature_cache,
)
# Apply NMS per frame in the batch
if run_nms:
with torch.profiler.record_function("forward_grounding_batched.nms"):
assert nms_prob_thresh is not None and nms_iou_thresh is not None
pred_probs = out["pred_logits"].squeeze(-1).sigmoid()
pred_masks = out["pred_masks"]
# pred_probs shape: [batch_size, num_queries]
# pred_masks shape: [batch_size, num_queries, H, W]
# Use batched NMS to process all frames at once
keep = nms_masks(
pred_probs=pred_probs,
pred_masks=pred_masks,
prob_threshold=nms_prob_thresh,
iou_threshold=nms_iou_thresh,
nms_use_iom=nms_use_iom,
do_compile=getattr(self, "compile_model", False),
running_in_prod=getattr(self, "running_in_prod", False),
)
# Set a very low threshold for detections removed by NMS
# keep shape: [batch_size, num_queries]
out["pred_logits"][:, :, 0] -= 1e4 * (~keep).float()
# Extract SAM2 backbone features if requested
if return_sam2_backbone_feats and "prev_encoder_out" in out:
backbone_data = out["prev_encoder_out"]["backbone_out"]
if self.is_multiplex and "interactive" in backbone_data:
out["_interactive_backbone"] = backbone_data["interactive"]
if "sam2_backbone_out" in backbone_data:
out["_sam2_backbone"] = backbone_data["sam2_backbone_out"]
out["_chunk_size"] = chunk_size
return out
def _slice_batched_output(
self,
chunk_outputs,
local_idx: int,
return_sam2_backbone_feats: bool,
):
"""
Slice a single frame's output from the batched chunk outputs.
"""
out = {}
# Keys to slice at batch dimension
batch_dim_keys = {
"pred_logits",
"pred_boxes",
"pred_boxes_xyxy",
"pred_masks",
"pred_logits_o2m",
"pred_boxes_o2m",
"pred_boxes_xyxy_o2m",
"pred_masks_o2m",
"queries",
"presence_logit_dec",
}
# Keys to skip
skip_keys = {
"_chunk_size",
"_interactive_backbone",
"_sam2_backbone",
"prev_encoder_out",
"encoder_hidden_states",
"aux_outputs",
}
for key, value in chunk_outputs.items():
if key in skip_keys:
continue
if key in batch_dim_keys and isinstance(value, torch.Tensor):
out[key] = value[local_idx : local_idx + 1]
elif isinstance(value, torch.Tensor):
try:
out[key] = value[local_idx : local_idx + 1]
except (IndexError, RuntimeError):
out[key] = value
# Add object IDs
if "pred_logits" in out:
out["pred_object_ids"] = self._get_dummy_object_ids(out["pred_logits"])
# Add SAM2 backbone features if requested
if return_sam2_backbone_feats:
if "_sam2_backbone" in chunk_outputs:
sam2_bb = chunk_outputs["_sam2_backbone"]
out["sam2_backbone_fpn_0"] = sam2_bb["backbone_fpn"][0].tensors[
local_idx : local_idx + 1
]
out["sam2_backbone_fpn_1"] = sam2_bb["backbone_fpn"][1].tensors[
local_idx : local_idx + 1
]
out["sam2_backbone_fpn_2"] = sam2_bb["backbone_fpn"][2].tensors[
local_idx : local_idx + 1
]
out["sam2_backbone_pos_enc"] = [
x[local_idx : local_idx + 1] for x in sam2_bb["vision_pos_enc"]
]
if self.is_multiplex and "_interactive_backbone" in chunk_outputs:
inte_bb = chunk_outputs["_interactive_backbone"]
out["interactive_backbone_fpn_0"] = inte_bb["backbone_fpn"][0].tensors[
local_idx : local_idx + 1
]
out["interactive_backbone_fpn_1"] = inte_bb["backbone_fpn"][1].tensors[
local_idx : local_idx + 1
]
out["interactive_backbone_fpn_2"] = inte_bb["backbone_fpn"][2].tensors[
local_idx : local_idx + 1
]
out["interactive_backbone_pos_enc"] = [
x[local_idx : local_idx + 1] for x in inte_bb["vision_pos_enc"]
]
return out
def _cleanup_previous_chunks_multigpu(
self,
grounding_cache,
current_chunk_key,
batch_size: int,
num_frames: int,
track_in_reverse: bool,
):
"""Remove previous chunks from cache to save GPU memory."""
chunk_start, chunk_end = current_chunk_key
if not track_in_reverse:
prev_chunk_start = chunk_start - batch_size
if prev_chunk_start >= 0:
prev_chunk_end = chunk_start
prev_chunk_key = (prev_chunk_start, prev_chunk_end)
# Cleanup grounding_buffer entry
chunk = grounding_cache["grounding_buffer"].pop(prev_chunk_key, None)
if chunk is not None:
del chunk
else:
next_chunk_start = chunk_end
if next_chunk_start < num_frames:
next_chunk_end = min(next_chunk_start + batch_size, num_frames)
next_chunk_key = (next_chunk_start, next_chunk_end)
grounding_cache["grounding_buffer"].pop(next_chunk_key, None)