Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import joblib | |
| import os | |
| from sklearn.ensemble import RandomForestClassifier # Assuming RF, add others if needed | |
| # Define categorical mappings | |
| SEX_MAP = {'Female': 0, 'Male': 1} | |
| ETHNICITY_MAP = {'Not Hispanic or Latino': 0, 'Hispanic or Latino': 1} | |
| MARITAL_STATUS_MAP = {'Single': 0, 'Married': 1} | |
| RACE_MAP = { | |
| 'White': 0, | |
| 'Asian': 1, | |
| 'American Indian or Alaska Native': 2, | |
| 'Native Hawaiian or Other Pacific Islander': 3, | |
| 'Black or African American': 4, | |
| 'Other Race': 5, | |
| 'Unknown': 6 | |
| } | |
| # Invert maps for display in dropdowns if necessary (or use keys directly) | |
| RACE_CHOICES = list(RACE_MAP.keys()) | |
| SEX_CHOICES = list(SEX_MAP.keys()) | |
| ETHNICITY_CHOICES = list(ETHNICITY_MAP.keys()) | |
| MARITAL_STATUS_CHOICES = list(MARITAL_STATUS_MAP.keys()) | |
| MODEL_DIR = "./models" | |
| # def get_available_models(): | |
| # if not os.path.exists(MODEL_DIR): | |
| # os.makedirs(MODEL_DIR) # Create models directory if it doesn't exist | |
| # return ["No models found. Please add .joblib models to the 'models' directory."] | |
| # models = [f for f in os.listdir(MODEL_DIR) if f.endswith(".joblib")] | |
| # if not models: | |
| # return ["No models found. Please add .joblib models to the 'models' directory."] | |
| # return models | |
| def get_available_models(): | |
| if not os.path.exists(MODEL_DIR): | |
| os.makedirs(MODEL_DIR) # Create models directory if it doesn't exist | |
| return {"classical": [], "foundation": []} | |
| models = [f for f in os.listdir(MODEL_DIR) if f.endswith(".joblib")] | |
| if not models: | |
| return {"classical": [], "foundation": []} | |
| # Organize models by type and time period | |
| model_dict = { | |
| "classical": { | |
| "diabetes": "Logistic regression_diabetes.joblib", | |
| "24mths": "Logistic regression_diabetes_24mths.joblib", | |
| "36mths": "Logistic regression_diabetes_36mths.joblib", | |
| "48mths": "Logistic regression_diabetes_48mths.joblib" | |
| }, | |
| "foundation": { | |
| "diabetes": "FM_Logistic regression_diabetes.joblib", | |
| "24mths": "FM_Logistic regression_diabetes_24mths.joblib", | |
| "36mths": "FM_Logistic regression_diabetes_36mths.joblib", | |
| "48mths": "FM_Logistic regression_diabetes_48mths.joblib" | |
| } | |
| } | |
| return model_dict | |
| # Define all features in the order your model expects them | |
| # IMPORTANT: This order must match the training data | |
| EXPECTED_COLUMNS = [ | |
| 'sex', 'race', 'ethnicity', 'marital_status', 'Prior_Mean_Glu', | |
| 'PT_ELX_GRP_1', 'PT_ELX_GRP_2', 'PT_ELX_GRP_3', 'PT_ELX_GRP_4', | |
| 'PT_ELX_GRP_5', 'PT_ELX_GRP_6', 'PT_ELX_GRP_7', 'PT_ELX_GRP_8', | |
| 'PT_ELX_GRP_9', 'PT_ELX_GRP_10', 'PT_ELX_GRP_13', 'PT_ELX_GRP_14', | |
| 'PT_ELX_GRP_15', 'PT_ELX_GRP_16', 'PT_ELX_GRP_17', 'PT_ELX_GRP_18', | |
| 'PT_ELX_GRP_19', 'PT_ELX_GRP_20', 'PT_ELX_GRP_21', 'PT_ELX_GRP_22', | |
| 'PT_ELX_GRP_23', 'PT_ELX_GRP_24', 'PT_ELX_GRP_25', 'PT_ELX_GRP_26', | |
| 'PT_ELX_GRP_27', 'PT_ELX_GRP_28', 'PT_ELX_GRP_29', 'PT_ELX_GRP_30', | |
| 'PT_ELX_GRP_31', 'MOF', 'SDOH', 'Gallstone', 'acei_drug', 'statin_drug', | |
| 'diuretic_drug', 'antiplatelet_drug', 'anticoagulant_drug', | |
| 'nsaid_drug', 'ppi_drug', 'beta_blokers_drug', 'vasodilators_drug', | |
| 'caaa_drug', 'ccb_drug', 'paaab_drug', 'age', 'BMI', 'Body_weight', | |
| 'SBP', 'DBP', 'Mean_AST', 'Mean_ALT', 'Mean_TBIL', 'Mean_ALP', | |
| 'Mean_Hgb', 'Mean_HCT', 'Mean_Cr', 'Mean_PLT', 'Mean_WBC', 'Mean_BUN', | |
| 'Mean_AGAP', 'Mean_Protein', 'Smoking', 'eGFR', 'ED_visits', 'LOS', | |
| 'Prediabetes', 'Alcohol_use', 'Famly_hist_diabetes', 'NAFLD', | |
| 'Hist_Gesta_diabetes', 'Pregnancy', 'numof_med_visits', | |
| 'History_AP_necrosis', 'Necrosectomy', 'Steroids_drugs', | |
| 'oral_contraceptive', 'cholelithiasis', 'acute_cholecystitis', | |
| 'hypertriglyceridemia' | |
| ] | |
| def predict_diabetes(model_type, time_period, sex, race, ethnicity, marital_status, Prior_Mean_Glu, | |
| PT_ELX_GRP_1, PT_ELX_GRP_2, PT_ELX_GRP_3, PT_ELX_GRP_4, | |
| PT_ELX_GRP_5, PT_ELX_GRP_6, PT_ELX_GRP_7, PT_ELX_GRP_8, | |
| PT_ELX_GRP_9, PT_ELX_GRP_10, PT_ELX_GRP_13, PT_ELX_GRP_14, | |
| PT_ELX_GRP_15, PT_ELX_GRP_16, PT_ELX_GRP_17, PT_ELX_GRP_18, | |
| PT_ELX_GRP_19, PT_ELX_GRP_20, PT_ELX_GRP_21, PT_ELX_GRP_22, | |
| PT_ELX_GRP_23, PT_ELX_GRP_24, PT_ELX_GRP_25, PT_ELX_GRP_26, | |
| PT_ELX_GRP_27, PT_ELX_GRP_28, PT_ELX_GRP_29, PT_ELX_GRP_30, | |
| PT_ELX_GRP_31, MOF, SDOH, Gallstone, acei_drug, statin_drug, | |
| diuretic_drug, antiplatelet_drug, anticoagulant_drug, | |
| nsaid_drug, ppi_drug, beta_blokers_drug, vasodilators_drug, | |
| caaa_drug, ccb_drug, paaab_drug, age, BMI, Body_weight, | |
| SBP, DBP, Mean_AST, Mean_ALT, Mean_TBIL, Mean_ALP, | |
| Mean_Hgb, Mean_HCT, Mean_Cr, Mean_PLT, Mean_WBC, Mean_BUN, | |
| Mean_AGAP, Mean_Protein, Smoking, eGFR, ED_visits, LOS, | |
| Prediabetes, Alcohol_use, Famly_hist_diabetes, NAFLD, | |
| Hist_Gesta_diabetes, Pregnancy, numof_med_visits, | |
| History_AP_necrosis, Necrosectomy, Steroids_drugs, | |
| oral_contraceptive, cholelithiasis, acute_cholecystitis, | |
| hypertriglyceridemia): | |
| if not model_type or not time_period: | |
| return "Please select both model type and time period." | |
| model_dict = get_available_models() | |
| model_name = model_dict[model_type][time_period] | |
| if not model_name: | |
| return "Selected model not found. Please check the model type and time period." | |
| model_path = os.path.join(MODEL_DIR, model_name) | |
| if not os.path.exists(model_path): | |
| return f"Model file {model_name} not found in {MODEL_DIR}." | |
| try: | |
| model = joblib.load(model_path) | |
| except Exception as e: | |
| return f"Error loading model: {e}" | |
| # Prepare data for prediction | |
| input_data = { | |
| 'sex': SEX_MAP[sex], | |
| 'race': RACE_MAP[race], | |
| 'ethnicity': ETHNICITY_MAP[ethnicity], | |
| 'marital_status': MARITAL_STATUS_MAP[marital_status], | |
| 'Prior_Mean_Glu': float(Prior_Mean_Glu), | |
| 'PT_ELX_GRP_1': float(PT_ELX_GRP_1), 'PT_ELX_GRP_2': float(PT_ELX_GRP_2), 'PT_ELX_GRP_3': float(PT_ELX_GRP_3), | |
| 'PT_ELX_GRP_4': float(PT_ELX_GRP_4), 'PT_ELX_GRP_5': float(PT_ELX_GRP_5), 'PT_ELX_GRP_6': float(PT_ELX_GRP_6), | |
| 'PT_ELX_GRP_7': float(PT_ELX_GRP_7), 'PT_ELX_GRP_8': float(PT_ELX_GRP_8), 'PT_ELX_GRP_9': float(PT_ELX_GRP_9), | |
| 'PT_ELX_GRP_10': float(PT_ELX_GRP_10), 'PT_ELX_GRP_13': float(PT_ELX_GRP_13), 'PT_ELX_GRP_14': float(PT_ELX_GRP_14), | |
| 'PT_ELX_GRP_15': float(PT_ELX_GRP_15), 'PT_ELX_GRP_16': float(PT_ELX_GRP_16), 'PT_ELX_GRP_17': float(PT_ELX_GRP_17), | |
| 'PT_ELX_GRP_18': float(PT_ELX_GRP_18), 'PT_ELX_GRP_19': float(PT_ELX_GRP_19), 'PT_ELX_GRP_20': float(PT_ELX_GRP_20), | |
| 'PT_ELX_GRP_21': float(PT_ELX_GRP_21), 'PT_ELX_GRP_22': float(PT_ELX_GRP_22), 'PT_ELX_GRP_23': float(PT_ELX_GRP_23), | |
| 'PT_ELX_GRP_24': float(PT_ELX_GRP_24), 'PT_ELX_GRP_25': float(PT_ELX_GRP_25), 'PT_ELX_GRP_26': float(PT_ELX_GRP_26), | |
| 'PT_ELX_GRP_27': float(PT_ELX_GRP_27), 'PT_ELX_GRP_28': float(PT_ELX_GRP_28), 'PT_ELX_GRP_29': float(PT_ELX_GRP_29), | |
| 'PT_ELX_GRP_30': float(PT_ELX_GRP_30), 'PT_ELX_GRP_31': float(PT_ELX_GRP_31), | |
| 'MOF': float(MOF), 'SDOH': float(SDOH), 'Gallstone': float(Gallstone), | |
| 'acei_drug': float(acei_drug), 'statin_drug': float(statin_drug), 'diuretic_drug': float(diuretic_drug), | |
| 'antiplatelet_drug': float(antiplatelet_drug), 'anticoagulant_drug': float(anticoagulant_drug), | |
| 'nsaid_drug': float(nsaid_drug), 'ppi_drug': float(ppi_drug), 'beta_blokers_drug': float(beta_blokers_drug), | |
| 'vasodilators_drug': float(vasodilators_drug), 'caaa_drug': float(caaa_drug), 'ccb_drug': float(ccb_drug), | |
| 'paaab_drug': float(paaab_drug), 'age': float(age), 'BMI': float(BMI), 'Body_weight': float(Body_weight), | |
| 'SBP': float(SBP), 'DBP': float(DBP), 'Mean_AST': float(Mean_AST), 'Mean_ALT': float(Mean_ALT), | |
| 'Mean_TBIL': float(Mean_TBIL), 'Mean_ALP': float(Mean_ALP), 'Mean_Hgb': float(Mean_Hgb), | |
| 'Mean_HCT': float(Mean_HCT), 'Mean_Cr': float(Mean_Cr), 'Mean_PLT': float(Mean_PLT), | |
| 'Mean_WBC': float(Mean_WBC), 'Mean_BUN': float(Mean_BUN), 'Mean_AGAP': float(Mean_AGAP), | |
| 'Mean_Protein': float(Mean_Protein), 'Smoking': float(Smoking), 'eGFR': float(eGFR), | |
| 'ED_visits': float(ED_visits), 'LOS': float(LOS), 'Prediabetes': float(Prediabetes), | |
| 'Alcohol_use': float(Alcohol_use), 'Famly_hist_diabetes': float(Famly_hist_diabetes), | |
| 'NAFLD': float(NAFLD), 'Hist_Gesta_diabetes': float(Hist_Gesta_diabetes), 'Pregnancy': float(Pregnancy), | |
| 'numof_med_visits': float(numof_med_visits), 'History_AP_necrosis': float(History_AP_necrosis), | |
| 'Necrosectomy': float(Necrosectomy), 'Steroids_drugs': float(Steroids_drugs), | |
| 'oral_contraceptive': float(oral_contraceptive), 'cholelithiasis': float(cholelithiasis), | |
| 'acute_cholecystitis': float(acute_cholecystitis), 'hypertriglyceridemia': float(hypertriglyceridemia) | |
| } | |
| # Create DataFrame in the correct order | |
| try: | |
| df = pd.DataFrame([input_data], columns=EXPECTED_COLUMNS) | |
| except Exception as e: | |
| return f"Error creating DataFrame: {e}. Check EXPECTED_COLUMNS and input_data keys." | |
| # Make prediction | |
| try: | |
| if model_type == "foundation": | |
| # Load the TabPFN model for preprocessing | |
| try: | |
| import numpy as np | |
| import tabpfn | |
| clf = joblib.load('models/FM/TabPFN_model_chunk_0.joblib') | |
| # Get embeddings for the input data | |
| X = clf.get_embeddings(df) | |
| print(X.shape) | |
| # X = np.concatenate(X,axis=1) | |
| # X = np.swapaxes(X,0,1) | |
| X = X.reshape(768 ,-1) | |
| print(X.shape) | |
| X = pd.DataFrame(data=X.T) | |
| # Make prediction using the processed data | |
| prediction = model.predict(X) | |
| except Exception as e: | |
| return f"Error in foundation model preprocessing: {e}" | |
| else: | |
| # For classical models, use the data directly | |
| prediction = model.predict(df) | |
| # Convert prediction to human-readable output | |
| result = prediction[0] | |
| if result == 1: | |
| return "Prediction: Positive for Diabetes" | |
| else: | |
| return "Prediction: Negative for Diabetes" | |
| except Exception as e: | |
| return f"Error during prediction: {e}" | |
| # Define Gradio inputs | |
| inputs = [ | |
| gr.Dropdown(choices=["classical", "foundation"], label="Model Type"), | |
| gr.Dropdown(choices=["diabetes", "24mths", "36mths", "48mths"], label="Time Period"), | |
| gr.Dropdown(choices=SEX_CHOICES, label="Sex"), | |
| gr.Dropdown(choices=RACE_CHOICES, label="Race"), | |
| gr.Dropdown(choices=ETHNICITY_CHOICES, label="Ethnicity"), | |
| gr.Dropdown(choices=MARITAL_STATUS_CHOICES, label="Marital Status"), | |
| gr.Number(label="Age"), | |
| gr.Number(label="BMI"), | |
| gr.Number(label="Smoking"), | |
| gr.Number(label="Prior Mean Glu"), | |
| gr.Number(label="PT_ELX_GRP_1"), | |
| gr.Number(label="PT_ELX_GRP_2"), | |
| gr.Number(label="PT_ELX_GRP_3"), | |
| gr.Number(label="PT_ELX_GRP_4"), | |
| gr.Number(label="PT_ELX_GRP_5"), | |
| gr.Number(label="PT_ELX_GRP_6"), | |
| gr.Number(label="PT_ELX_GRP_7"), | |
| gr.Number(label="PT_ELX_GRP_8"), | |
| gr.Number(label="PT_ELX_GRP_9"), | |
| gr.Number(label="PT_ELX_GRP_10"), | |
| gr.Number(label="PT_ELX_GRP_13"), | |
| gr.Number(label="PT_ELX_GRP_14"), | |
| gr.Number(label="PT_ELX_GRP_15"), | |
| gr.Number(label="PT_ELX_GRP_16"), | |
| gr.Number(label="PT_ELX_GRP_17"), | |
| gr.Number(label="PT_ELX_GRP_18"), | |
| gr.Number(label="PT_ELX_GRP_19"), | |
| gr.Number(label="PT_ELX_GRP_20"), | |
| gr.Number(label="PT_ELX_GRP_21"), | |
| gr.Number(label="PT_ELX_GRP_22"), | |
| gr.Number(label="PT_ELX_GRP_23"), | |
| gr.Number(label="PT_ELX_GRP_24"), | |
| gr.Number(label="PT_ELX_GRP_25"), | |
| gr.Number(label="PT_ELX_GRP_26"), | |
| gr.Number(label="PT_ELX_GRP_27"), | |
| gr.Number(label="PT_ELX_GRP_28"), | |
| gr.Number(label="PT_ELX_GRP_29"), | |
| gr.Number(label="PT_ELX_GRP_30"), | |
| gr.Number(label="PT_ELX_GRP_31"), | |
| gr.Number(label="MOF"), | |
| gr.Number(label="SDOH"), | |
| gr.Number(label="Gallstone"), | |
| gr.Number(label="ACE Inhibitor Drug"), | |
| gr.Number(label="Statin Drug"), | |
| gr.Number(label="Diuretic Drug"), | |
| gr.Number(label="Antiplatelet Drug"), | |
| gr.Number(label="Anticoagulant Drug"), | |
| gr.Number(label="NSAID Drug"), | |
| gr.Number(label="PPI Drug"), | |
| gr.Number(label="Beta Blockers Drug"), | |
| gr.Number(label="Vasodilators Drug"), | |
| gr.Number(label="CAAA Drug"), | |
| gr.Number(label="CCB Drug"), | |
| gr.Number(label="PAAAB Drug"), | |
| gr.Number(label="Body Weight (kg)"), | |
| gr.Number(label="SBP (Systolic Blood Pressure)"), | |
| gr.Number(label="DBP (Diastolic Blood Pressure)"), | |
| gr.Number(label="Mean AST"), | |
| gr.Number(label="Mean ALT"), | |
| gr.Number(label="Mean TBIL"), | |
| gr.Number(label="Mean ALP"), | |
| gr.Number(label="Mean Hgb"), | |
| gr.Number(label="Mean HCT"), | |
| gr.Number(label="Mean Cr"), | |
| gr.Number(label="Mean PLT"), | |
| gr.Number(label="Mean WBC"), | |
| gr.Number(label="Mean BUN"), | |
| gr.Number(label="Mean AGAP"), | |
| gr.Number(label="Mean Protein"), | |
| gr.Number(label="eGFR"), | |
| gr.Number(label="ED Visits"), | |
| gr.Number(label="LOS (Length of Stay)"), | |
| gr.Number(label="Prediabetes"), | |
| gr.Number(label="Alcohol Use"), | |
| gr.Number(label="Family History of Diabetes"), | |
| gr.Number(label="NAFLD"), | |
| gr.Number(label="History of Gestational Diabetes"), | |
| gr.Number(label="Pregnancy"), | |
| gr.Number(label="Number of Medical Visits"), | |
| gr.Number(label="History AP Necrosis"), | |
| gr.Number(label="Necrosectomy"), | |
| gr.Number(label="Steroids Drugs"), | |
| gr.Number(label="Oral Contraceptive"), | |
| gr.Number(label="Cholelithiasis"), | |
| gr.Number(label="Acute Cholecystitis"), | |
| gr.Number(label="Hypertriglyceridemia") | |
| ] | |
| # Define output | |
| output = gr.Textbox(label="Prediction Result") | |
| # Create and launch the Gradio interface | |
| iface = gr.Interface( | |
| fn=predict_diabetes, | |
| inputs=inputs, | |
| outputs=output, | |
| title="Diabetes Prediction", | |
| description="Enter patient data to predict diabetes. Ensure your models are in the 'models' directory.", | |
| allow_flagging='never' | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |