koreyspace's picture
Fix SAM2 handler imports
3306ab7 verified
Raw
History Blame Contribute Delete
4.35 kB
import base64
import os
from contextlib import nullcontext
from io import BytesIO
from typing import Any, Dict
import numpy as np
import torch
from PIL import Image
from sam2.sam2_image_predictor import SAM2ImagePredictor
MODEL_ID = "facebook/sam2.1-hiera-base-plus"
class EndpointHandler:
def __init__(self, path: str = ""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model_id = os.environ.get("SAM2_MODEL_ID", MODEL_ID)
self.predictor = SAM2ImagePredictor.from_pretrained(
self.model_id,
device=self.device,
)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
inputs = data.get("inputs", data)
image = self._decode_image(inputs["image_base64"]).convert("RGB")
mime_type = inputs.get("mime_type", "image/png")
boxes = inputs.get("boxes", [])
if not boxes:
return {"masks": []}
width, height = image.size
normalized_boxes = [
{
"id": str(box["id"]),
"box": self._normalize_box(box["box"], width, height),
}
for box in boxes
]
input_boxes = np.array(
[
[
item["box"]["x1"],
item["box"]["y1"],
item["box"]["x2"],
item["box"]["y2"],
]
for item in normalized_boxes
],
dtype=np.float32,
)
image_array = np.array(image)
with torch.inference_mode(), self._autocast_context():
self.predictor.set_image(image_array)
masks, scores, _ = self.predictor.predict(
box=input_boxes,
multimask_output=False,
)
masks = np.asarray(masks)
scores = np.asarray(scores)
response_masks = []
for index, item in enumerate(normalized_boxes):
mask_bool = self._select_mask(masks, index)
score = self._select_score(scores, index)
response_masks.append(
{
"id": item["id"],
"score": score,
"mask_png_base64": self._encode_cropped_mask(mask_bool, item["box"]),
"box": item["box"],
"mime_type": mime_type,
}
)
return {"masks": response_masks}
def _autocast_context(self):
if self.device == "cuda":
return torch.autocast("cuda", dtype=torch.bfloat16)
return nullcontext()
def _select_mask(self, masks: np.ndarray, index: int) -> np.ndarray:
if masks.ndim == 4:
return masks[index, 0] > 0
if masks.ndim == 3:
return masks[index] > 0
if masks.ndim == 2:
return masks > 0
raise ValueError(f"Unexpected mask tensor shape: {masks.shape}")
def _select_score(self, scores: np.ndarray, index: int) -> float:
if scores.ndim == 2:
return float(scores[index, 0])
if scores.ndim == 1:
return float(scores[index])
if scores.ndim == 0:
return float(scores)
raise ValueError(f"Unexpected score tensor shape: {scores.shape}")
def _decode_image(self, image_base64: str) -> Image.Image:
if "," in image_base64:
image_base64 = image_base64.split(",", 1)[1]
return Image.open(BytesIO(base64.b64decode(image_base64)))
def _normalize_box(
self,
box: Dict[str, float],
width: int,
height: int,
) -> Dict[str, int]:
x1 = int(max(0, min(width - 1, round(float(box["x1"])))))
y1 = int(max(0, min(height - 1, round(float(box["y1"])))))
x2 = int(max(x1 + 1, min(width, round(float(box["x2"])))))
y2 = int(max(y1 + 1, min(height, round(float(box["y2"])))))
return {"x1": x1, "y1": y1, "x2": x2, "y2": y2}
def _encode_cropped_mask(self, mask: np.ndarray, box: Dict[str, int]) -> str:
cropped = mask[box["y1"] : box["y2"], box["x1"] : box["x2"]]
mask_image = Image.fromarray((cropped.astype(np.uint8) * 255), mode="L")
output = BytesIO()
mask_image.save(output, format="PNG")
return base64.b64encode(output.getvalue()).decode("ascii")