| import gradio as gr |
| import pandas as pd |
| import numpy as np |
| import copy |
| import traceback |
|
|
| from inference import ( |
| FEATURE_NAMES, |
| REPORTING_OUTCOMES, |
| OUTCOME_DESCRIPTIONS, |
| OUTCOMES, |
| SHAP_OUTCOMES, |
| predict_with_comparison, |
| predict_all_outcomes, |
| create_all_shap_plots, |
| create_all_icon_arrays, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| AGEGPFF_CHOICES = ["<=10", "11-17", "18-29", "30-49", ">=50"] |
| SEX_CHOICES = ["Male", "Female"] |
| KPS_CHOICES = ["<90", "β₯ 90"] |
| DONORF_CHOICES = [ |
| "HLA identical sibling", "HLA mismatch relative", |
| "Matched unrelated donor", "Mismatched unrelated donor or cord blood", |
| ] |
| GRAFTYPE_CHOICES = ["Bone marrow", "Peripheral blood", "Cord blood"] |
| CONDGRPF_CHOICES = ["MAC", "RIC", "NMA"] |
| CONDGRP_FINAL_CHOICES = [ |
| "TBI/Cy", "TBI/Cy/Flu", "TBI/Cy/Flu/TT", "TBI/Mel", "TBI/Flu", |
| "TBI alone (300/400/600cGy)", "Bu/Cy", "Bu/Mel", "Flu/Bu/TT", |
| "Flu/Bu", "Flu/Mel/TT", "Flu/Mel", "Cy/Flu", "Treosulfan", |
| "Cy alone", "Flud", "TLI", |
| ] |
| ATGF_CHOICES = ["ATG", "Alemtuzumab", "None"] |
| GVHD_FINAL_CHOICES = [ |
| "Ex-vivo T-cell depletion", "CD34 selection", "Post-CY + siro +- MMF", |
| "Post-CY + MMF + CNI", "CNI + MMF", "CNI + MTX", "CNI alone", |
| "CNI + siro", "Siro alone", "MMF + MTX", "MMF + siro", "MMF alone", |
| "MTX alone", "MTX + siro", |
| ] |
| HLA_FINAL_CHOICES = ["8/8", "7/8", "β€ 6/8"] |
| RCMVPR_CHOICES = ["Negative", "Positive"] |
| EXCHTFPR_CHOICES = ["No", "Yes"] |
| VOC2YPR_CHOICES = ["No", "Yes"] |
| VOCFRQPR_CHOICES = ["< 3/yr", "β₯ 3/yr"] |
| SCATXRSN_CHOICES = [ |
| "CNS event", "Acute chest Syndrome", "Recurrent vaso-occlusive pain", |
| "Recurrent priapism", "Excessive transfusion requirements/iron overload", |
| "Cardio-pulmonary", "Chronic transfusion", "Asymptomatic", |
| "Renal insufficiency", "Splenic sequestration", "Avascular necrosis", |
| "Hodgkin lymphoma", |
| ] |
|
|
| NUM_COLS_SET = {"AGE", "NACS2YR"} |
| MAX_SCENARIOS = 5 |
|
|
| GROUPED_REGIMEN_CHOICES = [ |
| ("ββ HLA IDENTICAL ββ", "__header_hla_identical__"), |
| ("Hsieh et al 2014", "Hsieh et al 2014"), |
| ("Krishnamurti et al 2019", "Krishnamurti et al 2019"), |
| ("King et al 2015", "King et al 2015"), |
| ("Walters et al 1996", "Walters et al 1996"), |
| ("ββ HLA MISMATCHED ββ", "__header_hla_mismatched__"), |
| ("Bolanos-Meade et al 2022 (HLA Mismatch)", "Bolanos-Meade et al 2022 (HLA Mismatch)"), |
| ("Patel et al 2020 (HLA Mismatch)", "Patel et al 2020 (HLA Mismatch)"), |
| ("ββ MATCHED UNRELATED ββ", "__header_matched_unrelated__"), |
| ("L Krishnamurti et al 2019", "L Krishnamurti et al 2019"), |
| ("Shenoy et al 2016", "Shenoy et al 2016"), |
| ("ββ MISMATCHED UNRELATED / CORD BLOOD ββ", "__header_mismatched_cord__"), |
| ("Bolanos-Meade et al 2022 (Mismatched/Cord)", "Bolanos-Meade et al 2022 (Mismatched/Cord)"), |
| ("Patel et al 2020 (Mismatched/Cord)", "Patel et al 2020 (Mismatched/Cord)"), |
| ("ββ CUSTOM ββ", "__header_custom__"), |
| ("Custom", "Custom"), |
| ] |
| HEADER_VALUES = {v for _, v in GROUPED_REGIMEN_CHOICES if v.startswith("__header_")} |
|
|
| PUBLISHED_PRESETS = { |
| "Hsieh et al 2014": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI alone (300/400/600cGy)", "ATGF": "Alemtuzumab", "GVHD_FINAL": "Siro alone", "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling"}, |
| "Krishnamurti et al 2019": {"CONDGRPF": "MAC", "CONDGRP_FINAL": "Flu/Bu", "ATGF": "ATG", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling"}, |
| "King et al 2015": {"CONDGRPF": "RIC", "CONDGRP_FINAL": "Flu/Mel", "ATGF": "Alemtuzumab", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling"}, |
| "Walters et al 1996": {"CONDGRPF": "MAC", "CONDGRP_FINAL": "Bu/Cy", "ATGF": "ATG", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling"}, |
| "Bolanos-Meade et al 2022 (HLA Mismatch)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "HLA mismatch relative"}, |
| "Patel et al 2020 (HLA Mismatch)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu/TT", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "HLA mismatch relative"}, |
| "L Krishnamurti et al 2019": {"CONDGRPF": "MAC", "CONDGRP_FINAL": "Flu/Bu", "ATGF": "ATG", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "Matched unrelated donor"}, |
| "Shenoy et al 2016": {"CONDGRPF": "RIC", "CONDGRP_FINAL": "Flu/Mel", "ATGF": "Alemtuzumab", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "Matched unrelated donor"}, |
| "Bolanos-Meade et al 2022 (Mismatched/Cord)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "Mismatched unrelated donor or cord blood"}, |
| "Patel et al 2020 (Mismatched/Cord)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu/TT", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "Mismatched unrelated donor or cord blood"}, |
| } |
|
|
| MSD_DONOR = "HLA identical sibling" |
| MUD_DONOR = "Matched unrelated donor" |
| LOCKED_88_DONORS = {MSD_DONOR, MUD_DONOR} |
| NON_88_DONORS = {"HLA mismatch relative", "Mismatched unrelated donor or cord blood"} |
|
|
| MSD_REGIMENS = {"Hsieh et al 2014", "Krishnamurti et al 2019", "King et al 2015", "Walters et al 1996"} |
| MMUD_REGIMENS = {"Bolanos-Meade et al 2022 (Mismatched/Cord)", "Patel et al 2020 (Mismatched/Cord)", |
| "Bolanos-Meade et al 2022 (HLA Mismatch)", "Patel et al 2020 (HLA Mismatch)"} |
|
|
| PATIENT_FEATURES = ["AGE", "AGEGPFF", "SEX", "KPS", "RCMVPR"] |
| DONOR_FEATURES = ["DONORF", "GRAFTYPE", "HLA_FINAL", "CONDGRPF", "CONDGRP_FINAL", "ATGF", "GVHD_FINAL"] |
| DISEASE_FEATURES = ["NACS2YR", "EXCHTFPR", "VOC2YPR", "VOCFRQPR", "SCATXRSN"] |
| ALL_FEATURES = PATIENT_FEATURES + DONOR_FEATURES + DISEASE_FEATURES |
|
|
| |
| ICON_OUTCOMES = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI"] |
| OUTCOME_TITLES = { |
| "DEAD": "Death", |
| "GF": "Graft Failure", |
| "AGVHD": "Acute GvHD", |
| "CGVHD": "Chronic GvHD", |
| "VOCPSHI": "Vaso-Occlusive Crisis", |
| "STROKEHI": "Stroke Post-HCT", |
| } |
| EVENT_COLOR = "#e53935" |
| NO_EVENT_COLOR = "#43a047" |
| SHAP_ORDER = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "EFS", "STROKEHI", "OS"] |
|
|
| SHAP_LABELS = { |
| "DEAD": "Death", |
| "GF": "Graft Failure", |
| "AGVHD": "Acute GvHD", |
| "CGVHD": "Chronic GvHD", |
| "VOCPSHI": "Vaso-Occlusive Crisis Post-HCT", |
| "EFS": "Event-Free Survival", |
| "STROKEHI": "Stroke Post-HCT", |
| "OS": "Overall Survival", |
| } |
|
|
| SCENARIO_COLORS = ["#e65100", "#6a1b9a", "#1b5e20", "#0d47a1", "#b71c1c"] |
|
|
|
|
| |
| |
| |
|
|
| _FEATURE_META = { |
| "AGE": {"label": "Age (years)", "type": "number", "kwargs": {"minimum": 0, "maximum": 100, "step": 1}}, |
| "AGEGPFF": {"label": "Age Group", "type": "dropdown", "choices": AGEGPFF_CHOICES, "kwargs": {"interactive": False, "info": "Auto-filled from Age"}}, |
| "SEX": {"label": "Sex", "type": "radio", "choices": SEX_CHOICES}, |
| "KPS": {"label": "Karnofsky Performance Status", "type": "radio", "choices": KPS_CHOICES}, |
| "RCMVPR": {"label": "Recipient CMV Status", "type": "radio", "choices": RCMVPR_CHOICES}, |
| "DONORF": {"label": "Donor Type", "type": "dropdown", "choices": DONORF_CHOICES}, |
| "GRAFTYPE": {"label": "Graft Type", "type": "radio", "choices": GRAFTYPE_CHOICES}, |
| "HLA_FINAL": {"label": "HLA Matching", "type": "radio", "choices": HLA_FINAL_CHOICES}, |
| "CONDGRPF": {"label": "Conditioning Intensity", "type": "radio", "choices": CONDGRPF_CHOICES}, |
| "CONDGRP_FINAL": {"label": "Conditioning Regimen", "type": "dropdown", "choices": CONDGRP_FINAL_CHOICES}, |
| "ATGF": {"label": "Serotherapy", "type": "radio", "choices": ATGF_CHOICES}, |
| "GVHD_FINAL": {"label": "GvHD Prophylaxis", "type": "dropdown", "choices": GVHD_FINAL_CHOICES}, |
| "NACS2YR": {"label": "Acute Chest Syndrome Episodes (2 yrs pre-HCT)", "type": "number", "kwargs": {"minimum": 0, "maximum": 50, "step": 1}}, |
| "EXCHTFPR": {"label": "Exchange Transfusion Pre-HCT", "type": "radio", "choices": EXCHTFPR_CHOICES}, |
| "VOC2YPR": {"label": "Vaso-Occlusive Crisis in 2 yrs Pre-HCT", "type": "radio", "choices": VOC2YPR_CHOICES}, |
| "VOCFRQPR": {"label": "VoC Frequency Pre-HCT", "type": "radio", "choices": VOCFRQPR_CHOICES}, |
| "SCATXRSN": {"label": "Primary Reason for HCT", "type": "dropdown", "choices": SCATXRSN_CHOICES}, |
| } |
|
|
|
|
| def make_component(feature: str, suffix: str = ""): |
| meta = _FEATURE_META[feature] |
| label = f"{meta['label']} {suffix}".strip() if suffix else meta["label"] |
| kind = meta["type"] |
| kwargs = dict(meta.get("kwargs", {})) |
|
|
| if kind == "number": |
| return gr.Number(label=label, value=None, **kwargs) |
| elif kind == "dropdown": |
| interactive = kwargs.pop("interactive", True) |
| info = kwargs.pop("info", None) |
| return gr.Dropdown(label=label, choices=meta["choices"], value=None, |
| interactive=interactive, info=info, **kwargs) |
| elif kind == "radio": |
| interactive = kwargs.pop("interactive", True) |
| info = kwargs.pop("info", None) |
| return gr.Radio(label=label, choices=meta["choices"], value=None, |
| interactive=interactive, info=info, **kwargs) |
| else: |
| raise ValueError(f"Unknown component type '{kind}' for feature '{feature}'") |
|
|
|
|
| |
| |
| |
|
|
| def _hla_update_for_donor(donor_value): |
| if not donor_value: |
| return gr.update(choices=HLA_FINAL_CHOICES, interactive=True) |
| if donor_value in LOCKED_88_DONORS: |
| return gr.update(choices=["8/8"], value="8/8", interactive=False) |
| elif donor_value in NON_88_DONORS: |
| return gr.update(choices=["7/8", "β€ 6/8"], value=None, interactive=True) |
| return gr.update(choices=HLA_FINAL_CHOICES, interactive=True) |
|
|
|
|
| def _validate_counterfactual_constraints(base_dict, wi_dict, label=""): |
| violations = [] |
| tag = f"[{label}] " if label else "" |
|
|
| if base_dict.get("SEX") and wi_dict.get("SEX"): |
| if base_dict["SEX"] != wi_dict["SEX"]: |
| violations.append(f"{tag} Immutable feature: Sex cannot be changed.") |
|
|
| base_age, wi_age = base_dict.get("AGE"), wi_dict.get("AGE") |
| if base_age is not None and wi_age is not None: |
| try: |
| if float(wi_age) < float(base_age): |
| violations.append(f"{tag}Age cannot be decreased ({base_age} β {wi_age}).") |
| except (TypeError, ValueError): |
| pass |
|
|
| wi_donor = wi_dict.get("DONORF") |
| wi_hla = wi_dict.get("HLA_FINAL") |
| if wi_donor and wi_hla: |
| if wi_donor in LOCKED_88_DONORS and wi_hla != "8/8": |
| violations.append(f"{tag}HLA constraint: '{wi_donor}' requires 8/8 HLA.") |
| elif wi_donor in NON_88_DONORS and wi_hla == "8/8": |
| violations.append(f"{tag}HLA constraint: '{wi_donor}' cannot have 8/8 HLA.") |
|
|
| if wi_donor: |
| wi_gvhd = wi_dict.get("GVHD_FINAL", "") |
| if wi_donor == MSD_DONOR and wi_gvhd in {"Post-CY + siro +- MMF", "Post-CY + MMF + CNI"}: |
| violations.append( |
| f"{tag} Post-Cy GVHD prophylaxis is inconsistent with donor '{MSD_DONOR}'." |
| ) |
| return violations |
|
|
|
|
| def _values_to_dict(values): |
| d = {} |
| for f, v in zip(ALL_FEATURES, values): |
| if f in NUM_COLS_SET: |
| try: |
| d[f] = float(v) if v not in (None, "") else None |
| except (TypeError, ValueError): |
| d[f] = None |
| else: |
| d[f] = v |
| return d |
|
|
|
|
| def _check_missing(user_vals, label=""): |
| missing = [f for f, v in user_vals.items() |
| if v is None or v == "" or (isinstance(v, float) and pd.isna(v))] |
| if missing: |
| raise ValueError( |
| f"{'[' + label + '] ' if label else ''}Please fill in all fields. " |
| f"Missing: {', '.join(missing)}" |
| ) |
|
|
|
|
| def get_age_group(age): |
| if age is None or age == "": |
| return "" |
| try: |
| age = float(age) |
| if age <= 10: return "<=10" |
| elif age <= 17: return "11-17" |
| elif age <= 29: return "18-29" |
| elif age <= 49: return "30-49" |
| else: return ">=50" |
| except (ValueError, TypeError): |
| return "" |
|
|
|
|
| def vocfrqpr_from_voc2ypr(voc_status): |
| if voc_status == "No": |
| return gr.update(value="< 3/yr", interactive=False) |
| return gr.update(value=None, interactive=True) |
|
|
|
|
| def apply_grouped_preset(selected_value): |
| if not selected_value or selected_value in HEADER_VALUES: |
| return (gr.update(value=None),) + (gr.update(),) * 6 |
|
|
| if selected_value == "Custom": |
| return (gr.update(),) + tuple(gr.update(interactive=True) for _ in range(6)) |
|
|
| preset = PUBLISHED_PRESETS.get(selected_value) |
| if not preset: |
| return (gr.update(),) * 7 |
|
|
| donor = preset["DONORF"] |
| hla_update = gr.update( |
| value=preset["HLA_FINAL"], interactive=False, |
| choices=(["8/8"] if donor in LOCKED_88_DONORS |
| else (["7/8", "β€ 6/8"] if donor in NON_88_DONORS else HLA_FINAL_CHOICES)), |
| ) |
| return ( |
| gr.update(), |
| gr.update(value=preset["DONORF"], interactive=False), |
| gr.update(value=preset["CONDGRPF"], interactive=False), |
| gr.update(value=preset["CONDGRP_FINAL"], interactive=False), |
| gr.update(value=preset["ATGF"], interactive=False), |
| gr.update(value=preset["GVHD_FINAL"], interactive=False), |
| hla_update, |
| ) |
|
|
|
|
| def lock_sex(baseline_sex): |
| if baseline_sex: |
| return gr.update(value=baseline_sex, interactive=False) |
| return gr.update(interactive=False) |
|
|
|
|
| |
| |
| |
|
|
| def _stick_figure_svg(color, size=16): |
| h = round(size * 1.6) |
| return ( |
| f'<svg xmlns="http://www.w3.org/2000/svg" width="{size}" height="{h}" ' |
| f'viewBox="0 0 20 32" style="display:block;flex-shrink:0;" ' |
| f'stroke="{color}" stroke-width="2.2" stroke-linecap="round" fill="none">' |
| f'<circle cx="10" cy="5" r="3.8" fill="{color}" stroke="none"/>' |
| f'<line x1="10" y1="9" x2="10" y2="20"/>' |
| f'<line x1="3" y1="13" x2="17" y2="13"/>' |
| f'<line x1="10" y1="20" x2="4" y2="30"/>' |
| f'<line x1="10" y1="20" x2="16" y2="30"/>' |
| f'</svg>' |
| ) |
|
|
|
|
| def _render_single_icon_card(probability, outcome, panel_label="", panel_color="#1565c0"): |
| """Render a single icon array card as HTML (no JS). Used for Python-driven carousel.""" |
| title = OUTCOME_TITLES.get(outcome, OUTCOME_DESCRIPTIONS.get(outcome, outcome)) |
| n_event = round(probability * 100) |
| n_no_event = 100 - n_event |
| pct_str = f"{probability * 100:.1f}%" |
|
|
| rows_parts = [] |
| for row in range(10): |
| cells = "" |
| for col in range(10): |
| idx = row * 10 + col |
| color = EVENT_COLOR if idx < n_event else NO_EVENT_COLOR |
| cells += _stick_figure_svg(color, size=13) |
| rows_parts.append( |
| f'<div style="display:flex;justify-content:center;gap:1px;margin-bottom:1px;">{cells}</div>' |
| ) |
| grid_html = "\n".join(rows_parts) |
|
|
| fig_e = _stick_figure_svg(EVENT_COLOR, size=11) |
| fig_ne = _stick_figure_svg(NO_EVENT_COLOR, size=11) |
| legend = ( |
| f'<div style="display:inline-grid;grid-template-columns:13px 1fr 36px;' |
| f'align-items:center;gap:3px;row-gap:3px;">' |
| f'{fig_e}<span style="color:{EVENT_COLOR};font-weight:700;font-size:9px;">Event</span>' |
| f'<span style="color:#888;font-size:8px;">({n_event}/100)</span>' |
| f'{fig_ne}<span style="color:{NO_EVENT_COLOR};font-weight:700;font-size:9px;">No Event</span>' |
| f'<span style="color:#888;font-size:8px;">({n_no_event}/100)</span>' |
| f'</div>' |
| ) |
| badge = ( |
| f'<div style="background:{panel_color};color:#fff;font-size:8px;font-weight:700;' |
| f'border-radius:3px;padding:1px 5px;margin-bottom:2px;display:inline-block;">' |
| f'{panel_label}</div>' |
| ) if panel_label else "" |
|
|
| return ( |
| f'<div style="background:#fff;border:1px solid #e0e0e0;border-radius:7px;' |
| f'padding:10px 8px;text-align:center;font-family:\'Segoe UI\',Arial,sans-serif;' |
| f'box-shadow:0 2px 4px rgba(0,0,0,0.06);box-sizing:border-box;max-width:340px;' |
| f'margin:0 auto;display:flex;flex-direction:column;align-items:center;">' |
| f'{badge}' |
| f'<div style="min-height:30px;display:flex;align-items:center;justify-content:center;' |
| f'font-size:13px;font-weight:700;color:#222;line-height:1.3;margin-bottom:4px;">{title}</div>' |
| f'<div style="font-size:22px;font-weight:800;color:{EVENT_COLOR};' |
| f'line-height:1;margin-bottom:6px;">{pct_str}</div>' |
| f'<div style="margin-bottom:4px;">{grid_html}</div>' |
| f'<div>{legend}</div>' |
| f'</div>' |
| ) |
|
|
|
|
| def _render_icon_carousel_page(probs_dict, idx): |
| """Render the full carousel HTML for a given outcome index (no JS).""" |
| outcome = ICON_OUTCOMES[idx] |
| total = len(ICON_OUTCOMES) |
| card_html = _render_single_icon_card(probs_dict.get(outcome, 0.0), outcome) |
| label = OUTCOME_TITLES.get(outcome, outcome) |
|
|
| |
| dots = "" |
| for i, o in enumerate(ICON_OUTCOMES): |
| active = i == idx |
| dots += ( |
| f'<span style="display:inline-block;width:{"10" if active else "7"}px;' |
| f'height:{"10" if active else "7"}px;border-radius:50%;' |
| f'background:{"#1565c0" if active else "#ccc"};' |
| f'margin:0 3px;vertical-align:middle;"></span>' |
| ) |
|
|
| footnote = ( |
| f'<div style="font-size:10px;color:#888;text-align:center;margin-top:10px;">' |
| f'Each figure = 1 patient out of 100. ' |
| f'<span style="color:{EVENT_COLOR};font-weight:600;">■ Red = Event</span> ' |
| f'<span style="color:{NO_EVENT_COLOR};font-weight:600;">■ Green = No Event</span>' |
| f'</div>' |
| ) |
|
|
| return ( |
| f'<div style="font-family:\'Segoe UI\',Arial,sans-serif;padding:4px 0;text-align:center;">' |
| f'<div style="font-size:15px;font-weight:700;color:#1565c0;margin-bottom:2px;">{label}</div>' |
| f'<div style="font-size:11px;color:#888;margin-bottom:10px;">{idx+1} / {total}</div>' |
| f'<div style="margin-bottom:2px;">{dots}</div>' |
| f'<div style="margin-top:10px;">{card_html}</div>' |
| f'{footnote}' |
| f'</div>' |
| ) |
|
|
|
|
| def _render_comparison_icon_page(all_probs_list, labels, colors, idx): |
| """Render comparison icon carousel page for a given outcome index (no JS).""" |
| outcome = ICON_OUTCOMES[idx] |
| total = len(ICON_OUTCOMES) |
| out_label = OUTCOME_TITLES.get(outcome, outcome) |
|
|
| |
| dots = "" |
| for i in range(total): |
| active = i == idx |
| dots += ( |
| f'<span style="display:inline-block;width:{"10" if active else "7"}px;' |
| f'height:{"10" if active else "7"}px;border-radius:50%;' |
| f'background:{"#1565c0" if active else "#ccc"};' |
| f'margin:0 3px;vertical-align:middle;"></span>' |
| ) |
|
|
| rows_html = "" |
| for probs, label, color in zip(all_probs_list, labels, colors): |
| prob = probs.get(outcome, 0.0) |
| card = _render_single_icon_card(prob, outcome, label, color) |
| rows_html += ( |
| f'<div style="display:flex;align-items:flex-start;gap:10px;margin-bottom:12px;">' |
| f'<div style="width:80px;flex-shrink:0;font-size:9px;font-weight:700;' |
| f'text-align:center;border-radius:4px;padding:4px 6px;color:#fff;' |
| f'background:{color};white-space:normal;word-break:break-word;">{label}</div>' |
| f'<div style="flex:1;">{card}</div>' |
| f'</div>' |
| ) |
|
|
| footnote = ( |
| f'<div style="font-size:10px;color:#888;text-align:center;margin-top:4px;">' |
| f'Each figure = 1 patient out of 100. ' |
| f'<span style="color:{EVENT_COLOR};font-weight:600;">■ Red = Event</span> ' |
| f'<span style="color:{NO_EVENT_COLOR};font-weight:600;">■ Green = No Event</span>' |
| f'</div>' |
| ) |
|
|
| return ( |
| f'<div style="font-family:\'Segoe UI\',Arial,sans-serif;padding:4px 0;text-align:center;">' |
| f'<div style="font-size:15px;font-weight:700;color:#1565c0;margin-bottom:2px;">{out_label}</div>' |
| f'<div style="font-size:11px;color:#888;margin-bottom:6px;">{idx+1} / {total}</div>' |
| f'<div style="margin-bottom:10px;">{dots}</div>' |
| f'{rows_html}' |
| f'{footnote}' |
| f'</div>' |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def _delta_color_html(delta, is_survival): |
| if abs(delta) < 0.0005: |
| color = "#888888" |
| elif (delta > 0 and is_survival) or (delta < 0 and not is_survival): |
| color = "#2e7d32" |
| else: |
| color = "#c62828" |
| sign = "+" if delta >= 0 else "" |
| return f'<span style="color:{color};font-weight:700;">{sign}{delta*100:.1f}%</span>' |
|
|
|
|
| def _build_comparison_table_html(base_probs, base_ci, scenario_probs_list, scenario_ci_list, scenario_labels, scenario_colors): |
| survival_outcomes = {"OS", "EFS"} |
|
|
| outcome_headers = "".join( |
| f"<th style='padding:8px 10px;text-align:center;border-bottom:2px solid #ccd;" |
| f"color:#333;font-size:11px;'>{OUTCOME_DESCRIPTIONS.get(o, o)}</th>" |
| for o in REPORTING_OUTCOMES if o in base_probs |
| ) |
|
|
| header = ( |
| "<div style='overflow-x:auto;'>" |
| "<table style='width:100%;border-collapse:collapse;" |
| "font-family:\"Segoe UI\",Arial,sans-serif;font-size:12px;'>" |
| "<thead><tr style='background:#f0f4f8;'>" |
| "<th style='padding:9px 12px;text-align:left;border-bottom:2px solid #ccd;" |
| "color:#333;min-width:120px;'>Scenario</th>" |
| f"{outcome_headers}" |
| "</tr></thead><tbody>" |
| ) |
|
|
| rows = "" |
|
|
| baseline_cells = "" |
| for o in REPORTING_OUTCOMES: |
| if o not in base_probs: |
| continue |
| bp = base_probs[o] |
| blo, bhi = base_ci.get(o, (float("nan"), float("nan"))) |
| baseline_cells += ( |
| f"<td style='padding:7px 10px;text-align:center;'>" |
| f"<div style='font-weight:700;color:#1565c0;font-size:13px;'>{bp*100:.1f}%</div>" |
| f"<div style='color:#5c7fa8;font-size:9px;'>[{blo*100:.1f}%β{bhi*100:.1f}%]</div>" |
| f"</td>" |
| ) |
| rows += ( |
| f"<tr style='background:#e8f0fb;'>" |
| f"<td style='padding:8px 12px;font-weight:700;color:#1565c0;border-right:2px solid #ccd;'>" |
| f"<div style='display:flex;align-items:center;gap:6px;'>" |
| f"<span style='display:inline-block;width:10px;height:10px;border-radius:50%;" |
| f"background:#1565c0;flex-shrink:0;'></span>Baseline</div></td>" |
| f"{baseline_cells}</tr>" |
| ) |
|
|
| for j, (sp_dict, sci_dict, s_label) in enumerate(zip(scenario_probs_list, scenario_ci_list, scenario_labels)): |
| sc_color = scenario_colors[j % len(scenario_colors)] |
| bg = "#fafbfc" if j % 2 == 0 else "#ffffff" |
| scenario_cells = "" |
| for o in REPORTING_OUTCOMES: |
| if o not in base_probs or o not in sp_dict: |
| continue |
| bp = base_probs[o] |
| wp = sp_dict[o] |
| wlo, whi = sci_dict.get(o, (float("nan"), float("nan"))) |
| delta = wp - bp |
| is_surv = o in survival_outcomes |
| scenario_cells += ( |
| f"<td style='padding:7px 10px;text-align:center;'>" |
| f"<div style='font-weight:700;color:{sc_color};font-size:13px;'>{wp*100:.1f}%</div>" |
| f"<div style='color:#888;font-size:9px;'>[{wlo*100:.1f}%β{whi*100:.1f}%]</div>" |
| f"<div style='font-size:10px;margin-top:1px;'>{_delta_color_html(delta, is_surv)}</div>" |
| f"</td>" |
| ) |
| rows += ( |
| f"<tr style='background:{bg};'>" |
| f"<td style='padding:8px 12px;font-weight:600;border-right:2px solid #ccd;'>" |
| f"<div style='display:flex;align-items:center;gap:6px;'>" |
| f"<span style='display:inline-block;width:10px;height:10px;border-radius:50%;" |
| f"background:{sc_color};flex-shrink:0;'></span>" |
| f"<span style='color:{sc_color};'>{s_label}</span></div></td>" |
| f"{scenario_cells}</tr>" |
| ) |
|
|
| footer = ( |
| "</tbody></table></div>" |
| "<div style='font-size:10.5px;color:#888;margin-top:6px;'>" |
| "Ξ from Baseline: <span style='color:#2e7d32;font-weight:600;'>Green = improvement</span> " |
| "<span style='color:#c62828;font-weight:600;'>Red = worsening</span> | " |
| "OS & EFS: higher is better; all other outcomes: lower is better." |
| "</div>" |
| ) |
| return header + rows + footer |
|
|
|
|
| def _build_violation_html(violations): |
| if not violations: |
| return "" |
| items = "".join(f"<li style='margin-bottom:6px;'>{v}</li>" for v in violations) |
| return ( |
| f'<div style="background:#fff3e0;border:2px solid #e65100;border-radius:8px;' |
| f'padding:14px 18px;font-family:\'Segoe UI\',Arial,sans-serif;margin-bottom:12px;">' |
| f'<div style="font-weight:700;font-size:14px;color:#bf360c;margin-bottom:8px;">' |
| f'Constraint Violations β Analysis blocked</div>' |
| f'<ul style="margin:0;padding-left:20px;color:#6d1f00;font-size:13px;">{items}</ul>' |
| f'<div style="margin-top:10px;font-size:11px;color:#888;">' |
| f'Please correct the above before running the comparison.</div>' |
| f'</div>' |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def _shap_counter_html(idx): |
| labels = [SHAP_LABELS.get(o, o) for o in SHAP_ORDER] |
| items = " Β· ".join( |
| f'<span style="font-weight:{"700" if i==idx else "400"};' |
| f'color:{"#1565c0" if i==idx else "#aaa"};font-size:11px;">{l}</span>' |
| for i, l in enumerate(labels) |
| ) |
| return f'<div style="text-align:center;padding:4px 0 2px;">{items}</div>' |
|
|
|
|
| |
| |
| |
|
|
| def predict_gradio(*values): |
| try: |
| user_vals = _values_to_dict(values) |
| missing = [f for f, v in user_vals.items() |
| if v is None or v == "" or (isinstance(v, float) and pd.isna(v))] |
| if missing: |
| raise ValueError(f"Please fill in all fields. Missing: {', '.join(missing)}") |
|
|
| calibrated, _ = predict_with_comparison(user_vals) |
| calibrated_probs, calibrated_intervals = calibrated |
|
|
| rows = [] |
| for outcome in REPORTING_OUTCOMES: |
| desc = OUTCOME_DESCRIPTIONS[outcome] |
| calib_prob = calibrated_probs[outcome] |
| ci_low, ci_high = calibrated_intervals[outcome] |
| rows.append({ |
| "Outcome": desc, |
| "Probability": f"{calib_prob * 100:.1f}%", |
| "95% CI": f"[{ci_low * 100:.1f}% β {ci_high * 100:.1f}%]", |
| }) |
| df = pd.DataFrame(rows) |
|
|
| shap_plots = create_all_shap_plots(user_vals, max_display=10) |
|
|
| |
| icon_html = _render_icon_carousel_page(calibrated_probs, 0) |
| first_shap = shap_plots[SHAP_ORDER[0]] |
| shap_crumb = _shap_counter_html(0) |
|
|
| return ( |
| df, |
| calibrated_probs, |
| 0, |
| icon_html, |
| shap_plots, |
| 0, |
| first_shap, |
| shap_crumb, |
| ) |
| except Exception as e: |
| print(traceback.format_exc()) |
| raise gr.Error(f"{type(e).__name__}: {str(e)}") |
|
|
|
|
| |
| |
| |
|
|
| custom_css = """ |
| .predict-button { |
| background: linear-gradient(to right, #ff6b35, #ff8c42) !important; |
| border: none !important; color: white !important; |
| font-weight: bold !important; font-size: 16px !important; padding: 12px !important; |
| } |
| .predict-button:hover { background: linear-gradient(to right, #ff5722, #ff7b29) !important; } |
| .copy-button { |
| background: linear-gradient(to right, #388e3c, #66bb6a) !important; |
| border: none !important; color: white !important; font-weight: 600 !important; |
| } |
| .copy-button:hover { background: linear-gradient(to right, #2e7d32, #43a047) !important; } |
| .copy-from-predict-button { |
| background: linear-gradient(to right, #6a1b9a, #ab47bc) !important; |
| border: none !important; color: white !important; font-weight: 600 !important; |
| } |
| .copy-from-predict-button:hover { background: linear-gradient(to right, #4a148c, #8e24aa) !important; } |
| .counterfactual-button { |
| background: linear-gradient(to right, #1976d2, #42a5f5) !important; |
| border: none !important; color: white !important; |
| font-weight: bold !important; font-size: 15px !important; padding: 12px !important; |
| } |
| .counterfactual-button:hover { background: linear-gradient(to right, #1565c0, #1e88e5) !important; } |
| .shap-nav-button { |
| background: linear-gradient(to right, #37474f, #546e7a) !important; |
| border: none !important; color: white !important; |
| font-weight: bold !important; font-size: 18px !important; |
| padding: 8px 22px !important; border-radius: 6px !important; |
| } |
| .shap-nav-button:hover { background: linear-gradient(to right, #263238, #455a64) !important; } |
| .icon-nav-button { |
| background: linear-gradient(to right, #1565c0, #1e88e5) !important; |
| border: none !important; color: white !important; |
| font-weight: bold !important; font-size: 18px !important; |
| padding: 8px 22px !important; border-radius: 6px !important; |
| } |
| .icon-nav-button:hover { background: linear-gradient(to right, #0d47a1, #1565c0) !important; } |
| .output-dataframe table td:first-child, |
| .output-dataframe table th:first-child { |
| white-space: normal !important; word-break: break-word !important; min-width: 240px !important; |
| } |
| .constraint-info { |
| background: #e8f5e9; border-left: 4px solid #388e3c; |
| padding: 8px 14px; font-size: 12px; color: #1b5e20; |
| border-radius: 4px; margin-bottom: 8px; |
| } |
| .scenario-panel-0 { border-left: 4px solid #e65100 !important; } |
| .scenario-panel-1 { border-left: 4px solid #6a1b9a !important; } |
| .scenario-panel-2 { border-left: 4px solid #1b5e20 !important; } |
| .scenario-panel-3 { border-left: 4px solid #0d47a1 !important; } |
| .scenario-panel-4 { border-left: 4px solid #b71c1c !important; } |
| """ |
|
|
|
|
| |
| |
| |
|
|
| with gr.Blocks(title="HCT Outcome Prediction Model", css=custom_css) as demo: |
| gr.Markdown("# HCT Outcome Prediction Model") |
|
|
| n_scenarios_state = gr.State(1) |
|
|
| with gr.Tabs(): |
|
|
| |
| |
| |
| with gr.Tab("Predict Outcomes"): |
| gr.Markdown("Enter patient, transplant, and disease characteristics to predict outcomes.") |
|
|
| inputs_dict = {} |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### Patient Characteristics") |
| for f in PATIENT_FEATURES: |
| inputs_dict[f] = make_component(f) |
|
|
| with gr.Column(scale=1): |
| gr.Markdown("### Transplant Characteristics") |
| grouped_dd = gr.Dropdown( |
| choices=GROUPED_REGIMEN_CHOICES, value=None, |
| label="Published conditioning regimen", |
| info="Auto-fills Donor Type, Conditioning Intensity, Regimen, Serotherapy, GVHD Prophylaxis", |
| ) |
| p_donorf = inputs_dict["DONORF"] = make_component("DONORF") |
| inputs_dict["GRAFTYPE"] = make_component("GRAFTYPE") |
| p_condgrpf = inputs_dict["CONDGRPF"] = make_component("CONDGRPF") |
| p_condgrp_final = inputs_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL") |
| p_atgf = inputs_dict["ATGF"] = make_component("ATGF") |
| p_gvhd_final = inputs_dict["GVHD_FINAL"] = make_component("GVHD_FINAL") |
| p_hla_final = inputs_dict["HLA_FINAL"] = make_component("HLA_FINAL") |
|
|
| with gr.Column(scale=1): |
| gr.Markdown("### Disease Characteristics") |
| for f in DISEASE_FEATURES: |
| inputs_dict[f] = make_component(f) |
|
|
| inputs_dict["AGE"].change(get_age_group, inputs_dict["AGE"], inputs_dict["AGEGPFF"]) |
| inputs_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, inputs_dict["VOC2YPR"], inputs_dict["VOCFRQPR"]) |
| p_donorf.change(fn=_hla_update_for_donor, inputs=p_donorf, outputs=p_hla_final) |
| grouped_dd.change( |
| apply_grouped_preset, grouped_dd, |
| [grouped_dd, p_donorf, p_condgrpf, p_condgrp_final, p_atgf, p_gvhd_final, p_hla_final], |
| ) |
|
|
| inputs_list = [inputs_dict[f] for f in ALL_FEATURES] |
| predict_btn = gr.Button("Predict", elem_classes="predict-button", size="lg") |
|
|
| gr.Markdown("---") |
| gr.Markdown("## Prediction Results") |
| output_table = gr.Dataframe( |
| headers=["Outcome", "Probability", "95% CI"], |
| label="Predicted Outcomes", |
| elem_classes="output-dataframe", |
| row_count=(len(REPORTING_OUTCOMES), "dynamic"), |
| column_count=(3, "fixed"), |
| wrap=True, |
| ) |
|
|
| |
| gr.Markdown("---") |
| gr.Markdown("## Outcome Probability β Icon Arrays") |
| gr.Markdown("*Use the β β arrows to browse outcomes one at a time.*") |
|
|
| |
| icon_probs_state = gr.State(None) |
| icon_idx_state = gr.State(0) |
|
|
| with gr.Row(): |
| icon_prev_btn = gr.Button("β", elem_classes="icon-nav-button", size="sm", scale=0) |
| icon_display = gr.HTML(scale=4) |
| icon_next_btn = gr.Button("β", elem_classes="icon-nav-button", size="sm", scale=0) |
|
|
| def _on_icon_prev(idx, probs): |
| if probs is None: |
| return idx, "" |
| new_idx = max(0, idx - 1) |
| return new_idx, _render_icon_carousel_page(probs, new_idx) |
|
|
| def _on_icon_next(idx, probs): |
| if probs is None: |
| return idx, "" |
| new_idx = min(len(ICON_OUTCOMES) - 1, idx + 1) |
| return new_idx, _render_icon_carousel_page(probs, new_idx) |
|
|
| icon_prev_btn.click( |
| fn=_on_icon_prev, |
| inputs=[icon_idx_state, icon_probs_state], |
| outputs=[icon_idx_state, icon_display], |
| ) |
| icon_next_btn.click( |
| fn=_on_icon_next, |
| inputs=[icon_idx_state, icon_probs_state], |
| outputs=[icon_idx_state, icon_display], |
| ) |
|
|
| |
| gr.Markdown("---") |
| gr.Markdown("## SHAP β Feature Importance") |
| gr.Markdown("*Use the β β buttons to browse SHAP plots one at a time.*") |
|
|
| shap_plots_state = gr.State(None) |
| shap_idx_state = gr.State(0) |
|
|
| with gr.Row(): |
| shap_prev_btn = gr.Button("β", elem_classes="shap-nav-button", size="sm", scale=0) |
| shap_crumb = gr.HTML(value="", scale=4) |
| shap_next_btn = gr.Button("β", elem_classes="shap-nav-button", size="sm", scale=0) |
|
|
| shap_display = gr.Plot(label="SHAP Feature Importance") |
|
|
| def _on_shap_prev(idx, plots): |
| new_idx = max(0, idx - 1) |
| plot = plots[SHAP_ORDER[new_idx]] if plots else None |
| return new_idx, plot, _shap_counter_html(new_idx) |
|
|
| def _on_shap_next(idx, plots): |
| new_idx = min(len(SHAP_ORDER) - 1, idx + 1) |
| plot = plots[SHAP_ORDER[new_idx]] if plots else None |
| return new_idx, plot, _shap_counter_html(new_idx) |
|
|
| shap_prev_btn.click( |
| fn=_on_shap_prev, |
| inputs=[shap_idx_state, shap_plots_state], |
| outputs=[shap_idx_state, shap_display, shap_crumb], |
| ) |
| shap_next_btn.click( |
| fn=_on_shap_next, |
| inputs=[shap_idx_state, shap_plots_state], |
| outputs=[shap_idx_state, shap_display, shap_crumb], |
| ) |
|
|
| predict_btn.click( |
| fn=predict_gradio, |
| inputs=inputs_list, |
| outputs=[ |
| output_table, |
| icon_probs_state, |
| icon_idx_state, |
| icon_display, |
| shap_plots_state, |
| shap_idx_state, |
| shap_display, |
| shap_crumb, |
| ], |
| ) |
|
|
| |
| |
| |
| with gr.Tab("Counterfactual Scenarios"): |
| gr.Markdown( |
| "## Counterfactual Scenario Analysis\n" |
| "Enter the **baseline** patient, then choose how many counterfactual " |
| "scenarios you want to compare. Each scenario panel will appear below." |
| ) |
|
|
| with gr.Row(): |
| n_scenarios_slider = gr.Slider( |
| minimum=1, maximum=MAX_SCENARIOS, step=1, value=1, |
| label=f"How many counterfactual scenarios do you want to compare? (1β{MAX_SCENARIOS})", |
| info="Adjust this first β scenario panels will appear/disappear below.", |
| ) |
|
|
| gr.Markdown("---") |
|
|
| |
| gr.Markdown("## Baseline Patient Profile") |
|
|
| with gr.Row(): |
| copy_from_predict_btn = gr.Button( |
| "Copy from Predict tab β Baseline", |
| elem_classes="copy-from-predict-button", |
| size="sm", |
| ) |
|
|
| wi_baseline_dict = {} |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### Patient Characteristics") |
| for f in PATIENT_FEATURES: |
| wi_baseline_dict[f] = make_component(f) |
|
|
| with gr.Column(scale=1): |
| gr.Markdown("### Transplant Characteristics") |
| wi_grouped_base = gr.Dropdown( |
| choices=GROUPED_REGIMEN_CHOICES, value=None, |
| label="Published conditioning regimen (Baseline)", |
| ) |
| wb_donorf = wi_baseline_dict["DONORF"] = make_component("DONORF") |
| wi_baseline_dict["GRAFTYPE"] = make_component("GRAFTYPE") |
| wb_condgrpf = wi_baseline_dict["CONDGRPF"] = make_component("CONDGRPF") |
| wb_condgrp_final = wi_baseline_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL") |
| wb_atgf = wi_baseline_dict["ATGF"] = make_component("ATGF") |
| wb_gvhd_final = wi_baseline_dict["GVHD_FINAL"] = make_component("GVHD_FINAL") |
| wb_hla_final = wi_baseline_dict["HLA_FINAL"] = make_component("HLA_FINAL") |
|
|
| with gr.Column(scale=1): |
| gr.Markdown("### Disease Characteristics") |
| for f in DISEASE_FEATURES: |
| wi_baseline_dict[f] = make_component(f) |
|
|
| wi_baseline_dict["AGE"].change(get_age_group, wi_baseline_dict["AGE"], wi_baseline_dict["AGEGPFF"]) |
| wi_baseline_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, wi_baseline_dict["VOC2YPR"], wi_baseline_dict["VOCFRQPR"]) |
| wb_donorf.change(fn=_hla_update_for_donor, inputs=wb_donorf, outputs=wb_hla_final) |
| wi_grouped_base.change( |
| apply_grouped_preset, wi_grouped_base, |
| [wi_grouped_base, wb_donorf, wb_condgrpf, wb_condgrp_final, wb_atgf, wb_gvhd_final, wb_hla_final], |
| ) |
|
|
| wi_baseline_list = [wi_baseline_dict[f] for f in ALL_FEATURES] |
|
|
| def _copy_predict_to_baseline(*vals): |
| n = len(ALL_FEATURES) |
| feature_vals = vals[:n] |
| regimen_value = vals[n] if len(vals) > n else None |
| preset_outputs = apply_grouped_preset(regimen_value) |
| regimen_dd_upd = gr.update(value=regimen_value) |
| donorf_upd, condgrpf_upd, condgrp_final_upd, atgf_upd, gvhd_final_upd, hla_final_upd = preset_outputs[1:] |
| return (*list(feature_vals), regimen_dd_upd, donorf_upd, condgrpf_upd, |
| condgrp_final_upd, atgf_upd, gvhd_final_upd, hla_final_upd) |
|
|
| copy_from_predict_btn.click( |
| fn=_copy_predict_to_baseline, |
| inputs=inputs_list + [grouped_dd], |
| outputs=wi_baseline_list + [wi_grouped_base, wb_donorf, wb_condgrpf, |
| wb_condgrp_final, wb_atgf, wb_gvhd_final, wb_hla_final], |
| ) |
|
|
| gr.Markdown("---") |
|
|
| |
| scenario_dicts = [] |
| scenario_lists = [] |
| scenario_grouped_dds = [] |
| scenario_rows = [] |
| scenario_name_inputs = [] |
|
|
| scenario_donorf_handles = [] |
| scenario_hla_handles = [] |
| scenario_voc2ypr_handles = [] |
| scenario_age_handles = [] |
| scenario_agegp_handles = [] |
| scenario_copy_btns = [] |
|
|
| for s_idx in range(MAX_SCENARIOS): |
| color = SCENARIO_COLORS[s_idx] |
| label = f"Scenario {s_idx + 1}" |
| suffix = f"({label})" |
| visible_init = (s_idx == 0) |
|
|
| with gr.Row(visible=visible_init) as s_row: |
| scenario_rows.append(s_row) |
| with gr.Column(): |
| gr.HTML( |
| f'<div style="background:{color};color:#fff;font-weight:700;' |
| f'font-size:15px;padding:8px 14px;border-radius:6px;margin-bottom:6px;">' |
| f'Counterfactual {label}</div>' |
| ) |
| with gr.Row(): |
| s_name_input = gr.Textbox( |
| label=f"Scenario name ({label})", |
| value=label, |
| placeholder=f"e.g. {label}", |
| scale=2, |
| ) |
| scenario_name_inputs.append(s_name_input) |
|
|
| copy_btn = gr.Button( |
| f"β¬ Copy Baseline β {label}", |
| elem_classes="copy-button", |
| size="sm", |
| scale=1, |
| ) |
| scenario_copy_btns.append(copy_btn) |
|
|
| s_dict = {} |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown(f"#### Patient β {label}") |
| for f in PATIENT_FEATURES: |
| s_dict[f] = make_component(f, suffix) |
|
|
| with gr.Column(scale=1): |
| gr.Markdown(f"#### Transplant β {label}") |
| s_grouped_dd = gr.Dropdown( |
| choices=GROUPED_REGIMEN_CHOICES, value=None, |
| label=f"Published regimen ({label})", |
| ) |
| scenario_grouped_dds.append(s_grouped_dd) |
| s_donorf = s_dict["DONORF"] = make_component("DONORF", suffix) |
| s_dict["GRAFTYPE"] = make_component("GRAFTYPE", suffix) |
| s_condgrpf = s_dict["CONDGRPF"] = make_component("CONDGRPF", suffix) |
| s_condgrp_final = s_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL", suffix) |
| s_atgf = s_dict["ATGF"] = make_component("ATGF", suffix) |
| s_gvhd_final = s_dict["GVHD_FINAL"] = make_component("GVHD_FINAL", suffix) |
| s_hla_final = s_dict["HLA_FINAL"] = make_component("HLA_FINAL", suffix) |
|
|
| with gr.Column(scale=1): |
| gr.Markdown(f"#### Disease β {label}") |
| for f in DISEASE_FEATURES: |
| s_dict[f] = make_component(f, suffix) |
|
|
| scenario_dicts.append(s_dict) |
| scenario_lists.append([s_dict[f] for f in ALL_FEATURES]) |
| scenario_donorf_handles.append(s_donorf) |
| scenario_hla_handles.append(s_hla_final) |
| scenario_voc2ypr_handles.append(s_dict["VOC2YPR"]) |
| scenario_age_handles.append(s_dict["AGE"]) |
| scenario_agegp_handles.append(s_dict["AGEGPFF"]) |
|
|
| s_dict["AGE"].change(get_age_group, s_dict["AGE"], s_dict["AGEGPFF"]) |
| s_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, s_dict["VOC2YPR"], s_dict["VOCFRQPR"]) |
| s_donorf.change(fn=_hla_update_for_donor, inputs=s_donorf, outputs=s_hla_final) |
| s_grouped_dd.change( |
| apply_grouped_preset, s_grouped_dd, |
| [s_grouped_dd, s_donorf, s_condgrpf, s_condgrp_final, s_atgf, s_gvhd_final, s_hla_final], |
| ) |
|
|
| wi_baseline_dict["SEX"].change( |
| fn=lock_sex, |
| inputs=wi_baseline_dict["SEX"], |
| outputs=s_dict["SEX"], |
| ) |
| s_dict["SEX"].interactive = False |
|
|
| def _make_copy_fn(s_grouped_dd_ref, s_donorf_ref, s_condgrpf_ref, |
| s_condgrp_final_ref, s_atgf_ref, s_gvhd_final_ref, |
| s_hla_final_ref): |
| def _copy(*vals): |
| n = len(ALL_FEATURES) |
| feature_vals = vals[:n] |
| regimen_value = vals[n] if len(vals) > n else None |
| preset_outputs = apply_grouped_preset(regimen_value) |
| regimen_dd_upd = gr.update(value=regimen_value) |
| d_upd, c_upd, cr_upd, a_upd, g_upd, h_upd = preset_outputs[1:] |
| return (*list(feature_vals), regimen_dd_upd, |
| d_upd, c_upd, cr_upd, a_upd, g_upd, h_upd) |
| return _copy |
|
|
| copy_btn.click( |
| fn=_make_copy_fn(s_grouped_dd, s_donorf, s_condgrpf, |
| s_condgrp_final, s_atgf, s_gvhd_final, s_hla_final), |
| inputs=wi_baseline_list + [wi_grouped_base], |
| outputs=(scenario_lists[-1] |
| + [s_grouped_dd, s_donorf, s_condgrpf, |
| s_condgrp_final, s_atgf, s_gvhd_final, s_hla_final]), |
| ) |
| copy_btn.click( |
| fn=lock_sex, |
| inputs=wi_baseline_dict["SEX"], |
| outputs=s_dict["SEX"], |
| ) |
|
|
| |
| def _update_scenario_visibility(n): |
| return [gr.update(visible=(i < int(n))) for i in range(MAX_SCENARIOS)] |
|
|
| n_scenarios_slider.change( |
| fn=_update_scenario_visibility, |
| inputs=n_scenarios_slider, |
| outputs=scenario_rows, |
| ) |
| n_scenarios_slider.change( |
| fn=lambda n: int(n), |
| inputs=n_scenarios_slider, |
| outputs=n_scenarios_state, |
| ) |
|
|
| |
| gr.Markdown("---") |
| wi_run_btn = gr.Button( |
| "Run Counterfactual Comparison", |
| elem_classes="counterfactual-button", |
| size="lg", |
| ) |
|
|
| |
| gr.Markdown("## Comparison Results") |
| wi_violation_html = gr.HTML() |
| gr.Markdown("### Outcome Probability Table") |
| wi_table_html = gr.HTML() |
|
|
| gr.Markdown("---") |
|
|
| |
| with gr.Accordion("Outcome Icon Arrays β Baseline vs Scenarios", open=False): |
| gr.Markdown( |
| "*One row per scenario. Use the β β arrows to browse outcomes one at a time.*" |
| ) |
|
|
| wi_cmp_probs_state = gr.State(None) |
| wi_cmp_labels_state = gr.State(None) |
| wi_cmp_colors_state = gr.State(None) |
| wi_cmp_idx_state = gr.State(0) |
|
|
| with gr.Row(): |
| wi_cmp_prev_btn = gr.Button("β", elem_classes="icon-nav-button", size="sm", scale=0) |
| wi_icon_html = gr.HTML(scale=4) |
| wi_cmp_next_btn = gr.Button("β", elem_classes="icon-nav-button", size="sm", scale=0) |
|
|
| def _on_cmp_prev(idx, probs_list, labels, colors): |
| if probs_list is None: |
| return idx, "" |
| new_idx = max(0, idx - 1) |
| return new_idx, _render_comparison_icon_page(probs_list, labels, colors, new_idx) |
|
|
| def _on_cmp_next(idx, probs_list, labels, colors): |
| if probs_list is None: |
| return idx, "" |
| new_idx = min(len(ICON_OUTCOMES) - 1, idx + 1) |
| return new_idx, _render_comparison_icon_page(probs_list, labels, colors, new_idx) |
|
|
| wi_cmp_prev_btn.click( |
| fn=_on_cmp_prev, |
| inputs=[wi_cmp_idx_state, wi_cmp_probs_state, wi_cmp_labels_state, wi_cmp_colors_state], |
| outputs=[wi_cmp_idx_state, wi_icon_html], |
| ) |
| wi_cmp_next_btn.click( |
| fn=_on_cmp_next, |
| inputs=[wi_cmp_idx_state, wi_cmp_probs_state, wi_cmp_labels_state, wi_cmp_colors_state], |
| outputs=[wi_cmp_idx_state, wi_icon_html], |
| ) |
|
|
| gr.Markdown("---") |
|
|
| |
| with gr.Accordion("SHAP Feature Importance", open=False): |
| |
| gr.Markdown("### Baseline") |
| wi_shap_base_store = gr.State(None) |
| wi_shap_base_idx = gr.State(0) |
|
|
| with gr.Row(): |
| wb_shap_prev = gr.Button("β", elem_classes="shap-nav-button", size="sm", scale=0) |
| wb_shap_crumb = gr.HTML(value="", scale=4) |
| wb_shap_next = gr.Button("β", elem_classes="shap-nav-button", size="sm", scale=0) |
| wb_shap_plot = gr.Plot(label="Baseline β SHAP") |
|
|
| def _wb_prev(idx, plots): |
| new_idx = max(0, idx - 1) |
| plot = plots[SHAP_ORDER[new_idx]] if plots else None |
| return new_idx, plot, _shap_counter_html(new_idx) |
|
|
| def _wb_next(idx, plots): |
| new_idx = min(len(SHAP_ORDER) - 1, idx + 1) |
| plot = plots[SHAP_ORDER[new_idx]] if plots else None |
| return new_idx, plot, _shap_counter_html(new_idx) |
|
|
| wb_shap_prev.click(_wb_prev, [wi_shap_base_idx, wi_shap_base_store], |
| [wi_shap_base_idx, wb_shap_plot, wb_shap_crumb]) |
| wb_shap_next.click(_wb_next, [wi_shap_base_idx, wi_shap_base_store], |
| [wi_shap_base_idx, wb_shap_plot, wb_shap_crumb]) |
|
|
| |
| wi_shap_scen_stores = [] |
| wi_shap_scen_idxs = [] |
| wi_shap_scen_plots = [] |
| wi_shap_scen_crumbs = [] |
| wi_shap_scen_rows = [] |
|
|
| for s_idx in range(MAX_SCENARIOS): |
| with gr.Row(visible=(s_idx == 0)) as shap_row: |
| wi_shap_scen_rows.append(shap_row) |
| with gr.Column(): |
| scen_color = SCENARIO_COLORS[s_idx] |
| gr.HTML( |
| f'<div style="font-weight:700;font-size:13px;color:{scen_color};">' |
| f'Scenario {s_idx+1}</div>' |
| ) |
| s_store = gr.State(None) |
| s_idx_s = gr.State(0) |
| wi_shap_scen_stores.append(s_store) |
| wi_shap_scen_idxs.append(s_idx_s) |
|
|
| with gr.Row(): |
| s_prev_btn = gr.Button("β", elem_classes="shap-nav-button", size="sm", scale=0) |
| s_crumb = gr.HTML(value="", scale=4) |
| s_next_btn = gr.Button("β", elem_classes="shap-nav-button", size="sm", scale=0) |
| s_plot = gr.Plot(label=f"Scenario {s_idx+1} β SHAP") |
|
|
| wi_shap_scen_plots.append(s_plot) |
| wi_shap_scen_crumbs.append(s_crumb) |
|
|
| def _make_prev(st, ix): |
| def fn(idx, plots): |
| new_idx = max(0, idx - 1) |
| plot = plots[SHAP_ORDER[new_idx]] if plots else None |
| return new_idx, plot, _shap_counter_html(new_idx) |
| return fn |
|
|
| def _make_next(st, ix): |
| def fn(idx, plots): |
| new_idx = min(len(SHAP_ORDER) - 1, idx + 1) |
| plot = plots[SHAP_ORDER[new_idx]] if plots else None |
| return new_idx, plot, _shap_counter_html(new_idx) |
| return fn |
|
|
| s_prev_btn.click( |
| _make_prev(s_store, s_idx_s), |
| [s_idx_s, s_store], |
| [s_idx_s, s_plot, s_crumb], |
| ) |
| s_next_btn.click( |
| _make_next(s_store, s_idx_s), |
| [s_idx_s, s_store], |
| [s_idx_s, s_plot, s_crumb], |
| ) |
|
|
| |
| def _update_shap_vis(n): |
| return [gr.update(visible=(i < int(n))) for i in range(MAX_SCENARIOS)] |
|
|
| n_scenarios_slider.change( |
| fn=_update_shap_vis, |
| inputs=n_scenarios_slider, |
| outputs=wi_shap_scen_rows, |
| ) |
|
|
| |
| def run_counterfactual(*all_values): |
| n_feat = len(ALL_FEATURES) |
| n_scen = int(all_values[0]) |
| base_vals = all_values[1 : 1 + n_feat] |
|
|
| name_offset = 1 + n_feat |
| scenario_names_raw = all_values[name_offset : name_offset + MAX_SCENARIOS] |
|
|
| feat_offset = name_offset + MAX_SCENARIOS |
| scenario_val_blocks = [ |
| all_values[feat_offset + s * n_feat : feat_offset + (s + 1) * n_feat] |
| for s in range(MAX_SCENARIOS) |
| ] |
|
|
| def _empty_outputs(): |
| |
| |
| |
| out = ["", "", None, None, None, 0, ""] |
| out += [None, 0, None, ""] |
| for _ in range(MAX_SCENARIOS): |
| out += [None, 0, None, ""] |
| return tuple(out) |
|
|
| try: |
| base_dict = _values_to_dict(base_vals) |
| _check_missing(base_dict, "Baseline") |
|
|
| all_violations = [] |
| scenario_dicts_active = [] |
| for s in range(n_scen): |
| sd = _values_to_dict(scenario_val_blocks[s]) |
| _check_missing(sd, f"Scenario {s+1}") |
| v = _validate_counterfactual_constraints(base_dict, sd, f"Scenario {s+1}") |
| all_violations.extend(v) |
| scenario_dicts_active.append(sd) |
|
|
| if all_violations: |
| out = [_build_violation_html(all_violations), "", None, None, None, 0, ""] |
| out += [None, 0, None, ""] |
| for _ in range(MAX_SCENARIOS): |
| out += [None, 0, None, ""] |
| return tuple(out) |
|
|
| base_probs, base_ci = predict_all_outcomes( |
| base_dict, use_calibration=True, use_signed_voting=True, n_boot_ci=500 |
| ) |
|
|
| scen_probs_list = [] |
| scen_ci_list = [] |
| for sd in scenario_dicts_active: |
| sp, sci = predict_all_outcomes( |
| sd, use_calibration=True, use_signed_voting=True, n_boot_ci=500 |
| ) |
| scen_probs_list.append(sp) |
| scen_ci_list.append(sci) |
|
|
| scen_labels = [] |
| for i in range(n_scen): |
| raw_name = scenario_names_raw[i] if i < len(scenario_names_raw) else "" |
| name = str(raw_name).strip() if raw_name else "" |
| scen_labels.append(name if name else f"Scenario {i+1}") |
|
|
| scen_colors = SCENARIO_COLORS[:n_scen] |
|
|
| table_html = _build_comparison_table_html( |
| base_probs, base_ci, scen_probs_list, scen_ci_list, scen_labels, scen_colors |
| ) |
|
|
| |
| all_probs_list = [base_probs] + scen_probs_list |
| all_labels = ["Baseline"] + scen_labels |
| all_colors = ["#1565c0"] + scen_colors |
| icon_html_first = _render_comparison_icon_page(all_probs_list, all_labels, all_colors, 0) |
|
|
| |
| base_shap_plots = create_all_shap_plots(base_dict, max_display=10) |
| first_base_plot = base_shap_plots[SHAP_ORDER[0]] |
| base_crumb = _shap_counter_html(0) |
|
|
| |
| scen_shap_data = [] |
| for s in range(MAX_SCENARIOS): |
| if s < n_scen: |
| sp_plots = create_all_shap_plots(scenario_dicts_active[s], max_display=10) |
| scen_shap_data.append((sp_plots, sp_plots[SHAP_ORDER[0]], _shap_counter_html(0))) |
| else: |
| scen_shap_data.append((None, None, "")) |
|
|
| out = [ |
| "", |
| table_html, |
| all_probs_list, |
| all_labels, |
| all_colors, |
| 0, |
| icon_html_first, |
| base_shap_plots, |
| 0, |
| first_base_plot, |
| base_crumb, |
| ] |
| for s in range(MAX_SCENARIOS): |
| store, plot0, crumb0 = scen_shap_data[s] |
| out += [store, 0, plot0, crumb0] |
|
|
| return tuple(out) |
|
|
| except Exception as e: |
| print(traceback.format_exc()) |
| raise gr.Error(f"{type(e).__name__}: {str(e)}") |
|
|
| |
| all_run_outputs = ( |
| [wi_violation_html, wi_table_html, |
| wi_cmp_probs_state, wi_cmp_labels_state, wi_cmp_colors_state, |
| wi_cmp_idx_state, wi_icon_html] |
| + [wi_shap_base_store, wi_shap_base_idx, wb_shap_plot, wb_shap_crumb] |
| + [item for s in range(MAX_SCENARIOS) |
| for item in (wi_shap_scen_stores[s], wi_shap_scen_idxs[s], |
| wi_shap_scen_plots[s], wi_shap_scen_crumbs[s])] |
| ) |
|
|
| all_run_inputs = ( |
| [n_scenarios_state] |
| + wi_baseline_list |
| + scenario_name_inputs |
| + [feat for s_list in scenario_lists for feat in s_list] |
| ) |
|
|
| wi_run_btn.click( |
| fn=run_counterfactual, |
| inputs=all_run_inputs, |
| outputs=all_run_outputs, |
| ) |
|
|
|
|
| |
| if __name__ == "__main__": |
| demo.launch(ssr_mode=False) |