Spaces:
Build error
Build error
| import gradio as gr | |
| import pandas as pd | |
| import traceback | |
| from inference import ( | |
| FEATURE_NAMES, | |
| REPORTING_OUTCOMES, | |
| OUTCOME_DESCRIPTIONS, | |
| OUTCOMES, | |
| SHAP_OUTCOMES, | |
| predict_with_comparison, | |
| create_all_shap_plots, | |
| icon_array, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Choice lists | |
| # --------------------------------------------------------------------------- | |
| AGEGPFF_CHOICES = ["<=10", "11-17", "18-29", "30-49", ">=50"] | |
| SEX_CHOICES = ["Male", "Female"] | |
| KPS_CHOICES = ["<90", "β₯ 90"] | |
| DONORF_CHOICES = [ | |
| "HLA identical sibling", | |
| "HLA mismatch relative", | |
| "Matched unrelated donor", | |
| "Mismatched unrelated donor or cord blood", | |
| ] | |
| GRAFTYPE_CHOICES = ["Bone marrow", "Peripheral blood", "Cord blood"] | |
| CONDGRPF_CHOICES = ["MAC", "RIC", "NMA"] | |
| CONDGRP_FINAL_CHOICES = [ | |
| "TBI/Cy", "TBI/Cy/Flu", "TBI/Cy/Flu/TT", "TBI/Mel", "TBI/Flu", | |
| "TBI alone (300/400/600cGy)", "Bu/Cy", "Bu/Mel", "Flu/Bu/TT", | |
| "Flu/Bu", "Flu/Mel/TT", "Flu/Mel", "Cy/Flu", "Treosulfan", | |
| "Cy alone", "Flud", "TLI", | |
| ] | |
| ATGF_CHOICES = ["ATG", "Alemtuzumab", "None"] | |
| GVHD_FINAL_CHOICES = [ | |
| "Ex-vivo T-cell depletion", "CD34 selection", "Post-CY + siro +- MMF", | |
| "Post-CY + MMF + CNI", "CNI + MMF", "CNI + MTX", "CNI alone", | |
| "CNI + siro", "Siro alone", "MMF + MTX", "MMF + siro", "MMF alone", | |
| "MTX alone", "MTX + siro", | |
| ] | |
| HLA_FINAL_CHOICES = ["8/8", "7/8", "β€ 6/8"] | |
| RCMVPR_CHOICES = ["Negative", "Positive"] | |
| EXCHTFPR_CHOICES = ["No", "Yes"] | |
| VOC2YPR_CHOICES = ["No", "Yes"] | |
| VOCFRQPR_CHOICES = ["< 3/yr", "β₯ 3/yr"] | |
| SCATXRSN_CHOICES = [ | |
| "CNS event", "Acute chest Syndrome", "Recurrent vaso-occlusive pain", | |
| "Recurrent priapism", "Excessive transfusion requirements/iron overload", | |
| "Cardio-pulmonary", "Chronic transfusion", "Asymptomatic", | |
| "Renal insufficiency", "Splenic sequestration", "Avascular necrosis", | |
| "Hodgkin lymphoma", | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # Grouped published-regimen dropdown | |
| # --------------------------------------------------------------------------- | |
| GROUPED_REGIMEN_CHOICES = [ | |
| ("ββ HLA IDENTICAL ββ", "__header_hla_identical__"), | |
| ("Hsieh et al 2014", "Hsieh et al 2014"), | |
| ("Krishnamurti et al 2019", "Krishnamurti et al 2019"), | |
| ("King et al 2015", "King et al 2015"), | |
| ("Walters et al 1996", "Walters et al 1996"), | |
| ("ββ HLA MISMATCHED ββ", "__header_hla_mismatched__"), | |
| ("Bolanos-Meade et al 2022 (HLA Mismatch)", "Bolanos-Meade et al 2022 (HLA Mismatch)"), | |
| ("Patel et al 2020 (HLA Mismatch)", "Patel et al 2020 (HLA Mismatch)"), | |
| ("ββ MATCHED UNRELATED ββ", "__header_matched_unrelated__"), | |
| ("L Krishnamurti et al 2019", "L Krishnamurti et al 2019"), | |
| ("Shenoy et al 2016", "Shenoy et al 2016"), | |
| ("ββ MISMATCHED UNRELATED / CORD BLOOD ββ", "__header_mismatched_cord__"), | |
| ("Bolanos-Meade et al 2022 (Mismatched/Cord)", "Bolanos-Meade et al 2022 (Mismatched/Cord)"), | |
| ("Patel et al 2020 (Mismatched/Cord)", "Patel et al 2020 (Mismatched/Cord)"), | |
| ] | |
| HEADER_VALUES = {v for _, v in GROUPED_REGIMEN_CHOICES if v.startswith("__header_")} | |
| PUBLISHED_PRESETS = { | |
| # HLA Identical Sibling | |
| "Hsieh et al 2014": { | |
| "CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI alone (300/400/600cGy)", | |
| "ATGF": "Alemtuzumab", "GVHD_FINAL": "Siro alone", | |
| "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling", | |
| }, | |
| "Krishnamurti et al 2019": { | |
| "CONDGRPF": "MAC", "CONDGRP_FINAL": "Flu/Bu", | |
| "ATGF": "ATG", "GVHD_FINAL": "CNI + MTX", | |
| "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling", | |
| }, | |
| "King et al 2015": { | |
| "CONDGRPF": "RIC", "CONDGRP_FINAL": "Flu/Mel", | |
| "ATGF": "Alemtuzumab", "GVHD_FINAL": "CNI + MTX", | |
| "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling", | |
| }, | |
| "Walters et al 1996": { | |
| "CONDGRPF": "MAC", "CONDGRP_FINAL": "Bu/Cy", | |
| "ATGF": "ATG", "GVHD_FINAL": "CNI + MTX", | |
| "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling", | |
| }, | |
| # HLA Mismatch Relative | |
| "Bolanos-Meade et al 2022 (HLA Mismatch)": { | |
| "CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu", | |
| "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", | |
| "HLA_FINAL": "7/8", "DONORF": "HLA mismatch relative", | |
| }, | |
| "Patel et al 2020 (HLA Mismatch)": { | |
| "CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu/TT", | |
| "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", | |
| "HLA_FINAL": "7/8", "DONORF": "HLA mismatch relative", | |
| }, | |
| # Matched Unrelated Donor | |
| "L Krishnamurti et al 2019": { | |
| "CONDGRPF": "MAC", "CONDGRP_FINAL": "Flu/Bu", | |
| "ATGF": "ATG", "GVHD_FINAL": "CNI + MTX", | |
| "HLA_FINAL": "8/8", "DONORF": "Matched unrelated donor", | |
| }, | |
| "Shenoy et al 2016": { | |
| "CONDGRPF": "RIC", "CONDGRP_FINAL": "Flu/Mel", | |
| "ATGF": "Alemtuzumab", "GVHD_FINAL": "CNI + MTX", | |
| "HLA_FINAL": "8/8", "DONORF": "Matched unrelated donor", | |
| }, | |
| # Mismatched Unrelated Donor or Cord Blood | |
| "Bolanos-Meade et al 2022 (Mismatched/Cord)": { | |
| "CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu", | |
| "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", | |
| "HLA_FINAL": "7/8", "DONORF": "Mismatched unrelated donor or cord blood", | |
| }, | |
| "Patel et al 2020 (Mismatched/Cord)": { | |
| "CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu/TT", | |
| "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF", | |
| "HLA_FINAL": "7/8", "DONORF": "Mismatched unrelated donor or cord blood", | |
| }, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Feature groupings | |
| # --------------------------------------------------------------------------- | |
| PATIENT_FEATURES = ["AGE", "AGEGPFF", "SEX", "KPS", "RCMVPR"] | |
| DONOR_FEATURES = ["DONORF", "GRAFTYPE", "HLA_FINAL", | |
| "CONDGRPF", "CONDGRP_FINAL", "ATGF", "GVHD_FINAL"] | |
| DISEASE_FEATURES = ["NACS2YR", "EXCHTFPR", "VOC2YPR", "VOCFRQPR", "SCATXRSN"] | |
| ALL_FEATURES = PATIENT_FEATURES + DONOR_FEATURES + DISEASE_FEATURES | |
| # --------------------------------------------------------------------------- | |
| # Utility callbacks | |
| # --------------------------------------------------------------------------- | |
| def get_age_group(age): | |
| if age is None or age == "": | |
| return "" | |
| try: | |
| age = float(age) | |
| if age <= 10: | |
| return "<=10" | |
| elif age <= 17: | |
| return "11-17" | |
| elif age <= 29: | |
| return "18-29" | |
| elif age <= 49: | |
| return "30-49" | |
| else: | |
| return ">=50" | |
| except (ValueError, TypeError): | |
| return "" | |
| def vocfrqpr_from_voc2ypr(voc_status): | |
| if voc_status == "No": | |
| return gr.update(value="< 3/yr", interactive=False) | |
| else: | |
| return gr.update(value=None, interactive=True) | |
| def apply_grouped_preset(selected_value): | |
| if not selected_value or selected_value in HEADER_VALUES: | |
| return [gr.update(value=None)] + [gr.update()] * 6 | |
| preset = PUBLISHED_PRESETS.get(selected_value) | |
| if not preset: | |
| return [gr.update()] * 7 | |
| return [ | |
| gr.update(), | |
| gr.update(value=preset["DONORF"]), | |
| gr.update(value=preset["CONDGRPF"]), | |
| gr.update(value=preset["CONDGRP_FINAL"]), | |
| gr.update(value=preset["ATGF"]), | |
| gr.update(value=preset["GVHD_FINAL"]), | |
| gr.update(value=preset["HLA_FINAL"]), | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # Component factory | |
| # --------------------------------------------------------------------------- | |
| def make_component(name: str): | |
| if name == "AGE": | |
| return gr.Number(label="Age at transplant (years)", minimum=0, maximum=120) | |
| elif name == "AGEGPFF": | |
| return gr.Textbox(label="Age group (Auto-filled)", interactive=False) | |
| elif name == "NACS2YR": | |
| return gr.Number( | |
| label="Number of Acute Chest Syndromes within 2 years pre-HCT", | |
| minimum=0, | |
| ) | |
| elif name == "SEX": | |
| return gr.Dropdown(SEX_CHOICES, label="Sex") | |
| elif name == "KPS": | |
| return gr.Dropdown(KPS_CHOICES, label="Karnofsky/Lansky Performance Score at HCT") | |
| elif name == "DONORF": | |
| return gr.Dropdown(DONORF_CHOICES, label="Donor type") | |
| elif name == "GRAFTYPE": | |
| return gr.Dropdown(GRAFTYPE_CHOICES, label="Graft type") | |
| elif name == "CONDGRPF": | |
| return gr.Dropdown(CONDGRPF_CHOICES, label="Conditioning intensity") | |
| elif name == "CONDGRP_FINAL": | |
| return gr.Dropdown(CONDGRP_FINAL_CHOICES, label="Conditioning Regimen") | |
| elif name == "ATGF": | |
| return gr.Dropdown(ATGF_CHOICES, label="Serotherapy") | |
| elif name == "GVHD_FINAL": | |
| return gr.Dropdown(GVHD_FINAL_CHOICES, label="GVHD Prophylaxis") | |
| elif name == "HLA_FINAL": | |
| return gr.Dropdown(HLA_FINAL_CHOICES, label="Donor-Recipient HLA Matching") | |
| elif name == "RCMVPR": | |
| return gr.Dropdown(RCMVPR_CHOICES, label="Recipient CMV serostatus") | |
| elif name == "EXCHTFPR": | |
| return gr.Dropdown(EXCHTFPR_CHOICES, label="Exchange transfusion required?") | |
| elif name == "VOC2YPR": | |
| return gr.Dropdown( | |
| VOC2YPR_CHOICES, | |
| label="VOC requiring hospitalization within 2 years pre-HCT?", | |
| ) | |
| elif name == "VOCFRQPR": | |
| return gr.Dropdown(VOCFRQPR_CHOICES, label="Frequency of VOC hospitalizations") | |
| elif name == "SCATXRSN": | |
| return gr.Dropdown(SCATXRSN_CHOICES, label="Reason for Transplant") | |
| else: | |
| return gr.Textbox(label=name) | |
| # --------------------------------------------------------------------------- | |
| # Prediction callback | |
| # --------------------------------------------------------------------------- | |
| def predict_gradio(*values): | |
| try: | |
| user_vals = {f: v for f, v in zip(ALL_FEATURES, values)} | |
| missing = [] | |
| for f, v in user_vals.items(): | |
| if v is None or v == "" or (isinstance(v, float) and pd.isna(v)): | |
| missing.append(f) | |
| if missing: | |
| raise ValueError( | |
| f"Please fill in all fields before predicting.\nMissing: {', '.join(missing)}" | |
| ) | |
| calibrated, uncalibrated = predict_with_comparison(user_vals) | |
| calibrated_probs, calibrated_intervals = calibrated | |
| rows = [] | |
| for outcome in REPORTING_OUTCOMES: | |
| desc = OUTCOME_DESCRIPTIONS[outcome] | |
| calib_prob = calibrated_probs[outcome] | |
| ci_low_c, ci_high_c = calibrated_intervals[outcome] | |
| rows.append({ | |
| "Outcome": desc, | |
| "Probability": f"{calib_prob * 100:.1f}%", | |
| "95% CI": f"[{ci_low_c * 100:.1f}% - {ci_high_c * 100:.1f}%]", | |
| }) | |
| df = pd.DataFrame(rows) | |
| shap_plots = create_all_shap_plots(user_vals, max_display=10) | |
| # Icon arrays for each outcome | |
| icon_outcomes = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI"] | |
| icon_plots = {o: icon_array(calibrated_probs[o], o) for o in icon_outcomes} | |
| return ( | |
| df, | |
| icon_plots["DEAD"], | |
| icon_plots["GF"], | |
| icon_plots["AGVHD"], | |
| icon_plots["CGVHD"], | |
| icon_plots["VOCPSHI"], | |
| icon_plots["STROKEHI"], | |
| shap_plots["DEAD"], | |
| shap_plots["GF"], | |
| shap_plots["AGVHD"], | |
| shap_plots["CGVHD"], | |
| shap_plots["VOCPSHI"], | |
| shap_plots["EFS"], | |
| shap_plots["STROKEHI"], | |
| shap_plots["OS"], | |
| ) | |
| except Exception as e: | |
| tb = traceback.format_exc() | |
| print("=" * 60) | |
| print("ERROR IN predict_gradio:") | |
| print(tb) | |
| print("=" * 60) | |
| raise gr.Error(f"{type(e).__name__}: {str(e)}\n\nSee terminal for full traceback.") | |
| # --------------------------------------------------------------------------- | |
| # CSS (passed to launch() in Gradio 6+) | |
| # --------------------------------------------------------------------------- | |
| custom_css = """ | |
| .predict-button { | |
| background: linear-gradient(to right, #ff6b35, #ff8c42) !important; | |
| border: none !important; | |
| color: white !important; | |
| font-weight: bold !important; | |
| font-size: 16px !important; | |
| padding: 12px !important; | |
| } | |
| .predict-button:hover { | |
| background: linear-gradient(to right, #ff5722, #ff7b29) !important; | |
| } | |
| """ | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks(title="HCT Outcome Prediction Model") as demo: | |
| gr.Markdown( | |
| """ | |
| # HCT Outcome Prediction Model | |
| Enter patient, transplant, and disease characteristics to predict outcomes. | |
| """ | |
| ) | |
| inputs_dict = {} | |
| with gr.Row(): | |
| # ββ Patient Characteristics ββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Patient Characteristics") | |
| for f in PATIENT_FEATURES: | |
| inputs_dict[f] = make_component(f) | |
| # ββ Transplant Characteristics βββββββββββββββββββββββββββββββββββ | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Transplant Characteristics") | |
| grouped_regimen_dropdown = gr.Dropdown( | |
| choices=GROUPED_REGIMEN_CHOICES, | |
| value=None, | |
| label="Published conditioning regimen", | |
| info="Auto-fills Donor Type, Conditioning Intensity, Conditioning Regimen, " | |
| "Serotherapy and GVHD Prophylaxis", | |
| ) | |
| donorf_comp = inputs_dict["DONORF"] = make_component("DONORF") | |
| inputs_dict["GRAFTYPE"] = make_component("GRAFTYPE") | |
| condgrpf = inputs_dict["CONDGRPF"] = make_component("CONDGRPF") | |
| condgrp_final = inputs_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL") | |
| atgf = inputs_dict["ATGF"] = make_component("ATGF") | |
| gvhd_final = inputs_dict["GVHD_FINAL"] = make_component("GVHD_FINAL") | |
| hla_final = inputs_dict["HLA_FINAL"] = make_component("HLA_FINAL") | |
| # ββ Disease Characteristics ββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Disease Characteristics") | |
| for f in DISEASE_FEATURES: | |
| inputs_dict[f] = make_component(f) | |
| # ββ Reactive callbacks βββββββββββββββββββββββββββββββββββββββββββββββ | |
| inputs_dict["AGE"].change( | |
| fn=get_age_group, | |
| inputs=inputs_dict["AGE"], | |
| outputs=inputs_dict["AGEGPFF"], | |
| ) | |
| inputs_dict["VOC2YPR"].change( | |
| fn=vocfrqpr_from_voc2ypr, | |
| inputs=inputs_dict["VOC2YPR"], | |
| outputs=inputs_dict["VOCFRQPR"], | |
| ) | |
| grouped_regimen_dropdown.change( | |
| fn=apply_grouped_preset, | |
| inputs=grouped_regimen_dropdown, | |
| outputs=[ | |
| grouped_regimen_dropdown, | |
| donorf_comp, condgrpf, condgrp_final, atgf, gvhd_final, hla_final, | |
| ], | |
| ) | |
| inputs_list = [inputs_dict[f] for f in ALL_FEATURES] | |
| btn = gr.Button("Predict", elem_classes="predict-button", size="lg") | |
| gr.Markdown("---") | |
| gr.Markdown("## Prediction Results") | |
| gr.Markdown("### Predicted Outcomes") | |
| with gr.Column(): | |
| output_table = gr.Dataframe( | |
| headers=["Outcome", "Probability", "95% CI"], | |
| label="", | |
| row_count=(len(REPORTING_OUTCOMES), "dynamic"), | |
| column_count=(3, "fixed"), # fixed: col_count β column_count (Gradio 6) | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("## Icon Arrays") | |
| with gr.Row(): | |
| with gr.Column(): | |
| icon_dead = gr.Plot(label="Death") | |
| with gr.Column(): | |
| icon_gf = gr.Plot(label="Graft Failure") | |
| with gr.Column(): | |
| icon_agvhd = gr.Plot(label="Acute Graft-versus-Host Disease") | |
| with gr.Row(): | |
| with gr.Column(): | |
| icon_cgvhd = gr.Plot(label="Chronic Graft-versus-Host Disease") | |
| with gr.Column(): | |
| icon_vocpshi = gr.Plot(label="Vaso-Occlusive Crisis Post-HCT") | |
| with gr.Column(): | |
| icon_stroke = gr.Plot(label="Stroke Post-HCT") | |
| gr.Markdown("---") | |
| gr.Markdown("## SHAP - Feature Importance") | |
| with gr.Row(): | |
| with gr.Column(): | |
| shap_dead = gr.Plot(label="Death") | |
| with gr.Column(): | |
| shap_gf = gr.Plot(label="Graft Failure") | |
| with gr.Column(): | |
| shap_agvhd = gr.Plot(label="Acute Graft-versus-Host Disease") | |
| with gr.Column(): | |
| shap_cgvhd = gr.Plot(label="Chronic Graft-versus-Host Disease") | |
| with gr.Row(): | |
| with gr.Column(): | |
| shap_vocpshi = gr.Plot(label="Vaso-Occlusive Crisis Post-HCT") | |
| with gr.Column(): | |
| shap_efs = gr.Plot(label="Event-Free Survival") | |
| with gr.Column(): | |
| shap_stroke = gr.Plot(label="Stroke Post-HCT") | |
| with gr.Column(): | |
| shap_os = gr.Plot(label="Overall Survival") | |
| btn.click( | |
| fn=predict_gradio, | |
| inputs=inputs_list, | |
| outputs=[ | |
| output_table, | |
| icon_dead, icon_gf, icon_agvhd, icon_cgvhd, icon_vocpshi, icon_stroke, | |
| shap_dead, shap_gf, shap_agvhd, shap_cgvhd, | |
| shap_vocpshi, shap_efs, shap_stroke, shap_os, | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| ssr_mode=False, | |
| css=custom_css, # css moved to launch() in Gradio 6 | |
| ) |