Justin331 commited on
Commit
41d206b
·
verified ·
1 Parent(s): 682c8e3

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. handler.py +351 -0
  2. requirements.txt +16 -0
handler.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import base64
4
+ import tempfile
5
+ import zipfile
6
+ from typing import Dict, Any, Optional
7
+ from pathlib import Path
8
+ import json
9
+
10
+ import torch
11
+ import numpy as np
12
+ from PIL import Image
13
+ import cv2
14
+
15
+ # Transformers imports for SAM3
16
+ from transformers import Sam3VideoModel, Sam3VideoProcessor
17
+
18
+ # HuggingFace Hub for uploads
19
+ try:
20
+ from huggingface_hub import HfApi
21
+ HF_HUB_AVAILABLE = True
22
+ except ImportError:
23
+ HF_HUB_AVAILABLE = False
24
+
25
+
26
+ class EndpointHandler:
27
+ """
28
+ SAM3 Video Segmentation Handler for HuggingFace Inference Endpoints
29
+
30
+ Processes video with text prompts and returns segmentation masks.
31
+ Uses transformers library for clean integration with HuggingFace models.
32
+ """
33
+
34
+ def __init__(self, path: str = ""):
35
+ """
36
+ Initialize SAM3 video model using transformers.
37
+
38
+ Args:
39
+ path: Path to model repository (contains model files)
40
+ For HF Inference Endpoints, this is /repository
41
+ Contains: sam3.pt, config.json, processor_config.json, etc.
42
+ """
43
+ print(f"[INIT] Initializing SAM3 video model from {path}")
44
+
45
+ # Set device
46
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
47
+ if self.device != "cuda":
48
+ raise ValueError("SAM3 requires GPU acceleration. No CUDA device found.")
49
+
50
+ print(f"[INIT] Using device: {self.device}")
51
+
52
+ # Load model and processor from the repository
53
+ # If path is empty or ".", try to load from default model ID
54
+ model_path = path if path and path != "." else "facebook/sam3"
55
+
56
+ try:
57
+ print(f"[INIT] Loading model from: {model_path}")
58
+ self.model = Sam3VideoModel.from_pretrained(
59
+ model_path,
60
+ torch_dtype=torch.bfloat16,
61
+ device_map=self.device
62
+ )
63
+
64
+ self.processor = Sam3VideoProcessor.from_pretrained(model_path)
65
+
66
+ print("[INIT] SAM3 video model loaded successfully")
67
+
68
+ except Exception as e:
69
+ print(f"[INIT] Error loading from {model_path}: {e}")
70
+ print("[INIT] Falling back to facebook/sam3")
71
+
72
+ # Fallback to public model
73
+ self.model = Sam3VideoModel.from_pretrained(
74
+ "facebook/sam3",
75
+ torch_dtype=torch.bfloat16,
76
+ device_map=self.device
77
+ )
78
+
79
+ self.processor = Sam3VideoProcessor.from_pretrained("facebook/sam3")
80
+
81
+ print("[INIT] SAM3 video model loaded from facebook/sam3")
82
+
83
+ # Initialize HuggingFace API for uploads (if available)
84
+ self.hf_api = None
85
+ hf_token = os.getenv("HF_TOKEN")
86
+ if HF_HUB_AVAILABLE and hf_token:
87
+ self.hf_api = HfApi(token=hf_token)
88
+ print("[INIT] HuggingFace Hub API initialized")
89
+ else:
90
+ print("[INIT] HuggingFace Hub uploads disabled (no token or huggingface_hub not installed)")
91
+
92
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
93
+ """
94
+ Process video segmentation request using transformers API.
95
+
96
+ Expected input format:
97
+ {
98
+ "video": <base64_encoded_video>,
99
+ "text_prompt": "object to segment",
100
+ "return_format": "download_url" or "base64" or "metadata_only" # optional
101
+ "output_repo": "username/dataset-name", # optional, for HF upload
102
+ }
103
+
104
+ Returns:
105
+ {
106
+ "download_url": "https://...", # if uploaded to HF
107
+ "frame_count": 120,
108
+ "video_metadata": {...},
109
+ "compressed_size_mb": 15.3,
110
+ "objects_detected": [1, 2, 3] # object IDs
111
+ }
112
+ """
113
+ try:
114
+ # Extract parameters
115
+ video_data = data.get("video")
116
+ text_prompt = data.get("text_prompt", data.get("inputs", ""))
117
+ output_repo = data.get("output_repo")
118
+ return_format = data.get("return_format", "metadata_only")
119
+
120
+ if not video_data:
121
+ return {"error": "No video data provided. Include 'video' in request."}
122
+
123
+ if not text_prompt:
124
+ return {"error": "No text prompt provided. Include 'text_prompt' or 'inputs' in request."}
125
+
126
+ print(f"[REQUEST] Processing video with prompt: '{text_prompt}'")
127
+ print(f"[REQUEST] Return format: {return_format}")
128
+
129
+ # Process video in temporary directory
130
+ with tempfile.TemporaryDirectory() as tmpdir:
131
+ tmpdir_path = Path(tmpdir)
132
+
133
+ # Step 1: Decode and save video
134
+ video_path = self._prepare_video(video_data, tmpdir_path)
135
+ print(f"[STEP 1] Video prepared at: {video_path}")
136
+
137
+ # Step 2: Load video frames
138
+ video_frames = self._load_video_frames(video_path)
139
+ print(f"[STEP 2] Loaded {len(video_frames)} frames")
140
+
141
+ # Step 3: Initialize inference session
142
+ inference_session = self.processor.init_video_session(
143
+ video=video_frames,
144
+ inference_device=self.device,
145
+ processing_device="cpu",
146
+ video_storage_device="cpu",
147
+ dtype=torch.bfloat16,
148
+ )
149
+ print(f"[STEP 3] Inference session initialized")
150
+
151
+ # Step 4: Add text prompt
152
+ inference_session = self.processor.add_text_prompt(
153
+ inference_session=inference_session,
154
+ text=text_prompt,
155
+ )
156
+ print(f"[STEP 4] Text prompt added")
157
+
158
+ # Step 5: Propagate through video and save masks
159
+ masks_dir = tmpdir_path / "masks"
160
+ masks_dir.mkdir()
161
+
162
+ frame_outputs = self._propagate_and_save_masks(
163
+ inference_session,
164
+ masks_dir
165
+ )
166
+ print(f"[STEP 5] Propagated through {len(frame_outputs)} frames")
167
+
168
+ # Get unique object IDs across all frames
169
+ all_object_ids = set()
170
+ for frame_output in frame_outputs.values():
171
+ if 'object_ids' in frame_output and frame_output['object_ids'] is not None:
172
+ ids = frame_output['object_ids']
173
+ if torch.is_tensor(ids):
174
+ all_object_ids.update(ids.tolist())
175
+ else:
176
+ all_object_ids.update(ids)
177
+
178
+ # Step 6: Create ZIP archive
179
+ zip_path = tmpdir_path / "masks.zip"
180
+ self._create_zip(masks_dir, zip_path)
181
+ zip_size_mb = zip_path.stat().st_size / 1e6
182
+ print(f"[STEP 6] Created ZIP archive: {zip_size_mb:.2f} MB")
183
+
184
+ # Step 7: Prepare response based on return_format
185
+ response = {
186
+ "frame_count": len(frame_outputs),
187
+ "objects_detected": sorted(list(all_object_ids)) if all_object_ids else [],
188
+ "compressed_size_mb": round(zip_size_mb, 2),
189
+ "video_metadata": self._get_video_metadata_from_frames(video_frames)
190
+ }
191
+
192
+ if return_format == "download_url" and output_repo:
193
+ # Upload to HuggingFace
194
+ download_url = self._upload_to_hf(zip_path, output_repo)
195
+ response["download_url"] = download_url
196
+ print(f"[STEP 7] Uploaded to HuggingFace: {download_url}")
197
+
198
+ elif return_format == "base64":
199
+ # Return base64 encoded ZIP
200
+ with open(zip_path, "rb") as f:
201
+ zip_base64 = base64.b64encode(f.read()).decode('utf-8')
202
+ response["masks_zip_base64"] = zip_base64
203
+ print(f"[STEP 7] Returning base64 encoded ZIP")
204
+
205
+ else:
206
+ # metadata_only - just return stats
207
+ response["note"] = "Masks generated but not returned. Use return_format='base64' or 'download_url' to get masks."
208
+ print(f"[STEP 7] Returning metadata only")
209
+
210
+ return response
211
+
212
+ except Exception as e:
213
+ print(f"[ERROR] {type(e).__name__}: {str(e)}")
214
+ import traceback
215
+ traceback.print_exc()
216
+ return {
217
+ "error": str(e),
218
+ "error_type": type(e).__name__
219
+ }
220
+
221
+ def _prepare_video(self, video_data: Any, tmpdir: Path) -> Path:
222
+ """Decode base64 video data and save to temporary location."""
223
+ video_path = tmpdir / "input_video.mp4"
224
+
225
+ if isinstance(video_data, str):
226
+ # Base64 encoded
227
+ video_bytes = base64.b64decode(video_data)
228
+ elif isinstance(video_data, bytes):
229
+ video_bytes = video_data
230
+ else:
231
+ raise ValueError(f"Unsupported video data type: {type(video_data)}")
232
+
233
+ video_path.write_bytes(video_bytes)
234
+ return video_path
235
+
236
+ def _load_video_frames(self, video_path: Path) -> list:
237
+ """Load video frames from MP4 file."""
238
+ from transformers.video_utils import load_video
239
+
240
+ # load_video returns (frames, audio) - we only need frames
241
+ frames, _ = load_video(str(video_path))
242
+ return frames
243
+
244
+ def _propagate_and_save_masks(self, inference_session, masks_dir: Path) -> Dict[int, Dict]:
245
+ """
246
+ Propagate masks through video using transformers API and save to disk.
247
+
248
+ Returns dict mapping frame_idx -> outputs
249
+ """
250
+ outputs_per_frame = {}
251
+
252
+ # Use the model's propagate_in_video_iterator
253
+ for model_outputs in self.model.propagate_in_video_iterator(
254
+ inference_session=inference_session,
255
+ max_frame_num_to_track=None # Process all frames
256
+ ):
257
+ frame_idx = model_outputs.frame_idx
258
+
259
+ # Post-process outputs
260
+ processed_outputs = self.processor.postprocess_outputs(
261
+ inference_session,
262
+ model_outputs
263
+ )
264
+
265
+ outputs_per_frame[frame_idx] = processed_outputs
266
+
267
+ # Save masks for this frame
268
+ self._save_frame_masks(processed_outputs, masks_dir, frame_idx)
269
+
270
+ return outputs_per_frame
271
+
272
+ def _save_frame_masks(self, outputs: Dict, masks_dir: Path, frame_idx: int):
273
+ """
274
+ Save masks for a single frame.
275
+
276
+ Saves combined binary mask with all objects.
277
+ Format: mask_NNNN.png (white = object, black = background)
278
+ """
279
+ # Extract masks from outputs
280
+ if 'masks' not in outputs or outputs['masks'] is None or len(outputs['masks']) == 0:
281
+ # No objects detected - save empty mask
282
+ # Get dimensions from inference session or use default
283
+ height = 1080
284
+ width = 1920
285
+ combined_mask = np.zeros((height, width), dtype=np.uint8)
286
+ else:
287
+ masks = outputs['masks'] # Tensor of shape (num_objects, H, W)
288
+
289
+ # Convert to numpy if needed
290
+ if torch.is_tensor(masks):
291
+ masks = masks.cpu().numpy()
292
+
293
+ # Combine all object masks into single binary mask
294
+ if len(masks.shape) == 3:
295
+ # Multiple objects - combine with logical OR
296
+ combined_mask = np.any(masks > 0.5, axis=0).astype(np.uint8) * 255
297
+ elif len(masks.shape) == 2:
298
+ # Single object
299
+ combined_mask = (masks > 0.5).astype(np.uint8) * 255
300
+ else:
301
+ # Unexpected shape - save empty
302
+ combined_mask = np.zeros((1080, 1920), dtype=np.uint8)
303
+
304
+ # Save as PNG
305
+ mask_filename = masks_dir / f"mask_{frame_idx:04d}.png"
306
+ mask_image = Image.fromarray(combined_mask)
307
+ mask_image.save(mask_filename, compress_level=9)
308
+
309
+ def _create_zip(self, masks_dir: Path, zip_path: Path):
310
+ """Create ZIP archive of all mask PNGs."""
311
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
312
+ for mask_file in sorted(masks_dir.glob("mask_*.png")):
313
+ zipf.write(mask_file, mask_file.name)
314
+
315
+ def _upload_to_hf(self, zip_path: Path, output_repo: str) -> str:
316
+ """
317
+ Upload ZIP to HuggingFace dataset repository.
318
+
319
+ Returns: Download URL
320
+ """
321
+ if not self.hf_api:
322
+ raise RuntimeError("HuggingFace Hub API not available. Set HF_TOKEN environment variable.")
323
+
324
+ # Upload file to dataset repo
325
+ path_in_repo = f"masks/{zip_path.name}"
326
+
327
+ self.hf_api.upload_file(
328
+ path_or_fileobj=str(zip_path),
329
+ path_in_repo=path_in_repo,
330
+ repo_id=output_repo,
331
+ repo_type="dataset",
332
+ )
333
+
334
+ # Construct download URL
335
+ download_url = f"https://huggingface.co/datasets/{output_repo}/resolve/main/{path_in_repo}"
336
+ return download_url
337
+
338
+ def _get_video_metadata_from_frames(self, frames: list) -> Dict:
339
+ """Extract metadata from loaded video frames."""
340
+ if not frames or len(frames) == 0:
341
+ return {}
342
+
343
+ # Frames are numpy arrays of shape (H, W, C)
344
+ first_frame = frames[0]
345
+
346
+ return {
347
+ "frame_count": len(frames),
348
+ "height": first_frame.shape[0],
349
+ "width": first_frame.shape[1],
350
+ "channels": first_frame.shape[2] if len(first_frame.shape) > 2 else 1,
351
+ }
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SAM3 Handler Requirements for HuggingFace Inference Endpoints
2
+
3
+ # PyTorch (must be installed with CUDA support)
4
+ torch>=2.7.0
5
+ torchvision
6
+
7
+ # Transformers library (includes SAM3 support)
8
+ transformers>=4.38.0
9
+ accelerate
10
+
11
+ # Image/Video processing
12
+ opencv-python
13
+ Pillow
14
+
15
+ # HuggingFace Hub for uploads
16
+ huggingface_hub