Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from sklearn.model_selection import train_test_split, cross_val_score, KFold, StratifiedKFold | |
| from sklearn.ensemble import RandomForestClassifier | |
| from sklearn.tree import DecisionTreeClassifier | |
| from imblearn.over_sampling import SMOTE | |
| from sklearn.utils import resample | |
| from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score, roc_curve | |
| from sklearn.preprocessing import label_binarize | |
| from sklearn.multiclass import OneVsRestClassifier | |
| import numpy as np | |
| # Streamlit interface for file upload | |
| st.title("資料平衡與交叉驗證") | |
| st.write("選擇 CSV 資料集,資料平衡方法與模型後,計算 k-Fold 和 Stratified k-Fold Cross-Validation 的分數,並繪製混淆矩陣與 ROC AUC 曲線") | |
| # 上傳 CSV 檔案 | |
| uploaded_file = st.file_uploader("選擇 CSV 檔案", type=["csv"]) | |
| if uploaded_file is not None: | |
| try: | |
| # 載入資料 | |
| data = pd.read_csv(uploaded_file) | |
| # 顯示初始的缺失值統計 | |
| missing_values = data.isnull().sum() | |
| if missing_values.any(): | |
| st.warning("檢測到缺失值,以下是各欄位的缺失值數量:") | |
| st.write(missing_values[missing_values > 0]) | |
| # 檢查數值型欄位 | |
| numeric_columns = data.select_dtypes(include=[np.number]).columns | |
| # 對數值型欄位進行缺失值處理 | |
| for col in numeric_columns: | |
| if data[col].isnull().any(): | |
| # 使用中位數填補缺失值,因為平均值容易受極端值影響 | |
| median_value = data[col].median() | |
| data[col] = data[col].fillna(median_value) | |
| st.info(f"欄位 '{col}' 的缺失值已使用中位數 {median_value:.2f} 填補") | |
| # 檢查是否還有缺失值 | |
| if data.isnull().any().any(): | |
| st.error("資料集中仍存在非數值型的缺失值,請檢查資料") | |
| st.stop() | |
| # 顯示資料前幾行供確認 | |
| st.write("資料集預覽:", data.head()) | |
| # 顯示資料基本統計資訊 | |
| st.write("資料基本統計:", data.describe()) | |
| # 定義資料平衡方法處理的函數 | |
| def balance_data(balance_method, model_type): | |
| try: | |
| # 檢查目標變數是否存在 | |
| if 'physiological info_15' not in data.columns: | |
| st.error("找不到目標變數 'physiological info_15'") | |
| return None | |
| # 分離特徵與標籤 | |
| X = data.drop('physiological info_15', axis=1) | |
| y = data['physiological info_15'] | |
| # 檢查類別分布 | |
| class_distribution = y.value_counts() | |
| st.write("類別分布:", class_distribution) | |
| # 分割資料集 | |
| X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y) | |
| # 根據使用者選擇的平衡方法處理資料 | |
| if balance_method == 'SMOTE': | |
| try: | |
| smote = SMOTE(random_state=42) | |
| X_train_resampled, y_train_resampled = smote.fit_resample(X_train, y_train) | |
| except ValueError as e: | |
| st.error(f"SMOTE 處理時發生錯誤: {str(e)}") | |
| return None | |
| elif balance_method == '上採樣': | |
| # 上採樣邏輯保持不變 | |
| class_1 = data[data['physiological info_15'] == 1.0] | |
| class_2 = data[data['physiological info_15'] == 2.0] | |
| class_3 = data[data['physiological info_15'] == 3.0] | |
| max_size = max(len(class_1), len(class_2), len(class_3)) | |
| class_1_upsampled = resample(class_1.drop('physiological info_15', axis=1), | |
| replace=True, n_samples=max_size, random_state=42) | |
| class_2_upsampled = resample(class_2.drop('physiological info_15', axis=1), | |
| replace=True, n_samples=max_size, random_state=42) | |
| class_3_upsampled = resample(class_3.drop('physiological info_15', axis=1), | |
| replace=True, n_samples=max_size, random_state=42) | |
| X_train_resampled = pd.concat([class_1_upsampled, class_2_upsampled, class_3_upsampled]) | |
| y_train_resampled = pd.concat([ | |
| pd.Series([1.0] * max_size), | |
| pd.Series([2.0] * max_size), | |
| pd.Series([3.0] * max_size) | |
| ]) | |
| elif balance_method == '下採樣': | |
| # 下採樣邏輯保持不變 | |
| min_size = min(len(class_1), len(class_2), len(class_3)) | |
| class_1_downsampled = resample(class_1.drop('physiological info_15', axis=1), | |
| replace=False, n_samples=min_size, random_state=42) | |
| class_2_downsampled = resample(class_2.drop('physiological info_15', axis=1), | |
| replace=False, n_samples=min_size, random_state=42) | |
| class_3_downsampled = resample(class_3.drop('physiological info_15', axis=1), | |
| replace=False, n_samples=min_size, random_state=42) | |
| X_train_resampled = pd.concat([class_1_downsampled, class_2_downsampled, class_3_downsampled]) | |
| y_train_resampled = pd.concat([ | |
| pd.Series([1.0] * min_size), | |
| pd.Series([2.0] * min_size), | |
| pd.Series([3.0] * min_size) | |
| ]) | |
| else: | |
| X_train_resampled, y_train_resampled = X_train, y_train | |
| # 選擇模型 | |
| if model_type == '隨機森林': | |
| model = RandomForestClassifier(random_state=42) | |
| else: | |
| model = DecisionTreeClassifier(random_state=42) | |
| # 訓練模型並計算指標 | |
| return calculate_metrics(model, X_train_resampled, y_train_resampled, X_test, y_test) | |
| except Exception as e: | |
| st.error(f"處理資料時發生錯誤: {str(e)}") | |
| return None | |
| def calculate_metrics(model, X_train, y_train, X_test, y_test): | |
| # 訓練模型 | |
| model.fit(X_train, y_train) | |
| y_pred = model.predict(X_test) | |
| # 混淆矩陣 | |
| cm = confusion_matrix(y_test, y_pred) | |
| fig, ax = plt.subplots() | |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax) | |
| ax.set_title('Confusion Matrix') | |
| ax.set_xlabel('Predicted') | |
| ax.set_ylabel('Actual') | |
| st.pyplot(fig) | |
| # ROC AUC 曲線 | |
| try: | |
| y_test_binarized = label_binarize(y_test, classes=np.unique(y_test)) | |
| n_classes = y_test_binarized.shape[1] | |
| model_ovr = OneVsRestClassifier(model) | |
| y_prob = model_ovr.fit(X_train, y_train).predict_proba(X_test) | |
| # 繪製 ROC 曲線 | |
| fig, ax = plt.subplots() | |
| for i in range(n_classes): | |
| fpr, tpr, _ = roc_curve(y_test_binarized[:, i], y_prob[:, i]) | |
| roc_auc = roc_auc_score(y_test_binarized[:, i], y_prob[:, i]) | |
| ax.plot(fpr, tpr, label=f'Class {i+1} (AUC = {roc_auc:.2f})') | |
| ax.plot([0, 1], [0, 1], 'k--') | |
| ax.set_xlabel('False Positive Rate') | |
| ax.set_ylabel('True Positive Rate') | |
| ax.set_title('ROC Curves') | |
| ax.legend() | |
| st.pyplot(fig) | |
| except Exception as e: | |
| st.warning(f"繪製 ROC 曲線時發生錯誤: {str(e)}") | |
| # Cross Validation | |
| kfold = KFold(n_splits=5, shuffle=True, random_state=42) | |
| skfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) | |
| kfold_scores = cross_val_score(model, X_train, y_train, cv=kfold) | |
| skfold_scores = cross_val_score(model, X_train, y_train, cv=skfold) | |
| return { | |
| 'KFold Score': f"{kfold_scores.mean():.4f} ± {kfold_scores.std():.4f}", | |
| 'Stratified KFold Score': f"{skfold_scores.mean():.4f} ± {skfold_scores.std():.4f}", | |
| 'Test Accuracy': f"{accuracy_score(y_test, y_pred):.4f}" | |
| } | |
| # 選擇資料平衡方法和模型 | |
| balance_method = st.selectbox("選擇資料平衡方法", ["SMOTE", "上採樣", "下採樣", "不平衡"]) | |
| model_type = st.selectbox("選擇模型", ["隨機森林", "決策樹"]) | |
| # 按鈕來執行 | |
| if st.button('計算'): | |
| with st.spinner('正在處理資料與訓練模型...'): | |
| results = balance_data(balance_method, model_type) | |
| if results: | |
| st.success('計算完成!') | |
| for metric, value in results.items(): | |
| st.write(f"{metric}: {value}") | |
| except Exception as e: | |
| st.error(f"載入或處理資料時發生錯誤: {str(e)}") | |
| st.write("請確保您的 CSV 檔案格式正確且包含所需的欄位。") |