| | """ |
| | SAM 3 Custom Inference Handler for Hugging Face Inference Endpoints |
| | Model: facebook/sam3 |
| | |
| | For ProofPath video assessment - text-prompted segmentation to find UI elements. |
| | Supports text prompts like "Save button", "dropdown menu", "text input field". |
| | |
| | KEY CAPABILITIES: |
| | - Text-to-segment: Find ALL instances of a concept (e.g., "button" → all buttons) |
| | - Promptable Concept Segmentation (PCS): 270K unique concepts |
| | - Video tracking: Consistent object IDs across frames |
| | - Presence token: Discriminates similar elements ("player in white" vs "player in red") |
| | |
| | REQUIREMENTS: |
| | 1. Set HF_TOKEN environment variable (model is gated) |
| | 2. Accept license at https://huggingface.co/facebook/sam3 |
| | """ |
| |
|
| | from typing import Dict, List, Any, Optional, Union |
| | import torch |
| | import numpy as np |
| | import base64 |
| | import io |
| | import os |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path: str = ""): |
| | """ |
| | Initialize SAM 3 model for text-prompted segmentation. |
| | |
| | Args: |
| | path: Path to the model directory (ignored - we load from HF hub) |
| | """ |
| | model_id = "facebook/sam3" |
| | |
| | |
| | hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") |
| | |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | |
| | |
| | from transformers import Sam3Processor, Sam3Model |
| | |
| | self.processor = Sam3Processor.from_pretrained( |
| | model_id, |
| | token=hf_token, |
| | ) |
| | |
| | self.model = Sam3Model.from_pretrained( |
| | model_id, |
| | torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
| | token=hf_token, |
| | ).to(self.device) |
| | |
| | self.model.eval() |
| | |
| | |
| | self._video_model = None |
| | self._video_processor = None |
| | |
| | def _get_video_model(self): |
| | """Lazy load video model only when needed.""" |
| | if self._video_model is None: |
| | from transformers import Sam3VideoModel, Sam3VideoProcessor |
| | |
| | model_id = "facebook/sam3" |
| | hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") |
| | |
| | self._video_processor = Sam3VideoProcessor.from_pretrained( |
| | model_id, |
| | token=hf_token, |
| | ) |
| | |
| | self._video_model = Sam3VideoModel.from_pretrained( |
| | model_id, |
| | torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
| | token=hf_token, |
| | ).to(self.device) |
| | |
| | self._video_model.eval() |
| | |
| | return self._video_model, self._video_processor |
| | |
| | def _load_image(self, image_data: Any): |
| | """Load image from various formats.""" |
| | from PIL import Image |
| | import requests |
| | |
| | if isinstance(image_data, Image.Image): |
| | return image_data.convert('RGB') |
| | elif isinstance(image_data, str): |
| | if image_data.startswith(('http://', 'https://')): |
| | response = requests.get(image_data, stream=True) |
| | return Image.open(response.raw).convert('RGB') |
| | elif image_data.startswith('data:'): |
| | header, encoded = image_data.split(',', 1) |
| | image_bytes = base64.b64decode(encoded) |
| | return Image.open(io.BytesIO(image_bytes)).convert('RGB') |
| | else: |
| | |
| | image_bytes = base64.b64decode(image_data) |
| | return Image.open(io.BytesIO(image_bytes)).convert('RGB') |
| | elif isinstance(image_data, bytes): |
| | return Image.open(io.BytesIO(image_data)).convert('RGB') |
| | else: |
| | raise ValueError(f"Unsupported image input type: {type(image_data)}") |
| | |
| | def _load_video_frames(self, video_data: Any, max_frames: int = 100, fps: float = 2.0) -> List: |
| | """Load video frames from various formats.""" |
| | import cv2 |
| | from PIL import Image |
| | import tempfile |
| | |
| | |
| | if isinstance(video_data, str): |
| | if video_data.startswith(('http://', 'https://')): |
| | import requests |
| | response = requests.get(video_data, stream=True) |
| | with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
| | for chunk in response.iter_content(chunk_size=8192): |
| | f.write(chunk) |
| | video_path = f.name |
| | elif video_data.startswith('data:'): |
| | header, encoded = video_data.split(',', 1) |
| | video_bytes = base64.b64decode(encoded) |
| | with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
| | f.write(video_bytes) |
| | video_path = f.name |
| | else: |
| | video_bytes = base64.b64decode(video_data) |
| | with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
| | f.write(video_bytes) |
| | video_path = f.name |
| | elif isinstance(video_data, bytes): |
| | with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
| | f.write(video_data) |
| | video_path = f.name |
| | else: |
| | raise ValueError(f"Unsupported video input type: {type(video_data)}") |
| | |
| | try: |
| | cap = cv2.VideoCapture(video_path) |
| | video_fps = cap.get(cv2.CAP_PROP_FPS) |
| | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| | duration = total_frames / video_fps if video_fps > 0 else 0 |
| | |
| | |
| | target_frames = min(max_frames, int(duration * fps), total_frames) |
| | if target_frames <= 0: |
| | target_frames = min(max_frames, total_frames) |
| | |
| | frame_indices = np.linspace(0, total_frames - 1, target_frames, dtype=int) |
| | |
| | frames = [] |
| | for idx in frame_indices: |
| | cap.set(cv2.CAP_PROP_POS_FRAMES, idx) |
| | ret, frame = cap.read() |
| | if ret: |
| | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| | pil_image = Image.fromarray(frame_rgb) |
| | frames.append(pil_image) |
| | |
| | cap.release() |
| | |
| | metadata = { |
| | "duration": duration, |
| | "total_frames": total_frames, |
| | "sampled_frames": len(frames), |
| | "video_fps": video_fps |
| | } |
| | |
| | return frames, metadata |
| | |
| | finally: |
| | if os.path.exists(video_path): |
| | os.unlink(video_path) |
| | |
| | def _masks_to_serializable(self, masks: torch.Tensor) -> List[List[List[int]]]: |
| | """Convert binary masks to RLE or simplified format for JSON serialization.""" |
| | |
| | |
| | masks_np = masks.cpu().numpy().astype(np.uint8) |
| | |
| | |
| | encoded_masks = [] |
| | for mask in masks_np: |
| | |
| | from PIL import Image |
| | img = Image.fromarray(mask * 255) |
| | buffer = io.BytesIO() |
| | img.save(buffer, format='PNG') |
| | encoded = base64.b64encode(buffer.getvalue()).decode('utf-8') |
| | encoded_masks.append(encoded) |
| | |
| | return encoded_masks |
| | |
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| | """ |
| | Process image or video with SAM 3 for text-prompted segmentation. |
| | |
| | INPUT FORMATS: |
| | |
| | 1. Single image with text prompt (find all instances): |
| | { |
| | "inputs": <image_url_or_base64>, |
| | "parameters": { |
| | "prompt": "Save button", |
| | "threshold": 0.5, |
| | "mask_threshold": 0.5, |
| | "return_masks": true |
| | } |
| | } |
| | |
| | 2. Single image with multiple text prompts: |
| | { |
| | "inputs": <image_url_or_base64>, |
| | "parameters": { |
| | "prompts": ["button", "text field", "dropdown"], |
| | "threshold": 0.5 |
| | } |
| | } |
| | |
| | 3. Single image with box prompts (positive/negative): |
| | { |
| | "inputs": <image_url_or_base64>, |
| | "parameters": { |
| | "prompt": "handle", |
| | "boxes": [[40, 183, 318, 204]], |
| | "box_labels": [0], // 0=negative, 1=positive |
| | "threshold": 0.5 |
| | } |
| | } |
| | |
| | 4. Video with text prompt (track all instances): |
| | { |
| | "inputs": <video_url_or_base64>, |
| | "parameters": { |
| | "mode": "video", |
| | "prompt": "Submit button", |
| | "max_frames": 100, |
| | "fps": 2.0 |
| | } |
| | } |
| | |
| | 5. Batch images: |
| | { |
| | "inputs": [<image1>, <image2>, ...], |
| | "parameters": { |
| | "prompts": ["ear", "dial"], // One per image |
| | "threshold": 0.5 |
| | } |
| | } |
| | |
| | 6. ProofPath UI element detection: |
| | { |
| | "inputs": <screenshot_base64>, |
| | "parameters": { |
| | "mode": "ui_elements", |
| | "elements": ["Save button", "Cancel button", "text input"], |
| | "threshold": 0.5 |
| | } |
| | } |
| | |
| | OUTPUT FORMAT: |
| | { |
| | "results": [ |
| | { |
| | "prompt": "Save button", |
| | "instances": [ |
| | { |
| | "box": [x1, y1, x2, y2], |
| | "score": 0.95, |
| | "mask": "<base64_png>" // if return_masks=true |
| | } |
| | ] |
| | } |
| | ], |
| | "image_size": {"width": 1920, "height": 1080} |
| | } |
| | """ |
| | inputs = data.get("inputs") |
| | params = data.get("parameters", {}) |
| | |
| | if inputs is None: |
| | raise ValueError("No inputs provided") |
| | |
| | mode = params.get("mode", "image") |
| | |
| | if mode == "video": |
| | return self._process_video(inputs, params) |
| | elif mode == "ui_elements": |
| | return self._process_ui_elements(inputs, params) |
| | elif isinstance(inputs, list): |
| | return self._process_batch(inputs, params) |
| | else: |
| | return self._process_single_image(inputs, params) |
| | |
| | def _process_single_image(self, image_data: Any, params: Dict) -> Dict[str, Any]: |
| | """Process a single image with text and/or box prompts.""" |
| | image = self._load_image(image_data) |
| | |
| | threshold = params.get("threshold", 0.5) |
| | mask_threshold = params.get("mask_threshold", 0.5) |
| | return_masks = params.get("return_masks", True) |
| | |
| | |
| | prompt = params.get("prompt") |
| | prompts = params.get("prompts", [prompt] if prompt else []) |
| | |
| | if not prompts: |
| | raise ValueError("No text prompt(s) provided") |
| | |
| | |
| | boxes = params.get("boxes") |
| | box_labels = params.get("box_labels") |
| | |
| | results = [] |
| | |
| | for text_prompt in prompts: |
| | |
| | if boxes is not None: |
| | input_boxes = [boxes] |
| | input_boxes_labels = [box_labels] if box_labels else [[1] * len(boxes)] |
| | |
| | processor_inputs = self.processor( |
| | images=image, |
| | text=text_prompt, |
| | input_boxes=input_boxes, |
| | input_boxes_labels=input_boxes_labels, |
| | return_tensors="pt" |
| | ).to(self.device) |
| | else: |
| | processor_inputs = self.processor( |
| | images=image, |
| | text=text_prompt, |
| | return_tensors="pt" |
| | ).to(self.device) |
| | |
| | |
| | with torch.no_grad(): |
| | outputs = self.model(**processor_inputs) |
| | |
| | |
| | post_results = self.processor.post_process_instance_segmentation( |
| | outputs, |
| | threshold=threshold, |
| | mask_threshold=mask_threshold, |
| | target_sizes=processor_inputs.get("original_sizes").tolist() |
| | )[0] |
| | |
| | instances = [] |
| | for i in range(len(post_results.get("boxes", []))): |
| | instance = { |
| | "box": post_results["boxes"][i].tolist(), |
| | "score": float(post_results["scores"][i]) |
| | } |
| | |
| | if return_masks and "masks" in post_results: |
| | |
| | mask = post_results["masks"][i].cpu().numpy().astype(np.uint8) * 255 |
| | from PIL import Image as PILImage |
| | mask_img = PILImage.fromarray(mask) |
| | buffer = io.BytesIO() |
| | mask_img.save(buffer, format='PNG') |
| | instance["mask"] = base64.b64encode(buffer.getvalue()).decode('utf-8') |
| | |
| | instances.append(instance) |
| | |
| | results.append({ |
| | "prompt": text_prompt, |
| | "instances": instances, |
| | "count": len(instances) |
| | }) |
| | |
| | return { |
| | "results": results, |
| | "image_size": {"width": image.width, "height": image.height} |
| | } |
| | |
| | def _process_batch(self, images_data: List, params: Dict) -> Dict[str, Any]: |
| | """Process multiple images with text prompts.""" |
| | images = [self._load_image(img) for img in images_data] |
| | |
| | prompts = params.get("prompts", []) |
| | prompt = params.get("prompt") |
| | |
| | |
| | if prompt and not prompts: |
| | prompts = [prompt] * len(images) |
| | |
| | if len(prompts) != len(images): |
| | raise ValueError(f"Number of prompts ({len(prompts)}) must match number of images ({len(images)})") |
| | |
| | threshold = params.get("threshold", 0.5) |
| | mask_threshold = params.get("mask_threshold", 0.5) |
| | return_masks = params.get("return_masks", False) |
| | |
| | |
| | processor_inputs = self.processor( |
| | images=images, |
| | text=prompts, |
| | return_tensors="pt" |
| | ).to(self.device) |
| | |
| | with torch.no_grad(): |
| | outputs = self.model(**processor_inputs) |
| | |
| | |
| | all_results = self.processor.post_process_instance_segmentation( |
| | outputs, |
| | threshold=threshold, |
| | mask_threshold=mask_threshold, |
| | target_sizes=processor_inputs.get("original_sizes").tolist() |
| | ) |
| | |
| | results = [] |
| | for idx, (post_results, text_prompt, image) in enumerate(zip(all_results, prompts, images)): |
| | instances = [] |
| | for i in range(len(post_results.get("boxes", []))): |
| | instance = { |
| | "box": post_results["boxes"][i].tolist(), |
| | "score": float(post_results["scores"][i]) |
| | } |
| | |
| | if return_masks and "masks" in post_results: |
| | mask = post_results["masks"][i].cpu().numpy().astype(np.uint8) * 255 |
| | from PIL import Image as PILImage |
| | mask_img = PILImage.fromarray(mask) |
| | buffer = io.BytesIO() |
| | mask_img.save(buffer, format='PNG') |
| | instance["mask"] = base64.b64encode(buffer.getvalue()).decode('utf-8') |
| | |
| | instances.append(instance) |
| | |
| | results.append({ |
| | "image_index": idx, |
| | "prompt": text_prompt, |
| | "instances": instances, |
| | "count": len(instances), |
| | "image_size": {"width": image.width, "height": image.height} |
| | }) |
| | |
| | return {"results": results} |
| | |
| | def _process_ui_elements(self, image_data: Any, params: Dict) -> Dict[str, Any]: |
| | """ |
| | ProofPath-specific mode: Detect multiple UI element types in a screenshot. |
| | Returns structured data for each element type with bounding boxes. |
| | """ |
| | image = self._load_image(image_data) |
| | |
| | elements = params.get("elements", []) |
| | if not elements: |
| | |
| | elements = ["button", "text input", "dropdown", "checkbox", "link"] |
| | |
| | threshold = params.get("threshold", 0.5) |
| | mask_threshold = params.get("mask_threshold", 0.5) |
| | |
| | all_detections = {} |
| | |
| | for element_type in elements: |
| | processor_inputs = self.processor( |
| | images=image, |
| | text=element_type, |
| | return_tensors="pt" |
| | ).to(self.device) |
| | |
| | with torch.no_grad(): |
| | outputs = self.model(**processor_inputs) |
| | |
| | post_results = self.processor.post_process_instance_segmentation( |
| | outputs, |
| | threshold=threshold, |
| | mask_threshold=mask_threshold, |
| | target_sizes=processor_inputs.get("original_sizes").tolist() |
| | )[0] |
| | |
| | detections = [] |
| | for i in range(len(post_results.get("boxes", []))): |
| | box = post_results["boxes"][i].tolist() |
| | detections.append({ |
| | "box": box, |
| | "score": float(post_results["scores"][i]), |
| | "center": [ |
| | (box[0] + box[2]) / 2, |
| | (box[1] + box[3]) / 2 |
| | ] |
| | }) |
| | |
| | all_detections[element_type] = { |
| | "count": len(detections), |
| | "instances": detections |
| | } |
| | |
| | return { |
| | "ui_elements": all_detections, |
| | "image_size": {"width": image.width, "height": image.height}, |
| | "total_elements": sum(d["count"] for d in all_detections.values()) |
| | } |
| | |
| | def _process_video(self, video_data: Any, params: Dict) -> Dict[str, Any]: |
| | """ |
| | Process video with SAM3 Video for text-prompted tracking. |
| | Tracks all instances of the prompted concept across frames. |
| | """ |
| | video_model, video_processor = self._get_video_model() |
| | |
| | prompt = params.get("prompt") |
| | if not prompt: |
| | raise ValueError("Text prompt required for video mode") |
| | |
| | max_frames = params.get("max_frames", 100) |
| | fps = params.get("fps", 2.0) |
| | |
| | |
| | frames, video_metadata = self._load_video_frames(video_data, max_frames, fps) |
| | |
| | if not frames: |
| | raise ValueError("No frames could be extracted from video") |
| | |
| | |
| | inference_session = video_processor.init_video_session( |
| | video=frames, |
| | inference_device=self.device, |
| | processing_device="cpu", |
| | video_storage_device="cpu", |
| | dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
| | ) |
| | |
| | |
| | inference_session = video_processor.add_text_prompt( |
| | inference_session=inference_session, |
| | text=prompt, |
| | ) |
| | |
| | |
| | outputs_per_frame = {} |
| | for model_outputs in video_model.propagate_in_video_iterator( |
| | inference_session=inference_session, |
| | max_frame_num_to_track=max_frames |
| | ): |
| | processed = video_processor.postprocess_outputs(inference_session, model_outputs) |
| | |
| | frame_data = { |
| | "frame_idx": model_outputs.frame_idx, |
| | "object_ids": processed["object_ids"].tolist() if hasattr(processed["object_ids"], "tolist") else processed["object_ids"], |
| | "scores": processed["scores"].tolist() if hasattr(processed["scores"], "tolist") else processed["scores"], |
| | "boxes": processed["boxes"].tolist() if hasattr(processed["boxes"], "tolist") else processed["boxes"], |
| | } |
| | |
| | outputs_per_frame[model_outputs.frame_idx] = frame_data |
| | |
| | |
| | |
| | object_tracks = {} |
| | for frame_idx, frame_data in outputs_per_frame.items(): |
| | for i, obj_id in enumerate(frame_data["object_ids"]): |
| | obj_id_str = str(obj_id) |
| | if obj_id_str not in object_tracks: |
| | object_tracks[obj_id_str] = { |
| | "object_id": obj_id, |
| | "frames": [] |
| | } |
| | object_tracks[obj_id_str]["frames"].append({ |
| | "frame_idx": frame_idx, |
| | "box": frame_data["boxes"][i] if i < len(frame_data["boxes"]) else None, |
| | "score": frame_data["scores"][i] if i < len(frame_data["scores"]) else None |
| | }) |
| | |
| | return { |
| | "prompt": prompt, |
| | "video_metadata": video_metadata, |
| | "frames_processed": len(outputs_per_frame), |
| | "objects_tracked": len(object_tracks), |
| | "tracks": list(object_tracks.values()), |
| | "per_frame_detections": outputs_per_frame |
| | } |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | handler = EndpointHandler() |
| | |
| | |
| | test_data = { |
| | "inputs": "http://images.cocodataset.org/val2017/000000077595.jpg", |
| | "parameters": { |
| | "prompt": "ear", |
| | "threshold": 0.5, |
| | "return_masks": False |
| | } |
| | } |
| | |
| | result = handler(test_data) |
| | print(f"Found {result['results'][0]['count']} instances of '{result['results'][0]['prompt']}'") |
| | for inst in result['results'][0]['instances']: |
| | print(f" Box: {inst['box']}, Score: {inst['score']:.3f}") |
| |
|