Final-CF / app.py
shivapriyasom's picture
Update app.py
d0d0725 verified
Raw
History Blame Contribute Delete
70.6 kB
import gradio as gr
import pandas as pd
import numpy as np
import copy
import traceback
from inference import (
FEATURE_NAMES,
REPORTING_OUTCOMES,
OUTCOME_DESCRIPTIONS,
OUTCOMES,
SHAP_OUTCOMES,
predict_with_comparison,
predict_all_outcomes,
create_all_shap_plots,
create_all_icon_arrays,
)
# ─────────────────────────────────────────────────────────────────────────────
# CHOICES / CONSTANTS
# ─────────────────────────────────────────────────────────────────────────────
AGEGPFF_CHOICES = ["<=10", "11-17", "18-29", "30-49", ">=50"]
SEX_CHOICES = ["Male", "Female"]
KPS_CHOICES = ["<90", "β‰₯ 90"]
DONORF_CHOICES = [
"HLA identical sibling", "HLA mismatch relative",
"Matched unrelated donor", "Mismatched unrelated donor or cord blood",
]
GRAFTYPE_CHOICES = ["Bone marrow", "Peripheral blood", "Cord blood"]
CONDGRPF_CHOICES = ["MAC", "RIC", "NMA"]
CONDGRP_FINAL_CHOICES = [
"TBI/Cy", "TBI/Cy/Flu", "TBI/Cy/Flu/TT", "TBI/Mel", "TBI/Flu",
"TBI alone (300/400/600cGy)", "Bu/Cy", "Bu/Mel", "Flu/Bu/TT",
"Flu/Bu", "Flu/Mel/TT", "Flu/Mel", "Cy/Flu", "Treosulfan",
"Cy alone", "Flud", "TLI",
]
ATGF_CHOICES = ["ATG", "Alemtuzumab", "None"]
GVHD_FINAL_CHOICES = [
"Ex-vivo T-cell depletion", "CD34 selection", "Post-CY + siro +- MMF",
"Post-CY + MMF + CNI", "CNI + MMF", "CNI + MTX", "CNI alone",
"CNI + siro", "Siro alone", "MMF + MTX", "MMF + siro", "MMF alone",
"MTX alone", "MTX + siro",
]
HLA_FINAL_CHOICES = ["8/8", "7/8", "≀ 6/8"]
RCMVPR_CHOICES = ["Negative", "Positive"]
EXCHTFPR_CHOICES = ["No", "Yes"]
VOC2YPR_CHOICES = ["No", "Yes"]
VOCFRQPR_CHOICES = ["< 3/yr", "β‰₯ 3/yr"]
SCATXRSN_CHOICES = [
"CNS event", "Acute chest Syndrome", "Recurrent vaso-occlusive pain",
"Recurrent priapism", "Excessive transfusion requirements/iron overload",
"Cardio-pulmonary", "Chronic transfusion", "Asymptomatic",
"Renal insufficiency", "Splenic sequestration", "Avascular necrosis",
"Hodgkin lymphoma",
]
NUM_COLS_SET = {"AGE", "NACS2YR"}
MAX_SCENARIOS = 5
GROUPED_REGIMEN_CHOICES = [
("── HLA IDENTICAL ──", "__header_hla_identical__"),
("Hsieh et al 2014", "Hsieh et al 2014"),
("Krishnamurti et al 2019", "Krishnamurti et al 2019"),
("King et al 2015", "King et al 2015"),
("Walters et al 1996", "Walters et al 1996"),
("── HLA MISMATCHED ──", "__header_hla_mismatched__"),
("Bolanos-Meade et al 2022 (HLA Mismatch)", "Bolanos-Meade et al 2022 (HLA Mismatch)"),
("Patel et al 2020 (HLA Mismatch)", "Patel et al 2020 (HLA Mismatch)"),
("── MATCHED UNRELATED ──", "__header_matched_unrelated__"),
("L Krishnamurti et al 2019", "L Krishnamurti et al 2019"),
("Shenoy et al 2016", "Shenoy et al 2016"),
("── MISMATCHED UNRELATED / CORD BLOOD ──", "__header_mismatched_cord__"),
("Bolanos-Meade et al 2022 (Mismatched/Cord)", "Bolanos-Meade et al 2022 (Mismatched/Cord)"),
("Patel et al 2020 (Mismatched/Cord)", "Patel et al 2020 (Mismatched/Cord)"),
("── CUSTOM ──", "__header_custom__"),
("Custom", "Custom"),
]
HEADER_VALUES = {v for _, v in GROUPED_REGIMEN_CHOICES if v.startswith("__header_")}
PUBLISHED_PRESETS = {
"Hsieh et al 2014": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI alone (300/400/600cGy)", "ATGF": "Alemtuzumab", "GVHD_FINAL": "Siro alone", "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling"},
"Krishnamurti et al 2019": {"CONDGRPF": "MAC", "CONDGRP_FINAL": "Flu/Bu", "ATGF": "ATG", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling"},
"King et al 2015": {"CONDGRPF": "RIC", "CONDGRP_FINAL": "Flu/Mel", "ATGF": "Alemtuzumab", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling"},
"Walters et al 1996": {"CONDGRPF": "MAC", "CONDGRP_FINAL": "Bu/Cy", "ATGF": "ATG", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling"},
"Bolanos-Meade et al 2022 (HLA Mismatch)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "HLA mismatch relative"},
"Patel et al 2020 (HLA Mismatch)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu/TT", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "HLA mismatch relative"},
"L Krishnamurti et al 2019": {"CONDGRPF": "MAC", "CONDGRP_FINAL": "Flu/Bu", "ATGF": "ATG", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "Matched unrelated donor"},
"Shenoy et al 2016": {"CONDGRPF": "RIC", "CONDGRP_FINAL": "Flu/Mel", "ATGF": "Alemtuzumab", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "Matched unrelated donor"},
"Bolanos-Meade et al 2022 (Mismatched/Cord)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "Mismatched unrelated donor or cord blood"},
"Patel et al 2020 (Mismatched/Cord)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu/TT", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "Mismatched unrelated donor or cord blood"},
}
MSD_DONOR = "HLA identical sibling"
MUD_DONOR = "Matched unrelated donor"
LOCKED_88_DONORS = {MSD_DONOR, MUD_DONOR}
NON_88_DONORS = {"HLA mismatch relative", "Mismatched unrelated donor or cord blood"}
MSD_REGIMENS = {"Hsieh et al 2014", "Krishnamurti et al 2019", "King et al 2015", "Walters et al 1996"}
MMUD_REGIMENS = {"Bolanos-Meade et al 2022 (Mismatched/Cord)", "Patel et al 2020 (Mismatched/Cord)",
"Bolanos-Meade et al 2022 (HLA Mismatch)", "Patel et al 2020 (HLA Mismatch)"}
PATIENT_FEATURES = ["AGE", "AGEGPFF", "SEX", "KPS", "RCMVPR"]
DONOR_FEATURES = ["DONORF", "GRAFTYPE", "HLA_FINAL", "CONDGRPF", "CONDGRP_FINAL", "ATGF", "GVHD_FINAL"]
DISEASE_FEATURES = ["NACS2YR", "EXCHTFPR", "VOC2YPR", "VOCFRQPR", "SCATXRSN"]
ALL_FEATURES = PATIENT_FEATURES + DONOR_FEATURES + DISEASE_FEATURES
# Icon array outcomes in display order
ICON_OUTCOMES = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI"]
OUTCOME_TITLES = {
"DEAD": "Death",
"GF": "Graft Failure",
"AGVHD": "Acute GvHD",
"CGVHD": "Chronic GvHD",
"VOCPSHI": "Vaso-Occlusive Crisis",
"STROKEHI": "Stroke Post-HCT",
}
EVENT_COLOR = "#e53935"
NO_EVENT_COLOR = "#43a047"
SHAP_ORDER = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "EFS", "STROKEHI", "OS"]
SHAP_LABELS = {
"DEAD": "Death",
"GF": "Graft Failure",
"AGVHD": "Acute GvHD",
"CGVHD": "Chronic GvHD",
"VOCPSHI": "Vaso-Occlusive Crisis Post-HCT",
"EFS": "Event-Free Survival",
"STROKEHI": "Stroke Post-HCT",
"OS": "Overall Survival",
}
SCENARIO_COLORS = ["#e65100", "#6a1b9a", "#1b5e20", "#0d47a1", "#b71c1c"]
# ─────────────────────────────────────────────────────────────────────────────
# COMPONENT FACTORY
# ─────────────────────────────────────────────────────────────────────────────
_FEATURE_META = {
"AGE": {"label": "Age (years)", "type": "number", "kwargs": {"minimum": 0, "maximum": 100, "step": 1}},
"AGEGPFF": {"label": "Age Group", "type": "dropdown", "choices": AGEGPFF_CHOICES, "kwargs": {"interactive": False, "info": "Auto-filled from Age"}},
"SEX": {"label": "Sex", "type": "radio", "choices": SEX_CHOICES},
"KPS": {"label": "Karnofsky Performance Status", "type": "radio", "choices": KPS_CHOICES},
"RCMVPR": {"label": "Recipient CMV Status", "type": "radio", "choices": RCMVPR_CHOICES},
"DONORF": {"label": "Donor Type", "type": "dropdown", "choices": DONORF_CHOICES},
"GRAFTYPE": {"label": "Graft Type", "type": "radio", "choices": GRAFTYPE_CHOICES},
"HLA_FINAL": {"label": "HLA Matching", "type": "radio", "choices": HLA_FINAL_CHOICES},
"CONDGRPF": {"label": "Conditioning Intensity", "type": "radio", "choices": CONDGRPF_CHOICES},
"CONDGRP_FINAL": {"label": "Conditioning Regimen", "type": "dropdown", "choices": CONDGRP_FINAL_CHOICES},
"ATGF": {"label": "Serotherapy", "type": "radio", "choices": ATGF_CHOICES},
"GVHD_FINAL": {"label": "GvHD Prophylaxis", "type": "dropdown", "choices": GVHD_FINAL_CHOICES},
"NACS2YR": {"label": "Acute Chest Syndrome Episodes (2 yrs pre-HCT)", "type": "number", "kwargs": {"minimum": 0, "maximum": 50, "step": 1}},
"EXCHTFPR": {"label": "Exchange Transfusion Pre-HCT", "type": "radio", "choices": EXCHTFPR_CHOICES},
"VOC2YPR": {"label": "Vaso-Occlusive Crisis in 2 yrs Pre-HCT", "type": "radio", "choices": VOC2YPR_CHOICES},
"VOCFRQPR": {"label": "VoC Frequency Pre-HCT", "type": "radio", "choices": VOCFRQPR_CHOICES},
"SCATXRSN": {"label": "Primary Reason for HCT", "type": "dropdown", "choices": SCATXRSN_CHOICES},
}
def make_component(feature: str, suffix: str = ""):
meta = _FEATURE_META[feature]
label = f"{meta['label']} {suffix}".strip() if suffix else meta["label"]
kind = meta["type"]
kwargs = dict(meta.get("kwargs", {}))
if kind == "number":
return gr.Number(label=label, value=None, **kwargs)
elif kind == "dropdown":
interactive = kwargs.pop("interactive", True)
info = kwargs.pop("info", None)
return gr.Dropdown(label=label, choices=meta["choices"], value=None,
interactive=interactive, info=info, **kwargs)
elif kind == "radio":
interactive = kwargs.pop("interactive", True)
info = kwargs.pop("info", None)
return gr.Radio(label=label, choices=meta["choices"], value=None,
interactive=interactive, info=info, **kwargs)
else:
raise ValueError(f"Unknown component type '{kind}' for feature '{feature}'")
# ─────────────────────────────────────────────────────────────────────────────
# CONSTRAINT / VALIDATION HELPERS
# ─────────────────────────────────────────────────────────────────────────────
def _hla_update_for_donor(donor_value):
if not donor_value:
return gr.update(choices=HLA_FINAL_CHOICES, interactive=True)
if donor_value in LOCKED_88_DONORS:
return gr.update(choices=["8/8"], value="8/8", interactive=False)
elif donor_value in NON_88_DONORS:
return gr.update(choices=["7/8", "≀ 6/8"], value=None, interactive=True)
return gr.update(choices=HLA_FINAL_CHOICES, interactive=True)
def _validate_counterfactual_constraints(base_dict, wi_dict, label=""):
violations = []
tag = f"[{label}] " if label else ""
if base_dict.get("SEX") and wi_dict.get("SEX"):
if base_dict["SEX"] != wi_dict["SEX"]:
violations.append(f"{tag} Immutable feature: Sex cannot be changed.")
base_age, wi_age = base_dict.get("AGE"), wi_dict.get("AGE")
if base_age is not None and wi_age is not None:
try:
if float(wi_age) < float(base_age):
violations.append(f"{tag}Age cannot be decreased ({base_age} β†’ {wi_age}).")
except (TypeError, ValueError):
pass
wi_donor = wi_dict.get("DONORF")
wi_hla = wi_dict.get("HLA_FINAL")
if wi_donor and wi_hla:
if wi_donor in LOCKED_88_DONORS and wi_hla != "8/8":
violations.append(f"{tag}HLA constraint: '{wi_donor}' requires 8/8 HLA.")
elif wi_donor in NON_88_DONORS and wi_hla == "8/8":
violations.append(f"{tag}HLA constraint: '{wi_donor}' cannot have 8/8 HLA.")
if wi_donor:
wi_gvhd = wi_dict.get("GVHD_FINAL", "")
if wi_donor == MSD_DONOR and wi_gvhd in {"Post-CY + siro +- MMF", "Post-CY + MMF + CNI"}:
violations.append(
f"{tag} Post-Cy GVHD prophylaxis is inconsistent with donor '{MSD_DONOR}'."
)
return violations
def _values_to_dict(values):
d = {}
for f, v in zip(ALL_FEATURES, values):
if f in NUM_COLS_SET:
try:
d[f] = float(v) if v not in (None, "") else None
except (TypeError, ValueError):
d[f] = None
else:
d[f] = v
return d
def _check_missing(user_vals, label=""):
missing = [f for f, v in user_vals.items()
if v is None or v == "" or (isinstance(v, float) and pd.isna(v))]
if missing:
raise ValueError(
f"{'[' + label + '] ' if label else ''}Please fill in all fields. "
f"Missing: {', '.join(missing)}"
)
def get_age_group(age):
if age is None or age == "":
return ""
try:
age = float(age)
if age <= 10: return "<=10"
elif age <= 17: return "11-17"
elif age <= 29: return "18-29"
elif age <= 49: return "30-49"
else: return ">=50"
except (ValueError, TypeError):
return ""
def vocfrqpr_from_voc2ypr(voc_status):
if voc_status == "No":
return gr.update(value="< 3/yr", interactive=False)
return gr.update(value=None, interactive=True)
def apply_grouped_preset(selected_value):
if not selected_value or selected_value in HEADER_VALUES:
return (gr.update(value=None),) + (gr.update(),) * 6
if selected_value == "Custom":
return (gr.update(),) + tuple(gr.update(interactive=True) for _ in range(6))
preset = PUBLISHED_PRESETS.get(selected_value)
if not preset:
return (gr.update(),) * 7
donor = preset["DONORF"]
hla_update = gr.update(
value=preset["HLA_FINAL"], interactive=False,
choices=(["8/8"] if donor in LOCKED_88_DONORS
else (["7/8", "≀ 6/8"] if donor in NON_88_DONORS else HLA_FINAL_CHOICES)),
)
return (
gr.update(),
gr.update(value=preset["DONORF"], interactive=False),
gr.update(value=preset["CONDGRPF"], interactive=False),
gr.update(value=preset["CONDGRP_FINAL"], interactive=False),
gr.update(value=preset["ATGF"], interactive=False),
gr.update(value=preset["GVHD_FINAL"], interactive=False),
hla_update,
)
def lock_sex(baseline_sex):
if baseline_sex:
return gr.update(value=baseline_sex, interactive=False)
return gr.update(interactive=False)
# ─────────────────────────────────────────────────────────────────────────────
# ICON ARRAY HTML RENDERERS (pure Python, no JS)
# ─────────────────────────────────────────────────────────────────────────────
def _stick_figure_svg(color, size=16):
h = round(size * 1.6)
return (
f'<svg xmlns="http://www.w3.org/2000/svg" width="{size}" height="{h}" '
f'viewBox="0 0 20 32" style="display:block;flex-shrink:0;" '
f'stroke="{color}" stroke-width="2.2" stroke-linecap="round" fill="none">'
f'<circle cx="10" cy="5" r="3.8" fill="{color}" stroke="none"/>'
f'<line x1="10" y1="9" x2="10" y2="20"/>'
f'<line x1="3" y1="13" x2="17" y2="13"/>'
f'<line x1="10" y1="20" x2="4" y2="30"/>'
f'<line x1="10" y1="20" x2="16" y2="30"/>'
f'</svg>'
)
def _render_single_icon_card(probability, outcome, panel_label="", panel_color="#1565c0"):
"""Render a single icon array card as HTML (no JS). Used for Python-driven carousel."""
title = OUTCOME_TITLES.get(outcome, OUTCOME_DESCRIPTIONS.get(outcome, outcome))
n_event = round(probability * 100)
n_no_event = 100 - n_event
pct_str = f"{probability * 100:.1f}%"
rows_parts = []
for row in range(10):
cells = ""
for col in range(10):
idx = row * 10 + col
color = EVENT_COLOR if idx < n_event else NO_EVENT_COLOR
cells += _stick_figure_svg(color, size=13)
rows_parts.append(
f'<div style="display:flex;justify-content:center;gap:1px;margin-bottom:1px;">{cells}</div>'
)
grid_html = "\n".join(rows_parts)
fig_e = _stick_figure_svg(EVENT_COLOR, size=11)
fig_ne = _stick_figure_svg(NO_EVENT_COLOR, size=11)
legend = (
f'<div style="display:inline-grid;grid-template-columns:13px 1fr 36px;'
f'align-items:center;gap:3px;row-gap:3px;">'
f'{fig_e}<span style="color:{EVENT_COLOR};font-weight:700;font-size:9px;">Event</span>'
f'<span style="color:#888;font-size:8px;">({n_event}/100)</span>'
f'{fig_ne}<span style="color:{NO_EVENT_COLOR};font-weight:700;font-size:9px;">No Event</span>'
f'<span style="color:#888;font-size:8px;">({n_no_event}/100)</span>'
f'</div>'
)
badge = (
f'<div style="background:{panel_color};color:#fff;font-size:8px;font-weight:700;'
f'border-radius:3px;padding:1px 5px;margin-bottom:2px;display:inline-block;">'
f'{panel_label}</div>'
) if panel_label else ""
return (
f'<div style="background:#fff;border:1px solid #e0e0e0;border-radius:7px;'
f'padding:10px 8px;text-align:center;font-family:\'Segoe UI\',Arial,sans-serif;'
f'box-shadow:0 2px 4px rgba(0,0,0,0.06);box-sizing:border-box;max-width:340px;'
f'margin:0 auto;display:flex;flex-direction:column;align-items:center;">'
f'{badge}'
f'<div style="min-height:30px;display:flex;align-items:center;justify-content:center;'
f'font-size:13px;font-weight:700;color:#222;line-height:1.3;margin-bottom:4px;">{title}</div>'
f'<div style="font-size:22px;font-weight:800;color:{EVENT_COLOR};'
f'line-height:1;margin-bottom:6px;">{pct_str}</div>'
f'<div style="margin-bottom:4px;">{grid_html}</div>'
f'<div>{legend}</div>'
f'</div>'
)
def _render_icon_carousel_page(probs_dict, idx):
"""Render the full carousel HTML for a given outcome index (no JS)."""
outcome = ICON_OUTCOMES[idx]
total = len(ICON_OUTCOMES)
card_html = _render_single_icon_card(probs_dict.get(outcome, 0.0), outcome)
label = OUTCOME_TITLES.get(outcome, outcome)
# Breadcrumb dots
dots = ""
for i, o in enumerate(ICON_OUTCOMES):
active = i == idx
dots += (
f'<span style="display:inline-block;width:{"10" if active else "7"}px;'
f'height:{"10" if active else "7"}px;border-radius:50%;'
f'background:{"#1565c0" if active else "#ccc"};'
f'margin:0 3px;vertical-align:middle;"></span>'
)
footnote = (
f'<div style="font-size:10px;color:#888;text-align:center;margin-top:10px;">'
f'Each figure = 1 patient out of 100. '
f'<span style="color:{EVENT_COLOR};font-weight:600;">&#9632; Red = Event</span> &nbsp; '
f'<span style="color:{NO_EVENT_COLOR};font-weight:600;">&#9632; Green = No Event</span>'
f'</div>'
)
return (
f'<div style="font-family:\'Segoe UI\',Arial,sans-serif;padding:4px 0;text-align:center;">'
f'<div style="font-size:15px;font-weight:700;color:#1565c0;margin-bottom:2px;">{label}</div>'
f'<div style="font-size:11px;color:#888;margin-bottom:10px;">{idx+1} / {total}</div>'
f'<div style="margin-bottom:2px;">{dots}</div>'
f'<div style="margin-top:10px;">{card_html}</div>'
f'{footnote}'
f'</div>'
)
def _render_comparison_icon_page(all_probs_list, labels, colors, idx):
"""Render comparison icon carousel page for a given outcome index (no JS)."""
outcome = ICON_OUTCOMES[idx]
total = len(ICON_OUTCOMES)
out_label = OUTCOME_TITLES.get(outcome, outcome)
# Dots breadcrumb
dots = ""
for i in range(total):
active = i == idx
dots += (
f'<span style="display:inline-block;width:{"10" if active else "7"}px;'
f'height:{"10" if active else "7"}px;border-radius:50%;'
f'background:{"#1565c0" if active else "#ccc"};'
f'margin:0 3px;vertical-align:middle;"></span>'
)
rows_html = ""
for probs, label, color in zip(all_probs_list, labels, colors):
prob = probs.get(outcome, 0.0)
card = _render_single_icon_card(prob, outcome, label, color)
rows_html += (
f'<div style="display:flex;align-items:flex-start;gap:10px;margin-bottom:12px;">'
f'<div style="width:80px;flex-shrink:0;font-size:9px;font-weight:700;'
f'text-align:center;border-radius:4px;padding:4px 6px;color:#fff;'
f'background:{color};white-space:normal;word-break:break-word;">{label}</div>'
f'<div style="flex:1;">{card}</div>'
f'</div>'
)
footnote = (
f'<div style="font-size:10px;color:#888;text-align:center;margin-top:4px;">'
f'Each figure = 1 patient out of 100. '
f'<span style="color:{EVENT_COLOR};font-weight:600;">&#9632; Red = Event</span> &nbsp; '
f'<span style="color:{NO_EVENT_COLOR};font-weight:600;">&#9632; Green = No Event</span>'
f'</div>'
)
return (
f'<div style="font-family:\'Segoe UI\',Arial,sans-serif;padding:4px 0;text-align:center;">'
f'<div style="font-size:15px;font-weight:700;color:#1565c0;margin-bottom:2px;">{out_label}</div>'
f'<div style="font-size:11px;color:#888;margin-bottom:6px;">{idx+1} / {total}</div>'
f'<div style="margin-bottom:10px;">{dots}</div>'
f'{rows_html}'
f'{footnote}'
f'</div>'
)
# ─────────────────────────────────────────────────────────────────────────────
# OTHER HTML RENDERERS
# ─────────────────────────────────────────────────────────────────────────────
def _delta_color_html(delta, is_survival):
if abs(delta) < 0.0005:
color = "#888888"
elif (delta > 0 and is_survival) or (delta < 0 and not is_survival):
color = "#2e7d32"
else:
color = "#c62828"
sign = "+" if delta >= 0 else ""
return f'<span style="color:{color};font-weight:700;">{sign}{delta*100:.1f}%</span>'
def _build_comparison_table_html(base_probs, base_ci, scenario_probs_list, scenario_ci_list, scenario_labels, scenario_colors):
survival_outcomes = {"OS", "EFS"}
outcome_headers = "".join(
f"<th style='padding:8px 10px;text-align:center;border-bottom:2px solid #ccd;"
f"color:#333;font-size:11px;'>{OUTCOME_DESCRIPTIONS.get(o, o)}</th>"
for o in REPORTING_OUTCOMES if o in base_probs
)
header = (
"<div style='overflow-x:auto;'>"
"<table style='width:100%;border-collapse:collapse;"
"font-family:\"Segoe UI\",Arial,sans-serif;font-size:12px;'>"
"<thead><tr style='background:#f0f4f8;'>"
"<th style='padding:9px 12px;text-align:left;border-bottom:2px solid #ccd;"
"color:#333;min-width:120px;'>Scenario</th>"
f"{outcome_headers}"
"</tr></thead><tbody>"
)
rows = ""
baseline_cells = ""
for o in REPORTING_OUTCOMES:
if o not in base_probs:
continue
bp = base_probs[o]
blo, bhi = base_ci.get(o, (float("nan"), float("nan")))
baseline_cells += (
f"<td style='padding:7px 10px;text-align:center;'>"
f"<div style='font-weight:700;color:#1565c0;font-size:13px;'>{bp*100:.1f}%</div>"
f"<div style='color:#5c7fa8;font-size:9px;'>[{blo*100:.1f}%–{bhi*100:.1f}%]</div>"
f"</td>"
)
rows += (
f"<tr style='background:#e8f0fb;'>"
f"<td style='padding:8px 12px;font-weight:700;color:#1565c0;border-right:2px solid #ccd;'>"
f"<div style='display:flex;align-items:center;gap:6px;'>"
f"<span style='display:inline-block;width:10px;height:10px;border-radius:50%;"
f"background:#1565c0;flex-shrink:0;'></span>Baseline</div></td>"
f"{baseline_cells}</tr>"
)
for j, (sp_dict, sci_dict, s_label) in enumerate(zip(scenario_probs_list, scenario_ci_list, scenario_labels)):
sc_color = scenario_colors[j % len(scenario_colors)]
bg = "#fafbfc" if j % 2 == 0 else "#ffffff"
scenario_cells = ""
for o in REPORTING_OUTCOMES:
if o not in base_probs or o not in sp_dict:
continue
bp = base_probs[o]
wp = sp_dict[o]
wlo, whi = sci_dict.get(o, (float("nan"), float("nan")))
delta = wp - bp
is_surv = o in survival_outcomes
scenario_cells += (
f"<td style='padding:7px 10px;text-align:center;'>"
f"<div style='font-weight:700;color:{sc_color};font-size:13px;'>{wp*100:.1f}%</div>"
f"<div style='color:#888;font-size:9px;'>[{wlo*100:.1f}%–{whi*100:.1f}%]</div>"
f"<div style='font-size:10px;margin-top:1px;'>{_delta_color_html(delta, is_surv)}</div>"
f"</td>"
)
rows += (
f"<tr style='background:{bg};'>"
f"<td style='padding:8px 12px;font-weight:600;border-right:2px solid #ccd;'>"
f"<div style='display:flex;align-items:center;gap:6px;'>"
f"<span style='display:inline-block;width:10px;height:10px;border-radius:50%;"
f"background:{sc_color};flex-shrink:0;'></span>"
f"<span style='color:{sc_color};'>{s_label}</span></div></td>"
f"{scenario_cells}</tr>"
)
footer = (
"</tbody></table></div>"
"<div style='font-size:10.5px;color:#888;margin-top:6px;'>"
"Ξ” from Baseline: <span style='color:#2e7d32;font-weight:600;'>Green = improvement</span> &nbsp;"
"<span style='color:#c62828;font-weight:600;'>Red = worsening</span> &nbsp;|&nbsp; "
"OS &amp; EFS: higher is better; all other outcomes: lower is better."
"</div>"
)
return header + rows + footer
def _build_violation_html(violations):
if not violations:
return ""
items = "".join(f"<li style='margin-bottom:6px;'>{v}</li>" for v in violations)
return (
f'<div style="background:#fff3e0;border:2px solid #e65100;border-radius:8px;'
f'padding:14px 18px;font-family:\'Segoe UI\',Arial,sans-serif;margin-bottom:12px;">'
f'<div style="font-weight:700;font-size:14px;color:#bf360c;margin-bottom:8px;">'
f'Constraint Violations β€” Analysis blocked</div>'
f'<ul style="margin:0;padding-left:20px;color:#6d1f00;font-size:13px;">{items}</ul>'
f'<div style="margin-top:10px;font-size:11px;color:#888;">'
f'Please correct the above before running the comparison.</div>'
f'</div>'
)
# ─────────────────────────────────────────────────────────────────────────────
# SHAP CAROUSEL HELPERS
# ─────────────────────────────────────────────────────────────────────────────
def _shap_counter_html(idx):
labels = [SHAP_LABELS.get(o, o) for o in SHAP_ORDER]
items = " Β· ".join(
f'<span style="font-weight:{"700" if i==idx else "400"};'
f'color:{"#1565c0" if i==idx else "#aaa"};font-size:11px;">{l}</span>'
for i, l in enumerate(labels)
)
return f'<div style="text-align:center;padding:4px 0 2px;">{items}</div>'
# ─────────────────────────────────────────────────────────────────────────────
# MAIN PREDICT CALLBACK (Tab 1)
# ─────────────────────────────────────────────────────────────────────────────
def predict_gradio(*values):
try:
user_vals = _values_to_dict(values)
missing = [f for f, v in user_vals.items()
if v is None or v == "" or (isinstance(v, float) and pd.isna(v))]
if missing:
raise ValueError(f"Please fill in all fields. Missing: {', '.join(missing)}")
calibrated, _ = predict_with_comparison(user_vals)
calibrated_probs, calibrated_intervals = calibrated
rows = []
for outcome in REPORTING_OUTCOMES:
desc = OUTCOME_DESCRIPTIONS[outcome]
calib_prob = calibrated_probs[outcome]
ci_low, ci_high = calibrated_intervals[outcome]
rows.append({
"Outcome": desc,
"Probability": f"{calib_prob * 100:.1f}%",
"95% CI": f"[{ci_low * 100:.1f}% – {ci_high * 100:.1f}%]",
})
df = pd.DataFrame(rows)
shap_plots = create_all_shap_plots(user_vals, max_display=10)
# Icon carousel: start at index 0, store probs in State
icon_html = _render_icon_carousel_page(calibrated_probs, 0)
first_shap = shap_plots[SHAP_ORDER[0]]
shap_crumb = _shap_counter_html(0)
return (
df,
calibrated_probs, # stored in State for carousel navigation
0, # icon carousel index State
icon_html, # displayed icon card
shap_plots, # stored in State
0, # shap index State
first_shap, # displayed plot
shap_crumb, # breadcrumb HTML
)
except Exception as e:
print(traceback.format_exc())
raise gr.Error(f"{type(e).__name__}: {str(e)}")
# ─────────────────────────────────────────────────────────────────────────────
# CSS
# ─────────────────────────────────────────────────────────────────────────────
custom_css = """
.predict-button {
background: linear-gradient(to right, #ff6b35, #ff8c42) !important;
border: none !important; color: white !important;
font-weight: bold !important; font-size: 16px !important; padding: 12px !important;
}
.predict-button:hover { background: linear-gradient(to right, #ff5722, #ff7b29) !important; }
.copy-button {
background: linear-gradient(to right, #388e3c, #66bb6a) !important;
border: none !important; color: white !important; font-weight: 600 !important;
}
.copy-button:hover { background: linear-gradient(to right, #2e7d32, #43a047) !important; }
.copy-from-predict-button {
background: linear-gradient(to right, #6a1b9a, #ab47bc) !important;
border: none !important; color: white !important; font-weight: 600 !important;
}
.copy-from-predict-button:hover { background: linear-gradient(to right, #4a148c, #8e24aa) !important; }
.counterfactual-button {
background: linear-gradient(to right, #1976d2, #42a5f5) !important;
border: none !important; color: white !important;
font-weight: bold !important; font-size: 15px !important; padding: 12px !important;
}
.counterfactual-button:hover { background: linear-gradient(to right, #1565c0, #1e88e5) !important; }
.shap-nav-button {
background: linear-gradient(to right, #37474f, #546e7a) !important;
border: none !important; color: white !important;
font-weight: bold !important; font-size: 18px !important;
padding: 8px 22px !important; border-radius: 6px !important;
}
.shap-nav-button:hover { background: linear-gradient(to right, #263238, #455a64) !important; }
.icon-nav-button {
background: linear-gradient(to right, #1565c0, #1e88e5) !important;
border: none !important; color: white !important;
font-weight: bold !important; font-size: 18px !important;
padding: 8px 22px !important; border-radius: 6px !important;
}
.icon-nav-button:hover { background: linear-gradient(to right, #0d47a1, #1565c0) !important; }
.output-dataframe table td:first-child,
.output-dataframe table th:first-child {
white-space: normal !important; word-break: break-word !important; min-width: 240px !important;
}
.constraint-info {
background: #e8f5e9; border-left: 4px solid #388e3c;
padding: 8px 14px; font-size: 12px; color: #1b5e20;
border-radius: 4px; margin-bottom: 8px;
}
.scenario-panel-0 { border-left: 4px solid #e65100 !important; }
.scenario-panel-1 { border-left: 4px solid #6a1b9a !important; }
.scenario-panel-2 { border-left: 4px solid #1b5e20 !important; }
.scenario-panel-3 { border-left: 4px solid #0d47a1 !important; }
.scenario-panel-4 { border-left: 4px solid #b71c1c !important; }
"""
# ─────────────────────────────────────────────────────────────────────────────
# BUILD UI
# ─────────────────────────────────────────────────────────────────────────────
with gr.Blocks(title="HCT Outcome Prediction Model", css=custom_css) as demo:
gr.Markdown("# HCT Outcome Prediction Model")
n_scenarios_state = gr.State(1)
with gr.Tabs():
# ══════════════════════════════════════════════════════════════════════
# TAB 1 β€” PREDICT OUTCOMES
# ══════════════════════════════════════════════════════════════════════
with gr.Tab("Predict Outcomes"):
gr.Markdown("Enter patient, transplant, and disease characteristics to predict outcomes.")
inputs_dict = {}
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Patient Characteristics")
for f in PATIENT_FEATURES:
inputs_dict[f] = make_component(f)
with gr.Column(scale=1):
gr.Markdown("### Transplant Characteristics")
grouped_dd = gr.Dropdown(
choices=GROUPED_REGIMEN_CHOICES, value=None,
label="Published conditioning regimen",
info="Auto-fills Donor Type, Conditioning Intensity, Regimen, Serotherapy, GVHD Prophylaxis",
)
p_donorf = inputs_dict["DONORF"] = make_component("DONORF")
inputs_dict["GRAFTYPE"] = make_component("GRAFTYPE")
p_condgrpf = inputs_dict["CONDGRPF"] = make_component("CONDGRPF")
p_condgrp_final = inputs_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL")
p_atgf = inputs_dict["ATGF"] = make_component("ATGF")
p_gvhd_final = inputs_dict["GVHD_FINAL"] = make_component("GVHD_FINAL")
p_hla_final = inputs_dict["HLA_FINAL"] = make_component("HLA_FINAL")
with gr.Column(scale=1):
gr.Markdown("### Disease Characteristics")
for f in DISEASE_FEATURES:
inputs_dict[f] = make_component(f)
inputs_dict["AGE"].change(get_age_group, inputs_dict["AGE"], inputs_dict["AGEGPFF"])
inputs_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, inputs_dict["VOC2YPR"], inputs_dict["VOCFRQPR"])
p_donorf.change(fn=_hla_update_for_donor, inputs=p_donorf, outputs=p_hla_final)
grouped_dd.change(
apply_grouped_preset, grouped_dd,
[grouped_dd, p_donorf, p_condgrpf, p_condgrp_final, p_atgf, p_gvhd_final, p_hla_final],
)
inputs_list = [inputs_dict[f] for f in ALL_FEATURES]
predict_btn = gr.Button("Predict", elem_classes="predict-button", size="lg")
gr.Markdown("---")
gr.Markdown("## Prediction Results")
output_table = gr.Dataframe(
headers=["Outcome", "Probability", "95% CI"],
label="Predicted Outcomes",
elem_classes="output-dataframe",
row_count=(len(REPORTING_OUTCOMES), "dynamic"),
column_count=(3, "fixed"),
wrap=True,
)
# ── Icon Array Carousel (Tab 1) β€” Python-driven ───────────────────
gr.Markdown("---")
gr.Markdown("## Outcome Probability β€” Icon Arrays")
gr.Markdown("*Use the ← β†’ arrows to browse outcomes one at a time.*")
# State: store probs dict + current index
icon_probs_state = gr.State(None)
icon_idx_state = gr.State(0)
with gr.Row():
icon_prev_btn = gr.Button("←", elem_classes="icon-nav-button", size="sm", scale=0)
icon_display = gr.HTML(scale=4)
icon_next_btn = gr.Button("β†’", elem_classes="icon-nav-button", size="sm", scale=0)
def _on_icon_prev(idx, probs):
if probs is None:
return idx, ""
new_idx = max(0, idx - 1)
return new_idx, _render_icon_carousel_page(probs, new_idx)
def _on_icon_next(idx, probs):
if probs is None:
return idx, ""
new_idx = min(len(ICON_OUTCOMES) - 1, idx + 1)
return new_idx, _render_icon_carousel_page(probs, new_idx)
icon_prev_btn.click(
fn=_on_icon_prev,
inputs=[icon_idx_state, icon_probs_state],
outputs=[icon_idx_state, icon_display],
)
icon_next_btn.click(
fn=_on_icon_next,
inputs=[icon_idx_state, icon_probs_state],
outputs=[icon_idx_state, icon_display],
)
# ── SHAP Carousel (Tab 1) ─────────────────────────────────────────
gr.Markdown("---")
gr.Markdown("## SHAP β€” Feature Importance")
gr.Markdown("*Use the ← β†’ buttons to browse SHAP plots one at a time.*")
shap_plots_state = gr.State(None)
shap_idx_state = gr.State(0)
with gr.Row():
shap_prev_btn = gr.Button("←", elem_classes="shap-nav-button", size="sm", scale=0)
shap_crumb = gr.HTML(value="", scale=4)
shap_next_btn = gr.Button("β†’", elem_classes="shap-nav-button", size="sm", scale=0)
shap_display = gr.Plot(label="SHAP Feature Importance")
def _on_shap_prev(idx, plots):
new_idx = max(0, idx - 1)
plot = plots[SHAP_ORDER[new_idx]] if plots else None
return new_idx, plot, _shap_counter_html(new_idx)
def _on_shap_next(idx, plots):
new_idx = min(len(SHAP_ORDER) - 1, idx + 1)
plot = plots[SHAP_ORDER[new_idx]] if plots else None
return new_idx, plot, _shap_counter_html(new_idx)
shap_prev_btn.click(
fn=_on_shap_prev,
inputs=[shap_idx_state, shap_plots_state],
outputs=[shap_idx_state, shap_display, shap_crumb],
)
shap_next_btn.click(
fn=_on_shap_next,
inputs=[shap_idx_state, shap_plots_state],
outputs=[shap_idx_state, shap_display, shap_crumb],
)
predict_btn.click(
fn=predict_gradio,
inputs=inputs_list,
outputs=[
output_table,
icon_probs_state, # ← store probs dict
icon_idx_state, # ← reset to 0
icon_display, # ← show first card
shap_plots_state,
shap_idx_state,
shap_display,
shap_crumb,
],
)
# ══════════════════════════════════════════════════════════════════════
# TAB 2 β€” COUNTERFACTUAL ANALYSIS
# ══════════════════════════════════════════════════════════════════════
with gr.Tab("Counterfactual Scenarios"):
gr.Markdown(
"## Counterfactual Scenario Analysis\n"
"Enter the **baseline** patient, then choose how many counterfactual "
"scenarios you want to compare. Each scenario panel will appear below."
)
with gr.Row():
n_scenarios_slider = gr.Slider(
minimum=1, maximum=MAX_SCENARIOS, step=1, value=1,
label=f"How many counterfactual scenarios do you want to compare? (1–{MAX_SCENARIOS})",
info="Adjust this first β€” scenario panels will appear/disappear below.",
)
gr.Markdown("---")
# ── BASELINE ─────────────────────────────────────────────────────
gr.Markdown("## Baseline Patient Profile")
with gr.Row():
copy_from_predict_btn = gr.Button(
"Copy from Predict tab β†’ Baseline",
elem_classes="copy-from-predict-button",
size="sm",
)
wi_baseline_dict = {}
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Patient Characteristics")
for f in PATIENT_FEATURES:
wi_baseline_dict[f] = make_component(f)
with gr.Column(scale=1):
gr.Markdown("### Transplant Characteristics")
wi_grouped_base = gr.Dropdown(
choices=GROUPED_REGIMEN_CHOICES, value=None,
label="Published conditioning regimen (Baseline)",
)
wb_donorf = wi_baseline_dict["DONORF"] = make_component("DONORF")
wi_baseline_dict["GRAFTYPE"] = make_component("GRAFTYPE")
wb_condgrpf = wi_baseline_dict["CONDGRPF"] = make_component("CONDGRPF")
wb_condgrp_final = wi_baseline_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL")
wb_atgf = wi_baseline_dict["ATGF"] = make_component("ATGF")
wb_gvhd_final = wi_baseline_dict["GVHD_FINAL"] = make_component("GVHD_FINAL")
wb_hla_final = wi_baseline_dict["HLA_FINAL"] = make_component("HLA_FINAL")
with gr.Column(scale=1):
gr.Markdown("### Disease Characteristics")
for f in DISEASE_FEATURES:
wi_baseline_dict[f] = make_component(f)
wi_baseline_dict["AGE"].change(get_age_group, wi_baseline_dict["AGE"], wi_baseline_dict["AGEGPFF"])
wi_baseline_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, wi_baseline_dict["VOC2YPR"], wi_baseline_dict["VOCFRQPR"])
wb_donorf.change(fn=_hla_update_for_donor, inputs=wb_donorf, outputs=wb_hla_final)
wi_grouped_base.change(
apply_grouped_preset, wi_grouped_base,
[wi_grouped_base, wb_donorf, wb_condgrpf, wb_condgrp_final, wb_atgf, wb_gvhd_final, wb_hla_final],
)
wi_baseline_list = [wi_baseline_dict[f] for f in ALL_FEATURES]
def _copy_predict_to_baseline(*vals):
n = len(ALL_FEATURES)
feature_vals = vals[:n]
regimen_value = vals[n] if len(vals) > n else None
preset_outputs = apply_grouped_preset(regimen_value)
regimen_dd_upd = gr.update(value=regimen_value)
donorf_upd, condgrpf_upd, condgrp_final_upd, atgf_upd, gvhd_final_upd, hla_final_upd = preset_outputs[1:]
return (*list(feature_vals), regimen_dd_upd, donorf_upd, condgrpf_upd,
condgrp_final_upd, atgf_upd, gvhd_final_upd, hla_final_upd)
copy_from_predict_btn.click(
fn=_copy_predict_to_baseline,
inputs=inputs_list + [grouped_dd],
outputs=wi_baseline_list + [wi_grouped_base, wb_donorf, wb_condgrpf,
wb_condgrp_final, wb_atgf, wb_gvhd_final, wb_hla_final],
)
gr.Markdown("---")
# ── SCENARIO PANELS ───────────────────────────────────────────────
scenario_dicts = []
scenario_lists = []
scenario_grouped_dds = []
scenario_rows = []
scenario_name_inputs = []
scenario_donorf_handles = []
scenario_hla_handles = []
scenario_voc2ypr_handles = []
scenario_age_handles = []
scenario_agegp_handles = []
scenario_copy_btns = []
for s_idx in range(MAX_SCENARIOS):
color = SCENARIO_COLORS[s_idx]
label = f"Scenario {s_idx + 1}"
suffix = f"({label})"
visible_init = (s_idx == 0)
with gr.Row(visible=visible_init) as s_row:
scenario_rows.append(s_row)
with gr.Column():
gr.HTML(
f'<div style="background:{color};color:#fff;font-weight:700;'
f'font-size:15px;padding:8px 14px;border-radius:6px;margin-bottom:6px;">'
f'Counterfactual {label}</div>'
)
with gr.Row():
s_name_input = gr.Textbox(
label=f"Scenario name ({label})",
value=label,
placeholder=f"e.g. {label}",
scale=2,
)
scenario_name_inputs.append(s_name_input)
copy_btn = gr.Button(
f"⬇ Copy Baseline β†’ {label}",
elem_classes="copy-button",
size="sm",
scale=1,
)
scenario_copy_btns.append(copy_btn)
s_dict = {}
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(f"#### Patient β€” {label}")
for f in PATIENT_FEATURES:
s_dict[f] = make_component(f, suffix)
with gr.Column(scale=1):
gr.Markdown(f"#### Transplant β€” {label}")
s_grouped_dd = gr.Dropdown(
choices=GROUPED_REGIMEN_CHOICES, value=None,
label=f"Published regimen ({label})",
)
scenario_grouped_dds.append(s_grouped_dd)
s_donorf = s_dict["DONORF"] = make_component("DONORF", suffix)
s_dict["GRAFTYPE"] = make_component("GRAFTYPE", suffix)
s_condgrpf = s_dict["CONDGRPF"] = make_component("CONDGRPF", suffix)
s_condgrp_final = s_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL", suffix)
s_atgf = s_dict["ATGF"] = make_component("ATGF", suffix)
s_gvhd_final = s_dict["GVHD_FINAL"] = make_component("GVHD_FINAL", suffix)
s_hla_final = s_dict["HLA_FINAL"] = make_component("HLA_FINAL", suffix)
with gr.Column(scale=1):
gr.Markdown(f"#### Disease β€” {label}")
for f in DISEASE_FEATURES:
s_dict[f] = make_component(f, suffix)
scenario_dicts.append(s_dict)
scenario_lists.append([s_dict[f] for f in ALL_FEATURES])
scenario_donorf_handles.append(s_donorf)
scenario_hla_handles.append(s_hla_final)
scenario_voc2ypr_handles.append(s_dict["VOC2YPR"])
scenario_age_handles.append(s_dict["AGE"])
scenario_agegp_handles.append(s_dict["AGEGPFF"])
s_dict["AGE"].change(get_age_group, s_dict["AGE"], s_dict["AGEGPFF"])
s_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, s_dict["VOC2YPR"], s_dict["VOCFRQPR"])
s_donorf.change(fn=_hla_update_for_donor, inputs=s_donorf, outputs=s_hla_final)
s_grouped_dd.change(
apply_grouped_preset, s_grouped_dd,
[s_grouped_dd, s_donorf, s_condgrpf, s_condgrp_final, s_atgf, s_gvhd_final, s_hla_final],
)
wi_baseline_dict["SEX"].change(
fn=lock_sex,
inputs=wi_baseline_dict["SEX"],
outputs=s_dict["SEX"],
)
s_dict["SEX"].interactive = False
def _make_copy_fn(s_grouped_dd_ref, s_donorf_ref, s_condgrpf_ref,
s_condgrp_final_ref, s_atgf_ref, s_gvhd_final_ref,
s_hla_final_ref):
def _copy(*vals):
n = len(ALL_FEATURES)
feature_vals = vals[:n]
regimen_value = vals[n] if len(vals) > n else None
preset_outputs = apply_grouped_preset(regimen_value)
regimen_dd_upd = gr.update(value=regimen_value)
d_upd, c_upd, cr_upd, a_upd, g_upd, h_upd = preset_outputs[1:]
return (*list(feature_vals), regimen_dd_upd,
d_upd, c_upd, cr_upd, a_upd, g_upd, h_upd)
return _copy
copy_btn.click(
fn=_make_copy_fn(s_grouped_dd, s_donorf, s_condgrpf,
s_condgrp_final, s_atgf, s_gvhd_final, s_hla_final),
inputs=wi_baseline_list + [wi_grouped_base],
outputs=(scenario_lists[-1]
+ [s_grouped_dd, s_donorf, s_condgrpf,
s_condgrp_final, s_atgf, s_gvhd_final, s_hla_final]),
)
copy_btn.click(
fn=lock_sex,
inputs=wi_baseline_dict["SEX"],
outputs=s_dict["SEX"],
)
# ── Slider β†’ show/hide scenario panels ───────────────────────────
def _update_scenario_visibility(n):
return [gr.update(visible=(i < int(n))) for i in range(MAX_SCENARIOS)]
n_scenarios_slider.change(
fn=_update_scenario_visibility,
inputs=n_scenarios_slider,
outputs=scenario_rows,
)
n_scenarios_slider.change(
fn=lambda n: int(n),
inputs=n_scenarios_slider,
outputs=n_scenarios_state,
)
# ── RUN ──────────────────────────────────────────────────────────
gr.Markdown("---")
wi_run_btn = gr.Button(
"Run Counterfactual Comparison",
elem_classes="counterfactual-button",
size="lg",
)
# ── RESULTS ──────────────────────────────────────────────────────
gr.Markdown("## Comparison Results")
wi_violation_html = gr.HTML()
gr.Markdown("### Outcome Probability Table")
wi_table_html = gr.HTML()
gr.Markdown("---")
# ── Icon Arrays β€” comparison, Python-driven carousel ─────────────
with gr.Accordion("Outcome Icon Arrays β€” Baseline vs Scenarios", open=False):
gr.Markdown(
"*One row per scenario. Use the ← β†’ arrows to browse outcomes one at a time.*"
)
wi_cmp_probs_state = gr.State(None) # list of dicts
wi_cmp_labels_state = gr.State(None) # list of labels
wi_cmp_colors_state = gr.State(None) # list of colors
wi_cmp_idx_state = gr.State(0)
with gr.Row():
wi_cmp_prev_btn = gr.Button("←", elem_classes="icon-nav-button", size="sm", scale=0)
wi_icon_html = gr.HTML(scale=4)
wi_cmp_next_btn = gr.Button("β†’", elem_classes="icon-nav-button", size="sm", scale=0)
def _on_cmp_prev(idx, probs_list, labels, colors):
if probs_list is None:
return idx, ""
new_idx = max(0, idx - 1)
return new_idx, _render_comparison_icon_page(probs_list, labels, colors, new_idx)
def _on_cmp_next(idx, probs_list, labels, colors):
if probs_list is None:
return idx, ""
new_idx = min(len(ICON_OUTCOMES) - 1, idx + 1)
return new_idx, _render_comparison_icon_page(probs_list, labels, colors, new_idx)
wi_cmp_prev_btn.click(
fn=_on_cmp_prev,
inputs=[wi_cmp_idx_state, wi_cmp_probs_state, wi_cmp_labels_state, wi_cmp_colors_state],
outputs=[wi_cmp_idx_state, wi_icon_html],
)
wi_cmp_next_btn.click(
fn=_on_cmp_next,
inputs=[wi_cmp_idx_state, wi_cmp_probs_state, wi_cmp_labels_state, wi_cmp_colors_state],
outputs=[wi_cmp_idx_state, wi_icon_html],
)
gr.Markdown("---")
# ── SHAP Carousels β€” one per scenario slot ────────────────────────
with gr.Accordion("SHAP Feature Importance", open=False):
# Baseline SHAP carousel
gr.Markdown("### Baseline")
wi_shap_base_store = gr.State(None)
wi_shap_base_idx = gr.State(0)
with gr.Row():
wb_shap_prev = gr.Button("←", elem_classes="shap-nav-button", size="sm", scale=0)
wb_shap_crumb = gr.HTML(value="", scale=4)
wb_shap_next = gr.Button("β†’", elem_classes="shap-nav-button", size="sm", scale=0)
wb_shap_plot = gr.Plot(label="Baseline β€” SHAP")
def _wb_prev(idx, plots):
new_idx = max(0, idx - 1)
plot = plots[SHAP_ORDER[new_idx]] if plots else None
return new_idx, plot, _shap_counter_html(new_idx)
def _wb_next(idx, plots):
new_idx = min(len(SHAP_ORDER) - 1, idx + 1)
plot = plots[SHAP_ORDER[new_idx]] if plots else None
return new_idx, plot, _shap_counter_html(new_idx)
wb_shap_prev.click(_wb_prev, [wi_shap_base_idx, wi_shap_base_store],
[wi_shap_base_idx, wb_shap_plot, wb_shap_crumb])
wb_shap_next.click(_wb_next, [wi_shap_base_idx, wi_shap_base_store],
[wi_shap_base_idx, wb_shap_plot, wb_shap_crumb])
# Per-scenario SHAP carousels
wi_shap_scen_stores = []
wi_shap_scen_idxs = []
wi_shap_scen_plots = []
wi_shap_scen_crumbs = []
wi_shap_scen_rows = []
for s_idx in range(MAX_SCENARIOS):
with gr.Row(visible=(s_idx == 0)) as shap_row:
wi_shap_scen_rows.append(shap_row)
with gr.Column():
scen_color = SCENARIO_COLORS[s_idx]
gr.HTML(
f'<div style="font-weight:700;font-size:13px;color:{scen_color};">'
f'Scenario {s_idx+1}</div>'
)
s_store = gr.State(None)
s_idx_s = gr.State(0)
wi_shap_scen_stores.append(s_store)
wi_shap_scen_idxs.append(s_idx_s)
with gr.Row():
s_prev_btn = gr.Button("←", elem_classes="shap-nav-button", size="sm", scale=0)
s_crumb = gr.HTML(value="", scale=4)
s_next_btn = gr.Button("β†’", elem_classes="shap-nav-button", size="sm", scale=0)
s_plot = gr.Plot(label=f"Scenario {s_idx+1} β€” SHAP")
wi_shap_scen_plots.append(s_plot)
wi_shap_scen_crumbs.append(s_crumb)
def _make_prev(st, ix):
def fn(idx, plots):
new_idx = max(0, idx - 1)
plot = plots[SHAP_ORDER[new_idx]] if plots else None
return new_idx, plot, _shap_counter_html(new_idx)
return fn
def _make_next(st, ix):
def fn(idx, plots):
new_idx = min(len(SHAP_ORDER) - 1, idx + 1)
plot = plots[SHAP_ORDER[new_idx]] if plots else None
return new_idx, plot, _shap_counter_html(new_idx)
return fn
s_prev_btn.click(
_make_prev(s_store, s_idx_s),
[s_idx_s, s_store],
[s_idx_s, s_plot, s_crumb],
)
s_next_btn.click(
_make_next(s_store, s_idx_s),
[s_idx_s, s_store],
[s_idx_s, s_plot, s_crumb],
)
# Sync SHAP scenario row visibility with slider
def _update_shap_vis(n):
return [gr.update(visible=(i < int(n))) for i in range(MAX_SCENARIOS)]
n_scenarios_slider.change(
fn=_update_shap_vis,
inputs=n_scenarios_slider,
outputs=wi_shap_scen_rows,
)
# ── RUN callback ──────────────────────────────────────────────────
def run_counterfactual(*all_values):
n_feat = len(ALL_FEATURES)
n_scen = int(all_values[0])
base_vals = all_values[1 : 1 + n_feat]
name_offset = 1 + n_feat
scenario_names_raw = all_values[name_offset : name_offset + MAX_SCENARIOS]
feat_offset = name_offset + MAX_SCENARIOS
scenario_val_blocks = [
all_values[feat_offset + s * n_feat : feat_offset + (s + 1) * n_feat]
for s in range(MAX_SCENARIOS)
]
def _empty_outputs():
# violation, table, cmp_probs, cmp_labels, cmp_colors, cmp_idx, icon_html
# base_store, base_idx, base_plot, base_crumb
# per scenario: store, idx, plot, crumb
out = ["", "", None, None, None, 0, ""]
out += [None, 0, None, ""]
for _ in range(MAX_SCENARIOS):
out += [None, 0, None, ""]
return tuple(out)
try:
base_dict = _values_to_dict(base_vals)
_check_missing(base_dict, "Baseline")
all_violations = []
scenario_dicts_active = []
for s in range(n_scen):
sd = _values_to_dict(scenario_val_blocks[s])
_check_missing(sd, f"Scenario {s+1}")
v = _validate_counterfactual_constraints(base_dict, sd, f"Scenario {s+1}")
all_violations.extend(v)
scenario_dicts_active.append(sd)
if all_violations:
out = [_build_violation_html(all_violations), "", None, None, None, 0, ""]
out += [None, 0, None, ""]
for _ in range(MAX_SCENARIOS):
out += [None, 0, None, ""]
return tuple(out)
base_probs, base_ci = predict_all_outcomes(
base_dict, use_calibration=True, use_signed_voting=True, n_boot_ci=500
)
scen_probs_list = []
scen_ci_list = []
for sd in scenario_dicts_active:
sp, sci = predict_all_outcomes(
sd, use_calibration=True, use_signed_voting=True, n_boot_ci=500
)
scen_probs_list.append(sp)
scen_ci_list.append(sci)
scen_labels = []
for i in range(n_scen):
raw_name = scenario_names_raw[i] if i < len(scenario_names_raw) else ""
name = str(raw_name).strip() if raw_name else ""
scen_labels.append(name if name else f"Scenario {i+1}")
scen_colors = SCENARIO_COLORS[:n_scen]
table_html = _build_comparison_table_html(
base_probs, base_ci, scen_probs_list, scen_ci_list, scen_labels, scen_colors
)
# Icon carousel: store all probs + labels + colors in State
all_probs_list = [base_probs] + scen_probs_list
all_labels = ["Baseline"] + scen_labels
all_colors = ["#1565c0"] + scen_colors
icon_html_first = _render_comparison_icon_page(all_probs_list, all_labels, all_colors, 0)
# SHAP β€” baseline
base_shap_plots = create_all_shap_plots(base_dict, max_display=10)
first_base_plot = base_shap_plots[SHAP_ORDER[0]]
base_crumb = _shap_counter_html(0)
# SHAP β€” scenarios
scen_shap_data = []
for s in range(MAX_SCENARIOS):
if s < n_scen:
sp_plots = create_all_shap_plots(scenario_dicts_active[s], max_display=10)
scen_shap_data.append((sp_plots, sp_plots[SHAP_ORDER[0]], _shap_counter_html(0)))
else:
scen_shap_data.append((None, None, ""))
out = [
"", # violation html
table_html, # table
all_probs_list, # cmp_probs state
all_labels, # cmp_labels state
all_colors, # cmp_colors state
0, # cmp_idx state
icon_html_first, # icon display
base_shap_plots, # base shap store
0, # base shap idx
first_base_plot, # base shap plot
base_crumb, # base shap crumb
]
for s in range(MAX_SCENARIOS):
store, plot0, crumb0 = scen_shap_data[s]
out += [store, 0, plot0, crumb0]
return tuple(out)
except Exception as e:
print(traceback.format_exc())
raise gr.Error(f"{type(e).__name__}: {str(e)}")
# Build output list
all_run_outputs = (
[wi_violation_html, wi_table_html,
wi_cmp_probs_state, wi_cmp_labels_state, wi_cmp_colors_state,
wi_cmp_idx_state, wi_icon_html]
+ [wi_shap_base_store, wi_shap_base_idx, wb_shap_plot, wb_shap_crumb]
+ [item for s in range(MAX_SCENARIOS)
for item in (wi_shap_scen_stores[s], wi_shap_scen_idxs[s],
wi_shap_scen_plots[s], wi_shap_scen_crumbs[s])]
)
all_run_inputs = (
[n_scenarios_state]
+ wi_baseline_list
+ scenario_name_inputs
+ [feat for s_list in scenario_lists for feat in s_list]
)
wi_run_btn.click(
fn=run_counterfactual,
inputs=all_run_inputs,
outputs=all_run_outputs,
)
# ─────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
demo.launch(ssr_mode=False)