File size: 6,779 Bytes
d33e75e
 
 
 
 
0b67fec
d33e75e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b67fec
 
 
e6076ca
0b67fec
 
 
 
 
 
 
 
 
 
 
 
d33e75e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""
SAM2 Wrapper for Video Mask Tracking
Handles mask generation and propagation through video
"""

import os
import cv2
import numpy as np
import torch
from PIL import Image
from pathlib import Path
from typing import List, Tuple
import tempfile
import shutil

from sam2.build_sam import build_sam2_video_predictor


class SAM2VideoTracker:
    def __init__(self, checkpoint_path, config_file, device="cuda"):
        """
        Initialize SAM2 video tracker
        
        Args:
            checkpoint_path: Path to SAM2 checkpoint
            config_file: Path to SAM2 config file
            device: Device to run on
        """
        self.device = device
        self.predictor = build_sam2_video_predictor(
            config_file=config_file,
            ckpt_path=checkpoint_path,
            device=device
        )
        print(f"SAM2 video tracker initialized on {device}")
    
    def track_video(self, frames: List[np.ndarray], points: List[List[int]], 
                   labels: List[int]) -> List[np.ndarray]:
        """
        Track object through video using SAM2
        
        Args:
            frames: List of numpy arrays, [(H,W,3)]*n, uint8 RGB frames
            points: List of [x, y] coordinates for prompts
            labels: List of labels (1 for positive, 0 for negative)
            
        Returns:
            masks: List of numpy arrays, [(H,W)]*n, uint8 binary masks
        """
        # Create temporary directory for frames
        temp_dir = Path(tempfile.mkdtemp())
        frames_dir = temp_dir / "frames"
        frames_dir.mkdir(exist_ok=True)
        
        try:
            # Save frames to temp directory
            print(f"Saving {len(frames)} frames to temporary directory...")
            for i, frame in enumerate(frames):
                frame_path = frames_dir / f"{i:05d}.jpg"
                Image.fromarray(frame).save(frame_path, quality=95)
            
            # Initialize SAM2 video predictor
            print("Initializing SAM2 inference state...")
            inference_state = self.predictor.init_state(video_path=str(frames_dir))
            
            # Add prompts on first frame
            points_array = np.array(points, dtype=np.float32)
            labels_array = np.array(labels, dtype=np.int32)
            
            print(f"Adding {len(points)} point prompts on first frame...")
            _, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
                inference_state=inference_state,
                frame_idx=0,
                obj_id=1,
                points=points_array,
                labels=labels_array,
            )
            
            # Propagate through video
            print("Propagating masks through video...")
            masks = []
            for frame_idx, object_ids, mask_logits in self.predictor.propagate_in_video(inference_state):
                # Get mask for object ID 1
                # object_ids can be a tensor or a list
                obj_ids_list = object_ids.tolist() if hasattr(object_ids, 'tolist') else object_ids
                
                if 1 in obj_ids_list:
                    mask_idx = obj_ids_list.index(1)
                    mask = (mask_logits[mask_idx] > 0.0).cpu().numpy()
                    mask_uint8 = (mask.squeeze() * 255).astype(np.uint8)
                    masks.append(mask_uint8)
                else:
                    # No mask for this frame, use empty mask
                    h, w = frames[0].shape[:2]
                    masks.append(np.zeros((h, w), dtype=np.uint8))
            
            print(f"Generated {len(masks)} masks")
            return masks
            
        finally:
            # Clean up temporary directory
            shutil.rmtree(temp_dir, ignore_errors=True)
    
    def get_first_frame_mask(self, frame: np.ndarray, points: List[List[int]], 
                            labels: List[int]) -> np.ndarray:
        """
        Get mask for first frame only (for preview)
        
        Args:
            frame: np.ndarray, (H, W, 3), uint8 RGB frame
            points: List of [x, y] coordinates
            labels: List of labels (1 for positive, 0 for negative)
            
        Returns:
            mask: np.ndarray, (H, W), uint8 binary mask
        """
        # Create temporary directory
        temp_dir = Path(tempfile.mkdtemp())
        frames_dir = temp_dir / "frames"
        frames_dir.mkdir(exist_ok=True)
        
        try:
            # Save single frame
            frame_path = frames_dir / "00000.jpg"
            Image.fromarray(frame).save(frame_path, quality=95)
            
            # Initialize SAM2
            inference_state = self.predictor.init_state(video_path=str(frames_dir))
            
            # Add prompts
            points_array = np.array(points, dtype=np.float32)
            labels_array = np.array(labels, dtype=np.int32)
            
            _, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
                inference_state=inference_state,
                frame_idx=0,
                obj_id=1,
                points=points_array,
                labels=labels_array,
            )
            
            # Get mask
            if len(out_mask_logits) > 0:
                mask = (out_mask_logits[0] > 0.0).cpu().numpy()
                mask_uint8 = (mask.squeeze() * 255).astype(np.uint8)
                return mask_uint8
            else:
                return np.zeros(frame.shape[:2], dtype=np.uint8)
                
        finally:
            shutil.rmtree(temp_dir, ignore_errors=True)


def load_sam2_tracker(device="cuda"):
    """
    Load SAM2 video tracker with pretrained weights
    
    Args:
        device: Device to run on
        
    Returns:
        SAM2VideoTracker instance
    """
    # Use relative paths that work on Hugging Face Space
    # The checkpoint file should be in the root directory or checkpoints/
    checkpoint_path = "sam2.1_hiera_large.pt"
    config_file = "configs/sam2.1/sam2.1_hiera_l.yaml"
    
    # Check if checkpoint exists
    if not os.path.exists(checkpoint_path):
        # Try alternative path
        alt_checkpoint_path = os.path.join("checkpoints", "sam2.1_hiera_large.pt")
        if os.path.exists(alt_checkpoint_path):
            checkpoint_path = alt_checkpoint_path
        else:
            raise FileNotFoundError(
                f"SAM2 checkpoint not found at {checkpoint_path} or {alt_checkpoint_path}. "
                "Please run download_checkpoints.sh first or ensure sam2.1_hiera_large.pt is in the root directory."
            )
    
    print(f"Loading SAM2 from {checkpoint_path}...")
    tracker = SAM2VideoTracker(checkpoint_path, config_file, device)
    
    return tracker