File size: 6,812 Bytes
d33e75e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
"""
SAM2 Wrapper for Video Mask Tracking - Hugging Face Space Version
Handles mask generation and propagation through video
"""

import sys
import os
from pathlib import Path

# Add SAM2 to path if installed
try:
    import sam2
except ImportError:
    # Try to add from common locations
    possible_paths = [
        "/home/cvlab19/project/samuel/CVPR/sam2",
        "./sam2"
    ]
    for path in possible_paths:
        if os.path.exists(path):
            sys.path.append(path)
            break

import cv2
import numpy as np
import torch
from PIL import Image
from typing import List, Tuple
import tempfile
import shutil

from sam2.build_sam import build_sam2_video_predictor


class SAM2VideoTracker:
    def __init__(self, checkpoint_path, config_file, device="cuda"):
        """
        Initialize SAM2 video tracker
        
        Args:
            checkpoint_path: Path to SAM2 checkpoint
            config_file: Path to SAM2 config file
            device: Device to run on
        """
        self.device = device
        self.predictor = build_sam2_video_predictor(
            config_file=config_file,
            ckpt_path=checkpoint_path,
            device=device
        )
        print(f"SAM2 video tracker initialized on {device}")
    
    def track_video(self, frames: List[np.ndarray], points: List[List[int]], 
                   labels: List[int]) -> List[np.ndarray]:
        """
        Track object through video using SAM2
        
        Args:
            frames: List of numpy arrays, [(H,W,3)]*n, uint8 RGB frames
            points: List of [x, y] coordinates for prompts
            labels: List of labels (1 for positive, 0 for negative)
            
        Returns:
            masks: List of numpy arrays, [(H,W)]*n, uint8 binary masks
        """
        # Create temporary directory for frames
        temp_dir = Path(tempfile.mkdtemp())
        frames_dir = temp_dir / "frames"
        frames_dir.mkdir(exist_ok=True)
        
        try:
            # Save frames to temp directory
            print(f"Saving {len(frames)} frames to temporary directory...")
            for i, frame in enumerate(frames):
                frame_path = frames_dir / f"{i:05d}.jpg"
                Image.fromarray(frame).save(frame_path, quality=95)
            
            # Initialize SAM2 video predictor
            print("Initializing SAM2 inference state...")
            inference_state = self.predictor.init_state(video_path=str(frames_dir))
            
            # Add prompts on first frame
            points_array = np.array(points, dtype=np.float32)
            labels_array = np.array(labels, dtype=np.int32)
            
            print(f"Adding {len(points)} point prompts on first frame...")
            _, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
                inference_state=inference_state,
                frame_idx=0,
                obj_id=1,
                points=points_array,
                labels=labels_array,
            )
            
            # Propagate through video
            print("Propagating masks through video...")
            masks = []
            for frame_idx, object_ids, mask_logits in self.predictor.propagate_in_video(inference_state):
                # Get mask for object ID 1
                obj_ids_list = object_ids.tolist() if hasattr(object_ids, 'tolist') else object_ids
                
                if 1 in obj_ids_list:
                    mask_idx = obj_ids_list.index(1)
                    mask = (mask_logits[mask_idx] > 0.0).cpu().numpy()
                    mask_uint8 = (mask.squeeze() * 255).astype(np.uint8)
                    masks.append(mask_uint8)
                else:
                    # No mask for this frame, use empty mask
                    h, w = frames[0].shape[:2]
                    masks.append(np.zeros((h, w), dtype=np.uint8))
            
            print(f"Generated {len(masks)} masks")
            return masks
            
        finally:
            # Clean up temporary directory
            shutil.rmtree(temp_dir, ignore_errors=True)
    
    def get_first_frame_mask(self, frame: np.ndarray, points: List[List[int]], 
                            labels: List[int]) -> np.ndarray:
        """
        Get mask for first frame only (for preview)
        
        Args:
            frame: np.ndarray, (H, W, 3), uint8 RGB frame
            points: List of [x, y] coordinates
            labels: List of labels (1 for positive, 0 for negative)
            
        Returns:
            mask: np.ndarray, (H, W), uint8 binary mask
        """
        # Create temporary directory
        temp_dir = Path(tempfile.mkdtemp())
        frames_dir = temp_dir / "frames"
        frames_dir.mkdir(exist_ok=True)
        
        try:
            # Save single frame
            frame_path = frames_dir / "00000.jpg"
            Image.fromarray(frame).save(frame_path, quality=95)
            
            # Initialize SAM2
            inference_state = self.predictor.init_state(video_path=str(frames_dir))
            
            # Add prompts
            points_array = np.array(points, dtype=np.float32)
            labels_array = np.array(labels, dtype=np.int32)
            
            _, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
                inference_state=inference_state,
                frame_idx=0,
                obj_id=1,
                points=points_array,
                labels=labels_array,
            )
            
            # Get mask
            if len(out_mask_logits) > 0:
                mask = (out_mask_logits[0] > 0.0).cpu().numpy()
                mask_uint8 = (mask.squeeze() * 255).astype(np.uint8)
                return mask_uint8
            else:
                return np.zeros(frame.shape[:2], dtype=np.uint8)
                
        finally:
            shutil.rmtree(temp_dir, ignore_errors=True)


def load_sam2_tracker(checkpoint_path=None, device="cuda"):
    """
    Load SAM2 video tracker with pretrained weights
    
    Args:
        checkpoint_path: Path to SAM2 checkpoint (if None, uses default location)
        device: Device to run on
        
    Returns:
        SAM2VideoTracker instance
    """
    # Use provided path or default
    if checkpoint_path is None:
        checkpoint_path = "checkpoints/sam2.1_hiera_large.pt"
    
    # Config file should be in the SAM2 repo
    config_file = "configs/sam2.1/sam2.1_hiera_l.yaml"
    
    # Check if we need to use the local yaml file
    if not os.path.exists(config_file):
        config_file = "sam2_hiera_l.yaml"
    
    print(f"Loading SAM2 from {checkpoint_path}...")
    print(f"Using config: {config_file}")
    
    tracker = SAM2VideoTracker(checkpoint_path, config_file, device)
    
    return tracker