File size: 14,047 Bytes
41d206b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
import os
import io
import base64
import tempfile
import zipfile
from typing import Dict, Any, Optional
from pathlib import Path
import json

import torch
import numpy as np
from PIL import Image
import cv2

# Transformers imports for SAM3
from transformers import Sam3VideoModel, Sam3VideoProcessor

# HuggingFace Hub for uploads
try:
    from huggingface_hub import HfApi
    HF_HUB_AVAILABLE = True
except ImportError:
    HF_HUB_AVAILABLE = False


class EndpointHandler:
    """
    SAM3 Video Segmentation Handler for HuggingFace Inference Endpoints
    
    Processes video with text prompts and returns segmentation masks.
    Uses transformers library for clean integration with HuggingFace models.
    """
    
    def __init__(self, path: str = ""):
        """
        Initialize SAM3 video model using transformers.
        
        Args:
            path: Path to model repository (contains model files)
                  For HF Inference Endpoints, this is /repository
                  Contains: sam3.pt, config.json, processor_config.json, etc.
        """
        print(f"[INIT] Initializing SAM3 video model from {path}")
        
        # Set device
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        if self.device != "cuda":
            raise ValueError("SAM3 requires GPU acceleration. No CUDA device found.")
        
        print(f"[INIT] Using device: {self.device}")
        
        # Load model and processor from the repository
        # If path is empty or ".", try to load from default model ID
        model_path = path if path and path != "." else "facebook/sam3"
        
        try:
            print(f"[INIT] Loading model from: {model_path}")
            self.model = Sam3VideoModel.from_pretrained(
                model_path,
                torch_dtype=torch.bfloat16,
                device_map=self.device
            )
            
            self.processor = Sam3VideoProcessor.from_pretrained(model_path)
            
            print("[INIT] SAM3 video model loaded successfully")
            
        except Exception as e:
            print(f"[INIT] Error loading from {model_path}: {e}")
            print("[INIT] Falling back to facebook/sam3")
            
            # Fallback to public model
            self.model = Sam3VideoModel.from_pretrained(
                "facebook/sam3",
                torch_dtype=torch.bfloat16,
                device_map=self.device
            )
            
            self.processor = Sam3VideoProcessor.from_pretrained("facebook/sam3")
            
            print("[INIT] SAM3 video model loaded from facebook/sam3")
        
        # Initialize HuggingFace API for uploads (if available)
        self.hf_api = None
        hf_token = os.getenv("HF_TOKEN")
        if HF_HUB_AVAILABLE and hf_token:
            self.hf_api = HfApi(token=hf_token)
            print("[INIT] HuggingFace Hub API initialized")
        else:
            print("[INIT] HuggingFace Hub uploads disabled (no token or huggingface_hub not installed)")
    
    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Process video segmentation request using transformers API.
        
        Expected input format:
        {
            "video": <base64_encoded_video>,
            "text_prompt": "object to segment",
            "return_format": "download_url" or "base64" or "metadata_only"  # optional
            "output_repo": "username/dataset-name",  # optional, for HF upload
        }
        
        Returns:
        {
            "download_url": "https://...",  # if uploaded to HF
            "frame_count": 120,
            "video_metadata": {...},
            "compressed_size_mb": 15.3,
            "objects_detected": [1, 2, 3]  # object IDs
        }
        """
        try:
            # Extract parameters
            video_data = data.get("video")
            text_prompt = data.get("text_prompt", data.get("inputs", ""))
            output_repo = data.get("output_repo")
            return_format = data.get("return_format", "metadata_only")
            
            if not video_data:
                return {"error": "No video data provided. Include 'video' in request."}
            
            if not text_prompt:
                return {"error": "No text prompt provided. Include 'text_prompt' or 'inputs' in request."}
            
            print(f"[REQUEST] Processing video with prompt: '{text_prompt}'")
            print(f"[REQUEST] Return format: {return_format}")
            
            # Process video in temporary directory
            with tempfile.TemporaryDirectory() as tmpdir:
                tmpdir_path = Path(tmpdir)
                
                # Step 1: Decode and save video
                video_path = self._prepare_video(video_data, tmpdir_path)
                print(f"[STEP 1] Video prepared at: {video_path}")
                
                # Step 2: Load video frames
                video_frames = self._load_video_frames(video_path)
                print(f"[STEP 2] Loaded {len(video_frames)} frames")
                
                # Step 3: Initialize inference session
                inference_session = self.processor.init_video_session(
                    video=video_frames,
                    inference_device=self.device,
                    processing_device="cpu",
                    video_storage_device="cpu",
                    dtype=torch.bfloat16,
                )
                print(f"[STEP 3] Inference session initialized")
                
                # Step 4: Add text prompt
                inference_session = self.processor.add_text_prompt(
                    inference_session=inference_session,
                    text=text_prompt,
                )
                print(f"[STEP 4] Text prompt added")
                
                # Step 5: Propagate through video and save masks
                masks_dir = tmpdir_path / "masks"
                masks_dir.mkdir()
                
                frame_outputs = self._propagate_and_save_masks(
                    inference_session, 
                    masks_dir
                )
                print(f"[STEP 5] Propagated through {len(frame_outputs)} frames")
                
                # Get unique object IDs across all frames
                all_object_ids = set()
                for frame_output in frame_outputs.values():
                    if 'object_ids' in frame_output and frame_output['object_ids'] is not None:
                        ids = frame_output['object_ids']
                        if torch.is_tensor(ids):
                            all_object_ids.update(ids.tolist())
                        else:
                            all_object_ids.update(ids)
                
                # Step 6: Create ZIP archive
                zip_path = tmpdir_path / "masks.zip"
                self._create_zip(masks_dir, zip_path)
                zip_size_mb = zip_path.stat().st_size / 1e6
                print(f"[STEP 6] Created ZIP archive: {zip_size_mb:.2f} MB")
                
                # Step 7: Prepare response based on return_format
                response = {
                    "frame_count": len(frame_outputs),
                    "objects_detected": sorted(list(all_object_ids)) if all_object_ids else [],
                    "compressed_size_mb": round(zip_size_mb, 2),
                    "video_metadata": self._get_video_metadata_from_frames(video_frames)
                }
                
                if return_format == "download_url" and output_repo:
                    # Upload to HuggingFace
                    download_url = self._upload_to_hf(zip_path, output_repo)
                    response["download_url"] = download_url
                    print(f"[STEP 7] Uploaded to HuggingFace: {download_url}")
                
                elif return_format == "base64":
                    # Return base64 encoded ZIP
                    with open(zip_path, "rb") as f:
                        zip_base64 = base64.b64encode(f.read()).decode('utf-8')
                    response["masks_zip_base64"] = zip_base64
                    print(f"[STEP 7] Returning base64 encoded ZIP")
                
                else:
                    # metadata_only - just return stats
                    response["note"] = "Masks generated but not returned. Use return_format='base64' or 'download_url' to get masks."
                    print(f"[STEP 7] Returning metadata only")
                
                return response
        
        except Exception as e:
            print(f"[ERROR] {type(e).__name__}: {str(e)}")
            import traceback
            traceback.print_exc()
            return {
                "error": str(e),
                "error_type": type(e).__name__
            }
    
    def _prepare_video(self, video_data: Any, tmpdir: Path) -> Path:
        """Decode base64 video data and save to temporary location."""
        video_path = tmpdir / "input_video.mp4"
        
        if isinstance(video_data, str):
            # Base64 encoded
            video_bytes = base64.b64decode(video_data)
        elif isinstance(video_data, bytes):
            video_bytes = video_data
        else:
            raise ValueError(f"Unsupported video data type: {type(video_data)}")
        
        video_path.write_bytes(video_bytes)
        return video_path
    
    def _load_video_frames(self, video_path: Path) -> list:
        """Load video frames from MP4 file."""
        from transformers.video_utils import load_video
        
        # load_video returns (frames, audio) - we only need frames
        frames, _ = load_video(str(video_path))
        return frames
    
    def _propagate_and_save_masks(self, inference_session, masks_dir: Path) -> Dict[int, Dict]:
        """
        Propagate masks through video using transformers API and save to disk.
        
        Returns dict mapping frame_idx -> outputs
        """
        outputs_per_frame = {}
        
        # Use the model's propagate_in_video_iterator
        for model_outputs in self.model.propagate_in_video_iterator(
            inference_session=inference_session,
            max_frame_num_to_track=None  # Process all frames
        ):
            frame_idx = model_outputs.frame_idx
            
            # Post-process outputs
            processed_outputs = self.processor.postprocess_outputs(
                inference_session, 
                model_outputs
            )
            
            outputs_per_frame[frame_idx] = processed_outputs
            
            # Save masks for this frame
            self._save_frame_masks(processed_outputs, masks_dir, frame_idx)
        
        return outputs_per_frame
    
    def _save_frame_masks(self, outputs: Dict, masks_dir: Path, frame_idx: int):
        """
        Save masks for a single frame.
        
        Saves combined binary mask with all objects.
        Format: mask_NNNN.png (white = object, black = background)
        """
        # Extract masks from outputs
        if 'masks' not in outputs or outputs['masks'] is None or len(outputs['masks']) == 0:
            # No objects detected - save empty mask
            # Get dimensions from inference session or use default
            height = 1080
            width = 1920
            combined_mask = np.zeros((height, width), dtype=np.uint8)
        else:
            masks = outputs['masks']  # Tensor of shape (num_objects, H, W)
            
            # Convert to numpy if needed
            if torch.is_tensor(masks):
                masks = masks.cpu().numpy()
            
            # Combine all object masks into single binary mask
            if len(masks.shape) == 3:
                # Multiple objects - combine with logical OR
                combined_mask = np.any(masks > 0.5, axis=0).astype(np.uint8) * 255
            elif len(masks.shape) == 2:
                # Single object
                combined_mask = (masks > 0.5).astype(np.uint8) * 255
            else:
                # Unexpected shape - save empty
                combined_mask = np.zeros((1080, 1920), dtype=np.uint8)
        
        # Save as PNG
        mask_filename = masks_dir / f"mask_{frame_idx:04d}.png"
        mask_image = Image.fromarray(combined_mask)
        mask_image.save(mask_filename, compress_level=9)
    
    def _create_zip(self, masks_dir: Path, zip_path: Path):
        """Create ZIP archive of all mask PNGs."""
        with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
            for mask_file in sorted(masks_dir.glob("mask_*.png")):
                zipf.write(mask_file, mask_file.name)
    
    def _upload_to_hf(self, zip_path: Path, output_repo: str) -> str:
        """
        Upload ZIP to HuggingFace dataset repository.
        
        Returns: Download URL
        """
        if not self.hf_api:
            raise RuntimeError("HuggingFace Hub API not available. Set HF_TOKEN environment variable.")
        
        # Upload file to dataset repo
        path_in_repo = f"masks/{zip_path.name}"
        
        self.hf_api.upload_file(
            path_or_fileobj=str(zip_path),
            path_in_repo=path_in_repo,
            repo_id=output_repo,
            repo_type="dataset",
        )
        
        # Construct download URL
        download_url = f"https://huggingface.co/datasets/{output_repo}/resolve/main/{path_in_repo}"
        return download_url
    
    def _get_video_metadata_from_frames(self, frames: list) -> Dict:
        """Extract metadata from loaded video frames."""
        if not frames or len(frames) == 0:
            return {}
        
        # Frames are numpy arrays of shape (H, W, C)
        first_frame = frames[0]
        
        return {
            "frame_count": len(frames),
            "height": first_frame.shape[0],
            "width": first_frame.shape[1],
            "channels": first_frame.shape[2] if len(first_frame.shape) > 2 else 1,
        }