File size: 8,895 Bytes
09c17cd
 
 
 
 
401eff3
09c17cd
df1c720
401eff3
09c17cd
 
 
401eff3
09c17cd
 
df1c720
 
 
 
 
 
 
 
 
 
 
 
 
 
09c17cd
 
 
 
 
 
 
 
 
 
 
 
 
adb0f98
401eff3
 
 
 
 
 
 
 
09c17cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
677f8d1
 
 
 
 
 
 
 
 
 
401eff3
17b871a
df1c720
dbb276c
 
8620bbf
 
 
 
df1c720
17b871a
 
 
 
dbb276c
17b871a
 
8620bbf
 
 
 
 
 
 
df1c720
17b871a
 
df1c720
 
 
401eff3
 
09c17cd
 
 
 
df1c720
09c17cd
 
 
df1c720
401eff3
df1c720
 
 
09c17cd
df1c720
 
09c17cd
df1c720
09c17cd
df1c720
 
 
09c17cd
 
df1c720
 
09c17cd
 
 
 
 
ca562c6
17b871a
 
09c17cd
 
 
 
 
 
 
 
03fb695
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09c17cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17b871a
09c17cd
 
 
 
 
 
 
677f8d1
09c17cd
 
 
 
 
677f8d1
52c76d4
58dfa5f
677f8d1
52c76d4
677f8d1
09c17cd
 
 
 
 
 
52c76d4
58dfa5f
09c17cd
677f8d1
 
 
09c17cd
 
 
 
 
677f8d1
 
09c17cd
 
 
 
 
9ddfec3
 
 
 
 
 
 
09c17cd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
#!/usr/bin/env python3
"""
FastAPI 版 LLMView Word Tree サーバー
"""

import os
import threading
import sys
from pathlib import Path
from typing import List, Dict, Any, Optional

from fastapi import FastAPI, HTTPException
from huggingface_hub import snapshot_download
from pydantic import BaseModel, Field

# ZeroGPU対応: spacesパッケージをインポート(デコレータ用)
try:
    import spaces
    SPACES_AVAILABLE = True
    print("[SPACE] spacesパッケージをインポートしました")
except ImportError:
    SPACES_AVAILABLE = False
    print("[SPACE] spacesパッケージが見つかりません(ローカル環境の可能性)")
    # ダミーデコレータを定義
    class DummyGPU:
        def __call__(self, func):
            return func
    spaces = type('spaces', (), {'GPU': DummyGPU()})()

try:
    from package.path_manager import get_path_manager
except ImportError:
    from path_manager import get_path_manager  # type: ignore


path_manager = get_path_manager()
path_manager.setup_sys_path()

adapter = None
status_message = "モデル初期化中..."
status_lock = threading.Lock()

HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "meta-llama/Llama-3.2-3B-Instruct")
HF_LOCAL_DIR = Path(
    os.getenv(
        "HF_MODEL_LOCAL_DIR",
        path_manager.base_path / "model_cache",
    )
)
HF_TOKEN = os.getenv("HF_TOKEN")


class WordTreeRequest(BaseModel):
    prompt_text: str = Field(..., description="生成に使用するプロンプト")
    root_text: str = Field("", description="任意のルートテキスト")
    top_k: int = Field(5, ge=1, le=50, description="取得する候補数")
    max_depth: int = Field(10, ge=1, le=50, description="探索深さ")


class WordTreeResponse(BaseModel):
    text: str
    probability: float


def _set_status(message: str) -> None:
    global status_message
    with status_lock:
        status_message = message


def _get_dummy_results() -> List[WordTreeResponse]:
    """モデルが未準備・異常時に返すダミー候補"""
    dummy_payload = [
        {"text": "[eos]", "probability": 0.8},
        {"text": "#dummy#候補2", "probability": 0.6},
        {"text": "#dummy#候補3", "probability": 0.4},
    ]
    return [WordTreeResponse(**item) for item in dummy_payload]


def ensure_model_available() -> str:
    """モデルリポジトリIDを返す(ストレージ節約のため、Hubから直接読み込む)"""
    print(f"[MODEL] ensure_model_available() 開始")
    print(f"[MODEL] モデルリポジトリ: {HF_MODEL_REPO}")
    print(f"[MODEL] HF_TOKEN設定: {'あり' if HF_TOKEN else 'なし'}")
    if HF_TOKEN:
        # トークンの最初の数文字だけを表示(セキュリティのため)
        token_preview = HF_TOKEN[:7] + "..." + HF_TOKEN[-4:] if len(HF_TOKEN) > 11 else "***"
        print(f"[MODEL] HF_TOKENプレビュー: {token_preview} (長さ: {len(HF_TOKEN)})")
    
    # ストレージ節約のため、モデルをダウンロードせず、リポジトリIDを直接返す
    # transformers の from_pretrained() が Hub から直接読み込む
    print(f"[MODEL] ストレージ節約のため、Hubから直接読み込む方式を使用")
    print(f"[MODEL] モデルパス(リポジトリID): {HF_MODEL_REPO}")
    
    # huggingface_hub の login を使って明示的に認証
    if HF_TOKEN:
        try:
            from huggingface_hub import login
            print("[MODEL] huggingface_hub.login() を実行中...")
            login(token=HF_TOKEN, add_to_git_credential=False)
            print("[MODEL] ログイン成功")
        except Exception as login_error:
            print(f"[MODEL] ログインエラー(続行): {login_error}")
    
    # リポジトリIDを返す(transformers が Hub から直接読み込む)
    model_path_str = HF_MODEL_REPO
    os.environ["LLM_MODEL_PATH"] = model_path_str
    path_manager.model_path = model_path_str
    return model_path_str


def initialize_model() -> None:
    """RustAdapter とモデルを初期化"""
    global adapter
    try:
        print("[INIT] モデル初期化スレッド開始")
        _set_status("モデルを読み込み中です...")
        from package.rust_adapter import RustAdapter

        print("[INIT] ensure_model_available() を呼び出し")
        model_path = ensure_model_available()
        print(f"[INIT] モデルパス取得: {model_path}")
        
        print("[INIT] RustAdapter.get_instance() を呼び出し")
        adapter = RustAdapter.get_instance(model_path)
        print("[INIT] RustAdapter初期化完了")
        
        _set_status("モデル準備完了")
        print("[INIT] モデル初期化完了")
    except Exception as exc:  # pragma: no cover
        error_msg = f"モデル初期化に失敗しました: {exc}"
        print(f"[INIT] エラー: {error_msg}")
        _set_status(error_msg)
        import traceback
        traceback.print_exc()
        # プロセスを終了させないように、エラーをログに記録するだけ
        sys.stderr.write(f"[INIT] モデル初期化エラー(プロセスは継続): {exc}\n")


# Space 起動時にバックグラウンドで初期化
threading.Thread(target=initialize_model, daemon=True).start()

# ZeroGPU対応: モジュールレベルでGPU要求(起動時に検出されるように)
# 注意: Space は起動時に @spaces.GPU デコレータをスキャンするため、
# FastAPI のエンドポイント関数に適用する必要がある

app = FastAPI(
    title="LLMView Word Tree API",
    description="LLMView の単語ツリー構築 API。/build_word_tree にPOSTしてください。",
    version="1.0.0",
)


# ZeroGPU対応: 起動時に検出されるように、デコレータ付き関数を定義
@spaces.GPU
def _gpu_init_function():
    """GPU初期化用のダミー関数(Space起動時に検出される)"""
    pass


@app.on_event("startup")
async def startup_event():
    """アプリ起動時の処理(GPU要求を確実に検出させる)"""
    if SPACES_AVAILABLE:
        try:
            _gpu_init_function()
            print("[SPACE] GPU要求をstartup eventで送信しました")
        except Exception as e:
            print(f"[SPACE] GPU要求エラー: {e}")


@app.get("/")
def root() -> Dict[str, str]:
    """簡易案内"""
    return {
        "message": "LLMView Word Tree API",
        "status_endpoint": "/health",
        "build_endpoint": "/build_word_tree",
    }


@app.get("/health")
def health() -> Dict[str, Any]:
    """状態確認"""
    with status_lock:
        current_status = status_message
    return {
        "model_loaded": adapter is not None,
        "status": current_status,
        "model_path": path_manager.get_model_path(),
    }


@spaces.GPU  # ZeroGPU対応: デコレータを先に適用(Space起動時に検出される)
@app.post("/build_word_tree", response_model=List[WordTreeResponse])
def build_word_tree(payload: WordTreeRequest) -> List[WordTreeResponse]:
    """単語ツリーを構築"""
    if not payload.prompt_text.strip():
        raise HTTPException(status_code=400, detail="prompt_text を入力してください。")

    if adapter is None:
        print("[API] build_word_tree: モデル未準備(adapter is None)")
        raise HTTPException(
            status_code=503, detail=f"モデル準備中です: {status_message}"
        )

    try:
        print(
            f"[API] build_word_tree called: prompt=\n##########################\n{payload.prompt_text}\n##########################\n', "
            f"root=\n%%%%%%%%%%%%%%%%%%%\n{payload.root_text}\n%%%%%%%%%%%%%%%%%%%\n', top_k={payload.top_k}, max_depth={payload.max_depth}\n\n\n\n"
        )
        # print(f"[API] Adapter available: {adapter is not None}")

        results = adapter.build_word_tree(
            prompt_text=payload.prompt_text,
            root_text=payload.root_text,
            top_k=payload.top_k,
            max_depth=payload.max_depth,
        )
        # print(f"[API] Generated {len(results)} candidates")
        print(f"--------------------------------\n[API] results: {[item['text'] for item in results]}\n--------------------------------")
        if not results:
            print("[API] No candidates generated, returning dummy candidates")
            results = _get_dummy_results()

        return results
    except Exception as exc:
        import traceback

        traceback.print_exc()
        print(f"[API] build_word_tree error: {exc}, fallback to dummy results")
        return _get_dummy_results()


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(
    app,
    host="0.0.0.0",
    port=7860,
    log_level=os.getenv("UVICORN_LOG_LEVEL", "warning"),
    access_log=os.getenv("UVICORN_ACCESS_LOG", "false").lower() == "true",
    )