import numpy as np import pandas as pd import skops.io as sio import shap import plotly.graph_objects as go import os import sys import warnings warnings.filterwarnings("ignore", category=UserWarning, module="sklearn") print("===== Application Startup =====") print(f"Working directory: {os.getcwd()}") print(f"Files present: {os.listdir('.')}") # --------------------------------------------------------------------------- # Compatibility patch # --------------------------------------------------------------------------- import sklearn.compose._column_transformer as _ct if not hasattr(_ct, "_RemainderColsList"): class _RemainderColsList(list): def __init__(self, lst=None, future_dtype=None): super().__init__(lst or []) self.future_dtype = future_dtype _ct._RemainderColsList = _RemainderColsList import sklearn.compose sklearn.compose._RemainderColsList = _RemainderColsList print("Patched _RemainderColsList into sklearn.compose") # --------------------------------------------------------------------------- # Column / feature definitions # --------------------------------------------------------------------------- NUM_COLUMNS = ["AGE", "NACS2YR"] CATEG_COLUMNS = [ "AGEGPFF", "SEX", "KPS", "DONORF", "GRAFTYPE", "CONDGRPF", "CONDGRP_FINAL", "ATGF", "GVHD_FINAL", "HLA_FINAL", "RCMVPR", "EXCHTFPR", "VOC2YPR", "VOCFRQPR", "SCATXRSN", ] FEATURE_NAMES = NUM_COLUMNS + CATEG_COLUMNS OUTCOMES = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI", "DWOGF"] CLASSIFICATION_OUTCOMES = OUTCOMES REPORTING_OUTCOMES = [ "OS", "EFS", "GF", "DEAD", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI", ] OUTCOME_DESCRIPTIONS = { "OS": "Overall Survival", "EFS": "Event-Free Survival", "DEAD": "Total Mortality", "GF": "Graft Failure", "AGVHD": "Acute Graft-versus-Host Disease", "CGVHD": "Chronic Graft-versus-Host Disease", "VOCPSHI": "Vaso-Occlusive Crisis Post-HCT", "STROKEHI": "Stroke Post-HCT", } SHAP_OUTCOMES = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI", "OS", "EFS"] MODEL_DIR = "." CONSENSUS_THRESHOLD = 0.5 DEFAULT_N_BOOT_CI = 500 # --------------------------------------------------------------------------- # Model loading # --------------------------------------------------------------------------- def _load_skops_model(fname): if not os.path.exists(fname): raise RuntimeError(f"Model file not found: {fname}") try: untrusted = sio.get_untrusted_types(file=fname) model = sio.load(fname, trusted=untrusted) print(f" Loaded: {fname}") return model except Exception as e: raise RuntimeError(f"Failed to load '{fname}': {type(e).__name__}: {e}") from e print("Loading preprocessor...") preprocessor = _load_skops_model(os.path.join(MODEL_DIR, "preprocessor.skops")) print("Loading ensemble models...") classification_model_data = {} for _o in CLASSIFICATION_OUTCOMES: _path = os.path.join(MODEL_DIR, f"ensemble_model_{_o}.skops") if os.path.exists(_path): classification_model_data[_o] = _load_skops_model(_path) else: print(f" Warning: Model for {_o} not found at {_path}. Skipping.") classification_models = {o: d["models"] for o, d in classification_model_data.items()} betas = {o: d["beta"] for o, d in classification_model_data.items()} priors = {o: d["prior"] for o, d in classification_model_data.items()} consensus_thresholds = { o: d.get("consensus_threshold", CONSENSUS_THRESHOLD) for o, d in classification_model_data.items() } calibrators = {} for _o, _d in classification_model_data.items(): _cal = None _cal_type = _d.get("calibrator_type", None) if "calibrator" in _d and _d["calibrator"] is not None: if _cal_type is None or _cal_type == "isotonic": _cal = _d["calibrator"] else: print(f" Warning: outcome '{_o}' has calibrator_type='{_cal_type}'. Skipping.") elif "isotonic_calibrator" in _d and _d["isotonic_calibrator"] is not None: _cal = _d["isotonic_calibrator"] calibrators[_o] = _cal isotonic_calibrators = calibrators oof_probs_calibrated = { o: d.get("oof_probs_calibrated") for o, d in classification_model_data.items() } ohe = preprocessor.named_transformers_["cat"] ohe_feature_names = ohe.get_feature_names_out(CATEG_COLUMNS) processed_feature_names = np.concatenate([NUM_COLUMNS, ohe_feature_names]) print(f"Models loaded: {list(classification_models.keys())}") # --------------------------------------------------------------------------- # SHAP background data # --------------------------------------------------------------------------- print("Building SHAP background...") np.random.seed(23) _n_background = 500 _background_data = { "AGE": np.random.uniform(5, 50, _n_background), "NACS2YR": np.random.randint(0, 5, _n_background), "AGEGPFF": np.random.choice(["<=10", "11-17", "18-29", "30-49", ">=50"], _n_background), "SEX": np.random.choice(["Male", "Female"], _n_background), "KPS": np.random.choice(["<90", "≥ 90"], _n_background), "DONORF": np.random.choice([ "HLA identical sibling", "HLA mismatch relative", "Matched unrelated donor", "Mismatched unrelated donor or cord blood", ], _n_background), "GRAFTYPE": np.random.choice(["Bone marrow", "Peripheral blood", "Cord blood"], _n_background), "CONDGRPF": np.random.choice(["MAC", "RIC", "NMA"], _n_background), "CONDGRP_FINAL": np.random.choice(["TBI/Cy", "Bu/Cy", "Flu/Bu", "Flu/Mel"], _n_background), "ATGF": np.random.choice(["ATG", "Alemtuzumab", "None"], _n_background), "GVHD_FINAL": np.random.choice(["CNI + MMF", "CNI + MTX", "Post-CY + siro +- MMF"], _n_background), "HLA_FINAL": np.random.choice(["8/8", "7/8", "≤ 6/8"], _n_background), "RCMVPR": np.random.choice(["Negative", "Positive"], _n_background), "EXCHTFPR": np.random.choice(["No", "Yes"], _n_background), "VOC2YPR": np.random.choice(["No", "Yes"], _n_background), "VOCFRQPR": np.random.choice(["< 3/yr", "≥ 3/yr"], _n_background), "SCATXRSN": np.random.choice([ "CNS event", "Acute chest Syndrome", "Recurrent vaso-occlusive pain", "Recurrent priapism", "Excessive transfusion requirements/iron overload", "Cardio-pulmonary", "Chronic transfusion", "Asymptomatic", "Renal insufficiency", "Splenic sequestration", "Avascular necrosis", "Hodgkin lymphoma", ], _n_background), } _background_df = pd.DataFrame(_background_data)[FEATURE_NAMES] _X_background = preprocessor.transform(_background_df) shap_background = shap.maskers.Independent(_X_background) print("SHAP background ready.") # --------------------------------------------------------------------------- # Calibration helpers # --------------------------------------------------------------------------- def calibrate_probabilities_undersampling(p_s, beta): p_s = np.asarray(p_s, dtype=float) numerator = beta * p_s denominator = np.maximum((beta - 1.0) * p_s + 1.0, 1e-10) return np.clip(numerator / denominator, 0.0, 1.0) def predict_consensus_signed_voting(ensemble_models, X_test, threshold=0.5): individual_probas = np.array( [m.predict_proba(X_test)[:, 1] for m in ensemble_models] ) binary_preds = (individual_probas >= threshold).astype(int) signed_votes = np.where(binary_preds == 1, 1, -1) avg_signed_vote = np.mean(signed_votes, axis=0) consensus_pred = (avg_signed_vote > 0).astype(int) avg_proba = np.mean(individual_probas, axis=0) return consensus_pred, avg_proba, avg_signed_vote, individual_probas.flatten() def predict_consensus_majority(ensemble_models, X_test, threshold=0.5): individual_probas = np.array( [m.predict_proba(X_test)[:, 1] for m in ensemble_models] ) avg_proba = np.mean(individual_probas, axis=0) return avg_proba, individual_probas.flatten() # --------------------------------------------------------------------------- # Bootstrap CI # --------------------------------------------------------------------------- def bootstrap_ci_from_oof( point_estimate: float, oof_probs: np.ndarray, n_boot: int = DEFAULT_N_BOOT_CI, confidence: float = 0.95, random_state: int = 42, ) -> tuple: if oof_probs is None or len(oof_probs) == 0: return float(point_estimate), float(point_estimate) oof_probs = np.asarray(oof_probs, dtype=float) rng = np.random.RandomState(random_state) grand_mean = np.mean(oof_probs) n = len(oof_probs) boot_means = np.array([ np.mean(rng.choice(oof_probs, size=n, replace=True)) for _ in range(n_boot) ]) shift = point_estimate - grand_mean boot_means = boot_means + shift alpha = 1.0 - confidence lo = float(np.clip(np.percentile(boot_means, 100 * alpha / 2), 0.0, 1.0)) hi = float(np.clip(np.percentile(boot_means, 100 * (1 - alpha / 2)), 0.0, 1.0)) return lo, hi # --------------------------------------------------------------------------- # Calibration dispatch # --------------------------------------------------------------------------- def _calibrate_point(outcome: str, raw_prob: float, use_calibration: bool) -> float: beta = betas[outcome] p_beta = float(calibrate_probabilities_undersampling([raw_prob], beta)[0]) if not use_calibration: return p_beta cal = calibrators.get(outcome) if cal is None: return p_beta return float(cal.transform([p_beta])[0]) # --------------------------------------------------------------------------- # Main prediction functions # --------------------------------------------------------------------------- def predict_all_outcomes( user_inputs, use_calibration: bool = True, use_signed_voting: bool = True, n_boot_ci: int = DEFAULT_N_BOOT_CI, ): if isinstance(user_inputs, dict): input_df = pd.DataFrame([user_inputs]) else: input_df = pd.DataFrame([user_inputs], columns=FEATURE_NAMES) input_df = input_df[FEATURE_NAMES] X = preprocessor.transform(input_df) probs, intervals = {}, {} for o in CLASSIFICATION_OUTCOMES: if o not in classification_models: continue threshold = consensus_thresholds.get(o, CONSENSUS_THRESHOLD) if use_signed_voting: _, uncalib_arr, _, _ = predict_consensus_signed_voting( classification_models[o], X, threshold ) else: uncalib_arr, _ = predict_consensus_majority( classification_models[o], X, threshold ) raw_prob = float(uncalib_arr[0]) event_prob = _calibrate_point(o, raw_prob, use_calibration) lo, hi = bootstrap_ci_from_oof( point_estimate=event_prob, oof_probs=oof_probs_calibrated.get(o), n_boot=n_boot_ci, ) probs[o] = event_prob intervals[o] = (lo, hi) # OS = 1 - P(DEAD) if "DEAD" in probs: p_dead = probs["DEAD"] probs["OS"] = float(1.0 - p_dead) dead_lo, dead_hi = intervals["DEAD"] intervals["OS"] = ( float(np.clip(1.0 - dead_hi, 0, 1)), float(np.clip(1.0 - dead_lo, 0, 1)), ) # EFS = 1 - P(DWOGF) - P(GF) if "DWOGF" in probs and "GF" in probs: p_dwogf = probs["DWOGF"] p_gf = probs["GF"] probs["EFS"] = float(np.clip(1.0 - p_dwogf - p_gf, 0.0, 1.0)) oof_dwogf = oof_probs_calibrated.get("DWOGF") oof_gf = oof_probs_calibrated.get("GF") if oof_dwogf is not None and oof_gf is not None: oof_dwogf = np.asarray(oof_dwogf, dtype=float) oof_gf = np.asarray(oof_gf, dtype=float) n_min = min(len(oof_dwogf), len(oof_gf)) oof_dwogf = oof_dwogf[:n_min] oof_gf = oof_gf[:n_min] rng = np.random.RandomState(42) grand_dwogf = np.mean(oof_dwogf) grand_gf = np.mean(oof_gf) shift_dwogf = p_dwogf - grand_dwogf shift_gf = p_gf - grand_gf efs_boot = np.array([ np.clip( 1.0 - (np.mean(rng.choice(oof_dwogf, size=n_min, replace=True)) + shift_dwogf) - (np.mean(rng.choice(oof_gf, size=n_min, replace=True)) + shift_gf), 0.0, 1.0, ) for _ in range(n_boot_ci) ]) intervals["EFS"] = ( float(np.percentile(efs_boot, 2.5)), float(np.percentile(efs_boot, 97.5)), ) else: intervals["EFS"] = (probs["EFS"], probs["EFS"]) return probs, intervals def predict_with_comparison(user_inputs, n_boot_ci: int = DEFAULT_N_BOOT_CI): cal_probs, cal_intervals = predict_all_outcomes(user_inputs, True, True, n_boot_ci) uncal_probs, uncal_intervals = predict_all_outcomes(user_inputs, False, True, n_boot_ci) return (cal_probs, cal_intervals), (uncal_probs, uncal_intervals) # --------------------------------------------------------------------------- # SHAP helpers # --------------------------------------------------------------------------- def _get_shap_values_for_model_outcome(user_inputs, model_outcome, invert, X_proc): all_model_shap_vals = [] for rf_model in classification_models[model_outcome]: explainer = shap.TreeExplainer(rf_model, model_output="probability", data=shap_background) shap_vals = explainer.shap_values(X_proc) if isinstance(shap_vals, list): shap_vals = shap_vals[1] elif shap_vals.ndim == 3 and shap_vals.shape[2] == 2: shap_vals = shap_vals[:, :, 1] sv = shap_vals[0] if invert: sv = -sv all_model_shap_vals.append(sv) return np.array(all_model_shap_vals) def compute_shap_values_with_direction(user_inputs, outcome, max_display=10): if isinstance(user_inputs, dict): input_df = pd.DataFrame([user_inputs]) else: input_df = pd.DataFrame([user_inputs], columns=FEATURE_NAMES) X_proc = preprocessor.transform(input_df) processed_to_orig = {f: f for f in NUM_COLUMNS} for pf in ohe_feature_names: processed_to_orig[pf] = pf.split("_", 1)[0] if outcome == "OS": raw_shap = _get_shap_values_for_model_outcome(user_inputs, "DEAD", invert=True, X_proc=X_proc) elif outcome == "EFS": shap_dwogf = _get_shap_values_for_model_outcome(user_inputs, "DWOGF", invert=True, X_proc=X_proc) shap_gf = _get_shap_values_for_model_outcome(user_inputs, "GF", invert=True, X_proc=X_proc) raw_shap = np.concatenate([shap_dwogf, shap_gf], axis=0) else: raw_shap = _get_shap_values_for_model_outcome(user_inputs, outcome, invert=False, X_proc=X_proc) unique_orig_features = list(dict.fromkeys(processed_to_orig.values())) n_models = len(raw_shap) model_shap_by_orig = np.zeros((n_models, len(unique_orig_features))) for model_idx in range(n_models): agg_by_orig = {} for i, pf in enumerate(processed_feature_names): orig = processed_to_orig[pf] agg_by_orig.setdefault(orig, 0.0) agg_by_orig[orig] += raw_shap[model_idx, i] for feat_idx, feat_name in enumerate(unique_orig_features): model_shap_by_orig[model_idx, feat_idx] = agg_by_orig.get(feat_name, 0.0) mean_shap_vals = np.mean(model_shap_by_orig, axis=0) rng = np.random.RandomState(42) bootstrap_shap_means = np.array([ np.mean(model_shap_by_orig[rng.choice(n_models, size=n_models, replace=True)], axis=0) for _ in range(DEFAULT_N_BOOT_CI) ]) shap_ci_low = np.percentile(bootstrap_shap_means, 2.5, axis=0) shap_ci_high = np.percentile(bootstrap_shap_means, 97.5, axis=0) order = np.argsort(-np.abs(mean_shap_vals)) top_feat_names = [] for i in order[:max_display]: feat_name = unique_orig_features[i] if feat_name in user_inputs: val = user_inputs[feat_name] if isinstance(val, float) and val != int(val): display_name = f"{feat_name} = {val:.2f}" elif isinstance(val, (int, float)): display_name = f"{feat_name} = {int(val)}" else: val_str = str(val) if len(val_str) > 20: val_str = val_str[:17] + "..." display_name = f"{feat_name} = {val_str}" else: display_name = feat_name top_feat_names.append(display_name) top_feat_names = top_feat_names[::-1] top_shap_vals = mean_shap_vals[order][:max_display][::-1] top_ci_low = shap_ci_low[order][:max_display][::-1] top_ci_high = shap_ci_high[order][:max_display][::-1] return top_feat_names, top_shap_vals, top_ci_low, top_ci_high def create_shap_plot(user_inputs, outcome, max_display=10): feat_names, shap_vals, ci_low, ci_high = compute_shap_values_with_direction( user_inputs, outcome, max_display ) colors = ["blue" if v >= 0 else "red" for v in shap_vals] error_minus = shap_vals - ci_low error_plus = ci_high - shap_vals fig = go.Figure() fig.add_trace(go.Bar( y=feat_names, x=shap_vals, orientation="h", marker=dict(color=colors), showlegend=False, error_x=dict( type="data", symmetric=False, array=error_plus, arrayminus=error_minus, color="gray", thickness=1.5, width=4, ), )) fig.add_vline(x=0, line_width=1, line_color="black") fig.update_layout( title=dict( text=OUTCOME_DESCRIPTIONS.get(outcome, outcome), x=0.5, xanchor="center", font=dict(size=14, color="black"), ), xaxis_title="SHAP value", yaxis_title="", height=400, margin=dict(l=120, r=60, t=50, b=50), plot_bgcolor="white", paper_bgcolor="white", xaxis=dict(showgrid=True, gridcolor="lightgray", zeroline=True, zerolinecolor="black", zerolinewidth=1), yaxis=dict(showgrid=False), ) return fig def create_all_shap_plots(user_inputs, max_display=10): return {o: create_shap_plot(user_inputs, o, max_display) for o in SHAP_OUTCOMES} # --------------------------------------------------------------------------- # Icon array # --------------------------------------------------------------------------- # Root cause of previous gaps / distortion: # Plotly shape coords are in DATA units. If px-per-data-unit differs on # x vs y axes the circle head becomes an ellipse and spacing looks uneven. # # Fix: # • Use EQUAL axis spans on x and y (both = cols + 2*pad = 10.3) # • Set width and height so that usable pixels are EQUAL on both axes: # usable_w = W - margin_l - margin_r = W - 20 # usable_h = H - margin_t - margin_b = H - 100 # usable_w == usable_h → H = W + 80 # • This guarantees 1 data-unit = same number of pixels on both axes, # so circles are round and spacing is perfectly uniform. # --------------------------------------------------------------------------- def _stick_figure(cx, cy, color, s): """ Returns Plotly shape dicts for a stick figure centred at (cx, cy). s = scale (data units). With a cell size of 1.0, s ≈ 0.46 gives a figure that fills ~75 % of the cell vertically. Anatomy (all offsets relative to cy): head centre : cy + s*0.55 radius s*0.18 neck top : cy + s*0.35 hip : cy - s*0.15 arm branch : cy + s*0.18 foot : cy - s*0.55 """ shapes = [] lw = dict(color=color, width=1.8) # fixed pixel width — looks consistent # head hr = s * 0.18 hy = cy + s * 0.55 shapes.append(dict( type="circle", xref="x", yref="y", x0=cx - hr, y0=hy - hr, x1=cx + hr, y1=hy + hr, fillcolor=color, line=dict(color=color, width=0), )) neck_y = cy + s * 0.35 hip_y = cy - s * 0.15 arm_y = cy + s * 0.18 foot_y = cy - s * 0.55 # spine shapes.append(dict(type="line", xref="x", yref="y", x0=cx, y0=neck_y, x1=cx, y1=hip_y, line=lw)) # arms adx = s * 0.32 ady = s * 0.15 shapes.append(dict(type="line", xref="x", yref="y", x0=cx, y0=arm_y, x1=cx - adx, y1=arm_y - ady, line=lw)) shapes.append(dict(type="line", xref="x", yref="y", x0=cx, y0=arm_y, x1=cx + adx, y1=arm_y - ady, line=lw)) # legs ldx = s * 0.26 shapes.append(dict(type="line", xref="x", yref="y", x0=cx, y0=hip_y, x1=cx - ldx, y1=foot_y, line=lw)) shapes.append(dict(type="line", xref="x", yref="y", x0=cx, y0=hip_y, x1=cx + ldx, y1=foot_y, line=lw)) return shapes def icon_array(probability, outcome): outcome_labels = { "DEAD": ("Death", "Overall Survival"), "GF": ("Graft Failure", "No Graft Failure"), "AGVHD": ("AGVHD", "No AGVHD"), "CGVHD": ("CGVHD", "No CGVHD"), "VOCPSHI": ("VOC Post-HCT", "No VOC Post-HCT"), "STROKEHI": ("Stroke Post-HCT", "No Stroke Post-HCT"), } event_label, no_event_label = outcome_labels.get(outcome, ("Event", "No Event")) n_event = round(probability * 100) n_no_event = 100 - n_event cols, rows = 10, 10 # ── Layout constants ────────────────────────────────────────────────── # Icons sit on an integer grid 0..9 × 0..9. # Padding of 0.65 on each side → axis span = 9 + 2*0.65 = 10.30 # Margins: left=10, right=10, top=95, bottom=10 # usable_w = W - 20 ; usable_h = H - 105 # To ensure px_per_unit identical on both axes: usable_w == usable_h # → H = W + 85 # We also enforce equal axis spans (both 10.30). PAD = 0.65 W = 400 H = W + 85 # = 485 → usable = 380 px on both axes S = 0.46 # figure scale (≈ 75 % vertical fill per cell) x_lo, x_hi = -PAD, (cols - 1) + PAD # -0.65 … 9.65 span=10.30 y_lo, y_hi = -PAD, (rows - 1) + PAD # -0.65 … 9.65 span=10.30 all_shapes = [] icon_idx = 0 for row in range(rows): # row 0 → top of grid for col in range(cols): # col 0 → left color = "#e05555" if icon_idx < n_event else "#3bbfad" cx = col cy = (rows - 1) - row # invert: row 0 → cy=9 (top) all_shapes.extend(_stick_figure(cx, cy, color, S)) icon_idx += 1 fig = go.Figure() fig.update_layout( title=dict( text=( f"{OUTCOME_DESCRIPTIONS.get(outcome, outcome)}
" f"" f"■ {event_label}: {n_event}%" f"  " f"" f"■ {no_event_label}: {n_no_event}%" ), x=0.5, xanchor="center", font=dict(size=14, color="black"), ), shapes=all_shapes, xaxis=dict( range=[x_lo, x_hi], showgrid=False, zeroline=False, showticklabels=False, fixedrange=True, ), yaxis=dict( range=[y_lo, y_hi], showgrid=False, zeroline=False, showticklabels=False, fixedrange=True, # scaleanchor / scaleratio intentionally OMITTED — # equal spans + equal usable pixels already guarantee # identical px/unit on both axes without distortion. ), width=W, height=H, margin=dict(l=10, r=10, t=95, b=10), plot_bgcolor="white", paper_bgcolor="white", ) return fig print("===== inference.py loaded successfully =====")