mfarnas commited on
Commit
8ea1e26
·
1 Parent(s): 5cd3a8b

initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ src/saved_models/*.pkl filter=lfs diff=lfs merge=lfs -text
requirements.txt CHANGED
@@ -1,3 +1,12 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
 
 
 
1
+ catboost==1.2.8
2
+ huggingface_hub==0.33.2
3
+ numpy==1.26.4
4
+ pandas==2.3.0
5
+ pyarrow==16.1.0
6
+ PyYAML==6.0.2
7
+ scikit_learn==1.5.1
8
+ streamlit==1.46.1
9
+
10
+ # altair
11
+ # pandas
12
+ # streamlit
src/GVHD_Predictions_App.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ st.set_page_config(page_title="GVHD Predictions", layout="wide")
4
+ st.title("GVHD Predictions App")
src/inference_utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import streamlit as st
3
+ from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score, brier_score_loss, log_loss
4
+
5
+ def compute_metrics(y_true, y_pred_proba, threshold=0.5):
6
+ y_pred = (y_pred_proba >= threshold).astype(int)
7
+ return {
8
+ "AUC": roc_auc_score(y_true, y_pred_proba),
9
+ "F1": f1_score(y_true, y_pred),
10
+ "Accuracy": accuracy_score(y_true, y_pred),
11
+ "Precision": precision_score(y_true, y_pred),
12
+ "Recall": recall_score(y_true, y_pred),
13
+ "BrierScore": brier_score_loss(y_true, y_pred_proba),
14
+ "Logloss": log_loss(y_true, y_pred_proba),
15
+ }
16
+
17
+ def add_predictions(df, probs):
18
+ df['Predicted Probability'] = probs
19
+ df['GVHD Prediction'] = ['POSITIVE' if p > 0.5 else 'NEGATIVE' for p in probs]
20
+
21
+ df_with_gt = df[['Predicted Probability', 'GVHD Prediction']].join(st.session_state.targets_df)
22
+
23
+ # Define cell-level styling
24
+ def highlight_prediction(val):
25
+ if val == "POSITIVE":
26
+ return "background-color: #d4edda; color: #155724; text-align: center;"
27
+ elif val == "NEGATIVE":
28
+ return "background-color: #f8d7da; color: #721c24; text-align: center;"
29
+ return "text-align: center;"
30
+
31
+ # Apply color and alignment
32
+ df_styled = (
33
+ df_with_gt.style
34
+ .applymap(highlight_prediction, subset=["GVHD Prediction"])
35
+ .set_properties(**{'text-align': 'center'}) # Apply center alignment to all cells
36
+ )
37
+
38
+ return df_styled
src/model_utils.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pickle
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+
6
+ from catboost import CatBoostClassifier
7
+ # from xgboost import XGBClassifier
8
+ # from lightgbm import LGBMClassifier
9
+ from sklearn.ensemble import RandomForestClassifier
10
+
11
+ import json
12
+ import uuid
13
+ import io
14
+ from datetime import datetime
15
+ from typing import Any, Dict, Optional
16
+ import pickle
17
+ import pyarrow as pa
18
+ import pyarrow.parquet as pq
19
+ from huggingface_hub import CommitScheduler
20
+
21
+ MODEL_DIR = Path("saved_models")
22
+ MODEL_DIR.mkdir(exist_ok=True)
23
+
24
+ import yaml
25
+
26
+ def load_model_params(model_type, mode="ensemble", path=Path("params") / "model_params.yaml"):
27
+ if mode not in ["ensemble", "single_model"]:
28
+ raise ValueError("mode must be either 'ensemble' or 'single_model'")
29
+
30
+ if model_type not in ["CatBoost", "XGBoost", "LightGBM", "RandomForest"]:
31
+ raise ValueError("model_type must be one of 'CatBoost', 'XGBoost', 'LightGBM', or 'RandomForest'")
32
+
33
+ with open(path, "r") as f:
34
+ all_params = yaml.safe_load(f)
35
+
36
+ params = all_params[model_type][mode]
37
+ if "random_seed" in params:
38
+ st.session_state.random_seed = params["random_seed"]
39
+
40
+ return params
41
+
42
+ def get_model(model_type, mode="ensemble", best_iter=None):
43
+ params = load_model_params(model_type, mode)
44
+
45
+ # iter is set for single_model mode, where
46
+ if best_iter is not None:
47
+ params['iterations'] = best_iter
48
+ # if "random_seed" in st.session_state:
49
+ # random_seed = st.session_state.random_seed
50
+
51
+ if model_type == "CatBoost":
52
+ return CatBoostClassifier(**params)
53
+ # elif model_type == "XGBoost":
54
+ # return XGBClassifier(**params, use_label_encoder=False, eval_metric="logloss")
55
+ # elif model_type == "LightGBM":
56
+ # return LGBMClassifier(**params)
57
+ elif model_type == "RandomForest":
58
+ return RandomForestClassifier(**params)
59
+ else:
60
+ raise ValueError(f"Unsupported model type: {model_type}")
61
+
62
+ # def save_model(model, user_model_name, metrics_result_single=None):
63
+ # timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
64
+ # filename = f"{timestamp}_{user_model_name}_single.pkl"
65
+ # filepath = MODEL_DIR / filename
66
+
67
+ # single_model_data = {
68
+ # "timestamp": timestamp,
69
+ # "model_name": user_model_name,
70
+ # "target_col": st.session_state.target_col if "target_col" in st.session_state else "UNKNOWN",
71
+ # "model": model,
72
+ # "best_iteration": st.session_state.best_iteration,
73
+ # "metrics_result_single": metrics_result_single
74
+ # }
75
+
76
+ # with open(filepath, "wb") as f:
77
+ # pickle.dump(single_model_data, f)
78
+ # return filename
79
+
80
+ def save_model(model, user_model_name, metrics_result_single=None):
81
+ from datetime import datetime
82
+ import io
83
+ import uuid
84
+ import pickle
85
+ import json
86
+ import pyarrow as pa
87
+ import pyarrow.parquet as pq
88
+ from huggingface_hub import CommitScheduler
89
+
90
+ timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
91
+ filename = f"{timestamp}_{user_model_name}_single.pkl"
92
+
93
+ # Prepare model dict (same as before)
94
+ model_data = {
95
+ "timestamp": timestamp,
96
+ "model_name": user_model_name,
97
+ "target_col": st.session_state.get("target_col", "UNKNOWN"),
98
+ "model": model,
99
+ "best_iteration": st.session_state.get("best_iteration"),
100
+ "metrics_result_single": metrics_result_single,
101
+ }
102
+
103
+ # Serialize (pickle) to bytes
104
+ model_bytes = pickle.dumps(model_data)
105
+
106
+ # Prepare Parquet row
107
+ row = {
108
+ "filename": filename,
109
+ "timestamp": timestamp,
110
+ "type": "single",
111
+ "model_file": {"path": filename, "bytes": model_bytes},
112
+ }
113
+
114
+ table = pa.Table.from_pylist([row])
115
+ table = table.replace_schema_metadata({
116
+ "huggingface": json.dumps({"info": {
117
+ "features": {
118
+ "filename": {"_type": "Value", "dtype": "string"},
119
+ "timestamp": {"_type": "Value", "dtype": "string"},
120
+ "type": {"_type": "Value", "dtype": "string"},
121
+ "model_file": {"_type": "Value", "dtype": "binary"},
122
+ }
123
+ }})
124
+ })
125
+
126
+ # Write to in-memory buffer
127
+ buf = io.BytesIO()
128
+ pq.write_table(table, buf)
129
+ buf.seek(0)
130
+
131
+ # Upload to HF dataset
132
+ scheduler = CommitScheduler(
133
+ repo_id=st.secrets["HF_REPO_ID"],
134
+ repo_type="dataset",
135
+ path_in_repo="models",
136
+ token=st.secrets["HF_TOKEN"],
137
+ private=True,
138
+ folder_path="dummy"
139
+ )
140
+ scheduler.api.upload_file(
141
+ repo_id=st.secrets["HF_REPO_ID"],
142
+ repo_type="dataset",
143
+ path_in_repo=f"models/{uuid.uuid4()}.parquet",
144
+ path_or_fileobj=buf
145
+ )
146
+
147
+ return filename
148
+
149
+ # def save_model_ensemble(models, user_model_name, best_iterations=None, fold_scores=None, metrics_result_ensemble=None):
150
+ # timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
151
+ # filename = f"{timestamp}_{user_model_name}_ensemble.pkl"
152
+ # filepath = MODEL_DIR / filename
153
+
154
+ # ensemble_data = {
155
+ # "timestamp": timestamp,
156
+ # "model_name": user_model_name,
157
+ # "target_col": st.session_state.target_col if "target_col" in st.session_state else "UNKNOWN",
158
+ # "models": models,
159
+ # "best_iterations": best_iterations,
160
+ # "fold_scores": fold_scores,
161
+ # "metrics_result_ensemble": metrics_result_ensemble
162
+ # }
163
+
164
+ # with open(filepath, "wb") as f:
165
+ # pickle.dump(ensemble_data, f)
166
+ # return filename
167
+
168
+ def save_model_ensemble(models, user_model_name, best_iterations=None, fold_scores=None, metrics_result_ensemble=None):
169
+ from datetime import datetime
170
+ import io
171
+ import uuid
172
+ import pickle
173
+ import json
174
+ import pyarrow as pa
175
+ import pyarrow.parquet as pq
176
+ from huggingface_hub import CommitScheduler
177
+
178
+ timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
179
+ filename = f"{timestamp}_{user_model_name}_ensemble.pkl"
180
+
181
+ ensemble_data = {
182
+ "timestamp": timestamp,
183
+ "model_name": user_model_name,
184
+ "target_col": st.session_state.get("target_col", "UNKNOWN"),
185
+ "models": models,
186
+ "best_iterations": best_iterations,
187
+ "fold_scores": fold_scores,
188
+ "metrics_result_ensemble": metrics_result_ensemble,
189
+ }
190
+
191
+ model_bytes = pickle.dumps(ensemble_data)
192
+
193
+ row = {
194
+ "filename": filename,
195
+ "timestamp": timestamp,
196
+ "type": "ensemble",
197
+ "model_file": {"path": filename, "bytes": model_bytes},
198
+ }
199
+
200
+ table = pa.Table.from_pylist([row])
201
+ table = table.replace_schema_metadata({
202
+ "huggingface": json.dumps({"info": {
203
+ "features": {
204
+ "filename": {"_type": "Value", "dtype": "string"},
205
+ "timestamp": {"_type": "Value", "dtype": "string"},
206
+ "type": {"_type": "Value", "dtype": "string"},
207
+ "model_file": {"_type": "Value", "dtype": "binary"},
208
+ }
209
+ }})
210
+ })
211
+
212
+ buf = io.BytesIO()
213
+ pq.write_table(table, buf)
214
+ buf.seek(0)
215
+
216
+ scheduler = CommitScheduler(
217
+ repo_id=st.secrets["HF_REPO_ID"],
218
+ repo_type="dataset",
219
+ path_in_repo="models",
220
+ token=st.secrets["HF_TOKEN"],
221
+ private=True,
222
+ folder_path="dummy"
223
+ )
224
+ scheduler.api.upload_file(
225
+ repo_id=st.secrets["HF_REPO_ID"],
226
+ repo_type="dataset",
227
+ path_in_repo=f"models/{uuid.uuid4()}.parquet",
228
+ path_or_fileobj=buf
229
+ )
230
+
231
+ return filename
232
+
233
+
234
+ # def load_model(model_name):
235
+ # filepath = MODEL_DIR / f"{model_name}.pkl"
236
+ # if not filepath.exists():
237
+ # raise FileNotFoundError(f"Model file not found: {filepath}")
238
+
239
+ # with open(filepath, "rb") as f:
240
+ # single_model_data = pickle.load(f)
241
+
242
+ # return single_model_data
243
+
244
+ def load_model(model_name):
245
+ from huggingface_hub import hf_hub_download
246
+ import pyarrow.parquet as pq
247
+ import pickle
248
+
249
+ files = hf_hub_download(
250
+ repo_id=st.secrets["HF_REPO_ID"],
251
+ repo_type="dataset",
252
+ token=st.secrets["HF_TOKEN"],
253
+ filename=None, # Get whole repo listing
254
+ cache_dir=None,
255
+ local_dir=None,
256
+ local_dir_use_symlinks=False,
257
+ force_download=False,
258
+ resume_download=True
259
+ )
260
+
261
+ from huggingface_hub import HfApi
262
+ api = HfApi(token=st.secrets["HF_TOKEN"])
263
+ all_files = api.list_repo_files(repo_id=st.secrets["HF_REPO_ID"], repo_type="dataset")
264
+ model_files = [f for f in all_files if f.startswith("models/") and f.endswith(".parquet")]
265
+
266
+ # Find matching filename
267
+ target_file = None
268
+ for f in model_files:
269
+ downloaded = hf_hub_download(
270
+ repo_id=st.secrets["HF_REPO_ID"],
271
+ repo_type="dataset",
272
+ filename=f,
273
+ token=st.secrets["HF_TOKEN"]
274
+ )
275
+ table = pq.read_table(downloaded)
276
+ row = table.to_pylist()[0]
277
+ if row["filename"] == model_name:
278
+ target_file = downloaded
279
+ break
280
+
281
+ if not target_file:
282
+ raise FileNotFoundError(f"Model {model_name} not found in repo.")
283
+
284
+ model_bytes = row["model_file"]["bytes"]
285
+ return pickle.loads(model_bytes)
286
+
287
+
288
+ # def load_model_ensemble(filename):
289
+ # filepath = MODEL_DIR / f"{filename}.pkl"
290
+ # if not filepath.exists():
291
+ # raise FileNotFoundError(f"Model file not found: {filepath}")
292
+
293
+ # with open(filepath, "rb") as f:
294
+ # ensemble_data = pickle.load(f)
295
+
296
+ # return ensemble_data
297
+
298
+ def load_model_ensemble(filename):
299
+ return load_model(filename)
300
+
301
+
302
+ def ensemble_predict(models, X, cat_features):
303
+ preds = sum([model.predict_proba(X)[:, 1] for model in models]) / len(models)
304
+ return preds
src/model_utils_ori.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pickle
3
+ import catboost
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+
7
+ from catboost import CatBoostClassifier
8
+ # from xgboost import XGBClassifier
9
+ # from lightgbm import LGBMClassifier
10
+ from sklearn.ensemble import RandomForestClassifier
11
+
12
+ MODEL_DIR = Path("saved_models")
13
+ MODEL_DIR.mkdir(exist_ok=True)
14
+
15
+ import yaml
16
+
17
+ def load_model_params(model_type, mode="ensemble", path=Path("params") / "model_params.yaml"):
18
+ if mode not in ["ensemble", "single_model"]:
19
+ raise ValueError("mode must be either 'ensemble' or 'single_model'")
20
+
21
+ if model_type not in ["CatBoost", "XGBoost", "LightGBM", "RandomForest"]:
22
+ raise ValueError("model_type must be one of 'CatBoost', 'XGBoost', 'LightGBM', or 'RandomForest'")
23
+
24
+ with open(path, "r") as f:
25
+ all_params = yaml.safe_load(f)
26
+
27
+ params = all_params[model_type][mode]
28
+ if "random_seed" in params:
29
+ st.session_state.random_seed = params["random_seed"]
30
+
31
+ return params
32
+
33
+ def get_model(model_type, mode="ensemble", best_iter=None):
34
+ params = load_model_params(model_type, mode)
35
+
36
+ # iter is set for single_model mode, where
37
+ if best_iter is not None:
38
+ params['iterations'] = best_iter
39
+ # if "random_seed" in st.session_state:
40
+ # random_seed = st.session_state.random_seed
41
+
42
+ if model_type == "CatBoost":
43
+ return CatBoostClassifier(**params)
44
+ # elif model_type == "XGBoost":
45
+ # return XGBClassifier(**params, use_label_encoder=False, eval_metric="logloss")
46
+ # elif model_type == "LightGBM":
47
+ # return LGBMClassifier(**params)
48
+ elif model_type == "RandomForest":
49
+ return RandomForestClassifier(**params)
50
+ else:
51
+ raise ValueError(f"Unsupported model type: {model_type}")
52
+
53
+ def save_model(model, user_model_name, metrics_result_single=None):
54
+ timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
55
+ filename = f"{timestamp}_{user_model_name}_single.pkl"
56
+ filepath = MODEL_DIR / filename
57
+
58
+ single_model_data = {
59
+ "timestamp": timestamp,
60
+ "model_name": user_model_name,
61
+ "target_col": st.session_state.target_col if "target_col" in st.session_state else "UNKNOWN",
62
+ "model": model,
63
+ "best_iteration": st.session_state.best_iteration,
64
+ "metrics_result_single": metrics_result_single
65
+ }
66
+
67
+ with open(filepath, "wb") as f:
68
+ pickle.dump(single_model_data, f)
69
+ return filename
70
+
71
+ def load_model(model_name):
72
+ filepath = MODEL_DIR / f"{model_name}.pkl"
73
+ if not filepath.exists():
74
+ raise FileNotFoundError(f"Model file not found: {filepath}")
75
+
76
+ with open(filepath, "rb") as f:
77
+ single_model_data = pickle.load(f)
78
+
79
+ return single_model_data
80
+
81
+ def save_model_ensemble(models, user_model_name, best_iterations=None, fold_scores=None, metrics_result_ensemble=None):
82
+ timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
83
+ filename = f"{timestamp}_{user_model_name}_ensemble.pkl"
84
+ filepath = MODEL_DIR / filename
85
+
86
+ ensemble_data = {
87
+ "timestamp": timestamp,
88
+ "model_name": user_model_name,
89
+ "target_col": st.session_state.target_col if "target_col" in st.session_state else "UNKNOWN",
90
+ "models": models,
91
+ "best_iterations": best_iterations,
92
+ "fold_scores": fold_scores,
93
+ "metrics_result_ensemble": metrics_result_ensemble
94
+ }
95
+
96
+ with open(filepath, "wb") as f:
97
+ pickle.dump(ensemble_data, f)
98
+ return filename
99
+
100
+
101
+ def load_model_ensemble(filename):
102
+ filepath = MODEL_DIR / f"{filename}.pkl"
103
+ if not filepath.exists():
104
+ raise FileNotFoundError(f"Model file not found: {filepath}")
105
+
106
+ with open(filepath, "rb") as f:
107
+ ensemble_data = pickle.load(f)
108
+
109
+ return ensemble_data
110
+
111
+
112
+ def ensemble_predict(models, X, cat_features):
113
+ preds = sum([model.predict_proba(X)[:, 1] for model in models]) / len(models)
114
+ return preds
src/pages/1_Individual_Predictions.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from model_utils import load_model, load_model_ensemble, ensemble_predict
4
+ from preprocess_utils import load_train_features
5
+ from preprocess_utils import preprocess_pipeline as preprocess
6
+ from inference_utils import add_predictions
7
+ from sidebar import sidebar
8
+
9
+ # Initialize sidebar
10
+ sidebar()
11
+
12
+ st.title("👤 Individual Patient Prediction")
13
+
14
+ with st.form("individual_form"):
15
+ st.subheader("Recipient Information", divider=True)
16
+ gender = st.radio("Recipient Gender", ['MALE', 'FEMALE'], index=None)
17
+ dob = st.date_input("Recipient DOB", value="2000-01-31", format="DD/MM/YYYY")
18
+ nationality = st.selectbox("Recipient Nationality", sorted([
19
+ 'EMIRATI', 'EGYPTIAN', 'BANGLADESHI', 'AFGHAN', 'SYRIAN', 'INDIAN', 'PAKISTANI',
20
+ 'YEMENI', 'JORDANIAN', 'OMANI', 'FILIPINO', 'SUDANESE', 'MOROCCAN',
21
+ 'PALESTINIAN', 'ETHIOPIAN', 'AMERICAN', 'ALGERIAN', 'INDONESIAN', 'LEBANESE',
22
+ 'SAUDI', 'SRI LANKAN', 'SOMALI', 'FIJI', 'NEW ZEALANDER', 'COMORAN',
23
+ 'MAURITANIA', 'KUWAIT', 'BRITISH', 'UZBEKISTANI', 'ERITREAN', 'IRAQI'
24
+ ]), index=None)
25
+ diagnosis = st.selectbox("Hematological Diagnosis", sorted([
26
+ 'ACUTE MYELOID LEUKEMIA', 'ALPHA THALSSEMIA', 'AMYLOIDOSIS', 'APLASTIC ANEMIA', 'BALL',
27
+ 'BETA THALESSEMIA', 'BLASTIC PLASMACYTOID DENDRITRIC CELL NEOPLASM',
28
+ 'CHRONIC GRANULOMATOUS DISEASE', 'CHRONIC LYMPHOCYTIC LEUKEMIA', 'CML',
29
+ 'COMBINED VARIABLE IMMUNODEFICIENCY', 'DYSKERATOSIS CONGENTIA', 'FANCONI ANEMIA',
30
+ 'GLANZMANN THROMBASTHENIA', 'HEMOPHAGOCYTIC LYMPHOHISTIOCYTOSIS (HLH)',
31
+ 'HEREDITARY SPHEROCYTOSIS', 'HODGKIN LYMPHOMA', 'HYPOGAMMAGLOBULINEMIA',
32
+ 'LANGERHANS CELL HISTIOCYTOSIS', 'MYELODYSPLASTIC SYNDROME', 'MEDULLOBLASTOMA',
33
+ 'MULTIPLE MYELOMA', 'MYELOFIBROSIS', 'MYELOPROLIFERATIVE DISORDER',
34
+ 'NEUROBLASTOMA', 'NON HODGKIN LYMPHOMA', 'OTHER', 'PAROXYSMAL NOCTURNAL HEMOGLOBINURIA',
35
+ 'PLASMA CELL LEUKEMIA', 'SCID', 'SICKLE CELL DISEASE', 'TALL', 'X-LINKED HYPER IGM SYNDROME'
36
+ ]), index=None)
37
+ diagnosis_date = st.date_input("Date of First Diagnosis / BMBx", value="2000-01-31", format="DD/MM/YYYY")
38
+
39
+ recipient_blood_group = st.radio("Recipient Blood Group", ['A+', 'A-', 'B+', 'B-', 'O+', 'O-', 'AB+', 'AB-', 'Unknown'], key="recipient_blood_group", index=None)
40
+
41
+ st.markdown("###### Recipient HLA Alleles")
42
+ r_hla_a = st.multiselect("R_HLA_A", options=['Unknown', 'SELF'], max_selections=2, accept_new_options=True)
43
+ r_hla_b = st.multiselect("R_HLA_B", options=['Unknown', 'SELF'], max_selections=2, accept_new_options=True)
44
+ r_hla_c = st.multiselect("R_HLA_C", options=['Unknown', 'SELF'], max_selections=2, accept_new_options=True)
45
+ r_hla_dr = st.multiselect("R_HLA_DR", options=['Unknown', 'SELF'], max_selections=2, accept_new_options=True)
46
+ r_hla_dq = st.multiselect("R_HLA_DQ", options=['Unknown', 'SELF'], max_selections=2, accept_new_options=True)
47
+
48
+ st.subheader("Donor Information", divider=True)
49
+ donor_relation = st.radio("Donor Relation to Recipient", [
50
+ 'SELF', 'SIBLING', 'FIRST DEGREE RELATIVE', 'SECOND DEGREE RELATIVE', 'RELATED', 'UNRELATED', 'Unknown',
51
+ ], index=None)
52
+
53
+ if donor_relation == 'SELF':
54
+ # If the donor is the recipient, set the donor
55
+ st.session_state.SELF = True
56
+ else:
57
+ st.session_state.SELF = False
58
+
59
+ donor_gender = st.radio("Donor Gender", ['MALE', 'FEMALE'], index=None)
60
+
61
+ donor_dob = st.date_input("Donor DOB", value="2000-01-31", format="DD/MM/YYYY")
62
+
63
+ donor_blood_group = st.radio("Donor Blood Group", ['A+', 'A-', 'B+', 'B-', 'O+', 'O-', 'AB+', 'AB-', 'Unknown'], key="donor_blood_group", index=None)
64
+
65
+ st.markdown("###### Donor HLA Alleles")
66
+ d_hla_a = st.multiselect("D_HLA_A", options=['Unknown', 'SELF'], max_selections=2, accept_new_options=True)
67
+ d_hla_b = st.multiselect("D_HLA_B", options=['Unknown', 'SELF'], max_selections=2, accept_new_options=True)
68
+ d_hla_c = st.multiselect("D_HLA_C", options=['Unknown', 'SELF'], max_selections=2, accept_new_options=True)
69
+ d_hla_dr = st.multiselect("D_HLA_DR", options=['Unknown', 'SELF'], max_selections=2, accept_new_options=True)
70
+ d_hla_dq = st.multiselect("D_HLA_DQ", options=['Unknown', 'SELF'], max_selections=2, accept_new_options=True)
71
+
72
+ st.subheader("Treatment Details", divider=True)
73
+ lines_of_rx = st.selectbox("Number of Lines of Rx Before HSCT", [0, 1, 2, 3, 4, 5, 6, 7, 'Unknown'], index=None)
74
+ conditioning = st.multiselect("Pre-HSCT Conditioning Regimen", sorted([
75
+ 'ALEMTUZUMAB', 'ATG', 'BEAM', 'BUSULFAN', 'CAMPATH', 'CARMUSTINE', 'CLOFARABINE',
76
+ 'CYCLOPHOSPHAMIDE', 'CYCLOSPORIN', 'CYTARABINE', 'ETOPOSIDE', 'FLUDARABINE',
77
+ 'GEMCITABINE', 'MELPHALAN', 'MTX', 'OTHER', 'RANIMUSTINE', 'REDUCEDCONDITIONING',
78
+ 'RITUXIMAB', 'SIROLIMUS', 'TBI', 'THIOTEPA', 'TREOSULFAN', 'UA', 'VORNOSTAT'
79
+ ]), placeholder="Choose an option(s)")
80
+
81
+ st.subheader("HSCT Details", divider=True)
82
+ hsct_date = st.date_input("HSCT Date", value="2000-01-31", format="DD/MM/YYYY")
83
+ cell_source = st.radio("Source of Cells", sorted(['BONE MARROW', 'PERIPHERAL BLOOD', 'UMBILICAL CORD', 'PBSC', 'Unknown']), index=None)
84
+ hla_match = st.radio("HLA Match Ratio", ['FULL', 'PARTIAL', 'HAPLOIDENTICAL', 'Unknown'], index=None)
85
+
86
+ st.subheader("Post-HSCT Treatment and GVHD Prophylaxis", divider=True)
87
+ post_hsct_regimen = st.radio("Post-HSCT Regimen", ['YES', 'NO', 'IVIG', 'Unknown'], index=None)
88
+
89
+ gvhd_prophylaxis = st.multiselect("First GVHD Prophylaxis", [
90
+ 'NONE'] + sorted(['ABATACEPT', 'ALEMTUZUMAB', 'ATG', 'CYCLOPHOSPHAMIDE', #'CYCLOSPOPRIN', 'CYCLOSPRIN',
91
+ 'CYCLOSPORIN', 'IMATINIB', 'LEFLUNOMIDE', 'MMF', 'MTX',
92
+ 'RUXOLITINIB', 'SIROLIMUS', 'STEROID', 'TAC'
93
+ ]), placeholder="Choose an option(s)")
94
+
95
+ submitted = st.form_submit_button("PREDICT", type="primary")
96
+
97
+ if submitted:
98
+ # single model
99
+ model = load_model(st.session_state.selected_model)
100
+
101
+ # Collect input values in a dict
102
+ input_dict = {
103
+ "Recipient_gender": gender,
104
+ "Recepient_DOB": dob.strftime("%d/%m/%Y"),
105
+ "Recepient_Nationality": nationality,
106
+ "Hematological Diagnosis": diagnosis,
107
+ "Date of first diagnosis/BMBx date": diagnosis_date.strftime("%d/%m/%Y"),
108
+ "Recepient_Blood group before HSCT": recipient_blood_group if recipient_blood_group != "Unknown" else "X",
109
+ "Donor_DOB": donor_dob.strftime("%d/%m/%Y"),
110
+ "Donor_gender": donor_gender,
111
+ "D_Blood group": donor_blood_group if donor_blood_group != "Unknown" else "X",
112
+ "R_HLA_A": r_hla_a,
113
+ "R_HLA _B": r_hla_b,
114
+ "R_HLA _C": r_hla_c,
115
+ "R_HLA _DR": r_hla_dr,
116
+ "R_HLA _DQ": r_hla_dq,
117
+ "D_HLA_A": d_hla_a,
118
+ "D_HLA _B": d_hla_b,
119
+ "D_HLA_C": d_hla_c,
120
+ "D_HLA_DR": d_hla_dr,
121
+ "D_HLA _DQ": d_hla_dq,
122
+ "Number of lines of Rx before HSCT": lines_of_rx,
123
+ "PreHSCT conditioning regimen+/-ATG+/-TBI": conditioning,
124
+ "HSCT_date": hsct_date.strftime("%d/%m/%Y"),
125
+ "Source of cells": cell_source,
126
+ "Donor_relation to recipient": donor_relation,
127
+ "HLA match ratio": hla_match,
128
+ "Post HSCT regimen": post_hsct_regimen,
129
+ "First_GVHD prophylaxis": gvhd_prophylaxis
130
+ }
131
+
132
+ # You will need to transform these values into proper numeric or encoded inputs for your model
133
+ X = pd.DataFrame([input_dict]) # Placeholder
134
+ st.dataframe(X, use_container_width=True)
135
+ X.to_csv("/home/muhammadridzuan/2025_GVHD/GVHD_App/saved_models/test_individual_input2.csv", index=False)
136
+
137
+ # Define features
138
+ train_features, cat_features = load_train_features()
139
+ X = preprocess(X)
140
+ X = X[train_features]
141
+ st.write("Processed Input Data:")
142
+ st.dataframe(X, use_container_width=True)
143
+
144
+ if st.session_state.SELF:
145
+ prob = 0.0
146
+ else:
147
+ prob = model.predict_proba(X)[0][1]
148
+
149
+ result_df = pd.DataFrame()
150
+ result_df = add_predictions(result_df, [prob])
151
+
152
+ st.write("Predictions:")
153
+ st.dataframe(result_df, use_container_width=False, width=300)
src/pages/2_Bulk_Predictions.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from model_utils import load_model, load_model_ensemble, ensemble_predict
4
+ from preprocess_utils import load_train_features
5
+ from preprocess_utils import preprocess_pipeline as preprocess
6
+ from inference_utils import add_predictions, compute_metrics
7
+ from sidebar import sidebar
8
+
9
+ # Initialize sidebar
10
+ sidebar()
11
+
12
+ st.title("📊 Bulk Patient Predictions")
13
+
14
+ uploaded_file = st.file_uploader("Upload CSV", type=["csv"])
15
+ if uploaded_file:
16
+ df = pd.read_csv(uploaded_file, header=1)
17
+ st.write("Raw Data:")
18
+ st.dataframe(df)
19
+
20
+ if st.button("Preprocess"):
21
+ df_proc = preprocess(df)
22
+ edited_df = st.data_editor(df_proc, num_rows="dynamic")
23
+ st.session_state.bulk_input_df = edited_df
24
+
25
+ if st.button("Predict"):
26
+ if "bulk_input_df" not in st.session_state:
27
+ st.warning("Please preprocess data first.")
28
+ else:
29
+ if "ensemble" in st.session_state.selected_model:
30
+ # ensemble model
31
+ ensemble = True
32
+ try:
33
+ ensemble_data = load_model_ensemble(st.session_state.selected_model)
34
+ st.session_state.trained_models = ensemble_data["models"]
35
+ models = ensemble_data["models"]
36
+ st.session_state.best_iterations = ensemble_data.get("best_iterations", [])
37
+ st.session_state.fold_scores = ensemble_data.get("fold_scores", [])
38
+
39
+ except Exception as e:
40
+ st.error(f"Error loading ensemble: {str(e)}")
41
+ else:
42
+ # single model
43
+ ensemble = False
44
+ model_dict = load_model(st.session_state.selected_model)
45
+ model = model_dict["model"]
46
+
47
+ df = st.session_state.bulk_input_df
48
+
49
+ # Define the target column (customize this based on your use case)
50
+ target_col = "GVHD" # or "Acute GVHD(<100 days)", etc.
51
+
52
+ # Optional filtering depending on target choice
53
+ if target_col in ["Acute GVHD(<100 days)", "Chronic GVHD>100 days"]:
54
+ df = df[df[target_col] != 3]
55
+
56
+ y = df[target_col]
57
+
58
+ # Define features
59
+ train_features, cat_features = load_train_features()
60
+
61
+ X = df[train_features]
62
+
63
+ # Convert categorical columns to strings (CatBoost handles them)
64
+ for col in cat_features:
65
+ X[col] = X[col].astype(str)
66
+
67
+ # ensemble model prediction
68
+ if ensemble:
69
+ preds = ensemble_predict(models, X, cat_features)
70
+ metrics_result_ensemble = compute_metrics(y, preds)
71
+ else:
72
+ # single model prediction
73
+ preds = model.predict_proba(X)[:, 1]
74
+ metrics_result_single = compute_metrics(y, preds)
75
+
76
+ st.session_state.targets_df = y
77
+ styled = add_predictions(X.copy(), preds)
78
+ st.write("Predictions:")
79
+ st.dataframe(styled, use_container_width=False, width=300)
80
+
81
+ if not ensemble:
82
+ st.write("Single Model Predictions:")
83
+ for metric, value in metrics_result_single.items():
84
+ st.write(f" **{metric}**: {value:.3f}")
85
+ else:
86
+ st.write("Ensemble Predictions:")
87
+ for metric, value in metrics_result_ensemble.items():
88
+ st.write(f" **{metric}**: {value:.3f}")
89
+
90
+ # Find difference in columns between uploaded data and training features
91
+ missing_features = set(st.session_state.orig_train_cols).union(train_features) - set(df.columns)
92
+ missing_features = set([i if i[-2:] != "_X" else '' for i in missing_features])
93
+ missing_features = sorted(list(missing_features))
94
+
95
+ new_features = set(df.columns) - set(st.session_state.orig_train_cols).union(train_features)
96
+ new_features = sorted(list(new_features))
97
+ if missing_features:
98
+ st.warning(f"**Missing features in uploaded data:** \n{''' \n'''.join(missing_features)}")
99
+
100
+ if new_features:
101
+ st.warning(f"**New features in uploaded data not in training set:** \n{''' \n'''.join(new_features)}")
src/pages/3_Preprocessing_and_Training.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ from catboost import CatBoostClassifier, cv, Pool
5
+ from model_utils import get_model, save_model, save_model_ensemble, ensemble_predict
6
+ from preprocess_utils import load_train_features
7
+ from preprocess_utils import preprocess_pipeline as preprocess
8
+ from inference_utils import compute_metrics
9
+ from sidebar import sidebar
10
+ from sklearn.model_selection import StratifiedKFold
11
+
12
+ # Initialize sidebar
13
+ sidebar()
14
+
15
+ st.title("🧪 Preprocessing & Training")
16
+
17
+ uploaded_file = st.file_uploader("Upload CSV", type=["csv"])
18
+ if uploaded_file:
19
+ df = pd.read_csv(uploaded_file, header=1)
20
+ st.write("Raw Data:")
21
+ st.dataframe(df)
22
+
23
+ st.session_state.target_col = st.selectbox(
24
+ "Select target column to predict:",
25
+ options=[
26
+ "GVHD",
27
+ "Acute GVHD(<100 days)",
28
+ "Chronic GVHD>100 days",
29
+ ],
30
+ index=0
31
+ )
32
+
33
+ if st.button("Preprocess"):
34
+ df_proc = preprocess(df)
35
+
36
+ # TODO: Remove. Temp
37
+ st.session_state.orig_train_cols = df_proc.columns.tolist()
38
+
39
+ edited_df = st.data_editor(df_proc, num_rows="dynamic")
40
+ st.session_state.edited_df = edited_df
41
+
42
+
43
+ if st.button("Re-train"):
44
+ if "edited_df" not in st.session_state:
45
+ st.warning("Please preprocess and edit data first.")
46
+ else:
47
+ # Model selection
48
+ model_type = "CatBoost" # Fixed to CatBoost for now
49
+
50
+ df = st.session_state.edited_df.copy()
51
+ target_col = st.session_state.target_col
52
+
53
+ if target_col in ["Acute GVHD(<100 days)", "Chronic GVHD>100 days"]:
54
+ df = df[df[target_col] != 3]
55
+
56
+ y = df[target_col]
57
+ st.write(df[target_col].value_counts())
58
+ train_features, cat_features = load_train_features()
59
+ X = df[train_features]
60
+
61
+ for col in cat_features:
62
+ X[col] = X[col].astype(str)
63
+
64
+ st.info("Running 5-Fold cross-validation with model saving...")
65
+ skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)
66
+
67
+ fold_models = []
68
+ fold_scores = []
69
+ best_iterations = []
70
+
71
+ for fold, (train_idx, val_idx) in enumerate(skf.split(X, y), start=1):
72
+ st.write(f"Training Fold {fold}...")
73
+
74
+ X_train, X_val = X.iloc[train_idx], X.iloc[val_idx]
75
+ y_train, y_val = y.iloc[train_idx], y.iloc[val_idx]
76
+
77
+ train_pool = Pool(X_train, y_train, cat_features=cat_features)
78
+ val_pool = Pool(X_val, y_val, cat_features=cat_features)
79
+
80
+ model = get_model(model_type, mode="ensemble")
81
+
82
+ if model_type == "CatBoost":
83
+ model.fit(
84
+ X_train, y_train,
85
+ eval_set=(X_val, y_val),
86
+ cat_features=cat_features,
87
+ use_best_model=True,
88
+ )
89
+ else:
90
+ model.fit(X_train, y_train)
91
+
92
+ best_iter = model.get_best_iteration()
93
+ best_iterations.append(best_iter)
94
+
95
+ fold_models.append(model)
96
+ val_preds = model.predict_proba(X_val)[:, 1]
97
+ fold_scores.append(model.eval_metrics(val_pool, ["AUC", "F1", "Accuracy", "Precision", "Recall", "BrierScore", "Logloss"], best_iter))
98
+
99
+ st.success(f"Fold {fold} trained. Best iteration: {best_iter}")
100
+
101
+ st.session_state.trained_models = fold_models
102
+ st.session_state.fold_scores = fold_scores
103
+ st.session_state.best_iterations = best_iterations
104
+
105
+ ### TURN OFF SINGLE MODEL TRAINING ####
106
+ # Single model training
107
+ st.session_state.best_iteration = np.max(st.session_state.best_iterations) # if "best_iterations" in st.session_state else 5000
108
+
109
+ final_model = get_model(model_type, mode="ensemble", best_iter=st.session_state.best_iteration)
110
+ if model_type == "CatBoost":
111
+ final_model.fit(
112
+ X, y,
113
+ cat_features=cat_features,
114
+ )
115
+ else:
116
+ final_model.fit(X, y)
117
+ st.session_state.trained_model = final_model
118
+
119
+ st.success("All folds completed. Models saved for ensembling.")
120
+
121
+ # CV summary metrics
122
+ if "fold_scores" in st.session_state:
123
+ st.subheader("Cross-Validation Metrics (5-Fold)")
124
+ metrics = ["AUC", "F1", "Accuracy", "Precision", "Recall", "BrierScore", "Logloss"]
125
+ scores = st.session_state.fold_scores
126
+
127
+ for metric in metrics:
128
+ values = [score[metric][-1] for score in scores] # last = best_iteration
129
+ mean_val = sum(values) / len(values)
130
+ std_val = pd.Series(values).std()
131
+ st.write(f"**{metric}**: {mean_val:.3f} ± {std_val:.3f}")
132
+
133
+ # Single & ensemble evaluation
134
+ if "trained_model" in st.session_state or "trained_models" in st.session_state:
135
+ st.subheader("🔮 Ensemble Evaluation (on Training Data)")
136
+
137
+ models = st.session_state.trained_models
138
+ ### TURN OFF SINGLE MODEL EVALUATION ###
139
+ single_model = st.session_state.trained_model
140
+
141
+ df = st.session_state.edited_df.copy()
142
+ target_col = st.session_state.target_col
143
+ # st.session_state.targets_df = df[["GVHD", "Acute GVHD(<100 days)", "Chronic GVHD>100 days"]]
144
+
145
+ if target_col in ["Acute GVHD(<100 days)", "Chronic GVHD>100 days"]:
146
+ df = df[df[target_col] != 3]
147
+
148
+ y = df[target_col]
149
+ st.session_state.targets_df = y
150
+ train_features, cat_features = load_train_features()
151
+ X = df[train_features]
152
+
153
+ for col in cat_features:
154
+ X[col] = X[col].astype(str)
155
+
156
+ ### TURN OFF SINGLE MODEL EVALUATION ###
157
+ y_pred_prob_single = single_model.predict_proba(X)[:, 1]
158
+ metrics_result_single = compute_metrics(y, y_pred_prob_single)
159
+
160
+ y_pred_prob_ensemble = ensemble_predict(models, X, cat_features)
161
+ metrics_result_ensemble = compute_metrics(y, y_pred_prob_ensemble)
162
+
163
+ ### TURN OFF SINGLE MODEL EVALUATION ###
164
+ st.write("Single Model Predictions:")
165
+ for metric, value in metrics_result_single.items():
166
+ st.write(f"**{metric}**: {value:.3f}")
167
+
168
+ st.write("Ensemble Predictions:")
169
+ for metric, value in metrics_result_ensemble.items():
170
+ st.write(f"**{metric}**: {value:.3f}")
171
+
172
+
173
+ user_model_name = st.text_input("Enter model name to be saved:")
174
+
175
+ if user_model_name:
176
+ ### TURN OFF SINGLE MODEL SAVING ###
177
+ filename = save_model(st.session_state.trained_model, user_model_name, metrics_result_single)
178
+
179
+ filename = save_model_ensemble(
180
+ st.session_state.trained_models,
181
+ user_model_name,
182
+ best_iterations=st.session_state.best_iterations,
183
+ fold_scores=st.session_state.fold_scores,
184
+ metrics_result_ensemble=metrics_result_ensemble
185
+ )
186
+
187
+ st.success(f"{filename} is successfully saved!")
188
+ st.success(f"Ensemble saved as {filename}_ensemble")
189
+
190
+ else:
191
+ st.info("Train a model first before saving.")
src/params/model_params.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CatBoost:
2
+ ensemble:
3
+ learning_rate: 0.1
4
+ depth: 12
5
+ loss_function: Logloss
6
+ random_seed: 0
7
+ l2_leaf_reg: 7
8
+ subsample: 0.7
9
+ grow_policy: Lossguide # SymmetricTree or Depthwise or Lossguide
10
+ bagging_temperature: 1
11
+ random_strength: 5
12
+ min_data_in_leaf: 5
13
+ iterations: 10000
14
+ early_stopping_rounds: 50
15
+ custom_loss: ['AUC', "F1", "Accuracy", "Precision", "Recall", "BrierScore", "Logloss"]
16
+ verbose: False
17
+
18
+ # lr1e1_d12_l27_ss07_gpLg_bag1_rs5_m5
19
+
20
+ single_model:
21
+ # in this mode, the model is trained on the entire dataset using the best_iter obtained from cross-validation
22
+ learning_rate: 0.1
23
+ depth: 12
24
+ loss_function: Logloss
25
+ random_seed: 0
26
+ l2_leaf_reg: 7
27
+ subsample: 0.7
28
+ grow_policy: Lossguide # SymmetricTree or Depthwise or Lossguide
29
+ bagging_temperature: 1
30
+ random_strength: 5
31
+ min_data_in_leaf: 5
32
+ custom_loss: ['AUC', "F1", "Accuracy", "Precision", "Recall", "BrierScore", "Logloss"]
33
+ verbose: False
34
+
src/preprocess_utils.py ADDED
@@ -0,0 +1,928 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import re
4
+ from sklearn.preprocessing import MultiLabelBinarizer
5
+
6
+ # Constants
7
+ UNKNOWN_TOKEN = "X"
8
+ DATE_FORMAT = '%d/%m/%Y'
9
+ BLOOD_GROUP_COLS = ["D_Blood group", "Recepient_Blood group before HSCT"]
10
+ NATIONALITY_CORRECTIONS = {
11
+ "AFGHANISTAN": "AFGHAN",
12
+ "ALGERIA": "ALGERIAN",
13
+ "EMARATI": "EMIRATI",
14
+ "UAE": "EMIRATI",
15
+ "PHILIPPINO": "FILIPINO",
16
+ "JORDAN": "JORDANIAN",
17
+ "JORDANI": "JORDANIAN",
18
+ "PAKISTAN": "PAKISTANI",
19
+ "PAKISTANII": "PAKISTANI",
20
+ "PALESTINE": "PALESTINIAN",
21
+ "PALESTENIAN": "PALESTINIAN",
22
+ "USA": "AMERICAN",
23
+ }
24
+ # 1. Regional Grouping (Geography-Based)
25
+ regional_grouping = {
26
+ # Middle East
27
+ 'EMIRATI': 'Middle East',
28
+ 'OMANI': 'Middle East',
29
+ 'SAUDI': 'Middle East',
30
+ 'KUWAIT': 'Middle East',
31
+ 'JORDANIAN': 'Middle East',
32
+ 'LEBANESE': 'Middle East',
33
+ 'IRAQI': 'Middle East',
34
+ 'SYRIAN': 'Middle East',
35
+ 'YEMENI': 'Middle East',
36
+ 'PALESTINIAN': 'Middle East',
37
+
38
+ # North Africa
39
+ 'EGYPTIAN': 'North Africa',
40
+ 'SUDANESE': 'North Africa',
41
+ 'ALGERIAN': 'North Africa',
42
+ 'MOROCCAN': 'North Africa',
43
+ 'MAURITANIA': 'North Africa',
44
+ 'COMORAN': 'North Africa',
45
+
46
+ # South Asia
47
+ 'INDIAN': 'South Asia',
48
+ 'PAKISTANI': 'South Asia',
49
+ 'BANGLADESHI': 'South Asia',
50
+ 'SRI LANKAN': 'South Asia',
51
+ 'AFGHAN': 'South Asia',
52
+
53
+ # Southeast Asia
54
+ 'FILIPINO': 'Southeast Asia',
55
+ 'INDONESIAN': 'Southeast Asia',
56
+
57
+ # East Africa
58
+ 'ETHIOPIAN': 'East Africa',
59
+ 'SOMALI': 'East Africa',
60
+ 'ERITREAN': 'East Africa',
61
+
62
+ # Central Asia
63
+ 'UZBEKISTANI': 'Central Asia',
64
+
65
+ # Western Nations / Oceania / Americas
66
+ 'AMERICAN': 'Western',
67
+ 'BRITISH': 'Western',
68
+ 'NEW ZEALANDER': 'Oceania',
69
+ 'FIJI': 'Oceania'
70
+ }
71
+
72
+ # 2. Cultural-Linguistic Grouping
73
+ cultural_grouping = {
74
+ 'EMIRATI': 'Arab',
75
+ 'OMANI': 'Arab',
76
+ 'SAUDI': 'Arab',
77
+ 'KUWAIT': 'Arab',
78
+ 'JORDANIAN': 'Arab',
79
+ 'LEBANESE': 'Arab',
80
+ 'IRAQI': 'Arab',
81
+ 'SYRIAN': 'Arab',
82
+ 'YEMENI': 'Arab',
83
+ 'PALESTINIAN': 'Arab',
84
+ 'EGYPTIAN': 'Arab',
85
+ 'SUDANESE': 'Arab-African',
86
+ 'ALGERIAN': 'Arab',
87
+ 'MOROCCAN': 'Arab',
88
+ 'MAURITANIA': 'Arab',
89
+ 'COMORAN': 'Arab-African',
90
+ 'INDIAN': 'South Asian',
91
+ 'PAKISTANI': 'South Asian',
92
+ 'BANGLADESHI': 'South Asian',
93
+ 'SRI LANKAN': 'South Asian',
94
+ 'AFGHAN': 'South Asian',
95
+ 'FILIPINO': 'Southeast Asian',
96
+ 'INDONESIAN': 'Southeast Asian',
97
+ 'ETHIOPIAN': 'East African',
98
+ 'SOMALI': 'East African',
99
+ 'ERITREAN': 'East African',
100
+ 'UZBEKISTANI': 'Central Asian',
101
+ 'AMERICAN': 'Western/English-speaking',
102
+ 'BRITISH': 'Western/English-speaking',
103
+ 'NEW ZEALANDER': 'Western/English-speaking',
104
+ 'FIJI': 'Pacific Islander'
105
+ }
106
+
107
+ # 3. World Bank Income Grouping
108
+ income_grouping = {
109
+ 'EMIRATI': 'High income',
110
+ 'OMANI': 'High income',
111
+ 'SAUDI': 'High income',
112
+ 'KUWAIT': 'High income',
113
+ 'JORDANIAN': 'Upper-middle income',
114
+ 'LEBANESE': 'Upper-middle income',
115
+ 'IRAQI': 'Upper-middle income',
116
+ 'SYRIAN': 'Low income',
117
+ 'YEMENI': 'Low income',
118
+ 'PALESTINIAN': 'Lower-middle income',
119
+ 'EGYPTIAN': 'Lower-middle income',
120
+ 'SUDANESE': 'Low income',
121
+ 'ALGERIAN': 'Lower-middle income',
122
+ 'MOROCCAN': 'Lower-middle income',
123
+ 'MAURITANIA': 'Low income',
124
+ 'COMORAN': 'Low income',
125
+ 'INDIAN': 'Lower-middle income',
126
+ 'PAKISTANI': 'Lower-middle income',
127
+ 'BANGLADESHI': 'Lower-middle income',
128
+ 'SRI LANKAN': 'Lower-middle income',
129
+ 'AFGHAN': 'Low income',
130
+ 'FILIPINO': 'Lower-middle income',
131
+ 'INDONESIAN': 'Lower-middle income',
132
+ 'ETHIOPIAN': 'Low income',
133
+ 'SOMALI': 'Low income',
134
+ 'ERITREAN': 'Low income',
135
+ 'UZBEKISTANI': 'Lower-middle income',
136
+ 'AMERICAN': 'High income',
137
+ 'BRITISH': 'High income',
138
+ 'NEW ZEALANDER': 'High income',
139
+ 'FIJI': 'Upper-middle income'
140
+ }
141
+
142
+ # 4. WHO Regional Office Grouping
143
+ who_region_grouping = {
144
+ 'EMIRATI': 'EMRO',
145
+ 'OMANI': 'EMRO',
146
+ 'SAUDI': 'EMRO',
147
+ 'KUWAIT': 'EMRO',
148
+ 'JORDANIAN': 'EMRO',
149
+ 'LEBANESE': 'EMRO',
150
+ 'IRAQI': 'EMRO',
151
+ 'SYRIAN': 'EMRO',
152
+ 'YEMENI': 'EMRO',
153
+ 'PALESTINIAN': 'EMRO',
154
+ 'EGYPTIAN': 'EMRO',
155
+ 'SUDANESE': 'EMRO',
156
+ 'ALGERIAN': 'AFRO',
157
+ 'MOROCCAN': 'EMRO',
158
+ 'MAURITANIA': 'AFRO',
159
+ 'COMORAN': 'AFRO',
160
+ 'INDIAN': 'SEARO',
161
+ 'PAKISTANI': 'EMRO',
162
+ 'BANGLADESHI': 'SEARO',
163
+ 'SRI LANKAN': 'SEARO',
164
+ 'AFGHAN': 'EMRO',
165
+ 'FILIPINO': 'WPRO',
166
+ 'INDONESIAN': 'SEARO',
167
+ 'ETHIOPIAN': 'AFRO',
168
+ 'SOMALI': 'EMRO',
169
+ 'ERITREAN': 'AFRO',
170
+ 'UZBEKISTANI': 'EURO',
171
+ 'AMERICAN': 'AMRO',
172
+ 'BRITISH': 'EURO',
173
+ 'NEW ZEALANDER': 'WPRO',
174
+ 'FIJI': 'WPRO'
175
+ }
176
+ groupings = {
177
+ 'Recepient_Nationality_Geographical': regional_grouping,
178
+ 'Recepient_Nationality_Cultural': cultural_grouping,
179
+ 'Recepient_Nationality_Regional_Income': income_grouping,
180
+ 'Recepient_Nationality_Regional_WHO': who_region_grouping
181
+ }
182
+
183
+ # FIRST_GVHD_PROPHYLAXIS_CORRECTIONS
184
+ DRUG_SPELLING_CORRECTIONS = {
185
+ "CYCLOSPOPRIN": "CYCLOSPORIN",
186
+ "CYCLOSPRIN": "CYCLOSPORIN",
187
+ "CYCLOSPOROIN": "CYCLOSPORIN",
188
+ "CY": "CYCLOSPORIN",
189
+ "TAC": "TACROLIMUS", # no TACROLIMUS in new dataset, only TAC
190
+ "MTX": "METHOTREXATE", # one METHOTREXATE in new dataset (ID 118), replaced with MTX
191
+ "BUDESONIDE": "STEROID", # 3 BUDESONIDE in new dataset (ID 259, 263, 273), replaced with STEROID
192
+ "STEROIDS": "STEROID", # 6 STEROIDS in new dataset (ID 172, 175, 140, 146, 152, 166), replaced with STEROID
193
+ "ATG.": "ATG",
194
+ "FLUDARABINIE": "FLUDARABINE",
195
+ "FLUDRABINE":"FLUDARABINE",
196
+ "BUSULPHAN": "BUSULFAN",
197
+ "MEPHALAN": "MELPHALAN",
198
+ }
199
+ GENDER_MAP = {
200
+ 0: "MALE", 1: "FEMALE", 2: UNKNOWN_TOKEN,
201
+ "0": "MALE", "1": "FEMALE", "2": UNKNOWN_TOKEN
202
+ }
203
+ RELATION_CORRECTIONS = {
204
+ r"(?i)BROTHER": "SIBLING",
205
+ r"(?i)SISTER": "SIBLING",
206
+ r"(?i)FATHER": "FIRST DEGREE RELATIVE",
207
+ r"(?i)MOTHER": "FIRST DEGREE RELATIVE",
208
+ r"(?i)SON": "FIRST DEGREE RELATIVE",
209
+ r"(?i)DAUGHTER": "FIRST DEGREE RELATIVE",
210
+ r"(?i)COUSIN": "SECOND DEGREE RELATIVE",
211
+ r"(?i)UNCLE": "SECOND DEGREE RELATIVE",
212
+ r"(?i)AUNT": "SECOND DEGREE RELATIVE",
213
+ r"(?i)other": UNKNOWN_TOKEN
214
+ }
215
+ STRING_NORMALIZATION_MAP = {
216
+ r"(?i)unknown": UNKNOWN_TOKEN, r"(?i)unkown": UNKNOWN_TOKEN,
217
+ r"(?i)Unknwon": UNKNOWN_TOKEN, np.nan: UNKNOWN_TOKEN,
218
+ r"(?i)\bMale\b": "MALE", r"(?i)\bFemale\b": "FEMALE",
219
+ "1o": "10", r"(?i)Umbilical Cord": "UMBILICAL CORD",
220
+ r"(?i)Umbilical Cord blood": "UMBILICAL CORD",
221
+ r"(?i)Bone Marrow": "BONE MARROW", "MDS": "MYELODYSPLASTIC SYNDROME"
222
+ }
223
+ diagnosis_group_map = {
224
+ "MYELOPROLIFERATIVE DISORDER": "MYELOPROLIFERATIVE NEOPLASMS",
225
+ "CML": "MYELOPROLIFERATIVE NEOPLASMS",
226
+ "MYELOFIBROSIS": "MYELOPROLIFERATIVE NEOPLASMS",
227
+ "NON-HODGKIN LYMPHOMA": "LYMPHOMA",
228
+ 'NON HODGKIN LYMPHOMA': "LYMPHOMA",
229
+ "HODGKIN LYMPHOMA": "LYMPHOMA",
230
+ "BETA THALASSEMIA": "RED CELL DISORDERS",
231
+ 'BETA THALESSEMIA': "RED CELL DISORDERS",
232
+ "ALPHA THALASSEMIA": "RED CELL DISORDERS",
233
+ "ALPHA THALESSEMIA": "RED CELL DISORDERS",
234
+ "ALPHA THALSSEMIA": "RED CELL DISORDERS",
235
+ "HEREDITARY SPHEROCYTOSIS": "RED CELL DISORDERS",
236
+ "SICKLE CELL DISEASE": "RED CELL DISORDERS",
237
+ "APLASTIC ANEMIA": "BMF SYNDROMES",
238
+ "FANCONI ANEMIA": "BMF SYNDROMES",
239
+ "DYSKERATOSIS CONGENITA": "BMF SYNDROMES",
240
+ 'DYSKERATOSIS CONGENTIA': "BMF SYNDROMES",
241
+ "CHRONIC GRANULOMATOUS DISEASE": "IMMUNE DISORDERS",
242
+ "COMBINED VARIABLE IMMUNODEFICIENCY": "IMMUNE DISORDERS",
243
+ "SCID": "IMMUNE DISORDERS",
244
+
245
+ ## check this one
246
+ "X-LINKED HYPERGAMMAGLOBULINEMIA": "IMMUNE DISORDERS",
247
+ '-LINKED HYPERGAMMAGLOBULINEMIA': "IMMUNE DISORDERS",
248
+ '-LINKED HYPER IGM SYNDROME': "IMMUNE DISORDERS",
249
+ "HYPOGAMMAGLOBULINEMIA": "IMMUNE DISORDERS",
250
+
251
+
252
+ ## check this one
253
+ "GLANZMANN": "OTHER",
254
+ 'GLANZMANN THROMBASTHENIA': "OTHER",
255
+
256
+ "CLL": "OTHER",
257
+ "PNH": "OTHER",
258
+ "HLH": "OTHER",
259
+ "LANGERHANS CELL HISTIOCYTOSIS": "OTHER",
260
+ "BLASTIC PLASMACYTOID DENDRITIC CELL NEOPLASM": "OTHER",
261
+ 'BLASTIC PLASMACYTOID DENDRITRIC CELL NEOPLASM': "OTHER",
262
+ "B-ALL": "ALL",
263
+ "BALL": "ALL",
264
+ "TALL": "ALL",
265
+ "T-ALL": "ALL",
266
+ "AML": "AML",
267
+ "ACUTE MYELOID LEUKEMIA": "AML"
268
+ }
269
+
270
+ # # 0 nonmalignant; 1: malignant
271
+ malignant_map = {
272
+ 'AML': 1,
273
+ 'RED CELL DISORDERS': 0,
274
+ 'AMYLOIDOSIS': 0,
275
+ 'BMF SYNDROMES': 0,
276
+ 'ALL': 1,
277
+ 'OTHER': 0,
278
+ 'IMMUNE DISORDERS': 0,
279
+ 'CHRONIC LYMPHOCYTIC LEUKEMIA': 1,
280
+ 'MYELOPROLIFERATIVE NEOPLASMS': 1, # note: CML is malignant; not sure about MYELOPROLIFERATIVE DISORDER & MYELOFIBROSIS
281
+ 'HEMOPHAGOCYTIC LYMPHOHISTIOCYTOSIS (HLH)': 0,
282
+ 'LYMPHOMA': 1,
283
+ 'MYELODYSPLASTIC SYNDROME': 1,
284
+ 'MEDULLOBLASTOMA': 0,
285
+ 'MULTIPLE MYELOMA': 0,
286
+ 'NEUROBLASTOMA': 0,
287
+ 'PAROXYSMAL NOCTURNAL HEMOGLOBINURIA': 0,
288
+ 'PLASMA CELL LEUKEMIA': 0
289
+ }
290
+
291
+ def load_train_features():
292
+ # Define features
293
+ HLA_sub12 = [
294
+
295
+ # Recipient - HLA-A
296
+ 'R_HLA_A_1', 'R_HLA_A_2', 'R_HLA_A_3', 'R_HLA_A_4', 'R_HLA_A_7', 'R_HLA_A_8',
297
+ 'R_HLA_A_11', 'R_HLA_A_12', 'R_HLA_A_20', 'R_HLA_A_23', 'R_HLA_A_24', 'R_HLA_A_25',
298
+ 'R_HLA_A_26', 'R_HLA_A_29', 'R_HLA_A_30', 'R_HLA_A_31', 'R_HLA_A_32', 'R_HLA_A_33',
299
+ 'R_HLA_A_34', 'R_HLA_A_66', 'R_HLA_A_68', 'R_HLA_A_69', 'R_HLA_A_74', 'R_HLA_A_X',
300
+
301
+ # Recipient - HLA-B
302
+ 'R_HLA_B_7', 'R_HLA_B_8', 'R_HLA_B_13', 'R_HLA_B_14', 'R_HLA_B_15', 'R_HLA_B_18',
303
+ 'R_HLA_B_23', 'R_HLA_B_24', 'R_HLA_B_27', 'R_HLA_B_35', 'R_HLA_B_37', 'R_HLA_B_38',
304
+ 'R_HLA_B_39', 'R_HLA_B_40', 'R_HLA_B_41', 'R_HLA_B_42', 'R_HLA_B_44', 'R_HLA_B_45',
305
+ 'R_HLA_B_46', 'R_HLA_B_49', 'R_HLA_B_50', 'R_HLA_B_51', 'R_HLA_B_52', 'R_HLA_B_53',
306
+ 'R_HLA_B_55', 'R_HLA_B_56', 'R_HLA_B_57', 'R_HLA_B_58', 'R_HLA_B_73', 'R_HLA_B_81',
307
+ 'R_HLA_B_X',
308
+
309
+ # Recipient - HLA-C
310
+ 'R_HLA_C_1', 'R_HLA_C_2', 'R_HLA_C_3', 'R_HLA_C_4', 'R_HLA_C_5', 'R_HLA_C_6',
311
+ 'R_HLA_C_7', 'R_HLA_C_8', 'R_HLA_C_12', 'R_HLA_C_14', 'R_HLA_C_15', 'R_HLA_C_16',
312
+ 'R_HLA_C_17', 'R_HLA_C_18', 'R_HLA_C_38', 'R_HLA_C_49', 'R_HLA_C_50', 'R_HLA_C_X',
313
+
314
+ # Recipient - HLA-DR
315
+ 'R_HLA_DR_1', 'R_HLA_DR_2', 'R_HLA_DR_3', 'R_HLA_DR_4', 'R_HLA_DR_5', 'R_HLA_DR_6',
316
+ 'R_HLA_DR_7', 'R_HLA_DR_8', 'R_HLA_DR_9', 'R_HLA_DR_10', 'R_HLA_DR_11', 'R_HLA_DR_12',
317
+ 'R_HLA_DR_13', 'R_HLA_DR_14', 'R_HLA_DR_15', 'R_HLA_DR_16', 'R_HLA_DR_17', 'R_HLA_DR_X',
318
+
319
+ # Recipient - HLA-DQ
320
+ 'R_HLA_DQ_1', 'R_HLA_DQ_2', 'R_HLA_DQ_3', 'R_HLA_DQ_4', 'R_HLA_DQ_5', 'R_HLA_DQ_6',
321
+ 'R_HLA_DQ_7', 'R_HLA_DQ_11', 'R_HLA_DQ_15', 'R_HLA_DQ_16', 'R_HLA_DQ_301', 'R_HLA_DQ_X',
322
+
323
+ # Donor - HLA-A
324
+ 'D_HLA_A_1', 'D_HLA_A_2', 'D_HLA_A_3', 'D_HLA_A_8', 'D_HLA_A_11', 'D_HLA_A_12',
325
+ 'D_HLA_A_23', 'D_HLA_A_24', 'D_HLA_A_25', 'D_HLA_A_26', 'D_HLA_A_29', 'D_HLA_A_30',
326
+ 'D_HLA_A_31', 'D_HLA_A_32', 'D_HLA_A_33', 'D_HLA_A_34', 'D_HLA_A_66', 'D_HLA_A_68',
327
+ 'D_HLA_A_69', 'D_HLA_A_7', 'D_HLA_A_74', 'D_HLA_A_X',
328
+
329
+ # Donor - HLA-B
330
+ 'D_HLA_B_7', 'D_HLA_B_8', 'D_HLA_B_13', 'D_HLA_B_14', 'D_HLA_B_15', 'D_HLA_B_17',
331
+ 'D_HLA_B_18', 'D_HLA_B_23', 'D_HLA_B_24', 'D_HLA_B_27', 'D_HLA_B_35', 'D_HLA_B_37',
332
+ 'D_HLA_B_38', 'D_HLA_B_39', 'D_HLA_B_40', 'D_HLA_B_41', 'D_HLA_B_42', 'D_HLA_B_44',
333
+ 'D_HLA_B_45', 'D_HLA_B_48', 'D_HLA_B_49', 'D_HLA_B_50', 'D_HLA_B_51', 'D_HLA_B_52',
334
+ 'D_HLA_B_53', 'D_HLA_B_55', 'D_HLA_B_56', 'D_HLA_B_57', 'D_HLA_B_58', 'D_HLA_B_73',
335
+ 'D_HLA_B_81', 'D_HLA_B_X',
336
+
337
+ # Donor - HLA-C
338
+ 'D_HLA_C_1', 'D_HLA_C_2', 'D_HLA_C_3', 'D_HLA_C_4', 'D_HLA_C_5', 'D_HLA_C_6',
339
+ 'D_HLA_C_7', 'D_HLA_C_8', 'D_HLA_C_12', 'D_HLA_C_14', 'D_HLA_C_15', 'D_HLA_C_16',
340
+ 'D_HLA_C_17', 'D_HLA_C_18', 'D_HLA_C_38', 'D_HLA_C_49', 'D_HLA_C_50', 'D_HLA_C_X',
341
+
342
+ # Donor - HLA-DR
343
+ 'D_HLA_DR_1', 'D_HLA_DR_2', 'D_HLA_DR_3', 'D_HLA_DR_4', 'D_HLA_DR_5', 'D_HLA_DR_6',
344
+ 'D_HLA_DR_7', 'D_HLA_DR_8', 'D_HLA_DR_9', 'D_HLA_DR_10', 'D_HLA_DR_11', 'D_HLA_DR_12',
345
+ 'D_HLA_DR_13', 'D_HLA_DR_14', 'D_HLA_DR_15', 'D_HLA_DR_16', 'D_HLA_DR_17', 'D_HLA_DR_X',
346
+
347
+ # Donor - HLA-DQ
348
+ 'D_HLA_DQ_1', 'D_HLA_DQ_2', 'D_HLA_DQ_3', 'D_HLA_DQ_4', 'D_HLA_DQ_5', 'D_HLA_DQ_6',
349
+ 'D_HLA_DQ_7', 'D_HLA_DQ_11', 'D_HLA_DQ_15', 'D_HLA_DQ_16', 'D_HLA_DQ_301', 'D_HLA_DQ_X'
350
+ ]
351
+
352
+
353
+ HLA_sub12_without_X = [i for i in HLA_sub12 if "_X" not in i]
354
+
355
+ prehsct_onehot = [
356
+ 'PreHSCT_ALEMTUZUMAB',
357
+ 'PreHSCT_ATG',
358
+ 'PreHSCT_BEAM',
359
+ 'PreHSCT_BUSULFAN',
360
+ 'PreHSCT_CAMPATH',
361
+ 'PreHSCT_CARMUSTINE',
362
+ 'PreHSCT_CLOFARABINE',
363
+ 'PreHSCT_CYCLOPHOSPHAMIDE',
364
+ 'PreHSCT_CYCLOSPORIN',
365
+ 'PreHSCT_CYTARABINE',
366
+ 'PreHSCT_ETOPOSIDE',
367
+ 'PreHSCT_FLUDARABINE',
368
+ 'PreHSCT_GEMCITABINE',
369
+ 'PreHSCT_MELPHALAN',
370
+ 'PreHSCT_MTX',
371
+ 'PreHSCT_OTHER',
372
+ 'PreHSCT_RANIMUSTINE',
373
+ 'PreHSCT_REDUCEDCONDITIONING',
374
+ 'PreHSCT_RITUXIMAB',
375
+ 'PreHSCT_SIROLIMUS',
376
+ 'PreHSCT_TBI',
377
+ 'PreHSCT_THIOTEPA',
378
+ 'PreHSCT_TREOSULFAN',
379
+ 'PreHSCT_UA',
380
+ 'PreHSCT_VORNOSTAT',
381
+ ]
382
+
383
+ first_prophylaxis_onehot = [
384
+ 'First_GVHD_prophylaxis_ABATACEPT',
385
+ 'First_GVHD_prophylaxis_ALEMTUZUMAB',
386
+ 'First_GVHD_prophylaxis_ATG',
387
+ 'First_GVHD_prophylaxis_CYCLOPHOSPHAMIDE',
388
+ 'First_GVHD_prophylaxis_CYCLOSPORIN',
389
+ 'First_GVHD_prophylaxis_IMATINIB',
390
+ 'First_GVHD_prophylaxis_LEFLUNOMIDE',
391
+ 'First_GVHD_prophylaxis_MMF',
392
+ 'First_GVHD_prophylaxis_MTX',
393
+ 'First_GVHD_prophylaxis_NONE',
394
+ 'First_GVHD_prophylaxis_RUXOLITINIB',
395
+ 'First_GVHD_prophylaxis_SIROLIMUS',
396
+ 'First_GVHD_prophylaxis_STEROID',
397
+ 'First_GVHD_prophylaxis_TAC',
398
+ ]
399
+
400
+ train_features = [[
401
+ 'Recipient_gender',
402
+ 'R_Age_at_transplant_cutoff18',
403
+ 'Recepient_Nationality_Cultural',
404
+ 'Hematological Diagnosis_Grouped',
405
+ 'Recepient_Blood group before HSCT_MergePlusMinus',
406
+ 'D_Age_at_transplant_cutoff18',
407
+ 'Age_Gap_R_D',
408
+ 'Donor_gender',
409
+ 'D_Blood group_MergePlusMinus',
410
+ 'Number of lines of Rx before HSCT',
411
+ 'Source of cells',
412
+ 'Donor_relation to recipient',
413
+ ] + HLA_sub12_without_X + prehsct_onehot + first_prophylaxis_onehot][0]
414
+
415
+ # Categorical features
416
+ cat_features = [
417
+ 'Recipient_gender',
418
+ 'Recepient_Nationality_Cultural',
419
+ 'Hematological Diagnosis_Grouped',
420
+ 'Recepient_Blood group before HSCT_MergePlusMinus',
421
+ 'Donor_gender',
422
+ 'D_Blood group_MergePlusMinus',
423
+ 'Source of cells',
424
+ 'Donor_relation to recipient',
425
+ ]
426
+
427
+ return train_features, cat_features
428
+
429
+ def load_dataset(file_path: str) -> pd.DataFrame:
430
+ """Load dataset from CSV file and drop columns with all missing values"""
431
+ df = pd.read_csv(file_path, header=1)
432
+ return df.dropna(axis=1, how="all")
433
+
434
+ def normalize_strings(df: pd.DataFrame) -> pd.DataFrame:
435
+ """
436
+ Standardize string values across the dataset:
437
+ - Replace variations of unknown/NA with consistent token
438
+ - Correct common misspellings and abbreviations
439
+ - Capitalize all strings for consistency
440
+ - Strip leading/trailing whitespace
441
+ """
442
+ # Apply global string replacements
443
+ df = df.replace(STRING_NORMALIZATION_MAP, regex=True)
444
+
445
+ # Handle nationality-specific replacements
446
+ non_nationality_cols = [col for col in df.columns if "Nationality" not in col]
447
+ df[non_nationality_cols] = df[non_nationality_cols].replace(
448
+ {r"(?i)\buk\b": UNKNOWN_TOKEN}, regex=True
449
+ )
450
+
451
+ # Handle non-HLA specific replacements
452
+ non_hla_cols = [col for col in df.columns if "HLA" not in col]
453
+ df[non_hla_cols] = df[non_hla_cols].replace(
454
+ {r"(?i)\bna\b": UNKNOWN_TOKEN}, regex=True
455
+ )
456
+
457
+ # Capitalize all string values
458
+ df = df.applymap(lambda x: x.upper() if isinstance(x, str) else x)
459
+
460
+ # Strip whitespace
461
+ return df.applymap(lambda x: x.strip() if isinstance(x, str) else x)
462
+
463
+ def clean_blood_group_columns(df: pd.DataFrame, columns: list) -> pd.DataFrame:
464
+ """Remove spaces from specified blood group columns"""
465
+ for col in columns:
466
+ df[col] = df[col].str.replace(" ", "")
467
+ return df
468
+
469
+ def process_hla_columns(df: pd.DataFrame) -> pd.DataFrame:
470
+ """
471
+ Clean and process HLA columns by:
472
+ 1. Splitting combined HLA values into separate columns
473
+ 2. Standardizing missing value representation
474
+ 3. Sorting allele values numerically
475
+ 4. Recombining cleaned values
476
+ """
477
+ # Padding function to ensure 2 elements, filling with 'NA'. Used for Individual_Predictions
478
+ def pad_list(val):
479
+ if not isinstance(val, list):
480
+ val = []
481
+ return (val + ['NA', 'NA'])[:2]
482
+
483
+ hla_columns = [col for col in df.columns if "R_HLA" in col or "D_HLA" in col]
484
+ # hla_columns = ['R_HLA_A', 'R_HLA_B', 'R_HLA_C', 'R_HLA_DR', 'R_HLA_DQ',
485
+ # 'D_HLA_A', 'D_HLA_B', 'D_HLA_C', 'D_HLA_DR', 'D_HLA_DQ']
486
+
487
+ for col in hla_columns:
488
+ # Handle special NA representation
489
+ df[col] = df[col].replace({"NA": "NA&NA"})
490
+
491
+ # Split into two separate columns
492
+ split_cols = [f"{col}1", f"{col}2"]
493
+
494
+ if type(df[col].iloc[0]) != list: # and "&" in df[col].iloc[0]:
495
+ df[split_cols] = df[col].str.split("&", expand=True)
496
+ elif type(df[col].iloc[0]) == list:
497
+ df[col] = df[col].apply(pad_list)
498
+ df[split_cols] = pd.DataFrame(df[col].tolist(), index=df.index)
499
+
500
+ # Standardize missing values
501
+ missing_indicators = {" ", "NA", "N/A", UNKNOWN_TOKEN, "''", '""', "", "B1", None}
502
+ df[split_cols] = df[split_cols].replace(missing_indicators, np.nan)
503
+
504
+ # Convert to numeric and handle zeros
505
+ df[split_cols] = df[split_cols].apply(pd.to_numeric, errors='coerce')
506
+ df[split_cols] = df[split_cols].replace(0, np.nan)
507
+
508
+ # Sort values numerically
509
+ df[split_cols] = np.sort(df[split_cols], axis=1)
510
+
511
+ # Convert numbers to integers, missing to 'X'
512
+ df[split_cols] = df[split_cols].applymap(lambda x: str(int(x)) if pd.notna(x) else UNKNOWN_TOKEN)
513
+
514
+ # Recombine cleaned values
515
+ df[col] = df[split_cols].astype(str).agg("&".join, axis=1)
516
+
517
+ return df
518
+
519
+
520
+ def cast_as_int_if_possible(x):
521
+ try:
522
+ i = int(x)
523
+ # Only return int if conversion is lossless (e.g., avoid converting '5.5' -> 5)
524
+ if float(x) == i:
525
+ return i
526
+ except:
527
+ pass
528
+ return x
529
+
530
+ def HLA_unique_alleles(df, HLA_col1, HLA_col2):
531
+ HLA_col1_unique = df[HLA_col1].unique()
532
+ HLA_col2_unique = df[HLA_col2].unique()
533
+
534
+ HLA_col1_unique = [cast_as_int_if_possible(val) for val in HLA_col1_unique]
535
+ HLA_col2_unique = [cast_as_int_if_possible(val) for val in HLA_col2_unique]
536
+
537
+ unique_set = set(HLA_col1_unique).union(set(HLA_col2_unique))
538
+
539
+ # Replace NaN with "X"
540
+ unique_set = {(UNKNOWN_TOKEN if pd.isna(item) else str(item)) for item in unique_set}
541
+ print('unique_set', unique_set)
542
+ return sorted(unique_set)
543
+
544
+ def expand_HLA_cols_(df, HLA_col1, HLA_col2):
545
+ HLA_uniques = HLA_unique_alleles(df, HLA_col1, HLA_col2)
546
+
547
+ col_name = HLA_col1[:-1] # get "R_HLA_A" from "R_HLA_A1"
548
+ for i in HLA_uniques:
549
+ df[f"{col_name}_{i}"] = 0
550
+ df.loc[df[HLA_col1]==i, f"{col_name}_{i}"] = 1 # or = 1
551
+ df.loc[df[HLA_col2]==i, f"{col_name}_{i}"] = 1 # or = 1
552
+
553
+ return df
554
+
555
+ def expand_HLA_cols(df):
556
+ df = expand_HLA_cols_(df, HLA_col1="R_HLA_A1", HLA_col2="R_HLA_A2")
557
+ df = expand_HLA_cols_(df, HLA_col1="R_HLA_B1", HLA_col2="R_HLA_B2")
558
+ df = expand_HLA_cols_(df, HLA_col1="R_HLA_C1", HLA_col2="R_HLA_C2")
559
+ df = expand_HLA_cols_(df, HLA_col1="R_HLA_DR1", HLA_col2="R_HLA_DR2")
560
+ df = expand_HLA_cols_(df, HLA_col1="R_HLA_DQ1", HLA_col2="R_HLA_DQ2")
561
+
562
+ df = expand_HLA_cols_(df, HLA_col1="D_HLA_A1", HLA_col2="D_HLA_A2")
563
+ df = expand_HLA_cols_(df, HLA_col1="D_HLA_B1", HLA_col2="D_HLA_B2")
564
+ df = expand_HLA_cols_(df, HLA_col1="D_HLA_C1", HLA_col2="D_HLA_C2")
565
+ df = expand_HLA_cols_(df, HLA_col1="D_HLA_DR1", HLA_col2="D_HLA_DR2")
566
+ df = expand_HLA_cols_(df, HLA_col1="D_HLA_DQ1", HLA_col2="D_HLA_DQ2")
567
+ return df
568
+
569
+ def correct_nationalities(df: pd.DataFrame, column: str) -> pd.DataFrame:
570
+ """Standardize nationality names using predefined corrections"""
571
+ df[column] = df[column].replace(NATIONALITY_CORRECTIONS)
572
+ return df
573
+
574
+ def correct_indiv_drug_name(drug_list):
575
+ # Find all the drug names and separators in the string
576
+ parts = re.split(r'([ /+])', drug_list) # Split but keep the separators
577
+
578
+ corrected_parts = []
579
+
580
+ for part in parts:
581
+ # If the part is a drug name, apply the correction
582
+ if part.strip() and part.strip() not in {'', ' ', '/', '+'}:
583
+ corrected_part = DRUG_SPELLING_CORRECTIONS.get(part.strip(), part.strip())
584
+ corrected_parts.append(corrected_part)
585
+ else:
586
+ # If it's a separator (/, +, space), just keep it
587
+ corrected_parts.append(part)
588
+
589
+ # Join the parts back together
590
+ return ''.join(corrected_parts)
591
+
592
+ def correct_drug_name_in_list(df: pd.DataFrame, column: str) -> pd.DataFrame:
593
+ """Standardize drug names in a list using predefined corrections, preserving separators."""
594
+ # Apply the correction function to each entry in the specified column
595
+ df[column] = df[column].apply(correct_indiv_drug_name)
596
+
597
+ return df
598
+
599
+ def standardize_compound_columns(df: pd.DataFrame, columns: list) -> pd.DataFrame:
600
+ """
601
+ Process columns with compound values by:
602
+ 1. Removing spaces
603
+ 2. Standardizing separators
604
+ 3. Sorting components alphabetically
605
+ """
606
+ for col in columns:
607
+ if col in df.columns and type(df[col].iloc[0]) != list:
608
+ # Clean string values
609
+ df[col] = df[col].str.replace(" ", "").str.replace("+", "/").str.replace(",", "/")
610
+
611
+ # Split, remove empty parts, sort, and join
612
+ df[col] = df[col].apply(
613
+ lambda x: "/".join(sorted([part for part in x.split("/") if part])) if isinstance(x, str) else x
614
+ )
615
+ return df
616
+
617
+ def standardize_gender(df: pd.DataFrame) -> pd.DataFrame:
618
+ """Standardize donor gender values and infer from relationship where possible"""
619
+ # Apply gender mapping
620
+ df["Donor_gender"] = df["Donor_gender"].replace(GENDER_MAP)
621
+ df["Recipient_gender"] = df["Recipient_gender"].replace(GENDER_MAP)
622
+
623
+ # Infer gender from relationship
624
+ gender_map = {
625
+ "BROTHER": "MALE", "SISTER": "FEMALE",
626
+ "FATHER": "MALE", "MOTHER": "FEMALE",
627
+ "SON": "MALE", "DAUGHTER": "FEMALE",
628
+ "UNCLE": "MALE", "AUNT": "FEMALE"
629
+ }
630
+ for relationship, gender in gender_map.items():
631
+ mask = df["Donor_relation to recipient"] == relationship
632
+ df.loc[mask, "Donor_gender"] = gender
633
+
634
+ return df
635
+
636
+ def correct_donor_relationships(df: pd.DataFrame) -> pd.DataFrame:
637
+ """Standardize relationship categories using predefined corrections"""
638
+ return df.replace({"Donor_relation to recipient": RELATION_CORRECTIONS}, regex=True)
639
+
640
+ def handle_self_donor_consistency(df: pd.DataFrame) -> pd.DataFrame:
641
+ """
642
+ Ensure data consistency for self-donors by:
643
+ 1. Setting HLA values to 'SELF&SELF'
644
+ 2. Verifying matching demographics
645
+ """
646
+ self_mask = df["Donor_relation to recipient"] == "SELF"
647
+
648
+ # Set HLA values for self-donors
649
+ hla_cols = [col for col in df.columns if "R_HLA" in col or "D_HLA" in col]
650
+ df.loc[self_mask, hla_cols] = "SELF&SELF"
651
+
652
+ # Verify demographic consistency
653
+ assert df.loc[self_mask, "Recipient_gender"].equals(
654
+ df.loc[self_mask, "Donor_gender"]
655
+ ), "Recipient/Donor gender mismatch for self-donors"
656
+
657
+ assert df.loc[self_mask, "Recepient_Blood group before HSCT"].equals(
658
+ df.loc[self_mask, "D_Blood group"]
659
+ ), "Blood group mismatch for self-donors"
660
+
661
+ assert df.loc[self_mask, "Recepient_DOB"].equals(
662
+ df.loc[self_mask, "Donor_DOB"]
663
+ ), "DOB mismatch for self-donors"
664
+
665
+ return df
666
+
667
+ def safe_extract_year(date_str: str) -> str:
668
+ """
669
+ Safely extract year from date string:
670
+ - Returns year as integer if valid
671
+ - Returns UNKNOWN_TOKEN for invalid/missing dates
672
+ """
673
+ if not isinstance(date_str, str) or date_str == UNKNOWN_TOKEN:
674
+ return UNKNOWN_TOKEN
675
+
676
+ try:
677
+ # Handle special cases like "35 YEAR OLD"
678
+ if "YEAR" in date_str:
679
+ return UNKNOWN_TOKEN
680
+
681
+ parts = date_str.split("/")
682
+ if len(parts) < 3:
683
+ return UNKNOWN_TOKEN
684
+
685
+ year_part = parts[-1].strip()
686
+ return int(year_part) if year_part.isdigit() else UNKNOWN_TOKEN
687
+ except (ValueError, TypeError):
688
+ return UNKNOWN_TOKEN
689
+
690
+ def calculate_ages(df: pd.DataFrame) -> pd.DataFrame:
691
+ """
692
+ Calculate:
693
+ 1. Recipient age at transplant
694
+ 2. Donor age at transplant
695
+ 3. Age gap between recipient and donor
696
+ """
697
+ # Extract years safely
698
+ df["Recepient_DOB_Year"] = df["Recepient_DOB"].apply(safe_extract_year)
699
+ df["Donor_DOB_Year"] = df["Donor_DOB"].apply(safe_extract_year)
700
+ df["HSCT_date_Year"] = df["HSCT_date"].apply(safe_extract_year)
701
+
702
+ # Calculate ages with safe conversion
703
+ def calculate_age_diff(row, dob_col, transplant_col):
704
+ try:
705
+ return int(row[transplant_col]) - int(row[dob_col])
706
+ except (TypeError, ValueError):
707
+ return UNKNOWN_TOKEN
708
+
709
+ df["R_Age_at_transplant"] = df.apply(
710
+ lambda row: calculate_age_diff(row, "Recepient_DOB_Year", "HSCT_date_Year"),
711
+ axis=1
712
+ )
713
+
714
+ df["D_Age_at_transplant"] = df.apply(
715
+ lambda row: calculate_age_diff(row, "Donor_DOB_Year", "HSCT_date_Year"),
716
+ axis=1
717
+ )
718
+
719
+ df["Age_Gap_R_D"] = df.apply(
720
+ lambda row: calculate_age_diff(row, "Donor_DOB_Year", "Recepient_DOB_Year"),
721
+ axis=1
722
+ )
723
+
724
+ return df
725
+
726
+ # Utility Function: Split and One-Hot Encode Drug Regimens
727
+ def split_and_one_hot_encode(df, column_name, prefix):
728
+ """
729
+ Splits entries in a column by "/" and one-hot encodes the resulting tokens.
730
+
731
+ Args:
732
+ df (pd.DataFrame): Input dataframe
733
+ column_name (str): Name of the column to process
734
+ prefix (str): Prefix for the resulting one-hot encoded columns
735
+
736
+ Returns:
737
+ pd.DataFrame: DataFrame with one-hot encoded columns added
738
+ """
739
+ if type(df[column_name].iloc[0]) != list:
740
+ df[column_name] = df[column_name].fillna("").apply(lambda x: re.split(r'[/]', x) if x else [])
741
+ else:
742
+ pass
743
+
744
+ mlb = MultiLabelBinarizer()
745
+ encoded_df = pd.DataFrame(
746
+ mlb.fit_transform(df[column_name]),
747
+ columns=[f"{prefix}_{drug.strip()}" for drug in mlb.classes_],
748
+ index=df.index
749
+ )
750
+
751
+ df = pd.concat([df, encoded_df], axis=1)
752
+ return df
753
+
754
+ # Normalize Blood Groups (Remove +/-)
755
+ def merge_blood_groups(df, column, new_col):
756
+ """
757
+ Removes '+' and '-' from blood group values.
758
+
759
+ Args:
760
+ df (pd.DataFrame): Input dataframe
761
+ column (str): Column name to normalize
762
+ new_col (str): New column name for cleaned values
763
+
764
+ Returns:
765
+ pd.DataFrame: Updated dataframe
766
+ """
767
+ df[new_col] = df[column].apply(lambda x: re.sub(r'[+-]', '', x) if pd.notnull(x) else np.nan)
768
+ return df
769
+
770
+ def binarize_age(df, age_col, cutoff, new_col):
771
+ """
772
+ Binarizes age column based on a cutoff. Non-numeric values are left as-is.
773
+
774
+ Args:
775
+ df (pd.DataFrame): Input dataframe
776
+ age_col (str): Column name containing age
777
+ cutoff (int): Age cutoff
778
+ new_col (str): New binary column name
779
+
780
+ Returns:
781
+ pd.DataFrame: Updated dataframe
782
+ """
783
+ def binarize_or_keep(val):
784
+ try:
785
+ return int(val >= cutoff)
786
+ except TypeError:
787
+ return val # Leave strings or non-numeric values unchanged
788
+
789
+ df[new_col] = df[age_col].apply(binarize_or_keep)
790
+ return df
791
+
792
+ # Create Composite Gender & Relation Columns
793
+ def add_gender_relation_features(df):
794
+ """
795
+ Creates new columns combining donor relation with recipient and donor genders.
796
+
797
+ Returns:
798
+ pd.DataFrame: Updated dataframe
799
+ """
800
+ df["Relation_and_Recipient_Gender"] = df["Donor_relation to recipient"] + " R_" + df["Recipient_gender"]
801
+ df["Relation_and_Donor_Gender"] = df["Donor_relation to recipient"] + " D_" + df["Donor_gender"]
802
+ df["Relation_and_Recipient_and_Donor_Gender"] = (
803
+ df["Donor_relation to recipient"] + " R_" + df["Recipient_gender"] + " D_" + df["Donor_gender"]
804
+ )
805
+ return df
806
+
807
+ # Nationality-Based Groupings
808
+ def apply_nationality_groupings(df, column, grouping_dicts):
809
+ """
810
+ Applies multiple groupings based on nationality.
811
+
812
+ Args:
813
+ df (pd.DataFrame): Input dataframe
814
+ column (str): Column to group by
815
+ grouping_dicts (dict): Dictionary of {new_col_name: mapping_dict}
816
+
817
+ Returns:
818
+ pd.DataFrame: Updated dataframe
819
+ """
820
+ for new_col, mapping in grouping_dicts.items():
821
+ df[new_col] = df[column].replace(mapping)
822
+ return df
823
+
824
+ # Group and Binarize Diagnosis
825
+ def group_and_binarize_diagnosis(df, original_col, group_map, malignant_map):
826
+ """
827
+ Groups diagnosis into categories and flags as malignant or not.
828
+
829
+ Args:
830
+ df (pd.DataFrame): Input dataframe
831
+ original_col (str): Original diagnosis column
832
+ group_map (dict): Mapping of diagnoses to groups
833
+ malignant_map (dict): Mapping of groups to binary malignancy label
834
+
835
+ Returns:
836
+ pd.DataFrame: Updated dataframe
837
+ """
838
+ grouped_col = f"{original_col}_Grouped"
839
+ malignant_col = f"{original_col}_Malignant"
840
+
841
+ df[grouped_col] = df[original_col].replace(group_map)
842
+ df[malignant_col] = df[grouped_col].replace(malignant_map)
843
+ return df
844
+
845
+ def preprocess_pipeline(df) -> pd.DataFrame:
846
+ """
847
+ Full preprocessing pipeline:
848
+ 1. Load and initial cleaning
849
+ 2. String normalization
850
+ 3. Special column processing
851
+ 4. Data corrections
852
+ 5. Feature engineering
853
+ """
854
+ df = df.dropna(axis=1, how="all")
855
+
856
+ # Special column processing
857
+ # Strip leading/trailing spaces from column names
858
+ df.columns = df.columns.str.strip()
859
+ # Remove spaces from HLA columns
860
+ df.columns = [col.replace(" ", "") if "_HLA" in col else col for col in df.columns]
861
+
862
+ # String handling
863
+ df = normalize_strings(df)
864
+ df = clean_blood_group_columns(df, BLOOD_GROUP_COLS)
865
+
866
+ # Data corrections
867
+ df = correct_nationalities(df, "Recepient_Nationality")
868
+ df = correct_drug_name_in_list(df, "PreHSCT conditioning regimen+/-ATG+/-TBI")
869
+ df = correct_drug_name_in_list(df, "First_GVHD prophylaxis")
870
+ # df = correct_drug_name_in_list(df, "Post HSCT regimen")
871
+ df = standardize_compound_columns(
872
+ df,
873
+ ["PreHSCT conditioning regimen+/-ATG+/-TBI", "First_GVHD prophylaxis"]
874
+ )
875
+ df = standardize_gender(df)
876
+ df = correct_donor_relationships(df)
877
+
878
+ if "SELF" in df["Donor_relation to recipient"].unique():
879
+ df = handle_self_donor_consistency(df)
880
+
881
+ # HLA processing
882
+ df = process_hla_columns(df)
883
+ df = expand_HLA_cols(df)
884
+
885
+ # Feature engineering
886
+ df = calculate_ages(df)
887
+
888
+ # Final missing value handling
889
+ df = df.fillna(UNKNOWN_TOKEN)
890
+
891
+ # One-hot encode multi-drug regimen columns
892
+ df = split_and_one_hot_encode(df, 'PreHSCT conditioning regimen+/-ATG+/-TBI', 'PreHSCT')
893
+ df = split_and_one_hot_encode(df, 'First_GVHD prophylaxis', 'First_GVHD_prophylaxis')
894
+ # df = split_and_one_hot_encode(df, 'Post HSCT regimen', 'PostHSCT')
895
+
896
+ # Normalize blood groups
897
+ df = merge_blood_groups(df, "Recepient_Blood group before HSCT", "Recepient_Blood group before HSCT_MergePlusMinus")
898
+ df = merge_blood_groups(df, "D_Blood group", "D_Blood group_MergePlusMinus")
899
+
900
+ # Binarize ages
901
+ df = binarize_age(df, "R_Age_at_transplant", 16, "R_Age_at_transplant_cutoff16")
902
+ df = binarize_age(df, "R_Age_at_transplant", 18, "R_Age_at_transplant_cutoff18")
903
+ df = binarize_age(df, "D_Age_at_transplant", 16, "D_Age_at_transplant_cutoff16")
904
+ df = binarize_age(df, "D_Age_at_transplant", 18, "D_Age_at_transplant_cutoff18")
905
+
906
+ # Gender/Relation features
907
+ df = add_gender_relation_features(df)
908
+
909
+ # Group nationalities
910
+ df = apply_nationality_groupings(df, 'Recepient_Nationality', groupings)
911
+
912
+ # Group and binarize diagnosis
913
+ df = group_and_binarize_diagnosis(df, 'Hematological Diagnosis', diagnosis_group_map, malignant_map)
914
+
915
+ df = df.replace(UNKNOWN_TOKEN, np.nan)
916
+
917
+ # Add columns for new dfs for features that exist in the original dataset but not in the new one
918
+ for feature in load_train_features()[0]:
919
+ if ("_HLA" in feature or "First_GVHD_prophylaxis_" in feature or "PreHSCT_" in feature) and feature not in df.columns:
920
+ df[feature] = 0
921
+
922
+ return df
923
+
924
+ if __name__ == "__main__":
925
+ processed_data = preprocess_pipeline(
926
+ "/home/muhammadridzuan/2025_GVHD/2024_GVHD_SSMC/GVHD_Intel_data_MBZUAI_1.2.csv"
927
+ )
928
+ processed_data.to_csv("preprocessed_gvhd_data.csv", index=False)
src/saved_models/250706_150941_corr_drug_names_single.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69baff3c0aaedf52175dfb01c7663031e988b668eb8c7b4fa03d920de43265ce
3
+ size 149312
src/saved_models/250706_150942_corr_drug_names_ensemble.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71fe613ec10104d24e5d4623f053e46bf1abce9da257b8190ca6cea4a72ed7a5
3
+ size 855627
src/sidebar.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from pathlib import Path
3
+ import glob
4
+ from huggingface_hub import HfApi, hf_hub_download
5
+ import pyarrow.parquet as pq
6
+
7
+ # def sidebar():
8
+ # APP_DIR = Path(__file__).parent
9
+ # MODELS_DIR = APP_DIR / "saved_models"
10
+
11
+ # # Shared dropdown in the sidebar
12
+ # def get_model_options():
13
+ # models = ["Default"]
14
+ # model_files = glob.glob(str(MODELS_DIR / "*.pkl")) + glob.glob(str(MODELS_DIR / "*.cbm"))
15
+
16
+ # for m in model_files:
17
+ # models.append(Path(m).stem)
18
+ # return sorted(set(models))
19
+
20
+ # if 'selected_model' not in st.session_state:
21
+ # st.session_state.selected_model = "Default"
22
+
23
+ # st.sidebar.title("Model Selection")
24
+ # st.session_state.selected_model = st.sidebar.selectbox("Model", get_model_options())
25
+
26
+ def sidebar():
27
+ def get_model_options():
28
+ models = ["Default"]
29
+ api = HfApi(token=st.secrets["HF_TOKEN"])
30
+ all_files = api.list_repo_files(repo_id=st.secrets["HF_REPO_ID"], repo_type="dataset")
31
+ parquet_files = [f for f in all_files if f.startswith("models/") and f.endswith(".parquet")]
32
+
33
+ for f in parquet_files:
34
+ try:
35
+ # Download and read Parquet file
36
+ downloaded = hf_hub_download(
37
+ repo_id=st.secrets["HF_REPO_ID"],
38
+ repo_type="dataset",
39
+ filename=f,
40
+ token=st.secrets["HF_TOKEN"]
41
+ )
42
+ table = pq.read_table(downloaded)
43
+ row = table.to_pylist()[0]
44
+ models.append(row["filename"])
45
+ except Exception as e:
46
+ st.warning(f"Skipping model file due to error: {f} ({e})")
47
+
48
+ return sorted(set(models))
49
+
50
+ if 'selected_model' not in st.session_state:
51
+ st.session_state.selected_model = "Default"
52
+
53
+ st.sidebar.title("Model Selection")
54
+ st.session_state.selected_model = st.sidebar.selectbox("Model", get_model_options())