smartTranscend commited on
Commit
39d09a2
·
verified ·
1 Parent(s): 42f5370

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +229 -42
app.py CHANGED
@@ -19,6 +19,7 @@ import numpy as np
19
  from datetime import datetime
20
  import json
21
  import os
 
22
 
23
  # PEFT 相關的 import(LoRA 和 AdaLoRA)
24
  try:
@@ -37,8 +38,7 @@ except ImportError:
37
  # 檢查 GPU
38
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
 
40
- # 全域變數儲存最後訓練的模型路徑和 tokenizer
41
- LAST_MODEL_PATH = None
42
  LAST_TOKENIZER = None
43
  LAST_TUNING_METHOD = None
44
 
@@ -157,6 +157,12 @@ def run_original_code_with_tuning(
157
 
158
  global LAST_MODEL_PATH, LAST_TOKENIZER, LAST_TUNING_METHOD
159
 
 
 
 
 
 
 
160
  # ==================== 您的原始程式碼開始 ====================
161
 
162
  # 讀取上傳的檔案
@@ -466,6 +472,32 @@ def run_original_code_with_tuning(
466
  model.save_pretrained(save_dir)
467
  tokenizer.save_pretrained(save_dir)
468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  # 儲存到全域變數供預測使用
470
  LAST_MODEL_PATH = save_dir
471
  LAST_TOKENIZER = tokenizer
@@ -477,40 +509,131 @@ def run_original_code_with_tuning(
477
  print("=" * 80)
478
  print(f"完成時間: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
479
 
 
 
 
 
 
 
 
480
  # 加入所有資訊到結果中
481
  results['tuning_method'] = tuning_method
482
  results['best_metric'] = best_metric
483
  results['best_metric_value'] = results[f'eval_{metric_map.get(best_metric, "f1")}']
484
  results['baseline_results'] = baseline_results
 
485
 
486
  return results
487
 
488
- def predict_text(text_input):
489
  """
490
- 預測功能 - 使用最後訓練的模型
491
  """
492
- global LAST_MODEL_PATH, LAST_TOKENIZER, LAST_TUNING_METHOD
493
 
494
- if LAST_MODEL_PATH is None:
495
- return "先訓練模型"
496
 
497
  try:
498
- print(f"\n使用模型: {LAST_MODEL_PATH}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
 
500
  # 載入模型
501
- if LAST_TUNING_METHOD in ["LoRA", "AdaLoRA"] and PEFT_AVAILABLE:
 
502
  # 載入 PEFT 模型
503
  base_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
504
- model = PeftModel.from_pretrained(base_model, LAST_MODEL_PATH)
505
- model = model.to(device)
506
  else:
507
  # 載入一般模型
508
- model = BertForSequenceClassification.from_pretrained(LAST_MODEL_PATH).to(device)
509
 
510
- model.eval()
511
 
512
- # Tokenize 輸入
513
- inputs = LAST_TOKENIZER(
514
  text_input,
515
  truncation=True,
516
  padding='max_length',
@@ -518,41 +641,72 @@ def predict_text(text_input):
518
  return_tensors='pt'
519
  ).to(device)
520
 
521
- # 預測
522
  with torch.no_grad():
523
- outputs = model(**inputs)
524
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
525
- pred_class = probs.argmax(-1).item()
526
- confidence = probs[0][pred_class].item()
527
 
528
- # 準備結果
529
- result = "存活" if pred_class == 0 else "死亡"
530
- prob_survive = probs[0][0].item()
531
- prob_death = probs[0][1].item()
532
 
533
- output = f"""
534
- ## 🔮 預測結果
535
 
536
- ### 預測類別: **{result}**
537
 
538
- ### 信心度: **{confidence:.1%}**
539
 
540
- ### 機率分布:
541
- - 🟢 **存活機率**: {prob_survive:.2%}
542
- - 🔴 **死亡機率**: {prob_death:.2%}
543
 
 
544
  ### 模型資訊:
545
- - 使用方法: {LAST_TUNING_METHOD}
546
- - 模型路徑: {LAST_MODEL_PATH}
 
 
547
 
548
  ---
549
  **注意**: 此預測僅供參考,實際醫療決策應由專業醫師判斷。
550
  """
551
 
552
- return output
 
 
 
 
 
553
 
554
  except Exception as e:
555
- return f"❌ 預測錯誤:{str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
556
 
557
  # ============================================================================
558
  # Gradio 介面部分 - 修改輸出為三個格子
@@ -851,11 +1005,24 @@ with gr.Blocks(title="BERT 完整訓練與預測平台", theme=gr.themes.Soft())
851
  gr.Markdown("""
852
  ### 使用訓練好的模型進行預測
853
 
854
- 訓練頁面完成訓練後可以在這裡輸入病歷文本進行預測。
855
  """)
856
 
857
  with gr.Row():
858
  with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
859
  text_input = gr.Textbox(
860
  label="輸入病歷文本",
861
  placeholder="請輸入患者的病歷描述(英文)...",
@@ -881,9 +1048,18 @@ with gr.Blocks(title="BERT 完整訓練與預測平台", theme=gr.themes.Soft())
881
  )
882
 
883
  with gr.Column():
884
- prediction_output = gr.Markdown(
885
- label="預測結果",
886
- value="請先完成模型訓練,然後輸入文本進行預測..."
 
 
 
 
 
 
 
 
 
887
  )
888
 
889
  with gr.Tab("📖 使用說明"):
@@ -936,7 +1112,7 @@ with gr.Blocks(title="BERT 完整訓練與預測平台", theme=gr.themes.Soft())
936
  outputs=[lora_params, adalora_params]
937
  )
938
 
939
- # 設定按鈕動作 - 注意這裡改為三個輸出
940
  train_button.click(
941
  fn=train_wrapper,
942
  inputs=[
@@ -963,10 +1139,21 @@ with gr.Blocks(title="BERT 完整訓練與預測平台", theme=gr.themes.Soft())
963
  outputs=[data_info_output, baseline_output, finetuned_output] # 三個輸出
964
  )
965
 
 
 
 
 
 
 
 
 
 
 
 
966
  predict_button.click(
967
  fn=predict_text,
968
- inputs=[text_input],
969
- outputs=[prediction_output]
970
  )
971
 
972
  if __name__ == "__main__":
 
19
  from datetime import datetime
20
  import json
21
  import os
22
+ import gc # 用於記憶體清理
23
 
24
  # PEFT 相關的 import(LoRA 和 AdaLoRA)
25
  try:
 
38
  # 檢查 GPU
39
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
40
 
41
+ _MODEL_PATH = None
 
42
  LAST_TOKENIZER = None
43
  LAST_TUNING_METHOD = None
44
 
 
157
 
158
  global LAST_MODEL_PATH, LAST_TOKENIZER, LAST_TUNING_METHOD
159
 
160
+ # ==================== 清空記憶體(訓練前) ====================
161
+ import gc
162
+ torch.cuda.empty_cache()
163
+ gc.collect()
164
+ print("🧹 記憶體已清空")
165
+
166
  # ==================== 您的原始程式碼開始 ====================
167
 
168
  # 讀取上傳的檔案
 
472
  model.save_pretrained(save_dir)
473
  tokenizer.save_pretrained(save_dir)
474
 
475
+ # 儲存模型資訊到 JSON 檔案(用於預測頁面選擇)
476
+ model_info = {
477
+ 'model_path': save_dir,
478
+ 'tuning_method': tuning_method,
479
+ 'best_metric': best_metric,
480
+ 'best_metric_value': float(results[f'eval_{metric_map.get(best_metric, "f1")}']),
481
+ 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
482
+ 'weight_multiplier': weight_multiplier,
483
+ 'epochs': epochs
484
+ }
485
+
486
+ # 讀取現有的模型列表
487
+ models_list_file = './saved_models_list.json'
488
+ if os.path.exists(models_list_file):
489
+ with open(models_list_file, 'r') as f:
490
+ models_list = json.load(f)
491
+ else:
492
+ models_list = []
493
+
494
+ # 加入新模型資訊
495
+ models_list.append(model_info)
496
+
497
+ # 儲存更新後的列表
498
+ with open(models_list_file, 'w') as f:
499
+ json.dump(models_list, f, indent=2)
500
+
501
  # 儲存到全域變數供預測使用
502
  LAST_MODEL_PATH = save_dir
503
  LAST_TOKENIZER = tokenizer
 
509
  print("=" * 80)
510
  print(f"完成時間: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
511
 
512
+ # ==================== 清空記憶體(訓練後) ====================
513
+ del model
514
+ del trainer
515
+ torch.cuda.empty_cache()
516
+ gc.collect()
517
+ print("🧹 訓練後記憶體已清空")
518
+
519
  # 加入所有資訊到結果中
520
  results['tuning_method'] = tuning_method
521
  results['best_metric'] = best_metric
522
  results['best_metric_value'] = results[f'eval_{metric_map.get(best_metric, "f1")}']
523
  results['baseline_results'] = baseline_results
524
+ results['model_path'] = save_dir
525
 
526
  return results
527
 
528
+ def predict_text(model_choice, text_input):
529
  """
530
+ 預測功能 - 支援選擇已訓練的模型,並同時顯示未微調和微調的預測結果
531
  """
 
532
 
533
+ if not text_input or text_input.strip() == "":
534
+ return "請輸入文本", "輸入文本"
535
 
536
  try:
537
+ # ==================== 未微調的 BERT 預測 ====================
538
+ print("\n使用未微調 BERT 預測...")
539
+ baseline_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
540
+ baseline_model = BertForSequenceClassification.from_pretrained(
541
+ "bert-base-uncased",
542
+ num_labels=2
543
+ ).to(device)
544
+ baseline_model.eval()
545
+
546
+ # Tokenize 輸入(未微調)
547
+ baseline_inputs = baseline_tokenizer(
548
+ text_input,
549
+ truncation=True,
550
+ padding='max_length',
551
+ max_length=256,
552
+ return_tensors='pt'
553
+ ).to(device)
554
+
555
+ # 預測(未微調)
556
+ with torch.no_grad():
557
+ baseline_outputs = baseline_model(**baseline_inputs)
558
+ baseline_probs = torch.nn.functional.softmax(baseline_outputs.logits, dim=-1)
559
+ baseline_pred_class = baseline_probs.argmax(-1).item()
560
+ baseline_confidence = baseline_probs[0][baseline_pred_class].item()
561
+
562
+ baseline_result = "存活" if baseline_pred_class == 0 else "死亡"
563
+ baseline_prob_survive = baseline_probs[0][0].item()
564
+ baseline_prob_death = baseline_probs[0][1].item()
565
+
566
+ baseline_output = f"""
567
+ # 🔵 未微調 BERT 預測結果
568
+
569
+ ## 預測類別: **{baseline_result}**
570
+
571
+ ## 信心度: **{baseline_confidence:.1%}**
572
+
573
+ ## 機率分布:
574
+ - 🟢 **存活機率**: {baseline_prob_survive:.2%}
575
+ - 🔴 **死亡機率**: {baseline_prob_death:.2%}
576
+
577
+ ---
578
+ **說明**: 此為原始 BERT 模型,未經任何領域資料訓練
579
+ """
580
+
581
+ # 清空記憶體
582
+ del baseline_model
583
+ del baseline_tokenizer
584
+ torch.cuda.empty_cache()
585
+
586
+ # ==================== 微調後的 BERT 預測 ====================
587
+
588
+ if model_choice == "請先訓練模型":
589
+ finetuned_output = """
590
+ # 🟢 微調 BERT 預測結果
591
+
592
+ ❌ 尚未訓練任何模型,請先在「模型訓練」頁面訓練模型
593
+ """
594
+ return baseline_output, finetuned_output
595
+
596
+ # 解析選擇的模型路徑
597
+ model_path = model_choice.split(" | ")[0].replace("路徑: ", "")
598
+
599
+ # 從 JSON 讀取模型資訊
600
+ with open('./saved_models_list.json', 'r') as f:
601
+ models_list = json.load(f)
602
+
603
+ selected_model_info = None
604
+ for model_info in models_list:
605
+ if model_info['model_path'] == model_path:
606
+ selected_model_info = model_info
607
+ break
608
+
609
+ if selected_model_info is None:
610
+ finetuned_output = f"""
611
+ # 🟢 微調 BERT 預測結果
612
+
613
+ ❌ 找不到模型:{model_path}
614
+ """
615
+ return baseline_output, finetuned_output
616
+
617
+ print(f"\n使用微調模型: {model_path}")
618
+
619
+ # 載入 tokenizer
620
+ finetuned_tokenizer = BertTokenizer.from_pretrained(model_path)
621
 
622
  # 載入模型
623
+ tuning_method = selected_model_info['tuning_method']
624
+ if tuning_method in ["LoRA", "AdaLoRA"] and PEFT_AVAILABLE:
625
  # 載入 PEFT 模型
626
  base_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
627
+ finetuned_model = PeftModel.from_pretrained(base_model, model_path)
628
+ finetuned_model = finetuned_model.to(device)
629
  else:
630
  # 載入一般模型
631
+ finetuned_model = BertForSequenceClassification.from_pretrained(model_path).to(device)
632
 
633
+ finetuned_model.eval()
634
 
635
+ # Tokenize 輸入(微調)
636
+ finetuned_inputs = finetuned_tokenizer(
637
  text_input,
638
  truncation=True,
639
  padding='max_length',
 
641
  return_tensors='pt'
642
  ).to(device)
643
 
644
+ # 預測(微調)
645
  with torch.no_grad():
646
+ finetuned_outputs = finetuned_model(**finetuned_inputs)
647
+ finetuned_probs = torch.nn.functional.softmax(finetuned_outputs.logits, dim=-1)
648
+ finetuned_pred_class = finetuned_probs.argmax(-1).item()
649
+ finetuned_confidence = finetuned_probs[0][finetuned_pred_class].item()
650
 
651
+ finetuned_result = "存活" if finetuned_pred_class == 0 else "死亡"
652
+ finetuned_prob_survive = finetuned_probs[0][0].item()
653
+ finetuned_prob_death = finetuned_probs[0][1].item()
 
654
 
655
+ finetuned_output = f"""
656
+ # 🟢 微調 BERT 預測結果
657
 
658
+ ## 預測類別: **{finetuned_result}**
659
 
660
+ ## 信心度: **{finetuned_confidence:.1%}**
661
 
662
+ ## 機率分布:
663
+ - 🟢 **存活機率**: {finetuned_prob_survive:.2%}
664
+ - 🔴 **死亡機率**: {finetuned_prob_death:.2%}
665
 
666
+ ---
667
  ### 模型資訊:
668
+ - **微調方法**: {selected_model_info['tuning_method']}
669
+ - **最佳化指標**: {selected_model_info['best_metric']}
670
+ - **訓練時間**: {selected_model_info['timestamp']}
671
+ - **模型路徑**: {model_path}
672
 
673
  ---
674
  **注意**: 此預測僅供參考,實際醫療決策應由專業醫師判斷。
675
  """
676
 
677
+ # 清空記憶體
678
+ del finetuned_model
679
+ del finetuned_tokenizer
680
+ torch.cuda.empty_cache()
681
+
682
+ return baseline_output, finetuned_output
683
 
684
  except Exception as e:
685
+ import traceback
686
+ error_msg = f"❌ 預測錯誤:{str(e)}\n\n詳細錯誤訊息:\n{traceback.format_exc()}"
687
+ return error_msg, error_msg
688
+
689
+ def get_available_models():
690
+ """
691
+ 取得所有已訓練的模型列表
692
+ """
693
+ models_list_file = './saved_models_list.json'
694
+ if not os.path.exists(models_list_file):
695
+ return ["請先訓練模型"]
696
+
697
+ with open(models_list_file, 'r') as f:
698
+ models_list = json.load(f)
699
+
700
+ if len(models_list) == 0:
701
+ return ["請先訓練模型"]
702
+
703
+ # 格式化模型選項
704
+ model_choices = []
705
+ for i, model_info in enumerate(models_list, 1):
706
+ choice = f"路徑: {model_info['model_path']} | 方法: {model_info['tuning_method']} | 時間: {model_info['timestamp']}"
707
+ model_choices.append(choice)
708
+
709
+ return model_choices
710
 
711
  # ============================================================================
712
  # Gradio 介面部分 - 修改輸出為三個格子
 
1005
  gr.Markdown("""
1006
  ### 使用訓練好的模型進行預測
1007
 
1008
+ 選擇已訓練的模型,輸入病歷文本進行預測。會同時顯示未微調和微調模型的預測結果以供比較。
1009
  """)
1010
 
1011
  with gr.Row():
1012
  with gr.Column():
1013
+ # 模型選擇下拉選單
1014
+ model_dropdown = gr.Dropdown(
1015
+ label="選擇模型",
1016
+ choices=["請先訓練模型"],
1017
+ value="請先訓練模型",
1018
+ info="選擇要使用的已訓練模型"
1019
+ )
1020
+
1021
+ refresh_button = gr.Button(
1022
+ "🔄 重新整理模型列表",
1023
+ size="sm"
1024
+ )
1025
+
1026
  text_input = gr.Textbox(
1027
  label="輸入病歷文本",
1028
  placeholder="請輸入患者的病歷描述(英文)...",
 
1048
  )
1049
 
1050
  with gr.Column():
1051
+ gr.Markdown("### 預測結果比較")
1052
+
1053
+ # 上框:未微調 BERT 預測結果
1054
+ baseline_prediction_output = gr.Markdown(
1055
+ label="未微調 BERT",
1056
+ value="等待預測..."
1057
+ )
1058
+
1059
+ # 下框:微調 BERT 預測結果
1060
+ finetuned_prediction_output = gr.Markdown(
1061
+ label="微調 BERT",
1062
+ value="等待預測..."
1063
  )
1064
 
1065
  with gr.Tab("📖 使用說明"):
 
1112
  outputs=[lora_params, adalora_params]
1113
  )
1114
 
1115
+ # 設定訓練按鈕動作 - 注意這裡改為三個輸出
1116
  train_button.click(
1117
  fn=train_wrapper,
1118
  inputs=[
 
1139
  outputs=[data_info_output, baseline_output, finetuned_output] # 三個輸出
1140
  )
1141
 
1142
+ # 重新整理模型列表按鈕
1143
+ def refresh_models():
1144
+ return gr.update(choices=get_available_models(), value=get_available_models()[0])
1145
+
1146
+ refresh_button.click(
1147
+ fn=refresh_models,
1148
+ inputs=[],
1149
+ outputs=[model_dropdown]
1150
+ )
1151
+
1152
+ # 預測按鈕動作 - 兩個輸出:未微調和微調
1153
  predict_button.click(
1154
  fn=predict_text,
1155
+ inputs=[model_dropdown, text_input],
1156
+ outputs=[baseline_prediction_output, finetuned_prediction_output]
1157
  )
1158
 
1159
  if __name__ == "__main__":