| from collections import defaultdict |
|
|
| """ |
| Video tracking model with multiplexing support. |
| |
| This file extends the base video tracking with prompt functionality to add: |
| - Multiplexing: Support for processing multiple objects simultaneously |
| - Recording image features in memory to support the decoupled transformer for memory reading |
| """ |
|
|
| import logging |
| from copy import deepcopy |
|
|
| try: |
| from typing import Iterable, Literal, NotRequired, Optional, Required, TypedDict |
| except ImportError: |
| from typing_extensions import ( |
| Iterable, |
| Literal, |
| NotRequired, |
| Optional, |
| Required, |
| TypedDict, |
| ) |
|
|
| import numpy as np |
| import torch |
| import torch.distributed |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from ..model.data_misc import BatchedDatapoint, NestedTensor |
| from ..model.memory import SimpleMaskEncoder |
| from ..model.multiplex_mask_decoder import MLP, MultiplexMaskDecoder |
| from ..model.multiplex_utils import MultiplexController, MultiplexState |
| from ..model.sam3_tracker_utils import ( |
| get_1d_sine_pe, |
| get_next_point, |
| sample_box_points, |
| select_closest_cond_frames, |
| ) |
| from ..sam.mask_decoder import MaskDecoder |
| from ..sam.prompt_encoder import PositionEmbeddingRandom, PromptEncoder |
| from ..sam.transformer import TwoWayTransformer |
| from timm.layers import trunc_normal_ |
|
|
|
|
| |
| NO_OBJ_SCORE = -1024.0 |
|
|
| neck_outs = ["interactive", "sam2_backbone_out"] |
|
|
|
|
| class SAMOutput(TypedDict, total=True): |
| |
| low_res_multimasks: torch.Tensor |
| high_res_multimasks: torch.Tensor |
| ious: torch.Tensor |
| low_res_masks: torch.Tensor |
| high_res_masks: torch.Tensor |
| object_score_logits: torch.Tensor |
| obj_ptr: NotRequired[torch.Tensor] |
|
|
|
|
| class StageOutput(TypedDict, total=False): |
| |
| conditioning_objects: Required[set[int]] |
|
|
| |
| pred_masks: torch.Tensor |
| pred_masks_high_res: torch.Tensor |
| point_inputs: dict[str, torch.Tensor] |
| mask_inputs: torch.Tensor |
| object_score_logits: torch.Tensor |
| obj_ptr: torch.Tensor |
| maskmem_features: torch.Tensor |
| maskmem_pos_enc: list[torch.Tensor] |
| image_features: torch.Tensor |
| image_pos_enc: torch.Tensor |
|
|
| |
| iou_score: torch.Tensor |
| eff_iou_score: torch.Tensor |
|
|
| |
| multistep_pred_masks: torch.Tensor |
| multistep_pred_masks_high_res: torch.Tensor |
| multistep_pred_multimasks: list[torch.Tensor] |
| multistep_pred_multimasks_high_res: list[torch.Tensor] |
| multistep_pred_ious: list[torch.Tensor] |
| multistep_point_inputs: list[dict] |
| multistep_object_score_logits: list[torch.Tensor] |
|
|
|
|
| class VideoTrackingMultiplex(nn.Module): |
| def __init__( |
| self, |
| backbone: nn.Module, |
| transformer: nn.Module, |
| maskmem_backbone: nn.Module, |
| multiplex_controller: MultiplexController, |
| num_maskmem: int = 7, |
| image_size: int = 512, |
| backbone_stride: int = 16, |
| prob_to_use_pt_input_for_train: float = 0.0, |
| prob_to_use_pt_input_for_eval: float = 0.0, |
| prob_to_use_box_input_for_train: float = 0.0, |
| prob_to_use_box_input_for_eval: float = 0.0, |
| |
| apply_sigmoid_to_mask_logits_for_mem_enc: bool = False, |
| sigmoid_scale_for_mem_enc: float = 1.0, |
| sigmoid_bias_for_mem_enc: float = 0.0, |
| |
| binarize_mask_from_pts_for_mem_enc: bool = False, |
| use_mask_input_as_output_without_sam: bool = False, |
| |
| |
| |
| |
| |
| num_frames_to_correct_for_train: int = 1, |
| num_frames_to_correct_for_eval: int = 1, |
| rand_frames_to_correct_for_train: bool = False, |
| rand_frames_to_correct_for_eval: bool = False, |
| prob_correct_all_objects_for_train: float = 0.0, |
| ratio_of_objects_to_correct_for_train: float = 1.0, |
| force_correct_all_for_conditional_inputs: bool = False, |
| rand_objects_to_correct_for_train: bool = True, |
| |
| |
| |
| |
| |
| |
| num_init_cond_frames_for_train: int = 1, |
| num_init_cond_frames_for_eval: int = 1, |
| rand_init_cond_frames_for_train: bool = True, |
| rand_init_cond_frames_for_eval: bool = False, |
| |
| |
| |
| max_cond_frames_in_attn: int = -1, |
| |
| keep_first_cond_frame=False, |
| |
| |
| add_all_frames_to_correct_as_cond: bool = False, |
| |
| |
| num_correction_pt_per_frame: int = 7, |
| |
| |
| |
| pt_sampling_for_eval: Literal["uniform", "center"] = "center", |
| |
| |
| |
| prob_to_sample_from_gt_for_train: float = 0.0, |
| |
| |
| directly_add_no_mem_embed: bool = False, |
| |
| use_high_res_features_in_sam: bool = False, |
| |
| multimask_output_in_sam: bool = False, |
| |
| |
| multimask_min_pt_num: int = 1, |
| multimask_max_pt_num: int = 1, |
| |
| multimask_output_for_tracking: bool = False, |
| |
| |
| use_multimask_token_for_obj_ptr: bool = False, |
| |
| |
| use_best_iou_mask_for_mem_enc: bool = False, |
| |
| iou_prediction_use_sigmoid: bool = False, |
| |
| iter_use_prev_mask_pred: bool = False, |
| |
| |
| forward_backbone_per_frame_for_eval: bool = False, |
| |
| |
| |
| memory_temporal_stride_for_eval: int = 1, |
| |
| |
| offload_output_to_cpu_for_eval: bool = False, |
| |
| |
| trim_past_non_cond_mem_for_eval: bool = False, |
| |
| non_overlap_masks_for_mem_enc: bool = False, |
| |
| use_obj_ptrs_in_encoder: bool = False, |
| |
| max_obj_ptrs_in_encoder: int = 16, |
| |
| add_tpos_enc_to_obj_ptrs: bool = True, |
| |
| |
| proj_tpos_enc_in_obj_ptrs: bool = False, |
| |
| |
| use_signed_tpos_enc_to_obj_ptrs: bool = False, |
| |
| |
| only_obj_ptrs_in_the_past_for_eval: bool = False, |
| |
| pred_obj_scores: bool = False, |
| |
| pred_obj_scores_mlp: bool = False, |
| |
| |
| |
| fixed_no_obj_ptr: bool = False, |
| use_no_obj_ptr: bool = True, |
| use_mlp_for_obj_ptr_proj: bool = False, |
| |
| use_linear_no_obj_ptr: bool = False, |
| |
| no_obj_embed_spatial: bool = False, |
| |
| sincos_tpos_enc: bool = True, |
| |
| sam_mask_decoder_extra_args: Optional[dict] = None, |
| |
| compile_all_components: bool = False, |
| |
| save_image_features: bool = False, |
| |
| num_multimask_outputs: int = 3, |
| |
| decode_mask_with_shared_tokens: bool = False, |
| |
| decode_mask_attribute_with_shared_tokens: bool = False, |
| share_necks: bool = False, |
| |
| |
| |
| randomness_fix: bool = False, |
| |
| add_output_suppression_embeddings: bool = False, |
| |
| add_object_conditional_embeddings: bool = False, |
| |
| add_object_unconditional_embeddings: Optional[bool] = None, |
| |
| condition_as_mask_input: bool = False, |
| condition_as_mask_input_fg: float = 1.0, |
| condition_as_mask_input_bg: float = 0.0, |
| |
| |
| |
| |
| use_maskmem_tpos_v2: bool = False, |
| |
| use_memory_selection: bool = False, |
| |
| mf_threshold: float = 0.01, |
| |
| is_dynamic_model: bool = False, |
| object_score_logit_threshold: float = 0.0, |
| stability_score_attentuation: bool = False, |
| ): |
| super().__init__() |
|
|
| |
| interactive_sam_mask_decoder_extra_args = deepcopy(sam_mask_decoder_extra_args) |
| if sam_mask_decoder_extra_args is not None: |
| dynamic_multimask_via_stability = sam_mask_decoder_extra_args.get( |
| "dynamic_multimask_via_stability", False |
| ) |
| if dynamic_multimask_via_stability: |
| sam_mask_decoder_extra_args["dynamic_multimask_via_stability"] = False |
| print( |
| "dynamic_multimask_via_stability is reset to False in the multiplex model" |
| ) |
|
|
| |
| self.backbone = backbone |
| |
| self.use_high_res_features_in_sam = use_high_res_features_in_sam |
| self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 |
| self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder |
| self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder |
| if use_obj_ptrs_in_encoder: |
| |
| |
| |
| self.interactive_mask_downsample = torch.nn.Conv2d( |
| 1, 1, kernel_size=4, stride=4 |
| ) |
|
|
| self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs |
| if proj_tpos_enc_in_obj_ptrs: |
| assert add_tpos_enc_to_obj_ptrs |
| self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs |
| self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs |
| self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval |
| self.multiplex_controller = multiplex_controller |
| self.save_image_features = save_image_features |
| self.multiplex_count = self.multiplex_controller.multiplex_count |
|
|
| |
| |
| assert transformer.decoder is None, "transformer should be encoder-only" |
| self.transformer = transformer |
| self.hidden_dim: int = transformer.d_model |
|
|
| |
| self.maskmem_backbone = maskmem_backbone |
| self.mem_dim = self.hidden_dim |
| if hasattr(self.maskmem_backbone, "out_proj") and hasattr( |
| self.maskmem_backbone.out_proj, "weight" |
| ): |
| |
| mem_dim = self.maskmem_backbone.out_proj.weight.shape[0] |
| assert mem_dim == self.hidden_dim, ( |
| "there should be no compression of memory embeddings" |
| ) |
| self.num_maskmem = num_maskmem |
| |
| self.sincos_tpos_enc = sincos_tpos_enc |
| self.use_maskmem_tpos_v2 = use_maskmem_tpos_v2 |
| |
| |
| |
| self.maskmem_tpos_enc = torch.nn.Parameter( |
| torch.zeros(num_maskmem, 1, 1, self.mem_dim) |
| ) |
| trunc_normal_(self.maskmem_tpos_enc, std=0.02) |
|
|
| |
| self.interactivity_no_mem_embed = torch.nn.Parameter( |
| torch.zeros(1, 1, self.hidden_dim) |
| ) |
| trunc_normal_(self.interactivity_no_mem_embed, std=0.02) |
| self.directly_add_no_mem_embed = directly_add_no_mem_embed |
|
|
| |
| |
| self.apply_sigmoid_to_mask_logits_for_mem_enc = ( |
| apply_sigmoid_to_mask_logits_for_mem_enc |
| ) |
| if apply_sigmoid_to_mask_logits_for_mem_enc: |
| self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc |
| self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc |
|
|
| if binarize_mask_from_pts_for_mem_enc: |
| logging.warning( |
| """ |
| The current model is not trained with binarize_mask_from_pts_for_mem_enc; |
| We force it to False here because external callers often hardcoded this |
| to True, ignoring the config. |
| Re-training should be possible. |
| """ |
| ) |
| binarize_mask_from_pts_for_mem_enc = False |
|
|
| self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc |
| self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc |
| self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval |
| |
| |
| self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam |
| self.multimask_output_in_sam = multimask_output_in_sam |
| self.multimask_min_pt_num = multimask_min_pt_num |
| self.multimask_max_pt_num = multimask_max_pt_num |
| self.multimask_output_for_tracking = multimask_output_for_tracking |
| self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr |
| self.use_best_iou_mask_for_mem_enc = use_best_iou_mask_for_mem_enc |
| self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid |
| self.object_score_logit_threshold = object_score_logit_threshold |
| self.stability_score_attentuation = stability_score_attentuation |
| if iter_use_prev_mask_pred: |
| |
| |
| |
| |
| if min(prob_to_use_pt_input_for_train, prob_to_use_pt_input_for_eval) < 1: |
| assert use_mask_input_as_output_without_sam |
| self.iter_use_prev_mask_pred = iter_use_prev_mask_pred |
|
|
| |
| |
| self.image_size = image_size |
| self.backbone_stride = backbone_stride |
| self.low_res_mask_size = self.image_size // self.backbone_stride * 4 |
| |
| |
| |
| self.input_mask_size = self.low_res_mask_size * 4 |
| self.forward_backbone_per_frame_for_eval = forward_backbone_per_frame_for_eval |
| self.offload_output_to_cpu_for_eval = offload_output_to_cpu_for_eval |
| if trim_past_non_cond_mem_for_eval: |
| assert num_frames_to_correct_for_eval <= 1, ( |
| "trim_past_non_cond_mem_for_eval=True requires that only the first frame receives prompts" |
| ) |
| self.trim_past_non_cond_mem_for_eval = trim_past_non_cond_mem_for_eval |
| self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args |
| self.interactive_sam_mask_decoder_extra_args = ( |
| interactive_sam_mask_decoder_extra_args |
| ) |
| self.pred_obj_scores = pred_obj_scores |
| self.pred_obj_scores_mlp = pred_obj_scores_mlp |
| self.fixed_no_obj_ptr = fixed_no_obj_ptr |
| self.use_no_obj_ptr = use_no_obj_ptr |
| self.use_linear_no_obj_ptr = use_linear_no_obj_ptr |
|
|
| if self.fixed_no_obj_ptr: |
| assert self.pred_obj_scores |
| assert self.use_obj_ptrs_in_encoder |
| if ( |
| self.pred_obj_scores |
| and self.use_obj_ptrs_in_encoder |
| and self.use_no_obj_ptr |
| ): |
| if self.use_linear_no_obj_ptr: |
| self.no_obj_ptr_linear = nn.Linear(self.hidden_dim, self.hidden_dim) |
| else: |
| self.no_obj_ptr = torch.nn.Parameter( |
| torch.zeros(self.multiplex_count, self.hidden_dim) |
| ) |
| trunc_normal_(self.no_obj_ptr, std=0.02) |
|
|
| self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj |
| self.no_obj_embed_spatial = None |
| if no_obj_embed_spatial: |
| self.no_obj_embed_spatial = torch.nn.Parameter( |
| torch.zeros(self.multiplex_count, self.hidden_dim) |
| ) |
| trunc_normal_(self.no_obj_embed_spatial, std=0.02) |
| self.num_multimask_outputs = num_multimask_outputs |
| self.decode_mask_with_shared_tokens = decode_mask_with_shared_tokens |
| self.decode_mask_attribute_with_shared_tokens = ( |
| decode_mask_attribute_with_shared_tokens |
| ) |
| self.share_necks = share_necks |
|
|
| self.add_output_suppression_embeddings = add_output_suppression_embeddings |
| if self.add_output_suppression_embeddings: |
| self.output_valid_embed = torch.nn.Parameter( |
| torch.zeros(self.multiplex_count, self.hidden_dim) |
| ) |
| self.output_invalid_embed = torch.nn.Parameter( |
| torch.zeros(self.multiplex_count, self.hidden_dim) |
| ) |
| trunc_normal_(self.output_valid_embed, std=0.02) |
| trunc_normal_(self.output_invalid_embed, std=0.02) |
| self.add_object_conditional_embeddings = add_object_conditional_embeddings |
| if add_object_unconditional_embeddings is None: |
| add_object_unconditional_embeddings = add_object_conditional_embeddings |
| self.add_object_unconditional_embeddings = add_object_unconditional_embeddings |
| if add_object_unconditional_embeddings: |
| assert add_object_conditional_embeddings |
| if self.add_object_conditional_embeddings: |
| |
| |
| |
| |
| self.obj_cond_embed = torch.nn.Parameter( |
| torch.zeros(self.multiplex_count, self.hidden_dim) |
| ) |
| trunc_normal_(self.obj_cond_embed, std=0.02) |
| if self.add_object_unconditional_embeddings: |
| |
| self.obj_non_cond_embed = torch.nn.Parameter( |
| torch.zeros(self.multiplex_count, self.hidden_dim) |
| ) |
| trunc_normal_(self.obj_non_cond_embed, std=0.02) |
|
|
| self.condition_as_mask_input = condition_as_mask_input |
| self.condition_as_mask_input_fg = condition_as_mask_input_fg |
| self.condition_as_mask_input_bg = condition_as_mask_input_bg |
|
|
| self.is_dynamic_model = is_dynamic_model |
|
|
| self._build_sam_heads() |
|
|
| |
| self.prob_to_use_pt_input_for_train = prob_to_use_pt_input_for_train |
| self.prob_to_use_box_input_for_train = prob_to_use_box_input_for_train |
| self.prob_to_use_pt_input_for_eval = prob_to_use_pt_input_for_eval |
| self.prob_to_use_box_input_for_eval = prob_to_use_box_input_for_eval |
| if prob_to_use_pt_input_for_train > 0 or prob_to_use_pt_input_for_eval > 0: |
| logging.info("Using points (sampled from masks) as inputs") |
| assert num_frames_to_correct_for_train >= num_init_cond_frames_for_train |
| assert num_frames_to_correct_for_eval >= num_init_cond_frames_for_eval |
| self.num_frames_to_correct_for_train = num_frames_to_correct_for_train |
| self.num_frames_to_correct_for_eval = num_frames_to_correct_for_eval |
| self.rand_frames_to_correct_for_train = rand_frames_to_correct_for_train |
| self.rand_frames_to_correct_for_eval = rand_frames_to_correct_for_eval |
| self.prob_correct_all_objects_for_train = prob_correct_all_objects_for_train |
| self.ratio_of_objects_to_correct_for_train = ( |
| ratio_of_objects_to_correct_for_train |
| ) |
| self.rand_objects_to_correct_for_train = rand_objects_to_correct_for_train |
| self.force_correct_all_for_conditional_inputs = ( |
| force_correct_all_for_conditional_inputs |
| ) |
| |
| self.num_init_cond_frames_for_train = num_init_cond_frames_for_train |
| self.num_init_cond_frames_for_eval = num_init_cond_frames_for_eval |
| self.rand_init_cond_frames_for_train = rand_init_cond_frames_for_train |
| self.rand_init_cond_frames_for_eval = rand_init_cond_frames_for_eval |
| self.max_cond_frames_in_attn = max_cond_frames_in_attn |
| self.keep_first_cond_frame = keep_first_cond_frame |
| self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond |
| self.num_correction_pt_per_frame = num_correction_pt_per_frame |
| self.pt_sampling_for_eval = pt_sampling_for_eval |
| self.prob_to_sample_from_gt_for_train = prob_to_sample_from_gt_for_train |
| |
| self.rng = np.random.default_rng(seed=42) |
| if randomness_fix: |
| self.rng2 = np.random.default_rng(seed=42) |
| else: |
| self.rng2 = self.rng |
|
|
| |
| self.use_memory_selection = use_memory_selection |
| self.mf_threshold = mf_threshold |
|
|
| |
| self.compile_all_components = compile_all_components |
| if self.compile_all_components: |
| self._compile_all_components() |
|
|
| def _get_tpos_enc(self, rel_pos_list, device, max_abs_pos=None, dummy=False): |
| if dummy: |
| return torch.zeros(len(rel_pos_list), self.mem_dim, device=device) |
|
|
| t_diff_max = max_abs_pos - 1 if max_abs_pos is not None else 1 |
| pos_enc = ( |
| torch.tensor(rel_pos_list, device="cpu").pin_memory().to(device=device, non_blocking=True) |
| / t_diff_max |
| ) |
| if self.sincos_tpos_enc: |
| tpos_dim = ( |
| self.hidden_dim if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim |
| ) |
| pos_enc = get_1d_sine_pe(pos_enc, dim=tpos_dim) |
| else: |
| raise NotImplementedError |
| pos_enc = self.obj_ptr_tpos_proj(pos_enc) |
|
|
| return pos_enc |
|
|
| def _build_sam_heads(self): |
| """Build SAM-style prompt encoder and mask decoder.""" |
| self.sam_prompt_embed_dim = self.hidden_dim |
| self.sam_image_embedding_size = self.image_size // self.backbone_stride |
|
|
| self.image_pe_layer = PositionEmbeddingRandom(self.hidden_dim // 2) |
|
|
| |
| |
| self.interactive_sam_prompt_encoder = PromptEncoder( |
| embed_dim=self.sam_prompt_embed_dim, |
| image_embedding_size=( |
| self.sam_image_embedding_size, |
| self.sam_image_embedding_size, |
| ), |
| input_image_size=(self.image_size, self.image_size), |
| mask_in_chans=16, |
| ) |
|
|
| self.interactive_sam_mask_decoder = MaskDecoder( |
| num_multimask_outputs=3, |
| transformer=TwoWayTransformer( |
| depth=2, |
| embedding_dim=self.sam_prompt_embed_dim, |
| mlp_dim=2048, |
| num_heads=8, |
| ), |
| transformer_dim=self.sam_prompt_embed_dim, |
| iou_head_depth=3, |
| iou_head_hidden_dim=256, |
| use_high_res_features=self.use_high_res_features_in_sam, |
| iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, |
| pred_obj_scores=self.pred_obj_scores, |
| pred_obj_scores_mlp=self.pred_obj_scores_mlp, |
| use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, |
| **(self.interactive_sam_mask_decoder_extra_args or {}), |
| ) |
| if self.share_necks: |
| |
| del self.interactive_sam_mask_decoder.conv_s0 |
| del self.interactive_sam_mask_decoder.conv_s1 |
|
|
| self.sam_mask_decoder = MultiplexMaskDecoder( |
| multiplex_count=self.multiplex_count, |
| num_multimask_outputs=self.num_multimask_outputs, |
| transformer=TwoWayTransformer( |
| depth=2, |
| embedding_dim=self.hidden_dim, |
| mlp_dim=2048, |
| num_heads=8, |
| ), |
| transformer_dim=self.hidden_dim, |
| iou_head_depth=3, |
| iou_head_hidden_dim=256, |
| use_high_res_features=self.use_high_res_features_in_sam, |
| iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, |
| pred_obj_scores=self.pred_obj_scores, |
| pred_obj_scores_mlp=self.pred_obj_scores_mlp, |
| use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, |
| decode_mask_with_shared_tokens=self.decode_mask_with_shared_tokens, |
| decode_mask_attribute_with_shared_tokens=self.decode_mask_attribute_with_shared_tokens, |
| multimask_outputs_only=self.num_multimask_outputs > 0 |
| and self.multimask_output_in_sam, |
| **(self.sam_mask_decoder_extra_args or {}), |
| ) |
|
|
| if self.use_obj_ptrs_in_encoder: |
| |
| self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) |
| self.interactive_obj_ptr_proj = torch.nn.Linear( |
| self.hidden_dim, self.hidden_dim |
| ) |
| if self.use_mlp_for_obj_ptr_proj: |
| self.obj_ptr_proj = MLP( |
| self.hidden_dim, self.hidden_dim, self.hidden_dim, 3 |
| ) |
| self.interactive_obj_ptr_proj = MLP( |
| self.hidden_dim, self.hidden_dim, self.hidden_dim, 3 |
| ) |
| else: |
| self.obj_ptr_proj = torch.nn.Identity() |
| self.interactive_obj_ptr_proj = torch.nn.Identity() |
| if self.proj_tpos_enc_in_obj_ptrs: |
| |
| |
| self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) |
| else: |
| self.obj_ptr_tpos_proj = torch.nn.Identity() |
|
|
| def _get_interactive_pix_mem( |
| self, features: torch.Tensor, feat_sizes: list[tuple] |
| ) -> torch.Tensor: |
| assert self.directly_add_no_mem_embed |
| pix_feat_with_mem = features[-1] + self.interactivity_no_mem_embed |
| B = features[-1].size(1) |
| C = self.hidden_dim |
| H, W = feat_sizes[-1] |
| pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) |
| return pix_feat_with_mem |
|
|
| def _forward_sam_heads( |
| self, |
| backbone_features: torch.Tensor, |
| *, |
| point_inputs: Optional[dict[str, torch.Tensor]] = None, |
| mask_inputs: Optional[torch.Tensor] = None, |
| interactive_high_res_features: Optional[list[torch.Tensor]] = None, |
| propagation_high_res_features: Optional[list[torch.Tensor]] = None, |
| multimask_output: bool = False, |
| gt_masks=None, |
| multiplex_state: MultiplexState, |
| objects_to_interact: Optional[list[int]] = None, |
| ) -> SAMOutput: |
| """ |
| Forward SAM prompt encoders and mask heads. |
| We run the propagation head, the interactive head, or both, based on the inputs. |
| |
| Inputs: |
| - backbone_features: image features of [B, C, H, W] shape |
| - point_inputs: a dictionary with "point_coords" and "point_labels", where |
| 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the |
| absolute pixel-unit coordinate in (x, y) format of the P input points |
| 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means |
| positive clicks, 0 means negative clicks, and -1 means padding |
| - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the |
| same spatial size as the image. |
| - high_res_features: either 1) None or 2) a list of length 2 containing |
| two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively, |
| which will be used as high-resolution feature maps for SAM decoder. |
| - multimask_output: if it's True, we output 3 candidate masks and their 3 |
| corresponding IoU estimates, and if it's False, we output only 1 mask and |
| its corresponding IoU estimate. |
| |
| Outputs: |
| - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if |
| `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM |
| output mask logits (before sigmoid) for the low-resolution masks, with 4x |
| the resolution (1/4 stride) of the input backbone_features. |
| - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3 |
| if `multimask_output=True` and M = 1 if `multimask_output=False`), |
| upsampled from the low-resolution masks, with shape size as the image |
| (stride is 1 pixel). |
| - ious: [B, M] shape (where M = 3 if `multimask_output=True` and M = 1 |
| if `multimask_output=False`), the estimated IoU of each output mask. |
| - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`. |
| If `multimask_output=True`, it's the mask with the highest IoU estimate. |
| If `multimask_output=False`, it's the same as `low_res_multimasks`. |
| - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`. |
| If `multimask_output=True`, it's the mask with the highest IoU estimate. |
| If `multimask_output=False`, it's the same as `high_res_multimasks`. |
| - obj_ptr: [num_buckets, multiplex_count, C] shape, the object pointer vector for |
| the output mask, extracted based on the output token from the SAM mask decoder. |
| """ |
|
|
| device = backbone_features.device |
| assert backbone_features.size(1) == self.hidden_dim |
| assert backbone_features.size(2) == self.sam_image_embedding_size |
| assert backbone_features.size(3) == self.sam_image_embedding_size |
|
|
| is_interactive = point_inputs is not None or mask_inputs is not None |
|
|
| if is_interactive: |
| """ |
| Image-level, per-object interactive path |
| """ |
| assert interactive_high_res_features is not None |
| assert objects_to_interact is not None |
|
|
| |
| if point_inputs is not None: |
| sam_point_coords = point_inputs["point_coords"] |
| sam_point_labels = point_inputs["point_labels"] |
| else: |
| assert mask_inputs is not None |
| |
| sam_point_coords = torch.zeros( |
| mask_inputs.shape[0], 1, 2, device=device |
| ) |
| sam_point_labels = -torch.ones( |
| mask_inputs.shape[0], 1, dtype=torch.int32, device=device |
| ) |
|
|
| |
| if mask_inputs is not None: |
| |
| |
| assert len(mask_inputs.shape) == 4 |
| if ( |
| mask_inputs.shape[-2:] |
| != self.interactive_sam_prompt_encoder.mask_input_size |
| ): |
| sam_mask_prompt = F.interpolate( |
| mask_inputs.float(), |
| size=self.interactive_sam_prompt_encoder.mask_input_size, |
| align_corners=False, |
| mode="bilinear", |
| antialias=True, |
| ) |
| else: |
| sam_mask_prompt = mask_inputs |
| else: |
| |
| |
| sam_mask_prompt = None |
|
|
| sparse_embeddings, dense_embeddings = self.interactive_sam_prompt_encoder( |
| points=(sam_point_coords, sam_point_labels), |
| boxes=None, |
| masks=sam_mask_prompt, |
| ) |
|
|
| |
| |
| sparse_embeddings = self._maybe_clone(sparse_embeddings) |
| dense_embeddings = self._maybe_clone(dense_embeddings) |
| image_pe = self._maybe_clone( |
| self.interactive_sam_prompt_encoder.get_dense_pe() |
| ) |
| ( |
| low_res_multimasks, |
| ious, |
| sam_output_tokens, |
| object_score_logits, |
| ) = self.interactive_sam_mask_decoder( |
| image_embeddings=backbone_features, |
| image_pe=image_pe, |
| sparse_prompt_embeddings=sparse_embeddings, |
| dense_prompt_embeddings=dense_embeddings, |
| multimask_output=multimask_output, |
| repeat_image=True, |
| high_res_features=interactive_high_res_features, |
| ) |
|
|
| else: |
| """ |
| Multiplexed propagation path |
| """ |
| assert propagation_high_res_features is not None |
| assert multiplex_state is not None |
|
|
| if self.add_output_suppression_embeddings: |
| |
| output_valid_embed = self.output_valid_embed.unsqueeze(0) |
| output_invalid_embed = self.output_invalid_embed.unsqueeze(0) |
| valid_object_mask = ( |
| multiplex_state.get_valid_object_mask().unsqueeze(-1).float() |
| ) |
| output_merged_embed = ( |
| valid_object_mask * output_valid_embed |
| + (1 - valid_object_mask) * output_invalid_embed |
| ) |
| else: |
| output_merged_embed = None |
|
|
| |
| image_pe = self._maybe_clone(self.get_propagation_dense_pe()) |
| out = self.sam_mask_decoder( |
| image_embeddings=backbone_features, |
| image_pe=image_pe, |
| high_res_features=propagation_high_res_features, |
| multimask_output=multimask_output, |
| extra_per_object_embeddings=output_merged_embed, |
| ) |
| low_res_multimasks = out["masks"] |
| ious = out["iou_pred"] |
| sam_output_tokens = out["sam_tokens_out"] |
| object_score_logits = out["object_score_logits"] |
|
|
| low_res_multimasks = multiplex_state.demux(low_res_multimasks) |
| ious = multiplex_state.demux(ious) |
| object_score_logits = multiplex_state.demux(object_score_logits) |
| sam_output_tokens = multiplex_state.demux(sam_output_tokens) |
|
|
| """ |
| The interactive and the propagation paths converge here |
| """ |
| |
| |
| low_res_multimasks = self._maybe_clone(low_res_multimasks) |
| ious = self._maybe_clone(ious) |
| object_score_logits = self._maybe_clone(object_score_logits) |
| sam_output_tokens = self._maybe_clone(sam_output_tokens) |
|
|
| if self.pred_obj_scores: |
| is_obj_appearing = object_score_logits > self.object_score_logit_threshold |
|
|
| |
| |
| low_res_multimasks = torch.where( |
| is_obj_appearing[:, None, None], |
| low_res_multimasks, |
| NO_OBJ_SCORE, |
| ) |
|
|
| |
| |
| low_res_multimasks = low_res_multimasks.float() |
| high_res_multimasks = F.interpolate( |
| low_res_multimasks, |
| size=(self.image_size, self.image_size), |
| mode="bilinear", |
| align_corners=False, |
| ) |
|
|
| sam_output_token = sam_output_tokens[:, 0] |
| if multimask_output and ( |
| not self.decode_mask_with_shared_tokens or is_interactive |
| ): |
| |
| if self.stability_score_attentuation: |
| |
| stability_score = self.sam_mask_decoder._get_stability_scores( |
| low_res_multimasks |
| ) |
| ious = ious * stability_score |
|
|
| best_iou_inds = torch.argmax(ious, dim=-1) |
| batch_inds = torch.arange(ious.shape[0], device=device) |
|
|
| low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) |
| high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) |
| if sam_output_tokens.size(1) > 1: |
| sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] |
| else: |
| if multimask_output and not is_interactive: |
| assert self.decode_mask_with_shared_tokens |
| low_res_masks = low_res_multimasks[:, 0:1] |
| high_res_masks = high_res_multimasks[:, 0:1] |
| else: |
| low_res_masks = low_res_multimasks |
| high_res_masks = high_res_multimasks |
|
|
| |
| if self.use_obj_ptrs_in_encoder: |
| if is_interactive: |
| obj_ptr = self.interactive_obj_ptr_proj(sam_output_token) |
| else: |
| obj_ptr = self.obj_ptr_proj(sam_output_token) |
|
|
| if self.pred_obj_scores and self.use_no_obj_ptr: |
| lambda_is_obj_appearing = is_obj_appearing.float() |
| if self.use_linear_no_obj_ptr: |
| obj_ptr = lambda_is_obj_appearing * obj_ptr + ( |
| 1 - lambda_is_obj_appearing |
| ) * self.no_obj_ptr_linear(obj_ptr) |
| else: |
| if self.fixed_no_obj_ptr: |
| obj_ptr = lambda_is_obj_appearing * obj_ptr |
|
|
| |
| selected_no_obj_ptr = self.no_obj_ptr.unsqueeze(0).repeat( |
| multiplex_state.num_buckets, 1, 1 |
| ) |
| selected_no_obj_ptr = multiplex_state.demux(selected_no_obj_ptr) |
| if is_interactive: |
| |
| selected_no_obj_ptr = selected_no_obj_ptr[objects_to_interact] |
|
|
| obj_ptr = ( |
| obj_ptr + (1 - lambda_is_obj_appearing) * selected_no_obj_ptr |
| ) |
|
|
| outputs: SAMOutput = { |
| "low_res_multimasks": low_res_multimasks, |
| "high_res_multimasks": high_res_multimasks, |
| "ious": ious, |
| "low_res_masks": low_res_masks, |
| "high_res_masks": high_res_masks, |
| "object_score_logits": object_score_logits, |
| } |
| if self.use_obj_ptrs_in_encoder: |
| outputs["obj_ptr"] = obj_ptr |
| return outputs |
|
|
| def _use_mask_as_output( |
| self, |
| backbone_features: torch.Tensor, |
| high_res_features: list[torch.Tensor], |
| mask_inputs: torch.Tensor, |
| multiplex_state: MultiplexState, |
| objects_in_mask: Optional[list[int]] = None, |
| ) -> SAMOutput: |
| """ |
| Directly turn binary `mask_inputs` into a output mask logits without using SAM. |
| (same input and output shapes as in _forward_sam_heads above). |
| """ |
| if objects_in_mask is None: |
| objects_in_mask = list(range(multiplex_state.total_valid_entries)) |
|
|
| |
| out_scale, out_bias = 20.0, -10.0 |
| mask_inputs_float = mask_inputs.to(backbone_features.dtype) |
| assert mask_inputs.shape[0] == len(objects_in_mask), ( |
| f"{mask_inputs.shape[0]} != {len(objects_in_mask)}" |
| ) |
| high_res_masks = mask_inputs_float * out_scale + out_bias |
| low_res_masks = F.interpolate( |
| high_res_masks, |
| size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), |
| align_corners=False, |
| mode="bilinear", |
| antialias=True, |
| ) |
| |
| ious = mask_inputs.new_ones( |
| mask_inputs.size(0), 1, dtype=backbone_features.dtype |
| ) |
|
|
| if self.use_obj_ptrs_in_encoder: |
| |
| sam_outputs = self._forward_sam_heads( |
| backbone_features=backbone_features, |
| mask_inputs=self.interactive_mask_downsample(mask_inputs_float), |
| interactive_high_res_features=high_res_features, |
| gt_masks=mask_inputs, |
| objects_to_interact=objects_in_mask, |
| multiplex_state=multiplex_state, |
| ) |
| obj_ptr = sam_outputs["obj_ptr"] |
|
|
| |
| |
| |
| is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1) |
| is_obj_appearing = is_obj_appearing[..., None] |
| lambda_is_obj_appearing = is_obj_appearing.float() |
| object_score_logits = out_scale * lambda_is_obj_appearing + out_bias |
| |
| |
| |
| |
| if self.pred_obj_scores and self.use_no_obj_ptr: |
| if self.use_linear_no_obj_ptr: |
| obj_ptr = lambda_is_obj_appearing * obj_ptr + ( |
| 1 - lambda_is_obj_appearing |
| ) * self.no_obj_ptr_linear(obj_ptr) |
| else: |
| if self.fixed_no_obj_ptr: |
| obj_ptr = lambda_is_obj_appearing * obj_ptr |
| |
| selected_no_obj_ptr = self.no_obj_ptr.unsqueeze(0).repeat( |
| multiplex_state.num_buckets, 1, 1 |
| ) |
| selected_no_obj_ptr = multiplex_state.demux(selected_no_obj_ptr) |
| selected_no_obj_ptr = selected_no_obj_ptr[objects_in_mask] |
| obj_ptr = ( |
| obj_ptr + (1 - lambda_is_obj_appearing) * selected_no_obj_ptr |
| ) |
|
|
| outputs: SAMOutput = { |
| "low_res_multimasks": low_res_masks, |
| "high_res_multimasks": high_res_masks, |
| "ious": ious, |
| "low_res_masks": low_res_masks, |
| "high_res_masks": high_res_masks, |
| "object_score_logits": object_score_logits, |
| } |
| if self.use_obj_ptrs_in_encoder: |
| outputs["obj_ptr"] = obj_ptr |
| return outputs |
|
|
| def forward(self, input: BatchedDatapoint, is_inference=False): |
| if self.training or not self.forward_backbone_per_frame_for_eval: |
| |
| backbone_out = self.forward_image( |
| input.img_batch, need_interactive_out=True, need_propagation_out=True |
| ) |
| else: |
| |
| backbone_out = {} |
| backbone_out = self.prepare_prompt_inputs(backbone_out, input) |
| previous_stages_out = self.forward_tracking(backbone_out, input) |
|
|
| |
| return previous_stages_out, None |
|
|
| def forward_image( |
| self, |
| img_batch, |
| *, |
| need_sam3_out: bool = False, |
| need_interactive_out: bool = False, |
| need_propagation_out: bool = False, |
| ): |
| """Get the image feature on the input batch.""" |
| if self.share_necks: |
| need_propagation_out = need_interactive_out or need_propagation_out |
| need_interactive_out = False |
| |
| backbone_out = self.backbone.forward_image( |
| img_batch, |
| need_sam3_out=need_sam3_out, |
| need_sam2_out=need_propagation_out, |
| ) |
| backbone_out["interactive"] = backbone_out["sam2_backbone_out"] |
| else: |
| backbone_out = self.backbone.forward_image( |
| img_batch, |
| need_sam3_out=need_sam3_out, |
| need_interactive_out=need_interactive_out, |
| need_propagation_out=need_propagation_out, |
| ) |
| if self.use_high_res_features_in_sam: |
| |
| |
| if need_interactive_out: |
| backbone_out["interactive"]["backbone_fpn"][ |
| 0 |
| ].tensors = self.interactive_sam_mask_decoder.conv_s0( |
| backbone_out["interactive"]["backbone_fpn"][0].tensors |
| ) |
| backbone_out["interactive"]["backbone_fpn"][ |
| 1 |
| ].tensors = self.interactive_sam_mask_decoder.conv_s1( |
| backbone_out["interactive"]["backbone_fpn"][1].tensors |
| ) |
| if need_propagation_out: |
| backbone_out["sam2_backbone_out"]["backbone_fpn"][ |
| 0 |
| ].tensors = self.sam_mask_decoder.conv_s0( |
| backbone_out["sam2_backbone_out"]["backbone_fpn"][0].tensors |
| ) |
| backbone_out["sam2_backbone_out"]["backbone_fpn"][ |
| 1 |
| ].tensors = self.sam_mask_decoder.conv_s1( |
| backbone_out["sam2_backbone_out"]["backbone_fpn"][1].tensors |
| ) |
| |
| for out_type in backbone_out.keys(): |
| for i in range(len(backbone_out[out_type]["backbone_fpn"])): |
| backbone_out[out_type]["backbone_fpn"][i].tensors = self._maybe_clone( |
| backbone_out[out_type]["backbone_fpn"][i].tensors |
| ) |
| backbone_out[out_type]["vision_pos_enc"][i] = self._maybe_clone( |
| backbone_out[out_type]["vision_pos_enc"][i] |
| ) |
| return backbone_out |
|
|
| def _prepare_prompt_inputs_meta(self, backbone_out, input, start_frame_idx=0): |
| |
| |
| gt_masks_per_frame = { |
| stage_id: targets.segments.unsqueeze(1) |
| for stage_id, targets in enumerate(input.find_targets) |
| } |
| backbone_out["gt_masks_per_frame"] = gt_masks_per_frame |
| num_frames = len(input.find_targets) |
| backbone_out["num_frames"] = num_frames |
|
|
| |
| if self.training: |
| prob_to_use_pt_input = self.prob_to_use_pt_input_for_train |
| num_frames_to_correct = self.num_frames_to_correct_for_train |
| rand_frames_to_correct = self.rand_frames_to_correct_for_train |
| num_init_cond_frames = self.num_init_cond_frames_for_train |
| rand_init_cond_frames = self.rand_init_cond_frames_for_train |
| else: |
| prob_to_use_pt_input = self.prob_to_use_pt_input_for_eval |
| num_frames_to_correct = self.num_frames_to_correct_for_eval |
| rand_frames_to_correct = self.rand_frames_to_correct_for_eval |
| num_init_cond_frames = self.num_init_cond_frames_for_eval |
| rand_init_cond_frames = self.rand_init_cond_frames_for_eval |
| if num_frames == 1: |
| |
| |
| prob_to_use_pt_input = 1.0 |
| num_frames_to_correct = 1 |
| num_init_cond_frames = 1 |
| assert num_init_cond_frames >= 1 |
| |
| use_pt_input = self.rng.random() < prob_to_use_pt_input |
| if rand_init_cond_frames and num_init_cond_frames > 1: |
| |
| num_init_cond_frames = self.rng.integers( |
| 1, num_init_cond_frames, endpoint=True |
| ) |
| if ( |
| use_pt_input |
| and rand_frames_to_correct |
| and num_frames_to_correct > num_init_cond_frames |
| ): |
| |
| |
| num_frames_to_correct = self.rng.integers( |
| num_init_cond_frames, num_frames_to_correct, endpoint=True |
| ) |
| backbone_out["use_pt_input"] = use_pt_input |
|
|
| |
| if num_init_cond_frames == 1: |
| init_cond_frames = [start_frame_idx] |
| else: |
| |
| init_cond_frames = [start_frame_idx] + self.rng.choice( |
| range(start_frame_idx + 1, num_frames), |
| num_init_cond_frames - 1, |
| replace=False, |
| ).tolist() |
| backbone_out["init_cond_frames"] = init_cond_frames |
| backbone_out["frames_not_in_init_cond"] = [ |
| t for t in range(start_frame_idx, num_frames) if t not in init_cond_frames |
| ] |
|
|
| |
| |
| if not use_pt_input: |
| |
| frames_to_add_correction_pt = [] |
| elif num_frames_to_correct == num_init_cond_frames: |
| frames_to_add_correction_pt = init_cond_frames |
| else: |
| assert num_frames_to_correct > num_init_cond_frames |
| |
| extra_num = num_frames_to_correct - num_init_cond_frames |
| frames_to_add_correction_pt = ( |
| init_cond_frames |
| + self.rng.choice( |
| backbone_out["frames_not_in_init_cond"], extra_num, replace=False |
| ).tolist() |
| ) |
| backbone_out["frames_to_add_correction_pt"] = frames_to_add_correction_pt |
|
|
| return backbone_out |
|
|
| def _prepare_conditional_frames(self, backbone_out): |
| init_cond_frames = backbone_out["init_cond_frames"] |
| gt_masks_per_frame = backbone_out["gt_masks_per_frame"] |
| use_pt_input = backbone_out["use_pt_input"] |
|
|
| if self.training: |
| prob_to_use_box_input = self.prob_to_use_box_input_for_train |
| else: |
| prob_to_use_box_input = self.prob_to_use_box_input_for_eval |
|
|
| |
| backbone_out["mask_inputs_per_frame"] = {} |
| backbone_out["point_inputs_per_frame"] = {} |
| for t in init_cond_frames: |
| if not use_pt_input: |
| backbone_out["mask_inputs_per_frame"][t] = gt_masks_per_frame[t] |
| else: |
| |
| use_box_input = self.rng.random() < prob_to_use_box_input |
| if use_box_input: |
| points, labels = sample_box_points( |
| gt_masks_per_frame[t], |
| ) |
| else: |
| |
| |
| points, labels = get_next_point( |
| gt_masks=gt_masks_per_frame[t], |
| pred_masks=None, |
| method=( |
| "uniform" if self.training else self.pt_sampling_for_eval |
| ), |
| ) |
|
|
| point_inputs = {"point_coords": points, "point_labels": labels} |
| backbone_out["point_inputs_per_frame"][t] = point_inputs |
|
|
| return backbone_out |
|
|
| def prepare_prompt_inputs(self, backbone_out, input, start_frame_idx=0): |
| """ |
| Prepare input mask, point or box prompts. Optionally, we allow tracking from |
| a custom `start_frame_idx` to the end of the video (for evaluation purposes). |
| """ |
| backbone_out = self._prepare_prompt_inputs_meta( |
| backbone_out, input, start_frame_idx |
| ) |
| backbone_out = self._prepare_conditional_frames(backbone_out) |
| return backbone_out |
|
|
| def _prepare_backbone_features(self, backbone_out): |
| """Prepare and flatten visual features (same as in MDETR_API model).""" |
|
|
| backbone_features = {} |
|
|
| for neck_k in neck_outs: |
| if neck_k not in backbone_out: |
| continue |
| neck_out = backbone_out[neck_k] |
| assert len(neck_out["backbone_fpn"]) == len(neck_out["vision_pos_enc"]) |
| assert len(neck_out["backbone_fpn"]) >= self.num_feature_levels |
|
|
| feature_maps = neck_out["backbone_fpn"][-self.num_feature_levels :] |
| vision_pos_embeds = neck_out["vision_pos_enc"][-self.num_feature_levels :] |
|
|
| feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] |
| |
| vision_feats = [x.tensors.flatten(2).permute(2, 0, 1) for x in feature_maps] |
| vision_pos_embeds = [ |
| x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds |
| ] |
| vision_masks = [x.mask for x in feature_maps] |
|
|
| for i, vision_mask in enumerate(vision_masks): |
| if vision_mask is not None: |
| vision_masks[i] = vision_mask.flatten(1) |
|
|
| backbone_features[neck_k] = { |
| "vision_feats": vision_feats, |
| "vision_pos_embeds": vision_pos_embeds, |
| "vision_masks": vision_masks, |
| "feat_sizes": feat_sizes, |
| } |
|
|
| return backbone_features |
|
|
| def _prepare_backbone_features_per_frame( |
| self, |
| img_batch, |
| img_ids, |
| *, |
| need_interactive_out: bool = False, |
| need_propagation_out: bool = False, |
| ): |
| """Compute the image backbone features on the fly for the given img_ids.""" |
| |
| assert img_ids.numel() == 1 |
| unique_img_ids = img_ids |
|
|
| |
| image = img_batch.tensors[unique_img_ids] |
| image_mask = ( |
| img_batch.mask[unique_img_ids] if img_batch.mask is not None else None |
| ) |
|
|
| backbone_out = self.forward_image( |
| NestedTensor(tensors=image, mask=image_mask), |
| need_interactive_out=need_interactive_out, |
| need_propagation_out=need_propagation_out, |
| ) |
|
|
| backbone_features = self._prepare_backbone_features(backbone_out) |
| return image, backbone_features |
|
|
| def _prepare_memory_conditioned_features( |
| self, |
| *, |
| frame_idx, |
| is_init_cond_frame, |
| current_vision_feats, |
| current_vision_masks, |
| current_vision_pos_embeds, |
| feat_sizes, |
| output_dict, |
| num_frames, |
| track_in_reverse=False, |
| use_prev_mem_frame=True, |
| multiplex_state: MultiplexState, |
| ): |
| """Fuse the current frame's visual feature map with previous memory.""" |
| B = multiplex_state.num_buckets |
| |
| vision_feat = current_vision_feats[-1].expand(-1, B, -1) |
| vision_mask = ( |
| current_vision_masks[-1].expand(-1, B, -1) |
| if current_vision_masks[-1] is not None |
| else None |
| ) |
| vision_pos_embed = current_vision_pos_embeds[-1].expand(-1, B, -1) |
|
|
| C = self.hidden_dim |
| H, W = feat_sizes[-1] |
| device = current_vision_feats[-1].device |
| |
| |
| if self.num_maskmem == 0: |
| pix_feat = vision_feat.permute(1, 2, 0).view(B, C, H, W) |
| return pix_feat |
|
|
| num_obj_ptr_tokens = 0 |
| tpos_sign_mul = -1 if track_in_reverse else 1 |
| |
| if not is_init_cond_frame and use_prev_mem_frame: |
| |
| |
| to_cat_prompt, to_cat_prompt_pos_embed = [], [] |
| if self.save_image_features: |
| to_cat_image_feat, to_cat_image_pos_embed = [], [] |
| |
| |
| assert len(output_dict["cond_frame_outputs"]) > 0 |
| |
| cond_outputs = output_dict["cond_frame_outputs"] |
| selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( |
| frame_idx, |
| cond_outputs, |
| self.max_cond_frames_in_attn, |
| keep_first_cond_frame=self.keep_first_cond_frame, |
| ) |
|
|
| t_pos_and_prevs = [ |
| ((frame_idx - t) * tpos_sign_mul, out, True) |
| for t, out in selected_cond_outputs.items() |
| ] |
| |
| |
| |
| |
| r = 1 if self.training else self.memory_temporal_stride_for_eval |
|
|
| if self.use_memory_selection: |
| valid_indices = self.frame_filter( |
| output_dict, track_in_reverse, frame_idx, num_frames, r |
| ) |
|
|
| for t_pos in range(1, self.num_maskmem): |
| t_rel = self.num_maskmem - t_pos |
| if self.use_memory_selection: |
| if t_rel > len(valid_indices): |
| continue |
| prev_frame_idx = valid_indices[-t_rel] |
| else: |
| if t_rel == 1: |
| |
| if not track_in_reverse: |
| |
| prev_frame_idx = frame_idx - t_rel |
| else: |
| |
| prev_frame_idx = frame_idx + t_rel |
| else: |
| |
| if not track_in_reverse: |
| |
| |
| prev_frame_idx = ((frame_idx - 2) // r) * r |
| |
| prev_frame_idx = prev_frame_idx - (t_rel - 2) * r |
| else: |
| |
| |
| prev_frame_idx = -(-(frame_idx + 2) // r) * r |
| |
| prev_frame_idx = prev_frame_idx + (t_rel - 2) * r |
| out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) |
| if out is None: |
| |
| |
| out = unselected_cond_outputs.get(prev_frame_idx, None) |
| t_pos_and_prevs.append((t_pos, out, False)) |
|
|
| for t_pos, prev, is_selected_cond_frame in t_pos_and_prevs: |
| if prev is None: |
| continue |
|
|
| feats = prev.get("maskmem_features") |
| if feats is None: |
| continue |
| |
| |
| feats = feats.cuda(non_blocking=True) |
| if feats.dim() == 5: |
| feats = multiplex_state.demux(feats).contiguous() |
| prev["maskmem_features"] = ( |
| feats.cpu() if not feats.is_cuda else feats |
| ) |
|
|
| if feats.shape[0] == 0: |
| continue |
|
|
| to_cat_prompt.append(feats.flatten(2).permute(2, 0, 1)) |
| |
| |
| maskmem_pos_list = prev.get("maskmem_pos_enc") |
| if not maskmem_pos_list: |
| continue |
| maskmem_enc = maskmem_pos_list[-1] |
| if maskmem_enc is None: |
| continue |
| maskmem_enc = maskmem_enc.cuda(non_blocking=True) |
| if maskmem_enc.dim() == 5: |
| maskmem_enc = multiplex_state.demux(maskmem_enc).contiguous() |
| prev["maskmem_pos_enc"][-1] = ( |
| maskmem_enc.cpu() if not maskmem_enc.is_cuda else maskmem_enc |
| ) |
| maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) |
|
|
| if self.use_maskmem_tpos_v2: |
| |
| if t_pos <= 0 or t_pos >= self.num_maskmem: |
| tpos_enc = self.maskmem_tpos_enc[self.num_maskmem - 1] |
| else: |
| tpos_enc = self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] |
| else: |
| |
| |
| |
| t = t_pos if not is_selected_cond_frame else 0 |
| tpos_enc = self.maskmem_tpos_enc[self.num_maskmem - t - 1] |
|
|
| maskmem_enc = maskmem_enc + tpos_enc |
|
|
| if self.save_image_features: |
| |
| image_feat = prev["image_features"].cuda() |
| image_pos_embed = prev["image_pos_enc"].cuda() + tpos_enc |
| to_cat_image_feat.append(image_feat) |
| to_cat_image_pos_embed.append(image_pos_embed) |
|
|
| to_cat_prompt_pos_embed.append(maskmem_enc) |
|
|
| |
| if self.use_obj_ptrs_in_encoder: |
| max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) |
| |
| |
| if not self.training and self.only_obj_ptrs_in_the_past_for_eval: |
| ptr_cond_outputs = { |
| t: out |
| for t, out in selected_cond_outputs.items() |
| if (t >= frame_idx if track_in_reverse else t <= frame_idx) |
| } |
| else: |
| ptr_cond_outputs = selected_cond_outputs |
| pos_and_outs_for_ptr = [ |
| |
| ( |
| ( |
| (frame_idx - t) * tpos_sign_mul |
| if self.use_signed_tpos_enc_to_obj_ptrs |
| else abs(frame_idx - t) |
| ), |
| out, |
| True, |
| ) |
| for t, out in ptr_cond_outputs.items() |
| ] |
|
|
| |
| for t_diff in range(1, max_obj_ptrs_in_encoder): |
| if not self.use_memory_selection: |
| t = ( |
| frame_idx + t_diff |
| if track_in_reverse |
| else frame_idx - t_diff |
| ) |
| if t < 0 or (num_frames is not None and t >= num_frames): |
| break |
| else: |
| if -t_diff <= -len(valid_indices): |
| break |
| t = valid_indices[-t_diff] |
|
|
| out = output_dict["non_cond_frame_outputs"].get( |
| t, unselected_cond_outputs.get(t, None) |
| ) |
| if out is not None: |
| pos_and_outs_for_ptr.append((t_diff, out, False)) |
|
|
| |
| if len(pos_and_outs_for_ptr) > 0: |
| pos_list, out_list, is_selected_cond_frame_list = zip( |
| *pos_and_outs_for_ptr |
| ) |
| |
| filtered_data = [ |
| (pos, out, is_cond) |
| for pos, out, is_cond in zip( |
| pos_list, out_list, is_selected_cond_frame_list |
| ) |
| if "obj_ptr" in out |
| ] |
|
|
| |
| if len(filtered_data) > 0: |
| pos_list, out_list, is_selected_cond_frame_list = zip( |
| *filtered_data |
| ) |
| |
| |
| obj_ptrs = torch.cat( |
| [out["obj_ptr"] for out in out_list], dim=1 |
| ).transpose(0, 1) |
|
|
| |
| |
| if self.add_tpos_enc_to_obj_ptrs: |
| obj_pos = self._get_tpos_enc( |
| pos_list, |
| max_abs_pos=max_obj_ptrs_in_encoder, |
| device=device, |
| ) |
| else: |
| obj_pos = self._get_tpos_enc( |
| pos_list, device=device, dummy=True |
| ) |
| |
| obj_pos = obj_pos.unsqueeze(1).expand(-1, B, -1) |
|
|
| assert self.mem_dim == C, ( |
| f"obj_ptrs.shape = {obj_ptrs.shape}, C = {C}" |
| ) |
|
|
| |
| obj_pos = obj_pos.repeat_interleave( |
| multiplex_state.multiplex_count, dim=0 |
| ) |
|
|
| to_cat_prompt.append(obj_ptrs) |
| to_cat_prompt_pos_embed.append(obj_pos) |
| |
| num_obj_ptr_tokens = obj_ptrs.shape[0] |
| else: |
| |
| num_obj_ptr_tokens = 0 |
| else: |
| num_obj_ptr_tokens = 0 |
| else: |
| |
| raise NotImplementedError( |
| "Any init cond frame should have gone to _use_mask_as_output instead" |
| ) |
|
|
| |
| if len(to_cat_prompt) == 0: |
| |
| |
| |
| pix_feat = vision_feat.permute(1, 2, 0).view(B, C, H, W) |
| return pix_feat |
|
|
| prompt = torch.cat(to_cat_prompt, dim=0) |
| prompt_mask = None |
| prompt_pos_embed = torch.cat(to_cat_prompt_pos_embed, dim=0) |
|
|
| if self.save_image_features: |
| assert prompt_mask is None |
| assert vision_mask is None |
| if len(to_cat_image_feat) == 0 or len(to_cat_image_pos_embed) == 0: |
| |
| pix_feat = vision_feat.permute(1, 2, 0).view(B, C, H, W) |
| return pix_feat |
| image_feat = torch.cat(to_cat_image_feat, dim=0) |
| image_pos_embed = torch.cat(to_cat_image_pos_embed, dim=0) |
|
|
| encoder_out = self.transformer.encoder( |
| image=current_vision_feats[-1], |
| src=vision_feat, |
| memory_image=image_feat, |
| memory=prompt, |
| image_pos=current_vision_pos_embeds[-1], |
| src_pos=vision_pos_embed, |
| memory_image_pos=image_pos_embed, |
| memory_pos=prompt_pos_embed, |
| num_obj_ptr_tokens=num_obj_ptr_tokens, |
| ) |
| else: |
| encoder_out = self.transformer.encoder( |
| src=vision_feat, |
| src_key_padding_mask=vision_mask, |
| src_pos=vision_pos_embed, |
| prompt=prompt, |
| prompt_pos=prompt_pos_embed, |
| prompt_key_padding_mask=prompt_mask, |
| feat_sizes=feat_sizes, |
| num_obj_ptr_tokens=num_obj_ptr_tokens, |
| ) |
| |
| pix_feat_with_mem = encoder_out["memory"].permute(1, 2, 0).view(B, C, H, W) |
| return pix_feat_with_mem |
|
|
| def _encode_new_memory( |
| self, |
| image, |
| current_vision_feats, |
| feat_sizes, |
| pred_masks_high_res, |
| object_score_logits, |
| is_mask_from_pts, |
| *, |
| conditioning_objects: Optional[Iterable[int]] = None, |
| multiplex_state: MultiplexState, |
| ): |
| """Encode the current image and its prediction into a memory feature.""" |
| B = current_vision_feats[-1].size(1) |
| C = self.hidden_dim |
| H, W = feat_sizes[-1] |
| |
| pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) |
| if self.non_overlap_masks_for_mem_enc and not self.training: |
| |
| |
| |
| pred_masks_high_res = self._apply_non_overlapping_constraints( |
| pred_masks_high_res |
| ) |
| if self.apply_sigmoid_to_mask_logits_for_mem_enc: |
| |
| assert not self.binarize_mask_from_pts_for_mem_enc, ( |
| "haven't been trained this way; beware of hardcoded config override" |
| ) |
| binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts |
| if binarize and not self.training: |
| mask_for_mem = (pred_masks_high_res > 0).float() |
| else: |
| |
| mask_for_mem = torch.sigmoid(pred_masks_high_res) |
| |
| if self.sigmoid_scale_for_mem_enc != 1.0: |
| mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc |
| if self.sigmoid_bias_for_mem_enc != 0.0: |
| mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc |
| else: |
| mask_for_mem = pred_masks_high_res |
|
|
| if self.add_object_conditional_embeddings or self.condition_as_mask_input: |
| |
| if conditioning_objects is None: |
| conditioning_objects = [] |
| unconditioning_objects = sorted( |
| list(multiplex_state.get_all_valid_object_idx()) |
| ) |
| else: |
| conditioning_objects = sorted(list(conditioning_objects)) |
| all_objects_idx = multiplex_state.get_all_valid_object_idx() |
| unconditioning_objects = sorted( |
| [i for i in all_objects_idx if i not in conditioning_objects] |
| ) |
|
|
| mux_mask_for_mem = multiplex_state.mux(mask_for_mem).squeeze(2) |
|
|
| if self.condition_as_mask_input: |
| |
| |
| num_objects = mask_for_mem.shape[0] |
| |
| cond_values = torch.full( |
| (num_objects,), |
| self.condition_as_mask_input_bg, |
| device=mask_for_mem.device, |
| dtype=mask_for_mem.dtype, |
| ) |
| if len(conditioning_objects) > 0: |
| cond_values[conditioning_objects] = self.condition_as_mask_input_fg |
| |
| embedded_conditions = cond_values.view(-1, 1, 1, 1).expand_as(mask_for_mem) |
| embedded_conditions = multiplex_state.mux(embedded_conditions).squeeze(2) |
|
|
| mux_mask_for_mem = torch.cat([mux_mask_for_mem, embedded_conditions], dim=1) |
|
|
| if isinstance(self.maskmem_backbone, SimpleMaskEncoder): |
| maskmem_out = self.maskmem_backbone( |
| pix_feat, |
| mux_mask_for_mem, |
| skip_mask_sigmoid=True, |
| ) |
| else: |
| maskmem_out = self.maskmem_backbone(image, pix_feat, mux_mask_for_mem) |
| |
| maskmem_features = self._maybe_clone(maskmem_out["vision_features"]) |
| maskmem_pos_enc = [self._maybe_clone(m) for m in maskmem_out["vision_pos_enc"]] |
|
|
| if self.no_obj_embed_spatial is not None: |
| |
| |
| |
| no_obj_embed_spatial = self.no_obj_embed_spatial.unsqueeze(0).repeat( |
| multiplex_state.num_buckets, 1, 1 |
| ) |
| |
| if object_score_logits is not None: |
| obj_expected = multiplex_state.total_valid_entries |
| obj_current = object_score_logits.shape[0] |
| if obj_current != obj_expected: |
| if obj_current < obj_expected: |
| pad_shape = (obj_expected - obj_current,) + tuple( |
| object_score_logits.shape[1:] |
| ) |
| obj_pad = object_score_logits.new_zeros(pad_shape) |
| object_score_logits = torch.cat( |
| [object_score_logits, obj_pad], dim=0 |
| ) |
| else: |
| object_score_logits = object_score_logits[:obj_expected] |
| object_score_logits = multiplex_state.mux(object_score_logits) |
| is_obj_appearing = ( |
| object_score_logits > self.object_score_logit_threshold |
| ).float() |
|
|
| no_obj_embed = ((1 - is_obj_appearing) * no_obj_embed_spatial).sum(dim=1) |
| maskmem_features += no_obj_embed[..., None, None].expand_as( |
| maskmem_features |
| ) |
|
|
| if self.add_object_conditional_embeddings: |
| |
| |
| obj_cond_embed = self.obj_cond_embed.unsqueeze(0).repeat( |
| multiplex_state.num_buckets, 1, 1 |
| ) |
| obj_cond_embed = multiplex_state.demux(obj_cond_embed) |
| obj_merged_embed = obj_cond_embed |
|
|
| if self.add_object_unconditional_embeddings: |
| obj_non_cond_embed = self.obj_non_cond_embed.unsqueeze(0).repeat( |
| multiplex_state.num_buckets, 1, 1 |
| ) |
| obj_non_cond_embed = multiplex_state.demux(obj_non_cond_embed) |
| if self.training: |
| obj_merged_embed = obj_merged_embed.clone() |
| obj_merged_embed[unconditioning_objects] = obj_non_cond_embed[ |
| unconditioning_objects |
| ] |
|
|
| obj_merged_embed = multiplex_state.mux(obj_merged_embed).sum(dim=1) |
| maskmem_features = maskmem_features + obj_merged_embed[ |
| ..., None, None |
| ].expand_as(maskmem_features) |
|
|
| if maskmem_features.dim() == 5: |
| maskmem_features = multiplex_state.demux(maskmem_features).contiguous() |
|
|
| demuxed_pos_enc = [] |
| for pos_enc in maskmem_pos_enc: |
| pos_enc_clone = pos_enc |
| if pos_enc_clone is not None and pos_enc_clone.dim() == 5: |
| pos_enc_clone = multiplex_state.demux(pos_enc_clone).contiguous() |
| demuxed_pos_enc.append(pos_enc_clone) |
| maskmem_pos_enc = demuxed_pos_enc |
|
|
| return maskmem_features, maskmem_pos_enc |
|
|
| def forward_tracking( |
| self, |
| backbone_out, |
| input, |
| return_dict=False, |
| objects_to_interact: Optional[list[int]] = None, |
| ): |
| """Forward video tracking on each frame (and sample correction clicks).""" |
| img_feats_already_computed = ( |
| "interactive" in backbone_out or "sam2_backbone_out" in backbone_out |
| ) |
| if img_feats_already_computed: |
| |
| |
| |
| backbone_features = self._prepare_backbone_features(backbone_out) |
|
|
| |
| num_frames = backbone_out["num_frames"] |
| init_cond_frames = backbone_out["init_cond_frames"] |
| frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"] |
| |
| |
| processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"] |
|
|
| cond_frame_outputs: dict[int, StageOutput] = {} |
| non_cond_frame_outputs: dict[int, StageOutput] = {} |
| output_dict = { |
| "cond_frame_outputs": cond_frame_outputs, |
| "non_cond_frame_outputs": non_cond_frame_outputs, |
| } |
|
|
| multiplex_state = self.multiplex_controller.get_state( |
| backbone_out["gt_masks_per_frame"][0].shape[0], |
| device=backbone_out["gt_masks_per_frame"][0].device, |
| dtype=torch.float, |
| random=self.training, |
| ) |
|
|
| for stage_id in processing_order: |
| |
| img_ids = input.find_inputs[stage_id].img_ids |
| |
| assert all( |
| [img_id == img_ids[0] for img_id in img_ids] |
| ) |
| |
| img_ids = torch.tensor( |
| [img_ids[0]], device=img_ids.device, dtype=img_ids.dtype |
| ) |
|
|
| if img_feats_already_computed: |
| |
| current_image = input.img_batch.tensors[img_ids] |
| current_backbone_features = {} |
| for neck_k, neck_out in backbone_features.items(): |
| current_backbone_features[neck_k] = { |
| "vision_feats": [ |
| x[:, img_ids] for x in neck_out["vision_feats"] |
| ], |
| "vision_masks": [ |
| x[img_ids] if x is not None else None |
| for x in neck_out["vision_masks"] |
| ], |
| "vision_pos_embeds": [ |
| x[:, img_ids] for x in neck_out["vision_pos_embeds"] |
| ], |
| "feat_sizes": neck_out["feat_sizes"], |
| } |
| else: |
| |
| |
| need_interactive_out = (stage_id in frames_to_add_correction_pt) or ( |
| stage_id in init_cond_frames |
| ) |
| (current_image, current_backbone_features) = ( |
| self._prepare_backbone_features_per_frame( |
| input.img_batch, |
| img_ids, |
| need_interactive_out=need_interactive_out, |
| need_propagation_out=True, |
| ) |
| ) |
|
|
| |
| current_out = self.track_step( |
| frame_idx=stage_id, |
| is_init_cond_frame=stage_id in init_cond_frames, |
| backbone_features_interactive=current_backbone_features.get( |
| "interactive" |
| ), |
| backbone_features_propagation=current_backbone_features.get( |
| "sam2_backbone_out" |
| ), |
| image=current_image, |
| point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None), |
| mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None), |
| gt_masks=backbone_out["gt_masks_per_frame"].get(stage_id, None), |
| frames_to_add_correction_pt=frames_to_add_correction_pt, |
| output_dict=output_dict, |
| num_frames=num_frames, |
| multiplex_state=multiplex_state, |
| objects_to_interact=objects_to_interact, |
| ) |
| |
| add_output_as_cond_frame = stage_id in init_cond_frames or ( |
| self.add_all_frames_to_correct_as_cond |
| and stage_id in frames_to_add_correction_pt |
| ) |
| if add_output_as_cond_frame: |
| output_dict["cond_frame_outputs"][stage_id] = current_out |
| else: |
| output_dict["non_cond_frame_outputs"][stage_id] = current_out |
|
|
| output_dict["multiplex_state"] = multiplex_state |
|
|
| if return_dict: |
| return output_dict |
| |
| all_frame_outputs = {} |
| all_frame_outputs.update(output_dict["cond_frame_outputs"]) |
| all_frame_outputs.update(output_dict["non_cond_frame_outputs"]) |
| all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)] |
| |
| all_frame_outputs = [ |
| {k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs |
| ] |
|
|
| return all_frame_outputs |
|
|
| def _track_step_aux( |
| self, |
| *, |
| frame_idx, |
| is_init_cond_frame, |
| backbone_features_interactive, |
| backbone_features_propagation, |
| image, |
| point_inputs, |
| mask_inputs, |
| gt_masks, |
| frames_to_add_correction_pt, |
| output_dict, |
| num_frames, |
| track_in_reverse=False, |
| run_mem_encoder=True, |
| prev_sam_mask_logits=None, |
| multiplex_state: MultiplexState, |
| objects_to_interact: Optional[list[int]] = None, |
| need_aux_output: bool = False, |
| ) -> tuple[StageOutput, dict]: |
| """ |
| There are four different modes that track_step might enter, based on the inputs |
| 1. Mask-as-output. This is when mask_inputs is not None. |
| The input mask is returned directly. This case is for FA/VOS initialization. |
| 2. Propagation-only. This is when mask_inputs and point_inputs are empty. |
| We propagate masks using the memory only. This case is for VOS propagation. |
| 3. Interaction-only. This is when mask_inputs is None, point_inputs is not None, |
| and one of the followings is satisified: |
| a) prev_sam_mask_logits is not None. In this case, we refine prev_sam_mask_logits |
| with additional interactions, updating only the objects specified in objects_to_interact. |
| objects_to_interact must not be None. |
| This occurs when we refine the same frame with multiple point inputs iteratively. |
| b) prev_sam_mask_logits is None, and is_init_cond_frame is True. |
| This case is for initializing the first frame. All objects will have point inputs. |
| This mostly happens during training/interactive eval. |
| 4. Propagation-and-interaction. This is when mask_inputs is None, point_inputs is not None, |
| prev_sam_mask_logits is None, and objects_to_interact is not None. |
| This is when we are propagating to a new frame that has point inputs (from previous interactions). |
| This is more of an edge case that could happen in offline interactive eval. |
| We first propagate the mask to the current frame, and then perform interaction on the selected |
| objects. Finally, we replace the masks of the interacted objects in the propagated output |
| with the masks from the interaction output. |
| """ |
| current_out: StageOutput = { |
| "conditioning_objects": set(), |
| "point_inputs": point_inputs, |
| "mask_inputs": mask_inputs, |
| } |
|
|
| mode = None |
| if mask_inputs is not None: |
| mode = "mask_as_output" |
| elif point_inputs is None: |
| mode = "propagation_only" |
| elif point_inputs is not None: |
| |
| if prev_sam_mask_logits is not None: |
| assert objects_to_interact is not None, ( |
| "objects_to_interact must be specified when refining with prev_sam_mask_logits" |
| ) |
| mode = "interaction_only" |
| |
| elif is_init_cond_frame: |
| mode = "interaction_only" |
| |
| elif objects_to_interact is not None and prev_sam_mask_logits is None: |
| assert not self.training |
| mode = "propagation_and_interaction" |
|
|
| if mode is None: |
| raise ValueError( |
| f"Unable to determine tracking case. " |
| f"mask_inputs={mask_inputs is not None}, " |
| f"point_inputs={point_inputs is not None}, " |
| f"prev_sam_mask_logits={prev_sam_mask_logits is not None}, " |
| f"objects_to_interact={objects_to_interact}, " |
| f"is_init_cond_frame={is_init_cond_frame}" |
| ) |
| |
| interactive_high_res_features = interactive_vision_feats = None |
| interactive_feat_sizes = None |
| if backbone_features_interactive is not None: |
| interactive_vision_feats = backbone_features_interactive["vision_feats"] |
| interactive_feat_sizes = backbone_features_interactive["feat_sizes"] |
|
|
| |
| if len(interactive_vision_feats) > 1: |
| interactive_high_res_features = [ |
| x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) |
| for x, s in zip( |
| interactive_vision_feats[:-1], interactive_feat_sizes[:-1] |
| ) |
| ] |
| else: |
| |
| assert mode not in ["interaction_only", "propagation_and_interaction"] |
|
|
| propagation_high_res_features = propagation_vision_feats = None |
| propagation_vision_masks = None |
| propagation_vision_pos_embeds = propagation_feat_sizes = None |
| if backbone_features_propagation is not None: |
| propagation_vision_feats = backbone_features_propagation["vision_feats"] |
| propagation_vision_masks = backbone_features_propagation["vision_masks"] |
| propagation_vision_pos_embeds = backbone_features_propagation[ |
| "vision_pos_embeds" |
| ] |
| propagation_feat_sizes = backbone_features_propagation["feat_sizes"] |
|
|
| |
| if len(propagation_vision_feats) > 1: |
| propagation_high_res_features = [ |
| x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) |
| for x, s in zip( |
| propagation_vision_feats[:-1], propagation_feat_sizes[:-1] |
| ) |
| ] |
| else: |
| |
| assert mode not in ["propagation_only", "propagation_and_interaction"] |
| assert not run_mem_encoder |
|
|
| interactive_pix_feat = None |
| if mode == "mask_as_output": |
| |
| assert self.use_mask_input_as_output_without_sam |
| |
| |
| |
| interactive_pix_feat = self._get_interactive_pix_mem( |
| interactive_vision_feats, interactive_feat_sizes |
| ) |
| sam_outputs = self._use_mask_as_output( |
| backbone_features=interactive_pix_feat, |
| high_res_features=interactive_high_res_features, |
| mask_inputs=mask_inputs, |
| multiplex_state=multiplex_state, |
| ) |
| |
| current_out["conditioning_objects"].update(range(mask_inputs.shape[0])) |
| else: |
| |
| propagation_out = None |
| if mode in ["propagation_only", "propagation_and_interaction"]: |
| |
| assert backbone_features_propagation is not None |
| assert propagation_vision_feats is not None |
| assert propagation_vision_masks is not None |
| assert propagation_vision_pos_embeds is not None |
| assert propagation_feat_sizes is not None |
| pix_feat_with_mem = self._prepare_memory_conditioned_features( |
| frame_idx=frame_idx, |
| is_init_cond_frame=is_init_cond_frame, |
| current_vision_feats=propagation_vision_feats[-1:], |
| current_vision_masks=propagation_vision_masks[-1:], |
| current_vision_pos_embeds=propagation_vision_pos_embeds[-1:], |
| feat_sizes=propagation_feat_sizes[-1:], |
| output_dict=output_dict, |
| num_frames=num_frames, |
| track_in_reverse=track_in_reverse, |
| multiplex_state=multiplex_state, |
| ) |
|
|
| |
| |
| multimask_output = self._use_multimask( |
| is_init_cond_frame, point_inputs=None |
| ) |
| propagation_out = self._forward_sam_heads( |
| backbone_features=pix_feat_with_mem, |
| propagation_high_res_features=propagation_high_res_features, |
| multimask_output=multimask_output, |
| objects_to_interact=list( |
| range(multiplex_state.total_valid_entries) |
| ), |
| multiplex_state=multiplex_state, |
| ) |
|
|
| interaction_out = None |
| if mode in ["interaction_only", "propagation_and_interaction"]: |
| assert backbone_features_interactive is not None |
| assert interactive_vision_feats is not None |
| assert interactive_feat_sizes is not None |
| interactive_pix_feat = self._get_interactive_pix_mem( |
| interactive_vision_feats, interactive_feat_sizes |
| ) |
|
|
| |
| |
| |
| |
| |
| assert mask_inputs is None and point_inputs is not None |
| if prev_sam_mask_logits is not None: |
| assert objects_to_interact is not None |
| assert self.iter_use_prev_mask_pred |
| assert mode != "propagation_and_interaction" |
| mask_inputs = prev_sam_mask_logits[objects_to_interact] |
| elif mode == "propagation_and_interaction": |
| |
| assert objects_to_interact is not None |
| assert propagation_out is not None |
| mask_inputs = propagation_out["low_res_masks"][objects_to_interact] |
|
|
| if objects_to_interact is not None: |
| assert point_inputs["point_coords"].shape[0] == len( |
| objects_to_interact |
| ) |
| assert point_inputs["point_labels"].shape[0] == len( |
| objects_to_interact |
| ) |
|
|
| multimask_output = self._use_multimask( |
| is_init_cond_frame, point_inputs=point_inputs |
| ) |
| interaction_out = self._forward_sam_heads( |
| backbone_features=interactive_pix_feat, |
| point_inputs=point_inputs, |
| mask_inputs=mask_inputs, |
| interactive_high_res_features=interactive_high_res_features, |
| multimask_output=multimask_output, |
| objects_to_interact=( |
| objects_to_interact |
| if objects_to_interact is not None |
| else list(range(multiplex_state.total_valid_entries)) |
| ), |
| multiplex_state=multiplex_state, |
| ) |
| if objects_to_interact is None: |
| current_out["conditioning_objects"].update( |
| multiplex_state.get_all_valid_object_idx() |
| ) |
| else: |
| current_out["conditioning_objects"].update(objects_to_interact) |
|
|
| if propagation_out is None and interaction_out is not None: |
| sam_outputs = interaction_out |
| elif interaction_out is None and propagation_out is not None: |
| sam_outputs = propagation_out |
| else: |
| |
| assert propagation_out is not None and interaction_out is not None |
| keys_to_merge = [ |
| "low_res_multimasks", |
| "high_res_multimasks", |
| "low_res_masks", |
| "high_res_masks", |
| "ious", |
| "object_score_logits", |
| "obj_ptr", |
| ] |
| for k in keys_to_merge: |
| src = interaction_out[k] |
| dst = propagation_out[k] |
| |
| if torch.is_tensor(src) and torch.is_tensor(dst): |
| if torch.is_floating_point(src) and src.dtype != dst.dtype: |
| src = src.to(dtype=dst.dtype) |
| propagation_out[k][objects_to_interact] = src |
| sam_outputs = propagation_out |
|
|
| low_res_multimasks = sam_outputs["low_res_multimasks"] |
| high_res_multimasks = sam_outputs["high_res_multimasks"] |
| ious = sam_outputs["ious"] |
| low_res_masks = sam_outputs["low_res_masks"] |
| high_res_masks = sam_outputs["high_res_masks"] |
| object_score_logits = sam_outputs["object_score_logits"] |
|
|
| current_out["multistep_pred_masks"] = low_res_masks |
| current_out["multistep_pred_masks_high_res"] = high_res_masks |
| current_out["multistep_pred_multimasks"] = [low_res_multimasks] |
| current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks] |
| current_out["multistep_pred_ious"] = [ious] |
| current_out["multistep_point_inputs"] = [point_inputs] |
| current_out["multistep_object_score_logits"] = [object_score_logits] |
|
|
| if self.use_obj_ptrs_in_encoder: |
| obj_ptr = sam_outputs["obj_ptr"] |
|
|
| |
| if frame_idx in frames_to_add_correction_pt: |
| assert gt_masks is not None |
| assert interactive_vision_feats is not None |
| assert interactive_feat_sizes is not None |
| all_pred_masks = [low_res_masks] |
| all_pred_high_res_masks = [high_res_masks] |
| all_pred_multimasks = [low_res_multimasks] |
| all_pred_high_res_multimasks = [high_res_multimasks] |
| all_pred_ious = [ious] |
| all_point_inputs = [point_inputs] |
| all_object_score_logits = [object_score_logits] |
|
|
| |
| if self.training: |
| assert objects_to_interact is None |
|
|
| interact_with_all_objects = ( |
| self.rng.random() < self.prob_correct_all_objects_for_train |
| ) or ( |
| self.force_correct_all_for_conditional_inputs and is_init_cond_frame |
| ) |
|
|
| if interact_with_all_objects: |
| num_objects_to_correct = gt_masks.shape[0] |
| elif self.rand_objects_to_correct_for_train: |
| num_objects_to_correct = self.rng2.integers( |
| 1, |
| int( |
| gt_masks.shape[0] |
| * self.ratio_of_objects_to_correct_for_train |
| ) |
| + 1, |
| ) |
| else: |
| num_objects_to_correct = max( |
| 1, |
| int( |
| gt_masks.shape[0] |
| * self.ratio_of_objects_to_correct_for_train |
| ), |
| ) |
|
|
| objects_to_interact = self.rng2.choice( |
| range(gt_masks.shape[0]), |
| size=num_objects_to_correct, |
| replace=False, |
| ).tolist() |
|
|
| if point_inputs is not None: |
| |
| point_inputs = { |
| "point_coords": point_inputs["point_coords"][ |
| objects_to_interact |
| ], |
| "point_labels": point_inputs["point_labels"][ |
| objects_to_interact |
| ], |
| } |
| else: |
| assert objects_to_interact is not None |
| |
|
|
| if point_inputs is not None: |
| assert point_inputs["point_coords"].shape[0] == len(objects_to_interact) |
| assert point_inputs["point_labels"].shape[0] == len(objects_to_interact) |
|
|
| for _ in range(self.num_correction_pt_per_frame): |
| |
| |
| if self.training and self.prob_to_sample_from_gt_for_train > 0: |
| sample_from_gt = ( |
| self.rng.random() < self.prob_to_sample_from_gt_for_train |
| ) |
| else: |
| sample_from_gt = False |
| |
| pred_for_new_pt = None if sample_from_gt else (high_res_masks > 0) |
| new_points, new_labels = get_next_point( |
| gt_masks=gt_masks[objects_to_interact], |
| pred_masks=( |
| pred_for_new_pt[objects_to_interact] |
| if pred_for_new_pt is not None |
| else None |
| ), |
| method="uniform" if self.training else self.pt_sampling_for_eval, |
| ) |
| point_inputs = concat_points(point_inputs, new_points, new_labels) |
| assert low_res_masks.shape[0] > max(objects_to_interact), ( |
| f"interacting {objects_to_interact} in {low_res_masks.shape}?" |
| ) |
| if self.iter_use_prev_mask_pred: |
| |
| |
| |
| mask_inputs = low_res_masks[objects_to_interact] |
| multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) |
| pix_feat_with_mem = self._get_interactive_pix_mem( |
| interactive_vision_feats, interactive_feat_sizes |
| ) |
| sam_outputs = self._forward_sam_heads( |
| backbone_features=pix_feat_with_mem, |
| point_inputs=point_inputs, |
| mask_inputs=mask_inputs, |
| interactive_high_res_features=interactive_high_res_features, |
| propagation_high_res_features=propagation_high_res_features, |
| multimask_output=multimask_output, |
| gt_masks=gt_masks, |
| objects_to_interact=objects_to_interact, |
| multiplex_state=multiplex_state, |
| ) |
| interact_low_res_multimasks = sam_outputs["low_res_multimasks"] |
| interact_high_res_multimasks = sam_outputs["high_res_multimasks"] |
| interact_ious = sam_outputs["ious"] |
| interact_low_res_masks = sam_outputs["low_res_masks"] |
| interact_high_res_masks = sam_outputs["high_res_masks"] |
| interact_object_score_logits = sam_outputs["object_score_logits"] |
| if self.use_obj_ptrs_in_encoder: |
| interact_obj_ptr = sam_outputs["obj_ptr"] |
|
|
| if self.training: |
| |
| low_res_masks = low_res_masks.clone() |
| high_res_masks = high_res_masks.clone() |
| low_res_multimasks = low_res_multimasks.clone() |
| high_res_multimasks = high_res_multimasks.clone() |
| ious = ious.clone() |
| object_score_logits = object_score_logits.clone() |
| obj_ptr = obj_ptr.clone() if self.use_obj_ptrs_in_encoder else None |
|
|
| |
| if ( |
| torch.is_floating_point(interact_low_res_masks) |
| and interact_low_res_masks.dtype != low_res_masks.dtype |
| ): |
| interact_low_res_masks = interact_low_res_masks.to( |
| dtype=low_res_masks.dtype |
| ) |
| low_res_masks[objects_to_interact] = interact_low_res_masks |
| if ( |
| torch.is_floating_point(interact_high_res_masks) |
| and interact_high_res_masks.dtype != high_res_masks.dtype |
| ): |
| interact_high_res_masks = interact_high_res_masks.to( |
| dtype=high_res_masks.dtype |
| ) |
| high_res_masks[objects_to_interact] = interact_high_res_masks |
| if ( |
| torch.is_floating_point(interact_low_res_multimasks) |
| and interact_low_res_multimasks.dtype != low_res_multimasks.dtype |
| ): |
| interact_low_res_multimasks = interact_low_res_multimasks.to( |
| dtype=low_res_multimasks.dtype |
| ) |
| low_res_multimasks[objects_to_interact] = interact_low_res_multimasks |
| if ( |
| torch.is_floating_point(interact_high_res_multimasks) |
| and interact_high_res_multimasks.dtype != high_res_multimasks.dtype |
| ): |
| interact_high_res_multimasks = interact_high_res_multimasks.to( |
| dtype=high_res_multimasks.dtype |
| ) |
| high_res_multimasks[objects_to_interact] = interact_high_res_multimasks |
| if ( |
| torch.is_floating_point(interact_ious) |
| and interact_ious.dtype != ious.dtype |
| ): |
| interact_ious = interact_ious.to(dtype=ious.dtype) |
| ious[objects_to_interact] = interact_ious |
| if ( |
| torch.is_floating_point(interact_object_score_logits) |
| and interact_object_score_logits.dtype != object_score_logits.dtype |
| ): |
| interact_object_score_logits = interact_object_score_logits.to( |
| dtype=object_score_logits.dtype |
| ) |
| object_score_logits[objects_to_interact] = interact_object_score_logits |
| if self.use_obj_ptrs_in_encoder: |
| obj_ptr[objects_to_interact] = interact_obj_ptr |
|
|
| all_pred_masks.append(low_res_masks) |
| all_pred_high_res_masks.append(high_res_masks) |
| all_pred_multimasks.append(low_res_multimasks) |
| all_pred_high_res_multimasks.append(high_res_multimasks) |
| all_pred_ious.append(ious) |
| all_point_inputs.append(point_inputs) |
| all_object_score_logits.append(object_score_logits) |
|
|
| |
| |
| current_out["multistep_pred_masks"] = torch.cat(all_pred_masks, dim=1) |
| current_out["multistep_pred_masks_high_res"] = torch.cat( |
| all_pred_high_res_masks, dim=1 |
| ) |
| current_out["multistep_pred_multimasks"] = all_pred_multimasks |
| current_out["multistep_pred_multimasks_high_res"] = ( |
| all_pred_high_res_multimasks |
| ) |
| current_out["multistep_pred_ious"] = all_pred_ious |
| current_out["multistep_point_inputs"] = all_point_inputs |
| current_out["multistep_object_score_logits"] = all_object_score_logits |
|
|
| if self.add_all_frames_to_correct_as_cond: |
| if objects_to_interact is None: |
| current_out["conditioning_objects"].update( |
| multiplex_state.get_all_valid_object_idx() |
| ) |
| else: |
| current_out["conditioning_objects"].update(set(objects_to_interact)) |
|
|
| |
| current_out["pred_masks"] = low_res_masks |
| current_out["pred_masks_high_res"] = high_res_masks |
| if self.use_obj_ptrs_in_encoder: |
| |
| current_out["obj_ptr"] = multiplex_state.mux(obj_ptr) |
| if self.use_memory_selection: |
| current_out["object_score_logits"] = object_score_logits |
| iou_score = current_out["multistep_pred_ious"][-1].max(-1)[0] |
| current_out["iou_score"] = iou_score |
| current_out["eff_iou_score"] = self.cal_mem_score( |
| object_score_logits, iou_score |
| ) |
| |
| current_out["object_score_logits"] = object_score_logits |
|
|
| |
| |
| |
| |
| if run_mem_encoder and self.num_maskmem > 0: |
| high_res_masks_for_mem_enc = high_res_masks |
| maskmem_features, maskmem_pos_enc = self._encode_new_memory( |
| image=image, |
| current_vision_feats=propagation_vision_feats, |
| feat_sizes=propagation_feat_sizes, |
| pred_masks_high_res=high_res_masks_for_mem_enc, |
| object_score_logits=object_score_logits, |
| is_mask_from_pts=(point_inputs is not None), |
| conditioning_objects=current_out["conditioning_objects"], |
| multiplex_state=multiplex_state, |
| ) |
| current_out["maskmem_features"] = maskmem_features |
| current_out["maskmem_pos_enc"] = maskmem_pos_enc |
|
|
| if self.save_image_features: |
| current_out["image_features"] = propagation_vision_feats[-1] |
| current_out["image_pos_enc"] = propagation_vision_pos_embeds[-1] |
|
|
| |
| aux_output = {} |
| if need_aux_output: |
| if interactive_pix_feat is None: |
| interactive_pix_feat = self._get_interactive_pix_mem( |
| interactive_vision_feats, interactive_feat_sizes |
| ) |
| aux_output["interactive_pix_feat"] = interactive_pix_feat |
| aux_output["interactive_high_res_features"] = interactive_high_res_features |
| aux_output["propagation_vision_feats"] = propagation_vision_feats |
| aux_output["propagation_feat_sizes"] = propagation_feat_sizes |
|
|
| return current_out, aux_output |
|
|
| def _trim_output_and_memory( |
| self, |
| frame_idx: int, |
| output_dict: dict[str, dict[int, StageOutput]], |
| current_out: StageOutput, |
| memory_encoder_was_used: bool, |
| ) -> StageOutput: |
| |
| |
| if self.offload_output_to_cpu_for_eval and not self.training: |
| |
| trimmed_out: StageOutput = { |
| "conditioning_objects": current_out["conditioning_objects"], |
| "pred_masks": current_out["pred_masks"].cpu(), |
| "pred_masks_high_res": current_out["pred_masks_high_res"].cpu(), |
| |
| "object_score_logits": current_out["object_score_logits"], |
| "multistep_point_inputs": current_out["multistep_point_inputs"], |
| } |
| if self.use_obj_ptrs_in_encoder: |
| trimmed_out["obj_ptr"] = current_out["obj_ptr"] |
| if memory_encoder_was_used and self.num_maskmem > 0: |
| trimmed_out["maskmem_features"] = current_out["maskmem_features"].cpu() |
| trimmed_out["maskmem_pos_enc"] = [ |
| x.cpu() for x in current_out["maskmem_pos_enc"] |
| ] |
| if self.save_image_features: |
| trimmed_out["image_features"] = current_out["image_features"].cpu() |
| trimmed_out["image_pos_enc"] = current_out["image_pos_enc"].cpu() |
| current_out = trimmed_out |
|
|
| |
| |
| |
| def _trim_past_out( |
| past_out: StageOutput, current_out: StageOutput |
| ) -> Optional[StageOutput]: |
| if past_out is None: |
| return None |
| trimmed_past_out: StageOutput = { |
| "conditioning_objects": past_out["conditioning_objects"], |
| "pred_masks": past_out["pred_masks"], |
| "object_score_logits": past_out["object_score_logits"], |
| } |
| if "local_obj_id_to_idx" in past_out: |
| trimmed_past_out["local_obj_id_to_idx"] = past_out["local_obj_id_to_idx"].copy() |
| if "multistep_point_inputs" in past_out: |
| trimmed_past_out["multistep_point_inputs"] = past_out["multistep_point_inputs"] |
| if self.use_obj_ptrs_in_encoder and "obj_ptr" in past_out: |
| trimmed_past_out["obj_ptr"] = past_out["obj_ptr"] |
| return trimmed_past_out |
|
|
| if self.trim_past_non_cond_mem_for_eval and not self.training: |
| r = self.memory_temporal_stride_for_eval |
| past_frame_idx = frame_idx - r * self.num_maskmem |
| past_out = output_dict["non_cond_frame_outputs"].get(past_frame_idx, None) |
|
|
| if past_out is not None: |
| if ( |
| self.use_memory_selection |
| and past_out.get("eff_iou_score", 0) < self.mf_threshold |
| ) or not self.use_memory_selection: |
| output_dict["non_cond_frame_outputs"][past_frame_idx] = ( |
| _trim_past_out(past_out, current_out) |
| ) |
|
|
| if ( |
| self.use_memory_selection and not self.offload_output_to_cpu_for_eval |
| ): |
| far_old_frame_idx = frame_idx - 20 * self.max_obj_ptrs_in_encoder |
| past_out = output_dict["non_cond_frame_outputs"].get( |
| far_old_frame_idx, None |
| ) |
| if past_out is not None: |
| output_dict["non_cond_frame_outputs"][far_old_frame_idx] = ( |
| _trim_past_out(past_out, current_out) |
| ) |
|
|
| return current_out |
|
|
| def track_step( |
| self, |
| *, |
| frame_idx, |
| is_init_cond_frame, |
| backbone_features_interactive, |
| backbone_features_propagation, |
| image, |
| point_inputs, |
| mask_inputs, |
| gt_masks, |
| frames_to_add_correction_pt, |
| output_dict, |
| num_frames, |
| track_in_reverse=False, |
| |
| |
| |
| |
| |
| run_mem_encoder=True, |
| |
| prev_sam_mask_logits=None, |
| multiplex_state: MultiplexState, |
| |
| |
| objects_to_interact: Optional[list[int]] = None, |
| ) -> StageOutput: |
| current_out, _ = self._track_step_aux( |
| frame_idx=frame_idx, |
| is_init_cond_frame=is_init_cond_frame, |
| backbone_features_interactive=backbone_features_interactive, |
| backbone_features_propagation=backbone_features_propagation, |
| image=image, |
| point_inputs=point_inputs, |
| mask_inputs=mask_inputs, |
| gt_masks=gt_masks, |
| frames_to_add_correction_pt=frames_to_add_correction_pt, |
| output_dict=output_dict, |
| num_frames=num_frames, |
| track_in_reverse=track_in_reverse, |
| run_mem_encoder=run_mem_encoder, |
| prev_sam_mask_logits=prev_sam_mask_logits, |
| multiplex_state=multiplex_state, |
| objects_to_interact=objects_to_interact, |
| need_aux_output=False, |
| ) |
| current_out = self._trim_output_and_memory( |
| frame_idx, output_dict, current_out, memory_encoder_was_used=run_mem_encoder |
| ) |
|
|
| return current_out |
|
|
| def back_convert(self, targets): |
| """To be compatible with SetCriterionAPI losses (mask loss only).""" |
| batched_targets = {} |
| batched_targets["num_boxes"] = targets.num_boxes |
| batched_targets["masks"] = targets.segments |
| batched_targets["is_valid_mask"] = targets.is_valid_segment |
| return batched_targets |
|
|
| def _use_multimask(self, is_init_cond_frame, point_inputs): |
| """Whether to use multimask output in the SAM head.""" |
| num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) |
| multimask_output = ( |
| self.multimask_output_in_sam |
| and (is_init_cond_frame or self.multimask_output_for_tracking) |
| and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) |
| and self.num_multimask_outputs > 0 |
| ) |
| return multimask_output |
|
|
| def _apply_non_overlapping_constraints(self, pred_masks): |
| """ |
| Apply non-overlapping constraints to the object scores in pred_masks. Here we |
| keep only the highest scoring object at each spatial location in pred_masks. |
| """ |
| batch_size = pred_masks.size(0) |
| if batch_size == 1: |
| return pred_masks |
|
|
| device = pred_masks.device |
| |
| max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) |
| |
| batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] |
| keep = max_obj_inds == batch_obj_inds |
| |
| |
| pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) |
| return pred_masks |
|
|
| def _compile_all_components(self): |
| """Compile all model components for faster inference.""" |
| |
| |
| torch._dynamo.config.cache_size_limit = 64 |
| torch._dynamo.config.accumulated_cache_size_limit = 2048 |
|
|
| logging.info("Compiling all components. First time may be very slow.") |
|
|
| self.maskmem_backbone.forward = torch.compile( |
| self.maskmem_backbone.forward, |
| mode="max-autotune", |
| fullgraph=True, |
| dynamic=False, |
| ) |
| self.transformer.encoder.forward = torch.compile( |
| self.transformer.encoder.forward, |
| mode="max-autotune", |
| fullgraph=True, |
| dynamic=True, |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| self.sam_mask_decoder.forward = torch.compile( |
| self.sam_mask_decoder.forward, |
| mode="max-autotune", |
| fullgraph=True, |
| dynamic=False, |
| ) |
|
|
| def _maybe_clone(self, x): |
| """Clone a tensor if and only if `self.compile_all_components` is True.""" |
| return x.clone() if self.compile_all_components else x |
|
|
| def get_propagation_dense_pe(self) -> torch.Tensor: |
| """ |
| Returns the positional encoding used to encode point prompts, |
| applied to a dense set of points the shape of the image encoding. |
| |
| Returns: |
| torch.Tensor: Positional encoding with shape |
| 1x(embed_dim)x(embedding_h)x(embedding_w) |
| """ |
| return self.image_pe_layer( |
| (self.sam_image_embedding_size, self.sam_image_embedding_size) |
| ).unsqueeze(0) |
|
|
| def cal_mem_score(self, object_score_logits, iou_score): |
| object_score_norm = torch.where( |
| object_score_logits > 0, |
| object_score_logits.sigmoid() * 2 - 1, |
| torch.zeros_like(object_score_logits), |
| ) |
| score_per_frame = (object_score_norm * iou_score).mean() |
| return score_per_frame |
|
|
| def frame_filter(self, output_dict, track_in_reverse, frame_idx, num_frames, r): |
| if (frame_idx == 0 and not track_in_reverse) or ( |
| frame_idx == num_frames - 1 and track_in_reverse |
| ): |
| return [] |
|
|
| max_num = min( |
| num_frames, self.max_obj_ptrs_in_encoder |
| ) |
|
|
| if not track_in_reverse: |
| start = frame_idx - 1 |
| end = 0 |
| step = -r |
| must_include = frame_idx - 1 |
| else: |
| start = frame_idx + 1 |
| end = num_frames |
| step = r |
| must_include = frame_idx + 1 |
|
|
| valid_indices = [] |
| for i in range(start, end, step): |
| if ( |
| i not in output_dict["non_cond_frame_outputs"] |
| or "eff_iou_score" not in output_dict["non_cond_frame_outputs"][i] |
| ): |
| continue |
|
|
| score_per_frame = output_dict["non_cond_frame_outputs"][i]["eff_iou_score"] |
|
|
| if score_per_frame > self.mf_threshold: |
| valid_indices.insert(0, i) |
|
|
| if len(valid_indices) >= max_num - 1: |
| break |
|
|
| if must_include not in valid_indices: |
| valid_indices.append(must_include) |
|
|
| return valid_indices |
|
|
|
|
| def concat_points(old_point_inputs, new_points, new_labels): |
| """Add new points and labels to previous point inputs (add at the end).""" |
| if old_point_inputs is None: |
| points, labels = new_points, new_labels |
| else: |
| points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1) |
| labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1) |
|
|
| return {"point_coords": points, "point_labels": labels} |
|
|
|
|
| def _append( |
| d1: StageOutput, d2: SAMOutput, k1: str, k2: str, dim: int = 0, strict: bool = True |
| ): |
| if strict: |
| assert k1 in d1, f"{k1} not found" |
| else: |
| if k1 not in d1: |
| return |
|
|
| d1[k1] = torch.cat([d1[k1], d2[k2]], dim=dim) |
|
|
|
|
| def _merge( |
| d1: StageOutput, |
| d2: SAMOutput, |
| k1: str, |
| k2: str, |
| d2_idx: list[int], |
| strict: bool = True, |
| ): |
| if strict: |
| assert k1 in d1, f"{k1} not found" |
| else: |
| if k1 not in d1: |
| return |
| d1[k1][d2_idx] = d2[k2].to(dtype=d1[k1].dtype) |
|
|
|
|
| class VideoTrackingDynamicMultiplex(VideoTrackingMultiplex): |
| def __init__( |
| self, |
| enable_dynamic_training: bool = True, |
| rand_num_transition_points: bool = True, |
| max_num_transition_points: int = 3, |
| add_all_transition_frames_as_cond: bool = True, |
| max_trans_frames_in_attn: int = 4, |
| is_dynamic_model: bool = True, |
| is_dynamic_vos_evaluation: bool = False, |
| **kwargs, |
| ): |
| super().__init__(is_dynamic_model=is_dynamic_model, **kwargs) |
|
|
| self.enable_dynamic_training = enable_dynamic_training |
| self.rand_num_transition_points = rand_num_transition_points |
| self.max_num_transition_points = max_num_transition_points |
|
|
| self.add_all_transition_frames_as_cond = add_all_transition_frames_as_cond |
| self.max_trans_frames_in_attn = max_trans_frames_in_attn |
| self.is_dynamic_vos_evaluation = is_dynamic_vos_evaluation |
|
|
| def prepare_prompt_inputs(self, backbone_out, input, start_frame_idx=0): |
| """ |
| Prepare input mask, point or box prompts. Optionally, we allow tracking from |
| a custom `start_frame_idx` to the end of the video (for evaluation purposes). |
| """ |
|
|
| """ |
| This function, in addition to the prompt preparation done in the parent class, preprocesses the |
| masks and pre-computes visibility/validity attributes necessary for training with dynamic bucketing. |
| |
| **Data** |
| We use a modified dataset class and a modified collate_fn such that: |
| 1. The mask for an object is loaded if it is visible (area>0) on any of the loaded frames |
| 2. A "visible_objects_per_frame" attribute is computed, which contains the set of objects with area>0 on each frame |
| |
| Here, we use [] to denote a set of objects; i.e., object A and B are represented as [A, B]. |
| Consider the masks given by the dataloader in an arbitrary yet deterministic order. |
| That is, [2, 3] can appear on the first frame, and [1, 2, 3, 17] can appear on the second frame. |
| |
| This is incompatible with the object addition implementation, since we assume new objects are appended, not inserted. |
| Thus, we compute object_appearance_order which sorts the object idx using the frame at which they appear |
| (conditional frames always appear first). For objects that appear on the same frame, we shuffle them as augmentation. |
| We also reorder the ground-truth masks used for supervision. |
| |
| **Causal supervision** |
| Since not all objects appear on the first frame, we should not supervise on the objects that the model has no knowledge of yet. |
| Thus, we keep track of the set of objects that have been introduced, and the frame at which that happens. |
| We compute valid_idx_per_frame (and correspondingly trim the ground-truth) to enforce reasonable supervisions. |
| |
| **Transition points** |
| Transition points are non-initial-conditioning frames that introduce new objects. We uniformly sample some frames |
| to be candidates for transition points, and use them if they actually introduce new objects compared to the last seen |
| conditional frame/transition point. |
| Transitions do not always happen when an object first becomes visible, because our (initial) sampling is agnostic to visibility. |
| This is intended, as new objects do not always get detected immediately in the dense tracking setting. |
| """ |
|
|
| |
| backbone_out = super()._prepare_prompt_inputs_meta( |
| backbone_out, input, start_frame_idx=start_frame_idx |
| ) |
|
|
| num_frames = backbone_out["num_frames"] |
| gt_masks_per_frame = backbone_out["gt_masks_per_frame"] |
|
|
| if self.training or self.is_dynamic_vos_evaluation: |
| visible_objects_per_frame: dict[int, set[int]] = ( |
| input.visible_objects_per_frame |
| ) |
| else: |
| visible_objects_per_frame: dict[int, set[int]] = { |
| stage_id: set(range(gt_masks_per_frame[stage_id].shape[0])) |
| for stage_id in range(num_frames) |
| } |
|
|
| |
| |
| init_cond_frames: list[int] = backbone_out["init_cond_frames"] |
| init_cond_frames = sorted(init_cond_frames) |
| frames_not_in_init_cond: list[int] = backbone_out["frames_not_in_init_cond"] |
|
|
| |
| |
| if len(visible_objects_per_frame[start_frame_idx]) == 0: |
| if self.training: |
| logging.warning("Empty first frame, tracking an empty object") |
| visible_objects_per_frame[start_frame_idx] = {0} |
| |
| for stage_id in range(num_frames): |
| gt_masks_per_frame[stage_id][0] = torch.zeros_like( |
| gt_masks_per_frame[stage_id][0] |
| ) |
| else: |
| |
| |
| assert self.is_dynamic_vos_evaluation, ( |
| f"{visible_objects_per_frame=} invalid" |
| ) |
| assert len(init_cond_frames) == 1 |
| for stage_id in range(start_frame_idx, num_frames): |
| if len(visible_objects_per_frame[stage_id]) > 0: |
| init_cond_frames = [stage_id] |
| break |
| for i in range( |
| init_cond_frames[0] + 1 |
| ): |
| if i in frames_not_in_init_cond: |
| frames_not_in_init_cond.remove(i) |
|
|
| backbone_out["init_cond_frames"] = init_cond_frames |
|
|
| |
| |
| |
| valid_idx_per_frame: dict[int, list[int]] = {} |
| |
| valid_idx_prior_to_each_transition: dict[int, list[int]] = {} |
| new_idx_per_transition: dict[int, list[int]] = {} |
|
|
| if self.training and self.enable_dynamic_training: |
| |
| if self.rand_num_transition_points: |
| |
| num_transition_points = self.rng.integers( |
| 1, self.max_num_transition_points, endpoint=True |
| ) |
| else: |
| num_transition_points = self.max_num_transition_points |
|
|
| available_transition_points = frames_not_in_init_cond |
| num_transition_points = min( |
| num_transition_points, len(available_transition_points) |
| ) |
| |
| transition_points = self.rng2.choice( |
| available_transition_points, num_transition_points, replace=False |
| ).tolist() |
| transition_points = sorted(transition_points) |
|
|
| |
| filtered_transition_points = [] |
| objects_seen = set() |
| for stage_id in init_cond_frames: |
| objects_seen.update(visible_objects_per_frame[stage_id]) |
|
|
| for stage_id in range(start_frame_idx, num_frames): |
| if stage_id in transition_points: |
| new_objects_seen = ( |
| visible_objects_per_frame[stage_id] - objects_seen |
| ) |
| if len(new_objects_seen) > 0: |
| filtered_transition_points.append(stage_id) |
| objects_seen.update(new_objects_seen) |
| new_idx_per_transition[stage_id] = list(new_objects_seen) |
| transition_points = filtered_transition_points |
|
|
| |
| init_objects = set() |
| for stage_id in init_cond_frames: |
| init_objects.update(visible_objects_per_frame[stage_id]) |
| init_objects = list(init_objects) |
| self.rng2.shuffle(init_objects) |
|
|
| object_appearance_order = init_objects.copy() |
| valid_idx_per_frame[start_frame_idx] = list(range(len(init_objects))) |
| for stage_id in range(start_frame_idx + 1, num_frames): |
| if stage_id in transition_points: |
| |
| stage_objects = new_idx_per_transition[stage_id].copy() |
| self.rng2.shuffle(stage_objects) |
| valid_idx_prior_to_each_transition[stage_id] = list( |
| range(len(object_appearance_order)) |
| ) |
| new_idx_per_transition[stage_id] = list( |
| range( |
| len(object_appearance_order), |
| len(object_appearance_order) + len(stage_objects), |
| ) |
| ) |
| object_appearance_order.extend(stage_objects) |
|
|
| |
| if stage_id in init_cond_frames: |
| |
| |
| |
| |
| |
| valid_idx_per_frame[stage_id] = valid_idx_per_frame[ |
| start_frame_idx |
| ].copy() |
| elif stage_id in frames_not_in_init_cond: |
| valid_idx_per_frame[stage_id] = list( |
| range(len(object_appearance_order)) |
| ) |
| else: |
| raise ValueError( |
| f"Unexpected {stage_id=}? {init_cond_frames=} {frames_not_in_init_cond=} {transition_points=}" |
| ) |
| elif self.is_dynamic_vos_evaluation and not self.training: |
| |
| |
| |
| |
|
|
| |
| object_appearance_order: list[int] = [] |
| object_appear_at_stage: dict[int, int] = {} |
| transition_points: list[int] = [] |
| stage_to_new_objects: dict[int, list[int]] = defaultdict(list) |
| for stage_id in range(start_frame_idx, num_frames): |
| visible_objects = sorted(list(visible_objects_per_frame[stage_id])) |
| for obj_id in visible_objects: |
| if obj_id in object_appear_at_stage: |
| continue |
|
|
| object_appear_at_stage[obj_id] = stage_id |
| object_appearance_order.append(obj_id) |
| stage_to_new_objects[stage_id].append(obj_id) |
| if stage_id not in init_cond_frames: |
| transition_points.append(stage_id) |
|
|
| |
| objects_seen_so_far = [] |
| for stage_id in range(start_frame_idx, num_frames): |
| if stage_id in transition_points: |
| |
| new_objects = stage_to_new_objects[stage_id] |
| num_objects_before = len(objects_seen_so_far) |
|
|
| |
| valid_idx_prior_to_each_transition[stage_id] = list( |
| range(num_objects_before) |
| ) |
| |
| new_idx_per_transition[stage_id] = list( |
| range(num_objects_before, num_objects_before + len(new_objects)) |
| ) |
|
|
| objects_seen_so_far.extend(new_objects) |
|
|
| |
| if stage_id in init_cond_frames: |
| |
| valid_idx_per_frame[stage_id] = list( |
| range(len(stage_to_new_objects[stage_id])) |
| ) |
| objects_seen_so_far.extend(stage_to_new_objects[stage_id]) |
| else: |
| |
| valid_idx_per_frame[stage_id] = list( |
| range(len(objects_seen_so_far)) |
| ) |
|
|
| else: |
| |
| transition_points = [] |
| visible_objects_on_first_frame = sorted( |
| list(visible_objects_per_frame[start_frame_idx]) |
| ) |
| |
| object_orderings = list(range(len(visible_objects_on_first_frame))) |
| |
| object_appearance_order = visible_objects_on_first_frame.copy() |
| for stage_id in range(start_frame_idx, num_frames): |
| valid_idx_per_frame[stage_id] = object_orderings.copy() |
|
|
| |
| for stage_id in range(start_frame_idx, num_frames): |
| gt_masks_per_frame[stage_id] = gt_masks_per_frame[stage_id][ |
| object_appearance_order |
| ][valid_idx_per_frame[stage_id]] |
|
|
| |
| |
| |
| |
| for stage_id, targets in enumerate(input.find_targets): |
| if stage_id in transition_points: |
| |
| prev_objects = valid_idx_prior_to_each_transition[stage_id] |
| |
| targets.segments = gt_masks_per_frame[stage_id][prev_objects].squeeze(1) |
| else: |
| targets.segments = gt_masks_per_frame[stage_id].squeeze(1) |
| |
| |
| targets.num_boxes = targets.num_boxes[: targets.segments.shape[0]] |
|
|
| backbone_out["valid_idx_per_frame"] = valid_idx_per_frame |
| backbone_out["new_idx_per_transition"] = new_idx_per_transition |
| backbone_out["valid_objects_prior_to_each_transition"] = ( |
| valid_idx_prior_to_each_transition |
| ) |
| backbone_out["transition_points"] = set(transition_points) |
| backbone_out["gt_masks_per_frame"] = gt_masks_per_frame |
| backbone_out["object_appearance_order"] = object_appearance_order |
|
|
| backbone_out = self._prepare_conditional_frames(backbone_out) |
|
|
| return backbone_out |
|
|
| def add_new_masks_to_existing_state( |
| self, |
| *, |
| interactive_pix_feat: torch.Tensor, |
| interactive_high_res_features: list[torch.Tensor], |
| propagation_vision_feats: Optional[ |
| list[torch.Tensor] |
| ], |
| propagation_feat_sizes: Optional[ |
| list[tuple[int, int]] |
| ], |
| new_masks: torch.Tensor, |
| obj_idxs_in_mask: list[ |
| int |
| ], |
| obj_ids_in_mask: Optional[ |
| list[int] |
| ], |
| prev_output: StageOutput, |
| multiplex_state: MultiplexState, |
| add_mask_to_memory: bool = True, |
| are_masks_from_pts: bool = False, |
| allow_new_buckets: bool = False, |
| prefer_new_buckets: bool = False, |
| ) -> None: |
| """ |
| Add new objects to an existing output/multiplex state. |
| |
| This function encodes the input masks as new masks and merges them with the existing state. |
| The new object entries are always appended to the existing objects. |
| |
| This is because, in the dense tracking scenario, we should always propagate (existing state) |
| to the current frame first before introducing the new objects. |
| """ |
| assert self.use_mask_input_as_output_without_sam |
| assert new_masks.shape[0] == len(obj_idxs_in_mask) |
|
|
| num_new_objects = new_masks.shape[0] |
|
|
| if obj_ids_in_mask is not None: |
| assert len(obj_ids_in_mask) == num_new_objects |
|
|
| if self.use_obj_ptrs_in_encoder: |
| |
| existing_pointers = multiplex_state.demux(prev_output["obj_ptr"]) |
|
|
| |
| new_object_idx = multiplex_state.find_next_batch_of_available_indices( |
| num_objects=num_new_objects, |
| allow_new_buckets=allow_new_buckets, |
| prefer_new_buckets=prefer_new_buckets, |
| ) |
| multiplex_state.add_objects( |
| object_indices=new_object_idx, |
| object_ids=obj_ids_in_mask, |
| allow_new_buckets=allow_new_buckets, |
| prefer_new_buckets=prefer_new_buckets, |
| ) |
|
|
| |
| mask_output = self._use_mask_as_output( |
| backbone_features=interactive_pix_feat, |
| high_res_features=interactive_high_res_features, |
| mask_inputs=new_masks, |
| multiplex_state=multiplex_state, |
| objects_in_mask=new_object_idx, |
| ) |
|
|
| |
| |
| |
| interactive_resolution = mask_output["high_res_masks"].shape[-1] |
|
|
| |
| if ( |
| "pred_masks_high_res" in prev_output |
| and prev_output["pred_masks_high_res"] is not None |
| ): |
| existing_resolution = prev_output["pred_masks_high_res"].shape[-1] |
|
|
| if existing_resolution != interactive_resolution: |
| |
| |
| |
| prev_output["pred_masks_high_res"] = F.interpolate( |
| prev_output["pred_masks_high_res"], |
| size=(interactive_resolution, interactive_resolution), |
| mode="bilinear", |
| align_corners=False, |
| ) |
|
|
| |
| h, w = prev_output["pred_masks"].shape[-2:] |
| mask_output["low_res_masks"] = F.interpolate( |
| mask_output["low_res_masks"], |
| size=(h, w), |
| align_corners=False, |
| mode="bilinear", |
| antialias=True, |
| ) |
|
|
| _append(prev_output, mask_output, "pred_masks", "low_res_masks") |
| _append( |
| prev_output, |
| mask_output, |
| "pred_masks_high_res", |
| "high_res_masks", |
| strict=False, |
| ) |
| _append(prev_output, mask_output, "object_score_logits", "object_score_logits") |
| if self.use_memory_selection: |
| mask_output["ious"] = mask_output["ious"].squeeze(-1) |
| _append(prev_output, mask_output, "iou_score", "ious") |
|
|
| |
| if "input_masks" in prev_output: |
| prev_output["input_masks"] = torch.cat( |
| [prev_output["input_masks"], new_masks], dim=0 |
| ) |
|
|
| if self.use_obj_ptrs_in_encoder: |
| |
| |
| new_pointers = mask_output["obj_ptr"].to(existing_pointers.dtype) |
| combined_pointers = torch.cat([existing_pointers, new_pointers], dim=0) |
| prev_output["obj_ptr"] = multiplex_state.mux(combined_pointers) |
|
|
| |
| prev_output["conditioning_objects"].update(new_object_idx) |
|
|
| |
| if add_mask_to_memory: |
| assert ( |
| prev_output["pred_masks_high_res"].shape[0] |
| == multiplex_state.total_valid_entries |
| ) |
| |
| maskmem_features, maskmem_pos_enc = self._encode_new_memory( |
| image=None, |
| current_vision_feats=propagation_vision_feats, |
| feat_sizes=propagation_feat_sizes, |
| pred_masks_high_res=prev_output["pred_masks_high_res"], |
| object_score_logits=prev_output["object_score_logits"], |
| conditioning_objects=prev_output["conditioning_objects"], |
| is_mask_from_pts=are_masks_from_pts, |
| multiplex_state=multiplex_state, |
| ) |
| prev_output["maskmem_features"] = maskmem_features |
| prev_output["maskmem_pos_enc"] = maskmem_pos_enc |
| if self.save_image_features: |
| |
| assert "image_features" in prev_output |
| assert "image_pos_enc" in prev_output |
|
|
| def recondition_masks_in_existing_state( |
| self, |
| *, |
| interactive_pix_feat: torch.Tensor, |
| interactive_high_res_features: list[torch.Tensor], |
| propagation_vision_feats: Optional[ |
| list[torch.Tensor] |
| ], |
| propagation_feat_sizes: Optional[ |
| list[tuple[int, int]] |
| ], |
| new_masks: torch.Tensor, |
| obj_idxs_in_mask: list[ |
| int |
| ], |
| obj_ids_in_mask: Optional[ |
| list[int] |
| ], |
| prev_output: StageOutput, |
| multiplex_state: MultiplexState, |
| add_mask_to_memory: bool = True, |
| ) -> None: |
| """ |
| Recondition existing objects in an existing output/multiplex state. |
| |
| This function encodes the input masks and merges them with the existing state. |
| """ |
| assert self.use_mask_input_as_output_without_sam |
| assert new_masks.shape[0] == len(obj_idxs_in_mask) |
|
|
| num_new_objects = new_masks.shape[0] |
|
|
| if obj_ids_in_mask is not None: |
| assert len(obj_ids_in_mask) == num_new_objects |
|
|
| if self.use_obj_ptrs_in_encoder: |
| |
| existing_pointers = multiplex_state.demux(prev_output["obj_ptr"]) |
|
|
| |
| mask_output = self._use_mask_as_output( |
| backbone_features=interactive_pix_feat, |
| high_res_features=interactive_high_res_features, |
| mask_inputs=new_masks, |
| multiplex_state=multiplex_state, |
| objects_in_mask=obj_idxs_in_mask, |
| ) |
|
|
| |
| |
| h, w = prev_output["pred_masks"].shape[-2:] |
| mask_output["low_res_masks"] = F.interpolate( |
| mask_output["low_res_masks"], |
| size=(h, w), |
| align_corners=False, |
| mode="bilinear", |
| antialias=True, |
| ) |
|
|
| _merge( |
| prev_output, mask_output, "pred_masks", "low_res_masks", obj_idxs_in_mask |
| ) |
| _merge( |
| prev_output, |
| mask_output, |
| "pred_masks_high_res", |
| "high_res_masks", |
| obj_idxs_in_mask, |
| strict=False, |
| ) |
| _merge( |
| prev_output, |
| mask_output, |
| "object_score_logits", |
| "object_score_logits", |
| obj_idxs_in_mask, |
| ) |
| if self.use_memory_selection: |
| mask_output["ious"] = mask_output["ious"].squeeze(-1) |
| _merge( |
| prev_output, |
| mask_output, |
| "iou_score", |
| "ious", |
| obj_idxs_in_mask, |
| ) |
|
|
| |
| if "input_masks" in prev_output: |
| prev_output["input_masks"][obj_idxs_in_mask] = new_masks |
|
|
| if self.use_obj_ptrs_in_encoder: |
| |
| |
| new_pointers = mask_output["obj_ptr"].to(existing_pointers.dtype) |
| existing_pointers[obj_idxs_in_mask] = new_pointers |
| prev_output["obj_ptr"] = multiplex_state.mux(existing_pointers) |
|
|
| |
| prev_output["conditioning_objects"].update(obj_idxs_in_mask) |
|
|
| |
| if add_mask_to_memory: |
| assert ( |
| prev_output["pred_masks_high_res"].shape[0] |
| == multiplex_state.total_valid_entries |
| ) |
| |
| maskmem_features, maskmem_pos_enc = self._encode_new_memory( |
| image=None, |
| current_vision_feats=propagation_vision_feats, |
| feat_sizes=propagation_feat_sizes, |
| pred_masks_high_res=prev_output["pred_masks_high_res"], |
| object_score_logits=prev_output["object_score_logits"], |
| conditioning_objects=prev_output["conditioning_objects"], |
| is_mask_from_pts=False, |
| multiplex_state=multiplex_state, |
| ) |
| prev_output["maskmem_features"] = maskmem_features |
| prev_output["maskmem_pos_enc"] = maskmem_pos_enc |
| if self.save_image_features: |
| |
| assert "image_features" in prev_output |
| assert "image_pos_enc" in prev_output |
|
|
| def track_step( |
| self, |
| *, |
| frame_idx, |
| is_init_cond_frame, |
| backbone_features_interactive, |
| backbone_features_propagation, |
| image, |
| point_inputs, |
| mask_inputs, |
| gt_masks, |
| frames_to_add_correction_pt, |
| output_dict, |
| num_frames, |
| track_in_reverse=False, |
| |
| |
| |
| |
| |
| run_mem_encoder=True, |
| |
| prev_sam_mask_logits=None, |
| multiplex_state: MultiplexState, |
| |
| |
| objects_to_interact: Optional[list[int]] = None, |
| |
| new_object_masks: Optional[torch.Tensor] = None, |
| new_object_idxs: Optional[list[int]] = None, |
| new_object_ids: Optional[list[int]] = None, |
| are_new_masks_from_pts: bool = False, |
| ) -> StageOutput: |
| |
| |
| current_out, aux_out = self._track_step_aux( |
| frame_idx=frame_idx, |
| is_init_cond_frame=is_init_cond_frame, |
| backbone_features_interactive=backbone_features_interactive, |
| backbone_features_propagation=backbone_features_propagation, |
| image=image, |
| point_inputs=point_inputs, |
| mask_inputs=mask_inputs, |
| gt_masks=gt_masks, |
| frames_to_add_correction_pt=frames_to_add_correction_pt, |
| output_dict=output_dict, |
| num_frames=num_frames, |
| track_in_reverse=track_in_reverse, |
| run_mem_encoder=(run_mem_encoder and new_object_masks is None), |
| prev_sam_mask_logits=prev_sam_mask_logits, |
| multiplex_state=multiplex_state, |
| objects_to_interact=objects_to_interact, |
| need_aux_output=(new_object_masks is not None), |
| ) |
|
|
| |
| if new_object_masks is not None: |
| assert new_object_idxs is not None |
| self.add_new_masks_to_existing_state( |
| interactive_pix_feat=aux_out["interactive_pix_feat"], |
| interactive_high_res_features=aux_out["interactive_high_res_features"], |
| propagation_vision_feats=aux_out["propagation_vision_feats"], |
| propagation_feat_sizes=aux_out["propagation_feat_sizes"], |
| new_masks=new_object_masks, |
| obj_idxs_in_mask=new_object_idxs, |
| obj_ids_in_mask=new_object_ids, |
| prev_output=current_out, |
| multiplex_state=multiplex_state, |
| add_mask_to_memory=run_mem_encoder, |
| are_masks_from_pts=are_new_masks_from_pts, |
| ) |
|
|
| |
| current_out = self._trim_output_and_memory( |
| frame_idx=frame_idx, |
| output_dict=output_dict, |
| current_out=current_out, |
| memory_encoder_was_used=run_mem_encoder, |
| ) |
|
|
| return current_out |
|
|
| def forward_tracking( |
| self, |
| backbone_out, |
| input, |
| return_dict=False, |
| objects_to_interact: Optional[list[int]] = None, |
| ): |
| """Forward video tracking on each frame (and sample correction clicks).""" |
| img_feats_already_computed = ( |
| "interactive" in backbone_out or "sam2_backbone_out" in backbone_out |
| ) |
| if img_feats_already_computed: |
| |
| |
| |
| backbone_features = self._prepare_backbone_features(backbone_out) |
|
|
| |
| num_frames = backbone_out["num_frames"] |
| init_cond_frames = backbone_out["init_cond_frames"] |
| frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"] |
| |
| |
| processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"] |
|
|
| new_idx_per_transition = backbone_out["new_idx_per_transition"] |
| valid_objects_prior_to_each_transition = backbone_out[ |
| "valid_objects_prior_to_each_transition" |
| ] |
| transition_points = backbone_out["transition_points"] |
|
|
| cond_frame_outputs: dict[int, StageOutput] = {} |
| non_cond_frame_outputs: dict[int, StageOutput] = {} |
| output_dict = { |
| "cond_frame_outputs": cond_frame_outputs, |
| "non_cond_frame_outputs": non_cond_frame_outputs, |
| } |
| multiplex_state = self.multiplex_controller.get_state( |
| backbone_out["gt_masks_per_frame"][processing_order[0]].shape[0], |
| device=backbone_out["gt_masks_per_frame"][processing_order[0]].device, |
| dtype=torch.float, |
| random=self.training, |
| ) |
|
|
| for stage_id in processing_order: |
| |
| img_ids = input.find_inputs[stage_id].img_ids |
| |
| assert all( |
| [img_id == img_ids[0] for img_id in img_ids] |
| ) |
| |
| img_ids = torch.tensor( |
| [img_ids[0]], device=img_ids.device, dtype=img_ids.dtype |
| ) |
|
|
| if img_feats_already_computed: |
| |
| current_image = input.img_batch.tensors[img_ids] |
| current_backbone_features = {} |
| for neck_k, neck_out in backbone_features.items(): |
| current_backbone_features[neck_k] = { |
| "vision_feats": [ |
| x[:, img_ids] for x in neck_out["vision_feats"] |
| ], |
| "vision_masks": [ |
| x[img_ids] if x is not None else None |
| for x in neck_out["vision_masks"] |
| ], |
| "vision_pos_embeds": [ |
| x[:, img_ids] for x in neck_out["vision_pos_embeds"] |
| ], |
| "feat_sizes": neck_out["feat_sizes"], |
| } |
| else: |
| |
| |
| need_interactive_out = ( |
| (stage_id in frames_to_add_correction_pt) |
| or (stage_id in init_cond_frames) |
| or (stage_id in transition_points) |
| ) |
| (current_image, current_backbone_features) = ( |
| self._prepare_backbone_features_per_frame( |
| input.img_batch, |
| img_ids, |
| need_interactive_out=need_interactive_out, |
| need_propagation_out=True, |
| ) |
| ) |
|
|
| gt_masks = backbone_out["gt_masks_per_frame"].get(stage_id, None) |
| if stage_id in transition_points: |
| assert gt_masks is not None |
|
|
| |
| new_object_idxs = new_idx_per_transition[stage_id] |
| |
| assert sorted(new_object_idxs) == new_object_idxs |
| assert new_object_idxs[0] == len( |
| valid_objects_prior_to_each_transition[stage_id] |
| ), ( |
| f"{new_object_idxs=}; {gt_masks.shape=}; {valid_objects_prior_to_each_transition[stage_id]=}" |
| ) |
| assert new_object_idxs[-1] == (len(gt_masks) - 1), ( |
| f"{new_object_idxs=}; {gt_masks.shape=}" |
| ) |
| new_object_masks = gt_masks[new_object_idxs] |
|
|
| |
| gt_masks = gt_masks[: new_object_idxs[0]] |
| else: |
| new_object_masks = None |
| new_object_idxs = None |
|
|
| |
| current_out = self.track_step( |
| frame_idx=stage_id, |
| is_init_cond_frame=stage_id in init_cond_frames, |
| backbone_features_interactive=current_backbone_features.get( |
| "interactive" |
| ), |
| backbone_features_propagation=current_backbone_features.get( |
| "sam2_backbone_out" |
| ), |
| image=current_image, |
| point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None), |
| mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None), |
| gt_masks=gt_masks, |
| frames_to_add_correction_pt=frames_to_add_correction_pt, |
| output_dict=output_dict, |
| num_frames=num_frames, |
| multiplex_state=multiplex_state, |
| objects_to_interact=objects_to_interact, |
| new_object_masks=new_object_masks, |
| new_object_idxs=new_object_idxs, |
| ) |
| |
| add_output_as_cond_frame = ( |
| stage_id in init_cond_frames |
| or ( |
| self.add_all_frames_to_correct_as_cond |
| and stage_id in frames_to_add_correction_pt |
| ) |
| or ( |
| self.add_all_transition_frames_as_cond |
| and stage_id in transition_points |
| ) |
| ) |
|
|
| if add_output_as_cond_frame: |
| output_dict["cond_frame_outputs"][stage_id] = current_out |
| else: |
| output_dict["non_cond_frame_outputs"][stage_id] = current_out |
|
|
| output_dict["multiplex_state"] = multiplex_state |
|
|
| if return_dict: |
| return output_dict |
| |
| all_frame_outputs = {} |
| all_frame_outputs.update(output_dict["cond_frame_outputs"]) |
| all_frame_outputs.update(output_dict["non_cond_frame_outputs"]) |
| if self.is_dynamic_vos_evaluation: |
| all_frame_outputs = [all_frame_outputs.get(t) for t in range(num_frames)] |
| else: |
| all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)] |
| |
| all_frame_outputs = [ |
| {k: v for k, v in d.items() if k != "obj_ptr"} if d is not None else None |
| for d in all_frame_outputs |
| ] |
|
|
| if self.is_dynamic_vos_evaluation: |
| object_appearance_order = backbone_out["object_appearance_order"] |
| num_objects = len(input.find_metadatas[0].coco_image_id) |
|
|
| |
| inverse_object_appearance_order = [None for _ in object_appearance_order] |
| for idx, obj_id in enumerate(object_appearance_order): |
| inverse_object_appearance_order[obj_id] = idx |
| assert all(i is not None for i in inverse_object_appearance_order) |
|
|
| |
| |
| |
| |
| |
| |
| if len(inverse_object_appearance_order) < num_objects: |
| inverse_object_appearance_order.extend( |
| list(range(len(inverse_object_appearance_order), num_objects)) |
| ) |
|
|
| |
| last_mask = all_frame_outputs[-1]["pred_masks"] |
|
|
| shape = last_mask.shape[1:] |
| dtype = last_mask.dtype |
| device = last_mask.device |
| for stage_i, frame_out in enumerate(all_frame_outputs): |
| if frame_out is None: |
| all_frame_outputs[stage_i] = { |
| "pred_masks": torch.zeros( |
| (num_objects, *shape), device=device, dtype=dtype |
| ) |
| } |
| continue |
|
|
| pred_mask = frame_out["pred_masks"] |
| if pred_mask.shape[0] < num_objects: |
| shape = pred_mask.shape[ |
| 1: |
| ] |
| frame_out["pred_masks"] = torch.cat( |
| [ |
| pred_mask, |
| torch.zeros( |
| (num_objects - pred_mask.shape[0], *shape), |
| device=device, |
| dtype=dtype, |
| ), |
| ], |
| dim=0, |
| )[inverse_object_appearance_order] |
|
|
| return all_frame_outputs |
|
|