Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -19,6 +19,7 @@ import numpy as np
|
|
| 19 |
from datetime import datetime
|
| 20 |
import json
|
| 21 |
import os
|
|
|
|
| 22 |
|
| 23 |
# PEFT 相關的 import(LoRA 和 AdaLoRA)
|
| 24 |
try:
|
|
@@ -37,8 +38,7 @@ except ImportError:
|
|
| 37 |
# 檢查 GPU
|
| 38 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 39 |
|
| 40 |
-
|
| 41 |
-
LAST_MODEL_PATH = None
|
| 42 |
LAST_TOKENIZER = None
|
| 43 |
LAST_TUNING_METHOD = None
|
| 44 |
|
|
@@ -157,6 +157,12 @@ def run_original_code_with_tuning(
|
|
| 157 |
|
| 158 |
global LAST_MODEL_PATH, LAST_TOKENIZER, LAST_TUNING_METHOD
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
# ==================== 您的原始程式碼開始 ====================
|
| 161 |
|
| 162 |
# 讀取上傳的檔案
|
|
@@ -466,6 +472,32 @@ def run_original_code_with_tuning(
|
|
| 466 |
model.save_pretrained(save_dir)
|
| 467 |
tokenizer.save_pretrained(save_dir)
|
| 468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
# 儲存到全域變數供預測使用
|
| 470 |
LAST_MODEL_PATH = save_dir
|
| 471 |
LAST_TOKENIZER = tokenizer
|
|
@@ -477,40 +509,131 @@ def run_original_code_with_tuning(
|
|
| 477 |
print("=" * 80)
|
| 478 |
print(f"完成時間: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
| 479 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
# 加入所有資訊到結果中
|
| 481 |
results['tuning_method'] = tuning_method
|
| 482 |
results['best_metric'] = best_metric
|
| 483 |
results['best_metric_value'] = results[f'eval_{metric_map.get(best_metric, "f1")}']
|
| 484 |
results['baseline_results'] = baseline_results
|
|
|
|
| 485 |
|
| 486 |
return results
|
| 487 |
|
| 488 |
-
def predict_text(text_input):
|
| 489 |
"""
|
| 490 |
-
預測功能 -
|
| 491 |
"""
|
| 492 |
-
global LAST_MODEL_PATH, LAST_TOKENIZER, LAST_TUNING_METHOD
|
| 493 |
|
| 494 |
-
if
|
| 495 |
-
return "
|
| 496 |
|
| 497 |
try:
|
| 498 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
|
| 500 |
# 載入模型
|
| 501 |
-
|
|
|
|
| 502 |
# 載入 PEFT 模型
|
| 503 |
base_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
|
| 504 |
-
|
| 505 |
-
|
| 506 |
else:
|
| 507 |
# 載入一般模型
|
| 508 |
-
|
| 509 |
|
| 510 |
-
|
| 511 |
|
| 512 |
-
# Tokenize 輸入
|
| 513 |
-
|
| 514 |
text_input,
|
| 515 |
truncation=True,
|
| 516 |
padding='max_length',
|
|
@@ -518,41 +641,72 @@ def predict_text(text_input):
|
|
| 518 |
return_tensors='pt'
|
| 519 |
).to(device)
|
| 520 |
|
| 521 |
-
# 預測
|
| 522 |
with torch.no_grad():
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
prob_death = probs[0][1].item()
|
| 532 |
|
| 533 |
-
|
| 534 |
-
#
|
| 535 |
|
| 536 |
-
##
|
| 537 |
|
| 538 |
-
##
|
| 539 |
|
| 540 |
-
##
|
| 541 |
-
- 🟢 **存活機率**: {
|
| 542 |
-
- 🔴 **死亡機率**: {
|
| 543 |
|
|
|
|
| 544 |
### 模型資訊:
|
| 545 |
-
-
|
| 546 |
-
-
|
|
|
|
|
|
|
| 547 |
|
| 548 |
---
|
| 549 |
**注意**: 此預測僅供參考,實際醫療決策應由專業醫師判斷。
|
| 550 |
"""
|
| 551 |
|
| 552 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
|
| 554 |
except Exception as e:
|
| 555 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 556 |
|
| 557 |
# ============================================================================
|
| 558 |
# Gradio 介面部分 - 修改輸出為三個格子
|
|
@@ -851,11 +1005,24 @@ with gr.Blocks(title="BERT 完整訓練與預測平台", theme=gr.themes.Soft())
|
|
| 851 |
gr.Markdown("""
|
| 852 |
### 使用訓練好的模型進行預測
|
| 853 |
|
| 854 |
-
|
| 855 |
""")
|
| 856 |
|
| 857 |
with gr.Row():
|
| 858 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 859 |
text_input = gr.Textbox(
|
| 860 |
label="輸入病歷文本",
|
| 861 |
placeholder="請輸入患者的病歷描述(英文)...",
|
|
@@ -881,9 +1048,18 @@ with gr.Blocks(title="BERT 完整訓練與預測平台", theme=gr.themes.Soft())
|
|
| 881 |
)
|
| 882 |
|
| 883 |
with gr.Column():
|
| 884 |
-
|
| 885 |
-
|
| 886 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 887 |
)
|
| 888 |
|
| 889 |
with gr.Tab("📖 使用說明"):
|
|
@@ -936,7 +1112,7 @@ with gr.Blocks(title="BERT 完整訓練與預測平台", theme=gr.themes.Soft())
|
|
| 936 |
outputs=[lora_params, adalora_params]
|
| 937 |
)
|
| 938 |
|
| 939 |
-
# 設定按鈕動作 - 注意這裡改為三個輸出
|
| 940 |
train_button.click(
|
| 941 |
fn=train_wrapper,
|
| 942 |
inputs=[
|
|
@@ -963,10 +1139,21 @@ with gr.Blocks(title="BERT 完整訓練與預測平台", theme=gr.themes.Soft())
|
|
| 963 |
outputs=[data_info_output, baseline_output, finetuned_output] # 三個輸出
|
| 964 |
)
|
| 965 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 966 |
predict_button.click(
|
| 967 |
fn=predict_text,
|
| 968 |
-
inputs=[text_input],
|
| 969 |
-
outputs=[
|
| 970 |
)
|
| 971 |
|
| 972 |
if __name__ == "__main__":
|
|
|
|
| 19 |
from datetime import datetime
|
| 20 |
import json
|
| 21 |
import os
|
| 22 |
+
import gc # 用於記憶體清理
|
| 23 |
|
| 24 |
# PEFT 相關的 import(LoRA 和 AdaLoRA)
|
| 25 |
try:
|
|
|
|
| 38 |
# 檢查 GPU
|
| 39 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 40 |
|
| 41 |
+
_MODEL_PATH = None
|
|
|
|
| 42 |
LAST_TOKENIZER = None
|
| 43 |
LAST_TUNING_METHOD = None
|
| 44 |
|
|
|
|
| 157 |
|
| 158 |
global LAST_MODEL_PATH, LAST_TOKENIZER, LAST_TUNING_METHOD
|
| 159 |
|
| 160 |
+
# ==================== 清空記憶體(訓練前) ====================
|
| 161 |
+
import gc
|
| 162 |
+
torch.cuda.empty_cache()
|
| 163 |
+
gc.collect()
|
| 164 |
+
print("🧹 記憶體已清空")
|
| 165 |
+
|
| 166 |
# ==================== 您的原始程式碼開始 ====================
|
| 167 |
|
| 168 |
# 讀取上傳的檔案
|
|
|
|
| 472 |
model.save_pretrained(save_dir)
|
| 473 |
tokenizer.save_pretrained(save_dir)
|
| 474 |
|
| 475 |
+
# 儲存模型資訊到 JSON 檔案(用於預測頁面選擇)
|
| 476 |
+
model_info = {
|
| 477 |
+
'model_path': save_dir,
|
| 478 |
+
'tuning_method': tuning_method,
|
| 479 |
+
'best_metric': best_metric,
|
| 480 |
+
'best_metric_value': float(results[f'eval_{metric_map.get(best_metric, "f1")}']),
|
| 481 |
+
'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
| 482 |
+
'weight_multiplier': weight_multiplier,
|
| 483 |
+
'epochs': epochs
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
# 讀取現有的模型列表
|
| 487 |
+
models_list_file = './saved_models_list.json'
|
| 488 |
+
if os.path.exists(models_list_file):
|
| 489 |
+
with open(models_list_file, 'r') as f:
|
| 490 |
+
models_list = json.load(f)
|
| 491 |
+
else:
|
| 492 |
+
models_list = []
|
| 493 |
+
|
| 494 |
+
# 加入新模型資訊
|
| 495 |
+
models_list.append(model_info)
|
| 496 |
+
|
| 497 |
+
# 儲存更新後的列表
|
| 498 |
+
with open(models_list_file, 'w') as f:
|
| 499 |
+
json.dump(models_list, f, indent=2)
|
| 500 |
+
|
| 501 |
# 儲存到全域變數供預測使用
|
| 502 |
LAST_MODEL_PATH = save_dir
|
| 503 |
LAST_TOKENIZER = tokenizer
|
|
|
|
| 509 |
print("=" * 80)
|
| 510 |
print(f"完成時間: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
| 511 |
|
| 512 |
+
# ==================== 清空記憶體(訓練後) ====================
|
| 513 |
+
del model
|
| 514 |
+
del trainer
|
| 515 |
+
torch.cuda.empty_cache()
|
| 516 |
+
gc.collect()
|
| 517 |
+
print("🧹 訓練後記憶體已清空")
|
| 518 |
+
|
| 519 |
# 加入所有資訊到結果中
|
| 520 |
results['tuning_method'] = tuning_method
|
| 521 |
results['best_metric'] = best_metric
|
| 522 |
results['best_metric_value'] = results[f'eval_{metric_map.get(best_metric, "f1")}']
|
| 523 |
results['baseline_results'] = baseline_results
|
| 524 |
+
results['model_path'] = save_dir
|
| 525 |
|
| 526 |
return results
|
| 527 |
|
| 528 |
+
def predict_text(model_choice, text_input):
|
| 529 |
"""
|
| 530 |
+
預測功能 - 支援選擇已訓練的模型,並同時顯示未微調和微調的預測結果
|
| 531 |
"""
|
|
|
|
| 532 |
|
| 533 |
+
if not text_input or text_input.strip() == "":
|
| 534 |
+
return "請輸入文本", "請輸入文本"
|
| 535 |
|
| 536 |
try:
|
| 537 |
+
# ==================== 未微調的 BERT 預測 ====================
|
| 538 |
+
print("\n使用未微調 BERT 預測...")
|
| 539 |
+
baseline_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
| 540 |
+
baseline_model = BertForSequenceClassification.from_pretrained(
|
| 541 |
+
"bert-base-uncased",
|
| 542 |
+
num_labels=2
|
| 543 |
+
).to(device)
|
| 544 |
+
baseline_model.eval()
|
| 545 |
+
|
| 546 |
+
# Tokenize 輸入(未微調)
|
| 547 |
+
baseline_inputs = baseline_tokenizer(
|
| 548 |
+
text_input,
|
| 549 |
+
truncation=True,
|
| 550 |
+
padding='max_length',
|
| 551 |
+
max_length=256,
|
| 552 |
+
return_tensors='pt'
|
| 553 |
+
).to(device)
|
| 554 |
+
|
| 555 |
+
# 預測(未微調)
|
| 556 |
+
with torch.no_grad():
|
| 557 |
+
baseline_outputs = baseline_model(**baseline_inputs)
|
| 558 |
+
baseline_probs = torch.nn.functional.softmax(baseline_outputs.logits, dim=-1)
|
| 559 |
+
baseline_pred_class = baseline_probs.argmax(-1).item()
|
| 560 |
+
baseline_confidence = baseline_probs[0][baseline_pred_class].item()
|
| 561 |
+
|
| 562 |
+
baseline_result = "存活" if baseline_pred_class == 0 else "死亡"
|
| 563 |
+
baseline_prob_survive = baseline_probs[0][0].item()
|
| 564 |
+
baseline_prob_death = baseline_probs[0][1].item()
|
| 565 |
+
|
| 566 |
+
baseline_output = f"""
|
| 567 |
+
# 🔵 未微調 BERT 預測結果
|
| 568 |
+
|
| 569 |
+
## 預測類別: **{baseline_result}**
|
| 570 |
+
|
| 571 |
+
## 信心度: **{baseline_confidence:.1%}**
|
| 572 |
+
|
| 573 |
+
## 機率分布:
|
| 574 |
+
- 🟢 **存活機率**: {baseline_prob_survive:.2%}
|
| 575 |
+
- 🔴 **死亡機率**: {baseline_prob_death:.2%}
|
| 576 |
+
|
| 577 |
+
---
|
| 578 |
+
**說明**: 此為原始 BERT 模型,未經任何領域資料訓練
|
| 579 |
+
"""
|
| 580 |
+
|
| 581 |
+
# 清空記憶體
|
| 582 |
+
del baseline_model
|
| 583 |
+
del baseline_tokenizer
|
| 584 |
+
torch.cuda.empty_cache()
|
| 585 |
+
|
| 586 |
+
# ==================== 微調後的 BERT 預測 ====================
|
| 587 |
+
|
| 588 |
+
if model_choice == "請先訓練模型":
|
| 589 |
+
finetuned_output = """
|
| 590 |
+
# 🟢 微調 BERT 預測結果
|
| 591 |
+
|
| 592 |
+
❌ 尚未訓練任何模型,請先在「模型訓練」頁面訓練模型
|
| 593 |
+
"""
|
| 594 |
+
return baseline_output, finetuned_output
|
| 595 |
+
|
| 596 |
+
# 解析選擇的模型路徑
|
| 597 |
+
model_path = model_choice.split(" | ")[0].replace("路徑: ", "")
|
| 598 |
+
|
| 599 |
+
# 從 JSON 讀取模型資訊
|
| 600 |
+
with open('./saved_models_list.json', 'r') as f:
|
| 601 |
+
models_list = json.load(f)
|
| 602 |
+
|
| 603 |
+
selected_model_info = None
|
| 604 |
+
for model_info in models_list:
|
| 605 |
+
if model_info['model_path'] == model_path:
|
| 606 |
+
selected_model_info = model_info
|
| 607 |
+
break
|
| 608 |
+
|
| 609 |
+
if selected_model_info is None:
|
| 610 |
+
finetuned_output = f"""
|
| 611 |
+
# 🟢 微調 BERT 預測結果
|
| 612 |
+
|
| 613 |
+
❌ 找不到模型:{model_path}
|
| 614 |
+
"""
|
| 615 |
+
return baseline_output, finetuned_output
|
| 616 |
+
|
| 617 |
+
print(f"\n使用微調模型: {model_path}")
|
| 618 |
+
|
| 619 |
+
# 載入 tokenizer
|
| 620 |
+
finetuned_tokenizer = BertTokenizer.from_pretrained(model_path)
|
| 621 |
|
| 622 |
# 載入模型
|
| 623 |
+
tuning_method = selected_model_info['tuning_method']
|
| 624 |
+
if tuning_method in ["LoRA", "AdaLoRA"] and PEFT_AVAILABLE:
|
| 625 |
# 載入 PEFT 模型
|
| 626 |
base_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
|
| 627 |
+
finetuned_model = PeftModel.from_pretrained(base_model, model_path)
|
| 628 |
+
finetuned_model = finetuned_model.to(device)
|
| 629 |
else:
|
| 630 |
# 載入一般模型
|
| 631 |
+
finetuned_model = BertForSequenceClassification.from_pretrained(model_path).to(device)
|
| 632 |
|
| 633 |
+
finetuned_model.eval()
|
| 634 |
|
| 635 |
+
# Tokenize 輸入(微調)
|
| 636 |
+
finetuned_inputs = finetuned_tokenizer(
|
| 637 |
text_input,
|
| 638 |
truncation=True,
|
| 639 |
padding='max_length',
|
|
|
|
| 641 |
return_tensors='pt'
|
| 642 |
).to(device)
|
| 643 |
|
| 644 |
+
# 預測(微調)
|
| 645 |
with torch.no_grad():
|
| 646 |
+
finetuned_outputs = finetuned_model(**finetuned_inputs)
|
| 647 |
+
finetuned_probs = torch.nn.functional.softmax(finetuned_outputs.logits, dim=-1)
|
| 648 |
+
finetuned_pred_class = finetuned_probs.argmax(-1).item()
|
| 649 |
+
finetuned_confidence = finetuned_probs[0][finetuned_pred_class].item()
|
| 650 |
|
| 651 |
+
finetuned_result = "存活" if finetuned_pred_class == 0 else "死亡"
|
| 652 |
+
finetuned_prob_survive = finetuned_probs[0][0].item()
|
| 653 |
+
finetuned_prob_death = finetuned_probs[0][1].item()
|
|
|
|
| 654 |
|
| 655 |
+
finetuned_output = f"""
|
| 656 |
+
# 🟢 微調 BERT 預測結果
|
| 657 |
|
| 658 |
+
## 預測類別: **{finetuned_result}**
|
| 659 |
|
| 660 |
+
## 信心度: **{finetuned_confidence:.1%}**
|
| 661 |
|
| 662 |
+
## 機率分布:
|
| 663 |
+
- 🟢 **存活機率**: {finetuned_prob_survive:.2%}
|
| 664 |
+
- 🔴 **死亡機率**: {finetuned_prob_death:.2%}
|
| 665 |
|
| 666 |
+
---
|
| 667 |
### 模型資訊:
|
| 668 |
+
- **微調方法**: {selected_model_info['tuning_method']}
|
| 669 |
+
- **最佳化指標**: {selected_model_info['best_metric']}
|
| 670 |
+
- **訓練時間**: {selected_model_info['timestamp']}
|
| 671 |
+
- **模型路徑**: {model_path}
|
| 672 |
|
| 673 |
---
|
| 674 |
**注意**: 此預測僅供參考,實際醫療決策應由專業醫師判斷。
|
| 675 |
"""
|
| 676 |
|
| 677 |
+
# 清空記憶體
|
| 678 |
+
del finetuned_model
|
| 679 |
+
del finetuned_tokenizer
|
| 680 |
+
torch.cuda.empty_cache()
|
| 681 |
+
|
| 682 |
+
return baseline_output, finetuned_output
|
| 683 |
|
| 684 |
except Exception as e:
|
| 685 |
+
import traceback
|
| 686 |
+
error_msg = f"❌ 預測錯誤:{str(e)}\n\n詳細錯誤訊息:\n{traceback.format_exc()}"
|
| 687 |
+
return error_msg, error_msg
|
| 688 |
+
|
| 689 |
+
def get_available_models():
|
| 690 |
+
"""
|
| 691 |
+
取得所有已訓練的模型列表
|
| 692 |
+
"""
|
| 693 |
+
models_list_file = './saved_models_list.json'
|
| 694 |
+
if not os.path.exists(models_list_file):
|
| 695 |
+
return ["請先訓練模型"]
|
| 696 |
+
|
| 697 |
+
with open(models_list_file, 'r') as f:
|
| 698 |
+
models_list = json.load(f)
|
| 699 |
+
|
| 700 |
+
if len(models_list) == 0:
|
| 701 |
+
return ["請先訓練模型"]
|
| 702 |
+
|
| 703 |
+
# 格式化模型選項
|
| 704 |
+
model_choices = []
|
| 705 |
+
for i, model_info in enumerate(models_list, 1):
|
| 706 |
+
choice = f"路徑: {model_info['model_path']} | 方法: {model_info['tuning_method']} | 時間: {model_info['timestamp']}"
|
| 707 |
+
model_choices.append(choice)
|
| 708 |
+
|
| 709 |
+
return model_choices
|
| 710 |
|
| 711 |
# ============================================================================
|
| 712 |
# Gradio 介面部分 - 修改輸出為三個格子
|
|
|
|
| 1005 |
gr.Markdown("""
|
| 1006 |
### 使用訓練好的模型進行預測
|
| 1007 |
|
| 1008 |
+
選擇已訓練的模型,輸入病歷文本進行預測。會同時顯示未微調和微調模型的預測結果以供比較。
|
| 1009 |
""")
|
| 1010 |
|
| 1011 |
with gr.Row():
|
| 1012 |
with gr.Column():
|
| 1013 |
+
# 模型選擇下拉選單
|
| 1014 |
+
model_dropdown = gr.Dropdown(
|
| 1015 |
+
label="選擇模型",
|
| 1016 |
+
choices=["請先訓練模型"],
|
| 1017 |
+
value="請先訓練模型",
|
| 1018 |
+
info="選擇要使用的已訓練模型"
|
| 1019 |
+
)
|
| 1020 |
+
|
| 1021 |
+
refresh_button = gr.Button(
|
| 1022 |
+
"🔄 重新整理模型列表",
|
| 1023 |
+
size="sm"
|
| 1024 |
+
)
|
| 1025 |
+
|
| 1026 |
text_input = gr.Textbox(
|
| 1027 |
label="輸入病歷文本",
|
| 1028 |
placeholder="請輸入患者的病歷描述(英文)...",
|
|
|
|
| 1048 |
)
|
| 1049 |
|
| 1050 |
with gr.Column():
|
| 1051 |
+
gr.Markdown("### 預測結果比較")
|
| 1052 |
+
|
| 1053 |
+
# 上框:未微調 BERT 預測結果
|
| 1054 |
+
baseline_prediction_output = gr.Markdown(
|
| 1055 |
+
label="未微調 BERT",
|
| 1056 |
+
value="等待預測..."
|
| 1057 |
+
)
|
| 1058 |
+
|
| 1059 |
+
# 下框:微調 BERT 預測結果
|
| 1060 |
+
finetuned_prediction_output = gr.Markdown(
|
| 1061 |
+
label="微調 BERT",
|
| 1062 |
+
value="等待預測..."
|
| 1063 |
)
|
| 1064 |
|
| 1065 |
with gr.Tab("📖 使用說明"):
|
|
|
|
| 1112 |
outputs=[lora_params, adalora_params]
|
| 1113 |
)
|
| 1114 |
|
| 1115 |
+
# 設定訓練按鈕動作 - 注意這裡改為三個輸出
|
| 1116 |
train_button.click(
|
| 1117 |
fn=train_wrapper,
|
| 1118 |
inputs=[
|
|
|
|
| 1139 |
outputs=[data_info_output, baseline_output, finetuned_output] # 三個輸出
|
| 1140 |
)
|
| 1141 |
|
| 1142 |
+
# 重新整理模型列表按鈕
|
| 1143 |
+
def refresh_models():
|
| 1144 |
+
return gr.update(choices=get_available_models(), value=get_available_models()[0])
|
| 1145 |
+
|
| 1146 |
+
refresh_button.click(
|
| 1147 |
+
fn=refresh_models,
|
| 1148 |
+
inputs=[],
|
| 1149 |
+
outputs=[model_dropdown]
|
| 1150 |
+
)
|
| 1151 |
+
|
| 1152 |
+
# 預測按鈕動作 - 兩個輸出:未微調和微調
|
| 1153 |
predict_button.click(
|
| 1154 |
fn=predict_text,
|
| 1155 |
+
inputs=[model_dropdown, text_input],
|
| 1156 |
+
outputs=[baseline_prediction_output, finetuned_prediction_output]
|
| 1157 |
)
|
| 1158 |
|
| 1159 |
if __name__ == "__main__":
|