mlmodels / app.py
sathishleo's picture
Add app.py, backend, requirements, ignore models folder
6d489e7
import json
import os
import joblib
import pandas as pd
import streamlit as st
from backend.train_model import train_model
MODEL_DIR = "models"
MODEL_FILE = "my_model.pkl"
MODEL_PATH = os.path.join(MODEL_DIR, MODEL_FILE)
REPORTS_DIR = "reports"
PLOTS_DIR = os.path.join(REPORTS_DIR, "plots")
FEATURES = [
"Pregnancies", "Glucose", "BloodPressure", "SkinThickness",
"Insulin", "BMI", "DiabetesPedigreeFunction", "Age"
]
st.set_page_config(page_title="Diabetes Prediction Dashboard", layout="wide")
st.title("🩺 Diabetes Prediction Dashboard")
# Sidebar navigation
st.sidebar.header("Navigation")
page = st.sidebar.radio("Go to", ["Predict", "Batch Predict", "Reports & Plots"])
model = None
if os.path.exists(MODEL_PATH):
model = joblib.load(MODEL_PATH)
# ------------------ Train button ------------------
st.subheader("Train & Predict Diabetes Model")
if st.button("Train Model"):
with st.spinner("Training in progress... this may take a while ⏳"):
model = train_model()
joblib.dump(model, MODEL_PATH)
st.success(f"βœ… Model trained and saved to `{MODEL_PATH}`")
# ------------------ Predict single ------------------
def predict_df(df: pd.DataFrame):
if model is None:
st.error("⚠️ Model not loaded. Train first.")
return None
missing = [c for c in FEATURES if c not in df.columns]
if missing:
st.error(f"Missing columns: {missing}")
return None
return model.predict(df[FEATURES])
if page == "Predict":
st.subheader("πŸ”Ή Single Prediction")
cols = st.columns(4)
values = {}
ranges = {
"Pregnancies": (0, 20, 1), "Glucose": (0, 220, 120),
"BloodPressure": (0, 150, 70), "SkinThickness": (0, 100, 20),
"Insulin": (0, 900, 80), "BMI": (0.0, 70.0, 25.0),
"DiabetesPedigreeFunction": (0.0, 3.0, 0.5), "Age": (0, 120, 30)
}
for i, f in enumerate(FEATURES):
with cols[i % 4]:
lo, hi, default = ranges[f]
if isinstance(default, float):
values[f] = st.number_input(f, lo, hi, float(default))
else:
values[f] = st.number_input(f, int(lo), int(hi), int(default))
if st.button("Predict"):
row = pd.DataFrame([values])
pred = predict_df(row)
if pred is not None:
st.success("βœ… Diabetic" if int(pred[0]) == 1 else "🟒 Not Diabetic")
# ------------------ Batch predict ------------------
elif page == "Batch Predict":
st.subheader("πŸ“‚ Batch Prediction (Upload CSV)")
st.caption("CSV must include columns: " + ", ".join(FEATURES))
file = st.file_uploader("Upload CSV", type=["csv"])
if file is not None:
df = pd.read_csv(file)
st.write("Preview of uploaded data:")
st.dataframe(df.head())
preds = predict_df(df)
if preds is not None:
out = df.copy()
out["Prediction"] = preds
st.success(f"Predicted {len(out)} rows")
st.dataframe(out.head())
st.download_button(
"⬇️ Download predictions",
data=out.to_csv(index=False).encode('utf-8'),
file_name="predictions.csv",
mime="text/csv"
)
# ------------------ Reports & plots ------------------
elif page == "Reports & Plots":
st.subheader("πŸ“Š Model Comparison & Diagnostics")
# Load CSV
perf_csv = os.path.join(REPORTS_DIR, "model_comparison.csv")
perf_json = os.path.join(REPORTS_DIR, "model_comparison.json")
if os.path.exists(perf_csv):
df_perf = pd.read_csv(perf_csv)
st.dataframe(df_perf)
# Highlight best model based on F1
best_model = df_perf.loc[df_perf['F1'].idxmax()]
st.success(f"Best Model: {best_model['Model']} βœ… F1: {best_model['F1']} Accuracy: {best_model['Accuracy']}")
# Bar chart
st.bar_chart(df_perf.set_index('Model')[['Accuracy', 'F1']])
else:
st.info("No model comparison CSV found. Train models first!")
# Optionally show JSON
# if os.path.exists(perf_json):
# with open(perf_json, "r") as f:
# json_data = json.load(f)
# st.json(json_data)
# Display plots
plot_files = [
("Accuracy (bar)", "model_accuracy.png"),
("F1 (bar)", "model_f1.png"),
("Confusion Matrix (best)", "confusion_matrix.png"),
("ROC (best)", "roc_curve.png"),
("Variance (before/after)", "variance_comparison.png"),
("LR Loss vs Iterations", "logreg_loss_curves.png"),
("LR Accuracy vs Iterations", "logreg_accuracy_curves.png"),
]
rows = st.columns(2)
for i, (title, fname) in enumerate(plot_files):
p = os.path.join(PLOTS_DIR, fname)
if os.path.exists(p):
with rows[i % 2]:
st.markdown(f"**{title}**")
st.image(p)
else:
st.info(f"{fname} not available yet.")