Diabetes-TNX / app.py
vafaei_ar
FM selection and model added.
813cf60
import gradio as gr
import pandas as pd
import joblib
import os
from sklearn.ensemble import RandomForestClassifier # Assuming RF, add others if needed
# Define categorical mappings
SEX_MAP = {'Female': 0, 'Male': 1}
ETHNICITY_MAP = {'Not Hispanic or Latino': 0, 'Hispanic or Latino': 1}
MARITAL_STATUS_MAP = {'Single': 0, 'Married': 1}
RACE_MAP = {
'White': 0,
'Asian': 1,
'American Indian or Alaska Native': 2,
'Native Hawaiian or Other Pacific Islander': 3,
'Black or African American': 4,
'Other Race': 5,
'Unknown': 6
}
# Invert maps for display in dropdowns if necessary (or use keys directly)
RACE_CHOICES = list(RACE_MAP.keys())
SEX_CHOICES = list(SEX_MAP.keys())
ETHNICITY_CHOICES = list(ETHNICITY_MAP.keys())
MARITAL_STATUS_CHOICES = list(MARITAL_STATUS_MAP.keys())
MODEL_DIR = "./models"
# def get_available_models():
# if not os.path.exists(MODEL_DIR):
# os.makedirs(MODEL_DIR) # Create models directory if it doesn't exist
# return ["No models found. Please add .joblib models to the 'models' directory."]
# models = [f for f in os.listdir(MODEL_DIR) if f.endswith(".joblib")]
# if not models:
# return ["No models found. Please add .joblib models to the 'models' directory."]
# return models
def get_available_models():
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR) # Create models directory if it doesn't exist
return {"classical": [], "foundation": []}
models = [f for f in os.listdir(MODEL_DIR) if f.endswith(".joblib")]
if not models:
return {"classical": [], "foundation": []}
# Organize models by type and time period
model_dict = {
"classical": {
"diabetes": "Logistic regression_diabetes.joblib",
"24mths": "Logistic regression_diabetes_24mths.joblib",
"36mths": "Logistic regression_diabetes_36mths.joblib",
"48mths": "Logistic regression_diabetes_48mths.joblib"
},
"foundation": {
"diabetes": "FM_Logistic regression_diabetes.joblib",
"24mths": "FM_Logistic regression_diabetes_24mths.joblib",
"36mths": "FM_Logistic regression_diabetes_36mths.joblib",
"48mths": "FM_Logistic regression_diabetes_48mths.joblib"
}
}
return model_dict
# Define all features in the order your model expects them
# IMPORTANT: This order must match the training data
EXPECTED_COLUMNS = [
'sex', 'race', 'ethnicity', 'marital_status', 'Prior_Mean_Glu',
'PT_ELX_GRP_1', 'PT_ELX_GRP_2', 'PT_ELX_GRP_3', 'PT_ELX_GRP_4',
'PT_ELX_GRP_5', 'PT_ELX_GRP_6', 'PT_ELX_GRP_7', 'PT_ELX_GRP_8',
'PT_ELX_GRP_9', 'PT_ELX_GRP_10', 'PT_ELX_GRP_13', 'PT_ELX_GRP_14',
'PT_ELX_GRP_15', 'PT_ELX_GRP_16', 'PT_ELX_GRP_17', 'PT_ELX_GRP_18',
'PT_ELX_GRP_19', 'PT_ELX_GRP_20', 'PT_ELX_GRP_21', 'PT_ELX_GRP_22',
'PT_ELX_GRP_23', 'PT_ELX_GRP_24', 'PT_ELX_GRP_25', 'PT_ELX_GRP_26',
'PT_ELX_GRP_27', 'PT_ELX_GRP_28', 'PT_ELX_GRP_29', 'PT_ELX_GRP_30',
'PT_ELX_GRP_31', 'MOF', 'SDOH', 'Gallstone', 'acei_drug', 'statin_drug',
'diuretic_drug', 'antiplatelet_drug', 'anticoagulant_drug',
'nsaid_drug', 'ppi_drug', 'beta_blokers_drug', 'vasodilators_drug',
'caaa_drug', 'ccb_drug', 'paaab_drug', 'age', 'BMI', 'Body_weight',
'SBP', 'DBP', 'Mean_AST', 'Mean_ALT', 'Mean_TBIL', 'Mean_ALP',
'Mean_Hgb', 'Mean_HCT', 'Mean_Cr', 'Mean_PLT', 'Mean_WBC', 'Mean_BUN',
'Mean_AGAP', 'Mean_Protein', 'Smoking', 'eGFR', 'ED_visits', 'LOS',
'Prediabetes', 'Alcohol_use', 'Famly_hist_diabetes', 'NAFLD',
'Hist_Gesta_diabetes', 'Pregnancy', 'numof_med_visits',
'History_AP_necrosis', 'Necrosectomy', 'Steroids_drugs',
'oral_contraceptive', 'cholelithiasis', 'acute_cholecystitis',
'hypertriglyceridemia'
]
def predict_diabetes(model_type, time_period, sex, race, ethnicity, marital_status, Prior_Mean_Glu,
PT_ELX_GRP_1, PT_ELX_GRP_2, PT_ELX_GRP_3, PT_ELX_GRP_4,
PT_ELX_GRP_5, PT_ELX_GRP_6, PT_ELX_GRP_7, PT_ELX_GRP_8,
PT_ELX_GRP_9, PT_ELX_GRP_10, PT_ELX_GRP_13, PT_ELX_GRP_14,
PT_ELX_GRP_15, PT_ELX_GRP_16, PT_ELX_GRP_17, PT_ELX_GRP_18,
PT_ELX_GRP_19, PT_ELX_GRP_20, PT_ELX_GRP_21, PT_ELX_GRP_22,
PT_ELX_GRP_23, PT_ELX_GRP_24, PT_ELX_GRP_25, PT_ELX_GRP_26,
PT_ELX_GRP_27, PT_ELX_GRP_28, PT_ELX_GRP_29, PT_ELX_GRP_30,
PT_ELX_GRP_31, MOF, SDOH, Gallstone, acei_drug, statin_drug,
diuretic_drug, antiplatelet_drug, anticoagulant_drug,
nsaid_drug, ppi_drug, beta_blokers_drug, vasodilators_drug,
caaa_drug, ccb_drug, paaab_drug, age, BMI, Body_weight,
SBP, DBP, Mean_AST, Mean_ALT, Mean_TBIL, Mean_ALP,
Mean_Hgb, Mean_HCT, Mean_Cr, Mean_PLT, Mean_WBC, Mean_BUN,
Mean_AGAP, Mean_Protein, Smoking, eGFR, ED_visits, LOS,
Prediabetes, Alcohol_use, Famly_hist_diabetes, NAFLD,
Hist_Gesta_diabetes, Pregnancy, numof_med_visits,
History_AP_necrosis, Necrosectomy, Steroids_drugs,
oral_contraceptive, cholelithiasis, acute_cholecystitis,
hypertriglyceridemia):
if not model_type or not time_period:
return "Please select both model type and time period."
model_dict = get_available_models()
model_name = model_dict[model_type][time_period]
if not model_name:
return "Selected model not found. Please check the model type and time period."
model_path = os.path.join(MODEL_DIR, model_name)
if not os.path.exists(model_path):
return f"Model file {model_name} not found in {MODEL_DIR}."
try:
model = joblib.load(model_path)
except Exception as e:
return f"Error loading model: {e}"
# Prepare data for prediction
input_data = {
'sex': SEX_MAP[sex],
'race': RACE_MAP[race],
'ethnicity': ETHNICITY_MAP[ethnicity],
'marital_status': MARITAL_STATUS_MAP[marital_status],
'Prior_Mean_Glu': float(Prior_Mean_Glu),
'PT_ELX_GRP_1': float(PT_ELX_GRP_1), 'PT_ELX_GRP_2': float(PT_ELX_GRP_2), 'PT_ELX_GRP_3': float(PT_ELX_GRP_3),
'PT_ELX_GRP_4': float(PT_ELX_GRP_4), 'PT_ELX_GRP_5': float(PT_ELX_GRP_5), 'PT_ELX_GRP_6': float(PT_ELX_GRP_6),
'PT_ELX_GRP_7': float(PT_ELX_GRP_7), 'PT_ELX_GRP_8': float(PT_ELX_GRP_8), 'PT_ELX_GRP_9': float(PT_ELX_GRP_9),
'PT_ELX_GRP_10': float(PT_ELX_GRP_10), 'PT_ELX_GRP_13': float(PT_ELX_GRP_13), 'PT_ELX_GRP_14': float(PT_ELX_GRP_14),
'PT_ELX_GRP_15': float(PT_ELX_GRP_15), 'PT_ELX_GRP_16': float(PT_ELX_GRP_16), 'PT_ELX_GRP_17': float(PT_ELX_GRP_17),
'PT_ELX_GRP_18': float(PT_ELX_GRP_18), 'PT_ELX_GRP_19': float(PT_ELX_GRP_19), 'PT_ELX_GRP_20': float(PT_ELX_GRP_20),
'PT_ELX_GRP_21': float(PT_ELX_GRP_21), 'PT_ELX_GRP_22': float(PT_ELX_GRP_22), 'PT_ELX_GRP_23': float(PT_ELX_GRP_23),
'PT_ELX_GRP_24': float(PT_ELX_GRP_24), 'PT_ELX_GRP_25': float(PT_ELX_GRP_25), 'PT_ELX_GRP_26': float(PT_ELX_GRP_26),
'PT_ELX_GRP_27': float(PT_ELX_GRP_27), 'PT_ELX_GRP_28': float(PT_ELX_GRP_28), 'PT_ELX_GRP_29': float(PT_ELX_GRP_29),
'PT_ELX_GRP_30': float(PT_ELX_GRP_30), 'PT_ELX_GRP_31': float(PT_ELX_GRP_31),
'MOF': float(MOF), 'SDOH': float(SDOH), 'Gallstone': float(Gallstone),
'acei_drug': float(acei_drug), 'statin_drug': float(statin_drug), 'diuretic_drug': float(diuretic_drug),
'antiplatelet_drug': float(antiplatelet_drug), 'anticoagulant_drug': float(anticoagulant_drug),
'nsaid_drug': float(nsaid_drug), 'ppi_drug': float(ppi_drug), 'beta_blokers_drug': float(beta_blokers_drug),
'vasodilators_drug': float(vasodilators_drug), 'caaa_drug': float(caaa_drug), 'ccb_drug': float(ccb_drug),
'paaab_drug': float(paaab_drug), 'age': float(age), 'BMI': float(BMI), 'Body_weight': float(Body_weight),
'SBP': float(SBP), 'DBP': float(DBP), 'Mean_AST': float(Mean_AST), 'Mean_ALT': float(Mean_ALT),
'Mean_TBIL': float(Mean_TBIL), 'Mean_ALP': float(Mean_ALP), 'Mean_Hgb': float(Mean_Hgb),
'Mean_HCT': float(Mean_HCT), 'Mean_Cr': float(Mean_Cr), 'Mean_PLT': float(Mean_PLT),
'Mean_WBC': float(Mean_WBC), 'Mean_BUN': float(Mean_BUN), 'Mean_AGAP': float(Mean_AGAP),
'Mean_Protein': float(Mean_Protein), 'Smoking': float(Smoking), 'eGFR': float(eGFR),
'ED_visits': float(ED_visits), 'LOS': float(LOS), 'Prediabetes': float(Prediabetes),
'Alcohol_use': float(Alcohol_use), 'Famly_hist_diabetes': float(Famly_hist_diabetes),
'NAFLD': float(NAFLD), 'Hist_Gesta_diabetes': float(Hist_Gesta_diabetes), 'Pregnancy': float(Pregnancy),
'numof_med_visits': float(numof_med_visits), 'History_AP_necrosis': float(History_AP_necrosis),
'Necrosectomy': float(Necrosectomy), 'Steroids_drugs': float(Steroids_drugs),
'oral_contraceptive': float(oral_contraceptive), 'cholelithiasis': float(cholelithiasis),
'acute_cholecystitis': float(acute_cholecystitis), 'hypertriglyceridemia': float(hypertriglyceridemia)
}
# Create DataFrame in the correct order
try:
df = pd.DataFrame([input_data], columns=EXPECTED_COLUMNS)
except Exception as e:
return f"Error creating DataFrame: {e}. Check EXPECTED_COLUMNS and input_data keys."
# Make prediction
try:
if model_type == "foundation":
# Load the TabPFN model for preprocessing
try:
import numpy as np
import tabpfn
clf = joblib.load('models/FM/TabPFN_model_chunk_0.joblib')
# Get embeddings for the input data
X = clf.get_embeddings(df)
print(X.shape)
# X = np.concatenate(X,axis=1)
# X = np.swapaxes(X,0,1)
X = X.reshape(768 ,-1)
print(X.shape)
X = pd.DataFrame(data=X.T)
# Make prediction using the processed data
prediction = model.predict(X)
except Exception as e:
return f"Error in foundation model preprocessing: {e}"
else:
# For classical models, use the data directly
prediction = model.predict(df)
# Convert prediction to human-readable output
result = prediction[0]
if result == 1:
return "Prediction: Positive for Diabetes"
else:
return "Prediction: Negative for Diabetes"
except Exception as e:
return f"Error during prediction: {e}"
# Define Gradio inputs
inputs = [
gr.Dropdown(choices=["classical", "foundation"], label="Model Type"),
gr.Dropdown(choices=["diabetes", "24mths", "36mths", "48mths"], label="Time Period"),
gr.Dropdown(choices=SEX_CHOICES, label="Sex"),
gr.Dropdown(choices=RACE_CHOICES, label="Race"),
gr.Dropdown(choices=ETHNICITY_CHOICES, label="Ethnicity"),
gr.Dropdown(choices=MARITAL_STATUS_CHOICES, label="Marital Status"),
gr.Number(label="Age"),
gr.Number(label="BMI"),
gr.Number(label="Smoking"),
gr.Number(label="Prior Mean Glu"),
gr.Number(label="PT_ELX_GRP_1"),
gr.Number(label="PT_ELX_GRP_2"),
gr.Number(label="PT_ELX_GRP_3"),
gr.Number(label="PT_ELX_GRP_4"),
gr.Number(label="PT_ELX_GRP_5"),
gr.Number(label="PT_ELX_GRP_6"),
gr.Number(label="PT_ELX_GRP_7"),
gr.Number(label="PT_ELX_GRP_8"),
gr.Number(label="PT_ELX_GRP_9"),
gr.Number(label="PT_ELX_GRP_10"),
gr.Number(label="PT_ELX_GRP_13"),
gr.Number(label="PT_ELX_GRP_14"),
gr.Number(label="PT_ELX_GRP_15"),
gr.Number(label="PT_ELX_GRP_16"),
gr.Number(label="PT_ELX_GRP_17"),
gr.Number(label="PT_ELX_GRP_18"),
gr.Number(label="PT_ELX_GRP_19"),
gr.Number(label="PT_ELX_GRP_20"),
gr.Number(label="PT_ELX_GRP_21"),
gr.Number(label="PT_ELX_GRP_22"),
gr.Number(label="PT_ELX_GRP_23"),
gr.Number(label="PT_ELX_GRP_24"),
gr.Number(label="PT_ELX_GRP_25"),
gr.Number(label="PT_ELX_GRP_26"),
gr.Number(label="PT_ELX_GRP_27"),
gr.Number(label="PT_ELX_GRP_28"),
gr.Number(label="PT_ELX_GRP_29"),
gr.Number(label="PT_ELX_GRP_30"),
gr.Number(label="PT_ELX_GRP_31"),
gr.Number(label="MOF"),
gr.Number(label="SDOH"),
gr.Number(label="Gallstone"),
gr.Number(label="ACE Inhibitor Drug"),
gr.Number(label="Statin Drug"),
gr.Number(label="Diuretic Drug"),
gr.Number(label="Antiplatelet Drug"),
gr.Number(label="Anticoagulant Drug"),
gr.Number(label="NSAID Drug"),
gr.Number(label="PPI Drug"),
gr.Number(label="Beta Blockers Drug"),
gr.Number(label="Vasodilators Drug"),
gr.Number(label="CAAA Drug"),
gr.Number(label="CCB Drug"),
gr.Number(label="PAAAB Drug"),
gr.Number(label="Body Weight (kg)"),
gr.Number(label="SBP (Systolic Blood Pressure)"),
gr.Number(label="DBP (Diastolic Blood Pressure)"),
gr.Number(label="Mean AST"),
gr.Number(label="Mean ALT"),
gr.Number(label="Mean TBIL"),
gr.Number(label="Mean ALP"),
gr.Number(label="Mean Hgb"),
gr.Number(label="Mean HCT"),
gr.Number(label="Mean Cr"),
gr.Number(label="Mean PLT"),
gr.Number(label="Mean WBC"),
gr.Number(label="Mean BUN"),
gr.Number(label="Mean AGAP"),
gr.Number(label="Mean Protein"),
gr.Number(label="eGFR"),
gr.Number(label="ED Visits"),
gr.Number(label="LOS (Length of Stay)"),
gr.Number(label="Prediabetes"),
gr.Number(label="Alcohol Use"),
gr.Number(label="Family History of Diabetes"),
gr.Number(label="NAFLD"),
gr.Number(label="History of Gestational Diabetes"),
gr.Number(label="Pregnancy"),
gr.Number(label="Number of Medical Visits"),
gr.Number(label="History AP Necrosis"),
gr.Number(label="Necrosectomy"),
gr.Number(label="Steroids Drugs"),
gr.Number(label="Oral Contraceptive"),
gr.Number(label="Cholelithiasis"),
gr.Number(label="Acute Cholecystitis"),
gr.Number(label="Hypertriglyceridemia")
]
# Define output
output = gr.Textbox(label="Prediction Result")
# Create and launch the Gradio interface
iface = gr.Interface(
fn=predict_diabetes,
inputs=inputs,
outputs=output,
title="Diabetes Prediction",
description="Enter patient data to predict diabetes. Ensure your models are in the 'models' directory.",
allow_flagging='never'
)
if __name__ == "__main__":
iface.launch()