smartTranscend commited on
Commit
6a26e21
·
verified ·
1 Parent(s): f7eb620

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -40
app.py CHANGED
@@ -33,7 +33,7 @@ class WeightedTrainer(Trainer):
33
  super().__init__(*args, **kwargs)
34
  self.class_weights = class_weights
35
 
36
- def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
37
  labels = inputs.pop("labels")
38
  outputs = model(**inputs)
39
  if self.class_weights is not None:
@@ -158,22 +158,24 @@ def train_model(df_clean, weight_multiplier, epochs, batch_size, learning_rate,
158
  problem_type="single_label_classification"
159
  ).to(device)
160
 
161
- # 訓練設定
162
  training_args = TrainingArguments(
163
  output_dir='./results',
164
  num_train_epochs=epochs,
165
  per_device_train_batch_size=batch_size,
166
- per_device_eval_batch_size=batch_size*2,
167
  warmup_steps=200,
168
  weight_decay=0.01,
169
  learning_rate=learning_rate,
170
  logging_steps=50,
171
- eval_strategy="epoch",
172
  save_strategy="epoch",
173
  load_best_model_at_end=True,
174
  metric_for_best_model="f1",
175
  report_to="none",
176
- fp16=torch.cuda.is_available() # 使用混合精度加速
 
 
177
  )
178
 
179
  # 建立 Trainer
@@ -189,22 +191,11 @@ def train_model(df_clean, weight_multiplier, epochs, batch_size, learning_rate,
189
  # 訓練模型
190
  progress(0.3, desc="開始訓練...")
191
 
192
- # 訓練進度更新
193
- class ProgressCallback:
194
- def __init__(self, progress_bar, total_steps):
195
- self.progress_bar = progress_bar
196
- self.total_steps = total_steps
197
- self.current_step = 0
198
 
199
- def on_log(self, args, state, control, **kwargs):
200
- self.current_step = state.global_step
201
- progress_val = 0.3 + (0.6 * self.current_step / self.total_steps)
202
- self.progress_bar(progress_val, desc=f"訓練中... (Step {self.current_step}/{self.total_steps})")
203
-
204
- total_steps = len(train_dataset) // batch_size * epochs
205
- progress_callback = ProgressCallback(progress, total_steps)
206
- trainer.add_callback(progress_callback)
207
-
208
  trainer.train()
209
 
210
  progress(0.9, desc="評估模型...")
@@ -220,20 +211,20 @@ def train_model(df_clean, weight_multiplier, epochs, batch_size, learning_rate,
220
 
221
  | 指標 | 數值 |
222
  |------|------|
223
- | **F1 Score** | {results['eval_f1']:.4f} |
224
- | **Accuracy** | {results['eval_accuracy']:.4f} |
225
- | **Precision** | {results['eval_precision']:.4f} |
226
- | **Recall** | {results['eval_recall']:.4f} |
227
- | **Sensitivity** | {results['eval_sensitivity']:.4f} |
228
- | **Specificity** | {results['eval_specificity']:.4f} |
229
- | **AUC** | {results['eval_auc']:.4f} |
230
 
231
  ## 📈 混淆矩陣
232
 
233
  | | 預測:存活 | 預測:死亡 |
234
  |---|-----------|-----------|
235
- | **實際:存活** | TN={results['eval_tn']} | FP={results['eval_fp']} |
236
- | **實際:死亡** | FN={results['eval_fn']} | TP={results['eval_tp']} |
237
 
238
  ## ⚖️ 訓練設定
239
 
@@ -249,9 +240,9 @@ def train_model(df_clean, weight_multiplier, epochs, batch_size, learning_rate,
249
 
250
  ## 💡 模型解讀
251
 
252
- - **Precision** ({results['eval_precision']:.2f}): 預測為死亡的案例中,有 {results['eval_precision']*100:.1f}% 實死亡
253
- - **Recall** ({results['eval_recall']:.2f}): 實際死亡案例中,有 {results['eval_recall']*100:.1f}% 被正確識別
254
- - **F1 Score** ({results['eval_f1']:.2f}): 整體平衡表現 {'優秀' if results['eval_f1'] > 0.8 else '良好' if results['eval_f1'] > 0.6 else '尚可'}
255
 
256
  ---
257
  *訓練完成時間: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}*
@@ -271,7 +262,9 @@ def train_model(df_clean, weight_multiplier, epochs, batch_size, learning_rate,
271
  return report, results_json, gr.update(visible=True)
272
 
273
  except Exception as e:
274
- return f"❌ 訓練錯誤:{str(e)}", None, gr.update(visible=False)
 
 
275
 
276
  def predict_sample(text_input):
277
  """預測單一樣本"""
@@ -324,6 +317,8 @@ with gr.Blocks(title="BERT 乳癌存活預測訓練平台", theme=gr.themes.Soft
324
  2. 調整訓練參數
325
  3. 開始訓練
326
  4. 查看結果並測試預測
 
 
327
  """)
328
 
329
  # 狀態變數
@@ -355,9 +350,9 @@ with gr.Blocks(title="BERT 乳癌存活預測訓練平台", theme=gr.themes.Soft
355
  info="建議 3-5 輪"
356
  )
357
  batch_size_slider = gr.Slider(
358
- minimum=4, maximum=32, value=16, step=4,
359
  label="批次大小 (Batch Size)",
360
- info="較小的批次大小需要更多記憶體"
361
  )
362
  lr_slider = gr.Number(
363
  value=2e-5,
@@ -411,15 +406,20 @@ with gr.Blocks(title="BERT 乳癌存活預測訓練平台", theme=gr.themes.Soft
411
  - `label`: 0(存活)或 1(死亡)
412
 
413
  ### 參數說明
414
- - **權重倍數**: 調整對少數類別的重視程度
415
- - **訓練輪數**: 模型看過所有資料的次數
416
- - **批次大小**: 同時處理的樣本數
417
- - **學習率**: 模型更新的步幅
418
 
419
  ### 注意事項
420
  - 訓練時間依資料量和參數而定(通常 5-15 分鐘)
421
  - 建議至少有 100 筆以上的訓練資料
422
- - GPU 會顯著加速訓練
 
 
 
 
 
423
  """)
424
 
425
  # 事件處理
 
33
  super().__init__(*args, **kwargs)
34
  self.class_weights = class_weights
35
 
36
+ def compute_loss(self, model, inputs, return_outputs=False):
37
  labels = inputs.pop("labels")
38
  outputs = model(**inputs)
39
  if self.class_weights is not None:
 
158
  problem_type="single_label_classification"
159
  ).to(device)
160
 
161
+ # 訓練設定 - 使用正確的參數名稱
162
  training_args = TrainingArguments(
163
  output_dir='./results',
164
  num_train_epochs=epochs,
165
  per_device_train_batch_size=batch_size,
166
+ per_device_eval_batch_size=min(batch_size*2, 32),
167
  warmup_steps=200,
168
  weight_decay=0.01,
169
  learning_rate=learning_rate,
170
  logging_steps=50,
171
+ evaluation_strategy="epoch", # 正確的參數名稱
172
  save_strategy="epoch",
173
  load_best_model_at_end=True,
174
  metric_for_best_model="f1",
175
  report_to="none",
176
+ fp16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7, # 只在支援的 GPU 上使用
177
+ push_to_hub=False, # 明確設定為 False
178
+ remove_unused_columns=False # 避免移除必要欄位
179
  )
180
 
181
  # 建立 Trainer
 
191
  # 訓練模型
192
  progress(0.3, desc="開始訓練...")
193
 
194
+ # 簡單的進度更新
195
+ for epoch in range(epochs):
196
+ progress(0.3 + (0.6 * (epoch + 1) / epochs),
197
+ desc=f"訓練中... Epoch {epoch + 1}/{epochs}")
 
 
198
 
 
 
 
 
 
 
 
 
 
199
  trainer.train()
200
 
201
  progress(0.9, desc="評估模型...")
 
211
 
212
  | 指標 | 數值 |
213
  |------|------|
214
+ | **F1 Score** | {results.get('eval_f1', 0):.4f} |
215
+ | **Accuracy** | {results.get('eval_accuracy', 0):.4f} |
216
+ | **Precision** | {results.get('eval_precision', 0):.4f} |
217
+ | **Recall** | {results.get('eval_recall', 0):.4f} |
218
+ | **Sensitivity** | {results.get('eval_sensitivity', 0):.4f} |
219
+ | **Specificity** | {results.get('eval_specificity', 0):.4f} |
220
+ | **AUC** | {results.get('eval_auc', 0):.4f} |
221
 
222
  ## 📈 混淆矩陣
223
 
224
  | | 預測:存活 | 預測:死亡 |
225
  |---|-----------|-----------|
226
+ | **實際:存活** | TN={results.get('eval_tn', 0)} | FP={results.get('eval_fp', 0)} |
227
+ | **實際:死亡** | FN={results.get('eval_fn', 0)} | TP={results.get('eval_tp', 0)} |
228
 
229
  ## ⚖️ 訓練設定
230
 
 
240
 
241
  ## 💡 模型解讀
242
 
243
+ - **Precision** ({results.get('eval_precision', 0):.2f}): 預測為死亡的案例中的準
244
+ - **Recall** ({results.get('eval_recall', 0):.2f}): 實際死亡案例識別
245
+ - **F1 Score** ({results.get('eval_f1', 0):.2f}): 整體平衡表現
246
 
247
  ---
248
  *訓練完成時間: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}*
 
262
  return report, results_json, gr.update(visible=True)
263
 
264
  except Exception as e:
265
+ import traceback
266
+ error_msg = f"❌ 訓練錯誤:{str(e)}\n\n詳細錯誤:\n{traceback.format_exc()}"
267
+ return error_msg, None, gr.update(visible=False)
268
 
269
  def predict_sample(text_input):
270
  """預測單一樣本"""
 
317
  2. 調整訓練參數
318
  3. 開始訓練
319
  4. 查看結果並測試預測
320
+
321
+ **裝置狀態**: {f"🚀 GPU ({torch.cuda.get_device_name(0)})" if torch.cuda.is_available() else "💻 CPU (訓練會較慢)"}
322
  """)
323
 
324
  # 狀態變數
 
350
  info="建議 3-5 輪"
351
  )
352
  batch_size_slider = gr.Slider(
353
+ minimum=4, maximum=32, value=8, step=4,
354
  label="批次大小 (Batch Size)",
355
+ info="GPU 記憶體有限時請使用較小值"
356
  )
357
  lr_slider = gr.Number(
358
  value=2e-5,
 
406
  - `label`: 0(存活)或 1(死亡)
407
 
408
  ### 參數說明
409
+ - **權重倍數**: 調整對少數類別的重視程度(0.1-2.0)
410
+ - **訓練輪數**: 模型看過所有資料的次數(1-10)
411
+ - **批次大小**: 同時處理的樣本數(4-32)
412
+ - **學習率**: 模型更新的步幅(建議 2e-5)
413
 
414
  ### 注意事項
415
  - 訓練時間依資料量和參數而定(通常 5-15 分鐘)
416
  - 建議至少有 100 筆以上的訓練資料
417
+ - GPU 會顯著加速訓練(約快 5-10 倍)
418
+
419
+ ### 常見問題
420
+ - **記憶體不足**: 降低批次大小
421
+ - **訓練太慢**: 減少訓練輪數或使用 GPU
422
+ - **效果不佳**: 增加訓練資料或調整權重倍數
423
  """)
424
 
425
  # 事件處理