smartTranscend commited on
Commit
4c77d06
·
verified ·
1 Parent(s): 73db451

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +230 -83
app.py CHANGED
@@ -22,43 +22,35 @@ import json
22
  # 檢查 GPU
23
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
 
25
- def train_bert_model(file, weight_multiplier=0.8, epochs=3):
26
  """
27
- 個函數幾乎完全保持您原始程式碼的邏輯
28
- 它包裝一個函
29
  """
30
 
31
- output_log = []
 
32
 
33
- output_log.append("\n" + "=" * 80)
34
- output_log.append("乳癌存活預測 BERT Fine-tuning")
35
- output_log.append("=" * 80)
36
- output_log.append(f"開始時間: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
37
- output_log.append(f"使用裝置: {device}")
38
- output_log.append("=" * 80)
39
-
40
- # ============ 以下幾乎都是您的原始程式碼 ============
41
-
42
- # 讀取資料
43
- df_original = pd.read_csv(file.name)
44
  df_clean = pd.DataFrame({
45
  'text': df_original['Text'],
46
  'label': df_original['label']
47
  })
48
  df_clean = df_clean.dropna()
49
 
50
- output_log.append(f"\n原始資料分布:")
51
- output_log.append(f" 存活 (0): {sum(df_clean['label']==0)} ({sum(df_clean['label']==0)/len(df_clean)*100:.1f}%)")
52
- output_log.append(f" 死亡 (1): {sum(df_clean['label']==1)} ({sum(df_clean['label']==1)/len(df_clean)*100:.1f}%)")
53
-
54
- ratio = sum(df_clean['label']==0) / sum(df_clean['label']==1)
55
- output_log.append(f" 不平衡比例: {ratio:.1f}:1")
56
 
57
  # 載入 Tokenizer
58
- output_log.append("\n📦 載入 BERT Tokenizer...")
59
  tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
 
60
 
61
- # 評估函數 - 完全您的原始程式
62
  def compute_metrics(pred):
63
  labels = pred.label_ids
64
  preds = pred.predictions.argmax(-1)
@@ -95,7 +87,29 @@ def train_bert_model(file, weight_multiplier=0.8, epochs=3):
95
  'tp': int(tp), 'tn': int(tn), 'fp': int(fp), 'fn': int(fn)
96
  }
97
 
98
- # Tokenization
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  dataset = Dataset.from_pandas(df_clean[['text', 'label']])
100
 
101
  def preprocess_function(examples):
@@ -106,29 +120,44 @@ def train_bert_model(file, weight_multiplier=0.8, epochs=3):
106
  train_dataset = train_test_split['train']
107
  eval_dataset = train_test_split['test']
108
 
109
- output_log.append(f"\n✅ 資料集準備完成:")
110
- output_log.append(f" 訓練集: {len(train_dataset)} 筆")
111
- output_log.append(f" 驗證集: {len(eval_dataset)} 筆")
 
 
 
 
 
 
 
 
112
 
113
- # 設定權重 - 照您的原始程式
114
  weight_0 = 1.0
115
  weight_1 = ratio * weight_multiplier
116
 
117
- output_log.append(f"\n權重設定:")
118
- output_log.append(f" 倍數: {weight_multiplier}x")
119
- output_log.append(f" 存活類權重: {weight_0:.3f}")
120
- output_log.append(f" 死亡類權重: {weight_1:.3f}")
121
 
122
  class_weights = torch.tensor([weight_0, weight_1], dtype=torch.float).to(device)
123
 
124
- # 載入模型
125
- output_log.append("\n🔄 初始化模型...")
 
 
 
 
 
 
 
 
126
  model = BertForSequenceClassification.from_pretrained(
127
  "bert-base-uncased", num_labels=2, problem_type="single_label_classification"
128
  )
129
  model = model.to(device)
130
 
131
- # 自訂 Trainer - 您的原始程式
132
  class WeightedTrainer(Trainer):
133
  def compute_loss(self, model, inputs, return_outputs=False):
134
  labels = inputs.pop("labels")
@@ -137,17 +166,17 @@ def train_bert_model(file, weight_multiplier=0.8, epochs=3):
137
  loss = loss_fct(outputs.logits.view(-1, 2), labels.view(-1))
138
  return (loss, outputs) if return_outputs else loss
139
 
140
- # 訓練設定 - 照您的原始程式,只改 eval_strategy
141
  training_args = TrainingArguments(
142
  output_dir='./results_weight',
143
- num_train_epochs=epochs,
144
- per_device_train_batch_size=16,
145
- per_device_eval_batch_size=32,
146
- warmup_steps=200,
147
  weight_decay=0.01,
148
- learning_rate=2e-5,
149
  logging_steps=50,
150
- evaluation_strategy="epoch", # 這裡:eval_strategy → evaluation_strategy
151
  save_strategy="epoch",
152
  load_best_model_at_end=True,
153
  metric_for_best_model="sensitivity",
@@ -160,74 +189,192 @@ def train_bert_model(file, weight_multiplier=0.8, epochs=3):
160
  compute_metrics=compute_metrics
161
  )
162
 
163
- output_log.append("\n🚀 開始訓練...")
164
- output_log.append("-" * 80)
165
 
166
- # 訓練
167
  trainer.train()
168
 
169
- output_log.append("\n✅ Fine-tuned 模型訓練完成!")
170
 
171
- # 評估
172
- output_log.append("\n📊 評估 Fine-tuned 模型...")
173
  finetuned_results = trainer.evaluate()
174
 
175
- output_log.append(f"\nFine-tuned BERT ({weight_multiplier}x 權重) 表現:")
176
- output_log.append(f" F1 Score: {finetuned_results['eval_f1']:.4f}")
177
- output_log.append(f" Accuracy: {finetuned_results['eval_accuracy']:.4f}")
178
- output_log.append(f" Precision: {finetuned_results['eval_precision']:.4f}")
179
- output_log.append(f" Recall: {finetuned_results['eval_recall']:.4f}")
180
- output_log.append(f" Sensitivity: {finetuned_results['eval_sensitivity']:.4f}")
181
- output_log.append(f" Specificity: {finetuned_results['eval_specificity']:.4f}")
182
- output_log.append(f" 混淆矩陣: Tp={finetuned_results['eval_tp']}, Tn={finetuned_results['eval_tn']}, "
183
  f"Fp={finetuned_results['eval_fp']}, Fn={finetuned_results['eval_fn']}")
184
 
185
  # 儲存模型
186
- save_dir = './breast_cancer_bert'
187
  model.save_pretrained(save_dir)
188
  tokenizer.save_pretrained(save_dir)
189
 
190
- output_log.append(f"\n💾 模型已儲存至: {save_dir}")
191
- output_log.append("=" * 80)
192
- output_log.append("🎉 訓練完成!")
193
- output_log.append("=" * 80)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
- # 返回訓練日誌
196
- return "\n".join(output_log)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
- # 建立簡單的 Gradio 介面
199
- with gr.Blocks(title="BERT 乳癌存活預測訓練") as demo:
 
200
  gr.Markdown("""
201
  # 🏥 BERT 乳癌存活預測訓練平台
202
 
203
- 上傳 CSV 檔案(需包含 'Text' 和 'label' 欄位),點擊訓練按鈕開始。
 
 
 
 
 
204
  """)
205
 
206
  with gr.Row():
207
- with gr.Column():
208
- file_input = gr.File(label="上傳 CSV 檔案", file_types=[".csv"])
 
 
 
 
 
 
 
 
209
  weight_slider = gr.Slider(
210
- minimum=0.1, maximum=2.0, value=0.8, step=0.1,
211
- label="權重倍數"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  )
213
- epochs_slider = gr.Slider(
214
- minimum=1, maximum=10, value=3, step=1,
215
- label="訓練輪數 (Epochs)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  )
217
- train_btn = gr.Button("開始訓練", variant="primary")
218
 
219
- with gr.Column():
220
- output = gr.Textbox(
221
- label="訓練輸出",
222
- lines=30,
223
- max_lines=50
224
  )
225
 
226
- train_btn.click(
227
- fn=train_bert_model,
228
- inputs=[file_input, weight_slider, epochs_slider],
229
- outputs=output
 
 
 
 
 
 
 
 
230
  )
 
 
 
 
 
 
 
 
 
231
 
232
  if __name__ == "__main__":
233
  demo.launch()
 
22
  # 檢查 GPU
23
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
 
25
+ def run_your_original_code(file_path, weight_multiplier, epochs, batch_size, learning_rate, warmup_steps):
26
  """
27
+ 裡直接貼上原始程式碼
28
+ 只把必要的參數改
29
  """
30
 
31
+ # ==================== 以下是您的原始程式碼 ====================
32
+ # 我只把檔案讀取和參數部分改成變數,其他完全不動
33
 
34
+ # 讀取上傳的檔案
35
+ df_original = pd.read_csv(file_path)
 
 
 
 
 
 
 
 
 
36
  df_clean = pd.DataFrame({
37
  'text': df_original['Text'],
38
  'label': df_original['label']
39
  })
40
  df_clean = df_clean.dropna()
41
 
42
+ print("\n" + "=" * 80)
43
+ print("乳癌存活預測 BERT Fine-tuning - " + str(weight_multiplier) + "x 權重策略")
44
+ print("=" * 80)
45
+ print(f"開始時間: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
46
+ print("=" * 80)
 
47
 
48
  # 載入 Tokenizer
49
+ print("\n📦 載入 BERT Tokenizer...")
50
  tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
51
+ print("✅ Tokenizer 載入完成")
52
 
53
+ # 評估函數 - 完全您的原始程式
54
  def compute_metrics(pred):
55
  labels = pred.label_ids
56
  preds = pred.predictions.argmax(-1)
 
87
  'tp': int(tp), 'tn': int(tn), 'fp': int(fp), 'fn': int(fn)
88
  }
89
 
90
+ # ============================================================================
91
+ # 步驟 1:準備資料(不做平衡)
92
+ # ============================================================================
93
+
94
+ print("\n" + "=" * 80)
95
+ print("步驟 1:準備資料(保持原始比例)")
96
+ print("=" * 80)
97
+
98
+ print(f"\n原始資料分布:")
99
+ print(f" 存活 (0): {sum(df_clean['label']==0)} 筆 ({sum(df_clean['label']==0)/len(df_clean)*100:.1f}%)")
100
+ print(f" 死亡 (1): {sum(df_clean['label']==1)} 筆 ({sum(df_clean['label']==1)/len(df_clean)*100:.1f}%)")
101
+
102
+ ratio = sum(df_clean['label']==0) / sum(df_clean['label']==1)
103
+ print(f" 不平衡比例: {ratio:.1f}:1")
104
+
105
+ # ============================================================================
106
+ # 步驟 2:Tokenization
107
+ # ============================================================================
108
+
109
+ print("\n" + "=" * 80)
110
+ print("步驟 2:Tokenization")
111
+ print("=" * 80)
112
+
113
  dataset = Dataset.from_pandas(df_clean[['text', 'label']])
114
 
115
  def preprocess_function(examples):
 
120
  train_dataset = train_test_split['train']
121
  eval_dataset = train_test_split['test']
122
 
123
+ print(f"\n✅ 資料集準備完成:")
124
+ print(f" 訓練集: {len(train_dataset)} 筆")
125
+ print(f" 驗證集: {len(eval_dataset)} 筆")
126
+
127
+ # ============================================================================
128
+ # 步驟 3:設定權重 - 這裡用參數
129
+ # ============================================================================
130
+
131
+ print("\n" + "=" * 80)
132
+ print(f"步驟 3:設定類別權重({weight_multiplier}x 倍數)")
133
+ print("=" * 80)
134
 
 
135
  weight_0 = 1.0
136
  weight_1 = ratio * weight_multiplier
137
 
138
+ print(f"\n權重設定:")
139
+ print(f" 倍數: {weight_multiplier}x")
140
+ print(f" 存活類權重: {weight_0:.3f}")
141
+ print(f" 死亡類權重: {weight_1:.3f} (= {ratio:.1f} × {weight_multiplier})")
142
 
143
  class_weights = torch.tensor([weight_0, weight_1], dtype=torch.float).to(device)
144
 
145
+ # ============================================================================
146
+ # 步驟 4:訓練模型
147
+ # ============================================================================
148
+
149
+ print("\n" + "=" * 80)
150
+ print("步驟 4:訓練 Fine-tuned BERT 模型")
151
+ print("=" * 80)
152
+
153
+ print("\n🔄 初始化模型...")
154
+
155
  model = BertForSequenceClassification.from_pretrained(
156
  "bert-base-uncased", num_labels=2, problem_type="single_label_classification"
157
  )
158
  model = model.to(device)
159
 
160
+ # 自訂 Trainer(使用權重)- 您的原始程式
161
  class WeightedTrainer(Trainer):
162
  def compute_loss(self, model, inputs, return_outputs=False):
163
  labels = inputs.pop("labels")
 
166
  loss = loss_fct(outputs.logits.view(-1, 2), labels.view(-1))
167
  return (loss, outputs) if return_outputs else loss
168
 
169
+ # 訓練設定 - 使用參數
170
  training_args = TrainingArguments(
171
  output_dir='./results_weight',
172
+ num_train_epochs=epochs, # 使用參數
173
+ per_device_train_batch_size=batch_size, # 使用參數
174
+ per_device_eval_batch_size=batch_size*2, # 使用參數
175
+ warmup_steps=warmup_steps, # 使用參數
176
  weight_decay=0.01,
177
+ learning_rate=learning_rate, # 使用參數
178
  logging_steps=50,
179
+ evaluation_strategy="epoch", # 改為新版參數名
180
  save_strategy="epoch",
181
  load_best_model_at_end=True,
182
  metric_for_best_model="sensitivity",
 
189
  compute_metrics=compute_metrics
190
  )
191
 
192
+ print(f"\n🚀 開始訓練({epochs} epochs)...")
193
+ print("-" * 80)
194
 
 
195
  trainer.train()
196
 
197
+ print("\n✅ Fine-tuned 模型訓練完成!")
198
 
199
+ # 評估 Fine-tuned 模型
200
+ print("\n📊 評估 Fine-tuned 模型...")
201
  finetuned_results = trainer.evaluate()
202
 
203
+ print(f"\nFine-tuned BERT ({weight_multiplier}x 權重) 表現:")
204
+ print(f" F1 Score: {finetuned_results['eval_f1']:.4f}")
205
+ print(f" Accuracy: {finetuned_results['eval_accuracy']:.4f}")
206
+ print(f" Precision: {finetuned_results['eval_precision']:.4f}")
207
+ print(f" Recall: {finetuned_results['eval_recall']:.4f}")
208
+ print(f" Sensitivity: {finetuned_results['eval_sensitivity']:.4f}")
209
+ print(f" Specificity: {finetuned_results['eval_specificity']:.4f}")
210
+ print(f" 混淆矩陣: Tp={finetuned_results['eval_tp']}, Tn={finetuned_results['eval_tn']}, "
211
  f"Fp={finetuned_results['eval_fp']}, Fn={finetuned_results['eval_fn']}")
212
 
213
  # 儲存模型
214
+ save_dir = './breast_cancer_bert_weight'
215
  model.save_pretrained(save_dir)
216
  tokenizer.save_pretrained(save_dir)
217
 
218
+ print(f"\n💾 Fine-tuned 模型已儲存至: {save_dir}")
219
+ print("\n" + "=" * 80)
220
+ print("🎉 訓練完成!")
221
+ print("=" * 80)
222
+ print(f"完成時間: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
223
+
224
+ # 回傳結果
225
+ return finetuned_results
226
+
227
+ # ============================================================================
228
+ # Gradio 介面部分 - 只是���裝,不改您的程式
229
+ # ============================================================================
230
+
231
+ def train_wrapper(file, weight_mult, epochs, batch_size, lr, warmup):
232
+ """包裝函數,處理 Gradio 的輸入輸出"""
233
+
234
+ if file is None:
235
+ return "請上傳 CSV 檔案"
236
 
237
+ try:
238
+ # 呼叫您的原始程式碼
239
+ results = run_your_original_code(
240
+ file_path=file.name,
241
+ weight_multiplier=weight_mult,
242
+ epochs=int(epochs),
243
+ batch_size=int(batch_size),
244
+ learning_rate=lr,
245
+ warmup_steps=int(warmup)
246
+ )
247
+
248
+ # 格式化輸出
249
+ output = f"""
250
+ # 🎉 訓練完成!
251
+
252
+ ## 📊 模型表現指標
253
+
254
+ | 指標 | 數值 |
255
+ |------|------|
256
+ | **F1 Score** | {results['eval_f1']:.4f} |
257
+ | **Accuracy** | {results['eval_accuracy']:.4f} |
258
+ | **Precision** | {results['eval_precision']:.4f} |
259
+ | **Recall** | {results['eval_recall']:.4f} |
260
+ | **Sensitivity** | {results['eval_sensitivity']:.4f} |
261
+ | **Specificity** | {results['eval_specificity']:.4f} |
262
+
263
+ ## 📈 混淆矩陣
264
+
265
+ - True Positive (TP): {results['eval_tp']}
266
+ - True Negative (TN): {results['eval_tn']}
267
+ - False Positive (FP): {results['eval_fp']}
268
+ - False Negative (FN): {results['eval_fn']}
269
+
270
+ ## ⚙️ 使用的參數
271
+
272
+ - 權重倍數: {weight_mult}x
273
+ - 訓練輪數: {epochs}
274
+ - 批次大小: {batch_size}
275
+ - 學習率: {lr}
276
+ - Warmup Steps: {warmup}
277
+
278
+ 模型已儲存至 `./breast_cancer_bert_weight/`
279
+ """
280
+ return output
281
+
282
+ except Exception as e:
283
+ return f"❌ 錯誤:{str(e)}"
284
 
285
+ # 建立 Gradio 介面
286
+ with gr.Blocks(title="BERT 乳癌存活預測訓練", theme=gr.themes.Soft()) as demo:
287
+
288
  gr.Markdown("""
289
  # 🏥 BERT 乳癌存活預測訓練平台
290
 
291
+ ### 使用說明:
292
+ 1. 上傳您的 CSV 檔案(需包含 'Text' 和 'label' 欄位)
293
+ 2. 調整訓練參數(或使用預設值)
294
+ 3. 點擊「開始訓練」
295
+
296
+ **注意**:這個介面只是包裝您的原始程式碼,核心邏輯完全不變。
297
  """)
298
 
299
  with gr.Row():
300
+ with gr.Column(scale=1):
301
+ gr.Markdown("### 📤 資料與參數設定")
302
+
303
+ file_input = gr.File(
304
+ label="上傳 CSV 檔案",
305
+ file_types=[".csv"]
306
+ )
307
+
308
+ gr.Markdown("### ⚙️ 訓練參數")
309
+
310
  weight_slider = gr.Slider(
311
+ minimum=0.1,
312
+ maximum=2.0,
313
+ value=0.8,
314
+ step=0.1,
315
+ label="權重倍數 (Weight Multiplier)",
316
+ info="調整死亡類別的權重,您原始程式使用 0.8"
317
+ )
318
+
319
+ epochs_input = gr.Number(
320
+ value=8,
321
+ label="訓練輪數 (Epochs)",
322
+ info="您原始程式使用 8"
323
+ )
324
+
325
+ batch_size_input = gr.Number(
326
+ value=16,
327
+ label="批次大小 (Batch Size)",
328
+ info="您原始程式使用 16"
329
  )
330
+
331
+ lr_input = gr.Number(
332
+ value=2e-5,
333
+ label="學習率 (Learning Rate)",
334
+ info="您原始程式使用 2e-5"
335
+ )
336
+
337
+ warmup_input = gr.Number(
338
+ value=200,
339
+ label="Warmup Steps",
340
+ info="您原始程式使用 200"
341
+ )
342
+
343
+ train_button = gr.Button(
344
+ "🚀 開始訓練",
345
+ variant="primary",
346
+ size="lg"
347
  )
 
348
 
349
+ with gr.Column(scale=2):
350
+ gr.Markdown("### 📊 訓練結果")
351
+ output_text = gr.Markdown(
352
+ value="等待訓練...",
353
+ label="輸出結果"
354
  )
355
 
356
+ # 設定按鈕動作
357
+ train_button.click(
358
+ fn=train_wrapper,
359
+ inputs=[
360
+ file_input,
361
+ weight_slider,
362
+ epochs_input,
363
+ batch_size_input,
364
+ lr_input,
365
+ warmup_input
366
+ ],
367
+ outputs=output_text
368
  )
369
+
370
+ gr.Markdown("""
371
+ ---
372
+ ### 📝 備註
373
+ - 訓練時間依資料量和參數而定,通常需要 5-15 分鐘
374
+ - 建議使用 GPU 以加快訓練速度
375
+ - 模型會自動儲存在伺服器上
376
+ - 這個介面完全保留您的原始訓練邏輯
377
+ """)
378
 
379
  if __name__ == "__main__":
380
  demo.launch()