ring-sizer / script /experiment_sam_mask_quality.py
feng-x's picture
Upload folder using huggingface_hub
0c727ab verified
"""
Experiment: isolate the cause of the 'staircase' edges on our SAM 2.1 hand mask.
Runs five configurations of SAM 2.1 Hiera Small against a single sample image,
seeded by the MediaPipe palm center. All five variants feed the *same* pixel
prompts to the *same* model; only the inference input resolution and the
mask-upscale strategy differ.
Configurations:
A baseline 1024-long-side hard-mask + INTER_NEAREST upscale (current)
B lin_hard 1024-long-side hard-mask + INTER_LINEAR + rethresh
C soft 1024-long-side raw logits -> bilinear -> threshold full-res
D native_nn native hard-mask + INTER_NEAREST (no upscale needed)
E native_sft native raw logits -> bilinear -> threshold
For each variant the script saves:
- full-resolution binary mask PNG
- hand overlay with yellow mask + green palm prompt dot
- 600x600 fingertip crop centered on the middle-finger tip (landmark 12)
so the staircase vs smooth comparison is visible at a glance
It also prints a small table:
- perimeter (px) cv2.arcLength of the largest contour
- iso ratio perimeter / sqrt(area) -- higher = more jagged
- rel. to baseline (%) iso ratio relative to config A
Usage:
python script/experiment_sam_mask_quality.py \\
--input input/sample-04-12/card_2.jpg \\
--output-dir output/sam_mask_quality
"""
from __future__ import annotations
import argparse
import sys
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import cv2
import numpy as np
# Add repo root so we can import src.*
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from src.finger_segmentation import segment_hand
from src.sam_backend import get_sam2
from src.sam_hand_segmentation import palm_center_from_landmarks
# ----------------------------------------------------------------------------
# SAM inference variants
# ----------------------------------------------------------------------------
def _run_sam(
image_bgr: np.ndarray,
palm_xy: Tuple[int, int],
inference_long_side: Optional[int],
upscale_mode: str,
) -> Tuple[np.ndarray, float, float]:
"""Run SAM 2.1 with controlled inference resolution and upscale path.
Args:
image_bgr: Full-resolution BGR image (canonical orientation).
palm_xy: (x, y) pixel coords of palm center in the full-res image.
inference_long_side: If set, downscale so long-side equals this value.
If None, feed native resolution.
upscale_mode: One of:
- "nearest_hard": post_process_masks -> INTER_NEAREST to full res.
- "linear_hard": post_process_masks -> INTER_LINEAR -> re-threshold.
- "soft": raw pred_masks (256x256) -> bilinear to full res
-> threshold at 0.0.
Returns:
(mask_full: bool HxW, iou_score: float, infer_seconds: float)
"""
import torch
import torch.nn.functional as F
from PIL import Image as PILImage
h_full, w_full = image_bgr.shape[:2]
long_side = max(h_full, w_full)
if inference_long_side is None or long_side <= inference_long_side:
scaled_bgr = image_bgr
scale_back = 1.0
else:
s = inference_long_side / long_side
new_w = int(round(w_full * s))
new_h = int(round(h_full * s))
scaled_bgr = cv2.resize(image_bgr, (new_w, new_h), interpolation=cv2.INTER_AREA)
scale_back = 1.0 / s
scaled_rgb = cv2.cvtColor(scaled_bgr, cv2.COLOR_BGR2RGB)
pil = PILImage.fromarray(scaled_rgb)
scale_down = 1.0 / scale_back
palm_scaled = [
int(round(palm_xy[0] * scale_down)),
int(round(palm_xy[1] * scale_down)),
]
model, processor = get_sam2()
t0 = time.time()
inputs = processor(
images=pil,
input_points=[[[palm_scaled]]],
input_labels=[[[1]]],
return_tensors="pt",
)
with torch.inference_mode():
outputs = model(**inputs, multimask_output=True)
pred_masks = outputs.pred_masks.cpu() # (1, 1, num_cands, H_low, W_low)
iou_scores = outputs.iou_scores.cpu().numpy()[0, 0]
best_idx = int(np.argmax(iou_scores))
best_score = float(iou_scores[best_idx])
if upscale_mode == "soft":
logits = pred_masks[0, 0, best_idx].to(torch.float32) # (H_low, W_low)
logits_4d = logits.unsqueeze(0).unsqueeze(0)
upsampled = F.interpolate(
logits_4d,
size=(h_full, w_full),
mode="bilinear",
align_corners=False,
)[0, 0].numpy()
mask_full = upsampled > 0.0
else:
masks_scaled = processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"],
mask_threshold=0.0,
)[0][0]
mask_scaled = masks_scaled[best_idx].numpy().astype(np.uint8) # scaled-res
if mask_scaled.shape != (h_full, w_full):
if upscale_mode == "nearest_hard":
interp = cv2.INTER_NEAREST
elif upscale_mode == "linear_hard":
interp = cv2.INTER_LINEAR
else:
raise ValueError(f"unknown upscale_mode: {upscale_mode}")
resized = cv2.resize(mask_scaled, (w_full, h_full), interpolation=interp)
if upscale_mode == "linear_hard":
mask_full = resized >= 1 # rethreshold after linear interp
else:
mask_full = resized.astype(bool)
else:
mask_full = mask_scaled.astype(bool)
return mask_full, best_score, time.time() - t0
# ----------------------------------------------------------------------------
# Metrics + visualization helpers
# ----------------------------------------------------------------------------
def _roughness_metrics(mask: np.ndarray) -> Dict[str, float]:
"""Perimeter + isoperimetric ratio of the largest contour."""
mask_u8 = (mask.astype(np.uint8)) * 255
contours, _ = cv2.findContours(mask_u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
if not contours:
return {"perimeter_px": 0.0, "area_px": 0.0, "iso_ratio": float("nan")}
largest = max(contours, key=cv2.contourArea)
perim = float(cv2.arcLength(largest, closed=True))
area = float(cv2.contourArea(largest))
iso = perim / (np.sqrt(area) + 1e-9)
return {"perimeter_px": perim, "area_px": area, "iso_ratio": iso}
def _save_overlay(
path: Path,
image_bgr: np.ndarray,
mask: np.ndarray,
palm_xy: Tuple[int, int],
label: str,
) -> None:
overlay = image_bgr.copy()
tint = np.zeros_like(overlay)
tint[mask] = (0, 255, 255)
overlay = cv2.addWeighted(overlay, 1.0, tint, 0.35, 0)
contours, _ = cv2.findContours(
(mask.astype(np.uint8)) * 255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
)
cv2.drawContours(overlay, contours, -1, (0, 255, 255), 2, cv2.LINE_AA)
cv2.circle(overlay, palm_xy, 18, (0, 255, 0), -1)
cv2.circle(overlay, palm_xy, 18, (0, 0, 0), 3)
cv2.putText(overlay, label, (30, 60), cv2.FONT_HERSHEY_SIMPLEX, 1.1,
(255, 255, 255), 5, cv2.LINE_AA)
cv2.putText(overlay, label, (30, 60), cv2.FONT_HERSHEY_SIMPLEX, 1.1,
(0, 255, 255), 2, cv2.LINE_AA)
cv2.imwrite(str(path), overlay)
def _save_fingertip_crop(
path: Path,
image_bgr: np.ndarray,
mask: np.ndarray,
center_xy: Tuple[int, int],
crop_half: int = 300,
label: str = "",
) -> None:
h, w = image_bgr.shape[:2]
cx, cy = center_xy
x0 = max(0, cx - crop_half)
y0 = max(0, cy - crop_half)
x1 = min(w, cx + crop_half)
y1 = min(h, cy + crop_half)
crop = image_bgr[y0:y1, x0:x1].copy()
mask_crop = mask[y0:y1, x0:x1]
tint = np.zeros_like(crop)
tint[mask_crop] = (0, 255, 255)
crop = cv2.addWeighted(crop, 1.0, tint, 0.4, 0)
contours, _ = cv2.findContours(
(mask_crop.astype(np.uint8)) * 255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
)
cv2.drawContours(crop, contours, -1, (0, 255, 255), 2, cv2.LINE_AA)
if label:
cv2.putText(crop, label, (15, 35), cv2.FONT_HERSHEY_SIMPLEX, 0.9,
(255, 255, 255), 4, cv2.LINE_AA)
cv2.putText(crop, label, (15, 35), cv2.FONT_HERSHEY_SIMPLEX, 0.9,
(0, 255, 255), 2, cv2.LINE_AA)
cv2.imwrite(str(path), crop)
# ----------------------------------------------------------------------------
# Main
# ----------------------------------------------------------------------------
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--input", default="input/sample-04-12/card_2.jpg")
parser.add_argument("--output-dir", default="output/sam_mask_quality")
args = parser.parse_args()
in_path = Path(args.input)
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
image = cv2.imread(str(in_path))
if image is None:
print(f"ERROR: could not read {in_path}")
return 2
print(f"Input: {in_path} ({image.shape[1]}x{image.shape[0]})")
# Get canonical image + landmarks WITHOUT running SAM. Pass a large
# max_dimension so MediaPipe's internal resize does not happen and the
# canonical image stays at native resolution.
hand_data = segment_hand(
image=image,
finger="middle",
max_dimension=3000,
debug_dir=None,
use_sam_mask=False,
)
if hand_data is None:
print("ERROR: hand detection failed")
return 2
canonical = hand_data["canonical_image"]
landmarks = hand_data["landmarks"] # (21, 2) in canonical px coords
palm_xy = palm_center_from_landmarks(landmarks)
middle_tip_xy = (int(round(landmarks[12, 0])), int(round(landmarks[12, 1])))
ch, cw = canonical.shape[:2]
print(f"Canonical: {cw}x{ch} palm=({palm_xy[0]},{palm_xy[1]}) "
f"middle_tip=({middle_tip_xy[0]},{middle_tip_xy[1]})")
# Save canonical reference image
cv2.imwrite(str(out_dir / "00_canonical.png"), canonical)
variants = [
("A_baseline_1024_nn", 1024, "nearest_hard"),
("B_1024_linear_hard", 1024, "linear_hard"),
("C_1024_soft", 1024, "soft"),
("D_native_nn", None, "nearest_hard"),
("E_native_soft", None, "soft"),
]
results = []
for name, long_side, mode in variants:
print(f"\n=== {name} long_side={long_side} mode={mode} ===")
mask, score, secs = _run_sam(canonical, palm_xy, long_side, mode)
m = _roughness_metrics(mask)
print(f" iou={score:.3f} time={secs:.2f}s "
f"perim={m['perimeter_px']:.0f}px iso={m['iso_ratio']:.3f}")
cv2.imwrite(str(out_dir / f"{name}_mask.png"),
(mask.astype(np.uint8)) * 255)
_save_overlay(
out_dir / f"{name}_overlay.png",
canonical, mask, palm_xy,
label=f"{name} iou={score:.2f}",
)
_save_fingertip_crop(
out_dir / f"{name}_fingertip.png",
canonical, mask, middle_tip_xy,
crop_half=300,
label=name,
)
results.append((name, score, secs, m["perimeter_px"], m["iso_ratio"]))
# Summary table
print("\n")
print("=" * 78)
print(f"{'variant':<22}{'iou':>8}{'time(s)':>10}"
f"{'perim(px)':>14}{'iso':>10}{'vs A (%)':>12}")
print("-" * 78)
iso_base = results[0][4]
for name, score, secs, perim, iso in results:
rel = (iso / iso_base - 1.0) * 100.0 if iso_base else float("nan")
print(f"{name:<22}{score:>8.3f}{secs:>10.2f}"
f"{perim:>14.0f}{iso:>10.3f}{rel:>11.1f}%")
print("=" * 78)
# Side-by-side fingertip comparison
crops = [cv2.imread(str(out_dir / f"{name}_fingertip.png")) for name, *_ in results]
if all(c is not None for c in crops):
panel = np.hstack(crops)
cv2.imwrite(str(out_dir / "fingertip_comparison.png"), panel)
print(f"\nFingertip comparison strip: {out_dir / 'fingertip_comparison.png'}")
print(f"\nAll outputs saved to: {out_dir}/")
return 0
if __name__ == "__main__":
raise SystemExit(main())