Wen1201 commited on
Commit
6a07af8
·
verified ·
1 Parent(s): d00baad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -13
app.py CHANGED
@@ -51,12 +51,20 @@ def setup_bitfit(model):
51
  param.requires_grad = False
52
  return model
53
 
54
- def train_bert_model(csv_file, method, num_epochs, batch_size, learning_rate,
55
  weight_decay, dropout, lora_r, lora_alpha, lora_dropout,
56
  weight_mult, best_metric):
57
  global trained_models, model_counter
58
 
59
- bert_variant = "bert-base-uncased"
 
 
 
 
 
 
 
 
60
 
61
  try:
62
  if csv_file is None:
@@ -76,9 +84,8 @@ def train_bert_model(csv_file, method, num_epochs, batch_size, learning_rate,
76
  ratio = n0 / n1
77
  w0, w1 = 1.0, ratio * weight_mult
78
 
79
- info = f"📊 資料: {len(df_clean)} 筆\n存活: {n0} | 死亡: {n1}\n權重: {w0:.2f} / {w1:.2f}\n模型: {bert_variant}\n方法: {method.upper()}"
80
 
81
- model_name = bert_variant
82
  tokenizer = BertTokenizer.from_pretrained(model_name)
83
  dataset = Dataset.from_pandas(df_clean[['text', 'label']])
84
 
@@ -139,10 +146,10 @@ def train_bert_model(csv_file, method, num_epochs, batch_size, learning_rate,
139
  results = trainer.evaluate()
140
 
141
  model_counter += 1
142
- model_id = f"BERT_Model_{model_counter}_{method}"
143
  trained_models[model_id] = {
144
  'model': model, 'tokenizer': tokenizer, 'results': results,
145
- 'config': {'type': 'BERT', 'variant': bert_variant, 'method': method, 'metric': best_metric}
146
  }
147
 
148
  output = f"✅ 模型: {model_id}\n\n"
@@ -205,13 +212,13 @@ def compare():
205
  return "❌ 尚未訓練模型"
206
 
207
  text = "# 📊 模型比較\n\n"
208
- text += "| 模型 | 方法 | F1 | Acc | Prec | Recall | Sens | Spec |\n"
209
- text += "|------|------|-----|-----|------|--------|------|------|\n"
210
 
211
  for mid, info in trained_models.items():
212
  r = info['results']
213
  c = info['config']
214
- text += f"| {mid} | {c['method'].upper()} | {r['eval_f1']:.4f} | {r['eval_accuracy']:.4f} | "
215
  text += f"{r['eval_precision']:.4f} | {r['eval_recall']:.4f} | "
216
  text += f"{r['eval_sensitivity']:.4f} | {r['eval_specificity']:.4f} |\n"
217
 
@@ -226,7 +233,16 @@ with gr.Blocks(title="BERT Fine-tuning 教學平台", theme=gr.themes.Soft()) as
226
  gr.Markdown("# 🧬 BERT Fine-tuning 教學平台")
227
 
228
  with gr.Tab("訓練"):
229
- gr.Markdown("## 步驟 1: 選擇微調方法")
 
 
 
 
 
 
 
 
 
230
 
231
  method = gr.Radio(
232
  choices=["lora", "adalora", "ia3", "bitfit"],
@@ -234,10 +250,10 @@ with gr.Blocks(title="BERT Fine-tuning 教學平台", theme=gr.themes.Soft()) as
234
  label="微調方法"
235
  )
236
 
237
- gr.Markdown("## 步驟 2: 上傳資料")
238
  csv_file = gr.File(label="CSV 檔案 (需包含 Text 和 label 欄位)", file_types=[".csv"])
239
 
240
- gr.Markdown("## 步驟 3: 設定訓練參數")
241
 
242
  gr.Markdown("### 🎯 基本訓練參數")
243
  with gr.Row():
@@ -278,7 +294,7 @@ with gr.Blocks(title="BERT Fine-tuning 教學平台", theme=gr.themes.Soft()) as
278
 
279
  train_btn.click(
280
  train_bert_model,
281
- inputs=[csv_file, method, num_epochs, batch_size, learning_rate,
282
  weight_decay, dropout, lora_r, lora_alpha, lora_dropout,
283
  weight_mult, best_metric],
284
  outputs=[data_info, train_output, status]
@@ -317,6 +333,12 @@ with gr.Blocks(title="BERT Fine-tuning 教學平台", theme=gr.themes.Soft()) as
317
  gr.Markdown("""
318
  ## 📖 使用說明
319
 
 
 
 
 
 
 
320
  ### 微調方法
321
 
322
  - **LoRA**: 低秩適應,只訓練少量參數 ⭐推薦
 
51
  param.requires_grad = False
52
  return model
53
 
54
+ def train_bert_model(csv_file, base_model, method, num_epochs, batch_size, learning_rate,
55
  weight_decay, dropout, lora_r, lora_alpha, lora_dropout,
56
  weight_mult, best_metric):
57
  global trained_models, model_counter
58
 
59
+ # 模型名稱映射
60
+ model_mapping = {
61
+ "BERT-base": "bert-base-uncased",
62
+ "BERT-large": "bert-large-uncased",
63
+ "BioBERT": "dmis-lab/biobert-v1.1",
64
+ "ClinicalBERT": "emilyalsentzer/Bio_ClinicalBERT"
65
+ }
66
+
67
+ model_name = model_mapping.get(base_model, "bert-base-uncased")
68
 
69
  try:
70
  if csv_file is None:
 
84
  ratio = n0 / n1
85
  w0, w1 = 1.0, ratio * weight_mult
86
 
87
+ info = f"📊 資料: {len(df_clean)} 筆\n存活: {n0} | 死亡: {n1}\n權重: {w0:.2f} / {w1:.2f}\n模型: {base_model}\n方法: {method.upper()}"
88
 
 
89
  tokenizer = BertTokenizer.from_pretrained(model_name)
90
  dataset = Dataset.from_pandas(df_clean[['text', 'label']])
91
 
 
146
  results = trainer.evaluate()
147
 
148
  model_counter += 1
149
+ model_id = f"{base_model}_Model_{model_counter}_{method}"
150
  trained_models[model_id] = {
151
  'model': model, 'tokenizer': tokenizer, 'results': results,
152
+ 'config': {'type': base_model, 'model_name': model_name, 'method': method, 'metric': best_metric}
153
  }
154
 
155
  output = f"✅ 模型: {model_id}\n\n"
 
212
  return "❌ 尚未訓練模型"
213
 
214
  text = "# 📊 模型比較\n\n"
215
+ text += "| 模型 | 基礎模型 | 方法 | F1 | Acc | Prec | Recall | Sens | Spec |\n"
216
+ text += "|------|----------|------|-----|-----|------|--------|------|------|\n"
217
 
218
  for mid, info in trained_models.items():
219
  r = info['results']
220
  c = info['config']
221
+ text += f"| {mid} | {c['type']} | {c['method'].upper()} | {r['eval_f1']:.4f} | {r['eval_accuracy']:.4f} | "
222
  text += f"{r['eval_precision']:.4f} | {r['eval_recall']:.4f} | "
223
  text += f"{r['eval_sensitivity']:.4f} | {r['eval_specificity']:.4f} |\n"
224
 
 
233
  gr.Markdown("# 🧬 BERT Fine-tuning 教學平台")
234
 
235
  with gr.Tab("訓練"):
236
+ gr.Markdown("## 步驟 1: 選擇基礎模型")
237
+
238
+ base_model = gr.Dropdown(
239
+ choices=["BERT-base"],
240
+ value="BERT-base",
241
+ label="基礎模型",
242
+ info="更多模型即將推出"
243
+ )
244
+
245
+ gr.Markdown("## 步驟 2: 選擇微調方法")
246
 
247
  method = gr.Radio(
248
  choices=["lora", "adalora", "ia3", "bitfit"],
 
250
  label="微調方法"
251
  )
252
 
253
+ gr.Markdown("## 步驟 3: 上傳資料")
254
  csv_file = gr.File(label="CSV 檔案 (需包含 Text 和 label 欄位)", file_types=[".csv"])
255
 
256
+ gr.Markdown("## 步驟 4: 設定訓練參數")
257
 
258
  gr.Markdown("### 🎯 基本訓練參數")
259
  with gr.Row():
 
294
 
295
  train_btn.click(
296
  train_bert_model,
297
+ inputs=[csv_file, base_model, method, num_epochs, batch_size, learning_rate,
298
  weight_decay, dropout, lora_r, lora_alpha, lora_dropout,
299
  weight_mult, best_metric],
300
  outputs=[data_info, train_output, status]
 
333
  gr.Markdown("""
334
  ## 📖 使用說明
335
 
336
+ ### 基礎模型
337
+
338
+ - **BERT-base**: 標準 BERT,110M 參數 ⭐目前支援
339
+
340
+ *更多模型(BERT-large、BioBERT、ClinicalBERT)即將推出*
341
+
342
  ### 微調方法
343
 
344
  - **LoRA**: 低秩適應,只訓練少量參數 ⭐推薦