Spaces:
Paused
Paused
| """ | |
| SAM (Segment Anything Model) integration. | |
| Provides click-to-segment: user clicks a point on the image/canvas, | |
| SAM returns a precise mask of the object at that point. | |
| Model: vit_b (375MB) — fast on M4 Pro via MPS. | |
| Loaded once at module level and reused across requests. | |
| """ | |
| from pathlib import Path | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| _BASE_DIR = Path(__file__).parent | |
| _CHECKPOINT = _BASE_DIR / "models" / "sam_vit_b.pth" | |
| # Prefer MPS (Apple Silicon GPU), fall back to CPU | |
| _DEVICE = "mps" if torch.backends.mps.is_available() else "cpu" | |
| _predictor = None # lazy-loaded on first use | |
| def _get_predictor(): | |
| global _predictor | |
| if _predictor is not None: | |
| return _predictor | |
| from segment_anything import sam_model_registry, SamPredictor | |
| sam = sam_model_registry["vit_b"](checkpoint=str(_CHECKPOINT)) | |
| sam.to(_DEVICE) | |
| _predictor = SamPredictor(sam) | |
| return _predictor | |
| SAM_AVAILABLE = _CHECKPOINT.exists() | |
| def segment_at_point( | |
| img: np.ndarray, | |
| x: float, | |
| y: float, | |
| canvas_w: int, | |
| canvas_h: int, | |
| ) -> np.ndarray: | |
| """ | |
| Run SAM with a single click point and return a binary mask. | |
| Args: | |
| img: BGR image (original resolution). | |
| x, y: Click coordinates in canvas space. | |
| canvas_w/h: Canvas dimensions used to normalise the click. | |
| Returns: | |
| uint8 mask (255 = region to inpaint, 0 = keep), same size as img. | |
| """ | |
| predictor = _get_predictor() | |
| h, w = img.shape[:2] | |
| # Map canvas coords → image coords | |
| img_x = int(x / canvas_w * w) | |
| img_y = int(y / canvas_h * h) | |
| img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| predictor.set_image(img_rgb) | |
| point_coords = np.array([[img_x, img_y]]) | |
| point_labels = np.array([1]) # 1 = foreground | |
| masks, scores, _ = predictor.predict( | |
| point_coords=point_coords, | |
| point_labels=point_labels, | |
| multimask_output=True, | |
| ) | |
| # Pick the mask with the highest confidence score | |
| best = masks[np.argmax(scores)] | |
| # Convert bool mask → uint8 (0/255) | |
| return (best.astype(np.uint8) * 255) | |