shivapriyasom commited on
Commit
2e034b5
Β·
verified Β·
1 Parent(s): 3c51df3

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +99 -76
inference.py CHANGED
@@ -14,13 +14,11 @@ print(f"Working directory: {os.getcwd()}")
14
  print(f"Files present: {os.listdir('.')}")
15
 
16
  # ---------------------------------------------------------------------------
17
- # Compatibility patch β€” inject _RemainderColsList if the installed sklearn
18
- # version does not have it (added in sklearn 1.4+).
19
  # ---------------------------------------------------------------------------
20
  import sklearn.compose._column_transformer as _ct
21
  if not hasattr(_ct, "_RemainderColsList"):
22
  class _RemainderColsList(list):
23
- """Minimal shim for sklearn._RemainderColsList (missing in this env)."""
24
  def __init__(self, lst=None, future_dtype=None):
25
  super().__init__(lst or [])
26
  self.future_dtype = future_dtype
@@ -70,11 +68,10 @@ DEFAULT_N_BOOT_CI = 500
70
 
71
 
72
  # ---------------------------------------------------------------------------
73
- # Model loading β€” skops
74
  # ---------------------------------------------------------------------------
75
 
76
  def _load_skops_model(fname):
77
- """Load a skops model file. Raises RuntimeError on failure (no sys.exit)."""
78
  if not os.path.exists(fname):
79
  raise RuntimeError(f"Model file not found: {fname}")
80
  try:
@@ -106,26 +103,19 @@ consensus_thresholds = {
106
  for o, d in classification_model_data.items()
107
  }
108
 
109
- # Calibrators β€” isotonic only; supports both old and new key names
110
  calibrators = {}
111
  for _o, _d in classification_model_data.items():
112
  _cal = None
113
  _cal_type = _d.get("calibrator_type", None)
114
-
115
  if "calibrator" in _d and _d["calibrator"] is not None:
116
  if _cal_type is None or _cal_type == "isotonic":
117
  _cal = _d["calibrator"]
118
  else:
119
- print(
120
- f" Warning: outcome '{_o}' has calibrator_type='{_cal_type}'. "
121
- "Skipping non-isotonic calibrator (isotonic-only policy)."
122
- )
123
  elif "isotonic_calibrator" in _d and _d["isotonic_calibrator"] is not None:
124
  _cal = _d["isotonic_calibrator"]
125
-
126
  calibrators[_o] = _cal
127
 
128
- # Alias expected by app.py
129
  isotonic_calibrators = calibrators
130
 
131
  oof_probs_calibrated = {
@@ -255,14 +245,11 @@ def bootstrap_ci_from_oof(
255
  def _calibrate_point(outcome: str, raw_prob: float, use_calibration: bool) -> float:
256
  beta = betas[outcome]
257
  p_beta = float(calibrate_probabilities_undersampling([raw_prob], beta)[0])
258
-
259
  if not use_calibration:
260
  return p_beta
261
-
262
  cal = calibrators.get(outcome)
263
  if cal is None:
264
  return p_beta
265
-
266
  return float(cal.transform([p_beta])[0])
267
 
268
 
@@ -317,7 +304,6 @@ def predict_all_outcomes(
317
  if "DEAD" in probs:
318
  p_dead = probs["DEAD"]
319
  probs["OS"] = float(1.0 - p_dead)
320
-
321
  dead_lo, dead_hi = intervals["DEAD"]
322
  intervals["OS"] = (
323
  float(np.clip(1.0 - dead_hi, 0, 1)),
@@ -355,9 +341,10 @@ def predict_all_outcomes(
355
  )
356
  for _ in range(n_boot_ci)
357
  ])
358
- efs_lo = float(np.percentile(efs_boot, 2.5))
359
- efs_hi = float(np.percentile(efs_boot, 97.5))
360
- intervals["EFS"] = (efs_lo, efs_hi)
 
361
  else:
362
  intervals["EFS"] = (probs["EFS"], probs["EFS"])
363
 
@@ -375,22 +362,18 @@ def predict_with_comparison(user_inputs, n_boot_ci: int = DEFAULT_N_BOOT_CI):
375
  # ---------------------------------------------------------------------------
376
 
377
  def _get_shap_values_for_model_outcome(user_inputs, model_outcome, invert, X_proc):
378
- """Return per-model SHAP values (shape: n_models x n_processed_features)."""
379
  all_model_shap_vals = []
380
  for rf_model in classification_models[model_outcome]:
381
  explainer = shap.TreeExplainer(rf_model, model_output="probability", data=shap_background)
382
  shap_vals = explainer.shap_values(X_proc)
383
-
384
  if isinstance(shap_vals, list):
385
  shap_vals = shap_vals[1]
386
  elif shap_vals.ndim == 3 and shap_vals.shape[2] == 2:
387
  shap_vals = shap_vals[:, :, 1]
388
-
389
  sv = shap_vals[0]
390
  if invert:
391
  sv = -sv
392
  all_model_shap_vals.append(sv)
393
-
394
  return np.array(all_model_shap_vals)
395
 
396
 
@@ -518,9 +501,85 @@ def create_all_shap_plots(user_inputs, max_display=10):
518
 
519
 
520
  # ---------------------------------------------------------------------------
521
- # Icon array
522
  # ---------------------------------------------------------------------------
523
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
  def icon_array(probability, outcome):
525
  outcome_labels = {
526
  "DEAD": ("Death", "Overall Survival"),
@@ -537,51 +596,15 @@ def icon_array(probability, outcome):
537
  n_no_event = n_total - n_event
538
  cols, rows = 10, 10
539
 
540
- shapes = []
541
- icon_idx = 0
542
- sx, sy = 0.38, 0.38 # icon half-size within each unit cell
543
 
544
- for row in range(rows): # row 0 = top
545
- for col in range(cols): # col 0 = left
546
- color = "#ff6b6b" if icon_idx < n_event else "#4ecdc4"
547
  cx = col
548
- cy = rows - 1 - row # invert so row 0 renders at the top
549
-
550
- # ── head (circle) ──────────────────────────────────────────
551
- hr = sy * 0.22
552
- hx = cx
553
- hy = cy + sy * 0.65
554
- shapes.append(dict(
555
- type="circle", xref="x", yref="y",
556
- x0=hx - hr, y0=hy - hr,
557
- x1=hx + hr, y1=hy + hr,
558
- fillcolor=color, line=dict(color=color, width=0),
559
- ))
560
-
561
- # ── body (symmetric trapezoid: shoulders β†’ waist β†’ feet) ───
562
- tx0 = cx - sx * 0.32
563
- tx1 = cx + sx * 0.32
564
- wx0 = cx - sx * 0.20
565
- wx1 = cx + sx * 0.20
566
- bx0 = cx - sx * 0.32
567
- bx1 = cx + sx * 0.32
568
-
569
- ty_top = cy + sy * 0.38 # shoulder line
570
- ty_waist = cy + sy * 0.00 # waist line
571
- ty_bottom = cy - sy * 0.42 # feet line
572
-
573
- shapes.append(dict(
574
- type="path", xref="x", yref="y",
575
- path=(
576
- f"M {tx0},{ty_top} "
577
- f"L {tx1},{ty_top} "
578
- f"L {wx1},{ty_waist} "
579
- f"L {bx1},{ty_bottom} "
580
- f"L {bx0},{ty_bottom} "
581
- f"L {wx0},{ty_waist} Z"
582
- ),
583
- fillcolor=color, line=dict(color=color, width=0),
584
- ))
585
  icon_idx += 1
586
 
587
  fig = go.Figure()
@@ -589,27 +612,27 @@ def icon_array(probability, outcome):
589
  title=dict(
590
  text=(
591
  f"<b>{OUTCOME_DESCRIPTIONS.get(outcome, outcome)}</b><br>"
592
- f"<span style='font-size:12px;color:#ff6b6b'>β–  {event_label}: {n_event}%</span>"
593
  f"&nbsp;&nbsp;"
594
- f"<span style='font-size:12px;color:#4ecdc4'>β–  {no_event_label}: {n_no_event}%</span>"
595
  ),
596
  x=0.5, xanchor="center",
597
  font=dict(size=14, color="black"),
598
  ),
599
- shapes=shapes,
600
  xaxis=dict(
601
- range=[-0.65, cols - 0.35],
602
  showgrid=False, zeroline=False, showticklabels=False,
603
  fixedrange=True,
604
  ),
605
  yaxis=dict(
606
- range=[-0.65, rows - 0.35],
607
  showgrid=False, zeroline=False, showticklabels=False,
608
  fixedrange=True,
609
  ),
610
- height=420,
611
- width=400,
612
- margin=dict(l=10, r=10, t=80, b=10),
613
  plot_bgcolor="white",
614
  paper_bgcolor="white",
615
  )
 
14
  print(f"Files present: {os.listdir('.')}")
15
 
16
  # ---------------------------------------------------------------------------
17
+ # Compatibility patch
 
18
  # ---------------------------------------------------------------------------
19
  import sklearn.compose._column_transformer as _ct
20
  if not hasattr(_ct, "_RemainderColsList"):
21
  class _RemainderColsList(list):
 
22
  def __init__(self, lst=None, future_dtype=None):
23
  super().__init__(lst or [])
24
  self.future_dtype = future_dtype
 
68
 
69
 
70
  # ---------------------------------------------------------------------------
71
+ # Model loading
72
  # ---------------------------------------------------------------------------
73
 
74
  def _load_skops_model(fname):
 
75
  if not os.path.exists(fname):
76
  raise RuntimeError(f"Model file not found: {fname}")
77
  try:
 
103
  for o, d in classification_model_data.items()
104
  }
105
 
 
106
  calibrators = {}
107
  for _o, _d in classification_model_data.items():
108
  _cal = None
109
  _cal_type = _d.get("calibrator_type", None)
 
110
  if "calibrator" in _d and _d["calibrator"] is not None:
111
  if _cal_type is None or _cal_type == "isotonic":
112
  _cal = _d["calibrator"]
113
  else:
114
+ print(f" Warning: outcome '{_o}' has calibrator_type='{_cal_type}'. Skipping.")
 
 
 
115
  elif "isotonic_calibrator" in _d and _d["isotonic_calibrator"] is not None:
116
  _cal = _d["isotonic_calibrator"]
 
117
  calibrators[_o] = _cal
118
 
 
119
  isotonic_calibrators = calibrators
120
 
121
  oof_probs_calibrated = {
 
245
  def _calibrate_point(outcome: str, raw_prob: float, use_calibration: bool) -> float:
246
  beta = betas[outcome]
247
  p_beta = float(calibrate_probabilities_undersampling([raw_prob], beta)[0])
 
248
  if not use_calibration:
249
  return p_beta
 
250
  cal = calibrators.get(outcome)
251
  if cal is None:
252
  return p_beta
 
253
  return float(cal.transform([p_beta])[0])
254
 
255
 
 
304
  if "DEAD" in probs:
305
  p_dead = probs["DEAD"]
306
  probs["OS"] = float(1.0 - p_dead)
 
307
  dead_lo, dead_hi = intervals["DEAD"]
308
  intervals["OS"] = (
309
  float(np.clip(1.0 - dead_hi, 0, 1)),
 
341
  )
342
  for _ in range(n_boot_ci)
343
  ])
344
+ intervals["EFS"] = (
345
+ float(np.percentile(efs_boot, 2.5)),
346
+ float(np.percentile(efs_boot, 97.5)),
347
+ )
348
  else:
349
  intervals["EFS"] = (probs["EFS"], probs["EFS"])
350
 
 
362
  # ---------------------------------------------------------------------------
363
 
364
  def _get_shap_values_for_model_outcome(user_inputs, model_outcome, invert, X_proc):
 
365
  all_model_shap_vals = []
366
  for rf_model in classification_models[model_outcome]:
367
  explainer = shap.TreeExplainer(rf_model, model_output="probability", data=shap_background)
368
  shap_vals = explainer.shap_values(X_proc)
 
369
  if isinstance(shap_vals, list):
370
  shap_vals = shap_vals[1]
371
  elif shap_vals.ndim == 3 and shap_vals.shape[2] == 2:
372
  shap_vals = shap_vals[:, :, 1]
 
373
  sv = shap_vals[0]
374
  if invert:
375
  sv = -sv
376
  all_model_shap_vals.append(sv)
 
377
  return np.array(all_model_shap_vals)
378
 
379
 
 
501
 
502
 
503
  # ---------------------------------------------------------------------------
504
+ # Icon array β€” proper stick figures
505
  # ---------------------------------------------------------------------------
506
 
507
+ def _stick_figure_shapes(cx, cy, color, s=0.42):
508
+ """
509
+ Draw a classic stick figure centred at (cx, cy) with scale s.
510
+
511
+ Parts:
512
+ - head : circle
513
+ - spine : vertical line from neck to hips
514
+ - arms : angled lines left/right from mid-spine
515
+ - legs : angled lines left/right from hips
516
+ """
517
+ shapes = []
518
+ lw = dict(color=color, width=max(1, s * 4)) # line width scales with size
519
+
520
+ # ── head ──────────────────────────────────────────────────────────────
521
+ hr = s * 0.22 # head radius
522
+ hy = cy + s * 0.60 # head centre y
523
+ shapes.append(dict(
524
+ type="circle", xref="x", yref="y",
525
+ x0=cx - hr, y0=hy - hr,
526
+ x1=cx + hr, y1=hy + hr,
527
+ fillcolor=color,
528
+ line=dict(color=color, width=0),
529
+ ))
530
+
531
+ # key y levels
532
+ neck_y = cy + s * 0.35 # base of head / top of spine
533
+ hip_y = cy - s * 0.20 # bottom of spine / top of legs
534
+ arm_y = cy + s * 0.15 # where arms branch from spine
535
+ foot_y = cy - s * 0.60 # feet
536
+
537
+ # ── spine ─────────────────────────────────────────────────────────────
538
+ shapes.append(dict(
539
+ type="line", xref="x", yref="y",
540
+ x0=cx, y0=neck_y,
541
+ x1=cx, y1=hip_y,
542
+ line=lw,
543
+ ))
544
+
545
+ # ── arms (angled outward) ──────────────────────────────────────────────
546
+ arm_dx = s * 0.35
547
+ arm_dy = s * 0.18
548
+ # left arm
549
+ shapes.append(dict(
550
+ type="line", xref="x", yref="y",
551
+ x0=cx, y0=arm_y,
552
+ x1=cx - arm_dx, y1=arm_y - arm_dy,
553
+ line=lw,
554
+ ))
555
+ # right arm
556
+ shapes.append(dict(
557
+ type="line", xref="x", yref="y",
558
+ x0=cx, y0=arm_y,
559
+ x1=cx + arm_dx, y1=arm_y - arm_dy,
560
+ line=lw,
561
+ ))
562
+
563
+ # ── legs (angled outward) ──────────────────────────────────────────────
564
+ leg_dx = s * 0.28
565
+ # left leg
566
+ shapes.append(dict(
567
+ type="line", xref="x", yref="y",
568
+ x0=cx, y0=hip_y,
569
+ x1=cx - leg_dx, y1=foot_y,
570
+ line=lw,
571
+ ))
572
+ # right leg
573
+ shapes.append(dict(
574
+ type="line", xref="x", yref="y",
575
+ x0=cx, y0=hip_y,
576
+ x1=cx + leg_dx, y1=foot_y,
577
+ line=lw,
578
+ ))
579
+
580
+ return shapes
581
+
582
+
583
  def icon_array(probability, outcome):
584
  outcome_labels = {
585
  "DEAD": ("Death", "Overall Survival"),
 
596
  n_no_event = n_total - n_event
597
  cols, rows = 10, 10
598
 
599
+ all_shapes = []
600
+ icon_idx = 0
 
601
 
602
+ for row in range(rows): # row 0 = top
603
+ for col in range(cols): # col 0 = left
604
+ color = "#e05555" if icon_idx < n_event else "#3bbfad"
605
  cx = col
606
+ cy = rows - 1 - row # invert so row 0 is at top
607
+ all_shapes.extend(_stick_figure_shapes(cx, cy, color, s=0.38))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
608
  icon_idx += 1
609
 
610
  fig = go.Figure()
 
612
  title=dict(
613
  text=(
614
  f"<b>{OUTCOME_DESCRIPTIONS.get(outcome, outcome)}</b><br>"
615
+ f"<span style='font-size:12px;color:#e05555'>β–  {event_label}: {n_event}%</span>"
616
  f"&nbsp;&nbsp;"
617
+ f"<span style='font-size:12px;color:#3bbfad'>β–  {no_event_label}: {n_no_event}%</span>"
618
  ),
619
  x=0.5, xanchor="center",
620
  font=dict(size=14, color="black"),
621
  ),
622
+ shapes=all_shapes,
623
  xaxis=dict(
624
+ range=[-0.7, cols - 0.3],
625
  showgrid=False, zeroline=False, showticklabels=False,
626
  fixedrange=True,
627
  ),
628
  yaxis=dict(
629
+ range=[-0.8, rows - 0.2],
630
  showgrid=False, zeroline=False, showticklabels=False,
631
  fixedrange=True,
632
  ),
633
+ height=440,
634
+ width=420,
635
+ margin=dict(l=10, r=10, t=85, b=10),
636
  plot_bgcolor="white",
637
  paper_bgcolor="white",
638
  )