| import gradio as gr |
| import pandas as pd |
| import numpy as np |
| import copy |
| import traceback |
|
|
| from inference import ( |
| FEATURE_NAMES, |
| REPORTING_OUTCOMES, |
| OUTCOME_DESCRIPTIONS, |
| OUTCOMES, |
| SHAP_OUTCOMES, |
| predict_with_comparison, |
| predict_all_outcomes, |
| create_all_shap_plots, |
| create_all_icon_arrays, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| AGEGPFF_CHOICES = ["<=10", "11-17", "18-29", "30-49", ">=50"] |
| SEX_CHOICES = ["Male", "Female"] |
| KPS_CHOICES = ["<90", "≥ 90"] |
| DONORF_CHOICES = [ |
| "HLA identical sibling", "HLA mismatch relative", |
| "Matched unrelated donor", "Mismatched unrelated donor or cord blood", |
| ] |
| GRAFTYPE_CHOICES = ["Bone marrow", "Peripheral blood", "Cord blood"] |
| CONDGRPF_CHOICES = ["MAC", "RIC", "NMA"] |
| CONDGRP_FINAL_CHOICES = [ |
| "TBI/Cy", "TBI/Cy/Flu", "TBI/Cy/Flu/TT", "TBI/Mel", "TBI/Flu", |
| "TBI alone (300/400/600cGy)", "Bu/Cy", "Bu/Mel", "Flu/Bu/TT", |
| "Flu/Bu", "Flu/Mel/TT", "Flu/Mel", "Cy/Flu", "Treosulfan", |
| "Cy alone", "Flud", "TLI", |
| ] |
| ATGF_CHOICES = ["ATG", "Alemtuzumab", "None"] |
| GVHD_FINAL_CHOICES = [ |
| "Ex-vivo T-cell depletion", "CD34 selection", "Post-CY + siro +- MMF", |
| "Post-CY + MMF + CNI", "CNI + MMF", "CNI + MTX", "CNI alone", |
| "CNI + siro", "Siro alone", "MMF + MTX", "MMF + siro", "MMF alone", |
| "MTX alone", "MTX + siro", |
| ] |
| HLA_FINAL_CHOICES = ["8/8", "7/8", "≤ 6/8"] |
| RCMVPR_CHOICES = ["Negative", "Positive"] |
| EXCHTFPR_CHOICES = ["No", "Yes"] |
| VOC2YPR_CHOICES = ["No", "Yes"] |
| VOCFRQPR_CHOICES = ["< 3/yr", "≥ 3/yr"] |
| SCATXRSN_CHOICES = [ |
| "CNS event", "Acute chest Syndrome", "Recurrent vaso-occlusive pain", |
| "Recurrent priapism", "Excessive transfusion requirements/iron overload", |
| "Cardio-pulmonary", "Chronic transfusion", "Asymptomatic", |
| "Renal insufficiency", "Splenic sequestration", "Avascular necrosis", |
| "Hodgkin lymphoma", |
| ] |
|
|
| NUM_COLS_SET = {"AGE", "NACS2YR"} |
| MAX_SCENARIOS = 5 |
|
|
| GROUPED_REGIMEN_CHOICES = [ |
| ("── HLA IDENTICAL ──", "__header_hla_identical__"), |
| ("Hsieh et al 2014", "Hsieh et al 2014"), |
| ("Krishnamurti et al 2019", "Krishnamurti et al 2019"), |
| ("King et al 2015", "King et al 2015"), |
| ("Walters et al 1996", "Walters et al 1996"), |
| ("── HLA MISMATCHED ──", "__header_hla_mismatched__"), |
| ("Bolanos-Meade et al 2022 (HLA Mismatch)", "Bolanos-Meade et al 2022 (HLA Mismatch)"), |
| ("Patel et al 2020 (HLA Mismatch)", "Patel et al 2020 (HLA Mismatch)"), |
| ("── MATCHED UNRELATED ──", "__header_matched_unrelated__"), |
| ("L Krishnamurti et al 2019", "L Krishnamurti et al 2019"), |
| ("Shenoy et al 2016", "Shenoy et al 2016"), |
| ("── MISMATCHED UNRELATED / CORD BLOOD ──", "__header_mismatched_cord__"), |
| ("Bolanos-Meade et al 2022 (Mismatched/Cord)", "Bolanos-Meade et al 2022 (Mismatched/Cord)"), |
| ("Patel et al 2020 (Mismatched/Cord)", "Patel et al 2020 (Mismatched/Cord)"), |
| ("── CUSTOM ──", "__header_custom__"), |
| ("Custom", "Custom"), |
| ] |
| HEADER_VALUES = {v for _, v in GROUPED_REGIMEN_CHOICES if v.startswith("__header_")} |
|
|
| PUBLISHED_PRESETS = { |
| "Hsieh et al 2014": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI alone (300/400/600cGy)", "ATGF": "Alemtuzumab", "GVHD_FINAL": "Siro alone", "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling"}, |
| "Krishnamurti et al 2019": {"CONDGRPF": "MAC", "CONDGRP_FINAL": "Flu/Bu", "ATGF": "ATG", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling"}, |
| "King et al 2015": {"CONDGRPF": "RIC", "CONDGRP_FINAL": "Flu/Mel", "ATGF": "Alemtuzumab", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling"}, |
| "Walters et al 1996": {"CONDGRPF": "MAC", "CONDGRP_FINAL": "Bu/Cy", "ATGF": "ATG", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling"}, |
| "Bolanos-Meade et al 2022 (HLA Mismatch)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "HLA mismatch relative"}, |
| "Patel et al 2020 (HLA Mismatch)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu/TT", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "HLA mismatch relative"}, |
| "L Krishnamurti et al 2019": {"CONDGRPF": "MAC", "CONDGRP_FINAL": "Flu/Bu", "ATGF": "ATG", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "Matched unrelated donor"}, |
| "Shenoy et al 2016": {"CONDGRPF": "RIC", "CONDGRP_FINAL": "Flu/Mel", "ATGF": "Alemtuzumab", "GVHD_FINAL": "CNI + MTX", "HLA_FINAL": "8/8", "DONORF": "Matched unrelated donor"}, |
| "Bolanos-Meade et al 2022 (Mismatched/Cord)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "Mismatched unrelated donor or cord blood"}, |
| "Patel et al 2020 (Mismatched/Cord)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu/TT", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "Mismatched unrelated donor or cord blood"}, |
| } |
|
|
| MSD_DONOR = "HLA identical sibling" |
| MUD_DONOR = "Matched unrelated donor" |
| LOCKED_88_DONORS = {MSD_DONOR, MUD_DONOR} |
| NON_88_DONORS = {"HLA mismatch relative", "Mismatched unrelated donor or cord blood"} |
|
|
| MSD_REGIMENS = {"Hsieh et al 2014", "Krishnamurti et al 2019", "King et al 2015", "Walters et al 1996"} |
| MMUD_REGIMENS = {"Bolanos-Meade et al 2022 (Mismatched/Cord)", "Patel et al 2020 (Mismatched/Cord)", |
| "Bolanos-Meade et al 2022 (HLA Mismatch)", "Patel et al 2020 (HLA Mismatch)"} |
|
|
| PATIENT_FEATURES = ["AGE", "AGEGPFF", "SEX", "KPS", "RCMVPR"] |
| DONOR_FEATURES = ["DONORF", "GRAFTYPE", "HLA_FINAL", "CONDGRPF", "CONDGRP_FINAL", "ATGF", "GVHD_FINAL"] |
| DISEASE_FEATURES = ["NACS2YR", "EXCHTFPR", "VOC2YPR", "VOCFRQPR", "SCATXRSN"] |
| ALL_FEATURES = PATIENT_FEATURES + DONOR_FEATURES + DISEASE_FEATURES |
|
|
| ICON_OUTCOMES = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI"] |
| OUTCOME_TITLES = { |
| "DEAD": "Death", |
| "GF": "Graft Failure", |
| "AGVHD": "Acute GvHD", |
| "CGVHD": "Chronic GvHD", |
| "VOCPSHI": "Vaso-Occlusive Crisis", |
| "STROKEHI": "Stroke Post-HCT", |
| } |
| EVENT_COLOR = "#e53935" |
| NO_EVENT_COLOR = "#43a047" |
| SHAP_ORDER = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "EFS", "STROKEHI", "OS"] |
|
|
| SCENARIO_COLORS = ["#e65100", "#6a1b9a", "#1b5e20", "#0d47a1", "#b71c1c"] |
|
|
|
|
| |
| |
| |
|
|
| _FEATURE_META = { |
| "AGE": {"label": "Age (years)", "type": "number", "kwargs": {"minimum": 0, "maximum": 100, "step": 1}}, |
| "AGEGPFF": {"label": "Age Group", "type": "dropdown", "choices": AGEGPFF_CHOICES, "kwargs": {"interactive": False, "info": "Auto-filled from Age"}}, |
| "SEX": {"label": "Sex", "type": "radio", "choices": SEX_CHOICES}, |
| "KPS": {"label": "Karnofsky Performance Status", "type": "radio", "choices": KPS_CHOICES}, |
| "RCMVPR": {"label": "Recipient CMV Status", "type": "radio", "choices": RCMVPR_CHOICES}, |
| "DONORF": {"label": "Donor Type", "type": "dropdown", "choices": DONORF_CHOICES}, |
| "GRAFTYPE": {"label": "Graft Type", "type": "radio", "choices": GRAFTYPE_CHOICES}, |
| "HLA_FINAL": {"label": "HLA Matching", "type": "radio", "choices": HLA_FINAL_CHOICES}, |
| "CONDGRPF": {"label": "Conditioning Intensity", "type": "radio", "choices": CONDGRPF_CHOICES}, |
| "CONDGRP_FINAL": {"label": "Conditioning Regimen", "type": "dropdown", "choices": CONDGRP_FINAL_CHOICES}, |
| "ATGF": {"label": "Serotherapy", "type": "radio", "choices": ATGF_CHOICES}, |
| "GVHD_FINAL": {"label": "GvHD Prophylaxis", "type": "dropdown", "choices": GVHD_FINAL_CHOICES}, |
| "NACS2YR": {"label": "Acute Chest Syndrome Episodes (2 yrs pre-HCT)", "type": "number", "kwargs": {"minimum": 0, "maximum": 50, "step": 1}}, |
| "EXCHTFPR": {"label": "Exchange Transfusion Pre-HCT", "type": "radio", "choices": EXCHTFPR_CHOICES}, |
| "VOC2YPR": {"label": "Vaso-Occlusive Crisis in 2 yrs Pre-HCT", "type": "radio", "choices": VOC2YPR_CHOICES}, |
| "VOCFRQPR": {"label": "VoC Frequency Pre-HCT", "type": "radio", "choices": VOCFRQPR_CHOICES}, |
| "SCATXRSN": {"label": "Primary Reason for HCT", "type": "dropdown", "choices": SCATXRSN_CHOICES}, |
| } |
|
|
|
|
| def make_component(feature: str, suffix: str = ""): |
| meta = _FEATURE_META[feature] |
| label = f"{meta['label']} {suffix}".strip() if suffix else meta["label"] |
| kind = meta["type"] |
| kwargs = dict(meta.get("kwargs", {})) |
|
|
| if kind == "number": |
| return gr.Number(label=label, value=None, **kwargs) |
| elif kind == "dropdown": |
| interactive = kwargs.pop("interactive", True) |
| info = kwargs.pop("info", None) |
| return gr.Dropdown(label=label, choices=meta["choices"], value=None, |
| interactive=interactive, info=info, **kwargs) |
| elif kind == "radio": |
| interactive = kwargs.pop("interactive", True) |
| info = kwargs.pop("info", None) |
| return gr.Radio(label=label, choices=meta["choices"], value=None, |
| interactive=interactive, info=info, **kwargs) |
| else: |
| raise ValueError(f"Unknown component type '{kind}' for feature '{feature}'") |
|
|
|
|
| |
| |
| |
|
|
| def _hla_update_for_donor(donor_value): |
| if not donor_value: |
| return gr.update(choices=HLA_FINAL_CHOICES, interactive=True) |
| if donor_value in LOCKED_88_DONORS: |
| return gr.update(choices=["8/8"], value="8/8", interactive=False) |
| elif donor_value in NON_88_DONORS: |
| return gr.update(choices=["7/8", "≤ 6/8"], value=None, interactive=True) |
| return gr.update(choices=HLA_FINAL_CHOICES, interactive=True) |
|
|
|
|
| def _validate_counterfactual_constraints(base_dict, wi_dict, label=""): |
| violations = [] |
| tag = f"[{label}] " if label else "" |
|
|
| if base_dict.get("SEX") and wi_dict.get("SEX"): |
| if base_dict["SEX"] != wi_dict["SEX"]: |
| violations.append(f"{tag} Immutable feature: Sex cannot be changed.") |
|
|
| base_age, wi_age = base_dict.get("AGE"), wi_dict.get("AGE") |
| if base_age is not None and wi_age is not None: |
| try: |
| if float(wi_age) < float(base_age): |
| violations.append(f"{tag}Age cannot be decreased ({base_age} → {wi_age}).") |
| except (TypeError, ValueError): |
| pass |
|
|
| wi_donor = wi_dict.get("DONORF") |
| wi_hla = wi_dict.get("HLA_FINAL") |
| if wi_donor and wi_hla: |
| if wi_donor in LOCKED_88_DONORS and wi_hla != "8/8": |
| violations.append(f"{tag}HLA constraint: '{wi_donor}' requires 8/8 HLA.") |
| elif wi_donor in NON_88_DONORS and wi_hla == "8/8": |
| violations.append(f"{tag}HLA constraint: '{wi_donor}' cannot have 8/8 HLA.") |
|
|
| if wi_donor: |
| wi_gvhd = wi_dict.get("GVHD_FINAL", "") |
| if wi_donor == MSD_DONOR and wi_gvhd in {"Post-CY + siro +- MMF", "Post-CY + MMF + CNI"}: |
| violations.append( |
| f"{tag} Post-Cy GVHD prophylaxis is inconsistent with donor '{MSD_DONOR}'." |
| ) |
| return violations |
|
|
|
|
| def _values_to_dict(values): |
| d = {} |
| for f, v in zip(ALL_FEATURES, values): |
| if f in NUM_COLS_SET: |
| try: |
| d[f] = float(v) if v not in (None, "") else None |
| except (TypeError, ValueError): |
| d[f] = None |
| else: |
| d[f] = v |
| return d |
|
|
|
|
| def _check_missing(user_vals, label=""): |
| missing = [f for f, v in user_vals.items() |
| if v is None or v == "" or (isinstance(v, float) and pd.isna(v))] |
| if missing: |
| raise ValueError( |
| f"{'[' + label + '] ' if label else ''}Please fill in all fields. " |
| f"Missing: {', '.join(missing)}" |
| ) |
|
|
|
|
| def get_age_group(age): |
| if age is None or age == "": |
| return "" |
| try: |
| age = float(age) |
| if age <= 10: return "<=10" |
| elif age <= 17: return "11-17" |
| elif age <= 29: return "18-29" |
| elif age <= 49: return "30-49" |
| else: return ">=50" |
| except (ValueError, TypeError): |
| return "" |
|
|
|
|
| def vocfrqpr_from_voc2ypr(voc_status): |
| if voc_status == "No": |
| return gr.update(value="< 3/yr", interactive=False) |
| return gr.update(value=None, interactive=True) |
|
|
|
|
| def apply_grouped_preset(selected_value): |
| if not selected_value or selected_value in HEADER_VALUES: |
| return (gr.update(value=None),) + (gr.update(),) * 6 |
|
|
| if selected_value == "Custom": |
| return (gr.update(),) + tuple(gr.update(interactive=True) for _ in range(6)) |
|
|
| preset = PUBLISHED_PRESETS.get(selected_value) |
| if not preset: |
| return (gr.update(),) * 7 |
|
|
| donor = preset["DONORF"] |
| hla_update = gr.update( |
| value=preset["HLA_FINAL"], interactive=False, |
| choices=(["8/8"] if donor in LOCKED_88_DONORS |
| else (["7/8", "≤ 6/8"] if donor in NON_88_DONORS else HLA_FINAL_CHOICES)), |
| ) |
| return ( |
| gr.update(), |
| gr.update(value=preset["DONORF"], interactive=False), |
| gr.update(value=preset["CONDGRPF"], interactive=False), |
| gr.update(value=preset["CONDGRP_FINAL"], interactive=False), |
| gr.update(value=preset["ATGF"], interactive=False), |
| gr.update(value=preset["GVHD_FINAL"], interactive=False), |
| hla_update, |
| ) |
|
|
|
|
| def lock_sex(baseline_sex): |
| if baseline_sex: |
| return gr.update(value=baseline_sex, interactive=False) |
| return gr.update(interactive=False) |
|
|
|
|
| |
| |
| |
|
|
| def _stick_figure_svg(color, size=16): |
| h = round(size * 1.6) |
| return ( |
| f'<svg xmlns="http://www.w3.org/2000/svg" width="{size}" height="{h}" ' |
| f'viewBox="0 0 20 32" style="display:block;flex-shrink:0;" ' |
| f'stroke="{color}" stroke-width="2.2" stroke-linecap="round" fill="none">' |
| f'<circle cx="10" cy="5" r="3.8" fill="{color}" stroke="none"/>' |
| f'<line x1="10" y1="9" x2="10" y2="20"/>' |
| f'<line x1="3" y1="13" x2="17" y2="13"/>' |
| f'<line x1="10" y1="20" x2="4" y2="30"/>' |
| f'<line x1="10" y1="20" x2="16" y2="30"/>' |
| f'</svg>' |
| ) |
|
|
|
|
| def _icon_card_html(probability, outcome, panel_label="", panel_color="#1565c0"): |
| title = OUTCOME_TITLES.get(outcome, OUTCOME_DESCRIPTIONS.get(outcome, outcome)) |
| n_event = round(probability * 100) |
| n_no_event = 100 - n_event |
| pct_str = f"{probability * 100:.1f}%" |
|
|
| rows_parts = [] |
| for row in range(10): |
| cells = "" |
| for col in range(10): |
| idx = row * 10 + col |
| color = EVENT_COLOR if idx < n_event else NO_EVENT_COLOR |
| cells += _stick_figure_svg(color, size=13) |
| rows_parts.append( |
| f'<div style="display:flex;justify-content:center;gap:1px;margin-bottom:1px;">{cells}</div>' |
| ) |
| grid_html = "\n".join(rows_parts) |
|
|
| fig_e = _stick_figure_svg(EVENT_COLOR, size=11) |
| fig_ne = _stick_figure_svg(NO_EVENT_COLOR, size=11) |
| legend = ( |
| f'<div style="display:inline-grid;grid-template-columns:13px 1fr 36px;' |
| f'align-items:center;gap:3px;row-gap:3px;">' |
| f'{fig_e}<span style="color:{EVENT_COLOR};font-weight:700;font-size:9px;">Event</span>' |
| f'<span style="color:#888;font-size:8px;">({n_event}/100)</span>' |
| f'{fig_ne}<span style="color:{NO_EVENT_COLOR};font-weight:700;font-size:9px;">No Event</span>' |
| f'<span style="color:#888;font-size:8px;">({n_no_event}/100)</span>' |
| f'</div>' |
| ) |
| badge = ( |
| f'<div style="background:{panel_color};color:#fff;font-size:8px;font-weight:700;' |
| f'border-radius:3px;padding:1px 5px;margin-bottom:2px;display:inline-block;">' |
| f'{panel_label}</div>' |
| ) if panel_label else "" |
|
|
| return ( |
| f'<div style="background:#fff;border:1px solid #e0e0e0;border-radius:7px;' |
| f'padding:6px 5px;text-align:center;font-family:\'Segoe UI\',Arial,sans-serif;' |
| f'box-shadow:0 2px 4px rgba(0,0,0,0.06);box-sizing:border-box;' |
| f'display:flex;flex-direction:column;align-items:center;">' |
| f'{badge}' |
| f'<div style="min-height:26px;display:flex;align-items:center;justify-content:center;' |
| f'font-size:10px;font-weight:700;color:#222;line-height:1.3;margin-bottom:1px;">{title}</div>' |
| f'<div style="font-size:18px;font-weight:800;color:{EVENT_COLOR};' |
| f'line-height:1;margin-bottom:3px;">{pct_str}</div>' |
| f'<div style="margin-bottom:3px;">{grid_html}</div>' |
| f'<div>{legend}</div>' |
| f'</div>' |
| ) |
|
|
|
|
| def _build_comparison_icon_grid(all_probs_list, labels, colors): |
| """ |
| all_probs_list: list of dicts (baseline first, then scenarios) |
| labels: list of display labels |
| colors: list of hex colors |
| |
| Layout: one row per scenario, columns = outcomes. |
| """ |
| |
| outcome_headers = "".join( |
| f'<div style="flex:1 1 0%;min-width:0;text-align:center;font-size:11px;' |
| f'font-weight:700;color:#555;padding:4px 2px;">' |
| f'{OUTCOME_TITLES.get(o, o)}</div>' |
| for o in ICON_OUTCOMES |
| ) |
| header_row = ( |
| f'<div style="display:flex;gap:6px;margin-bottom:4px;padding-left:110px;">' |
| f'{outcome_headers}</div>' |
| ) |
|
|
| rows_html = header_row |
| for i, (probs, label, color) in enumerate(zip(all_probs_list, labels, colors)): |
| |
| row_label = ( |
| f'<div style="width:104px;flex-shrink:0;display:flex;align-items:center;' |
| f'justify-content:flex-end;padding-right:6px;">' |
| f'<span style="background:{color};color:#fff;font-size:9px;font-weight:700;' |
| f'border-radius:4px;padding:2px 6px;white-space:nowrap;overflow:hidden;' |
| f'text-overflow:ellipsis;max-width:100px;">{label}</span></div>' |
| ) |
| cards = "".join( |
| f'<div style="flex:1 1 0%;min-width:0;">' |
| f'{_icon_card_html(probs[o], o, "", color)}' |
| f'</div>' |
| for o in ICON_OUTCOMES |
| ) |
| rows_html += ( |
| f'<div style="display:flex;gap:6px;margin-bottom:8px;align-items:stretch;">' |
| f'{row_label}{cards}</div>' |
| ) |
|
|
| footnote = ( |
| f'<div style="font-size:10px;color:#888;text-align:center;margin-top:2px;">' |
| f'Each figure = 1 patient out of 100. ' |
| f'<span style="color:{EVENT_COLOR};font-weight:600;">■ Red = Event</span> ' |
| f'<span style="color:{NO_EVENT_COLOR};font-weight:600;">■ Green = No Event</span>' |
| f'</div>' |
| ) |
| return f'<div style="font-family:\'Segoe UI\',Arial,sans-serif;padding:4px 0;">{rows_html}{footnote}</div>' |
|
|
|
|
| def _delta_color_html(delta, is_survival): |
| if abs(delta) < 0.0005: |
| color = "#888888" |
| elif (delta > 0 and is_survival) or (delta < 0 and not is_survival): |
| color = "#2e7d32" |
| else: |
| color = "#c62828" |
| sign = "+" if delta >= 0 else "" |
| return f'<span style="color:{color};font-weight:700;">{sign}{delta*100:.1f}%</span>' |
|
|
|
|
| def _build_comparison_table_html(base_probs, base_ci, scenario_probs_list, scenario_ci_list, scenario_labels, scenario_colors): |
| """ |
| Inverted layout: rows = Baseline + Scenarios, columns = Outcomes. |
| """ |
| survival_outcomes = {"OS", "EFS"} |
|
|
| |
| outcome_headers = "".join( |
| f"<th style='padding:8px 10px;text-align:center;border-bottom:2px solid #ccd;" |
| f"color:#333;font-size:11px;'>{OUTCOME_DESCRIPTIONS.get(o, o)}</th>" |
| for o in REPORTING_OUTCOMES if o in base_probs |
| ) |
|
|
| header = ( |
| "<div style='overflow-x:auto;'>" |
| "<table style='width:100%;border-collapse:collapse;" |
| "font-family:\"Segoe UI\",Arial,sans-serif;font-size:12px;'>" |
| "<thead><tr style='background:#f0f4f8;'>" |
| "<th style='padding:9px 12px;text-align:left;border-bottom:2px solid #ccd;" |
| "color:#333;min-width:120px;'>Scenario</th>" |
| f"{outcome_headers}" |
| "</tr></thead><tbody>" |
| ) |
|
|
| rows = "" |
|
|
| |
| baseline_cells = "" |
| for o in REPORTING_OUTCOMES: |
| if o not in base_probs: |
| continue |
| bp = base_probs[o] |
| blo, bhi = base_ci.get(o, (float("nan"), float("nan"))) |
| baseline_cells += ( |
| f"<td style='padding:7px 10px;text-align:center;'>" |
| f"<div style='font-weight:700;color:#1565c0;font-size:13px;'>{bp*100:.1f}%</div>" |
| f"<div style='color:#5c7fa8;font-size:9px;'>[{blo*100:.1f}%–{bhi*100:.1f}%]</div>" |
| f"</td>" |
| ) |
| rows += ( |
| f"<tr style='background:#e8f0fb;'>" |
| f"<td style='padding:8px 12px;font-weight:700;color:#1565c0;border-right:2px solid #ccd;'>" |
| f"<div style='display:flex;align-items:center;gap:6px;'>" |
| f"<span style='display:inline-block;width:10px;height:10px;border-radius:50%;" |
| f"background:#1565c0;flex-shrink:0;'></span>Baseline</div></td>" |
| f"{baseline_cells}</tr>" |
| ) |
|
|
| |
| for j, (sp_dict, sci_dict, s_label) in enumerate(zip(scenario_probs_list, scenario_ci_list, scenario_labels)): |
| sc_color = scenario_colors[j % len(scenario_colors)] |
| bg = "#fafbfc" if j % 2 == 0 else "#ffffff" |
| scenario_cells = "" |
| for o in REPORTING_OUTCOMES: |
| if o not in base_probs or o not in sp_dict: |
| continue |
| bp = base_probs[o] |
| wp = sp_dict[o] |
| wlo, whi = sci_dict.get(o, (float("nan"), float("nan"))) |
| delta = wp - bp |
| is_surv = o in survival_outcomes |
| scenario_cells += ( |
| f"<td style='padding:7px 10px;text-align:center;'>" |
| f"<div style='font-weight:700;color:{sc_color};font-size:13px;'>{wp*100:.1f}%</div>" |
| f"<div style='color:#888;font-size:9px;'>[{wlo*100:.1f}%–{whi*100:.1f}%]</div>" |
| f"<div style='font-size:10px;margin-top:1px;'>{_delta_color_html(delta, is_surv)}</div>" |
| f"</td>" |
| ) |
| rows += ( |
| f"<tr style='background:{bg};'>" |
| f"<td style='padding:8px 12px;font-weight:600;border-right:2px solid #ccd;'>" |
| f"<div style='display:flex;align-items:center;gap:6px;'>" |
| f"<span style='display:inline-block;width:10px;height:10px;border-radius:50%;" |
| f"background:{sc_color};flex-shrink:0;'></span>" |
| f"<span style='color:{sc_color};'>{s_label}</span></div></td>" |
| f"{scenario_cells}</tr>" |
| ) |
|
|
| footer = ( |
| "</tbody></table></div>" |
| "<div style='font-size:10.5px;color:#888;margin-top:6px;'>" |
| "Δ from Baseline: <span style='color:#2e7d32;font-weight:600;'>Green = improvement</span> " |
| "<span style='color:#c62828;font-weight:600;'>Red = worsening</span> | " |
| "OS & EFS: higher is better; all other outcomes: lower is better." |
| "</div>" |
| ) |
| return header + rows + footer |
|
|
|
|
| def _build_violation_html(violations): |
| if not violations: |
| return "" |
| items = "".join(f"<li style='margin-bottom:6px;'>{v}</li>" for v in violations) |
| return ( |
| f'<div style="background:#fff3e0;border:2px solid #e65100;border-radius:8px;' |
| f'padding:14px 18px;font-family:\'Segoe UI\',Arial,sans-serif;margin-bottom:12px;">' |
| f'<div style="font-weight:700;font-size:14px;color:#bf360c;margin-bottom:8px;">' |
| f'Constraint Violations — Analysis blocked</div>' |
| f'<ul style="margin:0;padding-left:20px;color:#6d1f00;font-size:13px;">{items}</ul>' |
| f'<div style="margin-top:10px;font-size:11px;color:#888;">' |
| f'Please correct the above before running the comparison.</div>' |
| f'</div>' |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def predict_gradio(*values): |
| try: |
| user_vals = _values_to_dict(values) |
| missing = [f for f, v in user_vals.items() |
| if v is None or v == "" or (isinstance(v, float) and pd.isna(v))] |
| if missing: |
| raise ValueError(f"Please fill in all fields. Missing: {', '.join(missing)}") |
|
|
| calibrated, _ = predict_with_comparison(user_vals) |
| calibrated_probs, calibrated_intervals = calibrated |
|
|
| rows = [] |
| for outcome in REPORTING_OUTCOMES: |
| desc = OUTCOME_DESCRIPTIONS[outcome] |
| calib_prob = calibrated_probs[outcome] |
| ci_low, ci_high = calibrated_intervals[outcome] |
| rows.append({ |
| "Outcome": desc, |
| "Probability": f"{calib_prob * 100:.1f}%", |
| "95% CI": f"[{ci_low * 100:.1f}% – {ci_high * 100:.1f}%]", |
| }) |
| df = pd.DataFrame(rows) |
|
|
| shap_plots = create_all_shap_plots(user_vals, max_display=10) |
| icon_arrays = create_all_icon_arrays(calibrated_probs) |
|
|
| return ( |
| df, |
| icon_arrays["__grid__"], |
| shap_plots["DEAD"], |
| shap_plots["GF"], |
| shap_plots["AGVHD"], |
| shap_plots["CGVHD"], |
| shap_plots["VOCPSHI"], |
| shap_plots["EFS"], |
| shap_plots["STROKEHI"], |
| shap_plots["OS"], |
| ) |
| except Exception as e: |
| print(traceback.format_exc()) |
| raise gr.Error(f"{type(e).__name__}: {str(e)}") |
|
|
|
|
| |
| |
| |
|
|
| custom_css = """ |
| .predict-button { |
| background: linear-gradient(to right, #ff6b35, #ff8c42) !important; |
| border: none !important; color: white !important; |
| font-weight: bold !important; font-size: 16px !important; padding: 12px !important; |
| } |
| .predict-button:hover { background: linear-gradient(to right, #ff5722, #ff7b29) !important; } |
| .copy-button { |
| background: linear-gradient(to right, #388e3c, #66bb6a) !important; |
| border: none !important; color: white !important; font-weight: 600 !important; |
| } |
| .copy-button:hover { background: linear-gradient(to right, #2e7d32, #43a047) !important; } |
| .copy-from-predict-button { |
| background: linear-gradient(to right, #6a1b9a, #ab47bc) !important; |
| border: none !important; color: white !important; font-weight: 600 !important; |
| } |
| .copy-from-predict-button:hover { background: linear-gradient(to right, #4a148c, #8e24aa) !important; } |
| .counterfactual-button { |
| background: linear-gradient(to right, #1976d2, #42a5f5) !important; |
| border: none !important; color: white !important; |
| font-weight: bold !important; font-size: 15px !important; padding: 12px !important; |
| } |
| .counterfactual-button:hover { background: linear-gradient(to right, #1565c0, #1e88e5) !important; } |
| .output-dataframe table td:first-child, |
| .output-dataframe table th:first-child { |
| white-space: normal !important; word-break: break-word !important; min-width: 240px !important; |
| } |
| .constraint-info { |
| background: #e8f5e9; border-left: 4px solid #388e3c; |
| padding: 8px 14px; font-size: 12px; color: #1b5e20; |
| border-radius: 4px; margin-bottom: 8px; |
| } |
| .scenario-panel-0 { border-left: 4px solid #e65100 !important; } |
| .scenario-panel-1 { border-left: 4px solid #6a1b9a !important; } |
| .scenario-panel-2 { border-left: 4px solid #1b5e20 !important; } |
| .scenario-panel-3 { border-left: 4px solid #0d47a1 !important; } |
| .scenario-panel-4 { border-left: 4px solid #b71c1c !important; } |
| """ |
|
|
|
|
| |
| |
| |
|
|
| with gr.Blocks(title="HCT Outcome Prediction Model", css=custom_css) as demo: |
| gr.Markdown("# HCT Outcome Prediction Model") |
|
|
| |
| n_scenarios_state = gr.State(1) |
|
|
| with gr.Tabs(): |
|
|
| |
| |
| |
| with gr.Tab("Predict Outcomes"): |
| gr.Markdown("Enter patient, transplant, and disease characteristics to predict outcomes.") |
|
|
| inputs_dict = {} |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### Patient Characteristics") |
| for f in PATIENT_FEATURES: |
| inputs_dict[f] = make_component(f) |
|
|
| with gr.Column(scale=1): |
| gr.Markdown("### Transplant Characteristics") |
| grouped_dd = gr.Dropdown( |
| choices=GROUPED_REGIMEN_CHOICES, value=None, |
| label="Published conditioning regimen", |
| info="Auto-fills Donor Type, Conditioning Intensity, Regimen, Serotherapy, GVHD Prophylaxis", |
| ) |
| p_donorf = inputs_dict["DONORF"] = make_component("DONORF") |
| inputs_dict["GRAFTYPE"] = make_component("GRAFTYPE") |
| p_condgrpf = inputs_dict["CONDGRPF"] = make_component("CONDGRPF") |
| p_condgrp_final = inputs_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL") |
| p_atgf = inputs_dict["ATGF"] = make_component("ATGF") |
| p_gvhd_final = inputs_dict["GVHD_FINAL"] = make_component("GVHD_FINAL") |
| p_hla_final = inputs_dict["HLA_FINAL"] = make_component("HLA_FINAL") |
|
|
| with gr.Column(scale=1): |
| gr.Markdown("### Disease Characteristics") |
| for f in DISEASE_FEATURES: |
| inputs_dict[f] = make_component(f) |
|
|
| inputs_dict["AGE"].change(get_age_group, inputs_dict["AGE"], inputs_dict["AGEGPFF"]) |
| inputs_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, inputs_dict["VOC2YPR"], inputs_dict["VOCFRQPR"]) |
| p_donorf.change(fn=_hla_update_for_donor, inputs=p_donorf, outputs=p_hla_final) |
| grouped_dd.change( |
| apply_grouped_preset, grouped_dd, |
| [grouped_dd, p_donorf, p_condgrpf, p_condgrp_final, p_atgf, p_gvhd_final, p_hla_final], |
| ) |
|
|
| inputs_list = [inputs_dict[f] for f in ALL_FEATURES] |
| predict_btn = gr.Button("Predict", elem_classes="predict-button", size="lg") |
|
|
| gr.Markdown("---") |
| gr.Markdown("## Prediction Results") |
| output_table = gr.Dataframe( |
| headers=["Outcome", "Probability", "95% CI"], |
| label="Predicted Outcomes", |
| elem_classes="output-dataframe", |
| row_count=(len(REPORTING_OUTCOMES), "dynamic"), |
| column_count=(3, "fixed"), |
| wrap=True, |
| ) |
|
|
| gr.Markdown("---") |
| gr.Markdown("## Outcome Probability — Icon Arrays") |
| icon_array_grid = gr.HTML() |
|
|
| gr.Markdown("---") |
| gr.Markdown("## SHAP — Feature Importance") |
| with gr.Row(): |
| shap_dead = gr.Plot(label="Death") |
| shap_gf = gr.Plot(label="Graft Failure") |
| shap_agvhd = gr.Plot(label="Acute GvHD") |
| shap_cgvhd = gr.Plot(label="Chronic GvHD") |
| with gr.Row(): |
| shap_vocpshi = gr.Plot(label="Vaso-Occlusive Crisis Post-HCT") |
| shap_efs = gr.Plot(label="Event-Free Survival") |
| shap_stroke = gr.Plot(label="Stroke Post-HCT") |
| shap_os = gr.Plot(label="Overall Survival") |
|
|
| predict_btn.click( |
| fn=predict_gradio, |
| inputs=inputs_list, |
| outputs=[ |
| output_table, icon_array_grid, |
| shap_dead, shap_gf, shap_agvhd, shap_cgvhd, |
| shap_vocpshi, shap_efs, shap_stroke, shap_os, |
| ], |
| ) |
|
|
| |
| |
| |
| with gr.Tab("Counterfactual Scenarios"): |
| gr.Markdown( |
| "## Counterfactual Scenario Analysis\n" |
| "Enter the **baseline** patient, then choose how many counterfactual " |
| "scenarios you want to compare. Each scenario panel will appear below." |
| ) |
|
|
| |
| with gr.Row(): |
| n_scenarios_slider = gr.Slider( |
| minimum=1, maximum=MAX_SCENARIOS, step=1, value=1, |
| label=f"How many counterfactual scenarios do you want to compare? (1–{MAX_SCENARIOS})", |
| info="Adjust this first — scenario panels will appear/disappear below.", |
| ) |
|
|
| gr.Markdown("---") |
|
|
| |
| gr.Markdown("## Baseline Patient Profile") |
|
|
| with gr.Row(): |
| copy_from_predict_btn = gr.Button( |
| "Copy from Predict tab → Baseline", |
| elem_classes="copy-from-predict-button", |
| size="sm", |
| ) |
|
|
| wi_baseline_dict = {} |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### Patient Characteristics") |
| for f in PATIENT_FEATURES: |
| wi_baseline_dict[f] = make_component(f) |
|
|
| with gr.Column(scale=1): |
| gr.Markdown("### Transplant Characteristics") |
| wi_grouped_base = gr.Dropdown( |
| choices=GROUPED_REGIMEN_CHOICES, value=None, |
| label="Published conditioning regimen (Baseline)", |
| ) |
| wb_donorf = wi_baseline_dict["DONORF"] = make_component("DONORF") |
| wi_baseline_dict["GRAFTYPE"] = make_component("GRAFTYPE") |
| wb_condgrpf = wi_baseline_dict["CONDGRPF"] = make_component("CONDGRPF") |
| wb_condgrp_final = wi_baseline_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL") |
| wb_atgf = wi_baseline_dict["ATGF"] = make_component("ATGF") |
| wb_gvhd_final = wi_baseline_dict["GVHD_FINAL"] = make_component("GVHD_FINAL") |
| wb_hla_final = wi_baseline_dict["HLA_FINAL"] = make_component("HLA_FINAL") |
|
|
| with gr.Column(scale=1): |
| gr.Markdown("### Disease Characteristics") |
| for f in DISEASE_FEATURES: |
| wi_baseline_dict[f] = make_component(f) |
|
|
| wi_baseline_dict["AGE"].change(get_age_group, wi_baseline_dict["AGE"], wi_baseline_dict["AGEGPFF"]) |
| wi_baseline_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, wi_baseline_dict["VOC2YPR"], wi_baseline_dict["VOCFRQPR"]) |
| wb_donorf.change(fn=_hla_update_for_donor, inputs=wb_donorf, outputs=wb_hla_final) |
| wi_grouped_base.change( |
| apply_grouped_preset, wi_grouped_base, |
| [wi_grouped_base, wb_donorf, wb_condgrpf, wb_condgrp_final, wb_atgf, wb_gvhd_final, wb_hla_final], |
| ) |
|
|
| wi_baseline_list = [wi_baseline_dict[f] for f in ALL_FEATURES] |
|
|
| |
| def _copy_predict_to_baseline(*vals): |
| n = len(ALL_FEATURES) |
| feature_vals = vals[:n] |
| regimen_value = vals[n] if len(vals) > n else None |
| preset_outputs = apply_grouped_preset(regimen_value) |
| regimen_dd_upd = gr.update(value=regimen_value) |
| donorf_upd, condgrpf_upd, condgrp_final_upd, atgf_upd, gvhd_final_upd, hla_final_upd = preset_outputs[1:] |
| return (*list(feature_vals), regimen_dd_upd, donorf_upd, condgrpf_upd, |
| condgrp_final_upd, atgf_upd, gvhd_final_upd, hla_final_upd) |
|
|
| copy_from_predict_btn.click( |
| fn=_copy_predict_to_baseline, |
| inputs=inputs_list + [grouped_dd], |
| outputs=wi_baseline_list + [wi_grouped_base, wb_donorf, wb_condgrpf, |
| wb_condgrp_final, wb_atgf, wb_gvhd_final, wb_hla_final], |
| ) |
|
|
| gr.Markdown("---") |
|
|
| |
| scenario_dicts = [] |
| scenario_lists = [] |
| scenario_grouped_dds = [] |
| scenario_rows = [] |
| scenario_name_inputs = [] |
|
|
| scenario_donorf_handles = [] |
| scenario_hla_handles = [] |
| scenario_voc2ypr_handles = [] |
| scenario_age_handles = [] |
| scenario_agegp_handles = [] |
| scenario_copy_btns = [] |
|
|
| for s_idx in range(MAX_SCENARIOS): |
| color = SCENARIO_COLORS[s_idx] |
| label = f"Scenario {s_idx + 1}" |
| suffix = f"({label})" |
| visible_init = (s_idx == 0) |
|
|
| with gr.Row(visible=visible_init) as s_row: |
| scenario_rows.append(s_row) |
| with gr.Column(): |
| gr.HTML( |
| f'<div style="background:{color};color:#fff;font-weight:700;' |
| f'font-size:15px;padding:8px 14px;border-radius:6px;margin-bottom:6px;">' |
| f'Counterfactual {label}</div>' |
| ) |
| with gr.Row(): |
| |
| s_name_input = gr.Textbox( |
| label=f"Scenario name ({label})", |
| value=label, |
| placeholder=f"e.g. {label}", |
| scale=2, |
| ) |
| scenario_name_inputs.append(s_name_input) |
|
|
| copy_btn = gr.Button( |
| f"⬇ Copy Baseline → {label}", |
| elem_classes="copy-button", |
| size="sm", |
| scale=1, |
| ) |
| scenario_copy_btns.append(copy_btn) |
|
|
| s_dict = {} |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown(f"#### Patient — {label}") |
| for f in PATIENT_FEATURES: |
| s_dict[f] = make_component(f, suffix) |
|
|
| with gr.Column(scale=1): |
| gr.Markdown(f"#### Transplant — {label}") |
| s_grouped_dd = gr.Dropdown( |
| choices=GROUPED_REGIMEN_CHOICES, value=None, |
| label=f"Published regimen ({label})", |
| ) |
| scenario_grouped_dds.append(s_grouped_dd) |
| s_donorf = s_dict["DONORF"] = make_component("DONORF", suffix) |
| s_dict["GRAFTYPE"] = make_component("GRAFTYPE", suffix) |
| s_condgrpf = s_dict["CONDGRPF"] = make_component("CONDGRPF", suffix) |
| s_condgrp_final = s_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL", suffix) |
| s_atgf = s_dict["ATGF"] = make_component("ATGF", suffix) |
| s_gvhd_final = s_dict["GVHD_FINAL"] = make_component("GVHD_FINAL", suffix) |
| s_hla_final = s_dict["HLA_FINAL"] = make_component("HLA_FINAL", suffix) |
|
|
| with gr.Column(scale=1): |
| gr.Markdown(f"#### Disease — {label}") |
| for f in DISEASE_FEATURES: |
| s_dict[f] = make_component(f, suffix) |
|
|
| scenario_dicts.append(s_dict) |
| scenario_lists.append([s_dict[f] for f in ALL_FEATURES]) |
| scenario_donorf_handles.append(s_donorf) |
| scenario_hla_handles.append(s_hla_final) |
| scenario_voc2ypr_handles.append(s_dict["VOC2YPR"]) |
| scenario_age_handles.append(s_dict["AGE"]) |
| scenario_agegp_handles.append(s_dict["AGEGPFF"]) |
|
|
| |
| s_dict["AGE"].change(get_age_group, s_dict["AGE"], s_dict["AGEGPFF"]) |
| s_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, s_dict["VOC2YPR"], s_dict["VOCFRQPR"]) |
| s_donorf.change(fn=_hla_update_for_donor, inputs=s_donorf, outputs=s_hla_final) |
| s_grouped_dd.change( |
| apply_grouped_preset, s_grouped_dd, |
| [s_grouped_dd, s_donorf, s_condgrpf, s_condgrp_final, s_atgf, s_gvhd_final, s_hla_final], |
| ) |
|
|
| |
| wi_baseline_dict["SEX"].change( |
| fn=lock_sex, |
| inputs=wi_baseline_dict["SEX"], |
| outputs=s_dict["SEX"], |
| ) |
| s_dict["SEX"].interactive = False |
|
|
| |
| def _make_copy_fn(s_grouped_dd_ref, s_donorf_ref, s_condgrpf_ref, |
| s_condgrp_final_ref, s_atgf_ref, s_gvhd_final_ref, |
| s_hla_final_ref): |
| def _copy(*vals): |
| n = len(ALL_FEATURES) |
| feature_vals = vals[:n] |
| regimen_value = vals[n] if len(vals) > n else None |
| preset_outputs = apply_grouped_preset(regimen_value) |
| regimen_dd_upd = gr.update(value=regimen_value) |
| d_upd, c_upd, cr_upd, a_upd, g_upd, h_upd = preset_outputs[1:] |
| return (*list(feature_vals), regimen_dd_upd, |
| d_upd, c_upd, cr_upd, a_upd, g_upd, h_upd) |
| return _copy |
|
|
| copy_btn.click( |
| fn=_make_copy_fn(s_grouped_dd, s_donorf, s_condgrpf, |
| s_condgrp_final, s_atgf, s_gvhd_final, s_hla_final), |
| inputs=wi_baseline_list + [wi_grouped_base], |
| outputs=(scenario_lists[-1] |
| + [s_grouped_dd, s_donorf, s_condgrpf, |
| s_condgrp_final, s_atgf, s_gvhd_final, s_hla_final]), |
| ) |
| |
| copy_btn.click( |
| fn=lock_sex, |
| inputs=wi_baseline_dict["SEX"], |
| outputs=s_dict["SEX"], |
| ) |
|
|
| |
| def _update_scenario_visibility(n): |
| return [gr.update(visible=(i < int(n))) for i in range(MAX_SCENARIOS)] |
|
|
| n_scenarios_slider.change( |
| fn=_update_scenario_visibility, |
| inputs=n_scenarios_slider, |
| outputs=scenario_rows, |
| ) |
| n_scenarios_slider.change( |
| fn=lambda n: int(n), |
| inputs=n_scenarios_slider, |
| outputs=n_scenarios_state, |
| ) |
|
|
| |
| gr.Markdown("---") |
| wi_run_btn = gr.Button( |
| "Run Counterfactual Comparison", |
| elem_classes="counterfactual-button", |
| size="lg", |
| ) |
|
|
| |
| gr.Markdown("## Comparison Results") |
| wi_violation_html = gr.HTML() |
| gr.Markdown("### Outcome Probability Table") |
| wi_table_html = gr.HTML() |
|
|
| gr.Markdown("---") |
| |
| with gr.Accordion("Outcome Icon Arrays — Baseline vs Scenarios", open=False): |
| gr.Markdown("*Icon arrays show each outcome probability per 100 patients.*") |
| wi_icon_html = gr.HTML() |
|
|
| gr.Markdown("---") |
| |
| with gr.Accordion("SHAP Feature Importance", open=False): |
| gr.Markdown("### Baseline") |
| with gr.Row(): |
| wi_shap_base = {o: gr.Plot(label=f"{o} — Baseline") for o in SHAP_ORDER} |
|
|
| |
| wi_shap_scenarios = [] |
| for s_idx in range(MAX_SCENARIOS): |
| gr.Markdown(f"### Scenario {s_idx + 1}") |
| with gr.Row(visible=(s_idx == 0)) as shap_row: |
| shap_plots_s = {o: gr.Plot(label=f"{o} — Scenario {s_idx + 1}") for o in SHAP_ORDER} |
| wi_shap_scenarios.append((shap_row, shap_plots_s)) |
|
|
| |
| def _update_shap_visibility(n): |
| return [gr.update(visible=(i < int(n))) for i in range(MAX_SCENARIOS)] |
|
|
| n_scenarios_slider.change( |
| fn=_update_shap_visibility, |
| inputs=n_scenarios_slider, |
| outputs=[row for row, _ in wi_shap_scenarios], |
| ) |
|
|
| |
| def run_counterfactual(*all_values): |
| n_feat = len(ALL_FEATURES) |
| n_scen = int(all_values[0]) |
| base_vals = all_values[1 : 1 + n_feat] |
|
|
| |
| name_offset = 1 + n_feat |
| scenario_names_raw = all_values[name_offset : name_offset + MAX_SCENARIOS] |
|
|
| |
| feat_offset = name_offset + MAX_SCENARIOS |
| scenario_val_blocks = [ |
| all_values[feat_offset + s * n_feat : feat_offset + (s + 1) * n_feat] |
| for s in range(MAX_SCENARIOS) |
| ] |
|
|
| |
| base_shap_outputs = [None] * 8 |
| scene_shap_outputs = [[None] * 8 for _ in range(MAX_SCENARIOS)] |
|
|
| try: |
| base_dict = _values_to_dict(base_vals) |
| _check_missing(base_dict, "Baseline") |
|
|
| all_violations = [] |
| scenario_dicts_active = [] |
| for s in range(n_scen): |
| sd = _values_to_dict(scenario_val_blocks[s]) |
| _check_missing(sd, f"Scenario {s+1}") |
| v = _validate_counterfactual_constraints(base_dict, sd, f"Scenario {s+1}") |
| all_violations.extend(v) |
| scenario_dicts_active.append(sd) |
|
|
| if all_violations: |
| return ( |
| _build_violation_html(all_violations), "", "", |
| *base_shap_outputs, |
| *[p for plots in scene_shap_outputs for p in plots], |
| ) |
|
|
| base_probs, base_ci = predict_all_outcomes( |
| base_dict, use_calibration=True, use_signed_voting=True, n_boot_ci=500 |
| ) |
|
|
| scen_probs_list = [] |
| scen_ci_list = [] |
| for sd in scenario_dicts_active: |
| sp, sci = predict_all_outcomes( |
| sd, use_calibration=True, use_signed_voting=True, n_boot_ci=500 |
| ) |
| scen_probs_list.append(sp) |
| scen_ci_list.append(sci) |
|
|
| |
| scen_labels = [] |
| for i in range(n_scen): |
| raw_name = scenario_names_raw[i] if i < len(scenario_names_raw) else "" |
| name = str(raw_name).strip() if raw_name else "" |
| scen_labels.append(name if name else f"Scenario {i+1}") |
|
|
| scen_colors = SCENARIO_COLORS[:n_scen] |
|
|
| table_html = _build_comparison_table_html( |
| base_probs, base_ci, scen_probs_list, scen_ci_list, scen_labels, scen_colors |
| ) |
| icon_html = _build_comparison_icon_grid( |
| [base_probs] + scen_probs_list, |
| ["Baseline"] + scen_labels, |
| ["#1565c0"] + scen_colors, |
| ) |
|
|
| |
| base_shap_plots = create_all_shap_plots(base_dict, max_display=10) |
| base_shap_outputs = [base_shap_plots[o] for o in SHAP_ORDER] |
|
|
| for s, sd in enumerate(scenario_dicts_active): |
| sp_plots = create_all_shap_plots(sd, max_display=10) |
| scene_shap_outputs[s] = [sp_plots[o] for o in SHAP_ORDER] |
|
|
| return ( |
| "", |
| table_html, |
| icon_html, |
| *base_shap_outputs, |
| *[p for plots in scene_shap_outputs for p in plots], |
| ) |
|
|
| except Exception as e: |
| print(traceback.format_exc()) |
| raise gr.Error(f"{type(e).__name__}: {str(e)}") |
|
|
| |
| all_run_outputs = ( |
| [wi_violation_html, wi_table_html, wi_icon_html] |
| + [wi_shap_base[o] for o in SHAP_ORDER] |
| + [shap_plots_s[o] for _, shap_plots_s in wi_shap_scenarios for o in SHAP_ORDER] |
| ) |
|
|
| |
| |
| all_run_inputs = ( |
| [n_scenarios_state] |
| + wi_baseline_list |
| + scenario_name_inputs |
| + [feat for s_list in scenario_lists for feat in s_list] |
| ) |
|
|
| wi_run_btn.click( |
| fn=run_counterfactual, |
| inputs=all_run_inputs, |
| outputs=all_run_outputs, |
| ) |
|
|
|
|
| |
| if __name__ == "__main__": |
| demo.launch(ssr_mode=False) |