Spaces:
Running
Running
| """ | |
| Experiment: isolate the cause of the 'staircase' edges on our SAM 2.1 hand mask. | |
| Runs five configurations of SAM 2.1 Hiera Small against a single sample image, | |
| seeded by the MediaPipe palm center. All five variants feed the *same* pixel | |
| prompts to the *same* model; only the inference input resolution and the | |
| mask-upscale strategy differ. | |
| Configurations: | |
| A baseline 1024-long-side hard-mask + INTER_NEAREST upscale (current) | |
| B lin_hard 1024-long-side hard-mask + INTER_LINEAR + rethresh | |
| C soft 1024-long-side raw logits -> bilinear -> threshold full-res | |
| D native_nn native hard-mask + INTER_NEAREST (no upscale needed) | |
| E native_sft native raw logits -> bilinear -> threshold | |
| For each variant the script saves: | |
| - full-resolution binary mask PNG | |
| - hand overlay with yellow mask + green palm prompt dot | |
| - 600x600 fingertip crop centered on the middle-finger tip (landmark 12) | |
| so the staircase vs smooth comparison is visible at a glance | |
| It also prints a small table: | |
| - perimeter (px) cv2.arcLength of the largest contour | |
| - iso ratio perimeter / sqrt(area) -- higher = more jagged | |
| - rel. to baseline (%) iso ratio relative to config A | |
| Usage: | |
| python script/experiment_sam_mask_quality.py \\ | |
| --input input/sample-04-12/card_2.jpg \\ | |
| --output-dir output/sam_mask_quality | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import sys | |
| import time | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| import cv2 | |
| import numpy as np | |
| # Add repo root so we can import src.* | |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) | |
| from src.finger_segmentation import segment_hand | |
| from src.sam_backend import get_sam2 | |
| from src.sam_hand_segmentation import palm_center_from_landmarks | |
| # ---------------------------------------------------------------------------- | |
| # SAM inference variants | |
| # ---------------------------------------------------------------------------- | |
| def _run_sam( | |
| image_bgr: np.ndarray, | |
| palm_xy: Tuple[int, int], | |
| inference_long_side: Optional[int], | |
| upscale_mode: str, | |
| ) -> Tuple[np.ndarray, float, float]: | |
| """Run SAM 2.1 with controlled inference resolution and upscale path. | |
| Args: | |
| image_bgr: Full-resolution BGR image (canonical orientation). | |
| palm_xy: (x, y) pixel coords of palm center in the full-res image. | |
| inference_long_side: If set, downscale so long-side equals this value. | |
| If None, feed native resolution. | |
| upscale_mode: One of: | |
| - "nearest_hard": post_process_masks -> INTER_NEAREST to full res. | |
| - "linear_hard": post_process_masks -> INTER_LINEAR -> re-threshold. | |
| - "soft": raw pred_masks (256x256) -> bilinear to full res | |
| -> threshold at 0.0. | |
| Returns: | |
| (mask_full: bool HxW, iou_score: float, infer_seconds: float) | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image as PILImage | |
| h_full, w_full = image_bgr.shape[:2] | |
| long_side = max(h_full, w_full) | |
| if inference_long_side is None or long_side <= inference_long_side: | |
| scaled_bgr = image_bgr | |
| scale_back = 1.0 | |
| else: | |
| s = inference_long_side / long_side | |
| new_w = int(round(w_full * s)) | |
| new_h = int(round(h_full * s)) | |
| scaled_bgr = cv2.resize(image_bgr, (new_w, new_h), interpolation=cv2.INTER_AREA) | |
| scale_back = 1.0 / s | |
| scaled_rgb = cv2.cvtColor(scaled_bgr, cv2.COLOR_BGR2RGB) | |
| pil = PILImage.fromarray(scaled_rgb) | |
| scale_down = 1.0 / scale_back | |
| palm_scaled = [ | |
| int(round(palm_xy[0] * scale_down)), | |
| int(round(palm_xy[1] * scale_down)), | |
| ] | |
| model, processor = get_sam2() | |
| t0 = time.time() | |
| inputs = processor( | |
| images=pil, | |
| input_points=[[[palm_scaled]]], | |
| input_labels=[[[1]]], | |
| return_tensors="pt", | |
| ) | |
| with torch.inference_mode(): | |
| outputs = model(**inputs, multimask_output=True) | |
| pred_masks = outputs.pred_masks.cpu() # (1, 1, num_cands, H_low, W_low) | |
| iou_scores = outputs.iou_scores.cpu().numpy()[0, 0] | |
| best_idx = int(np.argmax(iou_scores)) | |
| best_score = float(iou_scores[best_idx]) | |
| if upscale_mode == "soft": | |
| logits = pred_masks[0, 0, best_idx].to(torch.float32) # (H_low, W_low) | |
| logits_4d = logits.unsqueeze(0).unsqueeze(0) | |
| upsampled = F.interpolate( | |
| logits_4d, | |
| size=(h_full, w_full), | |
| mode="bilinear", | |
| align_corners=False, | |
| )[0, 0].numpy() | |
| mask_full = upsampled > 0.0 | |
| else: | |
| masks_scaled = processor.post_process_masks( | |
| outputs.pred_masks.cpu(), | |
| inputs["original_sizes"], | |
| mask_threshold=0.0, | |
| )[0][0] | |
| mask_scaled = masks_scaled[best_idx].numpy().astype(np.uint8) # scaled-res | |
| if mask_scaled.shape != (h_full, w_full): | |
| if upscale_mode == "nearest_hard": | |
| interp = cv2.INTER_NEAREST | |
| elif upscale_mode == "linear_hard": | |
| interp = cv2.INTER_LINEAR | |
| else: | |
| raise ValueError(f"unknown upscale_mode: {upscale_mode}") | |
| resized = cv2.resize(mask_scaled, (w_full, h_full), interpolation=interp) | |
| if upscale_mode == "linear_hard": | |
| mask_full = resized >= 1 # rethreshold after linear interp | |
| else: | |
| mask_full = resized.astype(bool) | |
| else: | |
| mask_full = mask_scaled.astype(bool) | |
| return mask_full, best_score, time.time() - t0 | |
| # ---------------------------------------------------------------------------- | |
| # Metrics + visualization helpers | |
| # ---------------------------------------------------------------------------- | |
| def _roughness_metrics(mask: np.ndarray) -> Dict[str, float]: | |
| """Perimeter + isoperimetric ratio of the largest contour.""" | |
| mask_u8 = (mask.astype(np.uint8)) * 255 | |
| contours, _ = cv2.findContours(mask_u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) | |
| if not contours: | |
| return {"perimeter_px": 0.0, "area_px": 0.0, "iso_ratio": float("nan")} | |
| largest = max(contours, key=cv2.contourArea) | |
| perim = float(cv2.arcLength(largest, closed=True)) | |
| area = float(cv2.contourArea(largest)) | |
| iso = perim / (np.sqrt(area) + 1e-9) | |
| return {"perimeter_px": perim, "area_px": area, "iso_ratio": iso} | |
| def _save_overlay( | |
| path: Path, | |
| image_bgr: np.ndarray, | |
| mask: np.ndarray, | |
| palm_xy: Tuple[int, int], | |
| label: str, | |
| ) -> None: | |
| overlay = image_bgr.copy() | |
| tint = np.zeros_like(overlay) | |
| tint[mask] = (0, 255, 255) | |
| overlay = cv2.addWeighted(overlay, 1.0, tint, 0.35, 0) | |
| contours, _ = cv2.findContours( | |
| (mask.astype(np.uint8)) * 255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE | |
| ) | |
| cv2.drawContours(overlay, contours, -1, (0, 255, 255), 2, cv2.LINE_AA) | |
| cv2.circle(overlay, palm_xy, 18, (0, 255, 0), -1) | |
| cv2.circle(overlay, palm_xy, 18, (0, 0, 0), 3) | |
| cv2.putText(overlay, label, (30, 60), cv2.FONT_HERSHEY_SIMPLEX, 1.1, | |
| (255, 255, 255), 5, cv2.LINE_AA) | |
| cv2.putText(overlay, label, (30, 60), cv2.FONT_HERSHEY_SIMPLEX, 1.1, | |
| (0, 255, 255), 2, cv2.LINE_AA) | |
| cv2.imwrite(str(path), overlay) | |
| def _save_fingertip_crop( | |
| path: Path, | |
| image_bgr: np.ndarray, | |
| mask: np.ndarray, | |
| center_xy: Tuple[int, int], | |
| crop_half: int = 300, | |
| label: str = "", | |
| ) -> None: | |
| h, w = image_bgr.shape[:2] | |
| cx, cy = center_xy | |
| x0 = max(0, cx - crop_half) | |
| y0 = max(0, cy - crop_half) | |
| x1 = min(w, cx + crop_half) | |
| y1 = min(h, cy + crop_half) | |
| crop = image_bgr[y0:y1, x0:x1].copy() | |
| mask_crop = mask[y0:y1, x0:x1] | |
| tint = np.zeros_like(crop) | |
| tint[mask_crop] = (0, 255, 255) | |
| crop = cv2.addWeighted(crop, 1.0, tint, 0.4, 0) | |
| contours, _ = cv2.findContours( | |
| (mask_crop.astype(np.uint8)) * 255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE | |
| ) | |
| cv2.drawContours(crop, contours, -1, (0, 255, 255), 2, cv2.LINE_AA) | |
| if label: | |
| cv2.putText(crop, label, (15, 35), cv2.FONT_HERSHEY_SIMPLEX, 0.9, | |
| (255, 255, 255), 4, cv2.LINE_AA) | |
| cv2.putText(crop, label, (15, 35), cv2.FONT_HERSHEY_SIMPLEX, 0.9, | |
| (0, 255, 255), 2, cv2.LINE_AA) | |
| cv2.imwrite(str(path), crop) | |
| # ---------------------------------------------------------------------------- | |
| # Main | |
| # ---------------------------------------------------------------------------- | |
| def main() -> int: | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| parser.add_argument("--input", default="input/sample-04-12/card_2.jpg") | |
| parser.add_argument("--output-dir", default="output/sam_mask_quality") | |
| args = parser.parse_args() | |
| in_path = Path(args.input) | |
| out_dir = Path(args.output_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| image = cv2.imread(str(in_path)) | |
| if image is None: | |
| print(f"ERROR: could not read {in_path}") | |
| return 2 | |
| print(f"Input: {in_path} ({image.shape[1]}x{image.shape[0]})") | |
| # Get canonical image + landmarks WITHOUT running SAM. Pass a large | |
| # max_dimension so MediaPipe's internal resize does not happen and the | |
| # canonical image stays at native resolution. | |
| hand_data = segment_hand( | |
| image=image, | |
| finger="middle", | |
| max_dimension=3000, | |
| debug_dir=None, | |
| use_sam_mask=False, | |
| ) | |
| if hand_data is None: | |
| print("ERROR: hand detection failed") | |
| return 2 | |
| canonical = hand_data["canonical_image"] | |
| landmarks = hand_data["landmarks"] # (21, 2) in canonical px coords | |
| palm_xy = palm_center_from_landmarks(landmarks) | |
| middle_tip_xy = (int(round(landmarks[12, 0])), int(round(landmarks[12, 1]))) | |
| ch, cw = canonical.shape[:2] | |
| print(f"Canonical: {cw}x{ch} palm=({palm_xy[0]},{palm_xy[1]}) " | |
| f"middle_tip=({middle_tip_xy[0]},{middle_tip_xy[1]})") | |
| # Save canonical reference image | |
| cv2.imwrite(str(out_dir / "00_canonical.png"), canonical) | |
| variants = [ | |
| ("A_baseline_1024_nn", 1024, "nearest_hard"), | |
| ("B_1024_linear_hard", 1024, "linear_hard"), | |
| ("C_1024_soft", 1024, "soft"), | |
| ("D_native_nn", None, "nearest_hard"), | |
| ("E_native_soft", None, "soft"), | |
| ] | |
| results = [] | |
| for name, long_side, mode in variants: | |
| print(f"\n=== {name} long_side={long_side} mode={mode} ===") | |
| mask, score, secs = _run_sam(canonical, palm_xy, long_side, mode) | |
| m = _roughness_metrics(mask) | |
| print(f" iou={score:.3f} time={secs:.2f}s " | |
| f"perim={m['perimeter_px']:.0f}px iso={m['iso_ratio']:.3f}") | |
| cv2.imwrite(str(out_dir / f"{name}_mask.png"), | |
| (mask.astype(np.uint8)) * 255) | |
| _save_overlay( | |
| out_dir / f"{name}_overlay.png", | |
| canonical, mask, palm_xy, | |
| label=f"{name} iou={score:.2f}", | |
| ) | |
| _save_fingertip_crop( | |
| out_dir / f"{name}_fingertip.png", | |
| canonical, mask, middle_tip_xy, | |
| crop_half=300, | |
| label=name, | |
| ) | |
| results.append((name, score, secs, m["perimeter_px"], m["iso_ratio"])) | |
| # Summary table | |
| print("\n") | |
| print("=" * 78) | |
| print(f"{'variant':<22}{'iou':>8}{'time(s)':>10}" | |
| f"{'perim(px)':>14}{'iso':>10}{'vs A (%)':>12}") | |
| print("-" * 78) | |
| iso_base = results[0][4] | |
| for name, score, secs, perim, iso in results: | |
| rel = (iso / iso_base - 1.0) * 100.0 if iso_base else float("nan") | |
| print(f"{name:<22}{score:>8.3f}{secs:>10.2f}" | |
| f"{perim:>14.0f}{iso:>10.3f}{rel:>11.1f}%") | |
| print("=" * 78) | |
| # Side-by-side fingertip comparison | |
| crops = [cv2.imread(str(out_dir / f"{name}_fingertip.png")) for name, *_ in results] | |
| if all(c is not None for c in crops): | |
| panel = np.hstack(crops) | |
| cv2.imwrite(str(out_dir / "fingertip_comparison.png"), panel) | |
| print(f"\nFingertip comparison strip: {out_dir / 'fingertip_comparison.png'}") | |
| print(f"\nAll outputs saved to: {out_dir}/") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |