mfarnas commited on
Commit
ad12767
·
1 Parent(s): e3a752a

clean debug

Browse files
src/model_utils.py CHANGED
@@ -5,12 +5,12 @@ from catboost import CatBoostClassifier
5
  # from lightgbm import LGBMClassifier
6
  from sklearn.ensemble import RandomForestClassifier
7
 
8
- # MODEL_DIR = Path("saved_models")
9
  # MODEL_DIR.mkdir(exist_ok=True)
10
 
11
  import yaml
12
 
13
- def load_model_params(model_type, mode="ensemble", path=Path("src/params") / "model_params.yaml"):
14
  if mode not in ["ensemble", "single_model"]:
15
  raise ValueError("mode must be either 'ensemble' or 'single_model'")
16
 
@@ -46,28 +46,9 @@ def get_model(model_type, mode="ensemble", best_iter=None):
46
  else:
47
  raise ValueError(f"Unsupported model type: {model_type}")
48
 
49
- # def save_model(model, user_model_name, metrics_result_single=None):
50
- # timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
51
- # filename = f"{timestamp}_{user_model_name}_single.pkl"
52
- # filepath = MODEL_DIR / filename
53
-
54
- # single_model_data = {
55
- # "timestamp": timestamp,
56
- # "model_name": user_model_name,
57
- # "target_col": st.session_state.target_col if "target_col" in st.session_state else "UNKNOWN",
58
- # "model": model,
59
- # "best_iteration": st.session_state.best_iteration,
60
- # "metrics_result_single": metrics_result_single
61
- # }
62
-
63
- # with open(filepath, "wb") as f:
64
- # pickle.dump(single_model_data, f)
65
- # return filename
66
-
67
  def save_model(model, user_model_name, metrics_result_single=None):
68
  from datetime import datetime
69
  import io
70
- # import uuid
71
  import pickle
72
  import json
73
  import pyarrow as pa
@@ -135,33 +116,11 @@ def save_model(model, user_model_name, metrics_result_single=None):
135
  path_or_fileobj=buf
136
  )
137
 
138
- print('filename SAVEEEEEEEEEEEEE', filename)
139
- st.warning(f'SAVEEEEEEEEEEEEE {filename}')
140
  return filename
141
 
142
- # def save_model_ensemble(models, user_model_name, best_iterations=None, fold_scores=None, metrics_result_ensemble=None):
143
- # timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
144
- # filename = f"{timestamp}_{user_model_name}_ensemble.pkl"
145
- # filepath = MODEL_DIR / filename
146
-
147
- # ensemble_data = {
148
- # "timestamp": timestamp,
149
- # "model_name": user_model_name,
150
- # "target_col": st.session_state.target_col if "target_col" in st.session_state else "UNKNOWN",
151
- # "models": models,
152
- # "best_iterations": best_iterations,
153
- # "fold_scores": fold_scores,
154
- # "metrics_result_ensemble": metrics_result_ensemble
155
- # }
156
-
157
- # with open(filepath, "wb") as f:
158
- # pickle.dump(ensemble_data, f)
159
- # return filename
160
-
161
  def save_model_ensemble(models, user_model_name, best_iterations=None, fold_scores=None, metrics_result_ensemble=None):
162
  from datetime import datetime
163
  import io
164
- # import uuid
165
  import pickle
166
  import json
167
  import pyarrow as pa
@@ -225,21 +184,8 @@ def save_model_ensemble(models, user_model_name, best_iterations=None, fold_scor
225
  path_or_fileobj=buf
226
  )
227
 
228
- print('filename SAVEEEEEEEEEEEEEEEE', filename)
229
- st.warning(f'SAVEEEEEEEEEEEEEEEEEEE {filename}')
230
  return filename
231
 
232
-
233
- # def load_model(model_name):
234
- # filepath = MODEL_DIR / f"{model_name}.pkl"
235
- # if not filepath.exists():
236
- # raise FileNotFoundError(f"Model file not found: {filepath}")
237
-
238
- # with open(filepath, "rb") as f:
239
- # single_model_data = pickle.load(f)
240
-
241
- # return single_model_data
242
-
243
  def load_model(model_name):
244
  from huggingface_hub import login, hf_hub_download
245
  import pyarrow.parquet as pq
@@ -250,18 +196,6 @@ def load_model(model_name):
250
  if "HF_TOKEN" in os.environ:
251
  login(token=os.environ["HF_TOKEN"])
252
 
253
- # files = hf_hub_download(
254
- # repo_id=os.environ["HF_REPO_ID"],
255
- # repo_type="dataset",
256
- # token=os.environ["HF_TOKEN"],
257
- # filename=None, # Get whole repo listing
258
- # cache_dir=None,
259
- # local_dir=None,
260
- # local_dir_use_symlinks=False,
261
- # force_download=False,
262
- # resume_download=True
263
- # )
264
-
265
  from huggingface_hub import HfApi
266
  api = HfApi(token=os.environ["HF_TOKEN"])
267
  all_files = api.list_repo_files(repo_id=os.environ["HF_REPO_ID"], repo_type="dataset")
@@ -277,8 +211,6 @@ def load_model(model_name):
277
  token=os.environ["HF_TOKEN"]
278
  )
279
  table = pq.read_table(downloaded)
280
- print("tableeeeeee")
281
- st.dataframe(table)
282
  row = table.to_pylist()[0]
283
  if row["filename"] == model_name.replace("parquet", "pkl"):
284
  target_file = downloaded
@@ -289,26 +221,11 @@ def load_model(model_name):
289
 
290
  model_bytes = row["model_file"]["bytes"]
291
 
292
- print("LOADDDDDDDDDDDDDDDDDDDDDDDDDD")
293
- print('row["filename"]', row["filename"])
294
- print('model_name.replace("parquet", "pkl")', model_name.replace("parquet", "pkl"))
295
  return pickle.loads(model_bytes)
296
 
297
-
298
- # def load_model_ensemble(filename):
299
- # filepath = MODEL_DIR / f"{filename}.pkl"
300
- # if not filepath.exists():
301
- # raise FileNotFoundError(f"Model file not found: {filepath}")
302
-
303
- # with open(filepath, "rb") as f:
304
- # ensemble_data = pickle.load(f)
305
-
306
- # return ensemble_data
307
-
308
  def load_model_ensemble(filename):
309
  return load_model(filename)
310
 
311
-
312
  def ensemble_predict(models, X, cat_features):
313
  preds = sum([model.predict_proba(X)[:, 1] for model in models]) / len(models)
314
  return preds
 
5
  # from lightgbm import LGBMClassifier
6
  from sklearn.ensemble import RandomForestClassifier
7
 
8
+ MODEL_DIR = Path("src/params")
9
  # MODEL_DIR.mkdir(exist_ok=True)
10
 
11
  import yaml
12
 
13
+ def load_model_params(model_type, mode="ensemble", path=MODEL_DIR / "model_params.yaml"):
14
  if mode not in ["ensemble", "single_model"]:
15
  raise ValueError("mode must be either 'ensemble' or 'single_model'")
16
 
 
46
  else:
47
  raise ValueError(f"Unsupported model type: {model_type}")
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def save_model(model, user_model_name, metrics_result_single=None):
50
  from datetime import datetime
51
  import io
 
52
  import pickle
53
  import json
54
  import pyarrow as pa
 
116
  path_or_fileobj=buf
117
  )
118
 
 
 
119
  return filename
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  def save_model_ensemble(models, user_model_name, best_iterations=None, fold_scores=None, metrics_result_ensemble=None):
122
  from datetime import datetime
123
  import io
 
124
  import pickle
125
  import json
126
  import pyarrow as pa
 
184
  path_or_fileobj=buf
185
  )
186
 
 
 
187
  return filename
188
 
 
 
 
 
 
 
 
 
 
 
 
189
  def load_model(model_name):
190
  from huggingface_hub import login, hf_hub_download
191
  import pyarrow.parquet as pq
 
196
  if "HF_TOKEN" in os.environ:
197
  login(token=os.environ["HF_TOKEN"])
198
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  from huggingface_hub import HfApi
200
  api = HfApi(token=os.environ["HF_TOKEN"])
201
  all_files = api.list_repo_files(repo_id=os.environ["HF_REPO_ID"], repo_type="dataset")
 
211
  token=os.environ["HF_TOKEN"]
212
  )
213
  table = pq.read_table(downloaded)
 
 
214
  row = table.to_pylist()[0]
215
  if row["filename"] == model_name.replace("parquet", "pkl"):
216
  target_file = downloaded
 
221
 
222
  model_bytes = row["model_file"]["bytes"]
223
 
 
 
 
224
  return pickle.loads(model_bytes)
225
 
 
 
 
 
 
 
 
 
 
 
 
226
  def load_model_ensemble(filename):
227
  return load_model(filename)
228
 
 
229
  def ensemble_predict(models, X, cat_features):
230
  preds = sum([model.predict_proba(X)[:, 1] for model in models]) / len(models)
231
  return preds
src/model_utils_ori.py DELETED
@@ -1,114 +0,0 @@
1
- import streamlit as st
2
- import pickle
3
- import catboost
4
- from datetime import datetime
5
- from pathlib import Path
6
-
7
- from catboost import CatBoostClassifier
8
- # from xgboost import XGBClassifier
9
- # from lightgbm import LGBMClassifier
10
- from sklearn.ensemble import RandomForestClassifier
11
-
12
- MODEL_DIR = Path("saved_models")
13
- MODEL_DIR.mkdir(exist_ok=True)
14
-
15
- import yaml
16
-
17
- def load_model_params(model_type, mode="ensemble", path=Path("params") / "model_params.yaml"):
18
- if mode not in ["ensemble", "single_model"]:
19
- raise ValueError("mode must be either 'ensemble' or 'single_model'")
20
-
21
- if model_type not in ["CatBoost", "XGBoost", "LightGBM", "RandomForest"]:
22
- raise ValueError("model_type must be one of 'CatBoost', 'XGBoost', 'LightGBM', or 'RandomForest'")
23
-
24
- with open(path, "r") as f:
25
- all_params = yaml.safe_load(f)
26
-
27
- params = all_params[model_type][mode]
28
- if "random_seed" in params:
29
- st.session_state.random_seed = params["random_seed"]
30
-
31
- return params
32
-
33
- def get_model(model_type, mode="ensemble", best_iter=None):
34
- params = load_model_params(model_type, mode)
35
-
36
- # iter is set for single_model mode, where
37
- if best_iter is not None:
38
- params['iterations'] = best_iter
39
- # if "random_seed" in st.session_state:
40
- # random_seed = st.session_state.random_seed
41
-
42
- if model_type == "CatBoost":
43
- return CatBoostClassifier(**params)
44
- # elif model_type == "XGBoost":
45
- # return XGBClassifier(**params, use_label_encoder=False, eval_metric="logloss")
46
- # elif model_type == "LightGBM":
47
- # return LGBMClassifier(**params)
48
- elif model_type == "RandomForest":
49
- return RandomForestClassifier(**params)
50
- else:
51
- raise ValueError(f"Unsupported model type: {model_type}")
52
-
53
- def save_model(model, user_model_name, metrics_result_single=None):
54
- timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
55
- filename = f"{timestamp}_{user_model_name}_single.pkl"
56
- filepath = MODEL_DIR / filename
57
-
58
- single_model_data = {
59
- "timestamp": timestamp,
60
- "model_name": user_model_name,
61
- "target_col": st.session_state.target_col if "target_col" in st.session_state else "UNKNOWN",
62
- "model": model,
63
- "best_iteration": st.session_state.best_iteration,
64
- "metrics_result_single": metrics_result_single
65
- }
66
-
67
- with open(filepath, "wb") as f:
68
- pickle.dump(single_model_data, f)
69
- return filename
70
-
71
- def load_model(model_name):
72
- filepath = MODEL_DIR / f"{model_name}.pkl"
73
- if not filepath.exists():
74
- raise FileNotFoundError(f"Model file not found: {filepath}")
75
-
76
- with open(filepath, "rb") as f:
77
- single_model_data = pickle.load(f)
78
-
79
- return single_model_data
80
-
81
- def save_model_ensemble(models, user_model_name, best_iterations=None, fold_scores=None, metrics_result_ensemble=None):
82
- timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
83
- filename = f"{timestamp}_{user_model_name}_ensemble.pkl"
84
- filepath = MODEL_DIR / filename
85
-
86
- ensemble_data = {
87
- "timestamp": timestamp,
88
- "model_name": user_model_name,
89
- "target_col": st.session_state.target_col if "target_col" in st.session_state else "UNKNOWN",
90
- "models": models,
91
- "best_iterations": best_iterations,
92
- "fold_scores": fold_scores,
93
- "metrics_result_ensemble": metrics_result_ensemble
94
- }
95
-
96
- with open(filepath, "wb") as f:
97
- pickle.dump(ensemble_data, f)
98
- return filename
99
-
100
-
101
- def load_model_ensemble(filename):
102
- filepath = MODEL_DIR / f"{filename}.pkl"
103
- if not filepath.exists():
104
- raise FileNotFoundError(f"Model file not found: {filepath}")
105
-
106
- with open(filepath, "rb") as f:
107
- ensemble_data = pickle.load(f)
108
-
109
- return ensemble_data
110
-
111
-
112
- def ensemble_predict(models, X, cat_features):
113
- preds = sum([model.predict_proba(X)[:, 1] for model in models]) / len(models)
114
- return preds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/pages/1_Individual_Predictions.py CHANGED
@@ -90,7 +90,6 @@ with st.form("individual_form"):
90
  'CYTARABINE',
91
  'ETOPOSIDE',
92
  'FLUDARABINE',
93
- # 'GEMCITABIBE',
94
  'GEMCITABINE',
95
  'MELPHALAN',
96
  'METHOTREXATE',
@@ -187,9 +186,7 @@ if submitted:
187
  pred = 0.0
188
  else:
189
  if "ensemble" in st.session_state.selected_model:
190
- # ensemble = True
191
- # ensemble model prediction
192
- # if ensemble:
193
  models = load_model_ensemble(st.session_state.selected_model)
194
  models = models["model"]
195
  pred = ensemble_predict(models, X, cat_features)
 
90
  'CYTARABINE',
91
  'ETOPOSIDE',
92
  'FLUDARABINE',
 
93
  'GEMCITABINE',
94
  'MELPHALAN',
95
  'METHOTREXATE',
 
186
  pred = 0.0
187
  else:
188
  if "ensemble" in st.session_state.selected_model:
189
+ # ensemble prediction
 
 
190
  models = load_model_ensemble(st.session_state.selected_model)
191
  models = models["model"]
192
  pred = ensemble_predict(models, X, cat_features)
src/pages/2_Bulk_Predictions.py CHANGED
@@ -11,8 +11,6 @@ sidebar()
11
 
12
  st.title("📊 Bulk Patient Predictions")
13
 
14
- # training_preproc_cols = []
15
-
16
  uploaded_file = st.file_uploader("Upload CSV", type=["csv"])
17
  if uploaded_file:
18
  df = pd.read_csv(uploaded_file, header=1)
@@ -21,7 +19,6 @@ if uploaded_file:
21
 
22
  if st.button("Preprocess"):
23
  df_proc = preprocess(df)
24
- # print('df_proc', df_proc.columns) # Debugging line to check processed columns
25
  edited_df = st.data_editor(df_proc, num_rows="dynamic")
26
  st.session_state.bulk_input_df = edited_df
27
 
@@ -39,7 +36,6 @@ if uploaded_file:
39
  st.session_state.best_iterations = ensemble_data.get("best_iterations", [])
40
  st.session_state.fold_scores = ensemble_data.get("fold_scores", [])
41
 
42
- # st.success(f"Loaded ensemble: {ensemble_data['model_name']} from {ensemble_data['timestamp']}")
43
  except Exception as e:
44
  st.error(f"Error loading ensemble: {str(e)}")
45
  else:
 
11
 
12
  st.title("📊 Bulk Patient Predictions")
13
 
 
 
14
  uploaded_file = st.file_uploader("Upload CSV", type=["csv"])
15
  if uploaded_file:
16
  df = pd.read_csv(uploaded_file, header=1)
 
19
 
20
  if st.button("Preprocess"):
21
  df_proc = preprocess(df)
 
22
  edited_df = st.data_editor(df_proc, num_rows="dynamic")
23
  st.session_state.bulk_input_df = edited_df
24
 
 
36
  st.session_state.best_iterations = ensemble_data.get("best_iterations", [])
37
  st.session_state.fold_scores = ensemble_data.get("fold_scores", [])
38
 
 
39
  except Exception as e:
40
  st.error(f"Error loading ensemble: {str(e)}")
41
  else:
src/sidebar.py CHANGED
@@ -7,25 +7,6 @@ import pyarrow.parquet as pq
7
 
8
  st.session_state.orig_train_cols = ['EPI/ID numbers', 'Recipient_gender', 'Recepient_DOB', 'Recepient_Nationality', 'Hematological Diagnosis', 'Date of first diagnosis/BMBx date', 'Recepient_Blood group before HSCT', 'Donor_DOB', 'Donor_gender', 'D_Blood group', 'R_HLA_A', 'R_HLA_B', 'R_HLA_C', 'R_HLA_DR', 'R_HLA_DQ', 'D_HLA_A', 'D_HLA_B', 'D_HLA_C', 'D_HLA_DR', 'D_HLA_DQ', 'Number of lines of Rx before HSCT', 'PreHSCT conditioning regimen+/-ATG+/-TBI', 'HSCT_date', 'Source of cells', 'Donor_relation to recipient', 'HLA match ratio', 'Post HSCT regimen', 'First_GVHD prophylaxis', 'GVHD', 'Acute GVHD(<100 days)', 'Chronic GVHD>100 days', 'Acute+Chronic', 'GVHD severity', 'R_HLA_A1', 'R_HLA_A2', 'R_HLA_B1', 'R_HLA_B2', 'R_HLA_C1', 'R_HLA_C2', 'R_HLA_DR1', 'R_HLA_DR2', 'R_HLA_DQ1', 'R_HLA_DQ2', 'D_HLA_A1', 'D_HLA_A2', 'D_HLA_B1', 'D_HLA_B2', 'D_HLA_C1', 'D_HLA_C2', 'D_HLA_DR1', 'D_HLA_DR2', 'D_HLA_DQ1', 'D_HLA_DQ2', 'R_HLA_A_1', 'R_HLA_A_11', 'R_HLA_A_12', 'R_HLA_A_2', 'R_HLA_A_20', 'R_HLA_A_23', 'R_HLA_A_24', 'R_HLA_A_25', 'R_HLA_A_26', 'R_HLA_A_29', 'R_HLA_A_3', 'R_HLA_A_30', 'R_HLA_A_31', 'R_HLA_A_32', 'R_HLA_A_33', 'R_HLA_A_34', 'R_HLA_A_4', 'R_HLA_A_66', 'R_HLA_A_68', 'R_HLA_A_69', 'R_HLA_A_7', 'R_HLA_A_74', 'R_HLA_A_8', 'R_HLA_A_X', 'R_HLA_B_13', 'R_HLA_B_14', 'R_HLA_B_15', 'R_HLA_B_18', 'R_HLA_B_23', 'R_HLA_B_24', 'R_HLA_B_27', 'R_HLA_B_35', 'R_HLA_B_37', 'R_HLA_B_38', 'R_HLA_B_39', 'R_HLA_B_40', 'R_HLA_B_41', 'R_HLA_B_42', 'R_HLA_B_44', 'R_HLA_B_45', 'R_HLA_B_46', 'R_HLA_B_49', 'R_HLA_B_50', 'R_HLA_B_51', 'R_HLA_B_52', 'R_HLA_B_53', 'R_HLA_B_55', 'R_HLA_B_56', 'R_HLA_B_57', 'R_HLA_B_58', 'R_HLA_B_7', 'R_HLA_B_73', 'R_HLA_B_8', 'R_HLA_B_81', 'R_HLA_B_X', 'R_HLA_C_1', 'R_HLA_C_12', 'R_HLA_C_14', 'R_HLA_C_15', 'R_HLA_C_16', 'R_HLA_C_17', 'R_HLA_C_18', 'R_HLA_C_2', 'R_HLA_C_3', 'R_HLA_C_38', 'R_HLA_C_4', 'R_HLA_C_49', 'R_HLA_C_5', 'R_HLA_C_50', 'R_HLA_C_6', 'R_HLA_C_7', 'R_HLA_C_8', 'R_HLA_C_X', 'R_HLA_DR_1', 'R_HLA_DR_10', 'R_HLA_DR_11', 'R_HLA_DR_12', 'R_HLA_DR_13', 'R_HLA_DR_14', 'R_HLA_DR_15', 'R_HLA_DR_16', 'R_HLA_DR_17', 'R_HLA_DR_2', 'R_HLA_DR_3', 'R_HLA_DR_4', 'R_HLA_DR_5', 'R_HLA_DR_6', 'R_HLA_DR_7', 'R_HLA_DR_8', 'R_HLA_DR_9', 'R_HLA_DR_X', 'R_HLA_DQ_1', 'R_HLA_DQ_11', 'R_HLA_DQ_15', 'R_HLA_DQ_16', 'R_HLA_DQ_2', 'R_HLA_DQ_3', 'R_HLA_DQ_301', 'R_HLA_DQ_4', 'R_HLA_DQ_5', 'R_HLA_DQ_6', 'R_HLA_DQ_7', 'R_HLA_DQ_X', 'D_HLA_A_1', 'D_HLA_A_11', 'D_HLA_A_12', 'D_HLA_A_2', 'D_HLA_A_23', 'D_HLA_A_24', 'D_HLA_A_25', 'D_HLA_A_26', 'D_HLA_A_29', 'D_HLA_A_3', 'D_HLA_A_30', 'D_HLA_A_31', 'D_HLA_A_32', 'D_HLA_A_33', 'D_HLA_A_34', 'D_HLA_A_66', 'D_HLA_A_68', 'D_HLA_A_69', 'D_HLA_A_7', 'D_HLA_A_74', 'D_HLA_A_8', 'D_HLA_A_X', 'D_HLA_B_13', 'D_HLA_B_14', 'D_HLA_B_15', 'D_HLA_B_17', 'D_HLA_B_18', 'D_HLA_B_23', 'D_HLA_B_24', 'D_HLA_B_27', 'D_HLA_B_35', 'D_HLA_B_37', 'D_HLA_B_38', 'D_HLA_B_39', 'D_HLA_B_40', 'D_HLA_B_41', 'D_HLA_B_42', 'D_HLA_B_44', 'D_HLA_B_45', 'D_HLA_B_48', 'D_HLA_B_49', 'D_HLA_B_50', 'D_HLA_B_51', 'D_HLA_B_52', 'D_HLA_B_53', 'D_HLA_B_55', 'D_HLA_B_56', 'D_HLA_B_57', 'D_HLA_B_58', 'D_HLA_B_7', 'D_HLA_B_73', 'D_HLA_B_8', 'D_HLA_B_81', 'D_HLA_B_X', 'D_HLA_C_1', 'D_HLA_C_12', 'D_HLA_C_14', 'D_HLA_C_15', 'D_HLA_C_16', 'D_HLA_C_17', 'D_HLA_C_18', 'D_HLA_C_2', 'D_HLA_C_3', 'D_HLA_C_38', 'D_HLA_C_4', 'D_HLA_C_49', 'D_HLA_C_5', 'D_HLA_C_50', 'D_HLA_C_6', 'D_HLA_C_7', 'D_HLA_C_8', 'D_HLA_C_X', 'D_HLA_DR_1', 'D_HLA_DR_10', 'D_HLA_DR_11', 'D_HLA_DR_12', 'D_HLA_DR_13', 'D_HLA_DR_14', 'D_HLA_DR_15', 'D_HLA_DR_16', 'D_HLA_DR_17', 'D_HLA_DR_2', 'D_HLA_DR_3', 'D_HLA_DR_4', 'D_HLA_DR_5', 'D_HLA_DR_6', 'D_HLA_DR_7', 'D_HLA_DR_8', 'D_HLA_DR_9', 'D_HLA_DR_X', 'D_HLA_DQ_1', 'D_HLA_DQ_11', 'D_HLA_DQ_15', 'D_HLA_DQ_16', 'D_HLA_DQ_2', 'D_HLA_DQ_3', 'D_HLA_DQ_301', 'D_HLA_DQ_4', 'D_HLA_DQ_5', 'D_HLA_DQ_6', 'D_HLA_DQ_7', 'D_HLA_DQ_X', 'Recepient_DOB_Year', 'Donor_DOB_Year', 'HSCT_date_Year', 'R_Age_at_transplant', 'D_Age_at_transplant', 'Age_Gap_R_D', 'PreHSCT_ALEMTUZUMAB', 'PreHSCT_ATG', 'PreHSCT_BEAM', 'PreHSCT_BUSULFAN', 'PreHSCT_CAMPATH', 'PreHSCT_CARMUSTINE', 'PreHSCT_CLOFARABINE', 'PreHSCT_CYCLOPHOSPHAMIDE', 'PreHSCT_CYCLOSPORIN', 'PreHSCT_CYTARABINE', 'PreHSCT_ETOPOSIDE', 'PreHSCT_FLUDARABINE', 'PreHSCT_GEMCITABINE', 'PreHSCT_MELPHALAN', 'PreHSCT_METHOTREXATE', 'PreHSCT_OTHER', 'PreHSCT_RANIMUSTINE', 'PreHSCT_REDUCEDCONDITIONING', 'PreHSCT_RITUXIMAB', 'PreHSCT_SIROLIMUS', 'PreHSCT_TBI', 'PreHSCT_THIOTEPA', 'PreHSCT_TREOSULFAN', 'PreHSCT_UA', 'PreHSCT_VORNOSTAT', 'PreHSCT_X', 'First_GVHD_prophylaxis_ABATACEPT', 'First_GVHD_prophylaxis_ALEMTUZUMAB', 'First_GVHD_prophylaxis_ATG', 'First_GVHD_prophylaxis_CYCLOPHOSPHAMIDE', 'First_GVHD_prophylaxis_CYCLOSPORIN', 'First_GVHD_prophylaxis_IMATINIB', 'First_GVHD_prophylaxis_LEFLUNOMIDE', 'First_GVHD_prophylaxis_METHOTREXATE', 'First_GVHD_prophylaxis_MMF', 'First_GVHD_prophylaxis_NONE', 'First_GVHD_prophylaxis_RUXOLITINIB', 'First_GVHD_prophylaxis_SIROLIMUS', 'First_GVHD_prophylaxis_STEROID', 'First_GVHD_prophylaxis_TAC', 'First_GVHD_prophylaxis_TACROLIMUS', 'First_GVHD_prophylaxis_X', 'Recepient_Blood group before HSCT_MergePlusMinus', 'D_Blood group_MergePlusMinus', 'R_Age_at_transplant_cutoff16', 'R_Age_at_transplant_cutoff18', 'D_Age_at_transplant_cutoff16', 'D_Age_at_transplant_cutoff18', 'Relation_and_Recipient_Gender', 'Relation_and_Donor_Gender', 'Relation_and_Recipient_and_Donor_Gender', 'Recepient_Nationality_Geographical', 'Recepient_Nationality_Cultural', 'Recepient_Nationality_Regional_Income', 'Recepient_Nationality_Regional_WHO', 'Hematological Diagnosis_Grouped', 'Hematological Diagnosis_Malignant', 'PreHSCT_MTX', 'First_GVHD_prophylaxis_MTX']
9
 
10
- # def sidebar():
11
- # APP_DIR = Path(__file__).parent
12
- # MODELS_DIR = APP_DIR / "saved_models"
13
-
14
- # # Shared dropdown in the sidebar
15
- # def get_model_options():
16
- # models = ["Default"]
17
- # model_files = glob.glob(str(MODELS_DIR / "*.pkl")) + glob.glob(str(MODELS_DIR / "*.cbm"))
18
-
19
- # for m in model_files:
20
- # models.append(Path(m).stem)
21
- # return sorted(set(models))
22
-
23
- # if 'selected_model' not in st.session_state:
24
- # st.session_state.selected_model = "Default"
25
-
26
- # st.sidebar.title("Model Selection")
27
- # st.session_state.selected_model = st.sidebar.selectbox("Model", get_model_options())
28
-
29
  def sidebar():
30
  def get_model_options():
31
  models = ["Default_ensemble"]
@@ -48,16 +29,10 @@ def sidebar():
48
  except Exception as e:
49
  st.warning(f"Skipping model file due to error: {f} ({e})")
50
 
51
- # todel
52
- print(sorted(set(models)))
53
- st.warning(sorted(set(models)))
54
  return sorted(set(models))
55
 
56
  if 'selected_model' not in st.session_state:
57
  st.session_state.selected_model = "Default_ensemble"
58
 
59
  st.sidebar.title("Model Selection")
60
- st.session_state.selected_model = st.sidebar.selectbox("Model", get_model_options())
61
-
62
- # todel
63
- st.info(f"{st.session_state.selected_model} is chosen!")
 
7
 
8
  st.session_state.orig_train_cols = ['EPI/ID numbers', 'Recipient_gender', 'Recepient_DOB', 'Recepient_Nationality', 'Hematological Diagnosis', 'Date of first diagnosis/BMBx date', 'Recepient_Blood group before HSCT', 'Donor_DOB', 'Donor_gender', 'D_Blood group', 'R_HLA_A', 'R_HLA_B', 'R_HLA_C', 'R_HLA_DR', 'R_HLA_DQ', 'D_HLA_A', 'D_HLA_B', 'D_HLA_C', 'D_HLA_DR', 'D_HLA_DQ', 'Number of lines of Rx before HSCT', 'PreHSCT conditioning regimen+/-ATG+/-TBI', 'HSCT_date', 'Source of cells', 'Donor_relation to recipient', 'HLA match ratio', 'Post HSCT regimen', 'First_GVHD prophylaxis', 'GVHD', 'Acute GVHD(<100 days)', 'Chronic GVHD>100 days', 'Acute+Chronic', 'GVHD severity', 'R_HLA_A1', 'R_HLA_A2', 'R_HLA_B1', 'R_HLA_B2', 'R_HLA_C1', 'R_HLA_C2', 'R_HLA_DR1', 'R_HLA_DR2', 'R_HLA_DQ1', 'R_HLA_DQ2', 'D_HLA_A1', 'D_HLA_A2', 'D_HLA_B1', 'D_HLA_B2', 'D_HLA_C1', 'D_HLA_C2', 'D_HLA_DR1', 'D_HLA_DR2', 'D_HLA_DQ1', 'D_HLA_DQ2', 'R_HLA_A_1', 'R_HLA_A_11', 'R_HLA_A_12', 'R_HLA_A_2', 'R_HLA_A_20', 'R_HLA_A_23', 'R_HLA_A_24', 'R_HLA_A_25', 'R_HLA_A_26', 'R_HLA_A_29', 'R_HLA_A_3', 'R_HLA_A_30', 'R_HLA_A_31', 'R_HLA_A_32', 'R_HLA_A_33', 'R_HLA_A_34', 'R_HLA_A_4', 'R_HLA_A_66', 'R_HLA_A_68', 'R_HLA_A_69', 'R_HLA_A_7', 'R_HLA_A_74', 'R_HLA_A_8', 'R_HLA_A_X', 'R_HLA_B_13', 'R_HLA_B_14', 'R_HLA_B_15', 'R_HLA_B_18', 'R_HLA_B_23', 'R_HLA_B_24', 'R_HLA_B_27', 'R_HLA_B_35', 'R_HLA_B_37', 'R_HLA_B_38', 'R_HLA_B_39', 'R_HLA_B_40', 'R_HLA_B_41', 'R_HLA_B_42', 'R_HLA_B_44', 'R_HLA_B_45', 'R_HLA_B_46', 'R_HLA_B_49', 'R_HLA_B_50', 'R_HLA_B_51', 'R_HLA_B_52', 'R_HLA_B_53', 'R_HLA_B_55', 'R_HLA_B_56', 'R_HLA_B_57', 'R_HLA_B_58', 'R_HLA_B_7', 'R_HLA_B_73', 'R_HLA_B_8', 'R_HLA_B_81', 'R_HLA_B_X', 'R_HLA_C_1', 'R_HLA_C_12', 'R_HLA_C_14', 'R_HLA_C_15', 'R_HLA_C_16', 'R_HLA_C_17', 'R_HLA_C_18', 'R_HLA_C_2', 'R_HLA_C_3', 'R_HLA_C_38', 'R_HLA_C_4', 'R_HLA_C_49', 'R_HLA_C_5', 'R_HLA_C_50', 'R_HLA_C_6', 'R_HLA_C_7', 'R_HLA_C_8', 'R_HLA_C_X', 'R_HLA_DR_1', 'R_HLA_DR_10', 'R_HLA_DR_11', 'R_HLA_DR_12', 'R_HLA_DR_13', 'R_HLA_DR_14', 'R_HLA_DR_15', 'R_HLA_DR_16', 'R_HLA_DR_17', 'R_HLA_DR_2', 'R_HLA_DR_3', 'R_HLA_DR_4', 'R_HLA_DR_5', 'R_HLA_DR_6', 'R_HLA_DR_7', 'R_HLA_DR_8', 'R_HLA_DR_9', 'R_HLA_DR_X', 'R_HLA_DQ_1', 'R_HLA_DQ_11', 'R_HLA_DQ_15', 'R_HLA_DQ_16', 'R_HLA_DQ_2', 'R_HLA_DQ_3', 'R_HLA_DQ_301', 'R_HLA_DQ_4', 'R_HLA_DQ_5', 'R_HLA_DQ_6', 'R_HLA_DQ_7', 'R_HLA_DQ_X', 'D_HLA_A_1', 'D_HLA_A_11', 'D_HLA_A_12', 'D_HLA_A_2', 'D_HLA_A_23', 'D_HLA_A_24', 'D_HLA_A_25', 'D_HLA_A_26', 'D_HLA_A_29', 'D_HLA_A_3', 'D_HLA_A_30', 'D_HLA_A_31', 'D_HLA_A_32', 'D_HLA_A_33', 'D_HLA_A_34', 'D_HLA_A_66', 'D_HLA_A_68', 'D_HLA_A_69', 'D_HLA_A_7', 'D_HLA_A_74', 'D_HLA_A_8', 'D_HLA_A_X', 'D_HLA_B_13', 'D_HLA_B_14', 'D_HLA_B_15', 'D_HLA_B_17', 'D_HLA_B_18', 'D_HLA_B_23', 'D_HLA_B_24', 'D_HLA_B_27', 'D_HLA_B_35', 'D_HLA_B_37', 'D_HLA_B_38', 'D_HLA_B_39', 'D_HLA_B_40', 'D_HLA_B_41', 'D_HLA_B_42', 'D_HLA_B_44', 'D_HLA_B_45', 'D_HLA_B_48', 'D_HLA_B_49', 'D_HLA_B_50', 'D_HLA_B_51', 'D_HLA_B_52', 'D_HLA_B_53', 'D_HLA_B_55', 'D_HLA_B_56', 'D_HLA_B_57', 'D_HLA_B_58', 'D_HLA_B_7', 'D_HLA_B_73', 'D_HLA_B_8', 'D_HLA_B_81', 'D_HLA_B_X', 'D_HLA_C_1', 'D_HLA_C_12', 'D_HLA_C_14', 'D_HLA_C_15', 'D_HLA_C_16', 'D_HLA_C_17', 'D_HLA_C_18', 'D_HLA_C_2', 'D_HLA_C_3', 'D_HLA_C_38', 'D_HLA_C_4', 'D_HLA_C_49', 'D_HLA_C_5', 'D_HLA_C_50', 'D_HLA_C_6', 'D_HLA_C_7', 'D_HLA_C_8', 'D_HLA_C_X', 'D_HLA_DR_1', 'D_HLA_DR_10', 'D_HLA_DR_11', 'D_HLA_DR_12', 'D_HLA_DR_13', 'D_HLA_DR_14', 'D_HLA_DR_15', 'D_HLA_DR_16', 'D_HLA_DR_17', 'D_HLA_DR_2', 'D_HLA_DR_3', 'D_HLA_DR_4', 'D_HLA_DR_5', 'D_HLA_DR_6', 'D_HLA_DR_7', 'D_HLA_DR_8', 'D_HLA_DR_9', 'D_HLA_DR_X', 'D_HLA_DQ_1', 'D_HLA_DQ_11', 'D_HLA_DQ_15', 'D_HLA_DQ_16', 'D_HLA_DQ_2', 'D_HLA_DQ_3', 'D_HLA_DQ_301', 'D_HLA_DQ_4', 'D_HLA_DQ_5', 'D_HLA_DQ_6', 'D_HLA_DQ_7', 'D_HLA_DQ_X', 'Recepient_DOB_Year', 'Donor_DOB_Year', 'HSCT_date_Year', 'R_Age_at_transplant', 'D_Age_at_transplant', 'Age_Gap_R_D', 'PreHSCT_ALEMTUZUMAB', 'PreHSCT_ATG', 'PreHSCT_BEAM', 'PreHSCT_BUSULFAN', 'PreHSCT_CAMPATH', 'PreHSCT_CARMUSTINE', 'PreHSCT_CLOFARABINE', 'PreHSCT_CYCLOPHOSPHAMIDE', 'PreHSCT_CYCLOSPORIN', 'PreHSCT_CYTARABINE', 'PreHSCT_ETOPOSIDE', 'PreHSCT_FLUDARABINE', 'PreHSCT_GEMCITABINE', 'PreHSCT_MELPHALAN', 'PreHSCT_METHOTREXATE', 'PreHSCT_OTHER', 'PreHSCT_RANIMUSTINE', 'PreHSCT_REDUCEDCONDITIONING', 'PreHSCT_RITUXIMAB', 'PreHSCT_SIROLIMUS', 'PreHSCT_TBI', 'PreHSCT_THIOTEPA', 'PreHSCT_TREOSULFAN', 'PreHSCT_UA', 'PreHSCT_VORNOSTAT', 'PreHSCT_X', 'First_GVHD_prophylaxis_ABATACEPT', 'First_GVHD_prophylaxis_ALEMTUZUMAB', 'First_GVHD_prophylaxis_ATG', 'First_GVHD_prophylaxis_CYCLOPHOSPHAMIDE', 'First_GVHD_prophylaxis_CYCLOSPORIN', 'First_GVHD_prophylaxis_IMATINIB', 'First_GVHD_prophylaxis_LEFLUNOMIDE', 'First_GVHD_prophylaxis_METHOTREXATE', 'First_GVHD_prophylaxis_MMF', 'First_GVHD_prophylaxis_NONE', 'First_GVHD_prophylaxis_RUXOLITINIB', 'First_GVHD_prophylaxis_SIROLIMUS', 'First_GVHD_prophylaxis_STEROID', 'First_GVHD_prophylaxis_TAC', 'First_GVHD_prophylaxis_TACROLIMUS', 'First_GVHD_prophylaxis_X', 'Recepient_Blood group before HSCT_MergePlusMinus', 'D_Blood group_MergePlusMinus', 'R_Age_at_transplant_cutoff16', 'R_Age_at_transplant_cutoff18', 'D_Age_at_transplant_cutoff16', 'D_Age_at_transplant_cutoff18', 'Relation_and_Recipient_Gender', 'Relation_and_Donor_Gender', 'Relation_and_Recipient_and_Donor_Gender', 'Recepient_Nationality_Geographical', 'Recepient_Nationality_Cultural', 'Recepient_Nationality_Regional_Income', 'Recepient_Nationality_Regional_WHO', 'Hematological Diagnosis_Grouped', 'Hematological Diagnosis_Malignant', 'PreHSCT_MTX', 'First_GVHD_prophylaxis_MTX']
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def sidebar():
11
  def get_model_options():
12
  models = ["Default_ensemble"]
 
29
  except Exception as e:
30
  st.warning(f"Skipping model file due to error: {f} ({e})")
31
 
 
 
 
32
  return sorted(set(models))
33
 
34
  if 'selected_model' not in st.session_state:
35
  st.session_state.selected_model = "Default_ensemble"
36
 
37
  st.sidebar.title("Model Selection")
38
+ st.session_state.selected_model = st.sidebar.selectbox("Model", get_model_options())