File size: 6,779 Bytes
d33e75e 0b67fec d33e75e 0b67fec e6076ca 0b67fec d33e75e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
"""
SAM2 Wrapper for Video Mask Tracking
Handles mask generation and propagation through video
"""
import os
import cv2
import numpy as np
import torch
from PIL import Image
from pathlib import Path
from typing import List, Tuple
import tempfile
import shutil
from sam2.build_sam import build_sam2_video_predictor
class SAM2VideoTracker:
def __init__(self, checkpoint_path, config_file, device="cuda"):
"""
Initialize SAM2 video tracker
Args:
checkpoint_path: Path to SAM2 checkpoint
config_file: Path to SAM2 config file
device: Device to run on
"""
self.device = device
self.predictor = build_sam2_video_predictor(
config_file=config_file,
ckpt_path=checkpoint_path,
device=device
)
print(f"SAM2 video tracker initialized on {device}")
def track_video(self, frames: List[np.ndarray], points: List[List[int]],
labels: List[int]) -> List[np.ndarray]:
"""
Track object through video using SAM2
Args:
frames: List of numpy arrays, [(H,W,3)]*n, uint8 RGB frames
points: List of [x, y] coordinates for prompts
labels: List of labels (1 for positive, 0 for negative)
Returns:
masks: List of numpy arrays, [(H,W)]*n, uint8 binary masks
"""
# Create temporary directory for frames
temp_dir = Path(tempfile.mkdtemp())
frames_dir = temp_dir / "frames"
frames_dir.mkdir(exist_ok=True)
try:
# Save frames to temp directory
print(f"Saving {len(frames)} frames to temporary directory...")
for i, frame in enumerate(frames):
frame_path = frames_dir / f"{i:05d}.jpg"
Image.fromarray(frame).save(frame_path, quality=95)
# Initialize SAM2 video predictor
print("Initializing SAM2 inference state...")
inference_state = self.predictor.init_state(video_path=str(frames_dir))
# Add prompts on first frame
points_array = np.array(points, dtype=np.float32)
labels_array = np.array(labels, dtype=np.int32)
print(f"Adding {len(points)} point prompts on first frame...")
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
inference_state=inference_state,
frame_idx=0,
obj_id=1,
points=points_array,
labels=labels_array,
)
# Propagate through video
print("Propagating masks through video...")
masks = []
for frame_idx, object_ids, mask_logits in self.predictor.propagate_in_video(inference_state):
# Get mask for object ID 1
# object_ids can be a tensor or a list
obj_ids_list = object_ids.tolist() if hasattr(object_ids, 'tolist') else object_ids
if 1 in obj_ids_list:
mask_idx = obj_ids_list.index(1)
mask = (mask_logits[mask_idx] > 0.0).cpu().numpy()
mask_uint8 = (mask.squeeze() * 255).astype(np.uint8)
masks.append(mask_uint8)
else:
# No mask for this frame, use empty mask
h, w = frames[0].shape[:2]
masks.append(np.zeros((h, w), dtype=np.uint8))
print(f"Generated {len(masks)} masks")
return masks
finally:
# Clean up temporary directory
shutil.rmtree(temp_dir, ignore_errors=True)
def get_first_frame_mask(self, frame: np.ndarray, points: List[List[int]],
labels: List[int]) -> np.ndarray:
"""
Get mask for first frame only (for preview)
Args:
frame: np.ndarray, (H, W, 3), uint8 RGB frame
points: List of [x, y] coordinates
labels: List of labels (1 for positive, 0 for negative)
Returns:
mask: np.ndarray, (H, W), uint8 binary mask
"""
# Create temporary directory
temp_dir = Path(tempfile.mkdtemp())
frames_dir = temp_dir / "frames"
frames_dir.mkdir(exist_ok=True)
try:
# Save single frame
frame_path = frames_dir / "00000.jpg"
Image.fromarray(frame).save(frame_path, quality=95)
# Initialize SAM2
inference_state = self.predictor.init_state(video_path=str(frames_dir))
# Add prompts
points_array = np.array(points, dtype=np.float32)
labels_array = np.array(labels, dtype=np.int32)
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
inference_state=inference_state,
frame_idx=0,
obj_id=1,
points=points_array,
labels=labels_array,
)
# Get mask
if len(out_mask_logits) > 0:
mask = (out_mask_logits[0] > 0.0).cpu().numpy()
mask_uint8 = (mask.squeeze() * 255).astype(np.uint8)
return mask_uint8
else:
return np.zeros(frame.shape[:2], dtype=np.uint8)
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
def load_sam2_tracker(device="cuda"):
"""
Load SAM2 video tracker with pretrained weights
Args:
device: Device to run on
Returns:
SAM2VideoTracker instance
"""
# Use relative paths that work on Hugging Face Space
# The checkpoint file should be in the root directory or checkpoints/
checkpoint_path = "sam2.1_hiera_large.pt"
config_file = "configs/sam2.1/sam2.1_hiera_l.yaml"
# Check if checkpoint exists
if not os.path.exists(checkpoint_path):
# Try alternative path
alt_checkpoint_path = os.path.join("checkpoints", "sam2.1_hiera_large.pt")
if os.path.exists(alt_checkpoint_path):
checkpoint_path = alt_checkpoint_path
else:
raise FileNotFoundError(
f"SAM2 checkpoint not found at {checkpoint_path} or {alt_checkpoint_path}. "
"Please run download_checkpoints.sh first or ensure sam2.1_hiera_large.pt is in the root directory."
)
print(f"Loading SAM2 from {checkpoint_path}...")
tracker = SAM2VideoTracker(checkpoint_path, config_file, device)
return tracker
|