Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import pandas as pd | |
| import torch | |
| from torch import nn | |
| from transformers import ( | |
| BertTokenizer, | |
| BertForSequenceClassification, | |
| TrainingArguments, | |
| Trainer, | |
| DataCollatorWithPadding | |
| ) | |
| from datasets import Dataset | |
| from sklearn.metrics import ( | |
| accuracy_score, | |
| precision_recall_fscore_support, | |
| roc_auc_score, | |
| confusion_matrix | |
| ) | |
| import numpy as np | |
| from datetime import datetime | |
| import json | |
| import os | |
| import gc | |
| import random | |
| import traceback | |
| # ==================== 配置 ==================== | |
| st.set_page_config( | |
| page_title="BERT 乳癌預測系統", | |
| page_icon="🔬", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # 隨機種子設定 | |
| RANDOM_SEED = 42 | |
| def set_seed(seed=42): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' | |
| try: | |
| torch.use_deterministic_algorithms(True) | |
| except: | |
| pass | |
| set_seed(RANDOM_SEED) | |
| # 檢查設備 | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # PEFT 檢查 | |
| try: | |
| from peft import ( | |
| LoraConfig, | |
| AdaLoraConfig, | |
| get_peft_model, | |
| TaskType, | |
| PeftModel | |
| ) | |
| PEFT_AVAILABLE = True | |
| except ImportError: | |
| PEFT_AVAILABLE = False | |
| # ==================== 初始化 Session State ==================== | |
| if 'models_list' not in st.session_state: | |
| st.session_state.models_list = [] | |
| if 'training_in_progress' not in st.session_state: | |
| st.session_state.training_in_progress = False | |
| # ==================== 模型管理函數 ==================== | |
| def load_models_list(): | |
| """載入已訓練的模型列表""" | |
| models_list_file = './saved_models_list.json' | |
| if os.path.exists(models_list_file): | |
| with open(models_list_file, 'r') as f: | |
| return json.load(f) | |
| return [] | |
| def save_models_list(models_list): | |
| """保存模型列表""" | |
| with open('./saved_models_list.json', 'w') as f: | |
| json.dump(models_list, f, indent=2) | |
| def get_available_models(): | |
| """取得所有已訓練的模型""" | |
| models_list = load_models_list() | |
| if len(models_list) == 0: | |
| return [] | |
| model_choices = [] | |
| for model_info in models_list: | |
| training_type = model_info.get('training_type', '第一次微調') | |
| choice = f"{model_info['model_path']} | {training_type} | {model_info['tuning_method']} | {model_info['timestamp']}" | |
| model_choices.append(choice) | |
| return model_choices | |
| def get_first_finetuning_models(): | |
| """取得第一次微調的模型""" | |
| models_list = load_models_list() | |
| first_models = [m for m in models_list if not m.get('is_second_finetuning', False)] | |
| return [m['model_path'] for m in first_models] | |
| # ==================== 數據處理函數 ==================== | |
| def process_csv(file_path, label_column='label'): | |
| """處理上傳的 CSV 文件""" | |
| try: | |
| df = pd.read_csv(file_path) | |
| st.success(f"✅ 已載入 {len(df)} 筆數據") | |
| st.write("數據預覽:") | |
| st.dataframe(df.head(10)) | |
| return df | |
| except Exception as e: | |
| st.error(f"❌ 無法讀取 CSV: {str(e)}") | |
| return None | |
| def prepare_dataset(df, text_column='text', label_column='label'): | |
| """準備訓練數據集""" | |
| try: | |
| if text_column not in df.columns or label_column not in df.columns: | |
| raise ValueError(f"CSV 必須包含 '{text_column}' 和 '{label_column}' 列") | |
| data = { | |
| 'text': df[text_column].tolist(), | |
| 'label': df[label_column].tolist() | |
| } | |
| dataset = Dataset.from_dict(data) | |
| return dataset | |
| except Exception as e: | |
| st.error(f"❌ 數據準備錯誤: {str(e)}") | |
| return None | |
| # ==================== 訓練函數 ==================== | |
| def train_bert_model( | |
| train_dataset, | |
| tuning_method="full", | |
| epochs=3, | |
| batch_size=16, | |
| lr=2e-5, | |
| warmup_steps=200, | |
| best_metric='f1', | |
| weight_mult=0.8, | |
| lora_r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.1, | |
| lora_modules="query,value", | |
| adalora_init_r=12, | |
| adalora_target_r=8, | |
| adalora_tinit=0, | |
| adalora_tfinal=0, | |
| adalora_delta_t=1, | |
| is_second_tuning=False, | |
| base_model_path=None | |
| ): | |
| """訓練 BERT 模型""" | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| try: | |
| # 加載基礎模型 | |
| status_text.text("🔄 載入模型...") | |
| progress_bar.progress(10) | |
| if tuning_method == "full" or base_model_path is None: | |
| model = BertForSequenceClassification.from_pretrained( | |
| "bert-base-uncased", | |
| num_labels=2 | |
| ).to(device) | |
| tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| else: | |
| base_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2) | |
| if tuning_method in ["LoRA", "AdaLoRA"] and PEFT_AVAILABLE: | |
| model = PeftModel.from_pretrained(base_model, base_model_path) | |
| else: | |
| model = BertForSequenceClassification.from_pretrained(base_model_path) | |
| model = model.to(device) | |
| tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| # 應用 PEFT(如果需要) | |
| if tuning_method == "LoRA" and PEFT_AVAILABLE: | |
| status_text.text("⚙️ 配置 LoRA...") | |
| progress_bar.progress(15) | |
| lora_config = LoraConfig( | |
| r=lora_r, | |
| lora_alpha=lora_alpha, | |
| lora_dropout=lora_dropout, | |
| bias="none", | |
| target_modules=lora_modules.split(","), | |
| task_type=TaskType.SEQ_CLS, | |
| inference_mode=False | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| elif tuning_method == "AdaLoRA" and PEFT_AVAILABLE: | |
| status_text.text("⚙️ 配置 AdaLoRA...") | |
| progress_bar.progress(15) | |
| adalora_config = AdaLoraConfig( | |
| init_r=adalora_init_r, | |
| target_r=adalora_target_r, | |
| tinit=int(adalora_tinit), | |
| tfinal=int(adalora_tfinal), | |
| deltaT=int(adalora_delta_t), | |
| lora_alpha=lora_alpha, | |
| lora_dropout=lora_dropout, | |
| bias="none", | |
| target_modules=lora_modules.split(","), | |
| task_type=TaskType.SEQ_CLS, | |
| inference_mode=False | |
| ) | |
| model = get_peft_model(model, adalora_config) | |
| # 分詞化數據 | |
| status_text.text("📝 分詞化數據...") | |
| progress_bar.progress(20) | |
| def tokenize_function(examples): | |
| return tokenizer( | |
| examples['text'], | |
| padding="max_length", | |
| truncation=True, | |
| max_length=128 | |
| ) | |
| tokenized_dataset = train_dataset.map(tokenize_function, batched=True) | |
| tokenized_dataset = tokenized_dataset.remove_columns(['text']) | |
| tokenized_dataset = tokenized_dataset.rename_column('label', 'labels') | |
| # 分割訓練/驗證集 | |
| split_dataset = tokenized_dataset.train_test_split(test_size=0.2, seed=RANDOM_SEED) | |
| # 定義評估指標 | |
| def compute_metrics(eval_pred): | |
| predictions, labels = eval_pred | |
| predictions = np.argmax(predictions, axis=1) | |
| f1, _, _, _ = precision_recall_fscore_support( | |
| labels, predictions, average='binary', zero_division=0 | |
| ) | |
| acc = accuracy_score(labels, predictions) | |
| try: | |
| auc = roc_auc_score(labels, predictions) | |
| except: | |
| auc = 0.0 | |
| return { | |
| 'f1': float(f1), | |
| 'accuracy': float(acc), | |
| 'auc': float(auc) | |
| } | |
| # 訓練參數 | |
| status_text.text("🚀 開始訓練...") | |
| progress_bar.progress(30) | |
| training_args = TrainingArguments( | |
| output_dir='./results', | |
| num_train_epochs=int(epochs), | |
| per_device_train_batch_size=int(batch_size), | |
| per_device_eval_batch_size=int(batch_size), | |
| warmup_steps=int(warmup_steps), | |
| weight_decay=0.01, | |
| logging_dir='./logs', | |
| logging_steps=100, | |
| eval_strategy="epoch", | |
| save_strategy="epoch", | |
| load_best_model_at_end=True, | |
| metric_for_best_model=best_metric, | |
| learning_rate=lr, | |
| seed=RANDOM_SEED | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=split_dataset['train'], | |
| eval_dataset=split_dataset['test'], | |
| compute_metrics=compute_metrics, | |
| data_collator=DataCollatorWithPadding(tokenizer) | |
| ) | |
| # 訓練 | |
| train_result = trainer.train() | |
| progress_bar.progress(80) | |
| # 評估 | |
| status_text.text("📊 評估模型...") | |
| eval_result = trainer.evaluate(split_dataset['test']) | |
| progress_bar.progress(90) | |
| # 保存模型 | |
| status_text.text("💾 保存模型...") | |
| model_save_path = f"./bert_model_{datetime.now().strftime('%Y%m%d_%H%M%S')}" | |
| os.makedirs(model_save_path, exist_ok=True) | |
| if tuning_method in ["LoRA", "AdaLoRA"] and PEFT_AVAILABLE: | |
| model.save_pretrained(model_save_path) | |
| else: | |
| model.save_pretrained(model_save_path) | |
| tokenizer.save_pretrained(model_save_path) | |
| # 記錄模型信息 | |
| models_list = load_models_list() | |
| model_info = { | |
| 'model_path': model_save_path, | |
| 'tuning_method': tuning_method, | |
| 'best_metric': best_metric, | |
| 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), | |
| 'eval_results': {k: float(v) for k, v in eval_result.items()}, | |
| 'training_type': '二次微調' if is_second_tuning else '第一次微調', | |
| 'is_second_finetuning': is_second_tuning, | |
| 'base_model': base_model_path | |
| } | |
| models_list.append(model_info) | |
| save_models_list(models_list) | |
| progress_bar.progress(100) | |
| status_text.text("✅ 訓練完成!") | |
| # 返回結果 | |
| results = { | |
| 'status': 'success', | |
| 'model_path': model_save_path, | |
| 'eval_results': eval_result, | |
| 'model_info': model_info | |
| } | |
| del trainer | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return results | |
| except Exception as e: | |
| status_text.text(f"❌ 訓練失敗: {str(e)}") | |
| st.error(traceback.format_exc()) | |
| return {'status': 'error', 'message': str(e)} | |
| # ==================== 預測函數 ==================== | |
| def predict(text, model_path): | |
| """進行預測""" | |
| try: | |
| tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| # 載入模型 | |
| models_list = load_models_list() | |
| model_info = None | |
| for m in models_list: | |
| if m['model_path'] == model_path: | |
| model_info = m | |
| break | |
| tuning_method = model_info['tuning_method'] if model_info else 'full' | |
| if tuning_method in ["LoRA", "AdaLoRA"] and PEFT_AVAILABLE: | |
| base_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2) | |
| model = PeftModel.from_pretrained(base_model, model_path) | |
| else: | |
| model = BertForSequenceClassification.from_pretrained(model_path) | |
| model = model.to(device) | |
| model.eval() | |
| # 預測 | |
| inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| pred_class = probs.argmax(-1).item() | |
| confidence = probs[0][pred_class].item() | |
| result = "存活" if pred_class == 0 else "死亡" | |
| prob_survive = probs[0][0].item() | |
| prob_death = probs[0][1].item() | |
| del model | |
| torch.cuda.empty_cache() | |
| return { | |
| 'result': result, | |
| 'confidence': confidence, | |
| 'prob_survive': prob_survive, | |
| 'prob_death': prob_death, | |
| 'model_info': model_info | |
| } | |
| except Exception as e: | |
| return { | |
| 'status': 'error', | |
| 'message': str(e), | |
| 'traceback': traceback.format_exc() | |
| } | |
| # ==================== Streamlit UI ==================== | |
| st.title("🔬 BERT 乳癌生存預測系統") | |
| # 側邊欄設備信息 | |
| with st.sidebar: | |
| st.header("📋 系統信息") | |
| st.write(f"**設備**: {device}") | |
| st.write(f"**PEFT 可用**: {'✅' if PEFT_AVAILABLE else '❌'}") | |
| st.write(f"**CUDA 可用**: {'✅' if torch.cuda.is_available() else '❌'}") | |
| if torch.cuda.is_available(): | |
| st.write(f"**GPU**: {torch.cuda.get_device_name(0)}") | |
| st.write(f"**CUDA 版本**: {torch.version.cuda}") | |
| # 主要 UI - Tab 頁面 | |
| tab1, tab2, tab3, tab4 = st.tabs( | |
| ["1️⃣ 第一次微調", "2️⃣ 二次微調", "3️⃣ 測試評估", "4️⃣ 模型預測"] | |
| ) | |
| # ==================== Tab 1: 第一次微調 ==================== | |
| with tab1: | |
| st.header("第一次微調") | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| st.subheader("📤 數據上傳") | |
| uploaded_file = st.file_uploader("上傳訓練數據 (CSV)", type=['csv'], key='first_train') | |
| if uploaded_file: | |
| df = process_csv(uploaded_file) | |
| st.subheader("⚙️ 訓練參數") | |
| tuning_method = st.selectbox( | |
| "微調方法", | |
| ["full", "LoRA", "AdaLoRA"] if PEFT_AVAILABLE else ["full"], | |
| key='first_method' | |
| ) | |
| best_metric = st.selectbox( | |
| "最佳化指標", | |
| ["f1", "accuracy", "auc"], | |
| key='first_metric' | |
| ) | |
| col_train_a, col_train_b = st.columns(2) | |
| with col_train_a: | |
| epochs = st.number_input("訓練輪數", 1, 10, 3, key='first_epochs') | |
| batch_size = st.number_input("批次大小", 8, 64, 16, key='first_batch') | |
| with col_train_b: | |
| lr = st.number_input("學習率", 1e-6, 1e-3, 2e-5, format="%.2e", key='first_lr') | |
| warmup = st.number_input("Warmup Steps", 0, 1000, 200, key='first_warmup') | |
| weight_mult = st.slider("權重倍數", 0.1, 2.0, 0.8, 0.1, key='first_weight') | |
| # LoRA 參數 | |
| if tuning_method == "LoRA": | |
| st.write("### 🔷 LoRA 參數") | |
| col_lora_a, col_lora_b = st.columns(2) | |
| with col_lora_a: | |
| lora_r = st.slider("LoRA Rank", 4, 64, 16, 4, key='first_lora_r') | |
| lora_alpha = st.slider("LoRA Alpha", 8, 128, 32, 8, key='first_lora_alpha') | |
| with col_lora_b: | |
| lora_dropout = st.slider("LoRA Dropout", 0.0, 0.5, 0.1, 0.05, key='first_lora_dropout') | |
| lora_modules = st.text_input("目標模組", "query,value", key='first_lora_modules') | |
| else: | |
| lora_r = lora_alpha = 16 | |
| lora_dropout = 0.1 | |
| lora_modules = "query,value" | |
| # AdaLoRA 參數 | |
| if tuning_method == "AdaLoRA": | |
| st.write("### 🔶 AdaLoRA 參數") | |
| col_ada_a, col_ada_b = st.columns(2) | |
| with col_ada_a: | |
| adalora_init_r = st.slider("初始 Rank", 4, 64, 12, 4, key='first_ada_init_r') | |
| adalora_target_r = st.slider("目標 Rank", 4, 64, 8, 4, key='first_ada_target_r') | |
| with col_ada_b: | |
| adalora_tinit = st.number_input("Tinit", 0, 1000, 0, key='first_ada_tinit') | |
| adalora_tfinal = st.number_input("Tfinal", 0, 1000, 0, key='first_ada_tfinal') | |
| adalora_delta_t = st.number_input("Delta T", 1, 100, 1, key='first_ada_delta_t') | |
| else: | |
| adalora_init_r = adalora_target_r = 12 | |
| adalora_tinit = adalora_tfinal = 0 | |
| adalora_delta_t = 1 | |
| if st.button("🚀 開始第一次微調", key='first_train_btn'): | |
| st.session_state.training_in_progress = True | |
| dataset = prepare_dataset(df) | |
| if dataset: | |
| result = train_bert_model( | |
| dataset, | |
| tuning_method=tuning_method, | |
| epochs=epochs, | |
| batch_size=batch_size, | |
| lr=lr, | |
| warmup_steps=warmup, | |
| best_metric=best_metric, | |
| weight_mult=weight_mult, | |
| lora_r=lora_r, | |
| lora_alpha=lora_alpha, | |
| lora_dropout=lora_dropout, | |
| lora_modules=lora_modules, | |
| adalora_init_r=adalora_init_r, | |
| adalora_target_r=adalora_target_r, | |
| adalora_tinit=adalora_tinit, | |
| adalora_tfinal=adalora_tfinal, | |
| adalora_delta_t=adalora_delta_t, | |
| is_second_tuning=False | |
| ) | |
| if result['status'] == 'success': | |
| st.success("✅ 訓練完成!") | |
| st.json(result['eval_results']) | |
| else: | |
| st.error(f"❌ 訓練失敗: {result['message']}") | |
| st.session_state.training_in_progress = False | |
| # ==================== Tab 2: 二次微調 ==================== | |
| with tab2: | |
| st.header("二次微調") | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| st.subheader("🔄 選擇基礎模型") | |
| first_models = get_first_finetuning_models() | |
| if len(first_models) == 0: | |
| st.warning("⚠️ 請先進行第一次微調") | |
| else: | |
| base_model = st.selectbox("選擇第一次微調模型", first_models, key='second_base') | |
| st.subheader("📤 數據上傳") | |
| uploaded_file = st.file_uploader("上傳新訓練數據 (CSV)", type=['csv'], key='second_train') | |
| if uploaded_file: | |
| df = process_csv(uploaded_file) | |
| st.subheader("⚙️ 訓練參數") | |
| best_metric = st.selectbox( | |
| "最佳化指標", | |
| ["f1", "accuracy", "auc"], | |
| key='second_metric' | |
| ) | |
| col_train_a, col_train_b = st.columns(2) | |
| with col_train_a: | |
| epochs = st.number_input("訓練輪數", 1, 10, 2, key='second_epochs') | |
| batch_size = st.number_input("批次大小", 8, 64, 16, key='second_batch') | |
| with col_train_b: | |
| lr = st.number_input("學習率", 1e-6, 1e-3, 1e-5, format="%.2e", key='second_lr') | |
| warmup = st.number_input("Warmup Steps", 0, 1000, 100, key='second_warmup') | |
| weight_mult = st.slider("權重倍數", 0.1, 2.0, 0.8, 0.1, key='second_weight') | |
| if st.button("🚀 開始二次微調", key='second_train_btn'): | |
| st.session_state.training_in_progress = True | |
| dataset = prepare_dataset(df) | |
| if dataset: | |
| # 從基礎模型獲取微調方法 | |
| models_list = load_models_list() | |
| tuning_method = "full" | |
| for m in models_list: | |
| if m['model_path'] == base_model: | |
| tuning_method = m['tuning_method'] | |
| break | |
| result = train_bert_model( | |
| dataset, | |
| tuning_method=tuning_method, | |
| epochs=epochs, | |
| batch_size=batch_size, | |
| lr=lr, | |
| warmup_steps=warmup, | |
| best_metric=best_metric, | |
| weight_mult=weight_mult, | |
| is_second_tuning=True, | |
| base_model_path=base_model | |
| ) | |
| if result['status'] == 'success': | |
| st.success("✅ 二次微調完成!") | |
| st.json(result['eval_results']) | |
| else: | |
| st.error(f"❌ 訓練失敗: {result['message']}") | |
| st.session_state.training_in_progress = False | |
| # ==================== Tab 3: 測試評估 ==================== | |
| with tab3: | |
| st.header("測試評估") | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| st.subheader("📤 上傳測試數據") | |
| uploaded_file = st.file_uploader("上傳測試數據 (CSV)", type=['csv'], key='test_data') | |
| if uploaded_file: | |
| df = process_csv(uploaded_file) | |
| st.subheader("🎯 選擇模型") | |
| available_models = get_available_models() | |
| if len(available_models) == 0: | |
| st.warning("⚠️ 無可用模型") | |
| else: | |
| selected_model = st.selectbox("選擇模型", available_models, key='test_model') | |
| if st.button("▶️ 開始評估", key='test_btn'): | |
| st.info("⏳ 評估中... (此功能在簡化版本中暫未實現)") | |
| # ==================== Tab 4: 模型預測 ==================== | |
| with tab4: | |
| st.header("模型預測") | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| st.subheader("🎯 選擇模型") | |
| available_models = get_available_models() | |
| if len(available_models) == 0: | |
| st.warning("⚠️ 無可用模型") | |
| else: | |
| selected_model = st.selectbox("選擇預測模型", available_models, key='pred_model') | |
| st.subheader("📝 輸入文本") | |
| text_input = st.text_area( | |
| "輸入病歷或文本", | |
| placeholder="例如: The patient shows signs of ...", | |
| height=150, | |
| key='pred_text' | |
| ) | |
| if st.button("🔮 進行預測", key='pred_btn'): | |
| if text_input.strip(): | |
| # 提取實際模型路徑 | |
| model_path = selected_model.split(" | ")[0] | |
| with st.spinner("⏳ 預測中..."): | |
| result = predict(text_input, model_path) | |
| if 'status' not in result or result['status'] != 'error': | |
| col_result_a, col_result_b = st.columns(2) | |
| with col_result_a: | |
| st.markdown("### 🟢 預測結果") | |
| st.markdown(f"## **{result['result']}**") | |
| st.metric("信心度", f"{result['confidence']:.1%}") | |
| with col_result_b: | |
| st.markdown("### 📊 機率分布") | |
| st.metric("存活機率", f"{result['prob_survive']:.2%}") | |
| st.metric("死亡機率", f"{result['prob_death']:.2%}") | |
| st.markdown("---") | |
| st.markdown("### 📋 模型信息") | |
| if result['model_info']: | |
| st.write(f"**訓練類型**: {result['model_info']['training_type']}") | |
| st.write(f"**微調方法**: {result['model_info']['tuning_method']}") | |
| st.write(f"**訓練時間**: {result['model_info']['timestamp']}") | |
| else: | |
| st.error(f"❌ 預測失敗: {result['message']}") | |
| else: | |
| st.warning("⚠️ 請輸入文本") | |
| st.markdown("---") | |
| st.markdown("**注意**: 此預測系統僅供參考,實際醫療決策應由專業醫師判斷。") |