smartTranscend commited on
Commit
c18db4c
·
verified ·
1 Parent(s): ec5fbdd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +632 -0
app.py ADDED
@@ -0,0 +1,632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import torch
4
+ from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer
5
+ from peft import LoraConfig, AdaLoraConfig, get_peft_model, TaskType
6
+ from datasets import Dataset
7
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
8
+ from torch import nn
9
+ import os
10
+ from datetime import datetime
11
+
12
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
13
+
14
+ # 全域變數
15
+ trained_models = {}
16
+ model_counter = 0
17
+ baseline_results = {}
18
+ baseline_model_cache = {}
19
+
20
+ def calculate_improvement(baseline_val, finetuned_val):
21
+ """安全計算改善率"""
22
+ if baseline_val == 0:
23
+ if finetuned_val > 0:
24
+ return float('inf')
25
+ else:
26
+ return 0.0
27
+ return (finetuned_val - baseline_val) / baseline_val * 100
28
+
29
+ def format_improve(val):
30
+ """格式化改善率"""
31
+ if val == float('inf'):
32
+ return "N/A (baseline=0)"
33
+ return f"{val:+.1f}%"
34
+
35
+ def compute_metrics(pred):
36
+ try:
37
+ labels = pred.label_ids
38
+ preds = pred.predictions.argmax(-1)
39
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary', pos_label=1, zero_division=0)
40
+ acc = accuracy_score(labels, preds)
41
+ cm = confusion_matrix(labels, preds)
42
+ if cm.shape == (2, 2):
43
+ tn, fp, fn, tp = cm.ravel()
44
+ else:
45
+ tn = fp = fn = tp = 0
46
+ sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
47
+ specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
48
+ return {
49
+ 'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall,
50
+ 'sensitivity': sensitivity, 'specificity': specificity,
51
+ 'tp': int(tp), 'tn': int(tn), 'fp': int(fp), 'fn': int(fn)
52
+ }
53
+ except Exception as e:
54
+ print(f"Error in compute_metrics: {e}")
55
+ return {
56
+ 'accuracy': 0, 'f1': 0, 'precision': 0, 'recall': 0,
57
+ 'sensitivity': 0, 'specificity': 0, 'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0
58
+ }
59
+
60
+ class WeightedTrainer(Trainer):
61
+ def __init__(self, *args, class_weights=None, **kwargs):
62
+ super().__init__(*args, **kwargs)
63
+ self.class_weights = class_weights
64
+
65
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
66
+ labels = inputs.pop("labels")
67
+ outputs = model(**inputs)
68
+ loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
69
+ loss = loss_fct(outputs.logits.view(-1, 2), labels.view(-1))
70
+ return (loss, outputs) if return_outputs else loss
71
+
72
+ def evaluate_baseline(model, tokenizer, test_dataset, device):
73
+ """評估未微調的基準模型"""
74
+ model.eval()
75
+ all_preds = []
76
+ all_labels = []
77
+
78
+ from torch.utils.data import DataLoader
79
+
80
+ def collate_fn(batch):
81
+ return {
82
+ 'input_ids': torch.stack([torch.tensor(item['input_ids']) for item in batch]),
83
+ 'attention_mask': torch.stack([torch.tensor(item['attention_mask']) for item in batch]),
84
+ 'labels': torch.tensor([item['label'] for item in batch])
85
+ }
86
+
87
+ dataloader = DataLoader(test_dataset, batch_size=16, collate_fn=collate_fn)
88
+
89
+ with torch.no_grad():
90
+ for batch in dataloader:
91
+ labels = batch.pop('labels')
92
+ inputs = {k: v.to(device) for k, v in batch.items()}
93
+ outputs = model(**inputs)
94
+ preds = torch.argmax(outputs.logits, dim=-1)
95
+ all_preds.extend(preds.cpu().numpy())
96
+ all_labels.extend(labels.numpy())
97
+
98
+ precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary', pos_label=1, zero_division=0)
99
+ acc = accuracy_score(all_labels, all_preds)
100
+ cm = confusion_matrix(all_labels, all_preds)
101
+ if cm.shape == (2, 2):
102
+ tn, fp, fn, tp = cm.ravel()
103
+ else:
104
+ tn = fp = fn = tp = 0
105
+ sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
106
+ specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
107
+
108
+ return {
109
+ 'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall,
110
+ 'sensitivity': sensitivity, 'specificity': specificity,
111
+ 'tp': int(tp), 'tn': int(tn), 'fp': int(fp), 'fn': int(fn)
112
+ }
113
+
114
+ def train_bert_model(csv_file, base_model, method, num_epochs, batch_size, learning_rate,
115
+ weight_decay, dropout, lora_r, lora_alpha, lora_dropout,
116
+ weight_mult, best_metric):
117
+ global trained_models, model_counter, baseline_results
118
+
119
+ model_mapping = {
120
+ "BERT-base": "bert-base-uncased",
121
+ }
122
+
123
+ model_name = model_mapping.get(base_model, "bert-base-uncased")
124
+
125
+ try:
126
+ if csv_file is None:
127
+ return "❌ 請上傳 CSV", "", "", ""
128
+
129
+ df = pd.read_csv(csv_file.name)
130
+ if 'Text' not in df.columns or 'label' not in df.columns:
131
+ return "❌ 需要 Text 和 label 欄位", "", "", ""
132
+
133
+ df_clean = pd.DataFrame({
134
+ 'text': df['Text'].astype(str),
135
+ 'label': df['label'].astype(int)
136
+ }).dropna()
137
+
138
+ n0 = int(sum(df_clean['label'] == 0))
139
+ n1 = int(sum(df_clean['label'] == 1))
140
+ if n1 == 0:
141
+ return "❌ 無死亡樣本", "", "", ""
142
+
143
+ ratio = n0 / n1
144
+ w0, w1 = 1.0, ratio * weight_mult
145
+
146
+ info = f"📊 資料: {len(df_clean)} 筆\n存活: {n0} | 死亡: {n1}\n比例: {ratio:.2f}:1\n權重: {w0:.2f} / {w1:.2f}\n模型: {base_model}\n方法: {method.upper()}"
147
+
148
+ tokenizer = BertTokenizer.from_pretrained(model_name)
149
+ dataset = Dataset.from_pandas(df_clean[['text', 'label']])
150
+
151
+ def preprocess(ex):
152
+ return tokenizer(ex['text'], truncation=True, padding='max_length', max_length=128)
153
+
154
+ tokenized = dataset.map(preprocess, batched=True, remove_columns=['text'])
155
+ split = tokenized.train_test_split(test_size=0.2, seed=42)
156
+
157
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
158
+ info += f"\n裝置: {'GPU ✅' if torch.cuda.is_available() else 'CPU ⚠️'}"
159
+
160
+ # 評估基準模型(未微調)
161
+ info += "\n\n🔍 評估基準模型(未微調)..."
162
+ baseline_model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
163
+ baseline_model = baseline_model.to(device)
164
+
165
+ baseline_perf = evaluate_baseline(baseline_model, tokenizer, split['test'], device)
166
+ baseline_key = f"{base_model}_baseline"
167
+ baseline_results[baseline_key] = baseline_perf
168
+
169
+ info += f"\n基準 F1: {baseline_perf['f1']:.4f}"
170
+ info += f"\n基準 Accuracy: {baseline_perf['accuracy']:.4f}"
171
+
172
+ # 清理基準模型以釋放記憶體
173
+ del baseline_model
174
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
175
+
176
+ # 開始微調
177
+ info += f"\n\n🔧 套用 {method.upper()} 微調..."
178
+ model = BertForSequenceClassification.from_pretrained(
179
+ model_name, num_labels=2,
180
+ hidden_dropout_prob=dropout,
181
+ attention_probs_dropout_prob=dropout
182
+ )
183
+
184
+ peft_applied = False
185
+ if method == "lora":
186
+ config = LoraConfig(
187
+ task_type=TaskType.SEQ_CLS,
188
+ r=int(lora_r),
189
+ lora_alpha=int(lora_alpha),
190
+ lora_dropout=lora_dropout,
191
+ target_modules=["query", "value"],
192
+ bias="none"
193
+ )
194
+ model = get_peft_model(model, config)
195
+ peft_applied = True
196
+ info += f"\n✅ LoRA 已套用(r={int(lora_r)}, alpha={int(lora_alpha)})"
197
+ elif method == "adalora":
198
+ config = AdaLoraConfig(
199
+ task_type=TaskType.SEQ_CLS,
200
+ r=int(lora_r),
201
+ lora_alpha=int(lora_alpha),
202
+ lora_dropout=lora_dropout,
203
+ target_modules=["query", "value"],
204
+ init_r=12, tinit=200, tfinal=1000, deltaT=10
205
+ )
206
+ model = get_peft_model(model, config)
207
+ peft_applied = True
208
+ info += f"\n✅ AdaLoRA 已套用(r={int(lora_r)}, alpha={int(lora_alpha)})"
209
+
210
+ if not peft_applied:
211
+ info += f"\n⚠️ 警告:{method} 方法未被識別,使用 Full Fine-tuning"
212
+
213
+ model = model.to(device)
214
+
215
+ total = sum(p.numel() for p in model.parameters())
216
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
217
+ info += f"\n\n💾 參數量\n總參數: {total:,}\n可訓練: {trainable:,}\n比例: {trainable/total*100:.2f}%"
218
+
219
+ weights = torch.tensor([w0, w1], dtype=torch.float).to(device)
220
+
221
+ args = TrainingArguments(
222
+ output_dir='./results',
223
+ num_train_epochs=int(num_epochs),
224
+ per_device_train_batch_size=int(batch_size),
225
+ per_device_eval_batch_size=int(batch_size)*2,
226
+ learning_rate=float(learning_rate),
227
+ weight_decay=float(weight_decay),
228
+ evaluation_strategy="epoch",
229
+ save_strategy="epoch",
230
+ load_best_model_at_end=True,
231
+ metric_for_best_model=best_metric,
232
+ report_to="none",
233
+ logging_steps=50,
234
+ save_total_limit=2
235
+ )
236
+
237
+ trainer = WeightedTrainer(
238
+ model=model,
239
+ args=args,
240
+ train_dataset=split['train'],
241
+ eval_dataset=split['test'],
242
+ compute_metrics=compute_metrics,
243
+ class_weights=weights
244
+ )
245
+
246
+ info += "\n\n⏳ 開始訓練..."
247
+ trainer.train()
248
+ results = trainer.evaluate()
249
+
250
+ # 生成帶時間戳的模型 ID
251
+ model_counter += 1
252
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
253
+ model_id = f"{base_model}_{method}_{timestamp}"
254
+ trained_models[model_id] = {
255
+ 'model': model,
256
+ 'tokenizer': tokenizer,
257
+ 'results': results,
258
+ 'baseline': baseline_perf,
259
+ 'config': {
260
+ 'type': base_model,
261
+ 'model_name': model_name,
262
+ 'method': method,
263
+ 'metric': best_metric
264
+ },
265
+ 'timestamp': timestamp
266
+ }
267
+
268
+ # 計算改善
269
+ f1_improve = calculate_improvement(baseline_perf['f1'], results['eval_f1'])
270
+ acc_improve = calculate_improvement(baseline_perf['accuracy'], results['eval_accuracy'])
271
+ prec_improve = calculate_improvement(baseline_perf['precision'], results['eval_precision'])
272
+ rec_improve = calculate_improvement(baseline_perf['recall'], results['eval_recall'])
273
+ sens_improve = calculate_improvement(baseline_perf['sensitivity'], results['eval_sensitivity'])
274
+ spec_improve = calculate_improvement(baseline_perf['specificity'], results['eval_specificity'])
275
+
276
+ # 純 BERT 輸出
277
+ baseline_output = f"🔬 純 BERT(未微調)\n\n"
278
+ baseline_output += f"📈 表現\n"
279
+ baseline_output += f"F1: {baseline_perf['f1']:.4f}\n"
280
+ baseline_output += f"Accuracy: {baseline_perf['accuracy']:.4f}\n"
281
+ baseline_output += f"Precision: {baseline_perf['precision']:.4f}\n"
282
+ baseline_output += f"Recall: {baseline_perf['recall']:.4f}\n"
283
+ baseline_output += f"Sensitivity: {baseline_perf['sensitivity']:.4f}\n"
284
+ baseline_output += f"Specificity: {baseline_perf['specificity']:.4f}\n\n"
285
+ baseline_output += f"混淆矩陣\n"
286
+ baseline_output += f"TP: {baseline_perf['tp']} | TN: {baseline_perf['tn']}\n"
287
+ baseline_output += f"FP: {baseline_perf['fp']} | FN: {baseline_perf['fn']}"
288
+
289
+ # 微調 BERT 輸出
290
+ finetuned_output = f"✅ 微調 BERT\n模型: {model_id}\n\n"
291
+ finetuned_output += f"📈 表現\n"
292
+ finetuned_output += f"F1: {results['eval_f1']:.4f}\n"
293
+ finetuned_output += f"Accuracy: {results['eval_accuracy']:.4f}\n"
294
+ finetuned_output += f"Precision: {results['eval_precision']:.4f}\n"
295
+ finetuned_output += f"Recall: {results['eval_recall']:.4f}\n"
296
+ finetuned_output += f"Sensitivity: {results['eval_sensitivity']:.4f}\n"
297
+ finetuned_output += f"Specificity: {results['eval_specificity']:.4f}\n\n"
298
+ finetuned_output += f"混淆矩陣\n"
299
+ finetuned_output += f"TP: {results['eval_tp']} | TN: {results['eval_tn']}\n"
300
+ finetuned_output += f"FP: {results['eval_fp']} | FN: {results['eval_fn']}"
301
+
302
+ # 比較結果輸出
303
+ comparison_output = f"📊 純 BERT vs 微調 BERT 比較\n\n"
304
+ comparison_output += f"指標改善:\n"
305
+ comparison_output += f"F1: {baseline_perf['f1']:.4f} → {results['eval_f1']:.4f} ({format_improve(f1_improve)})\n"
306
+ comparison_output += f"Accuracy: {baseline_perf['accuracy']:.4f} → {results['eval_accuracy']:.4f} ({format_improve(acc_improve)})\n"
307
+ comparison_output += f"Precision: {baseline_perf['precision']:.4f} → {results['eval_precision']:.4f} ({format_improve(prec_improve)})\n"
308
+ comparison_output += f"Recall: {baseline_perf['recall']:.4f} → {results['eval_recall']:.4f} ({format_improve(rec_improve)})\n"
309
+ comparison_output += f"Sensitivity: {baseline_perf['sensitivity']:.4f} → {results['eval_sensitivity']:.4f} ({format_improve(sens_improve)})\n"
310
+ comparison_output += f"Specificity: {baseline_perf['specificity']:.4f} → {results['eval_specificity']:.4f} ({format_improve(spec_improve)})\n\n"
311
+ comparison_output += f"混淆矩陣變化:\n"
312
+ comparison_output += f"TP: {baseline_perf['tp']} → {results['eval_tp']} ({results['eval_tp'] - baseline_perf['tp']:+d})\n"
313
+ comparison_output += f"TN: {baseline_perf['tn']} → {results['eval_tn']} ({results['eval_tn'] - baseline_perf['tn']:+d})\n"
314
+ comparison_output += f"FP: {baseline_perf['fp']} → {results['eval_fp']} ({results['eval_fp'] - baseline_perf['fp']:+d})\n"
315
+ comparison_output += f"FN: {baseline_perf['fn']} → {results['eval_fn']} ({results['eval_fn'] - baseline_perf['fn']:+d})"
316
+
317
+ info += "\n\n✅ 訓練完成!"
318
+
319
+ return info, baseline_output, finetuned_output, comparison_output
320
+
321
+ except Exception as e:
322
+ import traceback
323
+ error_msg = f"❌ 錯誤: {str(e)}\n\n{traceback.format_exc()}"
324
+ return error_msg, "", "", ""
325
+
326
+ def predict(model_id, text):
327
+ global baseline_model_cache
328
+
329
+ if not model_id or model_id not in trained_models:
330
+ return "❌ 請選擇模型"
331
+ if not text:
332
+ return "❌ 請輸入文字"
333
+
334
+ try:
335
+ info = trained_models[model_id]
336
+ model, tokenizer = info['model'], info['tokenizer']
337
+ config = info['config']
338
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
339
+
340
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
341
+ inputs_cuda = {k: v.to(device) for k, v in inputs.items()}
342
+
343
+ # 預測:微調模型
344
+ model.eval()
345
+ with torch.no_grad():
346
+ outputs = model(**inputs_cuda)
347
+ probs_finetuned = torch.nn.functional.softmax(outputs.logits, dim=-1)
348
+ pred_finetuned = torch.argmax(probs_finetuned, dim=-1).item()
349
+
350
+ result_finetuned = "存活" if pred_finetuned == 0 else "死亡"
351
+
352
+ # 預測:基準模型(使用快取)
353
+ cache_key = config['model_name']
354
+ if cache_key not in baseline_model_cache:
355
+ baseline_model = BertForSequenceClassification.from_pretrained(config['model_name'], num_labels=2)
356
+ baseline_model = baseline_model.to(device)
357
+ baseline_model.eval()
358
+ baseline_model_cache[cache_key] = baseline_model
359
+ else:
360
+ baseline_model = baseline_model_cache[cache_key]
361
+
362
+ with torch.no_grad():
363
+ outputs_baseline = baseline_model(**inputs_cuda)
364
+ probs_baseline = torch.nn.functional.softmax(outputs_baseline.logits, dim=-1)
365
+ pred_baseline = torch.argmax(probs_baseline, dim=-1).item()
366
+
367
+ result_baseline = "存活" if pred_baseline == 0 else "死亡"
368
+
369
+ # 判斷是否一致
370
+ agreement = "✅ 一致" if pred_finetuned == pred_baseline else "⚠️ 不一致"
371
+
372
+ output = f"""🔮 預測結果比較
373
+
374
+ 📝 輸入文字: {text[:100]}{'...' if len(text) > 100 else ''}
375
+
376
+ {'='*50}
377
+
378
+ 🧬 微調模型 ({model_id})
379
+ 預測: {result_finetuned}
380
+ 信心: {probs_finetuned[0][pred_finetuned].item():.2%}
381
+ 機率分布:
382
+ • 存活: {probs_finetuned[0][0].item():.2%}
383
+ • 死亡: {probs_finetuned[0][1].item():.2%}
384
+
385
+ {'='*50}
386
+
387
+ 🔬 基準模型(未微調 {config['type']})
388
+ 預測: {result_baseline}
389
+ 信心: {probs_baseline[0][pred_baseline].item():.2%}
390
+ 機率分布:
391
+ • 存活: {probs_baseline[0][0].item():.2%}
392
+ • 死亡: {probs_baseline[0][1].item():.2%}
393
+
394
+ {'='*50}
395
+
396
+ 📊 結論
397
+ 兩模型預測: {agreement}
398
+ """
399
+
400
+ if pred_finetuned != pred_baseline:
401
+ output += f"\n💡 分析: 微調模型預測為【{result_finetuned}】,而基準模型預測為【{result_baseline}】"
402
+ output += f"\n 這顯示了 fine-tuning 對此案例的影響!"
403
+
404
+ f1_improve = calculate_improvement(info['baseline']['f1'], info['results']['eval_f1'])
405
+
406
+ output += f"""
407
+
408
+ 📈 模型表現
409
+ 微調模型 F1: {info['results']['eval_f1']:.4f}
410
+ 基準模型 F1: {info['baseline']['f1']:.4f}
411
+ 改善幅度: {format_improve(f1_improve)}
412
+ """
413
+
414
+ return output
415
+
416
+ except Exception as e:
417
+ import traceback
418
+ return f"❌ 錯誤: {str(e)}\n\n{traceback.format_exc()}"
419
+
420
+ def compare():
421
+ if not trained_models:
422
+ return "❌ 尚未訓練模型"
423
+
424
+ text = "# 📊 模型比較\n\n"
425
+ text += "## 微調模型表現\n\n"
426
+ text += "| 模型 | 基礎 | 方法 | F1 | Acc | Prec | Recall | Sens | Spec |\n"
427
+ text += "|------|------|------|-----|-----|------|--------|------|------|\n"
428
+
429
+ for mid, info in trained_models.items():
430
+ r = info['results']
431
+ c = info['config']
432
+ text += f"| {mid} | {c['type']} | {c['method'].upper()} | {r['eval_f1']:.4f} | {r['eval_accuracy']:.4f} | "
433
+ text += f"{r['eval_precision']:.4f} | {r['eval_recall']:.4f} | "
434
+ text += f"{r['eval_sensitivity']:.4f} | {r['eval_specificity']:.4f} |\n"
435
+
436
+ text += "\n## 基準模型表現(未微調)\n\n"
437
+ text += "| 模型 | F1 | Acc | Prec | Recall | Sens | Spec |\n"
438
+ text += "|------|-----|-----|------|--------|------|------|\n"
439
+
440
+ for mid, info in trained_models.items():
441
+ b = info['baseline']
442
+ c = info['config']
443
+ text += f"| {c['type']}-baseline | {b['f1']:.4f} | {b['accuracy']:.4f} | "
444
+ text += f"{b['precision']:.4f} | {b['recall']:.4f} | "
445
+ text += f"{b['sensitivity']:.4f} | {b['specificity']:.4f} |\n"
446
+
447
+ text += "\n## 🏆 最佳模型\n\n"
448
+ for metric in ['f1', 'accuracy', 'precision', 'recall', 'sensitivity', 'specificity']:
449
+ best = max(trained_models.items(), key=lambda x: x[1]['results'][f'eval_{metric}'])
450
+ baseline_val = best[1]['baseline'][metric]
451
+ finetuned_val = best[1]['results'][f'eval_{metric}']
452
+ improvement = calculate_improvement(baseline_val, finetuned_val)
453
+
454
+ text += f"**{metric.upper()}**: {best[0]} ({finetuned_val:.4f}, 改善 {format_improve(improvement)})\n\n"
455
+
456
+ return text
457
+
458
+ def refresh_model_list():
459
+ return gr.Dropdown(choices=list(trained_models.keys()))
460
+
461
+ # Gradio UI
462
+ with gr.Blocks(title="BERT Fine-tuning 教學平台", theme=gr.themes.Soft()) as demo:
463
+ gr.Markdown("# 🧬 BERT Fine-tuning 教學平台")
464
+ gr.Markdown("### 比較基準模型 vs 微調模型的表現差異")
465
+
466
+ with gr.Tab("訓練"):
467
+ gr.Markdown("## 步驟 1: 選擇基礎模型")
468
+
469
+ base_model = gr.Dropdown(
470
+ choices=["BERT-base"],
471
+ value="BERT-base",
472
+ label="基礎模型",
473
+ info="更多模型即將推出"
474
+ )
475
+
476
+ gr.Markdown("## 步驟 2: 選擇微調方法")
477
+
478
+ method = gr.Radio(
479
+ choices=["lora", "adalora"],
480
+ value="lora",
481
+ label="微調方法",
482
+ info="兩種都是參數高效方法,推薦從 LoRA 開始"
483
+ )
484
+
485
+ gr.Markdown("## 步驟 3: 上傳資料")
486
+ csv_file = gr.File(label="CSV 檔案 (需包含 Text 和 label 欄位)", file_types=[".csv"])
487
+
488
+ gr.Markdown("## 步驟 4: 設定訓練參數")
489
+
490
+ gr.Markdown("### 🎯 基本訓練參數")
491
+ with gr.Row():
492
+ num_epochs = gr.Number(value=3, label="訓練輪數 (epochs)", minimum=1, maximum=100, precision=0)
493
+ batch_size = gr.Number(value=8, label="批次大小 (batch_size)", minimum=1, maximum=128, precision=0)
494
+ learning_rate = gr.Number(value=2e-5, label="學習率 (learning_rate)", minimum=0, maximum=1)
495
+
496
+ gr.Markdown("### ⚙️ 進階參數")
497
+ with gr.Row():
498
+ weight_decay = gr.Number(value=0.01, label="權重衰減 (weight_decay)", minimum=0, maximum=1)
499
+ dropout = gr.Number(value=0.1, label="Dropout 機率", minimum=0, maximum=1)
500
+
501
+ gr.Markdown("### 🔧 LoRA 參數")
502
+ with gr.Row():
503
+ lora_r = gr.Number(value=16, label="LoRA Rank (r)", minimum=1, maximum=256, precision=0,
504
+ info="推薦 8-16,越大效果越好但越慢")
505
+ lora_alpha = gr.Number(value=32, label="LoRA Alpha", minimum=1, maximum=512, precision=0,
506
+ info="通常設為 Rank 的 2 倍")
507
+ lora_dropout = gr.Number(value=0.1, label="LoRA Dropout", minimum=0, maximum=1,
508
+ info="防止過擬合")
509
+
510
+ gr.Markdown("### ⚖️ 評估設定")
511
+ with gr.Row():
512
+ weight_mult = gr.Number(value=2.0, label="類別權重倍數", minimum=0, maximum=10,
513
+ info="推薦 1.5-2.5,過低會忽略少數類")
514
+ best_metric = gr.Dropdown(
515
+ choices=["f1", "accuracy", "precision", "recall", "sensitivity", "specificity"],
516
+ value="f1",
517
+ label="最佳模型選擇指標",
518
+ info="訓練時用此指標選擇最佳模型"
519
+ )
520
+
521
+ train_btn = gr.Button("🚀 開始訓練", variant="primary", size="lg")
522
+
523
+ gr.Markdown("## 📊 訓練結果")
524
+
525
+ data_info = gr.Textbox(label="📋 資料資訊", lines=10)
526
+
527
+ with gr.Row():
528
+ baseline_result = gr.Textbox(label="🔬 純 BERT(未微調)", lines=14)
529
+ finetuned_result = gr.Textbox(label="✅ 微調 BERT", lines=14)
530
+
531
+ comparison_result = gr.Textbox(label="📊 純 BERT vs 微調 BERT 比較", lines=14)
532
+
533
+ train_btn.click(
534
+ train_bert_model,
535
+ inputs=[csv_file, base_model, method, num_epochs, batch_size, learning_rate,
536
+ weight_decay, dropout, lora_r, lora_alpha, lora_dropout,
537
+ weight_mult, best_metric],
538
+ outputs=[data_info, baseline_result, finetuned_result, comparison_result]
539
+ )
540
+
541
+ with gr.Tab("預測"):
542
+ gr.Markdown("## 使用訓練好的模型預測")
543
+
544
+ with gr.Row():
545
+ model_drop = gr.Dropdown(label="選擇模型", choices=list(trained_models.keys()))
546
+ refresh = gr.Button("🔄 刷新")
547
+
548
+ text_input = gr.Textbox(label="輸入病例描述", lines=4,
549
+ placeholder="Patient diagnosed with...")
550
+ predict_btn = gr.Button("預測", variant="primary", size="lg")
551
+ pred_output = gr.Textbox(label="預測結果(含基準模型對比)", lines=20)
552
+
553
+ refresh.click(refresh_model_list, outputs=[model_drop])
554
+ predict_btn.click(predict, inputs=[model_drop, text_input], outputs=[pred_output])
555
+
556
+ gr.Examples(
557
+ examples=[
558
+ ["Patient with stage II breast cancer, good response to treatment."],
559
+ ["Advanced metastatic cancer, multiple organ involvement."]
560
+ ],
561
+ inputs=text_input
562
+ )
563
+
564
+ with gr.Tab("比較"):
565
+ gr.Markdown("## 比較所有模型(含基準模型)")
566
+ compare_btn = gr.Button("比較", variant="primary", size="lg")
567
+ compare_output = gr.Markdown()
568
+ compare_btn.click(compare, outputs=[compare_output])
569
+
570
+ with gr.Tab("說明"):
571
+ gr.Markdown("""
572
+ ## 📖 使用說明
573
+
574
+ ### 🎯 平台特色
575
+
576
+ 本平台會自動比較:
577
+ - **基準模型**:未經微調的原始 BERT
578
+ - **微調模型**:使��你的資料訓練後的 BERT
579
+
580
+ 這樣可以清楚看到 fine-tuning 帶來的改善!
581
+
582
+ ### 基礎模型
583
+
584
+ - **BERT-base**: 標準 BERT,110M 參數 ⭐目前支援
585
+
586
+ ### 微調方法
587
+
588
+ - **LoRA**: 低秩適應,參數高效的微調方法 ⭐強烈推薦
589
+ - 只訓練少量參數(通常 <1%)
590
+ - 訓練速度快,效果好
591
+ - 適合大多數情況
592
+
593
+ - **AdaLoRA**: 自適應 LoRA,動態調整秩
594
+ - 自動找出最重要的參數
595
+ - 可能比 LoRA 效果稍好
596
+ - 訓練時間稍長
597
+
598
+ ### 評估指標
599
+
600
+ - **F1**: 平衡指標,推薦用於不平衡資料 ⭐
601
+ - **Accuracy**: 整體準確率
602
+ - **Precision**: 減少假陽性
603
+ - **Recall/Sensitivity**: 減少假陰性
604
+ - **Specificity**: 真陰性率
605
+
606
+ ### 參數建議
607
+
608
+ 針對不平衡資料(如醫療資料):
609
+ - **微調方法**: LoRA(快速有效)或 AdaLoRA(追求極致)
610
+ - **LoRA Rank**: 8-16(平衡效果與速度)
611
+ - **類別權重倍數**: 1.5-2.5(資料不平衡時)
612
+ - **Learning rate**: 2e-5 到 5e-5
613
+ - **Epochs**: 3-8(避免過擬合)
614
+ - **Batch size**: 8-16(依 GPU 記憶體調整)
615
+
616
+ ### 資料格式
617
+
618
+ CSV 必須包含:
619
+ - `Text`: 病例描述
620
+ - `label`: 0=存活, 1=死亡
621
+
622
+ ### 🚀 快速開始
623
+
624
+ 1. 上傳包含 `Text` 和 `label` 欄位的 CSV
625
+ 2. 使用預設參數(適合大多數情況)
626
+ 3. 點擊「開始訓練」
627
+ 4. 在「預測」分頁測試模型
628
+ 5. 在「比較」分頁查看所有模型表現
629
+ """)
630
+
631
+ if __name__ == "__main__":
632
+ demo.launch()