Whatifs / app.py
shivapriyasom's picture
Update app.py
3f4a35c verified
import gradio as gr
import pandas as pd
import numpy as np
import copy
import traceback
from inference import (
FEATURE_NAMES,
REPORTING_OUTCOMES,
OUTCOME_DESCRIPTIONS,
OUTCOMES,
SHAP_OUTCOMES,
predict_with_comparison,
predict_all_outcomes,
create_all_shap_plots,
create_all_icon_arrays,
)
# ─────────────────────────────────────────────────────────────────────────────
# CHOICES / CONSTANTS
# ─────────────────────────────────────────────────────────────────────────────
AGEGPFF_CHOICES = ["<=10", "11-17", "18-29", "30-49", ">=50"]
SEX_CHOICES = ["Male", "Female"]
KPS_CHOICES = ["<90", "≥ 90"]
DONORF_CHOICES = [
"HLA identical sibling", "HLA mismatch relative",
"Matched unrelated donor", "Mismatched unrelated donor or cord blood",
]
GRAFTYPE_CHOICES = ["Bone marrow", "Peripheral blood", "Cord blood"]
CONDGRPF_CHOICES = ["MAC", "RIC", "NMA"]
CONDGRP_FINAL_CHOICES = [
"TBI/Cy", "TBI/Cy/Flu", "TBI/Cy/Flu/TT", "TBI/Mel", "TBI/Flu",
"TBI alone (300/400/600cGy)", "Bu/Cy", "Bu/Mel", "Flu/Bu/TT",
"Flu/Bu", "Flu/Mel/TT", "Flu/Mel", "Cy/Flu", "Treosulfan",
"Cy alone", "Flud", "TLI",
]
ATGF_CHOICES = ["ATG", "Alemtuzumab", "None"]
GVHD_FINAL_CHOICES = [
"Ex-vivo T-cell depletion", "CD34 selection", "Post-CY + siro +- MMF",
"Post-CY + MMF + CNI", "CNI + MMF", "CNI + MTX", "CNI alone",
"CNI + siro", "Siro alone", "MMF + MTX", "MMF + siro", "MMF alone",
"MTX alone", "MTX + siro",
]
HLA_FINAL_CHOICES = ["8/8", "7/8", "≤ 6/8"]
RCMVPR_CHOICES = ["Negative", "Positive"]
EXCHTFPR_CHOICES = ["No", "Yes"]
VOC2YPR_CHOICES = ["No", "Yes"]
VOCFRQPR_CHOICES = ["< 3/yr", "≥ 3/yr"]
SCATXRSN_CHOICES = [
"CNS event", "Acute chest Syndrome", "Recurrent vaso-occlusive pain",
"Recurrent priapism", "Excessive transfusion requirements/iron overload",
"Cardio-pulmonary", "Chronic transfusion", "Asymptomatic",
"Renal insufficiency", "Splenic sequestration", "Avascular necrosis",
"Hodgkin lymphoma",
]
NUM_COLS_SET = {"AGE", "NACS2YR"}
MAX_SCENARIOS = 5
GROUPED_REGIMEN_CHOICES = [
("── HLA IDENTICAL ──", "__header_hla_identical__"),
("Hsieh et al 2014", "Hsieh et al 2014"),
("Krishnamurti et al 2019", "Krishnamurti et al 2019"),
("King et al 2015", "King et al 2015"),
("Walters et al 1996", "Walters et al 1996"),
("── HLA MISMATCHED ──", "__header_hla_mismatched__"),
("Bolanos-Meade et al 2022 (HLA Mismatch)", "Bolanos-Meade et al 2022 (HLA Mismatch)"),
("Patel et al 2020 (HLA Mismatch)", "Patel et al 2020 (HLA Mismatch)"),
("── MATCHED UNRELATED ──", "__header_matched_unrelated__"),
("L Krishnamurti et al 2019", "L Krishnamurti et al 2019"),
("Shenoy et al 2016", "Shenoy et al 2016"),
("── MISMATCHED UNRELATED / CORD BLOOD ──", "__header_mismatched_cord__"),
("Bolanos-Meade et al 2022 (Mismatched/Cord)", "Bolanos-Meade et al 2022 (Mismatched/Cord)"),
("Patel et al 2020 (Mismatched/Cord)", "Patel et al 2020 (Mismatched/Cord)"),
("── CUSTOM ──", "__header_custom__"),
("Custom", "Custom"),
]
HEADER_VALUES = {v for _, v in GROUPED_REGIMEN_CHOICES if v.startswith("__header_")}
PUBLISHED_PRESETS = {
"Hsieh et al 2014": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI alone (300/400/600cGy)", "ATGF": "Alemtuzumab", "GVHD_FINAL": "Siro alone", "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling"},
"Krishnamurti et al 2019": {"CONDGRPF": "MAC", "CONDGRP_FINAL": "Flu/Bu", "ATGF": "ATG", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling"},
"King et al 2015": {"CONDGRPF": "RIC", "CONDGRP_FINAL": "Flu/Mel", "ATGF": "Alemtuzumab", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling"},
"Walters et al 1996": {"CONDGRPF": "MAC", "CONDGRP_FINAL": "Bu/Cy", "ATGF": "ATG", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling"},
"Bolanos-Meade et al 2022 (HLA Mismatch)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "HLA mismatch relative"},
"Patel et al 2020 (HLA Mismatch)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu/TT", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "HLA mismatch relative"},
"L Krishnamurti et al 2019": {"CONDGRPF": "MAC", "CONDGRP_FINAL": "Flu/Bu", "ATGF": "ATG", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "Matched unrelated donor"},
"Shenoy et al 2016": {"CONDGRPF": "RIC", "CONDGRP_FINAL": "Flu/Mel", "ATGF": "Alemtuzumab", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "Matched unrelated donor"},
"Bolanos-Meade et al 2022 (Mismatched/Cord)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "Mismatched unrelated donor or cord blood"},
"Patel et al 2020 (Mismatched/Cord)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu/TT", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "Mismatched unrelated donor or cord blood"},
}
MSD_DONOR = "HLA identical sibling"
MUD_DONOR = "Matched unrelated donor"
LOCKED_88_DONORS = {MSD_DONOR, MUD_DONOR}
NON_88_DONORS = {"HLA mismatch relative", "Mismatched unrelated donor or cord blood"}
MSD_REGIMENS = {"Hsieh et al 2014", "Krishnamurti et al 2019", "King et al 2015", "Walters et al 1996"}
MMUD_REGIMENS = {"Bolanos-Meade et al 2022 (Mismatched/Cord)", "Patel et al 2020 (Mismatched/Cord)",
"Bolanos-Meade et al 2022 (HLA Mismatch)", "Patel et al 2020 (HLA Mismatch)"}
PATIENT_FEATURES = ["AGE", "AGEGPFF", "SEX", "KPS", "RCMVPR"]
DONOR_FEATURES = ["DONORF", "GRAFTYPE", "HLA_FINAL", "CONDGRPF", "CONDGRP_FINAL", "ATGF", "GVHD_FINAL"]
DISEASE_FEATURES = ["NACS2YR", "EXCHTFPR", "VOC2YPR", "VOCFRQPR", "SCATXRSN"]
ALL_FEATURES = PATIENT_FEATURES + DONOR_FEATURES + DISEASE_FEATURES
ICON_OUTCOMES = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI"]
OUTCOME_TITLES = {
"DEAD": "Death",
"GF": "Graft Failure",
"AGVHD": "Acute GvHD",
"CGVHD": "Chronic GvHD",
"VOCPSHI": "Vaso-Occlusive Crisis",
"STROKEHI": "Stroke Post-HCT",
}
EVENT_COLOR = "#e53935"
NO_EVENT_COLOR = "#43a047"
SHAP_ORDER = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "EFS", "STROKEHI", "OS"]
SCENARIO_COLORS = ["#e65100", "#6a1b9a", "#1b5e20", "#0d47a1", "#b71c1c"]
# ─────────────────────────────────────────────────────────────────────────────
# COMPONENT FACTORY
# ─────────────────────────────────────────────────────────────────────────────
_FEATURE_META = {
"AGE": {"label": "Age (years)", "type": "number", "kwargs": {"minimum": 0, "maximum": 100, "step": 1}},
"AGEGPFF": {"label": "Age Group", "type": "dropdown", "choices": AGEGPFF_CHOICES, "kwargs": {"interactive": False, "info": "Auto-filled from Age"}},
"SEX": {"label": "Sex", "type": "radio", "choices": SEX_CHOICES},
"KPS": {"label": "Karnofsky Performance Status", "type": "radio", "choices": KPS_CHOICES},
"RCMVPR": {"label": "Recipient CMV Status", "type": "radio", "choices": RCMVPR_CHOICES},
"DONORF": {"label": "Donor Type", "type": "dropdown", "choices": DONORF_CHOICES},
"GRAFTYPE": {"label": "Graft Type", "type": "radio", "choices": GRAFTYPE_CHOICES},
"HLA_FINAL": {"label": "HLA Matching", "type": "radio", "choices": HLA_FINAL_CHOICES},
"CONDGRPF": {"label": "Conditioning Intensity", "type": "radio", "choices": CONDGRPF_CHOICES},
"CONDGRP_FINAL": {"label": "Conditioning Regimen", "type": "dropdown", "choices": CONDGRP_FINAL_CHOICES},
"ATGF": {"label": "Serotherapy", "type": "radio", "choices": ATGF_CHOICES},
"GVHD_FINAL": {"label": "GvHD Prophylaxis", "type": "dropdown", "choices": GVHD_FINAL_CHOICES},
"NACS2YR": {"label": "Acute Chest Syndrome Episodes (2 yrs pre-HCT)", "type": "number", "kwargs": {"minimum": 0, "maximum": 50, "step": 1}},
"EXCHTFPR": {"label": "Exchange Transfusion Pre-HCT", "type": "radio", "choices": EXCHTFPR_CHOICES},
"VOC2YPR": {"label": "Vaso-Occlusive Crisis in 2 yrs Pre-HCT", "type": "radio", "choices": VOC2YPR_CHOICES},
"VOCFRQPR": {"label": "VoC Frequency Pre-HCT", "type": "radio", "choices": VOCFRQPR_CHOICES},
"SCATXRSN": {"label": "Primary Reason for HCT", "type": "dropdown", "choices": SCATXRSN_CHOICES},
}
def make_component(feature: str, suffix: str = ""):
meta = _FEATURE_META[feature]
label = f"{meta['label']} {suffix}".strip() if suffix else meta["label"]
kind = meta["type"]
kwargs = dict(meta.get("kwargs", {}))
if kind == "number":
return gr.Number(label=label, value=None, **kwargs)
elif kind == "dropdown":
interactive = kwargs.pop("interactive", True)
info = kwargs.pop("info", None)
return gr.Dropdown(label=label, choices=meta["choices"], value=None,
interactive=interactive, info=info, **kwargs)
elif kind == "radio":
interactive = kwargs.pop("interactive", True)
info = kwargs.pop("info", None)
return gr.Radio(label=label, choices=meta["choices"], value=None,
interactive=interactive, info=info, **kwargs)
else:
raise ValueError(f"Unknown component type '{kind}' for feature '{feature}'")
# ─────────────────────────────────────────────────────────────────────────────
# CONSTRAINT / VALIDATION HELPERS
# ─────────────────────────────────────────────────────────────────────────────
def _hla_update_for_donor(donor_value):
if not donor_value:
return gr.update(choices=HLA_FINAL_CHOICES, interactive=True)
if donor_value in LOCKED_88_DONORS:
return gr.update(choices=["8/8"], value="8/8", interactive=False)
elif donor_value in NON_88_DONORS:
return gr.update(choices=["7/8", "≤ 6/8"], value=None, interactive=True)
return gr.update(choices=HLA_FINAL_CHOICES, interactive=True)
def _validate_counterfactual_constraints(base_dict, wi_dict, label=""):
violations = []
tag = f"[{label}] " if label else ""
if base_dict.get("SEX") and wi_dict.get("SEX"):
if base_dict["SEX"] != wi_dict["SEX"]:
violations.append(f"{tag} Immutable feature: Sex cannot be changed.")
base_age, wi_age = base_dict.get("AGE"), wi_dict.get("AGE")
if base_age is not None and wi_age is not None:
try:
if float(wi_age) < float(base_age):
violations.append(f"{tag}Age cannot be decreased ({base_age}{wi_age}).")
except (TypeError, ValueError):
pass
wi_donor = wi_dict.get("DONORF")
wi_hla = wi_dict.get("HLA_FINAL")
if wi_donor and wi_hla:
if wi_donor in LOCKED_88_DONORS and wi_hla != "8/8":
violations.append(f"{tag}HLA constraint: '{wi_donor}' requires 8/8 HLA.")
elif wi_donor in NON_88_DONORS and wi_hla == "8/8":
violations.append(f"{tag}HLA constraint: '{wi_donor}' cannot have 8/8 HLA.")
if wi_donor:
wi_gvhd = wi_dict.get("GVHD_FINAL", "")
if wi_donor == MSD_DONOR and wi_gvhd in {"Post-CY + siro +- MMF", "Post-CY + MMF + CNI"}:
violations.append(
f"{tag} Post-Cy GVHD prophylaxis is inconsistent with donor '{MSD_DONOR}'."
)
return violations
def _values_to_dict(values):
d = {}
for f, v in zip(ALL_FEATURES, values):
if f in NUM_COLS_SET:
try:
d[f] = float(v) if v not in (None, "") else None
except (TypeError, ValueError):
d[f] = None
else:
d[f] = v
return d
def _check_missing(user_vals, label=""):
missing = [f for f, v in user_vals.items()
if v is None or v == "" or (isinstance(v, float) and pd.isna(v))]
if missing:
raise ValueError(
f"{'[' + label + '] ' if label else ''}Please fill in all fields. "
f"Missing: {', '.join(missing)}"
)
def get_age_group(age):
if age is None or age == "":
return ""
try:
age = float(age)
if age <= 10: return "<=10"
elif age <= 17: return "11-17"
elif age <= 29: return "18-29"
elif age <= 49: return "30-49"
else: return ">=50"
except (ValueError, TypeError):
return ""
def vocfrqpr_from_voc2ypr(voc_status):
if voc_status == "No":
return gr.update(value="< 3/yr", interactive=False)
return gr.update(value=None, interactive=True)
def apply_grouped_preset(selected_value):
if not selected_value or selected_value in HEADER_VALUES:
return (gr.update(value=None),) + (gr.update(),) * 6
if selected_value == "Custom":
return (gr.update(),) + tuple(gr.update(interactive=True) for _ in range(6))
preset = PUBLISHED_PRESETS.get(selected_value)
if not preset:
return (gr.update(),) * 7
donor = preset["DONORF"]
hla_update = gr.update(
value=preset["HLA_FINAL"], interactive=False,
choices=(["8/8"] if donor in LOCKED_88_DONORS
else (["7/8", "≤ 6/8"] if donor in NON_88_DONORS else HLA_FINAL_CHOICES)),
)
return (
gr.update(),
gr.update(value=preset["DONORF"], interactive=False),
gr.update(value=preset["CONDGRPF"], interactive=False),
gr.update(value=preset["CONDGRP_FINAL"], interactive=False),
gr.update(value=preset["ATGF"], interactive=False),
gr.update(value=preset["GVHD_FINAL"], interactive=False),
hla_update,
)
def lock_sex(baseline_sex):
if baseline_sex:
return gr.update(value=baseline_sex, interactive=False)
return gr.update(interactive=False)
# ─────────────────────────────────────────────────────────────────────────────
# HTML RENDERERS
# ─────────────────────────────────────────────────────────────────────────────
def _stick_figure_svg(color, size=16):
h = round(size * 1.6)
return (
f'<svg xmlns="http://www.w3.org/2000/svg" width="{size}" height="{h}" '
f'viewBox="0 0 20 32" style="display:block;flex-shrink:0;" '
f'stroke="{color}" stroke-width="2.2" stroke-linecap="round" fill="none">'
f'<circle cx="10" cy="5" r="3.8" fill="{color}" stroke="none"/>'
f'<line x1="10" y1="9" x2="10" y2="20"/>'
f'<line x1="3" y1="13" x2="17" y2="13"/>'
f'<line x1="10" y1="20" x2="4" y2="30"/>'
f'<line x1="10" y1="20" x2="16" y2="30"/>'
f'</svg>'
)
def _icon_card_html(probability, outcome, panel_label="", panel_color="#1565c0"):
title = OUTCOME_TITLES.get(outcome, OUTCOME_DESCRIPTIONS.get(outcome, outcome))
n_event = round(probability * 100)
n_no_event = 100 - n_event
pct_str = f"{probability * 100:.1f}%"
rows_parts = []
for row in range(10):
cells = ""
for col in range(10):
idx = row * 10 + col
color = EVENT_COLOR if idx < n_event else NO_EVENT_COLOR
cells += _stick_figure_svg(color, size=13)
rows_parts.append(
f'<div style="display:flex;justify-content:center;gap:1px;margin-bottom:1px;">{cells}</div>'
)
grid_html = "\n".join(rows_parts)
fig_e = _stick_figure_svg(EVENT_COLOR, size=11)
fig_ne = _stick_figure_svg(NO_EVENT_COLOR, size=11)
legend = (
f'<div style="display:inline-grid;grid-template-columns:13px 1fr 36px;'
f'align-items:center;gap:3px;row-gap:3px;">'
f'{fig_e}<span style="color:{EVENT_COLOR};font-weight:700;font-size:9px;">Event</span>'
f'<span style="color:#888;font-size:8px;">({n_event}/100)</span>'
f'{fig_ne}<span style="color:{NO_EVENT_COLOR};font-weight:700;font-size:9px;">No Event</span>'
f'<span style="color:#888;font-size:8px;">({n_no_event}/100)</span>'
f'</div>'
)
badge = (
f'<div style="background:{panel_color};color:#fff;font-size:8px;font-weight:700;'
f'border-radius:3px;padding:1px 5px;margin-bottom:2px;display:inline-block;">'
f'{panel_label}</div>'
) if panel_label else ""
return (
f'<div style="background:#fff;border:1px solid #e0e0e0;border-radius:7px;'
f'padding:6px 5px;text-align:center;font-family:\'Segoe UI\',Arial,sans-serif;'
f'box-shadow:0 2px 4px rgba(0,0,0,0.06);box-sizing:border-box;'
f'display:flex;flex-direction:column;align-items:center;">'
f'{badge}'
f'<div style="min-height:26px;display:flex;align-items:center;justify-content:center;'
f'font-size:10px;font-weight:700;color:#222;line-height:1.3;margin-bottom:1px;">{title}</div>'
f'<div style="font-size:18px;font-weight:800;color:{EVENT_COLOR};'
f'line-height:1;margin-bottom:3px;">{pct_str}</div>'
f'<div style="margin-bottom:3px;">{grid_html}</div>'
f'<div>{legend}</div>'
f'</div>'
)
def _build_comparison_icon_grid(all_probs_list, labels, colors):
"""
all_probs_list: list of dicts (baseline first, then scenarios)
labels: list of display labels
colors: list of hex colors
Layout: one row per scenario, columns = outcomes.
"""
# Header row with outcome names
outcome_headers = "".join(
f'<div style="flex:1 1 0%;min-width:0;text-align:center;font-size:11px;'
f'font-weight:700;color:#555;padding:4px 2px;">'
f'{OUTCOME_TITLES.get(o, o)}</div>'
for o in ICON_OUTCOMES
)
header_row = (
f'<div style="display:flex;gap:6px;margin-bottom:4px;padding-left:110px;">'
f'{outcome_headers}</div>'
)
rows_html = header_row
for i, (probs, label, color) in enumerate(zip(all_probs_list, labels, colors)):
# Row label
row_label = (
f'<div style="width:104px;flex-shrink:0;display:flex;align-items:center;'
f'justify-content:flex-end;padding-right:6px;">'
f'<span style="background:{color};color:#fff;font-size:9px;font-weight:700;'
f'border-radius:4px;padding:2px 6px;white-space:nowrap;overflow:hidden;'
f'text-overflow:ellipsis;max-width:100px;">{label}</span></div>'
)
cards = "".join(
f'<div style="flex:1 1 0%;min-width:0;">'
f'{_icon_card_html(probs[o], o, "", color)}'
f'</div>'
for o in ICON_OUTCOMES
)
rows_html += (
f'<div style="display:flex;gap:6px;margin-bottom:8px;align-items:stretch;">'
f'{row_label}{cards}</div>'
)
footnote = (
f'<div style="font-size:10px;color:#888;text-align:center;margin-top:2px;">'
f'Each figure = 1 patient out of 100. '
f'<span style="color:{EVENT_COLOR};font-weight:600;">&#9632; Red = Event</span> &nbsp; '
f'<span style="color:{NO_EVENT_COLOR};font-weight:600;">&#9632; Green = No Event</span>'
f'</div>'
)
return f'<div style="font-family:\'Segoe UI\',Arial,sans-serif;padding:4px 0;">{rows_html}{footnote}</div>'
def _delta_color_html(delta, is_survival):
if abs(delta) < 0.0005:
color = "#888888"
elif (delta > 0 and is_survival) or (delta < 0 and not is_survival):
color = "#2e7d32"
else:
color = "#c62828"
sign = "+" if delta >= 0 else ""
return f'<span style="color:{color};font-weight:700;">{sign}{delta*100:.1f}%</span>'
def _build_comparison_table_html(base_probs, base_ci, scenario_probs_list, scenario_ci_list, scenario_labels, scenario_colors):
"""
Inverted layout: rows = Baseline + Scenarios, columns = Outcomes.
"""
survival_outcomes = {"OS", "EFS"}
# Build outcome column headers
outcome_headers = "".join(
f"<th style='padding:8px 10px;text-align:center;border-bottom:2px solid #ccd;"
f"color:#333;font-size:11px;'>{OUTCOME_DESCRIPTIONS.get(o, o)}</th>"
for o in REPORTING_OUTCOMES if o in base_probs
)
header = (
"<div style='overflow-x:auto;'>"
"<table style='width:100%;border-collapse:collapse;"
"font-family:\"Segoe UI\",Arial,sans-serif;font-size:12px;'>"
"<thead><tr style='background:#f0f4f8;'>"
"<th style='padding:9px 12px;text-align:left;border-bottom:2px solid #ccd;"
"color:#333;min-width:120px;'>Scenario</th>"
f"{outcome_headers}"
"</tr></thead><tbody>"
)
rows = ""
# Baseline row
baseline_cells = ""
for o in REPORTING_OUTCOMES:
if o not in base_probs:
continue
bp = base_probs[o]
blo, bhi = base_ci.get(o, (float("nan"), float("nan")))
baseline_cells += (
f"<td style='padding:7px 10px;text-align:center;'>"
f"<div style='font-weight:700;color:#1565c0;font-size:13px;'>{bp*100:.1f}%</div>"
f"<div style='color:#5c7fa8;font-size:9px;'>[{blo*100:.1f}%–{bhi*100:.1f}%]</div>"
f"</td>"
)
rows += (
f"<tr style='background:#e8f0fb;'>"
f"<td style='padding:8px 12px;font-weight:700;color:#1565c0;border-right:2px solid #ccd;'>"
f"<div style='display:flex;align-items:center;gap:6px;'>"
f"<span style='display:inline-block;width:10px;height:10px;border-radius:50%;"
f"background:#1565c0;flex-shrink:0;'></span>Baseline</div></td>"
f"{baseline_cells}</tr>"
)
# Scenario rows
for j, (sp_dict, sci_dict, s_label) in enumerate(zip(scenario_probs_list, scenario_ci_list, scenario_labels)):
sc_color = scenario_colors[j % len(scenario_colors)]
bg = "#fafbfc" if j % 2 == 0 else "#ffffff"
scenario_cells = ""
for o in REPORTING_OUTCOMES:
if o not in base_probs or o not in sp_dict:
continue
bp = base_probs[o]
wp = sp_dict[o]
wlo, whi = sci_dict.get(o, (float("nan"), float("nan")))
delta = wp - bp
is_surv = o in survival_outcomes
scenario_cells += (
f"<td style='padding:7px 10px;text-align:center;'>"
f"<div style='font-weight:700;color:{sc_color};font-size:13px;'>{wp*100:.1f}%</div>"
f"<div style='color:#888;font-size:9px;'>[{wlo*100:.1f}%–{whi*100:.1f}%]</div>"
f"<div style='font-size:10px;margin-top:1px;'>{_delta_color_html(delta, is_surv)}</div>"
f"</td>"
)
rows += (
f"<tr style='background:{bg};'>"
f"<td style='padding:8px 12px;font-weight:600;border-right:2px solid #ccd;'>"
f"<div style='display:flex;align-items:center;gap:6px;'>"
f"<span style='display:inline-block;width:10px;height:10px;border-radius:50%;"
f"background:{sc_color};flex-shrink:0;'></span>"
f"<span style='color:{sc_color};'>{s_label}</span></div></td>"
f"{scenario_cells}</tr>"
)
footer = (
"</tbody></table></div>"
"<div style='font-size:10.5px;color:#888;margin-top:6px;'>"
"Δ from Baseline: <span style='color:#2e7d32;font-weight:600;'>Green = improvement</span> &nbsp;"
"<span style='color:#c62828;font-weight:600;'>Red = worsening</span> &nbsp;|&nbsp; "
"OS &amp; EFS: higher is better; all other outcomes: lower is better."
"</div>"
)
return header + rows + footer
def _build_violation_html(violations):
if not violations:
return ""
items = "".join(f"<li style='margin-bottom:6px;'>{v}</li>" for v in violations)
return (
f'<div style="background:#fff3e0;border:2px solid #e65100;border-radius:8px;'
f'padding:14px 18px;font-family:\'Segoe UI\',Arial,sans-serif;margin-bottom:12px;">'
f'<div style="font-weight:700;font-size:14px;color:#bf360c;margin-bottom:8px;">'
f'Constraint Violations — Analysis blocked</div>'
f'<ul style="margin:0;padding-left:20px;color:#6d1f00;font-size:13px;">{items}</ul>'
f'<div style="margin-top:10px;font-size:11px;color:#888;">'
f'Please correct the above before running the comparison.</div>'
f'</div>'
)
# ─────────────────────────────────────────────────────────────────────────────
# MAIN PREDICT CALLBACK (Tab 1)
# ─────────────────────────────────────────────────────────────────────────────
def predict_gradio(*values):
try:
user_vals = _values_to_dict(values)
missing = [f for f, v in user_vals.items()
if v is None or v == "" or (isinstance(v, float) and pd.isna(v))]
if missing:
raise ValueError(f"Please fill in all fields. Missing: {', '.join(missing)}")
calibrated, _ = predict_with_comparison(user_vals)
calibrated_probs, calibrated_intervals = calibrated
rows = []
for outcome in REPORTING_OUTCOMES:
desc = OUTCOME_DESCRIPTIONS[outcome]
calib_prob = calibrated_probs[outcome]
ci_low, ci_high = calibrated_intervals[outcome]
rows.append({
"Outcome": desc,
"Probability": f"{calib_prob * 100:.1f}%",
"95% CI": f"[{ci_low * 100:.1f}% – {ci_high * 100:.1f}%]",
})
df = pd.DataFrame(rows)
shap_plots = create_all_shap_plots(user_vals, max_display=10)
icon_arrays = create_all_icon_arrays(calibrated_probs)
return (
df,
icon_arrays["__grid__"],
shap_plots["DEAD"],
shap_plots["GF"],
shap_plots["AGVHD"],
shap_plots["CGVHD"],
shap_plots["VOCPSHI"],
shap_plots["EFS"],
shap_plots["STROKEHI"],
shap_plots["OS"],
)
except Exception as e:
print(traceback.format_exc())
raise gr.Error(f"{type(e).__name__}: {str(e)}")
# ─────────────────────────────────────────────────────────────────────────────
# CSS
# ─────────────────────────────────────────────────────────────────────────────
custom_css = """
.predict-button {
background: linear-gradient(to right, #ff6b35, #ff8c42) !important;
border: none !important; color: white !important;
font-weight: bold !important; font-size: 16px !important; padding: 12px !important;
}
.predict-button:hover { background: linear-gradient(to right, #ff5722, #ff7b29) !important; }
.copy-button {
background: linear-gradient(to right, #388e3c, #66bb6a) !important;
border: none !important; color: white !important; font-weight: 600 !important;
}
.copy-button:hover { background: linear-gradient(to right, #2e7d32, #43a047) !important; }
.copy-from-predict-button {
background: linear-gradient(to right, #6a1b9a, #ab47bc) !important;
border: none !important; color: white !important; font-weight: 600 !important;
}
.copy-from-predict-button:hover { background: linear-gradient(to right, #4a148c, #8e24aa) !important; }
.counterfactual-button {
background: linear-gradient(to right, #1976d2, #42a5f5) !important;
border: none !important; color: white !important;
font-weight: bold !important; font-size: 15px !important; padding: 12px !important;
}
.counterfactual-button:hover { background: linear-gradient(to right, #1565c0, #1e88e5) !important; }
.output-dataframe table td:first-child,
.output-dataframe table th:first-child {
white-space: normal !important; word-break: break-word !important; min-width: 240px !important;
}
.constraint-info {
background: #e8f5e9; border-left: 4px solid #388e3c;
padding: 8px 14px; font-size: 12px; color: #1b5e20;
border-radius: 4px; margin-bottom: 8px;
}
.scenario-panel-0 { border-left: 4px solid #e65100 !important; }
.scenario-panel-1 { border-left: 4px solid #6a1b9a !important; }
.scenario-panel-2 { border-left: 4px solid #1b5e20 !important; }
.scenario-panel-3 { border-left: 4px solid #0d47a1 !important; }
.scenario-panel-4 { border-left: 4px solid #b71c1c !important; }
"""
# ─────────────────────────────────────────────────────────────────────────────
# BUILD UI
# ─────────────────────────────────────────────────────────────────────────────
with gr.Blocks(title="HCT Outcome Prediction Model", css=custom_css) as demo:
gr.Markdown("# HCT Outcome Prediction Model")
# ── shared state: how many counterfactual scenarios are active ────────────
n_scenarios_state = gr.State(1)
with gr.Tabs():
# ══════════════════════════════════════════════════════════════════════
# TAB 1 — PREDICT OUTCOMES
# ══════════════════════════════════════════════════════════════════════
with gr.Tab("Predict Outcomes"):
gr.Markdown("Enter patient, transplant, and disease characteristics to predict outcomes.")
inputs_dict = {}
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Patient Characteristics")
for f in PATIENT_FEATURES:
inputs_dict[f] = make_component(f)
with gr.Column(scale=1):
gr.Markdown("### Transplant Characteristics")
grouped_dd = gr.Dropdown(
choices=GROUPED_REGIMEN_CHOICES, value=None,
label="Published conditioning regimen",
info="Auto-fills Donor Type, Conditioning Intensity, Regimen, Serotherapy, GVHD Prophylaxis",
)
p_donorf = inputs_dict["DONORF"] = make_component("DONORF")
inputs_dict["GRAFTYPE"] = make_component("GRAFTYPE")
p_condgrpf = inputs_dict["CONDGRPF"] = make_component("CONDGRPF")
p_condgrp_final = inputs_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL")
p_atgf = inputs_dict["ATGF"] = make_component("ATGF")
p_gvhd_final = inputs_dict["GVHD_FINAL"] = make_component("GVHD_FINAL")
p_hla_final = inputs_dict["HLA_FINAL"] = make_component("HLA_FINAL")
with gr.Column(scale=1):
gr.Markdown("### Disease Characteristics")
for f in DISEASE_FEATURES:
inputs_dict[f] = make_component(f)
inputs_dict["AGE"].change(get_age_group, inputs_dict["AGE"], inputs_dict["AGEGPFF"])
inputs_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, inputs_dict["VOC2YPR"], inputs_dict["VOCFRQPR"])
p_donorf.change(fn=_hla_update_for_donor, inputs=p_donorf, outputs=p_hla_final)
grouped_dd.change(
apply_grouped_preset, grouped_dd,
[grouped_dd, p_donorf, p_condgrpf, p_condgrp_final, p_atgf, p_gvhd_final, p_hla_final],
)
inputs_list = [inputs_dict[f] for f in ALL_FEATURES]
predict_btn = gr.Button("Predict", elem_classes="predict-button", size="lg")
gr.Markdown("---")
gr.Markdown("## Prediction Results")
output_table = gr.Dataframe(
headers=["Outcome", "Probability", "95% CI"],
label="Predicted Outcomes",
elem_classes="output-dataframe",
row_count=(len(REPORTING_OUTCOMES), "dynamic"),
column_count=(3, "fixed"),
wrap=True,
)
gr.Markdown("---")
gr.Markdown("## Outcome Probability — Icon Arrays")
icon_array_grid = gr.HTML()
gr.Markdown("---")
gr.Markdown("## SHAP — Feature Importance")
with gr.Row():
shap_dead = gr.Plot(label="Death")
shap_gf = gr.Plot(label="Graft Failure")
shap_agvhd = gr.Plot(label="Acute GvHD")
shap_cgvhd = gr.Plot(label="Chronic GvHD")
with gr.Row():
shap_vocpshi = gr.Plot(label="Vaso-Occlusive Crisis Post-HCT")
shap_efs = gr.Plot(label="Event-Free Survival")
shap_stroke = gr.Plot(label="Stroke Post-HCT")
shap_os = gr.Plot(label="Overall Survival")
predict_btn.click(
fn=predict_gradio,
inputs=inputs_list,
outputs=[
output_table, icon_array_grid,
shap_dead, shap_gf, shap_agvhd, shap_cgvhd,
shap_vocpshi, shap_efs, shap_stroke, shap_os,
],
)
# ══════════════════════════════════════════════════════════════════════
# TAB 2 — COUNTERFACTUAL ANALYSIS (dynamic multi-scenario)
# ══════════════════════════════════════════════════════════════════════
with gr.Tab("Counterfactual Scenarios"):
gr.Markdown(
"## Counterfactual Scenario Analysis\n"
"Enter the **baseline** patient, then choose how many counterfactual "
"scenarios you want to compare. Each scenario panel will appear below."
)
# ── Number of scenarios selector ──────────────────────────────────
with gr.Row():
n_scenarios_slider = gr.Slider(
minimum=1, maximum=MAX_SCENARIOS, step=1, value=1,
label=f"How many counterfactual scenarios do you want to compare? (1–{MAX_SCENARIOS})",
info="Adjust this first — scenario panels will appear/disappear below.",
)
gr.Markdown("---")
# ── BASELINE ─────────────────────────────────────────────────────
gr.Markdown("## Baseline Patient Profile")
with gr.Row():
copy_from_predict_btn = gr.Button(
"Copy from Predict tab → Baseline",
elem_classes="copy-from-predict-button",
size="sm",
)
wi_baseline_dict = {}
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Patient Characteristics")
for f in PATIENT_FEATURES:
wi_baseline_dict[f] = make_component(f)
with gr.Column(scale=1):
gr.Markdown("### Transplant Characteristics")
wi_grouped_base = gr.Dropdown(
choices=GROUPED_REGIMEN_CHOICES, value=None,
label="Published conditioning regimen (Baseline)",
)
wb_donorf = wi_baseline_dict["DONORF"] = make_component("DONORF")
wi_baseline_dict["GRAFTYPE"] = make_component("GRAFTYPE")
wb_condgrpf = wi_baseline_dict["CONDGRPF"] = make_component("CONDGRPF")
wb_condgrp_final = wi_baseline_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL")
wb_atgf = wi_baseline_dict["ATGF"] = make_component("ATGF")
wb_gvhd_final = wi_baseline_dict["GVHD_FINAL"] = make_component("GVHD_FINAL")
wb_hla_final = wi_baseline_dict["HLA_FINAL"] = make_component("HLA_FINAL")
with gr.Column(scale=1):
gr.Markdown("### Disease Characteristics")
for f in DISEASE_FEATURES:
wi_baseline_dict[f] = make_component(f)
wi_baseline_dict["AGE"].change(get_age_group, wi_baseline_dict["AGE"], wi_baseline_dict["AGEGPFF"])
wi_baseline_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, wi_baseline_dict["VOC2YPR"], wi_baseline_dict["VOCFRQPR"])
wb_donorf.change(fn=_hla_update_for_donor, inputs=wb_donorf, outputs=wb_hla_final)
wi_grouped_base.change(
apply_grouped_preset, wi_grouped_base,
[wi_grouped_base, wb_donorf, wb_condgrpf, wb_condgrp_final, wb_atgf, wb_gvhd_final, wb_hla_final],
)
wi_baseline_list = [wi_baseline_dict[f] for f in ALL_FEATURES]
# Copy from Predict tab → Baseline
def _copy_predict_to_baseline(*vals):
n = len(ALL_FEATURES)
feature_vals = vals[:n]
regimen_value = vals[n] if len(vals) > n else None
preset_outputs = apply_grouped_preset(regimen_value)
regimen_dd_upd = gr.update(value=regimen_value)
donorf_upd, condgrpf_upd, condgrp_final_upd, atgf_upd, gvhd_final_upd, hla_final_upd = preset_outputs[1:]
return (*list(feature_vals), regimen_dd_upd, donorf_upd, condgrpf_upd,
condgrp_final_upd, atgf_upd, gvhd_final_upd, hla_final_upd)
copy_from_predict_btn.click(
fn=_copy_predict_to_baseline,
inputs=inputs_list + [grouped_dd],
outputs=wi_baseline_list + [wi_grouped_base, wb_donorf, wb_condgrpf,
wb_condgrp_final, wb_atgf, wb_gvhd_final, wb_hla_final],
)
gr.Markdown("---")
# ── SCENARIO PANELS ───────────────────────────────────────────────
scenario_dicts = []
scenario_lists = []
scenario_grouped_dds = []
scenario_rows = []
scenario_name_inputs = [] # NEW: one gr.Textbox per scenario for custom name
scenario_donorf_handles = []
scenario_hla_handles = []
scenario_voc2ypr_handles = []
scenario_age_handles = []
scenario_agegp_handles = []
scenario_copy_btns = []
for s_idx in range(MAX_SCENARIOS):
color = SCENARIO_COLORS[s_idx]
label = f"Scenario {s_idx + 1}"
suffix = f"({label})"
visible_init = (s_idx == 0)
with gr.Row(visible=visible_init) as s_row:
scenario_rows.append(s_row)
with gr.Column():
gr.HTML(
f'<div style="background:{color};color:#fff;font-weight:700;'
f'font-size:15px;padding:8px 14px;border-radius:6px;margin-bottom:6px;">'
f'Counterfactual {label}</div>'
)
with gr.Row():
# NEW: scenario name input
s_name_input = gr.Textbox(
label=f"Scenario name ({label})",
value=label,
placeholder=f"e.g. {label}",
scale=2,
)
scenario_name_inputs.append(s_name_input)
copy_btn = gr.Button(
f"⬇ Copy Baseline → {label}",
elem_classes="copy-button",
size="sm",
scale=1,
)
scenario_copy_btns.append(copy_btn)
s_dict = {}
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(f"#### Patient — {label}")
for f in PATIENT_FEATURES:
s_dict[f] = make_component(f, suffix)
with gr.Column(scale=1):
gr.Markdown(f"#### Transplant — {label}")
s_grouped_dd = gr.Dropdown(
choices=GROUPED_REGIMEN_CHOICES, value=None,
label=f"Published regimen ({label})",
)
scenario_grouped_dds.append(s_grouped_dd)
s_donorf = s_dict["DONORF"] = make_component("DONORF", suffix)
s_dict["GRAFTYPE"] = make_component("GRAFTYPE", suffix)
s_condgrpf = s_dict["CONDGRPF"] = make_component("CONDGRPF", suffix)
s_condgrp_final = s_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL", suffix)
s_atgf = s_dict["ATGF"] = make_component("ATGF", suffix)
s_gvhd_final = s_dict["GVHD_FINAL"] = make_component("GVHD_FINAL", suffix)
s_hla_final = s_dict["HLA_FINAL"] = make_component("HLA_FINAL", suffix)
with gr.Column(scale=1):
gr.Markdown(f"#### Disease — {label}")
for f in DISEASE_FEATURES:
s_dict[f] = make_component(f, suffix)
scenario_dicts.append(s_dict)
scenario_lists.append([s_dict[f] for f in ALL_FEATURES])
scenario_donorf_handles.append(s_donorf)
scenario_hla_handles.append(s_hla_final)
scenario_voc2ypr_handles.append(s_dict["VOC2YPR"])
scenario_age_handles.append(s_dict["AGE"])
scenario_agegp_handles.append(s_dict["AGEGPFF"])
# Wire up constraints within each scenario panel
s_dict["AGE"].change(get_age_group, s_dict["AGE"], s_dict["AGEGPFF"])
s_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, s_dict["VOC2YPR"], s_dict["VOCFRQPR"])
s_donorf.change(fn=_hla_update_for_donor, inputs=s_donorf, outputs=s_hla_final)
s_grouped_dd.change(
apply_grouped_preset, s_grouped_dd,
[s_grouped_dd, s_donorf, s_condgrpf, s_condgrp_final, s_atgf, s_gvhd_final, s_hla_final],
)
# Mirror baseline sex into each scenario and lock it
wi_baseline_dict["SEX"].change(
fn=lock_sex,
inputs=wi_baseline_dict["SEX"],
outputs=s_dict["SEX"],
)
s_dict["SEX"].interactive = False
# Copy Baseline → this scenario
def _make_copy_fn(s_grouped_dd_ref, s_donorf_ref, s_condgrpf_ref,
s_condgrp_final_ref, s_atgf_ref, s_gvhd_final_ref,
s_hla_final_ref):
def _copy(*vals):
n = len(ALL_FEATURES)
feature_vals = vals[:n]
regimen_value = vals[n] if len(vals) > n else None
preset_outputs = apply_grouped_preset(regimen_value)
regimen_dd_upd = gr.update(value=regimen_value)
d_upd, c_upd, cr_upd, a_upd, g_upd, h_upd = preset_outputs[1:]
return (*list(feature_vals), regimen_dd_upd,
d_upd, c_upd, cr_upd, a_upd, g_upd, h_upd)
return _copy
copy_btn.click(
fn=_make_copy_fn(s_grouped_dd, s_donorf, s_condgrpf,
s_condgrp_final, s_atgf, s_gvhd_final, s_hla_final),
inputs=wi_baseline_list + [wi_grouped_base],
outputs=(scenario_lists[-1]
+ [s_grouped_dd, s_donorf, s_condgrpf,
s_condgrp_final, s_atgf, s_gvhd_final, s_hla_final]),
)
# Re-lock sex after copy
copy_btn.click(
fn=lock_sex,
inputs=wi_baseline_dict["SEX"],
outputs=s_dict["SEX"],
)
# ── Slider → show/hide scenario panels ───────────────────────────
def _update_scenario_visibility(n):
return [gr.update(visible=(i < int(n))) for i in range(MAX_SCENARIOS)]
n_scenarios_slider.change(
fn=_update_scenario_visibility,
inputs=n_scenarios_slider,
outputs=scenario_rows,
)
n_scenarios_slider.change(
fn=lambda n: int(n),
inputs=n_scenarios_slider,
outputs=n_scenarios_state,
)
# ── RUN ──────────────────────────────────────────────────────────
gr.Markdown("---")
wi_run_btn = gr.Button(
"Run Counterfactual Comparison",
elem_classes="counterfactual-button",
size="lg",
)
# ── RESULTS ──────────────────────────────────────────────────────
gr.Markdown("## Comparison Results")
wi_violation_html = gr.HTML()
gr.Markdown("### Outcome Probability Table")
wi_table_html = gr.HTML()
gr.Markdown("---")
# NEW: Collapsible icon arrays section
with gr.Accordion("Outcome Icon Arrays — Baseline vs Scenarios", open=False):
gr.Markdown("*Icon arrays show each outcome probability per 100 patients.*")
wi_icon_html = gr.HTML()
gr.Markdown("---")
# NEW: Collapsible SHAP section
with gr.Accordion("SHAP Feature Importance", open=False):
gr.Markdown("### Baseline")
with gr.Row():
wi_shap_base = {o: gr.Plot(label=f"{o} — Baseline") for o in SHAP_ORDER}
# One SHAP row per scenario slot (hidden until needed)
wi_shap_scenarios = []
for s_idx in range(MAX_SCENARIOS):
gr.Markdown(f"### Scenario {s_idx + 1}")
with gr.Row(visible=(s_idx == 0)) as shap_row:
shap_plots_s = {o: gr.Plot(label=f"{o} — Scenario {s_idx + 1}") for o in SHAP_ORDER}
wi_shap_scenarios.append((shap_row, shap_plots_s))
# Also toggle SHAP rows with slider
def _update_shap_visibility(n):
return [gr.update(visible=(i < int(n))) for i in range(MAX_SCENARIOS)]
n_scenarios_slider.change(
fn=_update_shap_visibility,
inputs=n_scenarios_slider,
outputs=[row for row, _ in wi_shap_scenarios],
)
# ── RUN callback ──────────────────────────────────────────────────
def run_counterfactual(*all_values):
n_feat = len(ALL_FEATURES)
n_scen = int(all_values[0]) # n_scenarios_state
base_vals = all_values[1 : 1 + n_feat]
# Scenario names come next (MAX_SCENARIOS of them)
name_offset = 1 + n_feat
scenario_names_raw = all_values[name_offset : name_offset + MAX_SCENARIOS]
# Then the scenario feature blocks
feat_offset = name_offset + MAX_SCENARIOS
scenario_val_blocks = [
all_values[feat_offset + s * n_feat : feat_offset + (s + 1) * n_feat]
for s in range(MAX_SCENARIOS)
]
# Prepare flat output list
base_shap_outputs = [None] * 8
scene_shap_outputs = [[None] * 8 for _ in range(MAX_SCENARIOS)]
try:
base_dict = _values_to_dict(base_vals)
_check_missing(base_dict, "Baseline")
all_violations = []
scenario_dicts_active = []
for s in range(n_scen):
sd = _values_to_dict(scenario_val_blocks[s])
_check_missing(sd, f"Scenario {s+1}")
v = _validate_counterfactual_constraints(base_dict, sd, f"Scenario {s+1}")
all_violations.extend(v)
scenario_dicts_active.append(sd)
if all_violations:
return (
_build_violation_html(all_violations), "", "",
*base_shap_outputs,
*[p for plots in scene_shap_outputs for p in plots],
)
base_probs, base_ci = predict_all_outcomes(
base_dict, use_calibration=True, use_signed_voting=True, n_boot_ci=500
)
scen_probs_list = []
scen_ci_list = []
for sd in scenario_dicts_active:
sp, sci = predict_all_outcomes(
sd, use_calibration=True, use_signed_voting=True, n_boot_ci=500
)
scen_probs_list.append(sp)
scen_ci_list.append(sci)
# Use custom scenario names (fallback to "Scenario N" if blank)
scen_labels = []
for i in range(n_scen):
raw_name = scenario_names_raw[i] if i < len(scenario_names_raw) else ""
name = str(raw_name).strip() if raw_name else ""
scen_labels.append(name if name else f"Scenario {i+1}")
scen_colors = SCENARIO_COLORS[:n_scen]
table_html = _build_comparison_table_html(
base_probs, base_ci, scen_probs_list, scen_ci_list, scen_labels, scen_colors
)
icon_html = _build_comparison_icon_grid(
[base_probs] + scen_probs_list,
["Baseline"] + scen_labels,
["#1565c0"] + scen_colors,
)
# SHAP
base_shap_plots = create_all_shap_plots(base_dict, max_display=10)
base_shap_outputs = [base_shap_plots[o] for o in SHAP_ORDER]
for s, sd in enumerate(scenario_dicts_active):
sp_plots = create_all_shap_plots(sd, max_display=10)
scene_shap_outputs[s] = [sp_plots[o] for o in SHAP_ORDER]
return (
"",
table_html,
icon_html,
*base_shap_outputs,
*[p for plots in scene_shap_outputs for p in plots],
)
except Exception as e:
print(traceback.format_exc())
raise gr.Error(f"{type(e).__name__}: {str(e)}")
# Build flat output list: violation + table + icon + 8 base shap + MAX*8 scene shap
all_run_outputs = (
[wi_violation_html, wi_table_html, wi_icon_html]
+ [wi_shap_base[o] for o in SHAP_ORDER]
+ [shap_plots_s[o] for _, shap_plots_s in wi_shap_scenarios for o in SHAP_ORDER]
)
# Build flat input list:
# n_scenarios_state + wi_baseline_list + scenario_name_inputs + MAX_SCENARIOS × scenario_list
all_run_inputs = (
[n_scenarios_state]
+ wi_baseline_list
+ scenario_name_inputs # NEW: custom names
+ [feat for s_list in scenario_lists for feat in s_list]
)
wi_run_btn.click(
fn=run_counterfactual,
inputs=all_run_inputs,
outputs=all_run_outputs,
)
# ─────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
demo.launch(ssr_mode=False)