sam3-endpoint / handler.py
hfpipo's picture
Upload handler.py with huggingface_hub
df8aa8d verified
Raw
History Blame Contribute Delete
4.08 kB
from typing import Dict, Any
from transformers import Sam3Processor, Sam3Model
import torch
from PIL import Image
import io
import base64
import numpy as np
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the SAM3 model and processor for text-prompted segmentation.
Args:
path: Path to local model files (if deploying with custom weights)
or empty string to use the default facebook/sam3 model
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Use local path if provided, otherwise use the default model
model_id = path if path else "facebook/sam3"
self.processor = Sam3Processor.from_pretrained(model_id)
self.model = Sam3Model.from_pretrained(model_id).to(self.device)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process an image with a text prompt and return segmentation masks.
Expected input format:
{
"inputs": {
"image": "<base64_encoded_image>",
"prompt": "text description of object to segment" # e.g., "a red car"
}
}
Returns:
{
"masks": [...], # List of binary masks as base64 encoded PNGs
"boxes": [...], # Bounding boxes in xyxy format
"scores": [...] # Confidence scores
}
"""
# 1. Extract inputs
inputs = data.pop("inputs", data)
image_b64 = inputs.get("image")
text_prompt = inputs.get("prompt", None)
# Optional parameters
threshold = inputs.get("threshold", 0.5)
mask_threshold = inputs.get("mask_threshold", 0.5)
if not image_b64:
return {"error": "No image provided. Please provide a base64-encoded image."}
if not text_prompt:
return {"error": "No text prompt provided. Please provide a 'prompt' field."}
# 2. Decode image
try:
image_data = base64.b64decode(image_b64)
image = Image.open(io.BytesIO(image_data)).convert("RGB")
except Exception as e:
return {"error": f"Failed to decode image: {str(e)}"}
# 3. Process inputs with text prompt
processor_inputs = self.processor(
images=image,
text=text_prompt,
return_tensors="pt"
).to(self.device)
# 4. Run Inference
with torch.no_grad():
outputs = self.model(**processor_inputs)
# 5. Post-process results
results = self.processor.post_process_instance_segmentation(
outputs,
threshold=threshold,
mask_threshold=mask_threshold,
target_sizes=processor_inputs.get("original_sizes").tolist()
)[0]
# 6. Format response
response = {
"masks": [],
"boxes": [],
"scores": []
}
if len(results["masks"]) > 0:
# Convert masks to base64-encoded PNGs
for mask in results["masks"]:
# Convert boolean mask to uint8 image
mask_np = mask.cpu().numpy().astype(np.uint8) * 255
mask_img = Image.fromarray(mask_np, mode="L")
# Encode as base64 PNG
buffer = io.BytesIO()
mask_img.save(buffer, format="PNG")
mask_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
response["masks"].append(mask_b64)
# Convert boxes to list
if "boxes" in results:
response["boxes"] = results["boxes"].cpu().tolist()
# Convert scores to list
if "scores" in results:
response["scores"] = results["scores"].cpu().tolist()
response["num_objects"] = len(response["masks"])
return response