Spaces:
Running
Running
File size: 5,905 Bytes
22df1ea 6f3fe10 22df1ea 6f3fe10 d3d0932 22df1ea 6f3fe10 22df1ea 0c727ab 22df1ea 0c727ab 22df1ea 0c727ab 22df1ea 6f3fe10 22df1ea 6f3fe10 22df1ea 6f3fe10 22df1ea | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | """
SAM 2.1-based hand segmentation.
Produces a pixel-accurate hand mask using Meta's Segment Anything 2.1
(Hiera Tiny) via HuggingFace transformers, seeded by a positive point
prompt at the palm center (derived from MediaPipe landmarks). Optional
negative points can steer SAM away from the credit card.
This replaces the synthetic convex-hull "mask" produced by
`finger_segmentation._create_hand_mask()`, which is built from the
21 hand landmarks and does not follow the true hand contour.
Prompt-based inference: ~0.6s per call on CPU (vs ~18s for AMG).
"""
from __future__ import annotations
import logging
import time
from typing import List, Optional, Tuple
import cv2
import numpy as np
from .debug_observer import DebugObserver
from .sam_backend import INFERENCE_MAX_SIDE, get_sam2
logger = logging.getLogger(__name__)
def _downscale(image_bgr: np.ndarray) -> Tuple[np.ndarray, float]:
"""Downscale so the long side is INFERENCE_MAX_SIDE. Returns (scaled, scale_back).
`scale_back` is the factor to multiply scaled coords by to get original coords.
"""
h, w = image_bgr.shape[:2]
long_side = max(h, w)
if long_side <= INFERENCE_MAX_SIDE:
return image_bgr, 1.0
scale = INFERENCE_MAX_SIDE / long_side
new_w = int(round(w * scale))
new_h = int(round(h * scale))
return cv2.resize(image_bgr, (new_w, new_h), interpolation=cv2.INTER_AREA), 1.0 / scale
def segment_hand_sam(
image_bgr: np.ndarray,
palm_xy: Tuple[int, int],
negative_points: Optional[List[Tuple[int, int]]] = None,
debug_dir: Optional[str] = None,
) -> Optional[np.ndarray]:
"""Return a pixel-accurate bool hand mask (H x W) via SAM 2.1 Tiny.
Args:
image_bgr: Full-resolution BGR image in the canonical orientation.
palm_xy: (x, y) pixel coordinates of the palm center (positive prompt).
negative_points: Optional list of (x, y) points to steer SAM away from
non-hand regions (e.g., credit card center).
debug_dir: Optional directory to save mask + overlay for inspection.
Returns:
Bool mask of the same shape as `image_bgr[:2]`, or None on failure.
"""
import torch
import torch.nn.functional as F
from PIL import Image as PILImage
h_full, w_full = image_bgr.shape[:2]
scaled_bgr, scale_back = _downscale(image_bgr)
scaled_rgb = cv2.cvtColor(scaled_bgr, cv2.COLOR_BGR2RGB)
pil = PILImage.fromarray(scaled_rgb)
# Map prompt points into the scaled image space
scale_down = 1.0 / scale_back # original -> scaled
palm_scaled = (int(round(palm_xy[0] * scale_down)), int(round(palm_xy[1] * scale_down)))
prompt_points = [list(palm_scaled)]
prompt_labels = [1]
if negative_points:
for nx, ny in negative_points:
prompt_points.append([int(round(nx * scale_down)), int(round(ny * scale_down))])
prompt_labels.append(0)
model, processor = get_sam2()
t0 = time.time()
inputs = processor(
images=pil,
input_points=[[prompt_points]],
input_labels=[[prompt_labels]],
return_tensors="pt",
)
with torch.inference_mode():
outputs = model(**inputs, multimask_output=True)
# Use the raw 256x256 logits and bilinearly upsample them to the full
# resolution before thresholding. Going through post_process_masks +
# cv2.INTER_NEAREST binarizes at the scaled resolution and then blows
# the hard mask up ~4x, which produces visible staircase edges on the
# finger boundaries (see script/experiment_sam_mask_quality.py).
pred_masks = outputs.pred_masks.cpu() # (1, 1, num_cands, H_low, W_low)
scores = outputs.iou_scores.cpu().numpy()[0, 0]
best_idx = int(np.argmax(scores))
best_score = float(scores[best_idx])
logits_best = pred_masks[0, 0, best_idx].to(torch.float32)
upsampled = F.interpolate(
logits_best.unsqueeze(0).unsqueeze(0),
size=(h_full, w_full),
mode="bilinear",
align_corners=False,
)[0, 0].numpy()
mask_full = upsampled > 0.0
infer_time = time.time() - t0
logger.info(
"SAM hand mask: score=%.3f time=%.1fs area=%dpx",
best_score, infer_time, int(mask_full.sum()),
)
if debug_dir:
observer = DebugObserver(debug_dir)
observer.save_stage("01_sam_hand_mask", mask_full.astype(np.uint8) * 255)
overlay = image_bgr.copy()
tint = np.zeros_like(overlay)
tint[mask_full] = (0, 255, 255)
overlay = cv2.addWeighted(overlay, 1.0, tint, 0.35, 0)
contours, _ = cv2.findContours(
mask_full.astype(np.uint8) * 255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
)
cv2.drawContours(overlay, contours, -1, (0, 255, 255), 3, cv2.LINE_AA)
cv2.circle(overlay, palm_xy, 20, (0, 255, 0), -1)
cv2.circle(overlay, palm_xy, 20, (0, 0, 0), 3)
if negative_points:
for nx, ny in negative_points:
cv2.circle(overlay, (int(nx), int(ny)), 20, (0, 0, 255), -1)
cv2.circle(overlay, (int(nx), int(ny)), 20, (0, 0, 0), 3)
label = f"SAM hand score={best_score:.2f} {infer_time:.1f}s"
cv2.putText(overlay, label, (30, 60), cv2.FONT_HERSHEY_SIMPLEX, 1.1,
(255, 255, 255), 5, cv2.LINE_AA)
cv2.putText(overlay, label, (30, 60), cv2.FONT_HERSHEY_SIMPLEX, 1.1,
(0, 255, 255), 2, cv2.LINE_AA)
observer.save_stage("02_sam_hand_overlay", overlay)
return mask_full
def palm_center_from_landmarks(landmarks_px: np.ndarray) -> Tuple[int, int]:
"""Return (x, y) pixel coord of the palm center from the 21 MediaPipe landmarks.
Defined as the mean of wrist (0) + four MCPs (5, 9, 13, 17).
"""
idx = [0, 5, 9, 13, 17]
center = np.mean(landmarks_px[idx, :2], axis=0)
return (int(round(center[0])), int(round(center[1])))
|