File size: 7,674 Bytes
775bb75
cc58918
c401d3e
 
 
 
 
775bb75
 
cc58918
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
775bb75
 
c401d3e
cc58918
c401d3e
cc58918
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c401d3e
cc58918
 
c401d3e
 
 
 
 
 
 
 
 
cc58918
c401d3e
775bb75
 
c401d3e
 
 
3f3478e
 
 
c401d3e
3f3478e
 
 
 
 
c401d3e
 
 
5ec223d
 
 
c401d3e
5ec223d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c401d3e
 
775bb75
c401d3e
 
 
 
5ec223d
c401d3e
775bb75
c401d3e
 
775bb75
cc58918
775bb75
3f3478e
c401d3e
 
 
3f3478e
c401d3e
3f3478e
 
 
 
 
 
 
 
c401d3e
 
0929960
3f3478e
c401d3e
 
 
 
 
 
 
 
 
775bb75
c401d3e
803d6ed
 
 
 
0929960
dedeceb
803d6ed
dedeceb
 
803d6ed
 
 
 
b06ad3e
 
 
803d6ed
b06ad3e
803d6ed
 
 
b06ad3e
 
803d6ed
 
 
 
dedeceb
c401d3e
 
 
 
 
 
cc58918
775bb75
c401d3e
 
 
 
 
5ec223d
c401d3e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import sys
import time
import numpy as np
from PIL import Image
import torch


def log(msg: str):
    print(msg, flush=True)


def _make_log_tqdm():
    """tqdm subclass that routes per-file download progress to our log buffer."""
    try:
        from tqdm.auto import tqdm as _Base
    except ImportError:
        from tqdm import tqdm as _Base

    class _LogTqdm(_Base):
        def __init__(self, *args, **kwargs):
            self.__last_pct = -1
            # disable=True suppresses terminal rendering; tracking still works
            super().__init__(*args, disable=True, **kwargs)
            if self.total and self.total > 100_000:
                log(f"[SAM2] Downloading {self.desc or 'file'} ({self.total/1e6:.1f} MB) ...")

        def update(self, n=1):
            super().update(n)
            if not self.total or self.total <= 100_000:
                return
            pct = min(100, int(self.n / self.total * 100))
            if pct >= self.__last_pct + 10:
                log(f"[SAM2]   {self.desc}: {self.n/1e6:.0f}/{self.total/1e6:.0f} MB ({pct}%)")
                self.__last_pct = pct

        def close(self):
            super().close()
            if self.total and self.total > 100_000 and self.n >= self.total * 0.99:
                log(f"[SAM2]   {self.desc}: βœ“ done")

    return _LogTqdm


def load_sam2():
    from huggingface_hub import snapshot_download
    from transformers import Sam2Model, Sam2Processor

    model_id = "facebook/sam2-hiera-large"

    # Phase 1: download (instant if already cached; shows per-file progress if not)
    log("[SAM2] Checking model files in HF cache ...")
    t0 = time.time()
    snapshot_download(model_id, tqdm_class=_make_log_tqdm())
    log(f"[SAM2] Cache ready ({time.time()-t0:.1f}s). Loading processor ...")

    # Phase 2: deserialize processor
    t1 = time.time()
    processor = Sam2Processor.from_pretrained(model_id)
    log(f"[SAM2] Processor loaded ({time.time()-t1:.1f}s). Loading model weights ...")

    # Phase 3: deserialize model (~1-2 GB into GPU RAM β€” can take 30-60s)
    t2 = time.time()
    model = Sam2Model.from_pretrained(model_id)
    model.eval()
    log(f"[SAM2] Model loaded ({time.time()-t2:.1f}s). Total init: {time.time()-t0:.1f}s.")

    return model, processor


_sam2_cache = None


def get_sam2():
    global _sam2_cache
    if _sam2_cache is None:
        log("[SAM2] Cold start β€” initializing model for the first time ...")
        _sam2_cache = load_sam2()
    else:
        log("[SAM2] Using cached model.")
    return _sam2_cache


# Each prompt: (click_x, click_y, bbox_x1, bbox_y1, bbox_x2, bbox_y2)
# All values normalized [0,1]. Bbox constrains SAM2 to look only within
# that region, which is far more reliable than a point alone for body parts.
DEFAULT_PROMPTS = {
    "breast_left":  (0.40, 0.36, 0.28, 0.26, 0.50, 0.46),
    "breast_right": (0.60, 0.36, 0.50, 0.26, 0.72, 0.46),
    "buttocks":     (0.50, 0.72, 0.30, 0.62, 0.70, 0.85),
    "ponytail":     (0.50, 0.10, 0.35, 0.00, 0.65, 0.20),
    "hair":         (0.50, 0.10, 0.30, 0.00, 0.70, 0.25),
}


ANATOMY_REGIONS = {"breast_left", "breast_right", "buttocks"}


def segment_regions(image: Image.Image, requested: list[str], click_points: dict | None = None) -> dict:
    log(f"[Segment] Requested: {requested} | image size: {image.size}")

    # Body-region masks come from MediaPipe pose + ellipse β€” not SAM2.
    # SAM2 segments by pixel similarity, which on clothed photos catches the
    # tank top / shirt color rather than the underlying anatomy.
    anatomy_requests = [r for r in requested if r in ANATOMY_REGIONS]
    sam_requests = [r for r in requested if r not in ANATOMY_REGIONS]

    results: dict = {}
    if anatomy_requests:
        from anatomy import segment_anatomy
        results.update(segment_anatomy(image, anatomy_requests))

    if not sam_requests:
        log(f"[Segment] All {len(results)} regions complete (anatomy only).")
        return results

    log(f"[SAM2] Falling back to SAM2 for: {sam_requests}")
    model, processor = get_sam2()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    log(f"[SAM2] Using device: {device}")
    model = model.to(device)

    W, H = image.size

    for i, region in enumerate(sam_requests):
        if region not in DEFAULT_PROMPTS:
            log(f"[SAM2] Skipping unknown region: {region}")
            continue

        log(f"[SAM2] Processing region {i+1}/{len(requested)}: {region} ...")
        t = time.time()

        prompt = DEFAULT_PROMPTS[region]
        if click_points and region in click_points:
            px, py = click_points[region]
        else:
            px, py = prompt[0] * W, prompt[1] * H

        # Pass both a click point AND a bounding box. The bbox constrains SAM2
        # to segment only inside that region, which is essential for parts of
        # a body where a click alone yields ambiguous results (subpart vs torso
        # vs whole subject).
        bx1, by1, bx2, by2 = prompt[2] * W, prompt[3] * H, prompt[4] * W, prompt[5] * H
        log(f"[SAM2] {region} click=({px:.0f},{py:.0f}) bbox=({bx1:.0f},{by1:.0f},{bx2:.0f},{by2:.0f})")

        # 4-level nesting: [image][object][point][xy]; boxes: [image][object][xyxy]
        inputs = processor(
            images=image,
            input_points=[[[[px, py]]]],
            input_boxes=[[[bx1, by1, bx2, by2]]],
            return_tensors="pt",
        ).to(device)

        with torch.no_grad():
            outputs = model(**inputs)

        masks = processor.post_process_masks(
            outputs.pred_masks.cpu(),
            inputs["original_sizes"].cpu(),
        )[0]

        # SAM2 returns 3 masks (subpart / part / whole). argmax-ing IoU scores
        # often picks the "whole subject" mask, which is wrong for body-region
        # segmentation β€” we want the local part the click landed on.
        # Pick the smallest mask whose area is between 0.5% and 40% of the image.
        scores = outputs.iou_scores[0, 0].cpu().numpy()
        mtensor = masks[0].numpy()
        if mtensor.ndim == 4:
            mtensor = mtensor[0]
        # mtensor is now (num_masks, H, W)
        total_px = mtensor.shape[1] * mtensor.shape[2]
        areas = [int(np.sum(m > 0)) for m in mtensor]
        log(f"[SAM2] mask shape: {mtensor.shape}, areas: {areas}, scores: {scores.tolist()}")

        # Filter to masks with area between 0.5% and 40% of image, then pick
        # the one with the *highest* IoU score (model's own confidence) β€” not
        # the smallest, which often gave us a low-confidence sliver.
        candidates = [
            i for i in range(len(mtensor))
            if 0.005 * total_px <= areas[i] <= 0.40 * total_px
        ]
        if candidates:
            best = max(candidates, key=lambda i: scores[i])
            log(f"[SAM2] picked mask idx={best} (highest score within 0.5–40% range, score={scores[best]:.3f}, area={areas[best]})")
        else:
            best = int(np.argmax(scores))
            log(f"[SAM2] no mask in range β€” falling back to argmax idx={best}")

        mask = mtensor[best].astype(bool)

        rows = np.any(mask, axis=1)
        cols = np.any(mask, axis=0)
        rmin, rmax = np.where(rows)[0][[0, -1]]
        cmin, cmax = np.where(cols)[0][[0, -1]]

        log(f"[SAM2] '{region}' done in {time.time()-t:.1f}s β€” bbox=[{cmin},{rmin},{cmax-cmin},{rmax-rmin}] score={scores[best]:.3f}")

        results[region] = {
            "mask": mask.tolist(),
            "bbox": [int(cmin), int(rmin), int(cmax - cmin), int(rmax - rmin)],
        }

    log(f"[Segment] All {len(results)} regions complete.")
    return results