sam3 / handler.py
Thibaut's picture
Create handler.py
a00a9a9
raw
history blame
2.17 kB
import base64
import io
import torch
from PIL import Image
from transformers import AutoProcessor, SamModel
class EndpointHandler:
def __init__(self, path="facebook/sam3"):
self.processor = AutoProcessor.from_pretrained(path)
self.model = SamModel.from_pretrained(
path,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
self.model.eval()
if torch.cuda.is_available():
self.model = self.model.cuda()
def __call__(self, data):
"""
Expected HF pipeline request:
{
"inputs": "<base64 or URL>",
"parameters": {
"classes": ["pothole", "marking"]
}
}
"""
# Extract
image_b64 = data.get("inputs", None)
params = data.get("parameters", {})
classes = params.get("classes", None)
if image_b64 is None or classes is None:
return {"error": "Required fields: inputs (image base64), parameters.classes"}
# Decode image
image_bytes = base64.b64decode(image_b64)
pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
inputs = self.processor(
images=pil_image,
text=classes,
return_tensors="pt"
)
if torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
pred_masks = outputs.pred_masks.squeeze(1) # [N, H, W]
results = []
for i, cls in enumerate(classes):
mask = pred_masks[i].float().cpu()
binary_mask = (mask > 0.5).numpy().astype("uint8") * 255
pil_mask = Image.fromarray(binary_mask, mode="L")
buf = io.BytesIO()
pil_mask.save(buf, format="PNG")
mask_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
results.append({
"label": cls,
"mask": mask_b64,
"score": 1.0 # SAM3 does not output per-class confidence
})
return results