watermark-remover / sam_segment.py
the-adrianator's picture
Initial commit: AI watermark remover
b2c1b6b
"""
SAM (Segment Anything Model) integration.
Provides click-to-segment: user clicks a point on the image/canvas,
SAM returns a precise mask of the object at that point.
Model: vit_b (375MB) — fast on M4 Pro via MPS.
Loaded once at module level and reused across requests.
"""
from pathlib import Path
import cv2
import numpy as np
import torch
_BASE_DIR = Path(__file__).parent
_CHECKPOINT = _BASE_DIR / "models" / "sam_vit_b.pth"
# Prefer MPS (Apple Silicon GPU), fall back to CPU
_DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
_predictor = None # lazy-loaded on first use
def _get_predictor():
global _predictor
if _predictor is not None:
return _predictor
from segment_anything import sam_model_registry, SamPredictor
sam = sam_model_registry["vit_b"](checkpoint=str(_CHECKPOINT))
sam.to(_DEVICE)
_predictor = SamPredictor(sam)
return _predictor
SAM_AVAILABLE = _CHECKPOINT.exists()
def segment_at_point(
img: np.ndarray,
x: float,
y: float,
canvas_w: int,
canvas_h: int,
) -> np.ndarray:
"""
Run SAM with a single click point and return a binary mask.
Args:
img: BGR image (original resolution).
x, y: Click coordinates in canvas space.
canvas_w/h: Canvas dimensions used to normalise the click.
Returns:
uint8 mask (255 = region to inpaint, 0 = keep), same size as img.
"""
predictor = _get_predictor()
h, w = img.shape[:2]
# Map canvas coords → image coords
img_x = int(x / canvas_w * w)
img_y = int(y / canvas_h * h)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
predictor.set_image(img_rgb)
point_coords = np.array([[img_x, img_y]])
point_labels = np.array([1]) # 1 = foreground
masks, scores, _ = predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
multimask_output=True,
)
# Pick the mask with the highest confidence score
best = masks[np.argmax(scores)]
# Convert bool mask → uint8 (0/255)
return (best.astype(np.uint8) * 255)