Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import joblib | |
| import pandas as pd | |
| import streamlit as st | |
| from backend.train_model import train_model | |
| MODEL_DIR = "models" | |
| MODEL_FILE = "my_model.pkl" | |
| MODEL_PATH = os.path.join(MODEL_DIR, MODEL_FILE) | |
| REPORTS_DIR = "reports" | |
| PLOTS_DIR = os.path.join(REPORTS_DIR, "plots") | |
| FEATURES = [ | |
| "Pregnancies", "Glucose", "BloodPressure", "SkinThickness", | |
| "Insulin", "BMI", "DiabetesPedigreeFunction", "Age" | |
| ] | |
| st.set_page_config(page_title="Diabetes Prediction Dashboard", layout="wide") | |
| st.title("π©Ί Diabetes Prediction Dashboard") | |
| # Sidebar navigation | |
| st.sidebar.header("Navigation") | |
| page = st.sidebar.radio("Go to", ["Predict", "Batch Predict", "Reports & Plots"]) | |
| model = None | |
| if os.path.exists(MODEL_PATH): | |
| model = joblib.load(MODEL_PATH) | |
| # ------------------ Train button ------------------ | |
| st.subheader("Train & Predict Diabetes Model") | |
| if st.button("Train Model"): | |
| with st.spinner("Training in progress... this may take a while β³"): | |
| model = train_model() | |
| joblib.dump(model, MODEL_PATH) | |
| st.success(f"β Model trained and saved to `{MODEL_PATH}`") | |
| # ------------------ Predict single ------------------ | |
| def predict_df(df: pd.DataFrame): | |
| if model is None: | |
| st.error("β οΈ Model not loaded. Train first.") | |
| return None | |
| missing = [c for c in FEATURES if c not in df.columns] | |
| if missing: | |
| st.error(f"Missing columns: {missing}") | |
| return None | |
| return model.predict(df[FEATURES]) | |
| if page == "Predict": | |
| st.subheader("πΉ Single Prediction") | |
| cols = st.columns(4) | |
| values = {} | |
| ranges = { | |
| "Pregnancies": (0, 20, 1), "Glucose": (0, 220, 120), | |
| "BloodPressure": (0, 150, 70), "SkinThickness": (0, 100, 20), | |
| "Insulin": (0, 900, 80), "BMI": (0.0, 70.0, 25.0), | |
| "DiabetesPedigreeFunction": (0.0, 3.0, 0.5), "Age": (0, 120, 30) | |
| } | |
| for i, f in enumerate(FEATURES): | |
| with cols[i % 4]: | |
| lo, hi, default = ranges[f] | |
| if isinstance(default, float): | |
| values[f] = st.number_input(f, lo, hi, float(default)) | |
| else: | |
| values[f] = st.number_input(f, int(lo), int(hi), int(default)) | |
| if st.button("Predict"): | |
| row = pd.DataFrame([values]) | |
| pred = predict_df(row) | |
| if pred is not None: | |
| st.success("β Diabetic" if int(pred[0]) == 1 else "π’ Not Diabetic") | |
| # ------------------ Batch predict ------------------ | |
| elif page == "Batch Predict": | |
| st.subheader("π Batch Prediction (Upload CSV)") | |
| st.caption("CSV must include columns: " + ", ".join(FEATURES)) | |
| file = st.file_uploader("Upload CSV", type=["csv"]) | |
| if file is not None: | |
| df = pd.read_csv(file) | |
| st.write("Preview of uploaded data:") | |
| st.dataframe(df.head()) | |
| preds = predict_df(df) | |
| if preds is not None: | |
| out = df.copy() | |
| out["Prediction"] = preds | |
| st.success(f"Predicted {len(out)} rows") | |
| st.dataframe(out.head()) | |
| st.download_button( | |
| "β¬οΈ Download predictions", | |
| data=out.to_csv(index=False).encode('utf-8'), | |
| file_name="predictions.csv", | |
| mime="text/csv" | |
| ) | |
| # ------------------ Reports & plots ------------------ | |
| elif page == "Reports & Plots": | |
| st.subheader("π Model Comparison & Diagnostics") | |
| # Load CSV | |
| perf_csv = os.path.join(REPORTS_DIR, "model_comparison.csv") | |
| perf_json = os.path.join(REPORTS_DIR, "model_comparison.json") | |
| if os.path.exists(perf_csv): | |
| df_perf = pd.read_csv(perf_csv) | |
| st.dataframe(df_perf) | |
| # Highlight best model based on F1 | |
| best_model = df_perf.loc[df_perf['F1'].idxmax()] | |
| st.success(f"Best Model: {best_model['Model']} β F1: {best_model['F1']} Accuracy: {best_model['Accuracy']}") | |
| # Bar chart | |
| st.bar_chart(df_perf.set_index('Model')[['Accuracy', 'F1']]) | |
| else: | |
| st.info("No model comparison CSV found. Train models first!") | |
| # Optionally show JSON | |
| # if os.path.exists(perf_json): | |
| # with open(perf_json, "r") as f: | |
| # json_data = json.load(f) | |
| # st.json(json_data) | |
| # Display plots | |
| plot_files = [ | |
| ("Accuracy (bar)", "model_accuracy.png"), | |
| ("F1 (bar)", "model_f1.png"), | |
| ("Confusion Matrix (best)", "confusion_matrix.png"), | |
| ("ROC (best)", "roc_curve.png"), | |
| ("Variance (before/after)", "variance_comparison.png"), | |
| ("LR Loss vs Iterations", "logreg_loss_curves.png"), | |
| ("LR Accuracy vs Iterations", "logreg_accuracy_curves.png"), | |
| ] | |
| rows = st.columns(2) | |
| for i, (title, fname) in enumerate(plot_files): | |
| p = os.path.join(PLOTS_DIR, fname) | |
| if os.path.exists(p): | |
| with rows[i % 2]: | |
| st.markdown(f"**{title}**") | |
| st.image(p) | |
| else: | |
| st.info(f"{fname} not available yet.") |