sam3 / handler.py
peterproofpath's picture
Upload 2 files
0bd64d1 verified
raw
history blame
22.8 kB
"""
SAM 3 Custom Inference Handler for Hugging Face Inference Endpoints
Model: facebook/sam3
For ProofPath video assessment - text-prompted segmentation to find UI elements.
Supports text prompts like "Save button", "dropdown menu", "text input field".
KEY CAPABILITIES:
- Text-to-segment: Find ALL instances of a concept (e.g., "button" → all buttons)
- Promptable Concept Segmentation (PCS): 270K unique concepts
- Video tracking: Consistent object IDs across frames
- Presence token: Discriminates similar elements ("player in white" vs "player in red")
REQUIREMENTS:
1. Set HF_TOKEN environment variable (model is gated)
2. Accept license at https://huggingface.co/facebook/sam3
"""
from typing import Dict, List, Any, Optional, Union
import torch
import numpy as np
import base64
import io
import os
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize SAM 3 model for text-prompted segmentation.
Args:
path: Path to the model directory (ignored - we load from HF hub)
"""
model_id = "facebook/sam3"
# Get HF token for gated model access
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Import SAM3 components from transformers
from transformers import Sam3Processor, Sam3Model
self.processor = Sam3Processor.from_pretrained(
model_id,
token=hf_token,
)
self.model = Sam3Model.from_pretrained(
model_id,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
token=hf_token,
).to(self.device)
self.model.eval()
# Also load video model for video segmentation
self._video_model = None
self._video_processor = None
def _get_video_model(self):
"""Lazy load video model only when needed."""
if self._video_model is None:
from transformers import Sam3VideoModel, Sam3VideoProcessor
model_id = "facebook/sam3"
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
self._video_processor = Sam3VideoProcessor.from_pretrained(
model_id,
token=hf_token,
)
self._video_model = Sam3VideoModel.from_pretrained(
model_id,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
token=hf_token,
).to(self.device)
self._video_model.eval()
return self._video_model, self._video_processor
def _load_image(self, image_data: Any):
"""Load image from various formats."""
from PIL import Image
import requests
if isinstance(image_data, Image.Image):
return image_data.convert('RGB')
elif isinstance(image_data, str):
if image_data.startswith(('http://', 'https://')):
response = requests.get(image_data, stream=True)
return Image.open(response.raw).convert('RGB')
elif image_data.startswith('data:'):
header, encoded = image_data.split(',', 1)
image_bytes = base64.b64decode(encoded)
return Image.open(io.BytesIO(image_bytes)).convert('RGB')
else:
# Assume base64 encoded
image_bytes = base64.b64decode(image_data)
return Image.open(io.BytesIO(image_bytes)).convert('RGB')
elif isinstance(image_data, bytes):
return Image.open(io.BytesIO(image_data)).convert('RGB')
else:
raise ValueError(f"Unsupported image input type: {type(image_data)}")
def _load_video_frames(self, video_data: Any, max_frames: int = 100, fps: float = 2.0) -> List:
"""Load video frames from various formats."""
import cv2
from PIL import Image
import tempfile
# Decode to temp file if needed
if isinstance(video_data, str):
if video_data.startswith(('http://', 'https://')):
import requests
response = requests.get(video_data, stream=True)
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
video_path = f.name
elif video_data.startswith('data:'):
header, encoded = video_data.split(',', 1)
video_bytes = base64.b64decode(encoded)
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
f.write(video_bytes)
video_path = f.name
else:
video_bytes = base64.b64decode(video_data)
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
f.write(video_bytes)
video_path = f.name
elif isinstance(video_data, bytes):
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
f.write(video_data)
video_path = f.name
else:
raise ValueError(f"Unsupported video input type: {type(video_data)}")
try:
cap = cv2.VideoCapture(video_path)
video_fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = total_frames / video_fps if video_fps > 0 else 0
# Calculate frames to sample
target_frames = min(max_frames, int(duration * fps), total_frames)
if target_frames <= 0:
target_frames = min(max_frames, total_frames)
frame_indices = np.linspace(0, total_frames - 1, target_frames, dtype=int)
frames = []
for idx in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(frame_rgb)
frames.append(pil_image)
cap.release()
metadata = {
"duration": duration,
"total_frames": total_frames,
"sampled_frames": len(frames),
"video_fps": video_fps
}
return frames, metadata
finally:
if os.path.exists(video_path):
os.unlink(video_path)
def _masks_to_serializable(self, masks: torch.Tensor) -> List[List[List[int]]]:
"""Convert binary masks to RLE or simplified format for JSON serialization."""
# For efficiency, we'll return bounding box info and optionally compressed masks
# Full masks can be very large - return as base64 encoded numpy if needed
masks_np = masks.cpu().numpy().astype(np.uint8)
# Return as list of base64-encoded masks
encoded_masks = []
for mask in masks_np:
# Encode each mask as PNG for compression
from PIL import Image
img = Image.fromarray(mask * 255)
buffer = io.BytesIO()
img.save(buffer, format='PNG')
encoded = base64.b64encode(buffer.getvalue()).decode('utf-8')
encoded_masks.append(encoded)
return encoded_masks
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process image or video with SAM 3 for text-prompted segmentation.
INPUT FORMATS:
1. Single image with text prompt (find all instances):
{
"inputs": <image_url_or_base64>,
"parameters": {
"prompt": "Save button",
"threshold": 0.5,
"mask_threshold": 0.5,
"return_masks": true
}
}
2. Single image with multiple text prompts:
{
"inputs": <image_url_or_base64>,
"parameters": {
"prompts": ["button", "text field", "dropdown"],
"threshold": 0.5
}
}
3. Single image with box prompts (positive/negative):
{
"inputs": <image_url_or_base64>,
"parameters": {
"prompt": "handle",
"boxes": [[40, 183, 318, 204]],
"box_labels": [0], // 0=negative, 1=positive
"threshold": 0.5
}
}
4. Video with text prompt (track all instances):
{
"inputs": <video_url_or_base64>,
"parameters": {
"mode": "video",
"prompt": "Submit button",
"max_frames": 100,
"fps": 2.0
}
}
5. Batch images:
{
"inputs": [<image1>, <image2>, ...],
"parameters": {
"prompts": ["ear", "dial"], // One per image
"threshold": 0.5
}
}
6. ProofPath UI element detection:
{
"inputs": <screenshot_base64>,
"parameters": {
"mode": "ui_elements",
"elements": ["Save button", "Cancel button", "text input"],
"threshold": 0.5
}
}
OUTPUT FORMAT:
{
"results": [
{
"prompt": "Save button",
"instances": [
{
"box": [x1, y1, x2, y2],
"score": 0.95,
"mask": "<base64_png>" // if return_masks=true
}
]
}
],
"image_size": {"width": 1920, "height": 1080}
}
"""
inputs = data.get("inputs")
params = data.get("parameters", {})
if inputs is None:
raise ValueError("No inputs provided")
mode = params.get("mode", "image")
if mode == "video":
return self._process_video(inputs, params)
elif mode == "ui_elements":
return self._process_ui_elements(inputs, params)
elif isinstance(inputs, list):
return self._process_batch(inputs, params)
else:
return self._process_single_image(inputs, params)
def _process_single_image(self, image_data: Any, params: Dict) -> Dict[str, Any]:
"""Process a single image with text and/or box prompts."""
image = self._load_image(image_data)
threshold = params.get("threshold", 0.5)
mask_threshold = params.get("mask_threshold", 0.5)
return_masks = params.get("return_masks", True)
# Get prompts
prompt = params.get("prompt")
prompts = params.get("prompts", [prompt] if prompt else [])
if not prompts:
raise ValueError("No text prompt(s) provided")
# Get optional box prompts
boxes = params.get("boxes")
box_labels = params.get("box_labels")
results = []
for text_prompt in prompts:
# Prepare inputs
if boxes is not None:
input_boxes = [boxes]
input_boxes_labels = [box_labels] if box_labels else [[1] * len(boxes)]
processor_inputs = self.processor(
images=image,
text=text_prompt,
input_boxes=input_boxes,
input_boxes_labels=input_boxes_labels,
return_tensors="pt"
).to(self.device)
else:
processor_inputs = self.processor(
images=image,
text=text_prompt,
return_tensors="pt"
).to(self.device)
# Run inference
with torch.no_grad():
outputs = self.model(**processor_inputs)
# Post-process
post_results = self.processor.post_process_instance_segmentation(
outputs,
threshold=threshold,
mask_threshold=mask_threshold,
target_sizes=processor_inputs.get("original_sizes").tolist()
)[0]
instances = []
for i in range(len(post_results.get("boxes", []))):
instance = {
"box": post_results["boxes"][i].tolist(),
"score": float(post_results["scores"][i])
}
if return_masks and "masks" in post_results:
# Encode mask as base64 PNG
mask = post_results["masks"][i].cpu().numpy().astype(np.uint8) * 255
from PIL import Image as PILImage
mask_img = PILImage.fromarray(mask)
buffer = io.BytesIO()
mask_img.save(buffer, format='PNG')
instance["mask"] = base64.b64encode(buffer.getvalue()).decode('utf-8')
instances.append(instance)
results.append({
"prompt": text_prompt,
"instances": instances,
"count": len(instances)
})
return {
"results": results,
"image_size": {"width": image.width, "height": image.height}
}
def _process_batch(self, images_data: List, params: Dict) -> Dict[str, Any]:
"""Process multiple images with text prompts."""
images = [self._load_image(img) for img in images_data]
prompts = params.get("prompts", [])
prompt = params.get("prompt")
# Handle single prompt for all images
if prompt and not prompts:
prompts = [prompt] * len(images)
if len(prompts) != len(images):
raise ValueError(f"Number of prompts ({len(prompts)}) must match number of images ({len(images)})")
threshold = params.get("threshold", 0.5)
mask_threshold = params.get("mask_threshold", 0.5)
return_masks = params.get("return_masks", False) # Default false for batch
# Process batch
processor_inputs = self.processor(
images=images,
text=prompts,
return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model(**processor_inputs)
# Post-process all results
all_results = self.processor.post_process_instance_segmentation(
outputs,
threshold=threshold,
mask_threshold=mask_threshold,
target_sizes=processor_inputs.get("original_sizes").tolist()
)
results = []
for idx, (post_results, text_prompt, image) in enumerate(zip(all_results, prompts, images)):
instances = []
for i in range(len(post_results.get("boxes", []))):
instance = {
"box": post_results["boxes"][i].tolist(),
"score": float(post_results["scores"][i])
}
if return_masks and "masks" in post_results:
mask = post_results["masks"][i].cpu().numpy().astype(np.uint8) * 255
from PIL import Image as PILImage
mask_img = PILImage.fromarray(mask)
buffer = io.BytesIO()
mask_img.save(buffer, format='PNG')
instance["mask"] = base64.b64encode(buffer.getvalue()).decode('utf-8')
instances.append(instance)
results.append({
"image_index": idx,
"prompt": text_prompt,
"instances": instances,
"count": len(instances),
"image_size": {"width": image.width, "height": image.height}
})
return {"results": results}
def _process_ui_elements(self, image_data: Any, params: Dict) -> Dict[str, Any]:
"""
ProofPath-specific mode: Detect multiple UI element types in a screenshot.
Returns structured data for each element type with bounding boxes.
"""
image = self._load_image(image_data)
elements = params.get("elements", [])
if not elements:
# Default UI elements to look for
elements = ["button", "text input", "dropdown", "checkbox", "link"]
threshold = params.get("threshold", 0.5)
mask_threshold = params.get("mask_threshold", 0.5)
all_detections = {}
for element_type in elements:
processor_inputs = self.processor(
images=image,
text=element_type,
return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model(**processor_inputs)
post_results = self.processor.post_process_instance_segmentation(
outputs,
threshold=threshold,
mask_threshold=mask_threshold,
target_sizes=processor_inputs.get("original_sizes").tolist()
)[0]
detections = []
for i in range(len(post_results.get("boxes", []))):
box = post_results["boxes"][i].tolist()
detections.append({
"box": box,
"score": float(post_results["scores"][i]),
"center": [
(box[0] + box[2]) / 2,
(box[1] + box[3]) / 2
]
})
all_detections[element_type] = {
"count": len(detections),
"instances": detections
}
return {
"ui_elements": all_detections,
"image_size": {"width": image.width, "height": image.height},
"total_elements": sum(d["count"] for d in all_detections.values())
}
def _process_video(self, video_data: Any, params: Dict) -> Dict[str, Any]:
"""
Process video with SAM3 Video for text-prompted tracking.
Tracks all instances of the prompted concept across frames.
"""
video_model, video_processor = self._get_video_model()
prompt = params.get("prompt")
if not prompt:
raise ValueError("Text prompt required for video mode")
max_frames = params.get("max_frames", 100)
fps = params.get("fps", 2.0)
# Load video frames
frames, video_metadata = self._load_video_frames(video_data, max_frames, fps)
if not frames:
raise ValueError("No frames could be extracted from video")
# Initialize video session
inference_session = video_processor.init_video_session(
video=frames,
inference_device=self.device,
processing_device="cpu",
video_storage_device="cpu",
dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)
# Add text prompt
inference_session = video_processor.add_text_prompt(
inference_session=inference_session,
text=prompt,
)
# Process all frames
outputs_per_frame = {}
for model_outputs in video_model.propagate_in_video_iterator(
inference_session=inference_session,
max_frame_num_to_track=max_frames
):
processed = video_processor.postprocess_outputs(inference_session, model_outputs)
frame_data = {
"frame_idx": model_outputs.frame_idx,
"object_ids": processed["object_ids"].tolist() if hasattr(processed["object_ids"], "tolist") else processed["object_ids"],
"scores": processed["scores"].tolist() if hasattr(processed["scores"], "tolist") else processed["scores"],
"boxes": processed["boxes"].tolist() if hasattr(processed["boxes"], "tolist") else processed["boxes"],
}
outputs_per_frame[model_outputs.frame_idx] = frame_data
# Compile tracking results
# Group by object_id to show trajectory
object_tracks = {}
for frame_idx, frame_data in outputs_per_frame.items():
for i, obj_id in enumerate(frame_data["object_ids"]):
obj_id_str = str(obj_id)
if obj_id_str not in object_tracks:
object_tracks[obj_id_str] = {
"object_id": obj_id,
"frames": []
}
object_tracks[obj_id_str]["frames"].append({
"frame_idx": frame_idx,
"box": frame_data["boxes"][i] if i < len(frame_data["boxes"]) else None,
"score": frame_data["scores"][i] if i < len(frame_data["scores"]) else None
})
return {
"prompt": prompt,
"video_metadata": video_metadata,
"frames_processed": len(outputs_per_frame),
"objects_tracked": len(object_tracks),
"tracks": list(object_tracks.values()),
"per_frame_detections": outputs_per_frame
}
# For testing locally
if __name__ == "__main__":
handler = EndpointHandler()
# Test with a sample image URL
test_data = {
"inputs": "http://images.cocodataset.org/val2017/000000077595.jpg",
"parameters": {
"prompt": "ear",
"threshold": 0.5,
"return_masks": False
}
}
result = handler(test_data)
print(f"Found {result['results'][0]['count']} instances of '{result['results'][0]['prompt']}'")
for inst in result['results'][0]['instances']:
print(f" Box: {inst['box']}, Score: {inst['score']:.3f}")