BaRISTA / barista /data /splitter.py
savaw's picture
Upload folder using huggingface_hub
a35137b verified
import copy
import os
from typing import Dict, List
import numpy as np
import torch
from barista.data.metadata import Metadata
from barista.models.utils import seed_everything
_SUPPORTED_SPLITS = ["shuffle", "chronological"]
class Splitter:
"""Helper class to handle train/test/val splitting."""
def __init__(
self,
config: Dict,
subjects: List,
experiment: str,
use_fixed_seed: bool = False,
):
self.config = config
self.subjects = subjects
self.experiment = experiment
self.use_fixed_seed = use_fixed_seed
def _use_configured_seed(func):
"""Decorator for changing seed for a specific function"""
def wrapper(self, *args, **kwargs):
if not self.use_fixed_seed:
return func(self, *args, **kwargs)
prev_seed = int(os.environ.get("PL_GLOBAL_SEED", 0))
new_seed = int(self.config.get("splitter_seed", 0))
print(
f"Changing seed from {prev_seed} to {new_seed} for splitting"
)
seed_everything(new_seed)
out = func(self, *args, **kwargs)
print(f"Changing back seed from {new_seed} to {prev_seed}.")
seed_everything(prev_seed)
return out
return wrapper
@_use_configured_seed
def set_splits_for_subject(
self,
subject: str,
metadata: Metadata,
split_method: str = "shuffle"
) -> Metadata:
"""Set train/validation/test split
Every `split_together_length_s` will be splitted into one of the train/val/test
NOTE: This function assumes the segments are in order and consecutive in metadata if you want
to use split together multiple consecutive segments
"""
# Set default if necessary.
if split_method not in _SUPPORTED_SPLITS:
print(f"[Warning] Setting split_method={split_method} to 'shuffle'")
split_method = "shuffle"
# Ensure the split together length is at least as long as the segments.
# Setting allows to split time series based on intervals > neural segment length.
split_together_length_s = max(
self.config.get("split_together_length_s", self.config.segment_length_s),
self.config.segment_length_s
)
subject_rows_indices = metadata.get_indices_matching_cols_values(
["subject", "experiment"], [subject, self.experiment]
)
if split_method == "chronological":
return self._set_splits_across_time(
metadata, subject_rows_indices=subject_rows_indices
)
split_together_count = int(
split_together_length_s // self.config.segment_length_s
)
consecutive = (torch.diff(torch.tensor(subject_rows_indices)) == 1).all()
if split_together_count > 1:
assert (
consecutive
), "subject rows are not consecutive, can't do splitting together"
n_segments = len(subject_rows_indices)
if n_segments == 0:
print(
f"[WARNING] No rows found for the subject {subject} and experiment {self.experiment} in metadata"
)
return metadata
starting_ind = subject_rows_indices[0]
if consecutive:
groups = list(
range(
starting_ind,
starting_ind + n_segments - split_together_count + 1,
split_together_count,
)
)
else:
# we've asserted that split_together_count is 1 in this case
groups = copy.deepcopy(subject_rows_indices)
np.random.shuffle(groups)
val_size = max(int(self.config.val_ratio * len(groups)), 1)
test_size = max(int(self.config.test_ratio * len(groups)), 1)
val_indices = []
for group_starting_idx in groups[:val_size]:
group_elem_indices = np.arange(split_together_count) + group_starting_idx
val_indices.extend(group_elem_indices)
test_indices = []
for group_starting_idx in groups[val_size : val_size + test_size]:
group_elem_indices = np.arange(split_together_count) + group_starting_idx
test_indices.extend(group_elem_indices)
metadata.set_col_to_value(subject_rows_indices, "split", "train")
metadata.set_col_to_value(val_indices, "split", "val")
metadata.set_col_to_value(test_indices, "split", "test")
return metadata
@_use_configured_seed
def resplit_for_subject(
self,
subject_session: str,
metadata: Metadata,
split_method: str,
) -> Metadata:
if split_method == "chronological":
return self._set_splits_across_time(
metadata, subject_session=subject_session
)
else:
print("[WARNING] Resplitting only for chronological; splits unchanged")
return metadata
def __check_contiguous(self, subject_rows_indices, check_monotonic_only=False):
if check_monotonic_only:
assert (
torch.diff(torch.tensor(subject_rows_indices)) >= 1
).all(), "subject rows are not consecutive, can't do splitting together"
else: # we need to be exactly increments of one.
assert (
torch.diff(torch.tensor(subject_rows_indices)) == 1
).all(), "subject rows are not consecutive, can't do splitting together"
@_use_configured_seed
def _set_splits_across_time(
self,
metadata: Metadata,
subject_rows_indices: list = [],
subject_session: str = "",
return_splitted_indices: bool = False,
check_monotonic_only: bool = False,
verbose: bool = False,
) -> Metadata:
if not subject_rows_indices and not subject_session:
raise ValueError(
"Need to either pass complete subject session name or subject_row_indices"
)
if (
not subject_rows_indices
): # Prioritize using the subject_row_indices if given.
subject_rows_indices = metadata.get_indices_matching_cols_values(
["subject_session", "experiment"], [subject_session, self.experiment]
)
self.__check_contiguous(
subject_rows_indices, check_monotonic_only=check_monotonic_only
)
n_segments = len(subject_rows_indices)
assert len(self.config.run_ratios) == len(self.config.run_splits)
counts = (np.array(self.config.run_ratios) * n_segments).astype(int)
counts[-1] = n_segments - sum(counts[:-1])
if verbose:
print(f"subject_session: {subject_session}")
print(f"RATIOS: {self.config.run_ratios}")
print(f"self.config.run_splits: {self.config.run_splits}")
print(f"COUNTS: {counts}")
if return_splitted_indices:
splitted_indices = []
sum_now = 0
for c, split in zip(counts, self.config.run_splits):
label_split_indices = subject_rows_indices[sum_now : sum_now + c]
if return_splitted_indices:
splitted_indices.append(label_split_indices)
sum_now += c
metadata.set_col_to_value(label_split_indices, "split", split)
self._check_split_labels(metadata, subject_session)
if return_splitted_indices:
return metadata, splitted_indices
return metadata
def _check_split_labels(self, metadata, subject_session):
# Check that both labels available in each split.
# NOTE: Not using asserts because the initial default splits might not have
# both, but the ones computed offline will and provided through the .pkl file
# will satisfy requirement.
for split in np.unique(self.config.run_splits):
for i in range(2): # magic 2 = positive/negative labels
if (
len(
metadata.get_indices_matching_cols_values(
["subject_session", "experiment", "label", "split"],
[subject_session, self.experiment, i, split],
)
)
== 0
):
print(f"split {split} missing label {i}")