Spaces:
Paused
Paused
| import gradio as gr | |
| import pandas as pd | |
| import torch | |
| from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer | |
| from peft import LoraConfig, AdaLoraConfig, get_peft_model, TaskType | |
| from datasets import Dataset | |
| from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix | |
| from torch import nn | |
| import os | |
| from datetime import datetime | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # 全域變數 | |
| trained_models = {} | |
| model_counter = 0 | |
| baseline_results = {} | |
| baseline_model_cache = {} | |
| def calculate_improvement(baseline_val, finetuned_val): | |
| """安全計算改善率""" | |
| if baseline_val == 0: | |
| if finetuned_val > 0: | |
| return float('inf') | |
| else: | |
| return 0.0 | |
| return (finetuned_val - baseline_val) / baseline_val * 100 | |
| def format_improve(val): | |
| """格式化改善率""" | |
| if val == float('inf'): | |
| return "N/A (baseline=0)" | |
| return f"{val:+.1f}%" | |
| def compute_metrics(pred): | |
| try: | |
| labels = pred.label_ids | |
| preds = pred.predictions.argmax(-1) | |
| precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary', pos_label=1, zero_division=0) | |
| acc = accuracy_score(labels, preds) | |
| cm = confusion_matrix(labels, preds) | |
| if cm.shape == (2, 2): | |
| tn, fp, fn, tp = cm.ravel() | |
| else: | |
| tn = fp = fn = tp = 0 | |
| sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0 | |
| specificity = tn / (tn + fp) if (tn + fp) > 0 else 0 | |
| return { | |
| 'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall, | |
| 'sensitivity': sensitivity, 'specificity': specificity, | |
| 'tp': int(tp), 'tn': int(tn), 'fp': int(fp), 'fn': int(fn) | |
| } | |
| except Exception as e: | |
| print(f"Error in compute_metrics: {e}") | |
| return { | |
| 'accuracy': 0, 'f1': 0, 'precision': 0, 'recall': 0, | |
| 'sensitivity': 0, 'specificity': 0, 'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0 | |
| } | |
| class WeightedTrainer(Trainer): | |
| def __init__(self, *args, class_weights=None, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.class_weights = class_weights | |
| def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): | |
| labels = inputs.pop("labels") | |
| outputs = model(**inputs) | |
| loss_fct = nn.CrossEntropyLoss(weight=self.class_weights) | |
| loss = loss_fct(outputs.logits.view(-1, 2), labels.view(-1)) | |
| return (loss, outputs) if return_outputs else loss | |
| def evaluate_baseline(model, tokenizer, test_dataset, device): | |
| """評估未微調的基準模型""" | |
| model.eval() | |
| all_preds = [] | |
| all_labels = [] | |
| from torch.utils.data import DataLoader | |
| def collate_fn(batch): | |
| return { | |
| 'input_ids': torch.stack([torch.tensor(item['input_ids']) for item in batch]), | |
| 'attention_mask': torch.stack([torch.tensor(item['attention_mask']) for item in batch]), | |
| 'labels': torch.tensor([item['label'] for item in batch]) | |
| } | |
| dataloader = DataLoader(test_dataset, batch_size=16, collate_fn=collate_fn) | |
| with torch.no_grad(): | |
| for batch in dataloader: | |
| labels = batch.pop('labels') | |
| inputs = {k: v.to(device) for k, v in batch.items()} | |
| outputs = model(**inputs) | |
| preds = torch.argmax(outputs.logits, dim=-1) | |
| all_preds.extend(preds.cpu().numpy()) | |
| all_labels.extend(labels.numpy()) | |
| precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary', pos_label=1, zero_division=0) | |
| acc = accuracy_score(all_labels, all_preds) | |
| cm = confusion_matrix(all_labels, all_preds) | |
| if cm.shape == (2, 2): | |
| tn, fp, fn, tp = cm.ravel() | |
| else: | |
| tn = fp = fn = tp = 0 | |
| sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0 | |
| specificity = tn / (tn + fp) if (tn + fp) > 0 else 0 | |
| return { | |
| 'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall, | |
| 'sensitivity': sensitivity, 'specificity': specificity, | |
| 'tp': int(tp), 'tn': int(tn), 'fp': int(fp), 'fn': int(fn) | |
| } | |
| def train_bert_model(csv_file, base_model, method, num_epochs, batch_size, learning_rate, | |
| weight_decay, dropout, lora_r, lora_alpha, lora_dropout, | |
| weight_mult, best_metric): | |
| global trained_models, model_counter, baseline_results | |
| model_mapping = { | |
| "BERT-base": "bert-base-uncased", | |
| } | |
| model_name = model_mapping.get(base_model, "bert-base-uncased") | |
| try: | |
| if csv_file is None: | |
| return "❌ 請上傳 CSV", "", "", "" | |
| df = pd.read_csv(csv_file.name) | |
| if 'Text' not in df.columns or 'label' not in df.columns: | |
| return "❌ 需要 Text 和 label 欄位", "", "", "" | |
| df_clean = pd.DataFrame({ | |
| 'text': df['Text'].astype(str), | |
| 'label': df['label'].astype(int) | |
| }).dropna() | |
| n0 = int(sum(df_clean['label'] == 0)) | |
| n1 = int(sum(df_clean['label'] == 1)) | |
| if n1 == 0: | |
| return "❌ 無死亡樣本", "", "", "" | |
| ratio = n0 / n1 | |
| w0, w1 = 1.0, ratio * weight_mult | |
| info = f"📊 資料: {len(df_clean)} 筆\n存活: {n0} | 死亡: {n1}\n比例: {ratio:.2f}:1\n權重: {w0:.2f} / {w1:.2f}\n模型: {base_model}\n方法: {method.upper()}" | |
| tokenizer = BertTokenizer.from_pretrained(model_name) | |
| dataset = Dataset.from_pandas(df_clean[['text', 'label']]) | |
| def preprocess(ex): | |
| return tokenizer(ex['text'], truncation=True, padding='max_length', max_length=128) | |
| tokenized = dataset.map(preprocess, batched=True, remove_columns=['text']) | |
| split = tokenized.train_test_split(test_size=0.2, seed=42) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| info += f"\n裝置: {'GPU ✅' if torch.cuda.is_available() else 'CPU ⚠️'}" | |
| # 評估基準模型(未微調) | |
| info += "\n\n🔍 評估基準模型(未微調)..." | |
| baseline_model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2) | |
| baseline_model = baseline_model.to(device) | |
| baseline_perf = evaluate_baseline(baseline_model, tokenizer, split['test'], device) | |
| baseline_key = f"{base_model}_baseline" | |
| baseline_results[baseline_key] = baseline_perf | |
| info += f"\n基準 F1: {baseline_perf['f1']:.4f}" | |
| info += f"\n基準 Accuracy: {baseline_perf['accuracy']:.4f}" | |
| # 清理基準模型以釋放記憶體 | |
| del baseline_model | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
| # 開始微調 | |
| info += f"\n\n🔧 套用 {method.upper()} 微調..." | |
| model = BertForSequenceClassification.from_pretrained( | |
| model_name, num_labels=2, | |
| hidden_dropout_prob=dropout, | |
| attention_probs_dropout_prob=dropout | |
| ) | |
| peft_applied = False | |
| if method == "lora": | |
| config = LoraConfig( | |
| task_type=TaskType.SEQ_CLS, | |
| r=int(lora_r), | |
| lora_alpha=int(lora_alpha), | |
| lora_dropout=lora_dropout, | |
| target_modules=["query", "value"], | |
| bias="none" | |
| ) | |
| model = get_peft_model(model, config) | |
| peft_applied = True | |
| info += f"\n✅ LoRA 已套用(r={int(lora_r)}, alpha={int(lora_alpha)})" | |
| elif method == "adalora": | |
| config = AdaLoraConfig( | |
| task_type=TaskType.SEQ_CLS, | |
| r=int(lora_r), | |
| lora_alpha=int(lora_alpha), | |
| lora_dropout=lora_dropout, | |
| target_modules=["query", "value"], | |
| init_r=12, tinit=200, tfinal=1000, deltaT=10 | |
| ) | |
| model = get_peft_model(model, config) | |
| peft_applied = True | |
| info += f"\n✅ AdaLoRA 已套用(r={int(lora_r)}, alpha={int(lora_alpha)})" | |
| if not peft_applied: | |
| info += f"\n⚠️ 警告:{method} 方法未被識別,使用 Full Fine-tuning" | |
| model = model.to(device) | |
| total = sum(p.numel() for p in model.parameters()) | |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| info += f"\n\n💾 參數量\n總參數: {total:,}\n可訓練: {trainable:,}\n比例: {trainable/total*100:.2f}%" | |
| weights = torch.tensor([w0, w1], dtype=torch.float).to(device) | |
| args = TrainingArguments( | |
| output_dir='./results', | |
| num_train_epochs=int(num_epochs), | |
| per_device_train_batch_size=int(batch_size), | |
| per_device_eval_batch_size=int(batch_size)*2, | |
| learning_rate=float(learning_rate), | |
| weight_decay=float(weight_decay), | |
| evaluation_strategy="epoch", | |
| save_strategy="epoch", | |
| load_best_model_at_end=True, | |
| metric_for_best_model=best_metric, | |
| report_to="none", | |
| logging_steps=50, | |
| save_total_limit=2 | |
| ) | |
| trainer = WeightedTrainer( | |
| model=model, | |
| args=args, | |
| train_dataset=split['train'], | |
| eval_dataset=split['test'], | |
| compute_metrics=compute_metrics, | |
| class_weights=weights | |
| ) | |
| info += "\n\n⏳ 開始訓練..." | |
| trainer.train() | |
| results = trainer.evaluate() | |
| # 生成帶時間戳的模型 ID | |
| model_counter += 1 | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| model_id = f"{base_model}_{method}_{timestamp}" | |
| trained_models[model_id] = { | |
| 'model': model, | |
| 'tokenizer': tokenizer, | |
| 'results': results, | |
| 'baseline': baseline_perf, | |
| 'config': { | |
| 'type': base_model, | |
| 'model_name': model_name, | |
| 'method': method, | |
| 'metric': best_metric | |
| }, | |
| 'timestamp': timestamp | |
| } | |
| # 計算改善 | |
| f1_improve = calculate_improvement(baseline_perf['f1'], results['eval_f1']) | |
| acc_improve = calculate_improvement(baseline_perf['accuracy'], results['eval_accuracy']) | |
| prec_improve = calculate_improvement(baseline_perf['precision'], results['eval_precision']) | |
| rec_improve = calculate_improvement(baseline_perf['recall'], results['eval_recall']) | |
| sens_improve = calculate_improvement(baseline_perf['sensitivity'], results['eval_sensitivity']) | |
| spec_improve = calculate_improvement(baseline_perf['specificity'], results['eval_specificity']) | |
| # 純 BERT 輸出 | |
| baseline_output = f"🔬 純 BERT(未微調)\n\n" | |
| baseline_output += f"📈 表現\n" | |
| baseline_output += f"F1: {baseline_perf['f1']:.4f}\n" | |
| baseline_output += f"Accuracy: {baseline_perf['accuracy']:.4f}\n" | |
| baseline_output += f"Precision: {baseline_perf['precision']:.4f}\n" | |
| baseline_output += f"Recall: {baseline_perf['recall']:.4f}\n" | |
| baseline_output += f"Sensitivity: {baseline_perf['sensitivity']:.4f}\n" | |
| baseline_output += f"Specificity: {baseline_perf['specificity']:.4f}\n\n" | |
| baseline_output += f"混淆矩陣\n" | |
| baseline_output += f"TP: {baseline_perf['tp']} | TN: {baseline_perf['tn']}\n" | |
| baseline_output += f"FP: {baseline_perf['fp']} | FN: {baseline_perf['fn']}" | |
| # 微調 BERT 輸出 | |
| finetuned_output = f"✅ 微調 BERT\n模型: {model_id}\n\n" | |
| finetuned_output += f"📈 表現\n" | |
| finetuned_output += f"F1: {results['eval_f1']:.4f}\n" | |
| finetuned_output += f"Accuracy: {results['eval_accuracy']:.4f}\n" | |
| finetuned_output += f"Precision: {results['eval_precision']:.4f}\n" | |
| finetuned_output += f"Recall: {results['eval_recall']:.4f}\n" | |
| finetuned_output += f"Sensitivity: {results['eval_sensitivity']:.4f}\n" | |
| finetuned_output += f"Specificity: {results['eval_specificity']:.4f}\n\n" | |
| finetuned_output += f"混淆矩陣\n" | |
| finetuned_output += f"TP: {results['eval_tp']} | TN: {results['eval_tn']}\n" | |
| finetuned_output += f"FP: {results['eval_fp']} | FN: {results['eval_fn']}" | |
| # 比較結果輸出 | |
| comparison_output = f"📊 純 BERT vs 微調 BERT 比較\n\n" | |
| comparison_output += f"指標改善:\n" | |
| comparison_output += f"F1: {baseline_perf['f1']:.4f} → {results['eval_f1']:.4f} ({format_improve(f1_improve)})\n" | |
| comparison_output += f"Accuracy: {baseline_perf['accuracy']:.4f} → {results['eval_accuracy']:.4f} ({format_improve(acc_improve)})\n" | |
| comparison_output += f"Precision: {baseline_perf['precision']:.4f} → {results['eval_precision']:.4f} ({format_improve(prec_improve)})\n" | |
| comparison_output += f"Recall: {baseline_perf['recall']:.4f} → {results['eval_recall']:.4f} ({format_improve(rec_improve)})\n" | |
| comparison_output += f"Sensitivity: {baseline_perf['sensitivity']:.4f} → {results['eval_sensitivity']:.4f} ({format_improve(sens_improve)})\n" | |
| comparison_output += f"Specificity: {baseline_perf['specificity']:.4f} → {results['eval_specificity']:.4f} ({format_improve(spec_improve)})\n\n" | |
| comparison_output += f"混淆矩陣變化:\n" | |
| comparison_output += f"TP: {baseline_perf['tp']} → {results['eval_tp']} ({results['eval_tp'] - baseline_perf['tp']:+d})\n" | |
| comparison_output += f"TN: {baseline_perf['tn']} → {results['eval_tn']} ({results['eval_tn'] - baseline_perf['tn']:+d})\n" | |
| comparison_output += f"FP: {baseline_perf['fp']} → {results['eval_fp']} ({results['eval_fp'] - baseline_perf['fp']:+d})\n" | |
| comparison_output += f"FN: {baseline_perf['fn']} → {results['eval_fn']} ({results['eval_fn'] - baseline_perf['fn']:+d})" | |
| info += "\n\n✅ 訓練完成!" | |
| return info, baseline_output, finetuned_output, comparison_output | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"❌ 錯誤: {str(e)}\n\n{traceback.format_exc()}" | |
| return error_msg, "", "", "" | |
| def predict(model_id, text): | |
| global baseline_model_cache | |
| if not model_id or model_id not in trained_models: | |
| return "❌ 請選擇模型" | |
| if not text: | |
| return "❌ 請輸入文字" | |
| try: | |
| info = trained_models[model_id] | |
| model, tokenizer = info['model'], info['tokenizer'] | |
| config = info['config'] | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) | |
| inputs_cuda = {k: v.to(device) for k, v in inputs.items()} | |
| # 預測:微調模型 | |
| model.eval() | |
| with torch.no_grad(): | |
| outputs = model(**inputs_cuda) | |
| probs_finetuned = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| pred_finetuned = torch.argmax(probs_finetuned, dim=-1).item() | |
| result_finetuned = "存活" if pred_finetuned == 0 else "死亡" | |
| # 預測:基準模型(使用快取) | |
| cache_key = config['model_name'] | |
| if cache_key not in baseline_model_cache: | |
| baseline_model = BertForSequenceClassification.from_pretrained(config['model_name'], num_labels=2) | |
| baseline_model = baseline_model.to(device) | |
| baseline_model.eval() | |
| baseline_model_cache[cache_key] = baseline_model | |
| else: | |
| baseline_model = baseline_model_cache[cache_key] | |
| with torch.no_grad(): | |
| outputs_baseline = baseline_model(**inputs_cuda) | |
| probs_baseline = torch.nn.functional.softmax(outputs_baseline.logits, dim=-1) | |
| pred_baseline = torch.argmax(probs_baseline, dim=-1).item() | |
| result_baseline = "存活" if pred_baseline == 0 else "死亡" | |
| # 判斷是否一致 | |
| agreement = "✅ 一致" if pred_finetuned == pred_baseline else "⚠️ 不一致" | |
| output = f"""🔮 預測結果比較 | |
| 📝 輸入文字: {text[:100]}{'...' if len(text) > 100 else ''} | |
| {'='*50} | |
| 🧬 微調模型 ({model_id}) | |
| 預測: {result_finetuned} | |
| 信心: {probs_finetuned[0][pred_finetuned].item():.2%} | |
| 機率分布: | |
| • 存活: {probs_finetuned[0][0].item():.2%} | |
| • 死亡: {probs_finetuned[0][1].item():.2%} | |
| {'='*50} | |
| 🔬 基準模型(未微調 {config['type']}) | |
| 預測: {result_baseline} | |
| 信心: {probs_baseline[0][pred_baseline].item():.2%} | |
| 機率分布: | |
| • 存活: {probs_baseline[0][0].item():.2%} | |
| • 死亡: {probs_baseline[0][1].item():.2%} | |
| {'='*50} | |
| 📊 結論 | |
| 兩模型預測: {agreement} | |
| """ | |
| if pred_finetuned != pred_baseline: | |
| output += f"\n💡 分析: 微調模型預測為【{result_finetuned}】,而基準模型預測為【{result_baseline}】" | |
| output += f"\n 這顯示了 fine-tuning 對此案例的影響!" | |
| f1_improve = calculate_improvement(info['baseline']['f1'], info['results']['eval_f1']) | |
| output += f""" | |
| 📈 模型表現 | |
| 微調模型 F1: {info['results']['eval_f1']:.4f} | |
| 基準模型 F1: {info['baseline']['f1']:.4f} | |
| 改善幅度: {format_improve(f1_improve)} | |
| """ | |
| return output | |
| except Exception as e: | |
| import traceback | |
| return f"❌ 錯誤: {str(e)}\n\n{traceback.format_exc()}" | |
| def compare(): | |
| if not trained_models: | |
| return "❌ 尚未訓練模型" | |
| text = "# 📊 模型比較\n\n" | |
| text += "## 微調模型表現\n\n" | |
| text += "| 模型 | 基礎 | 方法 | F1 | Acc | Prec | Recall | Sens | Spec |\n" | |
| text += "|------|------|------|-----|-----|------|--------|------|------|\n" | |
| for mid, info in trained_models.items(): | |
| r = info['results'] | |
| c = info['config'] | |
| text += f"| {mid} | {c['type']} | {c['method'].upper()} | {r['eval_f1']:.4f} | {r['eval_accuracy']:.4f} | " | |
| text += f"{r['eval_precision']:.4f} | {r['eval_recall']:.4f} | " | |
| text += f"{r['eval_sensitivity']:.4f} | {r['eval_specificity']:.4f} |\n" | |
| text += "\n## 基準模型表現(未微調)\n\n" | |
| text += "| 模型 | F1 | Acc | Prec | Recall | Sens | Spec |\n" | |
| text += "|------|-----|-----|------|--------|------|------|\n" | |
| for mid, info in trained_models.items(): | |
| b = info['baseline'] | |
| c = info['config'] | |
| text += f"| {c['type']}-baseline | {b['f1']:.4f} | {b['accuracy']:.4f} | " | |
| text += f"{b['precision']:.4f} | {b['recall']:.4f} | " | |
| text += f"{b['sensitivity']:.4f} | {b['specificity']:.4f} |\n" | |
| text += "\n## 🏆 最佳模型\n\n" | |
| for metric in ['f1', 'accuracy', 'precision', 'recall', 'sensitivity', 'specificity']: | |
| best = max(trained_models.items(), key=lambda x: x[1]['results'][f'eval_{metric}']) | |
| baseline_val = best[1]['baseline'][metric] | |
| finetuned_val = best[1]['results'][f'eval_{metric}'] | |
| improvement = calculate_improvement(baseline_val, finetuned_val) | |
| text += f"**{metric.upper()}**: {best[0]} ({finetuned_val:.4f}, 改善 {format_improve(improvement)})\n\n" | |
| return text | |
| def refresh_model_list(): | |
| return gr.Dropdown(choices=list(trained_models.keys())) | |
| # Gradio UI | |
| with gr.Blocks(title="BERT Fine-tuning 教學平台", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🧬 BERT Fine-tuning 教學平台") | |
| gr.Markdown("### 比較基準模型 vs 微調模型的表現差異") | |
| with gr.Tab("訓練"): | |
| gr.Markdown("## 步驟 1: 選擇基礎模型") | |
| base_model = gr.Dropdown( | |
| choices=["BERT-base"], | |
| value="BERT-base", | |
| label="基礎模型", | |
| info="更多模型即將推出" | |
| ) | |
| gr.Markdown("## 步驟 2: 選擇微調方法") | |
| method = gr.Radio( | |
| choices=["lora", "adalora"], | |
| value="lora", | |
| label="微調方法", | |
| info="兩種都是參數高效方法,推薦從 LoRA 開始" | |
| ) | |
| gr.Markdown("## 步驟 3: 上傳資料") | |
| csv_file = gr.File(label="CSV 檔案 (需包含 Text 和 label 欄位)", file_types=[".csv"]) | |
| gr.Markdown("## 步驟 4: 設定訓練參數") | |
| gr.Markdown("### 🎯 基本訓練參數") | |
| with gr.Row(): | |
| num_epochs = gr.Number(value=3, label="訓練輪數 (epochs)", minimum=1, maximum=100, precision=0) | |
| batch_size = gr.Number(value=8, label="批次大小 (batch_size)", minimum=1, maximum=128, precision=0) | |
| learning_rate = gr.Number(value=2e-5, label="學習率 (learning_rate)", minimum=0, maximum=1) | |
| gr.Markdown("### ⚙️ 進階參數") | |
| with gr.Row(): | |
| weight_decay = gr.Number(value=0.01, label="權重衰減 (weight_decay)", minimum=0, maximum=1) | |
| dropout = gr.Number(value=0.1, label="Dropout 機率", minimum=0, maximum=1) | |
| gr.Markdown("### 🔧 LoRA 參數") | |
| with gr.Row(): | |
| lora_r = gr.Number(value=16, label="LoRA Rank (r)", minimum=1, maximum=256, precision=0, | |
| info="推薦 8-16,越大效果越好但越慢") | |
| lora_alpha = gr.Number(value=32, label="LoRA Alpha", minimum=1, maximum=512, precision=0, | |
| info="通常設為 Rank 的 2 倍") | |
| lora_dropout = gr.Number(value=0.1, label="LoRA Dropout", minimum=0, maximum=1, | |
| info="防止過擬合") | |
| gr.Markdown("### ⚖️ 評估設定") | |
| with gr.Row(): | |
| weight_mult = gr.Number(value=2.0, label="類別權重倍數", minimum=0, maximum=10, | |
| info="推薦 1.5-2.5,過低會忽略少數類") | |
| best_metric = gr.Dropdown( | |
| choices=["f1", "accuracy", "precision", "recall", "sensitivity", "specificity"], | |
| value="f1", | |
| label="最佳模型選擇指標", | |
| info="訓練時用此指標選擇最佳模型" | |
| ) | |
| train_btn = gr.Button("🚀 開始訓練", variant="primary", size="lg") | |
| gr.Markdown("## 📊 訓練結果") | |
| data_info = gr.Textbox(label="📋 資料資訊", lines=10) | |
| with gr.Row(): | |
| baseline_result = gr.Textbox(label="🔬 純 BERT(未微調)", lines=14) | |
| finetuned_result = gr.Textbox(label="✅ 微調 BERT", lines=14) | |
| comparison_result = gr.Textbox(label="📊 純 BERT vs 微調 BERT 比較", lines=14) | |
| train_btn.click( | |
| train_bert_model, | |
| inputs=[csv_file, base_model, method, num_epochs, batch_size, learning_rate, | |
| weight_decay, dropout, lora_r, lora_alpha, lora_dropout, | |
| weight_mult, best_metric], | |
| outputs=[data_info, baseline_result, finetuned_result, comparison_result] | |
| ) | |
| with gr.Tab("預測"): | |
| gr.Markdown("## 使用訓練好的模型預測") | |
| with gr.Row(): | |
| model_drop = gr.Dropdown(label="選擇模型", choices=list(trained_models.keys())) | |
| refresh = gr.Button("🔄 刷新") | |
| text_input = gr.Textbox(label="輸入病例描述", lines=4, | |
| placeholder="Patient diagnosed with...") | |
| predict_btn = gr.Button("預測", variant="primary", size="lg") | |
| pred_output = gr.Textbox(label="預測結果(含基準模型對比)", lines=20) | |
| refresh.click(refresh_model_list, outputs=[model_drop]) | |
| predict_btn.click(predict, inputs=[model_drop, text_input], outputs=[pred_output]) | |
| gr.Examples( | |
| examples=[ | |
| ["Patient with stage II breast cancer, good response to treatment."], | |
| ["Advanced metastatic cancer, multiple organ involvement."] | |
| ], | |
| inputs=text_input | |
| ) | |
| with gr.Tab("比較"): | |
| gr.Markdown("## 比較所有模型(含基準模型)") | |
| compare_btn = gr.Button("比較", variant="primary", size="lg") | |
| compare_output = gr.Markdown() | |
| compare_btn.click(compare, outputs=[compare_output]) | |
| with gr.Tab("說明"): | |
| gr.Markdown(""" | |
| ## 📖 使用說明 | |
| ### 🎯 平台特色 | |
| 本平台會自動比較: | |
| - **基準模型**:未經微調的原始 BERT | |
| - **微調模型**:使用你的資料訓練後的 BERT | |
| 這樣可以清楚看到 fine-tuning 帶來的改善! | |
| ### 基礎模型 | |
| - **BERT-base**: 標準 BERT,110M 參數 ⭐目前支援 | |
| ### 微調方法 | |
| - **LoRA**: 低秩適應,參數高效的微調方法 ⭐強烈推薦 | |
| - 只訓練少量參數(通常 <1%) | |
| - 訓練速度快,效果好 | |
| - 適合大多數情況 | |
| - **AdaLoRA**: 自適應 LoRA,動態調整秩 | |
| - 自動找出最重要的參數 | |
| - 可能比 LoRA 效果稍好 | |
| - 訓練時間稍長 | |
| ### 評估指標 | |
| - **F1**: 平衡指標,推薦用於不平衡資料 ⭐ | |
| - **Accuracy**: 整體準確率 | |
| - **Precision**: 減少假陽性 | |
| - **Recall/Sensitivity**: 減少假陰性 | |
| - **Specificity**: 真陰性率 | |
| ### 參數建議 | |
| 針對不平衡資料(如醫療資料): | |
| - **微調方法**: LoRA(快速有效)或 AdaLoRA(追求極致) | |
| - **LoRA Rank**: 8-16(平衡效果與速度) | |
| - **類別權重倍數**: 1.5-2.5(資料不平衡時) | |
| - **Learning rate**: 2e-5 到 5e-5 | |
| - **Epochs**: 3-8(避免過擬合) | |
| - **Batch size**: 8-16(依 GPU 記憶體調整) | |
| ### 資料格式 | |
| CSV 必須包含: | |
| - `Text`: 病例描述 | |
| - `label`: 0=存活, 1=死亡 | |
| ### 🚀 快速開始 | |
| 1. 上傳包含 `Text` 和 `label` 欄位的 CSV | |
| 2. 使用預設參數(適合大多數情況) | |
| 3. 點擊「開始訓練」 | |
| 4. 在「預測」分頁測試模型 | |
| 5. 在「比較」分頁查看所有模型表現 | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() |