""" SAM (Segment Anything Model) integration. Provides click-to-segment: user clicks a point on the image/canvas, SAM returns a precise mask of the object at that point. Model: vit_b (375MB) — fast on M4 Pro via MPS. Loaded once at module level and reused across requests. """ from pathlib import Path import cv2 import numpy as np import torch _BASE_DIR = Path(__file__).parent _CHECKPOINT = _BASE_DIR / "models" / "sam_vit_b.pth" # Prefer MPS (Apple Silicon GPU), fall back to CPU _DEVICE = "mps" if torch.backends.mps.is_available() else "cpu" _predictor = None # lazy-loaded on first use def _get_predictor(): global _predictor if _predictor is not None: return _predictor from segment_anything import sam_model_registry, SamPredictor sam = sam_model_registry["vit_b"](checkpoint=str(_CHECKPOINT)) sam.to(_DEVICE) _predictor = SamPredictor(sam) return _predictor SAM_AVAILABLE = _CHECKPOINT.exists() def segment_at_point( img: np.ndarray, x: float, y: float, canvas_w: int, canvas_h: int, ) -> np.ndarray: """ Run SAM with a single click point and return a binary mask. Args: img: BGR image (original resolution). x, y: Click coordinates in canvas space. canvas_w/h: Canvas dimensions used to normalise the click. Returns: uint8 mask (255 = region to inpaint, 0 = keep), same size as img. """ predictor = _get_predictor() h, w = img.shape[:2] # Map canvas coords → image coords img_x = int(x / canvas_w * w) img_y = int(y / canvas_h * h) img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) predictor.set_image(img_rgb) point_coords = np.array([[img_x, img_y]]) point_labels = np.array([1]) # 1 = foreground masks, scores, _ = predictor.predict( point_coords=point_coords, point_labels=point_labels, multimask_output=True, ) # Pick the mask with the highest confidence score best = masks[np.argmax(scores)] # Convert bool mask → uint8 (0/255) return (best.astype(np.uint8) * 255)