|
|
import json |
|
|
from datetime import datetime |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import streamlit as st |
|
|
import joblib |
|
|
|
|
|
|
|
|
import os |
|
|
from huggingface_hub import hf_hub_download, HfApi |
|
|
import hmac |
|
|
from sklearn.metrics import ( |
|
|
roc_auc_score, accuracy_score, |
|
|
roc_curve, confusion_matrix, |
|
|
precision_score, recall_score, f1_score, |
|
|
balanced_accuracy_score, |
|
|
precision_recall_curve, average_precision_score, |
|
|
brier_score_loss |
|
|
) |
|
|
import shap |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
from sklearn.calibration import calibration_curve |
|
|
from sklearn.feature_selection import VarianceThreshold, SelectFromModel |
|
|
from sklearn.decomposition import TruncatedSVD |
|
|
|
|
|
from sklearn.pipeline import Pipeline |
|
|
from sklearn.compose import ColumnTransformer |
|
|
from sklearn.preprocessing import OneHotEncoder, StandardScaler |
|
|
from sklearn.impute import SimpleImputer |
|
|
from sklearn.linear_model import LogisticRegression |
|
|
from sklearn.model_selection import train_test_split |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_fig(figsize=(5.5, 3.6), dpi=120): |
|
|
|
|
|
fig, ax = plt.subplots(figsize=figsize, dpi=dpi) |
|
|
return fig, ax |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fig_to_png_bytes(fig, dpi=600): |
|
|
import io |
|
|
buf = io.BytesIO() |
|
|
fig.savefig(buf, format="png", dpi=int(dpi), bbox_inches="tight") |
|
|
buf.seek(0) |
|
|
return buf.getvalue() |
|
|
|
|
|
from typing import Optional |
|
|
|
|
|
def render_plot_with_download( |
|
|
fig, |
|
|
*, |
|
|
title: str, |
|
|
filename: str, |
|
|
export_dpi: int = 600, |
|
|
key: Optional[str] = None |
|
|
): |
|
|
|
|
|
png_bytes = fig_to_png_bytes(fig, dpi=export_dpi) |
|
|
st.pyplot(fig, clear_figure=False) |
|
|
st.download_button( |
|
|
label=f"Download {title} (PNG {export_dpi} dpi)", |
|
|
data=png_bytes, |
|
|
file_name=filename, |
|
|
mime="image/png", |
|
|
key=key or f"dl_{filename}" |
|
|
) |
|
|
plt.close(fig) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LABEL_COL = "Outcome Event" |
|
|
|
|
|
|
|
|
|
|
|
SURV_START_COL_CANDS = ["Date of 1st Bone Marrow biopsy (Date of Diagnosis)", "Diagnosis Date", "Dx Date"] |
|
|
SURV_DEATH_COL_CANDS = ["death_date", "Death Date", "Date of Death"] |
|
|
SURV_CENSOR_COL_CANDS = ["last_followup_date", "Last Follow up date", "Last Follow-up Date", "Last Seen Date"] |
|
|
SURV_EVENT_COL_CANDS = ["Outcome Event", "Death", "Event"] |
|
|
|
|
|
|
|
|
|
|
|
def get_feature_cols_from_df(df: pd.DataFrame): |
|
|
""" |
|
|
Features = ALL columns except Outcome Event (no limit), preserving original order. |
|
|
""" |
|
|
if LABEL_COL not in df.columns: |
|
|
raise ValueError(f"Missing required label column '{LABEL_COL}'. " |
|
|
f"Ensure your Excel header row contains a column named '{LABEL_COL}'.") |
|
|
return [c for c in df.columns if c != LABEL_COL] |
|
|
|
|
|
from datetime import date |
|
|
|
|
|
def to_date_safe(x): |
|
|
"""Accepts date/datetime/str; returns python date or None.""" |
|
|
if x is None or (isinstance(x, float) and np.isnan(x)): |
|
|
return None |
|
|
if isinstance(x, datetime): |
|
|
return x.date() |
|
|
if isinstance(x, date): |
|
|
return x |
|
|
|
|
|
try: |
|
|
return pd.to_datetime(x).date() |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
def age_years_at(dob, ref_date): |
|
|
"""Age in years at ref_date. Returns float or np.nan.""" |
|
|
dob = to_date_safe(dob) |
|
|
ref_date = to_date_safe(ref_date) |
|
|
if dob is None or ref_date is None: |
|
|
return np.nan |
|
|
if ref_date < dob: |
|
|
return np.nan |
|
|
return (ref_date - dob).days / 365.25 |
|
|
|
|
|
def build_survival_targets(df: pd.DataFrame): |
|
|
start_col = find_col(df, SURV_START_COL_CANDS) |
|
|
death_col = find_col(df, SURV_DEATH_COL_CANDS) |
|
|
censor_col = find_col(df, SURV_CENSOR_COL_CANDS) |
|
|
event_col = find_col(df, SURV_EVENT_COL_CANDS) |
|
|
|
|
|
if start_col is None or event_col is None: |
|
|
raise ValueError("Survival requires at minimum Diagnosis date and Outcome/Event column.") |
|
|
|
|
|
start = pd.to_datetime(df[start_col], errors="coerce") |
|
|
|
|
|
|
|
|
ev_raw = df[event_col] |
|
|
if pd.api.types.is_numeric_dtype(ev_raw): |
|
|
ev = ev_raw.fillna(0).astype(int).eq(1) |
|
|
else: |
|
|
ev = (ev_raw.astype(str).str.strip().str.upper().isin(["YES","Y","1","TRUE","T","DEAD","DIED"])) |
|
|
|
|
|
|
|
|
death = pd.to_datetime(df[death_col], errors="coerce") if death_col else pd.Series([pd.NaT]*len(df)) |
|
|
last = pd.to_datetime(df[censor_col], errors="coerce") if censor_col else pd.Series([pd.NaT]*len(df)) |
|
|
|
|
|
end = death.where(ev, last) |
|
|
time_days = (end - start).dt.days.astype("float") |
|
|
time_days = time_days.where(time_days >= 0, np.nan) |
|
|
|
|
|
used = { |
|
|
"start": start_col, |
|
|
"event": event_col, |
|
|
"death": death_col, |
|
|
"censor": censor_col, |
|
|
} |
|
|
|
|
|
return time_days.to_numpy(), ev.astype(int).to_numpy(), used |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import re |
|
|
def align_columns_to_schema(df: pd.DataFrame, required_cols: list[str]) -> pd.DataFrame: |
|
|
""" |
|
|
Rename df columns to match required_cols using normalized matching. |
|
|
This fixes newline/space/underscore/case differences. |
|
|
""" |
|
|
|
|
|
df_cols = list(df.columns) |
|
|
norm_to_actual = {norm_col(c): c for c in df_cols} |
|
|
|
|
|
rename_map = {} |
|
|
for req in required_cols: |
|
|
k = norm_col(req) |
|
|
if k in norm_to_actual: |
|
|
src = norm_to_actual[k] |
|
|
if src != req: |
|
|
rename_map[src] = req |
|
|
|
|
|
if rename_map: |
|
|
df = df.rename(columns=rename_map) |
|
|
|
|
|
return df |
|
|
|
|
|
|
|
|
def norm_col(s: str) -> str: |
|
|
""" |
|
|
Normalize column names for robust matching: |
|
|
- strip leading/trailing whitespace |
|
|
- collapse multiple spaces |
|
|
- remove underscores |
|
|
- lower-case |
|
|
""" |
|
|
s = "" if s is None else str(s) |
|
|
s = s.replace("\u00A0", " ") |
|
|
s = re.sub(r"\s+", " ", s).strip() |
|
|
s = s.replace("_", "") |
|
|
return s.lower() |
|
|
|
|
|
def find_col(df: pd.DataFrame, candidates: list[str]) -> str | None: |
|
|
"""Return the first column in df that matches any candidate (normalized), else None.""" |
|
|
lookup = {norm_col(c): c for c in df.columns} |
|
|
for cand in candidates: |
|
|
k = norm_col(cand) |
|
|
if k in lookup: |
|
|
return lookup[k] |
|
|
return None |
|
|
|
|
|
def train_survival_bundle( |
|
|
df: pd.DataFrame, |
|
|
feature_cols: list[str], |
|
|
num_cols: list[str], |
|
|
cat_cols: list[str], |
|
|
time_days: np.ndarray, |
|
|
event01: np.ndarray, |
|
|
*, |
|
|
penalizer: float = 0.1 |
|
|
): |
|
|
""" |
|
|
Returns (bundle_dict, notes). Raises exceptions if hard-fail. |
|
|
Lazy-imports lifelines. |
|
|
""" |
|
|
from lifelines import CoxPHFitter |
|
|
from sklearn.impute import SimpleImputer |
|
|
|
|
|
|
|
|
df_surv = df[feature_cols].copy().replace({pd.NA: np.nan}) |
|
|
|
|
|
|
|
|
for c in num_cols: |
|
|
if c in df_surv.columns: |
|
|
df_surv[c] = pd.to_numeric(df_surv[c], errors="coerce") |
|
|
|
|
|
for c in cat_cols: |
|
|
if c in df_surv.columns: |
|
|
df_surv[c] = df_surv[c].astype("object") |
|
|
df_surv.loc[df_surv[c].isna(), c] = np.nan |
|
|
df_surv[c] = df_surv[c].map(lambda v: v if pd.isna(v) else str(v)) |
|
|
|
|
|
df_surv["time_days"] = time_days |
|
|
df_surv["event"] = event01 |
|
|
df_surv = df_surv.dropna(subset=["time_days", "event"]) |
|
|
|
|
|
duration_col = "time_days" |
|
|
event_col = "event" |
|
|
|
|
|
|
|
|
df_surv_oh = pd.get_dummies(df_surv, columns=cat_cols, drop_first=True) |
|
|
if duration_col not in df_surv_oh.columns or event_col not in df_surv_oh.columns: |
|
|
raise ValueError("Survival DF missing duration/event columns after one-hot encoding.") |
|
|
|
|
|
|
|
|
|
|
|
df_surv_oh = df_surv_oh.loc[:, ~df_surv_oh.columns.duplicated()].copy() |
|
|
|
|
|
|
|
|
X_cols = [c for c in df_surv_oh.columns if c not in (duration_col, event_col)] |
|
|
|
|
|
|
|
|
df_surv_oh[X_cols] = df_surv_oh[X_cols].apply(pd.to_numeric, errors="coerce") |
|
|
|
|
|
|
|
|
|
|
|
var0 = df_surv_oh[X_cols].var(skipna=True) |
|
|
X_cols = [c for c in X_cols if pd.notna(var0.get(c, np.nan)) and var0[c] > 0] |
|
|
|
|
|
if len(X_cols) == 0: |
|
|
raise ValueError("All Cox predictors are zero-variance after one-hot; cannot fit survival model.") |
|
|
|
|
|
|
|
|
|
|
|
imp_tmp = SimpleImputer(strategy="median") |
|
|
X_tmp = imp_tmp.fit_transform(df_surv_oh[X_cols]) |
|
|
df_surv_oh.loc[:, X_cols] = pd.DataFrame(X_tmp, columns=X_cols, index=df_surv_oh.index) |
|
|
|
|
|
|
|
|
var1 = df_surv_oh[X_cols].var(skipna=True) |
|
|
X_cols = [c for c in X_cols if pd.notna(var1.get(c, np.nan)) and var1[c] > 0] |
|
|
if len(X_cols) == 0: |
|
|
raise ValueError("All Cox predictors became zero-variance after imputation; cannot fit survival model.") |
|
|
|
|
|
|
|
|
n_events = int(np.sum(df_surv_oh[event_col].astype(int).to_numpy())) |
|
|
if n_events < 5: |
|
|
raise ValueError(f"Too few events for Cox model (events={n_events}).") |
|
|
|
|
|
|
|
|
if len(X_cols) > max(10, 2 * n_events): |
|
|
k = max(10, 2 * n_events) |
|
|
vars_sorted = var1[X_cols].sort_values(ascending=False) |
|
|
X_cols = list(vars_sorted.head(k).index) |
|
|
|
|
|
|
|
|
imp = SimpleImputer(strategy="median") |
|
|
X_final = imp.fit_transform(df_surv_oh[X_cols]) |
|
|
df_surv_oh.loc[:, X_cols] = pd.DataFrame(X_final, columns=X_cols, index=df_surv_oh.index) |
|
|
|
|
|
|
|
|
cph = CoxPHFitter(penalizer=float(max(penalizer, 1.0)), l1_ratio=0.0) |
|
|
cph.fit(df_surv_oh[[*X_cols, duration_col, event_col]], |
|
|
duration_col=duration_col, |
|
|
event_col=event_col) |
|
|
|
|
|
|
|
|
bundle = { |
|
|
"model": cph, |
|
|
"columns": X_cols, |
|
|
"imputer": imp, |
|
|
"cat_cols": cat_cols, |
|
|
"num_cols": num_cols, |
|
|
"feature_cols": feature_cols, |
|
|
"duration_col": duration_col, |
|
|
"event_col": event_col, |
|
|
"version": 2 |
|
|
} |
|
|
return bundle, "Survival model trained successfully." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_pipeline( |
|
|
num_cols, |
|
|
cat_cols, |
|
|
*, |
|
|
use_feature_selection: bool = True, |
|
|
l1_C: float = 1.0, |
|
|
selection_max_features: int | None = None, |
|
|
use_dimred: bool = False, |
|
|
svd_components: int = 50, |
|
|
svd_random_state: int = 42, |
|
|
): |
|
|
""" |
|
|
- No limit on raw variables. |
|
|
- Drops useless columns: |
|
|
(a) zero-variance after preprocessing |
|
|
(b) L1-based selection (optional) |
|
|
- Optional dimensionality reduction via TruncatedSVD (sparse-friendly). |
|
|
""" |
|
|
|
|
|
num_pipe = Pipeline([ |
|
|
("imputer", SimpleImputer(strategy="median")), |
|
|
("scaler", StandardScaler()) |
|
|
]) |
|
|
|
|
|
cat_pipe = Pipeline([ |
|
|
("imputer", SimpleImputer(strategy="most_frequent")), |
|
|
("onehot", OneHotEncoder(handle_unknown="ignore", sparse_output=True, drop="first")) |
|
|
|
|
|
]) |
|
|
|
|
|
preprocessor = ColumnTransformer( |
|
|
transformers=[ |
|
|
("num", num_pipe, num_cols), |
|
|
("cat", cat_pipe, cat_cols) |
|
|
], |
|
|
remainder="drop", |
|
|
verbose_feature_names_out=False |
|
|
) |
|
|
|
|
|
steps = [] |
|
|
steps.append(("preprocess", preprocessor)) |
|
|
|
|
|
|
|
|
steps.append(("vt", VarianceThreshold(threshold=0.0))) |
|
|
|
|
|
|
|
|
|
|
|
if use_dimred: |
|
|
steps.append(("svd", TruncatedSVD( |
|
|
n_components=int(svd_components), |
|
|
random_state=int(svd_random_state) |
|
|
))) |
|
|
|
|
|
|
|
|
|
|
|
if use_feature_selection: |
|
|
selector_est = LogisticRegression( |
|
|
penalty="l1", solver="saga", C=float(l1_C), |
|
|
max_iter=5000, n_jobs=-1, |
|
|
class_weight="balanced" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
selector = SelectFromModel(selector_est, threshold="median") |
|
|
steps.append(("select", selector)) |
|
|
|
|
|
|
|
|
clf = LogisticRegression(max_iter=5000, solver="lbfgs", class_weight="balanced") |
|
|
|
|
|
steps.append(("clf", clf)) |
|
|
|
|
|
return Pipeline(steps) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def coerce_binary_label(y: pd.Series): |
|
|
y_clean = y.dropna() |
|
|
uniq = list(pd.unique(y_clean)) |
|
|
if len(uniq) != 2: |
|
|
raise ValueError(f"Outcome Event must be binary (2 unique values). Found: {uniq}") |
|
|
|
|
|
if pd.api.types.is_numeric_dtype(y_clean): |
|
|
pos = sorted(uniq)[-1] |
|
|
return (y == pos).astype(int).to_numpy(), pos |
|
|
|
|
|
if y_clean.dtype == bool: |
|
|
return y.astype(int).to_numpy(), True |
|
|
|
|
|
uniq_str = sorted([str(u) for u in uniq]) |
|
|
pos = uniq_str[-1] |
|
|
return y.astype(str).eq(pos).astype(int).to_numpy(), pos |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_classification_metrics(y_true, y_proba, threshold: float = 0.5): |
|
|
y_pred = (y_proba >= threshold).astype(int) |
|
|
|
|
|
tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel() |
|
|
|
|
|
sensitivity = tp / (tp + fn) if (tp + fn) else 0.0 |
|
|
specificity = tn / (tn + fp) if (tn + fp) else 0.0 |
|
|
precision = precision_score(y_true, y_pred, zero_division=0) |
|
|
recall = recall_score(y_true, y_pred, zero_division=0) |
|
|
f1 = f1_score(y_true, y_pred, zero_division=0) |
|
|
acc = accuracy_score(y_true, y_pred) |
|
|
bacc = balanced_accuracy_score(y_true, y_pred) |
|
|
|
|
|
return { |
|
|
"threshold": float(threshold), |
|
|
"tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp), |
|
|
"sensitivity": float(sensitivity), |
|
|
"specificity": float(specificity), |
|
|
"precision": float(precision), |
|
|
"recall": float(recall), |
|
|
"f1": float(f1), |
|
|
"accuracy": float(acc), |
|
|
"balanced_accuracy": float(bacc), |
|
|
} |
|
|
|
|
|
def find_best_threshold(y_true, y_proba, metric="f1"): |
|
|
thresholds = np.linspace(0.05, 0.95, 181) |
|
|
best_t, best_val, best_cls = 0.5, -1, None |
|
|
for t in thresholds: |
|
|
cls = compute_classification_metrics(y_true, y_proba, threshold=float(t)) |
|
|
val = cls.get(metric, 0.0) |
|
|
if val > best_val: |
|
|
best_val, best_t, best_cls = val, float(t), cls |
|
|
return best_t, best_val, best_cls |
|
|
|
|
|
def find_best_threshold_f1(y_true, y_proba, t_min=0.01, t_max=0.99, n=199): |
|
|
""" |
|
|
Returns threshold that maximizes F1 on (y_true, y_proba). |
|
|
""" |
|
|
thresholds = np.linspace(float(t_min), float(t_max), int(n)) |
|
|
best = {"threshold": 0.5, "f1": -1.0, "cls": None} |
|
|
|
|
|
for t in thresholds: |
|
|
cls = compute_classification_metrics(y_true, y_proba, threshold=float(t)) |
|
|
if cls["f1"] > best["f1"]: |
|
|
best = {"threshold": float(t), "f1": float(cls["f1"]), "cls": cls} |
|
|
|
|
|
return best["threshold"], best["cls"] |
|
|
|
|
|
|
|
|
|
|
|
def bootstrap_internal_validation( |
|
|
df: pd.DataFrame, |
|
|
feature_cols: list[str], |
|
|
num_cols: list[str], |
|
|
cat_cols: list[str], |
|
|
*, |
|
|
n_boot: int = 1000, |
|
|
threshold: float = 0.5, |
|
|
use_feature_selection: bool = True, |
|
|
l1_C: float = 1.0, |
|
|
use_dimred: bool = False, |
|
|
svd_components: int = 50, |
|
|
random_state: int = 42, |
|
|
): |
|
|
""" |
|
|
Bootstrap internal validation using OOB evaluation. |
|
|
For each bootstrap replicate: |
|
|
- sample N rows with replacement |
|
|
- fit pipeline on bootstrap sample |
|
|
- evaluate on OOB rows (not sampled) |
|
|
Returns: |
|
|
- df_boot: per-iteration metrics |
|
|
- summary: mean, median, 2.5% and 97.5% CI for key metrics |
|
|
- n_skipped: count of iterations skipped due to empty OOB or single-class OOB |
|
|
""" |
|
|
|
|
|
rng = np.random.default_rng(int(random_state)) |
|
|
n = len(df) |
|
|
idx_all = np.arange(n) |
|
|
|
|
|
rows = [] |
|
|
n_skipped = 0 |
|
|
|
|
|
for b in range(int(n_boot)): |
|
|
boot_idx = rng.integers(0, n, size=n) |
|
|
oob_mask = np.ones(n, dtype=bool) |
|
|
oob_mask[boot_idx] = False |
|
|
oob_idx = idx_all[oob_mask] |
|
|
|
|
|
|
|
|
if len(oob_idx) == 0: |
|
|
n_skipped += 1 |
|
|
continue |
|
|
|
|
|
df_boot = df.iloc[boot_idx].copy() |
|
|
df_oob = df.iloc[oob_idx].copy() |
|
|
|
|
|
|
|
|
Xb = df_boot[feature_cols].copy().replace({pd.NA: np.nan}) |
|
|
yb_raw = df_boot[LABEL_COL].copy() |
|
|
|
|
|
Xo = df_oob[feature_cols].copy().replace({pd.NA: np.nan}) |
|
|
yo_raw = df_oob[LABEL_COL].copy() |
|
|
|
|
|
|
|
|
for c in num_cols: |
|
|
Xb[c] = pd.to_numeric(Xb[c], errors="coerce") |
|
|
Xo[c] = pd.to_numeric(Xo[c], errors="coerce") |
|
|
|
|
|
|
|
|
for c in cat_cols: |
|
|
Xb[c] = Xb[c].astype("object") |
|
|
Xb.loc[Xb[c].isna(), c] = np.nan |
|
|
Xb[c] = Xb[c].map(lambda v: v if pd.isna(v) else str(v)) |
|
|
|
|
|
Xo[c] = Xo[c].astype("object") |
|
|
Xo.loc[Xo[c].isna(), c] = np.nan |
|
|
Xo[c] = Xo[c].map(lambda v: v if pd.isna(v) else str(v)) |
|
|
|
|
|
|
|
|
try: |
|
|
yb01, _ = coerce_binary_label(yb_raw) |
|
|
yo01, _ = coerce_binary_label(yo_raw) |
|
|
except Exception: |
|
|
n_skipped += 1 |
|
|
continue |
|
|
|
|
|
|
|
|
if len(np.unique(yo01)) < 2: |
|
|
n_skipped += 1 |
|
|
continue |
|
|
|
|
|
|
|
|
pipe_b = build_pipeline( |
|
|
num_cols, cat_cols, |
|
|
use_feature_selection=use_feature_selection, |
|
|
l1_C=l1_C, |
|
|
use_dimred=use_dimred, |
|
|
svd_components=svd_components, |
|
|
) |
|
|
|
|
|
try: |
|
|
pipe_b.fit(Xb, yb01) |
|
|
proba_oob = pipe_b.predict_proba(Xo)[:, 1] |
|
|
except Exception: |
|
|
n_skipped += 1 |
|
|
continue |
|
|
|
|
|
|
|
|
try: |
|
|
auc = float(roc_auc_score(yo01, proba_oob)) |
|
|
except Exception: |
|
|
auc = np.nan |
|
|
|
|
|
pr = compute_pr_curve(yo01, proba_oob) |
|
|
cal = compute_calibration(yo01, proba_oob, n_bins=10, strategy="uniform") |
|
|
cls = compute_classification_metrics(yo01, proba_oob, threshold=float(threshold)) |
|
|
|
|
|
rows.append({ |
|
|
"iter": int(b), |
|
|
"n_oob": int(len(oob_idx)), |
|
|
"roc_auc": auc, |
|
|
"avg_precision": float(pr["average_precision"]), |
|
|
"brier": float(cal["brier"]), |
|
|
"sensitivity": float(cls["sensitivity"]), |
|
|
"specificity": float(cls["specificity"]), |
|
|
"precision": float(cls["precision"]), |
|
|
"recall": float(cls["recall"]), |
|
|
"f1": float(cls["f1"]), |
|
|
"accuracy": float(cls["accuracy"]), |
|
|
"balanced_accuracy": float(cls["balanced_accuracy"]), |
|
|
}) |
|
|
|
|
|
df_boot = pd.DataFrame(rows) |
|
|
|
|
|
def summarise(col: str): |
|
|
s = df_boot[col].dropna() |
|
|
if len(s) == 0: |
|
|
return {"mean": np.nan, "median": np.nan, "ci2.5": np.nan, "ci97.5": np.nan, "n": 0} |
|
|
return { |
|
|
"mean": float(s.mean()), |
|
|
"median": float(s.median()), |
|
|
"ci2.5": float(np.quantile(s, 0.025)), |
|
|
"ci97.5": float(np.quantile(s, 0.975)), |
|
|
"n": int(len(s)) |
|
|
} |
|
|
|
|
|
summary = { |
|
|
"roc_auc": summarise("roc_auc"), |
|
|
"avg_precision": summarise("avg_precision"), |
|
|
"brier": summarise("brier"), |
|
|
"sensitivity": summarise("sensitivity"), |
|
|
"specificity": summarise("specificity"), |
|
|
"f1": summarise("f1"), |
|
|
"balanced_accuracy": summarise("balanced_accuracy"), |
|
|
} |
|
|
|
|
|
return df_boot, summary, n_skipped |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_pr_curve(y_true, y_proba): |
|
|
precision, recall, pr_thresholds = precision_recall_curve(y_true, y_proba) |
|
|
ap = average_precision_score(y_true, y_proba) |
|
|
return { |
|
|
"average_precision": float(ap), |
|
|
"precision": [float(x) for x in precision], |
|
|
"recall": [float(x) for x in recall], |
|
|
"thresholds": [float(x) for x in pr_thresholds], |
|
|
} |
|
|
|
|
|
|
|
|
def compute_calibration(y_true, y_proba, n_bins: int = 10, strategy: str = "uniform"): |
|
|
prob_true, prob_pred = calibration_curve( |
|
|
y_true, y_proba, n_bins=n_bins, strategy=strategy |
|
|
) |
|
|
brier = brier_score_loss(y_true, y_proba) |
|
|
return { |
|
|
"n_bins": int(n_bins), |
|
|
"strategy": str(strategy), |
|
|
"prob_true": [float(x) for x in prob_true], |
|
|
"prob_pred": [float(x) for x in prob_pred], |
|
|
"brier": float(brier), |
|
|
} |
|
|
|
|
|
|
|
|
def decision_curve_analysis(y_true, y_proba, thresholds=None): |
|
|
y_true = np.asarray(y_true).astype(int) |
|
|
y_proba = np.asarray(y_proba).astype(float) |
|
|
|
|
|
if thresholds is None: |
|
|
thresholds = np.linspace(0.01, 0.99, 99) |
|
|
|
|
|
n = len(y_true) |
|
|
prevalence = float(np.mean(y_true)) |
|
|
|
|
|
nb_model = [] |
|
|
nb_all = [] |
|
|
nb_none = [] |
|
|
|
|
|
for pt in thresholds: |
|
|
y_pred = (y_proba >= pt).astype(int) |
|
|
|
|
|
tp = np.sum((y_pred == 1) & (y_true == 1)) |
|
|
fp = np.sum((y_pred == 1) & (y_true == 0)) |
|
|
|
|
|
w = pt / (1.0 - pt) |
|
|
|
|
|
nb_m = (tp / n) - (fp / n) * w |
|
|
nb_a = prevalence - (1.0 - prevalence) * w |
|
|
nb_n = 0.0 |
|
|
|
|
|
nb_model.append(float(nb_m)) |
|
|
nb_all.append(float(nb_a)) |
|
|
nb_none.append(float(nb_n)) |
|
|
|
|
|
return { |
|
|
"thresholds": [float(x) for x in thresholds], |
|
|
"net_benefit_model": nb_model, |
|
|
"net_benefit_all": nb_all, |
|
|
"net_benefit_none": nb_none, |
|
|
"prevalence": prevalence, |
|
|
} |
|
|
|
|
|
|
|
|
def train_and_save( |
|
|
df: pd.DataFrame, |
|
|
feature_cols, |
|
|
num_cols, |
|
|
cat_cols, |
|
|
n_bins: int, |
|
|
cal_strategy: str, |
|
|
dca_points: int, |
|
|
*, |
|
|
use_feature_selection: bool, |
|
|
l1_C: float, |
|
|
use_dimred: bool, |
|
|
svd_components: int,): |
|
|
|
|
|
X = df[feature_cols].copy() |
|
|
y_raw = df[LABEL_COL].copy() |
|
|
|
|
|
X = X.replace({pd.NA: np.nan}) |
|
|
|
|
|
for c in num_cols: |
|
|
X[c] = pd.to_numeric(X[c], errors="coerce") |
|
|
|
|
|
for c in cat_cols: |
|
|
X[c] = X[c].astype("object") |
|
|
X.loc[X[c].isna(), c] = np.nan |
|
|
X[c] = X[c].map(lambda v: v if pd.isna(v) else str(v)) |
|
|
|
|
|
y01, pos_class = coerce_binary_label(y_raw) |
|
|
|
|
|
|
|
|
|
|
|
surv_used_cols = None |
|
|
try: |
|
|
time_days, event01, surv_used_cols = build_survival_targets(df) |
|
|
except Exception: |
|
|
time_days, event01, surv_used_cols = None, None, None |
|
|
|
|
|
|
|
|
|
|
|
X_train, X_test, y_train, y_test = train_test_split( |
|
|
X, y01, test_size=0.2, random_state=42, stratify=y01 |
|
|
) |
|
|
|
|
|
pipe = build_pipeline( |
|
|
num_cols, cat_cols, |
|
|
use_feature_selection=use_feature_selection, |
|
|
l1_C=l1_C, |
|
|
use_dimred=use_dimred, |
|
|
svd_components=svd_components |
|
|
) |
|
|
|
|
|
pipe.fit(X_train, y_train) |
|
|
proba = pipe.predict_proba(X_test)[:, 1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
roc_auc = float(roc_auc_score(y_test, proba)) |
|
|
fpr, tpr, roc_thresholds = roc_curve(y_test, proba) |
|
|
cls_05 = compute_classification_metrics(y_test, proba, threshold=0.5) |
|
|
best_thr, best_val, cls_best = find_best_threshold(y_test, proba, metric="f1") |
|
|
|
|
|
metrics = { |
|
|
"roc_auc": roc_auc, |
|
|
"n_train": int(len(X_train)), |
|
|
"n_test": int(len(X_test)), |
|
|
|
|
|
|
|
|
"threshold@0.5": 0.5, |
|
|
"accuracy@0.5": cls_05["accuracy"], |
|
|
"balanced_accuracy@0.5": cls_05["balanced_accuracy"], |
|
|
"precision@0.5": cls_05["precision"], |
|
|
"recall@0.5": cls_05["recall"], |
|
|
"f1@0.5": cls_05["f1"], |
|
|
"sensitivity@0.5": cls_05["sensitivity"], |
|
|
"specificity@0.5": cls_05["specificity"], |
|
|
"confusion_matrix@0.5": { |
|
|
"tn": cls_05["tn"], "fp": cls_05["fp"], |
|
|
"fn": cls_05["fn"], "tp": cls_05["tp"], |
|
|
}, |
|
|
|
|
|
|
|
|
"best_threshold_by": "f1", |
|
|
"best_threshold": float(best_thr), |
|
|
"best_f1": float(cls_best["f1"]), |
|
|
"accuracy@best": cls_best["accuracy"], |
|
|
"balanced_accuracy@best": cls_best["balanced_accuracy"], |
|
|
"precision@best": cls_best["precision"], |
|
|
"recall@best": cls_best["recall"], |
|
|
"f1@best": cls_best["f1"], |
|
|
"sensitivity@best": cls_best["sensitivity"], |
|
|
"specificity@best": cls_best["specificity"], |
|
|
"confusion_matrix@best": { |
|
|
"tn": cls_best["tn"], "fp": cls_best["fp"], |
|
|
"fn": cls_best["fn"], "tp": cls_best["tp"], |
|
|
}, |
|
|
|
|
|
"roc_curve": { |
|
|
"fpr": [float(x) for x in fpr], |
|
|
"tpr": [float(x) for x in tpr], |
|
|
"thresholds": [float(x) for x in roc_thresholds], |
|
|
}, |
|
|
"pr_curve": compute_pr_curve(y_test, proba), |
|
|
"calibration": compute_calibration(y_test, proba, n_bins, cal_strategy), |
|
|
"decision_curve": decision_curve_analysis(y_test, proba, np.linspace(0.01, 0.99, dca_points)), |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
from sklearn.impute import SimpleImputer |
|
|
|
|
|
survival_trained = False |
|
|
surv_notes = None |
|
|
surv_used_cols = None |
|
|
|
|
|
try: |
|
|
time_days, event01, surv_used_cols = build_survival_targets(df) |
|
|
except Exception: |
|
|
time_days, event01, surv_used_cols = None, None, None |
|
|
|
|
|
if time_days is not None and event01 is not None: |
|
|
try: |
|
|
bundle, surv_notes = train_survival_bundle( |
|
|
df=df, |
|
|
feature_cols=feature_cols, |
|
|
num_cols=num_cols, |
|
|
cat_cols=cat_cols, |
|
|
time_days=time_days, |
|
|
event01=event01, |
|
|
penalizer=0.1 |
|
|
) |
|
|
joblib.dump(bundle, "survival_bundle.joblib", compress=3) |
|
|
survival_trained = True |
|
|
except Exception as e: |
|
|
survival_trained = False |
|
|
surv_notes = f"Survival model training failed: {e}" |
|
|
else: |
|
|
surv_notes = "Survival columns missing or could not be parsed; survival model not trained." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
joblib.dump(pipe, "model.joblib",compress=3) |
|
|
|
|
|
meta = { |
|
|
"framework": "Explainable-Acute-Leukemia-Mortality-Predictor", |
|
|
"model": "Logistic Regression", |
|
|
"created_at_utc": datetime.utcnow().isoformat(), |
|
|
"schema": { |
|
|
"features": feature_cols, |
|
|
"numeric": num_cols, |
|
|
"categorical": cat_cols, |
|
|
"label": LABEL_COL, |
|
|
}, |
|
|
"feature_reduction": { |
|
|
"variance_threshold": 0.0, |
|
|
"use_dimred": bool(use_dimred), |
|
|
"svd_components": int(svd_components) if use_dimred else None, |
|
|
"use_feature_selection": bool(use_feature_selection), |
|
|
"l1_C": float(l1_C) if use_feature_selection else None, |
|
|
"selection_method": "SelectFromModel(L1 saga, threshold=median)" if use_feature_selection else None, |
|
|
"note": "If SVD is enabled, SHAP becomes component-level (less interpretable)." |
|
|
}, |
|
|
"shap_background": { |
|
|
"file": "background.csv", |
|
|
"max_rows": 200, |
|
|
"note": "Raw (pre-transform) background sample for SHAP LinearExplainer." |
|
|
}, |
|
|
"survival": { |
|
|
"enabled": bool(survival_trained), |
|
|
"duration_units": "days", |
|
|
"model": "CoxPHFitter (lifelines) with one-hot for categoricals", |
|
|
"required_columns_any_of": { |
|
|
"start": SURV_START_COL_CANDS, |
|
|
"event": SURV_EVENT_COL_CANDS, |
|
|
"death": SURV_DEATH_COL_CANDS, |
|
|
"censor": SURV_CENSOR_COL_CANDS, |
|
|
}, |
|
|
"used_columns": surv_used_cols, |
|
|
"notes": surv_notes |
|
|
}, |
|
|
|
|
|
|
|
|
"positive_class": str(pos_class), |
|
|
"metrics": metrics, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
with open("meta.json", "w", encoding="utf-8") as f: |
|
|
json.dump(meta, f, indent=2) |
|
|
|
|
|
st.write("Local survival bundle exists:", os.path.exists("survival_bundle.joblib")) |
|
|
if os.path.exists("survival_bundle.joblib"): |
|
|
st.write("Local survival bundle size (bytes):", os.path.getsize("survival_bundle.joblib")) |
|
|
st.write("Local survival bundle exists:", os.path.exists("survival_bundle.joblib")) |
|
|
st.json(meta.get("survival", {})) |
|
|
|
|
|
return pipe, meta, X_train, y_test, proba |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_shap_explainer(pipe, X_bg, max_bg=200): |
|
|
|
|
|
|
|
|
if X_bg is None or len(X_bg) == 0: |
|
|
raise ValueError("SHAP background is empty.") |
|
|
|
|
|
if len(X_bg) > int(max_bg): |
|
|
X_bg = X_bg.sample(int(max_bg), random_state=42) |
|
|
|
|
|
clf = pipe.named_steps["clf"] |
|
|
Xt_bg = transform_before_clf(pipe, X_bg) |
|
|
|
|
|
explainer = shap.LinearExplainer( |
|
|
clf, |
|
|
Xt_bg, |
|
|
feature_perturbation="interventional" |
|
|
) |
|
|
return explainer |
|
|
|
|
|
|
|
|
def safe_dense(Xt, max_rows: int = 200): |
|
|
""" |
|
|
Convert sparse->dense carefully. Avoid converting huge matrices to dense. |
|
|
""" |
|
|
if hasattr(Xt, "shape") and Xt.shape[0] > max_rows: |
|
|
Xt = Xt[:max_rows] |
|
|
try: |
|
|
return Xt.toarray() |
|
|
except Exception: |
|
|
return np.array(Xt) |
|
|
|
|
|
|
|
|
|
|
|
def ensure_model_repo_exists(model_repo_id: str, token: str): |
|
|
""" |
|
|
Optional helper: create the model repo if it doesn't exist. |
|
|
Safe to call; if it exists, it will error -> you can ignore. |
|
|
""" |
|
|
|
|
|
|
|
|
os.environ["HF_HUB_HTTP_TIMEOUT"] = "300" |
|
|
os.environ["HF_HUB_ETAG_TIMEOUT"] = "300" |
|
|
|
|
|
api = HfApi(token=token) |
|
|
try: |
|
|
api.create_repo(repo_id=model_repo_id, repo_type="model", private=False, exist_ok=True) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
def coerce_X_like_schema(X: pd.DataFrame, feature_cols: list[str], num_cols: list[str], cat_cols: list[str]) -> pd.DataFrame: |
|
|
""" |
|
|
Robustly align an inference dataframe to the training schema: |
|
|
- normalizes column names (spaces/newlines/underscores/case) |
|
|
- renames matching columns to the model's exact feature names |
|
|
- creates missing columns as NaN (so prediction can still run) |
|
|
""" |
|
|
X = X.copy() |
|
|
|
|
|
|
|
|
inf_lookup = {norm_col(c): c for c in X.columns} |
|
|
|
|
|
|
|
|
X_out = pd.DataFrame(index=X.index) |
|
|
|
|
|
missing = [] |
|
|
for col in feature_cols: |
|
|
k = norm_col(col) |
|
|
if k in inf_lookup: |
|
|
X_out[col] = X[inf_lookup[k]] |
|
|
else: |
|
|
X_out[col] = np.nan |
|
|
missing.append(col) |
|
|
|
|
|
|
|
|
if missing: |
|
|
st.warning( |
|
|
"Inference file is missing some training columns (filled as blank/NaN): " |
|
|
+ ", ".join(missing[:12]) + (" ..." if len(missing) > 12 else "") |
|
|
) |
|
|
|
|
|
|
|
|
X_out = X_out.replace({pd.NA: np.nan}) |
|
|
|
|
|
for c in num_cols: |
|
|
if c in X_out.columns: |
|
|
X_out[c] = pd.to_numeric(X_out[c], errors="coerce") |
|
|
|
|
|
for c in cat_cols: |
|
|
if c in X_out.columns: |
|
|
X_out[c] = X_out[c].astype("object") |
|
|
X_out.loc[X_out[c].isna(), c] = np.nan |
|
|
X_out[c] = X_out[c].map(lambda v: v if pd.isna(v) else str(v)) |
|
|
|
|
|
return X_out |
|
|
|
|
|
|
|
|
|
|
|
def get_shap_background_auto(model_repo_id: str, feature_cols: list[str], num_cols: list[str], cat_cols: list[str]) -> pd.DataFrame | None: |
|
|
""" |
|
|
Attempts to load SHAP background from HF repo. Returns coerced background or None. |
|
|
""" |
|
|
df_bg = load_latest_background(model_repo_id) |
|
|
if df_bg is None: |
|
|
return None |
|
|
|
|
|
|
|
|
missing = [c for c in feature_cols if c not in df_bg.columns] |
|
|
if missing: |
|
|
return None |
|
|
|
|
|
return coerce_X_like_schema(df_bg, feature_cols, num_cols, cat_cols) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_background_sample_csv(X_bg: pd.DataFrame, feature_cols: list[str], max_rows: int = 200, out_path: str = "background.csv"): |
|
|
""" |
|
|
Saves a small *raw* background dataset (pre-transform) for SHAP explainer. |
|
|
Must contain columns exactly matching feature_cols. |
|
|
""" |
|
|
if X_bg is None or len(X_bg) == 0: |
|
|
raise ValueError("X_bg is empty; cannot save background sample.") |
|
|
|
|
|
X_bg = X_bg[feature_cols].copy() |
|
|
|
|
|
if len(X_bg) > int(max_rows): |
|
|
X_bg = X_bg.sample(int(max_rows), random_state=42) |
|
|
|
|
|
|
|
|
X_bg.to_csv(out_path, index=False, encoding="utf-8") |
|
|
return out_path |
|
|
|
|
|
|
|
|
def publish_background_to_hub(model_repo_id: str, version_tag: str, background_path: str = "background.csv"): |
|
|
""" |
|
|
Uploads background.csv to both versioned and latest paths. |
|
|
Requires HF_TOKEN with write permissions. |
|
|
""" |
|
|
token = os.environ.get("HF_TOKEN") |
|
|
if not token: |
|
|
raise RuntimeError("HF_TOKEN not found. Add it in Space Settings → Secrets.") |
|
|
api = HfApi(token=token) |
|
|
|
|
|
version_bg_path = f"releases/{version_tag}/background.csv" |
|
|
|
|
|
|
|
|
api.upload_file( |
|
|
path_or_fileobj=background_path, |
|
|
path_in_repo=version_bg_path, |
|
|
repo_id=model_repo_id, |
|
|
repo_type="model", |
|
|
commit_message=f"Upload SHAP background ({version_tag})" |
|
|
) |
|
|
|
|
|
|
|
|
api.upload_file( |
|
|
path_or_fileobj=background_path, |
|
|
path_in_repo="latest/background.csv", |
|
|
repo_id=model_repo_id, |
|
|
repo_type="model", |
|
|
commit_message=f"Update latest SHAP background ({version_tag})" |
|
|
) |
|
|
|
|
|
return { |
|
|
"version_bg_path": version_bg_path, |
|
|
"latest_bg_path": "latest/background.csv", |
|
|
} |
|
|
|
|
|
|
|
|
def load_latest_background(model_repo_id: str) -> pd.DataFrame | None: |
|
|
""" |
|
|
Loads latest/background.csv if present. Returns None if not found / cannot load. |
|
|
""" |
|
|
try: |
|
|
bg_file = hf_hub_download( |
|
|
repo_id=model_repo_id, |
|
|
repo_type="model", |
|
|
filename="latest/background.csv", |
|
|
) |
|
|
df_bg = pd.read_csv(bg_file) |
|
|
return df_bg |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
|
|
|
def load_background_by_version(model_repo_id: str, version_tag: str) -> pd.DataFrame | None: |
|
|
""" |
|
|
Loads releases/<version>/background.csv if present. |
|
|
""" |
|
|
try: |
|
|
bg_file = hf_hub_download( |
|
|
repo_id=model_repo_id, |
|
|
repo_type="model", |
|
|
filename=f"releases/{version_tag}/background.csv", |
|
|
) |
|
|
df_bg = pd.read_csv(bg_file) |
|
|
return df_bg |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
|
|
|
def publish_to_hub(model_repo_id: str, version_tag: str): |
|
|
""" |
|
|
Uploads model.joblib and meta.json to a Hugging Face *model* repository |
|
|
under a versioned folder and also updates a 'latest/' copy. |
|
|
Requires Space Secret: HF_TOKEN. |
|
|
""" |
|
|
token = os.environ.get("HF_TOKEN") |
|
|
if not token: |
|
|
raise RuntimeError("HF_TOKEN not found. Add it in Space Settings → Secrets.") |
|
|
|
|
|
api = HfApi(token=token) |
|
|
|
|
|
|
|
|
version_model_path = f"releases/{version_tag}/model.joblib" |
|
|
version_meta_path = f"releases/{version_tag}/meta.json" |
|
|
|
|
|
|
|
|
api.upload_file( |
|
|
path_or_fileobj="model.joblib", |
|
|
path_in_repo=version_model_path, |
|
|
repo_id=model_repo_id, |
|
|
repo_type="model", |
|
|
commit_message=f"Upload model artifacts ({version_tag})" |
|
|
) |
|
|
api.upload_file( |
|
|
path_or_fileobj="meta.json", |
|
|
path_in_repo=version_meta_path, |
|
|
repo_id=model_repo_id, |
|
|
repo_type="model", |
|
|
commit_message=f"Upload metadata ({version_tag})" |
|
|
) |
|
|
|
|
|
|
|
|
api.upload_file( |
|
|
path_or_fileobj="model.joblib", |
|
|
path_in_repo="latest/model.joblib", |
|
|
repo_id=model_repo_id, |
|
|
repo_type="model", |
|
|
commit_message=f"Update latest model ({version_tag})" |
|
|
) |
|
|
api.upload_file( |
|
|
path_or_fileobj="meta.json", |
|
|
path_in_repo="latest/meta.json", |
|
|
repo_id=model_repo_id, |
|
|
repo_type="model", |
|
|
commit_message=f"Update latest metadata ({version_tag})" |
|
|
) |
|
|
|
|
|
|
|
|
if os.path.exists("survival_bundle.joblib"): |
|
|
api.upload_file( |
|
|
path_or_fileobj="survival_bundle.joblib", |
|
|
path_in_repo=f"releases/{version_tag}/survival_bundle.joblib", |
|
|
repo_id=model_repo_id, |
|
|
repo_type="model", |
|
|
commit_message=f"Upload survival model ({version_tag})" |
|
|
) |
|
|
api.upload_file( |
|
|
path_or_fileobj="survival_bundle.joblib", |
|
|
path_in_repo="latest/survival_bundle.joblib", |
|
|
repo_id=model_repo_id, |
|
|
repo_type="model", |
|
|
commit_message=f"Update latest survival model ({version_tag})" |
|
|
) |
|
|
|
|
|
files = api.list_repo_files(repo_id=model_repo_id, repo_type="model") |
|
|
|
|
|
st.write([f for f in files if "survival" in f]) |
|
|
|
|
|
return { |
|
|
"version_model_path": version_model_path, |
|
|
"version_meta_path": version_meta_path, |
|
|
"latest_model_path": "latest/model.joblib", |
|
|
"latest_meta_path": "latest/meta.json", |
|
|
} |
|
|
|
|
|
MODEL_REPO_ID = "Synav/Explainable-Acute-Leukemia-Mortality-Predictor" |
|
|
|
|
|
def list_release_versions(model_repo_id: str): |
|
|
""" |
|
|
Returns sorted version tags found under releases/<version>/model.joblib in the model repo. |
|
|
""" |
|
|
api = HfApi(token=os.environ.get("HF_TOKEN") or None) |
|
|
files = api.list_repo_files(repo_id=model_repo_id, repo_type="model") |
|
|
|
|
|
versions = set() |
|
|
for f in files: |
|
|
|
|
|
if f.startswith("releases/") and f.endswith("/model.joblib"): |
|
|
parts = f.split("/") |
|
|
if len(parts) >= 3: |
|
|
versions.add(parts[1]) |
|
|
|
|
|
|
|
|
return sorted(versions, reverse=True) |
|
|
|
|
|
|
|
|
def load_model_by_version(model_repo_id: str, version_tag: str): |
|
|
""" |
|
|
Loads a specific version from releases/<version_tag>/model.joblib and meta.json |
|
|
""" |
|
|
model_file = hf_hub_download( |
|
|
repo_id=model_repo_id, |
|
|
repo_type="model", |
|
|
filename=f"releases/{version_tag}/model.joblib", |
|
|
) |
|
|
meta_file = hf_hub_download( |
|
|
repo_id=model_repo_id, |
|
|
repo_type="model", |
|
|
filename=f"releases/{version_tag}/meta.json", |
|
|
) |
|
|
|
|
|
pipe = joblib.load(model_file) |
|
|
with open(meta_file, "r", encoding="utf-8") as f: |
|
|
meta = json.load(f) |
|
|
|
|
|
return pipe, meta |
|
|
|
|
|
def load_latest_survival_model(model_repo_id: str): |
|
|
f = hf_hub_download( |
|
|
repo_id=model_repo_id, |
|
|
repo_type="model", |
|
|
filename="latest/survival_bundle.joblib", |
|
|
) |
|
|
return joblib.load(f) |
|
|
|
|
|
def load_survival_model_by_version(model_repo_id: str, version_tag: str): |
|
|
f = hf_hub_download( |
|
|
repo_id=model_repo_id, |
|
|
repo_type="model", |
|
|
filename=f"releases/{version_tag}/survival_bundle.joblib", |
|
|
) |
|
|
return joblib.load(f) |
|
|
|
|
|
def load_latest_survival_bundle(model_repo_id: str): |
|
|
f = hf_hub_download(repo_id=model_repo_id, repo_type="model", filename="latest/survival_bundle.joblib") |
|
|
return joblib.load(f) |
|
|
|
|
|
|
|
|
def load_latest_model(model_repo_id: str): |
|
|
""" |
|
|
Loads latest/model.joblib and latest/meta.json |
|
|
""" |
|
|
model_file = hf_hub_download( |
|
|
repo_id=model_repo_id, |
|
|
repo_type="model", |
|
|
filename="latest/model.joblib", |
|
|
) |
|
|
meta_file = hf_hub_download( |
|
|
repo_id=model_repo_id, |
|
|
repo_type="model", |
|
|
filename="latest/meta.json", |
|
|
) |
|
|
|
|
|
pipe = joblib.load(model_file) |
|
|
with open(meta_file, "r", encoding="utf-8") as f: |
|
|
meta = json.load(f) |
|
|
|
|
|
return pipe, meta |
|
|
|
|
|
def get_post_selection_feature_names(pipe) -> list[str]: |
|
|
""" |
|
|
Returns feature names aligned to the exact columns seen by the final classifier. |
|
|
Works when SVD is OFF. |
|
|
Handles: |
|
|
preprocess -> VarianceThreshold -> SelectFromModel -> clf |
|
|
|
|
|
If some steps are missing, it degrades gracefully. |
|
|
""" |
|
|
pre = pipe.named_steps.get("preprocess", None) |
|
|
if pre is None: |
|
|
raise ValueError("Pipeline missing 'preprocess' step.") |
|
|
|
|
|
|
|
|
try: |
|
|
names = list(pre.get_feature_names_out()) |
|
|
except Exception: |
|
|
|
|
|
names = [f"f{i}" for i in range(pipe.named_steps["clf"].coef_.shape[1])] |
|
|
|
|
|
|
|
|
vt = pipe.named_steps.get("vt", None) |
|
|
if vt is not None and hasattr(vt, "get_support"): |
|
|
vt_mask = vt.get_support() |
|
|
names = [n for n, keep in zip(names, vt_mask) if keep] |
|
|
|
|
|
|
|
|
|
|
|
if "svd" in pipe.named_steps: |
|
|
|
|
|
svd = pipe.named_steps["svd"] |
|
|
k = getattr(svd, "n_components", None) or len(names) |
|
|
return [f"SVD_component_{i+1}" for i in range(int(k))] |
|
|
|
|
|
|
|
|
sel = pipe.named_steps.get("select", None) |
|
|
if sel is not None and hasattr(sel, "get_support"): |
|
|
sel_mask = sel.get_support() |
|
|
names = [n for n, keep in zip(names, sel_mask) if keep] |
|
|
|
|
|
return names |
|
|
|
|
|
def transform_before_clf(pipe, X): |
|
|
""" |
|
|
Transform X through all pipeline steps BEFORE the classifier. |
|
|
Ensures SHAP sees the exact columns the clf sees. |
|
|
""" |
|
|
Xt = X |
|
|
for name, step in pipe.steps: |
|
|
if name == "clf": |
|
|
break |
|
|
if hasattr(step, "transform"): |
|
|
Xt = step.transform(Xt) |
|
|
return Xt |
|
|
|
|
|
def get_final_feature_names(pipe) -> list[str]: |
|
|
""" |
|
|
Return feature names aligned with the exact columns seen by the classifier. |
|
|
Handles: |
|
|
preprocess -> vt -> (svd?) -> (select?) -> clf |
|
|
""" |
|
|
pre = pipe.named_steps.get("preprocess", None) |
|
|
if pre is None: |
|
|
raise ValueError("Pipeline missing 'preprocess' step.") |
|
|
|
|
|
|
|
|
try: |
|
|
names = list(pre.get_feature_names_out()) |
|
|
except Exception: |
|
|
names = None |
|
|
|
|
|
|
|
|
vt = pipe.named_steps.get("vt", None) |
|
|
if names is not None and vt is not None and hasattr(vt, "get_support"): |
|
|
vt_mask = vt.get_support() |
|
|
names = [n for n, keep in zip(names, vt_mask) if keep] |
|
|
|
|
|
|
|
|
if "svd" in pipe.named_steps: |
|
|
svd = pipe.named_steps["svd"] |
|
|
k = int(getattr(svd, "n_components", 0) or 0) |
|
|
if k <= 0: |
|
|
|
|
|
k = 0 |
|
|
return [f"SVD_component_{i+1}" for i in range(k)] |
|
|
|
|
|
|
|
|
sel = pipe.named_steps.get("select", None) |
|
|
if names is not None and sel is not None and hasattr(sel, "get_support"): |
|
|
sel_mask = sel.get_support() |
|
|
names = [n for n, keep in zip(names, sel_mask) if keep] |
|
|
|
|
|
|
|
|
if names is None: |
|
|
n = pipe.named_steps["clf"].coef_.shape[1] |
|
|
names = [f"f{i}" for i in range(n)] |
|
|
|
|
|
return names |
|
|
def options_for(col: str, df: pd.DataFrame | None): |
|
|
base = DEFAULT_OPTIONS.get(col, []) |
|
|
excel = [] |
|
|
if df is not None and col in df.columns: |
|
|
excel = ( |
|
|
df[col].dropna().astype(str).map(lambda x: x.strip()) |
|
|
.loc[lambda s: s != ""].unique().tolist() |
|
|
) |
|
|
|
|
|
out = [] |
|
|
for v in base + sorted(set(excel)): |
|
|
if v not in out: |
|
|
out.append(v) |
|
|
return [""] + out |
|
|
|
|
|
import re |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
REGION_UNKNOWN = "Unknown" |
|
|
|
|
|
|
|
|
NATIONALITY_ALIASES = { |
|
|
|
|
|
"EMARATI": "United Arab Emirates", |
|
|
"EMIRATI": "United Arab Emirates", |
|
|
"UAE": "United Arab Emirates", |
|
|
|
|
|
"FILIPINO": "Philippines", |
|
|
"PALESTINIAN": "Palestine", |
|
|
"SYRIAN": "Syria", |
|
|
"LEBANESE": "Lebanon", |
|
|
"JORDANIAN": "Jordan", |
|
|
"YEMENI": "Yemen", |
|
|
"EGYPTIAN": "Egypt", |
|
|
"SUDANESE": "Sudan", |
|
|
"ETHIOPIAN": "Ethiopia", |
|
|
"ERITREAN": "Eritrea", |
|
|
"SOMALI": "Somalia", |
|
|
"KENYAN": "Kenya", |
|
|
"UGANDAN": "Uganda", |
|
|
"GUINEAN": "Guinea", |
|
|
"MOROCCAN": "Morocco", |
|
|
"COMORAN": "Comoros", |
|
|
|
|
|
"INDIAN": "India", |
|
|
"PAKISTANI": "Pakistan", |
|
|
"BANGLADESH": "Bangladesh", |
|
|
"NEPALESE": "Nepal", |
|
|
"INDONESIAN": "Indonesia", |
|
|
"MALAYSIAN": "Malaysia", |
|
|
|
|
|
"AMERICAN": "United States", |
|
|
"USA": "United States", |
|
|
"U.S.A": "United States", |
|
|
"UNITED STATES OF AMERICA": "United States", |
|
|
} |
|
|
|
|
|
def normalize_country_name(x: str) -> str | None: |
|
|
"""Convert nationality-like strings to a canonical country name when possible.""" |
|
|
if x is None: |
|
|
return None |
|
|
s = str(x).strip() |
|
|
if not s: |
|
|
return None |
|
|
|
|
|
s_up = re.sub(r"\s+", " ", s).strip().upper() |
|
|
|
|
|
|
|
|
if s_up in NATIONALITY_ALIASES: |
|
|
return NATIONALITY_ALIASES[s_up] |
|
|
|
|
|
|
|
|
|
|
|
return s.strip() |
|
|
|
|
|
from typing import Optional |
|
|
|
|
|
def country_to_region(country: str | None) -> str: |
|
|
""" |
|
|
Lazy-import country_converter to reduce startup memory. |
|
|
Returns one of: Africa, Americas, Asia, Europe, Oceania, Unknown |
|
|
""" |
|
|
if not country: |
|
|
return REGION_UNKNOWN |
|
|
|
|
|
import country_converter as coco |
|
|
|
|
|
r = coco.convert(names=country, to="continent") |
|
|
if not r or str(r).lower() in ("not found", "nan", "none"): |
|
|
return REGION_UNKNOWN |
|
|
|
|
|
if r == "America": |
|
|
return "Americas" |
|
|
return str(r) |
|
|
|
|
|
|
|
|
|
|
|
def add_ethnicity_region(df: pd.DataFrame, eth_col: str = "Ethnicity", out_col: str = "Ethnicity_Region") -> pd.DataFrame: |
|
|
if eth_col not in df.columns: |
|
|
df[out_col] = REGION_UNKNOWN |
|
|
return df |
|
|
|
|
|
norm = df[eth_col].map(normalize_country_name) |
|
|
df[out_col] = norm.map(country_to_region) |
|
|
return df |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_admin() -> bool: |
|
|
""" |
|
|
Admin gating for sensitive actions (Train + Publish). |
|
|
Requires Space secret ADMIN_KEY. |
|
|
""" |
|
|
secret = os.environ.get("ADMIN_KEY", "") |
|
|
if not secret: |
|
|
return False |
|
|
entered = st.session_state.get("admin_key", "") |
|
|
return hmac.compare_digest(str(entered), str(secret)) |
|
|
st.set_page_config(page_title="Explainable-Acute-Leukemia-Mortality-Predictor", layout="wide") |
|
|
|
|
|
with st.sidebar: |
|
|
st.markdown("### Plot settings") |
|
|
plot_width = st.slider("Plot width (inches)", 4.0, 10.0, 5.5, 0.1) |
|
|
plot_height = st.slider("Plot height (inches)", 2.5, 6.0, 3.6, 0.1) |
|
|
plot_dpi_screen = st.slider("Screen DPI", 80, 200, 120, 10) |
|
|
export_dpi = st.selectbox("Export DPI (PNG)", [300, 600, 900, 1200], index=1) |
|
|
|
|
|
FIGSIZE = (plot_width, plot_height) |
|
|
|
|
|
|
|
|
st.title("Explainable-Acute-Leukemia-Mortality-Predictor") |
|
|
st.caption("Explainable clinical AI for mortality and outcome prediction in acute leukemia using SHAP-interpretable models") |
|
|
|
|
|
with st.expander("About this AI model, who can use it, and required Excel format", expanded=True): |
|
|
st.markdown(""" |
|
|
### Quick Start (Single Patient) |
|
|
|
|
|
1. Open **Predict + SHAP** |
|
|
2. **Load model** → select *latest* (default) |
|
|
3. Enter patient details (Core → Clinical → FISH → NGS) |
|
|
4. Click **Predict single patient** |
|
|
|
|
|
You will get: |
|
|
• Mortality (os) probability |
|
|
• Risk category |
|
|
• 6-month, 1-year, 3-year survival with PLOT |
|
|
• SHAP explanation |
|
|
|
|
|
Edit values to instantly test different scenarios. |
|
|
|
|
|
*For research/decision-support only — not a substitute for clinical judgment.* |
|
|
""") |
|
|
|
|
|
|
|
|
st.warning( |
|
|
"Prediction will fail if feature names or variable types " |
|
|
"do not exactly match the trained model schema." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
with st.expander("Admin controls", expanded=False): |
|
|
st.text_input("Admin key", type="password", key="admin_key") |
|
|
st.caption("Training and publishing are enabled only after admin authentication.") |
|
|
|
|
|
|
|
|
|
|
|
tab_train, tab_predict = st.tabs(["1️⃣ Train", "2️⃣ Predict + SHAP"]) |
|
|
|
|
|
if "pipe" not in st.session_state: |
|
|
st.session_state.pipe = None |
|
|
if "explainer" not in st.session_state: |
|
|
st.session_state.explainer = None |
|
|
|
|
|
|
|
|
|
|
|
with tab_train: |
|
|
st.subheader("Train model") |
|
|
|
|
|
if not is_admin(): |
|
|
st.info("Training and publishing are restricted. Use Predict + SHAP for inference.") |
|
|
else: |
|
|
st.markdown("### Feature reduction options") |
|
|
|
|
|
use_feature_selection = st.checkbox( |
|
|
"Drop columns that do not affect prediction (L1 feature selection)", |
|
|
value=True, |
|
|
key="train_use_feature_selection" |
|
|
) |
|
|
l1_C = st.slider( |
|
|
"L1 selection strength (lower = fewer features)", |
|
|
0.01, 10.0, 1.0, 0.01 |
|
|
) if use_feature_selection else 1.0 |
|
|
|
|
|
use_dimred = st.checkbox( |
|
|
"Dimensionality reduction (TruncatedSVD) — reduces interpretability", |
|
|
value=False |
|
|
) |
|
|
|
|
|
svd_components = st.slider( |
|
|
"SVD components (only used if enabled)", |
|
|
5, 300, 50, 5 |
|
|
) if use_dimred else 50 |
|
|
|
|
|
st.divider() |
|
|
|
|
|
|
|
|
|
|
|
st.markdown("### Bootstrap internal validation (optional)") |
|
|
|
|
|
do_bootstrap = st.checkbox( |
|
|
"Enable bootstrapping (OOB internal validation)", |
|
|
value=False, |
|
|
key="train_bootstrap_enable" |
|
|
) |
|
|
|
|
|
n_boot = st.slider( |
|
|
"Bootstrap iterations", |
|
|
50, 2000, 1000, 50, |
|
|
key="train_bootstrap_n" |
|
|
) if do_bootstrap else 0 |
|
|
|
|
|
boot_thr = st.slider( |
|
|
"Bootstrap classification threshold", |
|
|
0.0, 1.0, 0.5, 0.01, |
|
|
key="train_bootstrap_thr" |
|
|
) if do_bootstrap else 0.5 |
|
|
|
|
|
|
|
|
|
|
|
train_file = st.file_uploader("Upload training Excel (.xlsx)", type=["xlsx"]) |
|
|
|
|
|
if train_file is None: |
|
|
st.info("Upload a training Excel file to enable training.") |
|
|
else: |
|
|
|
|
|
df = pd.read_excel(train_file, engine="openpyxl") |
|
|
df.columns = [c.strip() for c in df.columns] |
|
|
feature_cols = get_feature_cols_from_df(df) |
|
|
|
|
|
st.dataframe(df.head(), use_container_width=True) |
|
|
feature_cols = get_feature_cols_from_df(df) |
|
|
|
|
|
st.markdown("### Choose variable types (saved into the model)") |
|
|
default_numeric = feature_cols[:13] |
|
|
num_cols = st.multiselect( |
|
|
"Numeric variables (will be median-imputed + scaled)", |
|
|
options=feature_cols, |
|
|
default=default_numeric |
|
|
) |
|
|
|
|
|
|
|
|
cat_cols = [c for c in feature_cols if c not in num_cols] |
|
|
|
|
|
st.write(f"Categorical variables (will be most-frequent-imputed + one-hot): {len(cat_cols)}") |
|
|
st.caption("Note: The selected schema is stored with the trained model and must match inference files.") |
|
|
|
|
|
st.markdown("### Evaluation settings") |
|
|
n_bins = st.slider("Calibration bins", 5, 20, 10, 1) |
|
|
cal_strategy = st.selectbox("Calibration binning strategy", ["uniform", "quantile"], index=0) |
|
|
|
|
|
dca_points = st.slider("Decision curve points", 25, 200, 99, 1) |
|
|
|
|
|
|
|
|
|
|
|
if st.button("Train model"): |
|
|
with st.spinner("Training model..."): |
|
|
pipe, meta, X_train, y_test, proba = train_and_save( |
|
|
df, feature_cols, num_cols, cat_cols, |
|
|
n_bins=n_bins, cal_strategy=cal_strategy, dca_points=dca_points, |
|
|
use_feature_selection=use_feature_selection, |
|
|
l1_C=l1_C, |
|
|
use_dimred=use_dimred, |
|
|
svd_components=svd_components |
|
|
) |
|
|
|
|
|
try: |
|
|
save_background_sample_csv( |
|
|
X_bg=X_train, |
|
|
feature_cols=feature_cols, |
|
|
max_rows=200, |
|
|
out_path="background.csv" |
|
|
) |
|
|
st.success("Saved SHAP background sample (background.csv).") |
|
|
except Exception as e: |
|
|
st.warning(f"Could not save SHAP background sample: {e}") |
|
|
st.write("Local survival bundle exists:", os.path.exists("survival_bundle.joblib")) |
|
|
if os.path.exists("survival_bundle.joblib"): |
|
|
st.write("Local survival bundle size (bytes):", os.path.getsize("survival_bundle.joblib")) |
|
|
|
|
|
explainer = build_shap_explainer(pipe, X_train) |
|
|
|
|
|
st.session_state.pipe = pipe |
|
|
st.session_state.explainer = explainer |
|
|
st.session_state.meta = meta |
|
|
|
|
|
st.success("Training complete. model.joblib and meta.json created.") |
|
|
|
|
|
|
|
|
st.divider() |
|
|
st.subheader("Training performance (test split)") |
|
|
|
|
|
m = meta["metrics"] |
|
|
|
|
|
|
|
|
c1, c2, c3, c4 = st.columns(4) |
|
|
c1.metric("ROC AUC", f"{m['roc_auc']:.3f}") |
|
|
c2.metric("Sensitivity (best F1 thr)", f"{m['sensitivity@best']:.3f}") |
|
|
c3.metric("Specificity (best F1 thr)", f"{m['specificity@best']:.3f}") |
|
|
c4.metric("F1 (best)", f"{m['f1@best']:.3f}") |
|
|
st.caption(f"Best threshold (max F1): {m['best_threshold']:.2f}") |
|
|
|
|
|
|
|
|
c5, c6, c7, c8 = st.columns(4) |
|
|
c5.metric("Precision", f"{m['precision@0.5']:.3f}") |
|
|
c6.metric("Accuracy", f"{m['accuracy@0.5']:.3f}") |
|
|
c7.metric("Balanced Acc", f"{m['balanced_accuracy@0.5']:.3f}") |
|
|
c8.metric("Test N", m["n_test"]) |
|
|
|
|
|
|
|
|
if do_bootstrap: |
|
|
st.divider() |
|
|
st.subheader(f"Bootstrap internal validation (n={n_boot}, OOB)") |
|
|
|
|
|
with st.spinner("Running bootstrap..."): |
|
|
df_boot, boot_summary, n_skipped = bootstrap_internal_validation( |
|
|
df=df, |
|
|
feature_cols=feature_cols, |
|
|
num_cols=num_cols, |
|
|
cat_cols=cat_cols, |
|
|
n_boot=int(n_boot), |
|
|
threshold=float(boot_thr), |
|
|
use_feature_selection=use_feature_selection, |
|
|
l1_C=l1_C, |
|
|
use_dimred=use_dimred, |
|
|
svd_components=svd_components, |
|
|
random_state=42, |
|
|
) |
|
|
|
|
|
st.caption(f"Completed: {len(df_boot)} iterations. Skipped: {n_skipped} (empty/single-class OOB or fitting issues).") |
|
|
|
|
|
|
|
|
st.json(boot_summary) |
|
|
|
|
|
|
|
|
st.dataframe(df_boot.head(30), use_container_width=True) |
|
|
|
|
|
|
|
|
st.download_button( |
|
|
"Download bootstrap results (CSV)", |
|
|
df_boot.to_csv(index=False).encode("utf-8"), |
|
|
file_name=f"bootstrap_oob_{n_boot}.csv", |
|
|
mime="text/csv", |
|
|
key="dl_bootstrap_csv" |
|
|
) |
|
|
|
|
|
|
|
|
st.session_state.meta.setdefault("bootstrap", {}) |
|
|
st.session_state.meta["bootstrap"] = { |
|
|
"method": "bootstrap_oob", |
|
|
"n_boot": int(n_boot), |
|
|
"threshold": float(boot_thr), |
|
|
"skipped": int(n_skipped), |
|
|
"summary": boot_summary, |
|
|
} |
|
|
|
|
|
meta = st.session_state.meta |
|
|
with open("meta.json", "w", encoding="utf-8") as f: |
|
|
json.dump(meta, f, indent=2) |
|
|
|
|
|
|
|
|
|
|
|
cm = m["confusion_matrix@0.5"] |
|
|
cm_df = pd.DataFrame( |
|
|
[[cm["tn"], cm["fp"]], [cm["fn"], cm["tp"]]], |
|
|
index=["Actual 0", "Actual 1"], |
|
|
columns=["Pred 0", "Pred 1"] |
|
|
) |
|
|
st.markdown("**Confusion Matrix (threshold = 0.5)**") |
|
|
st.dataframe(cm_df) |
|
|
|
|
|
st.json(meta["survival"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
roc = m["roc_curve"] |
|
|
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen) |
|
|
ax.plot(roc["fpr"], roc["tpr"]) |
|
|
ax.plot([0, 1], [0, 1]) |
|
|
ax.set_xlabel("False Positive Rate (1 - Specificity)") |
|
|
ax.set_ylabel("True Positive Rate (Sensitivity)") |
|
|
ax.set_title(f"ROC Curve (AUC = {m['roc_auc']:.3f})") |
|
|
|
|
|
render_plot_with_download( |
|
|
fig, |
|
|
title="ROC curve", |
|
|
filename="roc_curve.png", |
|
|
export_dpi=export_dpi, |
|
|
key="dl_train_roc" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pr = m["pr_curve"] |
|
|
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen) |
|
|
ax.plot(pr["recall"], pr["precision"]) |
|
|
ax.set_xlabel("Recall") |
|
|
ax.set_ylabel("Precision") |
|
|
ax.set_title(f"PR Curve (AP = {pr['average_precision']:.3f})") |
|
|
|
|
|
render_plot_with_download( |
|
|
fig, |
|
|
title="PR curve", |
|
|
filename="pr_curve.png", |
|
|
export_dpi=export_dpi, |
|
|
key="dl_train_pr" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cal = m["calibration"] |
|
|
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen) |
|
|
ax.plot(cal["prob_pred"], cal["prob_true"]) |
|
|
ax.plot([0, 1], [0, 1]) |
|
|
ax.set_xlabel("Mean predicted probability") |
|
|
ax.set_ylabel("Observed event rate") |
|
|
ax.set_title("Calibration curve") |
|
|
|
|
|
render_plot_with_download( |
|
|
fig, |
|
|
title="Calibration curve", |
|
|
filename="calibration_curve.png", |
|
|
export_dpi=export_dpi, |
|
|
key="dl_train_cal" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dca = m["decision_curve"] |
|
|
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen) |
|
|
ax.plot(dca["thresholds"], dca["net_benefit_model"], label="Model") |
|
|
ax.plot(dca["thresholds"], dca["net_benefit_all"], label="Treat all") |
|
|
ax.plot(dca["thresholds"], dca["net_benefit_none"], label="Treat none") |
|
|
ax.set_xlabel("Threshold probability") |
|
|
ax.set_ylabel("Net benefit") |
|
|
ax.set_title("Decision curve analysis") |
|
|
ax.legend() |
|
|
|
|
|
render_plot_with_download( |
|
|
fig, |
|
|
title="Decision curve", |
|
|
filename="decision_curve.png", |
|
|
export_dpi=export_dpi, |
|
|
key="dl_train_dca" |
|
|
) |
|
|
|
|
|
|
|
|
st.caption( |
|
|
"If the model curve is above Treat-all and Treat-none across a threshold range, " |
|
|
"the model provides net clinical benefit in that range." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.divider() |
|
|
st.subheader("Threshold analysis") |
|
|
|
|
|
thr = st.slider("Decision threshold", 0.0, 1.0, 0.5, 0.01) |
|
|
|
|
|
|
|
|
|
|
|
st.session_state.y_test_last = y_test |
|
|
st.session_state.proba_last = proba |
|
|
if "y_test_last" in st.session_state and "proba_last" in st.session_state: |
|
|
cls = compute_classification_metrics(st.session_state.y_test_last, st.session_state.proba_last, threshold=thr) |
|
|
st.write({ |
|
|
"Sensitivity": cls["sensitivity"], |
|
|
"Specificity": cls["specificity"], |
|
|
"Precision": cls["precision"], |
|
|
"Recall": cls["recall"], |
|
|
"F1": cls["f1"], |
|
|
"Accuracy": cls["accuracy"], |
|
|
"Balanced Accuracy": cls["balanced_accuracy"], |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if st.session_state.get("pipe") is not None: |
|
|
st.divider() |
|
|
st.subheader("Publish trained model to Hugging Face Hub") |
|
|
|
|
|
default_version = datetime.utcnow().strftime("%Y%m%d-%H%M%S") |
|
|
version_tag = st.text_input( |
|
|
"Version tag", |
|
|
value=default_version, |
|
|
help="Used as releases/<version>/ in the model repository", |
|
|
) |
|
|
|
|
|
if st.button("Publish model.joblib + meta.json to Model Repo", key="publish_btn"): |
|
|
try: |
|
|
with st.spinner("Uploading to Hugging Face Model repo..."): |
|
|
paths = publish_to_hub(MODEL_REPO_ID, version_tag) |
|
|
|
|
|
if os.path.exists("background.csv"): |
|
|
bg_paths = publish_background_to_hub(MODEL_REPO_ID, version_tag, background_path="background.csv") |
|
|
paths.update(bg_paths) |
|
|
else: |
|
|
st.warning("background.csv not found; SHAP background will not be uploaded.") |
|
|
|
|
|
st.success("Uploaded successfully to your model repository.") |
|
|
st.json(paths) |
|
|
except Exception as e: |
|
|
st.error(f"Upload failed: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PRED_N_BINS = 10 |
|
|
PRED_CAL_STRATEGY = "uniform" |
|
|
|
|
|
|
|
|
|
|
|
with tab_predict: |
|
|
st.subheader("Select a trained model (no retraining required)") |
|
|
|
|
|
MODEL_REPO_ID = "Synav/Explainable-Acute-Leukemia-Mortality-Predictor" |
|
|
|
|
|
|
|
|
if "pipe" not in st.session_state: |
|
|
st.session_state.pipe = None |
|
|
if "meta" not in st.session_state: |
|
|
st.session_state.meta = None |
|
|
if "explainer" not in st.session_state: |
|
|
st.session_state.explainer = None |
|
|
|
|
|
|
|
|
try: |
|
|
versions = list_release_versions(MODEL_REPO_ID) |
|
|
except Exception as e: |
|
|
versions = [] |
|
|
st.error(f"Could not list model versions: {e}") |
|
|
|
|
|
choices = ["latest"] + versions if versions else ["latest"] |
|
|
selected = st.selectbox("Choose model version", choices, index=0) |
|
|
|
|
|
if st.button("Load selected model"): |
|
|
try: |
|
|
with st.spinner("Loading model from Hugging Face Hub..."): |
|
|
if selected == "latest": |
|
|
pipe, meta = load_latest_model(MODEL_REPO_ID) |
|
|
else: |
|
|
pipe, meta = load_model_by_version(MODEL_REPO_ID, selected) |
|
|
|
|
|
st.session_state.pipe = pipe |
|
|
st.session_state.meta = meta |
|
|
st.session_state.explainer = None |
|
|
st.success(f"Loaded model: {selected}") |
|
|
except Exception as e: |
|
|
st.error(f"Load failed: {e}") |
|
|
|
|
|
st.divider() |
|
|
if st.session_state.get("pipe") is None or st.session_state.get("meta") is None: |
|
|
st.warning("Load a model version above (it must include meta.json), then continue.") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
pipe = st.session_state.pipe |
|
|
meta = st.session_state.meta |
|
|
|
|
|
|
|
|
try: |
|
|
if selected == "latest": |
|
|
st.session_state.surv_model = load_latest_survival_model(MODEL_REPO_ID) |
|
|
else: |
|
|
st.session_state.surv_model = load_survival_model_by_version(MODEL_REPO_ID, selected) |
|
|
except Exception as e: |
|
|
st.session_state.surv_model = None |
|
|
st.warning(f"Survival bundle load failed: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEBUG_SURV = False |
|
|
|
|
|
if DEBUG_SURV: |
|
|
st.write("Survival bundle loaded:", isinstance(bundle, dict)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
feature_cols = meta["schema"]["features"] |
|
|
num_cols = meta["schema"]["numeric"] |
|
|
cat_cols = meta["schema"]["categorical"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
df_inf = st.session_state.get("df_inf") |
|
|
|
|
|
if df_inf is not None: |
|
|
|
|
|
X_bg = coerce_X_like_schema(df_inf, feature_cols, num_cols, cat_cols) |
|
|
else: |
|
|
|
|
|
X_bg = get_shap_background_auto(MODEL_REPO_ID, feature_cols, num_cols, cat_cols) |
|
|
|
|
|
st.session_state.X_bg_for_shap = X_bg |
|
|
|
|
|
|
|
|
|
|
|
FEATURE_LOOKUP = {norm_col(c): c for c in feature_cols} |
|
|
|
|
|
|
|
|
|
|
|
from datetime import date |
|
|
|
|
|
MIN_DOB = date(1900, 1, 1) |
|
|
|
|
|
ALL_COUNTRIES = [ |
|
|
"Afghanistan","Albania","Algeria","Andorra","Angola","Antigua and Barbuda", |
|
|
"Argentina","Armenia","Australia","Austria","Azerbaijan","Bahamas","Bahrain", |
|
|
"Bangladesh","Barbados","Belarus","Belgium","Belize","Benin","Bhutan","Bolivia", |
|
|
"Bosnia and Herzegovina","Botswana","Brazil","Brunei","Bulgaria","Burkina Faso", |
|
|
"Burundi","Cabo Verde","Cambodia","Cameroon","Canada","Central African Republic", |
|
|
"Chad","Chile","China","Colombia","Comoros","Congo (Congo-Brazzaville)", |
|
|
"Costa Rica","Cote d’Ivoire","Croatia","Cuba","Cyprus","Czechia", |
|
|
"Democratic Republic of the Congo","Denmark","Djibouti","Dominica", |
|
|
"Dominican Republic","Ecuador","Egypt","El Salvador","Equatorial Guinea", |
|
|
"Eritrea","Estonia","Eswatini","Ethiopia","Fiji","Finland","France","Gabon", |
|
|
"Gambia","Georgia","Germany","Ghana","Greece","Grenada","Guatemala","Guinea", |
|
|
"Guinea-Bissau","Guyana","Haiti","Honduras","Hungary","Iceland","India", |
|
|
"Indonesia","Iran","Iraq","Ireland","Israel","Italy","Jamaica","Japan","Jordan", |
|
|
"Kazakhstan","Kenya","Kiribati","Kuwait","Kyrgyzstan","Laos","Latvia","Lebanon", |
|
|
"Lesotho","Liberia","Libya","Liechtenstein","Lithuania","Luxembourg", |
|
|
"Madagascar","Malawi","Malaysia","Maldives","Mali","Malta","Marshall Islands", |
|
|
"Mauritania","Mauritius","Mexico","Micronesia","Moldova","Monaco","Mongolia", |
|
|
"Montenegro","Morocco","Mozambique","Myanmar","Namibia","Nauru","Nepal", |
|
|
"Netherlands","New Zealand","Nicaragua","Niger","Nigeria","North Korea", |
|
|
"North Macedonia","Norway","Oman","Pakistan","Palau","Palestine","Panama", |
|
|
"Papua New Guinea","Paraguay","Peru","Philippines","Poland","Portugal","Qatar", |
|
|
"Romania","Russia","Rwanda","Saint Kitts and Nevis","Saint Lucia", |
|
|
"Saint Vincent and the Grenadines","Samoa","San Marino","Sao Tome and Principe", |
|
|
"Saudi Arabia","Senegal","Serbia","Seychelles","Sierra Leone","Singapore", |
|
|
"Slovakia","Slovenia","Solomon Islands","Somalia","South Africa","South Korea", |
|
|
"South Sudan","Spain","Sri Lanka","Sudan","Suriname","Sweden","Switzerland", |
|
|
"Syria","Taiwan","Tajikistan","Tanzania","Thailand","Timor-Leste","Togo","Tonga", |
|
|
"Trinidad and Tobago","Tunisia","Turkey","Turkmenistan","Tuvalu","Uganda", |
|
|
"Ukraine","United Arab Emirates","United Kingdom","United States","Uruguay", |
|
|
"Uzbekistan","Vanuatu","Vatican City","Venezuela","Vietnam","Yemen","Zambia", |
|
|
"Zimbabwe", |
|
|
"Other","Unknown" |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_OPTIONS = { |
|
|
|
|
|
|
|
|
"Gender": ["Male", "Female"], |
|
|
|
|
|
"Ethnicity": ALL_COUNTRIES, |
|
|
|
|
|
|
|
|
"Type of Leukemia": [ |
|
|
"ALL", |
|
|
"AML", |
|
|
"APL", |
|
|
"CML", |
|
|
"MPAL", |
|
|
"Secondary AML", |
|
|
"Other", |
|
|
"Unknown" |
|
|
], |
|
|
|
|
|
"Risk Assesment": [ |
|
|
"Favorable", |
|
|
"Intermediate", |
|
|
"Adverse", |
|
|
"Unknown" |
|
|
], |
|
|
|
|
|
"ECOG": [0, 1, 2, 3, 4], |
|
|
|
|
|
|
|
|
"MDx Leukemia Screen 1": [ |
|
|
"Negative", |
|
|
"t(15;17) / PML-RARA", |
|
|
"t(9;22) / BCR-ABL1", |
|
|
"RUNX1-RUNX1T1", |
|
|
"inv(16) / CBFB-MYH11", |
|
|
"KMT2A-AFF1", |
|
|
"t(1;19) / TCF3-PBX1", |
|
|
"t(10;11)", |
|
|
"Other", |
|
|
"Not done", |
|
|
"Unknown" |
|
|
], |
|
|
|
|
|
"Post-Induction MRD": [ |
|
|
"Negative", |
|
|
"Positive", |
|
|
"Pending", |
|
|
"Not done", |
|
|
"Unknown" |
|
|
], |
|
|
|
|
|
|
|
|
"First_Induction": [ |
|
|
"3+7", |
|
|
"3+7+GO", |
|
|
"3+7+TKI", |
|
|
"3+7+Midostaurin", |
|
|
"ATO+ATRA", |
|
|
"Aza+Ven", |
|
|
"CALGB", |
|
|
"CPX-351", |
|
|
"Dexa+TKI", |
|
|
"FLAG-IDA", |
|
|
"HCVAD", |
|
|
"HCVAD+TKI", |
|
|
"HMA", |
|
|
"Hyper-CVAD", |
|
|
"Not fit", |
|
|
"Pediatric protocol", |
|
|
"Other", |
|
|
"Unknown" |
|
|
], |
|
|
|
|
|
|
|
|
|
|
|
"Thrombocytopenia": ["No", "Yes"], |
|
|
"Bleeding": ["No", "Yes"], |
|
|
"B Symptoms": ["No", "Yes"], |
|
|
"Lymphadenopathy": ["No", "Yes"], |
|
|
"CNS Involvement": ["No", "Yes"], |
|
|
"Extramedullary Involvement": ["No", "Yes"], |
|
|
} |
|
|
|
|
|
|
|
|
CANON_DROPDOWNS = [ |
|
|
"Ethnicity", |
|
|
"Type of Leukemia", |
|
|
"Post-Induction MRD", |
|
|
"First_Induction", |
|
|
"Risk Assesment", |
|
|
"MDx Leukemia Screen 1", |
|
|
"ECOG", |
|
|
"Gender", |
|
|
] |
|
|
|
|
|
|
|
|
DROPDOWN_FIELDS_ACTUAL = [] |
|
|
for canon in CANON_DROPDOWNS: |
|
|
key = norm_col(canon) |
|
|
if key in FEATURE_LOOKUP: |
|
|
DROPDOWN_FIELDS_ACTUAL.append(FEATURE_LOOKUP[key]) |
|
|
def alias_default_options(canon_name: str): |
|
|
k = norm_col(canon_name) |
|
|
if k in FEATURE_LOOKUP: |
|
|
actual = FEATURE_LOOKUP[k] |
|
|
if canon_name in DEFAULT_OPTIONS and actual not in DEFAULT_OPTIONS: |
|
|
DEFAULT_OPTIONS[actual] = DEFAULT_OPTIONS[canon_name] |
|
|
|
|
|
alias_default_options("Ethnicity") |
|
|
alias_default_options("Type of Leukemia") |
|
|
alias_default_options("Post-Induction MRD") |
|
|
alias_default_options("First_Induction") |
|
|
alias_default_options("Risk Assesment") |
|
|
alias_default_options("MDx Leukemia Screen 1") |
|
|
|
|
|
|
|
|
|
|
|
DROPDOWN_FIELDS = [ |
|
|
"Gender", |
|
|
"Ethnicity", |
|
|
"Type of Leukemia", |
|
|
"ECOG", |
|
|
"Risk Assesment", |
|
|
"MDx Leukemia Screen 1", |
|
|
"Post-Induction MRD", |
|
|
"First_Induction", |
|
|
] |
|
|
|
|
|
|
|
|
YESNO_FIELDS = [ |
|
|
"Thrombocytopenia", |
|
|
"Bleeding", |
|
|
"B Symptoms", |
|
|
"Lymphadenopathy", |
|
|
"CNS Involvement", |
|
|
"Extramedullary Involvement", |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
UNITS = { |
|
|
'WBCs on Admission\n ( x10^9/L) ': "x10^9/L", |
|
|
"Hb on Admission (g/L)": "g/L", |
|
|
"LDH on Admission (IU/L)": "IU/L", |
|
|
"Pre-Induction bone marrow biopsy blasts %": "%", |
|
|
"Age (years)": "years", |
|
|
} |
|
|
|
|
|
|
|
|
FISH_MARKERS = [ |
|
|
"fish_inv16/cbfb_myh11", |
|
|
"fish_t8;21/runx1_runx1t1", |
|
|
"fish_t15;17/PML-RARA", |
|
|
"fish_KMT2A-MLL/11q23", |
|
|
"fish_BCR-ABL1/t9;22", |
|
|
"fish_ETV6-RUNX1/t12;21", |
|
|
"fish_TCF3_PBX1/t1;19", |
|
|
"fish_IGH/14q32rearr", |
|
|
"fish_CRLF2 rearr", |
|
|
"fish_NUP214/9q34", |
|
|
"fish_Other", |
|
|
"fish_del7q/monosomy7", |
|
|
"fish_TP53/del17p", |
|
|
"fish_del13q", |
|
|
"fish_del11q", |
|
|
] |
|
|
FISH_COUNT_COL = "Number of FISH alteration" |
|
|
|
|
|
NGS_MARKERS = [ |
|
|
"ngs_FLT3", |
|
|
"ngs_NPM1", |
|
|
"ngs_CEBPA", |
|
|
"ngs_DNMT3A", |
|
|
"ngs_IDH1", |
|
|
"ngs_IDH2", |
|
|
"ngs_TET2", |
|
|
"ngs_TP53", |
|
|
"ngs_RUNX1", |
|
|
"ngs_Other", |
|
|
"ngs_spliceosome", |
|
|
] |
|
|
NGS_COUNT_COL = "No. of ngs_mutation" |
|
|
|
|
|
|
|
|
AGE_FEATURE = "Age (years)" |
|
|
DX_DATE_FEATURE = "Date of 1st Bone Marrow biopsy (Date of Diagnosis)" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_options_from_df(df: pd.DataFrame, col: str) -> list[str]: |
|
|
"""Dropdown options from uploaded Excel (inference file).""" |
|
|
if df is None or col not in df.columns: |
|
|
return [] |
|
|
vals = ( |
|
|
df[col] |
|
|
.dropna() |
|
|
.astype(str) |
|
|
.map(lambda x: x.strip()) |
|
|
.loc[lambda s: s != ""] |
|
|
.unique() |
|
|
.tolist() |
|
|
) |
|
|
vals = sorted(vals) |
|
|
return vals |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pipe = st.session_state.pipe |
|
|
meta = st.session_state.meta |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FEATURE_LOOKUP = {norm_col(c): c for c in feature_cols} |
|
|
|
|
|
|
|
|
TAB_MANAGED_FIELDS = set() |
|
|
|
|
|
|
|
|
if AGE_FEATURE in feature_cols: |
|
|
TAB_MANAGED_FIELDS.add(AGE_FEATURE) |
|
|
if DX_DATE_FEATURE in feature_cols: |
|
|
TAB_MANAGED_FIELDS.add(DX_DATE_FEATURE) |
|
|
if "Date of 1st CR" in feature_cols: |
|
|
TAB_MANAGED_FIELDS.add("Date of 1st CR") |
|
|
|
|
|
|
|
|
|
|
|
TAB_MANAGED_FIELDS.update([f for f in YESNO_FIELDS if f in feature_cols]) |
|
|
|
|
|
|
|
|
TAB_MANAGED_FIELDS.update([f for f in FISH_MARKERS if f in feature_cols]) |
|
|
|
|
|
|
|
|
TAB_MANAGED_FIELDS.update([f for f in NGS_MARKERS if f in feature_cols]) |
|
|
|
|
|
DROPDOWN_FIELDS = [f for f in DROPDOWN_FIELDS if f not in TAB_MANAGED_FIELDS] |
|
|
|
|
|
|
|
|
|
|
|
if FISH_COUNT_COL in feature_cols: |
|
|
TAB_MANAGED_FIELDS.add(FISH_COUNT_COL) |
|
|
if NGS_COUNT_COL in feature_cols: |
|
|
TAB_MANAGED_FIELDS.add(NGS_COUNT_COL) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
infer_file = st.file_uploader("Upload inference Excel (.xlsx) (optional, for dropdown options + batch prediction)", type=["xlsx"], key="infer_xlsx") |
|
|
if infer_file: |
|
|
st.session_state.df_inf = pd.read_excel(infer_file, engine="openpyxl") |
|
|
|
|
|
|
|
|
st.session_state.df_inf.columns = [ |
|
|
" ".join(str(c).replace("\u00A0"," ").split()) |
|
|
for c in st.session_state.df_inf.columns |
|
|
] |
|
|
|
|
|
|
|
|
st.session_state.df_inf = align_columns_to_schema( |
|
|
st.session_state.df_inf, |
|
|
meta["schema"]["features"]) |
|
|
|
|
|
st.session_state.df_inf = add_ethnicity_region( |
|
|
st.session_state.df_inf, |
|
|
"Ethnicity", |
|
|
"Ethnicity_Region" |
|
|
) |
|
|
else: |
|
|
st.session_state.df_inf = None |
|
|
|
|
|
|
|
|
df_for_options = st.session_state.df_inf |
|
|
|
|
|
st.divider() |
|
|
st.subheader("Enter patient details to get prediction (Example: DOB → Age, Dx date → prediction)") |
|
|
|
|
|
|
|
|
|
|
|
from datetime import date |
|
|
|
|
|
MIN_DATE = date(1900, 1, 1) |
|
|
|
|
|
c1, c2 = st.columns(2) |
|
|
with c1: |
|
|
dob_unknown = st.checkbox("DOB unknown", value=False, key="dob_unknown_chk") |
|
|
dob = None |
|
|
if not dob_unknown: |
|
|
dob = st.date_input("Date of birth (DOB)", min_value=MIN_DATE, max_value=date.today(), key="dob_date") |
|
|
|
|
|
with c2: |
|
|
dx_unknown = st.checkbox("Diagnosis date unknown", value=False, key="dx_unknown_chk") |
|
|
dx_date = None |
|
|
if not dx_unknown: |
|
|
dx_date = st.date_input("Date of 1st Bone Marrow biopsy (Diagnosis)", min_value=MIN_DATE, max_value=date.today(), key="dx_date") |
|
|
|
|
|
|
|
|
cr1_date = None |
|
|
if "Date of 1st CR" in feature_cols: |
|
|
cr1_unknown = st.checkbox("1st CR date unknown", value=False, key="cr1_unknown") |
|
|
if not cr1_unknown: |
|
|
cr1_date = st.date_input("Date of 1st CR", min_value=MIN_DATE, max_value=date.today(), key="cr1_date") |
|
|
|
|
|
derived_age = age_years_at(dob, dx_date) |
|
|
|
|
|
|
|
|
|
|
|
def yesno_to_01(v: str): |
|
|
if v == "Yes": |
|
|
return 1 |
|
|
if v == "No": |
|
|
return 0 |
|
|
return np.nan |
|
|
|
|
|
|
|
|
tab_core, tab_clin, tab_fish, tab_ngs = st.tabs(["Core", "Clinical (Yes/No)", "FISH", "NGS"]) |
|
|
|
|
|
values_by_index = [np.nan] * len(feature_cols) |
|
|
|
|
|
|
|
|
with tab_fish: |
|
|
st.caption("Select one or many FISH alterations. Selected = 1, not selected = 0.") |
|
|
fish_selected = st.multiselect( |
|
|
"FISH alterations", |
|
|
options=[m for m in FISH_MARKERS if m in feature_cols], |
|
|
default=[], |
|
|
key="sp_fish_selected", |
|
|
) |
|
|
|
|
|
with tab_ngs: |
|
|
st.caption("Select one or many NGS mutations. Selected = 1, not selected = 0.") |
|
|
ngs_selected = st.multiselect( |
|
|
"NGS mutations", |
|
|
options=[m for m in NGS_MARKERS if m in feature_cols], |
|
|
default=[], |
|
|
key="sp_ngs_selected", |
|
|
) |
|
|
|
|
|
with tab_clin: |
|
|
st.caption("Clinical flags: Yes=1, No=0") |
|
|
for i, f in enumerate(feature_cols): |
|
|
if f in YESNO_FIELDS: |
|
|
v = st.selectbox(f, options=["", "No", "Yes"], index=0, key=f"sp_{i}_yn") |
|
|
values_by_index[i] = yesno_to_01(v) |
|
|
|
|
|
with tab_core: |
|
|
ecog = st.selectbox("ECOG", options=[0, 1, 2, 3, 4], index=0, key="sp_ecog") |
|
|
|
|
|
for i, f in enumerate(feature_cols): |
|
|
|
|
|
if f in TAB_MANAGED_FIELDS: |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
if f.strip() == AGE_FEATURE.strip(): |
|
|
|
|
|
if dob_unknown or dx_unknown or np.isnan(derived_age): |
|
|
v_age = st.number_input( |
|
|
f"{f} (enter if DOB/Dx unknown)", |
|
|
value=None, |
|
|
step=1, |
|
|
format="%d", |
|
|
key=f"sp_{i}_age_manual", |
|
|
) |
|
|
values_by_index[i] = v_age |
|
|
else: |
|
|
st.number_input( |
|
|
f"{f} (auto from DOB & Dx date)", |
|
|
value=int(round(derived_age)), |
|
|
step=1, |
|
|
format="%d", |
|
|
key=f"sp_{i}_age_auto", |
|
|
disabled=True, |
|
|
) |
|
|
values_by_index[i] = float(derived_age) |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if f.strip() == DX_DATE_FEATURE.strip(): |
|
|
values_by_index[i] = np.nan if dx_date is None else dx_date.isoformat() |
|
|
st.text_input( |
|
|
f"{f} (auto)", |
|
|
value="" if dx_date is None else dx_date.isoformat(), |
|
|
disabled=True, |
|
|
key=f"sp_{i}_dx_show" |
|
|
) |
|
|
continue |
|
|
|
|
|
|
|
|
if f.strip() == "Date of 1st CR".strip(): |
|
|
values_by_index[i] = np.nan if cr1_date is None else cr1_date.isoformat() |
|
|
st.text_input( |
|
|
f"{f} (auto)", |
|
|
value="" if cr1_date is None else cr1_date.isoformat(), |
|
|
disabled=True, |
|
|
key=f"sp_{i}_cr_show" |
|
|
) |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if f.strip() == "ECOG": |
|
|
values_by_index[i] = int(ecog) |
|
|
continue |
|
|
|
|
|
|
|
|
if f == "Gender": |
|
|
v = st.selectbox("Gender", options=["", "Male", "Female"], index=0, key=f"sp_{i}_gender") |
|
|
values_by_index[i] = np.nan if v == "" else v |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
if f in DROPDOWN_FIELDS_ACTUAL and f not in ("Gender", "ECOG"): |
|
|
opts = options_for(f, df_for_options) |
|
|
v = st.selectbox(f, options=opts, index=0, key=f"sp_{i}_dd") |
|
|
values_by_index[i] = np.nan if v == "" else v |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if f in num_cols: |
|
|
unit = UNITS.get(f, None) |
|
|
label = f"{f} ({unit})" if unit else f |
|
|
if f == "Age (years)": |
|
|
v = st.number_input(label, value=None, step=1, format="%d", key=f"sp_{i}_num") |
|
|
else: |
|
|
v = st.number_input(label, value=None, format="%.2f", key=f"sp_{i}_num") |
|
|
values_by_index[i] = v |
|
|
continue |
|
|
|
|
|
|
|
|
if f in cat_cols: |
|
|
v = st.text_input(f, value="", key=f"sp_{i}_cat") |
|
|
values_by_index[i] = np.nan if v.strip() == "" else v |
|
|
continue |
|
|
|
|
|
|
|
|
v = st.text_input(f, value="", key=f"sp_{i}_other") |
|
|
values_by_index[i] = np.nan if v.strip() == "" else v |
|
|
|
|
|
|
|
|
|
|
|
fish_set = set(fish_selected) |
|
|
ngs_set = set(ngs_selected) |
|
|
|
|
|
for i, f in enumerate(feature_cols): |
|
|
if f in FISH_MARKERS: |
|
|
values_by_index[i] = 1 if f in fish_set else 0 |
|
|
if f in NGS_MARKERS: |
|
|
values_by_index[i] = 1 if f in ngs_set else 0 |
|
|
|
|
|
|
|
|
if FISH_COUNT_COL in feature_cols: |
|
|
values_by_index[feature_cols.index(FISH_COUNT_COL)] = int(len(fish_selected)) |
|
|
if NGS_COUNT_COL in feature_cols: |
|
|
values_by_index[feature_cols.index(NGS_COUNT_COL)] = int(len(ngs_selected)) |
|
|
st.divider() |
|
|
st.subheader("Predict single patient") |
|
|
|
|
|
m = meta.get("metrics", {}) |
|
|
default_thr = float(m.get("best_threshold", 0.5)) |
|
|
|
|
|
thr_single = st.slider( |
|
|
"Classification threshold", |
|
|
0.0, 1.0, default_thr, 0.01, |
|
|
key="sp_thr" |
|
|
) |
|
|
|
|
|
|
|
|
thr_ext = st.slider( |
|
|
"External validation threshold", |
|
|
0.0, 1.0, default_thr, 0.01, |
|
|
key="thr_ext" |
|
|
) |
|
|
|
|
|
low_cut_s, high_cut_s = st.slider( |
|
|
"Risk band cutoffs (low, high)", |
|
|
0.0, 1.0, (0.2, 0.8), 0.01, |
|
|
key="sp_risk_cuts" |
|
|
) |
|
|
|
|
|
|
|
|
if low_cut_s > high_cut_s: |
|
|
low_cut_s, high_cut_s = high_cut_s, low_cut_s |
|
|
|
|
|
|
|
|
def band_one(p: float) -> str: |
|
|
if p < low_cut_s: |
|
|
return "Low" |
|
|
if p >= high_cut_s: |
|
|
return "High" |
|
|
return "Intermediate" |
|
|
|
|
|
if st.button("Predict single patient", key="sp_predict_btn"): |
|
|
X_one = pd.DataFrame([values_by_index], columns=feature_cols).replace({pd.NA: np.nan}) |
|
|
|
|
|
for c in num_cols: |
|
|
if c in X_one.columns: |
|
|
X_one[c] = pd.to_numeric(X_one[c], errors="coerce") |
|
|
|
|
|
for c in cat_cols: |
|
|
if c in X_one.columns: |
|
|
X_one[c] = X_one[c].astype("object") |
|
|
X_one.loc[X_one[c].isna(), c] = np.nan |
|
|
X_one[c] = X_one[c].map(lambda v: v if pd.isna(v) else str(v)) |
|
|
|
|
|
|
|
|
proba_one = float(pipe.predict_proba(X_one)[:, 1][0]) |
|
|
st.success("Prediction generated.") |
|
|
st.metric("Predicted probability", f"{proba_one:.4f}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bundle = st.session_state.get("surv_model", None) |
|
|
|
|
|
if isinstance(bundle, dict) and bundle.get("model") is not None: |
|
|
try: |
|
|
cph = bundle["model"] |
|
|
surv_cols = bundle.get("columns", []) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cox_feature_cols = bundle.get("feature_cols", feature_cols) |
|
|
cox_num_cols = bundle.get("num_cols", num_cols) |
|
|
cox_cat_cols = bundle.get("cat_cols", cat_cols) |
|
|
df_one_surv = X_one[cox_feature_cols].copy() |
|
|
|
|
|
|
|
|
|
|
|
for c in cox_num_cols: |
|
|
if c in df_one_surv.columns: |
|
|
df_one_surv[c] = pd.to_numeric(df_one_surv[c], errors="coerce") |
|
|
|
|
|
for c in cox_cat_cols: |
|
|
if c in df_one_surv.columns: |
|
|
df_one_surv[c] = df_one_surv[c].astype("object").map( |
|
|
lambda v: v if pd.isna(v) else str(v) |
|
|
) |
|
|
|
|
|
df_one_surv_oh = pd.get_dummies(df_one_surv, columns=cox_cat_cols, drop_first=True) |
|
|
df_one_surv_oh = df_one_surv_oh.loc[:, ~df_one_surv_oh.columns.duplicated()].copy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for col in surv_cols: |
|
|
if col not in df_one_surv_oh.columns: |
|
|
df_one_surv_oh[col] = 0 |
|
|
df_one_surv_oh = df_one_surv_oh[surv_cols] |
|
|
|
|
|
imp = bundle.get("imputer", None) |
|
|
if imp is not None: |
|
|
X_imp = imp.transform(df_one_surv_oh) |
|
|
df_one_surv_oh = pd.DataFrame(X_imp, columns=surv_cols, index=df_one_surv_oh.index) |
|
|
else: |
|
|
df_one_surv_oh = df_one_surv_oh.fillna(0) |
|
|
|
|
|
|
|
|
surv_fn = cph.predict_survival_function(df_one_surv_oh) |
|
|
|
|
|
def surv_at(days: int) -> float: |
|
|
idx = surv_fn.index.values |
|
|
j = int(np.argmin(np.abs(idx - days))) |
|
|
return float(surv_fn.iloc[j, 0]) |
|
|
|
|
|
s6m = surv_at(180) |
|
|
s1y = surv_at(365) |
|
|
s2y = surv_at(730) |
|
|
s3y = surv_at(1095) |
|
|
|
|
|
st.subheader("Predicted survival probability") |
|
|
a, b, c, d = st.columns(4) |
|
|
a.metric("6 months", f"{s6m*100:.1f}%") |
|
|
b.metric("1 year", f"{s1y*100:.1f}%") |
|
|
c.metric("2 years", f"{s2y*100:.1f}%") |
|
|
d.metric("3 years", f"{s3y*100:.1f}%") |
|
|
|
|
|
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen) |
|
|
ax.plot(surv_fn.index, surv_fn.values) |
|
|
ax.set_xlabel("Days from diagnosis") |
|
|
ax.set_ylabel("Survival probability") |
|
|
ax.set_ylim(0, 1) |
|
|
ax.set_title("Predicted survival curve (patient)") |
|
|
|
|
|
render_plot_with_download( |
|
|
fig, |
|
|
title="Patient survival curve", |
|
|
filename="patient_survival_curve.png", |
|
|
export_dpi=export_dpi, |
|
|
key="dl_patient_surv_curve" |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
st.warning(f"Survival prediction could not be computed: {e}") |
|
|
|
|
|
else: |
|
|
st.info("Survival model not loaded/published yet (survival bundle missing).") |
|
|
|
|
|
|
|
|
|
|
|
out = X_one.copy() |
|
|
out["predicted_probability"] = proba_one |
|
|
pred_class = int(proba_one >= thr_single) |
|
|
out["predicted_class"] = pred_class |
|
|
|
|
|
out["risk_band"] = band_one(proba_one) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
X_one_t = transform_before_clf(pipe, X_one) |
|
|
|
|
|
explainer = st.session_state.get("explainer") |
|
|
explainer_sig = st.session_state.get("explainer_sig") |
|
|
|
|
|
current_sig = ( |
|
|
selected, |
|
|
None if st.session_state.get("X_bg_for_shap") is None else int(len(st.session_state["X_bg_for_shap"])) |
|
|
) |
|
|
|
|
|
if explainer is None or explainer_sig != current_sig: |
|
|
X_bg = st.session_state.get("X_bg_for_shap") |
|
|
if X_bg is None: |
|
|
st.error("SHAP background not available. Admin must publish latest/background.csv.") |
|
|
st.stop() |
|
|
|
|
|
st.session_state.explainer = build_shap_explainer(pipe, X_bg) |
|
|
st.session_state.explainer_sig = current_sig |
|
|
explainer = st.session_state.explainer |
|
|
|
|
|
shap_vals = explainer.shap_values(X_one_t) |
|
|
if isinstance(shap_vals, list): |
|
|
shap_vals = shap_vals[1] |
|
|
|
|
|
names = get_final_feature_names(pipe) |
|
|
|
|
|
try: |
|
|
x_dense = X_one_t.toarray()[0] |
|
|
except Exception: |
|
|
x_dense = np.array(X_one_t)[0] |
|
|
|
|
|
base = explainer.expected_value |
|
|
if not np.isscalar(base): |
|
|
base = float(np.array(base).reshape(-1)[0]) |
|
|
|
|
|
exp = shap.Explanation( |
|
|
values=shap_vals[0], |
|
|
base_values=float(base), |
|
|
data=x_dense, |
|
|
feature_names=names, |
|
|
) |
|
|
|
|
|
|
|
|
st.session_state.shap_single_exp = exp |
|
|
|
|
|
|
|
|
st.dataframe(out, use_container_width=True) |
|
|
|
|
|
st.download_button( |
|
|
"Download single patient result (CSV)", |
|
|
out.to_csv(index=False).encode("utf-8"), |
|
|
file_name="single_patient_prediction.csv", |
|
|
mime="text/csv", |
|
|
key="dl_sp_csv", |
|
|
) |
|
|
|
|
|
|
|
|
if "shap_single_exp" in st.session_state: |
|
|
exp = st.session_state.shap_single_exp |
|
|
|
|
|
max_display_single = st.slider( |
|
|
"Top features to display (single patient)", |
|
|
5, 40, 20, 1, |
|
|
key="sp_single_max_display" |
|
|
) |
|
|
|
|
|
c1, c2 = st.columns(2) |
|
|
|
|
|
with c1: |
|
|
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen) |
|
|
shap.plots.waterfall(exp, show=False, max_display=max_display_single) |
|
|
fig_w = plt.gcf() |
|
|
render_plot_with_download( |
|
|
fig_w, |
|
|
title="Single-patient SHAP waterfall", |
|
|
filename="single_patient_shap_waterfall.png", |
|
|
export_dpi=export_dpi, |
|
|
key="dl_sp_wf" |
|
|
) |
|
|
|
|
|
with c2: |
|
|
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen) |
|
|
shap.plots.bar(exp, show=False, max_display=max_display_single) |
|
|
fig_b = plt.gcf() |
|
|
render_plot_with_download( |
|
|
fig_b, |
|
|
title="Single-patient SHAP bar", |
|
|
filename="single_patient_shap_bar.png", |
|
|
export_dpi=export_dpi, |
|
|
key="dl_sp_bar" |
|
|
) |
|
|
|
|
|
|
|
|
with st.expander("How to interpret the SHAP explanation plots", expanded=True): |
|
|
st.markdown(r""" |
|
|
**What is SHAP?** |
|
|
SHAP (SHapley Additive exPlanations) decomposes a model prediction into **feature-wise contributions**. |
|
|
Each feature pushes the prediction **toward higher risk** or **toward lower risk**, relative to the model’s baseline. |
|
|
|
|
|
**What is \(E[f(X)]\)? (baseline / expected model output)** |
|
|
- \(E[f(X)]\) is the model’s **average output** over the reference population used by the explainer (typically the training set). |
|
|
- It is the **starting point** of the explanation before any patient-specific information is applied. |
|
|
|
|
|
**What is \(f(x)\)? (this patient’s model output)** |
|
|
- \(f(x)\) is the model’s **final output for this patient**, after adding all feature contributions to the baseline. |
|
|
- For many classifiers (including logistic regression), SHAP waterfall plots often use the **log-odds (logit) scale** rather than raw probability. |
|
|
|
|
|
**How the waterfall plot should be read** |
|
|
- The plot starts at **\(E[f(X)]\)** and then adds/subtracts feature effects to arrive at **\(f(x)\)**. |
|
|
- **Red bars** push the prediction **toward higher risk** (increase the model output). |
|
|
- **Blue bars** push the prediction **toward lower risk / protective direction** (decrease the model output). |
|
|
- The **bar length** indicates the **strength** of that feature’s influence for this patient. |
|
|
- “**Other features**” is the combined net effect of features not shown individually. |
|
|
|
|
|
**How the bar plot should be read (Top contributors)** |
|
|
- Features are ranked by **absolute SHAP value** (largest impact first). |
|
|
- **Positive SHAP** (right) increases predicted risk; **negative SHAP** (left) decreases predicted risk. |
|
|
|
|
|
**Clinical cautions** |
|
|
- SHAP explains the model’s behavior; it does **not** prove causality. |
|
|
- Ensure variable definitions and patient population match the model’s intended use. |
|
|
""") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.divider() |
|
|
st.subheader("Batch prediction / External validation") |
|
|
|
|
|
df_inf = st.session_state.df_inf |
|
|
if df_inf is None: |
|
|
st.info("Upload an inference Excel to run batch prediction / external validation.") |
|
|
else: |
|
|
meta = st.session_state.get("meta") |
|
|
pipe = st.session_state.get("pipe") |
|
|
if meta is None or pipe is None: |
|
|
st.error("Model not loaded. Please load a model version first.") |
|
|
st.stop() |
|
|
|
|
|
feature_cols = meta["schema"]["features"] |
|
|
num_cols = meta["schema"]["numeric"] |
|
|
cat_cols = meta["schema"]["categorical"] |
|
|
|
|
|
missing = [c for c in feature_cols if c not in df_inf.columns] |
|
|
if missing: |
|
|
st.error(f"Inference file is missing required feature columns: {missing}") |
|
|
st.stop() |
|
|
meta = st.session_state.get("meta") |
|
|
if not meta: |
|
|
st.error("Model metadata not loaded. Please load a model version.") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
X_inf = coerce_X_like_schema(df_inf, feature_cols, num_cols, cat_cols) |
|
|
|
|
|
X_inf = X_inf.replace({pd.NA: np.nan}) |
|
|
|
|
|
for c in num_cols: |
|
|
X_inf[c] = pd.to_numeric(X_inf[c], errors="coerce") |
|
|
|
|
|
for c in cat_cols: |
|
|
X_inf[c] = X_inf[c].astype("object") |
|
|
X_inf.loc[X_inf[c].isna(), c] = np.nan |
|
|
X_inf[c] = X_inf[c].map(lambda v: v if pd.isna(v) else str(v)) |
|
|
|
|
|
|
|
|
proba = pipe.predict_proba(X_inf)[:, 1] |
|
|
|
|
|
st.divider() |
|
|
st.subheader("External validation (if Outcome Event label is present)") |
|
|
|
|
|
if LABEL_COL in df_inf.columns: |
|
|
try: |
|
|
y_ext_raw = df_inf[LABEL_COL].copy() |
|
|
y_ext01, _ = coerce_binary_label(y_ext_raw) |
|
|
|
|
|
|
|
|
roc_auc_ext = float(roc_auc_score(y_ext01, proba)) |
|
|
fpr, tpr, roc_thresholds = roc_curve(y_ext01, proba) |
|
|
|
|
|
|
|
|
|
|
|
cls_ext = compute_classification_metrics(y_ext01, proba, threshold=float(thr_ext)) |
|
|
|
|
|
pr_ext = compute_pr_curve(y_ext01, proba) |
|
|
cal_ext = compute_calibration(y_ext01, proba, n_bins=PRED_N_BINS, strategy=PRED_CAL_STRATEGY) |
|
|
|
|
|
dca_ext = decision_curve_analysis(y_ext01, proba) |
|
|
|
|
|
|
|
|
c1, c2, c3, c4 = st.columns(4) |
|
|
c1.metric("ROC AUC (external)", f"{roc_auc_ext:.3f}") |
|
|
c2.metric("Sensitivity", f"{cls_ext['sensitivity']:.3f}") |
|
|
c3.metric("Specificity", f"{cls_ext['specificity']:.3f}") |
|
|
c4.metric("F1", f"{cls_ext['f1']:.3f}") |
|
|
|
|
|
|
|
|
cm_df = pd.DataFrame( |
|
|
[[cls_ext["tn"], cls_ext["fp"]], [cls_ext["fn"], cls_ext["tp"]]], |
|
|
index=["Actual 0", "Actual 1"], |
|
|
columns=["Pred 0", "Pred 1"], |
|
|
) |
|
|
st.markdown("**Confusion Matrix (external)**") |
|
|
st.dataframe(cm_df) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen) |
|
|
ax.plot(fpr, tpr) |
|
|
ax.plot([0, 1], [0, 1]) |
|
|
ax.set_xlabel("False Positive Rate (1 - Specificity)") |
|
|
ax.set_ylabel("True Positive Rate (Sensitivity)") |
|
|
ax.set_title(f"External ROC Curve (AUC = {roc_auc_ext:.3f})") |
|
|
|
|
|
render_plot_with_download( |
|
|
fig, |
|
|
title="External ROC curve", |
|
|
filename="external_roc_curve.png", |
|
|
export_dpi=export_dpi, |
|
|
key="dl_ext_roc" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen) |
|
|
ax.plot(pr_ext["recall"], pr_ext["precision"]) |
|
|
ax.set_xlabel("Recall") |
|
|
ax.set_ylabel("Precision") |
|
|
ax.set_title(f"External PR Curve (AP = {pr_ext['average_precision']:.3f})") |
|
|
|
|
|
render_plot_with_download( |
|
|
fig, |
|
|
title="External PR curve", |
|
|
filename="external_pr_curve.png", |
|
|
export_dpi=export_dpi, |
|
|
key="dl_ext_pr" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen) |
|
|
ax.plot(cal_ext["prob_pred"], cal_ext["prob_true"]) |
|
|
ax.plot([0, 1], [0, 1]) |
|
|
ax.set_xlabel("Mean predicted probability") |
|
|
ax.set_ylabel("Observed event rate") |
|
|
ax.set_title("External calibration curve") |
|
|
|
|
|
render_plot_with_download( |
|
|
fig, |
|
|
title="External calibration curve", |
|
|
filename="external_calibration_curve.png", |
|
|
export_dpi=export_dpi, |
|
|
key="dl_ext_cal" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen) |
|
|
ax.plot(dca_ext["thresholds"], dca_ext["net_benefit_model"], label="Model") |
|
|
ax.plot(dca_ext["thresholds"], dca_ext["net_benefit_all"], label="Treat all") |
|
|
ax.plot(dca_ext["thresholds"], dca_ext["net_benefit_none"], label="Treat none") |
|
|
ax.set_xlabel("Threshold probability") |
|
|
ax.set_ylabel("Net benefit") |
|
|
ax.set_title("External decision curve analysis") |
|
|
ax.legend() |
|
|
|
|
|
render_plot_with_download( |
|
|
fig, |
|
|
title="External decision curve", |
|
|
filename="external_decision_curve.png", |
|
|
export_dpi=export_dpi, |
|
|
key="dl_ext_dca" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Could not compute external validation metrics: {e}") |
|
|
else: |
|
|
st.info("No Outcome Event column found in the inference Excel, so external validation metrics cannot be computed.") |
|
|
|
|
|
|
|
|
proba = pipe.predict_proba(X_inf)[:, 1] |
|
|
|
|
|
df_out = df_inf.copy() |
|
|
df_out["predicted_probability"] = proba |
|
|
|
|
|
|
|
|
|
|
|
bundle = st.session_state.get("surv_model", None) |
|
|
|
|
|
if isinstance(bundle, dict) and bundle.get("model") is not None: |
|
|
try: |
|
|
cph = bundle["model"] |
|
|
surv_cols = bundle["columns"] |
|
|
imp = bundle.get("imputer", None) |
|
|
|
|
|
cox_feature_cols = bundle.get("feature_cols", feature_cols) |
|
|
cox_num_cols = bundle.get("num_cols", num_cols) |
|
|
cox_cat_cols = bundle.get("cat_cols", cat_cols) |
|
|
df_surv_in = X_inf[cox_feature_cols].copy() |
|
|
|
|
|
|
|
|
for c in cox_num_cols: |
|
|
if c in df_surv_in.columns: |
|
|
df_surv_in[c] = pd.to_numeric(df_surv_in[c], errors="coerce") |
|
|
|
|
|
for c in cox_cat_cols: |
|
|
if c in df_surv_in.columns: |
|
|
df_surv_in[c] = df_surv_in[c].astype("object") |
|
|
df_surv_in.loc[df_surv_in[c].isna(), c] = np.nan |
|
|
df_surv_in[c] = df_surv_in[c].map(lambda v: v if pd.isna(v) else str(v)) |
|
|
|
|
|
|
|
|
df_surv_in_oh = pd.get_dummies(df_surv_in, columns=cox_cat_cols, drop_first=True) |
|
|
df_surv_in_oh = df_surv_in_oh.loc[:, ~df_surv_in_oh.columns.duplicated()].copy() |
|
|
|
|
|
|
|
|
|
|
|
for col in surv_cols: |
|
|
if col not in df_surv_in_oh.columns: |
|
|
df_surv_in_oh[col] = 0 |
|
|
df_surv_in_oh = df_surv_in_oh[surv_cols] |
|
|
|
|
|
|
|
|
if imp is not None: |
|
|
X_imp = imp.transform(df_surv_in_oh) |
|
|
df_surv_in_oh = pd.DataFrame(X_imp, columns=surv_cols, index=df_surv_in_oh.index) |
|
|
else: |
|
|
df_surv_in_oh = df_surv_in_oh.fillna(0) |
|
|
|
|
|
surv_fn_all = cph.predict_survival_function(df_surv_in_oh) |
|
|
|
|
|
def surv_vec_at(days: int): |
|
|
idx = surv_fn_all.index.values |
|
|
j = int(np.argmin(np.abs(idx - days))) |
|
|
return surv_fn_all.iloc[j, :].values.astype(float) |
|
|
|
|
|
df_out["survival_6m"] = surv_vec_at(180) |
|
|
df_out["survival_1y"] = surv_vec_at(365) |
|
|
df_out["survival_2y"] = surv_vec_at(730) |
|
|
df_out["survival_3y"] = surv_vec_at(1095) |
|
|
|
|
|
except Exception as e: |
|
|
st.warning(f"Batch survival probabilities could not be computed: {e}") |
|
|
|
|
|
else: |
|
|
st.info("Survival model not loaded/published yet (survival bundle missing).") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if LABEL_COL in df_out.columns: |
|
|
curves_df = df_out[[LABEL_COL, "predicted_probability"]].copy() |
|
|
|
|
|
st.download_button( |
|
|
"Download CSV for ROC/Calibration/DCA panel", |
|
|
curves_df.to_csv(index=False).encode("utf-8"), |
|
|
file_name="predictions_for_curves.csv", |
|
|
mime="text/csv", |
|
|
key="dl_predictions_for_curves" |
|
|
) |
|
|
else: |
|
|
st.info("Outcome Event column not present — panel CSV export requires true labels.") |
|
|
|
|
|
df_out = add_ethnicity_region(df_out, eth_col="Ethnicity", out_col="Ethnicity_Region") |
|
|
|
|
|
st.subheader("Grouped outcomes by region (analytics)") |
|
|
if "Ethnicity_Region" in df_out.columns: |
|
|
grp = df_out.groupby("Ethnicity_Region")["predicted_probability"].agg(["count","mean","median"]).reset_index() |
|
|
st.dataframe(grp, use_container_width=True) |
|
|
|
|
|
|
|
|
st.divider() |
|
|
st.subheader("Risk stratification") |
|
|
|
|
|
|
|
|
thr = st.slider( |
|
|
"Decision threshold for classification", |
|
|
0.0, 1.0, 0.5, 0.01, |
|
|
key="pred_thr" |
|
|
) |
|
|
|
|
|
df_out["predicted_class"] = (df_out["predicted_probability"] >= thr).astype(int) |
|
|
|
|
|
|
|
|
|
|
|
low_cut, high_cut = st.slider( |
|
|
"Risk band cutoffs (low, high)", |
|
|
0.0, 1.0, (0.2, 0.8), 0.01, |
|
|
key="batch_risk_cuts" |
|
|
) |
|
|
|
|
|
|
|
|
if low_cut > high_cut: |
|
|
low_cut, high_cut = high_cut, low_cut |
|
|
|
|
|
|
|
|
def band(p: float) -> str: |
|
|
if p < low_cut: |
|
|
return "Low" |
|
|
if p >= high_cut: |
|
|
return "High" |
|
|
return "Intermediate" |
|
|
|
|
|
|
|
|
df_out["risk_band"] = df_out["predicted_probability"].map(band) |
|
|
|
|
|
|
|
|
|
|
|
st.dataframe(df_out.head()) |
|
|
|
|
|
st.download_button( |
|
|
"Download predictions", |
|
|
df_out.to_csv(index=False).encode(), |
|
|
"predictions.csv", |
|
|
"text/csv" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
st.divider() |
|
|
st.subheader("Batch SHAP (first 200 rows)") |
|
|
|
|
|
MAX_BATCH = 200 |
|
|
n_rows = len(X_inf) |
|
|
batch_n = min(MAX_BATCH, n_rows) |
|
|
|
|
|
cA, cB, cC = st.columns([1, 1, 1]) |
|
|
with cA: |
|
|
do_batch = st.button(f"Compute batch SHAP for first {batch_n} rows", key="batch_shap_btn") |
|
|
with cB: |
|
|
max_display_batch = st.slider("Top features to display (batch)",5, 40, 20, 1,key="batch_max_display") |
|
|
|
|
|
with cC: |
|
|
show_beeswarm = st.checkbox("Show beeswarm (slower)", value=True, key="batch_beeswarm") |
|
|
|
|
|
if do_batch: |
|
|
with st.spinner("Computing batch SHAP..."): |
|
|
X_batch = X_inf.iloc[:batch_n].copy() |
|
|
|
|
|
X_batch_t = transform_before_clf(pipe, X_batch) |
|
|
|
|
|
explainer = st.session_state.get("explainer") |
|
|
explainer_sig = st.session_state.get("explainer_sig") |
|
|
|
|
|
|
|
|
|
|
|
current_sig = ( |
|
|
selected, |
|
|
None if st.session_state.get("X_bg_for_shap") is None else int(len(st.session_state["X_bg_for_shap"])) |
|
|
) |
|
|
|
|
|
if explainer is None or explainer_sig != current_sig: |
|
|
X_bg = st.session_state.get("X_bg_for_shap") |
|
|
if X_bg is None: |
|
|
st.error("SHAP background not available. Admin must publish latest/background.csv.") |
|
|
st.stop() |
|
|
|
|
|
st.session_state.explainer = build_shap_explainer(pipe, X_bg) |
|
|
st.session_state.explainer_sig = current_sig |
|
|
explainer = st.session_state.explainer |
|
|
|
|
|
|
|
|
shap_vals_batch = explainer.shap_values(X_batch_t) |
|
|
if isinstance(shap_vals_batch, list): |
|
|
shap_vals_batch = shap_vals_batch[1] |
|
|
|
|
|
names = get_final_feature_names(pipe) |
|
|
|
|
|
|
|
|
if len(names) != shap_vals_batch.shape[1]: |
|
|
st.warning( |
|
|
f"Feature name mismatch: names={len(names)} vs shap_cols={shap_vals_batch.shape[1]}. " |
|
|
"Using generic names." |
|
|
) |
|
|
names = [f"f{i}" for i in range(shap_vals_batch.shape[1])] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
X_dense = safe_dense(X_batch_t, max_rows=200) |
|
|
|
|
|
except Exception: |
|
|
X_dense = np.array(X_batch_t) |
|
|
|
|
|
|
|
|
st.session_state.shap_batch_vals = shap_vals_batch |
|
|
st.session_state.shap_batch_X_dense = X_dense |
|
|
st.session_state.shap_batch_n = batch_n |
|
|
st.session_state.shap_batch_feature_names = names |
|
|
|
|
|
st.success(f"Batch SHAP computed for first {batch_n} rows.") |
|
|
|
|
|
if "shap_batch_vals" in st.session_state: |
|
|
shap_vals_batch = st.session_state.shap_batch_vals |
|
|
X_dense = st.session_state.shap_batch_X_dense |
|
|
batch_n = st.session_state.shap_batch_n |
|
|
names = st.session_state.shap_batch_feature_names |
|
|
|
|
|
st.divider() |
|
|
st.subheader("Export: Top SHAP features per row (batch)") |
|
|
|
|
|
top_k = st.slider("Top-K features per row", 3, 30, 10, 1, key="topk_export") |
|
|
|
|
|
|
|
|
|
|
|
include_proba = st.checkbox("Include predicted probability", value=True, key="include_proba_export") |
|
|
|
|
|
if st.button("Generate Top-K SHAP table", key="gen_topk_shap"): |
|
|
shap_vals_batch = st.session_state.shap_batch_vals |
|
|
names = st.session_state.shap_batch_feature_names |
|
|
batch_n = st.session_state.shap_batch_n |
|
|
|
|
|
rows = [] |
|
|
for i in range(batch_n): |
|
|
sv = shap_vals_batch[i] |
|
|
idx = np.argsort(np.abs(sv))[::-1][:top_k] |
|
|
|
|
|
for j in idx: |
|
|
val = float(sv[j]) |
|
|
rows.append({ |
|
|
"row_in_batch": int(i), |
|
|
"feature": str(names[j]), |
|
|
"shap_value": val, |
|
|
"abs_shap_value": abs(val), |
|
|
"direction": "↑" if val > 0 else ("↓" if val < 0 else "0"), |
|
|
}) |
|
|
|
|
|
df_topk = pd.DataFrame(rows) |
|
|
|
|
|
if include_proba: |
|
|
|
|
|
|
|
|
proba_batch = proba[:batch_n] |
|
|
df_proba = pd.DataFrame({"row_in_batch": list(range(batch_n)), "predicted_probability": proba_batch}) |
|
|
df_topk = df_topk.merge(df_proba, on="row_in_batch", how="left") |
|
|
|
|
|
|
|
|
df_topk = df_topk.sort_values(["row_in_batch", "abs_shap_value"], ascending=[True, False]) |
|
|
|
|
|
st.dataframe(df_topk, use_container_width=True) |
|
|
|
|
|
st.download_button( |
|
|
"Download Top-K SHAP per row (CSV)", |
|
|
df_topk.to_csv(index=False).encode("utf-8"), |
|
|
file_name=f"shap_top{top_k}_per_row_first{batch_n}.csv", |
|
|
mime="text/csv", |
|
|
key="dl_topk_shap_csv" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.markdown(f"### Global SHAP summary (first {batch_n} rows)") |
|
|
|
|
|
|
|
|
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen) |
|
|
shap.summary_plot( |
|
|
shap_vals_batch, |
|
|
features=X_dense, |
|
|
feature_names=names, |
|
|
plot_type="bar", |
|
|
max_display=max_display_batch, |
|
|
show=False |
|
|
) |
|
|
fig_bar = plt.gcf() |
|
|
render_plot_with_download(fig_bar, title="SHAP bar summary", filename="shap_summary_bar.png", export_dpi=export_dpi) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if show_beeswarm: |
|
|
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen) |
|
|
shap.summary_plot( |
|
|
shap_vals_batch, |
|
|
features=X_dense, |
|
|
feature_names=names, |
|
|
max_display=max_display_batch, |
|
|
show=False |
|
|
) |
|
|
fig_swarm = plt.gcf() |
|
|
render_plot_with_download(fig_swarm, title="SHAP beeswarm", filename="shap_beeswarm.png", export_dpi=export_dpi) |
|
|
|
|
|
|
|
|
|
|
|
st.markdown("### Waterfall plots (batch)") |
|
|
|
|
|
rows_to_plot = st.multiselect( |
|
|
"Select rows (within the first batch) to plot waterfalls", |
|
|
options=list(range(batch_n)), |
|
|
default=[0], |
|
|
key="batch_rows_to_plot" |
|
|
) |
|
|
max_display_single = st.slider("Top features to display (single-row SHAP)",5, 40, 20, 1,key="single_max_display") |
|
|
|
|
|
max_waterfalls = st.slider("Max waterfall plots to render", 1, 10, 3, 1, key="max_waterfalls") |
|
|
rows_to_plot = rows_to_plot[:max_waterfalls] |
|
|
|
|
|
explainer = st.session_state.get("explainer") |
|
|
base = explainer.expected_value |
|
|
if not np.isscalar(base): |
|
|
base = float(np.array(base).reshape(-1)[0]) |
|
|
|
|
|
for r in rows_to_plot: |
|
|
st.markdown(f"**Row {r} (within first {batch_n})**") |
|
|
|
|
|
exp = shap.Explanation( |
|
|
values=shap_vals_batch[r], |
|
|
base_values=float(base), |
|
|
data=X_dense[r], |
|
|
feature_names=names, |
|
|
) |
|
|
|
|
|
|
|
|
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen) |
|
|
shap.plots.waterfall(exp, show=False, max_display=max_display_single) |
|
|
fig_w = plt.gcf() |
|
|
render_plot_with_download( |
|
|
fig_w, |
|
|
title=f"Batch SHAP waterfall (row {r})", |
|
|
filename=f"shap_waterfall_batch_row_{r}.png", |
|
|
export_dpi=export_dpi, |
|
|
key=f"dl_wf_batch_{r}" |
|
|
) |
|
|
|
|
|
|
|
|
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen) |
|
|
shap.plots.bar(exp, show=False, max_display=max_display_single) |
|
|
fig_b = plt.gcf() |
|
|
render_plot_with_download( |
|
|
fig_b, |
|
|
title=f"Batch SHAP bar (row {r})", |
|
|
filename=f"shap_bar_batch_row_{r}.png", |
|
|
export_dpi=export_dpi, |
|
|
key=f"dl_bar_batch_{r}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.subheader("SHAP explanation") |
|
|
|
|
|
|
|
|
with st.form("shap_form"): |
|
|
row = st.number_input("Row index", 0, len(X_inf) - 1, int(st.session_state.get("shap_row", 0))) |
|
|
explain_btn = st.form_submit_button("Generate SHAP explanation") |
|
|
|
|
|
if explain_btn: |
|
|
st.session_state.shap_row = int(row) |
|
|
|
|
|
X_one = X_inf.iloc[[int(row)]] |
|
|
X_one_t = transform_before_clf(pipe, X_one) |
|
|
|
|
|
explainer = st.session_state.get("explainer") |
|
|
explainer_sig = st.session_state.get("explainer_sig") |
|
|
|
|
|
|
|
|
|
|
|
current_sig = ( |
|
|
selected, |
|
|
None if st.session_state.get("X_bg_for_shap") is None else int(len(st.session_state["X_bg_for_shap"])) |
|
|
) |
|
|
|
|
|
if explainer is None or explainer_sig != current_sig: |
|
|
X_bg = st.session_state.get("X_bg_for_shap") |
|
|
if X_bg is None: |
|
|
st.error("SHAP background not available. Admin must publish latest/background.csv.") |
|
|
st.stop() |
|
|
|
|
|
st.session_state.explainer = build_shap_explainer(pipe, X_bg) |
|
|
st.session_state.explainer_sig = current_sig |
|
|
explainer = st.session_state.explainer |
|
|
|
|
|
|
|
|
shap_vals = explainer.shap_values(X_one_t) |
|
|
if isinstance(shap_vals, list): |
|
|
shap_vals = shap_vals[1] |
|
|
|
|
|
names = get_final_feature_names(pipe) |
|
|
if len(names) != shap_vals.shape[1]: |
|
|
names = [f"f{i}" for i in range(shap_vals.shape[1])] |
|
|
|
|
|
try: |
|
|
x_dense = X_one_t.toarray()[0] |
|
|
except Exception: |
|
|
x_dense = np.array(X_one_t)[0] |
|
|
|
|
|
base = explainer.expected_value |
|
|
if not np.isscalar(base): |
|
|
base = float(np.array(base).reshape(-1)[0]) |
|
|
|
|
|
exp = shap.Explanation( |
|
|
values=shap_vals[0], |
|
|
base_values=float(base), |
|
|
data=x_dense, |
|
|
feature_names=names, |
|
|
) |
|
|
|
|
|
|
|
|
st.session_state.shap_single_exp = exp |
|
|
|
|
|
|
|
|
if "shap_single_exp" in st.session_state: |
|
|
exp = st.session_state.shap_single_exp |
|
|
|
|
|
max_display_single_row = st.slider( |
|
|
"Top features to display (single row)", |
|
|
5, 40, int(st.session_state.get("shap_single_max", 20)), 1, |
|
|
key="shap_single_max" |
|
|
) |
|
|
|
|
|
c1, c2 = st.columns(2) |
|
|
|
|
|
with c1: |
|
|
st.markdown("**Waterfall**") |
|
|
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen) |
|
|
shap.plots.waterfall(exp, show=False, max_display=max_display_single_row) |
|
|
fig_w = plt.gcf() |
|
|
render_plot_with_download( |
|
|
fig_w, |
|
|
title="SHAP waterfall", |
|
|
filename="shap_waterfall_row.png", |
|
|
export_dpi=export_dpi, |
|
|
key="dl_shap_wf_single" |
|
|
) |
|
|
|
|
|
with c2: |
|
|
st.markdown("**Top features**") |
|
|
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen) |
|
|
shap.plots.bar(exp, show=False, max_display=max_display_single_row) |
|
|
fig_b = plt.gcf() |
|
|
render_plot_with_download( |
|
|
fig_b, |
|
|
title="SHAP bar", |
|
|
filename="shap_bar_row.png", |
|
|
export_dpi=export_dpi, |
|
|
key="dl_shap_bar_single" |
|
|
) |
|
|
else: |
|
|
st.info("Submit a row index to generate the SHAP explanation.") |
|
|
|
|
|
|
|
|
|
|
|
|