Justin331 commited on
Commit
37b4ab0
·
verified ·
1 Parent(s): 3e8dd07

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. handler.py +144 -194
  2. requirements.txt +1 -1
  3. setup.py +8 -0
handler.py CHANGED
@@ -12,8 +12,8 @@ import numpy as np
12
  from PIL import Image
13
  import cv2
14
 
15
- # Transformers imports for SAM3
16
- from transformers import Sam3VideoModel, Sam3VideoProcessor
17
 
18
  # HuggingFace Hub for uploads
19
  try:
@@ -28,19 +28,17 @@ class EndpointHandler:
28
  SAM3 Video Segmentation Handler for HuggingFace Inference Endpoints
29
 
30
  Processes video with text prompts and returns segmentation masks.
31
- Uses transformers library for clean integration with HuggingFace models.
32
  """
33
 
34
  def __init__(self, path: str = ""):
35
  """
36
- Initialize SAM3 video model using transformers.
37
 
38
  Args:
39
- path: Path to model repository (contains model files)
40
- For HF Inference Endpoints, this is /repository
41
- Contains: sam3.pt, config.json, processor_config.json, etc.
42
  """
43
- print(f"[INIT] Initializing SAM3 video model from {path}")
44
 
45
  # Set device
46
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -49,36 +47,14 @@ class EndpointHandler:
49
 
50
  print(f"[INIT] Using device: {self.device}")
51
 
52
- # Load model and processor from the repository
53
- # If path is empty or ".", try to load from default model ID
54
- model_path = path if path and path != "." else "facebook/sam3"
55
-
56
  try:
57
- print(f"[INIT] Loading model from: {model_path}")
58
- self.model = Sam3VideoModel.from_pretrained(
59
- model_path,
60
- torch_dtype=torch.bfloat16,
61
- device_map=self.device
62
- )
63
-
64
- self.processor = Sam3VideoProcessor.from_pretrained(model_path)
65
-
66
- print("[INIT] SAM3 video model loaded successfully")
67
-
68
  except Exception as e:
69
- print(f"[INIT] Error loading from {model_path}: {e}")
70
- print("[INIT] Falling back to facebook/sam3")
71
-
72
- # Fallback to public model
73
- self.model = Sam3VideoModel.from_pretrained(
74
- "facebook/sam3",
75
- torch_dtype=torch.bfloat16,
76
- device_map=self.device
77
- )
78
-
79
- self.processor = Sam3VideoProcessor.from_pretrained("facebook/sam3")
80
-
81
- print("[INIT] SAM3 video model loaded from facebook/sam3")
82
 
83
  # Initialize HuggingFace API for uploads (if available)
84
  self.hf_api = None
@@ -91,7 +67,7 @@ class EndpointHandler:
91
 
92
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
93
  """
94
- Process video segmentation request using transformers API.
95
 
96
  Expected input format:
97
  {
@@ -134,46 +110,53 @@ class EndpointHandler:
134
  video_path = self._prepare_video(video_data, tmpdir_path)
135
  print(f"[STEP 1] Video prepared at: {video_path}")
136
 
137
- # Step 2: Load video frames
138
- video_frames = self._load_video_frames(video_path)
139
- print(f"[STEP 2] Loaded {len(video_frames)} frames")
140
-
141
- # Step 3: Initialize inference session
142
- inference_session = self.processor.init_video_session(
143
- video=video_frames,
144
- inference_device=self.device,
145
- processing_device="cpu",
146
- video_storage_device="cpu",
147
- dtype=torch.bfloat16,
148
  )
149
- print(f"[STEP 3] Inference session initialized")
 
150
 
151
- # Step 4: Add text prompt
152
- inference_session = self.processor.add_text_prompt(
153
- inference_session=inference_session,
154
- text=text_prompt,
 
 
 
 
155
  )
156
- print(f"[STEP 4] Text prompt added")
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
- # Step 5: Propagate through video and save masks
159
  masks_dir = tmpdir_path / "masks"
160
  masks_dir.mkdir()
161
 
162
- frame_outputs = self._propagate_and_save_masks(
163
- inference_session,
164
- masks_dir
165
- )
166
- print(f"[STEP 5] Propagated through {len(frame_outputs)} frames")
167
-
168
- # Get unique object IDs across all frames
169
  all_object_ids = set()
170
- for frame_output in frame_outputs.values():
171
- if 'object_ids' in frame_output and frame_output['object_ids'] is not None:
172
- ids = frame_output['object_ids']
173
- if torch.is_tensor(ids):
174
- all_object_ids.update(ids.tolist())
175
- else:
176
- all_object_ids.update(ids)
 
177
 
178
  # Step 6: Create ZIP archive
179
  zip_path = tmpdir_path / "masks.zip"
@@ -181,34 +164,45 @@ class EndpointHandler:
181
  zip_size_mb = zip_path.stat().st_size / 1e6
182
  print(f"[STEP 6] Created ZIP archive: {zip_size_mb:.2f} MB")
183
 
184
- # Step 7: Prepare response based on return_format
 
 
 
185
  response = {
186
- "frame_count": len(frame_outputs),
187
  "objects_detected": sorted(list(all_object_ids)) if all_object_ids else [],
188
  "compressed_size_mb": round(zip_size_mb, 2),
189
- "video_metadata": self._get_video_metadata_from_frames(video_frames)
190
  }
191
 
192
  if return_format == "download_url" and output_repo:
193
  # Upload to HuggingFace
194
  download_url = self._upload_to_hf(zip_path, output_repo)
195
  response["download_url"] = download_url
196
- print(f"[STEP 7] Uploaded to HuggingFace: {download_url}")
197
 
198
  elif return_format == "base64":
199
  # Return base64 encoded ZIP
200
  with open(zip_path, "rb") as f:
201
- zip_base64 = base64.b64encode(f.read()).decode('utf-8')
202
- response["masks_zip_base64"] = zip_base64
203
- print(f"[STEP 7] Returning base64 encoded ZIP")
204
 
205
  else:
206
- # metadata_only - just return stats
207
- response["note"] = "Masks generated but not returned. Use return_format='base64' or 'download_url' to get masks."
208
- print(f"[STEP 7] Returning metadata only")
 
 
 
 
 
 
 
 
209
 
210
  return response
211
-
212
  except Exception as e:
213
  print(f"[ERROR] {type(e).__name__}: {str(e)}")
214
  import traceback
@@ -218,134 +212,90 @@ class EndpointHandler:
218
  "error_type": type(e).__name__
219
  }
220
 
221
- def _prepare_video(self, video_data: Any, tmpdir: Path) -> Path:
222
- """Decode base64 video data and save to temporary location."""
223
- video_path = tmpdir / "input_video.mp4"
224
-
225
- if isinstance(video_data, str):
226
- # Base64 encoded
227
  video_bytes = base64.b64decode(video_data)
228
- elif isinstance(video_data, bytes):
229
- video_bytes = video_data
230
- else:
231
- raise ValueError(f"Unsupported video data type: {type(video_data)}")
232
 
 
233
  video_path.write_bytes(video_bytes)
234
- return video_path
235
-
236
- def _load_video_frames(self, video_path: Path) -> list:
237
- """Load video frames from MP4 file."""
238
- from transformers.video_utils import load_video
239
 
240
- # load_video returns (frames, audio) - we only need frames
241
- frames, _ = load_video(str(video_path))
242
- return frames
243
 
244
- def _propagate_and_save_masks(self, inference_session, masks_dir: Path) -> Dict[int, Dict]:
245
  """
246
- Propagate masks through video using transformers API and save to disk.
247
-
248
- Returns dict mapping frame_idx -> outputs
249
  """
250
- outputs_per_frame = {}
251
-
252
- # Use the model's propagate_in_video_iterator
253
- for model_outputs in self.model.propagate_in_video_iterator(
254
- inference_session=inference_session,
255
- max_frame_num_to_track=None # Process all frames
256
- ):
257
- frame_idx = model_outputs.frame_idx
258
-
259
- # Post-process outputs
260
- processed_outputs = self.processor.postprocess_outputs(
261
- inference_session,
262
- model_outputs
263
- )
264
-
265
- outputs_per_frame[frame_idx] = processed_outputs
266
-
267
- # Save masks for this frame
268
- self._save_frame_masks(processed_outputs, masks_dir, frame_idx)
269
 
270
- return outputs_per_frame
271
-
272
- def _save_frame_masks(self, outputs: Dict, masks_dir: Path, frame_idx: int):
273
- """
274
- Save masks for a single frame.
275
 
276
- Saves combined binary mask with all objects.
277
- Format: mask_NNNN.png (white = object, black = background)
278
- """
279
- # Extract masks from outputs
280
- if 'masks' not in outputs or outputs['masks'] is None or len(outputs['masks']) == 0:
281
- # No objects detected - save empty mask
282
- # Get dimensions from inference session or use default
283
- height = 1080
284
- width = 1920
285
- combined_mask = np.zeros((height, width), dtype=np.uint8)
286
- else:
287
- masks = outputs['masks'] # Tensor of shape (num_objects, H, W)
288
-
289
- # Convert to numpy if needed
290
- if torch.is_tensor(masks):
291
- masks = masks.cpu().numpy()
292
-
293
- # Combine all object masks into single binary mask
294
- if len(masks.shape) == 3:
295
- # Multiple objects - combine with logical OR
296
- combined_mask = np.any(masks > 0.5, axis=0).astype(np.uint8) * 255
297
- elif len(masks.shape) == 2:
298
- # Single object
299
- combined_mask = (masks > 0.5).astype(np.uint8) * 255
300
- else:
301
- # Unexpected shape - save empty
302
- combined_mask = np.zeros((1080, 1920), dtype=np.uint8)
303
 
304
- # Save as PNG
305
- mask_filename = masks_dir / f"mask_{frame_idx:04d}.png"
306
- mask_image = Image.fromarray(combined_mask)
307
- mask_image.save(mask_filename, compress_level=9)
 
 
 
 
 
 
 
 
308
 
309
  def _create_zip(self, masks_dir: Path, zip_path: Path):
310
  """Create ZIP archive of all mask PNGs."""
311
- with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
312
- for mask_file in sorted(masks_dir.glob("mask_*.png")):
313
  zipf.write(mask_file, mask_file.name)
314
 
315
- def _upload_to_hf(self, zip_path: Path, output_repo: str) -> str:
316
- """
317
- Upload ZIP to HuggingFace dataset repository.
318
-
319
- Returns: Download URL
320
- """
321
- if not self.hf_api:
322
- raise RuntimeError("HuggingFace Hub API not available. Set HF_TOKEN environment variable.")
323
-
324
- # Upload file to dataset repo
325
- path_in_repo = f"masks/{zip_path.name}"
326
-
327
- self.hf_api.upload_file(
328
- path_or_fileobj=str(zip_path),
329
- path_in_repo=path_in_repo,
330
- repo_id=output_repo,
331
- repo_type="dataset",
332
- )
333
-
334
- # Construct download URL
335
- download_url = f"https://huggingface.co/datasets/{output_repo}/resolve/main/{path_in_repo}"
336
- return download_url
337
-
338
- def _get_video_metadata_from_frames(self, frames: list) -> Dict:
339
- """Extract metadata from loaded video frames."""
340
- if not frames or len(frames) == 0:
341
  return {}
 
 
 
 
 
342
 
343
- # Frames are numpy arrays of shape (H, W, C)
344
- first_frame = frames[0]
345
-
346
- return {
347
- "frame_count": len(frames),
348
- "height": first_frame.shape[0],
349
- "width": first_frame.shape[1],
350
- "channels": first_frame.shape[2] if len(first_frame.shape) > 2 else 1,
351
- }
 
 
 
 
 
 
 
 
 
 
 
 
12
  from PIL import Image
13
  import cv2
14
 
15
+ # SAM3 imports - using local sam3 package in repository
16
+ from sam3.model_builder import build_sam3_video_predictor
17
 
18
  # HuggingFace Hub for uploads
19
  try:
 
28
  SAM3 Video Segmentation Handler for HuggingFace Inference Endpoints
29
 
30
  Processes video with text prompts and returns segmentation masks.
31
+ Uses SAM3 repository code directly from local sam3/ package.
32
  """
33
 
34
  def __init__(self, path: str = ""):
35
  """
36
+ Initialize SAM3 video predictor.
37
 
38
  Args:
39
+ path: Path to model repository (not used - model loads from HF automatically)
 
 
40
  """
41
+ print(f"[INIT] Initializing SAM3 video predictor")
42
 
43
  # Set device
44
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
47
 
48
  print(f"[INIT] Using device: {self.device}")
49
 
50
+ # Build SAM3 video predictor
51
+ # This automatically downloads model from facebook/sam3 on HuggingFace
 
 
52
  try:
53
+ self.predictor = build_sam3_video_predictor(gpus_to_use=[0])
54
+ print("[INIT] SAM3 video predictor loaded successfully")
 
 
 
 
 
 
 
 
 
55
  except Exception as e:
56
+ print(f"[INIT] Error loading SAM3 predictor: {e}")
57
+ raise
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  # Initialize HuggingFace API for uploads (if available)
60
  self.hf_api = None
 
67
 
68
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
69
  """
70
+ Process video segmentation request using SAM3 video predictor API.
71
 
72
  Expected input format:
73
  {
 
110
  video_path = self._prepare_video(video_data, tmpdir_path)
111
  print(f"[STEP 1] Video prepared at: {video_path}")
112
 
113
+ # Step 2: Start SAM3 session
114
+ response = self.predictor.handle_request(
115
+ request=dict(
116
+ type="start_session",
117
+ resource_path=str(video_path),
118
+ )
 
 
 
 
 
119
  )
120
+ session_id = response["session_id"]
121
+ print(f"[STEP 2] Session started: {session_id}")
122
 
123
+ # Step 3: Add text prompt
124
+ response = self.predictor.handle_request(
125
+ request=dict(
126
+ type="add_prompt",
127
+ session_id=session_id,
128
+ frame_index=0, # Add prompt on first frame
129
+ text=text_prompt,
130
+ )
131
  )
132
+ print(f"[STEP 3] Text prompt added")
133
+
134
+ # Step 4: Propagate through video and collect outputs
135
+ outputs_per_frame = {}
136
+ for stream_response in self.predictor.handle_stream_request(
137
+ request=dict(
138
+ type="propagate_in_video",
139
+ session_id=session_id,
140
+ )
141
+ ):
142
+ frame_idx = stream_response["frame_index"]
143
+ outputs_per_frame[frame_idx] = stream_response["outputs"]
144
+
145
+ print(f"[STEP 4] Propagated through {len(outputs_per_frame)} frames")
146
 
147
+ # Step 5: Save masks to PNG files
148
  masks_dir = tmpdir_path / "masks"
149
  masks_dir.mkdir()
150
 
 
 
 
 
 
 
 
151
  all_object_ids = set()
152
+ for frame_idx, frame_output in outputs_per_frame.items():
153
+ self._save_frame_masks(frame_output, masks_dir, frame_idx)
154
+
155
+ # Collect object IDs
156
+ if "object_ids" in frame_output and frame_output["object_ids"] is not None:
157
+ all_object_ids.update(frame_output["object_ids"])
158
+
159
+ print(f"[STEP 5] Saved masks for {len(outputs_per_frame)} frames")
160
 
161
  # Step 6: Create ZIP archive
162
  zip_path = tmpdir_path / "masks.zip"
 
164
  zip_size_mb = zip_path.stat().st_size / 1e6
165
  print(f"[STEP 6] Created ZIP archive: {zip_size_mb:.2f} MB")
166
 
167
+ # Step 7: Get video metadata
168
+ video_metadata = self._get_video_metadata(video_path)
169
+
170
+ # Step 8: Prepare response based on return_format
171
  response = {
172
+ "frame_count": len(outputs_per_frame),
173
  "objects_detected": sorted(list(all_object_ids)) if all_object_ids else [],
174
  "compressed_size_mb": round(zip_size_mb, 2),
175
+ "video_metadata": video_metadata
176
  }
177
 
178
  if return_format == "download_url" and output_repo:
179
  # Upload to HuggingFace
180
  download_url = self._upload_to_hf(zip_path, output_repo)
181
  response["download_url"] = download_url
182
+ print(f"[STEP 8] Uploaded to HuggingFace: {download_url}")
183
 
184
  elif return_format == "base64":
185
  # Return base64 encoded ZIP
186
  with open(zip_path, "rb") as f:
187
+ zip_bytes = f.read()
188
+ response["masks_zip_base64"] = base64.b64encode(zip_bytes).decode("utf-8")
189
+ print(f"[STEP 8] Encoded ZIP to base64")
190
 
191
  else:
192
+ # metadata_only - just return the stats
193
+ print(f"[STEP 8] Returning metadata only")
194
+
195
+ # Step 9: Close session
196
+ self.predictor.handle_request(
197
+ request=dict(
198
+ type="close_session",
199
+ session_id=session_id,
200
+ )
201
+ )
202
+ print(f"[STEP 9] Session closed")
203
 
204
  return response
205
+
206
  except Exception as e:
207
  print(f"[ERROR] {type(e).__name__}: {str(e)}")
208
  import traceback
 
212
  "error_type": type(e).__name__
213
  }
214
 
215
+ def _prepare_video(self, video_data: str, tmpdir: Path) -> Path:
216
+ """Decode base64 video and save to file."""
217
+ try:
 
 
 
218
  video_bytes = base64.b64decode(video_data)
219
+ except Exception as e:
220
+ raise ValueError(f"Failed to decode base64 video: {e}")
 
 
221
 
222
+ video_path = tmpdir / "input_video.mp4"
223
  video_path.write_bytes(video_bytes)
 
 
 
 
 
224
 
225
+ return video_path
 
 
226
 
227
+ def _save_frame_masks(self, frame_output: Dict, masks_dir: Path, frame_idx: int):
228
  """
229
+ Save masks for a frame as PNG files.
230
+ Each object gets its own mask file: frame_XXXX_obj_Y.png
 
231
  """
232
+ if "masks" not in frame_output or frame_output["masks"] is None:
233
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
+ masks = frame_output["masks"]
236
+ object_ids = frame_output.get("object_ids", [])
 
 
 
237
 
238
+ # Convert to numpy if tensor
239
+ if torch.is_tensor(masks):
240
+ masks = masks.cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
+ # Save each object's mask
243
+ for i, obj_id in enumerate(object_ids):
244
+ if i < len(masks):
245
+ mask = masks[i]
246
+
247
+ # Convert to binary (0 or 255)
248
+ mask_binary = (mask > 0.5).astype(np.uint8) * 255
249
+
250
+ # Save as PNG
251
+ mask_img = Image.fromarray(mask_binary)
252
+ mask_filename = f"frame_{frame_idx:05d}_obj_{obj_id}.png"
253
+ mask_img.save(masks_dir / mask_filename, compress_level=9)
254
 
255
  def _create_zip(self, masks_dir: Path, zip_path: Path):
256
  """Create ZIP archive of all mask PNGs."""
257
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED, compresslevel=9) as zipf:
258
+ for mask_file in sorted(masks_dir.glob("*.png")):
259
  zipf.write(mask_file, mask_file.name)
260
 
261
+ def _get_video_metadata(self, video_path: Path) -> Dict[str, Any]:
262
+ """Extract video metadata using OpenCV."""
263
+ try:
264
+ cap = cv2.VideoCapture(str(video_path))
265
+ metadata = {
266
+ "width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
267
+ "height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
268
+ "fps": float(cap.get(cv2.CAP_PROP_FPS)),
269
+ "frame_count": int(cap.get(cv2.CAP_PROP_FRAME_COUNT)),
270
+ }
271
+ cap.release()
272
+ return metadata
273
+ except Exception as e:
274
+ print(f"[WARNING] Could not extract video metadata: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
275
  return {}
276
+
277
+ def _upload_to_hf(self, zip_path: Path, repo_id: str) -> str:
278
+ """Upload ZIP file to HuggingFace dataset repository."""
279
+ if not self.hf_api:
280
+ raise ValueError("HuggingFace Hub API not initialized. Set HF_TOKEN environment variable.")
281
 
282
+ try:
283
+ # Generate unique filename
284
+ import time
285
+ timestamp = int(time.time())
286
+ filename = f"masks_{timestamp}.zip"
287
+
288
+ # Upload file
289
+ url = self.hf_api.upload_file(
290
+ path_or_fileobj=str(zip_path),
291
+ path_in_repo=filename,
292
+ repo_id=repo_id,
293
+ repo_type="dataset",
294
+ )
295
+
296
+ # Return download URL
297
+ download_url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/{filename}"
298
+ return download_url
299
+
300
+ except Exception as e:
301
+ raise ValueError(f"Failed to upload to HuggingFace: {e}")
requirements.txt CHANGED
@@ -1 +1 @@
1
- git+https://github.com/huggingface/transformers.git
 
1
+ .
setup.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name="sam3",
5
+ version="0.1.0",
6
+ packages=find_packages(),
7
+ description="A local package for the SAM3 model and utilities.",
8
+ )