import datetime import logging import math import os import sys from collections import defaultdict from copy import deepcopy from typing import Any, Dict, List, Optional, Set, Tuple import numpy as np import torch import torch.distributed as dist import torch.nn.functional as F from ..logger import get_logger from ..model.box_ops import fast_diag_box_iou from ..model.data_misc import BatchedDatapoint, NestedTensor from ..model.device_utils import accelerator_autocast from ..model.sam3_multiplex_detector import Sam3MultiplexDetector from ..model.sam3_tracker_utils import fill_holes_in_mask_scores, mask_to_box from ..model.sam3_video_base import ( _associate_det_trk_compilable, LazyAssociateDetTrkResult, MaskletConfirmationStatus, realize_adt_result, RealizedAssociateDetTrkresult, Sam3VideoBase, ) from ..perflib.masks_ops import mask_iou from ..train.masks_ops import rle_encode from torch import nn, Tensor # a short 3-min timeout to quickly detect any synchronization failures SAM3_COLLECTIVE_OP_TIMEOUT_SEC = int(os.getenv("SAM3_COLLECTIVE_OP_TIMEOUT_SEC", "180")) logger = get_logger(__name__) class Sam3MultiplexTrackerPredictor(nn.Module): def __init__( self, config_file, checkpoint_file=None, hydra_overrides=None, per_obj_inference=False, fill_hole_area=0, use_fa3=False, use_rope_real=True, keep_first_cond_frame=False, is_multiplex=False, is_multiplex_dynamic=False, use_memory_selection=False, ): """ Initialize the SAM2 predictor with the given configuration and checkpoint. Args: config_file (str): Path to the configuration file. checkpoint_file (str, optional): Path to the checkpoint file. If None, the model will be initialized without loading weights. hydra_overrides (list, optional): List of Hydra overrides to apply to the configuration. per_obj_inference (bool): If True, the model will perform per-object inference instead of bucketized batching. """ super().__init__() ####################################### # Load model from config and checkpoint ####################################### from hydra import compose, initialize_config_module from hydra.core.global_hydra import GlobalHydra from hydra.utils import instantiate package_root = __package__.rsplit(".model", 1)[0] # Ensure proper Hydra initialization if not GlobalHydra().is_initialized(): logger.info("Sam3MultiplexTrackerPredictor: GlobalHydra not initialized") GlobalHydra.instance().clear() initialize_config_module(f"{package_root}.config", version_base="1.2") if hydra_overrides is None: hydra_overrides = [] self.is_multiplex = is_multiplex self.is_multiplex_dynamic = is_multiplex_dynamic self.per_obj_inference = per_obj_inference if self.is_multiplex: inference_model_class = f"{package_root}.model.video_tracking_multiplex_demo.Sam3VideoTrackingMultiplexDemo" else: inference_model_class = ( f"{package_root}.model.video_tracking_with_prompt_demo_per_obj_inference.Sam3VideoTrackingWithPromptDemoPerObjInference" if per_obj_inference else f"{package_root}.model.video_tracking_with_prompt_demo.Sam3VideoTrackingWithPromptDemo" ) hydra_overrides = list(hydra_overrides) hydra_overrides.extend( [ "launcher.experiment_log_dir=''", f"++trainer.model._target_={inference_model_class}", # Shared backbone cfg "++trainer.model.image_size=1008", "++trainer.model.backbone_stride=14", "++trainer.model.maskmem_backbone.mask_downsampler.interpol_size=[1152,1152]", "++trainer.model.backbone.forward_in_chunk_for_eval=false", # always start tracking from the frame where we receive the first annotation # (clicks or mask) and ignore the `start_frame_idx` passed to `propagate_in_video` "++trainer.model.always_start_from_first_ann_frame=false", # apply non-overlapping constraints on the object masks in the # memory encoder to avoid/alleviate superposing mask predictions "++trainer.model.non_overlap_masks_for_mem_enc=false", # Do not apply non-overlapping constraints on the output "++trainer.model.non_overlap_masks_for_output=false", # attend to at most 4 temporally closest conditioning frames in the encoder for # better temporal locality and a better handling to a large number of annotated frames "++trainer.model.max_cond_frames_in_attn=4", f"++trainer.model.keep_first_cond_frame={keep_first_cond_frame}", # turn off all offloading options in the demo (we handle them separately in the demo class) "++trainer.model.offload_output_to_cpu_for_eval=false", "++trainer.model.trim_past_non_cond_mem_for_eval=false", # torch.compile on the image backbone (w/ `dynamic=false` and `fullgraph=true` to capture a full graph) # "++trainer.model.backbone.compile_mode=max-autotune", # "++trainer.model.backbone.compile_extra_args.fullgraph=true", # "++trainer.model.backbone.compile_extra_args.dynamic=false", "++trainer.model.backbone.visual.trunk.weights_path=null", # Postprocessing/demo options # dynamically fall back to multi-mask if the single mask is not stable "++trainer.model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", "++trainer.model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", "++trainer.model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking "++trainer.model.binarize_mask_from_pts_for_mem_enc=true", # only attend to object pointers in the past (before the current frame) in the encoder during evaluation "++trainer.model.only_obj_ptrs_in_the_past_for_eval=true", # clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks "++trainer.model.clear_non_cond_mem_around_input=true", "++trainer.model.transformer.encoder.layer.self_attention.feat_sizes=[72,72]", "++trainer.model.transformer.encoder.layer.cross_attention.feat_sizes=[72,72]", # fill small holes in the final masks up to `fill_hole_area` (after resizing them to the original video resolution) f"++trainer.model.fill_hole_area={fill_hole_area}", f"++trainer.model.transformer.encoder.layer.self_attention.use_fa3={use_fa3}", f"++trainer.model.transformer.encoder.layer.cross_attention.use_fa3={use_fa3}", f"++trainer.model.transformer.encoder.layer.self_attention.use_rope_real={use_rope_real}", f"++trainer.model.transformer.encoder.layer.cross_attention.use_rope_real={use_rope_real}", ] ) if self.is_multiplex or self.is_multiplex_dynamic: hydra_overrides.extend( [ f"++trainer.model.transformer.encoder.layer.self_attention_rope.use_fa3={use_fa3}", f"++trainer.model.transformer.encoder.layer.cross_attention_rope.use_fa3={use_fa3}", f"++trainer.model.transformer.encoder.layer.self_attention_rope.use_rope_real={use_rope_real}", f"++trainer.model.transformer.encoder.layer.cross_attention_rope.use_rope_real={use_rope_real}", ] ) hydra_overrides.extend( [f"++trainer.model.use_memory_selection={use_memory_selection}"] ) cfg = compose(config_name=config_file, overrides=hydra_overrides) model = instantiate(cfg.trainer.model, _recursive_=True) del model.backbone # Remove backbone since it is shared with the sam3 model if checkpoint_file is not None: ckpt = torch.load(checkpoint_file, map_location="cpu") model.load_state_dict(ckpt["model"], strict=False) self.model = model self.per_obj_inference = per_obj_inference self.fill_hole_area = fill_hole_area # use bfloat16 inference for Flash Attention kernel self.bf16_context = accelerator_autocast() self.bf16_context.__enter__() # keep using for the entire model process def __getattr__(self, name): # Expose all attributes of the underlying model model = super().__getattr__("model") if name == "model": return model return getattr(model, name) def forward(self, *args, **kwargs): raise NotImplementedError( "Use the sam2 predictor APIs instead. Check VideoTrackingWithPromptDemo class for details." ) def add_output_per_object(self, *args, **kwargs): if self.per_obj_inference: # nothing needs to be done as each object is already stored separately return # for batched inference state, we also need to add per-object # memory slides to support instance interactivity self._add_output_per_object(*args, **kwargs) class Sam3MultiplexBase(Sam3VideoBase): def __init__( self, tracker, detector, ckpt_path=None, sam3_ckpt_path=None, # prob threshold for detection outputs -- only keep detections above this threshold # enters NMS and det-to-track matching score_threshold_detection=0.5, # Detection threshold when running on image-only inputs image_only_det_thresh=0.5, # IoU threshold for detection NMS det_nms_thresh=0.0, # If `det_nms_use_iom` is True, use IoM instead of IoU for NMS det_nms_use_iom=False, # IoU threshold for det-to-track matching -- a detection is considered "matched" to a tracklet it # overlaps with a tracklet above this threshold -- it is often a loose threshold like 0.1 assoc_iou_thresh=0.5, # IoU threshold for det-to-track matching, which is used to determine whether a masklet is "unmatched" # by any detections -- it is often a stricter threshold like 0.5 trk_assoc_iou_thresh=0.5, # prob threshold for a detection to be added as a new object new_det_thresh=0.5, # hotstart parameters: we hold off the outputs for `hotstart_delay` frames and # 1) remove those tracklets unmatched by any detections based on `hotstart_unmatch_thresh` # 2) remove those tracklets overlapping with one another based on `hotstart_dup_thresh` hotstart_delay=0, hotstart_unmatch_thresh=3, hotstart_dup_thresh=3, # Whether to suppress masks only within hotstart. If False, we can suppress masks even if they start before hotstart period. suppress_unmatched_only_within_hotstart=True, init_trk_keep_alive=0, max_trk_keep_alive=8, min_trk_keep_alive=-4, # Threshold for suppressing overlapping objects based on recent occlusion suppress_overlapping_based_on_recent_occlusion_threshold=0.0, allow_unoccluded_to_suppress: bool = False, decrease_trk_keep_alive_for_empty_masklets=False, o2o_matching_masklets_enable=False, # Enable hungarian matching to match existing masklets suppress_det_close_to_boundary=False, fill_hole_area=16, sprinkle_removal_area=16, # The maximum number of objects (masklets) to track across all GPUs (for no limit, set it to -1) max_num_objects=128, # 128 objects (total across all GPUs) should be able to cover nearly all cases max_num_kboxes=20, recondition_every_nth_frame=-1, use_iom_recondition=False, iom_thresh_recondition=0.8, iou_thresh_recondition=0.8, is_multiplex=False, # masket confirmation status (to suppress unconfirmed masklets) masklet_confirmation_enable=False, # a masklet is confirmed after being consecutively detected and matched for # `masklet_confirmation_consecutive_det_thresh` masklet_confirmation_consecutive_det_thresh=3, # bbox heuristic parameters reconstruction_bbox_iou_thresh=0.0, reconstruction_bbox_det_score=0.5, reapply_no_object_pointer: bool = False, # reapply the no object pointer for suppressed objects running_in_prod=False, # Flag to specify if we are running in FBInfra for Insta Edit/Segments use_batched_grounding=False, batched_grounding_batch_size=1, **kwargs, ): nn.Module.__init__(self) assert isinstance(tracker, Sam3MultiplexTrackerPredictor) self.tracker = tracker assert isinstance(detector, Sam3MultiplexDetector) self.detector = detector if sam3_ckpt_path: ckpt = torch.load(sam3_ckpt_path, map_location="cpu", weights_only=True) self.detector.load_state_dict(ckpt["model"], strict=False) elif ckpt_path: self._load_checkpoint(ckpt_path, strict=False) self.score_threshold_detection = score_threshold_detection self.image_only_det_thresh = image_only_det_thresh self.det_nms_thresh = det_nms_thresh self.det_nms_use_iom = det_nms_use_iom self.assoc_iou_thresh = assoc_iou_thresh self.trk_assoc_iou_thresh = trk_assoc_iou_thresh self.new_det_thresh = new_det_thresh self.is_multiplex = is_multiplex self.running_in_prod = running_in_prod self.detector.running_in_prod = running_in_prod assert ( self.is_multiplex == self.tracker.is_multiplex == self.detector.is_multiplex ), ( f"is_multiplex must be the same for all models: {self.is_multiplex=}, {self.tracker.is_multiplex=}, {self.detector.is_multiplex=}" ) # hotstart parameters if hotstart_delay > 0: assert hotstart_unmatch_thresh <= hotstart_delay assert hotstart_dup_thresh <= hotstart_delay self.hotstart_delay = hotstart_delay self.hotstart_unmatch_thresh = hotstart_unmatch_thresh self.hotstart_dup_thresh = hotstart_dup_thresh self.suppress_unmatched_only_within_hotstart = ( suppress_unmatched_only_within_hotstart ) self.init_trk_keep_alive = init_trk_keep_alive self.max_trk_keep_alive = max_trk_keep_alive self.min_trk_keep_alive = min_trk_keep_alive self.suppress_overlapping_based_on_recent_occlusion_threshold = ( suppress_overlapping_based_on_recent_occlusion_threshold ) self.allow_unoccluded_to_suppress = allow_unoccluded_to_suppress self.suppress_det_close_to_boundary = suppress_det_close_to_boundary self.decrease_trk_keep_alive_for_empty_masklets = ( decrease_trk_keep_alive_for_empty_masklets ) self.o2o_matching_masklets_enable = o2o_matching_masklets_enable self.fill_hole_area = fill_hole_area self.sprinkle_removal_area = sprinkle_removal_area self.eval() self.rank = int(os.getenv("RANK", "0")) self.world_size = int(os.getenv("WORLD_SIZE", "1")) self._dist_pg_cpu = None # CPU process group (lazy-initialized on first use) # Initialize profiling variables self._profiler = None self._frame_count = 0 self._profile_save_dir = os.getenv("PROFILE_SAVE_DIR", "/tmp/profiling") self._profiling_enabled = os.getenv("ENABLE_PROFILING", "0").lower() == "1" # the maximum object number if max_num_objects > 0: multiplex_divisor = ( self.tracker.multiplex_controller.allowed_bucket_capacity if self.is_multiplex else 1 ) num_obj_for_compile = math.ceil( max_num_objects / (self.world_size * multiplex_divisor) ) else: max_num_objects = 10000 # no limit num_obj_for_compile = 16 logger.info( f"`setting max_num_objects` to {max_num_objects} -- creating {num_obj_for_compile=} objects for torch.compile cache" ) self.max_num_objects = max_num_objects self.num_obj_for_compile = num_obj_for_compile self.max_num_kboxes = max_num_kboxes self.recondition_every_nth_frame = recondition_every_nth_frame self.use_iom_recondition = use_iom_recondition self.iom_thresh_recondition = iom_thresh_recondition self.iou_thresh_recondition = iou_thresh_recondition self.masklet_confirmation_enable = masklet_confirmation_enable self.masklet_confirmation_consecutive_det_thresh = ( masklet_confirmation_consecutive_det_thresh ) self.reconstruction_bbox_iou_thresh = reconstruction_bbox_iou_thresh self.reconstruction_bbox_det_score = reconstruction_bbox_det_score self.reapply_no_object_pointer = reapply_no_object_pointer # Batched grounding configuration self.use_batched_grounding = use_batched_grounding self.batched_grounding_batch_size = ( batched_grounding_batch_size # Batch size for batched grounding ) if self.is_multiplex: assert not self.tracker.multiplex_controller.training, ( "This model class should only be used for eval." ) self.bucket_capacity: int = ( self.tracker.multiplex_controller.allowed_bucket_capacity ) def all_gather_cpu(self, tensor_list, tensor): if self._dist_pg_cpu is None: self._init_dist_pg_cpu() dist.broadcast(tensor_list, tensor, group=self._dist_pg_cpu) def all_gather_python_obj_cpu(self, object_list, python_obj): if self._dist_pg_cpu is None: self._init_dist_pg_cpu() dist.all_gather_object(object_list, python_obj, group=self._dist_pg_cpu) def broadcast_cpu(self, x, src): if self._dist_pg_cpu is None: self._init_dist_pg_cpu() dist.broadcast(x, src=src, group=self._dist_pg_cpu) def _start_profiling(self, frame_idx): self._profiling_enabled = os.getenv("ENABLE_PROFILING", "0").lower() == "1" self._profile_end_frame = int(os.getenv("PROFILE_END_FRAME", "-1")) """Start profiling for _det_track_one_frame if conditions are met.""" if not self._profiling_enabled: return False if not getattr(self, "_warm_up_complete", False): return False if self._profiler is not None: return True # Start profiling os.makedirs(self._profile_save_dir, exist_ok=True) profile_path = os.path.join( self._profile_save_dir, f"det_track_frame_rank_{self.rank}.json.gz" ) self._profiler = torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], record_shapes=True, experimental_config=torch.profiler._ExperimentalConfig( profile_all_threads=True ), ) self._profiler.start() self._current_profile_path = profile_path print(f"Started profiling frame on {frame_idx} on rank {self.rank}") return True def _stop_profiling(self): """Stop profiling and save trace.""" if self._profiler is not None: self._profiler.stop() self._profiler.export_chrome_trace(self._current_profile_path) print(f"Profiling trace saved to: {self._current_profile_path}") print( f"You can open this file in Perfetto (https://ui.perfetto.dev/) to visualize the trace" ) self._profiler = None self._profiling_enabled = False os.environ["ENABLE_PROFILING"] = "0" def _det_track_one_frame( self, frame_idx: int, num_frames: int, reverse: bool, input_batch: BatchedDatapoint, geometric_prompt: Any, tracker_states_local: List[Any], tracker_metadata_prev: Dict[str, Any], feature_cache: Dict, orig_vid_height: int, orig_vid_width: int, is_image_only: bool = False, ): profiling_enabled = self._start_profiling(frame_idx) try: return self._det_track_one_frame_impl( frame_idx=frame_idx, num_frames=num_frames, reverse=reverse, input_batch=input_batch, geometric_prompt=geometric_prompt, tracker_states_local=tracker_states_local, tracker_metadata_prev=tracker_metadata_prev, feature_cache=feature_cache, orig_vid_height=orig_vid_height, orig_vid_width=orig_vid_width, is_image_only=is_image_only, ) finally: if profiling_enabled: if sys.exc_info()[0] is not None: # If there is an exception, stop profiling self._stop_profiling() else: if ( (not reverse and frame_idx == num_frames - 1) or (reverse and frame_idx == 0) or self._profile_end_frame == frame_idx ): # Stop profiling if reached the last frame self._stop_profiling() def _det_track_one_frame_impl( self, frame_idx: int, num_frames: int, reverse: bool, input_batch: BatchedDatapoint, geometric_prompt: Any, tracker_states_local: List[Any], tracker_metadata_prev: Dict[str, Any], feature_cache: Dict, orig_vid_height: int, orig_vid_width: int, is_image_only: bool, ): """ This function handles one-step inference for the multiplex model in an SPMD manner. At a high-level, all GPUs execute the same function calls as if it's done on a single GPU, while under the hood, some function calls involve distributed computation based on sharded SAM2 states. - `input_batch` contains image and other inputs on the entire video; it should be identical across GPUs - `tracker_states_local` holds the local masklet information in this GPU shard - `tracker_metadata_prev` manages the metadata for SAM2 objects, such as which masklet is hold on which GPUs it contains both global and local masklet information """ # Step 1: run backbone and FA in a distributed manner -- this is done via Sam3MultiplexDetector, # a distributed FA model (assigned to `self.detector`) that shards frames in a round-robin manner. # It returns a "det_out" dict for `frame_idx` and fills SAM2 backbone features for `frame_idx` # into `feature_cache`. Despite its distributed inference under the hood, the results would be # the same as if it is running backbone and FA for every frame on a single GPU. with torch.profiler.record_function("run_backbone_and_detection"): det_out, pos_pred_mask = self.run_backbone_and_detection( frame_idx=frame_idx, num_frames=num_frames, reverse=reverse, input_batch=input_batch, geometric_prompt=geometric_prompt, feature_cache=feature_cache, use_batched_grounding=self.use_batched_grounding, batched_grounding_batch_size=self.batched_grounding_batch_size, ) # Step 2: each GPU propagates its local SAM2 states to get the SAM2 prediction masks. # the returned `tracker_low_res_masks_global` contains the concatenated masklet predictions # gathered from all GPUs (as if they are propagated on a single GPU). Note that this step only # runs the SAM2 propagation step, but doesn't encode new memory for the predicted masks; # we defer memory encoding to `run_tracker_update_execution_phase` after resolving all heuristics. with torch.profiler.record_function("run_tracker_propagation"): if tracker_metadata_prev == {}: # initialize masklet metadata if it's uninitialized (empty dict) tracker_metadata_prev.update(self._initialize_metadata()) tracker_low_res_masks_global, tracker_obj_scores_global = ( self.run_tracker_propagation( frame_idx=frame_idx, num_frames=num_frames, reverse=reverse, tracker_states_local=tracker_states_local, tracker_metadata_prev=tracker_metadata_prev, ) ) with torch.profiler.record_function("GPU sync and filter"): # Remove leading dimension (assumes batch size 1) assert pos_pred_mask.shape[0] == 1 pos_pred_mask = pos_pred_mask.squeeze(0) det_out = {k: det_out[k][0] for k in det_out} # Move detections we'll actually keep at the top for future logic pos_pred_mask_idx = pos_pred_mask.argsort(descending=True) pos_pred_mask = torch.index_select( pos_pred_mask, dim=0, index=pos_pred_mask_idx ) det_out = { k: torch.index_select(det_out[k], dim=0, index=pos_pred_mask_idx) for k in det_out } # Step 3: based on detection outputs and the propagated SAM2 prediction masks, we make plans # for SAM2 masklet updates (i.e. which objects to add and remove, how to load-balance them, etc). # We also run SAM2 memory encoder globally in this step to resolve non-overlapping constraints. # **This step should involve all the heuristics needed for any updates.** Most of the update # planning will be done on the master rank (GPU 0) and the resulting plan `sam2_update_plan` is # broadcasted to other GPUs (to be executed in a distributed manner). This step also generates the # new masklet metadata `tracker_metadata_new` (based on its previous version `tracker_metadata_prev`). with torch.profiler.record_function("run_tracker_update_planning_phase"): sam2_update_plan, tracker_metadata_new = ( self.run_tracker_update_planning_phase( frame_idx=frame_idx, num_frames=num_frames, reverse=reverse, det_out=det_out, det_keep=pos_pred_mask, tracker_low_res_masks_global=tracker_low_res_masks_global, tracker_obj_scores_global=tracker_obj_scores_global, tracker_metadata_prev=tracker_metadata_prev, tracker_states_local=tracker_states_local, is_image_only=is_image_only, ) ) # Get reconditioning info from the update plan reconditioned_obj_ids = sam2_update_plan.get("reconditioned_obj_ids", set()) det_to_matched_trk_obj_ids = sam2_update_plan.get( "det_to_matched_trk_obj_ids", {} ) # Step 4: based on `sam2_update_plan`, each GPU executes the update w.r.t. its local SAM2 inference states with torch.profiler.record_function("run_tracker_update_execution_phase"): tracker_states_local_new = self.run_tracker_update_execution_phase( frame_idx=frame_idx, num_frames=num_frames, reverse=reverse, det_out=det_out, tracker_states_local=tracker_states_local, tracker_update_plan=sam2_update_plan, tracker_metadata_new=tracker_metadata_new, orig_vid_height=orig_vid_height, orig_vid_width=orig_vid_width, feature_cache=feature_cache, ) # Step 5: finally, build the outputs for this frame (it only needs to be done on GPU 0 since # only GPU 0 will send outputs to the server). with torch.profiler.record_function("build_outputs"): if self.rank == 0: obj_id_to_mask = self.build_outputs( frame_idx=frame_idx, num_frames=num_frames, reverse=reverse, det_out=det_out, tracker_low_res_masks_global=tracker_low_res_masks_global, tracker_obj_scores_global=tracker_obj_scores_global, tracker_metadata_prev=tracker_metadata_prev, sam2_update_plan=sam2_update_plan, orig_vid_height=orig_vid_height, orig_vid_width=orig_vid_width, reconditioned_obj_ids=reconditioned_obj_ids, det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids, ) obj_id_to_score = tracker_metadata_new["obj_id_to_score"] else: obj_id_to_mask, obj_id_to_score = {}, {} # dummy outputs on other GPUs # a few statistics for the current frame as a part of the output frame_stats = { "num_obj_tracked": np.sum(tracker_metadata_new["num_obj_per_gpu"]), "num_obj_dropped": sam2_update_plan["num_obj_dropped_due_to_limit"], } # add sam2 scores to metadata, it should be fired for frames except the first frame if tracker_obj_scores_global.shape[0] > 0: # Convert tracker_obj_scores_global to sigmoid scores before updating tracker_obj_scores_global = tracker_obj_scores_global.sigmoid() sam2_obj_ids = tracker_metadata_prev["obj_ids_all_gpu"] tracker_metadata_new["obj_id_to_sam2_score_frame_wise"][frame_idx].update( dict(zip(sam2_obj_ids, tracker_obj_scores_global)) ) return ( obj_id_to_mask, # a dict: obj_id --> output mask obj_id_to_score, # a dict: obj_id --> output score (prob) tracker_states_local_new, tracker_metadata_new, frame_stats, tracker_obj_scores_global, # a dict: obj_id --> sam2 frame-level scores ) def run_backbone_and_detection( self, frame_idx: int, num_frames: int, input_batch: BatchedDatapoint, geometric_prompt: Any, feature_cache: Dict, reverse: bool, use_batched_grounding: bool = False, batched_grounding_batch_size: int = 16, ): # Step 1: if text feature is not cached in `feature_cache`, compute and cache it text_batch_key = tuple(input_batch.find_text_batch) if "text" not in feature_cache or text_batch_key not in feature_cache["text"]: text_outputs = self.detector.backbone.forward_text( input_batch.find_text_batch, device=self.device ) # note: we only cache the text feature of the most recent prompt feature_cache["text"] = {text_batch_key: text_outputs} else: text_outputs = feature_cache["text"][text_batch_key] feature_cache.pop(frame_idx - 1 if not reverse else frame_idx + 1, None) # Step 2: run backbone, FA detection, and post-processing with NMS # Extract max_frame_num_to_track from feature_cache if available tracking_bounds = feature_cache.get("tracking_bounds", {}) max_frame_num_to_track = tracking_bounds.get("max_frame_num_to_track") start_frame_idx = tracking_bounds.get("propagate_in_video_start_frame_idx") backbone_out = { "img_batch_all_stages": input_batch.img_batch, **text_outputs, } if use_batched_grounding: # Use fully batched forward_grounding approach if "grounding_cache" not in feature_cache: feature_cache["grounding_cache"] = {} with torch.profiler.record_function( "forward_video_grounding_batched_multigpu" ): sam3_image_out, _ = ( self.detector.forward_video_grounding_batched_multigpu( backbone_out=backbone_out, find_inputs=input_batch.find_inputs, geometric_prompt=geometric_prompt, frame_idx=frame_idx, num_frames=num_frames, grounding_cache=feature_cache["grounding_cache"], track_in_reverse=reverse, return_sam2_backbone_feats=True, run_nms=self.det_nms_thresh > 0.0, nms_prob_thresh=self.score_threshold_detection, nms_iou_thresh=self.det_nms_thresh, nms_use_iom=self.det_nms_use_iom, max_frame_num_to_track=max_frame_num_to_track, propagate_in_video_start_frame_idx=start_frame_idx, feature_cache=feature_cache, batch_size=batched_grounding_batch_size, ) ) else: # Use existing multi-GPU distributed approach if "multigpu_buffer" not in feature_cache: # "multigpu_buffer" is a buffer cache used by `self.detector` and it needs # to be passed to `forward_video_grounding_multigpu` for every call feature_cache["multigpu_buffer"] = {} with torch.profiler.record_function("forward_video_grounding_multigpu"): sam3_image_out, _ = self.detector.forward_video_grounding_multigpu( backbone_out=backbone_out, find_inputs=input_batch.find_inputs, geometric_prompt=geometric_prompt, frame_idx=frame_idx, num_frames=num_frames, multigpu_buffer=feature_cache["multigpu_buffer"], track_in_reverse=reverse, # also get the SAM2 backbone features return_sam2_backbone_feats=True, # run NMS as a part of distributed FA computation run_nms=self.det_nms_thresh > 0.0, nms_prob_thresh=self.score_threshold_detection, nms_iou_thresh=self.det_nms_thresh, nms_use_iom=self.det_nms_use_iom, # pass max_frame_num_to_track to respect tracking limits max_frame_num_to_track=max_frame_num_to_track, propagate_in_video_start_frame_idx=start_frame_idx, # pass feature_cache for buffered backbone computation feature_cache=feature_cache, ) # note: detections in `sam3_image_out` has already gone through NMS pred_probs = sam3_image_out["pred_logits"].squeeze(-1).sigmoid() pred_boxes_xyxy = sam3_image_out["pred_boxes_xyxy"] pred_masks = sam3_image_out["pred_masks"] # get the positive detection outputs above threshold pos_pred_mask = pred_probs > self.score_threshold_detection if self.suppress_det_close_to_boundary: # Suppress detections too close to image edges (for normalized boxes). keep = self._suppress_detections_close_to_boundary(pred_boxes_xyxy) pos_pred_mask = pos_pred_mask & keep det_out = { "bbox": pred_boxes_xyxy, "mask": pred_masks, "scores": pred_probs, } # Step 3: build SAM2 backbone features and store them in `feature_cache` backbone_cache = {} if self.is_multiplex: # For the multiplex model we have separate interaction and propagation features # TODO: We do not need the interaction features every frame so there are rooms for optimization interaction_sam_mask_decoder = self.tracker.interactive_sam_mask_decoder interaction_backbone_fpn = [ interaction_sam_mask_decoder.conv_s0( sam3_image_out["interactive_backbone_fpn_0"] ), interaction_sam_mask_decoder.conv_s1( sam3_image_out["interactive_backbone_fpn_1"] ), sam3_image_out[ "interactive_backbone_fpn_2" ], # fpn_2 doesn't need additional conv ] interaction_backbone_out = { "vision_features": interaction_backbone_fpn[-1], # top-level feature "vision_mask": None, "vision_pos_enc": sam3_image_out["interactive_backbone_pos_enc"], "backbone_fpn": [ NestedTensor(x, None) for x in interaction_backbone_fpn ], } backbone_cache["interactive"] = interaction_backbone_out sam_mask_decoder = self.tracker.sam_mask_decoder sam2_backbone_fpn = [ sam_mask_decoder.conv_s0(sam3_image_out["sam2_backbone_fpn_0"]), sam_mask_decoder.conv_s1(sam3_image_out["sam2_backbone_fpn_1"]), sam3_image_out["sam2_backbone_fpn_2"], # fpn_2 doesn't need additional conv ] sam2_backbone_out = { "vision_features": sam2_backbone_fpn[-1], # top-level feature "vision_mask": None, "vision_pos_enc": sam3_image_out["sam2_backbone_pos_enc"], "backbone_fpn": [NestedTensor(x, None) for x in sam2_backbone_fpn], } backbone_cache["sam2_backbone_out"] = sam2_backbone_out with torch.profiler.record_function("run_backbone_and_detection.feature_cache"): feature_cache[frame_idx] = ( input_batch.img_batch.tensors[frame_idx], backbone_cache, ) # remove from `feature_cache` old features to save GPU memory feature_cache.pop(frame_idx - 1 if not reverse else frame_idx + 1, None) return det_out, pos_pred_mask def run_tracker_propagation( self, frame_idx: int, num_frames: int, reverse: bool, tracker_states_local: List[Any], tracker_metadata_prev: Dict[str, np.ndarray], ): # Step 1: propagate the local SAM2 states to get the current frame's prediction # `low_res_masks_local` of the existing masklets on this GPU # - obj_ids_local: List[int] -- list of object IDs # - low_res_masks_local: Tensor -- (num_local_obj, H_mask, W_mask) with torch.profiler.record_function("propagate_tracker_one_frame_local_gpu"): obj_ids_local, low_res_masks_local, obj_scores_local = ( self._propogate_tracker_one_frame_local_gpu( tracker_states_local, frame_idx=frame_idx, reverse=reverse ) ) assert np.all( obj_ids_local == tracker_metadata_prev["obj_ids_per_gpu"][self.rank] ), "{} != {}".format( obj_ids_local, tracker_metadata_prev["obj_ids_per_gpu"][self.rank] ) # Step 2: all-gather `low_res_masks_local` into `low_res_masks_global` # - low_res_masks_global: Tensor -- (num_global_obj, H_mask, W_mask) with torch.profiler.record_function("all_gather_low_res_masks_local"): _, H_mask, W_mask = low_res_masks_local.shape if self.world_size > 1: # `low_res_masks_local` and `obj_scores_local` need to be contiguous and float32 # (they could be non-contiguous due to slicing and/or bfloat16 due to autocast) low_res_masks_local = low_res_masks_local.float().contiguous() obj_scores_local = obj_scores_local.float().contiguous() num_obj_this_gpu = tracker_metadata_prev["num_obj_per_gpu"][self.rank] assert low_res_masks_local.size(0) == num_obj_this_gpu assert obj_scores_local.size(0) == num_obj_this_gpu low_res_masks_peers = [ low_res_masks_local.new_empty(num_obj, H_mask, W_mask) for num_obj in tracker_metadata_prev["num_obj_per_gpu"] ] obj_scores_peers = [ obj_scores_local.new_empty(num_obj) for num_obj in tracker_metadata_prev["num_obj_per_gpu"] ] dist.all_gather(low_res_masks_peers, low_res_masks_local) dist.all_gather(obj_scores_peers, obj_scores_local) low_res_masks_global = torch.cat(low_res_masks_peers, dim=0) obj_scores_global = torch.cat(obj_scores_peers, dim=0) else: low_res_masks_global = low_res_masks_local obj_scores_global = obj_scores_local return low_res_masks_global, obj_scores_global def _recondition_masklets( self, frame_idx, det_out: Dict[str, Tensor], trk_id_to_max_iou_high_conf_det: Dict[int, int], # trk_obj_id -> det_idx tracker_states_local: List[Any], tracker_metadata: Dict[str, np.ndarray], tracker_obj_scores_global: Tensor, tracker_low_res_masks_global: Tensor, ): reconditioned_obj_ids = set() HIGH_CONF_THRESH = 0.8 input_mask_res = self.tracker.input_mask_size if len(trk_id_to_max_iou_high_conf_det) == 0: return tracker_states_local, reconditioned_obj_ids # === BATCH ALL INDEX LOOKUPS ON GPU === trk_obj_ids = list(trk_id_to_max_iou_high_conf_det.keys()) det_indices = list(trk_id_to_max_iou_high_conf_det.values()) # Convert obj_ids_all_gpu to tensor once (keep on GPU) obj_ids_all_gpu_t = torch.from_numpy(tracker_metadata["obj_ids_all_gpu"]).to( device=tracker_obj_scores_global.device ) trk_obj_ids_t = torch.tensor( trk_obj_ids, device=tracker_obj_scores_global.device ) det_indices_t = torch.tensor( det_indices, device=tracker_obj_scores_global.device ) # Batched lookup: find obj_idx for each trk_obj_id # Shape: (num_trk, num_all_obj) -> find matching indices matches = trk_obj_ids_t.unsqueeze(1) == obj_ids_all_gpu_t.unsqueeze(0) # (N, M) obj_indices_t = matches.int().argmax(dim=1) # (N,) # Batched score lookup and filtering - NO SYNC until we need CPU decision obj_scores_batch = tracker_obj_scores_global[obj_indices_t].sigmoid() # (N,) high_conf_mask = obj_scores_batch > HIGH_CONF_THRESH # (N,) bool tensor on GPU # === SINGLE SYNC POINT: Transfer filter mask to CPU === high_conf_mask_cpu = high_conf_mask.cpu().numpy() # Filter to only high-confidence items valid_trk_obj_ids = [ tid for tid, valid in zip(trk_obj_ids, high_conf_mask_cpu) if valid ] valid_det_indices = [ did for did, valid in zip(det_indices, high_conf_mask_cpu) if valid ] valid_obj_indices = obj_indices_t[high_conf_mask] # Keep as tensor if len(valid_trk_obj_ids) == 0: return tracker_states_local, reconditioned_obj_ids # === BATCH MASK OPERATIONS === valid_det_indices_t = torch.tensor( valid_det_indices, device=det_out["mask"].device ) # Batch fetch all detection masks at once new_masks = det_out["mask"][valid_det_indices_t] # (K, H, W) new_masks_binary = ( F.interpolate( new_masks.unsqueeze(1), size=(input_mask_res, input_mask_res), mode="bilinear", align_corners=False, ).squeeze(1) > 0 ) # (K, H, W) # Batch update low_res_masks_global old_masks = tracker_low_res_masks_global[valid_obj_indices] # (K, H, W) binary_agreement = (new_masks > 0) == (old_masks > 0) updated_masks = torch.where(binary_agreement, old_masks, new_masks) # Batch hole filling updated_masks = fill_holes_in_mask_scores( updated_masks.unsqueeze(1), fill_hole_area=self.fill_hole_area, sprinkle_removal_area=self.sprinkle_removal_area, fill_holes=True, remove_sprinkles=True, ).squeeze(1) # Write back (scatter) tracker_low_res_masks_global[valid_obj_indices] = updated_masks # === NOW DO THE STATE UPDATES (still needs iteration but with pre-filtered data) === if self.is_multiplex: state_to_recondition_info = {} for i, trk_obj_id in enumerate(valid_trk_obj_ids): for state_idx, inference_state in enumerate(tracker_states_local): if trk_obj_id in inference_state["obj_ids"]: if state_idx not in state_to_recondition_info: state_to_recondition_info[state_idx] = [] state_to_recondition_info[state_idx].append( (trk_obj_id, new_masks_binary[i]) ) break for state_idx, recondition_list in state_to_recondition_info.items(): inference_state = tracker_states_local[state_idx] obj_ids_to_recondition = [item[0] for item in recondition_list] masks_to_recondition = torch.stack( [item[1] for item in recondition_list] ) with torch.profiler.record_function( "_recodition_masklets.add_new_masks" ): self.tracker.add_new_masks( inference_state=inference_state, frame_idx=frame_idx, obj_ids=obj_ids_to_recondition, masks=masks_to_recondition, reconditioning=True, ) reconditioned_obj_ids.update(inference_state["obj_idx_to_id"].values()) else: # Non-multiplex: still iterate but masks already computed for i, trk_obj_id in enumerate(valid_trk_obj_ids): for inference_state in tracker_states_local: if trk_obj_id in inference_state["obj_ids"]: self.tracker.add_new_mask( inference_state=inference_state, frame_idx=frame_idx, obj_id=trk_obj_id, mask=new_masks_binary[i], ) reconditioned_obj_ids.update( inference_state["obj_idx_to_id"].values() ) break return tracker_states_local, reconditioned_obj_ids def _deepcopy(self, x): # If running in prod, dont need to do a deepcopy as we only traverse in 1 direction if True: return x return deepcopy(x) def run_tracker_update_planning_phase( self, frame_idx: int, num_frames: int, reverse: bool, det_out: Dict[str, Tensor], det_keep: Tensor, tracker_low_res_masks_global: Tensor, tracker_obj_scores_global: Tensor, tracker_metadata_prev: Dict[str, np.ndarray], tracker_states_local: List[Any], is_image_only: bool = False, ): # initialize new metadata from previous metadata (its values will be updated later) with torch.profiler.record_function("initialize_tracker_metadata_new"): tracker_metadata_new = self._create_planning_metadata(tracker_metadata_prev) # Initialize reconditioned_obj_ids early to avoid UnboundLocalError reconditioned_obj_ids = set() # Step 1: make the update plan and resolve heuristics on GPU 0 det_mask_preds: Tensor = det_out["mask"] # low-res mask logits det_scores: Tensor = det_out["scores"].float() # a) match FA and SAM2 masks and find new objects with torch.profiler.record_function("associate_det_trk"): adt_result = self._associate_det_trk( det_masks=det_mask_preds, det_scores=det_scores, det_keep=det_keep, trk_masks=tracker_low_res_masks_global, trk_obj_ids=tracker_metadata_prev["obj_ids_all_gpu"], default_det_thresh=( self.image_only_det_thresh if is_image_only else None ), ) # b) handle hotstart heuristics to remove objects (GPU-vectorized, no sync!) # here `rank0_metadata` contains metadata stored on (and only accessible to) GPU 0; # we avoid broadcasting them to other GPUs to save communication cost, assuming # that `rank0_metadata` is not needed by other GPUs rank0_metadata_new = self._deepcopy(tracker_metadata_prev["rank0_metadata"]) if not hasattr(self, "_warm_up_complete") or self._warm_up_complete: # Call GPU-vectorized hotstart using lazy adt_result (NO realize_adt yet!) with torch.profiler.record_function("_process_hotstart_gpu"): to_remove_mask, to_suppress_mask, gpu_metadata_new = ( self._process_hotstart_gpu( frame_idx=frame_idx, reverse=reverse, adt_result=adt_result, # Still lazy - no sync! tracker_metadata_prev=tracker_metadata_prev, gpu_metadata_prev=tracker_metadata_prev["gpu_metadata"], ) ) # IMPORTANT: From this point, tracker_metadata_new["gpu_metadata"] is updated but CPU metadata (obj_ids_all_gpu, etc.) is NOT tracker_metadata_new["gpu_metadata"] = gpu_metadata_new else: # if warm-up is not complete, we don't remove any objects N_obj = tracker_low_res_masks_global.size(0) to_remove_mask = torch.zeros( N_obj, dtype=torch.bool, device=tracker_low_res_masks_global.device ) to_suppress_mask = torch.zeros( N_obj, dtype=torch.bool, device=tracker_low_res_masks_global.device ) tracker_metadata_new["rank0_metadata"] = rank0_metadata_new # Step 3 (optional): recondition masklets based on high-confidence detections before memory encoding # NOTE: Running this in execution phase (after memory encoding) can lead to suboptimal results should_recondition_iou = False # Evaluate tracklets for reconditioning based on bbox IoU mismatch with detections if self.reconstruction_bbox_iou_thresh > 0: adt_result = realize_adt_result( adt_result, tracker_metadata_prev, det_mask_preds ) if ( self.reconstruction_bbox_iou_thresh > 0 and len(adt_result.trk_id_to_max_iou_high_conf_det) > 0 ): with torch.profiler.record_function( "evaluate_reconstruction_bbox_iou_thresh" ): trk_obj_ids = adt_result.trk_id_to_max_iou_high_conf_det.keys() sam2_obj_ids_all_gpu = list(tracker_metadata_prev["obj_ids_all_gpu"]) trk_ids = [ sam2_obj_ids_all_gpu.index(trk_obj_id) for trk_obj_id in trk_obj_ids if trk_obj_id in sam2_obj_ids_all_gpu ] det_ids = list(adt_result.trk_id_to_max_iou_high_conf_det.values()) det_boxes_bbox_iou = det_out["bbox"][det_ids] det_scores_bbox_iou = det_out["scores"][det_ids] sam2_mask = tracker_low_res_masks_global[trk_ids] mask_binary = sam2_mask > 0 sam2_box_pixels = mask_to_box(mask_binary.unsqueeze(1)).squeeze(1) mask_height, mask_width = sam2_mask.shape[-2:] sam2_box_normalized = sam2_box_pixels / torch.tensor( [mask_width, mask_height, mask_width, mask_height], device=sam2_box_pixels.device, ) iou = fast_diag_box_iou(det_boxes_bbox_iou, sam2_box_normalized)[0] if iou < self.reconstruction_bbox_iou_thresh and torch.any( det_scores_bbox_iou >= self.reconstruction_bbox_det_score ): should_recondition_iou = True if ( self.recondition_every_nth_frame > 0 and frame_idx % self.recondition_every_nth_frame == 0 ): adt_result = realize_adt_result( adt_result, tracker_metadata_prev, det_mask_preds ) should_recondition_periodic = ( self.recondition_every_nth_frame > 0 and frame_idx % self.recondition_every_nth_frame == 0 and len(adt_result.trk_id_to_max_iou_high_conf_det) > 0 ) # Recondition if periodic or IoU condition met if should_recondition_periodic or should_recondition_iou: adt_result = realize_adt_result( adt_result, tracker_metadata_prev, det_mask_preds ) # NOTE: sam2_low_res_mask_global is modified in-place on all GPUs. with torch.profiler.record_function("_recondition_masklets"): tracker_states_local, reconditioned_obj_ids = ( self._recondition_masklets( frame_idx, det_out, adt_result.trk_id_to_max_iou_high_conf_det, tracker_states_local, tracker_metadata_prev, tracker_obj_scores_global, tracker_low_res_masks_global, ) ) for state in tracker_states_local: if any( obj_id in reconditioned_obj_ids for obj_id in state.get("obj_ids", []) ): self.tracker.propagate_in_video_preflight( state, run_mem_encoder=True ) # Step 4: Run SAM2 memory encoder on the current frame's prediction masks # This is done on all GPUs batch_size = tracker_low_res_masks_global.size(0) if batch_size > 0: if not hasattr(self, "_warm_up_complete") or self._warm_up_complete: if self.suppress_overlapping_based_on_recent_occlusion_threshold > 0.0: # NOTE: tracker_low_res_masks_global is updated in-place then returned with torch.profiler.record_function( "_suppress_overlapping_based_on_recent_occlusion" ): tracker_low_res_masks_global = ( self._suppress_overlapping_based_on_recent_occlusion( frame_idx, tracker_low_res_masks_global, tracker_metadata_prev, tracker_metadata_new, to_remove_mask, # GPU boolean mask, no sync! reverse, ) ) with torch.profiler.record_function("_tracker_update_memories"): self._tracker_update_memories( tracker_states_local, frame_idx, tracker_metadata=tracker_metadata_prev, low_res_masks=tracker_low_res_masks_global, ) # NOW realize adt_result after memory encoding (sync only for GPU load balancing) adt_result = realize_adt_result( adt_result, tracker_metadata_prev, det_mask_preds ) new_det_obj_ids, new_det_gpu_ids, num_obj_dropped_due_to_limit = ( adt_result.get_new_det_gpu_ids( tracker_metadata_prev, is_image_only, det_scores, self ) ) # Convert GPU removal mask to CPU obj_id set for metadata updates if not hasattr(self, "_warm_up_complete") or self._warm_up_complete: obj_ids_all_gpu = tracker_metadata_prev["obj_ids_all_gpu"] to_remove_cpu = to_remove_mask.cpu().numpy() obj_ids_newly_removed = set(obj_ids_all_gpu[to_remove_cpu].tolist()) else: obj_ids_newly_removed = set() # Step 4: update the SAM2 metadata based on the update plan # note: except for "rank0_metadata" (that is only available on GPU 0), # the updated `tracker_metadata_new` should be identical on all GPUs for rank in range(self.world_size): new_det_obj_ids_this_gpu = new_det_obj_ids[new_det_gpu_ids == rank] updated_obj_ids_this_gpu = tracker_metadata_new["obj_ids_per_gpu"][rank] if len(new_det_obj_ids_this_gpu) > 0: updated_obj_ids_this_gpu = np.concatenate( [updated_obj_ids_this_gpu, new_det_obj_ids_this_gpu] ) if len(obj_ids_newly_removed) > 0: is_removed = np.isin( updated_obj_ids_this_gpu, list(obj_ids_newly_removed) ) updated_obj_ids_this_gpu = updated_obj_ids_this_gpu[~is_removed] tracker_metadata_new["obj_ids_per_gpu"][rank] = updated_obj_ids_this_gpu tracker_metadata_new["num_obj_per_gpu"][rank] = len( updated_obj_ids_this_gpu ) tracker_metadata_new["obj_ids_all_gpu"] = np.concatenate( tracker_metadata_new["obj_ids_per_gpu"] ) # update object scores and the maximum object ID assigned so far if len(new_det_obj_ids) > 0: det_scores_np: np.ndarray = det_scores.cpu().numpy() tracker_metadata_new["obj_id_to_score"].update( zip(new_det_obj_ids, det_scores_np[adt_result.new_det_fa_inds]) ) # sam2 scores are not available for new objects, use det score instead. # Store as GPU tensors for consistency with SAM2 propagation scores new_det_scores_tensor = det_scores[adt_result.new_det_fa_inds] tracker_metadata_new["obj_id_to_sam2_score_frame_wise"][frame_idx].update( zip(new_det_obj_ids, new_det_scores_tensor) ) tracker_metadata_new["max_obj_id"] = max( tracker_metadata_new["max_obj_id"], np.max(new_det_obj_ids), ) # for removed objects, we set their scores to a very low value (-1e4) but still # keep them in "obj_id_to_score" (it's easier to handle outputs this way) for obj_id in obj_ids_newly_removed: tracker_metadata_new["obj_id_to_score"][obj_id] = -1e4 # Store as GPU tensor for consistency tracker_metadata_new["obj_id_to_sam2_score_frame_wise"][frame_idx][ obj_id ] = torch.tensor(-1e4, dtype=torch.float32, device=det_scores.device) tracker_metadata_new["obj_id_to_last_occluded"].pop(obj_id, None) # check that "rank0_metadata" is in tracker_metadata_new if and only if it's GPU 0 assert "rank0_metadata" in tracker_metadata_new if self.masklet_confirmation_enable: with torch.profiler.record_function("update_masklet_confirmation_status"): rank0_metadata = self.update_masklet_confirmation_status( rank0_metadata=tracker_metadata_new["rank0_metadata"], obj_ids_all_gpu_prev=tracker_metadata_prev["obj_ids_all_gpu"], obj_ids_all_gpu_updated=tracker_metadata_new["obj_ids_all_gpu"], det_to_matched_trk_obj_ids=adt_result.det_to_matched_trk_obj_ids, new_det_obj_ids=new_det_obj_ids, ) tracker_metadata_new["rank0_metadata"] = rank0_metadata # Compact GPU metadata NOW (after sync) in preparation for next frame # This removes entries for objects that will be deleted in execution phase # so next frame's _process_hotstart_gpu doesn't need to do sync-inducing compaction if not hasattr(self, "_warm_up_complete") or self._warm_up_complete: if ( "gpu_metadata" in tracker_metadata_new and tracker_metadata_new["gpu_metadata"].get("N_obj", 0) > 0 ): with torch.profiler.record_function("compact_gpu_metadata"): gpu_meta = tracker_metadata_new["gpu_metadata"] removed_mask = gpu_meta[ "removed_mask" ] # (N_obj,) - which objects marked for removal keep_indices = torch.nonzero(~removed_mask, as_tuple=True)[0] gpu_meta["obj_first_frame"] = gpu_meta["obj_first_frame"][ keep_indices ] gpu_meta["consecutive_unmatch_count"] = gpu_meta[ "consecutive_unmatch_count" ][keep_indices] gpu_meta["trk_keep_alive"] = gpu_meta["trk_keep_alive"][ keep_indices ] gpu_meta["removed_mask"] = gpu_meta["removed_mask"][ keep_indices ] # Should be all False gpu_meta["last_occluded_tensor"] = gpu_meta["last_occluded_tensor"][ keep_indices ] # Compact pairwise matrix (remove both rows and columns) overlap_counts = gpu_meta["overlap_pair_counts"] overlap_counts = overlap_counts[keep_indices][:, keep_indices] gpu_meta["overlap_pair_counts"] = overlap_counts # Update N_obj to reflect post-removal count gpu_meta["N_obj"] = keep_indices.size(0) # After compaction, extend gpu_metadata with new objects' initial values # This ensures obj_first_frame is set to the detection frame, not propagation frame num_new = len(new_det_obj_ids) if num_new > 0: with torch.profiler.record_function( "extend_gpu_metadata_for_new_objects" ): gpu_meta = tracker_metadata_new["gpu_metadata"] device = det_scores.device NEVER_OCCLUDED = -1 # Extend all metadata tensors for new objects gpu_meta["obj_first_frame"] = torch.cat( [ gpu_meta.get( "obj_first_frame", torch.empty(0, dtype=torch.long, device=device), ), torch.full( (num_new,), frame_idx, dtype=torch.long, device=device ), ] ) gpu_meta["consecutive_unmatch_count"] = torch.cat( [ gpu_meta.get( "consecutive_unmatch_count", torch.empty(0, dtype=torch.long, device=device), ), torch.zeros(num_new, dtype=torch.long, device=device), ] ) gpu_meta["trk_keep_alive"] = torch.cat( [ gpu_meta.get( "trk_keep_alive", torch.empty(0, dtype=torch.long, device=device), ), torch.full( (num_new,), self.init_trk_keep_alive, dtype=torch.long, device=device, ), ] ) gpu_meta["removed_mask"] = torch.cat( [ gpu_meta.get( "removed_mask", torch.empty(0, dtype=torch.bool, device=device), ), torch.zeros(num_new, dtype=torch.bool, device=device), ] ) gpu_meta["last_occluded_tensor"] = torch.cat( [ gpu_meta.get( "last_occluded_tensor", torch.empty(0, dtype=torch.long, device=device), ), torch.full( (num_new,), NEVER_OCCLUDED, dtype=torch.long, device=device, ), ] ) # Grow overlap matrix old_N = gpu_meta.get("N_obj", 0) new_N = old_N + num_new old_overlap = gpu_meta.get( "overlap_pair_counts", torch.zeros((0, 0), dtype=torch.long, device=device), ) new_overlap = torch.zeros( (new_N, new_N), dtype=torch.long, device=device ) if old_N > 0: new_overlap[:old_N, :old_N] = old_overlap gpu_meta["overlap_pair_counts"] = new_overlap gpu_meta["N_obj"] = new_N sam2_update_plan = { "new_det_fa_inds": adt_result.new_det_fa_inds, # np.ndarray "new_det_obj_ids": new_det_obj_ids, # np.ndarray "new_det_gpu_ids": new_det_gpu_ids, # np.ndarray "unmatched_trk_obj_ids": adt_result.unmatched_trk_obj_ids, # np.ndarray "det_to_matched_trk_obj_ids": adt_result.det_to_matched_trk_obj_ids, # dict "obj_ids_newly_removed": obj_ids_newly_removed, # set "num_obj_dropped_due_to_limit": num_obj_dropped_due_to_limit, # int "trk_id_to_max_iou_high_conf_det": adt_result.trk_id_to_max_iou_high_conf_det, # dict "reconditioned_obj_ids": reconditioned_obj_ids, # set } return sam2_update_plan, tracker_metadata_new def _suppress_overlapping_based_on_recent_occlusion( self, frame_idx: int, tracker_low_res_masks_global: Tensor, tracker_metadata_prev: Dict[str, Any], tracker_metadata_new: Dict[str, Any], to_remove_mask: Tensor, # GPU boolean mask (N_obj,) instead of CPU set reverse: bool = False, ): """ Suppress overlapping masks based on the most recent occlusion information. If an object is removed by hotstart, we always suppress it if it overlaps with any other object. Args: frame_idx (int): The current frame index. tracker_low_res_masks_global (Tensor): The low-resolution masks for the current frame. tracker_metadata_prev (Dict[str, Any]): The metadata from the previous frame. tracker_metadata_new (Dict[str, Any]): The metadata for the current frame (with updated gpu_metadata from _process_hotstart_gpu). to_remove_mask (Tensor): GPU boolean mask (N_obj,) indicating which objects are removed. Return: Tensor: The updated low-resolution masks with some objects suppressed. """ # NOTE: obj_ids_global is only used for debug logging, so we can use prev (it won't match perfectly but close enough for debugging) # The actual suppression logic uses GPU tensors which ARE in the correct index space from tracker_metadata_new obj_ids_global = tracker_metadata_prev["obj_ids_all_gpu"] binary_tracker_low_res_masks_global = tracker_low_res_masks_global > 0 batch_size = tracker_low_res_masks_global.size(0) num_ids = len(obj_ids_global) # immediately to force proper debugging. (Aligned with merge decision 4.5.2) assert batch_size == num_ids, ( f"Mask/metadata count mismatch in _suppress_overlapping: " f"batch_size={batch_size}, num_ids={num_ids}, frame_idx={frame_idx}" ) binary_tracker_low_res_masks_global = tracker_low_res_masks_global > 0 if batch_size > 0: assert len(obj_ids_global) == batch_size, ( f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}" ) NEVER_OCCLUDED = -1 ALWAYS_OCCLUDED = 100000 # This value should be larger than any possible frame index, indicates that the object was removed by hotstart logic # GPU-vectorized: Build last_occluded_prev tensor without iteration/syncs device = binary_tracker_low_res_masks_global.device # Get last_occluded from UPDATED gpu_metadata (already in correct index space from _process_hotstart_gpu) gpu_metadata_new = tracker_metadata_new["gpu_metadata"] last_occluded_prev = gpu_metadata_new["last_occluded_tensor"] # Sanity check: ensure last_occluded_tensor is in sync with batch_size assert last_occluded_prev.size(0) == batch_size, ( f"last_occluded_tensor size mismatch: {last_occluded_prev.size(0)} vs {batch_size}. " f"This indicates gpu_metadata tensors are out of sync." ) # Set ALWAYS_OCCLUDED for removed objects (fully vectorized, no sync!) last_occluded_prev = torch.where( to_remove_mask, torch.full_like(last_occluded_prev, ALWAYS_OCCLUDED), last_occluded_prev, ) to_suppress = self._get_objects_to_suppress_based_on_most_recently_occluded( binary_tracker_low_res_masks_global, last_occluded_prev, obj_ids_global, frame_idx, reverse, ) # Update metadata with occlusion information (fully vectorized) is_obj_occluded = ~(binary_tracker_low_res_masks_global.any(dim=(-1, -2))) is_obj_occluded_or_suppressed = is_obj_occluded | to_suppress last_occluded_new = last_occluded_prev.clone() last_occluded_new[is_obj_occluded_or_suppressed] = frame_idx # Store in gpu_metadata to keep it aligned with other metadata tensors tracker_metadata_new["gpu_metadata"]["last_occluded_tensor"] = ( last_occluded_new ) # Also maintain legacy dict format for backwards compatibility # This conversion happens on CPU AFTER memory encoding, not in critical path tracker_metadata_new[ "obj_id_to_last_occluded" ] = {} # Will be populated later if needed # Zero out suppressed masks before memory encoding NO_OBJ_LOGIT = -10 tracker_low_res_masks_global[to_suppress] = NO_OBJ_LOGIT return tracker_low_res_masks_global def _create_planning_metadata(self, tracker_metadata_prev): """Extend planning metadata with multiplex-specific fields.""" metadata = super()._create_planning_metadata(tracker_metadata_prev) if self.is_multiplex: metadata["num_buc_per_gpu"] = self._deepcopy( tracker_metadata_prev["num_buc_per_gpu"] ) metadata["gpu_metadata"] = tracker_metadata_prev["gpu_metadata"] return metadata def _post_execution_phase_hook(self, tracker_states_local, tracker_metadata_new): """Update bucket count after execution phase (multiplex-specific).""" if self.is_multiplex and tracker_metadata_new is not None: actual_bucket_count = self._count_buckets_in_states(tracker_states_local) tracker_metadata_new["num_buc_per_gpu"][self.rank] = actual_bucket_count def _count_buckets_in_states(self, tracker_states_local: List[Any]) -> int: """Count the total number of buckets across all states.""" if not self.is_multiplex: return 0 total_buckets = 0 for state in tracker_states_local: if "multiplex_state" in state: total_buckets += state["multiplex_state"].num_buckets return total_buckets def build_outputs( self, frame_idx: int, num_frames: int, reverse: bool, det_out: Dict[ str, Tensor ], # TODO: Only det_out["mask"][new_det_fa_inds_local_t] is needed tracker_low_res_masks_global: Tensor, tracker_obj_scores_global: Tensor, tracker_metadata_prev: Dict[str, np.ndarray], sam2_update_plan: Dict[str, np.ndarray], orig_vid_height: int, orig_vid_width: int, reconditioned_obj_ids: set = None, det_to_matched_trk_obj_ids: dict = None, ): new_det_fa_inds: np.ndarray = sam2_update_plan["new_det_fa_inds"] new_det_obj_ids: np.ndarray = sam2_update_plan["new_det_obj_ids"] obj_id_to_mask = {} # obj_id --> output mask tensor # Part 1: masks from previous SAM2 propagation # Align IDs and masks from previous SAM2 propagation existing_masklet_obj_ids_all = tracker_metadata_prev["obj_ids_all_gpu"] existing_masklet_obj_ids_per_gpu = np.concatenate( tracker_metadata_prev["obj_ids_per_gpu"] ) use_per_gpu_ids = len(existing_masklet_obj_ids_per_gpu) != len( existing_masklet_obj_ids_all ) or not np.array_equal( existing_masklet_obj_ids_per_gpu, existing_masklet_obj_ids_all ) existing_masklet_obj_ids = ( existing_masklet_obj_ids_per_gpu if use_per_gpu_ids else existing_masklet_obj_ids_all ) existing_masklet_video_res_masks = F.interpolate( tracker_low_res_masks_global.unsqueeze(1), size=(orig_vid_height, orig_vid_width), mode="bilinear", align_corners=False, ) # (num_obj, 1, H_video, W_video) # Pad/truncate masks to match metadata count num_masks = existing_masklet_video_res_masks.size(0) num_ids = len(existing_masklet_obj_ids) if num_masks != num_ids: if num_masks < num_ids: pad = existing_masklet_video_res_masks.new_zeros( (num_ids - num_masks, 1, orig_vid_height, orig_vid_width) ) existing_masklet_video_res_masks = torch.cat( [existing_masklet_video_res_masks, pad], dim=0 ) else: existing_masklet_video_res_masks = existing_masklet_video_res_masks[ :num_ids ] existing_masklet_binary = existing_masklet_video_res_masks > 0 for obj_id, mask in zip(existing_masklet_obj_ids, existing_masklet_binary): obj_id_to_mask[obj_id] = mask # (1, H_video, W_video) # Part 2: masks from new detections new_det_fa_inds_t = torch.from_numpy(new_det_fa_inds) new_det_low_res_masks = det_out["mask"][new_det_fa_inds_t].unsqueeze(1) new_det_low_res_masks = fill_holes_in_mask_scores( new_det_low_res_masks, fill_hole_area=self.fill_hole_area, sprinkle_removal_area=self.sprinkle_removal_area, fill_holes=True, remove_sprinkles=True, ) new_masklet_video_res_masks = F.interpolate( new_det_low_res_masks, size=(orig_vid_height, orig_vid_width), mode="bilinear", align_corners=False, ) # (num_obj, 1, H_video, W_video) new_masklet_binary = new_masklet_video_res_masks > 0 assert len(new_det_obj_ids) == len(new_masklet_video_res_masks) for obj_id, mask in zip(new_det_obj_ids, new_masklet_binary): obj_id_to_mask[obj_id] = mask # (1, H_video, W_video) return obj_id_to_mask def _get_objects_to_suppress_based_on_most_recently_occluded( self, binary_low_res_masks: Tensor, last_occluded: Tensor, # GPU tensor (N_obj,) with frame indices obj_ids: np.ndarray, # numpy array of object IDs frame_idx: int = None, reverse: bool = False, ): # Suppress overlapping masks for objects that were most recently occluded assert binary_low_res_masks.dtype == torch.bool, ( f"Expected boolean tensor, got {binary_low_res_masks.dtype}" ) to_suppress = torch.zeros( binary_low_res_masks.size(0), device=binary_low_res_masks.device, dtype=torch.bool, ) if len(obj_ids) <= 1: return to_suppress iou = mask_iou(binary_low_res_masks, binary_low_res_masks) # [N,N] # Create masks for upper triangular matrix (i < j) and IoU threshold mask_iou_thresh = ( iou >= self.suppress_overlapping_based_on_recent_occlusion_threshold ) overlapping_pairs = torch.triu(mask_iou_thresh, diagonal=1) # [N,N] last_occ_expanded_i = last_occluded.unsqueeze(1) # (N, 1) last_occ_expanded_j = last_occluded.unsqueeze(0) # (1, N) cmp_op = torch.gt if not reverse else torch.lt if self.allow_unoccluded_to_suppress: # Suppress most recently occluded suppress_i_mask = overlapping_pairs & cmp_op( last_occ_expanded_i, last_occ_expanded_j ) suppress_j_mask = overlapping_pairs & cmp_op( last_occ_expanded_j, last_occ_expanded_i ) else: # Suppress most recently occluded suppress_i_mask = ( overlapping_pairs & cmp_op( last_occ_expanded_i, last_occ_expanded_j ) # (last_occ_expanded_i > last_occ_expanded_j) & (last_occ_expanded_j > -1) # j can suppress i only if j was previously occluded ) suppress_j_mask = ( overlapping_pairs & cmp_op(last_occ_expanded_j, last_occ_expanded_i) & ( last_occ_expanded_i > -1 ) # i can suppress j only if i was previously occluded ) # Apply suppression to_suppress = suppress_i_mask.any(dim=1) | suppress_j_mask.any(dim=0) # Log for debugging if ( self.rank == 0 and logger.isEnabledFor(logging.DEBUG) and frame_idx is not None ): suppress_i_mask = suppress_i_mask.cpu().numpy() suppress_j_mask = suppress_j_mask.cpu().numpy() last_occluded = last_occluded.cpu().numpy() # Find all suppression pairs without using torch.where batch_size = suppress_i_mask.shape[0] # Log i-suppression cases (where i gets suppressed in favor of j) for i in range(batch_size): for j in range(batch_size): if suppress_i_mask[i, j]: logger.debug( f"{frame_idx=}: Suppressing obj {obj_ids[i]} last occluded {last_occluded[i]} in favor of {obj_ids[j]} last occluded {last_occluded[j]}" ) # Log j-suppression cases (where j gets suppressed in favor of i) for i in range(batch_size): for j in range(batch_size): if suppress_j_mask[i, j]: logger.debug( f"{frame_idx=}: Suppressing obj {obj_ids[j]} last occluded {last_occluded[j]} in favor of {obj_ids[i]} last occluded {last_occluded[i]}" ) return to_suppress def _propogate_tracker_one_frame_local_gpu( self, inference_states: List[Any], frame_idx: int, reverse: bool, # by default, we disable memory encoding until we gather all outputs run_mem_encoder: bool = False, # When specified, only return masks/scores for these object ids filter_obj_ids: Optional[List[int]] = None, ): """ inference_states: List of inference states, each state corresponds to a different set of objects. """ obj_ids_local = [] low_res_masks_list = [] obj_scores_list = [] for inference_state in inference_states: if len(inference_state["obj_ids"]) == 0: continue # skip propagation on empty inference states # propagate one frame num_frames_propagated = 0 with torch.profiler.record_function("sam2_predictor.propagate_in_video"): for out in self.tracker.propagate_in_video( inference_state, start_frame_idx=frame_idx, # end_frame_idx = start_frame_idx + max_frame_num_to_track # (i.e. propagating 1 frame since end_frame_idx is inclusive) max_frame_num_to_track=0, reverse=reverse, tqdm_disable=True, run_mem_encoder=run_mem_encoder, ): # TODO we only need low-res outputs here for all-gather across GPUs, # so we can remove the high-res interpolation in `propagate_in_video` out_frame_idx, out_obj_ids, out_low_res_masks, _, out_obj_scores = ( out ) num_frames_propagated += 1 # only 1 frames should be propagated assert num_frames_propagated == 1 and out_frame_idx == frame_idx, ( f"num_frames_propagated: {num_frames_propagated}, out_frame_idx: {out_frame_idx}, frame_idx: {frame_idx}" ) assert isinstance(out_obj_ids, list) # Optionally filter to a subset of object ids (for partial propagation). # We also clamp indices to available rows to avoid CUDA index_select assertions. if filter_obj_ids is not None: if len(out_obj_ids) > 0: max_mask_rows = out_low_res_masks.shape[0] max_score_rows = out_obj_scores.shape[0] # Special case: common single-object refinement path where SAM2 returns a single mask row # but a longer out_obj_ids list for the state. Treat the lone row as the requested object. if ( len(filter_obj_ids) == 1 and max_mask_rows == 1 and max_score_rows == 1 ): out_obj_ids = [filter_obj_ids[0]] keep_indices = [0] else: keep_indices = [ i for i, oid in enumerate(out_obj_ids) if oid in filter_obj_ids and i < max_mask_rows and i < max_score_rows ] else: keep_indices = [] if len(keep_indices) > 0: idx_tensor = torch.as_tensor( keep_indices, device=out_low_res_masks.device, dtype=torch.long ) out_low_res_masks = out_low_res_masks.index_select( dim=0, index=idx_tensor ) out_obj_scores = out_obj_scores.index_select( dim=0, index=idx_tensor ) out_obj_ids = [out_obj_ids[i] for i in keep_indices] else: # no selected objects in this local state; skip appending out_obj_ids = [] if len(out_obj_ids) > 0: obj_ids_local.extend(out_obj_ids) low_res_masks_list.append(out_low_res_masks.squeeze(1)) obj_scores_list.append(out_obj_scores.squeeze(1)) # concatenate the output masklets from all local inference states with torch.profiler.record_function( "sam2_predictor.propagate_in_video.fill_holes" ): H_mask = W_mask = self.tracker.low_res_mask_size if len(low_res_masks_list) > 0: low_res_masks_local = torch.cat(low_res_masks_list, dim=0) obj_scores_local = torch.cat(obj_scores_list, dim=0) assert low_res_masks_local.shape[1:] == (H_mask, W_mask) # Apply hole filling to the masks low_res_masks_local = fill_holes_in_mask_scores( low_res_masks_local.unsqueeze(1), fill_hole_area=self.fill_hole_area, sprinkle_removal_area=self.sprinkle_removal_area, fill_holes=True, remove_sprinkles=True, ) low_res_masks_local = low_res_masks_local.squeeze(1) else: low_res_masks_local = torch.zeros(0, H_mask, W_mask, device=self.device) obj_scores_local = torch.zeros(0, device=self.device) if self.is_multiplex and self.tracker.is_multiplex_dynamic: # obj_ids_local might not be sorted, which is problematic because # the rest of the code assumes that they are. # Currently this only happens in the dynamic multiplex setting (since we backfill states) # so we only check for this condition here, but this should be generally applicable. # Note that a similar remapping is necessary when we update the memory, e.g., # in _tracker_update_memories if obj_ids_local != sorted(obj_ids_local): # Get sorting permutation sort_indices = sorted( range(len(obj_ids_local)), key=lambda i: obj_ids_local[i] ) # Apply permutation to reorder everything obj_ids_local = [obj_ids_local[i] for i in sort_indices] low_res_masks_local = low_res_masks_local[sort_indices] obj_scores_local = obj_scores_local[sort_indices] if self.is_multiplex and self.tracker.is_multiplex_dynamic: # obj_ids_local might not be sorted, which is problematic because # the rest of the code assumes that they are. # Currently this only happens in the dynamic multiplex setting (since we backfill states) # so we only check for this condition here, but this should be generally applicable. # Note that a similar remapping is necessary when we update the memory, e.g., # in _tracker_update_memories if obj_ids_local != sorted(obj_ids_local): # Get sorting permutation sort_indices = sorted( range(len(obj_ids_local)), key=lambda i: obj_ids_local[i] ) # Apply permutation to reorder everything obj_ids_local = [obj_ids_local[i] for i in sort_indices] if low_res_masks_local.shape[0] == len(sort_indices): low_res_masks_local = low_res_masks_local[sort_indices] obj_scores_local = obj_scores_local[sort_indices] return obj_ids_local, low_res_masks_local, obj_scores_local def _associate_det_trk( self, det_masks: Tensor, det_scores: Tensor, det_keep: Tensor, trk_masks: Tensor, trk_obj_ids: np.ndarray, default_det_thresh: Optional[float] = None, ): """ Match detections on the current frame with the existing masklets. Args: - det_masks: (N, H, W) tensor of predicted masks - det_scores: (N,) array of detection scores - trk_masks: (M, H, W) tensor of track masks - trk_obj_ids: (M,) array of object IDs corresponding to trk_masks Returns: - new_det_fa_inds: array of new object indices among in FA detection outputs - unmatched_trk_obj_ids: array of existing masklet object IDs that are not matched to any detections on this frame (for unmatched, we only count masklets with >0 area) - det_to_matched_trk_obj_ids: dict[int, np.ndarray]: mapping from FA detection indices to the list of matched tracklet object IDs - empty_trk_obj_ids: array of existing masklet object IDs with zero area in SAM2 prediction """ HIGH_CONF_THRESH = 0.8 iou_threshold = self.assoc_iou_thresh iou_threshold_trk = self.trk_assoc_iou_thresh new_det_thresh = ( self.new_det_thresh if default_det_thresh is None else default_det_thresh ) assert det_masks.is_floating_point(), "float tensor expected (do not binarize)" assert trk_masks.is_floating_point(), "float tensor expected (do not binarize)" assert trk_masks.size(0) == len(trk_obj_ids), ( f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}" ) if trk_masks.size(0) == 0: with torch.profiler.record_function("No tracklets"): num_trk = 0 is_new_det = det_scores >= new_det_thresh trk_is_unmatched = torch.zeros( num_trk, dtype=torch.bool, device=det_scores.device ) trk_is_nonempty = torch.zeros( num_trk, dtype=torch.bool, device=det_scores.device ) num_det = det_scores.shape[0] det_to_max_iou_trk_idx = torch.full( (num_det,), -1, dtype=torch.long, device=det_scores.device ) det_is_high_conf = det_scores >= HIGH_CONF_THRESH det_is_high_iou = torch.zeros( num_det, dtype=torch.bool, device=det_scores.device ) im_mask = torch.zeros( num_det, num_trk, dtype=torch.bool, device=det_scores.device ) return LazyAssociateDetTrkResult( trk_is_unmatched, trk_is_nonempty, is_new_det, det_to_max_iou_trk_idx, det_is_high_conf, det_is_high_iou, det_keep, im_mask, ) elif det_masks.size(0) == 0: with torch.profiler.record_function("No detections"): assert det_keep.size(0) == 0 # Make sure the keep mask agrees trk_is_nonempty = (trk_masks > 0).any(dim=(1, 2)) num_det = 0 num_trk = trk_masks.shape[0] trk_is_unmatched = torch.ones( num_trk, dtype=torch.bool, device=trk_masks.device ) trk_is_nonempty_tensor = trk_is_nonempty.to(trk_masks.device) is_new_det = torch.zeros( num_det, dtype=torch.bool, device=trk_masks.device ) det_to_max_iou_trk_idx = torch.full( (num_det,), -1, dtype=torch.long, device=trk_masks.device ) det_is_high_conf = torch.zeros( num_det, dtype=torch.bool, device=trk_masks.device ) det_is_high_iou = torch.zeros( num_det, dtype=torch.bool, device=trk_masks.device ) im_mask = torch.zeros( num_det, num_trk, dtype=torch.bool, device=trk_masks.device ) return LazyAssociateDetTrkResult( trk_is_unmatched, trk_is_nonempty_tensor, is_new_det, det_to_max_iou_trk_idx, det_is_high_conf, det_is_high_iou, det_keep, im_mask, ) if det_masks.shape[-2:] != trk_masks.shape[-2:]: # resize to the smaller size to save GPU memory if np.prod(det_masks.shape[-2:]) < np.prod(trk_masks.shape[-2:]): trk_masks = F.interpolate( trk_masks.unsqueeze(1), size=det_masks.shape[-2:], mode="bilinear", align_corners=False, ).squeeze(1) else: # resize detections to track size det_masks = F.interpolate( det_masks.unsqueeze(1), size=trk_masks.shape[-2:], mode="bilinear", align_corners=False, ).squeeze(1) with torch.profiler.record_function("associate_det_trk_compilable"): if trk_masks.shape[0] < self.max_num_objects: padding_size = self.max_num_objects - trk_masks.shape[0] trk_masks_padded = torch.cat( [ trk_masks, torch.zeros( padding_size, *trk_masks.shape[1:], device=trk_masks.device, dtype=trk_masks.dtype, ), ], dim=0, ) else: trk_masks_padded = trk_masks result = _associate_det_trk_compilable( det_masks, det_scores, det_keep, trk_masks_padded, new_det_thresh, iou_threshold_trk, iou_threshold, HIGH_CONF_THRESH, self.use_iom_recondition, self.o2o_matching_masklets_enable, self.iom_thresh_recondition, self.iou_thresh_recondition, ) ( trk_is_unmatched, trk_is_nonempty, is_new_det, det_to_max_iou_trk_idx, det_is_high_conf, det_is_high_iou, det_keep, im_mask, ) = result trk_is_unmatched = trk_is_unmatched[: trk_masks.shape[0]] trk_is_nonempty = trk_is_nonempty[: trk_masks.shape[0]] im_mask = im_mask[:, : trk_masks.shape[0]] return LazyAssociateDetTrkResult( trk_is_unmatched, trk_is_nonempty, is_new_det, det_to_max_iou_trk_idx, det_is_high_conf, det_is_high_iou, det_keep, im_mask, ) def _assign_new_det_to_gpus(self, new_det_num, prev_workload_per_gpu): """Distribute the new objects to the GPUs with the least workload.""" workload_per_gpu: np.ndarray = prev_workload_per_gpu.copy() new_det_gpu_ids = np.zeros(new_det_num, np.int64) if self.is_multiplex: # assign the objects in a batch of multiplex_count for i in range(0, new_det_num, self.bucket_capacity): # find the GPU with the least workload min_gpu = np.argmin(workload_per_gpu) new_det_gpu_ids[i : i + self.bucket_capacity] = min_gpu workload_per_gpu[min_gpu] += 1 else: # assign the objects one by one for i in range(len(new_det_gpu_ids)): # find the GPU with the least workload min_gpu = np.argmin(workload_per_gpu) new_det_gpu_ids[i] = min_gpu workload_per_gpu[min_gpu] += 1 return new_det_gpu_ids def _process_hotstart_gpu( self, frame_idx: int, reverse: bool, adt_result, # LazyAssociateDetTrkResult (always lazy now) tracker_metadata_prev: Dict[str, Any], gpu_metadata_prev: Dict[str, Tensor], ) -> Tuple[Tensor, Tensor, Dict[str, Tensor]]: """ Compute removal/suppression masks entirely on GPU without ANY syncs or branches. Uses position-indexed metadata (indexed 0 to N_obj-1) instead of obj_id-indexed to avoid needing obj_ids as GPU tensor. Returns: to_remove: boolean tensor (N_obj,) - objects to remove this frame to_suppress: boolean tensor (N_obj,) - objec ts to suppress (overlap suppression) gpu_metadata_new: updated GPU metadata for next frame """ # Handle edge case: if adt_result is already realized (no tracks exist), # return empty masks since there's nothing to remove if isinstance(adt_result, RealizedAssociateDetTrkresult): # No tracks exist, so no objects to remove/suppress empty_mask = torch.zeros(0, dtype=torch.bool, device=self.device) return empty_mask, empty_mask, {"N_obj": 0} device = adt_result.trk_is_unmatched.device N_obj = adt_result.trk_is_unmatched.size(0) # Number of current objects # ============================================================================ # STEP 1: Initialize/extract position-indexed GPU metadata # ============================================================================ # All metadata tensors are indexed by POSITION (0 to N_obj-1), not by obj_id # This grows/shrinks each frame as objects are added/removed # Get previous frame's metadata (sized for previous N_obj) # NOTE: Metadata is already compacted from previous frame (removed objects are already filtered out) prev_N_obj = gpu_metadata_prev.get("N_obj", 0) if prev_N_obj > 0: # Metadata from previous frame (position-indexed, already compacted) obj_first_frame_prev = gpu_metadata_prev["obj_first_frame"] # (prev_N_obj,) consecutive_unmatch_count_prev = gpu_metadata_prev[ "consecutive_unmatch_count" ] # (prev_N_obj,) trk_keep_alive_prev = gpu_metadata_prev["trk_keep_alive"] # (prev_N_obj,) removed_mask_prev = gpu_metadata_prev[ "removed_mask" ] # (prev_N_obj,) - should be all False after compaction overlap_pair_counts_prev = gpu_metadata_prev[ "overlap_pair_counts" ] # (prev_N_obj, prev_N_obj) last_occluded_prev = gpu_metadata_prev[ "last_occluded_tensor" ] # (prev_N_obj,) else: # First frame - no previous metadata obj_first_frame_prev = None consecutive_unmatch_count_prev = None trk_keep_alive_prev = None removed_mask_prev = None overlap_pair_counts_prev = None last_occluded_prev = None # ============================================================================ # STEP 2: Carry forward metadata from previous frame # ============================================================================ # Current frame has N_obj objects (from propagation) # New objects are added via extend_gpu_metadata_for_new_objects AFTER compaction, # so prev_N_obj should already include objects detected on previous frame. # N_obj should equal prev_N_obj (no new objects mid-planning-phase). assert N_obj == prev_N_obj, ( f"N_obj ({N_obj}) should equal prev_N_obj ({prev_N_obj}); new objects handled after compaction" ) # Carry forward existing metadata (or initialize if first frame) NEVER_OCCLUDED = -1 obj_first_frame = ( obj_first_frame_prev if obj_first_frame_prev is not None else torch.full((N_obj,), frame_idx, dtype=torch.long, device=device) ) consecutive_unmatch_count = ( consecutive_unmatch_count_prev if consecutive_unmatch_count_prev is not None else torch.zeros(N_obj, dtype=torch.long, device=device) ) trk_keep_alive = ( trk_keep_alive_prev if trk_keep_alive_prev is not None else torch.zeros(N_obj, dtype=torch.long, device=device) ) removed_mask = ( removed_mask_prev if removed_mask_prev is not None else torch.zeros(N_obj, dtype=torch.bool, device=device) ) overlap_pair_counts = ( overlap_pair_counts_prev if overlap_pair_counts_prev is not None else torch.zeros((N_obj, N_obj), dtype=torch.long, device=device) ) last_occluded = ( last_occluded_prev if last_occluded_prev is not None else torch.full((N_obj,), NEVER_OCCLUDED, dtype=torch.long, device=device) ) # ============================================================================ # STEP 3: Update keep-alive counters (fully vectorized) # ============================================================================ # Determine which tracks are matched by ANY detection trk_is_matched = adt_result.im_mask.any(dim=0) # (N_obj,) # Update: +1 for matched, -1 for unmatched, clamp to [min, max] trk_keep_alive = torch.where( trk_is_matched, trk_keep_alive + 1, trk_keep_alive - 1 ) trk_keep_alive = torch.clamp( trk_keep_alive, min=self.min_trk_keep_alive, max=self.max_trk_keep_alive ) # Also decrement for empty tracklets (if configured) if self.decrease_trk_keep_alive_for_empty_masklets: trk_keep_alive = torch.where( ~adt_result.trk_is_nonempty, torch.clamp(trk_keep_alive - 1, min=self.min_trk_keep_alive), trk_keep_alive, ) # ============================================================================ # STEP 4: Update total unmatch counters (fully vectorized) # ============================================================================ # Increment for unmatched, but DON'T reset for matched # Original logic accumulates total unmatched frames, not consecutive consecutive_unmatch_count = torch.where( adt_result.trk_is_unmatched, consecutive_unmatch_count + 1, consecutive_unmatch_count, # Keep previous value, don't reset ) # ============================================================================ # STEP 5: Update pairwise overlap tracking (fully vectorized) # ============================================================================ # Find detections matched by multiple tracks tracks_per_det = adt_result.im_mask.sum(dim=1) # (N_det,) multi_match_mask = tracks_per_det > 1 # (N_det,) # Build overlap increment matrix using einsum multi_match_tracks = adt_result.im_mask & multi_match_mask.unsqueeze( 1 ) # (N_det, N_obj) # Compute pairwise overlaps: for each detection, outer product of matched tracks pairwise_overlap_this_frame = torch.einsum( "di,dj->dij", multi_match_tracks.float(), multi_match_tracks.float() ) # (N_det, N_obj, N_obj) # Sum across detections overlap_increment = pairwise_overlap_this_frame.sum(dim=0) # (N_obj, N_obj) overlap_increment.fill_diagonal_(0) # No self-overlap overlap_increment = torch.triu( overlap_increment, diagonal=1 ) # Upper triangle only # Add this frame's increments (accumulate across frames, don't reset) # Original logic: overlap_pair_to_frame_inds[key].append(frame_idx) - never clears overlap_pair_counts = overlap_pair_counts + overlap_increment.long() # ============================================================================ # STEP 6: Compute removal decisions - UNMATCH criterion (fully vectorized) # ============================================================================ # Hotstart boundary hotstart_diff = ( frame_idx - self.hotstart_delay if not reverse else frame_idx + self.hotstart_delay ) # Check if objects are within hotstart window is_within_hotstart = ( (obj_first_frame > hotstart_diff) if not reverse else (obj_first_frame < hotstart_diff) ) # Remove if: within hotstart AND unmatched >= threshold AND not already removed remove_by_unmatch = ( is_within_hotstart & (consecutive_unmatch_count >= self.hotstart_unmatch_thresh) & ~removed_mask ) # Suppress if: keep_alive <= 0 AND not hotstart-only mode AND not removed suppress_by_unmatch = ( (trk_keep_alive <= 0) & torch.tensor(not self.suppress_unmatched_only_within_hotstart, device="cpu") .pin_memory() .to(device=device, non_blocking=True) & ~removed_mask & ~remove_by_unmatch ) # ============================================================================ # STEP 7: Compute removal decisions - OVERLAP criterion (fully vectorized) # ============================================================================ # For each object, find max overlap count with any EARLIER object # "Earlier" = appeared in an earlier frame # Build matrix: is_earlier[i, j] = True if object i appeared before object j first_frames_i = obj_first_frame.unsqueeze(1) # (N_obj, 1) first_frames_j = obj_first_frame.unsqueeze(0) # (1, N_obj) if not reverse: is_earlier_matrix = first_frames_i < first_frames_j # (N_obj, N_obj) else: is_earlier_matrix = first_frames_i > first_frames_j # (N_obj, N_obj) # ============================================================================ # STEP 8: Combine removal/suppression decisions # ============================================================================ # Mask overlap counts to only consider earlier objects if N_obj == 0: to_remove = remove_by_unmatch else: overlap_with_earlier = torch.where( is_earlier_matrix, overlap_pair_counts, torch.zeros_like(overlap_pair_counts), ) # For each object (column j), find max overlap with any earlier object (row i) max_overlap_with_earlier, _ = overlap_with_earlier.max(dim=0) # (N_obj,) # Remove if: within hotstart AND overlapped with earlier >= threshold remove_by_overlap = ( is_within_hotstart & (max_overlap_with_earlier >= self.hotstart_dup_thresh) & ~removed_mask ) to_remove = remove_by_unmatch | remove_by_overlap # (N_obj,) to_suppress = suppress_by_unmatch # (N_obj,) # Update removed mask for future frames removed_mask = removed_mask | to_remove # ============================================================================ # STEP 9: Package updated metadata (NO SYNCS) # ============================================================================ gpu_metadata_new = { "N_obj": N_obj, "obj_first_frame": obj_first_frame, "consecutive_unmatch_count": consecutive_unmatch_count, "trk_keep_alive": trk_keep_alive, "removed_mask": removed_mask, "overlap_pair_counts": overlap_pair_counts, "last_occluded_tensor": last_occluded, } return to_remove, to_suppress, gpu_metadata_new def _process_hotstart( self, frame_idx: int, num_frames: int, reverse: bool, det_to_matched_trk_obj_ids: Dict[int, np.ndarray], new_det_obj_ids: np.ndarray, empty_trk_obj_ids: np.ndarray, unmatched_trk_obj_ids: np.ndarray, rank0_metadata: Dict[str, Any], tracker_metadata: Dict[str, Any], ): """Handle hotstart heuristics to remove unmatched or duplicated objects.""" # obj_id --> first frame index where the object was detected obj_first_frame_idx = rank0_metadata["obj_first_frame_idx"] # obj_id --> [mismatched frame indices] unmatched_frame_inds = rank0_metadata["unmatched_frame_inds"] trk_keep_alive = rank0_metadata["trk_keep_alive"] # (first_appear_obj_id, obj_id) --> [overlap frame indices] overlap_pair_to_frame_inds = rank0_metadata["overlap_pair_to_frame_inds"] # removed_obj_ids: object IDs that are suppressed via hot-start removed_obj_ids = rank0_metadata["removed_obj_ids"] suppressed_obj_ids = rank0_metadata["suppressed_obj_ids"][frame_idx] obj_ids_newly_removed = set() # object IDs to be newly removed on this frame hotstart_diff = ( frame_idx - self.hotstart_delay if not reverse else frame_idx + self.hotstart_delay ) # Step 1: log the frame index where each object ID first appears for obj_id in new_det_obj_ids: if obj_id not in obj_first_frame_idx: obj_first_frame_idx[obj_id] = frame_idx assert obj_id not in trk_keep_alive trk_keep_alive[obj_id] = self.init_trk_keep_alive matched_trks = set() # We use the det-->tracks list to check for matched objects. Otherwise, we need to compute areas to decide whether they're occluded for matched_trks_per_det in det_to_matched_trk_obj_ids.values(): matched_trks.update(matched_trks_per_det) for obj_id in matched_trks: # NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the max value of trk_keep_alive trk_keep_alive[obj_id] = min( self.max_trk_keep_alive, trk_keep_alive[obj_id] + 1 ) for obj_id in unmatched_trk_obj_ids: unmatched_frame_inds[obj_id].append(frame_idx) # NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the min value of trk_keep_alive # The max keep alive is 2x the min, means the model prefers to keep the prediction rather than suppress it if it was matched long enough. trk_keep_alive[obj_id] = max( self.min_trk_keep_alive, trk_keep_alive[obj_id] - 1 ) if self.decrease_trk_keep_alive_for_empty_masklets: for obj_id in empty_trk_obj_ids: # NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the min value of trk_keep_alive trk_keep_alive[obj_id] = max( self.min_trk_keep_alive, trk_keep_alive[obj_id] - 1 ) # Step 2: removed tracks that has not matched with detections for `hotstart_unmatch_thresh` frames with hotstart period # a) add unmatched frame indices for each existing object ID # note that `unmatched_trk_obj_ids` contains those frames where the SAM2 output mask # doesn't match any FA detection; it excludes those frames where SAM2 gives an empty mask # b) remove a masklet if it first appears after `hotstart_diff` and is unmatched for more # than `self.hotstart_unmatch_thresh` frames for obj_id, frame_indices in unmatched_frame_inds.items(): if obj_id in removed_obj_ids or obj_id in obj_ids_newly_removed: continue # skip if the object is already removed if len(frame_indices) >= self.hotstart_unmatch_thresh: is_within_hotstart = ( obj_first_frame_idx[obj_id] > hotstart_diff and not reverse ) or (obj_first_frame_idx[obj_id] < hotstart_diff and reverse) if is_within_hotstart: obj_ids_newly_removed.add(obj_id) logger.info( f"Removing object {obj_id} at frame {frame_idx} " f"since it is unmatched for frames: {frame_indices}" ) if ( trk_keep_alive[obj_id] <= 0 # Object has not been matched for too long and not self.suppress_unmatched_only_within_hotstart and obj_id not in removed_obj_ids and obj_id not in obj_ids_newly_removed ): logger.debug( f"Suppressing object {obj_id} at frame {frame_idx}, due to being unmatched" ) suppressed_obj_ids.add(obj_id) # Step 3: removed tracks that overlaps with another track for `hotstart_dup_thresh` frames # a) find overlaps tracks -- we consider overlap if they match to the same detection for _, matched_trk_obj_ids in det_to_matched_trk_obj_ids.items(): if len(matched_trk_obj_ids) < 2: continue # only count detections that are matched to multiple (>=2) masklets # if there are multiple matched track ids, we need to find the one that appeared first; # these later appearing ids may be removed since they may be considered as duplicates first_appear_obj_id = ( min(matched_trk_obj_ids, key=lambda x: obj_first_frame_idx[x]) if not reverse else max(matched_trk_obj_ids, key=lambda x: obj_first_frame_idx[x]) ) for obj_id in matched_trk_obj_ids: if obj_id != first_appear_obj_id: key = (first_appear_obj_id, obj_id) overlap_pair_to_frame_inds[key].append(frame_idx) # b) remove a masklet if it first appears after `hotstart_diff` and it overlaps with another # masklet (that appears earlier) for more than `self.hotstart_dup_thresh` frames for (first_obj_id, obj_id), frame_indices in overlap_pair_to_frame_inds.items(): if obj_id in removed_obj_ids or obj_id in obj_ids_newly_removed: continue # skip if the object is already removed if (obj_first_frame_idx[obj_id] > hotstart_diff and not reverse) or ( obj_first_frame_idx[obj_id] < hotstart_diff and reverse ): if len(frame_indices) >= self.hotstart_dup_thresh: obj_ids_newly_removed.add(obj_id) logger.info( f"Removing object {obj_id} at frame {frame_idx} " f"since it overlaps with another track {first_obj_id} at frames: {frame_indices}" ) removed_obj_ids.update(obj_ids_newly_removed) return obj_ids_newly_removed, rank0_metadata def _tracker_update_memories( self, sam2_inference_states: List[Any], frame_idx: int, tracker_metadata: Dict[str, Any], low_res_masks: Tensor, ): """ Run Sam2 memory encoder, enforcing non-overlapping constraints globally. """ # TODO: Add most recently occluded heuristic for suppression of overlapping masks if len(sam2_inference_states) == 0: return # Avoid an extra interpolation step by directly interpolating to `interpol_size` high_res_H, high_res_W = ( self.tracker.maskmem_backbone.mask_downsampler.interpol_size ) # NOTE: inspect this part if we observe OOMs in the demo high_res_masks = F.interpolate( low_res_masks.unsqueeze(1), size=(high_res_H, high_res_W), mode="bilinear", align_corners=False, ) # We first apply non-overlapping constraints before memory encoding. This may include some suppression heuristics. with torch.profiler.record_function( "sam2_predictor.propagate_in_video.apply_non_overlapping_constraints" ): # TODO: try _apply_object_wise_non_overlapping_constraints instead high_res_masks = self.tracker._suppress_object_pw_area_shrinkage( high_res_masks ) # Instead of gathering the predicted object scores, we use mask areas as a proxy. object_score_logits = torch.where( (high_res_masks > 0).any(dim=(-1, -2)), 10.0, -10.0 ) if self.is_multiplex and self.tracker.is_multiplex_dynamic: # The objects in the masks are ordered w.r.t. object IDs, # which might not be true in the dynamic multiplex case with backfilling # (see also _propogate_tracker_one_frame_local_gpu) # We need to plan globally for the mask assignment here object_idx_assignment: dict[int, list[int]] = {} all_object_ids: list[int] = [] object_id_to_state_i: dict[int, int] = {} for state_i, sam2_state in enumerate(sam2_inference_states): obj_ids = sam2_state["obj_ids"] all_object_ids.extend(obj_ids) for obj_id in obj_ids: object_id_to_state_i[obj_id] = state_i object_idx_assignment[state_i] = [] sorted_indices = sorted( range(len(all_object_ids)), key=lambda i: all_object_ids[i] ) # Build the object_idx_assignment mapping for global_idx, local_idx in enumerate(sorted_indices): obj_id = all_object_ids[local_idx] object_idx_assignment[object_id_to_state_i[obj_id]].append(global_idx) # Run the memory encoder on local slices for each GPU start_idx_gpu = sum(tracker_metadata["num_obj_per_gpu"][: self.rank]) start_idx_state = start_idx_gpu for state_i, sam2_state in enumerate(sam2_inference_states): num_obj_per_state = len(sam2_state["obj_ids"]) if num_obj_per_state == 0: continue # Get the local high-res masks and object score logits for this inference state if self.is_multiplex and self.tracker.is_multiplex_dynamic: local_idx = ( torch.tensor(object_idx_assignment[state_i], device="cpu") .pin_memory() .to(device=high_res_masks.device, non_blocking=True) ) local_high_res_masks = high_res_masks[local_idx] local_object_score_logits = object_score_logits[local_idx] else: end_idx_state = start_idx_state + num_obj_per_state local_high_res_masks = high_res_masks[start_idx_state:end_idx_state] local_object_score_logits = object_score_logits[ start_idx_state:end_idx_state ] local_batch_size = local_high_res_masks.size(0) # Run Sam2 memory encoder. Note that we do not re-enforce the non-overlapping constraint as it is turned off by default encoded_mem = self.tracker._run_memory_encoder( sam2_state, frame_idx, local_batch_size, local_high_res_masks, local_object_score_logits, is_mask_from_pts=False, ) if self.is_multiplex: ( local_maskmem_features, local_maskmem_pos_enc, local_image_features, local_image_pos_enc, ) = encoded_mem else: local_maskmem_features, local_maskmem_pos_enc = encoded_mem # Store encoded memories in the local inference state output_dict = sam2_state["output_dict"] for storage_key in ["cond_frame_outputs", "non_cond_frame_outputs"]: if frame_idx not in output_dict[storage_key]: continue output_dict[storage_key][frame_idx]["maskmem_features"] = ( local_maskmem_features ) output_dict[storage_key][frame_idx]["maskmem_pos_enc"] = [ pos for pos in local_maskmem_pos_enc ] if self.is_multiplex: output_dict[storage_key][frame_idx]["image_features"] = ( local_image_features ) output_dict[storage_key][frame_idx]["image_pos_enc"] = ( local_image_pos_enc ) if self.reapply_no_object_pointer: # reapply the no_object_pointer projection for the objects suppressed by the heuristics newly_suppressed_objects = ( output_dict[storage_key][frame_idx]["object_score_logits"] > self.tracker.object_score_logit_threshold ) & (local_object_score_logits < 0) if torch.any(newly_suppressed_objects): existing_pointers = output_dict[storage_key][frame_idx][ "obj_ptr" ] multiplex_state = sam2_state["multiplex_state"] existing_pointers = multiplex_state.demux(existing_pointers) newly_suppressed_objects = newly_suppressed_objects.float() new_pointers = ( newly_suppressed_objects * self.tracker.no_obj_ptr_linear(existing_pointers) + (1 - newly_suppressed_objects) * existing_pointers ) output_dict[storage_key][frame_idx]["obj_ptr"] = ( multiplex_state.mux(new_pointers) ) elif self.reapply_no_object_pointer: raise NotImplementedError( "reapply_no_object_pointer is not implemented for non-multiplex" ) # for batched inference state, we also need to add per-object # memory slides to support instance interactivity self.tracker.add_output_per_object( inference_state=sam2_state, frame_idx=frame_idx, current_out=output_dict[storage_key][frame_idx], storage_key=storage_key, ) start_idx_state += num_obj_per_state def _tracker_add_new_objects( self, frame_idx: int, num_frames: int, new_obj_ids: List[int], new_obj_masks: Tensor, tracker_states_local: List[Any], orig_vid_height: int, orig_vid_width: int, feature_cache: Dict, ): """Add new objects to SAM2 inference states.""" prev_sam2_state = ( tracker_states_local[0] if len(tracker_states_local) > 0 else None ) # prepare inference_state if self.tracker.is_multiplex_dynamic: # in multiplex_dynamic mode, we first try to find the best-fit # inference state for the new objects. # Create a new state if needed num_new_objects = len(new_obj_ids) # Try to find existing states with available slots best_state = None best_available_slots = float("inf") for state in tracker_states_local: available_slots = state["multiplex_state"].available_slots # Find the state with the least available slots that can still fit the new objects if ( available_slots >= num_new_objects and available_slots < best_available_slots ): best_state = state best_available_slots = available_slots if best_state is not None: # Use the existing state with sufficient available slots new_sam2_state = best_state else: # Need to create a new state new_sam2_state = self.tracker.init_state( cached_features=feature_cache, video_height=orig_vid_height, video_width=orig_vid_width, num_frames=num_frames, ) new_sam2_state["backbone_out"] = ( prev_sam2_state.get("backbone_out", None) if prev_sam2_state is not None else None ) # Add the new state to our local states list tracker_states_local.append(new_sam2_state) else: if self.tracker.per_obj_inference: # in per_obj_inference mode, init_state happens only once, # new obj_ids will be added to the existing inference state if prev_sam2_state is not None: new_sam2_state = prev_sam2_state else: new_sam2_state = self.tracker.init_state( cached_features=feature_cache, video_height=orig_vid_height, video_width=orig_vid_width, num_frames=num_frames, ) new_sam2_state["backbone_out"] = None tracker_states_local = [new_sam2_state] else: # batch objects that first appear on the same frame together # Clear inference state. Keep the cached image features if available. new_sam2_state = self.tracker.init_state( cached_features=feature_cache, video_height=orig_vid_height, video_width=orig_vid_width, num_frames=num_frames, ) new_sam2_state["backbone_out"] = ( prev_sam2_state.get("backbone_out", None) if prev_sam2_state is not None else None ) tracker_states_local.append(new_sam2_state) assert len(new_obj_ids) == new_obj_masks.size(0) assert new_obj_masks.is_floating_point() # TODO consider removing this interpolation -- it's probably no longer needed # we should edit `self.tracker.add_new_mask` to directly take low-res input masks input_mask_res = self.tracker.input_mask_size new_obj_masks = F.interpolate( new_obj_masks.unsqueeze(1), size=(input_mask_res, input_mask_res), mode="bilinear", align_corners=False, ).squeeze(1) new_obj_masks = new_obj_masks > 0 if self.is_multiplex: # add all objects at once # NOTE: In the current implementation, add_new_masks also runs the memory encoder # the non-overlapping constraint is enforced self.tracker.add_new_masks( inference_state=new_sam2_state, frame_idx=frame_idx, obj_ids=new_obj_ids, masks=new_obj_masks, add_mask_to_memory=True, ) else: # add object one by one for new_obj_id, new_mask in zip(new_obj_ids, new_obj_masks): self.tracker.add_new_mask( inference_state=new_sam2_state, frame_idx=frame_idx, obj_id=new_obj_id, mask=new_mask, add_mask_to_memory=True, ) # NOTE: we skip enforcing the non-overlapping constraint **globally** when adding new objects. self.tracker.propagate_in_video_preflight(new_sam2_state, run_mem_encoder=True) return tracker_states_local def _tracker_remove_objects( self, tracker_states_local: List[Any], obj_ids: list[int] ): """ Remove an object from SAM2 inference states. This would remove the object from all frames in the video. """ if self.is_multiplex: tracker_states_local_before_removal = tracker_states_local.copy() tracker_states_local.clear() for sam2_inference_state in tracker_states_local_before_removal: # we try to remove `obj_id` on every inference state with `strict=False` # it will not do anything if an inference state doesn't contain `obj_id` new_obj_ids, _ = self.tracker.remove_objects( sam2_inference_state, obj_ids, strict=False, need_output=False ) # only keep an inference state if it's non-empty after object removal if len(new_obj_ids) > 0: tracker_states_local.append(sam2_inference_state) else: for obj_id in obj_ids: self._tracker_remove_object(tracker_states_local, obj_id) def update_masklet_confirmation_status( self, rank0_metadata: Dict[str, Any], obj_ids_all_gpu_prev: np.ndarray, obj_ids_all_gpu_updated: np.ndarray, det_to_matched_trk_obj_ids: Dict[int, np.ndarray], new_det_obj_ids: np.ndarray, ): """ Update masklet confirmation status. """ confirmation_data = rank0_metadata["masklet_confirmation"] status_prev = confirmation_data["status"] consecutive_det_num_prev = confirmation_data["consecutive_det_num"] N_prev = len(obj_ids_all_gpu_prev) N_updated = len(obj_ids_all_gpu_updated) # a) Map previous confirmation data to updated positions # For small arrays, simple dict lookup is fast unconfirmed_val = MaskletConfirmationStatus.UNCONFIRMED.value status = np.full(N_updated, unconfirmed_val, dtype=np.int64) consecutive_det_num = np.zeros(N_updated, dtype=np.int64) if N_prev > 0 and N_updated > 0: # Build mapping: obj_id -> new index obj_id_to_new_idx = { obj_id: idx for idx, obj_id in enumerate(obj_ids_all_gpu_updated) } # Copy previous values for objects that still exist for old_idx, obj_id in enumerate(obj_ids_all_gpu_prev): new_idx = obj_id_to_new_idx.get(obj_id) if new_idx is not None: status[new_idx] = status_prev[old_idx] consecutive_det_num[new_idx] = consecutive_det_num_prev[old_idx] # b) Update confirmation status based on current frame detections # Build set of all matched object IDs matched_obj_ids = set(new_det_obj_ids) for matched_trk_ids in det_to_matched_trk_obj_ids.values(): matched_obj_ids.update(matched_trk_ids) # Update consecutive detection count and status for idx, obj_id in enumerate(obj_ids_all_gpu_updated): if obj_id in matched_obj_ids: consecutive_det_num[idx] += 1 else: consecutive_det_num[idx] = 0 # Update status to CONFIRMED where threshold is met if ( consecutive_det_num[idx] >= self.masklet_confirmation_consecutive_det_thresh ): status[idx] = MaskletConfirmationStatus.CONFIRMED.value # Store updated arrays confirmation_data["status"] = status confirmation_data["consecutive_det_num"] = consecutive_det_num return rank0_metadata class Sam3MultiplexPredictorWrapper(Sam3MultiplexTrackerPredictor): """ Wraps a pre-built multiplex tracker model with the same interface as the onevision Sam3MultiplexTrackerPredictor class. Inherits from Sam3MultiplexTrackerPredictor to pass isinstance checks, but skips Sam3MultiplexTrackerPredictor.__init__ (which requires Hydra). Provides bf16 autocast, attribute proxying, and configuration flags needed by Sam3MultiplexTracking. The onevision Sam3MultiplexTrackerPredictor builds the tracker from Hydra config and applies extensive hydra_overrides. This version skips Hydra entirely — the caller is responsible for building the tracker via model_builder.py with the correct parameters. Key parameters that the onevision Sam3MultiplexTrackerPredictor sets via hydra_overrides (documented here for reference — these must be set in model_builder.py): - image_size=1008, backbone_stride=14 - maskmem_backbone.mask_downsampler.interpol_size=[1152,1152] - always_start_from_first_ann_frame=false - non_overlap_masks_for_mem_enc=false, non_overlap_masks_for_output=false - max_cond_frames_in_attn=4 - offload_output_to_cpu_for_eval=false, trim_past_non_cond_mem_for_eval=false - sam_mask_decoder_extra_args: dynamic_multimask_via_stability=true, etc. - binarize_mask_from_pts_for_mem_enc=true (SAM2 tracker default) - only_obj_ptrs_in_the_past_for_eval=true - clear_non_cond_mem_around_input=true - transformer.encoder.layer.self_attention.feat_sizes=[72,72] - transformer.encoder.layer.cross_attention.feat_sizes=[72,72] - fill_hole_area= - use_fa3, use_rope_real on self_attention, cross_attention, self_attention_rope, cross_attention_rope - use_memory_selection """ def __init__( self, model, per_obj_inference=False, fill_hole_area=0, is_multiplex=True, is_multiplex_dynamic=True, ): # Skip Sam3MultiplexTrackerPredictor.__init__ (requires Hydra) — call nn.Module.__init__ directly nn.Module.__init__(self) self.model = model self.per_obj_inference = per_obj_inference self.fill_hole_area = fill_hole_area self.is_multiplex = is_multiplex self.is_multiplex_dynamic = is_multiplex_dynamic # use bfloat16 inference for Flash Attention kernel self.bf16_context = accelerator_autocast() self.bf16_context.__enter__()