Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025. Your modifications here. | |
| # A wrapper for sam2 functions | |
| from collections import OrderedDict | |
| import torch | |
| from tqdm import tqdm | |
| from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base | |
| from sam2.sam2_video_predictor import SAM2VideoPredictor as _SAM2VideoPredictor | |
| from sam2.utils.misc import concat_points, fill_holes_in_mask_scores | |
| from sam_utils import load_video_frames_v2, load_video_frames | |
| class SAM2VideoPredictor(_SAM2VideoPredictor): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def init_state( | |
| self, | |
| video_path, | |
| offload_video_to_cpu=False, | |
| offload_state_to_cpu=False, | |
| async_loading_frames=False, | |
| frame_names=None | |
| ): | |
| """Initialize a inference state.""" | |
| images, video_height, video_width = load_video_frames( | |
| video_path=video_path, | |
| image_size=self.image_size, | |
| offload_video_to_cpu=offload_video_to_cpu, | |
| async_loading_frames=async_loading_frames, | |
| frame_names=frame_names | |
| ) | |
| inference_state = {} | |
| inference_state["images"] = images | |
| inference_state["num_frames"] = len(images) | |
| # whether to offload the video frames to CPU memory | |
| # turning on this option saves the GPU memory with only a very small overhead | |
| inference_state["offload_video_to_cpu"] = offload_video_to_cpu | |
| # whether to offload the inference state to CPU memory | |
| # turning on this option saves the GPU memory at the cost of a lower tracking fps | |
| # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object | |
| # and from 24 to 21 when tracking two objects) | |
| inference_state["offload_state_to_cpu"] = offload_state_to_cpu | |
| # the original video height and width, used for resizing final output scores | |
| inference_state["video_height"] = video_height | |
| inference_state["video_width"] = video_width | |
| inference_state["device"] = torch.device("cuda") | |
| if offload_state_to_cpu: | |
| inference_state["storage_device"] = torch.device("cpu") | |
| else: | |
| inference_state["storage_device"] = torch.device("cuda") | |
| # inputs on each frame | |
| inference_state["point_inputs_per_obj"] = {} | |
| inference_state["mask_inputs_per_obj"] = {} | |
| # visual features on a small number of recently visited frames for quick interactions | |
| inference_state["cached_features"] = {} | |
| # values that don't change across frames (so we only need to hold one copy of them) | |
| inference_state["constants"] = {} | |
| # mapping between client-side object id and model-side object index | |
| inference_state["obj_id_to_idx"] = OrderedDict() | |
| inference_state["obj_idx_to_id"] = OrderedDict() | |
| inference_state["obj_ids"] = [] | |
| # A storage to hold the model's tracking results and states on each frame | |
| inference_state["output_dict"] = { | |
| "cond_frame_outputs": {}, # dict containing {frame_idx: <out>} | |
| "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>} | |
| } | |
| # Slice (view) of each object tracking results, sharing the same memory with "output_dict" | |
| inference_state["output_dict_per_obj"] = {} | |
| # A temporary storage to hold new outputs when user interact with a frame | |
| # to add clicks or mask (it's merged into "output_dict" before propagation starts) | |
| inference_state["temp_output_dict_per_obj"] = {} | |
| # Frames that already holds consolidated outputs from click or mask inputs | |
| # (we directly use their consolidated outputs during tracking) | |
| inference_state["consolidated_frame_inds"] = { | |
| "cond_frame_outputs": set(), # set containing frame indices | |
| "non_cond_frame_outputs": set(), # set containing frame indices | |
| } | |
| # metadata for each tracking frame (e.g. which direction it's tracked) | |
| inference_state["tracking_has_started"] = False | |
| inference_state["frames_already_tracked"] = {} | |
| # Warm up the visual backbone and cache the image feature on frame 0 | |
| self._get_image_feature(inference_state, frame_idx=0, batch_size=1) | |
| return inference_state | |
| def init_state_v2( | |
| self, | |
| frames, | |
| offload_video_to_cpu=False, | |
| offload_state_to_cpu=False, | |
| async_loading_frames=False, | |
| frame_names=None | |
| ): | |
| """Initialize a inference state.""" | |
| images, video_height, video_width = load_video_frames_v2( | |
| frames=frames, | |
| image_size=self.image_size, | |
| offload_video_to_cpu=offload_video_to_cpu, | |
| async_loading_frames=async_loading_frames, | |
| frame_names=frame_names | |
| ) | |
| inference_state = {} | |
| inference_state["images"] = images | |
| inference_state["num_frames"] = len(images) | |
| # whether to offload the video frames to CPU memory | |
| # turning on this option saves the GPU memory with only a very small overhead | |
| inference_state["offload_video_to_cpu"] = offload_video_to_cpu | |
| # whether to offload the inference state to CPU memory | |
| # turning on this option saves the GPU memory at the cost of a lower tracking fps | |
| # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object | |
| # and from 24 to 21 when tracking two objects) | |
| inference_state["offload_state_to_cpu"] = offload_state_to_cpu | |
| # the original video height and width, used for resizing final output scores | |
| inference_state["video_height"] = video_height | |
| inference_state["video_width"] = video_width | |
| inference_state["device"] = torch.device("cuda") | |
| if offload_state_to_cpu: | |
| inference_state["storage_device"] = torch.device("cpu") | |
| else: | |
| inference_state["storage_device"] = torch.device("cuda") | |
| # inputs on each frame | |
| inference_state["point_inputs_per_obj"] = {} | |
| inference_state["mask_inputs_per_obj"] = {} | |
| # visual features on a small number of recently visited frames for quick interactions | |
| inference_state["cached_features"] = {} | |
| # values that don't change across frames (so we only need to hold one copy of them) | |
| inference_state["constants"] = {} | |
| # mapping between client-side object id and model-side object index | |
| inference_state["obj_id_to_idx"] = OrderedDict() | |
| inference_state["obj_idx_to_id"] = OrderedDict() | |
| inference_state["obj_ids"] = [] | |
| # A storage to hold the model's tracking results and states on each frame | |
| inference_state["output_dict"] = { | |
| "cond_frame_outputs": {}, # dict containing {frame_idx: <out>} | |
| "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>} | |
| } | |
| # Slice (view) of each object tracking results, sharing the same memory with "output_dict" | |
| inference_state["output_dict_per_obj"] = {} | |
| # A temporary storage to hold new outputs when user interact with a frame | |
| # to add clicks or mask (it's merged into "output_dict" before propagation starts) | |
| inference_state["temp_output_dict_per_obj"] = {} | |
| # Frames that already holds consolidated outputs from click or mask inputs | |
| # (we directly use their consolidated outputs during tracking) | |
| inference_state["consolidated_frame_inds"] = { | |
| "cond_frame_outputs": set(), # set containing frame indices | |
| "non_cond_frame_outputs": set(), # set containing frame indices | |
| } | |
| # metadata for each tracking frame (e.g. which direction it's tracked) | |
| inference_state["tracking_has_started"] = False | |
| inference_state["frames_already_tracked"] = {} | |
| # Warm up the visual backbone and cache the image feature on frame 0 | |
| self._get_image_feature(inference_state, frame_idx=0, batch_size=1) | |
| return inference_state |