|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
|
import torch
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
|
|
|
from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
|
|
|
|
|
|
|
|
|
class SAM2VideoPredictor(SAM2Base):
|
|
|
"""The predictor class to handle user interactions and manage inference states."""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
fill_hole_area=0,
|
|
|
|
|
|
non_overlap_masks=False,
|
|
|
|
|
|
|
|
|
clear_non_cond_mem_around_input=False,
|
|
|
|
|
|
clear_non_cond_mem_for_multi_obj=False,
|
|
|
**kwargs,
|
|
|
):
|
|
|
super().__init__(**kwargs)
|
|
|
self.fill_hole_area = fill_hole_area
|
|
|
self.non_overlap_masks = non_overlap_masks
|
|
|
self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
|
|
|
self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
def init_state(
|
|
|
self,
|
|
|
video_path=None,
|
|
|
images=None,
|
|
|
device="cpu",
|
|
|
async_loading_frames=False,
|
|
|
):
|
|
|
"""Initialize a inference state."""
|
|
|
if images is not None:
|
|
|
images, video_height, video_width = load_video_frames(
|
|
|
video_path=None,
|
|
|
images=images,
|
|
|
image_size=self.image_size,
|
|
|
async_loading_frames=async_loading_frames,
|
|
|
device=device,
|
|
|
)
|
|
|
else:
|
|
|
images, video_height, video_width = load_video_frames(
|
|
|
video_path=video_path,
|
|
|
image_size=self.image_size,
|
|
|
async_loading_frames=async_loading_frames,
|
|
|
device=device,
|
|
|
)
|
|
|
inference_state = dict()
|
|
|
inference_state["images"] = images
|
|
|
inference_state["num_frames"] = len(images)
|
|
|
|
|
|
inference_state["video_height"] = video_height
|
|
|
inference_state["video_width"] = video_width
|
|
|
inference_state["device"] = device
|
|
|
inference_state["storage_device"] = device
|
|
|
|
|
|
inference_state["point_inputs_per_obj"] = {}
|
|
|
inference_state["mask_inputs_per_obj"] = {}
|
|
|
|
|
|
inference_state["cached_features"] = {}
|
|
|
|
|
|
inference_state["constants"] = {}
|
|
|
|
|
|
inference_state["obj_id_to_idx"] = OrderedDict()
|
|
|
inference_state["obj_idx_to_id"] = OrderedDict()
|
|
|
inference_state["obj_ids"] = []
|
|
|
|
|
|
inference_state["output_dict"] = {
|
|
|
"cond_frame_outputs": {},
|
|
|
"non_cond_frame_outputs": {},
|
|
|
}
|
|
|
|
|
|
inference_state["output_dict_per_obj"] = {}
|
|
|
|
|
|
|
|
|
inference_state["temp_output_dict_per_obj"] = {}
|
|
|
|
|
|
|
|
|
inference_state["consolidated_frame_inds"] = {
|
|
|
"cond_frame_outputs": set(),
|
|
|
"non_cond_frame_outputs": set(),
|
|
|
}
|
|
|
|
|
|
inference_state["tracking_has_started"] = False
|
|
|
inference_state["frames_already_tracked"] = {}
|
|
|
|
|
|
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
|
|
return inference_state
|
|
|
|
|
|
def _obj_id_to_idx(self, inference_state, obj_id):
|
|
|
"""Map client-side object id to model-side object index."""
|
|
|
obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
|
|
|
if obj_idx is not None:
|
|
|
return obj_idx
|
|
|
|
|
|
|
|
|
|
|
|
allow_new_object = not inference_state["tracking_has_started"]
|
|
|
if allow_new_object:
|
|
|
|
|
|
obj_idx = len(inference_state["obj_id_to_idx"])
|
|
|
inference_state["obj_id_to_idx"][obj_id] = obj_idx
|
|
|
inference_state["obj_idx_to_id"][obj_idx] = obj_id
|
|
|
inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
|
|
|
|
|
|
inference_state["point_inputs_per_obj"][obj_idx] = {}
|
|
|
inference_state["mask_inputs_per_obj"][obj_idx] = {}
|
|
|
inference_state["output_dict_per_obj"][obj_idx] = {
|
|
|
"cond_frame_outputs": {},
|
|
|
"non_cond_frame_outputs": {},
|
|
|
}
|
|
|
inference_state["temp_output_dict_per_obj"][obj_idx] = {
|
|
|
"cond_frame_outputs": {},
|
|
|
"non_cond_frame_outputs": {},
|
|
|
}
|
|
|
return obj_idx
|
|
|
else:
|
|
|
raise RuntimeError(
|
|
|
f"Cannot add new object id {obj_id} after tracking starts. "
|
|
|
f"All existing object ids: {inference_state['obj_ids']}. "
|
|
|
f"Please call 'reset_state' to restart from scratch."
|
|
|
)
|
|
|
|
|
|
def _obj_idx_to_id(self, inference_state, obj_idx):
|
|
|
"""Map model-side object index to client-side object id."""
|
|
|
return inference_state["obj_idx_to_id"][obj_idx]
|
|
|
|
|
|
def _get_obj_num(self, inference_state):
|
|
|
"""Get the total number of unique object ids received so far in this session."""
|
|
|
return len(inference_state["obj_idx_to_id"])
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
def add_new_points(
|
|
|
self,
|
|
|
inference_state,
|
|
|
frame_idx,
|
|
|
obj_id,
|
|
|
points,
|
|
|
labels,
|
|
|
clear_old_points=True,
|
|
|
normalize_coords=True,
|
|
|
):
|
|
|
"""Add new points to a frame."""
|
|
|
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
|
|
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
|
|
|
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
|
|
|
|
|
|
if not isinstance(points, torch.Tensor):
|
|
|
points = torch.tensor(points, dtype=torch.float32)
|
|
|
if not isinstance(labels, torch.Tensor):
|
|
|
labels = torch.tensor(labels, dtype=torch.int32)
|
|
|
if points.dim() == 2:
|
|
|
points = points.unsqueeze(0)
|
|
|
if labels.dim() == 1:
|
|
|
labels = labels.unsqueeze(0)
|
|
|
if normalize_coords:
|
|
|
video_H = inference_state["video_height"]
|
|
|
video_W = inference_state["video_width"]
|
|
|
points = points / torch.tensor([video_W, video_H]).to(points.device)
|
|
|
|
|
|
points = points * self.image_size
|
|
|
points = points.to(inference_state["device"])
|
|
|
labels = labels.to(inference_state["device"])
|
|
|
|
|
|
if not clear_old_points:
|
|
|
point_inputs = point_inputs_per_frame.get(frame_idx, None)
|
|
|
else:
|
|
|
point_inputs = None
|
|
|
point_inputs = concat_points(point_inputs, points, labels)
|
|
|
|
|
|
point_inputs_per_frame[frame_idx] = point_inputs
|
|
|
mask_inputs_per_frame.pop(frame_idx, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
|
|
|
|
|
|
if is_init_cond_frame:
|
|
|
reverse = False
|
|
|
else:
|
|
|
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
|
|
|
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
|
|
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
|
|
|
|
|
|
|
|
is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
|
|
|
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
|
|
|
|
|
|
|
|
|
|
|
prev_sam_mask_logits = None
|
|
|
|
|
|
|
|
|
prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
|
|
|
if prev_out is None:
|
|
|
prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
|
|
|
if prev_out is None:
|
|
|
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
|
|
|
|
|
|
if prev_out is not None and prev_out["pred_masks"] is not None:
|
|
|
prev_sam_mask_logits = prev_out["pred_masks"].to(inference_state["device"])
|
|
|
|
|
|
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
|
|
|
current_out, _ = self._run_single_frame_inference(
|
|
|
inference_state=inference_state,
|
|
|
output_dict=obj_output_dict,
|
|
|
frame_idx=frame_idx,
|
|
|
batch_size=1,
|
|
|
is_init_cond_frame=is_init_cond_frame,
|
|
|
point_inputs=point_inputs,
|
|
|
mask_inputs=None,
|
|
|
reverse=reverse,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run_mem_encoder=False,
|
|
|
prev_sam_mask_logits=prev_sam_mask_logits,
|
|
|
)
|
|
|
|
|
|
obj_temp_output_dict[storage_key][frame_idx] = current_out
|
|
|
|
|
|
|
|
|
obj_ids = inference_state["obj_ids"]
|
|
|
consolidated_out = self._consolidate_temp_output_across_obj(
|
|
|
inference_state,
|
|
|
frame_idx,
|
|
|
is_cond=is_cond,
|
|
|
run_mem_encoder=False,
|
|
|
consolidate_at_video_res=True,
|
|
|
)
|
|
|
_, video_res_masks = self._get_orig_video_res_output(
|
|
|
inference_state, consolidated_out["pred_masks_video_res"]
|
|
|
)
|
|
|
return frame_idx, obj_ids, video_res_masks
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
def add_new_mask(
|
|
|
self,
|
|
|
inference_state,
|
|
|
frame_idx,
|
|
|
obj_id,
|
|
|
mask,
|
|
|
):
|
|
|
"""Add new mask to a frame."""
|
|
|
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
|
|
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
|
|
|
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
|
|
|
|
|
|
if not isinstance(mask, torch.Tensor):
|
|
|
mask = torch.tensor(mask, dtype=torch.bool)
|
|
|
assert mask.dim() == 2
|
|
|
mask_H, mask_W = mask.shape
|
|
|
mask_inputs_orig = mask[None, None]
|
|
|
mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
|
|
|
|
|
|
|
|
|
if mask_H != self.image_size or mask_W != self.image_size:
|
|
|
mask_inputs = torch.nn.functional.interpolate(
|
|
|
mask_inputs_orig,
|
|
|
size=(self.image_size, self.image_size),
|
|
|
align_corners=False,
|
|
|
mode="bilinear",
|
|
|
antialias=True,
|
|
|
)
|
|
|
mask_inputs = (mask_inputs >= 0.5).float()
|
|
|
else:
|
|
|
mask_inputs = mask_inputs_orig
|
|
|
|
|
|
mask_inputs_per_frame[frame_idx] = mask_inputs
|
|
|
point_inputs_per_frame.pop(frame_idx, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
|
|
|
|
|
|
if is_init_cond_frame:
|
|
|
reverse = False
|
|
|
else:
|
|
|
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
|
|
|
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
|
|
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
|
|
|
|
|
|
|
|
is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
|
|
|
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
|
|
|
|
|
current_out, _ = self._run_single_frame_inference(
|
|
|
inference_state=inference_state,
|
|
|
output_dict=obj_output_dict,
|
|
|
frame_idx=frame_idx,
|
|
|
batch_size=1,
|
|
|
is_init_cond_frame=is_init_cond_frame,
|
|
|
point_inputs=None,
|
|
|
mask_inputs=mask_inputs,
|
|
|
reverse=reverse,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run_mem_encoder=False,
|
|
|
)
|
|
|
|
|
|
obj_temp_output_dict[storage_key][frame_idx] = current_out
|
|
|
|
|
|
|
|
|
obj_ids = inference_state["obj_ids"]
|
|
|
consolidated_out = self._consolidate_temp_output_across_obj(
|
|
|
inference_state,
|
|
|
frame_idx,
|
|
|
is_cond=is_cond,
|
|
|
run_mem_encoder=False,
|
|
|
consolidate_at_video_res=True,
|
|
|
)
|
|
|
_, video_res_masks = self._get_orig_video_res_output(
|
|
|
inference_state, consolidated_out["pred_masks_video_res"]
|
|
|
)
|
|
|
return frame_idx, obj_ids, video_res_masks
|
|
|
|
|
|
def _get_orig_video_res_output(self, inference_state, any_res_masks):
|
|
|
"""
|
|
|
Resize the object scores to the original video resolution (video_res_masks)
|
|
|
and apply non-overlapping constraints for final output.
|
|
|
"""
|
|
|
device = inference_state["device"]
|
|
|
video_H = inference_state["video_height"]
|
|
|
video_W = inference_state["video_width"]
|
|
|
any_res_masks = any_res_masks.to(device, non_blocking=True)
|
|
|
if any_res_masks.shape[-2:] == (video_H, video_W):
|
|
|
video_res_masks = any_res_masks
|
|
|
else:
|
|
|
video_res_masks = torch.nn.functional.interpolate(
|
|
|
any_res_masks,
|
|
|
size=(video_H, video_W),
|
|
|
mode="bilinear",
|
|
|
align_corners=False,
|
|
|
)
|
|
|
if self.non_overlap_masks:
|
|
|
video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
|
|
|
return any_res_masks, video_res_masks
|
|
|
|
|
|
def _consolidate_temp_output_across_obj(
|
|
|
self,
|
|
|
inference_state,
|
|
|
frame_idx,
|
|
|
is_cond,
|
|
|
run_mem_encoder,
|
|
|
consolidate_at_video_res=False,
|
|
|
):
|
|
|
"""
|
|
|
Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
|
|
|
a frame into a single output for all objects, including
|
|
|
1) fill any missing objects either from `output_dict_per_obj` (if they exist in
|
|
|
`output_dict_per_obj` for this frame) or leave them as placeholder values
|
|
|
(if they don't exist in `output_dict_per_obj` for this frame);
|
|
|
2) if specified, rerun memory encoder after apply non-overlapping constraints
|
|
|
on the object scores.
|
|
|
"""
|
|
|
batch_size = self._get_obj_num(inference_state)
|
|
|
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
|
|
|
|
|
|
|
|
if consolidate_at_video_res:
|
|
|
assert not run_mem_encoder, "memory encoder cannot run at video resolution"
|
|
|
consolidated_H = inference_state["video_height"]
|
|
|
consolidated_W = inference_state["video_width"]
|
|
|
consolidated_mask_key = "pred_masks_video_res"
|
|
|
else:
|
|
|
consolidated_H = consolidated_W = self.image_size // 4
|
|
|
consolidated_mask_key = "pred_masks"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
consolidated_out = {
|
|
|
"maskmem_features": None,
|
|
|
"maskmem_pos_enc": None,
|
|
|
consolidated_mask_key: torch.full(
|
|
|
size=(batch_size, 1, consolidated_H, consolidated_W),
|
|
|
fill_value=NO_OBJ_SCORE,
|
|
|
dtype=torch.float32,
|
|
|
device=inference_state["storage_device"],
|
|
|
),
|
|
|
"obj_ptr": torch.full(
|
|
|
size=(batch_size, self.hidden_dim),
|
|
|
fill_value=NO_OBJ_SCORE,
|
|
|
dtype=torch.float32,
|
|
|
device=inference_state["device"],
|
|
|
),
|
|
|
}
|
|
|
empty_mask_ptr = None
|
|
|
for obj_idx in range(batch_size):
|
|
|
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
|
|
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
|
|
out = obj_temp_output_dict[storage_key].get(frame_idx, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if out is None:
|
|
|
out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
|
|
|
if out is None:
|
|
|
out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
|
|
|
|
|
|
|
|
|
|
|
|
if out is None:
|
|
|
|
|
|
|
|
|
|
|
|
if run_mem_encoder:
|
|
|
if empty_mask_ptr is None:
|
|
|
empty_mask_ptr = self._get_empty_mask_ptr(
|
|
|
inference_state, frame_idx
|
|
|
)
|
|
|
|
|
|
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
|
|
|
continue
|
|
|
|
|
|
obj_mask = out["pred_masks"]
|
|
|
consolidated_pred_masks = consolidated_out[consolidated_mask_key]
|
|
|
if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
|
|
|
consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
|
|
|
else:
|
|
|
|
|
|
resized_obj_mask = torch.nn.functional.interpolate(
|
|
|
obj_mask,
|
|
|
size=consolidated_pred_masks.shape[-2:],
|
|
|
mode="bilinear",
|
|
|
align_corners=False,
|
|
|
)
|
|
|
consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
|
|
|
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
|
|
|
|
|
|
|
|
|
|
|
|
if run_mem_encoder:
|
|
|
device = inference_state["device"]
|
|
|
high_res_masks = torch.nn.functional.interpolate(
|
|
|
consolidated_out["pred_masks"].to(device, non_blocking=True),
|
|
|
size=(self.image_size, self.image_size),
|
|
|
mode="bilinear",
|
|
|
align_corners=False,
|
|
|
)
|
|
|
if self.non_overlap_masks_for_mem_enc:
|
|
|
high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
|
|
|
maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
|
|
|
inference_state=inference_state,
|
|
|
frame_idx=frame_idx,
|
|
|
batch_size=batch_size,
|
|
|
high_res_masks=high_res_masks,
|
|
|
is_mask_from_pts=True,
|
|
|
)
|
|
|
consolidated_out["maskmem_features"] = maskmem_features
|
|
|
consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
|
|
|
|
|
|
return consolidated_out
|
|
|
|
|
|
def _get_empty_mask_ptr(self, inference_state, frame_idx):
|
|
|
"""Get a dummy object pointer based on an empty mask on the current frame."""
|
|
|
|
|
|
batch_size = 1
|
|
|
mask_inputs = torch.zeros(
|
|
|
(batch_size, 1, self.image_size, self.image_size),
|
|
|
dtype=torch.float32,
|
|
|
device=inference_state["device"],
|
|
|
)
|
|
|
|
|
|
|
|
|
(
|
|
|
_,
|
|
|
_,
|
|
|
current_vision_feats,
|
|
|
current_vision_pos_embeds,
|
|
|
feat_sizes,
|
|
|
) = self._get_image_feature(inference_state, frame_idx, batch_size)
|
|
|
|
|
|
|
|
|
current_out = self.track_step(
|
|
|
frame_idx=frame_idx,
|
|
|
is_init_cond_frame=True,
|
|
|
current_vision_feats=current_vision_feats,
|
|
|
current_vision_pos_embeds=current_vision_pos_embeds,
|
|
|
feat_sizes=feat_sizes,
|
|
|
point_inputs=None,
|
|
|
mask_inputs=mask_inputs,
|
|
|
output_dict={},
|
|
|
num_frames=inference_state["num_frames"],
|
|
|
track_in_reverse=False,
|
|
|
run_mem_encoder=False,
|
|
|
prev_sam_mask_logits=None,
|
|
|
)
|
|
|
return current_out["obj_ptr"]
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
def propagate_in_video_preflight(self, inference_state):
|
|
|
"""Prepare inference_state and consolidate temporary outputs before tracking."""
|
|
|
|
|
|
inference_state["tracking_has_started"] = True
|
|
|
batch_size = self._get_obj_num(inference_state)
|
|
|
|
|
|
|
|
|
|
|
|
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
|
|
output_dict = inference_state["output_dict"]
|
|
|
|
|
|
|
|
|
|
|
|
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
|
|
for is_cond in [False, True]:
|
|
|
|
|
|
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
|
|
|
|
|
|
|
|
|
|
|
temp_frame_inds = set()
|
|
|
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
|
|
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
|
|
|
consolidated_frame_inds[storage_key].update(temp_frame_inds)
|
|
|
|
|
|
for frame_idx in temp_frame_inds:
|
|
|
consolidated_out = self._consolidate_temp_output_across_obj(
|
|
|
inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
|
|
|
)
|
|
|
|
|
|
output_dict[storage_key][frame_idx] = consolidated_out
|
|
|
self._add_output_per_object(
|
|
|
inference_state, frame_idx, consolidated_out, storage_key
|
|
|
)
|
|
|
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
|
|
|
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
|
|
|
)
|
|
|
if clear_non_cond_mem:
|
|
|
|
|
|
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
|
|
|
|
|
|
|
|
|
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
|
|
obj_temp_output_dict[storage_key].clear()
|
|
|
|
|
|
|
|
|
|
|
|
for frame_idx in output_dict["cond_frame_outputs"]:
|
|
|
output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
|
|
for obj_output_dict in inference_state["output_dict_per_obj"].values():
|
|
|
for frame_idx in obj_output_dict["cond_frame_outputs"]:
|
|
|
obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
|
|
for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
|
|
|
assert frame_idx in output_dict["cond_frame_outputs"]
|
|
|
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
|
|
|
|
|
|
|
|
|
|
|
|
all_consolidated_frame_inds = (
|
|
|
consolidated_frame_inds["cond_frame_outputs"]
|
|
|
| consolidated_frame_inds["non_cond_frame_outputs"]
|
|
|
)
|
|
|
input_frames_inds = set()
|
|
|
for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
|
|
|
input_frames_inds.update(point_inputs_per_frame.keys())
|
|
|
for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
|
|
|
input_frames_inds.update(mask_inputs_per_frame.keys())
|
|
|
assert all_consolidated_frame_inds == input_frames_inds
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
def propagate_in_video(
|
|
|
self,
|
|
|
inference_state,
|
|
|
start_frame_idx=None,
|
|
|
max_frame_num_to_track=None,
|
|
|
reverse=False,
|
|
|
):
|
|
|
"""Propagate the input points across frames to track in the entire video."""
|
|
|
self.propagate_in_video_preflight(inference_state)
|
|
|
|
|
|
output_dict = inference_state["output_dict"]
|
|
|
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
|
|
obj_ids = inference_state["obj_ids"]
|
|
|
num_frames = inference_state["num_frames"]
|
|
|
batch_size = self._get_obj_num(inference_state)
|
|
|
if len(output_dict["cond_frame_outputs"]) == 0:
|
|
|
raise RuntimeError("No points are provided; please add points first")
|
|
|
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
|
|
|
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
|
|
|
)
|
|
|
|
|
|
|
|
|
if start_frame_idx is None:
|
|
|
|
|
|
start_frame_idx = min(output_dict["cond_frame_outputs"])
|
|
|
if max_frame_num_to_track is None:
|
|
|
|
|
|
max_frame_num_to_track = num_frames
|
|
|
if reverse:
|
|
|
end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
|
|
|
if start_frame_idx > 0:
|
|
|
processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
|
|
|
else:
|
|
|
processing_order = []
|
|
|
else:
|
|
|
end_frame_idx = min(
|
|
|
start_frame_idx + max_frame_num_to_track, num_frames - 1
|
|
|
)
|
|
|
processing_order = range(start_frame_idx, end_frame_idx + 1)
|
|
|
|
|
|
for frame_idx in tqdm(processing_order, desc="propagate in video"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
|
|
|
storage_key = "cond_frame_outputs"
|
|
|
current_out = output_dict[storage_key][frame_idx]
|
|
|
pred_masks = current_out["pred_masks"]
|
|
|
if clear_non_cond_mem:
|
|
|
|
|
|
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
|
|
|
elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
|
|
|
storage_key = "non_cond_frame_outputs"
|
|
|
current_out = output_dict[storage_key][frame_idx]
|
|
|
pred_masks = current_out["pred_masks"]
|
|
|
else:
|
|
|
storage_key = "non_cond_frame_outputs"
|
|
|
current_out, pred_masks = self._run_single_frame_inference(
|
|
|
inference_state=inference_state,
|
|
|
output_dict=output_dict,
|
|
|
frame_idx=frame_idx,
|
|
|
batch_size=batch_size,
|
|
|
is_init_cond_frame=False,
|
|
|
point_inputs=None,
|
|
|
mask_inputs=None,
|
|
|
reverse=reverse,
|
|
|
run_mem_encoder=True,
|
|
|
)
|
|
|
output_dict[storage_key][frame_idx] = current_out
|
|
|
|
|
|
|
|
|
self._add_output_per_object(
|
|
|
inference_state, frame_idx, current_out, storage_key
|
|
|
)
|
|
|
inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
|
|
|
|
|
|
|
|
|
|
|
|
_, video_res_masks = self._get_orig_video_res_output(
|
|
|
inference_state, pred_masks
|
|
|
)
|
|
|
yield frame_idx, obj_ids, video_res_masks
|
|
|
|
|
|
def _add_output_per_object(
|
|
|
self, inference_state, frame_idx, current_out, storage_key
|
|
|
):
|
|
|
"""
|
|
|
Split a multi-object output into per-object output slices and add them into
|
|
|
`output_dict_per_obj`. The resulting slices share the same tensor storage.
|
|
|
"""
|
|
|
maskmem_features = current_out["maskmem_features"]
|
|
|
assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
|
|
|
|
|
|
maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
|
|
assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
|
|
|
|
|
|
output_dict_per_obj = inference_state["output_dict_per_obj"]
|
|
|
for obj_idx, obj_output_dict in output_dict_per_obj.items():
|
|
|
obj_slice = slice(obj_idx, obj_idx + 1)
|
|
|
obj_out = {
|
|
|
"maskmem_features": None,
|
|
|
"maskmem_pos_enc": None,
|
|
|
"pred_masks": current_out["pred_masks"][obj_slice],
|
|
|
"obj_ptr": current_out["obj_ptr"][obj_slice],
|
|
|
}
|
|
|
if maskmem_features is not None:
|
|
|
obj_out["maskmem_features"] = maskmem_features[obj_slice]
|
|
|
if maskmem_pos_enc is not None:
|
|
|
obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
|
|
|
obj_output_dict[storage_key][frame_idx] = obj_out
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
def reset_state(self, inference_state):
|
|
|
"""Remove all input points or mask in all frames throughout the video."""
|
|
|
self._reset_tracking_results(inference_state)
|
|
|
|
|
|
inference_state["obj_id_to_idx"].clear()
|
|
|
inference_state["obj_idx_to_id"].clear()
|
|
|
inference_state["obj_ids"].clear()
|
|
|
inference_state["point_inputs_per_obj"].clear()
|
|
|
inference_state["mask_inputs_per_obj"].clear()
|
|
|
inference_state["output_dict_per_obj"].clear()
|
|
|
inference_state["temp_output_dict_per_obj"].clear()
|
|
|
|
|
|
def _reset_tracking_results(self, inference_state):
|
|
|
"""Reset all tracking inputs and results across the videos."""
|
|
|
for v in inference_state["point_inputs_per_obj"].values():
|
|
|
v.clear()
|
|
|
for v in inference_state["mask_inputs_per_obj"].values():
|
|
|
v.clear()
|
|
|
for v in inference_state["output_dict_per_obj"].values():
|
|
|
v["cond_frame_outputs"].clear()
|
|
|
v["non_cond_frame_outputs"].clear()
|
|
|
for v in inference_state["temp_output_dict_per_obj"].values():
|
|
|
v["cond_frame_outputs"].clear()
|
|
|
v["non_cond_frame_outputs"].clear()
|
|
|
inference_state["output_dict"]["cond_frame_outputs"].clear()
|
|
|
inference_state["output_dict"]["non_cond_frame_outputs"].clear()
|
|
|
inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
|
|
|
inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
|
|
|
inference_state["tracking_has_started"] = False
|
|
|
inference_state["frames_already_tracked"].clear()
|
|
|
|
|
|
def _get_image_feature(self, inference_state, frame_idx, batch_size):
|
|
|
"""Compute the image features on a given frame."""
|
|
|
|
|
|
image, backbone_out = inference_state["cached_features"].get(
|
|
|
frame_idx, (None, None)
|
|
|
)
|
|
|
if backbone_out is None:
|
|
|
|
|
|
image = (
|
|
|
inference_state["images"][frame_idx]
|
|
|
.to(inference_state["device"])
|
|
|
.float()
|
|
|
.unsqueeze(0)
|
|
|
)
|
|
|
backbone_out = self.forward_image(image)
|
|
|
|
|
|
|
|
|
inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
|
|
|
|
|
|
|
|
|
expanded_image = image.expand(batch_size, -1, -1, -1)
|
|
|
expanded_backbone_out = {
|
|
|
"backbone_fpn": backbone_out["backbone_fpn"].copy(),
|
|
|
"vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
|
|
|
}
|
|
|
for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
|
|
|
expanded_backbone_out["backbone_fpn"][i] = feat.expand(
|
|
|
batch_size, -1, -1, -1
|
|
|
)
|
|
|
for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
|
|
|
pos = pos.expand(batch_size, -1, -1, -1)
|
|
|
expanded_backbone_out["vision_pos_enc"][i] = pos
|
|
|
|
|
|
features = self._prepare_backbone_features(expanded_backbone_out)
|
|
|
features = (expanded_image,) + features
|
|
|
return features
|
|
|
|
|
|
def _run_single_frame_inference(
|
|
|
self,
|
|
|
inference_state,
|
|
|
output_dict,
|
|
|
frame_idx,
|
|
|
batch_size,
|
|
|
is_init_cond_frame,
|
|
|
point_inputs,
|
|
|
mask_inputs,
|
|
|
reverse,
|
|
|
run_mem_encoder,
|
|
|
prev_sam_mask_logits=None,
|
|
|
):
|
|
|
"""Run tracking on a single frame based on current inputs and previous memory."""
|
|
|
|
|
|
(
|
|
|
_,
|
|
|
_,
|
|
|
current_vision_feats,
|
|
|
current_vision_pos_embeds,
|
|
|
feat_sizes,
|
|
|
) = self._get_image_feature(inference_state, frame_idx, batch_size)
|
|
|
|
|
|
|
|
|
assert point_inputs is None or mask_inputs is None
|
|
|
current_out = self.track_step(
|
|
|
frame_idx=frame_idx,
|
|
|
is_init_cond_frame=is_init_cond_frame,
|
|
|
current_vision_feats=current_vision_feats,
|
|
|
current_vision_pos_embeds=current_vision_pos_embeds,
|
|
|
feat_sizes=feat_sizes,
|
|
|
point_inputs=point_inputs,
|
|
|
mask_inputs=mask_inputs,
|
|
|
output_dict=output_dict,
|
|
|
num_frames=inference_state["num_frames"],
|
|
|
track_in_reverse=reverse,
|
|
|
run_mem_encoder=run_mem_encoder,
|
|
|
prev_sam_mask_logits=prev_sam_mask_logits,
|
|
|
)
|
|
|
|
|
|
|
|
|
storage_device = inference_state["storage_device"]
|
|
|
maskmem_features = current_out["maskmem_features"]
|
|
|
if maskmem_features is not None:
|
|
|
maskmem_features = maskmem_features.to(torch.bfloat16)
|
|
|
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
|
|
|
pred_masks_gpu = current_out["pred_masks"]
|
|
|
|
|
|
if self.fill_hole_area > 0:
|
|
|
pred_masks_gpu = fill_holes_in_mask_scores(
|
|
|
pred_masks_gpu, self.fill_hole_area
|
|
|
)
|
|
|
pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
|
|
|
|
|
|
maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
|
|
|
|
|
|
obj_ptr = current_out["obj_ptr"]
|
|
|
|
|
|
compact_current_out = {
|
|
|
"maskmem_features": maskmem_features,
|
|
|
"maskmem_pos_enc": maskmem_pos_enc,
|
|
|
"pred_masks": pred_masks,
|
|
|
"obj_ptr": obj_ptr,
|
|
|
}
|
|
|
return compact_current_out, pred_masks_gpu
|
|
|
|
|
|
def _run_memory_encoder(
|
|
|
self, inference_state, frame_idx, batch_size, high_res_masks, is_mask_from_pts
|
|
|
):
|
|
|
"""
|
|
|
Run the memory encoder on `high_res_masks`. This is usually after applying
|
|
|
non-overlapping constraints to object scores. Since their scores changed, their
|
|
|
memory also need to be computed again with the memory encoder.
|
|
|
"""
|
|
|
|
|
|
_, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
|
|
|
inference_state, frame_idx, batch_size
|
|
|
)
|
|
|
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
|
|
|
current_vision_feats=current_vision_feats,
|
|
|
feat_sizes=feat_sizes,
|
|
|
pred_masks_high_res=high_res_masks,
|
|
|
is_mask_from_pts=is_mask_from_pts,
|
|
|
)
|
|
|
|
|
|
|
|
|
storage_device = inference_state["storage_device"]
|
|
|
maskmem_features = maskmem_features.to(torch.bfloat16)
|
|
|
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
|
|
|
|
|
|
maskmem_pos_enc = self._get_maskmem_pos_enc(
|
|
|
inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
|
|
|
)
|
|
|
return maskmem_features, maskmem_pos_enc
|
|
|
|
|
|
def _get_maskmem_pos_enc(self, inference_state, current_out):
|
|
|
"""
|
|
|
`maskmem_pos_enc` is the same across frames and objects, so we cache it as
|
|
|
a constant in the inference session to reduce session storage size.
|
|
|
"""
|
|
|
model_constants = inference_state["constants"]
|
|
|
|
|
|
out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
|
|
if out_maskmem_pos_enc is not None:
|
|
|
if "maskmem_pos_enc" not in model_constants:
|
|
|
assert isinstance(out_maskmem_pos_enc, list)
|
|
|
|
|
|
maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
|
|
|
model_constants["maskmem_pos_enc"] = maskmem_pos_enc
|
|
|
else:
|
|
|
maskmem_pos_enc = model_constants["maskmem_pos_enc"]
|
|
|
|
|
|
batch_size = out_maskmem_pos_enc[0].size(0)
|
|
|
expanded_maskmem_pos_enc = [
|
|
|
x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
|
|
|
]
|
|
|
else:
|
|
|
expanded_maskmem_pos_enc = None
|
|
|
return expanded_maskmem_pos_enc
|
|
|
|
|
|
def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
|
|
|
"""
|
|
|
Remove the non-conditioning memory around the input frame. When users provide
|
|
|
correction clicks, the surrounding frames' non-conditioning memories can still
|
|
|
contain outdated object appearance information and could confuse the model.
|
|
|
|
|
|
This method clears those non-conditioning memories surrounding the interacted
|
|
|
frame to avoid giving the model both old and new information about the object.
|
|
|
"""
|
|
|
r = self.memory_temporal_stride_for_eval
|
|
|
frame_idx_begin = frame_idx - r * self.num_maskmem
|
|
|
frame_idx_end = frame_idx + r * self.num_maskmem
|
|
|
output_dict = inference_state["output_dict"]
|
|
|
non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
|
|
|
for t in range(frame_idx_begin, frame_idx_end + 1):
|
|
|
non_cond_frame_outputs.pop(t, None)
|
|
|
for obj_output_dict in inference_state["output_dict_per_obj"].values():
|
|
|
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
|
|
|
|