shivapriyasom commited on
Commit
54614a0
Β·
verified Β·
1 Parent(s): f62406f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -38
app.py CHANGED
@@ -6,11 +6,14 @@ from inference import (
6
  FEATURE_NAMES,
7
  REPORTING_OUTCOMES,
8
  OUTCOME_DESCRIPTIONS,
 
 
9
  predict_with_comparison,
10
  create_all_shap_plots,
11
- icon_array,
12
  )
13
 
 
14
  # ---------------------------------------------------------------------------
15
  # Choice lists
16
  # ---------------------------------------------------------------------------
@@ -52,6 +55,7 @@ SCATXRSN_CHOICES = [
52
  "Hodgkin lymphoma",
53
  ]
54
 
 
55
  # ---------------------------------------------------------------------------
56
  # Grouped published-regimen dropdown
57
  # ---------------------------------------------------------------------------
@@ -132,6 +136,7 @@ PUBLISHED_PRESETS = {
132
  },
133
  }
134
 
 
135
  # ---------------------------------------------------------------------------
136
  # Feature groupings
137
  # ---------------------------------------------------------------------------
@@ -142,6 +147,7 @@ DONOR_FEATURES = ["DONORF", "GRAFTYPE", "HLA_FINAL",
142
  DISEASE_FEATURES = ["NACS2YR", "EXCHTFPR", "VOC2YPR", "VOCFRQPR", "SCATXRSN"]
143
  ALL_FEATURES = PATIENT_FEATURES + DONOR_FEATURES + DISEASE_FEATURES
144
 
 
145
  # ---------------------------------------------------------------------------
146
  # Utility callbacks
147
  # ---------------------------------------------------------------------------
@@ -173,6 +179,7 @@ def vocfrqpr_from_voc2ypr(voc_status):
173
 
174
 
175
  def apply_grouped_preset(selected_value):
 
176
  if not selected_value or selected_value in HEADER_VALUES:
177
  return [gr.update(value=None)] + [gr.update()] * 6
178
 
@@ -181,7 +188,7 @@ def apply_grouped_preset(selected_value):
181
  return [gr.update()] * 7
182
 
183
  return [
184
- gr.update(),
185
  gr.update(value=preset["DONORF"]),
186
  gr.update(value=preset["CONDGRPF"]),
187
  gr.update(value=preset["CONDGRP_FINAL"]),
@@ -190,6 +197,7 @@ def apply_grouped_preset(selected_value):
190
  gr.update(value=preset["HLA_FINAL"]),
191
  ]
192
 
 
193
  # ---------------------------------------------------------------------------
194
  # Component factory
195
  # ---------------------------------------------------------------------------
@@ -238,6 +246,7 @@ def make_component(name: str):
238
  else:
239
  return gr.Textbox(label=name)
240
 
 
241
  # ---------------------------------------------------------------------------
242
  # Prediction callback
243
  # ---------------------------------------------------------------------------
@@ -255,7 +264,7 @@ def predict_gradio(*values):
255
  f"Please fill in all fields before predicting.\nMissing: {', '.join(missing)}"
256
  )
257
 
258
- calibrated, _ = predict_with_comparison(user_vals)
259
  calibrated_probs, calibrated_intervals = calibrated
260
 
261
  rows = []
@@ -270,19 +279,12 @@ def predict_gradio(*values):
270
  })
271
  df = pd.DataFrame(rows)
272
 
273
- shap_plots = create_all_shap_plots(user_vals, max_display=10)
274
-
275
- icon_outcomes = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI"]
276
- icon_plots = {o: icon_array(calibrated_probs[o], o) for o in icon_outcomes}
277
 
278
  return (
279
  df,
280
- icon_plots["DEAD"],
281
- icon_plots["GF"],
282
- icon_plots["AGVHD"],
283
- icon_plots["CGVHD"],
284
- icon_plots["VOCPSHI"],
285
- icon_plots["STROKEHI"],
286
  shap_plots["DEAD"],
287
  shap_plots["GF"],
288
  shap_plots["AGVHD"],
@@ -301,6 +303,7 @@ def predict_gradio(*values):
301
  print("=" * 60)
302
  raise gr.Error(f"{type(e).__name__}: {str(e)}\n\nSee terminal for full traceback.")
303
 
 
304
  # ---------------------------------------------------------------------------
305
  # CSS
306
  # ---------------------------------------------------------------------------
@@ -335,13 +338,13 @@ with gr.Blocks(title="HCT Outcome Prediction Model", css=custom_css) as demo:
335
  inputs_dict = {}
336
 
337
  with gr.Row():
338
- # Patient
339
  with gr.Column(scale=1):
340
  gr.Markdown("### Patient Characteristics")
341
  for f in PATIENT_FEATURES:
342
  inputs_dict[f] = make_component(f)
343
 
344
- # Transplant
345
  with gr.Column(scale=1):
346
  gr.Markdown("### Transplant Characteristics")
347
 
@@ -361,13 +364,13 @@ with gr.Blocks(title="HCT Outcome Prediction Model", css=custom_css) as demo:
361
  gvhd_final = inputs_dict["GVHD_FINAL"] = make_component("GVHD_FINAL")
362
  hla_final = inputs_dict["HLA_FINAL"] = make_component("HLA_FINAL")
363
 
364
- # Disease
365
  with gr.Column(scale=1):
366
  gr.Markdown("### Disease Characteristics")
367
  for f in DISEASE_FEATURES:
368
  inputs_dict[f] = make_component(f)
369
 
370
- # reactive callbacks
371
  inputs_dict["AGE"].change(
372
  fn=get_age_group,
373
  inputs=inputs_dict["AGE"],
@@ -380,6 +383,7 @@ with gr.Blocks(title="HCT Outcome Prediction Model", css=custom_css) as demo:
380
  outputs=inputs_dict["VOCFRQPR"],
381
  )
382
 
 
383
  grouped_regimen_dropdown.change(
384
  fn=apply_grouped_preset,
385
  inputs=grouped_regimen_dropdown,
@@ -402,27 +406,18 @@ with gr.Blocks(title="HCT Outcome Prediction Model", css=custom_css) as demo:
402
  headers=["Outcome", "Probability", "95% CI"],
403
  label="",
404
  row_count=(len(REPORTING_OUTCOMES), "dynamic"),
405
- column_count=(3, "fixed"),
406
  )
407
 
408
  gr.Markdown("---")
409
- gr.Markdown("## Icon Arrays")
410
-
411
- with gr.Row():
412
- with gr.Column():
413
- icon_dead = gr.Plot(label="Death")
414
- with gr.Column():
415
- icon_gf = gr.Plot(label="Graft Failure")
416
- with gr.Column():
417
- icon_agvhd = gr.Plot(label="Acute Graft-versus-Host Disease")
418
 
419
- with gr.Row():
420
- with gr.Column():
421
- icon_cgvhd = gr.Plot(label="Chronic Graft-versus-Host Disease")
422
- with gr.Column():
423
- icon_vocpshi = gr.Plot(label="Vaso-Occlusive Crisis Post-HCT")
424
- with gr.Column():
425
- icon_stroke = gr.Plot(label="Stroke Post-HCT")
426
 
427
  gr.Markdown("---")
428
  gr.Markdown("## SHAP - Feature Importance")
@@ -452,13 +447,12 @@ with gr.Blocks(title="HCT Outcome Prediction Model", css=custom_css) as demo:
452
  inputs=inputs_list,
453
  outputs=[
454
  output_table,
455
- icon_dead, icon_gf, icon_agvhd, icon_cgvhd, icon_vocpshi, icon_stroke,
456
  shap_dead, shap_gf, shap_agvhd, shap_cgvhd,
457
  shap_vocpshi, shap_efs, shap_stroke, shap_os,
458
  ],
459
  )
460
 
 
461
  if __name__ == "__main__":
462
- demo.launch(
463
- ssr_mode=False,
464
- )
 
6
  FEATURE_NAMES,
7
  REPORTING_OUTCOMES,
8
  OUTCOME_DESCRIPTIONS,
9
+ OUTCOMES,
10
+ SHAP_OUTCOMES,
11
  predict_with_comparison,
12
  create_all_shap_plots,
13
+ create_all_icon_arrays,
14
  )
15
 
16
+
17
  # ---------------------------------------------------------------------------
18
  # Choice lists
19
  # ---------------------------------------------------------------------------
 
55
  "Hodgkin lymphoma",
56
  ]
57
 
58
+
59
  # ---------------------------------------------------------------------------
60
  # Grouped published-regimen dropdown
61
  # ---------------------------------------------------------------------------
 
136
  },
137
  }
138
 
139
+
140
  # ---------------------------------------------------------------------------
141
  # Feature groupings
142
  # ---------------------------------------------------------------------------
 
147
  DISEASE_FEATURES = ["NACS2YR", "EXCHTFPR", "VOC2YPR", "VOCFRQPR", "SCATXRSN"]
148
  ALL_FEATURES = PATIENT_FEATURES + DONOR_FEATURES + DISEASE_FEATURES
149
 
150
+
151
  # ---------------------------------------------------------------------------
152
  # Utility callbacks
153
  # ---------------------------------------------------------------------------
 
179
 
180
 
181
  def apply_grouped_preset(selected_value):
182
+ # Header row clicked β€” reset dropdown, leave all fields unchanged
183
  if not selected_value or selected_value in HEADER_VALUES:
184
  return [gr.update(value=None)] + [gr.update()] * 6
185
 
 
188
  return [gr.update()] * 7
189
 
190
  return [
191
+ gr.update(), # leave dropdown showing selection
192
  gr.update(value=preset["DONORF"]),
193
  gr.update(value=preset["CONDGRPF"]),
194
  gr.update(value=preset["CONDGRP_FINAL"]),
 
197
  gr.update(value=preset["HLA_FINAL"]),
198
  ]
199
 
200
+
201
  # ---------------------------------------------------------------------------
202
  # Component factory
203
  # ---------------------------------------------------------------------------
 
246
  else:
247
  return gr.Textbox(label=name)
248
 
249
+
250
  # ---------------------------------------------------------------------------
251
  # Prediction callback
252
  # ---------------------------------------------------------------------------
 
264
  f"Please fill in all fields before predicting.\nMissing: {', '.join(missing)}"
265
  )
266
 
267
+ calibrated, uncalibrated = predict_with_comparison(user_vals)
268
  calibrated_probs, calibrated_intervals = calibrated
269
 
270
  rows = []
 
279
  })
280
  df = pd.DataFrame(rows)
281
 
282
+ shap_plots = create_all_shap_plots(user_vals, max_display=10)
283
+ icon_arrays = create_all_icon_arrays(calibrated_probs)
 
 
284
 
285
  return (
286
  df,
287
+ icon_arrays["__grid__"], # single combined 4Γ—2 grid HTML
 
 
 
 
 
288
  shap_plots["DEAD"],
289
  shap_plots["GF"],
290
  shap_plots["AGVHD"],
 
303
  print("=" * 60)
304
  raise gr.Error(f"{type(e).__name__}: {str(e)}\n\nSee terminal for full traceback.")
305
 
306
+
307
  # ---------------------------------------------------------------------------
308
  # CSS
309
  # ---------------------------------------------------------------------------
 
338
  inputs_dict = {}
339
 
340
  with gr.Row():
341
+ # ── Patient Characteristics ──────────────────────────────────────
342
  with gr.Column(scale=1):
343
  gr.Markdown("### Patient Characteristics")
344
  for f in PATIENT_FEATURES:
345
  inputs_dict[f] = make_component(f)
346
 
347
+ # ── Transplant Characteristics ───────────────────────────────────
348
  with gr.Column(scale=1):
349
  gr.Markdown("### Transplant Characteristics")
350
 
 
364
  gvhd_final = inputs_dict["GVHD_FINAL"] = make_component("GVHD_FINAL")
365
  hla_final = inputs_dict["HLA_FINAL"] = make_component("HLA_FINAL")
366
 
367
+ # ── Disease Characteristics ──────────────────────────────────────
368
  with gr.Column(scale=1):
369
  gr.Markdown("### Disease Characteristics")
370
  for f in DISEASE_FEATURES:
371
  inputs_dict[f] = make_component(f)
372
 
373
+ # ── Reactive callbacks ───────────────────────────────────────────────
374
  inputs_dict["AGE"].change(
375
  fn=get_age_group,
376
  inputs=inputs_dict["AGE"],
 
383
  outputs=inputs_dict["VOCFRQPR"],
384
  )
385
 
386
+ # outputs[0] is the dropdown itself so clicking a header resets it to None
387
  grouped_regimen_dropdown.change(
388
  fn=apply_grouped_preset,
389
  inputs=grouped_regimen_dropdown,
 
406
  headers=["Outcome", "Probability", "95% CI"],
407
  label="",
408
  row_count=(len(REPORTING_OUTCOMES), "dynamic"),
409
+ col_count=(3, "fixed"),
410
  )
411
 
412
  gr.Markdown("---")
413
+ gr.Markdown("## Outcome Probability β€” Icon Arrays")
414
+ gr.Markdown(
415
+ "_Each figure represents 1 patient out of 100 with similar characteristics. "
416
+ "Colored figures indicate the predicted probability of each outcome._"
417
+ )
 
 
 
 
418
 
419
+ # Single HTML component holds the entire 4Γ—2 grid
420
+ icon_array_grid = gr.HTML(label="")
 
 
 
 
 
421
 
422
  gr.Markdown("---")
423
  gr.Markdown("## SHAP - Feature Importance")
 
447
  inputs=inputs_list,
448
  outputs=[
449
  output_table,
450
+ icon_array_grid,
451
  shap_dead, shap_gf, shap_agvhd, shap_cgvhd,
452
  shap_vocpshi, shap_efs, shap_stroke, shap_os,
453
  ],
454
  )
455
 
456
+
457
  if __name__ == "__main__":
458
+ demo.launch(ssr_mode=False)