feng-x commited on
Commit
b96e083
Β·
verified Β·
1 Parent(s): 4f1901d

Upload folder using huggingface_hub

Browse files
AGENTS.md CHANGED
@@ -21,7 +21,8 @@ For tasks of **reboot** from a new codex session:
21
  1. Read doc/v0/PRD.md, doc/v0/Plan.md, doc/v0/Progress.md for baseline implementation
22
  2. Read doc/v1/PRD.md, doc/v1/Plan.md, doc/v1/Progress.md for edge refinement (v1)
23
  3. Read doc/v4/PRD.md, doc/v4/Plan.md, doc/v4/Progress.md for SAM 2.1 integration (card + hand)
24
- 4. Assume this is a continuation of an existing project.
 
25
  5. Summarize your understanding of the current state and propose the next concrete step without writing code yet.
26
 
27
  ## Project Overview
 
21
  1. Read doc/v0/PRD.md, doc/v0/Plan.md, doc/v0/Progress.md for baseline implementation
22
  2. Read doc/v1/PRD.md, doc/v1/Plan.md, doc/v1/Progress.md for edge refinement (v1)
23
  3. Read doc/v4/PRD.md, doc/v4/Plan.md, doc/v4/Progress.md for SAM 2.1 integration (card + hand)
24
+ 4. Read doc/v5/PRD.md, doc/v5/Plan.md, doc/v5/Progress.md for the Gradio/ZeroGPU deployment port
25
+ 5. Assume this is a continuation of an existing project.
26
  5. Summarize your understanding of the current state and propose the next concrete step without writing code yet.
27
 
28
  ## Project Overview
CLAUDE.md CHANGED
@@ -21,7 +21,8 @@ For tasks of **reboot** from a new codex session:
21
  1. Read doc/v0/PRD.md, doc/v0/Plan.md, doc/v0/Progress.md for baseline implementation
22
  2. Read doc/v1/PRD.md, doc/v1/Plan.md, doc/v1/Progress.md for edge refinement (v1)
23
  3. Read doc/v4/PRD.md, doc/v4/Plan.md, doc/v4/Progress.md for SAM 2.1 integration (card + hand)
24
- 4. Assume this is a continuation of an existing project.
 
25
  5. Summarize your understanding of the current state and propose the next concrete step without writing code yet.
26
 
27
  ## Project Overview
 
21
  1. Read doc/v0/PRD.md, doc/v0/Plan.md, doc/v0/Progress.md for baseline implementation
22
  2. Read doc/v1/PRD.md, doc/v1/Plan.md, doc/v1/Progress.md for edge refinement (v1)
23
  3. Read doc/v4/PRD.md, doc/v4/Plan.md, doc/v4/Progress.md for SAM 2.1 integration (card + hand)
24
+ 4. Read doc/v5/PRD.md, doc/v5/Plan.md, doc/v5/Progress.md for the Gradio/ZeroGPU deployment port
25
+ 5. Assume this is a continuation of an existing project.
26
  5. Summarize your understanding of the current state and propose the next concrete step without writing code yet.
27
 
28
  ## Project Overview
README.md CHANGED
@@ -3,8 +3,10 @@ title: Ring Sizer
3
  emoji: "\U0001F48D"
4
  colorFrom: blue
5
  colorTo: purple
6
- sdk: docker
7
- app_port: 7860
 
 
8
  ---
9
 
10
  # Ring Sizer
 
3
  emoji: "\U0001F48D"
4
  colorFrom: blue
5
  colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ python_version: "3.10"
10
  ---
11
 
12
  # Ring Sizer
app.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Gradio entry point for the Ring Sizer HuggingFace Space (v5).
3
+
4
+ Public demo flow only: upload β†’ measurement β†’ result image + ring size
5
+ summary + raw JSON. The Flask app in `web_demo/` is still used locally for
6
+ admin / CSV / ground-truth editing, but HF Spaces now serves this Gradio
7
+ app so the measurement call can run on ZeroGPU-backed H200 GPUs.
8
+
9
+ See `doc/v5/` for the PRD, plan, and progress notes.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import logging
15
+ import os
16
+ import sys
17
+ import tempfile
18
+ import uuid
19
+ from concurrent.futures import ThreadPoolExecutor
20
+ from datetime import datetime
21
+ from pathlib import Path
22
+ from typing import Any, Dict, List, Optional, Tuple
23
+
24
+ import cv2
25
+ import numpy as np
26
+
27
+ # `spaces` is a no-op outside HF ZeroGPU, so importing it unconditionally
28
+ # keeps the local CPU path working without conditional imports.
29
+ import spaces # type: ignore
30
+ import gradio as gr
31
+
32
+ ROOT_DIR = Path(__file__).resolve().parent
33
+ sys.path.insert(0, str(ROOT_DIR))
34
+
35
+ from measure_finger import ( # noqa: E402
36
+ measure_finger,
37
+ measure_multi_finger,
38
+ apply_calibration,
39
+ )
40
+ from src.ring_size import ( # noqa: E402
41
+ recommend_ring_size,
42
+ VALID_RING_MODELS,
43
+ DEFAULT_RING_MODEL,
44
+ )
45
+ from src.ai_recommendation import ai_explain_recommendation # noqa: E402
46
+ from src.sam_backend import get_sam2 # noqa: E402
47
+
48
+ # HF ZeroGPU docs: "models must be placed on cuda at the root module level"
49
+ # (a PyTorch CUDA emulation mode is enabled outside @spaces.GPU functions,
50
+ # so this runs cleanly both on ZeroGPU and CPU). Pre-loading here means the
51
+ # first request does not pay the weight-to-GPU transfer cost.
52
+ try:
53
+ get_sam2()
54
+ except Exception as exc: # noqa: BLE001
55
+ # Don't block app startup if SAM weights are missing β€” the measurement
56
+ # call will re-attempt and surface a clearer error to the user.
57
+ print(f"[v5] SAM preload skipped: {exc}")
58
+
59
+ # Supabase persistence piggybacks on the same async executor pattern as the
60
+ # Flask app so the GPU slice releases as soon as the measurement returns.
61
+ try:
62
+ from web_demo.supabase_client import upload_file, save_measurement # noqa: E402
63
+ _SUPABASE_AVAILABLE = True
64
+ except Exception as exc: # noqa: BLE001
65
+ print(f"[v5] Supabase client not importable ({exc}) β€” persistence disabled")
66
+ _SUPABASE_AVAILABLE = False
67
+
68
+ logger = logging.getLogger(__name__)
69
+ _persist_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="supa-persist")
70
+
71
+ RESULTS_DIR = ROOT_DIR / "web_demo" / "results"
72
+ UPLOADS_DIR = ROOT_DIR / "web_demo" / "uploads"
73
+ RESULTS_DIR.mkdir(parents=True, exist_ok=True)
74
+ UPLOADS_DIR.mkdir(parents=True, exist_ok=True)
75
+
76
+ DEMO_EDGE_METHOD = "mask"
77
+ DEMO_CARD_METHOD = "sam"
78
+ DEMO_HAND_MASK_METHOD = "sam"
79
+
80
+ DEFAULT_SAMPLE_PATH = ROOT_DIR / "web_demo" / "static" / "examples" / "default_sample.jpg"
81
+
82
+
83
+ # ---------------------------------------------------------------------------
84
+ # Helpers
85
+ # ---------------------------------------------------------------------------
86
+
87
+ def _numpy_safe(obj: Any) -> Any:
88
+ """Recursively convert numpy scalar/array types to native Python types.
89
+
90
+ Gradio's JSON component calls `json.dumps` internally, which trips on
91
+ `np.float32`, `np.bool_`, and friends. This mirrors the helper already
92
+ used by the Flask app.
93
+ """
94
+ if isinstance(obj, dict):
95
+ return {k: _numpy_safe(v) for k, v in obj.items()}
96
+ if isinstance(obj, (list, tuple)):
97
+ return [_numpy_safe(v) for v in obj]
98
+ if isinstance(obj, np.bool_):
99
+ return bool(obj)
100
+ if isinstance(obj, np.integer):
101
+ return int(obj)
102
+ if isinstance(obj, np.floating):
103
+ return float(obj)
104
+ if isinstance(obj, np.ndarray):
105
+ return obj.tolist()
106
+ if isinstance(obj, np.generic):
107
+ return obj.item()
108
+ return obj
109
+
110
+
111
+ def _make_base_name(kol_name: str) -> Tuple[str, str]:
112
+ run_id = uuid.uuid4().hex[:8]
113
+ timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
114
+ slug = "".join(c if c.isalnum() else "-" for c in (kol_name or "anon")).strip("-").lower() or "anon"
115
+ return f"{slug}_{timestamp}_{run_id}", run_id
116
+
117
+
118
+ def _persist_async(
119
+ *,
120
+ upload_path: Optional[Path],
121
+ upload_name: str,
122
+ result_png_path: Path,
123
+ result_png_name: str,
124
+ record: Dict[str, Any],
125
+ ) -> None:
126
+ """Fire-and-forget Supabase persistence (storage uploads + row insert).
127
+
128
+ Errors are logged, never raised β€” a broken Supabase connection must
129
+ never poison the measurement the user just saw.
130
+ """
131
+ if not _SUPABASE_AVAILABLE:
132
+ return
133
+
134
+ def _task() -> None:
135
+ try:
136
+ photo_url = None
137
+ result_url = None
138
+ if upload_path is not None and upload_path.exists():
139
+ photo_url = upload_file(str(upload_path), f"photos/{upload_name}")
140
+ if result_png_path.exists():
141
+ result_url = upload_file(str(result_png_path), f"results/{result_png_name}")
142
+ record_with_urls = dict(record)
143
+ record_with_urls["photo_url"] = photo_url
144
+ record_with_urls["result_url"] = result_url
145
+ save_measurement(record_with_urls)
146
+ except Exception as exc: # noqa: BLE001
147
+ logger.exception("Supabase persist failed for run %s: %s",
148
+ record.get("run_id"), exc)
149
+
150
+ _persist_executor.submit(_task)
151
+
152
+
153
+ def _format_summary(result: Dict[str, Any], mode: str) -> str:
154
+ """Render a human-readable markdown summary above the raw JSON."""
155
+ if result.get("fail_reason"):
156
+ return f"**Measurement failed:** `{result['fail_reason']}`"
157
+
158
+ if mode == "multi":
159
+ lines = ["### Multi-finger result"]
160
+ for fn in ("index", "middle", "ring"):
161
+ pf = (result.get("per_finger") or {}).get(fn, {})
162
+ if pf.get("status") == "ok":
163
+ diam = pf.get("diameter_cm")
164
+ best = pf.get("best_match")
165
+ rng = pf.get("range", (None, None))
166
+ lines.append(
167
+ f"- **{fn.capitalize()}:** {diam:.2f} cm β†’ "
168
+ f"size **{best}** (range {rng[0]}–{rng[1]})"
169
+ )
170
+ else:
171
+ lines.append(f"- **{fn.capitalize()}:** failed ({pf.get('fail_reason', 'unknown')})")
172
+ if result.get("overall_best_size") is not None:
173
+ lines.append("")
174
+ lines.append(
175
+ f"**Recommended size:** **{result['overall_best_size']}** "
176
+ f"(range {result.get('overall_range_min')}–{result.get('overall_range_max')})"
177
+ )
178
+ if result.get("ai_explanation"):
179
+ lines.append("")
180
+ lines.append(f"**Why:** {result['ai_explanation']}")
181
+ return "\n".join(lines)
182
+
183
+ # Single finger
184
+ diam = result.get("finger_outer_diameter_cm")
185
+ conf = result.get("confidence")
186
+ ring = result.get("ring_size") or {}
187
+ lines = ["### Single-finger result"]
188
+ if diam is not None:
189
+ lines.append(f"- **Diameter:** {diam:.2f} cm")
190
+ if result.get("raw_diameter_cm") is not None:
191
+ lines.append(f"- **Raw (uncalibrated):** {result['raw_diameter_cm']:.2f} cm")
192
+ if conf is not None:
193
+ lines.append(f"- **Confidence:** {conf:.2f}")
194
+ if ring:
195
+ lines.append(
196
+ f"- **Ring size:** **{ring.get('best_match')}** "
197
+ f"(range {ring.get('range_min')}–{ring.get('range_max')})"
198
+ )
199
+ return "\n".join(lines)
200
+
201
+
202
+ # ---------------------------------------------------------------------------
203
+ # Measurement handler
204
+ # ---------------------------------------------------------------------------
205
+
206
+ @spaces.GPU(duration=60)
207
+ def run_measurement(
208
+ image: Optional[np.ndarray],
209
+ finger_index: str,
210
+ mode: str,
211
+ ring_model: str,
212
+ kol_name: str,
213
+ ai_explain: bool,
214
+ ) -> Tuple[Optional[np.ndarray], Dict[str, Any], str]:
215
+ """Run the measurement pipeline and return (overlay, json, summary).
216
+
217
+ Wrapped in `@spaces.GPU` so HF ZeroGPU allocates an H200 slice per
218
+ request. Outside ZeroGPU the decorator is a no-op and this runs on CPU.
219
+ """
220
+ if image is None:
221
+ return None, {"error": "No image uploaded"}, "**Error:** please upload an image."
222
+
223
+ if ring_model not in VALID_RING_MODELS:
224
+ ring_model = DEFAULT_RING_MODEL
225
+
226
+ # Gradio gives us an RGB numpy array; the rest of the pipeline expects BGR.
227
+ if image.ndim == 3 and image.shape[2] == 3:
228
+ image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
229
+ else:
230
+ image_bgr = image
231
+
232
+ base_name, run_id = _make_base_name(kol_name)
233
+ result_png_name = f"{base_name}_result.png"
234
+ result_png_path = RESULTS_DIR / result_png_name
235
+
236
+ # Also save the raw upload so Supabase persistence has something to push.
237
+ upload_name = f"{base_name}.jpg"
238
+ upload_path = UPLOADS_DIR / upload_name
239
+ cv2.imwrite(str(upload_path), image_bgr)
240
+
241
+ if mode == "multi":
242
+ result = measure_multi_finger(
243
+ image=image_bgr,
244
+ edge_method=DEMO_EDGE_METHOD,
245
+ card_method=DEMO_CARD_METHOD,
246
+ hand_mask_method=DEMO_HAND_MASK_METHOD,
247
+ result_png_path=str(result_png_path),
248
+ save_debug=False,
249
+ no_calibration=False,
250
+ ring_model=ring_model,
251
+ )
252
+ result = _numpy_safe(result)
253
+
254
+ per_finger = result.get("per_finger", {})
255
+ finger_widths = {
256
+ fn: (pf.get("diameter_cm") if pf.get("status") == "ok" else None)
257
+ for fn, pf in per_finger.items()
258
+ }
259
+ if ai_explain and result.get("overall_best_size") is not None:
260
+ ai_reason = ai_explain_recommendation(
261
+ finger_widths,
262
+ recommended_size=result["overall_best_size"],
263
+ range_min=result["overall_range_min"],
264
+ range_max=result["overall_range_max"],
265
+ ring_model=ring_model,
266
+ )
267
+ if ai_reason:
268
+ result["ai_explanation"] = ai_reason
269
+
270
+ # Persist async (release GPU slice first β€” this runs on CPU after return)
271
+ confidences = [
272
+ pf.get("confidence") for pf in per_finger.values()
273
+ if pf.get("status") == "ok" and pf.get("confidence") is not None
274
+ ]
275
+ overall_confidence = min(confidences) if confidences else None
276
+ _persist_async(
277
+ upload_path=upload_path,
278
+ upload_name=upload_name,
279
+ result_png_path=result_png_path,
280
+ result_png_name=result_png_name,
281
+ record={
282
+ "run_id": run_id,
283
+ "kol_name": kol_name,
284
+ "mode": "multi",
285
+ "ring_model": ring_model,
286
+ "overall_best_size": result.get("overall_best_size"),
287
+ "overall_range_min": result.get("overall_range_min"),
288
+ "overall_range_max": result.get("overall_range_max"),
289
+ "per_finger": per_finger,
290
+ "confidence": overall_confidence,
291
+ "result_json": result,
292
+ "fail_reason": result.get("fail_reason"),
293
+ },
294
+ )
295
+ else:
296
+ result = measure_finger(
297
+ image=image_bgr,
298
+ finger_index=finger_index,
299
+ edge_method=DEMO_EDGE_METHOD,
300
+ card_method=DEMO_CARD_METHOD,
301
+ hand_mask_method=DEMO_HAND_MASK_METHOD,
302
+ result_png_path=str(result_png_path),
303
+ save_debug=False,
304
+ ring_model=ring_model,
305
+ )
306
+
307
+ raw_diameter = result.get("finger_outer_diameter_cm")
308
+ if raw_diameter is not None:
309
+ result["raw_diameter_cm"] = round(raw_diameter, 4)
310
+ calibrated = round(apply_calibration(raw_diameter), 4)
311
+ result["finger_outer_diameter_cm"] = calibrated
312
+ result["calibration_applied"] = True
313
+ rec = recommend_ring_size(calibrated, ring_model=ring_model)
314
+ if rec:
315
+ result["ring_size"] = rec
316
+
317
+ result = _numpy_safe(result)
318
+ ring_size = result.get("ring_size", {}) or {}
319
+ _persist_async(
320
+ upload_path=upload_path,
321
+ upload_name=upload_name,
322
+ result_png_path=result_png_path,
323
+ result_png_name=result_png_name,
324
+ record={
325
+ "run_id": run_id,
326
+ "kol_name": kol_name,
327
+ "mode": "single",
328
+ "ring_model": ring_model,
329
+ "finger_index": finger_index,
330
+ "diameter_cm": result.get("finger_outer_diameter_cm"),
331
+ "confidence": result.get("confidence"),
332
+ "overall_best_size": ring_size.get("best_match"),
333
+ "overall_range_min": ring_size.get("range_min"),
334
+ "overall_range_max": ring_size.get("range_max"),
335
+ "result_json": result,
336
+ "fail_reason": result.get("fail_reason"),
337
+ },
338
+ )
339
+
340
+ # Load the overlay image Gradio will display.
341
+ overlay_rgb: Optional[np.ndarray] = None
342
+ if result_png_path.exists():
343
+ overlay_bgr = cv2.imread(str(result_png_path))
344
+ if overlay_bgr is not None:
345
+ overlay_rgb = cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB)
346
+
347
+ summary = _format_summary(result, mode)
348
+ return overlay_rgb, result, summary
349
+
350
+
351
+ # ---------------------------------------------------------------------------
352
+ # UI
353
+ # ---------------------------------------------------------------------------
354
+
355
+ _DESCRIPTION = """
356
+ Upload a single photo with **one hand and a credit card on the same flat
357
+ surface**. The app detects the card (for scale), segments the hand, and
358
+ measures the outer diameter of the chosen finger at the ring-wearing zone.
359
+ """
360
+
361
+ _EXAMPLES: List[List[Any]] = []
362
+ if DEFAULT_SAMPLE_PATH.exists():
363
+ _EXAMPLES.append([str(DEFAULT_SAMPLE_PATH), "index", "single", DEFAULT_RING_MODEL, "", False])
364
+
365
+
366
+ def build_demo() -> gr.Blocks:
367
+ with gr.Blocks(title="Ring Sizer") as demo:
368
+ gr.Markdown("# πŸ’ Ring Sizer")
369
+ gr.Markdown(_DESCRIPTION)
370
+
371
+ with gr.Row():
372
+ with gr.Column(scale=1):
373
+ image_in = gr.Image(
374
+ type="numpy",
375
+ label="Hand + credit card photo",
376
+ sources=["upload", "webcam"],
377
+ )
378
+ finger_in = gr.Dropdown(
379
+ choices=["index", "middle", "ring"],
380
+ value="index",
381
+ label="Finger",
382
+ )
383
+ mode_in = gr.Radio(
384
+ choices=["single", "multi"],
385
+ value="single",
386
+ label="Mode",
387
+ info="`single` measures one finger; `multi` measures index + middle + ring and aggregates.",
388
+ )
389
+ ring_model_in = gr.Dropdown(
390
+ choices=list(VALID_RING_MODELS),
391
+ value=DEFAULT_RING_MODEL,
392
+ label="Ring model",
393
+ )
394
+ kol_name_in = gr.Textbox(label="Name (optional)", placeholder="")
395
+ ai_explain_in = gr.Checkbox(label="Explain recommendation (AI)", value=False)
396
+ run_btn = gr.Button("Measure", variant="primary")
397
+
398
+ with gr.Column(scale=1):
399
+ image_out = gr.Image(label="Measurement overlay")
400
+ summary_out = gr.Markdown(label="Summary")
401
+ json_out = gr.JSON(label="Raw result")
402
+
403
+ run_btn.click(
404
+ fn=run_measurement,
405
+ inputs=[image_in, finger_in, mode_in, ring_model_in, kol_name_in, ai_explain_in],
406
+ outputs=[image_out, json_out, summary_out],
407
+ )
408
+
409
+ if _EXAMPLES:
410
+ gr.Examples(
411
+ examples=_EXAMPLES,
412
+ inputs=[image_in, finger_in, mode_in, ring_model_in, kol_name_in, ai_explain_in],
413
+ label="Try the default sample",
414
+ )
415
+
416
+ return demo
417
+
418
+
419
+ demo = build_demo()
420
+
421
+
422
+ if __name__ == "__main__":
423
+ demo.queue().launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", "7860")))
measure_finger.py CHANGED
@@ -12,9 +12,39 @@ Usage:
12
  import argparse
13
  import json
14
  import sys
 
 
15
  from pathlib import Path
16
  from typing import Optional, Dict, Any, List, Literal, Tuple
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  import cv2
19
  import numpy as np
20
 
@@ -580,8 +610,12 @@ def measure_finger(
580
  Returns:
581
  Output dictionary with measurement results
582
  """
 
 
 
583
  # Phase 2: Image quality metrics (informational only β€” no hard fail)
584
- quality = assess_image_quality(image)
 
585
  print(f"Image quality: blur={quality['blur_score']:.1f}, "
586
  f"brightness={quality['brightness']:.1f}, "
587
  f"contrast={quality['contrast']:.1f}")
@@ -596,12 +630,13 @@ def measure_finger(
596
  if save_debug and result_png_path is not None:
597
  finger_debug_dir = str(Path(result_png_path).parent / "finger_segmentation_debug")
598
 
599
- hand_data = segment_hand(
600
- image,
601
- finger=finger_index,
602
- debug_dir=finger_debug_dir,
603
- use_sam_mask=(hand_mask_method == "sam"),
604
- )
 
605
 
606
  if hand_data is None:
607
  print("No hand detected in image")
@@ -639,12 +674,13 @@ def measure_finger(
639
  view_angle_ok = True
640
  card_detected = False
641
  else:
642
- if card_method == "sam":
643
- card_result = _sam_card_detect(
644
- image_canonical, hand_data, save_debug, result_png_path
645
- )
646
- else:
647
- card_result = detect_credit_card(image_canonical, debug_dir=card_debug_dir)
 
648
 
649
  if card_result is None:
650
  print("Credit card not detected in image")
@@ -682,7 +718,8 @@ def measure_finger(
682
  # length and can cut into a wider-than-average finger, which would make
683
  # the mask boundary narrower than the true SAM boundary.
684
  raw_hand_mask = hand_data.get("mask")
685
- finger_data = isolate_finger(hand_data, finger=finger_index, image_shape=(h_can, w_can))
 
686
 
687
  if finger_data is None:
688
  print(f"Could not isolate finger: {finger_index}")
@@ -889,20 +926,21 @@ def measure_finger(
889
  else:
890
  edge_mask_input = cleaned_mask
891
 
892
- sobel_measurement = refine_edges_sobel(
893
- image=image_canonical, # Use canonical orientation
894
- axis_data=axis_data,
895
- zone_data=zone_data,
896
- scale_px_per_cm=px_per_cm,
897
- finger_landmarks=finger_data.get("landmarks"),
898
- sobel_threshold=sobel_threshold,
899
- kernel_size=sobel_kernel_size,
900
- use_subpixel=use_subpixel,
901
- finger_mask=edge_mask_input,
902
- debug_dir=edge_debug_dir,
903
- mask_mode=mask_mode,
904
- finger_name=finger_data.get("finger_name"),
905
- )
 
906
 
907
  sobel_width_cm = sobel_measurement["median_width_cm"]
908
  print(f"Edge width: {sobel_width_cm:.4f}cm "
@@ -1057,6 +1095,7 @@ def measure_finger(
1057
  print(f"Warning: Confidence {confidence_breakdown['overall']:.3f} is below threshold {confidence_threshold:.3f}")
1058
 
1059
  # Phase 9: Result visualization (always generated)
 
1060
  if result_png_path is not None:
1061
  print(f"Generating result visualization...")
1062
 
@@ -1141,6 +1180,14 @@ def measure_finger(
1141
  _save_debug_visualization(result_png_path, debug_image)
1142
  print(f"Result visualization saved to: {result_png_path}")
1143
 
 
 
 
 
 
 
 
 
1144
 
1145
  return create_output(
1146
  finger_diameter_cm=median_width_cm,
@@ -1410,8 +1457,12 @@ def measure_multi_finger(
1410
  """
1411
  from src.finger_segmentation import FINGER_LANDMARKS
1412
 
 
 
 
1413
  # Phase 1: Image quality metrics (informational only β€” no hard fail)
1414
- quality = assess_image_quality(image)
 
1415
  print(f"[multi] Image quality: blur={quality['blur_score']:.1f}, "
1416
  f"brightness={quality['brightness']:.1f}, contrast={quality['contrast']:.1f}")
1417
  if not quality["passed"]:
@@ -1428,12 +1479,13 @@ def measure_multi_finger(
1428
  if save_debug and result_png_path is not None:
1429
  finger_debug_dir = str(Path(result_png_path).parent / "finger_segmentation_debug")
1430
 
1431
- hand_data = segment_hand(
1432
- image,
1433
- finger="index",
1434
- debug_dir=finger_debug_dir,
1435
- use_sam_mask=(hand_mask_method == "sam"),
1436
- )
 
1437
  if hand_data is None:
1438
  print("[multi] No hand detected")
1439
  return {"fail_reason": "hand_not_detected", "per_finger": {}, "fingers_measured": 0, "fingers_succeeded": 0}
@@ -1453,12 +1505,13 @@ def measure_multi_finger(
1453
  view_angle_ok = True
1454
  card_detected = False
1455
  else:
1456
- if card_method == "sam":
1457
- card_result = _sam_card_detect(
1458
- image_canonical, hand_data, save_debug, result_png_path
1459
- )
1460
- else:
1461
- card_result = detect_credit_card(image_canonical, debug_dir=card_debug_dir)
 
1462
  if card_result is None:
1463
  # Emit a diagnostic visualization so the failure is debuggable:
1464
  # hand mask + card-prompt seeds on the canonical image. Without
@@ -1502,20 +1555,21 @@ def measure_multi_finger(
1502
  per_finger_raw: Dict[str, Dict] = {}
1503
  for fn in MULTI_FINGERS:
1504
  print(f"\n[multi] === Measuring {fn} finger ===")
1505
- result = _measure_single_finger_from_shared(
1506
- image_canonical=image_canonical,
1507
- hand_data=hand_data,
1508
- finger_name=fn,
1509
- px_per_cm=px_per_cm,
1510
- card_detected=card_detected,
1511
- view_angle_ok=view_angle_ok,
1512
- card_result=card_result,
1513
- scale_confidence=scale_confidence,
1514
- edge_method=edge_method,
1515
- sobel_threshold=sobel_threshold,
1516
- sobel_kernel_size=sobel_kernel_size,
1517
- use_subpixel=use_subpixel,
1518
- )
 
1519
 
1520
  # Apply calibration
1521
  raw_diam = result.get("finger_outer_diameter_cm")
@@ -1541,16 +1595,17 @@ def measure_multi_finger(
1541
 
1542
  # Build debug visualization
1543
  if result_png_path is not None:
1544
- _draw_multi_finger_debug(
1545
- image_canonical=image_canonical,
1546
- per_finger_raw=per_finger_raw,
1547
- aggregated=aggregated,
1548
- card_result=card_result,
1549
- px_per_cm=px_per_cm,
1550
- result_png_path=result_png_path,
1551
- hand_mask=hand_data.get("mask") if hand_data else None,
1552
- hand_landmarks=hand_data.get("landmarks") if hand_data else None,
1553
- )
 
1554
 
1555
  # Clean internal data from output
1556
  for fn, r in per_finger_raw.items():
@@ -1563,6 +1618,11 @@ def measure_multi_finger(
1563
  "lighting_uniform": lighting.get("uniform", True),
1564
  "fingers_well_spaced": spacing.get("well_spaced", True),
1565
  }
 
 
 
 
 
1566
  return aggregated
1567
 
1568
 
 
12
  import argparse
13
  import json
14
  import sys
15
+ import time
16
+ from contextlib import contextmanager
17
  from pathlib import Path
18
  from typing import Optional, Dict, Any, List, Literal, Tuple
19
 
20
+
21
+ @contextmanager
22
+ def _phase(name: str, totals: Optional[Dict[str, float]] = None):
23
+ """Log elapsed wall time for a pipeline phase.
24
+
25
+ Prints `[timing] <name>: <ms> ms` on exit. If `totals` is passed, the
26
+ elapsed milliseconds are also accumulated under `name` so the caller can
27
+ print a summary at the end.
28
+ """
29
+ t0 = time.perf_counter()
30
+ try:
31
+ yield
32
+ finally:
33
+ dt_ms = (time.perf_counter() - t0) * 1000.0
34
+ print(f"[timing] {name}: {dt_ms:.1f} ms")
35
+ if totals is not None:
36
+ totals[name] = totals.get(name, 0.0) + dt_ms
37
+
38
+
39
+ def _print_timing_summary(totals: Dict[str, float]) -> None:
40
+ if not totals:
41
+ return
42
+ total_ms = sum(totals.values())
43
+ print(f"[timing] ===== summary (total {total_ms:.1f} ms) =====")
44
+ for name, ms in sorted(totals.items(), key=lambda kv: -kv[1]):
45
+ pct = (ms / total_ms * 100.0) if total_ms > 0 else 0.0
46
+ print(f"[timing] {name:<28s} {ms:8.1f} ms ({pct:5.1f}%)")
47
+
48
  import cv2
49
  import numpy as np
50
 
 
610
  Returns:
611
  Output dictionary with measurement results
612
  """
613
+ timings: Dict[str, float] = {}
614
+ t_pipeline_start = time.perf_counter()
615
+
616
  # Phase 2: Image quality metrics (informational only β€” no hard fail)
617
+ with _phase("image_quality", timings):
618
+ quality = assess_image_quality(image)
619
  print(f"Image quality: blur={quality['blur_score']:.1f}, "
620
  f"brightness={quality['brightness']:.1f}, "
621
  f"contrast={quality['contrast']:.1f}")
 
630
  if save_debug and result_png_path is not None:
631
  finger_debug_dir = str(Path(result_png_path).parent / "finger_segmentation_debug")
632
 
633
+ with _phase(f"hand_segment[{hand_mask_method}]", timings):
634
+ hand_data = segment_hand(
635
+ image,
636
+ finger=finger_index,
637
+ debug_dir=finger_debug_dir,
638
+ use_sam_mask=(hand_mask_method == "sam"),
639
+ )
640
 
641
  if hand_data is None:
642
  print("No hand detected in image")
 
674
  view_angle_ok = True
675
  card_detected = False
676
  else:
677
+ with _phase(f"card_detect[{card_method}]", timings):
678
+ if card_method == "sam":
679
+ card_result = _sam_card_detect(
680
+ image_canonical, hand_data, save_debug, result_png_path
681
+ )
682
+ else:
683
+ card_result = detect_credit_card(image_canonical, debug_dir=card_debug_dir)
684
 
685
  if card_result is None:
686
  print("Credit card not detected in image")
 
718
  # length and can cut into a wider-than-average finger, which would make
719
  # the mask boundary narrower than the true SAM boundary.
720
  raw_hand_mask = hand_data.get("mask")
721
+ with _phase("finger_isolate", timings):
722
+ finger_data = isolate_finger(hand_data, finger=finger_index, image_shape=(h_can, w_can))
723
 
724
  if finger_data is None:
725
  print(f"Could not isolate finger: {finger_index}")
 
926
  else:
927
  edge_mask_input = cleaned_mask
928
 
929
+ with _phase(f"edge_refine[{mask_mode}]", timings):
930
+ sobel_measurement = refine_edges_sobel(
931
+ image=image_canonical, # Use canonical orientation
932
+ axis_data=axis_data,
933
+ zone_data=zone_data,
934
+ scale_px_per_cm=px_per_cm,
935
+ finger_landmarks=finger_data.get("landmarks"),
936
+ sobel_threshold=sobel_threshold,
937
+ kernel_size=sobel_kernel_size,
938
+ use_subpixel=use_subpixel,
939
+ finger_mask=edge_mask_input,
940
+ debug_dir=edge_debug_dir,
941
+ mask_mode=mask_mode,
942
+ finger_name=finger_data.get("finger_name"),
943
+ )
944
 
945
  sobel_width_cm = sobel_measurement["median_width_cm"]
946
  print(f"Edge width: {sobel_width_cm:.4f}cm "
 
1095
  print(f"Warning: Confidence {confidence_breakdown['overall']:.3f} is below threshold {confidence_threshold:.3f}")
1096
 
1097
  # Phase 9: Result visualization (always generated)
1098
+ t_viz_start = time.perf_counter() if result_png_path is not None else None
1099
  if result_png_path is not None:
1100
  print(f"Generating result visualization...")
1101
 
 
1180
  _save_debug_visualization(result_png_path, debug_image)
1181
  print(f"Result visualization saved to: {result_png_path}")
1182
 
1183
+ if t_viz_start is not None:
1184
+ viz_ms = (time.perf_counter() - t_viz_start) * 1000.0
1185
+ print(f"[timing] visualization: {viz_ms:.1f} ms")
1186
+ timings["visualization"] = timings.get("visualization", 0.0) + viz_ms
1187
+
1188
+ pipeline_ms = (time.perf_counter() - t_pipeline_start) * 1000.0
1189
+ print(f"[timing] pipeline_total: {pipeline_ms:.1f} ms")
1190
+ _print_timing_summary(timings)
1191
 
1192
  return create_output(
1193
  finger_diameter_cm=median_width_cm,
 
1457
  """
1458
  from src.finger_segmentation import FINGER_LANDMARKS
1459
 
1460
+ timings: Dict[str, float] = {}
1461
+ t_pipeline_start = time.perf_counter()
1462
+
1463
  # Phase 1: Image quality metrics (informational only β€” no hard fail)
1464
+ with _phase("image_quality", timings):
1465
+ quality = assess_image_quality(image)
1466
  print(f"[multi] Image quality: blur={quality['blur_score']:.1f}, "
1467
  f"brightness={quality['brightness']:.1f}, contrast={quality['contrast']:.1f}")
1468
  if not quality["passed"]:
 
1479
  if save_debug and result_png_path is not None:
1480
  finger_debug_dir = str(Path(result_png_path).parent / "finger_segmentation_debug")
1481
 
1482
+ with _phase(f"hand_segment[{hand_mask_method}]", timings):
1483
+ hand_data = segment_hand(
1484
+ image,
1485
+ finger="index",
1486
+ debug_dir=finger_debug_dir,
1487
+ use_sam_mask=(hand_mask_method == "sam"),
1488
+ )
1489
  if hand_data is None:
1490
  print("[multi] No hand detected")
1491
  return {"fail_reason": "hand_not_detected", "per_finger": {}, "fingers_measured": 0, "fingers_succeeded": 0}
 
1505
  view_angle_ok = True
1506
  card_detected = False
1507
  else:
1508
+ with _phase(f"card_detect[{card_method}]", timings):
1509
+ if card_method == "sam":
1510
+ card_result = _sam_card_detect(
1511
+ image_canonical, hand_data, save_debug, result_png_path
1512
+ )
1513
+ else:
1514
+ card_result = detect_credit_card(image_canonical, debug_dir=card_debug_dir)
1515
  if card_result is None:
1516
  # Emit a diagnostic visualization so the failure is debuggable:
1517
  # hand mask + card-prompt seeds on the canonical image. Without
 
1555
  per_finger_raw: Dict[str, Dict] = {}
1556
  for fn in MULTI_FINGERS:
1557
  print(f"\n[multi] === Measuring {fn} finger ===")
1558
+ with _phase(f"measure_finger[{fn}]", timings):
1559
+ result = _measure_single_finger_from_shared(
1560
+ image_canonical=image_canonical,
1561
+ hand_data=hand_data,
1562
+ finger_name=fn,
1563
+ px_per_cm=px_per_cm,
1564
+ card_detected=card_detected,
1565
+ view_angle_ok=view_angle_ok,
1566
+ card_result=card_result,
1567
+ scale_confidence=scale_confidence,
1568
+ edge_method=edge_method,
1569
+ sobel_threshold=sobel_threshold,
1570
+ sobel_kernel_size=sobel_kernel_size,
1571
+ use_subpixel=use_subpixel,
1572
+ )
1573
 
1574
  # Apply calibration
1575
  raw_diam = result.get("finger_outer_diameter_cm")
 
1595
 
1596
  # Build debug visualization
1597
  if result_png_path is not None:
1598
+ with _phase("visualization", timings):
1599
+ _draw_multi_finger_debug(
1600
+ image_canonical=image_canonical,
1601
+ per_finger_raw=per_finger_raw,
1602
+ aggregated=aggregated,
1603
+ card_result=card_result,
1604
+ px_per_cm=px_per_cm,
1605
+ result_png_path=result_png_path,
1606
+ hand_mask=hand_data.get("mask") if hand_data else None,
1607
+ hand_landmarks=hand_data.get("landmarks") if hand_data else None,
1608
+ )
1609
 
1610
  # Clean internal data from output
1611
  for fn, r in per_finger_raw.items():
 
1618
  "lighting_uniform": lighting.get("uniform", True),
1619
  "fingers_well_spaced": spacing.get("well_spaced", True),
1620
  }
1621
+
1622
+ pipeline_ms = (time.perf_counter() - t_pipeline_start) * 1000.0
1623
+ print(f"[timing] pipeline_total: {pipeline_ms:.1f} ms")
1624
+ _print_timing_summary(timings)
1625
+
1626
  return aggregated
1627
 
1628
 
requirements.txt CHANGED
@@ -12,3 +12,6 @@ torch>=2.4.0
12
  torchvision>=0.19.0
13
  transformers>=4.47.0
14
  pillow>=10.0.0
 
 
 
 
12
  torchvision>=0.19.0
13
  transformers>=4.47.0
14
  pillow>=10.0.0
15
+ # v5: HF ZeroGPU requires Gradio SDK; `spaces` provides @spaces.GPU (no-op off ZeroGPU)
16
+ gradio>=4.44.0
17
+ spaces>=0.30.0
src/sam_backend.py CHANGED
@@ -23,6 +23,35 @@ INFERENCE_MAX_SIDE = 1024
23
 
24
  _model = None
25
  _processor = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  def get_sam2() -> Tuple[object, object]:
@@ -32,19 +61,26 @@ def get_sam2() -> Tuple[object, object]:
32
  the HEAD-request retry storm that happens when huggingface.co is slow or
33
  unreachable but the weights are already on disk. On a true cache miss we
34
  fall through to a normal online load.
 
 
 
 
 
 
35
  """
36
- global _model, _processor
37
  if _model is None or _processor is None:
38
  from transformers import Sam2Model, Sam2Processor
 
39
  t0 = time.time()
40
- print(f" Loading SAM 2.1 ({SAM2_MODEL_ID})...")
41
  try:
42
  _processor = Sam2Processor.from_pretrained(SAM2_MODEL_ID, local_files_only=True)
43
- _model = Sam2Model.from_pretrained(SAM2_MODEL_ID, local_files_only=True).to("cpu").eval()
44
- print(f" SAM 2.1 loaded (offline cache) in {time.time() - t0:.1f}s")
45
  except (OSError, ValueError):
46
  # Cache miss β€” fall back to online download.
47
  _processor = Sam2Processor.from_pretrained(SAM2_MODEL_ID)
48
- _model = Sam2Model.from_pretrained(SAM2_MODEL_ID).to("cpu").eval()
49
- print(f" SAM 2.1 loaded (online) in {time.time() - t0:.1f}s")
50
  return _model, _processor
 
23
 
24
  _model = None
25
  _processor = None
26
+ _device: str = "cpu"
27
+
28
+
29
+ def _select_device() -> str:
30
+ """Pick a torch device for SAM inference.
31
+
32
+ Returns ``"cuda"`` when a GPU is visible (HF ZeroGPU exposes CUDA even
33
+ at module import time via an emulation shim, so this picks the right
34
+ path both at startup and inside ``@spaces.GPU`` functions), otherwise
35
+ ``"cpu"``. Import of torch is local so CLI users without it still see
36
+ a clean error from the caller.
37
+ """
38
+ try:
39
+ import torch
40
+ if torch.cuda.is_available():
41
+ return "cuda"
42
+ except Exception:
43
+ pass
44
+ return "cpu"
45
+
46
+
47
+ def get_sam2_device() -> str:
48
+ """Return the device the SAM singleton was loaded on.
49
+
50
+ Callers use this to move their ``processor(..., return_tensors="pt")``
51
+ outputs onto the same device as the model before the forward pass.
52
+ Returns ``"cpu"`` before ``get_sam2()`` has been called.
53
+ """
54
+ return _device
55
 
56
 
57
  def get_sam2() -> Tuple[object, object]:
 
61
  the HEAD-request retry storm that happens when huggingface.co is slow or
62
  unreachable but the weights are already on disk. On a true cache miss we
63
  fall through to a normal online load.
64
+
65
+ The model is placed on the device returned by ``_select_device()``.
66
+ HF ZeroGPU docs require CUDA placements to happen at module-level
67
+ startup for best performance β€” callers in ZeroGPU Spaces should invoke
68
+ ``get_sam2()`` once at import time so this runs before the first
69
+ ``@spaces.GPU``-wrapped request.
70
  """
71
+ global _model, _processor, _device
72
  if _model is None or _processor is None:
73
  from transformers import Sam2Model, Sam2Processor
74
+ _device = _select_device()
75
  t0 = time.time()
76
+ print(f" Loading SAM 2.1 ({SAM2_MODEL_ID}) on {_device}...")
77
  try:
78
  _processor = Sam2Processor.from_pretrained(SAM2_MODEL_ID, local_files_only=True)
79
+ _model = Sam2Model.from_pretrained(SAM2_MODEL_ID, local_files_only=True).to(_device).eval()
80
+ print(f" SAM 2.1 loaded (offline cache, {_device}) in {time.time() - t0:.1f}s")
81
  except (OSError, ValueError):
82
  # Cache miss β€” fall back to online download.
83
  _processor = Sam2Processor.from_pretrained(SAM2_MODEL_ID)
84
+ _model = Sam2Model.from_pretrained(SAM2_MODEL_ID).to(_device).eval()
85
+ print(f" SAM 2.1 loaded (online, {_device}) in {time.time() - t0:.1f}s")
86
  return _model, _processor
src/sam_card_detection.py CHANGED
@@ -26,7 +26,7 @@ from .card_detection import (
26
  get_quad_dimensions,
27
  order_corners,
28
  )
29
- from .sam_backend import INFERENCE_MAX_SIDE as PROMPT_INFERENCE_MAX_SIDE, get_sam2
30
 
31
  # HF Hub model id β€” tiny, small, base-plus, large
32
  SAM2_MODEL_ID = "facebook/sam2.1-hiera-small"
@@ -531,6 +531,13 @@ def detect_credit_card_sam_prompt(
531
  input_labels=input_labels,
532
  return_tensors="pt",
533
  )
 
 
 
 
 
 
 
534
  with torch.inference_mode():
535
  # multimask_output=True gives 3 masks per seed (small / medium / large
536
  # disambiguation of the prompt). Empirically this matters for card
@@ -542,13 +549,13 @@ def detect_credit_card_sam_prompt(
542
 
543
  # Score masks in the scaled 1024-space. Only the single winner is
544
  # upscaled to full resolution afterward, which avoids O(N) 12 MP resizes.
545
- scaled_h = inputs["original_sizes"][0][0].item()
546
- scaled_w = inputs["original_sizes"][0][1].item()
547
  scaled_area = float(scaled_h * scaled_w)
548
 
549
  masks_list = processor.post_process_masks(
550
  outputs.pred_masks.cpu(),
551
- inputs["original_sizes"],
552
  mask_threshold=0.0,
553
  )
554
  masks_tensor = masks_list[0] # (num_prompts, num_candidates, H_s, W_s)
 
26
  get_quad_dimensions,
27
  order_corners,
28
  )
29
+ from .sam_backend import INFERENCE_MAX_SIDE as PROMPT_INFERENCE_MAX_SIDE, get_sam2, get_sam2_device
30
 
31
  # HF Hub model id β€” tiny, small, base-plus, large
32
  SAM2_MODEL_ID = "facebook/sam2.1-hiera-small"
 
531
  input_labels=input_labels,
532
  return_tensors="pt",
533
  )
534
+ # `original_sizes` is used after the forward pass for mask post-processing
535
+ # and scale calculations. Pull it to CPU before moving `inputs` to the
536
+ # model device so downstream code never has to chase device placement.
537
+ original_sizes_cpu = inputs["original_sizes"].cpu() if hasattr(inputs["original_sizes"], "cpu") else inputs["original_sizes"]
538
+ device = get_sam2_device()
539
+ if device != "cpu":
540
+ inputs = inputs.to(device)
541
  with torch.inference_mode():
542
  # multimask_output=True gives 3 masks per seed (small / medium / large
543
  # disambiguation of the prompt). Empirically this matters for card
 
549
 
550
  # Score masks in the scaled 1024-space. Only the single winner is
551
  # upscaled to full resolution afterward, which avoids O(N) 12 MP resizes.
552
+ scaled_h = int(original_sizes_cpu[0][0].item())
553
+ scaled_w = int(original_sizes_cpu[0][1].item())
554
  scaled_area = float(scaled_h * scaled_w)
555
 
556
  masks_list = processor.post_process_masks(
557
  outputs.pred_masks.cpu(),
558
+ original_sizes_cpu,
559
  mask_threshold=0.0,
560
  )
561
  masks_tensor = masks_list[0] # (num_prompts, num_candidates, H_s, W_s)
src/sam_hand_segmentation.py CHANGED
@@ -22,7 +22,7 @@ from typing import List, Optional, Tuple
22
  import cv2
23
  import numpy as np
24
 
25
- from .sam_backend import INFERENCE_MAX_SIDE, get_sam2
26
 
27
 
28
  def _downscale(image_bgr: np.ndarray) -> Tuple[np.ndarray, float]:
@@ -87,6 +87,9 @@ def segment_hand_sam(
87
  input_labels=[[prompt_labels]],
88
  return_tensors="pt",
89
  )
 
 
 
90
  with torch.inference_mode():
91
  outputs = model(**inputs, multimask_output=True)
92
 
 
22
  import cv2
23
  import numpy as np
24
 
25
+ from .sam_backend import INFERENCE_MAX_SIDE, get_sam2, get_sam2_device
26
 
27
 
28
  def _downscale(image_bgr: np.ndarray) -> Tuple[np.ndarray, float]:
 
87
  input_labels=[[prompt_labels]],
88
  return_tensors="pt",
89
  )
90
+ device = get_sam2_device()
91
+ if device != "cpu":
92
+ inputs = inputs.to(device)
93
  with torch.inference_mode():
94
  outputs = model(**inputs, multimask_output=True)
95