jiggle-physics / segmentation.py
Justin Wood
Replace SAM2 with geometric ellipses for anatomy regions
5ec223d
import sys
import time
import numpy as np
from PIL import Image
import torch
def log(msg: str):
print(msg, flush=True)
def _make_log_tqdm():
"""tqdm subclass that routes per-file download progress to our log buffer."""
try:
from tqdm.auto import tqdm as _Base
except ImportError:
from tqdm import tqdm as _Base
class _LogTqdm(_Base):
def __init__(self, *args, **kwargs):
self.__last_pct = -1
# disable=True suppresses terminal rendering; tracking still works
super().__init__(*args, disable=True, **kwargs)
if self.total and self.total > 100_000:
log(f"[SAM2] Downloading {self.desc or 'file'} ({self.total/1e6:.1f} MB) ...")
def update(self, n=1):
super().update(n)
if not self.total or self.total <= 100_000:
return
pct = min(100, int(self.n / self.total * 100))
if pct >= self.__last_pct + 10:
log(f"[SAM2] {self.desc}: {self.n/1e6:.0f}/{self.total/1e6:.0f} MB ({pct}%)")
self.__last_pct = pct
def close(self):
super().close()
if self.total and self.total > 100_000 and self.n >= self.total * 0.99:
log(f"[SAM2] {self.desc}: βœ“ done")
return _LogTqdm
def load_sam2():
from huggingface_hub import snapshot_download
from transformers import Sam2Model, Sam2Processor
model_id = "facebook/sam2-hiera-large"
# Phase 1: download (instant if already cached; shows per-file progress if not)
log("[SAM2] Checking model files in HF cache ...")
t0 = time.time()
snapshot_download(model_id, tqdm_class=_make_log_tqdm())
log(f"[SAM2] Cache ready ({time.time()-t0:.1f}s). Loading processor ...")
# Phase 2: deserialize processor
t1 = time.time()
processor = Sam2Processor.from_pretrained(model_id)
log(f"[SAM2] Processor loaded ({time.time()-t1:.1f}s). Loading model weights ...")
# Phase 3: deserialize model (~1-2 GB into GPU RAM β€” can take 30-60s)
t2 = time.time()
model = Sam2Model.from_pretrained(model_id)
model.eval()
log(f"[SAM2] Model loaded ({time.time()-t2:.1f}s). Total init: {time.time()-t0:.1f}s.")
return model, processor
_sam2_cache = None
def get_sam2():
global _sam2_cache
if _sam2_cache is None:
log("[SAM2] Cold start β€” initializing model for the first time ...")
_sam2_cache = load_sam2()
else:
log("[SAM2] Using cached model.")
return _sam2_cache
# Each prompt: (click_x, click_y, bbox_x1, bbox_y1, bbox_x2, bbox_y2)
# All values normalized [0,1]. Bbox constrains SAM2 to look only within
# that region, which is far more reliable than a point alone for body parts.
DEFAULT_PROMPTS = {
"breast_left": (0.40, 0.36, 0.28, 0.26, 0.50, 0.46),
"breast_right": (0.60, 0.36, 0.50, 0.26, 0.72, 0.46),
"buttocks": (0.50, 0.72, 0.30, 0.62, 0.70, 0.85),
"ponytail": (0.50, 0.10, 0.35, 0.00, 0.65, 0.20),
"hair": (0.50, 0.10, 0.30, 0.00, 0.70, 0.25),
}
ANATOMY_REGIONS = {"breast_left", "breast_right", "buttocks"}
def segment_regions(image: Image.Image, requested: list[str], click_points: dict | None = None) -> dict:
log(f"[Segment] Requested: {requested} | image size: {image.size}")
# Body-region masks come from MediaPipe pose + ellipse β€” not SAM2.
# SAM2 segments by pixel similarity, which on clothed photos catches the
# tank top / shirt color rather than the underlying anatomy.
anatomy_requests = [r for r in requested if r in ANATOMY_REGIONS]
sam_requests = [r for r in requested if r not in ANATOMY_REGIONS]
results: dict = {}
if anatomy_requests:
from anatomy import segment_anatomy
results.update(segment_anatomy(image, anatomy_requests))
if not sam_requests:
log(f"[Segment] All {len(results)} regions complete (anatomy only).")
return results
log(f"[SAM2] Falling back to SAM2 for: {sam_requests}")
model, processor = get_sam2()
device = "cuda" if torch.cuda.is_available() else "cpu"
log(f"[SAM2] Using device: {device}")
model = model.to(device)
W, H = image.size
for i, region in enumerate(sam_requests):
if region not in DEFAULT_PROMPTS:
log(f"[SAM2] Skipping unknown region: {region}")
continue
log(f"[SAM2] Processing region {i+1}/{len(requested)}: {region} ...")
t = time.time()
prompt = DEFAULT_PROMPTS[region]
if click_points and region in click_points:
px, py = click_points[region]
else:
px, py = prompt[0] * W, prompt[1] * H
# Pass both a click point AND a bounding box. The bbox constrains SAM2
# to segment only inside that region, which is essential for parts of
# a body where a click alone yields ambiguous results (subpart vs torso
# vs whole subject).
bx1, by1, bx2, by2 = prompt[2] * W, prompt[3] * H, prompt[4] * W, prompt[5] * H
log(f"[SAM2] {region} click=({px:.0f},{py:.0f}) bbox=({bx1:.0f},{by1:.0f},{bx2:.0f},{by2:.0f})")
# 4-level nesting: [image][object][point][xy]; boxes: [image][object][xyxy]
inputs = processor(
images=image,
input_points=[[[[px, py]]]],
input_boxes=[[[bx1, by1, bx2, by2]]],
return_tensors="pt",
).to(device)
with torch.no_grad():
outputs = model(**inputs)
masks = processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
)[0]
# SAM2 returns 3 masks (subpart / part / whole). argmax-ing IoU scores
# often picks the "whole subject" mask, which is wrong for body-region
# segmentation β€” we want the local part the click landed on.
# Pick the smallest mask whose area is between 0.5% and 40% of the image.
scores = outputs.iou_scores[0, 0].cpu().numpy()
mtensor = masks[0].numpy()
if mtensor.ndim == 4:
mtensor = mtensor[0]
# mtensor is now (num_masks, H, W)
total_px = mtensor.shape[1] * mtensor.shape[2]
areas = [int(np.sum(m > 0)) for m in mtensor]
log(f"[SAM2] mask shape: {mtensor.shape}, areas: {areas}, scores: {scores.tolist()}")
# Filter to masks with area between 0.5% and 40% of image, then pick
# the one with the *highest* IoU score (model's own confidence) β€” not
# the smallest, which often gave us a low-confidence sliver.
candidates = [
i for i in range(len(mtensor))
if 0.005 * total_px <= areas[i] <= 0.40 * total_px
]
if candidates:
best = max(candidates, key=lambda i: scores[i])
log(f"[SAM2] picked mask idx={best} (highest score within 0.5–40% range, score={scores[best]:.3f}, area={areas[best]})")
else:
best = int(np.argmax(scores))
log(f"[SAM2] no mask in range β€” falling back to argmax idx={best}")
mask = mtensor[best].astype(bool)
rows = np.any(mask, axis=1)
cols = np.any(mask, axis=0)
rmin, rmax = np.where(rows)[0][[0, -1]]
cmin, cmax = np.where(cols)[0][[0, -1]]
log(f"[SAM2] '{region}' done in {time.time()-t:.1f}s β€” bbox=[{cmin},{rmin},{cmax-cmin},{rmax-rmin}] score={scores[best]:.3f}")
results[region] = {
"mask": mask.tolist(),
"bbox": [int(cmin), int(rmin), int(cmax - cmin), int(rmax - rmin)],
}
log(f"[Segment] All {len(results)} regions complete.")
return results