| """Code to handle preprocessing, segmenting and labeling the BrainTreebank dataset. |
| |
| Preprocessing and segmentation functionality is based on the implementations found in the |
| following repositories, but has been modified as needed to be used for the evaluation scheme |
| outlined in the BaRISTA paper: |
| https://github.com/czlwang/BrainBERT/tree/master/data |
| https://github.com/czlwang/PopulationTransformer/tree/main/data |
| https://github.com/czlwang/brain_treebank_code_release/tree/master/data |
| """ |
| import dataclasses |
| import einops |
| import hashlib |
| import numpy as np |
| from omegaconf import DictConfig, OmegaConf |
| import os |
| import pandas as pd |
| import pickle |
| import torch |
| from typing import Dict, List, Optional, Tuple, Union |
|
|
| from barista.data.available_sessions import BrainTreebankAvailableSessions |
| from barista.data.braintreebank_data_helpers import ( |
| BrainTreebankDatasetNames, |
| BrainTreebankDatasetPathManager, |
| BrainTreebankDatasetPreprocessor, |
| BrainTreebankDatasetRawDataHelper, |
| ) |
| from barista.data.braintreebank_dataset_spatial_groupings import ( |
| BrainTreebankSpatialGroupingsHelper, |
| ) |
| from barista.data.metadata import Metadata, MetadataRow, MetadataSpatialGroupRow |
| from barista.data.splitter import Splitter |
| from barista.data.fileprogresstracker import FileProgressTracker |
|
|
| _DEFAULT_FS = 2048 |
|
|
|
|
| torch_version = torch.__version__.split("+")[0] |
|
|
|
|
| class BrainTreebankWrapper: |
| def __init__(self, config: Union[DictConfig, OmegaConf], only_segment_generation=False): |
| self.config = config |
|
|
| self._setup_helpers() |
|
|
| self.spatial_groups_helper = BrainTreebankSpatialGroupingsHelper( |
| self.config, dataset_name=self.name |
| ) |
|
|
| |
| self.segments_processing_str, self.segments_processing_hash_str = ( |
| self._get_segments_processing_hash( |
| segment_length_s=self.config.segment_length_s, |
| ) |
| ) |
| |
| |
| if not self._is_raw_data_processed() or self.config.force_reprocess_stage1: |
| print( |
| "Processed raw dataset does not exist or reprocessing is enabled, processing starts." |
| ) |
| self._process_raw_data() |
| print(f"Raw data processing complete: {self._processed_raw_data_dir}") |
| else: |
| print("Processed raw data exists") |
|
|
| |
| os.makedirs(self._processed_segments_data_dir, exist_ok=True) |
|
|
| self.metadata = self._load_metadata() |
|
|
| |
| self.metadata = self._initialize_metadata() |
|
|
| |
| self.process_segments(only_segment_generation) |
| print(f"Segments are processed and ready to use. Metadata path: {self.metadata_path}") |
|
|
| @property |
| def name(self) -> str: |
| return "BrainTreebank" |
|
|
| @property |
| def available_sessions(self) -> Dict[str, List]: |
| return { |
| k.name: k.value |
| for k in BrainTreebankAvailableSessions |
| if not self.config.subjects_to_process |
| or k.name in self.config.subjects_to_process |
| } |
|
|
| @property |
| def experiment(self): |
| return self.config.experiment |
| |
| @property |
| def metadata_path(self): |
| return os.path.join( |
| self.config.save_dir, |
| self.experiment, |
| f"metadata_{self.segments_processing_hash_str}.csv", |
| ) |
| |
| def _setup_helpers(self): |
| self.path_manager = BrainTreebankDatasetPathManager( |
| dataset_dir=self.config.dataset_dir, |
| ) |
| self.raw_data_helper = BrainTreebankDatasetRawDataHelper(self.path_manager) |
| self.raw_data_preprocessor = BrainTreebankDatasetPreprocessor(self.config) |
| self.experiment_dataset_name = BrainTreebankDatasetNames.get_modes( |
| self.config.experiment |
| ) |
|
|
| self.samp_frequency = self.config.get("samp_frequency", _DEFAULT_FS) |
| self.splitter = Splitter( |
| config=self.config, |
| subjects=list(self.available_sessions.keys()), |
| experiment=self.experiment, |
| use_fixed_seed=self.config.use_fixed_seed_for_splitter, |
| ) |
| |
| def _process_raw_data(self): |
| os.makedirs(self._processed_raw_data_dir, exist_ok=True) |
|
|
| for subject in self.available_sessions.keys(): |
| print(f"Raw data processing for subject {subject} starts.") |
|
|
| sessions_count = len(self.available_sessions[subject]) |
| for i, session in enumerate(self.available_sessions[subject]): |
| processed_file_path = self._get_processed_raw_data_file_path( |
| subject=subject, session=session |
| ) |
| if os.path.exists(processed_file_path): |
| print( |
| f"Skipping session {session} ({i+1}/{sessions_count}), " |
| f"processed raw data exists in {processed_file_path}." |
| ) |
| else: |
| print( |
| f"Processing session {session} ({i+1}/{sessions_count})..." |
| ) |
|
|
| self._process_single_session_raw_data( |
| subject=subject, session=session |
| ) |
|
|
| def _process_single_session_raw_data(self, subject: str, session: str): |
| save_path = self._get_processed_raw_data_file_path( |
| subject=subject, session=session |
| ) |
| cache_dir, cache_path = self._get_processed_raw_data_file_path_cache( |
| subject=subject, session=session |
| ) |
|
|
| if not self.config.force_reprocess_stage1: |
| if os.path.isfile(save_path): |
| print(f"Skipping raw processing for {subject} {session}") |
| return |
|
|
| if os.path.isfile(cache_path): |
| print( |
| f"Making symlink for raw processed file for {subject} {session}" |
| ) |
| os.symlink(src=cache_path, dst=save_path) |
| return |
|
|
| raw_data_dict = self.raw_data_helper.get_raw_file(subject, session) |
| electrodes = raw_data_dict["electrode_info"] |
|
|
| |
| selected_electrodes = self.raw_data_helper.get_clean_elecs(subject) |
| assert len(set(selected_electrodes).intersection(set(electrodes))) == len( |
| selected_electrodes |
| ) |
|
|
| selected_elecs_inds = [ |
| i for i, e in enumerate(electrodes) if e in selected_electrodes |
| ] |
| electrode_data = raw_data_dict["data"][:, np.array(selected_elecs_inds)] |
| electrode_data = ( |
| electrode_data.T |
| ) |
|
|
| |
| if self.samp_frequency != _DEFAULT_FS: |
| raise NotImplementedError( |
| f"Resampling {self.name} dataset not yet supported." |
| ) |
|
|
| |
| electrode_data = self.raw_data_preprocessor.filter_data(electrode_data) |
|
|
| |
| electrode_data = self.raw_data_preprocessor.rereference_data( |
| selected_data=electrode_data, |
| selected_electrodes=selected_electrodes, |
| all_data=raw_data_dict["data"].T, |
| all_electrodes=raw_data_dict["electrode_info"], |
| ) |
|
|
| save_dict = dict( |
| data=torch.tensor(electrode_data.T), |
| time=torch.tensor(raw_data_dict["time"]), |
| samp_frequency=self.samp_frequency, |
| electrode_info=selected_electrodes, |
| ) |
|
|
| try: |
| os.makedirs(cache_dir, exist_ok=True) |
| torch.save(save_dict, cache_path) |
| print(f"Raw processed file created in {cache_path}") |
| os.symlink(src=cache_path, dst=save_path) |
| print(f"Raw processed file symlink created in {save_path}") |
| except (OSError, PermissionError, FileNotFoundError): |
| torch.save(save_dict, save_path) |
| print(f"Raw processed file created in {save_path}") |
| |
| def _is_raw_data_processed(self): |
| if not os.path.exists(self._processed_raw_data_dir): |
| return False |
|
|
| files_exist = [] |
| for subject in self.available_sessions.keys(): |
| for session in self.available_sessions[subject]: |
| path = self._get_processed_raw_data_file_path( |
| subject=subject, session=session |
| ) |
| files_exist.append(os.path.exists(path)) |
| return np.array(files_exist).all() |
|
|
| def _get_file_progress_tracker_save_path(self, subject: str, session: str) -> str: |
| filename = f"{subject}_{session}_processing_status.json" |
| return os.path.join(self._processed_segments_data_dir, filename) |
|
|
| def _get_channels_region_info( |
| self, |
| subject: str, |
| electrode_info: List[str], |
| ) -> List[Tuple]: |
| """ |
| Generate a list of Channels each including region information of the channel. |
| """ |
| channels, coords, channel_inds_to_remove = [], [], [] |
| for channel_ind, channel_name in enumerate(electrode_info): |
| localization_info = self.raw_data_helper.get_channel_localization( |
| subject, channel_name |
| ) |
| if not localization_info: |
| raise ValueError( |
| f"Couldn't found elec {channel_name} for subject {subject}" |
| ) |
|
|
| assert ( |
| "coords" in localization_info |
| ), "localization_info incomplete, missing coords" |
| coord = localization_info.pop("coords") |
|
|
| |
| if self.config.region_filtering.active: |
| match = False |
| for filtered_region in self.config.region_filtering.filters: |
| component_info = localization_info['region_info'] |
| match = filtered_region.lower() in component_info.lower() |
| if match: |
| break |
|
|
| if match: |
| channel_inds_to_remove.append(channel_ind) |
| continue |
|
|
| coords.append((coord[0], coord[1], coord[2])) |
| channels.append(( |
| localization_info['hemi'], |
| localization_info['region_info'], |
| localization_info['channel_stem'], |
| )) |
|
|
| return channels, coords, channel_inds_to_remove |
|
|
| def _create_spatial_groupings( |
| self, subject: str, session: str, coords: List[Tuple] |
| ): |
| localization = self.raw_data_helper.get_channel_localization_raw(subject) |
| rows = self.spatial_groups_helper.get_spatial_groupings( |
| subject, |
| session, |
| coords, |
| localization, |
| ) |
| for row in rows: |
| self.metadata.add_spatial_group(row) |
| print(f"Add spatial group {row.name} for {row.subject_session}") |
|
|
| self.metadata.save(self.metadata_path) |
|
|
| def _spatial_groupings_exist_for_subject(self, subject: str, session: str): |
| for spatial_grouping in self.config.spatial_groupings_to_create: |
| sg = self.metadata.get_spatial_grouping( |
| subject_session=f"{subject}_{session}", name=spatial_grouping |
| ) |
| if sg is None: |
| return False |
| return True |
|
|
| def _save_segment( |
| self, |
| subject: str, |
| session: str, |
| segment_data: torch.tensor, |
| segment_time: torch.tensor, |
| segment_labels: torch.tensor, |
| segment_id: int, |
| segment_seq_len: int, |
| file_progress_tracker: FileProgressTracker, |
| is_last_segment: bool |
| ) -> dict: |
| """Process and save one segment to file.""" |
|
|
| segment_data = { |
| "x": segment_data.float().clone(), |
| "timestamps": segment_time.clone(), |
| self.experiment: segment_labels.clone(), |
| } |
|
|
| segment_label = self._get_segment_label(segment_labels) |
| segment_filename = f"{subject}_{session}_{segment_id}.pt" |
| segment_path = os.path.join(self._processed_segments_data_dir, segment_filename) |
| torch.save(segment_data, segment_path) |
|
|
| meta_row = MetadataRow( |
| dataset=self.name, |
| subject=subject, |
| session=session, |
| subject_session=f"{subject}_{session}", |
| experiment=self.experiment, |
| seq_len=segment_seq_len, |
| d_input=np.prod(segment_data["x"].shape), |
| d_data=segment_data["x"].shape, |
| path=segment_path, |
| split="train", |
| filename=segment_filename, |
| processing_str=self.segments_processing_str, |
| label=segment_label, |
| ) |
|
|
| self.metadata.concat(pd.DataFrame([meta_row])) |
|
|
| if segment_id % self.config.processing_save_interval == 0 or is_last_segment: |
| self.metadata.save(self.metadata_path) |
| file_progress_tracker.update_last_file_ind( |
| file_ind=-1, ending_ind=-1, segment_id=segment_id |
| ) |
|
|
| def _create_segments_for_subject_session( |
| self, |
| subject: str, |
| session: str, |
| segment_length_s: int, |
| file_progress_tracker: FileProgressTracker, |
| ) -> int: |
| """ |
| Args: |
| subject: str. Subject name. |
| session: str. Session name. |
| segment_length_s: desired segment length in seconds |
| file_progress_tracker: tracker of last segment info that is processed |
| |
| Returns: |
| Number of newly added segments. |
| """ |
| processed_raw_data_path = self._get_processed_raw_data_file_path( |
| subject=subject, session=session |
| ) |
| preprocessed_data_dict = torch.load(processed_raw_data_path, weights_only=False) |
|
|
| data = preprocessed_data_dict["data"].T |
|
|
| electrode_names = preprocessed_data_dict["electrode_info"] |
| channels, coords, channel_inds_to_remove = self._get_channels_region_info( |
| subject, electrode_names |
| ) |
| assert len(electrode_names) - len(channel_inds_to_remove) == len(channels) |
|
|
| if channel_inds_to_remove: |
| print( |
| f"Dropping {len(channel_inds_to_remove)} channels out of {len(electrode_names)} because missing." |
| ) |
| channels_to_keep = np.delete( |
| np.arange(data.shape[0]), channel_inds_to_remove |
| ) |
| data = data[channels_to_keep, ...] |
| electrode_names = [ |
| electrode_names[i] |
| for i in range(len(electrode_names)) |
| if i not in channel_inds_to_remove |
| ] |
|
|
| assert data.shape[0] == len(channels) |
|
|
| self._create_spatial_groupings(subject, session, coords) |
|
|
| if ( |
| file_progress_tracker.is_completed() |
| and not self.config.force_reprocess_stage2 |
| ): |
| return 0 |
|
|
| |
| n_steps_in_one_segment = int(self.samp_frequency * segment_length_s) |
| data, labels, data_sample_indices = self._get_experiment_data_and_labels( |
| subject, |
| session, |
| data, |
| n_steps_in_one_segment, |
| time=preprocessed_data_dict["time"], |
| samp_frequency=preprocessed_data_dict["samp_frequency"], |
| electrode_info=preprocessed_data_dict["electrode_info"], |
| ) |
|
|
| |
| _, _, last_segment_id = file_progress_tracker.get_last_file_ind() |
|
|
| print( |
| f"{last_segment_id+1} segment(s) already processed for subject {subject} session {session}." |
| ) |
|
|
| for segment_ind in range(last_segment_id + 1, data.shape[0]): |
| segment_data = data[segment_ind, ...] |
| segment_label = labels[segment_ind, ...] |
|
|
| |
| segment_data = torch.tensor( |
| self.raw_data_preprocessor.zscore_data(segment_data) |
| ) |
| segment_data = segment_data.T |
|
|
| self._save_segment( |
| subject, |
| session=session, |
| segment_data=segment_data, |
| segment_time=data_sample_indices[segment_ind, ...], |
| segment_labels=segment_label, |
| segment_id=segment_ind, |
| segment_seq_len=n_steps_in_one_segment, |
| file_progress_tracker=file_progress_tracker, |
| is_last_segment=(segment_ind == data.shape[0] - 1), |
| ) |
|
|
| return data.shape[0] - last_segment_id |
|
|
| def _generate_segmented_data( |
| self, |
| data: torch.Tensor, |
| n_steps_in_one_segment: int, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Segment data of shape (channels x time_samples) to (number_of_segments x channels x n_steps_in_one_segment). |
| It will truncate extra samples. |
| |
| Returns segmented data and also indices corresponding to original data tensor. |
| """ |
| |
| cutoff_len = int(data.shape[-1] - data.shape[-1] % n_steps_in_one_segment) |
| data = data[..., :cutoff_len] |
| data_sample_indices = torch.arange(data.shape[-1]) |
| data = einops.rearrange(data, "c (ns sl) -> ns c sl", sl=n_steps_in_one_segment) |
| data_sample_indices = data_sample_indices.reshape( |
| [-1, n_steps_in_one_segment] |
| ) |
|
|
| return data, data_sample_indices |
|
|
| def _get_experiment_data_and_labels( |
| self, |
| subject: str, |
| session: str, |
| raw_data: torch.Tensor, |
| n_steps_in_one_segment: int, |
| **kwargs, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Generate data and labels pairs. The data is reshaped to segments, which is done either by chunking |
| or by word-based segmenting based on the given experiment. |
| |
| Args: |
| subject: str. Current data's subject name. |
| session: str. Current data's session name. |
| raw_data: a tensor of shape (n_channels x n_total_samples) |
| n_steps_in_one_segment: int. Number of samples we want in one segment. |
| |
| Output: |
| data: a tensor of shape (n_segments x n_channels x n_steps_in_one_segment) |
| labels: a tensor of shape (n_segments x n_steps_in_one_segment) |
| data_sample_indices: a tensor of shape (n_segments x n_steps_in_one_segment) |
| containing indices of samples of the raw data each item in data corresponds to |
| """ |
| if self.experiment_dataset_name == self._pretrain_enum: |
| data, data_sample_indices = self._generate_segmented_data( |
| raw_data, n_steps_in_one_segment |
| ) |
| labels = torch.tensor(np.ones_like(data_sample_indices) * np.nan) |
| return data, labels, data_sample_indices |
|
|
| |
| raw_labels, label_intervals = self.raw_data_helper.get_features( |
| subject, session, self.experiment, raw_data.shape[-1] |
| ) |
|
|
| if ( |
| self.experiment_dataset_name == BrainTreebankDatasetNames.SENTENCE_ONSET |
| or self.experiment_dataset_name |
| == BrainTreebankDatasetNames.SPEECH_VS_NONSPEECH |
| or self.experiment_dataset_name |
| == BrainTreebankDatasetNames.SENTENCE_ONSET_TIME |
| or self.experiment_dataset_name |
| == BrainTreebankDatasetNames.SPEECH_VS_NONSPEECH_TIME |
| ): |
| data, labels, data_sample_indices = ( |
| self._generate_data_and_labels_by_speech( |
| raw_data, n_steps_in_one_segment, raw_labels |
| ) |
| ) |
|
|
| elif ( |
| self.experiment_dataset_name == BrainTreebankDatasetNames.VOLUME |
| or self.experiment_dataset_name == BrainTreebankDatasetNames.OPTICAL_FLOW |
| ): |
| |
| label_switchpoints = np.array( |
| [elem[0] for elem in label_intervals], dtype=int |
| ) |
| data, data_sample_indices, _ = self._generate_word_aligned_segments( |
| raw_data, n_steps_in_one_segment, label_switchpoints |
| ) |
| |
| |
|
|
| start = ( |
| int(data.shape[-1] / 2) |
| if self.config.trial_alignment == "center" |
| else 0 |
| ) |
| valid_label_switchpoints = data_sample_indices[start :: data.shape[-1]] |
|
|
| labels = raw_labels[valid_label_switchpoints] |
| labels = einops.repeat(labels, "n -> n l", l=data.shape[-1]) |
|
|
| if self.config.quantile_numerical_labels.active: |
| labels = self._generate_quartile_labels(labels) |
|
|
| data_sample_indices = data_sample_indices.reshape( |
| (data.shape[0], data.shape[-1]) |
| ) |
| labels = torch.from_numpy(labels) |
|
|
| return data, labels, data_sample_indices |
|
|
| def _generate_data_and_labels_by_segments( |
| self, |
| raw_data: torch.Tensor, |
| n_steps_in_one_segment: int, |
| raw_labels: np.ndarray, |
| ): |
| """ |
| Generate data and labels pairs by chunking the full session |
| |
| Args: |
| raw_data: a tensor of shape (N_channels x N_total_samples) |
| n_steps_in_one_segment: number of samples we want in one segment |
| raw_labels: a numpy array of length N_total_samples containing labels |
| corresponding to each sample |
| |
| Output: |
| data: a tensor of shape (N_segments x N_channels x n_steps_in_one_segment) |
| labels: a tensor of shape (N_segments x n_steps_in_one_segment) |
| data_sample_indices: a tensor of shape (N_segments x n_steps_in_one_segment) |
| containing indices of samples of the raw data each item in data corresponds to |
| """ |
| data, data_sample_indices = self._generate_segmented_data( |
| raw_data, n_steps_in_one_segment |
| ) |
|
|
| |
| cutoff_len = data.shape[0] * data.shape[-1] |
|
|
| labels = raw_labels[..., :cutoff_len] |
| labels = einops.rearrange(labels, "(ns sl) -> ns sl", sl=n_steps_in_one_segment) |
|
|
| assert labels.shape[0] == data.shape[0] |
|
|
| if self.config.quantile_numerical_labels.active: |
| labels = self._generate_quartile_labels(labels) |
|
|
| labels = torch.from_numpy(labels) |
| return data, labels, data_sample_indices |
|
|
| def _generate_quartile_labels(self, feature_values: np.ndarray) -> np.ndarray: |
| """ |
| Convert float labels based on quantile values: values in the top quantile will be assigned 1, |
| values in the bottom quantile will be assigned 0, and all others will be assigned NaN. |
| """ |
| valid_inds = ~np.isnan(feature_values) |
| lower_thresh, higher_thresh = np.quantile( |
| feature_values[valid_inds], |
| [ |
| self.config.quantile_numerical_labels.lower_threshold, |
| self.config.quantile_numerical_labels.higher_threshold, |
| ], |
| ) |
|
|
| valid_inds = np.logical_or( |
| feature_values <= lower_thresh, feature_values >= higher_thresh |
| ) |
| new_feature_values = feature_values.copy() |
| new_feature_values[~valid_inds] = np.nan |
| new_feature_values[feature_values <= lower_thresh] = 0 |
| new_feature_values[feature_values >= higher_thresh] = 1 |
|
|
| return new_feature_values |
|
|
| def _generate_word_aligned_segments( |
| self, |
| raw_data: torch.Tensor, |
| n_steps_in_one_segment: int, |
| label_switchpoints: np.ndarray, |
| ): |
| if self.config.trial_alignment == "center": |
| half_window = int(n_steps_in_one_segment / 2) |
| start_inds = label_switchpoints - half_window |
| valid_start_inds = start_inds[ |
| np.logical_and( |
| start_inds >= 0, |
| start_inds + n_steps_in_one_segment < raw_data.shape[-1], |
| ) |
| ] |
|
|
| all_word_aligned_inds, word_aligned_inds, word_aligned_samples = ( |
| [], |
| [], |
| [], |
| ) |
| |
| for samp_ind, samp_start_ind in enumerate(valid_start_inds): |
| |
| inds_to_query = torch.arange( |
| samp_start_ind, samp_start_ind + n_steps_in_one_segment |
| ) |
| all_word_aligned_inds.append(inds_to_query) |
|
|
| |
| if ( |
| self.config.force_nonoverlap |
| and samp_ind > 0 |
| and samp_start_ind <= word_aligned_inds[-1][-1] |
| ): |
| continue |
|
|
| word_aligned_samples.append(raw_data[:, inds_to_query]) |
| word_aligned_inds.append(inds_to_query) |
|
|
| print( |
| f"Using only {len(word_aligned_inds)} out of {len(all_word_aligned_inds)} word-aligned segments." |
| ) |
| all_word_aligned_inds = torch.cat(all_word_aligned_inds) |
| word_aligned_inds = torch.cat( |
| word_aligned_inds |
| ) |
| word_aligned_samples = torch.stack( |
| word_aligned_samples |
| ) |
|
|
| if self.config.force_nonoverlap: |
| assert len(torch.unique(word_aligned_inds)) == len(word_aligned_inds) |
|
|
| else: |
| raise NotImplementedError("Only center trial alignment supported.") |
|
|
| return word_aligned_samples, word_aligned_inds, all_word_aligned_inds |
|
|
| def _generate_data_and_labels_by_speech( |
| self, |
| raw_data: torch.Tensor, |
| n_steps_in_one_segment: int, |
| labels: np.ndarray, |
| ): |
| """ |
| Generate data and labels pairs by segmenting based on words. |
| |
| This function will first create word-aligned non-overlapping segments and |
| then assign labels to each word. For speech_vs_nonspeech(_time) and |
| sentence_onset(_time) tasks, it then chunks the data and uses segments that |
| don't overlap with any word to generate negative labels. Note, this function |
| can generate either non-overlapping **or** overlapping word center-aligned |
| segments -- based on user preference. In the former case with non-overlapping |
| segments, not all parts of the data will be used, since this is word-based. |
| |
| Args: |
| data: a tensor of shape (n_channels x n_total_samples) |
| n_steps_in_one_segment: number of samples we want in one segment |
| raw_labels: a numpy array of length n_total_samples containing labels |
| corresponding to each sample |
| |
| Output: |
| data: a tensor of shape (n_segments x n_channels x n_steps_in_one_segment) |
| labels: a tensor of shape (n_segments x n_steps_in_one_segment) |
| data_sample_indices: a tensor of shape (n_segments x n_steps_in_one_segment) |
| containing indices of samples of the raw data each item in data corresponds to. |
| """ |
| |
| |
| |
| |
| label_switchpoints = np.where( |
| np.logical_and( |
| |
| np.concatenate((np.array([0]), np.diff(np.nan_to_num(labels)))) > 0, |
| ~np.isnan(labels), |
| ) |
| )[0] |
| out = self._generate_word_aligned_segments( |
| raw_data, n_steps_in_one_segment, label_switchpoints |
| ) |
| word_aligned_samples, word_aligned_inds, all_word_aligned_inds = out |
|
|
| if self.config.force_nonoverlap: |
| data_sample_indices = torch.arange(raw_data.shape[-1]) |
| is_unaligned_inds = np.logical_and( |
| ~np.isin(data_sample_indices, np.unique(all_word_aligned_inds)), |
| ~np.isnan(labels), |
| ) |
| |
| cutoff_len = int( |
| raw_data.shape[-1] - raw_data.shape[-1] % n_steps_in_one_segment |
| ) |
| is_unaligned_inds = np.reshape( |
| is_unaligned_inds[..., :cutoff_len], (-1, n_steps_in_one_segment) |
| ) |
| unaligned_inds = np.where(np.all(is_unaligned_inds, axis=1))[0] |
| unaligned_word_samples = torch.stack( |
| [ |
| raw_data[ |
| :, |
| start_ind |
| * n_steps_in_one_segment : (start_ind + 1) |
| * n_steps_in_one_segment, |
| ] |
| for start_ind in unaligned_inds |
| ] |
| ) |
|
|
| word_aligned_data_sample_inds = torch.reshape( |
| word_aligned_inds, (-1, n_steps_in_one_segment) |
| ) |
| unaligned_data_sample_inds = torch.reshape( |
| data_sample_indices[:cutoff_len], (-1, n_steps_in_one_segment) |
| )[unaligned_inds] |
|
|
| else: |
| |
| if self.config.nonword_stepsize_s is None: |
| self.config.nonword_stepsize_s = self.config.segment_length_s |
|
|
| offset = int(self.samp_frequency * self.config.nonword_stepsize_s) |
| |
| n_rows = ((raw_data.shape[-1] - n_steps_in_one_segment) // offset) + 1 |
|
|
| data_sample_indices = np.array( |
| [ |
| np.arange(i * offset, i * offset + n_steps_in_one_segment) |
| for i in range(n_rows) |
| ] |
| ) |
|
|
| is_unaligned_inds = np.logical_and( |
| ~np.isin(data_sample_indices, np.unique(all_word_aligned_inds)), |
| |
| |
| |
| ~np.isnan( |
| labels[data_sample_indices.flatten()].reshape( |
| data_sample_indices.shape |
| ) |
| ), |
| ) |
| unaligned_inds = np.where(np.all(is_unaligned_inds, axis=1))[0] |
|
|
| unaligned_word_samples = torch.stack( |
| [ |
| raw_data[ |
| :, |
| start_ind * offset : start_ind * offset |
| + n_steps_in_one_segment, |
| ] |
| for start_ind in unaligned_inds |
| ] |
| ) |
|
|
| data_sample_indices = torch.tensor(data_sample_indices) |
|
|
| word_aligned_data_sample_inds = torch.reshape( |
| word_aligned_inds, (-1, n_steps_in_one_segment) |
| ) |
| unaligned_data_sample_inds = data_sample_indices[unaligned_inds] |
|
|
| n_word_aligned_samples = word_aligned_samples.shape[0] |
| n_unaligned_word_samples = unaligned_word_samples.shape[0] |
|
|
| num_samples = n_unaligned_word_samples + n_word_aligned_samples |
|
|
| if self.config.force_balanced: |
| num_samples = min(n_unaligned_word_samples, n_word_aligned_samples) * 2 |
|
|
| word_aligned_to_use = np.sort( |
| np.random.choice( |
| range(n_word_aligned_samples), |
| replace=False, |
| size=num_samples // 2, |
| ) |
| ) |
| word_aligned_samples = word_aligned_samples[word_aligned_to_use, ...] |
| word_aligned_data_sample_inds = word_aligned_data_sample_inds[ |
| word_aligned_to_use |
| ] |
|
|
| unaligned_to_use = np.sort( |
| np.random.choice( |
| range(n_unaligned_word_samples), |
| replace=False, |
| size=num_samples // 2, |
| ) |
| ) |
| unaligned_word_samples = unaligned_word_samples[unaligned_to_use, ...] |
| unaligned_data_sample_inds = unaligned_data_sample_inds[unaligned_to_use] |
|
|
| n_word_aligned_samples = word_aligned_samples.shape[0] |
| n_unaligned_word_samples = unaligned_word_samples.shape[0] |
|
|
| |
| data = torch.empty( |
| n_word_aligned_samples + n_unaligned_word_samples, |
| *word_aligned_samples.shape[1:], |
| ) |
| data[:n_word_aligned_samples] = word_aligned_samples |
| data[n_word_aligned_samples:] = unaligned_word_samples |
|
|
| num_channels = raw_data.shape[0] |
| assert data.shape == ( |
| num_samples, |
| num_channels, |
| n_steps_in_one_segment, |
| ) |
|
|
| |
| labels = torch.zeros(num_samples, n_steps_in_one_segment) |
| labels[:n_word_aligned_samples] = 1 |
|
|
| |
| data_sample_indices = torch.empty( |
| n_word_aligned_samples + n_unaligned_word_samples, |
| n_steps_in_one_segment, |
| ) |
| data_sample_indices[:n_word_aligned_samples] = word_aligned_data_sample_inds |
| data_sample_indices[n_word_aligned_samples:] = unaligned_data_sample_inds |
|
|
| |
| sorted_inds = torch.argsort(data_sample_indices[:, 0]) |
| data_sample_indices = data_sample_indices[sorted_inds, ...] |
| data = data[sorted_inds, ...] |
| labels = labels[sorted_inds, ...] |
| return data, labels, data_sample_indices |
|
|
| def _aggregate_labels(self, labels: torch.Tensor) -> float: |
| """ |
| Return one label for each segment in batch instead of having one label for each timepoint |
| """ |
|
|
| nan_numels = torch.isnan(labels).sum() |
|
|
| if nan_numels / len(labels) >= self.config.aggregate_labels.nan_threshold: |
| label = torch.nan |
| elif self.config.aggregate_labels.type == "mean": |
| label = labels.nanmean() |
| label = float(label) |
| elif self.config.aggregate_labels.type == "threshold": |
| non_nan_numels = len(labels) - nan_numels |
| label = int( |
| ( |
| labels.nansum() / non_nan_numels |
| > self.config.aggregate_labels.threshold |
| ).long() |
| ) |
|
|
| return label |
|
|
| def _get_segment_label(self, labels: torch.tensor) -> float: |
| if self.experiment_dataset_name == self._pretrain_enum: |
| return np.nan |
|
|
| agg_label = self._aggregate_labels(labels) |
| return agg_label |
|
|
| def _process_segments_and_update_metadata_file(self): |
| """ |
| Process data files of subjects and add/update segments |
| """ |
| number_of_added_segments = 0 |
| for subject in self.available_sessions.keys(): |
| for session in self.available_sessions[subject]: |
| print( |
| f"Segment processing for subject {subject} session {session} starts." |
| ) |
|
|
| |
| file_progress_tracker = FileProgressTracker( |
| save_path=self._get_file_progress_tracker_save_path( |
| subject, session |
| ), |
| experiment=self.experiment, |
| ) |
|
|
| if self.config.force_reprocess_stage2: |
| corresponding_indices_to_remove = ( |
| self.metadata.get_indices_matching_cols_values( |
| ["subject", "session", "experiment"], |
| [subject, session, self.experiment], |
| ) |
| ) |
| self.metadata.drop_rows_based_on_indices( |
| corresponding_indices_to_remove |
| ) |
|
|
| file_progress_tracker.reset_process() |
| print( |
| f"Force reprocessing active, removed subject: {subject} session: " |
| f"{session} experiment: {self.experiment} from metadata, will " |
| f"start processing from the first file." |
| ) |
|
|
| if file_progress_tracker.is_completed(): |
| sp_exist = self._spatial_groupings_exist_for_subject( |
| subject, session |
| ) |
| if sp_exist and not self.config.force_recreate_spatial_groupings: |
| print( |
| f"Subject {subject} data already processed completely, skipping." |
| ) |
| continue |
| else: |
| print( |
| f"Subject {subject} data already processed completely," |
| " but force recreate spatial groupings is active," |
| " will recreate spatial groups" |
| ) |
|
|
| number_of_added_segments_for_subject_session = ( |
| self._create_segments_for_subject_session( |
| subject, |
| session, |
| self.config.segment_length_s, |
| file_progress_tracker, |
| ) |
| ) |
|
|
| print( |
| f"Added {number_of_added_segments_for_subject_session} new segments for subject {subject} session {session}" |
| ) |
|
|
| nan_labels = self.metadata.get_indices_matching_cols_values( |
| ["subject", "session", "experiment", "label"], |
| [subject, session, self.experiment, None], |
| ) |
| print( |
| f"{len(nan_labels)} segments for this subject session have nan labels" |
| ) |
|
|
| number_of_added_segments += number_of_added_segments_for_subject_session |
|
|
| self.metadata = self.splitter.set_splits_for_subject( |
| subject, self.metadata, self._split_method |
| ) |
| file_progress_tracker.mark_completion_status() |
| self.metadata.save(self.metadata_path) |
|
|
| print(f"Metadata saved in {self.metadata_path}") |
| print(f"Added {number_of_added_segments} new segments") |
|
|
| summary_str = self.metadata.get_summary_str() |
| print(f"{self.name} dataset, full metadata summary: {summary_str}") |
|
|
| def _filter_metadata_for_the_run(self): |
| """ |
| Do filtering on metadata based on experiment design |
| |
| # NOTE: Add stuff that are run dependent but do **not** alter the saved metadata here. |
| """ |
| |
| self.metadata.reduce_based_on_col_value("experiment", self.experiment) |
|
|
| |
| if not self.experiment_dataset_name == self._pretrain_enum: |
| n_dropped = self.metadata.reduce_based_on_col_value( |
| "label", None, keep=False |
| ) |
| print(f"Dropping {n_dropped} segments with no labels") |
|
|
| if self.experiment_dataset_name in ( |
| BrainTreebankDatasetNames.SPEECH_VS_NONSPEECH_TIME, |
| BrainTreebankDatasetNames.SENTENCE_ONSET_TIME, |
| BrainTreebankDatasetNames.VOLUME, |
| BrainTreebankDatasetNames.OPTICAL_FLOW |
| ): |
| |
| curr_fold = self.config.get("chron_fold_num", None) |
| if curr_fold is not None: |
| print(f"Using chronological fold: {curr_fold}.") |
| folds_path = os.path.join( |
| self.config.save_dir, |
| self.experiment, |
| f"metadata_{self.segments_processing_hash_str}_folds.pkl", |
| ) |
| try: |
| with open( |
| folds_path, |
| "rb", |
| ) as f: |
| folds_info = pickle.load(f) |
| except FileNotFoundError as e: |
| print(f"File {folds_path} not found. Generate the folds for the metadata ({self.metadata_path}) using `barista/generate_chronological_folds` notebook.") |
| exit(0) |
| |
| assert ( |
| len(self.config.finetune_sessions) == 1 |
| ), "Only one finetune session expected." |
|
|
| subject_session = self.config.finetune_sessions[0] |
| self.config.run_ratios = [ |
| |
| float(elem) for elem in folds_info[subject_session][curr_fold][0] |
| ] |
| self.config.run_splits = folds_info[subject_session][curr_fold][1] |
|
|
| else: |
| print("Using default run chronological ratios and splits.") |
|
|
| for subject_session in self.config.finetune_sessions: |
| self.splitter.resplit_for_subject( |
| subject_session, |
| self.metadata, |
| self._split_method, |
| ) |
|
|
| summary_str = self.metadata.get_summary_str() |
| print(f"{self.name} dataset, current run summary: {summary_str}") |
|
|
| def process_segments(self, only_segment_generation=False): |
| |
| old_metadata = self._load_metadata() |
| if old_metadata is not None: |
| self.metadata = old_metadata |
|
|
| if not self.config.skip_segment_generation_completely: |
| self._process_segments_and_update_metadata_file() |
|
|
| if not only_segment_generation: |
| self._filter_metadata_for_the_run() |
|
|
| @property |
| def _split_method(self): |
| if self.experiment_dataset_name in ( |
| BrainTreebankDatasetNames.SPEECH_VS_NONSPEECH, |
| BrainTreebankDatasetNames.SENTENCE_ONSET, |
| ): |
| assert self.config.force_nonoverlap is True, "Set force_nonoverlap to True for random split segments" |
| return "shuffle" |
| |
| |
| if self.experiment_dataset_name != BrainTreebankDatasetNames.PRETRAIN: |
| assert self.config.force_nonoverlap is False, "Set force_nonoverlap to False for chronological segments" |
| |
| return "chronological" |
| |
| @property |
| def _pretrain_enum(self) -> BrainTreebankDatasetNames: |
| return BrainTreebankDatasetNames.PRETRAIN |
| |
| def get_raw_data_file_path(self, subject: str, session: str): |
| self.path_manager.get_raw_data_filepath(subject, session) |
|
|
| @property |
| def _processed_raw_data_dir(self): |
| """ |
| Filename for processed raw data, i.e., filtering and referencing |
| """ |
| return os.path.join( |
| self.config.save_dir, |
| self._get_processed_raw_data_dir_name, |
| ) |
|
|
| @property |
| def _get_processed_raw_data_dir_name(self): |
| return f"processed_raw_{self.samp_frequency}Hz_notch_laplacianref_clnLap" |
|
|
| @property |
| def _processed_segments_data_dir(self): |
| """Data dir for the segmented trials corresponding to a particular experimental config.""" |
| return os.path.join( |
| self.config.save_dir, |
| self.experiment, |
| f"processed_segments_{self.segments_processing_hash_str}", |
| ) |
| |
| def _load_metadata(self) -> Optional[Metadata]: |
| if os.path.exists(self.metadata_path): |
| metadata = Metadata(load_path=self.metadata_path) |
| print(f"Metadata loaded from {self.metadata_path}") |
| return metadata |
| return None |
| |
| def _initialize_metadata(self) -> Metadata: |
| columns = [f.name for f in dataclasses.fields(MetadataRow)] |
| metadata_df = pd.DataFrame(columns=columns) |
| |
| columns = [f.name for f in dataclasses.fields(MetadataSpatialGroupRow)] |
| spatial_group_df = pd.DataFrame(columns=columns) |
| |
| metadata = Metadata(df=metadata_df, spatial_group_df=spatial_group_df) |
| print(f"Metadata initialized: {self.metadata_path}") |
| return metadata |
| |
| def _get_processed_raw_data_file_path(self, subject, session): |
| filename = f"{subject}_{session}.pt" |
| return os.path.join(self._processed_raw_data_dir, filename) |
|
|
| def _get_processed_raw_data_file_path_cache(self, subject, session): |
| filename = f"{subject}_{session}.pt" |
| path = os.path.join( |
| self.config.stage1_cache_dir, |
| self._get_processed_raw_data_dir_name, |
| ) |
| print(f"Cache dir: {path}") |
| return path, os.path.join(path, filename) |
|
|
| def _get_segments_processing_hash(self, segment_length_s): |
| """ |
| returns a tuple where the key is the processing str, value is the hashed key. |
| actual str can be found in metadata. |
| |
| this part can be overwritten by each dataset class based on specific settings |
| """ |
|
|
| processing_str = ( |
| f"{self.config.samp_frequency}Hz_zscrTrue" |
| f"_segment_length{segment_length_s}_val_ratio{self.config.val_ratio:.1e}_test_ratio{self.config.test_ratio:.1e}" |
| ) |
|
|
| if self.experiment_dataset_name != self._pretrain_enum: |
| processing_str += f"_trial_align{self.config.trial_alignment}" |
|
|
| if self.config.quantile_numerical_labels.active: |
| processing_str += f"quantile_numerical_labels_L{self.config.quantile_numerical_labels.lower_threshold}_H{self.config.quantile_numerical_labels.higher_threshold}" |
|
|
| processing_str += self.config.dataset_dir |
| processing_str += "_laplacian" |
|
|
| if self.config.region_filtering.active: |
| self.config.region_filtering['filters'].sort() |
| filter_str = ( |
| f"_region_filtered_{str(self.config.region_filtering.filters)}" |
| ) |
| processing_str += filter_str |
|
|
| if not self.config.force_balanced: |
| processing_str += "_all_labels" |
|
|
| if self._split_method == "chronological": |
| processing_str += "_chronosplit" |
| if not self.config.force_nonoverlap: |
| processing_str += "_overlapsegs" |
|
|
| processing_str += "_use_clean_laplacian" |
| processing_str += "_aggregate_label" + str(self.config.aggregate_labels) |
|
|
| hash_str = hashlib.sha256(bytes(processing_str, "utf-8")).hexdigest()[:5] |
| print(f"HASHSTR: {hash_str}") |
| return processing_str, hash_str |
|
|