copd-model-h / training /split_train_test_val.py
IamGrooooot's picture
Inital Upload
000de75
"""Splits the model H cohort into train, test and balanced cross validation folds.
The train set retains class ratio, sex and age distributions from the full dataset.
Patients can only appear in either train or test set.
This script also splits the train data into balanced folds for cross-validation. Patient
IDs for train, test and all data folds are stored for use in subsequent scripts.
"""
import numpy as np
import os
import pandas as pd
import pickle
import sys
import yaml
import splitting
with open("./training/config.yaml", "r") as config:
config = yaml.safe_load(config)
##############################################################
# Load correct files
##############################################################
save_cohort_info = True
# Specify which model to perform split on
model_type = config["model_settings"]["model_type"]
# Setup log file
log = open("./training/logging/split_train_test_" + model_type + ".log", "w")
sys.stdout = log
demographics = pd.read_pickle(
os.path.join(
config["outputs"]["processed_data_dir"],
"demographics_{}.pkl".format(model_type),
)
)
##############################################################
# Split data into a train and hold out test set
##############################################################
train_data, test_data = splitting.subject_wise_train_test_split(
data=demographics,
target_col="ExacWithin3Months",
id_col="StudyId",
test_size=0.2,
stratify_by_sex=True,
sex_col="Sex_F",
stratify_by_age=True,
age_bin_col="AgeBinned",
)
print(demographics.Sex_F.value_counts() / demographics.Sex_F.count())
print(train_data.Sex_F.value_counts() / train_data.Sex_F.count())
print(test_data.Sex_F.value_counts() / test_data.Sex_F.count())
print(demographics.Age.mean())
print(train_data.Age.mean())
print(test_data.Age.mean())
train_ids = train_data.StudyId.unique()
test_ids = test_data.StudyId.unique()
##############################################################
# Split training data into groups for cross validation
##############################################################
fold_patients = splitting.subject_wise_kfold_split(
train_data=train_data,
target_col="ExacWithin3Months",
id_col="StudyId",
num_folds=5,
sex_col="Sex_F",
age_col="Age",
stratify_by_sex=True,
print_log=True,
)
##############################################################
# Save cohort info
##############################################################
fold_patients = np.array(fold_patients, dtype="object")
if save_cohort_info:
# Save train and test ID info
os.makedirs(config["outputs"]["cohort_info_dir"], exist_ok=True)
with open(
os.path.join(
config["outputs"]["cohort_info_dir"], "test_ids_" + model_type + ".pkl"
),
"wb",
) as f:
pickle.dump(list(test_ids), f)
with open(
os.path.join(
config["outputs"]["cohort_info_dir"], "train_ids_" + model_type + ".pkl"
),
"wb",
) as f:
pickle.dump(list(train_ids), f)
print("Train and test patient IDs saved")
# Save cross validation fold info
np.save(
os.path.join(
config["outputs"]["cohort_info_dir"], "fold_patients_" + model_type + ".npy"
),
fold_patients,
allow_pickle=True,
)
print("Cross validation fold information saved")