Update app.py
Browse files
app.py
CHANGED
|
@@ -2198,9 +2198,7 @@ with tab_predict:
|
|
| 2198 |
# =========================
|
| 2199 |
pipe = st.session_state.pipe
|
| 2200 |
meta = st.session_state.meta
|
| 2201 |
-
|
| 2202 |
-
num_cols = meta["schema"]["numeric"]
|
| 2203 |
-
cat_cols = meta["schema"]["categorical"]
|
| 2204 |
|
| 2205 |
|
| 2206 |
|
|
@@ -2426,300 +2424,282 @@ with tab_predict:
|
|
| 2426 |
values_by_index[i] = np.nan if v.strip() == "" else v
|
| 2427 |
|
| 2428 |
|
| 2429 |
-
|
| 2430 |
-
|
| 2431 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2432 |
|
| 2433 |
-
|
| 2434 |
-
|
| 2435 |
-
|
| 2436 |
-
|
| 2437 |
-
|
| 2438 |
-
|
| 2439 |
-
|
| 2440 |
-
if FISH_COUNT_COL in feature_cols:
|
| 2441 |
-
values_by_index[feature_cols.index(FISH_COUNT_COL)] = int(len(fish_selected))
|
| 2442 |
-
if NGS_COUNT_COL in feature_cols:
|
| 2443 |
-
values_by_index[feature_cols.index(NGS_COUNT_COL)] = int(len(ngs_selected))
|
| 2444 |
-
st.divider()
|
| 2445 |
-
st.subheader("Predict single patient")
|
| 2446 |
|
| 2447 |
-
|
| 2448 |
-
|
| 2449 |
-
|
| 2450 |
-
|
| 2451 |
-
|
| 2452 |
-
|
| 2453 |
-
|
| 2454 |
-
|
| 2455 |
-
|
| 2456 |
-
|
| 2457 |
-
|
| 2458 |
-
|
| 2459 |
-
|
| 2460 |
-
|
| 2461 |
-
|
| 2462 |
|
| 2463 |
-
|
| 2464 |
-
|
| 2465 |
-
|
| 2466 |
-
|
| 2467 |
-
|
| 2468 |
-
|
| 2469 |
-
|
| 2470 |
-
|
| 2471 |
-
|
| 2472 |
|
| 2473 |
|
| 2474 |
-
|
| 2475 |
-
|
| 2476 |
-
|
| 2477 |
-
|
| 2478 |
-
|
| 2479 |
-
|
| 2480 |
-
|
| 2481 |
-
|
| 2482 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2483 |
|
| 2484 |
-
|
| 2485 |
-
|
| 2486 |
-
|
|
|
|
| 2487 |
|
| 2488 |
-
|
| 2489 |
-
|
| 2490 |
-
X_one[c] = X_one[c].astype("object")
|
| 2491 |
-
X_one.loc[X_one[c].isna(), c] = np.nan
|
| 2492 |
-
X_one[c] = X_one[c].map(lambda v: v if pd.isna(v) else str(v))
|
| 2493 |
-
|
| 2494 |
|
| 2495 |
-
|
| 2496 |
-
|
| 2497 |
-
st.metric("Predicted probability", f"{proba_one:.4f}")
|
| 2498 |
-
|
| 2499 |
-
|
| 2500 |
-
# ---- Survival prediction for this patient (if survival model loaded) ----
|
| 2501 |
-
cph = st.session_state.get("surv_model", None)
|
| 2502 |
-
|
| 2503 |
-
if cph is not None:
|
| 2504 |
-
try:
|
| 2505 |
-
# Build Cox input row with same preprocessing used in training Cox:
|
| 2506 |
-
# predictors + one-hot categoricals
|
| 2507 |
-
df_one_surv = X_one[feature_cols].copy()
|
| 2508 |
-
for c in num_cols:
|
| 2509 |
df_one_surv[c] = pd.to_numeric(df_one_surv[c], errors="coerce")
|
| 2510 |
-
|
| 2511 |
-
|
| 2512 |
-
|
| 2513 |
-
|
| 2514 |
-
|
| 2515 |
-
|
| 2516 |
-
|
| 2517 |
-
|
| 2518 |
-
|
| 2519 |
-
|
| 2520 |
-
|
| 2521 |
-
|
| 2522 |
-
|
| 2523 |
-
|
| 2524 |
-
|
| 2525 |
-
|
| 2526 |
-
|
| 2527 |
-
|
| 2528 |
-
|
| 2529 |
-
|
| 2530 |
-
|
| 2531 |
-
|
| 2532 |
-
|
| 2533 |
-
|
| 2534 |
-
|
| 2535 |
-
|
| 2536 |
-
|
| 2537 |
-
|
| 2538 |
-
|
| 2539 |
-
|
| 2540 |
-
|
| 2541 |
-
|
| 2542 |
-
|
| 2543 |
-
|
| 2544 |
-
|
| 2545 |
-
|
| 2546 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2547 |
|
| 2548 |
-
|
| 2549 |
-
surv_fn = cph.predict_survival_function(df_one_surv_oh)
|
| 2550 |
|
| 2551 |
-
|
| 2552 |
-
|
| 2553 |
-
|
| 2554 |
-
|
| 2555 |
-
|
| 2556 |
-
|
| 2557 |
-
|
| 2558 |
-
s1y = surv_at(365)
|
| 2559 |
-
s2y = surv_at(730)
|
| 2560 |
-
s3y = surv_at(1095)
|
| 2561 |
-
|
| 2562 |
-
st.subheader("Predicted survival probability")
|
| 2563 |
-
a,b,c,d = st.columns(4)
|
| 2564 |
-
a.metric("6 months", f"{s6m*100:.1f}%")
|
| 2565 |
-
b.metric("1 year", f"{s1y*100:.1f}%")
|
| 2566 |
-
c.metric("2 years", f"{s2y*100:.1f}%")
|
| 2567 |
-
d.metric("3 years", f"{s3y*100:.1f}%")
|
| 2568 |
-
|
| 2569 |
-
# Optional: plot curve
|
| 2570 |
-
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 2571 |
-
ax.plot(surv_fn.index, surv_fn.values)
|
| 2572 |
-
ax.set_xlabel("Days from diagnosis")
|
| 2573 |
-
ax.set_ylabel("Survival probability")
|
| 2574 |
-
ax.set_ylim(0, 1)
|
| 2575 |
-
ax.set_title("Predicted survival curve (patient)")
|
| 2576 |
-
|
| 2577 |
-
render_plot_with_download(
|
| 2578 |
-
fig,
|
| 2579 |
-
title="Patient survival curve",
|
| 2580 |
-
filename="patient_survival_curve.png",
|
| 2581 |
-
export_dpi=export_dpi,
|
| 2582 |
-
key="dl_patient_surv_curve"
|
| 2583 |
-
)
|
| 2584 |
|
| 2585 |
-
|
| 2586 |
-
st.warning(f"Survival prediction could not be computed: {e}")
|
| 2587 |
-
else:
|
| 2588 |
-
st.info("Survival model not loaded/published yet (survival_model.joblib missing).")
|
| 2589 |
|
| 2590 |
-
|
| 2591 |
-
|
| 2592 |
-
|
| 2593 |
-
|
| 2594 |
-
|
| 2595 |
-
|
| 2596 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2597 |
|
| 2598 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2599 |
|
| 2600 |
-
|
| 2601 |
-
|
| 2602 |
-
|
| 2603 |
-
|
| 2604 |
-
|
| 2605 |
-
|
| 2606 |
-
|
| 2607 |
-
|
| 2608 |
-
None if st.session_state.get("X_bg_for_shap") is None else int(len(st.session_state["X_bg_for_shap"]))
|
| 2609 |
-
)
|
| 2610 |
-
|
| 2611 |
-
if explainer is None or explainer_sig != current_sig:
|
| 2612 |
-
X_bg = st.session_state.get("X_bg_for_shap")
|
| 2613 |
-
if X_bg is None:
|
| 2614 |
-
st.error("SHAP background not available. Admin must publish latest/background.csv.")
|
| 2615 |
-
st.stop()
|
| 2616 |
-
|
| 2617 |
-
st.session_state.explainer = build_shap_explainer(pipe, X_bg)
|
| 2618 |
-
st.session_state.explainer_sig = current_sig
|
| 2619 |
-
explainer = st.session_state.explainer
|
| 2620 |
-
|
| 2621 |
-
shap_vals = explainer.shap_values(X_one_t)
|
| 2622 |
-
if isinstance(shap_vals, list):
|
| 2623 |
-
shap_vals = shap_vals[1]
|
| 2624 |
-
|
| 2625 |
-
names = get_final_feature_names(pipe)
|
| 2626 |
-
|
| 2627 |
-
try:
|
| 2628 |
-
x_dense = X_one_t.toarray()[0]
|
| 2629 |
-
except Exception:
|
| 2630 |
-
x_dense = np.array(X_one_t)[0]
|
| 2631 |
-
|
| 2632 |
-
base = explainer.expected_value
|
| 2633 |
-
if not np.isscalar(base):
|
| 2634 |
-
base = float(np.array(base).reshape(-1)[0])
|
| 2635 |
-
|
| 2636 |
-
exp = shap.Explanation(
|
| 2637 |
-
values=shap_vals[0],
|
| 2638 |
-
base_values=float(base),
|
| 2639 |
-
data=x_dense,
|
| 2640 |
-
feature_names=names,
|
| 2641 |
)
|
| 2642 |
-
|
| 2643 |
-
# CACHE ONLY
|
| 2644 |
-
st.session_state.shap_single_exp = exp
|
| 2645 |
-
|
| 2646 |
|
| 2647 |
-
|
| 2648 |
|
| 2649 |
-
|
| 2650 |
-
|
| 2651 |
-
|
| 2652 |
-
|
| 2653 |
-
|
| 2654 |
-
|
| 2655 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2656 |
|
| 2657 |
-
# ---
|
| 2658 |
-
|
| 2659 |
-
|
|
|
|
|
|
|
|
|
|
| 2660 |
|
| 2661 |
-
|
| 2662 |
-
|
| 2663 |
-
|
| 2664 |
-
key="sp_single_max_display"
|
| 2665 |
-
)
|
| 2666 |
|
| 2667 |
-
|
|
|
|
|
|
|
| 2668 |
|
| 2669 |
-
|
| 2670 |
-
|
| 2671 |
-
|
| 2672 |
-
|
| 2673 |
-
|
| 2674 |
-
|
| 2675 |
-
title="Single-patient SHAP waterfall",
|
| 2676 |
-
filename="single_patient_shap_waterfall.png",
|
| 2677 |
-
export_dpi=export_dpi,
|
| 2678 |
-
key="dl_sp_wf"
|
| 2679 |
-
)
|
| 2680 |
|
| 2681 |
-
|
| 2682 |
-
|
| 2683 |
-
|
| 2684 |
-
|
| 2685 |
-
|
| 2686 |
-
|
| 2687 |
-
|
| 2688 |
-
|
| 2689 |
-
export_dpi=export_dpi,
|
| 2690 |
-
key="dl_sp_bar"
|
| 2691 |
-
)
|
| 2692 |
-
|
| 2693 |
-
# --- After SHAP plots are displayed ---
|
| 2694 |
-
with st.expander("How to interpret the SHAP explanation plots", expanded=True):
|
| 2695 |
-
st.markdown(r"""
|
| 2696 |
-
**What is SHAP?**
|
| 2697 |
-
SHAP (SHapley Additive exPlanations) decomposes a model prediction into **feature-wise contributions**.
|
| 2698 |
-
Each feature pushes the prediction **toward higher risk** or **toward lower risk**, relative to the model’s baseline.
|
| 2699 |
-
|
| 2700 |
-
**What is \(E[f(X)]\)? (baseline / expected model output)**
|
| 2701 |
-
- \(E[f(X)]\) is the model’s **average output** over the reference population used by the explainer (typically the training set).
|
| 2702 |
-
- It is the **starting point** of the explanation before any patient-specific information is applied.
|
| 2703 |
-
|
| 2704 |
-
**What is \(f(x)\)? (this patient’s model output)**
|
| 2705 |
-
- \(f(x)\) is the model’s **final output for this patient**, after adding all feature contributions to the baseline.
|
| 2706 |
-
- For many classifiers (including logistic regression), SHAP waterfall plots often use the **log-odds (logit) scale** rather than raw probability.
|
| 2707 |
-
|
| 2708 |
-
**How the waterfall plot should be read**
|
| 2709 |
-
- The plot starts at **\(E[f(X)]\)** and then adds/subtracts feature effects to arrive at **\(f(x)\)**.
|
| 2710 |
-
- **Red bars** push the prediction **toward higher risk** (increase the model output).
|
| 2711 |
-
- **Blue bars** push the prediction **toward lower risk / protective direction** (decrease the model output).
|
| 2712 |
-
- The **bar length** indicates the **strength** of that feature’s influence for this patient.
|
| 2713 |
-
- “**Other features**” is the combined net effect of features not shown individually.
|
| 2714 |
-
|
| 2715 |
-
**How the bar plot should be read (Top contributors)**
|
| 2716 |
-
- Features are ranked by **absolute SHAP value** (largest impact first).
|
| 2717 |
-
- **Positive SHAP** (right) increases predicted risk; **negative SHAP** (left) decreases predicted risk.
|
| 2718 |
-
|
| 2719 |
-
**Clinical cautions**
|
| 2720 |
-
- SHAP explains the model’s behavior; it does **not** prove causality.
|
| 2721 |
-
- Ensure variable definitions and patient population match the model’s intended use.
|
| 2722 |
-
""")
|
| 2723 |
|
| 2724 |
|
| 2725 |
# -----------------------------
|
|
@@ -2915,29 +2895,51 @@ with tab_predict:
|
|
| 2915 |
cph = None
|
| 2916 |
surv_cols = []
|
| 2917 |
|
| 2918 |
-
if
|
|
|
|
|
|
|
|
|
|
| 2919 |
try:
|
| 2920 |
-
|
|
|
|
|
|
|
|
|
|
| 2921 |
|
| 2922 |
for c in num_cols:
|
| 2923 |
-
if c in
|
| 2924 |
-
|
|
|
|
| 2925 |
for c in cat_cols:
|
| 2926 |
-
if c in
|
| 2927 |
-
|
|
|
|
|
|
|
| 2928 |
|
| 2929 |
-
|
| 2930 |
|
| 2931 |
# Align to training predictor columns
|
| 2932 |
for col in surv_cols:
|
| 2933 |
-
if col not in
|
| 2934 |
-
|
| 2935 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2936 |
|
| 2937 |
-
surv_fn = cph.predict_survival_function(df_one_surv_oh)
|
| 2938 |
-
...
|
| 2939 |
except Exception as e:
|
| 2940 |
-
st.warning(f"
|
|
|
|
|
|
|
|
|
|
| 2941 |
|
| 2942 |
|
| 2943 |
surv_fn_all = cph.predict_survival_function(df_surv_in_oh)
|
|
|
|
| 2198 |
# =========================
|
| 2199 |
pipe = st.session_state.pipe
|
| 2200 |
meta = st.session_state.meta
|
| 2201 |
+
|
|
|
|
|
|
|
| 2202 |
|
| 2203 |
|
| 2204 |
|
|
|
|
| 2424 |
values_by_index[i] = np.nan if v.strip() == "" else v
|
| 2425 |
|
| 2426 |
|
| 2427 |
+
# Apply FISH/NGS selections to row
|
| 2428 |
+
fish_set = set(fish_selected)
|
| 2429 |
+
ngs_set = set(ngs_selected)
|
| 2430 |
+
|
| 2431 |
+
for i, f in enumerate(feature_cols):
|
| 2432 |
+
if f in FISH_MARKERS:
|
| 2433 |
+
values_by_index[i] = 1 if f in fish_set else 0
|
| 2434 |
+
if f in NGS_MARKERS:
|
| 2435 |
+
values_by_index[i] = 1 if f in ngs_set else 0
|
| 2436 |
|
| 2437 |
+
# Auto-fill marker counts if present
|
| 2438 |
+
if FISH_COUNT_COL in feature_cols:
|
| 2439 |
+
values_by_index[feature_cols.index(FISH_COUNT_COL)] = int(len(fish_selected))
|
| 2440 |
+
if NGS_COUNT_COL in feature_cols:
|
| 2441 |
+
values_by_index[feature_cols.index(NGS_COUNT_COL)] = int(len(ngs_selected))
|
| 2442 |
+
st.divider()
|
| 2443 |
+
st.subheader("Predict single patient")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2444 |
|
| 2445 |
+
m = meta.get("metrics", {})
|
| 2446 |
+
default_thr = float(m.get("best_threshold", 0.5))
|
| 2447 |
+
|
| 2448 |
+
thr_single = st.slider(
|
| 2449 |
+
"Classification threshold",
|
| 2450 |
+
0.0, 1.0, default_thr, 0.01,
|
| 2451 |
+
key="sp_thr"
|
| 2452 |
+
)
|
| 2453 |
+
|
| 2454 |
+
# External validation threshold
|
| 2455 |
+
thr_ext = st.slider(
|
| 2456 |
+
"External validation threshold",
|
| 2457 |
+
0.0, 1.0, default_thr, 0.01,
|
| 2458 |
+
key="thr_ext"
|
| 2459 |
+
)
|
| 2460 |
|
| 2461 |
+
low_cut_s, high_cut_s = st.slider(
|
| 2462 |
+
"Risk band cutoffs (low, high)",
|
| 2463 |
+
0.0, 1.0, (0.2, 0.8), 0.01,
|
| 2464 |
+
key="sp_risk_cuts"
|
| 2465 |
+
)
|
| 2466 |
+
|
| 2467 |
+
# Ensure low <= high
|
| 2468 |
+
if low_cut_s > high_cut_s:
|
| 2469 |
+
low_cut_s, high_cut_s = high_cut_s, low_cut_s
|
| 2470 |
|
| 2471 |
|
| 2472 |
+
def band_one(p: float) -> str:
|
| 2473 |
+
if p < low_cut_s:
|
| 2474 |
+
return "Low"
|
| 2475 |
+
if p >= high_cut_s:
|
| 2476 |
+
return "High"
|
| 2477 |
+
return "Intermediate"
|
| 2478 |
+
# Submit button (no form needed; simpler + fewer state surprises)
|
| 2479 |
+
if st.button("Predict single patient", key="sp_predict_btn"):
|
| 2480 |
+
X_one = pd.DataFrame([values_by_index], columns=feature_cols).replace({pd.NA: np.nan})
|
| 2481 |
+
|
| 2482 |
+
for c in num_cols:
|
| 2483 |
+
if c in X_one.columns:
|
| 2484 |
+
X_one[c] = pd.to_numeric(X_one[c], errors="coerce")
|
| 2485 |
+
|
| 2486 |
+
for c in cat_cols:
|
| 2487 |
+
if c in X_one.columns:
|
| 2488 |
+
X_one[c] = X_one[c].astype("object")
|
| 2489 |
+
X_one.loc[X_one[c].isna(), c] = np.nan
|
| 2490 |
+
X_one[c] = X_one[c].map(lambda v: v if pd.isna(v) else str(v))
|
| 2491 |
+
|
| 2492 |
+
|
| 2493 |
+
proba_one = float(pipe.predict_proba(X_one)[:, 1][0])
|
| 2494 |
+
st.success("Prediction generated.")
|
| 2495 |
+
st.metric("Predicted probability", f"{proba_one:.4f}")
|
| 2496 |
+
|
| 2497 |
+
|
| 2498 |
+
# ---- Survival prediction for this patient (if survival model loaded) ----
|
| 2499 |
+
# ---- Survival prediction for this patient (if survival model loaded) ----
|
| 2500 |
+
bundle = st.session_state.get("surv_model", None)
|
| 2501 |
|
| 2502 |
+
if isinstance(bundle, dict) and bundle.get("model") is not None:
|
| 2503 |
+
try:
|
| 2504 |
+
cph = bundle["model"]
|
| 2505 |
+
surv_cols = bundle.get("columns", [])
|
| 2506 |
|
| 2507 |
+
# Build Cox input row (same preprocessing as Cox training)
|
| 2508 |
+
df_one_surv = X_one[feature_cols].copy()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2509 |
|
| 2510 |
+
for c in num_cols:
|
| 2511 |
+
if c in df_one_surv.columns:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2512 |
df_one_surv[c] = pd.to_numeric(df_one_surv[c], errors="coerce")
|
| 2513 |
+
|
| 2514 |
+
for c in cat_cols:
|
| 2515 |
+
if c in df_one_surv.columns:
|
| 2516 |
+
df_one_surv[c] = df_one_surv[c].astype("object").map(
|
| 2517 |
+
lambda v: v if pd.isna(v) else str(v)
|
| 2518 |
+
)
|
| 2519 |
+
|
| 2520 |
+
df_one_surv_oh = pd.get_dummies(df_one_surv, columns=cat_cols, drop_first=True)
|
| 2521 |
+
|
| 2522 |
+
# Align to training predictor columns
|
| 2523 |
+
for col in surv_cols:
|
| 2524 |
+
if col not in df_one_surv_oh.columns:
|
| 2525 |
+
df_one_surv_oh[col] = 0
|
| 2526 |
+
df_one_surv_oh = df_one_surv_oh[surv_cols]
|
| 2527 |
+
|
| 2528 |
+
# Predict survival function
|
| 2529 |
+
surv_fn = cph.predict_survival_function(df_one_surv_oh)
|
| 2530 |
+
|
| 2531 |
+
def surv_at(days: int) -> float:
|
| 2532 |
+
idx = surv_fn.index.values
|
| 2533 |
+
j = int(np.argmin(np.abs(idx - days)))
|
| 2534 |
+
return float(surv_fn.iloc[j, 0])
|
| 2535 |
+
|
| 2536 |
+
s6m = surv_at(180)
|
| 2537 |
+
s1y = surv_at(365)
|
| 2538 |
+
s2y = surv_at(730)
|
| 2539 |
+
s3y = surv_at(1095)
|
| 2540 |
+
|
| 2541 |
+
st.subheader("Predicted survival probability")
|
| 2542 |
+
a, b, c, d = st.columns(4)
|
| 2543 |
+
a.metric("6 months", f"{s6m*100:.1f}%")
|
| 2544 |
+
b.metric("1 year", f"{s1y*100:.1f}%")
|
| 2545 |
+
c.metric("2 years", f"{s2y*100:.1f}%")
|
| 2546 |
+
d.metric("3 years", f"{s3y*100:.1f}%")
|
| 2547 |
+
|
| 2548 |
+
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 2549 |
+
ax.plot(surv_fn.index, surv_fn.values)
|
| 2550 |
+
ax.set_xlabel("Days from diagnosis")
|
| 2551 |
+
ax.set_ylabel("Survival probability")
|
| 2552 |
+
ax.set_ylim(0, 1)
|
| 2553 |
+
ax.set_title("Predicted survival curve (patient)")
|
| 2554 |
+
|
| 2555 |
+
render_plot_with_download(
|
| 2556 |
+
fig,
|
| 2557 |
+
title="Patient survival curve",
|
| 2558 |
+
filename="patient_survival_curve.png",
|
| 2559 |
+
export_dpi=export_dpi,
|
| 2560 |
+
key="dl_patient_surv_curve"
|
| 2561 |
+
)
|
| 2562 |
+
|
| 2563 |
+
except Exception as e:
|
| 2564 |
+
st.warning(f"Survival prediction could not be computed: {e}")
|
| 2565 |
+
|
| 2566 |
+
else:
|
| 2567 |
+
st.info("Survival model not loaded/published yet (survival bundle missing).")
|
| 2568 |
|
|
|
|
|
|
|
| 2569 |
|
| 2570 |
+
|
| 2571 |
+
out = X_one.copy()
|
| 2572 |
+
out["predicted_probability"] = proba_one
|
| 2573 |
+
pred_class = int(proba_one >= thr_single)
|
| 2574 |
+
out["predicted_class"] = pred_class
|
| 2575 |
+
|
| 2576 |
+
out["risk_band"] = band_one(proba_one)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2577 |
|
| 2578 |
+
|
|
|
|
|
|
|
|
|
|
| 2579 |
|
| 2580 |
+
# ---- SHAP compute only (cache) ----
|
| 2581 |
+
X_one_t = transform_before_clf(pipe, X_one)
|
| 2582 |
+
|
| 2583 |
+
explainer = st.session_state.get("explainer")
|
| 2584 |
+
explainer_sig = st.session_state.get("explainer_sig")
|
| 2585 |
+
|
| 2586 |
+
current_sig = (
|
| 2587 |
+
selected,
|
| 2588 |
+
None if st.session_state.get("X_bg_for_shap") is None else int(len(st.session_state["X_bg_for_shap"]))
|
| 2589 |
+
)
|
| 2590 |
+
|
| 2591 |
+
if explainer is None or explainer_sig != current_sig:
|
| 2592 |
+
X_bg = st.session_state.get("X_bg_for_shap")
|
| 2593 |
+
if X_bg is None:
|
| 2594 |
+
st.error("SHAP background not available. Admin must publish latest/background.csv.")
|
| 2595 |
+
st.stop()
|
| 2596 |
+
|
| 2597 |
+
st.session_state.explainer = build_shap_explainer(pipe, X_bg)
|
| 2598 |
+
st.session_state.explainer_sig = current_sig
|
| 2599 |
+
explainer = st.session_state.explainer
|
| 2600 |
+
|
| 2601 |
+
shap_vals = explainer.shap_values(X_one_t)
|
| 2602 |
+
if isinstance(shap_vals, list):
|
| 2603 |
+
shap_vals = shap_vals[1]
|
| 2604 |
+
|
| 2605 |
+
names = get_final_feature_names(pipe)
|
| 2606 |
+
|
| 2607 |
+
try:
|
| 2608 |
+
x_dense = X_one_t.toarray()[0]
|
| 2609 |
+
except Exception:
|
| 2610 |
+
x_dense = np.array(X_one_t)[0]
|
| 2611 |
+
|
| 2612 |
+
base = explainer.expected_value
|
| 2613 |
+
if not np.isscalar(base):
|
| 2614 |
+
base = float(np.array(base).reshape(-1)[0])
|
| 2615 |
+
|
| 2616 |
+
exp = shap.Explanation(
|
| 2617 |
+
values=shap_vals[0],
|
| 2618 |
+
base_values=float(base),
|
| 2619 |
+
data=x_dense,
|
| 2620 |
+
feature_names=names,
|
| 2621 |
+
)
|
| 2622 |
+
|
| 2623 |
+
# CACHE ONLY
|
| 2624 |
+
st.session_state.shap_single_exp = exp
|
| 2625 |
|
| 2626 |
+
|
| 2627 |
+
st.dataframe(out, use_container_width=True)
|
| 2628 |
+
|
| 2629 |
+
st.download_button(
|
| 2630 |
+
"Download single patient result (CSV)",
|
| 2631 |
+
out.to_csv(index=False).encode("utf-8"),
|
| 2632 |
+
file_name="single_patient_prediction.csv",
|
| 2633 |
+
mime="text/csv",
|
| 2634 |
+
key="dl_sp_csv",
|
| 2635 |
+
)
|
| 2636 |
|
| 2637 |
+
# ---- Always render cached SHAP ----
|
| 2638 |
+
if "shap_single_exp" in st.session_state:
|
| 2639 |
+
exp = st.session_state.shap_single_exp
|
| 2640 |
+
|
| 2641 |
+
max_display_single = st.slider(
|
| 2642 |
+
"Top features to display (single patient)",
|
| 2643 |
+
5, 40, 20, 1,
|
| 2644 |
+
key="sp_single_max_display"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2645 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2646 |
|
| 2647 |
+
c1, c2 = st.columns(2)
|
| 2648 |
|
| 2649 |
+
with c1:
|
| 2650 |
+
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 2651 |
+
shap.plots.waterfall(exp, show=False, max_display=max_display_single)
|
| 2652 |
+
fig_w = plt.gcf()
|
| 2653 |
+
render_plot_with_download(
|
| 2654 |
+
fig_w,
|
| 2655 |
+
title="Single-patient SHAP waterfall",
|
| 2656 |
+
filename="single_patient_shap_waterfall.png",
|
| 2657 |
+
export_dpi=export_dpi,
|
| 2658 |
+
key="dl_sp_wf"
|
| 2659 |
+
)
|
| 2660 |
+
|
| 2661 |
+
with c2:
|
| 2662 |
+
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 2663 |
+
shap.plots.bar(exp, show=False, max_display=max_display_single)
|
| 2664 |
+
fig_b = plt.gcf()
|
| 2665 |
+
render_plot_with_download(
|
| 2666 |
+
fig_b,
|
| 2667 |
+
title="Single-patient SHAP bar",
|
| 2668 |
+
filename="single_patient_shap_bar.png",
|
| 2669 |
+
export_dpi=export_dpi,
|
| 2670 |
+
key="dl_sp_bar"
|
| 2671 |
+
)
|
| 2672 |
|
| 2673 |
+
# --- After SHAP plots are displayed ---
|
| 2674 |
+
with st.expander("How to interpret the SHAP explanation plots", expanded=True):
|
| 2675 |
+
st.markdown(r"""
|
| 2676 |
+
**What is SHAP?**
|
| 2677 |
+
SHAP (SHapley Additive exPlanations) decomposes a model prediction into **feature-wise contributions**.
|
| 2678 |
+
Each feature pushes the prediction **toward higher risk** or **toward lower risk**, relative to the model’s baseline.
|
| 2679 |
|
| 2680 |
+
**What is \(E[f(X)]\)? (baseline / expected model output)**
|
| 2681 |
+
- \(E[f(X)]\) is the model’s **average output** over the reference population used by the explainer (typically the training set).
|
| 2682 |
+
- It is the **starting point** of the explanation before any patient-specific information is applied.
|
|
|
|
|
|
|
| 2683 |
|
| 2684 |
+
**What is \(f(x)\)? (this patient’s model output)**
|
| 2685 |
+
- \(f(x)\) is the model’s **final output for this patient**, after adding all feature contributions to the baseline.
|
| 2686 |
+
- For many classifiers (including logistic regression), SHAP waterfall plots often use the **log-odds (logit) scale** rather than raw probability.
|
| 2687 |
|
| 2688 |
+
**How the waterfall plot should be read**
|
| 2689 |
+
- The plot starts at **\(E[f(X)]\)** and then adds/subtracts feature effects to arrive at **\(f(x)\)**.
|
| 2690 |
+
- **Red bars** push the prediction **toward higher risk** (increase the model output).
|
| 2691 |
+
- **Blue bars** push the prediction **toward lower risk / protective direction** (decrease the model output).
|
| 2692 |
+
- The **bar length** indicates the **strength** of that feature’s influence for this patient.
|
| 2693 |
+
- “**Other features**” is the combined net effect of features not shown individually.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2694 |
|
| 2695 |
+
**How the bar plot should be read (Top contributors)**
|
| 2696 |
+
- Features are ranked by **absolute SHAP value** (largest impact first).
|
| 2697 |
+
- **Positive SHAP** (right) increases predicted risk; **negative SHAP** (left) decreases predicted risk.
|
| 2698 |
+
|
| 2699 |
+
**Clinical cautions**
|
| 2700 |
+
- SHAP explains the model’s behavior; it does **not** prove causality.
|
| 2701 |
+
- Ensure variable definitions and patient population match the model’s intended use.
|
| 2702 |
+
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2703 |
|
| 2704 |
|
| 2705 |
# -----------------------------
|
|
|
|
| 2895 |
cph = None
|
| 2896 |
surv_cols = []
|
| 2897 |
|
| 2898 |
+
# ---- Batch survival probabilities (if survival model loaded) ----
|
| 2899 |
+
bundle = st.session_state.get("surv_model", None)
|
| 2900 |
+
|
| 2901 |
+
if isinstance(bundle, dict) and bundle.get("model") is not None:
|
| 2902 |
try:
|
| 2903 |
+
cph = bundle["model"]
|
| 2904 |
+
surv_cols = bundle.get("columns", [])
|
| 2905 |
+
|
| 2906 |
+
df_surv_in = X_inf[feature_cols].copy()
|
| 2907 |
|
| 2908 |
for c in num_cols:
|
| 2909 |
+
if c in df_surv_in.columns:
|
| 2910 |
+
df_surv_in[c] = pd.to_numeric(df_surv_in[c], errors="coerce")
|
| 2911 |
+
|
| 2912 |
for c in cat_cols:
|
| 2913 |
+
if c in df_surv_in.columns:
|
| 2914 |
+
df_surv_in[c] = df_surv_in[c].astype("object").map(
|
| 2915 |
+
lambda v: v if pd.isna(v) else str(v)
|
| 2916 |
+
)
|
| 2917 |
|
| 2918 |
+
df_surv_in_oh = pd.get_dummies(df_surv_in, columns=cat_cols, drop_first=True)
|
| 2919 |
|
| 2920 |
# Align to training predictor columns
|
| 2921 |
for col in surv_cols:
|
| 2922 |
+
if col not in df_surv_in_oh.columns:
|
| 2923 |
+
df_surv_in_oh[col] = 0
|
| 2924 |
+
df_surv_in_oh = df_surv_in_oh[surv_cols]
|
| 2925 |
+
|
| 2926 |
+
surv_fn_all = cph.predict_survival_function(df_surv_in_oh)
|
| 2927 |
+
|
| 2928 |
+
def surv_vec_at(days: int):
|
| 2929 |
+
idx = surv_fn_all.index.values
|
| 2930 |
+
j = int(np.argmin(np.abs(idx - days)))
|
| 2931 |
+
return surv_fn_all.iloc[j, :].values.astype(float)
|
| 2932 |
+
|
| 2933 |
+
df_out["survival_6m"] = surv_vec_at(180)
|
| 2934 |
+
df_out["survival_1y"] = surv_vec_at(365)
|
| 2935 |
+
df_out["survival_2y"] = surv_vec_at(730)
|
| 2936 |
+
df_out["survival_3y"] = surv_vec_at(1095)
|
| 2937 |
|
|
|
|
|
|
|
| 2938 |
except Exception as e:
|
| 2939 |
+
st.warning(f"Batch survival probabilities could not be computed: {e}")
|
| 2940 |
+
else:
|
| 2941 |
+
st.info("Survival model not loaded/published yet (survival bundle missing).")
|
| 2942 |
+
|
| 2943 |
|
| 2944 |
|
| 2945 |
surv_fn_all = cph.predict_survival_function(df_surv_in_oh)
|