ring-sizer / src /sam_hand_segmentation.py
feng-x's picture
Upload folder using huggingface_hub
6f3fe10 verified
"""
SAM 2.1-based hand segmentation.
Produces a pixel-accurate hand mask using Meta's Segment Anything 2.1
(Hiera Tiny) via HuggingFace transformers, seeded by a positive point
prompt at the palm center (derived from MediaPipe landmarks). Optional
negative points can steer SAM away from the credit card.
This replaces the synthetic convex-hull "mask" produced by
`finger_segmentation._create_hand_mask()`, which is built from the
21 hand landmarks and does not follow the true hand contour.
Prompt-based inference: ~0.6s per call on CPU (vs ~18s for AMG).
"""
from __future__ import annotations
import logging
import time
from typing import List, Optional, Tuple
import cv2
import numpy as np
from .debug_observer import DebugObserver
from .sam_backend import INFERENCE_MAX_SIDE, get_sam2
logger = logging.getLogger(__name__)
def _downscale(image_bgr: np.ndarray) -> Tuple[np.ndarray, float]:
"""Downscale so the long side is INFERENCE_MAX_SIDE. Returns (scaled, scale_back).
`scale_back` is the factor to multiply scaled coords by to get original coords.
"""
h, w = image_bgr.shape[:2]
long_side = max(h, w)
if long_side <= INFERENCE_MAX_SIDE:
return image_bgr, 1.0
scale = INFERENCE_MAX_SIDE / long_side
new_w = int(round(w * scale))
new_h = int(round(h * scale))
return cv2.resize(image_bgr, (new_w, new_h), interpolation=cv2.INTER_AREA), 1.0 / scale
def segment_hand_sam(
image_bgr: np.ndarray,
palm_xy: Tuple[int, int],
negative_points: Optional[List[Tuple[int, int]]] = None,
debug_dir: Optional[str] = None,
) -> Optional[np.ndarray]:
"""Return a pixel-accurate bool hand mask (H x W) via SAM 2.1 Tiny.
Args:
image_bgr: Full-resolution BGR image in the canonical orientation.
palm_xy: (x, y) pixel coordinates of the palm center (positive prompt).
negative_points: Optional list of (x, y) points to steer SAM away from
non-hand regions (e.g., credit card center).
debug_dir: Optional directory to save mask + overlay for inspection.
Returns:
Bool mask of the same shape as `image_bgr[:2]`, or None on failure.
"""
import torch
import torch.nn.functional as F
from PIL import Image as PILImage
h_full, w_full = image_bgr.shape[:2]
scaled_bgr, scale_back = _downscale(image_bgr)
scaled_rgb = cv2.cvtColor(scaled_bgr, cv2.COLOR_BGR2RGB)
pil = PILImage.fromarray(scaled_rgb)
# Map prompt points into the scaled image space
scale_down = 1.0 / scale_back # original -> scaled
palm_scaled = (int(round(palm_xy[0] * scale_down)), int(round(palm_xy[1] * scale_down)))
prompt_points = [list(palm_scaled)]
prompt_labels = [1]
if negative_points:
for nx, ny in negative_points:
prompt_points.append([int(round(nx * scale_down)), int(round(ny * scale_down))])
prompt_labels.append(0)
model, processor = get_sam2()
t0 = time.time()
inputs = processor(
images=pil,
input_points=[[prompt_points]],
input_labels=[[prompt_labels]],
return_tensors="pt",
)
with torch.inference_mode():
outputs = model(**inputs, multimask_output=True)
# Use the raw 256x256 logits and bilinearly upsample them to the full
# resolution before thresholding. Going through post_process_masks +
# cv2.INTER_NEAREST binarizes at the scaled resolution and then blows
# the hard mask up ~4x, which produces visible staircase edges on the
# finger boundaries (see script/experiment_sam_mask_quality.py).
pred_masks = outputs.pred_masks.cpu() # (1, 1, num_cands, H_low, W_low)
scores = outputs.iou_scores.cpu().numpy()[0, 0]
best_idx = int(np.argmax(scores))
best_score = float(scores[best_idx])
logits_best = pred_masks[0, 0, best_idx].to(torch.float32)
upsampled = F.interpolate(
logits_best.unsqueeze(0).unsqueeze(0),
size=(h_full, w_full),
mode="bilinear",
align_corners=False,
)[0, 0].numpy()
mask_full = upsampled > 0.0
infer_time = time.time() - t0
logger.info(
"SAM hand mask: score=%.3f time=%.1fs area=%dpx",
best_score, infer_time, int(mask_full.sum()),
)
if debug_dir:
observer = DebugObserver(debug_dir)
observer.save_stage("01_sam_hand_mask", mask_full.astype(np.uint8) * 255)
overlay = image_bgr.copy()
tint = np.zeros_like(overlay)
tint[mask_full] = (0, 255, 255)
overlay = cv2.addWeighted(overlay, 1.0, tint, 0.35, 0)
contours, _ = cv2.findContours(
mask_full.astype(np.uint8) * 255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
)
cv2.drawContours(overlay, contours, -1, (0, 255, 255), 3, cv2.LINE_AA)
cv2.circle(overlay, palm_xy, 20, (0, 255, 0), -1)
cv2.circle(overlay, palm_xy, 20, (0, 0, 0), 3)
if negative_points:
for nx, ny in negative_points:
cv2.circle(overlay, (int(nx), int(ny)), 20, (0, 0, 255), -1)
cv2.circle(overlay, (int(nx), int(ny)), 20, (0, 0, 0), 3)
label = f"SAM hand score={best_score:.2f} {infer_time:.1f}s"
cv2.putText(overlay, label, (30, 60), cv2.FONT_HERSHEY_SIMPLEX, 1.1,
(255, 255, 255), 5, cv2.LINE_AA)
cv2.putText(overlay, label, (30, 60), cv2.FONT_HERSHEY_SIMPLEX, 1.1,
(0, 255, 255), 2, cv2.LINE_AA)
observer.save_stage("02_sam_hand_overlay", overlay)
return mask_full
def palm_center_from_landmarks(landmarks_px: np.ndarray) -> Tuple[int, int]:
"""Return (x, y) pixel coord of the palm center from the 21 MediaPipe landmarks.
Defined as the mean of wrist (0) + four MCPs (5, 9, 13, 17).
"""
idx = [0, 5, 9, 13, 17]
center = np.mean(landmarks_px[idx, :2], axis=0)
return (int(round(center[0])), int(round(center[1])))