File size: 6,812 Bytes
d33e75e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
"""
SAM2 Wrapper for Video Mask Tracking - Hugging Face Space Version
Handles mask generation and propagation through video
"""
import sys
import os
from pathlib import Path
# Add SAM2 to path if installed
try:
import sam2
except ImportError:
# Try to add from common locations
possible_paths = [
"/home/cvlab19/project/samuel/CVPR/sam2",
"./sam2"
]
for path in possible_paths:
if os.path.exists(path):
sys.path.append(path)
break
import cv2
import numpy as np
import torch
from PIL import Image
from typing import List, Tuple
import tempfile
import shutil
from sam2.build_sam import build_sam2_video_predictor
class SAM2VideoTracker:
def __init__(self, checkpoint_path, config_file, device="cuda"):
"""
Initialize SAM2 video tracker
Args:
checkpoint_path: Path to SAM2 checkpoint
config_file: Path to SAM2 config file
device: Device to run on
"""
self.device = device
self.predictor = build_sam2_video_predictor(
config_file=config_file,
ckpt_path=checkpoint_path,
device=device
)
print(f"SAM2 video tracker initialized on {device}")
def track_video(self, frames: List[np.ndarray], points: List[List[int]],
labels: List[int]) -> List[np.ndarray]:
"""
Track object through video using SAM2
Args:
frames: List of numpy arrays, [(H,W,3)]*n, uint8 RGB frames
points: List of [x, y] coordinates for prompts
labels: List of labels (1 for positive, 0 for negative)
Returns:
masks: List of numpy arrays, [(H,W)]*n, uint8 binary masks
"""
# Create temporary directory for frames
temp_dir = Path(tempfile.mkdtemp())
frames_dir = temp_dir / "frames"
frames_dir.mkdir(exist_ok=True)
try:
# Save frames to temp directory
print(f"Saving {len(frames)} frames to temporary directory...")
for i, frame in enumerate(frames):
frame_path = frames_dir / f"{i:05d}.jpg"
Image.fromarray(frame).save(frame_path, quality=95)
# Initialize SAM2 video predictor
print("Initializing SAM2 inference state...")
inference_state = self.predictor.init_state(video_path=str(frames_dir))
# Add prompts on first frame
points_array = np.array(points, dtype=np.float32)
labels_array = np.array(labels, dtype=np.int32)
print(f"Adding {len(points)} point prompts on first frame...")
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
inference_state=inference_state,
frame_idx=0,
obj_id=1,
points=points_array,
labels=labels_array,
)
# Propagate through video
print("Propagating masks through video...")
masks = []
for frame_idx, object_ids, mask_logits in self.predictor.propagate_in_video(inference_state):
# Get mask for object ID 1
obj_ids_list = object_ids.tolist() if hasattr(object_ids, 'tolist') else object_ids
if 1 in obj_ids_list:
mask_idx = obj_ids_list.index(1)
mask = (mask_logits[mask_idx] > 0.0).cpu().numpy()
mask_uint8 = (mask.squeeze() * 255).astype(np.uint8)
masks.append(mask_uint8)
else:
# No mask for this frame, use empty mask
h, w = frames[0].shape[:2]
masks.append(np.zeros((h, w), dtype=np.uint8))
print(f"Generated {len(masks)} masks")
return masks
finally:
# Clean up temporary directory
shutil.rmtree(temp_dir, ignore_errors=True)
def get_first_frame_mask(self, frame: np.ndarray, points: List[List[int]],
labels: List[int]) -> np.ndarray:
"""
Get mask for first frame only (for preview)
Args:
frame: np.ndarray, (H, W, 3), uint8 RGB frame
points: List of [x, y] coordinates
labels: List of labels (1 for positive, 0 for negative)
Returns:
mask: np.ndarray, (H, W), uint8 binary mask
"""
# Create temporary directory
temp_dir = Path(tempfile.mkdtemp())
frames_dir = temp_dir / "frames"
frames_dir.mkdir(exist_ok=True)
try:
# Save single frame
frame_path = frames_dir / "00000.jpg"
Image.fromarray(frame).save(frame_path, quality=95)
# Initialize SAM2
inference_state = self.predictor.init_state(video_path=str(frames_dir))
# Add prompts
points_array = np.array(points, dtype=np.float32)
labels_array = np.array(labels, dtype=np.int32)
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
inference_state=inference_state,
frame_idx=0,
obj_id=1,
points=points_array,
labels=labels_array,
)
# Get mask
if len(out_mask_logits) > 0:
mask = (out_mask_logits[0] > 0.0).cpu().numpy()
mask_uint8 = (mask.squeeze() * 255).astype(np.uint8)
return mask_uint8
else:
return np.zeros(frame.shape[:2], dtype=np.uint8)
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
def load_sam2_tracker(checkpoint_path=None, device="cuda"):
"""
Load SAM2 video tracker with pretrained weights
Args:
checkpoint_path: Path to SAM2 checkpoint (if None, uses default location)
device: Device to run on
Returns:
SAM2VideoTracker instance
"""
# Use provided path or default
if checkpoint_path is None:
checkpoint_path = "checkpoints/sam2.1_hiera_large.pt"
# Config file should be in the SAM2 repo
config_file = "configs/sam2.1/sam2.1_hiera_l.yaml"
# Check if we need to use the local yaml file
if not os.path.exists(config_file):
config_file = "sam2_hiera_l.yaml"
print(f"Loading SAM2 from {checkpoint_path}...")
print(f"Using config: {config_file}")
tracker = SAM2VideoTracker(checkpoint_path, config_file, device)
return tracker
|