| | """Splits the model H cohort into train, test and balanced cross validation folds. |
| | |
| | The train set retains class ratio, sex and age distributions from the full dataset. |
| | Patients can only appear in either train or test set. |
| | |
| | This script also splits the train data into balanced folds for cross-validation. Patient |
| | IDs for train, test and all data folds are stored for use in subsequent scripts. |
| | """ |
| |
|
| | import numpy as np |
| | import os |
| | import pandas as pd |
| | import pickle |
| | import sys |
| | import yaml |
| | import splitting |
| |
|
| | with open("./training/config.yaml", "r") as config: |
| | config = yaml.safe_load(config) |
| |
|
| | |
| | |
| | |
| | save_cohort_info = True |
| |
|
| | |
| | model_type = config["model_settings"]["model_type"] |
| |
|
| | |
| | log = open("./training/logging/split_train_test_" + model_type + ".log", "w") |
| | sys.stdout = log |
| |
|
| | demographics = pd.read_pickle( |
| | os.path.join( |
| | config["outputs"]["processed_data_dir"], |
| | "demographics_{}.pkl".format(model_type), |
| | ) |
| | ) |
| |
|
| | |
| | |
| | |
| | train_data, test_data = splitting.subject_wise_train_test_split( |
| | data=demographics, |
| | target_col="ExacWithin3Months", |
| | id_col="StudyId", |
| | test_size=0.2, |
| | stratify_by_sex=True, |
| | sex_col="Sex_F", |
| | stratify_by_age=True, |
| | age_bin_col="AgeBinned", |
| | ) |
| | print(demographics.Sex_F.value_counts() / demographics.Sex_F.count()) |
| | print(train_data.Sex_F.value_counts() / train_data.Sex_F.count()) |
| | print(test_data.Sex_F.value_counts() / test_data.Sex_F.count()) |
| | print(demographics.Age.mean()) |
| | print(train_data.Age.mean()) |
| | print(test_data.Age.mean()) |
| |
|
| | train_ids = train_data.StudyId.unique() |
| | test_ids = test_data.StudyId.unique() |
| | |
| | |
| | |
| | fold_patients = splitting.subject_wise_kfold_split( |
| | train_data=train_data, |
| | target_col="ExacWithin3Months", |
| | id_col="StudyId", |
| | num_folds=5, |
| | sex_col="Sex_F", |
| | age_col="Age", |
| | stratify_by_sex=True, |
| | print_log=True, |
| | ) |
| |
|
| | |
| | |
| | |
| | fold_patients = np.array(fold_patients, dtype="object") |
| |
|
| | if save_cohort_info: |
| | |
| | os.makedirs(config["outputs"]["cohort_info_dir"], exist_ok=True) |
| | with open( |
| | os.path.join( |
| | config["outputs"]["cohort_info_dir"], "test_ids_" + model_type + ".pkl" |
| | ), |
| | "wb", |
| | ) as f: |
| | pickle.dump(list(test_ids), f) |
| | with open( |
| | os.path.join( |
| | config["outputs"]["cohort_info_dir"], "train_ids_" + model_type + ".pkl" |
| | ), |
| | "wb", |
| | ) as f: |
| | pickle.dump(list(train_ids), f) |
| | print("Train and test patient IDs saved") |
| |
|
| | |
| | np.save( |
| | os.path.join( |
| | config["outputs"]["cohort_info_dir"], "fold_patients_" + model_type + ".npy" |
| | ), |
| | fold_patients, |
| | allow_pickle=True, |
| | ) |
| | print("Cross validation fold information saved") |
| |
|