File size: 8,493 Bytes
a35137b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 | import copy
import os
from typing import Dict, List
import numpy as np
import torch
from barista.data.metadata import Metadata
from barista.models.utils import seed_everything
_SUPPORTED_SPLITS = ["shuffle", "chronological"]
class Splitter:
"""Helper class to handle train/test/val splitting."""
def __init__(
self,
config: Dict,
subjects: List,
experiment: str,
use_fixed_seed: bool = False,
):
self.config = config
self.subjects = subjects
self.experiment = experiment
self.use_fixed_seed = use_fixed_seed
def _use_configured_seed(func):
"""Decorator for changing seed for a specific function"""
def wrapper(self, *args, **kwargs):
if not self.use_fixed_seed:
return func(self, *args, **kwargs)
prev_seed = int(os.environ.get("PL_GLOBAL_SEED", 0))
new_seed = int(self.config.get("splitter_seed", 0))
print(
f"Changing seed from {prev_seed} to {new_seed} for splitting"
)
seed_everything(new_seed)
out = func(self, *args, **kwargs)
print(f"Changing back seed from {new_seed} to {prev_seed}.")
seed_everything(prev_seed)
return out
return wrapper
@_use_configured_seed
def set_splits_for_subject(
self,
subject: str,
metadata: Metadata,
split_method: str = "shuffle"
) -> Metadata:
"""Set train/validation/test split
Every `split_together_length_s` will be splitted into one of the train/val/test
NOTE: This function assumes the segments are in order and consecutive in metadata if you want
to use split together multiple consecutive segments
"""
# Set default if necessary.
if split_method not in _SUPPORTED_SPLITS:
print(f"[Warning] Setting split_method={split_method} to 'shuffle'")
split_method = "shuffle"
# Ensure the split together length is at least as long as the segments.
# Setting allows to split time series based on intervals > neural segment length.
split_together_length_s = max(
self.config.get("split_together_length_s", self.config.segment_length_s),
self.config.segment_length_s
)
subject_rows_indices = metadata.get_indices_matching_cols_values(
["subject", "experiment"], [subject, self.experiment]
)
if split_method == "chronological":
return self._set_splits_across_time(
metadata, subject_rows_indices=subject_rows_indices
)
split_together_count = int(
split_together_length_s // self.config.segment_length_s
)
consecutive = (torch.diff(torch.tensor(subject_rows_indices)) == 1).all()
if split_together_count > 1:
assert (
consecutive
), "subject rows are not consecutive, can't do splitting together"
n_segments = len(subject_rows_indices)
if n_segments == 0:
print(
f"[WARNING] No rows found for the subject {subject} and experiment {self.experiment} in metadata"
)
return metadata
starting_ind = subject_rows_indices[0]
if consecutive:
groups = list(
range(
starting_ind,
starting_ind + n_segments - split_together_count + 1,
split_together_count,
)
)
else:
# we've asserted that split_together_count is 1 in this case
groups = copy.deepcopy(subject_rows_indices)
np.random.shuffle(groups)
val_size = max(int(self.config.val_ratio * len(groups)), 1)
test_size = max(int(self.config.test_ratio * len(groups)), 1)
val_indices = []
for group_starting_idx in groups[:val_size]:
group_elem_indices = np.arange(split_together_count) + group_starting_idx
val_indices.extend(group_elem_indices)
test_indices = []
for group_starting_idx in groups[val_size : val_size + test_size]:
group_elem_indices = np.arange(split_together_count) + group_starting_idx
test_indices.extend(group_elem_indices)
metadata.set_col_to_value(subject_rows_indices, "split", "train")
metadata.set_col_to_value(val_indices, "split", "val")
metadata.set_col_to_value(test_indices, "split", "test")
return metadata
@_use_configured_seed
def resplit_for_subject(
self,
subject_session: str,
metadata: Metadata,
split_method: str,
) -> Metadata:
if split_method == "chronological":
return self._set_splits_across_time(
metadata, subject_session=subject_session
)
else:
print("[WARNING] Resplitting only for chronological; splits unchanged")
return metadata
def __check_contiguous(self, subject_rows_indices, check_monotonic_only=False):
if check_monotonic_only:
assert (
torch.diff(torch.tensor(subject_rows_indices)) >= 1
).all(), "subject rows are not consecutive, can't do splitting together"
else: # we need to be exactly increments of one.
assert (
torch.diff(torch.tensor(subject_rows_indices)) == 1
).all(), "subject rows are not consecutive, can't do splitting together"
@_use_configured_seed
def _set_splits_across_time(
self,
metadata: Metadata,
subject_rows_indices: list = [],
subject_session: str = "",
return_splitted_indices: bool = False,
check_monotonic_only: bool = False,
verbose: bool = False,
) -> Metadata:
if not subject_rows_indices and not subject_session:
raise ValueError(
"Need to either pass complete subject session name or subject_row_indices"
)
if (
not subject_rows_indices
): # Prioritize using the subject_row_indices if given.
subject_rows_indices = metadata.get_indices_matching_cols_values(
["subject_session", "experiment"], [subject_session, self.experiment]
)
self.__check_contiguous(
subject_rows_indices, check_monotonic_only=check_monotonic_only
)
n_segments = len(subject_rows_indices)
assert len(self.config.run_ratios) == len(self.config.run_splits)
counts = (np.array(self.config.run_ratios) * n_segments).astype(int)
counts[-1] = n_segments - sum(counts[:-1])
if verbose:
print(f"subject_session: {subject_session}")
print(f"RATIOS: {self.config.run_ratios}")
print(f"self.config.run_splits: {self.config.run_splits}")
print(f"COUNTS: {counts}")
if return_splitted_indices:
splitted_indices = []
sum_now = 0
for c, split in zip(counts, self.config.run_splits):
label_split_indices = subject_rows_indices[sum_now : sum_now + c]
if return_splitted_indices:
splitted_indices.append(label_split_indices)
sum_now += c
metadata.set_col_to_value(label_split_indices, "split", split)
self._check_split_labels(metadata, subject_session)
if return_splitted_indices:
return metadata, splitted_indices
return metadata
def _check_split_labels(self, metadata, subject_session):
# Check that both labels available in each split.
# NOTE: Not using asserts because the initial default splits might not have
# both, but the ones computed offline will and provided through the .pkl file
# will satisfy requirement.
for split in np.unique(self.config.run_splits):
for i in range(2): # magic 2 = positive/negative labels
if (
len(
metadata.get_indices_matching_cols_values(
["subject_session", "experiment", "label", "split"],
[subject_session, self.experiment, i, split],
)
)
== 0
):
print(f"split {split} missing label {i}")
|