Spaces:
Build error
Build error
| """ | |
| SAM2 Wrapper for Video Mask Tracking - Hugging Face Space Version | |
| Handles mask generation and propagation through video | |
| """ | |
| import sys | |
| import os | |
| from pathlib import Path | |
| # Add SAM2 to path if installed | |
| try: | |
| import sam2 | |
| except ImportError: | |
| # Try to add from common locations | |
| possible_paths = [ | |
| "/home/cvlab19/project/samuel/CVPR/sam2", | |
| "./sam2" | |
| ] | |
| for path in possible_paths: | |
| if os.path.exists(path): | |
| sys.path.append(path) | |
| break | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from typing import List, Tuple | |
| import tempfile | |
| import shutil | |
| from sam2.build_sam import build_sam2_video_predictor | |
| class SAM2VideoTracker: | |
| def __init__(self, checkpoint_path, config_file, device="cuda"): | |
| """ | |
| Initialize SAM2 video tracker | |
| Args: | |
| checkpoint_path: Path to SAM2 checkpoint | |
| config_file: Path to SAM2 config file | |
| device: Device to run on | |
| """ | |
| self.device = device | |
| self.predictor = build_sam2_video_predictor( | |
| config_file=config_file, | |
| ckpt_path=checkpoint_path, | |
| device=device | |
| ) | |
| print(f"SAM2 video tracker initialized on {device}") | |
| def track_video(self, frames: List[np.ndarray], points: List[List[int]], | |
| labels: List[int]) -> List[np.ndarray]: | |
| """ | |
| Track object through video using SAM2 | |
| Args: | |
| frames: List of numpy arrays, [(H,W,3)]*n, uint8 RGB frames | |
| points: List of [x, y] coordinates for prompts | |
| labels: List of labels (1 for positive, 0 for negative) | |
| Returns: | |
| masks: List of numpy arrays, [(H,W)]*n, uint8 binary masks | |
| """ | |
| # Create temporary directory for frames | |
| temp_dir = Path(tempfile.mkdtemp()) | |
| frames_dir = temp_dir / "frames" | |
| frames_dir.mkdir(exist_ok=True) | |
| try: | |
| # Save frames to temp directory | |
| print(f"Saving {len(frames)} frames to temporary directory...") | |
| for i, frame in enumerate(frames): | |
| frame_path = frames_dir / f"{i:05d}.jpg" | |
| Image.fromarray(frame).save(frame_path, quality=95) | |
| # Initialize SAM2 video predictor | |
| print("Initializing SAM2 inference state...") | |
| inference_state = self.predictor.init_state(video_path=str(frames_dir)) | |
| # Add prompts on first frame | |
| points_array = np.array(points, dtype=np.float32) | |
| labels_array = np.array(labels, dtype=np.int32) | |
| print(f"Adding {len(points)} point prompts on first frame...") | |
| _, out_obj_ids, out_mask_logits = self.predictor.add_new_points( | |
| inference_state=inference_state, | |
| frame_idx=0, | |
| obj_id=1, | |
| points=points_array, | |
| labels=labels_array, | |
| ) | |
| # Propagate through video | |
| print("Propagating masks through video...") | |
| masks = [] | |
| for frame_idx, object_ids, mask_logits in self.predictor.propagate_in_video(inference_state): | |
| # Get mask for object ID 1 | |
| obj_ids_list = object_ids.tolist() if hasattr(object_ids, 'tolist') else object_ids | |
| if 1 in obj_ids_list: | |
| mask_idx = obj_ids_list.index(1) | |
| mask = (mask_logits[mask_idx] > 0.0).cpu().numpy() | |
| mask_uint8 = (mask.squeeze() * 255).astype(np.uint8) | |
| masks.append(mask_uint8) | |
| else: | |
| # No mask for this frame, use empty mask | |
| h, w = frames[0].shape[:2] | |
| masks.append(np.zeros((h, w), dtype=np.uint8)) | |
| print(f"Generated {len(masks)} masks") | |
| return masks | |
| finally: | |
| # Clean up temporary directory | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| def get_first_frame_mask(self, frame: np.ndarray, points: List[List[int]], | |
| labels: List[int]) -> np.ndarray: | |
| """ | |
| Get mask for first frame only (for preview) | |
| Args: | |
| frame: np.ndarray, (H, W, 3), uint8 RGB frame | |
| points: List of [x, y] coordinates | |
| labels: List of labels (1 for positive, 0 for negative) | |
| Returns: | |
| mask: np.ndarray, (H, W), uint8 binary mask | |
| """ | |
| # Create temporary directory | |
| temp_dir = Path(tempfile.mkdtemp()) | |
| frames_dir = temp_dir / "frames" | |
| frames_dir.mkdir(exist_ok=True) | |
| try: | |
| # Save single frame | |
| frame_path = frames_dir / "00000.jpg" | |
| Image.fromarray(frame).save(frame_path, quality=95) | |
| # Initialize SAM2 | |
| inference_state = self.predictor.init_state(video_path=str(frames_dir)) | |
| # Add prompts | |
| points_array = np.array(points, dtype=np.float32) | |
| labels_array = np.array(labels, dtype=np.int32) | |
| _, out_obj_ids, out_mask_logits = self.predictor.add_new_points( | |
| inference_state=inference_state, | |
| frame_idx=0, | |
| obj_id=1, | |
| points=points_array, | |
| labels=labels_array, | |
| ) | |
| # Get mask | |
| if len(out_mask_logits) > 0: | |
| mask = (out_mask_logits[0] > 0.0).cpu().numpy() | |
| mask_uint8 = (mask.squeeze() * 255).astype(np.uint8) | |
| return mask_uint8 | |
| else: | |
| return np.zeros(frame.shape[:2], dtype=np.uint8) | |
| finally: | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| def load_sam2_tracker(checkpoint_path=None, device="cuda"): | |
| """ | |
| Load SAM2 video tracker with pretrained weights | |
| Args: | |
| checkpoint_path: Path to SAM2 checkpoint (if None, uses default location) | |
| device: Device to run on | |
| Returns: | |
| SAM2VideoTracker instance | |
| """ | |
| # Use provided path or default | |
| if checkpoint_path is None: | |
| checkpoint_path = "checkpoints/sam2.1_hiera_large.pt" | |
| # Config file should be in the SAM2 repo | |
| config_file = "configs/sam2.1/sam2.1_hiera_l.yaml" | |
| # Check if we need to use the local yaml file | |
| if not os.path.exists(config_file): | |
| config_file = "sam2_hiera_l.yaml" | |
| print(f"Loading SAM2 from {checkpoint_path}...") | |
| print(f"Using config: {config_file}") | |
| tracker = SAM2VideoTracker(checkpoint_path, config_file, device) | |
| return tracker | |