Update app.py
Browse files
app.py
CHANGED
|
@@ -145,7 +145,7 @@ def build_pipeline(
|
|
| 145 |
|
| 146 |
cat_pipe = Pipeline([
|
| 147 |
("imputer", SimpleImputer(strategy="most_frequent")),
|
| 148 |
-
("onehot", OneHotEncoder(handle_unknown="ignore",
|
| 149 |
])
|
| 150 |
|
| 151 |
preprocessor = ColumnTransformer(
|
|
@@ -813,7 +813,7 @@ st.warning(
|
|
| 813 |
with st.expander("Admin controls", expanded=False):
|
| 814 |
st.text_input("Admin key", type="password", key="admin_key")
|
| 815 |
st.caption("Training and publishing are enabled only after admin authentication.")
|
| 816 |
-
|
| 817 |
|
| 818 |
|
| 819 |
tab_train, tab_predict = st.tabs(["1️⃣ Train", "2️⃣ Predict + SHAP"])
|
|
@@ -823,34 +823,40 @@ if "pipe" not in st.session_state:
|
|
| 823 |
if "explainer" not in st.session_state:
|
| 824 |
st.session_state.explainer = None
|
| 825 |
|
|
|
|
|
|
|
| 826 |
|
|
|
|
|
|
|
|
|
|
| 827 |
|
|
|
|
| 828 |
|
| 829 |
-
st.
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
|
| 833 |
-
|
| 834 |
-
|
| 835 |
-
|
| 836 |
-
|
| 837 |
-
|
| 838 |
-
|
| 839 |
-
|
| 840 |
-
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
|
| 844 |
-
|
|
|
|
|
|
|
|
|
|
| 845 |
|
| 846 |
-
|
| 847 |
-
"SVD components (only used if enabled)",
|
| 848 |
-
5, 300, 50, 5
|
| 849 |
-
) if use_dimred else 50
|
| 850 |
|
| 851 |
# ---------------- TRAIN ----------------
|
| 852 |
|
| 853 |
-
|
| 854 |
|
| 855 |
st.subheader("Train model")
|
| 856 |
if not is_admin():
|
|
@@ -1133,7 +1139,7 @@ with tab_predict:
|
|
| 1133 |
st.divider()
|
| 1134 |
if st.session_state.pipe is None:
|
| 1135 |
st.warning("Load a model version above, then upload an inference Excel.")
|
| 1136 |
-
|
| 1137 |
|
| 1138 |
pipe = st.session_state.pipe
|
| 1139 |
|
|
@@ -1258,23 +1264,23 @@ with tab_predict:
|
|
| 1258 |
# --- header dates ---
|
| 1259 |
c1, c2 = st.columns(2)
|
| 1260 |
with c1:
|
| 1261 |
-
|
| 1262 |
-
|
| 1263 |
-
|
| 1264 |
-
min_value=MIN_DOB,
|
| 1265 |
-
|
| 1266 |
-
key="sp_dob",
|
| 1267 |
-
)
|
| 1268 |
with c2:
|
| 1269 |
-
|
| 1270 |
-
|
| 1271 |
-
|
| 1272 |
-
|
| 1273 |
-
|
| 1274 |
-
|
| 1275 |
-
|
|
|
|
| 1276 |
|
| 1277 |
-
derived_age = age_years_at(dob, dx_date)
|
|
|
|
| 1278 |
|
| 1279 |
def yesno_to_01(v: str):
|
| 1280 |
if v == "Yes":
|
|
@@ -1393,14 +1399,7 @@ with tab_predict:
|
|
| 1393 |
v = st.text_input(f, value="", key=f"sp_{i}_other")
|
| 1394 |
values_by_index[i] = np.nan if v.strip() == "" else v
|
| 1395 |
|
| 1396 |
-
|
| 1397 |
-
st.caption("Clinical flags: Yes=1, No=0")
|
| 1398 |
-
for i, f in enumerate(feature_cols):
|
| 1399 |
-
if f in YESNO_FIELDS:
|
| 1400 |
-
v = st.selectbox(f, options=["", "No", "Yes"], index=0, key=f"sp_{i}_yn")
|
| 1401 |
-
values_by_index[i] = yesno_to_01(v)
|
| 1402 |
-
|
| 1403 |
-
|
| 1404 |
# Apply FISH/NGS selections to row
|
| 1405 |
fish_set = set(fish_selected)
|
| 1406 |
ngs_set = set(ngs_selected)
|
|
@@ -1529,11 +1528,8 @@ with tab_predict:
|
|
| 1529 |
cls_ext = compute_classification_metrics(y_ext01, proba, threshold=float(thr_ext))
|
| 1530 |
|
| 1531 |
pr_ext = compute_pr_curve(y_ext01, proba)
|
| 1532 |
-
cal_ext = compute_calibration(
|
| 1533 |
-
|
| 1534 |
-
n_bins=int(n_bins) if "n_bins" in locals() else 10,
|
| 1535 |
-
strategy=str(cal_strategy) if "cal_strategy" in locals() else "uniform"
|
| 1536 |
-
)
|
| 1537 |
dca_ext = decision_curve_analysis(y_ext01, proba)
|
| 1538 |
|
| 1539 |
# Display headline metrics
|
|
@@ -1819,16 +1815,17 @@ with tab_predict:
|
|
| 1819 |
|
| 1820 |
|
| 1821 |
# BEESWARM SUMMARY (optional)
|
| 1822 |
-
|
| 1823 |
-
|
| 1824 |
-
|
| 1825 |
-
|
| 1826 |
-
|
| 1827 |
-
|
| 1828 |
-
|
| 1829 |
-
|
| 1830 |
-
|
| 1831 |
-
|
|
|
|
| 1832 |
|
| 1833 |
|
| 1834 |
|
|
|
|
| 145 |
|
| 146 |
cat_pipe = Pipeline([
|
| 147 |
("imputer", SimpleImputer(strategy="most_frequent")),
|
| 148 |
+
("onehot", OneHotEncoder(handle_unknown="ignore", sparse=True, drop="first"))
|
| 149 |
])
|
| 150 |
|
| 151 |
preprocessor = ColumnTransformer(
|
|
|
|
| 813 |
with st.expander("Admin controls", expanded=False):
|
| 814 |
st.text_input("Admin key", type="password", key="admin_key")
|
| 815 |
st.caption("Training and publishing are enabled only after admin authentication.")
|
| 816 |
+
|
| 817 |
|
| 818 |
|
| 819 |
tab_train, tab_predict = st.tabs(["1️⃣ Train", "2️⃣ Predict + SHAP"])
|
|
|
|
| 823 |
if "explainer" not in st.session_state:
|
| 824 |
st.session_state.explainer = None
|
| 825 |
|
| 826 |
+
with tab_train:
|
| 827 |
+
st.subheader("Train model")
|
| 828 |
|
| 829 |
+
if not is_admin():
|
| 830 |
+
st.info("Training and publishing are restricted. Use Predict + SHAP for inference.")
|
| 831 |
+
st.stop()
|
| 832 |
|
| 833 |
+
st.markdown("### Feature reduction options")
|
| 834 |
|
| 835 |
+
use_feature_selection = st.checkbox(
|
| 836 |
+
"Drop columns that do not affect prediction (L1 feature selection)",
|
| 837 |
+
value=True,
|
| 838 |
+
key="train_use_feature_selection"
|
| 839 |
+
)
|
| 840 |
+
l1_C = st.slider(
|
| 841 |
+
"L1 selection strength (lower = fewer features)",
|
| 842 |
+
0.01, 10.0, 1.0, 0.01
|
| 843 |
+
) if use_feature_selection else 1.0
|
| 844 |
+
|
| 845 |
+
use_dimred = st.checkbox(
|
| 846 |
+
"Dimensionality reduction (TruncatedSVD) — reduces interpretability",
|
| 847 |
+
value=False
|
| 848 |
+
)
|
| 849 |
+
|
| 850 |
+
svd_components = st.slider(
|
| 851 |
+
"SVD components (only used if enabled)",
|
| 852 |
+
5, 300, 50, 5
|
| 853 |
+
) if use_dimred else 50
|
| 854 |
|
| 855 |
+
st.divider()
|
|
|
|
|
|
|
|
|
|
| 856 |
|
| 857 |
# ---------------- TRAIN ----------------
|
| 858 |
|
| 859 |
+
|
| 860 |
|
| 861 |
st.subheader("Train model")
|
| 862 |
if not is_admin():
|
|
|
|
| 1139 |
st.divider()
|
| 1140 |
if st.session_state.pipe is None:
|
| 1141 |
st.warning("Load a model version above, then upload an inference Excel.")
|
| 1142 |
+
|
| 1143 |
|
| 1144 |
pipe = st.session_state.pipe
|
| 1145 |
|
|
|
|
| 1264 |
# --- header dates ---
|
| 1265 |
c1, c2 = st.columns(2)
|
| 1266 |
with c1:
|
| 1267 |
+
dob_unknown = st.checkbox("DOB unknown", value=False, key="dob_unknown")
|
| 1268 |
+
dob = None
|
| 1269 |
+
if not dob_unknown:
|
| 1270 |
+
dob = st.date_input("Date of birth (DOB)", min_value=MIN_DOB, max_value=date.today(), key="dob")
|
| 1271 |
+
|
|
|
|
|
|
|
| 1272 |
with c2:
|
| 1273 |
+
dx_unknown = st.checkbox("Diagnosis date unknown", value=False, key="dx_unknown")
|
| 1274 |
+
dx_date = None
|
| 1275 |
+
if not dx_unknown:
|
| 1276 |
+
dx_date = st.date_input(
|
| 1277 |
+
"Date of Diagnosis / 1st Bone Marrow biopsy",
|
| 1278 |
+
min_value=MIN_DOB, max_value=date.today(),
|
| 1279 |
+
key="dx_date"
|
| 1280 |
+
)
|
| 1281 |
|
| 1282 |
+
derived_age = age_years_at(dob, dx_date)
|
| 1283 |
+
|
| 1284 |
|
| 1285 |
def yesno_to_01(v: str):
|
| 1286 |
if v == "Yes":
|
|
|
|
| 1399 |
v = st.text_input(f, value="", key=f"sp_{i}_other")
|
| 1400 |
values_by_index[i] = np.nan if v.strip() == "" else v
|
| 1401 |
|
| 1402 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1403 |
# Apply FISH/NGS selections to row
|
| 1404 |
fish_set = set(fish_selected)
|
| 1405 |
ngs_set = set(ngs_selected)
|
|
|
|
| 1528 |
cls_ext = compute_classification_metrics(y_ext01, proba, threshold=float(thr_ext))
|
| 1529 |
|
| 1530 |
pr_ext = compute_pr_curve(y_ext01, proba)
|
| 1531 |
+
cal_ext = compute_calibration(y_ext01, proba, n_bins=PRED_N_BINS, strategy=PRED_CAL_STRATEGY)
|
| 1532 |
+
|
|
|
|
|
|
|
|
|
|
| 1533 |
dca_ext = decision_curve_analysis(y_ext01, proba)
|
| 1534 |
|
| 1535 |
# Display headline metrics
|
|
|
|
| 1815 |
|
| 1816 |
|
| 1817 |
# BEESWARM SUMMARY (optional)
|
| 1818 |
+
if show_beeswarm:
|
| 1819 |
+
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 1820 |
+
shap.summary_plot(
|
| 1821 |
+
shap_vals_batch,
|
| 1822 |
+
features=X_dense,
|
| 1823 |
+
feature_names=names,
|
| 1824 |
+
max_display=max_display,
|
| 1825 |
+
show=False
|
| 1826 |
+
)
|
| 1827 |
+
fig_swarm = plt.gcf()
|
| 1828 |
+
render_plot_with_download(fig_swarm, title="SHAP beeswarm", filename="shap_beeswarm.png", export_dpi=export_dpi)
|
| 1829 |
|
| 1830 |
|
| 1831 |
|