Oleksandr Ternovyi
Read config from inside inputs, not top-level request body
0a2b18e
Raw
History Blame Contribute Delete
3.55 kB
import torch
import numpy as np
from PIL import Image
import base64
from io import BytesIO
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.sam2_image_predictor import SAM2ImagePredictor
class EndpointHandler:
def __init__(self, path=""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
from sam2.build_sam import build_sam2_hf
self.sam2_model = build_sam2_hf(
"facebook/sam2.1-hiera-base-plus",
device=self.device,
apply_postprocessing=False
)
self.predictor = SAM2ImagePredictor(self.sam2_model)
def _run_automatic(self, image_np: np.ndarray, cfg: dict) -> dict:
mask_generator = SAM2AutomaticMaskGenerator(
model=self.sam2_model,
points_per_side=cfg.get("points_per_side", 16),
pred_iou_thresh=cfg.get("pred_iou_thresh", 0.7),
stability_score_thresh=cfg.get("stability_score_thresh", 0.9),
min_mask_region_area=cfg.get("min_mask_region_area", 200),
crop_n_layers=cfg.get("crop_n_layers", 0),
crop_n_points_downscale_factor=cfg.get("crop_n_points_downscale_factor", 1),
)
with torch.inference_mode(), torch.autocast(self.device, dtype=self.dtype):
masks = mask_generator.generate(image_np)
output = [
{
"segmentation": m["segmentation"].tolist(), # bool mask [H, W]
"area": m["area"],
"bbox": m["bbox"], # [x, y, w, h]
"predicted_iou": m["predicted_iou"],
"stability_score": m["stability_score"],
}
for m in masks
]
return {"masks": output, "count": len(output)}
def _run_point(self, image_np: np.ndarray, cfg: dict) -> dict:
points = cfg.get("points")
if not points:
return {"error": "mode 'point' requires 'points': [[x, y], ...] in config"}
labels = cfg.get("labels", [1] * len(points))
multimask_output = cfg.get("multimask_output", True)
with torch.inference_mode(), torch.autocast(self.device, dtype=self.dtype):
self.predictor.set_image(image_np)
masks, scores, _ = self.predictor.predict(
point_coords=np.array(points),
point_labels=np.array(labels),
multimask_output=multimask_output,
)
order = np.argsort(scores)[::-1]
output = [
{
"segmentation": masks[i].astype(bool).tolist(),
"score": float(scores[i]),
}
for i in order
]
return {"masks": output, "count": len(output)}
def __call__(self, data: dict) -> dict:
inputs_data = data.get("inputs", {})
image_b64 = inputs_data.get("image")
if image_b64 is None:
return {"error": "missing 'image' in inputs"}
image_np = np.array(
Image.open(BytesIO(base64.b64decode(image_b64))).convert("RGB")
)
cfg = inputs_data.get("config", {})
mode = cfg.get("mode", "automatic")
if mode == "point":
return self._run_point(image_np, cfg)
elif mode == "automatic":
return self._run_automatic(image_np, cfg)
else:
return {"error": f"unknown mode '{mode}', expected 'automatic' or 'point'"}