Spaces:
Running
Running
| """ | |
| SAM 2.1-based hand segmentation. | |
| Produces a pixel-accurate hand mask using Meta's Segment Anything 2.1 | |
| (Hiera Tiny) via HuggingFace transformers, seeded by a positive point | |
| prompt at the palm center (derived from MediaPipe landmarks). Optional | |
| negative points can steer SAM away from the credit card. | |
| This replaces the synthetic convex-hull "mask" produced by | |
| `finger_segmentation._create_hand_mask()`, which is built from the | |
| 21 hand landmarks and does not follow the true hand contour. | |
| Prompt-based inference: ~0.6s per call on CPU (vs ~18s for AMG). | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import time | |
| from typing import List, Optional, Tuple | |
| import cv2 | |
| import numpy as np | |
| from .debug_observer import DebugObserver | |
| from .sam_backend import INFERENCE_MAX_SIDE, get_sam2 | |
| logger = logging.getLogger(__name__) | |
| def _downscale(image_bgr: np.ndarray) -> Tuple[np.ndarray, float]: | |
| """Downscale so the long side is INFERENCE_MAX_SIDE. Returns (scaled, scale_back). | |
| `scale_back` is the factor to multiply scaled coords by to get original coords. | |
| """ | |
| h, w = image_bgr.shape[:2] | |
| long_side = max(h, w) | |
| if long_side <= INFERENCE_MAX_SIDE: | |
| return image_bgr, 1.0 | |
| scale = INFERENCE_MAX_SIDE / long_side | |
| new_w = int(round(w * scale)) | |
| new_h = int(round(h * scale)) | |
| return cv2.resize(image_bgr, (new_w, new_h), interpolation=cv2.INTER_AREA), 1.0 / scale | |
| def segment_hand_sam( | |
| image_bgr: np.ndarray, | |
| palm_xy: Tuple[int, int], | |
| negative_points: Optional[List[Tuple[int, int]]] = None, | |
| debug_dir: Optional[str] = None, | |
| ) -> Optional[np.ndarray]: | |
| """Return a pixel-accurate bool hand mask (H x W) via SAM 2.1 Tiny. | |
| Args: | |
| image_bgr: Full-resolution BGR image in the canonical orientation. | |
| palm_xy: (x, y) pixel coordinates of the palm center (positive prompt). | |
| negative_points: Optional list of (x, y) points to steer SAM away from | |
| non-hand regions (e.g., credit card center). | |
| debug_dir: Optional directory to save mask + overlay for inspection. | |
| Returns: | |
| Bool mask of the same shape as `image_bgr[:2]`, or None on failure. | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image as PILImage | |
| h_full, w_full = image_bgr.shape[:2] | |
| scaled_bgr, scale_back = _downscale(image_bgr) | |
| scaled_rgb = cv2.cvtColor(scaled_bgr, cv2.COLOR_BGR2RGB) | |
| pil = PILImage.fromarray(scaled_rgb) | |
| # Map prompt points into the scaled image space | |
| scale_down = 1.0 / scale_back # original -> scaled | |
| palm_scaled = (int(round(palm_xy[0] * scale_down)), int(round(palm_xy[1] * scale_down))) | |
| prompt_points = [list(palm_scaled)] | |
| prompt_labels = [1] | |
| if negative_points: | |
| for nx, ny in negative_points: | |
| prompt_points.append([int(round(nx * scale_down)), int(round(ny * scale_down))]) | |
| prompt_labels.append(0) | |
| model, processor = get_sam2() | |
| t0 = time.time() | |
| inputs = processor( | |
| images=pil, | |
| input_points=[[prompt_points]], | |
| input_labels=[[prompt_labels]], | |
| return_tensors="pt", | |
| ) | |
| with torch.inference_mode(): | |
| outputs = model(**inputs, multimask_output=True) | |
| # Use the raw 256x256 logits and bilinearly upsample them to the full | |
| # resolution before thresholding. Going through post_process_masks + | |
| # cv2.INTER_NEAREST binarizes at the scaled resolution and then blows | |
| # the hard mask up ~4x, which produces visible staircase edges on the | |
| # finger boundaries (see script/experiment_sam_mask_quality.py). | |
| pred_masks = outputs.pred_masks.cpu() # (1, 1, num_cands, H_low, W_low) | |
| scores = outputs.iou_scores.cpu().numpy()[0, 0] | |
| best_idx = int(np.argmax(scores)) | |
| best_score = float(scores[best_idx]) | |
| logits_best = pred_masks[0, 0, best_idx].to(torch.float32) | |
| upsampled = F.interpolate( | |
| logits_best.unsqueeze(0).unsqueeze(0), | |
| size=(h_full, w_full), | |
| mode="bilinear", | |
| align_corners=False, | |
| )[0, 0].numpy() | |
| mask_full = upsampled > 0.0 | |
| infer_time = time.time() - t0 | |
| logger.info( | |
| "SAM hand mask: score=%.3f time=%.1fs area=%dpx", | |
| best_score, infer_time, int(mask_full.sum()), | |
| ) | |
| if debug_dir: | |
| observer = DebugObserver(debug_dir) | |
| observer.save_stage("01_sam_hand_mask", mask_full.astype(np.uint8) * 255) | |
| overlay = image_bgr.copy() | |
| tint = np.zeros_like(overlay) | |
| tint[mask_full] = (0, 255, 255) | |
| overlay = cv2.addWeighted(overlay, 1.0, tint, 0.35, 0) | |
| contours, _ = cv2.findContours( | |
| mask_full.astype(np.uint8) * 255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE | |
| ) | |
| cv2.drawContours(overlay, contours, -1, (0, 255, 255), 3, cv2.LINE_AA) | |
| cv2.circle(overlay, palm_xy, 20, (0, 255, 0), -1) | |
| cv2.circle(overlay, palm_xy, 20, (0, 0, 0), 3) | |
| if negative_points: | |
| for nx, ny in negative_points: | |
| cv2.circle(overlay, (int(nx), int(ny)), 20, (0, 0, 255), -1) | |
| cv2.circle(overlay, (int(nx), int(ny)), 20, (0, 0, 0), 3) | |
| label = f"SAM hand score={best_score:.2f} {infer_time:.1f}s" | |
| cv2.putText(overlay, label, (30, 60), cv2.FONT_HERSHEY_SIMPLEX, 1.1, | |
| (255, 255, 255), 5, cv2.LINE_AA) | |
| cv2.putText(overlay, label, (30, 60), cv2.FONT_HERSHEY_SIMPLEX, 1.1, | |
| (0, 255, 255), 2, cv2.LINE_AA) | |
| observer.save_stage("02_sam_hand_overlay", overlay) | |
| return mask_full | |
| def palm_center_from_landmarks(landmarks_px: np.ndarray) -> Tuple[int, int]: | |
| """Return (x, y) pixel coord of the palm center from the 21 MediaPipe landmarks. | |
| Defined as the mean of wrist (0) + four MCPs (5, 9, 13, 17). | |
| """ | |
| idx = [0, 5, 9, 13, 17] | |
| center = np.mean(landmarks_px[idx, :2], axis=0) | |
| return (int(round(center[0])), int(round(center[1]))) | |