Diabetes-TNX / app.py
vafaei_ar
Initial Space with model + Gradio app
2d1beae
import gradio as gr
import pandas as pd
import joblib
import os
from sklearn.ensemble import RandomForestClassifier # Assuming RF, add others if needed
# Define categorical mappings
SEX_MAP = {'Female': 0, 'Male': 1}
ETHNICITY_MAP = {'Not Hispanic or Latino': 0, 'Hispanic or Latino': 1}
MARITAL_STATUS_MAP = {'Single': 0, 'Married': 1}
RACE_MAP = {
'White': 0,
'Asian': 1,
'American Indian or Alaska Native': 2,
'Native Hawaiian or Other Pacific Islander': 3,
'Black or African American': 4,
'Other Race': 5,
'Unknown': 6
}
# Invert maps for display in dropdowns if necessary (or use keys directly)
RACE_CHOICES = list(RACE_MAP.keys())
SEX_CHOICES = list(SEX_MAP.keys())
ETHNICITY_CHOICES = list(ETHNICITY_MAP.keys())
MARITAL_STATUS_CHOICES = list(MARITAL_STATUS_MAP.keys())
MODEL_DIR = "./models"
def get_available_models():
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR) # Create models directory if it doesn't exist
return ["No models found. Please add .joblib models to the 'models' directory."]
models = [f for f in os.listdir(MODEL_DIR) if f.endswith(".joblib")]
if not models:
return ["No models found. Please add .joblib models to the 'models' directory."]
return models
# Define all features in the order your model expects them
# IMPORTANT: This order must match the training data
EXPECTED_COLUMNS = [
'sex', 'race', 'ethnicity', 'marital_status', 'Prior_Mean_Glu',
'PT_ELX_GRP_1', 'PT_ELX_GRP_2', 'PT_ELX_GRP_3', 'PT_ELX_GRP_4',
'PT_ELX_GRP_5', 'PT_ELX_GRP_6', 'PT_ELX_GRP_7', 'PT_ELX_GRP_8',
'PT_ELX_GRP_9', 'PT_ELX_GRP_10', 'PT_ELX_GRP_13', 'PT_ELX_GRP_14',
'PT_ELX_GRP_15', 'PT_ELX_GRP_16', 'PT_ELX_GRP_17', 'PT_ELX_GRP_18',
'PT_ELX_GRP_19', 'PT_ELX_GRP_20', 'PT_ELX_GRP_21', 'PT_ELX_GRP_22',
'PT_ELX_GRP_23', 'PT_ELX_GRP_24', 'PT_ELX_GRP_25', 'PT_ELX_GRP_26',
'PT_ELX_GRP_27', 'PT_ELX_GRP_28', 'PT_ELX_GRP_29', 'PT_ELX_GRP_30',
'PT_ELX_GRP_31', 'MOF', 'SDOH', 'Gallstone', 'acei_drug', 'statin_drug',
'diuretic_drug', 'antiplatelet_drug', 'anticoagulant_drug',
'nsaid_drug', 'ppi_drug', 'beta_blokers_drug', 'vasodilators_drug',
'caaa_drug', 'ccb_drug', 'paaab_drug', 'age', 'BMI', 'Body_weight',
'SBP', 'DBP', 'Mean_AST', 'Mean_ALT', 'Mean_TBIL', 'Mean_ALP',
'Mean_Hgb', 'Mean_HCT', 'Mean_Cr', 'Mean_PLT', 'Mean_WBC', 'Mean_BUN',
'Mean_AGAP', 'Mean_Protein', 'Smoking', 'eGFR', 'ED_visits', 'LOS',
'Prediabetes', 'Alcohol_use', 'Famly_hist_diabetes', 'NAFLD',
'Hist_Gesta_diabetes', 'Pregnancy', 'numof_med_visits',
'History_AP_necrosis', 'Necrosectomy', 'Steroids_drugs',
'oral_contraceptive', 'cholelithiasis', 'acute_cholecystitis',
'hypertriglyceridemia'
]
def predict_diabetes(model_name, sex, race, ethnicity, marital_status, Prior_Mean_Glu,
PT_ELX_GRP_1, PT_ELX_GRP_2, PT_ELX_GRP_3, PT_ELX_GRP_4,
PT_ELX_GRP_5, PT_ELX_GRP_6, PT_ELX_GRP_7, PT_ELX_GRP_8,
PT_ELX_GRP_9, PT_ELX_GRP_10, PT_ELX_GRP_13, PT_ELX_GRP_14,
PT_ELX_GRP_15, PT_ELX_GRP_16, PT_ELX_GRP_17, PT_ELX_GRP_18,
PT_ELX_GRP_19, PT_ELX_GRP_20, PT_ELX_GRP_21, PT_ELX_GRP_22,
PT_ELX_GRP_23, PT_ELX_GRP_24, PT_ELX_GRP_25, PT_ELX_GRP_26,
PT_ELX_GRP_27, PT_ELX_GRP_28, PT_ELX_GRP_29, PT_ELX_GRP_30,
PT_ELX_GRP_31, MOF, SDOH, Gallstone, acei_drug, statin_drug,
diuretic_drug, antiplatelet_drug, anticoagulant_drug,
nsaid_drug, ppi_drug, beta_blokers_drug, vasodilators_drug,
caaa_drug, ccb_drug, paaab_drug, age, BMI, Body_weight,
SBP, DBP, Mean_AST, Mean_ALT, Mean_TBIL, Mean_ALP,
Mean_Hgb, Mean_HCT, Mean_Cr, Mean_PLT, Mean_WBC, Mean_BUN,
Mean_AGAP, Mean_Protein, Smoking, eGFR, ED_visits, LOS,
Prediabetes, Alcohol_use, Famly_hist_diabetes, NAFLD,
Hist_Gesta_diabetes, Pregnancy, numof_med_visits,
History_AP_necrosis, Necrosectomy, Steroids_drugs,
oral_contraceptive, cholelithiasis, acute_cholecystitis,
hypertriglyceridemia):
if not model_name or "No models found" in model_name:
return "Please select a valid model from the 'models/' directory."
model_path = os.path.join(MODEL_DIR, model_name)
if not os.path.exists(model_path):
return f"Model file {model_name} not found in {MODEL_DIR}."
try:
model = joblib.load(model_path)
except Exception as e:
return f"Error loading model: {e}"
# Prepare data for prediction
input_data = {
'sex': SEX_MAP[sex],
'race': RACE_MAP[race],
'ethnicity': ETHNICITY_MAP[ethnicity],
'marital_status': MARITAL_STATUS_MAP[marital_status],
'Prior_Mean_Glu': float(Prior_Mean_Glu),
'PT_ELX_GRP_1': float(PT_ELX_GRP_1), 'PT_ELX_GRP_2': float(PT_ELX_GRP_2), 'PT_ELX_GRP_3': float(PT_ELX_GRP_3),
'PT_ELX_GRP_4': float(PT_ELX_GRP_4), 'PT_ELX_GRP_5': float(PT_ELX_GRP_5), 'PT_ELX_GRP_6': float(PT_ELX_GRP_6),
'PT_ELX_GRP_7': float(PT_ELX_GRP_7), 'PT_ELX_GRP_8': float(PT_ELX_GRP_8), 'PT_ELX_GRP_9': float(PT_ELX_GRP_9),
'PT_ELX_GRP_10': float(PT_ELX_GRP_10), 'PT_ELX_GRP_13': float(PT_ELX_GRP_13), 'PT_ELX_GRP_14': float(PT_ELX_GRP_14),
'PT_ELX_GRP_15': float(PT_ELX_GRP_15), 'PT_ELX_GRP_16': float(PT_ELX_GRP_16), 'PT_ELX_GRP_17': float(PT_ELX_GRP_17),
'PT_ELX_GRP_18': float(PT_ELX_GRP_18), 'PT_ELX_GRP_19': float(PT_ELX_GRP_19), 'PT_ELX_GRP_20': float(PT_ELX_GRP_20),
'PT_ELX_GRP_21': float(PT_ELX_GRP_21), 'PT_ELX_GRP_22': float(PT_ELX_GRP_22), 'PT_ELX_GRP_23': float(PT_ELX_GRP_23),
'PT_ELX_GRP_24': float(PT_ELX_GRP_24), 'PT_ELX_GRP_25': float(PT_ELX_GRP_25), 'PT_ELX_GRP_26': float(PT_ELX_GRP_26),
'PT_ELX_GRP_27': float(PT_ELX_GRP_27), 'PT_ELX_GRP_28': float(PT_ELX_GRP_28), 'PT_ELX_GRP_29': float(PT_ELX_GRP_29),
'PT_ELX_GRP_30': float(PT_ELX_GRP_30), 'PT_ELX_GRP_31': float(PT_ELX_GRP_31),
'MOF': float(MOF), 'SDOH': float(SDOH), 'Gallstone': float(Gallstone),
'acei_drug': float(acei_drug), 'statin_drug': float(statin_drug), 'diuretic_drug': float(diuretic_drug),
'antiplatelet_drug': float(antiplatelet_drug), 'anticoagulant_drug': float(anticoagulant_drug),
'nsaid_drug': float(nsaid_drug), 'ppi_drug': float(ppi_drug), 'beta_blokers_drug': float(beta_blokers_drug),
'vasodilators_drug': float(vasodilators_drug), 'caaa_drug': float(caaa_drug), 'ccb_drug': float(ccb_drug),
'paaab_drug': float(paaab_drug), 'age': float(age), 'BMI': float(BMI), 'Body_weight': float(Body_weight),
'SBP': float(SBP), 'DBP': float(DBP), 'Mean_AST': float(Mean_AST), 'Mean_ALT': float(Mean_ALT),
'Mean_TBIL': float(Mean_TBIL), 'Mean_ALP': float(Mean_ALP), 'Mean_Hgb': float(Mean_Hgb),
'Mean_HCT': float(Mean_HCT), 'Mean_Cr': float(Mean_Cr), 'Mean_PLT': float(Mean_PLT),
'Mean_WBC': float(Mean_WBC), 'Mean_BUN': float(Mean_BUN), 'Mean_AGAP': float(Mean_AGAP),
'Mean_Protein': float(Mean_Protein), 'Smoking': float(Smoking), 'eGFR': float(eGFR),
'ED_visits': float(ED_visits), 'LOS': float(LOS), 'Prediabetes': float(Prediabetes),
'Alcohol_use': float(Alcohol_use), 'Famly_hist_diabetes': float(Famly_hist_diabetes),
'NAFLD': float(NAFLD), 'Hist_Gesta_diabetes': float(Hist_Gesta_diabetes), 'Pregnancy': float(Pregnancy),
'numof_med_visits': float(numof_med_visits), 'History_AP_necrosis': float(History_AP_necrosis),
'Necrosectomy': float(Necrosectomy), 'Steroids_drugs': float(Steroids_drugs),
'oral_contraceptive': float(oral_contraceptive), 'cholelithiasis': float(cholelithiasis),
'acute_cholecystitis': float(acute_cholecystitis), 'hypertriglyceridemia': float(hypertriglyceridemia)
}
# Create DataFrame in the correct order
try:
df = pd.DataFrame([input_data], columns=EXPECTED_COLUMNS)
except Exception as e:
return f"Error creating DataFrame: {e}. Check EXPECTED_COLUMNS and input_data keys."
# Make prediction
try:
prediction = model.predict(df)
# You might need to access the first element if prediction is an array
# e.g., result = prediction[0]
# Also, convert to a more human-readable output
result = prediction[0]
if result == 1:
return "Prediction: Positive for Diabetes"
else:
return "Prediction: Negative for Diabetes"
except Exception as e:
return f"Error during prediction: {e}"
# Define Gradio inputs
inputs = [
gr.Dropdown(choices=get_available_models(), label="Select Model"),
gr.Dropdown(choices=SEX_CHOICES, label="Sex"),
gr.Dropdown(choices=RACE_CHOICES, label="Race"),
gr.Dropdown(choices=ETHNICITY_CHOICES, label="Ethnicity"),
gr.Dropdown(choices=MARITAL_STATUS_CHOICES, label="Marital Status"),
gr.Number(label="Prior Mean Glu"),
gr.Number(label="PT_ELX_GRP_1"), gr.Number(label="PT_ELX_GRP_2"), gr.Number(label="PT_ELX_GRP_3"),
gr.Number(label="PT_ELX_GRP_4"), gr.Number(label="PT_ELX_GRP_5"), gr.Number(label="PT_ELX_GRP_6"),
gr.Number(label="PT_ELX_GRP_7"), gr.Number(label="PT_ELX_GRP_8"), gr.Number(label="PT_ELX_GRP_9"),
gr.Number(label="PT_ELX_GRP_10"), gr.Number(label="PT_ELX_GRP_13"), gr.Number(label="PT_ELX_GRP_14"),
gr.Number(label="PT_ELX_GRP_15"), gr.Number(label="PT_ELX_GRP_16"), gr.Number(label="PT_ELX_GRP_17"),
gr.Number(label="PT_ELX_GRP_18"), gr.Number(label="PT_ELX_GRP_19"), gr.Number(label="PT_ELX_GRP_20"),
gr.Number(label="PT_ELX_GRP_21"), gr.Number(label="PT_ELX_GRP_22"), gr.Number(label="PT_ELX_GRP_23"),
gr.Number(label="PT_ELX_GRP_24"), gr.Number(label="PT_ELX_GRP_25"), gr.Number(label="PT_ELX_GRP_26"),
gr.Number(label="PT_ELX_GRP_27"), gr.Number(label="PT_ELX_GRP_28"), gr.Number(label="PT_ELX_GRP_29"),
gr.Number(label="PT_ELX_GRP_30"), gr.Number(label="PT_ELX_GRP_31"),
gr.Number(label="MOF"), gr.Number(label="SDOH"), gr.Number(label="Gallstone"),
gr.Number(label="ACE Inhibitor Drug"), gr.Number(label="Statin Drug"), gr.Number(label="Diuretic Drug"),
gr.Number(label="Antiplatelet Drug"), gr.Number(label="Anticoagulant Drug"),
gr.Number(label="NSAID Drug"), gr.Number(label="PPI Drug"), gr.Number(label="Beta Blockers Drug"),
gr.Number(label="Vasodilators Drug"), gr.Number(label="CAAA Drug"), gr.Number(label="CCB Drug"),
gr.Number(label="PAAAB Drug"), gr.Number(label="Age"), gr.Number(label="BMI"),
gr.Number(label="Body Weight (kg)"), gr.Number(label="SBP (Systolic Blood Pressure)"),
gr.Number(label="DBP (Diastolic Blood Pressure)"), gr.Number(label="Mean AST"), gr.Number(label="Mean ALT"),
gr.Number(label="Mean TBIL"), gr.Number(label="Mean ALP"), gr.Number(label="Mean Hgb"),
gr.Number(label="Mean HCT"), gr.Number(label="Mean Cr"), gr.Number(label="Mean PLT"),
gr.Number(label="Mean WBC"), gr.Number(label="Mean BUN"), gr.Number(label="Mean AGAP"),
gr.Number(label="Mean Protein"), gr.Number(label="Smoking"), gr.Number(label="eGFR"),
gr.Number(label="ED Visits"), gr.Number(label="LOS (Length of Stay)"), gr.Number(label="Prediabetes"),
gr.Number(label="Alcohol Use"), gr.Number(label="Family History of Diabetes"),
gr.Number(label="NAFLD"), gr.Number(label="History of Gestational Diabetes"),
gr.Number(label="Pregnancy"), gr.Number(label="Number of Medical Visits"),
gr.Number(label="History AP Necrosis"), gr.Number(label="Necrosectomy"),
gr.Number(label="Steroids Drugs"), gr.Number(label="Oral Contraceptive"),
gr.Number(label="Cholelithiasis"), gr.Number(label="Acute Cholecystitis"),
gr.Number(label="Hypertriglyceridemia")
]
# Define output
output = gr.Textbox(label="Prediction Result")
# Create and launch the Gradio interface
iface = gr.Interface(
fn=predict_diabetes,
inputs=inputs,
outputs=output,
title="Diabetes Prediction",
description="Enter patient data to predict diabetes. Ensure your models are in the 'models' directory.",
allow_flagging='never'
)
if __name__ == "__main__":
iface.launch()