'
f'{fig_e}Event'
f'({n_event}/100)'
f'{fig_ne}No Event'
f'({n_no_event}/100)'
f'
'
)
badge = (
f'
'
f'{panel_label}
'
) if panel_label else ""
return (
f'
'
f'{badge}'
f'
{title}
'
f'
{pct_str}
'
f'
{grid_html}
'
f'
{legend}
'
f'
'
)
def _render_icon_carousel_page(probs_dict, idx):
"""Render the full carousel HTML for a given outcome index (no JS)."""
outcome = ICON_OUTCOMES[idx]
total = len(ICON_OUTCOMES)
card_html = _render_single_icon_card(probs_dict.get(outcome, 0.0), outcome)
label = OUTCOME_TITLES.get(outcome, outcome)
# Breadcrumb dots
dots = ""
for i, o in enumerate(ICON_OUTCOMES):
active = i == idx
dots += (
f''
)
footnote = (
f'
'
f'Each figure = 1 patient out of 100. '
f'■ Red = Event '
f'■ Green = No Event'
f'
'
)
return (
f'
'
f'
{label}
'
f'
{idx+1} / {total}
'
f'
{dots}
'
f'
{card_html}
'
f'{footnote}'
f'
'
)
def _render_comparison_icon_page(all_probs_list, labels, colors, idx):
"""Render comparison icon carousel page for a given outcome index (no JS)."""
outcome = ICON_OUTCOMES[idx]
total = len(ICON_OUTCOMES)
out_label = OUTCOME_TITLES.get(outcome, outcome)
# Dots breadcrumb
dots = ""
for i in range(total):
active = i == idx
dots += (
f''
)
rows_html = ""
for probs, label, color in zip(all_probs_list, labels, colors):
prob = probs.get(outcome, 0.0)
card = _render_single_icon_card(prob, outcome, label, color)
rows_html += (
f'
'
f'
{label}
'
f'
{card}
'
f'
'
)
footnote = (
f'
'
f'Each figure = 1 patient out of 100. '
f'■ Red = Event '
f'■ Green = No Event'
f'
'
)
return (
f'
'
f'
{out_label}
'
f'
{idx+1} / {total}
'
f'
{dots}
'
f'{rows_html}'
f'{footnote}'
f'
'
)
# ─────────────────────────────────────────────────────────────────────────────
# OTHER HTML RENDERERS
# ─────────────────────────────────────────────────────────────────────────────
def _delta_color_html(delta, is_survival):
if abs(delta) < 0.0005:
color = "#888888"
elif (delta > 0 and is_survival) or (delta < 0 and not is_survival):
color = "#2e7d32"
else:
color = "#c62828"
sign = "+" if delta >= 0 else ""
return f'{sign}{delta*100:.1f}%'
def _build_comparison_table_html(base_probs, base_ci, scenario_probs_list, scenario_ci_list, scenario_labels, scenario_colors):
survival_outcomes = {"OS", "EFS"}
outcome_headers = "".join(
f"
{OUTCOME_DESCRIPTIONS.get(o, o)}
"
for o in REPORTING_OUTCOMES if o in base_probs
)
header = (
"
"
"
"
"
"
"
Scenario
"
f"{outcome_headers}"
"
"
)
rows = ""
baseline_cells = ""
for o in REPORTING_OUTCOMES:
if o not in base_probs:
continue
bp = base_probs[o]
blo, bhi = base_ci.get(o, (float("nan"), float("nan")))
baseline_cells += (
f"
"
f"
{bp*100:.1f}%
"
f"
[{blo*100:.1f}%–{bhi*100:.1f}%]
"
f"
"
)
rows += (
f"
"
f"
"
f"
"
f"Baseline
"
f"{baseline_cells}
"
)
for j, (sp_dict, sci_dict, s_label) in enumerate(zip(scenario_probs_list, scenario_ci_list, scenario_labels)):
sc_color = scenario_colors[j % len(scenario_colors)]
bg = "#fafbfc" if j % 2 == 0 else "#ffffff"
scenario_cells = ""
for o in REPORTING_OUTCOMES:
if o not in base_probs or o not in sp_dict:
continue
bp = base_probs[o]
wp = sp_dict[o]
wlo, whi = sci_dict.get(o, (float("nan"), float("nan")))
delta = wp - bp
is_surv = o in survival_outcomes
scenario_cells += (
f"
"
f"
{wp*100:.1f}%
"
f"
[{wlo*100:.1f}%–{whi*100:.1f}%]
"
f"
{_delta_color_html(delta, is_surv)}
"
f"
"
)
rows += (
f"
"
f"
"
f"
"
f""
f"{s_label}
"
f"{scenario_cells}
"
)
footer = (
"
"
"
"
"Δ from Baseline: Green = improvement "
"Red = worsening | "
"OS & EFS: higher is better; all other outcomes: lower is better."
"
"
)
return header + rows + footer
def _build_violation_html(violations):
if not violations:
return ""
items = "".join(f"
{v}
" for v in violations)
return (
f'
'
f'
'
f'Constraint Violations — Analysis blocked
'
f'
{items}
'
f'
'
f'Please correct the above before running the comparison.
'
f'
'
)
# ─────────────────────────────────────────────────────────────────────────────
# SHAP CAROUSEL HELPERS
# ─────────────────────────────────────────────────────────────────────────────
def _shap_counter_html(idx):
labels = [SHAP_LABELS.get(o, o) for o in SHAP_ORDER]
items = " · ".join(
f'{l}'
for i, l in enumerate(labels)
)
return f'
{items}
'
# ─────────────────────────────────────────────────────────────────────────────
# MAIN PREDICT CALLBACK (Tab 1)
# ─────────────────────────────────────────────────────────────────────────────
def predict_gradio(*values):
try:
user_vals = _values_to_dict(values)
missing = [f for f, v in user_vals.items()
if v is None or v == "" or (isinstance(v, float) and pd.isna(v))]
if missing:
raise ValueError(f"Please fill in all fields. Missing: {', '.join(missing)}")
calibrated, _ = predict_with_comparison(user_vals)
calibrated_probs, calibrated_intervals = calibrated
rows = []
for outcome in REPORTING_OUTCOMES:
desc = OUTCOME_DESCRIPTIONS[outcome]
calib_prob = calibrated_probs[outcome]
ci_low, ci_high = calibrated_intervals[outcome]
rows.append({
"Outcome": desc,
"Probability": f"{calib_prob * 100:.1f}%",
"95% CI": f"[{ci_low * 100:.1f}% – {ci_high * 100:.1f}%]",
})
df = pd.DataFrame(rows)
shap_plots = create_all_shap_plots(user_vals, max_display=10)
# Icon carousel: start at index 0, store probs in State
icon_html = _render_icon_carousel_page(calibrated_probs, 0)
first_shap = shap_plots[SHAP_ORDER[0]]
shap_crumb = _shap_counter_html(0)
return (
df,
calibrated_probs, # stored in State for carousel navigation
0, # icon carousel index State
icon_html, # displayed icon card
shap_plots, # stored in State
0, # shap index State
first_shap, # displayed plot
shap_crumb, # breadcrumb HTML
)
except Exception as e:
print(traceback.format_exc())
raise gr.Error(f"{type(e).__name__}: {str(e)}")
# ─────────────────────────────────────────────────────────────────────────────
# CSS
# ─────────────────────────────────────────────────────────────────────────────
custom_css = """
.predict-button {
background: linear-gradient(to right, #ff6b35, #ff8c42) !important;
border: none !important; color: white !important;
font-weight: bold !important; font-size: 16px !important; padding: 12px !important;
}
.predict-button:hover { background: linear-gradient(to right, #ff5722, #ff7b29) !important; }
.copy-button {
background: linear-gradient(to right, #388e3c, #66bb6a) !important;
border: none !important; color: white !important; font-weight: 600 !important;
}
.copy-button:hover { background: linear-gradient(to right, #2e7d32, #43a047) !important; }
.copy-from-predict-button {
background: linear-gradient(to right, #6a1b9a, #ab47bc) !important;
border: none !important; color: white !important; font-weight: 600 !important;
}
.copy-from-predict-button:hover { background: linear-gradient(to right, #4a148c, #8e24aa) !important; }
.counterfactual-button {
background: linear-gradient(to right, #1976d2, #42a5f5) !important;
border: none !important; color: white !important;
font-weight: bold !important; font-size: 15px !important; padding: 12px !important;
}
.counterfactual-button:hover { background: linear-gradient(to right, #1565c0, #1e88e5) !important; }
.shap-nav-button {
background: linear-gradient(to right, #37474f, #546e7a) !important;
border: none !important; color: white !important;
font-weight: bold !important; font-size: 18px !important;
padding: 8px 22px !important; border-radius: 6px !important;
}
.shap-nav-button:hover { background: linear-gradient(to right, #263238, #455a64) !important; }
.icon-nav-button {
background: linear-gradient(to right, #1565c0, #1e88e5) !important;
border: none !important; color: white !important;
font-weight: bold !important; font-size: 18px !important;
padding: 8px 22px !important; border-radius: 6px !important;
}
.icon-nav-button:hover { background: linear-gradient(to right, #0d47a1, #1565c0) !important; }
.output-dataframe table td:first-child,
.output-dataframe table th:first-child {
white-space: normal !important; word-break: break-word !important; min-width: 240px !important;
}
.constraint-info {
background: #e8f5e9; border-left: 4px solid #388e3c;
padding: 8px 14px; font-size: 12px; color: #1b5e20;
border-radius: 4px; margin-bottom: 8px;
}
.scenario-panel-0 { border-left: 4px solid #e65100 !important; }
.scenario-panel-1 { border-left: 4px solid #6a1b9a !important; }
.scenario-panel-2 { border-left: 4px solid #1b5e20 !important; }
.scenario-panel-3 { border-left: 4px solid #0d47a1 !important; }
.scenario-panel-4 { border-left: 4px solid #b71c1c !important; }
"""
# ─────────────────────────────────────────────────────────────────────────────
# BUILD UI
# ─────────────────────────────────────────────────────────────────────────────
with gr.Blocks(title="HCT Outcome Prediction Model", css=custom_css) as demo:
gr.Markdown("# HCT Outcome Prediction Model")
n_scenarios_state = gr.State(1)
with gr.Tabs():
# ══════════════════════════════════════════════════════════════════════
# TAB 1 — PREDICT OUTCOMES
# ══════════════════════════════════════════════════════════════════════
with gr.Tab("Predict Outcomes"):
gr.Markdown("Enter patient, transplant, and disease characteristics to predict outcomes.")
inputs_dict = {}
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Patient Characteristics")
for f in PATIENT_FEATURES:
inputs_dict[f] = make_component(f)
with gr.Column(scale=1):
gr.Markdown("### Transplant Characteristics")
grouped_dd = gr.Dropdown(
choices=GROUPED_REGIMEN_CHOICES, value=None,
label="Published conditioning regimen",
info="Auto-fills Donor Type, Conditioning Intensity, Regimen, Serotherapy, GVHD Prophylaxis",
)
p_donorf = inputs_dict["DONORF"] = make_component("DONORF")
inputs_dict["GRAFTYPE"] = make_component("GRAFTYPE")
p_condgrpf = inputs_dict["CONDGRPF"] = make_component("CONDGRPF")
p_condgrp_final = inputs_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL")
p_atgf = inputs_dict["ATGF"] = make_component("ATGF")
p_gvhd_final = inputs_dict["GVHD_FINAL"] = make_component("GVHD_FINAL")
p_hla_final = inputs_dict["HLA_FINAL"] = make_component("HLA_FINAL")
with gr.Column(scale=1):
gr.Markdown("### Disease Characteristics")
for f in DISEASE_FEATURES:
inputs_dict[f] = make_component(f)
inputs_dict["AGE"].change(get_age_group, inputs_dict["AGE"], inputs_dict["AGEGPFF"])
inputs_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, inputs_dict["VOC2YPR"], inputs_dict["VOCFRQPR"])
p_donorf.change(fn=_hla_update_for_donor, inputs=p_donorf, outputs=p_hla_final)
grouped_dd.change(
apply_grouped_preset, grouped_dd,
[grouped_dd, p_donorf, p_condgrpf, p_condgrp_final, p_atgf, p_gvhd_final, p_hla_final],
)
inputs_list = [inputs_dict[f] for f in ALL_FEATURES]
predict_btn = gr.Button("Predict", elem_classes="predict-button", size="lg")
gr.Markdown("---")
gr.Markdown("## Prediction Results")
output_table = gr.Dataframe(
headers=["Outcome", "Probability", "95% CI"],
label="Predicted Outcomes",
elem_classes="output-dataframe",
row_count=(len(REPORTING_OUTCOMES), "dynamic"),
column_count=(3, "fixed"),
wrap=True,
)
# ── Icon Array Carousel (Tab 1) — Python-driven ───────────────────
gr.Markdown("---")
gr.Markdown("## Outcome Probability — Icon Arrays")
gr.Markdown("*Use the ← → arrows to browse outcomes one at a time.*")
# State: store probs dict + current index
icon_probs_state = gr.State(None)
icon_idx_state = gr.State(0)
with gr.Row():
icon_prev_btn = gr.Button("←", elem_classes="icon-nav-button", size="sm", scale=0)
icon_display = gr.HTML(scale=4)
icon_next_btn = gr.Button("→", elem_classes="icon-nav-button", size="sm", scale=0)
def _on_icon_prev(idx, probs):
if probs is None:
return idx, ""
new_idx = max(0, idx - 1)
return new_idx, _render_icon_carousel_page(probs, new_idx)
def _on_icon_next(idx, probs):
if probs is None:
return idx, ""
new_idx = min(len(ICON_OUTCOMES) - 1, idx + 1)
return new_idx, _render_icon_carousel_page(probs, new_idx)
icon_prev_btn.click(
fn=_on_icon_prev,
inputs=[icon_idx_state, icon_probs_state],
outputs=[icon_idx_state, icon_display],
)
icon_next_btn.click(
fn=_on_icon_next,
inputs=[icon_idx_state, icon_probs_state],
outputs=[icon_idx_state, icon_display],
)
# ── SHAP Carousel (Tab 1) ─────────────────────────────────────────
gr.Markdown("---")
gr.Markdown("## SHAP — Feature Importance")
gr.Markdown("*Use the ← → buttons to browse SHAP plots one at a time.*")
shap_plots_state = gr.State(None)
shap_idx_state = gr.State(0)
with gr.Row():
shap_prev_btn = gr.Button("←", elem_classes="shap-nav-button", size="sm", scale=0)
shap_crumb = gr.HTML(value="", scale=4)
shap_next_btn = gr.Button("→", elem_classes="shap-nav-button", size="sm", scale=0)
shap_display = gr.Plot(label="SHAP Feature Importance")
def _on_shap_prev(idx, plots):
new_idx = max(0, idx - 1)
plot = plots[SHAP_ORDER[new_idx]] if plots else None
return new_idx, plot, _shap_counter_html(new_idx)
def _on_shap_next(idx, plots):
new_idx = min(len(SHAP_ORDER) - 1, idx + 1)
plot = plots[SHAP_ORDER[new_idx]] if plots else None
return new_idx, plot, _shap_counter_html(new_idx)
shap_prev_btn.click(
fn=_on_shap_prev,
inputs=[shap_idx_state, shap_plots_state],
outputs=[shap_idx_state, shap_display, shap_crumb],
)
shap_next_btn.click(
fn=_on_shap_next,
inputs=[shap_idx_state, shap_plots_state],
outputs=[shap_idx_state, shap_display, shap_crumb],
)
predict_btn.click(
fn=predict_gradio,
inputs=inputs_list,
outputs=[
output_table,
icon_probs_state, # ← store probs dict
icon_idx_state, # ← reset to 0
icon_display, # ← show first card
shap_plots_state,
shap_idx_state,
shap_display,
shap_crumb,
],
)
# ══════════════════════════════════════════════════════════════════════
# TAB 2 — COUNTERFACTUAL ANALYSIS
# ══════════════════════════════════════════════════════════════════════
with gr.Tab("Counterfactual Scenarios"):
gr.Markdown(
"## Counterfactual Scenario Analysis\n"
"Enter the **baseline** patient, then choose how many counterfactual "
"scenarios you want to compare. Each scenario panel will appear below."
)
with gr.Row():
n_scenarios_slider = gr.Slider(
minimum=1, maximum=MAX_SCENARIOS, step=1, value=1,
label=f"How many counterfactual scenarios do you want to compare? (1–{MAX_SCENARIOS})",
info="Adjust this first — scenario panels will appear/disappear below.",
)
gr.Markdown("---")
# ── BASELINE ─────────────────────────────────────────────────────
gr.Markdown("## Baseline Patient Profile")
with gr.Row():
copy_from_predict_btn = gr.Button(
"Copy from Predict tab → Baseline",
elem_classes="copy-from-predict-button",
size="sm",
)
wi_baseline_dict = {}
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Patient Characteristics")
for f in PATIENT_FEATURES:
wi_baseline_dict[f] = make_component(f)
with gr.Column(scale=1):
gr.Markdown("### Transplant Characteristics")
wi_grouped_base = gr.Dropdown(
choices=GROUPED_REGIMEN_CHOICES, value=None,
label="Published conditioning regimen (Baseline)",
)
wb_donorf = wi_baseline_dict["DONORF"] = make_component("DONORF")
wi_baseline_dict["GRAFTYPE"] = make_component("GRAFTYPE")
wb_condgrpf = wi_baseline_dict["CONDGRPF"] = make_component("CONDGRPF")
wb_condgrp_final = wi_baseline_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL")
wb_atgf = wi_baseline_dict["ATGF"] = make_component("ATGF")
wb_gvhd_final = wi_baseline_dict["GVHD_FINAL"] = make_component("GVHD_FINAL")
wb_hla_final = wi_baseline_dict["HLA_FINAL"] = make_component("HLA_FINAL")
with gr.Column(scale=1):
gr.Markdown("### Disease Characteristics")
for f in DISEASE_FEATURES:
wi_baseline_dict[f] = make_component(f)
wi_baseline_dict["AGE"].change(get_age_group, wi_baseline_dict["AGE"], wi_baseline_dict["AGEGPFF"])
wi_baseline_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, wi_baseline_dict["VOC2YPR"], wi_baseline_dict["VOCFRQPR"])
wb_donorf.change(fn=_hla_update_for_donor, inputs=wb_donorf, outputs=wb_hla_final)
wi_grouped_base.change(
apply_grouped_preset, wi_grouped_base,
[wi_grouped_base, wb_donorf, wb_condgrpf, wb_condgrp_final, wb_atgf, wb_gvhd_final, wb_hla_final],
)
wi_baseline_list = [wi_baseline_dict[f] for f in ALL_FEATURES]
def _copy_predict_to_baseline(*vals):
n = len(ALL_FEATURES)
feature_vals = vals[:n]
regimen_value = vals[n] if len(vals) > n else None
preset_outputs = apply_grouped_preset(regimen_value)
regimen_dd_upd = gr.update(value=regimen_value)
donorf_upd, condgrpf_upd, condgrp_final_upd, atgf_upd, gvhd_final_upd, hla_final_upd = preset_outputs[1:]
return (*list(feature_vals), regimen_dd_upd, donorf_upd, condgrpf_upd,
condgrp_final_upd, atgf_upd, gvhd_final_upd, hla_final_upd)
copy_from_predict_btn.click(
fn=_copy_predict_to_baseline,
inputs=inputs_list + [grouped_dd],
outputs=wi_baseline_list + [wi_grouped_base, wb_donorf, wb_condgrpf,
wb_condgrp_final, wb_atgf, wb_gvhd_final, wb_hla_final],
)
gr.Markdown("---")
# ── SCENARIO PANELS ───────────────────────────────────────────────
scenario_dicts = []
scenario_lists = []
scenario_grouped_dds = []
scenario_rows = []
scenario_name_inputs = []
scenario_donorf_handles = []
scenario_hla_handles = []
scenario_voc2ypr_handles = []
scenario_age_handles = []
scenario_agegp_handles = []
scenario_copy_btns = []
for s_idx in range(MAX_SCENARIOS):
color = SCENARIO_COLORS[s_idx]
label = f"Scenario {s_idx + 1}"
suffix = f"({label})"
visible_init = (s_idx == 0)
with gr.Row(visible=visible_init) as s_row:
scenario_rows.append(s_row)
with gr.Column():
gr.HTML(
f'