File size: 2,114 Bytes
b2c1b6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""
SAM (Segment Anything Model) integration.

Provides click-to-segment: user clicks a point on the image/canvas,
SAM returns a precise mask of the object at that point.

Model: vit_b (375MB) — fast on M4 Pro via MPS.
Loaded once at module level and reused across requests.
"""

from pathlib import Path

import cv2
import numpy as np
import torch

_BASE_DIR   = Path(__file__).parent
_CHECKPOINT = _BASE_DIR / "models" / "sam_vit_b.pth"

# Prefer MPS (Apple Silicon GPU), fall back to CPU
_DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"

_predictor = None  # lazy-loaded on first use


def _get_predictor():
    global _predictor
    if _predictor is not None:
        return _predictor

    from segment_anything import sam_model_registry, SamPredictor

    sam = sam_model_registry["vit_b"](checkpoint=str(_CHECKPOINT))
    sam.to(_DEVICE)
    _predictor = SamPredictor(sam)
    return _predictor


SAM_AVAILABLE = _CHECKPOINT.exists()


def segment_at_point(
    img: np.ndarray,
    x: float,
    y: float,
    canvas_w: int,
    canvas_h: int,
) -> np.ndarray:
    """
    Run SAM with a single click point and return a binary mask.

    Args:
        img:       BGR image (original resolution).
        x, y:     Click coordinates in canvas space.
        canvas_w/h: Canvas dimensions used to normalise the click.

    Returns:
        uint8 mask (255 = region to inpaint, 0 = keep), same size as img.
    """
    predictor = _get_predictor()

    h, w = img.shape[:2]

    # Map canvas coords → image coords
    img_x = int(x / canvas_w * w)
    img_y = int(y / canvas_h * h)

    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    predictor.set_image(img_rgb)

    point_coords = np.array([[img_x, img_y]])
    point_labels = np.array([1])  # 1 = foreground

    masks, scores, _ = predictor.predict(
        point_coords=point_coords,
        point_labels=point_labels,
        multimask_output=True,
    )

    # Pick the mask with the highest confidence score
    best = masks[np.argmax(scores)]

    # Convert bool mask → uint8 (0/255)
    return (best.astype(np.uint8) * 255)