vafaei_ar commited on
Commit
813cf60
·
1 Parent(s): 01529ed

FM selection and model added.

Browse files
Files changed (1) hide show
  1. app.py +68 -15
app.py CHANGED
@@ -26,14 +26,41 @@ MARITAL_STATUS_CHOICES = list(MARITAL_STATUS_MAP.keys())
26
 
27
  MODEL_DIR = "./models"
28
 
 
 
 
 
 
 
 
 
 
29
  def get_available_models():
30
  if not os.path.exists(MODEL_DIR):
31
  os.makedirs(MODEL_DIR) # Create models directory if it doesn't exist
32
- return ["No models found. Please add .joblib models to the 'models' directory."]
 
33
  models = [f for f in os.listdir(MODEL_DIR) if f.endswith(".joblib")]
34
  if not models:
35
- return ["No models found. Please add .joblib models to the 'models' directory."]
36
- return models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # Define all features in the order your model expects them
39
  # IMPORTANT: This order must match the training data
@@ -60,7 +87,7 @@ EXPECTED_COLUMNS = [
60
  'hypertriglyceridemia'
61
  ]
62
 
63
- def predict_diabetes(model_name, sex, race, ethnicity, marital_status, Prior_Mean_Glu,
64
  PT_ELX_GRP_1, PT_ELX_GRP_2, PT_ELX_GRP_3, PT_ELX_GRP_4,
65
  PT_ELX_GRP_5, PT_ELX_GRP_6, PT_ELX_GRP_7, PT_ELX_GRP_8,
66
  PT_ELX_GRP_9, PT_ELX_GRP_10, PT_ELX_GRP_13, PT_ELX_GRP_14,
@@ -81,8 +108,14 @@ def predict_diabetes(model_name, sex, race, ethnicity, marital_status, Prior_Mea
81
  oral_contraceptive, cholelithiasis, acute_cholecystitis,
82
  hypertriglyceridemia):
83
 
84
- if not model_name or "No models found" in model_name:
85
- return "Please select a valid model from the 'models/' directory."
 
 
 
 
 
 
86
 
87
  model_path = os.path.join(MODEL_DIR, model_name)
88
  if not os.path.exists(model_path):
@@ -138,11 +171,30 @@ def predict_diabetes(model_name, sex, race, ethnicity, marital_status, Prior_Mea
138
 
139
  # Make prediction
140
  try:
141
- prediction = model.predict(df)
142
- # You might need to access the first element if prediction is an array
143
- # e.g., result = prediction[0]
144
- # Also, convert to a more human-readable output
145
- result = prediction[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  if result == 1:
147
  return "Prediction: Positive for Diabetes"
148
  else:
@@ -152,11 +204,15 @@ def predict_diabetes(model_name, sex, race, ethnicity, marital_status, Prior_Mea
152
 
153
  # Define Gradio inputs
154
  inputs = [
155
- gr.Dropdown(choices=get_available_models(), label="Select Model"),
 
156
  gr.Dropdown(choices=SEX_CHOICES, label="Sex"),
157
  gr.Dropdown(choices=RACE_CHOICES, label="Race"),
158
  gr.Dropdown(choices=ETHNICITY_CHOICES, label="Ethnicity"),
159
  gr.Dropdown(choices=MARITAL_STATUS_CHOICES, label="Marital Status"),
 
 
 
160
  gr.Number(label="Prior Mean Glu"),
161
  gr.Number(label="PT_ELX_GRP_1"),
162
  gr.Number(label="PT_ELX_GRP_2"),
@@ -202,8 +258,6 @@ inputs = [
202
  gr.Number(label="CAAA Drug"),
203
  gr.Number(label="CCB Drug"),
204
  gr.Number(label="PAAAB Drug"),
205
- gr.Number(label="Age"),
206
- gr.Number(label="BMI"),
207
  gr.Number(label="Body Weight (kg)"),
208
  gr.Number(label="SBP (Systolic Blood Pressure)"),
209
  gr.Number(label="DBP (Diastolic Blood Pressure)"),
@@ -219,7 +273,6 @@ inputs = [
219
  gr.Number(label="Mean BUN"),
220
  gr.Number(label="Mean AGAP"),
221
  gr.Number(label="Mean Protein"),
222
- gr.Number(label="Smoking"),
223
  gr.Number(label="eGFR"),
224
  gr.Number(label="ED Visits"),
225
  gr.Number(label="LOS (Length of Stay)"),
 
26
 
27
  MODEL_DIR = "./models"
28
 
29
+ # def get_available_models():
30
+ # if not os.path.exists(MODEL_DIR):
31
+ # os.makedirs(MODEL_DIR) # Create models directory if it doesn't exist
32
+ # return ["No models found. Please add .joblib models to the 'models' directory."]
33
+ # models = [f for f in os.listdir(MODEL_DIR) if f.endswith(".joblib")]
34
+ # if not models:
35
+ # return ["No models found. Please add .joblib models to the 'models' directory."]
36
+ # return models
37
+
38
  def get_available_models():
39
  if not os.path.exists(MODEL_DIR):
40
  os.makedirs(MODEL_DIR) # Create models directory if it doesn't exist
41
+ return {"classical": [], "foundation": []}
42
+
43
  models = [f for f in os.listdir(MODEL_DIR) if f.endswith(".joblib")]
44
  if not models:
45
+ return {"classical": [], "foundation": []}
46
+
47
+ # Organize models by type and time period
48
+ model_dict = {
49
+ "classical": {
50
+ "diabetes": "Logistic regression_diabetes.joblib",
51
+ "24mths": "Logistic regression_diabetes_24mths.joblib",
52
+ "36mths": "Logistic regression_diabetes_36mths.joblib",
53
+ "48mths": "Logistic regression_diabetes_48mths.joblib"
54
+ },
55
+ "foundation": {
56
+ "diabetes": "FM_Logistic regression_diabetes.joblib",
57
+ "24mths": "FM_Logistic regression_diabetes_24mths.joblib",
58
+ "36mths": "FM_Logistic regression_diabetes_36mths.joblib",
59
+ "48mths": "FM_Logistic regression_diabetes_48mths.joblib"
60
+ }
61
+ }
62
+
63
+ return model_dict
64
 
65
  # Define all features in the order your model expects them
66
  # IMPORTANT: This order must match the training data
 
87
  'hypertriglyceridemia'
88
  ]
89
 
90
+ def predict_diabetes(model_type, time_period, sex, race, ethnicity, marital_status, Prior_Mean_Glu,
91
  PT_ELX_GRP_1, PT_ELX_GRP_2, PT_ELX_GRP_3, PT_ELX_GRP_4,
92
  PT_ELX_GRP_5, PT_ELX_GRP_6, PT_ELX_GRP_7, PT_ELX_GRP_8,
93
  PT_ELX_GRP_9, PT_ELX_GRP_10, PT_ELX_GRP_13, PT_ELX_GRP_14,
 
108
  oral_contraceptive, cholelithiasis, acute_cholecystitis,
109
  hypertriglyceridemia):
110
 
111
+ if not model_type or not time_period:
112
+ return "Please select both model type and time period."
113
+
114
+ model_dict = get_available_models()
115
+ model_name = model_dict[model_type][time_period]
116
+
117
+ if not model_name:
118
+ return "Selected model not found. Please check the model type and time period."
119
 
120
  model_path = os.path.join(MODEL_DIR, model_name)
121
  if not os.path.exists(model_path):
 
171
 
172
  # Make prediction
173
  try:
174
+ if model_type == "foundation":
175
+ # Load the TabPFN model for preprocessing
176
+ try:
177
+ import numpy as np
178
+ import tabpfn
179
+ clf = joblib.load('models/FM/TabPFN_model_chunk_0.joblib')
180
+ # Get embeddings for the input data
181
+ X = clf.get_embeddings(df)
182
+ print(X.shape)
183
+ # X = np.concatenate(X,axis=1)
184
+ # X = np.swapaxes(X,0,1)
185
+ X = X.reshape(768 ,-1)
186
+ print(X.shape)
187
+ X = pd.DataFrame(data=X.T)
188
+ # Make prediction using the processed data
189
+ prediction = model.predict(X)
190
+ except Exception as e:
191
+ return f"Error in foundation model preprocessing: {e}"
192
+ else:
193
+ # For classical models, use the data directly
194
+ prediction = model.predict(df)
195
+
196
+ # Convert prediction to human-readable output
197
+ result = prediction[0]
198
  if result == 1:
199
  return "Prediction: Positive for Diabetes"
200
  else:
 
204
 
205
  # Define Gradio inputs
206
  inputs = [
207
+ gr.Dropdown(choices=["classical", "foundation"], label="Model Type"),
208
+ gr.Dropdown(choices=["diabetes", "24mths", "36mths", "48mths"], label="Time Period"),
209
  gr.Dropdown(choices=SEX_CHOICES, label="Sex"),
210
  gr.Dropdown(choices=RACE_CHOICES, label="Race"),
211
  gr.Dropdown(choices=ETHNICITY_CHOICES, label="Ethnicity"),
212
  gr.Dropdown(choices=MARITAL_STATUS_CHOICES, label="Marital Status"),
213
+ gr.Number(label="Age"),
214
+ gr.Number(label="BMI"),
215
+ gr.Number(label="Smoking"),
216
  gr.Number(label="Prior Mean Glu"),
217
  gr.Number(label="PT_ELX_GRP_1"),
218
  gr.Number(label="PT_ELX_GRP_2"),
 
258
  gr.Number(label="CAAA Drug"),
259
  gr.Number(label="CCB Drug"),
260
  gr.Number(label="PAAAB Drug"),
 
 
261
  gr.Number(label="Body Weight (kg)"),
262
  gr.Number(label="SBP (Systolic Blood Pressure)"),
263
  gr.Number(label="DBP (Diastolic Blood Pressure)"),
 
273
  gr.Number(label="Mean BUN"),
274
  gr.Number(label="Mean AGAP"),
275
  gr.Number(label="Mean Protein"),
 
276
  gr.Number(label="eGFR"),
277
  gr.Number(label="ED Visits"),
278
  gr.Number(label="LOS (Length of Stay)"),