shivapriyasom commited on
Commit
0cb567b
·
verified ·
1 Parent(s): c24ded0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -21
app.py CHANGED
@@ -10,7 +10,7 @@ from inference import (
10
  SHAP_OUTCOMES,
11
  predict_with_comparison,
12
  create_all_shap_plots,
13
- create_all_pie_charts,
14
  )
15
 
16
 
@@ -179,7 +179,6 @@ def vocfrqpr_from_voc2ypr(voc_status):
179
 
180
 
181
  def apply_grouped_preset(selected_value):
182
- # Header row clicked — reset dropdown, leave all fields unchanged
183
  if not selected_value or selected_value in HEADER_VALUES:
184
  return [gr.update(value=None)] + [gr.update()] * 6
185
 
@@ -188,7 +187,7 @@ def apply_grouped_preset(selected_value):
188
  return [gr.update()] * 7
189
 
190
  return [
191
- gr.update(), # leave dropdown showing selection
192
  gr.update(value=preset["DONORF"]),
193
  gr.update(value=preset["CONDGRPF"]),
194
  gr.update(value=preset["CONDGRP_FINAL"]),
@@ -280,16 +279,19 @@ def predict_gradio(*values):
280
  df = pd.DataFrame(rows)
281
 
282
  shap_plots = create_all_shap_plots(user_vals, max_display=10)
283
- pie_charts = create_all_pie_charts(calibrated_probs)
 
 
 
284
 
285
  return (
286
  df,
287
- pie_charts["DEAD"],
288
- pie_charts["GF"],
289
- pie_charts["AGVHD"],
290
- pie_charts["CGVHD"],
291
- pie_charts["VOCPSHI"],
292
- pie_charts["STROKEHI"],
293
  shap_plots["DEAD"],
294
  shap_plots["GF"],
295
  shap_plots["AGVHD"],
@@ -331,7 +333,7 @@ custom_css = """
331
  # Gradio UI
332
  # ---------------------------------------------------------------------------
333
 
334
- with gr.Blocks(title="HCT Outcome Prediction Model") as demo:
335
  gr.Markdown(
336
  """
337
  # HCT Outcome Prediction Model
@@ -388,7 +390,6 @@ with gr.Blocks(title="HCT Outcome Prediction Model") as demo:
388
  outputs=inputs_dict["VOCFRQPR"],
389
  )
390
 
391
- # outputs[0] is the dropdown itself so clicking a header resets it to None
392
  grouped_regimen_dropdown.change(
393
  fn=apply_grouped_preset,
394
  inputs=grouped_regimen_dropdown,
@@ -415,23 +416,23 @@ with gr.Blocks(title="HCT Outcome Prediction Model") as demo:
415
  )
416
 
417
  gr.Markdown("---")
418
- gr.Markdown("## Pie Charts")
419
 
420
  with gr.Row():
421
  with gr.Column():
422
- pie_dead = gr.Plot(label="Death")
423
  with gr.Column():
424
- pie_gf = gr.Plot(label="Graft Failure")
425
  with gr.Column():
426
- pie_agvhd = gr.Plot(label="Acute Graft-versus-Host Disease")
427
 
428
  with gr.Row():
429
  with gr.Column():
430
- pie_cgvhd = gr.Plot(label="Chronic Graft-versus-Host Disease")
431
  with gr.Column():
432
- pie_vocpshi = gr.Plot(label="Vaso-Occlusive Crisis Post-HCT")
433
  with gr.Column():
434
- pie_stroke = gr.Plot(label="Stroke Post-HCT")
435
 
436
  gr.Markdown("---")
437
  gr.Markdown("## SHAP - Feature Importance")
@@ -461,7 +462,7 @@ with gr.Blocks(title="HCT Outcome Prediction Model") as demo:
461
  inputs=inputs_list,
462
  outputs=[
463
  output_table,
464
- pie_dead, pie_gf, pie_agvhd, pie_cgvhd, pie_vocpshi, pie_stroke,
465
  shap_dead, shap_gf, shap_agvhd, shap_cgvhd,
466
  shap_vocpshi, shap_efs, shap_stroke, shap_os,
467
  ],
@@ -469,4 +470,4 @@ with gr.Blocks(title="HCT Outcome Prediction Model") as demo:
469
 
470
 
471
  if __name__ == "__main__":
472
- demo.launch(ssr_mode=False, css=custom_css)
 
10
  SHAP_OUTCOMES,
11
  predict_with_comparison,
12
  create_all_shap_plots,
13
+ icon_array,
14
  )
15
 
16
 
 
179
 
180
 
181
  def apply_grouped_preset(selected_value):
 
182
  if not selected_value or selected_value in HEADER_VALUES:
183
  return [gr.update(value=None)] + [gr.update()] * 6
184
 
 
187
  return [gr.update()] * 7
188
 
189
  return [
190
+ gr.update(),
191
  gr.update(value=preset["DONORF"]),
192
  gr.update(value=preset["CONDGRPF"]),
193
  gr.update(value=preset["CONDGRP_FINAL"]),
 
279
  df = pd.DataFrame(rows)
280
 
281
  shap_plots = create_all_shap_plots(user_vals, max_display=10)
282
+
283
+ # Icon arrays for each outcome
284
+ icon_outcomes = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI"]
285
+ icon_plots = {o: icon_array(calibrated_probs[o], o) for o in icon_outcomes}
286
 
287
  return (
288
  df,
289
+ icon_plots["DEAD"],
290
+ icon_plots["GF"],
291
+ icon_plots["AGVHD"],
292
+ icon_plots["CGVHD"],
293
+ icon_plots["VOCPSHI"],
294
+ icon_plots["STROKEHI"],
295
  shap_plots["DEAD"],
296
  shap_plots["GF"],
297
  shap_plots["AGVHD"],
 
333
  # Gradio UI
334
  # ---------------------------------------------------------------------------
335
 
336
+ with gr.Blocks(title="HCT Outcome Prediction Model", css=custom_css) as demo:
337
  gr.Markdown(
338
  """
339
  # HCT Outcome Prediction Model
 
390
  outputs=inputs_dict["VOCFRQPR"],
391
  )
392
 
 
393
  grouped_regimen_dropdown.change(
394
  fn=apply_grouped_preset,
395
  inputs=grouped_regimen_dropdown,
 
416
  )
417
 
418
  gr.Markdown("---")
419
+ gr.Markdown("## Icon Arrays")
420
 
421
  with gr.Row():
422
  with gr.Column():
423
+ icon_dead = gr.Plot(label="Death")
424
  with gr.Column():
425
+ icon_gf = gr.Plot(label="Graft Failure")
426
  with gr.Column():
427
+ icon_agvhd = gr.Plot(label="Acute Graft-versus-Host Disease")
428
 
429
  with gr.Row():
430
  with gr.Column():
431
+ icon_cgvhd = gr.Plot(label="Chronic Graft-versus-Host Disease")
432
  with gr.Column():
433
+ icon_vocpshi = gr.Plot(label="Vaso-Occlusive Crisis Post-HCT")
434
  with gr.Column():
435
+ icon_stroke = gr.Plot(label="Stroke Post-HCT")
436
 
437
  gr.Markdown("---")
438
  gr.Markdown("## SHAP - Feature Importance")
 
462
  inputs=inputs_list,
463
  outputs=[
464
  output_table,
465
+ icon_dead, icon_gf, icon_agvhd, icon_cgvhd, icon_vocpshi, icon_stroke,
466
  shap_dead, shap_gf, shap_agvhd, shap_cgvhd,
467
  shap_vocpshi, shap_efs, shap_stroke, shap_os,
468
  ],
 
470
 
471
 
472
  if __name__ == "__main__":
473
+ demo.launch(ssr_mode=False)