File size: 4,079 Bytes
df8aa8d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 | from typing import Dict, Any
from transformers import Sam3Processor, Sam3Model
import torch
from PIL import Image
import io
import base64
import numpy as np
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the SAM3 model and processor for text-prompted segmentation.
Args:
path: Path to local model files (if deploying with custom weights)
or empty string to use the default facebook/sam3 model
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Use local path if provided, otherwise use the default model
model_id = path if path else "facebook/sam3"
self.processor = Sam3Processor.from_pretrained(model_id)
self.model = Sam3Model.from_pretrained(model_id).to(self.device)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process an image with a text prompt and return segmentation masks.
Expected input format:
{
"inputs": {
"image": "<base64_encoded_image>",
"prompt": "text description of object to segment" # e.g., "a red car"
}
}
Returns:
{
"masks": [...], # List of binary masks as base64 encoded PNGs
"boxes": [...], # Bounding boxes in xyxy format
"scores": [...] # Confidence scores
}
"""
# 1. Extract inputs
inputs = data.pop("inputs", data)
image_b64 = inputs.get("image")
text_prompt = inputs.get("prompt", None)
# Optional parameters
threshold = inputs.get("threshold", 0.5)
mask_threshold = inputs.get("mask_threshold", 0.5)
if not image_b64:
return {"error": "No image provided. Please provide a base64-encoded image."}
if not text_prompt:
return {"error": "No text prompt provided. Please provide a 'prompt' field."}
# 2. Decode image
try:
image_data = base64.b64decode(image_b64)
image = Image.open(io.BytesIO(image_data)).convert("RGB")
except Exception as e:
return {"error": f"Failed to decode image: {str(e)}"}
# 3. Process inputs with text prompt
processor_inputs = self.processor(
images=image,
text=text_prompt,
return_tensors="pt"
).to(self.device)
# 4. Run Inference
with torch.no_grad():
outputs = self.model(**processor_inputs)
# 5. Post-process results
results = self.processor.post_process_instance_segmentation(
outputs,
threshold=threshold,
mask_threshold=mask_threshold,
target_sizes=processor_inputs.get("original_sizes").tolist()
)[0]
# 6. Format response
response = {
"masks": [],
"boxes": [],
"scores": []
}
if len(results["masks"]) > 0:
# Convert masks to base64-encoded PNGs
for mask in results["masks"]:
# Convert boolean mask to uint8 image
mask_np = mask.cpu().numpy().astype(np.uint8) * 255
mask_img = Image.fromarray(mask_np, mode="L")
# Encode as base64 PNG
buffer = io.BytesIO()
mask_img.save(buffer, format="PNG")
mask_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
response["masks"].append(mask_b64)
# Convert boxes to list
if "boxes" in results:
response["boxes"] = results["boxes"].cpu().tolist()
# Convert scores to list
if "scores" in results:
response["scores"] = results["scores"].cpu().tolist()
response["num_objects"] = len(response["masks"])
return response
|