| | """Module to perform splitting of data into train/test or K-folds.""" |
| |
|
| | import pandas as pd |
| | import numpy as np |
| | from sklearn.model_selection import StratifiedGroupKFold |
| |
|
| |
|
| | def subject_wise_train_test_split( |
| | data, |
| | target_col, |
| | id_col, |
| | test_size, |
| | stratify_by_sex=False, |
| | sex_col=None, |
| | stratify_by_age=False, |
| | age_bin_col=None, |
| | shuffle=False, |
| | random_state=None, |
| | ): |
| | """Subject wise splitting data into train and test sets. |
| | |
| | Splits data into train and test sets ensuring that the same patient can only appear in |
| | either train or test set. Stratifies data according to class label, with additional |
| | options to stratify by sex and age. |
| | |
| | Parameters |
| | ---------- |
| | data : pd.DataFrame |
| | dataframe containing features and target. |
| | target_col : str |
| | name of target column. |
| | id_col : str |
| | name of patient id column. |
| | test_size : float |
| | represents the proportion of the dataset to include in the test split. Float should |
| | be between 0 and 1. |
| | stratify_by_sex : bool, optional |
| | option to stratify data by sex, by default False. |
| | sex_col : str, optional |
| | name of sex column, by default None. |
| | stratify_by_age : bool, optional |
| | option to stratify data by age, by default False. |
| | age_bin_col : str, optional |
| | name of age column, by default None. Age column must be provided in binned format. |
| | shuffle : bool, optional |
| | whether to shuffle each class's samples before splitting into batches, by default |
| | False. |
| | random_state : int, optional |
| | when shuffle is True, random_state affects the ordering of the indices, by default |
| | None. |
| | |
| | Returns |
| | ------- |
| | train_data : pd.DataFrame |
| | train data stratified by class. Also stratified by age/sex as specified in input |
| | parameters. |
| | test_data : pd.DataFrame |
| | test data stratified by class. Also stratified by age/sex as specified in input |
| | parameters. |
| | |
| | Raises |
| | ------- |
| | ValueError : error raised when boolean stratify_by_age or stratify_by_sex is True but |
| | the respective columns are not provided for stratifying. |
| | |
| | """ |
| | |
| | |
| | if (stratify_by_age is True) & (age_bin_col is None): |
| | raise ValueError( |
| | "Parameter stratify_by_age is True but age_bin_col not provided." |
| | ) |
| | if (stratify_by_sex is True) & (sex_col is None): |
| | raise ValueError("Parameter stratify_by_sex is True but sex_col not provided.") |
| |
|
| | |
| | |
| | if (stratify_by_sex is True) and (stratify_by_age is True): |
| | data["TempTarget"] = ( |
| | data[target_col].astype(str) + data[sex_col].astype(str) + data[age_bin_col] |
| | ) |
| | elif (stratify_by_sex is True) and (stratify_by_age is False): |
| | data["TempTarget"] = data[target_col].astype(str) + data[sex_col].astype(str) |
| | elif (stratify_by_sex is False) and (stratify_by_age is True): |
| | data["TempTarget"] = data[target_col].astype(str) + data[sex_col].astype(str) |
| | else: |
| | data["TempTarget"] = data[target_col] |
| | temp_target_col = "TempTarget" |
| |
|
| | |
| | num_folds = round(1 / test_size) |
| | sgkf = StratifiedGroupKFold( |
| | n_splits=num_folds, shuffle=shuffle, random_state=random_state |
| | ) |
| |
|
| | |
| | train_test_splits = next( |
| | sgkf.split(data, data[temp_target_col], groups=data[id_col]) |
| | ) |
| | train_ids = train_test_splits[0].tolist() |
| | test_ids = train_test_splits[1].tolist() |
| | train_data = data.iloc[train_ids] |
| | test_data = data.iloc[test_ids] |
| |
|
| | |
| | train_data = train_data.drop(columns=temp_target_col) |
| | test_data = test_data.drop(columns=temp_target_col) |
| |
|
| | return train_data, test_data |
| |
|
| |
|
| | def subject_wise_kfold_split( |
| | train_data, |
| | target_col, |
| | id_col, |
| | num_folds, |
| | sex_col=None, |
| | age_col=None, |
| | stratify_by_sex=False, |
| | stratify_by_age=False, |
| | age_bin_col=None, |
| | shuffle=False, |
| | random_state=None, |
| | print_log=False, |
| | ): |
| | """Subject wise splitting data into balanced K-folds. |
| | |
| | Splits data into K-folds ensuring that the same patient can only appear in |
| | either train or validation set. Stratifies data according to class label, with additional |
| | options to stratify by sex and age. |
| | |
| | Parameters |
| | ---------- |
| | train_data : pd.DataFrame |
| | dataframe containing features and target. |
| | target_col : str |
| | name of target column. |
| | id_col : str |
| | name of patient id column. |
| | num_folds : int |
| | number of folds. |
| | sex_col : str, optional |
| | name of sex column, by default None. Required if stratify_by_sex is True. Can be |
| | included when stratify_by_sex is False to get info on sex ratio across folds. |
| | age_col : str, optional |
| | name of age column, by default None. Column must be a continous variable. Can be |
| | included to get info on mean age across folds. |
| | stratify_by_sex : bool, optional |
| | option to stratify data by sex, by default False. |
| | stratify_by_age : bool, optional |
| | option to stratify data by age, by default False. The binned age (age_bin_col) is |
| | used for stratifying by age rather than age_col. |
| | age_bin_col : str, optional |
| | name of age column, by default None. Age column must be provided in binned format. |
| | shuffle : bool, optional |
| | whether to shuffle each class's samples before splitting into batches, by default |
| | False. |
| | random_state : int, optional |
| | when shuffle is True, random_state affects the ordering of the indices, by default |
| | None. |
| | print_log : bool, optional |
| | flag to print distributions across folds, by default False. |
| | |
| | Returns |
| | ------- |
| | validation_fold_ids : list of arrays |
| | each array contains the validation patient IDs for each fold. |
| | |
| | Raises |
| | ------- |
| | ValueError : error raised when boolean stratify_by_age or stratify_by_sex is True but |
| | the respective columns are not provided for stratifying. |
| | |
| | """ |
| | |
| | |
| | if (stratify_by_age is True) & (age_bin_col is None): |
| | raise ValueError( |
| | "Parameter stratify_by_age is True but age_bin_col not provided." |
| | ) |
| | if (stratify_by_sex is True) & (sex_col is None): |
| | raise ValueError("Parameter stratify_by_sex is True but sex_col not provided.") |
| | |
| | |
| | |
| | if (stratify_by_sex is True) and (stratify_by_age is True): |
| | train_data["TempTarget"] = ( |
| | train_data[target_col].astype(str) |
| | + train_data[sex_col].astype(str) |
| | + train_data[age_bin_col] |
| | ) |
| | elif (stratify_by_sex is True) and (stratify_by_age is False): |
| | train_data["TempTarget"] = train_data[target_col].astype(str) + train_data[ |
| | sex_col |
| | ].astype(str) |
| | elif (stratify_by_sex is False) and (stratify_by_age is True): |
| | train_data["TempTarget"] = train_data[target_col].astype(str) + train_data[ |
| | sex_col |
| | ].astype(str) |
| | else: |
| | train_data["TempTarget"] = train_data[target_col] |
| | temp_target_col = "TempTarget" |
| |
|
| | sgkf_train = StratifiedGroupKFold( |
| | n_splits=num_folds, shuffle=shuffle, random_state=random_state |
| | ) |
| |
|
| | validation_fold_ids = [] |
| | class_fold_ratios = [] |
| | sex_fold_ratios = [] |
| | fold_mean_ages = [] |
| | for i, (train_index, validation_index) in enumerate( |
| | sgkf_train.split( |
| | train_data, train_data[temp_target_col], groups=train_data[id_col] |
| | ) |
| | ): |
| | |
| | validation_ids = train_data[id_col].iloc[validation_index].unique() |
| | validation_fold_ids.append(validation_ids) |
| |
|
| | |
| | train_fold_data = train_data[~train_data[id_col].isin(validation_ids)] |
| | class_ratio = train_fold_data[target_col].value_counts()[1] / len( |
| | train_fold_data |
| | ) |
| | class_fold_ratios.append(class_ratio) |
| | if not sex_col is None: |
| | sex_ratio = train_fold_data[sex_col].value_counts()[1] / len( |
| | train_fold_data |
| | ) |
| | sex_fold_ratios.append(sex_ratio) |
| | if not age_col is None: |
| | mean_age = train_fold_data[age_col].mean() |
| | fold_mean_ages.append(mean_age) |
| |
|
| | if print_log is True: |
| | print("Fold proportions:") |
| | print("Train class ratio:", class_fold_ratios) |
| | if not sex_col is None: |
| | print("Sex class ratio:", sex_fold_ratios) |
| | if not age_col is None: |
| | print("Mean age:", fold_mean_ages) |
| |
|
| | |
| | validation_fold_ids = np.asarray(validation_fold_ids, dtype="object") |
| |
|
| | |
| | del train_data[temp_target_col] |
| |
|
| | return validation_fold_ids |
| |
|
| |
|
| | def get_cv_fold_indices(validation_fold_ids, train_data, id_col): |
| | """ |
| | Find train/val dataframe indices for each fold and format for cross validation. |
| | |
| | Creates a tuple with training and validation indices for each K-fold using the |
| | validation_fold_ids. These patients are assigned to the validation portion of the data |
| | and all other patients are assigned to the train portion for that fold. |
| | Checks that all patient IDs listed for the K folds are contained in the train data. |
| | For each fold, extracts the dataframe indices for patient data belonging that fold |
| | and assigns all other indices to the 'train' portion. The list returned contains tuples |
| | required to be passed to sklearn's cross_validate function (through the cv argument). |
| | |
| | Parameters |
| | ---------- |
| | fold_patients : array |
| | lists of patient IDs for each of the K folds. |
| | train_data : pd.DataFrame |
| | train data (must contain id_col). |
| | id_col : str |
| | name of column containing patient ID. |
| | |
| | Returns |
| | ------- |
| | cross_validation_fold_indices : list of tuples |
| | K lists of val/train dataframe indices. |
| | |
| | """ |
| | |
| | cross_val_fold_indices = [] |
| | for fold in validation_fold_ids: |
| | fold_val_ids = train_data[train_data[id_col].isin(fold)] |
| | fold_train_ids = train_data[~train_data[id_col].isin(fold)] |
| |
|
| | |
| | fold_val_index = fold_val_ids.index |
| | fold_train_index = fold_train_ids.index |
| |
|
| | |
| | cross_val_fold_indices.append((fold_train_index, fold_val_index)) |
| | return cross_val_fold_indices |
| |
|