import streamlit as st import pandas as pd import torch from torch import nn from transformers import ( BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding ) from datasets import Dataset from sklearn.metrics import ( accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix ) import numpy as np from datetime import datetime import json import os import gc import random import traceback # ==================== 配置 ==================== st.set_page_config( page_title="BERT 乳癌預測系統", page_icon="🔬", layout="wide", initial_sidebar_state="expanded" ) # 隨機種子設定 RANDOM_SEED = 42 def set_seed(seed=42): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False os.environ['PYTHONHASHSEED'] = str(seed) os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' try: torch.use_deterministic_algorithms(True) except: pass set_seed(RANDOM_SEED) # 檢查設備 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # PEFT 檢查 try: from peft import ( LoraConfig, AdaLoraConfig, get_peft_model, TaskType, PeftModel ) PEFT_AVAILABLE = True except ImportError: PEFT_AVAILABLE = False # ==================== 初始化 Session State ==================== if 'models_list' not in st.session_state: st.session_state.models_list = [] if 'training_in_progress' not in st.session_state: st.session_state.training_in_progress = False # ==================== 模型管理函數 ==================== def load_models_list(): """載入已訓練的模型列表""" models_list_file = './saved_models_list.json' if os.path.exists(models_list_file): with open(models_list_file, 'r') as f: return json.load(f) return [] def save_models_list(models_list): """保存模型列表""" with open('./saved_models_list.json', 'w') as f: json.dump(models_list, f, indent=2) def get_available_models(): """取得所有已訓練的模型""" models_list = load_models_list() if len(models_list) == 0: return [] model_choices = [] for model_info in models_list: training_type = model_info.get('training_type', '第一次微調') choice = f"{model_info['model_path']} | {training_type} | {model_info['tuning_method']} | {model_info['timestamp']}" model_choices.append(choice) return model_choices def get_first_finetuning_models(): """取得第一次微調的模型""" models_list = load_models_list() first_models = [m for m in models_list if not m.get('is_second_finetuning', False)] return [m['model_path'] for m in first_models] # ==================== 數據處理函數 ==================== def process_csv(file_path, label_column='label'): """處理上傳的 CSV 文件""" try: df = pd.read_csv(file_path) st.success(f"✅ 已載入 {len(df)} 筆數據") st.write("數據預覽:") st.dataframe(df.head(10)) return df except Exception as e: st.error(f"❌ 無法讀取 CSV: {str(e)}") return None def prepare_dataset(df, text_column='text', label_column='label'): """準備訓練數據集""" try: if text_column not in df.columns or label_column not in df.columns: raise ValueError(f"CSV 必須包含 '{text_column}' 和 '{label_column}' 列") data = { 'text': df[text_column].tolist(), 'label': df[label_column].tolist() } dataset = Dataset.from_dict(data) return dataset except Exception as e: st.error(f"❌ 數據準備錯誤: {str(e)}") return None # ==================== 訓練函數 ==================== def train_bert_model( train_dataset, tuning_method="full", epochs=3, batch_size=16, lr=2e-5, warmup_steps=200, best_metric='f1', weight_mult=0.8, lora_r=16, lora_alpha=32, lora_dropout=0.1, lora_modules="query,value", adalora_init_r=12, adalora_target_r=8, adalora_tinit=0, adalora_tfinal=0, adalora_delta_t=1, is_second_tuning=False, base_model_path=None ): """訓練 BERT 模型""" progress_bar = st.progress(0) status_text = st.empty() try: # 加載基礎模型 status_text.text("🔄 載入模型...") progress_bar.progress(10) if tuning_method == "full" or base_model_path is None: model = BertForSequenceClassification.from_pretrained( "bert-base-uncased", num_labels=2 ).to(device) tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") else: base_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2) if tuning_method in ["LoRA", "AdaLoRA"] and PEFT_AVAILABLE: model = PeftModel.from_pretrained(base_model, base_model_path) else: model = BertForSequenceClassification.from_pretrained(base_model_path) model = model.to(device) tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") # 應用 PEFT(如果需要) if tuning_method == "LoRA" and PEFT_AVAILABLE: status_text.text("⚙️ 配置 LoRA...") progress_bar.progress(15) lora_config = LoraConfig( r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, bias="none", target_modules=lora_modules.split(","), task_type=TaskType.SEQ_CLS, inference_mode=False ) model = get_peft_model(model, lora_config) elif tuning_method == "AdaLoRA" and PEFT_AVAILABLE: status_text.text("⚙️ 配置 AdaLoRA...") progress_bar.progress(15) adalora_config = AdaLoraConfig( init_r=adalora_init_r, target_r=adalora_target_r, tinit=int(adalora_tinit), tfinal=int(adalora_tfinal), deltaT=int(adalora_delta_t), lora_alpha=lora_alpha, lora_dropout=lora_dropout, bias="none", target_modules=lora_modules.split(","), task_type=TaskType.SEQ_CLS, inference_mode=False ) model = get_peft_model(model, adalora_config) # 分詞化數據 status_text.text("📝 分詞化數據...") progress_bar.progress(20) def tokenize_function(examples): return tokenizer( examples['text'], padding="max_length", truncation=True, max_length=128 ) tokenized_dataset = train_dataset.map(tokenize_function, batched=True) tokenized_dataset = tokenized_dataset.remove_columns(['text']) tokenized_dataset = tokenized_dataset.rename_column('label', 'labels') # 分割訓練/驗證集 split_dataset = tokenized_dataset.train_test_split(test_size=0.2, seed=RANDOM_SEED) # 定義評估指標 def compute_metrics(eval_pred): predictions, labels = eval_pred predictions = np.argmax(predictions, axis=1) f1, _, _, _ = precision_recall_fscore_support( labels, predictions, average='binary', zero_division=0 ) acc = accuracy_score(labels, predictions) try: auc = roc_auc_score(labels, predictions) except: auc = 0.0 return { 'f1': float(f1), 'accuracy': float(acc), 'auc': float(auc) } # 訓練參數 status_text.text("🚀 開始訓練...") progress_bar.progress(30) training_args = TrainingArguments( output_dir='./results', num_train_epochs=int(epochs), per_device_train_batch_size=int(batch_size), per_device_eval_batch_size=int(batch_size), warmup_steps=int(warmup_steps), weight_decay=0.01, logging_dir='./logs', logging_steps=100, eval_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model=best_metric, learning_rate=lr, seed=RANDOM_SEED ) trainer = Trainer( model=model, args=training_args, train_dataset=split_dataset['train'], eval_dataset=split_dataset['test'], compute_metrics=compute_metrics, data_collator=DataCollatorWithPadding(tokenizer) ) # 訓練 train_result = trainer.train() progress_bar.progress(80) # 評估 status_text.text("📊 評估模型...") eval_result = trainer.evaluate(split_dataset['test']) progress_bar.progress(90) # 保存模型 status_text.text("💾 保存模型...") model_save_path = f"./bert_model_{datetime.now().strftime('%Y%m%d_%H%M%S')}" os.makedirs(model_save_path, exist_ok=True) if tuning_method in ["LoRA", "AdaLoRA"] and PEFT_AVAILABLE: model.save_pretrained(model_save_path) else: model.save_pretrained(model_save_path) tokenizer.save_pretrained(model_save_path) # 記錄模型信息 models_list = load_models_list() model_info = { 'model_path': model_save_path, 'tuning_method': tuning_method, 'best_metric': best_metric, 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'eval_results': {k: float(v) for k, v in eval_result.items()}, 'training_type': '二次微調' if is_second_tuning else '第一次微調', 'is_second_finetuning': is_second_tuning, 'base_model': base_model_path } models_list.append(model_info) save_models_list(models_list) progress_bar.progress(100) status_text.text("✅ 訓練完成!") # 返回結果 results = { 'status': 'success', 'model_path': model_save_path, 'eval_results': eval_result, 'model_info': model_info } del trainer torch.cuda.empty_cache() gc.collect() return results except Exception as e: status_text.text(f"❌ 訓練失敗: {str(e)}") st.error(traceback.format_exc()) return {'status': 'error', 'message': str(e)} # ==================== 預測函數 ==================== def predict(text, model_path): """進行預測""" try: tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") # 載入模型 models_list = load_models_list() model_info = None for m in models_list: if m['model_path'] == model_path: model_info = m break tuning_method = model_info['tuning_method'] if model_info else 'full' if tuning_method in ["LoRA", "AdaLoRA"] and PEFT_AVAILABLE: base_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2) model = PeftModel.from_pretrained(base_model, model_path) else: model = BertForSequenceClassification.from_pretrained(model_path) model = model.to(device) model.eval() # 預測 inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probs = torch.nn.functional.softmax(logits, dim=-1) pred_class = probs.argmax(-1).item() confidence = probs[0][pred_class].item() result = "存活" if pred_class == 0 else "死亡" prob_survive = probs[0][0].item() prob_death = probs[0][1].item() del model torch.cuda.empty_cache() return { 'result': result, 'confidence': confidence, 'prob_survive': prob_survive, 'prob_death': prob_death, 'model_info': model_info } except Exception as e: return { 'status': 'error', 'message': str(e), 'traceback': traceback.format_exc() } # ==================== Streamlit UI ==================== st.title("🔬 BERT 乳癌生存預測系統") # 側邊欄設備信息 with st.sidebar: st.header("📋 系統信息") st.write(f"**設備**: {device}") st.write(f"**PEFT 可用**: {'✅' if PEFT_AVAILABLE else '❌'}") st.write(f"**CUDA 可用**: {'✅' if torch.cuda.is_available() else '❌'}") if torch.cuda.is_available(): st.write(f"**GPU**: {torch.cuda.get_device_name(0)}") st.write(f"**CUDA 版本**: {torch.version.cuda}") # 主要 UI - Tab 頁面 tab1, tab2, tab3, tab4 = st.tabs( ["1️⃣ 第一次微調", "2️⃣ 二次微調", "3️⃣ 測試評估", "4️⃣ 模型預測"] ) # ==================== Tab 1: 第一次微調 ==================== with tab1: st.header("第一次微調") col1, col2 = st.columns([1, 2]) with col1: st.subheader("📤 數據上傳") uploaded_file = st.file_uploader("上傳訓練數據 (CSV)", type=['csv'], key='first_train') if uploaded_file: df = process_csv(uploaded_file) st.subheader("⚙️ 訓練參數") tuning_method = st.selectbox( "微調方法", ["full", "LoRA", "AdaLoRA"] if PEFT_AVAILABLE else ["full"], key='first_method' ) best_metric = st.selectbox( "最佳化指標", ["f1", "accuracy", "auc"], key='first_metric' ) col_train_a, col_train_b = st.columns(2) with col_train_a: epochs = st.number_input("訓練輪數", 1, 10, 3, key='first_epochs') batch_size = st.number_input("批次大小", 8, 64, 16, key='first_batch') with col_train_b: lr = st.number_input("學習率", 1e-6, 1e-3, 2e-5, format="%.2e", key='first_lr') warmup = st.number_input("Warmup Steps", 0, 1000, 200, key='first_warmup') weight_mult = st.slider("權重倍數", 0.1, 2.0, 0.8, 0.1, key='first_weight') # LoRA 參數 if tuning_method == "LoRA": st.write("### 🔷 LoRA 參數") col_lora_a, col_lora_b = st.columns(2) with col_lora_a: lora_r = st.slider("LoRA Rank", 4, 64, 16, 4, key='first_lora_r') lora_alpha = st.slider("LoRA Alpha", 8, 128, 32, 8, key='first_lora_alpha') with col_lora_b: lora_dropout = st.slider("LoRA Dropout", 0.0, 0.5, 0.1, 0.05, key='first_lora_dropout') lora_modules = st.text_input("目標模組", "query,value", key='first_lora_modules') else: lora_r = lora_alpha = 16 lora_dropout = 0.1 lora_modules = "query,value" # AdaLoRA 參數 if tuning_method == "AdaLoRA": st.write("### 🔶 AdaLoRA 參數") col_ada_a, col_ada_b = st.columns(2) with col_ada_a: adalora_init_r = st.slider("初始 Rank", 4, 64, 12, 4, key='first_ada_init_r') adalora_target_r = st.slider("目標 Rank", 4, 64, 8, 4, key='first_ada_target_r') with col_ada_b: adalora_tinit = st.number_input("Tinit", 0, 1000, 0, key='first_ada_tinit') adalora_tfinal = st.number_input("Tfinal", 0, 1000, 0, key='first_ada_tfinal') adalora_delta_t = st.number_input("Delta T", 1, 100, 1, key='first_ada_delta_t') else: adalora_init_r = adalora_target_r = 12 adalora_tinit = adalora_tfinal = 0 adalora_delta_t = 1 if st.button("🚀 開始第一次微調", key='first_train_btn'): st.session_state.training_in_progress = True dataset = prepare_dataset(df) if dataset: result = train_bert_model( dataset, tuning_method=tuning_method, epochs=epochs, batch_size=batch_size, lr=lr, warmup_steps=warmup, best_metric=best_metric, weight_mult=weight_mult, lora_r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, lora_modules=lora_modules, adalora_init_r=adalora_init_r, adalora_target_r=adalora_target_r, adalora_tinit=adalora_tinit, adalora_tfinal=adalora_tfinal, adalora_delta_t=adalora_delta_t, is_second_tuning=False ) if result['status'] == 'success': st.success("✅ 訓練完成!") st.json(result['eval_results']) else: st.error(f"❌ 訓練失敗: {result['message']}") st.session_state.training_in_progress = False # ==================== Tab 2: 二次微調 ==================== with tab2: st.header("二次微調") col1, col2 = st.columns([1, 2]) with col1: st.subheader("🔄 選擇基礎模型") first_models = get_first_finetuning_models() if len(first_models) == 0: st.warning("⚠️ 請先進行第一次微調") else: base_model = st.selectbox("選擇第一次微調模型", first_models, key='second_base') st.subheader("📤 數據上傳") uploaded_file = st.file_uploader("上傳新訓練數據 (CSV)", type=['csv'], key='second_train') if uploaded_file: df = process_csv(uploaded_file) st.subheader("⚙️ 訓練參數") best_metric = st.selectbox( "最佳化指標", ["f1", "accuracy", "auc"], key='second_metric' ) col_train_a, col_train_b = st.columns(2) with col_train_a: epochs = st.number_input("訓練輪數", 1, 10, 2, key='second_epochs') batch_size = st.number_input("批次大小", 8, 64, 16, key='second_batch') with col_train_b: lr = st.number_input("學習率", 1e-6, 1e-3, 1e-5, format="%.2e", key='second_lr') warmup = st.number_input("Warmup Steps", 0, 1000, 100, key='second_warmup') weight_mult = st.slider("權重倍數", 0.1, 2.0, 0.8, 0.1, key='second_weight') if st.button("🚀 開始二次微調", key='second_train_btn'): st.session_state.training_in_progress = True dataset = prepare_dataset(df) if dataset: # 從基礎模型獲取微調方法 models_list = load_models_list() tuning_method = "full" for m in models_list: if m['model_path'] == base_model: tuning_method = m['tuning_method'] break result = train_bert_model( dataset, tuning_method=tuning_method, epochs=epochs, batch_size=batch_size, lr=lr, warmup_steps=warmup, best_metric=best_metric, weight_mult=weight_mult, is_second_tuning=True, base_model_path=base_model ) if result['status'] == 'success': st.success("✅ 二次微調完成!") st.json(result['eval_results']) else: st.error(f"❌ 訓練失敗: {result['message']}") st.session_state.training_in_progress = False # ==================== Tab 3: 測試評估 ==================== with tab3: st.header("測試評估") col1, col2 = st.columns([1, 2]) with col1: st.subheader("📤 上傳測試數據") uploaded_file = st.file_uploader("上傳測試數據 (CSV)", type=['csv'], key='test_data') if uploaded_file: df = process_csv(uploaded_file) st.subheader("🎯 選擇模型") available_models = get_available_models() if len(available_models) == 0: st.warning("⚠️ 無可用模型") else: selected_model = st.selectbox("選擇模型", available_models, key='test_model') if st.button("▶️ 開始評估", key='test_btn'): st.info("⏳ 評估中... (此功能在簡化版本中暫未實現)") # ==================== Tab 4: 模型預測 ==================== with tab4: st.header("模型預測") col1, col2 = st.columns([1, 2]) with col1: st.subheader("🎯 選擇模型") available_models = get_available_models() if len(available_models) == 0: st.warning("⚠️ 無可用模型") else: selected_model = st.selectbox("選擇預測模型", available_models, key='pred_model') st.subheader("📝 輸入文本") text_input = st.text_area( "輸入病歷或文本", placeholder="例如: The patient shows signs of ...", height=150, key='pred_text' ) if st.button("🔮 進行預測", key='pred_btn'): if text_input.strip(): # 提取實際模型路徑 model_path = selected_model.split(" | ")[0] with st.spinner("⏳ 預測中..."): result = predict(text_input, model_path) if 'status' not in result or result['status'] != 'error': col_result_a, col_result_b = st.columns(2) with col_result_a: st.markdown("### 🟢 預測結果") st.markdown(f"## **{result['result']}**") st.metric("信心度", f"{result['confidence']:.1%}") with col_result_b: st.markdown("### 📊 機率分布") st.metric("存活機率", f"{result['prob_survive']:.2%}") st.metric("死亡機率", f"{result['prob_death']:.2%}") st.markdown("---") st.markdown("### 📋 模型信息") if result['model_info']: st.write(f"**訓練類型**: {result['model_info']['training_type']}") st.write(f"**微調方法**: {result['model_info']['tuning_method']}") st.write(f"**訓練時間**: {result['model_info']['timestamp']}") else: st.error(f"❌ 預測失敗: {result['message']}") else: st.warning("⚠️ 請輸入文本") st.markdown("---") st.markdown("**注意**: 此預測系統僅供參考,實際醫療決策應由專業醫師判斷。")