0624 / app.py
smartTranscend's picture
Update app.py
f9f016f verified
Raw
History Blame Contribute Delete
25.5 kB
import streamlit as st
import pandas as pd
import torch
from torch import nn
from transformers import (
BertTokenizer,
BertForSequenceClassification,
TrainingArguments,
Trainer,
DataCollatorWithPadding
)
from datasets import Dataset
from sklearn.metrics import (
accuracy_score,
precision_recall_fscore_support,
roc_auc_score,
confusion_matrix
)
import numpy as np
from datetime import datetime
import json
import os
import gc
import random
import traceback
# ==================== 配置 ====================
st.set_page_config(
page_title="BERT 乳癌預測系統",
page_icon="🔬",
layout="wide",
initial_sidebar_state="expanded"
)
# 隨機種子設定
RANDOM_SEED = 42
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['PYTHONHASHSEED'] = str(seed)
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
try:
torch.use_deterministic_algorithms(True)
except:
pass
set_seed(RANDOM_SEED)
# 檢查設備
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# PEFT 檢查
try:
from peft import (
LoraConfig,
AdaLoraConfig,
get_peft_model,
TaskType,
PeftModel
)
PEFT_AVAILABLE = True
except ImportError:
PEFT_AVAILABLE = False
# ==================== 初始化 Session State ====================
if 'models_list' not in st.session_state:
st.session_state.models_list = []
if 'training_in_progress' not in st.session_state:
st.session_state.training_in_progress = False
# ==================== 模型管理函數 ====================
def load_models_list():
"""載入已訓練的模型列表"""
models_list_file = './saved_models_list.json'
if os.path.exists(models_list_file):
with open(models_list_file, 'r') as f:
return json.load(f)
return []
def save_models_list(models_list):
"""保存模型列表"""
with open('./saved_models_list.json', 'w') as f:
json.dump(models_list, f, indent=2)
def get_available_models():
"""取得所有已訓練的模型"""
models_list = load_models_list()
if len(models_list) == 0:
return []
model_choices = []
for model_info in models_list:
training_type = model_info.get('training_type', '第一次微調')
choice = f"{model_info['model_path']} | {training_type} | {model_info['tuning_method']} | {model_info['timestamp']}"
model_choices.append(choice)
return model_choices
def get_first_finetuning_models():
"""取得第一次微調的模型"""
models_list = load_models_list()
first_models = [m for m in models_list if not m.get('is_second_finetuning', False)]
return [m['model_path'] for m in first_models]
# ==================== 數據處理函數 ====================
def process_csv(file_path, label_column='label'):
"""處理上傳的 CSV 文件"""
try:
df = pd.read_csv(file_path)
st.success(f"✅ 已載入 {len(df)} 筆數據")
st.write("數據預覽:")
st.dataframe(df.head(10))
return df
except Exception as e:
st.error(f"❌ 無法讀取 CSV: {str(e)}")
return None
def prepare_dataset(df, text_column='text', label_column='label'):
"""準備訓練數據集"""
try:
if text_column not in df.columns or label_column not in df.columns:
raise ValueError(f"CSV 必須包含 '{text_column}' 和 '{label_column}' 列")
data = {
'text': df[text_column].tolist(),
'label': df[label_column].tolist()
}
dataset = Dataset.from_dict(data)
return dataset
except Exception as e:
st.error(f"❌ 數據準備錯誤: {str(e)}")
return None
# ==================== 訓練函數 ====================
def train_bert_model(
train_dataset,
tuning_method="full",
epochs=3,
batch_size=16,
lr=2e-5,
warmup_steps=200,
best_metric='f1',
weight_mult=0.8,
lora_r=16,
lora_alpha=32,
lora_dropout=0.1,
lora_modules="query,value",
adalora_init_r=12,
adalora_target_r=8,
adalora_tinit=0,
adalora_tfinal=0,
adalora_delta_t=1,
is_second_tuning=False,
base_model_path=None
):
"""訓練 BERT 模型"""
progress_bar = st.progress(0)
status_text = st.empty()
try:
# 加載基礎模型
status_text.text("🔄 載入模型...")
progress_bar.progress(10)
if tuning_method == "full" or base_model_path is None:
model = BertForSequenceClassification.from_pretrained(
"bert-base-uncased",
num_labels=2
).to(device)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
else:
base_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
if tuning_method in ["LoRA", "AdaLoRA"] and PEFT_AVAILABLE:
model = PeftModel.from_pretrained(base_model, base_model_path)
else:
model = BertForSequenceClassification.from_pretrained(base_model_path)
model = model.to(device)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# 應用 PEFT(如果需要)
if tuning_method == "LoRA" and PEFT_AVAILABLE:
status_text.text("⚙️ 配置 LoRA...")
progress_bar.progress(15)
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
bias="none",
target_modules=lora_modules.split(","),
task_type=TaskType.SEQ_CLS,
inference_mode=False
)
model = get_peft_model(model, lora_config)
elif tuning_method == "AdaLoRA" and PEFT_AVAILABLE:
status_text.text("⚙️ 配置 AdaLoRA...")
progress_bar.progress(15)
adalora_config = AdaLoraConfig(
init_r=adalora_init_r,
target_r=adalora_target_r,
tinit=int(adalora_tinit),
tfinal=int(adalora_tfinal),
deltaT=int(adalora_delta_t),
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
bias="none",
target_modules=lora_modules.split(","),
task_type=TaskType.SEQ_CLS,
inference_mode=False
)
model = get_peft_model(model, adalora_config)
# 分詞化數據
status_text.text("📝 分詞化數據...")
progress_bar.progress(20)
def tokenize_function(examples):
return tokenizer(
examples['text'],
padding="max_length",
truncation=True,
max_length=128
)
tokenized_dataset = train_dataset.map(tokenize_function, batched=True)
tokenized_dataset = tokenized_dataset.remove_columns(['text'])
tokenized_dataset = tokenized_dataset.rename_column('label', 'labels')
# 分割訓練/驗證集
split_dataset = tokenized_dataset.train_test_split(test_size=0.2, seed=RANDOM_SEED)
# 定義評估指標
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
f1, _, _, _ = precision_recall_fscore_support(
labels, predictions, average='binary', zero_division=0
)
acc = accuracy_score(labels, predictions)
try:
auc = roc_auc_score(labels, predictions)
except:
auc = 0.0
return {
'f1': float(f1),
'accuracy': float(acc),
'auc': float(auc)
}
# 訓練參數
status_text.text("🚀 開始訓練...")
progress_bar.progress(30)
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=int(epochs),
per_device_train_batch_size=int(batch_size),
per_device_eval_batch_size=int(batch_size),
warmup_steps=int(warmup_steps),
weight_decay=0.01,
logging_dir='./logs',
logging_steps=100,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model=best_metric,
learning_rate=lr,
seed=RANDOM_SEED
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=split_dataset['train'],
eval_dataset=split_dataset['test'],
compute_metrics=compute_metrics,
data_collator=DataCollatorWithPadding(tokenizer)
)
# 訓練
train_result = trainer.train()
progress_bar.progress(80)
# 評估
status_text.text("📊 評估模型...")
eval_result = trainer.evaluate(split_dataset['test'])
progress_bar.progress(90)
# 保存模型
status_text.text("💾 保存模型...")
model_save_path = f"./bert_model_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
os.makedirs(model_save_path, exist_ok=True)
if tuning_method in ["LoRA", "AdaLoRA"] and PEFT_AVAILABLE:
model.save_pretrained(model_save_path)
else:
model.save_pretrained(model_save_path)
tokenizer.save_pretrained(model_save_path)
# 記錄模型信息
models_list = load_models_list()
model_info = {
'model_path': model_save_path,
'tuning_method': tuning_method,
'best_metric': best_metric,
'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'eval_results': {k: float(v) for k, v in eval_result.items()},
'training_type': '二次微調' if is_second_tuning else '第一次微調',
'is_second_finetuning': is_second_tuning,
'base_model': base_model_path
}
models_list.append(model_info)
save_models_list(models_list)
progress_bar.progress(100)
status_text.text("✅ 訓練完成!")
# 返回結果
results = {
'status': 'success',
'model_path': model_save_path,
'eval_results': eval_result,
'model_info': model_info
}
del trainer
torch.cuda.empty_cache()
gc.collect()
return results
except Exception as e:
status_text.text(f"❌ 訓練失敗: {str(e)}")
st.error(traceback.format_exc())
return {'status': 'error', 'message': str(e)}
# ==================== 預測函數 ====================
def predict(text, model_path):
"""進行預測"""
try:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# 載入模型
models_list = load_models_list()
model_info = None
for m in models_list:
if m['model_path'] == model_path:
model_info = m
break
tuning_method = model_info['tuning_method'] if model_info else 'full'
if tuning_method in ["LoRA", "AdaLoRA"] and PEFT_AVAILABLE:
base_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
model = PeftModel.from_pretrained(base_model, model_path)
else:
model = BertForSequenceClassification.from_pretrained(model_path)
model = model.to(device)
model.eval()
# 預測
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=-1)
pred_class = probs.argmax(-1).item()
confidence = probs[0][pred_class].item()
result = "存活" if pred_class == 0 else "死亡"
prob_survive = probs[0][0].item()
prob_death = probs[0][1].item()
del model
torch.cuda.empty_cache()
return {
'result': result,
'confidence': confidence,
'prob_survive': prob_survive,
'prob_death': prob_death,
'model_info': model_info
}
except Exception as e:
return {
'status': 'error',
'message': str(e),
'traceback': traceback.format_exc()
}
# ==================== Streamlit UI ====================
st.title("🔬 BERT 乳癌生存預測系統")
# 側邊欄設備信息
with st.sidebar:
st.header("📋 系統信息")
st.write(f"**設備**: {device}")
st.write(f"**PEFT 可用**: {'✅' if PEFT_AVAILABLE else '❌'}")
st.write(f"**CUDA 可用**: {'✅' if torch.cuda.is_available() else '❌'}")
if torch.cuda.is_available():
st.write(f"**GPU**: {torch.cuda.get_device_name(0)}")
st.write(f"**CUDA 版本**: {torch.version.cuda}")
# 主要 UI - Tab 頁面
tab1, tab2, tab3, tab4 = st.tabs(
["1️⃣ 第一次微調", "2️⃣ 二次微調", "3️⃣ 測試評估", "4️⃣ 模型預測"]
)
# ==================== Tab 1: 第一次微調 ====================
with tab1:
st.header("第一次微調")
col1, col2 = st.columns([1, 2])
with col1:
st.subheader("📤 數據上傳")
uploaded_file = st.file_uploader("上傳訓練數據 (CSV)", type=['csv'], key='first_train')
if uploaded_file:
df = process_csv(uploaded_file)
st.subheader("⚙️ 訓練參數")
tuning_method = st.selectbox(
"微調方法",
["full", "LoRA", "AdaLoRA"] if PEFT_AVAILABLE else ["full"],
key='first_method'
)
best_metric = st.selectbox(
"最佳化指標",
["f1", "accuracy", "auc"],
key='first_metric'
)
col_train_a, col_train_b = st.columns(2)
with col_train_a:
epochs = st.number_input("訓練輪數", 1, 10, 3, key='first_epochs')
batch_size = st.number_input("批次大小", 8, 64, 16, key='first_batch')
with col_train_b:
lr = st.number_input("學習率", 1e-6, 1e-3, 2e-5, format="%.2e", key='first_lr')
warmup = st.number_input("Warmup Steps", 0, 1000, 200, key='first_warmup')
weight_mult = st.slider("權重倍數", 0.1, 2.0, 0.8, 0.1, key='first_weight')
# LoRA 參數
if tuning_method == "LoRA":
st.write("### 🔷 LoRA 參數")
col_lora_a, col_lora_b = st.columns(2)
with col_lora_a:
lora_r = st.slider("LoRA Rank", 4, 64, 16, 4, key='first_lora_r')
lora_alpha = st.slider("LoRA Alpha", 8, 128, 32, 8, key='first_lora_alpha')
with col_lora_b:
lora_dropout = st.slider("LoRA Dropout", 0.0, 0.5, 0.1, 0.05, key='first_lora_dropout')
lora_modules = st.text_input("目標模組", "query,value", key='first_lora_modules')
else:
lora_r = lora_alpha = 16
lora_dropout = 0.1
lora_modules = "query,value"
# AdaLoRA 參數
if tuning_method == "AdaLoRA":
st.write("### 🔶 AdaLoRA 參數")
col_ada_a, col_ada_b = st.columns(2)
with col_ada_a:
adalora_init_r = st.slider("初始 Rank", 4, 64, 12, 4, key='first_ada_init_r')
adalora_target_r = st.slider("目標 Rank", 4, 64, 8, 4, key='first_ada_target_r')
with col_ada_b:
adalora_tinit = st.number_input("Tinit", 0, 1000, 0, key='first_ada_tinit')
adalora_tfinal = st.number_input("Tfinal", 0, 1000, 0, key='first_ada_tfinal')
adalora_delta_t = st.number_input("Delta T", 1, 100, 1, key='first_ada_delta_t')
else:
adalora_init_r = adalora_target_r = 12
adalora_tinit = adalora_tfinal = 0
adalora_delta_t = 1
if st.button("🚀 開始第一次微調", key='first_train_btn'):
st.session_state.training_in_progress = True
dataset = prepare_dataset(df)
if dataset:
result = train_bert_model(
dataset,
tuning_method=tuning_method,
epochs=epochs,
batch_size=batch_size,
lr=lr,
warmup_steps=warmup,
best_metric=best_metric,
weight_mult=weight_mult,
lora_r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
lora_modules=lora_modules,
adalora_init_r=adalora_init_r,
adalora_target_r=adalora_target_r,
adalora_tinit=adalora_tinit,
adalora_tfinal=adalora_tfinal,
adalora_delta_t=adalora_delta_t,
is_second_tuning=False
)
if result['status'] == 'success':
st.success("✅ 訓練完成!")
st.json(result['eval_results'])
else:
st.error(f"❌ 訓練失敗: {result['message']}")
st.session_state.training_in_progress = False
# ==================== Tab 2: 二次微調 ====================
with tab2:
st.header("二次微調")
col1, col2 = st.columns([1, 2])
with col1:
st.subheader("🔄 選擇基礎模型")
first_models = get_first_finetuning_models()
if len(first_models) == 0:
st.warning("⚠️ 請先進行第一次微調")
else:
base_model = st.selectbox("選擇第一次微調模型", first_models, key='second_base')
st.subheader("📤 數據上傳")
uploaded_file = st.file_uploader("上傳新訓練數據 (CSV)", type=['csv'], key='second_train')
if uploaded_file:
df = process_csv(uploaded_file)
st.subheader("⚙️ 訓練參數")
best_metric = st.selectbox(
"最佳化指標",
["f1", "accuracy", "auc"],
key='second_metric'
)
col_train_a, col_train_b = st.columns(2)
with col_train_a:
epochs = st.number_input("訓練輪數", 1, 10, 2, key='second_epochs')
batch_size = st.number_input("批次大小", 8, 64, 16, key='second_batch')
with col_train_b:
lr = st.number_input("學習率", 1e-6, 1e-3, 1e-5, format="%.2e", key='second_lr')
warmup = st.number_input("Warmup Steps", 0, 1000, 100, key='second_warmup')
weight_mult = st.slider("權重倍數", 0.1, 2.0, 0.8, 0.1, key='second_weight')
if st.button("🚀 開始二次微調", key='second_train_btn'):
st.session_state.training_in_progress = True
dataset = prepare_dataset(df)
if dataset:
# 從基礎模型獲取微調方法
models_list = load_models_list()
tuning_method = "full"
for m in models_list:
if m['model_path'] == base_model:
tuning_method = m['tuning_method']
break
result = train_bert_model(
dataset,
tuning_method=tuning_method,
epochs=epochs,
batch_size=batch_size,
lr=lr,
warmup_steps=warmup,
best_metric=best_metric,
weight_mult=weight_mult,
is_second_tuning=True,
base_model_path=base_model
)
if result['status'] == 'success':
st.success("✅ 二次微調完成!")
st.json(result['eval_results'])
else:
st.error(f"❌ 訓練失敗: {result['message']}")
st.session_state.training_in_progress = False
# ==================== Tab 3: 測試評估 ====================
with tab3:
st.header("測試評估")
col1, col2 = st.columns([1, 2])
with col1:
st.subheader("📤 上傳測試數據")
uploaded_file = st.file_uploader("上傳測試數據 (CSV)", type=['csv'], key='test_data')
if uploaded_file:
df = process_csv(uploaded_file)
st.subheader("🎯 選擇模型")
available_models = get_available_models()
if len(available_models) == 0:
st.warning("⚠️ 無可用模型")
else:
selected_model = st.selectbox("選擇模型", available_models, key='test_model')
if st.button("▶️ 開始評估", key='test_btn'):
st.info("⏳ 評估中... (此功能在簡化版本中暫未實現)")
# ==================== Tab 4: 模型預測 ====================
with tab4:
st.header("模型預測")
col1, col2 = st.columns([1, 2])
with col1:
st.subheader("🎯 選擇模型")
available_models = get_available_models()
if len(available_models) == 0:
st.warning("⚠️ 無可用模型")
else:
selected_model = st.selectbox("選擇預測模型", available_models, key='pred_model')
st.subheader("📝 輸入文本")
text_input = st.text_area(
"輸入病歷或文本",
placeholder="例如: The patient shows signs of ...",
height=150,
key='pred_text'
)
if st.button("🔮 進行預測", key='pred_btn'):
if text_input.strip():
# 提取實際模型路徑
model_path = selected_model.split(" | ")[0]
with st.spinner("⏳ 預測中..."):
result = predict(text_input, model_path)
if 'status' not in result or result['status'] != 'error':
col_result_a, col_result_b = st.columns(2)
with col_result_a:
st.markdown("### 🟢 預測結果")
st.markdown(f"## **{result['result']}**")
st.metric("信心度", f"{result['confidence']:.1%}")
with col_result_b:
st.markdown("### 📊 機率分布")
st.metric("存活機率", f"{result['prob_survive']:.2%}")
st.metric("死亡機率", f"{result['prob_death']:.2%}")
st.markdown("---")
st.markdown("### 📋 模型信息")
if result['model_info']:
st.write(f"**訓練類型**: {result['model_info']['training_type']}")
st.write(f"**微調方法**: {result['model_info']['tuning_method']}")
st.write(f"**訓練時間**: {result['model_info']['timestamp']}")
else:
st.error(f"❌ 預測失敗: {result['message']}")
else:
st.warning("⚠️ 請輸入文本")
st.markdown("---")
st.markdown("**注意**: 此預測系統僅供參考,實際醫療決策應由專業醫師判斷。")