ring-sizer / script /compare_hand_sam.py
feng-x's picture
Upload folder using huggingface_hub
22df1ea verified
"""Compare hand-mask quality across backends on a single image.
Runs MediaPipe (current pipeline), SAM 2.1 tiny, and SAM 2.1 small using
a point prompt at the palm center from MediaPipe landmarks. Saves a 4-panel
side-by-side comparison and also writes each mask's contour + edge crop.
"""
from __future__ import annotations
import sys
import time
from pathlib import Path
from typing import Tuple
import cv2
import numpy as np
from PIL import Image as PILImage
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from src.finger_segmentation import segment_hand # noqa: E402
IMG_PATH = Path("input/sample-04-12/card_2.jpg")
OUT_DIR = Path("output/hand_sam_compare")
SAM_MODELS = [
("sam2.1-tiny", "facebook/sam2.1-hiera-tiny"),
("sam2.1-small", "facebook/sam2.1-hiera-small"),
]
def palm_and_card_points(image_bgr: np.ndarray, hand_data: dict) -> Tuple[Tuple[int, int], Tuple[int, int]]:
"""Return (palm_center, card_center) pixel coords in the canonical image space.
Palm center = mean of wrist + MCPs (landmarks 0, 5, 9, 13, 17).
Card center = a rough point to the left of the hand (negative prompt hint).
"""
landmarks = hand_data.get("landmarks")
if landmarks is None:
raise RuntimeError("MediaPipe returned no landmarks")
# landmarks is (21, 2 or 3) in pixel coords
lm = np.asarray(landmarks)[:, :2]
palm_ids = [0, 5, 9, 13, 17]
palm_center = tuple(np.round(lm[palm_ids].mean(axis=0)).astype(int).tolist())
# Card hint: far from hand, toward image left
h, w = image_bgr.shape[:2]
hand_x_min = int(lm[:, 0].min())
card_x = max(50, hand_x_min - 150)
card_y = h // 2
return palm_center, (card_x, card_y)
def run_sam(
model_id: str,
image_rgb: np.ndarray,
palm_xy: Tuple[int, int],
negative_xy: Tuple[int, int],
) -> Tuple[np.ndarray, float, float]:
"""Run SAM 2.1 with palm positive + card negative point. Returns (mask, score, seconds)."""
import torch
from transformers import Sam2Model, Sam2Processor
processor = Sam2Processor.from_pretrained(model_id)
model = Sam2Model.from_pretrained(model_id).to("cpu").eval()
pil = PILImage.fromarray(image_rgb)
input_points = [[[list(palm_xy), list(negative_xy)]]]
input_labels = [[[1, 0]]]
t0 = time.time()
inputs = processor(
images=pil,
input_points=input_points,
input_labels=input_labels,
return_tensors="pt",
)
with torch.inference_mode():
outputs = model(**inputs, multimask_output=True)
masks = processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"],
mask_threshold=0.0,
)[0][0] # (num_candidates, H, W) for first image, first prompt set
scores = outputs.iou_scores.cpu().numpy()[0, 0]
best_idx = int(np.argmax(scores))
mask = masks[best_idx].numpy().astype(bool)
return mask, float(scores[best_idx]), time.time() - t0
def mask_to_overlay(image_bgr: np.ndarray, mask: np.ndarray, color: Tuple[int, int, int]) -> np.ndarray:
"""Return a BGR image with the mask tinted + contour drawn."""
out = image_bgr.copy()
tint = np.zeros_like(out)
tint[mask] = color
out = cv2.addWeighted(out, 1.0, tint, 0.35, 0)
contours, _ = cv2.findContours(
mask.astype(np.uint8) * 255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
)
cv2.drawContours(out, contours, -1, color, 2, cv2.LINE_AA)
return out
def label_panel(img: np.ndarray, text: str) -> np.ndarray:
h, w = img.shape[:2]
cv2.rectangle(img, (0, 0), (w, 60), (0, 0, 0), -1)
cv2.putText(img, text, (20, 42), cv2.FONT_HERSHEY_SIMPLEX, 1.3,
(255, 255, 255), 3, cv2.LINE_AA)
return img
def main() -> int:
OUT_DIR.mkdir(parents=True, exist_ok=True)
image_bgr = cv2.imread(str(IMG_PATH))
if image_bgr is None:
print(f"Failed to load {IMG_PATH}")
return 1
print(f"Image: {IMG_PATH} {image_bgr.shape}")
# --- MediaPipe baseline ---
t0 = time.time()
hand_data = segment_hand(image_bgr, finger="index")
mp_time = time.time() - t0
if hand_data is None:
print("MediaPipe detected no hand — aborting")
return 1
canonical_image = hand_data.get("canonical_image", image_bgr)
mp_mask = hand_data.get("mask")
if mp_mask is None:
print("MediaPipe did not return a hand mask")
return 1
mp_mask = mp_mask.astype(bool)
print(f"MediaPipe: {mp_time:.1f}s mask_area={mp_mask.sum()}")
# Work in the canonical image so the comparison is apples-to-apples
image_for_sam = canonical_image.copy()
palm_xy, card_xy = palm_and_card_points(image_for_sam, hand_data)
print(f"Palm prompt: {palm_xy} Negative hint: {card_xy}")
image_rgb = cv2.cvtColor(image_for_sam, cv2.COLOR_BGR2RGB)
# --- SAM models ---
results = {"mediapipe": (mp_mask, None, mp_time)}
for name, model_id in SAM_MODELS:
print(f"\n=== {name} ({model_id}) ===")
try:
mask, score, seconds = run_sam(model_id, image_rgb, palm_xy, card_xy)
# Align shape (should already be canonical)
if mask.shape != mp_mask.shape:
mask = cv2.resize(
mask.astype(np.uint8),
(mp_mask.shape[1], mp_mask.shape[0]),
interpolation=cv2.INTER_NEAREST,
).astype(bool)
print(f" score={score:.3f} time={seconds:.1f}s area={mask.sum()}")
results[name] = (mask, score, seconds)
except Exception as e:
print(f" FAILED: {e!r}")
import traceback
traceback.print_exc()
# --- Render panels ---
panels = []
colors = {
"mediapipe": (0, 165, 255), # orange
"sam2.1-tiny": (0, 255, 255), # yellow
"sam2.1-small": (0, 255, 0), # green
}
# Panel 0: original with prompt points
orig = image_for_sam.copy()
cv2.circle(orig, palm_xy, 18, (0, 255, 0), -1)
cv2.circle(orig, palm_xy, 18, (0, 0, 0), 3)
cv2.circle(orig, card_xy, 18, (0, 0, 255), -1)
cv2.circle(orig, card_xy, 18, (0, 0, 0), 3)
panels.append(label_panel(orig, "original + prompts"))
for name in ["mediapipe", "sam2.1-tiny", "sam2.1-small"]:
if name not in results:
continue
mask, score, seconds = results[name]
panel = mask_to_overlay(image_for_sam, mask, colors[name])
label = f"{name} {seconds:.1f}s"
if score is not None:
label += f" score={score:.2f}"
panels.append(label_panel(panel, label))
# Save individual panels full-res
for i, p in enumerate(panels):
cv2.imwrite(str(OUT_DIR / f"panel_{i}_{['orig','mediapipe','tiny','small'][i]}.png"), p)
# Build a single side-by-side at a readable size
def resize_to_height(img: np.ndarray, H: int) -> np.ndarray:
h, w = img.shape[:2]
scale = H / h
return cv2.resize(img, (int(round(w * scale)), H), interpolation=cv2.INTER_AREA)
target_h = 900
resized = [resize_to_height(p, target_h) for p in panels]
combined = np.hstack(resized)
cv2.imwrite(str(OUT_DIR / "comparison_full.png"), combined)
# Also zoom-crop around the hand for fine-detail inspection
ys, xs = np.where(mp_mask)
if len(xs) > 0:
pad = 80
x0, x1 = max(0, xs.min() - pad), min(image_for_sam.shape[1], xs.max() + pad)
y0, y1 = max(0, ys.min() - pad), min(image_for_sam.shape[0], ys.max() + pad)
crops = []
for p in panels:
crop = p[y0:y1, x0:x1]
crops.append(resize_to_height(crop, target_h))
combined_zoom = np.hstack(crops)
cv2.imwrite(str(OUT_DIR / "comparison_zoom.png"), combined_zoom)
print(f"\nSaved panels to {OUT_DIR}/")
return 0
if __name__ == "__main__":
raise SystemExit(main())