| | import os |
| | import io |
| | import base64 |
| | import tempfile |
| | import zipfile |
| | from typing import Dict, Any, Optional |
| | from pathlib import Path |
| | import json |
| |
|
| | import torch |
| | import numpy as np |
| | from PIL import Image |
| | import cv2 |
| |
|
| | |
| | from transformers import Sam3VideoModel, Sam3VideoProcessor |
| |
|
| | |
| | try: |
| | from huggingface_hub import HfApi |
| | HF_HUB_AVAILABLE = True |
| | except ImportError: |
| | HF_HUB_AVAILABLE = False |
| |
|
| |
|
| | class EndpointHandler: |
| | """ |
| | SAM3 Video Segmentation Handler for HuggingFace Inference Endpoints |
| | |
| | Processes video with text prompts and returns segmentation masks. |
| | Uses transformers library for clean integration with HuggingFace models. |
| | """ |
| | |
| | def __init__(self, path: str = ""): |
| | """ |
| | Initialize SAM3 video model using transformers. |
| | |
| | Args: |
| | path: Path to model repository (contains model files) |
| | For HF Inference Endpoints, this is /repository |
| | Contains: sam3.pt, config.json, processor_config.json, etc. |
| | """ |
| | print(f"[INIT] Initializing SAM3 video model from {path}") |
| | |
| | |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | if self.device != "cuda": |
| | raise ValueError("SAM3 requires GPU acceleration. No CUDA device found.") |
| | |
| | print(f"[INIT] Using device: {self.device}") |
| | |
| | |
| | |
| | model_path = path if path and path != "." else "facebook/sam3" |
| | |
| | try: |
| | print(f"[INIT] Loading model from: {model_path}") |
| | self.model = Sam3VideoModel.from_pretrained( |
| | model_path, |
| | torch_dtype=torch.bfloat16, |
| | device_map=self.device |
| | ) |
| | |
| | self.processor = Sam3VideoProcessor.from_pretrained(model_path) |
| | |
| | print("[INIT] SAM3 video model loaded successfully") |
| | |
| | except Exception as e: |
| | print(f"[INIT] Error loading from {model_path}: {e}") |
| | print("[INIT] Falling back to facebook/sam3") |
| | |
| | |
| | self.model = Sam3VideoModel.from_pretrained( |
| | "facebook/sam3", |
| | torch_dtype=torch.bfloat16, |
| | device_map=self.device |
| | ) |
| | |
| | self.processor = Sam3VideoProcessor.from_pretrained("facebook/sam3") |
| | |
| | print("[INIT] SAM3 video model loaded from facebook/sam3") |
| | |
| | |
| | self.hf_api = None |
| | hf_token = os.getenv("HF_TOKEN") |
| | if HF_HUB_AVAILABLE and hf_token: |
| | self.hf_api = HfApi(token=hf_token) |
| | print("[INIT] HuggingFace Hub API initialized") |
| | else: |
| | print("[INIT] HuggingFace Hub uploads disabled (no token or huggingface_hub not installed)") |
| | |
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| | """ |
| | Process video segmentation request using transformers API. |
| | |
| | Expected input format: |
| | { |
| | "video": <base64_encoded_video>, |
| | "text_prompt": "object to segment", |
| | "return_format": "download_url" or "base64" or "metadata_only" # optional |
| | "output_repo": "username/dataset-name", # optional, for HF upload |
| | } |
| | |
| | Returns: |
| | { |
| | "download_url": "https://...", # if uploaded to HF |
| | "frame_count": 120, |
| | "video_metadata": {...}, |
| | "compressed_size_mb": 15.3, |
| | "objects_detected": [1, 2, 3] # object IDs |
| | } |
| | """ |
| | try: |
| | |
| | video_data = data.get("video") |
| | text_prompt = data.get("text_prompt", data.get("inputs", "")) |
| | output_repo = data.get("output_repo") |
| | return_format = data.get("return_format", "metadata_only") |
| | |
| | if not video_data: |
| | return {"error": "No video data provided. Include 'video' in request."} |
| | |
| | if not text_prompt: |
| | return {"error": "No text prompt provided. Include 'text_prompt' or 'inputs' in request."} |
| | |
| | print(f"[REQUEST] Processing video with prompt: '{text_prompt}'") |
| | print(f"[REQUEST] Return format: {return_format}") |
| | |
| | |
| | with tempfile.TemporaryDirectory() as tmpdir: |
| | tmpdir_path = Path(tmpdir) |
| | |
| | |
| | video_path = self._prepare_video(video_data, tmpdir_path) |
| | print(f"[STEP 1] Video prepared at: {video_path}") |
| | |
| | |
| | video_frames = self._load_video_frames(video_path) |
| | print(f"[STEP 2] Loaded {len(video_frames)} frames") |
| | |
| | |
| | inference_session = self.processor.init_video_session( |
| | video=video_frames, |
| | inference_device=self.device, |
| | processing_device="cpu", |
| | video_storage_device="cpu", |
| | dtype=torch.bfloat16, |
| | ) |
| | print(f"[STEP 3] Inference session initialized") |
| | |
| | |
| | inference_session = self.processor.add_text_prompt( |
| | inference_session=inference_session, |
| | text=text_prompt, |
| | ) |
| | print(f"[STEP 4] Text prompt added") |
| | |
| | |
| | masks_dir = tmpdir_path / "masks" |
| | masks_dir.mkdir() |
| | |
| | frame_outputs = self._propagate_and_save_masks( |
| | inference_session, |
| | masks_dir |
| | ) |
| | print(f"[STEP 5] Propagated through {len(frame_outputs)} frames") |
| | |
| | |
| | all_object_ids = set() |
| | for frame_output in frame_outputs.values(): |
| | if 'object_ids' in frame_output and frame_output['object_ids'] is not None: |
| | ids = frame_output['object_ids'] |
| | if torch.is_tensor(ids): |
| | all_object_ids.update(ids.tolist()) |
| | else: |
| | all_object_ids.update(ids) |
| | |
| | |
| | zip_path = tmpdir_path / "masks.zip" |
| | self._create_zip(masks_dir, zip_path) |
| | zip_size_mb = zip_path.stat().st_size / 1e6 |
| | print(f"[STEP 6] Created ZIP archive: {zip_size_mb:.2f} MB") |
| | |
| | |
| | response = { |
| | "frame_count": len(frame_outputs), |
| | "objects_detected": sorted(list(all_object_ids)) if all_object_ids else [], |
| | "compressed_size_mb": round(zip_size_mb, 2), |
| | "video_metadata": self._get_video_metadata_from_frames(video_frames) |
| | } |
| | |
| | if return_format == "download_url" and output_repo: |
| | |
| | download_url = self._upload_to_hf(zip_path, output_repo) |
| | response["download_url"] = download_url |
| | print(f"[STEP 7] Uploaded to HuggingFace: {download_url}") |
| | |
| | elif return_format == "base64": |
| | |
| | with open(zip_path, "rb") as f: |
| | zip_base64 = base64.b64encode(f.read()).decode('utf-8') |
| | response["masks_zip_base64"] = zip_base64 |
| | print(f"[STEP 7] Returning base64 encoded ZIP") |
| | |
| | else: |
| | |
| | response["note"] = "Masks generated but not returned. Use return_format='base64' or 'download_url' to get masks." |
| | print(f"[STEP 7] Returning metadata only") |
| | |
| | return response |
| | |
| | except Exception as e: |
| | print(f"[ERROR] {type(e).__name__}: {str(e)}") |
| | import traceback |
| | traceback.print_exc() |
| | return { |
| | "error": str(e), |
| | "error_type": type(e).__name__ |
| | } |
| | |
| | def _prepare_video(self, video_data: Any, tmpdir: Path) -> Path: |
| | """Decode base64 video data and save to temporary location.""" |
| | video_path = tmpdir / "input_video.mp4" |
| | |
| | if isinstance(video_data, str): |
| | |
| | video_bytes = base64.b64decode(video_data) |
| | elif isinstance(video_data, bytes): |
| | video_bytes = video_data |
| | else: |
| | raise ValueError(f"Unsupported video data type: {type(video_data)}") |
| | |
| | video_path.write_bytes(video_bytes) |
| | return video_path |
| | |
| | def _load_video_frames(self, video_path: Path) -> list: |
| | """Load video frames from MP4 file.""" |
| | from transformers.video_utils import load_video |
| | |
| | |
| | frames, _ = load_video(str(video_path)) |
| | return frames |
| | |
| | def _propagate_and_save_masks(self, inference_session, masks_dir: Path) -> Dict[int, Dict]: |
| | """ |
| | Propagate masks through video using transformers API and save to disk. |
| | |
| | Returns dict mapping frame_idx -> outputs |
| | """ |
| | outputs_per_frame = {} |
| | |
| | |
| | for model_outputs in self.model.propagate_in_video_iterator( |
| | inference_session=inference_session, |
| | max_frame_num_to_track=None |
| | ): |
| | frame_idx = model_outputs.frame_idx |
| | |
| | |
| | processed_outputs = self.processor.postprocess_outputs( |
| | inference_session, |
| | model_outputs |
| | ) |
| | |
| | outputs_per_frame[frame_idx] = processed_outputs |
| | |
| | |
| | self._save_frame_masks(processed_outputs, masks_dir, frame_idx) |
| | |
| | return outputs_per_frame |
| | |
| | def _save_frame_masks(self, outputs: Dict, masks_dir: Path, frame_idx: int): |
| | """ |
| | Save masks for a single frame. |
| | |
| | Saves combined binary mask with all objects. |
| | Format: mask_NNNN.png (white = object, black = background) |
| | """ |
| | |
| | if 'masks' not in outputs or outputs['masks'] is None or len(outputs['masks']) == 0: |
| | |
| | |
| | height = 1080 |
| | width = 1920 |
| | combined_mask = np.zeros((height, width), dtype=np.uint8) |
| | else: |
| | masks = outputs['masks'] |
| | |
| | |
| | if torch.is_tensor(masks): |
| | masks = masks.cpu().numpy() |
| | |
| | |
| | if len(masks.shape) == 3: |
| | |
| | combined_mask = np.any(masks > 0.5, axis=0).astype(np.uint8) * 255 |
| | elif len(masks.shape) == 2: |
| | |
| | combined_mask = (masks > 0.5).astype(np.uint8) * 255 |
| | else: |
| | |
| | combined_mask = np.zeros((1080, 1920), dtype=np.uint8) |
| | |
| | |
| | mask_filename = masks_dir / f"mask_{frame_idx:04d}.png" |
| | mask_image = Image.fromarray(combined_mask) |
| | mask_image.save(mask_filename, compress_level=9) |
| | |
| | def _create_zip(self, masks_dir: Path, zip_path: Path): |
| | """Create ZIP archive of all mask PNGs.""" |
| | with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: |
| | for mask_file in sorted(masks_dir.glob("mask_*.png")): |
| | zipf.write(mask_file, mask_file.name) |
| | |
| | def _upload_to_hf(self, zip_path: Path, output_repo: str) -> str: |
| | """ |
| | Upload ZIP to HuggingFace dataset repository. |
| | |
| | Returns: Download URL |
| | """ |
| | if not self.hf_api: |
| | raise RuntimeError("HuggingFace Hub API not available. Set HF_TOKEN environment variable.") |
| | |
| | |
| | path_in_repo = f"masks/{zip_path.name}" |
| | |
| | self.hf_api.upload_file( |
| | path_or_fileobj=str(zip_path), |
| | path_in_repo=path_in_repo, |
| | repo_id=output_repo, |
| | repo_type="dataset", |
| | ) |
| | |
| | |
| | download_url = f"https://huggingface.co/datasets/{output_repo}/resolve/main/{path_in_repo}" |
| | return download_url |
| | |
| | def _get_video_metadata_from_frames(self, frames: list) -> Dict: |
| | """Extract metadata from loaded video frames.""" |
| | if not frames or len(frames) == 0: |
| | return {} |
| | |
| | |
| | first_frame = frames[0] |
| | |
| | return { |
| | "frame_count": len(frames), |
| | "height": first_frame.shape[0], |
| | "width": first_frame.shape[1], |
| | "channels": first_frame.shape[2] if len(first_frame.shape) > 2 else 1, |
| | } |
| |
|