Synav's picture
Update app.py
b2b0f33 verified
import json
from datetime import datetime
import numpy as np
import pandas as pd
import streamlit as st
import joblib
import os
from huggingface_hub import hf_hub_download, HfApi
import hmac
from sklearn.metrics import (
roc_auc_score, accuracy_score,
roc_curve, confusion_matrix,
precision_score, recall_score, f1_score,
balanced_accuracy_score,
precision_recall_curve, average_precision_score,
brier_score_loss
)
import shap
import matplotlib.pyplot as plt
from sklearn.calibration import calibration_curve
from sklearn.feature_selection import VarianceThreshold, SelectFromModel
from sklearn.decomposition import TruncatedSVD
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
#Figures setting block
# REPLACE make_fig with this (or add this and stop using plt.plot directly)
def make_fig(figsize=(5.5, 3.6), dpi=120):
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
return fig, ax
def fig_to_png_bytes(fig, dpi=600):
import io
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=int(dpi), bbox_inches="tight")
buf.seek(0)
return buf.getvalue()
from typing import Optional
def render_plot_with_download(
fig,
*,
title: str,
filename: str,
export_dpi: int = 600,
key: Optional[str] = None
):
png_bytes = fig_to_png_bytes(fig, dpi=export_dpi)
st.pyplot(fig, clear_figure=False)
st.download_button(
label=f"Download {title} (PNG {export_dpi} dpi)",
data=png_bytes,
file_name=filename,
mime="image/png",
key=key or f"dl_{filename}"
)
plt.close(fig)
# ============================================================
# Fixed schema definition (PLACEHOLDER FRAMEWORK)
# ============================================================
# =========================
# 1) No feature limit
# =========================
LABEL_COL = "Outcome Event" # keep label column exactly as header "Outcome Event" in your Excel
# =========================
# Survival columns (Excel headers)
# =========================
SURV_START_COL_CANDS = ["Date of 1st Bone Marrow biopsy (Date of Diagnosis)", "Diagnosis Date", "Dx Date"]
SURV_DEATH_COL_CANDS = ["death_date", "Death Date", "Date of Death"]
SURV_CENSOR_COL_CANDS = ["last_followup_date", "Last Follow up date", "Last Follow-up Date", "Last Seen Date"]
SURV_EVENT_COL_CANDS = ["Outcome Event", "Death", "Event"]
def get_feature_cols_from_df(df: pd.DataFrame):
"""
Features = ALL columns except Outcome Event (no limit), preserving original order.
"""
if LABEL_COL not in df.columns:
raise ValueError(f"Missing required label column '{LABEL_COL}'. "
f"Ensure your Excel header row contains a column named '{LABEL_COL}'.")
return [c for c in df.columns if c != LABEL_COL]
from datetime import date
def to_date_safe(x):
"""Accepts date/datetime/str; returns python date or None."""
if x is None or (isinstance(x, float) and np.isnan(x)):
return None
if isinstance(x, datetime):
return x.date()
if isinstance(x, date):
return x
# string
try:
return pd.to_datetime(x).date()
except Exception:
return None
def age_years_at(dob, ref_date):
"""Age in years at ref_date. Returns float or np.nan."""
dob = to_date_safe(dob)
ref_date = to_date_safe(ref_date)
if dob is None or ref_date is None:
return np.nan
if ref_date < dob:
return np.nan
return (ref_date - dob).days / 365.25
def build_survival_targets(df: pd.DataFrame):
start_col = find_col(df, SURV_START_COL_CANDS)
death_col = find_col(df, SURV_DEATH_COL_CANDS)
censor_col = find_col(df, SURV_CENSOR_COL_CANDS)
event_col = find_col(df, SURV_EVENT_COL_CANDS)
if start_col is None or event_col is None:
raise ValueError("Survival requires at minimum Diagnosis date and Outcome/Event column.")
start = pd.to_datetime(df[start_col], errors="coerce")
# event indicator (treat YES/1/TRUE as event)
ev_raw = df[event_col]
if pd.api.types.is_numeric_dtype(ev_raw):
ev = ev_raw.fillna(0).astype(int).eq(1)
else:
ev = (ev_raw.astype(str).str.strip().str.upper().isin(["YES","Y","1","TRUE","T","DEAD","DIED"]))
death = pd.to_datetime(df[death_col], errors="coerce") if death_col else pd.Series([pd.NaT]*len(df))
last = pd.to_datetime(df[censor_col], errors="coerce") if censor_col else pd.Series([pd.NaT]*len(df))
end = death.where(ev, last)
time_days = (end - start).dt.days.astype("float")
time_days = time_days.where(time_days >= 0, np.nan)
used = {
"start": start_col,
"event": event_col,
"death": death_col,
"censor": censor_col,
}
return time_days.to_numpy(), ev.astype(int).to_numpy(), used
import re
def align_columns_to_schema(df: pd.DataFrame, required_cols: list[str]) -> pd.DataFrame:
"""
Rename df columns to match required_cols using normalized matching.
This fixes newline/space/underscore/case differences.
"""
# normalize inference columns
df_cols = list(df.columns)
norm_to_actual = {norm_col(c): c for c in df_cols}
rename_map = {}
for req in required_cols:
k = norm_col(req)
if k in norm_to_actual:
src = norm_to_actual[k]
if src != req:
rename_map[src] = req
if rename_map:
df = df.rename(columns=rename_map)
return df
def norm_col(s: str) -> str:
"""
Normalize column names for robust matching:
- strip leading/trailing whitespace
- collapse multiple spaces
- remove underscores
- lower-case
"""
s = "" if s is None else str(s)
s = s.replace("\u00A0", " ") # non-breaking space -> normal space
s = re.sub(r"\s+", " ", s).strip() # collapse spaces
s = s.replace("_", "") # ignore underscores
return s.lower()
def find_col(df: pd.DataFrame, candidates: list[str]) -> str | None:
"""Return the first column in df that matches any candidate (normalized), else None."""
lookup = {norm_col(c): c for c in df.columns}
for cand in candidates:
k = norm_col(cand)
if k in lookup:
return lookup[k]
return None
def train_survival_bundle(
df: pd.DataFrame,
feature_cols: list[str],
num_cols: list[str],
cat_cols: list[str],
time_days: np.ndarray,
event01: np.ndarray,
*,
penalizer: float = 0.1
):
"""
Returns (bundle_dict, notes). Raises exceptions if hard-fail.
Lazy-imports lifelines.
"""
from lifelines import CoxPHFitter # lazy
from sklearn.impute import SimpleImputer # light, ok here too
# build survival DF
df_surv = df[feature_cols].copy().replace({pd.NA: np.nan})
# coerce numeric/cat
for c in num_cols:
if c in df_surv.columns:
df_surv[c] = pd.to_numeric(df_surv[c], errors="coerce")
for c in cat_cols:
if c in df_surv.columns:
df_surv[c] = df_surv[c].astype("object")
df_surv.loc[df_surv[c].isna(), c] = np.nan
df_surv[c] = df_surv[c].map(lambda v: v if pd.isna(v) else str(v))
df_surv["time_days"] = time_days
df_surv["event"] = event01
df_surv = df_surv.dropna(subset=["time_days", "event"])
duration_col = "time_days"
event_col = "event"
# one-hot
df_surv_oh = pd.get_dummies(df_surv, columns=cat_cols, drop_first=True)
if duration_col not in df_surv_oh.columns or event_col not in df_surv_oh.columns:
raise ValueError("Survival DF missing duration/event columns after one-hot encoding.")
# remove duplicate columns if any messy headers caused duplicates
df_surv_oh = df_surv_oh.loc[:, ~df_surv_oh.columns.duplicated()].copy()
# predictor columns
X_cols = [c for c in df_surv_oh.columns if c not in (duration_col, event_col)]
# force numeric for Cox predictors
df_surv_oh[X_cols] = df_surv_oh[X_cols].apply(pd.to_numeric, errors="coerce")
# ---- Drop zero-variance columns (CRITICAL) ----
# Do this before imputation to remove constant 0/1 one-hot columns
var0 = df_surv_oh[X_cols].var(skipna=True)
X_cols = [c for c in X_cols if pd.notna(var0.get(c, np.nan)) and var0[c] > 0]
if len(X_cols) == 0:
raise ValueError("All Cox predictors are zero-variance after one-hot; cannot fit survival model.")
# ---- Impute predictors ----
# ---- Phase 1: temporary impute only to compute variance reliably ----
imp_tmp = SimpleImputer(strategy="median")
X_tmp = imp_tmp.fit_transform(df_surv_oh[X_cols])
df_surv_oh.loc[:, X_cols] = pd.DataFrame(X_tmp, columns=X_cols, index=df_surv_oh.index)
# ---- Drop any columns that became constant after imputation ----
var1 = df_surv_oh[X_cols].var(skipna=True)
X_cols = [c for c in X_cols if pd.notna(var1.get(c, np.nan)) and var1[c] > 0]
if len(X_cols) == 0:
raise ValueError("All Cox predictors became zero-variance after imputation; cannot fit survival model.")
# ---- Basic sanity: events vs predictors ----
n_events = int(np.sum(df_surv_oh[event_col].astype(int).to_numpy()))
if n_events < 5:
raise ValueError(f"Too few events for Cox model (events={n_events}).")
# ---- Optional stabilizer: top-k most variable predictors ----
if len(X_cols) > max(10, 2 * n_events):
k = max(10, 2 * n_events)
vars_sorted = var1[X_cols].sort_values(ascending=False)
X_cols = list(vars_sorted.head(k).index)
# ---- Phase 2 (CRITICAL): fit FINAL imputer on FINAL X_cols ----
imp = SimpleImputer(strategy="median")
X_final = imp.fit_transform(df_surv_oh[X_cols])
df_surv_oh.loc[:, X_cols] = pd.DataFrame(X_final, columns=X_cols, index=df_surv_oh.index)
# ---- Fit Cox ----
cph = CoxPHFitter(penalizer=float(max(penalizer, 1.0)), l1_ratio=0.0)
cph.fit(df_surv_oh[[*X_cols, duration_col, event_col]],
duration_col=duration_col,
event_col=event_col)
bundle = {
"model": cph,
"columns": X_cols,
"imputer": imp,
"cat_cols": cat_cols,
"num_cols": num_cols,
"feature_cols": feature_cols,
"duration_col": duration_col,
"event_col": event_col,
"version": 2
}
return bundle, "Survival model trained successfully."
# ============================================================
# Model pipeline
# ============================================================
# =========================
# 3) UPDATED pipeline builder
# =========================
def build_pipeline(
num_cols,
cat_cols,
*,
use_feature_selection: bool = True,
l1_C: float = 1.0,
selection_max_features: int | None = None,
use_dimred: bool = False,
svd_components: int = 50,
svd_random_state: int = 42,
):
"""
- No limit on raw variables.
- Drops useless columns:
(a) zero-variance after preprocessing
(b) L1-based selection (optional)
- Optional dimensionality reduction via TruncatedSVD (sparse-friendly).
"""
num_pipe = Pipeline([
("imputer", SimpleImputer(strategy="median")),
("scaler", StandardScaler())
])
cat_pipe = Pipeline([
("imputer", SimpleImputer(strategy="most_frequent")),
("onehot", OneHotEncoder(handle_unknown="ignore", sparse_output=True, drop="first"))
])
preprocessor = ColumnTransformer(
transformers=[
("num", num_pipe, num_cols),
("cat", cat_pipe, cat_cols)
],
remainder="drop",
verbose_feature_names_out=False
)
steps = []
steps.append(("preprocess", preprocessor))
# Drop columns that never vary (common in one-hot)
steps.append(("vt", VarianceThreshold(threshold=0.0)))
# Optional: dimensionality reduction (works with sparse)
# WARNING: makes SHAP less interpretable (components instead of original features)
if use_dimred:
steps.append(("svd", TruncatedSVD(
n_components=int(svd_components),
random_state=int(svd_random_state)
)))
# Optional: feature selection using L1 logistic (sparse-friendly if SVD is OFF;
# if SVD is ON, selection happens on components)
if use_feature_selection:
selector_est = LogisticRegression(
penalty="l1", solver="saga", C=float(l1_C),
max_iter=5000, n_jobs=-1,
class_weight="balanced"
)
# If you want to cap features: set max_features and use threshold that keeps top coefficients.
# SelectFromModel doesn't have direct "max_features" — simplest safe approach is threshold-based.
# Keep threshold='median' as default; adjust if you want more aggressive pruning.
selector = SelectFromModel(selector_est, threshold="median")
steps.append(("select", selector))
# Final classifier (keep stable, probability-calibratable)
clf = LogisticRegression(max_iter=5000, solver="lbfgs", class_weight="balanced")
steps.append(("clf", clf))
return Pipeline(steps)
# ============================================================
# Validation utilities
# ============================================================
def coerce_binary_label(y: pd.Series):
y_clean = y.dropna()
uniq = list(pd.unique(y_clean))
if len(uniq) != 2:
raise ValueError(f"Outcome Event must be binary (2 unique values). Found: {uniq}")
if pd.api.types.is_numeric_dtype(y_clean):
pos = sorted(uniq)[-1]
return (y == pos).astype(int).to_numpy(), pos
if y_clean.dtype == bool:
return y.astype(int).to_numpy(), True
uniq_str = sorted([str(u) for u in uniq])
pos = uniq_str[-1]
return y.astype(str).eq(pos).astype(int).to_numpy(), pos
# ============================================================
# Training + persistence
# ============================================================
def compute_classification_metrics(y_true, y_proba, threshold: float = 0.5):
y_pred = (y_proba >= threshold).astype(int)
tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
sensitivity = tp / (tp + fn) if (tp + fn) else 0.0 # recall, TPR
specificity = tn / (tn + fp) if (tn + fp) else 0.0 # TNR
precision = precision_score(y_true, y_pred, zero_division=0)
recall = recall_score(y_true, y_pred, zero_division=0)
f1 = f1_score(y_true, y_pred, zero_division=0)
acc = accuracy_score(y_true, y_pred)
bacc = balanced_accuracy_score(y_true, y_pred)
return {
"threshold": float(threshold),
"tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp),
"sensitivity": float(sensitivity),
"specificity": float(specificity),
"precision": float(precision),
"recall": float(recall),
"f1": float(f1),
"accuracy": float(acc),
"balanced_accuracy": float(bacc),
}
def find_best_threshold(y_true, y_proba, metric="f1"):
thresholds = np.linspace(0.05, 0.95, 181) # step ~0.005
best_t, best_val, best_cls = 0.5, -1, None
for t in thresholds:
cls = compute_classification_metrics(y_true, y_proba, threshold=float(t))
val = cls.get(metric, 0.0)
if val > best_val:
best_val, best_t, best_cls = val, float(t), cls
return best_t, best_val, best_cls
def find_best_threshold_f1(y_true, y_proba, t_min=0.01, t_max=0.99, n=199):
"""
Returns threshold that maximizes F1 on (y_true, y_proba).
"""
thresholds = np.linspace(float(t_min), float(t_max), int(n))
best = {"threshold": 0.5, "f1": -1.0, "cls": None}
for t in thresholds:
cls = compute_classification_metrics(y_true, y_proba, threshold=float(t))
if cls["f1"] > best["f1"]:
best = {"threshold": float(t), "f1": float(cls["f1"]), "cls": cls}
return best["threshold"], best["cls"]
#BOOTSTRAPPING UTILITIES
def bootstrap_internal_validation(
df: pd.DataFrame,
feature_cols: list[str],
num_cols: list[str],
cat_cols: list[str],
*,
n_boot: int = 1000,
threshold: float = 0.5,
use_feature_selection: bool = True,
l1_C: float = 1.0,
use_dimred: bool = False,
svd_components: int = 50,
random_state: int = 42,
):
"""
Bootstrap internal validation using OOB evaluation.
For each bootstrap replicate:
- sample N rows with replacement
- fit pipeline on bootstrap sample
- evaluate on OOB rows (not sampled)
Returns:
- df_boot: per-iteration metrics
- summary: mean, median, 2.5% and 97.5% CI for key metrics
- n_skipped: count of iterations skipped due to empty OOB or single-class OOB
"""
rng = np.random.default_rng(int(random_state))
n = len(df)
idx_all = np.arange(n)
rows = []
n_skipped = 0
for b in range(int(n_boot)):
boot_idx = rng.integers(0, n, size=n) # sample N rows with replacement
oob_mask = np.ones(n, dtype=bool)
oob_mask[boot_idx] = False
oob_idx = idx_all[oob_mask]
# If OOB empty, skip (rare but possible with small N)
if len(oob_idx) == 0:
n_skipped += 1
continue
df_boot = df.iloc[boot_idx].copy()
df_oob = df.iloc[oob_idx].copy()
# Prepare X/y for boot and OOB
Xb = df_boot[feature_cols].copy().replace({pd.NA: np.nan})
yb_raw = df_boot[LABEL_COL].copy()
Xo = df_oob[feature_cols].copy().replace({pd.NA: np.nan})
yo_raw = df_oob[LABEL_COL].copy()
# Numeric coercion
for c in num_cols:
Xb[c] = pd.to_numeric(Xb[c], errors="coerce")
Xo[c] = pd.to_numeric(Xo[c], errors="coerce")
# Categorical coercion
for c in cat_cols:
Xb[c] = Xb[c].astype("object")
Xb.loc[Xb[c].isna(), c] = np.nan
Xb[c] = Xb[c].map(lambda v: v if pd.isna(v) else str(v))
Xo[c] = Xo[c].astype("object")
Xo.loc[Xo[c].isna(), c] = np.nan
Xo[c] = Xo[c].map(lambda v: v if pd.isna(v) else str(v))
# Coerce labels
try:
yb01, _ = coerce_binary_label(yb_raw)
yo01, _ = coerce_binary_label(yo_raw)
except Exception:
n_skipped += 1
continue
# If OOB has only one class, AUC/PR-AUC invalid -> skip
if len(np.unique(yo01)) < 2:
n_skipped += 1
continue
pipe_b = build_pipeline(
num_cols, cat_cols,
use_feature_selection=use_feature_selection,
l1_C=l1_C,
use_dimred=use_dimred,
svd_components=svd_components,
)
try:
pipe_b.fit(Xb, yb01)
proba_oob = pipe_b.predict_proba(Xo)[:, 1]
except Exception:
n_skipped += 1
continue
# Metrics on OOB
try:
auc = float(roc_auc_score(yo01, proba_oob))
except Exception:
auc = np.nan
pr = compute_pr_curve(yo01, proba_oob) # returns dict with average_precision
cal = compute_calibration(yo01, proba_oob, n_bins=10, strategy="uniform") # has brier
cls = compute_classification_metrics(yo01, proba_oob, threshold=float(threshold))
rows.append({
"iter": int(b),
"n_oob": int(len(oob_idx)),
"roc_auc": auc,
"avg_precision": float(pr["average_precision"]),
"brier": float(cal["brier"]),
"sensitivity": float(cls["sensitivity"]),
"specificity": float(cls["specificity"]),
"precision": float(cls["precision"]),
"recall": float(cls["recall"]),
"f1": float(cls["f1"]),
"accuracy": float(cls["accuracy"]),
"balanced_accuracy": float(cls["balanced_accuracy"]),
})
df_boot = pd.DataFrame(rows)
def summarise(col: str):
s = df_boot[col].dropna()
if len(s) == 0:
return {"mean": np.nan, "median": np.nan, "ci2.5": np.nan, "ci97.5": np.nan, "n": 0}
return {
"mean": float(s.mean()),
"median": float(s.median()),
"ci2.5": float(np.quantile(s, 0.025)),
"ci97.5": float(np.quantile(s, 0.975)),
"n": int(len(s))
}
summary = {
"roc_auc": summarise("roc_auc"),
"avg_precision": summarise("avg_precision"),
"brier": summarise("brier"),
"sensitivity": summarise("sensitivity"),
"specificity": summarise("specificity"),
"f1": summarise("f1"),
"balanced_accuracy": summarise("balanced_accuracy"),
}
return df_boot, summary, n_skipped
#PLOT CURVE UTILITIES
def compute_pr_curve(y_true, y_proba):
precision, recall, pr_thresholds = precision_recall_curve(y_true, y_proba)
ap = average_precision_score(y_true, y_proba)
return {
"average_precision": float(ap),
"precision": [float(x) for x in precision],
"recall": [float(x) for x in recall],
"thresholds": [float(x) for x in pr_thresholds],
}
def compute_calibration(y_true, y_proba, n_bins: int = 10, strategy: str = "uniform"):
prob_true, prob_pred = calibration_curve(
y_true, y_proba, n_bins=n_bins, strategy=strategy
)
brier = brier_score_loss(y_true, y_proba)
return {
"n_bins": int(n_bins),
"strategy": str(strategy),
"prob_true": [float(x) for x in prob_true],
"prob_pred": [float(x) for x in prob_pred],
"brier": float(brier),
}
def decision_curve_analysis(y_true, y_proba, thresholds=None):
y_true = np.asarray(y_true).astype(int)
y_proba = np.asarray(y_proba).astype(float)
if thresholds is None:
thresholds = np.linspace(0.01, 0.99, 99)
n = len(y_true)
prevalence = float(np.mean(y_true))
nb_model = []
nb_all = []
nb_none = []
for pt in thresholds:
y_pred = (y_proba >= pt).astype(int)
tp = np.sum((y_pred == 1) & (y_true == 1))
fp = np.sum((y_pred == 1) & (y_true == 0))
w = pt / (1.0 - pt)
nb_m = (tp / n) - (fp / n) * w
nb_a = prevalence - (1.0 - prevalence) * w
nb_n = 0.0
nb_model.append(float(nb_m))
nb_all.append(float(nb_a))
nb_none.append(float(nb_n))
return {
"thresholds": [float(x) for x in thresholds],
"net_benefit_model": nb_model,
"net_benefit_all": nb_all,
"net_benefit_none": nb_none,
"prevalence": prevalence,
}
def train_and_save(
df: pd.DataFrame,
feature_cols,
num_cols,
cat_cols,
n_bins: int,
cal_strategy: str,
dca_points: int,
*,
use_feature_selection: bool,
l1_C: float,
use_dimred: bool,
svd_components: int,):
X = df[feature_cols].copy()
y_raw = df[LABEL_COL].copy()
X = X.replace({pd.NA: np.nan})
for c in num_cols:
X[c] = pd.to_numeric(X[c], errors="coerce")
for c in cat_cols:
X[c] = X[c].astype("object")
X.loc[X[c].isna(), c] = np.nan
X[c] = X[c].map(lambda v: v if pd.isna(v) else str(v))
y01, pos_class = coerce_binary_label(y_raw)
# ---- Survival targets (time-to-event) ----
# Requires these columns to exist in the training Excel:
# SURV_START_COL, SURV_DEATH_COL, SURV_CENSOR_COL
surv_used_cols = None
try:
time_days, event01, surv_used_cols = build_survival_targets(df)
except Exception:
time_days, event01, surv_used_cols = None, None, None
X_train, X_test, y_train, y_test = train_test_split(
X, y01, test_size=0.2, random_state=42, stratify=y01
)
pipe = build_pipeline(
num_cols, cat_cols,
use_feature_selection=use_feature_selection,
l1_C=l1_C,
use_dimred=use_dimred,
svd_components=svd_components
)
pipe.fit(X_train, y_train)
proba = pipe.predict_proba(X_test)[:, 1]
# (metrics code unchanged...)
# ----- METRICS BLOCK (MISSING) -----
roc_auc = float(roc_auc_score(y_test, proba))
fpr, tpr, roc_thresholds = roc_curve(y_test, proba)
cls_05 = compute_classification_metrics(y_test, proba, threshold=0.5)
best_thr, best_val, cls_best = find_best_threshold(y_test, proba, metric="f1")
metrics = {
"roc_auc": roc_auc,
"n_train": int(len(X_train)),
"n_test": int(len(X_test)),
# reference @0.5
"threshold@0.5": 0.5,
"accuracy@0.5": cls_05["accuracy"],
"balanced_accuracy@0.5": cls_05["balanced_accuracy"],
"precision@0.5": cls_05["precision"],
"recall@0.5": cls_05["recall"],
"f1@0.5": cls_05["f1"],
"sensitivity@0.5": cls_05["sensitivity"],
"specificity@0.5": cls_05["specificity"],
"confusion_matrix@0.5": {
"tn": cls_05["tn"], "fp": cls_05["fp"],
"fn": cls_05["fn"], "tp": cls_05["tp"],
},
# primary: best F1 threshold
"best_threshold_by": "f1",
"best_threshold": float(best_thr),
"best_f1": float(cls_best["f1"]),
"accuracy@best": cls_best["accuracy"],
"balanced_accuracy@best": cls_best["balanced_accuracy"],
"precision@best": cls_best["precision"],
"recall@best": cls_best["recall"],
"f1@best": cls_best["f1"],
"sensitivity@best": cls_best["sensitivity"],
"specificity@best": cls_best["specificity"],
"confusion_matrix@best": {
"tn": cls_best["tn"], "fp": cls_best["fp"],
"fn": cls_best["fn"], "tp": cls_best["tp"],
},
"roc_curve": {
"fpr": [float(x) for x in fpr],
"tpr": [float(x) for x in tpr],
"thresholds": [float(x) for x in roc_thresholds],
},
"pr_curve": compute_pr_curve(y_test, proba),
"calibration": compute_calibration(y_test, proba, n_bins, cal_strategy),
"decision_curve": decision_curve_analysis(y_test, proba, np.linspace(0.01, 0.99, dca_points)),
}
# ---- Train survival model (CoxPH) ----
from sklearn.impute import SimpleImputer
survival_trained = False
surv_notes = None
surv_used_cols = None
try:
time_days, event01, surv_used_cols = build_survival_targets(df)
except Exception:
time_days, event01, surv_used_cols = None, None, None
if time_days is not None and event01 is not None:
try:
bundle, surv_notes = train_survival_bundle(
df=df,
feature_cols=feature_cols,
num_cols=num_cols,
cat_cols=cat_cols,
time_days=time_days,
event01=event01,
penalizer=0.1
)
joblib.dump(bundle, "survival_bundle.joblib", compress=3)
survival_trained = True
except Exception as e:
survival_trained = False
surv_notes = f"Survival model training failed: {e}"
else:
surv_notes = "Survival columns missing or could not be parsed; survival model not trained."
joblib.dump(pipe, "model.joblib",compress=3)
meta = {
"framework": "Explainable-Acute-Leukemia-Mortality-Predictor",
"model": "Logistic Regression",
"created_at_utc": datetime.utcnow().isoformat(),
"schema": {
"features": feature_cols, # now unlimited
"numeric": num_cols,
"categorical": cat_cols,
"label": LABEL_COL,
},
"feature_reduction": {
"variance_threshold": 0.0,
"use_dimred": bool(use_dimred),
"svd_components": int(svd_components) if use_dimred else None,
"use_feature_selection": bool(use_feature_selection),
"l1_C": float(l1_C) if use_feature_selection else None,
"selection_method": "SelectFromModel(L1 saga, threshold=median)" if use_feature_selection else None,
"note": "If SVD is enabled, SHAP becomes component-level (less interpretable)."
},
"shap_background": {
"file": "background.csv",
"max_rows": 200,
"note": "Raw (pre-transform) background sample for SHAP LinearExplainer."
},
"survival": {
"enabled": bool(survival_trained),
"duration_units": "days",
"model": "CoxPHFitter (lifelines) with one-hot for categoricals",
"required_columns_any_of": {
"start": SURV_START_COL_CANDS,
"event": SURV_EVENT_COL_CANDS,
"death": SURV_DEATH_COL_CANDS,
"censor": SURV_CENSOR_COL_CANDS,
},
"used_columns": surv_used_cols,
"notes": surv_notes
},
"positive_class": str(pos_class),
"metrics": metrics,
}
with open("meta.json", "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
st.write("Local survival bundle exists:", os.path.exists("survival_bundle.joblib"))
if os.path.exists("survival_bundle.joblib"):
st.write("Local survival bundle size (bytes):", os.path.getsize("survival_bundle.joblib"))
st.write("Local survival bundle exists:", os.path.exists("survival_bundle.joblib"))
st.json(meta.get("survival", {}))
return pipe, meta, X_train, y_test, proba
# ============================================================
# SHAP
# ============================================================
def build_shap_explainer(pipe, X_bg, max_bg=200):
if X_bg is None or len(X_bg) == 0:
raise ValueError("SHAP background is empty.")
if len(X_bg) > int(max_bg):
X_bg = X_bg.sample(int(max_bg), random_state=42)
clf = pipe.named_steps["clf"]
Xt_bg = transform_before_clf(pipe, X_bg)
explainer = shap.LinearExplainer(
clf,
Xt_bg,
feature_perturbation="interventional"
)
return explainer
def safe_dense(Xt, max_rows: int = 200):
"""
Convert sparse->dense carefully. Avoid converting huge matrices to dense.
"""
if hasattr(Xt, "shape") and Xt.shape[0] > max_rows:
Xt = Xt[:max_rows]
try:
return Xt.toarray()
except Exception:
return np.array(Xt)
def ensure_model_repo_exists(model_repo_id: str, token: str):
"""
Optional helper: create the model repo if it doesn't exist.
Safe to call; if it exists, it will error -> you can ignore.
"""
# Increase Hub HTTP timeouts (seconds)
os.environ["HF_HUB_HTTP_TIMEOUT"] = "300" # 5 minutes
os.environ["HF_HUB_ETAG_TIMEOUT"] = "300"
api = HfApi(token=token)
try:
api.create_repo(repo_id=model_repo_id, repo_type="model", private=False, exist_ok=True)
except Exception:
pass
def coerce_X_like_schema(X: pd.DataFrame, feature_cols: list[str], num_cols: list[str], cat_cols: list[str]) -> pd.DataFrame:
"""
Robustly align an inference dataframe to the training schema:
- normalizes column names (spaces/newlines/underscores/case)
- renames matching columns to the model's exact feature names
- creates missing columns as NaN (so prediction can still run)
"""
X = X.copy()
# Build normalized lookup from inference file columns -> actual column name in inference
inf_lookup = {norm_col(c): c for c in X.columns}
# Build output with exact training columns
X_out = pd.DataFrame(index=X.index)
missing = []
for col in feature_cols:
k = norm_col(col)
if k in inf_lookup:
X_out[col] = X[inf_lookup[k]]
else:
X_out[col] = np.nan
missing.append(col)
# Optional: show missing columns (don’t hard fail)
if missing:
st.warning(
"Inference file is missing some training columns (filled as blank/NaN): "
+ ", ".join(missing[:12]) + (" ..." if len(missing) > 12 else "")
)
# Coercions (same as your existing logic)
X_out = X_out.replace({pd.NA: np.nan})
for c in num_cols:
if c in X_out.columns:
X_out[c] = pd.to_numeric(X_out[c], errors="coerce")
for c in cat_cols:
if c in X_out.columns:
X_out[c] = X_out[c].astype("object")
X_out.loc[X_out[c].isna(), c] = np.nan
X_out[c] = X_out[c].map(lambda v: v if pd.isna(v) else str(v))
return X_out
def get_shap_background_auto(model_repo_id: str, feature_cols: list[str], num_cols: list[str], cat_cols: list[str]) -> pd.DataFrame | None:
"""
Attempts to load SHAP background from HF repo. Returns coerced background or None.
"""
df_bg = load_latest_background(model_repo_id)
if df_bg is None:
return None
# Ensure required columns exist
missing = [c for c in feature_cols if c not in df_bg.columns]
if missing:
return None
return coerce_X_like_schema(df_bg, feature_cols, num_cols, cat_cols)
# ============================================================
# SHAP background persistence (best practice)
# ============================================================
def save_background_sample_csv(X_bg: pd.DataFrame, feature_cols: list[str], max_rows: int = 200, out_path: str = "background.csv"):
"""
Saves a small *raw* background dataset (pre-transform) for SHAP explainer.
Must contain columns exactly matching feature_cols.
"""
if X_bg is None or len(X_bg) == 0:
raise ValueError("X_bg is empty; cannot save background sample.")
X_bg = X_bg[feature_cols].copy()
if len(X_bg) > int(max_rows):
X_bg = X_bg.sample(int(max_rows), random_state=42)
# Preserve exact columns for future loading
X_bg.to_csv(out_path, index=False, encoding="utf-8")
return out_path
def publish_background_to_hub(model_repo_id: str, version_tag: str, background_path: str = "background.csv"):
"""
Uploads background.csv to both versioned and latest paths.
Requires HF_TOKEN with write permissions.
"""
token = os.environ.get("HF_TOKEN")
if not token:
raise RuntimeError("HF_TOKEN not found. Add it in Space Settings → Secrets.")
api = HfApi(token=token)
version_bg_path = f"releases/{version_tag}/background.csv"
# Versioned
api.upload_file(
path_or_fileobj=background_path,
path_in_repo=version_bg_path,
repo_id=model_repo_id,
repo_type="model",
commit_message=f"Upload SHAP background ({version_tag})"
)
# Latest
api.upload_file(
path_or_fileobj=background_path,
path_in_repo="latest/background.csv",
repo_id=model_repo_id,
repo_type="model",
commit_message=f"Update latest SHAP background ({version_tag})"
)
return {
"version_bg_path": version_bg_path,
"latest_bg_path": "latest/background.csv",
}
def load_latest_background(model_repo_id: str) -> pd.DataFrame | None:
"""
Loads latest/background.csv if present. Returns None if not found / cannot load.
"""
try:
bg_file = hf_hub_download(
repo_id=model_repo_id,
repo_type="model",
filename="latest/background.csv",
)
df_bg = pd.read_csv(bg_file)
return df_bg
except Exception:
return None
def load_background_by_version(model_repo_id: str, version_tag: str) -> pd.DataFrame | None:
"""
Loads releases/<version>/background.csv if present.
"""
try:
bg_file = hf_hub_download(
repo_id=model_repo_id,
repo_type="model",
filename=f"releases/{version_tag}/background.csv",
)
df_bg = pd.read_csv(bg_file)
return df_bg
except Exception:
return None
def publish_to_hub(model_repo_id: str, version_tag: str):
"""
Uploads model.joblib and meta.json to a Hugging Face *model* repository
under a versioned folder and also updates a 'latest/' copy.
Requires Space Secret: HF_TOKEN.
"""
token = os.environ.get("HF_TOKEN")
if not token:
raise RuntimeError("HF_TOKEN not found. Add it in Space Settings → Secrets.")
api = HfApi(token=token)
# Versioned paths
version_model_path = f"releases/{version_tag}/model.joblib"
version_meta_path = f"releases/{version_tag}/meta.json"
# Upload versioned artifacts
api.upload_file(
path_or_fileobj="model.joblib",
path_in_repo=version_model_path,
repo_id=model_repo_id,
repo_type="model",
commit_message=f"Upload model artifacts ({version_tag})"
)
api.upload_file(
path_or_fileobj="meta.json",
path_in_repo=version_meta_path,
repo_id=model_repo_id,
repo_type="model",
commit_message=f"Upload metadata ({version_tag})"
)
# Also maintain a 'latest/' copy for easy loading
api.upload_file(
path_or_fileobj="model.joblib",
path_in_repo="latest/model.joblib",
repo_id=model_repo_id,
repo_type="model",
commit_message=f"Update latest model ({version_tag})"
)
api.upload_file(
path_or_fileobj="meta.json",
path_in_repo="latest/meta.json",
repo_id=model_repo_id,
repo_type="model",
commit_message=f"Update latest metadata ({version_tag})"
)
# Optional: upload survival model if present
if os.path.exists("survival_bundle.joblib"):
api.upload_file(
path_or_fileobj="survival_bundle.joblib",
path_in_repo=f"releases/{version_tag}/survival_bundle.joblib",
repo_id=model_repo_id,
repo_type="model",
commit_message=f"Upload survival model ({version_tag})"
)
api.upload_file(
path_or_fileobj="survival_bundle.joblib",
path_in_repo="latest/survival_bundle.joblib",
repo_id=model_repo_id,
repo_type="model",
commit_message=f"Update latest survival model ({version_tag})"
)
files = api.list_repo_files(repo_id=model_repo_id, repo_type="model")
st.write([f for f in files if "survival" in f])
return {
"version_model_path": version_model_path,
"version_meta_path": version_meta_path,
"latest_model_path": "latest/model.joblib",
"latest_meta_path": "latest/meta.json",
}
MODEL_REPO_ID = "Synav/Explainable-Acute-Leukemia-Mortality-Predictor"
def list_release_versions(model_repo_id: str):
"""
Returns sorted version tags found under releases/<version>/model.joblib in the model repo.
"""
api = HfApi(token=os.environ.get("HF_TOKEN") or None)
files = api.list_repo_files(repo_id=model_repo_id, repo_type="model")
versions = set()
for f in files:
# We only care about releases/<version>/model.joblib
if f.startswith("releases/") and f.endswith("/model.joblib"):
parts = f.split("/")
if len(parts) >= 3:
versions.add(parts[1])
# Most users want newest first (timestamp tags sort lexicographically)
return sorted(versions, reverse=True)
def load_model_by_version(model_repo_id: str, version_tag: str):
"""
Loads a specific version from releases/<version_tag>/model.joblib and meta.json
"""
model_file = hf_hub_download(
repo_id=model_repo_id,
repo_type="model",
filename=f"releases/{version_tag}/model.joblib",
)
meta_file = hf_hub_download(
repo_id=model_repo_id,
repo_type="model",
filename=f"releases/{version_tag}/meta.json",
)
pipe = joblib.load(model_file)
with open(meta_file, "r", encoding="utf-8") as f:
meta = json.load(f)
return pipe, meta
def load_latest_survival_model(model_repo_id: str):
f = hf_hub_download(
repo_id=model_repo_id,
repo_type="model",
filename="latest/survival_bundle.joblib",
)
return joblib.load(f)
def load_survival_model_by_version(model_repo_id: str, version_tag: str):
f = hf_hub_download(
repo_id=model_repo_id,
repo_type="model",
filename=f"releases/{version_tag}/survival_bundle.joblib",
)
return joblib.load(f)
def load_latest_survival_bundle(model_repo_id: str):
f = hf_hub_download(repo_id=model_repo_id, repo_type="model", filename="latest/survival_bundle.joblib")
return joblib.load(f)
def load_latest_model(model_repo_id: str):
"""
Loads latest/model.joblib and latest/meta.json
"""
model_file = hf_hub_download(
repo_id=model_repo_id,
repo_type="model",
filename="latest/model.joblib",
)
meta_file = hf_hub_download(
repo_id=model_repo_id,
repo_type="model",
filename="latest/meta.json",
)
pipe = joblib.load(model_file)
with open(meta_file, "r", encoding="utf-8") as f:
meta = json.load(f)
return pipe, meta
def get_post_selection_feature_names(pipe) -> list[str]:
"""
Returns feature names aligned to the exact columns seen by the final classifier.
Works when SVD is OFF.
Handles:
preprocess -> VarianceThreshold -> SelectFromModel -> clf
If some steps are missing, it degrades gracefully.
"""
pre = pipe.named_steps.get("preprocess", None)
if pre is None:
raise ValueError("Pipeline missing 'preprocess' step.")
# Start with preprocessed (post-onehot) names
try:
names = list(pre.get_feature_names_out())
except Exception:
# Fallback, but you should almost never hit this
names = [f"f{i}" for i in range(pipe.named_steps["clf"].coef_.shape[1])]
# Apply VarianceThreshold support mask (if present)
vt = pipe.named_steps.get("vt", None)
if vt is not None and hasattr(vt, "get_support"):
vt_mask = vt.get_support()
names = [n for n, keep in zip(names, vt_mask) if keep]
# If SVD exists, we cannot map back to original one-hot features
# because the feature space becomes components.
if "svd" in pipe.named_steps:
# Return component names for correctness
svd = pipe.named_steps["svd"]
k = getattr(svd, "n_components", None) or len(names)
return [f"SVD_component_{i+1}" for i in range(int(k))]
# Apply SelectFromModel mask (if present)
sel = pipe.named_steps.get("select", None)
if sel is not None and hasattr(sel, "get_support"):
sel_mask = sel.get_support()
names = [n for n, keep in zip(names, sel_mask) if keep]
return names
def transform_before_clf(pipe, X):
"""
Transform X through all pipeline steps BEFORE the classifier.
Ensures SHAP sees the exact columns the clf sees.
"""
Xt = X
for name, step in pipe.steps:
if name == "clf":
break
if hasattr(step, "transform"):
Xt = step.transform(Xt)
return Xt
def get_final_feature_names(pipe) -> list[str]:
"""
Return feature names aligned with the exact columns seen by the classifier.
Handles:
preprocess -> vt -> (svd?) -> (select?) -> clf
"""
pre = pipe.named_steps.get("preprocess", None)
if pre is None:
raise ValueError("Pipeline missing 'preprocess' step.")
# Start: post-onehot names
try:
names = list(pre.get_feature_names_out())
except Exception:
names = None
# Apply vt mask if we have names
vt = pipe.named_steps.get("vt", None)
if names is not None and vt is not None and hasattr(vt, "get_support"):
vt_mask = vt.get_support()
names = [n for n, keep in zip(names, vt_mask) if keep]
# If SVD exists: feature space becomes components; names must be components
if "svd" in pipe.named_steps:
svd = pipe.named_steps["svd"]
k = int(getattr(svd, "n_components", 0) or 0)
if k <= 0:
# fallback if unknown
k = 0
return [f"SVD_component_{i+1}" for i in range(k)]
# Apply select mask (post-vt) if we have names
sel = pipe.named_steps.get("select", None)
if names is not None and sel is not None and hasattr(sel, "get_support"):
sel_mask = sel.get_support()
names = [n for n, keep in zip(names, sel_mask) if keep]
# Fallback: if we couldn't get names, return generic based on clf coef size
if names is None:
n = pipe.named_steps["clf"].coef_.shape[1]
names = [f"f{i}" for i in range(n)]
return names
def options_for(col: str, df: pd.DataFrame | None):
base = DEFAULT_OPTIONS.get(col, [])
excel = []
if df is not None and col in df.columns:
excel = (
df[col].dropna().astype(str).map(lambda x: x.strip())
.loc[lambda s: s != ""].unique().tolist()
)
# keep order: defaults first, then any extra from excel
out = []
for v in base + sorted(set(excel)):
if v not in out:
out.append(v)
return [""] + out
import re
# Canonical region labels you can use for analysis
# (UN-style: Africa, Americas, Asia, Europe, Oceania; you can later refine into subregions)
REGION_UNKNOWN = "Unknown"
# Normalization: your Excel has values like "Emarati", "FILIPINO", "Indian ", etc.
NATIONALITY_ALIASES = {
# Nationality adjectives / common variants -> canonical country name
"EMARATI": "United Arab Emirates",
"EMIRATI": "United Arab Emirates",
"UAE": "United Arab Emirates",
"FILIPINO": "Philippines",
"PALESTINIAN": "Palestine",
"SYRIAN": "Syria",
"LEBANESE": "Lebanon",
"JORDANIAN": "Jordan",
"YEMENI": "Yemen",
"EGYPTIAN": "Egypt",
"SUDANESE": "Sudan",
"ETHIOPIAN": "Ethiopia",
"ERITREAN": "Eritrea",
"SOMALI": "Somalia",
"KENYAN": "Kenya",
"UGANDAN": "Uganda",
"GUINEAN": "Guinea",
"MOROCCAN": "Morocco",
"COMORAN": "Comoros",
"INDIAN": "India",
"PAKISTANI": "Pakistan",
"BANGLADESH": "Bangladesh",
"NEPALESE": "Nepal",
"INDONESIAN": "Indonesia",
"MALAYSIAN": "Malaysia",
"AMERICAN": "United States",
"USA": "United States",
"U.S.A": "United States",
"UNITED STATES OF AMERICA": "United States",
}
def normalize_country_name(x: str) -> str | None:
"""Convert nationality-like strings to a canonical country name when possible."""
if x is None:
return None
s = str(x).strip()
if not s:
return None
s_up = re.sub(r"\s+", " ", s).strip().upper()
# map adjectives/variants
if s_up in NATIONALITY_ALIASES:
return NATIONALITY_ALIASES[s_up]
# If someone already entered a country name, keep it
# country_converter can handle many variants; pass through as-is
return s.strip()
from typing import Optional
def country_to_region(country: str | None) -> str:
"""
Lazy-import country_converter to reduce startup memory.
Returns one of: Africa, Americas, Asia, Europe, Oceania, Unknown
"""
if not country:
return REGION_UNKNOWN
import country_converter as coco # lazy
r = coco.convert(names=country, to="continent")
if not r or str(r).lower() in ("not found", "nan", "none"):
return REGION_UNKNOWN
if r == "America":
return "Americas"
return str(r)
def add_ethnicity_region(df: pd.DataFrame, eth_col: str = "Ethnicity", out_col: str = "Ethnicity_Region") -> pd.DataFrame:
if eth_col not in df.columns:
df[out_col] = REGION_UNKNOWN
return df
norm = df[eth_col].map(normalize_country_name)
df[out_col] = norm.map(country_to_region)
return df
# ============================================================
# Streamlit UI
# ============================================================
def is_admin() -> bool:
"""
Admin gating for sensitive actions (Train + Publish).
Requires Space secret ADMIN_KEY.
"""
secret = os.environ.get("ADMIN_KEY", "")
if not secret:
return False
entered = st.session_state.get("admin_key", "")
return hmac.compare_digest(str(entered), str(secret))
st.set_page_config(page_title="Explainable-Acute-Leukemia-Mortality-Predictor", layout="wide")
# ---------------- Plot settings (global: Train + Predict) ----------------
with st.sidebar:
st.markdown("### Plot settings")
plot_width = st.slider("Plot width (inches)", 4.0, 10.0, 5.5, 0.1)
plot_height = st.slider("Plot height (inches)", 2.5, 6.0, 3.6, 0.1)
plot_dpi_screen = st.slider("Screen DPI", 80, 200, 120, 10)
export_dpi = st.selectbox("Export DPI (PNG)", [300, 600, 900, 1200], index=1)
FIGSIZE = (plot_width, plot_height)
st.title("Explainable-Acute-Leukemia-Mortality-Predictor")
st.caption("Explainable clinical AI for mortality and outcome prediction in acute leukemia using SHAP-interpretable models")
with st.expander("About this AI model, who can use it, and required Excel format", expanded=True):
st.markdown("""
### Quick Start (Single Patient)
1. Open **Predict + SHAP**
2. **Load model** → select *latest* (default)
3. Enter patient details (Core → Clinical → FISH → NGS)
4. Click **Predict single patient**
You will get:
• Mortality (os) probability
• Risk category
• 6-month, 1-year, 3-year survival with PLOT
• SHAP explanation
Edit values to instantly test different scenarios.
*For research/decision-support only — not a substitute for clinical judgment.*
""")
st.warning(
"Prediction will fail if feature names or variable types "
"do not exactly match the trained model schema."
)
# 🔐 Admin controls come AFTER schema explanation
with st.expander("Admin controls", expanded=False):
st.text_input("Admin key", type="password", key="admin_key")
st.caption("Training and publishing are enabled only after admin authentication.")
tab_train, tab_predict = st.tabs(["1️⃣ Train", "2️⃣ Predict + SHAP"])
if "pipe" not in st.session_state:
st.session_state.pipe = None
if "explainer" not in st.session_state:
st.session_state.explainer = None
# ---------------- TRAIN ----------------
with tab_train:
st.subheader("Train model")
if not is_admin():
st.info("Training and publishing are restricted. Use Predict + SHAP for inference.")
else:
st.markdown("### Feature reduction options")
use_feature_selection = st.checkbox(
"Drop columns that do not affect prediction (L1 feature selection)",
value=True,
key="train_use_feature_selection"
)
l1_C = st.slider(
"L1 selection strength (lower = fewer features)",
0.01, 10.0, 1.0, 0.01
) if use_feature_selection else 1.0
use_dimred = st.checkbox(
"Dimensionality reduction (TruncatedSVD) — reduces interpretability",
value=False
)
svd_components = st.slider(
"SVD components (only used if enabled)",
5, 300, 50, 5
) if use_dimred else 50
st.divider()
# then keep your file uploader + training button + publish block here
st.markdown("### Bootstrap internal validation (optional)")
do_bootstrap = st.checkbox(
"Enable bootstrapping (OOB internal validation)",
value=False,
key="train_bootstrap_enable"
)
n_boot = st.slider(
"Bootstrap iterations",
50, 2000, 1000, 50,
key="train_bootstrap_n"
) if do_bootstrap else 0
boot_thr = st.slider(
"Bootstrap classification threshold",
0.0, 1.0, 0.5, 0.01,
key="train_bootstrap_thr"
) if do_bootstrap else 0.5
#---------------------
train_file = st.file_uploader("Upload training Excel (.xlsx)", type=["xlsx"])
if train_file is None:
st.info("Upload a training Excel file to enable training.")
else:
df = pd.read_excel(train_file, engine="openpyxl")
df.columns = [c.strip() for c in df.columns]
feature_cols = get_feature_cols_from_df(df)
st.dataframe(df.head(), use_container_width=True)
feature_cols = get_feature_cols_from_df(df)
st.markdown("### Choose variable types (saved into the model)")
default_numeric = feature_cols[:13] # initial suggestion
num_cols = st.multiselect(
"Numeric variables (will be median-imputed + scaled)",
options=feature_cols,
default=default_numeric
)
# Everything not selected as numeric becomes categorical
cat_cols = [c for c in feature_cols if c not in num_cols]
st.write(f"Categorical variables (will be most-frequent-imputed + one-hot): {len(cat_cols)}")
st.caption("Note: The selected schema is stored with the trained model and must match inference files.")
st.markdown("### Evaluation settings")
n_bins = st.slider("Calibration bins", 5, 20, 10, 1)
cal_strategy = st.selectbox("Calibration binning strategy", ["uniform", "quantile"], index=0)
dca_points = st.slider("Decision curve points", 25, 200, 99, 1)
if st.button("Train model"):
with st.spinner("Training model..."):
pipe, meta, X_train, y_test, proba = train_and_save(
df, feature_cols, num_cols, cat_cols,
n_bins=n_bins, cal_strategy=cal_strategy, dca_points=dca_points,
use_feature_selection=use_feature_selection,
l1_C=l1_C,
use_dimred=use_dimred,
svd_components=svd_components
)
# --- Save background sample for SHAP (raw X_train) ---
try:
save_background_sample_csv(
X_bg=X_train,
feature_cols=feature_cols,
max_rows=200,
out_path="background.csv"
)
st.success("Saved SHAP background sample (background.csv).")
except Exception as e:
st.warning(f"Could not save SHAP background sample: {e}")
st.write("Local survival bundle exists:", os.path.exists("survival_bundle.joblib"))
if os.path.exists("survival_bundle.joblib"):
st.write("Local survival bundle size (bytes):", os.path.getsize("survival_bundle.joblib"))
explainer = build_shap_explainer(pipe, X_train)
st.session_state.pipe = pipe
st.session_state.explainer = explainer
st.session_state.meta = meta
st.success("Training complete. model.joblib and meta.json created.")
st.divider()
st.subheader("Training performance (test split)")
m = meta["metrics"]
# Show key metrics at threshold 0.5
c1, c2, c3, c4 = st.columns(4)
c1.metric("ROC AUC", f"{m['roc_auc']:.3f}")
c2.metric("Sensitivity (best F1 thr)", f"{m['sensitivity@best']:.3f}")
c3.metric("Specificity (best F1 thr)", f"{m['specificity@best']:.3f}")
c4.metric("F1 (best)", f"{m['f1@best']:.3f}")
st.caption(f"Best threshold (max F1): {m['best_threshold']:.2f}")
c5, c6, c7, c8 = st.columns(4)
c5.metric("Precision", f"{m['precision@0.5']:.3f}")
c6.metric("Accuracy", f"{m['accuracy@0.5']:.3f}")
c7.metric("Balanced Acc", f"{m['balanced_accuracy@0.5']:.3f}")
c8.metric("Test N", m["n_test"])
if do_bootstrap:
st.divider()
st.subheader(f"Bootstrap internal validation (n={n_boot}, OOB)")
with st.spinner("Running bootstrap..."):
df_boot, boot_summary, n_skipped = bootstrap_internal_validation(
df=df,
feature_cols=feature_cols,
num_cols=num_cols,
cat_cols=cat_cols,
n_boot=int(n_boot),
threshold=float(boot_thr),
use_feature_selection=use_feature_selection,
l1_C=l1_C,
use_dimred=use_dimred,
svd_components=svd_components,
random_state=42,
)
st.caption(f"Completed: {len(df_boot)} iterations. Skipped: {n_skipped} (empty/single-class OOB or fitting issues).")
# Display summary
st.json(boot_summary)
# Show per-iteration table (optional)
st.dataframe(df_boot.head(30), use_container_width=True)
# Download full bootstrap results
st.download_button(
"Download bootstrap results (CSV)",
df_boot.to_csv(index=False).encode("utf-8"),
file_name=f"bootstrap_oob_{n_boot}.csv",
mime="text/csv",
key="dl_bootstrap_csv"
)
# Save into meta.json so it persists with the model
st.session_state.meta.setdefault("bootstrap", {})
st.session_state.meta["bootstrap"] = {
"method": "bootstrap_oob",
"n_boot": int(n_boot),
"threshold": float(boot_thr),
"skipped": int(n_skipped),
"summary": boot_summary,
}
meta = st.session_state.meta
with open("meta.json", "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
# Confusion matrix display
cm = m["confusion_matrix@0.5"]
cm_df = pd.DataFrame(
[[cm["tn"], cm["fp"]], [cm["fn"], cm["tp"]]],
index=["Actual 0", "Actual 1"],
columns=["Pred 0", "Pred 1"]
)
st.markdown("**Confusion Matrix (threshold = 0.5)**")
st.dataframe(cm_df)
st.json(meta["survival"])
# TRAINING: ROC curve plot
# =========================
roc = m["roc_curve"]
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
ax.plot(roc["fpr"], roc["tpr"])
ax.plot([0, 1], [0, 1])
ax.set_xlabel("False Positive Rate (1 - Specificity)")
ax.set_ylabel("True Positive Rate (Sensitivity)")
ax.set_title(f"ROC Curve (AUC = {m['roc_auc']:.3f})")
render_plot_with_download(
fig,
title="ROC curve",
filename="roc_curve.png",
export_dpi=export_dpi,
key="dl_train_roc"
)
#Precision recall curve
# =========================
# TRAINING: PR curve plot
# =========================
pr = m["pr_curve"]
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
ax.plot(pr["recall"], pr["precision"])
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.set_title(f"PR Curve (AP = {pr['average_precision']:.3f})")
render_plot_with_download(
fig,
title="PR curve",
filename="pr_curve.png",
export_dpi=export_dpi,
key="dl_train_pr"
)
#Calibration plot
# =========================
# TRAINING: Calibration plot
# =========================
cal = m["calibration"]
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
ax.plot(cal["prob_pred"], cal["prob_true"])
ax.plot([0, 1], [0, 1])
ax.set_xlabel("Mean predicted probability")
ax.set_ylabel("Observed event rate")
ax.set_title("Calibration curve")
render_plot_with_download(
fig,
title="Calibration curve",
filename="calibration_curve.png",
export_dpi=export_dpi,
key="dl_train_cal"
)
#Decision curve
# =========================
# TRAINING: Decision curve analysis plot
# =========================
dca = m["decision_curve"]
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
ax.plot(dca["thresholds"], dca["net_benefit_model"], label="Model")
ax.plot(dca["thresholds"], dca["net_benefit_all"], label="Treat all")
ax.plot(dca["thresholds"], dca["net_benefit_none"], label="Treat none")
ax.set_xlabel("Threshold probability")
ax.set_ylabel("Net benefit")
ax.set_title("Decision curve analysis")
ax.legend()
render_plot_with_download(
fig,
title="Decision curve",
filename="decision_curve.png",
export_dpi=export_dpi,
key="dl_train_dca"
)
st.caption(
"If the model curve is above Treat-all and Treat-none across a threshold range, "
"the model provides net clinical benefit in that range."
)
st.divider()
st.subheader("Threshold analysis")
thr = st.slider("Decision threshold", 0.0, 1.0, 0.5, 0.01)
# Recompute threshold-based metrics quickly using stored probabilities
# You need y_test and proba in scope. Easiest is to store them in session_state during training.
st.session_state.y_test_last = y_test
st.session_state.proba_last = proba
if "y_test_last" in st.session_state and "proba_last" in st.session_state:
cls = compute_classification_metrics(st.session_state.y_test_last, st.session_state.proba_last, threshold=thr)
st.write({
"Sensitivity": cls["sensitivity"],
"Specificity": cls["specificity"],
"Precision": cls["precision"],
"Recall": cls["recall"],
"F1": cls["f1"],
"Accuracy": cls["accuracy"],
"Balanced Accuracy": cls["balanced_accuracy"],
})
# ---------------- PUBLISH (only after training) ----------------
# ---------------- PUBLISH (only after training) ----------------
if st.session_state.get("pipe") is not None:
st.divider()
st.subheader("Publish trained model to Hugging Face Hub")
default_version = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
version_tag = st.text_input(
"Version tag",
value=default_version,
help="Used as releases/<version>/ in the model repository",
)
if st.button("Publish model.joblib + meta.json to Model Repo", key="publish_btn"):
try:
with st.spinner("Uploading to Hugging Face Model repo..."):
paths = publish_to_hub(MODEL_REPO_ID, version_tag)
# Upload background.csv if it exists
if os.path.exists("background.csv"):
bg_paths = publish_background_to_hub(MODEL_REPO_ID, version_tag, background_path="background.csv")
paths.update(bg_paths)
else:
st.warning("background.csv not found; SHAP background will not be uploaded.")
st.success("Uploaded successfully to your model repository.")
st.json(paths)
except Exception as e:
st.error(f"Upload failed: {e}")
# ---------------- PREDICT ----------------
PRED_N_BINS = 10
PRED_CAL_STRATEGY = "uniform"
with tab_predict:
st.subheader("Select a trained model (no retraining required)")
MODEL_REPO_ID = "Synav/Explainable-Acute-Leukemia-Mortality-Predictor"
# Ensure session state keys exist
if "pipe" not in st.session_state:
st.session_state.pipe = None
if "meta" not in st.session_state:
st.session_state.meta = None
if "explainer" not in st.session_state:
st.session_state.explainer = None
# List available releases
try:
versions = list_release_versions(MODEL_REPO_ID)
except Exception as e:
versions = []
st.error(f"Could not list model versions: {e}")
choices = ["latest"] + versions if versions else ["latest"]
selected = st.selectbox("Choose model version", choices, index=0)
if st.button("Load selected model"):
try:
with st.spinner("Loading model from Hugging Face Hub..."):
if selected == "latest":
pipe, meta = load_latest_model(MODEL_REPO_ID)
else:
pipe, meta = load_model_by_version(MODEL_REPO_ID, selected)
st.session_state.pipe = pipe
st.session_state.meta = meta
st.session_state.explainer = None # rebuild later with inference data
st.success(f"Loaded model: {selected}")
except Exception as e:
st.error(f"Load failed: {e}")
st.divider()
if st.session_state.get("pipe") is None or st.session_state.get("meta") is None:
st.warning("Load a model version above (it must include meta.json), then continue.")
st.stop()
pipe = st.session_state.pipe
meta = st.session_state.meta
# Try load survival model (optional)
try:
if selected == "latest":
st.session_state.surv_model = load_latest_survival_model(MODEL_REPO_ID)
else:
st.session_state.surv_model = load_survival_model_by_version(MODEL_REPO_ID, selected)
except Exception as e:
st.session_state.surv_model = None
st.warning(f"Survival bundle load failed: {e}")
# bundle = st.session_state.get("surv_model", None)
# bundle = st.session_state.get("surv_model", None)
# st.write("surv_model type:", type(bundle))
# st.write("surv_model keys:", list(bundle.keys()) if isinstance(bundle, dict) else None)
# if isinstance(bundle, dict) and bundle.get("model") is not None:
# st.write("Cox predictors:", len(bundle.get("columns", [])))
# st.write("Cox cat cols:", len(bundle.get("cat_cols", [])))
# st.write("Cox num cols:", len(bundle.get("num_cols", [])))
# st.write("Bundle version:", bundle.get("version"))
# cph = bundle["model"]
# st.write("Cox coef shape:", getattr(cph, "params_", None).shape if hasattr(cph, "params_") else None)
# else:
# cph = None
DEBUG_SURV = False
if DEBUG_SURV:
st.write("Survival bundle loaded:", isinstance(bundle, dict))
# 1) MUST come first: schema from meta
feature_cols = meta["schema"]["features"]
num_cols = meta["schema"]["numeric"]
cat_cols = meta["schema"]["categorical"]
# ------------------------------------------------------------
# SHAP background: prefer inference file, else HF background.csv
# ------------------------------------------------------------
df_inf = st.session_state.get("df_inf")
if df_inf is not None:
# use user cohort as background (optional)
X_bg = coerce_X_like_schema(df_inf, feature_cols, num_cols, cat_cols)
else:
# fall back to published background
X_bg = get_shap_background_auto(MODEL_REPO_ID, feature_cols, num_cols, cat_cols)
st.session_state.X_bg_for_shap = X_bg
# 2) Now we can build lookup
FEATURE_LOOKUP = {norm_col(c): c for c in feature_cols}
from datetime import date
MIN_DOB = date(1900, 1, 1)
ALL_COUNTRIES = [
"Afghanistan","Albania","Algeria","Andorra","Angola","Antigua and Barbuda",
"Argentina","Armenia","Australia","Austria","Azerbaijan","Bahamas","Bahrain",
"Bangladesh","Barbados","Belarus","Belgium","Belize","Benin","Bhutan","Bolivia",
"Bosnia and Herzegovina","Botswana","Brazil","Brunei","Bulgaria","Burkina Faso",
"Burundi","Cabo Verde","Cambodia","Cameroon","Canada","Central African Republic",
"Chad","Chile","China","Colombia","Comoros","Congo (Congo-Brazzaville)",
"Costa Rica","Cote d’Ivoire","Croatia","Cuba","Cyprus","Czechia",
"Democratic Republic of the Congo","Denmark","Djibouti","Dominica",
"Dominican Republic","Ecuador","Egypt","El Salvador","Equatorial Guinea",
"Eritrea","Estonia","Eswatini","Ethiopia","Fiji","Finland","France","Gabon",
"Gambia","Georgia","Germany","Ghana","Greece","Grenada","Guatemala","Guinea",
"Guinea-Bissau","Guyana","Haiti","Honduras","Hungary","Iceland","India",
"Indonesia","Iran","Iraq","Ireland","Israel","Italy","Jamaica","Japan","Jordan",
"Kazakhstan","Kenya","Kiribati","Kuwait","Kyrgyzstan","Laos","Latvia","Lebanon",
"Lesotho","Liberia","Libya","Liechtenstein","Lithuania","Luxembourg",
"Madagascar","Malawi","Malaysia","Maldives","Mali","Malta","Marshall Islands",
"Mauritania","Mauritius","Mexico","Micronesia","Moldova","Monaco","Mongolia",
"Montenegro","Morocco","Mozambique","Myanmar","Namibia","Nauru","Nepal",
"Netherlands","New Zealand","Nicaragua","Niger","Nigeria","North Korea",
"North Macedonia","Norway","Oman","Pakistan","Palau","Palestine","Panama",
"Papua New Guinea","Paraguay","Peru","Philippines","Poland","Portugal","Qatar",
"Romania","Russia","Rwanda","Saint Kitts and Nevis","Saint Lucia",
"Saint Vincent and the Grenadines","Samoa","San Marino","Sao Tome and Principe",
"Saudi Arabia","Senegal","Serbia","Seychelles","Sierra Leone","Singapore",
"Slovakia","Slovenia","Solomon Islands","Somalia","South Africa","South Korea",
"South Sudan","Spain","Sri Lanka","Sudan","Suriname","Sweden","Switzerland",
"Syria","Taiwan","Tajikistan","Tanzania","Thailand","Timor-Leste","Togo","Tonga",
"Trinidad and Tobago","Tunisia","Turkey","Turkmenistan","Tuvalu","Uganda",
"Ukraine","United Arab Emirates","United Kingdom","United States","Uruguay",
"Uzbekistan","Vanuatu","Vatican City","Venezuela","Vietnam","Yemen","Zambia",
"Zimbabwe",
"Other","Unknown"
]
DEFAULT_OPTIONS = {
# ---------------- Demographics ----------------
"Gender": ["Male", "Female"],
"Ethnicity": ALL_COUNTRIES,
# ---------------- Disease ----------------
"Type of Leukemia": [
"ALL",
"AML",
"APL",
"CML",
"MPAL",
"Secondary AML",
"Other",
"Unknown"
],
"Risk Assesment": [
"Favorable",
"Intermediate",
"Adverse",
"Unknown"
],
"ECOG": [0, 1, 2, 3, 4],
# ---------------- Molecular / Diagnostic ----------------
"MDx Leukemia Screen 1": [
"Negative",
"t(15;17) / PML-RARA",
"t(9;22) / BCR-ABL1",
"RUNX1-RUNX1T1",
"inv(16) / CBFB-MYH11",
"KMT2A-AFF1",
"t(1;19) / TCF3-PBX1",
"t(10;11)",
"Other",
"Not done",
"Unknown"
],
"Post-Induction MRD": [
"Negative",
"Positive",
"Pending",
"Not done",
"Unknown"
],
# ---------------- Treatment ----------------
"First_Induction": [
"3+7",
"3+7+GO",
"3+7+TKI",
"3+7+Midostaurin",
"ATO+ATRA",
"Aza+Ven",
"CALGB",
"CPX-351",
"Dexa+TKI",
"FLAG-IDA",
"HCVAD",
"HCVAD+TKI",
"HMA",
"Hyper-CVAD",
"Not fit",
"Pediatric protocol",
"Other",
"Unknown"
],
# ---------------- Clinical Yes/No ----------------
# (stored as 0/1 in your code)
"Thrombocytopenia": ["No", "Yes"],
"Bleeding": ["No", "Yes"],
"B Symptoms": ["No", "Yes"],
"Lymphadenopathy": ["No", "Yes"],
"CNS Involvement": ["No", "Yes"],
"Extramedullary Involvement": ["No", "Yes"],
}
# Canonical dropdown fields (use your intended names)
CANON_DROPDOWNS = [
"Ethnicity",
"Type of Leukemia",
"Post-Induction MRD",
"First_Induction",
"Risk Assesment",
"MDx Leukemia Screen 1",
"ECOG",
"Gender",
]
# Convert canonical -> actual column names present in this loaded model
DROPDOWN_FIELDS_ACTUAL = []
for canon in CANON_DROPDOWNS:
key = norm_col(canon)
if key in FEATURE_LOOKUP:
DROPDOWN_FIELDS_ACTUAL.append(FEATURE_LOOKUP[key])
def alias_default_options(canon_name: str):
k = norm_col(canon_name)
if k in FEATURE_LOOKUP:
actual = FEATURE_LOOKUP[k]
if canon_name in DEFAULT_OPTIONS and actual not in DEFAULT_OPTIONS:
DEFAULT_OPTIONS[actual] = DEFAULT_OPTIONS[canon_name]
alias_default_options("Ethnicity")
alias_default_options("Type of Leukemia")
alias_default_options("Post-Induction MRD")
alias_default_options("First_Induction")
alias_default_options("Risk Assesment")
alias_default_options("MDx Leukemia Screen 1")
# Fields you want as controlled dropdowns (rather than free text)
DROPDOWN_FIELDS = [
"Gender",
"Ethnicity",
"Type of Leukemia",
"ECOG",
"Risk Assesment",
"MDx Leukemia Screen 1",
"Post-Induction MRD",
"First_Induction",
]
# Yes/No (stored as 1/0) fields
YESNO_FIELDS = [
"Thrombocytopenia",
"Bleeding",
"B Symptoms",
"Lymphadenopathy",
"CNS Involvement",
"Extramedullary Involvement",
]
# Numeric fields with units (show units in the label)
UNITS = {
'WBCs on Admission\n ( x10^9/L) ': "x10^9/L",
"Hb on Admission (g/L)": "g/L",
"LDH on Admission (IU/L)": "IU/L",
"Pre-Induction bone marrow biopsy blasts %": "%",
"Age (years)": "years",
}
# FISH / NGS marker groups (multi-select -> sets multiple columns 1/0)
FISH_MARKERS = [
"fish_inv16/cbfb_myh11",
"fish_t8;21/runx1_runx1t1",
"fish_t15;17/PML-RARA",
"fish_KMT2A-MLL/11q23",
"fish_BCR-ABL1/t9;22",
"fish_ETV6-RUNX1/t12;21",
"fish_TCF3_PBX1/t1;19",
"fish_IGH/14q32rearr",
"fish_CRLF2 rearr",
"fish_NUP214/9q34",
"fish_Other",
"fish_del7q/monosomy7",
"fish_TP53/del17p",
"fish_del13q",
"fish_del11q",
]
FISH_COUNT_COL = "Number of FISH alteration"
NGS_MARKERS = [
"ngs_FLT3",
"ngs_NPM1",
"ngs_CEBPA",
"ngs_DNMT3A",
"ngs_IDH1",
"ngs_IDH2",
"ngs_TET2",
"ngs_TP53",
"ngs_RUNX1",
"ngs_Other",
"ngs_spliceosome",
]
NGS_COUNT_COL = "No. of ngs_mutation"
AGE_FEATURE = "Age (years)"
DX_DATE_FEATURE = "Date of 1st Bone Marrow biopsy (Date of Diagnosis)" # keep exact if your schema has trailing space
# =========================
# Fields managed by dedicated tabs (do NOT render duplicates in Core tab)
# =========================
def get_options_from_df(df: pd.DataFrame, col: str) -> list[str]:
"""Dropdown options from uploaded Excel (inference file)."""
if df is None or col not in df.columns:
return []
vals = (
df[col]
.dropna()
.astype(str)
.map(lambda x: x.strip())
.loc[lambda s: s != ""]
.unique()
.tolist()
)
vals = sorted(vals)
return vals
# =========================
# After model is loaded
# =========================
pipe = st.session_state.pipe
meta = st.session_state.meta
# Map normalized name -> actual model column name
FEATURE_LOOKUP = {norm_col(c): c for c in feature_cols}
TAB_MANAGED_FIELDS = set()
# Derived/auto fields (managed by DOB/Dx/CR widgets)
if AGE_FEATURE in feature_cols:
TAB_MANAGED_FIELDS.add(AGE_FEATURE)
if DX_DATE_FEATURE in feature_cols:
TAB_MANAGED_FIELDS.add(DX_DATE_FEATURE)
if "Date of 1st CR" in feature_cols:
TAB_MANAGED_FIELDS.add("Date of 1st CR")
# Clinical yes/no handled in Clinical tab
TAB_MANAGED_FIELDS.update([f for f in YESNO_FIELDS if f in feature_cols])
# FISH handled in FISH tab
TAB_MANAGED_FIELDS.update([f for f in FISH_MARKERS if f in feature_cols])
# NGS handled in NGS tab
TAB_MANAGED_FIELDS.update([f for f in NGS_MARKERS if f in feature_cols])
DROPDOWN_FIELDS = [f for f in DROPDOWN_FIELDS if f not in TAB_MANAGED_FIELDS]
# Optional marker counts: handled automatically
if FISH_COUNT_COL in feature_cols:
TAB_MANAGED_FIELDS.add(FISH_COUNT_COL)
if NGS_COUNT_COL in feature_cols:
TAB_MANAGED_FIELDS.add(NGS_COUNT_COL)
# ---------------------------
# Inference Excel (optional) - store in session_state for dropdown options
# ---------------------------
infer_file = st.file_uploader("Upload inference Excel (.xlsx) (optional, for dropdown options + batch prediction)", type=["xlsx"], key="infer_xlsx")
if infer_file:
st.session_state.df_inf = pd.read_excel(infer_file, engine="openpyxl")
# 🔴 PLACE IT HERE — normalize headers immediately after load
st.session_state.df_inf.columns = [
" ".join(str(c).replace("\u00A0"," ").split())
for c in st.session_state.df_inf.columns
]
# ✅ ALIGN TO TRAINED MODEL SCHEMA (THIS FIXES YOUR ERROR)
st.session_state.df_inf = align_columns_to_schema(
st.session_state.df_inf,
meta["schema"]["features"])
st.session_state.df_inf = add_ethnicity_region(
st.session_state.df_inf,
"Ethnicity",
"Ethnicity_Region"
)
else:
st.session_state.df_inf = None
df_for_options = st.session_state.df_inf
st.divider()
st.subheader("Enter patient details to get prediction (Example: DOB → Age, Dx date → prediction)")
from datetime import date
MIN_DATE = date(1900, 1, 1)
c1, c2 = st.columns(2)
with c1:
dob_unknown = st.checkbox("DOB unknown", value=False, key="dob_unknown_chk")
dob = None
if not dob_unknown:
dob = st.date_input("Date of birth (DOB)", min_value=MIN_DATE, max_value=date.today(), key="dob_date")
with c2:
dx_unknown = st.checkbox("Diagnosis date unknown", value=False, key="dx_unknown_chk")
dx_date = None
if not dx_unknown:
dx_date = st.date_input("Date of 1st Bone Marrow biopsy (Diagnosis)", min_value=MIN_DATE, max_value=date.today(), key="dx_date")
# Optional additional date input for CR1 if present in feature_cols
cr1_date = None
if "Date of 1st CR" in feature_cols:
cr1_unknown = st.checkbox("1st CR date unknown", value=False, key="cr1_unknown")
if not cr1_unknown:
cr1_date = st.date_input("Date of 1st CR", min_value=MIN_DATE, max_value=date.today(), key="cr1_date")
derived_age = age_years_at(dob, dx_date)
def yesno_to_01(v: str):
if v == "Yes":
return 1
if v == "No":
return 0
return np.nan
# Create tabs FIRST (outside any form)
tab_core, tab_clin, tab_fish, tab_ngs = st.tabs(["Core", "Clinical (Yes/No)", "FISH", "NGS"])
# Storage for selections
values_by_index = [np.nan] * len(feature_cols)
with tab_fish:
st.caption("Select one or many FISH alterations. Selected = 1, not selected = 0.")
fish_selected = st.multiselect(
"FISH alterations",
options=[m for m in FISH_MARKERS if m in feature_cols],
default=[],
key="sp_fish_selected",
)
with tab_ngs:
st.caption("Select one or many NGS mutations. Selected = 1, not selected = 0.")
ngs_selected = st.multiselect(
"NGS mutations",
options=[m for m in NGS_MARKERS if m in feature_cols],
default=[],
key="sp_ngs_selected",
)
with tab_clin:
st.caption("Clinical flags: Yes=1, No=0")
for i, f in enumerate(feature_cols):
if f in YESNO_FIELDS:
v = st.selectbox(f, options=["", "No", "Yes"], index=0, key=f"sp_{i}_yn")
values_by_index[i] = yesno_to_01(v)
with tab_core:
ecog = st.selectbox("ECOG", options=[0, 1, 2, 3, 4], index=0, key="sp_ecog")
for i, f in enumerate(feature_cols):
# Skip fields handled in other tabs / derived inputs
if f in TAB_MANAGED_FIELDS:
continue
# Age auto-calc (display integer, store float)
# --- Age (auto from DOB & Dx date) ---
if f.strip() == AGE_FEATURE.strip():
# If DOB or Dx date unknown, allow manual entry
if dob_unknown or dx_unknown or np.isnan(derived_age):
v_age = st.number_input(
f"{f} (enter if DOB/Dx unknown)",
value=None,
step=1,
format="%d",
key=f"sp_{i}_age_manual",
)
values_by_index[i] = v_age
else:
st.number_input(
f"{f} (auto from DOB & Dx date)",
value=int(round(derived_age)),
step=1,
format="%d",
key=f"sp_{i}_age_auto",
disabled=True,
)
values_by_index[i] = float(derived_age)
continue
# --- Diagnosis date (auto from dx_date input) ---
if f.strip() == DX_DATE_FEATURE.strip():
values_by_index[i] = np.nan if dx_date is None else dx_date.isoformat()
st.text_input(
f"{f} (auto)",
value="" if dx_date is None else dx_date.isoformat(),
disabled=True,
key=f"sp_{i}_dx_show"
)
continue
if f.strip() == "Date of 1st CR".strip():
values_by_index[i] = np.nan if cr1_date is None else cr1_date.isoformat()
st.text_input(
f"{f} (auto)",
value="" if cr1_date is None else cr1_date.isoformat(),
disabled=True,
key=f"sp_{i}_cr_show"
)
continue
# ECOG mapped to int
if f.strip() == "ECOG":
values_by_index[i] = int(ecog)
continue
# Gender dropdown
if f == "Gender":
v = st.selectbox("Gender", options=["", "Male", "Female"], index=0, key=f"sp_{i}_gender")
values_by_index[i] = np.nan if v == "" else v
continue
# Other dropdown fields (from inference df options if available)
# Dropdown fields (robust)
if f in DROPDOWN_FIELDS_ACTUAL and f not in ("Gender", "ECOG"):
opts = options_for(f, df_for_options) # defaults + excel uniques
v = st.selectbox(f, options=opts, index=0, key=f"sp_{i}_dd")
values_by_index[i] = np.nan if v == "" else v
continue
# Numeric
if f in num_cols:
unit = UNITS.get(f, None)
label = f"{f} ({unit})" if unit else f
if f == "Age (years)":
v = st.number_input(label, value=None, step=1, format="%d", key=f"sp_{i}_num")
else:
v = st.number_input(label, value=None, format="%.2f", key=f"sp_{i}_num")
values_by_index[i] = v
continue
# Categorical fallback
if f in cat_cols:
v = st.text_input(f, value="", key=f"sp_{i}_cat")
values_by_index[i] = np.nan if v.strip() == "" else v
continue
# Other fallback
v = st.text_input(f, value="", key=f"sp_{i}_other")
values_by_index[i] = np.nan if v.strip() == "" else v
# Apply FISH/NGS selections to row
fish_set = set(fish_selected)
ngs_set = set(ngs_selected)
for i, f in enumerate(feature_cols):
if f in FISH_MARKERS:
values_by_index[i] = 1 if f in fish_set else 0
if f in NGS_MARKERS:
values_by_index[i] = 1 if f in ngs_set else 0
# Auto-fill marker counts if present
if FISH_COUNT_COL in feature_cols:
values_by_index[feature_cols.index(FISH_COUNT_COL)] = int(len(fish_selected))
if NGS_COUNT_COL in feature_cols:
values_by_index[feature_cols.index(NGS_COUNT_COL)] = int(len(ngs_selected))
st.divider()
st.subheader("Predict single patient")
m = meta.get("metrics", {})
default_thr = float(m.get("best_threshold", 0.5))
thr_single = st.slider(
"Classification threshold",
0.0, 1.0, default_thr, 0.01,
key="sp_thr"
)
# External validation threshold
thr_ext = st.slider(
"External validation threshold",
0.0, 1.0, default_thr, 0.01,
key="thr_ext"
)
low_cut_s, high_cut_s = st.slider(
"Risk band cutoffs (low, high)",
0.0, 1.0, (0.2, 0.8), 0.01,
key="sp_risk_cuts"
)
# Ensure low <= high
if low_cut_s > high_cut_s:
low_cut_s, high_cut_s = high_cut_s, low_cut_s
def band_one(p: float) -> str:
if p < low_cut_s:
return "Low"
if p >= high_cut_s:
return "High"
return "Intermediate"
# Submit button (no form needed; simpler + fewer state surprises)
if st.button("Predict single patient", key="sp_predict_btn"):
X_one = pd.DataFrame([values_by_index], columns=feature_cols).replace({pd.NA: np.nan})
for c in num_cols:
if c in X_one.columns:
X_one[c] = pd.to_numeric(X_one[c], errors="coerce")
for c in cat_cols:
if c in X_one.columns:
X_one[c] = X_one[c].astype("object")
X_one.loc[X_one[c].isna(), c] = np.nan
X_one[c] = X_one[c].map(lambda v: v if pd.isna(v) else str(v))
proba_one = float(pipe.predict_proba(X_one)[:, 1][0])
st.success("Prediction generated.")
st.metric("Predicted probability", f"{proba_one:.4f}")
# ---- Survival prediction for this patient (if survival model loaded) ----
# ---- Survival prediction for this patient (if survival model loaded) ----
bundle = st.session_state.get("surv_model", None)
if isinstance(bundle, dict) and bundle.get("model") is not None:
try:
cph = bundle["model"]
surv_cols = bundle.get("columns", [])
# Build Cox input row (same preprocessing as Cox training)
cox_feature_cols = bundle.get("feature_cols", feature_cols)
cox_num_cols = bundle.get("num_cols", num_cols)
cox_cat_cols = bundle.get("cat_cols", cat_cols)
df_one_surv = X_one[cox_feature_cols].copy()
for c in cox_num_cols:
if c in df_one_surv.columns:
df_one_surv[c] = pd.to_numeric(df_one_surv[c], errors="coerce")
for c in cox_cat_cols:
if c in df_one_surv.columns:
df_one_surv[c] = df_one_surv[c].astype("object").map(
lambda v: v if pd.isna(v) else str(v)
)
df_one_surv_oh = pd.get_dummies(df_one_surv, columns=cox_cat_cols, drop_first=True)
df_one_surv_oh = df_one_surv_oh.loc[:, ~df_one_surv_oh.columns.duplicated()].copy()
# Align to training predictor columns
for col in surv_cols:
if col not in df_one_surv_oh.columns:
df_one_surv_oh[col] = 0
df_one_surv_oh = df_one_surv_oh[surv_cols]
imp = bundle.get("imputer", None)
if imp is not None:
X_imp = imp.transform(df_one_surv_oh)
df_one_surv_oh = pd.DataFrame(X_imp, columns=surv_cols, index=df_one_surv_oh.index)
else:
df_one_surv_oh = df_one_surv_oh.fillna(0)
# Predict survival function
surv_fn = cph.predict_survival_function(df_one_surv_oh)
def surv_at(days: int) -> float:
idx = surv_fn.index.values
j = int(np.argmin(np.abs(idx - days)))
return float(surv_fn.iloc[j, 0])
s6m = surv_at(180)
s1y = surv_at(365)
s2y = surv_at(730)
s3y = surv_at(1095)
st.subheader("Predicted survival probability")
a, b, c, d = st.columns(4)
a.metric("6 months", f"{s6m*100:.1f}%")
b.metric("1 year", f"{s1y*100:.1f}%")
c.metric("2 years", f"{s2y*100:.1f}%")
d.metric("3 years", f"{s3y*100:.1f}%")
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
ax.plot(surv_fn.index, surv_fn.values)
ax.set_xlabel("Days from diagnosis")
ax.set_ylabel("Survival probability")
ax.set_ylim(0, 1)
ax.set_title("Predicted survival curve (patient)")
render_plot_with_download(
fig,
title="Patient survival curve",
filename="patient_survival_curve.png",
export_dpi=export_dpi,
key="dl_patient_surv_curve"
)
except Exception as e:
st.warning(f"Survival prediction could not be computed: {e}")
else:
st.info("Survival model not loaded/published yet (survival bundle missing).")
out = X_one.copy()
out["predicted_probability"] = proba_one
pred_class = int(proba_one >= thr_single)
out["predicted_class"] = pred_class
out["risk_band"] = band_one(proba_one)
# ---- SHAP compute only (cache) ----
X_one_t = transform_before_clf(pipe, X_one)
explainer = st.session_state.get("explainer")
explainer_sig = st.session_state.get("explainer_sig")
current_sig = (
selected,
None if st.session_state.get("X_bg_for_shap") is None else int(len(st.session_state["X_bg_for_shap"]))
)
if explainer is None or explainer_sig != current_sig:
X_bg = st.session_state.get("X_bg_for_shap")
if X_bg is None:
st.error("SHAP background not available. Admin must publish latest/background.csv.")
st.stop()
st.session_state.explainer = build_shap_explainer(pipe, X_bg)
st.session_state.explainer_sig = current_sig
explainer = st.session_state.explainer
shap_vals = explainer.shap_values(X_one_t)
if isinstance(shap_vals, list):
shap_vals = shap_vals[1]
names = get_final_feature_names(pipe)
try:
x_dense = X_one_t.toarray()[0]
except Exception:
x_dense = np.array(X_one_t)[0]
base = explainer.expected_value
if not np.isscalar(base):
base = float(np.array(base).reshape(-1)[0])
exp = shap.Explanation(
values=shap_vals[0],
base_values=float(base),
data=x_dense,
feature_names=names,
)
# CACHE ONLY
st.session_state.shap_single_exp = exp
st.dataframe(out, use_container_width=True)
st.download_button(
"Download single patient result (CSV)",
out.to_csv(index=False).encode("utf-8"),
file_name="single_patient_prediction.csv",
mime="text/csv",
key="dl_sp_csv",
)
# ---- Always render cached SHAP ----
if "shap_single_exp" in st.session_state:
exp = st.session_state.shap_single_exp
max_display_single = st.slider(
"Top features to display (single patient)",
5, 40, 20, 1,
key="sp_single_max_display"
)
c1, c2 = st.columns(2)
with c1:
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
shap.plots.waterfall(exp, show=False, max_display=max_display_single)
fig_w = plt.gcf()
render_plot_with_download(
fig_w,
title="Single-patient SHAP waterfall",
filename="single_patient_shap_waterfall.png",
export_dpi=export_dpi,
key="dl_sp_wf"
)
with c2:
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
shap.plots.bar(exp, show=False, max_display=max_display_single)
fig_b = plt.gcf()
render_plot_with_download(
fig_b,
title="Single-patient SHAP bar",
filename="single_patient_shap_bar.png",
export_dpi=export_dpi,
key="dl_sp_bar"
)
# --- After SHAP plots are displayed ---
with st.expander("How to interpret the SHAP explanation plots", expanded=True):
st.markdown(r"""
**What is SHAP?**
SHAP (SHapley Additive exPlanations) decomposes a model prediction into **feature-wise contributions**.
Each feature pushes the prediction **toward higher risk** or **toward lower risk**, relative to the model’s baseline.
**What is \(E[f(X)]\)? (baseline / expected model output)**
- \(E[f(X)]\) is the model’s **average output** over the reference population used by the explainer (typically the training set).
- It is the **starting point** of the explanation before any patient-specific information is applied.
**What is \(f(x)\)? (this patient’s model output)**
- \(f(x)\) is the model’s **final output for this patient**, after adding all feature contributions to the baseline.
- For many classifiers (including logistic regression), SHAP waterfall plots often use the **log-odds (logit) scale** rather than raw probability.
**How the waterfall plot should be read**
- The plot starts at **\(E[f(X)]\)** and then adds/subtracts feature effects to arrive at **\(f(x)\)**.
- **Red bars** push the prediction **toward higher risk** (increase the model output).
- **Blue bars** push the prediction **toward lower risk / protective direction** (decrease the model output).
- The **bar length** indicates the **strength** of that feature’s influence for this patient.
- “**Other features**” is the combined net effect of features not shown individually.
**How the bar plot should be read (Top contributors)**
- Features are ranked by **absolute SHAP value** (largest impact first).
- **Positive SHAP** (right) increases predicted risk; **negative SHAP** (left) decreases predicted risk.
**Clinical cautions**
- SHAP explains the model’s behavior; it does **not** prove causality.
- Ensure variable definitions and patient population match the model’s intended use.
""")
# -----------------------------
# Batch prediction / validation
# -----------------------------
st.divider()
st.subheader("Batch prediction / External validation")
df_inf = st.session_state.df_inf
if df_inf is None:
st.info("Upload an inference Excel to run batch prediction / external validation.")
else:
meta = st.session_state.get("meta")
pipe = st.session_state.get("pipe")
if meta is None or pipe is None:
st.error("Model not loaded. Please load a model version first.")
st.stop()
feature_cols = meta["schema"]["features"]
num_cols = meta["schema"]["numeric"]
cat_cols = meta["schema"]["categorical"]
missing = [c for c in feature_cols if c not in df_inf.columns]
if missing:
st.error(f"Inference file is missing required feature columns: {missing}")
st.stop()
meta = st.session_state.get("meta")
if not meta:
st.error("Model metadata not loaded. Please load a model version.")
st.stop()
X_inf = coerce_X_like_schema(df_inf, feature_cols, num_cols, cat_cols)
X_inf = X_inf.replace({pd.NA: np.nan})
for c in num_cols:
X_inf[c] = pd.to_numeric(X_inf[c], errors="coerce")
for c in cat_cols:
X_inf[c] = X_inf[c].astype("object")
X_inf.loc[X_inf[c].isna(), c] = np.nan
X_inf[c] = X_inf[c].map(lambda v: v if pd.isna(v) else str(v))
proba = pipe.predict_proba(X_inf)[:, 1]
st.divider()
st.subheader("External validation (if Outcome Event label is present)")
if LABEL_COL in df_inf.columns:
try:
y_ext_raw = df_inf[LABEL_COL].copy()
y_ext01, _ = coerce_binary_label(y_ext_raw)
# Core metrics
roc_auc_ext = float(roc_auc_score(y_ext01, proba))
fpr, tpr, roc_thresholds = roc_curve(y_ext01, proba)
# Threshold metrics (user-controlled)
cls_ext = compute_classification_metrics(y_ext01, proba, threshold=float(thr_ext))
pr_ext = compute_pr_curve(y_ext01, proba)
cal_ext = compute_calibration(y_ext01, proba, n_bins=PRED_N_BINS, strategy=PRED_CAL_STRATEGY)
dca_ext = decision_curve_analysis(y_ext01, proba)
# Display headline metrics
c1, c2, c3, c4 = st.columns(4)
c1.metric("ROC AUC (external)", f"{roc_auc_ext:.3f}")
c2.metric("Sensitivity", f"{cls_ext['sensitivity']:.3f}")
c3.metric("Specificity", f"{cls_ext['specificity']:.3f}")
c4.metric("F1", f"{cls_ext['f1']:.3f}")
# Confusion matrix
cm_df = pd.DataFrame(
[[cls_ext["tn"], cls_ext["fp"]], [cls_ext["fn"], cls_ext["tp"]]],
index=["Actual 0", "Actual 1"],
columns=["Pred 0", "Pred 1"],
)
st.markdown("**Confusion Matrix (external)**")
st.dataframe(cm_df)
# ROC plot
# =========================
# EXTERNAL: ROC curve plot
# (replace your current plt.plot(fpr, tpr) block)
# =========================
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
ax.plot(fpr, tpr)
ax.plot([0, 1], [0, 1])
ax.set_xlabel("False Positive Rate (1 - Specificity)")
ax.set_ylabel("True Positive Rate (Sensitivity)")
ax.set_title(f"External ROC Curve (AUC = {roc_auc_ext:.3f})")
render_plot_with_download(
fig,
title="External ROC curve",
filename="external_roc_curve.png",
export_dpi=export_dpi,
key="dl_ext_roc"
)
# -------------------------
# EXTERNAL: PR curve
# -------------------------
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
ax.plot(pr_ext["recall"], pr_ext["precision"])
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.set_title(f"External PR Curve (AP = {pr_ext['average_precision']:.3f})")
render_plot_with_download(
fig,
title="External PR curve",
filename="external_pr_curve.png",
export_dpi=export_dpi,
key="dl_ext_pr"
)
# -------------------------
# EXTERNAL: Calibration
# -------------------------
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
ax.plot(cal_ext["prob_pred"], cal_ext["prob_true"])
ax.plot([0, 1], [0, 1])
ax.set_xlabel("Mean predicted probability")
ax.set_ylabel("Observed event rate")
ax.set_title("External calibration curve")
render_plot_with_download(
fig,
title="External calibration curve",
filename="external_calibration_curve.png",
export_dpi=export_dpi,
key="dl_ext_cal"
)
# -------------------------
# EXTERNAL: DCA
# -------------------------
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
ax.plot(dca_ext["thresholds"], dca_ext["net_benefit_model"], label="Model")
ax.plot(dca_ext["thresholds"], dca_ext["net_benefit_all"], label="Treat all")
ax.plot(dca_ext["thresholds"], dca_ext["net_benefit_none"], label="Treat none")
ax.set_xlabel("Threshold probability")
ax.set_ylabel("Net benefit")
ax.set_title("External decision curve analysis")
ax.legend()
render_plot_with_download(
fig,
title="External decision curve",
filename="external_decision_curve.png",
export_dpi=export_dpi,
key="dl_ext_dca"
)
except Exception as e:
st.error(f"Could not compute external validation metrics: {e}")
else:
st.info("No Outcome Event column found in the inference Excel, so external validation metrics cannot be computed.")
# Predict probabilities
proba = pipe.predict_proba(X_inf)[:, 1]
df_out = df_inf.copy()
df_out["predicted_probability"] = proba
# ---- Batch survival probabilities (if survival model loaded) ----
# ---- Batch survival probabilities (if survival model loaded) ----
bundle = st.session_state.get("surv_model", None)
if isinstance(bundle, dict) and bundle.get("model") is not None:
try:
cph = bundle["model"]
surv_cols = bundle["columns"]
imp = bundle.get("imputer", None)
cox_feature_cols = bundle.get("feature_cols", feature_cols)
cox_num_cols = bundle.get("num_cols", num_cols)
cox_cat_cols = bundle.get("cat_cols", cat_cols)
df_surv_in = X_inf[cox_feature_cols].copy()
for c in cox_num_cols:
if c in df_surv_in.columns:
df_surv_in[c] = pd.to_numeric(df_surv_in[c], errors="coerce")
for c in cox_cat_cols:
if c in df_surv_in.columns:
df_surv_in[c] = df_surv_in[c].astype("object")
df_surv_in.loc[df_surv_in[c].isna(), c] = np.nan
df_surv_in[c] = df_surv_in[c].map(lambda v: v if pd.isna(v) else str(v))
df_surv_in_oh = pd.get_dummies(df_surv_in, columns=cox_cat_cols, drop_first=True)
df_surv_in_oh = df_surv_in_oh.loc[:, ~df_surv_in_oh.columns.duplicated()].copy()
# align columns
for col in surv_cols:
if col not in df_surv_in_oh.columns:
df_surv_in_oh[col] = 0
df_surv_in_oh = df_surv_in_oh[surv_cols]
# impute
if imp is not None:
X_imp = imp.transform(df_surv_in_oh)
df_surv_in_oh = pd.DataFrame(X_imp, columns=surv_cols, index=df_surv_in_oh.index)
else:
df_surv_in_oh = df_surv_in_oh.fillna(0)
surv_fn_all = cph.predict_survival_function(df_surv_in_oh)
def surv_vec_at(days: int):
idx = surv_fn_all.index.values
j = int(np.argmin(np.abs(idx - days)))
return surv_fn_all.iloc[j, :].values.astype(float)
df_out["survival_6m"] = surv_vec_at(180)
df_out["survival_1y"] = surv_vec_at(365)
df_out["survival_2y"] = surv_vec_at(730)
df_out["survival_3y"] = surv_vec_at(1095)
except Exception as e:
st.warning(f"Batch survival probabilities could not be computed: {e}")
else:
st.info("Survival model not loaded/published yet (survival bundle missing).")
# ==========================================================
# EXPORT: CSV needed for NEJM-style ROC + Calibration + DCA
# (label + predicted_probability)
# ==========================================================
if LABEL_COL in df_out.columns:
curves_df = df_out[[LABEL_COL, "predicted_probability"]].copy()
st.download_button(
"Download CSV for ROC/Calibration/DCA panel",
curves_df.to_csv(index=False).encode("utf-8"),
file_name="predictions_for_curves.csv",
mime="text/csv",
key="dl_predictions_for_curves"
)
else:
st.info("Outcome Event column not present — panel CSV export requires true labels.")
df_out = add_ethnicity_region(df_out, eth_col="Ethnicity", out_col="Ethnicity_Region")
st.subheader("Grouped outcomes by region (analytics)")
if "Ethnicity_Region" in df_out.columns:
grp = df_out.groupby("Ethnicity_Region")["predicted_probability"].agg(["count","mean","median"]).reset_index()
st.dataframe(grp, use_container_width=True)
# --- classification + risk bands ---
st.divider()
st.subheader("Risk stratification")
# Classification threshold
thr = st.slider(
"Decision threshold for classification",
0.0, 1.0, 0.5, 0.01,
key="pred_thr"
)
df_out["predicted_class"] = (df_out["predicted_probability"] >= thr).astype(int)
# ✅ Batch risk band slider (MISSING in your code)
low_cut, high_cut = st.slider(
"Risk band cutoffs (low, high)",
0.0, 1.0, (0.2, 0.8), 0.01,
key="batch_risk_cuts"
)
# safety (optional but good practice)
if low_cut > high_cut:
low_cut, high_cut = high_cut, low_cut
def band(p: float) -> str:
if p < low_cut:
return "Low"
if p >= high_cut:
return "High"
return "Intermediate"
df_out["risk_band"] = df_out["predicted_probability"].map(band)
# --- END ADD ---
st.dataframe(df_out.head())
st.download_button(
"Download predictions",
df_out.to_csv(index=False).encode(),
"predictions.csv",
"text/csv"
)
#Batch SHAP for whole BLOCK
st.divider()
st.subheader("Batch SHAP (first 200 rows)")
MAX_BATCH = 200
n_rows = len(X_inf)
batch_n = min(MAX_BATCH, n_rows)
cA, cB, cC = st.columns([1, 1, 1])
with cA:
do_batch = st.button(f"Compute batch SHAP for first {batch_n} rows", key="batch_shap_btn")
with cB:
max_display_batch = st.slider("Top features to display (batch)",5, 40, 20, 1,key="batch_max_display")
with cC:
show_beeswarm = st.checkbox("Show beeswarm (slower)", value=True, key="batch_beeswarm")
if do_batch:
with st.spinner("Computing batch SHAP..."):
X_batch = X_inf.iloc[:batch_n].copy()
X_batch_t = transform_before_clf(pipe, X_batch)
explainer = st.session_state.get("explainer")
explainer_sig = st.session_state.get("explainer_sig")
# Create a simple signature that changes if model changes or background changes
# (using version + number of background rows is usually enough)
current_sig = (
selected, # or meta.get("created_at_utc") or meta.get("metrics", {}).get("roc_auc")
None if st.session_state.get("X_bg_for_shap") is None else int(len(st.session_state["X_bg_for_shap"]))
)
if explainer is None or explainer_sig != current_sig:
X_bg = st.session_state.get("X_bg_for_shap")
if X_bg is None:
st.error("SHAP background not available. Admin must publish latest/background.csv.")
st.stop()
st.session_state.explainer = build_shap_explainer(pipe, X_bg)
st.session_state.explainer_sig = current_sig
explainer = st.session_state.explainer
shap_vals_batch = explainer.shap_values(X_batch_t)
if isinstance(shap_vals_batch, list):
shap_vals_batch = shap_vals_batch[1]
names = get_final_feature_names(pipe)
# Absolute sanity check: align names with SHAP columns
if len(names) != shap_vals_batch.shape[1]:
st.warning(
f"Feature name mismatch: names={len(names)} vs shap_cols={shap_vals_batch.shape[1]}. "
"Using generic names."
)
names = [f"f{i}" for i in range(shap_vals_batch.shape[1])]
# Dense conversion once (used for summary + waterfalls)
try:
X_dense = safe_dense(X_batch_t, max_rows=200)
except Exception:
X_dense = np.array(X_batch_t)
# Cache batch results
st.session_state.shap_batch_vals = shap_vals_batch
st.session_state.shap_batch_X_dense = X_dense
st.session_state.shap_batch_n = batch_n
st.session_state.shap_batch_feature_names = names
st.success(f"Batch SHAP computed for first {batch_n} rows.")
if "shap_batch_vals" in st.session_state:
shap_vals_batch = st.session_state.shap_batch_vals
X_dense = st.session_state.shap_batch_X_dense
batch_n = st.session_state.shap_batch_n
names = st.session_state.shap_batch_feature_names
st.divider()
st.subheader("Export: Top SHAP features per row (batch)")
top_k = st.slider("Top-K features per row", 3, 30, 10, 1, key="topk_export")
# Optional: include predicted probabilities for the same batch rows
# (Assumes you already computed proba for all X_inf earlier)
include_proba = st.checkbox("Include predicted probability", value=True, key="include_proba_export")
if st.button("Generate Top-K SHAP table", key="gen_topk_shap"):
shap_vals_batch = st.session_state.shap_batch_vals # shape: (batch_n, n_features)
names = st.session_state.shap_batch_feature_names
batch_n = st.session_state.shap_batch_n
rows = []
for i in range(batch_n):
sv = shap_vals_batch[i]
idx = np.argsort(np.abs(sv))[::-1][:top_k] # top-k by absolute SHAP
for j in idx:
val = float(sv[j])
rows.append({
"row_in_batch": int(i),
"feature": str(names[j]),
"shap_value": val,
"abs_shap_value": abs(val),
"direction": "↑" if val > 0 else ("↓" if val < 0 else "0"),
})
df_topk = pd.DataFrame(rows)
if include_proba:
# Use the same batch rows from the previously computed proba vector
# If you want absolute Excel row index, add + df_inf.index[0] logic as needed
proba_batch = proba[:batch_n]
df_proba = pd.DataFrame({"row_in_batch": list(range(batch_n)), "predicted_probability": proba_batch})
df_topk = df_topk.merge(df_proba, on="row_in_batch", how="left")
# Sort nicely: each row block by importance
df_topk = df_topk.sort_values(["row_in_batch", "abs_shap_value"], ascending=[True, False])
st.dataframe(df_topk, use_container_width=True)
st.download_button(
"Download Top-K SHAP per row (CSV)",
df_topk.to_csv(index=False).encode("utf-8"),
file_name=f"shap_top{top_k}_per_row_first{batch_n}.csv",
mime="text/csv",
key="dl_topk_shap_csv"
)
st.markdown(f"### Global SHAP summary (first {batch_n} rows)")
# BAR SUMMARY
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
shap.summary_plot(
shap_vals_batch,
features=X_dense,
feature_names=names,
plot_type="bar",
max_display=max_display_batch,
show=False
)
fig_bar = plt.gcf()
render_plot_with_download(fig_bar, title="SHAP bar summary", filename="shap_summary_bar.png", export_dpi=export_dpi)
# BEESWARM SUMMARY (optional)
if show_beeswarm:
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
shap.summary_plot(
shap_vals_batch,
features=X_dense,
feature_names=names,
max_display=max_display_batch,
show=False
)
fig_swarm = plt.gcf()
render_plot_with_download(fig_swarm, title="SHAP beeswarm", filename="shap_beeswarm.png", export_dpi=export_dpi)
st.markdown("### Waterfall plots (batch)")
rows_to_plot = st.multiselect(
"Select rows (within the first batch) to plot waterfalls",
options=list(range(batch_n)),
default=[0],
key="batch_rows_to_plot"
)
max_display_single = st.slider("Top features to display (single-row SHAP)",5, 40, 20, 1,key="single_max_display")
max_waterfalls = st.slider("Max waterfall plots to render", 1, 10, 3, 1, key="max_waterfalls")
rows_to_plot = rows_to_plot[:max_waterfalls]
explainer = st.session_state.get("explainer")
base = explainer.expected_value
if not np.isscalar(base):
base = float(np.array(base).reshape(-1)[0])
for r in rows_to_plot:
st.markdown(f"**Row {r} (within first {batch_n})**")
exp = shap.Explanation(
values=shap_vals_batch[r],
base_values=float(base),
data=X_dense[r],
feature_names=names,
)
# Waterfall
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
shap.plots.waterfall(exp, show=False, max_display=max_display_single)
fig_w = plt.gcf()
render_plot_with_download(
fig_w,
title=f"Batch SHAP waterfall (row {r})",
filename=f"shap_waterfall_batch_row_{r}.png",
export_dpi=export_dpi,
key=f"dl_wf_batch_{r}"
)
# Bar
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
shap.plots.bar(exp, show=False, max_display=max_display_single)
fig_b = plt.gcf()
render_plot_with_download(
fig_b,
title=f"Batch SHAP bar (row {r})",
filename=f"shap_bar_batch_row_{r}.png",
export_dpi=export_dpi,
key=f"dl_bar_batch_{r}"
)
# Single row SHAP block
st.subheader("SHAP explanation")
# 1) Form only computes / updates cached explanation
with st.form("shap_form"):
row = st.number_input("Row index", 0, len(X_inf) - 1, int(st.session_state.get("shap_row", 0)))
explain_btn = st.form_submit_button("Generate SHAP explanation")
if explain_btn:
st.session_state.shap_row = int(row)
X_one = X_inf.iloc[[int(row)]]
X_one_t = transform_before_clf(pipe, X_one)
explainer = st.session_state.get("explainer")
explainer_sig = st.session_state.get("explainer_sig")
# Create a simple signature that changes if model changes or background changes
# (using version + number of background rows is usually enough)
current_sig = (
selected, # or meta.get("created_at_utc") or meta.get("metrics", {}).get("roc_auc")
None if st.session_state.get("X_bg_for_shap") is None else int(len(st.session_state["X_bg_for_shap"]))
)
if explainer is None or explainer_sig != current_sig:
X_bg = st.session_state.get("X_bg_for_shap")
if X_bg is None:
st.error("SHAP background not available. Admin must publish latest/background.csv.")
st.stop()
st.session_state.explainer = build_shap_explainer(pipe, X_bg)
st.session_state.explainer_sig = current_sig
explainer = st.session_state.explainer
shap_vals = explainer.shap_values(X_one_t)
if isinstance(shap_vals, list):
shap_vals = shap_vals[1] # positive class
names = get_final_feature_names(pipe)
if len(names) != shap_vals.shape[1]:
names = [f"f{i}" for i in range(shap_vals.shape[1])]
try:
x_dense = X_one_t.toarray()[0]
except Exception:
x_dense = np.array(X_one_t)[0]
base = explainer.expected_value
if not np.isscalar(base):
base = float(np.array(base).reshape(-1)[0])
exp = shap.Explanation(
values=shap_vals[0],
base_values=float(base),
data=x_dense,
feature_names=names,
)
# Cache for re-plotting when sliders change
st.session_state.shap_single_exp = exp
# 2) Plot section OUTSIDE the form (will rerun on slider changes)
if "shap_single_exp" in st.session_state:
exp = st.session_state.shap_single_exp
max_display_single_row = st.slider(
"Top features to display (single row)",
5, 40, int(st.session_state.get("shap_single_max", 20)), 1,
key="shap_single_max"
)
c1, c2 = st.columns(2)
with c1:
st.markdown("**Waterfall**")
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
shap.plots.waterfall(exp, show=False, max_display=max_display_single_row)
fig_w = plt.gcf()
render_plot_with_download(
fig_w,
title="SHAP waterfall",
filename="shap_waterfall_row.png",
export_dpi=export_dpi,
key="dl_shap_wf_single"
)
with c2:
st.markdown("**Top features**")
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
shap.plots.bar(exp, show=False, max_display=max_display_single_row)
fig_b = plt.gcf()
render_plot_with_download(
fig_b,
title="SHAP bar",
filename="shap_bar_row.png",
export_dpi=export_dpi,
key="dl_shap_bar_single"
)
else:
st.info("Submit a row index to generate the SHAP explanation.")