Spaces:
Paused
Paused
| import gradio as gr | |
| import pandas as pd | |
| import torch | |
| from torch import nn | |
| from transformers import ( | |
| BertTokenizer, | |
| BertForSequenceClassification, | |
| TrainingArguments, | |
| Trainer | |
| ) | |
| from datasets import Dataset | |
| from sklearn.metrics import ( | |
| accuracy_score, | |
| precision_recall_fscore_support, | |
| roc_auc_score, | |
| confusion_matrix | |
| ) | |
| import numpy as np | |
| from datetime import datetime | |
| import json | |
| import os | |
| import gc # 用於記憶體清理 | |
| # PEFT 相關的 import(LoRA 和 AdaLoRA) | |
| try: | |
| from peft import ( | |
| LoraConfig, | |
| AdaLoraConfig, | |
| get_peft_model, | |
| TaskType, | |
| PeftModel | |
| ) | |
| PEFT_AVAILABLE = True | |
| except ImportError: | |
| PEFT_AVAILABLE = False | |
| print("⚠️ PEFT 未安裝,LoRA 和 AdaLoRA 功能將不可用") | |
| # 檢查 GPU | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| _MODEL_PATH = None | |
| LAST_TOKENIZER = None | |
| LAST_TUNING_METHOD = None | |
| def evaluate_baseline_bert(eval_dataset, df_clean): | |
| """ | |
| 評估原始 BERT(完全沒看過資料)的表現 | |
| 這部分是從您的格子 5 提取的 baseline 比較邏輯 | |
| """ | |
| print("\n" + "=" * 80) | |
| print("評估 Baseline 純 BERT(完全沒看過資料)") | |
| print("=" * 80) | |
| # 載入純 BERT | |
| baseline_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| baseline_model = BertForSequenceClassification.from_pretrained( | |
| "bert-base-uncased", | |
| num_labels=2 | |
| ).to(device) | |
| baseline_model.eval() | |
| print(" ⚠️ 這個模型完全沒有使用您的資料訓練") | |
| # 重新處理驗證集 | |
| baseline_dataset = Dataset.from_pandas(df_clean[['text', 'label']]) | |
| def baseline_preprocess(examples): | |
| return baseline_tokenizer(examples['text'], truncation=True, padding='max_length', max_length=256) | |
| baseline_tokenized = baseline_dataset.map(baseline_preprocess, batched=True) | |
| baseline_split = baseline_tokenized.train_test_split(test_size=0.2, seed=42) | |
| baseline_eval_dataset = baseline_split['test'] | |
| # 建立 Baseline Trainer | |
| baseline_trainer_args = TrainingArguments( | |
| output_dir='./temp_baseline', | |
| per_device_eval_batch_size=32, | |
| report_to="none" | |
| ) | |
| baseline_trainer = Trainer( | |
| model=baseline_model, | |
| args=baseline_trainer_args, | |
| ) | |
| # 評估 Baseline | |
| print("🔄 評估純 BERT...") | |
| predictions_output = baseline_trainer.predict(baseline_eval_dataset) | |
| all_preds = predictions_output.predictions.argmax(-1) | |
| all_labels = predictions_output.label_ids | |
| probs = torch.nn.functional.softmax(torch.tensor(predictions_output.predictions), dim=-1)[:, 1].numpy() | |
| # 計算指標 | |
| precision, recall, f1, _ = precision_recall_fscore_support( | |
| all_labels, all_preds, average='binary', pos_label=1, zero_division=0 | |
| ) | |
| acc = accuracy_score(all_labels, all_preds) | |
| try: | |
| auc = roc_auc_score(all_labels, probs) | |
| except: | |
| auc = 0.0 | |
| cm = confusion_matrix(all_labels, all_preds) | |
| if cm.shape == (2, 2): | |
| tn, fp, fn, tp = cm.ravel() | |
| sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0 | |
| specificity = tn / (tn + fp) if (tn + fp) > 0 else 0 | |
| else: | |
| sensitivity = specificity = 0 | |
| tn = fp = fn = tp = 0 | |
| baseline_results = { | |
| 'f1': float(f1), | |
| 'accuracy': float(acc), | |
| 'precision': float(precision), | |
| 'recall': float(recall), | |
| 'sensitivity': float(sensitivity), | |
| 'specificity': float(specificity), | |
| 'auc': float(auc), | |
| 'tp': int(tp), | |
| 'tn': int(tn), | |
| 'fp': int(fp), | |
| 'fn': int(fn) | |
| } | |
| print("✅ Baseline 評估完成") | |
| return baseline_results | |
| def run_original_code_with_tuning( | |
| file_path, | |
| weight_multiplier, | |
| epochs, | |
| batch_size, | |
| learning_rate, | |
| warmup_steps, | |
| tuning_method, | |
| best_metric, | |
| # LoRA 參數 | |
| lora_r, | |
| lora_alpha, | |
| lora_dropout, | |
| lora_modules, | |
| # AdaLoRA 參數 | |
| adalora_init_r, | |
| adalora_target_r, | |
| adalora_tinit, | |
| adalora_tfinal, | |
| adalora_delta_t | |
| ): | |
| """ | |
| 您的原始程式碼 + 不同微調方法的選項 + Baseline 比較 | |
| 核心邏輯完全不變,只是在模型初始化部分加入條件判斷 | |
| """ | |
| global LAST_MODEL_PATH, LAST_TOKENIZER, LAST_TUNING_METHOD | |
| # ==================== 清空記憶體(訓練前) ==================== | |
| import gc | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| print("🧹 記憶體已清空") | |
| # ==================== 您的原始程式碼開始 ==================== | |
| # 讀取上傳的檔案 | |
| df_original = pd.read_csv(file_path) | |
| df_clean = pd.DataFrame({ | |
| 'text': df_original['Text'], | |
| 'label': df_original['label'] | |
| }) | |
| df_clean = df_clean.dropna() | |
| print("\n" + "=" * 80) | |
| print(f"乳癌存活預測 BERT Fine-tuning - {tuning_method} 方法") | |
| print("=" * 80) | |
| print(f"開始時間: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") | |
| print(f"微調方法: {tuning_method}") | |
| print(f"最佳化指標: {best_metric}") | |
| print("=" * 80) | |
| # 載入 Tokenizer | |
| print("\n📦 載入 BERT Tokenizer...") | |
| tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| print("✅ Tokenizer 載入完成") | |
| # 評估函數 - 完全是您的原始程式碼,不動 | |
| def compute_metrics(pred): | |
| labels = pred.label_ids | |
| preds = pred.predictions.argmax(-1) | |
| probs = torch.nn.functional.softmax(torch.tensor(pred.predictions), dim=-1)[:, 1].numpy() | |
| precision, recall, f1, _ = precision_recall_fscore_support( | |
| labels, preds, average='binary', pos_label=1, zero_division=0 | |
| ) | |
| acc = accuracy_score(labels, preds) | |
| try: | |
| auc = roc_auc_score(labels, probs) | |
| except: | |
| auc = 0.0 | |
| cm = confusion_matrix(labels, preds) | |
| if cm.shape == (2, 2): | |
| tn, fp, fn, tp = cm.ravel() | |
| else: | |
| if len(np.unique(preds)) == 1: | |
| if preds[0] == 0: | |
| tn, fp, fn, tp = sum(labels == 0), 0, sum(labels == 1), 0 | |
| else: | |
| tn, fp, fn, tp = 0, sum(labels == 0), 0, sum(labels == 1) | |
| else: | |
| tn = fp = fn = tp = 0 | |
| sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0 | |
| specificity = tn / (tn + fp) if (tn + fp) > 0 else 0 | |
| return { | |
| 'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall, | |
| 'auc': auc, 'sensitivity': sensitivity, 'specificity': specificity, | |
| 'tp': int(tp), 'tn': int(tn), 'fp': int(fp), 'fn': int(fn) | |
| } | |
| # ============================================================================ | |
| # 步驟 1:準備資料(不做平衡)- 您的原始程式碼 | |
| # ============================================================================ | |
| print("\n" + "=" * 80) | |
| print("步驟 1:準備資料(保持原始比例)") | |
| print("=" * 80) | |
| print(f"\n原始資料分布:") | |
| print(f" 存活 (0): {sum(df_clean['label']==0)} 筆 ({sum(df_clean['label']==0)/len(df_clean)*100:.1f}%)") | |
| print(f" 死亡 (1): {sum(df_clean['label']==1)} 筆 ({sum(df_clean['label']==1)/len(df_clean)*100:.1f}%)") | |
| ratio = sum(df_clean['label']==0) / sum(df_clean['label']==1) | |
| print(f" 不平衡比例: {ratio:.1f}:1") | |
| # ============================================================================ | |
| # 步驟 2:Tokenization - 您的原始程式碼 | |
| # ============================================================================ | |
| print("\n" + "=" * 80) | |
| print("步驟 2:Tokenization") | |
| print("=" * 80) | |
| dataset = Dataset.from_pandas(df_clean[['text', 'label']]) | |
| def preprocess_function(examples): | |
| return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=256) | |
| tokenized_dataset = dataset.map(preprocess_function, batched=True) | |
| train_test_split = tokenized_dataset.train_test_split(test_size=0.2, seed=42) | |
| train_dataset = train_test_split['train'] | |
| eval_dataset = train_test_split['test'] | |
| print(f"\n✅ 資料集準備完成:") | |
| print(f" 訓練集: {len(train_dataset)} 筆") | |
| print(f" 驗證集: {len(eval_dataset)} 筆") | |
| # ============================================================================ | |
| # 步驟 3:設定權重 - 您的原始程式碼 | |
| # ============================================================================ | |
| print("\n" + "=" * 80) | |
| print(f"步驟 3:設定類別權重({weight_multiplier}x 倍數)") | |
| print("=" * 80) | |
| weight_0 = 1.0 | |
| weight_1 = ratio * weight_multiplier | |
| print(f"\n權重設定:") | |
| print(f" 倍數: {weight_multiplier}x") | |
| print(f" 存活類權重: {weight_0:.3f}") | |
| print(f" 死亡類權重: {weight_1:.3f} (= {ratio:.1f} × {weight_multiplier})") | |
| class_weights = torch.tensor([weight_0, weight_1], dtype=torch.float).to(device) | |
| # ============================================================================ | |
| # 步驟 4:訓練模型 - 這裡加入不同微調方法的選擇 | |
| # ============================================================================ | |
| print("\n" + "=" * 80) | |
| print(f"步驟 4:訓練 {tuning_method} BERT 模型") | |
| print("=" * 80) | |
| print(f"\n🔄 初始化模型 ({tuning_method})...") | |
| # 基礎模型載入 | |
| model = BertForSequenceClassification.from_pretrained( | |
| "bert-base-uncased", num_labels=2, problem_type="single_label_classification" | |
| ) | |
| # 根據選擇的微調方法設定模型 | |
| if tuning_method == "Full Fine-tuning": | |
| # 您的原始方法 - 完全不動 | |
| model = model.to(device) | |
| print("✅ 使用完整 Fine-tuning(所有參數可訓練)") | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| all_params = sum(p.numel() for p in model.parameters()) | |
| print(f" 可訓練參數: {trainable_params:,} / {all_params:,} ({100 * trainable_params / all_params:.2f}%)") | |
| elif tuning_method == "LoRA" and PEFT_AVAILABLE: | |
| # LoRA 設定 | |
| target_modules = lora_modules.split(",") if lora_modules else ["query", "value"] | |
| target_modules = [m.strip() for m in target_modules] | |
| peft_config = LoraConfig( | |
| task_type=TaskType.SEQ_CLS, | |
| r=int(lora_r), | |
| lora_alpha=int(lora_alpha), | |
| lora_dropout=float(lora_dropout), | |
| target_modules=target_modules | |
| ) | |
| model = get_peft_model(model, peft_config) | |
| model = model.to(device) | |
| print("✅ 使用 LoRA 微調") | |
| print(f" LoRA rank (r): {lora_r}") | |
| print(f" LoRA alpha: {lora_alpha}") | |
| print(f" LoRA dropout: {lora_dropout}") | |
| print(f" 目標模組: {target_modules}") | |
| model.print_trainable_parameters() | |
| elif tuning_method == "AdaLoRA" and PEFT_AVAILABLE: | |
| # AdaLoRA 設定 | |
| target_modules = lora_modules.split(",") if lora_modules else ["query", "value"] | |
| target_modules = [m.strip() for m in target_modules] | |
| peft_config = AdaLoraConfig( | |
| task_type=TaskType.SEQ_CLS, | |
| init_r=int(adalora_init_r), | |
| target_r=int(adalora_target_r), | |
| tinit=int(adalora_tinit), | |
| tfinal=int(adalora_tfinal), | |
| deltaT=int(adalora_delta_t), | |
| lora_alpha=int(lora_alpha), | |
| lora_dropout=float(lora_dropout), | |
| target_modules=target_modules | |
| ) | |
| model = get_peft_model(model, peft_config) | |
| model = model.to(device) | |
| print("✅ 使用 AdaLoRA 微調") | |
| print(f" 初始 rank: {adalora_init_r}") | |
| print(f" 目標 rank: {adalora_target_r}") | |
| print(f" Tinit: {adalora_tinit}, Tfinal: {adalora_tfinal}, DeltaT: {adalora_delta_t}") | |
| model.print_trainable_parameters() | |
| else: | |
| # 預設使用 Full Fine-tuning | |
| model = model.to(device) | |
| print("⚠️ PEFT 未安裝或方法無效,使用 Full Fine-tuning") | |
| # 自訂 Trainer(使用權重)- 您的原始程式碼 | |
| class WeightedTrainer(Trainer): | |
| def compute_loss(self, model, inputs, return_outputs=False): | |
| labels = inputs.pop("labels") | |
| outputs = model(**inputs) | |
| loss_fct = nn.CrossEntropyLoss(weight=class_weights) | |
| loss = loss_fct(outputs.logits.view(-1, 2), labels.view(-1)) | |
| return (loss, outputs) if return_outputs else loss | |
| # 訓練設定 - 根據選擇的最佳指標調整 | |
| metric_map = { | |
| "f1": "f1", | |
| "accuracy": "accuracy", | |
| "precision": "precision", | |
| "recall": "recall", | |
| "sensitivity": "sensitivity", | |
| "specificity": "specificity", | |
| "auc": "auc" | |
| } | |
| training_args = TrainingArguments( | |
| output_dir='./results_weight', | |
| num_train_epochs=epochs, | |
| per_device_train_batch_size=batch_size, | |
| per_device_eval_batch_size=batch_size*2, | |
| warmup_steps=warmup_steps, | |
| weight_decay=0.01, | |
| learning_rate=learning_rate, | |
| logging_steps=50, | |
| evaluation_strategy="epoch", | |
| save_strategy="epoch", | |
| load_best_model_at_end=True, | |
| metric_for_best_model=metric_map.get(best_metric, "f1"), # 使用選擇的指標 | |
| report_to="none", | |
| greater_is_better=True # 所有指標都是越高越好 | |
| ) | |
| trainer = WeightedTrainer( | |
| model=model, args=training_args, | |
| train_dataset=train_dataset, eval_dataset=eval_dataset, | |
| compute_metrics=compute_metrics | |
| ) | |
| print(f"\n🚀 開始訓練({epochs} epochs)...") | |
| print(f" 最佳化指標: {best_metric}") | |
| print("-" * 80) | |
| trainer.train() | |
| print("\n✅ 模型訓練完成!") | |
| # 評估模型 | |
| print("\n📊 評估模型...") | |
| results = trainer.evaluate() | |
| print(f"\n{tuning_method} BERT ({weight_multiplier}x 權重) 表現:") | |
| print(f" F1 Score: {results['eval_f1']:.4f}") | |
| print(f" Accuracy: {results['eval_accuracy']:.4f}") | |
| print(f" Precision: {results['eval_precision']:.4f}") | |
| print(f" Recall: {results['eval_recall']:.4f}") | |
| print(f" Sensitivity: {results['eval_sensitivity']:.4f}") | |
| print(f" Specificity: {results['eval_specificity']:.4f}") | |
| print(f" AUC: {results['eval_auc']:.4f}") | |
| print(f" 混淆矩陣: Tp={results['eval_tp']}, Tn={results['eval_tn']}, " | |
| f"Fp={results['eval_fp']}, Fn={results['eval_fn']}") | |
| # ============================================================================ | |
| # 步驟 5:Baseline 比較(純 BERT)- 從您的原始程式碼 | |
| # ============================================================================ | |
| print("\n" + "=" * 80) | |
| print("步驟 5:Baseline 比較 - 純 BERT(完全沒看過資料)") | |
| print("=" * 80) | |
| baseline_results = evaluate_baseline_bert(eval_dataset, df_clean) | |
| # ============================================================================ | |
| # 步驟 6:比較結果 - 從您的原始程式碼 | |
| # ============================================================================ | |
| print("\n" + "=" * 80) | |
| print(f"📊 【對比結果】純 BERT vs {tuning_method} BERT") | |
| print("=" * 80) | |
| print("\n📋 詳細比較表:") | |
| print("-" * 100) | |
| print(f"{'指標':<15} {'純 BERT':<20} {tuning_method:<20} {'改善幅度':<20}") | |
| print("-" * 100) | |
| metrics_to_compare = [ | |
| ('F1 Score', 'f1', 'eval_f1'), | |
| ('Accuracy', 'accuracy', 'eval_accuracy'), | |
| ('Precision', 'precision', 'eval_precision'), | |
| ('Recall', 'recall', 'eval_recall'), | |
| ('Sensitivity', 'sensitivity', 'eval_sensitivity'), | |
| ('Specificity', 'specificity', 'eval_specificity'), | |
| ('AUC', 'auc', 'eval_auc') | |
| ] | |
| for name, baseline_key, finetuned_key in metrics_to_compare: | |
| baseline_val = baseline_results[baseline_key] | |
| finetuned_val = results[finetuned_key] | |
| improvement = ((finetuned_val - baseline_val) / baseline_val * 100) if baseline_val > 0 else 0 | |
| print(f"{name:<15} {baseline_val:<20.4f} {finetuned_val:<20.4f} {improvement:>+18.1f}%") | |
| print("-" * 100) | |
| # 儲存模型 | |
| save_dir = f'./breast_cancer_bert_{tuning_method.lower().replace(" ", "_")}' | |
| if tuning_method in ["LoRA", "AdaLoRA"] and PEFT_AVAILABLE: | |
| # PEFT 模型儲存方式 | |
| model.save_pretrained(save_dir) | |
| tokenizer.save_pretrained(save_dir) | |
| else: | |
| # 一般模型儲存方式 | |
| model.save_pretrained(save_dir) | |
| tokenizer.save_pretrained(save_dir) | |
| # 儲存模型資訊到 JSON 檔案(用於預測頁面選擇) | |
| model_info = { | |
| 'model_path': save_dir, | |
| 'tuning_method': tuning_method, | |
| 'best_metric': best_metric, | |
| 'best_metric_value': float(results[f'eval_{metric_map.get(best_metric, "f1")}']), | |
| 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), | |
| 'weight_multiplier': weight_multiplier, | |
| 'epochs': epochs | |
| } | |
| # 讀取現有的模型列表 | |
| models_list_file = './saved_models_list.json' | |
| if os.path.exists(models_list_file): | |
| with open(models_list_file, 'r') as f: | |
| models_list = json.load(f) | |
| else: | |
| models_list = [] | |
| # 加入新模型資訊 | |
| models_list.append(model_info) | |
| # 儲存更新後的列表 | |
| with open(models_list_file, 'w') as f: | |
| json.dump(models_list, f, indent=2) | |
| # 儲存到全域變數供預測使用 | |
| LAST_MODEL_PATH = save_dir | |
| LAST_TOKENIZER = tokenizer | |
| LAST_TUNING_METHOD = tuning_method | |
| print(f"\n💾 模型已儲存至: {save_dir}") | |
| print("\n" + "=" * 80) | |
| print("🎉 訓練完成!") | |
| print("=" * 80) | |
| print(f"完成時間: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") | |
| # ==================== 清空記憶體(訓練後) ==================== | |
| del model | |
| del trainer | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| print("🧹 訓練後記憶體已清空") | |
| # 加入所有資訊到結果中 | |
| results['tuning_method'] = tuning_method | |
| results['best_metric'] = best_metric | |
| results['best_metric_value'] = results[f'eval_{metric_map.get(best_metric, "f1")}'] | |
| results['baseline_results'] = baseline_results | |
| results['model_path'] = save_dir | |
| return results | |
| def predict_text(model_choice, text_input): | |
| """ | |
| 預測功能 - 支援選擇已訓練的模型,並同時顯示未微調和微調的預測結果 | |
| """ | |
| if not text_input or text_input.strip() == "": | |
| return "請輸入文本", "請輸入文本" | |
| try: | |
| # ==================== 未微調的 BERT 預測 ==================== | |
| print("\n使用未微調 BERT 預測...") | |
| baseline_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| baseline_model = BertForSequenceClassification.from_pretrained( | |
| "bert-base-uncased", | |
| num_labels=2 | |
| ).to(device) | |
| baseline_model.eval() | |
| # Tokenize 輸入(未微調) | |
| baseline_inputs = baseline_tokenizer( | |
| text_input, | |
| truncation=True, | |
| padding='max_length', | |
| max_length=256, | |
| return_tensors='pt' | |
| ).to(device) | |
| # 預測(未微調) | |
| with torch.no_grad(): | |
| baseline_outputs = baseline_model(**baseline_inputs) | |
| baseline_probs = torch.nn.functional.softmax(baseline_outputs.logits, dim=-1) | |
| baseline_pred_class = baseline_probs.argmax(-1).item() | |
| baseline_confidence = baseline_probs[0][baseline_pred_class].item() | |
| baseline_result = "存活" if baseline_pred_class == 0 else "死亡" | |
| baseline_prob_survive = baseline_probs[0][0].item() | |
| baseline_prob_death = baseline_probs[0][1].item() | |
| baseline_output = f""" | |
| # 🔵 未微調 BERT 預測結果 | |
| ## 預測類別: **{baseline_result}** | |
| ## 信心度: **{baseline_confidence:.1%}** | |
| ## 機率分布: | |
| - 🟢 **存活機率**: {baseline_prob_survive:.2%} | |
| - 🔴 **死亡機率**: {baseline_prob_death:.2%} | |
| --- | |
| **說明**: 此為原始 BERT 模型,未經任何領域資料訓練 | |
| """ | |
| # 清空記憶體 | |
| del baseline_model | |
| del baseline_tokenizer | |
| torch.cuda.empty_cache() | |
| # ==================== 微調後的 BERT 預測 ==================== | |
| if model_choice == "請先訓練模型": | |
| finetuned_output = """ | |
| # 🟢 微調 BERT 預測結果 | |
| ❌ 尚未訓練任何模型,請先在「模型訓練」頁面訓練模型 | |
| """ | |
| return baseline_output, finetuned_output | |
| # 解析選擇的模型路徑 | |
| model_path = model_choice.split(" | ")[0].replace("路徑: ", "") | |
| # 從 JSON 讀取模型資訊 | |
| with open('./saved_models_list.json', 'r') as f: | |
| models_list = json.load(f) | |
| selected_model_info = None | |
| for model_info in models_list: | |
| if model_info['model_path'] == model_path: | |
| selected_model_info = model_info | |
| break | |
| if selected_model_info is None: | |
| finetuned_output = f""" | |
| # 🟢 微調 BERT 預測結果 | |
| ❌ 找不到模型:{model_path} | |
| """ | |
| return baseline_output, finetuned_output | |
| print(f"\n使用微調模型: {model_path}") | |
| # 載入 tokenizer | |
| finetuned_tokenizer = BertTokenizer.from_pretrained(model_path) | |
| # 載入模型 | |
| tuning_method = selected_model_info['tuning_method'] | |
| if tuning_method in ["LoRA", "AdaLoRA"] and PEFT_AVAILABLE: | |
| # 載入 PEFT 模型 | |
| base_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2) | |
| finetuned_model = PeftModel.from_pretrained(base_model, model_path) | |
| finetuned_model = finetuned_model.to(device) | |
| else: | |
| # 載入一般模型 | |
| finetuned_model = BertForSequenceClassification.from_pretrained(model_path).to(device) | |
| finetuned_model.eval() | |
| # Tokenize 輸入(微調) | |
| finetuned_inputs = finetuned_tokenizer( | |
| text_input, | |
| truncation=True, | |
| padding='max_length', | |
| max_length=256, | |
| return_tensors='pt' | |
| ).to(device) | |
| # 預測(微調) | |
| with torch.no_grad(): | |
| finetuned_outputs = finetuned_model(**finetuned_inputs) | |
| finetuned_probs = torch.nn.functional.softmax(finetuned_outputs.logits, dim=-1) | |
| finetuned_pred_class = finetuned_probs.argmax(-1).item() | |
| finetuned_confidence = finetuned_probs[0][finetuned_pred_class].item() | |
| finetuned_result = "存活" if finetuned_pred_class == 0 else "死亡" | |
| finetuned_prob_survive = finetuned_probs[0][0].item() | |
| finetuned_prob_death = finetuned_probs[0][1].item() | |
| finetuned_output = f""" | |
| # 🟢 微調 BERT 預測結果 | |
| ## 預測類別: **{finetuned_result}** | |
| ## 信心度: **{finetuned_confidence:.1%}** | |
| ## 機率分布: | |
| - 🟢 **存活機率**: {finetuned_prob_survive:.2%} | |
| - 🔴 **死亡機率**: {finetuned_prob_death:.2%} | |
| --- | |
| ### 模型資訊: | |
| - **微調方法**: {selected_model_info['tuning_method']} | |
| - **最佳化指標**: {selected_model_info['best_metric']} | |
| - **訓練時間**: {selected_model_info['timestamp']} | |
| - **模型路徑**: {model_path} | |
| --- | |
| **注意**: 此預測僅供參考,實際醫療決策應由專業醫師判斷。 | |
| """ | |
| # 清空記憶體 | |
| del finetuned_model | |
| del finetuned_tokenizer | |
| torch.cuda.empty_cache() | |
| return baseline_output, finetuned_output | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"❌ 預測錯誤:{str(e)}\n\n詳細錯誤訊息:\n{traceback.format_exc()}" | |
| return error_msg, error_msg | |
| def get_available_models(): | |
| """ | |
| 取得所有已訓練的模型列表 | |
| """ | |
| models_list_file = './saved_models_list.json' | |
| if not os.path.exists(models_list_file): | |
| return ["請先訓練模型"] | |
| with open(models_list_file, 'r') as f: | |
| models_list = json.load(f) | |
| if len(models_list) == 0: | |
| return ["請先訓練模型"] | |
| # 格式化模型選項 | |
| model_choices = [] | |
| for i, model_info in enumerate(models_list, 1): | |
| choice = f"路徑: {model_info['model_path']} | 方法: {model_info['tuning_method']} | 時間: {model_info['timestamp']}" | |
| model_choices.append(choice) | |
| return model_choices | |
| # ============================================================================ | |
| # Gradio 介面部分 - 修改輸出為三個格子 | |
| # ============================================================================ | |
| def train_wrapper( | |
| file, | |
| tuning_method, | |
| weight_mult, | |
| epochs, | |
| batch_size, | |
| lr, | |
| warmup, | |
| best_metric, | |
| lora_r, | |
| lora_alpha, | |
| lora_dropout, | |
| lora_modules, | |
| adalora_init_r, | |
| adalora_target_r, | |
| adalora_tinit, | |
| adalora_tfinal, | |
| adalora_delta_t | |
| ): | |
| """包裝函數,處理 Gradio 的輸入輸出 - 分成三格顯示""" | |
| if file is None: | |
| return "請上傳 CSV 檔案", "", "" | |
| try: | |
| # 呼叫訓練函數 | |
| results = run_original_code_with_tuning( | |
| file_path=file.name, | |
| weight_multiplier=weight_mult, | |
| epochs=int(epochs), | |
| batch_size=int(batch_size), | |
| learning_rate=lr, | |
| warmup_steps=int(warmup), | |
| tuning_method=tuning_method, | |
| best_metric=best_metric, | |
| lora_r=lora_r, | |
| lora_alpha=lora_alpha, | |
| lora_dropout=lora_dropout, | |
| lora_modules=lora_modules, | |
| adalora_init_r=adalora_init_r, | |
| adalora_target_r=adalora_target_r, | |
| adalora_tinit=adalora_tinit, | |
| adalora_tfinal=adalora_tfinal, | |
| adalora_delta_t=adalora_delta_t | |
| ) | |
| # 取得 baseline 結果 | |
| baseline_results = results['baseline_results'] | |
| # ==================== 格式化輸出:分成三個部分 ==================== | |
| # 第一格:資料資訊 (最上面一大格) | |
| data_info = f""" | |
| # 📊 資料資訊 | |
| ## 🔧 訓練配置 | |
| - **微調方法**: {results['tuning_method']} | |
| - **最佳化指標**: {results['best_metric']} | |
| - **最佳指標值**: {results['best_metric_value']:.4f} | |
| ## ⚙️ 訓練參數 | |
| - **權重倍數**: {weight_mult}x | |
| - **訓練輪數**: {epochs} | |
| - **批次大小**: {batch_size} | |
| - **學習率**: {lr} | |
| - **Warmup Steps**: {warmup} | |
| ✅ 訓練完成!模型已儲存,可在「預測」頁面使用! | |
| """ | |
| # 第二格:純 BERT (未微調) - 中間左邊 | |
| baseline_output = f""" | |
| # 🔵 純 BERT (Baseline) | |
| ## 未經訓練 | |
| ### 📈 評估指標 | |
| | 指標 | 數值 | | |
| |------|------| | |
| | **F1 Score** | {baseline_results['f1']:.4f} | | |
| | **Accuracy** | {baseline_results['accuracy']:.4f} | | |
| | **Precision** | {baseline_results['precision']:.4f} | | |
| | **Recall** | {baseline_results['recall']:.4f} | | |
| | **Sensitivity** | {baseline_results['sensitivity']:.4f} | | |
| | **Specificity** | {baseline_results['specificity']:.4f} | | |
| | **AUC** | {baseline_results['auc']:.4f} | | |
| ### 📈 混淆矩陣 | |
| | | 預測:存活 | 預測:死亡 | | |
| |---|-----------|-----------| | |
| | **實際:存活** | TN={baseline_results['tn']} | FP={baseline_results['fp']} | | |
| | **實際:死亡** | FN={baseline_results['fn']} | TP={baseline_results['tp']} | | |
| """ | |
| # 第三格:經微調 BERT - 中間右邊 | |
| finetuned_output = f""" | |
| # 🟢 經微調 BERT | |
| ## {results['tuning_method']} | |
| ### 📈 評估指標 | |
| | 指標 | 數值 | | |
| |------|------| | |
| | **F1 Score** | {results['eval_f1']:.4f} | | |
| | **Accuracy** | {results['eval_accuracy']:.4f} | | |
| | **Precision** | {results['eval_precision']:.4f} | | |
| | **Recall** | {results['eval_recall']:.4f} | | |
| | **Sensitivity** | {results['eval_sensitivity']:.4f} | | |
| | **Specificity** | {results['eval_specificity']:.4f} | | |
| | **AUC** | {results['eval_auc']:.4f} | | |
| ### 📈 混淆矩陣 | |
| | | 預測:存活 | 預測:死亡 | | |
| |---|-----------|-----------| | |
| | **實際:存活** | TN={results['eval_tn']} | FP={results['eval_fp']} | | |
| | **實際:死亡** | FN={results['eval_fn']} | TP={results['eval_tp']} | | |
| """ | |
| return data_info, baseline_output, finetuned_output | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"❌ 錯誤:{str(e)}\n\n詳細錯誤訊息:\n{traceback.format_exc()}" | |
| return error_msg, "", "" | |
| # 建立 Gradio 介面 | |
| with gr.Blocks(title="BERT 完整訓練與預測平台", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🏥 BERT 乳癌存活預測 - 完整訓練與預測平台 | |
| ### 🌟 功能特色: | |
| - 🎯 支援三種微調方法:Full Fine-tuning、LoRA、AdaLoRA | |
| - 📊 自動比較有/無微調的表現差異 | |
| - 🎨 可選擇最佳化指標(F1、Accuracy、Precision、Recall 等) | |
| - 🔮 訓練後可直接預測新病例 | |
| - 💾 自動儲存最佳模型 | |
| """) | |
| with gr.Tab("🎯 模型訓練"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📤 資料上傳") | |
| file_input = gr.File( | |
| label="上傳 CSV 檔案", | |
| file_types=[".csv"] | |
| ) | |
| gr.Markdown("### 🔧 微調方法選擇") | |
| tuning_method = gr.Radio( | |
| choices=["Full Fine-tuning", "LoRA", "AdaLoRA"], | |
| value="Full Fine-tuning", | |
| label="選擇微調方法", | |
| info="不同的參數效率微調方法" | |
| ) | |
| gr.Markdown("### 🎯 最佳模型選擇") | |
| best_metric = gr.Dropdown( | |
| choices=["f1", "accuracy", "precision", "recall", "sensitivity", "specificity", "auc"], | |
| value="f1", | |
| label="選擇最佳化指標", | |
| info="模型會根據此指標選擇最佳檢查點,結果會特別顯示此指標" | |
| ) | |
| gr.Markdown("### ⚙️ 基本訓練參數") | |
| weight_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.8, | |
| step=0.1, | |
| label="權重倍數", | |
| info="調整死亡類別的權重" | |
| ) | |
| epochs_input = gr.Number( | |
| value=8, | |
| label="訓練輪數 (Epochs)" | |
| ) | |
| batch_size_input = gr.Number( | |
| value=16, | |
| label="批次大小 (Batch Size)" | |
| ) | |
| lr_input = gr.Number( | |
| value=2e-5, | |
| label="學習率 (Learning Rate)" | |
| ) | |
| warmup_input = gr.Number( | |
| value=200, | |
| label="Warmup Steps" | |
| ) | |
| # LoRA 特定參數(預設隱藏) | |
| with gr.Column(visible=False) as lora_params: | |
| gr.Markdown("### 🔷 LoRA 參數") | |
| lora_r = gr.Slider( | |
| minimum=4, | |
| maximum=64, | |
| value=16, | |
| step=4, | |
| label="LoRA Rank (r)", | |
| info="低秩分解的秩" | |
| ) | |
| lora_alpha = gr.Slider( | |
| minimum=8, | |
| maximum=128, | |
| value=32, | |
| step=8, | |
| label="LoRA Alpha", | |
| info="LoRA 縮放參數" | |
| ) | |
| lora_dropout = gr.Slider( | |
| minimum=0.0, | |
| maximum=0.5, | |
| value=0.1, | |
| step=0.05, | |
| label="LoRA Dropout" | |
| ) | |
| lora_modules = gr.Textbox( | |
| value="query,value", | |
| label="目標模組", | |
| info="用逗號分隔" | |
| ) | |
| # AdaLoRA 特定參數(預設隱藏) | |
| with gr.Column(visible=False) as adalora_params: | |
| gr.Markdown("### 🔶 AdaLoRA 參數") | |
| adalora_init_r = gr.Slider( | |
| minimum=4, | |
| maximum=64, | |
| value=12, | |
| step=4, | |
| label="初始 Rank" | |
| ) | |
| adalora_target_r = gr.Slider( | |
| minimum=4, | |
| maximum=64, | |
| value=8, | |
| step=4, | |
| label="目標 Rank" | |
| ) | |
| adalora_tinit = gr.Number(value=0, label="Tinit") | |
| adalora_tfinal = gr.Number(value=0, label="Tfinal") | |
| adalora_delta_t = gr.Number(value=1, label="Delta T") | |
| train_button = gr.Button( | |
| "🚀 開始訓練", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 📊 訓練結果與比較") | |
| # 第一格:資料資訊(最上面一大格) | |
| data_info_output = gr.Markdown( | |
| value="### 等待訓練...\n\n訓練完成後會顯示資料資訊和訓練配置", | |
| label="資料資訊" | |
| ) | |
| # 第二和第三格:並排顯示(中間) | |
| with gr.Row(): | |
| # 第二格:純 BERT (左邊) | |
| baseline_output = gr.Markdown( | |
| value="### 純 BERT (未微調)\n等待訓練完成...", | |
| label="純 BERT" | |
| ) | |
| # 第三格:經微調 BERT (右邊) | |
| finetuned_output = gr.Markdown( | |
| value="### 經微調 BERT\n等待訓練完成...", | |
| label="經微調 BERT" | |
| ) | |
| with gr.Tab("🔮 模型預測"): | |
| gr.Markdown(""" | |
| ### 使用訓練好的模型進行預測 | |
| 選擇已訓練的模型,輸入病歷文本進行預測。會同時顯示未微調和微調模型的預測結果以供比較。 | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # 模型選擇下拉選單 | |
| model_dropdown = gr.Dropdown( | |
| label="選擇模型", | |
| choices=["請先訓練模型"], | |
| value="請先訓練模型", | |
| info="選擇要使用的已訓練模型" | |
| ) | |
| refresh_button = gr.Button( | |
| "🔄 重新整理模型列表", | |
| size="sm" | |
| ) | |
| text_input = gr.Textbox( | |
| label="輸入病歷文本", | |
| placeholder="請輸入患者的病歷描述(英文)...", | |
| lines=10 | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["Patient is a 45-year-old female with stage II breast cancer, ER+/PR+/HER2-, underwent mastectomy and chemotherapy."], | |
| ["65-year-old woman diagnosed with triple-negative breast cancer, stage III, with lymph node involvement."], | |
| ["Early stage breast cancer detected, patient is 38 years old, no family history, scheduled for lumpectomy."], | |
| ["Patient with advanced metastatic breast cancer, multiple organ involvement, poor prognosis."], | |
| ["Young patient, BRCA1 positive, preventive double mastectomy performed, good recovery."] | |
| ], | |
| inputs=text_input, | |
| label="範例文本(點擊使用)" | |
| ) | |
| predict_button = gr.Button( | |
| "🔮 開始預測", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("### 預測結果比較") | |
| # 上框:未微調 BERT 預測結果 | |
| baseline_prediction_output = gr.Markdown( | |
| label="未微調 BERT", | |
| value="等待預測..." | |
| ) | |
| # 下框:微調 BERT 預測結果 | |
| finetuned_prediction_output = gr.Markdown( | |
| label="微調 BERT", | |
| value="等待預測..." | |
| ) | |
| with gr.Tab("📖 使用說明"): | |
| gr.Markdown(""" | |
| ## 🔧 微調方法說明 | |
| | 方法 | 訓練速度 | 記憶體 | 效果 | 適用場景 | | |
| |------|---------|--------|------|----------| | |
| | **Full Fine-tuning** | 1x (基準) | 高 | 最佳 | 資源充足,要最佳效果 | | |
| | **LoRA** | 3-5x 快 | 低 | 良好 | 資源有限,快速實驗 | | |
| | **AdaLoRA** | 3-5x 快 | 低 | 良好 | 自動調整,平衡效果 | | |
| ## 📊 指標說明 | |
| - **F1 Score**: 精確率和召回率的調和平均,平衡指標 | |
| - **Accuracy**: 整體準確率 | |
| - **Precision**: 預測為死亡中的準確率 | |
| - **Recall/Sensitivity**: 實際死亡中被正確識別的比例 | |
| - **Specificity**: 實際存活中被正確識別的比例 | |
| - **AUC**: ROC 曲線下面積,整體分類能力 | |
| ## 💡 使用建議 | |
| 1. **資料不平衡嚴重**:增加權重倍數,使用 Recall 作為最佳化指標 | |
| 2. **避免誤診**:使用 Precision 作為最佳化指標 | |
| 3. **整體平衡**:使用 F1 Score 作為最佳化指標 | |
| 4. **快速實驗**:使用 LoRA,減少 epochs | |
| 5. **最佳效果**:使用 Full Fine-tuning,8-10 epochs | |
| ## ⚠️ 注意事項 | |
| - 訓練時間依資料量和方法而定(5-20 分鐘) | |
| - 建議至少 100 筆訓練資料 | |
| - GPU 會顯著加速訓練 | |
| - 預測結果僅供參考,實際醫療決策應由專業醫師判斷 | |
| """) | |
| # 根據選擇的微調方法顯示/隱藏相應參數 | |
| def update_params_visibility(method): | |
| if method == "LoRA": | |
| return gr.update(visible=True), gr.update(visible=False) | |
| elif method == "AdaLoRA": | |
| return gr.update(visible=True), gr.update(visible=True) | |
| else: | |
| return gr.update(visible=False), gr.update(visible=False) | |
| tuning_method.change( | |
| fn=update_params_visibility, | |
| inputs=[tuning_method], | |
| outputs=[lora_params, adalora_params] | |
| ) | |
| # 設定訓練按鈕動作 - 注意這裡改為三個輸出 | |
| train_button.click( | |
| fn=train_wrapper, | |
| inputs=[ | |
| file_input, | |
| tuning_method, | |
| weight_slider, | |
| epochs_input, | |
| batch_size_input, | |
| lr_input, | |
| warmup_input, | |
| best_metric, | |
| # LoRA 參數 | |
| lora_r, | |
| lora_alpha, | |
| lora_dropout, | |
| lora_modules, | |
| # AdaLoRA 參數 | |
| adalora_init_r, | |
| adalora_target_r, | |
| adalora_tinit, | |
| adalora_tfinal, | |
| adalora_delta_t | |
| ], | |
| outputs=[data_info_output, baseline_output, finetuned_output] # 三個輸出 | |
| ) | |
| # 重新整理模型列表按鈕 | |
| def refresh_models(): | |
| return gr.update(choices=get_available_models(), value=get_available_models()[0]) | |
| refresh_button.click( | |
| fn=refresh_models, | |
| inputs=[], | |
| outputs=[model_dropdown] | |
| ) | |
| # 預測按鈕動作 - 兩個輸出:未微調和微調 | |
| predict_button.click( | |
| fn=predict_text, | |
| inputs=[model_dropdown, text_input], | |
| outputs=[baseline_prediction_output, finetuned_prediction_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |