Synav commited on
Commit
8824526
·
verified ·
1 Parent(s): 3d12e25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +289 -287
app.py CHANGED
@@ -2198,9 +2198,7 @@ with tab_predict:
2198
  # =========================
2199
  pipe = st.session_state.pipe
2200
  meta = st.session_state.meta
2201
- feature_cols = meta["schema"]["features"]
2202
- num_cols = meta["schema"]["numeric"]
2203
- cat_cols = meta["schema"]["categorical"]
2204
 
2205
 
2206
 
@@ -2426,300 +2424,282 @@ with tab_predict:
2426
  values_by_index[i] = np.nan if v.strip() == "" else v
2427
 
2428
 
2429
- # Apply FISH/NGS selections to row
2430
- fish_set = set(fish_selected)
2431
- ngs_set = set(ngs_selected)
 
 
 
 
 
 
2432
 
2433
- for i, f in enumerate(feature_cols):
2434
- if f in FISH_MARKERS:
2435
- values_by_index[i] = 1 if f in fish_set else 0
2436
- if f in NGS_MARKERS:
2437
- values_by_index[i] = 1 if f in ngs_set else 0
2438
-
2439
- # Auto-fill marker counts if present
2440
- if FISH_COUNT_COL in feature_cols:
2441
- values_by_index[feature_cols.index(FISH_COUNT_COL)] = int(len(fish_selected))
2442
- if NGS_COUNT_COL in feature_cols:
2443
- values_by_index[feature_cols.index(NGS_COUNT_COL)] = int(len(ngs_selected))
2444
- st.divider()
2445
- st.subheader("Predict single patient")
2446
 
2447
- m = meta.get("metrics", {})
2448
- default_thr = float(m.get("best_threshold", 0.5))
2449
-
2450
- thr_single = st.slider(
2451
- "Classification threshold",
2452
- 0.0, 1.0, default_thr, 0.01,
2453
- key="sp_thr"
2454
- )
2455
-
2456
- # External validation threshold
2457
- thr_ext = st.slider(
2458
- "External validation threshold",
2459
- 0.0, 1.0, default_thr, 0.01,
2460
- key="thr_ext"
2461
- )
2462
 
2463
- low_cut_s, high_cut_s = st.slider(
2464
- "Risk band cutoffs (low, high)",
2465
- 0.0, 1.0, (0.2, 0.8), 0.01,
2466
- key="sp_risk_cuts"
2467
- )
2468
-
2469
- # Ensure low <= high
2470
- if low_cut_s > high_cut_s:
2471
- low_cut_s, high_cut_s = high_cut_s, low_cut_s
2472
 
2473
 
2474
- def band_one(p: float) -> str:
2475
- if p < low_cut_s:
2476
- return "Low"
2477
- if p >= high_cut_s:
2478
- return "High"
2479
- return "Intermediate"
2480
- # Submit button (no form needed; simpler + fewer state surprises)
2481
- if st.button("Predict single patient", key="sp_predict_btn"):
2482
- X_one = pd.DataFrame([values_by_index], columns=feature_cols).replace({pd.NA: np.nan})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2483
 
2484
- for c in num_cols:
2485
- if c in X_one.columns:
2486
- X_one[c] = pd.to_numeric(X_one[c], errors="coerce")
 
2487
 
2488
- for c in cat_cols:
2489
- if c in X_one.columns:
2490
- X_one[c] = X_one[c].astype("object")
2491
- X_one.loc[X_one[c].isna(), c] = np.nan
2492
- X_one[c] = X_one[c].map(lambda v: v if pd.isna(v) else str(v))
2493
-
2494
 
2495
- proba_one = float(pipe.predict_proba(X_one)[:, 1][0])
2496
- st.success("Prediction generated.")
2497
- st.metric("Predicted probability", f"{proba_one:.4f}")
2498
-
2499
-
2500
- # ---- Survival prediction for this patient (if survival model loaded) ----
2501
- cph = st.session_state.get("surv_model", None)
2502
-
2503
- if cph is not None:
2504
- try:
2505
- # Build Cox input row with same preprocessing used in training Cox:
2506
- # predictors + one-hot categoricals
2507
- df_one_surv = X_one[feature_cols].copy()
2508
- for c in num_cols:
2509
  df_one_surv[c] = pd.to_numeric(df_one_surv[c], errors="coerce")
2510
- for c in cat_cols:
2511
- df_one_surv[c] = df_one_surv[c].astype("object").map(lambda v: v if pd.isna(v) else str(v))
2512
-
2513
- df_one_surv_oh = pd.get_dummies(df_one_surv, columns=cat_cols, drop_first=True)
2514
-
2515
- # Align columns to Cox model training columns
2516
- bundle = st.session_state.get("surv_model", None)
2517
- if isinstance(bundle, dict) and "model" in bundle:
2518
- cph = bundle["model"]
2519
- surv_cols = bundle.get("columns", [])
2520
- else:
2521
- cph = None
2522
- surv_cols = []
2523
-
2524
- if cph is not None:
2525
- try:
2526
- df_one_surv = X_one[feature_cols].copy()
2527
-
2528
- for c in num_cols:
2529
- if c in df_one_surv.columns:
2530
- df_one_surv[c] = pd.to_numeric(df_one_surv[c], errors="coerce")
2531
- for c in cat_cols:
2532
- if c in df_one_surv.columns:
2533
- df_one_surv[c] = df_one_surv[c].astype("object").map(lambda v: v if pd.isna(v) else str(v))
2534
-
2535
- df_one_surv_oh = pd.get_dummies(df_one_surv, columns=cat_cols, drop_first=True)
2536
-
2537
- # Align to training predictor columns
2538
- for col in surv_cols:
2539
- if col not in df_one_surv_oh.columns:
2540
- df_one_surv_oh[col] = 0
2541
- df_one_surv_oh = df_one_surv_oh[surv_cols]
2542
-
2543
- surv_fn = cph.predict_survival_function(df_one_surv_oh)
2544
- ...
2545
- except Exception as e:
2546
- st.warning(f"Survival prediction could not be computed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2547
 
2548
-
2549
- surv_fn = cph.predict_survival_function(df_one_surv_oh)
2550
 
2551
- def surv_at(days: int) -> float:
2552
- # find nearest time index
2553
- idx = surv_fn.index.values
2554
- j = int(np.argmin(np.abs(idx - days)))
2555
- return float(surv_fn.iloc[j, 0])
2556
-
2557
- s6m = surv_at(180)
2558
- s1y = surv_at(365)
2559
- s2y = surv_at(730)
2560
- s3y = surv_at(1095)
2561
-
2562
- st.subheader("Predicted survival probability")
2563
- a,b,c,d = st.columns(4)
2564
- a.metric("6 months", f"{s6m*100:.1f}%")
2565
- b.metric("1 year", f"{s1y*100:.1f}%")
2566
- c.metric("2 years", f"{s2y*100:.1f}%")
2567
- d.metric("3 years", f"{s3y*100:.1f}%")
2568
-
2569
- # Optional: plot curve
2570
- fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
2571
- ax.plot(surv_fn.index, surv_fn.values)
2572
- ax.set_xlabel("Days from diagnosis")
2573
- ax.set_ylabel("Survival probability")
2574
- ax.set_ylim(0, 1)
2575
- ax.set_title("Predicted survival curve (patient)")
2576
-
2577
- render_plot_with_download(
2578
- fig,
2579
- title="Patient survival curve",
2580
- filename="patient_survival_curve.png",
2581
- export_dpi=export_dpi,
2582
- key="dl_patient_surv_curve"
2583
- )
2584
 
2585
- except Exception as e:
2586
- st.warning(f"Survival prediction could not be computed: {e}")
2587
- else:
2588
- st.info("Survival model not loaded/published yet (survival_model.joblib missing).")
2589
 
2590
-
2591
- out = X_one.copy()
2592
- out["predicted_probability"] = proba_one
2593
- pred_class = int(proba_one >= thr_single)
2594
- out["predicted_class"] = pred_class
2595
-
2596
- out["risk_band"] = band_one(proba_one)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2597
 
2598
-
 
 
 
 
 
 
 
 
 
2599
 
2600
- # ---- SHAP compute only (cache) ----
2601
- X_one_t = transform_before_clf(pipe, X_one)
2602
-
2603
- explainer = st.session_state.get("explainer")
2604
- explainer_sig = st.session_state.get("explainer_sig")
2605
-
2606
- current_sig = (
2607
- selected,
2608
- None if st.session_state.get("X_bg_for_shap") is None else int(len(st.session_state["X_bg_for_shap"]))
2609
- )
2610
-
2611
- if explainer is None or explainer_sig != current_sig:
2612
- X_bg = st.session_state.get("X_bg_for_shap")
2613
- if X_bg is None:
2614
- st.error("SHAP background not available. Admin must publish latest/background.csv.")
2615
- st.stop()
2616
-
2617
- st.session_state.explainer = build_shap_explainer(pipe, X_bg)
2618
- st.session_state.explainer_sig = current_sig
2619
- explainer = st.session_state.explainer
2620
-
2621
- shap_vals = explainer.shap_values(X_one_t)
2622
- if isinstance(shap_vals, list):
2623
- shap_vals = shap_vals[1]
2624
-
2625
- names = get_final_feature_names(pipe)
2626
-
2627
- try:
2628
- x_dense = X_one_t.toarray()[0]
2629
- except Exception:
2630
- x_dense = np.array(X_one_t)[0]
2631
-
2632
- base = explainer.expected_value
2633
- if not np.isscalar(base):
2634
- base = float(np.array(base).reshape(-1)[0])
2635
-
2636
- exp = shap.Explanation(
2637
- values=shap_vals[0],
2638
- base_values=float(base),
2639
- data=x_dense,
2640
- feature_names=names,
2641
  )
2642
-
2643
- # CACHE ONLY
2644
- st.session_state.shap_single_exp = exp
2645
-
2646
 
2647
- st.dataframe(out, use_container_width=True)
2648
 
2649
- st.download_button(
2650
- "Download single patient result (CSV)",
2651
- out.to_csv(index=False).encode("utf-8"),
2652
- file_name="single_patient_prediction.csv",
2653
- mime="text/csv",
2654
- key="dl_sp_csv",
2655
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2656
 
2657
- # ---- Always render cached SHAP ----
2658
- if "shap_single_exp" in st.session_state:
2659
- exp = st.session_state.shap_single_exp
 
 
 
2660
 
2661
- max_display_single = st.slider(
2662
- "Top features to display (single patient)",
2663
- 5, 40, 20, 1,
2664
- key="sp_single_max_display"
2665
- )
2666
 
2667
- c1, c2 = st.columns(2)
 
 
2668
 
2669
- with c1:
2670
- plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
2671
- shap.plots.waterfall(exp, show=False, max_display=max_display_single)
2672
- fig_w = plt.gcf()
2673
- render_plot_with_download(
2674
- fig_w,
2675
- title="Single-patient SHAP waterfall",
2676
- filename="single_patient_shap_waterfall.png",
2677
- export_dpi=export_dpi,
2678
- key="dl_sp_wf"
2679
- )
2680
 
2681
- with c2:
2682
- plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
2683
- shap.plots.bar(exp, show=False, max_display=max_display_single)
2684
- fig_b = plt.gcf()
2685
- render_plot_with_download(
2686
- fig_b,
2687
- title="Single-patient SHAP bar",
2688
- filename="single_patient_shap_bar.png",
2689
- export_dpi=export_dpi,
2690
- key="dl_sp_bar"
2691
- )
2692
-
2693
- # --- After SHAP plots are displayed ---
2694
- with st.expander("How to interpret the SHAP explanation plots", expanded=True):
2695
- st.markdown(r"""
2696
- **What is SHAP?**
2697
- SHAP (SHapley Additive exPlanations) decomposes a model prediction into **feature-wise contributions**.
2698
- Each feature pushes the prediction **toward higher risk** or **toward lower risk**, relative to the model’s baseline.
2699
-
2700
- **What is \(E[f(X)]\)? (baseline / expected model output)**
2701
- - \(E[f(X)]\) is the model’s **average output** over the reference population used by the explainer (typically the training set).
2702
- - It is the **starting point** of the explanation before any patient-specific information is applied.
2703
-
2704
- **What is \(f(x)\)? (this patient’s model output)**
2705
- - \(f(x)\) is the model’s **final output for this patient**, after adding all feature contributions to the baseline.
2706
- - For many classifiers (including logistic regression), SHAP waterfall plots often use the **log-odds (logit) scale** rather than raw probability.
2707
-
2708
- **How the waterfall plot should be read**
2709
- - The plot starts at **\(E[f(X)]\)** and then adds/subtracts feature effects to arrive at **\(f(x)\)**.
2710
- - **Red bars** push the prediction **toward higher risk** (increase the model output).
2711
- - **Blue bars** push the prediction **toward lower risk / protective direction** (decrease the model output).
2712
- - The **bar length** indicates the **strength** of that feature’s influence for this patient.
2713
- - “**Other features**” is the combined net effect of features not shown individually.
2714
-
2715
- **How the bar plot should be read (Top contributors)**
2716
- - Features are ranked by **absolute SHAP value** (largest impact first).
2717
- - **Positive SHAP** (right) increases predicted risk; **negative SHAP** (left) decreases predicted risk.
2718
-
2719
- **Clinical cautions**
2720
- - SHAP explains the model’s behavior; it does **not** prove causality.
2721
- - Ensure variable definitions and patient population match the model’s intended use.
2722
- """)
2723
 
2724
 
2725
  # -----------------------------
@@ -2915,29 +2895,51 @@ with tab_predict:
2915
  cph = None
2916
  surv_cols = []
2917
 
2918
- if cph is not None:
 
 
 
2919
  try:
2920
- df_one_surv = X_one[feature_cols].copy()
 
 
 
2921
 
2922
  for c in num_cols:
2923
- if c in df_one_surv.columns:
2924
- df_one_surv[c] = pd.to_numeric(df_one_surv[c], errors="coerce")
 
2925
  for c in cat_cols:
2926
- if c in df_one_surv.columns:
2927
- df_one_surv[c] = df_one_surv[c].astype("object").map(lambda v: v if pd.isna(v) else str(v))
 
 
2928
 
2929
- df_one_surv_oh = pd.get_dummies(df_one_surv, columns=cat_cols, drop_first=True)
2930
 
2931
  # Align to training predictor columns
2932
  for col in surv_cols:
2933
- if col not in df_one_surv_oh.columns:
2934
- df_one_surv_oh[col] = 0
2935
- df_one_surv_oh = df_one_surv_oh[surv_cols]
 
 
 
 
 
 
 
 
 
 
 
 
2936
 
2937
- surv_fn = cph.predict_survival_function(df_one_surv_oh)
2938
- ...
2939
  except Exception as e:
2940
- st.warning(f"Survival prediction could not be computed: {e}")
 
 
 
2941
 
2942
 
2943
  surv_fn_all = cph.predict_survival_function(df_surv_in_oh)
 
2198
  # =========================
2199
  pipe = st.session_state.pipe
2200
  meta = st.session_state.meta
2201
+
 
 
2202
 
2203
 
2204
 
 
2424
  values_by_index[i] = np.nan if v.strip() == "" else v
2425
 
2426
 
2427
+ # Apply FISH/NGS selections to row
2428
+ fish_set = set(fish_selected)
2429
+ ngs_set = set(ngs_selected)
2430
+
2431
+ for i, f in enumerate(feature_cols):
2432
+ if f in FISH_MARKERS:
2433
+ values_by_index[i] = 1 if f in fish_set else 0
2434
+ if f in NGS_MARKERS:
2435
+ values_by_index[i] = 1 if f in ngs_set else 0
2436
 
2437
+ # Auto-fill marker counts if present
2438
+ if FISH_COUNT_COL in feature_cols:
2439
+ values_by_index[feature_cols.index(FISH_COUNT_COL)] = int(len(fish_selected))
2440
+ if NGS_COUNT_COL in feature_cols:
2441
+ values_by_index[feature_cols.index(NGS_COUNT_COL)] = int(len(ngs_selected))
2442
+ st.divider()
2443
+ st.subheader("Predict single patient")
 
 
 
 
 
 
2444
 
2445
+ m = meta.get("metrics", {})
2446
+ default_thr = float(m.get("best_threshold", 0.5))
2447
+
2448
+ thr_single = st.slider(
2449
+ "Classification threshold",
2450
+ 0.0, 1.0, default_thr, 0.01,
2451
+ key="sp_thr"
2452
+ )
2453
+
2454
+ # External validation threshold
2455
+ thr_ext = st.slider(
2456
+ "External validation threshold",
2457
+ 0.0, 1.0, default_thr, 0.01,
2458
+ key="thr_ext"
2459
+ )
2460
 
2461
+ low_cut_s, high_cut_s = st.slider(
2462
+ "Risk band cutoffs (low, high)",
2463
+ 0.0, 1.0, (0.2, 0.8), 0.01,
2464
+ key="sp_risk_cuts"
2465
+ )
2466
+
2467
+ # Ensure low <= high
2468
+ if low_cut_s > high_cut_s:
2469
+ low_cut_s, high_cut_s = high_cut_s, low_cut_s
2470
 
2471
 
2472
+ def band_one(p: float) -> str:
2473
+ if p < low_cut_s:
2474
+ return "Low"
2475
+ if p >= high_cut_s:
2476
+ return "High"
2477
+ return "Intermediate"
2478
+ # Submit button (no form needed; simpler + fewer state surprises)
2479
+ if st.button("Predict single patient", key="sp_predict_btn"):
2480
+ X_one = pd.DataFrame([values_by_index], columns=feature_cols).replace({pd.NA: np.nan})
2481
+
2482
+ for c in num_cols:
2483
+ if c in X_one.columns:
2484
+ X_one[c] = pd.to_numeric(X_one[c], errors="coerce")
2485
+
2486
+ for c in cat_cols:
2487
+ if c in X_one.columns:
2488
+ X_one[c] = X_one[c].astype("object")
2489
+ X_one.loc[X_one[c].isna(), c] = np.nan
2490
+ X_one[c] = X_one[c].map(lambda v: v if pd.isna(v) else str(v))
2491
+
2492
+
2493
+ proba_one = float(pipe.predict_proba(X_one)[:, 1][0])
2494
+ st.success("Prediction generated.")
2495
+ st.metric("Predicted probability", f"{proba_one:.4f}")
2496
+
2497
+
2498
+ # ---- Survival prediction for this patient (if survival model loaded) ----
2499
+ # ---- Survival prediction for this patient (if survival model loaded) ----
2500
+ bundle = st.session_state.get("surv_model", None)
2501
 
2502
+ if isinstance(bundle, dict) and bundle.get("model") is not None:
2503
+ try:
2504
+ cph = bundle["model"]
2505
+ surv_cols = bundle.get("columns", [])
2506
 
2507
+ # Build Cox input row (same preprocessing as Cox training)
2508
+ df_one_surv = X_one[feature_cols].copy()
 
 
 
 
2509
 
2510
+ for c in num_cols:
2511
+ if c in df_one_surv.columns:
 
 
 
 
 
 
 
 
 
 
 
 
2512
  df_one_surv[c] = pd.to_numeric(df_one_surv[c], errors="coerce")
2513
+
2514
+ for c in cat_cols:
2515
+ if c in df_one_surv.columns:
2516
+ df_one_surv[c] = df_one_surv[c].astype("object").map(
2517
+ lambda v: v if pd.isna(v) else str(v)
2518
+ )
2519
+
2520
+ df_one_surv_oh = pd.get_dummies(df_one_surv, columns=cat_cols, drop_first=True)
2521
+
2522
+ # Align to training predictor columns
2523
+ for col in surv_cols:
2524
+ if col not in df_one_surv_oh.columns:
2525
+ df_one_surv_oh[col] = 0
2526
+ df_one_surv_oh = df_one_surv_oh[surv_cols]
2527
+
2528
+ # Predict survival function
2529
+ surv_fn = cph.predict_survival_function(df_one_surv_oh)
2530
+
2531
+ def surv_at(days: int) -> float:
2532
+ idx = surv_fn.index.values
2533
+ j = int(np.argmin(np.abs(idx - days)))
2534
+ return float(surv_fn.iloc[j, 0])
2535
+
2536
+ s6m = surv_at(180)
2537
+ s1y = surv_at(365)
2538
+ s2y = surv_at(730)
2539
+ s3y = surv_at(1095)
2540
+
2541
+ st.subheader("Predicted survival probability")
2542
+ a, b, c, d = st.columns(4)
2543
+ a.metric("6 months", f"{s6m*100:.1f}%")
2544
+ b.metric("1 year", f"{s1y*100:.1f}%")
2545
+ c.metric("2 years", f"{s2y*100:.1f}%")
2546
+ d.metric("3 years", f"{s3y*100:.1f}%")
2547
+
2548
+ fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
2549
+ ax.plot(surv_fn.index, surv_fn.values)
2550
+ ax.set_xlabel("Days from diagnosis")
2551
+ ax.set_ylabel("Survival probability")
2552
+ ax.set_ylim(0, 1)
2553
+ ax.set_title("Predicted survival curve (patient)")
2554
+
2555
+ render_plot_with_download(
2556
+ fig,
2557
+ title="Patient survival curve",
2558
+ filename="patient_survival_curve.png",
2559
+ export_dpi=export_dpi,
2560
+ key="dl_patient_surv_curve"
2561
+ )
2562
+
2563
+ except Exception as e:
2564
+ st.warning(f"Survival prediction could not be computed: {e}")
2565
+
2566
+ else:
2567
+ st.info("Survival model not loaded/published yet (survival bundle missing).")
2568
 
 
 
2569
 
2570
+
2571
+ out = X_one.copy()
2572
+ out["predicted_probability"] = proba_one
2573
+ pred_class = int(proba_one >= thr_single)
2574
+ out["predicted_class"] = pred_class
2575
+
2576
+ out["risk_band"] = band_one(proba_one)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2577
 
2578
+
 
 
 
2579
 
2580
+ # ---- SHAP compute only (cache) ----
2581
+ X_one_t = transform_before_clf(pipe, X_one)
2582
+
2583
+ explainer = st.session_state.get("explainer")
2584
+ explainer_sig = st.session_state.get("explainer_sig")
2585
+
2586
+ current_sig = (
2587
+ selected,
2588
+ None if st.session_state.get("X_bg_for_shap") is None else int(len(st.session_state["X_bg_for_shap"]))
2589
+ )
2590
+
2591
+ if explainer is None or explainer_sig != current_sig:
2592
+ X_bg = st.session_state.get("X_bg_for_shap")
2593
+ if X_bg is None:
2594
+ st.error("SHAP background not available. Admin must publish latest/background.csv.")
2595
+ st.stop()
2596
+
2597
+ st.session_state.explainer = build_shap_explainer(pipe, X_bg)
2598
+ st.session_state.explainer_sig = current_sig
2599
+ explainer = st.session_state.explainer
2600
+
2601
+ shap_vals = explainer.shap_values(X_one_t)
2602
+ if isinstance(shap_vals, list):
2603
+ shap_vals = shap_vals[1]
2604
+
2605
+ names = get_final_feature_names(pipe)
2606
+
2607
+ try:
2608
+ x_dense = X_one_t.toarray()[0]
2609
+ except Exception:
2610
+ x_dense = np.array(X_one_t)[0]
2611
+
2612
+ base = explainer.expected_value
2613
+ if not np.isscalar(base):
2614
+ base = float(np.array(base).reshape(-1)[0])
2615
+
2616
+ exp = shap.Explanation(
2617
+ values=shap_vals[0],
2618
+ base_values=float(base),
2619
+ data=x_dense,
2620
+ feature_names=names,
2621
+ )
2622
+
2623
+ # CACHE ONLY
2624
+ st.session_state.shap_single_exp = exp
2625
 
2626
+
2627
+ st.dataframe(out, use_container_width=True)
2628
+
2629
+ st.download_button(
2630
+ "Download single patient result (CSV)",
2631
+ out.to_csv(index=False).encode("utf-8"),
2632
+ file_name="single_patient_prediction.csv",
2633
+ mime="text/csv",
2634
+ key="dl_sp_csv",
2635
+ )
2636
 
2637
+ # ---- Always render cached SHAP ----
2638
+ if "shap_single_exp" in st.session_state:
2639
+ exp = st.session_state.shap_single_exp
2640
+
2641
+ max_display_single = st.slider(
2642
+ "Top features to display (single patient)",
2643
+ 5, 40, 20, 1,
2644
+ key="sp_single_max_display"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2645
  )
 
 
 
 
2646
 
2647
+ c1, c2 = st.columns(2)
2648
 
2649
+ with c1:
2650
+ plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
2651
+ shap.plots.waterfall(exp, show=False, max_display=max_display_single)
2652
+ fig_w = plt.gcf()
2653
+ render_plot_with_download(
2654
+ fig_w,
2655
+ title="Single-patient SHAP waterfall",
2656
+ filename="single_patient_shap_waterfall.png",
2657
+ export_dpi=export_dpi,
2658
+ key="dl_sp_wf"
2659
+ )
2660
+
2661
+ with c2:
2662
+ plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
2663
+ shap.plots.bar(exp, show=False, max_display=max_display_single)
2664
+ fig_b = plt.gcf()
2665
+ render_plot_with_download(
2666
+ fig_b,
2667
+ title="Single-patient SHAP bar",
2668
+ filename="single_patient_shap_bar.png",
2669
+ export_dpi=export_dpi,
2670
+ key="dl_sp_bar"
2671
+ )
2672
 
2673
+ # --- After SHAP plots are displayed ---
2674
+ with st.expander("How to interpret the SHAP explanation plots", expanded=True):
2675
+ st.markdown(r"""
2676
+ **What is SHAP?**
2677
+ SHAP (SHapley Additive exPlanations) decomposes a model prediction into **feature-wise contributions**.
2678
+ Each feature pushes the prediction **toward higher risk** or **toward lower risk**, relative to the model’s baseline.
2679
 
2680
+ **What is \(E[f(X)]\)? (baseline / expected model output)**
2681
+ - \(E[f(X)]\) is the model’s **average output** over the reference population used by the explainer (typically the training set).
2682
+ - It is the **starting point** of the explanation before any patient-specific information is applied.
 
 
2683
 
2684
+ **What is \(f(x)\)? (this patient’s model output)**
2685
+ - \(f(x)\) is the model’s **final output for this patient**, after adding all feature contributions to the baseline.
2686
+ - For many classifiers (including logistic regression), SHAP waterfall plots often use the **log-odds (logit) scale** rather than raw probability.
2687
 
2688
+ **How the waterfall plot should be read**
2689
+ - The plot starts at **\(E[f(X)]\)** and then adds/subtracts feature effects to arrive at **\(f(x)\)**.
2690
+ - **Red bars** push the prediction **toward higher risk** (increase the model output).
2691
+ - **Blue bars** push the prediction **toward lower risk / protective direction** (decrease the model output).
2692
+ - The **bar length** indicates the **strength** of that feature’s influence for this patient.
2693
+ - “**Other features**” is the combined net effect of features not shown individually.
 
 
 
 
 
2694
 
2695
+ **How the bar plot should be read (Top contributors)**
2696
+ - Features are ranked by **absolute SHAP value** (largest impact first).
2697
+ - **Positive SHAP** (right) increases predicted risk; **negative SHAP** (left) decreases predicted risk.
2698
+
2699
+ **Clinical cautions**
2700
+ - SHAP explains the model’s behavior; it does **not** prove causality.
2701
+ - Ensure variable definitions and patient population match the model’s intended use.
2702
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2703
 
2704
 
2705
  # -----------------------------
 
2895
  cph = None
2896
  surv_cols = []
2897
 
2898
+ # ---- Batch survival probabilities (if survival model loaded) ----
2899
+ bundle = st.session_state.get("surv_model", None)
2900
+
2901
+ if isinstance(bundle, dict) and bundle.get("model") is not None:
2902
  try:
2903
+ cph = bundle["model"]
2904
+ surv_cols = bundle.get("columns", [])
2905
+
2906
+ df_surv_in = X_inf[feature_cols].copy()
2907
 
2908
  for c in num_cols:
2909
+ if c in df_surv_in.columns:
2910
+ df_surv_in[c] = pd.to_numeric(df_surv_in[c], errors="coerce")
2911
+
2912
  for c in cat_cols:
2913
+ if c in df_surv_in.columns:
2914
+ df_surv_in[c] = df_surv_in[c].astype("object").map(
2915
+ lambda v: v if pd.isna(v) else str(v)
2916
+ )
2917
 
2918
+ df_surv_in_oh = pd.get_dummies(df_surv_in, columns=cat_cols, drop_first=True)
2919
 
2920
  # Align to training predictor columns
2921
  for col in surv_cols:
2922
+ if col not in df_surv_in_oh.columns:
2923
+ df_surv_in_oh[col] = 0
2924
+ df_surv_in_oh = df_surv_in_oh[surv_cols]
2925
+
2926
+ surv_fn_all = cph.predict_survival_function(df_surv_in_oh)
2927
+
2928
+ def surv_vec_at(days: int):
2929
+ idx = surv_fn_all.index.values
2930
+ j = int(np.argmin(np.abs(idx - days)))
2931
+ return surv_fn_all.iloc[j, :].values.astype(float)
2932
+
2933
+ df_out["survival_6m"] = surv_vec_at(180)
2934
+ df_out["survival_1y"] = surv_vec_at(365)
2935
+ df_out["survival_2y"] = surv_vec_at(730)
2936
+ df_out["survival_3y"] = surv_vec_at(1095)
2937
 
 
 
2938
  except Exception as e:
2939
+ st.warning(f"Batch survival probabilities could not be computed: {e}")
2940
+ else:
2941
+ st.info("Survival model not loaded/published yet (survival bundle missing).")
2942
+
2943
 
2944
 
2945
  surv_fn_all = cph.predict_survival_function(df_surv_in_oh)