nyasukun commited on
Commit
2ed7694
·
1 Parent(s): e48aba1
Files changed (1) hide show
  1. app.py +331 -363
app.py CHANGED
@@ -2,7 +2,6 @@ import gradio as gr
2
  from huggingface_hub import AsyncInferenceClient
3
  from typing import List, Dict, Optional, Union
4
  import logging
5
- from dataclasses import dataclass
6
  from enum import Enum, auto
7
  import torch
8
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, pipeline
@@ -15,397 +14,366 @@ logging.basicConfig(
15
  )
16
  logger = logging.getLogger(__name__)
17
 
18
- # モデルの定義
19
- class ModelType(Enum):
20
- LOCAL = "local"
21
- INFERENCE_API = "inference_api"
22
 
23
- @dataclass
24
- class ModelConfig:
25
- name: str
26
- description: str
27
- type: ModelType
28
- model_id: Optional[str] = None
29
- model_path: Optional[str] = None
30
-
31
- # モデル定義を拡充
32
  TEXT_GENERATION_MODELS = [
33
- ModelConfig(
34
- name="Zephyr-7B",
35
- description="Specialized in understanding context and nuance",
36
- type=ModelType.INFERENCE_API,
37
- model_id="HuggingFaceH4/zephyr-7b-beta"
38
- ),
39
- ModelConfig(
40
- name="Llama-2",
41
- description="Known for its robust performance in content analysis",
42
- type=ModelType.LOCAL,
43
- model_path="meta-llama/Llama-2-7b-hf"
44
- ),
45
- ModelConfig(
46
- name="Mistral-7B",
47
- description="Offers precise and detailed text evaluation",
48
- type=ModelType.LOCAL,
49
- model_path="mistralai/Mistral-7B-v0.1"
50
- )
51
  ]
52
 
53
  CLASSIFICATION_MODELS = [
54
- ModelConfig(
55
- name="Toxic-BERT",
56
- description="Fine-tuned for toxic content detection",
57
- type=ModelType.LOCAL,
58
- model_path="unitary/toxic-bert"
59
- )
60
  ]
61
 
62
- class LocalModelManager:
63
- def __init__(self):
64
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
65
- logger.info(f"Using device: {self.device}")
66
- self.models = {}
67
- self.tokenizers = {}
68
- self.pipelines = {}
69
-
70
- def preload_models(self):
71
- """起動時にすべてのローカルモデルを事前にロード"""
72
- logger.info("Preloading all local models...")
73
- for model in TEXT_GENERATION_MODELS:
74
- if model.type == ModelType.LOCAL and model.model_path:
75
- self.load_model_sync(model.model_path, "text-generation")
76
-
77
- for model in CLASSIFICATION_MODELS:
78
- if model.type == ModelType.LOCAL and model.model_path:
79
- self.load_model_sync(model.model_path, "text-classification")
80
-
81
- logger.info("All local models preloaded successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- def load_model_sync(self, model_path: str, task: str = "text-generation"):
84
- """モデルの同期ロード"""
85
- if model_path not in self.models:
86
- logger.info(f"Loading model: {model_path}")
 
 
87
  try:
88
- self.tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
89
-
90
- if task == "text-generation":
91
- model = AutoModelForCausalLM.from_pretrained(
92
- model_path,
93
- torch_dtype=torch.float16,
94
- device_map="auto"
95
- )
96
- self.pipelines[model_path] = pipeline(
97
- "text-generation",
98
- model=model,
99
- tokenizer=self.tokenizers[model_path]
100
- )
101
- else: # classification
102
- model = AutoModelForSequenceClassification.from_pretrained(
103
- model_path,
104
- device_map="auto"
105
- )
106
- self.pipelines[model_path] = pipeline(
107
- "text-classification",
108
- model=model,
109
- tokenizer=self.tokenizers[model_path]
110
- )
111
-
112
- self.models[model_path] = model
113
- logger.info(f"Model loaded successfully: {model_path}")
114
  except Exception as e:
115
- logger.error(f"Error loading model {model_path}: {str(e)}")
116
- raise
117
-
118
- async def load_model(self, model_path: str, task: str = "text-generation"):
119
- """モデルの遅延ロード(バックワードコンパティビリティのために維持)"""
120
- if model_path not in self.models:
121
- self.load_model_sync(model_path, task)
122
-
123
- @spaces.GPU()
124
- def _generate_text_sync(self, pipeline, text: str) -> str:
125
- """同期的なテキスト生成の実行"""
126
- outputs = pipeline(
 
 
 
 
 
 
 
 
127
  text,
128
- max_new_tokens=100,
129
  do_sample=True,
130
  temperature=0.7,
131
  top_p=0.9,
132
  num_return_sequences=1
133
  )
134
  return outputs[0]["generated_text"]
135
-
136
- async def generate_text(self, model_path: str, text: str) -> str:
137
- """テキスト生成の実行(非同期ラッパー)"""
138
- if model_path not in self.models:
139
- await self.load_model(model_path, "text-generation")
140
-
141
- try:
142
- return self._generate_text_sync(self.pipelines[model_path], text)
143
- except Exception as e:
144
- logger.error(f"Error in text generation with {model_path}: {str(e)}")
145
- raise
146
-
147
- @spaces.GPU()
148
- def _classify_text_sync(self, pipeline, text: str) -> str:
149
- """同期的なテキスト分類の実行"""
150
- result = pipeline(text)
151
- return str(result)
152
-
153
- async def classify_text(self, model_path: str, text: str) -> str:
154
- """テキスト分類の実行(非同期ラッパー)"""
155
- if model_path not in self.models:
156
- await self.load_model(model_path, "text-classification")
157
-
158
- try:
159
- return self._classify_text_sync(self.pipelines[model_path], text)
160
- except Exception as e:
161
- logger.error(f"Error in classification with {model_path}: {str(e)}")
162
- raise
163
-
164
- class ModelManager:
165
- def __init__(self):
166
- self.api_clients = {}
167
- self.local_manager = LocalModelManager()
168
- self._initialize_clients()
169
-
170
- def _initialize_clients(self):
171
- """Inference APIクライアントの初期化"""
172
- for model in TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS:
173
- if model.type == ModelType.INFERENCE_API and model.model_id:
174
- self.api_clients[model.model_id] = AsyncInferenceClient(
175
- model.model_id,
176
- token=True # これによりHFトークンを使用
177
- )
178
 
179
- def preload_models(self):
180
- """起動時にすべてのモデルを事前にロード"""
181
- logger.info("Preloading models...")
182
- self.local_manager.preload_models()
183
- logger.info("Model preloading complete")
184
-
185
- async def run_text_generation(self, text: str, selected_types: List[str]) -> List[str]:
186
- """テキスト生成モデルの実行"""
187
- results = []
188
- for model in TEXT_GENERATION_MODELS:
189
- if model.type.value in selected_types:
190
- try:
191
- if model.type == ModelType.INFERENCE_API:
192
- logger.info(f"Running API text generation: {model.name}")
193
- response = await self.api_clients[model.model_id].text_generation(
194
- text, max_new_tokens=100, temperature=0.7
195
- )
196
- results.append(f"{model.name}: {response}")
197
- else:
198
- logger.info(f"Running local text generation: {model.name}")
199
- response = await self.local_manager.generate_text(model.model_path, text)
200
- results.append(f"{model.name}: {response}")
201
- except Exception as e:
202
- logger.error(f"Error in {model.name}: {str(e)}")
203
- results.append(f"{model.name}: Error - {str(e)}")
204
- return results
205
-
206
- async def run_classification(self, text: str, selected_types: List[str]) -> List[str]:
207
- """分類モデルの実行"""
208
- results = []
209
- for model in CLASSIFICATION_MODELS:
210
- if model.type.value in selected_types:
211
- try:
212
- if model.type == ModelType.INFERENCE_API:
213
- logger.info(f"Running API classification: {model.name}")
214
- response = await self.api_clients[model.model_id].text_classification(text)
215
- results.append(f"{model.name}: {response}")
216
- else:
217
- logger.info(f"Running local classification: {model.name}")
218
- response = await self.local_manager.classify_text(model.model_path, text)
219
- results.append(f"{model.name}: {response}")
220
- except Exception as e:
221
- logger.error(f"Error in {model.name}: {str(e)}")
222
- results.append(f"{model.name}: Error - {str(e)}")
223
- return results
224
-
225
- class UIComponents:
226
- def __init__(self):
227
- self.input_text = None
228
- self.filter_checkboxes = None
229
- self.invoke_button = None
230
- self.gen_model_outputs = []
231
- self.class_model_outputs = []
232
- self.community_output = None
233
-
234
- def create_header(self):
235
- """ヘッダーセクションの作成"""
236
- return gr.Markdown("""
237
- # Toxic Eye
238
- This system evaluates the toxicity level of input text using multiple approaches.
239
- """)
240
-
241
- def create_input_section(self):
242
- """入力セクションの作成"""
243
- with gr.Row():
244
- self.input_text = gr.Textbox(
245
- label="Input Text",
246
- placeholder="Enter text to analyze...",
247
- lines=3
248
- )
249
-
250
- def create_filter_section(self):
251
- """フィルターセクションの作成"""
252
- with gr.Row():
253
- self.filter_checkboxes = gr.CheckboxGroup(
254
- choices=[t.value for t in ModelType],
255
- value=[t.value for t in ModelType],
256
- label="Filter Models",
257
- info="Choose which types of models to display",
258
- interactive=True
259
- )
260
-
261
- def create_invoke_button(self):
262
- """Invokeボタンの作成"""
263
- with gr.Row():
264
- self.invoke_button = gr.Button(
265
- "Invoke Selected Models",
266
- variant="primary",
267
- size="lg"
268
- )
269
-
270
- def create_model_grid(self, models: List[ModelConfig]) -> List[Dict]:
271
- """モデルグリッドの作成"""
272
- outputs = []
273
- with gr.Column() as container:
274
- for i in range(0, len(models), 2):
275
- with gr.Row() as row:
276
- for j in range(min(2, len(models) - i)):
277
- model = models[i + j]
278
- with gr.Column():
279
- with gr.Group() as group:
280
- gr.Markdown(f"### {model.name}")
281
- gr.Markdown(f"Type: {model.type.value}")
282
- output = gr.Textbox(
283
- label="Model Output",
284
- lines=5,
285
- interactive=False,
286
- info=model.description
287
- )
288
- outputs.append({
289
- "type": model.type.value,
290
- "name": model.name,
291
- "output": output,
292
- "group": group
293
- })
294
- return outputs
295
-
296
- def create_model_tabs(self):
297
- """モデルタブの作成"""
298
- with gr.Tabs():
299
- with gr.Tab("Text Generation LLM"):
300
- self.gen_model_outputs = self.create_model_grid(TEXT_GENERATION_MODELS)
301
- with gr.Tab("Classification LLM"):
302
- self.class_model_outputs = self.create_model_grid(CLASSIFICATION_MODELS)
303
- with gr.Tab("Community (Not implemented)"):
304
- with gr.Column():
305
- self.community_output = gr.Textbox(
306
- label="Related Community Topics",
307
- lines=5,
308
- interactive=False
309
  )
310
-
311
- class ToxicityApp:
312
- def __init__(self):
313
- self.ui = UIComponents()
314
- self.model_manager = ModelManager()
315
- self.models_loaded = False
316
-
317
- def update_model_visibility(self, selected_types: List[str]) -> List[gr.update]:
318
- """モデルの表示状態を更新"""
319
- logger.info(f"Updating visibility for types: {selected_types}")
320
-
321
- updates = []
322
- for outputs in [self.ui.gen_model_outputs, self.ui.class_model_outputs]:
323
- for output in outputs:
324
- visible = output["type"] in selected_types
325
- logger.info(f"Model {output['name']} (type: {output['type']}): visible = {visible}")
326
- updates.append(gr.update(visible=visible))
327
- return updates
328
-
329
- async def handle_invoke(self, text: str, selected_types: List[str]) -> List[str]:
330
- """Invokeボタンのハンドラ"""
331
- gen_results = await self.model_manager.run_text_generation(text, selected_types)
332
- class_results = await self.model_manager.run_classification(text, selected_types)
333
-
334
- # 結果リストの長さを調整
335
- gen_results.extend([""] * (len(TEXT_GENERATION_MODELS) - len(gen_results)))
336
- class_results.extend([""] * (len(CLASSIFICATION_MODELS) - len(class_results)))
337
-
338
- return gen_results + class_results
 
 
 
 
 
 
339
 
340
- def load_models_and_update_ui(self):
341
- """モデルをロードしUIを更新する"""
342
- try:
343
- # モデルのロード
344
- self.model_manager.preload_models()
345
- self.models_loaded = True
346
- logger.info("Models loaded successfully")
347
- # ロード完了メッセージを返して、UIのロード中表示を非表示にする
348
- return gr.update(visible=False), gr.update(visible=True)
349
- except Exception as e:
350
- logger.error(f"Error loading models: {e}")
351
- return gr.update(value=f"Error loading models: {e}"), gr.update(visible=False)
352
 
353
- def create_ui(self):
354
- """UIの作成"""
355
- with gr.Blocks() as demo:
356
- # ロード中コンポーネント
357
- with gr.Group(visible=True) as loading_group:
358
- gr.Markdown("""
359
- # Toxic Eye
360
-
361
- ### Loading models... This may take a few minutes.
362
-
363
- The application is initializing and preloading all models.
364
- Please wait while the models are being loaded...
365
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
 
367
- # メインUIコンポーネント(初期状態では非表示)
368
- with gr.Group(visible=False) as main_ui_group:
369
- self.ui.create_header()
370
- self.ui.create_input_section()
371
- self.ui.create_filter_section()
372
- self.ui.create_invoke_button()
373
- self.ui.create_model_tabs()
374
-
375
- # イベントハンドラの設定
376
- self.ui.filter_checkboxes.change(
377
- fn=self.update_model_visibility,
378
- inputs=[self.ui.filter_checkboxes],
379
- outputs=[
380
- output["group"]
381
- for outputs in [self.ui.gen_model_outputs, self.ui.class_model_outputs]
382
- for output in outputs
383
- ]
 
 
 
384
  )
385
-
386
- self.ui.invoke_button.click(
387
- fn=self.handle_invoke,
388
- inputs=[self.ui.input_text, self.ui.filter_checkboxes],
389
- outputs=[
390
- output["output"]
391
- for outputs in [self.ui.gen_model_outputs, self.ui.class_model_outputs]
392
- for output in outputs
393
- ]
394
  )
395
 
396
- # 起動時にモデルロード処理を実行
397
- demo.load(
398
- fn=self.load_models_and_update_ui,
399
- inputs=None,
400
- outputs=[loading_group, main_ui_group]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
 
403
- return demo
404
-
405
  def main():
406
  logger.info("Starting Toxic Eye application")
407
- app = ToxicityApp()
408
- demo = app.create_ui()
409
  demo.launch()
410
 
411
  if __name__ == "__main__":
 
2
  from huggingface_hub import AsyncInferenceClient
3
  from typing import List, Dict, Optional, Union
4
  import logging
 
5
  from enum import Enum, auto
6
  import torch
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, pipeline
 
14
  )
15
  logger = logging.getLogger(__name__)
16
 
17
+ # モデルタイプの定義
18
+ LOCAL = "local"
19
+ INFERENCE_API = "inference_api"
 
20
 
21
+ # モデル定義
 
 
 
 
 
 
 
 
22
  TEXT_GENERATION_MODELS = [
23
+ {
24
+ "name": "Zephyr-7B",
25
+ "description": "Specialized in understanding context and nuance",
26
+ "type": INFERENCE_API,
27
+ "model_id": "HuggingFaceH4/zephyr-7b-beta"
28
+ },
29
+ {
30
+ "name": "Llama-2",
31
+ "description": "Known for its robust performance in content analysis",
32
+ "type": LOCAL,
33
+ "model_path": "meta-llama/Llama-2-7b-hf"
34
+ },
35
+ {
36
+ "name": "Mistral-7B",
37
+ "description": "Offers precise and detailed text evaluation",
38
+ "type": LOCAL,
39
+ "model_path": "mistralai/Mistral-7B-v0.1"
40
+ }
41
  ]
42
 
43
  CLASSIFICATION_MODELS = [
44
+ {
45
+ "name": "Toxic-BERT",
46
+ "description": "Fine-tuned for toxic content detection",
47
+ "type": LOCAL,
48
+ "model_path": "unitary/toxic-bert"
49
+ }
50
  ]
51
 
52
+ # グローバル変数でモデルやトークナイザーを管理
53
+ models = {}
54
+ tokenizers = {}
55
+ pipelines = {}
56
+ api_clients = {}
57
+
58
+ # インファレンスAPIクライアントの初期化
59
+ def initialize_api_clients():
60
+ """Inference APIクライアントの初期化"""
61
+ for model in TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS:
62
+ if model["type"] == INFERENCE_API and "model_id" in model:
63
+ api_clients[model["model_id"]] = AsyncInferenceClient(
64
+ model["model_id"],
65
+ token=True # これによりHFトークンを使用
66
+ )
67
+ logger.info("API clients initialized")
68
+
69
+ # モデルのロード関数
70
+ def load_model(model_path, task="text-generation"):
71
+ """モデルの同期ロード"""
72
+ if model_path not in models:
73
+ logger.info(f"Loading model: {model_path}")
74
+ try:
75
+ tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
76
+
77
+ if task == "text-generation":
78
+ model = AutoModelForCausalLM.from_pretrained(
79
+ model_path,
80
+ torch_dtype=torch.float16,
81
+ load_in_8bit=True, # メモリ使用量削減のため8bit量子化を使用
82
+ device_map="auto"
83
+ )
84
+ pipelines[model_path] = pipeline(
85
+ "text-generation",
86
+ model=model,
87
+ tokenizer=tokenizers[model_path]
88
+ )
89
+ else: # classification
90
+ model = AutoModelForSequenceClassification.from_pretrained(
91
+ model_path,
92
+ device_map="auto"
93
+ )
94
+ pipelines[model_path] = pipeline(
95
+ "text-classification",
96
+ model=model,
97
+ tokenizer=tokenizers[model_path]
98
+ )
99
+
100
+ models[model_path] = model
101
+ logger.info(f"Model loaded successfully: {model_path}")
102
+ except Exception as e:
103
+ logger.error(f"Error loading model {model_path}: {str(e)}")
104
+ raise
105
 
106
+ # すべてのモデルを事前にロード
107
+ def preload_models():
108
+ """起動時にすべてのローカルモデルを事前にロード"""
109
+ logger.info("Preloading all local models...")
110
+ for model in TEXT_GENERATION_MODELS:
111
+ if model["type"] == LOCAL and "model_path" in model:
112
  try:
113
+ load_model(model["model_path"], "text-generation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  except Exception as e:
115
+ logger.error(f"Failed to preload {model['name']}: {e}")
116
+
117
+ for model in CLASSIFICATION_MODELS:
118
+ if model["type"] == LOCAL and "model_path" in model:
119
+ try:
120
+ load_model(model["model_path"], "text-classification")
121
+ except Exception as e:
122
+ logger.error(f"Failed to preload {model['name']}: {e}")
123
+
124
+ logger.info("Model preloading complete")
125
+
126
+ # テキスト生成の実行関数
127
+ @spaces.GPU()
128
+ def generate_text(model_path, text):
129
+ """テキスト生成の実行"""
130
+ if model_path not in models:
131
+ load_model(model_path, "text-generation")
132
+
133
+ try:
134
+ outputs = pipelines[model_path](
135
  text,
136
+ max_new_tokens=50, # トークン数を減らしてGPUメモリ使用量を削減
137
  do_sample=True,
138
  temperature=0.7,
139
  top_p=0.9,
140
  num_return_sequences=1
141
  )
142
  return outputs[0]["generated_text"]
143
+ except Exception as e:
144
+ logger.error(f"Error in text generation with {model_path}: {str(e)}")
145
+ raise
146
+
147
+ # テキスト分類の実行関数
148
+ @spaces.GPU()
149
+ def classify_text(model_path, text):
150
+ """テキスト分類の実行"""
151
+ if model_path not in models:
152
+ load_model(model_path, "text-classification")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+ try:
155
+ result = pipelines[model_path](text)
156
+ return str(result)
157
+ except Exception as e:
158
+ logger.error(f"Error in classification with {model_path}: {str(e)}")
159
+ raise
160
+
161
+ # 複数のモデルでテキスト生成実行
162
+ async def run_text_generation(text, selected_types):
163
+ """テキスト生成モデルの実行"""
164
+ results = []
165
+ for model in TEXT_GENERATION_MODELS:
166
+ if model["type"] in selected_types:
167
+ try:
168
+ if model["type"] == INFERENCE_API:
169
+ logger.info(f"Running API text generation: {model['name']}")
170
+ response = await api_clients[model["model_id"]].text_generation(
171
+ text, max_new_tokens=50, temperature=0.7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  )
173
+ results.append(f"{model['name']}: {response}")
174
+ else:
175
+ logger.info(f"Running local text generation: {model['name']}")
176
+ response = generate_text(model["model_path"], text)
177
+ results.append(f"{model['name']}: {response}")
178
+ except Exception as e:
179
+ logger.error(f"Error in {model['name']}: {str(e)}")
180
+ results.append(f"{model['name']}: Error - {str(e)}")
181
+ return results
182
+
183
+ # 複数のモデルでテキスト分類を実行
184
+ async def run_classification(text, selected_types):
185
+ """分類モデルの実行"""
186
+ results = []
187
+ for model in CLASSIFICATION_MODELS:
188
+ if model["type"] in selected_types:
189
+ try:
190
+ if model["type"] == INFERENCE_API:
191
+ logger.info(f"Running API classification: {model['name']}")
192
+ response = await api_clients[model["model_id"]].text_classification(text)
193
+ results.append(f"{model['name']}: {response}")
194
+ else:
195
+ logger.info(f"Running local classification: {model['name']}")
196
+ response = classify_text(model["model_path"], text)
197
+ results.append(f"{model['name']}: {response}")
198
+ except Exception as e:
199
+ logger.error(f"Error in {model['name']}: {str(e)}")
200
+ results.append(f"{model['name']}: Error - {str(e)}")
201
+ return results
202
+
203
+ # Invokeボタンのハンドラ
204
+ async def handle_invoke(text, selected_types):
205
+ """Invokeボタンのハンドラ"""
206
+ gen_results = await run_text_generation(text, selected_types)
207
+ class_results = await run_classification(text, selected_types)
208
 
209
+ # 結果リストの長さを調整
210
+ gen_results.extend([""] * (len(TEXT_GENERATION_MODELS) - len(gen_results)))
211
+ class_results.extend([""] * (len(CLASSIFICATION_MODELS) - len(class_results)))
212
+
213
+ return gen_results + class_results
 
 
 
 
 
 
 
214
 
215
+ # モデルの表示状態を更新
216
+ def update_model_visibility(selected_types):
217
+ """モデルの表示状態を更新"""
218
+ logger.info(f"Updating visibility for types: {selected_types}")
219
+
220
+ updates = []
221
+ for model_outputs in [gen_model_outputs, class_model_outputs]:
222
+ for output in model_outputs:
223
+ visible = output["type"] in selected_types
224
+ logger.info(f"Model {output['name']} (type: {output['type']}): visible = {visible}")
225
+ updates.append(gr.update(visible=visible))
226
+ return updates
227
+
228
+ # モデルをロードしUIを更新する
229
+ def load_models_and_update_ui():
230
+ """モデルをロードしUIを更新する"""
231
+ try:
232
+ # APIクライアント初期化
233
+ initialize_api_clients()
234
+ # モデルのロード
235
+ preload_models()
236
+ logger.info("Models loaded successfully")
237
+ # ロード完了メッセージを返して、UIのロード中表示を非表示にする
238
+ return gr.update(visible=False), gr.update(visible=True)
239
+ except Exception as e:
240
+ logger.error(f"Error loading models: {e}")
241
+ return gr.update(value=f"Error loading models: {e}"), gr.update(visible=False)
242
+
243
+ # モデルグリッドの作成
244
+ def create_model_grid(models):
245
+ """モデルグリッドの作成"""
246
+ outputs = []
247
+ with gr.Column() as container:
248
+ for i in range(0, len(models), 2):
249
+ with gr.Row() as row:
250
+ for j in range(min(2, len(models) - i)):
251
+ model = models[i + j]
252
+ with gr.Column():
253
+ with gr.Group() as group:
254
+ gr.Markdown(f"### {model['name']}")
255
+ gr.Markdown(f"Type: {model['type']}")
256
+ output = gr.Textbox(
257
+ label="Model Output",
258
+ lines=5,
259
+ interactive=False,
260
+ info=model['description']
261
+ )
262
+ outputs.append({
263
+ "type": model["type"],
264
+ "name": model["name"],
265
+ "output": output,
266
+ "group": group
267
+ })
268
+ return outputs
269
+
270
+ # グローバル変数としてUI部品を保持
271
+ input_text = None
272
+ filter_checkboxes = None
273
+ invoke_button = None
274
+ gen_model_outputs = []
275
+ class_model_outputs = []
276
+ community_output = None
277
+
278
+ # UIの作成
279
+ def create_ui():
280
+ """UIの作成"""
281
+ global input_text, filter_checkboxes, invoke_button, gen_model_outputs, class_model_outputs, community_output
282
+
283
+ with gr.Blocks() as demo:
284
+ # ロード中コンポーネント
285
+ with gr.Group(visible=True) as loading_group:
286
+ gr.Markdown("""
287
+ # Toxic Eye
288
 
289
+ ### Loading models... This may take a few minutes.
290
+
291
+ The application is initializing and preloading all models.
292
+ Please wait while the models are being loaded...
293
+ """)
294
+
295
+ # メインUIコンポーネント(初期状態では非表示)
296
+ with gr.Group(visible=False) as main_ui_group:
297
+ # ヘッダー
298
+ gr.Markdown("""
299
+ # Toxic Eye
300
+ This system evaluates the toxicity level of input text using multiple approaches.
301
+ """)
302
+
303
+ # 入力セクション
304
+ with gr.Row():
305
+ input_text = gr.Textbox(
306
+ label="Input Text",
307
+ placeholder="Enter text to analyze...",
308
+ lines=3
309
  )
310
+
311
+ # フィルターセクション
312
+ with gr.Row():
313
+ filter_checkboxes = gr.CheckboxGroup(
314
+ choices=[LOCAL, INFERENCE_API],
315
+ value=[LOCAL, INFERENCE_API],
316
+ label="Filter Models",
317
+ info="Choose which types of models to display",
318
+ interactive=True
319
  )
320
 
321
+ # Invokeボタン
322
+ with gr.Row():
323
+ invoke_button = gr.Button(
324
+ "Invoke Selected Models",
325
+ variant="primary",
326
+ size="lg"
327
+ )
328
+
329
+ # モデルタブ
330
+ with gr.Tabs():
331
+ with gr.Tab("Text Generation LLM"):
332
+ gen_model_outputs = create_model_grid(TEXT_GENERATION_MODELS)
333
+ with gr.Tab("Classification LLM"):
334
+ class_model_outputs = create_model_grid(CLASSIFICATION_MODELS)
335
+ with gr.Tab("Community (Not implemented)"):
336
+ with gr.Column():
337
+ community_output = gr.Textbox(
338
+ label="Related Community Topics",
339
+ lines=5,
340
+ interactive=False
341
+ )
342
+
343
+ # イベントハンドラの設定
344
+ filter_checkboxes.change(
345
+ fn=update_model_visibility,
346
+ inputs=[filter_checkboxes],
347
+ outputs=[
348
+ output["group"]
349
+ for outputs in [gen_model_outputs, class_model_outputs]
350
+ for output in outputs
351
+ ]
352
  )
353
+
354
+ invoke_button.click(
355
+ fn=handle_invoke,
356
+ inputs=[input_text, filter_checkboxes],
357
+ outputs=[
358
+ output["output"]
359
+ for outputs in [gen_model_outputs, class_model_outputs]
360
+ for output in outputs
361
+ ]
362
+ )
363
+
364
+ # 起動時にモデルロード処理を実行
365
+ demo.load(
366
+ fn=load_models_and_update_ui,
367
+ inputs=None,
368
+ outputs=[loading_group, main_ui_group]
369
+ )
370
+
371
+ return demo
372
 
373
+ # メイン関数
 
374
  def main():
375
  logger.info("Starting Toxic Eye application")
376
+ demo = create_ui()
 
377
  demo.launch()
378
 
379
  if __name__ == "__main__":