| import base64 |
| import os |
| from contextlib import nullcontext |
| from io import BytesIO |
| from typing import Any, Dict |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
| from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
|
|
| MODEL_ID = "facebook/sam2.1-hiera-base-plus" |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.model_id = os.environ.get("SAM2_MODEL_ID", MODEL_ID) |
| self.predictor = SAM2ImagePredictor.from_pretrained( |
| self.model_id, |
| device=self.device, |
| ) |
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| inputs = data.get("inputs", data) |
| image = self._decode_image(inputs["image_base64"]).convert("RGB") |
| mime_type = inputs.get("mime_type", "image/png") |
| boxes = inputs.get("boxes", []) |
| if not boxes: |
| return {"masks": []} |
|
|
| width, height = image.size |
| normalized_boxes = [ |
| { |
| "id": str(box["id"]), |
| "box": self._normalize_box(box["box"], width, height), |
| } |
| for box in boxes |
| ] |
| input_boxes = np.array( |
| [ |
| [ |
| item["box"]["x1"], |
| item["box"]["y1"], |
| item["box"]["x2"], |
| item["box"]["y2"], |
| ] |
| for item in normalized_boxes |
| ], |
| dtype=np.float32, |
| ) |
|
|
| image_array = np.array(image) |
| with torch.inference_mode(), self._autocast_context(): |
| self.predictor.set_image(image_array) |
| masks, scores, _ = self.predictor.predict( |
| box=input_boxes, |
| multimask_output=False, |
| ) |
|
|
| masks = np.asarray(masks) |
| scores = np.asarray(scores) |
|
|
| response_masks = [] |
| for index, item in enumerate(normalized_boxes): |
| mask_bool = self._select_mask(masks, index) |
| score = self._select_score(scores, index) |
|
|
| response_masks.append( |
| { |
| "id": item["id"], |
| "score": score, |
| "mask_png_base64": self._encode_cropped_mask(mask_bool, item["box"]), |
| "box": item["box"], |
| "mime_type": mime_type, |
| } |
| ) |
|
|
| return {"masks": response_masks} |
|
|
| def _autocast_context(self): |
| if self.device == "cuda": |
| return torch.autocast("cuda", dtype=torch.bfloat16) |
| return nullcontext() |
|
|
| def _select_mask(self, masks: np.ndarray, index: int) -> np.ndarray: |
| if masks.ndim == 4: |
| return masks[index, 0] > 0 |
| if masks.ndim == 3: |
| return masks[index] > 0 |
| if masks.ndim == 2: |
| return masks > 0 |
| raise ValueError(f"Unexpected mask tensor shape: {masks.shape}") |
|
|
| def _select_score(self, scores: np.ndarray, index: int) -> float: |
| if scores.ndim == 2: |
| return float(scores[index, 0]) |
| if scores.ndim == 1: |
| return float(scores[index]) |
| if scores.ndim == 0: |
| return float(scores) |
| raise ValueError(f"Unexpected score tensor shape: {scores.shape}") |
|
|
| def _decode_image(self, image_base64: str) -> Image.Image: |
| if "," in image_base64: |
| image_base64 = image_base64.split(",", 1)[1] |
| return Image.open(BytesIO(base64.b64decode(image_base64))) |
|
|
| def _normalize_box( |
| self, |
| box: Dict[str, float], |
| width: int, |
| height: int, |
| ) -> Dict[str, int]: |
| x1 = int(max(0, min(width - 1, round(float(box["x1"]))))) |
| y1 = int(max(0, min(height - 1, round(float(box["y1"]))))) |
| x2 = int(max(x1 + 1, min(width, round(float(box["x2"]))))) |
| y2 = int(max(y1 + 1, min(height, round(float(box["y2"]))))) |
| return {"x1": x1, "y1": y1, "x2": x2, "y2": y2} |
|
|
| def _encode_cropped_mask(self, mask: np.ndarray, box: Dict[str, int]) -> str: |
| cropped = mask[box["y1"] : box["y2"], box["x1"] : box["x2"]] |
| mask_image = Image.fromarray((cropped.astype(np.uint8) * 255), mode="L") |
| output = BytesIO() |
| mask_image.save(output, format="PNG") |
| return base64.b64encode(output.getvalue()).decode("ascii") |
|
|