|
|
|
|
|
|
|
|
|
|
|
|
| import warnings
|
| from collections import OrderedDict
|
|
|
| import torch
|
| import torch.nn.functional as F
|
|
|
| from tqdm import tqdm
|
|
|
| from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
|
| from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
|
|
|
|
|
| class SAM2VideoPredictor(SAM2Base):
|
| """The predictor class to handle user interactions and manage inference states."""
|
|
|
| def __init__(
|
| self,
|
| fill_hole_area=0,
|
|
|
| non_overlap_masks=False,
|
|
|
|
|
| clear_non_cond_mem_around_input=False,
|
|
|
|
|
| add_all_frames_to_correct_as_cond=False,
|
| **kwargs,
|
| ):
|
| super().__init__(**kwargs)
|
| self.fill_hole_area = fill_hole_area
|
| self.non_overlap_masks = non_overlap_masks
|
| self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
|
| self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
|
|
|
| @torch.inference_mode()
|
| def init_state(
|
| self,
|
| video_path,
|
| offload_video_to_cpu=False,
|
| offload_state_to_cpu=False,
|
| async_loading_frames=False,
|
| ):
|
| """Initialize an inference state."""
|
| compute_device = self.device
|
| images, video_height, video_width = load_video_frames(
|
| video_path=video_path,
|
| image_size=self.image_size,
|
| offload_video_to_cpu=offload_video_to_cpu,
|
| async_loading_frames=async_loading_frames,
|
| compute_device=compute_device,
|
| )
|
| inference_state = {}
|
| inference_state["images"] = images
|
| inference_state["num_frames"] = len(images)
|
|
|
|
|
| inference_state["offload_video_to_cpu"] = offload_video_to_cpu
|
|
|
|
|
|
|
|
|
| inference_state["offload_state_to_cpu"] = offload_state_to_cpu
|
|
|
| inference_state["video_height"] = video_height
|
| inference_state["video_width"] = video_width
|
| inference_state["device"] = compute_device
|
| if offload_state_to_cpu:
|
| inference_state["storage_device"] = torch.device("cpu")
|
| else:
|
| inference_state["storage_device"] = compute_device
|
|
|
| inference_state["point_inputs_per_obj"] = {}
|
| inference_state["mask_inputs_per_obj"] = {}
|
|
|
| inference_state["cached_features"] = {}
|
|
|
| inference_state["constants"] = {}
|
|
|
| inference_state["obj_id_to_idx"] = OrderedDict()
|
| inference_state["obj_idx_to_id"] = OrderedDict()
|
| inference_state["obj_ids"] = []
|
|
|
| inference_state["output_dict_per_obj"] = {}
|
|
|
|
|
| inference_state["temp_output_dict_per_obj"] = {}
|
|
|
|
|
|
|
| inference_state["frames_tracked_per_obj"] = {}
|
|
|
| self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
| return inference_state
|
|
|
| @classmethod
|
| def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
|
| """
|
| Load a pretrained model from the Hugging Face hub.
|
|
|
| Arguments:
|
| model_id (str): The Hugging Face repository ID.
|
| **kwargs: Additional arguments to pass to the model constructor.
|
|
|
| Returns:
|
| (SAM2VideoPredictor): The loaded model.
|
| """
|
| from sam2.build_sam import build_sam2_video_predictor_hf
|
|
|
| sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
|
| return sam_model
|
|
|
| def _obj_id_to_idx(self, inference_state, obj_id):
|
| """Map client-side object id to model-side object index."""
|
| obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
|
| if obj_idx is not None:
|
| return obj_idx
|
|
|
|
|
| allow_new_object = True
|
| if allow_new_object:
|
|
|
| obj_idx = len(inference_state["obj_id_to_idx"])
|
| inference_state["obj_id_to_idx"][obj_id] = obj_idx
|
| inference_state["obj_idx_to_id"][obj_idx] = obj_id
|
| inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
|
|
|
| inference_state["point_inputs_per_obj"][obj_idx] = {}
|
| inference_state["mask_inputs_per_obj"][obj_idx] = {}
|
| inference_state["output_dict_per_obj"][obj_idx] = {
|
| "cond_frame_outputs": {},
|
| "non_cond_frame_outputs": {},
|
| }
|
| inference_state["temp_output_dict_per_obj"][obj_idx] = {
|
| "cond_frame_outputs": {},
|
| "non_cond_frame_outputs": {},
|
| }
|
| inference_state["frames_tracked_per_obj"][obj_idx] = {}
|
| return obj_idx
|
| else:
|
| raise RuntimeError(
|
| f"Cannot add new object id {obj_id} after tracking starts. "
|
| f"All existing object ids: {inference_state['obj_ids']}. "
|
| f"Please call 'reset_state' to restart from scratch."
|
| )
|
|
|
| def _obj_idx_to_id(self, inference_state, obj_idx):
|
| """Map model-side object index to client-side object id."""
|
| return inference_state["obj_idx_to_id"][obj_idx]
|
|
|
| def _get_obj_num(self, inference_state):
|
| """Get the total number of unique object ids received so far in this session."""
|
| return len(inference_state["obj_idx_to_id"])
|
|
|
| @torch.inference_mode()
|
| def add_new_points_or_box(
|
| self,
|
| inference_state,
|
| frame_idx,
|
| obj_id,
|
| points=None,
|
| labels=None,
|
| clear_old_points=True,
|
| normalize_coords=True,
|
| box=None,
|
| ):
|
| """Add new points to a frame."""
|
| obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
| point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
|
| mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
|
|
|
| if (points is not None) != (labels is not None):
|
| raise ValueError("points and labels must be provided together")
|
| if points is None and box is None:
|
| raise ValueError("at least one of points or box must be provided as input")
|
|
|
| if points is None:
|
| points = torch.zeros(0, 2, dtype=torch.float32)
|
| elif not isinstance(points, torch.Tensor):
|
| points = torch.tensor(points, dtype=torch.float32)
|
| if labels is None:
|
| labels = torch.zeros(0, dtype=torch.int32)
|
| elif not isinstance(labels, torch.Tensor):
|
| labels = torch.tensor(labels, dtype=torch.int32)
|
| if points.dim() == 2:
|
| points = points.unsqueeze(0)
|
| if labels.dim() == 1:
|
| labels = labels.unsqueeze(0)
|
|
|
|
|
|
|
| if box is not None:
|
| if not clear_old_points:
|
| raise ValueError(
|
| "cannot add box without clearing old points, since "
|
| "box prompt must be provided before any point prompt "
|
| "(please use clear_old_points=True instead)"
|
| )
|
| if not isinstance(box, torch.Tensor):
|
| box = torch.tensor(box, dtype=torch.float32, device=points.device)
|
| box_coords = box.reshape(1, 2, 2)
|
| box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
|
| box_labels = box_labels.reshape(1, 2)
|
| points = torch.cat([box_coords, points], dim=1)
|
| labels = torch.cat([box_labels, labels], dim=1)
|
|
|
| if normalize_coords:
|
| video_H = inference_state["video_height"]
|
| video_W = inference_state["video_width"]
|
| points = points / torch.tensor([video_W, video_H]).to(points.device)
|
|
|
| points = points * self.image_size
|
| points = points.to(inference_state["device"])
|
| labels = labels.to(inference_state["device"])
|
|
|
| if not clear_old_points:
|
| point_inputs = point_inputs_per_frame.get(frame_idx, None)
|
| else:
|
| point_inputs = None
|
| point_inputs = concat_points(point_inputs, points, labels)
|
|
|
| point_inputs_per_frame[frame_idx] = point_inputs
|
| mask_inputs_per_frame.pop(frame_idx, None)
|
|
|
|
|
|
|
|
|
| obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx]
|
| is_init_cond_frame = frame_idx not in obj_frames_tracked
|
|
|
| if is_init_cond_frame:
|
| reverse = False
|
| else:
|
| reverse = obj_frames_tracked[frame_idx]["reverse"]
|
| obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
| obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
|
|
|
|
| is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
|
| storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
|
|
|
|
|
|
| prev_sam_mask_logits = None
|
|
|
|
|
| prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
|
| if prev_out is None:
|
| prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
|
| if prev_out is None:
|
| prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
|
|
|
| if prev_out is not None and prev_out["pred_masks"] is not None:
|
| device = inference_state["device"]
|
| prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
|
|
|
| prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
|
| current_out, _ = self._run_single_frame_inference(
|
| inference_state=inference_state,
|
| output_dict=obj_output_dict,
|
| frame_idx=frame_idx,
|
| batch_size=1,
|
| is_init_cond_frame=is_init_cond_frame,
|
| point_inputs=point_inputs,
|
| mask_inputs=None,
|
| reverse=reverse,
|
|
|
|
|
|
|
|
|
| run_mem_encoder=False,
|
| prev_sam_mask_logits=prev_sam_mask_logits,
|
| )
|
|
|
| obj_temp_output_dict[storage_key][frame_idx] = current_out
|
|
|
|
|
| obj_ids = inference_state["obj_ids"]
|
| consolidated_out = self._consolidate_temp_output_across_obj(
|
| inference_state,
|
| frame_idx,
|
| is_cond=is_cond,
|
| consolidate_at_video_res=True,
|
| )
|
| _, video_res_masks = self._get_orig_video_res_output(
|
| inference_state, consolidated_out["pred_masks_video_res"]
|
| )
|
| return frame_idx, obj_ids, video_res_masks
|
|
|
| def add_new_points(self, *args, **kwargs):
|
| """Deprecated method. Please use `add_new_points_or_box` instead."""
|
| return self.add_new_points_or_box(*args, **kwargs)
|
|
|
| @torch.inference_mode()
|
| def add_new_mask(
|
| self,
|
| inference_state,
|
| frame_idx,
|
| obj_id,
|
| mask,
|
| ):
|
| """Add new mask to a frame."""
|
| obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
| point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
|
| mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
|
|
|
| if not isinstance(mask, torch.Tensor):
|
| mask = torch.tensor(mask, dtype=torch.bool)
|
| assert mask.dim() == 2
|
| mask_H, mask_W = mask.shape
|
| mask_inputs_orig = mask[None, None]
|
| mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
|
|
|
|
|
| if mask_H != self.image_size or mask_W != self.image_size:
|
| mask_inputs = torch.nn.functional.interpolate(
|
| mask_inputs_orig,
|
| size=(self.image_size, self.image_size),
|
| align_corners=False,
|
| mode="bilinear",
|
| antialias=True,
|
| )
|
| mask_inputs = (mask_inputs >= 0.5).float()
|
| else:
|
| mask_inputs = mask_inputs_orig
|
|
|
| mask_inputs_per_frame[frame_idx] = mask_inputs
|
| point_inputs_per_frame.pop(frame_idx, None)
|
|
|
|
|
|
|
|
|
| obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx]
|
| is_init_cond_frame = frame_idx not in obj_frames_tracked
|
|
|
| if is_init_cond_frame:
|
| reverse = False
|
| else:
|
| reverse = obj_frames_tracked[frame_idx]["reverse"]
|
| obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
| obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
|
|
|
|
| is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
|
| storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
|
|
| current_out, _ = self._run_single_frame_inference(
|
| inference_state=inference_state,
|
| output_dict=obj_output_dict,
|
| frame_idx=frame_idx,
|
| batch_size=1,
|
| is_init_cond_frame=is_init_cond_frame,
|
| point_inputs=None,
|
| mask_inputs=mask_inputs,
|
| reverse=reverse,
|
|
|
|
|
|
|
|
|
| run_mem_encoder=False,
|
| )
|
|
|
| obj_temp_output_dict[storage_key][frame_idx] = current_out
|
|
|
|
|
| obj_ids = inference_state["obj_ids"]
|
| consolidated_out = self._consolidate_temp_output_across_obj(
|
| inference_state,
|
| frame_idx,
|
| is_cond=is_cond,
|
| consolidate_at_video_res=True,
|
| )
|
| _, video_res_masks = self._get_orig_video_res_output(
|
| inference_state, consolidated_out["pred_masks_video_res"]
|
| )
|
| return frame_idx, obj_ids, video_res_masks
|
|
|
| def _get_orig_video_res_output(self, inference_state, any_res_masks):
|
| """
|
| Resize the object scores to the original video resolution (video_res_masks)
|
| and apply non-overlapping constraints for final output.
|
| """
|
| device = inference_state["device"]
|
| video_H = inference_state["video_height"]
|
| video_W = inference_state["video_width"]
|
| any_res_masks = any_res_masks.to(device, non_blocking=True)
|
| if any_res_masks.shape[-2:] == (video_H, video_W):
|
| video_res_masks = any_res_masks
|
| else:
|
| video_res_masks = torch.nn.functional.interpolate(
|
| any_res_masks,
|
| size=(video_H, video_W),
|
| mode="bilinear",
|
| align_corners=False,
|
| )
|
| if self.non_overlap_masks:
|
| video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
|
| return any_res_masks, video_res_masks
|
|
|
| def _consolidate_temp_output_across_obj(
|
| self,
|
| inference_state,
|
| frame_idx,
|
| is_cond,
|
| consolidate_at_video_res=False,
|
| ):
|
| """
|
| Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
|
| a frame into a single output for all objects, including
|
| 1) fill any missing objects either from `output_dict_per_obj` (if they exist in
|
| `output_dict_per_obj` for this frame) or leave them as placeholder values
|
| (if they don't exist in `output_dict_per_obj` for this frame);
|
| 2) if specified, rerun memory encoder after apply non-overlapping constraints
|
| on the object scores.
|
| """
|
| batch_size = self._get_obj_num(inference_state)
|
| storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
|
|
|
|
| if consolidate_at_video_res:
|
| consolidated_H = inference_state["video_height"]
|
| consolidated_W = inference_state["video_width"]
|
| consolidated_mask_key = "pred_masks_video_res"
|
| else:
|
| consolidated_H = consolidated_W = self.image_size // 4
|
| consolidated_mask_key = "pred_masks"
|
|
|
|
|
|
|
|
|
|
|
| consolidated_out = {
|
| consolidated_mask_key: torch.full(
|
| size=(batch_size, 1, consolidated_H, consolidated_W),
|
| fill_value=NO_OBJ_SCORE,
|
| dtype=torch.float32,
|
| device=inference_state["storage_device"],
|
| ),
|
| }
|
| for obj_idx in range(batch_size):
|
| obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
| obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
| out = obj_temp_output_dict[storage_key].get(frame_idx, None)
|
|
|
|
|
|
|
|
|
| if out is None:
|
| out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
|
| if out is None:
|
| out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
|
|
|
|
|
|
|
| if out is None:
|
| continue
|
|
|
| obj_mask = out["pred_masks"]
|
| consolidated_pred_masks = consolidated_out[consolidated_mask_key]
|
| if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
|
| consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
|
| else:
|
|
|
| resized_obj_mask = torch.nn.functional.interpolate(
|
| obj_mask,
|
| size=consolidated_pred_masks.shape[-2:],
|
| mode="bilinear",
|
| align_corners=False,
|
| )
|
| consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
|
|
|
| return consolidated_out
|
|
|
| @torch.inference_mode()
|
| def propagate_in_video_preflight(self, inference_state):
|
| """Prepare inference_state and consolidate temporary outputs before tracking."""
|
|
|
| batch_size = self._get_obj_num(inference_state)
|
| if batch_size == 0:
|
| raise RuntimeError(
|
| "No input points or masks are provided for any object; please add inputs first."
|
| )
|
|
|
|
|
|
|
| for obj_idx in range(batch_size):
|
| obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
| obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
| for is_cond in [False, True]:
|
|
|
| storage_key = (
|
| "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
| )
|
|
|
|
|
|
|
| for frame_idx, out in obj_temp_output_dict[storage_key].items():
|
|
|
| if out["maskmem_features"] is None:
|
| high_res_masks = torch.nn.functional.interpolate(
|
| out["pred_masks"].to(inference_state["device"]),
|
| size=(self.image_size, self.image_size),
|
| mode="bilinear",
|
| align_corners=False,
|
| )
|
| maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
|
| inference_state=inference_state,
|
| frame_idx=frame_idx,
|
| batch_size=1,
|
| high_res_masks=high_res_masks,
|
| object_score_logits=out["object_score_logits"],
|
|
|
| is_mask_from_pts=True,
|
| )
|
| out["maskmem_features"] = maskmem_features
|
| out["maskmem_pos_enc"] = maskmem_pos_enc
|
|
|
| obj_output_dict[storage_key][frame_idx] = out
|
| if self.clear_non_cond_mem_around_input:
|
|
|
| self._clear_obj_non_cond_mem_around_input(
|
| inference_state, frame_idx, obj_idx
|
| )
|
|
|
|
|
| obj_temp_output_dict[storage_key].clear()
|
|
|
|
|
| obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
| if len(obj_output_dict["cond_frame_outputs"]) == 0:
|
| obj_id = self._obj_idx_to_id(inference_state, obj_idx)
|
| raise RuntimeError(
|
| f"No input points or masks are provided for object id {obj_id}; please add inputs first."
|
| )
|
|
|
|
|
| for frame_idx in obj_output_dict["cond_frame_outputs"]:
|
| obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
|
|
| @torch.inference_mode()
|
| def propagate_in_video(
|
| self,
|
| inference_state,
|
| start_frame_idx=None,
|
| max_frame_num_to_track=None,
|
| reverse=False,
|
| ):
|
| """Propagate the input points across frames to track in the entire video."""
|
| self.propagate_in_video_preflight(inference_state)
|
|
|
| obj_ids = inference_state["obj_ids"]
|
| num_frames = inference_state["num_frames"]
|
| batch_size = self._get_obj_num(inference_state)
|
|
|
|
|
| if start_frame_idx is None:
|
|
|
| start_frame_idx = min(
|
| t
|
| for obj_output_dict in inference_state["output_dict_per_obj"].values()
|
| for t in obj_output_dict["cond_frame_outputs"]
|
| )
|
| if max_frame_num_to_track is None:
|
|
|
| max_frame_num_to_track = num_frames
|
| if reverse:
|
| end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
|
| if start_frame_idx > 0:
|
| processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
|
| else:
|
| processing_order = []
|
| else:
|
| end_frame_idx = min(
|
| start_frame_idx + max_frame_num_to_track, num_frames - 1
|
| )
|
| processing_order = range(start_frame_idx, end_frame_idx + 1)
|
|
|
| for frame_idx in tqdm(processing_order, desc="propagate in video"):
|
| pred_masks_per_obj = [None] * batch_size
|
| for obj_idx in range(batch_size):
|
| obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
|
|
|
|
|
|
|
|
| if frame_idx in obj_output_dict["cond_frame_outputs"]:
|
| storage_key = "cond_frame_outputs"
|
| current_out = obj_output_dict[storage_key][frame_idx]
|
| device = inference_state["device"]
|
| pred_masks = current_out["pred_masks"].to(device, non_blocking=True)
|
| if self.clear_non_cond_mem_around_input:
|
|
|
| self._clear_obj_non_cond_mem_around_input(
|
| inference_state, frame_idx, obj_idx
|
| )
|
| else:
|
| storage_key = "non_cond_frame_outputs"
|
| current_out, pred_masks = self._run_single_frame_inference(
|
| inference_state=inference_state,
|
| output_dict=obj_output_dict,
|
| frame_idx=frame_idx,
|
| batch_size=1,
|
| is_init_cond_frame=False,
|
| point_inputs=None,
|
| mask_inputs=None,
|
| reverse=reverse,
|
| run_mem_encoder=True,
|
| )
|
| obj_output_dict[storage_key][frame_idx] = current_out
|
|
|
| inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = {
|
| "reverse": reverse
|
| }
|
| pred_masks_per_obj[obj_idx] = pred_masks
|
|
|
|
|
|
|
| if len(pred_masks_per_obj) > 1:
|
| all_pred_masks = torch.cat(pred_masks_per_obj, dim=0)
|
| else:
|
| all_pred_masks = pred_masks_per_obj[0]
|
| _, video_res_masks = self._get_orig_video_res_output(
|
| inference_state, all_pred_masks
|
| )
|
| yield frame_idx, obj_ids, video_res_masks
|
|
|
| @torch.inference_mode()
|
| def clear_all_prompts_in_frame(
|
| self, inference_state, frame_idx, obj_id, need_output=True
|
| ):
|
| """Remove all input points or mask in a specific frame for a given object."""
|
| obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
|
|
|
|
| inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None)
|
| inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None)
|
|
|
| temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
| temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
|
| temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
|
|
|
|
|
| obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
| out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
|
| if out is not None:
|
|
|
|
|
| obj_output_dict["non_cond_frame_outputs"][frame_idx] = out
|
| inference_state["frames_tracked_per_obj"][obj_idx].pop(frame_idx, None)
|
|
|
| if not need_output:
|
| return
|
|
|
| obj_ids = inference_state["obj_ids"]
|
| is_cond = any(
|
| frame_idx in obj_temp_output_dict["cond_frame_outputs"]
|
| for obj_temp_output_dict in temp_output_dict_per_obj.values()
|
| )
|
| consolidated_out = self._consolidate_temp_output_across_obj(
|
| inference_state,
|
| frame_idx,
|
| is_cond=is_cond,
|
| consolidate_at_video_res=True,
|
| )
|
| _, video_res_masks = self._get_orig_video_res_output(
|
| inference_state, consolidated_out["pred_masks_video_res"]
|
| )
|
| return frame_idx, obj_ids, video_res_masks
|
|
|
| @torch.inference_mode()
|
| def reset_state(self, inference_state):
|
| """Remove all input points or mask in all frames throughout the video."""
|
| self._reset_tracking_results(inference_state)
|
|
|
| inference_state["obj_id_to_idx"].clear()
|
| inference_state["obj_idx_to_id"].clear()
|
| inference_state["obj_ids"].clear()
|
| inference_state["point_inputs_per_obj"].clear()
|
| inference_state["mask_inputs_per_obj"].clear()
|
| inference_state["output_dict_per_obj"].clear()
|
| inference_state["temp_output_dict_per_obj"].clear()
|
| inference_state["frames_tracked_per_obj"].clear()
|
|
|
| def _reset_tracking_results(self, inference_state):
|
| """Reset all tracking inputs and results across the videos."""
|
| for v in inference_state["point_inputs_per_obj"].values():
|
| v.clear()
|
| for v in inference_state["mask_inputs_per_obj"].values():
|
| v.clear()
|
| for v in inference_state["output_dict_per_obj"].values():
|
| v["cond_frame_outputs"].clear()
|
| v["non_cond_frame_outputs"].clear()
|
| for v in inference_state["temp_output_dict_per_obj"].values():
|
| v["cond_frame_outputs"].clear()
|
| v["non_cond_frame_outputs"].clear()
|
| for v in inference_state["frames_tracked_per_obj"].values():
|
| v.clear()
|
|
|
| def _get_image_feature(self, inference_state, frame_idx, batch_size):
|
| """Compute the image features on a given frame."""
|
|
|
| image, backbone_out = inference_state["cached_features"].get(
|
| frame_idx, (None, None)
|
| )
|
| if backbone_out is None:
|
|
|
| device = inference_state["device"]
|
| image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
|
| backbone_out = self.forward_image(image)
|
|
|
|
|
| inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
|
|
|
|
|
| expanded_image = image.expand(batch_size, -1, -1, -1)
|
| expanded_backbone_out = {
|
| "backbone_fpn": backbone_out["backbone_fpn"].copy(),
|
| "vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
|
| }
|
| for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
|
| expanded_backbone_out["backbone_fpn"][i] = feat.expand(
|
| batch_size, -1, -1, -1
|
| )
|
| for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
|
| pos = pos.expand(batch_size, -1, -1, -1)
|
| expanded_backbone_out["vision_pos_enc"][i] = pos
|
|
|
| features = self._prepare_backbone_features(expanded_backbone_out)
|
| features = (expanded_image,) + features
|
| return features
|
|
|
| def _run_single_frame_inference(
|
| self,
|
| inference_state,
|
| output_dict,
|
| frame_idx,
|
| batch_size,
|
| is_init_cond_frame,
|
| point_inputs,
|
| mask_inputs,
|
| reverse,
|
| run_mem_encoder,
|
| prev_sam_mask_logits=None,
|
| ):
|
| """Run tracking on a single frame based on current inputs and previous memory."""
|
|
|
| (
|
| _,
|
| _,
|
| current_vision_feats,
|
| current_vision_pos_embeds,
|
| feat_sizes,
|
| ) = self._get_image_feature(inference_state, frame_idx, batch_size)
|
|
|
|
|
| assert point_inputs is None or mask_inputs is None
|
| current_out = self.track_step(
|
| frame_idx=frame_idx,
|
| is_init_cond_frame=is_init_cond_frame,
|
| current_vision_feats=current_vision_feats,
|
| current_vision_pos_embeds=current_vision_pos_embeds,
|
| feat_sizes=feat_sizes,
|
| point_inputs=point_inputs,
|
| mask_inputs=mask_inputs,
|
| output_dict=output_dict,
|
| num_frames=inference_state["num_frames"],
|
| track_in_reverse=reverse,
|
| run_mem_encoder=run_mem_encoder,
|
| prev_sam_mask_logits=prev_sam_mask_logits,
|
| )
|
|
|
|
|
| storage_device = inference_state["storage_device"]
|
| maskmem_features = current_out["maskmem_features"]
|
| if maskmem_features is not None:
|
| maskmem_features = maskmem_features.to(torch.bfloat16)
|
| maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
|
| pred_masks_gpu = current_out["pred_masks"]
|
|
|
| if self.fill_hole_area > 0:
|
| pred_masks_gpu = fill_holes_in_mask_scores(
|
| pred_masks_gpu, self.fill_hole_area
|
| )
|
| pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
|
|
|
| maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
|
|
|
| obj_ptr = current_out["obj_ptr"]
|
| object_score_logits = current_out["object_score_logits"]
|
|
|
| compact_current_out = {
|
| "maskmem_features": maskmem_features,
|
| "maskmem_pos_enc": maskmem_pos_enc,
|
| "pred_masks": pred_masks,
|
| "obj_ptr": obj_ptr,
|
| "object_score_logits": object_score_logits,
|
| }
|
| return compact_current_out, pred_masks_gpu
|
|
|
| def _run_memory_encoder(
|
| self,
|
| inference_state,
|
| frame_idx,
|
| batch_size,
|
| high_res_masks,
|
| object_score_logits,
|
| is_mask_from_pts,
|
| ):
|
| """
|
| Run the memory encoder on `high_res_masks`. This is usually after applying
|
| non-overlapping constraints to object scores. Since their scores changed, their
|
| memory also need to be computed again with the memory encoder.
|
| """
|
|
|
| _, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
|
| inference_state, frame_idx, batch_size
|
| )
|
| maskmem_features, maskmem_pos_enc = self._encode_new_memory(
|
| current_vision_feats=current_vision_feats,
|
| feat_sizes=feat_sizes,
|
| pred_masks_high_res=high_res_masks,
|
| object_score_logits=object_score_logits,
|
| is_mask_from_pts=is_mask_from_pts,
|
| )
|
|
|
|
|
| storage_device = inference_state["storage_device"]
|
| maskmem_features = maskmem_features.to(torch.bfloat16)
|
| maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
|
|
|
| maskmem_pos_enc = self._get_maskmem_pos_enc(
|
| inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
|
| )
|
| return maskmem_features, maskmem_pos_enc
|
|
|
| def _get_maskmem_pos_enc(self, inference_state, current_out):
|
| """
|
| `maskmem_pos_enc` is the same across frames and objects, so we cache it as
|
| a constant in the inference session to reduce session storage size.
|
| """
|
| model_constants = inference_state["constants"]
|
|
|
| out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
| if out_maskmem_pos_enc is not None:
|
| if "maskmem_pos_enc" not in model_constants:
|
| assert isinstance(out_maskmem_pos_enc, list)
|
|
|
| maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
|
| model_constants["maskmem_pos_enc"] = maskmem_pos_enc
|
| else:
|
| maskmem_pos_enc = model_constants["maskmem_pos_enc"]
|
|
|
| batch_size = out_maskmem_pos_enc[0].size(0)
|
| expanded_maskmem_pos_enc = [
|
| x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
|
| ]
|
| else:
|
| expanded_maskmem_pos_enc = None
|
| return expanded_maskmem_pos_enc
|
|
|
| @torch.inference_mode()
|
| def remove_object(self, inference_state, obj_id, strict=False, need_output=True):
|
| """
|
| Remove an object id from the tracking state. If strict is True, we check whether
|
| the object id actually exists and raise an error if it doesn't exist.
|
| """
|
| old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None)
|
| updated_frames = []
|
|
|
| if old_obj_idx_to_rm is None:
|
| if not strict:
|
| return inference_state["obj_ids"], updated_frames
|
| raise RuntimeError(
|
| f"Cannot remove object id {obj_id} as it doesn't exist. "
|
| f"All existing object ids: {inference_state['obj_ids']}."
|
| )
|
|
|
|
|
| if len(inference_state["obj_id_to_idx"]) == 1:
|
| self.reset_state(inference_state)
|
| return inference_state["obj_ids"], updated_frames
|
|
|
|
|
|
|
|
|
|
|
|
|
| obj_input_frames_inds = set()
|
| obj_input_frames_inds.update(
|
| inference_state["point_inputs_per_obj"][old_obj_idx_to_rm]
|
| )
|
| obj_input_frames_inds.update(
|
| inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm]
|
| )
|
| for frame_idx in obj_input_frames_inds:
|
| self.clear_all_prompts_in_frame(
|
| inference_state, frame_idx, obj_id, need_output=False
|
| )
|
|
|
|
|
|
|
| old_obj_ids = inference_state["obj_ids"]
|
| old_obj_inds = list(range(len(old_obj_ids)))
|
| remain_old_obj_inds = old_obj_inds.copy()
|
| remain_old_obj_inds.remove(old_obj_idx_to_rm)
|
| new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds]
|
| new_obj_inds = list(range(len(new_obj_ids)))
|
|
|
| old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds))
|
| inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds))
|
| inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids))
|
| inference_state["obj_ids"] = new_obj_ids
|
|
|
|
|
| def _map_keys(container):
|
| new_kvs = []
|
| for k in old_obj_inds:
|
| v = container.pop(k)
|
| if k in old_idx_to_new_idx:
|
| new_kvs.append((old_idx_to_new_idx[k], v))
|
| container.update(new_kvs)
|
|
|
| _map_keys(inference_state["point_inputs_per_obj"])
|
| _map_keys(inference_state["mask_inputs_per_obj"])
|
| _map_keys(inference_state["output_dict_per_obj"])
|
| _map_keys(inference_state["temp_output_dict_per_obj"])
|
| _map_keys(inference_state["frames_tracked_per_obj"])
|
|
|
|
|
|
|
| if need_output:
|
| temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
| for frame_idx in obj_input_frames_inds:
|
| is_cond = any(
|
| frame_idx in obj_temp_output_dict["cond_frame_outputs"]
|
| for obj_temp_output_dict in temp_output_dict_per_obj.values()
|
| )
|
| consolidated_out = self._consolidate_temp_output_across_obj(
|
| inference_state,
|
| frame_idx,
|
| is_cond=is_cond,
|
| consolidate_at_video_res=True,
|
| )
|
| _, video_res_masks = self._get_orig_video_res_output(
|
| inference_state, consolidated_out["pred_masks_video_res"]
|
| )
|
| updated_frames.append((frame_idx, video_res_masks))
|
|
|
| return inference_state["obj_ids"], updated_frames
|
|
|
| def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
|
| """
|
| Remove the non-conditioning memory around the input frame. When users provide
|
| correction clicks, the surrounding frames' non-conditioning memories can still
|
| contain outdated object appearance information and could confuse the model.
|
|
|
| This method clears those non-conditioning memories surrounding the interacted
|
| frame to avoid giving the model both old and new information about the object.
|
| """
|
| r = self.memory_temporal_stride_for_eval
|
| frame_idx_begin = frame_idx - r * self.num_maskmem
|
| frame_idx_end = frame_idx + r * self.num_maskmem
|
| batch_size = self._get_obj_num(inference_state)
|
| for obj_idx in range(batch_size):
|
| obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
| non_cond_frame_outputs = obj_output_dict["non_cond_frame_outputs"]
|
| for t in range(frame_idx_begin, frame_idx_end + 1):
|
| non_cond_frame_outputs.pop(t, None)
|
|
|
|
|
| class SAM2VideoPredictorVOS(SAM2VideoPredictor):
|
| """Optimized for the VOS setting"""
|
|
|
| def __init__(self, *args, **kwargs):
|
| super().__init__(*args, **kwargs)
|
| self._compile_all_components()
|
|
|
| def _compile_all_components(self):
|
| print("Compiling all components for VOS setting. First time may be very slow.")
|
| self.memory_encoder.forward = torch.compile(
|
| self.memory_encoder.forward,
|
| mode="max-autotune",
|
| fullgraph=True,
|
| dynamic=False,
|
| )
|
|
|
| self.memory_attention.forward = torch.compile(
|
| self.memory_attention.forward,
|
| mode="max-autotune",
|
| fullgraph=True,
|
| dynamic=True,
|
| )
|
|
|
| self.sam_prompt_encoder.forward = torch.compile(
|
| self.sam_prompt_encoder.forward,
|
| mode="max-autotune",
|
| fullgraph=True,
|
| dynamic=False,
|
| )
|
|
|
| self.sam_mask_decoder.forward = torch.compile(
|
| self.sam_mask_decoder.forward,
|
| mode="max-autotune",
|
| fullgraph=True,
|
| dynamic=False,
|
| )
|
|
|
| def forward_image(self, img_batch: torch.Tensor):
|
| """
|
| Identical to the corresponding method in the parent (SAM2VideoPredictor), but
|
| cloning the backbone features and pos encoding to enable compilation.
|
| """
|
| backbone_out = self.image_encoder(img_batch)
|
| if self.use_high_res_features_in_sam:
|
|
|
|
|
| backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
|
| backbone_out["backbone_fpn"][0]
|
| )
|
| backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
|
| backbone_out["backbone_fpn"][1]
|
| )
|
|
|
| for i in range(len(backbone_out["backbone_fpn"])):
|
| backbone_out["backbone_fpn"][i] = backbone_out["backbone_fpn"][i].clone()
|
| backbone_out["vision_pos_enc"][i] = backbone_out["vision_pos_enc"][
|
| i
|
| ].clone()
|
| return backbone_out
|
|
|
| def _forward_sam_heads(
|
| self,
|
| backbone_features,
|
| point_inputs=None,
|
| mask_inputs=None,
|
| high_res_features=None,
|
| multimask_output=False,
|
| ):
|
| """
|
| Identical to the corresponding method in the parent (SAM2VideoPredictor), but
|
| cloning the outputs of prompt_encoder and mask_decoder to enable compilation.
|
| """
|
| B = backbone_features.size(0)
|
| device = backbone_features.device
|
| assert backbone_features.size(1) == self.sam_prompt_embed_dim
|
| assert backbone_features.size(2) == self.sam_image_embedding_size
|
| assert backbone_features.size(3) == self.sam_image_embedding_size
|
|
|
|
|
| if point_inputs is not None:
|
| sam_point_coords = point_inputs["point_coords"]
|
| sam_point_labels = point_inputs["point_labels"]
|
| assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
|
| else:
|
|
|
| sam_point_coords = torch.zeros(B, 1, 2, device=device)
|
| sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
|
|
|
|
|
| if mask_inputs is not None:
|
|
|
|
|
| assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
|
| if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
|
| sam_mask_prompt = F.interpolate(
|
| mask_inputs.float(),
|
| size=self.sam_prompt_encoder.mask_input_size,
|
| align_corners=False,
|
| mode="bilinear",
|
| antialias=True,
|
| )
|
| else:
|
| sam_mask_prompt = mask_inputs
|
| else:
|
|
|
|
|
| sam_mask_prompt = None
|
|
|
| sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
|
| points=(sam_point_coords, sam_point_labels),
|
| boxes=None,
|
| masks=sam_mask_prompt,
|
| )
|
|
|
|
|
| sparse_embeddings = sparse_embeddings.clone()
|
| dense_embeddings = dense_embeddings.clone()
|
| image_pe = self.sam_prompt_encoder.get_dense_pe().clone()
|
| (
|
| low_res_multimasks,
|
| ious,
|
| sam_output_tokens,
|
| object_score_logits,
|
| ) = self.sam_mask_decoder(
|
| image_embeddings=backbone_features,
|
| image_pe=image_pe,
|
| sparse_prompt_embeddings=sparse_embeddings,
|
| dense_prompt_embeddings=dense_embeddings,
|
| multimask_output=multimask_output,
|
| repeat_image=False,
|
| high_res_features=high_res_features,
|
| )
|
|
|
|
|
| low_res_multimasks = low_res_multimasks.clone()
|
| ious = ious.clone()
|
| sam_output_tokens = sam_output_tokens.clone()
|
| object_score_logits = object_score_logits.clone()
|
|
|
| if self.pred_obj_scores:
|
| is_obj_appearing = object_score_logits > 0
|
|
|
|
|
|
|
| low_res_multimasks = torch.where(
|
| is_obj_appearing[:, None, None],
|
| low_res_multimasks,
|
| NO_OBJ_SCORE,
|
| )
|
|
|
|
|
|
|
| low_res_multimasks = low_res_multimasks.float()
|
| high_res_multimasks = F.interpolate(
|
| low_res_multimasks,
|
| size=(self.image_size, self.image_size),
|
| mode="bilinear",
|
| align_corners=False,
|
| )
|
|
|
| sam_output_token = sam_output_tokens[:, 0]
|
| if multimask_output:
|
|
|
| best_iou_inds = torch.argmax(ious, dim=-1)
|
| batch_inds = torch.arange(B, device=device)
|
| low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
| high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
| if sam_output_tokens.size(1) > 1:
|
| sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
|
| else:
|
| low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
|
|
|
|
|
| obj_ptr = self.obj_ptr_proj(sam_output_token)
|
| if self.pred_obj_scores:
|
|
|
| if self.soft_no_obj_ptr:
|
| lambda_is_obj_appearing = object_score_logits.sigmoid()
|
| else:
|
| lambda_is_obj_appearing = is_obj_appearing.float()
|
|
|
| if self.fixed_no_obj_ptr:
|
| obj_ptr = lambda_is_obj_appearing * obj_ptr
|
| obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
|
|
|
| return (
|
| low_res_multimasks,
|
| high_res_multimasks,
|
| ious,
|
| low_res_masks,
|
| high_res_masks,
|
| obj_ptr,
|
| object_score_logits,
|
| )
|
|
|
| def _encode_new_memory(
|
| self,
|
| current_vision_feats,
|
| feat_sizes,
|
| pred_masks_high_res,
|
| object_score_logits,
|
| is_mask_from_pts,
|
| ):
|
| """
|
| Identical to the corresponding method in the parent (SAM2VideoPredictor), but
|
| cloning the memories and their pos enc to enable compilation.
|
| """
|
| B = current_vision_feats[-1].size(1)
|
| C = self.hidden_dim
|
| H, W = feat_sizes[-1]
|
|
|
| pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
|
| if self.non_overlap_masks_for_mem_enc and not self.training:
|
|
|
|
|
|
|
| pred_masks_high_res = self._apply_non_overlapping_constraints(
|
| pred_masks_high_res
|
| )
|
|
|
| binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
|
| if binarize and not self.training:
|
| mask_for_mem = (pred_masks_high_res > 0).float()
|
| else:
|
|
|
| mask_for_mem = torch.sigmoid(pred_masks_high_res)
|
|
|
| if self.sigmoid_scale_for_mem_enc != 1.0:
|
| mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
|
| if self.sigmoid_bias_for_mem_enc != 0.0:
|
| mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
|
| maskmem_out = self.memory_encoder(
|
| pix_feat, mask_for_mem, skip_mask_sigmoid=True
|
| )
|
|
|
| maskmem_features = maskmem_out["vision_features"].clone()
|
| maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]]
|
|
|
|
|
| if self.no_obj_embed_spatial is not None:
|
| is_obj_appearing = (object_score_logits > 0).float()
|
| maskmem_features += (
|
| 1 - is_obj_appearing[..., None, None]
|
| ) * self.no_obj_embed_spatial[..., None, None].expand(
|
| *maskmem_features.shape
|
| )
|
|
|
| return maskmem_features, maskmem_pos_enc
|
|
|