| |
|
|
| |
|
|
| """ |
| Base predictor class shared by SAM3 and SAM3.1 (multiplex) video predictors. |
| |
| Provides the common handle_request/handle_stream_request API and session management. |
| Subclasses only need to override methods where their behavior differs. |
| """ |
|
|
| import gc |
| import time |
| import uuid |
| from typing import Dict, List, Optional |
|
|
| import torch |
| from ..logger import get_logger |
| from ..model.device_utils import accelerator_autocast, empty_accelerator_cache |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| class Sam3BasePredictor: |
| """ |
| Base class for SAM3 video predictors. Provides: |
| - Session management (start, reset, close) |
| - Request dispatch (handle_request / handle_stream_request) |
| - Common add_prompt / propagate_in_video / remove_object / reset_session / close_session |
| |
| Subclasses must set `self.model` and `self._all_inference_states` before use. |
| """ |
|
|
| def __init__(self): |
| |
| self.model = None |
| self._all_inference_states: Dict[str, dict] = {} |
|
|
| @staticmethod |
| def _bf16_autocast(): |
| return accelerator_autocast() |
|
|
| |
|
|
| @torch.inference_mode() |
| def handle_request(self, request): |
| """Dispatch a request based on its type.""" |
| request_type = request["type"] |
| if request_type == "start_session": |
| return self.start_session( |
| resource_path=request["resource_path"], |
| session_id=request.get("session_id", None), |
| offload_video_to_cpu=request.get("offload_video_to_cpu", False), |
| offload_state_to_cpu=request.get("offload_state_to_cpu", False), |
| cache_frame_outputs=request.get("cache_frame_outputs", True), |
| ) |
| elif request_type == "add_prompt": |
| return self.add_prompt( |
| session_id=request["session_id"], |
| frame_idx=request["frame_index"], |
| text=request.get("text", None), |
| points=request.get("points", None), |
| point_labels=request.get("point_labels", None), |
| clear_old_points=request.get("clear_old_points", True), |
| bounding_boxes=request.get("bounding_boxes", None), |
| bounding_box_labels=request.get("bounding_box_labels", None), |
| clear_old_boxes=request.get("clear_old_boxes", True), |
| output_prob_thresh=request.get( |
| "output_prob_thresh", |
| getattr(self, "default_output_prob_thresh", 0.5), |
| ), |
| obj_id=request.get("obj_id", None), |
| rel_coordinates=request.get("rel_coordinates", True), |
| preencoded_text_outputs=request.get("preencoded_text_outputs", None), |
| ) |
| elif request_type == "remove_object": |
| return self.remove_object( |
| session_id=request["session_id"], |
| frame_idx=request.get("frame_index", 0), |
| obj_id=request["obj_id"], |
| ) |
| elif request_type == "reset_session": |
| return self.reset_session(session_id=request["session_id"]) |
| elif request_type == "cancel_propagation": |
| return self.cancel_propagation(session_id=request["session_id"]) |
| elif request_type == "close_session": |
| return self.close_session( |
| session_id=request["session_id"], |
| run_gc_collect=request.get("run_gc_collect", True), |
| ) |
| else: |
| raise RuntimeError(f"invalid request type: {request_type}") |
|
|
| @torch.inference_mode() |
| def handle_stream_request(self, request): |
| """Dispatch a stream request based on its type.""" |
| request_type = request["type"] |
| if request_type == "propagate_in_video": |
| yield from self.propagate_in_video( |
| session_id=request["session_id"], |
| propagation_direction=request.get("propagation_direction", "both"), |
| start_frame_idx=request.get("start_frame_index", None), |
| max_frame_num_to_track=request.get("max_frame_num_to_track", None), |
| output_prob_thresh=request.get( |
| "output_prob_thresh", |
| getattr(self, "default_output_prob_thresh", 0.5), |
| ), |
| progress_callback=request.get("progress_callback", None), |
| ) |
| else: |
| raise RuntimeError(f"invalid request type: {request_type}") |
|
|
| |
|
|
| def start_session( |
| self, |
| resource_path, |
| session_id=None, |
| offload_video_to_cpu=False, |
| offload_state_to_cpu=False, |
| cache_frame_outputs=True, |
| ): |
| """Start a new inference session on a video directory or path.""" |
| init_kwargs = dict( |
| resource_path=resource_path, |
| offload_video_to_cpu=offload_video_to_cpu, |
| offload_state_to_cpu=offload_state_to_cpu, |
| ) |
| if hasattr(self, "async_loading_frames"): |
| init_kwargs["async_loading_frames"] = self.async_loading_frames |
| if hasattr(self, "video_loader_type"): |
| init_kwargs["video_loader_type"] = self.video_loader_type |
| import inspect |
|
|
| sig = inspect.signature(self.model.init_state) |
| filtered_kwargs = {k: v for k, v in init_kwargs.items() if k in sig.parameters} |
| inference_state = self.model.init_state(**filtered_kwargs) |
| inference_state["cache_frame_outputs"] = bool(cache_frame_outputs) |
|
|
| if not session_id: |
| session_id = str(uuid.uuid4()) |
| self._all_inference_states[session_id] = { |
| "state": inference_state, |
| "session_id": session_id, |
| "start_time": time.time(), |
| "last_use_time": time.time(), |
| } |
| logger.info(f"started new session {session_id}") |
| return {"session_id": session_id} |
|
|
| def add_prompt( |
| self, |
| session_id: str, |
| frame_idx: int, |
| text: Optional[str] = None, |
| points=None, |
| point_labels=None, |
| clear_old_points: bool = True, |
| bounding_boxes=None, |
| bounding_box_labels=None, |
| clear_old_boxes: bool = True, |
| output_prob_thresh: float = 0.5, |
| obj_id: Optional[int] = None, |
| rel_coordinates: bool = True, |
| preencoded_text_outputs=None, |
| ): |
| """Add text, box and/or point prompt on a specific video frame.""" |
| session = self._get_session(session_id) |
| inference_state = session["state"] |
| self._extend_expiration_time(session) |
|
|
| |
| if points is not None and not isinstance(points, torch.Tensor): |
| points = torch.tensor(points, dtype=torch.float32) |
| if point_labels is not None and not isinstance(point_labels, torch.Tensor): |
| point_labels = torch.tensor(point_labels, dtype=torch.int32) |
| if bounding_boxes is not None and not isinstance(bounding_boxes, torch.Tensor): |
| bounding_boxes = torch.tensor(bounding_boxes, dtype=torch.float32) |
| if bounding_box_labels is not None and not isinstance( |
| bounding_box_labels, torch.Tensor |
| ): |
| bounding_box_labels = torch.tensor(bounding_box_labels, dtype=torch.int32) |
|
|
| kwargs = dict( |
| inference_state=inference_state, |
| frame_idx=frame_idx, |
| text_str=text, |
| points=points, |
| point_labels=point_labels, |
| clear_old_points=clear_old_points, |
| boxes_xywh=bounding_boxes, |
| box_labels=bounding_box_labels, |
| clear_old_boxes=clear_old_boxes, |
| output_prob_thresh=output_prob_thresh, |
| rel_coordinates=rel_coordinates, |
| preencoded_text_outputs=preencoded_text_outputs, |
| ) |
| if obj_id is not None: |
| kwargs["obj_id"] = obj_id |
|
|
| |
| |
| import inspect |
|
|
| sig = inspect.signature(self.model.add_prompt) |
| valid_params = set(sig.parameters.keys()) |
| filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} |
|
|
| with self._bf16_autocast(): |
| frame_idx, outputs = self.model.add_prompt(**filtered_kwargs) |
| return {"frame_index": frame_idx, "outputs": outputs} |
|
|
| def remove_object( |
| self, |
| session_id: str, |
| frame_idx: int = 0, |
| obj_id: int = 0, |
| is_user_action: bool = True, |
| ): |
| """Remove an object from tracking.""" |
| session = self._get_session(session_id) |
| inference_state = session["state"] |
| self._extend_expiration_time(session) |
|
|
| result = self.model.remove_object( |
| inference_state, obj_id, frame_idx=frame_idx, is_user_action=is_user_action |
| ) |
| |
| if result is None or (isinstance(result, tuple) and result[1] is None): |
| import numpy as np |
|
|
| out_obj_ids = torch.zeros(0, dtype=torch.int64) |
| out_binary_masks = torch.zeros( |
| 0, |
| inference_state["orig_height"], |
| inference_state["orig_width"], |
| dtype=torch.bool, |
| ) |
| out_boxes_xywh = torch.zeros(0, 4, dtype=torch.float32) |
| outputs = { |
| "out_obj_ids": out_obj_ids.cpu().numpy(), |
| "out_boxes_xywh": out_boxes_xywh.cpu().numpy(), |
| "out_binary_masks": out_binary_masks.cpu().numpy(), |
| } |
| elif isinstance(result, tuple): |
| _, outputs = result |
| else: |
| outputs = result |
| return {"frame_index": frame_idx, "outputs": outputs} |
|
|
| def cancel_propagation(self, session_id): |
| """Cancel any ongoing propagation. No-op if not supported by the model.""" |
| session = self._get_session(session_id) |
| inference_state = session["state"] |
| self._extend_expiration_time(session) |
| if hasattr(self.model, "cancel_propagation"): |
| self.model.cancel_propagation(inference_state) |
| return {"is_success": True} |
|
|
| def propagate_in_video( |
| self, |
| session_id, |
| propagation_direction="both", |
| start_frame_idx=None, |
| max_frame_num_to_track=None, |
| output_prob_thresh=0.5, |
| **kwargs, |
| ): |
| """Propagate the added prompts to get results on all video frames.""" |
| try: |
| session = self._get_session(session_id) |
| inference_state = session["state"] |
| self._extend_expiration_time(session) |
| if propagation_direction not in ["both", "forward", "backward"]: |
| raise ValueError( |
| f"invalid propagation direction: {propagation_direction}" |
| ) |
|
|
| propagate_kwargs = dict( |
| inference_state=inference_state, |
| start_frame_idx=start_frame_idx, |
| max_frame_num_to_track=max_frame_num_to_track, |
| ) |
| |
| import inspect |
|
|
| sig = inspect.signature(self.model.propagate_in_video) |
| if "output_prob_thresh" in sig.parameters: |
| propagate_kwargs["output_prob_thresh"] = output_prob_thresh |
| for k, v in kwargs.items(): |
| if k in sig.parameters: |
| propagate_kwargs[k] = v |
|
|
| |
| with self._bf16_autocast(): |
| if propagation_direction in ["both", "forward"]: |
| for frame_idx, outputs in self.model.propagate_in_video( |
| **propagate_kwargs, |
| reverse=False, |
| ): |
| yield {"frame_index": frame_idx, "outputs": outputs} |
| |
| if propagation_direction in ["both", "backward"]: |
| for frame_idx, outputs in self.model.propagate_in_video( |
| **propagate_kwargs, |
| reverse=True, |
| ): |
| yield {"frame_index": frame_idx, "outputs": outputs} |
| finally: |
| logger.info(f"propagation ended in session {session_id}") |
|
|
| def reset_session(self, session_id): |
| """Reset the session to its initial state.""" |
| session = self._get_session(session_id) |
| inference_state = session["state"] |
| self._extend_expiration_time(session) |
| self.model.reset_state(inference_state) |
| return {"is_success": True} |
|
|
| def close_session(self, session_id, run_gc_collect=True): |
| """Close a session. Idempotent.""" |
| session = self._all_inference_states.pop(session_id, None) |
| if session is None: |
| logger.warning(f"cannot close session {session_id} as it does not exist") |
| else: |
| del session |
| if run_gc_collect: |
| gc.collect() |
| logger.info(f"removed session {session_id}") |
| return {"is_success": True} |
|
|
| def _get_session(self, session_id): |
| session = self._all_inference_states.get(session_id, None) |
| if session is None: |
| raise RuntimeError( |
| f"Cannot find session {session_id}; it might have expired" |
| ) |
| return session |
|
|
| def _extend_expiration_time(self, session): |
| """Update last-use time for session expiration tracking.""" |
| session["last_use_time"] = time.time() |
|
|
| def shutdown(self): |
| """Shutdown the predictor, release runtime caches, and move model weights off CUDA.""" |
| self._all_inference_states.clear() |
| self._exit_bf16_contexts() |
| model = self.model |
| if model is not None: |
| self._clear_cuda_runtime_caches(model) |
| if isinstance(model, torch.nn.Module): |
| model.to("cpu") |
| self.model = None |
| gc.collect() |
| empty_accelerator_cache() |
|
|
| def _exit_bf16_contexts(self): |
| seen = set() |
| for obj in [self, self.model]: |
| if obj is None: |
| continue |
| modules = obj.modules() if isinstance(obj, torch.nn.Module) else [obj] |
| for module in modules: |
| context = getattr(module, "bf16_context", None) |
| if context is None or id(context) in seen: |
| continue |
| seen.add(id(context)) |
| context.__exit__(None, None, None) |
| module.bf16_context = None |
|
|
| @staticmethod |
| def _clear_cuda_runtime_caches(model): |
| for module in model.modules(): |
| cache = getattr(module, "cache", None) |
| if isinstance(cache, dict): |
| cache.clear() |
| coord_cache = getattr(module, "coord_cache", None) |
| if isinstance(coord_cache, dict): |
| coord_cache.clear() |
| if hasattr(module, "compilable_cord_cache"): |
| module.compilable_cord_cache = None |
| if hasattr(module, "compilable_stored_size"): |
| module.compilable_stored_size = None |
| for name in ("freqs_cis", "freqs_cis_real", "freqs_cis_imag"): |
| tensor = getattr(module, name, None) |
| if torch.is_tensor(tensor) and tensor.is_cuda: |
| setattr(module, name, tensor.cpu()) |
|
|