shivapriyasom commited on
Commit
7bae749
·
verified ·
1 Parent(s): 6568f04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +279 -26
app.py CHANGED
@@ -73,6 +73,8 @@ GROUPED_REGIMEN_CHOICES = [
73
  ("── MISMATCHED UNRELATED / CORD BLOOD ──", "__header_mismatched_cord__"),
74
  ("Bolanos-Meade et al 2022 (Mismatched/Cord)", "Bolanos-Meade et al 2022 (Mismatched/Cord)"),
75
  ("Patel et al 2020 (Mismatched/Cord)", "Patel et al 2020 (Mismatched/Cord)"),
 
 
76
  ]
77
  HEADER_VALUES = {v for _, v in GROUPED_REGIMEN_CHOICES if v.startswith("__header_")}
78
 
@@ -89,6 +91,19 @@ PUBLISHED_PRESETS = {
89
  "Patel et al 2020 (Mismatched/Cord)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu/TT", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "Mismatched unrelated donor or cord blood"},
90
  }
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  PATIENT_FEATURES = ["AGE", "AGEGPFF", "SEX", "KPS", "RCMVPR"]
93
  DONOR_FEATURES = ["DONORF", "GRAFTYPE", "HLA_FINAL", "CONDGRPF", "CONDGRP_FINAL", "ATGF", "GVHD_FINAL"]
94
  DISEASE_FEATURES = ["NACS2YR", "EXCHTFPR", "VOC2YPR", "VOCFRQPR", "SCATXRSN"]
@@ -108,6 +123,95 @@ NO_EVENT_COLOR = "#43a047"
108
  SHAP_ORDER = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "EFS", "STROKEHI", "OS"]
109
 
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  # ─────────────────────────────────────────────────────────────────────────────
112
  # SHARED HELPERS
113
  # ─────────────────────────────────────────────────────────────────────────────
@@ -156,20 +260,57 @@ def vocfrqpr_from_voc2ypr(voc_status):
156
 
157
 
158
  def apply_grouped_preset(selected_value):
 
 
 
 
 
 
 
 
 
159
  if not selected_value or selected_value in HEADER_VALUES:
160
- return [gr.update(value=None)] + [gr.update()] * 6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  preset = PUBLISHED_PRESETS.get(selected_value)
162
  if not preset:
163
- return [gr.update()] * 7
164
- return [
 
 
 
 
 
 
 
165
  gr.update(),
166
- gr.update(value=preset["DONORF"]),
167
- gr.update(value=preset["CONDGRPF"]),
168
- gr.update(value=preset["CONDGRP_FINAL"]),
169
- gr.update(value=preset["ATGF"]),
170
- gr.update(value=preset["GVHD_FINAL"]),
171
- gr.update(value=preset["HLA_FINAL"]),
172
- ]
173
 
174
 
175
  def copy_fields(*vals):
@@ -297,10 +438,6 @@ def _icon_card_html(probability: float, outcome: str,
297
 
298
 
299
  def _build_comparison_icon_grid(base_probs: dict, wi_probs: dict) -> str:
300
- """
301
- Side-by-side icon array grid.
302
- Layout: 3 outcome columns per row, each column has Baseline card + What-If card side by side.
303
- """
304
  rows_html = ""
305
  for row_start in range(0, len(ICON_OUTCOMES), 3):
306
  chunk = ICON_OUTCOMES[row_start: row_start + 3]
@@ -398,6 +535,22 @@ def _build_comparison_table_html(base_probs, base_ci, wi_probs, wi_ci) -> str:
398
  return header + rows + footer
399
 
400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  # ─────────────────────────────────────────────────────────────────────────────
402
  # MAIN PREDICT CALLBACK
403
  # ─────────────────────────────────────────────────────────────────────────────
@@ -452,19 +605,28 @@ def predict_gradio(*values):
452
  def run_whatif_predict(*all_values):
453
  """
454
  Receives 2 × len(ALL_FEATURES) values: first block = baseline, second = what-if.
455
- Returns: table_html, icon_grid_html,
456
  base_shap×8, wi_shap×8 (order = SHAP_ORDER)
457
  """
458
  n = len(ALL_FEATURES)
459
  baseline_vals = all_values[:n]
460
  whatif_vals = all_values[n:]
461
 
 
 
 
462
  try:
463
  base_dict = _values_to_dict(baseline_vals)
464
  whatif_dict = _values_to_dict(whatif_vals)
465
  _check_missing(base_dict, "Baseline")
466
  _check_missing(whatif_dict, "What-If")
467
 
 
 
 
 
 
 
468
  base_probs, base_ci = predict_all_outcomes(
469
  base_dict, use_calibration=True, use_signed_voting=True, n_boot_ci=500
470
  )
@@ -479,6 +641,7 @@ def run_whatif_predict(*all_values):
479
  wi_shap = create_all_shap_plots(whatif_dict, max_display=10)
480
 
481
  return (
 
482
  table_html,
483
  icon_html,
484
  *[base_shap[o] for o in SHAP_ORDER],
@@ -489,6 +652,17 @@ def run_whatif_predict(*all_values):
489
  raise gr.Error(f"{type(e).__name__}: {str(e)}")
490
 
491
 
 
 
 
 
 
 
 
 
 
 
 
492
  # ─────────────────────────────────────────────────────────────────────────────
493
  # CSS
494
  # ─────────────────────────────────────────────────────────────────────────────
@@ -516,6 +690,12 @@ custom_css = """
516
  border: none !important; color: white !important; font-weight: 600 !important;
517
  }
518
  .copy-from-predict-button:hover { background: linear-gradient(to right, #4a148c, #8e24aa) !important; }
 
 
 
 
 
 
519
  /* Ensure Outcome column text is never truncated in the results table */
520
  .output-dataframe table td:first-child,
521
  .output-dataframe table th:first-child {
@@ -523,6 +703,16 @@ custom_css = """
523
  word-break: break-word !important;
524
  min-width: 240px !important;
525
  }
 
 
 
 
 
 
 
 
 
 
526
  """
527
 
528
 
@@ -530,7 +720,6 @@ custom_css = """
530
  # BUILD UI
531
  # ─────────────────────────────────────────────────────────────────────────────
532
 
533
- # FIX 1: css moved from gr.Blocks() to demo.launch() for Gradio 6.0
534
  with gr.Blocks(title="HCT Outcome Prediction Model") as demo:
535
  gr.Markdown("# HCT Outcome Prediction Model")
536
 
@@ -572,6 +761,14 @@ with gr.Blocks(title="HCT Outcome Prediction Model") as demo:
572
 
573
  inputs_dict["AGE"].change(get_age_group, inputs_dict["AGE"], inputs_dict["AGEGPFF"])
574
  inputs_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, inputs_dict["VOC2YPR"], inputs_dict["VOCFRQPR"])
 
 
 
 
 
 
 
 
575
  grouped_dd.change(
576
  apply_grouped_preset, grouped_dd,
577
  [grouped_dd, p_donorf, p_condgrpf, p_condgrp_final, p_atgf, p_gvhd_final, p_hla_final],
@@ -584,7 +781,6 @@ with gr.Blocks(title="HCT Outcome Prediction Model") as demo:
584
  gr.Markdown("---")
585
  gr.Markdown("## Prediction Results")
586
 
587
- # FIX 2: col_count renamed to column_count in Gradio 6.0
588
  output_table = gr.Dataframe(
589
  headers=["Outcome", "Probability", "95% CI"],
590
  label="Predicted Outcomes",
@@ -622,7 +818,7 @@ with gr.Blocks(title="HCT Outcome Prediction Model") as demo:
622
  )
623
 
624
  # ══════════════════════════════════════════════════════════════════════
625
- # TAB 2 — WHAT-IF ANALYSIS
626
  # ══════════════════════════════════════════════════════════════════════
627
  with gr.Tab("Counterfactual Scenarios"):
628
  gr.Markdown(
@@ -634,6 +830,19 @@ with gr.Blocks(title="HCT Outcome Prediction Model") as demo:
634
  "- *Copy Baseline → What-If* — mirrors the baseline into the what-if panel so you only change what differs"
635
  )
636
 
 
 
 
 
 
 
 
 
 
 
 
 
 
637
  wi_baseline_dict = {}
638
  wi_whatif_dict = {}
639
 
@@ -673,6 +882,14 @@ with gr.Blocks(title="HCT Outcome Prediction Model") as demo:
673
 
674
  wi_baseline_dict["AGE"].change(get_age_group, wi_baseline_dict["AGE"], wi_baseline_dict["AGEGPFF"])
675
  wi_baseline_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, wi_baseline_dict["VOC2YPR"], wi_baseline_dict["VOCFRQPR"])
 
 
 
 
 
 
 
 
676
  wi_grouped_base.change(
677
  apply_grouped_preset, wi_grouped_base,
678
  [wi_grouped_base, wb_donorf, wb_condgrpf, wb_condgrp_final, wb_atgf, wb_gvhd_final, wb_hla_final],
@@ -683,7 +900,7 @@ with gr.Blocks(title="HCT Outcome Prediction Model") as demo:
683
  # Wire "Copy from Predict tab" → baseline fields
684
  copy_from_predict_btn.click(
685
  fn=copy_fields,
686
- inputs=inputs_list, # Tab 1 components
687
  outputs=wi_baseline_list,
688
  )
689
 
@@ -707,7 +924,7 @@ with gr.Blocks(title="HCT Outcome Prediction Model") as demo:
707
  wi_grouped_wi = gr.Dropdown(
708
  choices=GROUPED_REGIMEN_CHOICES, value=None,
709
  label="Published conditioning regimen (Counterfactual)",
710
- info="Auto-fills what-if transplant fields",
711
  )
712
  ww_donorf = wi_whatif_dict["DONORF"] = make_component("DONORF", "(What-If)")
713
  wi_whatif_dict["GRAFTYPE"] = make_component("GRAFTYPE", "(What-If)")
@@ -718,12 +935,30 @@ with gr.Blocks(title="HCT Outcome Prediction Model") as demo:
718
  ww_hla_final = wi_whatif_dict["HLA_FINAL"] = make_component("HLA_FINAL", "(What-If)")
719
 
720
  with gr.Column(scale=1):
721
- gr.Markdown("### Counterfactual — Disease Characteristics")
722
  for f in DISEASE_FEATURES:
723
  wi_whatif_dict[f] = make_component(f, "(Counterfactual)")
724
 
725
  wi_whatif_dict["AGE"].change(get_age_group, wi_whatif_dict["AGE"], wi_whatif_dict["AGEGPFF"])
726
  wi_whatif_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, wi_whatif_dict["VOC2YPR"], wi_whatif_dict["VOCFRQPR"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
727
  wi_grouped_wi.change(
728
  apply_grouped_preset, wi_grouped_wi,
729
  [wi_grouped_wi, ww_donorf, ww_condgrpf, ww_condgrp_final, ww_atgf, ww_gvhd_final, ww_hla_final],
@@ -733,20 +968,38 @@ with gr.Blocks(title="HCT Outcome Prediction Model") as demo:
733
 
734
  copy_to_whatif_btn.click(fn=copy_fields, inputs=wi_baseline_list, outputs=wi_whatif_list)
735
 
 
 
 
 
 
 
 
 
 
 
 
 
736
  # ── RUN ──────────────────────────────────────────────────────────
737
  gr.Markdown("---")
738
- wi_run_btn = gr.Button("Run Counterfactual Comparison", elem_classes="Counterfactual-button", size="lg")
 
 
 
 
739
 
740
  # ── RESULTS ──────────────────────────────────────────────────────
741
  gr.Markdown("## Comparison Results")
742
 
 
 
 
743
  gr.Markdown("### Outcome Probability Table")
744
  wi_table_html = gr.HTML()
745
 
746
  gr.Markdown("---")
747
  gr.Markdown("### Outcome Icon Arrays — Baseline vs Counterfactual")
748
- # FIX 3: sanitize=False removed (parameter no longer exists in Gradio 6.0)
749
- gr.Markdown("<small>Each outcome shows Baseline and What-If side by side.</small>")
750
  wi_icon_html = gr.HTML()
751
 
752
  gr.Markdown("---")
@@ -768,7 +1021,7 @@ with gr.Blocks(title="HCT Outcome Prediction Model") as demo:
768
  wi_shap_wi_dead = gr.Plot(label="Death — Counterfactual")
769
  wi_shap_wi_gf = gr.Plot(label="Graft Failure — Counterfactual")
770
  wi_shap_wi_agvhd = gr.Plot(label="Acute GvHD — Counterfactual")
771
- wi_shap_wi_cgvhd = gr.Plot(label="Chronic GvHD —Counterfactual")
772
  with gr.Row():
773
  wi_shap_wi_vocpshi = gr.Plot(label="VOC Post-HCT — Counterfactual")
774
  wi_shap_wi_efs = gr.Plot(label="Event-Free Survival — Counterfactual")
@@ -779,6 +1032,7 @@ with gr.Blocks(title="HCT Outcome Prediction Model") as demo:
779
  fn=run_whatif_predict,
780
  inputs=wi_baseline_list + wi_whatif_list,
781
  outputs=[
 
782
  wi_table_html,
783
  wi_icon_html,
784
  # baseline SHAP (SHAP_ORDER: DEAD GF AGVHD CGVHD VOCPSHI EFS STROKEHI OS)
@@ -793,5 +1047,4 @@ with gr.Blocks(title="HCT Outcome Prediction Model") as demo:
793
 
794
  # ─────────────────────────────────────────────────────────────────────────────
795
  if __name__ == "__main__":
796
- # FIX 1 (continued): css passed to launch() instead of gr.Blocks()
797
  demo.launch(ssr_mode=False, css=custom_css)
 
73
  ("── MISMATCHED UNRELATED / CORD BLOOD ──", "__header_mismatched_cord__"),
74
  ("Bolanos-Meade et al 2022 (Mismatched/Cord)", "Bolanos-Meade et al 2022 (Mismatched/Cord)"),
75
  ("Patel et al 2020 (Mismatched/Cord)", "Patel et al 2020 (Mismatched/Cord)"),
76
+ ("── CUSTOM ──", "__header_custom__"),
77
+ ("Custom", "Custom"),
78
  ]
79
  HEADER_VALUES = {v for _, v in GROUPED_REGIMEN_CHOICES if v.startswith("__header_")}
80
 
 
91
  "Patel et al 2020 (Mismatched/Cord)": {"CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu/TT", "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", "HLA_FINAL": "7/8", "DONORF": "Mismatched unrelated donor or cord blood"},
92
  }
93
 
94
+ # Donor types that strictly lock HLA to 8/8
95
+ MSD_DONOR = "HLA identical sibling"
96
+ MUD_DONOR = "Matched unrelated donor"
97
+ LOCKED_88_DONORS = {MSD_DONOR, MUD_DONOR}
98
+
99
+ # Donor types that cannot use 8/8
100
+ NON_88_DONORS = {"HLA mismatch relative", "Mismatched unrelated donor or cord blood"}
101
+
102
+ # Mapping of which published regimens belong to which donor category
103
+ MSD_REGIMENS = {"Hsieh et al 2014", "Krishnamurti et al 2019", "King et al 2015", "Walters et al 1996"}
104
+ MMUD_REGIMENS = {"Bolanos-Meade et al 2022 (Mismatched/Cord)", "Patel et al 2020 (Mismatched/Cord)",
105
+ "Bolanos-Meade et al 2022 (HLA Mismatch)", "Patel et al 2020 (HLA Mismatch)"}
106
+
107
  PATIENT_FEATURES = ["AGE", "AGEGPFF", "SEX", "KPS", "RCMVPR"]
108
  DONOR_FEATURES = ["DONORF", "GRAFTYPE", "HLA_FINAL", "CONDGRPF", "CONDGRP_FINAL", "ATGF", "GVHD_FINAL"]
109
  DISEASE_FEATURES = ["NACS2YR", "EXCHTFPR", "VOC2YPR", "VOCFRQPR", "SCATXRSN"]
 
123
  SHAP_ORDER = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "EFS", "STROKEHI", "OS"]
124
 
125
 
126
+ # ─────────────────────────────────────────────────────────────────────────────
127
+ # CONSTRAINT VALIDATION HELPERS
128
+ # ─────────────────────────────────────────────────────────────────────────────
129
+
130
+ def _get_hla_constraint_for_donor(donor):
131
+ """
132
+ Returns (allowed_hla_choices, locked_to_88, locked_from_88)
133
+ - locked_to_88: donor must have 8/8 (MSD or MUD)
134
+ - locked_from_88: donor cannot have 8/8 (mismatch/cord)
135
+ """
136
+ if donor in LOCKED_88_DONORS:
137
+ return ["8/8"], True, False
138
+ elif donor in NON_88_DONORS:
139
+ return ["7/8", "≤ 6/8"], False, True
140
+ return HLA_FINAL_CHOICES, False, False
141
+
142
+
143
+ def _validate_counterfactual_constraints(base_dict, wi_dict):
144
+ """
145
+ Returns a list of constraint violation messages (strings).
146
+ Empty list means no violations.
147
+ """
148
+ violations = []
149
+
150
+ # 1. Sex is immutable — cannot be changed
151
+ if base_dict.get("SEX") and wi_dict.get("SEX"):
152
+ if base_dict["SEX"] != wi_dict["SEX"]:
153
+ violations.append(
154
+ "❌ Immutable feature: Sex cannot be changed in a counterfactual scenario."
155
+ )
156
+
157
+ # 2. Age cannot be decreased
158
+ base_age = base_dict.get("AGE")
159
+ wi_age = wi_dict.get("AGE")
160
+ if base_age is not None and wi_age is not None:
161
+ try:
162
+ if float(wi_age) < float(base_age):
163
+ violations.append(
164
+ f"❌ Unacceptable counterfactual: Age cannot be decreased "
165
+ f"(baseline {base_age} → what-if {wi_age})."
166
+ )
167
+ except (TypeError, ValueError):
168
+ pass
169
+
170
+ # 3. Donor type ↔ HLA matching constraint
171
+ wi_donor = wi_dict.get("DONORF")
172
+ wi_hla = wi_dict.get("HLA_FINAL")
173
+ if wi_donor and wi_hla:
174
+ if wi_donor in LOCKED_88_DONORS and wi_hla != "8/8":
175
+ violations.append(
176
+ f"❌ HLA constraint: Donor type '{wi_donor}' requires strictly 8/8 HLA matching."
177
+ )
178
+ elif wi_donor in NON_88_DONORS and wi_hla == "8/8":
179
+ violations.append(
180
+ f"❌ HLA constraint: Donor type '{wi_donor}' cannot have 8/8 HLA matching."
181
+ )
182
+
183
+ # 4. Cross-donor regimen mismatch (MMUD regimen for MSD, or MSD regimen for MMUD)
184
+ base_donor = base_dict.get("DONORF")
185
+ wi_regimen_label = wi_dict.get("__regimen_label__", "") # injected below if available
186
+
187
+ # We detect this via DONORF in the what-if vs the preset's expected donor
188
+ if wi_donor:
189
+ # Check if what-if donor is MSD but the preset fields look like an MMUD preset
190
+ wi_condgrp = wi_dict.get("CONDGRP_FINAL", "")
191
+ wi_gvhd = wi_dict.get("GVHD_FINAL", "")
192
+ # PostCY GVHD prophylaxis is characteristic of haploidentical / MMUD regimens
193
+ is_mmud_style_gvhd = wi_gvhd in {"Post-CY + siro +- MMF", "Post-CY + MMF + CNI"}
194
+ if wi_donor == MSD_DONOR and is_mmud_style_gvhd:
195
+ violations.append(
196
+ "❌ Unacceptable counterfactual: Post-cyclophosphamide GVHD prophylaxis "
197
+ "(characteristic of MMUD/haploidentical regimens) is inconsistent with "
198
+ f"donor type '{MSD_DONOR}'."
199
+ )
200
+
201
+ return violations
202
+
203
+
204
+ def _hla_update_for_donor(donor_value):
205
+ """Return gr.update for HLA dropdown based on donor selection."""
206
+ if not donor_value:
207
+ return gr.update(choices=HLA_FINAL_CHOICES, interactive=True)
208
+ if donor_value in LOCKED_88_DONORS:
209
+ return gr.update(choices=["8/8"], value="8/8", interactive=False)
210
+ elif donor_value in NON_88_DONORS:
211
+ return gr.update(choices=["7/8", "≤ 6/8"], value=None, interactive=True)
212
+ return gr.update(choices=HLA_FINAL_CHOICES, interactive=True)
213
+
214
+
215
  # ─────────────────────────────────────────────────────────────────────────────
216
  # SHARED HELPERS
217
  # ─────────────────────────────────────────────────────────────────────────────
 
260
 
261
 
262
  def apply_grouped_preset(selected_value):
263
+ """
264
+ Applies a published preset or enables Custom mode.
265
+ Returns updates for:
266
+ [dropdown, donorf, condgrpf, condgrp_final, atgf, gvhd_final, hla_final,
267
+ condgrpf_interactive, condgrp_final_interactive, atgf_interactive,
268
+ gvhd_final_interactive, donorf_interactive]
269
+ When a published preset is selected, transplant fields are locked (non-interactive).
270
+ When Custom is selected, transplant fields become interactive (except HLA which follows donor).
271
+ """
272
  if not selected_value or selected_value in HEADER_VALUES:
273
+ return (
274
+ gr.update(value=None), # dropdown reset
275
+ gr.update(), # donorf
276
+ gr.update(), # condgrpf
277
+ gr.update(), # condgrp_final
278
+ gr.update(), # atgf
279
+ gr.update(), # gvhd_final
280
+ gr.update(), # hla_final
281
+ )
282
+
283
+ if selected_value == "Custom":
284
+ # Unlock all transplant fields for custom entry
285
+ return (
286
+ gr.update(),
287
+ gr.update(interactive=True), # donorf
288
+ gr.update(interactive=True), # condgrpf
289
+ gr.update(interactive=True), # condgrp_final
290
+ gr.update(interactive=True), # atgf
291
+ gr.update(interactive=True), # gvhd_final
292
+ gr.update(interactive=True), # hla_final — will be further constrained by donor change
293
+ )
294
+
295
  preset = PUBLISHED_PRESETS.get(selected_value)
296
  if not preset:
297
+ return (gr.update(),) * 7
298
+
299
+ donor = preset["DONORF"]
300
+ hla_upd = _hla_update_for_donor(donor)
301
+ # Override the value with the preset's HLA
302
+ hla_upd_dict = hla_upd if isinstance(hla_upd, dict) else {}
303
+ hla_final_update = gr.update(value=preset["HLA_FINAL"], interactive=False, choices=["8/8"] if donor in LOCKED_88_DONORS else (["7/8", "≤ 6/8"] if donor in NON_88_DONORS else HLA_FINAL_CHOICES))
304
+
305
+ return (
306
  gr.update(),
307
+ gr.update(value=preset["DONORF"], interactive=False),
308
+ gr.update(value=preset["CONDGRPF"], interactive=False),
309
+ gr.update(value=preset["CONDGRP_FINAL"], interactive=False),
310
+ gr.update(value=preset["ATGF"], interactive=False),
311
+ gr.update(value=preset["GVHD_FINAL"], interactive=False),
312
+ hla_final_update,
313
+ )
314
 
315
 
316
  def copy_fields(*vals):
 
438
 
439
 
440
  def _build_comparison_icon_grid(base_probs: dict, wi_probs: dict) -> str:
 
 
 
 
441
  rows_html = ""
442
  for row_start in range(0, len(ICON_OUTCOMES), 3):
443
  chunk = ICON_OUTCOMES[row_start: row_start + 3]
 
535
  return header + rows + footer
536
 
537
 
538
+ def _build_violation_html(violations: list) -> str:
539
+ if not violations:
540
+ return ""
541
+ items = "".join(f"<li style='margin-bottom:6px;'>{v}</li>" for v in violations)
542
+ return (
543
+ f'<div style="background:#fff3e0;border:2px solid #e65100;border-radius:8px;'
544
+ f'padding:14px 18px;font-family:\'Segoe UI\',Arial,sans-serif;margin-bottom:12px;">'
545
+ f'<div style="font-weight:700;font-size:14px;color:#bf360c;margin-bottom:8px;">'
546
+ f'⚠️ Constraint Violations — Analysis blocked</div>'
547
+ f'<ul style="margin:0;padding-left:20px;color:#6d1f00;font-size:13px;">{items}</ul>'
548
+ f'<div style="margin-top:10px;font-size:11px;color:#888;">'
549
+ f'Please correct the above before running the counterfactual comparison.</div>'
550
+ f'</div>'
551
+ )
552
+
553
+
554
  # ─────────────────────────────────────────────────────────────────────────────
555
  # MAIN PREDICT CALLBACK
556
  # ─────────────────────────────────────────────────────────────────────────────
 
605
  def run_whatif_predict(*all_values):
606
  """
607
  Receives 2 × len(ALL_FEATURES) values: first block = baseline, second = what-if.
608
+ Returns: violation_html, table_html, icon_grid_html,
609
  base_shap×8, wi_shap×8 (order = SHAP_ORDER)
610
  """
611
  n = len(ALL_FEATURES)
612
  baseline_vals = all_values[:n]
613
  whatif_vals = all_values[n:]
614
 
615
+ # Number of outputs: 1 violation + 1 table + 1 icon + 8 base shap + 8 wi shap = 19
616
+ empty_shaps = [None] * 8
617
+
618
  try:
619
  base_dict = _values_to_dict(baseline_vals)
620
  whatif_dict = _values_to_dict(whatif_vals)
621
  _check_missing(base_dict, "Baseline")
622
  _check_missing(whatif_dict, "What-If")
623
 
624
+ # ── Constraint validation ────────────────────────────────────────────
625
+ violations = _validate_counterfactual_constraints(base_dict, whatif_dict)
626
+ if violations:
627
+ violation_html = _build_violation_html(violations)
628
+ return (violation_html, "", "", *empty_shaps, *empty_shaps)
629
+
630
  base_probs, base_ci = predict_all_outcomes(
631
  base_dict, use_calibration=True, use_signed_voting=True, n_boot_ci=500
632
  )
 
641
  wi_shap = create_all_shap_plots(whatif_dict, max_display=10)
642
 
643
  return (
644
+ "", # no violations
645
  table_html,
646
  icon_html,
647
  *[base_shap[o] for o in SHAP_ORDER],
 
652
  raise gr.Error(f"{type(e).__name__}: {str(e)}")
653
 
654
 
655
+ # ─────────────────────────────────────────────────────────────────────────────
656
+ # WHAT-IF SEX LOCK CALLBACK
657
+ # ─────────────────────────────────────────────────────────────────────────────
658
+
659
+ def lock_whatif_sex(baseline_sex):
660
+ """Mirror baseline sex into what-if and lock it."""
661
+ if baseline_sex:
662
+ return gr.update(value=baseline_sex, interactive=False)
663
+ return gr.update(interactive=False)
664
+
665
+
666
  # ─────────────────────────────────────────────────────────────────────────────
667
  # CSS
668
  # ─────────────────────────────────────────────────────────────────────────────
 
690
  border: none !important; color: white !important; font-weight: 600 !important;
691
  }
692
  .copy-from-predict-button:hover { background: linear-gradient(to right, #4a148c, #8e24aa) !important; }
693
+ .counterfactual-button {
694
+ background: linear-gradient(to right, #1976d2, #42a5f5) !important;
695
+ border: none !important; color: white !important;
696
+ font-weight: bold !important; font-size: 15px !important; padding: 12px !important;
697
+ }
698
+ .counterfactual-button:hover { background: linear-gradient(to right, #1565c0, #1e88e5) !important; }
699
  /* Ensure Outcome column text is never truncated in the results table */
700
  .output-dataframe table td:first-child,
701
  .output-dataframe table th:first-child {
 
703
  word-break: break-word !important;
704
  min-width: 240px !important;
705
  }
706
+ /* Constraint info box */
707
+ .constraint-info {
708
+ background: #e8f5e9;
709
+ border-left: 4px solid #388e3c;
710
+ padding: 8px 14px;
711
+ font-size: 12px;
712
+ color: #1b5e20;
713
+ border-radius: 4px;
714
+ margin-bottom: 8px;
715
+ }
716
  """
717
 
718
 
 
720
  # BUILD UI
721
  # ─────────────────────────────────────────────────────────────────────────────
722
 
 
723
  with gr.Blocks(title="HCT Outcome Prediction Model") as demo:
724
  gr.Markdown("# HCT Outcome Prediction Model")
725
 
 
761
 
762
  inputs_dict["AGE"].change(get_age_group, inputs_dict["AGE"], inputs_dict["AGEGPFF"])
763
  inputs_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, inputs_dict["VOC2YPR"], inputs_dict["VOCFRQPR"])
764
+
765
+ # HLA locking on Tab 1 based on donor selection
766
+ p_donorf.change(
767
+ fn=_hla_update_for_donor,
768
+ inputs=p_donorf,
769
+ outputs=p_hla_final,
770
+ )
771
+
772
  grouped_dd.change(
773
  apply_grouped_preset, grouped_dd,
774
  [grouped_dd, p_donorf, p_condgrpf, p_condgrp_final, p_atgf, p_gvhd_final, p_hla_final],
 
781
  gr.Markdown("---")
782
  gr.Markdown("## Prediction Results")
783
 
 
784
  output_table = gr.Dataframe(
785
  headers=["Outcome", "Probability", "95% CI"],
786
  label="Predicted Outcomes",
 
818
  )
819
 
820
  # ══════════════════════════════════════════════════════════════════════
821
+ # TAB 2 — WHAT-IF / COUNTERFACTUAL ANALYSIS
822
  # ══════════════════════════════════════════════════════════════════════
823
  with gr.Tab("Counterfactual Scenarios"):
824
  gr.Markdown(
 
830
  "- *Copy Baseline → What-If* — mirrors the baseline into the what-if panel so you only change what differs"
831
  )
832
 
833
+ # Constraint legend
834
+ gr.HTML(
835
+ '<div class="constraint-info">'
836
+ '<strong>Active constraints:</strong> '
837
+ '(1) <em>Sex</em> is immutable and locked in the counterfactual. &nbsp;|&nbsp; '
838
+ '(2) Transplant characteristics must be changed as <em>grouped clinical scenarios</em> via the published regimen dropdown (or Custom). &nbsp;|&nbsp; '
839
+ '(3) <em>MSD / MUD donors</em> are locked to 8/8 HLA; mismatch/cord donors cannot use 8/8. &nbsp;|&nbsp; '
840
+ '(4) VoC = No → VOC frequency defaults to &lt;3/yr. &nbsp;|&nbsp; '
841
+ '(5) Age cannot be <em>decreased</em>. &nbsp;|&nbsp; '
842
+ '(6) MMUD-style regimens cannot be assigned to MSD donors and vice-versa.'
843
+ '</div>'
844
+ )
845
+
846
  wi_baseline_dict = {}
847
  wi_whatif_dict = {}
848
 
 
882
 
883
  wi_baseline_dict["AGE"].change(get_age_group, wi_baseline_dict["AGE"], wi_baseline_dict["AGEGPFF"])
884
  wi_baseline_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, wi_baseline_dict["VOC2YPR"], wi_baseline_dict["VOCFRQPR"])
885
+
886
+ # Baseline donor → HLA locking
887
+ wb_donorf.change(
888
+ fn=_hla_update_for_donor,
889
+ inputs=wb_donorf,
890
+ outputs=wb_hla_final,
891
+ )
892
+
893
  wi_grouped_base.change(
894
  apply_grouped_preset, wi_grouped_base,
895
  [wi_grouped_base, wb_donorf, wb_condgrpf, wb_condgrp_final, wb_atgf, wb_gvhd_final, wb_hla_final],
 
900
  # Wire "Copy from Predict tab" → baseline fields
901
  copy_from_predict_btn.click(
902
  fn=copy_fields,
903
+ inputs=inputs_list,
904
  outputs=wi_baseline_list,
905
  )
906
 
 
924
  wi_grouped_wi = gr.Dropdown(
925
  choices=GROUPED_REGIMEN_CHOICES, value=None,
926
  label="Published conditioning regimen (Counterfactual)",
927
+ info="Auto-fills what-if transplant fields. Use 'Custom' to enter manually.",
928
  )
929
  ww_donorf = wi_whatif_dict["DONORF"] = make_component("DONORF", "(What-If)")
930
  wi_whatif_dict["GRAFTYPE"] = make_component("GRAFTYPE", "(What-If)")
 
935
  ww_hla_final = wi_whatif_dict["HLA_FINAL"] = make_component("HLA_FINAL", "(What-If)")
936
 
937
  with gr.Column(scale=1):
938
+ gr.Markdown("### Counterfactual — Disease Characteristics")
939
  for f in DISEASE_FEATURES:
940
  wi_whatif_dict[f] = make_component(f, "(Counterfactual)")
941
 
942
  wi_whatif_dict["AGE"].change(get_age_group, wi_whatif_dict["AGE"], wi_whatif_dict["AGEGPFF"])
943
  wi_whatif_dict["VOC2YPR"].change(vocfrqpr_from_voc2ypr, wi_whatif_dict["VOC2YPR"], wi_whatif_dict["VOCFRQPR"])
944
+
945
+ # ── CONSTRAINT: Sex is immutable — mirror baseline sex and lock what-if sex ──
946
+ wi_baseline_dict["SEX"].change(
947
+ fn=lock_whatif_sex,
948
+ inputs=wi_baseline_dict["SEX"],
949
+ outputs=wi_whatif_dict["SEX"],
950
+ )
951
+ # Also lock on copy
952
+ wi_whatif_dict["SEX"].interactive = False
953
+
954
+ # ── CONSTRAINT: Counterfactual donor → HLA locking ──────────────
955
+ ww_donorf.change(
956
+ fn=_hla_update_for_donor,
957
+ inputs=ww_donorf,
958
+ outputs=ww_hla_final,
959
+ )
960
+
961
+ # Preset selection locks transplant fields; Custom unlocks them
962
  wi_grouped_wi.change(
963
  apply_grouped_preset, wi_grouped_wi,
964
  [wi_grouped_wi, ww_donorf, ww_condgrpf, ww_condgrp_final, ww_atgf, ww_gvhd_final, ww_hla_final],
 
968
 
969
  copy_to_whatif_btn.click(fn=copy_fields, inputs=wi_baseline_list, outputs=wi_whatif_list)
970
 
971
+ # After copying baseline → whatif, re-apply sex lock and HLA lock
972
+ copy_to_whatif_btn.click(
973
+ fn=lock_whatif_sex,
974
+ inputs=wi_baseline_dict["SEX"],
975
+ outputs=wi_whatif_dict["SEX"],
976
+ )
977
+ copy_to_whatif_btn.click(
978
+ fn=_hla_update_for_donor,
979
+ inputs=ww_donorf,
980
+ outputs=ww_hla_final,
981
+ )
982
+
983
  # ── RUN ──────────────────────────────────────────────────────────
984
  gr.Markdown("---")
985
+ wi_run_btn = gr.Button(
986
+ "Run Counterfactual Comparison",
987
+ elem_classes="counterfactual-button",
988
+ size="lg",
989
+ )
990
 
991
  # ── RESULTS ──────────────────────────────────────────────────────
992
  gr.Markdown("## Comparison Results")
993
 
994
+ # Constraint violation banner (shown if validation fails)
995
+ wi_violation_html = gr.HTML()
996
+
997
  gr.Markdown("### Outcome Probability Table")
998
  wi_table_html = gr.HTML()
999
 
1000
  gr.Markdown("---")
1001
  gr.Markdown("### Outcome Icon Arrays — Baseline vs Counterfactual")
1002
+ gr.Markdown("<small>Each outcome shows Baseline and What-If side by side.</small>")
 
1003
  wi_icon_html = gr.HTML()
1004
 
1005
  gr.Markdown("---")
 
1021
  wi_shap_wi_dead = gr.Plot(label="Death — Counterfactual")
1022
  wi_shap_wi_gf = gr.Plot(label="Graft Failure — Counterfactual")
1023
  wi_shap_wi_agvhd = gr.Plot(label="Acute GvHD — Counterfactual")
1024
+ wi_shap_wi_cgvhd = gr.Plot(label="Chronic GvHD — Counterfactual")
1025
  with gr.Row():
1026
  wi_shap_wi_vocpshi = gr.Plot(label="VOC Post-HCT — Counterfactual")
1027
  wi_shap_wi_efs = gr.Plot(label="Event-Free Survival — Counterfactual")
 
1032
  fn=run_whatif_predict,
1033
  inputs=wi_baseline_list + wi_whatif_list,
1034
  outputs=[
1035
+ wi_violation_html,
1036
  wi_table_html,
1037
  wi_icon_html,
1038
  # baseline SHAP (SHAP_ORDER: DEAD GF AGVHD CGVHD VOCPSHI EFS STROKEHI OS)
 
1047
 
1048
  # ─────────────────────────────────────────────────────────────────────────────
1049
  if __name__ == "__main__":
 
1050
  demo.launch(ssr_mode=False, css=custom_css)