shivapriyasom commited on
Commit
ab21c9c
·
verified ·
1 Parent(s): 9045ae4

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +604 -0
inference.py CHANGED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import skops.io as sio
4
+ import shap
5
+ import plotly.graph_objects as go
6
+ import os
7
+ import sys
8
+ import warnings
9
+
10
+ warnings.filterwarnings("ignore", category=UserWarning, module="sklearn")
11
+
12
+ # ---------------------------------------------------------------------------
13
+ # Compatibility patch — inject _RemainderColsList if the installed sklearn
14
+ # version does not have it (added in sklearn 1.4+). This allows .skops files
15
+ # saved with a newer sklearn to load correctly on older environments.
16
+ # ---------------------------------------------------------------------------
17
+ import sklearn.compose._column_transformer as _ct
18
+ if not hasattr(_ct, "_RemainderColsList"):
19
+ class _RemainderColsList(list):
20
+ """Minimal shim for sklearn._RemainderColsList (missing in this env)."""
21
+ def __init__(self, lst=None, future_dtype=None):
22
+ super().__init__(lst or [])
23
+ self.future_dtype = future_dtype
24
+ _ct._RemainderColsList = _RemainderColsList
25
+ import sklearn.compose
26
+ sklearn.compose._RemainderColsList = _RemainderColsList
27
+
28
+
29
+ # ---------------------------------------------------------------------------
30
+ # Column / feature definitions
31
+ # ---------------------------------------------------------------------------
32
+
33
+ NUM_COLUMNS = ["AGE", "NACS2YR"]
34
+ CATEG_COLUMNS = [
35
+ "AGEGPFF",
36
+ "SEX",
37
+ "KPS",
38
+ "DONORF",
39
+ "GRAFTYPE",
40
+ "CONDGRPF",
41
+ "CONDGRP_FINAL",
42
+ "ATGF",
43
+ "GVHD_FINAL",
44
+ "HLA_FINAL",
45
+ "RCMVPR",
46
+ "EXCHTFPR",
47
+ "VOC2YPR",
48
+ "VOCFRQPR",
49
+ "SCATXRSN",
50
+ ]
51
+
52
+ FEATURE_NAMES = NUM_COLUMNS + CATEG_COLUMNS
53
+
54
+ OUTCOMES = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI", "DWOGF"]
55
+ CLASSIFICATION_OUTCOMES = OUTCOMES
56
+
57
+ REPORTING_OUTCOMES = [
58
+ "OS", "EFS", "GF", "DEAD",
59
+ "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI",
60
+ ]
61
+
62
+ OUTCOME_DESCRIPTIONS = {
63
+ "OS": "Overall Survival",
64
+ "EFS": "Event-Free Survival",
65
+ "DEAD": "Total Mortality",
66
+ "GF": "Graft Failure",
67
+ "AGVHD": "Acute Graft-versus-Host Disease",
68
+ "CGVHD": "Chronic Graft-versus-Host Disease",
69
+ "VOCPSHI": "Vaso-Occlusive Crisis Post-HCT",
70
+ "STROKEHI": "Stroke Post-HCT",
71
+ }
72
+
73
+ SHAP_OUTCOMES = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI", "OS", "EFS"]
74
+
75
+ MODEL_DIR = "."
76
+ CONSENSUS_THRESHOLD = 0.5
77
+ DEFAULT_N_BOOT_CI = 500
78
+
79
+
80
+ # ---------------------------------------------------------------------------
81
+ # Model loading — skops
82
+ # ---------------------------------------------------------------------------
83
+
84
+ def _load_skops_model(fname):
85
+ try:
86
+ untrusted = sio.get_untrusted_types(file=fname)
87
+ return sio.load(fname, trusted=untrusted)
88
+ except Exception as e:
89
+ print(f"Error loading '{fname}': {e}")
90
+ sys.exit(1)
91
+
92
+
93
+ preprocessor = _load_skops_model(os.path.join(MODEL_DIR, "preprocessor.skops"))
94
+
95
+ classification_model_data = {}
96
+ for _o in CLASSIFICATION_OUTCOMES:
97
+ _path = os.path.join(MODEL_DIR, f"ensemble_model_{_o}.skops")
98
+ if os.path.exists(_path):
99
+ classification_model_data[_o] = _load_skops_model(_path)
100
+ else:
101
+ print(f"Warning: Model for {_o} not found at {_path}. Skipping.")
102
+
103
+ classification_models = {o: d["models"] for o, d in classification_model_data.items()}
104
+ betas = {o: d["beta"] for o, d in classification_model_data.items()}
105
+ priors = {o: d["prior"] for o, d in classification_model_data.items()}
106
+ consensus_thresholds = {
107
+ o: d.get("consensus_threshold", CONSENSUS_THRESHOLD)
108
+ for o, d in classification_model_data.items()
109
+ }
110
+
111
+ # Calibrators — isotonic only; supports both old and new key names
112
+ calibrators = {}
113
+ for _o, _d in classification_model_data.items():
114
+ _cal = None
115
+ _cal_type = _d.get("calibrator_type", None)
116
+
117
+ if "calibrator" in _d and _d["calibrator"] is not None:
118
+ if _cal_type is None or _cal_type == "isotonic":
119
+ _cal = _d["calibrator"]
120
+ else:
121
+ print(
122
+ f"Warning: outcome '{_o}' has calibrator_type='{_cal_type}'. "
123
+ "Skipping non-isotonic calibrator (isotonic-only policy)."
124
+ )
125
+ elif "isotonic_calibrator" in _d and _d["isotonic_calibrator"] is not None:
126
+ _cal = _d["isotonic_calibrator"]
127
+
128
+ calibrators[_o] = _cal
129
+
130
+ # Alias expected by app.py
131
+ isotonic_calibrators = calibrators
132
+
133
+ oof_probs_calibrated = {
134
+ o: d.get("oof_probs_calibrated") for o, d in classification_model_data.items()
135
+ }
136
+
137
+ ohe = preprocessor.named_transformers_["cat"]
138
+ ohe_feature_names = ohe.get_feature_names_out(CATEG_COLUMNS)
139
+ processed_feature_names = np.concatenate([NUM_COLUMNS, ohe_feature_names])
140
+
141
+
142
+ # ---------------------------------------------------------------------------
143
+ # SHAP background data
144
+ # ---------------------------------------------------------------------------
145
+
146
+ np.random.seed(23)
147
+ _n_background = 500
148
+
149
+ _background_data = {
150
+ "AGE": np.random.uniform(5, 50, _n_background),
151
+ "NACS2YR": np.random.randint(0, 5, _n_background),
152
+ "AGEGPFF": np.random.choice(["<=10", "11-17", "18-29", "30-49", ">=50"], _n_background),
153
+ "SEX": np.random.choice(["Male", "Female"], _n_background),
154
+ "KPS": np.random.choice(["<90", "≥ 90"], _n_background),
155
+ "DONORF": np.random.choice([
156
+ "HLA identical sibling", "HLA mismatch relative",
157
+ "Matched unrelated donor",
158
+ "Mismatched unrelated donor or cord blood",
159
+ ], _n_background),
160
+ "GRAFTYPE": np.random.choice(["Bone marrow", "Peripheral blood", "Cord blood"], _n_background),
161
+ "CONDGRPF": np.random.choice(["MAC", "RIC", "NMA"], _n_background),
162
+ "CONDGRP_FINAL": np.random.choice(["TBI/Cy", "Bu/Cy", "Flu/Bu", "Flu/Mel"], _n_background),
163
+ "ATGF": np.random.choice(["ATG", "Alemtuzumab", "None"], _n_background),
164
+ "GVHD_FINAL": np.random.choice(["CNI + MMF", "CNI + MTX", "Post-CY + siro +- MMF"], _n_background),
165
+ "HLA_FINAL": np.random.choice(["8/8", "7/8", "≤ 6/8"], _n_background),
166
+ "RCMVPR": np.random.choice(["Negative", "Positive"], _n_background),
167
+ "EXCHTFPR": np.random.choice(["No", "Yes"], _n_background),
168
+ "VOC2YPR": np.random.choice(["No", "Yes"], _n_background),
169
+ "VOCFRQPR": np.random.choice(["< 3/yr", "≥ 3/yr"], _n_background),
170
+ "SCATXRSN": np.random.choice([
171
+ "CNS event", "Acute chest Syndrome",
172
+ "Recurrent vaso-occlusive pain", "Recurrent priapism",
173
+ "Excessive transfusion requirements/iron overload",
174
+ "Cardio-pulmonary", "Chronic transfusion", "Asymptomatic",
175
+ "Renal insufficiency", "Splenic sequestration",
176
+ "Avascular necrosis", "Hodgkin lymphoma",
177
+ ], _n_background),
178
+ }
179
+
180
+ _background_df = pd.DataFrame(_background_data)[FEATURE_NAMES]
181
+ _X_background = preprocessor.transform(_background_df)
182
+ shap_background = shap.maskers.Independent(_X_background)
183
+
184
+
185
+ # ---------------------------------------------------------------------------
186
+ # Calibration helpers
187
+ # ---------------------------------------------------------------------------
188
+
189
+ def calibrate_probabilities_undersampling(p_s, beta):
190
+ p_s = np.asarray(p_s, dtype=float)
191
+ numerator = beta * p_s
192
+ denominator = np.maximum((beta - 1.0) * p_s + 1.0, 1e-10)
193
+ return np.clip(numerator / denominator, 0.0, 1.0)
194
+
195
+
196
+ def predict_consensus_signed_voting(ensemble_models, X_test, threshold=0.5):
197
+ individual_probas = np.array(
198
+ [m.predict_proba(X_test)[:, 1] for m in ensemble_models]
199
+ )
200
+ binary_preds = (individual_probas >= threshold).astype(int)
201
+ signed_votes = np.where(binary_preds == 1, 1, -1)
202
+ avg_signed_vote = np.mean(signed_votes, axis=0)
203
+ consensus_pred = (avg_signed_vote > 0).astype(int)
204
+ avg_proba = np.mean(individual_probas, axis=0)
205
+ return consensus_pred, avg_proba, avg_signed_vote, individual_probas.flatten()
206
+
207
+
208
+ def predict_consensus_majority(ensemble_models, X_test, threshold=0.5):
209
+ individual_probas = np.array(
210
+ [m.predict_proba(X_test)[:, 1] for m in ensemble_models]
211
+ )
212
+ avg_proba = np.mean(individual_probas, axis=0)
213
+ return avg_proba, individual_probas.flatten()
214
+
215
+
216
+ # ---------------------------------------------------------------------------
217
+ # Bootstrap CI
218
+ # ---------------------------------------------------------------------------
219
+
220
+ def bootstrap_ci_from_oof(
221
+ point_estimate: float,
222
+ oof_probs: np.ndarray,
223
+ n_boot: int = DEFAULT_N_BOOT_CI,
224
+ confidence: float = 0.95,
225
+ random_state: int = 42,
226
+ ) -> tuple:
227
+ if oof_probs is None or len(oof_probs) == 0:
228
+ return float(point_estimate), float(point_estimate)
229
+
230
+ oof_probs = np.asarray(oof_probs, dtype=float)
231
+ rng = np.random.RandomState(random_state)
232
+ grand_mean = np.mean(oof_probs)
233
+ n = len(oof_probs)
234
+
235
+ boot_means = np.array([
236
+ np.mean(rng.choice(oof_probs, size=n, replace=True))
237
+ for _ in range(n_boot)
238
+ ])
239
+
240
+ shift = point_estimate - grand_mean
241
+ boot_means = boot_means + shift
242
+
243
+ alpha = 1.0 - confidence
244
+ lo = float(np.clip(np.percentile(boot_means, 100 * alpha / 2), 0.0, 1.0))
245
+ hi = float(np.clip(np.percentile(boot_means, 100 * (1 - alpha / 2)), 0.0, 1.0))
246
+ return lo, hi
247
+
248
+
249
+ # ---------------------------------------------------------------------------
250
+ # Calibration dispatch
251
+ # ---------------------------------------------------------------------------
252
+
253
+ def _calibrate_point(outcome: str, raw_prob: float, use_calibration: bool) -> float:
254
+ beta = betas[outcome]
255
+ p_beta = float(calibrate_probabilities_undersampling([raw_prob], beta)[0])
256
+
257
+ if not use_calibration:
258
+ return p_beta
259
+
260
+ cal = calibrators.get(outcome)
261
+ if cal is None:
262
+ return p_beta
263
+
264
+ return float(cal.transform([p_beta])[0])
265
+
266
+
267
+ # ---------------------------------------------------------------------------
268
+ # Main prediction functions
269
+ # ---------------------------------------------------------------------------
270
+
271
+ def predict_all_outcomes(
272
+ user_inputs,
273
+ use_calibration: bool = True,
274
+ use_signed_voting: bool = True,
275
+ n_boot_ci: int = DEFAULT_N_BOOT_CI,
276
+ ):
277
+ if isinstance(user_inputs, dict):
278
+ input_df = pd.DataFrame([user_inputs])
279
+ else:
280
+ input_df = pd.DataFrame([user_inputs], columns=FEATURE_NAMES)
281
+
282
+ input_df = input_df[FEATURE_NAMES]
283
+ X = preprocessor.transform(input_df)
284
+
285
+ probs, intervals = {}, {}
286
+
287
+ for o in CLASSIFICATION_OUTCOMES:
288
+ if o not in classification_models:
289
+ continue
290
+
291
+ threshold = consensus_thresholds.get(o, CONSENSUS_THRESHOLD)
292
+
293
+ if use_signed_voting:
294
+ _, uncalib_arr, _, _ = predict_consensus_signed_voting(
295
+ classification_models[o], X, threshold
296
+ )
297
+ else:
298
+ uncalib_arr, _ = predict_consensus_majority(
299
+ classification_models[o], X, threshold
300
+ )
301
+
302
+ raw_prob = float(uncalib_arr[0])
303
+ event_prob = _calibrate_point(o, raw_prob, use_calibration)
304
+
305
+ lo, hi = bootstrap_ci_from_oof(
306
+ point_estimate=event_prob,
307
+ oof_probs=oof_probs_calibrated.get(o),
308
+ n_boot=n_boot_ci,
309
+ )
310
+
311
+ probs[o] = event_prob
312
+ intervals[o] = (lo, hi)
313
+
314
+ # OS = 1 - P(DEAD)
315
+ if "DEAD" in probs:
316
+ p_dead = probs["DEAD"]
317
+ probs["OS"] = float(1.0 - p_dead)
318
+
319
+ dead_lo, dead_hi = intervals["DEAD"]
320
+ intervals["OS"] = (
321
+ float(np.clip(1.0 - dead_hi, 0, 1)),
322
+ float(np.clip(1.0 - dead_lo, 0, 1)),
323
+ )
324
+
325
+ # EFS = 1 - P(DWOGF) - P(GF)
326
+ if "DWOGF" in probs and "GF" in probs:
327
+ p_dwogf = probs["DWOGF"]
328
+ p_gf = probs["GF"]
329
+ probs["EFS"] = float(np.clip(1.0 - p_dwogf - p_gf, 0.0, 1.0))
330
+
331
+ oof_dwogf = oof_probs_calibrated.get("DWOGF")
332
+ oof_gf = oof_probs_calibrated.get("GF")
333
+
334
+ if oof_dwogf is not None and oof_gf is not None:
335
+ oof_dwogf = np.asarray(oof_dwogf, dtype=float)
336
+ oof_gf = np.asarray(oof_gf, dtype=float)
337
+ n_min = min(len(oof_dwogf), len(oof_gf))
338
+ oof_dwogf = oof_dwogf[:n_min]
339
+ oof_gf = oof_gf[:n_min]
340
+
341
+ rng = np.random.RandomState(42)
342
+ grand_dwogf = np.mean(oof_dwogf)
343
+ grand_gf = np.mean(oof_gf)
344
+ shift_dwogf = p_dwogf - grand_dwogf
345
+ shift_gf = p_gf - grand_gf
346
+
347
+ efs_boot = np.array([
348
+ np.clip(
349
+ 1.0
350
+ - (np.mean(rng.choice(oof_dwogf, size=n_min, replace=True)) + shift_dwogf)
351
+ - (np.mean(rng.choice(oof_gf, size=n_min, replace=True)) + shift_gf),
352
+ 0.0, 1.0,
353
+ )
354
+ for _ in range(n_boot_ci)
355
+ ])
356
+ efs_lo = float(np.percentile(efs_boot, 2.5))
357
+ efs_hi = float(np.percentile(efs_boot, 97.5))
358
+ intervals["EFS"] = (efs_lo, efs_hi)
359
+ else:
360
+ intervals["EFS"] = (probs["EFS"], probs["EFS"])
361
+
362
+ return probs, intervals
363
+
364
+
365
+ def predict_with_comparison(user_inputs, n_boot_ci: int = DEFAULT_N_BOOT_CI):
366
+ cal_probs, cal_intervals = predict_all_outcomes(user_inputs, True, True, n_boot_ci)
367
+ uncal_probs, uncal_intervals = predict_all_outcomes(user_inputs, False, True, n_boot_ci)
368
+ return (cal_probs, cal_intervals), (uncal_probs, uncal_intervals)
369
+
370
+
371
+ # ---------------------------------------------------------------------------
372
+ # SHAP helpers
373
+ # ---------------------------------------------------------------------------
374
+
375
+ def _get_shap_values_for_model_outcome(user_inputs, model_outcome, invert, X_proc):
376
+ """Return per-model SHAP values (shape: n_models × n_processed_features)."""
377
+ all_model_shap_vals = []
378
+ for rf_model in classification_models[model_outcome]:
379
+ explainer = shap.TreeExplainer(rf_model, model_output="probability", data=shap_background)
380
+ shap_vals = explainer.shap_values(X_proc)
381
+
382
+ if isinstance(shap_vals, list):
383
+ shap_vals = shap_vals[1]
384
+ elif shap_vals.ndim == 3 and shap_vals.shape[2] == 2:
385
+ shap_vals = shap_vals[:, :, 1]
386
+
387
+ sv = shap_vals[0]
388
+ if invert:
389
+ sv = -sv
390
+ all_model_shap_vals.append(sv)
391
+
392
+ return np.array(all_model_shap_vals)
393
+
394
+
395
+ def compute_shap_values_with_direction(user_inputs, outcome, max_display=10):
396
+ if isinstance(user_inputs, dict):
397
+ input_df = pd.DataFrame([user_inputs])
398
+ else:
399
+ input_df = pd.DataFrame([user_inputs], columns=FEATURE_NAMES)
400
+
401
+ X_proc = preprocessor.transform(input_df)
402
+
403
+ processed_to_orig = {f: f for f in NUM_COLUMNS}
404
+ for pf in ohe_feature_names:
405
+ processed_to_orig[pf] = pf.split("_", 1)[0]
406
+
407
+ if outcome == "OS":
408
+ raw_shap = _get_shap_values_for_model_outcome(user_inputs, "DEAD", invert=True, X_proc=X_proc)
409
+ elif outcome == "EFS":
410
+ shap_dwogf = _get_shap_values_for_model_outcome(user_inputs, "DWOGF", invert=True, X_proc=X_proc)
411
+ shap_gf = _get_shap_values_for_model_outcome(user_inputs, "GF", invert=True, X_proc=X_proc)
412
+ raw_shap = np.concatenate([shap_dwogf, shap_gf], axis=0)
413
+ else:
414
+ raw_shap = _get_shap_values_for_model_outcome(user_inputs, outcome, invert=False, X_proc=X_proc)
415
+
416
+ unique_orig_features = list(dict.fromkeys(processed_to_orig.values()))
417
+ n_models = len(raw_shap)
418
+
419
+ model_shap_by_orig = np.zeros((n_models, len(unique_orig_features)))
420
+ for model_idx in range(n_models):
421
+ agg_by_orig = {}
422
+ for i, pf in enumerate(processed_feature_names):
423
+ orig = processed_to_orig[pf]
424
+ agg_by_orig.setdefault(orig, 0.0)
425
+ agg_by_orig[orig] += raw_shap[model_idx, i]
426
+ for feat_idx, feat_name in enumerate(unique_orig_features):
427
+ model_shap_by_orig[model_idx, feat_idx] = agg_by_orig.get(feat_name, 0.0)
428
+
429
+ mean_shap_vals = np.mean(model_shap_by_orig, axis=0)
430
+
431
+ rng = np.random.RandomState(42)
432
+ bootstrap_shap_means = np.array([
433
+ np.mean(model_shap_by_orig[rng.choice(n_models, size=n_models, replace=True)], axis=0)
434
+ for _ in range(DEFAULT_N_BOOT_CI)
435
+ ])
436
+ shap_ci_low = np.percentile(bootstrap_shap_means, 2.5, axis=0)
437
+ shap_ci_high = np.percentile(bootstrap_shap_means, 97.5, axis=0)
438
+
439
+ order = np.argsort(-np.abs(mean_shap_vals))
440
+
441
+ top_feat_names = []
442
+ for i in order[:max_display]:
443
+ feat_name = unique_orig_features[i]
444
+ if feat_name in user_inputs:
445
+ val = user_inputs[feat_name]
446
+ if isinstance(val, float) and val != int(val):
447
+ display_name = f"{feat_name} = {val:.2f}"
448
+ elif isinstance(val, (int, float)):
449
+ display_name = f"{feat_name} = {int(val)}"
450
+ else:
451
+ val_str = str(val)
452
+ if len(val_str) > 20:
453
+ val_str = val_str[:17] + "..."
454
+ display_name = f"{feat_name} = {val_str}"
455
+ else:
456
+ display_name = feat_name
457
+ top_feat_names.append(display_name)
458
+
459
+ top_feat_names = top_feat_names[::-1]
460
+ top_shap_vals = mean_shap_vals[order][:max_display][::-1]
461
+ top_ci_low = shap_ci_low[order][:max_display][::-1]
462
+ top_ci_high = shap_ci_high[order][:max_display][::-1]
463
+
464
+ return top_feat_names, top_shap_vals, top_ci_low, top_ci_high
465
+
466
+
467
+ def create_shap_plot(user_inputs, outcome, max_display=10):
468
+ feat_names, shap_vals, ci_low, ci_high = compute_shap_values_with_direction(
469
+ user_inputs, outcome, max_display
470
+ )
471
+
472
+ colors = ["blue" if v >= 0 else "red" for v in shap_vals]
473
+ error_minus = shap_vals - ci_low
474
+ error_plus = ci_high - shap_vals
475
+
476
+ fig = go.Figure()
477
+ fig.add_trace(go.Bar(
478
+ y=feat_names,
479
+ x=shap_vals,
480
+ orientation="h",
481
+ marker=dict(color=colors),
482
+ showlegend=False,
483
+ error_x=dict(
484
+ type="data",
485
+ symmetric=False,
486
+ array=error_plus,
487
+ arrayminus=error_minus,
488
+ color="gray",
489
+ thickness=1.5,
490
+ width=4,
491
+ ),
492
+ ))
493
+ fig.add_vline(x=0, line_width=1, line_color="black")
494
+
495
+ fig.update_layout(
496
+ title=dict(
497
+ text=OUTCOME_DESCRIPTIONS.get(outcome, outcome),
498
+ x=0.5, xanchor="center",
499
+ font=dict(size=14, color="black"),
500
+ ),
501
+ xaxis_title="SHAP value",
502
+ yaxis_title="",
503
+ height=400,
504
+ margin=dict(l=120, r=60, t=50, b=50),
505
+ plot_bgcolor="white",
506
+ paper_bgcolor="white",
507
+ xaxis=dict(showgrid=True, gridcolor="lightgray", zeroline=True,
508
+ zerolinecolor="black", zerolinewidth=1),
509
+ yaxis=dict(showgrid=False),
510
+ )
511
+ return fig
512
+
513
+
514
+ def create_all_shap_plots(user_inputs, max_display=10):
515
+ return {o: create_shap_plot(user_inputs, o, max_display) for o in SHAP_OUTCOMES}
516
+
517
+
518
+
519
+ def icon_array(probability, outcome):
520
+ outcome_labels = {
521
+ "DEAD": ("Death", "Overall Survival"),
522
+ "GF": ("Graft Failure", "No Graft Failure"),
523
+ "AGVHD": ("AGVHD", "No AGVHD"),
524
+ "CGVHD": ("CGVHD", "No CGVHD"),
525
+ "VOCPSHI": ("VOC Post-HCT", "No VOC Post-HCT"),
526
+ "STROKEHI": ("Stroke Post-HCT", "No Stroke Post-HCT"),
527
+ }
528
+
529
+ event_label, no_event_label = outcome_labels.get(outcome, ("Event", "No Event"))
530
+ n_total = 100
531
+ n_event = round(probability * n_total)
532
+ n_no_event = n_total - n_event
533
+ cols, rows = 10, 10
534
+
535
+ shapes = []
536
+ icon_idx = 0
537
+
538
+ for row in range(rows - 1, -1, -1): # top → bottom
539
+ for col in range(cols): # left → right
540
+ color = "#ff6b6b" if icon_idx < n_event else "#4ecdc4"
541
+ x0 = col * 1.2
542
+ y0 = row * 1.6
543
+
544
+ # --- head (circle) ---
545
+ cx, cy_head, hr = x0 + 0.5, y0 + 1.35, 0.22
546
+ shapes.append(dict(
547
+ type="circle", xref="x", yref="y",
548
+ x0=cx - hr, y0=cy_head - hr,
549
+ x1=cx + hr, y1=cy_head + hr,
550
+ fillcolor=color, line=dict(color=color, width=0),
551
+ ))
552
+
553
+ # --- body (pentagon-ish path) ---
554
+ shapes.append(dict(
555
+ type="path", xref="x", yref="y",
556
+ path=(
557
+ f"M {x0+0.18},{y0+1.10} "
558
+ f"L {x0+0.82},{y0+1.10} "
559
+ f"L {x0+0.90},{y0+0.55} "
560
+ f"L {x0+0.60},{y0+0.55} "
561
+ f"L {x0+0.60},{y0+0.0} "
562
+ f"L {x0+0.40},{y0+0.0} "
563
+ f"L {x0+0.40},{y0+0.55} "
564
+ f"L {x0+0.10},{y0+0.55} Z"
565
+ ),
566
+ fillcolor=color, line=dict(color=color, width=0),
567
+ ))
568
+ icon_idx += 1
569
+
570
+ fig = go.Figure()
571
+ fig.update_layout(
572
+ title=dict(
573
+ text=(
574
+ f"<b>{OUTCOME_DESCRIPTIONS.get(outcome, outcome)}</b><br>"
575
+ f"<span style='font-size:12px;color:#ff6b6b'>■ {event_label}: {n_event}%</span>"
576
+ f"&nbsp;&nbsp;"
577
+ f"<span style='font-size:12px;color:#4ecdc4'>■ {no_event_label}: {n_no_event}%</span>"
578
+ ),
579
+ x=0.5, xanchor="center",
580
+ font=dict(size=14, color="black"),
581
+ ),
582
+ shapes=shapes,
583
+ xaxis=dict(
584
+ range=[-0.3, cols * 1.2 + 0.1],
585
+ showgrid=False, zeroline=False, showticklabels=False,
586
+ ),
587
+ yaxis=dict(
588
+ range=[-0.3, rows * 1.6 + 0.3],
589
+ showgrid=False, zeroline=False, showticklabels=False,
590
+ scaleanchor="x", scaleratio=1,
591
+ ),
592
+ height=460,
593
+ width=430,
594
+ margin=dict(l=10, r=10, t=90, b=10),
595
+ plot_bgcolor="white",
596
+ paper_bgcolor="white",
597
+ )
598
+ return fig
599
+
600
+
601
+
602
+
603
+
604
+