import json from datetime import datetime import numpy as np import pandas as pd import streamlit as st import joblib import os from huggingface_hub import hf_hub_download, HfApi import hmac from sklearn.metrics import ( roc_auc_score, accuracy_score, roc_curve, confusion_matrix, precision_score, recall_score, f1_score, balanced_accuracy_score, precision_recall_curve, average_precision_score, brier_score_loss ) import shap import matplotlib.pyplot as plt from sklearn.calibration import calibration_curve from sklearn.feature_selection import VarianceThreshold, SelectFromModel from sklearn.decomposition import TruncatedSVD from sklearn.pipeline import Pipeline from sklearn.compose import ColumnTransformer from sklearn.preprocessing import OneHotEncoder, StandardScaler from sklearn.impute import SimpleImputer from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split #Figures setting block # REPLACE make_fig with this (or add this and stop using plt.plot directly) def make_fig(figsize=(5.5, 3.6), dpi=120): fig, ax = plt.subplots(figsize=figsize, dpi=dpi) return fig, ax def fig_to_png_bytes(fig, dpi=600): import io buf = io.BytesIO() fig.savefig(buf, format="png", dpi=int(dpi), bbox_inches="tight") buf.seek(0) return buf.getvalue() from typing import Optional def render_plot_with_download( fig, *, title: str, filename: str, export_dpi: int = 600, key: Optional[str] = None ): png_bytes = fig_to_png_bytes(fig, dpi=export_dpi) st.pyplot(fig, clear_figure=False) st.download_button( label=f"Download {title} (PNG {export_dpi} dpi)", data=png_bytes, file_name=filename, mime="image/png", key=key or f"dl_{filename}" ) plt.close(fig) # ============================================================ # Fixed schema definition (PLACEHOLDER FRAMEWORK) # ============================================================ # ========================= # 1) No feature limit # ========================= LABEL_COL = "Outcome Event" # keep label column exactly as header "Outcome Event" in your Excel # ========================= # Survival columns (Excel headers) # ========================= SURV_START_COL_CANDS = ["Date of 1st Bone Marrow biopsy (Date of Diagnosis)", "Diagnosis Date", "Dx Date"] SURV_DEATH_COL_CANDS = ["death_date", "Death Date", "Date of Death"] SURV_CENSOR_COL_CANDS = ["last_followup_date", "Last Follow up date", "Last Follow-up Date", "Last Seen Date"] SURV_EVENT_COL_CANDS = ["Outcome Event", "Death", "Event"] def get_feature_cols_from_df(df: pd.DataFrame): """ Features = ALL columns except Outcome Event (no limit), preserving original order. """ if LABEL_COL not in df.columns: raise ValueError(f"Missing required label column '{LABEL_COL}'. " f"Ensure your Excel header row contains a column named '{LABEL_COL}'.") return [c for c in df.columns if c != LABEL_COL] from datetime import date def to_date_safe(x): """Accepts date/datetime/str; returns python date or None.""" if x is None or (isinstance(x, float) and np.isnan(x)): return None if isinstance(x, datetime): return x.date() if isinstance(x, date): return x # string try: return pd.to_datetime(x).date() except Exception: return None def age_years_at(dob, ref_date): """Age in years at ref_date. Returns float or np.nan.""" dob = to_date_safe(dob) ref_date = to_date_safe(ref_date) if dob is None or ref_date is None: return np.nan if ref_date < dob: return np.nan return (ref_date - dob).days / 365.25 def build_survival_targets(df: pd.DataFrame): start_col = find_col(df, SURV_START_COL_CANDS) death_col = find_col(df, SURV_DEATH_COL_CANDS) censor_col = find_col(df, SURV_CENSOR_COL_CANDS) event_col = find_col(df, SURV_EVENT_COL_CANDS) if start_col is None or event_col is None: raise ValueError("Survival requires at minimum Diagnosis date and Outcome/Event column.") start = pd.to_datetime(df[start_col], errors="coerce") # event indicator (treat YES/1/TRUE as event) ev_raw = df[event_col] if pd.api.types.is_numeric_dtype(ev_raw): ev = ev_raw.fillna(0).astype(int).eq(1) else: ev = (ev_raw.astype(str).str.strip().str.upper().isin(["YES","Y","1","TRUE","T","DEAD","DIED"])) death = pd.to_datetime(df[death_col], errors="coerce") if death_col else pd.Series([pd.NaT]*len(df)) last = pd.to_datetime(df[censor_col], errors="coerce") if censor_col else pd.Series([pd.NaT]*len(df)) end = death.where(ev, last) time_days = (end - start).dt.days.astype("float") time_days = time_days.where(time_days >= 0, np.nan) used = { "start": start_col, "event": event_col, "death": death_col, "censor": censor_col, } return time_days.to_numpy(), ev.astype(int).to_numpy(), used import re def align_columns_to_schema(df: pd.DataFrame, required_cols: list[str]) -> pd.DataFrame: """ Rename df columns to match required_cols using normalized matching. This fixes newline/space/underscore/case differences. """ # normalize inference columns df_cols = list(df.columns) norm_to_actual = {norm_col(c): c for c in df_cols} rename_map = {} for req in required_cols: k = norm_col(req) if k in norm_to_actual: src = norm_to_actual[k] if src != req: rename_map[src] = req if rename_map: df = df.rename(columns=rename_map) return df def norm_col(s: str) -> str: """ Normalize column names for robust matching: - strip leading/trailing whitespace - collapse multiple spaces - remove underscores - lower-case """ s = "" if s is None else str(s) s = s.replace("\u00A0", " ") # non-breaking space -> normal space s = re.sub(r"\s+", " ", s).strip() # collapse spaces s = s.replace("_", "") # ignore underscores return s.lower() def find_col(df: pd.DataFrame, candidates: list[str]) -> str | None: """Return the first column in df that matches any candidate (normalized), else None.""" lookup = {norm_col(c): c for c in df.columns} for cand in candidates: k = norm_col(cand) if k in lookup: return lookup[k] return None def train_survival_bundle( df: pd.DataFrame, feature_cols: list[str], num_cols: list[str], cat_cols: list[str], time_days: np.ndarray, event01: np.ndarray, *, penalizer: float = 0.1 ): """ Returns (bundle_dict, notes). Raises exceptions if hard-fail. Lazy-imports lifelines. """ from lifelines import CoxPHFitter # lazy from sklearn.impute import SimpleImputer # light, ok here too # build survival DF df_surv = df[feature_cols].copy().replace({pd.NA: np.nan}) # coerce numeric/cat for c in num_cols: if c in df_surv.columns: df_surv[c] = pd.to_numeric(df_surv[c], errors="coerce") for c in cat_cols: if c in df_surv.columns: df_surv[c] = df_surv[c].astype("object") df_surv.loc[df_surv[c].isna(), c] = np.nan df_surv[c] = df_surv[c].map(lambda v: v if pd.isna(v) else str(v)) df_surv["time_days"] = time_days df_surv["event"] = event01 df_surv = df_surv.dropna(subset=["time_days", "event"]) duration_col = "time_days" event_col = "event" # one-hot df_surv_oh = pd.get_dummies(df_surv, columns=cat_cols, drop_first=True) if duration_col not in df_surv_oh.columns or event_col not in df_surv_oh.columns: raise ValueError("Survival DF missing duration/event columns after one-hot encoding.") # remove duplicate columns if any messy headers caused duplicates df_surv_oh = df_surv_oh.loc[:, ~df_surv_oh.columns.duplicated()].copy() # predictor columns X_cols = [c for c in df_surv_oh.columns if c not in (duration_col, event_col)] # force numeric for Cox predictors df_surv_oh[X_cols] = df_surv_oh[X_cols].apply(pd.to_numeric, errors="coerce") # ---- Drop zero-variance columns (CRITICAL) ---- # Do this before imputation to remove constant 0/1 one-hot columns var0 = df_surv_oh[X_cols].var(skipna=True) X_cols = [c for c in X_cols if pd.notna(var0.get(c, np.nan)) and var0[c] > 0] if len(X_cols) == 0: raise ValueError("All Cox predictors are zero-variance after one-hot; cannot fit survival model.") # ---- Impute predictors ---- # ---- Phase 1: temporary impute only to compute variance reliably ---- imp_tmp = SimpleImputer(strategy="median") X_tmp = imp_tmp.fit_transform(df_surv_oh[X_cols]) df_surv_oh.loc[:, X_cols] = pd.DataFrame(X_tmp, columns=X_cols, index=df_surv_oh.index) # ---- Drop any columns that became constant after imputation ---- var1 = df_surv_oh[X_cols].var(skipna=True) X_cols = [c for c in X_cols if pd.notna(var1.get(c, np.nan)) and var1[c] > 0] if len(X_cols) == 0: raise ValueError("All Cox predictors became zero-variance after imputation; cannot fit survival model.") # ---- Basic sanity: events vs predictors ---- n_events = int(np.sum(df_surv_oh[event_col].astype(int).to_numpy())) if n_events < 5: raise ValueError(f"Too few events for Cox model (events={n_events}).") # ---- Optional stabilizer: top-k most variable predictors ---- if len(X_cols) > max(10, 2 * n_events): k = max(10, 2 * n_events) vars_sorted = var1[X_cols].sort_values(ascending=False) X_cols = list(vars_sorted.head(k).index) # ---- Phase 2 (CRITICAL): fit FINAL imputer on FINAL X_cols ---- imp = SimpleImputer(strategy="median") X_final = imp.fit_transform(df_surv_oh[X_cols]) df_surv_oh.loc[:, X_cols] = pd.DataFrame(X_final, columns=X_cols, index=df_surv_oh.index) # ---- Fit Cox ---- cph = CoxPHFitter(penalizer=float(max(penalizer, 1.0)), l1_ratio=0.0) cph.fit(df_surv_oh[[*X_cols, duration_col, event_col]], duration_col=duration_col, event_col=event_col) bundle = { "model": cph, "columns": X_cols, "imputer": imp, "cat_cols": cat_cols, "num_cols": num_cols, "feature_cols": feature_cols, "duration_col": duration_col, "event_col": event_col, "version": 2 } return bundle, "Survival model trained successfully." # ============================================================ # Model pipeline # ============================================================ # ========================= # 3) UPDATED pipeline builder # ========================= def build_pipeline( num_cols, cat_cols, *, use_feature_selection: bool = True, l1_C: float = 1.0, selection_max_features: int | None = None, use_dimred: bool = False, svd_components: int = 50, svd_random_state: int = 42, ): """ - No limit on raw variables. - Drops useless columns: (a) zero-variance after preprocessing (b) L1-based selection (optional) - Optional dimensionality reduction via TruncatedSVD (sparse-friendly). """ num_pipe = Pipeline([ ("imputer", SimpleImputer(strategy="median")), ("scaler", StandardScaler()) ]) cat_pipe = Pipeline([ ("imputer", SimpleImputer(strategy="most_frequent")), ("onehot", OneHotEncoder(handle_unknown="ignore", sparse_output=True, drop="first")) ]) preprocessor = ColumnTransformer( transformers=[ ("num", num_pipe, num_cols), ("cat", cat_pipe, cat_cols) ], remainder="drop", verbose_feature_names_out=False ) steps = [] steps.append(("preprocess", preprocessor)) # Drop columns that never vary (common in one-hot) steps.append(("vt", VarianceThreshold(threshold=0.0))) # Optional: dimensionality reduction (works with sparse) # WARNING: makes SHAP less interpretable (components instead of original features) if use_dimred: steps.append(("svd", TruncatedSVD( n_components=int(svd_components), random_state=int(svd_random_state) ))) # Optional: feature selection using L1 logistic (sparse-friendly if SVD is OFF; # if SVD is ON, selection happens on components) if use_feature_selection: selector_est = LogisticRegression( penalty="l1", solver="saga", C=float(l1_C), max_iter=5000, n_jobs=-1, class_weight="balanced" ) # If you want to cap features: set max_features and use threshold that keeps top coefficients. # SelectFromModel doesn't have direct "max_features" — simplest safe approach is threshold-based. # Keep threshold='median' as default; adjust if you want more aggressive pruning. selector = SelectFromModel(selector_est, threshold="median") steps.append(("select", selector)) # Final classifier (keep stable, probability-calibratable) clf = LogisticRegression(max_iter=5000, solver="lbfgs", class_weight="balanced") steps.append(("clf", clf)) return Pipeline(steps) # ============================================================ # Validation utilities # ============================================================ def coerce_binary_label(y: pd.Series): y_clean = y.dropna() uniq = list(pd.unique(y_clean)) if len(uniq) != 2: raise ValueError(f"Outcome Event must be binary (2 unique values). Found: {uniq}") if pd.api.types.is_numeric_dtype(y_clean): pos = sorted(uniq)[-1] return (y == pos).astype(int).to_numpy(), pos if y_clean.dtype == bool: return y.astype(int).to_numpy(), True uniq_str = sorted([str(u) for u in uniq]) pos = uniq_str[-1] return y.astype(str).eq(pos).astype(int).to_numpy(), pos # ============================================================ # Training + persistence # ============================================================ def compute_classification_metrics(y_true, y_proba, threshold: float = 0.5): y_pred = (y_proba >= threshold).astype(int) tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel() sensitivity = tp / (tp + fn) if (tp + fn) else 0.0 # recall, TPR specificity = tn / (tn + fp) if (tn + fp) else 0.0 # TNR precision = precision_score(y_true, y_pred, zero_division=0) recall = recall_score(y_true, y_pred, zero_division=0) f1 = f1_score(y_true, y_pred, zero_division=0) acc = accuracy_score(y_true, y_pred) bacc = balanced_accuracy_score(y_true, y_pred) return { "threshold": float(threshold), "tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp), "sensitivity": float(sensitivity), "specificity": float(specificity), "precision": float(precision), "recall": float(recall), "f1": float(f1), "accuracy": float(acc), "balanced_accuracy": float(bacc), } def find_best_threshold(y_true, y_proba, metric="f1"): thresholds = np.linspace(0.05, 0.95, 181) # step ~0.005 best_t, best_val, best_cls = 0.5, -1, None for t in thresholds: cls = compute_classification_metrics(y_true, y_proba, threshold=float(t)) val = cls.get(metric, 0.0) if val > best_val: best_val, best_t, best_cls = val, float(t), cls return best_t, best_val, best_cls def find_best_threshold_f1(y_true, y_proba, t_min=0.01, t_max=0.99, n=199): """ Returns threshold that maximizes F1 on (y_true, y_proba). """ thresholds = np.linspace(float(t_min), float(t_max), int(n)) best = {"threshold": 0.5, "f1": -1.0, "cls": None} for t in thresholds: cls = compute_classification_metrics(y_true, y_proba, threshold=float(t)) if cls["f1"] > best["f1"]: best = {"threshold": float(t), "f1": float(cls["f1"]), "cls": cls} return best["threshold"], best["cls"] #BOOTSTRAPPING UTILITIES def bootstrap_internal_validation( df: pd.DataFrame, feature_cols: list[str], num_cols: list[str], cat_cols: list[str], *, n_boot: int = 1000, threshold: float = 0.5, use_feature_selection: bool = True, l1_C: float = 1.0, use_dimred: bool = False, svd_components: int = 50, random_state: int = 42, ): """ Bootstrap internal validation using OOB evaluation. For each bootstrap replicate: - sample N rows with replacement - fit pipeline on bootstrap sample - evaluate on OOB rows (not sampled) Returns: - df_boot: per-iteration metrics - summary: mean, median, 2.5% and 97.5% CI for key metrics - n_skipped: count of iterations skipped due to empty OOB or single-class OOB """ rng = np.random.default_rng(int(random_state)) n = len(df) idx_all = np.arange(n) rows = [] n_skipped = 0 for b in range(int(n_boot)): boot_idx = rng.integers(0, n, size=n) # sample N rows with replacement oob_mask = np.ones(n, dtype=bool) oob_mask[boot_idx] = False oob_idx = idx_all[oob_mask] # If OOB empty, skip (rare but possible with small N) if len(oob_idx) == 0: n_skipped += 1 continue df_boot = df.iloc[boot_idx].copy() df_oob = df.iloc[oob_idx].copy() # Prepare X/y for boot and OOB Xb = df_boot[feature_cols].copy().replace({pd.NA: np.nan}) yb_raw = df_boot[LABEL_COL].copy() Xo = df_oob[feature_cols].copy().replace({pd.NA: np.nan}) yo_raw = df_oob[LABEL_COL].copy() # Numeric coercion for c in num_cols: Xb[c] = pd.to_numeric(Xb[c], errors="coerce") Xo[c] = pd.to_numeric(Xo[c], errors="coerce") # Categorical coercion for c in cat_cols: Xb[c] = Xb[c].astype("object") Xb.loc[Xb[c].isna(), c] = np.nan Xb[c] = Xb[c].map(lambda v: v if pd.isna(v) else str(v)) Xo[c] = Xo[c].astype("object") Xo.loc[Xo[c].isna(), c] = np.nan Xo[c] = Xo[c].map(lambda v: v if pd.isna(v) else str(v)) # Coerce labels try: yb01, _ = coerce_binary_label(yb_raw) yo01, _ = coerce_binary_label(yo_raw) except Exception: n_skipped += 1 continue # If OOB has only one class, AUC/PR-AUC invalid -> skip if len(np.unique(yo01)) < 2: n_skipped += 1 continue pipe_b = build_pipeline( num_cols, cat_cols, use_feature_selection=use_feature_selection, l1_C=l1_C, use_dimred=use_dimred, svd_components=svd_components, ) try: pipe_b.fit(Xb, yb01) proba_oob = pipe_b.predict_proba(Xo)[:, 1] except Exception: n_skipped += 1 continue # Metrics on OOB try: auc = float(roc_auc_score(yo01, proba_oob)) except Exception: auc = np.nan pr = compute_pr_curve(yo01, proba_oob) # returns dict with average_precision cal = compute_calibration(yo01, proba_oob, n_bins=10, strategy="uniform") # has brier cls = compute_classification_metrics(yo01, proba_oob, threshold=float(threshold)) rows.append({ "iter": int(b), "n_oob": int(len(oob_idx)), "roc_auc": auc, "avg_precision": float(pr["average_precision"]), "brier": float(cal["brier"]), "sensitivity": float(cls["sensitivity"]), "specificity": float(cls["specificity"]), "precision": float(cls["precision"]), "recall": float(cls["recall"]), "f1": float(cls["f1"]), "accuracy": float(cls["accuracy"]), "balanced_accuracy": float(cls["balanced_accuracy"]), }) df_boot = pd.DataFrame(rows) def summarise(col: str): s = df_boot[col].dropna() if len(s) == 0: return {"mean": np.nan, "median": np.nan, "ci2.5": np.nan, "ci97.5": np.nan, "n": 0} return { "mean": float(s.mean()), "median": float(s.median()), "ci2.5": float(np.quantile(s, 0.025)), "ci97.5": float(np.quantile(s, 0.975)), "n": int(len(s)) } summary = { "roc_auc": summarise("roc_auc"), "avg_precision": summarise("avg_precision"), "brier": summarise("brier"), "sensitivity": summarise("sensitivity"), "specificity": summarise("specificity"), "f1": summarise("f1"), "balanced_accuracy": summarise("balanced_accuracy"), } return df_boot, summary, n_skipped #PLOT CURVE UTILITIES def compute_pr_curve(y_true, y_proba): precision, recall, pr_thresholds = precision_recall_curve(y_true, y_proba) ap = average_precision_score(y_true, y_proba) return { "average_precision": float(ap), "precision": [float(x) for x in precision], "recall": [float(x) for x in recall], "thresholds": [float(x) for x in pr_thresholds], } def compute_calibration(y_true, y_proba, n_bins: int = 10, strategy: str = "uniform"): prob_true, prob_pred = calibration_curve( y_true, y_proba, n_bins=n_bins, strategy=strategy ) brier = brier_score_loss(y_true, y_proba) return { "n_bins": int(n_bins), "strategy": str(strategy), "prob_true": [float(x) for x in prob_true], "prob_pred": [float(x) for x in prob_pred], "brier": float(brier), } def decision_curve_analysis(y_true, y_proba, thresholds=None): y_true = np.asarray(y_true).astype(int) y_proba = np.asarray(y_proba).astype(float) if thresholds is None: thresholds = np.linspace(0.01, 0.99, 99) n = len(y_true) prevalence = float(np.mean(y_true)) nb_model = [] nb_all = [] nb_none = [] for pt in thresholds: y_pred = (y_proba >= pt).astype(int) tp = np.sum((y_pred == 1) & (y_true == 1)) fp = np.sum((y_pred == 1) & (y_true == 0)) w = pt / (1.0 - pt) nb_m = (tp / n) - (fp / n) * w nb_a = prevalence - (1.0 - prevalence) * w nb_n = 0.0 nb_model.append(float(nb_m)) nb_all.append(float(nb_a)) nb_none.append(float(nb_n)) return { "thresholds": [float(x) for x in thresholds], "net_benefit_model": nb_model, "net_benefit_all": nb_all, "net_benefit_none": nb_none, "prevalence": prevalence, } def train_and_save( df: pd.DataFrame, feature_cols, num_cols, cat_cols, n_bins: int, cal_strategy: str, dca_points: int, *, use_feature_selection: bool, l1_C: float, use_dimred: bool, svd_components: int,): X = df[feature_cols].copy() y_raw = df[LABEL_COL].copy() X = X.replace({pd.NA: np.nan}) for c in num_cols: X[c] = pd.to_numeric(X[c], errors="coerce") for c in cat_cols: X[c] = X[c].astype("object") X.loc[X[c].isna(), c] = np.nan X[c] = X[c].map(lambda v: v if pd.isna(v) else str(v)) y01, pos_class = coerce_binary_label(y_raw) # ---- Survival targets (time-to-event) ---- # Requires these columns to exist in the training Excel: # SURV_START_COL, SURV_DEATH_COL, SURV_CENSOR_COL surv_used_cols = None try: time_days, event01, surv_used_cols = build_survival_targets(df) except Exception: time_days, event01, surv_used_cols = None, None, None X_train, X_test, y_train, y_test = train_test_split( X, y01, test_size=0.2, random_state=42, stratify=y01 ) pipe = build_pipeline( num_cols, cat_cols, use_feature_selection=use_feature_selection, l1_C=l1_C, use_dimred=use_dimred, svd_components=svd_components ) pipe.fit(X_train, y_train) proba = pipe.predict_proba(X_test)[:, 1] # (metrics code unchanged...) # ----- METRICS BLOCK (MISSING) ----- roc_auc = float(roc_auc_score(y_test, proba)) fpr, tpr, roc_thresholds = roc_curve(y_test, proba) cls_05 = compute_classification_metrics(y_test, proba, threshold=0.5) best_thr, best_val, cls_best = find_best_threshold(y_test, proba, metric="f1") metrics = { "roc_auc": roc_auc, "n_train": int(len(X_train)), "n_test": int(len(X_test)), # reference @0.5 "threshold@0.5": 0.5, "accuracy@0.5": cls_05["accuracy"], "balanced_accuracy@0.5": cls_05["balanced_accuracy"], "precision@0.5": cls_05["precision"], "recall@0.5": cls_05["recall"], "f1@0.5": cls_05["f1"], "sensitivity@0.5": cls_05["sensitivity"], "specificity@0.5": cls_05["specificity"], "confusion_matrix@0.5": { "tn": cls_05["tn"], "fp": cls_05["fp"], "fn": cls_05["fn"], "tp": cls_05["tp"], }, # primary: best F1 threshold "best_threshold_by": "f1", "best_threshold": float(best_thr), "best_f1": float(cls_best["f1"]), "accuracy@best": cls_best["accuracy"], "balanced_accuracy@best": cls_best["balanced_accuracy"], "precision@best": cls_best["precision"], "recall@best": cls_best["recall"], "f1@best": cls_best["f1"], "sensitivity@best": cls_best["sensitivity"], "specificity@best": cls_best["specificity"], "confusion_matrix@best": { "tn": cls_best["tn"], "fp": cls_best["fp"], "fn": cls_best["fn"], "tp": cls_best["tp"], }, "roc_curve": { "fpr": [float(x) for x in fpr], "tpr": [float(x) for x in tpr], "thresholds": [float(x) for x in roc_thresholds], }, "pr_curve": compute_pr_curve(y_test, proba), "calibration": compute_calibration(y_test, proba, n_bins, cal_strategy), "decision_curve": decision_curve_analysis(y_test, proba, np.linspace(0.01, 0.99, dca_points)), } # ---- Train survival model (CoxPH) ---- from sklearn.impute import SimpleImputer survival_trained = False surv_notes = None surv_used_cols = None try: time_days, event01, surv_used_cols = build_survival_targets(df) except Exception: time_days, event01, surv_used_cols = None, None, None if time_days is not None and event01 is not None: try: bundle, surv_notes = train_survival_bundle( df=df, feature_cols=feature_cols, num_cols=num_cols, cat_cols=cat_cols, time_days=time_days, event01=event01, penalizer=0.1 ) joblib.dump(bundle, "survival_bundle.joblib", compress=3) survival_trained = True except Exception as e: survival_trained = False surv_notes = f"Survival model training failed: {e}" else: surv_notes = "Survival columns missing or could not be parsed; survival model not trained." joblib.dump(pipe, "model.joblib",compress=3) meta = { "framework": "Explainable-Acute-Leukemia-Mortality-Predictor", "model": "Logistic Regression", "created_at_utc": datetime.utcnow().isoformat(), "schema": { "features": feature_cols, # now unlimited "numeric": num_cols, "categorical": cat_cols, "label": LABEL_COL, }, "feature_reduction": { "variance_threshold": 0.0, "use_dimred": bool(use_dimred), "svd_components": int(svd_components) if use_dimred else None, "use_feature_selection": bool(use_feature_selection), "l1_C": float(l1_C) if use_feature_selection else None, "selection_method": "SelectFromModel(L1 saga, threshold=median)" if use_feature_selection else None, "note": "If SVD is enabled, SHAP becomes component-level (less interpretable)." }, "shap_background": { "file": "background.csv", "max_rows": 200, "note": "Raw (pre-transform) background sample for SHAP LinearExplainer." }, "survival": { "enabled": bool(survival_trained), "duration_units": "days", "model": "CoxPHFitter (lifelines) with one-hot for categoricals", "required_columns_any_of": { "start": SURV_START_COL_CANDS, "event": SURV_EVENT_COL_CANDS, "death": SURV_DEATH_COL_CANDS, "censor": SURV_CENSOR_COL_CANDS, }, "used_columns": surv_used_cols, "notes": surv_notes }, "positive_class": str(pos_class), "metrics": metrics, } with open("meta.json", "w", encoding="utf-8") as f: json.dump(meta, f, indent=2) st.write("Local survival bundle exists:", os.path.exists("survival_bundle.joblib")) if os.path.exists("survival_bundle.joblib"): st.write("Local survival bundle size (bytes):", os.path.getsize("survival_bundle.joblib")) st.write("Local survival bundle exists:", os.path.exists("survival_bundle.joblib")) st.json(meta.get("survival", {})) return pipe, meta, X_train, y_test, proba # ============================================================ # SHAP # ============================================================ def build_shap_explainer(pipe, X_bg, max_bg=200): if X_bg is None or len(X_bg) == 0: raise ValueError("SHAP background is empty.") if len(X_bg) > int(max_bg): X_bg = X_bg.sample(int(max_bg), random_state=42) clf = pipe.named_steps["clf"] Xt_bg = transform_before_clf(pipe, X_bg) explainer = shap.LinearExplainer( clf, Xt_bg, feature_perturbation="interventional" ) return explainer def safe_dense(Xt, max_rows: int = 200): """ Convert sparse->dense carefully. Avoid converting huge matrices to dense. """ if hasattr(Xt, "shape") and Xt.shape[0] > max_rows: Xt = Xt[:max_rows] try: return Xt.toarray() except Exception: return np.array(Xt) def ensure_model_repo_exists(model_repo_id: str, token: str): """ Optional helper: create the model repo if it doesn't exist. Safe to call; if it exists, it will error -> you can ignore. """ # Increase Hub HTTP timeouts (seconds) os.environ["HF_HUB_HTTP_TIMEOUT"] = "300" # 5 minutes os.environ["HF_HUB_ETAG_TIMEOUT"] = "300" api = HfApi(token=token) try: api.create_repo(repo_id=model_repo_id, repo_type="model", private=False, exist_ok=True) except Exception: pass def coerce_X_like_schema(X: pd.DataFrame, feature_cols: list[str], num_cols: list[str], cat_cols: list[str]) -> pd.DataFrame: """ Robustly align an inference dataframe to the training schema: - normalizes column names (spaces/newlines/underscores/case) - renames matching columns to the model's exact feature names - creates missing columns as NaN (so prediction can still run) """ X = X.copy() # Build normalized lookup from inference file columns -> actual column name in inference inf_lookup = {norm_col(c): c for c in X.columns} # Build output with exact training columns X_out = pd.DataFrame(index=X.index) missing = [] for col in feature_cols: k = norm_col(col) if k in inf_lookup: X_out[col] = X[inf_lookup[k]] else: X_out[col] = np.nan missing.append(col) # Optional: show missing columns (don’t hard fail) if missing: st.warning( "Inference file is missing some training columns (filled as blank/NaN): " + ", ".join(missing[:12]) + (" ..." if len(missing) > 12 else "") ) # Coercions (same as your existing logic) X_out = X_out.replace({pd.NA: np.nan}) for c in num_cols: if c in X_out.columns: X_out[c] = pd.to_numeric(X_out[c], errors="coerce") for c in cat_cols: if c in X_out.columns: X_out[c] = X_out[c].astype("object") X_out.loc[X_out[c].isna(), c] = np.nan X_out[c] = X_out[c].map(lambda v: v if pd.isna(v) else str(v)) return X_out def get_shap_background_auto(model_repo_id: str, feature_cols: list[str], num_cols: list[str], cat_cols: list[str]) -> pd.DataFrame | None: """ Attempts to load SHAP background from HF repo. Returns coerced background or None. """ df_bg = load_latest_background(model_repo_id) if df_bg is None: return None # Ensure required columns exist missing = [c for c in feature_cols if c not in df_bg.columns] if missing: return None return coerce_X_like_schema(df_bg, feature_cols, num_cols, cat_cols) # ============================================================ # SHAP background persistence (best practice) # ============================================================ def save_background_sample_csv(X_bg: pd.DataFrame, feature_cols: list[str], max_rows: int = 200, out_path: str = "background.csv"): """ Saves a small *raw* background dataset (pre-transform) for SHAP explainer. Must contain columns exactly matching feature_cols. """ if X_bg is None or len(X_bg) == 0: raise ValueError("X_bg is empty; cannot save background sample.") X_bg = X_bg[feature_cols].copy() if len(X_bg) > int(max_rows): X_bg = X_bg.sample(int(max_rows), random_state=42) # Preserve exact columns for future loading X_bg.to_csv(out_path, index=False, encoding="utf-8") return out_path def publish_background_to_hub(model_repo_id: str, version_tag: str, background_path: str = "background.csv"): """ Uploads background.csv to both versioned and latest paths. Requires HF_TOKEN with write permissions. """ token = os.environ.get("HF_TOKEN") if not token: raise RuntimeError("HF_TOKEN not found. Add it in Space Settings → Secrets.") api = HfApi(token=token) version_bg_path = f"releases/{version_tag}/background.csv" # Versioned api.upload_file( path_or_fileobj=background_path, path_in_repo=version_bg_path, repo_id=model_repo_id, repo_type="model", commit_message=f"Upload SHAP background ({version_tag})" ) # Latest api.upload_file( path_or_fileobj=background_path, path_in_repo="latest/background.csv", repo_id=model_repo_id, repo_type="model", commit_message=f"Update latest SHAP background ({version_tag})" ) return { "version_bg_path": version_bg_path, "latest_bg_path": "latest/background.csv", } def load_latest_background(model_repo_id: str) -> pd.DataFrame | None: """ Loads latest/background.csv if present. Returns None if not found / cannot load. """ try: bg_file = hf_hub_download( repo_id=model_repo_id, repo_type="model", filename="latest/background.csv", ) df_bg = pd.read_csv(bg_file) return df_bg except Exception: return None def load_background_by_version(model_repo_id: str, version_tag: str) -> pd.DataFrame | None: """ Loads releases//background.csv if present. """ try: bg_file = hf_hub_download( repo_id=model_repo_id, repo_type="model", filename=f"releases/{version_tag}/background.csv", ) df_bg = pd.read_csv(bg_file) return df_bg except Exception: return None def publish_to_hub(model_repo_id: str, version_tag: str): """ Uploads model.joblib and meta.json to a Hugging Face *model* repository under a versioned folder and also updates a 'latest/' copy. Requires Space Secret: HF_TOKEN. """ token = os.environ.get("HF_TOKEN") if not token: raise RuntimeError("HF_TOKEN not found. Add it in Space Settings → Secrets.") api = HfApi(token=token) # Versioned paths version_model_path = f"releases/{version_tag}/model.joblib" version_meta_path = f"releases/{version_tag}/meta.json" # Upload versioned artifacts api.upload_file( path_or_fileobj="model.joblib", path_in_repo=version_model_path, repo_id=model_repo_id, repo_type="model", commit_message=f"Upload model artifacts ({version_tag})" ) api.upload_file( path_or_fileobj="meta.json", path_in_repo=version_meta_path, repo_id=model_repo_id, repo_type="model", commit_message=f"Upload metadata ({version_tag})" ) # Also maintain a 'latest/' copy for easy loading api.upload_file( path_or_fileobj="model.joblib", path_in_repo="latest/model.joblib", repo_id=model_repo_id, repo_type="model", commit_message=f"Update latest model ({version_tag})" ) api.upload_file( path_or_fileobj="meta.json", path_in_repo="latest/meta.json", repo_id=model_repo_id, repo_type="model", commit_message=f"Update latest metadata ({version_tag})" ) # Optional: upload survival model if present if os.path.exists("survival_bundle.joblib"): api.upload_file( path_or_fileobj="survival_bundle.joblib", path_in_repo=f"releases/{version_tag}/survival_bundle.joblib", repo_id=model_repo_id, repo_type="model", commit_message=f"Upload survival model ({version_tag})" ) api.upload_file( path_or_fileobj="survival_bundle.joblib", path_in_repo="latest/survival_bundle.joblib", repo_id=model_repo_id, repo_type="model", commit_message=f"Update latest survival model ({version_tag})" ) files = api.list_repo_files(repo_id=model_repo_id, repo_type="model") st.write([f for f in files if "survival" in f]) return { "version_model_path": version_model_path, "version_meta_path": version_meta_path, "latest_model_path": "latest/model.joblib", "latest_meta_path": "latest/meta.json", } MODEL_REPO_ID = "Synav/Explainable-Acute-Leukemia-Mortality-Predictor" def list_release_versions(model_repo_id: str): """ Returns sorted version tags found under releases//model.joblib in the model repo. """ api = HfApi(token=os.environ.get("HF_TOKEN") or None) files = api.list_repo_files(repo_id=model_repo_id, repo_type="model") versions = set() for f in files: # We only care about releases//model.joblib if f.startswith("releases/") and f.endswith("/model.joblib"): parts = f.split("/") if len(parts) >= 3: versions.add(parts[1]) # Most users want newest first (timestamp tags sort lexicographically) return sorted(versions, reverse=True) def load_model_by_version(model_repo_id: str, version_tag: str): """ Loads a specific version from releases//model.joblib and meta.json """ model_file = hf_hub_download( repo_id=model_repo_id, repo_type="model", filename=f"releases/{version_tag}/model.joblib", ) meta_file = hf_hub_download( repo_id=model_repo_id, repo_type="model", filename=f"releases/{version_tag}/meta.json", ) pipe = joblib.load(model_file) with open(meta_file, "r", encoding="utf-8") as f: meta = json.load(f) return pipe, meta def load_latest_survival_model(model_repo_id: str): f = hf_hub_download( repo_id=model_repo_id, repo_type="model", filename="latest/survival_bundle.joblib", ) return joblib.load(f) def load_survival_model_by_version(model_repo_id: str, version_tag: str): f = hf_hub_download( repo_id=model_repo_id, repo_type="model", filename=f"releases/{version_tag}/survival_bundle.joblib", ) return joblib.load(f) def load_latest_survival_bundle(model_repo_id: str): f = hf_hub_download(repo_id=model_repo_id, repo_type="model", filename="latest/survival_bundle.joblib") return joblib.load(f) def load_latest_model(model_repo_id: str): """ Loads latest/model.joblib and latest/meta.json """ model_file = hf_hub_download( repo_id=model_repo_id, repo_type="model", filename="latest/model.joblib", ) meta_file = hf_hub_download( repo_id=model_repo_id, repo_type="model", filename="latest/meta.json", ) pipe = joblib.load(model_file) with open(meta_file, "r", encoding="utf-8") as f: meta = json.load(f) return pipe, meta def get_post_selection_feature_names(pipe) -> list[str]: """ Returns feature names aligned to the exact columns seen by the final classifier. Works when SVD is OFF. Handles: preprocess -> VarianceThreshold -> SelectFromModel -> clf If some steps are missing, it degrades gracefully. """ pre = pipe.named_steps.get("preprocess", None) if pre is None: raise ValueError("Pipeline missing 'preprocess' step.") # Start with preprocessed (post-onehot) names try: names = list(pre.get_feature_names_out()) except Exception: # Fallback, but you should almost never hit this names = [f"f{i}" for i in range(pipe.named_steps["clf"].coef_.shape[1])] # Apply VarianceThreshold support mask (if present) vt = pipe.named_steps.get("vt", None) if vt is not None and hasattr(vt, "get_support"): vt_mask = vt.get_support() names = [n for n, keep in zip(names, vt_mask) if keep] # If SVD exists, we cannot map back to original one-hot features # because the feature space becomes components. if "svd" in pipe.named_steps: # Return component names for correctness svd = pipe.named_steps["svd"] k = getattr(svd, "n_components", None) or len(names) return [f"SVD_component_{i+1}" for i in range(int(k))] # Apply SelectFromModel mask (if present) sel = pipe.named_steps.get("select", None) if sel is not None and hasattr(sel, "get_support"): sel_mask = sel.get_support() names = [n for n, keep in zip(names, sel_mask) if keep] return names def transform_before_clf(pipe, X): """ Transform X through all pipeline steps BEFORE the classifier. Ensures SHAP sees the exact columns the clf sees. """ Xt = X for name, step in pipe.steps: if name == "clf": break if hasattr(step, "transform"): Xt = step.transform(Xt) return Xt def get_final_feature_names(pipe) -> list[str]: """ Return feature names aligned with the exact columns seen by the classifier. Handles: preprocess -> vt -> (svd?) -> (select?) -> clf """ pre = pipe.named_steps.get("preprocess", None) if pre is None: raise ValueError("Pipeline missing 'preprocess' step.") # Start: post-onehot names try: names = list(pre.get_feature_names_out()) except Exception: names = None # Apply vt mask if we have names vt = pipe.named_steps.get("vt", None) if names is not None and vt is not None and hasattr(vt, "get_support"): vt_mask = vt.get_support() names = [n for n, keep in zip(names, vt_mask) if keep] # If SVD exists: feature space becomes components; names must be components if "svd" in pipe.named_steps: svd = pipe.named_steps["svd"] k = int(getattr(svd, "n_components", 0) or 0) if k <= 0: # fallback if unknown k = 0 return [f"SVD_component_{i+1}" for i in range(k)] # Apply select mask (post-vt) if we have names sel = pipe.named_steps.get("select", None) if names is not None and sel is not None and hasattr(sel, "get_support"): sel_mask = sel.get_support() names = [n for n, keep in zip(names, sel_mask) if keep] # Fallback: if we couldn't get names, return generic based on clf coef size if names is None: n = pipe.named_steps["clf"].coef_.shape[1] names = [f"f{i}" for i in range(n)] return names def options_for(col: str, df: pd.DataFrame | None): base = DEFAULT_OPTIONS.get(col, []) excel = [] if df is not None and col in df.columns: excel = ( df[col].dropna().astype(str).map(lambda x: x.strip()) .loc[lambda s: s != ""].unique().tolist() ) # keep order: defaults first, then any extra from excel out = [] for v in base + sorted(set(excel)): if v not in out: out.append(v) return [""] + out import re # Canonical region labels you can use for analysis # (UN-style: Africa, Americas, Asia, Europe, Oceania; you can later refine into subregions) REGION_UNKNOWN = "Unknown" # Normalization: your Excel has values like "Emarati", "FILIPINO", "Indian ", etc. NATIONALITY_ALIASES = { # Nationality adjectives / common variants -> canonical country name "EMARATI": "United Arab Emirates", "EMIRATI": "United Arab Emirates", "UAE": "United Arab Emirates", "FILIPINO": "Philippines", "PALESTINIAN": "Palestine", "SYRIAN": "Syria", "LEBANESE": "Lebanon", "JORDANIAN": "Jordan", "YEMENI": "Yemen", "EGYPTIAN": "Egypt", "SUDANESE": "Sudan", "ETHIOPIAN": "Ethiopia", "ERITREAN": "Eritrea", "SOMALI": "Somalia", "KENYAN": "Kenya", "UGANDAN": "Uganda", "GUINEAN": "Guinea", "MOROCCAN": "Morocco", "COMORAN": "Comoros", "INDIAN": "India", "PAKISTANI": "Pakistan", "BANGLADESH": "Bangladesh", "NEPALESE": "Nepal", "INDONESIAN": "Indonesia", "MALAYSIAN": "Malaysia", "AMERICAN": "United States", "USA": "United States", "U.S.A": "United States", "UNITED STATES OF AMERICA": "United States", } def normalize_country_name(x: str) -> str | None: """Convert nationality-like strings to a canonical country name when possible.""" if x is None: return None s = str(x).strip() if not s: return None s_up = re.sub(r"\s+", " ", s).strip().upper() # map adjectives/variants if s_up in NATIONALITY_ALIASES: return NATIONALITY_ALIASES[s_up] # If someone already entered a country name, keep it # country_converter can handle many variants; pass through as-is return s.strip() from typing import Optional def country_to_region(country: str | None) -> str: """ Lazy-import country_converter to reduce startup memory. Returns one of: Africa, Americas, Asia, Europe, Oceania, Unknown """ if not country: return REGION_UNKNOWN import country_converter as coco # lazy r = coco.convert(names=country, to="continent") if not r or str(r).lower() in ("not found", "nan", "none"): return REGION_UNKNOWN if r == "America": return "Americas" return str(r) def add_ethnicity_region(df: pd.DataFrame, eth_col: str = "Ethnicity", out_col: str = "Ethnicity_Region") -> pd.DataFrame: if eth_col not in df.columns: df[out_col] = REGION_UNKNOWN return df norm = df[eth_col].map(normalize_country_name) df[out_col] = norm.map(country_to_region) return df # ============================================================ # Streamlit UI # ============================================================ def is_admin() -> bool: """ Admin gating for sensitive actions (Train + Publish). Requires Space secret ADMIN_KEY. """ secret = os.environ.get("ADMIN_KEY", "") if not secret: return False entered = st.session_state.get("admin_key", "") return hmac.compare_digest(str(entered), str(secret)) st.set_page_config(page_title="Explainable-Acute-Leukemia-Mortality-Predictor", layout="wide") # ---------------- Plot settings (global: Train + Predict) ---------------- with st.sidebar: st.markdown("### Plot settings") plot_width = st.slider("Plot width (inches)", 4.0, 10.0, 5.5, 0.1) plot_height = st.slider("Plot height (inches)", 2.5, 6.0, 3.6, 0.1) plot_dpi_screen = st.slider("Screen DPI", 80, 200, 120, 10) export_dpi = st.selectbox("Export DPI (PNG)", [300, 600, 900, 1200], index=1) FIGSIZE = (plot_width, plot_height) st.title("Explainable-Acute-Leukemia-Mortality-Predictor") st.caption("Explainable clinical AI for mortality and outcome prediction in acute leukemia using SHAP-interpretable models") with st.expander("About this AI model, who can use it, and required Excel format", expanded=True): st.markdown(""" ### Quick Start (Single Patient) 1. Open **Predict + SHAP** 2. **Load model** → select *latest* (default) 3. Enter patient details (Core → Clinical → FISH → NGS) 4. Click **Predict single patient** You will get: • Mortality (os) probability • Risk category • 6-month, 1-year, 3-year survival with PLOT • SHAP explanation Edit values to instantly test different scenarios. *For research/decision-support only — not a substitute for clinical judgment.* """) st.warning( "Prediction will fail if feature names or variable types " "do not exactly match the trained model schema." ) # 🔐 Admin controls come AFTER schema explanation with st.expander("Admin controls", expanded=False): st.text_input("Admin key", type="password", key="admin_key") st.caption("Training and publishing are enabled only after admin authentication.") tab_train, tab_predict = st.tabs(["1️⃣ Train", "2️⃣ Predict + SHAP"]) if "pipe" not in st.session_state: st.session_state.pipe = None if "explainer" not in st.session_state: st.session_state.explainer = None # ---------------- TRAIN ---------------- with tab_train: st.subheader("Train model") if not is_admin(): st.info("Training and publishing are restricted. Use Predict + SHAP for inference.") else: st.markdown("### Feature reduction options") use_feature_selection = st.checkbox( "Drop columns that do not affect prediction (L1 feature selection)", value=True, key="train_use_feature_selection" ) l1_C = st.slider( "L1 selection strength (lower = fewer features)", 0.01, 10.0, 1.0, 0.01 ) if use_feature_selection else 1.0 use_dimred = st.checkbox( "Dimensionality reduction (TruncatedSVD) — reduces interpretability", value=False ) svd_components = st.slider( "SVD components (only used if enabled)", 5, 300, 50, 5 ) if use_dimred else 50 st.divider() # then keep your file uploader + training button + publish block here st.markdown("### Bootstrap internal validation (optional)") do_bootstrap = st.checkbox( "Enable bootstrapping (OOB internal validation)", value=False, key="train_bootstrap_enable" ) n_boot = st.slider( "Bootstrap iterations", 50, 2000, 1000, 50, key="train_bootstrap_n" ) if do_bootstrap else 0 boot_thr = st.slider( "Bootstrap classification threshold", 0.0, 1.0, 0.5, 0.01, key="train_bootstrap_thr" ) if do_bootstrap else 0.5 #--------------------- train_file = st.file_uploader("Upload training Excel (.xlsx)", type=["xlsx"]) if train_file is None: st.info("Upload a training Excel file to enable training.") else: df = pd.read_excel(train_file, engine="openpyxl") df.columns = [c.strip() for c in df.columns] feature_cols = get_feature_cols_from_df(df) st.dataframe(df.head(), use_container_width=True) feature_cols = get_feature_cols_from_df(df) st.markdown("### Choose variable types (saved into the model)") default_numeric = feature_cols[:13] # initial suggestion num_cols = st.multiselect( "Numeric variables (will be median-imputed + scaled)", options=feature_cols, default=default_numeric ) # Everything not selected as numeric becomes categorical cat_cols = [c for c in feature_cols if c not in num_cols] st.write(f"Categorical variables (will be most-frequent-imputed + one-hot): {len(cat_cols)}") st.caption("Note: The selected schema is stored with the trained model and must match inference files.") st.markdown("### Evaluation settings") n_bins = st.slider("Calibration bins", 5, 20, 10, 1) cal_strategy = st.selectbox("Calibration binning strategy", ["uniform", "quantile"], index=0) dca_points = st.slider("Decision curve points", 25, 200, 99, 1) if st.button("Train model"): with st.spinner("Training model..."): pipe, meta, X_train, y_test, proba = train_and_save( df, feature_cols, num_cols, cat_cols, n_bins=n_bins, cal_strategy=cal_strategy, dca_points=dca_points, use_feature_selection=use_feature_selection, l1_C=l1_C, use_dimred=use_dimred, svd_components=svd_components ) # --- Save background sample for SHAP (raw X_train) --- try: save_background_sample_csv( X_bg=X_train, feature_cols=feature_cols, max_rows=200, out_path="background.csv" ) st.success("Saved SHAP background sample (background.csv).") except Exception as e: st.warning(f"Could not save SHAP background sample: {e}") st.write("Local survival bundle exists:", os.path.exists("survival_bundle.joblib")) if os.path.exists("survival_bundle.joblib"): st.write("Local survival bundle size (bytes):", os.path.getsize("survival_bundle.joblib")) explainer = build_shap_explainer(pipe, X_train) st.session_state.pipe = pipe st.session_state.explainer = explainer st.session_state.meta = meta st.success("Training complete. model.joblib and meta.json created.") st.divider() st.subheader("Training performance (test split)") m = meta["metrics"] # Show key metrics at threshold 0.5 c1, c2, c3, c4 = st.columns(4) c1.metric("ROC AUC", f"{m['roc_auc']:.3f}") c2.metric("Sensitivity (best F1 thr)", f"{m['sensitivity@best']:.3f}") c3.metric("Specificity (best F1 thr)", f"{m['specificity@best']:.3f}") c4.metric("F1 (best)", f"{m['f1@best']:.3f}") st.caption(f"Best threshold (max F1): {m['best_threshold']:.2f}") c5, c6, c7, c8 = st.columns(4) c5.metric("Precision", f"{m['precision@0.5']:.3f}") c6.metric("Accuracy", f"{m['accuracy@0.5']:.3f}") c7.metric("Balanced Acc", f"{m['balanced_accuracy@0.5']:.3f}") c8.metric("Test N", m["n_test"]) if do_bootstrap: st.divider() st.subheader(f"Bootstrap internal validation (n={n_boot}, OOB)") with st.spinner("Running bootstrap..."): df_boot, boot_summary, n_skipped = bootstrap_internal_validation( df=df, feature_cols=feature_cols, num_cols=num_cols, cat_cols=cat_cols, n_boot=int(n_boot), threshold=float(boot_thr), use_feature_selection=use_feature_selection, l1_C=l1_C, use_dimred=use_dimred, svd_components=svd_components, random_state=42, ) st.caption(f"Completed: {len(df_boot)} iterations. Skipped: {n_skipped} (empty/single-class OOB or fitting issues).") # Display summary st.json(boot_summary) # Show per-iteration table (optional) st.dataframe(df_boot.head(30), use_container_width=True) # Download full bootstrap results st.download_button( "Download bootstrap results (CSV)", df_boot.to_csv(index=False).encode("utf-8"), file_name=f"bootstrap_oob_{n_boot}.csv", mime="text/csv", key="dl_bootstrap_csv" ) # Save into meta.json so it persists with the model st.session_state.meta.setdefault("bootstrap", {}) st.session_state.meta["bootstrap"] = { "method": "bootstrap_oob", "n_boot": int(n_boot), "threshold": float(boot_thr), "skipped": int(n_skipped), "summary": boot_summary, } meta = st.session_state.meta with open("meta.json", "w", encoding="utf-8") as f: json.dump(meta, f, indent=2) # Confusion matrix display cm = m["confusion_matrix@0.5"] cm_df = pd.DataFrame( [[cm["tn"], cm["fp"]], [cm["fn"], cm["tp"]]], index=["Actual 0", "Actual 1"], columns=["Pred 0", "Pred 1"] ) st.markdown("**Confusion Matrix (threshold = 0.5)**") st.dataframe(cm_df) st.json(meta["survival"]) # TRAINING: ROC curve plot # ========================= roc = m["roc_curve"] fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen) ax.plot(roc["fpr"], roc["tpr"]) ax.plot([0, 1], [0, 1]) ax.set_xlabel("False Positive Rate (1 - Specificity)") ax.set_ylabel("True Positive Rate (Sensitivity)") ax.set_title(f"ROC Curve (AUC = {m['roc_auc']:.3f})") render_plot_with_download( fig, title="ROC curve", filename="roc_curve.png", export_dpi=export_dpi, key="dl_train_roc" ) #Precision recall curve # ========================= # TRAINING: PR curve plot # ========================= pr = m["pr_curve"] fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen) ax.plot(pr["recall"], pr["precision"]) ax.set_xlabel("Recall") ax.set_ylabel("Precision") ax.set_title(f"PR Curve (AP = {pr['average_precision']:.3f})") render_plot_with_download( fig, title="PR curve", filename="pr_curve.png", export_dpi=export_dpi, key="dl_train_pr" ) #Calibration plot # ========================= # TRAINING: Calibration plot # ========================= cal = m["calibration"] fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen) ax.plot(cal["prob_pred"], cal["prob_true"]) ax.plot([0, 1], [0, 1]) ax.set_xlabel("Mean predicted probability") ax.set_ylabel("Observed event rate") ax.set_title("Calibration curve") render_plot_with_download( fig, title="Calibration curve", filename="calibration_curve.png", export_dpi=export_dpi, key="dl_train_cal" ) #Decision curve # ========================= # TRAINING: Decision curve analysis plot # ========================= dca = m["decision_curve"] fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen) ax.plot(dca["thresholds"], dca["net_benefit_model"], label="Model") ax.plot(dca["thresholds"], dca["net_benefit_all"], label="Treat all") ax.plot(dca["thresholds"], dca["net_benefit_none"], label="Treat none") ax.set_xlabel("Threshold probability") ax.set_ylabel("Net benefit") ax.set_title("Decision curve analysis") ax.legend() render_plot_with_download( fig, title="Decision curve", filename="decision_curve.png", export_dpi=export_dpi, key="dl_train_dca" ) st.caption( "If the model curve is above Treat-all and Treat-none across a threshold range, " "the model provides net clinical benefit in that range." ) st.divider() st.subheader("Threshold analysis") thr = st.slider("Decision threshold", 0.0, 1.0, 0.5, 0.01) # Recompute threshold-based metrics quickly using stored probabilities # You need y_test and proba in scope. Easiest is to store them in session_state during training. st.session_state.y_test_last = y_test st.session_state.proba_last = proba if "y_test_last" in st.session_state and "proba_last" in st.session_state: cls = compute_classification_metrics(st.session_state.y_test_last, st.session_state.proba_last, threshold=thr) st.write({ "Sensitivity": cls["sensitivity"], "Specificity": cls["specificity"], "Precision": cls["precision"], "Recall": cls["recall"], "F1": cls["f1"], "Accuracy": cls["accuracy"], "Balanced Accuracy": cls["balanced_accuracy"], }) # ---------------- PUBLISH (only after training) ---------------- # ---------------- PUBLISH (only after training) ---------------- if st.session_state.get("pipe") is not None: st.divider() st.subheader("Publish trained model to Hugging Face Hub") default_version = datetime.utcnow().strftime("%Y%m%d-%H%M%S") version_tag = st.text_input( "Version tag", value=default_version, help="Used as releases// in the model repository", ) if st.button("Publish model.joblib + meta.json to Model Repo", key="publish_btn"): try: with st.spinner("Uploading to Hugging Face Model repo..."): paths = publish_to_hub(MODEL_REPO_ID, version_tag) # Upload background.csv if it exists if os.path.exists("background.csv"): bg_paths = publish_background_to_hub(MODEL_REPO_ID, version_tag, background_path="background.csv") paths.update(bg_paths) else: st.warning("background.csv not found; SHAP background will not be uploaded.") st.success("Uploaded successfully to your model repository.") st.json(paths) except Exception as e: st.error(f"Upload failed: {e}") # ---------------- PREDICT ---------------- PRED_N_BINS = 10 PRED_CAL_STRATEGY = "uniform" with tab_predict: st.subheader("Select a trained model (no retraining required)") MODEL_REPO_ID = "Synav/Explainable-Acute-Leukemia-Mortality-Predictor" # Ensure session state keys exist if "pipe" not in st.session_state: st.session_state.pipe = None if "meta" not in st.session_state: st.session_state.meta = None if "explainer" not in st.session_state: st.session_state.explainer = None # List available releases try: versions = list_release_versions(MODEL_REPO_ID) except Exception as e: versions = [] st.error(f"Could not list model versions: {e}") choices = ["latest"] + versions if versions else ["latest"] selected = st.selectbox("Choose model version", choices, index=0) if st.button("Load selected model"): try: with st.spinner("Loading model from Hugging Face Hub..."): if selected == "latest": pipe, meta = load_latest_model(MODEL_REPO_ID) else: pipe, meta = load_model_by_version(MODEL_REPO_ID, selected) st.session_state.pipe = pipe st.session_state.meta = meta st.session_state.explainer = None # rebuild later with inference data st.success(f"Loaded model: {selected}") except Exception as e: st.error(f"Load failed: {e}") st.divider() if st.session_state.get("pipe") is None or st.session_state.get("meta") is None: st.warning("Load a model version above (it must include meta.json), then continue.") st.stop() pipe = st.session_state.pipe meta = st.session_state.meta # Try load survival model (optional) try: if selected == "latest": st.session_state.surv_model = load_latest_survival_model(MODEL_REPO_ID) else: st.session_state.surv_model = load_survival_model_by_version(MODEL_REPO_ID, selected) except Exception as e: st.session_state.surv_model = None st.warning(f"Survival bundle load failed: {e}") # bundle = st.session_state.get("surv_model", None) # bundle = st.session_state.get("surv_model", None) # st.write("surv_model type:", type(bundle)) # st.write("surv_model keys:", list(bundle.keys()) if isinstance(bundle, dict) else None) # if isinstance(bundle, dict) and bundle.get("model") is not None: # st.write("Cox predictors:", len(bundle.get("columns", []))) # st.write("Cox cat cols:", len(bundle.get("cat_cols", []))) # st.write("Cox num cols:", len(bundle.get("num_cols", []))) # st.write("Bundle version:", bundle.get("version")) # cph = bundle["model"] # st.write("Cox coef shape:", getattr(cph, "params_", None).shape if hasattr(cph, "params_") else None) # else: # cph = None DEBUG_SURV = False if DEBUG_SURV: st.write("Survival bundle loaded:", isinstance(bundle, dict)) # 1) MUST come first: schema from meta feature_cols = meta["schema"]["features"] num_cols = meta["schema"]["numeric"] cat_cols = meta["schema"]["categorical"] # ------------------------------------------------------------ # SHAP background: prefer inference file, else HF background.csv # ------------------------------------------------------------ df_inf = st.session_state.get("df_inf") if df_inf is not None: # use user cohort as background (optional) X_bg = coerce_X_like_schema(df_inf, feature_cols, num_cols, cat_cols) else: # fall back to published background X_bg = get_shap_background_auto(MODEL_REPO_ID, feature_cols, num_cols, cat_cols) st.session_state.X_bg_for_shap = X_bg # 2) Now we can build lookup FEATURE_LOOKUP = {norm_col(c): c for c in feature_cols} from datetime import date MIN_DOB = date(1900, 1, 1) ALL_COUNTRIES = [ "Afghanistan","Albania","Algeria","Andorra","Angola","Antigua and Barbuda", "Argentina","Armenia","Australia","Austria","Azerbaijan","Bahamas","Bahrain", "Bangladesh","Barbados","Belarus","Belgium","Belize","Benin","Bhutan","Bolivia", "Bosnia and Herzegovina","Botswana","Brazil","Brunei","Bulgaria","Burkina Faso", "Burundi","Cabo Verde","Cambodia","Cameroon","Canada","Central African Republic", "Chad","Chile","China","Colombia","Comoros","Congo (Congo-Brazzaville)", "Costa Rica","Cote d’Ivoire","Croatia","Cuba","Cyprus","Czechia", "Democratic Republic of the Congo","Denmark","Djibouti","Dominica", "Dominican Republic","Ecuador","Egypt","El Salvador","Equatorial Guinea", "Eritrea","Estonia","Eswatini","Ethiopia","Fiji","Finland","France","Gabon", "Gambia","Georgia","Germany","Ghana","Greece","Grenada","Guatemala","Guinea", "Guinea-Bissau","Guyana","Haiti","Honduras","Hungary","Iceland","India", "Indonesia","Iran","Iraq","Ireland","Israel","Italy","Jamaica","Japan","Jordan", "Kazakhstan","Kenya","Kiribati","Kuwait","Kyrgyzstan","Laos","Latvia","Lebanon", "Lesotho","Liberia","Libya","Liechtenstein","Lithuania","Luxembourg", "Madagascar","Malawi","Malaysia","Maldives","Mali","Malta","Marshall Islands", "Mauritania","Mauritius","Mexico","Micronesia","Moldova","Monaco","Mongolia", "Montenegro","Morocco","Mozambique","Myanmar","Namibia","Nauru","Nepal", "Netherlands","New Zealand","Nicaragua","Niger","Nigeria","North Korea", "North Macedonia","Norway","Oman","Pakistan","Palau","Palestine","Panama", "Papua New Guinea","Paraguay","Peru","Philippines","Poland","Portugal","Qatar", "Romania","Russia","Rwanda","Saint Kitts and Nevis","Saint Lucia", "Saint Vincent and the Grenadines","Samoa","San Marino","Sao Tome and Principe", "Saudi Arabia","Senegal","Serbia","Seychelles","Sierra Leone","Singapore", "Slovakia","Slovenia","Solomon Islands","Somalia","South Africa","South Korea", "South Sudan","Spain","Sri Lanka","Sudan","Suriname","Sweden","Switzerland", "Syria","Taiwan","Tajikistan","Tanzania","Thailand","Timor-Leste","Togo","Tonga", "Trinidad and Tobago","Tunisia","Turkey","Turkmenistan","Tuvalu","Uganda", "Ukraine","United Arab Emirates","United Kingdom","United States","Uruguay", "Uzbekistan","Vanuatu","Vatican City","Venezuela","Vietnam","Yemen","Zambia", "Zimbabwe", "Other","Unknown" ] DEFAULT_OPTIONS = { # ---------------- Demographics ---------------- "Gender": ["Male", "Female"], "Ethnicity": ALL_COUNTRIES, # ---------------- Disease ---------------- "Type of Leukemia": [ "ALL", "AML", "APL", "CML", "MPAL", "Secondary AML", "Other", "Unknown" ], "Risk Assesment": [ "Favorable", "Intermediate", "Adverse", "Unknown" ], "ECOG": [0, 1, 2, 3, 4], # ---------------- Molecular / Diagnostic ---------------- "MDx Leukemia Screen 1": [ "Negative", "t(15;17) / PML-RARA", "t(9;22) / BCR-ABL1", "RUNX1-RUNX1T1", "inv(16) / CBFB-MYH11", "KMT2A-AFF1", "t(1;19) / TCF3-PBX1", "t(10;11)", "Other", "Not done", "Unknown" ], "Post-Induction MRD": [ "Negative", "Positive", "Pending", "Not done", "Unknown" ], # ---------------- Treatment ---------------- "First_Induction": [ "3+7", "3+7+GO", "3+7+TKI", "3+7+Midostaurin", "ATO+ATRA", "Aza+Ven", "CALGB", "CPX-351", "Dexa+TKI", "FLAG-IDA", "HCVAD", "HCVAD+TKI", "HMA", "Hyper-CVAD", "Not fit", "Pediatric protocol", "Other", "Unknown" ], # ---------------- Clinical Yes/No ---------------- # (stored as 0/1 in your code) "Thrombocytopenia": ["No", "Yes"], "Bleeding": ["No", "Yes"], "B Symptoms": ["No", "Yes"], "Lymphadenopathy": ["No", "Yes"], "CNS Involvement": ["No", "Yes"], "Extramedullary Involvement": ["No", "Yes"], } # Canonical dropdown fields (use your intended names) CANON_DROPDOWNS = [ "Ethnicity", "Type of Leukemia", "Post-Induction MRD", "First_Induction", "Risk Assesment", "MDx Leukemia Screen 1", "ECOG", "Gender", ] # Convert canonical -> actual column names present in this loaded model DROPDOWN_FIELDS_ACTUAL = [] for canon in CANON_DROPDOWNS: key = norm_col(canon) if key in FEATURE_LOOKUP: DROPDOWN_FIELDS_ACTUAL.append(FEATURE_LOOKUP[key]) def alias_default_options(canon_name: str): k = norm_col(canon_name) if k in FEATURE_LOOKUP: actual = FEATURE_LOOKUP[k] if canon_name in DEFAULT_OPTIONS and actual not in DEFAULT_OPTIONS: DEFAULT_OPTIONS[actual] = DEFAULT_OPTIONS[canon_name] alias_default_options("Ethnicity") alias_default_options("Type of Leukemia") alias_default_options("Post-Induction MRD") alias_default_options("First_Induction") alias_default_options("Risk Assesment") alias_default_options("MDx Leukemia Screen 1") # Fields you want as controlled dropdowns (rather than free text) DROPDOWN_FIELDS = [ "Gender", "Ethnicity", "Type of Leukemia", "ECOG", "Risk Assesment", "MDx Leukemia Screen 1", "Post-Induction MRD", "First_Induction", ] # Yes/No (stored as 1/0) fields YESNO_FIELDS = [ "Thrombocytopenia", "Bleeding", "B Symptoms", "Lymphadenopathy", "CNS Involvement", "Extramedullary Involvement", ] # Numeric fields with units (show units in the label) UNITS = { 'WBCs on Admission\n ( x10^9/L) ': "x10^9/L", "Hb on Admission (g/L)": "g/L", "LDH on Admission (IU/L)": "IU/L", "Pre-Induction bone marrow biopsy blasts %": "%", "Age (years)": "years", } # FISH / NGS marker groups (multi-select -> sets multiple columns 1/0) FISH_MARKERS = [ "fish_inv16/cbfb_myh11", "fish_t8;21/runx1_runx1t1", "fish_t15;17/PML-RARA", "fish_KMT2A-MLL/11q23", "fish_BCR-ABL1/t9;22", "fish_ETV6-RUNX1/t12;21", "fish_TCF3_PBX1/t1;19", "fish_IGH/14q32rearr", "fish_CRLF2 rearr", "fish_NUP214/9q34", "fish_Other", "fish_del7q/monosomy7", "fish_TP53/del17p", "fish_del13q", "fish_del11q", ] FISH_COUNT_COL = "Number of FISH alteration" NGS_MARKERS = [ "ngs_FLT3", "ngs_NPM1", "ngs_CEBPA", "ngs_DNMT3A", "ngs_IDH1", "ngs_IDH2", "ngs_TET2", "ngs_TP53", "ngs_RUNX1", "ngs_Other", "ngs_spliceosome", ] NGS_COUNT_COL = "No. of ngs_mutation" AGE_FEATURE = "Age (years)" DX_DATE_FEATURE = "Date of 1st Bone Marrow biopsy (Date of Diagnosis)" # keep exact if your schema has trailing space # ========================= # Fields managed by dedicated tabs (do NOT render duplicates in Core tab) # ========================= def get_options_from_df(df: pd.DataFrame, col: str) -> list[str]: """Dropdown options from uploaded Excel (inference file).""" if df is None or col not in df.columns: return [] vals = ( df[col] .dropna() .astype(str) .map(lambda x: x.strip()) .loc[lambda s: s != ""] .unique() .tolist() ) vals = sorted(vals) return vals # ========================= # After model is loaded # ========================= pipe = st.session_state.pipe meta = st.session_state.meta # Map normalized name -> actual model column name FEATURE_LOOKUP = {norm_col(c): c for c in feature_cols} TAB_MANAGED_FIELDS = set() # Derived/auto fields (managed by DOB/Dx/CR widgets) if AGE_FEATURE in feature_cols: TAB_MANAGED_FIELDS.add(AGE_FEATURE) if DX_DATE_FEATURE in feature_cols: TAB_MANAGED_FIELDS.add(DX_DATE_FEATURE) if "Date of 1st CR" in feature_cols: TAB_MANAGED_FIELDS.add("Date of 1st CR") # Clinical yes/no handled in Clinical tab TAB_MANAGED_FIELDS.update([f for f in YESNO_FIELDS if f in feature_cols]) # FISH handled in FISH tab TAB_MANAGED_FIELDS.update([f for f in FISH_MARKERS if f in feature_cols]) # NGS handled in NGS tab TAB_MANAGED_FIELDS.update([f for f in NGS_MARKERS if f in feature_cols]) DROPDOWN_FIELDS = [f for f in DROPDOWN_FIELDS if f not in TAB_MANAGED_FIELDS] # Optional marker counts: handled automatically if FISH_COUNT_COL in feature_cols: TAB_MANAGED_FIELDS.add(FISH_COUNT_COL) if NGS_COUNT_COL in feature_cols: TAB_MANAGED_FIELDS.add(NGS_COUNT_COL) # --------------------------- # Inference Excel (optional) - store in session_state for dropdown options # --------------------------- infer_file = st.file_uploader("Upload inference Excel (.xlsx) (optional, for dropdown options + batch prediction)", type=["xlsx"], key="infer_xlsx") if infer_file: st.session_state.df_inf = pd.read_excel(infer_file, engine="openpyxl") # 🔴 PLACE IT HERE — normalize headers immediately after load st.session_state.df_inf.columns = [ " ".join(str(c).replace("\u00A0"," ").split()) for c in st.session_state.df_inf.columns ] # ✅ ALIGN TO TRAINED MODEL SCHEMA (THIS FIXES YOUR ERROR) st.session_state.df_inf = align_columns_to_schema( st.session_state.df_inf, meta["schema"]["features"]) st.session_state.df_inf = add_ethnicity_region( st.session_state.df_inf, "Ethnicity", "Ethnicity_Region" ) else: st.session_state.df_inf = None df_for_options = st.session_state.df_inf st.divider() st.subheader("Enter patient details to get prediction (Example: DOB → Age, Dx date → prediction)") from datetime import date MIN_DATE = date(1900, 1, 1) c1, c2 = st.columns(2) with c1: dob_unknown = st.checkbox("DOB unknown", value=False, key="dob_unknown_chk") dob = None if not dob_unknown: dob = st.date_input("Date of birth (DOB)", min_value=MIN_DATE, max_value=date.today(), key="dob_date") with c2: dx_unknown = st.checkbox("Diagnosis date unknown", value=False, key="dx_unknown_chk") dx_date = None if not dx_unknown: dx_date = st.date_input("Date of 1st Bone Marrow biopsy (Diagnosis)", min_value=MIN_DATE, max_value=date.today(), key="dx_date") # Optional additional date input for CR1 if present in feature_cols cr1_date = None if "Date of 1st CR" in feature_cols: cr1_unknown = st.checkbox("1st CR date unknown", value=False, key="cr1_unknown") if not cr1_unknown: cr1_date = st.date_input("Date of 1st CR", min_value=MIN_DATE, max_value=date.today(), key="cr1_date") derived_age = age_years_at(dob, dx_date) def yesno_to_01(v: str): if v == "Yes": return 1 if v == "No": return 0 return np.nan # Create tabs FIRST (outside any form) tab_core, tab_clin, tab_fish, tab_ngs = st.tabs(["Core", "Clinical (Yes/No)", "FISH", "NGS"]) # Storage for selections values_by_index = [np.nan] * len(feature_cols) with tab_fish: st.caption("Select one or many FISH alterations. Selected = 1, not selected = 0.") fish_selected = st.multiselect( "FISH alterations", options=[m for m in FISH_MARKERS if m in feature_cols], default=[], key="sp_fish_selected", ) with tab_ngs: st.caption("Select one or many NGS mutations. Selected = 1, not selected = 0.") ngs_selected = st.multiselect( "NGS mutations", options=[m for m in NGS_MARKERS if m in feature_cols], default=[], key="sp_ngs_selected", ) with tab_clin: st.caption("Clinical flags: Yes=1, No=0") for i, f in enumerate(feature_cols): if f in YESNO_FIELDS: v = st.selectbox(f, options=["", "No", "Yes"], index=0, key=f"sp_{i}_yn") values_by_index[i] = yesno_to_01(v) with tab_core: ecog = st.selectbox("ECOG", options=[0, 1, 2, 3, 4], index=0, key="sp_ecog") for i, f in enumerate(feature_cols): # Skip fields handled in other tabs / derived inputs if f in TAB_MANAGED_FIELDS: continue # Age auto-calc (display integer, store float) # --- Age (auto from DOB & Dx date) --- if f.strip() == AGE_FEATURE.strip(): # If DOB or Dx date unknown, allow manual entry if dob_unknown or dx_unknown or np.isnan(derived_age): v_age = st.number_input( f"{f} (enter if DOB/Dx unknown)", value=None, step=1, format="%d", key=f"sp_{i}_age_manual", ) values_by_index[i] = v_age else: st.number_input( f"{f} (auto from DOB & Dx date)", value=int(round(derived_age)), step=1, format="%d", key=f"sp_{i}_age_auto", disabled=True, ) values_by_index[i] = float(derived_age) continue # --- Diagnosis date (auto from dx_date input) --- if f.strip() == DX_DATE_FEATURE.strip(): values_by_index[i] = np.nan if dx_date is None else dx_date.isoformat() st.text_input( f"{f} (auto)", value="" if dx_date is None else dx_date.isoformat(), disabled=True, key=f"sp_{i}_dx_show" ) continue if f.strip() == "Date of 1st CR".strip(): values_by_index[i] = np.nan if cr1_date is None else cr1_date.isoformat() st.text_input( f"{f} (auto)", value="" if cr1_date is None else cr1_date.isoformat(), disabled=True, key=f"sp_{i}_cr_show" ) continue # ECOG mapped to int if f.strip() == "ECOG": values_by_index[i] = int(ecog) continue # Gender dropdown if f == "Gender": v = st.selectbox("Gender", options=["", "Male", "Female"], index=0, key=f"sp_{i}_gender") values_by_index[i] = np.nan if v == "" else v continue # Other dropdown fields (from inference df options if available) # Dropdown fields (robust) if f in DROPDOWN_FIELDS_ACTUAL and f not in ("Gender", "ECOG"): opts = options_for(f, df_for_options) # defaults + excel uniques v = st.selectbox(f, options=opts, index=0, key=f"sp_{i}_dd") values_by_index[i] = np.nan if v == "" else v continue # Numeric if f in num_cols: unit = UNITS.get(f, None) label = f"{f} ({unit})" if unit else f if f == "Age (years)": v = st.number_input(label, value=None, step=1, format="%d", key=f"sp_{i}_num") else: v = st.number_input(label, value=None, format="%.2f", key=f"sp_{i}_num") values_by_index[i] = v continue # Categorical fallback if f in cat_cols: v = st.text_input(f, value="", key=f"sp_{i}_cat") values_by_index[i] = np.nan if v.strip() == "" else v continue # Other fallback v = st.text_input(f, value="", key=f"sp_{i}_other") values_by_index[i] = np.nan if v.strip() == "" else v # Apply FISH/NGS selections to row fish_set = set(fish_selected) ngs_set = set(ngs_selected) for i, f in enumerate(feature_cols): if f in FISH_MARKERS: values_by_index[i] = 1 if f in fish_set else 0 if f in NGS_MARKERS: values_by_index[i] = 1 if f in ngs_set else 0 # Auto-fill marker counts if present if FISH_COUNT_COL in feature_cols: values_by_index[feature_cols.index(FISH_COUNT_COL)] = int(len(fish_selected)) if NGS_COUNT_COL in feature_cols: values_by_index[feature_cols.index(NGS_COUNT_COL)] = int(len(ngs_selected)) st.divider() st.subheader("Predict single patient") m = meta.get("metrics", {}) default_thr = float(m.get("best_threshold", 0.5)) thr_single = st.slider( "Classification threshold", 0.0, 1.0, default_thr, 0.01, key="sp_thr" ) # External validation threshold thr_ext = st.slider( "External validation threshold", 0.0, 1.0, default_thr, 0.01, key="thr_ext" ) low_cut_s, high_cut_s = st.slider( "Risk band cutoffs (low, high)", 0.0, 1.0, (0.2, 0.8), 0.01, key="sp_risk_cuts" ) # Ensure low <= high if low_cut_s > high_cut_s: low_cut_s, high_cut_s = high_cut_s, low_cut_s def band_one(p: float) -> str: if p < low_cut_s: return "Low" if p >= high_cut_s: return "High" return "Intermediate" # Submit button (no form needed; simpler + fewer state surprises) if st.button("Predict single patient", key="sp_predict_btn"): X_one = pd.DataFrame([values_by_index], columns=feature_cols).replace({pd.NA: np.nan}) for c in num_cols: if c in X_one.columns: X_one[c] = pd.to_numeric(X_one[c], errors="coerce") for c in cat_cols: if c in X_one.columns: X_one[c] = X_one[c].astype("object") X_one.loc[X_one[c].isna(), c] = np.nan X_one[c] = X_one[c].map(lambda v: v if pd.isna(v) else str(v)) proba_one = float(pipe.predict_proba(X_one)[:, 1][0]) st.success("Prediction generated.") st.metric("Predicted probability", f"{proba_one:.4f}") # ---- Survival prediction for this patient (if survival model loaded) ---- # ---- Survival prediction for this patient (if survival model loaded) ---- bundle = st.session_state.get("surv_model", None) if isinstance(bundle, dict) and bundle.get("model") is not None: try: cph = bundle["model"] surv_cols = bundle.get("columns", []) # Build Cox input row (same preprocessing as Cox training) cox_feature_cols = bundle.get("feature_cols", feature_cols) cox_num_cols = bundle.get("num_cols", num_cols) cox_cat_cols = bundle.get("cat_cols", cat_cols) df_one_surv = X_one[cox_feature_cols].copy() for c in cox_num_cols: if c in df_one_surv.columns: df_one_surv[c] = pd.to_numeric(df_one_surv[c], errors="coerce") for c in cox_cat_cols: if c in df_one_surv.columns: df_one_surv[c] = df_one_surv[c].astype("object").map( lambda v: v if pd.isna(v) else str(v) ) df_one_surv_oh = pd.get_dummies(df_one_surv, columns=cox_cat_cols, drop_first=True) df_one_surv_oh = df_one_surv_oh.loc[:, ~df_one_surv_oh.columns.duplicated()].copy() # Align to training predictor columns for col in surv_cols: if col not in df_one_surv_oh.columns: df_one_surv_oh[col] = 0 df_one_surv_oh = df_one_surv_oh[surv_cols] imp = bundle.get("imputer", None) if imp is not None: X_imp = imp.transform(df_one_surv_oh) df_one_surv_oh = pd.DataFrame(X_imp, columns=surv_cols, index=df_one_surv_oh.index) else: df_one_surv_oh = df_one_surv_oh.fillna(0) # Predict survival function surv_fn = cph.predict_survival_function(df_one_surv_oh) def surv_at(days: int) -> float: idx = surv_fn.index.values j = int(np.argmin(np.abs(idx - days))) return float(surv_fn.iloc[j, 0]) s6m = surv_at(180) s1y = surv_at(365) s2y = surv_at(730) s3y = surv_at(1095) st.subheader("Predicted survival probability") a, b, c, d = st.columns(4) a.metric("6 months", f"{s6m*100:.1f}%") b.metric("1 year", f"{s1y*100:.1f}%") c.metric("2 years", f"{s2y*100:.1f}%") d.metric("3 years", f"{s3y*100:.1f}%") fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen) ax.plot(surv_fn.index, surv_fn.values) ax.set_xlabel("Days from diagnosis") ax.set_ylabel("Survival probability") ax.set_ylim(0, 1) ax.set_title("Predicted survival curve (patient)") render_plot_with_download( fig, title="Patient survival curve", filename="patient_survival_curve.png", export_dpi=export_dpi, key="dl_patient_surv_curve" ) except Exception as e: st.warning(f"Survival prediction could not be computed: {e}") else: st.info("Survival model not loaded/published yet (survival bundle missing).") out = X_one.copy() out["predicted_probability"] = proba_one pred_class = int(proba_one >= thr_single) out["predicted_class"] = pred_class out["risk_band"] = band_one(proba_one) # ---- SHAP compute only (cache) ---- X_one_t = transform_before_clf(pipe, X_one) explainer = st.session_state.get("explainer") explainer_sig = st.session_state.get("explainer_sig") current_sig = ( selected, None if st.session_state.get("X_bg_for_shap") is None else int(len(st.session_state["X_bg_for_shap"])) ) if explainer is None or explainer_sig != current_sig: X_bg = st.session_state.get("X_bg_for_shap") if X_bg is None: st.error("SHAP background not available. Admin must publish latest/background.csv.") st.stop() st.session_state.explainer = build_shap_explainer(pipe, X_bg) st.session_state.explainer_sig = current_sig explainer = st.session_state.explainer shap_vals = explainer.shap_values(X_one_t) if isinstance(shap_vals, list): shap_vals = shap_vals[1] names = get_final_feature_names(pipe) try: x_dense = X_one_t.toarray()[0] except Exception: x_dense = np.array(X_one_t)[0] base = explainer.expected_value if not np.isscalar(base): base = float(np.array(base).reshape(-1)[0]) exp = shap.Explanation( values=shap_vals[0], base_values=float(base), data=x_dense, feature_names=names, ) # CACHE ONLY st.session_state.shap_single_exp = exp st.dataframe(out, use_container_width=True) st.download_button( "Download single patient result (CSV)", out.to_csv(index=False).encode("utf-8"), file_name="single_patient_prediction.csv", mime="text/csv", key="dl_sp_csv", ) # ---- Always render cached SHAP ---- if "shap_single_exp" in st.session_state: exp = st.session_state.shap_single_exp max_display_single = st.slider( "Top features to display (single patient)", 5, 40, 20, 1, key="sp_single_max_display" ) c1, c2 = st.columns(2) with c1: plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen) shap.plots.waterfall(exp, show=False, max_display=max_display_single) fig_w = plt.gcf() render_plot_with_download( fig_w, title="Single-patient SHAP waterfall", filename="single_patient_shap_waterfall.png", export_dpi=export_dpi, key="dl_sp_wf" ) with c2: plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen) shap.plots.bar(exp, show=False, max_display=max_display_single) fig_b = plt.gcf() render_plot_with_download( fig_b, title="Single-patient SHAP bar", filename="single_patient_shap_bar.png", export_dpi=export_dpi, key="dl_sp_bar" ) # --- After SHAP plots are displayed --- with st.expander("How to interpret the SHAP explanation plots", expanded=True): st.markdown(r""" **What is SHAP?** SHAP (SHapley Additive exPlanations) decomposes a model prediction into **feature-wise contributions**. Each feature pushes the prediction **toward higher risk** or **toward lower risk**, relative to the model’s baseline. **What is \(E[f(X)]\)? (baseline / expected model output)** - \(E[f(X)]\) is the model’s **average output** over the reference population used by the explainer (typically the training set). - It is the **starting point** of the explanation before any patient-specific information is applied. **What is \(f(x)\)? (this patient’s model output)** - \(f(x)\) is the model’s **final output for this patient**, after adding all feature contributions to the baseline. - For many classifiers (including logistic regression), SHAP waterfall plots often use the **log-odds (logit) scale** rather than raw probability. **How the waterfall plot should be read** - The plot starts at **\(E[f(X)]\)** and then adds/subtracts feature effects to arrive at **\(f(x)\)**. - **Red bars** push the prediction **toward higher risk** (increase the model output). - **Blue bars** push the prediction **toward lower risk / protective direction** (decrease the model output). - The **bar length** indicates the **strength** of that feature’s influence for this patient. - “**Other features**” is the combined net effect of features not shown individually. **How the bar plot should be read (Top contributors)** - Features are ranked by **absolute SHAP value** (largest impact first). - **Positive SHAP** (right) increases predicted risk; **negative SHAP** (left) decreases predicted risk. **Clinical cautions** - SHAP explains the model’s behavior; it does **not** prove causality. - Ensure variable definitions and patient population match the model’s intended use. """) # ----------------------------- # Batch prediction / validation # ----------------------------- st.divider() st.subheader("Batch prediction / External validation") df_inf = st.session_state.df_inf if df_inf is None: st.info("Upload an inference Excel to run batch prediction / external validation.") else: meta = st.session_state.get("meta") pipe = st.session_state.get("pipe") if meta is None or pipe is None: st.error("Model not loaded. Please load a model version first.") st.stop() feature_cols = meta["schema"]["features"] num_cols = meta["schema"]["numeric"] cat_cols = meta["schema"]["categorical"] missing = [c for c in feature_cols if c not in df_inf.columns] if missing: st.error(f"Inference file is missing required feature columns: {missing}") st.stop() meta = st.session_state.get("meta") if not meta: st.error("Model metadata not loaded. Please load a model version.") st.stop() X_inf = coerce_X_like_schema(df_inf, feature_cols, num_cols, cat_cols) X_inf = X_inf.replace({pd.NA: np.nan}) for c in num_cols: X_inf[c] = pd.to_numeric(X_inf[c], errors="coerce") for c in cat_cols: X_inf[c] = X_inf[c].astype("object") X_inf.loc[X_inf[c].isna(), c] = np.nan X_inf[c] = X_inf[c].map(lambda v: v if pd.isna(v) else str(v)) proba = pipe.predict_proba(X_inf)[:, 1] st.divider() st.subheader("External validation (if Outcome Event label is present)") if LABEL_COL in df_inf.columns: try: y_ext_raw = df_inf[LABEL_COL].copy() y_ext01, _ = coerce_binary_label(y_ext_raw) # Core metrics roc_auc_ext = float(roc_auc_score(y_ext01, proba)) fpr, tpr, roc_thresholds = roc_curve(y_ext01, proba) # Threshold metrics (user-controlled) cls_ext = compute_classification_metrics(y_ext01, proba, threshold=float(thr_ext)) pr_ext = compute_pr_curve(y_ext01, proba) cal_ext = compute_calibration(y_ext01, proba, n_bins=PRED_N_BINS, strategy=PRED_CAL_STRATEGY) dca_ext = decision_curve_analysis(y_ext01, proba) # Display headline metrics c1, c2, c3, c4 = st.columns(4) c1.metric("ROC AUC (external)", f"{roc_auc_ext:.3f}") c2.metric("Sensitivity", f"{cls_ext['sensitivity']:.3f}") c3.metric("Specificity", f"{cls_ext['specificity']:.3f}") c4.metric("F1", f"{cls_ext['f1']:.3f}") # Confusion matrix cm_df = pd.DataFrame( [[cls_ext["tn"], cls_ext["fp"]], [cls_ext["fn"], cls_ext["tp"]]], index=["Actual 0", "Actual 1"], columns=["Pred 0", "Pred 1"], ) st.markdown("**Confusion Matrix (external)**") st.dataframe(cm_df) # ROC plot # ========================= # EXTERNAL: ROC curve plot # (replace your current plt.plot(fpr, tpr) block) # ========================= fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen) ax.plot(fpr, tpr) ax.plot([0, 1], [0, 1]) ax.set_xlabel("False Positive Rate (1 - Specificity)") ax.set_ylabel("True Positive Rate (Sensitivity)") ax.set_title(f"External ROC Curve (AUC = {roc_auc_ext:.3f})") render_plot_with_download( fig, title="External ROC curve", filename="external_roc_curve.png", export_dpi=export_dpi, key="dl_ext_roc" ) # ------------------------- # EXTERNAL: PR curve # ------------------------- fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen) ax.plot(pr_ext["recall"], pr_ext["precision"]) ax.set_xlabel("Recall") ax.set_ylabel("Precision") ax.set_title(f"External PR Curve (AP = {pr_ext['average_precision']:.3f})") render_plot_with_download( fig, title="External PR curve", filename="external_pr_curve.png", export_dpi=export_dpi, key="dl_ext_pr" ) # ------------------------- # EXTERNAL: Calibration # ------------------------- fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen) ax.plot(cal_ext["prob_pred"], cal_ext["prob_true"]) ax.plot([0, 1], [0, 1]) ax.set_xlabel("Mean predicted probability") ax.set_ylabel("Observed event rate") ax.set_title("External calibration curve") render_plot_with_download( fig, title="External calibration curve", filename="external_calibration_curve.png", export_dpi=export_dpi, key="dl_ext_cal" ) # ------------------------- # EXTERNAL: DCA # ------------------------- fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen) ax.plot(dca_ext["thresholds"], dca_ext["net_benefit_model"], label="Model") ax.plot(dca_ext["thresholds"], dca_ext["net_benefit_all"], label="Treat all") ax.plot(dca_ext["thresholds"], dca_ext["net_benefit_none"], label="Treat none") ax.set_xlabel("Threshold probability") ax.set_ylabel("Net benefit") ax.set_title("External decision curve analysis") ax.legend() render_plot_with_download( fig, title="External decision curve", filename="external_decision_curve.png", export_dpi=export_dpi, key="dl_ext_dca" ) except Exception as e: st.error(f"Could not compute external validation metrics: {e}") else: st.info("No Outcome Event column found in the inference Excel, so external validation metrics cannot be computed.") # Predict probabilities proba = pipe.predict_proba(X_inf)[:, 1] df_out = df_inf.copy() df_out["predicted_probability"] = proba # ---- Batch survival probabilities (if survival model loaded) ---- # ---- Batch survival probabilities (if survival model loaded) ---- bundle = st.session_state.get("surv_model", None) if isinstance(bundle, dict) and bundle.get("model") is not None: try: cph = bundle["model"] surv_cols = bundle["columns"] imp = bundle.get("imputer", None) cox_feature_cols = bundle.get("feature_cols", feature_cols) cox_num_cols = bundle.get("num_cols", num_cols) cox_cat_cols = bundle.get("cat_cols", cat_cols) df_surv_in = X_inf[cox_feature_cols].copy() for c in cox_num_cols: if c in df_surv_in.columns: df_surv_in[c] = pd.to_numeric(df_surv_in[c], errors="coerce") for c in cox_cat_cols: if c in df_surv_in.columns: df_surv_in[c] = df_surv_in[c].astype("object") df_surv_in.loc[df_surv_in[c].isna(), c] = np.nan df_surv_in[c] = df_surv_in[c].map(lambda v: v if pd.isna(v) else str(v)) df_surv_in_oh = pd.get_dummies(df_surv_in, columns=cox_cat_cols, drop_first=True) df_surv_in_oh = df_surv_in_oh.loc[:, ~df_surv_in_oh.columns.duplicated()].copy() # align columns for col in surv_cols: if col not in df_surv_in_oh.columns: df_surv_in_oh[col] = 0 df_surv_in_oh = df_surv_in_oh[surv_cols] # impute if imp is not None: X_imp = imp.transform(df_surv_in_oh) df_surv_in_oh = pd.DataFrame(X_imp, columns=surv_cols, index=df_surv_in_oh.index) else: df_surv_in_oh = df_surv_in_oh.fillna(0) surv_fn_all = cph.predict_survival_function(df_surv_in_oh) def surv_vec_at(days: int): idx = surv_fn_all.index.values j = int(np.argmin(np.abs(idx - days))) return surv_fn_all.iloc[j, :].values.astype(float) df_out["survival_6m"] = surv_vec_at(180) df_out["survival_1y"] = surv_vec_at(365) df_out["survival_2y"] = surv_vec_at(730) df_out["survival_3y"] = surv_vec_at(1095) except Exception as e: st.warning(f"Batch survival probabilities could not be computed: {e}") else: st.info("Survival model not loaded/published yet (survival bundle missing).") # ========================================================== # EXPORT: CSV needed for NEJM-style ROC + Calibration + DCA # (label + predicted_probability) # ========================================================== if LABEL_COL in df_out.columns: curves_df = df_out[[LABEL_COL, "predicted_probability"]].copy() st.download_button( "Download CSV for ROC/Calibration/DCA panel", curves_df.to_csv(index=False).encode("utf-8"), file_name="predictions_for_curves.csv", mime="text/csv", key="dl_predictions_for_curves" ) else: st.info("Outcome Event column not present — panel CSV export requires true labels.") df_out = add_ethnicity_region(df_out, eth_col="Ethnicity", out_col="Ethnicity_Region") st.subheader("Grouped outcomes by region (analytics)") if "Ethnicity_Region" in df_out.columns: grp = df_out.groupby("Ethnicity_Region")["predicted_probability"].agg(["count","mean","median"]).reset_index() st.dataframe(grp, use_container_width=True) # --- classification + risk bands --- st.divider() st.subheader("Risk stratification") # Classification threshold thr = st.slider( "Decision threshold for classification", 0.0, 1.0, 0.5, 0.01, key="pred_thr" ) df_out["predicted_class"] = (df_out["predicted_probability"] >= thr).astype(int) # ✅ Batch risk band slider (MISSING in your code) low_cut, high_cut = st.slider( "Risk band cutoffs (low, high)", 0.0, 1.0, (0.2, 0.8), 0.01, key="batch_risk_cuts" ) # safety (optional but good practice) if low_cut > high_cut: low_cut, high_cut = high_cut, low_cut def band(p: float) -> str: if p < low_cut: return "Low" if p >= high_cut: return "High" return "Intermediate" df_out["risk_band"] = df_out["predicted_probability"].map(band) # --- END ADD --- st.dataframe(df_out.head()) st.download_button( "Download predictions", df_out.to_csv(index=False).encode(), "predictions.csv", "text/csv" ) #Batch SHAP for whole BLOCK st.divider() st.subheader("Batch SHAP (first 200 rows)") MAX_BATCH = 200 n_rows = len(X_inf) batch_n = min(MAX_BATCH, n_rows) cA, cB, cC = st.columns([1, 1, 1]) with cA: do_batch = st.button(f"Compute batch SHAP for first {batch_n} rows", key="batch_shap_btn") with cB: max_display_batch = st.slider("Top features to display (batch)",5, 40, 20, 1,key="batch_max_display") with cC: show_beeswarm = st.checkbox("Show beeswarm (slower)", value=True, key="batch_beeswarm") if do_batch: with st.spinner("Computing batch SHAP..."): X_batch = X_inf.iloc[:batch_n].copy() X_batch_t = transform_before_clf(pipe, X_batch) explainer = st.session_state.get("explainer") explainer_sig = st.session_state.get("explainer_sig") # Create a simple signature that changes if model changes or background changes # (using version + number of background rows is usually enough) current_sig = ( selected, # or meta.get("created_at_utc") or meta.get("metrics", {}).get("roc_auc") None if st.session_state.get("X_bg_for_shap") is None else int(len(st.session_state["X_bg_for_shap"])) ) if explainer is None or explainer_sig != current_sig: X_bg = st.session_state.get("X_bg_for_shap") if X_bg is None: st.error("SHAP background not available. Admin must publish latest/background.csv.") st.stop() st.session_state.explainer = build_shap_explainer(pipe, X_bg) st.session_state.explainer_sig = current_sig explainer = st.session_state.explainer shap_vals_batch = explainer.shap_values(X_batch_t) if isinstance(shap_vals_batch, list): shap_vals_batch = shap_vals_batch[1] names = get_final_feature_names(pipe) # Absolute sanity check: align names with SHAP columns if len(names) != shap_vals_batch.shape[1]: st.warning( f"Feature name mismatch: names={len(names)} vs shap_cols={shap_vals_batch.shape[1]}. " "Using generic names." ) names = [f"f{i}" for i in range(shap_vals_batch.shape[1])] # Dense conversion once (used for summary + waterfalls) try: X_dense = safe_dense(X_batch_t, max_rows=200) except Exception: X_dense = np.array(X_batch_t) # Cache batch results st.session_state.shap_batch_vals = shap_vals_batch st.session_state.shap_batch_X_dense = X_dense st.session_state.shap_batch_n = batch_n st.session_state.shap_batch_feature_names = names st.success(f"Batch SHAP computed for first {batch_n} rows.") if "shap_batch_vals" in st.session_state: shap_vals_batch = st.session_state.shap_batch_vals X_dense = st.session_state.shap_batch_X_dense batch_n = st.session_state.shap_batch_n names = st.session_state.shap_batch_feature_names st.divider() st.subheader("Export: Top SHAP features per row (batch)") top_k = st.slider("Top-K features per row", 3, 30, 10, 1, key="topk_export") # Optional: include predicted probabilities for the same batch rows # (Assumes you already computed proba for all X_inf earlier) include_proba = st.checkbox("Include predicted probability", value=True, key="include_proba_export") if st.button("Generate Top-K SHAP table", key="gen_topk_shap"): shap_vals_batch = st.session_state.shap_batch_vals # shape: (batch_n, n_features) names = st.session_state.shap_batch_feature_names batch_n = st.session_state.shap_batch_n rows = [] for i in range(batch_n): sv = shap_vals_batch[i] idx = np.argsort(np.abs(sv))[::-1][:top_k] # top-k by absolute SHAP for j in idx: val = float(sv[j]) rows.append({ "row_in_batch": int(i), "feature": str(names[j]), "shap_value": val, "abs_shap_value": abs(val), "direction": "↑" if val > 0 else ("↓" if val < 0 else "0"), }) df_topk = pd.DataFrame(rows) if include_proba: # Use the same batch rows from the previously computed proba vector # If you want absolute Excel row index, add + df_inf.index[0] logic as needed proba_batch = proba[:batch_n] df_proba = pd.DataFrame({"row_in_batch": list(range(batch_n)), "predicted_probability": proba_batch}) df_topk = df_topk.merge(df_proba, on="row_in_batch", how="left") # Sort nicely: each row block by importance df_topk = df_topk.sort_values(["row_in_batch", "abs_shap_value"], ascending=[True, False]) st.dataframe(df_topk, use_container_width=True) st.download_button( "Download Top-K SHAP per row (CSV)", df_topk.to_csv(index=False).encode("utf-8"), file_name=f"shap_top{top_k}_per_row_first{batch_n}.csv", mime="text/csv", key="dl_topk_shap_csv" ) st.markdown(f"### Global SHAP summary (first {batch_n} rows)") # BAR SUMMARY plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen) shap.summary_plot( shap_vals_batch, features=X_dense, feature_names=names, plot_type="bar", max_display=max_display_batch, show=False ) fig_bar = plt.gcf() render_plot_with_download(fig_bar, title="SHAP bar summary", filename="shap_summary_bar.png", export_dpi=export_dpi) # BEESWARM SUMMARY (optional) if show_beeswarm: plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen) shap.summary_plot( shap_vals_batch, features=X_dense, feature_names=names, max_display=max_display_batch, show=False ) fig_swarm = plt.gcf() render_plot_with_download(fig_swarm, title="SHAP beeswarm", filename="shap_beeswarm.png", export_dpi=export_dpi) st.markdown("### Waterfall plots (batch)") rows_to_plot = st.multiselect( "Select rows (within the first batch) to plot waterfalls", options=list(range(batch_n)), default=[0], key="batch_rows_to_plot" ) max_display_single = st.slider("Top features to display (single-row SHAP)",5, 40, 20, 1,key="single_max_display") max_waterfalls = st.slider("Max waterfall plots to render", 1, 10, 3, 1, key="max_waterfalls") rows_to_plot = rows_to_plot[:max_waterfalls] explainer = st.session_state.get("explainer") base = explainer.expected_value if not np.isscalar(base): base = float(np.array(base).reshape(-1)[0]) for r in rows_to_plot: st.markdown(f"**Row {r} (within first {batch_n})**") exp = shap.Explanation( values=shap_vals_batch[r], base_values=float(base), data=X_dense[r], feature_names=names, ) # Waterfall plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen) shap.plots.waterfall(exp, show=False, max_display=max_display_single) fig_w = plt.gcf() render_plot_with_download( fig_w, title=f"Batch SHAP waterfall (row {r})", filename=f"shap_waterfall_batch_row_{r}.png", export_dpi=export_dpi, key=f"dl_wf_batch_{r}" ) # Bar plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen) shap.plots.bar(exp, show=False, max_display=max_display_single) fig_b = plt.gcf() render_plot_with_download( fig_b, title=f"Batch SHAP bar (row {r})", filename=f"shap_bar_batch_row_{r}.png", export_dpi=export_dpi, key=f"dl_bar_batch_{r}" ) # Single row SHAP block st.subheader("SHAP explanation") # 1) Form only computes / updates cached explanation with st.form("shap_form"): row = st.number_input("Row index", 0, len(X_inf) - 1, int(st.session_state.get("shap_row", 0))) explain_btn = st.form_submit_button("Generate SHAP explanation") if explain_btn: st.session_state.shap_row = int(row) X_one = X_inf.iloc[[int(row)]] X_one_t = transform_before_clf(pipe, X_one) explainer = st.session_state.get("explainer") explainer_sig = st.session_state.get("explainer_sig") # Create a simple signature that changes if model changes or background changes # (using version + number of background rows is usually enough) current_sig = ( selected, # or meta.get("created_at_utc") or meta.get("metrics", {}).get("roc_auc") None if st.session_state.get("X_bg_for_shap") is None else int(len(st.session_state["X_bg_for_shap"])) ) if explainer is None or explainer_sig != current_sig: X_bg = st.session_state.get("X_bg_for_shap") if X_bg is None: st.error("SHAP background not available. Admin must publish latest/background.csv.") st.stop() st.session_state.explainer = build_shap_explainer(pipe, X_bg) st.session_state.explainer_sig = current_sig explainer = st.session_state.explainer shap_vals = explainer.shap_values(X_one_t) if isinstance(shap_vals, list): shap_vals = shap_vals[1] # positive class names = get_final_feature_names(pipe) if len(names) != shap_vals.shape[1]: names = [f"f{i}" for i in range(shap_vals.shape[1])] try: x_dense = X_one_t.toarray()[0] except Exception: x_dense = np.array(X_one_t)[0] base = explainer.expected_value if not np.isscalar(base): base = float(np.array(base).reshape(-1)[0]) exp = shap.Explanation( values=shap_vals[0], base_values=float(base), data=x_dense, feature_names=names, ) # Cache for re-plotting when sliders change st.session_state.shap_single_exp = exp # 2) Plot section OUTSIDE the form (will rerun on slider changes) if "shap_single_exp" in st.session_state: exp = st.session_state.shap_single_exp max_display_single_row = st.slider( "Top features to display (single row)", 5, 40, int(st.session_state.get("shap_single_max", 20)), 1, key="shap_single_max" ) c1, c2 = st.columns(2) with c1: st.markdown("**Waterfall**") plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen) shap.plots.waterfall(exp, show=False, max_display=max_display_single_row) fig_w = plt.gcf() render_plot_with_download( fig_w, title="SHAP waterfall", filename="shap_waterfall_row.png", export_dpi=export_dpi, key="dl_shap_wf_single" ) with c2: st.markdown("**Top features**") plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen) shap.plots.bar(exp, show=False, max_display=max_display_single_row) fig_b = plt.gcf() render_plot_with_download( fig_b, title="SHAP bar", filename="shap_bar_row.png", export_dpi=export_dpi, key="dl_shap_bar_single" ) else: st.info("Submit a row index to generate the SHAP explanation.")