1111 / app.py
smartTranscend's picture
Update app.py
39d09a2 verified
import gradio as gr
import pandas as pd
import torch
from torch import nn
from transformers import (
BertTokenizer,
BertForSequenceClassification,
TrainingArguments,
Trainer
)
from datasets import Dataset
from sklearn.metrics import (
accuracy_score,
precision_recall_fscore_support,
roc_auc_score,
confusion_matrix
)
import numpy as np
from datetime import datetime
import json
import os
import gc # 用於記憶體清理
# PEFT 相關的 import(LoRA 和 AdaLoRA)
try:
from peft import (
LoraConfig,
AdaLoraConfig,
get_peft_model,
TaskType,
PeftModel
)
PEFT_AVAILABLE = True
except ImportError:
PEFT_AVAILABLE = False
print("⚠️ PEFT 未安裝,LoRA 和 AdaLoRA 功能將不可用")
# 檢查 GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
_MODEL_PATH = None
LAST_TOKENIZER = None
LAST_TUNING_METHOD = None
def evaluate_baseline_bert(eval_dataset, df_clean):
"""
評估原始 BERT(完全沒看過資料)的表現
這部分是從您的格子 5 提取的 baseline 比較邏輯
"""
print("\n" + "=" * 80)
print("評估 Baseline 純 BERT(完全沒看過資料)")
print("=" * 80)
# 載入純 BERT
baseline_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
baseline_model = BertForSequenceClassification.from_pretrained(
"bert-base-uncased",
num_labels=2
).to(device)
baseline_model.eval()
print(" ⚠️ 這個模型完全沒有使用您的資料訓練")
# 重新處理驗證集
baseline_dataset = Dataset.from_pandas(df_clean[['text', 'label']])
def baseline_preprocess(examples):
return baseline_tokenizer(examples['text'], truncation=True, padding='max_length', max_length=256)
baseline_tokenized = baseline_dataset.map(baseline_preprocess, batched=True)
baseline_split = baseline_tokenized.train_test_split(test_size=0.2, seed=42)
baseline_eval_dataset = baseline_split['test']
# 建立 Baseline Trainer
baseline_trainer_args = TrainingArguments(
output_dir='./temp_baseline',
per_device_eval_batch_size=32,
report_to="none"
)
baseline_trainer = Trainer(
model=baseline_model,
args=baseline_trainer_args,
)
# 評估 Baseline
print("🔄 評估純 BERT...")
predictions_output = baseline_trainer.predict(baseline_eval_dataset)
all_preds = predictions_output.predictions.argmax(-1)
all_labels = predictions_output.label_ids
probs = torch.nn.functional.softmax(torch.tensor(predictions_output.predictions), dim=-1)[:, 1].numpy()
# 計算指標
precision, recall, f1, _ = precision_recall_fscore_support(
all_labels, all_preds, average='binary', pos_label=1, zero_division=0
)
acc = accuracy_score(all_labels, all_preds)
try:
auc = roc_auc_score(all_labels, probs)
except:
auc = 0.0
cm = confusion_matrix(all_labels, all_preds)
if cm.shape == (2, 2):
tn, fp, fn, tp = cm.ravel()
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
else:
sensitivity = specificity = 0
tn = fp = fn = tp = 0
baseline_results = {
'f1': float(f1),
'accuracy': float(acc),
'precision': float(precision),
'recall': float(recall),
'sensitivity': float(sensitivity),
'specificity': float(specificity),
'auc': float(auc),
'tp': int(tp),
'tn': int(tn),
'fp': int(fp),
'fn': int(fn)
}
print("✅ Baseline 評估完成")
return baseline_results
def run_original_code_with_tuning(
file_path,
weight_multiplier,
epochs,
batch_size,
learning_rate,
warmup_steps,
tuning_method,
best_metric,
# LoRA 參數
lora_r,
lora_alpha,
lora_dropout,
lora_modules,
# AdaLoRA 參數
adalora_init_r,
adalora_target_r,
adalora_tinit,
adalora_tfinal,
adalora_delta_t
):
"""
您的原始程式碼 + 不同微調方法的選項 + Baseline 比較
核心邏輯完全不變,只是在模型初始化部分加入條件判斷
"""
global LAST_MODEL_PATH, LAST_TOKENIZER, LAST_TUNING_METHOD
# ==================== 清空記憶體(訓練前) ====================
import gc
torch.cuda.empty_cache()
gc.collect()
print("🧹 記憶體已清空")
# ==================== 您的原始程式碼開始 ====================
# 讀取上傳的檔案
df_original = pd.read_csv(file_path)
df_clean = pd.DataFrame({
'text': df_original['Text'],
'label': df_original['label']
})
df_clean = df_clean.dropna()
print("\n" + "=" * 80)
print(f"乳癌存活預測 BERT Fine-tuning - {tuning_method} 方法")
print("=" * 80)
print(f"開始時間: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"微調方法: {tuning_method}")
print(f"最佳化指標: {best_metric}")
print("=" * 80)
# 載入 Tokenizer
print("\n📦 載入 BERT Tokenizer...")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
print("✅ Tokenizer 載入完成")
# 評估函數 - 完全是您的原始程式碼,不動
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
probs = torch.nn.functional.softmax(torch.tensor(pred.predictions), dim=-1)[:, 1].numpy()
precision, recall, f1, _ = precision_recall_fscore_support(
labels, preds, average='binary', pos_label=1, zero_division=0
)
acc = accuracy_score(labels, preds)
try:
auc = roc_auc_score(labels, probs)
except:
auc = 0.0
cm = confusion_matrix(labels, preds)
if cm.shape == (2, 2):
tn, fp, fn, tp = cm.ravel()
else:
if len(np.unique(preds)) == 1:
if preds[0] == 0:
tn, fp, fn, tp = sum(labels == 0), 0, sum(labels == 1), 0
else:
tn, fp, fn, tp = 0, sum(labels == 0), 0, sum(labels == 1)
else:
tn = fp = fn = tp = 0
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
return {
'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall,
'auc': auc, 'sensitivity': sensitivity, 'specificity': specificity,
'tp': int(tp), 'tn': int(tn), 'fp': int(fp), 'fn': int(fn)
}
# ============================================================================
# 步驟 1:準備資料(不做平衡)- 您的原始程式碼
# ============================================================================
print("\n" + "=" * 80)
print("步驟 1:準備資料(保持原始比例)")
print("=" * 80)
print(f"\n原始資料分布:")
print(f" 存活 (0): {sum(df_clean['label']==0)} 筆 ({sum(df_clean['label']==0)/len(df_clean)*100:.1f}%)")
print(f" 死亡 (1): {sum(df_clean['label']==1)} 筆 ({sum(df_clean['label']==1)/len(df_clean)*100:.1f}%)")
ratio = sum(df_clean['label']==0) / sum(df_clean['label']==1)
print(f" 不平衡比例: {ratio:.1f}:1")
# ============================================================================
# 步驟 2:Tokenization - 您的原始程式碼
# ============================================================================
print("\n" + "=" * 80)
print("步驟 2:Tokenization")
print("=" * 80)
dataset = Dataset.from_pandas(df_clean[['text', 'label']])
def preprocess_function(examples):
return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=256)
tokenized_dataset = dataset.map(preprocess_function, batched=True)
train_test_split = tokenized_dataset.train_test_split(test_size=0.2, seed=42)
train_dataset = train_test_split['train']
eval_dataset = train_test_split['test']
print(f"\n✅ 資料集準備完成:")
print(f" 訓練集: {len(train_dataset)} 筆")
print(f" 驗證集: {len(eval_dataset)} 筆")
# ============================================================================
# 步驟 3:設定權重 - 您的原始程式碼
# ============================================================================
print("\n" + "=" * 80)
print(f"步驟 3:設定類別權重({weight_multiplier}x 倍數)")
print("=" * 80)
weight_0 = 1.0
weight_1 = ratio * weight_multiplier
print(f"\n權重設定:")
print(f" 倍數: {weight_multiplier}x")
print(f" 存活類權重: {weight_0:.3f}")
print(f" 死亡類權重: {weight_1:.3f} (= {ratio:.1f} × {weight_multiplier})")
class_weights = torch.tensor([weight_0, weight_1], dtype=torch.float).to(device)
# ============================================================================
# 步驟 4:訓練模型 - 這裡加入不同微調方法的選擇
# ============================================================================
print("\n" + "=" * 80)
print(f"步驟 4:訓練 {tuning_method} BERT 模型")
print("=" * 80)
print(f"\n🔄 初始化模型 ({tuning_method})...")
# 基礎模型載入
model = BertForSequenceClassification.from_pretrained(
"bert-base-uncased", num_labels=2, problem_type="single_label_classification"
)
# 根據選擇的微調方法設定模型
if tuning_method == "Full Fine-tuning":
# 您的原始方法 - 完全不動
model = model.to(device)
print("✅ 使用完整 Fine-tuning(所有參數可訓練)")
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
all_params = sum(p.numel() for p in model.parameters())
print(f" 可訓練參數: {trainable_params:,} / {all_params:,} ({100 * trainable_params / all_params:.2f}%)")
elif tuning_method == "LoRA" and PEFT_AVAILABLE:
# LoRA 設定
target_modules = lora_modules.split(",") if lora_modules else ["query", "value"]
target_modules = [m.strip() for m in target_modules]
peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
r=int(lora_r),
lora_alpha=int(lora_alpha),
lora_dropout=float(lora_dropout),
target_modules=target_modules
)
model = get_peft_model(model, peft_config)
model = model.to(device)
print("✅ 使用 LoRA 微調")
print(f" LoRA rank (r): {lora_r}")
print(f" LoRA alpha: {lora_alpha}")
print(f" LoRA dropout: {lora_dropout}")
print(f" 目標模組: {target_modules}")
model.print_trainable_parameters()
elif tuning_method == "AdaLoRA" and PEFT_AVAILABLE:
# AdaLoRA 設定
target_modules = lora_modules.split(",") if lora_modules else ["query", "value"]
target_modules = [m.strip() for m in target_modules]
peft_config = AdaLoraConfig(
task_type=TaskType.SEQ_CLS,
init_r=int(adalora_init_r),
target_r=int(adalora_target_r),
tinit=int(adalora_tinit),
tfinal=int(adalora_tfinal),
deltaT=int(adalora_delta_t),
lora_alpha=int(lora_alpha),
lora_dropout=float(lora_dropout),
target_modules=target_modules
)
model = get_peft_model(model, peft_config)
model = model.to(device)
print("✅ 使用 AdaLoRA 微調")
print(f" 初始 rank: {adalora_init_r}")
print(f" 目標 rank: {adalora_target_r}")
print(f" Tinit: {adalora_tinit}, Tfinal: {adalora_tfinal}, DeltaT: {adalora_delta_t}")
model.print_trainable_parameters()
else:
# 預設使用 Full Fine-tuning
model = model.to(device)
print("⚠️ PEFT 未安裝或方法無效,使用 Full Fine-tuning")
# 自訂 Trainer(使用權重)- 您的原始程式碼
class WeightedTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels")
outputs = model(**inputs)
loss_fct = nn.CrossEntropyLoss(weight=class_weights)
loss = loss_fct(outputs.logits.view(-1, 2), labels.view(-1))
return (loss, outputs) if return_outputs else loss
# 訓練設定 - 根據選擇的最佳指標調整
metric_map = {
"f1": "f1",
"accuracy": "accuracy",
"precision": "precision",
"recall": "recall",
"sensitivity": "sensitivity",
"specificity": "specificity",
"auc": "auc"
}
training_args = TrainingArguments(
output_dir='./results_weight',
num_train_epochs=epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size*2,
warmup_steps=warmup_steps,
weight_decay=0.01,
learning_rate=learning_rate,
logging_steps=50,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model=metric_map.get(best_metric, "f1"), # 使用選擇的指標
report_to="none",
greater_is_better=True # 所有指標都是越高越好
)
trainer = WeightedTrainer(
model=model, args=training_args,
train_dataset=train_dataset, eval_dataset=eval_dataset,
compute_metrics=compute_metrics
)
print(f"\n🚀 開始訓練({epochs} epochs)...")
print(f" 最佳化指標: {best_metric}")
print("-" * 80)
trainer.train()
print("\n✅ 模型訓練完成!")
# 評估模型
print("\n📊 評估模型...")
results = trainer.evaluate()
print(f"\n{tuning_method} BERT ({weight_multiplier}x 權重) 表現:")
print(f" F1 Score: {results['eval_f1']:.4f}")
print(f" Accuracy: {results['eval_accuracy']:.4f}")
print(f" Precision: {results['eval_precision']:.4f}")
print(f" Recall: {results['eval_recall']:.4f}")
print(f" Sensitivity: {results['eval_sensitivity']:.4f}")
print(f" Specificity: {results['eval_specificity']:.4f}")
print(f" AUC: {results['eval_auc']:.4f}")
print(f" 混淆矩陣: Tp={results['eval_tp']}, Tn={results['eval_tn']}, "
f"Fp={results['eval_fp']}, Fn={results['eval_fn']}")
# ============================================================================
# 步驟 5:Baseline 比較(純 BERT)- 從您的原始程式碼
# ============================================================================
print("\n" + "=" * 80)
print("步驟 5:Baseline 比較 - 純 BERT(完全沒看過資料)")
print("=" * 80)
baseline_results = evaluate_baseline_bert(eval_dataset, df_clean)
# ============================================================================
# 步驟 6:比較結果 - 從您的原始程式碼
# ============================================================================
print("\n" + "=" * 80)
print(f"📊 【對比結果】純 BERT vs {tuning_method} BERT")
print("=" * 80)
print("\n📋 詳細比較表:")
print("-" * 100)
print(f"{'指標':<15} {'純 BERT':<20} {tuning_method:<20} {'改善幅度':<20}")
print("-" * 100)
metrics_to_compare = [
('F1 Score', 'f1', 'eval_f1'),
('Accuracy', 'accuracy', 'eval_accuracy'),
('Precision', 'precision', 'eval_precision'),
('Recall', 'recall', 'eval_recall'),
('Sensitivity', 'sensitivity', 'eval_sensitivity'),
('Specificity', 'specificity', 'eval_specificity'),
('AUC', 'auc', 'eval_auc')
]
for name, baseline_key, finetuned_key in metrics_to_compare:
baseline_val = baseline_results[baseline_key]
finetuned_val = results[finetuned_key]
improvement = ((finetuned_val - baseline_val) / baseline_val * 100) if baseline_val > 0 else 0
print(f"{name:<15} {baseline_val:<20.4f} {finetuned_val:<20.4f} {improvement:>+18.1f}%")
print("-" * 100)
# 儲存模型
save_dir = f'./breast_cancer_bert_{tuning_method.lower().replace(" ", "_")}'
if tuning_method in ["LoRA", "AdaLoRA"] and PEFT_AVAILABLE:
# PEFT 模型儲存方式
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
else:
# 一般模型儲存方式
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
# 儲存模型資訊到 JSON 檔案(用於預測頁面選擇)
model_info = {
'model_path': save_dir,
'tuning_method': tuning_method,
'best_metric': best_metric,
'best_metric_value': float(results[f'eval_{metric_map.get(best_metric, "f1")}']),
'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'weight_multiplier': weight_multiplier,
'epochs': epochs
}
# 讀取現有的模型列表
models_list_file = './saved_models_list.json'
if os.path.exists(models_list_file):
with open(models_list_file, 'r') as f:
models_list = json.load(f)
else:
models_list = []
# 加入新模型資訊
models_list.append(model_info)
# 儲存更新後的列表
with open(models_list_file, 'w') as f:
json.dump(models_list, f, indent=2)
# 儲存到全域變數供預測使用
LAST_MODEL_PATH = save_dir
LAST_TOKENIZER = tokenizer
LAST_TUNING_METHOD = tuning_method
print(f"\n💾 模型已儲存至: {save_dir}")
print("\n" + "=" * 80)
print("🎉 訓練完成!")
print("=" * 80)
print(f"完成時間: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
# ==================== 清空記憶體(訓練後) ====================
del model
del trainer
torch.cuda.empty_cache()
gc.collect()
print("🧹 訓練後記憶體已清空")
# 加入所有資訊到結果中
results['tuning_method'] = tuning_method
results['best_metric'] = best_metric
results['best_metric_value'] = results[f'eval_{metric_map.get(best_metric, "f1")}']
results['baseline_results'] = baseline_results
results['model_path'] = save_dir
return results
def predict_text(model_choice, text_input):
"""
預測功能 - 支援選擇已訓練的模型,並同時顯示未微調和微調的預測結果
"""
if not text_input or text_input.strip() == "":
return "請輸入文本", "請輸入文本"
try:
# ==================== 未微調的 BERT 預測 ====================
print("\n使用未微調 BERT 預測...")
baseline_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
baseline_model = BertForSequenceClassification.from_pretrained(
"bert-base-uncased",
num_labels=2
).to(device)
baseline_model.eval()
# Tokenize 輸入(未微調)
baseline_inputs = baseline_tokenizer(
text_input,
truncation=True,
padding='max_length',
max_length=256,
return_tensors='pt'
).to(device)
# 預測(未微調)
with torch.no_grad():
baseline_outputs = baseline_model(**baseline_inputs)
baseline_probs = torch.nn.functional.softmax(baseline_outputs.logits, dim=-1)
baseline_pred_class = baseline_probs.argmax(-1).item()
baseline_confidence = baseline_probs[0][baseline_pred_class].item()
baseline_result = "存活" if baseline_pred_class == 0 else "死亡"
baseline_prob_survive = baseline_probs[0][0].item()
baseline_prob_death = baseline_probs[0][1].item()
baseline_output = f"""
# 🔵 未微調 BERT 預測結果
## 預測類別: **{baseline_result}**
## 信心度: **{baseline_confidence:.1%}**
## 機率分布:
- 🟢 **存活機率**: {baseline_prob_survive:.2%}
- 🔴 **死亡機率**: {baseline_prob_death:.2%}
---
**說明**: 此為原始 BERT 模型,未經任何領域資料訓練
"""
# 清空記憶體
del baseline_model
del baseline_tokenizer
torch.cuda.empty_cache()
# ==================== 微調後的 BERT 預測 ====================
if model_choice == "請先訓練模型":
finetuned_output = """
# 🟢 微調 BERT 預測結果
❌ 尚未訓練任何模型,請先在「模型訓練」頁面訓練模型
"""
return baseline_output, finetuned_output
# 解析選擇的模型路徑
model_path = model_choice.split(" | ")[0].replace("路徑: ", "")
# 從 JSON 讀取模型資訊
with open('./saved_models_list.json', 'r') as f:
models_list = json.load(f)
selected_model_info = None
for model_info in models_list:
if model_info['model_path'] == model_path:
selected_model_info = model_info
break
if selected_model_info is None:
finetuned_output = f"""
# 🟢 微調 BERT 預測結果
❌ 找不到模型:{model_path}
"""
return baseline_output, finetuned_output
print(f"\n使用微調模型: {model_path}")
# 載入 tokenizer
finetuned_tokenizer = BertTokenizer.from_pretrained(model_path)
# 載入模型
tuning_method = selected_model_info['tuning_method']
if tuning_method in ["LoRA", "AdaLoRA"] and PEFT_AVAILABLE:
# 載入 PEFT 模型
base_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
finetuned_model = PeftModel.from_pretrained(base_model, model_path)
finetuned_model = finetuned_model.to(device)
else:
# 載入一般模型
finetuned_model = BertForSequenceClassification.from_pretrained(model_path).to(device)
finetuned_model.eval()
# Tokenize 輸入(微調)
finetuned_inputs = finetuned_tokenizer(
text_input,
truncation=True,
padding='max_length',
max_length=256,
return_tensors='pt'
).to(device)
# 預測(微調)
with torch.no_grad():
finetuned_outputs = finetuned_model(**finetuned_inputs)
finetuned_probs = torch.nn.functional.softmax(finetuned_outputs.logits, dim=-1)
finetuned_pred_class = finetuned_probs.argmax(-1).item()
finetuned_confidence = finetuned_probs[0][finetuned_pred_class].item()
finetuned_result = "存活" if finetuned_pred_class == 0 else "死亡"
finetuned_prob_survive = finetuned_probs[0][0].item()
finetuned_prob_death = finetuned_probs[0][1].item()
finetuned_output = f"""
# 🟢 微調 BERT 預測結果
## 預測類別: **{finetuned_result}**
## 信心度: **{finetuned_confidence:.1%}**
## 機率分布:
- 🟢 **存活機率**: {finetuned_prob_survive:.2%}
- 🔴 **死亡機率**: {finetuned_prob_death:.2%}
---
### 模型資訊:
- **微調方法**: {selected_model_info['tuning_method']}
- **最佳化指標**: {selected_model_info['best_metric']}
- **訓練時間**: {selected_model_info['timestamp']}
- **模型路徑**: {model_path}
---
**注意**: 此預測僅供參考,實際醫療決策應由專業醫師判斷。
"""
# 清空記憶體
del finetuned_model
del finetuned_tokenizer
torch.cuda.empty_cache()
return baseline_output, finetuned_output
except Exception as e:
import traceback
error_msg = f"❌ 預測錯誤:{str(e)}\n\n詳細錯誤訊息:\n{traceback.format_exc()}"
return error_msg, error_msg
def get_available_models():
"""
取得所有已訓練的模型列表
"""
models_list_file = './saved_models_list.json'
if not os.path.exists(models_list_file):
return ["請先訓練模型"]
with open(models_list_file, 'r') as f:
models_list = json.load(f)
if len(models_list) == 0:
return ["請先訓練模型"]
# 格式化模型選項
model_choices = []
for i, model_info in enumerate(models_list, 1):
choice = f"路徑: {model_info['model_path']} | 方法: {model_info['tuning_method']} | 時間: {model_info['timestamp']}"
model_choices.append(choice)
return model_choices
# ============================================================================
# Gradio 介面部分 - 修改輸出為三個格子
# ============================================================================
def train_wrapper(
file,
tuning_method,
weight_mult,
epochs,
batch_size,
lr,
warmup,
best_metric,
lora_r,
lora_alpha,
lora_dropout,
lora_modules,
adalora_init_r,
adalora_target_r,
adalora_tinit,
adalora_tfinal,
adalora_delta_t
):
"""包裝函數,處理 Gradio 的輸入輸出 - 分成三格顯示"""
if file is None:
return "請上傳 CSV 檔案", "", ""
try:
# 呼叫訓練函數
results = run_original_code_with_tuning(
file_path=file.name,
weight_multiplier=weight_mult,
epochs=int(epochs),
batch_size=int(batch_size),
learning_rate=lr,
warmup_steps=int(warmup),
tuning_method=tuning_method,
best_metric=best_metric,
lora_r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
lora_modules=lora_modules,
adalora_init_r=adalora_init_r,
adalora_target_r=adalora_target_r,
adalora_tinit=adalora_tinit,
adalora_tfinal=adalora_tfinal,
adalora_delta_t=adalora_delta_t
)
# 取得 baseline 結果
baseline_results = results['baseline_results']
# ==================== 格式化輸出:分成三個部分 ====================
# 第一格:資料資訊 (最上面一大格)
data_info = f"""
# 📊 資料資訊
## 🔧 訓練配置
- **微調方法**: {results['tuning_method']}
- **最佳化指標**: {results['best_metric']}
- **最佳指標值**: {results['best_metric_value']:.4f}
## ⚙️ 訓練參數
- **權重倍數**: {weight_mult}x
- **訓練輪數**: {epochs}
- **批次大小**: {batch_size}
- **學習率**: {lr}
- **Warmup Steps**: {warmup}
✅ 訓練完成!模型已儲存,可在「預測」頁面使用!
"""
# 第二格:純 BERT (未微調) - 中間左邊
baseline_output = f"""
# 🔵 純 BERT (Baseline)
## 未經訓練
### 📈 評估指標
| 指標 | 數值 |
|------|------|
| **F1 Score** | {baseline_results['f1']:.4f} |
| **Accuracy** | {baseline_results['accuracy']:.4f} |
| **Precision** | {baseline_results['precision']:.4f} |
| **Recall** | {baseline_results['recall']:.4f} |
| **Sensitivity** | {baseline_results['sensitivity']:.4f} |
| **Specificity** | {baseline_results['specificity']:.4f} |
| **AUC** | {baseline_results['auc']:.4f} |
### 📈 混淆矩陣
| | 預測:存活 | 預測:死亡 |
|---|-----------|-----------|
| **實際:存活** | TN={baseline_results['tn']} | FP={baseline_results['fp']} |
| **實際:死亡** | FN={baseline_results['fn']} | TP={baseline_results['tp']} |
"""
# 第三格:經微調 BERT - 中間右邊
finetuned_output = f"""
# 🟢 經微調 BERT
## {results['tuning_method']}
### 📈 評估指標
| 指標 | 數值 |
|------|------|
| **F1 Score** | {results['eval_f1']:.4f} |
| **Accuracy** | {results['eval_accuracy']:.4f} |
| **Precision** | {results['eval_precision']:.4f} |
| **Recall** | {results['eval_recall']:.4f} |
| **Sensitivity** | {results['eval_sensitivity']:.4f} |
| **Specificity** | {results['eval_specificity']:.4f} |
| **AUC** | {results['eval_auc']:.4f} |
### 📈 混淆矩陣
| | 預測:存活 | 預測:死亡 |
|---|-----------|-----------|
| **實際:存活** | TN={results['eval_tn']} | FP={results['eval_fp']} |
| **實際:死亡** | FN={results['eval_fn']} | TP={results['eval_tp']} |
"""
return data_info, baseline_output, finetuned_output
except Exception as e:
import traceback
error_msg = f"❌ 錯誤:{str(e)}\n\n詳細錯誤訊息:\n{traceback.format_exc()}"
return error_msg, "", ""
# 建立 Gradio 介面
with gr.Blocks(title="BERT 完整訓練與預測平台", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🏥 BERT 乳癌存活預測 - 完整訓練與預測平台
### 🌟 功能特色:
- 🎯 支援三種微調方法:Full Fine-tuning、LoRA、AdaLoRA
- 📊 自動比較有/無微調的表現差異
- 🎨 可選擇最佳化指標(F1、Accuracy、Precision、Recall 等)
- 🔮 訓練後可直接預測新病例
- 💾 自動儲存最佳模型
""")
with gr.Tab("🎯 模型訓練"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 📤 資料上傳")
file_input = gr.File(
label="上傳 CSV 檔案",
file_types=[".csv"]
)
gr.Markdown("### 🔧 微調方法選擇")
tuning_method = gr.Radio(
choices=["Full Fine-tuning", "LoRA", "AdaLoRA"],
value="Full Fine-tuning",
label="選擇微調方法",
info="不同的參數效率微調方法"
)
gr.Markdown("### 🎯 最佳模型選擇")
best_metric = gr.Dropdown(
choices=["f1", "accuracy", "precision", "recall", "sensitivity", "specificity", "auc"],
value="f1",
label="選擇最佳化指標",
info="模型會根據此指標選擇最佳檢查點,結果會特別顯示此指標"
)
gr.Markdown("### ⚙️ 基本訓練參數")
weight_slider = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.8,
step=0.1,
label="權重倍數",
info="調整死亡類別的權重"
)
epochs_input = gr.Number(
value=8,
label="訓練輪數 (Epochs)"
)
batch_size_input = gr.Number(
value=16,
label="批次大小 (Batch Size)"
)
lr_input = gr.Number(
value=2e-5,
label="學習率 (Learning Rate)"
)
warmup_input = gr.Number(
value=200,
label="Warmup Steps"
)
# LoRA 特定參數(預設隱藏)
with gr.Column(visible=False) as lora_params:
gr.Markdown("### 🔷 LoRA 參數")
lora_r = gr.Slider(
minimum=4,
maximum=64,
value=16,
step=4,
label="LoRA Rank (r)",
info="低秩分解的秩"
)
lora_alpha = gr.Slider(
minimum=8,
maximum=128,
value=32,
step=8,
label="LoRA Alpha",
info="LoRA 縮放參數"
)
lora_dropout = gr.Slider(
minimum=0.0,
maximum=0.5,
value=0.1,
step=0.05,
label="LoRA Dropout"
)
lora_modules = gr.Textbox(
value="query,value",
label="目標模組",
info="用逗號分隔"
)
# AdaLoRA 特定參數(預設隱藏)
with gr.Column(visible=False) as adalora_params:
gr.Markdown("### 🔶 AdaLoRA 參數")
adalora_init_r = gr.Slider(
minimum=4,
maximum=64,
value=12,
step=4,
label="初始 Rank"
)
adalora_target_r = gr.Slider(
minimum=4,
maximum=64,
value=8,
step=4,
label="目標 Rank"
)
adalora_tinit = gr.Number(value=0, label="Tinit")
adalora_tfinal = gr.Number(value=0, label="Tfinal")
adalora_delta_t = gr.Number(value=1, label="Delta T")
train_button = gr.Button(
"🚀 開始訓練",
variant="primary",
size="lg"
)
with gr.Column(scale=2):
gr.Markdown("### 📊 訓練結果與比較")
# 第一格:資料資訊(最上面一大格)
data_info_output = gr.Markdown(
value="### 等待訓練...\n\n訓練完成後會顯示資料資訊和訓練配置",
label="資料資訊"
)
# 第二和第三格:並排顯示(中間)
with gr.Row():
# 第二格:純 BERT (左邊)
baseline_output = gr.Markdown(
value="### 純 BERT (未微調)\n等待訓練完成...",
label="純 BERT"
)
# 第三格:經微調 BERT (右邊)
finetuned_output = gr.Markdown(
value="### 經微調 BERT\n等待訓練完成...",
label="經微調 BERT"
)
with gr.Tab("🔮 模型預測"):
gr.Markdown("""
### 使用訓練好的模型進行預測
選擇已訓練的模型,輸入病歷文本進行預測。會同時顯示未微調和微調模型的預測結果以供比較。
""")
with gr.Row():
with gr.Column():
# 模型選擇下拉選單
model_dropdown = gr.Dropdown(
label="選擇模型",
choices=["請先訓練模型"],
value="請先訓練模型",
info="選擇要使用的已訓練模型"
)
refresh_button = gr.Button(
"🔄 重新整理模型列表",
size="sm"
)
text_input = gr.Textbox(
label="輸入病歷文本",
placeholder="請輸入患者的病歷描述(英文)...",
lines=10
)
gr.Examples(
examples=[
["Patient is a 45-year-old female with stage II breast cancer, ER+/PR+/HER2-, underwent mastectomy and chemotherapy."],
["65-year-old woman diagnosed with triple-negative breast cancer, stage III, with lymph node involvement."],
["Early stage breast cancer detected, patient is 38 years old, no family history, scheduled for lumpectomy."],
["Patient with advanced metastatic breast cancer, multiple organ involvement, poor prognosis."],
["Young patient, BRCA1 positive, preventive double mastectomy performed, good recovery."]
],
inputs=text_input,
label="範例文本(點擊使用)"
)
predict_button = gr.Button(
"🔮 開始預測",
variant="primary",
size="lg"
)
with gr.Column():
gr.Markdown("### 預測結果比較")
# 上框:未微調 BERT 預測結果
baseline_prediction_output = gr.Markdown(
label="未微調 BERT",
value="等待預測..."
)
# 下框:微調 BERT 預測結果
finetuned_prediction_output = gr.Markdown(
label="微調 BERT",
value="等待預測..."
)
with gr.Tab("📖 使用說明"):
gr.Markdown("""
## 🔧 微調方法說明
| 方法 | 訓練速度 | 記憶體 | 效果 | 適用場景 |
|------|---------|--------|------|----------|
| **Full Fine-tuning** | 1x (基準) | 高 | 最佳 | 資源充足,要最佳效果 |
| **LoRA** | 3-5x 快 | 低 | 良好 | 資源有限,快速實驗 |
| **AdaLoRA** | 3-5x 快 | 低 | 良好 | 自動調整,平衡效果 |
## 📊 指標說明
- **F1 Score**: 精確率和召回率的調和平均,平衡指標
- **Accuracy**: 整體準確率
- **Precision**: 預測為死亡中的準確率
- **Recall/Sensitivity**: 實際死亡中被正確識別的比例
- **Specificity**: 實際存活中被正確識別的比例
- **AUC**: ROC 曲線下面積,整體分類能力
## 💡 使用建議
1. **資料不平衡嚴重**:增加權重倍數,使用 Recall 作為最佳化指標
2. **避免誤診**:使用 Precision 作為最佳化指標
3. **整體平衡**:使用 F1 Score 作為最佳化指標
4. **快速實驗**:使用 LoRA,減少 epochs
5. **最佳效果**:使用 Full Fine-tuning,8-10 epochs
## ⚠️ 注意事項
- 訓練時間依資料量和方法而定(5-20 分鐘)
- 建議至少 100 筆訓練資料
- GPU 會顯著加速訓練
- 預測結果僅供參考,實際醫療決策應由專業醫師判斷
""")
# 根據選擇的微調方法顯示/隱藏相應參數
def update_params_visibility(method):
if method == "LoRA":
return gr.update(visible=True), gr.update(visible=False)
elif method == "AdaLoRA":
return gr.update(visible=True), gr.update(visible=True)
else:
return gr.update(visible=False), gr.update(visible=False)
tuning_method.change(
fn=update_params_visibility,
inputs=[tuning_method],
outputs=[lora_params, adalora_params]
)
# 設定訓練按鈕動作 - 注意這裡改為三個輸出
train_button.click(
fn=train_wrapper,
inputs=[
file_input,
tuning_method,
weight_slider,
epochs_input,
batch_size_input,
lr_input,
warmup_input,
best_metric,
# LoRA 參數
lora_r,
lora_alpha,
lora_dropout,
lora_modules,
# AdaLoRA 參數
adalora_init_r,
adalora_target_r,
adalora_tinit,
adalora_tfinal,
adalora_delta_t
],
outputs=[data_info_output, baseline_output, finetuned_output] # 三個輸出
)
# 重新整理模型列表按鈕
def refresh_models():
return gr.update(choices=get_available_models(), value=get_available_models()[0])
refresh_button.click(
fn=refresh_models,
inputs=[],
outputs=[model_dropdown]
)
# 預測按鈕動作 - 兩個輸出:未微調和微調
predict_button.click(
fn=predict_text,
inputs=[model_dropdown, text_input],
outputs=[baseline_prediction_output, finetuned_prediction_output]
)
if __name__ == "__main__":
demo.launch()