|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import contextlib |
|
|
import logging |
|
|
import os |
|
|
import uuid |
|
|
from pathlib import Path |
|
|
from threading import Lock |
|
|
from typing import Any, Dict, Generator, List |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from app_conf import APP_ROOT, MODEL_SIZE |
|
|
from inference.data_types import ( |
|
|
AddMaskRequest, |
|
|
AddPointsRequest, |
|
|
CancelPorpagateResponse, |
|
|
CancelPropagateInVideoRequest, |
|
|
ClearPointsInFrameRequest, |
|
|
ClearPointsInVideoRequest, |
|
|
ClearPointsInVideoResponse, |
|
|
CloseSessionRequest, |
|
|
CloseSessionResponse, |
|
|
Mask, |
|
|
PropagateDataResponse, |
|
|
PropagateDataValue, |
|
|
PropagateInVideoRequest, |
|
|
RemoveObjectRequest, |
|
|
RemoveObjectResponse, |
|
|
StartSessionRequest, |
|
|
StartSessionResponse, |
|
|
) |
|
|
from pycocotools.mask import decode as decode_masks, encode as encode_masks |
|
|
from sam2.build_sam import build_sam2_video_predictor |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class InferenceAPI: |
|
|
|
|
|
def __init__(self) -> None: |
|
|
super(InferenceAPI, self).__init__() |
|
|
|
|
|
self.session_states: Dict[str, Any] = {} |
|
|
self.score_thresh = 0 |
|
|
|
|
|
if MODEL_SIZE == "tiny": |
|
|
checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_tiny.pt" |
|
|
model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml" |
|
|
elif MODEL_SIZE == "small": |
|
|
checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_small.pt" |
|
|
model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml" |
|
|
elif MODEL_SIZE == "large": |
|
|
checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_large.pt" |
|
|
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" |
|
|
else: |
|
|
checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_base_plus.pt" |
|
|
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml" |
|
|
|
|
|
|
|
|
force_cpu_device = os.environ.get("SAM2_DEMO_FORCE_CPU_DEVICE", "0") == "1" |
|
|
if force_cpu_device: |
|
|
logger.info("forcing CPU device for SAM 2 demo") |
|
|
if torch.cuda.is_available() and not force_cpu_device: |
|
|
device = torch.device("cuda") |
|
|
elif torch.backends.mps.is_available() and not force_cpu_device: |
|
|
device = torch.device("mps") |
|
|
else: |
|
|
device = torch.device("cpu") |
|
|
logger.info(f"using device: {device}") |
|
|
|
|
|
if device.type == "cuda": |
|
|
|
|
|
if torch.cuda.get_device_properties(0).major >= 8: |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
elif device.type == "mps": |
|
|
logging.warning( |
|
|
"\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might " |
|
|
"give numerically different outputs and sometimes degraded performance on MPS. " |
|
|
"See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion." |
|
|
) |
|
|
|
|
|
self.device = device |
|
|
self.predictor = build_sam2_video_predictor( |
|
|
model_cfg, checkpoint, device=device |
|
|
) |
|
|
self.inference_lock = Lock() |
|
|
|
|
|
def autocast_context(self): |
|
|
if self.device.type == "cuda": |
|
|
return torch.autocast("cuda", dtype=torch.bfloat16) |
|
|
else: |
|
|
return contextlib.nullcontext() |
|
|
|
|
|
def start_session(self, request: StartSessionRequest) -> StartSessionResponse: |
|
|
with self.autocast_context(), self.inference_lock: |
|
|
session_id = str(uuid.uuid4()) |
|
|
|
|
|
|
|
|
offload_video_to_cpu = self.device.type == "mps" |
|
|
inference_state = self.predictor.init_state( |
|
|
request.path, |
|
|
offload_video_to_cpu=offload_video_to_cpu, |
|
|
) |
|
|
self.session_states[session_id] = { |
|
|
"canceled": False, |
|
|
"state": inference_state, |
|
|
} |
|
|
return StartSessionResponse(session_id=session_id) |
|
|
|
|
|
def close_session(self, request: CloseSessionRequest) -> CloseSessionResponse: |
|
|
is_successful = self.__clear_session_state(request.session_id) |
|
|
return CloseSessionResponse(success=is_successful) |
|
|
|
|
|
def add_points( |
|
|
self, request: AddPointsRequest, test: str = "" |
|
|
) -> PropagateDataResponse: |
|
|
with self.autocast_context(), self.inference_lock: |
|
|
session = self.__get_session(request.session_id) |
|
|
inference_state = session["state"] |
|
|
|
|
|
frame_idx = request.frame_index |
|
|
obj_id = request.object_id |
|
|
points = request.points |
|
|
labels = request.labels |
|
|
clear_old_points = request.clear_old_points |
|
|
|
|
|
|
|
|
frame_idx, object_ids, masks = self.predictor.add_new_points_or_box( |
|
|
inference_state=inference_state, |
|
|
frame_idx=frame_idx, |
|
|
obj_id=obj_id, |
|
|
points=points, |
|
|
labels=labels, |
|
|
clear_old_points=clear_old_points, |
|
|
normalize_coords=False, |
|
|
) |
|
|
|
|
|
masks_binary = (masks > self.score_thresh)[:, 0].cpu().numpy() |
|
|
|
|
|
rle_mask_list = self.__get_rle_mask_list( |
|
|
object_ids=object_ids, masks=masks_binary |
|
|
) |
|
|
|
|
|
return PropagateDataResponse( |
|
|
frame_index=frame_idx, |
|
|
results=rle_mask_list, |
|
|
) |
|
|
|
|
|
def add_mask(self, request: AddMaskRequest) -> PropagateDataResponse: |
|
|
""" |
|
|
Add new points on a specific video frame. |
|
|
- mask is a numpy array of shape [H_im, W_im] (containing 1 for foreground and 0 for background). |
|
|
Note: providing an input mask would overwrite any previous input points on this frame. |
|
|
""" |
|
|
with self.autocast_context(), self.inference_lock: |
|
|
session_id = request.session_id |
|
|
frame_idx = request.frame_index |
|
|
obj_id = request.object_id |
|
|
rle_mask = { |
|
|
"counts": request.mask.counts, |
|
|
"size": request.mask.size, |
|
|
} |
|
|
|
|
|
mask = decode_masks(rle_mask) |
|
|
|
|
|
logger.info( |
|
|
f"add mask on frame {frame_idx} in session {session_id}: {obj_id=}, {mask.shape=}" |
|
|
) |
|
|
session = self.__get_session(session_id) |
|
|
inference_state = session["state"] |
|
|
|
|
|
frame_idx, obj_ids, video_res_masks = self.model.add_new_mask( |
|
|
inference_state=inference_state, |
|
|
frame_idx=frame_idx, |
|
|
obj_id=obj_id, |
|
|
mask=torch.tensor(mask > 0), |
|
|
) |
|
|
masks_binary = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy() |
|
|
|
|
|
rle_mask_list = self.__get_rle_mask_list( |
|
|
object_ids=obj_ids, masks=masks_binary |
|
|
) |
|
|
|
|
|
return PropagateDataResponse( |
|
|
frame_index=frame_idx, |
|
|
results=rle_mask_list, |
|
|
) |
|
|
|
|
|
def clear_points_in_frame( |
|
|
self, request: ClearPointsInFrameRequest |
|
|
) -> PropagateDataResponse: |
|
|
""" |
|
|
Remove all input points in a specific frame. |
|
|
""" |
|
|
with self.autocast_context(), self.inference_lock: |
|
|
session_id = request.session_id |
|
|
frame_idx = request.frame_index |
|
|
obj_id = request.object_id |
|
|
|
|
|
logger.info( |
|
|
f"clear inputs on frame {frame_idx} in session {session_id}: {obj_id=}" |
|
|
) |
|
|
session = self.__get_session(session_id) |
|
|
inference_state = session["state"] |
|
|
frame_idx, obj_ids, video_res_masks = ( |
|
|
self.predictor.clear_all_prompts_in_frame( |
|
|
inference_state, frame_idx, obj_id |
|
|
) |
|
|
) |
|
|
masks_binary = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy() |
|
|
|
|
|
rle_mask_list = self.__get_rle_mask_list( |
|
|
object_ids=obj_ids, masks=masks_binary |
|
|
) |
|
|
|
|
|
return PropagateDataResponse( |
|
|
frame_index=frame_idx, |
|
|
results=rle_mask_list, |
|
|
) |
|
|
|
|
|
def clear_points_in_video( |
|
|
self, request: ClearPointsInVideoRequest |
|
|
) -> ClearPointsInVideoResponse: |
|
|
""" |
|
|
Remove all input points in all frames throughout the video. |
|
|
""" |
|
|
with self.autocast_context(), self.inference_lock: |
|
|
session_id = request.session_id |
|
|
logger.info(f"clear all inputs across the video in session {session_id}") |
|
|
session = self.__get_session(session_id) |
|
|
inference_state = session["state"] |
|
|
self.predictor.reset_state(inference_state) |
|
|
return ClearPointsInVideoResponse(success=True) |
|
|
|
|
|
def remove_object(self, request: RemoveObjectRequest) -> RemoveObjectResponse: |
|
|
""" |
|
|
Remove an object id from the tracking state. |
|
|
""" |
|
|
with self.autocast_context(), self.inference_lock: |
|
|
session_id = request.session_id |
|
|
obj_id = request.object_id |
|
|
logger.info(f"remove object in session {session_id}: {obj_id=}") |
|
|
session = self.__get_session(session_id) |
|
|
inference_state = session["state"] |
|
|
new_obj_ids, updated_frames = self.predictor.remove_object( |
|
|
inference_state, obj_id |
|
|
) |
|
|
|
|
|
results = [] |
|
|
for frame_index, video_res_masks in updated_frames: |
|
|
masks = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy() |
|
|
rle_mask_list = self.__get_rle_mask_list( |
|
|
object_ids=new_obj_ids, masks=masks |
|
|
) |
|
|
results.append( |
|
|
PropagateDataResponse( |
|
|
frame_index=frame_index, |
|
|
results=rle_mask_list, |
|
|
) |
|
|
) |
|
|
|
|
|
return RemoveObjectResponse(results=results) |
|
|
|
|
|
def propagate_in_video( |
|
|
self, request: PropagateInVideoRequest |
|
|
) -> Generator[PropagateDataResponse, None, None]: |
|
|
session_id = request.session_id |
|
|
start_frame_idx = request.start_frame_index |
|
|
propagation_direction = "both" |
|
|
max_frame_num_to_track = None |
|
|
|
|
|
""" |
|
|
Propagate existing input points in all frames to track the object across video. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with self.autocast_context(), self.inference_lock: |
|
|
logger.info( |
|
|
f"propagate in video in session {session_id}: " |
|
|
f"{propagation_direction=}, {start_frame_idx=}, {max_frame_num_to_track=}" |
|
|
) |
|
|
|
|
|
try: |
|
|
session = self.__get_session(session_id) |
|
|
session["canceled"] = False |
|
|
|
|
|
inference_state = session["state"] |
|
|
if propagation_direction not in ["both", "forward", "backward"]: |
|
|
raise ValueError( |
|
|
f"invalid propagation direction: {propagation_direction}" |
|
|
) |
|
|
|
|
|
|
|
|
if propagation_direction in ["both", "forward"]: |
|
|
for outputs in self.predictor.propagate_in_video( |
|
|
inference_state=inference_state, |
|
|
start_frame_idx=start_frame_idx, |
|
|
max_frame_num_to_track=max_frame_num_to_track, |
|
|
reverse=False, |
|
|
): |
|
|
if session["canceled"]: |
|
|
return None |
|
|
|
|
|
frame_idx, obj_ids, video_res_masks = outputs |
|
|
masks_binary = ( |
|
|
(video_res_masks > self.score_thresh)[:, 0].cpu().numpy() |
|
|
) |
|
|
|
|
|
rle_mask_list = self.__get_rle_mask_list( |
|
|
object_ids=obj_ids, masks=masks_binary |
|
|
) |
|
|
|
|
|
yield PropagateDataResponse( |
|
|
frame_index=frame_idx, |
|
|
results=rle_mask_list, |
|
|
) |
|
|
|
|
|
|
|
|
if propagation_direction in ["both", "backward"]: |
|
|
for outputs in self.predictor.propagate_in_video( |
|
|
inference_state=inference_state, |
|
|
start_frame_idx=start_frame_idx, |
|
|
max_frame_num_to_track=max_frame_num_to_track, |
|
|
reverse=True, |
|
|
): |
|
|
if session["canceled"]: |
|
|
return None |
|
|
|
|
|
frame_idx, obj_ids, video_res_masks = outputs |
|
|
masks_binary = ( |
|
|
(video_res_masks > self.score_thresh)[:, 0].cpu().numpy() |
|
|
) |
|
|
|
|
|
rle_mask_list = self.__get_rle_mask_list( |
|
|
object_ids=obj_ids, masks=masks_binary |
|
|
) |
|
|
|
|
|
yield PropagateDataResponse( |
|
|
frame_index=frame_idx, |
|
|
results=rle_mask_list, |
|
|
) |
|
|
finally: |
|
|
|
|
|
|
|
|
logger.info( |
|
|
f"propagation ended in session {session_id}; {self.__get_session_stats()}" |
|
|
) |
|
|
|
|
|
def cancel_propagate_in_video( |
|
|
self, request: CancelPropagateInVideoRequest |
|
|
) -> CancelPorpagateResponse: |
|
|
session = self.__get_session(request.session_id) |
|
|
session["canceled"] = True |
|
|
return CancelPorpagateResponse(success=True) |
|
|
|
|
|
def __get_rle_mask_list( |
|
|
self, object_ids: List[int], masks: np.ndarray |
|
|
) -> List[PropagateDataValue]: |
|
|
""" |
|
|
Return a list of data values, i.e. list of object/mask combos. |
|
|
""" |
|
|
return [ |
|
|
self.__get_mask_for_object(object_id=object_id, mask=mask) |
|
|
for object_id, mask in zip(object_ids, masks) |
|
|
] |
|
|
|
|
|
def __get_mask_for_object( |
|
|
self, object_id: int, mask: np.ndarray |
|
|
) -> PropagateDataValue: |
|
|
""" |
|
|
Create a data value for an object/mask combo. |
|
|
""" |
|
|
mask_rle = encode_masks(np.array(mask, dtype=np.uint8, order="F")) |
|
|
mask_rle["counts"] = mask_rle["counts"].decode() |
|
|
return PropagateDataValue( |
|
|
object_id=object_id, |
|
|
mask=Mask( |
|
|
size=mask_rle["size"], |
|
|
counts=mask_rle["counts"], |
|
|
), |
|
|
) |
|
|
|
|
|
def __get_session(self, session_id: str): |
|
|
session = self.session_states.get(session_id, None) |
|
|
if session is None: |
|
|
raise RuntimeError( |
|
|
f"Cannot find session {session_id}; it might have expired" |
|
|
) |
|
|
return session |
|
|
|
|
|
def __get_session_stats(self): |
|
|
"""Get a statistics string for live sessions and their GPU usage.""" |
|
|
|
|
|
live_session_strs = [ |
|
|
f"'{session_id}' ({session['state']['num_frames']} frames, " |
|
|
f"{len(session['state']['obj_ids'])} objects)" |
|
|
for session_id, session in self.session_states.items() |
|
|
] |
|
|
session_stats_str = ( |
|
|
"Test String Here - -" |
|
|
f"live sessions: [{', '.join(live_session_strs)}], GPU memory: " |
|
|
f"{torch.cuda.memory_allocated() // 1024**2} MiB used and " |
|
|
f"{torch.cuda.memory_reserved() // 1024**2} MiB reserved" |
|
|
f" (max over time: {torch.cuda.max_memory_allocated() // 1024**2} MiB used " |
|
|
f"and {torch.cuda.max_memory_reserved() // 1024**2} MiB reserved)" |
|
|
) |
|
|
return session_stats_str |
|
|
|
|
|
def __clear_session_state(self, session_id: str) -> bool: |
|
|
session = self.session_states.pop(session_id, None) |
|
|
if session is None: |
|
|
logger.warning( |
|
|
f"cannot close session {session_id} as it does not exist (it might have expired); " |
|
|
f"{self.__get_session_stats()}" |
|
|
) |
|
|
return False |
|
|
else: |
|
|
logger.info(f"removed session {session_id}; {self.__get_session_stats()}") |
|
|
return True |
|
|
|