import numpy as np import pandas as pd import skops.io as sio import shap import plotly.graph_objects as go import os import sys import warnings warnings.filterwarnings("ignore", category=UserWarning, module="sklearn") import sklearn.compose._column_transformer as _ct if not hasattr(_ct, "_RemainderColsList"): class _RemainderColsList(list): """Minimal shim for sklearn._RemainderColsList (missing in this env).""" def __init__(self, lst=None, future_dtype=None): super().__init__(lst or []) self.future_dtype = future_dtype _ct._RemainderColsList = _RemainderColsList import sklearn.compose sklearn.compose._RemainderColsList = _RemainderColsList NUM_COLUMNS = ["AGE", "NACS2YR"] CATEG_COLUMNS = [ "AGEGPFF", "SEX", "KPS", "DONORF", "GRAFTYPE", "CONDGRPF", "CONDGRP_FINAL", "ATGF", "GVHD_FINAL", "HLA_FINAL", "RCMVPR", "EXCHTFPR", "VOC2YPR", "VOCFRQPR", "SCATXRSN", ] FEATURE_NAMES = NUM_COLUMNS + CATEG_COLUMNS OUTCOMES = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI", "DWOGF"] CLASSIFICATION_OUTCOMES = OUTCOMES REPORTING_OUTCOMES = [ "OS", "EFS", "GF", "DEAD", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI", ] OUTCOME_DESCRIPTIONS = { "OS": "Overall Survival", "EFS": "Event-Free Survival", "DEAD": "Death", "GF": "Graft Failure", "AGVHD": "Acute Graft-versus-Host Disease", "CGVHD": "Chronic Graft-versus-Host Disease", "VOCPSHI": "Vaso-Occlusive Crisis Post-HCT", "STROKEHI": "Stroke Post-HCT", } SHAP_OUTCOMES = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI", "OS", "EFS"] MODEL_DIR = "." CONSENSUS_THRESHOLD = 0.5 DEFAULT_N_BOOT_CI = 500 def _load_skops_model(fname): try: untrusted = sio.get_untrusted_types(file=fname) return sio.load(fname, trusted=untrusted) except Exception as e: print(f"Error loading '{fname}': {e}") sys.exit(1) preprocessor = _load_skops_model(os.path.join(MODEL_DIR, "preprocessor.skops")) classification_model_data = {} for _o in CLASSIFICATION_OUTCOMES: _path = os.path.join(MODEL_DIR, f"ensemble_model_{_o}.skops") if os.path.exists(_path): classification_model_data[_o] = _load_skops_model(_path) else: print(f"Warning: Model for {_o} not found at {_path}. Skipping.") classification_models = {o: d["models"] for o, d in classification_model_data.items()} betas = {o: d["beta"] for o, d in classification_model_data.items()} priors = {o: d["prior"] for o, d in classification_model_data.items()} consensus_thresholds = { o: d.get("consensus_threshold", CONSENSUS_THRESHOLD) for o, d in classification_model_data.items() } calibrators = {} for _o, _d in classification_model_data.items(): _cal = None _cal_type = _d.get("calibrator_type", None) if "calibrator" in _d and _d["calibrator"] is not None: if _cal_type is None or _cal_type == "isotonic": _cal = _d["calibrator"] else: print( f"Warning: outcome '{_o}' has calibrator_type='{_cal_type}'. " "Skipping non-isotonic calibrator (isotonic-only policy)." ) elif "isotonic_calibrator" in _d and _d["isotonic_calibrator"] is not None: _cal = _d["isotonic_calibrator"] calibrators[_o] = _cal # Alias expected by app.py isotonic_calibrators = calibrators oof_probs_calibrated = { o: d.get("oof_probs_calibrated") for o, d in classification_model_data.items() } ohe = preprocessor.named_transformers_["cat"] ohe_feature_names = ohe.get_feature_names_out(CATEG_COLUMNS) processed_feature_names = np.concatenate([NUM_COLUMNS, ohe_feature_names]) np.random.seed(23) _n_background = 500 _background_data = { "AGE": np.random.uniform(5, 50, _n_background), "NACS2YR": np.random.randint(0, 5, _n_background), "AGEGPFF": np.random.choice(["<=10", "11-17", "18-29", "30-49", ">=50"], _n_background), "SEX": np.random.choice(["Male", "Female"], _n_background), "KPS": np.random.choice(["<90", "≥ 90"], _n_background), "DONORF": np.random.choice([ "HLA identical sibling", "HLA mismatch relative", "Matched unrelated donor", "Mismatched unrelated donor or cord blood", ], _n_background), "GRAFTYPE": np.random.choice(["Bone marrow", "Peripheral blood", "Cord blood"], _n_background), "CONDGRPF": np.random.choice(["MAC", "RIC", "NMA"], _n_background), "CONDGRP_FINAL": np.random.choice(["TBI/Cy", "Bu/Cy", "Flu/Bu", "Flu/Mel"], _n_background), "ATGF": np.random.choice(["ATG", "Alemtuzumab", "None"], _n_background), "GVHD_FINAL": np.random.choice(["CNI + MMF", "CNI + MTX", "Post-CY + siro +- MMF"], _n_background), "HLA_FINAL": np.random.choice(["8/8", "7/8", "≤ 6/8"], _n_background), "RCMVPR": np.random.choice(["Negative", "Positive"], _n_background), "EXCHTFPR": np.random.choice(["No", "Yes"], _n_background), "VOC2YPR": np.random.choice(["No", "Yes"], _n_background), "VOCFRQPR": np.random.choice(["< 3/yr", "≥ 3/yr"], _n_background), "SCATXRSN": np.random.choice([ "CNS event", "Acute chest Syndrome", "Recurrent vaso-occlusive pain", "Recurrent priapism", "Excessive transfusion requirements/iron overload", "Cardio-pulmonary", "Chronic transfusion", "Asymptomatic", "Renal insufficiency", "Splenic sequestration", "Avascular necrosis", "Hodgkin lymphoma", ], _n_background), } _background_df = pd.DataFrame(_background_data)[FEATURE_NAMES] _X_background = preprocessor.transform(_background_df) shap_background = shap.maskers.Independent(_X_background) def calibrate_probabilities_undersampling(p_s, beta): p_s = np.asarray(p_s, dtype=float) numerator = beta * p_s denominator = np.maximum((beta - 1.0) * p_s + 1.0, 1e-10) return np.clip(numerator / denominator, 0.0, 1.0) def predict_consensus_signed_voting(ensemble_models, X_test, threshold=0.5): individual_probas = np.array( [m.predict_proba(X_test)[:, 1] for m in ensemble_models] ) binary_preds = (individual_probas >= threshold).astype(int) signed_votes = np.where(binary_preds == 1, 1, -1) avg_signed_vote = np.mean(signed_votes, axis=0) consensus_pred = (avg_signed_vote > 0).astype(int) avg_proba = np.mean(individual_probas, axis=0) return consensus_pred, avg_proba, avg_signed_vote, individual_probas.flatten() def predict_consensus_majority(ensemble_models, X_test, threshold=0.5): individual_probas = np.array( [m.predict_proba(X_test)[:, 1] for m in ensemble_models] ) avg_proba = np.mean(individual_probas, axis=0) return avg_proba, individual_probas.flatten() def bootstrap_ci_from_oof( point_estimate: float, oof_probs: np.ndarray, n_boot: int = DEFAULT_N_BOOT_CI, confidence: float = 0.95, random_state: int = 42, ) -> tuple: if oof_probs is None or len(oof_probs) == 0: return float(point_estimate), float(point_estimate) oof_probs = np.asarray(oof_probs, dtype=float) rng = np.random.RandomState(random_state) grand_mean = np.mean(oof_probs) n = len(oof_probs) boot_means = np.array([ np.mean(rng.choice(oof_probs, size=n, replace=True)) for _ in range(n_boot) ]) shift = point_estimate - grand_mean boot_means = boot_means + shift alpha = 1.0 - confidence lo = float(np.clip(np.percentile(boot_means, 100 * alpha / 2), 0.0, 1.0)) hi = float(np.clip(np.percentile(boot_means, 100 * (1 - alpha / 2)), 0.0, 1.0)) return lo, hi def _calibrate_point(outcome: str, raw_prob: float, use_calibration: bool) -> float: beta = betas[outcome] p_beta = float(calibrate_probabilities_undersampling([raw_prob], beta)[0]) if not use_calibration: return p_beta cal = calibrators.get(outcome) if cal is None: return p_beta return float(cal.transform([p_beta])[0]) def predict_all_outcomes( user_inputs, use_calibration: bool = True, use_signed_voting: bool = True, n_boot_ci: int = DEFAULT_N_BOOT_CI, ): if isinstance(user_inputs, dict): input_df = pd.DataFrame([user_inputs]) else: input_df = pd.DataFrame([user_inputs], columns=FEATURE_NAMES) input_df = input_df[FEATURE_NAMES] X = preprocessor.transform(input_df) probs, intervals = {}, {} for o in CLASSIFICATION_OUTCOMES: if o not in classification_models: continue threshold = consensus_thresholds.get(o, CONSENSUS_THRESHOLD) if use_signed_voting: _, uncalib_arr, _, _ = predict_consensus_signed_voting( classification_models[o], X, threshold ) else: uncalib_arr, _ = predict_consensus_majority( classification_models[o], X, threshold ) raw_prob = float(uncalib_arr[0]) event_prob = _calibrate_point(o, raw_prob, use_calibration) lo, hi = bootstrap_ci_from_oof( point_estimate=event_prob, oof_probs=oof_probs_calibrated.get(o), n_boot=n_boot_ci, ) probs[o] = event_prob intervals[o] = (lo, hi) if "DEAD" in probs: p_dead = probs["DEAD"] probs["OS"] = float(1.0 - p_dead) dead_lo, dead_hi = intervals["DEAD"] intervals["OS"] = ( float(np.clip(1.0 - dead_hi, 0, 1)), float(np.clip(1.0 - dead_lo, 0, 1)), ) if "DWOGF" in probs and "GF" in probs: p_dwogf = probs["DWOGF"] p_gf = probs["GF"] probs["EFS"] = float(np.clip(1.0 - p_dwogf - p_gf, 0.0, 1.0)) oof_dwogf = oof_probs_calibrated.get("DWOGF") oof_gf = oof_probs_calibrated.get("GF") if oof_dwogf is not None and oof_gf is not None: oof_dwogf = np.asarray(oof_dwogf, dtype=float) oof_gf = np.asarray(oof_gf, dtype=float) n_min = min(len(oof_dwogf), len(oof_gf)) oof_dwogf = oof_dwogf[:n_min] oof_gf = oof_gf[:n_min] rng = np.random.RandomState(42) grand_dwogf = np.mean(oof_dwogf) grand_gf = np.mean(oof_gf) shift_dwogf = p_dwogf - grand_dwogf shift_gf = p_gf - grand_gf efs_boot = np.array([ np.clip( 1.0 - (np.mean(rng.choice(oof_dwogf, size=n_min, replace=True)) + shift_dwogf) - (np.mean(rng.choice(oof_gf, size=n_min, replace=True)) + shift_gf), 0.0, 1.0, ) for _ in range(DEFAULT_N_BOOT_CI) ]) efs_lo = float(np.percentile(efs_boot, 2.5)) efs_hi = float(np.percentile(efs_boot, 97.5)) intervals["EFS"] = (efs_lo, efs_hi) else: intervals["EFS"] = (probs["EFS"], probs["EFS"]) return probs, intervals def predict_with_comparison(user_inputs, n_boot_ci: int = DEFAULT_N_BOOT_CI): cal_probs, cal_intervals = predict_all_outcomes(user_inputs, True, True, n_boot_ci) uncal_probs, uncal_intervals = predict_all_outcomes(user_inputs, False, True, n_boot_ci) return (cal_probs, cal_intervals), (uncal_probs, uncal_intervals) def _get_shap_values_for_model_outcome(user_inputs, model_outcome, invert, X_proc): """Return per-model SHAP values (shape: n_models × n_processed_features).""" all_model_shap_vals = [] for rf_model in classification_models[model_outcome]: explainer = shap.TreeExplainer(rf_model, model_output="probability", data=shap_background) shap_vals = explainer.shap_values(X_proc) if isinstance(shap_vals, list): shap_vals = shap_vals[1] elif shap_vals.ndim == 3 and shap_vals.shape[2] == 2: shap_vals = shap_vals[:, :, 1] sv = shap_vals[0] if invert: sv = -sv all_model_shap_vals.append(sv) return np.array(all_model_shap_vals) def compute_shap_values_with_direction(user_inputs, outcome, max_display=10): if isinstance(user_inputs, dict): input_df = pd.DataFrame([user_inputs]) else: input_df = pd.DataFrame([user_inputs], columns=FEATURE_NAMES) X_proc = preprocessor.transform(input_df) processed_to_orig = {f: f for f in NUM_COLUMNS} for pf in ohe_feature_names: processed_to_orig[pf] = pf.split("_", 1)[0] if outcome == "OS": raw_shap = _get_shap_values_for_model_outcome(user_inputs, "DEAD", invert=True, X_proc=X_proc) elif outcome == "EFS": shap_dwogf = _get_shap_values_for_model_outcome(user_inputs, "DWOGF", invert=True, X_proc=X_proc) shap_gf = _get_shap_values_for_model_outcome(user_inputs, "GF", invert=True, X_proc=X_proc) raw_shap = np.concatenate([shap_dwogf, shap_gf], axis=0) else: raw_shap = _get_shap_values_for_model_outcome(user_inputs, outcome, invert=False, X_proc=X_proc) unique_orig_features = list(dict.fromkeys(processed_to_orig.values())) n_models = len(raw_shap) model_shap_by_orig = np.zeros((n_models, len(unique_orig_features))) for model_idx in range(n_models): agg_by_orig = {} for i, pf in enumerate(processed_feature_names): orig = processed_to_orig[pf] agg_by_orig.setdefault(orig, 0.0) agg_by_orig[orig] += raw_shap[model_idx, i] for feat_idx, feat_name in enumerate(unique_orig_features): model_shap_by_orig[model_idx, feat_idx] = agg_by_orig.get(feat_name, 0.0) mean_shap_vals = np.mean(model_shap_by_orig, axis=0) rng = np.random.RandomState(42) bootstrap_shap_means = np.array([ np.mean(model_shap_by_orig[rng.choice(n_models, size=n_models, replace=True)], axis=0) for _ in range(DEFAULT_N_BOOT_CI) ]) shap_ci_low = np.percentile(bootstrap_shap_means, 2.5, axis=0) shap_ci_high = np.percentile(bootstrap_shap_means, 97.5, axis=0) order = np.argsort(-np.abs(mean_shap_vals)) top_feat_names = [] for i in order[:max_display]: feat_name = unique_orig_features[i] if feat_name in user_inputs: val = user_inputs[feat_name] if isinstance(val, float) and val != int(val): display_name = f"{feat_name} = {val:.2f}" elif isinstance(val, (int, float)): display_name = f"{feat_name} = {int(val)}" else: val_str = str(val) if len(val_str) > 20: val_str = val_str[:17] + "..." display_name = f"{feat_name} = {val_str}" else: display_name = feat_name top_feat_names.append(display_name) top_feat_names = top_feat_names[::-1] top_shap_vals = mean_shap_vals[order][:max_display][::-1] top_ci_low = shap_ci_low[order][:max_display][::-1] top_ci_high = shap_ci_high[order][:max_display][::-1] return top_feat_names, top_shap_vals, top_ci_low, top_ci_high def create_shap_plot(user_inputs, outcome, max_display=10): feat_names, shap_vals, ci_low, ci_high = compute_shap_values_with_direction( user_inputs, outcome, max_display ) colors = ["blue" if v >= 0 else "red" for v in shap_vals] error_minus = shap_vals - ci_low error_plus = ci_high - shap_vals fig = go.Figure() fig.add_trace(go.Bar( y=feat_names, x=shap_vals, orientation="h", marker=dict(color=colors), showlegend=False, error_x=dict( type="data", symmetric=False, array=error_plus, arrayminus=error_minus, color="gray", thickness=1.5, width=4, ), )) fig.add_vline(x=0, line_width=1, line_color="black") fig.update_layout( title=dict( text=OUTCOME_DESCRIPTIONS.get(outcome, outcome), x=0.5, xanchor="center", font=dict(size=14, color="black"), ), xaxis_title="SHAP value", yaxis_title="", height=400, margin=dict(l=120, r=60, t=50, b=50), plot_bgcolor="white", paper_bgcolor="white", xaxis=dict(showgrid=True, gridcolor="lightgray", zeroline=True, zerolinecolor="black", zerolinewidth=1), yaxis=dict(showgrid=False), ) return fig def create_all_shap_plots(user_inputs, max_display=10): return {o: create_shap_plot(user_inputs, o, max_display) for o in SHAP_OUTCOMES} EVENT_COLOR = "#e53935" NO_EVENT_COLOR = "#43a047" OUTCOME_TITLES = { "DEAD": "TDeath", "GF": "Graft Failure", "AGVHD": "Acute GvHD", "CGVHD": "Chronic GvHD", "VOCPSHI": "Vaso-Occlusive Crisis", "STROKEHI": "Stroke Post-HCT", } OUTCOME_LABELS = { "DEAD": ("Death", "No Death"), "GF": ("Graft Failure", "No Graft Failure"), "AGVHD": ("Acute GVHD", "No Acute GVHD"), "CGVHD": ("Chronic GVHD", "No Chronic GVHD"), "VOCPSHI": ("VOC", "No VOC"), "STROKEHI": ("Stroke", "No Stroke"), } def _stick_figure_svg(color: str, size: int = 16) -> str: """Inline SVG stick figure. ViewBox 0 0 20 32 (portrait).""" h = round(size * 1.6) return ( f'' f'' f'' f'' f'' f'' f'' ) def create_icon_array_html(probability: float, outcome: str) -> str: """ Single outcome card: 10×10 stick-figure grid, red=event, green=no event. All cards have identical fixed-height sections so they align in the grid. Legend always shows 'Event (N/100)' and 'No Event (N/100)' — never varies. """ title = OUTCOME_TITLES.get(outcome, OUTCOME_DESCRIPTIONS.get(outcome, outcome)) event_label, no_event_label = OUTCOME_LABELS.get(outcome, ("Event", "No Event")) n_event = round(probability * 100) n_no_event = 100 - n_event pct_str = f"{probability * 100:.1f}%" rows_parts = [] for row in range(10): cells = "" for col in range(10): idx = row * 10 + col color = EVENT_COLOR if idx < n_event else NO_EVENT_COLOR cells += _stick_figure_svg(color, size=16) rows_parts.append( f'
{cells}
' ) grid_html = "\n".join(rows_parts) # --- legend: fixed two-line block, identical text for every card --- fig_event = _stick_figure_svg(EVENT_COLOR, size=13) fig_no_event = _stick_figure_svg(NO_EVENT_COLOR, size=13) legend_html = ( f'
' f'{fig_event}' f'{event_label}' f'({n_event}/100)' f'{fig_no_event}' f'{no_event_label}' f'({n_no_event}/100)' f'
' ) return ( f'
' # title — fixed height, 2-line max via min-height f'
' f'{title}
' # probability number f'
{pct_str}
' # icon grid f'
{grid_html}
' # legend — always 2 fixed-height rows f'
{legend_html}
' f'
' ) def create_all_icon_arrays(calibrated_probs: dict) -> dict: """ Returns individual cards + a combined '__grid__' key with the 4×2 layout. All 6 cards are rendered at equal flex widths and equal internal heights. """ pie_outcomes = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI"] cards = {o: create_icon_array_html(calibrated_probs[o], o) for o in pie_outcomes} rows_html = "" for row_start in range(0, len(pie_outcomes), 4): row_outcomes = pie_outcomes[row_start: row_start + 4] cols = "".join( f'
{cards[o]}
' for o in row_outcomes ) rows_html += ( f'
{cols}
' ) footnote = ( f'
' f'Each figure = 1 patient out of 100 with similar characteristics.  ' f'■ Red = Event' f'  ' f'■ Green = No Event' f'
' ) cards["__grid__"] = ( f'
' f'{rows_html}{footnote}
' ) return cards def create_pie_chart(probability, outcome): return create_icon_array_html(probability, outcome) def create_all_pie_charts(calibrated_probs): return create_all_icon_arrays(calibrated_probs)