|
|
""" |
|
|
SAM2 Wrapper for Video Mask Tracking |
|
|
Handles mask generation and propagation through video |
|
|
""" |
|
|
|
|
|
import os |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
from PIL import Image |
|
|
from pathlib import Path |
|
|
from typing import List, Tuple |
|
|
import tempfile |
|
|
import shutil |
|
|
|
|
|
from sam2.build_sam import build_sam2_video_predictor |
|
|
|
|
|
|
|
|
class SAM2VideoTracker: |
|
|
def __init__(self, checkpoint_path, config_file, device="cuda"): |
|
|
""" |
|
|
Initialize SAM2 video tracker |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to SAM2 checkpoint |
|
|
config_file: Path to SAM2 config file |
|
|
device: Device to run on |
|
|
""" |
|
|
self.device = device |
|
|
self.predictor = build_sam2_video_predictor( |
|
|
config_file=config_file, |
|
|
ckpt_path=checkpoint_path, |
|
|
device=device |
|
|
) |
|
|
print(f"SAM2 video tracker initialized on {device}") |
|
|
|
|
|
def track_video(self, frames: List[np.ndarray], points: List[List[int]], |
|
|
labels: List[int]) -> List[np.ndarray]: |
|
|
""" |
|
|
Track object through video using SAM2 |
|
|
|
|
|
Args: |
|
|
frames: List of numpy arrays, [(H,W,3)]*n, uint8 RGB frames |
|
|
points: List of [x, y] coordinates for prompts |
|
|
labels: List of labels (1 for positive, 0 for negative) |
|
|
|
|
|
Returns: |
|
|
masks: List of numpy arrays, [(H,W)]*n, uint8 binary masks |
|
|
""" |
|
|
|
|
|
temp_dir = Path(tempfile.mkdtemp()) |
|
|
frames_dir = temp_dir / "frames" |
|
|
frames_dir.mkdir(exist_ok=True) |
|
|
|
|
|
try: |
|
|
|
|
|
print(f"Saving {len(frames)} frames to temporary directory...") |
|
|
for i, frame in enumerate(frames): |
|
|
frame_path = frames_dir / f"{i:05d}.jpg" |
|
|
Image.fromarray(frame).save(frame_path, quality=95) |
|
|
|
|
|
|
|
|
print("Initializing SAM2 inference state...") |
|
|
inference_state = self.predictor.init_state(video_path=str(frames_dir)) |
|
|
|
|
|
|
|
|
points_array = np.array(points, dtype=np.float32) |
|
|
labels_array = np.array(labels, dtype=np.int32) |
|
|
|
|
|
print(f"Adding {len(points)} point prompts on first frame...") |
|
|
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points( |
|
|
inference_state=inference_state, |
|
|
frame_idx=0, |
|
|
obj_id=1, |
|
|
points=points_array, |
|
|
labels=labels_array, |
|
|
) |
|
|
|
|
|
|
|
|
print("Propagating masks through video...") |
|
|
masks = [] |
|
|
for frame_idx, object_ids, mask_logits in self.predictor.propagate_in_video(inference_state): |
|
|
|
|
|
|
|
|
obj_ids_list = object_ids.tolist() if hasattr(object_ids, 'tolist') else object_ids |
|
|
|
|
|
if 1 in obj_ids_list: |
|
|
mask_idx = obj_ids_list.index(1) |
|
|
mask = (mask_logits[mask_idx] > 0.0).cpu().numpy() |
|
|
mask_uint8 = (mask.squeeze() * 255).astype(np.uint8) |
|
|
masks.append(mask_uint8) |
|
|
else: |
|
|
|
|
|
h, w = frames[0].shape[:2] |
|
|
masks.append(np.zeros((h, w), dtype=np.uint8)) |
|
|
|
|
|
print(f"Generated {len(masks)} masks") |
|
|
return masks |
|
|
|
|
|
finally: |
|
|
|
|
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
|
|
|
def get_first_frame_mask(self, frame: np.ndarray, points: List[List[int]], |
|
|
labels: List[int]) -> np.ndarray: |
|
|
""" |
|
|
Get mask for first frame only (for preview) |
|
|
|
|
|
Args: |
|
|
frame: np.ndarray, (H, W, 3), uint8 RGB frame |
|
|
points: List of [x, y] coordinates |
|
|
labels: List of labels (1 for positive, 0 for negative) |
|
|
|
|
|
Returns: |
|
|
mask: np.ndarray, (H, W), uint8 binary mask |
|
|
""" |
|
|
|
|
|
temp_dir = Path(tempfile.mkdtemp()) |
|
|
frames_dir = temp_dir / "frames" |
|
|
frames_dir.mkdir(exist_ok=True) |
|
|
|
|
|
try: |
|
|
|
|
|
frame_path = frames_dir / "00000.jpg" |
|
|
Image.fromarray(frame).save(frame_path, quality=95) |
|
|
|
|
|
|
|
|
inference_state = self.predictor.init_state(video_path=str(frames_dir)) |
|
|
|
|
|
|
|
|
points_array = np.array(points, dtype=np.float32) |
|
|
labels_array = np.array(labels, dtype=np.int32) |
|
|
|
|
|
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points( |
|
|
inference_state=inference_state, |
|
|
frame_idx=0, |
|
|
obj_id=1, |
|
|
points=points_array, |
|
|
labels=labels_array, |
|
|
) |
|
|
|
|
|
|
|
|
if len(out_mask_logits) > 0: |
|
|
mask = (out_mask_logits[0] > 0.0).cpu().numpy() |
|
|
mask_uint8 = (mask.squeeze() * 255).astype(np.uint8) |
|
|
return mask_uint8 |
|
|
else: |
|
|
return np.zeros(frame.shape[:2], dtype=np.uint8) |
|
|
|
|
|
finally: |
|
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
|
|
|
|
|
|
def load_sam2_tracker(device="cuda"): |
|
|
""" |
|
|
Load SAM2 video tracker with pretrained weights |
|
|
|
|
|
Args: |
|
|
device: Device to run on |
|
|
|
|
|
Returns: |
|
|
SAM2VideoTracker instance |
|
|
""" |
|
|
|
|
|
|
|
|
checkpoint_path = "sam2.1_hiera_large.pt" |
|
|
config_file = "configs/sam2.1/sam2.1_hiera_l.yaml" |
|
|
|
|
|
|
|
|
if not os.path.exists(checkpoint_path): |
|
|
|
|
|
alt_checkpoint_path = os.path.join("checkpoints", "sam2.1_hiera_large.pt") |
|
|
if os.path.exists(alt_checkpoint_path): |
|
|
checkpoint_path = alt_checkpoint_path |
|
|
else: |
|
|
raise FileNotFoundError( |
|
|
f"SAM2 checkpoint not found at {checkpoint_path} or {alt_checkpoint_path}. " |
|
|
"Please run download_checkpoints.sh first or ensure sam2.1_hiera_large.pt is in the root directory." |
|
|
) |
|
|
|
|
|
print(f"Loading SAM2 from {checkpoint_path}...") |
|
|
tracker = SAM2VideoTracker(checkpoint_path, config_file, device) |
|
|
|
|
|
return tracker |
|
|
|