shivapriyasom commited on
Commit
fc0f605
·
verified ·
1 Parent(s): d920fdb

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +14 -44
inference.py CHANGED
@@ -9,11 +9,7 @@ import warnings
9
 
10
  warnings.filterwarnings("ignore", category=UserWarning, module="sklearn")
11
 
12
- # ---------------------------------------------------------------------------
13
- # Compatibility patch — inject _RemainderColsList if the installed sklearn
14
- # version does not have it (added in sklearn 1.4+). This allows .skops files
15
- # saved with a newer sklearn to load correctly on older environments.
16
- # ---------------------------------------------------------------------------
17
  import sklearn.compose._column_transformer as _ct
18
  if not hasattr(_ct, "_RemainderColsList"):
19
  class _RemainderColsList(list):
@@ -26,9 +22,6 @@ if not hasattr(_ct, "_RemainderColsList"):
26
  sklearn.compose._RemainderColsList = _RemainderColsList
27
 
28
 
29
- # ---------------------------------------------------------------------------
30
- # Column / feature definitions
31
- # ---------------------------------------------------------------------------
32
 
33
  NUM_COLUMNS = ["AGE", "NACS2YR"]
34
  CATEG_COLUMNS = [
@@ -62,7 +55,7 @@ REPORTING_OUTCOMES = [
62
  OUTCOME_DESCRIPTIONS = {
63
  "OS": "Overall Survival",
64
  "EFS": "Event-Free Survival",
65
- "DEAD": "Total Mortality",
66
  "GF": "Graft Failure",
67
  "AGVHD": "Acute Graft-versus-Host Disease",
68
  "CGVHD": "Chronic Graft-versus-Host Disease",
@@ -77,9 +70,7 @@ CONSENSUS_THRESHOLD = 0.5
77
  DEFAULT_N_BOOT_CI = 500
78
 
79
 
80
- # ---------------------------------------------------------------------------
81
- # Model loading — skops
82
- # ---------------------------------------------------------------------------
83
 
84
  def _load_skops_model(fname):
85
  try:
@@ -108,7 +99,7 @@ consensus_thresholds = {
108
  for o, d in classification_model_data.items()
109
  }
110
 
111
- # Calibrators — isotonic only; supports both old and new key names
112
  calibrators = {}
113
  for _o, _d in classification_model_data.items():
114
  _cal = None
@@ -139,9 +130,7 @@ ohe_feature_names = ohe.get_feature_names_out(CATEG_COLUMNS)
139
  processed_feature_names = np.concatenate([NUM_COLUMNS, ohe_feature_names])
140
 
141
 
142
- # ---------------------------------------------------------------------------
143
- # SHAP background data
144
- # ---------------------------------------------------------------------------
145
 
146
  np.random.seed(23)
147
  _n_background = 500
@@ -182,9 +171,6 @@ _X_background = preprocessor.transform(_background_df)
182
  shap_background = shap.maskers.Independent(_X_background)
183
 
184
 
185
- # ---------------------------------------------------------------------------
186
- # Calibration helpers
187
- # ---------------------------------------------------------------------------
188
 
189
  def calibrate_probabilities_undersampling(p_s, beta):
190
  p_s = np.asarray(p_s, dtype=float)
@@ -213,9 +199,7 @@ def predict_consensus_majority(ensemble_models, X_test, threshold=0.5):
213
  return avg_proba, individual_probas.flatten()
214
 
215
 
216
- # ---------------------------------------------------------------------------
217
- # Bootstrap CI
218
- # ---------------------------------------------------------------------------
219
 
220
  def bootstrap_ci_from_oof(
221
  point_estimate: float,
@@ -246,9 +230,7 @@ def bootstrap_ci_from_oof(
246
  return lo, hi
247
 
248
 
249
- # ---------------------------------------------------------------------------
250
- # Calibration dispatch
251
- # ---------------------------------------------------------------------------
252
 
253
  def _calibrate_point(outcome: str, raw_prob: float, use_calibration: bool) -> float:
254
  beta = betas[outcome]
@@ -264,9 +246,7 @@ def _calibrate_point(outcome: str, raw_prob: float, use_calibration: bool) -> fl
264
  return float(cal.transform([p_beta])[0])
265
 
266
 
267
- # ---------------------------------------------------------------------------
268
- # Main prediction functions
269
- # ---------------------------------------------------------------------------
270
 
271
  def predict_all_outcomes(
272
  user_inputs,
@@ -311,7 +291,6 @@ def predict_all_outcomes(
311
  probs[o] = event_prob
312
  intervals[o] = (lo, hi)
313
 
314
- # OS = 1 - P(DEAD)
315
  if "DEAD" in probs:
316
  p_dead = probs["DEAD"]
317
  probs["OS"] = float(1.0 - p_dead)
@@ -322,7 +301,6 @@ def predict_all_outcomes(
322
  float(np.clip(1.0 - dead_lo, 0, 1)),
323
  )
324
 
325
- # EFS = 1 - P(DWOGF) - P(GF)
326
  if "DWOGF" in probs and "GF" in probs:
327
  p_dwogf = probs["DWOGF"]
328
  p_gf = probs["GF"]
@@ -368,9 +346,7 @@ def predict_with_comparison(user_inputs, n_boot_ci: int = DEFAULT_N_BOOT_CI):
368
  return (cal_probs, cal_intervals), (uncal_probs, uncal_intervals)
369
 
370
 
371
- # ---------------------------------------------------------------------------
372
- # SHAP helpers
373
- # ---------------------------------------------------------------------------
374
 
375
  def _get_shap_values_for_model_outcome(user_inputs, model_outcome, invert, X_proc):
376
  """Return per-model SHAP values (shape: n_models × n_processed_features)."""
@@ -515,15 +491,13 @@ def create_all_shap_plots(user_inputs, max_display=10):
515
  return {o: create_shap_plot(user_inputs, o, max_display) for o in SHAP_OUTCOMES}
516
 
517
 
518
- # ---------------------------------------------------------------------------
519
- # Icon Array (replaces Pie Charts)
520
- # ---------------------------------------------------------------------------
521
 
522
- EVENT_COLOR = "#e53935" # red — event
523
- NO_EVENT_COLOR = "#43a047" # green — no event
 
524
 
525
  OUTCOME_TITLES = {
526
- "DEAD": "Total Mortality",
527
  "GF": "Graft Failure",
528
  "AGVHD": "Acute GvHD",
529
  "CGVHD": "Chronic GvHD",
@@ -531,7 +505,7 @@ OUTCOME_TITLES = {
531
  "STROKEHI": "Stroke Post-HCT",
532
  }
533
 
534
- # Short, equal-length label pairs so the legend stays uniformly sized
535
  OUTCOME_LABELS = {
536
  "DEAD": ("Death", "No Death"),
537
  "GF": ("Graft Failure", "No Graft Failure"),
@@ -570,7 +544,6 @@ def create_icon_array_html(probability: float, outcome: str) -> str:
570
  n_no_event = 100 - n_event
571
  pct_str = f"{probability * 100:.1f}%"
572
 
573
- # --- grid: 10 rows × 10 cols ---
574
  rows_parts = []
575
  for row in range(10):
576
  cells = ""
@@ -656,9 +629,6 @@ def create_all_icon_arrays(calibrated_probs: dict) -> dict:
656
  return cards
657
 
658
 
659
- # ---------------------------------------------------------------------------
660
- # Backward-compatibility aliases
661
- # ---------------------------------------------------------------------------
662
 
663
  def create_pie_chart(probability, outcome):
664
  return create_icon_array_html(probability, outcome)
 
9
 
10
  warnings.filterwarnings("ignore", category=UserWarning, module="sklearn")
11
 
12
+
 
 
 
 
13
  import sklearn.compose._column_transformer as _ct
14
  if not hasattr(_ct, "_RemainderColsList"):
15
  class _RemainderColsList(list):
 
22
  sklearn.compose._RemainderColsList = _RemainderColsList
23
 
24
 
 
 
 
25
 
26
  NUM_COLUMNS = ["AGE", "NACS2YR"]
27
  CATEG_COLUMNS = [
 
55
  OUTCOME_DESCRIPTIONS = {
56
  "OS": "Overall Survival",
57
  "EFS": "Event-Free Survival",
58
+ "DEAD": "Death",
59
  "GF": "Graft Failure",
60
  "AGVHD": "Acute Graft-versus-Host Disease",
61
  "CGVHD": "Chronic Graft-versus-Host Disease",
 
70
  DEFAULT_N_BOOT_CI = 500
71
 
72
 
73
+
 
 
74
 
75
  def _load_skops_model(fname):
76
  try:
 
99
  for o, d in classification_model_data.items()
100
  }
101
 
102
+
103
  calibrators = {}
104
  for _o, _d in classification_model_data.items():
105
  _cal = None
 
130
  processed_feature_names = np.concatenate([NUM_COLUMNS, ohe_feature_names])
131
 
132
 
133
+
 
 
134
 
135
  np.random.seed(23)
136
  _n_background = 500
 
171
  shap_background = shap.maskers.Independent(_X_background)
172
 
173
 
 
 
 
174
 
175
  def calibrate_probabilities_undersampling(p_s, beta):
176
  p_s = np.asarray(p_s, dtype=float)
 
199
  return avg_proba, individual_probas.flatten()
200
 
201
 
202
+
 
 
203
 
204
  def bootstrap_ci_from_oof(
205
  point_estimate: float,
 
230
  return lo, hi
231
 
232
 
233
+
 
 
234
 
235
  def _calibrate_point(outcome: str, raw_prob: float, use_calibration: bool) -> float:
236
  beta = betas[outcome]
 
246
  return float(cal.transform([p_beta])[0])
247
 
248
 
249
+
 
 
250
 
251
  def predict_all_outcomes(
252
  user_inputs,
 
291
  probs[o] = event_prob
292
  intervals[o] = (lo, hi)
293
 
 
294
  if "DEAD" in probs:
295
  p_dead = probs["DEAD"]
296
  probs["OS"] = float(1.0 - p_dead)
 
301
  float(np.clip(1.0 - dead_lo, 0, 1)),
302
  )
303
 
 
304
  if "DWOGF" in probs and "GF" in probs:
305
  p_dwogf = probs["DWOGF"]
306
  p_gf = probs["GF"]
 
346
  return (cal_probs, cal_intervals), (uncal_probs, uncal_intervals)
347
 
348
 
349
+
 
 
350
 
351
  def _get_shap_values_for_model_outcome(user_inputs, model_outcome, invert, X_proc):
352
  """Return per-model SHAP values (shape: n_models × n_processed_features)."""
 
491
  return {o: create_shap_plot(user_inputs, o, max_display) for o in SHAP_OUTCOMES}
492
 
493
 
 
 
 
494
 
495
+
496
+ EVENT_COLOR = "#e53935"
497
+ NO_EVENT_COLOR = "#43a047"
498
 
499
  OUTCOME_TITLES = {
500
+ "DEAD": "TDeath",
501
  "GF": "Graft Failure",
502
  "AGVHD": "Acute GvHD",
503
  "CGVHD": "Chronic GvHD",
 
505
  "STROKEHI": "Stroke Post-HCT",
506
  }
507
 
508
+
509
  OUTCOME_LABELS = {
510
  "DEAD": ("Death", "No Death"),
511
  "GF": ("Graft Failure", "No Graft Failure"),
 
544
  n_no_event = 100 - n_event
545
  pct_str = f"{probability * 100:.1f}%"
546
 
 
547
  rows_parts = []
548
  for row in range(10):
549
  cells = ""
 
629
  return cards
630
 
631
 
 
 
 
632
 
633
  def create_pie_chart(probability, outcome):
634
  return create_icon_array_html(probability, outcome)