hearts / app.py
sidcww's picture
Update app.py
07bd0fb verified
import pandas as pd
import seaborn as sns
import streamlit as st
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
def process_and_evaluate(file):
# 載入數據集
df = pd.read_csv(file)
# 編碼分類特徵
categorical_columns = df.select_dtypes(include=['object']).columns
label_encoders = {}
for col in categorical_columns:
le = LabelEncoder()
df[col] = le.fit_transform(df[col])
label_encoders[col] = le
# 定義目標變數和特徵
target = 'target' # 假設目標列名為 'target'
X = df.drop(columns=[target])
y = df[target]
# 將數據集拆分為訓練集和測試集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 訓練隨機森林分類器
clf = RandomForestClassifier(random_state=42)
clf.fit(X_train, y_train)
# 在測試集上進行預測
y_pred = clf.predict(X_test)
y_pred_prob = clf.predict_proba(X_test)[:, 1] # 正類的概率得分
# 計算混淆矩陣
conf_matrix = confusion_matrix(y_test, y_pred)
# 計算分類報告
classification_rep = classification_report(y_test, y_pred)
# 計算ROC曲線和AUC
fpr, tpr, _ = roc_curve(y_test, y_pred_prob)
roc_auc = auc(fpr, tpr)
return df, conf_matrix, classification_rep, fpr, tpr, roc_auc
def main():
st.title("心臟病預測")
st.write("上傳包含心臟病數據的CSV文件,以獲取分類報告、混淆矩陣、相關矩陣、ROC曲線和AUC。")
uploaded_file = st.file_uploader("選擇一個CSV文件", type="csv")
if uploaded_file is not None:
df, conf_matrix, classification_rep, fpr, tpr, roc_auc = process_and_evaluate(uploaded_file)
st.subheader("分類報告")
st.text(classification_rep)
st.subheader("混淆矩陣")
fig, ax = plt.subplots()
sns.heatmap(conf_matrix, annot=True, fmt='d', ax=ax, cmap='Blues')
plt.title("混淆矩陣")
ax.set_xlabel('預測標籤')
ax.set_ylabel('實際標籤')
st.pyplot(fig)
st.subheader("相關矩陣")
corr_matrix = df.corr()
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', ax=ax)
plt.title("相關矩陣")
ax.set_xlabel('特徵')
ax.set_ylabel('特徵')
st.pyplot(fig)
st.subheader("ROC曲線和AUC")
fig, ax = plt.subplots(figsize=(10, 8)) # 更大的圖形尺寸以便更清晰地顯示
ax.plot(fpr, tpr, color='blue', lw=2, label=f'ROC曲線 (AUC = {roc_auc:.2f})')
ax.plot([0, 1], [0, 1], color='gray', lw=2, linestyle='--')
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('假陽性率', fontsize=14)
ax.set_ylabel('真正例率', fontsize=14)
ax.set_title('接收者操作特徵曲線 (ROC)', fontsize=16)
ax.legend(loc="lower right", fontsize=12)
ax.grid(True, linestyle='--', linewidth=0.5)
st.pyplot(fig)
if __name__ == "__main__":
main()