import json import os import joblib import pandas as pd import streamlit as st from backend.train_model import train_model MODEL_DIR = "models" MODEL_FILE = "my_model.pkl" MODEL_PATH = os.path.join(MODEL_DIR, MODEL_FILE) REPORTS_DIR = "reports" PLOTS_DIR = os.path.join(REPORTS_DIR, "plots") FEATURES = [ "Pregnancies", "Glucose", "BloodPressure", "SkinThickness", "Insulin", "BMI", "DiabetesPedigreeFunction", "Age" ] st.set_page_config(page_title="Diabetes Prediction Dashboard", layout="wide") st.title("🩺 Diabetes Prediction Dashboard") # Sidebar navigation st.sidebar.header("Navigation") page = st.sidebar.radio("Go to", ["Predict", "Batch Predict", "Reports & Plots"]) model = None if os.path.exists(MODEL_PATH): model = joblib.load(MODEL_PATH) # ------------------ Train button ------------------ st.subheader("Train & Predict Diabetes Model") if st.button("Train Model"): with st.spinner("Training in progress... this may take a while ⏳"): model = train_model() joblib.dump(model, MODEL_PATH) st.success(f"✅ Model trained and saved to `{MODEL_PATH}`") # ------------------ Predict single ------------------ def predict_df(df: pd.DataFrame): if model is None: st.error("⚠️ Model not loaded. Train first.") return None missing = [c for c in FEATURES if c not in df.columns] if missing: st.error(f"Missing columns: {missing}") return None return model.predict(df[FEATURES]) if page == "Predict": st.subheader("🔹 Single Prediction") cols = st.columns(4) values = {} ranges = { "Pregnancies": (0, 20, 1), "Glucose": (0, 220, 120), "BloodPressure": (0, 150, 70), "SkinThickness": (0, 100, 20), "Insulin": (0, 900, 80), "BMI": (0.0, 70.0, 25.0), "DiabetesPedigreeFunction": (0.0, 3.0, 0.5), "Age": (0, 120, 30) } for i, f in enumerate(FEATURES): with cols[i % 4]: lo, hi, default = ranges[f] if isinstance(default, float): values[f] = st.number_input(f, lo, hi, float(default)) else: values[f] = st.number_input(f, int(lo), int(hi), int(default)) if st.button("Predict"): row = pd.DataFrame([values]) pred = predict_df(row) if pred is not None: st.success("✅ Diabetic" if int(pred[0]) == 1 else "🟢 Not Diabetic") # ------------------ Batch predict ------------------ elif page == "Batch Predict": st.subheader("📂 Batch Prediction (Upload CSV)") st.caption("CSV must include columns: " + ", ".join(FEATURES)) file = st.file_uploader("Upload CSV", type=["csv"]) if file is not None: df = pd.read_csv(file) st.write("Preview of uploaded data:") st.dataframe(df.head()) preds = predict_df(df) if preds is not None: out = df.copy() out["Prediction"] = preds st.success(f"Predicted {len(out)} rows") st.dataframe(out.head()) st.download_button( "⬇️ Download predictions", data=out.to_csv(index=False).encode('utf-8'), file_name="predictions.csv", mime="text/csv" ) # ------------------ Reports & plots ------------------ elif page == "Reports & Plots": st.subheader("📊 Model Comparison & Diagnostics") # Load CSV perf_csv = os.path.join(REPORTS_DIR, "model_comparison.csv") perf_json = os.path.join(REPORTS_DIR, "model_comparison.json") if os.path.exists(perf_csv): df_perf = pd.read_csv(perf_csv) st.dataframe(df_perf) # Highlight best model based on F1 best_model = df_perf.loc[df_perf['F1'].idxmax()] st.success(f"Best Model: {best_model['Model']} ✅ F1: {best_model['F1']} Accuracy: {best_model['Accuracy']}") # Bar chart st.bar_chart(df_perf.set_index('Model')[['Accuracy', 'F1']]) else: st.info("No model comparison CSV found. Train models first!") # Optionally show JSON # if os.path.exists(perf_json): # with open(perf_json, "r") as f: # json_data = json.load(f) # st.json(json_data) # Display plots plot_files = [ ("Accuracy (bar)", "model_accuracy.png"), ("F1 (bar)", "model_f1.png"), ("Confusion Matrix (best)", "confusion_matrix.png"), ("ROC (best)", "roc_curve.png"), ("Variance (before/after)", "variance_comparison.png"), ("LR Loss vs Iterations", "logreg_loss_curves.png"), ("LR Accuracy vs Iterations", "logreg_accuracy_curves.png"), ] rows = st.columns(2) for i, (title, fname) in enumerate(plot_files): p = os.path.join(PLOTS_DIR, fname) if os.path.exists(p): with rows[i % 2]: st.markdown(f"**{title}**") st.image(p) else: st.info(f"{fname} not available yet.")