Spaces:
Paused
Paused
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,632 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer
|
| 5 |
+
from peft import LoraConfig, AdaLoraConfig, get_peft_model, TaskType
|
| 6 |
+
from datasets import Dataset
|
| 7 |
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
|
| 8 |
+
from torch import nn
|
| 9 |
+
import os
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
|
| 12 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 13 |
+
|
| 14 |
+
# 全域變數
|
| 15 |
+
trained_models = {}
|
| 16 |
+
model_counter = 0
|
| 17 |
+
baseline_results = {}
|
| 18 |
+
baseline_model_cache = {}
|
| 19 |
+
|
| 20 |
+
def calculate_improvement(baseline_val, finetuned_val):
|
| 21 |
+
"""安全計算改善率"""
|
| 22 |
+
if baseline_val == 0:
|
| 23 |
+
if finetuned_val > 0:
|
| 24 |
+
return float('inf')
|
| 25 |
+
else:
|
| 26 |
+
return 0.0
|
| 27 |
+
return (finetuned_val - baseline_val) / baseline_val * 100
|
| 28 |
+
|
| 29 |
+
def format_improve(val):
|
| 30 |
+
"""格式化改善率"""
|
| 31 |
+
if val == float('inf'):
|
| 32 |
+
return "N/A (baseline=0)"
|
| 33 |
+
return f"{val:+.1f}%"
|
| 34 |
+
|
| 35 |
+
def compute_metrics(pred):
|
| 36 |
+
try:
|
| 37 |
+
labels = pred.label_ids
|
| 38 |
+
preds = pred.predictions.argmax(-1)
|
| 39 |
+
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary', pos_label=1, zero_division=0)
|
| 40 |
+
acc = accuracy_score(labels, preds)
|
| 41 |
+
cm = confusion_matrix(labels, preds)
|
| 42 |
+
if cm.shape == (2, 2):
|
| 43 |
+
tn, fp, fn, tp = cm.ravel()
|
| 44 |
+
else:
|
| 45 |
+
tn = fp = fn = tp = 0
|
| 46 |
+
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
|
| 47 |
+
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
|
| 48 |
+
return {
|
| 49 |
+
'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall,
|
| 50 |
+
'sensitivity': sensitivity, 'specificity': specificity,
|
| 51 |
+
'tp': int(tp), 'tn': int(tn), 'fp': int(fp), 'fn': int(fn)
|
| 52 |
+
}
|
| 53 |
+
except Exception as e:
|
| 54 |
+
print(f"Error in compute_metrics: {e}")
|
| 55 |
+
return {
|
| 56 |
+
'accuracy': 0, 'f1': 0, 'precision': 0, 'recall': 0,
|
| 57 |
+
'sensitivity': 0, 'specificity': 0, 'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
class WeightedTrainer(Trainer):
|
| 61 |
+
def __init__(self, *args, class_weights=None, **kwargs):
|
| 62 |
+
super().__init__(*args, **kwargs)
|
| 63 |
+
self.class_weights = class_weights
|
| 64 |
+
|
| 65 |
+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
| 66 |
+
labels = inputs.pop("labels")
|
| 67 |
+
outputs = model(**inputs)
|
| 68 |
+
loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
|
| 69 |
+
loss = loss_fct(outputs.logits.view(-1, 2), labels.view(-1))
|
| 70 |
+
return (loss, outputs) if return_outputs else loss
|
| 71 |
+
|
| 72 |
+
def evaluate_baseline(model, tokenizer, test_dataset, device):
|
| 73 |
+
"""評估未微調的基準模型"""
|
| 74 |
+
model.eval()
|
| 75 |
+
all_preds = []
|
| 76 |
+
all_labels = []
|
| 77 |
+
|
| 78 |
+
from torch.utils.data import DataLoader
|
| 79 |
+
|
| 80 |
+
def collate_fn(batch):
|
| 81 |
+
return {
|
| 82 |
+
'input_ids': torch.stack([torch.tensor(item['input_ids']) for item in batch]),
|
| 83 |
+
'attention_mask': torch.stack([torch.tensor(item['attention_mask']) for item in batch]),
|
| 84 |
+
'labels': torch.tensor([item['label'] for item in batch])
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
dataloader = DataLoader(test_dataset, batch_size=16, collate_fn=collate_fn)
|
| 88 |
+
|
| 89 |
+
with torch.no_grad():
|
| 90 |
+
for batch in dataloader:
|
| 91 |
+
labels = batch.pop('labels')
|
| 92 |
+
inputs = {k: v.to(device) for k, v in batch.items()}
|
| 93 |
+
outputs = model(**inputs)
|
| 94 |
+
preds = torch.argmax(outputs.logits, dim=-1)
|
| 95 |
+
all_preds.extend(preds.cpu().numpy())
|
| 96 |
+
all_labels.extend(labels.numpy())
|
| 97 |
+
|
| 98 |
+
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary', pos_label=1, zero_division=0)
|
| 99 |
+
acc = accuracy_score(all_labels, all_preds)
|
| 100 |
+
cm = confusion_matrix(all_labels, all_preds)
|
| 101 |
+
if cm.shape == (2, 2):
|
| 102 |
+
tn, fp, fn, tp = cm.ravel()
|
| 103 |
+
else:
|
| 104 |
+
tn = fp = fn = tp = 0
|
| 105 |
+
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
|
| 106 |
+
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
|
| 107 |
+
|
| 108 |
+
return {
|
| 109 |
+
'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall,
|
| 110 |
+
'sensitivity': sensitivity, 'specificity': specificity,
|
| 111 |
+
'tp': int(tp), 'tn': int(tn), 'fp': int(fp), 'fn': int(fn)
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
def train_bert_model(csv_file, base_model, method, num_epochs, batch_size, learning_rate,
|
| 115 |
+
weight_decay, dropout, lora_r, lora_alpha, lora_dropout,
|
| 116 |
+
weight_mult, best_metric):
|
| 117 |
+
global trained_models, model_counter, baseline_results
|
| 118 |
+
|
| 119 |
+
model_mapping = {
|
| 120 |
+
"BERT-base": "bert-base-uncased",
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
model_name = model_mapping.get(base_model, "bert-base-uncased")
|
| 124 |
+
|
| 125 |
+
try:
|
| 126 |
+
if csv_file is None:
|
| 127 |
+
return "❌ 請上傳 CSV", "", "", ""
|
| 128 |
+
|
| 129 |
+
df = pd.read_csv(csv_file.name)
|
| 130 |
+
if 'Text' not in df.columns or 'label' not in df.columns:
|
| 131 |
+
return "❌ 需要 Text 和 label 欄位", "", "", ""
|
| 132 |
+
|
| 133 |
+
df_clean = pd.DataFrame({
|
| 134 |
+
'text': df['Text'].astype(str),
|
| 135 |
+
'label': df['label'].astype(int)
|
| 136 |
+
}).dropna()
|
| 137 |
+
|
| 138 |
+
n0 = int(sum(df_clean['label'] == 0))
|
| 139 |
+
n1 = int(sum(df_clean['label'] == 1))
|
| 140 |
+
if n1 == 0:
|
| 141 |
+
return "❌ 無死亡樣本", "", "", ""
|
| 142 |
+
|
| 143 |
+
ratio = n0 / n1
|
| 144 |
+
w0, w1 = 1.0, ratio * weight_mult
|
| 145 |
+
|
| 146 |
+
info = f"📊 資料: {len(df_clean)} 筆\n存活: {n0} | 死亡: {n1}\n比例: {ratio:.2f}:1\n權重: {w0:.2f} / {w1:.2f}\n模型: {base_model}\n方法: {method.upper()}"
|
| 147 |
+
|
| 148 |
+
tokenizer = BertTokenizer.from_pretrained(model_name)
|
| 149 |
+
dataset = Dataset.from_pandas(df_clean[['text', 'label']])
|
| 150 |
+
|
| 151 |
+
def preprocess(ex):
|
| 152 |
+
return tokenizer(ex['text'], truncation=True, padding='max_length', max_length=128)
|
| 153 |
+
|
| 154 |
+
tokenized = dataset.map(preprocess, batched=True, remove_columns=['text'])
|
| 155 |
+
split = tokenized.train_test_split(test_size=0.2, seed=42)
|
| 156 |
+
|
| 157 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 158 |
+
info += f"\n裝置: {'GPU ✅' if torch.cuda.is_available() else 'CPU ⚠️'}"
|
| 159 |
+
|
| 160 |
+
# 評估基準模型(未微調)
|
| 161 |
+
info += "\n\n🔍 評估基準模型(未微調)..."
|
| 162 |
+
baseline_model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
|
| 163 |
+
baseline_model = baseline_model.to(device)
|
| 164 |
+
|
| 165 |
+
baseline_perf = evaluate_baseline(baseline_model, tokenizer, split['test'], device)
|
| 166 |
+
baseline_key = f"{base_model}_baseline"
|
| 167 |
+
baseline_results[baseline_key] = baseline_perf
|
| 168 |
+
|
| 169 |
+
info += f"\n基準 F1: {baseline_perf['f1']:.4f}"
|
| 170 |
+
info += f"\n基準 Accuracy: {baseline_perf['accuracy']:.4f}"
|
| 171 |
+
|
| 172 |
+
# 清理基準模型以釋放記憶體
|
| 173 |
+
del baseline_model
|
| 174 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 175 |
+
|
| 176 |
+
# 開始微調
|
| 177 |
+
info += f"\n\n🔧 套用 {method.upper()} 微調..."
|
| 178 |
+
model = BertForSequenceClassification.from_pretrained(
|
| 179 |
+
model_name, num_labels=2,
|
| 180 |
+
hidden_dropout_prob=dropout,
|
| 181 |
+
attention_probs_dropout_prob=dropout
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
peft_applied = False
|
| 185 |
+
if method == "lora":
|
| 186 |
+
config = LoraConfig(
|
| 187 |
+
task_type=TaskType.SEQ_CLS,
|
| 188 |
+
r=int(lora_r),
|
| 189 |
+
lora_alpha=int(lora_alpha),
|
| 190 |
+
lora_dropout=lora_dropout,
|
| 191 |
+
target_modules=["query", "value"],
|
| 192 |
+
bias="none"
|
| 193 |
+
)
|
| 194 |
+
model = get_peft_model(model, config)
|
| 195 |
+
peft_applied = True
|
| 196 |
+
info += f"\n✅ LoRA 已套用(r={int(lora_r)}, alpha={int(lora_alpha)})"
|
| 197 |
+
elif method == "adalora":
|
| 198 |
+
config = AdaLoraConfig(
|
| 199 |
+
task_type=TaskType.SEQ_CLS,
|
| 200 |
+
r=int(lora_r),
|
| 201 |
+
lora_alpha=int(lora_alpha),
|
| 202 |
+
lora_dropout=lora_dropout,
|
| 203 |
+
target_modules=["query", "value"],
|
| 204 |
+
init_r=12, tinit=200, tfinal=1000, deltaT=10
|
| 205 |
+
)
|
| 206 |
+
model = get_peft_model(model, config)
|
| 207 |
+
peft_applied = True
|
| 208 |
+
info += f"\n✅ AdaLoRA 已套用(r={int(lora_r)}, alpha={int(lora_alpha)})"
|
| 209 |
+
|
| 210 |
+
if not peft_applied:
|
| 211 |
+
info += f"\n⚠️ 警告:{method} 方法未被識別,使用 Full Fine-tuning"
|
| 212 |
+
|
| 213 |
+
model = model.to(device)
|
| 214 |
+
|
| 215 |
+
total = sum(p.numel() for p in model.parameters())
|
| 216 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 217 |
+
info += f"\n\n💾 參數量\n總參數: {total:,}\n可訓練: {trainable:,}\n比例: {trainable/total*100:.2f}%"
|
| 218 |
+
|
| 219 |
+
weights = torch.tensor([w0, w1], dtype=torch.float).to(device)
|
| 220 |
+
|
| 221 |
+
args = TrainingArguments(
|
| 222 |
+
output_dir='./results',
|
| 223 |
+
num_train_epochs=int(num_epochs),
|
| 224 |
+
per_device_train_batch_size=int(batch_size),
|
| 225 |
+
per_device_eval_batch_size=int(batch_size)*2,
|
| 226 |
+
learning_rate=float(learning_rate),
|
| 227 |
+
weight_decay=float(weight_decay),
|
| 228 |
+
evaluation_strategy="epoch",
|
| 229 |
+
save_strategy="epoch",
|
| 230 |
+
load_best_model_at_end=True,
|
| 231 |
+
metric_for_best_model=best_metric,
|
| 232 |
+
report_to="none",
|
| 233 |
+
logging_steps=50,
|
| 234 |
+
save_total_limit=2
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
trainer = WeightedTrainer(
|
| 238 |
+
model=model,
|
| 239 |
+
args=args,
|
| 240 |
+
train_dataset=split['train'],
|
| 241 |
+
eval_dataset=split['test'],
|
| 242 |
+
compute_metrics=compute_metrics,
|
| 243 |
+
class_weights=weights
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
info += "\n\n⏳ 開始訓練..."
|
| 247 |
+
trainer.train()
|
| 248 |
+
results = trainer.evaluate()
|
| 249 |
+
|
| 250 |
+
# 生成帶時間戳的模型 ID
|
| 251 |
+
model_counter += 1
|
| 252 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 253 |
+
model_id = f"{base_model}_{method}_{timestamp}"
|
| 254 |
+
trained_models[model_id] = {
|
| 255 |
+
'model': model,
|
| 256 |
+
'tokenizer': tokenizer,
|
| 257 |
+
'results': results,
|
| 258 |
+
'baseline': baseline_perf,
|
| 259 |
+
'config': {
|
| 260 |
+
'type': base_model,
|
| 261 |
+
'model_name': model_name,
|
| 262 |
+
'method': method,
|
| 263 |
+
'metric': best_metric
|
| 264 |
+
},
|
| 265 |
+
'timestamp': timestamp
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
# 計算改善
|
| 269 |
+
f1_improve = calculate_improvement(baseline_perf['f1'], results['eval_f1'])
|
| 270 |
+
acc_improve = calculate_improvement(baseline_perf['accuracy'], results['eval_accuracy'])
|
| 271 |
+
prec_improve = calculate_improvement(baseline_perf['precision'], results['eval_precision'])
|
| 272 |
+
rec_improve = calculate_improvement(baseline_perf['recall'], results['eval_recall'])
|
| 273 |
+
sens_improve = calculate_improvement(baseline_perf['sensitivity'], results['eval_sensitivity'])
|
| 274 |
+
spec_improve = calculate_improvement(baseline_perf['specificity'], results['eval_specificity'])
|
| 275 |
+
|
| 276 |
+
# 純 BERT 輸出
|
| 277 |
+
baseline_output = f"🔬 純 BERT(未微調)\n\n"
|
| 278 |
+
baseline_output += f"📈 表現\n"
|
| 279 |
+
baseline_output += f"F1: {baseline_perf['f1']:.4f}\n"
|
| 280 |
+
baseline_output += f"Accuracy: {baseline_perf['accuracy']:.4f}\n"
|
| 281 |
+
baseline_output += f"Precision: {baseline_perf['precision']:.4f}\n"
|
| 282 |
+
baseline_output += f"Recall: {baseline_perf['recall']:.4f}\n"
|
| 283 |
+
baseline_output += f"Sensitivity: {baseline_perf['sensitivity']:.4f}\n"
|
| 284 |
+
baseline_output += f"Specificity: {baseline_perf['specificity']:.4f}\n\n"
|
| 285 |
+
baseline_output += f"混淆矩陣\n"
|
| 286 |
+
baseline_output += f"TP: {baseline_perf['tp']} | TN: {baseline_perf['tn']}\n"
|
| 287 |
+
baseline_output += f"FP: {baseline_perf['fp']} | FN: {baseline_perf['fn']}"
|
| 288 |
+
|
| 289 |
+
# 微調 BERT 輸出
|
| 290 |
+
finetuned_output = f"✅ 微調 BERT\n模型: {model_id}\n\n"
|
| 291 |
+
finetuned_output += f"📈 表現\n"
|
| 292 |
+
finetuned_output += f"F1: {results['eval_f1']:.4f}\n"
|
| 293 |
+
finetuned_output += f"Accuracy: {results['eval_accuracy']:.4f}\n"
|
| 294 |
+
finetuned_output += f"Precision: {results['eval_precision']:.4f}\n"
|
| 295 |
+
finetuned_output += f"Recall: {results['eval_recall']:.4f}\n"
|
| 296 |
+
finetuned_output += f"Sensitivity: {results['eval_sensitivity']:.4f}\n"
|
| 297 |
+
finetuned_output += f"Specificity: {results['eval_specificity']:.4f}\n\n"
|
| 298 |
+
finetuned_output += f"混淆矩陣\n"
|
| 299 |
+
finetuned_output += f"TP: {results['eval_tp']} | TN: {results['eval_tn']}\n"
|
| 300 |
+
finetuned_output += f"FP: {results['eval_fp']} | FN: {results['eval_fn']}"
|
| 301 |
+
|
| 302 |
+
# 比較結果輸出
|
| 303 |
+
comparison_output = f"📊 純 BERT vs 微調 BERT 比較\n\n"
|
| 304 |
+
comparison_output += f"指標改善:\n"
|
| 305 |
+
comparison_output += f"F1: {baseline_perf['f1']:.4f} → {results['eval_f1']:.4f} ({format_improve(f1_improve)})\n"
|
| 306 |
+
comparison_output += f"Accuracy: {baseline_perf['accuracy']:.4f} → {results['eval_accuracy']:.4f} ({format_improve(acc_improve)})\n"
|
| 307 |
+
comparison_output += f"Precision: {baseline_perf['precision']:.4f} → {results['eval_precision']:.4f} ({format_improve(prec_improve)})\n"
|
| 308 |
+
comparison_output += f"Recall: {baseline_perf['recall']:.4f} → {results['eval_recall']:.4f} ({format_improve(rec_improve)})\n"
|
| 309 |
+
comparison_output += f"Sensitivity: {baseline_perf['sensitivity']:.4f} → {results['eval_sensitivity']:.4f} ({format_improve(sens_improve)})\n"
|
| 310 |
+
comparison_output += f"Specificity: {baseline_perf['specificity']:.4f} → {results['eval_specificity']:.4f} ({format_improve(spec_improve)})\n\n"
|
| 311 |
+
comparison_output += f"混淆矩陣變化:\n"
|
| 312 |
+
comparison_output += f"TP: {baseline_perf['tp']} → {results['eval_tp']} ({results['eval_tp'] - baseline_perf['tp']:+d})\n"
|
| 313 |
+
comparison_output += f"TN: {baseline_perf['tn']} → {results['eval_tn']} ({results['eval_tn'] - baseline_perf['tn']:+d})\n"
|
| 314 |
+
comparison_output += f"FP: {baseline_perf['fp']} → {results['eval_fp']} ({results['eval_fp'] - baseline_perf['fp']:+d})\n"
|
| 315 |
+
comparison_output += f"FN: {baseline_perf['fn']} → {results['eval_fn']} ({results['eval_fn'] - baseline_perf['fn']:+d})"
|
| 316 |
+
|
| 317 |
+
info += "\n\n✅ 訓練完成!"
|
| 318 |
+
|
| 319 |
+
return info, baseline_output, finetuned_output, comparison_output
|
| 320 |
+
|
| 321 |
+
except Exception as e:
|
| 322 |
+
import traceback
|
| 323 |
+
error_msg = f"❌ 錯誤: {str(e)}\n\n{traceback.format_exc()}"
|
| 324 |
+
return error_msg, "", "", ""
|
| 325 |
+
|
| 326 |
+
def predict(model_id, text):
|
| 327 |
+
global baseline_model_cache
|
| 328 |
+
|
| 329 |
+
if not model_id or model_id not in trained_models:
|
| 330 |
+
return "❌ 請選擇模型"
|
| 331 |
+
if not text:
|
| 332 |
+
return "❌ 請輸入文字"
|
| 333 |
+
|
| 334 |
+
try:
|
| 335 |
+
info = trained_models[model_id]
|
| 336 |
+
model, tokenizer = info['model'], info['tokenizer']
|
| 337 |
+
config = info['config']
|
| 338 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 339 |
+
|
| 340 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
|
| 341 |
+
inputs_cuda = {k: v.to(device) for k, v in inputs.items()}
|
| 342 |
+
|
| 343 |
+
# 預測:微調模型
|
| 344 |
+
model.eval()
|
| 345 |
+
with torch.no_grad():
|
| 346 |
+
outputs = model(**inputs_cuda)
|
| 347 |
+
probs_finetuned = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
| 348 |
+
pred_finetuned = torch.argmax(probs_finetuned, dim=-1).item()
|
| 349 |
+
|
| 350 |
+
result_finetuned = "存活" if pred_finetuned == 0 else "死亡"
|
| 351 |
+
|
| 352 |
+
# 預測:基準模型(使用快取)
|
| 353 |
+
cache_key = config['model_name']
|
| 354 |
+
if cache_key not in baseline_model_cache:
|
| 355 |
+
baseline_model = BertForSequenceClassification.from_pretrained(config['model_name'], num_labels=2)
|
| 356 |
+
baseline_model = baseline_model.to(device)
|
| 357 |
+
baseline_model.eval()
|
| 358 |
+
baseline_model_cache[cache_key] = baseline_model
|
| 359 |
+
else:
|
| 360 |
+
baseline_model = baseline_model_cache[cache_key]
|
| 361 |
+
|
| 362 |
+
with torch.no_grad():
|
| 363 |
+
outputs_baseline = baseline_model(**inputs_cuda)
|
| 364 |
+
probs_baseline = torch.nn.functional.softmax(outputs_baseline.logits, dim=-1)
|
| 365 |
+
pred_baseline = torch.argmax(probs_baseline, dim=-1).item()
|
| 366 |
+
|
| 367 |
+
result_baseline = "存活" if pred_baseline == 0 else "死亡"
|
| 368 |
+
|
| 369 |
+
# 判斷是否一致
|
| 370 |
+
agreement = "✅ 一致" if pred_finetuned == pred_baseline else "⚠️ 不一致"
|
| 371 |
+
|
| 372 |
+
output = f"""🔮 預測結果比較
|
| 373 |
+
|
| 374 |
+
📝 輸入文字: {text[:100]}{'...' if len(text) > 100 else ''}
|
| 375 |
+
|
| 376 |
+
{'='*50}
|
| 377 |
+
|
| 378 |
+
🧬 微調模型 ({model_id})
|
| 379 |
+
預測: {result_finetuned}
|
| 380 |
+
信心: {probs_finetuned[0][pred_finetuned].item():.2%}
|
| 381 |
+
機率分布:
|
| 382 |
+
• 存活: {probs_finetuned[0][0].item():.2%}
|
| 383 |
+
• 死亡: {probs_finetuned[0][1].item():.2%}
|
| 384 |
+
|
| 385 |
+
{'='*50}
|
| 386 |
+
|
| 387 |
+
🔬 基準模型(未微調 {config['type']})
|
| 388 |
+
預測: {result_baseline}
|
| 389 |
+
信心: {probs_baseline[0][pred_baseline].item():.2%}
|
| 390 |
+
機率分布:
|
| 391 |
+
• 存活: {probs_baseline[0][0].item():.2%}
|
| 392 |
+
• 死亡: {probs_baseline[0][1].item():.2%}
|
| 393 |
+
|
| 394 |
+
{'='*50}
|
| 395 |
+
|
| 396 |
+
📊 結論
|
| 397 |
+
兩模型預測: {agreement}
|
| 398 |
+
"""
|
| 399 |
+
|
| 400 |
+
if pred_finetuned != pred_baseline:
|
| 401 |
+
output += f"\n💡 分析: 微調模型預測為【{result_finetuned}】,而基準模型預測為【{result_baseline}】"
|
| 402 |
+
output += f"\n 這顯示了 fine-tuning 對此案例的影響!"
|
| 403 |
+
|
| 404 |
+
f1_improve = calculate_improvement(info['baseline']['f1'], info['results']['eval_f1'])
|
| 405 |
+
|
| 406 |
+
output += f"""
|
| 407 |
+
|
| 408 |
+
📈 模型表現
|
| 409 |
+
微調模型 F1: {info['results']['eval_f1']:.4f}
|
| 410 |
+
基準模型 F1: {info['baseline']['f1']:.4f}
|
| 411 |
+
改善幅度: {format_improve(f1_improve)}
|
| 412 |
+
"""
|
| 413 |
+
|
| 414 |
+
return output
|
| 415 |
+
|
| 416 |
+
except Exception as e:
|
| 417 |
+
import traceback
|
| 418 |
+
return f"❌ 錯誤: {str(e)}\n\n{traceback.format_exc()}"
|
| 419 |
+
|
| 420 |
+
def compare():
|
| 421 |
+
if not trained_models:
|
| 422 |
+
return "❌ 尚未訓練模型"
|
| 423 |
+
|
| 424 |
+
text = "# 📊 模型比較\n\n"
|
| 425 |
+
text += "## 微調模型表現\n\n"
|
| 426 |
+
text += "| 模型 | 基礎 | 方法 | F1 | Acc | Prec | Recall | Sens | Spec |\n"
|
| 427 |
+
text += "|------|------|------|-----|-----|------|--------|------|------|\n"
|
| 428 |
+
|
| 429 |
+
for mid, info in trained_models.items():
|
| 430 |
+
r = info['results']
|
| 431 |
+
c = info['config']
|
| 432 |
+
text += f"| {mid} | {c['type']} | {c['method'].upper()} | {r['eval_f1']:.4f} | {r['eval_accuracy']:.4f} | "
|
| 433 |
+
text += f"{r['eval_precision']:.4f} | {r['eval_recall']:.4f} | "
|
| 434 |
+
text += f"{r['eval_sensitivity']:.4f} | {r['eval_specificity']:.4f} |\n"
|
| 435 |
+
|
| 436 |
+
text += "\n## 基準模型表現(未微調)\n\n"
|
| 437 |
+
text += "| 模型 | F1 | Acc | Prec | Recall | Sens | Spec |\n"
|
| 438 |
+
text += "|------|-----|-----|------|--------|------|------|\n"
|
| 439 |
+
|
| 440 |
+
for mid, info in trained_models.items():
|
| 441 |
+
b = info['baseline']
|
| 442 |
+
c = info['config']
|
| 443 |
+
text += f"| {c['type']}-baseline | {b['f1']:.4f} | {b['accuracy']:.4f} | "
|
| 444 |
+
text += f"{b['precision']:.4f} | {b['recall']:.4f} | "
|
| 445 |
+
text += f"{b['sensitivity']:.4f} | {b['specificity']:.4f} |\n"
|
| 446 |
+
|
| 447 |
+
text += "\n## 🏆 最佳模型\n\n"
|
| 448 |
+
for metric in ['f1', 'accuracy', 'precision', 'recall', 'sensitivity', 'specificity']:
|
| 449 |
+
best = max(trained_models.items(), key=lambda x: x[1]['results'][f'eval_{metric}'])
|
| 450 |
+
baseline_val = best[1]['baseline'][metric]
|
| 451 |
+
finetuned_val = best[1]['results'][f'eval_{metric}']
|
| 452 |
+
improvement = calculate_improvement(baseline_val, finetuned_val)
|
| 453 |
+
|
| 454 |
+
text += f"**{metric.upper()}**: {best[0]} ({finetuned_val:.4f}, 改善 {format_improve(improvement)})\n\n"
|
| 455 |
+
|
| 456 |
+
return text
|
| 457 |
+
|
| 458 |
+
def refresh_model_list():
|
| 459 |
+
return gr.Dropdown(choices=list(trained_models.keys()))
|
| 460 |
+
|
| 461 |
+
# Gradio UI
|
| 462 |
+
with gr.Blocks(title="BERT Fine-tuning 教學平台", theme=gr.themes.Soft()) as demo:
|
| 463 |
+
gr.Markdown("# 🧬 BERT Fine-tuning 教學平台")
|
| 464 |
+
gr.Markdown("### 比較基準模型 vs 微調模型的表現差異")
|
| 465 |
+
|
| 466 |
+
with gr.Tab("訓練"):
|
| 467 |
+
gr.Markdown("## 步驟 1: 選擇基礎模型")
|
| 468 |
+
|
| 469 |
+
base_model = gr.Dropdown(
|
| 470 |
+
choices=["BERT-base"],
|
| 471 |
+
value="BERT-base",
|
| 472 |
+
label="基礎模型",
|
| 473 |
+
info="更多模型即將推出"
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
gr.Markdown("## 步驟 2: 選擇微調方法")
|
| 477 |
+
|
| 478 |
+
method = gr.Radio(
|
| 479 |
+
choices=["lora", "adalora"],
|
| 480 |
+
value="lora",
|
| 481 |
+
label="微調方法",
|
| 482 |
+
info="兩種都是參數高效方法,推薦從 LoRA 開始"
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
gr.Markdown("## 步驟 3: 上傳資料")
|
| 486 |
+
csv_file = gr.File(label="CSV 檔案 (需包含 Text 和 label 欄位)", file_types=[".csv"])
|
| 487 |
+
|
| 488 |
+
gr.Markdown("## 步驟 4: 設定訓練參數")
|
| 489 |
+
|
| 490 |
+
gr.Markdown("### 🎯 基本訓練參數")
|
| 491 |
+
with gr.Row():
|
| 492 |
+
num_epochs = gr.Number(value=3, label="訓練輪數 (epochs)", minimum=1, maximum=100, precision=0)
|
| 493 |
+
batch_size = gr.Number(value=8, label="批次大小 (batch_size)", minimum=1, maximum=128, precision=0)
|
| 494 |
+
learning_rate = gr.Number(value=2e-5, label="學習率 (learning_rate)", minimum=0, maximum=1)
|
| 495 |
+
|
| 496 |
+
gr.Markdown("### ⚙️ 進階參數")
|
| 497 |
+
with gr.Row():
|
| 498 |
+
weight_decay = gr.Number(value=0.01, label="權重衰減 (weight_decay)", minimum=0, maximum=1)
|
| 499 |
+
dropout = gr.Number(value=0.1, label="Dropout 機率", minimum=0, maximum=1)
|
| 500 |
+
|
| 501 |
+
gr.Markdown("### 🔧 LoRA 參數")
|
| 502 |
+
with gr.Row():
|
| 503 |
+
lora_r = gr.Number(value=16, label="LoRA Rank (r)", minimum=1, maximum=256, precision=0,
|
| 504 |
+
info="推薦 8-16,越大效果越好但越慢")
|
| 505 |
+
lora_alpha = gr.Number(value=32, label="LoRA Alpha", minimum=1, maximum=512, precision=0,
|
| 506 |
+
info="通常設為 Rank 的 2 倍")
|
| 507 |
+
lora_dropout = gr.Number(value=0.1, label="LoRA Dropout", minimum=0, maximum=1,
|
| 508 |
+
info="防止過擬合")
|
| 509 |
+
|
| 510 |
+
gr.Markdown("### ⚖️ 評估設定")
|
| 511 |
+
with gr.Row():
|
| 512 |
+
weight_mult = gr.Number(value=2.0, label="類別權重倍數", minimum=0, maximum=10,
|
| 513 |
+
info="推薦 1.5-2.5,過低會忽略少數類")
|
| 514 |
+
best_metric = gr.Dropdown(
|
| 515 |
+
choices=["f1", "accuracy", "precision", "recall", "sensitivity", "specificity"],
|
| 516 |
+
value="f1",
|
| 517 |
+
label="最佳模型選擇指標",
|
| 518 |
+
info="訓練時用此指標選擇最佳模型"
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
train_btn = gr.Button("🚀 開始訓練", variant="primary", size="lg")
|
| 522 |
+
|
| 523 |
+
gr.Markdown("## 📊 訓練結果")
|
| 524 |
+
|
| 525 |
+
data_info = gr.Textbox(label="📋 資料資訊", lines=10)
|
| 526 |
+
|
| 527 |
+
with gr.Row():
|
| 528 |
+
baseline_result = gr.Textbox(label="🔬 純 BERT(未微調)", lines=14)
|
| 529 |
+
finetuned_result = gr.Textbox(label="✅ 微調 BERT", lines=14)
|
| 530 |
+
|
| 531 |
+
comparison_result = gr.Textbox(label="📊 純 BERT vs 微調 BERT 比較", lines=14)
|
| 532 |
+
|
| 533 |
+
train_btn.click(
|
| 534 |
+
train_bert_model,
|
| 535 |
+
inputs=[csv_file, base_model, method, num_epochs, batch_size, learning_rate,
|
| 536 |
+
weight_decay, dropout, lora_r, lora_alpha, lora_dropout,
|
| 537 |
+
weight_mult, best_metric],
|
| 538 |
+
outputs=[data_info, baseline_result, finetuned_result, comparison_result]
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
with gr.Tab("預測"):
|
| 542 |
+
gr.Markdown("## 使用訓練好的模型預測")
|
| 543 |
+
|
| 544 |
+
with gr.Row():
|
| 545 |
+
model_drop = gr.Dropdown(label="選擇模型", choices=list(trained_models.keys()))
|
| 546 |
+
refresh = gr.Button("🔄 刷新")
|
| 547 |
+
|
| 548 |
+
text_input = gr.Textbox(label="輸入病例描述", lines=4,
|
| 549 |
+
placeholder="Patient diagnosed with...")
|
| 550 |
+
predict_btn = gr.Button("預測", variant="primary", size="lg")
|
| 551 |
+
pred_output = gr.Textbox(label="預測結果(含基準模型對比)", lines=20)
|
| 552 |
+
|
| 553 |
+
refresh.click(refresh_model_list, outputs=[model_drop])
|
| 554 |
+
predict_btn.click(predict, inputs=[model_drop, text_input], outputs=[pred_output])
|
| 555 |
+
|
| 556 |
+
gr.Examples(
|
| 557 |
+
examples=[
|
| 558 |
+
["Patient with stage II breast cancer, good response to treatment."],
|
| 559 |
+
["Advanced metastatic cancer, multiple organ involvement."]
|
| 560 |
+
],
|
| 561 |
+
inputs=text_input
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
with gr.Tab("比較"):
|
| 565 |
+
gr.Markdown("## 比較所有模型(含基準模型)")
|
| 566 |
+
compare_btn = gr.Button("比較", variant="primary", size="lg")
|
| 567 |
+
compare_output = gr.Markdown()
|
| 568 |
+
compare_btn.click(compare, outputs=[compare_output])
|
| 569 |
+
|
| 570 |
+
with gr.Tab("說明"):
|
| 571 |
+
gr.Markdown("""
|
| 572 |
+
## 📖 使用說明
|
| 573 |
+
|
| 574 |
+
### 🎯 平台特色
|
| 575 |
+
|
| 576 |
+
本平台會自動比較:
|
| 577 |
+
- **基準模型**:未經微調的原始 BERT
|
| 578 |
+
- **微調模型**:使��你的資料訓練後的 BERT
|
| 579 |
+
|
| 580 |
+
這樣可以清楚看到 fine-tuning 帶來的改善!
|
| 581 |
+
|
| 582 |
+
### 基礎模型
|
| 583 |
+
|
| 584 |
+
- **BERT-base**: 標準 BERT,110M 參數 ⭐目前支援
|
| 585 |
+
|
| 586 |
+
### 微調方法
|
| 587 |
+
|
| 588 |
+
- **LoRA**: 低秩適應,參數高效的微調方法 ⭐強烈推薦
|
| 589 |
+
- 只訓練少量參數(通常 <1%)
|
| 590 |
+
- 訓練速度快,效果好
|
| 591 |
+
- 適合大多數情況
|
| 592 |
+
|
| 593 |
+
- **AdaLoRA**: 自適應 LoRA,動態調整秩
|
| 594 |
+
- 自動找出最重要的參數
|
| 595 |
+
- 可能比 LoRA 效果稍好
|
| 596 |
+
- 訓練時間稍長
|
| 597 |
+
|
| 598 |
+
### 評估指標
|
| 599 |
+
|
| 600 |
+
- **F1**: 平衡指標,推薦用於不平衡資料 ⭐
|
| 601 |
+
- **Accuracy**: 整體準確率
|
| 602 |
+
- **Precision**: 減少假陽性
|
| 603 |
+
- **Recall/Sensitivity**: 減少假陰性
|
| 604 |
+
- **Specificity**: 真陰性率
|
| 605 |
+
|
| 606 |
+
### 參數建議
|
| 607 |
+
|
| 608 |
+
針對不平衡資料(如醫療資料):
|
| 609 |
+
- **微調方法**: LoRA(快速有效)或 AdaLoRA(追求極致)
|
| 610 |
+
- **LoRA Rank**: 8-16(平衡效果與速度)
|
| 611 |
+
- **類別權重倍數**: 1.5-2.5(資料不平衡時)
|
| 612 |
+
- **Learning rate**: 2e-5 到 5e-5
|
| 613 |
+
- **Epochs**: 3-8(避免過擬合)
|
| 614 |
+
- **Batch size**: 8-16(依 GPU 記憶體調整)
|
| 615 |
+
|
| 616 |
+
### 資料格式
|
| 617 |
+
|
| 618 |
+
CSV 必須包含:
|
| 619 |
+
- `Text`: 病例描述
|
| 620 |
+
- `label`: 0=存活, 1=死亡
|
| 621 |
+
|
| 622 |
+
### 🚀 快速開始
|
| 623 |
+
|
| 624 |
+
1. 上傳包含 `Text` 和 `label` 欄位的 CSV
|
| 625 |
+
2. 使用預設參數(適合大多數情況)
|
| 626 |
+
3. 點擊「開始訓練」
|
| 627 |
+
4. 在「預測」分頁測試模型
|
| 628 |
+
5. 在「比較」分頁查看所有模型表現
|
| 629 |
+
""")
|
| 630 |
+
|
| 631 |
+
if __name__ == "__main__":
|
| 632 |
+
demo.launch()
|