Spaces:
Sleeping
Sleeping
File size: 4,953 Bytes
a3124e1 5318f00 a7d529a f95a877 5318f00 be605fe 5318f00 be605fe 5318f00 f95a877 5318f00 f95a877 5318f00 f95a877 5318f00 f95a877 5318f00 f95a877 5318f00 f95a877 5318f00 f95a877 5318f00 f95a877 5318f00 f95a877 5318f00 f95a877 5318f00 cea075e 6d489e7 cea075e 5318f00 cea075e 5318f00 d14e29d 5318f00 cea075e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import json
import os
import joblib
import pandas as pd
import streamlit as st
from backend.train_model import train_model
MODEL_DIR = "models"
MODEL_FILE = "my_model.pkl"
MODEL_PATH = os.path.join(MODEL_DIR, MODEL_FILE)
REPORTS_DIR = "reports"
PLOTS_DIR = os.path.join(REPORTS_DIR, "plots")
FEATURES = [
"Pregnancies", "Glucose", "BloodPressure", "SkinThickness",
"Insulin", "BMI", "DiabetesPedigreeFunction", "Age"
]
st.set_page_config(page_title="Diabetes Prediction Dashboard", layout="wide")
st.title("🩺 Diabetes Prediction Dashboard")
# Sidebar navigation
st.sidebar.header("Navigation")
page = st.sidebar.radio("Go to", ["Predict", "Batch Predict", "Reports & Plots"])
model = None
if os.path.exists(MODEL_PATH):
model = joblib.load(MODEL_PATH)
# ------------------ Train button ------------------
st.subheader("Train & Predict Diabetes Model")
if st.button("Train Model"):
with st.spinner("Training in progress... this may take a while ⏳"):
model = train_model()
joblib.dump(model, MODEL_PATH)
st.success(f"✅ Model trained and saved to `{MODEL_PATH}`")
# ------------------ Predict single ------------------
def predict_df(df: pd.DataFrame):
if model is None:
st.error("⚠️ Model not loaded. Train first.")
return None
missing = [c for c in FEATURES if c not in df.columns]
if missing:
st.error(f"Missing columns: {missing}")
return None
return model.predict(df[FEATURES])
if page == "Predict":
st.subheader("🔹 Single Prediction")
cols = st.columns(4)
values = {}
ranges = {
"Pregnancies": (0, 20, 1), "Glucose": (0, 220, 120),
"BloodPressure": (0, 150, 70), "SkinThickness": (0, 100, 20),
"Insulin": (0, 900, 80), "BMI": (0.0, 70.0, 25.0),
"DiabetesPedigreeFunction": (0.0, 3.0, 0.5), "Age": (0, 120, 30)
}
for i, f in enumerate(FEATURES):
with cols[i % 4]:
lo, hi, default = ranges[f]
if isinstance(default, float):
values[f] = st.number_input(f, lo, hi, float(default))
else:
values[f] = st.number_input(f, int(lo), int(hi), int(default))
if st.button("Predict"):
row = pd.DataFrame([values])
pred = predict_df(row)
if pred is not None:
st.success("✅ Diabetic" if int(pred[0]) == 1 else "🟢 Not Diabetic")
# ------------------ Batch predict ------------------
elif page == "Batch Predict":
st.subheader("📂 Batch Prediction (Upload CSV)")
st.caption("CSV must include columns: " + ", ".join(FEATURES))
file = st.file_uploader("Upload CSV", type=["csv"])
if file is not None:
df = pd.read_csv(file)
st.write("Preview of uploaded data:")
st.dataframe(df.head())
preds = predict_df(df)
if preds is not None:
out = df.copy()
out["Prediction"] = preds
st.success(f"Predicted {len(out)} rows")
st.dataframe(out.head())
st.download_button(
"⬇️ Download predictions",
data=out.to_csv(index=False).encode('utf-8'),
file_name="predictions.csv",
mime="text/csv"
)
# ------------------ Reports & plots ------------------
elif page == "Reports & Plots":
st.subheader("📊 Model Comparison & Diagnostics")
# Load CSV
perf_csv = os.path.join(REPORTS_DIR, "model_comparison.csv")
perf_json = os.path.join(REPORTS_DIR, "model_comparison.json")
if os.path.exists(perf_csv):
df_perf = pd.read_csv(perf_csv)
st.dataframe(df_perf)
# Highlight best model based on F1
best_model = df_perf.loc[df_perf['F1'].idxmax()]
st.success(f"Best Model: {best_model['Model']} ✅ F1: {best_model['F1']} Accuracy: {best_model['Accuracy']}")
# Bar chart
st.bar_chart(df_perf.set_index('Model')[['Accuracy', 'F1']])
else:
st.info("No model comparison CSV found. Train models first!")
# Optionally show JSON
# if os.path.exists(perf_json):
# with open(perf_json, "r") as f:
# json_data = json.load(f)
# st.json(json_data)
# Display plots
plot_files = [
("Accuracy (bar)", "model_accuracy.png"),
("F1 (bar)", "model_f1.png"),
("Confusion Matrix (best)", "confusion_matrix.png"),
("ROC (best)", "roc_curve.png"),
("Variance (before/after)", "variance_comparison.png"),
("LR Loss vs Iterations", "logreg_loss_curves.png"),
("LR Accuracy vs Iterations", "logreg_accuracy_curves.png"),
]
rows = st.columns(2)
for i, (title, fname) in enumerate(plot_files):
p = os.path.join(PLOTS_DIR, fname)
if os.path.exists(p):
with rows[i % 2]:
st.markdown(f"**{title}**")
st.image(p)
else:
st.info(f"{fname} not available yet.") |