|
|
""" |
|
|
SAM 3 Custom Inference Handler for Hugging Face Inference Endpoints |
|
|
Model: facebook/sam3 |
|
|
|
|
|
For ProofPath video assessment - text-prompted segmentation to find UI elements. |
|
|
Supports text prompts like "Save button", "dropdown menu", "text input field". |
|
|
|
|
|
KEY CAPABILITIES: |
|
|
- Text-to-segment: Find ALL instances of a concept (e.g., "button" → all buttons) |
|
|
- Promptable Concept Segmentation (PCS): 270K unique concepts |
|
|
- Video tracking: Consistent object IDs across frames |
|
|
- Presence token: Discriminates similar elements ("player in white" vs "player in red") |
|
|
|
|
|
REQUIREMENTS: |
|
|
1. Set HF_TOKEN environment variable (model is gated) |
|
|
2. Accept license at https://huggingface.co/facebook/sam3 |
|
|
""" |
|
|
|
|
|
from typing import Dict, List, Any, Optional, Union |
|
|
import torch |
|
|
import numpy as np |
|
|
import base64 |
|
|
import io |
|
|
import os |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
""" |
|
|
Initialize SAM 3 model for text-prompted segmentation. |
|
|
|
|
|
Args: |
|
|
path: Path to the model directory (ignored - we load from HF hub) |
|
|
""" |
|
|
model_id = "facebook/sam3" |
|
|
|
|
|
|
|
|
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") |
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
from transformers import Sam3Processor, Sam3Model |
|
|
|
|
|
self.processor = Sam3Processor.from_pretrained( |
|
|
model_id, |
|
|
token=hf_token, |
|
|
) |
|
|
|
|
|
self.model = Sam3Model.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
|
|
token=hf_token, |
|
|
).to(self.device) |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
self._video_model = None |
|
|
self._video_processor = None |
|
|
|
|
|
def _get_video_model(self): |
|
|
"""Lazy load video model only when needed.""" |
|
|
if self._video_model is None: |
|
|
from transformers import Sam3VideoModel, Sam3VideoProcessor |
|
|
|
|
|
model_id = "facebook/sam3" |
|
|
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") |
|
|
|
|
|
self._video_processor = Sam3VideoProcessor.from_pretrained( |
|
|
model_id, |
|
|
token=hf_token, |
|
|
) |
|
|
|
|
|
self._video_model = Sam3VideoModel.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
|
|
token=hf_token, |
|
|
).to(self.device) |
|
|
|
|
|
self._video_model.eval() |
|
|
|
|
|
return self._video_model, self._video_processor |
|
|
|
|
|
def _load_image(self, image_data: Any): |
|
|
"""Load image from various formats.""" |
|
|
from PIL import Image |
|
|
import requests |
|
|
|
|
|
if isinstance(image_data, Image.Image): |
|
|
return image_data.convert('RGB') |
|
|
elif isinstance(image_data, str): |
|
|
if image_data.startswith(('http://', 'https://')): |
|
|
response = requests.get(image_data, stream=True) |
|
|
return Image.open(response.raw).convert('RGB') |
|
|
elif image_data.startswith('data:'): |
|
|
header, encoded = image_data.split(',', 1) |
|
|
image_bytes = base64.b64decode(encoded) |
|
|
return Image.open(io.BytesIO(image_bytes)).convert('RGB') |
|
|
else: |
|
|
|
|
|
image_bytes = base64.b64decode(image_data) |
|
|
return Image.open(io.BytesIO(image_bytes)).convert('RGB') |
|
|
elif isinstance(image_data, bytes): |
|
|
return Image.open(io.BytesIO(image_data)).convert('RGB') |
|
|
else: |
|
|
raise ValueError(f"Unsupported image input type: {type(image_data)}") |
|
|
|
|
|
def _load_video_frames(self, video_data: Any, max_frames: int = 100, fps: float = 2.0) -> List: |
|
|
"""Load video frames from various formats.""" |
|
|
import cv2 |
|
|
from PIL import Image |
|
|
import tempfile |
|
|
|
|
|
|
|
|
if isinstance(video_data, str): |
|
|
if video_data.startswith(('http://', 'https://')): |
|
|
import requests |
|
|
response = requests.get(video_data, stream=True) |
|
|
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
|
|
for chunk in response.iter_content(chunk_size=8192): |
|
|
f.write(chunk) |
|
|
video_path = f.name |
|
|
elif video_data.startswith('data:'): |
|
|
header, encoded = video_data.split(',', 1) |
|
|
video_bytes = base64.b64decode(encoded) |
|
|
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
|
|
f.write(video_bytes) |
|
|
video_path = f.name |
|
|
else: |
|
|
video_bytes = base64.b64decode(video_data) |
|
|
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
|
|
f.write(video_bytes) |
|
|
video_path = f.name |
|
|
elif isinstance(video_data, bytes): |
|
|
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
|
|
f.write(video_data) |
|
|
video_path = f.name |
|
|
else: |
|
|
raise ValueError(f"Unsupported video input type: {type(video_data)}") |
|
|
|
|
|
try: |
|
|
cap = cv2.VideoCapture(video_path) |
|
|
video_fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
duration = total_frames / video_fps if video_fps > 0 else 0 |
|
|
|
|
|
|
|
|
target_frames = min(max_frames, int(duration * fps), total_frames) |
|
|
if target_frames <= 0: |
|
|
target_frames = min(max_frames, total_frames) |
|
|
|
|
|
frame_indices = np.linspace(0, total_frames - 1, target_frames, dtype=int) |
|
|
|
|
|
frames = [] |
|
|
for idx in frame_indices: |
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, idx) |
|
|
ret, frame = cap.read() |
|
|
if ret: |
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
pil_image = Image.fromarray(frame_rgb) |
|
|
frames.append(pil_image) |
|
|
|
|
|
cap.release() |
|
|
|
|
|
metadata = { |
|
|
"duration": duration, |
|
|
"total_frames": total_frames, |
|
|
"sampled_frames": len(frames), |
|
|
"video_fps": video_fps |
|
|
} |
|
|
|
|
|
return frames, metadata |
|
|
|
|
|
finally: |
|
|
if os.path.exists(video_path): |
|
|
os.unlink(video_path) |
|
|
|
|
|
def _masks_to_serializable(self, masks: torch.Tensor) -> List[List[List[int]]]: |
|
|
"""Convert binary masks to RLE or simplified format for JSON serialization.""" |
|
|
|
|
|
|
|
|
masks_np = masks.cpu().numpy().astype(np.uint8) |
|
|
|
|
|
|
|
|
encoded_masks = [] |
|
|
for mask in masks_np: |
|
|
|
|
|
from PIL import Image |
|
|
img = Image.fromarray(mask * 255) |
|
|
buffer = io.BytesIO() |
|
|
img.save(buffer, format='PNG') |
|
|
encoded = base64.b64encode(buffer.getvalue()).decode('utf-8') |
|
|
encoded_masks.append(encoded) |
|
|
|
|
|
return encoded_masks |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Process image or video with SAM 3 for text-prompted segmentation. |
|
|
|
|
|
INPUT FORMATS: |
|
|
|
|
|
1. Single image with text prompt (find all instances): |
|
|
{ |
|
|
"inputs": <image_url_or_base64>, |
|
|
"parameters": { |
|
|
"prompt": "Save button", |
|
|
"threshold": 0.5, |
|
|
"mask_threshold": 0.5, |
|
|
"return_masks": true |
|
|
} |
|
|
} |
|
|
|
|
|
2. Single image with multiple text prompts: |
|
|
{ |
|
|
"inputs": <image_url_or_base64>, |
|
|
"parameters": { |
|
|
"prompts": ["button", "text field", "dropdown"], |
|
|
"threshold": 0.5 |
|
|
} |
|
|
} |
|
|
|
|
|
3. Single image with box prompts (positive/negative): |
|
|
{ |
|
|
"inputs": <image_url_or_base64>, |
|
|
"parameters": { |
|
|
"prompt": "handle", |
|
|
"boxes": [[40, 183, 318, 204]], |
|
|
"box_labels": [0], // 0=negative, 1=positive |
|
|
"threshold": 0.5 |
|
|
} |
|
|
} |
|
|
|
|
|
4. Video with text prompt (track all instances): |
|
|
{ |
|
|
"inputs": <video_url_or_base64>, |
|
|
"parameters": { |
|
|
"mode": "video", |
|
|
"prompt": "Submit button", |
|
|
"max_frames": 100, |
|
|
"fps": 2.0 |
|
|
} |
|
|
} |
|
|
|
|
|
5. Batch images: |
|
|
{ |
|
|
"inputs": [<image1>, <image2>, ...], |
|
|
"parameters": { |
|
|
"prompts": ["ear", "dial"], // One per image |
|
|
"threshold": 0.5 |
|
|
} |
|
|
} |
|
|
|
|
|
6. ProofPath UI element detection: |
|
|
{ |
|
|
"inputs": <screenshot_base64>, |
|
|
"parameters": { |
|
|
"mode": "ui_elements", |
|
|
"elements": ["Save button", "Cancel button", "text input"], |
|
|
"threshold": 0.5 |
|
|
} |
|
|
} |
|
|
|
|
|
OUTPUT FORMAT: |
|
|
{ |
|
|
"results": [ |
|
|
{ |
|
|
"prompt": "Save button", |
|
|
"instances": [ |
|
|
{ |
|
|
"box": [x1, y1, x2, y2], |
|
|
"score": 0.95, |
|
|
"mask": "<base64_png>" // if return_masks=true |
|
|
} |
|
|
] |
|
|
} |
|
|
], |
|
|
"image_size": {"width": 1920, "height": 1080} |
|
|
} |
|
|
""" |
|
|
inputs = data.get("inputs") |
|
|
params = data.get("parameters", {}) |
|
|
|
|
|
if inputs is None: |
|
|
raise ValueError("No inputs provided") |
|
|
|
|
|
mode = params.get("mode", "image") |
|
|
|
|
|
if mode == "video": |
|
|
return self._process_video(inputs, params) |
|
|
elif mode == "ui_elements": |
|
|
return self._process_ui_elements(inputs, params) |
|
|
elif isinstance(inputs, list): |
|
|
return self._process_batch(inputs, params) |
|
|
else: |
|
|
return self._process_single_image(inputs, params) |
|
|
|
|
|
def _process_single_image(self, image_data: Any, params: Dict) -> Dict[str, Any]: |
|
|
"""Process a single image with text and/or box prompts.""" |
|
|
image = self._load_image(image_data) |
|
|
|
|
|
threshold = params.get("threshold", 0.5) |
|
|
mask_threshold = params.get("mask_threshold", 0.5) |
|
|
return_masks = params.get("return_masks", True) |
|
|
|
|
|
|
|
|
prompt = params.get("prompt") |
|
|
prompts = params.get("prompts", [prompt] if prompt else []) |
|
|
|
|
|
if not prompts: |
|
|
raise ValueError("No text prompt(s) provided") |
|
|
|
|
|
|
|
|
boxes = params.get("boxes") |
|
|
box_labels = params.get("box_labels") |
|
|
|
|
|
results = [] |
|
|
|
|
|
for text_prompt in prompts: |
|
|
|
|
|
if boxes is not None: |
|
|
input_boxes = [boxes] |
|
|
input_boxes_labels = [box_labels] if box_labels else [[1] * len(boxes)] |
|
|
|
|
|
processor_inputs = self.processor( |
|
|
images=image, |
|
|
text=text_prompt, |
|
|
input_boxes=input_boxes, |
|
|
input_boxes_labels=input_boxes_labels, |
|
|
return_tensors="pt" |
|
|
).to(self.device) |
|
|
else: |
|
|
processor_inputs = self.processor( |
|
|
images=image, |
|
|
text=text_prompt, |
|
|
return_tensors="pt" |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**processor_inputs) |
|
|
|
|
|
|
|
|
post_results = self.processor.post_process_instance_segmentation( |
|
|
outputs, |
|
|
threshold=threshold, |
|
|
mask_threshold=mask_threshold, |
|
|
target_sizes=processor_inputs.get("original_sizes").tolist() |
|
|
)[0] |
|
|
|
|
|
instances = [] |
|
|
for i in range(len(post_results.get("boxes", []))): |
|
|
instance = { |
|
|
"box": post_results["boxes"][i].tolist(), |
|
|
"score": float(post_results["scores"][i]) |
|
|
} |
|
|
|
|
|
if return_masks and "masks" in post_results: |
|
|
|
|
|
mask = post_results["masks"][i].cpu().numpy().astype(np.uint8) * 255 |
|
|
from PIL import Image as PILImage |
|
|
mask_img = PILImage.fromarray(mask) |
|
|
buffer = io.BytesIO() |
|
|
mask_img.save(buffer, format='PNG') |
|
|
instance["mask"] = base64.b64encode(buffer.getvalue()).decode('utf-8') |
|
|
|
|
|
instances.append(instance) |
|
|
|
|
|
results.append({ |
|
|
"prompt": text_prompt, |
|
|
"instances": instances, |
|
|
"count": len(instances) |
|
|
}) |
|
|
|
|
|
return { |
|
|
"results": results, |
|
|
"image_size": {"width": image.width, "height": image.height} |
|
|
} |
|
|
|
|
|
def _process_batch(self, images_data: List, params: Dict) -> Dict[str, Any]: |
|
|
"""Process multiple images with text prompts.""" |
|
|
images = [self._load_image(img) for img in images_data] |
|
|
|
|
|
prompts = params.get("prompts", []) |
|
|
prompt = params.get("prompt") |
|
|
|
|
|
|
|
|
if prompt and not prompts: |
|
|
prompts = [prompt] * len(images) |
|
|
|
|
|
if len(prompts) != len(images): |
|
|
raise ValueError(f"Number of prompts ({len(prompts)}) must match number of images ({len(images)})") |
|
|
|
|
|
threshold = params.get("threshold", 0.5) |
|
|
mask_threshold = params.get("mask_threshold", 0.5) |
|
|
return_masks = params.get("return_masks", False) |
|
|
|
|
|
|
|
|
processor_inputs = self.processor( |
|
|
images=images, |
|
|
text=prompts, |
|
|
return_tensors="pt" |
|
|
).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**processor_inputs) |
|
|
|
|
|
|
|
|
all_results = self.processor.post_process_instance_segmentation( |
|
|
outputs, |
|
|
threshold=threshold, |
|
|
mask_threshold=mask_threshold, |
|
|
target_sizes=processor_inputs.get("original_sizes").tolist() |
|
|
) |
|
|
|
|
|
results = [] |
|
|
for idx, (post_results, text_prompt, image) in enumerate(zip(all_results, prompts, images)): |
|
|
instances = [] |
|
|
for i in range(len(post_results.get("boxes", []))): |
|
|
instance = { |
|
|
"box": post_results["boxes"][i].tolist(), |
|
|
"score": float(post_results["scores"][i]) |
|
|
} |
|
|
|
|
|
if return_masks and "masks" in post_results: |
|
|
mask = post_results["masks"][i].cpu().numpy().astype(np.uint8) * 255 |
|
|
from PIL import Image as PILImage |
|
|
mask_img = PILImage.fromarray(mask) |
|
|
buffer = io.BytesIO() |
|
|
mask_img.save(buffer, format='PNG') |
|
|
instance["mask"] = base64.b64encode(buffer.getvalue()).decode('utf-8') |
|
|
|
|
|
instances.append(instance) |
|
|
|
|
|
results.append({ |
|
|
"image_index": idx, |
|
|
"prompt": text_prompt, |
|
|
"instances": instances, |
|
|
"count": len(instances), |
|
|
"image_size": {"width": image.width, "height": image.height} |
|
|
}) |
|
|
|
|
|
return {"results": results} |
|
|
|
|
|
def _process_ui_elements(self, image_data: Any, params: Dict) -> Dict[str, Any]: |
|
|
""" |
|
|
ProofPath-specific mode: Detect multiple UI element types in a screenshot. |
|
|
Returns structured data for each element type with bounding boxes. |
|
|
""" |
|
|
image = self._load_image(image_data) |
|
|
|
|
|
elements = params.get("elements", []) |
|
|
if not elements: |
|
|
|
|
|
elements = ["button", "text input", "dropdown", "checkbox", "link"] |
|
|
|
|
|
threshold = params.get("threshold", 0.5) |
|
|
mask_threshold = params.get("mask_threshold", 0.5) |
|
|
|
|
|
all_detections = {} |
|
|
|
|
|
for element_type in elements: |
|
|
processor_inputs = self.processor( |
|
|
images=image, |
|
|
text=element_type, |
|
|
return_tensors="pt" |
|
|
).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**processor_inputs) |
|
|
|
|
|
post_results = self.processor.post_process_instance_segmentation( |
|
|
outputs, |
|
|
threshold=threshold, |
|
|
mask_threshold=mask_threshold, |
|
|
target_sizes=processor_inputs.get("original_sizes").tolist() |
|
|
)[0] |
|
|
|
|
|
detections = [] |
|
|
for i in range(len(post_results.get("boxes", []))): |
|
|
box = post_results["boxes"][i].tolist() |
|
|
detections.append({ |
|
|
"box": box, |
|
|
"score": float(post_results["scores"][i]), |
|
|
"center": [ |
|
|
(box[0] + box[2]) / 2, |
|
|
(box[1] + box[3]) / 2 |
|
|
] |
|
|
}) |
|
|
|
|
|
all_detections[element_type] = { |
|
|
"count": len(detections), |
|
|
"instances": detections |
|
|
} |
|
|
|
|
|
return { |
|
|
"ui_elements": all_detections, |
|
|
"image_size": {"width": image.width, "height": image.height}, |
|
|
"total_elements": sum(d["count"] for d in all_detections.values()) |
|
|
} |
|
|
|
|
|
def _process_video(self, video_data: Any, params: Dict) -> Dict[str, Any]: |
|
|
""" |
|
|
Process video with SAM3 Video for text-prompted tracking. |
|
|
Tracks all instances of the prompted concept across frames. |
|
|
""" |
|
|
video_model, video_processor = self._get_video_model() |
|
|
|
|
|
prompt = params.get("prompt") |
|
|
if not prompt: |
|
|
raise ValueError("Text prompt required for video mode") |
|
|
|
|
|
max_frames = params.get("max_frames", 100) |
|
|
fps = params.get("fps", 2.0) |
|
|
|
|
|
|
|
|
frames, video_metadata = self._load_video_frames(video_data, max_frames, fps) |
|
|
|
|
|
if not frames: |
|
|
raise ValueError("No frames could be extracted from video") |
|
|
|
|
|
|
|
|
inference_session = video_processor.init_video_session( |
|
|
video=frames, |
|
|
inference_device=self.device, |
|
|
processing_device="cpu", |
|
|
video_storage_device="cpu", |
|
|
dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
|
|
) |
|
|
|
|
|
|
|
|
inference_session = video_processor.add_text_prompt( |
|
|
inference_session=inference_session, |
|
|
text=prompt, |
|
|
) |
|
|
|
|
|
|
|
|
outputs_per_frame = {} |
|
|
for model_outputs in video_model.propagate_in_video_iterator( |
|
|
inference_session=inference_session, |
|
|
max_frame_num_to_track=max_frames |
|
|
): |
|
|
processed = video_processor.postprocess_outputs(inference_session, model_outputs) |
|
|
|
|
|
frame_data = { |
|
|
"frame_idx": model_outputs.frame_idx, |
|
|
"object_ids": processed["object_ids"].tolist() if hasattr(processed["object_ids"], "tolist") else processed["object_ids"], |
|
|
"scores": processed["scores"].tolist() if hasattr(processed["scores"], "tolist") else processed["scores"], |
|
|
"boxes": processed["boxes"].tolist() if hasattr(processed["boxes"], "tolist") else processed["boxes"], |
|
|
} |
|
|
|
|
|
outputs_per_frame[model_outputs.frame_idx] = frame_data |
|
|
|
|
|
|
|
|
|
|
|
object_tracks = {} |
|
|
for frame_idx, frame_data in outputs_per_frame.items(): |
|
|
for i, obj_id in enumerate(frame_data["object_ids"]): |
|
|
obj_id_str = str(obj_id) |
|
|
if obj_id_str not in object_tracks: |
|
|
object_tracks[obj_id_str] = { |
|
|
"object_id": obj_id, |
|
|
"frames": [] |
|
|
} |
|
|
object_tracks[obj_id_str]["frames"].append({ |
|
|
"frame_idx": frame_idx, |
|
|
"box": frame_data["boxes"][i] if i < len(frame_data["boxes"]) else None, |
|
|
"score": frame_data["scores"][i] if i < len(frame_data["scores"]) else None |
|
|
}) |
|
|
|
|
|
return { |
|
|
"prompt": prompt, |
|
|
"video_metadata": video_metadata, |
|
|
"frames_processed": len(outputs_per_frame), |
|
|
"objects_tracked": len(object_tracks), |
|
|
"tracks": list(object_tracks.values()), |
|
|
"per_frame_detections": outputs_per_frame |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
handler = EndpointHandler() |
|
|
|
|
|
|
|
|
test_data = { |
|
|
"inputs": "http://images.cocodataset.org/val2017/000000077595.jpg", |
|
|
"parameters": { |
|
|
"prompt": "ear", |
|
|
"threshold": 0.5, |
|
|
"return_masks": False |
|
|
} |
|
|
} |
|
|
|
|
|
result = handler(test_data) |
|
|
print(f"Found {result['results'][0]['count']} instances of '{result['results'][0]['prompt']}'") |
|
|
for inst in result['results'][0]['instances']: |
|
|
print(f" Box: {inst['box']}, Score: {inst['score']:.3f}") |
|
|
|