import gradio as gr import pandas as pd import numpy as np import copy import traceback from inference import ( FEATURE_NAMES, REPORTING_OUTCOMES, OUTCOME_DESCRIPTIONS, OUTCOMES, SHAP_OUTCOMES, predict_with_comparison, predict_all_outcomes, create_all_shap_plots, create_all_icon_arrays, ) # ───────────────────────────────────────────────────────────────────────────── # CHOICES / CONSTANTS # ───────────────────────────────────────────────────────────────────────────── AGEGPFF_CHOICES = ["<=10", "11-17", "18-29", "30-49", ">=50"] SEX_CHOICES = ["Male", "Female"] KPS_CHOICES = ["<90", "≥ 90"] DONORF_CHOICES = [ "HLA identical sibling", "HLA mismatch relative", "Matched unrelated donor", "Mismatched unrelated donor or cord blood", ] GRAFTYPE_CHOICES = ["Bone marrow", "Peripheral blood", "Cord blood"] CONDGRPF_CHOICES = ["MAC", "RIC", "NMA"] CONDGRP_FINAL_CHOICES = [ "TBI/Cy", "TBI/Cy/Flu", "TBI/Cy/Flu/TT", "TBI/Mel", "TBI/Flu", "TBI alone (300/400/600cGy)", "Bu/Cy", "Bu/Mel", "Flu/Bu/TT", "Flu/Bu", "Flu/Mel/TT", "Flu/Mel", "Cy/Flu", "Treosulfan", "Cy alone", "Flud", "TLI", ] ATGF_CHOICES = ["ATG", "Alemtuzumab", "None"] GVHD_FINAL_CHOICES = [ "Ex-vivo T-cell depletion", "CD34 selection", "Post-CY + siro +- MMF", "Post-CY + MMF + CNI", "CNI + MMF", "CNI + MTX", "CNI alone", "CNI + siro", "Siro alone", "MMF + MTX", "MMF + siro", "MMF alone", "MTX alone", "MTX + siro", ] HLA_FINAL_CHOICES = ["8/8", "7/8", "≤ 6/8"] RCMVPR_CHOICES = ["Negative", "Positive"] EXCHTFPR_CHOICES = ["No", "Yes"] VOC2YPR_CHOICES = ["No", "Yes"] VOCFRQPR_CHOICES = ["< 3/yr", "≥ 3/yr"] SCATXRSN_CHOICES = [ "CNS event", "Acute chest Syndrome", "Recurrent vaso-occlusive pain", "Recurrent priapism", "Excessive transfusion requirements/iron overload", "Cardio-pulmonary", "Chronic transfusion", "Asymptomatic", "Renal insufficiency", "Splenic sequestration", "Avascular necrosis", "Hodgkin lymphoma", ] NUM_COLS_SET = {"AGE", "NACS2YR"} MAX_SCENARIOS = 5 GROUPED_REGIMEN_CHOICES = [ ("── HLA IDENTICAL ──", "__header_hla_identical__"), ("Hsieh et al 2014", "Hsieh et al 2014"), ("Krishnamurti et al 2019", "Krishnamurti et al 2019"), ("King et al 2015", "King et al 2015"), ("Walters et al 1996", "Walters et al 1996"), ("── HLA MISMATCHED ──", "__header_hla_mismatched__"), ("Bolanos-Meade et al 2022 (HLA Mismatch)", "Bolanos-Meade et al 2022 (HLA Mismatch)"), ("Patel et al 2020 (HLA Mismatch)", "Patel et al 2020 (HLA Mismatch)"), ("── MATCHED UNRELATED ──", "__header_matched_unrelated__"), ("L Krishnamurti et al 2019", "L Krishnamurti et al 2019"), ("Shenoy et al 2016", "Shenoy et al 2016"), ("── MISMATCHED UNRELATED / CORD BLOOD ──", "__header_mismatched_cord__"), ("Bolanos-Meade et al 2022 (Mismatched/Cord)", "Bolanos-Meade et al 2022 (Mismatched/Cord)"), ("Patel et al 2020 (Mismatched/Cord)", "Patel et al 2020 (Mismatched/Cord)"), ("── CUSTOM ──", "__header_custom__"), ("Custom", "Custom"), ] HEADER_VALUES = {v for _, v in GROUPED_REGIMEN_CHOICES if v.startswith("__header_")} PUBLISHED_PRESETS = { "Hsieh et al 2014": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI alone (300/400/600cGy)", "ATGF": "Alemtuzumab", "GVHD_FINAL": "Siro alone", "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling"}, "Krishnamurti et al 2019": {"CONDGRPF": "MAC", "CONDGRP_FINAL": "Flu/Bu", "ATGF": "ATG", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling"}, "King et al 2015": {"CONDGRPF": "RIC", "CONDGRP_FINAL": "Flu/Mel", "ATGF": "Alemtuzumab", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling"}, "Walters et al 1996": {"CONDGRPF": "MAC", "CONDGRP_FINAL": "Bu/Cy", "ATGF": "ATG", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling"}, "Bolanos-Meade et al 2022 (HLA Mismatch)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "HLA mismatch relative"}, "Patel et al 2020 (HLA Mismatch)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu/TT", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "HLA mismatch relative"}, "L Krishnamurti et al 2019": {"CONDGRPF": "MAC", "CONDGRP_FINAL": "Flu/Bu", "ATGF": "ATG", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "Matched unrelated donor"}, "Shenoy et al 2016": {"CONDGRPF": "RIC", "CONDGRP_FINAL": "Flu/Mel", "ATGF": "Alemtuzumab", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "Matched unrelated donor"}, "Bolanos-Meade et al 2022 (Mismatched/Cord)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "Mismatched unrelated donor or cord blood"}, "Patel et al 2020 (Mismatched/Cord)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu/TT", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "Mismatched unrelated donor or cord blood"}, } MSD_DONOR = "HLA identical sibling" MUD_DONOR = "Matched unrelated donor" LOCKED_88_DONORS = {MSD_DONOR, MUD_DONOR} NON_88_DONORS = {"HLA mismatch relative", "Mismatched unrelated donor or cord blood"} MSD_REGIMENS = {"Hsieh et al 2014", "Krishnamurti et al 2019", "King et al 2015", "Walters et al 1996"} MMUD_REGIMENS = {"Bolanos-Meade et al 2022 (Mismatched/Cord)", "Patel et al 2020 (Mismatched/Cord)", "Bolanos-Meade et al 2022 (HLA Mismatch)", "Patel et al 2020 (HLA Mismatch)"} PATIENT_FEATURES = ["AGE", "AGEGPFF", "SEX", "KPS", "RCMVPR"] DONOR_FEATURES = ["DONORF", "GRAFTYPE", "HLA_FINAL", "CONDGRPF", "CONDGRP_FINAL", "ATGF", "GVHD_FINAL"] DISEASE_FEATURES = ["NACS2YR", "EXCHTFPR", "VOC2YPR", "VOCFRQPR", "SCATXRSN"] ALL_FEATURES = PATIENT_FEATURES + DONOR_FEATURES + DISEASE_FEATURES ICON_OUTCOMES = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI"] OUTCOME_TITLES = { "DEAD": "Death", "GF": "Graft Failure", "AGVHD": "Acute GvHD", "CGVHD": "Chronic GvHD", "VOCPSHI": "Vaso-Occlusive Crisis", "STROKEHI": "Stroke Post-HCT", } EVENT_COLOR = "#e53935" NO_EVENT_COLOR = "#43a047" SHAP_ORDER = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "EFS", "STROKEHI", "OS"] SCENARIO_COLORS = ["#e65100", "#6a1b9a", "#1b5e20", "#0d47a1", "#b71c1c"] # ───────────────────────────────────────────────────────────────────────────── # COMPONENT FACTORY # ───────────────────────────────────────────────────────────────────────────── _FEATURE_META = { "AGE": {"label": "Age (years)", "type": "number", "kwargs": {"minimum": 0, "maximum": 100, "step": 1}}, "AGEGPFF": {"label": "Age Group", "type": "dropdown", "choices": AGEGPFF_CHOICES, "kwargs": {"interactive": False, "info": "Auto-filled from Age"}}, "SEX": {"label": "Sex", "type": "radio", "choices": SEX_CHOICES}, "KPS": {"label": "Karnofsky Performance Status", "type": "radio", "choices": KPS_CHOICES}, "RCMVPR": {"label": "Recipient CMV Status", "type": "radio", "choices": RCMVPR_CHOICES}, "DONORF": {"label": "Donor Type", "type": "dropdown", "choices": DONORF_CHOICES}, "GRAFTYPE": {"label": "Graft Type", "type": "radio", "choices": GRAFTYPE_CHOICES}, "HLA_FINAL": {"label": "HLA Matching", "type": "radio", "choices": HLA_FINAL_CHOICES}, "CONDGRPF": {"label": "Conditioning Intensity", "type": "radio", "choices": CONDGRPF_CHOICES}, "CONDGRP_FINAL": {"label": "Conditioning Regimen", "type": "dropdown", "choices": CONDGRP_FINAL_CHOICES}, "ATGF": {"label": "Serotherapy", "type": "radio", "choices": ATGF_CHOICES}, "GVHD_FINAL": {"label": "GvHD Prophylaxis", "type": "dropdown", "choices": GVHD_FINAL_CHOICES}, "NACS2YR": {"label": "Acute Chest Syndrome Episodes (2 yrs pre-HCT)", "type": "number", "kwargs": {"minimum": 0, "maximum": 50, "step": 1}}, "EXCHTFPR": {"label": "Exchange Transfusion Pre-HCT", "type": "radio", "choices": EXCHTFPR_CHOICES}, "VOC2YPR": {"label": "Vaso-Occlusive Crisis in 2 yrs Pre-HCT", "type": "radio", "choices": VOC2YPR_CHOICES}, "VOCFRQPR": {"label": "VoC Frequency Pre-HCT", "type": "radio", "choices": VOCFRQPR_CHOICES}, "SCATXRSN": {"label": "Primary Reason for HCT", "type": "dropdown", "choices": SCATXRSN_CHOICES}, } def make_component(feature: str, suffix: str = ""): meta = _FEATURE_META[feature] label = f"{meta['label']} {suffix}".strip() if suffix else meta["label"] kind = meta["type"] kwargs = dict(meta.get("kwargs", {})) if kind == "number": return gr.Number(label=label, value=None, **kwargs) elif kind == "dropdown": interactive = kwargs.pop("interactive", True) info = kwargs.pop("info", None) return gr.Dropdown(label=label, choices=meta["choices"], value=None, interactive=interactive, info=info, **kwargs) elif kind == "radio": interactive = kwargs.pop("interactive", True) info = kwargs.pop("info", None) return gr.Radio(label=label, choices=meta["choices"], value=None, interactive=interactive, info=info, **kwargs) else: raise ValueError(f"Unknown component type '{kind}' for feature '{feature}'") # ───────────────────────────────────────────────────────────────────────────── # CONSTRAINT / VALIDATION HELPERS # ───────────────────────────────────────────────────────────────────────────── def _hla_update_for_donor(donor_value): if not donor_value: return gr.update(choices=HLA_FINAL_CHOICES, interactive=True) if donor_value in LOCKED_88_DONORS: return gr.update(choices=["8/8"], value="8/8", interactive=False) elif donor_value in NON_88_DONORS: return gr.update(choices=["7/8", "≤ 6/8"], value=None, interactive=True) return gr.update(choices=HLA_FINAL_CHOICES, interactive=True) def _validate_counterfactual_constraints(base_dict, wi_dict, label=""): violations = [] tag = f"[{label}] " if label else "" if base_dict.get("SEX") and wi_dict.get("SEX"): if base_dict["SEX"] != wi_dict["SEX"]: violations.append(f"{tag} Immutable feature: Sex cannot be changed.") base_age, wi_age = base_dict.get("AGE"), wi_dict.get("AGE") if base_age is not None and wi_age is not None: try: if float(wi_age) < float(base_age): violations.append(f"{tag}Age cannot be decreased ({base_age} → {wi_age}).") except (TypeError, ValueError): pass wi_donor = wi_dict.get("DONORF") wi_hla = wi_dict.get("HLA_FINAL") if wi_donor and wi_hla: if wi_donor in LOCKED_88_DONORS and wi_hla != "8/8": violations.append(f"{tag}HLA constraint: '{wi_donor}' requires 8/8 HLA.") elif wi_donor in NON_88_DONORS and wi_hla == "8/8": violations.append(f"{tag}HLA constraint: '{wi_donor}' cannot have 8/8 HLA.") if wi_donor: wi_gvhd = wi_dict.get("GVHD_FINAL", "") if wi_donor == MSD_DONOR and wi_gvhd in {"Post-CY + siro +- MMF", "Post-CY + MMF + CNI"}: violations.append( f"{tag} Post-Cy GVHD prophylaxis is inconsistent with donor '{MSD_DONOR}'." ) return violations def _values_to_dict(values): d = {} for f, v in zip(ALL_FEATURES, values): if f in NUM_COLS_SET: try: d[f] = float(v) if v not in (None, "") else None except (TypeError, ValueError): d[f] = None else: d[f] = v return d def _check_missing(user_vals, label=""): missing = [f for f, v in user_vals.items() if v is None or v == "" or (isinstance(v, float) and pd.isna(v))] if missing: raise ValueError( f"{'[' + label + '] ' if label else ''}Please fill in all fields. " f"Missing: {', '.join(missing)}" ) def get_age_group(age): if age is None or age == "": return "" try: age = float(age) if age <= 10: return "<=10" elif age <= 17: return "11-17" elif age <= 29: return "18-29" elif age <= 49: return "30-49" else: return ">=50" except (ValueError, TypeError): return "" def vocfrqpr_from_voc2ypr(voc_status): if voc_status == "No": return gr.update(value="< 3/yr", interactive=False) return gr.update(value=None, interactive=True) def apply_grouped_preset(selected_value): if not selected_value or selected_value in HEADER_VALUES: return (gr.update(value=None),) + (gr.update(),) * 6 if selected_value == "Custom": return (gr.update(),) + tuple(gr.update(interactive=True) for _ in range(6)) preset = PUBLISHED_PRESETS.get(selected_value) if not preset: return (gr.update(),) * 7 donor = preset["DONORF"] hla_update = gr.update( value=preset["HLA_FINAL"], interactive=False, choices=(["8/8"] if donor in LOCKED_88_DONORS else (["7/8", "≤ 6/8"] if donor in NON_88_DONORS else HLA_FINAL_CHOICES)), ) return ( gr.update(), gr.update(value=preset["DONORF"], interactive=False), gr.update(value=preset["CONDGRPF"], interactive=False), gr.update(value=preset["CONDGRP_FINAL"], interactive=False), gr.update(value=preset["ATGF"], interactive=False), gr.update(value=preset["GVHD_FINAL"], interactive=False), hla_update, ) def lock_sex(baseline_sex): if baseline_sex: return gr.update(value=baseline_sex, interactive=False) return gr.update(interactive=False) # ───────────────────────────────────────────────────────────────────────────── # HTML RENDERERS # ───────────────────────────────────────────────────────────────────────────── def _stick_figure_svg(color, size=16): h = round(size * 1.6) return ( f'' f'' f'' f'' f'' f'' f'' ) def _icon_card_html(probability, outcome, panel_label="", panel_color="#1565c0"): title = OUTCOME_TITLES.get(outcome, OUTCOME_DESCRIPTIONS.get(outcome, outcome)) n_event = round(probability * 100) n_no_event = 100 - n_event pct_str = f"{probability * 100:.1f}%" rows_parts = [] for row in range(10): cells = "" for col in range(10): idx = row * 10 + col color = EVENT_COLOR if idx < n_event else NO_EVENT_COLOR cells += _stick_figure_svg(color, size=13) rows_parts.append( f'
{cells}
' ) grid_html = "\n".join(rows_parts) fig_e = _stick_figure_svg(EVENT_COLOR, size=11) fig_ne = _stick_figure_svg(NO_EVENT_COLOR, size=11) legend = ( f'
' f'{fig_e}Event' f'({n_event}/100)' f'{fig_ne}No Event' f'({n_no_event}/100)' f'
' ) badge = ( f'
' f'{panel_label}
' ) if panel_label else "" return ( f'
' f'{badge}' f'
{title}
' f'
{pct_str}
' f'
{grid_html}
' f'
{legend}
' f'
' ) def _build_comparison_icon_grid(all_probs_list, labels, colors): """ all_probs_list: list of dicts (baseline first, then scenarios) labels: list of display labels colors: list of hex colors Layout: one row per scenario, columns = outcomes. """ # Header row with outcome names outcome_headers = "".join( f'
' f'{OUTCOME_TITLES.get(o, o)}
' for o in ICON_OUTCOMES ) header_row = ( f'
' f'{outcome_headers}
' ) rows_html = header_row for i, (probs, label, color) in enumerate(zip(all_probs_list, labels, colors)): # Row label row_label = ( f'
' f'{label}
' ) cards = "".join( f'
' f'{_icon_card_html(probs[o], o, "", color)}' f'
' for o in ICON_OUTCOMES ) rows_html += ( f'
' f'{row_label}{cards}
' ) footnote = ( f'
' f'Each figure = 1 patient out of 100. ' f'■ Red = Event   ' f'■ Green = No Event' f'
' ) return f'
{rows_html}{footnote}
' def _delta_color_html(delta, is_survival): if abs(delta) < 0.0005: color = "#888888" elif (delta > 0 and is_survival) or (delta < 0 and not is_survival): color = "#2e7d32" else: color = "#c62828" sign = "+" if delta >= 0 else "" return f'{sign}{delta*100:.1f}%' def _build_comparison_table_html(base_probs, base_ci, scenario_probs_list, scenario_ci_list, scenario_labels, scenario_colors): """ Inverted layout: rows = Baseline + Scenarios, columns = Outcomes. """ survival_outcomes = {"OS", "EFS"} # Build outcome column headers outcome_headers = "".join( f"{OUTCOME_DESCRIPTIONS.get(o, o)}" for o in REPORTING_OUTCOMES if o in base_probs ) header = ( "
" "" "" "" f"{outcome_headers}" "" ) rows = "" # Baseline row baseline_cells = "" for o in REPORTING_OUTCOMES: if o not in base_probs: continue bp = base_probs[o] blo, bhi = base_ci.get(o, (float("nan"), float("nan"))) baseline_cells += ( f"" ) rows += ( f"" f"" f"{baseline_cells}" ) # Scenario rows for j, (sp_dict, sci_dict, s_label) in enumerate(zip(scenario_probs_list, scenario_ci_list, scenario_labels)): sc_color = scenario_colors[j % len(scenario_colors)] bg = "#fafbfc" if j % 2 == 0 else "#ffffff" scenario_cells = "" for o in REPORTING_OUTCOMES: if o not in base_probs or o not in sp_dict: continue bp = base_probs[o] wp = sp_dict[o] wlo, whi = sci_dict.get(o, (float("nan"), float("nan"))) delta = wp - bp is_surv = o in survival_outcomes scenario_cells += ( f"" ) rows += ( f"" f"" f"{scenario_cells}" ) footer = ( "
Scenario
" f"
{bp*100:.1f}%
" f"
[{blo*100:.1f}%–{bhi*100:.1f}%]
" f"
" f"
" f"Baseline
" f"
{wp*100:.1f}%
" f"
[{wlo*100:.1f}%–{whi*100:.1f}%]
" f"
{_delta_color_html(delta, is_surv)}
" f"
" f"
" f"" f"{s_label}
" "
" "Δ from Baseline: Green = improvement  " "Red = worsening  |  " "OS & EFS: higher is better; all other outcomes: lower is better." "
" ) return header + rows + footer def _build_violation_html(violations): if not violations: return "" items = "".join(f"
  • {v}
  • " for v in violations) return ( f'
    ' f'
    ' f'Constraint Violations — Analysis blocked
    ' f'' f'
    ' f'Please correct the above before running the comparison.
    ' f'
    ' ) # ───────────────────────────────────────────────────────────────────────────── # MAIN PREDICT CALLBACK (Tab 1) # ───────────────────────────────────────────────────────────────────────────── def predict_gradio(*values): try: user_vals = _values_to_dict(values) missing = [f for f, v in user_vals.items() if v is None or v == "" or (isinstance(v, float) and pd.isna(v))] if missing: raise ValueError(f"Please fill in all fields. Missing: {', '.join(missing)}") calibrated, _ = predict_with_comparison(user_vals) calibrated_probs, calibrated_intervals = calibrated rows = [] for outcome in REPORTING_OUTCOMES: desc = OUTCOME_DESCRIPTIONS[outcome] calib_prob = calibrated_probs[outcome] ci_low, ci_high = calibrated_intervals[outcome] rows.append({ "Outcome": desc, "Probability": f"{calib_prob * 100:.1f}%", "95% CI": f"[{ci_low * 100:.1f}% – {ci_high * 100:.1f}%]", }) df = pd.DataFrame(rows) shap_plots = create_all_shap_plots(user_vals, max_display=10) icon_arrays = create_all_icon_arrays(calibrated_probs) return ( df, icon_arrays["__grid__"], shap_plots["DEAD"], shap_plots["GF"], shap_plots["AGVHD"], shap_plots["CGVHD"], shap_plots["VOCPSHI"], shap_plots["EFS"], shap_plots["STROKEHI"], shap_plots["OS"], ) except Exception as e: print(traceback.format_exc()) raise gr.Error(f"{type(e).__name__}: {str(e)}") # ───────────────────────────────────────────────────────────────────────────── # CSS # ───────────────────────────────────────────────────────────────────────────── custom_css = """ .predict-button { background: linear-gradient(to right, #ff6b35, #ff8c42) !important; border: none !important; color: white !important; font-weight: bold !important; font-size: 16px !important; padding: 12px !important; } .predict-button:hover { background: linear-gradient(to right, #ff5722, #ff7b29) !important; } .copy-button { background: linear-gradient(to right, #388e3c, #66bb6a) !important; border: none !important; color: white !important; font-weight: 600 !important; } .copy-button:hover { background: linear-gradient(to right, #2e7d32, #43a047) !important; } .copy-from-predict-button { background: linear-gradient(to right, #6a1b9a, #ab47bc) !important; border: none !important; color: white !important; font-weight: 600 !important; } .copy-from-predict-button:hover { background: linear-gradient(to right, #4a148c, #8e24aa) !important; } .counterfactual-button { background: linear-gradient(to right, #1976d2, #42a5f5) !important; border: none !important; color: white !important; font-weight: bold !important; font-size: 15px !important; padding: 12px !important; } .counterfactual-button:hover { background: linear-gradient(to right, #1565c0, #1e88e5) !important; } .output-dataframe table td:first-child, .output-dataframe table th:first-child { white-space: normal !important; word-break: break-word !important; min-width: 240px !important; } .constraint-info { background: #e8f5e9; border-left: 4px solid #388e3c; padding: 8px 14px; font-size: 12px; color: #1b5e20; border-radius: 4px; margin-bottom: 8px; } .scenario-panel-0 { border-left: 4px solid #e65100 !important; } .scenario-panel-1 { border-left: 4px solid #6a1b9a !important; } .scenario-panel-2 { border-left: 4px solid #1b5e20 !important; } .scenario-panel-3 { border-left: 4px solid #0d47a1 !important; } .scenario-panel-4 { border-left: 4px solid #b71c1c !important; } """ # ───────────────────────────────────────────────────────────────────────────── # BUILD UI # ───────────────────────────────────────────────────────────────────────────── with gr.Blocks(title="HCT Outcome Prediction Model", css=custom_css) as demo: gr.Markdown("# HCT Outcome Prediction Model") # ── shared state: how many counterfactual scenarios are active ──────────── n_scenarios_state = gr.State(1) with gr.Tabs(): # ══════════════════════════════════════════════════════════════════════ # TAB 1 — PREDICT OUTCOMES # ══════════════════════════════════════════════════════════════════════ with gr.Tab("Predict Outcomes"): gr.Markdown("Enter patient, transplant, and disease characteristics to predict outcomes.") inputs_dict = {} with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Patient Characteristics") for f in PATIENT_FEATURES: inputs_dict[f] = make_component(f) with gr.Column(scale=1): gr.Markdown("### Transplant Characteristics") grouped_dd = gr.Dropdown( choices=GROUPED_REGIMEN_CHOICES, value=None, label="Published conditioning regimen", info="Auto-fills Donor Type, Conditioning Intensity, Regimen, Serotherapy, GVHD Prophylaxis", ) p_donorf = inputs_dict["DONORF"] = make_component("DONORF") inputs_dict["GRAFTYPE"] = make_component("GRAFTYPE") p_condgrpf = inputs_dict["CONDGRPF"] = make_component("CONDGRPF") p_condgrp_final = inputs_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL") p_atgf = inputs_dict["ATGF"] = make_component("ATGF") p_gvhd_final = inputs_dict["GVHD_FINAL"] = make_component("GVHD_FINAL") p_hla_final = inputs_dict["HLA_FINAL"] = make_component("HLA_FINAL") with gr.Column(scale=1): gr.Markdown("### Disease Characteristics") for f in DISEASE_FEATURES: inputs_dict[f] = make_component(f) inputs_dict["AGE"].change(get_age_group, inputs_dict["AGE"], inputs_dict["AGEGPFF"]) inputs_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, inputs_dict["VOC2YPR"], inputs_dict["VOCFRQPR"]) p_donorf.change(fn=_hla_update_for_donor, inputs=p_donorf, outputs=p_hla_final) grouped_dd.change( apply_grouped_preset, grouped_dd, [grouped_dd, p_donorf, p_condgrpf, p_condgrp_final, p_atgf, p_gvhd_final, p_hla_final], ) inputs_list = [inputs_dict[f] for f in ALL_FEATURES] predict_btn = gr.Button("Predict", elem_classes="predict-button", size="lg") gr.Markdown("---") gr.Markdown("## Prediction Results") output_table = gr.Dataframe( headers=["Outcome", "Probability", "95% CI"], label="Predicted Outcomes", elem_classes="output-dataframe", row_count=(len(REPORTING_OUTCOMES), "dynamic"), column_count=(3, "fixed"), wrap=True, ) gr.Markdown("---") gr.Markdown("## Outcome Probability — Icon Arrays") icon_array_grid = gr.HTML() gr.Markdown("---") gr.Markdown("## SHAP — Feature Importance") with gr.Row(): shap_dead = gr.Plot(label="Death") shap_gf = gr.Plot(label="Graft Failure") shap_agvhd = gr.Plot(label="Acute GvHD") shap_cgvhd = gr.Plot(label="Chronic GvHD") with gr.Row(): shap_vocpshi = gr.Plot(label="Vaso-Occlusive Crisis Post-HCT") shap_efs = gr.Plot(label="Event-Free Survival") shap_stroke = gr.Plot(label="Stroke Post-HCT") shap_os = gr.Plot(label="Overall Survival") predict_btn.click( fn=predict_gradio, inputs=inputs_list, outputs=[ output_table, icon_array_grid, shap_dead, shap_gf, shap_agvhd, shap_cgvhd, shap_vocpshi, shap_efs, shap_stroke, shap_os, ], ) # ══════════════════════════════════════════════════════════════════════ # TAB 2 — COUNTERFACTUAL ANALYSIS (dynamic multi-scenario) # ══════════════════════════════════════════════════════════════════════ with gr.Tab("Counterfactual Scenarios"): gr.Markdown( "## Counterfactual Scenario Analysis\n" "Enter the **baseline** patient, then choose how many counterfactual " "scenarios you want to compare. Each scenario panel will appear below." ) # ── Number of scenarios selector ────────────────────────────────── with gr.Row(): n_scenarios_slider = gr.Slider( minimum=1, maximum=MAX_SCENARIOS, step=1, value=1, label=f"How many counterfactual scenarios do you want to compare? (1–{MAX_SCENARIOS})", info="Adjust this first — scenario panels will appear/disappear below.", ) gr.Markdown("---") # ── BASELINE ───────────────────────────────────────────────────── gr.Markdown("## Baseline Patient Profile") with gr.Row(): copy_from_predict_btn = gr.Button( "Copy from Predict tab → Baseline", elem_classes="copy-from-predict-button", size="sm", ) wi_baseline_dict = {} with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Patient Characteristics") for f in PATIENT_FEATURES: wi_baseline_dict[f] = make_component(f) with gr.Column(scale=1): gr.Markdown("### Transplant Characteristics") wi_grouped_base = gr.Dropdown( choices=GROUPED_REGIMEN_CHOICES, value=None, label="Published conditioning regimen (Baseline)", ) wb_donorf = wi_baseline_dict["DONORF"] = make_component("DONORF") wi_baseline_dict["GRAFTYPE"] = make_component("GRAFTYPE") wb_condgrpf = wi_baseline_dict["CONDGRPF"] = make_component("CONDGRPF") wb_condgrp_final = wi_baseline_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL") wb_atgf = wi_baseline_dict["ATGF"] = make_component("ATGF") wb_gvhd_final = wi_baseline_dict["GVHD_FINAL"] = make_component("GVHD_FINAL") wb_hla_final = wi_baseline_dict["HLA_FINAL"] = make_component("HLA_FINAL") with gr.Column(scale=1): gr.Markdown("### Disease Characteristics") for f in DISEASE_FEATURES: wi_baseline_dict[f] = make_component(f) wi_baseline_dict["AGE"].change(get_age_group, wi_baseline_dict["AGE"], wi_baseline_dict["AGEGPFF"]) wi_baseline_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, wi_baseline_dict["VOC2YPR"], wi_baseline_dict["VOCFRQPR"]) wb_donorf.change(fn=_hla_update_for_donor, inputs=wb_donorf, outputs=wb_hla_final) wi_grouped_base.change( apply_grouped_preset, wi_grouped_base, [wi_grouped_base, wb_donorf, wb_condgrpf, wb_condgrp_final, wb_atgf, wb_gvhd_final, wb_hla_final], ) wi_baseline_list = [wi_baseline_dict[f] for f in ALL_FEATURES] # Copy from Predict tab → Baseline def _copy_predict_to_baseline(*vals): n = len(ALL_FEATURES) feature_vals = vals[:n] regimen_value = vals[n] if len(vals) > n else None preset_outputs = apply_grouped_preset(regimen_value) regimen_dd_upd = gr.update(value=regimen_value) donorf_upd, condgrpf_upd, condgrp_final_upd, atgf_upd, gvhd_final_upd, hla_final_upd = preset_outputs[1:] return (*list(feature_vals), regimen_dd_upd, donorf_upd, condgrpf_upd, condgrp_final_upd, atgf_upd, gvhd_final_upd, hla_final_upd) copy_from_predict_btn.click( fn=_copy_predict_to_baseline, inputs=inputs_list + [grouped_dd], outputs=wi_baseline_list + [wi_grouped_base, wb_donorf, wb_condgrpf, wb_condgrp_final, wb_atgf, wb_gvhd_final, wb_hla_final], ) gr.Markdown("---") # ── SCENARIO PANELS ─────────────────────────────────────────────── scenario_dicts = [] scenario_lists = [] scenario_grouped_dds = [] scenario_rows = [] scenario_name_inputs = [] # NEW: one gr.Textbox per scenario for custom name scenario_donorf_handles = [] scenario_hla_handles = [] scenario_voc2ypr_handles = [] scenario_age_handles = [] scenario_agegp_handles = [] scenario_copy_btns = [] for s_idx in range(MAX_SCENARIOS): color = SCENARIO_COLORS[s_idx] label = f"Scenario {s_idx + 1}" suffix = f"({label})" visible_init = (s_idx == 0) with gr.Row(visible=visible_init) as s_row: scenario_rows.append(s_row) with gr.Column(): gr.HTML( f'
    ' f'Counterfactual {label}
    ' ) with gr.Row(): # NEW: scenario name input s_name_input = gr.Textbox( label=f"Scenario name ({label})", value=label, placeholder=f"e.g. {label}", scale=2, ) scenario_name_inputs.append(s_name_input) copy_btn = gr.Button( f"⬇ Copy Baseline → {label}", elem_classes="copy-button", size="sm", scale=1, ) scenario_copy_btns.append(copy_btn) s_dict = {} with gr.Row(): with gr.Column(scale=1): gr.Markdown(f"#### Patient — {label}") for f in PATIENT_FEATURES: s_dict[f] = make_component(f, suffix) with gr.Column(scale=1): gr.Markdown(f"#### Transplant — {label}") s_grouped_dd = gr.Dropdown( choices=GROUPED_REGIMEN_CHOICES, value=None, label=f"Published regimen ({label})", ) scenario_grouped_dds.append(s_grouped_dd) s_donorf = s_dict["DONORF"] = make_component("DONORF", suffix) s_dict["GRAFTYPE"] = make_component("GRAFTYPE", suffix) s_condgrpf = s_dict["CONDGRPF"] = make_component("CONDGRPF", suffix) s_condgrp_final = s_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL", suffix) s_atgf = s_dict["ATGF"] = make_component("ATGF", suffix) s_gvhd_final = s_dict["GVHD_FINAL"] = make_component("GVHD_FINAL", suffix) s_hla_final = s_dict["HLA_FINAL"] = make_component("HLA_FINAL", suffix) with gr.Column(scale=1): gr.Markdown(f"#### Disease — {label}") for f in DISEASE_FEATURES: s_dict[f] = make_component(f, suffix) scenario_dicts.append(s_dict) scenario_lists.append([s_dict[f] for f in ALL_FEATURES]) scenario_donorf_handles.append(s_donorf) scenario_hla_handles.append(s_hla_final) scenario_voc2ypr_handles.append(s_dict["VOC2YPR"]) scenario_age_handles.append(s_dict["AGE"]) scenario_agegp_handles.append(s_dict["AGEGPFF"]) # Wire up constraints within each scenario panel s_dict["AGE"].change(get_age_group, s_dict["AGE"], s_dict["AGEGPFF"]) s_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, s_dict["VOC2YPR"], s_dict["VOCFRQPR"]) s_donorf.change(fn=_hla_update_for_donor, inputs=s_donorf, outputs=s_hla_final) s_grouped_dd.change( apply_grouped_preset, s_grouped_dd, [s_grouped_dd, s_donorf, s_condgrpf, s_condgrp_final, s_atgf, s_gvhd_final, s_hla_final], ) # Mirror baseline sex into each scenario and lock it wi_baseline_dict["SEX"].change( fn=lock_sex, inputs=wi_baseline_dict["SEX"], outputs=s_dict["SEX"], ) s_dict["SEX"].interactive = False # Copy Baseline → this scenario def _make_copy_fn(s_grouped_dd_ref, s_donorf_ref, s_condgrpf_ref, s_condgrp_final_ref, s_atgf_ref, s_gvhd_final_ref, s_hla_final_ref): def _copy(*vals): n = len(ALL_FEATURES) feature_vals = vals[:n] regimen_value = vals[n] if len(vals) > n else None preset_outputs = apply_grouped_preset(regimen_value) regimen_dd_upd = gr.update(value=regimen_value) d_upd, c_upd, cr_upd, a_upd, g_upd, h_upd = preset_outputs[1:] return (*list(feature_vals), regimen_dd_upd, d_upd, c_upd, cr_upd, a_upd, g_upd, h_upd) return _copy copy_btn.click( fn=_make_copy_fn(s_grouped_dd, s_donorf, s_condgrpf, s_condgrp_final, s_atgf, s_gvhd_final, s_hla_final), inputs=wi_baseline_list + [wi_grouped_base], outputs=(scenario_lists[-1] + [s_grouped_dd, s_donorf, s_condgrpf, s_condgrp_final, s_atgf, s_gvhd_final, s_hla_final]), ) # Re-lock sex after copy copy_btn.click( fn=lock_sex, inputs=wi_baseline_dict["SEX"], outputs=s_dict["SEX"], ) # ── Slider → show/hide scenario panels ─────────────────────────── def _update_scenario_visibility(n): return [gr.update(visible=(i < int(n))) for i in range(MAX_SCENARIOS)] n_scenarios_slider.change( fn=_update_scenario_visibility, inputs=n_scenarios_slider, outputs=scenario_rows, ) n_scenarios_slider.change( fn=lambda n: int(n), inputs=n_scenarios_slider, outputs=n_scenarios_state, ) # ── RUN ────────────────────────────────────────────────────────── gr.Markdown("---") wi_run_btn = gr.Button( "Run Counterfactual Comparison", elem_classes="counterfactual-button", size="lg", ) # ── RESULTS ────────────────────────────────────────────────────── gr.Markdown("## Comparison Results") wi_violation_html = gr.HTML() gr.Markdown("### Outcome Probability Table") wi_table_html = gr.HTML() gr.Markdown("---") # NEW: Collapsible icon arrays section with gr.Accordion("Outcome Icon Arrays — Baseline vs Scenarios", open=False): gr.Markdown("*Icon arrays show each outcome probability per 100 patients.*") wi_icon_html = gr.HTML() gr.Markdown("---") # NEW: Collapsible SHAP section with gr.Accordion("SHAP Feature Importance", open=False): gr.Markdown("### Baseline") with gr.Row(): wi_shap_base = {o: gr.Plot(label=f"{o} — Baseline") for o in SHAP_ORDER} # One SHAP row per scenario slot (hidden until needed) wi_shap_scenarios = [] for s_idx in range(MAX_SCENARIOS): gr.Markdown(f"### Scenario {s_idx + 1}") with gr.Row(visible=(s_idx == 0)) as shap_row: shap_plots_s = {o: gr.Plot(label=f"{o} — Scenario {s_idx + 1}") for o in SHAP_ORDER} wi_shap_scenarios.append((shap_row, shap_plots_s)) # Also toggle SHAP rows with slider def _update_shap_visibility(n): return [gr.update(visible=(i < int(n))) for i in range(MAX_SCENARIOS)] n_scenarios_slider.change( fn=_update_shap_visibility, inputs=n_scenarios_slider, outputs=[row for row, _ in wi_shap_scenarios], ) # ── RUN callback ────────────────────────────────────────────────── def run_counterfactual(*all_values): n_feat = len(ALL_FEATURES) n_scen = int(all_values[0]) # n_scenarios_state base_vals = all_values[1 : 1 + n_feat] # Scenario names come next (MAX_SCENARIOS of them) name_offset = 1 + n_feat scenario_names_raw = all_values[name_offset : name_offset + MAX_SCENARIOS] # Then the scenario feature blocks feat_offset = name_offset + MAX_SCENARIOS scenario_val_blocks = [ all_values[feat_offset + s * n_feat : feat_offset + (s + 1) * n_feat] for s in range(MAX_SCENARIOS) ] # Prepare flat output list base_shap_outputs = [None] * 8 scene_shap_outputs = [[None] * 8 for _ in range(MAX_SCENARIOS)] try: base_dict = _values_to_dict(base_vals) _check_missing(base_dict, "Baseline") all_violations = [] scenario_dicts_active = [] for s in range(n_scen): sd = _values_to_dict(scenario_val_blocks[s]) _check_missing(sd, f"Scenario {s+1}") v = _validate_counterfactual_constraints(base_dict, sd, f"Scenario {s+1}") all_violations.extend(v) scenario_dicts_active.append(sd) if all_violations: return ( _build_violation_html(all_violations), "", "", *base_shap_outputs, *[p for plots in scene_shap_outputs for p in plots], ) base_probs, base_ci = predict_all_outcomes( base_dict, use_calibration=True, use_signed_voting=True, n_boot_ci=500 ) scen_probs_list = [] scen_ci_list = [] for sd in scenario_dicts_active: sp, sci = predict_all_outcomes( sd, use_calibration=True, use_signed_voting=True, n_boot_ci=500 ) scen_probs_list.append(sp) scen_ci_list.append(sci) # Use custom scenario names (fallback to "Scenario N" if blank) scen_labels = [] for i in range(n_scen): raw_name = scenario_names_raw[i] if i < len(scenario_names_raw) else "" name = str(raw_name).strip() if raw_name else "" scen_labels.append(name if name else f"Scenario {i+1}") scen_colors = SCENARIO_COLORS[:n_scen] table_html = _build_comparison_table_html( base_probs, base_ci, scen_probs_list, scen_ci_list, scen_labels, scen_colors ) icon_html = _build_comparison_icon_grid( [base_probs] + scen_probs_list, ["Baseline"] + scen_labels, ["#1565c0"] + scen_colors, ) # SHAP base_shap_plots = create_all_shap_plots(base_dict, max_display=10) base_shap_outputs = [base_shap_plots[o] for o in SHAP_ORDER] for s, sd in enumerate(scenario_dicts_active): sp_plots = create_all_shap_plots(sd, max_display=10) scene_shap_outputs[s] = [sp_plots[o] for o in SHAP_ORDER] return ( "", table_html, icon_html, *base_shap_outputs, *[p for plots in scene_shap_outputs for p in plots], ) except Exception as e: print(traceback.format_exc()) raise gr.Error(f"{type(e).__name__}: {str(e)}") # Build flat output list: violation + table + icon + 8 base shap + MAX*8 scene shap all_run_outputs = ( [wi_violation_html, wi_table_html, wi_icon_html] + [wi_shap_base[o] for o in SHAP_ORDER] + [shap_plots_s[o] for _, shap_plots_s in wi_shap_scenarios for o in SHAP_ORDER] ) # Build flat input list: # n_scenarios_state + wi_baseline_list + scenario_name_inputs + MAX_SCENARIOS × scenario_list all_run_inputs = ( [n_scenarios_state] + wi_baseline_list + scenario_name_inputs # NEW: custom names + [feat for s_list in scenario_lists for feat in s_list] ) wi_run_btn.click( fn=run_counterfactual, inputs=all_run_inputs, outputs=all_run_outputs, ) # ───────────────────────────────────────────────────────────────────────────── if __name__ == "__main__": demo.launch(ssr_mode=False)