#!/usr/bin/env python3 """ FastAPI 版 LLMView Word Tree サーバー """ import os import threading import sys from pathlib import Path from typing import List, Dict, Any, Optional from fastapi import FastAPI, HTTPException from huggingface_hub import snapshot_download from pydantic import BaseModel, Field # ZeroGPU対応: spacesパッケージをインポート(デコレータ用) try: import spaces SPACES_AVAILABLE = True print("[SPACE] spacesパッケージをインポートしました") except ImportError: SPACES_AVAILABLE = False print("[SPACE] spacesパッケージが見つかりません(ローカル環境の可能性)") # ダミーデコレータを定義 class DummyGPU: def __call__(self, func): return func spaces = type('spaces', (), {'GPU': DummyGPU()})() try: from package.path_manager import get_path_manager except ImportError: from path_manager import get_path_manager # type: ignore path_manager = get_path_manager() path_manager.setup_sys_path() adapter = None status_message = "モデル初期化中..." status_lock = threading.Lock() HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "meta-llama/Llama-3.2-3B-Instruct") HF_LOCAL_DIR = Path( os.getenv( "HF_MODEL_LOCAL_DIR", path_manager.base_path / "model_cache", ) ) HF_TOKEN = os.getenv("HF_TOKEN") class WordTreeRequest(BaseModel): prompt_text: str = Field(..., description="生成に使用するプロンプト") root_text: str = Field("", description="任意のルートテキスト") top_k: int = Field(5, ge=1, le=50, description="取得する候補数") max_depth: int = Field(10, ge=1, le=50, description="探索深さ") class WordTreeResponse(BaseModel): text: str probability: float def _set_status(message: str) -> None: global status_message with status_lock: status_message = message def _get_dummy_results() -> List[WordTreeResponse]: """モデルが未準備・異常時に返すダミー候補""" dummy_payload = [ {"text": "[eos]", "probability": 0.8}, {"text": "#dummy#候補2", "probability": 0.6}, {"text": "#dummy#候補3", "probability": 0.4}, ] return [WordTreeResponse(**item) for item in dummy_payload] def ensure_model_available() -> str: """モデルリポジトリIDを返す(ストレージ節約のため、Hubから直接読み込む)""" print(f"[MODEL] ensure_model_available() 開始") print(f"[MODEL] モデルリポジトリ: {HF_MODEL_REPO}") print(f"[MODEL] HF_TOKEN設定: {'あり' if HF_TOKEN else 'なし'}") if HF_TOKEN: # トークンの最初の数文字だけを表示(セキュリティのため) token_preview = HF_TOKEN[:7] + "..." + HF_TOKEN[-4:] if len(HF_TOKEN) > 11 else "***" print(f"[MODEL] HF_TOKENプレビュー: {token_preview} (長さ: {len(HF_TOKEN)})") # ストレージ節約のため、モデルをダウンロードせず、リポジトリIDを直接返す # transformers の from_pretrained() が Hub から直接読み込む print(f"[MODEL] ストレージ節約のため、Hubから直接読み込む方式を使用") print(f"[MODEL] モデルパス(リポジトリID): {HF_MODEL_REPO}") # huggingface_hub の login を使って明示的に認証 if HF_TOKEN: try: from huggingface_hub import login print("[MODEL] huggingface_hub.login() を実行中...") login(token=HF_TOKEN, add_to_git_credential=False) print("[MODEL] ログイン成功") except Exception as login_error: print(f"[MODEL] ログインエラー(続行): {login_error}") # リポジトリIDを返す(transformers が Hub から直接読み込む) model_path_str = HF_MODEL_REPO os.environ["LLM_MODEL_PATH"] = model_path_str path_manager.model_path = model_path_str return model_path_str def initialize_model() -> None: """RustAdapter とモデルを初期化""" global adapter try: print("[INIT] モデル初期化スレッド開始") _set_status("モデルを読み込み中です...") from package.rust_adapter import RustAdapter print("[INIT] ensure_model_available() を呼び出し") model_path = ensure_model_available() print(f"[INIT] モデルパス取得: {model_path}") print("[INIT] RustAdapter.get_instance() を呼び出し") adapter = RustAdapter.get_instance(model_path) print("[INIT] RustAdapter初期化完了") _set_status("モデル準備完了") print("[INIT] モデル初期化完了") except Exception as exc: # pragma: no cover error_msg = f"モデル初期化に失敗しました: {exc}" print(f"[INIT] エラー: {error_msg}") _set_status(error_msg) import traceback traceback.print_exc() # プロセスを終了させないように、エラーをログに記録するだけ sys.stderr.write(f"[INIT] モデル初期化エラー(プロセスは継続): {exc}\n") # Space 起動時にバックグラウンドで初期化 threading.Thread(target=initialize_model, daemon=True).start() # ZeroGPU対応: モジュールレベルでGPU要求(起動時に検出されるように) # 注意: Space は起動時に @spaces.GPU デコレータをスキャンするため、 # FastAPI のエンドポイント関数に適用する必要がある app = FastAPI( title="LLMView Word Tree API", description="LLMView の単語ツリー構築 API。/build_word_tree にPOSTしてください。", version="1.0.0", ) # ZeroGPU対応: 起動時に検出されるように、デコレータ付き関数を定義 @spaces.GPU def _gpu_init_function(): """GPU初期化用のダミー関数(Space起動時に検出される)""" pass @app.on_event("startup") async def startup_event(): """アプリ起動時の処理(GPU要求を確実に検出させる)""" if SPACES_AVAILABLE: try: _gpu_init_function() print("[SPACE] GPU要求をstartup eventで送信しました") except Exception as e: print(f"[SPACE] GPU要求エラー: {e}") @app.get("/") def root() -> Dict[str, str]: """簡易案内""" return { "message": "LLMView Word Tree API", "status_endpoint": "/health", "build_endpoint": "/build_word_tree", } @app.get("/health") def health() -> Dict[str, Any]: """状態確認""" with status_lock: current_status = status_message return { "model_loaded": adapter is not None, "status": current_status, "model_path": path_manager.get_model_path(), } @spaces.GPU # ZeroGPU対応: デコレータを先に適用(Space起動時に検出される) @app.post("/build_word_tree", response_model=List[WordTreeResponse]) def build_word_tree(payload: WordTreeRequest) -> List[WordTreeResponse]: """単語ツリーを構築""" if not payload.prompt_text.strip(): raise HTTPException(status_code=400, detail="prompt_text を入力してください。") if adapter is None: print("[API] build_word_tree: モデル未準備(adapter is None)") raise HTTPException( status_code=503, detail=f"モデル準備中です: {status_message}" ) try: print( f"[API] build_word_tree called: prompt=\n##########################\n{payload.prompt_text}\n##########################\n', " f"root=\n%%%%%%%%%%%%%%%%%%%\n{payload.root_text}\n%%%%%%%%%%%%%%%%%%%\n', top_k={payload.top_k}, max_depth={payload.max_depth}\n\n\n\n" ) # print(f"[API] Adapter available: {adapter is not None}") results = adapter.build_word_tree( prompt_text=payload.prompt_text, root_text=payload.root_text, top_k=payload.top_k, max_depth=payload.max_depth, ) # print(f"[API] Generated {len(results)} candidates") print(f"--------------------------------\n[API] results: {[item['text'] for item in results]}\n--------------------------------") if not results: print("[API] No candidates generated, returning dummy candidates") results = _get_dummy_results() return results except Exception as exc: import traceback traceback.print_exc() print(f"[API] build_word_tree error: {exc}, fallback to dummy results") return _get_dummy_results() if __name__ == "__main__": import uvicorn uvicorn.run( app, host="0.0.0.0", port=7860, log_level=os.getenv("UVICORN_LOG_LEVEL", "warning"), access_log=os.getenv("UVICORN_ACCESS_LOG", "false").lower() == "true", )