import torch import numpy as np from PIL import Image import base64 from io import BytesIO from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator from sam2.sam2_image_predictor import SAM2ImagePredictor class EndpointHandler: def __init__(self, path=""): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 from sam2.build_sam import build_sam2_hf self.sam2_model = build_sam2_hf( "facebook/sam2.1-hiera-base-plus", device=self.device, apply_postprocessing=False ) self.predictor = SAM2ImagePredictor(self.sam2_model) def _run_automatic(self, image_np: np.ndarray, cfg: dict) -> dict: mask_generator = SAM2AutomaticMaskGenerator( model=self.sam2_model, points_per_side=cfg.get("points_per_side", 16), pred_iou_thresh=cfg.get("pred_iou_thresh", 0.7), stability_score_thresh=cfg.get("stability_score_thresh", 0.9), min_mask_region_area=cfg.get("min_mask_region_area", 200), crop_n_layers=cfg.get("crop_n_layers", 0), crop_n_points_downscale_factor=cfg.get("crop_n_points_downscale_factor", 1), ) with torch.inference_mode(), torch.autocast(self.device, dtype=self.dtype): masks = mask_generator.generate(image_np) output = [ { "segmentation": m["segmentation"].tolist(), # bool mask [H, W] "area": m["area"], "bbox": m["bbox"], # [x, y, w, h] "predicted_iou": m["predicted_iou"], "stability_score": m["stability_score"], } for m in masks ] return {"masks": output, "count": len(output)} def _run_point(self, image_np: np.ndarray, cfg: dict) -> dict: points = cfg.get("points") if not points: return {"error": "mode 'point' requires 'points': [[x, y], ...] in config"} labels = cfg.get("labels", [1] * len(points)) multimask_output = cfg.get("multimask_output", True) with torch.inference_mode(), torch.autocast(self.device, dtype=self.dtype): self.predictor.set_image(image_np) masks, scores, _ = self.predictor.predict( point_coords=np.array(points), point_labels=np.array(labels), multimask_output=multimask_output, ) order = np.argsort(scores)[::-1] output = [ { "segmentation": masks[i].astype(bool).tolist(), "score": float(scores[i]), } for i in order ] return {"masks": output, "count": len(output)} def __call__(self, data: dict) -> dict: inputs_data = data.get("inputs", {}) image_b64 = inputs_data.get("image") if image_b64 is None: return {"error": "missing 'image' in inputs"} image_np = np.array( Image.open(BytesIO(base64.b64decode(image_b64))).convert("RGB") ) cfg = inputs_data.get("config", {}) mode = cfg.get("mode", "automatic") if mode == "point": return self._run_point(image_np, cfg) elif mode == "automatic": return self._run_automatic(image_np, cfg) else: return {"error": f"unknown mode '{mode}', expected 'automatic' or 'point'"}