SPRIGHT_With_Icon_array / inference.py
shivapriyasom's picture
Update inference.py
c744750 verified
import numpy as np
import pandas as pd
import skops.io as sio
import shap
import plotly.graph_objects as go
import os
import sys
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="sklearn")
print("===== Application Startup =====")
print(f"Working directory: {os.getcwd()}")
print(f"Files present: {os.listdir('.')}")
# ---------------------------------------------------------------------------
# Compatibility patch
# ---------------------------------------------------------------------------
import sklearn.compose._column_transformer as _ct
if not hasattr(_ct, "_RemainderColsList"):
class _RemainderColsList(list):
def __init__(self, lst=None, future_dtype=None):
super().__init__(lst or [])
self.future_dtype = future_dtype
_ct._RemainderColsList = _RemainderColsList
import sklearn.compose
sklearn.compose._RemainderColsList = _RemainderColsList
print("Patched _RemainderColsList into sklearn.compose")
# ---------------------------------------------------------------------------
# Column / feature definitions
# ---------------------------------------------------------------------------
NUM_COLUMNS = ["AGE", "NACS2YR"]
CATEG_COLUMNS = [
"AGEGPFF", "SEX", "KPS", "DONORF", "GRAFTYPE", "CONDGRPF",
"CONDGRP_FINAL", "ATGF", "GVHD_FINAL", "HLA_FINAL",
"RCMVPR", "EXCHTFPR", "VOC2YPR", "VOCFRQPR", "SCATXRSN",
]
FEATURE_NAMES = NUM_COLUMNS + CATEG_COLUMNS
OUTCOMES = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI", "DWOGF"]
CLASSIFICATION_OUTCOMES = OUTCOMES
REPORTING_OUTCOMES = [
"OS", "EFS", "GF", "DEAD",
"AGVHD", "CGVHD", "VOCPSHI", "STROKEHI",
]
OUTCOME_DESCRIPTIONS = {
"OS": "Overall Survival",
"EFS": "Event-Free Survival",
"DEAD": "Total Mortality",
"GF": "Graft Failure",
"AGVHD": "Acute Graft-versus-Host Disease",
"CGVHD": "Chronic Graft-versus-Host Disease",
"VOCPSHI": "Vaso-Occlusive Crisis Post-HCT",
"STROKEHI": "Stroke Post-HCT",
}
SHAP_OUTCOMES = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI", "OS", "EFS"]
MODEL_DIR = "."
CONSENSUS_THRESHOLD = 0.5
DEFAULT_N_BOOT_CI = 500
# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
def _load_skops_model(fname):
if not os.path.exists(fname):
raise RuntimeError(f"Model file not found: {fname}")
try:
untrusted = sio.get_untrusted_types(file=fname)
model = sio.load(fname, trusted=untrusted)
print(f" Loaded: {fname}")
return model
except Exception as e:
raise RuntimeError(f"Failed to load '{fname}': {type(e).__name__}: {e}") from e
print("Loading preprocessor...")
preprocessor = _load_skops_model(os.path.join(MODEL_DIR, "preprocessor.skops"))
print("Loading ensemble models...")
classification_model_data = {}
for _o in CLASSIFICATION_OUTCOMES:
_path = os.path.join(MODEL_DIR, f"ensemble_model_{_o}.skops")
if os.path.exists(_path):
classification_model_data[_o] = _load_skops_model(_path)
else:
print(f" Warning: Model for {_o} not found at {_path}. Skipping.")
classification_models = {o: d["models"] for o, d in classification_model_data.items()}
betas = {o: d["beta"] for o, d in classification_model_data.items()}
priors = {o: d["prior"] for o, d in classification_model_data.items()}
consensus_thresholds = {
o: d.get("consensus_threshold", CONSENSUS_THRESHOLD)
for o, d in classification_model_data.items()
}
calibrators = {}
for _o, _d in classification_model_data.items():
_cal = None
_cal_type = _d.get("calibrator_type", None)
if "calibrator" in _d and _d["calibrator"] is not None:
if _cal_type is None or _cal_type == "isotonic":
_cal = _d["calibrator"]
else:
print(f" Warning: outcome '{_o}' has calibrator_type='{_cal_type}'. Skipping.")
elif "isotonic_calibrator" in _d and _d["isotonic_calibrator"] is not None:
_cal = _d["isotonic_calibrator"]
calibrators[_o] = _cal
isotonic_calibrators = calibrators
oof_probs_calibrated = {
o: d.get("oof_probs_calibrated") for o, d in classification_model_data.items()
}
ohe = preprocessor.named_transformers_["cat"]
ohe_feature_names = ohe.get_feature_names_out(CATEG_COLUMNS)
processed_feature_names = np.concatenate([NUM_COLUMNS, ohe_feature_names])
print(f"Models loaded: {list(classification_models.keys())}")
# ---------------------------------------------------------------------------
# SHAP background data
# ---------------------------------------------------------------------------
print("Building SHAP background...")
np.random.seed(23)
_n_background = 500
_background_data = {
"AGE": np.random.uniform(5, 50, _n_background),
"NACS2YR": np.random.randint(0, 5, _n_background),
"AGEGPFF": np.random.choice(["<=10", "11-17", "18-29", "30-49", ">=50"], _n_background),
"SEX": np.random.choice(["Male", "Female"], _n_background),
"KPS": np.random.choice(["<90", "β‰₯ 90"], _n_background),
"DONORF": np.random.choice([
"HLA identical sibling", "HLA mismatch relative",
"Matched unrelated donor",
"Mismatched unrelated donor or cord blood",
], _n_background),
"GRAFTYPE": np.random.choice(["Bone marrow", "Peripheral blood", "Cord blood"], _n_background),
"CONDGRPF": np.random.choice(["MAC", "RIC", "NMA"], _n_background),
"CONDGRP_FINAL": np.random.choice(["TBI/Cy", "Bu/Cy", "Flu/Bu", "Flu/Mel"], _n_background),
"ATGF": np.random.choice(["ATG", "Alemtuzumab", "None"], _n_background),
"GVHD_FINAL": np.random.choice(["CNI + MMF", "CNI + MTX", "Post-CY + siro +- MMF"], _n_background),
"HLA_FINAL": np.random.choice(["8/8", "7/8", "≀ 6/8"], _n_background),
"RCMVPR": np.random.choice(["Negative", "Positive"], _n_background),
"EXCHTFPR": np.random.choice(["No", "Yes"], _n_background),
"VOC2YPR": np.random.choice(["No", "Yes"], _n_background),
"VOCFRQPR": np.random.choice(["< 3/yr", "β‰₯ 3/yr"], _n_background),
"SCATXRSN": np.random.choice([
"CNS event", "Acute chest Syndrome",
"Recurrent vaso-occlusive pain", "Recurrent priapism",
"Excessive transfusion requirements/iron overload",
"Cardio-pulmonary", "Chronic transfusion", "Asymptomatic",
"Renal insufficiency", "Splenic sequestration",
"Avascular necrosis", "Hodgkin lymphoma",
], _n_background),
}
_background_df = pd.DataFrame(_background_data)[FEATURE_NAMES]
_X_background = preprocessor.transform(_background_df)
shap_background = shap.maskers.Independent(_X_background)
print("SHAP background ready.")
# ---------------------------------------------------------------------------
# Calibration helpers
# ---------------------------------------------------------------------------
def calibrate_probabilities_undersampling(p_s, beta):
p_s = np.asarray(p_s, dtype=float)
numerator = beta * p_s
denominator = np.maximum((beta - 1.0) * p_s + 1.0, 1e-10)
return np.clip(numerator / denominator, 0.0, 1.0)
def predict_consensus_signed_voting(ensemble_models, X_test, threshold=0.5):
individual_probas = np.array(
[m.predict_proba(X_test)[:, 1] for m in ensemble_models]
)
binary_preds = (individual_probas >= threshold).astype(int)
signed_votes = np.where(binary_preds == 1, 1, -1)
avg_signed_vote = np.mean(signed_votes, axis=0)
consensus_pred = (avg_signed_vote > 0).astype(int)
avg_proba = np.mean(individual_probas, axis=0)
return consensus_pred, avg_proba, avg_signed_vote, individual_probas.flatten()
def predict_consensus_majority(ensemble_models, X_test, threshold=0.5):
individual_probas = np.array(
[m.predict_proba(X_test)[:, 1] for m in ensemble_models]
)
avg_proba = np.mean(individual_probas, axis=0)
return avg_proba, individual_probas.flatten()
# ---------------------------------------------------------------------------
# Bootstrap CI
# ---------------------------------------------------------------------------
def bootstrap_ci_from_oof(
point_estimate: float,
oof_probs: np.ndarray,
n_boot: int = DEFAULT_N_BOOT_CI,
confidence: float = 0.95,
random_state: int = 42,
) -> tuple:
if oof_probs is None or len(oof_probs) == 0:
return float(point_estimate), float(point_estimate)
oof_probs = np.asarray(oof_probs, dtype=float)
rng = np.random.RandomState(random_state)
grand_mean = np.mean(oof_probs)
n = len(oof_probs)
boot_means = np.array([
np.mean(rng.choice(oof_probs, size=n, replace=True))
for _ in range(n_boot)
])
shift = point_estimate - grand_mean
boot_means = boot_means + shift
alpha = 1.0 - confidence
lo = float(np.clip(np.percentile(boot_means, 100 * alpha / 2), 0.0, 1.0))
hi = float(np.clip(np.percentile(boot_means, 100 * (1 - alpha / 2)), 0.0, 1.0))
return lo, hi
# ---------------------------------------------------------------------------
# Calibration dispatch
# ---------------------------------------------------------------------------
def _calibrate_point(outcome: str, raw_prob: float, use_calibration: bool) -> float:
beta = betas[outcome]
p_beta = float(calibrate_probabilities_undersampling([raw_prob], beta)[0])
if not use_calibration:
return p_beta
cal = calibrators.get(outcome)
if cal is None:
return p_beta
return float(cal.transform([p_beta])[0])
# ---------------------------------------------------------------------------
# Main prediction functions
# ---------------------------------------------------------------------------
def predict_all_outcomes(
user_inputs,
use_calibration: bool = True,
use_signed_voting: bool = True,
n_boot_ci: int = DEFAULT_N_BOOT_CI,
):
if isinstance(user_inputs, dict):
input_df = pd.DataFrame([user_inputs])
else:
input_df = pd.DataFrame([user_inputs], columns=FEATURE_NAMES)
input_df = input_df[FEATURE_NAMES]
X = preprocessor.transform(input_df)
probs, intervals = {}, {}
for o in CLASSIFICATION_OUTCOMES:
if o not in classification_models:
continue
threshold = consensus_thresholds.get(o, CONSENSUS_THRESHOLD)
if use_signed_voting:
_, uncalib_arr, _, _ = predict_consensus_signed_voting(
classification_models[o], X, threshold
)
else:
uncalib_arr, _ = predict_consensus_majority(
classification_models[o], X, threshold
)
raw_prob = float(uncalib_arr[0])
event_prob = _calibrate_point(o, raw_prob, use_calibration)
lo, hi = bootstrap_ci_from_oof(
point_estimate=event_prob,
oof_probs=oof_probs_calibrated.get(o),
n_boot=n_boot_ci,
)
probs[o] = event_prob
intervals[o] = (lo, hi)
# OS = 1 - P(DEAD)
if "DEAD" in probs:
p_dead = probs["DEAD"]
probs["OS"] = float(1.0 - p_dead)
dead_lo, dead_hi = intervals["DEAD"]
intervals["OS"] = (
float(np.clip(1.0 - dead_hi, 0, 1)),
float(np.clip(1.0 - dead_lo, 0, 1)),
)
# EFS = 1 - P(DWOGF) - P(GF)
if "DWOGF" in probs and "GF" in probs:
p_dwogf = probs["DWOGF"]
p_gf = probs["GF"]
probs["EFS"] = float(np.clip(1.0 - p_dwogf - p_gf, 0.0, 1.0))
oof_dwogf = oof_probs_calibrated.get("DWOGF")
oof_gf = oof_probs_calibrated.get("GF")
if oof_dwogf is not None and oof_gf is not None:
oof_dwogf = np.asarray(oof_dwogf, dtype=float)
oof_gf = np.asarray(oof_gf, dtype=float)
n_min = min(len(oof_dwogf), len(oof_gf))
oof_dwogf = oof_dwogf[:n_min]
oof_gf = oof_gf[:n_min]
rng = np.random.RandomState(42)
grand_dwogf = np.mean(oof_dwogf)
grand_gf = np.mean(oof_gf)
shift_dwogf = p_dwogf - grand_dwogf
shift_gf = p_gf - grand_gf
efs_boot = np.array([
np.clip(
1.0
- (np.mean(rng.choice(oof_dwogf, size=n_min, replace=True)) + shift_dwogf)
- (np.mean(rng.choice(oof_gf, size=n_min, replace=True)) + shift_gf),
0.0, 1.0,
)
for _ in range(n_boot_ci)
])
intervals["EFS"] = (
float(np.percentile(efs_boot, 2.5)),
float(np.percentile(efs_boot, 97.5)),
)
else:
intervals["EFS"] = (probs["EFS"], probs["EFS"])
return probs, intervals
def predict_with_comparison(user_inputs, n_boot_ci: int = DEFAULT_N_BOOT_CI):
cal_probs, cal_intervals = predict_all_outcomes(user_inputs, True, True, n_boot_ci)
uncal_probs, uncal_intervals = predict_all_outcomes(user_inputs, False, True, n_boot_ci)
return (cal_probs, cal_intervals), (uncal_probs, uncal_intervals)
# ---------------------------------------------------------------------------
# SHAP helpers
# ---------------------------------------------------------------------------
def _get_shap_values_for_model_outcome(user_inputs, model_outcome, invert, X_proc):
all_model_shap_vals = []
for rf_model in classification_models[model_outcome]:
explainer = shap.TreeExplainer(rf_model, model_output="probability", data=shap_background)
shap_vals = explainer.shap_values(X_proc)
if isinstance(shap_vals, list):
shap_vals = shap_vals[1]
elif shap_vals.ndim == 3 and shap_vals.shape[2] == 2:
shap_vals = shap_vals[:, :, 1]
sv = shap_vals[0]
if invert:
sv = -sv
all_model_shap_vals.append(sv)
return np.array(all_model_shap_vals)
def compute_shap_values_with_direction(user_inputs, outcome, max_display=10):
if isinstance(user_inputs, dict):
input_df = pd.DataFrame([user_inputs])
else:
input_df = pd.DataFrame([user_inputs], columns=FEATURE_NAMES)
X_proc = preprocessor.transform(input_df)
processed_to_orig = {f: f for f in NUM_COLUMNS}
for pf in ohe_feature_names:
processed_to_orig[pf] = pf.split("_", 1)[0]
if outcome == "OS":
raw_shap = _get_shap_values_for_model_outcome(user_inputs, "DEAD", invert=True, X_proc=X_proc)
elif outcome == "EFS":
shap_dwogf = _get_shap_values_for_model_outcome(user_inputs, "DWOGF", invert=True, X_proc=X_proc)
shap_gf = _get_shap_values_for_model_outcome(user_inputs, "GF", invert=True, X_proc=X_proc)
raw_shap = np.concatenate([shap_dwogf, shap_gf], axis=0)
else:
raw_shap = _get_shap_values_for_model_outcome(user_inputs, outcome, invert=False, X_proc=X_proc)
unique_orig_features = list(dict.fromkeys(processed_to_orig.values()))
n_models = len(raw_shap)
model_shap_by_orig = np.zeros((n_models, len(unique_orig_features)))
for model_idx in range(n_models):
agg_by_orig = {}
for i, pf in enumerate(processed_feature_names):
orig = processed_to_orig[pf]
agg_by_orig.setdefault(orig, 0.0)
agg_by_orig[orig] += raw_shap[model_idx, i]
for feat_idx, feat_name in enumerate(unique_orig_features):
model_shap_by_orig[model_idx, feat_idx] = agg_by_orig.get(feat_name, 0.0)
mean_shap_vals = np.mean(model_shap_by_orig, axis=0)
rng = np.random.RandomState(42)
bootstrap_shap_means = np.array([
np.mean(model_shap_by_orig[rng.choice(n_models, size=n_models, replace=True)], axis=0)
for _ in range(DEFAULT_N_BOOT_CI)
])
shap_ci_low = np.percentile(bootstrap_shap_means, 2.5, axis=0)
shap_ci_high = np.percentile(bootstrap_shap_means, 97.5, axis=0)
order = np.argsort(-np.abs(mean_shap_vals))
top_feat_names = []
for i in order[:max_display]:
feat_name = unique_orig_features[i]
if feat_name in user_inputs:
val = user_inputs[feat_name]
if isinstance(val, float) and val != int(val):
display_name = f"{feat_name} = {val:.2f}"
elif isinstance(val, (int, float)):
display_name = f"{feat_name} = {int(val)}"
else:
val_str = str(val)
if len(val_str) > 20:
val_str = val_str[:17] + "..."
display_name = f"{feat_name} = {val_str}"
else:
display_name = feat_name
top_feat_names.append(display_name)
top_feat_names = top_feat_names[::-1]
top_shap_vals = mean_shap_vals[order][:max_display][::-1]
top_ci_low = shap_ci_low[order][:max_display][::-1]
top_ci_high = shap_ci_high[order][:max_display][::-1]
return top_feat_names, top_shap_vals, top_ci_low, top_ci_high
def create_shap_plot(user_inputs, outcome, max_display=10):
feat_names, shap_vals, ci_low, ci_high = compute_shap_values_with_direction(
user_inputs, outcome, max_display
)
colors = ["blue" if v >= 0 else "red" for v in shap_vals]
error_minus = shap_vals - ci_low
error_plus = ci_high - shap_vals
fig = go.Figure()
fig.add_trace(go.Bar(
y=feat_names,
x=shap_vals,
orientation="h",
marker=dict(color=colors),
showlegend=False,
error_x=dict(
type="data",
symmetric=False,
array=error_plus,
arrayminus=error_minus,
color="gray",
thickness=1.5,
width=4,
),
))
fig.add_vline(x=0, line_width=1, line_color="black")
fig.update_layout(
title=dict(
text=OUTCOME_DESCRIPTIONS.get(outcome, outcome),
x=0.5, xanchor="center",
font=dict(size=14, color="black"),
),
xaxis_title="SHAP value",
yaxis_title="",
height=400,
margin=dict(l=120, r=60, t=50, b=50),
plot_bgcolor="white",
paper_bgcolor="white",
xaxis=dict(showgrid=True, gridcolor="lightgray", zeroline=True,
zerolinecolor="black", zerolinewidth=1),
yaxis=dict(showgrid=False),
)
return fig
def create_all_shap_plots(user_inputs, max_display=10):
return {o: create_shap_plot(user_inputs, o, max_display) for o in SHAP_OUTCOMES}
# ---------------------------------------------------------------------------
# Icon array
# ---------------------------------------------------------------------------
# Root cause of previous gaps / distortion:
# Plotly shape coords are in DATA units. If px-per-data-unit differs on
# x vs y axes the circle head becomes an ellipse and spacing looks uneven.
#
# Fix:
# β€’ Use EQUAL axis spans on x and y (both = cols + 2*pad = 10.3)
# β€’ Set width and height so that usable pixels are EQUAL on both axes:
# usable_w = W - margin_l - margin_r = W - 20
# usable_h = H - margin_t - margin_b = H - 100
# usable_w == usable_h β†’ H = W + 80
# β€’ This guarantees 1 data-unit = same number of pixels on both axes,
# so circles are round and spacing is perfectly uniform.
# ---------------------------------------------------------------------------
def _stick_figure(cx, cy, color, s):
"""
Returns Plotly shape dicts for a stick figure centred at (cx, cy).
s = scale (data units). With a cell size of 1.0, s β‰ˆ 0.46 gives
a figure that fills ~75 % of the cell vertically.
Anatomy (all offsets relative to cy):
head centre : cy + s*0.55 radius s*0.18
neck top : cy + s*0.35
hip : cy - s*0.15
arm branch : cy + s*0.18
foot : cy - s*0.55
"""
shapes = []
lw = dict(color=color, width=1.8) # fixed pixel width β€” looks consistent
# head
hr = s * 0.18
hy = cy + s * 0.55
shapes.append(dict(
type="circle", xref="x", yref="y",
x0=cx - hr, y0=hy - hr,
x1=cx + hr, y1=hy + hr,
fillcolor=color,
line=dict(color=color, width=0),
))
neck_y = cy + s * 0.35
hip_y = cy - s * 0.15
arm_y = cy + s * 0.18
foot_y = cy - s * 0.55
# spine
shapes.append(dict(type="line", xref="x", yref="y",
x0=cx, y0=neck_y, x1=cx, y1=hip_y, line=lw))
# arms
adx = s * 0.32
ady = s * 0.15
shapes.append(dict(type="line", xref="x", yref="y",
x0=cx, y0=arm_y, x1=cx - adx, y1=arm_y - ady, line=lw))
shapes.append(dict(type="line", xref="x", yref="y",
x0=cx, y0=arm_y, x1=cx + adx, y1=arm_y - ady, line=lw))
# legs
ldx = s * 0.26
shapes.append(dict(type="line", xref="x", yref="y",
x0=cx, y0=hip_y, x1=cx - ldx, y1=foot_y, line=lw))
shapes.append(dict(type="line", xref="x", yref="y",
x0=cx, y0=hip_y, x1=cx + ldx, y1=foot_y, line=lw))
return shapes
def icon_array(probability, outcome):
outcome_labels = {
"DEAD": ("Death", "Overall Survival"),
"GF": ("Graft Failure", "No Graft Failure"),
"AGVHD": ("AGVHD", "No AGVHD"),
"CGVHD": ("CGVHD", "No CGVHD"),
"VOCPSHI": ("VOC Post-HCT", "No VOC Post-HCT"),
"STROKEHI": ("Stroke Post-HCT", "No Stroke Post-HCT"),
}
event_label, no_event_label = outcome_labels.get(outcome, ("Event", "No Event"))
n_event = round(probability * 100)
n_no_event = 100 - n_event
cols, rows = 10, 10
# ── Layout constants ──────────────────────────────────────────────────
# Icons sit on an integer grid 0..9 Γ— 0..9.
# Padding of 0.65 on each side β†’ axis span = 9 + 2*0.65 = 10.30
# Margins: left=10, right=10, top=95, bottom=10
# usable_w = W - 20 ; usable_h = H - 105
# To ensure px_per_unit identical on both axes: usable_w == usable_h
# β†’ H = W + 85
# We also enforce equal axis spans (both 10.30).
PAD = 0.65
W = 400
H = W + 85 # = 485 β†’ usable = 380 px on both axes
S = 0.46 # figure scale (β‰ˆ 75 % vertical fill per cell)
x_lo, x_hi = -PAD, (cols - 1) + PAD # -0.65 … 9.65 span=10.30
y_lo, y_hi = -PAD, (rows - 1) + PAD # -0.65 … 9.65 span=10.30
all_shapes = []
icon_idx = 0
for row in range(rows): # row 0 β†’ top of grid
for col in range(cols): # col 0 β†’ left
color = "#e05555" if icon_idx < n_event else "#3bbfad"
cx = col
cy = (rows - 1) - row # invert: row 0 β†’ cy=9 (top)
all_shapes.extend(_stick_figure(cx, cy, color, S))
icon_idx += 1
fig = go.Figure()
fig.update_layout(
title=dict(
text=(
f"<b>{OUTCOME_DESCRIPTIONS.get(outcome, outcome)}</b><br>"
f"<span style='font-size:12px;color:#e05555'>"
f"β–  {event_label}: {n_event}%</span>"
f"&nbsp;&nbsp;"
f"<span style='font-size:12px;color:#3bbfad'>"
f"β–  {no_event_label}: {n_no_event}%</span>"
),
x=0.5, xanchor="center",
font=dict(size=14, color="black"),
),
shapes=all_shapes,
xaxis=dict(
range=[x_lo, x_hi],
showgrid=False, zeroline=False, showticklabels=False,
fixedrange=True,
),
yaxis=dict(
range=[y_lo, y_hi],
showgrid=False, zeroline=False, showticklabels=False,
fixedrange=True,
# scaleanchor / scaleratio intentionally OMITTED β€”
# equal spans + equal usable pixels already guarantee
# identical px/unit on both axes without distortion.
),
width=W,
height=H,
margin=dict(l=10, r=10, t=95, b=10),
plot_bgcolor="white",
paper_bgcolor="white",
)
return fig
print("===== inference.py loaded successfully =====")