data_balance / app.py
KKingzor's picture
Update app.py
6544fec verified
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, cross_val_score, KFold, StratifiedKFold
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from imblearn.over_sampling import SMOTE
from sklearn.utils import resample
from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score, roc_curve
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
import numpy as np
# Streamlit interface for file upload
st.title("資料平衡與交叉驗證")
st.write("選擇 CSV 資料集,資料平衡方法與模型後,計算 k-Fold 和 Stratified k-Fold Cross-Validation 的分數,並繪製混淆矩陣與 ROC AUC 曲線")
# 上傳 CSV 檔案
uploaded_file = st.file_uploader("選擇 CSV 檔案", type=["csv"])
if uploaded_file is not None:
try:
# 載入資料
data = pd.read_csv(uploaded_file)
# 顯示初始的缺失值統計
missing_values = data.isnull().sum()
if missing_values.any():
st.warning("檢測到缺失值,以下是各欄位的缺失值數量:")
st.write(missing_values[missing_values > 0])
# 檢查數值型欄位
numeric_columns = data.select_dtypes(include=[np.number]).columns
# 對數值型欄位進行缺失值處理
for col in numeric_columns:
if data[col].isnull().any():
# 使用中位數填補缺失值,因為平均值容易受極端值影響
median_value = data[col].median()
data[col] = data[col].fillna(median_value)
st.info(f"欄位 '{col}' 的缺失值已使用中位數 {median_value:.2f} 填補")
# 檢查是否還有缺失值
if data.isnull().any().any():
st.error("資料集中仍存在非數值型的缺失值,請檢查資料")
st.stop()
# 顯示資料前幾行供確認
st.write("資料集預覽:", data.head())
# 顯示資料基本統計資訊
st.write("資料基本統計:", data.describe())
# 定義資料平衡方法處理的函數
def balance_data(balance_method, model_type):
try:
# 檢查目標變數是否存在
if 'physiological info_15' not in data.columns:
st.error("找不到目標變數 'physiological info_15'")
return None
# 分離特徵與標籤
X = data.drop('physiological info_15', axis=1)
y = data['physiological info_15']
# 檢查類別分布
class_distribution = y.value_counts()
st.write("類別分布:", class_distribution)
# 分割資料集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
# 根據使用者選擇的平衡方法處理資料
if balance_method == 'SMOTE':
try:
smote = SMOTE(random_state=42)
X_train_resampled, y_train_resampled = smote.fit_resample(X_train, y_train)
except ValueError as e:
st.error(f"SMOTE 處理時發生錯誤: {str(e)}")
return None
elif balance_method == '上採樣':
# 上採樣邏輯保持不變
class_1 = data[data['physiological info_15'] == 1.0]
class_2 = data[data['physiological info_15'] == 2.0]
class_3 = data[data['physiological info_15'] == 3.0]
max_size = max(len(class_1), len(class_2), len(class_3))
class_1_upsampled = resample(class_1.drop('physiological info_15', axis=1),
replace=True, n_samples=max_size, random_state=42)
class_2_upsampled = resample(class_2.drop('physiological info_15', axis=1),
replace=True, n_samples=max_size, random_state=42)
class_3_upsampled = resample(class_3.drop('physiological info_15', axis=1),
replace=True, n_samples=max_size, random_state=42)
X_train_resampled = pd.concat([class_1_upsampled, class_2_upsampled, class_3_upsampled])
y_train_resampled = pd.concat([
pd.Series([1.0] * max_size),
pd.Series([2.0] * max_size),
pd.Series([3.0] * max_size)
])
elif balance_method == '下採樣':
# 下採樣邏輯保持不變
min_size = min(len(class_1), len(class_2), len(class_3))
class_1_downsampled = resample(class_1.drop('physiological info_15', axis=1),
replace=False, n_samples=min_size, random_state=42)
class_2_downsampled = resample(class_2.drop('physiological info_15', axis=1),
replace=False, n_samples=min_size, random_state=42)
class_3_downsampled = resample(class_3.drop('physiological info_15', axis=1),
replace=False, n_samples=min_size, random_state=42)
X_train_resampled = pd.concat([class_1_downsampled, class_2_downsampled, class_3_downsampled])
y_train_resampled = pd.concat([
pd.Series([1.0] * min_size),
pd.Series([2.0] * min_size),
pd.Series([3.0] * min_size)
])
else:
X_train_resampled, y_train_resampled = X_train, y_train
# 選擇模型
if model_type == '隨機森林':
model = RandomForestClassifier(random_state=42)
else:
model = DecisionTreeClassifier(random_state=42)
# 訓練模型並計算指標
return calculate_metrics(model, X_train_resampled, y_train_resampled, X_test, y_test)
except Exception as e:
st.error(f"處理資料時發生錯誤: {str(e)}")
return None
def calculate_metrics(model, X_train, y_train, X_test, y_test):
# 訓練模型
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
# 混淆矩陣
cm = confusion_matrix(y_test, y_pred)
fig, ax = plt.subplots()
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax)
ax.set_title('Confusion Matrix')
ax.set_xlabel('Predicted')
ax.set_ylabel('Actual')
st.pyplot(fig)
# ROC AUC 曲線
try:
y_test_binarized = label_binarize(y_test, classes=np.unique(y_test))
n_classes = y_test_binarized.shape[1]
model_ovr = OneVsRestClassifier(model)
y_prob = model_ovr.fit(X_train, y_train).predict_proba(X_test)
# 繪製 ROC 曲線
fig, ax = plt.subplots()
for i in range(n_classes):
fpr, tpr, _ = roc_curve(y_test_binarized[:, i], y_prob[:, i])
roc_auc = roc_auc_score(y_test_binarized[:, i], y_prob[:, i])
ax.plot(fpr, tpr, label=f'Class {i+1} (AUC = {roc_auc:.2f})')
ax.plot([0, 1], [0, 1], 'k--')
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('ROC Curves')
ax.legend()
st.pyplot(fig)
except Exception as e:
st.warning(f"繪製 ROC 曲線時發生錯誤: {str(e)}")
# Cross Validation
kfold = KFold(n_splits=5, shuffle=True, random_state=42)
skfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
kfold_scores = cross_val_score(model, X_train, y_train, cv=kfold)
skfold_scores = cross_val_score(model, X_train, y_train, cv=skfold)
return {
'KFold Score': f"{kfold_scores.mean():.4f} ± {kfold_scores.std():.4f}",
'Stratified KFold Score': f"{skfold_scores.mean():.4f} ± {skfold_scores.std():.4f}",
'Test Accuracy': f"{accuracy_score(y_test, y_pred):.4f}"
}
# 選擇資料平衡方法和模型
balance_method = st.selectbox("選擇資料平衡方法", ["SMOTE", "上採樣", "下採樣", "不平衡"])
model_type = st.selectbox("選擇模型", ["隨機森林", "決策樹"])
# 按鈕來執行
if st.button('計算'):
with st.spinner('正在處理資料與訓練模型...'):
results = balance_data(balance_method, model_type)
if results:
st.success('計算完成!')
for metric, value in results.items():
st.write(f"{metric}: {value}")
except Exception as e:
st.error(f"載入或處理資料時發生錯誤: {str(e)}")
st.write("請確保您的 CSV 檔案格式正確且包含所需的欄位。")