sathishleo commited on
Commit
f95a877
·
1 Parent(s): 00aa295

Add app.py, backend, and model for HF Space

Browse files
Files changed (1) hide show
  1. app.py +22 -71
app.py CHANGED
@@ -1,21 +1,9 @@
1
  import os
2
- import subprocess
3
-
4
  import joblib
5
  import pandas as pd
6
  import streamlit as st
7
- from backend.train_model import train_model # your function
8
- NONE = None
9
- # from backend.train_model import train_model
10
 
11
- # Get the current directory of the Streamlit script
12
- # BASE_DIR = os.path.dirname(os.path.abspath(__file__))
13
- #
14
- # # Build the absolute path to the model
15
- # # BASE_DIR = os.path.dirname(os.path.abspath(__file__))
16
- # MODEL_PATH = os.path.join(BASE_DIR, "..", "models", "best_model.pkl")
17
- # REPORTS_DIR = os.path.join(BASE_DIR, "..", "reports")
18
- # PLOTS_DIR = os.path.join(REPORTS_DIR, "plots")
19
  MODEL_DIR = "models"
20
  MODEL_FILE = "my_model.pkl"
21
  MODEL_PATH = os.path.join(MODEL_DIR, MODEL_FILE)
@@ -23,45 +11,35 @@ MODEL_PATH = os.path.join(MODEL_DIR, MODEL_FILE)
23
  REPORTS_DIR = "reports"
24
  PLOTS_DIR = os.path.join(REPORTS_DIR, "plots")
25
 
26
-
27
- BACKEND_URL = os.getenv("BACKEND_URL", "http://backend:5000")
 
 
28
 
29
  st.set_page_config(page_title="Diabetes Prediction Dashboard", layout="wide")
30
  st.title("🩺 Diabetes Prediction Dashboard")
31
 
32
- # ---------- Sidebar ----------
33
  st.sidebar.header("Navigation")
34
  page = st.sidebar.radio("Go to", ["Predict", "Batch Predict", "Reports & Plots"])
35
 
36
- # ---------- Load best model ----------
 
 
37
 
38
- # @st.cache_resource
39
- # def load_model(path):
40
- # if os.path.exists(path):
41
- # model = joblib.load(path)
42
- # st.sidebar.success("✅ Best model loaded")
43
- # return model
44
- # else:
45
- # result = subprocess.run(["python", "backend/train_model.py"], capture_output=True, text=True)
46
- # st.text(result.stdout)
47
- # st.text(result.stderr)
48
- #
49
- # # Reload the trained model
50
- # model = load_model(MODEL_PATH)
51
- # return model
52
- #
53
- # model = load_model(MODEL_PATH)
54
 
55
- # ---------- Features ----------
56
- FEATURES = [
57
- "Pregnancies", "Glucose", "BloodPressure", "SkinThickness",
58
- "Insulin", "BMI", "DiabetesPedigreeFunction", "Age"
59
- ]
60
 
 
61
  def predict_df(df: pd.DataFrame):
62
- """Run model prediction on a DataFrame"""
63
  if model is None:
64
- st.error("Model not loaded")
65
  return None
66
  missing = [c for c in FEATURES if c not in df.columns]
67
  if missing:
@@ -69,21 +47,8 @@ def predict_df(df: pd.DataFrame):
69
  return None
70
  return model.predict(df[FEATURES])
71
 
72
- # # ---------- Pages ----------
73
- # model = joblib.load(MODEL_PATH)
74
- st.title("Train & Predict Diabetes Model")
75
-
76
- if not os.path.exists(MODEL_PATH):
77
- st.warning("No trained model found. Please train the model first.")
78
-
79
- if st.button("Train Model"):
80
- st.info("Training started...")
81
- model = train_model()
82
- joblib.dump(model, MODEL_PATH)
83
- st.success(f"Model trained and saved to {MODEL_PATH}")
84
- elif page == "Predict":
85
  st.subheader("🔹 Single Prediction")
86
-
87
  cols = st.columns(4)
88
  values = {}
89
  ranges = {
@@ -92,7 +57,6 @@ elif page == "Predict":
92
  "Insulin": (0, 900, 80), "BMI": (0.0, 70.0, 25.0),
93
  "DiabetesPedigreeFunction": (0.0, 3.0, 0.5), "Age": (0, 120, 30)
94
  }
95
-
96
  for i, f in enumerate(FEATURES):
97
  with cols[i % 4]:
98
  lo, hi, default = ranges[f]
@@ -100,30 +64,27 @@ elif page == "Predict":
100
  values[f] = st.number_input(f, lo, hi, float(default))
101
  else:
102
  values[f] = st.number_input(f, int(lo), int(hi), int(default))
103
-
104
  if st.button("Predict"):
105
  row = pd.DataFrame([values])
106
  pred = predict_df(row)
107
  if pred is not None:
108
  st.success("✅ Diabetic" if int(pred[0]) == 1 else "🟢 Not Diabetic")
109
 
 
110
  elif page == "Batch Predict":
111
  st.subheader("📂 Batch Prediction (Upload CSV)")
112
  st.caption("CSV must include columns: " + ", ".join(FEATURES))
113
-
114
  file = st.file_uploader("Upload CSV", type=["csv"])
115
  if file is not None:
116
  df = pd.read_csv(file)
117
  st.write("Preview of uploaded data:")
118
  st.dataframe(df.head())
119
-
120
  preds = predict_df(df)
121
  if preds is not None:
122
  out = df.copy()
123
  out["Prediction"] = preds
124
  st.success(f"Predicted {len(out)} rows")
125
  st.dataframe(out.head())
126
-
127
  st.download_button(
128
  "⬇️ Download predictions",
129
  data=out.to_csv(index=False).encode('utf-8'),
@@ -131,18 +92,9 @@ elif page == "Batch Predict":
131
  mime="text/csv"
132
  )
133
 
 
134
  elif page == "Reports & Plots":
135
  st.subheader("📊 Model Comparison & Diagnostics")
136
-
137
- # Table report
138
- # cmp_path = os.path.join(REPORTS_DIR, "model_comparison.csv")
139
- # if os.path.exists(cmp_path):
140
- # cmp_df = pd.read_csv(cmp_path)
141
- # st.dataframe(cmp_df)
142
- # else:
143
- # st.warning("⚠️ model_comparison.csv not found. Run training.")
144
-
145
- # Plots grid
146
  plot_files = [
147
  ("Accuracy (bar)", "model_accuracy.png"),
148
  ("F1 (bar)", "model_f1.png"),
@@ -152,7 +104,6 @@ elif page == "Reports & Plots":
152
  ("LR Loss vs Iterations", "logreg_loss_curves.png"),
153
  ("LR Accuracy vs Iterations", "logreg_accuracy_curves.png"),
154
  ]
155
-
156
  rows = st.columns(2)
157
  i = 0
158
  for title, fname in plot_files:
@@ -160,7 +111,7 @@ elif page == "Reports & Plots":
160
  if os.path.exists(p):
161
  with rows[i % 2]:
162
  st.markdown(f"**{title}**")
163
- st.image(p, use_container_width=True)
164
  i += 1
165
  else:
166
  st.info(f"{fname} not available yet.")
 
1
  import os
 
 
2
  import joblib
3
  import pandas as pd
4
  import streamlit as st
5
+ from backend.train_model import train_model
 
 
6
 
 
 
 
 
 
 
 
 
7
  MODEL_DIR = "models"
8
  MODEL_FILE = "my_model.pkl"
9
  MODEL_PATH = os.path.join(MODEL_DIR, MODEL_FILE)
 
11
  REPORTS_DIR = "reports"
12
  PLOTS_DIR = os.path.join(REPORTS_DIR, "plots")
13
 
14
+ FEATURES = [
15
+ "Pregnancies", "Glucose", "BloodPressure", "SkinThickness",
16
+ "Insulin", "BMI", "DiabetesPedigreeFunction", "Age"
17
+ ]
18
 
19
  st.set_page_config(page_title="Diabetes Prediction Dashboard", layout="wide")
20
  st.title("🩺 Diabetes Prediction Dashboard")
21
 
22
+ # Sidebar navigation
23
  st.sidebar.header("Navigation")
24
  page = st.sidebar.radio("Go to", ["Predict", "Batch Predict", "Reports & Plots"])
25
 
26
+ model = None
27
+ if os.path.exists(MODEL_PATH):
28
+ model = joblib.load(MODEL_PATH)
29
 
30
+ # ------------------ Train button ------------------
31
+ st.subheader("Train & Predict Diabetes Model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ if st.button("Train Model"):
34
+ with st.spinner("Training in progress... this may take a while ⏳"):
35
+ model = train_model()
36
+ joblib.dump(model, MODEL_PATH)
37
+ st.success(f"✅ Model trained and saved to `{MODEL_PATH}`")
38
 
39
+ # ------------------ Predict single ------------------
40
  def predict_df(df: pd.DataFrame):
 
41
  if model is None:
42
+ st.error("⚠️ Model not loaded. Train first.")
43
  return None
44
  missing = [c for c in FEATURES if c not in df.columns]
45
  if missing:
 
47
  return None
48
  return model.predict(df[FEATURES])
49
 
50
+ if page == "Predict":
 
 
 
 
 
 
 
 
 
 
 
 
51
  st.subheader("🔹 Single Prediction")
 
52
  cols = st.columns(4)
53
  values = {}
54
  ranges = {
 
57
  "Insulin": (0, 900, 80), "BMI": (0.0, 70.0, 25.0),
58
  "DiabetesPedigreeFunction": (0.0, 3.0, 0.5), "Age": (0, 120, 30)
59
  }
 
60
  for i, f in enumerate(FEATURES):
61
  with cols[i % 4]:
62
  lo, hi, default = ranges[f]
 
64
  values[f] = st.number_input(f, lo, hi, float(default))
65
  else:
66
  values[f] = st.number_input(f, int(lo), int(hi), int(default))
 
67
  if st.button("Predict"):
68
  row = pd.DataFrame([values])
69
  pred = predict_df(row)
70
  if pred is not None:
71
  st.success("✅ Diabetic" if int(pred[0]) == 1 else "🟢 Not Diabetic")
72
 
73
+ # ------------------ Batch predict ------------------
74
  elif page == "Batch Predict":
75
  st.subheader("📂 Batch Prediction (Upload CSV)")
76
  st.caption("CSV must include columns: " + ", ".join(FEATURES))
 
77
  file = st.file_uploader("Upload CSV", type=["csv"])
78
  if file is not None:
79
  df = pd.read_csv(file)
80
  st.write("Preview of uploaded data:")
81
  st.dataframe(df.head())
 
82
  preds = predict_df(df)
83
  if preds is not None:
84
  out = df.copy()
85
  out["Prediction"] = preds
86
  st.success(f"Predicted {len(out)} rows")
87
  st.dataframe(out.head())
 
88
  st.download_button(
89
  "⬇️ Download predictions",
90
  data=out.to_csv(index=False).encode('utf-8'),
 
92
  mime="text/csv"
93
  )
94
 
95
+ # ------------------ Reports & plots ------------------
96
  elif page == "Reports & Plots":
97
  st.subheader("📊 Model Comparison & Diagnostics")
 
 
 
 
 
 
 
 
 
 
98
  plot_files = [
99
  ("Accuracy (bar)", "model_accuracy.png"),
100
  ("F1 (bar)", "model_f1.png"),
 
104
  ("LR Loss vs Iterations", "logreg_loss_curves.png"),
105
  ("LR Accuracy vs Iterations", "logreg_accuracy_curves.png"),
106
  ]
 
107
  rows = st.columns(2)
108
  i = 0
109
  for title, fname in plot_files:
 
111
  if os.path.exists(p):
112
  with rows[i % 2]:
113
  st.markdown(f"**{title}**")
114
+ st.image(p, width=700) # ✅ works on all Streamlit versions
115
  i += 1
116
  else:
117
  st.info(f"{fname} not available yet.")