import sys import time import numpy as np from PIL import Image import torch def log(msg: str): print(msg, flush=True) def _make_log_tqdm(): """tqdm subclass that routes per-file download progress to our log buffer.""" try: from tqdm.auto import tqdm as _Base except ImportError: from tqdm import tqdm as _Base class _LogTqdm(_Base): def __init__(self, *args, **kwargs): self.__last_pct = -1 # disable=True suppresses terminal rendering; tracking still works super().__init__(*args, disable=True, **kwargs) if self.total and self.total > 100_000: log(f"[SAM2] Downloading {self.desc or 'file'} ({self.total/1e6:.1f} MB) ...") def update(self, n=1): super().update(n) if not self.total or self.total <= 100_000: return pct = min(100, int(self.n / self.total * 100)) if pct >= self.__last_pct + 10: log(f"[SAM2] {self.desc}: {self.n/1e6:.0f}/{self.total/1e6:.0f} MB ({pct}%)") self.__last_pct = pct def close(self): super().close() if self.total and self.total > 100_000 and self.n >= self.total * 0.99: log(f"[SAM2] {self.desc}: ✓ done") return _LogTqdm def load_sam2(): from huggingface_hub import snapshot_download from transformers import Sam2Model, Sam2Processor model_id = "facebook/sam2-hiera-large" # Phase 1: download (instant if already cached; shows per-file progress if not) log("[SAM2] Checking model files in HF cache ...") t0 = time.time() snapshot_download(model_id, tqdm_class=_make_log_tqdm()) log(f"[SAM2] Cache ready ({time.time()-t0:.1f}s). Loading processor ...") # Phase 2: deserialize processor t1 = time.time() processor = Sam2Processor.from_pretrained(model_id) log(f"[SAM2] Processor loaded ({time.time()-t1:.1f}s). Loading model weights ...") # Phase 3: deserialize model (~1-2 GB into GPU RAM — can take 30-60s) t2 = time.time() model = Sam2Model.from_pretrained(model_id) model.eval() log(f"[SAM2] Model loaded ({time.time()-t2:.1f}s). Total init: {time.time()-t0:.1f}s.") return model, processor _sam2_cache = None def get_sam2(): global _sam2_cache if _sam2_cache is None: log("[SAM2] Cold start — initializing model for the first time ...") _sam2_cache = load_sam2() else: log("[SAM2] Using cached model.") return _sam2_cache # Each prompt: (click_x, click_y, bbox_x1, bbox_y1, bbox_x2, bbox_y2) # All values normalized [0,1]. Bbox constrains SAM2 to look only within # that region, which is far more reliable than a point alone for body parts. DEFAULT_PROMPTS = { "breast_left": (0.40, 0.36, 0.28, 0.26, 0.50, 0.46), "breast_right": (0.60, 0.36, 0.50, 0.26, 0.72, 0.46), "buttocks": (0.50, 0.72, 0.30, 0.62, 0.70, 0.85), "ponytail": (0.50, 0.10, 0.35, 0.00, 0.65, 0.20), "hair": (0.50, 0.10, 0.30, 0.00, 0.70, 0.25), } ANATOMY_REGIONS = {"breast_left", "breast_right", "buttocks"} def segment_regions(image: Image.Image, requested: list[str], click_points: dict | None = None) -> dict: log(f"[Segment] Requested: {requested} | image size: {image.size}") # Body-region masks come from MediaPipe pose + ellipse — not SAM2. # SAM2 segments by pixel similarity, which on clothed photos catches the # tank top / shirt color rather than the underlying anatomy. anatomy_requests = [r for r in requested if r in ANATOMY_REGIONS] sam_requests = [r for r in requested if r not in ANATOMY_REGIONS] results: dict = {} if anatomy_requests: from anatomy import segment_anatomy results.update(segment_anatomy(image, anatomy_requests)) if not sam_requests: log(f"[Segment] All {len(results)} regions complete (anatomy only).") return results log(f"[SAM2] Falling back to SAM2 for: {sam_requests}") model, processor = get_sam2() device = "cuda" if torch.cuda.is_available() else "cpu" log(f"[SAM2] Using device: {device}") model = model.to(device) W, H = image.size for i, region in enumerate(sam_requests): if region not in DEFAULT_PROMPTS: log(f"[SAM2] Skipping unknown region: {region}") continue log(f"[SAM2] Processing region {i+1}/{len(requested)}: {region} ...") t = time.time() prompt = DEFAULT_PROMPTS[region] if click_points and region in click_points: px, py = click_points[region] else: px, py = prompt[0] * W, prompt[1] * H # Pass both a click point AND a bounding box. The bbox constrains SAM2 # to segment only inside that region, which is essential for parts of # a body where a click alone yields ambiguous results (subpart vs torso # vs whole subject). bx1, by1, bx2, by2 = prompt[2] * W, prompt[3] * H, prompt[4] * W, prompt[5] * H log(f"[SAM2] {region} click=({px:.0f},{py:.0f}) bbox=({bx1:.0f},{by1:.0f},{bx2:.0f},{by2:.0f})") # 4-level nesting: [image][object][point][xy]; boxes: [image][object][xyxy] inputs = processor( images=image, input_points=[[[[px, py]]]], input_boxes=[[[bx1, by1, bx2, by2]]], return_tensors="pt", ).to(device) with torch.no_grad(): outputs = model(**inputs) masks = processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), )[0] # SAM2 returns 3 masks (subpart / part / whole). argmax-ing IoU scores # often picks the "whole subject" mask, which is wrong for body-region # segmentation — we want the local part the click landed on. # Pick the smallest mask whose area is between 0.5% and 40% of the image. scores = outputs.iou_scores[0, 0].cpu().numpy() mtensor = masks[0].numpy() if mtensor.ndim == 4: mtensor = mtensor[0] # mtensor is now (num_masks, H, W) total_px = mtensor.shape[1] * mtensor.shape[2] areas = [int(np.sum(m > 0)) for m in mtensor] log(f"[SAM2] mask shape: {mtensor.shape}, areas: {areas}, scores: {scores.tolist()}") # Filter to masks with area between 0.5% and 40% of image, then pick # the one with the *highest* IoU score (model's own confidence) — not # the smallest, which often gave us a low-confidence sliver. candidates = [ i for i in range(len(mtensor)) if 0.005 * total_px <= areas[i] <= 0.40 * total_px ] if candidates: best = max(candidates, key=lambda i: scores[i]) log(f"[SAM2] picked mask idx={best} (highest score within 0.5–40% range, score={scores[best]:.3f}, area={areas[best]})") else: best = int(np.argmax(scores)) log(f"[SAM2] no mask in range — falling back to argmax idx={best}") mask = mtensor[best].astype(bool) rows = np.any(mask, axis=1) cols = np.any(mask, axis=0) rmin, rmax = np.where(rows)[0][[0, -1]] cmin, cmax = np.where(cols)[0][[0, -1]] log(f"[SAM2] '{region}' done in {time.time()-t:.1f}s — bbox=[{cmin},{rmin},{cmax-cmin},{rmax-rmin}] score={scores[best]:.3f}") results[region] = { "mask": mask.tolist(), "bbox": [int(cmin), int(rmin), int(cmax - cmin), int(rmax - rmin)], } log(f"[Segment] All {len(results)} regions complete.") return results