trtd56 commited on
Commit
f47b67b
·
1 Parent(s): 19bd40d

Add auto-download for hand_landmarker model, update .gitignore

Browse files
.gitignore CHANGED
@@ -34,3 +34,7 @@ htmlcov/
34
  # OS
35
  .DS_Store
36
  Thumbs.db
 
 
 
 
 
34
  # OS
35
  .DS_Store
36
  Thumbs.db
37
+
38
+ # ML Models (downloaded at runtime)
39
+ *.task
40
+ models/
rock_paper_scissors/detection/hand_detector.py CHANGED
@@ -2,10 +2,14 @@
2
 
3
  from typing import Optional, NamedTuple
4
  from pathlib import Path
 
5
  import numpy as np
6
 
7
  from ..config import settings
8
 
 
 
 
9
  # MediaPipe Task APIのインポート
10
  try:
11
  import mediapipe as mp
@@ -24,8 +28,16 @@ class HandLandmarks(NamedTuple):
24
  handedness: str # "Left" or "Right"
25
 
26
 
 
 
 
 
 
 
 
 
27
  def _get_model_path() -> Path:
28
- """モデルファイルのパスを取得"""
29
  # パッケージディレクトリからの相対パス
30
  package_dir = Path(__file__).parent.parent
31
  model_path = package_dir / "models" / "hand_landmarker.task"
@@ -38,12 +50,9 @@ def _get_model_path() -> Path:
38
  if alt_path.exists():
39
  return alt_path
40
 
41
- raise FileNotFoundError(
42
- f"Model file not found. Please download it:\n"
43
- f"mkdir -p {package_dir / 'models'} && "
44
- f"curl -L -o {model_path} "
45
- f"https://storage.googleapis.com/mediapipe-models/hand_landmarker/hand_landmarker/float16/latest/hand_landmarker.task"
46
- )
47
 
48
 
49
  class HandDetector:
 
2
 
3
  from typing import Optional, NamedTuple
4
  from pathlib import Path
5
+ import urllib.request
6
  import numpy as np
7
 
8
  from ..config import settings
9
 
10
+ # モデルダウンロードURL
11
+ MODEL_URL = "https://storage.googleapis.com/mediapipe-models/hand_landmarker/hand_landmarker/float16/latest/hand_landmarker.task"
12
+
13
  # MediaPipe Task APIのインポート
14
  try:
15
  import mediapipe as mp
 
28
  handedness: str # "Left" or "Right"
29
 
30
 
31
+ def _download_model(model_path: Path) -> None:
32
+ """モデルファイルをダウンロード"""
33
+ print(f"Downloading hand_landmarker.task model...")
34
+ model_path.parent.mkdir(parents=True, exist_ok=True)
35
+ urllib.request.urlretrieve(MODEL_URL, model_path)
36
+ print(f"Model downloaded to {model_path}")
37
+
38
+
39
  def _get_model_path() -> Path:
40
+ """モデルファイルのパスを取得(存在しない場合は自動ダウンロード)"""
41
  # パッケージディレクトリからの相対パス
42
  package_dir = Path(__file__).parent.parent
43
  model_path = package_dir / "models" / "hand_landmarker.task"
 
50
  if alt_path.exists():
51
  return alt_path
52
 
53
+ # モデルが見つからない場合、自動ダウンロード
54
+ _download_model(model_path)
55
+ return model_path
 
 
 
56
 
57
 
58
  class HandDetector: