shivapriyasom commited on
Commit
f62406f
·
verified ·
1 Parent(s): 4a16c34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -1
app.py CHANGED
@@ -332,4 +332,133 @@ with gr.Blocks(title="HCT Outcome Prediction Model", css=custom_css) as demo:
332
  """
333
  )
334
 
335
- inputs_dict =
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  """
333
  )
334
 
335
+ inputs_dict = {}
336
+
337
+ with gr.Row():
338
+ # Patient
339
+ with gr.Column(scale=1):
340
+ gr.Markdown("### Patient Characteristics")
341
+ for f in PATIENT_FEATURES:
342
+ inputs_dict[f] = make_component(f)
343
+
344
+ # Transplant
345
+ with gr.Column(scale=1):
346
+ gr.Markdown("### Transplant Characteristics")
347
+
348
+ grouped_regimen_dropdown = gr.Dropdown(
349
+ choices=GROUPED_REGIMEN_CHOICES,
350
+ value=None,
351
+ label="Published conditioning regimen",
352
+ info="Auto-fills Donor Type, Conditioning Intensity, Conditioning Regimen, "
353
+ "Serotherapy and GVHD Prophylaxis",
354
+ )
355
+
356
+ donorf_comp = inputs_dict["DONORF"] = make_component("DONORF")
357
+ inputs_dict["GRAFTYPE"] = make_component("GRAFTYPE")
358
+ condgrpf = inputs_dict["CONDGRPF"] = make_component("CONDGRPF")
359
+ condgrp_final = inputs_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL")
360
+ atgf = inputs_dict["ATGF"] = make_component("ATGF")
361
+ gvhd_final = inputs_dict["GVHD_FINAL"] = make_component("GVHD_FINAL")
362
+ hla_final = inputs_dict["HLA_FINAL"] = make_component("HLA_FINAL")
363
+
364
+ # Disease
365
+ with gr.Column(scale=1):
366
+ gr.Markdown("### Disease Characteristics")
367
+ for f in DISEASE_FEATURES:
368
+ inputs_dict[f] = make_component(f)
369
+
370
+ # reactive callbacks
371
+ inputs_dict["AGE"].change(
372
+ fn=get_age_group,
373
+ inputs=inputs_dict["AGE"],
374
+ outputs=inputs_dict["AGEGPFF"],
375
+ )
376
+
377
+ inputs_dict["VOC2YPR"].change(
378
+ fn=vocfrqpr_from_voc2ypr,
379
+ inputs=inputs_dict["VOC2YPR"],
380
+ outputs=inputs_dict["VOCFRQPR"],
381
+ )
382
+
383
+ grouped_regimen_dropdown.change(
384
+ fn=apply_grouped_preset,
385
+ inputs=grouped_regimen_dropdown,
386
+ outputs=[
387
+ grouped_regimen_dropdown,
388
+ donorf_comp, condgrpf, condgrp_final, atgf, gvhd_final, hla_final,
389
+ ],
390
+ )
391
+
392
+ inputs_list = [inputs_dict[f] for f in ALL_FEATURES]
393
+
394
+ btn = gr.Button("Predict", elem_classes="predict-button", size="lg")
395
+
396
+ gr.Markdown("---")
397
+ gr.Markdown("## Prediction Results")
398
+ gr.Markdown("### Predicted Outcomes")
399
+
400
+ with gr.Column():
401
+ output_table = gr.Dataframe(
402
+ headers=["Outcome", "Probability", "95% CI"],
403
+ label="",
404
+ row_count=(len(REPORTING_OUTCOMES), "dynamic"),
405
+ column_count=(3, "fixed"),
406
+ )
407
+
408
+ gr.Markdown("---")
409
+ gr.Markdown("## Icon Arrays")
410
+
411
+ with gr.Row():
412
+ with gr.Column():
413
+ icon_dead = gr.Plot(label="Death")
414
+ with gr.Column():
415
+ icon_gf = gr.Plot(label="Graft Failure")
416
+ with gr.Column():
417
+ icon_agvhd = gr.Plot(label="Acute Graft-versus-Host Disease")
418
+
419
+ with gr.Row():
420
+ with gr.Column():
421
+ icon_cgvhd = gr.Plot(label="Chronic Graft-versus-Host Disease")
422
+ with gr.Column():
423
+ icon_vocpshi = gr.Plot(label="Vaso-Occlusive Crisis Post-HCT")
424
+ with gr.Column():
425
+ icon_stroke = gr.Plot(label="Stroke Post-HCT")
426
+
427
+ gr.Markdown("---")
428
+ gr.Markdown("## SHAP - Feature Importance")
429
+
430
+ with gr.Row():
431
+ with gr.Column():
432
+ shap_dead = gr.Plot(label="Death")
433
+ with gr.Column():
434
+ shap_gf = gr.Plot(label="Graft Failure")
435
+ with gr.Column():
436
+ shap_agvhd = gr.Plot(label="Acute Graft-versus-Host Disease")
437
+ with gr.Column():
438
+ shap_cgvhd = gr.Plot(label="Chronic Graft-versus-Host Disease")
439
+
440
+ with gr.Row():
441
+ with gr.Column():
442
+ shap_vocpshi = gr.Plot(label="Vaso-Occlusive Crisis Post-HCT")
443
+ with gr.Column():
444
+ shap_efs = gr.Plot(label="Event-Free Survival")
445
+ with gr.Column():
446
+ shap_stroke = gr.Plot(label="Stroke Post-HCT")
447
+ with gr.Column():
448
+ shap_os = gr.Plot(label="Overall Survival")
449
+
450
+ btn.click(
451
+ fn=predict_gradio,
452
+ inputs=inputs_list,
453
+ outputs=[
454
+ output_table,
455
+ icon_dead, icon_gf, icon_agvhd, icon_cgvhd, icon_vocpshi, icon_stroke,
456
+ shap_dead, shap_gf, shap_agvhd, shap_cgvhd,
457
+ shap_vocpshi, shap_efs, shap_stroke, shap_os,
458
+ ],
459
+ )
460
+
461
+ if __name__ == "__main__":
462
+ demo.launch(
463
+ ssr_mode=False,
464
+ )