File size: 11,979 Bytes
0447f30
 
 
 
 
 
 
 
 
 
 
f1d5201
 
0447f30
e1a7842
 
 
 
 
 
0447f30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1d5201
 
 
 
 
 
 
 
 
 
 
 
0447f30
 
 
 
 
 
 
 
3fbc018
 
 
 
 
 
 
 
 
 
0447f30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1a7842
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0447f30
e1a7842
 
 
 
0447f30
e1a7842
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0447f30
e1a7842
 
 
 
 
 
 
 
0447f30
e1a7842
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0447f30
 
0dac8e2
 
 
 
 
f1d5201
 
e1a7842
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1d5201
3fbc018
f1d5201
 
 
 
 
 
 
 
 
 
 
3fbc018
e1a7842
 
3fbc018
f1d5201
 
 
 
3fbc018
f1d5201
 
 
 
 
 
 
3fbc018
 
 
 
 
f1d5201
 
 
 
 
 
 
 
3fbc018
 
f1d5201
 
 
 
 
3fbc018
 
f1d5201
 
0447f30
e1a7842
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0447f30
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
#!/usr/bin/env python3
"""
LLMView Multi-Model - Gradioアプリ
Hugging Face Spaces用
"""
import os
import sys
import threading
from pathlib import Path
from typing import List, Dict, Any, Optional

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field

# Gradio UIを有効化するかどうか(環境変数で制御)
ENABLE_GRADIO_UI = os.getenv("ENABLE_GRADIO_UI", "false").lower() == "true"

if ENABLE_GRADIO_UI:
    import gradio as gr

# ZeroGPU対応: spacesパッケージをインポート(デコレータ用)
try:
    import spaces
    SPACES_AVAILABLE = True
    print("[SPACE] spacesパッケージをインポートしました")
except ImportError:
    SPACES_AVAILABLE = False
    print("[SPACE] spacesパッケージが見つかりません(ローカル環境の可能性)")
    # ダミーデコレータを定義
    class DummyGPU:
        def __call__(self, func):
            return func
    spaces = type('spaces', (), {'GPU': DummyGPU()})()

# パッケージパスを追加
sys.path.insert(0, str(Path(__file__).parent))

from package.ai import get_ai_model
from package.word_processor import WordDeterminer, WordPiece
from package.adapter import ModelAdapter

# グローバル変数
adapter: Optional[ModelAdapter] = None
status_message = "モデル初期化中..."
status_lock = threading.Lock()

# 環境変数から設定を取得
MODEL_TYPE = os.getenv("MODEL_TYPE", "transformers")
HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "meta-llama/Llama-3.2-3B-Instruct")

# FastAPI用のリクエスト/レスポンスモデル
class WordTreeRequest(BaseModel):
    prompt_text: str = Field(..., description="生成に使用するプロンプト")
    root_text: str = Field("", description="任意のルートテキスト")
    top_k: int = Field(5, ge=1, le=50, description="取得する候補数")
    max_depth: int = Field(10, ge=1, le=50, description="探索深さ")


class WordTreeResponse(BaseModel):
    text: str
    probability: float


def _set_status(message: str) -> None:
    """ステータスメッセージを更新"""
    global status_message
    with status_lock:
        status_message = message


def _get_dummy_results() -> List[WordTreeResponse]:
    """モデルが未準備・異常時に返すダミー候補"""
    dummy_payload = [
        {"text": "[eos]", "probability": 0.8},
        {"text": "#dummy#候補2", "probability": 0.6},
        {"text": "#dummy#候補3", "probability": 0.4},
    ]
    return [WordTreeResponse(**item) for item in dummy_payload]


def initialize_model() -> None:
    """モデルを初期化"""
    global adapter
    try:
        print("[INIT] モデル初期化開始")
        _set_status("モデルを読み込み中です...")
        
        # AIモデルを取得
        ai_model = get_ai_model()
        print(f"[INIT] AIモデル取得成功: {type(ai_model)}")
        
        # ModelAdapterを初期化
        adapter = ModelAdapter(ai_model)
        print("[INIT] ModelAdapter初期化完了")
        
        _set_status("モデル準備完了")
        print("[INIT] モデル初期化完了")
    except Exception as exc:
        error_msg = f"モデル初期化に失敗しました: {exc}"
        print(f"[INIT] エラー: {error_msg}")
        _set_status(error_msg)
        import traceback
        traceback.print_exc()


# バックグラウンドでモデルを初期化
threading.Thread(target=initialize_model, daemon=True).start()


def build_word_tree(
    prompt_text: str,
    root_text: str = "",
    top_k: int = 5,
    max_depth: int = 10
) -> List[Dict[str, Any]]:
    """
    単語ツリーを構築
    
    Args:
        prompt_text: プロンプトテキスト
        root_text: ルートテキスト(オプション)
        top_k: 取得する候補数
        max_depth: 最大深さ
        
    Returns:
        List[Dict[str, Any]]: 候補リスト
    """
    if not prompt_text.strip():
        return [{"text": "プロンプトを入力してください", "probability": 0.0}]
    
    if adapter is None:
        with status_lock:
            current_status = status_message
        return [{"text": f"モデル準備中: {current_status}", "probability": 0.0}]
    
    try:
        results = adapter.build_word_tree(
            prompt_text=prompt_text,
            root_text=root_text,
            top_k=top_k,
            max_depth=max_depth,
        )
        
        if not results:
            return [{"text": "候補が生成されませんでした", "probability": 0.0}]
        
        return results
    except Exception as exc:
        import traceback
        traceback.print_exc()
        return [{"text": f"エラー: {exc}", "probability": 0.0}]


def get_status() -> str:
    """ステータスを取得"""
    with status_lock:
        current_status = status_message
    
    model_info = f"モデルタイプ: {MODEL_TYPE}\n"
    if MODEL_TYPE == "transformers":
        model_info += f"モデル: {HF_MODEL_REPO}\n"
    
    return f"{model_info}ステータス: {current_status}"


# Gradioインターフェース(オプション)
demo = None
if ENABLE_GRADIO_UI:
    with gr.Blocks(title="LLMView Multi-Model", theme=gr.themes.Soft()) as demo:
        gr.Markdown("""
        # LLMView Multi-Model
        
        複数のAIモデルに対応した単語ツリー構築ツール
        
        ## 使い方
        1. プロンプトを入力
        2. オプションでルートテキストを指定(既存のテキストの続きを生成する場合)
        3. パラメータを調整(top_k: 候補数、max_depth: 最大深さ)
        4. 「単語ツリーを構築」ボタンをクリック
        """)
        
        with gr.Row():
            with gr.Column(scale=2):
                prompt_input = gr.Textbox(
                    label="プロンプト",
                    placeholder="例: 電球を作ったのは誰?",
                    lines=3
                )
                root_input = gr.Textbox(
                    label="ルートテキスト(オプション)",
                    placeholder="例: 電球を作ったのは",
                    lines=2
                )
                
                with gr.Row():
                    top_k_slider = gr.Slider(
                        minimum=1,
                        maximum=20,
                        value=5,
                        step=1,
                        label="候補数 (top_k)"
                    )
                    max_depth_slider = gr.Slider(
                        minimum=1,
                        maximum=50,
                        value=10,
                        step=1,
                        label="最大深さ (max_depth)"
                    )
                
                build_btn = gr.Button("単語ツリーを構築", variant="primary")
            
            with gr.Column(scale=1):
                status_output = gr.Textbox(
                    label="ステータス",
                    value=get_status(),
                    lines=5,
                    interactive=False
                )
                refresh_status_btn = gr.Button("ステータス更新")
        
        results_output = gr.Dataframe(
            label="結果",
            headers=["テキスト", "確率"],
            datatype=["str", "number"],
            interactive=False
        )
        
        # イベントハンドラ
        def build_and_display(prompt, root, top_k, max_depth):
            results = build_word_tree(prompt, root, int(top_k), int(max_depth))
            # DataFrame用に変換
            df_data = [[r["text"], f"{r['probability']:.4f}"] for r in results]
            return df_data, get_status()
        
        build_btn.click(
            fn=build_and_display,
            inputs=[prompt_input, root_input, top_k_slider, max_depth_slider],
            outputs=[results_output, status_output]
        )
        
        refresh_status_btn.click(
            fn=lambda: get_status(),
            outputs=status_output
        )


# FastAPIアプリを作成(元のLLMViewと同じ構造)
app = FastAPI(
    title="LLMView Multi-Model API",
    description="LLMView の単語ツリー構築 API。/build_word_tree にPOSTしてください。",
    version="1.0.0",
)


# ZeroGPU対応: 起動時に検出されるように、デコレータ付き関数を定義
@spaces.GPU
def _gpu_init_function():
    """GPU初期化用のダミー関数(Space起動時に検出される)"""
    pass


@app.on_event("startup")
async def startup_event():
    """アプリ起動時の処理(GPU要求を確実に検出させる)"""
    if SPACES_AVAILABLE:
        try:
            _gpu_init_function()
            print("[SPACE] GPU要求をstartup eventで送信しました")
        except Exception as e:
            print(f"[SPACE] GPU要求エラー: {e}")


@app.get("/")
def root() -> Dict[str, str]:
    """簡易案内(元のLLMViewと同じ)"""
    return {
        "message": "LLMView Multi-Model API",
        "status_endpoint": "/health",
        "build_endpoint": "/build_word_tree",
    }


@app.get("/health")
def health() -> Dict[str, Any]:
    """状態確認(元のLLMViewと同じ形式)"""
    with status_lock:
        current_status = status_message
    
    return {
        "model_loaded": adapter is not None,
        "status": current_status,
        "model_type": MODEL_TYPE,
        "model_path": HF_MODEL_REPO if MODEL_TYPE == "transformers" else None,
    }


@spaces.GPU  # ZeroGPU対応: デコレータを先に適用(Space起動時に検出される)
@app.post("/build_word_tree", response_model=List[WordTreeResponse])
def api_build_word_tree(payload: WordTreeRequest) -> List[WordTreeResponse]:
    """単語ツリーを構築(元のLLMViewと同じAPI)"""
    if not payload.prompt_text.strip():
        raise HTTPException(status_code=400, detail="prompt_text を入力してください。")

    if adapter is None:
        print("[API] build_word_tree: モデル未準備(adapter is None)")
        with status_lock:
            current_status = status_message
        raise HTTPException(
            status_code=503, detail=f"モデル準備中です: {current_status}"
        )

    try:
        print(
            f"[API] build_word_tree called: prompt=\n##########################\n{payload.prompt_text}\n##########################\n', "
            f"root=\n%%%%%%%%%%%%%%%%%%%\n{payload.root_text}\n%%%%%%%%%%%%%%%%%%%\n', top_k={payload.top_k}, max_depth={payload.max_depth}"
        )

        results = adapter.build_word_tree(
            prompt_text=payload.prompt_text,
            root_text=payload.root_text,
            top_k=payload.top_k,
            max_depth=payload.max_depth,
        )

        if not results:
            print("[API] No candidates generated, returning dummy candidates")
            results = _get_dummy_results()

        return [WordTreeResponse(**item) for item in results]
    except Exception as exc:
        import traceback
        traceback.print_exc()
        print(f"[API] build_word_tree error: {exc}, fallback to dummy results")
        return _get_dummy_results()


if __name__ == "__main__":
    # Gradio UIが有効な場合はGradioアプリを起動、無効な場合はFastAPIのみ
    if ENABLE_GRADIO_UI and demo is not None:
        # GradioアプリにFastAPIを統合
        demo.fastapi_app = app
        demo.launch(
            server_name="0.0.0.0",
            server_port=7860,
            share=False
        )
    else:
        # FastAPIのみ(元のLLMViewと同じ)
        import uvicorn
        uvicorn.run(
            app,
            host="0.0.0.0",
            port=7860,
            log_level=os.getenv("UVICORN_LOG_LEVEL", "warning"),
            access_log=os.getenv("UVICORN_ACCESS_LOG", "false").lower() == "true",
        )