""" SAM 3 Custom Inference Handler for Hugging Face Inference Endpoints Model: facebook/sam3 For ProofPath video assessment - text-prompted segmentation to find UI elements. Supports text prompts like "Save button", "dropdown menu", "text input field". KEY CAPABILITIES: - Text-to-segment: Find ALL instances of a concept (e.g., "button" → all buttons) - Promptable Concept Segmentation (PCS): 270K unique concepts - Video tracking: Consistent object IDs across frames - Presence token: Discriminates similar elements ("player in white" vs "player in red") REQUIREMENTS: 1. Set HF_TOKEN environment variable (model is gated) 2. Accept license at https://huggingface.co/facebook/sam3 """ from typing import Dict, List, Any, Optional, Union import torch import numpy as np import base64 import io import os class EndpointHandler: def __init__(self, path: str = ""): """ Initialize SAM 3 model for text-prompted segmentation. Args: path: Path to the model directory (ignored - we load from HF hub) """ model_id = "facebook/sam3" # Get HF token for gated model access hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Import SAM3 components from transformers from transformers import Sam3Processor, Sam3Model self.processor = Sam3Processor.from_pretrained( model_id, token=hf_token, ) self.model = Sam3Model.from_pretrained( model_id, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, token=hf_token, ).to(self.device) self.model.eval() # Also load video model for video segmentation self._video_model = None self._video_processor = None def _get_video_model(self): """Lazy load video model only when needed.""" if self._video_model is None: from transformers import Sam3VideoModel, Sam3VideoProcessor model_id = "facebook/sam3" hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") self._video_processor = Sam3VideoProcessor.from_pretrained( model_id, token=hf_token, ) self._video_model = Sam3VideoModel.from_pretrained( model_id, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, token=hf_token, ).to(self.device) self._video_model.eval() return self._video_model, self._video_processor def _load_image(self, image_data: Any): """Load image from various formats.""" from PIL import Image import requests if isinstance(image_data, Image.Image): return image_data.convert('RGB') elif isinstance(image_data, str): if image_data.startswith(('http://', 'https://')): response = requests.get(image_data, stream=True) return Image.open(response.raw).convert('RGB') elif image_data.startswith('data:'): header, encoded = image_data.split(',', 1) image_bytes = base64.b64decode(encoded) return Image.open(io.BytesIO(image_bytes)).convert('RGB') else: # Assume base64 encoded image_bytes = base64.b64decode(image_data) return Image.open(io.BytesIO(image_bytes)).convert('RGB') elif isinstance(image_data, bytes): return Image.open(io.BytesIO(image_data)).convert('RGB') else: raise ValueError(f"Unsupported image input type: {type(image_data)}") def _load_video_frames(self, video_data: Any, max_frames: int = 100, fps: float = 2.0) -> List: """Load video frames from various formats.""" import cv2 from PIL import Image import tempfile # Decode to temp file if needed if isinstance(video_data, str): if video_data.startswith(('http://', 'https://')): import requests response = requests.get(video_data, stream=True) with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) video_path = f.name elif video_data.startswith('data:'): header, encoded = video_data.split(',', 1) video_bytes = base64.b64decode(encoded) with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: f.write(video_bytes) video_path = f.name else: video_bytes = base64.b64decode(video_data) with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: f.write(video_bytes) video_path = f.name elif isinstance(video_data, bytes): with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: f.write(video_data) video_path = f.name else: raise ValueError(f"Unsupported video input type: {type(video_data)}") try: cap = cv2.VideoCapture(video_path) video_fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) duration = total_frames / video_fps if video_fps > 0 else 0 # Calculate frames to sample target_frames = min(max_frames, int(duration * fps), total_frames) if target_frames <= 0: target_frames = min(max_frames, total_frames) frame_indices = np.linspace(0, total_frames - 1, target_frames, dtype=int) frames = [] for idx in frame_indices: cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ret, frame = cap.read() if ret: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(frame_rgb) frames.append(pil_image) cap.release() metadata = { "duration": duration, "total_frames": total_frames, "sampled_frames": len(frames), "video_fps": video_fps } return frames, metadata finally: if os.path.exists(video_path): os.unlink(video_path) def _masks_to_serializable(self, masks: torch.Tensor) -> List[List[List[int]]]: """Convert binary masks to RLE or simplified format for JSON serialization.""" # For efficiency, we'll return bounding box info and optionally compressed masks # Full masks can be very large - return as base64 encoded numpy if needed masks_np = masks.cpu().numpy().astype(np.uint8) # Return as list of base64-encoded masks encoded_masks = [] for mask in masks_np: # Encode each mask as PNG for compression from PIL import Image img = Image.fromarray(mask * 255) buffer = io.BytesIO() img.save(buffer, format='PNG') encoded = base64.b64encode(buffer.getvalue()).decode('utf-8') encoded_masks.append(encoded) return encoded_masks def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Process image or video with SAM 3 for text-prompted segmentation. INPUT FORMATS: 1. Single image with text prompt (find all instances): { "inputs": , "parameters": { "prompt": "Save button", "threshold": 0.5, "mask_threshold": 0.5, "return_masks": true } } 2. Single image with multiple text prompts: { "inputs": , "parameters": { "prompts": ["button", "text field", "dropdown"], "threshold": 0.5 } } 3. Single image with box prompts (positive/negative): { "inputs": , "parameters": { "prompt": "handle", "boxes": [[40, 183, 318, 204]], "box_labels": [0], // 0=negative, 1=positive "threshold": 0.5 } } 4. Video with text prompt (track all instances): { "inputs": , "parameters": { "mode": "video", "prompt": "Submit button", "max_frames": 100, "fps": 2.0 } } 5. Batch images: { "inputs": [, , ...], "parameters": { "prompts": ["ear", "dial"], // One per image "threshold": 0.5 } } 6. ProofPath UI element detection: { "inputs": , "parameters": { "mode": "ui_elements", "elements": ["Save button", "Cancel button", "text input"], "threshold": 0.5 } } OUTPUT FORMAT: { "results": [ { "prompt": "Save button", "instances": [ { "box": [x1, y1, x2, y2], "score": 0.95, "mask": "" // if return_masks=true } ] } ], "image_size": {"width": 1920, "height": 1080} } """ inputs = data.get("inputs") params = data.get("parameters", {}) if inputs is None: raise ValueError("No inputs provided") mode = params.get("mode", "image") if mode == "video": return self._process_video(inputs, params) elif mode == "ui_elements": return self._process_ui_elements(inputs, params) elif isinstance(inputs, list): return self._process_batch(inputs, params) else: return self._process_single_image(inputs, params) def _process_single_image(self, image_data: Any, params: Dict) -> Dict[str, Any]: """Process a single image with text and/or box prompts.""" image = self._load_image(image_data) threshold = params.get("threshold", 0.5) mask_threshold = params.get("mask_threshold", 0.5) return_masks = params.get("return_masks", True) # Get prompts prompt = params.get("prompt") prompts = params.get("prompts", [prompt] if prompt else []) if not prompts: raise ValueError("No text prompt(s) provided") # Get optional box prompts boxes = params.get("boxes") box_labels = params.get("box_labels") results = [] for text_prompt in prompts: # Prepare inputs if boxes is not None: input_boxes = [boxes] input_boxes_labels = [box_labels] if box_labels else [[1] * len(boxes)] processor_inputs = self.processor( images=image, text=text_prompt, input_boxes=input_boxes, input_boxes_labels=input_boxes_labels, return_tensors="pt" ).to(self.device) else: processor_inputs = self.processor( images=image, text=text_prompt, return_tensors="pt" ).to(self.device) # Run inference with torch.no_grad(): outputs = self.model(**processor_inputs) # Post-process post_results = self.processor.post_process_instance_segmentation( outputs, threshold=threshold, mask_threshold=mask_threshold, target_sizes=processor_inputs.get("original_sizes").tolist() )[0] instances = [] for i in range(len(post_results.get("boxes", []))): instance = { "box": post_results["boxes"][i].tolist(), "score": float(post_results["scores"][i]) } if return_masks and "masks" in post_results: # Encode mask as base64 PNG mask = post_results["masks"][i].cpu().numpy().astype(np.uint8) * 255 from PIL import Image as PILImage mask_img = PILImage.fromarray(mask) buffer = io.BytesIO() mask_img.save(buffer, format='PNG') instance["mask"] = base64.b64encode(buffer.getvalue()).decode('utf-8') instances.append(instance) results.append({ "prompt": text_prompt, "instances": instances, "count": len(instances) }) return { "results": results, "image_size": {"width": image.width, "height": image.height} } def _process_batch(self, images_data: List, params: Dict) -> Dict[str, Any]: """Process multiple images with text prompts.""" images = [self._load_image(img) for img in images_data] prompts = params.get("prompts", []) prompt = params.get("prompt") # Handle single prompt for all images if prompt and not prompts: prompts = [prompt] * len(images) if len(prompts) != len(images): raise ValueError(f"Number of prompts ({len(prompts)}) must match number of images ({len(images)})") threshold = params.get("threshold", 0.5) mask_threshold = params.get("mask_threshold", 0.5) return_masks = params.get("return_masks", False) # Default false for batch # Process batch processor_inputs = self.processor( images=images, text=prompts, return_tensors="pt" ).to(self.device) with torch.no_grad(): outputs = self.model(**processor_inputs) # Post-process all results all_results = self.processor.post_process_instance_segmentation( outputs, threshold=threshold, mask_threshold=mask_threshold, target_sizes=processor_inputs.get("original_sizes").tolist() ) results = [] for idx, (post_results, text_prompt, image) in enumerate(zip(all_results, prompts, images)): instances = [] for i in range(len(post_results.get("boxes", []))): instance = { "box": post_results["boxes"][i].tolist(), "score": float(post_results["scores"][i]) } if return_masks and "masks" in post_results: mask = post_results["masks"][i].cpu().numpy().astype(np.uint8) * 255 from PIL import Image as PILImage mask_img = PILImage.fromarray(mask) buffer = io.BytesIO() mask_img.save(buffer, format='PNG') instance["mask"] = base64.b64encode(buffer.getvalue()).decode('utf-8') instances.append(instance) results.append({ "image_index": idx, "prompt": text_prompt, "instances": instances, "count": len(instances), "image_size": {"width": image.width, "height": image.height} }) return {"results": results} def _process_ui_elements(self, image_data: Any, params: Dict) -> Dict[str, Any]: """ ProofPath-specific mode: Detect multiple UI element types in a screenshot. Returns structured data for each element type with bounding boxes. """ image = self._load_image(image_data) elements = params.get("elements", []) if not elements: # Default UI elements to look for elements = ["button", "text input", "dropdown", "checkbox", "link"] threshold = params.get("threshold", 0.5) mask_threshold = params.get("mask_threshold", 0.5) all_detections = {} for element_type in elements: processor_inputs = self.processor( images=image, text=element_type, return_tensors="pt" ).to(self.device) with torch.no_grad(): outputs = self.model(**processor_inputs) post_results = self.processor.post_process_instance_segmentation( outputs, threshold=threshold, mask_threshold=mask_threshold, target_sizes=processor_inputs.get("original_sizes").tolist() )[0] detections = [] for i in range(len(post_results.get("boxes", []))): box = post_results["boxes"][i].tolist() detections.append({ "box": box, "score": float(post_results["scores"][i]), "center": [ (box[0] + box[2]) / 2, (box[1] + box[3]) / 2 ] }) all_detections[element_type] = { "count": len(detections), "instances": detections } return { "ui_elements": all_detections, "image_size": {"width": image.width, "height": image.height}, "total_elements": sum(d["count"] for d in all_detections.values()) } def _process_video(self, video_data: Any, params: Dict) -> Dict[str, Any]: """ Process video with SAM3 Video for text-prompted tracking. Tracks all instances of the prompted concept across frames. """ video_model, video_processor = self._get_video_model() prompt = params.get("prompt") if not prompt: raise ValueError("Text prompt required for video mode") max_frames = params.get("max_frames", 100) fps = params.get("fps", 2.0) # Load video frames frames, video_metadata = self._load_video_frames(video_data, max_frames, fps) if not frames: raise ValueError("No frames could be extracted from video") # Initialize video session inference_session = video_processor.init_video_session( video=frames, inference_device=self.device, processing_device="cpu", video_storage_device="cpu", dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, ) # Add text prompt inference_session = video_processor.add_text_prompt( inference_session=inference_session, text=prompt, ) # Process all frames outputs_per_frame = {} for model_outputs in video_model.propagate_in_video_iterator( inference_session=inference_session, max_frame_num_to_track=max_frames ): processed = video_processor.postprocess_outputs(inference_session, model_outputs) frame_data = { "frame_idx": model_outputs.frame_idx, "object_ids": processed["object_ids"].tolist() if hasattr(processed["object_ids"], "tolist") else processed["object_ids"], "scores": processed["scores"].tolist() if hasattr(processed["scores"], "tolist") else processed["scores"], "boxes": processed["boxes"].tolist() if hasattr(processed["boxes"], "tolist") else processed["boxes"], } outputs_per_frame[model_outputs.frame_idx] = frame_data # Compile tracking results # Group by object_id to show trajectory object_tracks = {} for frame_idx, frame_data in outputs_per_frame.items(): for i, obj_id in enumerate(frame_data["object_ids"]): obj_id_str = str(obj_id) if obj_id_str not in object_tracks: object_tracks[obj_id_str] = { "object_id": obj_id, "frames": [] } object_tracks[obj_id_str]["frames"].append({ "frame_idx": frame_idx, "box": frame_data["boxes"][i] if i < len(frame_data["boxes"]) else None, "score": frame_data["scores"][i] if i < len(frame_data["scores"]) else None }) return { "prompt": prompt, "video_metadata": video_metadata, "frames_processed": len(outputs_per_frame), "objects_tracked": len(object_tracks), "tracks": list(object_tracks.values()), "per_frame_detections": outputs_per_frame } # For testing locally if __name__ == "__main__": handler = EndpointHandler() # Test with a sample image URL test_data = { "inputs": "http://images.cocodataset.org/val2017/000000077595.jpg", "parameters": { "prompt": "ear", "threshold": 0.5, "return_masks": False } } result = handler(test_data) print(f"Found {result['results'][0]['count']} instances of '{result['results'][0]['prompt']}'") for inst in result['results'][0]['instances']: print(f" Box: {inst['box']}, Score: {inst['score']:.3f}")