| import logging |
| from collections import OrderedDict |
| from copy import deepcopy |
| from typing import Iterable, Optional |
|
|
| import numpy as np |
| import torch |
| from ..model.data_misc import NestedTensor |
| from ..model.device_utils import get_accelerator_device |
| from ..model.io_utils import load_video_frames |
| from ..model.multiplex_utils import MultiplexState |
| from ..model.sam3_tracker_utils import fill_holes_in_mask_scores |
| from ..model.video_tracking_multiplex import ( |
| concat_points, |
| NO_OBJ_SCORE, |
| VideoTrackingDynamicMultiplex, |
| ) |
| from tqdm import tqdm |
|
|
|
|
| class VideoTrackingMultiplexDemo(VideoTrackingDynamicMultiplex): |
| """ |
| The demo class that extends the `VideoTrackingDynamicMultiplex` to handle user interactions |
| and manage inference states, with support for multi-object tracking. |
| |
| Interactions are not yet implemented. |
| """ |
|
|
| def __init__( |
| self, |
| |
| |
| clear_non_cond_mem_around_input=False, |
| |
| clear_non_cond_mem_for_multi_obj=False, |
| |
| fill_hole_area=0, |
| |
| |
| always_start_from_first_ann_frame=False, |
| |
| |
| |
| |
| max_point_num_in_prompt_enc=16, |
| non_overlap_masks_for_output=True, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
|
|
| self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input |
| self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj |
| self.fill_hole_area = fill_hole_area |
| self.always_start_from_first_ann_frame = always_start_from_first_ann_frame |
| self.max_point_num_in_prompt_enc = max_point_num_in_prompt_enc |
| self.non_overlap_masks_for_output = non_overlap_masks_for_output |
|
|
| @torch.inference_mode() |
| def init_state( |
| self, |
| video_path, |
| offload_video_to_cpu, |
| offload_state_to_cpu, |
| async_loading_frames=False, |
| use_cv2=False, |
| ): |
| """Initialize a inference state.""" |
| |
| |
| |
| if not self.apply_sigmoid_to_mask_logits_for_mem_enc: |
| raise NotImplementedError( |
| "Multi-object tracking requires sigmoid in memory encoder for non-overlapping constraints." |
| ) |
|
|
| images, video_height, video_width = load_video_frames( |
| video_path=video_path, |
| image_size=self.image_size, |
| offload_video_to_cpu=offload_video_to_cpu, |
| async_loading_frames=async_loading_frames, |
| use_cv2=use_cv2, |
| ) |
| inference_state = {} |
| inference_state["images"] = images |
| inference_state["num_frames"] = len(images) |
| |
| |
| inference_state["offload_video_to_cpu"] = offload_video_to_cpu |
| |
| |
| |
| |
| inference_state["offload_state_to_cpu"] = offload_state_to_cpu |
| |
| inference_state["video_height"] = video_height |
| inference_state["video_width"] = video_width |
| inference_state["device"] = get_accelerator_device() |
| if offload_state_to_cpu: |
| inference_state["storage_device"] = torch.device("cpu") |
| else: |
| inference_state["storage_device"] = get_accelerator_device() |
| |
| inference_state["point_inputs_per_obj"] = {} |
| inference_state["mask_inputs_per_obj"] = {} |
| |
| inference_state["cached_features"] = {} |
| |
| inference_state["constants"] = {} |
| |
| inference_state["obj_id_to_idx"] = OrderedDict() |
| inference_state["obj_idx_to_id"] = OrderedDict() |
| inference_state["obj_ids"] = [] |
| |
| inference_state["output_dict"] = { |
| "cond_frame_outputs": {}, |
| "non_cond_frame_outputs": {}, |
| } |
| |
| inference_state["first_ann_frame_idx"] = None |
| |
| inference_state["output_dict_per_obj"] = {} |
| |
| |
| inference_state["temp_output_dict_per_obj"] = {} |
| |
| |
| inference_state["consolidated_frame_inds"] = { |
| "cond_frame_outputs": set(), |
| "non_cond_frame_outputs": set(), |
| } |
| |
| inference_state["tracking_has_started"] = False |
| inference_state["frames_already_tracked"] = {} |
| inference_state["multiplex_state"] = None |
| |
| |
| inference_state["user_refined_frames_per_obj"] = {} |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| return inference_state |
|
|
| def _obj_id_to_idx(self, inference_state, obj_id, error_if_new=False): |
| """Map client-side object id to model-side object index.""" |
| obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) |
| if obj_idx is not None: |
| return obj_idx |
|
|
| if ( |
| self.is_dynamic_model or not inference_state["tracking_has_started"] |
| ) and not error_if_new: |
| |
| obj_idx = len(inference_state["obj_id_to_idx"]) |
| inference_state["obj_id_to_idx"][obj_id] = obj_idx |
| inference_state["obj_idx_to_id"][obj_idx] = obj_id |
| inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"]) |
| |
| inference_state["point_inputs_per_obj"][obj_idx] = {} |
| inference_state["mask_inputs_per_obj"][obj_idx] = {} |
| inference_state["output_dict_per_obj"][obj_idx] = { |
| "cond_frame_outputs": {}, |
| "non_cond_frame_outputs": {}, |
| } |
| inference_state["temp_output_dict_per_obj"][obj_idx] = { |
| "cond_frame_outputs": {}, |
| "non_cond_frame_outputs": {}, |
| } |
| return obj_idx |
| else: |
| raise RuntimeError( |
| f"Cannot add new object id {obj_id}. " |
| f"All existing object ids: {inference_state['obj_ids']}." |
| ) |
|
|
| def _obj_idx_to_id(self, inference_state, obj_idx): |
| """Map model-side object index to client-side object id.""" |
| return inference_state["obj_idx_to_id"][obj_idx] |
|
|
| def _get_obj_num(self, inference_state): |
| """Get the total number of unique object ids received so far in this session.""" |
| |
| return inference_state["multiplex_state"].total_valid_entries |
|
|
| @torch.inference_mode() |
| def _extract_object_for_interaction(self, inference_state, obj_id, frame_idx): |
| """ |
| Extract a single object from multiplex state for singleton interaction. |
| Adapted from sam3_multiplex_tracking._extract_object_to_singleton_state() |
| |
| Returns: |
| singleton_state: New inference state containing only this object |
| obj_idx_in_source: Original object index before removal (for merging back) |
| """ |
| source_state = inference_state |
| obj_idx_in_source = source_state["obj_id_to_idx"][obj_id] |
|
|
| |
| multiplex_state = source_state.get("multiplex_state") |
|
|
| |
| singleton_consolidated_outputs = { |
| "cond_frame_outputs": {}, |
| "non_cond_frame_outputs": {}, |
| } |
|
|
| if "output_dict" in source_state: |
| for storage_key in ["cond_frame_outputs", "non_cond_frame_outputs"]: |
| source_outputs = source_state["output_dict"].get(storage_key, {}) |
|
|
| for f_idx, source_frame_out in source_outputs.items(): |
| |
| has_valid_data = ( |
| source_frame_out["pred_masks"].shape[0] >= obj_idx_in_source + 1 |
| ) |
|
|
| if has_valid_data: |
| |
| singleton_frame_out = { |
| "pred_masks": source_frame_out["pred_masks"][ |
| obj_idx_in_source : obj_idx_in_source + 1 |
| ].clone(), |
| "object_score_logits": source_frame_out[ |
| "object_score_logits" |
| ][obj_idx_in_source : obj_idx_in_source + 1].clone(), |
| |
| "image_features": source_frame_out.get("image_features"), |
| "image_pos_enc": source_frame_out.get("image_pos_enc"), |
| "local_obj_id_to_idx": {obj_id: 0}, |
| } |
|
|
| |
| maskmem_features = source_frame_out.get("maskmem_features") |
| if maskmem_features is not None: |
| if multiplex_state is not None: |
| expected_buckets = multiplex_state.num_buckets |
| expected_multiplex = multiplex_state.multiplex_count |
| if ( |
| maskmem_features.dim() >= 2 |
| and maskmem_features.shape[0] == expected_buckets |
| and maskmem_features.shape[1] == expected_multiplex |
| ): |
| try: |
| demuxed_features = multiplex_state.demux( |
| maskmem_features |
| ) |
| except AssertionError as exc: |
| logging.warning( |
| "[EXTRACT] demux failed for maskmem_features shape %s: %s", |
| tuple(maskmem_features.shape), |
| exc, |
| ) |
| demuxed_features = None |
| if demuxed_features is not None: |
| maskmem_features = demuxed_features[ |
| obj_idx_in_source : obj_idx_in_source + 1 |
| ].clone() |
| else: |
| maskmem_features = maskmem_features[ |
| obj_idx_in_source : obj_idx_in_source + 1 |
| ].clone() |
| elif maskmem_features.shape[0] == 0: |
| |
| maskmem_features = None |
| elif maskmem_features.shape[0] >= obj_idx_in_source + 1: |
| |
| maskmem_features = maskmem_features[ |
| obj_idx_in_source : obj_idx_in_source + 1 |
| ].clone() |
| else: |
| logging.warning( |
| "[EXTRACT] maskmem_features shape %s incompatible with multiplex state; dropping tensor", |
| tuple(maskmem_features.shape), |
| ) |
| maskmem_features = None |
| else: |
| maskmem_features = maskmem_features[ |
| obj_idx_in_source : obj_idx_in_source + 1 |
| ].clone() |
| singleton_frame_out["maskmem_features"] = maskmem_features |
|
|
| |
| maskmem_pos_enc = source_frame_out.get("maskmem_pos_enc") |
| if maskmem_pos_enc is not None: |
| remapped_pos_enc = [] |
| for level_enc in maskmem_pos_enc: |
| if level_enc is None: |
| remapped_pos_enc.append(None) |
| continue |
| if multiplex_state is not None: |
| expected_buckets = multiplex_state.num_buckets |
| expected_multiplex = multiplex_state.multiplex_count |
| if ( |
| level_enc.dim() >= 2 |
| and level_enc.shape[0] == expected_buckets |
| and level_enc.shape[1] == expected_multiplex |
| ): |
| try: |
| demuxed_level = multiplex_state.demux( |
| level_enc |
| ) |
| except AssertionError as exc: |
| logging.warning( |
| "[EXTRACT] demux failed for maskmem_pos_enc level shape %s: %s", |
| tuple(level_enc.shape), |
| exc, |
| ) |
| demuxed_level = None |
| if demuxed_level is not None: |
| remapped_pos_enc.append( |
| demuxed_level[ |
| obj_idx_in_source : obj_idx_in_source |
| + 1 |
| ].clone() |
| ) |
| elif ( |
| level_enc.shape[0] >= obj_idx_in_source + 1 |
| ): |
| remapped_pos_enc.append( |
| level_enc[ |
| obj_idx_in_source : obj_idx_in_source |
| + 1 |
| ].clone() |
| ) |
| else: |
| logging.warning( |
| "[EXTRACT] maskmem_pos_enc level shape %s incompatible with multiplex state; dropping level", |
| tuple(level_enc.shape), |
| ) |
| remapped_pos_enc.append(None) |
| elif level_enc.shape[0] >= obj_idx_in_source + 1: |
| remapped_pos_enc.append( |
| level_enc[ |
| obj_idx_in_source : obj_idx_in_source |
| + 1 |
| ].clone() |
| ) |
| else: |
| logging.warning( |
| "[EXTRACT] maskmem_pos_enc level shape %s incompatible with multiplex state; dropping level", |
| tuple(level_enc.shape), |
| ) |
| remapped_pos_enc.append(None) |
| else: |
| remapped_pos_enc.append( |
| level_enc[ |
| obj_idx_in_source : obj_idx_in_source + 1 |
| ].clone() |
| ) |
| maskmem_pos_enc = remapped_pos_enc |
| singleton_frame_out["maskmem_pos_enc"] = maskmem_pos_enc |
|
|
| |
| if ( |
| "obj_ptr" in source_frame_out |
| and self.use_obj_ptrs_in_encoder |
| ): |
| source_obj_ptr = source_frame_out["obj_ptr"] |
| if multiplex_state is not None: |
| |
| obj_ptr_data_space = multiplex_state.demux( |
| source_obj_ptr |
| ) |
| |
| singleton_frame_out["obj_ptr"] = obj_ptr_data_space[ |
| obj_idx_in_source : obj_idx_in_source + 1 |
| ].clone() |
| else: |
| singleton_frame_out["obj_ptr"] = source_obj_ptr[ |
| obj_idx_in_source : obj_idx_in_source + 1 |
| ].clone() |
|
|
| |
| if "conditioning_objects" in source_frame_out: |
| if ( |
| obj_idx_in_source |
| in source_frame_out["conditioning_objects"] |
| ): |
| singleton_frame_out["conditioning_objects"] = {0} |
| else: |
| singleton_frame_out["conditioning_objects"] = set() |
|
|
| singleton_consolidated_outputs[storage_key][f_idx] = ( |
| singleton_frame_out |
| ) |
|
|
| |
| extracted_point_inputs = {} |
| extracted_mask_inputs = {} |
|
|
| if ( |
| "point_inputs_per_obj" in source_state |
| and obj_idx_in_source in source_state["point_inputs_per_obj"] |
| ): |
| extracted_point_inputs = source_state["point_inputs_per_obj"][ |
| obj_idx_in_source |
| ].copy() |
|
|
| if ( |
| "mask_inputs_per_obj" in source_state |
| and obj_idx_in_source in source_state["mask_inputs_per_obj"] |
| ): |
| extracted_mask_inputs = source_state["mask_inputs_per_obj"][ |
| obj_idx_in_source |
| ].copy() |
|
|
| |
| extracted_obj_cond_outputs = {} |
| extracted_obj_non_cond_outputs = {} |
| extracted_temp_cond_outputs = {} |
| extracted_temp_non_cond_outputs = {} |
|
|
| if ( |
| "output_dict_per_obj" in source_state |
| and obj_idx_in_source in source_state["output_dict_per_obj"] |
| ): |
| obj_output_dict = source_state["output_dict_per_obj"][obj_idx_in_source] |
| extracted_obj_cond_outputs = obj_output_dict.get( |
| "cond_frame_outputs", {} |
| ).copy() |
| extracted_obj_non_cond_outputs = obj_output_dict.get( |
| "non_cond_frame_outputs", {} |
| ).copy() |
|
|
| if ( |
| "temp_output_dict_per_obj" in source_state |
| and obj_idx_in_source in source_state["temp_output_dict_per_obj"] |
| ): |
| temp_obj_output_dict = source_state["temp_output_dict_per_obj"][ |
| obj_idx_in_source |
| ] |
| extracted_temp_cond_outputs = temp_obj_output_dict.get( |
| "cond_frame_outputs", {} |
| ).copy() |
| extracted_temp_non_cond_outputs = temp_obj_output_dict.get( |
| "non_cond_frame_outputs", {} |
| ).copy() |
|
|
| |
| remaining_obj_ids, _ = self.remove_object( |
| source_state, |
| obj_id, |
| strict=False, |
| need_output=False, |
| clear_user_refined_map=False, |
| ) |
|
|
| |
| updated_multiplex_state = source_state.get("multiplex_state") |
| if updated_multiplex_state is not None: |
| if ( |
| getattr(updated_multiplex_state, "assignments", None) is None |
| or updated_multiplex_state.total_valid_entries == 0 |
| ): |
| source_state["multiplex_state"] = None |
|
|
| |
| singleton_state = self.init_state( |
| cached_features=source_state["cached_features"], |
| video_height=source_state["video_height"], |
| video_width=source_state["video_width"], |
| num_frames=source_state["num_frames"], |
| ) |
|
|
| |
| singleton_state["obj_id_to_idx"] = {obj_id: 0} |
| singleton_state["obj_idx_to_id"] = {0: obj_id} |
| singleton_state["obj_ids"] = [obj_id] |
| singleton_state["point_inputs_per_obj"] = {0: extracted_point_inputs} |
| singleton_state["mask_inputs_per_obj"] = {0: extracted_mask_inputs} |
| singleton_state["output_dict_per_obj"] = { |
| 0: { |
| "cond_frame_outputs": extracted_obj_cond_outputs, |
| "non_cond_frame_outputs": extracted_obj_non_cond_outputs, |
| } |
| } |
| singleton_state["temp_output_dict_per_obj"] = { |
| 0: { |
| "cond_frame_outputs": extracted_temp_cond_outputs, |
| "non_cond_frame_outputs": extracted_temp_non_cond_outputs, |
| } |
| } |
| singleton_state["frames_already_tracked"] = source_state[ |
| "frames_already_tracked" |
| ].copy() |
|
|
| |
| new_multiplex_state = self.multiplex_controller.get_state( |
| num_valid_entries=1, |
| device=source_state["device"], |
| dtype=torch.float32, |
| random=False, |
| object_ids=[obj_id], |
| ) |
| singleton_state["multiplex_state"] = new_multiplex_state |
|
|
| |
| for storage_key in ["cond_frame_outputs", "non_cond_frame_outputs"]: |
| for f_idx, frame_out in singleton_consolidated_outputs[storage_key].items(): |
| |
| if frame_out.get("maskmem_features") is not None: |
| |
| frame_out["maskmem_features"] = frame_out[ |
| "maskmem_features" |
| ].clone() |
|
|
| if frame_out.get("maskmem_pos_enc") is not None: |
| remapped_levels = [] |
| for level_enc in frame_out["maskmem_pos_enc"]: |
| if level_enc is None: |
| remapped_levels.append(None) |
| continue |
| remapped_levels.append(level_enc.clone()) |
| frame_out["maskmem_pos_enc"] = remapped_levels |
|
|
| |
| if "obj_ptr" in frame_out and self.use_obj_ptrs_in_encoder: |
| |
| frame_out["obj_ptr"] = new_multiplex_state.mux(frame_out["obj_ptr"]) |
|
|
| singleton_state["output_dict"] = singleton_consolidated_outputs |
|
|
| return singleton_state, obj_idx_in_source |
|
|
| @torch.inference_mode() |
| def _merge_singleton_interaction_result( |
| self, |
| inference_state, |
| singleton_state, |
| obj_id, |
| original_obj_idx, |
| ): |
| """ |
| Merge singleton interaction result back into multiplex state. |
| |
| SIMPLIFIED APPROACH: Add object back at the END (new index), not at original position. |
| This avoids complex index shifting and works with multiplex controller's add_objects() API. |
| |
| Args: |
| inference_state: The main multiplex inference state |
| singleton_state: The singleton state with interaction results |
| obj_id: The object ID |
| original_obj_idx: The original index before extraction (unused - we add at end instead) |
| """ |
| |
| new_obj_idx = len(inference_state["obj_ids"]) |
|
|
| |
| inference_state["obj_ids"].append(obj_id) |
| inference_state["obj_id_to_idx"][obj_id] = new_obj_idx |
|
|
| |
| |
| inference_state["output_dict_per_obj"][new_obj_idx] = { |
| "cond_frame_outputs": {}, |
| "non_cond_frame_outputs": {}, |
| } |
| inference_state["temp_output_dict_per_obj"][new_obj_idx] = { |
| "cond_frame_outputs": {}, |
| "non_cond_frame_outputs": {}, |
| } |
|
|
| inference_state["obj_idx_to_id"][new_obj_idx] = obj_id |
|
|
| |
| multiplex_state = inference_state.get("multiplex_state") |
|
|
| assignments = ( |
| getattr(multiplex_state, "assignments", None) |
| if multiplex_state is not None |
| else None |
| ) |
| total_valid_entries = ( |
| getattr(multiplex_state, "total_valid_entries", 0) |
| if multiplex_state is not None and assignments is not None |
| else 0 |
| ) |
| need_state_reinit = ( |
| multiplex_state is None or assignments is None or total_valid_entries == 0 |
| ) |
|
|
| if not need_state_reinit and getattr(multiplex_state, "object_ids", None): |
| if obj_id in multiplex_state.object_ids: |
| old_idx = multiplex_state.object_ids.index(obj_id) |
| multiplex_state.remove_objects(object_indices=[old_idx], strict=False) |
| assignments = getattr(multiplex_state, "assignments", None) |
| total_valid_entries = ( |
| getattr(multiplex_state, "total_valid_entries", 0) |
| if assignments is not None |
| else 0 |
| ) |
| need_state_reinit = assignments is None or total_valid_entries == 0 |
|
|
| if need_state_reinit: |
| inference_state["multiplex_state"] = self.multiplex_controller.get_state( |
| num_valid_entries=len(inference_state["obj_ids"]), |
| device=inference_state["device"], |
| dtype=torch.float32, |
| random=False, |
| object_ids=list(inference_state["obj_ids"]), |
| ) |
| multiplex_state = inference_state["multiplex_state"] |
| else: |
| |
| multiplex_state.add_objects( |
| object_indices=[new_obj_idx], |
| object_ids=[obj_id], |
| allow_new_buckets=True, |
| ) |
|
|
| |
| singleton_obj_idx = 0 |
| if ( |
| "point_inputs_per_obj" in singleton_state |
| and singleton_obj_idx in singleton_state["point_inputs_per_obj"] |
| ): |
| if "point_inputs_per_obj" not in inference_state: |
| inference_state["point_inputs_per_obj"] = {} |
| inference_state["point_inputs_per_obj"][new_obj_idx] = singleton_state[ |
| "point_inputs_per_obj" |
| ][singleton_obj_idx].copy() |
|
|
| if ( |
| "mask_inputs_per_obj" in singleton_state |
| and singleton_obj_idx in singleton_state["mask_inputs_per_obj"] |
| ): |
| if "mask_inputs_per_obj" not in inference_state: |
| inference_state["mask_inputs_per_obj"] = {} |
| inference_state["mask_inputs_per_obj"][new_obj_idx] = singleton_state[ |
| "mask_inputs_per_obj" |
| ][singleton_obj_idx].copy() |
|
|
| |
| if ( |
| "output_dict_per_obj" in singleton_state |
| and singleton_obj_idx in singleton_state["output_dict_per_obj"] |
| ): |
| if "output_dict_per_obj" not in inference_state: |
| inference_state["output_dict_per_obj"] = {} |
| inference_state["output_dict_per_obj"][new_obj_idx] = singleton_state[ |
| "output_dict_per_obj" |
| ][singleton_obj_idx].copy() |
|
|
| if ( |
| "temp_output_dict_per_obj" in singleton_state |
| and singleton_obj_idx in singleton_state["temp_output_dict_per_obj"] |
| ): |
| if "temp_output_dict_per_obj" not in inference_state: |
| inference_state["temp_output_dict_per_obj"] = {} |
| inference_state["temp_output_dict_per_obj"][new_obj_idx] = singleton_state[ |
| "temp_output_dict_per_obj" |
| ][singleton_obj_idx].copy() |
|
|
| |
| |
| |
| if "output_dict" in singleton_state: |
| singleton_multiplex_state = singleton_state.get("multiplex_state") |
| for singleton_storage_key in [ |
| "cond_frame_outputs", |
| "non_cond_frame_outputs", |
| ]: |
| singleton_outputs = singleton_state["output_dict"].get( |
| singleton_storage_key, {} |
| ) |
|
|
| |
| if not singleton_outputs: |
| continue |
|
|
| for frame_idx, singleton_frame_out in singleton_outputs.items(): |
| |
| if "output_dict" not in inference_state: |
| inference_state["output_dict"] = { |
| "cond_frame_outputs": {}, |
| "non_cond_frame_outputs": {}, |
| } |
|
|
| if ( |
| frame_idx |
| not in inference_state["output_dict"][singleton_storage_key] |
| ): |
| |
| num_objs = len(inference_state["obj_ids"]) |
|
|
| |
| |
| if num_objs <= new_obj_idx: |
| num_objs = new_obj_idx + 1 |
|
|
| new_maskmem_features = None |
| new_maskmem_pos_enc = None |
| if ( |
| singleton_frame_out.get("maskmem_features") is not None |
| and multiplex_state is not None |
| ): |
| |
| singleton_features_muxed = singleton_frame_out[ |
| "maskmem_features" |
| ] |
| if singleton_features_muxed.shape[:2] == ( |
| singleton_multiplex_state.num_buckets, |
| singleton_multiplex_state.multiplex_count, |
| ): |
| |
| singleton_features_data = ( |
| singleton_multiplex_state.demux( |
| singleton_features_muxed |
| ) |
| ) |
| else: |
| |
| singleton_features_data = singleton_features_muxed |
|
|
| feature_shape = (num_objs,) + singleton_features_data.shape[ |
| 1: |
| ] |
| maskmem_features_data = torch.zeros( |
| feature_shape, |
| dtype=singleton_features_data.dtype, |
| device=singleton_features_data.device, |
| ) |
| maskmem_features_data[new_obj_idx : new_obj_idx + 1] = ( |
| singleton_features_data |
| ) |
| |
| new_maskmem_features = multiplex_state.mux( |
| maskmem_features_data |
| ) |
|
|
| if ( |
| singleton_frame_out.get("maskmem_pos_enc") is not None |
| and multiplex_state is not None |
| ): |
| new_maskmem_pos_enc = [] |
| for level_enc in singleton_frame_out["maskmem_pos_enc"]: |
| if level_enc is None: |
| new_maskmem_pos_enc.append(None) |
| continue |
| |
| if level_enc.shape[:2] == ( |
| singleton_multiplex_state.num_buckets, |
| singleton_multiplex_state.multiplex_count, |
| ): |
| |
| level_data = singleton_multiplex_state.demux( |
| level_enc |
| ) |
| else: |
| |
| level_data = level_enc |
|
|
| level_shape = (num_objs,) + level_data.shape[1:] |
| level_tensor = torch.zeros( |
| level_shape, |
| dtype=level_data.dtype, |
| device=level_data.device, |
| ) |
| level_tensor[new_obj_idx : new_obj_idx + 1] = level_data |
| |
| new_maskmem_pos_enc.append( |
| multiplex_state.mux(level_tensor) |
| ) |
|
|
| inference_state["output_dict"][singleton_storage_key][ |
| frame_idx |
| ] = { |
| "maskmem_features": new_maskmem_features, |
| "maskmem_pos_enc": new_maskmem_pos_enc, |
| "image_features": singleton_frame_out.get("image_features"), |
| "image_pos_enc": singleton_frame_out.get("image_pos_enc"), |
| "local_obj_id_to_idx": {obj_id: new_obj_idx}, |
| "conditioning_objects": ( |
| set([new_obj_idx]) |
| if singleton_obj_idx |
| in singleton_frame_out.get( |
| "conditioning_objects", set() |
| ) |
| else set() |
| ), |
| "pred_masks": torch.zeros( |
| ( |
| num_objs, |
| 1, |
| singleton_frame_out["pred_masks"].shape[2], |
| singleton_frame_out["pred_masks"].shape[3], |
| ), |
| dtype=singleton_frame_out["pred_masks"].dtype, |
| device=singleton_frame_out["pred_masks"].device, |
| ), |
| "object_score_logits": torch.full( |
| (num_objs, 1), |
| NO_OBJ_SCORE, |
| dtype=singleton_frame_out["object_score_logits"].dtype, |
| device=singleton_frame_out[ |
| "object_score_logits" |
| ].device, |
| ), |
| } |
| |
| inference_state["output_dict"][singleton_storage_key][ |
| frame_idx |
| ]["pred_masks"][ |
| new_obj_idx : new_obj_idx + 1 |
| ] = singleton_frame_out["pred_masks"] |
| inference_state["output_dict"][singleton_storage_key][ |
| frame_idx |
| ]["object_score_logits"][ |
| new_obj_idx : new_obj_idx + 1 |
| ] = singleton_frame_out["object_score_logits"] |
|
|
| |
| if "pred_masks_video_res" in singleton_frame_out: |
| inference_state["output_dict"][singleton_storage_key][ |
| frame_idx |
| ]["pred_masks_video_res"] = torch.zeros( |
| ( |
| num_objs, |
| 1, |
| singleton_frame_out["pred_masks_video_res"].shape[ |
| 2 |
| ], |
| singleton_frame_out["pred_masks_video_res"].shape[ |
| 3 |
| ], |
| ), |
| dtype=singleton_frame_out["pred_masks_video_res"].dtype, |
| device=singleton_frame_out[ |
| "pred_masks_video_res" |
| ].device, |
| ) |
| inference_state["output_dict"][singleton_storage_key][ |
| frame_idx |
| ]["pred_masks_video_res"][ |
| new_obj_idx : new_obj_idx + 1 |
| ] = singleton_frame_out["pred_masks_video_res"] |
|
|
| |
| if ( |
| "obj_ptr" in singleton_frame_out |
| and self.use_obj_ptrs_in_encoder |
| ): |
| singleton_obj_ptr_data = singleton_multiplex_state.demux( |
| singleton_frame_out["obj_ptr"] |
| ) |
| obj_ptr_data = torch.zeros( |
| (num_objs, singleton_obj_ptr_data.shape[1]), |
| dtype=singleton_obj_ptr_data.dtype, |
| device=singleton_obj_ptr_data.device, |
| ) |
| obj_ptr_data[new_obj_idx : new_obj_idx + 1] = ( |
| singleton_obj_ptr_data |
| ) |
| inference_state["output_dict"][singleton_storage_key][ |
| frame_idx |
| ]["obj_ptr"] = multiplex_state.mux(obj_ptr_data) |
| else: |
| |
| main_frame_out = inference_state["output_dict"][ |
| singleton_storage_key |
| ][frame_idx] |
|
|
| num_objs_total = len(inference_state["obj_ids"]) |
|
|
| if ( |
| singleton_frame_out.get("maskmem_features") is not None |
| and multiplex_state is not None |
| ): |
| |
| singleton_features_muxed = singleton_frame_out[ |
| "maskmem_features" |
| ] |
| if singleton_features_muxed.shape[:2] == ( |
| singleton_multiplex_state.num_buckets, |
| singleton_multiplex_state.multiplex_count, |
| ): |
| |
| singleton_features_data = ( |
| singleton_multiplex_state.demux( |
| singleton_features_muxed |
| ) |
| ) |
| else: |
| |
| singleton_features_data = singleton_features_muxed |
|
|
| existing_features_muxed = main_frame_out.get( |
| "maskmem_features" |
| ) |
| if existing_features_muxed is not None: |
| |
| if existing_features_muxed.shape[:2] == ( |
| multiplex_state.num_buckets, |
| multiplex_state.multiplex_count, |
| ): |
| |
| existing_features_data = multiplex_state.demux( |
| existing_features_muxed |
| ) |
| else: |
| |
| existing_features_data = existing_features_muxed |
| else: |
| existing_features_data = None |
|
|
| if existing_features_data is None: |
| feature_shape = ( |
| num_objs_total, |
| ) + singleton_features_data.shape[1:] |
| existing_features_data = torch.zeros( |
| feature_shape, |
| dtype=singleton_features_data.dtype, |
| device=singleton_features_data.device, |
| ) |
| elif existing_features_data.shape[0] < num_objs_total: |
| pad_size = ( |
| num_objs_total - existing_features_data.shape[0] |
| ) |
| pad = torch.zeros( |
| (pad_size,) + existing_features_data.shape[1:], |
| dtype=existing_features_data.dtype, |
| device=existing_features_data.device, |
| ) |
| existing_features_data = torch.cat( |
| [existing_features_data, pad], dim=0 |
| ) |
|
|
| existing_features_data[new_obj_idx : new_obj_idx + 1] = ( |
| singleton_features_data |
| ) |
| main_frame_out["maskmem_features"] = multiplex_state.mux( |
| existing_features_data |
| ) |
|
|
| if ( |
| singleton_frame_out.get("maskmem_pos_enc") is not None |
| and multiplex_state is not None |
| ): |
| existing_pos_enc_list = ( |
| main_frame_out.get("maskmem_pos_enc") or [] |
| ) |
| new_maskmem_pos_enc = [] |
| max_levels = max( |
| len(singleton_frame_out["maskmem_pos_enc"]), |
| len(existing_pos_enc_list), |
| ) |
| for level_idx in range(max_levels): |
| singleton_level_muxed = ( |
| singleton_frame_out["maskmem_pos_enc"][level_idx] |
| if level_idx |
| < len(singleton_frame_out["maskmem_pos_enc"]) |
| else None |
| ) |
| existing_level_muxed = ( |
| existing_pos_enc_list[level_idx] |
| if level_idx < len(existing_pos_enc_list) |
| else None |
| ) |
|
|
| if singleton_level_muxed is None: |
| |
| new_maskmem_pos_enc.append(existing_level_muxed) |
| continue |
|
|
| |
| if singleton_level_muxed.shape[:2] == ( |
| singleton_multiplex_state.num_buckets, |
| singleton_multiplex_state.multiplex_count, |
| ): |
| |
| singleton_level_data = ( |
| singleton_multiplex_state.demux( |
| singleton_level_muxed |
| ) |
| ) |
| else: |
| |
| singleton_level_data = singleton_level_muxed |
|
|
| if existing_level_muxed is not None: |
| |
| if existing_level_muxed.shape[:2] == ( |
| multiplex_state.num_buckets, |
| multiplex_state.multiplex_count, |
| ): |
| |
| existing_level_data = multiplex_state.demux( |
| existing_level_muxed |
| ) |
| else: |
| |
| existing_level_data = existing_level_muxed |
| else: |
| existing_level_data = None |
|
|
| if existing_level_data is None: |
| level_shape = ( |
| num_objs_total, |
| ) + singleton_level_data.shape[1:] |
| existing_level_data = torch.zeros( |
| level_shape, |
| dtype=singleton_level_data.dtype, |
| device=singleton_level_data.device, |
| ) |
| elif existing_level_data.shape[0] < num_objs_total: |
| pad_size = ( |
| num_objs_total - existing_level_data.shape[0] |
| ) |
| pad = torch.zeros( |
| (pad_size,) + existing_level_data.shape[1:], |
| dtype=existing_level_data.dtype, |
| device=existing_level_data.device, |
| ) |
| existing_level_data = torch.cat( |
| [existing_level_data, pad], dim=0 |
| ) |
|
|
| existing_level_data[new_obj_idx : new_obj_idx + 1] = ( |
| singleton_level_data |
| ) |
| new_maskmem_pos_enc.append( |
| multiplex_state.mux(existing_level_data) |
| ) |
|
|
| main_frame_out["maskmem_pos_enc"] = new_maskmem_pos_enc |
|
|
| singleton_pred_masks = singleton_frame_out[ |
| "pred_masks" |
| ] |
| singleton_scores = singleton_frame_out[ |
| "object_score_logits" |
| ] |
|
|
| |
| num_existing_objs = main_frame_out["pred_masks"].shape[0] |
| if new_obj_idx >= num_existing_objs: |
| num_objs_needed = new_obj_idx + 1 |
| pad_size = num_objs_needed - num_existing_objs |
|
|
| main_frame_out["pred_masks"] = torch.cat( |
| [ |
| main_frame_out["pred_masks"], |
| torch.zeros( |
| ( |
| pad_size, |
| 1, |
| singleton_pred_masks.shape[2], |
| singleton_pred_masks.shape[3], |
| ), |
| dtype=singleton_pred_masks.dtype, |
| device=singleton_pred_masks.device, |
| ), |
| ], |
| dim=0, |
| ) |
|
|
| main_frame_out["object_score_logits"] = torch.cat( |
| [ |
| main_frame_out["object_score_logits"], |
| torch.full( |
| (pad_size, 1), |
| NO_OBJ_SCORE, |
| dtype=singleton_scores.dtype, |
| device=singleton_scores.device, |
| ), |
| ], |
| dim=0, |
| ) |
|
|
| |
| main_frame_out["pred_masks"][new_obj_idx : new_obj_idx + 1] = ( |
| singleton_pred_masks |
| ) |
| main_frame_out["object_score_logits"][ |
| new_obj_idx : new_obj_idx + 1 |
| ] = singleton_scores |
| |
| |
| |
| if "local_obj_id_to_idx" not in main_frame_out: |
| main_frame_out["local_obj_id_to_idx"] = deepcopy( |
| inference_state["obj_id_to_idx"] |
| ) |
| main_frame_out["local_obj_id_to_idx"][obj_id] = new_obj_idx |
|
|
| |
| if "pred_masks_video_res" in singleton_frame_out: |
| if "pred_masks_video_res" in main_frame_out: |
| |
| if ( |
| main_frame_out["pred_masks_video_res"].shape[0] |
| < new_obj_idx + 1 |
| ): |
| pad_size = ( |
| new_obj_idx |
| + 1 |
| - main_frame_out["pred_masks_video_res"].shape[ |
| 0 |
| ] |
| ) |
| main_frame_out["pred_masks_video_res"] = torch.cat( |
| [ |
| main_frame_out["pred_masks_video_res"], |
| torch.zeros( |
| ( |
| pad_size, |
| 1, |
| singleton_frame_out[ |
| "pred_masks_video_res" |
| ].shape[2], |
| singleton_frame_out[ |
| "pred_masks_video_res" |
| ].shape[3], |
| ), |
| dtype=singleton_frame_out[ |
| "pred_masks_video_res" |
| ].dtype, |
| device=singleton_frame_out[ |
| "pred_masks_video_res" |
| ].device, |
| ), |
| ], |
| dim=0, |
| ) |
| else: |
| |
| num_objs = len(inference_state["obj_ids"]) |
| main_frame_out["pred_masks_video_res"] = torch.zeros( |
| ( |
| num_objs, |
| 1, |
| singleton_frame_out[ |
| "pred_masks_video_res" |
| ].shape[2], |
| singleton_frame_out[ |
| "pred_masks_video_res" |
| ].shape[3], |
| ), |
| dtype=singleton_frame_out[ |
| "pred_masks_video_res" |
| ].dtype, |
| device=singleton_frame_out[ |
| "pred_masks_video_res" |
| ].device, |
| ) |
| |
| main_frame_out["pred_masks_video_res"][ |
| new_obj_idx : new_obj_idx + 1 |
| ] = singleton_frame_out["pred_masks_video_res"] |
|
|
| |
| if ( |
| "obj_ptr" in singleton_frame_out |
| and self.use_obj_ptrs_in_encoder |
| ): |
| singleton_obj_ptr_data = singleton_multiplex_state.demux( |
| singleton_frame_out["obj_ptr"] |
| ) |
|
|
| if "obj_ptr" in main_frame_out: |
| |
| |
| |
|
|
| old_obj_ptr_muxed = main_frame_out["obj_ptr"] |
| |
| old_num_buckets = old_obj_ptr_muxed.shape[1] |
|
|
| |
| if old_num_buckets != multiplex_state.num_buckets: |
| |
| |
| num_objs = len(inference_state["obj_ids"]) |
| obj_ptr_data = torch.zeros( |
| (num_objs, singleton_obj_ptr_data.shape[1]), |
| dtype=singleton_obj_ptr_data.dtype, |
| device=singleton_obj_ptr_data.device, |
| ) |
| |
| obj_ptr_data[new_obj_idx : new_obj_idx + 1] = ( |
| singleton_obj_ptr_data |
| ) |
| main_frame_out["obj_ptr"] = multiplex_state.mux( |
| obj_ptr_data |
| ) |
| else: |
| |
| main_obj_ptr_data = multiplex_state.demux( |
| old_obj_ptr_muxed |
| ) |
|
|
| |
| if main_obj_ptr_data.shape[0] < new_obj_idx + 1: |
| pad_size = ( |
| new_obj_idx + 1 - main_obj_ptr_data.shape[0] |
| ) |
| main_obj_ptr_data = torch.cat( |
| [ |
| main_obj_ptr_data, |
| torch.zeros( |
| ( |
| pad_size, |
| main_obj_ptr_data.shape[1], |
| ), |
| dtype=main_obj_ptr_data.dtype, |
| device=main_obj_ptr_data.device, |
| ), |
| ], |
| dim=0, |
| ) |
|
|
| main_obj_ptr_data[new_obj_idx : new_obj_idx + 1] = ( |
| singleton_obj_ptr_data |
| ) |
| main_frame_out["obj_ptr"] = multiplex_state.mux( |
| main_obj_ptr_data |
| ) |
| else: |
| |
| num_objs = len(inference_state["obj_ids"]) |
| obj_ptr_data = torch.zeros( |
| (num_objs, singleton_obj_ptr_data.shape[1]), |
| dtype=singleton_obj_ptr_data.dtype, |
| device=singleton_obj_ptr_data.device, |
| ) |
| obj_ptr_data[new_obj_idx : new_obj_idx + 1] = ( |
| singleton_obj_ptr_data |
| ) |
| main_frame_out["obj_ptr"] = multiplex_state.mux( |
| obj_ptr_data |
| ) |
|
|
| |
| if singleton_obj_idx in singleton_frame_out.get( |
| "conditioning_objects", set() |
| ): |
| main_frame_out["conditioning_objects"].add(new_obj_idx) |
|
|
| @torch.inference_mode() |
| def add_new_points( |
| self, |
| inference_state, |
| frame_idx, |
| obj_id, |
| points, |
| labels, |
| clear_old_points, |
| rel_coordinates=True, |
| use_prev_mem_frame=False, |
| ): |
| """ |
| Add new points to create a new object in the multiplex model. |
| |
| This method converts point inputs to masks via the interactivity head and adds |
| the new object to the existing multiplex bucket (for dynamic models). |
| |
| Args: |
| inference_state: Current inference state |
| frame_idx: Frame index to add points |
| obj_id: Object ID (will be auto-created if new) |
| points: Point coordinates tensor |
| labels: Point labels tensor (1 for positive, 0 for negative) |
| clear_old_points: Whether to clear old points on this frame |
| rel_coordinates: Whether points are in relative coordinates [0, 1] |
| use_prev_mem_frame: Whether to use previous memory frames (for compatibility) |
| |
| Returns: |
| Tuple of (frame_idx, obj_ids, low_res_masks, video_res_masks) |
| """ |
| obj_idx = self._obj_id_to_idx(inference_state, obj_id) |
| obj_idxs = [obj_idx] |
| obj_ids = [obj_id] |
|
|
| point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] |
| mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] |
|
|
| if points.dim() == 2: |
| points = points.unsqueeze(0) |
| if labels.dim() == 1: |
| labels = labels.unsqueeze(0) |
|
|
| if rel_coordinates: |
| points = points * self.image_size |
|
|
| points = points.to(inference_state["device"]) |
| labels = labels.to(inference_state["device"]) |
|
|
| if not clear_old_points: |
| old_point_inputs = point_inputs_per_frame.get(frame_idx, None) |
| else: |
| old_point_inputs = None |
|
|
| point_inputs = concat_points(old_point_inputs, points, labels) |
| point_inputs_per_frame[frame_idx] = point_inputs |
|
|
| is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] |
|
|
| if is_init_cond_frame: |
| reverse = False |
| else: |
| reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] |
|
|
| is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond |
| storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" |
|
|
| multiplex_state = inference_state["multiplex_state"] |
| is_new_state = multiplex_state is None |
|
|
| if is_new_state: |
| multiplex_state = self.multiplex_controller.get_state( |
| num_valid_entries=1, |
| device=inference_state["device"], |
| dtype=torch.float32, |
| random=False, |
| object_ids=obj_ids, |
| ) |
| inference_state["multiplex_state"] = multiplex_state |
|
|
| |
| |
| |
| |
| is_existing_object = ( |
| not is_new_state |
| and multiplex_state is not None |
| and obj_id in multiplex_state.object_ids |
| ) |
|
|
| if is_existing_object: |
| if is_init_cond_frame: |
| is_new_obj = False |
| is_refine = False |
| is_gap_fill_case = True |
| else: |
| is_new_obj = False |
| is_refine = True |
| is_gap_fill_case = False |
| else: |
| is_new_obj = True |
| is_refine = False |
| is_gap_fill_case = False |
|
|
| if is_new_obj: |
| should_add_to_existing = not is_new_state |
| allow_new_buckets_local = True |
| prefer_new_buckets_local = True |
|
|
| current_out, _ = self._run_single_frame_inference( |
| inference_state=inference_state, |
| output_dict=inference_state["output_dict"], |
| frame_idx=frame_idx, |
| batch_size=1, |
| is_init_cond_frame=True, |
| point_inputs=point_inputs, |
| mask_inputs=None, |
| reverse=False, |
| run_mem_encoder=False, |
| prev_sam_mask_logits=None, |
| add_to_existing_state=should_add_to_existing, |
| new_obj_idxs=obj_idxs, |
| new_obj_ids=obj_ids, |
| allow_new_buckets=allow_new_buckets_local, |
| prefer_new_buckets=prefer_new_buckets_local, |
| objects_to_interact=None, |
| ) |
| elif is_refine: |
| singleton_state, original_obj_idx = self._extract_object_for_interaction( |
| inference_state, obj_id, frame_idx |
| ) |
|
|
| user_refined_frames_map = inference_state.get( |
| "user_refined_frames_per_obj", {} |
| ) |
| user_refined_frames = user_refined_frames_map.get(obj_id) |
| if user_refined_frames is None: |
| user_refined_frames = set() |
| is_first_refinement = frame_idx not in user_refined_frames |
|
|
| prev_sam_mask_logits_singleton = None |
| if not is_first_refinement: |
| singleton_obj_idx = 0 |
| singleton_output_dict = singleton_state["output_dict_per_obj"][ |
| singleton_obj_idx |
| ] |
| singleton_temp_output_dict = singleton_state[ |
| "temp_output_dict_per_obj" |
| ][singleton_obj_idx] |
|
|
| |
| |
| |
| prev_out = None |
|
|
| storage_key_current = ( |
| "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" |
| ) |
| prev_out = singleton_temp_output_dict[storage_key_current].get( |
| frame_idx |
| ) |
|
|
| if prev_out is None: |
| prev_out = singleton_output_dict["cond_frame_outputs"].get( |
| frame_idx |
| ) |
| if prev_out is None: |
| prev_out = singleton_output_dict["non_cond_frame_outputs"].get( |
| frame_idx |
| ) |
|
|
| if prev_out is not None and prev_out["pred_masks"] is not None: |
| prev_sam_mask_logits_singleton = prev_out["pred_masks"].cuda( |
| non_blocking=True |
| ) |
| prev_sam_mask_logits_singleton = torch.clamp( |
| prev_sam_mask_logits_singleton, -32.0, 32.0 |
| ) |
|
|
| if is_first_refinement: |
| |
| |
| singleton_is_init_cond = True |
| singleton_objects_to_interact = None |
| else: |
| |
| singleton_is_init_cond = False |
| singleton_objects_to_interact = ( |
| [0] if prev_sam_mask_logits_singleton is not None else None |
| ) |
|
|
| singleton_obj_idx = 0 |
| singleton_obj_idxs = [singleton_obj_idx] |
| singleton_obj_ids = [obj_id] |
|
|
| current_out, _ = self._run_single_frame_inference( |
| inference_state=singleton_state, |
| output_dict=singleton_state["output_dict"], |
| frame_idx=frame_idx, |
| batch_size=1, |
| is_init_cond_frame=singleton_is_init_cond, |
| point_inputs=point_inputs, |
| mask_inputs=None, |
| reverse=False, |
| run_mem_encoder=False, |
| prev_sam_mask_logits=prev_sam_mask_logits_singleton, |
| add_to_existing_state=False, |
| new_obj_idxs=singleton_obj_idxs, |
| new_obj_ids=singleton_obj_ids, |
| allow_new_buckets=False, |
| objects_to_interact=singleton_objects_to_interact, |
| ) |
|
|
| singleton_storage_key = ( |
| "cond_frame_outputs" |
| if singleton_is_init_cond |
| else "non_cond_frame_outputs" |
| ) |
|
|
| _, singleton_video_res_masks = self._get_orig_video_res_output( |
| singleton_state, current_out["pred_masks"] |
| ) |
| current_out["pred_masks_video_res"] = singleton_video_res_masks |
|
|
| singleton_state["output_dict"][singleton_storage_key][frame_idx] = ( |
| current_out |
| ) |
|
|
| self._merge_singleton_interaction_result( |
| inference_state, singleton_state, obj_id, original_obj_idx |
| ) |
|
|
| obj_idx = inference_state["obj_id_to_idx"][obj_id] |
| obj_idxs = [obj_idx] |
|
|
| if "user_refined_frames_per_obj" not in inference_state: |
| inference_state["user_refined_frames_per_obj"] = {} |
| if obj_id not in inference_state["user_refined_frames_per_obj"]: |
| inference_state["user_refined_frames_per_obj"][obj_id] = set() |
|
|
| inference_state["user_refined_frames_per_obj"][obj_id].add(frame_idx) |
|
|
| merged_frame_out = inference_state["output_dict"][singleton_storage_key][ |
| frame_idx |
| ] |
| obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] |
| obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] |
|
|
| if "pred_masks_video_res" in merged_frame_out: |
| pred_masks_video_res_slice = merged_frame_out["pred_masks_video_res"][ |
| obj_idx : obj_idx + 1 |
| ] |
| else: |
| _, video_res_masks = self._get_orig_video_res_output( |
| inference_state, merged_frame_out["pred_masks"] |
| ) |
| pred_masks_video_res_slice = video_res_masks[obj_idx : obj_idx + 1] |
|
|
| pred_masks_slice = merged_frame_out["pred_masks"][obj_idx : obj_idx + 1] |
|
|
| obj_temp_output_dict[singleton_storage_key][frame_idx] = { |
| "pred_masks": pred_masks_slice, |
| "pred_masks_video_res": pred_masks_video_res_slice, |
| "object_score_logits": merged_frame_out["object_score_logits"][ |
| obj_idx : obj_idx + 1 |
| ], |
| } |
| obj_output_dict[singleton_storage_key][frame_idx] = obj_temp_output_dict[ |
| singleton_storage_key |
| ][frame_idx] |
|
|
| elif is_gap_fill_case: |
| |
| |
| |
| obj_idx = inference_state["obj_id_to_idx"][obj_id] |
| obj_idxs = [obj_idx] |
| batch_size = self._get_obj_num(inference_state) |
|
|
| obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] |
| obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] |
|
|
| current_out, _ = self._run_single_frame_inference( |
| inference_state=inference_state, |
| output_dict=inference_state["output_dict"], |
| frame_idx=frame_idx, |
| batch_size=batch_size, |
| is_init_cond_frame=True, |
| point_inputs=point_inputs, |
| mask_inputs=None, |
| reverse=False, |
| run_mem_encoder=False, |
| prev_sam_mask_logits=None, |
| add_to_existing_state=False, |
| new_obj_idxs=[obj_idx], |
| new_obj_ids=[obj_id], |
| allow_new_buckets=False, |
| prefer_new_buckets=False, |
| objects_to_interact=[obj_idx], |
| ) |
|
|
| current_out["local_obj_id_to_idx"] = deepcopy( |
| inference_state["obj_id_to_idx"] |
| ) |
|
|
| _, video_res_masks = self._get_orig_video_res_output( |
| inference_state, current_out["pred_masks"] |
| ) |
| current_out["pred_masks_video_res"] = video_res_masks |
|
|
| is_cond = storage_key == "cond_frame_outputs" |
| if ( |
| is_cond |
| and frame_idx |
| in inference_state["output_dict"]["non_cond_frame_outputs"] |
| ): |
| del inference_state["output_dict"]["non_cond_frame_outputs"][frame_idx] |
| if "consolidated_frame_inds" in inference_state: |
| inference_state["consolidated_frame_inds"][ |
| "non_cond_frame_outputs" |
| ].discard(frame_idx) |
|
|
| |
| inference_state["output_dict"][storage_key][frame_idx] = current_out |
|
|
| |
| if "consolidated_frame_inds" in inference_state: |
| inference_state["consolidated_frame_inds"][storage_key].add(frame_idx) |
|
|
| |
| obj_temp_output_dict[storage_key][frame_idx] = { |
| "pred_masks": current_out["pred_masks"][obj_idx : obj_idx + 1], |
| "pred_masks_video_res": video_res_masks[obj_idx : obj_idx + 1], |
| "object_score_logits": current_out["object_score_logits"][ |
| obj_idx : obj_idx + 1 |
| ], |
| } |
| obj_output_dict[storage_key][frame_idx] = obj_temp_output_dict[storage_key][ |
| frame_idx |
| ] |
|
|
| |
| obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] |
| obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] |
|
|
| |
| if is_refine or is_gap_fill_case: |
| |
| |
|
|
| singleton_obj_idx = 0 |
|
|
| |
| _, video_res_masks_singleton = self._get_orig_video_res_output( |
| inference_state, current_out["pred_masks"] |
| ) |
|
|
| |
| if "consolidated_frame_inds" in inference_state: |
| inference_state["consolidated_frame_inds"][storage_key].add(frame_idx) |
|
|
| |
| video_res_masks_to_return = video_res_masks_singleton[ |
| singleton_obj_idx : singleton_obj_idx + 1 |
| ] |
| else: |
| |
| _, video_res_masks = self._get_orig_video_res_output( |
| inference_state, current_out["pred_masks"] |
| ) |
|
|
| current_out["pred_masks_video_res"] = video_res_masks |
| current_out["local_obj_id_to_idx"] = deepcopy( |
| inference_state["obj_id_to_idx"] |
| ) |
|
|
| |
| if ( |
| is_cond |
| and frame_idx |
| in inference_state["output_dict"]["non_cond_frame_outputs"] |
| ): |
| del inference_state["output_dict"]["non_cond_frame_outputs"][frame_idx] |
| |
| if "consolidated_frame_inds" in inference_state: |
| inference_state["consolidated_frame_inds"][ |
| "non_cond_frame_outputs" |
| ].discard(frame_idx) |
|
|
| inference_state["output_dict"][storage_key][frame_idx] = current_out |
|
|
| |
| if "consolidated_frame_inds" in inference_state: |
| inference_state["consolidated_frame_inds"][storage_key].add(frame_idx) |
|
|
| |
| obj_temp_output_dict[storage_key][frame_idx] = { |
| "pred_masks_video_res": current_out["pred_masks_video_res"][ |
| obj_idx : obj_idx + 1 |
| ], |
| "pred_masks": current_out["pred_masks"][obj_idx : obj_idx + 1], |
| "object_score_logits": current_out["object_score_logits"][ |
| obj_idx : obj_idx + 1 |
| ], |
| } |
|
|
| obj_output_dict[storage_key][frame_idx] = obj_temp_output_dict[storage_key][ |
| frame_idx |
| ] |
|
|
| video_res_masks_to_return = video_res_masks[obj_idx : obj_idx + 1] |
|
|
| low_res_masks = None |
| return frame_idx, obj_ids, low_res_masks, video_res_masks_to_return |
|
|
| @torch.inference_mode() |
| def add_new_masks( |
| self, |
| inference_state, |
| frame_idx, |
| obj_ids, |
| masks, |
| |
| add_mask_to_memory=False, |
| |
| reconditioning=False, |
| ): |
| """Add new mask to a frame.""" |
| if isinstance(obj_ids, np.ndarray): |
| obj_ids = obj_ids.tolist() |
| obj_idxs = [ |
| self._obj_id_to_idx(inference_state, obj_id, error_if_new=reconditioning) |
| for obj_id in obj_ids |
| ] |
| point_inputs_per_frame = [ |
| inference_state["point_inputs_per_obj"][obj_idx] for obj_idx in obj_idxs |
| ] |
| mask_inputs_per_frame = [ |
| inference_state["mask_inputs_per_obj"][obj_idx] for obj_idx in obj_idxs |
| ] |
|
|
| assert masks.dim() == 3 |
| num_objects, mask_H, mask_W = masks.shape |
| assert num_objects == len(obj_ids) |
| masks_inputs_orig = masks[:, None, :, :] |
| masks_inputs_orig = masks_inputs_orig.float().to(inference_state["device"]) |
|
|
| |
| if mask_H != self.input_mask_size or mask_W != self.input_mask_size: |
| mask_inputs = torch.nn.functional.interpolate( |
| masks_inputs_orig, |
| size=(self.input_mask_size, self.input_mask_size), |
| align_corners=False, |
| mode="bilinear", |
| antialias=True, |
| ) |
| else: |
| mask_inputs = masks_inputs_orig |
|
|
| |
| video_H = inference_state["video_height"] |
| video_W = inference_state["video_width"] |
| if mask_H != video_H or mask_W != video_W: |
| mask_inputs_video_res = torch.nn.functional.interpolate( |
| masks_inputs_orig, |
| size=(video_H, video_W), |
| align_corners=False, |
| mode="bilinear", |
| antialias=True, |
| ) |
| else: |
| mask_inputs_video_res = masks_inputs_orig |
| |
| mask_inputs_video_res = mask_inputs_video_res > 0.5 |
|
|
| multiplex_state = inference_state["multiplex_state"] |
| is_new_state = multiplex_state is None |
|
|
| if not reconditioning: |
| if is_new_state: |
| multiplex_state = self.multiplex_controller.get_state( |
| num_valid_entries=num_objects, |
| device=inference_state["device"], |
| dtype=torch.float32, |
| random=False, |
| object_ids=obj_ids, |
| ) |
| inference_state["multiplex_state"] = multiplex_state |
| else: |
| assert self.is_dynamic_model, ( |
| "New objects are not allowed after state creation" |
| ) |
|
|
| for i in range(num_objects): |
| mask_inputs_per_frame[i][frame_idx] = mask_inputs_video_res[i : i + 1] |
| point_inputs_per_frame[i].pop(frame_idx, None) |
| |
| |
| |
| |
| is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] |
| |
| if is_init_cond_frame: |
| reverse = False |
| else: |
| reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] |
| obj_output_dicts = [ |
| inference_state["output_dict_per_obj"][obj_idx] for obj_idx in obj_idxs |
| ] |
| obj_temp_output_dicts = [ |
| inference_state["temp_output_dict_per_obj"][obj_idx] for obj_idx in obj_idxs |
| ] |
| |
| |
| is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond |
| storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" |
|
|
| |
| allow_new_buckets_local = False |
| if not is_new_state and not reconditioning and multiplex_state is not None: |
| if multiplex_state.available_slots < num_objects: |
| allow_new_buckets_local = True |
|
|
| current_out, _ = self._run_single_frame_inference( |
| inference_state=inference_state, |
| output_dict=inference_state["output_dict"], |
| frame_idx=frame_idx, |
| batch_size=num_objects, |
| is_init_cond_frame=is_init_cond_frame, |
| point_inputs=None, |
| mask_inputs=mask_inputs, |
| reverse=reverse, |
| |
| |
| |
| |
| run_mem_encoder=False, |
| add_to_existing_state=not is_new_state and not reconditioning, |
| new_obj_idxs=obj_idxs, |
| new_obj_ids=obj_ids, |
| allow_new_buckets=allow_new_buckets_local, |
| reconditioning=reconditioning, |
| ) |
| |
| |
| |
| |
| _, video_res_masks = self._get_orig_video_res_output( |
| inference_state, current_out["pred_masks"] |
| ) |
| obj_idxs_t = torch.as_tensor(obj_idxs, device=video_res_masks.device) |
| video_res_masks[obj_idxs_t] = torch.where( |
| mask_inputs_video_res, -NO_OBJ_SCORE, NO_OBJ_SCORE |
| ) |
|
|
| current_out["pred_masks_video_res"] = video_res_masks |
| with torch.profiler.record_function("add_new_masks._deepcopy"): |
| current_out["local_obj_id_to_idx"] = deepcopy( |
| inference_state["obj_id_to_idx"] |
| ) |
| if ( |
| is_cond |
| and frame_idx in inference_state["output_dict"]["non_cond_frame_outputs"] |
| ): |
| del inference_state["output_dict"]["non_cond_frame_outputs"][frame_idx] |
| |
| if "consolidated_frame_inds" in inference_state: |
| inference_state["consolidated_frame_inds"][ |
| "non_cond_frame_outputs" |
| ].discard(frame_idx) |
|
|
| inference_state["output_dict"][storage_key][frame_idx] = current_out |
|
|
| |
| if "consolidated_frame_inds" in inference_state: |
| inference_state["consolidated_frame_inds"][storage_key].add(frame_idx) |
|
|
| with torch.profiler.record_function("add_new_masks.obj_loop"): |
| |
| for i, obj_idx in enumerate(obj_idxs): |
| |
| |
| obj_temp_output_dicts[i][storage_key][frame_idx] = { |
| "pred_masks_video_res": current_out["pred_masks_video_res"][ |
| obj_idx : obj_idx + 1 |
| ] |
| } |
| obj_output_dicts[i][storage_key][frame_idx] = obj_temp_output_dicts[i][ |
| storage_key |
| ][frame_idx] |
|
|
| |
| |
| combined_new_mask = mask_inputs_video_res.any( |
| dim=0, keepdim=True |
| ) |
|
|
| |
| num_new = len(obj_idxs) |
| exclude_self_masks = {} |
| if num_new > 1: |
| for i in range(num_new): |
| other_indices = torch.cat( |
| [ |
| torch.arange(i, device=mask_inputs_video_res.device), |
| torch.arange( |
| i + 1, num_new, device=mask_inputs_video_res.device |
| ), |
| ] |
| ) |
| exclude_self_masks[obj_idxs[i]] = mask_inputs_video_res[ |
| other_indices |
| ].any(dim=0, keepdim=True) |
|
|
| |
| temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] |
| obj_idxs_set = set(obj_idxs) |
|
|
| for obj_idx2, obj_temp_output_dict2 in temp_output_dict_per_obj.items(): |
| current_out2 = obj_temp_output_dict2[storage_key].get(frame_idx, None) |
| if current_out2 is None: |
| continue |
|
|
| if obj_idx2 not in obj_idxs_set: |
| |
| suppress_mask = combined_new_mask |
| elif obj_idx2 in exclude_self_masks: |
| |
| suppress_mask = exclude_self_masks[obj_idx2] |
| else: |
| |
| continue |
|
|
| current_out2["pred_masks_video_res"] = torch.where( |
| suppress_mask, |
| NO_OBJ_SCORE, |
| current_out2["pred_masks_video_res"], |
| ) |
|
|
| |
| obj_ids = inference_state["obj_ids"] |
| consolidated_out = self._consolidate_temp_output_across_obj( |
| inference_state, |
| frame_idx, |
| is_cond=is_cond, |
| run_mem_encoder=False, |
| consolidate_at_video_res=True, |
| ) |
| _, video_res_masks = self._get_orig_video_res_output( |
| inference_state, consolidated_out["pred_masks_video_res"] |
| ) |
| low_res_masks = None |
|
|
| consolidated_out["local_obj_id_to_idx"] = current_out["local_obj_id_to_idx"] |
|
|
| return frame_idx, obj_ids, low_res_masks, video_res_masks |
|
|
| def _get_orig_video_res_output(self, inference_state, any_res_masks): |
| """ |
| Resize the object scores to the original video resolution (video_res_masks) |
| and apply non-overlapping constraints for final output. |
| """ |
| device = inference_state["device"] |
| video_H = inference_state["video_height"] |
| video_W = inference_state["video_width"] |
| any_res_masks = any_res_masks.to(device, non_blocking=True) |
| if any_res_masks.shape[-2:] == (video_H, video_W): |
| video_res_masks = any_res_masks |
| else: |
| video_res_masks = torch.nn.functional.interpolate( |
| any_res_masks, |
| size=(video_H, video_W), |
| mode="bilinear", |
| align_corners=False, |
| ) |
| if self.non_overlap_masks_for_output: |
| video_res_masks = self._apply_non_overlapping_constraints(video_res_masks) |
| |
| if self.fill_hole_area > 0: |
| video_res_masks = fill_holes_in_mask_scores( |
| video_res_masks, self.fill_hole_area |
| ) |
| return any_res_masks, video_res_masks |
|
|
| def _consolidate_temp_output_across_obj( |
| self, |
| inference_state, |
| frame_idx, |
| is_cond, |
| run_mem_encoder, |
| consolidate_at_video_res=False, |
| ): |
| """ |
| Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on |
| a frame into a single output for all objects, including |
| 1) fill any missing objects either from `output_dict_per_obj` (if they exist in |
| `output_dict_per_obj` for this frame) or leave them as placeholder values |
| (if they don't exist in `output_dict_per_obj` for this frame); |
| 2) if specified, rerun memory encoder after apply non-overlapping constraints |
| on the object scores. |
| """ |
| batch_size = self._get_obj_num(inference_state) |
| storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" |
|
|
| |
| |
| max_obj_idx = batch_size - 1 |
|
|
| |
| for obj_idx in inference_state["temp_output_dict_per_obj"].keys(): |
| if obj_idx > max_obj_idx: |
| max_obj_idx = obj_idx |
| for obj_idx in inference_state["output_dict_per_obj"].keys(): |
| if obj_idx > max_obj_idx: |
| max_obj_idx = obj_idx |
|
|
| |
| consolidated_batch_size = max(max_obj_idx + 1, 0) |
|
|
| |
| |
| if consolidate_at_video_res: |
| assert not run_mem_encoder, "memory encoder cannot run at video resolution" |
| consolidated_H = inference_state["video_height"] |
| consolidated_W = inference_state["video_width"] |
| consolidated_mask_key = "pred_masks_video_res" |
| else: |
| consolidated_H = consolidated_W = self.low_res_mask_size |
| consolidated_mask_key = "pred_masks" |
|
|
| |
| |
| |
| |
|
|
| consolidated_out = { |
| "conditioning_objects": None, |
| "maskmem_features": None, |
| "maskmem_pos_enc": None, |
| "image_features": None, |
| "image_pos_enc": None, |
| "obj_ptr": None, |
| consolidated_mask_key: torch.full( |
| size=( |
| consolidated_batch_size, |
| 1, |
| consolidated_H, |
| consolidated_W, |
| ), |
| fill_value=NO_OBJ_SCORE, |
| dtype=torch.float32, |
| device=inference_state["storage_device"], |
| ), |
| } |
|
|
| all_out = inference_state["output_dict"]["cond_frame_outputs"].get( |
| frame_idx, None |
| ) |
| if all_out is None: |
| all_out = inference_state["output_dict"]["non_cond_frame_outputs"].get( |
| frame_idx, None |
| ) |
|
|
| |
| |
| need_to_reconstruct_from_per_obj = all_out is None |
|
|
| if need_to_reconstruct_from_per_obj: |
| |
| |
| conditioning_objects = set() |
| for obj_idx in range(batch_size): |
| |
| if obj_idx in inference_state["point_inputs_per_obj"]: |
| point_inputs = inference_state["point_inputs_per_obj"][obj_idx] |
| if ( |
| frame_idx in point_inputs |
| and point_inputs[frame_idx] is not None |
| ): |
| conditioning_objects.add(obj_idx) |
| continue |
|
|
| |
| if obj_idx in inference_state["mask_inputs_per_obj"]: |
| mask_inputs = inference_state["mask_inputs_per_obj"][obj_idx] |
| if frame_idx in mask_inputs and mask_inputs[frame_idx] is not None: |
| conditioning_objects.add(obj_idx) |
|
|
| consolidated_out["conditioning_objects"] = conditioning_objects |
| |
| |
| else: |
| |
| consolidated_out["conditioning_objects"] = all_out.get( |
| "conditioning_objects", set() |
| ) |
| consolidated_out["obj_ptr"] = all_out["obj_ptr"] |
| consolidated_out["object_score_logits"] = all_out["object_score_logits"] |
| if self.use_memory_selection: |
| consolidated_out["iou_score"] = all_out["iou_score"] |
| |
| consolidated_out["maskmem_features"] = all_out.get("maskmem_features") |
| consolidated_out["maskmem_pos_enc"] = all_out.get("maskmem_pos_enc") |
| consolidated_out["image_features"] = all_out.get("image_features") |
| consolidated_out["image_pos_enc"] = all_out.get("image_pos_enc") |
| consolidated_out["local_obj_id_to_idx"] = all_out.get( |
| "local_obj_id_to_idx", {} |
| ) |
| consolidated_out["obj_ptr"] = all_out["obj_ptr"] |
| consolidated_out["object_score_logits"] = all_out["object_score_logits"] |
| if self.use_memory_selection: |
| consolidated_out["iou_score"] = all_out["iou_score"] |
| |
| consolidated_out["maskmem_features"] = all_out.get("maskmem_features") |
| consolidated_out["maskmem_pos_enc"] = all_out.get("maskmem_pos_enc") |
| consolidated_out["image_features"] = all_out.get("image_features") |
| consolidated_out["image_pos_enc"] = all_out.get("image_pos_enc") |
| consolidated_out["local_obj_id_to_idx"] = all_out.get( |
| "local_obj_id_to_idx", {} |
| ) |
| all_mask = all_out.get("pred_masks_video_res", all_out["pred_masks"]) |
| |
| |
| |
| if all_mask.shape[-2:] == (consolidated_H, consolidated_W): |
| consolidated_out[consolidated_mask_key] = all_mask |
| else: |
| |
| |
| is_downsampling = all_mask.shape[-1] > consolidated_W |
| resized_mask = torch.nn.functional.interpolate( |
| all_mask, |
| size=(consolidated_H, consolidated_W), |
| mode="bilinear", |
| align_corners=False, |
| antialias=is_downsampling, |
| ) |
| consolidated_out[consolidated_mask_key] = resized_mask |
|
|
| |
| |
| obj_score_logits_list = [] |
| obj_ptr_list = [] if need_to_reconstruct_from_per_obj else None |
| iou_scores_list = ( |
| [] |
| if need_to_reconstruct_from_per_obj and self.use_memory_selection |
| else None |
| ) |
|
|
| |
| |
| if ( |
| need_to_reconstruct_from_per_obj |
| and consolidated_mask_key not in consolidated_out |
| ): |
| |
| consolidated_out[consolidated_mask_key] = torch.zeros( |
| (consolidated_batch_size, 1, consolidated_H, consolidated_W), |
| dtype=torch.float32, |
| device=inference_state["storage_device"], |
| ) |
| consolidated_out["object_score_logits"] = torch.full( |
| (consolidated_batch_size, 1), |
| NO_OBJ_SCORE, |
| dtype=torch.float32, |
| device=inference_state["storage_device"], |
| ) |
|
|
| for obj_idx in range( |
| consolidated_batch_size |
| ): |
| |
| if obj_idx not in inference_state["temp_output_dict_per_obj"]: |
| continue |
| if obj_idx not in inference_state["output_dict_per_obj"]: |
| continue |
| obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] |
| obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] |
| out = obj_temp_output_dict[storage_key].get(frame_idx, None) |
| |
| |
| |
| |
| if out is None: |
| out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None) |
| if out is None: |
| out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None) |
| if out is None: |
| |
| continue |
| |
| |
| obj_mask = out.get("pred_masks_video_res") |
| if obj_mask is None: |
| obj_mask = out.get("pred_masks") |
| consolidated_pred_masks = consolidated_out[consolidated_mask_key] |
|
|
| |
| |
| if obj_idx >= consolidated_pred_masks.shape[0]: |
| pad_size = obj_idx + 1 - consolidated_pred_masks.shape[0] |
| consolidated_pred_masks = torch.cat( |
| [ |
| consolidated_pred_masks, |
| torch.zeros( |
| ( |
| pad_size, |
| 1, |
| consolidated_pred_masks.shape[-2], |
| consolidated_pred_masks.shape[-1], |
| ), |
| dtype=consolidated_pred_masks.dtype, |
| device=consolidated_pred_masks.device, |
| ), |
| ], |
| dim=0, |
| ) |
| consolidated_out[consolidated_mask_key] = consolidated_pred_masks |
| |
| if "object_score_logits" in consolidated_out: |
| consolidated_scores = consolidated_out["object_score_logits"] |
| consolidated_scores = torch.cat( |
| [ |
| consolidated_scores, |
| torch.full( |
| (pad_size, 1), |
| NO_OBJ_SCORE, |
| dtype=consolidated_scores.dtype, |
| device=consolidated_scores.device, |
| ), |
| ], |
| dim=0, |
| ) |
| consolidated_out["object_score_logits"] = consolidated_scores |
|
|
| if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]: |
| |
| if obj_mask.dtype != consolidated_pred_masks.dtype: |
| obj_mask = obj_mask.to(consolidated_pred_masks.dtype) |
| consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask |
| else: |
| |
| is_downsampling = "pred_masks_video_res" in out |
| resized_obj_mask = torch.nn.functional.interpolate( |
| obj_mask, |
| size=consolidated_pred_masks.shape[-2:], |
| mode="bilinear", |
| align_corners=False, |
| antialias=is_downsampling, |
| ) |
| |
| if resized_obj_mask.dtype != consolidated_pred_masks.dtype: |
| resized_obj_mask = resized_obj_mask.to( |
| consolidated_pred_masks.dtype |
| ) |
| consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask |
|
|
| |
| if need_to_reconstruct_from_per_obj: |
| if "object_score_logits" in out: |
| obj_score_logits_list.append(out["object_score_logits"]) |
| if self.use_memory_selection and "iou_score" in out: |
| iou_scores_list.append(out["iou_score"]) |
|
|
| |
| if need_to_reconstruct_from_per_obj: |
| |
| |
| |
| |
| |
| if not obj_score_logits_list and run_mem_encoder: |
| run_mem_encoder = False |
|
|
| if obj_score_logits_list: |
| consolidated_out["object_score_logits"] = torch.cat( |
| obj_score_logits_list, dim=0 |
| ) |
| else: |
| |
| device = inference_state["device"] |
| consolidated_out["object_score_logits"] = torch.zeros( |
| (batch_size, 1), |
| dtype=torch.float32, |
| device=device, |
| ) |
|
|
| if self.use_memory_selection: |
| if iou_scores_list: |
| consolidated_out["iou_score"] = torch.cat(iou_scores_list, dim=0) |
| else: |
| consolidated_out["iou_score"] = None |
|
|
| |
| consolidated_out["obj_ptr"] = None |
|
|
| |
| |
| if run_mem_encoder: |
| device = inference_state["device"] |
| high_res_masks = torch.nn.functional.interpolate( |
| consolidated_out["pred_masks"].to(device, non_blocking=True), |
| size=(self.image_size, self.image_size), |
| mode="bilinear", |
| align_corners=False, |
| ) |
| high_res_masks = self._apply_non_overlapping_constraints(high_res_masks) |
| maskmem_features, maskmem_pos_enc, image_features, image_pos_enc = ( |
| self._run_memory_encoder( |
| inference_state=inference_state, |
| frame_idx=frame_idx, |
| batch_size=batch_size, |
| high_res_masks=high_res_masks, |
| object_score_logits=consolidated_out["object_score_logits"], |
| is_mask_from_pts=True, |
| conditioning_objects=consolidated_out[ |
| "conditioning_objects" |
| ], |
| ) |
| ) |
| consolidated_out["maskmem_features"] = maskmem_features |
| consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc |
| consolidated_out["image_features"] = image_features |
| consolidated_out["image_pos_enc"] = image_pos_enc |
|
|
| return consolidated_out |
|
|
| @torch.inference_mode() |
| def propagate_in_video_preflight(self, inference_state, run_mem_encoder=True): |
| """Prepare inference_state and consolidate temporary outputs before tracking.""" |
| inference_state["tracking_has_started"] = True |
| batch_size = self._get_obj_num(inference_state) |
|
|
| |
| |
| temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] |
| output_dict = inference_state["output_dict"] |
| |
| |
| |
| consolidated_frame_inds = inference_state["consolidated_frame_inds"] |
| for is_cond in [False, True]: |
| |
| storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" |
| |
| |
| |
| temp_frame_inds = set() |
| for obj_temp_output_dict in temp_output_dict_per_obj.values(): |
| temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) |
| consolidated_frame_inds[storage_key].update(temp_frame_inds) |
| |
| for frame_idx in temp_frame_inds: |
| consolidated_out = self._consolidate_temp_output_across_obj( |
| inference_state, |
| frame_idx, |
| is_cond=is_cond, |
| run_mem_encoder=run_mem_encoder, |
| ) |
| |
| output_dict[storage_key][frame_idx] = consolidated_out |
| self._add_output_per_object( |
| inference_state, frame_idx, consolidated_out, storage_key |
| ) |
| clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( |
| self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 |
| ) |
| if clear_non_cond_mem: |
| |
| self._clear_non_cond_mem_around_input(inference_state, frame_idx) |
|
|
| |
| for obj_temp_output_dict in temp_output_dict_per_obj.values(): |
| obj_temp_output_dict[storage_key].clear() |
|
|
| |
| |
| for frame_idx in output_dict["cond_frame_outputs"]: |
| output_dict["non_cond_frame_outputs"].pop(frame_idx, None) |
| for obj_output_dict in inference_state["output_dict_per_obj"].values(): |
| for frame_idx in obj_output_dict["cond_frame_outputs"]: |
| obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) |
| for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: |
| assert frame_idx in output_dict["cond_frame_outputs"] |
| consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) |
|
|
| |
| |
| all_consolidated_frame_inds = ( |
| consolidated_frame_inds["cond_frame_outputs"] |
| | consolidated_frame_inds["non_cond_frame_outputs"] |
| ) |
|
|
| input_frames_inds = set() |
| for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values(): |
| input_frames_inds.update(point_inputs_per_frame.keys()) |
| for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values(): |
| input_frames_inds.update(mask_inputs_per_frame.keys()) |
| assert all_consolidated_frame_inds == input_frames_inds |
| |
| if inference_state["first_ann_frame_idx"] is None: |
| inference_state["first_ann_frame_idx"] = min( |
| input_frames_inds, default=None |
| ) |
| |
| |
| if ( |
| inference_state["first_ann_frame_idx"] |
| not in output_dict["cond_frame_outputs"] |
| ): |
| inference_state["first_ann_frame_idx"] = min( |
| output_dict["cond_frame_outputs"], default=None |
| ) |
|
|
| def _get_processing_order( |
| self, inference_state, start_frame_idx, max_frame_num_to_track, reverse |
| ): |
| num_frames = inference_state["num_frames"] |
| |
| if self.always_start_from_first_ann_frame: |
| |
| |
| start_frame_idx = inference_state["first_ann_frame_idx"] |
| if start_frame_idx is None: |
| |
| start_frame_idx = min(inference_state["output_dict"]["cond_frame_outputs"]) |
| if max_frame_num_to_track is None: |
| |
| max_frame_num_to_track = num_frames |
| if reverse: |
| end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) |
| if start_frame_idx > 0: |
| processing_order = range(start_frame_idx, end_frame_idx - 1, -1) |
| else: |
| |
| |
| |
| |
| processing_order = [0] |
| else: |
| end_frame_idx = min( |
| start_frame_idx + max_frame_num_to_track, num_frames - 1 |
| ) |
| processing_order = range(start_frame_idx, end_frame_idx + 1) |
| return processing_order |
|
|
| @torch.inference_mode() |
| def propagate_in_video( |
| self, |
| inference_state, |
| start_frame_idx, |
| max_frame_num_to_track, |
| reverse, |
| tqdm_disable=False, |
| obj_ids=None, |
| run_mem_encoder=True, |
| ): |
| """Propagate the input points across frames to track in the entire video.""" |
| output_dict = inference_state["output_dict"] |
| consolidated_frame_inds = inference_state["consolidated_frame_inds"] |
| if obj_ids is not None: |
| raise NotImplementedError( |
| "Per-object tracking yet for batched inference if not implemented." |
| ) |
| obj_ids = inference_state["obj_ids"] |
| batch_size = self._get_obj_num(inference_state) |
| if len(output_dict["cond_frame_outputs"]) == 0: |
| raise RuntimeError("No points are provided; please add points first") |
| clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( |
| self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 |
| ) |
| assert clear_non_cond_mem is False, "Not implemented" |
|
|
| processing_order = self._get_processing_order( |
| inference_state, |
| start_frame_idx, |
| max_frame_num_to_track, |
| reverse, |
| ) |
|
|
| for frame_idx in tqdm( |
| processing_order, desc="propagate in video", disable=tqdm_disable |
| ): |
| |
| |
| |
| |
| if frame_idx in consolidated_frame_inds["cond_frame_outputs"]: |
| storage_key = "cond_frame_outputs" |
| current_out = output_dict[storage_key][frame_idx] |
| pred_masks = current_out["pred_masks"] |
| if clear_non_cond_mem: |
| |
| self._clear_non_cond_mem_around_input(inference_state, frame_idx) |
| elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]: |
| storage_key = "non_cond_frame_outputs" |
| current_out = output_dict[storage_key][frame_idx] |
| pred_masks = current_out["pred_masks"] |
| else: |
| storage_key = "non_cond_frame_outputs" |
| with torch.profiler.record_function( |
| "VideoTrackingMultiplexDemo._run_single_frame_inference" |
| ): |
| current_out, pred_masks = self._run_single_frame_inference( |
| inference_state=inference_state, |
| output_dict=output_dict, |
| frame_idx=frame_idx, |
| batch_size=batch_size, |
| is_init_cond_frame=False, |
| point_inputs=None, |
| mask_inputs=None, |
| reverse=reverse, |
| run_mem_encoder=run_mem_encoder, |
| ) |
| current_out["local_obj_id_to_idx"] = deepcopy( |
| inference_state["obj_id_to_idx"] |
| ) |
| output_dict[storage_key][frame_idx] = current_out |
| |
| |
| self._add_output_per_object( |
| inference_state, frame_idx, current_out, storage_key |
| ) |
| inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse} |
|
|
| |
| |
| low_res_masks, video_res_masks = self._get_orig_video_res_output( |
| inference_state, pred_masks |
| ) |
| yield frame_idx, obj_ids, low_res_masks, video_res_masks |
|
|
| def _add_output_per_object( |
| self, inference_state, frame_idx, current_out, storage_key |
| ): |
| """ |
| Split a multi-object output into per-object output slices and add them into |
| `output_dict_per_obj`. The resulting slices share the same tensor storage. |
| """ |
| |
| |
|
|
| output_dict_per_obj = inference_state["output_dict_per_obj"] |
| for obj_idx, obj_output_dict in output_dict_per_obj.items(): |
| obj_slice = slice(obj_idx, obj_idx + 1) |
| obj_out = { |
| "pred_masks": current_out["pred_masks"][obj_slice], |
| "object_score_logits": current_out["object_score_logits"][obj_slice], |
| } |
| if self.use_memory_selection: |
| obj_out["iou_score"] = current_out["iou_score"][obj_slice] |
| obj_output_dict[storage_key][frame_idx] = obj_out |
|
|
| @torch.inference_mode() |
| def clear_all_points_in_frame( |
| self, |
| inference_state, |
| frame_idx, |
| obj_id, |
| need_output=True, |
| preserve_user_refined: bool = False, |
| ): |
| """Remove all input points or mask in a specific frame for a given object.""" |
| obj_idx = self._obj_id_to_idx(inference_state, obj_id) |
|
|
| |
| inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None) |
| inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None) |
|
|
| |
| if ( |
| not preserve_user_refined |
| and "user_refined_frames_per_obj" in inference_state |
| ): |
| user_refined_map = inference_state["user_refined_frames_per_obj"] |
| if obj_id in user_refined_map: |
| user_refined_map[obj_id].discard(frame_idx) |
|
|
| temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] |
| temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None) |
| temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None) |
|
|
| |
| batch_size = self._get_obj_num(inference_state) |
| frame_has_input = False |
| for obj_idx2 in range(batch_size): |
| |
| if obj_idx2 not in inference_state["point_inputs_per_obj"]: |
| continue |
| if obj_idx2 not in inference_state["mask_inputs_per_obj"]: |
| continue |
| if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]: |
| frame_has_input = True |
| break |
| if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]: |
| frame_has_input = True |
| break |
|
|
| |
| |
| if not frame_has_input: |
| output_dict = inference_state["output_dict"] |
| consolidated_frame_inds = inference_state["consolidated_frame_inds"] |
| consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx) |
| consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) |
| |
| out = output_dict["cond_frame_outputs"].pop(frame_idx, None) |
| if out is not None: |
| |
| |
| output_dict["non_cond_frame_outputs"][frame_idx] = out |
| inference_state["frames_already_tracked"].pop(frame_idx, None) |
| |
| for obj_idx2 in range(batch_size): |
| |
| if obj_idx2 not in inference_state["output_dict_per_obj"]: |
| continue |
| obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2] |
| obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None) |
| if obj_out is not None: |
| obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out |
|
|
| |
| if len(output_dict["cond_frame_outputs"]) == 0: |
| self._reset_tracking_results(inference_state) |
|
|
| if not need_output: |
| return |
| |
| obj_ids = inference_state["obj_ids"] |
| is_cond = any( |
| frame_idx in obj_temp_output_dict["cond_frame_outputs"] |
| for obj_temp_output_dict in temp_output_dict_per_obj.values() |
| ) |
| consolidated_out = self._consolidate_temp_output_across_obj( |
| inference_state, |
| frame_idx, |
| is_cond=is_cond, |
| run_mem_encoder=False, |
| consolidate_at_video_res=True, |
| ) |
| _, video_res_masks = self._get_orig_video_res_output( |
| inference_state, consolidated_out["pred_masks_video_res"] |
| ) |
| low_res_masks = None |
| return frame_idx, obj_ids, low_res_masks, video_res_masks |
|
|
| @torch.inference_mode() |
| def clear_all_points_in_video(self, inference_state): |
| """Remove all input points or mask in all frames throughout the video.""" |
| self._reset_tracking_results(inference_state) |
| |
| inference_state["obj_id_to_idx"].clear() |
| inference_state["obj_idx_to_id"].clear() |
| inference_state["obj_ids"].clear() |
| inference_state["point_inputs_per_obj"].clear() |
| inference_state["mask_inputs_per_obj"].clear() |
| inference_state["output_dict_per_obj"].clear() |
| inference_state["temp_output_dict_per_obj"].clear() |
| inference_state["multiplex_state"] = None |
|
|
| def _reset_tracking_results(self, inference_state): |
| """Reset all tracking inputs and results across the videos.""" |
| for v in inference_state["point_inputs_per_obj"].values(): |
| v.clear() |
| for v in inference_state["mask_inputs_per_obj"].values(): |
| v.clear() |
| for v in inference_state["output_dict_per_obj"].values(): |
| v["cond_frame_outputs"].clear() |
| v["non_cond_frame_outputs"].clear() |
| for v in inference_state["temp_output_dict_per_obj"].values(): |
| v["cond_frame_outputs"].clear() |
| v["non_cond_frame_outputs"].clear() |
| inference_state["output_dict"]["cond_frame_outputs"].clear() |
| inference_state["output_dict"]["non_cond_frame_outputs"].clear() |
| inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear() |
| inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear() |
| inference_state["tracking_has_started"] = False |
| inference_state["frames_already_tracked"].clear() |
| inference_state["first_ann_frame_idx"] = None |
|
|
| def _get_image_feature(self, inference_state, frame_idx, batch_size): |
| """Compute the image features on a given frame.""" |
| |
| image, backbone_out = inference_state["cached_features"].get( |
| frame_idx, (None, None) |
| ) |
| if backbone_out is None: |
| |
| image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0) |
| |
| backbone_out = self.forward_image( |
| NestedTensor(tensors=image, mask=None), |
| need_sam3_out=True, |
| need_interactive_out=True, |
| need_propagation_out=True, |
| ) |
| |
| |
| inference_state["cached_features"] = {frame_idx: (image, backbone_out)} |
|
|
| features = self._prepare_backbone_features(backbone_out) |
| return image, features |
|
|
| def _run_single_frame_inference( |
| self, |
| inference_state, |
| output_dict, |
| frame_idx, |
| batch_size, |
| is_init_cond_frame, |
| point_inputs, |
| mask_inputs, |
| reverse, |
| run_mem_encoder, |
| prev_sam_mask_logits=None, |
| add_to_existing_state: bool = False, |
| new_obj_idxs: Optional[list[int]] = None, |
| new_obj_ids: Optional[list[int]] = None, |
| allow_new_buckets: bool = False, |
| prefer_new_buckets: bool = False, |
| reconditioning: bool = False, |
| objects_to_interact: Optional[list[int]] = None, |
| ): |
| """Run tracking on a single frame based on current inputs and previous memory.""" |
| |
| with torch.profiler.record_function( |
| "VideoTrackingMultiplexDemo._get_image_feature" |
| ): |
| image, backbone_features = self._get_image_feature( |
| inference_state, frame_idx, batch_size |
| ) |
|
|
| if add_to_existing_state or reconditioning: |
| assert new_obj_idxs is not None |
| assert new_obj_ids is not None |
|
|
| backbone_features_interactive = backbone_features["interactive"] |
| backbone_features_propagation = backbone_features["sam2_backbone_out"] |
|
|
| if add_to_existing_state or reconditioning: |
| with torch.profiler.record_function( |
| "VideoTrackingMultiplexDemo.add_new_masks_to_existing_state" |
| ): |
| |
| |
| existing_out = output_dict["cond_frame_outputs"].get(frame_idx) |
| if existing_out is None: |
| existing_out = output_dict["non_cond_frame_outputs"].get(frame_idx) |
| if existing_out is None: |
| raise RuntimeError( |
| f"No existing output found for frame {frame_idx} in either storage" |
| ) |
|
|
| |
| interactive_pix_feat = self._get_interactive_pix_mem( |
| backbone_features_interactive["vision_feats"], |
| backbone_features_interactive["feat_sizes"], |
| ) |
|
|
| |
| interactive_high_res_features = [ |
| x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) |
| for x, s in zip( |
| backbone_features_interactive["vision_feats"][:-1], |
| backbone_features_interactive["feat_sizes"][:-1], |
| ) |
| ] |
|
|
| |
| propagation_vision_feats = ( |
| backbone_features_propagation["vision_feats"] |
| if run_mem_encoder |
| else None |
| ) |
| propagation_feat_sizes = ( |
| backbone_features_propagation["feat_sizes"] |
| if run_mem_encoder |
| else None |
| ) |
|
|
| |
| if reconditioning: |
| self.recondition_masks_in_existing_state( |
| interactive_pix_feat=interactive_pix_feat, |
| interactive_high_res_features=interactive_high_res_features, |
| propagation_vision_feats=propagation_vision_feats, |
| propagation_feat_sizes=propagation_feat_sizes, |
| new_masks=mask_inputs, |
| obj_idxs_in_mask=new_obj_idxs, |
| obj_ids_in_mask=new_obj_ids, |
| prev_output=existing_out, |
| multiplex_state=inference_state["multiplex_state"], |
| add_mask_to_memory=run_mem_encoder, |
| ) |
| else: |
| |
| |
| new_masks_from_points = None |
| if mask_inputs is None and point_inputs is not None: |
| with torch.profiler.record_function( |
| "VideoTrackingMultiplexDemo.points_to_masks" |
| ): |
| multimask_output = self._use_multimask( |
| is_init_cond_frame, point_inputs=point_inputs |
| ) |
| interaction_out = self._forward_sam_heads( |
| backbone_features=interactive_pix_feat, |
| point_inputs=point_inputs, |
| mask_inputs=None, |
| interactive_high_res_features=interactive_high_res_features, |
| multimask_output=multimask_output, |
| objects_to_interact=new_obj_idxs, |
| multiplex_state=inference_state["multiplex_state"], |
| ) |
| new_masks_from_points = interaction_out["low_res_masks"] |
|
|
| self.add_new_masks_to_existing_state( |
| interactive_pix_feat=interactive_pix_feat, |
| interactive_high_res_features=interactive_high_res_features, |
| propagation_vision_feats=propagation_vision_feats, |
| propagation_feat_sizes=propagation_feat_sizes, |
| new_masks=( |
| mask_inputs |
| if mask_inputs is not None |
| else new_masks_from_points |
| ), |
| obj_idxs_in_mask=new_obj_idxs, |
| obj_ids_in_mask=new_obj_ids, |
| prev_output=existing_out, |
| multiplex_state=inference_state["multiplex_state"], |
| add_mask_to_memory=run_mem_encoder, |
| are_masks_from_pts=(mask_inputs is None), |
| allow_new_buckets=allow_new_buckets, |
| prefer_new_buckets=prefer_new_buckets, |
| ) |
|
|
| |
| current_out = existing_out |
| else: |
| |
| assert point_inputs is None or mask_inputs is None |
| with torch.profiler.record_function( |
| "VideoTrackingMultiplexDemo.track_step" |
| ): |
| current_out = self.track_step( |
| frame_idx=frame_idx, |
| is_init_cond_frame=is_init_cond_frame, |
| backbone_features_interactive=backbone_features_interactive, |
| backbone_features_propagation=backbone_features_propagation, |
| image=image, |
| point_inputs=point_inputs, |
| mask_inputs=mask_inputs, |
| gt_masks=None, |
| frames_to_add_correction_pt=[], |
| output_dict=output_dict, |
| num_frames=inference_state["num_frames"], |
| track_in_reverse=reverse, |
| run_mem_encoder=run_mem_encoder, |
| prev_sam_mask_logits=prev_sam_mask_logits, |
| multiplex_state=inference_state["multiplex_state"], |
| objects_to_interact=objects_to_interact, |
| ) |
|
|
| |
| storage_device = inference_state["storage_device"] |
| if current_out.get("maskmem_features") is not None: |
| maskmem_features = current_out["maskmem_features"] |
| maskmem_features = maskmem_features.to( |
| device=storage_device, dtype=torch.bfloat16, non_blocking=True |
| ) |
| else: |
| maskmem_features = None |
|
|
| if current_out.get("image_features") is not None: |
| assert "image_pos_enc" in current_out |
| image_features = current_out["image_features"].to( |
| storage_device, non_blocking=True |
| ) |
| image_pos_enc = current_out["image_pos_enc"].to( |
| storage_device, non_blocking=True |
| ) |
| else: |
| image_features = image_pos_enc = None |
|
|
| pred_masks_gpu = current_out["pred_masks"] |
| pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) |
| |
| with torch.profiler.record_function( |
| "VideoTrackingMultiplexDemo.maskmem_pos_enc" |
| ): |
| maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) |
| |
| obj_ptr = current_out["obj_ptr"] |
| object_score_logits = current_out["object_score_logits"] |
| conditioning_objects = current_out["conditioning_objects"] |
| |
| compact_current_out = { |
| "maskmem_features": maskmem_features, |
| "maskmem_pos_enc": maskmem_pos_enc, |
| "image_features": image_features, |
| "image_pos_enc": image_pos_enc, |
| "pred_masks": pred_masks, |
| "obj_ptr": obj_ptr, |
| "object_score_logits": object_score_logits, |
| "conditioning_objects": conditioning_objects, |
| } |
| if self.use_memory_selection: |
| with torch.profiler.record_function( |
| "VideoTrackingMultiplexDemo.use_memory_selection" |
| ): |
| compact_current_out["iou_score"] = current_out["iou_score"] |
| compact_current_out["eff_iou_score"] = self.cal_mem_score( |
| object_score_logits, current_out["iou_score"] |
| ) |
| return compact_current_out, pred_masks_gpu |
|
|
| def _run_memory_encoder( |
| self, |
| inference_state, |
| frame_idx, |
| batch_size, |
| high_res_masks, |
| object_score_logits, |
| is_mask_from_pts, |
| conditioning_objects=None, |
| ): |
| """ |
| Run the memory encoder on `high_res_masks`. This is usually after applying |
| non-overlapping constraints to object scores. Since their scores changed, their |
| memory also need to be computed again with the memory encoder. |
| """ |
| |
| image, backbone_features = self._get_image_feature( |
| inference_state, frame_idx, batch_size |
| ) |
| backbone_features_propagation = backbone_features["sam2_backbone_out"] |
| propagation_vision_feats = backbone_features_propagation["vision_feats"] |
| propagation_vision_pos_embeds = backbone_features_propagation[ |
| "vision_pos_embeds" |
| ] |
| propagation_feat_sizes = backbone_features_propagation["feat_sizes"] |
|
|
| |
| if conditioning_objects is None: |
| output_dict = inference_state["output_dict"] |
| for storage_key in ["cond_frame_outputs", "non_cond_frame_outputs"]: |
| storage = output_dict[storage_key] |
| if frame_idx not in storage: |
| continue |
| conditioning_objects = storage[frame_idx]["conditioning_objects"] |
| break |
| else: |
| raise ValueError(f"conditioning objects not found at {frame_idx=}") |
|
|
| maskmem_features, maskmem_pos_enc = self._encode_new_memory( |
| image=image, |
| current_vision_feats=propagation_vision_feats, |
| feat_sizes=propagation_feat_sizes, |
| pred_masks_high_res=high_res_masks, |
| object_score_logits=object_score_logits, |
| is_mask_from_pts=is_mask_from_pts, |
| conditioning_objects=conditioning_objects, |
| multiplex_state=inference_state["multiplex_state"], |
| ) |
|
|
| |
| storage_device = inference_state["storage_device"] |
| maskmem_features = maskmem_features.to(torch.bfloat16) |
| maskmem_features = maskmem_features.to(storage_device, non_blocking=True) |
| |
| maskmem_pos_enc = self._get_maskmem_pos_enc( |
| inference_state, {"maskmem_pos_enc": maskmem_pos_enc} |
| ) |
|
|
| image_features = propagation_vision_feats[-1] |
| image_features = image_features.to(storage_device, non_blocking=True) |
| image_pos_enc = propagation_vision_pos_embeds[-1] |
| image_pos_enc = image_pos_enc.to(storage_device, non_blocking=True) |
| return maskmem_features, maskmem_pos_enc, image_features, image_pos_enc |
|
|
| def _get_maskmem_pos_enc(self, inference_state, current_out): |
| """ |
| `maskmem_pos_enc` is the same across frames and objects, so we cache it as |
| a constant in the inference session to reduce session storage size. |
| """ |
| model_constants = inference_state["constants"] |
| |
| out_maskmem_pos_enc = current_out.get("maskmem_pos_enc") |
| if out_maskmem_pos_enc is not None: |
| if "maskmem_pos_enc" not in model_constants: |
| assert isinstance(out_maskmem_pos_enc, list) |
| |
| maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] |
| model_constants["maskmem_pos_enc"] = maskmem_pos_enc |
| else: |
| maskmem_pos_enc = model_constants["maskmem_pos_enc"] |
| |
| batch_size = out_maskmem_pos_enc[0].size(0) |
| expanded_maskmem_pos_enc = [ |
| x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc |
| ] |
| else: |
| expanded_maskmem_pos_enc = None |
| return expanded_maskmem_pos_enc |
|
|
| @torch.inference_mode() |
| def remove_object( |
| self, |
| inference_state, |
| obj_id: int, |
| strict=False, |
| need_output=True, |
| clear_user_refined_map: bool = True, |
| ): |
| """ |
| Remove a single object from the tracking state. |
| |
| This is a convenience wrapper around remove_objects() for removing a single object. |
| |
| Args: |
| inference_state: Current inference state |
| obj_id: Object ID to remove |
| strict: If True, raise error if object doesn't exist |
| need_output: Whether to return updated frames |
| |
| Returns: |
| Tuple of (remaining_obj_ids, updated_frames) |
| """ |
| return self.remove_objects( |
| inference_state, |
| obj_ids=[obj_id], |
| strict=strict, |
| need_output=need_output, |
| clear_user_refined_map=clear_user_refined_map, |
| ) |
|
|
| @torch.inference_mode() |
| def remove_objects( |
| self, |
| inference_state, |
| obj_ids: Iterable[int], |
| strict=False, |
| need_output=True, |
| clear_user_refined_map: bool = True, |
| ): |
| """ |
| Remove a list of object ids from the tracking state. If strict is True, we check whether |
| the object ids actually exist and raise an error if any of them don't exist. |
| """ |
| obj_ids = list(obj_ids) |
| old_obj_idxs_to_rm = [ |
| inference_state["obj_id_to_idx"].get(obj_id, None) for obj_id in obj_ids |
| ] |
| updated_frames = [] |
| actually_used_obj_ids = [] |
| removing_any = False |
| for old_obj_idx_to_rm, obj_id in zip(old_obj_idxs_to_rm, obj_ids, strict=True): |
| if old_obj_idx_to_rm is None: |
| if strict: |
| raise ValueError( |
| f"Object id {obj_id} does not exist in the tracking state." |
| ) |
| else: |
| actually_used_obj_ids.append(obj_id) |
| removing_any = True |
| if not removing_any: |
| return inference_state["obj_ids"], updated_frames |
|
|
| |
| old_obj_idxs_to_rm = [x for x in old_obj_idxs_to_rm if x is not None] |
| obj_ids = actually_used_obj_ids |
| removed_obj_ids = list(obj_ids) |
|
|
| |
| |
| |
| |
| |
| if clear_user_refined_map and "user_refined_frames_per_obj" in inference_state: |
| user_refined_map = inference_state["user_refined_frames_per_obj"] |
| for removed_obj_id in removed_obj_ids: |
| if removed_obj_id in user_refined_map: |
| user_refined_map.pop(removed_obj_id, None) |
|
|
| all_obj_input_frames_inds = set() |
| for old_obj_idx_to_rm, obj_id in zip(old_obj_idxs_to_rm, obj_ids, strict=True): |
| obj_input_frames_inds = set() |
| obj_input_frames_inds.update( |
| inference_state["point_inputs_per_obj"][old_obj_idx_to_rm] |
| ) |
| obj_input_frames_inds.update( |
| inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm] |
| ) |
| for frame_idx in obj_input_frames_inds: |
| self.clear_all_points_in_frame( |
| inference_state, |
| frame_idx, |
| obj_id, |
| need_output=False, |
| preserve_user_refined=not clear_user_refined_map, |
| ) |
| all_obj_input_frames_inds.update(obj_input_frames_inds) |
|
|
| |
| |
| old_obj_ids = inference_state["obj_ids"] |
| old_obj_inds = list(range(len(old_obj_ids))) |
| remain_old_obj_inds = old_obj_inds.copy() |
| for old_obj_idx_to_rm in old_obj_idxs_to_rm: |
| remain_old_obj_inds.remove(old_obj_idx_to_rm) |
| new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds] |
| new_obj_inds = list(range(len(new_obj_ids))) |
| |
| old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds)) |
| inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds)) |
| inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids)) |
| inference_state["obj_ids"] = new_obj_ids |
|
|
| if len(new_obj_ids) == 0: |
| return new_obj_ids, updated_frames |
|
|
| |
| |
| |
| def _map_keys(container): |
| new_kvs = [] |
| for k in old_obj_inds: |
| v = container.pop(k) |
| if k in old_idx_to_new_idx: |
| new_kvs.append((old_idx_to_new_idx[k], v)) |
| container.update(new_kvs) |
|
|
| _map_keys(inference_state["point_inputs_per_obj"]) |
| _map_keys(inference_state["mask_inputs_per_obj"]) |
| _map_keys(inference_state["output_dict_per_obj"]) |
| _map_keys(inference_state["temp_output_dict_per_obj"]) |
|
|
| multiplex_state: MultiplexState = inference_state["multiplex_state"] |
| |
| buckets_to_keep = multiplex_state.remove_objects( |
| old_obj_idxs_to_rm, strict=True |
| ) |
| obj_ids = set(obj_ids) |
|
|
| |
| def _slice_state(output_dict, storage_key): |
| for frame_idx, out in output_dict[storage_key].items(): |
| if out.get("maskmem_features") is not None: |
| out["maskmem_features"] = out["maskmem_features"][buckets_to_keep] |
| if out.get("maskmem_pos_enc") is not None: |
| out["maskmem_pos_enc"] = [ |
| x[buckets_to_keep] for x in out["maskmem_pos_enc"] |
| ] |
| |
| out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out) |
| if out.get("obj_ptr") is not None: |
| out["obj_ptr"] = out["obj_ptr"][buckets_to_keep] |
|
|
| |
| |
| |
| |
| local_obj_id_to_idx = out["local_obj_id_to_idx"] |
|
|
| |
| local_remain_old_obj_inds = [ |
| obj_idx |
| for obj_id, obj_idx in local_obj_id_to_idx.items() |
| if obj_id not in obj_ids |
| ] |
|
|
| |
| max_pred = out["pred_masks"].shape[0] |
| max_scores = out["object_score_logits"].shape[0] |
| keep_indices = [ |
| idx |
| for idx in local_remain_old_obj_inds |
| if 0 <= idx < max_pred and 0 <= idx < max_scores |
| ] |
| out["pred_masks"] = out["pred_masks"][keep_indices] |
| out["object_score_logits"] = out["object_score_logits"][keep_indices] |
| if self.use_memory_selection: |
| out["iou_score"] = out["iou_score"][keep_indices] |
| out["eff_iou_score"] = self.cal_mem_score( |
| out["object_score_logits"], out["iou_score"] |
| ) |
| sliced_conditioning_objects = set() |
|
|
| |
| new_local_obj_id_to_idx = {} |
| old_to_new = { |
| old_idx: new_i for new_i, old_idx in enumerate(keep_indices) |
| } |
| for obj_id, old_idx in local_obj_id_to_idx.items(): |
| if obj_id not in obj_ids: |
| |
| if old_idx in old_to_new: |
| new_idx = old_to_new[old_idx] |
| new_local_obj_id_to_idx[obj_id] = new_idx |
| if old_idx in out["conditioning_objects"]: |
| sliced_conditioning_objects.add(new_idx) |
|
|
| out["local_obj_id_to_idx"] = new_local_obj_id_to_idx |
| out["conditioning_objects"] = sliced_conditioning_objects |
|
|
| |
| self._add_output_per_object( |
| inference_state, frame_idx, out, storage_key |
| ) |
|
|
| _slice_state(inference_state["output_dict"], "cond_frame_outputs") |
| _slice_state(inference_state["output_dict"], "non_cond_frame_outputs") |
|
|
| |
| |
| if need_output: |
| temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] |
| for frame_idx in all_obj_input_frames_inds: |
| is_cond = any( |
| frame_idx in obj_temp_output_dict["cond_frame_outputs"] |
| for obj_temp_output_dict in temp_output_dict_per_obj.values() |
| ) |
| consolidated_out = self._consolidate_temp_output_across_obj( |
| inference_state, |
| frame_idx, |
| is_cond=is_cond, |
| run_mem_encoder=False, |
| consolidate_at_video_res=True, |
| ) |
| _, video_res_masks = self._get_orig_video_res_output( |
| inference_state, consolidated_out["pred_masks_video_res"] |
| ) |
| updated_frames.append((frame_idx, video_res_masks)) |
|
|
| return inference_state["obj_ids"], updated_frames |
|
|
| def _clear_non_cond_mem_around_input(self, inference_state, frame_idx): |
| """ |
| Remove the non-conditioning memory around the input frame. When users provide |
| correction clicks, the surrounding frames' non-conditioning memories can still |
| contain outdated object appearance information and could confuse the model. |
| |
| This function clears those non-conditioning memories surrounding the interacted |
| frame to avoid giving the model both old and new information about the object. |
| """ |
| r = self.memory_temporal_stride_for_eval |
| frame_idx_begin = frame_idx - r * self.num_maskmem |
| frame_idx_end = frame_idx + r * self.num_maskmem |
| output_dict = inference_state["output_dict"] |
| non_cond_frame_outputs = output_dict["non_cond_frame_outputs"] |
| for t in range(frame_idx_begin, frame_idx_end + 1): |
| non_cond_frame_outputs.pop(t, None) |
| for obj_output_dict in inference_state["output_dict_per_obj"].values(): |
| obj_output_dict["non_cond_frame_outputs"].pop(t, None) |
|
|
| @torch.inference_mode() |
| @torch.autocast(device_type="cuda", dtype=torch.bfloat16) |
| def warm_up_compilation( |
| self, offload_video_to_cpu=False, offload_state_to_cpu=False |
| ): |
| """ |
| Warm up the model by running a dummy inference to compile the model. This is |
| useful to avoid the compilation overhead in the first inference call. |
| """ |
| if not self.compile_all_components: |
| return |
|
|
| raise NotImplementedError( |
| "Please use `VideoTrackingMultiplexDemoPerBucketInference` instead for full model compilation." |
| ) |
|
|
|
|
| class Sam3VideoTrackingMultiplexDemo(VideoTrackingMultiplexDemo): |
| @torch.inference_mode() |
| def init_state( |
| self, |
| video_height, |
| video_width, |
| num_frames, |
| cached_features=None, |
| offload_video_to_cpu=False, |
| offload_state_to_cpu=False, |
| ): |
| """Initialize a inference state.""" |
| |
| |
| |
| if not self.apply_sigmoid_to_mask_logits_for_mem_enc: |
| raise NotImplementedError( |
| "Multi-object tracking requires sigmoid in memory encoder for non-overlapping constraints." |
| ) |
| inference_state = {} |
| |
| inference_state["num_frames"] = num_frames |
| |
| |
| inference_state["offload_video_to_cpu"] = offload_video_to_cpu |
| |
| |
| |
| |
| inference_state["offload_state_to_cpu"] = offload_state_to_cpu |
| |
| inference_state["video_height"] = video_height |
| inference_state["video_width"] = video_width |
| inference_state["device"] = get_accelerator_device() |
| if offload_state_to_cpu: |
| inference_state["storage_device"] = torch.device("cpu") |
| else: |
| inference_state["storage_device"] = get_accelerator_device() |
| |
| inference_state["point_inputs_per_obj"] = {} |
| inference_state["mask_inputs_per_obj"] = {} |
| |
| inference_state["cached_features"] = ( |
| {} if cached_features is None else cached_features |
| ) |
| |
| inference_state["constants"] = {} |
| |
| inference_state["obj_id_to_idx"] = OrderedDict() |
| inference_state["obj_idx_to_id"] = OrderedDict() |
| inference_state["obj_ids"] = [] |
| |
| inference_state["output_dict"] = { |
| "cond_frame_outputs": {}, |
| "non_cond_frame_outputs": {}, |
| } |
| |
| inference_state["first_ann_frame_idx"] = None |
| |
| inference_state["output_dict_per_obj"] = {} |
| |
| |
| inference_state["temp_output_dict_per_obj"] = {} |
| |
| |
| inference_state["consolidated_frame_inds"] = { |
| "cond_frame_outputs": set(), |
| "non_cond_frame_outputs": set(), |
| } |
| |
| inference_state["tracking_has_started"] = False |
| inference_state["frames_already_tracked"] = {} |
| inference_state["multiplex_state"] = None |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self.clear_all_points_in_video(inference_state) |
| return inference_state |
|
|
| def _suppress_shrinked_masks( |
| self, pred_masks, new_pred_masks, shrink_threshold=0.3 |
| ): |
| area_before = (pred_masks > 0).sum(dim=(-1, -2)) |
| area_after = (new_pred_masks > 0).sum(dim=(-1, -2)) |
| area_before = torch.clamp(area_before, min=1.0) |
| area_ratio = area_after / area_before |
| keep = area_ratio >= shrink_threshold |
| keep_mask = keep[..., None, None].expand_as(pred_masks) |
| pred_masks_after = torch.where( |
| keep_mask, pred_masks, torch.clamp(pred_masks, max=-10.0) |
| ) |
| return pred_masks_after |
|
|
| @staticmethod |
| def _suppress_object_pw_area_shrinkage(pred_masks): |
| """ |
| This function suppresses masks that shrink in area after applying pixelwise non-overlapping constriants. |
| Note that the final output can still be overlapping. |
| """ |
| |
| |
| |
| |
|
|
| batch_size = pred_masks.size(0) |
| if batch_size == 1: |
| return pred_masks |
|
|
| device = pred_masks.device |
| |
| max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) |
| |
| batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] |
| keep = max_obj_inds == batch_obj_inds |
| |
| |
| pixel_level_non_overlapping_masks = torch.where( |
| keep, pred_masks, torch.clamp(pred_masks, max=-10.0) |
| ) |
|
|
| |
| |
| |
| |
| |
|
|
| shrink_threshold = 0.3 |
| area_before = (pred_masks > 0).sum(dim=(-1, -2)) |
| area_after = (pixel_level_non_overlapping_masks > 0).sum(dim=(-1, -2)) |
| area_before = torch.clamp(area_before, min=1.0) |
| area_ratio = area_after / area_before |
| keep = area_ratio >= shrink_threshold |
| keep_mask = keep[..., None, None].expand_as(pred_masks) |
| pred_masks_after = torch.where( |
| keep_mask, pred_masks, torch.clamp(pred_masks, max=-10.0) |
| ) |
|
|
| return pred_masks_after |
|
|
| def _apply_object_wise_non_overlapping_constraints( |
| self, pred_masks, obj_scores, background_value=-10.0 |
| ): |
| """ |
| Applies non-overlapping constraints object wise (i.e. only one object can claim the overlapping region) |
| """ |
| |
| |
| pred_masks_single_score = torch.where( |
| pred_masks > 0, obj_scores[..., None, None], background_value |
| ) |
| |
| pixel_level_non_overlapping_masks = super()._apply_non_overlapping_constraints( |
| pred_masks_single_score |
| ) |
| |
| pred_masks = torch.where( |
| pixel_level_non_overlapping_masks > 0, |
| pred_masks, |
| torch.clamp(pred_masks, max=background_value), |
| ) |
| return pred_masks |
|
|
| @torch.inference_mode() |
| def propagate_in_video( |
| self, |
| inference_state, |
| start_frame_idx, |
| max_frame_num_to_track, |
| reverse, |
| tqdm_disable=False, |
| obj_ids=None, |
| run_mem_encoder=True, |
| ): |
| """Propagate the input points across frames to track in the entire video.""" |
| |
| output_dict = inference_state["output_dict"] |
| consolidated_frame_inds = inference_state["consolidated_frame_inds"] |
| if obj_ids is not None: |
| raise NotImplementedError( |
| "Per-object tracking yet for batched inference if not implemented." |
| ) |
| obj_ids = inference_state["obj_ids"] |
| batch_size = self._get_obj_num(inference_state) |
| if len(output_dict["cond_frame_outputs"]) == 0: |
| raise RuntimeError("No points are provided; please add points first") |
| clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( |
| self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 |
| ) |
|
|
| processing_order = self._get_processing_order( |
| inference_state, |
| start_frame_idx, |
| max_frame_num_to_track, |
| reverse, |
| ) |
|
|
| for frame_idx in tqdm( |
| processing_order, desc="propagate in video", disable=tqdm_disable |
| ): |
| |
| |
| |
| |
| if frame_idx in consolidated_frame_inds["cond_frame_outputs"]: |
| storage_key = "cond_frame_outputs" |
| current_out = output_dict[storage_key][frame_idx] |
| pred_masks = current_out["pred_masks"] |
| obj_scores = current_out["object_score_logits"] |
| if clear_non_cond_mem: |
| |
| self._clear_non_cond_mem_around_input(inference_state, frame_idx) |
| elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]: |
| storage_key = "non_cond_frame_outputs" |
| current_out = output_dict[storage_key][frame_idx] |
| pred_masks = current_out["pred_masks"] |
| obj_scores = current_out["object_score_logits"] |
| else: |
| storage_key = "non_cond_frame_outputs" |
| with torch.profiler.record_function( |
| "VideoTrackingMultiplexDemo._run_single_frame_inference" |
| ): |
| current_out, pred_masks = self._run_single_frame_inference( |
| inference_state=inference_state, |
| output_dict=output_dict, |
| frame_idx=frame_idx, |
| batch_size=batch_size, |
| is_init_cond_frame=False, |
| point_inputs=None, |
| mask_inputs=None, |
| reverse=reverse, |
| run_mem_encoder=run_mem_encoder, |
| ) |
| obj_scores = current_out["object_score_logits"] |
| current_out["local_obj_id_to_idx"] = deepcopy( |
| inference_state["obj_id_to_idx"] |
| ) |
| output_dict[storage_key][frame_idx] = current_out |
|
|
| |
| |
| self._add_output_per_object( |
| inference_state, frame_idx, current_out, storage_key |
| ) |
| inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse} |
|
|
| |
| |
| low_res_masks, video_res_masks = self._get_orig_video_res_output( |
| inference_state, pred_masks |
| ) |
| yield frame_idx, obj_ids, low_res_masks, video_res_masks, obj_scores |
|
|