|
|
|
|
|
|
|
|
from collections import OrderedDict |
|
|
import torch |
|
|
from tqdm import tqdm |
|
|
|
|
|
from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base |
|
|
from sam2.sam2_video_predictor import SAM2VideoPredictor as _SAM2VideoPredictor |
|
|
from sam2.utils.misc import concat_points, fill_holes_in_mask_scores |
|
|
|
|
|
from sam_utils import load_video_frames_v2, load_video_frames |
|
|
|
|
|
|
|
|
class SAM2VideoPredictor(_SAM2VideoPredictor): |
|
|
def __init__(self, *args, **kwargs): |
|
|
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def init_state( |
|
|
self, |
|
|
video_path, |
|
|
offload_video_to_cpu=False, |
|
|
offload_state_to_cpu=False, |
|
|
async_loading_frames=False, |
|
|
frame_names=None |
|
|
): |
|
|
"""Initialize a inference state.""" |
|
|
images, video_height, video_width = load_video_frames( |
|
|
video_path=video_path, |
|
|
image_size=self.image_size, |
|
|
offload_video_to_cpu=offload_video_to_cpu, |
|
|
async_loading_frames=async_loading_frames, |
|
|
frame_names=frame_names |
|
|
) |
|
|
inference_state = {} |
|
|
inference_state["images"] = images |
|
|
inference_state["num_frames"] = len(images) |
|
|
|
|
|
|
|
|
inference_state["offload_video_to_cpu"] = offload_video_to_cpu |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference_state["offload_state_to_cpu"] = offload_state_to_cpu |
|
|
|
|
|
inference_state["video_height"] = video_height |
|
|
inference_state["video_width"] = video_width |
|
|
inference_state["device"] = torch.device("cuda") |
|
|
if offload_state_to_cpu: |
|
|
inference_state["storage_device"] = torch.device("cpu") |
|
|
else: |
|
|
inference_state["storage_device"] = torch.device("cuda") |
|
|
|
|
|
inference_state["point_inputs_per_obj"] = {} |
|
|
inference_state["mask_inputs_per_obj"] = {} |
|
|
|
|
|
inference_state["cached_features"] = {} |
|
|
|
|
|
inference_state["constants"] = {} |
|
|
|
|
|
inference_state["obj_id_to_idx"] = OrderedDict() |
|
|
inference_state["obj_idx_to_id"] = OrderedDict() |
|
|
inference_state["obj_ids"] = [] |
|
|
|
|
|
inference_state["output_dict"] = { |
|
|
"cond_frame_outputs": {}, |
|
|
"non_cond_frame_outputs": {}, |
|
|
} |
|
|
|
|
|
inference_state["output_dict_per_obj"] = {} |
|
|
|
|
|
|
|
|
inference_state["temp_output_dict_per_obj"] = {} |
|
|
|
|
|
|
|
|
inference_state["consolidated_frame_inds"] = { |
|
|
"cond_frame_outputs": set(), |
|
|
"non_cond_frame_outputs": set(), |
|
|
} |
|
|
|
|
|
inference_state["tracking_has_started"] = False |
|
|
inference_state["frames_already_tracked"] = {} |
|
|
|
|
|
self._get_image_feature(inference_state, frame_idx=0, batch_size=1) |
|
|
return inference_state |
|
|
|
|
|
@torch.inference_mode() |
|
|
def init_state_v2( |
|
|
self, |
|
|
frames, |
|
|
offload_video_to_cpu=False, |
|
|
offload_state_to_cpu=False, |
|
|
async_loading_frames=False, |
|
|
frame_names=None |
|
|
): |
|
|
"""Initialize a inference state.""" |
|
|
images, video_height, video_width = load_video_frames_v2( |
|
|
frames=frames, |
|
|
image_size=self.image_size, |
|
|
offload_video_to_cpu=offload_video_to_cpu, |
|
|
async_loading_frames=async_loading_frames, |
|
|
frame_names=frame_names |
|
|
) |
|
|
inference_state = {} |
|
|
inference_state["images"] = images |
|
|
inference_state["num_frames"] = len(images) |
|
|
|
|
|
|
|
|
inference_state["offload_video_to_cpu"] = offload_video_to_cpu |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference_state["offload_state_to_cpu"] = offload_state_to_cpu |
|
|
|
|
|
inference_state["video_height"] = video_height |
|
|
inference_state["video_width"] = video_width |
|
|
inference_state["device"] = torch.device("cuda") |
|
|
if offload_state_to_cpu: |
|
|
inference_state["storage_device"] = torch.device("cpu") |
|
|
else: |
|
|
inference_state["storage_device"] = torch.device("cuda") |
|
|
|
|
|
inference_state["point_inputs_per_obj"] = {} |
|
|
inference_state["mask_inputs_per_obj"] = {} |
|
|
|
|
|
inference_state["cached_features"] = {} |
|
|
|
|
|
inference_state["constants"] = {} |
|
|
|
|
|
inference_state["obj_id_to_idx"] = OrderedDict() |
|
|
inference_state["obj_idx_to_id"] = OrderedDict() |
|
|
inference_state["obj_ids"] = [] |
|
|
|
|
|
inference_state["output_dict"] = { |
|
|
"cond_frame_outputs": {}, |
|
|
"non_cond_frame_outputs": {}, |
|
|
} |
|
|
|
|
|
inference_state["output_dict_per_obj"] = {} |
|
|
|
|
|
|
|
|
inference_state["temp_output_dict_per_obj"] = {} |
|
|
|
|
|
|
|
|
inference_state["consolidated_frame_inds"] = { |
|
|
"cond_frame_outputs": set(), |
|
|
"non_cond_frame_outputs": set(), |
|
|
} |
|
|
|
|
|
inference_state["tracking_has_started"] = False |
|
|
inference_state["frames_already_tracked"] = {} |
|
|
|
|
|
self._get_image_feature(inference_state, frame_idx=0, batch_size=1) |
|
|
return inference_state |