smartTranscend's picture
Rename 1031 combine_Wen.py to app.py
f7f2a8b verified
import gradio as gr
import pandas as pd
import torch
from transformers import (
BertTokenizer, BertForSequenceClassification,
AutoTokenizer, AutoModelForSequenceClassification,
TrainingArguments, Trainer, DataCollatorWithPadding
)
from peft import (
LoraConfig, AdaLoraConfig, AdaptionPromptConfig, PrefixTuningConfig,
get_peft_model, TaskType, prepare_model_for_kbit_training
)
from datasets import Dataset, DatasetDict
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.utils import resample
from torch import nn
import os
from datetime import datetime
import gc
import numpy as np
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
torch.backends.cudnn.benchmark = False
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 全域變數
trained_models = {}
model_counter = 0
baseline_results = {}
baseline_model_cache = {}
baseline_performance_cache = {}
second_stage_models = {} # 儲存二次微調的模型
def calculate_improvement(baseline_val, finetuned_val):
"""安全計算改善率"""
if baseline_val == 0:
if finetuned_val > 0:
return float('inf')
else:
return 0.0
return (finetuned_val - baseline_val) / baseline_val * 100
def format_improve(val):
"""格式化改善率"""
if val == float('inf'):
return "N/A (baseline=0)"
return f"{val:+.1f}%"
def thorough_memory_cleanup():
"""徹底清理記憶體"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
torch.cuda.synchronize()
def compute_metrics(pred):
try:
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary', pos_label=1, zero_division=0)
acc = accuracy_score(labels, preds)
cm = confusion_matrix(labels, preds)
if cm.shape == (2, 2):
tn, fp, fn, tp = cm.ravel()
else:
tn = fp = fn = tp = 0
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
return {
'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall,
'sensitivity': sensitivity, 'specificity': specificity,
'tp': int(tp), 'tn': int(tn), 'fp': int(fp), 'fn': int(fn)
}
except Exception as e:
print(f"Error in compute_metrics: {e}")
return {
'accuracy': 0, 'f1': 0, 'precision': 0, 'recall': 0,
'sensitivity': 0, 'specificity': 0, 'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0
}
def evaluate_baseline(model, tokenizer, test_dataset, device, is_llama=False):
"""評估未微調的基準模型"""
model.eval()
all_preds = []
all_labels = []
from torch.utils.data import DataLoader
def collate_fn(batch):
return {
'input_ids': torch.stack([torch.tensor(item['input_ids']) for item in batch]),
'attention_mask': torch.stack([torch.tensor(item['attention_mask']) for item in batch]),
'labels': torch.tensor([item['label'] for item in batch])
}
dataloader = DataLoader(test_dataset, batch_size=16, collate_fn=collate_fn)
with torch.no_grad():
for batch in dataloader:
labels = batch.pop('labels')
target_device = model.device if is_llama else device
inputs = {k: v.to(target_device) for k, v in batch.items()}
outputs = model(**inputs)
preds = torch.argmax(outputs.logits, dim=-1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.numpy())
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary', pos_label=1, zero_division=0)
acc = accuracy_score(all_labels, all_preds)
cm = confusion_matrix(all_labels, all_preds)
if cm.shape == (2, 2):
tn, fp, fn, tp = cm.ravel()
else:
tn = fp = fn = tp = 0
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
return {
'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall,
'sensitivity': sensitivity, 'specificity': specificity,
'tp': int(tp), 'tn': int(tn), 'fp': int(fp), 'fn': int(fn)
}
class WeightedTrainer(Trainer):
def __init__(self, *args, class_weights=None, use_focal_loss=False, focal_gamma=2.0, **kwargs):
super().__init__(*args, **kwargs)
self.class_weights = class_weights
self.use_focal_loss = use_focal_loss
self.focal_gamma = focal_gamma
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
if self.class_weights is not None:
class_weights = self.class_weights.to(logits.dtype).to(logits.device)
else:
class_weights = None
if self.use_focal_loss:
ce_loss = nn.CrossEntropyLoss(reduction='none')(
logits.view(-1, 2), labels.view(-1)
)
pt = torch.exp(-ce_loss)
focal_weight = (1 - pt) ** self.focal_gamma
focal_loss = focal_weight * ce_loss
if class_weights is not None:
sample_weights = class_weights[labels.view(-1)]
focal_loss = focal_loss * sample_weights
loss = focal_loss.mean()
else:
loss_fct = nn.CrossEntropyLoss(weight=class_weights)
loss = loss_fct(logits.view(-1, 2), labels.view(-1))
return (loss, outputs) if return_outputs else loss
def train_model(csv_file, base_model, method, num_epochs, batch_size, learning_rate,
weight_decay, dropout, lora_r, lora_alpha, lora_dropout,
adalora_init_r, adalora_tinit, adalora_tfinal, adalora_deltaT,
adapter_len, prefix_len, best_metric):
global trained_models, model_counter, baseline_results, baseline_performance_cache
thorough_memory_cleanup()
print(f"🧹 GPU 記憶體清理完成")
model_mapping = {
"BERT-base": "bert-base-uncased",
"Llama-3.2-1B": "meta-llama/Llama-3.2-1B",
}
model_name = model_mapping.get(base_model, "bert-base-uncased")
is_llama = "llama" in model_name.lower()
try:
if csv_file is None:
return "❌ 請上傳 CSV", "", "", ""
df = pd.read_csv(csv_file.name)
text_col = 'Text' if 'Text' in df.columns else 'text'
label_col = 'label' if 'label' in df.columns else 'nbcd'
if text_col not in df.columns or label_col not in df.columns:
return f"❌ 需要 {text_col}{label_col} 欄位", "", "", ""
df_clean = pd.DataFrame({
'text': df[text_col].astype(str),
'label': df[label_col].astype(int)
}).dropna()
avg_length = df_clean['text'].str.len().mean()
min_length = df_clean['text'].str.len().min()
max_length = df_clean['text'].str.len().max()
n0_original = int(sum(df_clean['label'] == 0))
n1_original = int(sum(df_clean['label'] == 1))
if n1_original == 0:
return "❌ 無死亡樣本", "", "", ""
ratio_original = n0_original / n1_original
info = f"📊 原始資料: {len(df_clean)} 筆\n"
info += f"📏 文本長度: 平均 {avg_length:.0f} | 最小 {min_length} | 最大 {max_length}\n"
info += f"📈 原始分布 - 存活: {n0_original} | 死亡: {n1_original} (比例 {ratio_original:.2f}:1)\n"
# ⭐ 改這裡:強制所有模型都使用資料平衡
if True: # 原本是 if is_llama,改成 True
info += f"\n⚖️ 資料平衡策略:執行平衡處理...\n"
df_class_0 = df_clean[df_clean['label'] == 0]
df_class_1 = df_clean[df_clean['label'] == 1]
# ⭐ 改這裡:BERT 和 Llama 用不同的平衡數量
target_n = 500 if not is_llama else 700 # BERT 用 500,Llama 用 700
if len(df_class_0) > target_n:
df_class_0_balanced = resample(df_class_0, n_samples=target_n, random_state=42, replace=False)
info += f" ✅ Class 0 欠採樣: {len(df_class_0)}{len(df_class_0_balanced)} 筆\n"
else:
df_class_0_balanced = df_class_0
info += f" ⚠️ Class 0 樣本數不足,保持 {len(df_class_0)} 筆\n"
if len(df_class_1) < target_n:
df_class_1_balanced = resample(df_class_1, n_samples=target_n, random_state=42, replace=True)
info += f" ✅ Class 1 過採樣: {len(df_class_1)}{len(df_class_1_balanced)} 筆\n"
else:
df_class_1_balanced = df_class_1
info += f" ⚠️ Class 1 樣本數充足,保持 {len(df_class_1)} 筆\n"
df_clean = pd.concat([df_class_0_balanced, df_class_1_balanced])
df_clean = df_clean.sample(frac=1, random_state=42).reset_index(drop=True)
n0 = int(sum(df_clean['label'] == 0))
n1 = int(sum(df_clean['label'] == 1))
ratio = n0 / n1
info += f"\n📊 平衡後資料: {len(df_clean)} 筆\n"
info += f"📈 平衡後分布 - 存活: {n0} | 死亡: {n1} (比例 {ratio:.2f}:1)\n"
w0 = 1.0
w1 = 1.0 # 已平衡,不需要額外權重
info += f"🎯 類別權重: {w0:.4f} / {w1:.4f} (資料已平衡,使用相等權重)\n"
info += f"🤖 模型: {base_model}\n"
info += f"🔧 方法: {method.upper()}"
if is_llama:
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
max_length = 512
else:
tokenizer = BertTokenizer.from_pretrained(model_name)
max_length = 256
dataset = Dataset.from_pandas(df_clean[['text', 'label']])
def preprocess(ex):
return tokenizer(ex['text'], truncation=True, padding='max_length', max_length=max_length)
tokenized = dataset.map(preprocess, batched=True, remove_columns=['text'])
split = tokenized.train_test_split(test_size=0.2, seed=42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
info += f"\n裝置: {'GPU ✅' if torch.cuda.is_available() else 'CPU ⚠️'}"
# 評估基準模型
baseline_key = f"{base_model}_baseline"
if baseline_key in baseline_performance_cache:
info += f"\n✅ 使用快取的 Baseline 評估結果\n"
baseline_perf = baseline_performance_cache[baseline_key]
else:
info += f"\n🔍 首次評估 Baseline 模型...\n"
if is_llama:
baseline_model = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=2,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
baseline_model.config.pad_token_id = tokenizer.pad_token_id
else:
baseline_model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
baseline_model = baseline_model.to(device)
baseline_perf = evaluate_baseline(baseline_model, tokenizer, split['test'], device, is_llama=is_llama)
baseline_performance_cache[baseline_key] = baseline_perf
baseline_results[baseline_key] = baseline_perf
del baseline_model
thorough_memory_cleanup()
info += f"\n\n🔧 套用 {method.upper()} 微調..."
if is_llama:
model = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=2,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
model.config.pad_token_id = tokenizer.pad_token_id
else:
model = BertForSequenceClassification.from_pretrained(
model_name, num_labels=2,
hidden_dropout_prob=dropout,
attention_probs_dropout_prob=dropout
)
peft_applied = False
# 根據方法套用對應的 PEFT 配置
if method == "lora":
if is_llama:
config = LoraConfig(
task_type=TaskType.SEQ_CLS,
r=int(lora_r),
lora_alpha=int(lora_alpha),
lora_dropout=lora_dropout,
target_modules=["q_proj", "v_proj"],
bias="none"
)
else:
config = LoraConfig(
task_type=TaskType.SEQ_CLS,
r=int(lora_r),
lora_alpha=int(lora_alpha),
lora_dropout=lora_dropout,
target_modules=["query", "value"],
bias="none"
)
model = get_peft_model(model, config)
peft_applied = True
info += f"\n✅ LoRA 已套用(r={int(lora_r)}, alpha={int(lora_alpha)})"
elif method == "adalora":
steps_per_epoch = len(split['train']) // int(batch_size)
total_steps = steps_per_epoch * int(num_epochs)
adjusted_tinit = min(int(adalora_tinit), int(total_steps * 0.2))
adjusted_tfinal = min(int(adalora_tfinal), int(total_steps * 0.9))
if adjusted_tinit >= adjusted_tfinal:
adjusted_tinit = int(total_steps * 0.1)
adjusted_tfinal = int(total_steps * 0.8)
info += f"\n📊 AdaLoRA 步數調整:\n"
info += f" 總訓練步數: {total_steps}\n"
info += f" tinit: {int(adalora_tinit)}{adjusted_tinit}\n"
info += f" tfinal: {int(adalora_tfinal)}{adjusted_tfinal}\n"
if is_llama:
config = AdaLoraConfig(
task_type=TaskType.SEQ_CLS,
r=int(lora_r),
lora_alpha=int(lora_alpha),
lora_dropout=lora_dropout,
target_modules=["q_proj", "v_proj"],
init_r=int(adalora_init_r),
tinit=adjusted_tinit,
tfinal=adjusted_tfinal,
deltaT=int(adalora_deltaT)
)
else:
config = AdaLoraConfig(
task_type=TaskType.SEQ_CLS,
r=int(lora_r),
lora_alpha=int(lora_alpha),
lora_dropout=lora_dropout,
target_modules=["query", "value"],
init_r=int(adalora_init_r),
tinit=adjusted_tinit,
tfinal=adjusted_tfinal,
deltaT=int(adalora_deltaT)
)
model = get_peft_model(model, config)
peft_applied = True
info += f"\n✅ AdaLoRA 已套用(r={int(lora_r)}, alpha={int(lora_alpha)}, init_r={int(adalora_init_r)})"
elif method == "adapter":
# Adapter (LLaMA-Adapter style)
if is_llama:
config = AdaptionPromptConfig(
task_type=TaskType.SEQ_CLS,
adapter_len=int(adapter_len),
adapter_layers=30 # 根據 Llama 層數調整
)
model = get_peft_model(model, config)
peft_applied = True
info += f"\n✅ Adapter 已套用(length={int(adapter_len)})"
else:
info += f"\n⚠️ Adapter 僅支援 Llama,改用 LoRA"
config = LoraConfig(
task_type=TaskType.SEQ_CLS,
r=int(lora_r),
lora_alpha=int(lora_alpha),
lora_dropout=lora_dropout,
target_modules=["query", "value"],
bias="none"
)
model = get_peft_model(model, config)
peft_applied = True
elif method == "prefix":
# Prefix Tuning
config = PrefixTuningConfig(
task_type=TaskType.SEQ_CLS,
num_virtual_tokens=int(prefix_len),
prefix_projection=True # 使用 MLP 投影
)
model = get_peft_model(model, config)
peft_applied = True
info += f"\n✅ Prefix Tuning 已套用(tokens={int(prefix_len)})"
elif method == "prompt":
# Prompt Tuning (類似 Prefix 但更簡單)
config = PrefixTuningConfig(
task_type=TaskType.SEQ_CLS,
num_virtual_tokens=int(prefix_len),
prefix_projection=False # 不使用投影
)
model = get_peft_model(model, config)
peft_applied = True
info += f"\n✅ Prompt Tuning 已套用(tokens={int(prefix_len)})"
elif method == "bitfit":
# BitFit: 只訓練 bias
for name, param in model.named_parameters():
if 'bias' not in name:
param.requires_grad = False
peft_applied = True
info += f"\n✅ BitFit 已套用(僅訓練 bias 參數)"
if not peft_applied:
info += f"\n⚠️ 警告:{method} 方法未被識別,使用 Full Fine-tuning"
if not is_llama:
model = model.to(device)
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
info += f"\n\n💾 參數量\n總參數: {total:,}\n可訓練: {trainable:,}\n比例: {trainable/total*100:.2f}%"
if is_llama:
weight_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
weights = torch.tensor([w0, w1], dtype=weight_dtype).to(model.device)
else:
weights = torch.tensor([w0, w1], dtype=torch.float32).to(device)
info += f"\n⚖️ 權重 dtype: {weights.dtype} | device: {weights.device}\n"
metrics_lower_is_better = ['loss']
is_greater_better = best_metric not in metrics_lower_is_better
args = TrainingArguments(
output_dir='./results',
num_train_epochs=int(num_epochs),
per_device_train_batch_size=int(batch_size),
per_device_eval_batch_size=int(batch_size)*2,
learning_rate=float(learning_rate),
weight_decay=float(weight_decay),
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model=best_metric,
greater_is_better=is_greater_better,
report_to="none",
logging_steps=10,
warmup_steps=100 if is_llama else 50,
warmup_ratio=0.1 if is_llama else 0.0,
logging_first_step=True,
bf16=(torch.cuda.is_available() and is_llama),
gradient_accumulation_steps=4 if is_llama else 1,
gradient_checkpointing=True if is_llama else False,
optim="adamw_torch",
seed=42,
max_grad_norm=0.3 if is_llama else 1.0,
)
info += f"\n📊 最佳模型選擇: {best_metric} ({'越大越好' if is_greater_better else '越小越好'})\n"
focal_gamma = 2.0
trainer = WeightedTrainer(
model=model,
args=args,
train_dataset=split['train'],
eval_dataset=split['test'],
compute_metrics=compute_metrics,
class_weights=weights,
use_focal_loss=True,
focal_gamma=2.0
)
if is_llama:
info += f"\n⚡ Llama 使用 Focal Loss (gamma={focal_gamma}) + {weight_boost}x 權重策略"
info += "\n\n⏳ 開始訓練..."
info += f"\n📊 訓練前檢查:"
info += f"\n - 訓練樣本: {len(split['train'])}"
info += f"\n - 測試樣本: {len(split['test'])}"
info += f"\n - 批次數/epoch: {len(split['train']) // int(batch_size)}"
train_result = trainer.train()
info += f"\n\n✅ 訓練完成!"
info += f"\n📉 最終 Training Loss: {train_result.training_loss:.4f}"
results = trainer.evaluate()
model_counter += 1
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_id = f"{base_model}_{method}_{timestamp}"
trained_models[model_id] = {
'model': model,
'tokenizer': tokenizer,
'results': results,
'baseline': baseline_perf,
'config': {
'type': base_model,
'model_name': model_name,
'method': method,
'metric': best_metric,
'is_llama': is_llama
},
'timestamp': timestamp,
'stage': 1 # 標記為第一階段訓練
}
metric_name_map = {
'f1': 'F1',
'accuracy': 'Accuracy',
'precision': 'Precision',
'recall': 'Recall',
'sensitivity': 'Sensitivity',
'specificity': 'Specificity'
}
baseline_val = baseline_perf[best_metric]
finetuned_val = results[f'eval_{best_metric}']
improvement = calculate_improvement(baseline_val, finetuned_val)
baseline_output = f"🔬 純 {base_model}(未微調)\n\n"
baseline_output += f"📊 {metric_name_map[best_metric]} 表現\n"
baseline_output += f"{metric_name_map[best_metric]}: {baseline_val:.4f}\n\n"
baseline_output += f"混淆矩陣\n"
baseline_output += f"TP: {baseline_perf['tp']} | TN: {baseline_perf['tn']}\n"
baseline_output += f"FP: {baseline_perf['fp']} | FN: {baseline_perf['fn']}"
finetuned_output = f"✅ 微調 {base_model}\n"
finetuned_output += f"模型: {model_id}\n\n"
finetuned_output += f"📊 {metric_name_map[best_metric]} 表現\n"
finetuned_output += f"{metric_name_map[best_metric]}: {finetuned_val:.4f}\n\n"
finetuned_output += f"混淆矩陣\n"
finetuned_output += f"TP: {results['eval_tp']} | TN: {results['eval_tn']}\n"
finetuned_output += f"FP: {results['eval_fp']} | FN: {results['eval_fn']}"
comparison_output = f"📊 純 {base_model} vs 微調 {base_model} 比較\n\n"
comparison_output += f"🎯 選擇的評估指標: {metric_name_map[best_metric]}\n\n"
comparison_output += f"{metric_name_map[best_metric]} 改善:\n"
comparison_output += f"{baseline_val:.4f}{finetuned_val:.4f} ({format_improve(improvement)})\n\n"
comparison_output += f"混淆矩陣變化:\n"
comparison_output += f"TP: {baseline_perf['tp']}{results['eval_tp']} ({results['eval_tp'] - baseline_perf['tp']:+d})\n"
comparison_output += f"TN: {baseline_perf['tn']}{results['eval_tn']} ({results['eval_tn'] - baseline_perf['tn']:+d})\n"
comparison_output += f"FP: {baseline_perf['fp']}{results['eval_fp']} ({results['eval_fp'] - baseline_perf['fp']:+d})\n"
comparison_output += f"FN: {baseline_perf['fn']}{results['eval_fn']} ({results['eval_fn'] - baseline_perf['fn']:+d})"
info += "\n\n✅ 訓練完成!"
thorough_memory_cleanup()
return info, baseline_output, finetuned_output, comparison_output
except Exception as e:
thorough_memory_cleanup()
import traceback
error_msg = f"❌ 錯誤: {str(e)}\n\n{traceback.format_exc()}"
return error_msg, "", "", ""
def second_stage_train(first_model_id, csv_file, num_epochs, batch_size, learning_rate, best_metric):
"""二次微調:基於已訓練模型繼續訓練"""
global trained_models, second_stage_models
if not first_model_id or first_model_id not in trained_models:
return "❌ 請選擇第一階段模型", "", ""
if csv_file is None:
return "❌ 請上傳新的訓練資料", "", ""
try:
thorough_memory_cleanup()
# 載入第一階段模型
first_model_info = trained_models[first_model_id]
model = first_model_info['model']
tokenizer = first_model_info['tokenizer']
config = first_model_info['config']
is_llama = config['is_llama']
info = f"🔄 二次微調\n"
info += f"基於模型: {first_model_id}\n"
info += f"方法: {config['method'].upper()}\n\n"
# 讀取新資料
df = pd.read_csv(csv_file.name)
text_col = 'Text' if 'Text' in df.columns else 'text'
label_col = 'label' if 'label' in df.columns else 'nbcd'
df_clean = pd.DataFrame({
'text': df[text_col].astype(str),
'label': df[label_col].astype(int)
}).dropna()
n0 = int(sum(df_clean['label'] == 0))
n1 = int(sum(df_clean['label'] == 1))
info += f"📊 新資料: {len(df_clean)} 筆\n"
info += f"📈 分布 - 存活: {n0} | 死亡: {n1}\n\n"
# 準備資料集
max_length = 512 if is_llama else 256
dataset = Dataset.from_pandas(df_clean[['text', 'label']])
def preprocess(ex):
return tokenizer(ex['text'], truncation=True, padding='max_length', max_length=max_length)
tokenized = dataset.map(preprocess, batched=True, remove_columns=['text'])
split = tokenized.train_test_split(test_size=0.2, seed=42)
# 計算權重
if is_llama:
w0 = 1.0
w1 = (n0 / n1) * 1.5
weight_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
weights = torch.tensor([w0, w1], dtype=weight_dtype).to(model.device)
else:
w0 = 1.0
w1 = min((n0 / n1) * 0.8, 15.0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weights = torch.tensor([w0, w1], dtype=torch.float32).to(device)
info += f"🎯 類別權重: {w0:.4f} / {w1:.4f}\n"
# 訓練配置
args = TrainingArguments(
output_dir='./results_stage2',
num_train_epochs=int(num_epochs),
per_device_train_batch_size=int(batch_size),
per_device_eval_batch_size=int(batch_size)*2,
learning_rate=float(learning_rate) * 0.5, # 二次微調使用較小學習率
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model=best_metric,
greater_is_better=True,
report_to="none",
logging_steps=10,
seed=43 # 不同的 seed
)
info += f"\n⏳ 開始二次微調(學習率減半)...\n"
trainer = WeightedTrainer(
model=model,
args=args,
train_dataset=split['train'],
eval_dataset=split['test'],
compute_metrics=compute_metrics,
class_weights=weights,
use_focal_loss=is_llama
)
train_result = trainer.train()
results = trainer.evaluate()
info += f"\n✅ 二次微調完成!\n"
info += f"📉 Training Loss: {train_result.training_loss:.4f}\n"
# 保存二次微調模型
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_id = f"{first_model_id}_stage2_{timestamp}"
second_stage_models[model_id] = {
'model': model,
'tokenizer': tokenizer,
'results': results,
'first_stage_id': first_model_id,
'first_stage_results': first_model_info['results'],
'baseline': first_model_info['baseline'],
'config': config,
'timestamp': timestamp,
'stage': 2
}
# 同時加入 trained_models 以便預測使用
trained_models[model_id] = second_stage_models[model_id]
metric_name_map = {
'f1': 'F1', 'accuracy': 'Accuracy', 'precision': 'Precision',
'recall': 'Recall', 'sensitivity': 'Sensitivity', 'specificity': 'Specificity'
}
# 比較結果
baseline_val = first_model_info['baseline'][best_metric]
stage1_val = first_model_info['results'][f'eval_{best_metric}']
stage2_val = results[f'eval_{best_metric}']
stage1_improve = calculate_improvement(baseline_val, stage1_val)
stage2_improve = calculate_improvement(stage1_val, stage2_val)
total_improve = calculate_improvement(baseline_val, stage2_val)
stage1_output = f"🥇 第一階段微調結果\n\n"
stage1_output += f"模型: {first_model_id}\n"
stage1_output += f"{metric_name_map[best_metric]}: {stage1_val:.4f}\n"
stage1_output += f"較 Baseline 改善: {format_improve(stage1_improve)}\n\n"
stage1_output += f"混淆矩陣\n"
stage1_output += f"TP: {first_model_info['results']['eval_tp']} | TN: {first_model_info['results']['eval_tn']}\n"
stage1_output += f"FP: {first_model_info['results']['eval_fp']} | FN: {first_model_info['results']['eval_fn']}"
stage2_output = f"🥈 第二階段微調結果\n\n"
stage2_output += f"模型: {model_id}\n"
stage2_output += f"{metric_name_map[best_metric]}: {stage2_val:.4f}\n"
stage2_output += f"較第一階段改善: {format_improve(stage2_improve)}\n"
stage2_output += f"較 Baseline 總改善: {format_improve(total_improve)}\n\n"
stage2_output += f"混淆矩陣\n"
stage2_output += f"TP: {results['eval_tp']} | TN: {results['eval_tn']}\n"
stage2_output += f"FP: {results['eval_fp']} | FN: {results['eval_fn']}"
thorough_memory_cleanup()
return info, stage1_output, stage2_output
except Exception as e:
thorough_memory_cleanup()
import traceback
return f"❌ 錯誤: {str(e)}\n\n{traceback.format_exc()}", "", ""
def evaluate_on_new_data(csv_file, selected_models):
"""在全新資料上評估多個模型"""
global trained_models, baseline_model_cache
if csv_file is None:
return "❌ 請上傳測試資料"
if not selected_models:
return "❌ 請至少選擇一個模型"
try:
# 讀取測試資料
df = pd.read_csv(csv_file.name)
text_col = 'Text' if 'Text' in df.columns else 'text'
label_col = 'label' if 'label' in df.columns else 'nbcd'
df_clean = pd.DataFrame({
'text': df[text_col].astype(str),
'label': df[label_col].astype(int)
}).dropna()
output = f"# 📊 全新資料評估報告\n\n"
output += f"## 測試資料概況\n"
output += f"- 總樣本數: {len(df_clean)}\n"
output += f"- 存活 (0): {sum(df_clean['label']==0)}\n"
output += f"- 死亡 (1): {sum(df_clean['label']==1)}\n\n"
output += f"## 模型表現比較\n\n"
results_table = []
for model_id in selected_models:
if model_id not in trained_models:
continue
info = trained_models[model_id]
model = info['model']
tokenizer = info['tokenizer']
config = info['config']
is_llama = config['is_llama']
# 準備資料
max_length = 512 if is_llama else 256
dataset = Dataset.from_pandas(df_clean[['text', 'label']])
def preprocess(ex):
return tokenizer(ex['text'], truncation=True, padding='max_length', max_length=max_length)
tokenized = dataset.map(preprocess, batched=True, remove_columns=['text'])
# 評估
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
perf = evaluate_baseline(model, tokenizer, tokenized, device, is_llama=is_llama)
stage = info.get('stage', 1)
stage_label = "🔬 Baseline" if "baseline" in model_id else f"🥇 Stage {stage}"
results_table.append({
'model': model_id,
'stage': stage_label,
'method': config['method'].upper(),
'f1': perf['f1'],
'acc': perf['accuracy'],
'prec': perf['precision'],
'recall': perf['recall'],
'sens': perf['sensitivity'],
'spec': perf['specificity'],
'tp': perf['tp'],
'tn': perf['tn'],
'fp': perf['fp'],
'fn': perf['fn']
})
# 也評估 baseline 模型
if results_table:
first_model = trained_models[selected_models[0]]
config = first_model['config']
model_name = config['model_name']
is_llama = config['is_llama']
cache_key = model_name
if cache_key not in baseline_model_cache:
if is_llama:
baseline_model = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=2,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
baseline_model.config.pad_token_id = first_model['tokenizer'].pad_token_id
else:
baseline_model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
baseline_model = baseline_model.to(device)
baseline_model.eval()
baseline_model_cache[cache_key] = baseline_model
else:
baseline_model = baseline_model_cache[cache_key]
tokenizer = first_model['tokenizer']
max_length = 512 if is_llama else 256
dataset = Dataset.from_pandas(df_clean[['text', 'label']])
def preprocess(ex):
return tokenizer(ex['text'], truncation=True, padding='max_length', max_length=max_length)
tokenized = dataset.map(preprocess, batched=True, remove_columns=['text'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
baseline_perf = evaluate_baseline(baseline_model, tokenizer, tokenized, device, is_llama=is_llama)
results_table.insert(0, {
'model': f"{config['type']}-Baseline",
'stage': "🔬 Baseline",
'method': "None",
'f1': baseline_perf['f1'],
'acc': baseline_perf['accuracy'],
'prec': baseline_perf['precision'],
'recall': baseline_perf['recall'],
'sens': baseline_perf['sensitivity'],
'spec': baseline_perf['specificity'],
'tp': baseline_perf['tp'],
'tn': baseline_perf['tn'],
'fp': baseline_perf['fp'],
'fn': baseline_perf['fn']
})
# 輸出表格
output += "| 模型 | 階段 | 方法 | F1 | Acc | Prec | Recall | Sens | Spec |\n"
output += "|------|------|------|-----|-----|------|--------|------|------|\n"
for r in results_table:
output += f"| {r['model'][:30]} | {r['stage']} | {r['method']} | "
output += f"{r['f1']:.4f} | {r['acc']:.4f} | {r['prec']:.4f} | "
output += f"{r['recall']:.4f} | {r['sens']:.4f} | {r['spec']:.4f} |\n"
output += "\n## 混淆矩陣\n\n"
output += "| 模型 | TP | TN | FP | FN |\n"
output += "|------|----|----|----|\----|\n"
for r in results_table:
output += f"| {r['model'][:30]} | {r['tp']} | {r['tn']} | {r['fp']} | {r['fn']} |\n"
# 找出最佳模型
output += "\n## 🏆 最佳模型\n\n"
for metric in ['f1', 'acc', 'sens', 'spec']:
best = max(results_table, key=lambda x: x[metric])
baseline_val = results_table[0][metric]
improve = calculate_improvement(baseline_val, best[metric])
metric_names = {'f1': 'F1', 'acc': 'Accuracy', 'sens': 'Sensitivity', 'spec': 'Specificity'}
output += f"**{metric_names[metric]}**: {best['model'][:30]} ({best[metric]:.4f}, 較 Baseline 改善 {format_improve(improve)})\n\n"
return output
except Exception as e:
import traceback
return f"❌ 錯誤: {str(e)}\n\n{traceback.format_exc()}"
def predict(model_id, text):
"""預測功能(保持原樣)"""
global baseline_model_cache
if not model_id or model_id not in trained_models:
return "❌ 請選擇模型"
if not text:
return "❌ 請輸入文字"
try:
info = trained_models[model_id]
model, tokenizer = info['model'], info['tokenizer']
config = info['config']
is_llama = config.get('is_llama', False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
max_length = 512 if is_llama else 256
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=max_length)
if not is_llama:
inputs_cuda = {k: v.to(device) for k, v in inputs.items()}
else:
inputs_cuda = {k: v.to(model.device) for k, v in inputs.items()}
model.eval()
with torch.no_grad():
outputs = model(**inputs_cuda)
probs_finetuned = torch.nn.functional.softmax(outputs.logits, dim=-1)
pred_finetuned = torch.argmax(probs_finetuned, dim=-1).item()
result_finetuned = "存活" if pred_finetuned == 0 else "死亡"
cache_key = config['model_name']
if cache_key not in baseline_model_cache:
if is_llama:
baseline_model = AutoModelForSequenceClassification.from_pretrained(
config['model_name'], num_labels=2,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
baseline_model.config.pad_token_id = tokenizer.pad_token_id
else:
baseline_model = BertForSequenceClassification.from_pretrained(config['model_name'], num_labels=2)
baseline_model = baseline_model.to(device)
baseline_model.eval()
baseline_model_cache[cache_key] = baseline_model
else:
baseline_model = baseline_model_cache[cache_key]
with torch.no_grad():
if is_llama:
inputs_baseline = {k: v.to(baseline_model.device) for k, v in inputs.items()}
else:
inputs_baseline = inputs_cuda
outputs_baseline = baseline_model(**inputs_baseline)
probs_baseline = torch.nn.functional.softmax(outputs_baseline.logits, dim=-1)
pred_baseline = torch.argmax(probs_baseline, dim=-1).item()
result_baseline = "存活" if pred_baseline == 0 else "死亡"
agreement = "✅ 一致" if pred_finetuned == pred_baseline else "⚠️ 不一致"
metric_name_map = {
'f1': 'F1',
'accuracy': 'Accuracy',
'precision': 'Precision',
'recall': 'Recall',
'sensitivity': 'Sensitivity',
'specificity': 'Specificity'
}
selected_metric = config['metric']
metric_display = metric_name_map[selected_metric]
baseline_metric_val = info['baseline'][selected_metric]
finetuned_metric_val = info['results'][f'eval_{selected_metric}']
improvement = calculate_improvement(baseline_metric_val, finetuned_metric_val)
stage = info.get('stage', 1)
stage_label = f"Stage {stage}" if stage > 1 else "微調"
output = f"""🔮 預測結果比較
📝 輸入文字: {text[:100]}{'...' if len(text) > 100 else ''}
{'='*50}
🧬 {stage_label}模型 ({model_id})
預測: {result_finetuned}
信心: {probs_finetuned[0][pred_finetuned].item():.2%}
機率分布:
• 存活: {probs_finetuned[0][0].item():.2%}
• 死亡: {probs_finetuned[0][1].item():.2%}
{'='*50}
🔬 基準模型(未微調 {config['type']}
預測: {result_baseline}
信心: {probs_baseline[0][pred_baseline].item():.2%}
機率分布:
• 存活: {probs_baseline[0][0].item():.2%}
• 死亡: {probs_baseline[0][1].item():.2%}
{'='*50}
📊 結論
兩模型預測: {agreement}
"""
if pred_finetuned != pred_baseline:
output += f"\n💡 分析: {stage_label}模型預測為【{result_finetuned}】,而基準模型預測為【{result_baseline}】"
output += f"\n 這顯示了 fine-tuning 對此案例的影響!"
output += f"""
📈 模型表現(基於 {metric_display}
{stage_label}模型 {metric_display}: {finetuned_metric_val:.4f}
基準模型 {metric_display}: {baseline_metric_val:.4f}
改善幅度: {format_improve(improvement)}
"""
return output
except Exception as e:
import traceback
return f"❌ 錯誤: {str(e)}\n\n{traceback.format_exc()}"
def compare():
"""比較所有模型"""
if not trained_models:
return "❌ 尚未訓練模型"
text = "# 📊 模型比較\n\n"
text += "## 微調模型表現\n\n"
text += "| 模型 | 階段 | 基礎 | 方法 | F1 | Acc | Prec | Recall | Sens | Spec |\n"
text += "|------|------|------|------|-----|-----|------|--------|------|------|\n"
for mid, info in trained_models.items():
r = info['results']
c = info['config']
stage = info.get('stage', 1)
text += f"| {mid} | Stage{stage} | {c['type']} | {c['method'].upper()} | {r['eval_f1']:.4f} | {r['eval_accuracy']:.4f} | "
text += f"{r['eval_precision']:.4f} | {r['eval_recall']:.4f} | "
text += f"{r['eval_sensitivity']:.4f} | {r['eval_specificity']:.4f} |\n"
text += "\n## 基準模型表現(未微調)\n\n"
text += "| 模型 | F1 | Acc | Prec | Recall | Sens | Spec |\n"
text += "|------|-----|-----|------|--------|------|------|\n"
seen_baselines = set()
for mid, info in trained_models.items():
b = info['baseline']
c = info['config']
baseline_key = f"{c['type']}-baseline"
if baseline_key not in seen_baselines:
text += f"| {baseline_key} | {b['f1']:.4f} | {b['accuracy']:.4f} | "
text += f"{b['precision']:.4f} | {b['recall']:.4f} | "
text += f"{b['sensitivity']:.4f} | {b['specificity']:.4f} |\n"
seen_baselines.add(baseline_key)
text += "\n## 🏆 最佳模型\n\n"
for metric in ['f1', 'accuracy', 'precision', 'recall', 'sensitivity', 'specificity']:
best = max(trained_models.items(), key=lambda x: x[1]['results'][f'eval_{metric}'])
baseline_val = best[1]['baseline'][metric]
finetuned_val = best[1]['results'][f'eval_{metric}']
improvement = calculate_improvement(baseline_val, finetuned_val)
text += f"**{metric.upper()}**: {best[0]} ({finetuned_val:.4f}, 改善 {format_improve(improvement)})\n\n"
return text
def refresh_model_list():
return gr.Dropdown(choices=list(trained_models.keys()))
def refresh_model_checkboxes():
return gr.CheckboxGroup(choices=list(trained_models.keys()))
def clear_gpu_memory():
"""手動清理 GPU 記憶體"""
global baseline_model_cache, baseline_performance_cache
try:
baseline_model_cache.clear()
baseline_performance_cache.clear()
thorough_memory_cleanup()
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated(0) / 1024**3
reserved = torch.cuda.memory_reserved(0) / 1024**3
max_allocated = torch.cuda.max_memory_allocated(0) / 1024**3
return f"""✅ GPU 記憶體清理完成!
當前狀態:
已分配: {allocated:.2f} GB
已保留: {reserved:.2f} GB
峰值使用: {max_allocated:.2f} GB"""
else:
return "✅ 記憶體清理完成(CPU 模式)"
except Exception as e:
return f"❌ 清理失敗: {str(e)}"
def update_method_params(method):
"""根據選擇的方法更新參數顯示"""
return {
lora_params: gr.update(visible=method in ["lora", "adalora"]),
adalora_params: gr.update(visible=method == "adalora"),
adapter_params: gr.update(visible=method == "adapter"),
prefix_params: gr.update(visible=method in ["prefix", "prompt"])
}
# Gradio UI
with gr.Blocks(title="完整版 Fine-tuning 平台", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🤖 完整版 BERT & Llama Fine-tuning 平台 v3")
gr.Markdown("### 支持 6 種微調方法 + 二次微調 + 全新資料測試")
gr.Markdown("#### ✨ LoRA | AdaLoRA | Adapter | Prefix Tuning | Prompt Tuning | BitFit")
with gr.Tab("🥇 第一階段訓練"):
gr.Markdown("## 步驟 1: 選擇基礎模型")
base_model = gr.Dropdown(
choices=["BERT-base", "Llama-3.2-1B"],
value="BERT-base",
label="基礎模型"
)
gr.Markdown("### 🧹 記憶體管理")
with gr.Row():
clear_mem_btn = gr.Button("🧹 清理 GPU 記憶體", variant="secondary")
mem_status = gr.Textbox(label="記憶體狀態", lines=4, interactive=False, scale=2)
gr.Markdown("## 步驟 2: 選擇微調方法")
method = gr.Radio(
choices=["lora", "adalora", "adapter", "prefix", "prompt", "bitfit"],
value="lora",
label="微調方法"
)
gr.Markdown("## 步驟 3: 上傳資料")
csv_file = gr.File(label="CSV 檔案 (需包含 Text/text 和 label/nbcd 欄位)", file_types=[".csv"])
gr.Markdown("## 步驟 4: 設定訓練參數")
with gr.Row():
num_epochs = gr.Number(value=8, label="訓練輪數", minimum=1, maximum=100, precision=0)
batch_size = gr.Number(value=16, label="批次大小", minimum=1, maximum=128, precision=0)
learning_rate = gr.Number(value=2e-5, label="學習率", minimum=0, maximum=1)
with gr.Row():
weight_decay = gr.Number(value=0.01, label="權重衰減", minimum=0, maximum=1)
dropout = gr.Number(value=0.3, label="Dropout", minimum=0, maximum=1)
gr.Markdown("### 🔧 方法參數")
with gr.Group(visible=True) as lora_params:
gr.Markdown("#### LoRA 參數")
with gr.Row():
lora_r = gr.Number(value=32, label="Rank (r)", minimum=1, maximum=256, precision=0)
lora_alpha = gr.Number(value=64, label="Alpha", minimum=1, maximum=512, precision=0)
lora_dropout = gr.Number(value=0.1, label="Dropout", minimum=0, maximum=1)
with gr.Group(visible=False) as adalora_params:
gr.Markdown("#### AdaLoRA 參數")
with gr.Row():
adalora_init_r = gr.Number(value=12, label="初始 Rank", minimum=1, maximum=64, precision=0)
adalora_tinit = gr.Number(value=200, label="Tinit", minimum=0, maximum=1000, precision=0)
with gr.Row():
adalora_tfinal = gr.Number(value=1000, label="Tfinal", minimum=0, maximum=5000, precision=0)
adalora_deltaT = gr.Number(value=10, label="DeltaT", minimum=1, maximum=100, precision=0)
with gr.Group(visible=False) as adapter_params:
gr.Markdown("#### Adapter 參數")
adapter_len = gr.Number(value=10, label="Adapter Length", minimum=1, maximum=50, precision=0,
info="Adapter tokens 數量")
with gr.Group(visible=False) as prefix_params:
gr.Markdown("#### Prefix/Prompt 參數")
prefix_len = gr.Number(value=20, label="Virtual Tokens", minimum=1, maximum=100, precision=0,
info="虛擬 token 數量")
method.change(
update_method_params,
inputs=[method],
outputs=[lora_params, adalora_params, adapter_params, prefix_params]
)
best_metric = gr.Dropdown(
choices=["f1", "accuracy", "precision", "recall", "sensitivity", "specificity"],
value="f1",
label="最佳模型選擇指標"
)
train_btn = gr.Button("🚀 開始訓練", variant="primary", size="lg")
gr.Markdown("## 📊 訓練結果")
data_info = gr.Textbox(label="📋 資料資訊", lines=10)
with gr.Row():
baseline_result = gr.Textbox(label="🔬 Baseline", lines=14)
finetuned_result = gr.Textbox(label="✅ 微調模型", lines=14)
comparison_result = gr.Textbox(label="📊 比較", lines=14)
clear_mem_btn.click(clear_gpu_memory, outputs=[mem_status])
train_btn.click(
train_model,
inputs=[csv_file, base_model, method, num_epochs, batch_size, learning_rate,
weight_decay, dropout, lora_r, lora_alpha, lora_dropout,
adalora_init_r, adalora_tinit, adalora_tfinal, adalora_deltaT,
adapter_len, prefix_len, best_metric],
outputs=[data_info, baseline_result, finetuned_result, comparison_result]
)
with gr.Tab("🥈 第二階段訓練"):
gr.Markdown("## 二次微調:基於已訓練模型繼續訓練")
gr.Markdown("### 選擇第一階段模型,上傳新資料,進行二次微調")
with gr.Row():
first_model_select = gr.Dropdown(label="選擇第一階段模型", choices=list(trained_models.keys()))
refresh_stage1 = gr.Button("🔄 刷新模型列表")
stage2_csv = gr.File(label="上傳新的訓練資料 CSV", file_types=[".csv"])
gr.Markdown("### 二次微調參數")
with gr.Row():
stage2_epochs = gr.Number(value=3, label="訓練輪數", minimum=1, maximum=20, precision=0,
info="建議較少輪數")
stage2_batch = gr.Number(value=16, label="批次大小", minimum=1, maximum=128, precision=0)
stage2_lr = gr.Number(value=1e-5, label="學習率", minimum=0, maximum=1,
info="自動減半,建議更小")
stage2_metric = gr.Dropdown(
choices=["f1", "accuracy", "precision", "recall", "sensitivity", "specificity"],
value="f1",
label="評估指標"
)
stage2_train_btn = gr.Button("🔄 開始二次微調", variant="primary", size="lg")
gr.Markdown("## 📊 二次微調結果")
stage2_info = gr.Textbox(label="📋 訓練資訊", lines=8)
with gr.Row():
stage1_result = gr.Textbox(label="🥇 第一階段", lines=12)
stage2_result = gr.Textbox(label="🥈 第二階段", lines=12)
refresh_stage1.click(refresh_model_list, outputs=[first_model_select])
stage2_train_btn.click(
second_stage_train,
inputs=[first_model_select, stage2_csv, stage2_epochs, stage2_batch, stage2_lr, stage2_metric],
outputs=[stage2_info, stage1_result, stage2_result]
)
with gr.Tab("🆕 全新資料測試"):
gr.Markdown("## 在全新資料上測試所有模型")
gr.Markdown("### 上傳模型未見過的測試資料,比較 Baseline、Stage1、Stage2 的表現")
test_csv = gr.File(label="上傳測試資料 CSV", file_types=[".csv"])
with gr.Row():
test_models = gr.CheckboxGroup(label="選擇要測試的模型", choices=list(trained_models.keys()))
refresh_test = gr.Button("🔄 刷新")
test_btn = gr.Button("🧪 開始測試", variant="primary", size="lg")
test_output = gr.Markdown(label="測試結果")
refresh_test.click(refresh_model_checkboxes, outputs=[test_models])
test_btn.click(
evaluate_on_new_data,
inputs=[test_csv, test_models],
outputs=[test_output]
)
with gr.Tab("🔮 預測"):
gr.Markdown("## 使用訓練好的模型預測")
with gr.Row():
model_drop = gr.Dropdown(label="選擇模型", choices=list(trained_models.keys()))
refresh = gr.Button("🔄 刷新")
text_input = gr.Textbox(label="輸入病例描述", lines=4,
placeholder="Patient diagnosed with...")
predict_btn = gr.Button("預測", variant="primary", size="lg")
pred_output = gr.Textbox(label="預測結果", lines=20)
refresh.click(refresh_model_list, outputs=[model_drop])
predict_btn.click(predict, inputs=[model_drop, text_input], outputs=[pred_output])
gr.Examples(
examples=[
["Patient with stage II breast cancer, good response to treatment."],
["Advanced metastatic cancer, multiple organ involvement."]
],
inputs=text_input
)
with gr.Tab("📊 比較"):
gr.Markdown("## 比較所有模型")
compare_btn = gr.Button("比較", variant="primary", size="lg")
compare_output = gr.Markdown()
compare_btn.click(compare, outputs=[compare_output])
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
max_threads=4
)