'
f'{fig_e}Event'
f'({n_event}/100)'
f'{fig_ne}No Event'
f'({n_no_event}/100)'
f'
'
)
badge = (
f'
'
f'{panel_label}
'
) if panel_label else ""
return (
f'
'
f'{badge}'
f'
{title}
'
f'
{pct_str}
'
f'
{grid_html}
'
f'
{legend}
'
f'
'
)
def _build_comparison_icon_grid(all_probs_list, labels, colors):
"""
all_probs_list: list of dicts (baseline first, then scenarios)
labels: list of display labels
colors: list of hex colors
Layout: one row per scenario, columns = outcomes.
"""
# Header row with outcome names
outcome_headers = "".join(
f'
'
f'{OUTCOME_TITLES.get(o, o)}
'
for o in ICON_OUTCOMES
)
header_row = (
f'
'
f'{outcome_headers}
'
)
rows_html = header_row
for i, (probs, label, color) in enumerate(zip(all_probs_list, labels, colors)):
# Row label
row_label = (
f'
'
f'{label}
'
)
cards = "".join(
f'
'
f'{_icon_card_html(probs[o], o, "", color)}'
f'
'
for o in ICON_OUTCOMES
)
rows_html += (
f'
'
f'{row_label}{cards}
'
)
footnote = (
f'
'
f'Each figure = 1 patient out of 100. '
f'■ Red = Event '
f'■ Green = No Event'
f'
'
)
return f'
{rows_html}{footnote}
'
def _delta_color_html(delta, is_survival):
if abs(delta) < 0.0005:
color = "#888888"
elif (delta > 0 and is_survival) or (delta < 0 and not is_survival):
color = "#2e7d32"
else:
color = "#c62828"
sign = "+" if delta >= 0 else ""
return f'{sign}{delta*100:.1f}%'
def _build_comparison_table_html(base_probs, base_ci, scenario_probs_list, scenario_ci_list, scenario_labels, scenario_colors):
"""
Inverted layout: rows = Baseline + Scenarios, columns = Outcomes.
"""
survival_outcomes = {"OS", "EFS"}
# Build outcome column headers
outcome_headers = "".join(
f"
{OUTCOME_DESCRIPTIONS.get(o, o)}
"
for o in REPORTING_OUTCOMES if o in base_probs
)
header = (
"
"
"
"
"
"
"
Scenario
"
f"{outcome_headers}"
"
"
)
rows = ""
# Baseline row
baseline_cells = ""
for o in REPORTING_OUTCOMES:
if o not in base_probs:
continue
bp = base_probs[o]
blo, bhi = base_ci.get(o, (float("nan"), float("nan")))
baseline_cells += (
f"
"
f"
{bp*100:.1f}%
"
f"
[{blo*100:.1f}%–{bhi*100:.1f}%]
"
f"
"
)
rows += (
f"
"
f"
"
f"
"
f"Baseline
"
f"{baseline_cells}
"
)
# Scenario rows
for j, (sp_dict, sci_dict, s_label) in enumerate(zip(scenario_probs_list, scenario_ci_list, scenario_labels)):
sc_color = scenario_colors[j % len(scenario_colors)]
bg = "#fafbfc" if j % 2 == 0 else "#ffffff"
scenario_cells = ""
for o in REPORTING_OUTCOMES:
if o not in base_probs or o not in sp_dict:
continue
bp = base_probs[o]
wp = sp_dict[o]
wlo, whi = sci_dict.get(o, (float("nan"), float("nan")))
delta = wp - bp
is_surv = o in survival_outcomes
scenario_cells += (
f"
"
f"
{wp*100:.1f}%
"
f"
[{wlo*100:.1f}%–{whi*100:.1f}%]
"
f"
{_delta_color_html(delta, is_surv)}
"
f"
"
)
rows += (
f"
"
f"
"
f"
"
f""
f"{s_label}
"
f"{scenario_cells}
"
)
footer = (
"
"
"
"
"Δ from Baseline: Green = improvement "
"Red = worsening | "
"OS & EFS: higher is better; all other outcomes: lower is better."
"
"
)
return header + rows + footer
def _build_violation_html(violations):
if not violations:
return ""
items = "".join(f"
{v}
" for v in violations)
return (
f'
'
f'
'
f'Constraint Violations — Analysis blocked
'
f'
{items}
'
f'
'
f'Please correct the above before running the comparison.
'
f'
'
)
# ─────────────────────────────────────────────────────────────────────────────
# MAIN PREDICT CALLBACK (Tab 1)
# ─────────────────────────────────────────────────────────────────────────────
def predict_gradio(*values):
try:
user_vals = _values_to_dict(values)
missing = [f for f, v in user_vals.items()
if v is None or v == "" or (isinstance(v, float) and pd.isna(v))]
if missing:
raise ValueError(f"Please fill in all fields. Missing: {', '.join(missing)}")
calibrated, _ = predict_with_comparison(user_vals)
calibrated_probs, calibrated_intervals = calibrated
rows = []
for outcome in REPORTING_OUTCOMES:
desc = OUTCOME_DESCRIPTIONS[outcome]
calib_prob = calibrated_probs[outcome]
ci_low, ci_high = calibrated_intervals[outcome]
rows.append({
"Outcome": desc,
"Probability": f"{calib_prob * 100:.1f}%",
"95% CI": f"[{ci_low * 100:.1f}% – {ci_high * 100:.1f}%]",
})
df = pd.DataFrame(rows)
shap_plots = create_all_shap_plots(user_vals, max_display=10)
icon_arrays = create_all_icon_arrays(calibrated_probs)
return (
df,
icon_arrays["__grid__"],
shap_plots["DEAD"],
shap_plots["GF"],
shap_plots["AGVHD"],
shap_plots["CGVHD"],
shap_plots["VOCPSHI"],
shap_plots["EFS"],
shap_plots["STROKEHI"],
shap_plots["OS"],
)
except Exception as e:
print(traceback.format_exc())
raise gr.Error(f"{type(e).__name__}: {str(e)}")
# ─────────────────────────────────────────────────────────────────────────────
# CSS
# ─────────────────────────────────────────────────────────────────────────────
custom_css = """
.predict-button {
background: linear-gradient(to right, #ff6b35, #ff8c42) !important;
border: none !important; color: white !important;
font-weight: bold !important; font-size: 16px !important; padding: 12px !important;
}
.predict-button:hover { background: linear-gradient(to right, #ff5722, #ff7b29) !important; }
.copy-button {
background: linear-gradient(to right, #388e3c, #66bb6a) !important;
border: none !important; color: white !important; font-weight: 600 !important;
}
.copy-button:hover { background: linear-gradient(to right, #2e7d32, #43a047) !important; }
.copy-from-predict-button {
background: linear-gradient(to right, #6a1b9a, #ab47bc) !important;
border: none !important; color: white !important; font-weight: 600 !important;
}
.copy-from-predict-button:hover { background: linear-gradient(to right, #4a148c, #8e24aa) !important; }
.counterfactual-button {
background: linear-gradient(to right, #1976d2, #42a5f5) !important;
border: none !important; color: white !important;
font-weight: bold !important; font-size: 15px !important; padding: 12px !important;
}
.counterfactual-button:hover { background: linear-gradient(to right, #1565c0, #1e88e5) !important; }
.output-dataframe table td:first-child,
.output-dataframe table th:first-child {
white-space: normal !important; word-break: break-word !important; min-width: 240px !important;
}
.constraint-info {
background: #e8f5e9; border-left: 4px solid #388e3c;
padding: 8px 14px; font-size: 12px; color: #1b5e20;
border-radius: 4px; margin-bottom: 8px;
}
.scenario-panel-0 { border-left: 4px solid #e65100 !important; }
.scenario-panel-1 { border-left: 4px solid #6a1b9a !important; }
.scenario-panel-2 { border-left: 4px solid #1b5e20 !important; }
.scenario-panel-3 { border-left: 4px solid #0d47a1 !important; }
.scenario-panel-4 { border-left: 4px solid #b71c1c !important; }
"""
# ─────────────────────────────────────────────────────────────────────────────
# BUILD UI
# ─────────────────────────────────────────────────────────────────────────────
with gr.Blocks(title="HCT Outcome Prediction Model", css=custom_css) as demo:
gr.Markdown("# HCT Outcome Prediction Model")
# ── shared state: how many counterfactual scenarios are active ────────────
n_scenarios_state = gr.State(1)
with gr.Tabs():
# ══════════════════════════════════════════════════════════════════════
# TAB 1 — PREDICT OUTCOMES
# ══════════════════════════════════════════════════════════════════════
with gr.Tab("Predict Outcomes"):
gr.Markdown("Enter patient, transplant, and disease characteristics to predict outcomes.")
inputs_dict = {}
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Patient Characteristics")
for f in PATIENT_FEATURES:
inputs_dict[f] = make_component(f)
with gr.Column(scale=1):
gr.Markdown("### Transplant Characteristics")
grouped_dd = gr.Dropdown(
choices=GROUPED_REGIMEN_CHOICES, value=None,
label="Published conditioning regimen",
info="Auto-fills Donor Type, Conditioning Intensity, Regimen, Serotherapy, GVHD Prophylaxis",
)
p_donorf = inputs_dict["DONORF"] = make_component("DONORF")
inputs_dict["GRAFTYPE"] = make_component("GRAFTYPE")
p_condgrpf = inputs_dict["CONDGRPF"] = make_component("CONDGRPF")
p_condgrp_final = inputs_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL")
p_atgf = inputs_dict["ATGF"] = make_component("ATGF")
p_gvhd_final = inputs_dict["GVHD_FINAL"] = make_component("GVHD_FINAL")
p_hla_final = inputs_dict["HLA_FINAL"] = make_component("HLA_FINAL")
with gr.Column(scale=1):
gr.Markdown("### Disease Characteristics")
for f in DISEASE_FEATURES:
inputs_dict[f] = make_component(f)
inputs_dict["AGE"].change(get_age_group, inputs_dict["AGE"], inputs_dict["AGEGPFF"])
inputs_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, inputs_dict["VOC2YPR"], inputs_dict["VOCFRQPR"])
p_donorf.change(fn=_hla_update_for_donor, inputs=p_donorf, outputs=p_hla_final)
grouped_dd.change(
apply_grouped_preset, grouped_dd,
[grouped_dd, p_donorf, p_condgrpf, p_condgrp_final, p_atgf, p_gvhd_final, p_hla_final],
)
inputs_list = [inputs_dict[f] for f in ALL_FEATURES]
predict_btn = gr.Button("Predict", elem_classes="predict-button", size="lg")
gr.Markdown("---")
gr.Markdown("## Prediction Results")
output_table = gr.Dataframe(
headers=["Outcome", "Probability", "95% CI"],
label="Predicted Outcomes",
elem_classes="output-dataframe",
row_count=(len(REPORTING_OUTCOMES), "dynamic"),
column_count=(3, "fixed"),
wrap=True,
)
gr.Markdown("---")
gr.Markdown("## Outcome Probability — Icon Arrays")
icon_array_grid = gr.HTML()
gr.Markdown("---")
gr.Markdown("## SHAP — Feature Importance")
with gr.Row():
shap_dead = gr.Plot(label="Death")
shap_gf = gr.Plot(label="Graft Failure")
shap_agvhd = gr.Plot(label="Acute GvHD")
shap_cgvhd = gr.Plot(label="Chronic GvHD")
with gr.Row():
shap_vocpshi = gr.Plot(label="Vaso-Occlusive Crisis Post-HCT")
shap_efs = gr.Plot(label="Event-Free Survival")
shap_stroke = gr.Plot(label="Stroke Post-HCT")
shap_os = gr.Plot(label="Overall Survival")
predict_btn.click(
fn=predict_gradio,
inputs=inputs_list,
outputs=[
output_table, icon_array_grid,
shap_dead, shap_gf, shap_agvhd, shap_cgvhd,
shap_vocpshi, shap_efs, shap_stroke, shap_os,
],
)
# ══════════════════════════════════════════════════════════════════════
# TAB 2 — COUNTERFACTUAL ANALYSIS (dynamic multi-scenario)
# ══════════════════════════════════════════════════════════════════════
with gr.Tab("Counterfactual Scenarios"):
gr.Markdown(
"## Counterfactual Scenario Analysis\n"
"Enter the **baseline** patient, then choose how many counterfactual "
"scenarios you want to compare. Each scenario panel will appear below."
)
# ── Number of scenarios selector ──────────────────────────────────
with gr.Row():
n_scenarios_slider = gr.Slider(
minimum=1, maximum=MAX_SCENARIOS, step=1, value=1,
label=f"How many counterfactual scenarios do you want to compare? (1–{MAX_SCENARIOS})",
info="Adjust this first — scenario panels will appear/disappear below.",
)
gr.Markdown("---")
# ── BASELINE ─────────────────────────────────────────────────────
gr.Markdown("## Baseline Patient Profile")
with gr.Row():
copy_from_predict_btn = gr.Button(
"Copy from Predict tab → Baseline",
elem_classes="copy-from-predict-button",
size="sm",
)
wi_baseline_dict = {}
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Patient Characteristics")
for f in PATIENT_FEATURES:
wi_baseline_dict[f] = make_component(f)
with gr.Column(scale=1):
gr.Markdown("### Transplant Characteristics")
wi_grouped_base = gr.Dropdown(
choices=GROUPED_REGIMEN_CHOICES, value=None,
label="Published conditioning regimen (Baseline)",
)
wb_donorf = wi_baseline_dict["DONORF"] = make_component("DONORF")
wi_baseline_dict["GRAFTYPE"] = make_component("GRAFTYPE")
wb_condgrpf = wi_baseline_dict["CONDGRPF"] = make_component("CONDGRPF")
wb_condgrp_final = wi_baseline_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL")
wb_atgf = wi_baseline_dict["ATGF"] = make_component("ATGF")
wb_gvhd_final = wi_baseline_dict["GVHD_FINAL"] = make_component("GVHD_FINAL")
wb_hla_final = wi_baseline_dict["HLA_FINAL"] = make_component("HLA_FINAL")
with gr.Column(scale=1):
gr.Markdown("### Disease Characteristics")
for f in DISEASE_FEATURES:
wi_baseline_dict[f] = make_component(f)
wi_baseline_dict["AGE"].change(get_age_group, wi_baseline_dict["AGE"], wi_baseline_dict["AGEGPFF"])
wi_baseline_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, wi_baseline_dict["VOC2YPR"], wi_baseline_dict["VOCFRQPR"])
wb_donorf.change(fn=_hla_update_for_donor, inputs=wb_donorf, outputs=wb_hla_final)
wi_grouped_base.change(
apply_grouped_preset, wi_grouped_base,
[wi_grouped_base, wb_donorf, wb_condgrpf, wb_condgrp_final, wb_atgf, wb_gvhd_final, wb_hla_final],
)
wi_baseline_list = [wi_baseline_dict[f] for f in ALL_FEATURES]
# Copy from Predict tab → Baseline
def _copy_predict_to_baseline(*vals):
n = len(ALL_FEATURES)
feature_vals = vals[:n]
regimen_value = vals[n] if len(vals) > n else None
preset_outputs = apply_grouped_preset(regimen_value)
regimen_dd_upd = gr.update(value=regimen_value)
donorf_upd, condgrpf_upd, condgrp_final_upd, atgf_upd, gvhd_final_upd, hla_final_upd = preset_outputs[1:]
return (*list(feature_vals), regimen_dd_upd, donorf_upd, condgrpf_upd,
condgrp_final_upd, atgf_upd, gvhd_final_upd, hla_final_upd)
copy_from_predict_btn.click(
fn=_copy_predict_to_baseline,
inputs=inputs_list + [grouped_dd],
outputs=wi_baseline_list + [wi_grouped_base, wb_donorf, wb_condgrpf,
wb_condgrp_final, wb_atgf, wb_gvhd_final, wb_hla_final],
)
gr.Markdown("---")
# ── SCENARIO PANELS ───────────────────────────────────────────────
scenario_dicts = []
scenario_lists = []
scenario_grouped_dds = []
scenario_rows = []
scenario_name_inputs = [] # NEW: one gr.Textbox per scenario for custom name
scenario_donorf_handles = []
scenario_hla_handles = []
scenario_voc2ypr_handles = []
scenario_age_handles = []
scenario_agegp_handles = []
scenario_copy_btns = []
for s_idx in range(MAX_SCENARIOS):
color = SCENARIO_COLORS[s_idx]
label = f"Scenario {s_idx + 1}"
suffix = f"({label})"
visible_init = (s_idx == 0)
with gr.Row(visible=visible_init) as s_row:
scenario_rows.append(s_row)
with gr.Column():
gr.HTML(
f'
'
f'Counterfactual {label}
'
)
with gr.Row():
# NEW: scenario name input
s_name_input = gr.Textbox(
label=f"Scenario name ({label})",
value=label,
placeholder=f"e.g. {label}",
scale=2,
)
scenario_name_inputs.append(s_name_input)
copy_btn = gr.Button(
f"⬇ Copy Baseline → {label}",
elem_classes="copy-button",
size="sm",
scale=1,
)
scenario_copy_btns.append(copy_btn)
s_dict = {}
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(f"#### Patient — {label}")
for f in PATIENT_FEATURES:
s_dict[f] = make_component(f, suffix)
with gr.Column(scale=1):
gr.Markdown(f"#### Transplant — {label}")
s_grouped_dd = gr.Dropdown(
choices=GROUPED_REGIMEN_CHOICES, value=None,
label=f"Published regimen ({label})",
)
scenario_grouped_dds.append(s_grouped_dd)
s_donorf = s_dict["DONORF"] = make_component("DONORF", suffix)
s_dict["GRAFTYPE"] = make_component("GRAFTYPE", suffix)
s_condgrpf = s_dict["CONDGRPF"] = make_component("CONDGRPF", suffix)
s_condgrp_final = s_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL", suffix)
s_atgf = s_dict["ATGF"] = make_component("ATGF", suffix)
s_gvhd_final = s_dict["GVHD_FINAL"] = make_component("GVHD_FINAL", suffix)
s_hla_final = s_dict["HLA_FINAL"] = make_component("HLA_FINAL", suffix)
with gr.Column(scale=1):
gr.Markdown(f"#### Disease — {label}")
for f in DISEASE_FEATURES:
s_dict[f] = make_component(f, suffix)
scenario_dicts.append(s_dict)
scenario_lists.append([s_dict[f] for f in ALL_FEATURES])
scenario_donorf_handles.append(s_donorf)
scenario_hla_handles.append(s_hla_final)
scenario_voc2ypr_handles.append(s_dict["VOC2YPR"])
scenario_age_handles.append(s_dict["AGE"])
scenario_agegp_handles.append(s_dict["AGEGPFF"])
# Wire up constraints within each scenario panel
s_dict["AGE"].change(get_age_group, s_dict["AGE"], s_dict["AGEGPFF"])
s_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, s_dict["VOC2YPR"], s_dict["VOCFRQPR"])
s_donorf.change(fn=_hla_update_for_donor, inputs=s_donorf, outputs=s_hla_final)
s_grouped_dd.change(
apply_grouped_preset, s_grouped_dd,
[s_grouped_dd, s_donorf, s_condgrpf, s_condgrp_final, s_atgf, s_gvhd_final, s_hla_final],
)
# Mirror baseline sex into each scenario and lock it
wi_baseline_dict["SEX"].change(
fn=lock_sex,
inputs=wi_baseline_dict["SEX"],
outputs=s_dict["SEX"],
)
s_dict["SEX"].interactive = False
# Copy Baseline → this scenario
def _make_copy_fn(s_grouped_dd_ref, s_donorf_ref, s_condgrpf_ref,
s_condgrp_final_ref, s_atgf_ref, s_gvhd_final_ref,
s_hla_final_ref):
def _copy(*vals):
n = len(ALL_FEATURES)
feature_vals = vals[:n]
regimen_value = vals[n] if len(vals) > n else None
preset_outputs = apply_grouped_preset(regimen_value)
regimen_dd_upd = gr.update(value=regimen_value)
d_upd, c_upd, cr_upd, a_upd, g_upd, h_upd = preset_outputs[1:]
return (*list(feature_vals), regimen_dd_upd,
d_upd, c_upd, cr_upd, a_upd, g_upd, h_upd)
return _copy
copy_btn.click(
fn=_make_copy_fn(s_grouped_dd, s_donorf, s_condgrpf,
s_condgrp_final, s_atgf, s_gvhd_final, s_hla_final),
inputs=wi_baseline_list + [wi_grouped_base],
outputs=(scenario_lists[-1]
+ [s_grouped_dd, s_donorf, s_condgrpf,
s_condgrp_final, s_atgf, s_gvhd_final, s_hla_final]),
)
# Re-lock sex after copy
copy_btn.click(
fn=lock_sex,
inputs=wi_baseline_dict["SEX"],
outputs=s_dict["SEX"],
)
# ── Slider → show/hide scenario panels ───────────────────────────
def _update_scenario_visibility(n):
return [gr.update(visible=(i < int(n))) for i in range(MAX_SCENARIOS)]
n_scenarios_slider.change(
fn=_update_scenario_visibility,
inputs=n_scenarios_slider,
outputs=scenario_rows,
)
n_scenarios_slider.change(
fn=lambda n: int(n),
inputs=n_scenarios_slider,
outputs=n_scenarios_state,
)
# ── RUN ──────────────────────────────────────────────────────────
gr.Markdown("---")
wi_run_btn = gr.Button(
"Run Counterfactual Comparison",
elem_classes="counterfactual-button",
size="lg",
)
# ── RESULTS ──────────────────────────────────────────────────────
gr.Markdown("## Comparison Results")
wi_violation_html = gr.HTML()
gr.Markdown("### Outcome Probability Table")
wi_table_html = gr.HTML()
gr.Markdown("---")
# NEW: Collapsible icon arrays section
with gr.Accordion("Outcome Icon Arrays — Baseline vs Scenarios", open=False):
gr.Markdown("*Icon arrays show each outcome probability per 100 patients.*")
wi_icon_html = gr.HTML()
gr.Markdown("---")
# NEW: Collapsible SHAP section
with gr.Accordion("SHAP Feature Importance", open=False):
gr.Markdown("### Baseline")
with gr.Row():
wi_shap_base = {o: gr.Plot(label=f"{o} — Baseline") for o in SHAP_ORDER}
# One SHAP row per scenario slot (hidden until needed)
wi_shap_scenarios = []
for s_idx in range(MAX_SCENARIOS):
gr.Markdown(f"### Scenario {s_idx + 1}")
with gr.Row(visible=(s_idx == 0)) as shap_row:
shap_plots_s = {o: gr.Plot(label=f"{o} — Scenario {s_idx + 1}") for o in SHAP_ORDER}
wi_shap_scenarios.append((shap_row, shap_plots_s))
# Also toggle SHAP rows with slider
def _update_shap_visibility(n):
return [gr.update(visible=(i < int(n))) for i in range(MAX_SCENARIOS)]
n_scenarios_slider.change(
fn=_update_shap_visibility,
inputs=n_scenarios_slider,
outputs=[row for row, _ in wi_shap_scenarios],
)
# ── RUN callback ──────────────────────────────────────────────────
def run_counterfactual(*all_values):
n_feat = len(ALL_FEATURES)
n_scen = int(all_values[0]) # n_scenarios_state
base_vals = all_values[1 : 1 + n_feat]
# Scenario names come next (MAX_SCENARIOS of them)
name_offset = 1 + n_feat
scenario_names_raw = all_values[name_offset : name_offset + MAX_SCENARIOS]
# Then the scenario feature blocks
feat_offset = name_offset + MAX_SCENARIOS
scenario_val_blocks = [
all_values[feat_offset + s * n_feat : feat_offset + (s + 1) * n_feat]
for s in range(MAX_SCENARIOS)
]
# Prepare flat output list
base_shap_outputs = [None] * 8
scene_shap_outputs = [[None] * 8 for _ in range(MAX_SCENARIOS)]
try:
base_dict = _values_to_dict(base_vals)
_check_missing(base_dict, "Baseline")
all_violations = []
scenario_dicts_active = []
for s in range(n_scen):
sd = _values_to_dict(scenario_val_blocks[s])
_check_missing(sd, f"Scenario {s+1}")
v = _validate_counterfactual_constraints(base_dict, sd, f"Scenario {s+1}")
all_violations.extend(v)
scenario_dicts_active.append(sd)
if all_violations:
return (
_build_violation_html(all_violations), "", "",
*base_shap_outputs,
*[p for plots in scene_shap_outputs for p in plots],
)
base_probs, base_ci = predict_all_outcomes(
base_dict, use_calibration=True, use_signed_voting=True, n_boot_ci=500
)
scen_probs_list = []
scen_ci_list = []
for sd in scenario_dicts_active:
sp, sci = predict_all_outcomes(
sd, use_calibration=True, use_signed_voting=True, n_boot_ci=500
)
scen_probs_list.append(sp)
scen_ci_list.append(sci)
# Use custom scenario names (fallback to "Scenario N" if blank)
scen_labels = []
for i in range(n_scen):
raw_name = scenario_names_raw[i] if i < len(scenario_names_raw) else ""
name = str(raw_name).strip() if raw_name else ""
scen_labels.append(name if name else f"Scenario {i+1}")
scen_colors = SCENARIO_COLORS[:n_scen]
table_html = _build_comparison_table_html(
base_probs, base_ci, scen_probs_list, scen_ci_list, scen_labels, scen_colors
)
icon_html = _build_comparison_icon_grid(
[base_probs] + scen_probs_list,
["Baseline"] + scen_labels,
["#1565c0"] + scen_colors,
)
# SHAP
base_shap_plots = create_all_shap_plots(base_dict, max_display=10)
base_shap_outputs = [base_shap_plots[o] for o in SHAP_ORDER]
for s, sd in enumerate(scenario_dicts_active):
sp_plots = create_all_shap_plots(sd, max_display=10)
scene_shap_outputs[s] = [sp_plots[o] for o in SHAP_ORDER]
return (
"",
table_html,
icon_html,
*base_shap_outputs,
*[p for plots in scene_shap_outputs for p in plots],
)
except Exception as e:
print(traceback.format_exc())
raise gr.Error(f"{type(e).__name__}: {str(e)}")
# Build flat output list: violation + table + icon + 8 base shap + MAX*8 scene shap
all_run_outputs = (
[wi_violation_html, wi_table_html, wi_icon_html]
+ [wi_shap_base[o] for o in SHAP_ORDER]
+ [shap_plots_s[o] for _, shap_plots_s in wi_shap_scenarios for o in SHAP_ORDER]
)
# Build flat input list:
# n_scenarios_state + wi_baseline_list + scenario_name_inputs + MAX_SCENARIOS × scenario_list
all_run_inputs = (
[n_scenarios_state]
+ wi_baseline_list
+ scenario_name_inputs # NEW: custom names
+ [feat for s_list in scenario_lists for feat in s_list]
)
wi_run_btn.click(
fn=run_counterfactual,
inputs=all_run_inputs,
outputs=all_run_outputs,
)
# ─────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
demo.launch(ssr_mode=False)