Spaces:
Sleeping
Sleeping
Update inference.py
Browse files- inference.py +65 -88
inference.py
CHANGED
|
@@ -519,67 +519,42 @@ def create_all_shap_plots(user_inputs, max_display=10):
|
|
| 519 |
# Icon Array (replaces Pie Charts)
|
| 520 |
# ---------------------------------------------------------------------------
|
| 521 |
|
| 522 |
-
#
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
"C8.35,8 7,6.65 7,5 C7,3.35 8.35,2 10,2 Z "
|
| 526 |
-
"M10,9 C6.13,9 3,10.79 3,13 L3,14 L17,14 L17,13 "
|
| 527 |
-
"C17,10.79 13.87,9 10,9 Z"
|
| 528 |
-
)
|
| 529 |
|
| 530 |
OUTCOME_CONFIG = {
|
| 531 |
-
"DEAD": {
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
},
|
| 538 |
-
"GF": {
|
| 539 |
-
"event_label": "Graft Failure",
|
| 540 |
-
"no_event_label": "No Graft Failure",
|
| 541 |
-
"event_color": "#e76f51",
|
| 542 |
-
"no_event_color": "#457b9d",
|
| 543 |
-
"title": "Graft Failure",
|
| 544 |
-
},
|
| 545 |
-
"AGVHD": {
|
| 546 |
-
"event_label": "Acute GVHD",
|
| 547 |
-
"no_event_label": "No Acute GVHD",
|
| 548 |
-
"event_color": "#c77dff",
|
| 549 |
-
"no_event_color": "#48cae4",
|
| 550 |
-
"title": "Acute GvHD",
|
| 551 |
-
},
|
| 552 |
-
"CGVHD": {
|
| 553 |
-
"event_label": "Chronic GVHD",
|
| 554 |
-
"no_event_label": "No Chronic GVHD",
|
| 555 |
-
"event_color": "#9b2226",
|
| 556 |
-
"no_event_color": "#94d2bd",
|
| 557 |
-
"title": "Chronic GvHD",
|
| 558 |
-
},
|
| 559 |
-
"VOCPSHI": {
|
| 560 |
-
"event_label": "VOC Post-HCT",
|
| 561 |
-
"no_event_label": "No VOC",
|
| 562 |
-
"event_color": "#f4a261",
|
| 563 |
-
"no_event_color": "#264653",
|
| 564 |
-
"title": "Vaso-Occlusive Crisis",
|
| 565 |
-
},
|
| 566 |
-
"STROKEHI": {
|
| 567 |
-
"event_label": "Stroke Post-HCT",
|
| 568 |
-
"no_event_label": "No Stroke",
|
| 569 |
-
"event_color": "#d62828",
|
| 570 |
-
"no_event_color": "#606c38",
|
| 571 |
-
"title": "Stroke Post-HCT",
|
| 572 |
-
},
|
| 573 |
}
|
| 574 |
|
| 575 |
|
| 576 |
-
def
|
| 577 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 578 |
return (
|
| 579 |
f'<svg xmlns="http://www.w3.org/2000/svg" '
|
| 580 |
-
f'width="{size}" height="{size}" viewBox="0 0 20
|
| 581 |
-
f'style="display:inline-block;vertical-align:
|
| 582 |
-
f'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 583 |
f'</svg>'
|
| 584 |
)
|
| 585 |
|
|
@@ -587,14 +562,13 @@ def _person_icon_svg(color: str, size: int = 20) -> str:
|
|
| 587 |
def create_icon_array_html(probability: float, outcome: str) -> str:
|
| 588 |
"""
|
| 589 |
Build an HTML icon array card for a single outcome.
|
| 590 |
-
100
|
|
|
|
| 591 |
Returns a self-contained HTML string suitable for gr.HTML.
|
| 592 |
"""
|
| 593 |
cfg = OUTCOME_CONFIG.get(outcome, {
|
| 594 |
"event_label": "Event",
|
| 595 |
"no_event_label": "No Event",
|
| 596 |
-
"event_color": "#e63946",
|
| 597 |
-
"no_event_color": "#adb5bd",
|
| 598 |
"title": OUTCOME_DESCRIPTIONS.get(outcome, outcome),
|
| 599 |
})
|
| 600 |
|
|
@@ -602,47 +576,51 @@ def create_icon_array_html(probability: float, outcome: str) -> str:
|
|
| 602 |
n_no_event = 100 - n_event
|
| 603 |
pct_str = f"{probability * 100:.1f}%"
|
| 604 |
|
| 605 |
-
# Build icon
|
| 606 |
-
|
| 607 |
-
for
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 612 |
|
| 613 |
html = f"""
|
| 614 |
<div style="
|
| 615 |
background: #ffffff;
|
| 616 |
border: 1px solid #e0e0e0;
|
| 617 |
border-radius: 10px;
|
| 618 |
-
padding:
|
| 619 |
text-align: center;
|
| 620 |
font-family: 'Segoe UI', Arial, sans-serif;
|
| 621 |
box-shadow: 0 2px 6px rgba(0,0,0,0.07);
|
| 622 |
-
height: 100%;
|
| 623 |
box-sizing: border-box;
|
| 624 |
">
|
| 625 |
-
<div style="font-size:
|
| 626 |
{cfg['title']}
|
| 627 |
</div>
|
| 628 |
-
<div style="font-size:
|
| 629 |
{pct_str}
|
| 630 |
</div>
|
| 631 |
-
<div style="
|
| 632 |
-
{
|
| 633 |
</div>
|
| 634 |
-
<div style="font-size:
|
| 635 |
-
<span>
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
{cfg['event_label']} ({n_event}/100)
|
| 640 |
</span>
|
| 641 |
-
<span>
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
{cfg['no_event_label']} ({n_no_event}/100)
|
| 646 |
</span>
|
| 647 |
</div>
|
| 648 |
</div>
|
|
@@ -663,19 +641,18 @@ def create_all_icon_arrays(calibrated_probs: dict) -> dict:
|
|
| 663 |
for row_start in range(0, len(pie_outcomes), 4):
|
| 664 |
row_outcomes = pie_outcomes[row_start: row_start + 4]
|
| 665 |
cols_html = "".join(
|
| 666 |
-
f'<div style="flex:
|
| 667 |
for o in row_outcomes
|
| 668 |
)
|
| 669 |
-
rows_html += f""
|
| 670 |
-
<div style="display: flex; gap: 12px; margin-bottom: 12px;">
|
| 671 |
-
{cols_html}
|
| 672 |
-
</div>"""
|
| 673 |
|
| 674 |
grid_html = f"""
|
| 675 |
-
<div style="font-family:
|
| 676 |
{rows_html}
|
| 677 |
-
<div style="font-size:
|
| 678 |
Each figure represents 1 patient out of 100 with similar characteristics.
|
|
|
|
|
|
|
| 679 |
</div>
|
| 680 |
</div>
|
| 681 |
"""
|
|
|
|
| 519 |
# Icon Array (replaces Pie Charts)
|
| 520 |
# ---------------------------------------------------------------------------
|
| 521 |
|
| 522 |
+
# Uniform colors for all outcomes: red = event, green = no event
|
| 523 |
+
EVENT_COLOR = "#e53935" # red
|
| 524 |
+
NO_EVENT_COLOR = "#43a047" # green
|
|
|
|
|
|
|
|
|
|
|
|
|
| 525 |
|
| 526 |
OUTCOME_CONFIG = {
|
| 527 |
+
"DEAD": {"event_label": "Death", "no_event_label": "Survival", "title": "Total Mortality"},
|
| 528 |
+
"GF": {"event_label": "Graft Failure", "no_event_label": "No Graft Failure", "title": "Graft Failure"},
|
| 529 |
+
"AGVHD": {"event_label": "Acute GVHD", "no_event_label": "No Acute GVHD", "title": "Acute GvHD"},
|
| 530 |
+
"CGVHD": {"event_label": "Chronic GVHD", "no_event_label": "No Chronic GVHD", "title": "Chronic GvHD"},
|
| 531 |
+
"VOCPSHI": {"event_label": "VOC Post-HCT", "no_event_label": "No VOC", "title": "Vaso-Occlusive Crisis"},
|
| 532 |
+
"STROKEHI": {"event_label": "Stroke Post-HCT", "no_event_label": "No Stroke", "title": "Stroke Post-HCT"},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 533 |
}
|
| 534 |
|
| 535 |
|
| 536 |
+
def _stick_figure_svg(color: str, size: int = 20) -> str:
|
| 537 |
+
"""
|
| 538 |
+
Return an inline SVG stick figure (head + body + arms + legs).
|
| 539 |
+
ViewBox is 0 0 20 32 so the figure is taller than wide.
|
| 540 |
+
stroke-width is scaled so figures look clean at small sizes.
|
| 541 |
+
"""
|
| 542 |
+
sw = 2.2 # stroke width
|
| 543 |
+
# head: circle centred at (10, 5) r=4
|
| 544 |
+
# body: (10,9) → (10,20)
|
| 545 |
+
# arms: (3,13) → (17,13)
|
| 546 |
+
# left leg: (10,20) → (4,30)
|
| 547 |
+
# right leg: (10,20) → (16,30)
|
| 548 |
return (
|
| 549 |
f'<svg xmlns="http://www.w3.org/2000/svg" '
|
| 550 |
+
f'width="{size}" height="{round(size * 1.6)}" viewBox="0 0 20 32" '
|
| 551 |
+
f'style="display:inline-block;vertical-align:bottom;" '
|
| 552 |
+
f'stroke="{color}" stroke-width="{sw}" stroke-linecap="round" fill="none">'
|
| 553 |
+
f'<circle cx="10" cy="5" r="3.8" fill="{color}" stroke="none"/>'
|
| 554 |
+
f'<line x1="10" y1="9" x2="10" y2="20"/>'
|
| 555 |
+
f'<line x1="3" y1="13" x2="17" y2="13"/>'
|
| 556 |
+
f'<line x1="10" y1="20" x2="4" y2="30"/>'
|
| 557 |
+
f'<line x1="10" y1="20" x2="16" y2="30"/>'
|
| 558 |
f'</svg>'
|
| 559 |
)
|
| 560 |
|
|
|
|
| 562 |
def create_icon_array_html(probability: float, outcome: str) -> str:
|
| 563 |
"""
|
| 564 |
Build an HTML icon array card for a single outcome.
|
| 565 |
+
100 stick-figure icons in a 10×10 grid.
|
| 566 |
+
Red = event, green = no event (uniform across all outcomes).
|
| 567 |
Returns a self-contained HTML string suitable for gr.HTML.
|
| 568 |
"""
|
| 569 |
cfg = OUTCOME_CONFIG.get(outcome, {
|
| 570 |
"event_label": "Event",
|
| 571 |
"no_event_label": "No Event",
|
|
|
|
|
|
|
| 572 |
"title": OUTCOME_DESCRIPTIONS.get(outcome, outcome),
|
| 573 |
})
|
| 574 |
|
|
|
|
| 576 |
n_no_event = 100 - n_event
|
| 577 |
pct_str = f"{probability * 100:.1f}%"
|
| 578 |
|
| 579 |
+
# Build icon grid: 10 per row, event (red) figures first, then no-event (green)
|
| 580 |
+
rows_parts = []
|
| 581 |
+
for row in range(10):
|
| 582 |
+
row_html = '<div style="display:flex;justify-content:center;gap:1px;margin-bottom:1px;">'
|
| 583 |
+
for col in range(10):
|
| 584 |
+
idx = row * 10 + col
|
| 585 |
+
color = EVENT_COLOR if idx < n_event else NO_EVENT_COLOR
|
| 586 |
+
row_html += _stick_figure_svg(color, size=16)
|
| 587 |
+
row_html += '</div>'
|
| 588 |
+
rows_parts.append(row_html)
|
| 589 |
+
grid_html = "\n".join(rows_parts)
|
| 590 |
+
|
| 591 |
+
legend_figure_event = _stick_figure_svg(EVENT_COLOR, size=12)
|
| 592 |
+
legend_figure_no_event = _stick_figure_svg(NO_EVENT_COLOR, size=12)
|
| 593 |
|
| 594 |
html = f"""
|
| 595 |
<div style="
|
| 596 |
background: #ffffff;
|
| 597 |
border: 1px solid #e0e0e0;
|
| 598 |
border-radius: 10px;
|
| 599 |
+
padding: 12px 10px 10px 10px;
|
| 600 |
text-align: center;
|
| 601 |
font-family: 'Segoe UI', Arial, sans-serif;
|
| 602 |
box-shadow: 0 2px 6px rgba(0,0,0,0.07);
|
|
|
|
| 603 |
box-sizing: border-box;
|
| 604 |
">
|
| 605 |
+
<div style="font-size: 12px; font-weight: 700; color: #222; margin-bottom: 4px; line-height: 1.3;">
|
| 606 |
{cfg['title']}
|
| 607 |
</div>
|
| 608 |
+
<div style="font-size: 24px; font-weight: 800; color: {EVENT_COLOR}; margin-bottom: 6px;">
|
| 609 |
{pct_str}
|
| 610 |
</div>
|
| 611 |
+
<div style="margin-bottom: 6px;">
|
| 612 |
+
{grid_html}
|
| 613 |
</div>
|
| 614 |
+
<div style="font-size: 10.5px; color: #444; display:flex; justify-content:center; gap:12px; flex-wrap:wrap;">
|
| 615 |
+
<span style="display:inline-flex;align-items:center;gap:3px;">
|
| 616 |
+
{legend_figure_event}
|
| 617 |
+
<span style="color:{EVENT_COLOR};font-weight:600;">{cfg['event_label']}</span>
|
| 618 |
+
({n_event}/100)
|
|
|
|
| 619 |
</span>
|
| 620 |
+
<span style="display:inline-flex;align-items:center;gap:3px;">
|
| 621 |
+
{legend_figure_no_event}
|
| 622 |
+
<span style="color:{NO_EVENT_COLOR};font-weight:600;">{cfg['no_event_label']}</span>
|
| 623 |
+
({n_no_event}/100)
|
|
|
|
| 624 |
</span>
|
| 625 |
</div>
|
| 626 |
</div>
|
|
|
|
| 641 |
for row_start in range(0, len(pie_outcomes), 4):
|
| 642 |
row_outcomes = pie_outcomes[row_start: row_start + 4]
|
| 643 |
cols_html = "".join(
|
| 644 |
+
f'<div style="flex:1 1 0;min-width:0;">{cards[o]}</div>'
|
| 645 |
for o in row_outcomes
|
| 646 |
)
|
| 647 |
+
rows_html += f'<div style="display:flex;gap:10px;margin-bottom:10px;">{cols_html}</div>'
|
|
|
|
|
|
|
|
|
|
| 648 |
|
| 649 |
grid_html = f"""
|
| 650 |
+
<div style="font-family:'Segoe UI',Arial,sans-serif;padding:4px 0;">
|
| 651 |
{rows_html}
|
| 652 |
+
<div style="font-size:11px;color:#888;text-align:center;margin-top:2px;">
|
| 653 |
Each figure represents 1 patient out of 100 with similar characteristics.
|
| 654 |
+
<span style="color:{EVENT_COLOR};font-weight:600;">■ Red = event</span>
|
| 655 |
+
<span style="color:{NO_EVENT_COLOR};font-weight:600;">■ Green = no event</span>
|
| 656 |
</div>
|
| 657 |
</div>
|
| 658 |
"""
|