WatNeru commited on
Commit
17b871a
·
1 Parent(s): ca562c6

Load model directly from Hub to save storage space

Browse files
Files changed (2) hide show
  1. app.py +13 -81
  2. package/ai.py +14 -5
app.py CHANGED
@@ -69,7 +69,7 @@ def _set_status(message: str) -> None:
69
 
70
 
71
  def ensure_model_available() -> str:
72
- """モデルディレクトリをローカルに用意なければHFから取得)"""
73
  print(f"[MODEL] ensure_model_available() 開始")
74
  print(f"[MODEL] モデルリポジトリ: {HF_MODEL_REPO}")
75
  print(f"[MODEL] HF_TOKEN設定: {'あり' if HF_TOKEN else 'なし'}")
@@ -78,32 +78,13 @@ def ensure_model_available() -> str:
78
  token_preview = HF_TOKEN[:7] + "..." + HF_TOKEN[-4:] if len(HF_TOKEN) > 11 else "***"
79
  print(f"[MODEL] HF_TOKENプレビュー: {token_preview} (長さ: {len(HF_TOKEN)})")
80
 
81
- # モデルディレクトリのパス構築(リポジトリ名から)
82
- model_dir_name = HF_MODEL_REPO.split("/")[-1] # "Llama-3.2-3B-Instruct"
83
- model_cache_path = HF_LOCAL_DIR / model_dir_name
84
- print(f"[MODEL] モデルキャッシュパス: {model_cache_path}")
85
 
86
- # 既存モデルディレクトリチェック
87
- if model_cache_path.exists() and model_cache_path.is_dir():
88
- if (model_cache_path / "config.json").exists():
89
- print(f"[MODEL] 既存のモデルディレクトリを使用: {model_cache_path}")
90
- model_path_str = str(model_cache_path.resolve())
91
- os.environ["LLM_MODEL_PATH"] = model_path_str
92
- path_manager.model_path = model_path_str
93
- return model_path_str
94
-
95
- print(f"[MODEL] モデルディレクトリが見つからないため、ダウンロードを開始")
96
- HF_LOCAL_DIR.mkdir(parents=True, exist_ok=True)
97
- print(f"[MODEL] ダウンロード先ディレクトリ: {HF_LOCAL_DIR}")
98
- _set_status("Hugging Face からモデルをダウンロード中...")
99
-
100
- # snapshot_downloadでモデル全体をダウンロード(PyTorch モデルは複数ファイル)
101
- try:
102
- if not HF_TOKEN:
103
- print("[MODEL] 警告: HF_TOKEN が設定されていません。認証が必要なモデルの場合、ダウンロードに失敗する可能性があります。")
104
- raise ValueError("HF_TOKEN が設定されていません")
105
-
106
- # huggingface_hub の login を使って明示的に認証(念のため)
107
  try:
108
  from huggingface_hub import login
109
  print("[MODEL] huggingface_hub.login() を実行中...")
@@ -111,49 +92,9 @@ def ensure_model_available() -> str:
111
  print("[MODEL] ログイン成功")
112
  except Exception as login_error:
113
  print(f"[MODEL] ログインエラー(続行): {login_error}")
114
-
115
- print(f"[MODEL] snapshot_download を開始: {HF_MODEL_REPO}")
116
- downloaded_dir = snapshot_download(
117
- repo_id=HF_MODEL_REPO,
118
- local_dir=str(HF_LOCAL_DIR),
119
- local_dir_use_symlinks=False,
120
- token=HF_TOKEN,
121
- )
122
- print(f"[MODEL] snapshot_download完了: {downloaded_dir}")
123
- except Exception as e:
124
- error_msg = str(e)
125
- print(f"[MODEL] snapshot_downloadエラー: {error_msg}")
126
-
127
- # 認証エラーの場合、より詳細なメッセージを表示
128
- if "401" in error_msg or "authentication" in error_msg.lower() or "token" in error_msg.lower():
129
- print("[MODEL] 認証エラーの可能性があります。HF_TOKEN が正しく設定されているか確認してください。")
130
- elif "404" in error_msg or "not found" in error_msg.lower():
131
- print(f"[MODEL] リポジトリが見つかりません: {HF_MODEL_REPO}")
132
- print("[MODEL] リポジトリ名が正しいか、アクセス権限があるか確認してください。")
133
-
134
- import traceback
135
- traceback.print_exc()
136
- raise
137
-
138
- # ダウンロードされたディレクトリを確認
139
- downloaded_dir_path = Path(downloaded_dir)
140
- print(f"[MODEL] ダウンロード先パス: {downloaded_dir_path}")
141
-
142
- # ダウンロードされたファイルをリストアップ
143
- downloaded_files = list(downloaded_dir_path.glob("*"))
144
- print(f"[MODEL] ダウンロードされたファイル数: {len(downloaded_files)}")
145
- if downloaded_files:
146
- print(f"[MODEL] ダウンロードされたファイル: {[f.name for f in downloaded_files[:10]]}")
147
 
148
- # config.jsonある確認
149
- if not (downloaded_dir_path / "config.json").exists():
150
- raise FileNotFoundError(
151
- f"モデルディレクトリ {downloaded_dir} に config.json が見つか���ません。"
152
- f"ダウンロードされたファイル: {[f.name for f in downloaded_files[:10]]}"
153
- )
154
-
155
- model_path_str = str(downloaded_dir_path.resolve())
156
- print(f"[MODEL] モデルパスを設定: {model_path_str}")
157
  os.environ["LLM_MODEL_PATH"] = model_path_str
158
  path_manager.model_path = model_path_str
159
  return model_path_str
@@ -191,17 +132,8 @@ def initialize_model() -> None:
191
  threading.Thread(target=initialize_model, daemon=True).start()
192
 
193
  # ZeroGPU対応: モジュールレベルでGPU要求(起動時に検出されるように)
194
- if SPACES_AVAILABLE:
195
- try:
196
- # spaces.GPU() を呼び出してデコレータを取得し、ダミー関数に適用
197
- gpu_decorator = spaces.GPU()
198
- @gpu_decorator
199
- def _gpu_request_dummy():
200
- """GPU要求用のダミー関数(Space起動時に検出される)"""
201
- pass
202
- print("[SPACE] GPU要求をモジュールレベルで送信しました")
203
- except Exception as e:
204
- print(f"[SPACE] GPU要求エラー: {e}")
205
 
206
  app = FastAPI(
207
  title="LLMView Word Tree API",
@@ -210,8 +142,8 @@ app = FastAPI(
210
  )
211
 
212
 
 
213
  @app.get("/")
214
- @spaces.GPU # ZeroGPU対応: root エンドポイントにも適用して起動時に検出されるように
215
  def root() -> Dict[str, str]:
216
  """簡易案内"""
217
  return {
@@ -233,8 +165,8 @@ def health() -> Dict[str, Any]:
233
  }
234
 
235
 
 
236
  @app.post("/build_word_tree", response_model=List[WordTreeResponse])
237
- @spaces.GPU # ZeroGPU対応: このエンドポイントでGPUを要求
238
  def build_word_tree(payload: WordTreeRequest) -> List[WordTreeResponse]:
239
  """単語ツリーを構築"""
240
  if not payload.prompt_text.strip():
 
69
 
70
 
71
  def ensure_model_available() -> str:
72
+ """モデルリポジトリID返すストレージ節約のため、Hubから直接読み込む)"""
73
  print(f"[MODEL] ensure_model_available() 開始")
74
  print(f"[MODEL] モデルリポジトリ: {HF_MODEL_REPO}")
75
  print(f"[MODEL] HF_TOKEN設定: {'あり' if HF_TOKEN else 'なし'}")
 
78
  token_preview = HF_TOKEN[:7] + "..." + HF_TOKEN[-4:] if len(HF_TOKEN) > 11 else "***"
79
  print(f"[MODEL] HF_TOKENプレビュー: {token_preview} (長さ: {len(HF_TOKEN)})")
80
 
81
+ # ストレージ節約のため、モデルをダウンロードせず、リポジトリIDを直接返す
82
+ # transformers の from_pretrained() が Hub から直接読み込む
83
+ print(f"[MODEL] ストレージ節約のため、Hubから直接読み込む方式を使用")
84
+ print(f"[MODEL] モデルパス(リポジトリID): {HF_MODEL_REPO}")
85
 
86
+ # huggingface_hub login 使って明示的に認証
87
+ if HF_TOKEN:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  try:
89
  from huggingface_hub import login
90
  print("[MODEL] huggingface_hub.login() を実行中...")
 
92
  print("[MODEL] ログイン成功")
93
  except Exception as login_error:
94
  print(f"[MODEL] ログインエラー(続行): {login_error}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ # リポジトリIDを返す(transformers Hub ら直接読み込む)
97
+ model_path_str = HF_MODEL_REPO
 
 
 
 
 
 
 
98
  os.environ["LLM_MODEL_PATH"] = model_path_str
99
  path_manager.model_path = model_path_str
100
  return model_path_str
 
132
  threading.Thread(target=initialize_model, daemon=True).start()
133
 
134
  # ZeroGPU対応: モジュールレベルでGPU要求(起動時に検出されるように)
135
+ # 注意: Space は起動時に @spaces.GPU デコレータをスキャンするため、
136
+ # FastAPI のエンドポイント関数に適用する必要がある
 
 
 
 
 
 
 
 
 
137
 
138
  app = FastAPI(
139
  title="LLMView Word Tree API",
 
142
  )
143
 
144
 
145
+ @spaces.GPU # ZeroGPU対応: デコレータを先に適用(Space起動時に検出される)
146
  @app.get("/")
 
147
  def root() -> Dict[str, str]:
148
  """簡易案内"""
149
  return {
 
165
  }
166
 
167
 
168
+ @spaces.GPU # ZeroGPU対応: デコレータを先に適用(Space起動時に検出される)
169
  @app.post("/build_word_tree", response_model=List[WordTreeResponse])
 
170
  def build_word_tree(payload: WordTreeRequest) -> List[WordTreeResponse]:
171
  """単語ツリーを構築"""
172
  if not payload.prompt_text.strip():
package/ai.py CHANGED
@@ -47,9 +47,9 @@ class AI:
47
  cls._instances.clear()
48
 
49
  def _load_model(self, model_path: str) -> Optional[Any]:
50
- """モデルをロード(Transformers使用)"""
51
  try:
52
- if not model_path or not os.path.exists(model_path):
53
  return None
54
 
55
  # transformersを使用してモデルをロード
@@ -67,16 +67,25 @@ class AI:
67
  print(f"[AI] モデルをロード中: {model_path}")
68
  print(f"[AI] デバイス: {device}")
69
 
70
- # トークナイザーとモデルロー
 
 
 
 
 
 
 
 
 
71
  tokenizer = AutoTokenizer.from_pretrained(
72
  model_path,
73
- token=os.getenv("HF_TOKEN"),
74
  )
75
  model = AutoModelForCausalLM.from_pretrained(
76
  model_path,
77
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
78
  device_map="auto" if device == "cuda" else None,
79
- token=os.getenv("HF_TOKEN"),
80
  )
81
 
82
  if device == "cpu":
 
47
  cls._instances.clear()
48
 
49
  def _load_model(self, model_path: str) -> Optional[Any]:
50
+ """モデルをロード(Transformers使用、Hubから直接読み込み)"""
51
  try:
52
+ if not model_path:
53
  return None
54
 
55
  # transformersを使用してモデルをロード
 
67
  print(f"[AI] モデルをロード中: {model_path}")
68
  print(f"[AI] デバイス: {device}")
69
 
70
+ # モデルパスがリポジトリID("user/repo"形式)か、ローカルパスかを判定
71
+ hf_token = os.getenv("HF_TOKEN")
72
+ is_repo_id = "/" in model_path and not os.path.exists(model_path)
73
+
74
+ if is_repo_id:
75
+ print(f"[AI] Hugging Face Hub から直接読み込み: {model_path}")
76
+ else:
77
+ print(f"[AI] ローカルパスから読み込み: {model_path}")
78
+
79
+ # トークナイザーとモデルをロード(Hubから直接読み込む)
80
  tokenizer = AutoTokenizer.from_pretrained(
81
  model_path,
82
+ token=hf_token,
83
  )
84
  model = AutoModelForCausalLM.from_pretrained(
85
  model_path,
86
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
87
  device_map="auto" if device == "cuda" else None,
88
+ token=hf_token,
89
  )
90
 
91
  if device == "cpu":