import os import shutil import urllib.request from pathlib import Path import gradio as gr import pandas as pd from pandas.api.types import CategoricalDtype from pycaret.classification import load_model, predict_model # Optional: load example data (not required for predictions, but kept since it exists in your repo) # If the file is missing, the app still runs. try: _ex_data = pd.read_csv("example_data2.csv") except Exception: _ex_data = None MODEL_BASENAME = "final_model" # pycaret load_model expects the basename MODEL_FILE = f"{MODEL_BASENAME}.pkl" # this should exist locally in your Space repo MODEL_URL = "https://github.com/fmegahed/tavr_paper/blob/main/data/final_model.pkl?raw=true" _MODEL = None def _ensure_model_file() -> None: """ Ensure final_model.pkl exists locally. If it is missing, try to download it once as a fallback. """ if Path(MODEL_FILE).exists(): return # Fallback: download if the repo file is missing for some reason with urllib.request.urlopen(MODEL_URL) as response, open(MODEL_FILE, "wb") as out_file: shutil.copyfileobj(response, out_file) def _get_model(): """ Load and cache the PyCaret model once per process. """ global _MODEL if _MODEL is None: _ensure_model_file() _MODEL = load_model(MODEL_BASENAME) return _MODEL def predict( age, female, race, elective, aweekend, zipinc_qrtl, hosp_region, hosp_division, hosp_locteach, hosp_bedsize, h_contrl, pay, anemia, atrial_fibrillation, cancer, cardiac_arrhythmias, carotid_artery_disease, chronic_kidney_disease, chronic_pulmonary_disease, coagulopathy, depression, diabetes_mellitus, drug_abuse, dyslipidemia, endocarditis, family_history, fluid_and_electrolyte_disorder, heart_failure, hypertension, known_cad, liver_disease, obesity, peripheral_vascular_disease, prior_cabg, prior_icd, prior_mi, prior_pci, prior_ppm, prior_tia_stroke, pulmonary_circulation_disorder, smoker, valvular_disease, weight_loss, endovascular_tavr, transapical_tavr, ): df = pd.DataFrame.from_dict( { "age": [age], "female": [female], "race": [race], "elective": [elective], "aweekend": [aweekend], "zipinc_qrtl": [zipinc_qrtl], "hosp_region": [hosp_region], "hosp_division": [hosp_division], "hosp_locteach": [hosp_locteach], "hosp_bedsize": [hosp_bedsize], "h_contrl": [h_contrl], "pay": [pay], "anemia": [anemia], "atrial_fibrillation": [atrial_fibrillation], "cancer": [cancer], "cardiac_arrhythmias": [cardiac_arrhythmias], "carotid_artery_disease": [carotid_artery_disease], "chronic_kidney_disease": [chronic_kidney_disease], "chronic_pulmonary_disease": [chronic_pulmonary_disease], "coagulopathy": [coagulopathy], "depression": [depression], "diabetes_mellitus": [diabetes_mellitus], "drug_abuse": [drug_abuse], "dyslipidemia": [dyslipidemia], "endocarditis": [endocarditis], "family_history": [family_history], "fluid_and_electrolyte_disorder": [fluid_and_electrolyte_disorder], "heart_failure": [heart_failure], "hypertension": [hypertension], "known_cad": [known_cad], "liver_disease": [liver_disease], "obesity": [obesity], "peripheral_vascular_disease": [peripheral_vascular_disease], "prior_cabg": [prior_cabg], "prior_icd": [prior_icd], "prior_mi": [prior_mi], "prior_pci": [prior_pci], "prior_ppm": [prior_ppm], "prior_tia_stroke": [prior_tia_stroke], "pulmonary_circulation_disorder": [pulmonary_circulation_disorder], "smoker": [smoker], "valvular_disease": [valvular_disease], "weight_loss": [weight_loss], "endovascular_tavr": [endovascular_tavr], "transapical_tavr": [transapical_tavr], } ) # Convert object columns to categorical obj_cols = df.select_dtypes(include=["object"]).columns for c in obj_cols: df[c] = df[c].astype("category") # Convert ordinal column to ordered categorical ordinal_cat = CategoricalDtype( categories=["FirstQ", "SecondQ", "ThirdQ", "FourthQ"], ordered=True, ) df["zipinc_qrtl"] = df["zipinc_qrtl"].astype(ordinal_cat) model = _get_model() pred = predict_model(model, df, raw_score=True) # These column names depend on how the model was trained and saved # This matches your original code: Score_Yes for death, Score_No for survival. return { "Death %": round(100 * float(pred["Score_Yes"].iloc[0]), 2), "Survival %": round(100 * float(pred["Score_No"].iloc[0]), 2), "Predicting Death Outcome": str(pred["Label"].iloc[0]), } inputs = [ gr.Slider(minimum=18, maximum=100, value=80, label="Age"), gr.Dropdown(choices=["Female", "Male"], value="Female", label="Sex"), gr.Dropdown( choices=[ "Asian or Pacific Islander", "Black", "Hispanic", "Native American", "White", "Other", ], value="White", label="Race", ), gr.Radio(choices=["Elective", "NonElective"], value="Elective", label="Elective"), gr.Radio(choices=["No", "Yes"], value="No", label="Weekend"), gr.Radio( choices=["FirstQ", "SecondQ", "ThirdQ", "FourthQ"], value="SecondQ", label="Zip Income Quartile", ), gr.Radio( choices=["Midwest", "Northeast", "South", "West"], value="South", label="Hospital Region", ), gr.Radio( choices=[ "New England", "Middle Atlantic", "East North Central", "West North Central", "South Atlantic", "East South Central", "West South Central", "Mountain", "Pacific", ], value="South Atlantic", label="Hospital Division", ), gr.Radio( choices=["Urban teaching", "Urban nonteaching", "Rural"], value="Urban teaching", label="Hospital Location/Teaching", ), gr.Radio(choices=["Small", "Medium", "Large"], value="Large", label="Hospital Bedsize"), gr.Radio( choices=["Government_nonfederal", "Private_invest_own", "Private_not_profit"], value="Private_not_profit", label="Hospital Control", ), gr.Dropdown( choices=["Private insurance", "Medicare", "Medicaid", "Self-pay", "No charge", "Other"], value="Medicare", label="Payee", ), # Comorbidities gr.Radio(choices=["No", "Yes"], value="Yes", label="Anemia"), gr.Radio(choices=["No", "Yes"], value="Yes", label="Atrial Fibrillation"), gr.Radio(choices=["No", "Yes"], value="No", label="Cancer"), gr.Radio(choices=["No", "Yes"], value="Yes", label="Cardiac Arrhythmias"), gr.Radio(choices=["No", "Yes"], value="No", label="Carotid Artery Disease"), gr.Radio(choices=["No", "Yes"], value="Yes", label="Chronic Kidney Disease"), gr.Radio(choices=["No", "Yes"], value="Yes", label="Chronic Pulmonary Disease"), gr.Radio(choices=["No", "Yes"], value="No", label="Coagulopathy"), gr.Radio(choices=["No", "Yes"], value="No", label="Depression"), gr.Radio(choices=["No", "Yes"], value="Yes", label="Diabetes Mellitus"), gr.Radio(choices=["No", "Yes"], value="No", label="Drug Abuse"), gr.Radio(choices=["No", "Yes"], value="Yes", label="Dyslipidemia"), gr.Radio(choices=["No", "Yes"], value="No", label="Endocarditis"), gr.Radio(choices=["No", "Yes"], value="No", label="Family History"), gr.Radio(choices=["No", "Yes"], value="Yes", label="Fluid and Electrolyte Disorder"), gr.Radio(choices=["No", "Yes"], value="Yes", label="Heart Failure"), gr.Radio(choices=["No", "Yes"], value="Yes", label="Hypertension"), gr.Radio(choices=["No", "Yes"], value="Yes", label="Known CAD"), gr.Radio(choices=["No", "Yes"], value="No", label="Liver Disease"), gr.Radio(choices=["No", "Yes"], value="Yes", label="Obesity"), gr.Radio(choices=["No", "Yes"], value="Yes", label="Peripheral Vascular Disease"), gr.Radio(choices=["No", "Yes"], value="Yes", label="Prior CABG"), gr.Radio(choices=["No", "Yes"], value="Yes", label="Prior ICD"), gr.Radio(choices=["No", "Yes"], value="Yes", label="Prior MI"), gr.Radio(choices=["No", "Yes"], value="Yes", label="Prior PCI"), gr.Radio(choices=["No", "Yes"], value="Yes", label="Prior PPM"), gr.Radio(choices=["No", "Yes"], value="Yes", label="Prior TIA Stroke"), gr.Radio(choices=["No", "Yes"], value="No", label="Pulmonary Circulation Disorder"), gr.Radio(choices=["No", "Yes"], value="No", label="Smoker"), gr.Radio(choices=["No", "Yes"], value="Yes", label="Valvular Disease"), gr.Radio(choices=["No", "Yes"], value="No", label="Weight Loss"), gr.Radio(choices=["No", "Yes"], value="Yes", label="Endovascular TAVR"), gr.Radio(choices=["No", "Yes"], value="Yes", label="Transapical TAVR"), ] description_html = """
This app predicts in-hospital mortality after TAVR using a finalized logistic regression model with L2 penalty,
based on national inpatient data from 2012–2019 (HCUP NIS).
Published paper:
Alhwiti, T., Aldrugh, S., & Megahed, F. M. (2023), Scientific Reports