| |
| """ |
| FastAPI 版 LLMView Word Tree サーバー |
| """ |
|
|
| import os |
| import threading |
| import sys |
| from pathlib import Path |
| from typing import List, Dict, Any, Optional |
|
|
| from fastapi import FastAPI, HTTPException |
| from huggingface_hub import snapshot_download |
| from pydantic import BaseModel, Field |
|
|
| |
| try: |
| import spaces |
| SPACES_AVAILABLE = True |
| print("[SPACE] spacesパッケージをインポートしました") |
| except ImportError: |
| SPACES_AVAILABLE = False |
| print("[SPACE] spacesパッケージが見つかりません(ローカル環境の可能性)") |
| |
| class DummyGPU: |
| def __call__(self, func): |
| return func |
| spaces = type('spaces', (), {'GPU': DummyGPU()})() |
|
|
| try: |
| from package.path_manager import get_path_manager |
| except ImportError: |
| from path_manager import get_path_manager |
|
|
|
|
| path_manager = get_path_manager() |
| path_manager.setup_sys_path() |
|
|
| adapter = None |
| status_message = "モデル初期化中..." |
| status_lock = threading.Lock() |
|
|
| HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "meta-llama/Llama-3.2-3B-Instruct") |
| HF_LOCAL_DIR = Path( |
| os.getenv( |
| "HF_MODEL_LOCAL_DIR", |
| path_manager.base_path / "model_cache", |
| ) |
| ) |
| HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
| class WordTreeRequest(BaseModel): |
| prompt_text: str = Field(..., description="生成に使用するプロンプト") |
| root_text: str = Field("", description="任意のルートテキスト") |
| top_k: int = Field(5, ge=1, le=50, description="取得する候補数") |
| max_depth: int = Field(10, ge=1, le=50, description="探索深さ") |
|
|
|
|
| class WordTreeResponse(BaseModel): |
| text: str |
| probability: float |
|
|
|
|
| def _set_status(message: str) -> None: |
| global status_message |
| with status_lock: |
| status_message = message |
|
|
|
|
| def _get_dummy_results() -> List[WordTreeResponse]: |
| """モデルが未準備・異常時に返すダミー候補""" |
| dummy_payload = [ |
| {"text": "[eos]", "probability": 0.8}, |
| {"text": "#dummy#候補2", "probability": 0.6}, |
| {"text": "#dummy#候補3", "probability": 0.4}, |
| ] |
| return [WordTreeResponse(**item) for item in dummy_payload] |
|
|
|
|
| def ensure_model_available() -> str: |
| """モデルリポジトリIDを返す(ストレージ節約のため、Hubから直接読み込む)""" |
| print(f"[MODEL] ensure_model_available() 開始") |
| print(f"[MODEL] モデルリポジトリ: {HF_MODEL_REPO}") |
| print(f"[MODEL] HF_TOKEN設定: {'あり' if HF_TOKEN else 'なし'}") |
| if HF_TOKEN: |
| |
| token_preview = HF_TOKEN[:7] + "..." + HF_TOKEN[-4:] if len(HF_TOKEN) > 11 else "***" |
| print(f"[MODEL] HF_TOKENプレビュー: {token_preview} (長さ: {len(HF_TOKEN)})") |
| |
| |
| |
| print(f"[MODEL] ストレージ節約のため、Hubから直接読み込む方式を使用") |
| print(f"[MODEL] モデルパス(リポジトリID): {HF_MODEL_REPO}") |
| |
| |
| if HF_TOKEN: |
| try: |
| from huggingface_hub import login |
| print("[MODEL] huggingface_hub.login() を実行中...") |
| login(token=HF_TOKEN, add_to_git_credential=False) |
| print("[MODEL] ログイン成功") |
| except Exception as login_error: |
| print(f"[MODEL] ログインエラー(続行): {login_error}") |
| |
| |
| model_path_str = HF_MODEL_REPO |
| os.environ["LLM_MODEL_PATH"] = model_path_str |
| path_manager.model_path = model_path_str |
| return model_path_str |
|
|
|
|
| def initialize_model() -> None: |
| """RustAdapter とモデルを初期化""" |
| global adapter |
| try: |
| print("[INIT] モデル初期化スレッド開始") |
| _set_status("モデルを読み込み中です...") |
| from package.rust_adapter import RustAdapter |
|
|
| print("[INIT] ensure_model_available() を呼び出し") |
| model_path = ensure_model_available() |
| print(f"[INIT] モデルパス取得: {model_path}") |
| |
| print("[INIT] RustAdapter.get_instance() を呼び出し") |
| adapter = RustAdapter.get_instance(model_path) |
| print("[INIT] RustAdapter初期化完了") |
| |
| _set_status("モデル準備完了") |
| print("[INIT] モデル初期化完了") |
| except Exception as exc: |
| error_msg = f"モデル初期化に失敗しました: {exc}" |
| print(f"[INIT] エラー: {error_msg}") |
| _set_status(error_msg) |
| import traceback |
| traceback.print_exc() |
| |
| sys.stderr.write(f"[INIT] モデル初期化エラー(プロセスは継続): {exc}\n") |
|
|
|
|
| |
| threading.Thread(target=initialize_model, daemon=True).start() |
|
|
| |
| |
| |
|
|
| app = FastAPI( |
| title="LLMView Word Tree API", |
| description="LLMView の単語ツリー構築 API。/build_word_tree にPOSTしてください。", |
| version="1.0.0", |
| ) |
|
|
|
|
| |
| @spaces.GPU |
| def _gpu_init_function(): |
| """GPU初期化用のダミー関数(Space起動時に検出される)""" |
| pass |
|
|
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| """アプリ起動時の処理(GPU要求を確実に検出させる)""" |
| if SPACES_AVAILABLE: |
| try: |
| _gpu_init_function() |
| print("[SPACE] GPU要求をstartup eventで送信しました") |
| except Exception as e: |
| print(f"[SPACE] GPU要求エラー: {e}") |
|
|
|
|
| @app.get("/") |
| def root() -> Dict[str, str]: |
| """簡易案内""" |
| return { |
| "message": "LLMView Word Tree API", |
| "status_endpoint": "/health", |
| "build_endpoint": "/build_word_tree", |
| } |
|
|
|
|
| @app.get("/health") |
| def health() -> Dict[str, Any]: |
| """状態確認""" |
| with status_lock: |
| current_status = status_message |
| return { |
| "model_loaded": adapter is not None, |
| "status": current_status, |
| "model_path": path_manager.get_model_path(), |
| } |
|
|
|
|
| @spaces.GPU |
| @app.post("/build_word_tree", response_model=List[WordTreeResponse]) |
| def build_word_tree(payload: WordTreeRequest) -> List[WordTreeResponse]: |
| """単語ツリーを構築""" |
| if not payload.prompt_text.strip(): |
| raise HTTPException(status_code=400, detail="prompt_text を入力してください。") |
|
|
| if adapter is None: |
| print("[API] build_word_tree: モデル未準備(adapter is None)") |
| raise HTTPException( |
| status_code=503, detail=f"モデル準備中です: {status_message}" |
| ) |
|
|
| try: |
| print( |
| f"[API] build_word_tree called: prompt=\n##########################\n{payload.prompt_text}\n##########################\n', " |
| f"root=\n%%%%%%%%%%%%%%%%%%%\n{payload.root_text}\n%%%%%%%%%%%%%%%%%%%\n', top_k={payload.top_k}, max_depth={payload.max_depth}\n\n\n\n" |
| ) |
| |
|
|
| results = adapter.build_word_tree( |
| prompt_text=payload.prompt_text, |
| root_text=payload.root_text, |
| top_k=payload.top_k, |
| max_depth=payload.max_depth, |
| ) |
| |
| print(f"--------------------------------\n[API] results: {[item['text'] for item in results]}\n--------------------------------") |
| if not results: |
| print("[API] No candidates generated, returning dummy candidates") |
| results = _get_dummy_results() |
|
|
| return results |
| except Exception as exc: |
| import traceback |
|
|
| traceback.print_exc() |
| print(f"[API] build_word_tree error: {exc}, fallback to dummy results") |
| return _get_dummy_results() |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
|
|
| uvicorn.run( |
| app, |
| host="0.0.0.0", |
| port=7860, |
| log_level=os.getenv("UVICORN_LOG_LEVEL", "warning"), |
| access_log=os.getenv("UVICORN_ACCESS_LOG", "false").lower() == "true", |
| ) |
|
|
|
|
|
|