import gradio as gr import pandas as pd import torch from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer from peft import LoraConfig, AdaLoraConfig, get_peft_model, TaskType from datasets import Dataset from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix from torch import nn import os from datetime import datetime os.environ["TOKENIZERS_PARALLELISM"] = "false" # 全域變數 trained_models = {} model_counter = 0 baseline_results = {} baseline_model_cache = {} def calculate_improvement(baseline_val, finetuned_val): """安全計算改善率""" if baseline_val == 0: if finetuned_val > 0: return float('inf') else: return 0.0 return (finetuned_val - baseline_val) / baseline_val * 100 def format_improve(val): """格式化改善率""" if val == float('inf'): return "N/A (baseline=0)" return f"{val:+.1f}%" def compute_metrics(pred): try: labels = pred.label_ids preds = pred.predictions.argmax(-1) precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary', pos_label=1, zero_division=0) acc = accuracy_score(labels, preds) cm = confusion_matrix(labels, preds) if cm.shape == (2, 2): tn, fp, fn, tp = cm.ravel() else: tn = fp = fn = tp = 0 sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0 specificity = tn / (tn + fp) if (tn + fp) > 0 else 0 return { 'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall, 'sensitivity': sensitivity, 'specificity': specificity, 'tp': int(tp), 'tn': int(tn), 'fp': int(fp), 'fn': int(fn) } except Exception as e: print(f"Error in compute_metrics: {e}") return { 'accuracy': 0, 'f1': 0, 'precision': 0, 'recall': 0, 'sensitivity': 0, 'specificity': 0, 'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0 } class WeightedTrainer(Trainer): def __init__(self, *args, class_weights=None, **kwargs): super().__init__(*args, **kwargs) self.class_weights = class_weights def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): labels = inputs.pop("labels") outputs = model(**inputs) loss_fct = nn.CrossEntropyLoss(weight=self.class_weights) loss = loss_fct(outputs.logits.view(-1, 2), labels.view(-1)) return (loss, outputs) if return_outputs else loss def evaluate_baseline(model, tokenizer, test_dataset, device): """評估未微調的基準模型""" model.eval() all_preds = [] all_labels = [] from torch.utils.data import DataLoader def collate_fn(batch): return { 'input_ids': torch.stack([torch.tensor(item['input_ids']) for item in batch]), 'attention_mask': torch.stack([torch.tensor(item['attention_mask']) for item in batch]), 'labels': torch.tensor([item['label'] for item in batch]) } dataloader = DataLoader(test_dataset, batch_size=16, collate_fn=collate_fn) with torch.no_grad(): for batch in dataloader: labels = batch.pop('labels') inputs = {k: v.to(device) for k, v in batch.items()} outputs = model(**inputs) preds = torch.argmax(outputs.logits, dim=-1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.numpy()) precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary', pos_label=1, zero_division=0) acc = accuracy_score(all_labels, all_preds) cm = confusion_matrix(all_labels, all_preds) if cm.shape == (2, 2): tn, fp, fn, tp = cm.ravel() else: tn = fp = fn = tp = 0 sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0 specificity = tn / (tn + fp) if (tn + fp) > 0 else 0 return { 'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall, 'sensitivity': sensitivity, 'specificity': specificity, 'tp': int(tp), 'tn': int(tn), 'fp': int(fp), 'fn': int(fn) } def train_bert_model(csv_file, base_model, method, num_epochs, batch_size, learning_rate, weight_decay, dropout, lora_r, lora_alpha, lora_dropout, weight_mult, best_metric): global trained_models, model_counter, baseline_results model_mapping = { "BERT-base": "bert-base-uncased", } model_name = model_mapping.get(base_model, "bert-base-uncased") try: if csv_file is None: return "❌ 請上傳 CSV", "", "", "" df = pd.read_csv(csv_file.name) if 'Text' not in df.columns or 'label' not in df.columns: return "❌ 需要 Text 和 label 欄位", "", "", "" df_clean = pd.DataFrame({ 'text': df['Text'].astype(str), 'label': df['label'].astype(int) }).dropna() n0 = int(sum(df_clean['label'] == 0)) n1 = int(sum(df_clean['label'] == 1)) if n1 == 0: return "❌ 無死亡樣本", "", "", "" ratio = n0 / n1 w0, w1 = 1.0, ratio * weight_mult info = f"📊 資料: {len(df_clean)} 筆\n存活: {n0} | 死亡: {n1}\n比例: {ratio:.2f}:1\n權重: {w0:.2f} / {w1:.2f}\n模型: {base_model}\n方法: {method.upper()}" tokenizer = BertTokenizer.from_pretrained(model_name) dataset = Dataset.from_pandas(df_clean[['text', 'label']]) def preprocess(ex): return tokenizer(ex['text'], truncation=True, padding='max_length', max_length=128) tokenized = dataset.map(preprocess, batched=True, remove_columns=['text']) split = tokenized.train_test_split(test_size=0.2, seed=42) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') info += f"\n裝置: {'GPU ✅' if torch.cuda.is_available() else 'CPU ⚠️'}" # 評估基準模型(未微調) info += "\n\n🔍 評估基準模型(未微調)..." baseline_model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2) baseline_model = baseline_model.to(device) baseline_perf = evaluate_baseline(baseline_model, tokenizer, split['test'], device) baseline_key = f"{base_model}_baseline" baseline_results[baseline_key] = baseline_perf info += f"\n基準 F1: {baseline_perf['f1']:.4f}" info += f"\n基準 Accuracy: {baseline_perf['accuracy']:.4f}" # 清理基準模型以釋放記憶體 del baseline_model torch.cuda.empty_cache() if torch.cuda.is_available() else None # 開始微調 info += f"\n\n🔧 套用 {method.upper()} 微調..." model = BertForSequenceClassification.from_pretrained( model_name, num_labels=2, hidden_dropout_prob=dropout, attention_probs_dropout_prob=dropout ) peft_applied = False if method == "lora": config = LoraConfig( task_type=TaskType.SEQ_CLS, r=int(lora_r), lora_alpha=int(lora_alpha), lora_dropout=lora_dropout, target_modules=["query", "value"], bias="none" ) model = get_peft_model(model, config) peft_applied = True info += f"\n✅ LoRA 已套用(r={int(lora_r)}, alpha={int(lora_alpha)})" elif method == "adalora": config = AdaLoraConfig( task_type=TaskType.SEQ_CLS, r=int(lora_r), lora_alpha=int(lora_alpha), lora_dropout=lora_dropout, target_modules=["query", "value"], init_r=12, tinit=200, tfinal=1000, deltaT=10 ) model = get_peft_model(model, config) peft_applied = True info += f"\n✅ AdaLoRA 已套用(r={int(lora_r)}, alpha={int(lora_alpha)})" if not peft_applied: info += f"\n⚠️ 警告:{method} 方法未被識別,使用 Full Fine-tuning" model = model.to(device) total = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) info += f"\n\n💾 參數量\n總參數: {total:,}\n可訓練: {trainable:,}\n比例: {trainable/total*100:.2f}%" weights = torch.tensor([w0, w1], dtype=torch.float).to(device) args = TrainingArguments( output_dir='./results', num_train_epochs=int(num_epochs), per_device_train_batch_size=int(batch_size), per_device_eval_batch_size=int(batch_size)*2, learning_rate=float(learning_rate), weight_decay=float(weight_decay), evaluation_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model=best_metric, report_to="none", logging_steps=50, save_total_limit=2 ) trainer = WeightedTrainer( model=model, args=args, train_dataset=split['train'], eval_dataset=split['test'], compute_metrics=compute_metrics, class_weights=weights ) info += "\n\n⏳ 開始訓練..." trainer.train() results = trainer.evaluate() # 生成帶時間戳的模型 ID model_counter += 1 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") model_id = f"{base_model}_{method}_{timestamp}" trained_models[model_id] = { 'model': model, 'tokenizer': tokenizer, 'results': results, 'baseline': baseline_perf, 'config': { 'type': base_model, 'model_name': model_name, 'method': method, 'metric': best_metric }, 'timestamp': timestamp } # 計算改善 f1_improve = calculate_improvement(baseline_perf['f1'], results['eval_f1']) acc_improve = calculate_improvement(baseline_perf['accuracy'], results['eval_accuracy']) prec_improve = calculate_improvement(baseline_perf['precision'], results['eval_precision']) rec_improve = calculate_improvement(baseline_perf['recall'], results['eval_recall']) sens_improve = calculate_improvement(baseline_perf['sensitivity'], results['eval_sensitivity']) spec_improve = calculate_improvement(baseline_perf['specificity'], results['eval_specificity']) # 純 BERT 輸出 baseline_output = f"🔬 純 BERT(未微調)\n\n" baseline_output += f"📈 表現\n" baseline_output += f"F1: {baseline_perf['f1']:.4f}\n" baseline_output += f"Accuracy: {baseline_perf['accuracy']:.4f}\n" baseline_output += f"Precision: {baseline_perf['precision']:.4f}\n" baseline_output += f"Recall: {baseline_perf['recall']:.4f}\n" baseline_output += f"Sensitivity: {baseline_perf['sensitivity']:.4f}\n" baseline_output += f"Specificity: {baseline_perf['specificity']:.4f}\n\n" baseline_output += f"混淆矩陣\n" baseline_output += f"TP: {baseline_perf['tp']} | TN: {baseline_perf['tn']}\n" baseline_output += f"FP: {baseline_perf['fp']} | FN: {baseline_perf['fn']}" # 微調 BERT 輸出 finetuned_output = f"✅ 微調 BERT\n模型: {model_id}\n\n" finetuned_output += f"📈 表現\n" finetuned_output += f"F1: {results['eval_f1']:.4f}\n" finetuned_output += f"Accuracy: {results['eval_accuracy']:.4f}\n" finetuned_output += f"Precision: {results['eval_precision']:.4f}\n" finetuned_output += f"Recall: {results['eval_recall']:.4f}\n" finetuned_output += f"Sensitivity: {results['eval_sensitivity']:.4f}\n" finetuned_output += f"Specificity: {results['eval_specificity']:.4f}\n\n" finetuned_output += f"混淆矩陣\n" finetuned_output += f"TP: {results['eval_tp']} | TN: {results['eval_tn']}\n" finetuned_output += f"FP: {results['eval_fp']} | FN: {results['eval_fn']}" # 比較結果輸出 comparison_output = f"📊 純 BERT vs 微調 BERT 比較\n\n" comparison_output += f"指標改善:\n" comparison_output += f"F1: {baseline_perf['f1']:.4f} → {results['eval_f1']:.4f} ({format_improve(f1_improve)})\n" comparison_output += f"Accuracy: {baseline_perf['accuracy']:.4f} → {results['eval_accuracy']:.4f} ({format_improve(acc_improve)})\n" comparison_output += f"Precision: {baseline_perf['precision']:.4f} → {results['eval_precision']:.4f} ({format_improve(prec_improve)})\n" comparison_output += f"Recall: {baseline_perf['recall']:.4f} → {results['eval_recall']:.4f} ({format_improve(rec_improve)})\n" comparison_output += f"Sensitivity: {baseline_perf['sensitivity']:.4f} → {results['eval_sensitivity']:.4f} ({format_improve(sens_improve)})\n" comparison_output += f"Specificity: {baseline_perf['specificity']:.4f} → {results['eval_specificity']:.4f} ({format_improve(spec_improve)})\n\n" comparison_output += f"混淆矩陣變化:\n" comparison_output += f"TP: {baseline_perf['tp']} → {results['eval_tp']} ({results['eval_tp'] - baseline_perf['tp']:+d})\n" comparison_output += f"TN: {baseline_perf['tn']} → {results['eval_tn']} ({results['eval_tn'] - baseline_perf['tn']:+d})\n" comparison_output += f"FP: {baseline_perf['fp']} → {results['eval_fp']} ({results['eval_fp'] - baseline_perf['fp']:+d})\n" comparison_output += f"FN: {baseline_perf['fn']} → {results['eval_fn']} ({results['eval_fn'] - baseline_perf['fn']:+d})" info += "\n\n✅ 訓練完成!" return info, baseline_output, finetuned_output, comparison_output except Exception as e: import traceback error_msg = f"❌ 錯誤: {str(e)}\n\n{traceback.format_exc()}" return error_msg, "", "", "" def predict(model_id, text): global baseline_model_cache if not model_id or model_id not in trained_models: return "❌ 請選擇模型" if not text: return "❌ 請輸入文字" try: info = trained_models[model_id] model, tokenizer = info['model'], info['tokenizer'] config = info['config'] device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) inputs_cuda = {k: v.to(device) for k, v in inputs.items()} # 預測:微調模型 model.eval() with torch.no_grad(): outputs = model(**inputs_cuda) probs_finetuned = torch.nn.functional.softmax(outputs.logits, dim=-1) pred_finetuned = torch.argmax(probs_finetuned, dim=-1).item() result_finetuned = "存活" if pred_finetuned == 0 else "死亡" # 預測:基準模型(使用快取) cache_key = config['model_name'] if cache_key not in baseline_model_cache: baseline_model = BertForSequenceClassification.from_pretrained(config['model_name'], num_labels=2) baseline_model = baseline_model.to(device) baseline_model.eval() baseline_model_cache[cache_key] = baseline_model else: baseline_model = baseline_model_cache[cache_key] with torch.no_grad(): outputs_baseline = baseline_model(**inputs_cuda) probs_baseline = torch.nn.functional.softmax(outputs_baseline.logits, dim=-1) pred_baseline = torch.argmax(probs_baseline, dim=-1).item() result_baseline = "存活" if pred_baseline == 0 else "死亡" # 判斷是否一致 agreement = "✅ 一致" if pred_finetuned == pred_baseline else "⚠️ 不一致" output = f"""🔮 預測結果比較 📝 輸入文字: {text[:100]}{'...' if len(text) > 100 else ''} {'='*50} 🧬 微調模型 ({model_id}) 預測: {result_finetuned} 信心: {probs_finetuned[0][pred_finetuned].item():.2%} 機率分布: • 存活: {probs_finetuned[0][0].item():.2%} • 死亡: {probs_finetuned[0][1].item():.2%} {'='*50} 🔬 基準模型(未微調 {config['type']}) 預測: {result_baseline} 信心: {probs_baseline[0][pred_baseline].item():.2%} 機率分布: • 存活: {probs_baseline[0][0].item():.2%} • 死亡: {probs_baseline[0][1].item():.2%} {'='*50} 📊 結論 兩模型預測: {agreement} """ if pred_finetuned != pred_baseline: output += f"\n💡 分析: 微調模型預測為【{result_finetuned}】,而基準模型預測為【{result_baseline}】" output += f"\n 這顯示了 fine-tuning 對此案例的影響!" f1_improve = calculate_improvement(info['baseline']['f1'], info['results']['eval_f1']) output += f""" 📈 模型表現 微調模型 F1: {info['results']['eval_f1']:.4f} 基準模型 F1: {info['baseline']['f1']:.4f} 改善幅度: {format_improve(f1_improve)} """ return output except Exception as e: import traceback return f"❌ 錯誤: {str(e)}\n\n{traceback.format_exc()}" def compare(): if not trained_models: return "❌ 尚未訓練模型" text = "# 📊 模型比較\n\n" text += "## 微調模型表現\n\n" text += "| 模型 | 基礎 | 方法 | F1 | Acc | Prec | Recall | Sens | Spec |\n" text += "|------|------|------|-----|-----|------|--------|------|------|\n" for mid, info in trained_models.items(): r = info['results'] c = info['config'] text += f"| {mid} | {c['type']} | {c['method'].upper()} | {r['eval_f1']:.4f} | {r['eval_accuracy']:.4f} | " text += f"{r['eval_precision']:.4f} | {r['eval_recall']:.4f} | " text += f"{r['eval_sensitivity']:.4f} | {r['eval_specificity']:.4f} |\n" text += "\n## 基準模型表現(未微調)\n\n" text += "| 模型 | F1 | Acc | Prec | Recall | Sens | Spec |\n" text += "|------|-----|-----|------|--------|------|------|\n" for mid, info in trained_models.items(): b = info['baseline'] c = info['config'] text += f"| {c['type']}-baseline | {b['f1']:.4f} | {b['accuracy']:.4f} | " text += f"{b['precision']:.4f} | {b['recall']:.4f} | " text += f"{b['sensitivity']:.4f} | {b['specificity']:.4f} |\n" text += "\n## 🏆 最佳模型\n\n" for metric in ['f1', 'accuracy', 'precision', 'recall', 'sensitivity', 'specificity']: best = max(trained_models.items(), key=lambda x: x[1]['results'][f'eval_{metric}']) baseline_val = best[1]['baseline'][metric] finetuned_val = best[1]['results'][f'eval_{metric}'] improvement = calculate_improvement(baseline_val, finetuned_val) text += f"**{metric.upper()}**: {best[0]} ({finetuned_val:.4f}, 改善 {format_improve(improvement)})\n\n" return text def refresh_model_list(): return gr.Dropdown(choices=list(trained_models.keys())) # Gradio UI with gr.Blocks(title="BERT Fine-tuning 教學平台", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🧬 BERT Fine-tuning 教學平台") gr.Markdown("### 比較基準模型 vs 微調模型的表現差異") with gr.Tab("訓練"): gr.Markdown("## 步驟 1: 選擇基礎模型") base_model = gr.Dropdown( choices=["BERT-base"], value="BERT-base", label="基礎模型", info="更多模型即將推出" ) gr.Markdown("## 步驟 2: 選擇微調方法") method = gr.Radio( choices=["lora", "adalora"], value="lora", label="微調方法", info="兩種都是參數高效方法,推薦從 LoRA 開始" ) gr.Markdown("## 步驟 3: 上傳資料") csv_file = gr.File(label="CSV 檔案 (需包含 Text 和 label 欄位)", file_types=[".csv"]) gr.Markdown("## 步驟 4: 設定訓練參數") gr.Markdown("### 🎯 基本訓練參數") with gr.Row(): num_epochs = gr.Number(value=3, label="訓練輪數 (epochs)", minimum=1, maximum=100, precision=0) batch_size = gr.Number(value=8, label="批次大小 (batch_size)", minimum=1, maximum=128, precision=0) learning_rate = gr.Number(value=2e-5, label="學習率 (learning_rate)", minimum=0, maximum=1) gr.Markdown("### ⚙️ 進階參數") with gr.Row(): weight_decay = gr.Number(value=0.01, label="權重衰減 (weight_decay)", minimum=0, maximum=1) dropout = gr.Number(value=0.1, label="Dropout 機率", minimum=0, maximum=1) gr.Markdown("### 🔧 LoRA 參數") with gr.Row(): lora_r = gr.Number(value=16, label="LoRA Rank (r)", minimum=1, maximum=256, precision=0, info="推薦 8-16,越大效果越好但越慢") lora_alpha = gr.Number(value=32, label="LoRA Alpha", minimum=1, maximum=512, precision=0, info="通常設為 Rank 的 2 倍") lora_dropout = gr.Number(value=0.1, label="LoRA Dropout", minimum=0, maximum=1, info="防止過擬合") gr.Markdown("### ⚖️ 評估設定") with gr.Row(): weight_mult = gr.Number(value=2.0, label="類別權重倍數", minimum=0, maximum=10, info="推薦 1.5-2.5,過低會忽略少數類") best_metric = gr.Dropdown( choices=["f1", "accuracy", "precision", "recall", "sensitivity", "specificity"], value="f1", label="最佳模型選擇指標", info="訓練時用此指標選擇最佳模型" ) train_btn = gr.Button("🚀 開始訓練", variant="primary", size="lg") gr.Markdown("## 📊 訓練結果") data_info = gr.Textbox(label="📋 資料資訊", lines=10) with gr.Row(): baseline_result = gr.Textbox(label="🔬 純 BERT(未微調)", lines=14) finetuned_result = gr.Textbox(label="✅ 微調 BERT", lines=14) comparison_result = gr.Textbox(label="📊 純 BERT vs 微調 BERT 比較", lines=14) train_btn.click( train_bert_model, inputs=[csv_file, base_model, method, num_epochs, batch_size, learning_rate, weight_decay, dropout, lora_r, lora_alpha, lora_dropout, weight_mult, best_metric], outputs=[data_info, baseline_result, finetuned_result, comparison_result] ) with gr.Tab("預測"): gr.Markdown("## 使用訓練好的模型預測") with gr.Row(): model_drop = gr.Dropdown(label="選擇模型", choices=list(trained_models.keys())) refresh = gr.Button("🔄 刷新") text_input = gr.Textbox(label="輸入病例描述", lines=4, placeholder="Patient diagnosed with...") predict_btn = gr.Button("預測", variant="primary", size="lg") pred_output = gr.Textbox(label="預測結果(含基準模型對比)", lines=20) refresh.click(refresh_model_list, outputs=[model_drop]) predict_btn.click(predict, inputs=[model_drop, text_input], outputs=[pred_output]) gr.Examples( examples=[ ["Patient with stage II breast cancer, good response to treatment."], ["Advanced metastatic cancer, multiple organ involvement."] ], inputs=text_input ) with gr.Tab("比較"): gr.Markdown("## 比較所有模型(含基準模型)") compare_btn = gr.Button("比較", variant="primary", size="lg") compare_output = gr.Markdown() compare_btn.click(compare, outputs=[compare_output]) with gr.Tab("說明"): gr.Markdown(""" ## 📖 使用說明 ### 🎯 平台特色 本平台會自動比較: - **基準模型**:未經微調的原始 BERT - **微調模型**:使用你的資料訓練後的 BERT 這樣可以清楚看到 fine-tuning 帶來的改善! ### 基礎模型 - **BERT-base**: 標準 BERT,110M 參數 ⭐目前支援 ### 微調方法 - **LoRA**: 低秩適應,參數高效的微調方法 ⭐強烈推薦 - 只訓練少量參數(通常 <1%) - 訓練速度快,效果好 - 適合大多數情況 - **AdaLoRA**: 自適應 LoRA,動態調整秩 - 自動找出最重要的參數 - 可能比 LoRA 效果稍好 - 訓練時間稍長 ### 評估指標 - **F1**: 平衡指標,推薦用於不平衡資料 ⭐ - **Accuracy**: 整體準確率 - **Precision**: 減少假陽性 - **Recall/Sensitivity**: 減少假陰性 - **Specificity**: 真陰性率 ### 參數建議 針對不平衡資料(如醫療資料): - **微調方法**: LoRA(快速有效)或 AdaLoRA(追求極致) - **LoRA Rank**: 8-16(平衡效果與速度) - **類別權重倍數**: 1.5-2.5(資料不平衡時) - **Learning rate**: 2e-5 到 5e-5 - **Epochs**: 3-8(避免過擬合) - **Batch size**: 8-16(依 GPU 記憶體調整) ### 資料格式 CSV 必須包含: - `Text`: 病例描述 - `label`: 0=存活, 1=死亡 ### 🚀 快速開始 1. 上傳包含 `Text` 和 `label` 欄位的 CSV 2. 使用預設參數(適合大多數情況) 3. 點擊「開始訓練」 4. 在「預測」分頁測試模型 5. 在「比較」分頁查看所有模型表現 """) if __name__ == "__main__": demo.launch()