| | |
| | |
| |
|
| | |
| | |
| |
|
| | import warnings |
| | from collections import OrderedDict |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | from tqdm import tqdm |
| |
|
| | from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base |
| | from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames |
| |
|
| |
|
| | class SAM2VideoPredictor(SAM2Base): |
| | """The predictor class to handle user interactions and manage inference states.""" |
| |
|
| | def __init__( |
| | self, |
| | fill_hole_area=0, |
| | |
| | non_overlap_masks=False, |
| | |
| | |
| | clear_non_cond_mem_around_input=False, |
| | |
| | |
| | add_all_frames_to_correct_as_cond=False, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| | self.fill_hole_area = fill_hole_area |
| | self.non_overlap_masks = non_overlap_masks |
| | self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input |
| | self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond |
| |
|
| | @torch.inference_mode() |
| | def init_state( |
| | self, |
| | video_path, |
| | offload_video_to_cpu=False, |
| | offload_state_to_cpu=False, |
| | async_loading_frames=False, |
| | ): |
| | """Initialize an inference state.""" |
| | compute_device = self.device |
| | images, video_height, video_width = load_video_frames( |
| | video_path=video_path, |
| | image_size=self.image_size, |
| | offload_video_to_cpu=offload_video_to_cpu, |
| | async_loading_frames=async_loading_frames, |
| | compute_device=compute_device, |
| | ) |
| | inference_state = {} |
| | inference_state["images"] = images |
| | inference_state["num_frames"] = len(images) |
| | |
| | |
| | inference_state["offload_video_to_cpu"] = offload_video_to_cpu |
| | |
| | |
| | |
| | |
| | inference_state["offload_state_to_cpu"] = offload_state_to_cpu |
| | |
| | inference_state["video_height"] = video_height |
| | inference_state["video_width"] = video_width |
| | inference_state["device"] = compute_device |
| | if offload_state_to_cpu: |
| | inference_state["storage_device"] = torch.device("cpu") |
| | else: |
| | inference_state["storage_device"] = compute_device |
| | |
| | inference_state["point_inputs_per_obj"] = {} |
| | inference_state["mask_inputs_per_obj"] = {} |
| | |
| | inference_state["cached_features"] = {} |
| | |
| | inference_state["constants"] = {} |
| | |
| | inference_state["obj_id_to_idx"] = OrderedDict() |
| | inference_state["obj_idx_to_id"] = OrderedDict() |
| | inference_state["obj_ids"] = [] |
| | |
| | inference_state["output_dict_per_obj"] = {} |
| | |
| | |
| | inference_state["temp_output_dict_per_obj"] = {} |
| | |
| | |
| | |
| | inference_state["frames_tracked_per_obj"] = {} |
| | |
| | self._get_image_feature(inference_state, frame_idx=0, batch_size=1) |
| | return inference_state |
| |
|
| | @classmethod |
| | def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor": |
| | """ |
| | Load a pretrained model from the Hugging Face hub. |
| | |
| | Arguments: |
| | model_id (str): The Hugging Face repository ID. |
| | **kwargs: Additional arguments to pass to the model constructor. |
| | |
| | Returns: |
| | (SAM2VideoPredictor): The loaded model. |
| | """ |
| | from sam2.build_sam import build_sam2_video_predictor_hf |
| |
|
| | sam_model = build_sam2_video_predictor_hf(model_id, **kwargs) |
| | return sam_model |
| |
|
| | def _obj_id_to_idx(self, inference_state, obj_id): |
| | """Map client-side object id to model-side object index.""" |
| | obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) |
| | if obj_idx is not None: |
| | return obj_idx |
| |
|
| | |
| | allow_new_object = True |
| | if allow_new_object: |
| | |
| | obj_idx = len(inference_state["obj_id_to_idx"]) |
| | inference_state["obj_id_to_idx"][obj_id] = obj_idx |
| | inference_state["obj_idx_to_id"][obj_idx] = obj_id |
| | inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"]) |
| | |
| | inference_state["point_inputs_per_obj"][obj_idx] = {} |
| | inference_state["mask_inputs_per_obj"][obj_idx] = {} |
| | inference_state["output_dict_per_obj"][obj_idx] = { |
| | "cond_frame_outputs": {}, |
| | "non_cond_frame_outputs": {}, |
| | } |
| | inference_state["temp_output_dict_per_obj"][obj_idx] = { |
| | "cond_frame_outputs": {}, |
| | "non_cond_frame_outputs": {}, |
| | } |
| | inference_state["frames_tracked_per_obj"][obj_idx] = {} |
| | return obj_idx |
| | else: |
| | raise RuntimeError( |
| | f"Cannot add new object id {obj_id} after tracking starts. " |
| | f"All existing object ids: {inference_state['obj_ids']}. " |
| | f"Please call 'reset_state' to restart from scratch." |
| | ) |
| |
|
| | def _obj_idx_to_id(self, inference_state, obj_idx): |
| | """Map model-side object index to client-side object id.""" |
| | return inference_state["obj_idx_to_id"][obj_idx] |
| |
|
| | def _get_obj_num(self, inference_state): |
| | """Get the total number of unique object ids received so far in this session.""" |
| | return len(inference_state["obj_idx_to_id"]) |
| |
|
| | @torch.inference_mode() |
| | def add_new_points_or_box( |
| | self, |
| | inference_state, |
| | frame_idx, |
| | obj_id, |
| | points=None, |
| | labels=None, |
| | clear_old_points=True, |
| | normalize_coords=True, |
| | box=None, |
| | ): |
| | """Add new points to a frame.""" |
| | obj_idx = self._obj_id_to_idx(inference_state, obj_id) |
| | point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] |
| | mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] |
| |
|
| | if (points is not None) != (labels is not None): |
| | raise ValueError("points and labels must be provided together") |
| | if points is None and box is None: |
| | raise ValueError("at least one of points or box must be provided as input") |
| |
|
| | if points is None: |
| | points = torch.zeros(0, 2, dtype=torch.float32) |
| | elif not isinstance(points, torch.Tensor): |
| | points = torch.tensor(points, dtype=torch.float32) |
| | if labels is None: |
| | labels = torch.zeros(0, dtype=torch.int32) |
| | elif not isinstance(labels, torch.Tensor): |
| | labels = torch.tensor(labels, dtype=torch.int32) |
| | if points.dim() == 2: |
| | points = points.unsqueeze(0) |
| | if labels.dim() == 1: |
| | labels = labels.unsqueeze(0) |
| |
|
| | |
| | |
| | if box is not None: |
| | if not clear_old_points: |
| | raise ValueError( |
| | "cannot add box without clearing old points, since " |
| | "box prompt must be provided before any point prompt " |
| | "(please use clear_old_points=True instead)" |
| | ) |
| | if not isinstance(box, torch.Tensor): |
| | box = torch.tensor(box, dtype=torch.float32, device=points.device) |
| | box_coords = box.reshape(1, 2, 2) |
| | box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device) |
| | box_labels = box_labels.reshape(1, 2) |
| | points = torch.cat([box_coords, points], dim=1) |
| | labels = torch.cat([box_labels, labels], dim=1) |
| |
|
| | if normalize_coords: |
| | video_H = inference_state["video_height"] |
| | video_W = inference_state["video_width"] |
| | points = points / torch.tensor([video_W, video_H]).to(points.device) |
| | |
| | points = points * self.image_size |
| | points = points.to(inference_state["device"]) |
| | labels = labels.to(inference_state["device"]) |
| |
|
| | if not clear_old_points: |
| | point_inputs = point_inputs_per_frame.get(frame_idx, None) |
| | else: |
| | point_inputs = None |
| | point_inputs = concat_points(point_inputs, points, labels) |
| |
|
| | point_inputs_per_frame[frame_idx] = point_inputs |
| | mask_inputs_per_frame.pop(frame_idx, None) |
| | |
| | |
| | |
| | |
| | obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx] |
| | is_init_cond_frame = frame_idx not in obj_frames_tracked |
| | |
| | if is_init_cond_frame: |
| | reverse = False |
| | else: |
| | reverse = obj_frames_tracked[frame_idx]["reverse"] |
| | obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] |
| | obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] |
| | |
| | |
| | is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond |
| | storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" |
| |
|
| | |
| | |
| | prev_sam_mask_logits = None |
| | |
| | |
| | prev_out = obj_temp_output_dict[storage_key].get(frame_idx) |
| | if prev_out is None: |
| | prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx) |
| | if prev_out is None: |
| | prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) |
| |
|
| | if prev_out is not None and prev_out["pred_masks"] is not None: |
| | device = inference_state["device"] |
| | prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True) |
| | |
| | prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) |
| | current_out, _ = self._run_single_frame_inference( |
| | inference_state=inference_state, |
| | output_dict=obj_output_dict, |
| | frame_idx=frame_idx, |
| | batch_size=1, |
| | is_init_cond_frame=is_init_cond_frame, |
| | point_inputs=point_inputs, |
| | mask_inputs=None, |
| | reverse=reverse, |
| | |
| | |
| | |
| | |
| | run_mem_encoder=False, |
| | prev_sam_mask_logits=prev_sam_mask_logits, |
| | ) |
| | |
| | obj_temp_output_dict[storage_key][frame_idx] = current_out |
| |
|
| | |
| | obj_ids = inference_state["obj_ids"] |
| | consolidated_out = self._consolidate_temp_output_across_obj( |
| | inference_state, |
| | frame_idx, |
| | is_cond=is_cond, |
| | consolidate_at_video_res=True, |
| | ) |
| | _, video_res_masks = self._get_orig_video_res_output( |
| | inference_state, consolidated_out["pred_masks_video_res"] |
| | ) |
| | return frame_idx, obj_ids, video_res_masks |
| |
|
| | def add_new_points(self, *args, **kwargs): |
| | """Deprecated method. Please use `add_new_points_or_box` instead.""" |
| | return self.add_new_points_or_box(*args, **kwargs) |
| |
|
| | @torch.inference_mode() |
| | def add_new_mask( |
| | self, |
| | inference_state, |
| | frame_idx, |
| | obj_id, |
| | mask, |
| | ): |
| | """Add new mask to a frame.""" |
| | obj_idx = self._obj_id_to_idx(inference_state, obj_id) |
| | point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] |
| | mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] |
| |
|
| | if not isinstance(mask, torch.Tensor): |
| | mask = torch.tensor(mask, dtype=torch.bool) |
| | assert mask.dim() == 2 |
| | mask_H, mask_W = mask.shape |
| | mask_inputs_orig = mask[None, None] |
| | mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"]) |
| |
|
| | |
| | if mask_H != self.image_size or mask_W != self.image_size: |
| | mask_inputs = torch.nn.functional.interpolate( |
| | mask_inputs_orig, |
| | size=(self.image_size, self.image_size), |
| | align_corners=False, |
| | mode="bilinear", |
| | antialias=True, |
| | ) |
| | mask_inputs = (mask_inputs >= 0.5).float() |
| | else: |
| | mask_inputs = mask_inputs_orig |
| |
|
| | mask_inputs_per_frame[frame_idx] = mask_inputs |
| | point_inputs_per_frame.pop(frame_idx, None) |
| | |
| | |
| | |
| | |
| | obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx] |
| | is_init_cond_frame = frame_idx not in obj_frames_tracked |
| | |
| | if is_init_cond_frame: |
| | reverse = False |
| | else: |
| | reverse = obj_frames_tracked[frame_idx]["reverse"] |
| | obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] |
| | obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] |
| | |
| | |
| | is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond |
| | storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" |
| |
|
| | current_out, _ = self._run_single_frame_inference( |
| | inference_state=inference_state, |
| | output_dict=obj_output_dict, |
| | frame_idx=frame_idx, |
| | batch_size=1, |
| | is_init_cond_frame=is_init_cond_frame, |
| | point_inputs=None, |
| | mask_inputs=mask_inputs, |
| | reverse=reverse, |
| | |
| | |
| | |
| | |
| | run_mem_encoder=False, |
| | ) |
| | |
| | obj_temp_output_dict[storage_key][frame_idx] = current_out |
| |
|
| | |
| | obj_ids = inference_state["obj_ids"] |
| | consolidated_out = self._consolidate_temp_output_across_obj( |
| | inference_state, |
| | frame_idx, |
| | is_cond=is_cond, |
| | consolidate_at_video_res=True, |
| | ) |
| | _, video_res_masks = self._get_orig_video_res_output( |
| | inference_state, consolidated_out["pred_masks_video_res"] |
| | ) |
| | return frame_idx, obj_ids, video_res_masks |
| |
|
| | def _get_orig_video_res_output(self, inference_state, any_res_masks): |
| | """ |
| | Resize the object scores to the original video resolution (video_res_masks) |
| | and apply non-overlapping constraints for final output. |
| | """ |
| | device = inference_state["device"] |
| | video_H = inference_state["video_height"] |
| | video_W = inference_state["video_width"] |
| | any_res_masks = any_res_masks.to(device, non_blocking=True) |
| | if any_res_masks.shape[-2:] == (video_H, video_W): |
| | video_res_masks = any_res_masks |
| | else: |
| | video_res_masks = torch.nn.functional.interpolate( |
| | any_res_masks, |
| | size=(video_H, video_W), |
| | mode="bilinear", |
| | align_corners=False, |
| | ) |
| | if self.non_overlap_masks: |
| | video_res_masks = self._apply_non_overlapping_constraints(video_res_masks) |
| | return any_res_masks, video_res_masks |
| |
|
| | def _consolidate_temp_output_across_obj( |
| | self, |
| | inference_state, |
| | frame_idx, |
| | is_cond, |
| | consolidate_at_video_res=False, |
| | ): |
| | """ |
| | Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on |
| | a frame into a single output for all objects, including |
| | 1) fill any missing objects either from `output_dict_per_obj` (if they exist in |
| | `output_dict_per_obj` for this frame) or leave them as placeholder values |
| | (if they don't exist in `output_dict_per_obj` for this frame); |
| | 2) if specified, rerun memory encoder after apply non-overlapping constraints |
| | on the object scores. |
| | """ |
| | batch_size = self._get_obj_num(inference_state) |
| | storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" |
| | |
| | |
| | if consolidate_at_video_res: |
| | consolidated_H = inference_state["video_height"] |
| | consolidated_W = inference_state["video_width"] |
| | consolidated_mask_key = "pred_masks_video_res" |
| | else: |
| | consolidated_H = consolidated_W = self.image_size // 4 |
| | consolidated_mask_key = "pred_masks" |
| |
|
| | |
| | |
| | |
| | |
| | consolidated_out = { |
| | consolidated_mask_key: torch.full( |
| | size=(batch_size, 1, consolidated_H, consolidated_W), |
| | fill_value=NO_OBJ_SCORE, |
| | dtype=torch.float32, |
| | device=inference_state["storage_device"], |
| | ), |
| | } |
| | for obj_idx in range(batch_size): |
| | obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] |
| | obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] |
| | out = obj_temp_output_dict[storage_key].get(frame_idx, None) |
| | |
| | |
| | |
| | |
| | if out is None: |
| | out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None) |
| | if out is None: |
| | out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None) |
| | |
| | |
| | |
| | if out is None: |
| | continue |
| | |
| | obj_mask = out["pred_masks"] |
| | consolidated_pred_masks = consolidated_out[consolidated_mask_key] |
| | if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]: |
| | consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask |
| | else: |
| | |
| | resized_obj_mask = torch.nn.functional.interpolate( |
| | obj_mask, |
| | size=consolidated_pred_masks.shape[-2:], |
| | mode="bilinear", |
| | align_corners=False, |
| | ) |
| | consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask |
| |
|
| | return consolidated_out |
| |
|
| | @torch.inference_mode() |
| | def propagate_in_video_preflight(self, inference_state): |
| | """Prepare inference_state and consolidate temporary outputs before tracking.""" |
| | |
| | batch_size = self._get_obj_num(inference_state) |
| | if batch_size == 0: |
| | raise RuntimeError( |
| | "No input points or masks are provided for any object; please add inputs first." |
| | ) |
| |
|
| | |
| | |
| | for obj_idx in range(batch_size): |
| | obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] |
| | obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] |
| | for is_cond in [False, True]: |
| | |
| | storage_key = ( |
| | "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" |
| | ) |
| | |
| | |
| | |
| | for frame_idx, out in obj_temp_output_dict[storage_key].items(): |
| | |
| | if out["maskmem_features"] is None: |
| | high_res_masks = torch.nn.functional.interpolate( |
| | out["pred_masks"].to(inference_state["device"]), |
| | size=(self.image_size, self.image_size), |
| | mode="bilinear", |
| | align_corners=False, |
| | ) |
| | maskmem_features, maskmem_pos_enc = self._run_memory_encoder( |
| | inference_state=inference_state, |
| | frame_idx=frame_idx, |
| | batch_size=1, |
| | high_res_masks=high_res_masks, |
| | object_score_logits=out["object_score_logits"], |
| | |
| | is_mask_from_pts=True, |
| | ) |
| | out["maskmem_features"] = maskmem_features |
| | out["maskmem_pos_enc"] = maskmem_pos_enc |
| |
|
| | obj_output_dict[storage_key][frame_idx] = out |
| | if self.clear_non_cond_mem_around_input: |
| | |
| | self._clear_obj_non_cond_mem_around_input( |
| | inference_state, frame_idx, obj_idx |
| | ) |
| |
|
| | |
| | obj_temp_output_dict[storage_key].clear() |
| |
|
| | |
| | obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] |
| | if len(obj_output_dict["cond_frame_outputs"]) == 0: |
| | obj_id = self._obj_idx_to_id(inference_state, obj_idx) |
| | raise RuntimeError( |
| | f"No input points or masks are provided for object id {obj_id}; please add inputs first." |
| | ) |
| | |
| | |
| | for frame_idx in obj_output_dict["cond_frame_outputs"]: |
| | obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) |
| |
|
| | @torch.inference_mode() |
| | def propagate_in_video( |
| | self, |
| | inference_state, |
| | start_frame_idx=None, |
| | max_frame_num_to_track=None, |
| | reverse=False, |
| | ): |
| | """Propagate the input points across frames to track in the entire video.""" |
| | self.propagate_in_video_preflight(inference_state) |
| |
|
| | obj_ids = inference_state["obj_ids"] |
| | num_frames = inference_state["num_frames"] |
| | batch_size = self._get_obj_num(inference_state) |
| |
|
| | |
| | if start_frame_idx is None: |
| | |
| | start_frame_idx = min( |
| | t |
| | for obj_output_dict in inference_state["output_dict_per_obj"].values() |
| | for t in obj_output_dict["cond_frame_outputs"] |
| | ) |
| | if max_frame_num_to_track is None: |
| | |
| | max_frame_num_to_track = num_frames |
| | if reverse: |
| | end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) |
| | if start_frame_idx > 0: |
| | processing_order = range(start_frame_idx, end_frame_idx - 1, -1) |
| | else: |
| | processing_order = [] |
| | else: |
| | end_frame_idx = min( |
| | start_frame_idx + max_frame_num_to_track, num_frames - 1 |
| | ) |
| | processing_order = range(start_frame_idx, end_frame_idx + 1) |
| |
|
| | for frame_idx in tqdm(processing_order, desc="propagate in video"): |
| | pred_masks_per_obj = [None] * batch_size |
| | for obj_idx in range(batch_size): |
| | obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] |
| | |
| | |
| | |
| | |
| | if frame_idx in obj_output_dict["cond_frame_outputs"]: |
| | storage_key = "cond_frame_outputs" |
| | current_out = obj_output_dict[storage_key][frame_idx] |
| | device = inference_state["device"] |
| | pred_masks = current_out["pred_masks"].to(device, non_blocking=True) |
| | if self.clear_non_cond_mem_around_input: |
| | |
| | self._clear_obj_non_cond_mem_around_input( |
| | inference_state, frame_idx, obj_idx |
| | ) |
| | else: |
| | storage_key = "non_cond_frame_outputs" |
| | current_out, pred_masks = self._run_single_frame_inference( |
| | inference_state=inference_state, |
| | output_dict=obj_output_dict, |
| | frame_idx=frame_idx, |
| | batch_size=1, |
| | is_init_cond_frame=False, |
| | point_inputs=None, |
| | mask_inputs=None, |
| | reverse=reverse, |
| | run_mem_encoder=True, |
| | ) |
| | obj_output_dict[storage_key][frame_idx] = current_out |
| |
|
| | inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = { |
| | "reverse": reverse |
| | } |
| | pred_masks_per_obj[obj_idx] = pred_masks |
| |
|
| | |
| | |
| | if len(pred_masks_per_obj) > 1: |
| | all_pred_masks = torch.cat(pred_masks_per_obj, dim=0) |
| | else: |
| | all_pred_masks = pred_masks_per_obj[0] |
| | _, video_res_masks = self._get_orig_video_res_output( |
| | inference_state, all_pred_masks |
| | ) |
| | yield frame_idx, obj_ids, video_res_masks |
| |
|
| | @torch.inference_mode() |
| | def clear_all_prompts_in_frame( |
| | self, inference_state, frame_idx, obj_id, need_output=True |
| | ): |
| | """Remove all input points or mask in a specific frame for a given object.""" |
| | obj_idx = self._obj_id_to_idx(inference_state, obj_id) |
| |
|
| | |
| | inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None) |
| | inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None) |
| |
|
| | temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] |
| | temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None) |
| | temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None) |
| |
|
| | |
| | obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] |
| | out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None) |
| | if out is not None: |
| | |
| | |
| | obj_output_dict["non_cond_frame_outputs"][frame_idx] = out |
| | inference_state["frames_tracked_per_obj"][obj_idx].pop(frame_idx, None) |
| |
|
| | if not need_output: |
| | return |
| | |
| | obj_ids = inference_state["obj_ids"] |
| | is_cond = any( |
| | frame_idx in obj_temp_output_dict["cond_frame_outputs"] |
| | for obj_temp_output_dict in temp_output_dict_per_obj.values() |
| | ) |
| | consolidated_out = self._consolidate_temp_output_across_obj( |
| | inference_state, |
| | frame_idx, |
| | is_cond=is_cond, |
| | consolidate_at_video_res=True, |
| | ) |
| | _, video_res_masks = self._get_orig_video_res_output( |
| | inference_state, consolidated_out["pred_masks_video_res"] |
| | ) |
| | return frame_idx, obj_ids, video_res_masks |
| |
|
| | @torch.inference_mode() |
| | def reset_state(self, inference_state): |
| | """Remove all input points or mask in all frames throughout the video.""" |
| | self._reset_tracking_results(inference_state) |
| | |
| | inference_state["obj_id_to_idx"].clear() |
| | inference_state["obj_idx_to_id"].clear() |
| | inference_state["obj_ids"].clear() |
| | inference_state["point_inputs_per_obj"].clear() |
| | inference_state["mask_inputs_per_obj"].clear() |
| | inference_state["output_dict_per_obj"].clear() |
| | inference_state["temp_output_dict_per_obj"].clear() |
| | inference_state["frames_tracked_per_obj"].clear() |
| |
|
| | def _reset_tracking_results(self, inference_state): |
| | """Reset all tracking inputs and results across the videos.""" |
| | for v in inference_state["point_inputs_per_obj"].values(): |
| | v.clear() |
| | for v in inference_state["mask_inputs_per_obj"].values(): |
| | v.clear() |
| | for v in inference_state["output_dict_per_obj"].values(): |
| | v["cond_frame_outputs"].clear() |
| | v["non_cond_frame_outputs"].clear() |
| | for v in inference_state["temp_output_dict_per_obj"].values(): |
| | v["cond_frame_outputs"].clear() |
| | v["non_cond_frame_outputs"].clear() |
| | for v in inference_state["frames_tracked_per_obj"].values(): |
| | v.clear() |
| |
|
| | def _get_image_feature(self, inference_state, frame_idx, batch_size): |
| | """Compute the image features on a given frame.""" |
| | |
| | image, backbone_out = inference_state["cached_features"].get( |
| | frame_idx, (None, None) |
| | ) |
| | if backbone_out is None: |
| | |
| | device = inference_state["device"] |
| | image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0) |
| | backbone_out = self.forward_image(image) |
| | |
| | |
| | inference_state["cached_features"] = {frame_idx: (image, backbone_out)} |
| |
|
| | |
| | expanded_image = image.expand(batch_size, -1, -1, -1) |
| | expanded_backbone_out = { |
| | "backbone_fpn": backbone_out["backbone_fpn"].copy(), |
| | "vision_pos_enc": backbone_out["vision_pos_enc"].copy(), |
| | } |
| | for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]): |
| | expanded_backbone_out["backbone_fpn"][i] = feat.expand( |
| | batch_size, -1, -1, -1 |
| | ) |
| | for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]): |
| | pos = pos.expand(batch_size, -1, -1, -1) |
| | expanded_backbone_out["vision_pos_enc"][i] = pos |
| |
|
| | features = self._prepare_backbone_features(expanded_backbone_out) |
| | features = (expanded_image,) + features |
| | return features |
| |
|
| | def _run_single_frame_inference( |
| | self, |
| | inference_state, |
| | output_dict, |
| | frame_idx, |
| | batch_size, |
| | is_init_cond_frame, |
| | point_inputs, |
| | mask_inputs, |
| | reverse, |
| | run_mem_encoder, |
| | prev_sam_mask_logits=None, |
| | ): |
| | """Run tracking on a single frame based on current inputs and previous memory.""" |
| | |
| | ( |
| | _, |
| | _, |
| | current_vision_feats, |
| | current_vision_pos_embeds, |
| | feat_sizes, |
| | ) = self._get_image_feature(inference_state, frame_idx, batch_size) |
| |
|
| | |
| | assert point_inputs is None or mask_inputs is None |
| | current_out = self.track_step( |
| | frame_idx=frame_idx, |
| | is_init_cond_frame=is_init_cond_frame, |
| | current_vision_feats=current_vision_feats, |
| | current_vision_pos_embeds=current_vision_pos_embeds, |
| | feat_sizes=feat_sizes, |
| | point_inputs=point_inputs, |
| | mask_inputs=mask_inputs, |
| | output_dict=output_dict, |
| | num_frames=inference_state["num_frames"], |
| | track_in_reverse=reverse, |
| | run_mem_encoder=run_mem_encoder, |
| | prev_sam_mask_logits=prev_sam_mask_logits, |
| | ) |
| |
|
| | |
| | storage_device = inference_state["storage_device"] |
| | maskmem_features = current_out["maskmem_features"] |
| | if maskmem_features is not None: |
| | maskmem_features = maskmem_features.to(torch.bfloat16) |
| | maskmem_features = maskmem_features.to(storage_device, non_blocking=True) |
| | pred_masks_gpu = current_out["pred_masks"] |
| | |
| | if self.fill_hole_area > 0: |
| | pred_masks_gpu = fill_holes_in_mask_scores( |
| | pred_masks_gpu, self.fill_hole_area |
| | ) |
| | pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) |
| | |
| | maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) |
| | |
| | obj_ptr = current_out["obj_ptr"] |
| | object_score_logits = current_out["object_score_logits"] |
| | |
| | compact_current_out = { |
| | "maskmem_features": maskmem_features, |
| | "maskmem_pos_enc": maskmem_pos_enc, |
| | "pred_masks": pred_masks, |
| | "obj_ptr": obj_ptr, |
| | "object_score_logits": object_score_logits, |
| | } |
| | return compact_current_out, pred_masks_gpu |
| |
|
| | def _run_memory_encoder( |
| | self, |
| | inference_state, |
| | frame_idx, |
| | batch_size, |
| | high_res_masks, |
| | object_score_logits, |
| | is_mask_from_pts, |
| | ): |
| | """ |
| | Run the memory encoder on `high_res_masks`. This is usually after applying |
| | non-overlapping constraints to object scores. Since their scores changed, their |
| | memory also need to be computed again with the memory encoder. |
| | """ |
| | |
| | _, _, current_vision_feats, _, feat_sizes = self._get_image_feature( |
| | inference_state, frame_idx, batch_size |
| | ) |
| | maskmem_features, maskmem_pos_enc = self._encode_new_memory( |
| | current_vision_feats=current_vision_feats, |
| | feat_sizes=feat_sizes, |
| | pred_masks_high_res=high_res_masks, |
| | object_score_logits=object_score_logits, |
| | is_mask_from_pts=is_mask_from_pts, |
| | ) |
| |
|
| | |
| | storage_device = inference_state["storage_device"] |
| | maskmem_features = maskmem_features.to(torch.bfloat16) |
| | maskmem_features = maskmem_features.to(storage_device, non_blocking=True) |
| | |
| | maskmem_pos_enc = self._get_maskmem_pos_enc( |
| | inference_state, {"maskmem_pos_enc": maskmem_pos_enc} |
| | ) |
| | return maskmem_features, maskmem_pos_enc |
| |
|
| | def _get_maskmem_pos_enc(self, inference_state, current_out): |
| | """ |
| | `maskmem_pos_enc` is the same across frames and objects, so we cache it as |
| | a constant in the inference session to reduce session storage size. |
| | """ |
| | model_constants = inference_state["constants"] |
| | |
| | out_maskmem_pos_enc = current_out["maskmem_pos_enc"] |
| | if out_maskmem_pos_enc is not None: |
| | if "maskmem_pos_enc" not in model_constants: |
| | assert isinstance(out_maskmem_pos_enc, list) |
| | |
| | maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] |
| | model_constants["maskmem_pos_enc"] = maskmem_pos_enc |
| | else: |
| | maskmem_pos_enc = model_constants["maskmem_pos_enc"] |
| | |
| | batch_size = out_maskmem_pos_enc[0].size(0) |
| | expanded_maskmem_pos_enc = [ |
| | x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc |
| | ] |
| | else: |
| | expanded_maskmem_pos_enc = None |
| | return expanded_maskmem_pos_enc |
| |
|
| | @torch.inference_mode() |
| | def remove_object(self, inference_state, obj_id, strict=False, need_output=True): |
| | """ |
| | Remove an object id from the tracking state. If strict is True, we check whether |
| | the object id actually exists and raise an error if it doesn't exist. |
| | """ |
| | old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None) |
| | updated_frames = [] |
| | |
| | if old_obj_idx_to_rm is None: |
| | if not strict: |
| | return inference_state["obj_ids"], updated_frames |
| | raise RuntimeError( |
| | f"Cannot remove object id {obj_id} as it doesn't exist. " |
| | f"All existing object ids: {inference_state['obj_ids']}." |
| | ) |
| |
|
| | |
| | if len(inference_state["obj_id_to_idx"]) == 1: |
| | self.reset_state(inference_state) |
| | return inference_state["obj_ids"], updated_frames |
| |
|
| | |
| | |
| | |
| | |
| | |
| | obj_input_frames_inds = set() |
| | obj_input_frames_inds.update( |
| | inference_state["point_inputs_per_obj"][old_obj_idx_to_rm] |
| | ) |
| | obj_input_frames_inds.update( |
| | inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm] |
| | ) |
| | for frame_idx in obj_input_frames_inds: |
| | self.clear_all_prompts_in_frame( |
| | inference_state, frame_idx, obj_id, need_output=False |
| | ) |
| |
|
| | |
| | |
| | old_obj_ids = inference_state["obj_ids"] |
| | old_obj_inds = list(range(len(old_obj_ids))) |
| | remain_old_obj_inds = old_obj_inds.copy() |
| | remain_old_obj_inds.remove(old_obj_idx_to_rm) |
| | new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds] |
| | new_obj_inds = list(range(len(new_obj_ids))) |
| | |
| | old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds)) |
| | inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds)) |
| | inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids)) |
| | inference_state["obj_ids"] = new_obj_ids |
| |
|
| | |
| | def _map_keys(container): |
| | new_kvs = [] |
| | for k in old_obj_inds: |
| | v = container.pop(k) |
| | if k in old_idx_to_new_idx: |
| | new_kvs.append((old_idx_to_new_idx[k], v)) |
| | container.update(new_kvs) |
| |
|
| | _map_keys(inference_state["point_inputs_per_obj"]) |
| | _map_keys(inference_state["mask_inputs_per_obj"]) |
| | _map_keys(inference_state["output_dict_per_obj"]) |
| | _map_keys(inference_state["temp_output_dict_per_obj"]) |
| | _map_keys(inference_state["frames_tracked_per_obj"]) |
| |
|
| | |
| | |
| | if need_output: |
| | temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] |
| | for frame_idx in obj_input_frames_inds: |
| | is_cond = any( |
| | frame_idx in obj_temp_output_dict["cond_frame_outputs"] |
| | for obj_temp_output_dict in temp_output_dict_per_obj.values() |
| | ) |
| | consolidated_out = self._consolidate_temp_output_across_obj( |
| | inference_state, |
| | frame_idx, |
| | is_cond=is_cond, |
| | consolidate_at_video_res=True, |
| | ) |
| | _, video_res_masks = self._get_orig_video_res_output( |
| | inference_state, consolidated_out["pred_masks_video_res"] |
| | ) |
| | updated_frames.append((frame_idx, video_res_masks)) |
| |
|
| | return inference_state["obj_ids"], updated_frames |
| |
|
| | def _clear_non_cond_mem_around_input(self, inference_state, frame_idx): |
| | """ |
| | Remove the non-conditioning memory around the input frame. When users provide |
| | correction clicks, the surrounding frames' non-conditioning memories can still |
| | contain outdated object appearance information and could confuse the model. |
| | |
| | This method clears those non-conditioning memories surrounding the interacted |
| | frame to avoid giving the model both old and new information about the object. |
| | """ |
| | r = self.memory_temporal_stride_for_eval |
| | frame_idx_begin = frame_idx - r * self.num_maskmem |
| | frame_idx_end = frame_idx + r * self.num_maskmem |
| | batch_size = self._get_obj_num(inference_state) |
| | for obj_idx in range(batch_size): |
| | obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] |
| | non_cond_frame_outputs = obj_output_dict["non_cond_frame_outputs"] |
| | for t in range(frame_idx_begin, frame_idx_end + 1): |
| | non_cond_frame_outputs.pop(t, None) |
| |
|
| |
|
| | class SAM2VideoPredictorVOS(SAM2VideoPredictor): |
| | """Optimized for the VOS setting""" |
| |
|
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self._compile_all_components() |
| |
|
| | def _compile_all_components(self): |
| | print("Compiling all components for VOS setting. First time may be very slow.") |
| | self.memory_encoder.forward = torch.compile( |
| | self.memory_encoder.forward, |
| | mode="max-autotune", |
| | fullgraph=True, |
| | dynamic=False, |
| | ) |
| |
|
| | self.memory_attention.forward = torch.compile( |
| | self.memory_attention.forward, |
| | mode="max-autotune", |
| | fullgraph=True, |
| | dynamic=True, |
| | ) |
| |
|
| | self.sam_prompt_encoder.forward = torch.compile( |
| | self.sam_prompt_encoder.forward, |
| | mode="max-autotune", |
| | fullgraph=True, |
| | dynamic=False, |
| | ) |
| |
|
| | self.sam_mask_decoder.forward = torch.compile( |
| | self.sam_mask_decoder.forward, |
| | mode="max-autotune", |
| | fullgraph=True, |
| | dynamic=False, |
| | ) |
| |
|
| | def forward_image(self, img_batch: torch.Tensor): |
| | """ |
| | Identical to the corresponding method in the parent (SAM2VideoPredictor), but |
| | cloning the backbone features and pos encoding to enable compilation. |
| | """ |
| | backbone_out = self.image_encoder(img_batch) |
| | if self.use_high_res_features_in_sam: |
| | |
| | |
| | backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0( |
| | backbone_out["backbone_fpn"][0] |
| | ) |
| | backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1( |
| | backbone_out["backbone_fpn"][1] |
| | ) |
| | |
| | for i in range(len(backbone_out["backbone_fpn"])): |
| | backbone_out["backbone_fpn"][i] = backbone_out["backbone_fpn"][i].clone() |
| | backbone_out["vision_pos_enc"][i] = backbone_out["vision_pos_enc"][ |
| | i |
| | ].clone() |
| | return backbone_out |
| |
|
| | def _forward_sam_heads( |
| | self, |
| | backbone_features, |
| | point_inputs=None, |
| | mask_inputs=None, |
| | high_res_features=None, |
| | multimask_output=False, |
| | ): |
| | """ |
| | Identical to the corresponding method in the parent (SAM2VideoPredictor), but |
| | cloning the outputs of prompt_encoder and mask_decoder to enable compilation. |
| | """ |
| | B = backbone_features.size(0) |
| | device = backbone_features.device |
| | assert backbone_features.size(1) == self.sam_prompt_embed_dim |
| | assert backbone_features.size(2) == self.sam_image_embedding_size |
| | assert backbone_features.size(3) == self.sam_image_embedding_size |
| |
|
| | |
| | if point_inputs is not None: |
| | sam_point_coords = point_inputs["point_coords"] |
| | sam_point_labels = point_inputs["point_labels"] |
| | assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B |
| | else: |
| | |
| | sam_point_coords = torch.zeros(B, 1, 2, device=device) |
| | sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) |
| |
|
| | |
| | if mask_inputs is not None: |
| | |
| | |
| | assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) |
| | if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: |
| | sam_mask_prompt = F.interpolate( |
| | mask_inputs.float(), |
| | size=self.sam_prompt_encoder.mask_input_size, |
| | align_corners=False, |
| | mode="bilinear", |
| | antialias=True, |
| | ) |
| | else: |
| | sam_mask_prompt = mask_inputs |
| | else: |
| | |
| | |
| | sam_mask_prompt = None |
| |
|
| | sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( |
| | points=(sam_point_coords, sam_point_labels), |
| | boxes=None, |
| | masks=sam_mask_prompt, |
| | ) |
| | |
| | |
| | sparse_embeddings = sparse_embeddings.clone() |
| | dense_embeddings = dense_embeddings.clone() |
| | image_pe = self.sam_prompt_encoder.get_dense_pe().clone() |
| | ( |
| | low_res_multimasks, |
| | ious, |
| | sam_output_tokens, |
| | object_score_logits, |
| | ) = self.sam_mask_decoder( |
| | image_embeddings=backbone_features, |
| | image_pe=image_pe, |
| | sparse_prompt_embeddings=sparse_embeddings, |
| | dense_prompt_embeddings=dense_embeddings, |
| | multimask_output=multimask_output, |
| | repeat_image=False, |
| | high_res_features=high_res_features, |
| | ) |
| | |
| | |
| | low_res_multimasks = low_res_multimasks.clone() |
| | ious = ious.clone() |
| | sam_output_tokens = sam_output_tokens.clone() |
| | object_score_logits = object_score_logits.clone() |
| |
|
| | if self.pred_obj_scores: |
| | is_obj_appearing = object_score_logits > 0 |
| |
|
| | |
| | |
| | low_res_multimasks = torch.where( |
| | is_obj_appearing[:, None, None], |
| | low_res_multimasks, |
| | NO_OBJ_SCORE, |
| | ) |
| |
|
| | |
| | |
| | low_res_multimasks = low_res_multimasks.float() |
| | high_res_multimasks = F.interpolate( |
| | low_res_multimasks, |
| | size=(self.image_size, self.image_size), |
| | mode="bilinear", |
| | align_corners=False, |
| | ) |
| |
|
| | sam_output_token = sam_output_tokens[:, 0] |
| | if multimask_output: |
| | |
| | best_iou_inds = torch.argmax(ious, dim=-1) |
| | batch_inds = torch.arange(B, device=device) |
| | low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) |
| | high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) |
| | if sam_output_tokens.size(1) > 1: |
| | sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] |
| | else: |
| | low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks |
| |
|
| | |
| | obj_ptr = self.obj_ptr_proj(sam_output_token) |
| | if self.pred_obj_scores: |
| | |
| | if self.soft_no_obj_ptr: |
| | lambda_is_obj_appearing = object_score_logits.sigmoid() |
| | else: |
| | lambda_is_obj_appearing = is_obj_appearing.float() |
| |
|
| | if self.fixed_no_obj_ptr: |
| | obj_ptr = lambda_is_obj_appearing * obj_ptr |
| | obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr |
| |
|
| | return ( |
| | low_res_multimasks, |
| | high_res_multimasks, |
| | ious, |
| | low_res_masks, |
| | high_res_masks, |
| | obj_ptr, |
| | object_score_logits, |
| | ) |
| |
|
| | def _encode_new_memory( |
| | self, |
| | current_vision_feats, |
| | feat_sizes, |
| | pred_masks_high_res, |
| | object_score_logits, |
| | is_mask_from_pts, |
| | ): |
| | """ |
| | Identical to the corresponding method in the parent (SAM2VideoPredictor), but |
| | cloning the memories and their pos enc to enable compilation. |
| | """ |
| | B = current_vision_feats[-1].size(1) |
| | C = self.hidden_dim |
| | H, W = feat_sizes[-1] |
| | |
| | pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) |
| | if self.non_overlap_masks_for_mem_enc and not self.training: |
| | |
| | |
| | |
| | pred_masks_high_res = self._apply_non_overlapping_constraints( |
| | pred_masks_high_res |
| | ) |
| | |
| | binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts |
| | if binarize and not self.training: |
| | mask_for_mem = (pred_masks_high_res > 0).float() |
| | else: |
| | |
| | mask_for_mem = torch.sigmoid(pred_masks_high_res) |
| | |
| | if self.sigmoid_scale_for_mem_enc != 1.0: |
| | mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc |
| | if self.sigmoid_bias_for_mem_enc != 0.0: |
| | mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc |
| | maskmem_out = self.memory_encoder( |
| | pix_feat, mask_for_mem, skip_mask_sigmoid=True |
| | ) |
| | |
| | maskmem_features = maskmem_out["vision_features"].clone() |
| | maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]] |
| | |
| | |
| | if self.no_obj_embed_spatial is not None: |
| | is_obj_appearing = (object_score_logits > 0).float() |
| | maskmem_features += ( |
| | 1 - is_obj_appearing[..., None, None] |
| | ) * self.no_obj_embed_spatial[..., None, None].expand( |
| | *maskmem_features.shape |
| | ) |
| |
|
| | return maskmem_features, maskmem_pos_enc |
| |
|