app.py
CHANGED
|
@@ -66,9 +66,22 @@ class LocalModelManager:
|
|
| 66 |
self.models = {}
|
| 67 |
self.tokenizers = {}
|
| 68 |
self.pipelines = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
-
|
| 71 |
-
"""モデルの
|
| 72 |
if model_path not in self.models:
|
| 73 |
logger.info(f"Loading model: {model_path}")
|
| 74 |
try:
|
|
@@ -102,6 +115,11 @@ class LocalModelManager:
|
|
| 102 |
logger.error(f"Error loading model {model_path}: {str(e)}")
|
| 103 |
raise
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
@spaces.GPU()
|
| 106 |
def _generate_text_sync(self, pipeline, text: str) -> str:
|
| 107 |
"""同期的なテキスト生成の実行"""
|
|
@@ -157,6 +175,12 @@ class ModelManager:
|
|
| 157 |
model.model_id,
|
| 158 |
token=True # これによりHFトークンを使用
|
| 159 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
async def run_text_generation(self, text: str, selected_types: List[str]) -> List[str]:
|
| 162 |
"""テキスト生成モデルの実行"""
|
|
@@ -288,6 +312,7 @@ class ToxicityApp:
|
|
| 288 |
def __init__(self):
|
| 289 |
self.ui = UIComponents()
|
| 290 |
self.model_manager = ModelManager()
|
|
|
|
| 291 |
|
| 292 |
def update_model_visibility(self, selected_types: List[str]) -> List[gr.update]:
|
| 293 |
"""モデルの表示状態を更新"""
|
|
@@ -311,40 +336,74 @@ class ToxicityApp:
|
|
| 311 |
class_results.extend([""] * (len(CLASSIFICATION_MODELS) - len(class_results)))
|
| 312 |
|
| 313 |
return gen_results + class_results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
def create_ui(self):
|
| 316 |
"""UIの作成"""
|
| 317 |
with gr.Blocks() as demo:
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
)
|
| 344 |
|
| 345 |
return demo
|
| 346 |
|
| 347 |
def main():
|
|
|
|
| 348 |
app = ToxicityApp()
|
| 349 |
demo = app.create_ui()
|
| 350 |
demo.launch()
|
|
|
|
| 66 |
self.models = {}
|
| 67 |
self.tokenizers = {}
|
| 68 |
self.pipelines = {}
|
| 69 |
+
|
| 70 |
+
def preload_models(self):
|
| 71 |
+
"""起動時にすべてのローカルモデルを事前にロード"""
|
| 72 |
+
logger.info("Preloading all local models...")
|
| 73 |
+
for model in TEXT_GENERATION_MODELS:
|
| 74 |
+
if model.type == ModelType.LOCAL and model.model_path:
|
| 75 |
+
self.load_model_sync(model.model_path, "text-generation")
|
| 76 |
+
|
| 77 |
+
for model in CLASSIFICATION_MODELS:
|
| 78 |
+
if model.type == ModelType.LOCAL and model.model_path:
|
| 79 |
+
self.load_model_sync(model.model_path, "text-classification")
|
| 80 |
+
|
| 81 |
+
logger.info("All local models preloaded successfully")
|
| 82 |
|
| 83 |
+
def load_model_sync(self, model_path: str, task: str = "text-generation"):
|
| 84 |
+
"""モデルの同期ロード"""
|
| 85 |
if model_path not in self.models:
|
| 86 |
logger.info(f"Loading model: {model_path}")
|
| 87 |
try:
|
|
|
|
| 115 |
logger.error(f"Error loading model {model_path}: {str(e)}")
|
| 116 |
raise
|
| 117 |
|
| 118 |
+
async def load_model(self, model_path: str, task: str = "text-generation"):
|
| 119 |
+
"""モデルの遅延ロード(バックワードコンパティビリティのために維持)"""
|
| 120 |
+
if model_path not in self.models:
|
| 121 |
+
self.load_model_sync(model_path, task)
|
| 122 |
+
|
| 123 |
@spaces.GPU()
|
| 124 |
def _generate_text_sync(self, pipeline, text: str) -> str:
|
| 125 |
"""同期的なテキスト生成の実行"""
|
|
|
|
| 175 |
model.model_id,
|
| 176 |
token=True # これによりHFトークンを使用
|
| 177 |
)
|
| 178 |
+
|
| 179 |
+
def preload_models(self):
|
| 180 |
+
"""起動時にすべてのモデルを事前にロード"""
|
| 181 |
+
logger.info("Preloading models...")
|
| 182 |
+
self.local_manager.preload_models()
|
| 183 |
+
logger.info("Model preloading complete")
|
| 184 |
|
| 185 |
async def run_text_generation(self, text: str, selected_types: List[str]) -> List[str]:
|
| 186 |
"""テキスト生成モデルの実行"""
|
|
|
|
| 312 |
def __init__(self):
|
| 313 |
self.ui = UIComponents()
|
| 314 |
self.model_manager = ModelManager()
|
| 315 |
+
self.models_loaded = False
|
| 316 |
|
| 317 |
def update_model_visibility(self, selected_types: List[str]) -> List[gr.update]:
|
| 318 |
"""モデルの表示状態を更新"""
|
|
|
|
| 336 |
class_results.extend([""] * (len(CLASSIFICATION_MODELS) - len(class_results)))
|
| 337 |
|
| 338 |
return gen_results + class_results
|
| 339 |
+
|
| 340 |
+
def load_models_and_update_ui(self):
|
| 341 |
+
"""モデルをロードしUIを更新する"""
|
| 342 |
+
try:
|
| 343 |
+
# モデルのロード
|
| 344 |
+
self.model_manager.preload_models()
|
| 345 |
+
self.models_loaded = True
|
| 346 |
+
logger.info("Models loaded successfully")
|
| 347 |
+
# ロード完了メッセージを返して、UIのロード中表示を非表示にする
|
| 348 |
+
return gr.update(visible=False), gr.update(visible=True)
|
| 349 |
+
except Exception as e:
|
| 350 |
+
logger.error(f"Error loading models: {e}")
|
| 351 |
+
return gr.update(value=f"Error loading models: {e}"), gr.update(visible=False)
|
| 352 |
|
| 353 |
def create_ui(self):
|
| 354 |
"""UIの作成"""
|
| 355 |
with gr.Blocks() as demo:
|
| 356 |
+
# ロード中コンポーネント
|
| 357 |
+
with gr.Group(visible=True) as loading_group:
|
| 358 |
+
gr.Markdown("""
|
| 359 |
+
# Toxic Eye
|
| 360 |
+
|
| 361 |
+
### Loading models... This may take a few minutes.
|
| 362 |
+
|
| 363 |
+
The application is initializing and preloading all models.
|
| 364 |
+
Please wait while the models are being loaded...
|
| 365 |
+
""")
|
| 366 |
+
|
| 367 |
+
# メインUIコンポーネント(初期状態では非表示)
|
| 368 |
+
with gr.Group(visible=False) as main_ui_group:
|
| 369 |
+
self.ui.create_header()
|
| 370 |
+
self.ui.create_input_section()
|
| 371 |
+
self.ui.create_filter_section()
|
| 372 |
+
self.ui.create_invoke_button()
|
| 373 |
+
self.ui.create_model_tabs()
|
| 374 |
+
|
| 375 |
+
# イベントハンドラの設定
|
| 376 |
+
self.ui.filter_checkboxes.change(
|
| 377 |
+
fn=self.update_model_visibility,
|
| 378 |
+
inputs=[self.ui.filter_checkboxes],
|
| 379 |
+
outputs=[
|
| 380 |
+
output["group"]
|
| 381 |
+
for outputs in [self.ui.gen_model_outputs, self.ui.class_model_outputs]
|
| 382 |
+
for output in outputs
|
| 383 |
+
]
|
| 384 |
+
)
|
| 385 |
|
| 386 |
+
self.ui.invoke_button.click(
|
| 387 |
+
fn=self.handle_invoke,
|
| 388 |
+
inputs=[self.ui.input_text, self.ui.filter_checkboxes],
|
| 389 |
+
outputs=[
|
| 390 |
+
output["output"]
|
| 391 |
+
for outputs in [self.ui.gen_model_outputs, self.ui.class_model_outputs]
|
| 392 |
+
for output in outputs
|
| 393 |
+
]
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
# 起動時にモデルロード処理を実行
|
| 397 |
+
demo.load(
|
| 398 |
+
fn=self.load_models_and_update_ui,
|
| 399 |
+
inputs=None,
|
| 400 |
+
outputs=[loading_group, main_ui_group]
|
| 401 |
)
|
| 402 |
|
| 403 |
return demo
|
| 404 |
|
| 405 |
def main():
|
| 406 |
+
logger.info("Starting Toxic Eye application")
|
| 407 |
app = ToxicityApp()
|
| 408 |
demo = app.create_ui()
|
| 409 |
demo.launch()
|