Synav commited on
Commit
aa01912
·
verified ·
1 Parent(s): cd77cd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -0
app.py CHANGED
@@ -100,6 +100,31 @@ def get_feature_cols_from_df(df: pd.DataFrame):
100
  f"Ensure your Excel header row contains a column named '{LABEL_COL}'.")
101
  return [c for c in df.columns if c != LABEL_COL]
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  # ============================================================
105
  # Model pipeline
@@ -1110,6 +1135,136 @@ with tab_predict:
1110
 
1111
  pipe = st.session_state.pipe
1112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1113
  infer_file = st.file_uploader("Upload inference Excel (.xlsx)", type=["xlsx"])
1114
  if infer_file:
1115
  df_inf = pd.read_excel(infer_file, engine="openpyxl")
@@ -1117,6 +1272,7 @@ with tab_predict:
1117
  if not meta:
1118
  st.error("Model metadata not loaded. Please load a model version.")
1119
  st.stop()
 
1120
 
1121
  feature_cols = meta["schema"]["features"]
1122
  num_cols = meta["schema"]["numeric"]
 
100
  f"Ensure your Excel header row contains a column named '{LABEL_COL}'.")
101
  return [c for c in df.columns if c != LABEL_COL]
102
 
103
+ from datetime import date
104
+
105
+ def to_date_safe(x):
106
+ """Accepts date/datetime/str; returns python date or None."""
107
+ if x is None or (isinstance(x, float) and np.isnan(x)):
108
+ return None
109
+ if isinstance(x, datetime):
110
+ return x.date()
111
+ if isinstance(x, date):
112
+ return x
113
+ # string
114
+ try:
115
+ return pd.to_datetime(x).date()
116
+ except Exception:
117
+ return None
118
+
119
+ def age_years_at(dob, ref_date):
120
+ """Age in years at ref_date. Returns float or np.nan."""
121
+ dob = to_date_safe(dob)
122
+ ref_date = to_date_safe(ref_date)
123
+ if dob is None or ref_date is None:
124
+ return np.nan
125
+ if ref_date < dob:
126
+ return np.nan
127
+ return (ref_date - dob).days / 365.25
128
 
129
  # ============================================================
130
  # Model pipeline
 
1135
 
1136
  pipe = st.session_state.pipe
1137
 
1138
+ # =========================
1139
+ # After model is loaded
1140
+ # =========================
1141
+ pipe = st.session_state.pipe
1142
+ meta = st.session_state.meta
1143
+ feature_cols = meta["schema"]["features"]
1144
+ num_cols = meta["schema"]["numeric"]
1145
+ cat_cols = meta["schema"]["categorical"]
1146
+
1147
+ st.divider()
1148
+ st.subheader("Single patient entry (DOB → Age, Dx date → prediction)")
1149
+
1150
+ AGE_FEATURE = "Age (years)"
1151
+ DX_DATE_FEATURE = "Date of 1st Bone Marrow biopsy (Date of Diagnosis) " # note trailing space
1152
+
1153
+ # safer date inputs (Streamlit versions vary with None defaults)
1154
+ c1, c2, c3 = st.columns(3)
1155
+ with c1:
1156
+ dob_tmp = st.date_input("Date of birth (DOB)", value=date.today(), key="sp_dob")
1157
+ use_dob = st.checkbox("Use DOB", value=False, key="sp_use_dob")
1158
+ dob = dob_tmp if use_dob else None
1159
+
1160
+ with c2:
1161
+ dx_tmp = st.date_input("Date of Diagnosis / 1st Bone Marrow biopsy", value=date.today(), key="sp_dx_date")
1162
+ use_dx = st.checkbox("Use Dx date", value=False, key="sp_use_dx")
1163
+ dx_date = dx_tmp if use_dx else None
1164
+
1165
+ with c3:
1166
+ ecog = st.selectbox("ECOG", options=[0, 1, 2, 3, 4], index=0, key="sp_ecog")
1167
+
1168
+ from typing import Optional
1169
+ def age_years_at(dob: Optional[date], ref_date: Optional[date]) -> float:
1170
+ if dob is None or ref_date is None:
1171
+ return np.nan
1172
+ if ref_date < dob:
1173
+ return np.nan
1174
+ return (ref_date - dob).days / 365.25
1175
+
1176
+ derived_age = age_years_at(dob, dx_date)
1177
+
1178
+ with st.expander("Enter remaining model features", expanded=False):
1179
+ with st.form("single_patient_form"):
1180
+
1181
+ values_by_index = [np.nan] * len(feature_cols)
1182
+
1183
+ for i, f in enumerate(feature_cols):
1184
+
1185
+ if f == AGE_FEATURE:
1186
+ st.number_input(
1187
+ f"{f} (auto)",
1188
+ value=None if np.isnan(derived_age) else float(derived_age),
1189
+ format="%.2f",
1190
+ key=f"sp_{i}_age_display",
1191
+ )
1192
+ values_by_index[i] = derived_age
1193
+ continue
1194
+
1195
+ if f.strip() == DX_DATE_FEATURE.strip():
1196
+ # if model expects numeric, store numeric timestamp
1197
+ if f in num_cols and dx_date is not None:
1198
+ values_by_index[i] = pd.Timestamp(dx_date).toordinal()
1199
+ else:
1200
+ values_by_index[i] = np.nan if dx_date is None else str(dx_date)
1201
+ continue
1202
+
1203
+
1204
+ if f.strip() == "ECOG":
1205
+ values_by_index[i] = ecog
1206
+ continue
1207
+
1208
+
1209
+ if f in num_cols:
1210
+ v = st.number_input(f, value=None, format="%.6f", key=f"sp_{i}_num")
1211
+ values_by_index[i] = v
1212
+ elif f in cat_cols:
1213
+ v = st.text_input(f, value="", key=f"sp_{i}_cat")
1214
+ values_by_index[i] = np.nan if v.strip() == "" else v
1215
+ else:
1216
+ v = st.text_input(f, value="", key=f"sp_{i}_other")
1217
+ values_by_index[i] = np.nan if v.strip() == "" else v
1218
+
1219
+ submitted = st.form_submit_button("Predict single patient")
1220
+
1221
+ # IMPORTANT: everything below must be inside "if submitted:"
1222
+ if submitted:
1223
+ X_one = pd.DataFrame([values_by_index], columns=feature_cols).replace({pd.NA: np.nan})
1224
+
1225
+ for c in num_cols:
1226
+ if c in X_one.columns:
1227
+ X_one[c] = pd.to_numeric(X_one[c], errors="coerce")
1228
+
1229
+ for c in cat_cols:
1230
+ if c in X_one.columns:
1231
+ X_one[c] = X_one[c].astype("object")
1232
+ X_one.loc[X_one[c].isna(), c] = np.nan
1233
+ X_one[c] = X_one[c].map(lambda v: v if pd.isna(v) else str(v))
1234
+
1235
+ proba_one = float(pipe.predict_proba(X_one)[:, 1][0])
1236
+ st.success("Prediction generated.")
1237
+ st.metric("Predicted probability", f"{proba_one:.4f}")
1238
+
1239
+ thr_single = st.slider("Classification threshold", 0.0, 1.0, 0.5, 0.01, key="sp_thr")
1240
+ pred_class = int(proba_one >= thr_single)
1241
+
1242
+ low_cut_s, high_cut_s = st.slider(
1243
+ "Risk band cutoffs (low, high)", 0.0, 1.0, (0.2, 0.8), 0.01, key="sp_risk_cuts"
1244
+ )
1245
+
1246
+ def band_one(p):
1247
+ if p < low_cut_s:
1248
+ return "Low"
1249
+ if p >= high_cut_s:
1250
+ return "High"
1251
+ return "Intermediate"
1252
+
1253
+ out = X_one.copy()
1254
+ out["predicted_probability"] = proba_one
1255
+ out["predicted_class"] = pred_class
1256
+ out["risk_band"] = band_one(proba_one)
1257
+
1258
+ st.dataframe(out, use_container_width=True)
1259
+
1260
+ st.download_button(
1261
+ "Download single patient result (CSV)",
1262
+ out.to_csv(index=False).encode("utf-8"),
1263
+ file_name="single_patient_prediction.csv",
1264
+ mime="text/csv",
1265
+ key="dl_sp_csv",
1266
+ )
1267
+
1268
  infer_file = st.file_uploader("Upload inference Excel (.xlsx)", type=["xlsx"])
1269
  if infer_file:
1270
  df_inf = pd.read_excel(infer_file, engine="openpyxl")
 
1272
  if not meta:
1273
  st.error("Model metadata not loaded. Please load a model version.")
1274
  st.stop()
1275
+
1276
 
1277
  feature_cols = meta["schema"]["features"]
1278
  num_cols = meta["schema"]["numeric"]