| |
|
|
| import datetime |
| import logging |
| import math |
| import os |
| from collections import defaultdict |
| from copy import deepcopy |
| from enum import Enum |
| from typing import Any, Dict, List, Set |
|
|
| import numpy as np |
| import numpy.typing as npt |
| import torch |
| import torch.distributed as dist |
| import torch.nn.functional as F |
|
|
| from sam3 import perflib |
| from sam3.logger import get_logger |
| from sam3.model.box_ops import fast_diag_box_iou |
| from sam3.model.data_misc import BatchedDatapoint |
| from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores, mask_to_box |
| from sam3.perflib.masks_ops import mask_iou |
| from sam3.train.masks_ops import rle_encode |
| from torch import nn, Tensor |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| class MaskletConfirmationStatus(Enum): |
| UNCONFIRMED = 1 |
| CONFIRMED = 2 |
|
|
|
|
| class Sam3VideoBase(nn.Module): |
| def __init__( |
| self, |
| detector: nn.Module, |
| tracker: nn.Module, |
| |
| |
| score_threshold_detection=0.5, |
| |
| det_nms_thresh=0.0, |
| |
| |
| assoc_iou_thresh=0.5, |
| |
| |
| trk_assoc_iou_thresh=0.5, |
| |
| new_det_thresh=0.0, |
| |
| |
| |
| hotstart_delay=0, |
| hotstart_unmatch_thresh=3, |
| hotstart_dup_thresh=3, |
| |
| suppress_unmatched_only_within_hotstart=True, |
| init_trk_keep_alive=0, |
| max_trk_keep_alive=8, |
| min_trk_keep_alive=-4, |
| |
| suppress_overlapping_based_on_recent_occlusion_threshold=0.0, |
| decrease_trk_keep_alive_for_empty_masklets=False, |
| o2o_matching_masklets_enable=False, |
| suppress_det_close_to_boundary=False, |
| fill_hole_area=16, |
| |
| max_num_objects=-1, |
| recondition_every_nth_frame=-1, |
| |
| masklet_confirmation_enable=False, |
| |
| |
| masklet_confirmation_consecutive_det_thresh=3, |
| |
| reconstruction_bbox_iou_thresh=0.0, |
| reconstruction_bbox_det_score=0.0, |
| ): |
| super().__init__() |
| self.detector = detector |
| self.tracker = tracker |
| self.score_threshold_detection = score_threshold_detection |
| self.det_nms_thresh = det_nms_thresh |
| self.assoc_iou_thresh = assoc_iou_thresh |
| self.trk_assoc_iou_thresh = trk_assoc_iou_thresh |
| self.new_det_thresh = new_det_thresh |
|
|
| |
| if hotstart_delay > 0: |
| assert hotstart_unmatch_thresh <= hotstart_delay |
| assert hotstart_dup_thresh <= hotstart_delay |
| self.hotstart_delay = hotstart_delay |
| self.hotstart_unmatch_thresh = hotstart_unmatch_thresh |
| self.hotstart_dup_thresh = hotstart_dup_thresh |
| self.suppress_unmatched_only_within_hotstart = ( |
| suppress_unmatched_only_within_hotstart |
| ) |
| self.init_trk_keep_alive = init_trk_keep_alive |
| self.max_trk_keep_alive = max_trk_keep_alive |
| self.min_trk_keep_alive = min_trk_keep_alive |
| self.suppress_overlapping_based_on_recent_occlusion_threshold = ( |
| suppress_overlapping_based_on_recent_occlusion_threshold |
| ) |
| self.suppress_det_close_to_boundary = suppress_det_close_to_boundary |
| self.decrease_trk_keep_alive_for_empty_masklets = ( |
| decrease_trk_keep_alive_for_empty_masklets |
| ) |
| self.o2o_matching_masklets_enable = o2o_matching_masklets_enable |
| self.fill_hole_area = fill_hole_area |
| self.eval() |
| self.rank = int(os.getenv("RANK", "0")) |
| self.world_size = int(os.getenv("WORLD_SIZE", "1")) |
| self._dist_pg_cpu = None |
|
|
| |
| if max_num_objects > 0: |
| num_obj_for_compile = math.ceil(max_num_objects / self.world_size) |
| else: |
| max_num_objects = 10000 |
| num_obj_for_compile = 16 |
| logger.info(f"setting {max_num_objects=} and {num_obj_for_compile=}") |
| self.max_num_objects = max_num_objects |
| self.num_obj_for_compile = num_obj_for_compile |
| self.recondition_every_nth_frame = recondition_every_nth_frame |
| self.masklet_confirmation_enable = masklet_confirmation_enable |
| self.masklet_confirmation_consecutive_det_thresh = ( |
| masklet_confirmation_consecutive_det_thresh |
| ) |
| self.reconstruction_bbox_iou_thresh = reconstruction_bbox_iou_thresh |
| self.reconstruction_bbox_det_score = reconstruction_bbox_det_score |
|
|
| @property |
| def device(self): |
| self._device = getattr(self, "_device", None) or next(self.parameters()).device |
| return self._device |
|
|
| def _init_dist_pg_cpu(self): |
| |
| timeout_sec = int(os.getenv("SAM3_COLLECTIVE_OP_TIMEOUT_SEC", "180")) |
| timeout = datetime.timedelta(seconds=timeout_sec) |
| self._dist_pg_cpu = dist.new_group(backend="gloo", timeout=timeout) |
|
|
| def broadcast_python_obj_cpu(self, python_obj_list, src): |
| if self._dist_pg_cpu is None: |
| self._init_dist_pg_cpu() |
| dist.broadcast_object_list(python_obj_list, src=src, group=self._dist_pg_cpu) |
|
|
| def _det_track_one_frame( |
| self, |
| frame_idx: int, |
| num_frames: int, |
| reverse: bool, |
| input_batch: BatchedDatapoint, |
| geometric_prompt: Any, |
| tracker_states_local: List[Any], |
| tracker_metadata_prev: Dict[str, Any], |
| feature_cache: Dict, |
| orig_vid_height: int, |
| orig_vid_width: int, |
| is_image_only: bool = False, |
| allow_new_detections: bool = True, |
| ): |
| """ |
| This function handles one-step inference for the DenseTracking model in an SPMD manner. |
| At a high-level, all GPUs execute the same function calls as if it's done on a single GPU, |
| while under the hood, some function calls involve distributed computation based on sharded |
| SAM2 states. |
| |
| - `input_batch` contains image and other inputs on the entire video; it should be identical across GPUs |
| - `tracker_states_local` holds the local masklet information in this GPU shard |
| - `tracker_metadata_prev` manages the metadata for SAM2 objects, such as which masklet is hold on which GPUs |
| it contains both global and local masklet information |
| """ |
|
|
| |
| |
| |
| |
| |
| det_out = self.run_backbone_and_detection( |
| frame_idx=frame_idx, |
| num_frames=num_frames, |
| reverse=reverse, |
| input_batch=input_batch, |
| geometric_prompt=geometric_prompt, |
| feature_cache=feature_cache, |
| allow_new_detections=allow_new_detections, |
| ) |
|
|
| |
| |
| |
| |
| |
| if tracker_metadata_prev == {}: |
| |
| tracker_metadata_prev.update(self._initialize_metadata()) |
| tracker_low_res_masks_global, tracker_obj_scores_global = ( |
| self.run_tracker_propagation( |
| frame_idx=frame_idx, |
| num_frames=num_frames, |
| reverse=reverse, |
| tracker_states_local=tracker_states_local, |
| tracker_metadata_prev=tracker_metadata_prev, |
| ) |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| tracker_update_plan, tracker_metadata_new = ( |
| self.run_tracker_update_planning_phase( |
| frame_idx=frame_idx, |
| num_frames=num_frames, |
| reverse=reverse, |
| det_out=det_out, |
| tracker_low_res_masks_global=tracker_low_res_masks_global, |
| tracker_obj_scores_global=tracker_obj_scores_global, |
| tracker_metadata_prev=tracker_metadata_prev, |
| tracker_states_local=tracker_states_local, |
| is_image_only=is_image_only, |
| ) |
| ) |
|
|
| |
| reconditioned_obj_ids = tracker_update_plan.get("reconditioned_obj_ids", set()) |
| det_to_matched_trk_obj_ids = tracker_update_plan.get( |
| "det_to_matched_trk_obj_ids", {} |
| ) |
|
|
| |
| tracker_states_local_new = self.run_tracker_update_execution_phase( |
| frame_idx=frame_idx, |
| num_frames=num_frames, |
| reverse=reverse, |
| det_out=det_out, |
| tracker_states_local=tracker_states_local, |
| tracker_update_plan=tracker_update_plan, |
| orig_vid_height=orig_vid_height, |
| orig_vid_width=orig_vid_width, |
| feature_cache=feature_cache, |
| ) |
|
|
| |
| |
| if self.rank == 0: |
| obj_id_to_mask = self.build_outputs( |
| frame_idx=frame_idx, |
| num_frames=num_frames, |
| reverse=reverse, |
| det_out=det_out, |
| tracker_low_res_masks_global=tracker_low_res_masks_global, |
| tracker_obj_scores_global=tracker_obj_scores_global, |
| tracker_metadata_prev=tracker_metadata_prev, |
| tracker_update_plan=tracker_update_plan, |
| orig_vid_height=orig_vid_height, |
| orig_vid_width=orig_vid_width, |
| reconditioned_obj_ids=reconditioned_obj_ids, |
| det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids, |
| ) |
| obj_id_to_score = tracker_metadata_new["obj_id_to_score"] |
| else: |
| obj_id_to_mask, obj_id_to_score = {}, {} |
| |
| frame_stats = { |
| "num_obj_tracked": np.sum(tracker_metadata_new["num_obj_per_gpu"]), |
| "num_obj_dropped": tracker_update_plan["num_obj_dropped_due_to_limit"], |
| } |
| |
| if tracker_obj_scores_global.shape[0] > 0: |
| |
| tracker_obj_scores_global = tracker_obj_scores_global.sigmoid().tolist() |
| tracker_obj_ids = tracker_metadata_prev["obj_ids_all_gpu"] |
| tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][ |
| frame_idx |
| ].update(dict(zip(tracker_obj_ids, tracker_obj_scores_global))) |
| return ( |
| obj_id_to_mask, |
| obj_id_to_score, |
| tracker_states_local_new, |
| tracker_metadata_new, |
| frame_stats, |
| tracker_obj_scores_global, |
| ) |
|
|
| def _suppress_detections_close_to_boundary(self, boxes, margin=0.025): |
| """ |
| Suppress detections too close to image edges (for normalized boxes). |
| |
| boxes: (N, 4) in xyxy format, normalized [0,1] |
| margin: fraction of image |
| """ |
| x_min, y_min, x_max, y_max = boxes.unbind(-1) |
| x_c = (x_min + x_max) / 2 |
| y_c = (y_min + y_max) / 2 |
| keep = ( |
| (x_c > margin) |
| & (x_c < 1.0 - margin) |
| & (y_c > margin) |
| & (y_c < 1.0 - margin) |
| ) |
|
|
| return keep |
|
|
| def run_backbone_and_detection( |
| self, |
| frame_idx: int, |
| num_frames: int, |
| input_batch: BatchedDatapoint, |
| geometric_prompt: Any, |
| feature_cache: Dict, |
| reverse: bool, |
| allow_new_detections: bool, |
| ): |
| |
| text_batch_key = tuple(input_batch.find_text_batch) |
| if "text" not in feature_cache or text_batch_key not in feature_cache["text"]: |
| text_outputs = self.detector.backbone.forward_text( |
| input_batch.find_text_batch, device=self.device |
| ) |
| |
| feature_cache["text"] = {text_batch_key: text_outputs} |
| else: |
| text_outputs = feature_cache["text"][text_batch_key] |
|
|
| |
| if "multigpu_buffer" not in feature_cache: |
| |
| |
| feature_cache["multigpu_buffer"] = {} |
|
|
| |
| tracking_bounds = feature_cache.get("tracking_bounds", {}) |
| max_frame_num_to_track = tracking_bounds.get("max_frame_num_to_track") |
| start_frame_idx = tracking_bounds.get("propagate_in_video_start_frame_idx") |
|
|
| sam3_image_out, _ = self.detector.forward_video_grounding_multigpu( |
| backbone_out={ |
| "img_batch_all_stages": input_batch.img_batch, |
| **text_outputs, |
| }, |
| find_inputs=input_batch.find_inputs, |
| geometric_prompt=geometric_prompt, |
| frame_idx=frame_idx, |
| num_frames=num_frames, |
| multigpu_buffer=feature_cache["multigpu_buffer"], |
| track_in_reverse=reverse, |
| |
| return_tracker_backbone_feats=True, |
| |
| run_nms=self.det_nms_thresh > 0.0, |
| nms_prob_thresh=self.score_threshold_detection, |
| nms_iou_thresh=self.det_nms_thresh, |
| |
| max_frame_num_to_track=max_frame_num_to_track, |
| propagate_in_video_start_frame_idx=start_frame_idx, |
| ) |
| |
| pred_probs = sam3_image_out["pred_logits"].squeeze(-1).sigmoid() |
| if not allow_new_detections: |
| pred_probs = pred_probs - 1e8 |
| pred_boxes_xyxy = sam3_image_out["pred_boxes_xyxy"] |
| pred_masks = sam3_image_out["pred_masks"] |
| |
| pos_pred_idx = torch.where(pred_probs > self.score_threshold_detection) |
| det_out = { |
| "bbox": pred_boxes_xyxy[pos_pred_idx[0], pos_pred_idx[1]], |
| "mask": pred_masks[pos_pred_idx[0], pos_pred_idx[1]], |
| "scores": pred_probs[pos_pred_idx[0], pos_pred_idx[1]], |
| } |
|
|
| |
| backbone_cache = {} |
| sam_mask_decoder = self.tracker.sam_mask_decoder |
| tracker_backbone_fpn = [ |
| sam_mask_decoder.conv_s0(sam3_image_out["tracker_backbone_fpn_0"]), |
| sam_mask_decoder.conv_s1(sam3_image_out["tracker_backbone_fpn_1"]), |
| sam3_image_out["tracker_backbone_fpn_2"], |
| ] |
| tracker_backbone_out = { |
| "vision_features": tracker_backbone_fpn[-1], |
| "vision_pos_enc": sam3_image_out["tracker_backbone_pos_enc"], |
| "backbone_fpn": tracker_backbone_fpn, |
| } |
| backbone_cache["tracker_backbone_out"] = tracker_backbone_out |
| feature_cache[frame_idx] = ( |
| input_batch.img_batch[frame_idx], |
| backbone_cache, |
| ) |
| |
| feature_cache.pop(frame_idx - 1 if not reverse else frame_idx + 1, None) |
| return det_out |
|
|
| def run_tracker_propagation( |
| self, |
| frame_idx: int, |
| num_frames: int, |
| reverse: bool, |
| tracker_states_local: List[Any], |
| tracker_metadata_prev: Dict[str, npt.NDArray], |
| ): |
| |
| |
| |
| |
| obj_ids_local, low_res_masks_local, obj_scores_local = ( |
| self._propogate_tracker_one_frame_local_gpu( |
| tracker_states_local, frame_idx=frame_idx, reverse=reverse |
| ) |
| ) |
|
|
| assert np.all( |
| obj_ids_local == tracker_metadata_prev["obj_ids_per_gpu"][self.rank] |
| ), "{} != {}".format( |
| obj_ids_local, tracker_metadata_prev["obj_ids_per_gpu"][self.rank] |
| ) |
|
|
| |
| |
| _, H_mask, W_mask = low_res_masks_local.shape |
| if self.world_size > 1: |
| |
| |
| low_res_masks_local = low_res_masks_local.float().contiguous() |
| obj_scores_local = obj_scores_local.float().contiguous() |
| num_obj_this_gpu = tracker_metadata_prev["num_obj_per_gpu"][self.rank] |
| assert low_res_masks_local.size(0) == num_obj_this_gpu |
| assert obj_scores_local.size(0) == num_obj_this_gpu |
| low_res_masks_peers = [ |
| low_res_masks_local.new_empty(num_obj, H_mask, W_mask) |
| for num_obj in tracker_metadata_prev["num_obj_per_gpu"] |
| ] |
| obj_scores_peers = [ |
| obj_scores_local.new_empty(num_obj) |
| for num_obj in tracker_metadata_prev["num_obj_per_gpu"] |
| ] |
| dist.all_gather(low_res_masks_peers, low_res_masks_local) |
| dist.all_gather(obj_scores_peers, obj_scores_local) |
| low_res_masks_global = torch.cat(low_res_masks_peers, dim=0) |
| obj_scores_global = torch.cat(obj_scores_peers, dim=0) |
| else: |
| low_res_masks_global = low_res_masks_local |
| obj_scores_global = obj_scores_local |
| return low_res_masks_global, obj_scores_global |
|
|
| def _recondition_masklets( |
| self, |
| frame_idx, |
| det_out: Dict[str, Tensor], |
| trk_id_to_max_iou_high_conf_det: List[int], |
| tracker_states_local: List[Any], |
| tracker_metadata: Dict[str, npt.NDArray], |
| tracker_obj_scores_global: Tensor, |
| ): |
| |
| for trk_obj_id, det_idx in trk_id_to_max_iou_high_conf_det.items(): |
| new_mask = det_out["mask"][det_idx : det_idx + 1] |
| input_mask_res = self.tracker.input_mask_size |
| new_mask_binary = ( |
| F.interpolate( |
| new_mask.unsqueeze(1), |
| size=(input_mask_res, input_mask_res), |
| mode="bilinear", |
| align_corners=False, |
| ).squeeze(1)[0] |
| > 0 |
| ) |
| HIGH_CONF_THRESH = 0.8 |
| reconditioned_states_idx = set() |
| obj_idx = np.where(tracker_metadata["obj_ids_all_gpu"] == trk_obj_id)[ |
| 0 |
| ].item() |
| obj_score = tracker_obj_scores_global[obj_idx] |
| for state_idx, inference_state in enumerate(tracker_states_local): |
| if ( |
| trk_obj_id in inference_state["obj_ids"] |
| |
| |
| and obj_score > HIGH_CONF_THRESH |
| ): |
| logger.debug( |
| f"Adding new mask for track {trk_obj_id} at frame {frame_idx}. Objects {inference_state['obj_ids']} are all reconditioned." |
| ) |
| self.tracker.add_new_mask( |
| inference_state=inference_state, |
| frame_idx=frame_idx, |
| obj_id=trk_obj_id, |
| mask=new_mask_binary, |
| ) |
| reconditioned_states_idx.add(state_idx) |
|
|
| for idx in reconditioned_states_idx: |
| self.tracker.propagate_in_video_preflight( |
| tracker_states_local[idx], run_mem_encoder=True |
| ) |
| return tracker_states_local |
|
|
| def run_tracker_update_planning_phase( |
| self, |
| frame_idx: int, |
| num_frames: int, |
| reverse: bool, |
| det_out: Dict[str, Tensor], |
| tracker_low_res_masks_global: Tensor, |
| tracker_obj_scores_global: Tensor, |
| tracker_metadata_prev: Dict[str, npt.NDArray], |
| tracker_states_local: List[Any], |
| is_image_only: bool = False, |
| ): |
| |
| tracker_metadata_new = { |
| "obj_ids_per_gpu": deepcopy(tracker_metadata_prev["obj_ids_per_gpu"]), |
| "obj_ids_all_gpu": None, |
| "num_obj_per_gpu": deepcopy(tracker_metadata_prev["num_obj_per_gpu"]), |
| "obj_id_to_score": deepcopy(tracker_metadata_prev["obj_id_to_score"]), |
| "obj_id_to_tracker_score_frame_wise": deepcopy( |
| tracker_metadata_prev["obj_id_to_tracker_score_frame_wise"] |
| ), |
| "obj_id_to_last_occluded": {}, |
| "max_obj_id": deepcopy(tracker_metadata_prev["max_obj_id"]), |
| } |
|
|
| |
| reconditioned_obj_ids = set() |
|
|
| |
| det_mask_preds: Tensor = det_out["mask"] |
| det_scores_np: npt.NDArray = det_out["scores"].float().cpu().numpy() |
| det_bbox_xyxy: Tensor = det_out["bbox"] |
| if self.rank == 0: |
| |
| ( |
| new_det_fa_inds, |
| unmatched_trk_obj_ids, |
| det_to_matched_trk_obj_ids, |
| trk_id_to_max_iou_high_conf_det, |
| empty_trk_obj_ids, |
| ) = self._associate_det_trk( |
| det_masks=det_mask_preds, |
| det_scores_np=det_scores_np, |
| trk_masks=tracker_low_res_masks_global, |
| trk_obj_ids=tracker_metadata_prev["obj_ids_all_gpu"], |
| ) |
| if self.suppress_det_close_to_boundary: |
| keep = self._suppress_detections_close_to_boundary( |
| det_bbox_xyxy[new_det_fa_inds] |
| ) |
| new_det_fa_inds = new_det_fa_inds[keep.cpu().numpy()] |
|
|
| |
| prev_obj_num = np.sum(tracker_metadata_prev["num_obj_per_gpu"]) |
| new_det_num = len(new_det_fa_inds) |
| num_obj_dropped_due_to_limit = 0 |
| if not is_image_only and prev_obj_num + new_det_num > self.max_num_objects: |
| logger.warning( |
| f"hitting {self.max_num_objects=} with {new_det_num=} and {prev_obj_num=}" |
| ) |
| new_det_num_to_keep = self.max_num_objects - prev_obj_num |
| num_obj_dropped_due_to_limit = new_det_num - new_det_num_to_keep |
| new_det_fa_inds = self._drop_new_det_with_obj_limit( |
| new_det_fa_inds, det_scores_np, new_det_num_to_keep |
| ) |
| assert len(new_det_fa_inds) == new_det_num_to_keep |
| new_det_num = len(new_det_fa_inds) |
|
|
| |
| new_det_start_obj_id = tracker_metadata_prev["max_obj_id"] + 1 |
| new_det_obj_ids = new_det_start_obj_id + np.arange(new_det_num) |
| prev_workload_per_gpu = tracker_metadata_prev["num_obj_per_gpu"] |
| new_det_gpu_ids = self._assign_new_det_to_gpus( |
| new_det_num=new_det_num, |
| prev_workload_per_gpu=prev_workload_per_gpu, |
| ) |
|
|
| |
| |
| |
| |
| rank0_metadata_new = deepcopy(tracker_metadata_prev["rank0_metadata"]) |
| if not hasattr(self, "_warm_up_complete") or self._warm_up_complete: |
| obj_ids_newly_removed, rank0_metadata_new = self._process_hotstart( |
| frame_idx=frame_idx, |
| num_frames=num_frames, |
| reverse=reverse, |
| det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids, |
| new_det_obj_ids=new_det_obj_ids, |
| empty_trk_obj_ids=empty_trk_obj_ids, |
| unmatched_trk_obj_ids=unmatched_trk_obj_ids, |
| rank0_metadata=rank0_metadata_new, |
| tracker_metadata=tracker_metadata_prev, |
| ) |
| else: |
| |
| obj_ids_newly_removed = set() |
| tracker_metadata_new["rank0_metadata"] = rank0_metadata_new |
|
|
| |
| NUM_BROADCAST_ITEMS = 9 |
| if self.rank == 0 and self.world_size > 1: |
| |
| |
| num_obj_per_gpu_on_rank0 = tracker_metadata_prev["num_obj_per_gpu"] |
| update_plan = [ |
| new_det_fa_inds, |
| new_det_obj_ids, |
| new_det_gpu_ids, |
| num_obj_per_gpu_on_rank0, |
| unmatched_trk_obj_ids, |
| det_to_matched_trk_obj_ids, |
| obj_ids_newly_removed, |
| num_obj_dropped_due_to_limit, |
| trk_id_to_max_iou_high_conf_det, |
| ] |
| assert ( |
| len(update_plan) == NUM_BROADCAST_ITEMS |
| ), f"Manually update NUM_BROADCAST_ITEMS to be: {len(update_plan)}" |
| self.broadcast_python_obj_cpu(update_plan, src=0) |
| elif self.rank > 0 and self.world_size > 1: |
| update_plan = [ |
| None |
| ] * NUM_BROADCAST_ITEMS |
| self.broadcast_python_obj_cpu(update_plan, src=0) |
| ( |
| new_det_fa_inds, |
| new_det_obj_ids, |
| new_det_gpu_ids, |
| num_obj_per_gpu_on_rank0, |
| unmatched_trk_obj_ids, |
| det_to_matched_trk_obj_ids, |
| obj_ids_newly_removed, |
| num_obj_dropped_due_to_limit, |
| trk_id_to_max_iou_high_conf_det, |
| ) = update_plan |
| |
| |
| if not np.all( |
| num_obj_per_gpu_on_rank0 == tracker_metadata_prev["num_obj_per_gpu"] |
| ): |
| raise RuntimeError( |
| f"{self.rank=} received {num_obj_per_gpu_on_rank0=}, which is inconsistent with local record " |
| f"{tracker_metadata_prev['num_obj_per_gpu']=}. There's likely a bug in update planning or execution." |
| ) |
|
|
| |
| tracker_update_plan = { |
| "new_det_fa_inds": new_det_fa_inds, |
| "new_det_obj_ids": new_det_obj_ids, |
| "new_det_gpu_ids": new_det_gpu_ids, |
| "unmatched_trk_obj_ids": unmatched_trk_obj_ids, |
| "det_to_matched_trk_obj_ids": det_to_matched_trk_obj_ids, |
| "obj_ids_newly_removed": obj_ids_newly_removed, |
| "num_obj_dropped_due_to_limit": num_obj_dropped_due_to_limit, |
| "trk_id_to_max_iou_high_conf_det": trk_id_to_max_iou_high_conf_det, |
| "reconditioned_obj_ids": reconditioned_obj_ids, |
| } |
|
|
| |
| |
| should_recondition_iou = False |
|
|
| |
| if ( |
| self.reconstruction_bbox_iou_thresh > 0 |
| and len(trk_id_to_max_iou_high_conf_det) > 0 |
| ): |
| for trk_obj_id, det_idx in trk_id_to_max_iou_high_conf_det.items(): |
| det_box = det_out["bbox"][det_idx] |
| det_score = det_out["scores"][det_idx] |
|
|
| try: |
| trk_idx = list(tracker_metadata_prev["obj_ids_all_gpu"]).index( |
| trk_obj_id |
| ) |
| except ValueError: |
| continue |
|
|
| tracker_mask = tracker_low_res_masks_global[trk_idx] |
| mask_binary = tracker_mask > 0 |
| mask_area = mask_binary.sum().item() |
|
|
| if mask_area == 0: |
| continue |
|
|
| |
| tracker_box_pixels = ( |
| mask_to_box(mask_binary.unsqueeze(0).unsqueeze(0)) |
| .squeeze(0) |
| .squeeze(0) |
| ) |
| mask_height, mask_width = tracker_mask.shape[-2:] |
| tracker_box_normalized = torch.tensor( |
| [ |
| tracker_box_pixels[0] / mask_width, |
| tracker_box_pixels[1] / mask_height, |
| tracker_box_pixels[2] / mask_width, |
| tracker_box_pixels[3] / mask_height, |
| ], |
| device=tracker_box_pixels.device, |
| ) |
|
|
| |
| det_box_batch = det_box.unsqueeze(0) |
| tracker_box_batch = tracker_box_normalized.unsqueeze(0) |
| iou = fast_diag_box_iou(det_box_batch, tracker_box_batch)[0] |
|
|
| if ( |
| iou < self.reconstruction_bbox_iou_thresh |
| and det_score >= self.reconstruction_bbox_det_score |
| ): |
| should_recondition_iou = True |
| reconditioned_obj_ids.add(trk_obj_id) |
|
|
| should_recondition_periodic = ( |
| self.recondition_every_nth_frame > 0 |
| and frame_idx % self.recondition_every_nth_frame == 0 |
| and len(trk_id_to_max_iou_high_conf_det) > 0 |
| ) |
|
|
| |
| if should_recondition_periodic or should_recondition_iou: |
| self._recondition_masklets( |
| frame_idx, |
| det_out, |
| trk_id_to_max_iou_high_conf_det, |
| tracker_states_local, |
| tracker_metadata_prev, |
| tracker_obj_scores_global, |
| ) |
|
|
| |
| |
| batch_size = tracker_low_res_masks_global.size(0) |
| if batch_size > 0: |
| if not hasattr(self, "_warm_up_complete") or self._warm_up_complete: |
| if self.suppress_overlapping_based_on_recent_occlusion_threshold > 0.0: |
| |
| tracker_low_res_masks_global = ( |
| self._suppress_overlapping_based_on_recent_occlusion( |
| frame_idx, |
| tracker_low_res_masks_global, |
| tracker_metadata_prev, |
| tracker_metadata_new, |
| obj_ids_newly_removed, |
| reverse, |
| ) |
| ) |
|
|
| self._tracker_update_memories( |
| tracker_states_local, |
| frame_idx, |
| tracker_metadata=tracker_metadata_prev, |
| low_res_masks=tracker_low_res_masks_global, |
| ) |
|
|
| |
| |
| |
| for rank in range(self.world_size): |
| new_det_obj_ids_this_gpu = new_det_obj_ids[new_det_gpu_ids == rank] |
| updated_obj_ids_this_gpu = tracker_metadata_new["obj_ids_per_gpu"][rank] |
| if len(new_det_obj_ids_this_gpu) > 0: |
| updated_obj_ids_this_gpu = np.concatenate( |
| [updated_obj_ids_this_gpu, new_det_obj_ids_this_gpu] |
| ) |
| if len(obj_ids_newly_removed) > 0: |
| is_removed = np.isin( |
| updated_obj_ids_this_gpu, list(obj_ids_newly_removed) |
| ) |
| updated_obj_ids_this_gpu = updated_obj_ids_this_gpu[~is_removed] |
| tracker_metadata_new["obj_ids_per_gpu"][rank] = updated_obj_ids_this_gpu |
| tracker_metadata_new["num_obj_per_gpu"][rank] = len( |
| updated_obj_ids_this_gpu |
| ) |
| tracker_metadata_new["obj_ids_all_gpu"] = np.concatenate( |
| tracker_metadata_new["obj_ids_per_gpu"] |
| ) |
| |
| if len(new_det_obj_ids) > 0: |
| tracker_metadata_new["obj_id_to_score"].update( |
| zip(new_det_obj_ids, det_scores_np[new_det_fa_inds]) |
| ) |
| |
| tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][ |
| frame_idx |
| ].update(zip(new_det_obj_ids, det_scores_np[new_det_fa_inds])) |
| tracker_metadata_new["max_obj_id"] = max( |
| tracker_metadata_new["max_obj_id"], |
| np.max(new_det_obj_ids), |
| ) |
| |
| |
| for obj_id in obj_ids_newly_removed: |
| tracker_metadata_new["obj_id_to_score"][obj_id] = -1e4 |
| tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][frame_idx][ |
| obj_id |
| ] = -1e4 |
| tracker_metadata_new["obj_id_to_last_occluded"].pop(obj_id, None) |
| |
| assert ("rank0_metadata" in tracker_metadata_new) == (self.rank == 0) |
| if self.rank == 0 and self.masklet_confirmation_enable: |
| rank0_metadata = self.update_masklet_confirmation_status( |
| rank0_metadata=tracker_metadata_new["rank0_metadata"], |
| obj_ids_all_gpu_prev=tracker_metadata_prev["obj_ids_all_gpu"], |
| obj_ids_all_gpu_updated=tracker_metadata_new["obj_ids_all_gpu"], |
| det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids, |
| new_det_obj_ids=new_det_obj_ids, |
| ) |
| tracker_metadata_new["rank0_metadata"] = rank0_metadata |
|
|
| return tracker_update_plan, tracker_metadata_new |
|
|
| def _suppress_overlapping_based_on_recent_occlusion( |
| self, |
| frame_idx: int, |
| tracker_low_res_masks_global: Tensor, |
| tracker_metadata_prev: Dict[str, Any], |
| tracker_metadata_new: Dict[str, Any], |
| obj_ids_newly_removed: Set[int], |
| reverse: bool = False, |
| ): |
| """ |
| Suppress overlapping masks based on the most recent occlusion information. If an object is removed by hotstart, we always suppress it if it overlaps with any other object. |
| Args: |
| frame_idx (int): The current frame index. |
| tracker_low_res_masks_global (Tensor): The low-resolution masks for the current frame. |
| tracker_metadata_prev (Dict[str, Any]): The metadata from the previous frame. |
| tracker_metadata_new (Dict[str, Any]): The metadata for the current frame. |
| obj_ids_newly_removed (Set[int]): The object IDs that have been removed. |
| Return: |
| Tensor: The updated low-resolution masks with some objects suppressed. |
| """ |
| obj_ids_global = tracker_metadata_prev["obj_ids_all_gpu"] |
| binary_tracker_low_res_masks_global = tracker_low_res_masks_global > 0 |
| batch_size = tracker_low_res_masks_global.size(0) |
| if batch_size > 0: |
| assert ( |
| len(obj_ids_global) == batch_size |
| ), f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}" |
| NEVER_OCCLUDED = -1 |
| ALWAYS_OCCLUDED = 100000 |
| last_occluded_prev = torch.cat( |
| [ |
| tracker_metadata_prev["obj_id_to_last_occluded"].get( |
| obj_id, |
| torch.full( |
| (1,), |
| fill_value=( |
| NEVER_OCCLUDED |
| if obj_id not in obj_ids_newly_removed |
| else ALWAYS_OCCLUDED |
| ), |
| device=binary_tracker_low_res_masks_global.device, |
| dtype=torch.long, |
| ), |
| ) |
| for obj_id in obj_ids_global |
| ], |
| dim=0, |
| ) |
| to_suppress = self._get_objects_to_suppress_based_on_most_recently_occluded( |
| binary_tracker_low_res_masks_global, |
| last_occluded_prev, |
| obj_ids_global, |
| frame_idx, |
| reverse, |
| ) |
|
|
| |
| is_obj_occluded = ~(binary_tracker_low_res_masks_global.any(dim=(-1, -2))) |
| is_obj_occluded_or_suppressed = is_obj_occluded | to_suppress |
| last_occluded_new = last_occluded_prev.clone() |
| last_occluded_new[is_obj_occluded_or_suppressed] = frame_idx |
| |
| tracker_metadata_new["obj_id_to_last_occluded"] = { |
| obj_id: last_occluded_new[obj_idx : obj_idx + 1] |
| for obj_idx, obj_id in enumerate(obj_ids_global) |
| } |
|
|
| |
| NO_OBJ_LOGIT = -10 |
| tracker_low_res_masks_global[to_suppress] = NO_OBJ_LOGIT |
|
|
| return tracker_low_res_masks_global |
|
|
| def run_tracker_update_execution_phase( |
| self, |
| frame_idx: int, |
| num_frames: int, |
| reverse: bool, |
| det_out: Dict[str, Tensor], |
| tracker_states_local: List[Any], |
| tracker_update_plan: Dict[str, npt.NDArray], |
| orig_vid_height: int, |
| orig_vid_width: int, |
| feature_cache: Dict, |
| ): |
| |
| new_det_fa_inds: npt.NDArray = tracker_update_plan["new_det_fa_inds"] |
| new_det_obj_ids: npt.NDArray = tracker_update_plan["new_det_obj_ids"] |
| new_det_gpu_ids: npt.NDArray = tracker_update_plan["new_det_gpu_ids"] |
| is_on_this_gpu: npt.NDArray = new_det_gpu_ids == self.rank |
| new_det_obj_ids_local: npt.NDArray = new_det_obj_ids[is_on_this_gpu] |
| new_det_fa_inds_local: npt.NDArray = new_det_fa_inds[is_on_this_gpu] |
| obj_ids_newly_removed: Set[int] = tracker_update_plan["obj_ids_newly_removed"] |
|
|
| |
| if len(new_det_fa_inds_local) > 0: |
| new_det_fa_inds_local_t = torch.from_numpy(new_det_fa_inds_local) |
| new_det_masks: Tensor = det_out["mask"][new_det_fa_inds_local_t] |
| |
| tracker_states_local = self._tracker_add_new_objects( |
| frame_idx=frame_idx, |
| num_frames=num_frames, |
| new_obj_ids=new_det_obj_ids_local, |
| new_obj_masks=new_det_masks, |
| tracker_states_local=tracker_states_local, |
| orig_vid_height=orig_vid_height, |
| orig_vid_width=orig_vid_width, |
| feature_cache=feature_cache, |
| ) |
|
|
| |
| if len(obj_ids_newly_removed) > 0: |
| self._tracker_remove_objects(tracker_states_local, obj_ids_newly_removed) |
|
|
| return tracker_states_local |
|
|
| def build_outputs( |
| self, |
| frame_idx: int, |
| num_frames: int, |
| reverse: bool, |
| det_out: Dict[str, Tensor], |
| tracker_low_res_masks_global: Tensor, |
| tracker_obj_scores_global: Tensor, |
| tracker_metadata_prev: Dict[str, npt.NDArray], |
| tracker_update_plan: Dict[str, npt.NDArray], |
| orig_vid_height: int, |
| orig_vid_width: int, |
| reconditioned_obj_ids: set = None, |
| det_to_matched_trk_obj_ids: dict = None, |
| ): |
| new_det_fa_inds: npt.NDArray = tracker_update_plan["new_det_fa_inds"] |
| new_det_obj_ids: npt.NDArray = tracker_update_plan["new_det_obj_ids"] |
| obj_id_to_mask = {} |
|
|
| |
| existing_masklet_obj_ids = tracker_metadata_prev["obj_ids_all_gpu"] |
| existing_masklet_video_res_masks = F.interpolate( |
| tracker_low_res_masks_global.unsqueeze(1), |
| size=(orig_vid_height, orig_vid_width), |
| mode="bilinear", |
| align_corners=False, |
| ) |
| existing_masklet_binary = existing_masklet_video_res_masks > 0 |
| assert len(existing_masklet_obj_ids) == len(existing_masklet_binary) |
| for obj_id, mask in zip(existing_masklet_obj_ids, existing_masklet_binary): |
| obj_id_to_mask[obj_id] = mask |
|
|
| |
| new_det_fa_inds_t = torch.from_numpy(new_det_fa_inds) |
| new_det_low_res_masks = det_out["mask"][new_det_fa_inds_t].unsqueeze(1) |
| new_det_low_res_masks = fill_holes_in_mask_scores( |
| new_det_low_res_masks, |
| max_area=self.fill_hole_area, |
| fill_holes=True, |
| remove_sprinkles=True, |
| ) |
| new_masklet_video_res_masks = F.interpolate( |
| new_det_low_res_masks, |
| size=(orig_vid_height, orig_vid_width), |
| mode="bilinear", |
| align_corners=False, |
| ) |
|
|
| new_masklet_binary = new_masklet_video_res_masks > 0 |
| assert len(new_det_obj_ids) == len(new_masklet_video_res_masks) |
| for obj_id, mask in zip(new_det_obj_ids, new_masklet_binary): |
| obj_id_to_mask[obj_id] = mask |
|
|
| |
| if reconditioned_obj_ids is not None and len(reconditioned_obj_ids) > 0: |
| trk_id_to_max_iou_high_conf_det = tracker_update_plan.get( |
| "trk_id_to_max_iou_high_conf_det", {} |
| ) |
|
|
| for obj_id in reconditioned_obj_ids: |
| det_idx = trk_id_to_max_iou_high_conf_det.get(obj_id) |
|
|
| if det_idx is not None: |
| det_mask = det_out["mask"][det_idx] |
| det_mask = det_mask.unsqueeze(0).unsqueeze(0) |
| det_mask_resized = ( |
| F.interpolate( |
| det_mask.float(), |
| size=(orig_vid_height, orig_vid_width), |
| mode="bilinear", |
| align_corners=False, |
| ) |
| > 0 |
| ) |
|
|
| det_mask_final = det_mask_resized.squeeze(0) |
| obj_id_to_mask[obj_id] = det_mask_final |
|
|
| return obj_id_to_mask |
|
|
| def _get_objects_to_suppress_based_on_most_recently_occluded( |
| self, |
| binary_low_res_masks: Tensor, |
| last_occluded: List[int], |
| obj_ids: List[int], |
| frame_idx: int = None, |
| reverse: bool = False, |
| ): |
| |
| assert ( |
| binary_low_res_masks.dtype == torch.bool |
| ), f"Expected boolean tensor, got {binary_low_res_masks.dtype}" |
| to_suppress = torch.zeros( |
| binary_low_res_masks.size(0), |
| device=binary_low_res_masks.device, |
| dtype=torch.bool, |
| ) |
| if len(obj_ids) <= 1: |
| return to_suppress |
|
|
| iou = mask_iou(binary_low_res_masks, binary_low_res_masks) |
|
|
| |
| mask_iou_thresh = ( |
| iou >= self.suppress_overlapping_based_on_recent_occlusion_threshold |
| ) |
| overlapping_pairs = torch.triu(mask_iou_thresh, diagonal=1) |
|
|
| last_occ_expanded_i = last_occluded.unsqueeze(1) |
| last_occ_expanded_j = last_occluded.unsqueeze(0) |
| |
| cmp_op = torch.gt if not reverse else torch.lt |
| suppress_i_mask = ( |
| overlapping_pairs |
| & cmp_op( |
| last_occ_expanded_i, last_occ_expanded_j |
| ) |
| & ( |
| last_occ_expanded_j > -1 |
| ) |
| ) |
| suppress_j_mask = ( |
| overlapping_pairs |
| & cmp_op(last_occ_expanded_j, last_occ_expanded_i) |
| & ( |
| last_occ_expanded_i > -1 |
| ) |
| ) |
| |
| to_suppress = suppress_i_mask.any(dim=1) | suppress_j_mask.any(dim=0) |
|
|
| |
| if ( |
| self.rank == 0 |
| and logger.isEnabledFor(logging.DEBUG) |
| and frame_idx is not None |
| ): |
| suppress_i_mask = suppress_i_mask.cpu().numpy() |
| suppress_j_mask = suppress_j_mask.cpu().numpy() |
| last_occluded = last_occluded.cpu().numpy() |
|
|
| |
| batch_size = suppress_i_mask.shape[0] |
|
|
| |
| for i in range(batch_size): |
| for j in range(batch_size): |
| if suppress_i_mask[i, j]: |
| logger.debug( |
| f"{frame_idx=}: Suppressing obj {obj_ids[i]} last occluded {last_occluded[i]} in favor of {obj_ids[j]} last occluded {last_occluded[j]}" |
| ) |
|
|
| |
| for i in range(batch_size): |
| for j in range(batch_size): |
| if suppress_j_mask[i, j]: |
| logger.debug( |
| f"{frame_idx=}: Suppressing obj {obj_ids[j]} last occluded {last_occluded[j]} in favor of {obj_ids[i]} last occluded {last_occluded[i]}" |
| ) |
|
|
| return to_suppress |
|
|
| def _propogate_tracker_one_frame_local_gpu( |
| self, |
| inference_states: List[Any], |
| frame_idx: int, |
| reverse: bool, |
| |
| run_mem_encoder: bool = False, |
| ): |
| """ |
| inference_states: List of inference states, each state corresponds to a different set of objects. |
| """ |
| obj_ids_local = [] |
| low_res_masks_list = [] |
| obj_scores_list = [] |
| for inference_state in inference_states: |
| if len(inference_state["obj_ids"]) == 0: |
| continue |
|
|
| |
| num_frames_propagated = 0 |
| for out in self.tracker.propagate_in_video( |
| inference_state, |
| start_frame_idx=frame_idx, |
| |
| |
| max_frame_num_to_track=0, |
| reverse=reverse, |
| tqdm_disable=True, |
| run_mem_encoder=run_mem_encoder, |
| ): |
| out_frame_idx, out_obj_ids, out_low_res_masks, _, out_obj_scores = out |
| num_frames_propagated += 1 |
|
|
| |
| assert ( |
| num_frames_propagated == 1 and out_frame_idx == frame_idx |
| ), f"num_frames_propagated: {num_frames_propagated}, out_frame_idx: {out_frame_idx}, frame_idx: {frame_idx}" |
| assert isinstance(out_obj_ids, list) |
| obj_ids_local.extend(out_obj_ids) |
| low_res_masks_list.append(out_low_res_masks.squeeze(1)) |
| obj_scores_list.append(out_obj_scores.squeeze(1)) |
|
|
| |
| H_mask = W_mask = self.tracker.low_res_mask_size |
| if len(low_res_masks_list) > 0: |
| low_res_masks_local = torch.cat(low_res_masks_list, dim=0) |
| obj_scores_local = torch.cat(obj_scores_list, dim=0) |
| assert low_res_masks_local.shape[1:] == (H_mask, W_mask) |
|
|
| |
| low_res_masks_local = fill_holes_in_mask_scores( |
| low_res_masks_local.unsqueeze(1), |
| max_area=self.fill_hole_area, |
| fill_holes=True, |
| remove_sprinkles=True, |
| ) |
| low_res_masks_local = low_res_masks_local.squeeze(1) |
| else: |
| low_res_masks_local = torch.zeros(0, H_mask, W_mask, device=self.device) |
| obj_scores_local = torch.zeros(0, device=self.device) |
|
|
| return obj_ids_local, low_res_masks_local, obj_scores_local |
|
|
| def _associate_det_trk( |
| self, |
| det_masks: Tensor, |
| det_scores_np: npt.NDArray, |
| trk_masks: Tensor, |
| trk_obj_ids: npt.NDArray, |
| ): |
| """ |
| Match detections on the current frame with the existing masklets. |
| |
| Args: |
| - det_masks: (N, H, W) tensor of predicted masks |
| - det_scores_np: (N,) array of detection scores |
| - trk_masks: (M, H, W) tensor of track masks |
| - trk_obj_ids: (M,) array of object IDs corresponding to trk_masks |
| |
| Returns: |
| - new_det_fa_inds: array of new object indices. |
| - unmatched_trk_obj_ids: array of existing masklet object IDs that are not matched |
| to any detections on this frame (for unmatched, we only count masklets with >0 area) |
| - det_to_matched_trk_obj_ids: dict[int, npt.NDArray]: mapping from detector's detection indices |
| to the list of matched tracklet object IDs |
| - empty_trk_obj_ids: array of existing masklet object IDs with zero area in SAM2 prediction |
| """ |
| iou_threshold = self.assoc_iou_thresh |
| iou_threshold_trk = self.trk_assoc_iou_thresh |
| new_det_thresh = self.new_det_thresh |
|
|
| assert det_masks.is_floating_point(), "float tensor expected (do not binarize)" |
| assert trk_masks.is_floating_point(), "float tensor expected (do not binarize)" |
| assert ( |
| trk_masks.size(0) == len(trk_obj_ids) |
| ), f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}" |
| if trk_masks.size(0) == 0: |
| |
| new_det_fa_inds = np.arange(det_masks.size(0)) |
| unmatched_trk_obj_ids = np.array([], np.int64) |
| empty_trk_obj_ids = np.array([], np.int64) |
| det_to_matched_trk_obj_ids = {} |
| trk_id_to_max_iou_high_conf_det = {} |
| return ( |
| new_det_fa_inds, |
| unmatched_trk_obj_ids, |
| det_to_matched_trk_obj_ids, |
| trk_id_to_max_iou_high_conf_det, |
| empty_trk_obj_ids, |
| ) |
| elif det_masks.size(0) == 0: |
| |
| new_det_fa_inds = np.array([], np.int64) |
| trk_is_nonempty = (trk_masks > 0).any(dim=(1, 2)).cpu().numpy() |
| unmatched_trk_obj_ids = trk_obj_ids[trk_is_nonempty] |
| empty_trk_obj_ids = trk_obj_ids[~trk_is_nonempty] |
| det_to_matched_trk_obj_ids = {} |
| trk_id_to_max_iou_high_conf_det = {} |
| return ( |
| new_det_fa_inds, |
| unmatched_trk_obj_ids, |
| det_to_matched_trk_obj_ids, |
| trk_id_to_max_iou_high_conf_det, |
| empty_trk_obj_ids, |
| ) |
|
|
| if det_masks.shape[-2:] != trk_masks.shape[-2:]: |
| |
| if np.prod(det_masks.shape[-2:]) < np.prod(trk_masks.shape[-2:]): |
| trk_masks = F.interpolate( |
| trk_masks.unsqueeze(1), |
| size=det_masks.shape[-2:], |
| mode="bilinear", |
| align_corners=False, |
| ).squeeze(1) |
| else: |
| |
| det_masks = F.interpolate( |
| det_masks.unsqueeze(1), |
| size=trk_masks.shape[-2:], |
| mode="bilinear", |
| align_corners=False, |
| ).squeeze(1) |
|
|
| det_masks_binary = det_masks > 0 |
| trk_masks_binary = trk_masks > 0 |
| ious = mask_iou(det_masks_binary, trk_masks_binary) |
|
|
| ious_np = ious.cpu().numpy() |
| if self.o2o_matching_masklets_enable: |
| from scipy.optimize import linear_sum_assignment |
|
|
| |
| cost_matrix = 1 - ious_np |
| row_ind, col_ind = linear_sum_assignment(cost_matrix) |
| trk_is_matched = np.zeros(trk_masks.size(0), dtype=bool) |
| for d, t in zip(row_ind, col_ind): |
| if ious_np[d, t] >= iou_threshold_trk: |
| trk_is_matched[t] = True |
| else: |
| trk_is_matched = (ious_np >= iou_threshold_trk).any(axis=0) |
| |
| trk_is_nonempty = trk_masks_binary.any(dim=(1, 2)).cpu().numpy() |
| trk_is_unmatched = np.logical_and(trk_is_nonempty, ~trk_is_matched) |
| unmatched_trk_obj_ids = trk_obj_ids[trk_is_unmatched] |
| |
| empty_trk_obj_ids = trk_obj_ids[~trk_is_nonempty] |
|
|
| |
| |
| is_new_det = np.logical_and( |
| det_scores_np >= new_det_thresh, |
| np.logical_not(np.any(ious_np >= iou_threshold, axis=1)), |
| ) |
| new_det_fa_inds = np.nonzero(is_new_det)[0] |
|
|
| |
| det_to_matched_trk_obj_ids = {} |
| trk_id_to_max_iou_high_conf_det = {} |
| HIGH_CONF_THRESH = 0.8 |
| HIGH_IOU_THRESH = 0.8 |
| det_to_max_iou_trk_idx = np.argmax(ious_np, axis=1) |
| det_is_high_conf = (det_scores_np >= HIGH_CONF_THRESH) & ~is_new_det |
| det_is_high_iou = np.max(ious_np, axis=1) >= HIGH_IOU_THRESH |
| det_is_high_conf_and_iou = set( |
| np.nonzero(det_is_high_conf & det_is_high_iou)[0] |
| ) |
| for d in range(det_masks.size(0)): |
| det_to_matched_trk_obj_ids[d] = trk_obj_ids[ious_np[d, :] >= iou_threshold] |
| if d in det_is_high_conf_and_iou: |
| trk_obj_id = trk_obj_ids[det_to_max_iou_trk_idx[d]].item() |
| trk_id_to_max_iou_high_conf_det[trk_obj_id] = d |
|
|
| return ( |
| new_det_fa_inds, |
| unmatched_trk_obj_ids, |
| det_to_matched_trk_obj_ids, |
| trk_id_to_max_iou_high_conf_det, |
| empty_trk_obj_ids, |
| ) |
|
|
| def _assign_new_det_to_gpus(self, new_det_num, prev_workload_per_gpu): |
| """Distribute the new objects to the GPUs with the least workload.""" |
| workload_per_gpu: npt.NDArray = prev_workload_per_gpu.copy() |
| new_det_gpu_ids = np.zeros(new_det_num, np.int64) |
|
|
| |
| for i in range(len(new_det_gpu_ids)): |
| |
| min_gpu = np.argmin(workload_per_gpu) |
| new_det_gpu_ids[i] = min_gpu |
| workload_per_gpu[min_gpu] += 1 |
| return new_det_gpu_ids |
|
|
| def _process_hotstart( |
| self, |
| frame_idx: int, |
| num_frames: int, |
| reverse: bool, |
| det_to_matched_trk_obj_ids: Dict[int, npt.NDArray], |
| new_det_obj_ids: npt.NDArray, |
| empty_trk_obj_ids: npt.NDArray, |
| unmatched_trk_obj_ids: npt.NDArray, |
| rank0_metadata: Dict[str, Any], |
| tracker_metadata: Dict[str, Any], |
| ): |
| """Handle hotstart heuristics to remove unmatched or duplicated objects.""" |
| |
| obj_first_frame_idx = rank0_metadata["obj_first_frame_idx"] |
| |
| unmatched_frame_inds = rank0_metadata["unmatched_frame_inds"] |
| trk_keep_alive = rank0_metadata["trk_keep_alive"] |
| |
| overlap_pair_to_frame_inds = rank0_metadata["overlap_pair_to_frame_inds"] |
| |
| removed_obj_ids = rank0_metadata["removed_obj_ids"] |
| suppressed_obj_ids = rank0_metadata["suppressed_obj_ids"][frame_idx] |
|
|
| obj_ids_newly_removed = set() |
| hotstart_diff = ( |
| frame_idx - self.hotstart_delay |
| if not reverse |
| else frame_idx + self.hotstart_delay |
| ) |
|
|
| |
| for obj_id in new_det_obj_ids: |
| if obj_id not in obj_first_frame_idx: |
| obj_first_frame_idx[obj_id] = frame_idx |
| assert obj_id not in trk_keep_alive |
| trk_keep_alive[obj_id] = self.init_trk_keep_alive |
|
|
| matched_trks = set() |
| |
| for matched_trks_per_det in det_to_matched_trk_obj_ids.values(): |
| matched_trks.update(matched_trks_per_det) |
| for obj_id in matched_trks: |
| |
| trk_keep_alive[obj_id] = min( |
| self.max_trk_keep_alive, trk_keep_alive[obj_id] + 1 |
| ) |
| for obj_id in unmatched_trk_obj_ids: |
| unmatched_frame_inds[obj_id].append(frame_idx) |
| |
| |
| trk_keep_alive[obj_id] = max( |
| self.min_trk_keep_alive, trk_keep_alive[obj_id] - 1 |
| ) |
| if self.decrease_trk_keep_alive_for_empty_masklets: |
| for obj_id in empty_trk_obj_ids: |
| |
| trk_keep_alive[obj_id] = max( |
| self.min_trk_keep_alive, trk_keep_alive[obj_id] - 1 |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| for obj_id, frame_indices in unmatched_frame_inds.items(): |
| if obj_id in removed_obj_ids or obj_id in obj_ids_newly_removed: |
| continue |
| if len(frame_indices) >= self.hotstart_unmatch_thresh: |
| is_within_hotstart = ( |
| obj_first_frame_idx[obj_id] > hotstart_diff and not reverse |
| ) or (obj_first_frame_idx[obj_id] < hotstart_diff and reverse) |
| if is_within_hotstart: |
| obj_ids_newly_removed.add(obj_id) |
| logger.debug( |
| f"Removing object {obj_id} at frame {frame_idx} " |
| f"since it is unmatched for frames: {frame_indices}" |
| ) |
| if ( |
| trk_keep_alive[obj_id] <= 0 |
| and not self.suppress_unmatched_only_within_hotstart |
| and obj_id not in removed_obj_ids |
| and obj_id not in obj_ids_newly_removed |
| ): |
| logger.debug( |
| f"Suppressing object {obj_id} at frame {frame_idx}, due to being unmatched" |
| ) |
| suppressed_obj_ids.add(obj_id) |
|
|
| |
| |
| for _, matched_trk_obj_ids in det_to_matched_trk_obj_ids.items(): |
| if len(matched_trk_obj_ids) < 2: |
| continue |
| |
| |
| first_appear_obj_id = ( |
| min(matched_trk_obj_ids, key=lambda x: obj_first_frame_idx[x]) |
| if not reverse |
| else max(matched_trk_obj_ids, key=lambda x: obj_first_frame_idx[x]) |
| ) |
| for obj_id in matched_trk_obj_ids: |
| if obj_id != first_appear_obj_id: |
| key = (first_appear_obj_id, obj_id) |
| overlap_pair_to_frame_inds[key].append(frame_idx) |
|
|
| |
| |
| for (first_obj_id, obj_id), frame_indices in overlap_pair_to_frame_inds.items(): |
| if obj_id in removed_obj_ids or obj_id in obj_ids_newly_removed: |
| continue |
| if (obj_first_frame_idx[obj_id] > hotstart_diff and not reverse) or ( |
| obj_first_frame_idx[obj_id] < hotstart_diff and reverse |
| ): |
| if len(frame_indices) >= self.hotstart_dup_thresh: |
| obj_ids_newly_removed.add(obj_id) |
| logger.debug( |
| f"Removing object {obj_id} at frame {frame_idx} " |
| f"since it overlaps with another track {first_obj_id} at frames: {frame_indices}" |
| ) |
|
|
| removed_obj_ids.update(obj_ids_newly_removed) |
| return obj_ids_newly_removed, rank0_metadata |
|
|
| def _tracker_update_memories( |
| self, |
| tracker_inference_states: List[Any], |
| frame_idx: int, |
| tracker_metadata: Dict[str, Any], |
| low_res_masks: Tensor, |
| ): |
| """ |
| Run Sam2 memory encoder, enforcing non-overlapping constraints globally. |
| """ |
| if len(tracker_inference_states) == 0: |
| return |
| |
| high_res_H, high_res_W = ( |
| self.tracker.maskmem_backbone.mask_downsampler.interpol_size |
| ) |
| |
| high_res_masks = F.interpolate( |
| low_res_masks.unsqueeze(1), |
| size=(high_res_H, high_res_W), |
| mode="bilinear", |
| align_corners=False, |
| ) |
| |
| if not hasattr(self, "_warm_up_complete") or self._warm_up_complete: |
| high_res_masks = self.tracker._suppress_object_pw_area_shrinkage( |
| high_res_masks |
| ) |
| |
| object_score_logits = torch.where( |
| (high_res_masks > 0).any(dim=(-1, -2)), 10.0, -10.0 |
| ) |
|
|
| |
| start_idx_gpu = sum(tracker_metadata["num_obj_per_gpu"][: self.rank]) |
| start_idx_state = start_idx_gpu |
| for tracker_state in tracker_inference_states: |
| num_obj_per_state = len(tracker_state["obj_ids"]) |
| if num_obj_per_state == 0: |
| continue |
| |
| end_idx_state = start_idx_state + num_obj_per_state |
| local_high_res_masks = high_res_masks[start_idx_state:end_idx_state] |
| local_object_score_logits = object_score_logits[ |
| start_idx_state:end_idx_state |
| ] |
| local_batch_size = local_high_res_masks.size(0) |
| |
|
|
| encoded_mem = self.tracker._run_memory_encoder( |
| tracker_state, |
| frame_idx, |
| local_batch_size, |
| local_high_res_masks, |
| local_object_score_logits, |
| is_mask_from_pts=False, |
| ) |
| local_maskmem_features, local_maskmem_pos_enc = encoded_mem |
| |
| output_dict = tracker_state["output_dict"] |
| for storage_key in ["cond_frame_outputs", "non_cond_frame_outputs"]: |
| if frame_idx not in output_dict[storage_key]: |
| continue |
| output_dict[storage_key][frame_idx]["maskmem_features"] = ( |
| local_maskmem_features |
| ) |
| output_dict[storage_key][frame_idx]["maskmem_pos_enc"] = [ |
| pos for pos in local_maskmem_pos_enc |
| ] |
| |
| |
| self.tracker._add_output_per_object( |
| inference_state=tracker_state, |
| frame_idx=frame_idx, |
| current_out=output_dict[storage_key][frame_idx], |
| storage_key=storage_key, |
| ) |
| start_idx_state += num_obj_per_state |
|
|
| def _tracker_add_new_objects( |
| self, |
| frame_idx: int, |
| num_frames: int, |
| new_obj_ids: List[int], |
| new_obj_masks: Tensor, |
| tracker_states_local: List[Any], |
| orig_vid_height: int, |
| orig_vid_width: int, |
| feature_cache: Dict, |
| ): |
| """Add a new object to SAM2 inference states.""" |
| prev_tracker_state = ( |
| tracker_states_local[0] if len(tracker_states_local) > 0 else None |
| ) |
|
|
| |
| |
| |
| new_tracker_state = self.tracker.init_state( |
| cached_features=feature_cache, |
| video_height=orig_vid_height, |
| video_width=orig_vid_width, |
| num_frames=num_frames, |
| ) |
| new_tracker_state["backbone_out"] = ( |
| prev_tracker_state.get("backbone_out", None) |
| if prev_tracker_state is not None |
| else None |
| ) |
|
|
| assert len(new_obj_ids) == new_obj_masks.size(0) |
| assert new_obj_masks.is_floating_point() |
| input_mask_res = self.tracker.input_mask_size |
| new_obj_masks = F.interpolate( |
| new_obj_masks.unsqueeze(1), |
| size=(input_mask_res, input_mask_res), |
| mode="bilinear", |
| align_corners=False, |
| ).squeeze(1) |
| new_obj_masks = new_obj_masks > 0 |
|
|
| |
| for new_obj_id, new_mask in zip(new_obj_ids, new_obj_masks): |
| self.tracker.add_new_mask( |
| inference_state=new_tracker_state, |
| frame_idx=frame_idx, |
| obj_id=new_obj_id, |
| mask=new_mask, |
| add_mask_to_memory=True, |
| ) |
| |
| self.tracker.propagate_in_video_preflight( |
| new_tracker_state, run_mem_encoder=True |
| ) |
| tracker_states_local.append(new_tracker_state) |
| return tracker_states_local |
|
|
| def _tracker_remove_object(self, tracker_states_local: List[Any], obj_id: int): |
| """ |
| Remove an object from SAM2 inference states. This would remove the object from |
| all frames in the video. |
| """ |
| tracker_states_local_before_removal = tracker_states_local.copy() |
| tracker_states_local.clear() |
| for tracker_inference_state in tracker_states_local_before_removal: |
| |
| |
| new_obj_ids, _ = self.tracker.remove_object( |
| tracker_inference_state, obj_id, strict=False, need_output=False |
| ) |
| |
| if len(new_obj_ids) > 0: |
| tracker_states_local.append(tracker_inference_state) |
|
|
| def _tracker_remove_objects( |
| self, tracker_states_local: List[Any], obj_ids: list[int] |
| ): |
| """ |
| Remove an object from SAM2 inference states. This would remove the object from |
| all frames in the video. |
| """ |
| for obj_id in obj_ids: |
| self._tracker_remove_object(tracker_states_local, obj_id) |
|
|
| def _initialize_metadata(self): |
| """Initialize metadata for the masklets.""" |
| tracker_metadata = { |
| "obj_ids_per_gpu": [np.array([], np.int64) for _ in range(self.world_size)], |
| "obj_ids_all_gpu": np.array([], np.int64), |
| "num_obj_per_gpu": np.zeros(self.world_size, np.int64), |
| "max_obj_id": -1, |
| "obj_id_to_score": {}, |
| "obj_id_to_tracker_score_frame_wise": defaultdict(dict), |
| "obj_id_to_last_occluded": {}, |
| } |
| if self.rank == 0: |
| |
| |
| |
| |
| |
| rank0_metadata = { |
| "obj_first_frame_idx": {}, |
| "unmatched_frame_inds": defaultdict(list), |
| "trk_keep_alive": defaultdict( |
| int |
| ), |
| "overlap_pair_to_frame_inds": defaultdict(list), |
| "removed_obj_ids": set(), |
| "suppressed_obj_ids": defaultdict( |
| set |
| ), |
| } |
| if self.masklet_confirmation_enable: |
| |
| rank0_metadata["masklet_confirmation"] = { |
| |
| "status": np.array([], np.int64), |
| |
| |
| "consecutive_det_num": np.array([], np.int64), |
| } |
| tracker_metadata["rank0_metadata"] = rank0_metadata |
|
|
| return tracker_metadata |
|
|
| def update_masklet_confirmation_status( |
| self, |
| rank0_metadata: Dict[str, Any], |
| obj_ids_all_gpu_prev: npt.NDArray, |
| obj_ids_all_gpu_updated: npt.NDArray, |
| det_to_matched_trk_obj_ids: Dict[int, npt.NDArray], |
| new_det_obj_ids: npt.NDArray, |
| ): |
| confirmation_data = rank0_metadata["masklet_confirmation"] |
|
|
| |
| status_prev = confirmation_data["status"] |
| consecutive_det_num_prev = confirmation_data["consecutive_det_num"] |
| assert ( |
| status_prev.shape == obj_ids_all_gpu_prev.shape |
| ), f"Got {status_prev.shape} vs {obj_ids_all_gpu_prev.shape}" |
|
|
| obj_id_to_updated_idx = { |
| obj_id: idx for idx, obj_id in enumerate(obj_ids_all_gpu_updated) |
| } |
| prev_elem_is_in_updated = np.isin(obj_ids_all_gpu_prev, obj_ids_all_gpu_updated) |
| prev_elem_obj_ids_in_updated = obj_ids_all_gpu_prev[prev_elem_is_in_updated] |
| prev_elem_inds_in_updated = np.array( |
| [obj_id_to_updated_idx[obj_id] for obj_id in prev_elem_obj_ids_in_updated], |
| dtype=np.int64, |
| ) |
| |
| unconfirmed_val = MaskletConfirmationStatus.UNCONFIRMED.value |
| status = np.full_like(obj_ids_all_gpu_updated, fill_value=unconfirmed_val) |
| status[prev_elem_inds_in_updated] = status_prev[prev_elem_is_in_updated] |
| consecutive_det_num = np.zeros_like(obj_ids_all_gpu_updated) |
| consecutive_det_num[prev_elem_inds_in_updated] = consecutive_det_num_prev[ |
| prev_elem_is_in_updated |
| ] |
|
|
| |
| |
| |
| is_matched = np.isin(obj_ids_all_gpu_updated, new_det_obj_ids) |
| for matched_trk_obj_ids in det_to_matched_trk_obj_ids.values(): |
| is_matched |= np.isin(obj_ids_all_gpu_updated, matched_trk_obj_ids) |
| consecutive_det_num = np.where(is_matched, consecutive_det_num + 1, 0) |
|
|
| |
| change_to_confirmed = ( |
| consecutive_det_num >= self.masklet_confirmation_consecutive_det_thresh |
| ) |
| status[change_to_confirmed] = MaskletConfirmationStatus.CONFIRMED.value |
|
|
| confirmation_data["status"] = status |
| confirmation_data["consecutive_det_num"] = consecutive_det_num |
| return rank0_metadata |
|
|
| def forward(self, input: BatchedDatapoint, is_inference: bool = False): |
| raise NotImplementedError("Evaluation outside demo is not implemented yet") |
|
|
| def _load_checkpoint(self, ckpt_path: str, strict: bool = True): |
| sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"] |
| missing_keys, unexpected_keys = self.load_state_dict(sd, strict=strict) |
| if len(missing_keys) > 0 or len(unexpected_keys) > 0: |
| logger.warning(f"Loaded ckpt with {missing_keys=}, {unexpected_keys=}") |
| else: |
| logger.info("Loaded ckpt successfully without missing or unexpected keys") |
|
|
| def prep_for_evaluator(self, video_frames, tracking_res, scores_labels): |
| """This method is only used for benchmark eval (not used in the demo).""" |
| num_frames = len(video_frames) |
| w, h = video_frames[0].size |
| zero_mask = torch.zeros((1, h, w), dtype=torch.bool) |
| object_ids = list(scores_labels.keys()) |
| preds = {"scores": [], "labels": [], "boxes": [], "masks_rle": []} |
| for oid in object_ids: |
| o_masks = [] |
| o_score = scores_labels[oid][0].item() |
| o_label = scores_labels[oid][1] |
| for frame_idx in range(num_frames): |
| if frame_idx not in tracking_res: |
| o_masks.append(zero_mask) |
| else: |
| o_masks.append(tracking_res[frame_idx].get(oid, zero_mask)) |
|
|
| o_masks = torch.cat(o_masks, dim=0) |
| preds["scores"].append(o_score) |
| preds["labels"].append(o_label) |
| preds["boxes"].append(mask_to_box(o_masks.unsqueeze(1)).squeeze()) |
| preds["masks_rle"].append(rle_encode(o_masks, return_areas=True)) |
|
|
| preds["boxes"] = ( |
| torch.stack(preds["boxes"], dim=0) |
| if len(preds["boxes"]) > 0 |
| else torch.empty( |
| (0, num_frames, 4), dtype=torch.float32, device=self.device |
| ) |
| ) |
| preds["scores"] = ( |
| torch.tensor(preds["scores"], device=self.device) |
| if len(preds["scores"]) > 0 |
| else torch.empty((0,), device=self.device) |
| ) |
| preds["per_frame_scores"] = preds["scores"] |
| preds["labels"] = ( |
| torch.tensor(preds["labels"], device=self.device) |
| if len(preds["labels"]) > 0 |
| else torch.empty((0,), device=self.device) |
| ) |
| return preds |
|
|
| def _encode_prompt(self, **kwargs): |
| return self.detector._encode_prompt(**kwargs) |
|
|
| def _drop_new_det_with_obj_limit(self, new_det_fa_inds, det_scores_np, num_to_keep): |
| """ |
| Drop a few new detections based on the maximum number of objects. We drop new objects based |
| on their detection scores, keeping the high-scoring ones and dropping the low-scoring ones. |
| """ |
| assert 0 <= num_to_keep <= len(new_det_fa_inds) |
| if num_to_keep == 0: |
| return np.array([], np.int64) |
| if num_to_keep == len(new_det_fa_inds): |
| return new_det_fa_inds |
|
|
| |
| score_order = np.argsort(det_scores_np[new_det_fa_inds])[::-1] |
| new_det_fa_inds = new_det_fa_inds[score_order[:num_to_keep]] |
| return new_det_fa_inds |
|
|