import base64 import io import torch from PIL import Image from transformers import AutoProcessor, SamModel class EndpointHandler: def __init__(self, path="facebook/sam3"): self.processor = AutoProcessor.from_pretrained(path) self.model = SamModel.from_pretrained( path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) self.model.eval() if torch.cuda.is_available(): self.model = self.model.cuda() def __call__(self, data): """ Expected HF pipeline request: { "inputs": "", "parameters": { "classes": ["pothole", "marking"] } } """ # Extract image_b64 = data.get("inputs", None) params = data.get("parameters", {}) classes = params.get("classes", None) if image_b64 is None or classes is None: return {"error": "Required fields: inputs (image base64), parameters.classes"} # Decode image image_bytes = base64.b64decode(image_b64) pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") inputs = self.processor( images=pil_image, text=classes, return_tensors="pt" ) if torch.cuda.is_available(): inputs = {k: v.cuda() for k, v in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs) pred_masks = outputs.pred_masks.squeeze(1) # [N, H, W] results = [] for i, cls in enumerate(classes): mask = pred_masks[i].float().cpu() binary_mask = (mask > 0.5).numpy().astype("uint8") * 255 pil_mask = Image.fromarray(binary_mask, mode="L") buf = io.BytesIO() pil_mask.save(buf, format="PNG") mask_b64 = base64.b64encode(buf.getvalue()).decode("utf-8") results.append({ "label": cls, "mask": mask_b64, "score": 1.0 # SAM3 does not output per-class confidence }) return results