Spaces:
Sleeping
Sleeping
File size: 7,674 Bytes
775bb75 cc58918 c401d3e 775bb75 cc58918 775bb75 c401d3e cc58918 c401d3e cc58918 c401d3e cc58918 c401d3e cc58918 c401d3e 775bb75 c401d3e 3f3478e c401d3e 3f3478e c401d3e 5ec223d c401d3e 5ec223d c401d3e 775bb75 c401d3e 5ec223d c401d3e 775bb75 c401d3e 775bb75 cc58918 775bb75 3f3478e c401d3e 3f3478e c401d3e 3f3478e c401d3e 0929960 3f3478e c401d3e 775bb75 c401d3e 803d6ed 0929960 dedeceb 803d6ed dedeceb 803d6ed b06ad3e 803d6ed b06ad3e 803d6ed b06ad3e 803d6ed dedeceb c401d3e cc58918 775bb75 c401d3e 5ec223d c401d3e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | import sys
import time
import numpy as np
from PIL import Image
import torch
def log(msg: str):
print(msg, flush=True)
def _make_log_tqdm():
"""tqdm subclass that routes per-file download progress to our log buffer."""
try:
from tqdm.auto import tqdm as _Base
except ImportError:
from tqdm import tqdm as _Base
class _LogTqdm(_Base):
def __init__(self, *args, **kwargs):
self.__last_pct = -1
# disable=True suppresses terminal rendering; tracking still works
super().__init__(*args, disable=True, **kwargs)
if self.total and self.total > 100_000:
log(f"[SAM2] Downloading {self.desc or 'file'} ({self.total/1e6:.1f} MB) ...")
def update(self, n=1):
super().update(n)
if not self.total or self.total <= 100_000:
return
pct = min(100, int(self.n / self.total * 100))
if pct >= self.__last_pct + 10:
log(f"[SAM2] {self.desc}: {self.n/1e6:.0f}/{self.total/1e6:.0f} MB ({pct}%)")
self.__last_pct = pct
def close(self):
super().close()
if self.total and self.total > 100_000 and self.n >= self.total * 0.99:
log(f"[SAM2] {self.desc}: β done")
return _LogTqdm
def load_sam2():
from huggingface_hub import snapshot_download
from transformers import Sam2Model, Sam2Processor
model_id = "facebook/sam2-hiera-large"
# Phase 1: download (instant if already cached; shows per-file progress if not)
log("[SAM2] Checking model files in HF cache ...")
t0 = time.time()
snapshot_download(model_id, tqdm_class=_make_log_tqdm())
log(f"[SAM2] Cache ready ({time.time()-t0:.1f}s). Loading processor ...")
# Phase 2: deserialize processor
t1 = time.time()
processor = Sam2Processor.from_pretrained(model_id)
log(f"[SAM2] Processor loaded ({time.time()-t1:.1f}s). Loading model weights ...")
# Phase 3: deserialize model (~1-2 GB into GPU RAM β can take 30-60s)
t2 = time.time()
model = Sam2Model.from_pretrained(model_id)
model.eval()
log(f"[SAM2] Model loaded ({time.time()-t2:.1f}s). Total init: {time.time()-t0:.1f}s.")
return model, processor
_sam2_cache = None
def get_sam2():
global _sam2_cache
if _sam2_cache is None:
log("[SAM2] Cold start β initializing model for the first time ...")
_sam2_cache = load_sam2()
else:
log("[SAM2] Using cached model.")
return _sam2_cache
# Each prompt: (click_x, click_y, bbox_x1, bbox_y1, bbox_x2, bbox_y2)
# All values normalized [0,1]. Bbox constrains SAM2 to look only within
# that region, which is far more reliable than a point alone for body parts.
DEFAULT_PROMPTS = {
"breast_left": (0.40, 0.36, 0.28, 0.26, 0.50, 0.46),
"breast_right": (0.60, 0.36, 0.50, 0.26, 0.72, 0.46),
"buttocks": (0.50, 0.72, 0.30, 0.62, 0.70, 0.85),
"ponytail": (0.50, 0.10, 0.35, 0.00, 0.65, 0.20),
"hair": (0.50, 0.10, 0.30, 0.00, 0.70, 0.25),
}
ANATOMY_REGIONS = {"breast_left", "breast_right", "buttocks"}
def segment_regions(image: Image.Image, requested: list[str], click_points: dict | None = None) -> dict:
log(f"[Segment] Requested: {requested} | image size: {image.size}")
# Body-region masks come from MediaPipe pose + ellipse β not SAM2.
# SAM2 segments by pixel similarity, which on clothed photos catches the
# tank top / shirt color rather than the underlying anatomy.
anatomy_requests = [r for r in requested if r in ANATOMY_REGIONS]
sam_requests = [r for r in requested if r not in ANATOMY_REGIONS]
results: dict = {}
if anatomy_requests:
from anatomy import segment_anatomy
results.update(segment_anatomy(image, anatomy_requests))
if not sam_requests:
log(f"[Segment] All {len(results)} regions complete (anatomy only).")
return results
log(f"[SAM2] Falling back to SAM2 for: {sam_requests}")
model, processor = get_sam2()
device = "cuda" if torch.cuda.is_available() else "cpu"
log(f"[SAM2] Using device: {device}")
model = model.to(device)
W, H = image.size
for i, region in enumerate(sam_requests):
if region not in DEFAULT_PROMPTS:
log(f"[SAM2] Skipping unknown region: {region}")
continue
log(f"[SAM2] Processing region {i+1}/{len(requested)}: {region} ...")
t = time.time()
prompt = DEFAULT_PROMPTS[region]
if click_points and region in click_points:
px, py = click_points[region]
else:
px, py = prompt[0] * W, prompt[1] * H
# Pass both a click point AND a bounding box. The bbox constrains SAM2
# to segment only inside that region, which is essential for parts of
# a body where a click alone yields ambiguous results (subpart vs torso
# vs whole subject).
bx1, by1, bx2, by2 = prompt[2] * W, prompt[3] * H, prompt[4] * W, prompt[5] * H
log(f"[SAM2] {region} click=({px:.0f},{py:.0f}) bbox=({bx1:.0f},{by1:.0f},{bx2:.0f},{by2:.0f})")
# 4-level nesting: [image][object][point][xy]; boxes: [image][object][xyxy]
inputs = processor(
images=image,
input_points=[[[[px, py]]]],
input_boxes=[[[bx1, by1, bx2, by2]]],
return_tensors="pt",
).to(device)
with torch.no_grad():
outputs = model(**inputs)
masks = processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
)[0]
# SAM2 returns 3 masks (subpart / part / whole). argmax-ing IoU scores
# often picks the "whole subject" mask, which is wrong for body-region
# segmentation β we want the local part the click landed on.
# Pick the smallest mask whose area is between 0.5% and 40% of the image.
scores = outputs.iou_scores[0, 0].cpu().numpy()
mtensor = masks[0].numpy()
if mtensor.ndim == 4:
mtensor = mtensor[0]
# mtensor is now (num_masks, H, W)
total_px = mtensor.shape[1] * mtensor.shape[2]
areas = [int(np.sum(m > 0)) for m in mtensor]
log(f"[SAM2] mask shape: {mtensor.shape}, areas: {areas}, scores: {scores.tolist()}")
# Filter to masks with area between 0.5% and 40% of image, then pick
# the one with the *highest* IoU score (model's own confidence) β not
# the smallest, which often gave us a low-confidence sliver.
candidates = [
i for i in range(len(mtensor))
if 0.005 * total_px <= areas[i] <= 0.40 * total_px
]
if candidates:
best = max(candidates, key=lambda i: scores[i])
log(f"[SAM2] picked mask idx={best} (highest score within 0.5β40% range, score={scores[best]:.3f}, area={areas[best]})")
else:
best = int(np.argmax(scores))
log(f"[SAM2] no mask in range β falling back to argmax idx={best}")
mask = mtensor[best].astype(bool)
rows = np.any(mask, axis=1)
cols = np.any(mask, axis=0)
rmin, rmax = np.where(rows)[0][[0, -1]]
cmin, cmax = np.where(cols)[0][[0, -1]]
log(f"[SAM2] '{region}' done in {time.time()-t:.1f}s β bbox=[{cmin},{rmin},{cmax-cmin},{rmax-rmin}] score={scores[best]:.3f}")
results[region] = {
"mask": mask.tolist(),
"bbox": [int(cmin), int(rmin), int(cmax - cmin), int(rmax - rmin)],
}
log(f"[Segment] All {len(results)} regions complete.")
return results
|