import gradio as gr import pandas as pd import numpy as np import copy import traceback from inference import ( FEATURE_NAMES, REPORTING_OUTCOMES, OUTCOME_DESCRIPTIONS, OUTCOMES, SHAP_OUTCOMES, predict_with_comparison, predict_all_outcomes, create_all_shap_plots, create_all_icon_arrays, ) # ───────────────────────────────────────────────────────────────────────────── # CHOICES / CONSTANTS # ───────────────────────────────────────────────────────────────────────────── AGEGPFF_CHOICES = ["<=10", "11-17", "18-29", "30-49", ">=50"] SEX_CHOICES = ["Male", "Female"] KPS_CHOICES = ["<90", "≥ 90"] DONORF_CHOICES = [ "HLA identical sibling", "HLA mismatch relative", "Matched unrelated donor", "Mismatched unrelated donor or cord blood", ] GRAFTYPE_CHOICES = ["Bone marrow", "Peripheral blood", "Cord blood"] CONDGRPF_CHOICES = ["MAC", "RIC", "NMA"] CONDGRP_FINAL_CHOICES = [ "TBI/Cy", "TBI/Cy/Flu", "TBI/Cy/Flu/TT", "TBI/Mel", "TBI/Flu", "TBI alone (300/400/600cGy)", "Bu/Cy", "Bu/Mel", "Flu/Bu/TT", "Flu/Bu", "Flu/Mel/TT", "Flu/Mel", "Cy/Flu", "Treosulfan", "Cy alone", "Flud", "TLI", ] ATGF_CHOICES = ["ATG", "Alemtuzumab", "None"] GVHD_FINAL_CHOICES = [ "Ex-vivo T-cell depletion", "CD34 selection", "Post-CY + siro +- MMF", "Post-CY + MMF + CNI", "CNI + MMF", "CNI + MTX", "CNI alone", "CNI + siro", "Siro alone", "MMF + MTX", "MMF + siro", "MMF alone", "MTX alone", "MTX + siro", ] HLA_FINAL_CHOICES = ["8/8", "7/8", "≤ 6/8"] RCMVPR_CHOICES = ["Negative", "Positive"] EXCHTFPR_CHOICES = ["No", "Yes"] VOC2YPR_CHOICES = ["No", "Yes"] VOCFRQPR_CHOICES = ["< 3/yr", "≥ 3/yr"] SCATXRSN_CHOICES = [ "CNS event", "Acute chest Syndrome", "Recurrent vaso-occlusive pain", "Recurrent priapism", "Excessive transfusion requirements/iron overload", "Cardio-pulmonary", "Chronic transfusion", "Asymptomatic", "Renal insufficiency", "Splenic sequestration", "Avascular necrosis", "Hodgkin lymphoma", ] NUM_COLS_SET = {"AGE", "NACS2YR"} MAX_SCENARIOS = 5 GROUPED_REGIMEN_CHOICES = [ ("── HLA IDENTICAL ──", "__header_hla_identical__"), ("Hsieh et al 2014", "Hsieh et al 2014"), ("Krishnamurti et al 2019", "Krishnamurti et al 2019"), ("King et al 2015", "King et al 2015"), ("Walters et al 1996", "Walters et al 1996"), ("── HLA MISMATCHED ──", "__header_hla_mismatched__"), ("Bolanos-Meade et al 2022 (HLA Mismatch)", "Bolanos-Meade et al 2022 (HLA Mismatch)"), ("Patel et al 2020 (HLA Mismatch)", "Patel et al 2020 (HLA Mismatch)"), ("── MATCHED UNRELATED ──", "__header_matched_unrelated__"), ("L Krishnamurti et al 2019", "L Krishnamurti et al 2019"), ("Shenoy et al 2016", "Shenoy et al 2016"), ("── MISMATCHED UNRELATED / CORD BLOOD ──", "__header_mismatched_cord__"), ("Bolanos-Meade et al 2022 (Mismatched/Cord)", "Bolanos-Meade et al 2022 (Mismatched/Cord)"), ("Patel et al 2020 (Mismatched/Cord)", "Patel et al 2020 (Mismatched/Cord)"), ("── CUSTOM ──", "__header_custom__"), ("Custom", "Custom"), ] HEADER_VALUES = {v for _, v in GROUPED_REGIMEN_CHOICES if v.startswith("__header_")} PUBLISHED_PRESETS = { "Hsieh et al 2014": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI alone (300/400/600cGy)", "ATGF": "Alemtuzumab", "GVHD_FINAL": "Siro alone", "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling"}, "Krishnamurti et al 2019": {"CONDGRPF": "MAC", "CONDGRP_FINAL": "Flu/Bu", "ATGF": "ATG", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling"}, "King et al 2015": {"CONDGRPF": "RIC", "CONDGRP_FINAL": "Flu/Mel", "ATGF": "Alemtuzumab", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling"}, "Walters et al 1996": {"CONDGRPF": "MAC", "CONDGRP_FINAL": "Bu/Cy", "ATGF": "ATG", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling"}, "Bolanos-Meade et al 2022 (HLA Mismatch)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "HLA mismatch relative"}, "Patel et al 2020 (HLA Mismatch)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu/TT", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "HLA mismatch relative"}, "L Krishnamurti et al 2019": {"CONDGRPF": "MAC", "CONDGRP_FINAL": "Flu/Bu", "ATGF": "ATG", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "Matched unrelated donor"}, "Shenoy et al 2016": {"CONDGRPF": "RIC", "CONDGRP_FINAL": "Flu/Mel", "ATGF": "Alemtuzumab", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "Matched unrelated donor"}, "Bolanos-Meade et al 2022 (Mismatched/Cord)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "Mismatched unrelated donor or cord blood"}, "Patel et al 2020 (Mismatched/Cord)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu/TT", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "Mismatched unrelated donor or cord blood"}, } MSD_DONOR = "HLA identical sibling" MUD_DONOR = "Matched unrelated donor" LOCKED_88_DONORS = {MSD_DONOR, MUD_DONOR} NON_88_DONORS = {"HLA mismatch relative", "Mismatched unrelated donor or cord blood"} MSD_REGIMENS = {"Hsieh et al 2014", "Krishnamurti et al 2019", "King et al 2015", "Walters et al 1996"} MMUD_REGIMENS = {"Bolanos-Meade et al 2022 (Mismatched/Cord)", "Patel et al 2020 (Mismatched/Cord)", "Bolanos-Meade et al 2022 (HLA Mismatch)", "Patel et al 2020 (HLA Mismatch)"} PATIENT_FEATURES = ["AGE", "AGEGPFF", "SEX", "KPS", "RCMVPR"] DONOR_FEATURES = ["DONORF", "GRAFTYPE", "HLA_FINAL", "CONDGRPF", "CONDGRP_FINAL", "ATGF", "GVHD_FINAL"] DISEASE_FEATURES = ["NACS2YR", "EXCHTFPR", "VOC2YPR", "VOCFRQPR", "SCATXRSN"] ALL_FEATURES = PATIENT_FEATURES + DONOR_FEATURES + DISEASE_FEATURES # Icon array outcomes in display order ICON_OUTCOMES = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI"] OUTCOME_TITLES = { "DEAD": "Death", "GF": "Graft Failure", "AGVHD": "Acute GvHD", "CGVHD": "Chronic GvHD", "VOCPSHI": "Vaso-Occlusive Crisis", "STROKEHI": "Stroke Post-HCT", } EVENT_COLOR = "#e53935" NO_EVENT_COLOR = "#43a047" SHAP_ORDER = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "EFS", "STROKEHI", "OS"] SHAP_LABELS = { "DEAD": "Death", "GF": "Graft Failure", "AGVHD": "Acute GvHD", "CGVHD": "Chronic GvHD", "VOCPSHI": "Vaso-Occlusive Crisis Post-HCT", "EFS": "Event-Free Survival", "STROKEHI": "Stroke Post-HCT", "OS": "Overall Survival", } SCENARIO_COLORS = ["#e65100", "#6a1b9a", "#1b5e20", "#0d47a1", "#b71c1c"] # ───────────────────────────────────────────────────────────────────────────── # COMPONENT FACTORY # ───────────────────────────────────────────────────────────────────────────── _FEATURE_META = { "AGE": {"label": "Age (years)", "type": "number", "kwargs": {"minimum": 0, "maximum": 100, "step": 1}}, "AGEGPFF": {"label": "Age Group", "type": "dropdown", "choices": AGEGPFF_CHOICES, "kwargs": {"interactive": False, "info": "Auto-filled from Age"}}, "SEX": {"label": "Sex", "type": "radio", "choices": SEX_CHOICES}, "KPS": {"label": "Karnofsky Performance Status", "type": "radio", "choices": KPS_CHOICES}, "RCMVPR": {"label": "Recipient CMV Status", "type": "radio", "choices": RCMVPR_CHOICES}, "DONORF": {"label": "Donor Type", "type": "dropdown", "choices": DONORF_CHOICES}, "GRAFTYPE": {"label": "Graft Type", "type": "radio", "choices": GRAFTYPE_CHOICES}, "HLA_FINAL": {"label": "HLA Matching", "type": "radio", "choices": HLA_FINAL_CHOICES}, "CONDGRPF": {"label": "Conditioning Intensity", "type": "radio", "choices": CONDGRPF_CHOICES}, "CONDGRP_FINAL": {"label": "Conditioning Regimen", "type": "dropdown", "choices": CONDGRP_FINAL_CHOICES}, "ATGF": {"label": "Serotherapy", "type": "radio", "choices": ATGF_CHOICES}, "GVHD_FINAL": {"label": "GvHD Prophylaxis", "type": "dropdown", "choices": GVHD_FINAL_CHOICES}, "NACS2YR": {"label": "Acute Chest Syndrome Episodes (2 yrs pre-HCT)", "type": "number", "kwargs": {"minimum": 0, "maximum": 50, "step": 1}}, "EXCHTFPR": {"label": "Exchange Transfusion Pre-HCT", "type": "radio", "choices": EXCHTFPR_CHOICES}, "VOC2YPR": {"label": "Vaso-Occlusive Crisis in 2 yrs Pre-HCT", "type": "radio", "choices": VOC2YPR_CHOICES}, "VOCFRQPR": {"label": "VoC Frequency Pre-HCT", "type": "radio", "choices": VOCFRQPR_CHOICES}, "SCATXRSN": {"label": "Primary Reason for HCT", "type": "dropdown", "choices": SCATXRSN_CHOICES}, } def make_component(feature: str, suffix: str = ""): meta = _FEATURE_META[feature] label = f"{meta['label']} {suffix}".strip() if suffix else meta["label"] kind = meta["type"] kwargs = dict(meta.get("kwargs", {})) if kind == "number": return gr.Number(label=label, value=None, **kwargs) elif kind == "dropdown": interactive = kwargs.pop("interactive", True) info = kwargs.pop("info", None) return gr.Dropdown(label=label, choices=meta["choices"], value=None, interactive=interactive, info=info, **kwargs) elif kind == "radio": interactive = kwargs.pop("interactive", True) info = kwargs.pop("info", None) return gr.Radio(label=label, choices=meta["choices"], value=None, interactive=interactive, info=info, **kwargs) else: raise ValueError(f"Unknown component type '{kind}' for feature '{feature}'") # ───────────────────────────────────────────────────────────────────────────── # CONSTRAINT / VALIDATION HELPERS # ───────────────────────────────────────────────────────────────────────────── def _hla_update_for_donor(donor_value): if not donor_value: return gr.update(choices=HLA_FINAL_CHOICES, interactive=True) if donor_value in LOCKED_88_DONORS: return gr.update(choices=["8/8"], value="8/8", interactive=False) elif donor_value in NON_88_DONORS: return gr.update(choices=["7/8", "≤ 6/8"], value=None, interactive=True) return gr.update(choices=HLA_FINAL_CHOICES, interactive=True) def _validate_counterfactual_constraints(base_dict, wi_dict, label=""): violations = [] tag = f"[{label}] " if label else "" if base_dict.get("SEX") and wi_dict.get("SEX"): if base_dict["SEX"] != wi_dict["SEX"]: violations.append(f"{tag} Immutable feature: Sex cannot be changed.") base_age, wi_age = base_dict.get("AGE"), wi_dict.get("AGE") if base_age is not None and wi_age is not None: try: if float(wi_age) < float(base_age): violations.append(f"{tag}Age cannot be decreased ({base_age} → {wi_age}).") except (TypeError, ValueError): pass wi_donor = wi_dict.get("DONORF") wi_hla = wi_dict.get("HLA_FINAL") if wi_donor and wi_hla: if wi_donor in LOCKED_88_DONORS and wi_hla != "8/8": violations.append(f"{tag}HLA constraint: '{wi_donor}' requires 8/8 HLA.") elif wi_donor in NON_88_DONORS and wi_hla == "8/8": violations.append(f"{tag}HLA constraint: '{wi_donor}' cannot have 8/8 HLA.") if wi_donor: wi_gvhd = wi_dict.get("GVHD_FINAL", "") if wi_donor == MSD_DONOR and wi_gvhd in {"Post-CY + siro +- MMF", "Post-CY + MMF + CNI"}: violations.append( f"{tag} Post-Cy GVHD prophylaxis is inconsistent with donor '{MSD_DONOR}'." ) return violations def _values_to_dict(values): d = {} for f, v in zip(ALL_FEATURES, values): if f in NUM_COLS_SET: try: d[f] = float(v) if v not in (None, "") else None except (TypeError, ValueError): d[f] = None else: d[f] = v return d def _check_missing(user_vals, label=""): missing = [f for f, v in user_vals.items() if v is None or v == "" or (isinstance(v, float) and pd.isna(v))] if missing: raise ValueError( f"{'[' + label + '] ' if label else ''}Please fill in all fields. " f"Missing: {', '.join(missing)}" ) def get_age_group(age): if age is None or age == "": return "" try: age = float(age) if age <= 10: return "<=10" elif age <= 17: return "11-17" elif age <= 29: return "18-29" elif age <= 49: return "30-49" else: return ">=50" except (ValueError, TypeError): return "" def vocfrqpr_from_voc2ypr(voc_status): if voc_status == "No": return gr.update(value="< 3/yr", interactive=False) return gr.update(value=None, interactive=True) def apply_grouped_preset(selected_value): if not selected_value or selected_value in HEADER_VALUES: return (gr.update(value=None),) + (gr.update(),) * 6 if selected_value == "Custom": return (gr.update(),) + tuple(gr.update(interactive=True) for _ in range(6)) preset = PUBLISHED_PRESETS.get(selected_value) if not preset: return (gr.update(),) * 7 donor = preset["DONORF"] hla_update = gr.update( value=preset["HLA_FINAL"], interactive=False, choices=(["8/8"] if donor in LOCKED_88_DONORS else (["7/8", "≤ 6/8"] if donor in NON_88_DONORS else HLA_FINAL_CHOICES)), ) return ( gr.update(), gr.update(value=preset["DONORF"], interactive=False), gr.update(value=preset["CONDGRPF"], interactive=False), gr.update(value=preset["CONDGRP_FINAL"], interactive=False), gr.update(value=preset["ATGF"], interactive=False), gr.update(value=preset["GVHD_FINAL"], interactive=False), hla_update, ) def lock_sex(baseline_sex): if baseline_sex: return gr.update(value=baseline_sex, interactive=False) return gr.update(interactive=False) # ───────────────────────────────────────────────────────────────────────────── # ICON ARRAY HTML RENDERERS (pure Python, no JS) # ───────────────────────────────────────────────────────────────────────────── def _stick_figure_svg(color, size=16): h = round(size * 1.6) return ( f'' f'' f'' f'' f'' f'' f'' ) def _render_single_icon_card(probability, outcome, panel_label="", panel_color="#1565c0"): """Render a single icon array card as HTML (no JS). Used for Python-driven carousel.""" title = OUTCOME_TITLES.get(outcome, OUTCOME_DESCRIPTIONS.get(outcome, outcome)) n_event = round(probability * 100) n_no_event = 100 - n_event pct_str = f"{probability * 100:.1f}%" rows_parts = [] for row in range(10): cells = "" for col in range(10): idx = row * 10 + col color = EVENT_COLOR if idx < n_event else NO_EVENT_COLOR cells += _stick_figure_svg(color, size=13) rows_parts.append( f'
{cells}
' ) grid_html = "\n".join(rows_parts) fig_e = _stick_figure_svg(EVENT_COLOR, size=11) fig_ne = _stick_figure_svg(NO_EVENT_COLOR, size=11) legend = ( f'
' f'{fig_e}Event' f'({n_event}/100)' f'{fig_ne}No Event' f'({n_no_event}/100)' f'
' ) badge = ( f'
' f'{panel_label}
' ) if panel_label else "" return ( f'
' f'{badge}' f'
{title}
' f'
{pct_str}
' f'
{grid_html}
' f'
{legend}
' f'
' ) def _render_icon_carousel_page(probs_dict, idx): """Render the full carousel HTML for a given outcome index (no JS).""" outcome = ICON_OUTCOMES[idx] total = len(ICON_OUTCOMES) card_html = _render_single_icon_card(probs_dict.get(outcome, 0.0), outcome) label = OUTCOME_TITLES.get(outcome, outcome) # Breadcrumb dots dots = "" for i, o in enumerate(ICON_OUTCOMES): active = i == idx dots += ( f'' ) footnote = ( f'
' f'Each figure = 1 patient out of 100. ' f'■ Red = Event   ' f'■ Green = No Event' f'
' ) return ( f'
' f'
{label}
' f'
{idx+1} / {total}
' f'
{dots}
' f'
{card_html}
' f'{footnote}' f'
' ) def _render_comparison_icon_page(all_probs_list, labels, colors, idx): """Render comparison icon carousel page for a given outcome index (no JS).""" outcome = ICON_OUTCOMES[idx] total = len(ICON_OUTCOMES) out_label = OUTCOME_TITLES.get(outcome, outcome) # Dots breadcrumb dots = "" for i in range(total): active = i == idx dots += ( f'' ) rows_html = "" for probs, label, color in zip(all_probs_list, labels, colors): prob = probs.get(outcome, 0.0) card = _render_single_icon_card(prob, outcome, label, color) rows_html += ( f'
' f'
{label}
' f'
{card}
' f'
' ) footnote = ( f'
' f'Each figure = 1 patient out of 100. ' f'■ Red = Event   ' f'■ Green = No Event' f'
' ) return ( f'
' f'
{out_label}
' f'
{idx+1} / {total}
' f'
{dots}
' f'{rows_html}' f'{footnote}' f'
' ) # ───────────────────────────────────────────────────────────────────────────── # OTHER HTML RENDERERS # ───────────────────────────────────────────────────────────────────────────── def _delta_color_html(delta, is_survival): if abs(delta) < 0.0005: color = "#888888" elif (delta > 0 and is_survival) or (delta < 0 and not is_survival): color = "#2e7d32" else: color = "#c62828" sign = "+" if delta >= 0 else "" return f'{sign}{delta*100:.1f}%' def _build_comparison_table_html(base_probs, base_ci, scenario_probs_list, scenario_ci_list, scenario_labels, scenario_colors): survival_outcomes = {"OS", "EFS"} outcome_headers = "".join( f"{OUTCOME_DESCRIPTIONS.get(o, o)}" for o in REPORTING_OUTCOMES if o in base_probs ) header = ( "
" "" "" "" f"{outcome_headers}" "" ) rows = "" baseline_cells = "" for o in REPORTING_OUTCOMES: if o not in base_probs: continue bp = base_probs[o] blo, bhi = base_ci.get(o, (float("nan"), float("nan"))) baseline_cells += ( f"" ) rows += ( f"" f"" f"{baseline_cells}" ) for j, (sp_dict, sci_dict, s_label) in enumerate(zip(scenario_probs_list, scenario_ci_list, scenario_labels)): sc_color = scenario_colors[j % len(scenario_colors)] bg = "#fafbfc" if j % 2 == 0 else "#ffffff" scenario_cells = "" for o in REPORTING_OUTCOMES: if o not in base_probs or o not in sp_dict: continue bp = base_probs[o] wp = sp_dict[o] wlo, whi = sci_dict.get(o, (float("nan"), float("nan"))) delta = wp - bp is_surv = o in survival_outcomes scenario_cells += ( f"" ) rows += ( f"" f"" f"{scenario_cells}" ) footer = ( "
Scenario
" f"
{bp*100:.1f}%
" f"
[{blo*100:.1f}%–{bhi*100:.1f}%]
" f"
" f"
" f"Baseline
" f"
{wp*100:.1f}%
" f"
[{wlo*100:.1f}%–{whi*100:.1f}%]
" f"
{_delta_color_html(delta, is_surv)}
" f"
" f"
" f"" f"{s_label}
" "
" "Δ from Baseline: Green = improvement  " "Red = worsening  |  " "OS & EFS: higher is better; all other outcomes: lower is better." "
" ) return header + rows + footer def _build_violation_html(violations): if not violations: return "" items = "".join(f"
  • {v}
  • " for v in violations) return ( f'
    ' f'
    ' f'Constraint Violations — Analysis blocked
    ' f'' f'
    ' f'Please correct the above before running the comparison.
    ' f'
    ' ) # ───────────────────────────────────────────────────────────────────────────── # SHAP CAROUSEL HELPERS # ───────────────────────────────────────────────────────────────────────────── def _shap_counter_html(idx): labels = [SHAP_LABELS.get(o, o) for o in SHAP_ORDER] items = " · ".join( f'{l}' for i, l in enumerate(labels) ) return f'
    {items}
    ' # ───────────────────────────────────────────────────────────────────────────── # MAIN PREDICT CALLBACK (Tab 1) # ───────────────────────────────────────────────────────────────────────────── def predict_gradio(*values): try: user_vals = _values_to_dict(values) missing = [f for f, v in user_vals.items() if v is None or v == "" or (isinstance(v, float) and pd.isna(v))] if missing: raise ValueError(f"Please fill in all fields. Missing: {', '.join(missing)}") calibrated, _ = predict_with_comparison(user_vals) calibrated_probs, calibrated_intervals = calibrated rows = [] for outcome in REPORTING_OUTCOMES: desc = OUTCOME_DESCRIPTIONS[outcome] calib_prob = calibrated_probs[outcome] ci_low, ci_high = calibrated_intervals[outcome] rows.append({ "Outcome": desc, "Probability": f"{calib_prob * 100:.1f}%", "95% CI": f"[{ci_low * 100:.1f}% – {ci_high * 100:.1f}%]", }) df = pd.DataFrame(rows) shap_plots = create_all_shap_plots(user_vals, max_display=10) # Icon carousel: start at index 0, store probs in State icon_html = _render_icon_carousel_page(calibrated_probs, 0) first_shap = shap_plots[SHAP_ORDER[0]] shap_crumb = _shap_counter_html(0) return ( df, calibrated_probs, # stored in State for carousel navigation 0, # icon carousel index State icon_html, # displayed icon card shap_plots, # stored in State 0, # shap index State first_shap, # displayed plot shap_crumb, # breadcrumb HTML ) except Exception as e: print(traceback.format_exc()) raise gr.Error(f"{type(e).__name__}: {str(e)}") # ───────────────────────────────────────────────────────────────────────────── # CSS # ───────────────────────────────────────────────────────────────────────────── custom_css = """ .predict-button { background: linear-gradient(to right, #ff6b35, #ff8c42) !important; border: none !important; color: white !important; font-weight: bold !important; font-size: 16px !important; padding: 12px !important; } .predict-button:hover { background: linear-gradient(to right, #ff5722, #ff7b29) !important; } .copy-button { background: linear-gradient(to right, #388e3c, #66bb6a) !important; border: none !important; color: white !important; font-weight: 600 !important; } .copy-button:hover { background: linear-gradient(to right, #2e7d32, #43a047) !important; } .copy-from-predict-button { background: linear-gradient(to right, #6a1b9a, #ab47bc) !important; border: none !important; color: white !important; font-weight: 600 !important; } .copy-from-predict-button:hover { background: linear-gradient(to right, #4a148c, #8e24aa) !important; } .counterfactual-button { background: linear-gradient(to right, #1976d2, #42a5f5) !important; border: none !important; color: white !important; font-weight: bold !important; font-size: 15px !important; padding: 12px !important; } .counterfactual-button:hover { background: linear-gradient(to right, #1565c0, #1e88e5) !important; } .shap-nav-button { background: linear-gradient(to right, #37474f, #546e7a) !important; border: none !important; color: white !important; font-weight: bold !important; font-size: 18px !important; padding: 8px 22px !important; border-radius: 6px !important; } .shap-nav-button:hover { background: linear-gradient(to right, #263238, #455a64) !important; } .icon-nav-button { background: linear-gradient(to right, #1565c0, #1e88e5) !important; border: none !important; color: white !important; font-weight: bold !important; font-size: 18px !important; padding: 8px 22px !important; border-radius: 6px !important; } .icon-nav-button:hover { background: linear-gradient(to right, #0d47a1, #1565c0) !important; } .output-dataframe table td:first-child, .output-dataframe table th:first-child { white-space: normal !important; word-break: break-word !important; min-width: 240px !important; } .constraint-info { background: #e8f5e9; border-left: 4px solid #388e3c; padding: 8px 14px; font-size: 12px; color: #1b5e20; border-radius: 4px; margin-bottom: 8px; } .scenario-panel-0 { border-left: 4px solid #e65100 !important; } .scenario-panel-1 { border-left: 4px solid #6a1b9a !important; } .scenario-panel-2 { border-left: 4px solid #1b5e20 !important; } .scenario-panel-3 { border-left: 4px solid #0d47a1 !important; } .scenario-panel-4 { border-left: 4px solid #b71c1c !important; } """ # ───────────────────────────────────────────────────────────────────────────── # BUILD UI # ───────────────────────────────────────────────────────────────────────────── with gr.Blocks(title="HCT Outcome Prediction Model", css=custom_css) as demo: gr.Markdown("# HCT Outcome Prediction Model") n_scenarios_state = gr.State(1) with gr.Tabs(): # ══════════════════════════════════════════════════════════════════════ # TAB 1 — PREDICT OUTCOMES # ══════════════════════════════════════════════════════════════════════ with gr.Tab("Predict Outcomes"): gr.Markdown("Enter patient, transplant, and disease characteristics to predict outcomes.") inputs_dict = {} with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Patient Characteristics") for f in PATIENT_FEATURES: inputs_dict[f] = make_component(f) with gr.Column(scale=1): gr.Markdown("### Transplant Characteristics") grouped_dd = gr.Dropdown( choices=GROUPED_REGIMEN_CHOICES, value=None, label="Published conditioning regimen", info="Auto-fills Donor Type, Conditioning Intensity, Regimen, Serotherapy, GVHD Prophylaxis", ) p_donorf = inputs_dict["DONORF"] = make_component("DONORF") inputs_dict["GRAFTYPE"] = make_component("GRAFTYPE") p_condgrpf = inputs_dict["CONDGRPF"] = make_component("CONDGRPF") p_condgrp_final = inputs_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL") p_atgf = inputs_dict["ATGF"] = make_component("ATGF") p_gvhd_final = inputs_dict["GVHD_FINAL"] = make_component("GVHD_FINAL") p_hla_final = inputs_dict["HLA_FINAL"] = make_component("HLA_FINAL") with gr.Column(scale=1): gr.Markdown("### Disease Characteristics") for f in DISEASE_FEATURES: inputs_dict[f] = make_component(f) inputs_dict["AGE"].change(get_age_group, inputs_dict["AGE"], inputs_dict["AGEGPFF"]) inputs_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, inputs_dict["VOC2YPR"], inputs_dict["VOCFRQPR"]) p_donorf.change(fn=_hla_update_for_donor, inputs=p_donorf, outputs=p_hla_final) grouped_dd.change( apply_grouped_preset, grouped_dd, [grouped_dd, p_donorf, p_condgrpf, p_condgrp_final, p_atgf, p_gvhd_final, p_hla_final], ) inputs_list = [inputs_dict[f] for f in ALL_FEATURES] predict_btn = gr.Button("Predict", elem_classes="predict-button", size="lg") gr.Markdown("---") gr.Markdown("## Prediction Results") output_table = gr.Dataframe( headers=["Outcome", "Probability", "95% CI"], label="Predicted Outcomes", elem_classes="output-dataframe", row_count=(len(REPORTING_OUTCOMES), "dynamic"), column_count=(3, "fixed"), wrap=True, ) # ── Icon Array Carousel (Tab 1) — Python-driven ─────────────────── gr.Markdown("---") gr.Markdown("## Outcome Probability — Icon Arrays") gr.Markdown("*Use the ← → arrows to browse outcomes one at a time.*") # State: store probs dict + current index icon_probs_state = gr.State(None) icon_idx_state = gr.State(0) with gr.Row(): icon_prev_btn = gr.Button("←", elem_classes="icon-nav-button", size="sm", scale=0) icon_display = gr.HTML(scale=4) icon_next_btn = gr.Button("→", elem_classes="icon-nav-button", size="sm", scale=0) def _on_icon_prev(idx, probs): if probs is None: return idx, "" new_idx = max(0, idx - 1) return new_idx, _render_icon_carousel_page(probs, new_idx) def _on_icon_next(idx, probs): if probs is None: return idx, "" new_idx = min(len(ICON_OUTCOMES) - 1, idx + 1) return new_idx, _render_icon_carousel_page(probs, new_idx) icon_prev_btn.click( fn=_on_icon_prev, inputs=[icon_idx_state, icon_probs_state], outputs=[icon_idx_state, icon_display], ) icon_next_btn.click( fn=_on_icon_next, inputs=[icon_idx_state, icon_probs_state], outputs=[icon_idx_state, icon_display], ) # ── SHAP Carousel (Tab 1) ───────────────────────────────────────── gr.Markdown("---") gr.Markdown("## SHAP — Feature Importance") gr.Markdown("*Use the ← → buttons to browse SHAP plots one at a time.*") shap_plots_state = gr.State(None) shap_idx_state = gr.State(0) with gr.Row(): shap_prev_btn = gr.Button("←", elem_classes="shap-nav-button", size="sm", scale=0) shap_crumb = gr.HTML(value="", scale=4) shap_next_btn = gr.Button("→", elem_classes="shap-nav-button", size="sm", scale=0) shap_display = gr.Plot(label="SHAP Feature Importance") def _on_shap_prev(idx, plots): new_idx = max(0, idx - 1) plot = plots[SHAP_ORDER[new_idx]] if plots else None return new_idx, plot, _shap_counter_html(new_idx) def _on_shap_next(idx, plots): new_idx = min(len(SHAP_ORDER) - 1, idx + 1) plot = plots[SHAP_ORDER[new_idx]] if plots else None return new_idx, plot, _shap_counter_html(new_idx) shap_prev_btn.click( fn=_on_shap_prev, inputs=[shap_idx_state, shap_plots_state], outputs=[shap_idx_state, shap_display, shap_crumb], ) shap_next_btn.click( fn=_on_shap_next, inputs=[shap_idx_state, shap_plots_state], outputs=[shap_idx_state, shap_display, shap_crumb], ) predict_btn.click( fn=predict_gradio, inputs=inputs_list, outputs=[ output_table, icon_probs_state, # ← store probs dict icon_idx_state, # ← reset to 0 icon_display, # ← show first card shap_plots_state, shap_idx_state, shap_display, shap_crumb, ], ) # ══════════════════════════════════════════════════════════════════════ # TAB 2 — COUNTERFACTUAL ANALYSIS # ══════════════════════════════════════════════════════════════════════ with gr.Tab("Counterfactual Scenarios"): gr.Markdown( "## Counterfactual Scenario Analysis\n" "Enter the **baseline** patient, then choose how many counterfactual " "scenarios you want to compare. Each scenario panel will appear below." ) with gr.Row(): n_scenarios_slider = gr.Slider( minimum=1, maximum=MAX_SCENARIOS, step=1, value=1, label=f"How many counterfactual scenarios do you want to compare? (1–{MAX_SCENARIOS})", info="Adjust this first — scenario panels will appear/disappear below.", ) gr.Markdown("---") # ── BASELINE ───────────────────────────────────────────────────── gr.Markdown("## Baseline Patient Profile") with gr.Row(): copy_from_predict_btn = gr.Button( "Copy from Predict tab → Baseline", elem_classes="copy-from-predict-button", size="sm", ) wi_baseline_dict = {} with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Patient Characteristics") for f in PATIENT_FEATURES: wi_baseline_dict[f] = make_component(f) with gr.Column(scale=1): gr.Markdown("### Transplant Characteristics") wi_grouped_base = gr.Dropdown( choices=GROUPED_REGIMEN_CHOICES, value=None, label="Published conditioning regimen (Baseline)", ) wb_donorf = wi_baseline_dict["DONORF"] = make_component("DONORF") wi_baseline_dict["GRAFTYPE"] = make_component("GRAFTYPE") wb_condgrpf = wi_baseline_dict["CONDGRPF"] = make_component("CONDGRPF") wb_condgrp_final = wi_baseline_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL") wb_atgf = wi_baseline_dict["ATGF"] = make_component("ATGF") wb_gvhd_final = wi_baseline_dict["GVHD_FINAL"] = make_component("GVHD_FINAL") wb_hla_final = wi_baseline_dict["HLA_FINAL"] = make_component("HLA_FINAL") with gr.Column(scale=1): gr.Markdown("### Disease Characteristics") for f in DISEASE_FEATURES: wi_baseline_dict[f] = make_component(f) wi_baseline_dict["AGE"].change(get_age_group, wi_baseline_dict["AGE"], wi_baseline_dict["AGEGPFF"]) wi_baseline_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, wi_baseline_dict["VOC2YPR"], wi_baseline_dict["VOCFRQPR"]) wb_donorf.change(fn=_hla_update_for_donor, inputs=wb_donorf, outputs=wb_hla_final) wi_grouped_base.change( apply_grouped_preset, wi_grouped_base, [wi_grouped_base, wb_donorf, wb_condgrpf, wb_condgrp_final, wb_atgf, wb_gvhd_final, wb_hla_final], ) wi_baseline_list = [wi_baseline_dict[f] for f in ALL_FEATURES] def _copy_predict_to_baseline(*vals): n = len(ALL_FEATURES) feature_vals = vals[:n] regimen_value = vals[n] if len(vals) > n else None preset_outputs = apply_grouped_preset(regimen_value) regimen_dd_upd = gr.update(value=regimen_value) donorf_upd, condgrpf_upd, condgrp_final_upd, atgf_upd, gvhd_final_upd, hla_final_upd = preset_outputs[1:] return (*list(feature_vals), regimen_dd_upd, donorf_upd, condgrpf_upd, condgrp_final_upd, atgf_upd, gvhd_final_upd, hla_final_upd) copy_from_predict_btn.click( fn=_copy_predict_to_baseline, inputs=inputs_list + [grouped_dd], outputs=wi_baseline_list + [wi_grouped_base, wb_donorf, wb_condgrpf, wb_condgrp_final, wb_atgf, wb_gvhd_final, wb_hla_final], ) gr.Markdown("---") # ── SCENARIO PANELS ─────────────────────────────────────────────── scenario_dicts = [] scenario_lists = [] scenario_grouped_dds = [] scenario_rows = [] scenario_name_inputs = [] scenario_donorf_handles = [] scenario_hla_handles = [] scenario_voc2ypr_handles = [] scenario_age_handles = [] scenario_agegp_handles = [] scenario_copy_btns = [] for s_idx in range(MAX_SCENARIOS): color = SCENARIO_COLORS[s_idx] label = f"Scenario {s_idx + 1}" suffix = f"({label})" visible_init = (s_idx == 0) with gr.Row(visible=visible_init) as s_row: scenario_rows.append(s_row) with gr.Column(): gr.HTML( f'
    ' f'Counterfactual {label}
    ' ) with gr.Row(): s_name_input = gr.Textbox( label=f"Scenario name ({label})", value=label, placeholder=f"e.g. {label}", scale=2, ) scenario_name_inputs.append(s_name_input) copy_btn = gr.Button( f"⬇ Copy Baseline → {label}", elem_classes="copy-button", size="sm", scale=1, ) scenario_copy_btns.append(copy_btn) s_dict = {} with gr.Row(): with gr.Column(scale=1): gr.Markdown(f"#### Patient — {label}") for f in PATIENT_FEATURES: s_dict[f] = make_component(f, suffix) with gr.Column(scale=1): gr.Markdown(f"#### Transplant — {label}") s_grouped_dd = gr.Dropdown( choices=GROUPED_REGIMEN_CHOICES, value=None, label=f"Published regimen ({label})", ) scenario_grouped_dds.append(s_grouped_dd) s_donorf = s_dict["DONORF"] = make_component("DONORF", suffix) s_dict["GRAFTYPE"] = make_component("GRAFTYPE", suffix) s_condgrpf = s_dict["CONDGRPF"] = make_component("CONDGRPF", suffix) s_condgrp_final = s_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL", suffix) s_atgf = s_dict["ATGF"] = make_component("ATGF", suffix) s_gvhd_final = s_dict["GVHD_FINAL"] = make_component("GVHD_FINAL", suffix) s_hla_final = s_dict["HLA_FINAL"] = make_component("HLA_FINAL", suffix) with gr.Column(scale=1): gr.Markdown(f"#### Disease — {label}") for f in DISEASE_FEATURES: s_dict[f] = make_component(f, suffix) scenario_dicts.append(s_dict) scenario_lists.append([s_dict[f] for f in ALL_FEATURES]) scenario_donorf_handles.append(s_donorf) scenario_hla_handles.append(s_hla_final) scenario_voc2ypr_handles.append(s_dict["VOC2YPR"]) scenario_age_handles.append(s_dict["AGE"]) scenario_agegp_handles.append(s_dict["AGEGPFF"]) s_dict["AGE"].change(get_age_group, s_dict["AGE"], s_dict["AGEGPFF"]) s_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, s_dict["VOC2YPR"], s_dict["VOCFRQPR"]) s_donorf.change(fn=_hla_update_for_donor, inputs=s_donorf, outputs=s_hla_final) s_grouped_dd.change( apply_grouped_preset, s_grouped_dd, [s_grouped_dd, s_donorf, s_condgrpf, s_condgrp_final, s_atgf, s_gvhd_final, s_hla_final], ) wi_baseline_dict["SEX"].change( fn=lock_sex, inputs=wi_baseline_dict["SEX"], outputs=s_dict["SEX"], ) s_dict["SEX"].interactive = False def _make_copy_fn(s_grouped_dd_ref, s_donorf_ref, s_condgrpf_ref, s_condgrp_final_ref, s_atgf_ref, s_gvhd_final_ref, s_hla_final_ref): def _copy(*vals): n = len(ALL_FEATURES) feature_vals = vals[:n] regimen_value = vals[n] if len(vals) > n else None preset_outputs = apply_grouped_preset(regimen_value) regimen_dd_upd = gr.update(value=regimen_value) d_upd, c_upd, cr_upd, a_upd, g_upd, h_upd = preset_outputs[1:] return (*list(feature_vals), regimen_dd_upd, d_upd, c_upd, cr_upd, a_upd, g_upd, h_upd) return _copy copy_btn.click( fn=_make_copy_fn(s_grouped_dd, s_donorf, s_condgrpf, s_condgrp_final, s_atgf, s_gvhd_final, s_hla_final), inputs=wi_baseline_list + [wi_grouped_base], outputs=(scenario_lists[-1] + [s_grouped_dd, s_donorf, s_condgrpf, s_condgrp_final, s_atgf, s_gvhd_final, s_hla_final]), ) copy_btn.click( fn=lock_sex, inputs=wi_baseline_dict["SEX"], outputs=s_dict["SEX"], ) # ── Slider → show/hide scenario panels ─────────────────────────── def _update_scenario_visibility(n): return [gr.update(visible=(i < int(n))) for i in range(MAX_SCENARIOS)] n_scenarios_slider.change( fn=_update_scenario_visibility, inputs=n_scenarios_slider, outputs=scenario_rows, ) n_scenarios_slider.change( fn=lambda n: int(n), inputs=n_scenarios_slider, outputs=n_scenarios_state, ) # ── RUN ────────────────────────────────────────────────────────── gr.Markdown("---") wi_run_btn = gr.Button( "Run Counterfactual Comparison", elem_classes="counterfactual-button", size="lg", ) # ── RESULTS ────────────────────────────────────────────────────── gr.Markdown("## Comparison Results") wi_violation_html = gr.HTML() gr.Markdown("### Outcome Probability Table") wi_table_html = gr.HTML() gr.Markdown("---") # ── Icon Arrays — comparison, Python-driven carousel ───────────── with gr.Accordion("Outcome Icon Arrays — Baseline vs Scenarios", open=False): gr.Markdown( "*One row per scenario. Use the ← → arrows to browse outcomes one at a time.*" ) wi_cmp_probs_state = gr.State(None) # list of dicts wi_cmp_labels_state = gr.State(None) # list of labels wi_cmp_colors_state = gr.State(None) # list of colors wi_cmp_idx_state = gr.State(0) with gr.Row(): wi_cmp_prev_btn = gr.Button("←", elem_classes="icon-nav-button", size="sm", scale=0) wi_icon_html = gr.HTML(scale=4) wi_cmp_next_btn = gr.Button("→", elem_classes="icon-nav-button", size="sm", scale=0) def _on_cmp_prev(idx, probs_list, labels, colors): if probs_list is None: return idx, "" new_idx = max(0, idx - 1) return new_idx, _render_comparison_icon_page(probs_list, labels, colors, new_idx) def _on_cmp_next(idx, probs_list, labels, colors): if probs_list is None: return idx, "" new_idx = min(len(ICON_OUTCOMES) - 1, idx + 1) return new_idx, _render_comparison_icon_page(probs_list, labels, colors, new_idx) wi_cmp_prev_btn.click( fn=_on_cmp_prev, inputs=[wi_cmp_idx_state, wi_cmp_probs_state, wi_cmp_labels_state, wi_cmp_colors_state], outputs=[wi_cmp_idx_state, wi_icon_html], ) wi_cmp_next_btn.click( fn=_on_cmp_next, inputs=[wi_cmp_idx_state, wi_cmp_probs_state, wi_cmp_labels_state, wi_cmp_colors_state], outputs=[wi_cmp_idx_state, wi_icon_html], ) gr.Markdown("---") # ── SHAP Carousels — one per scenario slot ──────────────────────── with gr.Accordion("SHAP Feature Importance", open=False): # Baseline SHAP carousel gr.Markdown("### Baseline") wi_shap_base_store = gr.State(None) wi_shap_base_idx = gr.State(0) with gr.Row(): wb_shap_prev = gr.Button("←", elem_classes="shap-nav-button", size="sm", scale=0) wb_shap_crumb = gr.HTML(value="", scale=4) wb_shap_next = gr.Button("→", elem_classes="shap-nav-button", size="sm", scale=0) wb_shap_plot = gr.Plot(label="Baseline — SHAP") def _wb_prev(idx, plots): new_idx = max(0, idx - 1) plot = plots[SHAP_ORDER[new_idx]] if plots else None return new_idx, plot, _shap_counter_html(new_idx) def _wb_next(idx, plots): new_idx = min(len(SHAP_ORDER) - 1, idx + 1) plot = plots[SHAP_ORDER[new_idx]] if plots else None return new_idx, plot, _shap_counter_html(new_idx) wb_shap_prev.click(_wb_prev, [wi_shap_base_idx, wi_shap_base_store], [wi_shap_base_idx, wb_shap_plot, wb_shap_crumb]) wb_shap_next.click(_wb_next, [wi_shap_base_idx, wi_shap_base_store], [wi_shap_base_idx, wb_shap_plot, wb_shap_crumb]) # Per-scenario SHAP carousels wi_shap_scen_stores = [] wi_shap_scen_idxs = [] wi_shap_scen_plots = [] wi_shap_scen_crumbs = [] wi_shap_scen_rows = [] for s_idx in range(MAX_SCENARIOS): with gr.Row(visible=(s_idx == 0)) as shap_row: wi_shap_scen_rows.append(shap_row) with gr.Column(): scen_color = SCENARIO_COLORS[s_idx] gr.HTML( f'
    ' f'Scenario {s_idx+1}
    ' ) s_store = gr.State(None) s_idx_s = gr.State(0) wi_shap_scen_stores.append(s_store) wi_shap_scen_idxs.append(s_idx_s) with gr.Row(): s_prev_btn = gr.Button("←", elem_classes="shap-nav-button", size="sm", scale=0) s_crumb = gr.HTML(value="", scale=4) s_next_btn = gr.Button("→", elem_classes="shap-nav-button", size="sm", scale=0) s_plot = gr.Plot(label=f"Scenario {s_idx+1} — SHAP") wi_shap_scen_plots.append(s_plot) wi_shap_scen_crumbs.append(s_crumb) def _make_prev(st, ix): def fn(idx, plots): new_idx = max(0, idx - 1) plot = plots[SHAP_ORDER[new_idx]] if plots else None return new_idx, plot, _shap_counter_html(new_idx) return fn def _make_next(st, ix): def fn(idx, plots): new_idx = min(len(SHAP_ORDER) - 1, idx + 1) plot = plots[SHAP_ORDER[new_idx]] if plots else None return new_idx, plot, _shap_counter_html(new_idx) return fn s_prev_btn.click( _make_prev(s_store, s_idx_s), [s_idx_s, s_store], [s_idx_s, s_plot, s_crumb], ) s_next_btn.click( _make_next(s_store, s_idx_s), [s_idx_s, s_store], [s_idx_s, s_plot, s_crumb], ) # Sync SHAP scenario row visibility with slider def _update_shap_vis(n): return [gr.update(visible=(i < int(n))) for i in range(MAX_SCENARIOS)] n_scenarios_slider.change( fn=_update_shap_vis, inputs=n_scenarios_slider, outputs=wi_shap_scen_rows, ) # ── RUN callback ────────────────────────────────────────────────── def run_counterfactual(*all_values): n_feat = len(ALL_FEATURES) n_scen = int(all_values[0]) base_vals = all_values[1 : 1 + n_feat] name_offset = 1 + n_feat scenario_names_raw = all_values[name_offset : name_offset + MAX_SCENARIOS] feat_offset = name_offset + MAX_SCENARIOS scenario_val_blocks = [ all_values[feat_offset + s * n_feat : feat_offset + (s + 1) * n_feat] for s in range(MAX_SCENARIOS) ] def _empty_outputs(): # violation, table, cmp_probs, cmp_labels, cmp_colors, cmp_idx, icon_html # base_store, base_idx, base_plot, base_crumb # per scenario: store, idx, plot, crumb out = ["", "", None, None, None, 0, ""] out += [None, 0, None, ""] for _ in range(MAX_SCENARIOS): out += [None, 0, None, ""] return tuple(out) try: base_dict = _values_to_dict(base_vals) _check_missing(base_dict, "Baseline") all_violations = [] scenario_dicts_active = [] for s in range(n_scen): sd = _values_to_dict(scenario_val_blocks[s]) _check_missing(sd, f"Scenario {s+1}") v = _validate_counterfactual_constraints(base_dict, sd, f"Scenario {s+1}") all_violations.extend(v) scenario_dicts_active.append(sd) if all_violations: out = [_build_violation_html(all_violations), "", None, None, None, 0, ""] out += [None, 0, None, ""] for _ in range(MAX_SCENARIOS): out += [None, 0, None, ""] return tuple(out) base_probs, base_ci = predict_all_outcomes( base_dict, use_calibration=True, use_signed_voting=True, n_boot_ci=500 ) scen_probs_list = [] scen_ci_list = [] for sd in scenario_dicts_active: sp, sci = predict_all_outcomes( sd, use_calibration=True, use_signed_voting=True, n_boot_ci=500 ) scen_probs_list.append(sp) scen_ci_list.append(sci) scen_labels = [] for i in range(n_scen): raw_name = scenario_names_raw[i] if i < len(scenario_names_raw) else "" name = str(raw_name).strip() if raw_name else "" scen_labels.append(name if name else f"Scenario {i+1}") scen_colors = SCENARIO_COLORS[:n_scen] table_html = _build_comparison_table_html( base_probs, base_ci, scen_probs_list, scen_ci_list, scen_labels, scen_colors ) # Icon carousel: store all probs + labels + colors in State all_probs_list = [base_probs] + scen_probs_list all_labels = ["Baseline"] + scen_labels all_colors = ["#1565c0"] + scen_colors icon_html_first = _render_comparison_icon_page(all_probs_list, all_labels, all_colors, 0) # SHAP — baseline base_shap_plots = create_all_shap_plots(base_dict, max_display=10) first_base_plot = base_shap_plots[SHAP_ORDER[0]] base_crumb = _shap_counter_html(0) # SHAP — scenarios scen_shap_data = [] for s in range(MAX_SCENARIOS): if s < n_scen: sp_plots = create_all_shap_plots(scenario_dicts_active[s], max_display=10) scen_shap_data.append((sp_plots, sp_plots[SHAP_ORDER[0]], _shap_counter_html(0))) else: scen_shap_data.append((None, None, "")) out = [ "", # violation html table_html, # table all_probs_list, # cmp_probs state all_labels, # cmp_labels state all_colors, # cmp_colors state 0, # cmp_idx state icon_html_first, # icon display base_shap_plots, # base shap store 0, # base shap idx first_base_plot, # base shap plot base_crumb, # base shap crumb ] for s in range(MAX_SCENARIOS): store, plot0, crumb0 = scen_shap_data[s] out += [store, 0, plot0, crumb0] return tuple(out) except Exception as e: print(traceback.format_exc()) raise gr.Error(f"{type(e).__name__}: {str(e)}") # Build output list all_run_outputs = ( [wi_violation_html, wi_table_html, wi_cmp_probs_state, wi_cmp_labels_state, wi_cmp_colors_state, wi_cmp_idx_state, wi_icon_html] + [wi_shap_base_store, wi_shap_base_idx, wb_shap_plot, wb_shap_crumb] + [item for s in range(MAX_SCENARIOS) for item in (wi_shap_scen_stores[s], wi_shap_scen_idxs[s], wi_shap_scen_plots[s], wi_shap_scen_crumbs[s])] ) all_run_inputs = ( [n_scenarios_state] + wi_baseline_list + scenario_name_inputs + [feat for s_list in scenario_lists for feat in s_list] ) wi_run_btn.click( fn=run_counterfactual, inputs=all_run_inputs, outputs=all_run_outputs, ) # ───────────────────────────────────────────────────────────────────────────── if __name__ == "__main__": demo.launch(ssr_mode=False)