import base64 import os from contextlib import nullcontext from io import BytesIO from typing import Any, Dict import numpy as np import torch from PIL import Image from sam2.sam2_image_predictor import SAM2ImagePredictor MODEL_ID = "facebook/sam2.1-hiera-base-plus" class EndpointHandler: def __init__(self, path: str = ""): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model_id = os.environ.get("SAM2_MODEL_ID", MODEL_ID) self.predictor = SAM2ImagePredictor.from_pretrained( self.model_id, device=self.device, ) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: inputs = data.get("inputs", data) image = self._decode_image(inputs["image_base64"]).convert("RGB") mime_type = inputs.get("mime_type", "image/png") boxes = inputs.get("boxes", []) if not boxes: return {"masks": []} width, height = image.size normalized_boxes = [ { "id": str(box["id"]), "box": self._normalize_box(box["box"], width, height), } for box in boxes ] input_boxes = np.array( [ [ item["box"]["x1"], item["box"]["y1"], item["box"]["x2"], item["box"]["y2"], ] for item in normalized_boxes ], dtype=np.float32, ) image_array = np.array(image) with torch.inference_mode(), self._autocast_context(): self.predictor.set_image(image_array) masks, scores, _ = self.predictor.predict( box=input_boxes, multimask_output=False, ) masks = np.asarray(masks) scores = np.asarray(scores) response_masks = [] for index, item in enumerate(normalized_boxes): mask_bool = self._select_mask(masks, index) score = self._select_score(scores, index) response_masks.append( { "id": item["id"], "score": score, "mask_png_base64": self._encode_cropped_mask(mask_bool, item["box"]), "box": item["box"], "mime_type": mime_type, } ) return {"masks": response_masks} def _autocast_context(self): if self.device == "cuda": return torch.autocast("cuda", dtype=torch.bfloat16) return nullcontext() def _select_mask(self, masks: np.ndarray, index: int) -> np.ndarray: if masks.ndim == 4: return masks[index, 0] > 0 if masks.ndim == 3: return masks[index] > 0 if masks.ndim == 2: return masks > 0 raise ValueError(f"Unexpected mask tensor shape: {masks.shape}") def _select_score(self, scores: np.ndarray, index: int) -> float: if scores.ndim == 2: return float(scores[index, 0]) if scores.ndim == 1: return float(scores[index]) if scores.ndim == 0: return float(scores) raise ValueError(f"Unexpected score tensor shape: {scores.shape}") def _decode_image(self, image_base64: str) -> Image.Image: if "," in image_base64: image_base64 = image_base64.split(",", 1)[1] return Image.open(BytesIO(base64.b64decode(image_base64))) def _normalize_box( self, box: Dict[str, float], width: int, height: int, ) -> Dict[str, int]: x1 = int(max(0, min(width - 1, round(float(box["x1"]))))) y1 = int(max(0, min(height - 1, round(float(box["y1"]))))) x2 = int(max(x1 + 1, min(width, round(float(box["x2"]))))) y2 = int(max(y1 + 1, min(height, round(float(box["y2"]))))) return {"x1": x1, "y1": y1, "x2": x2, "y2": y2} def _encode_cropped_mask(self, mask: np.ndarray, box: Dict[str, int]) -> str: cropped = mask[box["y1"] : box["y2"], box["x1"] : box["x2"]] mask_image = Image.fromarray((cropped.astype(np.uint8) * 255), mode="L") output = BytesIO() mask_image.save(output, format="PNG") return base64.b64encode(output.getvalue()).decode("ascii")