shivapriyasom commited on
Commit
33eec8a
Β·
verified Β·
1 Parent(s): f559a3c

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +42 -32
inference.py CHANGED
@@ -373,7 +373,7 @@ def predict_with_comparison(user_inputs, n_boot_ci: int = DEFAULT_N_BOOT_CI):
373
  # ---------------------------------------------------------------------------
374
 
375
  def _get_shap_values_for_model_outcome(user_inputs, model_outcome, invert, X_proc):
376
- """Return per-model SHAP values (shape: n_models Γ— n_processed_features)."""
377
  all_model_shap_vals = []
378
  for rf_model in classification_models[model_outcome]:
379
  explainer = shap.TreeExplainer(rf_model, model_output="probability", data=shap_background)
@@ -515,6 +515,9 @@ def create_all_shap_plots(user_inputs, max_display=10):
515
  return {o: create_shap_plot(user_inputs, o, max_display) for o in SHAP_OUTCOMES}
516
 
517
 
 
 
 
518
 
519
  def icon_array(probability, outcome):
520
  outcome_labels = {
@@ -532,36 +535,48 @@ def icon_array(probability, outcome):
532
  n_no_event = n_total - n_event
533
  cols, rows = 10, 10
534
 
535
- shapes = []
536
  icon_idx = 0
 
537
 
538
- for row in range(rows - 1, -1, -1): # top β†’ bottom
539
- for col in range(cols): # left β†’ right
540
  color = "#ff6b6b" if icon_idx < n_event else "#4ecdc4"
541
- x0 = col * 1.2
542
- y0 = row * 1.6
543
 
544
- # --- head (circle) ---
545
- cx, cy_head, hr = x0 + 0.5, y0 + 1.35, 0.22
 
 
546
  shapes.append(dict(
547
  type="circle", xref="x", yref="y",
548
- x0=cx - hr, y0=cy_head - hr,
549
- x1=cx + hr, y1=cy_head + hr,
550
  fillcolor=color, line=dict(color=color, width=0),
551
  ))
552
 
553
- # --- body (pentagon-ish path) ---
 
 
 
 
 
 
 
 
 
 
 
554
  shapes.append(dict(
555
  type="path", xref="x", yref="y",
556
  path=(
557
- f"M {x0+0.18},{y0+1.10} "
558
- f"L {x0+0.82},{y0+1.10} "
559
- f"L {x0+0.90},{y0+0.55} "
560
- f"L {x0+0.60},{y0+0.55} "
561
- f"L {x0+0.60},{y0+0.0} "
562
- f"L {x0+0.40},{y0+0.0} "
563
- f"L {x0+0.40},{y0+0.55} "
564
- f"L {x0+0.10},{y0+0.55} Z"
565
  ),
566
  fillcolor=color, line=dict(color=color, width=0),
567
  ))
@@ -581,24 +596,19 @@ def icon_array(probability, outcome):
581
  ),
582
  shapes=shapes,
583
  xaxis=dict(
584
- range=[-0.3, cols * 1.2 + 0.1],
585
  showgrid=False, zeroline=False, showticklabels=False,
 
586
  ),
587
  yaxis=dict(
588
- range=[-0.3, rows * 1.6 + 0.3],
589
  showgrid=False, zeroline=False, showticklabels=False,
590
- scaleanchor="x", scaleratio=1,
591
  ),
592
- height=460,
593
- width=430,
594
- margin=dict(l=10, r=10, t=90, b=10),
595
  plot_bgcolor="white",
596
  paper_bgcolor="white",
597
  )
598
- return fig
599
-
600
-
601
-
602
-
603
-
604
-
 
373
  # ---------------------------------------------------------------------------
374
 
375
  def _get_shap_values_for_model_outcome(user_inputs, model_outcome, invert, X_proc):
376
+ """Return per-model SHAP values (shape: n_models x n_processed_features)."""
377
  all_model_shap_vals = []
378
  for rf_model in classification_models[model_outcome]:
379
  explainer = shap.TreeExplainer(rf_model, model_output="probability", data=shap_background)
 
515
  return {o: create_shap_plot(user_inputs, o, max_display) for o in SHAP_OUTCOMES}
516
 
517
 
518
+ # ---------------------------------------------------------------------------
519
+ # Icon array
520
+ # ---------------------------------------------------------------------------
521
 
522
  def icon_array(probability, outcome):
523
  outcome_labels = {
 
535
  n_no_event = n_total - n_event
536
  cols, rows = 10, 10
537
 
538
+ shapes = []
539
  icon_idx = 0
540
+ sx, sy = 0.38, 0.38 # icon half-size within each unit cell
541
 
542
+ for row in range(rows): # row 0 = top
543
+ for col in range(cols): # col 0 = left
544
  color = "#ff6b6b" if icon_idx < n_event else "#4ecdc4"
545
+ cx = col
546
+ cy = rows - 1 - row # invert so row 0 renders at the top
547
 
548
+ # ── head (circle) ──────────────────────────────────────────
549
+ hr = sy * 0.22
550
+ hx = cx
551
+ hy = cy + sy * 0.65
552
  shapes.append(dict(
553
  type="circle", xref="x", yref="y",
554
+ x0=hx - hr, y0=hy - hr,
555
+ x1=hx + hr, y1=hy + hr,
556
  fillcolor=color, line=dict(color=color, width=0),
557
  ))
558
 
559
+ # ── body (symmetric trapezoid: shoulders β†’ waist β†’ feet) ───
560
+ tx0 = cx - sx * 0.32
561
+ tx1 = cx + sx * 0.32
562
+ wx0 = cx - sx * 0.20
563
+ wx1 = cx + sx * 0.20
564
+ bx0 = cx - sx * 0.32
565
+ bx1 = cx + sx * 0.32
566
+
567
+ ty_top = cy + sy * 0.38 # shoulder line
568
+ ty_waist = cy + sy * 0.00 # waist line
569
+ ty_bottom = cy - sy * 0.42 # feet line
570
+
571
  shapes.append(dict(
572
  type="path", xref="x", yref="y",
573
  path=(
574
+ f"M {tx0},{ty_top} "
575
+ f"L {tx1},{ty_top} "
576
+ f"L {wx1},{ty_waist} "
577
+ f"L {bx1},{ty_bottom} "
578
+ f"L {bx0},{ty_bottom} "
579
+ f"L {wx0},{ty_waist} Z"
 
 
580
  ),
581
  fillcolor=color, line=dict(color=color, width=0),
582
  ))
 
596
  ),
597
  shapes=shapes,
598
  xaxis=dict(
599
+ range=[-0.65, cols - 0.35],
600
  showgrid=False, zeroline=False, showticklabels=False,
601
+ fixedrange=True,
602
  ),
603
  yaxis=dict(
604
+ range=[-0.65, rows - 0.35],
605
  showgrid=False, zeroline=False, showticklabels=False,
606
+ fixedrange=True,
607
  ),
608
+ height=420,
609
+ width=400,
610
+ margin=dict(l=10, r=10, t=80, b=10),
611
  plot_bgcolor="white",
612
  paper_bgcolor="white",
613
  )
614
+ return fig