smartTranscend commited on
Commit
b9be0aa
·
verified ·
1 Parent(s): c18db4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -37
app.py CHANGED
@@ -8,8 +8,15 @@ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, con
8
  from torch import nn
9
  import os
10
  from datetime import datetime
 
11
 
12
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
 
 
 
 
 
13
 
14
  # 全域變數
15
  trained_models = {}
@@ -57,16 +64,71 @@ def compute_metrics(pred):
57
  'sensitivity': 0, 'specificity': 0, 'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0
58
  }
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  class WeightedTrainer(Trainer):
61
- def __init__(self, *args, class_weights=None, **kwargs):
62
  super().__init__(*args, **kwargs)
63
  self.class_weights = class_weights
 
64
 
65
  def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
66
  labels = inputs.pop("labels")
67
  outputs = model(**inputs)
68
- loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
69
- loss = loss_fct(outputs.logits.view(-1, 2), labels.view(-1))
 
 
 
 
 
 
 
 
 
 
 
 
70
  return (loss, outputs) if return_outputs else loss
71
 
72
  def evaluate_baseline(model, tokenizer, test_dataset, device):
@@ -141,9 +203,14 @@ def train_bert_model(csv_file, base_model, method, num_epochs, batch_size, learn
141
  return "❌ 無死亡樣本", "", "", ""
142
 
143
  ratio = n0 / n1
144
- w0, w1 = 1.0, ratio * weight_mult
 
 
 
 
145
 
146
- info = f"📊 資料: {len(df_clean)} 筆\n存活: {n0} | 死亡: {n1}\n比例: {ratio:.2f}:1\n權重: {w0:.2f} / {w1:.2f}\n模型: {base_model}\n方法: {method.upper()}"
 
147
 
148
  tokenizer = BertTokenizer.from_pretrained(model_name)
149
  dataset = Dataset.from_pandas(df_clean[['text', 'label']])
@@ -157,8 +224,7 @@ def train_bert_model(csv_file, base_model, method, num_epochs, batch_size, learn
157
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
158
  info += f"\n裝置: {'GPU ✅' if torch.cuda.is_available() else 'CPU ⚠️'}"
159
 
160
- # 評估基準模型(未微調)
161
- info += "\n\n🔍 評估基準模型(未微調)..."
162
  baseline_model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
163
  baseline_model = baseline_model.to(device)
164
 
@@ -166,12 +232,11 @@ def train_bert_model(csv_file, base_model, method, num_epochs, batch_size, learn
166
  baseline_key = f"{base_model}_baseline"
167
  baseline_results[baseline_key] = baseline_perf
168
 
169
- info += f"\n基準 F1: {baseline_perf['f1']:.4f}"
170
- info += f"\n基準 Accuracy: {baseline_perf['accuracy']:.4f}"
171
-
172
  # 清理基準模型以釋放記憶體
173
  del baseline_model
174
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
 
 
175
 
176
  # 開始微調
177
  info += f"\n\n🔧 套用 {method.upper()} 微調..."
@@ -226,12 +291,12 @@ def train_bert_model(csv_file, base_model, method, num_epochs, batch_size, learn
226
  learning_rate=float(learning_rate),
227
  weight_decay=float(weight_decay),
228
  evaluation_strategy="epoch",
229
- save_strategy="epoch",
230
- load_best_model_at_end=True,
231
- metric_for_best_model=best_metric,
232
  report_to="none",
233
- logging_steps=50,
234
- save_total_limit=2
 
235
  )
236
 
237
  trainer = WeightedTrainer(
@@ -240,11 +305,27 @@ def train_bert_model(csv_file, base_model, method, num_epochs, batch_size, learn
240
  train_dataset=split['train'],
241
  eval_dataset=split['test'],
242
  compute_metrics=compute_metrics,
243
- class_weights=weights
 
244
  )
245
 
 
 
 
246
  info += "\n\n⏳ 開始訓練..."
247
- trainer.train()
 
 
 
 
 
 
 
 
 
 
 
 
248
  results = trainer.evaluate()
249
 
250
  # 生成帶時間戳的模型 ID
@@ -255,7 +336,7 @@ def train_bert_model(csv_file, base_model, method, num_epochs, batch_size, learn
255
  'model': model,
256
  'tokenizer': tokenizer,
257
  'results': results,
258
- 'baseline': baseline_perf,
259
  'config': {
260
  'type': base_model,
261
  'model_name': model_name,
@@ -275,7 +356,7 @@ def train_bert_model(csv_file, base_model, method, num_epochs, batch_size, learn
275
 
276
  # 純 BERT 輸出
277
  baseline_output = f"🔬 純 BERT(未微調)\n\n"
278
- baseline_output += f"📈 表現\n"
279
  baseline_output += f"F1: {baseline_perf['f1']:.4f}\n"
280
  baseline_output += f"Accuracy: {baseline_perf['accuracy']:.4f}\n"
281
  baseline_output += f"Precision: {baseline_perf['precision']:.4f}\n"
@@ -287,8 +368,9 @@ def train_bert_model(csv_file, base_model, method, num_epochs, batch_size, learn
287
  baseline_output += f"FP: {baseline_perf['fp']} | FN: {baseline_perf['fn']}"
288
 
289
  # 微調 BERT 輸出
290
- finetuned_output = f"✅ 微調 BERT\n模型: {model_id}\n\n"
291
- finetuned_output += f"📈 表現\n"
 
292
  finetuned_output += f"F1: {results['eval_f1']:.4f}\n"
293
  finetuned_output += f"Accuracy: {results['eval_accuracy']:.4f}\n"
294
  finetuned_output += f"Precision: {results['eval_precision']:.4f}\n"
@@ -489,9 +571,12 @@ with gr.Blocks(title="BERT Fine-tuning 教學平台", theme=gr.themes.Soft()) as
489
 
490
  gr.Markdown("### 🎯 基本訓練參數")
491
  with gr.Row():
492
- num_epochs = gr.Number(value=3, label="訓練輪數 (epochs)", minimum=1, maximum=100, precision=0)
493
- batch_size = gr.Number(value=8, label="批次大小 (batch_size)", minimum=1, maximum=128, precision=0)
494
- learning_rate = gr.Number(value=2e-5, label="學習率 (learning_rate)", minimum=0, maximum=1)
 
 
 
495
 
496
  gr.Markdown("### ⚙️ 進階參數")
497
  with gr.Row():
@@ -500,17 +585,17 @@ with gr.Blocks(title="BERT Fine-tuning 教學平台", theme=gr.themes.Soft()) as
500
 
501
  gr.Markdown("### 🔧 LoRA 參數")
502
  with gr.Row():
503
- lora_r = gr.Number(value=16, label="LoRA Rank (r)", minimum=1, maximum=256, precision=0,
504
- info="推薦 8-16,越大效果越好但越慢")
505
- lora_alpha = gr.Number(value=32, label="LoRA Alpha", minimum=1, maximum=512, precision=0,
506
- info="通常設為 Rank 2")
507
- lora_dropout = gr.Number(value=0.1, label="LoRA Dropout", minimum=0, maximum=1,
508
- info="防止過擬合")
509
 
510
  gr.Markdown("### ⚖️ 評估設定")
511
  with gr.Row():
512
- weight_mult = gr.Number(value=2.0, label="類別權重倍數", minimum=0, maximum=10,
513
- info="推薦 1.5-2.5,過低會忽略少數類")
514
  best_metric = gr.Dropdown(
515
  choices=["f1", "accuracy", "precision", "recall", "sensitivity", "specificity"],
516
  value="f1",
@@ -608,9 +693,12 @@ with gr.Blocks(title="BERT Fine-tuning 教學平台", theme=gr.themes.Soft()) as
608
  針對不平衡資料(如醫療資料):
609
  - **微調方法**: LoRA(快速有效)或 AdaLoRA(追求極致)
610
  - **LoRA Rank**: 8-16(平衡效果與速度)
611
- - **類別權重倍數**: 1.5-2.5(資料不平衡時)
612
- - **Learning rate**: 2e-5 到 5e-5
613
- - **Epochs**: 3-8(避免過擬合)
 
 
 
614
  - **Batch size**: 8-16(依 GPU 記憶體調整)
615
 
616
  ### 資料格式
@@ -629,4 +717,9 @@ with gr.Blocks(title="BERT Fine-tuning 教學平台", theme=gr.themes.Soft()) as
629
  """)
630
 
631
  if __name__ == "__main__":
632
- demo.launch()
 
 
 
 
 
 
8
  from torch import nn
9
  import os
10
  from datetime import datetime
11
+ import gc
12
 
13
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
15
+
16
+ # 設置較小的預設值以節省記憶體
17
+ torch.backends.cudnn.benchmark = False
18
+ if torch.cuda.is_available():
19
+ torch.cuda.empty_cache()
20
 
21
  # 全域變數
22
  trained_models = {}
 
64
  'sensitivity': 0, 'specificity': 0, 'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0
65
  }
66
 
67
+ def evaluate_baseline(model, tokenizer, test_dataset, device):
68
+ """評估未微調的基準模型"""
69
+ model.eval()
70
+ all_preds = []
71
+ all_labels = []
72
+
73
+ from torch.utils.data import DataLoader
74
+
75
+ def collate_fn(batch):
76
+ return {
77
+ 'input_ids': torch.stack([torch.tensor(item['input_ids']) for item in batch]),
78
+ 'attention_mask': torch.stack([torch.tensor(item['attention_mask']) for item in batch]),
79
+ 'labels': torch.tensor([item['label'] for item in batch])
80
+ }
81
+
82
+ dataloader = DataLoader(test_dataset, batch_size=16, collate_fn=collate_fn)
83
+
84
+ with torch.no_grad():
85
+ for batch in dataloader:
86
+ labels = batch.pop('labels')
87
+ inputs = {k: v.to(device) for k, v in batch.items()}
88
+ outputs = model(**inputs)
89
+ preds = torch.argmax(outputs.logits, dim=-1)
90
+ all_preds.extend(preds.cpu().numpy())
91
+ all_labels.extend(labels.numpy())
92
+
93
+ precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary', pos_label=1, zero_division=0)
94
+ acc = accuracy_score(all_labels, all_preds)
95
+ cm = confusion_matrix(all_labels, all_preds)
96
+ if cm.shape == (2, 2):
97
+ tn, fp, fn, tp = cm.ravel()
98
+ else:
99
+ tn = fp = fn = tp = 0
100
+ sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
101
+ specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
102
+
103
+ return {
104
+ 'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall,
105
+ 'sensitivity': sensitivity, 'specificity': specificity,
106
+ 'tp': int(tp), 'tn': int(tn), 'fp': int(fp), 'fn': int(fn)
107
+ }
108
+
109
  class WeightedTrainer(Trainer):
110
+ def __init__(self, *args, class_weights=None, use_focal_loss=False, **kwargs):
111
  super().__init__(*args, **kwargs)
112
  self.class_weights = class_weights
113
+ self.use_focal_loss = use_focal_loss
114
 
115
  def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
116
  labels = inputs.pop("labels")
117
  outputs = model(**inputs)
118
+ logits = outputs.logits
119
+
120
+ if self.use_focal_loss:
121
+ # Focal Loss: 更關注難分類的樣本
122
+ ce_loss = nn.CrossEntropyLoss(weight=self.class_weights, reduction='none')(
123
+ logits.view(-1, 2), labels.view(-1)
124
+ )
125
+ pt = torch.exp(-ce_loss)
126
+ focal_loss = ((1 - pt) ** 2 * ce_loss).mean()
127
+ loss = focal_loss
128
+ else:
129
+ loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
130
+ loss = loss_fct(logits.view(-1, 2), labels.view(-1))
131
+
132
  return (loss, outputs) if return_outputs else loss
133
 
134
  def evaluate_baseline(model, tokenizer, test_dataset, device):
 
203
  return "❌ 無死亡樣本", "", "", ""
204
 
205
  ratio = n0 / n1
206
+ # 動態調整權重計算
207
+ if ratio > 10: # 極度不平衡
208
+ w0, w1 = 1.0, min(ratio * weight_mult, ratio * 0.7) # 限制最大權重
209
+ else:
210
+ w0, w1 = 1.0, ratio * weight_mult
211
 
212
+ info = f"📊 資料: {len(df_clean)} 筆\n存活: {n0} | 死亡: {n1}\n比例: {ratio:.2f}:1\n"
213
+ info += f"⚖️ 權重: {w0:.2f} / {w1:.2f}\n模型: {base_model}\n方法: {method.upper()}"
214
 
215
  tokenizer = BertTokenizer.from_pretrained(model_name)
216
  dataset = Dataset.from_pandas(df_clean[['text', 'label']])
 
224
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
225
  info += f"\n裝置: {'GPU ✅' if torch.cuda.is_available() else 'CPU ⚠️'}"
226
 
227
+ # 🔇 靜默評估基準模型(不顯示在資料資訊中)
 
228
  baseline_model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
229
  baseline_model = baseline_model.to(device)
230
 
 
232
  baseline_key = f"{base_model}_baseline"
233
  baseline_results[baseline_key] = baseline_perf
234
 
 
 
 
235
  # 清理基準模型以釋放記憶體
236
  del baseline_model
237
+ if torch.cuda.is_available():
238
+ torch.cuda.empty_cache()
239
+ gc.collect()
240
 
241
  # 開始微調
242
  info += f"\n\n🔧 套用 {method.upper()} 微調..."
 
291
  learning_rate=float(learning_rate),
292
  weight_decay=float(weight_decay),
293
  evaluation_strategy="epoch",
294
+ save_strategy="no", # 🔧 改為不保存,避免 PEFT 載入問題
295
+ load_best_model_at_end=False, # 🔧 關閉,直接用最後一個 epoch
 
296
  report_to="none",
297
+ logging_steps=10,
298
+ warmup_steps=50,
299
+ logging_first_step=True
300
  )
301
 
302
  trainer = WeightedTrainer(
 
305
  train_dataset=split['train'],
306
  eval_dataset=split['test'],
307
  compute_metrics=compute_metrics,
308
+ class_weights=weights,
309
+ use_focal_loss=(ratio > 10) # 極度不平衡時使用 Focal Loss
310
  )
311
 
312
+ if ratio > 10:
313
+ info += "\n\n⚡ 使用 Focal Loss 處理極度不平衡資料"
314
+
315
  info += "\n\n⏳ 開始訓練..."
316
+
317
+ # 訓練前檢查
318
+ info += f"\n📊 訓練前檢查:"
319
+ info += f"\n - 訓練樣本: {len(split['train'])}"
320
+ info += f"\n - 測試樣本: {len(split['test'])}"
321
+ info += f"\n - 批次數/epoch: {len(split['train']) // int(batch_size)}"
322
+
323
+ train_result = trainer.train()
324
+
325
+ # 訓練後資訊
326
+ info += f"\n\n✅ 訓練完成!"
327
+ info += f"\n📉 最終 Training Loss: {train_result.training_loss:.4f}"
328
+
329
  results = trainer.evaluate()
330
 
331
  # 生成帶時間戳的模型 ID
 
336
  'model': model,
337
  'tokenizer': tokenizer,
338
  'results': results,
339
+ 'baseline': baseline_perf, # 保存基準結果供後續使用
340
  'config': {
341
  'type': base_model,
342
  'model_name': model_name,
 
356
 
357
  # 純 BERT 輸出
358
  baseline_output = f"🔬 純 BERT(未微調)\n\n"
359
+ baseline_output += f"📊 表現\n"
360
  baseline_output += f"F1: {baseline_perf['f1']:.4f}\n"
361
  baseline_output += f"Accuracy: {baseline_perf['accuracy']:.4f}\n"
362
  baseline_output += f"Precision: {baseline_perf['precision']:.4f}\n"
 
368
  baseline_output += f"FP: {baseline_perf['fp']} | FN: {baseline_perf['fn']}"
369
 
370
  # 微調 BERT 輸出
371
+ finetuned_output = f"✅ 微調 BERT\n"
372
+ finetuned_output += f"模型: {model_id}\n\n"
373
+ finetuned_output += f"📊 表現\n"
374
  finetuned_output += f"F1: {results['eval_f1']:.4f}\n"
375
  finetuned_output += f"Accuracy: {results['eval_accuracy']:.4f}\n"
376
  finetuned_output += f"Precision: {results['eval_precision']:.4f}\n"
 
571
 
572
  gr.Markdown("### 🎯 基本訓練參數")
573
  with gr.Row():
574
+ num_epochs = gr.Number(value=5, label="訓練輪數 (epochs)", minimum=1, maximum=100, precision=0,
575
+ info="建議 5-8 ")
576
+ batch_size = gr.Number(value=4, label="批次大小 (batch_size)", minimum=1, maximum=128, precision=0,
577
+ info="記憶體不足時降到 4")
578
+ learning_rate = gr.Number(value=5e-5, label="學習率 (learning_rate)", minimum=0, maximum=1, format=".0e",
579
+ info="5e-5 是平衡選擇")
580
 
581
  gr.Markdown("### ⚙️ 進階參數")
582
  with gr.Row():
 
585
 
586
  gr.Markdown("### 🔧 LoRA 參數")
587
  with gr.Row():
588
+ lora_r = gr.Number(value=32, label="LoRA Rank (r)", minimum=1, maximum=256, precision=0,
589
+ info="提高到 32,增加表達能力")
590
+ lora_alpha = gr.Number(value=64, label="LoRA Alpha", minimum=1, maximum=512, precision=0,
591
+ info="Alpha = Rank × 2")
592
+ lora_dropout = gr.Number(value=0.05, label="LoRA Dropout", minimum=0, maximum=1,
593
+ info="降低 dropout,避免欠擬合")
594
 
595
  gr.Markdown("### ⚖️ 評估設定")
596
  with gr.Row():
597
+ weight_mult = gr.Number(value=1.0, label="類別權重倍數", minimum=0, maximum=5,
598
+ info="⚠️ 資料極度不平衡時建議 0.5-1.5,不要超過 2.0")
599
  best_metric = gr.Dropdown(
600
  choices=["f1", "accuracy", "precision", "recall", "sensitivity", "specificity"],
601
  value="f1",
 
693
  針對不平衡資料(如醫療資料):
694
  - **微調方法**: LoRA(快速有效)或 AdaLoRA(追求極致)
695
  - **LoRA Rank**: 8-16(平衡效果與速度)
696
+ - **類別權重倍數**:
697
+ - ⚠️ **極度不平衡 (>10:1)**: 0.5-1.0(你的情況!)
698
+ - 中度不平衡 (3-10:1): 1.0-1.5
699
+ - 輕度不平衡 (<3:1): 1.5-2.5
700
+ - **Learning rate**: 3e-5 到 5e-5(較高的學習率配合 LoRA)
701
+ - **Epochs**: 5-10(極度不平衡需要更多輪)
702
  - **Batch size**: 8-16(依 GPU 記憶體調整)
703
 
704
  ### 資料格式
 
717
  """)
718
 
719
  if __name__ == "__main__":
720
+ demo.launch(
721
+ server_name="0.0.0.0",
722
+ server_port=7860,
723
+ share=False,
724
+ max_threads=4 # 限制執行緒數
725
+ )