sathishleo commited on
Commit
5318f00
·
1 Parent(s): e1cab8a

Add app.py, backend, and model for HF Space

Browse files
Files changed (1) hide show
  1. app.py +146 -2
app.py CHANGED
@@ -1,4 +1,148 @@
 
 
 
1
  import streamlit as st
2
 
3
- st.title("✅ HF Space App Test")
4
- st.write("Hello! This Space is working.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import joblib
3
+ import pandas as pd
4
  import streamlit as st
5
 
6
+ NONE = None
7
+ # from backend.train_model import train_model
8
+
9
+ # Get the current directory of the Streamlit script
10
+ # BASE_DIR = os.path.dirname(os.path.abspath(__file__))
11
+ #
12
+ # # Build the absolute path to the model
13
+ # # BASE_DIR = os.path.dirname(os.path.abspath(__file__))
14
+ # MODEL_PATH = os.path.join(BASE_DIR, "..", "models", "best_model.pkl")
15
+ # REPORTS_DIR = os.path.join(BASE_DIR, "..", "reports")
16
+ # PLOTS_DIR = os.path.join(REPORTS_DIR, "plots")
17
+ MODEL_DIR = "backend/models"
18
+ MODEL_FILE = "my_model.pkl"
19
+ MODEL_PATH = os.path.join(MODEL_DIR, MODEL_FILE)
20
+
21
+ REPORTS_DIR = "backend/reports"
22
+ PLOTS_DIR = os.path.join(REPORTS_DIR, "plots")
23
+
24
+
25
+ BACKEND_URL = os.getenv("BACKEND_URL", "http://backend:5000")
26
+
27
+ st.set_page_config(page_title="Diabetes Prediction Dashboard", layout="wide")
28
+ st.title("🩺 Diabetes Prediction Dashboard")
29
+
30
+ # ---------- Sidebar ----------
31
+ st.sidebar.header("Navigation")
32
+ page = st.sidebar.radio("Go to", ["Predict", "Batch Predict", "Reports & Plots"])
33
+
34
+ # ---------- Load best model ----------
35
+
36
+ @st.cache_resource
37
+ def load_model(path):
38
+ if os.path.exists(path):
39
+ model = joblib.load(path)
40
+ st.sidebar.success("✅ Best model loaded")
41
+ return model
42
+ else:
43
+ st.sidebar.error("❌ Best model not found. Run backend/train_model.py first.")
44
+ return None
45
+
46
+ model = load_model(MODEL_PATH)
47
+
48
+ # ---------- Features ----------
49
+ FEATURES = [
50
+ "Pregnancies", "Glucose", "BloodPressure", "SkinThickness",
51
+ "Insulin", "BMI", "DiabetesPedigreeFunction", "Age"
52
+ ]
53
+
54
+ def predict_df(df: pd.DataFrame):
55
+ """Run model prediction on a DataFrame"""
56
+ if model is None:
57
+ st.error("Model not loaded")
58
+ return None
59
+ missing = [c for c in FEATURES if c not in df.columns]
60
+ if missing:
61
+ st.error(f"Missing columns: {missing}")
62
+ return None
63
+ return model.predict(df[FEATURES])
64
+
65
+ # ---------- Pages ----------
66
+ if page == "Predict":
67
+ st.subheader("🔹 Single Prediction")
68
+
69
+ cols = st.columns(4)
70
+ values = {}
71
+ ranges = {
72
+ "Pregnancies": (0, 20, 1), "Glucose": (0, 220, 120),
73
+ "BloodPressure": (0, 150, 70), "SkinThickness": (0, 100, 20),
74
+ "Insulin": (0, 900, 80), "BMI": (0.0, 70.0, 25.0),
75
+ "DiabetesPedigreeFunction": (0.0, 3.0, 0.5), "Age": (0, 120, 30)
76
+ }
77
+
78
+ for i, f in enumerate(FEATURES):
79
+ with cols[i % 4]:
80
+ lo, hi, default = ranges[f]
81
+ if isinstance(default, float):
82
+ values[f] = st.number_input(f, lo, hi, float(default))
83
+ else:
84
+ values[f] = st.number_input(f, int(lo), int(hi), int(default))
85
+
86
+ if st.button("Predict"):
87
+ row = pd.DataFrame([values])
88
+ pred = predict_df(row)
89
+ if pred is not None:
90
+ st.success("✅ Diabetic" if int(pred[0]) == 1 else "🟢 Not Diabetic")
91
+
92
+ elif page == "Batch Predict":
93
+ st.subheader("📂 Batch Prediction (Upload CSV)")
94
+ st.caption("CSV must include columns: " + ", ".join(FEATURES))
95
+
96
+ file = st.file_uploader("Upload CSV", type=["csv"])
97
+ if file is not None:
98
+ df = pd.read_csv(file)
99
+ st.write("Preview of uploaded data:")
100
+ st.dataframe(df.head())
101
+
102
+ preds = predict_df(df)
103
+ if preds is not None:
104
+ out = df.copy()
105
+ out["Prediction"] = preds
106
+ st.success(f"Predicted {len(out)} rows")
107
+ st.dataframe(out.head())
108
+
109
+ st.download_button(
110
+ "⬇️ Download predictions",
111
+ data=out.to_csv(index=False).encode('utf-8'),
112
+ file_name="predictions.csv",
113
+ mime="text/csv"
114
+ )
115
+
116
+ elif page == "Reports & Plots":
117
+ st.subheader("📊 Model Comparison & Diagnostics")
118
+
119
+ # Table report
120
+ # cmp_path = os.path.join(REPORTS_DIR, "model_comparison.csv")
121
+ # if os.path.exists(cmp_path):
122
+ # cmp_df = pd.read_csv(cmp_path)
123
+ # st.dataframe(cmp_df)
124
+ # else:
125
+ # st.warning("⚠️ model_comparison.csv not found. Run training.")
126
+
127
+ # Plots grid
128
+ plot_files = [
129
+ ("Accuracy (bar)", "model_accuracy.png"),
130
+ ("F1 (bar)", "model_f1.png"),
131
+ ("Confusion Matrix (best)", "confusion_matrix.png"),
132
+ ("ROC (best)", "roc_curve.png"),
133
+ ("Variance (before/after)", "variance_comparison.png"),
134
+ ("LR Loss vs Iterations", "logreg_loss_curves.png"),
135
+ ("LR Accuracy vs Iterations", "logreg_accuracy_curves.png"),
136
+ ]
137
+
138
+ rows = st.columns(2)
139
+ i = 0
140
+ for title, fname in plot_files:
141
+ p = os.path.join(PLOTS_DIR, fname)
142
+ if os.path.exists(p):
143
+ with rows[i % 2]:
144
+ st.markdown(f"**{title}**")
145
+ st.image(p, use_container_width=True)
146
+ i += 1
147
+ else:
148
+ st.info(f"{fname} not available yet.")