Update app.py
Browse files
app.py
CHANGED
|
@@ -100,6 +100,31 @@ def get_feature_cols_from_df(df: pd.DataFrame):
|
|
| 100 |
f"Ensure your Excel header row contains a column named '{LABEL_COL}'.")
|
| 101 |
return [c for c in df.columns if c != LABEL_COL]
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
# ============================================================
|
| 105 |
# Model pipeline
|
|
@@ -1110,6 +1135,136 @@ with tab_predict:
|
|
| 1110 |
|
| 1111 |
pipe = st.session_state.pipe
|
| 1112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1113 |
infer_file = st.file_uploader("Upload inference Excel (.xlsx)", type=["xlsx"])
|
| 1114 |
if infer_file:
|
| 1115 |
df_inf = pd.read_excel(infer_file, engine="openpyxl")
|
|
@@ -1117,6 +1272,7 @@ with tab_predict:
|
|
| 1117 |
if not meta:
|
| 1118 |
st.error("Model metadata not loaded. Please load a model version.")
|
| 1119 |
st.stop()
|
|
|
|
| 1120 |
|
| 1121 |
feature_cols = meta["schema"]["features"]
|
| 1122 |
num_cols = meta["schema"]["numeric"]
|
|
|
|
| 100 |
f"Ensure your Excel header row contains a column named '{LABEL_COL}'.")
|
| 101 |
return [c for c in df.columns if c != LABEL_COL]
|
| 102 |
|
| 103 |
+
from datetime import date
|
| 104 |
+
|
| 105 |
+
def to_date_safe(x):
|
| 106 |
+
"""Accepts date/datetime/str; returns python date or None."""
|
| 107 |
+
if x is None or (isinstance(x, float) and np.isnan(x)):
|
| 108 |
+
return None
|
| 109 |
+
if isinstance(x, datetime):
|
| 110 |
+
return x.date()
|
| 111 |
+
if isinstance(x, date):
|
| 112 |
+
return x
|
| 113 |
+
# string
|
| 114 |
+
try:
|
| 115 |
+
return pd.to_datetime(x).date()
|
| 116 |
+
except Exception:
|
| 117 |
+
return None
|
| 118 |
+
|
| 119 |
+
def age_years_at(dob, ref_date):
|
| 120 |
+
"""Age in years at ref_date. Returns float or np.nan."""
|
| 121 |
+
dob = to_date_safe(dob)
|
| 122 |
+
ref_date = to_date_safe(ref_date)
|
| 123 |
+
if dob is None or ref_date is None:
|
| 124 |
+
return np.nan
|
| 125 |
+
if ref_date < dob:
|
| 126 |
+
return np.nan
|
| 127 |
+
return (ref_date - dob).days / 365.25
|
| 128 |
|
| 129 |
# ============================================================
|
| 130 |
# Model pipeline
|
|
|
|
| 1135 |
|
| 1136 |
pipe = st.session_state.pipe
|
| 1137 |
|
| 1138 |
+
# =========================
|
| 1139 |
+
# After model is loaded
|
| 1140 |
+
# =========================
|
| 1141 |
+
pipe = st.session_state.pipe
|
| 1142 |
+
meta = st.session_state.meta
|
| 1143 |
+
feature_cols = meta["schema"]["features"]
|
| 1144 |
+
num_cols = meta["schema"]["numeric"]
|
| 1145 |
+
cat_cols = meta["schema"]["categorical"]
|
| 1146 |
+
|
| 1147 |
+
st.divider()
|
| 1148 |
+
st.subheader("Single patient entry (DOB → Age, Dx date → prediction)")
|
| 1149 |
+
|
| 1150 |
+
AGE_FEATURE = "Age (years)"
|
| 1151 |
+
DX_DATE_FEATURE = "Date of 1st Bone Marrow biopsy (Date of Diagnosis) " # note trailing space
|
| 1152 |
+
|
| 1153 |
+
# safer date inputs (Streamlit versions vary with None defaults)
|
| 1154 |
+
c1, c2, c3 = st.columns(3)
|
| 1155 |
+
with c1:
|
| 1156 |
+
dob_tmp = st.date_input("Date of birth (DOB)", value=date.today(), key="sp_dob")
|
| 1157 |
+
use_dob = st.checkbox("Use DOB", value=False, key="sp_use_dob")
|
| 1158 |
+
dob = dob_tmp if use_dob else None
|
| 1159 |
+
|
| 1160 |
+
with c2:
|
| 1161 |
+
dx_tmp = st.date_input("Date of Diagnosis / 1st Bone Marrow biopsy", value=date.today(), key="sp_dx_date")
|
| 1162 |
+
use_dx = st.checkbox("Use Dx date", value=False, key="sp_use_dx")
|
| 1163 |
+
dx_date = dx_tmp if use_dx else None
|
| 1164 |
+
|
| 1165 |
+
with c3:
|
| 1166 |
+
ecog = st.selectbox("ECOG", options=[0, 1, 2, 3, 4], index=0, key="sp_ecog")
|
| 1167 |
+
|
| 1168 |
+
from typing import Optional
|
| 1169 |
+
def age_years_at(dob: Optional[date], ref_date: Optional[date]) -> float:
|
| 1170 |
+
if dob is None or ref_date is None:
|
| 1171 |
+
return np.nan
|
| 1172 |
+
if ref_date < dob:
|
| 1173 |
+
return np.nan
|
| 1174 |
+
return (ref_date - dob).days / 365.25
|
| 1175 |
+
|
| 1176 |
+
derived_age = age_years_at(dob, dx_date)
|
| 1177 |
+
|
| 1178 |
+
with st.expander("Enter remaining model features", expanded=False):
|
| 1179 |
+
with st.form("single_patient_form"):
|
| 1180 |
+
|
| 1181 |
+
values_by_index = [np.nan] * len(feature_cols)
|
| 1182 |
+
|
| 1183 |
+
for i, f in enumerate(feature_cols):
|
| 1184 |
+
|
| 1185 |
+
if f == AGE_FEATURE:
|
| 1186 |
+
st.number_input(
|
| 1187 |
+
f"{f} (auto)",
|
| 1188 |
+
value=None if np.isnan(derived_age) else float(derived_age),
|
| 1189 |
+
format="%.2f",
|
| 1190 |
+
key=f"sp_{i}_age_display",
|
| 1191 |
+
)
|
| 1192 |
+
values_by_index[i] = derived_age
|
| 1193 |
+
continue
|
| 1194 |
+
|
| 1195 |
+
if f.strip() == DX_DATE_FEATURE.strip():
|
| 1196 |
+
# if model expects numeric, store numeric timestamp
|
| 1197 |
+
if f in num_cols and dx_date is not None:
|
| 1198 |
+
values_by_index[i] = pd.Timestamp(dx_date).toordinal()
|
| 1199 |
+
else:
|
| 1200 |
+
values_by_index[i] = np.nan if dx_date is None else str(dx_date)
|
| 1201 |
+
continue
|
| 1202 |
+
|
| 1203 |
+
|
| 1204 |
+
if f.strip() == "ECOG":
|
| 1205 |
+
values_by_index[i] = ecog
|
| 1206 |
+
continue
|
| 1207 |
+
|
| 1208 |
+
|
| 1209 |
+
if f in num_cols:
|
| 1210 |
+
v = st.number_input(f, value=None, format="%.6f", key=f"sp_{i}_num")
|
| 1211 |
+
values_by_index[i] = v
|
| 1212 |
+
elif f in cat_cols:
|
| 1213 |
+
v = st.text_input(f, value="", key=f"sp_{i}_cat")
|
| 1214 |
+
values_by_index[i] = np.nan if v.strip() == "" else v
|
| 1215 |
+
else:
|
| 1216 |
+
v = st.text_input(f, value="", key=f"sp_{i}_other")
|
| 1217 |
+
values_by_index[i] = np.nan if v.strip() == "" else v
|
| 1218 |
+
|
| 1219 |
+
submitted = st.form_submit_button("Predict single patient")
|
| 1220 |
+
|
| 1221 |
+
# IMPORTANT: everything below must be inside "if submitted:"
|
| 1222 |
+
if submitted:
|
| 1223 |
+
X_one = pd.DataFrame([values_by_index], columns=feature_cols).replace({pd.NA: np.nan})
|
| 1224 |
+
|
| 1225 |
+
for c in num_cols:
|
| 1226 |
+
if c in X_one.columns:
|
| 1227 |
+
X_one[c] = pd.to_numeric(X_one[c], errors="coerce")
|
| 1228 |
+
|
| 1229 |
+
for c in cat_cols:
|
| 1230 |
+
if c in X_one.columns:
|
| 1231 |
+
X_one[c] = X_one[c].astype("object")
|
| 1232 |
+
X_one.loc[X_one[c].isna(), c] = np.nan
|
| 1233 |
+
X_one[c] = X_one[c].map(lambda v: v if pd.isna(v) else str(v))
|
| 1234 |
+
|
| 1235 |
+
proba_one = float(pipe.predict_proba(X_one)[:, 1][0])
|
| 1236 |
+
st.success("Prediction generated.")
|
| 1237 |
+
st.metric("Predicted probability", f"{proba_one:.4f}")
|
| 1238 |
+
|
| 1239 |
+
thr_single = st.slider("Classification threshold", 0.0, 1.0, 0.5, 0.01, key="sp_thr")
|
| 1240 |
+
pred_class = int(proba_one >= thr_single)
|
| 1241 |
+
|
| 1242 |
+
low_cut_s, high_cut_s = st.slider(
|
| 1243 |
+
"Risk band cutoffs (low, high)", 0.0, 1.0, (0.2, 0.8), 0.01, key="sp_risk_cuts"
|
| 1244 |
+
)
|
| 1245 |
+
|
| 1246 |
+
def band_one(p):
|
| 1247 |
+
if p < low_cut_s:
|
| 1248 |
+
return "Low"
|
| 1249 |
+
if p >= high_cut_s:
|
| 1250 |
+
return "High"
|
| 1251 |
+
return "Intermediate"
|
| 1252 |
+
|
| 1253 |
+
out = X_one.copy()
|
| 1254 |
+
out["predicted_probability"] = proba_one
|
| 1255 |
+
out["predicted_class"] = pred_class
|
| 1256 |
+
out["risk_band"] = band_one(proba_one)
|
| 1257 |
+
|
| 1258 |
+
st.dataframe(out, use_container_width=True)
|
| 1259 |
+
|
| 1260 |
+
st.download_button(
|
| 1261 |
+
"Download single patient result (CSV)",
|
| 1262 |
+
out.to_csv(index=False).encode("utf-8"),
|
| 1263 |
+
file_name="single_patient_prediction.csv",
|
| 1264 |
+
mime="text/csv",
|
| 1265 |
+
key="dl_sp_csv",
|
| 1266 |
+
)
|
| 1267 |
+
|
| 1268 |
infer_file = st.file_uploader("Upload inference Excel (.xlsx)", type=["xlsx"])
|
| 1269 |
if infer_file:
|
| 1270 |
df_inf = pd.read_excel(infer_file, engine="openpyxl")
|
|
|
|
| 1272 |
if not meta:
|
| 1273 |
st.error("Model metadata not loaded. Please load a model version.")
|
| 1274 |
st.stop()
|
| 1275 |
+
|
| 1276 |
|
| 1277 |
feature_cols = meta["schema"]["features"]
|
| 1278 |
num_cols = meta["schema"]["numeric"]
|