| | import base64 |
| | import io |
| | import torch |
| | from PIL import Image |
| | from transformers import AutoProcessor, SamModel |
| |
|
| |
|
| | class EndpointHandler: |
| |
|
| | def __init__(self, path="facebook/sam3"): |
| | self.processor = AutoProcessor.from_pretrained(path) |
| | self.model = SamModel.from_pretrained( |
| | path, |
| | torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
| | ) |
| | self.model.eval() |
| | if torch.cuda.is_available(): |
| | self.model = self.model.cuda() |
| |
|
| | def __call__(self, data): |
| | """ |
| | Expected HF pipeline request: |
| | { |
| | "inputs": "<base64 or URL>", |
| | "parameters": { |
| | "classes": ["pothole", "marking"] |
| | } |
| | } |
| | """ |
| |
|
| | |
| | image_b64 = data.get("inputs", None) |
| | params = data.get("parameters", {}) |
| | classes = params.get("classes", None) |
| |
|
| | if image_b64 is None or classes is None: |
| | return {"error": "Required fields: inputs (image base64), parameters.classes"} |
| |
|
| | |
| | image_bytes = base64.b64decode(image_b64) |
| | pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| |
|
| | inputs = self.processor( |
| | images=pil_image, |
| | text=classes, |
| | return_tensors="pt" |
| | ) |
| |
|
| | if torch.cuda.is_available(): |
| | inputs = {k: v.cuda() for k, v in inputs.items()} |
| |
|
| | with torch.no_grad(): |
| | outputs = self.model(**inputs) |
| |
|
| | pred_masks = outputs.pred_masks.squeeze(1) |
| |
|
| | results = [] |
| | for i, cls in enumerate(classes): |
| | mask = pred_masks[i].float().cpu() |
| | binary_mask = (mask > 0.5).numpy().astype("uint8") * 255 |
| |
|
| | pil_mask = Image.fromarray(binary_mask, mode="L") |
| | buf = io.BytesIO() |
| | pil_mask.save(buf, format="PNG") |
| | mask_b64 = base64.b64encode(buf.getvalue()).decode("utf-8") |
| |
|
| | results.append({ |
| | "label": cls, |
| | "mask": mask_b64, |
| | "score": 1.0 |
| | }) |
| |
|
| | return results |
| |
|