Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| LLMView Multi-Model - Gradioアプリ | |
| Hugging Face Spaces用 | |
| """ | |
| import os | |
| import sys | |
| import threading | |
| from pathlib import Path | |
| from typing import List, Dict, Any, Optional | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel, Field | |
| # Gradio UIを有効化するかどうか(環境変数で制御) | |
| ENABLE_GRADIO_UI = os.getenv("ENABLE_GRADIO_UI", "false").lower() == "true" | |
| if ENABLE_GRADIO_UI: | |
| import gradio as gr | |
| # ZeroGPU対応: spacesパッケージをインポート(デコレータ用) | |
| try: | |
| import spaces | |
| SPACES_AVAILABLE = True | |
| print("[SPACE] spacesパッケージをインポートしました") | |
| except ImportError: | |
| SPACES_AVAILABLE = False | |
| print("[SPACE] spacesパッケージが見つかりません(ローカル環境の可能性)") | |
| # ダミーデコレータを定義 | |
| class DummyGPU: | |
| def __call__(self, func): | |
| return func | |
| spaces = type('spaces', (), {'GPU': DummyGPU()})() | |
| # パッケージパスを追加 | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| from package.ai import get_ai_model | |
| from package.word_processor import WordDeterminer, WordPiece | |
| from package.adapter import ModelAdapter | |
| # グローバル変数 | |
| adapter: Optional[ModelAdapter] = None | |
| status_message = "モデル初期化中..." | |
| status_lock = threading.Lock() | |
| # 環境変数から設定を取得 | |
| MODEL_TYPE = os.getenv("MODEL_TYPE", "transformers") | |
| HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "meta-llama/Llama-3.2-3B-Instruct") | |
| # FastAPI用のリクエスト/レスポンスモデル | |
| class WordTreeRequest(BaseModel): | |
| prompt_text: str = Field(..., description="生成に使用するプロンプト") | |
| root_text: str = Field("", description="任意のルートテキスト") | |
| top_k: int = Field(5, ge=1, le=50, description="取得する候補数") | |
| max_depth: int = Field(10, ge=1, le=50, description="探索深さ") | |
| class WordTreeResponse(BaseModel): | |
| text: str | |
| probability: float | |
| def _set_status(message: str) -> None: | |
| """ステータスメッセージを更新""" | |
| global status_message | |
| with status_lock: | |
| status_message = message | |
| def _get_dummy_results() -> List[WordTreeResponse]: | |
| """モデルが未準備・異常時に返すダミー候補""" | |
| dummy_payload = [ | |
| {"text": "[eos]", "probability": 0.8}, | |
| {"text": "#dummy#候補2", "probability": 0.6}, | |
| {"text": "#dummy#候補3", "probability": 0.4}, | |
| ] | |
| return [WordTreeResponse(**item) for item in dummy_payload] | |
| def initialize_model() -> None: | |
| """モデルを初期化""" | |
| global adapter | |
| try: | |
| print("[INIT] モデル初期化開始") | |
| _set_status("モデルを読み込み中です...") | |
| # AIモデルを取得 | |
| ai_model = get_ai_model() | |
| print(f"[INIT] AIモデル取得成功: {type(ai_model)}") | |
| # ModelAdapterを初期化 | |
| adapter = ModelAdapter(ai_model) | |
| print("[INIT] ModelAdapter初期化完了") | |
| _set_status("モデル準備完了") | |
| print("[INIT] モデル初期化完了") | |
| except Exception as exc: | |
| error_msg = f"モデル初期化に失敗しました: {exc}" | |
| print(f"[INIT] エラー: {error_msg}") | |
| _set_status(error_msg) | |
| import traceback | |
| traceback.print_exc() | |
| # バックグラウンドでモデルを初期化 | |
| threading.Thread(target=initialize_model, daemon=True).start() | |
| def build_word_tree( | |
| prompt_text: str, | |
| root_text: str = "", | |
| top_k: int = 5, | |
| max_depth: int = 10 | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| 単語ツリーを構築 | |
| Args: | |
| prompt_text: プロンプトテキスト | |
| root_text: ルートテキスト(オプション) | |
| top_k: 取得する候補数 | |
| max_depth: 最大深さ | |
| Returns: | |
| List[Dict[str, Any]]: 候補リスト | |
| """ | |
| if not prompt_text.strip(): | |
| return [{"text": "プロンプトを入力してください", "probability": 0.0}] | |
| if adapter is None: | |
| with status_lock: | |
| current_status = status_message | |
| return [{"text": f"モデル準備中: {current_status}", "probability": 0.0}] | |
| try: | |
| results = adapter.build_word_tree( | |
| prompt_text=prompt_text, | |
| root_text=root_text, | |
| top_k=top_k, | |
| max_depth=max_depth, | |
| ) | |
| if not results: | |
| return [{"text": "候補が生成されませんでした", "probability": 0.0}] | |
| return results | |
| except Exception as exc: | |
| import traceback | |
| traceback.print_exc() | |
| return [{"text": f"エラー: {exc}", "probability": 0.0}] | |
| def get_status() -> str: | |
| """ステータスを取得""" | |
| with status_lock: | |
| current_status = status_message | |
| model_info = f"モデルタイプ: {MODEL_TYPE}\n" | |
| if MODEL_TYPE == "transformers": | |
| model_info += f"モデル: {HF_MODEL_REPO}\n" | |
| return f"{model_info}ステータス: {current_status}" | |
| # Gradioインターフェース(オプション) | |
| demo = None | |
| if ENABLE_GRADIO_UI: | |
| with gr.Blocks(title="LLMView Multi-Model", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # LLMView Multi-Model | |
| 複数のAIモデルに対応した単語ツリー構築ツール | |
| ## 使い方 | |
| 1. プロンプトを入力 | |
| 2. オプションでルートテキストを指定(既存のテキストの続きを生成する場合) | |
| 3. パラメータを調整(top_k: 候補数、max_depth: 最大深さ) | |
| 4. 「単語ツリーを構築」ボタンをクリック | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt_input = gr.Textbox( | |
| label="プロンプト", | |
| placeholder="例: 電球を作ったのは誰?", | |
| lines=3 | |
| ) | |
| root_input = gr.Textbox( | |
| label="ルートテキスト(オプション)", | |
| placeholder="例: 電球を作ったのは", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| top_k_slider = gr.Slider( | |
| minimum=1, | |
| maximum=20, | |
| value=5, | |
| step=1, | |
| label="候補数 (top_k)" | |
| ) | |
| max_depth_slider = gr.Slider( | |
| minimum=1, | |
| maximum=50, | |
| value=10, | |
| step=1, | |
| label="最大深さ (max_depth)" | |
| ) | |
| build_btn = gr.Button("単語ツリーを構築", variant="primary") | |
| with gr.Column(scale=1): | |
| status_output = gr.Textbox( | |
| label="ステータス", | |
| value=get_status(), | |
| lines=5, | |
| interactive=False | |
| ) | |
| refresh_status_btn = gr.Button("ステータス更新") | |
| results_output = gr.Dataframe( | |
| label="結果", | |
| headers=["テキスト", "確率"], | |
| datatype=["str", "number"], | |
| interactive=False | |
| ) | |
| # イベントハンドラ | |
| def build_and_display(prompt, root, top_k, max_depth): | |
| results = build_word_tree(prompt, root, int(top_k), int(max_depth)) | |
| # DataFrame用に変換 | |
| df_data = [[r["text"], f"{r['probability']:.4f}"] for r in results] | |
| return df_data, get_status() | |
| build_btn.click( | |
| fn=build_and_display, | |
| inputs=[prompt_input, root_input, top_k_slider, max_depth_slider], | |
| outputs=[results_output, status_output] | |
| ) | |
| refresh_status_btn.click( | |
| fn=lambda: get_status(), | |
| outputs=status_output | |
| ) | |
| # FastAPIアプリを作成(元のLLMViewと同じ構造) | |
| app = FastAPI( | |
| title="LLMView Multi-Model API", | |
| description="LLMView の単語ツリー構築 API。/build_word_tree にPOSTしてください。", | |
| version="1.0.0", | |
| ) | |
| # ZeroGPU対応: 起動時に検出されるように、デコレータ付き関数を定義 | |
| def _gpu_init_function(): | |
| """GPU初期化用のダミー関数(Space起動時に検出される)""" | |
| pass | |
| async def startup_event(): | |
| """アプリ起動時の処理(GPU要求を確実に検出させる)""" | |
| if SPACES_AVAILABLE: | |
| try: | |
| _gpu_init_function() | |
| print("[SPACE] GPU要求をstartup eventで送信しました") | |
| except Exception as e: | |
| print(f"[SPACE] GPU要求エラー: {e}") | |
| def root() -> Dict[str, str]: | |
| """簡易案内(元のLLMViewと同じ)""" | |
| return { | |
| "message": "LLMView Multi-Model API", | |
| "status_endpoint": "/health", | |
| "build_endpoint": "/build_word_tree", | |
| } | |
| def health() -> Dict[str, Any]: | |
| """状態確認(元のLLMViewと同じ形式)""" | |
| with status_lock: | |
| current_status = status_message | |
| return { | |
| "model_loaded": adapter is not None, | |
| "status": current_status, | |
| "model_type": MODEL_TYPE, | |
| "model_path": HF_MODEL_REPO if MODEL_TYPE == "transformers" else None, | |
| } | |
| # ZeroGPU対応: デコレータを先に適用(Space起動時に検出される) | |
| def api_build_word_tree(payload: WordTreeRequest) -> List[WordTreeResponse]: | |
| """単語ツリーを構築(元のLLMViewと同じAPI)""" | |
| if not payload.prompt_text.strip(): | |
| raise HTTPException(status_code=400, detail="prompt_text を入力してください。") | |
| if adapter is None: | |
| print("[API] build_word_tree: モデル未準備(adapter is None)") | |
| with status_lock: | |
| current_status = status_message | |
| raise HTTPException( | |
| status_code=503, detail=f"モデル準備中です: {current_status}" | |
| ) | |
| try: | |
| print( | |
| f"[API] build_word_tree called: prompt=\n##########################\n{payload.prompt_text}\n##########################\n', " | |
| f"root=\n%%%%%%%%%%%%%%%%%%%\n{payload.root_text}\n%%%%%%%%%%%%%%%%%%%\n', top_k={payload.top_k}, max_depth={payload.max_depth}" | |
| ) | |
| results = adapter.build_word_tree( | |
| prompt_text=payload.prompt_text, | |
| root_text=payload.root_text, | |
| top_k=payload.top_k, | |
| max_depth=payload.max_depth, | |
| ) | |
| if not results: | |
| print("[API] No candidates generated, returning dummy candidates") | |
| results = _get_dummy_results() | |
| return [WordTreeResponse(**item) for item in results] | |
| except Exception as exc: | |
| import traceback | |
| traceback.print_exc() | |
| print(f"[API] build_word_tree error: {exc}, fallback to dummy results") | |
| return _get_dummy_results() | |
| if __name__ == "__main__": | |
| # Gradio UIが有効な場合はGradioアプリを起動、無効な場合はFastAPIのみ | |
| if ENABLE_GRADIO_UI and demo is not None: | |
| # GradioアプリにFastAPIを統合 | |
| demo.fastapi_app = app | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) | |
| else: | |
| # FastAPIのみ(元のLLMViewと同じ) | |
| import uvicorn | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| log_level=os.getenv("UVICORN_LOG_LEVEL", "warning"), | |
| access_log=os.getenv("UVICORN_ACCESS_LOG", "false").lower() == "true", | |
| ) | |