Spaces:
Paused
Paused
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer | |
| from peft import LoraConfig, AdaLoraConfig, get_peft_model, TaskType | |
| from datasets import Dataset | |
| from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix | |
| from torch import nn | |
| from torch.utils.data import DataLoader, WeightedRandomSampler | |
| import os | |
| from datetime import datetime | |
| import gc | |
| import json | |
| from functools import lru_cache | |
| from typing import Dict, List, Tuple, Optional | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # 環境設置 | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" | |
| # 優化 CUDA 設置 | |
| torch.backends.cudnn.benchmark = False | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # ==================== 全域變數 ==================== | |
| trained_models = {} | |
| model_counter = 0 | |
| training_histories = {} # 新增:儲存訓練歷史 | |
| # ==================== 訓練監控類 ==================== | |
| class TrainingMonitor: | |
| """訓練過程監控器""" | |
| def __init__(self): | |
| self.history = { | |
| 'epoch': [], | |
| 'train_loss': [], | |
| 'eval_loss': [], | |
| 'eval_accuracy': [], | |
| 'eval_f1': [], | |
| 'eval_precision': [], | |
| 'eval_recall': [], | |
| 'learning_rate': [], | |
| 'best_epoch': None, | |
| 'best_metric_value': None | |
| } | |
| def log_epoch(self, epoch: int, train_loss: float, eval_metrics: Dict, lr: float): | |
| """記錄每個 epoch 的結果""" | |
| self.history['epoch'].append(epoch) | |
| self.history['train_loss'].append(train_loss) | |
| self.history['eval_loss'].append(eval_metrics.get('eval_loss', 0)) | |
| self.history['eval_accuracy'].append(eval_metrics.get('eval_accuracy', 0)) | |
| self.history['eval_f1'].append(eval_metrics.get('eval_f1', 0)) | |
| self.history['eval_precision'].append(eval_metrics.get('eval_precision', 0)) | |
| self.history['eval_recall'].append(eval_metrics.get('eval_recall', 0)) | |
| self.history['learning_rate'].append(lr) | |
| def update_best(self, epoch: int, metric_value: float): | |
| """更新最佳結果""" | |
| self.history['best_epoch'] = epoch | |
| self.history['best_metric_value'] = metric_value | |
| def get_summary(self) -> str: | |
| """獲取訓練摘要""" | |
| if not self.history['epoch']: | |
| return "尚無訓練記錄" | |
| summary = "📈 訓練歷程摘要\n" | |
| summary += f"總訓練輪數: {len(self.history['epoch'])}\n" | |
| summary += f"最佳 Epoch: {self.history['best_epoch']}\n" | |
| summary += f"最佳指標值: {self.history['best_metric_value']:.4f}\n\n" | |
| summary += "各 Epoch 表現:\n" | |
| for i, epoch in enumerate(self.history['epoch']): | |
| summary += f"Epoch {epoch}: Loss={self.history['train_loss'][i]:.4f}, " | |
| summary += f"F1={self.history['eval_f1'][i]:.4f}, " | |
| summary += f"Acc={self.history['eval_accuracy'][i]:.4f}\n" | |
| return summary | |
| # ==================== 權重計算改進 ==================== | |
| def calculate_class_weights(n0: int, n1: int, weight_mult: float = 1.0, | |
| method: str = 'sqrt') -> Tuple[float, float]: | |
| """ | |
| 改進的類別權重計算 | |
| Args: | |
| n0: 負類樣本數(存活) | |
| n1: 正類樣本數(死亡) | |
| weight_mult: 權重倍數調整 | |
| method: 計算方法 ('balanced', 'sqrt', 'log', 'custom') | |
| Returns: | |
| (w0, w1): 類別權重 | |
| """ | |
| if n1 == 0: | |
| return 1.0, 1.0 | |
| ratio = n0 / n1 | |
| total = n0 + n1 | |
| if method == 'balanced': | |
| # sklearn 風格的平衡權重 | |
| w0 = total / (2 * n0) if n0 > 0 else 1.0 | |
| w1 = total / (2 * n1) if n1 > 0 else 1.0 | |
| w1 *= weight_mult | |
| elif method == 'sqrt': | |
| # 使用平方根緩和極端權重(推薦用於極度不平衡) | |
| w0 = 1.0 | |
| w1 = min(np.sqrt(ratio) * weight_mult, 10.0) # 設置上限為 10 | |
| elif method == 'log': | |
| # 使用對數進一步緩和 | |
| w0 = 1.0 | |
| w1 = min(np.log1p(ratio) * weight_mult, 8.0) # 設置上限為 8 | |
| elif method == 'custom': | |
| # 自定義邏輯,根據不平衡程度調整 | |
| if ratio > 20: # 極度不平衡 | |
| w0 = 1.0 | |
| w1 = min(5.0 * weight_mult, 10.0) | |
| elif ratio > 10: # 高度不平衡 | |
| w0 = 1.0 | |
| w1 = min(ratio * 0.3 * weight_mult, 8.0) | |
| elif ratio > 5: # 中度不平衡 | |
| w0 = 1.0 | |
| w1 = min(ratio * 0.5 * weight_mult, 6.0) | |
| else: # 輕度不平衡 | |
| w0 = 1.0 | |
| w1 = ratio * weight_mult | |
| else: | |
| # 預設使用 sqrt 方法 | |
| w0 = 1.0 | |
| w1 = min(np.sqrt(ratio) * weight_mult, 10.0) | |
| return w0, w1 | |
| # ==================== 評估指標計算 ==================== | |
| def compute_metrics(pred): | |
| """計算完整的評估指標""" | |
| try: | |
| labels = pred.label_ids | |
| preds = pred.predictions.argmax(-1) | |
| # 基本指標 | |
| precision, recall, f1, _ = precision_recall_fscore_support( | |
| labels, preds, average='binary', pos_label=1, zero_division=0 | |
| ) | |
| acc = accuracy_score(labels, preds) | |
| # 混淆矩陣 | |
| cm = confusion_matrix(labels, preds) | |
| tn = fp = fn = tp = 0 | |
| if cm.shape == (2, 2): | |
| tn, fp, fn, tp = cm.ravel() | |
| # 敏感度和特異度 | |
| sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0 | |
| specificity = tn / (tn + fp) if (tn + fp) > 0 else 0 | |
| # 額外指標 | |
| ppv = tp / (tp + fp) if (tp + fp) > 0 else 0 # 陽性預測值 | |
| npv = tn / (tn + fn) if (tn + fn) > 0 else 0 # 陰性預測值 | |
| return { | |
| 'accuracy': acc, | |
| 'f1': f1, | |
| 'precision': precision, | |
| 'recall': recall, | |
| 'sensitivity': sensitivity, | |
| 'specificity': specificity, | |
| 'ppv': ppv, | |
| 'npv': npv, | |
| 'tp': int(tp), | |
| 'tn': int(tn), | |
| 'fp': int(fp), | |
| 'fn': int(fn) | |
| } | |
| except Exception as e: | |
| print(f"Error in compute_metrics: {e}") | |
| return {k: 0 for k in ['accuracy', 'f1', 'precision', 'recall', | |
| 'sensitivity', 'specificity', 'ppv', 'npv', | |
| 'tp', 'tn', 'fp', 'fn']} | |
| # ==================== 基準模型評估(修正版,只保留一個) ==================== | |
| def evaluate_baseline(model, tokenizer, test_dataset, device, batch_size=16): | |
| """評估未微調的基準模型""" | |
| model.eval() | |
| all_preds = [] | |
| all_labels = [] | |
| def collate_fn(batch): | |
| return { | |
| 'input_ids': torch.stack([torch.tensor(item['input_ids']) for item in batch]), | |
| 'attention_mask': torch.stack([torch.tensor(item['attention_mask']) for item in batch]), | |
| 'labels': torch.tensor([item['label'] for item in batch]) | |
| } | |
| dataloader = DataLoader( | |
| test_dataset, | |
| batch_size=batch_size, | |
| collate_fn=collate_fn, | |
| pin_memory=torch.cuda.is_available(), | |
| num_workers=0 # 避免多進程問題 | |
| ) | |
| with torch.no_grad(): | |
| for batch in dataloader: | |
| labels = batch.pop('labels') | |
| inputs = {k: v.to(device) for k, v in batch.items()} | |
| outputs = model(**inputs) | |
| preds = torch.argmax(outputs.logits, dim=-1) | |
| all_preds.extend(preds.cpu().numpy()) | |
| all_labels.extend(labels.numpy()) | |
| # 計算所有指標 | |
| precision, recall, f1, _ = precision_recall_fscore_support( | |
| all_labels, all_preds, average='binary', pos_label=1, zero_division=0 | |
| ) | |
| acc = accuracy_score(all_labels, all_preds) | |
| cm = confusion_matrix(all_labels, all_preds) | |
| tn = fp = fn = tp = 0 | |
| if cm.shape == (2, 2): | |
| tn, fp, fn, tp = cm.ravel() | |
| sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0 | |
| specificity = tn / (tn + fp) if (tn + fp) > 0 else 0 | |
| ppv = tp / (tp + fp) if (tp + fp) > 0 else 0 | |
| npv = tn / (tn + fn) if (tn + fn) > 0 else 0 | |
| return { | |
| 'accuracy': acc, | |
| 'f1': f1, | |
| 'precision': precision, | |
| 'recall': recall, | |
| 'sensitivity': sensitivity, | |
| 'specificity': specificity, | |
| 'ppv': ppv, | |
| 'npv': npv, | |
| 'tp': int(tp), | |
| 'tn': int(tn), | |
| 'fp': int(fp), | |
| 'fn': int(fn) | |
| } | |
| # ==================== 自定義 Trainer 與 Early Stopping ==================== | |
| class CustomTrainer(Trainer): | |
| """支援類別權重、Focal Loss 和 Early Stopping 的 Trainer""" | |
| def __init__(self, *args, class_weights=None, use_focal_loss=False, | |
| focal_gamma=2.0, monitor=None, early_stopping_patience=3, | |
| early_stopping_metric='eval_f1', **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.class_weights = class_weights | |
| self.use_focal_loss = use_focal_loss | |
| self.focal_gamma = focal_gamma | |
| self.monitor = monitor | |
| self.early_stopping_patience = early_stopping_patience | |
| self.early_stopping_metric = early_stopping_metric | |
| self.best_metric = -float('inf') | |
| self.best_model_state = None | |
| self.patience_counter = 0 | |
| self.current_epoch = 0 | |
| def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): | |
| """計算損失函數""" | |
| labels = inputs.pop("labels") | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| if self.use_focal_loss and self.class_weights is not None: | |
| # Focal Loss 實現 | |
| ce_loss = nn.CrossEntropyLoss(weight=self.class_weights, reduction='none')( | |
| logits.view(-1, 2), labels.view(-1) | |
| ) | |
| pt = torch.exp(-ce_loss) | |
| focal_loss = ((1 - pt) ** self.focal_gamma * ce_loss).mean() | |
| loss = focal_loss | |
| elif self.class_weights is not None: | |
| # 標準加權交叉熵 | |
| loss_fct = nn.CrossEntropyLoss(weight=self.class_weights) | |
| loss = loss_fct(logits.view(-1, 2), labels.view(-1)) | |
| else: | |
| # 標準交叉熵 | |
| loss_fct = nn.CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, 2), labels.view(-1)) | |
| return (loss, outputs) if return_outputs else loss | |
| def on_epoch_end(self, args, state, control, **kwargs): | |
| """每個 epoch 結束時的回調""" | |
| self.current_epoch += 1 | |
| # 評估模型 | |
| metrics = self.evaluate() | |
| # 記錄到監控器 | |
| if self.monitor: | |
| self.monitor.log_epoch( | |
| epoch=self.current_epoch, | |
| train_loss=state.log_history[-1].get('loss', 0) if state.log_history else 0, | |
| eval_metrics=metrics, | |
| lr=self.get_learning_rate() | |
| ) | |
| # Early Stopping 檢查 | |
| current_metric = metrics.get(self.early_stopping_metric, 0) | |
| if current_metric > self.best_metric: | |
| self.best_metric = current_metric | |
| self.best_model_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()} | |
| self.patience_counter = 0 | |
| if self.monitor: | |
| self.monitor.update_best(self.current_epoch, current_metric) | |
| print(f"✅ Epoch {self.current_epoch}: 新最佳 {self.early_stopping_metric} = {current_metric:.4f}") | |
| else: | |
| self.patience_counter += 1 | |
| print(f"⏳ Epoch {self.current_epoch}: 無改善 (patience: {self.patience_counter}/{self.early_stopping_patience})") | |
| if self.patience_counter >= self.early_stopping_patience: | |
| print(f"🛑 Early Stopping 於 Epoch {self.current_epoch}") | |
| control.should_training_stop = True | |
| return control | |
| def get_learning_rate(self): | |
| """獲取當前學習率""" | |
| if self.optimizer is None: | |
| return 0 | |
| return self.optimizer.param_groups[0]['lr'] | |
| def load_best_model(self): | |
| """載入最佳模型""" | |
| if self.best_model_state: | |
| self.model.load_state_dict(self.best_model_state) | |
| print(f"✅ 已載入最佳模型 (最佳 {self.early_stopping_metric} = {self.best_metric:.4f})") | |
| # ==================== 基準模型快取(改進版) ==================== | |
| def get_cached_baseline_model(model_name: str, num_labels: int = 2): | |
| """使用 LRU 快取管理基準模型""" | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = BertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels) | |
| return model.to(device) | |
| # ==================== 改善率計算 ==================== | |
| def calculate_improvement(baseline_val: float, finetuned_val: float) -> float: | |
| """安全計算改善率""" | |
| if baseline_val == 0: | |
| return float('inf') if finetuned_val > 0 else 0.0 | |
| return (finetuned_val - baseline_val) / baseline_val * 100 | |
| def format_improvement(val: float) -> str: | |
| """格式化改善率顯示""" | |
| if val == float('inf'): | |
| return "N/A (baseline=0)" | |
| elif val > 0: | |
| return f"↑ {val:.1f}%" | |
| elif val < 0: | |
| return f"↓ {abs(val):.1f}%" | |
| else: | |
| return "→ 0.0%" | |
| # ==================== 主要訓練函數(改進版) ==================== | |
| def train_bert_model(csv_file, base_model, method, num_epochs, batch_size, learning_rate, | |
| weight_decay, dropout, lora_r, lora_alpha, lora_dropout, | |
| weight_mult, weight_method, best_metric, use_early_stopping, patience): | |
| """ | |
| 改進的 BERT 模型訓練函數 | |
| """ | |
| global trained_models, model_counter, training_histories | |
| model_mapping = { | |
| "BERT-base": "bert-base-uncased", | |
| "BERT-base-chinese": "bert-base-chinese", | |
| "BioBERT": "dmis-lab/biobert-base-cased-v1.2", | |
| "SciBERT": "allenai/scibert_scivocab_uncased" | |
| } | |
| model_name = model_mapping.get(base_model, "bert-base-uncased") | |
| try: | |
| # ========== 資料驗證與載入 ========== | |
| if csv_file is None: | |
| return "❌ 請上傳 CSV 檔案", "", "", "", "" | |
| df = pd.read_csv(csv_file.name) | |
| if 'Text' not in df.columns or 'label' not in df.columns: | |
| return "❌ CSV 必須包含 'Text' 和 'label' 欄位", "", "", "", "" | |
| # 資料清理 | |
| df_clean = pd.DataFrame({ | |
| 'text': df['Text'].astype(str), | |
| 'label': df['label'].astype(int) | |
| }).dropna() | |
| # 統計資料 | |
| n0 = int(sum(df_clean['label'] == 0)) | |
| n1 = int(sum(df_clean['label'] == 1)) | |
| if n1 == 0: | |
| return "❌ 資料集中沒有正類樣本(死亡)", "", "", "", "" | |
| ratio = n0 / n1 if n1 > 0 else 0 | |
| # ========== 計算類別權重 ========== | |
| w0, w1 = calculate_class_weights(n0, n1, weight_mult, method=weight_method) | |
| # ========== 準備資料資訊 ========== | |
| info = f"📊 資料集統計\n" | |
| info += f"{'='*50}\n" | |
| info += f"總樣本數: {len(df_clean):,}\n" | |
| info += f"存活 (0): {n0:,} ({n0/len(df_clean)*100:.1f}%)\n" | |
| info += f"死亡 (1): {n1:,} ({n1/len(df_clean)*100:.1f}%)\n" | |
| info += f"不平衡比例: {ratio:.2f}:1\n" | |
| info += f"\n⚖️ 類別權重設定\n" | |
| info += f"{'='*50}\n" | |
| info += f"計算方法: {weight_method}\n" | |
| info += f"存活權重: {w0:.3f}\n" | |
| info += f"死亡權重: {w1:.3f}\n" | |
| info += f"權重比例: 1:{w1/w0:.2f}\n" | |
| # ========== 模型與分詞器初始化 ========== | |
| info += f"\n🤖 模型配置\n" | |
| info += f"{'='*50}\n" | |
| info += f"基礎模型: {base_model}\n" | |
| info += f"模型路徑: {model_name}\n" | |
| info += f"微調方法: {method.upper()}\n" | |
| tokenizer = BertTokenizer.from_pretrained(model_name) | |
| # ========== 資料集準備 ========== | |
| dataset = Dataset.from_pandas(df_clean[['text', 'label']]) | |
| def preprocess(examples): | |
| return tokenizer( | |
| examples['text'], | |
| truncation=True, | |
| padding='max_length', | |
| max_length=128 | |
| ) | |
| tokenized = dataset.map(preprocess, batched=True, remove_columns=['text']) | |
| split = tokenized.train_test_split(test_size=0.2, seed=42, stratify=tokenized['label']) | |
| # ========== 設備配置 ========== | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| info += f"運算設備: {'GPU ✅ (' + torch.cuda.get_device_name(0) + ')' if torch.cuda.is_available() else 'CPU ⚠️'}\n" | |
| # ========== 評估基準模型 ========== | |
| info += f"\n📏 基準模型評估\n" | |
| info += f"{'='*50}\n" | |
| info += f"正在評估未微調的 {base_model}...\n" | |
| baseline_model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2) | |
| baseline_model = baseline_model.to(device) | |
| baseline_perf = evaluate_baseline( | |
| baseline_model, tokenizer, split['test'], device, batch_size=batch_size*2 | |
| ) | |
| info += f"基準 F1 分數: {baseline_perf['f1']:.4f}\n" | |
| info += f"基準準確率: {baseline_perf['accuracy']:.4f}\n" | |
| # 清理基準模型記憶體 | |
| del baseline_model | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # ========== 配置微調模型 ========== | |
| info += f"\n🔧 微調配置\n" | |
| info += f"{'='*50}\n" | |
| model = BertForSequenceClassification.from_pretrained( | |
| model_name, | |
| num_labels=2, | |
| hidden_dropout_prob=dropout, | |
| attention_probs_dropout_prob=dropout | |
| ) | |
| # 應用 PEFT 方法 | |
| peft_applied = False | |
| if method == "lora": | |
| from peft import LoraConfig, get_peft_model, TaskType | |
| config = LoraConfig( | |
| task_type=TaskType.SEQ_CLS, | |
| r=int(lora_r), | |
| lora_alpha=int(lora_alpha), | |
| lora_dropout=lora_dropout, | |
| target_modules=["query", "value"], | |
| bias="none" | |
| ) | |
| model = get_peft_model(model, config) | |
| peft_applied = True | |
| info += f"✅ LoRA 已套用\n" | |
| info += f" - Rank (r): {int(lora_r)}\n" | |
| info += f" - Alpha: {int(lora_alpha)}\n" | |
| info += f" - Dropout: {lora_dropout}\n" | |
| elif method == "adalora": | |
| from peft import AdaLoraConfig, get_peft_model, TaskType | |
| config = AdaLoraConfig( | |
| task_type=TaskType.SEQ_CLS, | |
| r=int(lora_r), | |
| lora_alpha=int(lora_alpha), | |
| lora_dropout=lora_dropout, | |
| target_modules=["query", "value"], | |
| init_r=12, | |
| target_r=int(lora_r), | |
| tinit=200, | |
| tfinal=1000, | |
| deltaT=10 | |
| ) | |
| model = get_peft_model(model, config) | |
| peft_applied = True | |
| info += f"✅ AdaLoRA 已套用\n" | |
| info += f" - Initial Rank: 12\n" | |
| info += f" - Target Rank: {int(lora_r)}\n" | |
| info += f" - Alpha: {int(lora_alpha)}\n" | |
| elif method == "full": | |
| info += f"✅ Full Fine-tuning 模式\n" | |
| peft_applied = False | |
| model = model.to(device) | |
| # 參數統計 | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| info += f"\n💾 模型參數\n" | |
| info += f"{'='*50}\n" | |
| info += f"總參數量: {total_params:,}\n" | |
| info += f"可訓練參數: {trainable_params:,}\n" | |
| info += f"可訓練比例: {trainable_params/total_params*100:.2f}%\n" | |
| info += f"記憶體節省: {(1 - trainable_params/total_params)*100:.1f}%\n" | |
| # ========== 準備訓練 ========== | |
| weights = torch.tensor([w0, w1], dtype=torch.float).to(device) | |
| use_focal = ratio > 10 # 極度不平衡時使用 Focal Loss | |
| if use_focal: | |
| info += f"\n⚡ 特殊設定\n" | |
| info += f"{'='*50}\n" | |
| info += f"使用 Focal Loss (γ=2.0) 處理極度不平衡\n" | |
| # 訓練參數 | |
| training_args = TrainingArguments( | |
| output_dir='./results', | |
| num_train_epochs=int(num_epochs), | |
| per_device_train_batch_size=int(batch_size), | |
| per_device_eval_batch_size=int(batch_size) * 2, | |
| learning_rate=float(learning_rate), | |
| weight_decay=float(weight_decay), | |
| evaluation_strategy="epoch", | |
| save_strategy="no", # 使用自定義保存策略 | |
| load_best_model_at_end=False, | |
| report_to="none", | |
| logging_steps=max(1, len(split['train']) // (int(batch_size) * 10)), | |
| warmup_steps=min(500, len(split['train']) // int(batch_size)), | |
| logging_first_step=True, | |
| remove_unused_columns=False, | |
| label_smoothing_factor=0.1 if ratio > 20 else 0.0, # 極度不平衡時使用標籤平滑 | |
| ) | |
| # 創建監控器 | |
| monitor = TrainingMonitor() | |
| # 創建自定義 Trainer | |
| trainer = CustomTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=split['train'], | |
| eval_dataset=split['test'], | |
| compute_metrics=compute_metrics, | |
| class_weights=weights, | |
| use_focal_loss=use_focal, | |
| focal_gamma=2.0, | |
| monitor=monitor, | |
| early_stopping_patience=patience if use_early_stopping else 999, | |
| early_stopping_metric=f'eval_{best_metric}' | |
| ) | |
| info += f"\n🚀 訓練設定\n" | |
| info += f"{'='*50}\n" | |
| info += f"訓練樣本: {len(split['train']):,}\n" | |
| info += f"測試樣本: {len(split['test']):,}\n" | |
| info += f"批次大小: {int(batch_size)}\n" | |
| info += f"訓練輪數: {int(num_epochs)}\n" | |
| info += f"批次數/輪: {len(split['train']) // int(batch_size)}\n" | |
| info += f"Early Stopping: {'開啟 (patience=' + str(patience) + ')' if use_early_stopping else '關閉'}\n" | |
| info += f"最佳指標: {best_metric}\n" | |
| info += f"\n⏳ 開始訓練...\n" | |
| info += f"{'='*50}\n" | |
| # ========== 執行訓練 ========== | |
| train_result = trainer.train() | |
| # 載入最佳模型 | |
| if use_early_stopping: | |
| trainer.load_best_model() | |
| # 最終評估 | |
| final_results = trainer.evaluate() | |
| # ========== 保存模型與結果 ========== | |
| model_counter += 1 | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| model_id = f"{base_model}_{method}_{model_counter}_{timestamp}" | |
| trained_models[model_id] = { | |
| 'model': model, | |
| 'tokenizer': tokenizer, | |
| 'results': final_results, | |
| 'baseline': baseline_perf, | |
| 'config': { | |
| 'type': base_model, | |
| 'model_name': model_name, | |
| 'method': method, | |
| 'metric': best_metric, | |
| 'epochs': int(num_epochs), | |
| 'batch_size': int(batch_size), | |
| 'learning_rate': float(learning_rate), | |
| 'weight_method': weight_method, | |
| 'weight_mult': weight_mult | |
| }, | |
| 'timestamp': timestamp, | |
| 'monitor': monitor # 保存訓練歷史 | |
| } | |
| training_histories[model_id] = monitor.history | |
| info += f"\n✅ 訓練完成!\n" | |
| info += f"最終 Training Loss: {train_result.training_loss:.4f}\n" | |
| if monitor.history['best_epoch']: | |
| info += f"最佳 Epoch: {monitor.history['best_epoch']}\n" | |
| # ========== 準備輸出結果 ========== | |
| # 基準模型結果 | |
| baseline_output = format_baseline_results(baseline_perf) | |
| # 微調模型結果 | |
| finetuned_output = format_finetuned_results(model_id, final_results) | |
| # 比較結果 | |
| comparison_output = format_comparison_results(baseline_perf, final_results) | |
| # 訓練歷程 | |
| history_output = monitor.get_summary() | |
| return info, baseline_output, finetuned_output, comparison_output, history_output | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"❌ 錯誤發生\n\n錯誤類型: {type(e).__name__}\n錯誤訊息: {str(e)}\n\n" | |
| error_msg += f"詳細追蹤:\n{traceback.format_exc()}" | |
| return error_msg, "", "", "", "" | |
| # ==================== 格式化輸出函數 ==================== | |
| def format_baseline_results(baseline_perf: Dict) -> str: | |
| """格式化基準模型結果""" | |
| output = "🔬 純 BERT(未微調)\n\n" | |
| output += "📊 模型表現\n" | |
| output += f"{'='*30}\n" | |
| output += f"F1 Score: {baseline_perf['f1']:.4f}\n" | |
| output += f"Accuracy: {baseline_perf['accuracy']:.4f}\n" | |
| output += f"Precision: {baseline_perf['precision']:.4f}\n" | |
| output += f"Recall: {baseline_perf['recall']:.4f}\n" | |
| output += f"Sensitivity: {baseline_perf['sensitivity']:.4f}\n" | |
| output += f"Specificity: {baseline_perf['specificity']:.4f}\n" | |
| output += f"PPV: {baseline_perf['ppv']:.4f}\n" | |
| output += f"NPV: {baseline_perf['npv']:.4f}\n\n" | |
| output += "📈 混淆矩陣\n" | |
| output += f"{'='*30}\n" | |
| output += f" 預測 0 預測 1\n" | |
| output += f"實際 0 {baseline_perf['tn']:4d} {baseline_perf['fp']:4d}\n" | |
| output += f"實際 1 {baseline_perf['fn']:4d} {baseline_perf['tp']:4d}\n" | |
| return output | |
| def format_finetuned_results(model_id: str, results: Dict) -> str: | |
| """格式化微調模型結果""" | |
| output = f"✅ 微調 BERT\n" | |
| output += f"模型 ID: {model_id}\n\n" | |
| output += "📊 模型表現\n" | |
| output += f"{'='*30}\n" | |
| output += f"F1 Score: {results['eval_f1']:.4f}\n" | |
| output += f"Accuracy: {results['eval_accuracy']:.4f}\n" | |
| output += f"Precision: {results['eval_precision']:.4f}\n" | |
| output += f"Recall: {results['eval_recall']:.4f}\n" | |
| output += f"Sensitivity: {results['eval_sensitivity']:.4f}\n" | |
| output += f"Specificity: {results['eval_specificity']:.4f}\n" | |
| output += f"PPV: {results['eval_ppv']:.4f}\n" | |
| output += f"NPV: {results['eval_npv']:.4f}\n\n" | |
| output += "📈 混淆矩陣\n" | |
| output += f"{'='*30}\n" | |
| output += f" 預測 0 預測 1\n" | |
| output += f"實際 0 {results['eval_tn']:4d} {results['eval_fp']:4d}\n" | |
| output += f"實際 1 {results['eval_fn']:4d} {results['eval_tp']:4d}\n" | |
| return output | |
| def format_comparison_results(baseline_perf: Dict, finetuned_results: Dict) -> str: | |
| """格式化比較結果""" | |
| output = "📊 純 BERT vs 微調 BERT 比較\n\n" | |
| output += "指標改善分析:\n" | |
| output += f"{'='*50}\n" | |
| output += f"{'指標':<12} {'基準':>8} {'微調':>8} {'變化':>10} {'改善率':>10}\n" | |
| output += f"{'-'*50}\n" | |
| metrics = [ | |
| ('F1', 'f1', 'eval_f1'), | |
| ('Accuracy', 'accuracy', 'eval_accuracy'), | |
| ('Precision', 'precision', 'eval_precision'), | |
| ('Recall', 'recall', 'eval_recall'), | |
| ('Sensitivity', 'sensitivity', 'eval_sensitivity'), | |
| ('Specificity', 'specificity', 'eval_specificity'), | |
| ('PPV', 'ppv', 'eval_ppv'), | |
| ('NPV', 'npv', 'eval_npv') | |
| ] | |
| for name, base_key, fine_key in metrics: | |
| base_val = baseline_perf[base_key] | |
| fine_val = finetuned_results[fine_key] | |
| change = fine_val - base_val | |
| improve = calculate_improvement(base_val, fine_val) | |
| output += f"{name:<12} {base_val:>8.4f} {fine_val:>8.4f} " | |
| output += f"{change:+10.4f} {format_improvement(improve):>10}\n" | |
| output += f"\n混淆矩陣變化:\n" | |
| output += f"{'='*40}\n" | |
| output += f"{'項目':<10} {'基準':>8} {'微調':>8} {'變化':>10}\n" | |
| output += f"{'-'*40}\n" | |
| cm_items = [ | |
| ('True Pos', 'tp', 'eval_tp'), | |
| ('True Neg', 'tn', 'eval_tn'), | |
| ('False Pos', 'fp', 'eval_fp'), | |
| ('False Neg', 'fn', 'eval_fn') | |
| ] | |
| for name, base_key, fine_key in cm_items: | |
| base_val = baseline_perf[base_key] | |
| fine_val = finetuned_results[fine_key] | |
| change = fine_val - base_val | |
| output += f"{name:<10} {base_val:>8d} {fine_val:>8d} {change:+10d}\n" | |
| # 總結 | |
| output += f"\n📈 整體評估:\n" | |
| output += f"{'='*40}\n" | |
| f1_improve = calculate_improvement(baseline_perf['f1'], finetuned_results['eval_f1']) | |
| if f1_improve > 10: | |
| output += "✅ 顯著改善:微調帶來明顯的性能提升!\n" | |
| elif f1_improve > 0: | |
| output += "✅ 有所改善:微調產生正向影響。\n" | |
| elif f1_improve == 0: | |
| output += "➖ 無變化:微調未產生明顯影響。\n" | |
| else: | |
| output += "⚠️ 性能下降:可能需要調整超參數。\n" | |
| return output | |
| # ==================== 預測函數(改進版) ==================== | |
| def predict(model_id, text): | |
| """使用選定模型進行預測並與基準模型比較""" | |
| if not model_id or model_id not in trained_models: | |
| return "❌ 請選擇一個已訓練的模型" | |
| if not text or len(text.strip()) == 0: | |
| return "❌ 請輸入要預測的文字" | |
| try: | |
| # 獲取模型資訊 | |
| info = trained_models[model_id] | |
| model = info['model'] | |
| tokenizer = info['tokenizer'] | |
| config = info['config'] | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # 文字預處理 | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=128 | |
| ) | |
| inputs_device = {k: v.to(device) for k, v in inputs.items()} | |
| # ========== 微調模型預測 ========== | |
| model.eval() | |
| with torch.no_grad(): | |
| outputs = model(**inputs_device) | |
| logits = outputs.logits | |
| probs_finetuned = torch.nn.functional.softmax(logits, dim=-1) | |
| pred_finetuned = torch.argmax(probs_finetuned, dim=-1).item() | |
| confidence_finetuned = probs_finetuned[0][pred_finetuned].item() | |
| # ========== 基準模型預測 ========== | |
| baseline_model = get_cached_baseline_model(config['model_name']) | |
| baseline_model.eval() | |
| with torch.no_grad(): | |
| outputs_baseline = baseline_model(**inputs_device) | |
| logits_baseline = outputs_baseline.logits | |
| probs_baseline = torch.nn.functional.softmax(logits_baseline, dim=-1) | |
| pred_baseline = torch.argmax(probs_baseline, dim=-1).item() | |
| confidence_baseline = probs_baseline[0][pred_baseline].item() | |
| # ========== 格式化輸出 ========== | |
| result_finetuned = "🟢 存活" if pred_finetuned == 0 else "🔴 死亡" | |
| result_baseline = "🟢 存活" if pred_baseline == 0 else "🔴 死亡" | |
| agreement = "✅ 一致" if pred_finetuned == pred_baseline else "⚠️ 不一致" | |
| output = f"""🔮 預測結果比較分析 | |
| 📝 輸入文字 | |
| {'='*60} | |
| {text[:200]}{'...' if len(text) > 200 else ''} | |
| {'='*60} | |
| 🎯 微調模型預測 ({model_id}) | |
| {'='*60} | |
| 預測結果: {result_finetuned} | |
| 預測信心: {confidence_finetuned:.1%} | |
| 機率分布: | |
| • 存活 (0): {probs_finetuned[0][0].item():.2%} | |
| • 死亡 (1): {probs_finetuned[0][1].item():.2%} | |
| 模型配置: | |
| • 方法: {config['method'].upper()} | |
| • 基礎模型: {config['type']} | |
| • 訓練輪數: {config['epochs']} | |
| {'='*60} | |
| 🔬 基準模型預測(未微調 {config['type']}) | |
| {'='*60} | |
| 預測結果: {result_baseline} | |
| 預測信心: {confidence_baseline:.1%} | |
| 機率分布: | |
| • 存活 (0): {probs_baseline[0][0].item():.2%} | |
| • 死亡 (1): {probs_baseline[0][1].item():.2%} | |
| {'='*60} | |
| 📊 預測分析 | |
| {'='*60} | |
| 兩模型預測: {agreement} | |
| """ | |
| if pred_finetuned != pred_baseline: | |
| output += f""" | |
| 💡 差異分析: | |
| 微調模型預測【{result_finetuned}】(信心: {confidence_finetuned:.1%}) | |
| 基準模型預測【{result_baseline}】(信心: {confidence_baseline:.1%}) | |
| 這種差異顯示了微調對此特定案例的影響。 | |
| 微調模型可能學習到了更適合您資料集的特徵。 | |
| """ | |
| else: | |
| output += f""" | |
| ✅ 預測一致性分析: | |
| 兩個模型都預測為【{result_finetuned}】 | |
| 信心差異: {abs(confidence_finetuned - confidence_baseline):.1%} | |
| """ | |
| # 加入模型整體表現對比 | |
| f1_improve = calculate_improvement( | |
| info['baseline']['f1'], | |
| info['results']['eval_f1'] | |
| ) | |
| output += f""" | |
| 📈 模型整體表現對比 | |
| {'='*60} | |
| 微調模型 F1: {info['results']['eval_f1']:.4f} | |
| 基準模型 F1: {info['baseline']['f1']:.4f} | |
| 改善幅度: {format_improvement(f1_improve)} | |
| 微調模型準確率: {info['results']['eval_accuracy']:.4f} | |
| 基準模型準確率: {info['baseline']['accuracy']:.4f} | |
| """ | |
| return output | |
| except Exception as e: | |
| import traceback | |
| return f"❌ 預測時發生錯誤\n\n{str(e)}\n\n{traceback.format_exc()}" | |
| # ==================== 模型比較函數 ==================== | |
| def compare_models(): | |
| """比較所有已訓練的模型""" | |
| if not trained_models: | |
| return "❌ 尚未訓練任何模型。請先在「訓練」頁面訓練模型。" | |
| output = "# 📊 模型比較報告\n\n" | |
| output += f"共有 {len(trained_models)} 個已訓練模型\n\n" | |
| # 微調模型表現表格 | |
| output += "## 🎯 微調模型表現\n\n" | |
| output += "| 模型 ID | 基礎模型 | 方法 | F1 | 準確率 | 精確率 | 召回率 | 敏感度 | 特異度 |\n" | |
| output += "|---------|----------|------|-----|--------|--------|--------|--------|--------|\n" | |
| for model_id, info in trained_models.items(): | |
| r = info['results'] | |
| c = info['config'] | |
| # 縮短模型 ID 顯示 | |
| short_id = f"{c['type']}_{c['method']}_{info['timestamp'][-6:]}" | |
| output += f"| {short_id} | {c['type']} | {c['method'].upper()} | " | |
| output += f"{r['eval_f1']:.4f} | {r['eval_accuracy']:.4f} | " | |
| output += f"{r['eval_precision']:.4f} | {r['eval_recall']:.4f} | " | |
| output += f"{r['eval_sensitivity']:.4f} | {r['eval_specificity']:.4f} |\n" | |
| # 基準模型表現 | |
| output += "\n## 🔬 基準模型表現(未微調)\n\n" | |
| # 獲取唯一的基準模型 | |
| unique_baselines = {} | |
| for model_id, info in trained_models.items(): | |
| base_type = info['config']['type'] | |
| if base_type not in unique_baselines: | |
| unique_baselines[base_type] = info['baseline'] | |
| output += "| 基礎模型 | F1 | 準確率 | 精確率 | 召回率 | 敏感度 | 特異度 |\n" | |
| output += "|----------|-----|--------|--------|--------|--------|--------|\n" | |
| for base_type, baseline in unique_baselines.items(): | |
| output += f"| {base_type} | {baseline['f1']:.4f} | {baseline['accuracy']:.4f} | " | |
| output += f"{baseline['precision']:.4f} | {baseline['recall']:.4f} | " | |
| output += f"{baseline['sensitivity']:.4f} | {baseline['specificity']:.4f} |\n" | |
| # 最佳模型分析 | |
| output += "\n## 🏆 最佳模型(各指標)\n\n" | |
| metrics_to_check = [ | |
| ('F1 Score', 'eval_f1'), | |
| ('準確率', 'eval_accuracy'), | |
| ('精確率', 'eval_precision'), | |
| ('召回率', 'eval_recall'), | |
| ('敏感度', 'eval_sensitivity'), | |
| ('特異度', 'eval_specificity') | |
| ] | |
| for metric_name, metric_key in metrics_to_check: | |
| best_model = max( | |
| trained_models.items(), | |
| key=lambda x: x[1]['results'][metric_key] | |
| ) | |
| model_id = best_model[0] | |
| value = best_model[1]['results'][metric_key] | |
| baseline_val = best_model[1]['baseline'][metric_key.replace('eval_', '')] | |
| improvement = calculate_improvement(baseline_val, value) | |
| output += f"**{metric_name}**: {model_id[:30]}... " | |
| output += f"({value:.4f}, 改善 {format_improvement(improvement)})\n\n" | |
| # 改善統計 | |
| output += "## 📈 改善統計\n\n" | |
| improvements = [] | |
| for model_id, info in trained_models.items(): | |
| f1_base = info['baseline']['f1'] | |
| f1_fine = info['results']['eval_f1'] | |
| improve = calculate_improvement(f1_base, f1_fine) | |
| if improve != float('inf'): | |
| improvements.append({ | |
| 'model': model_id, | |
| 'improvement': improve, | |
| 'method': info['config']['method'] | |
| }) | |
| if improvements: | |
| avg_improvement = np.mean([x['improvement'] for x in improvements]) | |
| max_improvement = max(improvements, key=lambda x: x['improvement']) | |
| min_improvement = min(improvements, key=lambda x: x['improvement']) | |
| output += f"平均 F1 改善: {format_improvement(avg_improvement)}\n" | |
| output += f"最大改善: {max_improvement['model'][:30]}... ({format_improvement(max_improvement['improvement'])})\n" | |
| output += f"最小改善: {min_improvement['model'][:30]}... ({format_improvement(min_improvement['improvement'])})\n\n" | |
| # 方法比較 | |
| method_improvements = {} | |
| for imp in improvements: | |
| method = imp['method'] | |
| if method not in method_improvements: | |
| method_improvements[method] = [] | |
| method_improvements[method].append(imp['improvement']) | |
| output += "### 各方法平均改善:\n" | |
| for method, imps in method_improvements.items(): | |
| avg_imp = np.mean(imps) | |
| output += f"- **{method.upper()}**: {format_improvement(avg_imp)}\n" | |
| return output | |
| # ==================== Gradio UI ==================== | |
| def create_demo(): | |
| """創建 Gradio 介面""" | |
| with gr.Blocks( | |
| title="BERT Fine-tuning 教學平台", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container {font-family: 'Microsoft JhengHei', 'Arial', sans-serif;} | |
| """ | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🧬 BERT Fine-tuning 教學平台 | |
| ### 比較基準模型 vs 微調模型的表現差異(改進版) | |
| """ | |
| ) | |
| with gr.Tab("🎯 訓練"): | |
| gr.Markdown("## 步驟 1: 選擇基礎模型") | |
| base_model = gr.Dropdown( | |
| choices=["BERT-base", "BERT-base-chinese", "BioBERT", "SciBERT"], | |
| value="BERT-base", | |
| label="基礎模型", | |
| info="選擇適合您資料的預訓練模型" | |
| ) | |
| gr.Markdown("## 步驟 2: 選擇微調方法") | |
| method = gr.Radio( | |
| choices=["lora", "adalora", "full"], | |
| value="lora", | |
| label="微調方法", | |
| info="LoRA 和 AdaLoRA 是參數高效方法,Full 是完全微調" | |
| ) | |
| gr.Markdown("## 步驟 3: 上傳資料") | |
| csv_file = gr.File( | |
| label="CSV 檔案(需包含 Text 和 label 欄位)", | |
| file_types=[".csv"] | |
| ) | |
| gr.Markdown("## 步驟 4: 設定訓練參數") | |
| with gr.Accordion("🎯 基本訓練參數", open=True): | |
| with gr.Row(): | |
| num_epochs = gr.Number( | |
| value=5, label="訓練輪數", minimum=1, maximum=50, precision=0, | |
| info="建議 3-10 輪,過多可能過擬合" | |
| ) | |
| batch_size = gr.Number( | |
| value=8, label="批次大小", minimum=1, maximum=64, precision=0, | |
| info="GPU 記憶體不足時請降低" | |
| ) | |
| learning_rate = gr.Number( | |
| value=3e-5, label="學習率", minimum=1e-6, maximum=1e-3, | |
| info="建議 1e-5 到 5e-5" | |
| ) | |
| with gr.Accordion("⚙️ 進階參數"): | |
| with gr.Row(): | |
| weight_decay = gr.Number( | |
| value=0.01, label="權重衰減", minimum=0, maximum=1, | |
| info="防止過擬合,建議 0.01-0.1" | |
| ) | |
| dropout = gr.Number( | |
| value=0.1, label="Dropout 率", minimum=0, maximum=0.5, | |
| info="防止過擬合,建議 0.1-0.3" | |
| ) | |
| with gr.Accordion("🔧 PEFT 參數(LoRA/AdaLoRA)"): | |
| with gr.Row(): | |
| lora_r = gr.Number( | |
| value=16, label="LoRA Rank (r)", minimum=1, maximum=64, precision=0, | |
| info="越大表達能力越強,但參數越多" | |
| ) | |
| lora_alpha = gr.Number( | |
| value=32, label="LoRA Alpha", minimum=1, maximum=128, precision=0, | |
| info="通常設為 Rank 的 2 倍" | |
| ) | |
| lora_dropout = gr.Number( | |
| value=0.05, label="LoRA Dropout", minimum=0, maximum=0.5, | |
| info="LoRA 層的 dropout" | |
| ) | |
| with gr.Accordion("⚖️ 類別平衡設定"): | |
| with gr.Row(): | |
| weight_mult = gr.Number( | |
| value=1.0, label="權重倍數", minimum=0.1, maximum=5.0, | |
| info="調整少數類權重的倍數" | |
| ) | |
| weight_method = gr.Dropdown( | |
| choices=["sqrt", "log", "balanced", "custom"], | |
| value="sqrt", | |
| label="權重計算方法", | |
| info="sqrt 和 log 適合極度不平衡資料" | |
| ) | |
| with gr.Accordion("🎯 訓練策略"): | |
| with gr.Row(): | |
| best_metric = gr.Dropdown( | |
| choices=["f1", "accuracy", "precision", "recall", "sensitivity", "specificity"], | |
| value="f1", | |
| label="最佳模型指標", | |
| info="根據此指標選擇最佳模型" | |
| ) | |
| use_early_stopping = gr.Checkbox( | |
| value=True, label="啟用 Early Stopping", | |
| info="當模型不再改善時提前停止" | |
| ) | |
| patience = gr.Number( | |
| value=3, label="Patience", minimum=1, maximum=10, precision=0, | |
| info="幾輪無改善後停止訓練" | |
| ) | |
| train_btn = gr.Button("🚀 開始訓練", variant="primary", size="lg") | |
| gr.Markdown("## 📊 訓練結果") | |
| with gr.Row(): | |
| data_info = gr.Textbox(label="📋 訓練資訊", lines=25) | |
| history_output = gr.Textbox(label="📈 訓練歷程", lines=25) | |
| with gr.Row(): | |
| baseline_result = gr.Textbox(label="🔬 基準模型(未微調)", lines=15) | |
| finetuned_result = gr.Textbox(label="✅ 微調模型", lines=15) | |
| comparison_result = gr.Textbox(label="📊 效能比較分析", lines=20) | |
| train_btn.click( | |
| train_bert_model, | |
| inputs=[ | |
| csv_file, base_model, method, num_epochs, batch_size, learning_rate, | |
| weight_decay, dropout, lora_r, lora_alpha, lora_dropout, | |
| weight_mult, weight_method, best_metric, use_early_stopping, patience | |
| ], | |
| outputs=[data_info, baseline_result, finetuned_result, comparison_result, history_output] | |
| ) | |
| with gr.Tab("🔮 預測"): | |
| gr.Markdown("## 使用訓練好的模型進行預測") | |
| with gr.Row(): | |
| model_dropdown = gr.Dropdown( | |
| label="選擇模型", | |
| choices=list(trained_models.keys()), | |
| interactive=True | |
| ) | |
| refresh_btn = gr.Button("🔄 刷新模型列表", size="sm") | |
| text_input = gr.Textbox( | |
| label="輸入要預測的文字", | |
| lines=5, | |
| placeholder="請輸入病例描述或相關文字..." | |
| ) | |
| predict_btn = gr.Button("🎯 執行預測", variant="primary", size="lg") | |
| pred_output = gr.Textbox(label="預測結果與分析", lines=25) | |
| # 刷新模型列表 | |
| refresh_btn.click( | |
| lambda: gr.Dropdown(choices=list(trained_models.keys())), | |
| outputs=[model_dropdown] | |
| ) | |
| # 執行預測 | |
| predict_btn.click( | |
| predict, | |
| inputs=[model_dropdown, text_input], | |
| outputs=[pred_output] | |
| ) | |
| # 範例 | |
| gr.Examples( | |
| examples=[ | |
| ["Patient with stage II breast cancer, showing good response to chemotherapy treatment."], | |
| ["Advanced metastatic cancer with multiple organ failure, poor prognosis."], | |
| ["Early stage tumor detected, surgery scheduled, excellent recovery expected."], | |
| ["Terminal stage disease, palliative care initiated, family counseling provided."] | |
| ], | |
| inputs=text_input | |
| ) | |
| with gr.Tab("📊 比較"): | |
| gr.Markdown("## 比較所有已訓練的模型") | |
| compare_btn = gr.Button("📊 生成比較報告", variant="primary", size="lg") | |
| compare_output = gr.Markdown() | |
| compare_btn.click(compare_models, outputs=[compare_output]) | |
| with gr.Tab("📖 說明"): | |
| gr.Markdown(""" | |
| ## 📖 使用說明 | |
| ### 🎯 平台特色 | |
| 本改進版平台提供以下功能: | |
| 1. **自動基準比較**:每次訓練都會自動評估基準模型,清楚顯示微調的改善 | |
| 2. **訓練監控**:記錄每個 epoch 的詳細訓練歷程 | |
| 3. **Early Stopping**:避免過擬合,自動選擇最佳模型 | |
| 4. **多種權重策略**:針對不平衡資料提供多種處理方法 | |
| 5. **完整評估指標**:包含 F1、準確率、精確率、召回率、敏感度、特異度、PPV、NPV | |
| ### 🤖 支援的基礎模型 | |
| - **BERT-base**: 標準英文 BERT,適用於一般英文文本 | |
| - **BERT-base-chinese**: 中文 BERT,適用於中文文本 | |
| - **BioBERT**: 生物醫學領域專用 BERT | |
| - **SciBERT**: 科學文獻專用 BERT | |
| ### 🔧 微調方法說明 | |
| - **LoRA** (Low-Rank Adaptation) | |
| - 參數效率最高,只訓練 <1% 參數 | |
| - 訓練速度快,記憶體需求低 | |
| - 適合大多數場景 | |
| - **AdaLoRA** (Adaptive LoRA) | |
| - 自動調整秩的分配 | |
| - 可能獲得更好的效果 | |
| - 訓練時間稍長 | |
| - **Full** (完全微調) | |
| - 訓練所有參數 | |
| - 可能獲得最佳效果 | |
| - 需要較大記憶體和時間 | |
| ### ⚖️ 處理不平衡資料 | |
| #### 權重計算方法: | |
| 1. **sqrt** (平方根法) - 推薦用於極度不平衡 | |
| - 使用平方根緩和權重 | |
| - 避免權重過大導致過擬合 | |
| 2. **log** (對數法) - 更保守的方法 | |
| - 使用對數進一步緩和 | |
| - 適合極度不平衡且容易過擬合的情況 | |
| 3. **balanced** (平衡法) | |
| - sklearn 風格的自動平衡 | |
| - 適合中度不平衡 | |
| 4. **custom** (自定義) | |
| - 根據不平衡程度自動調整 | |
| - 綜合考慮多種因素 | |
| #### 建議參數設定: | |
| **極度不平衡 (>20:1)** | |
| - 權重方法: sqrt 或 log | |
| - 權重倍數: 0.5-1.0 | |
| - 使用 Focal Loss (自動啟用) | |
| - Early Stopping: 建議開啟 | |
| **高度不平衡 (10-20:1)** | |
| - 權重方法: sqrt | |
| - 權重倍數: 0.8-1.5 | |
| - Early Stopping: 建議開啟 | |
| **中度不平衡 (5-10:1)** | |
| - 權重方法: balanced | |
| - 權重倍數: 1.0-2.0 | |
| **輕度不平衡 (<5:1)** | |
| - 權重方法: balanced | |
| - 權重倍數: 1.5-3.0 | |
| ### 📊 評估指標說明 | |
| - **F1 Score**: 精確率和召回率的調和平均,適合不平衡資料 | |
| - **Accuracy**: 整體準確率 | |
| - **Precision**: 預測為正類中實際為正類的比例 | |
| - **Recall/Sensitivity**: 實際正類中被正確預測的比例 | |
| - **Specificity**: 實際負類中被正確預測的比例 | |
| - **PPV**: 陽性預測值 | |
| - **NPV**: 陰性預測值 | |
| ### 🚀 快速開始指南 | |
| 1. **準備資料** | |
| - CSV 格式,包含 `Text` 和 `label` 欄位 | |
| - label: 0=負類(如存活), 1=正類(如死亡) | |
| 2. **選擇模型與方法** | |
| - 英文資料:BERT-base + LoRA | |
| - 中文資料:BERT-base-chinese + LoRA | |
| - 醫學資料:BioBERT + LoRA | |
| 3. **設定參數** | |
| - 使用預設參數作為起點 | |
| - 根據資料不平衡程度調整權重設定 | |
| 4. **訓練與評估** | |
| - 點擊「開始訓練」 | |
| - 查看基準 vs 微調的比較 | |
| - 觀察訓練歷程 | |
| 5. **測試預測** | |
| - 在「預測」頁面選擇模型 | |
| - 輸入文字進行預測 | |
| - 比較微調前後的差異 | |
| ### ⚠️ 注意事項 | |
| - GPU 可大幅加速訓練 | |
| - 批次大小過大可能導致記憶體不足 | |
| - Early Stopping 可避免過擬合 | |
| - 極度不平衡資料建議使用較保守的權重設定 | |
| ### 💡 優化建議 | |
| 1. **記憶體不足**:降低批次大小或使用 LoRA | |
| 2. **過擬合**:增加 dropout、使用 Early Stopping、降低學習率 | |
| 3. **欠擬合**:增加訓練輪數、提高學習率、增加模型容量 | |
| 4. **不平衡資料**:調整類別權重、使用適當的評估指標(F1) | |
| """) | |
| return demo | |
| # ==================== 主程式 ==================== | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| max_threads=4 | |
| ) |