copd-model-c / training /train_test_split.py
IamGrooooot's picture
Initial release: 72-hour COPD exacerbation prediction model
e69d4e4
"""Splits the model C cohort and patient days into stratified train and test sets.
The train set retains these characteristics of the full data set:
Exac days to non-exac days ratio (within 5%). Individual patients can only appear in
either train or test.
Sex ratio (within 0.05)
Age distribution (minimum p-value for Kolmogorov-Smirnov test=0.9)
This script also splits the train data into balanced folds for cross-validation. Patient
IDs for train, test and all data folds are stored for use in subsequent scripts.
All data sets are divided into train and test and stored in separate folders.
"""
import numpy as np
import os
import pandas as pd
import pickle
from lenusml import splits
data_dir = '<YOUR_DATA_PATH>/copd-dataset/'
output_train_data_dir = '<YOUR_DATA_PATH>/train_data'
output_test_data_dir = '<YOUR_DATA_PATH>/test_data'
cohort_info_dir = '../data/cohort_info/'
save_cohort_info = True
data = pd.read_pickle(os.path.join(data_dir, 'exac_data.pkl'))
##########################################
# Prepare demographic info for splitting
##########################################
# Calculate decimal age on DateOfEvent
data['DateOfBirth'] = pd.to_datetime(data['DateOfBirth'], utc=True)
def calculate_age_decimal(dob, date):
age = date - dob
decimal_age = (age.days + age.seconds / 86400.0) / 365.2425
return decimal_age
data['Age'] = data.apply(lambda x: calculate_age_decimal(
x['DateOfBirth'], x['DateOfEvent']), axis=1)
data = data.drop(columns=['DateOfBirth'])
##########################################
# Merge with COPD status and inhaler data
##########################################
patient_details = pd.read_csv(os.path.join(data_dir, 'CopdDatasetPatientDetails.txt'),
usecols=['StudyId', 'CopdStatusDetailsId'],
delimiter="|")
copd_status = pd.read_csv(os.path.join(data_dir, 'CopdDatasetCopdStatusDetails.txt'),
usecols=['Id', 'SmokingStatus', 'RequiredAcuteNIV',
'RequiredICUAdmission',
'LungFunction_FEV1PercentPredicted',
'LabsHighestEosinophilCount'],
delimiter="|")
# Strip out % signs from spirometry and convert to float
copd_status['LungFunction_FEV1PercentPredicted'] = copd_status[
'LungFunction_FEV1PercentPredicted'].str.strip('%').astype('float')
patient_details = patient_details.merge(
copd_status, left_on='CopdStatusDetailsId', right_on='Id',
how='left').drop(columns=['CopdStatusDetailsId', 'Id'])
data = data.merge(patient_details, on='StudyId', how='left')
#################################
# Define train and test cohorts
#################################
print('Split data into train and test')
# Set the class ratio tolerance to 5% of the data class ratio
class_ratio_tolerance = 0.05 * data.IsExac.value_counts(normalize=True)[0] /\
data.IsExac.value_counts(normalize=True)[1]
print("Class ratio tolerance: ", class_ratio_tolerance)
# Set the sex ratio tolerance to 5% of the data class ratio
sex_ratio_tolerance = 0.05 * data.Sex.value_counts(normalize=True)['M'] /\
data.Sex.value_counts(normalize=True)['F']
print("Sex ratio tolerance: ", sex_ratio_tolerance)
train_data, test_data, train_ids, test_ids = splits.train_test_stratified_class_sex(
data=data, id_column='StudyId', class_column='IsExac', sex_column='Sex',
train_proportion=0.85,
proportion_tolerance=0.05, class_ratio_tolerance=class_ratio_tolerance,
sex_ratio_tolerance=sex_ratio_tolerance, random_seed=42)
#################################
# Create cross validation folds
#################################
fold_proportions, fold_class_ratios, fold_patients = splits.group_kfold_class_balanced(
data=train_data, id_column='StudyId', class_column='IsExac', K=5,
fold_proportion_tolerance=0.05,
fold_class_ratio_tolerance=class_ratio_tolerance, random_seed=42)
if save_cohort_info:
os.makedirs(cohort_info_dir, exist_ok=True)
with open(os.path.join(cohort_info_dir, "test_ids.pkl"), 'wb') as f:
pickle.dump(list(test_ids), f)
with open(os.path.join(cohort_info_dir, "train_ids.pkl"), 'wb') as f:
pickle.dump(list(train_ids), f)
print('Train and test patient IDs saved')
with open(os.path.join(cohort_info_dir, "fold_proportions.pkl"), 'wb') as f:
pickle.dump(list(fold_proportions), f)
with open(os.path.join(cohort_info_dir, "fold_class_ratios.pkl"), 'wb') as f:
pickle.dump(list(fold_class_ratios), f)
np.save(os.path.join(cohort_info_dir, 'fold_patients.npy'), fold_patients,
allow_pickle=True)
print('Cross validation fold information saved')
###############################
# Save train and test sets
###############################
# Create the output directories
os.makedirs(output_train_data_dir, exist_ok=True)
os.makedirs(output_test_data_dir, exist_ok=True)
# Save exac and patient details info
train_data.to_pickle(os.path.join(output_train_data_dir, 'train_data.pkl'))
test_data.to_pickle(os.path.join(output_test_data_dir, 'test_data.pkl'))
print('Patient details/exac data saved')