WatNeru commited on
Commit
dbb276c
·
1 Parent(s): adb0f98

Fix model download logic for PyTorch models and improve error handling

Browse files
Files changed (1) hide show
  1. app.py +35 -9
app.py CHANGED
@@ -71,15 +71,22 @@ def _set_status(message: str) -> None:
71
  def ensure_model_available() -> str:
72
  """モデルディレクトリをローカルに用意(なければHFから取得)"""
73
  print(f"[MODEL] ensure_model_available() 開始")
74
- current_path = Path(path_manager.get_model_path())
75
- print(f"[MODEL] 現在のモデルパス: {current_path}")
76
 
77
- # PyTorch モデルの場合、ディレクトリ全体チェック
78
- if current_path.exists() and current_path.is_dir():
79
- # config.json があるか確認(モデルディレクトリの確認)
80
- if (current_path / "config.json").exists():
81
- print(f"[MODEL] 既存のモデルディレクトリを使用: {current_path}")
82
- return str(current_path)
 
 
 
 
 
 
 
83
 
84
  print(f"[MODEL] モデルディレクトリが見つからないため、ダウンロードを開始")
85
  HF_LOCAL_DIR.mkdir(parents=True, exist_ok=True)
@@ -88,6 +95,9 @@ def ensure_model_available() -> str:
88
 
89
  # snapshot_downloadでモデル全体をダウンロード(PyTorch モデルは複数ファイル)
90
  try:
 
 
 
91
  downloaded_dir = snapshot_download(
92
  repo_id=HF_MODEL_REPO,
93
  local_dir=str(HF_LOCAL_DIR),
@@ -96,7 +106,16 @@ def ensure_model_available() -> str:
96
  )
97
  print(f"[MODEL] snapshot_download完了: {downloaded_dir}")
98
  except Exception as e:
99
- print(f"[MODEL] snapshot_downloadエラー: {e}")
 
 
 
 
 
 
 
 
 
100
  import traceback
101
  traceback.print_exc()
102
  raise
@@ -105,10 +124,17 @@ def ensure_model_available() -> str:
105
  downloaded_dir_path = Path(downloaded_dir)
106
  print(f"[MODEL] ダウンロード先パス: {downloaded_dir_path}")
107
 
 
 
 
 
 
 
108
  # config.json があるか確認
109
  if not (downloaded_dir_path / "config.json").exists():
110
  raise FileNotFoundError(
111
  f"モデルディレクトリ {downloaded_dir} に config.json が見つかりません。"
 
112
  )
113
 
114
  model_path_str = str(downloaded_dir_path.resolve())
 
71
  def ensure_model_available() -> str:
72
  """モデルディレクトリをローカルに用意(なければHFから取得)"""
73
  print(f"[MODEL] ensure_model_available() 開始")
74
+ print(f"[MODEL] モデルリポジトリ: {HF_MODEL_REPO}")
75
+ print(f"[MODEL] HF_TOKEN設定: {'あり' if HF_TOKEN else 'なし'}")
76
 
77
+ # モデルディレクトリのパス構築(リポジトリ名から)
78
+ model_dir_name = HF_MODEL_REPO.split("/")[-1] # "Llama-3.2-3B-Instruct"
79
+ model_cache_path = HF_LOCAL_DIR / model_dir_name
80
+ print(f"[MODEL] モデルキャッシュパス: {model_cache_path}")
81
+
82
+ # 既存のモデルディレクトリをチェック
83
+ if model_cache_path.exists() and model_cache_path.is_dir():
84
+ if (model_cache_path / "config.json").exists():
85
+ print(f"[MODEL] 既存のモデルディレクトリを使用: {model_cache_path}")
86
+ model_path_str = str(model_cache_path.resolve())
87
+ os.environ["LLM_MODEL_PATH"] = model_path_str
88
+ path_manager.model_path = model_path_str
89
+ return model_path_str
90
 
91
  print(f"[MODEL] モデルディレクトリが見つからないため、ダウンロードを開始")
92
  HF_LOCAL_DIR.mkdir(parents=True, exist_ok=True)
 
95
 
96
  # snapshot_downloadでモデル全体をダウンロード(PyTorch モデルは複数ファイル)
97
  try:
98
+ if not HF_TOKEN:
99
+ print("[MODEL] 警告: HF_TOKEN が設定されていません。認証が必要なモデルの場合、ダウンロードに失敗する可能性があります。")
100
+
101
  downloaded_dir = snapshot_download(
102
  repo_id=HF_MODEL_REPO,
103
  local_dir=str(HF_LOCAL_DIR),
 
106
  )
107
  print(f"[MODEL] snapshot_download完了: {downloaded_dir}")
108
  except Exception as e:
109
+ error_msg = str(e)
110
+ print(f"[MODEL] snapshot_downloadエラー: {error_msg}")
111
+
112
+ # 認証エラーの場合、より詳細なメッセージを表示
113
+ if "401" in error_msg or "authentication" in error_msg.lower() or "token" in error_msg.lower():
114
+ print("[MODEL] 認証エラーの可能性があります。HF_TOKEN が正しく設定されているか確認してください。")
115
+ elif "404" in error_msg or "not found" in error_msg.lower():
116
+ print(f"[MODEL] リポジトリが見つかりません: {HF_MODEL_REPO}")
117
+ print("[MODEL] リポジトリ名が正しいか、アクセス権限があるか確認してください。")
118
+
119
  import traceback
120
  traceback.print_exc()
121
  raise
 
124
  downloaded_dir_path = Path(downloaded_dir)
125
  print(f"[MODEL] ダウンロード先パス: {downloaded_dir_path}")
126
 
127
+ # ダウンロードされたファイルをリストアップ
128
+ downloaded_files = list(downloaded_dir_path.glob("*"))
129
+ print(f"[MODEL] ダウンロードされたファイル数: {len(downloaded_files)}")
130
+ if downloaded_files:
131
+ print(f"[MODEL] ダウンロードされたファイル: {[f.name for f in downloaded_files[:10]]}")
132
+
133
  # config.json があるか確認
134
  if not (downloaded_dir_path / "config.json").exists():
135
  raise FileNotFoundError(
136
  f"モデルディレクトリ {downloaded_dir} に config.json が見つかりません。"
137
+ f"ダウンロードされたファイル: {[f.name for f in downloaded_files[:10]]}"
138
  )
139
 
140
  model_path_str = str(downloaded_dir_path.resolve())