shivapriyasom commited on
Commit
9c0b6ef
·
verified ·
1 Parent(s): 0032613

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +65 -88
inference.py CHANGED
@@ -519,67 +519,42 @@ def create_all_shap_plots(user_inputs, max_display=10):
519
  # Icon Array (replaces Pie Charts)
520
  # ---------------------------------------------------------------------------
521
 
522
- # SVG path for a simple person/figure icon
523
- _PERSON_SVG_PATH = (
524
- "M10,2 C11.65,2 13,3.35 13,5 C13,6.65 11.65,8 10,8 "
525
- "C8.35,8 7,6.65 7,5 C7,3.35 8.35,2 10,2 Z "
526
- "M10,9 C6.13,9 3,10.79 3,13 L3,14 L17,14 L17,13 "
527
- "C17,10.79 13.87,9 10,9 Z"
528
- )
529
 
530
  OUTCOME_CONFIG = {
531
- "DEAD": {
532
- "event_label": "Death",
533
- "no_event_label": "Survival",
534
- "event_color": "#e63946",
535
- "no_event_color": "#2a9d8f",
536
- "title": "Total Mortality",
537
- },
538
- "GF": {
539
- "event_label": "Graft Failure",
540
- "no_event_label": "No Graft Failure",
541
- "event_color": "#e76f51",
542
- "no_event_color": "#457b9d",
543
- "title": "Graft Failure",
544
- },
545
- "AGVHD": {
546
- "event_label": "Acute GVHD",
547
- "no_event_label": "No Acute GVHD",
548
- "event_color": "#c77dff",
549
- "no_event_color": "#48cae4",
550
- "title": "Acute GvHD",
551
- },
552
- "CGVHD": {
553
- "event_label": "Chronic GVHD",
554
- "no_event_label": "No Chronic GVHD",
555
- "event_color": "#9b2226",
556
- "no_event_color": "#94d2bd",
557
- "title": "Chronic GvHD",
558
- },
559
- "VOCPSHI": {
560
- "event_label": "VOC Post-HCT",
561
- "no_event_label": "No VOC",
562
- "event_color": "#f4a261",
563
- "no_event_color": "#264653",
564
- "title": "Vaso-Occlusive Crisis",
565
- },
566
- "STROKEHI": {
567
- "event_label": "Stroke Post-HCT",
568
- "no_event_label": "No Stroke",
569
- "event_color": "#d62828",
570
- "no_event_color": "#606c38",
571
- "title": "Stroke Post-HCT",
572
- },
573
  }
574
 
575
 
576
- def _person_icon_svg(color: str, size: int = 20) -> str:
577
- """Return an inline SVG person icon of the given color."""
 
 
 
 
 
 
 
 
 
 
578
  return (
579
  f'<svg xmlns="http://www.w3.org/2000/svg" '
580
- f'width="{size}" height="{size}" viewBox="0 0 20 20" '
581
- f'style="display:inline-block;vertical-align:middle;">'
582
- f'<path d="{_PERSON_SVG_PATH}" fill="{color}"/>'
 
 
 
 
 
583
  f'</svg>'
584
  )
585
 
@@ -587,14 +562,13 @@ def _person_icon_svg(color: str, size: int = 20) -> str:
587
  def create_icon_array_html(probability: float, outcome: str) -> str:
588
  """
589
  Build an HTML icon array card for a single outcome.
590
- 100 person icons in a 10×10 grid; filled icons represent the event probability.
 
591
  Returns a self-contained HTML string suitable for gr.HTML.
592
  """
593
  cfg = OUTCOME_CONFIG.get(outcome, {
594
  "event_label": "Event",
595
  "no_event_label": "No Event",
596
- "event_color": "#e63946",
597
- "no_event_color": "#adb5bd",
598
  "title": OUTCOME_DESCRIPTIONS.get(outcome, outcome),
599
  })
600
 
@@ -602,47 +576,51 @@ def create_icon_array_html(probability: float, outcome: str) -> str:
602
  n_no_event = 100 - n_event
603
  pct_str = f"{probability * 100:.1f}%"
604
 
605
- # Build icon rows (10 columns × 10 rows)
606
- icons_html = ""
607
- for i in range(100):
608
- color = cfg["event_color"] if i < n_event else cfg["no_event_color"]
609
- icons_html += _person_icon_svg(color, size=18)
610
- if (i + 1) % 10 == 0:
611
- icons_html += "<br>"
 
 
 
 
 
 
 
612
 
613
  html = f"""
614
  <div style="
615
  background: #ffffff;
616
  border: 1px solid #e0e0e0;
617
  border-radius: 10px;
618
- padding: 14px 16px;
619
  text-align: center;
620
  font-family: 'Segoe UI', Arial, sans-serif;
621
  box-shadow: 0 2px 6px rgba(0,0,0,0.07);
622
- height: 100%;
623
  box-sizing: border-box;
624
  ">
625
- <div style="font-size: 13px; font-weight: 600; color: #333; margin-bottom: 6px; line-height: 1.3;">
626
  {cfg['title']}
627
  </div>
628
- <div style="font-size: 26px; font-weight: 700; color: {cfg['event_color']}; margin-bottom: 6px;">
629
  {pct_str}
630
  </div>
631
- <div style="line-height: 1.6; margin-bottom: 8px;">
632
- {icons_html}
633
  </div>
634
- <div style="font-size: 11px; color: #555; display: flex; justify-content: center; gap: 14px; flex-wrap: wrap;">
635
- <span>
636
- <svg width="11" height="11" viewBox="0 0 20 20" style="vertical-align:middle;margin-right:3px;">
637
- <path d="{_PERSON_SVG_PATH}" fill="{cfg['event_color']}"/>
638
- </svg>
639
- {cfg['event_label']} ({n_event}/100)
640
  </span>
641
- <span>
642
- <svg width="11" height="11" viewBox="0 0 20 20" style="vertical-align:middle;margin-right:3px;">
643
- <path d="{_PERSON_SVG_PATH}" fill="{cfg['no_event_color']}"/>
644
- </svg>
645
- {cfg['no_event_label']} ({n_no_event}/100)
646
  </span>
647
  </div>
648
  </div>
@@ -663,19 +641,18 @@ def create_all_icon_arrays(calibrated_probs: dict) -> dict:
663
  for row_start in range(0, len(pie_outcomes), 4):
664
  row_outcomes = pie_outcomes[row_start: row_start + 4]
665
  cols_html = "".join(
666
- f'<div style="flex: 1 1 0; min-width: 0;">{cards[o]}</div>'
667
  for o in row_outcomes
668
  )
669
- rows_html += f"""
670
- <div style="display: flex; gap: 12px; margin-bottom: 12px;">
671
- {cols_html}
672
- </div>"""
673
 
674
  grid_html = f"""
675
- <div style="font-family: 'Segoe UI', Arial, sans-serif; padding: 4px 0;">
676
  {rows_html}
677
- <div style="font-size: 11px; color: #888; text-align: center; margin-top: 4px;">
678
  Each figure represents 1 patient out of 100 with similar characteristics.
 
 
679
  </div>
680
  </div>
681
  """
 
519
  # Icon Array (replaces Pie Charts)
520
  # ---------------------------------------------------------------------------
521
 
522
+ # Uniform colors for all outcomes: red = event, green = no event
523
+ EVENT_COLOR = "#e53935" # red
524
+ NO_EVENT_COLOR = "#43a047" # green
 
 
 
 
525
 
526
  OUTCOME_CONFIG = {
527
+ "DEAD": {"event_label": "Death", "no_event_label": "Survival", "title": "Total Mortality"},
528
+ "GF": {"event_label": "Graft Failure", "no_event_label": "No Graft Failure", "title": "Graft Failure"},
529
+ "AGVHD": {"event_label": "Acute GVHD", "no_event_label": "No Acute GVHD", "title": "Acute GvHD"},
530
+ "CGVHD": {"event_label": "Chronic GVHD", "no_event_label": "No Chronic GVHD", "title": "Chronic GvHD"},
531
+ "VOCPSHI": {"event_label": "VOC Post-HCT", "no_event_label": "No VOC", "title": "Vaso-Occlusive Crisis"},
532
+ "STROKEHI": {"event_label": "Stroke Post-HCT", "no_event_label": "No Stroke", "title": "Stroke Post-HCT"},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
  }
534
 
535
 
536
+ def _stick_figure_svg(color: str, size: int = 20) -> str:
537
+ """
538
+ Return an inline SVG stick figure (head + body + arms + legs).
539
+ ViewBox is 0 0 20 32 so the figure is taller than wide.
540
+ stroke-width is scaled so figures look clean at small sizes.
541
+ """
542
+ sw = 2.2 # stroke width
543
+ # head: circle centred at (10, 5) r=4
544
+ # body: (10,9) → (10,20)
545
+ # arms: (3,13) → (17,13)
546
+ # left leg: (10,20) → (4,30)
547
+ # right leg: (10,20) → (16,30)
548
  return (
549
  f'<svg xmlns="http://www.w3.org/2000/svg" '
550
+ f'width="{size}" height="{round(size * 1.6)}" viewBox="0 0 20 32" '
551
+ f'style="display:inline-block;vertical-align:bottom;" '
552
+ f'stroke="{color}" stroke-width="{sw}" stroke-linecap="round" fill="none">'
553
+ f'<circle cx="10" cy="5" r="3.8" fill="{color}" stroke="none"/>'
554
+ f'<line x1="10" y1="9" x2="10" y2="20"/>'
555
+ f'<line x1="3" y1="13" x2="17" y2="13"/>'
556
+ f'<line x1="10" y1="20" x2="4" y2="30"/>'
557
+ f'<line x1="10" y1="20" x2="16" y2="30"/>'
558
  f'</svg>'
559
  )
560
 
 
562
  def create_icon_array_html(probability: float, outcome: str) -> str:
563
  """
564
  Build an HTML icon array card for a single outcome.
565
+ 100 stick-figure icons in a 10×10 grid.
566
+ Red = event, green = no event (uniform across all outcomes).
567
  Returns a self-contained HTML string suitable for gr.HTML.
568
  """
569
  cfg = OUTCOME_CONFIG.get(outcome, {
570
  "event_label": "Event",
571
  "no_event_label": "No Event",
 
 
572
  "title": OUTCOME_DESCRIPTIONS.get(outcome, outcome),
573
  })
574
 
 
576
  n_no_event = 100 - n_event
577
  pct_str = f"{probability * 100:.1f}%"
578
 
579
+ # Build icon grid: 10 per row, event (red) figures first, then no-event (green)
580
+ rows_parts = []
581
+ for row in range(10):
582
+ row_html = '<div style="display:flex;justify-content:center;gap:1px;margin-bottom:1px;">'
583
+ for col in range(10):
584
+ idx = row * 10 + col
585
+ color = EVENT_COLOR if idx < n_event else NO_EVENT_COLOR
586
+ row_html += _stick_figure_svg(color, size=16)
587
+ row_html += '</div>'
588
+ rows_parts.append(row_html)
589
+ grid_html = "\n".join(rows_parts)
590
+
591
+ legend_figure_event = _stick_figure_svg(EVENT_COLOR, size=12)
592
+ legend_figure_no_event = _stick_figure_svg(NO_EVENT_COLOR, size=12)
593
 
594
  html = f"""
595
  <div style="
596
  background: #ffffff;
597
  border: 1px solid #e0e0e0;
598
  border-radius: 10px;
599
+ padding: 12px 10px 10px 10px;
600
  text-align: center;
601
  font-family: 'Segoe UI', Arial, sans-serif;
602
  box-shadow: 0 2px 6px rgba(0,0,0,0.07);
 
603
  box-sizing: border-box;
604
  ">
605
+ <div style="font-size: 12px; font-weight: 700; color: #222; margin-bottom: 4px; line-height: 1.3;">
606
  {cfg['title']}
607
  </div>
608
+ <div style="font-size: 24px; font-weight: 800; color: {EVENT_COLOR}; margin-bottom: 6px;">
609
  {pct_str}
610
  </div>
611
+ <div style="margin-bottom: 6px;">
612
+ {grid_html}
613
  </div>
614
+ <div style="font-size: 10.5px; color: #444; display:flex; justify-content:center; gap:12px; flex-wrap:wrap;">
615
+ <span style="display:inline-flex;align-items:center;gap:3px;">
616
+ {legend_figure_event}
617
+ <span style="color:{EVENT_COLOR};font-weight:600;">{cfg['event_label']}</span>
618
+ &nbsp;({n_event}/100)
 
619
  </span>
620
+ <span style="display:inline-flex;align-items:center;gap:3px;">
621
+ {legend_figure_no_event}
622
+ <span style="color:{NO_EVENT_COLOR};font-weight:600;">{cfg['no_event_label']}</span>
623
+ &nbsp;({n_no_event}/100)
 
624
  </span>
625
  </div>
626
  </div>
 
641
  for row_start in range(0, len(pie_outcomes), 4):
642
  row_outcomes = pie_outcomes[row_start: row_start + 4]
643
  cols_html = "".join(
644
+ f'<div style="flex:1 1 0;min-width:0;">{cards[o]}</div>'
645
  for o in row_outcomes
646
  )
647
+ rows_html += f'<div style="display:flex;gap:10px;margin-bottom:10px;">{cols_html}</div>'
 
 
 
648
 
649
  grid_html = f"""
650
+ <div style="font-family:'Segoe UI',Arial,sans-serif;padding:4px 0;">
651
  {rows_html}
652
+ <div style="font-size:11px;color:#888;text-align:center;margin-top:2px;">
653
  Each figure represents 1 patient out of 100 with similar characteristics.
654
+ &nbsp;<span style="color:{EVENT_COLOR};font-weight:600;">■ Red = event</span>
655
+ &nbsp;<span style="color:{NO_EVENT_COLOR};font-weight:600;">■ Green = no event</span>
656
  </div>
657
  </div>
658
  """