WatNeru commited on
Commit
e1a7842
·
1 Parent(s): 0dac8e2

gradio ui disable

Browse files
Files changed (1) hide show
  1. app.py +133 -95
app.py CHANGED
@@ -9,11 +9,15 @@ import threading
9
  from pathlib import Path
10
  from typing import List, Dict, Any, Optional
11
 
12
- import gradio as gr
13
  from fastapi import FastAPI, HTTPException
14
- from fastapi.responses import JSONResponse
15
  from pydantic import BaseModel, Field
16
 
 
 
 
 
 
 
17
  # ZeroGPU対応: spacesパッケージをインポート(デコレータ用)
18
  try:
19
  import spaces
@@ -159,84 +163,94 @@ def get_status() -> str:
159
  return f"{model_info}ステータス: {current_status}"
160
 
161
 
162
- # Gradioインターフェース
163
- with gr.Blocks(title="LLMView Multi-Model", theme=gr.themes.Soft()) as demo:
164
- gr.Markdown("""
165
- # LLMView Multi-Model
166
-
167
- 複数のAIモデルに対応した単語ツリー構築ツール
168
-
169
- ## 使い方
170
- 1. プロンプトを入力
171
- 2. オプションでルートテキストを指定(既存のテキストの続きを生成する場合)
172
- 3. パラメータ調整(top_k: 候補数、max_depth: 最大深さ)
173
- 4. 「単語ツリーを構築」ボタンクリック
174
- """)
175
-
176
- with gr.Row():
177
- with gr.Column(scale=2):
178
- prompt_input = gr.Textbox(
179
- label="プロンプト",
180
- placeholder="例: 電球を作ったのは誰?",
181
- lines=3
182
- )
183
- root_input = gr.Textbox(
184
- label="ルートテキスト(オプション)",
185
- placeholder="例: 電球を作ったのは",
186
- lines=2
187
- )
188
-
189
- with gr.Row():
190
- top_k_slider = gr.Slider(
191
- minimum=1,
192
- maximum=20,
193
- value=5,
194
- step=1,
195
- label="候補数 (top_k)"
196
  )
197
- max_depth_slider = gr.Slider(
198
- minimum=1,
199
- maximum=50,
200
- value=10,
201
- step=1,
202
- label="最大深さ (max_depth)"
203
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
- build_btn = gr.Button("単語ツリーを構築", variant="primary")
 
 
 
 
 
 
 
206
 
207
- with gr.Column(scale=1):
208
- status_output = gr.Textbox(
209
- label="ータス",
210
- value=get_status(),
211
- lines=5,
212
- interactive=False
213
- )
214
- refresh_status_btn = gr.Button("ステータス更新")
215
-
216
- results_output = gr.Dataframe(
217
- label="結果",
218
- headers=["テキスト", "確率"],
219
- datatype=["str", "number"],
220
- interactive=False
221
- )
222
-
223
- # イベントハンドラ
224
- def build_and_display(prompt, root, top_k, max_depth):
225
- results = build_word_tree(prompt, root, int(top_k), int(max_depth))
226
- # DataFrame用に変換
227
- df_data = [[r["text"], f"{r['probability']:.4f}"] for r in results]
228
- return df_data, get_status()
229
-
230
- build_btn.click(
231
- fn=build_and_display,
232
- inputs=[prompt_input, root_input, top_k_slider, max_depth_slider],
233
- outputs=[results_output, status_output]
234
- )
235
-
236
- refresh_status_btn.click(
237
- fn=lambda: get_status(),
238
- outputs=status_output
239
- )
240
 
241
 
242
  # ZeroGPU対応: 起動時に検出されるように、デコレータ付き関数を定義
@@ -246,9 +260,28 @@ def _gpu_init_function():
246
  pass
247
 
248
 
249
- # GradioアプリのFastAPIインスタンスに直接ルートを追加
250
- # Gradioアプリは内部でFastAPIインスタンスを持っているので、それに直接ルートを追加
251
- @demo.app.get("/health")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  def health() -> Dict[str, Any]:
253
  """状態確認(元のLLMViewと同じ形式)"""
254
  with status_lock:
@@ -263,8 +296,8 @@ def health() -> Dict[str, Any]:
263
 
264
 
265
  @spaces.GPU # ZeroGPU対応: デコレータを先に適用(Space起動時に検出される)
266
- @demo.app.post("/build_word_tree", response_model=List[WordTreeResponse])
267
- def build_word_tree(payload: WordTreeRequest) -> List[WordTreeResponse]:
268
  """単語ツリーを構築(元のLLMViewと同じAPI)"""
269
  if not payload.prompt_text.strip():
270
  raise HTTPException(status_code=400, detail="prompt_text を入力してください。")
@@ -303,18 +336,23 @@ def build_word_tree(payload: WordTreeRequest) -> List[WordTreeResponse]:
303
 
304
 
305
  if __name__ == "__main__":
306
- # Hugging Face Spaces用設定
307
- # GPU要求を確実に検出させる
308
- if SPACES_AVAILABLE:
309
- try:
310
- _gpu_init_function()
311
- print("[SPACE] GPU要求を送信しました")
312
- except Exception as e:
313
- print(f"[SPACE] GPU要求エラー: {e}")
314
-
315
- demo.launch(
316
- server_name="0.0.0.0",
317
- server_port=7860,
318
- share=False
319
- )
 
 
 
 
 
320
 
 
9
  from pathlib import Path
10
  from typing import List, Dict, Any, Optional
11
 
 
12
  from fastapi import FastAPI, HTTPException
 
13
  from pydantic import BaseModel, Field
14
 
15
+ # Gradio UIを有効化するかどうか(環境変数で制御)
16
+ ENABLE_GRADIO_UI = os.getenv("ENABLE_GRADIO_UI", "false").lower() == "true"
17
+
18
+ if ENABLE_GRADIO_UI:
19
+ import gradio as gr
20
+
21
  # ZeroGPU対応: spacesパッケージをインポート(デコレータ用)
22
  try:
23
  import spaces
 
163
  return f"{model_info}ステータス: {current_status}"
164
 
165
 
166
+ # Gradioインターフェース(オプション)
167
+ demo = None
168
+ if ENABLE_GRADIO_UI:
169
+ with gr.Blocks(title="LLMView Multi-Model", theme=gr.themes.Soft()) as demo:
170
+ gr.Markdown("""
171
+ # LLMView Multi-Model
172
+
173
+ 複数のAIモデルに対応した単語ツリー構築ツール
174
+
175
+ ## 使い方
176
+ 1. プロンプト入力
177
+ 2. オプションでルトテキスト指定(既存のテキストの続き生成する場合)
178
+ 3. パラメータを調整(top_k: 候補数、max_depth: 最大深さ)
179
+ 4. 「単語ツリーを構築」ボタンをクリック
180
+ """)
181
+
182
+ with gr.Row():
183
+ with gr.Column(scale=2):
184
+ prompt_input = gr.Textbox(
185
+ label="プロンプト",
186
+ placeholder="例: 電球を作ったのは誰?",
187
+ lines=3
 
 
 
 
 
 
 
 
 
 
 
 
188
  )
189
+ root_input = gr.Textbox(
190
+ label="ルートテキスト(オプション)",
191
+ placeholder="例: 電球を作ったのは",
192
+ lines=2
 
 
193
  )
194
+
195
+ with gr.Row():
196
+ top_k_slider = gr.Slider(
197
+ minimum=1,
198
+ maximum=20,
199
+ value=5,
200
+ step=1,
201
+ label="候補数 (top_k)"
202
+ )
203
+ max_depth_slider = gr.Slider(
204
+ minimum=1,
205
+ maximum=50,
206
+ value=10,
207
+ step=1,
208
+ label="最大深さ (max_depth)"
209
+ )
210
+
211
+ build_btn = gr.Button("単語ツリーを構築", variant="primary")
212
 
213
+ with gr.Column(scale=1):
214
+ status_output = gr.Textbox(
215
+ label="ステータス",
216
+ value=get_status(),
217
+ lines=5,
218
+ interactive=False
219
+ )
220
+ refresh_status_btn = gr.Button("ステータス更新")
221
 
222
+ results_output = gr.Dataframe(
223
+ label="結果",
224
+ headers=["テ", "確率"],
225
+ datatype=["str", "number"],
226
+ interactive=False
227
+ )
228
+
229
+ # イベントハンドラ
230
+ def build_and_display(prompt, root, top_k, max_depth):
231
+ results = build_word_tree(prompt, root, int(top_k), int(max_depth))
232
+ # DataFrame用に変換
233
+ df_data = [[r["text"], f"{r['probability']:.4f}"] for r in results]
234
+ return df_data, get_status()
235
+
236
+ build_btn.click(
237
+ fn=build_and_display,
238
+ inputs=[prompt_input, root_input, top_k_slider, max_depth_slider],
239
+ outputs=[results_output, status_output]
240
+ )
241
+
242
+ refresh_status_btn.click(
243
+ fn=lambda: get_status(),
244
+ outputs=status_output
245
+ )
246
+
247
+
248
+ # FastAPIアプリを作成(元のLLMViewと同じ構造)
249
+ app = FastAPI(
250
+ title="LLMView Multi-Model API",
251
+ description="LLMView の単語ツリー構築 API。/build_word_tree にPOSTしてください。",
252
+ version="1.0.0",
253
+ )
 
254
 
255
 
256
  # ZeroGPU対応: 起動時に検出されるように、デコレータ付き関数を定義
 
260
  pass
261
 
262
 
263
+ @app.on_event("startup")
264
+ async def startup_event():
265
+ """アプリ起動時の処理(GPU要求を確実に検出させる)"""
266
+ if SPACES_AVAILABLE:
267
+ try:
268
+ _gpu_init_function()
269
+ print("[SPACE] GPU要求をstartup eventで送信しました")
270
+ except Exception as e:
271
+ print(f"[SPACE] GPU要求エラー: {e}")
272
+
273
+
274
+ @app.get("/")
275
+ def root() -> Dict[str, str]:
276
+ """簡易案内(元のLLMViewと同じ)"""
277
+ return {
278
+ "message": "LLMView Multi-Model API",
279
+ "status_endpoint": "/health",
280
+ "build_endpoint": "/build_word_tree",
281
+ }
282
+
283
+
284
+ @app.get("/health")
285
  def health() -> Dict[str, Any]:
286
  """状態確認(元のLLMViewと同じ形式)"""
287
  with status_lock:
 
296
 
297
 
298
  @spaces.GPU # ZeroGPU対応: デコレータを先に適用(Space起動時に検出される)
299
+ @app.post("/build_word_tree", response_model=List[WordTreeResponse])
300
+ def api_build_word_tree(payload: WordTreeRequest) -> List[WordTreeResponse]:
301
  """単語ツリーを構築(元のLLMViewと同じAPI)"""
302
  if not payload.prompt_text.strip():
303
  raise HTTPException(status_code=400, detail="prompt_text を入力してください。")
 
336
 
337
 
338
  if __name__ == "__main__":
339
+ # Gradio UIが有効な場合はGradioアプリを起動、無効な場合はFastAPI
340
+ if ENABLE_GRADIO_UI and demo is not None:
341
+ # GradioアプリにFastAPIを統合
342
+ demo.fastapi_app = app
343
+ demo.launch(
344
+ server_name="0.0.0.0",
345
+ server_port=7860,
346
+ share=False
347
+ )
348
+ else:
349
+ # FastAPIのみ(元のLLMViewと同じ)
350
+ import uvicorn
351
+ uvicorn.run(
352
+ app,
353
+ host="0.0.0.0",
354
+ port=7860,
355
+ log_level=os.getenv("UVICORN_LOG_LEVEL", "warning"),
356
+ access_log=os.getenv("UVICORN_ACCESS_LOG", "false").lower() == "true",
357
+ )
358