Upload folder using huggingface_hub
Browse files- LICENSE +14 -0
- LICENSE.md +14 -0
- README.md +255 -5
- barista/config/braintreebank.yaml +117 -0
- barista/config/model.yaml +36 -0
- barista/config/train.yaml +15 -0
- barista/data/atlas.py +251 -0
- barista/data/available_sessions.py +28 -0
- barista/data/braintreebank_data_helpers.py +741 -0
- barista/data/braintreebank_dataset.py +230 -0
- barista/data/braintreebank_dataset_spatial_groupings.py +149 -0
- barista/data/braintreebank_wrapper.py +1186 -0
- barista/data/dataframe_wrapper.py +268 -0
- barista/data/fileprogresstracker.py +93 -0
- barista/data/metadata.py +175 -0
- barista/data/metadata_spatial_groups.py +60 -0
- barista/data/splitter.py +237 -0
- barista/generate_chronological_folds.ipynb +626 -0
- barista/models/TSEncoder2D.py +213 -0
- barista/models/mlp.py +60 -0
- barista/models/model.py +68 -0
- barista/models/spatial_encoder.py +276 -0
- barista/models/tokenized_batched_item.py +132 -0
- barista/models/tokenizer.py +238 -0
- barista/models/transformer.py +318 -0
- barista/models/utils.py +23 -0
- barista/prepare_segments.py +27 -0
- barista/train.py +368 -0
- barista/utility_scripts/aggregate_runs.py +161 -0
- barista/utility_scripts/run_finetune_folds.sh +276 -0
- barista/utility_scripts/run_finetune_random_splits.sh +267 -0
- pretrained_models/chans_chans.ckpt +3 -0
- pretrained_models/lobes_chans.ckpt +3 -0
- pretrained_models/parcels_chans.ckpt +3 -0
- requirements.txt +15 -0
- setup.py +25 -0
LICENSE
CHANGED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
This software is Copyright © 2025 The University of Southern California. All Rights Reserved.
|
| 2 |
+
|
| 3 |
+
Permission to use, copy, modify, and distribute this software and its documentation for educational, research and non-profit purposes, without fee, and without a writen agreement is hereby granted, provided that the above copyright notice, this paragraph and the following three paragraphs appear in all copies.
|
| 4 |
+
|
| 5 |
+
Permission to make commercial use of this software may be obtained by contacting:\
|
| 6 |
+
USC Stevens Center for Innovation\
|
| 7 |
+
University of Southern California\
|
| 8 |
+
1150 S. Olive Street, Suite 2300\
|
| 9 |
+
Los Angeles, CA 90115, USA\
|
| 10 |
+
E-mail to: info@stevens.usc.edu and cc to: accounting@stevens.usc.edu
|
| 11 |
+
|
| 12 |
+
This software program and documentation are copyrighted by The University of Southern California. The software program and documentation are supplied "as is", without any accompanying services from USC. USC does not warrant that the operation of the program will be uninterrupted or error-free. The end-user understands that the program was developed for research purposes and is advised not to rely exclusively on the program for any reason.
|
| 13 |
+
|
| 14 |
+
IN NO EVENT SHALL THE UNIVERSITY OF SOUTHERN CALIFORNIA BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF THE UNIVERSITY OF SOUTHERN CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. THE UNIVERSITY OF SOUTHERN CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF SOUTHERN CALIFORNIA HAS NO OBLIGATIONS TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
|
LICENSE.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
This software is Copyright © 2025 The University of Southern California. All Rights Reserved.
|
| 2 |
+
|
| 3 |
+
Permission to use, copy, modify, and distribute this software and its documentation for educational, research and non-profit purposes, without fee, and without a writen agreement is hereby granted, provided that the above copyright notice, this paragraph and the following three paragraphs appear in all copies.
|
| 4 |
+
|
| 5 |
+
Permission to make commercial use of this software may be obtained by contacting:\
|
| 6 |
+
USC Stevens Center for Innovation\
|
| 7 |
+
University of Southern California\
|
| 8 |
+
1150 S. Olive Street, Suite 2300\
|
| 9 |
+
Los Angeles, CA 90115, USA\
|
| 10 |
+
E-mail to: info@stevens.usc.edu and cc to: accounting@stevens.usc.edu
|
| 11 |
+
|
| 12 |
+
This software program and documentation are copyrighted by The University of Southern California. The software program and documentation are supplied "as is", without any accompanying services from USC. USC does not warrant that the operation of the program will be uninterrupted or error-free. The end-user understands that the program was developed for research purposes and is advised not to rely exclusively on the program for any reason.
|
| 13 |
+
|
| 14 |
+
IN NO EVENT SHALL THE UNIVERSITY OF SOUTHERN CALIFORNIA BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF THE UNIVERSITY OF SOUTHERN CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. THE UNIVERSITY OF SOUTHERN CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF SOUTHERN CALIFORNIA HAS NO OBLIGATIONS TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
|
README.md
CHANGED
|
@@ -1,5 +1,255 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language: en
|
| 3 |
+
tags:
|
| 4 |
+
- ieeg
|
| 5 |
+
- bci
|
| 6 |
+
- neuroscience
|
| 7 |
+
- foundation-model
|
| 8 |
+
- neurips-2025
|
| 9 |
+
arxiv: 2512.12135
|
| 10 |
+
metrics:
|
| 11 |
+
- accuracy
|
| 12 |
+
license: other
|
| 13 |
+
license_link: LICENSE
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
# BaRISTA ☕
|
| 17 |
+
|
| 18 |
+
[](https://www.python.org/)
|
| 19 |
+
[](https://openreview.net/forum?id=LDjBDk3Czb)
|
| 20 |
+
|
| 21 |
+
This repository contains the official PyTorch implementation of [**BaRISTA** (Brain Scale Informed Spatiotemporal Representation of Human Intracranial Neural Activity)](#publication).
|
| 22 |
+
|
| 23 |
+
## Table of Contents
|
| 24 |
+
- [Installation](#installation)
|
| 25 |
+
- [Data Preparation](#data-preparation)
|
| 26 |
+
- [Data Segmentation](#data-segmentation)
|
| 27 |
+
- [Finetuning the Model](#finetuning-the-model)
|
| 28 |
+
- [Additional Scripts](#additional-scripts)
|
| 29 |
+
- [Publication](#publication)
|
| 30 |
+
|
| 31 |
+
---
|
| 32 |
+
## Installation
|
| 33 |
+
|
| 34 |
+
We recommend setting up a virtual environment to manage dependencies.
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
# 1. Create and activate a virtual environment
|
| 38 |
+
python -m venv barista_venv
|
| 39 |
+
source barista_venv/bin/activate
|
| 40 |
+
|
| 41 |
+
# 2. Install the package in editable mode
|
| 42 |
+
python -m pip install -e .
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
## Data Preparation
|
| 46 |
+
|
| 47 |
+
1. Download the data from the [Brain Treebank website](https://braintreebank.dev/). You will also need the `clean_laplacian.json` file from the [PopT codebase](https://github.com/czlwang/PopulationTransformer/blob/main/electrode_selections/clean_laplacian.json).
|
| 48 |
+
|
| 49 |
+
2. Update the `dataset_dir` config in `barista/braintreebank.yaml` to point to the raw data path.
|
| 50 |
+
|
| 51 |
+
The data directory structure should have the following structure:
|
| 52 |
+
|
| 53 |
+
<details> <summary><strong>Click to expand full directory tree</strong></summary>
|
| 54 |
+
|
| 55 |
+
```
|
| 56 |
+
braintreebank_data
|
| 57 |
+
|__corrupted_elec.json
|
| 58 |
+
|__clean_laplacian.json
|
| 59 |
+
|__all_subject_data
|
| 60 |
+
| |__ sub_1_trial000.h5
|
| 61 |
+
| |__ sub_1_trial001.h5
|
| 62 |
+
| |__ sub_1_trial002.h5
|
| 63 |
+
| |__ sub_2_trial000.h5
|
| 64 |
+
| |
|
| 65 |
+
| ...
|
| 66 |
+
|
|
| 67 |
+
|__ electrode_labels
|
| 68 |
+
| |__ sub_1
|
| 69 |
+
| | |__ electrode_labels.json
|
| 70 |
+
| |__ sub_2
|
| 71 |
+
| | |__ electrode_labels.json
|
| 72 |
+
| ...
|
| 73 |
+
|
|
| 74 |
+
|__ localization
|
| 75 |
+
| |__ elec_coords_full.csv
|
| 76 |
+
| |__ sub_1
|
| 77 |
+
| | |__ depth-wm.csv
|
| 78 |
+
| |__ sub_2
|
| 79 |
+
| | |__ depth-wm.csv
|
| 80 |
+
| ...
|
| 81 |
+
|
|
| 82 |
+
|__ subject_metadata
|
| 83 |
+
| |__ sub_1_trial000_metadata.json
|
| 84 |
+
| |__ sub_1_trial001_metadata.json
|
| 85 |
+
| |__ sub_1_trial002_metadata.json
|
| 86 |
+
| |__ sub_2_trial000_metadata.json
|
| 87 |
+
| |
|
| 88 |
+
| ...
|
| 89 |
+
|
|
| 90 |
+
|__ subject_timings
|
| 91 |
+
| |__ sub_1_trial000_timings.csv
|
| 92 |
+
| |__ sub_1_trial001_timings.csv
|
| 93 |
+
| |__ sub_1_trial002_timings.csv
|
| 94 |
+
| |__ sub_2_trial000_timings.csv
|
| 95 |
+
| |
|
| 96 |
+
| ...
|
| 97 |
+
|
|
| 98 |
+
|__ transcripts
|
| 99 |
+
| |__ ant-man
|
| 100 |
+
| | |__ features.csv
|
| 101 |
+
| |__ aquaman
|
| 102 |
+
| | |__ features.csv
|
| 103 |
+
| ......
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
</details>
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
## Data Segmentation
|
| 110 |
+
|
| 111 |
+
You must segment the data **before training**. The required arguments depend on the experiment:
|
| 112 |
+
|
| 113 |
+
| Experiment Type | `force_nonoverlap` | `experiment` options |
|
| 114 |
+
--------------------------------------------------|----------------------|----------------------|
|
| 115 |
+
| **1. Random splits**, non-overlapping neural segments (Main Analysis in the paper) | `True` | `sentence_onset`, `speech_vs_nonspeech` |
|
| 116 |
+
| **2. Chronological splits**, increased labels (Appendix K in the paper) | `False` | `sentence_onset_time`, `speech_vs_nonspeech_time`, `volume`, `optical_flow` |
|
| 117 |
+
|
| 118 |
+
### 1. Generating Random Splits with Non-Overlapping Neural Segments
|
| 119 |
+
|
| 120 |
+
To generate the random splits with non-overlapping neural segments, as used for the main analysis (Section 4), you will need to run the following:
|
| 121 |
+
|
| 122 |
+
```bash
|
| 123 |
+
python barista/prepare_segments.py \
|
| 124 |
+
--config barista/config/braintreebank.yaml \
|
| 125 |
+
--experiment <sentence_onset|speech_vs_nonspeech>
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
> ⚠️ Ensure `force_nonoverlap` in `barista/config/braintreebank.yaml` is set to `True` for this experiment. Incorrect settings will produce invalid splits.
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
This setting should **only** be used with the `sentence_onset` and `speech_vs_nonspeech` experiments.
|
| 132 |
+
|
| 133 |
+
### 2. Generating Chronological Splits with Increased Label Data
|
| 134 |
+
We can also generate chronological splitting (splitting sessions based on time rather than random shuffling). This approach enables us to increase the number of labeled segments for finetuning by allowing overlap between segments within the same split, while preventing information leakage (i.e., no overlapping neural segments) between train and test splits. To generate the chronological splits used for the evaluation in Appendix K, there are two steps to follow.
|
| 135 |
+
|
| 136 |
+
First, you will need to segment the data using the following command:
|
| 137 |
+
|
| 138 |
+
```bash
|
| 139 |
+
python barista/prepare_segments.py \
|
| 140 |
+
--config barista/config/braintreebank.yaml \
|
| 141 |
+
--experiment <sentence_onset_time|speech_vs_nonspeech_time|volume|optical_flow>
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
> ⚠️ Ensure `force_nonoverlap` in `barista/config/braintreebank.yaml` is set to `False` for this experiment. Incorrect settings will produce invalid splits.
|
| 145 |
+
|
| 146 |
+
This setting should **only** be used with the `sentence_onset_time`, `speech_vs_nonspeech_time`, `volume`, and `optical_flow` experiments.
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
Second, you will need to generate the 5 chronological folds to use during evaluation. To create these different folds, we use the `data/generate_chronological_folds.ipynb` notebook. This notebook automatically will generate 5 different train/valid/test splits across time, while ensuring that all generated splits have both positive and negative labels present. To use the notebook, take the following steps:
|
| 150 |
+
|
| 151 |
+
1. Open `generate_chronological_folds.ipynb`
|
| 152 |
+
|
| 153 |
+
2. Update the `_METADATA_FNAMES` variable with the metadata hash string produced from the previous step.
|
| 154 |
+
|
| 155 |
+
3. Run the notebook to generate the 5 train/valid/test fold pickle files.
|
| 156 |
+
|
| 157 |
+
The notebook will output a pickle file in the same directory as the specified metadata file and it will be dynamically loaded during train/eval time to ensure the right chronological split fold is used.
|
| 158 |
+
|
| 159 |
+
## Finetuning the Model
|
| 160 |
+
To finetune the model,
|
| 161 |
+
|
| 162 |
+
1. Set update `finetune_sessions` field in `barista/config/braintreebank.yaml` to the desired finetuning session.
|
| 163 |
+
|
| 164 |
+
2. Use the following command to run finetuning:
|
| 165 |
+
|
| 166 |
+
```bash
|
| 167 |
+
python barista/train.py
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
It is important to ensure the `braintreebank.yaml` fields match precisely with the config used during segmentation generation, including the `experiment` field. Otherwise, the metadata hash string will not match and the experiment will fail. For the chronological folds, the experiment will also fail if the pickle file outlined in the second step of [Generating chronological splits with increased label data](#generating-chronological-splits-with-increased-label-data) hasn't been generated.
|
| 171 |
+
|
| 172 |
+
### Loading Pretrained Model
|
| 173 |
+
|
| 174 |
+
Pretrained models are available under `pretrained_models/`. Set the `checkpoint_path` in `barista/config/train.yaml` to the specific pretrained model path. e.g. `checkpoint_path: pretrained_models/parcels_chans.ckpt`.
|
| 175 |
+
|
| 176 |
+
> ⚠️ You also need to set the `tokenizer.spatial_grouping` in `barista/config/model.yaml` accordingly for each of the models.
|
| 177 |
+
|
| 178 |
+
| Checkpoint Name | `tokenizer.spatial_grouping` |
|
| 179 |
+
| -------------------- | ---------------------------- |
|
| 180 |
+
| `chans_chans.ckpt` | `coords` |
|
| 181 |
+
| `parcels_chans.ckpt` | `destrieux` |
|
| 182 |
+
| `lobes_chans.ckpt` | `lobes` |
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
Alternatively, you can pass these as extra argument to train command:
|
| 186 |
+
|
| 187 |
+
**Example finetuning command for Parcel level model**
|
| 188 |
+
```bash
|
| 189 |
+
python barista/train.py \
|
| 190 |
+
--override \
|
| 191 |
+
tokenizer.spatial_grouping="destrieux" \
|
| 192 |
+
checkpoint_path="pretrained_models/parcels_chans.ckpt"
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
## Additional Scripts
|
| 196 |
+
|
| 197 |
+
You can also use the scripts under `barista/utility_scripts` to run the model for a specific setting across different finetuning seeds.
|
| 198 |
+
The run outputs are saved in the results directory specified in the script and can be easily aggregated using `aggregate_runs.py` across different subjects, models, and folds.
|
| 199 |
+
|
| 200 |
+
**Example usage for random splits**
|
| 201 |
+
```bash
|
| 202 |
+
./barista/utility_scripts/run_finetune_random_splits.sh \
|
| 203 |
+
--spe destrieux \
|
| 204 |
+
--checkpoint "pretrained_models/parcels_chans.ckpt" \
|
| 205 |
+
--session HOLDSUBJ_1_HS1_1 \
|
| 206 |
+
--gpu 0 \
|
| 207 |
+
--exp sentence_onset
|
| 208 |
+
```
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
**Example usage for chronological fold**
|
| 212 |
+
```bash
|
| 213 |
+
./barista/utility_scripts/run_finetune_folds.sh \
|
| 214 |
+
--spe destrieux \
|
| 215 |
+
--checkpoint "pretrained_models/parcels_chans.ckpt" \
|
| 216 |
+
--session HOLDSUBJ_1_HS1_1 \
|
| 217 |
+
--gpu 0 \
|
| 218 |
+
--fold 0 \
|
| 219 |
+
--exp sentence_onset_time
|
| 220 |
+
```
|
| 221 |
+
|
| 222 |
+
### Aggregating Results
|
| 223 |
+
|
| 224 |
+
You can use `utility_scripts/aggregate_runs.py` to get the average results as a markdown table:
|
| 225 |
+
|
| 226 |
+
```bash
|
| 227 |
+
python barista/utility_scripts/aggregate_runs.py \
|
| 228 |
+
--results_dir <results|results_folds>
|
| 229 |
+
```
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
## Publication
|
| 233 |
+
[Oganesian, L. L.\*, Hashemi, S.\*, Shanechi, M. M. BaRISTA: Brain Scale Informed Spatiotemporal Representation of Human Intracranial Neural Activity. In Advances in Neural Information Processing Systems 2025.](https://openreview.net/forum?id=LDjBDk3Czb)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
**Citation**
|
| 237 |
+
```
|
| 238 |
+
@inproceedings{
|
| 239 |
+
oganesian2025barista,
|
| 240 |
+
title={BaRISTA: Brain Scale Informed Spatiotemporal Representation of Human Intracranial Neural Activity},
|
| 241 |
+
author={Oganesian, Lucine L. and Hashemi, Saba and Shanechi, Maryam M.},
|
| 242 |
+
booktitle={Advances in Neural Information Processing Systems},
|
| 243 |
+
year={2025},
|
| 244 |
+
url={https://openreview.net/pdf?id=LDjBDk3Czb}
|
| 245 |
+
}
|
| 246 |
+
```
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
## License
|
| 250 |
+
Copyright (c) 2025 University of Southern California <br />
|
| 251 |
+
See full notice in [LICENSE.md](LICENSE.md) <br />
|
| 252 |
+
Lucine L. Oganesian, Saba Hashemi, and Maryam M. Shanechi <br />
|
| 253 |
+
Shanechi Lab, University of Southern California
|
| 254 |
+
|
| 255 |
+
|
barista/config/braintreebank.yaml
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Directory where the raw data exists.
|
| 2 |
+
dataset_dir: "braintreebank_raw"
|
| 3 |
+
## Directory where to save the preprocessed data.
|
| 4 |
+
save_dir: "braintreebank_data_segments"
|
| 5 |
+
## Directory where to store cached stage 1 preprocessed data (i.e., filtered, rereferenced) to then segment.
|
| 6 |
+
stage1_cache_dir: "braintreebank_processed_raw_cache"
|
| 7 |
+
|
| 8 |
+
samp_frequency: 2048 # in Hz. Default: 2048.
|
| 9 |
+
segment_length_s: 3
|
| 10 |
+
region_filtering:
|
| 11 |
+
active: True
|
| 12 |
+
# Use region names that partially match the Destrieux column in the
|
| 13 |
+
# localization file to exclude channels.
|
| 14 |
+
filters:
|
| 15 |
+
- GRID
|
| 16 |
+
- VENT
|
| 17 |
+
|
| 18 |
+
aggregate_labels:
|
| 19 |
+
nan_threshold: 1 # value between 0 and 1, drop segments with more than this percentage of NaNs
|
| 20 |
+
type: threshold # threshold | mean
|
| 21 |
+
threshold: 0.5
|
| 22 |
+
|
| 23 |
+
quantile_numerical_labels:
|
| 24 |
+
active: True
|
| 25 |
+
lower_threshold: 0.25
|
| 26 |
+
higher_threshold: 0.75
|
| 27 |
+
|
| 28 |
+
force_balanced: True
|
| 29 |
+
force_nonoverlap: True
|
| 30 |
+
|
| 31 |
+
## NOTE: val_ratio and test_ratio only used for shuffle & random splits.
|
| 32 |
+
val_ratio: 0.1
|
| 33 |
+
test_ratio: 0.1
|
| 34 |
+
|
| 35 |
+
## NOTE: run_ratios only used for chronological splits; use val_ratio and test_ratio in
|
| 36 |
+
## dataset/single/base.yaml for shuffle & random splits.
|
| 37 |
+
run_ratios: [0.8, 0.1, 0.1]
|
| 38 |
+
run_splits: ["train", "val", "test"]
|
| 39 |
+
chron_fold_num: 0 # Chronological fold number to use. Default is ratios & splits in config.
|
| 40 |
+
|
| 41 |
+
## This is the step size used when generating negative sample segments for sentence_onset*
|
| 42 |
+
## and speech_vs_nonspeech* tasks.
|
| 43 |
+
nonword_stepsize_s: # leave empty for no nonword overlap (i.e., step = segment length)
|
| 44 |
+
|
| 45 |
+
trial_alignment: center # center only supported for now. Can extend to other alignments as desired.
|
| 46 |
+
subjects_to_process: # list of which subjects to process, set empty to run for all available
|
| 47 |
+
# - SUBJ_1
|
| 48 |
+
# - SUBJ_2
|
| 49 |
+
# - SUBJ_3
|
| 50 |
+
# - SUBJ_4
|
| 51 |
+
# - SUBJ_5
|
| 52 |
+
# - SUBJ_6
|
| 53 |
+
# - SUBJ_7
|
| 54 |
+
# - SUBJ_8
|
| 55 |
+
# - SUBJ_9
|
| 56 |
+
# - SUBJ_10
|
| 57 |
+
- HOLDSUBJ_1
|
| 58 |
+
- HOLDSUBJ_2
|
| 59 |
+
- HOLDSUBJ_3
|
| 60 |
+
- HOLDSUBJ_4
|
| 61 |
+
- HOLDSUBJ_6
|
| 62 |
+
- HOLDSUBJ_7
|
| 63 |
+
- HOLDSUBJ_10
|
| 64 |
+
|
| 65 |
+
# Options:
|
| 66 |
+
# "speech_vs_nonspeech" | "sentence_onset" [random split]
|
| 67 |
+
# "sentence_onset_time" | "speech_vs_nonspeech_time" | "volume" | "optical_flow" [chronological split]
|
| 68 |
+
experiment: "sentence_onset_time"
|
| 69 |
+
|
| 70 |
+
### Dataset processing
|
| 71 |
+
skip_segment_generation_completely: False
|
| 72 |
+
force_reprocess_stage1: False
|
| 73 |
+
force_reprocess_stage2: False
|
| 74 |
+
force_recreate_spatial_groupings: False
|
| 75 |
+
|
| 76 |
+
processing_save_interval: 100 # save files every # of segments
|
| 77 |
+
processing_log_interval: 50
|
| 78 |
+
|
| 79 |
+
use_fixed_seed_for_splitter: True
|
| 80 |
+
split_together_length_s: 3 # Note: Recommended to use the same value as segment_length_s above
|
| 81 |
+
|
| 82 |
+
shuffle_dataloader: True
|
| 83 |
+
|
| 84 |
+
# Note: Recommendation is to use the full subject_session label here.
|
| 85 |
+
pretrain_sessions:
|
| 86 |
+
- SUBJ_1_S1_0
|
| 87 |
+
# - SUBJ_1_S1_2
|
| 88 |
+
# - SUBJ_2_S2_0
|
| 89 |
+
# - SUBJ_2_S2_1
|
| 90 |
+
# - SUBJ_2_S2_2
|
| 91 |
+
# - SUBJ_2_S2_3
|
| 92 |
+
# - SUBJ_2_S2_4
|
| 93 |
+
# - SUBJ_3_S3_1
|
| 94 |
+
# - SUBJ_3_S3_2
|
| 95 |
+
# - SUBJ_4_S4_1
|
| 96 |
+
# - SUBJ_5_S5_0
|
| 97 |
+
# - SUBJ_6_S6_0
|
| 98 |
+
# - SUBJ_6_S6_1
|
| 99 |
+
# - SUBJ_7_S7_1
|
| 100 |
+
# - SUBJ_8_S8_0
|
| 101 |
+
# - SUBJ_9_S9_0
|
| 102 |
+
# - SUBJ_10_S10_1
|
| 103 |
+
finetune_sessions:
|
| 104 |
+
# - SUBJ_2_S2_5 # Pseudo held out
|
| 105 |
+
# - SUBJ_4_S4_2 # Pseudo held out
|
| 106 |
+
- HOLDSUBJ_1_HS1_1
|
| 107 |
+
# - HOLDSUBJ_2_HS2_6
|
| 108 |
+
# - HOLDSUBJ_3_HS3_0
|
| 109 |
+
# - HOLDSUBJ_4_HS4_0
|
| 110 |
+
# - HOLDSUBJ_6_HS6_4
|
| 111 |
+
# - HOLDSUBJ_7_HS7_0
|
| 112 |
+
# - HOLDSUBJ_10_HS10_0
|
| 113 |
+
|
| 114 |
+
spatial_groupings_to_create:
|
| 115 |
+
- coords
|
| 116 |
+
- destrieux
|
| 117 |
+
- lobes
|
barista/config/model.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
backbone:
|
| 2 |
+
num_layers: 12
|
| 3 |
+
d_hidden: 64
|
| 4 |
+
d_input: ${backbone.d_hidden} # same as d hidden
|
| 5 |
+
d_out: ${backbone.d_hidden} # same as d hidden
|
| 6 |
+
mlp_ratio: 4
|
| 7 |
+
norm: rmsnorm
|
| 8 |
+
norm_eps: 1e-8
|
| 9 |
+
activation: gelu
|
| 10 |
+
num_heads: 4
|
| 11 |
+
max_position: 1024
|
| 12 |
+
dropout: 0.1
|
| 13 |
+
|
| 14 |
+
tokenizer:
|
| 15 |
+
temporal_encoder:
|
| 16 |
+
input_dims: 128
|
| 17 |
+
output_dims: 128
|
| 18 |
+
hidden_dims: 5
|
| 19 |
+
depth: 4 # Zero-index (will have 5 convolution blocks all together)
|
| 20 |
+
kernel_size: 3
|
| 21 |
+
stride: 1
|
| 22 |
+
enable_checkpointing: False
|
| 23 |
+
|
| 24 |
+
temporal_subsegment_len: 512
|
| 25 |
+
temporal_subsegment_step: 512
|
| 26 |
+
|
| 27 |
+
samp_frequency: 2048
|
| 28 |
+
num_seconds: 3
|
| 29 |
+
|
| 30 |
+
d_hidden: ${backbone.d_input}
|
| 31 |
+
|
| 32 |
+
add_spatial_encoding: True
|
| 33 |
+
spatial_grouping: destrieux # coords | destrieux | lobes
|
| 34 |
+
|
| 35 |
+
embedding_max_dim: # leave empty for no normalization of embeddings
|
| 36 |
+
embedding_init_scale: 1.0
|
barista/config/train.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed: 0
|
| 2 |
+
checkpoint_path: "pretrained_models/parcels_chans.ckpt"
|
| 3 |
+
device: cuda:0
|
| 4 |
+
epochs: 30
|
| 5 |
+
dataloader:
|
| 6 |
+
drop_last: False
|
| 7 |
+
drop_last_val: False
|
| 8 |
+
num_workers: 16
|
| 9 |
+
batch_size: 128
|
| 10 |
+
persistent_workers: False
|
| 11 |
+
pin_memory: True
|
| 12 |
+
optimization:
|
| 13 |
+
finetune_lr: 1e-4
|
| 14 |
+
new_param_lr: 1e-3
|
| 15 |
+
freeze_tokenizer: True
|
barista/data/atlas.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Enums for the various spatial scales explored.
|
| 2 |
+
|
| 3 |
+
Useful references for the atlas parcels:
|
| 4 |
+
https://pmc.ncbi.nlm.nih.gov/articles/PMC2937159/pdf/nihms213933.pdf
|
| 5 |
+
https://surfer.nmr.mgh.harvard.edu/pub/articles/HBM09-Destrieux-Sulcal.pdf
|
| 6 |
+
|
| 7 |
+
Useful references for mapping atlas parcels to lobes (see below):
|
| 8 |
+
https://surfer.nmr.mgh.harvard.edu/fswiki/CorticalParcellation
|
| 9 |
+
https://www.frontiersin.org/journals/neuroscience/articles/10.3389/fnins.2012.00171/full#h12
|
| 10 |
+
"""
|
| 11 |
+
import enum
|
| 12 |
+
|
| 13 |
+
UNKNWON_STR = "UNKNOWN"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class EnumWithUnknown(enum.Enum):
|
| 17 |
+
@classmethod
|
| 18 |
+
def get_enum(cls, value):
|
| 19 |
+
value = (value or UNKNWON_STR).upper()
|
| 20 |
+
try:
|
| 21 |
+
return cls[value]
|
| 22 |
+
except KeyError as e:
|
| 23 |
+
raise NotImplementedError(
|
| 24 |
+
f"Unknown value '{value}' for enum {cls.__name__}"
|
| 25 |
+
) from e
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Destrieux(EnumWithUnknown):
|
| 29 |
+
UNKNOWN = 0
|
| 30 |
+
LEFT_AMYGDALA = 1
|
| 31 |
+
LEFT_HIPPOCAMPUS = 2
|
| 32 |
+
LEFT_INF_LAT_VENT = 3
|
| 33 |
+
LEFT_PUTAMEN = 4
|
| 34 |
+
RIGHT_AMYGDALA = 5
|
| 35 |
+
RIGHT_HIPPOCAMPUS = 6
|
| 36 |
+
RIGHT_INF_LAT_VENT = 7
|
| 37 |
+
RIGHT_PUTAMEN = 8
|
| 38 |
+
CTX_LH_G_INS_LG_AND_S_CENT_INS = 9
|
| 39 |
+
CTX_LH_G_AND_S_CINGUL_ANT = 10
|
| 40 |
+
CTX_LH_G_AND_S_CINGUL_MID_ANT = 11
|
| 41 |
+
CTX_LH_G_AND_S_CINGUL_MID_POST = 12
|
| 42 |
+
CTX_LH_G_AND_S_SUBCENTRAL = 13
|
| 43 |
+
CTX_LH_G_CINGUL_POST_DORSAL = 14
|
| 44 |
+
CTX_LH_G_FRONT_INF_OPERCULAR = 15
|
| 45 |
+
CTX_LH_G_FRONT_INF_ORBITAL = 16
|
| 46 |
+
CTX_LH_G_FRONT_INF_TRIANGUL = 17
|
| 47 |
+
CTX_LH_G_FRONT_MIDDLE = 18
|
| 48 |
+
CTX_LH_G_FRONT_SUP = 19
|
| 49 |
+
CTX_LH_G_INSULAR_SHORT = 20
|
| 50 |
+
CTX_LH_G_OC_TEMP_MED_PARAHIP = 21
|
| 51 |
+
CTX_LH_G_OCCIPITAL_MIDDLE = 22
|
| 52 |
+
CTX_LH_G_ORBITAL = 23
|
| 53 |
+
CTX_LH_G_PARIET_INF_ANGULAR = 24
|
| 54 |
+
CTX_LH_G_PARIET_INF_SUPRAMAR = 25
|
| 55 |
+
CTX_LH_G_PARIETAL_SUP = 26
|
| 56 |
+
CTX_LH_G_POSTCENTRAL = 27
|
| 57 |
+
CTX_LH_G_PRECENTRAL = 28
|
| 58 |
+
CTX_LH_G_PRECUNEUS = 29
|
| 59 |
+
CTX_LH_G_RECTUS = 30
|
| 60 |
+
CTX_LH_G_TEMP_SUP_G_T_TRANSV = 31
|
| 61 |
+
CTX_LH_G_TEMP_SUP_LATERAL = 32
|
| 62 |
+
CTX_LH_G_TEMP_SUP_PLAN_POLAR = 33
|
| 63 |
+
CTX_LH_G_TEMP_SUP_PLAN_TEMPO = 34
|
| 64 |
+
CTX_LH_G_TEMPORAL_INF = 35
|
| 65 |
+
CTX_LH_G_TEMPORAL_MIDDLE = 36
|
| 66 |
+
CTX_LH_LAT_FIS_ANT_HORIZONT = 37
|
| 67 |
+
CTX_LH_LAT_FIS_ANT_VERTICAL = 38
|
| 68 |
+
CTX_LH_LAT_FIS_POST = 39
|
| 69 |
+
CTX_LH_POLE_TEMPORAL = 40
|
| 70 |
+
CTX_LH_S_CALCARINE = 41
|
| 71 |
+
CTX_LH_S_CENTRAL = 42
|
| 72 |
+
CTX_LH_S_CINGUL_MARGINALIS = 43
|
| 73 |
+
CTX_LH_S_CIRCULAR_INSULA_ANT = 44
|
| 74 |
+
CTX_LH_S_CIRCULAR_INSULA_INF = 45
|
| 75 |
+
CTX_LH_S_CIRCULAR_INSULA_SUP = 46
|
| 76 |
+
CTX_LH_S_COLLAT_TRANSV_ANT = 47
|
| 77 |
+
CTX_LH_S_FRONT_INF = 48
|
| 78 |
+
CTX_LH_S_FRONT_MIDDLE = 49
|
| 79 |
+
CTX_LH_S_FRONT_SUP = 50
|
| 80 |
+
CTX_LH_S_INTRAPARIET_AND_P_TRANS = 51
|
| 81 |
+
CTX_LH_S_OC_TEMP_MED_AND_LINGUAL = 52
|
| 82 |
+
CTX_LH_S_ORBITAL_H_SHAPED = 53
|
| 83 |
+
CTX_LH_S_ORBITAL_LATERAL = 54
|
| 84 |
+
CTX_LH_S_ORBITAL_MED_OLFACT = 55
|
| 85 |
+
CTX_LH_S_PARIETO_OCCIPITAL = 56
|
| 86 |
+
CTX_LH_S_PERICALLOSAL = 57
|
| 87 |
+
CTX_LH_S_POSTCENTRAL = 58
|
| 88 |
+
CTX_LH_S_PRECENTRAL_INF_PART = 59
|
| 89 |
+
CTX_LH_S_PRECENTRAL_SUP_PART = 60
|
| 90 |
+
CTX_LH_S_SUBORBITAL = 61
|
| 91 |
+
CTX_LH_S_SUBPARIETAL = 62
|
| 92 |
+
CTX_LH_S_TEMPORAL_INF = 63
|
| 93 |
+
CTX_LH_S_TEMPORAL_SUP = 64
|
| 94 |
+
CTX_LH_S_TEMPORAL_TRANSVERSE = 65
|
| 95 |
+
CTX_RH_G_INS_LG_AND_S_CENT_INS = 66
|
| 96 |
+
CTX_RH_G_AND_S_CINGUL_ANT = 67
|
| 97 |
+
CTX_RH_G_AND_S_CINGUL_MID_ANT = 68
|
| 98 |
+
CTX_RH_G_AND_S_CINGUL_MID_POST = 69
|
| 99 |
+
CTX_RH_G_AND_S_FRONTOMARGIN = 70
|
| 100 |
+
CTX_RH_G_AND_S_PARACENTRAL = 71
|
| 101 |
+
CTX_RH_G_AND_S_SUBCENTRAL = 72
|
| 102 |
+
CTX_RH_G_CINGUL_POST_DORSAL = 73
|
| 103 |
+
CTX_RH_G_FRONT_INF_OPERCULAR = 74
|
| 104 |
+
CTX_RH_G_FRONT_INF_ORBITAL = 75
|
| 105 |
+
CTX_RH_G_FRONT_INF_TRIANGUL = 76
|
| 106 |
+
CTX_RH_G_FRONT_MIDDLE = 77
|
| 107 |
+
CTX_RH_G_FRONT_SUP = 78
|
| 108 |
+
CTX_RH_G_INSULAR_SHORT = 79
|
| 109 |
+
CTX_RH_G_OC_TEMP_LAT_FUSIFOR = 80
|
| 110 |
+
CTX_RH_G_OC_TEMP_MED_PARAHIP = 81
|
| 111 |
+
CTX_RH_G_ORBITAL = 82
|
| 112 |
+
CTX_RH_G_PARIET_INF_ANGULAR = 83
|
| 113 |
+
CTX_RH_G_PARIET_INF_SUPRAMAR = 84
|
| 114 |
+
CTX_RH_G_PRECENTRAL = 85
|
| 115 |
+
CTX_RH_G_RECTUS = 86
|
| 116 |
+
CTX_RH_G_TEMP_SUP_G_T_TRANSV = 87
|
| 117 |
+
CTX_RH_G_TEMP_SUP_LATERAL = 88
|
| 118 |
+
CTX_RH_G_TEMP_SUP_PLAN_POLAR = 89
|
| 119 |
+
CTX_RH_G_TEMP_SUP_PLAN_TEMPO = 90
|
| 120 |
+
CTX_RH_G_TEMPORAL_INF = 91
|
| 121 |
+
CTX_RH_G_TEMPORAL_MIDDLE = 92
|
| 122 |
+
CTX_RH_LAT_FIS_ANT_HORIZONT = 93
|
| 123 |
+
CTX_RH_LAT_FIS_ANT_VERTICAL = 94
|
| 124 |
+
CTX_RH_LAT_FIS_POST = 95
|
| 125 |
+
CTX_RH_POLE_TEMPORAL = 96
|
| 126 |
+
CTX_RH_S_CENTRAL = 97
|
| 127 |
+
CTX_RH_S_CINGUL_MARGINALIS = 98
|
| 128 |
+
CTX_RH_S_CIRCULAR_INSULA_ANT = 99
|
| 129 |
+
CTX_RH_S_CIRCULAR_INSULA_INF = 100
|
| 130 |
+
CTX_RH_S_CIRCULAR_INSULA_SUP = 101
|
| 131 |
+
CTX_RH_S_COLLAT_TRANSV_ANT = 102
|
| 132 |
+
CTX_RH_S_FRONT_INF = 103
|
| 133 |
+
CTX_RH_S_FRONT_MIDDLE = 104
|
| 134 |
+
CTX_RH_S_FRONT_SUP = 105
|
| 135 |
+
CTX_RH_S_INTRAPARIET_AND_P_TRANS = 106
|
| 136 |
+
CTX_RH_S_OC_TEMP_LAT = 107
|
| 137 |
+
CTX_RH_S_OC_TEMP_MED_AND_LINGUAL = 108
|
| 138 |
+
CTX_RH_S_ORBITAL_H_SHAPED = 109
|
| 139 |
+
CTX_RH_S_ORBITAL_LATERAL = 110
|
| 140 |
+
CTX_RH_S_ORBITAL_MED_OLFACT = 111
|
| 141 |
+
CTX_RH_S_PERICALLOSAL = 112
|
| 142 |
+
CTX_RH_S_POSTCENTRAL = 113
|
| 143 |
+
CTX_RH_S_PRECENTRAL_INF_PART = 114
|
| 144 |
+
CTX_RH_S_PRECENTRAL_SUP_PART = 115
|
| 145 |
+
CTX_RH_S_SUBORBITAL = 116
|
| 146 |
+
CTX_RH_S_SUBPARIETAL = 117
|
| 147 |
+
CTX_RH_S_TEMPORAL_INF = 118
|
| 148 |
+
CTX_RH_S_TEMPORAL_SUP = 119
|
| 149 |
+
CTX_RH_S_TEMPORAL_TRANSVERSE = 120
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class Lobes(EnumWithUnknown):
|
| 153 |
+
"""Maps the Desikan-Killany Atlas regions to lobes."""
|
| 154 |
+
UNKNOWN = 0
|
| 155 |
+
|
| 156 |
+
## Amygdala (Left, Right)
|
| 157 |
+
LEFT_AMYGDALA = 1
|
| 158 |
+
RIGHT_AMYGDALA = 2
|
| 159 |
+
|
| 160 |
+
## Hippocampus (Left, Right)
|
| 161 |
+
LEFT_HIPPOCAMPUS = 3
|
| 162 |
+
RIGHT_HIPPOCAMPUS = 4
|
| 163 |
+
|
| 164 |
+
## Frontal Lobe (Left)
|
| 165 |
+
CTX_LH_SUPERIORFRONTAL = 5
|
| 166 |
+
CTX_LH_ROSTRALMIDDLEFRONTAL = 5
|
| 167 |
+
CTX_LH_CAUDALMIDDLEFRONTAL = 5
|
| 168 |
+
CTX_LH_PARSOPERCULARIS = 5
|
| 169 |
+
CTX_LH_PARSORBITALIS = 5
|
| 170 |
+
CTX_LH_PARSTRIANGULARIS = 5
|
| 171 |
+
CTX_LH_LATERALORBITOFRONTAL = 5
|
| 172 |
+
CTX_LH_MEDIALORBITOFRONTAL = 5
|
| 173 |
+
CTX_LH_PRECENTRAL = 5
|
| 174 |
+
CTX_LH_PARACENTRAL = 5
|
| 175 |
+
|
| 176 |
+
## Frontal Cortex (Right)
|
| 177 |
+
CTX_RH_SUPERIORFRONTAL = 6
|
| 178 |
+
CTX_RH_ROSTRALMIDDLEFRONTAL = 6
|
| 179 |
+
CTX_RH_CAUDALMIDDLEFRONTAL = 6
|
| 180 |
+
CTX_RH_PARSOPERCULARIS = 6
|
| 181 |
+
CTX_RH_PARSORBITALIS = 6
|
| 182 |
+
CTX_RH_PARSTRIANGULARIS = 6
|
| 183 |
+
CTX_RH_LATERALORBITOFRONTAL = 6
|
| 184 |
+
CTX_RH_MEDIALORBITOFRONTAL = 6
|
| 185 |
+
CTX_RH_PRECENTRAL = 6
|
| 186 |
+
CTX_RH_PARACENTRAL = 6
|
| 187 |
+
# Frontal pole should go here in the future
|
| 188 |
+
|
| 189 |
+
## Parietal Lobe (Left)
|
| 190 |
+
CTX_LH_SUPERIORPARIETAL = 7
|
| 191 |
+
CTX_LH_INFERIORPARIETAL = 7
|
| 192 |
+
CTX_LH_SUPRAMARGINAL = 7
|
| 193 |
+
CTX_LH_POSTCENTRAL = 7
|
| 194 |
+
CTX_LH_PRECUNEUS = 7
|
| 195 |
+
|
| 196 |
+
## Parietal Lobe (Right)
|
| 197 |
+
CTX_RH_SUPERIORPARIETAL = 8
|
| 198 |
+
CTX_RH_INFERIORPARIETAL = 8
|
| 199 |
+
CTX_RH_SUPRAMARGINAL = 8
|
| 200 |
+
CTX_RH_POSTCENTRAL = 8
|
| 201 |
+
CTX_RH_PRECUNEUS = 8
|
| 202 |
+
|
| 203 |
+
## Temporal Lobe (Left)
|
| 204 |
+
CTX_LH_SUPERIORTEMPORAL = 9
|
| 205 |
+
CTX_LH_MIDDLETEMPORAL = 9
|
| 206 |
+
CTX_LH_INFERIORTEMPORAL = 9
|
| 207 |
+
CTX_LH_BANKSSTS = 9
|
| 208 |
+
CTX_LH_FUSIFORM = 9
|
| 209 |
+
CTX_LH_TRANSVERSETEMPORAL = 9
|
| 210 |
+
CTX_LH_ENTORHINAL = 9
|
| 211 |
+
CTX_LH_TEMPORALPOLE = 9
|
| 212 |
+
CTX_LH_PARAHIPPOCAMPAL = 9
|
| 213 |
+
|
| 214 |
+
## Temporal Lobe (Right)
|
| 215 |
+
CTX_RH_SUPERIORTEMPORAL = 10
|
| 216 |
+
CTX_RH_MIDDLETEMPORAL = 10
|
| 217 |
+
CTX_RH_INFERIORTEMPORAL = 10
|
| 218 |
+
CTX_RH_BANKSSTS = 10
|
| 219 |
+
CTX_RH_FUSIFORM = 10
|
| 220 |
+
CTX_RH_TRANSVERSETEMPORAL = 10
|
| 221 |
+
CTX_RH_ENTORHINAL = 10
|
| 222 |
+
CTX_RH_TEMPORALPOLE = 10
|
| 223 |
+
CTX_RH_PARAHIPPOCAMPAL = 10
|
| 224 |
+
|
| 225 |
+
## Occipital Lobe (Left) - ENUM 11 RESERVED
|
| 226 |
+
|
| 227 |
+
## Occipital Lobe (Right) - ENUM 12 RESERVED
|
| 228 |
+
|
| 229 |
+
## Cingulate (Left)
|
| 230 |
+
CTX_LH_ROSTRALANTERIORCINGULATE = 13
|
| 231 |
+
CTX_LH_CAUDALANTERIORCINGULATE = 13
|
| 232 |
+
CTX_LH_POSTERIORCINGULATE = 13
|
| 233 |
+
CTX_LH_ISTHMUSCINGULATE = 13
|
| 234 |
+
|
| 235 |
+
## Cingulate (Right)
|
| 236 |
+
CTX_RH_ROSTRALANTERIORCINGULATE = 14
|
| 237 |
+
CTX_RH_CAUDALANTERIORCINGULATE = 14
|
| 238 |
+
CTX_RH_POSTERIORCINGULATE = 14
|
| 239 |
+
CTX_RH_ISTHMUSCINGULATE = 14
|
| 240 |
+
|
| 241 |
+
## Insula (Left, Right)
|
| 242 |
+
CTX_LH_INSULA = 15
|
| 243 |
+
CTX_RH_INSULA = 16
|
| 244 |
+
|
| 245 |
+
## Putamen (Left, Right)
|
| 246 |
+
LEFT_PUTAMEN = 17
|
| 247 |
+
RIGHT_PUTAMEN = 18
|
| 248 |
+
|
| 249 |
+
## Ventricles (Left, Right)
|
| 250 |
+
LEFT_INF_LAT_VENT = 19
|
| 251 |
+
RIGHT_INF_LAT_VENT = 20
|
barista/data/available_sessions.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
def enumval_formatter(subject, trial_list):
|
| 4 |
+
return [f"S{subject}_{trial}" for trial in trial_list]
|
| 5 |
+
|
| 6 |
+
def holdout_enumval_formatter(subject, trial_list):
|
| 7 |
+
return [f"HS{subject}_{trial}" for trial in trial_list]
|
| 8 |
+
|
| 9 |
+
class BrainTreebankAvailableSessions(Enum):
|
| 10 |
+
SUBJ_1: list = enumval_formatter("1", ["0", "2"])
|
| 11 |
+
SUBJ_2: list = enumval_formatter("2", ["0", "1", "2", "3", "4", "5"])
|
| 12 |
+
SUBJ_3: list = enumval_formatter("3", ["1", "2"])
|
| 13 |
+
SUBJ_4: list = enumval_formatter("4", ["1", "2"])
|
| 14 |
+
SUBJ_5: list = enumval_formatter("5", ["0"])
|
| 15 |
+
SUBJ_6: list = enumval_formatter("6", ["0", "1"])
|
| 16 |
+
SUBJ_7: list = enumval_formatter("7", ["1"])
|
| 17 |
+
SUBJ_8: list = enumval_formatter("8", ["0"])
|
| 18 |
+
SUBJ_9: list = enumval_formatter("9", ["0"])
|
| 19 |
+
SUBJ_10: list = enumval_formatter("10", ["1"])
|
| 20 |
+
|
| 21 |
+
## Heldout trials.
|
| 22 |
+
HOLDSUBJ_1: list = holdout_enumval_formatter("1", ["1"])
|
| 23 |
+
HOLDSUBJ_2: list = holdout_enumval_formatter("2", ["6"])
|
| 24 |
+
HOLDSUBJ_3: list = holdout_enumval_formatter("3", ["0"])
|
| 25 |
+
HOLDSUBJ_4: list = holdout_enumval_formatter("4", ["0"])
|
| 26 |
+
HOLDSUBJ_6: list = holdout_enumval_formatter("6", ["4"])
|
| 27 |
+
HOLDSUBJ_7: list = holdout_enumval_formatter("7", ["0"])
|
| 28 |
+
HOLDSUBJ_10: list = holdout_enumval_formatter("10", ["0"])
|
barista/data/braintreebank_data_helpers.py
ADDED
|
@@ -0,0 +1,741 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Code to handle data I/O, parsing, and data/feature preprocessing for the BrainTreebank dataset.
|
| 2 |
+
|
| 3 |
+
Functionality in this module is based on the implementations found in the following
|
| 4 |
+
repositories, but have been modified as needed to be used as outlined in the BaRISTA paper:
|
| 5 |
+
https://github.com/czlwang/BrainBERT/tree/master/data
|
| 6 |
+
https://github.com/czlwang/PopulationTransformer/tree/main/data
|
| 7 |
+
https://github.com/czlwang/brain_treebank_code_release/tree/master/data
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
import os
|
| 12 |
+
from collections import OrderedDict
|
| 13 |
+
from enum import Enum
|
| 14 |
+
from typing import Dict, List, Union
|
| 15 |
+
|
| 16 |
+
import h5py
|
| 17 |
+
import numpy as np
|
| 18 |
+
import ordered_set
|
| 19 |
+
import pandas as pd
|
| 20 |
+
import scipy
|
| 21 |
+
import sklearn.preprocessing as sk_preprocessing
|
| 22 |
+
|
| 23 |
+
# Data frame column IDs for *_timings.csv and features.csv files.
|
| 24 |
+
_START_COL = "start"
|
| 25 |
+
_END_COL = "end"
|
| 26 |
+
_LBL_COL = "pos"
|
| 27 |
+
_TRIG_TIME_COL = "movie_time"
|
| 28 |
+
_START_WALLTIME = "start_time"
|
| 29 |
+
_TRIG_IDX_COL = "index"
|
| 30 |
+
_EST_IDX_COL = "est_idx"
|
| 31 |
+
_EST_END_IDX_COL = "est_end_idx"
|
| 32 |
+
_WORD_TIME_COL = "word_time"
|
| 33 |
+
_WORD_TEXT_COL = "text"
|
| 34 |
+
_IS_ONSET_COL = "is_onset"
|
| 35 |
+
_IS_OFFSET_COL = "is_offset"
|
| 36 |
+
|
| 37 |
+
# Data frame column IDs elec_coords_full.csv file.
|
| 38 |
+
_ELECTRODE_INFO = "Electrode"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class BrainTreebankDatasetNames(Enum):
|
| 42 |
+
PRETRAIN = "pretrain"
|
| 43 |
+
|
| 44 |
+
## Random splits downstream tasks.
|
| 45 |
+
SENTENCE_ONSET = "sentence_onset"
|
| 46 |
+
SPEECH_VS_NONSPEECH = "speech_vs_nonspeech"
|
| 47 |
+
|
| 48 |
+
## Chronological split downstream tasks.
|
| 49 |
+
SENTENCE_ONSET_TIME = "sentence_onset_time"
|
| 50 |
+
SPEECH_VS_NONSPEECH_TIME = "speech_vs_nonspeech_time"
|
| 51 |
+
VOLUME = "volume"
|
| 52 |
+
OPTICAL_FLOW = "optical_flow"
|
| 53 |
+
|
| 54 |
+
@classmethod
|
| 55 |
+
def get_modes(cls, modes_str: Union[str, List[str]]):
|
| 56 |
+
if isinstance(modes_str, str):
|
| 57 |
+
return cls(modes_str)
|
| 58 |
+
else:
|
| 59 |
+
modes = [cls(mode_str) for mode_str in modes_str]
|
| 60 |
+
return modes
|
| 61 |
+
|
| 62 |
+
def get_abbrv(self, c=1) -> str:
|
| 63 |
+
return "".join([b[:c] for b in self.value.split("_")])
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class BrainTreebankDatasetPathManager:
|
| 67 |
+
"""Manage file paths for Brain Treebank dataset
|
| 68 |
+
|
| 69 |
+
Expected dataset directory structure:
|
| 70 |
+
braintreebank_data
|
| 71 |
+
|__corrupted_elec.json
|
| 72 |
+
|__clean_laplacian.json
|
| 73 |
+
|__all_subject_data
|
| 74 |
+
| |__ sub_1_trial000.h5
|
| 75 |
+
| |__ sub_1_trial001.h5
|
| 76 |
+
| |__ sub_1_trial002.h5
|
| 77 |
+
| |__ sub_2_trial000.h5
|
| 78 |
+
| |
|
| 79 |
+
| ...
|
| 80 |
+
|
|
| 81 |
+
|__ electrode_labels
|
| 82 |
+
| |__ sub_1
|
| 83 |
+
| | |__ electrode_labels.json
|
| 84 |
+
| |__ sub_2
|
| 85 |
+
| | |__ electrode_labels.json
|
| 86 |
+
| ...
|
| 87 |
+
|
|
| 88 |
+
|__ localization
|
| 89 |
+
| |__ elec_coords_full.csv
|
| 90 |
+
| |__ sub_1
|
| 91 |
+
| | |__ depth-wm.csv
|
| 92 |
+
| |__ sub_2
|
| 93 |
+
| | |__ depth-wm.csv
|
| 94 |
+
| ...
|
| 95 |
+
|
|
| 96 |
+
|__ subject_metadata
|
| 97 |
+
| |__ sub_1_trial000_metadata.json
|
| 98 |
+
| |__ sub_1_trial001_metadata.json
|
| 99 |
+
| |__ sub_1_trial002_metadata.json
|
| 100 |
+
| |__ sub_2_trial000_metadata.json
|
| 101 |
+
| |
|
| 102 |
+
| ...
|
| 103 |
+
|
|
| 104 |
+
|__ subject_timings
|
| 105 |
+
| |__ sub_1_trial000_timings.csv
|
| 106 |
+
| |__ sub_1_trial001_timings.csv
|
| 107 |
+
| |__ sub_1_trial002_timings.csv
|
| 108 |
+
| |__ sub_2_trial000_timings.csv
|
| 109 |
+
| |
|
| 110 |
+
| ...
|
| 111 |
+
|
|
| 112 |
+
|__ transcripts
|
| 113 |
+
| |__ ant-man
|
| 114 |
+
| | |__ features.csv
|
| 115 |
+
| |__ aquaman
|
| 116 |
+
| | |__ features.csv
|
| 117 |
+
| ......
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(self, dataset_dir: str):
|
| 121 |
+
self.dataset_dir = dataset_dir
|
| 122 |
+
|
| 123 |
+
# Path to neural data h5 file.
|
| 124 |
+
self.neural_data_file = os.path.join(
|
| 125 |
+
self.dataset_dir,
|
| 126 |
+
"all_subject_data",
|
| 127 |
+
"sub_{}_trial00{}.h5",
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Path to electrode channel name meta information.
|
| 131 |
+
self.raw_electrodes_meta_file = os.path.join(
|
| 132 |
+
self.dataset_dir, "electrode_labels", "sub_{}", "electrode_labels.json"
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# Path to brain regions csv file.
|
| 136 |
+
self.regions_file = os.path.join(
|
| 137 |
+
self.dataset_dir, "localization", "sub_{}", "depth-wm.csv"
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Path to trial movie trigger times to align features with neural activity.
|
| 141 |
+
self.movie_triggers_file = os.path.join(
|
| 142 |
+
self.dataset_dir, "subject_timings", "sub_{}_trial00{}_timings.csv"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Path to trial meta information.
|
| 146 |
+
self.trial_meta = os.path.join(
|
| 147 |
+
self.dataset_dir, "subject_metadata", "sub_{}_trial00{}_metadata.json"
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Path to extracted features csv file.
|
| 151 |
+
self.features_file = os.path.join(
|
| 152 |
+
self.dataset_dir, "transcripts", "{}", "features.csv"
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
self._CORRUPTED_ELECTRODES_PATH = os.path.join(
|
| 156 |
+
self.dataset_dir, "corrupted_elec.json"
|
| 157 |
+
)
|
| 158 |
+
self._CLEAN_LAPLACIAN = os.path.join(
|
| 159 |
+
self.dataset_dir, "clean_laplacian.json"
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
def format_subject(self, subject: str) -> str:
|
| 163 |
+
"""AvailableSessions stores subjects as SUBJ_#. Strips 'SUBJ' prefix here."""
|
| 164 |
+
return subject.split("_")[-1]
|
| 165 |
+
|
| 166 |
+
def format_session(self, session: str) -> str:
|
| 167 |
+
"""AvailableSessions stores subject sessions with a prefix as (H)S_#. Strips prefix here."""
|
| 168 |
+
return session.split("_")[-1]
|
| 169 |
+
|
| 170 |
+
def get_raw_data_filepath(self, subject: str, session: str) -> str:
|
| 171 |
+
"""Get raw data file path for a given subject and trial.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
subject: subject str e.g. 1
|
| 175 |
+
session: trial int e.g. 0
|
| 176 |
+
"""
|
| 177 |
+
return self.neural_data_file.format(
|
| 178 |
+
self.format_subject(subject), self.format_session(session)
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
def get_raw_electrode_channel_names_filepath(self, subject: str) -> str:
|
| 182 |
+
return self.raw_electrodes_meta_file.format(self.format_subject(subject))
|
| 183 |
+
|
| 184 |
+
def get_localization_filepath(self, subject: str) -> str:
|
| 185 |
+
return self.regions_file.format(self.format_subject(subject))
|
| 186 |
+
|
| 187 |
+
def get_noise_area_filepath(self) -> str:
|
| 188 |
+
return self._CORRUPTED_ELECTRODES_PATH
|
| 189 |
+
|
| 190 |
+
def get_clean_laplacian_filepath(self) -> str:
|
| 191 |
+
return self._CLEAN_LAPLACIAN
|
| 192 |
+
|
| 193 |
+
def get_movie_triggers_filepath(self, subject: str, trial: str) -> str:
|
| 194 |
+
return self.movie_triggers_file.format(
|
| 195 |
+
self.format_subject(subject), self.format_session(trial)
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
def get_features_filepath(self, subject: str, trial: str) -> str:
|
| 199 |
+
with open(
|
| 200 |
+
self.trial_meta.format(
|
| 201 |
+
self.format_subject(subject), self.format_session(trial)
|
| 202 |
+
),
|
| 203 |
+
"r",
|
| 204 |
+
) as f:
|
| 205 |
+
meta_dict = json.load(f)
|
| 206 |
+
title = meta_dict["title"]
|
| 207 |
+
movie_id = meta_dict["filename"]
|
| 208 |
+
|
| 209 |
+
print(f"Loading features for movie {title}.")
|
| 210 |
+
return self.features_file.format(movie_id), title
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class BrainTreebankDatasetRawDataHelper:
|
| 214 |
+
"""Manages loading data from the BrainTreebank dataset files.
|
| 215 |
+
|
| 216 |
+
Check each method docstring for file information.
|
| 217 |
+
"""
|
| 218 |
+
def __init__(
|
| 219 |
+
self,
|
| 220 |
+
path_manager: BrainTreebankDatasetPathManager,
|
| 221 |
+
samp_frequency: int = 2048,
|
| 222 |
+
):
|
| 223 |
+
self.path_manager = path_manager
|
| 224 |
+
self.samp_frequency = samp_frequency
|
| 225 |
+
self.localization_df = {}
|
| 226 |
+
self.trial_triggers_cache = {}
|
| 227 |
+
|
| 228 |
+
def get_raw_file(
|
| 229 |
+
self,
|
| 230 |
+
subject: str,
|
| 231 |
+
trial: str,
|
| 232 |
+
) -> dict:
|
| 233 |
+
"""File load from the file noise info meta hashmap.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
subject: str or int. Subject to index by.
|
| 237 |
+
trial: str or int. Subject trial to index by.
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
A dictionary containing following keys:
|
| 241 |
+
data: np.ndarray (n_samples x channels) -- actual recordings
|
| 242 |
+
time: np.ndarray (n_samples) -- timestamps when movie trigger times recorded
|
| 243 |
+
samp_frequency: sampling rate Hz
|
| 244 |
+
raw_electrode_info: list of channel names, indices are in order of columns in data
|
| 245 |
+
"""
|
| 246 |
+
path = self.path_manager.get_raw_data_filepath(subject, trial)
|
| 247 |
+
with h5py.File(path, "r") as hf:
|
| 248 |
+
raw_data = hf["data"]
|
| 249 |
+
|
| 250 |
+
channel_labels = self.get_electrode_info(subject)
|
| 251 |
+
|
| 252 |
+
raw_data_n_channels = len(raw_data.keys())
|
| 253 |
+
if subject == "SUBJ_1" or subject == "HOLDSUBJ_1":
|
| 254 |
+
raw_data_n_channels -= 1 # Will ignore last channel for subject 1 based on dataset author's comment
|
| 255 |
+
assert (
|
| 256 |
+
len(channel_labels) == raw_data_n_channels
|
| 257 |
+
), "Channel count mismatch between h5 and json."
|
| 258 |
+
|
| 259 |
+
# Extracts a numpy array from h5 dataset (may take a few minutes).
|
| 260 |
+
electrode_data = []
|
| 261 |
+
for i in range(len(channel_labels)):
|
| 262 |
+
electrode_data.append(raw_data[f"electrode_{i}"][:])
|
| 263 |
+
|
| 264 |
+
electrode_data = np.stack(electrode_data)
|
| 265 |
+
|
| 266 |
+
return {
|
| 267 |
+
"data": electrode_data.T, # n_samples x n_channels
|
| 268 |
+
"time": self._extract_neural_timestamps(subject, trial, electrode_data),
|
| 269 |
+
"samp_frequency": self.samp_frequency,
|
| 270 |
+
"electrode_info": channel_labels,
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
def get_corrupted_elecs(self, subject: str) -> List[str]:
|
| 274 |
+
"""
|
| 275 |
+
Returns:
|
| 276 |
+
a list of strings corresponding to corrupted electrode channel names.
|
| 277 |
+
"""
|
| 278 |
+
with open(self.path_manager.get_noise_area_filepath(), "r") as f:
|
| 279 |
+
corrupted_elecs = json.load(f)
|
| 280 |
+
return corrupted_elecs[f"subject{self.path_manager.format_subject(subject)}"]
|
| 281 |
+
|
| 282 |
+
def get_clean_elecs(self, subject: str) -> List[str]:
|
| 283 |
+
"""
|
| 284 |
+
Returns:
|
| 285 |
+
a list of strings corresponding to clean electrode channel names.
|
| 286 |
+
"""
|
| 287 |
+
with open(self.path_manager.get_clean_laplacian_filepath(), "r") as f:
|
| 288 |
+
elecs = json.load(f)
|
| 289 |
+
return elecs[f"sub_{self.path_manager.format_subject(subject)}"]
|
| 290 |
+
|
| 291 |
+
def _elec_name_strip(self, x):
|
| 292 |
+
return x.replace("*", "").replace("#", "").replace("_", "")
|
| 293 |
+
|
| 294 |
+
def get_electrode_info(self, subject: str) -> List[str]:
|
| 295 |
+
"""
|
| 296 |
+
Returns list of electrodes for the specified trial.
|
| 297 |
+
NOTE: the order of these labels is important. Their position corresponds with a row in data.h5
|
| 298 |
+
"""
|
| 299 |
+
with open(
|
| 300 |
+
self.path_manager.get_raw_electrode_channel_names_filepath(subject), "r"
|
| 301 |
+
) as f:
|
| 302 |
+
electrode_labels = json.load(f)
|
| 303 |
+
|
| 304 |
+
electrode_labels = [self._elec_name_strip(e) for e in electrode_labels]
|
| 305 |
+
return electrode_labels
|
| 306 |
+
|
| 307 |
+
def get_channel_localization_raw(self, subject: str) -> dict:
|
| 308 |
+
# Lazy loading.
|
| 309 |
+
if subject not in self.localization_df:
|
| 310 |
+
df = pd.read_csv(self.path_manager.get_localization_filepath(subject))
|
| 311 |
+
df[_ELECTRODE_INFO] = df[_ELECTRODE_INFO].apply(self._elec_name_strip)
|
| 312 |
+
self.localization_df[subject] = df
|
| 313 |
+
return self.localization_df[subject]
|
| 314 |
+
|
| 315 |
+
def get_channel_localization(
|
| 316 |
+
self, subject: str, channel_name: str
|
| 317 |
+
) -> dict:
|
| 318 |
+
"""Extract localization information for given subject and channel label.
|
| 319 |
+
|
| 320 |
+
Channel localization info is a pandas DataFrame with the headers:
|
| 321 |
+
ID: electrode channel ID
|
| 322 |
+
Z: Z coordinate (subject specific, to the best of our understanding)
|
| 323 |
+
X: X coordinate (subject specific, to the best of our understanding)
|
| 324 |
+
Y: Y coordinate (subject specific, to the best of our understanding)
|
| 325 |
+
Hemisphere: 0 (right) vs 1 (left)
|
| 326 |
+
Subject: sub_<id>
|
| 327 |
+
Electrode: Electrode channel label
|
| 328 |
+
Region: region based on Destrieux atlas
|
| 329 |
+
|
| 330 |
+
NOTE: https://surfer.nmr.mgh.harvard.edu/fswiki/CorticalParcellation
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
Dictionary with the following keys:
|
| 334 |
+
hemi: hemisphere
|
| 335 |
+
region_info: Destrieux parcel info
|
| 336 |
+
channel_stem: electrode name
|
| 337 |
+
coords: LIP coords
|
| 338 |
+
"""
|
| 339 |
+
df = self.get_channel_localization_raw(subject)
|
| 340 |
+
channel_row = df.loc[df[_ELECTRODE_INFO] == channel_name]
|
| 341 |
+
|
| 342 |
+
if len(channel_row) == 0:
|
| 343 |
+
return {}
|
| 344 |
+
|
| 345 |
+
def parse_region_str(region_str):
|
| 346 |
+
if "_" in region_str:
|
| 347 |
+
split_region_str = region_str.split("_")
|
| 348 |
+
hemi = "L" if split_region_str[1].lower() == "lh" else "R"
|
| 349 |
+
region_info = "_".join(split_region_str[2:])
|
| 350 |
+
elif "-" in region_str and "_" not in region_str:
|
| 351 |
+
split_region_str = region_str.split("-")
|
| 352 |
+
hemi = "L" if split_region_str[0].lower() == "left" else "R"
|
| 353 |
+
region_info = split_region_str[-1]
|
| 354 |
+
elif region_str.lower() == "unknown":
|
| 355 |
+
hemi = "UNKNOWN"
|
| 356 |
+
region_info = "UNKNOWN"
|
| 357 |
+
else:
|
| 358 |
+
raise ValueError(f"Unsupported region_str: {region_str}.")
|
| 359 |
+
return hemi, region_info
|
| 360 |
+
|
| 361 |
+
hemi, region_info = parse_region_str(channel_row.iloc[0]["Destrieux"])
|
| 362 |
+
channel_stem, _ = BrainTreebankDatasetRawDataHelper.stem_electrode_name(
|
| 363 |
+
channel_name
|
| 364 |
+
)
|
| 365 |
+
coords = channel_row.iloc[0][["L", "I", "P"]].to_numpy().astype(np.int64)
|
| 366 |
+
return {
|
| 367 |
+
"hemi": hemi,
|
| 368 |
+
"region_info": region_info,
|
| 369 |
+
"channel_stem": channel_stem,
|
| 370 |
+
"coords": coords,
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
@classmethod
|
| 374 |
+
def stem_electrode_name(cls, name):
|
| 375 |
+
"""Need to stem the electrode channel names to find neighbors.
|
| 376 |
+
|
| 377 |
+
Functionality from the BrainBERT repository:
|
| 378 |
+
https://github.com/czlwang/BrainBERT/tree/master/data
|
| 379 |
+
"""
|
| 380 |
+
# names look like 'O1aIb4', 'O1aIb5', 'O1aIb6', 'O1aIb7'
|
| 381 |
+
# names look like 'T1b2
|
| 382 |
+
name = name.replace("*", "") # some stems have * in name
|
| 383 |
+
found_stem_end = False
|
| 384 |
+
stem, num = [], []
|
| 385 |
+
for c in reversed(name):
|
| 386 |
+
if c.isalpha():
|
| 387 |
+
found_stem_end = True
|
| 388 |
+
if found_stem_end:
|
| 389 |
+
stem.append(c)
|
| 390 |
+
else:
|
| 391 |
+
num.append(c)
|
| 392 |
+
return "".join(reversed(stem)), int("".join(reversed(num)))
|
| 393 |
+
|
| 394 |
+
@classmethod
|
| 395 |
+
def get_all_laplacian_electrodes(cls, elec_list):
|
| 396 |
+
"""Select for channels that have neighbors needed for Laplacian rereferencing.
|
| 397 |
+
|
| 398 |
+
Functionality from the BrainBERT repository:
|
| 399 |
+
https://github.com/czlwang/BrainBERT/tree/master/data
|
| 400 |
+
"""
|
| 401 |
+
stems = [
|
| 402 |
+
BrainTreebankDatasetRawDataHelper.stem_electrode_name(e) for e in elec_list
|
| 403 |
+
]
|
| 404 |
+
|
| 405 |
+
def has_nbrs(stem, stems):
|
| 406 |
+
(x, y) = stem
|
| 407 |
+
return ((x, y + 1) in stems) and ((x, y - 1) in stems)
|
| 408 |
+
|
| 409 |
+
laplacian_stems = [x for x in stems if has_nbrs(x, stems)]
|
| 410 |
+
electrodes = [f"{x}{y}" for (x, y) in laplacian_stems]
|
| 411 |
+
return electrodes
|
| 412 |
+
|
| 413 |
+
def _get_trial_triggers(self, subject: str, trial: str) -> pd.DataFrame:
|
| 414 |
+
"""
|
| 415 |
+
Returns:
|
| 416 |
+
a pandas DataFrame with the following column headers:
|
| 417 |
+
type: trigger type
|
| 418 |
+
movie_time: movie time at which trigger was sent
|
| 419 |
+
start_time: wall clock time at which trigger was sent
|
| 420 |
+
end_time: wall clock time at which trigger concluded
|
| 421 |
+
trig_type: type of trigger token sent (movie beginning/end/pause/unpause)
|
| 422 |
+
index: neural data samples that recorded the beginning of the trigger
|
| 423 |
+
diff: ??
|
| 424 |
+
"""
|
| 425 |
+
movie_triggers_fpath = self.path_manager.get_movie_triggers_filepath(
|
| 426 |
+
subject, trial
|
| 427 |
+
)
|
| 428 |
+
triggers_cache_key = os.path.basename(movie_triggers_fpath)
|
| 429 |
+
# Use lazy loading of movie triggers to save on compute in the future.
|
| 430 |
+
if triggers_cache_key in self.trial_triggers_cache:
|
| 431 |
+
df = self.trial_triggers_cache[triggers_cache_key]
|
| 432 |
+
else:
|
| 433 |
+
df = pd.read_csv(movie_triggers_fpath)
|
| 434 |
+
self.trial_triggers_cache[triggers_cache_key] = df
|
| 435 |
+
return df
|
| 436 |
+
|
| 437 |
+
def _get_trial_features(self, subject: str, trial: str) -> List[Dict]:
|
| 438 |
+
"""
|
| 439 |
+
Returns:
|
| 440 |
+
a pandas DataFrame with the following column headers:
|
| 441 |
+
'bin_head',
|
| 442 |
+
'charecter_num',
|
| 443 |
+
'delta_magnitude',
|
| 444 |
+
'delta_mel',
|
| 445 |
+
'delta_pitch',
|
| 446 |
+
'delta_rms',
|
| 447 |
+
'deprel',
|
| 448 |
+
'end',
|
| 449 |
+
'est_idx', = estimated first neural sample
|
| 450 |
+
'est_end_idx', = estimated last neural sample
|
| 451 |
+
'face_num',
|
| 452 |
+
'gpt2_surprisal',
|
| 453 |
+
'head',
|
| 454 |
+
'idx_in_sentence',
|
| 455 |
+
'is_onset',
|
| 456 |
+
'lemma',
|
| 457 |
+
'magnitude',
|
| 458 |
+
'max_global_angle',
|
| 459 |
+
'max_global_magnitude',
|
| 460 |
+
'max_mean_magnitude',
|
| 461 |
+
'max_mean_pixel_brightness',
|
| 462 |
+
'max_mean_pixel_difference',
|
| 463 |
+
'max_median_magnitude',
|
| 464 |
+
'max_vector_angle',
|
| 465 |
+
'max_vector_magnitude',
|
| 466 |
+
'mean_pixel_brightness',
|
| 467 |
+
'mel',
|
| 468 |
+
'min_mean_pixel_brightness',
|
| 469 |
+
'min_mean_pixel_difference',
|
| 470 |
+
'onset_diff',
|
| 471 |
+
'phoneme_num',
|
| 472 |
+
'pitch',
|
| 473 |
+
'pos',
|
| 474 |
+
'prev_word_idx',
|
| 475 |
+
'rms',
|
| 476 |
+
'sentence',
|
| 477 |
+
'sentence_idx',
|
| 478 |
+
'speaker',
|
| 479 |
+
'start',
|
| 480 |
+
'syllable',
|
| 481 |
+
'text',
|
| 482 |
+
'word_diff',
|
| 483 |
+
'word_idx',
|
| 484 |
+
'word_length'
|
| 485 |
+
|
| 486 |
+
See dataset technical paper for full explanation: https://braintreebank.dev/.
|
| 487 |
+
"""
|
| 488 |
+
features_filename, movie_title = self.path_manager.get_features_filepath(
|
| 489 |
+
subject, trial
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
df = pd.read_csv(features_filename).set_index("Unnamed: 0")
|
| 493 |
+
df = df.dropna().reset_index(drop=True) # Drop rows with NaN word times.
|
| 494 |
+
trig_df = self._get_trial_triggers(subject, trial)
|
| 495 |
+
df = self._add_estimated_sample_index(df, trig_df)
|
| 496 |
+
df = df.dropna().reset_index(drop=True) # Drop rows with NaN sample times.
|
| 497 |
+
return df
|
| 498 |
+
|
| 499 |
+
def get_features(
|
| 500 |
+
self, subject: str, trial: str, feature_name: str, n_samples: int
|
| 501 |
+
) -> np.ndarray:
|
| 502 |
+
df = self._get_trial_features(subject, trial)
|
| 503 |
+
|
| 504 |
+
if feature_name == "volume":
|
| 505 |
+
feature_vals = df.rms
|
| 506 |
+
elif (
|
| 507 |
+
feature_name == "sentence_onset"
|
| 508 |
+
or feature_name == "sentence_onset_time"
|
| 509 |
+
):
|
| 510 |
+
feature_vals = df.is_onset
|
| 511 |
+
elif (
|
| 512 |
+
feature_name == "speech_vs_nonspeech"
|
| 513 |
+
or feature_name == "speech_vs_nonspeech_time"
|
| 514 |
+
):
|
| 515 |
+
feature_vals = np.ones(df.size)
|
| 516 |
+
elif feature_name == "optical_flow":
|
| 517 |
+
feature_vals = df.max_global_magnitude
|
| 518 |
+
else:
|
| 519 |
+
raise ValueError(f"Unsupported feature_name: {feature_name}")
|
| 520 |
+
|
| 521 |
+
label_intervals = list(zip(df[_EST_IDX_COL].array, df[_EST_END_IDX_COL].array))
|
| 522 |
+
label_init = lambda x: (
|
| 523 |
+
0
|
| 524 |
+
if x
|
| 525 |
+
in [
|
| 526 |
+
"speech_vs_nonspeech",
|
| 527 |
+
"speech_vs_nonspeech_time",
|
| 528 |
+
"sentence_onset",
|
| 529 |
+
"sentence_onset_time",
|
| 530 |
+
]
|
| 531 |
+
else np.nan
|
| 532 |
+
)
|
| 533 |
+
labels = np.ones(n_samples) * label_init(feature_name)
|
| 534 |
+
for label_ind, label_interval in enumerate(label_intervals):
|
| 535 |
+
if feature_name != "sentence_onset" and feature_name != "sentence_onset_time":
|
| 536 |
+
labels[int(label_interval[0]) : int(label_interval[1])] = feature_vals[
|
| 537 |
+
label_ind
|
| 538 |
+
]
|
| 539 |
+
else:
|
| 540 |
+
# sentence_onset has to only handle putting labels for onset words
|
| 541 |
+
labels[int(label_interval[0]) : int(label_interval[1])] = (
|
| 542 |
+
1 if feature_vals[label_ind] else np.nan
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
return labels, label_intervals
|
| 546 |
+
|
| 547 |
+
def _estimate_sample_index(self, t, near_t, near_trig):
|
| 548 |
+
"""Estimates the word onset data sample by interpolation from nearest trigger.
|
| 549 |
+
|
| 550 |
+
Source:
|
| 551 |
+
quickstart.ipynb notebook on https://braintreebank.dev/
|
| 552 |
+
|
| 553 |
+
Args:
|
| 554 |
+
t - word movie time
|
| 555 |
+
near_t - nearest trigger movie time
|
| 556 |
+
near_trig - nearest trigger sample index
|
| 557 |
+
|
| 558 |
+
Returns:
|
| 559 |
+
Estimated word onset sample index.
|
| 560 |
+
"""
|
| 561 |
+
trig_diff = (t - near_t) * self.samp_frequency
|
| 562 |
+
return round(near_trig + trig_diff)
|
| 563 |
+
|
| 564 |
+
def _add_estimated_sample_index(self, w_df, t_df):
|
| 565 |
+
"""Computes and adds data sample indices to annotated movie word onsets.
|
| 566 |
+
|
| 567 |
+
Source:
|
| 568 |
+
quickstart.ipynb notebook on https://braintreebank.dev/
|
| 569 |
+
|
| 570 |
+
Args:
|
| 571 |
+
w_df - movie annotated words data frame
|
| 572 |
+
t_df - computer triggers data frame
|
| 573 |
+
|
| 574 |
+
Returns:
|
| 575 |
+
Movie annotated words data frame augmented with estimated data sample indices
|
| 576 |
+
"""
|
| 577 |
+
tmp_w_df = w_df.copy(deep=True)
|
| 578 |
+
last_t = t_df.loc[len(t_df) - 1, _TRIG_TIME_COL]
|
| 579 |
+
for i, t, endt in zip(w_df.index, w_df[_START_COL], w_df[_END_COL]):
|
| 580 |
+
if t > last_t: # If movie continues after triggers
|
| 581 |
+
break
|
| 582 |
+
|
| 583 |
+
# Find nearest movie time index for start.
|
| 584 |
+
idx = (abs(t_df[_TRIG_TIME_COL] - t)).idxmin()
|
| 585 |
+
tmp_w_df.loc[i, :] = w_df.loc[i, :]
|
| 586 |
+
tmp_w_df.loc[i, _EST_IDX_COL] = self._estimate_sample_index(
|
| 587 |
+
t, t_df.loc[idx, _TRIG_TIME_COL], t_df.loc[idx, _TRIG_IDX_COL]
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
# Find nearest movie time index for end.
|
| 591 |
+
end_idx = (abs(t_df[_TRIG_TIME_COL] - endt)).idxmin()
|
| 592 |
+
tmp_w_df.loc[i, _EST_END_IDX_COL] = self._estimate_sample_index(
|
| 593 |
+
endt,
|
| 594 |
+
t_df.loc[end_idx, _TRIG_TIME_COL],
|
| 595 |
+
t_df.loc[end_idx, _TRIG_IDX_COL],
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
return tmp_w_df
|
| 599 |
+
|
| 600 |
+
def _extract_neural_timestamps(self, subject: str, trial: str, data: np.ndarray):
|
| 601 |
+
"""Extracts wall clock timestamps associated with recorded triggers.
|
| 602 |
+
|
| 603 |
+
NOTE: Not all samples will have a timestamp.
|
| 604 |
+
"""
|
| 605 |
+
t_df = self._get_trial_triggers(subject, trial)
|
| 606 |
+
timestamps = np.ones(data.shape[-1]) * np.nan
|
| 607 |
+
for sample_index, sample_walltime in zip(
|
| 608 |
+
t_df[_TRIG_IDX_COL], t_df[_START_WALLTIME]
|
| 609 |
+
):
|
| 610 |
+
timestamps[int(sample_index)] = sample_walltime
|
| 611 |
+
return timestamps
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
class BrainTreebankDatasetPreprocessor:
|
| 615 |
+
"""Helper class to preprocess the raw BrainTreebank neural data.
|
| 616 |
+
|
| 617 |
+
Recommended flow:
|
| 618 |
+
filter_data -> rereference
|
| 619 |
+
|
| 620 |
+
filter_data() currently performs:
|
| 621 |
+
notch filtering
|
| 622 |
+
|
| 623 |
+
Functionality partially utilizes implementations from the BrainBERT repository:
|
| 624 |
+
https://github.com/czlwang/BrainBERT/tree/master/data
|
| 625 |
+
"""
|
| 626 |
+
|
| 627 |
+
def __init__(self, config: Dict):
|
| 628 |
+
self.config = config
|
| 629 |
+
|
| 630 |
+
# For notch filtering.
|
| 631 |
+
self.freqs_to_filter = [60, 120, 180, 240, 300, 360]
|
| 632 |
+
|
| 633 |
+
def notch_filter(self, data: np.ndarray, freq: float, Q: int = 30) -> np.ndarray:
|
| 634 |
+
"""Notch filters input data along time axis.
|
| 635 |
+
|
| 636 |
+
Args:
|
| 637 |
+
data: np.ndarray shape (n_channels, n_samples)
|
| 638 |
+
|
| 639 |
+
Returns filtered signal.
|
| 640 |
+
"""
|
| 641 |
+
w0 = freq / (self.config.samp_frequency / 2)
|
| 642 |
+
b, a = scipy.signal.iirnotch(w0, Q)
|
| 643 |
+
y = scipy.signal.lfilter(b, a, data, axis=-1)
|
| 644 |
+
return y
|
| 645 |
+
|
| 646 |
+
def filter_data(self, data_arr: np.ndarray):
|
| 647 |
+
"""Filters data based on provided config.
|
| 648 |
+
|
| 649 |
+
Args:
|
| 650 |
+
data: np.ndarray shape (n_channels, n_samples)
|
| 651 |
+
|
| 652 |
+
Returns filtered signal.
|
| 653 |
+
"""
|
| 654 |
+
for f in self.freqs_to_filter:
|
| 655 |
+
data_arr = self.notch_filter(data_arr, f)
|
| 656 |
+
return data_arr
|
| 657 |
+
|
| 658 |
+
def _get_all_adj_electrodes(
|
| 659 |
+
self, selected_electrodes: List[str], all_electrodes: List[str]
|
| 660 |
+
):
|
| 661 |
+
"""Extracts all adjacent electrodes to use with Laplacian rereferencing."""
|
| 662 |
+
all_electrode_stems = [
|
| 663 |
+
BrainTreebankDatasetRawDataHelper.stem_electrode_name(l)
|
| 664 |
+
for l in all_electrodes
|
| 665 |
+
]
|
| 666 |
+
|
| 667 |
+
elec2neighbors_dict, unique_neighbors = OrderedDict(), ordered_set.OrderedSet()
|
| 668 |
+
for selected_electrode in selected_electrodes:
|
| 669 |
+
stem, num = BrainTreebankDatasetRawDataHelper.stem_electrode_name(
|
| 670 |
+
selected_electrode
|
| 671 |
+
)
|
| 672 |
+
nbrs = [
|
| 673 |
+
n
|
| 674 |
+
for n in [(stem, num - 1), (stem, num + 1)]
|
| 675 |
+
if n in all_electrode_stems
|
| 676 |
+
]
|
| 677 |
+
|
| 678 |
+
assert len(nbrs) == 2, "Neighbors must be 2 for Laplacian rereferencing."
|
| 679 |
+
|
| 680 |
+
elec2neighbors_dict[selected_electrode] = [
|
| 681 |
+
e_stem + str(num_stem) for (e_stem, num_stem) in nbrs
|
| 682 |
+
]
|
| 683 |
+
unique_neighbors.update(elec2neighbors_dict[selected_electrode])
|
| 684 |
+
|
| 685 |
+
neighbor_label2id = {
|
| 686 |
+
elec: all_electrodes.index(elec) for elec in unique_neighbors
|
| 687 |
+
}
|
| 688 |
+
return elec2neighbors_dict, neighbor_label2id
|
| 689 |
+
|
| 690 |
+
def _laplacian_rereference(
|
| 691 |
+
self,
|
| 692 |
+
selected_data: np.ndarray,
|
| 693 |
+
selected_electrodes: List[str],
|
| 694 |
+
all_data: np.ndarray,
|
| 695 |
+
all_electrodes: List[str],
|
| 696 |
+
):
|
| 697 |
+
"""
|
| 698 |
+
Args:
|
| 699 |
+
selected_data: np.ndarray shape (n_selected_channels, n_samples), corresponding
|
| 700 |
+
to the selected electrodes.
|
| 701 |
+
selected_electrodes: List[str], labels corrresponding to selected electrodes
|
| 702 |
+
(e.g., "clean" electrodes).
|
| 703 |
+
all_data: np.ndarray shape (n_total_channels, n_samples).
|
| 704 |
+
all_electrodes: List[str], labels corrresponding to all electrodes.
|
| 705 |
+
"""
|
| 706 |
+
elec2neighbors_dict, neighbor_label2id = self._get_all_adj_electrodes(
|
| 707 |
+
selected_electrodes, all_electrodes
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
selected_neighbor_data = [
|
| 711 |
+
[
|
| 712 |
+
all_data[neighbor_label2id[nghbr_elec], ...]
|
| 713 |
+
for nghbr_elec in elec2neighbors_dict[elec]
|
| 714 |
+
]
|
| 715 |
+
for elec in selected_electrodes
|
| 716 |
+
]
|
| 717 |
+
selected_neighbor_data = np.array(selected_neighbor_data)
|
| 718 |
+
selected_neighbor_data = self.filter_data(selected_neighbor_data)
|
| 719 |
+
|
| 720 |
+
assert selected_data.shape == (
|
| 721 |
+
selected_neighbor_data.shape[0],
|
| 722 |
+
selected_neighbor_data.shape[-1],
|
| 723 |
+
)
|
| 724 |
+
ref_data = selected_data - np.mean(selected_neighbor_data, axis=1)
|
| 725 |
+
return ref_data
|
| 726 |
+
|
| 727 |
+
def rereference_data(self, **rereference_kwargs) -> np.ndarray:
|
| 728 |
+
"""Rereferences electrode data based on provided reference electrodes.
|
| 729 |
+
|
| 730 |
+
Check _laplacian_rereference() above for required arguments.
|
| 731 |
+
"""
|
| 732 |
+
data = self._laplacian_rereference(**rereference_kwargs)
|
| 733 |
+
return data
|
| 734 |
+
|
| 735 |
+
def zscore_data(self, data: np.ndarray) -> np.ndarray:
|
| 736 |
+
data = (
|
| 737 |
+
sk_preprocessing.StandardScaler(with_mean=True, with_std=True)
|
| 738 |
+
.fit_transform(data.T)
|
| 739 |
+
.T
|
| 740 |
+
)
|
| 741 |
+
return data
|
barista/data/braintreebank_dataset.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict, defaultdict, namedtuple
|
| 2 |
+
from copy import deepcopy
|
| 3 |
+
from typing import List, Optional, Union
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torch
|
| 7 |
+
from barista.data.braintreebank_wrapper import BrainTreebankWrapper
|
| 8 |
+
from omegaconf import DictConfig, OmegaConf
|
| 9 |
+
from torch.utils.data import DataLoader, Dataset
|
| 10 |
+
|
| 11 |
+
DatapointMetadata = namedtuple(
|
| 12 |
+
"Metadata",
|
| 13 |
+
["subject_session", "subject"],
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
DataPoint = namedtuple(
|
| 17 |
+
"DataPoint",
|
| 18 |
+
["x", "label", "metadata"],
|
| 19 |
+
defaults=(None,) * 3
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
BatchItem = namedtuple(
|
| 23 |
+
"BatchItem",
|
| 24 |
+
[
|
| 25 |
+
"x",
|
| 26 |
+
"labels",
|
| 27 |
+
"subject_sessions",
|
| 28 |
+
],
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
torch_version = torch.__version__.split("+")[0]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class BrainTreebankDataset(Dataset):
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
config: Union[OmegaConf, DictConfig],
|
| 38 |
+
max_cache_size: int = 5000,
|
| 39 |
+
include_subject_sessions: Optional[List[str]] = [],
|
| 40 |
+
exclude_subject_sessions: Optional[List[str]] = [],
|
| 41 |
+
):
|
| 42 |
+
"""BrainTreebank Dataset class.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
config: OmegaConf or DictConfig.
|
| 46 |
+
max_cache_size: int. The segment cache size to use to avoid
|
| 47 |
+
reloading segments.
|
| 48 |
+
include_subject_sessions: Optional list of str corresponding to
|
| 49 |
+
the subject_sessions to keep/use in the dataset
|
| 50 |
+
exclude_subject_sessions: Optional list of str corresponding to
|
| 51 |
+
the subject_sessions to discard/not use in the dataset.
|
| 52 |
+
"""
|
| 53 |
+
self.config = config
|
| 54 |
+
|
| 55 |
+
self.dataset = BrainTreebankWrapper(config)
|
| 56 |
+
self.metadata = self.dataset.metadata
|
| 57 |
+
if self.config.get("shuffle_dataloader", True):
|
| 58 |
+
print("Shuffling metadata.")
|
| 59 |
+
self.metadata.shuffle()
|
| 60 |
+
|
| 61 |
+
if not include_subject_sessions:
|
| 62 |
+
print(
|
| 63 |
+
f"Including only finetune sessions specified in config: {config.finetune_sessions}"
|
| 64 |
+
)
|
| 65 |
+
include_subject_sessions = list(config.finetune_sessions)
|
| 66 |
+
|
| 67 |
+
self._reduce_metadata(
|
| 68 |
+
subject_sessions=include_subject_sessions,
|
| 69 |
+
keep=True
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
if exclude_subject_sessions:
|
| 73 |
+
self._reduce_metadata(
|
| 74 |
+
subject_sessions=exclude_subject_sessions,
|
| 75 |
+
keep=False
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
self.max_cache_size = max_cache_size
|
| 79 |
+
self.data_cache = OrderedDict()
|
| 80 |
+
|
| 81 |
+
def check_no_common_segment(self, train_dataset, val_dataset, test_dataset):
|
| 82 |
+
"""Double checking paths for no overlap in splits."""
|
| 83 |
+
train_paths = set(train_dataset.dataset.metadata.get_unique_values_in_col("path"))
|
| 84 |
+
val_paths = set(val_dataset.dataset.metadata.get_unique_values_in_col("path"))
|
| 85 |
+
test_paths = set(test_dataset.dataset.metadata.get_unique_values_in_col("path"))
|
| 86 |
+
|
| 87 |
+
assert not train_paths.intersection(test_paths)
|
| 88 |
+
assert not train_paths.intersection(val_paths)
|
| 89 |
+
assert not val_paths.intersection(test_paths)
|
| 90 |
+
|
| 91 |
+
def _reduce_metadata(self, subject_sessions: List[str], keep=True):
|
| 92 |
+
"""Reduce metadata by either keeping OR discarding the specified subject_sessions.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
subject_sessions: list of str corresponding to subject session identifiers.
|
| 96 |
+
keep: bool. If true, keep the specified subject sessions, otherwise discard.
|
| 97 |
+
"""
|
| 98 |
+
if not isinstance(subject_sessions, list):
|
| 99 |
+
subject_sessions = [subject_sessions]
|
| 100 |
+
|
| 101 |
+
combined_pattern = "|".join(subject_sessions)
|
| 102 |
+
|
| 103 |
+
self.metadata.reduce_based_on_col_value(
|
| 104 |
+
col_name="subject_session",
|
| 105 |
+
value=combined_pattern,
|
| 106 |
+
regex=True,
|
| 107 |
+
keep=keep,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
summary_str = self.metadata.get_summary_str()
|
| 111 |
+
print(f"Reduced dataset: {summary_str}")
|
| 112 |
+
|
| 113 |
+
def set_split(self, split: str):
|
| 114 |
+
self.metadata.reduce_based_on_col_value(col_name="split", value=split)
|
| 115 |
+
|
| 116 |
+
def get_dataloader(self, split: str, train_config: Union[DictConfig, OmegaConf]):
|
| 117 |
+
split_dataset = deepcopy(self)
|
| 118 |
+
split_dataset.set_split(split=split)
|
| 119 |
+
|
| 120 |
+
if split == "test":
|
| 121 |
+
# Don't drop any samples for test for consistency across different batch size.
|
| 122 |
+
drop_last = False
|
| 123 |
+
elif split == "train":
|
| 124 |
+
drop_last = train_config.dataloader.drop_last
|
| 125 |
+
else: # split == "val"
|
| 126 |
+
drop_last = train_config.dataloader.get(
|
| 127 |
+
"drop_last_val",
|
| 128 |
+
train_config.dataloader.drop_last
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
return DataLoader(
|
| 132 |
+
split_dataset,
|
| 133 |
+
batch_size=train_config.dataloader.batch_size,
|
| 134 |
+
collate_fn=collate_with_metadata_fn_group_subjects,
|
| 135 |
+
num_workers=train_config.dataloader.num_workers,
|
| 136 |
+
persistent_workers=train_config.dataloader.persistent_workers,
|
| 137 |
+
pin_memory=train_config.dataloader.pin_memory,
|
| 138 |
+
drop_last=drop_last,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def __len__(self):
|
| 142 |
+
return len(self.metadata)
|
| 143 |
+
|
| 144 |
+
def __getitem__(self, idx):
|
| 145 |
+
meta_row = self.metadata[idx]
|
| 146 |
+
segment_path = meta_row["path"]
|
| 147 |
+
|
| 148 |
+
if segment_path not in self.data_cache:
|
| 149 |
+
data_file = torch.load(
|
| 150 |
+
segment_path, weights_only=(torch_version > "2.2.1")
|
| 151 |
+
)
|
| 152 |
+
if len(self.data_cache) >= self.max_cache_size:
|
| 153 |
+
first_path = next(iter(self.data_cache))
|
| 154 |
+
self.data_cache.pop(first_path)
|
| 155 |
+
self.data_cache[segment_path] = data_file
|
| 156 |
+
|
| 157 |
+
else:
|
| 158 |
+
data_file = self.data_cache[segment_path]
|
| 159 |
+
|
| 160 |
+
metadata = DatapointMetadata(
|
| 161 |
+
subject_session=meta_row.subject_session,
|
| 162 |
+
subject=meta_row.subject,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
if "label" in meta_row and not pd.isna(meta_row.label):
|
| 166 |
+
label = torch.tensor((meta_row.label,))
|
| 167 |
+
else:
|
| 168 |
+
label = data_file[meta_row.experiment]
|
| 169 |
+
if label is None:
|
| 170 |
+
raise ValueError("Label cannot be None in the data_file.")
|
| 171 |
+
|
| 172 |
+
datapoint = DataPoint(
|
| 173 |
+
x=data_file["x"],
|
| 174 |
+
label=label,
|
| 175 |
+
metadata=metadata,
|
| 176 |
+
)
|
| 177 |
+
return datapoint
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def collate_with_metadata_fn_group_subjects(batch: List[DataPoint]):
|
| 181 |
+
"""Returns a list of batched tensors, each for one session."""
|
| 182 |
+
x, labels, subject_sessions = (
|
| 183 |
+
[],
|
| 184 |
+
[],
|
| 185 |
+
[],
|
| 186 |
+
)
|
| 187 |
+
x_dims, labels_dims = [], []
|
| 188 |
+
x_seq_lens, labels_seq_lens = [], []
|
| 189 |
+
|
| 190 |
+
x_dict = defaultdict(list)
|
| 191 |
+
for i, datapoint in enumerate(batch):
|
| 192 |
+
ss = datapoint.metadata.subject_session
|
| 193 |
+
x_dict[ss].append(i)
|
| 194 |
+
|
| 195 |
+
for sub_sesh_list in x_dict.values():
|
| 196 |
+
sub_sesh_x = []
|
| 197 |
+
for i in sub_sesh_list:
|
| 198 |
+
datapoint = batch[i]
|
| 199 |
+
|
| 200 |
+
# Skip all zero sessions
|
| 201 |
+
if torch.all(datapoint.x == 0):
|
| 202 |
+
continue
|
| 203 |
+
|
| 204 |
+
sub_sesh_x.append(datapoint.x)
|
| 205 |
+
labels.append(datapoint.label)
|
| 206 |
+
|
| 207 |
+
subject_sessions.append(datapoint.metadata.subject_session)
|
| 208 |
+
|
| 209 |
+
x_dims.append(datapoint.x.shape[-1])
|
| 210 |
+
labels_dims.append(datapoint.label.shape[-1])
|
| 211 |
+
|
| 212 |
+
x_seq_lens.append(datapoint.x.shape[0])
|
| 213 |
+
labels_seq_lens.append(datapoint.label.shape[0])
|
| 214 |
+
|
| 215 |
+
if sub_sesh_x:
|
| 216 |
+
sub_sesh_x = torch.stack(sub_sesh_x, dim=0)
|
| 217 |
+
x.append(sub_sesh_x)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
if (torch.tensor(labels_dims) == labels_dims[0]).all() and (
|
| 221 |
+
torch.tensor(labels_seq_lens) == labels_seq_lens[0]
|
| 222 |
+
).all():
|
| 223 |
+
labels = torch.stack(labels, dim=0)
|
| 224 |
+
|
| 225 |
+
batch = BatchItem(
|
| 226 |
+
x=x,
|
| 227 |
+
labels=labels,
|
| 228 |
+
subject_sessions=subject_sessions,
|
| 229 |
+
)
|
| 230 |
+
return batch
|
barista/data/braintreebank_dataset_spatial_groupings.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
|
| 3 |
+
import pandas as pd
|
| 4 |
+
|
| 5 |
+
import barista.data.atlas as atlas_enums
|
| 6 |
+
from barista.data.metadata_spatial_groups import (
|
| 7 |
+
MetadataSpatialGroupRow,
|
| 8 |
+
SpatialGroupingName,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
XYZ_MAX = 200
|
| 12 |
+
|
| 13 |
+
class BrainTreebankSpatialGroupingsHelper:
|
| 14 |
+
"""
|
| 15 |
+
Helper class to generate spatial groups rows
|
| 16 |
+
|
| 17 |
+
Creating new spatial groupings should be added here.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, config, dataset_name: str):
|
| 21 |
+
self.config = config
|
| 22 |
+
self.dataset_name = dataset_name
|
| 23 |
+
|
| 24 |
+
def get_spatial_groupings(
|
| 25 |
+
self,
|
| 26 |
+
subject: str,
|
| 27 |
+
session: str,
|
| 28 |
+
coords: List[Tuple],
|
| 29 |
+
localization: pd.DataFrame,
|
| 30 |
+
) -> List[MetadataSpatialGroupRow]:
|
| 31 |
+
rows = []
|
| 32 |
+
for spatial_grouping in self.config.spatial_groupings_to_create:
|
| 33 |
+
sg = SpatialGroupingName(spatial_grouping)
|
| 34 |
+
if sg == SpatialGroupingName.COORDS:
|
| 35 |
+
group_components = coords
|
| 36 |
+
n_effective_components = 3
|
| 37 |
+
max_elements_for_component = (XYZ_MAX, XYZ_MAX, XYZ_MAX)
|
| 38 |
+
padding_indices = (None, None, None)
|
| 39 |
+
|
| 40 |
+
elif sg == SpatialGroupingName.DESTRIEUX:
|
| 41 |
+
(
|
| 42 |
+
group_components,
|
| 43 |
+
n_effective_components,
|
| 44 |
+
max_elements_for_component,
|
| 45 |
+
padding_indices,
|
| 46 |
+
) = self._get_grouping_based_on_loc_file(
|
| 47 |
+
subject=subject,
|
| 48 |
+
coords=coords,
|
| 49 |
+
localization=localization,
|
| 50 |
+
localization_col="Destrieux",
|
| 51 |
+
enum_class=atlas_enums.Destrieux,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
elif sg == SpatialGroupingName.LOBES:
|
| 55 |
+
(
|
| 56 |
+
group_components,
|
| 57 |
+
n_effective_components,
|
| 58 |
+
max_elements_for_component,
|
| 59 |
+
padding_indices,
|
| 60 |
+
) = self._get_grouping_based_on_loc_file(
|
| 61 |
+
subject=subject,
|
| 62 |
+
coords=coords,
|
| 63 |
+
localization=localization,
|
| 64 |
+
localization_col="DesikanKilliany",
|
| 65 |
+
enum_class=atlas_enums.Lobes,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
else:
|
| 69 |
+
raise NotImplementedError()
|
| 70 |
+
|
| 71 |
+
group_ids = self._get_group_ids_based_on_group_components(
|
| 72 |
+
group_components, n_effective_components
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
assert len(max_elements_for_component) >= n_effective_components
|
| 76 |
+
assert len(padding_indices) >= n_effective_components
|
| 77 |
+
|
| 78 |
+
row = MetadataSpatialGroupRow(
|
| 79 |
+
dataset=self.dataset_name,
|
| 80 |
+
subject=subject,
|
| 81 |
+
session=session,
|
| 82 |
+
subject_session=f"{subject}_{session}",
|
| 83 |
+
name=sg.value,
|
| 84 |
+
n_effective_components=n_effective_components,
|
| 85 |
+
max_elements_for_component=max_elements_for_component,
|
| 86 |
+
padding_indices=padding_indices,
|
| 87 |
+
group_components=group_components,
|
| 88 |
+
group_ids=group_ids,
|
| 89 |
+
)
|
| 90 |
+
rows.append(row)
|
| 91 |
+
return rows
|
| 92 |
+
|
| 93 |
+
def _get_grouping_based_on_loc_file(
|
| 94 |
+
self,
|
| 95 |
+
subject: str,
|
| 96 |
+
coords: List[Tuple],
|
| 97 |
+
localization: pd.DataFrame,
|
| 98 |
+
localization_col: str,
|
| 99 |
+
enum_class,
|
| 100 |
+
):
|
| 101 |
+
group_components = []
|
| 102 |
+
for coord in coords:
|
| 103 |
+
found = False
|
| 104 |
+
|
| 105 |
+
for i in range(len(localization)):
|
| 106 |
+
loc = localization.iloc[i]
|
| 107 |
+
|
| 108 |
+
df_coord = (loc.L, loc.I, loc.P)
|
| 109 |
+
|
| 110 |
+
if df_coord == coord:
|
| 111 |
+
identifier_value = loc[localization_col].replace("-", "_").upper()
|
| 112 |
+
enum_i = enum_class.get_enum(identifier_value)
|
| 113 |
+
group_components.append((enum_i.value, identifier_value))
|
| 114 |
+
found = True
|
| 115 |
+
break
|
| 116 |
+
|
| 117 |
+
if not found:
|
| 118 |
+
raise ValueError(
|
| 119 |
+
f"Channel not found in localization file for {subject}"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
max_elements_for_component = (max([v.value for v in enum_class]) + 1,)
|
| 123 |
+
padding_indices = (enum_class.UNKNOWN.value,)
|
| 124 |
+
n_effective_components = 1
|
| 125 |
+
|
| 126 |
+
return (
|
| 127 |
+
group_components,
|
| 128 |
+
n_effective_components,
|
| 129 |
+
max_elements_for_component,
|
| 130 |
+
padding_indices,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
def _get_group_ids_based_on_group_components(
|
| 134 |
+
self, group_components: List[Tuple], n_effective_componetns: int
|
| 135 |
+
) -> List[int]:
|
| 136 |
+
groups_to_id_mapping = dict()
|
| 137 |
+
group_id = 0
|
| 138 |
+
group_ids = []
|
| 139 |
+
for components in group_components:
|
| 140 |
+
group = components[:n_effective_componetns]
|
| 141 |
+
if group not in groups_to_id_mapping:
|
| 142 |
+
chan_group_id = group_id
|
| 143 |
+
groups_to_id_mapping[group] = group_id
|
| 144 |
+
group_id += 1
|
| 145 |
+
else:
|
| 146 |
+
chan_group_id = groups_to_id_mapping[group]
|
| 147 |
+
group_ids.append(chan_group_id)
|
| 148 |
+
|
| 149 |
+
return group_ids
|
barista/data/braintreebank_wrapper.py
ADDED
|
@@ -0,0 +1,1186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Code to handle preprocessing, segmenting and labeling the BrainTreebank dataset.
|
| 2 |
+
|
| 3 |
+
Preprocessing and segmentation functionality is based on the implementations found in the
|
| 4 |
+
following repositories, but has been modified as needed to be used for the evaluation scheme
|
| 5 |
+
outlined in the BaRISTA paper:
|
| 6 |
+
https://github.com/czlwang/BrainBERT/tree/master/data
|
| 7 |
+
https://github.com/czlwang/PopulationTransformer/tree/main/data
|
| 8 |
+
https://github.com/czlwang/brain_treebank_code_release/tree/master/data
|
| 9 |
+
"""
|
| 10 |
+
import dataclasses
|
| 11 |
+
import einops
|
| 12 |
+
import hashlib
|
| 13 |
+
import numpy as np
|
| 14 |
+
from omegaconf import DictConfig, OmegaConf
|
| 15 |
+
import os
|
| 16 |
+
import pandas as pd
|
| 17 |
+
import pickle
|
| 18 |
+
import torch
|
| 19 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
from barista.data.available_sessions import BrainTreebankAvailableSessions
|
| 22 |
+
from barista.data.braintreebank_data_helpers import (
|
| 23 |
+
BrainTreebankDatasetNames,
|
| 24 |
+
BrainTreebankDatasetPathManager,
|
| 25 |
+
BrainTreebankDatasetPreprocessor,
|
| 26 |
+
BrainTreebankDatasetRawDataHelper,
|
| 27 |
+
)
|
| 28 |
+
from barista.data.braintreebank_dataset_spatial_groupings import (
|
| 29 |
+
BrainTreebankSpatialGroupingsHelper,
|
| 30 |
+
)
|
| 31 |
+
from barista.data.metadata import Metadata, MetadataRow, MetadataSpatialGroupRow
|
| 32 |
+
from barista.data.splitter import Splitter
|
| 33 |
+
from barista.data.fileprogresstracker import FileProgressTracker
|
| 34 |
+
|
| 35 |
+
_DEFAULT_FS = 2048 # Hz
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
torch_version = torch.__version__.split("+")[0]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class BrainTreebankWrapper:
|
| 42 |
+
def __init__(self, config: Union[DictConfig, OmegaConf], only_segment_generation=False):
|
| 43 |
+
self.config = config
|
| 44 |
+
|
| 45 |
+
self._setup_helpers()
|
| 46 |
+
|
| 47 |
+
self.spatial_groups_helper = BrainTreebankSpatialGroupingsHelper(
|
| 48 |
+
self.config, dataset_name=self.name
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Hash string identifier corresponding to the preprocessing config used.
|
| 52 |
+
self.segments_processing_str, self.segments_processing_hash_str = (
|
| 53 |
+
self._get_segments_processing_hash(
|
| 54 |
+
segment_length_s=self.config.segment_length_s,
|
| 55 |
+
)
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Raw data processing (e.g., filtering).
|
| 59 |
+
if not self._is_raw_data_processed() or self.config.force_reprocess_stage1:
|
| 60 |
+
print(
|
| 61 |
+
"Processed raw dataset does not exist or reprocessing is enabled, processing starts."
|
| 62 |
+
)
|
| 63 |
+
self._process_raw_data()
|
| 64 |
+
print(f"Raw data processing complete: {self._processed_raw_data_dir}")
|
| 65 |
+
else:
|
| 66 |
+
print("Processed raw data exists")
|
| 67 |
+
|
| 68 |
+
# Processing of segments from processed raw data
|
| 69 |
+
os.makedirs(self._processed_segments_data_dir, exist_ok=True)
|
| 70 |
+
|
| 71 |
+
self.metadata = self._load_metadata()
|
| 72 |
+
|
| 73 |
+
# Empty the metadata since segments do not exist
|
| 74 |
+
self.metadata = self._initialize_metadata()
|
| 75 |
+
|
| 76 |
+
# Process the segments now
|
| 77 |
+
self.process_segments(only_segment_generation)
|
| 78 |
+
print(f"Segments are processed and ready to use. Metadata path: {self.metadata_path}")
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
def name(self) -> str:
|
| 82 |
+
return "BrainTreebank"
|
| 83 |
+
|
| 84 |
+
@property
|
| 85 |
+
def available_sessions(self) -> Dict[str, List]:
|
| 86 |
+
return {
|
| 87 |
+
k.name: k.value
|
| 88 |
+
for k in BrainTreebankAvailableSessions
|
| 89 |
+
if not self.config.subjects_to_process
|
| 90 |
+
or k.name in self.config.subjects_to_process
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
@property
|
| 94 |
+
def experiment(self):
|
| 95 |
+
return self.config.experiment
|
| 96 |
+
|
| 97 |
+
@property
|
| 98 |
+
def metadata_path(self):
|
| 99 |
+
return os.path.join(
|
| 100 |
+
self.config.save_dir,
|
| 101 |
+
self.experiment,
|
| 102 |
+
f"metadata_{self.segments_processing_hash_str}.csv",
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def _setup_helpers(self):
|
| 106 |
+
self.path_manager = BrainTreebankDatasetPathManager(
|
| 107 |
+
dataset_dir=self.config.dataset_dir,
|
| 108 |
+
)
|
| 109 |
+
self.raw_data_helper = BrainTreebankDatasetRawDataHelper(self.path_manager)
|
| 110 |
+
self.raw_data_preprocessor = BrainTreebankDatasetPreprocessor(self.config)
|
| 111 |
+
self.experiment_dataset_name = BrainTreebankDatasetNames.get_modes(
|
| 112 |
+
self.config.experiment
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
self.samp_frequency = self.config.get("samp_frequency", _DEFAULT_FS)
|
| 116 |
+
self.splitter = Splitter(
|
| 117 |
+
config=self.config,
|
| 118 |
+
subjects=list(self.available_sessions.keys()),
|
| 119 |
+
experiment=self.experiment,
|
| 120 |
+
use_fixed_seed=self.config.use_fixed_seed_for_splitter,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
def _process_raw_data(self):
|
| 124 |
+
os.makedirs(self._processed_raw_data_dir, exist_ok=True)
|
| 125 |
+
|
| 126 |
+
for subject in self.available_sessions.keys():
|
| 127 |
+
print(f"Raw data processing for subject {subject} starts.")
|
| 128 |
+
|
| 129 |
+
sessions_count = len(self.available_sessions[subject])
|
| 130 |
+
for i, session in enumerate(self.available_sessions[subject]):
|
| 131 |
+
processed_file_path = self._get_processed_raw_data_file_path(
|
| 132 |
+
subject=subject, session=session
|
| 133 |
+
)
|
| 134 |
+
if os.path.exists(processed_file_path):
|
| 135 |
+
print(
|
| 136 |
+
f"Skipping session {session} ({i+1}/{sessions_count}), "
|
| 137 |
+
f"processed raw data exists in {processed_file_path}."
|
| 138 |
+
)
|
| 139 |
+
else:
|
| 140 |
+
print(
|
| 141 |
+
f"Processing session {session} ({i+1}/{sessions_count})..."
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
self._process_single_session_raw_data(
|
| 145 |
+
subject=subject, session=session
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
def _process_single_session_raw_data(self, subject: str, session: str):
|
| 149 |
+
save_path = self._get_processed_raw_data_file_path(
|
| 150 |
+
subject=subject, session=session
|
| 151 |
+
)
|
| 152 |
+
cache_dir, cache_path = self._get_processed_raw_data_file_path_cache(
|
| 153 |
+
subject=subject, session=session
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
if not self.config.force_reprocess_stage1:
|
| 157 |
+
if os.path.isfile(save_path):
|
| 158 |
+
print(f"Skipping raw processing for {subject} {session}")
|
| 159 |
+
return
|
| 160 |
+
|
| 161 |
+
if os.path.isfile(cache_path):
|
| 162 |
+
print(
|
| 163 |
+
f"Making symlink for raw processed file for {subject} {session}"
|
| 164 |
+
)
|
| 165 |
+
os.symlink(src=cache_path, dst=save_path)
|
| 166 |
+
return
|
| 167 |
+
|
| 168 |
+
raw_data_dict = self.raw_data_helper.get_raw_file(subject, session)
|
| 169 |
+
electrodes = raw_data_dict["electrode_info"]
|
| 170 |
+
|
| 171 |
+
## Clean the electrodes based on corrupted channel meta information.
|
| 172 |
+
selected_electrodes = self.raw_data_helper.get_clean_elecs(subject)
|
| 173 |
+
assert len(set(selected_electrodes).intersection(set(electrodes))) == len(
|
| 174 |
+
selected_electrodes
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
selected_elecs_inds = [
|
| 178 |
+
i for i, e in enumerate(electrodes) if e in selected_electrodes
|
| 179 |
+
]
|
| 180 |
+
electrode_data = raw_data_dict["data"][:, np.array(selected_elecs_inds)]
|
| 181 |
+
electrode_data = (
|
| 182 |
+
electrode_data.T
|
| 183 |
+
) # Preprocessor requires (n_channels, n_samples)
|
| 184 |
+
|
| 185 |
+
## Resample the data if self.samp_frequency != default_fs
|
| 186 |
+
if self.samp_frequency != _DEFAULT_FS:
|
| 187 |
+
raise NotImplementedError(
|
| 188 |
+
f"Resampling {self.name} dataset not yet supported."
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
## Filter the data (e.g., notch).
|
| 192 |
+
electrode_data = self.raw_data_preprocessor.filter_data(electrode_data)
|
| 193 |
+
|
| 194 |
+
## Do rerefencing.
|
| 195 |
+
electrode_data = self.raw_data_preprocessor.rereference_data(
|
| 196 |
+
selected_data=electrode_data,
|
| 197 |
+
selected_electrodes=selected_electrodes,
|
| 198 |
+
all_data=raw_data_dict["data"].T,
|
| 199 |
+
all_electrodes=raw_data_dict["electrode_info"],
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
save_dict = dict(
|
| 203 |
+
data=torch.tensor(electrode_data.T), # (n_samples, n_channels)
|
| 204 |
+
time=torch.tensor(raw_data_dict["time"]),
|
| 205 |
+
samp_frequency=self.samp_frequency,
|
| 206 |
+
electrode_info=selected_electrodes,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
try:
|
| 210 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 211 |
+
torch.save(save_dict, cache_path)
|
| 212 |
+
print(f"Raw processed file created in {cache_path}")
|
| 213 |
+
os.symlink(src=cache_path, dst=save_path)
|
| 214 |
+
print(f"Raw processed file symlink created in {save_path}")
|
| 215 |
+
except (OSError, PermissionError, FileNotFoundError):
|
| 216 |
+
torch.save(save_dict, save_path)
|
| 217 |
+
print(f"Raw processed file created in {save_path}")
|
| 218 |
+
|
| 219 |
+
def _is_raw_data_processed(self):
|
| 220 |
+
if not os.path.exists(self._processed_raw_data_dir):
|
| 221 |
+
return False
|
| 222 |
+
|
| 223 |
+
files_exist = []
|
| 224 |
+
for subject in self.available_sessions.keys():
|
| 225 |
+
for session in self.available_sessions[subject]:
|
| 226 |
+
path = self._get_processed_raw_data_file_path(
|
| 227 |
+
subject=subject, session=session
|
| 228 |
+
)
|
| 229 |
+
files_exist.append(os.path.exists(path))
|
| 230 |
+
return np.array(files_exist).all()
|
| 231 |
+
|
| 232 |
+
def _get_file_progress_tracker_save_path(self, subject: str, session: str) -> str:
|
| 233 |
+
filename = f"{subject}_{session}_processing_status.json"
|
| 234 |
+
return os.path.join(self._processed_segments_data_dir, filename)
|
| 235 |
+
|
| 236 |
+
def _get_channels_region_info(
|
| 237 |
+
self,
|
| 238 |
+
subject: str,
|
| 239 |
+
electrode_info: List[str],
|
| 240 |
+
) -> List[Tuple]:
|
| 241 |
+
"""
|
| 242 |
+
Generate a list of Channels each including region information of the channel.
|
| 243 |
+
"""
|
| 244 |
+
channels, coords, channel_inds_to_remove = [], [], []
|
| 245 |
+
for channel_ind, channel_name in enumerate(electrode_info):
|
| 246 |
+
localization_info = self.raw_data_helper.get_channel_localization(
|
| 247 |
+
subject, channel_name
|
| 248 |
+
)
|
| 249 |
+
if not localization_info:
|
| 250 |
+
raise ValueError(
|
| 251 |
+
f"Couldn't found elec {channel_name} for subject {subject}"
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
assert (
|
| 255 |
+
"coords" in localization_info
|
| 256 |
+
), "localization_info incomplete, missing coords"
|
| 257 |
+
coord = localization_info.pop("coords")
|
| 258 |
+
|
| 259 |
+
## Remove channels from regions specified in the config file.
|
| 260 |
+
if self.config.region_filtering.active:
|
| 261 |
+
match = False
|
| 262 |
+
for filtered_region in self.config.region_filtering.filters:
|
| 263 |
+
component_info = localization_info['region_info']
|
| 264 |
+
match = filtered_region.lower() in component_info.lower()
|
| 265 |
+
if match:
|
| 266 |
+
break
|
| 267 |
+
|
| 268 |
+
if match:
|
| 269 |
+
channel_inds_to_remove.append(channel_ind)
|
| 270 |
+
continue
|
| 271 |
+
|
| 272 |
+
coords.append((coord[0], coord[1], coord[2]))
|
| 273 |
+
channels.append((
|
| 274 |
+
localization_info['hemi'],
|
| 275 |
+
localization_info['region_info'],
|
| 276 |
+
localization_info['channel_stem'],
|
| 277 |
+
))
|
| 278 |
+
|
| 279 |
+
return channels, coords, channel_inds_to_remove
|
| 280 |
+
|
| 281 |
+
def _create_spatial_groupings(
|
| 282 |
+
self, subject: str, session: str, coords: List[Tuple]
|
| 283 |
+
):
|
| 284 |
+
localization = self.raw_data_helper.get_channel_localization_raw(subject)
|
| 285 |
+
rows = self.spatial_groups_helper.get_spatial_groupings(
|
| 286 |
+
subject,
|
| 287 |
+
session,
|
| 288 |
+
coords,
|
| 289 |
+
localization,
|
| 290 |
+
)
|
| 291 |
+
for row in rows:
|
| 292 |
+
self.metadata.add_spatial_group(row)
|
| 293 |
+
print(f"Add spatial group {row.name} for {row.subject_session}")
|
| 294 |
+
|
| 295 |
+
self.metadata.save(self.metadata_path)
|
| 296 |
+
|
| 297 |
+
def _spatial_groupings_exist_for_subject(self, subject: str, session: str):
|
| 298 |
+
for spatial_grouping in self.config.spatial_groupings_to_create:
|
| 299 |
+
sg = self.metadata.get_spatial_grouping(
|
| 300 |
+
subject_session=f"{subject}_{session}", name=spatial_grouping
|
| 301 |
+
)
|
| 302 |
+
if sg is None:
|
| 303 |
+
return False
|
| 304 |
+
return True
|
| 305 |
+
|
| 306 |
+
def _save_segment(
|
| 307 |
+
self,
|
| 308 |
+
subject: str,
|
| 309 |
+
session: str,
|
| 310 |
+
segment_data: torch.tensor,
|
| 311 |
+
segment_time: torch.tensor,
|
| 312 |
+
segment_labels: torch.tensor,
|
| 313 |
+
segment_id: int,
|
| 314 |
+
segment_seq_len: int,
|
| 315 |
+
file_progress_tracker: FileProgressTracker,
|
| 316 |
+
is_last_segment: bool
|
| 317 |
+
) -> dict:
|
| 318 |
+
"""Process and save one segment to file."""
|
| 319 |
+
|
| 320 |
+
segment_data = {
|
| 321 |
+
"x": segment_data.float().clone(),
|
| 322 |
+
"timestamps": segment_time.clone(),
|
| 323 |
+
self.experiment: segment_labels.clone(),
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
segment_label = self._get_segment_label(segment_labels)
|
| 327 |
+
segment_filename = f"{subject}_{session}_{segment_id}.pt"
|
| 328 |
+
segment_path = os.path.join(self._processed_segments_data_dir, segment_filename)
|
| 329 |
+
torch.save(segment_data, segment_path)
|
| 330 |
+
|
| 331 |
+
meta_row = MetadataRow(
|
| 332 |
+
dataset=self.name,
|
| 333 |
+
subject=subject,
|
| 334 |
+
session=session,
|
| 335 |
+
subject_session=f"{subject}_{session}",
|
| 336 |
+
experiment=self.experiment,
|
| 337 |
+
seq_len=segment_seq_len,
|
| 338 |
+
d_input=np.prod(segment_data["x"].shape),
|
| 339 |
+
d_data=segment_data["x"].shape,
|
| 340 |
+
path=segment_path,
|
| 341 |
+
split="train",
|
| 342 |
+
filename=segment_filename,
|
| 343 |
+
processing_str=self.segments_processing_str,
|
| 344 |
+
label=segment_label,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
self.metadata.concat(pd.DataFrame([meta_row]))
|
| 348 |
+
|
| 349 |
+
if segment_id % self.config.processing_save_interval == 0 or is_last_segment:
|
| 350 |
+
self.metadata.save(self.metadata_path)
|
| 351 |
+
file_progress_tracker.update_last_file_ind(
|
| 352 |
+
file_ind=-1, ending_ind=-1, segment_id=segment_id
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
def _create_segments_for_subject_session(
|
| 356 |
+
self,
|
| 357 |
+
subject: str,
|
| 358 |
+
session: str,
|
| 359 |
+
segment_length_s: int,
|
| 360 |
+
file_progress_tracker: FileProgressTracker,
|
| 361 |
+
) -> int:
|
| 362 |
+
"""
|
| 363 |
+
Args:
|
| 364 |
+
subject: str. Subject name.
|
| 365 |
+
session: str. Session name.
|
| 366 |
+
segment_length_s: desired segment length in seconds
|
| 367 |
+
file_progress_tracker: tracker of last segment info that is processed
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
Number of newly added segments.
|
| 371 |
+
"""
|
| 372 |
+
processed_raw_data_path = self._get_processed_raw_data_file_path(
|
| 373 |
+
subject=subject, session=session
|
| 374 |
+
)
|
| 375 |
+
preprocessed_data_dict = torch.load(processed_raw_data_path, weights_only=False)
|
| 376 |
+
|
| 377 |
+
data = preprocessed_data_dict["data"].T # (n_channels, n_samples)
|
| 378 |
+
|
| 379 |
+
electrode_names = preprocessed_data_dict["electrode_info"]
|
| 380 |
+
channels, coords, channel_inds_to_remove = self._get_channels_region_info(
|
| 381 |
+
subject, electrode_names
|
| 382 |
+
)
|
| 383 |
+
assert len(electrode_names) - len(channel_inds_to_remove) == len(channels)
|
| 384 |
+
|
| 385 |
+
if channel_inds_to_remove: # Channels and coords already have these indices removed.
|
| 386 |
+
print(
|
| 387 |
+
f"Dropping {len(channel_inds_to_remove)} channels out of {len(electrode_names)} because missing."
|
| 388 |
+
)
|
| 389 |
+
channels_to_keep = np.delete(
|
| 390 |
+
np.arange(data.shape[0]), channel_inds_to_remove
|
| 391 |
+
)
|
| 392 |
+
data = data[channels_to_keep, ...]
|
| 393 |
+
electrode_names = [
|
| 394 |
+
electrode_names[i]
|
| 395 |
+
for i in range(len(electrode_names))
|
| 396 |
+
if i not in channel_inds_to_remove
|
| 397 |
+
]
|
| 398 |
+
|
| 399 |
+
assert data.shape[0] == len(channels)
|
| 400 |
+
|
| 401 |
+
self._create_spatial_groupings(subject, session, coords)
|
| 402 |
+
|
| 403 |
+
if (
|
| 404 |
+
file_progress_tracker.is_completed()
|
| 405 |
+
and not self.config.force_reprocess_stage2
|
| 406 |
+
):
|
| 407 |
+
return 0
|
| 408 |
+
|
| 409 |
+
# Segment the neural activity data into segments of segment_length_s seconds.
|
| 410 |
+
n_steps_in_one_segment = int(self.samp_frequency * segment_length_s)
|
| 411 |
+
data, labels, data_sample_indices = self._get_experiment_data_and_labels(
|
| 412 |
+
subject,
|
| 413 |
+
session,
|
| 414 |
+
data,
|
| 415 |
+
n_steps_in_one_segment,
|
| 416 |
+
time=preprocessed_data_dict["time"],
|
| 417 |
+
samp_frequency=preprocessed_data_dict["samp_frequency"],
|
| 418 |
+
electrode_info=preprocessed_data_dict["electrode_info"],
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
# Get the file index of previously processed files
|
| 422 |
+
_, _, last_segment_id = file_progress_tracker.get_last_file_ind()
|
| 423 |
+
|
| 424 |
+
print(
|
| 425 |
+
f"{last_segment_id+1} segment(s) already processed for subject {subject} session {session}."
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
for segment_ind in range(last_segment_id + 1, data.shape[0]):
|
| 429 |
+
segment_data = data[segment_ind, ...] # (n_channels, segment_len)
|
| 430 |
+
segment_label = labels[segment_ind, ...]
|
| 431 |
+
|
| 432 |
+
# Normalize current segment
|
| 433 |
+
segment_data = torch.tensor(
|
| 434 |
+
self.raw_data_preprocessor.zscore_data(segment_data)
|
| 435 |
+
)
|
| 436 |
+
segment_data = segment_data.T # (segment_len, n_channels)
|
| 437 |
+
|
| 438 |
+
self._save_segment(
|
| 439 |
+
subject,
|
| 440 |
+
session=session,
|
| 441 |
+
segment_data=segment_data,
|
| 442 |
+
segment_time=data_sample_indices[segment_ind, ...],
|
| 443 |
+
segment_labels=segment_label,
|
| 444 |
+
segment_id=segment_ind,
|
| 445 |
+
segment_seq_len=n_steps_in_one_segment,
|
| 446 |
+
file_progress_tracker=file_progress_tracker,
|
| 447 |
+
is_last_segment=(segment_ind == data.shape[0] - 1),
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
return data.shape[0] - last_segment_id
|
| 451 |
+
|
| 452 |
+
def _generate_segmented_data(
|
| 453 |
+
self,
|
| 454 |
+
data: torch.Tensor,
|
| 455 |
+
n_steps_in_one_segment: int,
|
| 456 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 457 |
+
"""
|
| 458 |
+
Segment data of shape (channels x time_samples) to (number_of_segments x channels x n_steps_in_one_segment).
|
| 459 |
+
It will truncate extra samples.
|
| 460 |
+
|
| 461 |
+
Returns segmented data and also indices corresponding to original data tensor.
|
| 462 |
+
"""
|
| 463 |
+
# Truncate time series to a divisible length by the desired window size.
|
| 464 |
+
cutoff_len = int(data.shape[-1] - data.shape[-1] % n_steps_in_one_segment)
|
| 465 |
+
data = data[..., :cutoff_len]
|
| 466 |
+
data_sample_indices = torch.arange(data.shape[-1])
|
| 467 |
+
data = einops.rearrange(data, "c (ns sl) -> ns c sl", sl=n_steps_in_one_segment)
|
| 468 |
+
data_sample_indices = data_sample_indices.reshape(
|
| 469 |
+
[-1, n_steps_in_one_segment]
|
| 470 |
+
) # (n_segments, segment_length)
|
| 471 |
+
|
| 472 |
+
return data, data_sample_indices
|
| 473 |
+
|
| 474 |
+
def _get_experiment_data_and_labels(
|
| 475 |
+
self,
|
| 476 |
+
subject: str,
|
| 477 |
+
session: str,
|
| 478 |
+
raw_data: torch.Tensor,
|
| 479 |
+
n_steps_in_one_segment: int,
|
| 480 |
+
**kwargs, ## Needed for child classes.
|
| 481 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 482 |
+
"""
|
| 483 |
+
Generate data and labels pairs. The data is reshaped to segments, which is done either by chunking
|
| 484 |
+
or by word-based segmenting based on the given experiment.
|
| 485 |
+
|
| 486 |
+
Args:
|
| 487 |
+
subject: str. Current data's subject name.
|
| 488 |
+
session: str. Current data's session name.
|
| 489 |
+
raw_data: a tensor of shape (n_channels x n_total_samples)
|
| 490 |
+
n_steps_in_one_segment: int. Number of samples we want in one segment.
|
| 491 |
+
|
| 492 |
+
Output:
|
| 493 |
+
data: a tensor of shape (n_segments x n_channels x n_steps_in_one_segment)
|
| 494 |
+
labels: a tensor of shape (n_segments x n_steps_in_one_segment)
|
| 495 |
+
data_sample_indices: a tensor of shape (n_segments x n_steps_in_one_segment)
|
| 496 |
+
containing indices of samples of the raw data each item in data corresponds to
|
| 497 |
+
"""
|
| 498 |
+
if self.experiment_dataset_name == self._pretrain_enum:
|
| 499 |
+
data, data_sample_indices = self._generate_segmented_data(
|
| 500 |
+
raw_data, n_steps_in_one_segment
|
| 501 |
+
)
|
| 502 |
+
labels = torch.tensor(np.ones_like(data_sample_indices) * np.nan) # dummy
|
| 503 |
+
return data, labels, data_sample_indices
|
| 504 |
+
|
| 505 |
+
# Get associated experiment labels
|
| 506 |
+
raw_labels, label_intervals = self.raw_data_helper.get_features(
|
| 507 |
+
subject, session, self.experiment, raw_data.shape[-1]
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
if (
|
| 511 |
+
self.experiment_dataset_name == BrainTreebankDatasetNames.SENTENCE_ONSET
|
| 512 |
+
or self.experiment_dataset_name
|
| 513 |
+
== BrainTreebankDatasetNames.SPEECH_VS_NONSPEECH
|
| 514 |
+
or self.experiment_dataset_name
|
| 515 |
+
== BrainTreebankDatasetNames.SENTENCE_ONSET_TIME
|
| 516 |
+
or self.experiment_dataset_name
|
| 517 |
+
== BrainTreebankDatasetNames.SPEECH_VS_NONSPEECH_TIME
|
| 518 |
+
):
|
| 519 |
+
data, labels, data_sample_indices = (
|
| 520 |
+
self._generate_data_and_labels_by_speech(
|
| 521 |
+
raw_data, n_steps_in_one_segment, raw_labels
|
| 522 |
+
)
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
elif (
|
| 526 |
+
self.experiment_dataset_name == BrainTreebankDatasetNames.VOLUME
|
| 527 |
+
or self.experiment_dataset_name == BrainTreebankDatasetNames.OPTICAL_FLOW
|
| 528 |
+
):
|
| 529 |
+
# label switch point will be the the neural activity index that corresponds to the word onset
|
| 530 |
+
label_switchpoints = np.array(
|
| 531 |
+
[elem[0] for elem in label_intervals], dtype=int
|
| 532 |
+
)
|
| 533 |
+
data, data_sample_indices, _ = self._generate_word_aligned_segments(
|
| 534 |
+
raw_data, n_steps_in_one_segment, label_switchpoints
|
| 535 |
+
)
|
| 536 |
+
# data_sample_indices are neural activity indice that corresponds to the segment start
|
| 537 |
+
# which is label switch points - segment len / 2 * sampling rate
|
| 538 |
+
|
| 539 |
+
start = (
|
| 540 |
+
int(data.shape[-1] / 2)
|
| 541 |
+
if self.config.trial_alignment == "center"
|
| 542 |
+
else 0
|
| 543 |
+
)
|
| 544 |
+
valid_label_switchpoints = data_sample_indices[start :: data.shape[-1]]
|
| 545 |
+
|
| 546 |
+
labels = raw_labels[valid_label_switchpoints]
|
| 547 |
+
labels = einops.repeat(labels, "n -> n l", l=data.shape[-1])
|
| 548 |
+
|
| 549 |
+
if self.config.quantile_numerical_labels.active:
|
| 550 |
+
labels = self._generate_quartile_labels(labels)
|
| 551 |
+
|
| 552 |
+
data_sample_indices = data_sample_indices.reshape(
|
| 553 |
+
(data.shape[0], data.shape[-1])
|
| 554 |
+
)
|
| 555 |
+
labels = torch.from_numpy(labels)
|
| 556 |
+
|
| 557 |
+
return data, labels, data_sample_indices
|
| 558 |
+
|
| 559 |
+
def _generate_data_and_labels_by_segments(
|
| 560 |
+
self,
|
| 561 |
+
raw_data: torch.Tensor,
|
| 562 |
+
n_steps_in_one_segment: int,
|
| 563 |
+
raw_labels: np.ndarray,
|
| 564 |
+
):
|
| 565 |
+
"""
|
| 566 |
+
Generate data and labels pairs by chunking the full session
|
| 567 |
+
|
| 568 |
+
Args:
|
| 569 |
+
raw_data: a tensor of shape (N_channels x N_total_samples)
|
| 570 |
+
n_steps_in_one_segment: number of samples we want in one segment
|
| 571 |
+
raw_labels: a numpy array of length N_total_samples containing labels
|
| 572 |
+
corresponding to each sample
|
| 573 |
+
|
| 574 |
+
Output:
|
| 575 |
+
data: a tensor of shape (N_segments x N_channels x n_steps_in_one_segment)
|
| 576 |
+
labels: a tensor of shape (N_segments x n_steps_in_one_segment)
|
| 577 |
+
data_sample_indices: a tensor of shape (N_segments x n_steps_in_one_segment)
|
| 578 |
+
containing indices of samples of the raw data each item in data corresponds to
|
| 579 |
+
"""
|
| 580 |
+
data, data_sample_indices = self._generate_segmented_data(
|
| 581 |
+
raw_data, n_steps_in_one_segment
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
# data: N x channels x n_steps_in_one_segment
|
| 585 |
+
cutoff_len = data.shape[0] * data.shape[-1]
|
| 586 |
+
|
| 587 |
+
labels = raw_labels[..., :cutoff_len]
|
| 588 |
+
labels = einops.rearrange(labels, "(ns sl) -> ns sl", sl=n_steps_in_one_segment)
|
| 589 |
+
|
| 590 |
+
assert labels.shape[0] == data.shape[0]
|
| 591 |
+
|
| 592 |
+
if self.config.quantile_numerical_labels.active:
|
| 593 |
+
labels = self._generate_quartile_labels(labels)
|
| 594 |
+
|
| 595 |
+
labels = torch.from_numpy(labels)
|
| 596 |
+
return data, labels, data_sample_indices
|
| 597 |
+
|
| 598 |
+
def _generate_quartile_labels(self, feature_values: np.ndarray) -> np.ndarray:
|
| 599 |
+
"""
|
| 600 |
+
Convert float labels based on quantile values: values in the top quantile will be assigned 1,
|
| 601 |
+
values in the bottom quantile will be assigned 0, and all others will be assigned NaN.
|
| 602 |
+
"""
|
| 603 |
+
valid_inds = ~np.isnan(feature_values)
|
| 604 |
+
lower_thresh, higher_thresh = np.quantile(
|
| 605 |
+
feature_values[valid_inds],
|
| 606 |
+
[
|
| 607 |
+
self.config.quantile_numerical_labels.lower_threshold,
|
| 608 |
+
self.config.quantile_numerical_labels.higher_threshold,
|
| 609 |
+
],
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
valid_inds = np.logical_or(
|
| 613 |
+
feature_values <= lower_thresh, feature_values >= higher_thresh
|
| 614 |
+
)
|
| 615 |
+
new_feature_values = feature_values.copy()
|
| 616 |
+
new_feature_values[~valid_inds] = np.nan
|
| 617 |
+
new_feature_values[feature_values <= lower_thresh] = 0
|
| 618 |
+
new_feature_values[feature_values >= higher_thresh] = 1
|
| 619 |
+
|
| 620 |
+
return new_feature_values
|
| 621 |
+
|
| 622 |
+
def _generate_word_aligned_segments(
|
| 623 |
+
self,
|
| 624 |
+
raw_data: torch.Tensor,
|
| 625 |
+
n_steps_in_one_segment: int,
|
| 626 |
+
label_switchpoints: np.ndarray,
|
| 627 |
+
):
|
| 628 |
+
if self.config.trial_alignment == "center":
|
| 629 |
+
half_window = int(n_steps_in_one_segment / 2)
|
| 630 |
+
start_inds = label_switchpoints - half_window # start of word boundries
|
| 631 |
+
valid_start_inds = start_inds[
|
| 632 |
+
np.logical_and(
|
| 633 |
+
start_inds >= 0,
|
| 634 |
+
start_inds + n_steps_in_one_segment < raw_data.shape[-1],
|
| 635 |
+
)
|
| 636 |
+
]
|
| 637 |
+
|
| 638 |
+
all_word_aligned_inds, word_aligned_inds, word_aligned_samples = (
|
| 639 |
+
[],
|
| 640 |
+
[],
|
| 641 |
+
[],
|
| 642 |
+
)
|
| 643 |
+
## Note that the positive samples will most likely have overlaps between the windows.
|
| 644 |
+
for samp_ind, samp_start_ind in enumerate(valid_start_inds):
|
| 645 |
+
# inds in neural activity for this word
|
| 646 |
+
inds_to_query = torch.arange(
|
| 647 |
+
samp_start_ind, samp_start_ind + n_steps_in_one_segment
|
| 648 |
+
)
|
| 649 |
+
all_word_aligned_inds.append(inds_to_query)
|
| 650 |
+
|
| 651 |
+
## Explicitly avoiding overlapping positive samples here.
|
| 652 |
+
if (
|
| 653 |
+
self.config.force_nonoverlap
|
| 654 |
+
and samp_ind > 0
|
| 655 |
+
and samp_start_ind <= word_aligned_inds[-1][-1]
|
| 656 |
+
):
|
| 657 |
+
continue
|
| 658 |
+
|
| 659 |
+
word_aligned_samples.append(raw_data[:, inds_to_query])
|
| 660 |
+
word_aligned_inds.append(inds_to_query)
|
| 661 |
+
|
| 662 |
+
print(
|
| 663 |
+
f"Using only {len(word_aligned_inds)} out of {len(all_word_aligned_inds)} word-aligned segments."
|
| 664 |
+
)
|
| 665 |
+
all_word_aligned_inds = torch.cat(all_word_aligned_inds)
|
| 666 |
+
word_aligned_inds = torch.cat(
|
| 667 |
+
word_aligned_inds
|
| 668 |
+
) # (n_segments * segment_length)
|
| 669 |
+
word_aligned_samples = torch.stack( #
|
| 670 |
+
word_aligned_samples
|
| 671 |
+
) # (n_segments, n_channels, segment_length)
|
| 672 |
+
|
| 673 |
+
if self.config.force_nonoverlap:
|
| 674 |
+
assert len(torch.unique(word_aligned_inds)) == len(word_aligned_inds)
|
| 675 |
+
|
| 676 |
+
else:
|
| 677 |
+
raise NotImplementedError("Only center trial alignment supported.")
|
| 678 |
+
|
| 679 |
+
return word_aligned_samples, word_aligned_inds, all_word_aligned_inds
|
| 680 |
+
|
| 681 |
+
def _generate_data_and_labels_by_speech(
|
| 682 |
+
self,
|
| 683 |
+
raw_data: torch.Tensor,
|
| 684 |
+
n_steps_in_one_segment: int,
|
| 685 |
+
labels: np.ndarray,
|
| 686 |
+
):
|
| 687 |
+
"""
|
| 688 |
+
Generate data and labels pairs by segmenting based on words.
|
| 689 |
+
|
| 690 |
+
This function will first create word-aligned non-overlapping segments and
|
| 691 |
+
then assign labels to each word. For speech_vs_nonspeech(_time) and
|
| 692 |
+
sentence_onset(_time) tasks, it then chunks the data and uses segments that
|
| 693 |
+
don't overlap with any word to generate negative labels. Note, this function
|
| 694 |
+
can generate either non-overlapping **or** overlapping word center-aligned
|
| 695 |
+
segments -- based on user preference. In the former case with non-overlapping
|
| 696 |
+
segments, not all parts of the data will be used, since this is word-based.
|
| 697 |
+
|
| 698 |
+
Args:
|
| 699 |
+
data: a tensor of shape (n_channels x n_total_samples)
|
| 700 |
+
n_steps_in_one_segment: number of samples we want in one segment
|
| 701 |
+
raw_labels: a numpy array of length n_total_samples containing labels
|
| 702 |
+
corresponding to each sample
|
| 703 |
+
|
| 704 |
+
Output:
|
| 705 |
+
data: a tensor of shape (n_segments x n_channels x n_steps_in_one_segment)
|
| 706 |
+
labels: a tensor of shape (n_segments x n_steps_in_one_segment)
|
| 707 |
+
data_sample_indices: a tensor of shape (n_segments x n_steps_in_one_segment)
|
| 708 |
+
containing indices of samples of the raw data each item in data corresponds to.
|
| 709 |
+
"""
|
| 710 |
+
# NOTE: The reason why label_intervals/word start times are not used as the switchpoints is
|
| 711 |
+
# because sentence onset true labels don't include all words, but only words that are onsets.
|
| 712 |
+
# Using word start times as switch points will generate more word aligned segments than is
|
| 713 |
+
# correct / needed. As such, here we use the raw labels directly to determine switchpoints.
|
| 714 |
+
label_switchpoints = np.where(
|
| 715 |
+
np.logical_and(
|
| 716 |
+
# All switch points should have delta with previous sample greater than 0.
|
| 717 |
+
np.concatenate((np.array([0]), np.diff(np.nan_to_num(labels)))) > 0,
|
| 718 |
+
~np.isnan(labels),
|
| 719 |
+
)
|
| 720 |
+
)[0]
|
| 721 |
+
out = self._generate_word_aligned_segments(
|
| 722 |
+
raw_data, n_steps_in_one_segment, label_switchpoints
|
| 723 |
+
)
|
| 724 |
+
word_aligned_samples, word_aligned_inds, all_word_aligned_inds = out
|
| 725 |
+
|
| 726 |
+
if self.config.force_nonoverlap:
|
| 727 |
+
data_sample_indices = torch.arange(raw_data.shape[-1])
|
| 728 |
+
is_unaligned_inds = np.logical_and(
|
| 729 |
+
~np.isin(data_sample_indices, np.unique(all_word_aligned_inds)),
|
| 730 |
+
~np.isnan(labels),
|
| 731 |
+
)
|
| 732 |
+
# Truncate time series to a divisible length by the desired window size.
|
| 733 |
+
cutoff_len = int(
|
| 734 |
+
raw_data.shape[-1] - raw_data.shape[-1] % n_steps_in_one_segment
|
| 735 |
+
)
|
| 736 |
+
is_unaligned_inds = np.reshape(
|
| 737 |
+
is_unaligned_inds[..., :cutoff_len], (-1, n_steps_in_one_segment)
|
| 738 |
+
)
|
| 739 |
+
unaligned_inds = np.where(np.all(is_unaligned_inds, axis=1))[0]
|
| 740 |
+
unaligned_word_samples = torch.stack(
|
| 741 |
+
[
|
| 742 |
+
raw_data[
|
| 743 |
+
:,
|
| 744 |
+
start_ind
|
| 745 |
+
* n_steps_in_one_segment : (start_ind + 1)
|
| 746 |
+
* n_steps_in_one_segment,
|
| 747 |
+
]
|
| 748 |
+
for start_ind in unaligned_inds
|
| 749 |
+
]
|
| 750 |
+
)
|
| 751 |
+
|
| 752 |
+
word_aligned_data_sample_inds = torch.reshape(
|
| 753 |
+
word_aligned_inds, (-1, n_steps_in_one_segment)
|
| 754 |
+
)
|
| 755 |
+
unaligned_data_sample_inds = torch.reshape(
|
| 756 |
+
data_sample_indices[:cutoff_len], (-1, n_steps_in_one_segment)
|
| 757 |
+
)[unaligned_inds]
|
| 758 |
+
|
| 759 |
+
else: # not self.config.force_nonoverlap
|
| 760 |
+
# setting self.config.nonword_stepsize_s=segment_length should yield non overlap
|
| 761 |
+
if self.config.nonword_stepsize_s is None:
|
| 762 |
+
self.config.nonword_stepsize_s = self.config.segment_length_s
|
| 763 |
+
|
| 764 |
+
offset = int(self.samp_frequency * self.config.nonword_stepsize_s)
|
| 765 |
+
# Computation for n_rows: https://stackoverflow.com/a/53580139
|
| 766 |
+
n_rows = ((raw_data.shape[-1] - n_steps_in_one_segment) // offset) + 1
|
| 767 |
+
|
| 768 |
+
data_sample_indices = np.array(
|
| 769 |
+
[
|
| 770 |
+
np.arange(i * offset, i * offset + n_steps_in_one_segment)
|
| 771 |
+
for i in range(n_rows)
|
| 772 |
+
]
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
is_unaligned_inds = np.logical_and(
|
| 776 |
+
~np.isin(data_sample_indices, np.unique(all_word_aligned_inds)),
|
| 777 |
+
# NOTE: The second conditional is necessary because in the sentence onset case,
|
| 778 |
+
# regions with speech that aren't sentence onsets are labelled with nans.
|
| 779 |
+
# These should also be considered when labeling negatives.
|
| 780 |
+
~np.isnan(
|
| 781 |
+
labels[data_sample_indices.flatten()].reshape(
|
| 782 |
+
data_sample_indices.shape
|
| 783 |
+
)
|
| 784 |
+
),
|
| 785 |
+
)
|
| 786 |
+
unaligned_inds = np.where(np.all(is_unaligned_inds, axis=1))[0]
|
| 787 |
+
|
| 788 |
+
unaligned_word_samples = torch.stack(
|
| 789 |
+
[
|
| 790 |
+
raw_data[
|
| 791 |
+
:,
|
| 792 |
+
start_ind * offset : start_ind * offset
|
| 793 |
+
+ n_steps_in_one_segment,
|
| 794 |
+
]
|
| 795 |
+
for start_ind in unaligned_inds
|
| 796 |
+
]
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
data_sample_indices = torch.tensor(data_sample_indices)
|
| 800 |
+
|
| 801 |
+
word_aligned_data_sample_inds = torch.reshape(
|
| 802 |
+
word_aligned_inds, (-1, n_steps_in_one_segment)
|
| 803 |
+
)
|
| 804 |
+
unaligned_data_sample_inds = data_sample_indices[unaligned_inds]
|
| 805 |
+
|
| 806 |
+
n_word_aligned_samples = word_aligned_samples.shape[0]
|
| 807 |
+
n_unaligned_word_samples = unaligned_word_samples.shape[0]
|
| 808 |
+
|
| 809 |
+
num_samples = n_unaligned_word_samples + n_word_aligned_samples
|
| 810 |
+
|
| 811 |
+
if self.config.force_balanced:
|
| 812 |
+
num_samples = min(n_unaligned_word_samples, n_word_aligned_samples) * 2
|
| 813 |
+
|
| 814 |
+
word_aligned_to_use = np.sort(
|
| 815 |
+
np.random.choice(
|
| 816 |
+
range(n_word_aligned_samples),
|
| 817 |
+
replace=False,
|
| 818 |
+
size=num_samples // 2,
|
| 819 |
+
)
|
| 820 |
+
)
|
| 821 |
+
word_aligned_samples = word_aligned_samples[word_aligned_to_use, ...]
|
| 822 |
+
word_aligned_data_sample_inds = word_aligned_data_sample_inds[
|
| 823 |
+
word_aligned_to_use
|
| 824 |
+
]
|
| 825 |
+
|
| 826 |
+
unaligned_to_use = np.sort(
|
| 827 |
+
np.random.choice(
|
| 828 |
+
range(n_unaligned_word_samples),
|
| 829 |
+
replace=False,
|
| 830 |
+
size=num_samples // 2,
|
| 831 |
+
)
|
| 832 |
+
)
|
| 833 |
+
unaligned_word_samples = unaligned_word_samples[unaligned_to_use, ...]
|
| 834 |
+
unaligned_data_sample_inds = unaligned_data_sample_inds[unaligned_to_use]
|
| 835 |
+
|
| 836 |
+
n_word_aligned_samples = word_aligned_samples.shape[0]
|
| 837 |
+
n_unaligned_word_samples = unaligned_word_samples.shape[0]
|
| 838 |
+
|
| 839 |
+
# Concatenate data
|
| 840 |
+
data = torch.empty(
|
| 841 |
+
n_word_aligned_samples + n_unaligned_word_samples,
|
| 842 |
+
*word_aligned_samples.shape[1:],
|
| 843 |
+
)
|
| 844 |
+
data[:n_word_aligned_samples] = word_aligned_samples
|
| 845 |
+
data[n_word_aligned_samples:] = unaligned_word_samples
|
| 846 |
+
|
| 847 |
+
num_channels = raw_data.shape[0]
|
| 848 |
+
assert data.shape == (
|
| 849 |
+
num_samples,
|
| 850 |
+
num_channels,
|
| 851 |
+
n_steps_in_one_segment,
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
# Concatenate labels
|
| 855 |
+
labels = torch.zeros(num_samples, n_steps_in_one_segment)
|
| 856 |
+
labels[:n_word_aligned_samples] = 1
|
| 857 |
+
|
| 858 |
+
# Concatenate sample indices
|
| 859 |
+
data_sample_indices = torch.empty(
|
| 860 |
+
n_word_aligned_samples + n_unaligned_word_samples,
|
| 861 |
+
n_steps_in_one_segment,
|
| 862 |
+
)
|
| 863 |
+
data_sample_indices[:n_word_aligned_samples] = word_aligned_data_sample_inds
|
| 864 |
+
data_sample_indices[n_word_aligned_samples:] = unaligned_data_sample_inds
|
| 865 |
+
|
| 866 |
+
## Putting the samples back in temporally sorted order.
|
| 867 |
+
sorted_inds = torch.argsort(data_sample_indices[:, 0])
|
| 868 |
+
data_sample_indices = data_sample_indices[sorted_inds, ...]
|
| 869 |
+
data = data[sorted_inds, ...]
|
| 870 |
+
labels = labels[sorted_inds, ...]
|
| 871 |
+
return data, labels, data_sample_indices
|
| 872 |
+
|
| 873 |
+
def _aggregate_labels(self, labels: torch.Tensor) -> float:
|
| 874 |
+
"""
|
| 875 |
+
Return one label for each segment in batch instead of having one label for each timepoint
|
| 876 |
+
"""
|
| 877 |
+
|
| 878 |
+
nan_numels = torch.isnan(labels).sum()
|
| 879 |
+
|
| 880 |
+
if nan_numels / len(labels) >= self.config.aggregate_labels.nan_threshold:
|
| 881 |
+
label = torch.nan
|
| 882 |
+
elif self.config.aggregate_labels.type == "mean":
|
| 883 |
+
label = labels.nanmean()
|
| 884 |
+
label = float(label)
|
| 885 |
+
elif self.config.aggregate_labels.type == "threshold":
|
| 886 |
+
non_nan_numels = len(labels) - nan_numels
|
| 887 |
+
label = int(
|
| 888 |
+
(
|
| 889 |
+
labels.nansum() / non_nan_numels
|
| 890 |
+
> self.config.aggregate_labels.threshold
|
| 891 |
+
).long()
|
| 892 |
+
)
|
| 893 |
+
|
| 894 |
+
return label
|
| 895 |
+
|
| 896 |
+
def _get_segment_label(self, labels: torch.tensor) -> float:
|
| 897 |
+
if self.experiment_dataset_name == self._pretrain_enum:
|
| 898 |
+
return np.nan # pretraining data has no labels
|
| 899 |
+
|
| 900 |
+
agg_label = self._aggregate_labels(labels)
|
| 901 |
+
return agg_label
|
| 902 |
+
|
| 903 |
+
def _process_segments_and_update_metadata_file(self):
|
| 904 |
+
"""
|
| 905 |
+
Process data files of subjects and add/update segments
|
| 906 |
+
"""
|
| 907 |
+
number_of_added_segments = 0
|
| 908 |
+
for subject in self.available_sessions.keys():
|
| 909 |
+
for session in self.available_sessions[subject]:
|
| 910 |
+
print(
|
| 911 |
+
f"Segment processing for subject {subject} session {session} starts."
|
| 912 |
+
)
|
| 913 |
+
|
| 914 |
+
# Check status of processing
|
| 915 |
+
file_progress_tracker = FileProgressTracker(
|
| 916 |
+
save_path=self._get_file_progress_tracker_save_path(
|
| 917 |
+
subject, session
|
| 918 |
+
),
|
| 919 |
+
experiment=self.experiment,
|
| 920 |
+
)
|
| 921 |
+
|
| 922 |
+
if self.config.force_reprocess_stage2:
|
| 923 |
+
corresponding_indices_to_remove = (
|
| 924 |
+
self.metadata.get_indices_matching_cols_values(
|
| 925 |
+
["subject", "session", "experiment"],
|
| 926 |
+
[subject, session, self.experiment],
|
| 927 |
+
)
|
| 928 |
+
)
|
| 929 |
+
self.metadata.drop_rows_based_on_indices(
|
| 930 |
+
corresponding_indices_to_remove
|
| 931 |
+
)
|
| 932 |
+
|
| 933 |
+
file_progress_tracker.reset_process()
|
| 934 |
+
print(
|
| 935 |
+
f"Force reprocessing active, removed subject: {subject} session: "
|
| 936 |
+
f"{session} experiment: {self.experiment} from metadata, will "
|
| 937 |
+
f"start processing from the first file."
|
| 938 |
+
)
|
| 939 |
+
|
| 940 |
+
if file_progress_tracker.is_completed():
|
| 941 |
+
sp_exist = self._spatial_groupings_exist_for_subject(
|
| 942 |
+
subject, session
|
| 943 |
+
)
|
| 944 |
+
if sp_exist and not self.config.force_recreate_spatial_groupings:
|
| 945 |
+
print(
|
| 946 |
+
f"Subject {subject} data already processed completely, skipping."
|
| 947 |
+
)
|
| 948 |
+
continue
|
| 949 |
+
else:
|
| 950 |
+
print(
|
| 951 |
+
f"Subject {subject} data already processed completely,"
|
| 952 |
+
" but force recreate spatial groupings is active,"
|
| 953 |
+
" will recreate spatial groups"
|
| 954 |
+
)
|
| 955 |
+
|
| 956 |
+
number_of_added_segments_for_subject_session = (
|
| 957 |
+
self._create_segments_for_subject_session(
|
| 958 |
+
subject,
|
| 959 |
+
session,
|
| 960 |
+
self.config.segment_length_s,
|
| 961 |
+
file_progress_tracker,
|
| 962 |
+
)
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
print(
|
| 966 |
+
f"Added {number_of_added_segments_for_subject_session} new segments for subject {subject} session {session}"
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
nan_labels = self.metadata.get_indices_matching_cols_values(
|
| 970 |
+
["subject", "session", "experiment", "label"],
|
| 971 |
+
[subject, session, self.experiment, None],
|
| 972 |
+
)
|
| 973 |
+
print(
|
| 974 |
+
f"{len(nan_labels)} segments for this subject session have nan labels"
|
| 975 |
+
)
|
| 976 |
+
|
| 977 |
+
number_of_added_segments += number_of_added_segments_for_subject_session
|
| 978 |
+
|
| 979 |
+
self.metadata = self.splitter.set_splits_for_subject(
|
| 980 |
+
subject, self.metadata, self._split_method
|
| 981 |
+
)
|
| 982 |
+
file_progress_tracker.mark_completion_status()
|
| 983 |
+
self.metadata.save(self.metadata_path)
|
| 984 |
+
|
| 985 |
+
print(f"Metadata saved in {self.metadata_path}")
|
| 986 |
+
print(f"Added {number_of_added_segments} new segments")
|
| 987 |
+
|
| 988 |
+
summary_str = self.metadata.get_summary_str()
|
| 989 |
+
print(f"{self.name} dataset, full metadata summary: {summary_str}")
|
| 990 |
+
|
| 991 |
+
def _filter_metadata_for_the_run(self):
|
| 992 |
+
"""
|
| 993 |
+
Do filtering on metadata based on experiment design
|
| 994 |
+
|
| 995 |
+
# NOTE: Add stuff that are run dependent but do **not** alter the saved metadata here.
|
| 996 |
+
"""
|
| 997 |
+
# Return only needed experiment
|
| 998 |
+
self.metadata.reduce_based_on_col_value("experiment", self.experiment)
|
| 999 |
+
|
| 1000 |
+
# Drop rows with no label if not pretraining
|
| 1001 |
+
if not self.experiment_dataset_name == self._pretrain_enum:
|
| 1002 |
+
n_dropped = self.metadata.reduce_based_on_col_value(
|
| 1003 |
+
"label", None, keep=False
|
| 1004 |
+
)
|
| 1005 |
+
print(f"Dropping {n_dropped} segments with no labels")
|
| 1006 |
+
|
| 1007 |
+
if self.experiment_dataset_name in (
|
| 1008 |
+
BrainTreebankDatasetNames.SPEECH_VS_NONSPEECH_TIME,
|
| 1009 |
+
BrainTreebankDatasetNames.SENTENCE_ONSET_TIME,
|
| 1010 |
+
BrainTreebankDatasetNames.VOLUME,
|
| 1011 |
+
BrainTreebankDatasetNames.OPTICAL_FLOW
|
| 1012 |
+
):
|
| 1013 |
+
|
| 1014 |
+
curr_fold = self.config.get("chron_fold_num", None)
|
| 1015 |
+
if curr_fold is not None:
|
| 1016 |
+
print(f"Using chronological fold: {curr_fold}.")
|
| 1017 |
+
folds_path = os.path.join(
|
| 1018 |
+
self.config.save_dir,
|
| 1019 |
+
self.experiment,
|
| 1020 |
+
f"metadata_{self.segments_processing_hash_str}_folds.pkl",
|
| 1021 |
+
)
|
| 1022 |
+
try:
|
| 1023 |
+
with open(
|
| 1024 |
+
folds_path,
|
| 1025 |
+
"rb",
|
| 1026 |
+
) as f:
|
| 1027 |
+
folds_info = pickle.load(f)
|
| 1028 |
+
except FileNotFoundError as e:
|
| 1029 |
+
print(f"File {folds_path} not found. Generate the folds for the metadata ({self.metadata_path}) using `barista/generate_chronological_folds` notebook.")
|
| 1030 |
+
exit(0)
|
| 1031 |
+
|
| 1032 |
+
assert (
|
| 1033 |
+
len(self.config.finetune_sessions) == 1
|
| 1034 |
+
), "Only one finetune session expected."
|
| 1035 |
+
|
| 1036 |
+
subject_session = self.config.finetune_sessions[0]
|
| 1037 |
+
self.config.run_ratios = [
|
| 1038 |
+
# In case values were saved out as non-primitive float type.
|
| 1039 |
+
float(elem) for elem in folds_info[subject_session][curr_fold][0]
|
| 1040 |
+
]
|
| 1041 |
+
self.config.run_splits = folds_info[subject_session][curr_fold][1]
|
| 1042 |
+
|
| 1043 |
+
else: # no chron_fold_num specified.
|
| 1044 |
+
print("Using default run chronological ratios and splits.")
|
| 1045 |
+
|
| 1046 |
+
for subject_session in self.config.finetune_sessions:
|
| 1047 |
+
self.splitter.resplit_for_subject(
|
| 1048 |
+
subject_session,
|
| 1049 |
+
self.metadata,
|
| 1050 |
+
self._split_method,
|
| 1051 |
+
)
|
| 1052 |
+
|
| 1053 |
+
summary_str = self.metadata.get_summary_str()
|
| 1054 |
+
print(f"{self.name} dataset, current run summary: {summary_str}")
|
| 1055 |
+
|
| 1056 |
+
def process_segments(self, only_segment_generation=False):
|
| 1057 |
+
# Load the metadata in this dataset to have info from previously precessed segments.
|
| 1058 |
+
old_metadata = self._load_metadata()
|
| 1059 |
+
if old_metadata is not None:
|
| 1060 |
+
self.metadata = old_metadata
|
| 1061 |
+
|
| 1062 |
+
if not self.config.skip_segment_generation_completely:
|
| 1063 |
+
self._process_segments_and_update_metadata_file()
|
| 1064 |
+
|
| 1065 |
+
if not only_segment_generation:
|
| 1066 |
+
self._filter_metadata_for_the_run()
|
| 1067 |
+
|
| 1068 |
+
@property
|
| 1069 |
+
def _split_method(self):
|
| 1070 |
+
if self.experiment_dataset_name in (
|
| 1071 |
+
BrainTreebankDatasetNames.SPEECH_VS_NONSPEECH,
|
| 1072 |
+
BrainTreebankDatasetNames.SENTENCE_ONSET,
|
| 1073 |
+
):
|
| 1074 |
+
assert self.config.force_nonoverlap is True, "Set force_nonoverlap to True for random split segments"
|
| 1075 |
+
return "shuffle"
|
| 1076 |
+
# Everything else should just be split chronologically.
|
| 1077 |
+
|
| 1078 |
+
if self.experiment_dataset_name != BrainTreebankDatasetNames.PRETRAIN:
|
| 1079 |
+
assert self.config.force_nonoverlap is False, "Set force_nonoverlap to False for chronological segments"
|
| 1080 |
+
|
| 1081 |
+
return "chronological"
|
| 1082 |
+
|
| 1083 |
+
@property
|
| 1084 |
+
def _pretrain_enum(self) -> BrainTreebankDatasetNames:
|
| 1085 |
+
return BrainTreebankDatasetNames.PRETRAIN
|
| 1086 |
+
|
| 1087 |
+
def get_raw_data_file_path(self, subject: str, session: str):
|
| 1088 |
+
self.path_manager.get_raw_data_filepath(subject, session)
|
| 1089 |
+
|
| 1090 |
+
@property
|
| 1091 |
+
def _processed_raw_data_dir(self):
|
| 1092 |
+
"""
|
| 1093 |
+
Filename for processed raw data, i.e., filtering and referencing
|
| 1094 |
+
"""
|
| 1095 |
+
return os.path.join(
|
| 1096 |
+
self.config.save_dir,
|
| 1097 |
+
self._get_processed_raw_data_dir_name,
|
| 1098 |
+
)
|
| 1099 |
+
|
| 1100 |
+
@property
|
| 1101 |
+
def _get_processed_raw_data_dir_name(self):
|
| 1102 |
+
return f"processed_raw_{self.samp_frequency}Hz_notch_laplacianref_clnLap"
|
| 1103 |
+
|
| 1104 |
+
@property
|
| 1105 |
+
def _processed_segments_data_dir(self):
|
| 1106 |
+
"""Data dir for the segmented trials corresponding to a particular experimental config."""
|
| 1107 |
+
return os.path.join(
|
| 1108 |
+
self.config.save_dir,
|
| 1109 |
+
self.experiment,
|
| 1110 |
+
f"processed_segments_{self.segments_processing_hash_str}",
|
| 1111 |
+
)
|
| 1112 |
+
|
| 1113 |
+
def _load_metadata(self) -> Optional[Metadata]:
|
| 1114 |
+
if os.path.exists(self.metadata_path):
|
| 1115 |
+
metadata = Metadata(load_path=self.metadata_path)
|
| 1116 |
+
print(f"Metadata loaded from {self.metadata_path}")
|
| 1117 |
+
return metadata
|
| 1118 |
+
return None
|
| 1119 |
+
|
| 1120 |
+
def _initialize_metadata(self) -> Metadata:
|
| 1121 |
+
columns = [f.name for f in dataclasses.fields(MetadataRow)]
|
| 1122 |
+
metadata_df = pd.DataFrame(columns=columns)
|
| 1123 |
+
|
| 1124 |
+
columns = [f.name for f in dataclasses.fields(MetadataSpatialGroupRow)]
|
| 1125 |
+
spatial_group_df = pd.DataFrame(columns=columns)
|
| 1126 |
+
|
| 1127 |
+
metadata = Metadata(df=metadata_df, spatial_group_df=spatial_group_df)
|
| 1128 |
+
print(f"Metadata initialized: {self.metadata_path}")
|
| 1129 |
+
return metadata
|
| 1130 |
+
|
| 1131 |
+
def _get_processed_raw_data_file_path(self, subject, session):
|
| 1132 |
+
filename = f"{subject}_{session}.pt"
|
| 1133 |
+
return os.path.join(self._processed_raw_data_dir, filename)
|
| 1134 |
+
|
| 1135 |
+
def _get_processed_raw_data_file_path_cache(self, subject, session):
|
| 1136 |
+
filename = f"{subject}_{session}.pt"
|
| 1137 |
+
path = os.path.join(
|
| 1138 |
+
self.config.stage1_cache_dir,
|
| 1139 |
+
self._get_processed_raw_data_dir_name,
|
| 1140 |
+
)
|
| 1141 |
+
print(f"Cache dir: {path}")
|
| 1142 |
+
return path, os.path.join(path, filename)
|
| 1143 |
+
|
| 1144 |
+
def _get_segments_processing_hash(self, segment_length_s):
|
| 1145 |
+
"""
|
| 1146 |
+
returns a tuple where the key is the processing str, value is the hashed key.
|
| 1147 |
+
actual str can be found in metadata.
|
| 1148 |
+
|
| 1149 |
+
this part can be overwritten by each dataset class based on specific settings
|
| 1150 |
+
"""
|
| 1151 |
+
|
| 1152 |
+
processing_str = (
|
| 1153 |
+
f"{self.config.samp_frequency}Hz_zscrTrue"
|
| 1154 |
+
f"_segment_length{segment_length_s}_val_ratio{self.config.val_ratio:.1e}_test_ratio{self.config.test_ratio:.1e}"
|
| 1155 |
+
)
|
| 1156 |
+
|
| 1157 |
+
if self.experiment_dataset_name != self._pretrain_enum:
|
| 1158 |
+
processing_str += f"_trial_align{self.config.trial_alignment}"
|
| 1159 |
+
|
| 1160 |
+
if self.config.quantile_numerical_labels.active:
|
| 1161 |
+
processing_str += f"quantile_numerical_labels_L{self.config.quantile_numerical_labels.lower_threshold}_H{self.config.quantile_numerical_labels.higher_threshold}"
|
| 1162 |
+
|
| 1163 |
+
processing_str += self.config.dataset_dir
|
| 1164 |
+
processing_str += "_laplacian"
|
| 1165 |
+
|
| 1166 |
+
if self.config.region_filtering.active:
|
| 1167 |
+
self.config.region_filtering['filters'].sort()
|
| 1168 |
+
filter_str = (
|
| 1169 |
+
f"_region_filtered_{str(self.config.region_filtering.filters)}"
|
| 1170 |
+
)
|
| 1171 |
+
processing_str += filter_str
|
| 1172 |
+
|
| 1173 |
+
if not self.config.force_balanced:
|
| 1174 |
+
processing_str += "_all_labels"
|
| 1175 |
+
|
| 1176 |
+
if self._split_method == "chronological":
|
| 1177 |
+
processing_str += "_chronosplit"
|
| 1178 |
+
if not self.config.force_nonoverlap:
|
| 1179 |
+
processing_str += "_overlapsegs"
|
| 1180 |
+
|
| 1181 |
+
processing_str += "_use_clean_laplacian"
|
| 1182 |
+
processing_str += "_aggregate_label" + str(self.config.aggregate_labels)
|
| 1183 |
+
|
| 1184 |
+
hash_str = hashlib.sha256(bytes(processing_str, "utf-8")).hexdigest()[:5]
|
| 1185 |
+
print(f"HASHSTR: {hash_str}")
|
| 1186 |
+
return processing_str, hash_str
|
barista/data/dataframe_wrapper.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import deepcopy
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import torch
|
| 5 |
+
from typing import List, Optional, Union
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DataframeWrapper:
|
| 9 |
+
"""
|
| 10 |
+
A wrapper for a pandas DataFrame
|
| 11 |
+
|
| 12 |
+
This class provide extra functionality over pd.DataFrame and abstracts
|
| 13 |
+
the dependency on pandas dataframe (for the most part).
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
df: Optional[pd.DataFrame] = None,
|
| 19 |
+
load_path: Optional[str] = None,
|
| 20 |
+
) -> None:
|
| 21 |
+
if df is not None and load_path is not None:
|
| 22 |
+
raise ValueError("Only one of inner df or load path should be set")
|
| 23 |
+
|
| 24 |
+
if df is not None:
|
| 25 |
+
self._df: pd.DataFrame = df
|
| 26 |
+
else:
|
| 27 |
+
self._df: pd.DataFrame = self.load(load_path)
|
| 28 |
+
|
| 29 |
+
def copy(self):
|
| 30 |
+
new_df = self._df.copy(deep=True)
|
| 31 |
+
return self.__class__(df=new_df)
|
| 32 |
+
|
| 33 |
+
@classmethod
|
| 34 |
+
def merge(
|
| 35 |
+
cls,
|
| 36 |
+
metadatas: List["DataframeWrapper"],
|
| 37 |
+
drop_duplicate: bool = False,
|
| 38 |
+
merge_columns: Union[str, List[str], None] = None,
|
| 39 |
+
keep="first",
|
| 40 |
+
) -> "DataframeWrapper":
|
| 41 |
+
"""
|
| 42 |
+
Merge metadata's dataframes
|
| 43 |
+
If drop_duplicate = True, only one row from rows having same `merge_columns` will remain
|
| 44 |
+
based on `keep` strategy. Default to using all columns.
|
| 45 |
+
"""
|
| 46 |
+
metadata_dfs = [m._df for m in metadatas]
|
| 47 |
+
df = pd.concat(metadata_dfs, ignore_index=True)
|
| 48 |
+
if drop_duplicate:
|
| 49 |
+
df = df.drop_duplicates(subset=merge_columns, keep=keep)
|
| 50 |
+
return cls(df)
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def columns(self):
|
| 54 |
+
return self._df.columns
|
| 55 |
+
|
| 56 |
+
def concat(self, new_df: pd.DataFrame):
|
| 57 |
+
self._df = pd.concat([self._df, new_df], ignore_index=True, sort=True)
|
| 58 |
+
|
| 59 |
+
def shuffle(self, column: Optional[str] = None) -> None:
|
| 60 |
+
"""Shuffle the metadata table rows, or only a column if specified"""
|
| 61 |
+
shuffled = self._df.sample(frac=1, random_state=42).reset_index(drop=True)
|
| 62 |
+
|
| 63 |
+
if column is not None:
|
| 64 |
+
self._df[column] = shuffled[column]
|
| 65 |
+
else:
|
| 66 |
+
self._df = shuffled
|
| 67 |
+
|
| 68 |
+
def clear(self) -> None:
|
| 69 |
+
"""Setting the metadata to empty table"""
|
| 70 |
+
self._df = self._df.head(0)
|
| 71 |
+
|
| 72 |
+
def is_empty(self) -> bool:
|
| 73 |
+
return len(self._df) == 0
|
| 74 |
+
|
| 75 |
+
def __getitem__(self, idx: int) -> pd.Series:
|
| 76 |
+
"""Get a metadata table row"""
|
| 77 |
+
return self._df.iloc[idx]
|
| 78 |
+
|
| 79 |
+
def apply_fn_on_all_rows(self, col_name: str, fn: callable) -> pd.Series:
|
| 80 |
+
"""Apply a function on each row of the dataframe"""
|
| 81 |
+
return self._df[col_name].apply(fn)
|
| 82 |
+
|
| 83 |
+
def get_unique_values_in_col(
|
| 84 |
+
self, col_name: str, indices: Optional[List[int]] = None
|
| 85 |
+
) -> np.ndarray:
|
| 86 |
+
"""Get unique values of a columnn"""
|
| 87 |
+
values = self._df[col_name]
|
| 88 |
+
if indices is not None:
|
| 89 |
+
values = values.iloc[indices]
|
| 90 |
+
return list(values.unique())
|
| 91 |
+
|
| 92 |
+
def get_indices_matching_cols_values(
|
| 93 |
+
self, col_names: List, values: List, contains: bool = False, check_range: bool = False
|
| 94 |
+
) -> List[int]:
|
| 95 |
+
"""
|
| 96 |
+
Get indices of the rows that their value of specified `col_names`
|
| 97 |
+
match the values in the `values` list
|
| 98 |
+
|
| 99 |
+
value can be a tuple of two for continues values, specify `range=True`, it can also be a list
|
| 100 |
+
which in that case if `contains=True` it will check if the row value is in the list
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
assert len(col_names) == len(values)
|
| 104 |
+
|
| 105 |
+
mask = pd.Series(True, range(len(self)))
|
| 106 |
+
for col_name, value in zip(col_names, values):
|
| 107 |
+
if check_range and isinstance(value, tuple):
|
| 108 |
+
assert len(value) == 2, "For a range provide min and max value"
|
| 109 |
+
min_val, max_val = value
|
| 110 |
+
mask &= (self._df[col_name] >= min_val) & (self._df[col_name] <= max_val)
|
| 111 |
+
elif contains and isinstance(value, list):
|
| 112 |
+
mask &= self._df[col_name].isin(value)
|
| 113 |
+
elif value == None or pd.isnull(value):
|
| 114 |
+
mask &= self._df[col_name].isnull()
|
| 115 |
+
else:
|
| 116 |
+
mask &= self._df[col_name] == value
|
| 117 |
+
|
| 118 |
+
return self._df.index[mask].tolist()
|
| 119 |
+
|
| 120 |
+
def get_column_max_value(self, col_name: str):
|
| 121 |
+
return self._df[col_name].max()
|
| 122 |
+
|
| 123 |
+
def set_col_to_value(self, indices: List[int], col: str, value):
|
| 124 |
+
self._df.loc[indices, col] = value
|
| 125 |
+
|
| 126 |
+
def save(self, path: str) -> None:
|
| 127 |
+
"""Save metadata table to csv after converting lists and tuples to strings"""
|
| 128 |
+
|
| 129 |
+
def convert_complex_data(val, delimiter=","):
|
| 130 |
+
if isinstance(val, (list, tuple)):
|
| 131 |
+
return "[" + delimiter.join(map(str, val)) + "]"
|
| 132 |
+
elif isinstance(val, (dict, torch.Tensor, np.ndarray)):
|
| 133 |
+
raise TypeError(
|
| 134 |
+
f"Only columns of type list and tuple can be converted and saved, but received {type(val)}."
|
| 135 |
+
)
|
| 136 |
+
else:
|
| 137 |
+
return val
|
| 138 |
+
|
| 139 |
+
metadata_save = deepcopy(self._df)
|
| 140 |
+
if len(metadata_save) > 0:
|
| 141 |
+
for col in metadata_save.columns:
|
| 142 |
+
metadata_save[col] = metadata_save[col].apply(convert_complex_data)
|
| 143 |
+
metadata_save.to_csv(path, index=False)
|
| 144 |
+
|
| 145 |
+
def load(self, path: str) -> pd.DataFrame:
|
| 146 |
+
metadata = pd.read_csv(path)
|
| 147 |
+
|
| 148 |
+
def convert_from_string(val, delimiter=","):
|
| 149 |
+
# Check if the value is a list or tuple
|
| 150 |
+
if isinstance(val, str) and (
|
| 151 |
+
(val.startswith("[") and val.endswith("]"))
|
| 152 |
+
or (val.startswith("(") and val.endswith(")"))
|
| 153 |
+
):
|
| 154 |
+
val = val[1:-1]
|
| 155 |
+
# Attempt to convert to a list of floats or ints
|
| 156 |
+
val_split = val.split(delimiter)
|
| 157 |
+
converted = []
|
| 158 |
+
for item in val_split:
|
| 159 |
+
try:
|
| 160 |
+
if "." in item or "e-" in item or "e+" in item:
|
| 161 |
+
converted.append(float(item))
|
| 162 |
+
elif item == "None" or item == "":
|
| 163 |
+
converted.append(None)
|
| 164 |
+
else:
|
| 165 |
+
converted.append(int(item))
|
| 166 |
+
except Exception:
|
| 167 |
+
converted.append(item)
|
| 168 |
+
return converted
|
| 169 |
+
return val
|
| 170 |
+
|
| 171 |
+
def convert_channels_string_to_tuples(val: str):
|
| 172 |
+
if val.startswith("[") and val.endswith("]"):
|
| 173 |
+
val = val[1:-1]
|
| 174 |
+
|
| 175 |
+
def convert_channel_value(ch_val: str):
|
| 176 |
+
if ch_val.isnumeric():
|
| 177 |
+
return int(ch_val)
|
| 178 |
+
elif (ch_val.startswith("'") and ch_val.endswith("'")) or (
|
| 179 |
+
ch_val.startswith('"') and ch_val.endswith('"')
|
| 180 |
+
):
|
| 181 |
+
return ch_val[1:-1]
|
| 182 |
+
return ch_val
|
| 183 |
+
|
| 184 |
+
try:
|
| 185 |
+
return [
|
| 186 |
+
tuple(
|
| 187 |
+
[convert_channel_value(c) for c in ch_info_str[1:].split(", ")]
|
| 188 |
+
)
|
| 189 |
+
for ch_info_str in val[:-1].split("),")
|
| 190 |
+
]
|
| 191 |
+
except ValueError as e:
|
| 192 |
+
return [
|
| 193 |
+
tuple(ch_info_str[1:].split(", "))
|
| 194 |
+
for ch_info_str in val[:-1].split("),")
|
| 195 |
+
]
|
| 196 |
+
|
| 197 |
+
# Apply conversion to each column
|
| 198 |
+
for col in metadata.columns:
|
| 199 |
+
if col == "channels" or col == "coords": # keeping for backward compatibility
|
| 200 |
+
metadata[col] = np.nan
|
| 201 |
+
elif col == "group_components":
|
| 202 |
+
# Only do conversion for unique channel str since many segments have same channels
|
| 203 |
+
unique_str = metadata[col].unique()
|
| 204 |
+
channel_dict = {
|
| 205 |
+
c: convert_channels_string_to_tuples(c) for c in unique_str
|
| 206 |
+
}
|
| 207 |
+
metadata[col] = metadata[col].apply(lambda c: channel_dict[c])
|
| 208 |
+
else:
|
| 209 |
+
metadata[col] = metadata[col].apply(convert_from_string)
|
| 210 |
+
return metadata
|
| 211 |
+
|
| 212 |
+
def drop_rows_based_on_indices(self, indices: List[int]) -> None:
|
| 213 |
+
"""Drop certain rows based on list of indices"""
|
| 214 |
+
self._df = self._df.drop(indices).reset_index(drop=True)
|
| 215 |
+
|
| 216 |
+
def reduce_based_on_col_value(
|
| 217 |
+
self,
|
| 218 |
+
col_name: str,
|
| 219 |
+
value: Union[str, float],
|
| 220 |
+
regex: bool = False,
|
| 221 |
+
keep: bool = True,
|
| 222 |
+
) -> None:
|
| 223 |
+
"""
|
| 224 |
+
Filter rows based on `value` of the column `col_name`
|
| 225 |
+
Pass None as value if want to check for nan values.
|
| 226 |
+
|
| 227 |
+
regex: whether to use regex expression (contains) or exact value
|
| 228 |
+
keep: whether to keep the matching values rows or the rows that do not match
|
| 229 |
+
|
| 230 |
+
Returns number of dropped rows
|
| 231 |
+
"""
|
| 232 |
+
if not regex:
|
| 233 |
+
if value == None:
|
| 234 |
+
indices = self._df[col_name].isnull()
|
| 235 |
+
else:
|
| 236 |
+
indices = self._df[col_name] == value
|
| 237 |
+
else:
|
| 238 |
+
indices = self._df[col_name].str.contains(value)
|
| 239 |
+
|
| 240 |
+
if not keep:
|
| 241 |
+
indices = ~indices
|
| 242 |
+
|
| 243 |
+
self._df = self._df[indices].reset_index(drop=True)
|
| 244 |
+
return (~indices).sum()
|
| 245 |
+
|
| 246 |
+
def __len__(self):
|
| 247 |
+
return len(self._df)
|
| 248 |
+
|
| 249 |
+
def _get_column_mapping_dict_from_dataframe(self, key_col: str, value_col: str, df: Optional[None] = None):
|
| 250 |
+
"""
|
| 251 |
+
Get a dictionary containing `key_col` column values as keys and
|
| 252 |
+
`value_col` column values as values
|
| 253 |
+
"""
|
| 254 |
+
|
| 255 |
+
if df is None:
|
| 256 |
+
df = self._df
|
| 257 |
+
|
| 258 |
+
unique_keys_index = (
|
| 259 |
+
df.dropna(subset=value_col)
|
| 260 |
+
.drop_duplicates(subset=key_col, keep="first")
|
| 261 |
+
.index
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
keys = df.loc[unique_keys_index, key_col]
|
| 265 |
+
values = df.loc[unique_keys_index, value_col]
|
| 266 |
+
|
| 267 |
+
output = dict(zip(keys, values))
|
| 268 |
+
return output
|
barista/data/fileprogresstracker.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
class FileProgressTracker:
|
| 6 |
+
"""Manage loading and storing latest completely processed file index
|
| 7 |
+
|
| 8 |
+
This class save information required to continue processing in a file.
|
| 9 |
+
The file structure will be:
|
| 10 |
+
{
|
| 11 |
+
[experiment]: {
|
| 12 |
+
[self._file_ind_key]: int,
|
| 13 |
+
[self._ending_ind_key]: int,
|
| 14 |
+
[self._segment_id_key]: int
|
| 15 |
+
}
|
| 16 |
+
}
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, save_path: str, experiment: str):
|
| 20 |
+
self.path = save_path
|
| 21 |
+
self.experiment = experiment
|
| 22 |
+
self._file_ind_key = "file_ind"
|
| 23 |
+
self._ending_ind_key = "ending_ind"
|
| 24 |
+
self._segment_id_key = "segment_id"
|
| 25 |
+
self._completed_key = "is_completed"
|
| 26 |
+
|
| 27 |
+
def _load_file(self) -> dict:
|
| 28 |
+
"""Load processing info from file
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
A dictionary having structure as descripted in the class info
|
| 32 |
+
"""
|
| 33 |
+
data = {}
|
| 34 |
+
if os.path.exists(self.path):
|
| 35 |
+
with open(self.path) as f:
|
| 36 |
+
data = json.load(f)
|
| 37 |
+
|
| 38 |
+
if self.experiment not in data:
|
| 39 |
+
data[self.experiment] = {
|
| 40 |
+
self._file_ind_key: 0,
|
| 41 |
+
self._ending_ind_key: 0,
|
| 42 |
+
self._segment_id_key: -1,
|
| 43 |
+
self._completed_key: False,
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
return data
|
| 47 |
+
|
| 48 |
+
def _update_file(self, update_dict: dict) -> None:
|
| 49 |
+
"""Update specified keys in file"""
|
| 50 |
+
|
| 51 |
+
data = self._load_file()
|
| 52 |
+
data[self.experiment].update(update_dict)
|
| 53 |
+
|
| 54 |
+
with open(self.path, "w+") as f:
|
| 55 |
+
json.dump(data, f)
|
| 56 |
+
|
| 57 |
+
def get_last_file_ind(self) -> Tuple[int, int, int]:
|
| 58 |
+
"""Get last file that was processed for this experiment
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
A tuple containing file index, ending index in the file, and the segment number of the last processed file
|
| 62 |
+
"""
|
| 63 |
+
data = self._load_file()
|
| 64 |
+
return (
|
| 65 |
+
data[self.experiment][self._file_ind_key],
|
| 66 |
+
data[self.experiment][self._ending_ind_key],
|
| 67 |
+
data[self.experiment][self._segment_id_key],
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def update_last_file_ind(
|
| 71 |
+
self, file_ind: int, ending_ind: int, segment_id: int
|
| 72 |
+
) -> None:
|
| 73 |
+
"""Update last file processed info in this experiment without changing other info in file if necessary"""
|
| 74 |
+
|
| 75 |
+
self._update_file(
|
| 76 |
+
{
|
| 77 |
+
self._file_ind_key: file_ind,
|
| 78 |
+
self._ending_ind_key: ending_ind,
|
| 79 |
+
self._segment_id_key: segment_id,
|
| 80 |
+
}
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def mark_completion_status(self, completed: bool = True) -> None:
|
| 84 |
+
self._update_file({self._completed_key: completed})
|
| 85 |
+
|
| 86 |
+
def is_completed(self) -> bool:
|
| 87 |
+
data = self._load_file()
|
| 88 |
+
return data[self.experiment].get(self._completed_key, False)
|
| 89 |
+
|
| 90 |
+
def reset_process(self) -> None:
|
| 91 |
+
"""Reset file processing status"""
|
| 92 |
+
self.mark_completion_status(completed=False)
|
| 93 |
+
self.update_last_file_ind(0, 0, -1)
|
barista/data/metadata.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import torch
|
| 5 |
+
from typing import Dict, List, Optional, Union
|
| 6 |
+
|
| 7 |
+
from barista.data.dataframe_wrapper import DataframeWrapper
|
| 8 |
+
from barista.data.metadata_spatial_groups import (
|
| 9 |
+
MetadataSpatialGroupRow,
|
| 10 |
+
MetadataSpatialGroups,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclasses.dataclass
|
| 15 |
+
class MetadataRow:
|
| 16 |
+
dataset: str
|
| 17 |
+
subject: str
|
| 18 |
+
session: str
|
| 19 |
+
subject_session: str
|
| 20 |
+
experiment: str
|
| 21 |
+
d_input: int
|
| 22 |
+
d_data: torch.Size
|
| 23 |
+
split: str
|
| 24 |
+
path: str
|
| 25 |
+
filename: str
|
| 26 |
+
processing_str: str
|
| 27 |
+
seq_len: int
|
| 28 |
+
label: Optional[float]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Metadata(DataframeWrapper):
|
| 32 |
+
"""
|
| 33 |
+
Metadata class to keep track of all segment meta information.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, df=None, load_path=None, spatial_group_df=None):
|
| 37 |
+
if df is None:
|
| 38 |
+
assert spatial_group_df is None
|
| 39 |
+
|
| 40 |
+
super().__init__(df, load_path)
|
| 41 |
+
|
| 42 |
+
self._spatial_groups = None
|
| 43 |
+
if load_path is not None:
|
| 44 |
+
try:
|
| 45 |
+
self._spatial_groups = MetadataSpatialGroups(
|
| 46 |
+
load_path=self._get_spatial_group_path(load_path)
|
| 47 |
+
)
|
| 48 |
+
except FileNotFoundError:
|
| 49 |
+
pass
|
| 50 |
+
elif spatial_group_df is not None:
|
| 51 |
+
self._spatial_groups = MetadataSpatialGroups(df=spatial_group_df)
|
| 52 |
+
|
| 53 |
+
def _get_spatial_group_path(self, path: str) -> str:
|
| 54 |
+
suffix = ".csv"
|
| 55 |
+
new_path = path[: -len(suffix)]
|
| 56 |
+
spatial_path = f"{new_path}_spatial_groups{suffix}"
|
| 57 |
+
return spatial_path
|
| 58 |
+
|
| 59 |
+
def save(self, path: str) -> None:
|
| 60 |
+
super().save(path)
|
| 61 |
+
self._spatial_groups.save(self._get_spatial_group_path(path))
|
| 62 |
+
|
| 63 |
+
@classmethod
|
| 64 |
+
def merge(
|
| 65 |
+
cls,
|
| 66 |
+
metadatas: List["Metadata"],
|
| 67 |
+
drop_duplicate: bool = False,
|
| 68 |
+
merge_columns: Union[str, List[str], None] = None,
|
| 69 |
+
keep="first",
|
| 70 |
+
) -> "Metadata":
|
| 71 |
+
new_metadata = super().merge(metadatas, drop_duplicate, merge_columns, keep)
|
| 72 |
+
|
| 73 |
+
# Add spatial groups
|
| 74 |
+
spatial_groups = [m._spatial_groups for m in metadatas]
|
| 75 |
+
merged_spatial_groups = MetadataSpatialGroups.merge(
|
| 76 |
+
spatial_groups,
|
| 77 |
+
drop_duplicate=True,
|
| 78 |
+
merge_columns=[
|
| 79 |
+
"dataset",
|
| 80 |
+
"subject_session",
|
| 81 |
+
"name",
|
| 82 |
+
],
|
| 83 |
+
)
|
| 84 |
+
new_metadata._spatial_groups = merged_spatial_groups
|
| 85 |
+
return new_metadata
|
| 86 |
+
|
| 87 |
+
def get_subject_session_d_input(self) -> dict:
|
| 88 |
+
return self._get_column_mapping_dict_from_dataframe(
|
| 89 |
+
key_col="subject_session",
|
| 90 |
+
value_col="d_input",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def get_subjects(self) -> dict:
|
| 94 |
+
return self.get_unique_values_in_col("subject")
|
| 95 |
+
|
| 96 |
+
def _shape_str_to_list(self, value) -> tuple:
|
| 97 |
+
if not isinstance(value, str):
|
| 98 |
+
return value
|
| 99 |
+
return [int(a) for a in value.split(",")]
|
| 100 |
+
|
| 101 |
+
def get_subject_session_full_d_data(self) -> Dict[str, List[int]]:
|
| 102 |
+
"""
|
| 103 |
+
Returns a dict containing subject_session to data shape
|
| 104 |
+
"""
|
| 105 |
+
my_dict = self._get_column_mapping_dict_from_dataframe(
|
| 106 |
+
key_col="subject_session",
|
| 107 |
+
value_col="d_data",
|
| 108 |
+
)
|
| 109 |
+
return {k: self._shape_str_to_list(v) for k, v in my_dict.items()}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def get_labels_count_summary(self) -> dict:
|
| 113 |
+
splits = self.get_unique_values_in_col("split")
|
| 114 |
+
labels = self.get_unique_values_in_col("label")
|
| 115 |
+
|
| 116 |
+
labels_count = defaultdict(dict)
|
| 117 |
+
for split in splits:
|
| 118 |
+
for label in labels:
|
| 119 |
+
count = len(
|
| 120 |
+
self.get_indices_matching_cols_values(
|
| 121 |
+
["split", "label"],
|
| 122 |
+
[split, label],
|
| 123 |
+
)
|
| 124 |
+
)
|
| 125 |
+
labels_count[split][label] = count
|
| 126 |
+
return labels_count
|
| 127 |
+
|
| 128 |
+
def get_summary_str(self) -> str:
|
| 129 |
+
subjects = self.get_unique_values_in_col("subject")
|
| 130 |
+
labels_count = self.get_labels_count_summary()
|
| 131 |
+
|
| 132 |
+
summary_str = f"Metadata for {len(subjects)} subjects ({subjects})"
|
| 133 |
+
|
| 134 |
+
for split, labels in labels_count.items():
|
| 135 |
+
for label, count in labels.items():
|
| 136 |
+
summary_str += f", {count} {split} segments with label {label}"
|
| 137 |
+
|
| 138 |
+
return summary_str
|
| 139 |
+
|
| 140 |
+
########################### spatial group related ###########################
|
| 141 |
+
|
| 142 |
+
def add_spatial_group(self, spatial_group_row: MetadataSpatialGroupRow):
|
| 143 |
+
"""
|
| 144 |
+
Add (or overwrite) the spatial group
|
| 145 |
+
"""
|
| 146 |
+
self._spatial_groups.remove_spatial_group(
|
| 147 |
+
spatial_group_row.subject_session, spatial_group_row.name
|
| 148 |
+
)
|
| 149 |
+
self._spatial_groups.concat(pd.DataFrame([spatial_group_row]))
|
| 150 |
+
|
| 151 |
+
def get_spatial_grouping(
|
| 152 |
+
self, subject_session: str, name: str
|
| 153 |
+
) -> Optional[MetadataSpatialGroupRow]:
|
| 154 |
+
"""
|
| 155 |
+
Return spatial grouping information for spatial grouping `name` and subject_session `subject_session`'s.
|
| 156 |
+
|
| 157 |
+
Spatial grouping is MetadataSpatialGroupRow which the most important property is group_components
|
| 158 |
+
which is a list of tuples that contains group info for each channel of the data,
|
| 159 |
+
and group_ids which is a list of integer that specify which group each channel belongs to.
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
return self._spatial_groups.get_spatial_grouping(subject_session, name)
|
| 163 |
+
|
| 164 |
+
def get_spatial_grouping_id_hashmap(self, name: str) -> Dict[str, List[int]]:
|
| 165 |
+
"""
|
| 166 |
+
Return spatial grouping dictionary which maps each subject_session to list of group ids which is a list of
|
| 167 |
+
length channels specifying which group each channel belongs to.
|
| 168 |
+
|
| 169 |
+
# NOTE Don't use during forward because of the copy
|
| 170 |
+
"""
|
| 171 |
+
temp_copy = self._spatial_groups.copy()
|
| 172 |
+
temp_copy.reduce_based_on_col_value(col_name="name", value=name, keep=True)
|
| 173 |
+
return temp_copy._get_column_mapping_dict_from_dataframe(
|
| 174 |
+
"subject_session", "group_ids"
|
| 175 |
+
)
|
barista/data/metadata_spatial_groups.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
from enum import Enum
|
| 3 |
+
from typing import List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
from barista.data.dataframe_wrapper import DataframeWrapper
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclasses.dataclass
|
| 9 |
+
class MetadataSpatialGroupRow:
|
| 10 |
+
dataset: str
|
| 11 |
+
subject: str
|
| 12 |
+
session: str
|
| 13 |
+
subject_session: str
|
| 14 |
+
name: str # name/identifier of the spatial grouping
|
| 15 |
+
n_effective_components: int
|
| 16 |
+
max_elements_for_component: (
|
| 17 |
+
Tuple # tuple of size n_effective_components (or larger)
|
| 18 |
+
)
|
| 19 |
+
padding_indices: Tuple # tuple of size n_effective_components (or larger)
|
| 20 |
+
group_components: List # list of len number of channels -- List tuples that contains group info for each channel, useful for spatial encoding
|
| 21 |
+
group_ids: List # list of len number of channels -- List of int specifying which group each channel belongs to, useful for spatial masking
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SpatialGroupingName(Enum):
|
| 25 |
+
COORDS = "coords"
|
| 26 |
+
DESTRIEUX = "destrieux"
|
| 27 |
+
LOBES = "lobes"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class MetadataSpatialGroups(DataframeWrapper):
|
| 31 |
+
def _get_spatial_grouping_index(
|
| 32 |
+
self, subject_session: str, name: str
|
| 33 |
+
) -> Optional[int]:
|
| 34 |
+
indices = self.get_indices_matching_cols_values(
|
| 35 |
+
["subject_session", "name"], [subject_session, name]
|
| 36 |
+
)
|
| 37 |
+
if len(indices) == 0:
|
| 38 |
+
return None
|
| 39 |
+
assert (
|
| 40 |
+
len(indices) == 1
|
| 41 |
+
), f"More than one results for spatial grouping '{name}' for '{subject_session}'"
|
| 42 |
+
|
| 43 |
+
return indices[0]
|
| 44 |
+
|
| 45 |
+
def get_spatial_grouping(
|
| 46 |
+
self, subject_session: str, name: str
|
| 47 |
+
) -> MetadataSpatialGroupRow:
|
| 48 |
+
idx = self._get_spatial_grouping_index(subject_session, name)
|
| 49 |
+
if idx is None:
|
| 50 |
+
return None
|
| 51 |
+
a = self._df.iloc[idx].to_dict()
|
| 52 |
+
if "uniq_group_components" in a:
|
| 53 |
+
del a["uniq_group_components"]
|
| 54 |
+
return MetadataSpatialGroupRow(**a)
|
| 55 |
+
|
| 56 |
+
def remove_spatial_group(self, subject_session: str, name: str) -> int:
|
| 57 |
+
idx = self._get_spatial_grouping_index(subject_session, name)
|
| 58 |
+
if idx is None:
|
| 59 |
+
return 0
|
| 60 |
+
return self.drop_rows_based_on_indices([idx])
|
barista/data/splitter.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import os
|
| 3 |
+
from typing import Dict, List
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from barista.data.metadata import Metadata
|
| 9 |
+
from barista.models.utils import seed_everything
|
| 10 |
+
|
| 11 |
+
_SUPPORTED_SPLITS = ["shuffle", "chronological"]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Splitter:
|
| 15 |
+
"""Helper class to handle train/test/val splitting."""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
config: Dict,
|
| 20 |
+
subjects: List,
|
| 21 |
+
experiment: str,
|
| 22 |
+
use_fixed_seed: bool = False,
|
| 23 |
+
):
|
| 24 |
+
self.config = config
|
| 25 |
+
self.subjects = subjects
|
| 26 |
+
self.experiment = experiment
|
| 27 |
+
|
| 28 |
+
self.use_fixed_seed = use_fixed_seed
|
| 29 |
+
|
| 30 |
+
def _use_configured_seed(func):
|
| 31 |
+
"""Decorator for changing seed for a specific function"""
|
| 32 |
+
|
| 33 |
+
def wrapper(self, *args, **kwargs):
|
| 34 |
+
if not self.use_fixed_seed:
|
| 35 |
+
return func(self, *args, **kwargs)
|
| 36 |
+
|
| 37 |
+
prev_seed = int(os.environ.get("PL_GLOBAL_SEED", 0))
|
| 38 |
+
new_seed = int(self.config.get("splitter_seed", 0))
|
| 39 |
+
|
| 40 |
+
print(
|
| 41 |
+
f"Changing seed from {prev_seed} to {new_seed} for splitting"
|
| 42 |
+
)
|
| 43 |
+
seed_everything(new_seed)
|
| 44 |
+
|
| 45 |
+
out = func(self, *args, **kwargs)
|
| 46 |
+
|
| 47 |
+
print(f"Changing back seed from {new_seed} to {prev_seed}.")
|
| 48 |
+
seed_everything(prev_seed)
|
| 49 |
+
|
| 50 |
+
return out
|
| 51 |
+
|
| 52 |
+
return wrapper
|
| 53 |
+
|
| 54 |
+
@_use_configured_seed
|
| 55 |
+
def set_splits_for_subject(
|
| 56 |
+
self,
|
| 57 |
+
subject: str,
|
| 58 |
+
metadata: Metadata,
|
| 59 |
+
split_method: str = "shuffle"
|
| 60 |
+
) -> Metadata:
|
| 61 |
+
"""Set train/validation/test split
|
| 62 |
+
|
| 63 |
+
Every `split_together_length_s` will be splitted into one of the train/val/test
|
| 64 |
+
|
| 65 |
+
NOTE: This function assumes the segments are in order and consecutive in metadata if you want
|
| 66 |
+
to use split together multiple consecutive segments
|
| 67 |
+
"""
|
| 68 |
+
# Set default if necessary.
|
| 69 |
+
if split_method not in _SUPPORTED_SPLITS:
|
| 70 |
+
print(f"[Warning] Setting split_method={split_method} to 'shuffle'")
|
| 71 |
+
split_method = "shuffle"
|
| 72 |
+
|
| 73 |
+
# Ensure the split together length is at least as long as the segments.
|
| 74 |
+
# Setting allows to split time series based on intervals > neural segment length.
|
| 75 |
+
split_together_length_s = max(
|
| 76 |
+
self.config.get("split_together_length_s", self.config.segment_length_s),
|
| 77 |
+
self.config.segment_length_s
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
subject_rows_indices = metadata.get_indices_matching_cols_values(
|
| 81 |
+
["subject", "experiment"], [subject, self.experiment]
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
if split_method == "chronological":
|
| 85 |
+
return self._set_splits_across_time(
|
| 86 |
+
metadata, subject_rows_indices=subject_rows_indices
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
split_together_count = int(
|
| 90 |
+
split_together_length_s // self.config.segment_length_s
|
| 91 |
+
)
|
| 92 |
+
consecutive = (torch.diff(torch.tensor(subject_rows_indices)) == 1).all()
|
| 93 |
+
|
| 94 |
+
if split_together_count > 1:
|
| 95 |
+
assert (
|
| 96 |
+
consecutive
|
| 97 |
+
), "subject rows are not consecutive, can't do splitting together"
|
| 98 |
+
|
| 99 |
+
n_segments = len(subject_rows_indices)
|
| 100 |
+
if n_segments == 0:
|
| 101 |
+
print(
|
| 102 |
+
f"[WARNING] No rows found for the subject {subject} and experiment {self.experiment} in metadata"
|
| 103 |
+
)
|
| 104 |
+
return metadata
|
| 105 |
+
|
| 106 |
+
starting_ind = subject_rows_indices[0]
|
| 107 |
+
|
| 108 |
+
if consecutive:
|
| 109 |
+
groups = list(
|
| 110 |
+
range(
|
| 111 |
+
starting_ind,
|
| 112 |
+
starting_ind + n_segments - split_together_count + 1,
|
| 113 |
+
split_together_count,
|
| 114 |
+
)
|
| 115 |
+
)
|
| 116 |
+
else:
|
| 117 |
+
# we've asserted that split_together_count is 1 in this case
|
| 118 |
+
groups = copy.deepcopy(subject_rows_indices)
|
| 119 |
+
|
| 120 |
+
np.random.shuffle(groups)
|
| 121 |
+
|
| 122 |
+
val_size = max(int(self.config.val_ratio * len(groups)), 1)
|
| 123 |
+
test_size = max(int(self.config.test_ratio * len(groups)), 1)
|
| 124 |
+
|
| 125 |
+
val_indices = []
|
| 126 |
+
for group_starting_idx in groups[:val_size]:
|
| 127 |
+
group_elem_indices = np.arange(split_together_count) + group_starting_idx
|
| 128 |
+
val_indices.extend(group_elem_indices)
|
| 129 |
+
|
| 130 |
+
test_indices = []
|
| 131 |
+
for group_starting_idx in groups[val_size : val_size + test_size]:
|
| 132 |
+
group_elem_indices = np.arange(split_together_count) + group_starting_idx
|
| 133 |
+
test_indices.extend(group_elem_indices)
|
| 134 |
+
|
| 135 |
+
metadata.set_col_to_value(subject_rows_indices, "split", "train")
|
| 136 |
+
metadata.set_col_to_value(val_indices, "split", "val")
|
| 137 |
+
metadata.set_col_to_value(test_indices, "split", "test")
|
| 138 |
+
|
| 139 |
+
return metadata
|
| 140 |
+
|
| 141 |
+
@_use_configured_seed
|
| 142 |
+
def resplit_for_subject(
|
| 143 |
+
self,
|
| 144 |
+
subject_session: str,
|
| 145 |
+
metadata: Metadata,
|
| 146 |
+
split_method: str,
|
| 147 |
+
) -> Metadata:
|
| 148 |
+
if split_method == "chronological":
|
| 149 |
+
return self._set_splits_across_time(
|
| 150 |
+
metadata, subject_session=subject_session
|
| 151 |
+
)
|
| 152 |
+
else:
|
| 153 |
+
print("[WARNING] Resplitting only for chronological; splits unchanged")
|
| 154 |
+
return metadata
|
| 155 |
+
|
| 156 |
+
def __check_contiguous(self, subject_rows_indices, check_monotonic_only=False):
|
| 157 |
+
if check_monotonic_only:
|
| 158 |
+
assert (
|
| 159 |
+
torch.diff(torch.tensor(subject_rows_indices)) >= 1
|
| 160 |
+
).all(), "subject rows are not consecutive, can't do splitting together"
|
| 161 |
+
else: # we need to be exactly increments of one.
|
| 162 |
+
assert (
|
| 163 |
+
torch.diff(torch.tensor(subject_rows_indices)) == 1
|
| 164 |
+
).all(), "subject rows are not consecutive, can't do splitting together"
|
| 165 |
+
|
| 166 |
+
@_use_configured_seed
|
| 167 |
+
def _set_splits_across_time(
|
| 168 |
+
self,
|
| 169 |
+
metadata: Metadata,
|
| 170 |
+
subject_rows_indices: list = [],
|
| 171 |
+
subject_session: str = "",
|
| 172 |
+
return_splitted_indices: bool = False,
|
| 173 |
+
check_monotonic_only: bool = False,
|
| 174 |
+
verbose: bool = False,
|
| 175 |
+
) -> Metadata:
|
| 176 |
+
if not subject_rows_indices and not subject_session:
|
| 177 |
+
raise ValueError(
|
| 178 |
+
"Need to either pass complete subject session name or subject_row_indices"
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
if (
|
| 182 |
+
not subject_rows_indices
|
| 183 |
+
): # Prioritize using the subject_row_indices if given.
|
| 184 |
+
subject_rows_indices = metadata.get_indices_matching_cols_values(
|
| 185 |
+
["subject_session", "experiment"], [subject_session, self.experiment]
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
self.__check_contiguous(
|
| 189 |
+
subject_rows_indices, check_monotonic_only=check_monotonic_only
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
n_segments = len(subject_rows_indices)
|
| 193 |
+
|
| 194 |
+
assert len(self.config.run_ratios) == len(self.config.run_splits)
|
| 195 |
+
|
| 196 |
+
counts = (np.array(self.config.run_ratios) * n_segments).astype(int)
|
| 197 |
+
counts[-1] = n_segments - sum(counts[:-1])
|
| 198 |
+
|
| 199 |
+
if verbose:
|
| 200 |
+
print(f"subject_session: {subject_session}")
|
| 201 |
+
print(f"RATIOS: {self.config.run_ratios}")
|
| 202 |
+
print(f"self.config.run_splits: {self.config.run_splits}")
|
| 203 |
+
print(f"COUNTS: {counts}")
|
| 204 |
+
|
| 205 |
+
if return_splitted_indices:
|
| 206 |
+
splitted_indices = []
|
| 207 |
+
sum_now = 0
|
| 208 |
+
for c, split in zip(counts, self.config.run_splits):
|
| 209 |
+
label_split_indices = subject_rows_indices[sum_now : sum_now + c]
|
| 210 |
+
if return_splitted_indices:
|
| 211 |
+
splitted_indices.append(label_split_indices)
|
| 212 |
+
|
| 213 |
+
sum_now += c
|
| 214 |
+
metadata.set_col_to_value(label_split_indices, "split", split)
|
| 215 |
+
|
| 216 |
+
self._check_split_labels(metadata, subject_session)
|
| 217 |
+
if return_splitted_indices:
|
| 218 |
+
return metadata, splitted_indices
|
| 219 |
+
return metadata
|
| 220 |
+
|
| 221 |
+
def _check_split_labels(self, metadata, subject_session):
|
| 222 |
+
# Check that both labels available in each split.
|
| 223 |
+
# NOTE: Not using asserts because the initial default splits might not have
|
| 224 |
+
# both, but the ones computed offline will and provided through the .pkl file
|
| 225 |
+
# will satisfy requirement.
|
| 226 |
+
for split in np.unique(self.config.run_splits):
|
| 227 |
+
for i in range(2): # magic 2 = positive/negative labels
|
| 228 |
+
if (
|
| 229 |
+
len(
|
| 230 |
+
metadata.get_indices_matching_cols_values(
|
| 231 |
+
["subject_session", "experiment", "label", "split"],
|
| 232 |
+
[subject_session, self.experiment, i, split],
|
| 233 |
+
)
|
| 234 |
+
)
|
| 235 |
+
== 0
|
| 236 |
+
):
|
| 237 |
+
print(f"split {split} missing label {i}")
|
barista/generate_chronological_folds.ipynb
ADDED
|
@@ -0,0 +1,626 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "6d5e7d9f",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"### Chronological split generation.\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"The following is code used to generate the chronological splits based on the presence of positive and negative samples. This is more of an issue for the speech/sentence tasks, but the same approach is also used for the volume and optical flow tasks."
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"cell_type": "code",
|
| 15 |
+
"execution_count": 1,
|
| 16 |
+
"id": "70411f5f",
|
| 17 |
+
"metadata": {},
|
| 18 |
+
"outputs": [],
|
| 19 |
+
"source": [
|
| 20 |
+
"%load_ext autoreload\n",
|
| 21 |
+
"%autoreload 2\n"
|
| 22 |
+
]
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"cell_type": "code",
|
| 26 |
+
"execution_count": null,
|
| 27 |
+
"id": "2d6f1fed",
|
| 28 |
+
"metadata": {},
|
| 29 |
+
"outputs": [],
|
| 30 |
+
"source": [
|
| 31 |
+
"from barista.data.metadata import Metadata\n",
|
| 32 |
+
"from collections import Counter, defaultdict\n",
|
| 33 |
+
"import numpy as np\n",
|
| 34 |
+
"import os\n",
|
| 35 |
+
"from pathlib import Path"
|
| 36 |
+
]
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"cell_type": "code",
|
| 40 |
+
"execution_count": 3,
|
| 41 |
+
"id": "b579134b",
|
| 42 |
+
"metadata": {},
|
| 43 |
+
"outputs": [],
|
| 44 |
+
"source": [
|
| 45 |
+
"def load_metadata(metadata_path):\n",
|
| 46 |
+
" return Metadata(load_path=metadata_path)"
|
| 47 |
+
]
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"cell_type": "code",
|
| 51 |
+
"execution_count": 4,
|
| 52 |
+
"id": "d17fbaaa",
|
| 53 |
+
"metadata": {},
|
| 54 |
+
"outputs": [],
|
| 55 |
+
"source": [
|
| 56 |
+
"def generate_folds(subject_rows_indices, per_label_subject_rows_indices,\n",
|
| 57 |
+
" bucket_size=0.05, step_size=1, base_step_size=1,\n",
|
| 58 |
+
" window=4, base_window=1, **folds_kwargss):\n",
|
| 59 |
+
" assert window % 4 == 0, \"Window should be divisible by 4\"\n",
|
| 60 |
+
"\n",
|
| 61 |
+
" bucket_len = int(bucket_size * len(subject_rows_indices)) # bucket size in samples\n",
|
| 62 |
+
" buckets = np.arange(subject_rows_indices[0], subject_rows_indices[-1], bucket_len)\n",
|
| 63 |
+
" print(f\"Buckets: {buckets}\")\n",
|
| 64 |
+
"\n",
|
| 65 |
+
" ## Magic number 2 everywhere corresponds to the 0/1 (negative/positive) labels.\n",
|
| 66 |
+
" ## First, sum the unique label counts per bucket according to the specifications provided.\n",
|
| 67 |
+
" bucket_counts = [{} for i in range(len(buckets)-1)]\n",
|
| 68 |
+
" for bucket_ind in range(0, len(bucket_counts), base_step_size):\n",
|
| 69 |
+
" bucket_start = buckets[bucket_ind]\n",
|
| 70 |
+
" bucket_end = bucket_start + base_window * bucket_len\n",
|
| 71 |
+
" for i in range(2):\n",
|
| 72 |
+
" bucket_counts[bucket_ind][i] = np.sum(np.logical_and(\n",
|
| 73 |
+
" per_label_subject_rows_indices[i] >= bucket_start,\n",
|
| 74 |
+
" per_label_subject_rows_indices[i] < bucket_end\n",
|
| 75 |
+
" ))\n",
|
| 76 |
+
"\n",
|
| 77 |
+
" ## Count the residual samples in the last bucket.\n",
|
| 78 |
+
" for i in range(2):\n",
|
| 79 |
+
" bucket_counts[-1][i] += np.sum(\n",
|
| 80 |
+
" per_label_subject_rows_indices[i] >= bucket_end\n",
|
| 81 |
+
" )\n",
|
| 82 |
+
" print(f\"bucket_counts: {bucket_counts}\")\n",
|
| 83 |
+
"\n",
|
| 84 |
+
" return _find_folds(bucket_counts, step_size, window, bucket_size, **folds_kwargss)\n",
|
| 85 |
+
"\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"def _find_folds(bucket_counts, step_size, window, bucket_size, num_folds=5):\n",
|
| 88 |
+
" \"\"\"Logic to find all legitimate folds such that train and test are separated with valid, e.g.,\n",
|
| 89 |
+
" \n",
|
| 90 |
+
" [train, valid, test]\n",
|
| 91 |
+
" [test, valid, train]\n",
|
| 92 |
+
" [train, valid (0.05), test, valid(0.05), train]\n",
|
| 93 |
+
" \"\"\"\n",
|
| 94 |
+
" all_folds, all_folds_splits = [], []\n",
|
| 95 |
+
" head, tail = 0, len(bucket_counts) - window\n",
|
| 96 |
+
" use_tail, quad_window = 0, int(window / 4)\n",
|
| 97 |
+
" while len(all_folds) < num_folds:\n",
|
| 98 |
+
" curr_ind = tail if use_tail else head\n",
|
| 99 |
+
" found = False\n",
|
| 100 |
+
" while not found and curr_ind >= 0 and curr_ind <= len(bucket_counts) - window:\n",
|
| 101 |
+
" ## Check that any of the validation buckets has both sets of labels.\n",
|
| 102 |
+
" val_found = False\n",
|
| 103 |
+
" for check_i in range(quad_window):\n",
|
| 104 |
+
" val_found |= bucket_counts[curr_ind + check_i][0] > 0 and bucket_counts[curr_ind + check_i][1] > 0\n",
|
| 105 |
+
" for check_i in range(window - quad_window, window):\n",
|
| 106 |
+
" val_found |= bucket_counts[curr_ind + check_i][0] > 0 and bucket_counts[curr_ind + check_i][1] > 0\n",
|
| 107 |
+
"\n",
|
| 108 |
+
" ## Check that any of the test buckets for test data has both labels.\n",
|
| 109 |
+
" test_found = False\n",
|
| 110 |
+
" for check_i in range(quad_window, 3*quad_window):\n",
|
| 111 |
+
" test_found |= bucket_counts[curr_ind + check_i][0] > 0 and bucket_counts[curr_ind + check_i][1] > 0\n",
|
| 112 |
+
"\n",
|
| 113 |
+
" found = val_found & test_found\n",
|
| 114 |
+
" if found:\n",
|
| 115 |
+
" found_ind = curr_ind\n",
|
| 116 |
+
" curr_ind += -step_size if use_tail else step_size\n",
|
| 117 |
+
"\n",
|
| 118 |
+
" val_test_interval = np.array([found_ind, found_ind + window]) * bucket_size\n",
|
| 119 |
+
" \n",
|
| 120 |
+
" this_fold = [bucket_size, (window-2)*bucket_size, bucket_size]\n",
|
| 121 |
+
" this_fold_splits = [\"val\", \"test\", \"val\"]\n",
|
| 122 |
+
" if 1.0 - val_test_interval[-1] > 0:\n",
|
| 123 |
+
" this_fold.append(1.0 - val_test_interval[-1])\n",
|
| 124 |
+
" this_fold_splits.append('train')\n",
|
| 125 |
+
" if val_test_interval[0] > 0:\n",
|
| 126 |
+
" this_fold = [val_test_interval[0]] + this_fold\n",
|
| 127 |
+
" this_fold_splits = ['train'] + this_fold_splits\n",
|
| 128 |
+
"\n",
|
| 129 |
+
" assert np.sum(this_fold) == 1.0\n",
|
| 130 |
+
" all_folds.append(this_fold)\n",
|
| 131 |
+
" all_folds_splits.append(this_fold_splits)\n",
|
| 132 |
+
"\n",
|
| 133 |
+
" if use_tail:\n",
|
| 134 |
+
" tail = curr_ind - 1 * step_size\n",
|
| 135 |
+
" else:\n",
|
| 136 |
+
" head = curr_ind + 1 * step_size\n",
|
| 137 |
+
" use_tail = 1 - use_tail\n",
|
| 138 |
+
"\n",
|
| 139 |
+
" return all_folds, all_folds_splits\n"
|
| 140 |
+
]
|
| 141 |
+
},
|
| 142 |
+
{
|
| 143 |
+
"cell_type": "code",
|
| 144 |
+
"execution_count": null,
|
| 145 |
+
"id": "34c7aa28",
|
| 146 |
+
"metadata": {},
|
| 147 |
+
"outputs": [
|
| 148 |
+
{
|
| 149 |
+
"name": "stdout",
|
| 150 |
+
"output_type": "stream",
|
| 151 |
+
"text": [
|
| 152 |
+
"Buckets: [ 0 154 308 462 616 770 924 1078 1232 1386 1540 1694 1848 2002\n",
|
| 153 |
+
" 2156 2310 2464 2618 2772 2926 3080]\n",
|
| 154 |
+
"bucket_counts: [{0: 68, 1: 86}, {0: 92, 1: 62}, {0: 123, 1: 31}, {0: 42, 1: 112}, {0: 25, 1: 129}, {0: 76, 1: 78}, {0: 65, 1: 89}, {0: 81, 1: 73}, {0: 65, 1: 89}, {0: 33, 1: 121}, {0: 23, 1: 131}, {0: 65, 1: 89}, {0: 75, 1: 79}, {0: 106, 1: 48}, {0: 51, 1: 103}, {0: 103, 1: 51}, {0: 74, 1: 80}, {0: 62, 1: 92}, {0: 154, 1: 0}, {0: 160, 1: 0}]\n",
|
| 155 |
+
"Buckets: [ 0 165 330 495 660 825 990 1155 1320 1485 1650 1815 1980 2145\n",
|
| 156 |
+
" 2310 2475 2640 2805 2970 3135]\n",
|
| 157 |
+
"bucket_counts: [{0: 75, 1: 90}, {0: 117, 1: 48}, {0: 116, 1: 49}, {0: 35, 1: 130}, {0: 19, 1: 146}, {0: 40, 1: 125}, {0: 86, 1: 79}, {0: 48, 1: 117}, {0: 115, 1: 50}, {0: 50, 1: 115}, {0: 28, 1: 137}, {0: 26, 1: 139}, {0: 121, 1: 44}, {0: 95, 1: 70}, {0: 73, 1: 92}, {0: 83, 1: 82}, {0: 105, 1: 60}, {0: 88, 1: 77}, {0: 330, 1: 0}]\n",
|
| 158 |
+
"Buckets: [3086 3187 3288 3389 3490 3591 3692 3793 3894 3995 4096 4197 4298 4399\n",
|
| 159 |
+
" 4500 4601 4702 4803 4904 5005 5106]\n",
|
| 160 |
+
"bucket_counts: [{0: 78, 1: 23}, {0: 46, 1: 55}, {0: 68, 1: 33}, {0: 101, 1: 0}, {0: 95, 1: 6}, {0: 30, 1: 71}, {0: 17, 1: 84}, {0: 42, 1: 59}, {0: 25, 1: 76}, {0: 48, 1: 53}, {0: 21, 1: 80}, {0: 31, 1: 70}, {0: 26, 1: 75}, {0: 25, 1: 76}, {0: 74, 1: 27}, {0: 33, 1: 68}, {0: 39, 1: 62}, {0: 59, 1: 42}, {0: 45, 1: 56}, {0: 113, 1: 0}]\n",
|
| 161 |
+
"Buckets: [3300 3588 3876 4164 4452 4740 5028 5316 5604 5892 6180 6468 6756 7044\n",
|
| 162 |
+
" 7332 7620 7908 8196 8484 8772 9060]\n",
|
| 163 |
+
"bucket_counts: [{0: 231, 1: 57}, {0: 124, 1: 164}, {0: 195, 1: 93}, {0: 288, 1: 0}, {0: 246, 1: 42}, {0: 64, 1: 224}, {0: 95, 1: 193}, {0: 85, 1: 203}, {0: 45, 1: 243}, {0: 120, 1: 168}, {0: 42, 1: 246}, {0: 115, 1: 173}, {0: 57, 1: 231}, {0: 71, 1: 217}, {0: 236, 1: 52}, {0: 85, 1: 203}, {0: 109, 1: 179}, {0: 193, 1: 95}, {0: 184, 1: 104}, {0: 302, 1: 0}]\n",
|
| 164 |
+
"Buckets: [5118 5184 5250 5316 5382 5448 5514 5580 5646 5712 5778 5844 5910 5976\n",
|
| 165 |
+
" 6042 6108 6174 6240 6306 6372 6438]\n",
|
| 166 |
+
"bucket_counts: [{0: 66, 1: 0}, {0: 66, 1: 0}, {0: 39, 1: 27}, {0: 46, 1: 20}, {0: 9, 1: 57}, {0: 38, 1: 28}, {0: 13, 1: 53}, {0: 19, 1: 47}, {0: 26, 1: 40}, {0: 20, 1: 46}, {0: 18, 1: 48}, {0: 12, 1: 54}, {0: 28, 1: 38}, {0: 34, 1: 32}, {0: 49, 1: 17}, {0: 28, 1: 38}, {0: 35, 1: 31}, {0: 25, 1: 41}, {0: 19, 1: 47}, {0: 76, 1: 2}]\n",
|
| 167 |
+
"Buckets: [ 9074 9140 9206 9272 9338 9404 9470 9536 9602 9668 9734 9800\n",
|
| 168 |
+
" 9866 9932 9998 10064 10130 10196 10262 10328 10394]\n",
|
| 169 |
+
"bucket_counts: [{0: 66, 1: 0}, {0: 66, 1: 0}, {0: 35, 1: 31}, {0: 58, 1: 8}, {0: 18, 1: 48}, {0: 36, 1: 30}, {0: 9, 1: 57}, {0: 20, 1: 46}, {0: 20, 1: 46}, {0: 22, 1: 44}, {0: 16, 1: 50}, {0: 7, 1: 59}, {0: 18, 1: 48}, {0: 28, 1: 38}, {0: 55, 1: 11}, {0: 39, 1: 27}, {0: 35, 1: 31}, {0: 21, 1: 45}, {0: 19, 1: 47}, {0: 81, 1: 3}]\n",
|
| 170 |
+
"Buckets: [6450 6529 6608 6687 6766 6845 6924 7003 7082 7161 7240 7319 7398 7477\n",
|
| 171 |
+
" 7556 7635 7714 7793 7872 7951 8030]\n",
|
| 172 |
+
"bucket_counts: [{0: 79, 1: 0}, {0: 79, 1: 0}, {0: 64, 1: 15}, {0: 52, 1: 27}, {0: 19, 1: 60}, {0: 51, 1: 28}, {0: 27, 1: 52}, {0: 20, 1: 59}, {0: 9, 1: 70}, {0: 23, 1: 56}, {0: 18, 1: 61}, {0: 56, 1: 23}, {0: 14, 1: 65}, {0: 26, 1: 53}, {0: 6, 1: 73}, {0: 37, 1: 42}, {0: 46, 1: 33}, {0: 25, 1: 54}, {0: 54, 1: 25}, {0: 92, 1: 1}]\n",
|
| 173 |
+
"Buckets: [10412 10509 10606 10703 10800 10897 10994 11091 11188 11285 11382 11479\n",
|
| 174 |
+
" 11576 11673 11770 11867 11964 12061 12158 12255 12352]\n",
|
| 175 |
+
"bucket_counts: [{0: 97, 1: 0}, {0: 86, 1: 11}, {0: 76, 1: 21}, {0: 24, 1: 73}, {0: 53, 1: 44}, {0: 36, 1: 61}, {0: 8, 1: 89}, {0: 17, 1: 80}, {0: 45, 1: 52}, {0: 97, 1: 0}, {0: 69, 1: 28}, {0: 56, 1: 41}, {0: 32, 1: 65}, {0: 21, 1: 76}, {0: 12, 1: 85}, {0: 28, 1: 69}, {0: 39, 1: 58}, {0: 28, 1: 69}, {0: 43, 1: 54}, {0: 110, 1: 1}]\n",
|
| 176 |
+
"Buckets: [8044 8095 8146 8197 8248 8299 8350 8401 8452 8503 8554 8605 8656 8707\n",
|
| 177 |
+
" 8758 8809 8860 8911 8962 9013 9064]\n",
|
| 178 |
+
"bucket_counts: [{0: 51, 1: 0}, {0: 51, 1: 0}, {0: 43, 1: 8}, {0: 4, 1: 47}, {0: 16, 1: 35}, {0: 16, 1: 35}, {0: 18, 1: 33}, {0: 6, 1: 45}, {0: 37, 1: 14}, {0: 42, 1: 9}, {0: 8, 1: 43}, {0: 0, 1: 51}, {0: 24, 1: 27}, {0: 51, 1: 0}, {0: 51, 1: 0}, {0: 28, 1: 23}, {0: 8, 1: 43}, {0: 24, 1: 27}, {0: 12, 1: 39}, {0: 24, 1: 35}]\n",
|
| 179 |
+
"Buckets: [12366 12499 12632 12765 12898 13031 13164 13297 13430 13563 13696 13829\n",
|
| 180 |
+
" 13962 14095 14228 14361 14494 14627 14760 14893 15026]\n",
|
| 181 |
+
"bucket_counts: [{0: 133, 1: 0}, {0: 133, 1: 0}, {0: 74, 1: 59}, {0: 9, 1: 124}, {0: 57, 1: 76}, {0: 22, 1: 111}, {0: 60, 1: 73}, {0: 48, 1: 85}, {0: 71, 1: 62}, {0: 133, 1: 0}, {0: 24, 1: 109}, {0: 15, 1: 118}, {0: 71, 1: 62}, {0: 133, 1: 0}, {0: 133, 1: 0}, {0: 50, 1: 83}, {0: 42, 1: 91}, {0: 30, 1: 103}, {0: 39, 1: 94}, {0: 60, 1: 87}]\n",
|
| 182 |
+
"Buckets: [9072 9112 9152 9192 9232 9272 9312 9352 9392 9432 9472 9512 9552 9592\n",
|
| 183 |
+
" 9632 9672 9712 9752 9792 9832 9872]\n",
|
| 184 |
+
"bucket_counts: [{0: 30, 1: 10}, {0: 29, 1: 11}, {0: 39, 1: 1}, {0: 15, 1: 25}, {0: 12, 1: 28}, {0: 27, 1: 13}, {0: 12, 1: 28}, {0: 16, 1: 24}, {0: 21, 1: 19}, {0: 20, 1: 20}, {0: 17, 1: 23}, {0: 18, 1: 22}, {0: 11, 1: 29}, {0: 15, 1: 25}, {0: 24, 1: 16}, {0: 19, 1: 21}, {0: 17, 1: 23}, {0: 30, 1: 10}, {0: 10, 1: 30}, {0: 24, 1: 28}]\n",
|
| 185 |
+
"Buckets: [15040 15079 15118 15157 15196 15235 15274 15313 15352 15391 15430 15469\n",
|
| 186 |
+
" 15508 15547 15586 15625 15664 15703 15742 15781]\n",
|
| 187 |
+
"bucket_counts: [{0: 35, 1: 4}, {0: 25, 1: 14}, {0: 38, 1: 1}, {0: 17, 1: 22}, {0: 7, 1: 32}, {0: 32, 1: 7}, {0: 12, 1: 27}, {0: 17, 1: 22}, {0: 15, 1: 24}, {0: 14, 1: 25}, {0: 19, 1: 20}, {0: 18, 1: 21}, {0: 7, 1: 32}, {0: 13, 1: 26}, {0: 21, 1: 18}, {0: 17, 1: 22}, {0: 20, 1: 19}, {0: 27, 1: 12}, {0: 36, 1: 42}]\n",
|
| 188 |
+
"Buckets: [ 9884 9942 10000 10058 10116 10174 10232 10290 10348 10406 10464 10522\n",
|
| 189 |
+
" 10580 10638 10696 10754 10812 10870 10928 10986 11044]\n",
|
| 190 |
+
"bucket_counts: [{0: 48, 1: 10}, {0: 39, 1: 19}, {0: 44, 1: 14}, {0: 23, 1: 35}, {0: 27, 1: 31}, {0: 17, 1: 41}, {0: 25, 1: 33}, {0: 9, 1: 49}, {0: 19, 1: 39}, {0: 18, 1: 40}, {0: 58, 1: 0}, {0: 58, 1: 0}, {0: 12, 1: 46}, {0: 13, 1: 45}, {0: 12, 1: 46}, {0: 14, 1: 44}, {0: 25, 1: 33}, {0: 14, 1: 44}, {0: 38, 1: 20}, {0: 76, 1: 0}]\n",
|
| 191 |
+
"Buckets: [15820 15877 15934 15991 16048 16105 16162 16219 16276 16333 16390 16447\n",
|
| 192 |
+
" 16504 16561 16618 16675 16732 16789 16846 16903 16960]\n",
|
| 193 |
+
"bucket_counts: [{0: 48, 1: 9}, {0: 38, 1: 19}, {0: 45, 1: 12}, {0: 19, 1: 38}, {0: 7, 1: 50}, {0: 32, 1: 25}, {0: 22, 1: 35}, {0: 15, 1: 42}, {0: 14, 1: 43}, {0: 16, 1: 41}, {0: 44, 1: 13}, {0: 57, 1: 0}, {0: 21, 1: 36}, {0: 16, 1: 41}, {0: 15, 1: 42}, {0: 13, 1: 44}, {0: 25, 1: 32}, {0: 23, 1: 34}, {0: 41, 1: 16}, {0: 61, 1: 0}]\n"
|
| 194 |
+
]
|
| 195 |
+
}
|
| 196 |
+
],
|
| 197 |
+
"source": [
|
| 198 |
+
"## Specify all subjects to compute the chronological folds for.\n",
|
| 199 |
+
"## By default we have the held out sessions (val/test) listed here.\n",
|
| 200 |
+
"ALL_SUBJECTS = [\n",
|
| 201 |
+
" \"HOLDSUBJ_1_HS1_1\",\n",
|
| 202 |
+
" \"HOLDSUBJ_2_HS2_6\",\n",
|
| 203 |
+
" \"HOLDSUBJ_3_HS3_0\",\n",
|
| 204 |
+
" \"HOLDSUBJ_4_HS4_0\",\n",
|
| 205 |
+
" \"HOLDSUBJ_6_HS6_4\",\n",
|
| 206 |
+
" \"HOLDSUBJ_7_HS7_0\",\n",
|
| 207 |
+
" \"HOLDSUBJ_10_HS10_0\",\n",
|
| 208 |
+
"\n",
|
| 209 |
+
" # \"SUBJ_2_S2_5\",\n",
|
| 210 |
+
" # \"SUBJ_4_S4_2\",\n",
|
| 211 |
+
"]\n",
|
| 212 |
+
"\n",
|
| 213 |
+
"## List all the metadata files that correspond to the segments to preprocess. Can optionally use\n",
|
| 214 |
+
"## keyword identifiers for each of the metadata files that need to be processed.\n",
|
| 215 |
+
"_METADATA_FNAMES = {\n",
|
| 216 |
+
" 'default_metadata': 'metadata_ee8e0.csv',\n",
|
| 217 |
+
"}\n",
|
| 218 |
+
"\n",
|
| 219 |
+
"## List all experiments for which the folds should be computed.\n",
|
| 220 |
+
"# _ALL_EXPERIMENTS = [\"sentence_onset_time\", \"speech_vs_nonspeech_time\", \"volume\", \"optical_flow\"]\n",
|
| 221 |
+
"_ALL_EXPERIMENTS = [\"sentence_onset_time\", \"speech_vs_nonspeech_time\"]\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"_SEGMENT_DIR = 'braintreebank_data_segments/{0}'\n",
|
| 224 |
+
"\n",
|
| 225 |
+
"## These are the recommended default settings for computing the folds.\n",
|
| 226 |
+
"bucket_size = 0.05 # Each bucket is 5% duration in samples\n",
|
| 227 |
+
"base_step_size = 1 # We take increments of base_step_size * 5% in samples when constructing buckets.\n",
|
| 228 |
+
"base_window = 1 # Count number of samples per base_window * 5% interval per bucket. Should match base_step_size ideally.\n",
|
| 229 |
+
"step_size = 2 # We take increments of step_size * bucket_size (5%) when looking for buckets.\n",
|
| 230 |
+
"window = 4 # Targeting 20% of data for val and test (i.e., 4 buckets combined for val and test).\n",
|
| 231 |
+
"num_folds = 5 # Number of folds to generate.\n",
|
| 232 |
+
"\n",
|
| 233 |
+
"subject_folds = {}\n",
|
| 234 |
+
"for metadata_setting in _METADATA_FNAMES.keys():\n",
|
| 235 |
+
" metadata_setting_folds = defaultdict(dict)\n",
|
| 236 |
+
"\n",
|
| 237 |
+
" for subject_session in ALL_SUBJECTS:\n",
|
| 238 |
+
" for experiment in _ALL_EXPERIMENTS:\n",
|
| 239 |
+
"\n",
|
| 240 |
+
" fpath = _SEGMENT_DIR.format(experiment)\n",
|
| 241 |
+
" metadata_fname = _METADATA_FNAMES[metadata_setting]\n",
|
| 242 |
+
" metadata = load_metadata(os.path.join(fpath, metadata_fname))\n",
|
| 243 |
+
"\n",
|
| 244 |
+
" subject_rows_indices = metadata.get_indices_matching_cols_values(\n",
|
| 245 |
+
" [\"subject_session\", \"experiment\"], [subject_session, experiment]\n",
|
| 246 |
+
" )\n",
|
| 247 |
+
"\n",
|
| 248 |
+
" per_label_subject_rows_indices = [0, 0]\n",
|
| 249 |
+
" for i in range(2): # 2 = negative/positive labels.\n",
|
| 250 |
+
" per_label_subject_rows_indices[i] = (\n",
|
| 251 |
+
" metadata.get_indices_matching_cols_values(\n",
|
| 252 |
+
" [\"subject_session\", \"experiment\", \"label\"],\n",
|
| 253 |
+
" [subject_session, experiment, i],\n",
|
| 254 |
+
" )\n",
|
| 255 |
+
" )\n",
|
| 256 |
+
"\n",
|
| 257 |
+
" all_folds, all_folds_splits = generate_folds(\n",
|
| 258 |
+
" subject_rows_indices,\n",
|
| 259 |
+
" per_label_subject_rows_indices,\n",
|
| 260 |
+
" bucket_size,\n",
|
| 261 |
+
" step_size,\n",
|
| 262 |
+
" base_step_size,\n",
|
| 263 |
+
" window,\n",
|
| 264 |
+
" base_window,\n",
|
| 265 |
+
" num_folds=num_folds\n",
|
| 266 |
+
" )\n",
|
| 267 |
+
"\n",
|
| 268 |
+
" metadata_setting_folds[subject_session][experiment] = (all_folds, all_folds_splits)\n",
|
| 269 |
+
"\n",
|
| 270 |
+
" subject_folds[metadata_setting] = metadata_setting_folds"
|
| 271 |
+
]
|
| 272 |
+
},
|
| 273 |
+
{
|
| 274 |
+
"cell_type": "code",
|
| 275 |
+
"execution_count": 6,
|
| 276 |
+
"id": "aacfb210",
|
| 277 |
+
"metadata": {},
|
| 278 |
+
"outputs": [
|
| 279 |
+
{
|
| 280 |
+
"name": "stdout",
|
| 281 |
+
"output_type": "stream",
|
| 282 |
+
"text": [
|
| 283 |
+
"metadata_setting:default_metadata, subject_session:HOLDSUBJ_1_HS1_1, experiment:sentence_onset_time\n",
|
| 284 |
+
"\n",
|
| 285 |
+
"Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
|
| 286 |
+
"Split statistics: {'train': Counter({1: 1252, 0: 1218}), 'val': Counter({1: 198, 0: 110}), 'test': Counter({0: 215, 1: 93})}\n",
|
| 287 |
+
"Run_ratio: [0.8, 0.05, 0.1, 0.05]\n",
|
| 288 |
+
"Split statistics: {'train': Counter({1: 1371, 0: 1097}), 'val': Counter({0: 228, 1: 82}), 'test': Counter({0: 218, 1: 90})}\n",
|
| 289 |
+
"Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
|
| 290 |
+
"Split statistics: {'train': Counter({0: 1296, 1: 1174}), 'val': Counter({1: 202, 0: 106}), 'test': Counter({1: 167, 0: 141})}\n",
|
| 291 |
+
"Run_ratio: [0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996]\n",
|
| 292 |
+
"Split statistics: {'train': Counter({1: 1262, 0: 1208}), 'val': Counter({0: 178, 1: 130}), 'test': Counter({0: 157, 1: 151})}\n",
|
| 293 |
+
"Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
|
| 294 |
+
"Split statistics: {'train': Counter({0: 1355, 1: 1115}), 'val': Counter({1: 175, 0: 133}), 'test': Counter({1: 253, 0: 55})}\n",
|
| 295 |
+
"\n",
|
| 296 |
+
"\n",
|
| 297 |
+
"metadata_setting:default_metadata, subject_session:HOLDSUBJ_2_HS2_6, experiment:sentence_onset_time\n",
|
| 298 |
+
"\n",
|
| 299 |
+
"Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
|
| 300 |
+
"Split statistics: {'train': Counter({1: 905, 0: 722}), 'val': Counter({0: 179, 1: 23}), 'test': Counter({0: 115, 1: 88})}\n",
|
| 301 |
+
"Run_ratio: [0.8, 0.05, 0.1, 0.05]\n",
|
| 302 |
+
"Split statistics: {'train': Counter({1: 860, 0: 765}), 'val': Counter({0: 140, 1: 64}), 'test': Counter({0: 111, 1: 92})}\n",
|
| 303 |
+
"Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
|
| 304 |
+
"Split statistics: {'train': Counter({0: 834, 1: 793}), 'val': Counter({0: 133, 1: 69}), 'test': Counter({1: 154, 0: 49})}\n",
|
| 305 |
+
"Run_ratio: [0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996]\n",
|
| 306 |
+
"Split statistics: {'train': Counter({0: 856, 1: 771}), 'val': Counter({1: 146, 0: 56}), 'test': Counter({0: 104, 1: 99})}\n",
|
| 307 |
+
"Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
|
| 308 |
+
"Split statistics: {'train': Counter({0: 890, 1: 737}), 'val': Counter({1: 146, 0: 56}), 'test': Counter({1: 133, 0: 70})}\n",
|
| 309 |
+
"\n",
|
| 310 |
+
"\n",
|
| 311 |
+
"metadata_setting:default_metadata, subject_session:HOLDSUBJ_3_HS3_0, experiment:sentence_onset_time\n",
|
| 312 |
+
"\n",
|
| 313 |
+
"Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
|
| 314 |
+
"Split statistics: {'train': Counter({1: 618, 0: 449}), 'val': Counter({0: 111, 1: 21}), 'test': Counter({0: 106, 1: 27})}\n",
|
| 315 |
+
"Run_ratio: [0.8, 0.05, 0.1, 0.05]\n",
|
| 316 |
+
"Split statistics: {'train': Counter({1: 552, 0: 513}), 'val': Counter({0: 101, 1: 33}), 'test': Counter({1: 81, 0: 52})}\n",
|
| 317 |
+
"Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
|
| 318 |
+
"Split statistics: {'train': Counter({0: 586, 1: 481}), 'val': Counter({1: 104, 0: 28}), 'test': Counter({1: 81, 0: 52})}\n",
|
| 319 |
+
"Run_ratio: [0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996]\n",
|
| 320 |
+
"Split statistics: {'train': Counter({1: 539, 0: 528}), 'val': Counter({1: 76, 0: 56}), 'test': Counter({0: 82, 1: 51})}\n",
|
| 321 |
+
"Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
|
| 322 |
+
"Split statistics: {'train': Counter({0: 589, 1: 478}), 'val': Counter({1: 92, 0: 40}), 'test': Counter({1: 96, 0: 37})}\n",
|
| 323 |
+
"\n",
|
| 324 |
+
"\n",
|
| 325 |
+
"metadata_setting:default_metadata, subject_session:HOLDSUBJ_4_HS4_0, experiment:sentence_onset_time\n",
|
| 326 |
+
"\n",
|
| 327 |
+
"Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
|
| 328 |
+
"Split statistics: {'train': Counter({1: 754, 0: 523}), 'val': Counter({0: 130, 1: 28}), 'test': Counter({0: 144, 1: 15})}\n",
|
| 329 |
+
"Run_ratio: [0.8, 0.05, 0.1, 0.05]\n",
|
| 330 |
+
"Split statistics: {'train': Counter({1: 689, 0: 586}), 'val': Counter({0: 125, 1: 35}), 'test': Counter({0: 86, 1: 73})}\n",
|
| 331 |
+
"Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
|
| 332 |
+
"Split statistics: {'train': Counter({0: 680, 1: 597}), 'val': Counter({1: 119, 0: 39}), 'test': Counter({1: 81, 0: 78})}\n",
|
| 333 |
+
"Run_ratio: [0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996]\n",
|
| 334 |
+
"Split statistics: {'train': Counter({0: 710, 1: 567}), 'val': Counter({1: 102, 0: 56}), 'test': Counter({1: 128, 0: 31})}\n",
|
| 335 |
+
"Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
|
| 336 |
+
"Split statistics: {'train': Counter({0: 691, 1: 586}), 'val': Counter({1: 98, 0: 60}), 'test': Counter({1: 113, 0: 46})}\n",
|
| 337 |
+
"\n",
|
| 338 |
+
"\n",
|
| 339 |
+
"metadata_setting:default_metadata, subject_session:HOLDSUBJ_6_HS6_4, experiment:sentence_onset_time\n",
|
| 340 |
+
"\n",
|
| 341 |
+
"Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
|
| 342 |
+
"Split statistics: {'train': Counter({1: 459, 0: 365}), 'val': Counter({0: 55, 1: 47}), 'test': Counter({0: 94, 1: 8})}\n",
|
| 343 |
+
"Run_ratio: [0.8, 0.05, 0.1, 0.05]\n",
|
| 344 |
+
"Split statistics: {'train': Counter({0: 448, 1: 374}), 'val': Counter({1: 70, 0: 34}), 'test': Counter({1: 70, 0: 32})}\n",
|
| 345 |
+
"Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
|
| 346 |
+
"Split statistics: {'train': Counter({0: 458, 1: 366}), 'val': Counter({1: 79, 0: 23}), 'test': Counter({1: 69, 0: 33})}\n",
|
| 347 |
+
"Run_ratio: [0.5, 0.05, 0.1, 0.05, 0.29999999999999993]\n",
|
| 348 |
+
"Split statistics: {'train': Counter({0: 427, 1: 397}), 'val': Counter({0: 59, 1: 43}), 'test': Counter({1: 74, 0: 28})}\n",
|
| 349 |
+
"Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
|
| 350 |
+
"Split statistics: {'train': Counter({0: 427, 1: 397}), 'val': Counter({1: 62, 0: 40}), 'test': Counter({1: 55, 0: 47})}\n",
|
| 351 |
+
"\n",
|
| 352 |
+
"\n",
|
| 353 |
+
"metadata_setting:default_metadata, subject_session:HOLDSUBJ_7_HS7_0, experiment:sentence_onset_time\n",
|
| 354 |
+
"\n",
|
| 355 |
+
"Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
|
| 356 |
+
"Split statistics: {'train': Counter({1: 358, 0: 293}), 'val': Counter({0: 44, 1: 36}), 'test': Counter({0: 69, 1: 12})}\n",
|
| 357 |
+
"Run_ratio: [0.8, 0.05, 0.1, 0.05]\n",
|
| 358 |
+
"Split statistics: {'train': Counter({0: 330, 1: 319}), 'val': Counter({0: 45, 1: 37}), 'test': Counter({1: 50, 0: 31})}\n",
|
| 359 |
+
"Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
|
| 360 |
+
"Split statistics: {'train': Counter({0: 337, 1: 314}), 'val': Counter({1: 50, 0: 30}), 'test': Counter({1: 42, 0: 39})}\n",
|
| 361 |
+
"Run_ratio: [0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996]\n",
|
| 362 |
+
"Split statistics: {'train': Counter({0: 333, 1: 318}), 'val': Counter({1: 43, 0: 37}), 'test': Counter({1: 45, 0: 36})}\n",
|
| 363 |
+
"Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
|
| 364 |
+
"Split statistics: {'train': Counter({0: 331, 1: 320}), 'val': Counter({0: 41, 1: 39}), 'test': Counter({1: 47, 0: 34})}\n",
|
| 365 |
+
"\n",
|
| 366 |
+
"\n",
|
| 367 |
+
"metadata_setting:default_metadata, subject_session:HOLDSUBJ_10_HS10_0, experiment:sentence_onset_time\n",
|
| 368 |
+
"\n",
|
| 369 |
+
"Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
|
| 370 |
+
"Split statistics: {'train': Counter({1: 510, 0: 435}), 'val': Counter({0: 70, 1: 46}), 'test': Counter({0: 84, 1: 33})}\n",
|
| 371 |
+
"Run_ratio: [0.8, 0.05, 0.1, 0.05]\n",
|
| 372 |
+
"Split statistics: {'train': Counter({1: 495, 0: 447}), 'val': Counter({0: 76, 1: 43}), 'test': Counter({0: 66, 1: 51})}\n",
|
| 373 |
+
"Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
|
| 374 |
+
"Split statistics: {'train': Counter({0: 511, 1: 434}), 'val': Counter({1: 80, 0: 36}), 'test': Counter({1: 75, 0: 42})}\n",
|
| 375 |
+
"Run_ratio: [0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996]\n",
|
| 376 |
+
"Split statistics: {'train': Counter({0: 534, 1: 411}), 'val': Counter({1: 89, 0: 27}), 'test': Counter({1: 89, 0: 28})}\n",
|
| 377 |
+
"Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
|
| 378 |
+
"Split statistics: {'train': Counter({1: 509, 0: 436}), 'val': Counter({0: 74, 1: 42}), 'test': Counter({0: 79, 1: 38})}\n",
|
| 379 |
+
"\n",
|
| 380 |
+
"\n",
|
| 381 |
+
"metadata_setting:default_metadata, subject_session:HOLDSUBJ_1_HS1_1, experiment:speech_vs_nonspeech_time\n",
|
| 382 |
+
"\n",
|
| 383 |
+
"Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
|
| 384 |
+
"Split statistics: {'train': Counter({1: 1333, 0: 1307}), 'val': Counter({1: 220, 0: 110}), 'test': Counter({0: 233, 1: 97})}\n",
|
| 385 |
+
"Run_ratio: [0.75, 0.05, 0.1, 0.05, 0.04999999999999993]\n",
|
| 386 |
+
"Split statistics: {'train': Counter({1: 1431, 0: 1209}), 'val': Counter({0: 248, 1: 82}), 'test': Counter({0: 193, 1: 137})}\n",
|
| 387 |
+
"Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
|
| 388 |
+
"Split statistics: {'train': Counter({0: 1457, 1: 1183}), 'val': Counter({1: 263, 0: 67}), 'test': Counter({1: 204, 0: 126})}\n",
|
| 389 |
+
"Run_ratio: [0.55, 0.05, 0.1, 0.05, 0.25]\n",
|
| 390 |
+
"Split statistics: {'train': Counter({0: 1335, 1: 1305}), 'val': Counter({1: 231, 0: 99}), 'test': Counter({0: 216, 1: 114})}\n",
|
| 391 |
+
"Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
|
| 392 |
+
"Split statistics: {'train': Counter({0: 1431, 1: 1209}), 'val': Counter({1: 189, 0: 141}), 'test': Counter({1: 252, 0: 78})}\n",
|
| 393 |
+
"\n",
|
| 394 |
+
"\n",
|
| 395 |
+
"metadata_setting:default_metadata, subject_session:HOLDSUBJ_2_HS2_6, experiment:speech_vs_nonspeech_time\n",
|
| 396 |
+
"\n",
|
| 397 |
+
"Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
|
| 398 |
+
"Split statistics: {'train': Counter({1: 2573, 0: 2048}), 'val': Counter({0: 519, 1: 57}), 'test': Counter({0: 320, 1: 257})}\n",
|
| 399 |
+
"Run_ratio: [0.8, 0.05, 0.1, 0.05]\n",
|
| 400 |
+
"Split statistics: {'train': Counter({1: 2511, 0: 2108}), 'val': Counter({0: 396, 1: 182}), 'test': Counter({0: 383, 1: 194})}\n",
|
| 401 |
+
"Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
|
| 402 |
+
"Split statistics: {'train': Counter({0: 2399, 1: 2222}), 'val': Counter({0: 329, 1: 247}), 'test': Counter({1: 418, 0: 159})}\n",
|
| 403 |
+
"Run_ratio: [0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996]\n",
|
| 404 |
+
"Split statistics: {'train': Counter({0: 2431, 1: 2190}), 'val': Counter({1: 434, 0: 142}), 'test': Counter({0: 314, 1: 263})}\n",
|
| 405 |
+
"Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
|
| 406 |
+
"Split statistics: {'train': Counter({0: 2565, 1: 2056}), 'val': Counter({1: 418, 0: 158}), 'test': Counter({1: 413, 0: 164})}\n",
|
| 407 |
+
"\n",
|
| 408 |
+
"\n",
|
| 409 |
+
"metadata_setting:default_metadata, subject_session:HOLDSUBJ_3_HS3_0, experiment:speech_vs_nonspeech_time\n",
|
| 410 |
+
"\n",
|
| 411 |
+
"Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
|
| 412 |
+
"Split statistics: {'train': Counter({1: 630, 0: 443}), 'val': Counter({0: 125, 1: 7}), 'test': Counter({0: 101, 1: 32})}\n",
|
| 413 |
+
"Run_ratio: [0.8, 0.05, 0.1, 0.05]\n",
|
| 414 |
+
"Split statistics: {'train': Counter({1: 555, 0: 515}), 'val': Counter({0: 105, 1: 30}), 'test': Counter({1: 84, 0: 49})}\n",
|
| 415 |
+
"Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
|
| 416 |
+
"Split statistics: {'train': Counter({0: 588, 1: 485}), 'val': Counter({1: 95, 0: 37}), 'test': Counter({1: 89, 0: 44})}\n",
|
| 417 |
+
"Run_ratio: [0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996]\n",
|
| 418 |
+
"Split statistics: {'train': Counter({1: 544, 0: 529}), 'val': Counter({1: 85, 0: 47}), 'test': Counter({0: 93, 1: 40})}\n",
|
| 419 |
+
"Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
|
| 420 |
+
"Split statistics: {'train': Counter({0: 604, 1: 469}), 'val': Counter({1: 100, 0: 32}), 'test': Counter({1: 100, 0: 33})}\n",
|
| 421 |
+
"\n",
|
| 422 |
+
"\n",
|
| 423 |
+
"metadata_setting:default_metadata, subject_session:HOLDSUBJ_4_HS4_0, experiment:speech_vs_nonspeech_time\n",
|
| 424 |
+
"\n",
|
| 425 |
+
"Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
|
| 426 |
+
"Split statistics: {'train': Counter({1: 872, 0: 693}), 'val': Counter({0: 122, 1: 72}), 'test': Counter({0: 162, 1: 33})}\n",
|
| 427 |
+
"Run_ratio: [0.8, 0.05, 0.1, 0.05]\n",
|
| 428 |
+
"Split statistics: {'train': Counter({1: 805, 0: 758}), 'val': Counter({0: 146, 1: 50}), 'test': Counter({1: 122, 0: 73})}\n",
|
| 429 |
+
"Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
|
| 430 |
+
"Split statistics: {'train': Counter({0: 863, 1: 702}), 'val': Counter({1: 125, 0: 69}), 'test': Counter({1: 150, 0: 45})}\n",
|
| 431 |
+
"Run_ratio: [0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996]\n",
|
| 432 |
+
"Split statistics: {'train': Counter({0: 883, 1: 682}), 'val': Counter({1: 132, 0: 62}), 'test': Counter({1: 163, 0: 32})}\n",
|
| 433 |
+
"Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
|
| 434 |
+
"Split statistics: {'train': Counter({1: 852, 0: 713}), 'val': Counter({0: 102, 1: 92}), 'test': Counter({0: 162, 1: 33})}\n",
|
| 435 |
+
"\n",
|
| 436 |
+
"\n",
|
| 437 |
+
"metadata_setting:default_metadata, subject_session:HOLDSUBJ_6_HS6_4, experiment:speech_vs_nonspeech_time\n",
|
| 438 |
+
"\n",
|
| 439 |
+
"Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
|
| 440 |
+
"Split statistics: {'train': Counter({1: 1153, 0: 988}), 'val': Counter({0: 142, 1: 124}), 'test': Counter({0: 207, 1: 60})}\n",
|
| 441 |
+
"Run_ratio: [0.8, 0.05, 0.1, 0.05]\n",
|
| 442 |
+
"Split statistics: {'train': Counter({0: 1168, 1: 971}), 'val': Counter({1: 165, 0: 103}), 'test': Counter({1: 201, 0: 66})}\n",
|
| 443 |
+
"Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
|
| 444 |
+
"Split statistics: {'train': Counter({0: 1149, 1: 992}), 'val': Counter({1: 162, 0: 104}), 'test': Counter({1: 183, 0: 84})}\n",
|
| 445 |
+
"Run_ratio: [0.5, 0.05, 0.1, 0.05, 0.29999999999999993]\n",
|
| 446 |
+
"Split statistics: {'train': Counter({0: 1089, 1: 1052}), 'val': Counter({0: 157, 1: 109}), 'test': Counter({1: 176, 0: 91})}\n",
|
| 447 |
+
"Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
|
| 448 |
+
"Split statistics: {'train': Counter({0: 1095, 1: 1046}), 'val': Counter({1: 178, 0: 88}), 'test': Counter({0: 154, 1: 113})}\n",
|
| 449 |
+
"\n",
|
| 450 |
+
"\n",
|
| 451 |
+
"metadata_setting:default_metadata, subject_session:HOLDSUBJ_7_HS7_0, experiment:speech_vs_nonspeech_time\n",
|
| 452 |
+
"\n",
|
| 453 |
+
"Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
|
| 454 |
+
"Split statistics: {'train': Counter({1: 349, 0: 275}), 'val': Counter({0: 52, 1: 26}), 'test': Counter({0: 63, 1: 15})}\n",
|
| 455 |
+
"Run_ratio: [0.75, 0.05, 0.1, 0.05, 0.04999999999999993]\n",
|
| 456 |
+
"Split statistics: {'train': Counter({0: 313, 1: 311}), 'val': Counter({1: 48, 0: 30}), 'test': Counter({0: 47, 1: 31})}\n",
|
| 457 |
+
"Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
|
| 458 |
+
"Split statistics: {'train': Counter({0: 322, 1: 302}), 'val': Counter({1: 54, 0: 24}), 'test': Counter({0: 44, 1: 34})}\n",
|
| 459 |
+
"Run_ratio: [0.55, 0.05, 0.1, 0.05, 0.25]\n",
|
| 460 |
+
"Split statistics: {'train': Counter({0: 331, 1: 293}), 'val': Counter({1: 39, 0: 39}), 'test': Counter({1: 58, 0: 20})}\n",
|
| 461 |
+
"Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
|
| 462 |
+
"Split statistics: {'train': Counter({0: 324, 1: 300}), 'val': Counter({1: 45, 0: 33}), 'test': Counter({1: 45, 0: 33})}\n",
|
| 463 |
+
"\n",
|
| 464 |
+
"\n",
|
| 465 |
+
"metadata_setting:default_metadata, subject_session:HOLDSUBJ_10_HS10_0, experiment:speech_vs_nonspeech_time\n",
|
| 466 |
+
"\n",
|
| 467 |
+
"Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
|
| 468 |
+
"Split statistics: {'train': Counter({1: 494, 0: 422}), 'val': Counter({0: 67, 1: 47}), 'test': Counter({0: 83, 1: 31})}\n",
|
| 469 |
+
"Run_ratio: [0.8, 0.05, 0.1, 0.05]\n",
|
| 470 |
+
"Split statistics: {'train': Counter({1: 493, 0: 422}), 'val': Counter({0: 83, 1: 32}), 'test': Counter({0: 67, 1: 47})}\n",
|
| 471 |
+
"Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
|
| 472 |
+
"Split statistics: {'train': Counter({0: 496, 1: 420}), 'val': Counter({1: 92, 0: 22}), 'test': Counter({1: 60, 0: 54})}\n",
|
| 473 |
+
"Run_ratio: [0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996]\n",
|
| 474 |
+
"Split statistics: {'train': Counter({0: 509, 1: 407}), 'val': Counter({1: 82, 0: 32}), 'test': Counter({1: 83, 0: 31})}\n",
|
| 475 |
+
"Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
|
| 476 |
+
"Split statistics: {'train': Counter({1: 476, 0: 440}), 'val': Counter({0: 71, 1: 43}), 'test': Counter({0: 61, 1: 53})}\n",
|
| 477 |
+
"\n",
|
| 478 |
+
"\n"
|
| 479 |
+
]
|
| 480 |
+
}
|
| 481 |
+
],
|
| 482 |
+
"source": [
|
| 483 |
+
"## Following code will compute the statistics associated with each fold.\n",
|
| 484 |
+
"all_output_dicts = {}\n",
|
| 485 |
+
"for metadata_setting in _METADATA_FNAMES.keys():\n",
|
| 486 |
+
" output_dict = {} # {experiment_name: {subject_session: [(ratio1, split1), (ratio2, split2), ...]}}\n",
|
| 487 |
+
" for experiment in _ALL_EXPERIMENTS:\n",
|
| 488 |
+
" output_dict[experiment] = {}\n",
|
| 489 |
+
"\n",
|
| 490 |
+
" fpath = _SEGMENT_DIR.format(experiment)\n",
|
| 491 |
+
" metadata_fname = _METADATA_FNAMES[metadata_setting]\n",
|
| 492 |
+
" metadata = load_metadata(os.path.join(fpath, metadata_fname))\n",
|
| 493 |
+
"\n",
|
| 494 |
+
" for subject_session in ALL_SUBJECTS:\n",
|
| 495 |
+
" print(\n",
|
| 496 |
+
" f'metadata_setting:{metadata_setting}, '\n",
|
| 497 |
+
" f'subject_session:{subject_session}, '\n",
|
| 498 |
+
" f'experiment:{experiment}\\n'\n",
|
| 499 |
+
" )\n",
|
| 500 |
+
"\n",
|
| 501 |
+
" subject_rows_indices = metadata.get_indices_matching_cols_values(\n",
|
| 502 |
+
" [\"subject_session\", \"experiment\"], [subject_session, experiment]\n",
|
| 503 |
+
" )\n",
|
| 504 |
+
" n_segments = len(subject_rows_indices)\n",
|
| 505 |
+
"\n",
|
| 506 |
+
" folds, splits = subject_folds[metadata_setting][subject_session][experiment]\n",
|
| 507 |
+
" out_tuples = []\n",
|
| 508 |
+
" for run_ratio, run_splits in zip(folds, splits):\n",
|
| 509 |
+
" counts = (np.array(run_ratio) * n_segments).astype(int)\n",
|
| 510 |
+
" counts[-1] = n_segments - sum(counts[:-1])\n",
|
| 511 |
+
"\n",
|
| 512 |
+
" print(f\"Run_ratio: {run_ratio}\")\n",
|
| 513 |
+
"\n",
|
| 514 |
+
" agg_split_counts = {'train': Counter(), 'val': Counter(), 'test': Counter()}\n",
|
| 515 |
+
" sum_now = 0\n",
|
| 516 |
+
" for c, split in zip(counts, run_splits):\n",
|
| 517 |
+
" label_split_indices = subject_rows_indices[sum_now : sum_now + c]\n",
|
| 518 |
+
" sum_now += c\n",
|
| 519 |
+
" agg_split_counts[split].update(\n",
|
| 520 |
+
" metadata._df.iloc[label_split_indices].label.to_numpy()\n",
|
| 521 |
+
" )\n",
|
| 522 |
+
"\n",
|
| 523 |
+
" print(f'Split statistics: {agg_split_counts}')\n",
|
| 524 |
+
" out_tuples.append((run_ratio, run_splits))\n",
|
| 525 |
+
" print('\\n')\n",
|
| 526 |
+
"\n",
|
| 527 |
+
" output_dict[experiment][subject_session] = out_tuples\n",
|
| 528 |
+
"\n",
|
| 529 |
+
" all_output_dicts[metadata_setting] = output_dict"
|
| 530 |
+
]
|
| 531 |
+
},
|
| 532 |
+
{
|
| 533 |
+
"cell_type": "code",
|
| 534 |
+
"execution_count": 7,
|
| 535 |
+
"id": "22372166",
|
| 536 |
+
"metadata": {},
|
| 537 |
+
"outputs": [
|
| 538 |
+
{
|
| 539 |
+
"name": "stdout",
|
| 540 |
+
"output_type": "stream",
|
| 541 |
+
"text": [
|
| 542 |
+
"/data/seyedesa/njepa/public_release_test/data_nov30_15_00/sentence_onset_time\n",
|
| 543 |
+
"/data/seyedesa/njepa/public_release_test/data_nov30_15_00/speech_vs_nonspeech_time\n"
|
| 544 |
+
]
|
| 545 |
+
}
|
| 546 |
+
],
|
| 547 |
+
"source": [
|
| 548 |
+
"## Save out the data in the format expected in braintreebank_dataset.py.\n",
|
| 549 |
+
"import pickle\n",
|
| 550 |
+
"\n",
|
| 551 |
+
"for fb_setting, fb_setting_output in all_output_dicts.items():\n",
|
| 552 |
+
" out_fname = f\"{Path(_METADATA_FNAMES[fb_setting]).stem}_folds.pkl\"\n",
|
| 553 |
+
"\n",
|
| 554 |
+
" for experiment, experiment_output in fb_setting_output.items():\n",
|
| 555 |
+
" out_path = _SEGMENT_DIR.format(experiment)\n",
|
| 556 |
+
" print(out_path)\n",
|
| 557 |
+
" with open(os.path.join(out_path, out_fname), 'wb') as file:\n",
|
| 558 |
+
" pickle.dump(experiment_output, file)"
|
| 559 |
+
]
|
| 560 |
+
},
|
| 561 |
+
{
|
| 562 |
+
"cell_type": "code",
|
| 563 |
+
"execution_count": 8,
|
| 564 |
+
"id": "1e6e1189",
|
| 565 |
+
"metadata": {},
|
| 566 |
+
"outputs": [
|
| 567 |
+
{
|
| 568 |
+
"name": "stdout",
|
| 569 |
+
"output_type": "stream",
|
| 570 |
+
"text": [
|
| 571 |
+
"sentence_onset_time\n",
|
| 572 |
+
"{'HOLDSUBJ_1_HS1_1': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_2_HS2_6': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_3_HS3_0': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_4_HS4_0': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_6_HS6_4': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.5, 0.05, 0.1, 0.05, 0.29999999999999993], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_7_HS7_0': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_10_HS10_0': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])]}\n",
|
| 573 |
+
"\n",
|
| 574 |
+
"\n",
|
| 575 |
+
"speech_vs_nonspeech_time\n",
|
| 576 |
+
"{'HOLDSUBJ_1_HS1_1': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.75, 0.05, 0.1, 0.05, 0.04999999999999993], ['train', 'val', 'test', 'val', 'train']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.55, 0.05, 0.1, 0.05, 0.25], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_2_HS2_6': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_3_HS3_0': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_4_HS4_0': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_6_HS6_4': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.5, 0.05, 0.1, 0.05, 0.29999999999999993], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_7_HS7_0': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.75, 0.05, 0.1, 0.05, 0.04999999999999993], ['train', 'val', 'test', 'val', 'train']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.55, 0.05, 0.1, 0.05, 0.25], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_10_HS10_0': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])]}\n",
|
| 577 |
+
"\n",
|
| 578 |
+
"\n"
|
| 579 |
+
]
|
| 580 |
+
}
|
| 581 |
+
],
|
| 582 |
+
"source": [
|
| 583 |
+
"## Checking output was correct.\n",
|
| 584 |
+
"for fb_setting, fb_setting_output in all_output_dicts.items():\n",
|
| 585 |
+
" out_fname = f\"{Path(_METADATA_FNAMES[fb_setting]).stem}_folds.pkl\"\n",
|
| 586 |
+
"\n",
|
| 587 |
+
" for experiment, experiment_output in fb_setting_output.items():\n",
|
| 588 |
+
" out_path = _SEGMENT_DIR.format(experiment)\n",
|
| 589 |
+
" with open(os.path.join(out_path, out_fname), 'rb') as file:\n",
|
| 590 |
+
" datatmp = pickle.load(file)\n",
|
| 591 |
+
" print(experiment)\n",
|
| 592 |
+
" print(datatmp)\n",
|
| 593 |
+
" print('\\n')"
|
| 594 |
+
]
|
| 595 |
+
},
|
| 596 |
+
{
|
| 597 |
+
"cell_type": "code",
|
| 598 |
+
"execution_count": null,
|
| 599 |
+
"id": "77052232",
|
| 600 |
+
"metadata": {},
|
| 601 |
+
"outputs": [],
|
| 602 |
+
"source": []
|
| 603 |
+
}
|
| 604 |
+
],
|
| 605 |
+
"metadata": {
|
| 606 |
+
"kernelspec": {
|
| 607 |
+
"display_name": "venv",
|
| 608 |
+
"language": "python",
|
| 609 |
+
"name": "python3"
|
| 610 |
+
},
|
| 611 |
+
"language_info": {
|
| 612 |
+
"codemirror_mode": {
|
| 613 |
+
"name": "ipython",
|
| 614 |
+
"version": 3
|
| 615 |
+
},
|
| 616 |
+
"file_extension": ".py",
|
| 617 |
+
"mimetype": "text/x-python",
|
| 618 |
+
"name": "python",
|
| 619 |
+
"nbconvert_exporter": "python",
|
| 620 |
+
"pygments_lexer": "ipython3",
|
| 621 |
+
"version": "3.8.10"
|
| 622 |
+
}
|
| 623 |
+
},
|
| 624 |
+
"nbformat": 4,
|
| 625 |
+
"nbformat_minor": 5
|
| 626 |
+
}
|
barista/models/TSEncoder2D.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Source code based on publicly released dilated CNN models as found in
|
| 2 |
+
## SimTS model: https://github.com/xingyu617/SimTS_Representation_Learning/blob/main/models/dilation.py
|
| 3 |
+
## and
|
| 4 |
+
## TS2Vec repo: https://github.com/zhihanyue/ts2vec/blob/main/models/dilated_conv.py
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch.utils.checkpoint import checkpoint
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def init_weights(m):
|
| 13 |
+
"""
|
| 14 |
+
Relevant reading material:
|
| 15 |
+
https://pytorch.org/docs/stable/nn.init.html
|
| 16 |
+
https://github.com/pytorch/vision/blob/309bd7a1512ad9ff0e9729fbdad043cb3472e4cb/torchvision/models/densenet.py#L203
|
| 17 |
+
"""
|
| 18 |
+
if isinstance(m, nn.Conv2d):
|
| 19 |
+
nn.init.kaiming_normal_(m.weight)
|
| 20 |
+
m.bias.data.fill_(0.0)
|
| 21 |
+
elif isinstance(m, nn.Linear):
|
| 22 |
+
nn.init.constant_(m.bias, 0)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class SamePadConv(nn.Module):
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
in_channels,
|
| 29 |
+
out_channels,
|
| 30 |
+
kernel_size,
|
| 31 |
+
stride=1,
|
| 32 |
+
dilation=1,
|
| 33 |
+
groups=1,
|
| 34 |
+
):
|
| 35 |
+
"""Padded convolution to ensure same sized input and output."""
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.receptive_field = (kernel_size - 1) * dilation + 1
|
| 38 |
+
padding = self.receptive_field // 2
|
| 39 |
+
self.conv = nn.Conv2d(
|
| 40 |
+
in_channels,
|
| 41 |
+
out_channels,
|
| 42 |
+
(1, kernel_size),
|
| 43 |
+
padding=(0, padding),
|
| 44 |
+
stride=(1, stride),
|
| 45 |
+
dilation=(1, dilation),
|
| 46 |
+
groups=groups,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
init_weights(self.conv)
|
| 50 |
+
|
| 51 |
+
self.remove = 1 if self.receptive_field % 2 == 0 else 0
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
out = self.conv(x)
|
| 55 |
+
if self.remove > 0:
|
| 56 |
+
out = out[:, :, :, : -self.remove]
|
| 57 |
+
return out
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class ConvBlock(nn.Module):
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
in_channels,
|
| 64 |
+
out_channels,
|
| 65 |
+
kernel_size,
|
| 66 |
+
stride,
|
| 67 |
+
dilation,
|
| 68 |
+
final=False,
|
| 69 |
+
enable_checkpointing=False,
|
| 70 |
+
):
|
| 71 |
+
"""
|
| 72 |
+
Convolutional block implementation.
|
| 73 |
+
|
| 74 |
+
Consists of two convolution layers followed by a residual stream.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
in_channels: int. Input channel count.
|
| 78 |
+
out_channels: int. Output channel count.
|
| 79 |
+
kernel_size: int. Convolution kernel size.
|
| 80 |
+
stride: int. Convolution stride size.
|
| 81 |
+
dilation: int. Convolution dilation amount.
|
| 82 |
+
final: bool. This is the final convolutional block in the stack. Only relevant for
|
| 83 |
+
using a projection head for the residual stream.
|
| 84 |
+
enable_checkpointing: bool. Enable checkpointing of the intermediate weights if
|
| 85 |
+
desired. Default False.
|
| 86 |
+
"""
|
| 87 |
+
super().__init__()
|
| 88 |
+
|
| 89 |
+
self.enable_checkpointing = enable_checkpointing
|
| 90 |
+
|
| 91 |
+
self.conv1 = SamePadConv(
|
| 92 |
+
in_channels,
|
| 93 |
+
out_channels,
|
| 94 |
+
kernel_size,
|
| 95 |
+
stride=stride,
|
| 96 |
+
dilation=dilation,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
self.conv2 = SamePadConv(
|
| 100 |
+
out_channels,
|
| 101 |
+
out_channels,
|
| 102 |
+
kernel_size,
|
| 103 |
+
stride=stride,
|
| 104 |
+
dilation=dilation,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
self.projector = (
|
| 108 |
+
nn.Conv2d(
|
| 109 |
+
in_channels, out_channels, kernel_size=(1, 1), stride=(1, stride**2),
|
| 110 |
+
)
|
| 111 |
+
if in_channels != out_channels or final or stride != 1
|
| 112 |
+
else None
|
| 113 |
+
)
|
| 114 |
+
if self.projector is not None:
|
| 115 |
+
init_weights(self.projector)
|
| 116 |
+
|
| 117 |
+
def _forward_mini_block(self, x: torch.tensor, block_num: int):
|
| 118 |
+
x = self.conv1(x) if block_num == 1 else self.conv2(x)
|
| 119 |
+
x = F.layer_norm(x, (x.shape[-1],))
|
| 120 |
+
x = F.gelu(x)
|
| 121 |
+
return x
|
| 122 |
+
|
| 123 |
+
def forward(self, x: torch.tensor):
|
| 124 |
+
residual = x if self.projector is None else self.projector(x)
|
| 125 |
+
|
| 126 |
+
if self.enable_checkpointing:
|
| 127 |
+
x = checkpoint(self._forward_mini_block, x, 1, use_reentrant=False)
|
| 128 |
+
x = checkpoint(self._forward_mini_block, x, 2, use_reentrant=False)
|
| 129 |
+
else:
|
| 130 |
+
x = self._forward_mini_block(x, block_num=1)
|
| 131 |
+
x = self._forward_mini_block(x, block_num=2)
|
| 132 |
+
|
| 133 |
+
return x + residual
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class DilatedConvEncoder(nn.Module):
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
in_channels,
|
| 140 |
+
channels,
|
| 141 |
+
kernel_size,
|
| 142 |
+
stride=1,
|
| 143 |
+
enable_checkpointing=False,
|
| 144 |
+
):
|
| 145 |
+
"""Dilated CNN implementation. See ConvBlock for argument definitions."""
|
| 146 |
+
super().__init__()
|
| 147 |
+
|
| 148 |
+
self.enable_checkpointing = enable_checkpointing
|
| 149 |
+
|
| 150 |
+
self.net = nn.ModuleList(
|
| 151 |
+
[
|
| 152 |
+
ConvBlock(
|
| 153 |
+
channels[i - 1] if i > 0 else in_channels,
|
| 154 |
+
channels[i],
|
| 155 |
+
kernel_size=kernel_size,
|
| 156 |
+
stride=stride,
|
| 157 |
+
dilation=2**i,
|
| 158 |
+
final=(i == len(channels) - 1),
|
| 159 |
+
enable_checkpointing=enable_checkpointing,
|
| 160 |
+
)
|
| 161 |
+
for i in range(len(channels))
|
| 162 |
+
]
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
def forward(self, x: torch.tensor):
|
| 166 |
+
for layer in self.net:
|
| 167 |
+
x = layer(x)
|
| 168 |
+
return x
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class TSEncoder2D(nn.Module):
|
| 172 |
+
def __init__(
|
| 173 |
+
self,
|
| 174 |
+
input_dims,
|
| 175 |
+
output_dims,
|
| 176 |
+
hidden_dims=64,
|
| 177 |
+
depth=10,
|
| 178 |
+
kernel_size=3,
|
| 179 |
+
stride=1,
|
| 180 |
+
enable_checkpointing=False,
|
| 181 |
+
):
|
| 182 |
+
"""
|
| 183 |
+
Original source implementation:
|
| 184 |
+
TS2Vec Encoder: https://github.com/zhihanyue/ts2vec/blob/main/models/encoder.py
|
| 185 |
+
|
| 186 |
+
See ConvBlock function for argument definitions.
|
| 187 |
+
"""
|
| 188 |
+
super().__init__()
|
| 189 |
+
self.input_dims = input_dims
|
| 190 |
+
self.output_dims = output_dims
|
| 191 |
+
self.hidden_dims = hidden_dims
|
| 192 |
+
self.enable_checkpointing = enable_checkpointing
|
| 193 |
+
|
| 194 |
+
self.feature_extractor = DilatedConvEncoder(
|
| 195 |
+
input_dims,
|
| 196 |
+
[hidden_dims] * depth + [output_dims],
|
| 197 |
+
kernel_size=kernel_size,
|
| 198 |
+
stride=stride,
|
| 199 |
+
enable_checkpointing=self.enable_checkpointing,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
def forward(self, x: torch.tensor):
|
| 203 |
+
"""
|
| 204 |
+
Args:
|
| 205 |
+
x: torch.tensor of shape (1, 1, B * T * D, N) with time (N) along the last axis.
|
| 206 |
+
Note: the additional (1, 1) for the first two axies is to use 2D convs for
|
| 207 |
+
1D convolution operations.
|
| 208 |
+
Note: B=Batch, T=Number of segments, D=Channels.
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
Temporal encoded version of the input tensor of shape (1, 1, B * T * D, N)
|
| 212 |
+
"""
|
| 213 |
+
return self.feature_extractor(x)
|
barista/models/mlp.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
from barista.models.utils import get_activation_function
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MLP(nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
d_input: int,
|
| 12 |
+
d_out: int,
|
| 13 |
+
layer_list: List = None,
|
| 14 |
+
dropout: float = 0.1,
|
| 15 |
+
bias: bool = True,
|
| 16 |
+
use_first_dropout: bool = True,
|
| 17 |
+
use_final_dropout: bool = False,
|
| 18 |
+
use_final_activation: bool = False,
|
| 19 |
+
activation: str = "linear",
|
| 20 |
+
use_identity_stub: bool = True,
|
| 21 |
+
**kwargs
|
| 22 |
+
):
|
| 23 |
+
super(MLP, self).__init__()
|
| 24 |
+
|
| 25 |
+
self.d_input = d_input
|
| 26 |
+
self.d_out = d_out
|
| 27 |
+
self.layer_list = layer_list
|
| 28 |
+
self.dropout = dropout
|
| 29 |
+
self.use_first_dropout = use_first_dropout
|
| 30 |
+
self.use_final_dropout = use_final_dropout
|
| 31 |
+
self.use_final_activation = use_final_activation
|
| 32 |
+
self.activation_fn = get_activation_function(activation)
|
| 33 |
+
|
| 34 |
+
current_dim = self.d_input
|
| 35 |
+
self.layers = nn.ModuleList()
|
| 36 |
+
if self.layer_list is not None:
|
| 37 |
+
for _, dim in enumerate(self.layer_list):
|
| 38 |
+
self.layers.append(nn.Linear(current_dim, dim, bias=bias))
|
| 39 |
+
current_dim = dim
|
| 40 |
+
else:
|
| 41 |
+
if use_identity_stub:
|
| 42 |
+
self.layers.append(nn.Identity())
|
| 43 |
+
|
| 44 |
+
self.final_layer = nn.Linear(current_dim, self.d_out, bias=bias)
|
| 45 |
+
|
| 46 |
+
def forward(self, x, *args, **kwargs):
|
| 47 |
+
if self.use_first_dropout:
|
| 48 |
+
x = nn.Dropout(self.dropout)(x)
|
| 49 |
+
for layer in self.layers:
|
| 50 |
+
x = layer(x)
|
| 51 |
+
x = self.activation_fn(x)
|
| 52 |
+
x = nn.Dropout(self.dropout)(x)
|
| 53 |
+
x = self.final_layer(x)
|
| 54 |
+
if self.use_final_activation:
|
| 55 |
+
x = self.activation_fn(x)
|
| 56 |
+
if self.use_final_dropout:
|
| 57 |
+
x = nn.Dropout(self.dropout)(x)
|
| 58 |
+
return x
|
| 59 |
+
|
| 60 |
+
|
barista/models/model.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from omegaconf import DictConfig
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
from barista.data.metadata import Metadata
|
| 7 |
+
from barista.models.tokenizer import Tokenizer
|
| 8 |
+
from barista.models.transformer import Transformer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Barista(nn.Module):
|
| 12 |
+
def __init__(self, model_config: DictConfig, metadata: Metadata, *args, **kwargs):
|
| 13 |
+
super().__init__(*args, **kwargs)
|
| 14 |
+
self.metadata = metadata
|
| 15 |
+
|
| 16 |
+
self.tokenizer = Tokenizer(
|
| 17 |
+
config=model_config.tokenizer,
|
| 18 |
+
metadata=self.metadata,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
self.backbone = Transformer(
|
| 22 |
+
**model_config.backbone,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
self.d_hidden = model_config.backbone.d_hidden
|
| 26 |
+
|
| 27 |
+
self.head = None
|
| 28 |
+
|
| 29 |
+
def create_downstream_head(self, n_chans, output_dim):
|
| 30 |
+
self.channel_weights = nn.Linear(
|
| 31 |
+
n_chans * self.tokenizer.num_subsegments,
|
| 32 |
+
1,
|
| 33 |
+
bias=False,
|
| 34 |
+
)
|
| 35 |
+
self.binary_classifier = nn.Linear(
|
| 36 |
+
self.d_hidden, output_dim
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def get_latent_embeddings(self, x: torch.Tensor, subject_sessions: List):
|
| 40 |
+
# Get tokens
|
| 41 |
+
tokenized_x = self.tokenizer(x, subject_sessions, output_as_list=False)
|
| 42 |
+
|
| 43 |
+
# Pass through transformer
|
| 44 |
+
latents = self.backbone(
|
| 45 |
+
x=tokenized_x.tokens,
|
| 46 |
+
seq_lens=tokenized_x.seq_lens,
|
| 47 |
+
position_ids=tokenized_x.position_ids,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
return latents
|
| 51 |
+
|
| 52 |
+
def forward(self, x: torch.Tensor, subject_sessions: List):
|
| 53 |
+
|
| 54 |
+
latents = self.get_latent_embeddings(x, subject_sessions)
|
| 55 |
+
|
| 56 |
+
# Pass through Task head
|
| 57 |
+
batch_size = x[0].shape[0]
|
| 58 |
+
latents_reshaped = latents.reshape(batch_size, -1, latents.shape[-1])
|
| 59 |
+
x = self.channel_weights(latents_reshaped.permute(0, 2, 1)).squeeze(dim=-1)
|
| 60 |
+
x = self.binary_classifier(x)
|
| 61 |
+
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
def get_task_params(self):
|
| 65 |
+
return [*self.channel_weights.named_parameters(), *self.binary_classifier.named_parameters()]
|
| 66 |
+
|
| 67 |
+
def get_upstream_params(self):
|
| 68 |
+
return [*self.tokenizer.named_parameters(), *self.backbone.named_parameters()]
|
barista/models/spatial_encoder.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
import einops
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SpatialEncoderMeta:
|
| 10 |
+
def __init__(self, subject_session_spatial_groups=None):
|
| 11 |
+
"""Metadata object with subject session information for spatial encoding."""
|
| 12 |
+
self.subject_session_spatial_groups = subject_session_spatial_groups
|
| 13 |
+
|
| 14 |
+
@property
|
| 15 |
+
def num_region_info(self):
|
| 16 |
+
n_effective_components_across_sessions = set(
|
| 17 |
+
[a.n_effective_components for a in self.subject_session_spatial_groups.values()]
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
assert len(n_effective_components_across_sessions) == 1, (
|
| 21 |
+
"Doesn't support variable number of effective components for different subject_sessions"
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
self._num_region_info = n_effective_components_across_sessions.pop()
|
| 25 |
+
return self._num_region_info
|
| 26 |
+
|
| 27 |
+
@property
|
| 28 |
+
def embedding_table_configs(self):
|
| 29 |
+
configs = {}
|
| 30 |
+
for i in range(self.num_region_info):
|
| 31 |
+
n_embeddings_for_components_set = set(
|
| 32 |
+
[a.max_elements_for_component[i] for a in self.subject_session_spatial_groups.values()]
|
| 33 |
+
)
|
| 34 |
+
padding_indices_set = set(
|
| 35 |
+
[a.padding_indices[i] for a in self.subject_session_spatial_groups.values()]
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
assert len(n_embeddings_for_components_set) == 1, (
|
| 39 |
+
"Doesn't support variable number of max components for different subject_sessions, "
|
| 40 |
+
"change to use max of values across the subject if it is not important."
|
| 41 |
+
)
|
| 42 |
+
assert len(padding_indices_set) == 1, (
|
| 43 |
+
"Doesn't support variable number of padding indices for different subject_sessions, "
|
| 44 |
+
"change to use max of values across the subject if it is not important."
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
configs[i] = {
|
| 48 |
+
'num_embeddings': n_embeddings_for_components_set.pop(),
|
| 49 |
+
'padding_idx': padding_indices_set.pop()
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
return configs
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class BaseSpatialEncoder(ABC, nn.Module):
|
| 56 |
+
"""Abstract class definition for spatial encoding modules.
|
| 57 |
+
|
| 58 |
+
Implement this interface to try new spatial encoding approaches in the tokenizer.
|
| 59 |
+
"""
|
| 60 |
+
_SUBJ_SESH_QUERY_HASH_STR = "{0}_queryvec"
|
| 61 |
+
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
dim_h: int,
|
| 65 |
+
spatial_encoder_meta: SpatialEncoderMeta,
|
| 66 |
+
):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.dim_h = dim_h
|
| 69 |
+
self.spatial_encoder_meta = spatial_encoder_meta
|
| 70 |
+
|
| 71 |
+
self._construct_region_encoding_meta()
|
| 72 |
+
|
| 73 |
+
def _construct_region_encoding_meta(self):
|
| 74 |
+
"""Constructs a hashmap of channel region information -> query vector for spatial encoding."""
|
| 75 |
+
for (
|
| 76 |
+
subject_session,
|
| 77 |
+
spatial_groups,
|
| 78 |
+
) in self.spatial_encoder_meta.subject_session_spatial_groups.items():
|
| 79 |
+
query_vector = torch.tensor(
|
| 80 |
+
[tuple(map(int, e[:spatial_groups.n_effective_components])) for e in spatial_groups.group_components]
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
query_vector = self._transform_query_vector(query_vector)
|
| 84 |
+
|
| 85 |
+
self.register_buffer(
|
| 86 |
+
BaseSpatialEncoder._SUBJ_SESH_QUERY_HASH_STR.format(subject_session),
|
| 87 |
+
query_vector, persistent=False
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def _transform_query_vector(self, query_vector: torch.Tensor):
|
| 91 |
+
return query_vector
|
| 92 |
+
|
| 93 |
+
def get_embedding_table_query_vector(self, subject_session: str) -> torch.Tensor:
|
| 94 |
+
return self._buffers[BaseSpatialEncoder._SUBJ_SESH_QUERY_HASH_STR.format(subject_session)].to(torch.long)
|
| 95 |
+
|
| 96 |
+
def update_for_new_sessions(self,
|
| 97 |
+
new_subject_session_spatial_groups):
|
| 98 |
+
self.spatial_encoder_meta.subject_session_spatial_groups = new_subject_session_spatial_groups
|
| 99 |
+
self._construct_region_encoding_meta()
|
| 100 |
+
return []
|
| 101 |
+
|
| 102 |
+
@abstractmethod
|
| 103 |
+
def _encode(self, x: torch.tensor) -> torch.tensor:
|
| 104 |
+
pass
|
| 105 |
+
|
| 106 |
+
@abstractmethod
|
| 107 |
+
def _get_position_encoding(
|
| 108 |
+
self, x: torch.tensor, subject_session: str
|
| 109 |
+
) -> torch.tensor:
|
| 110 |
+
pass
|
| 111 |
+
|
| 112 |
+
def forward(
|
| 113 |
+
self,
|
| 114 |
+
x: torch.tensor,
|
| 115 |
+
subject_session: str,
|
| 116 |
+
timepoints: int = 1,
|
| 117 |
+
mask: torch.tensor = None,
|
| 118 |
+
) -> torch.tensor:
|
| 119 |
+
"""
|
| 120 |
+
Args:
|
| 121 |
+
x: torch.tensor of shape (B, T*R, D). Time-space interleaved tokens of dim D.
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
A torch.tensor of shape (B, T*R, D) that is the encoding corresponding to
|
| 125 |
+
the input token x.
|
| 126 |
+
"""
|
| 127 |
+
session_PE = self._get_position_encoding(x, subject_session)
|
| 128 |
+
assert (
|
| 129 |
+
x.shape[-1] == session_PE.shape[-1]
|
| 130 |
+
), f"Region dimension mismatch: {x.shape[-1]} vs {session_PE.shape[-1]}."
|
| 131 |
+
|
| 132 |
+
position_encoding = einops.repeat(
|
| 133 |
+
session_PE, "r d -> b (t r) d", b=x.shape[0], t=timepoints
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
if mask is not None:
|
| 137 |
+
position_encoding = position_encoding[:, mask, :]
|
| 138 |
+
|
| 139 |
+
assert (
|
| 140 |
+
x.shape == position_encoding.shape
|
| 141 |
+
), "Output position encoding does not match in shape"
|
| 142 |
+
return position_encoding
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class EmbeddingTable(BaseSpatialEncoder):
|
| 146 |
+
def __init__(
|
| 147 |
+
self,
|
| 148 |
+
dim_h: int,
|
| 149 |
+
spatial_encoder_meta: SpatialEncoderMeta,
|
| 150 |
+
embedding_max_dim: Optional[float] = None,
|
| 151 |
+
embedding_init_scale: float = 1.0
|
| 152 |
+
):
|
| 153 |
+
"""A lookup table of different embeddings for different spatial fields."""
|
| 154 |
+
super().__init__(dim_h, spatial_encoder_meta)
|
| 155 |
+
|
| 156 |
+
# Create the embeddings.
|
| 157 |
+
self.subcomponent_embedding_info = self.spatial_encoder_meta.embedding_table_configs
|
| 158 |
+
subcomponent_dims = self._get_subcomponent_dims()
|
| 159 |
+
|
| 160 |
+
self.subcomponent_embeddings = nn.ModuleDict()
|
| 161 |
+
for (
|
| 162 |
+
subcomponent_ind,
|
| 163 |
+
subcomponent_config,
|
| 164 |
+
) in self.subcomponent_embedding_info.items():
|
| 165 |
+
subcomponent_dim = subcomponent_dims[subcomponent_ind]
|
| 166 |
+
|
| 167 |
+
self.subcomponent_embeddings[str(subcomponent_ind)] = nn.Embedding(
|
| 168 |
+
subcomponent_config["num_embeddings"],
|
| 169 |
+
subcomponent_dim,
|
| 170 |
+
padding_idx=subcomponent_config["padding_idx"],
|
| 171 |
+
max_norm=embedding_max_dim,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
self.init_weights_for_embeddings(
|
| 175 |
+
self.subcomponent_embeddings[str(subcomponent_ind)],
|
| 176 |
+
embedding_init_scale
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
@abstractmethod
|
| 180 |
+
def _get_subcomponent_dims(self):
|
| 181 |
+
raise NotImplementedError
|
| 182 |
+
|
| 183 |
+
def update_for_new_sessions(self, new_subject_session_spatial_groups):
|
| 184 |
+
"""Add need embedding table elements based on new subject session information."""
|
| 185 |
+
new_params = super().update_for_new_sessions(new_subject_session_spatial_groups)
|
| 186 |
+
|
| 187 |
+
subcomponent_embedding_info = self.spatial_encoder_meta.embedding_table_configs
|
| 188 |
+
for subcomponent_ind, subcomponent_config in subcomponent_embedding_info.items():
|
| 189 |
+
prev_embeddings = self.subcomponent_embeddings[str(subcomponent_ind)]
|
| 190 |
+
n_rows, subcomponent_dim = prev_embeddings.weight.shape
|
| 191 |
+
|
| 192 |
+
if subcomponent_config['num_embeddings'] == n_rows:
|
| 193 |
+
# no need to add any new embedding
|
| 194 |
+
continue
|
| 195 |
+
|
| 196 |
+
new_embeddings = torch.empty(
|
| 197 |
+
subcomponent_config['num_embeddings'] - n_rows,
|
| 198 |
+
subcomponent_dim,
|
| 199 |
+
device=prev_embeddings.weight.device
|
| 200 |
+
)
|
| 201 |
+
nn.init.normal_(new_embeddings)
|
| 202 |
+
|
| 203 |
+
new_data = torch.cat((prev_embeddings.weight.data, new_embeddings))
|
| 204 |
+
|
| 205 |
+
self.subcomponent_embeddings[str(subcomponent_ind)] = nn.Embedding(
|
| 206 |
+
subcomponent_config["num_embeddings"],
|
| 207 |
+
subcomponent_dim,
|
| 208 |
+
padding_idx=subcomponent_config["padding_idx"],
|
| 209 |
+
)
|
| 210 |
+
self.subcomponent_embeddings[str(subcomponent_ind)].weight.data = new_data
|
| 211 |
+
|
| 212 |
+
new_params.extend([n for n, _ in self.named_parameters()])
|
| 213 |
+
|
| 214 |
+
return new_params
|
| 215 |
+
|
| 216 |
+
def init_weights_for_embeddings(self, embedding_table: nn.Embedding, embedding_init_scale: float = 1.0):
|
| 217 |
+
nn.init.normal_(embedding_table.weight, std=embedding_init_scale)
|
| 218 |
+
embedding_table._fill_padding_idx_with_zero()
|
| 219 |
+
|
| 220 |
+
def _transform_query_vector(self, query_vector: torch.Tensor):
|
| 221 |
+
return query_vector.to(torch.float).T
|
| 222 |
+
|
| 223 |
+
def _get_position_encoding(
|
| 224 |
+
self, _: torch.tensor, subject_session: str
|
| 225 |
+
) -> torch.tensor:
|
| 226 |
+
"""Returns the encoding vector based on a subject session query."""
|
| 227 |
+
session_region_query = self.get_embedding_table_query_vector(
|
| 228 |
+
subject_session
|
| 229 |
+
)
|
| 230 |
+
single_session_PE = self._encode(session_region_query)
|
| 231 |
+
return single_session_PE
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class EmbeddingTablePool(EmbeddingTable):
|
| 235 |
+
def _get_subcomponent_dims(self):
|
| 236 |
+
return {k: self.dim_h for k in self.subcomponent_embedding_info.keys()}
|
| 237 |
+
|
| 238 |
+
def _encode(self, x: torch.tensor) -> torch.tensor:
|
| 239 |
+
"""
|
| 240 |
+
Args:
|
| 241 |
+
x: torch.tensor of shape (B, T*R, D). Time-space interleaved tokens of dim D.
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
A torch.tensor of shape (B, T*R, D) that is the encoding corresponding to
|
| 245 |
+
the input token. If token has multiple spatial fields, the encoding for
|
| 246 |
+
each of these fields will be summed together before being return (e.g.,
|
| 247 |
+
x,y,z LPI coordinates).
|
| 248 |
+
"""
|
| 249 |
+
PE = torch.zeros((x.shape[0], x.shape[1], self.dim_h), device=x.get_device())
|
| 250 |
+
for subcomponent_ind in range(x.shape[0]):
|
| 251 |
+
subcomponent_x = x[subcomponent_ind, ...]
|
| 252 |
+
PE[subcomponent_ind, ...] = self.subcomponent_embeddings[
|
| 253 |
+
str(subcomponent_ind)
|
| 254 |
+
](subcomponent_x)
|
| 255 |
+
return torch.sum(PE, axis=0)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def create_spatial_encoder(
|
| 259 |
+
dim_h: int,
|
| 260 |
+
subject_session_spatial_groups=None,
|
| 261 |
+
embedding_max_dim=None,
|
| 262 |
+
embedding_init_scale=1.0,
|
| 263 |
+
) -> BaseSpatialEncoder:
|
| 264 |
+
"""Creates the spatial encoder and the cached spatial encoding information needed during forward passes."""
|
| 265 |
+
spatial_encoder_meta = SpatialEncoderMeta(
|
| 266 |
+
subject_session_spatial_groups
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
spatial_encoder = EmbeddingTablePool(
|
| 270 |
+
dim_h,
|
| 271 |
+
spatial_encoder_meta,
|
| 272 |
+
embedding_max_dim,
|
| 273 |
+
embedding_init_scale
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
return spatial_encoder
|
barista/models/tokenized_batched_item.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import einops
|
| 3 |
+
import torch
|
| 4 |
+
from typing import List, Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclasses.dataclass
|
| 8 |
+
class TokenizedBatchedItem:
|
| 9 |
+
"""
|
| 10 |
+
tokens: (B_i, N, D)
|
| 11 |
+
position_ids: (B_i, N)
|
| 12 |
+
temporal_group_ids: (B_i, N)
|
| 13 |
+
spatial_group_ids: (B_i, N)
|
| 14 |
+
seq_lens: List[int]
|
| 15 |
+
spatial_embeddings: (B_i, N, D)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
NOTE: Assumption: Either seq_lens length is one, or B_i is one, i.e. we either
|
| 19 |
+
have a batched tensor or a list of single tensors.
|
| 20 |
+
"""
|
| 21 |
+
tokens: torch.Tensor
|
| 22 |
+
position_ids: torch.Tensor
|
| 23 |
+
seq_lens: List[int]
|
| 24 |
+
spatial_embeddings: Optional[torch.Tensor]
|
| 25 |
+
temporal_group_ids: Optional[torch.Tensor]
|
| 26 |
+
spatial_group_ids: Optional[torch.Tensor]
|
| 27 |
+
subject_sessions: List[str]
|
| 28 |
+
|
| 29 |
+
@classmethod
|
| 30 |
+
def get_as_one_sequence(
|
| 31 |
+
cls, tokenized_items_list: List["TokenizedBatchedItem"]
|
| 32 |
+
) -> "TokenizedBatchedItem":
|
| 33 |
+
"""
|
| 34 |
+
Generate a long concatenated sequence from a list of TokenizedBatchedItem
|
| 35 |
+
"""
|
| 36 |
+
(
|
| 37 |
+
seq_lens,
|
| 38 |
+
tokens_list,
|
| 39 |
+
position_ids,
|
| 40 |
+
temporal_group_ids,
|
| 41 |
+
spatial_group_ids,
|
| 42 |
+
spatial_embeddings_list,
|
| 43 |
+
subject_sessions_list,
|
| 44 |
+
) = ([], [], [], [], [], [], [])
|
| 45 |
+
for item in tokenized_items_list:
|
| 46 |
+
batch_size = item.tokens.shape[0]
|
| 47 |
+
|
| 48 |
+
tokens_list.append(einops.rearrange(item.tokens, "b n d -> (b n) d"))
|
| 49 |
+
if item.spatial_embeddings is not None:
|
| 50 |
+
spatial_embeddings_list.append(
|
| 51 |
+
einops.rearrange(item.spatial_embeddings, "b n d -> (b n) d")
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
if item.position_ids is not None:
|
| 55 |
+
position_ids.append(item.position_ids.flatten())
|
| 56 |
+
|
| 57 |
+
if item.temporal_group_ids is not None:
|
| 58 |
+
temporal_group_ids.append(item.temporal_group_ids.flatten())
|
| 59 |
+
|
| 60 |
+
if item.spatial_group_ids is not None:
|
| 61 |
+
spatial_group_ids.append(item.spatial_group_ids.flatten())
|
| 62 |
+
|
| 63 |
+
seq_lens.extend(item.seq_lens * batch_size)
|
| 64 |
+
subject_sessions_list.extend(item.subject_sessions * batch_size)
|
| 65 |
+
|
| 66 |
+
tokens = torch.cat(tokens_list).unsqueeze(dim=0)
|
| 67 |
+
assert tokens.shape[:2] == (1, sum(seq_lens))
|
| 68 |
+
|
| 69 |
+
if len(spatial_embeddings_list) > 0:
|
| 70 |
+
spatial_embeddings = torch.cat(spatial_embeddings_list).unsqueeze(dim=0)
|
| 71 |
+
assert spatial_embeddings.shape[:2] == (1, sum(seq_lens))
|
| 72 |
+
else:
|
| 73 |
+
spatial_embeddings = None
|
| 74 |
+
|
| 75 |
+
if len(position_ids) > 0:
|
| 76 |
+
position_ids = torch.cat(position_ids).unsqueeze(dim=0)
|
| 77 |
+
assert position_ids.shape == (1, sum(seq_lens))
|
| 78 |
+
else:
|
| 79 |
+
position_ids = None
|
| 80 |
+
|
| 81 |
+
if len(temporal_group_ids) > 0:
|
| 82 |
+
temporal_group_ids = torch.cat(temporal_group_ids).unsqueeze(dim=0)
|
| 83 |
+
assert temporal_group_ids.shape == (1, sum(seq_lens))
|
| 84 |
+
else:
|
| 85 |
+
temporal_group_ids = None
|
| 86 |
+
|
| 87 |
+
if len(spatial_group_ids) > 0:
|
| 88 |
+
spatial_group_ids = torch.cat(spatial_group_ids).unsqueeze(dim=0)
|
| 89 |
+
assert spatial_group_ids.shape == (1, sum(seq_lens))
|
| 90 |
+
else:
|
| 91 |
+
spatial_group_ids = None
|
| 92 |
+
|
| 93 |
+
return TokenizedBatchedItem(
|
| 94 |
+
tokens=tokens,
|
| 95 |
+
position_ids=position_ids,
|
| 96 |
+
temporal_group_ids=temporal_group_ids,
|
| 97 |
+
spatial_group_ids=spatial_group_ids,
|
| 98 |
+
seq_lens=seq_lens,
|
| 99 |
+
spatial_embeddings=spatial_embeddings,
|
| 100 |
+
subject_sessions=subject_sessions_list
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
def get_as_list_items(self) -> List["TokenizedBatchedItem"]:
|
| 104 |
+
"""
|
| 105 |
+
Note: this does not exactly reverse `get_as_one_sequence` because it does not batch items with the
|
| 106 |
+
same seq length together
|
| 107 |
+
"""
|
| 108 |
+
tokenized_items_list = []
|
| 109 |
+
cur_total_len = 0
|
| 110 |
+
for seq_ind, seq_len in enumerate(self.seq_lens):
|
| 111 |
+
tokens = TokenizedBatchedItem(
|
| 112 |
+
tokens=self.tokens[:, cur_total_len : cur_total_len + seq_len],
|
| 113 |
+
position_ids=None if self.position_ids is None else self.position_ids[
|
| 114 |
+
:, cur_total_len : cur_total_len + seq_len
|
| 115 |
+
],
|
| 116 |
+
temporal_group_ids=self.temporal_group_ids[
|
| 117 |
+
:, cur_total_len : cur_total_len + seq_len
|
| 118 |
+
],
|
| 119 |
+
spatial_group_ids=self.spatial_group_ids[
|
| 120 |
+
:, cur_total_len : cur_total_len + seq_len
|
| 121 |
+
],
|
| 122 |
+
spatial_embeddings=None if self.spatial_embeddings is None else self.spatial_embeddings[
|
| 123 |
+
:, cur_total_len : cur_total_len + seq_len
|
| 124 |
+
],
|
| 125 |
+
seq_lens=[seq_len],
|
| 126 |
+
subject_sessions=self.subject_sessions[seq_ind]
|
| 127 |
+
)
|
| 128 |
+
cur_total_len += seq_len
|
| 129 |
+
|
| 130 |
+
tokenized_items_list.append(tokens)
|
| 131 |
+
|
| 132 |
+
return tokenized_items_list
|
barista/models/tokenizer.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import einops
|
| 2 |
+
from omegaconf import DictConfig
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from typing import Dict, List, Union
|
| 6 |
+
|
| 7 |
+
import barista.models.spatial_encoder as spe
|
| 8 |
+
from barista.data.metadata import Metadata
|
| 9 |
+
from barista.models.mlp import MLP
|
| 10 |
+
from barista.models.tokenized_batched_item import TokenizedBatchedItem
|
| 11 |
+
from barista.models.TSEncoder2D import TSEncoder2D
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Tokenizer(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
config: DictConfig,
|
| 18 |
+
metadata: Metadata,
|
| 19 |
+
):
|
| 20 |
+
super().__init__()
|
| 21 |
+
|
| 22 |
+
self.metadata = metadata
|
| 23 |
+
self.config = config
|
| 24 |
+
|
| 25 |
+
self.subjects = metadata.get_subjects()
|
| 26 |
+
|
| 27 |
+
self.num_subsegments = int(
|
| 28 |
+
(
|
| 29 |
+
self.config.samp_frequency * self.config.num_seconds
|
| 30 |
+
- self.config.temporal_subsegment_len
|
| 31 |
+
)
|
| 32 |
+
// (self.config.temporal_subsegment_step)
|
| 33 |
+
+ 1
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
self.dim_h = self.config.d_hidden
|
| 37 |
+
|
| 38 |
+
self._build_temporal_encoder()
|
| 39 |
+
|
| 40 |
+
self._build_temporal_pooler()
|
| 41 |
+
|
| 42 |
+
self._build_spatial_encoder()
|
| 43 |
+
|
| 44 |
+
def _build_temporal_encoder(self):
|
| 45 |
+
self.config.temporal_encoder.input_dims = 1
|
| 46 |
+
self.config.temporal_encoder.output_dims = 1
|
| 47 |
+
self.temporal_encoder = TSEncoder2D(**self.config.temporal_encoder)
|
| 48 |
+
|
| 49 |
+
def _build_temporal_pooler(self):
|
| 50 |
+
self.temporal_pooler = MLP(
|
| 51 |
+
d_input=self.config.temporal_subsegment_len,
|
| 52 |
+
d_out=self.dim_h,
|
| 53 |
+
dropout=0.0,
|
| 54 |
+
bias=False,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def _build_spatial_encoder(self):
|
| 58 |
+
self.subject_session_spatial_groups = {}
|
| 59 |
+
for sub_sesh in self.metadata.get_subject_session_d_input().keys():
|
| 60 |
+
spatial_grouping = self.metadata.get_spatial_grouping(
|
| 61 |
+
subject_session=sub_sesh, name=self.config.spatial_grouping
|
| 62 |
+
)
|
| 63 |
+
self.subject_session_spatial_groups[sub_sesh] = spatial_grouping
|
| 64 |
+
|
| 65 |
+
self.spatial_encoder = spe.create_spatial_encoder(
|
| 66 |
+
dim_h=self.dim_h,
|
| 67 |
+
subject_session_spatial_groups=self.subject_session_spatial_groups,
|
| 68 |
+
embedding_max_dim=self.config.get('embedding_max_dim', None),
|
| 69 |
+
embedding_init_scale=self.config.get('embedding_init_scale', 1.0),
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
def update_for_new_sessions(
|
| 73 |
+
self,
|
| 74 |
+
new_session_d_input_dict: Dict[str, int],
|
| 75 |
+
new_metadata: Metadata,
|
| 76 |
+
) -> List:
|
| 77 |
+
|
| 78 |
+
self.subject_session_spatial_groups = {}
|
| 79 |
+
for sub_sesh in new_session_d_input_dict.keys():
|
| 80 |
+
spatial_grouping = new_metadata.get_spatial_grouping(
|
| 81 |
+
subject_session=sub_sesh, name=self.config.spatial_grouping
|
| 82 |
+
)
|
| 83 |
+
self.subject_session_spatial_groups[sub_sesh] = spatial_grouping
|
| 84 |
+
|
| 85 |
+
self.metadata = new_metadata
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
new_params = []
|
| 89 |
+
if self.config.add_spatial_encoding:
|
| 90 |
+
new_se_params = self.spatial_encoder.update_for_new_sessions(
|
| 91 |
+
new_subject_session_spatial_groups=self.subject_session_spatial_groups
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
new_params.extend([f"spatial_encoder.{n}" for n in new_se_params])
|
| 95 |
+
|
| 96 |
+
return new_params
|
| 97 |
+
|
| 98 |
+
def _tokenize_for_batch_tensor(
|
| 99 |
+
self,
|
| 100 |
+
x: Union[torch.Tensor, List],
|
| 101 |
+
subject_session: str,
|
| 102 |
+
add_spatial_encoding_to_tokens: bool = True,
|
| 103 |
+
) -> torch.tensor:
|
| 104 |
+
"""
|
| 105 |
+
Args:
|
| 106 |
+
x: Input tensor of shape (B, N, D) or a list of tensors each of shape (N_i, D_i)
|
| 107 |
+
B: Batch size
|
| 108 |
+
N: Time points
|
| 109 |
+
R: Channel dim
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
Tokenized version of the same data as a TokenizedBatchedItem object.
|
| 113 |
+
"""
|
| 114 |
+
batch_size, num_timepoints, num_channels = x.shape
|
| 115 |
+
|
| 116 |
+
x = einops.rearrange(x, "b n d -> b d n")
|
| 117 |
+
|
| 118 |
+
# NOTE that unfold doesn't copy the memory, so if step is less than size (sliding window)
|
| 119 |
+
# and any of shared elements are changed, all occurance of that element in patches will change
|
| 120 |
+
x = x.unfold(
|
| 121 |
+
dimension=-1,
|
| 122 |
+
size=self.config.temporal_subsegment_len,
|
| 123 |
+
step=self.config.temporal_subsegment_step,
|
| 124 |
+
) # (B D num_subsegments subseg_len)
|
| 125 |
+
|
| 126 |
+
collapsed_x = einops.rearrange(
|
| 127 |
+
x, "b d t n -> (b t d) n"
|
| 128 |
+
) # (B * T * D, N)
|
| 129 |
+
|
| 130 |
+
transposed_tokens = einops.rearrange(
|
| 131 |
+
collapsed_x, "btd n -> 1 1 btd n"
|
| 132 |
+
) # (1, 1, B * T * D, N)
|
| 133 |
+
|
| 134 |
+
collapsed_tokens = self.temporal_encoder(transposed_tokens)
|
| 135 |
+
collapsed_tokens = collapsed_tokens.squeeze() # (B * T * D, N)
|
| 136 |
+
|
| 137 |
+
# "Time" dimension to hidden dimension. Using a fully connected layer here.
|
| 138 |
+
collapsed_tokens = self.temporal_pooler(
|
| 139 |
+
collapsed_tokens
|
| 140 |
+
) # (B * T * D, N) -> (B * T * D, HID_D)
|
| 141 |
+
|
| 142 |
+
collapsed_tokens_full = collapsed_tokens
|
| 143 |
+
|
| 144 |
+
# Create the time-space interleaved tokens.
|
| 145 |
+
tokens = einops.rearrange(
|
| 146 |
+
collapsed_tokens_full,
|
| 147 |
+
"(b t d) dh -> b (t d) dh",
|
| 148 |
+
b=batch_size,
|
| 149 |
+
t=self.num_subsegments,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
seqlen_timepoints = self.num_subsegments
|
| 153 |
+
|
| 154 |
+
if self.config.add_spatial_encoding:
|
| 155 |
+
spatial_encoding = self.spatial_encoder(
|
| 156 |
+
tokens,
|
| 157 |
+
subject_session=subject_session,
|
| 158 |
+
timepoints=seqlen_timepoints,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Make sure regions at differnet timestamps have same spatial encoding
|
| 162 |
+
assert (
|
| 163 |
+
seqlen_timepoints == 1
|
| 164 |
+
or spatial_encoding[0, 0, 0] == spatial_encoding[0, num_channels, 0]
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
if add_spatial_encoding_to_tokens:
|
| 168 |
+
tokens = tokens + spatial_encoding
|
| 169 |
+
|
| 170 |
+
else: # not self.config.add_spatial_encoding
|
| 171 |
+
spatial_encoding = None
|
| 172 |
+
|
| 173 |
+
temporal_group_ids = torch.arange(seqlen_timepoints, device=x.device)
|
| 174 |
+
temporal_group_ids = einops.repeat(
|
| 175 |
+
temporal_group_ids,
|
| 176 |
+
"t -> b (t d)",
|
| 177 |
+
b=batch_size,
|
| 178 |
+
d=num_channels
|
| 179 |
+
)
|
| 180 |
+
# Make sure different regions at same timestamps have same positional encoding
|
| 181 |
+
assert seqlen_timepoints == 1 or (
|
| 182 |
+
temporal_group_ids[0, 0] == temporal_group_ids[0, 1]
|
| 183 |
+
and temporal_group_ids[0, 0]
|
| 184 |
+
!= temporal_group_ids[
|
| 185 |
+
0, num_channels
|
| 186 |
+
]
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
position_ids = temporal_group_ids.clone()
|
| 190 |
+
|
| 191 |
+
return TokenizedBatchedItem(
|
| 192 |
+
tokens=tokens,
|
| 193 |
+
position_ids=position_ids,
|
| 194 |
+
spatial_group_ids=None,
|
| 195 |
+
temporal_group_ids=temporal_group_ids,
|
| 196 |
+
seq_lens=[tokens.shape[1]],
|
| 197 |
+
spatial_embeddings=spatial_encoding,
|
| 198 |
+
subject_sessions=[subject_session]
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
def forward(
|
| 202 |
+
self,
|
| 203 |
+
x: List,
|
| 204 |
+
subject_sessions: List,
|
| 205 |
+
output_as_list: bool = False,
|
| 206 |
+
add_spatial_encoding_to_tokens: bool = True,
|
| 207 |
+
) -> Union[TokenizedBatchedItem, List[TokenizedBatchedItem]]:
|
| 208 |
+
"""
|
| 209 |
+
Args:
|
| 210 |
+
x: A list of tensors each of shape (B_i, N_i, D_i)
|
| 211 |
+
B: Batch size
|
| 212 |
+
N: Time points
|
| 213 |
+
D: Channel dim
|
| 214 |
+
subject_sessions: list of strings corresponding to subject_session identifier
|
| 215 |
+
output_as_list: if True, will output a list of TokenizedBatchedItem, each correspond to one subject,
|
| 216 |
+
if False, will merge all as a long sequence
|
| 217 |
+
add_spatial_encoding_to_tokens: bool. Adds spatial encoding to tokens
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
TokenizedBatchItem if output_as_list is False, else list of TokenizedBatchItem objects.
|
| 221 |
+
"""
|
| 222 |
+
passed_datapoints = 0
|
| 223 |
+
tokenized_items_list = []
|
| 224 |
+
|
| 225 |
+
for x_item in x:
|
| 226 |
+
tokenized_item = self._tokenize_for_batch_tensor(
|
| 227 |
+
x_item,
|
| 228 |
+
subject_sessions[passed_datapoints],
|
| 229 |
+
add_spatial_encoding_to_tokens=add_spatial_encoding_to_tokens,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
tokenized_items_list.append(tokenized_item)
|
| 233 |
+
passed_datapoints += x_item.shape[0]
|
| 234 |
+
|
| 235 |
+
if output_as_list:
|
| 236 |
+
return tokenized_items_list
|
| 237 |
+
|
| 238 |
+
return TokenizedBatchedItem.get_as_one_sequence(tokenized_items_list)
|
barista/models/transformer.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import xformers.ops as xops
|
| 5 |
+
from einops import rearrange, repeat
|
| 6 |
+
|
| 7 |
+
from barista.models.utils import get_activation_function
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class RotaryEmbedding(nn.Module):
|
| 11 |
+
def __init__(self, d_head, base=10000, max_position=1024):
|
| 12 |
+
super().__init__()
|
| 13 |
+
|
| 14 |
+
self.d_head = d_head
|
| 15 |
+
self.max_position = max_position
|
| 16 |
+
|
| 17 |
+
inv_freq = 1 / (
|
| 18 |
+
base
|
| 19 |
+
** (torch.arange(0, self.d_head, 2, dtype=torch.float32) / self.d_head)
|
| 20 |
+
)
|
| 21 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 22 |
+
self.build_cache()
|
| 23 |
+
|
| 24 |
+
def build_cache(self):
|
| 25 |
+
t = torch.arange(
|
| 26 |
+
self.max_position,
|
| 27 |
+
dtype=self.inv_freq.dtype,
|
| 28 |
+
)
|
| 29 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq) # (self.max_position, d//2)
|
| 30 |
+
|
| 31 |
+
emb = torch.cat((freqs, freqs), dim=-1) # (self.max_position, d)
|
| 32 |
+
dtype = torch.get_default_dtype()
|
| 33 |
+
self.register_buffer(
|
| 34 |
+
"cos_cached", emb.cos().to(dtype), persistent=False
|
| 35 |
+
) # (self.max_position, d)
|
| 36 |
+
self.register_buffer(
|
| 37 |
+
"sin_cached", emb.sin().to(dtype), persistent=False
|
| 38 |
+
) # (self.max_position, d)
|
| 39 |
+
|
| 40 |
+
def forward(self, position_ids):
|
| 41 |
+
"""Returns the rotation matrices"""
|
| 42 |
+
cos = self.cos_cached[position_ids].unsqueeze(2) # [bs, seq_len, 1, head_dim]
|
| 43 |
+
sin = self.sin_cached[position_ids].unsqueeze(2) # [bs, seq_len, 1, head_dim]
|
| 44 |
+
return cos, sin
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def rotate_half(x):
|
| 48 |
+
"""Rotates half the hidden dims of the input."""
|
| 49 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 50 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 51 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def apply_rotary_pos_emb(q, k, cos, sin):
|
| 55 |
+
"""
|
| 56 |
+
Applies the rotation matrices on query and key tensors
|
| 57 |
+
q: B x seq_len x num_head x head_dim
|
| 58 |
+
k: B x seq_len x num_head x head_dim
|
| 59 |
+
"""
|
| 60 |
+
q_embed = (q * cos.to(q)) + (
|
| 61 |
+
rotate_half(q) * sin.to(q)
|
| 62 |
+
) # [bs, seq_len, num_heads, head_dim]
|
| 63 |
+
k_embed = (k * cos.to(k)) + (
|
| 64 |
+
rotate_half(k) * sin.to(k)
|
| 65 |
+
) # [bs, seq_len, num_heads, head_dim]
|
| 66 |
+
return q_embed, k_embed
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class RMSNorm(nn.Module):
|
| 70 |
+
def __init__(self, d_hidden, eps=1e-6):
|
| 71 |
+
"""
|
| 72 |
+
https://github.com/huggingface/transformers/blob/8e164c5400b7b413c7b8fb32e35132001effc970/src/transformers/models/llama/modeling_llama.py#L74
|
| 73 |
+
"""
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.weight = nn.Parameter(torch.ones(d_hidden))
|
| 76 |
+
self.variance_epsilon = eps
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
input_dtype = x.dtype
|
| 80 |
+
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
| 81 |
+
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
| 82 |
+
return (self.weight * x).to(input_dtype)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class SelfAttention(nn.Module):
|
| 86 |
+
|
| 87 |
+
def __init__(
|
| 88 |
+
self, d_hidden, num_heads=8, dropout=0.1, **kwargs
|
| 89 |
+
):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.d_hidden = d_hidden
|
| 92 |
+
self.num_heads = num_heads
|
| 93 |
+
self.d_head = self.d_hidden // self.num_heads
|
| 94 |
+
self.dropout = nn.Dropout(dropout)
|
| 95 |
+
|
| 96 |
+
assert (
|
| 97 |
+
self.d_hidden % self.num_heads == 0
|
| 98 |
+
), f"Number of attention heads: {self.num_heads} must divide embedding dimension: {self.d_hidden}."
|
| 99 |
+
|
| 100 |
+
self.qkv_proj = nn.Linear(self.d_hidden, 3 * self.d_hidden, bias=True)
|
| 101 |
+
self.o_proj = nn.Linear(self.d_hidden, self.d_hidden, bias=True)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def get_qkv(self, x):
|
| 105 |
+
q, k, v = self.qkv_proj(x).chunk(3, dim=-1)
|
| 106 |
+
|
| 107 |
+
q = rearrange(q, "b n (h d_h) -> b n h d_h", h=self.num_heads)
|
| 108 |
+
k = rearrange(k, "b n (h d_h) -> b n h d_h", h=self.num_heads)
|
| 109 |
+
v = rearrange(v, "b n (h d_h) -> b n h d_h", h=self.num_heads)
|
| 110 |
+
return q, k, v
|
| 111 |
+
|
| 112 |
+
def get_attention_out(self, q, k, v, seq_lens=None):
|
| 113 |
+
attention_weights = None
|
| 114 |
+
|
| 115 |
+
attention_out = self.get_memory_efficient_attention(q, k, v, seq_lens)
|
| 116 |
+
|
| 117 |
+
attention_out = self.dropout(attention_out)
|
| 118 |
+
attention_out = rearrange(attention_out, "b n h d_h -> b n (h d_h)")
|
| 119 |
+
out = self.o_proj(attention_out)
|
| 120 |
+
return out, attention_weights
|
| 121 |
+
|
| 122 |
+
def get_memory_efficient_attention(self, q, k, v, seq_lens=None):
|
| 123 |
+
if seq_lens is not None and q.shape[0] == 1:
|
| 124 |
+
attn_bias = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
|
| 125 |
+
else:
|
| 126 |
+
attn_bias = None
|
| 127 |
+
|
| 128 |
+
attn_bias = attn_bias.to(q.device)
|
| 129 |
+
|
| 130 |
+
assert q.shape[-2:] == (
|
| 131 |
+
self.num_heads,
|
| 132 |
+
self.d_head,
|
| 133 |
+
)
|
| 134 |
+
attention_out = xops.memory_efficient_attention(
|
| 135 |
+
q,
|
| 136 |
+
k,
|
| 137 |
+
v,
|
| 138 |
+
p=0,
|
| 139 |
+
attn_bias=attn_bias,
|
| 140 |
+
)
|
| 141 |
+
return attention_out
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def forward(self, x, seq_lens=None, **kwargs):
|
| 145 |
+
if seq_lens is None and x.shape[0] == 1:
|
| 146 |
+
raise ValueError(
|
| 147 |
+
f"'seq_lens' for memory efficient attention with variable length sequences (x.shape[0] == 1) must be non-None."
|
| 148 |
+
)
|
| 149 |
+
q, k, v = self.get_qkv(x)
|
| 150 |
+
out, att_weights = self.get_attention_out(q, k, v, seq_lens)
|
| 151 |
+
return out, att_weights
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class RotarySelfAttention(SelfAttention):
|
| 155 |
+
def __init__(
|
| 156 |
+
self,
|
| 157 |
+
d_hidden,
|
| 158 |
+
num_heads=8,
|
| 159 |
+
max_position=1024,
|
| 160 |
+
dropout=0.1,
|
| 161 |
+
**kwargs,
|
| 162 |
+
):
|
| 163 |
+
super().__init__(
|
| 164 |
+
d_hidden=d_hidden,
|
| 165 |
+
num_heads=num_heads,
|
| 166 |
+
dropout=dropout,
|
| 167 |
+
)
|
| 168 |
+
self.max_position = max_position
|
| 169 |
+
self.rotary_emb = RotaryEmbedding(self.d_head, max_position=self.max_position)
|
| 170 |
+
|
| 171 |
+
def forward(self, x, position_ids=None, seq_lens=None):
|
| 172 |
+
if seq_lens is None and x.shape[0] == 1:
|
| 173 |
+
raise ValueError(
|
| 174 |
+
"'seq_lens' for memory efficient attention with variable length sequences (x.shape[0] == 1) must be non-None."
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
if position_ids is None:
|
| 178 |
+
if x.shape[0] == 1:
|
| 179 |
+
position_ids = [torch.arange(seq_len_, device=x.device, dtype=int) for seq_len_ in seq_lens]
|
| 180 |
+
position_ids = torch.cat(position_ids).unsqueeze(dim=0)
|
| 181 |
+
else:
|
| 182 |
+
position_ids = repeat(
|
| 183 |
+
torch.arange(x.shape[1], device=x.device, dtype=int), "n -> b n", b=x.shape[0])
|
| 184 |
+
|
| 185 |
+
q, k, v = self.get_qkv(x)
|
| 186 |
+
|
| 187 |
+
cos, sin = self.rotary_emb(position_ids)
|
| 188 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
| 189 |
+
v = v.to(q)
|
| 190 |
+
|
| 191 |
+
out, att_weights = self.get_attention_out(q, k, v, seq_lens)
|
| 192 |
+
return out, att_weights
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class GatedTransformerMLP(nn.Module):
|
| 196 |
+
def __init__(self, d_hidden, mlp_ratio=4, activation="silu", dropout=0.1):
|
| 197 |
+
super().__init__()
|
| 198 |
+
d_feedforward = mlp_ratio * d_hidden
|
| 199 |
+
self.gate_proj = nn.Linear(d_hidden, d_feedforward, bias=True)
|
| 200 |
+
self.down_proj = nn.Linear(d_feedforward, d_hidden, bias=True)
|
| 201 |
+
self.up_proj = nn.Linear(d_hidden, d_feedforward, bias=True)
|
| 202 |
+
self.activation_fn = get_activation_function(activation)
|
| 203 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 204 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 205 |
+
|
| 206 |
+
def forward(self, x):
|
| 207 |
+
x = self.dropout1(self.activation_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 208 |
+
return self.dropout2(self.down_proj(x))
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class TransformerEncoderLayer(nn.Module):
|
| 212 |
+
def __init__(
|
| 213 |
+
self,
|
| 214 |
+
d_hidden,
|
| 215 |
+
mlp_ratio=4,
|
| 216 |
+
norm="rmsnorm",
|
| 217 |
+
norm_eps=1e-6,
|
| 218 |
+
activation="silu",
|
| 219 |
+
num_heads=8,
|
| 220 |
+
dropout=0.1,
|
| 221 |
+
**attention_module_kwargs,
|
| 222 |
+
):
|
| 223 |
+
super().__init__()
|
| 224 |
+
self.d_hidden = d_hidden
|
| 225 |
+
|
| 226 |
+
attention_cls = RotarySelfAttention
|
| 227 |
+
|
| 228 |
+
self.attention = attention_cls(
|
| 229 |
+
d_hidden=d_hidden,
|
| 230 |
+
num_heads=num_heads,
|
| 231 |
+
dropout=dropout,
|
| 232 |
+
**attention_module_kwargs,
|
| 233 |
+
)
|
| 234 |
+
self.mlp = GatedTransformerMLP(
|
| 235 |
+
d_hidden=d_hidden,
|
| 236 |
+
mlp_ratio=mlp_ratio,
|
| 237 |
+
activation=activation,
|
| 238 |
+
dropout=dropout,
|
| 239 |
+
)
|
| 240 |
+
self.dropout = nn.Dropout(dropout)
|
| 241 |
+
|
| 242 |
+
if norm.lower() == "rmsnorm":
|
| 243 |
+
self.norm1 = RMSNorm(d_hidden, eps=norm_eps)
|
| 244 |
+
self.norm2 = RMSNorm(d_hidden, eps=norm_eps)
|
| 245 |
+
elif norm.lower() == "layernorm":
|
| 246 |
+
self.norm1 = nn.LayerNorm(d_hidden, eps=norm_eps)
|
| 247 |
+
self.norm2 = nn.LayerNorm(d_hidden, eps=norm_eps)
|
| 248 |
+
else:
|
| 249 |
+
raise NotImplementedError()
|
| 250 |
+
|
| 251 |
+
def forward(self, x, position_ids=None, seq_lens=None, ):
|
| 252 |
+
residual = x
|
| 253 |
+
x = self.norm1(x)
|
| 254 |
+
x, att_weights = self.attention(
|
| 255 |
+
x=x,
|
| 256 |
+
position_ids=position_ids,
|
| 257 |
+
seq_lens=seq_lens,
|
| 258 |
+
)
|
| 259 |
+
x = self.dropout(x)
|
| 260 |
+
x = residual + x
|
| 261 |
+
|
| 262 |
+
residual = x
|
| 263 |
+
x = self.norm2(x)
|
| 264 |
+
x = self.mlp(x)
|
| 265 |
+
x = residual + x
|
| 266 |
+
|
| 267 |
+
return x, att_weights
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class Transformer(nn.Module):
|
| 271 |
+
def __init__(
|
| 272 |
+
self,
|
| 273 |
+
num_layers,
|
| 274 |
+
d_hidden,
|
| 275 |
+
mlp_ratio=4,
|
| 276 |
+
norm="rmsnorm",
|
| 277 |
+
norm_eps=1e-6,
|
| 278 |
+
activation="gelu",
|
| 279 |
+
num_heads=8,
|
| 280 |
+
dropout=0.1,
|
| 281 |
+
**attention_module_kwargs,
|
| 282 |
+
):
|
| 283 |
+
super().__init__()
|
| 284 |
+
self.layers = nn.ModuleList(
|
| 285 |
+
[
|
| 286 |
+
TransformerEncoderLayer(
|
| 287 |
+
d_hidden=d_hidden,
|
| 288 |
+
mlp_ratio=mlp_ratio,
|
| 289 |
+
norm=norm,
|
| 290 |
+
norm_eps=norm_eps,
|
| 291 |
+
activation=activation,
|
| 292 |
+
num_heads=num_heads,
|
| 293 |
+
dropout=dropout,
|
| 294 |
+
**attention_module_kwargs,
|
| 295 |
+
)
|
| 296 |
+
for _ in range(num_layers)
|
| 297 |
+
]
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
if norm.lower() == "rmsnorm":
|
| 301 |
+
self.norm = RMSNorm(d_hidden, eps=norm_eps)
|
| 302 |
+
elif norm.lower() == "layernorm":
|
| 303 |
+
self.norm = nn.LayerNorm(d_hidden, eps=norm_eps)
|
| 304 |
+
|
| 305 |
+
def forward(self, x, position_ids=None, seq_lens=None, **kwargs):
|
| 306 |
+
weights_list = []
|
| 307 |
+
for layer in self.layers:
|
| 308 |
+
x, weights = layer(
|
| 309 |
+
x=x,
|
| 310 |
+
position_ids=position_ids,
|
| 311 |
+
seq_lens=seq_lens,
|
| 312 |
+
)
|
| 313 |
+
weights_list.append(weights)
|
| 314 |
+
|
| 315 |
+
if self.norm:
|
| 316 |
+
x = self.norm(x)
|
| 317 |
+
|
| 318 |
+
return x
|
barista/models/utils.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_activation_function(activation_str):
|
| 9 |
+
if activation_str.lower() == "relu":
|
| 10 |
+
return nn.ReLU()
|
| 11 |
+
elif activation_str.lower() == "linear":
|
| 12 |
+
return lambda x: x
|
| 13 |
+
elif activation_str.lower() == "gelu":
|
| 14 |
+
return nn.GELU()
|
| 15 |
+
|
| 16 |
+
def seed_everything(seed):
|
| 17 |
+
os.environ["PL_GLOBAL_SEED"] = str(seed)
|
| 18 |
+
random.seed(seed)
|
| 19 |
+
np.random.seed(seed)
|
| 20 |
+
torch.manual_seed(seed)
|
| 21 |
+
torch.cuda.manual_seed(seed)
|
| 22 |
+
torch.cuda.manual_seed_all(seed)
|
| 23 |
+
print(f"Random seed set as {seed}")
|
barista/prepare_segments.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Script to preprocess and prepare data segments.
|
| 2 |
+
|
| 3 |
+
Example usage:
|
| 4 |
+
python prepare_segments.py --config config/braintreebank_config.yaml --experiment sentence_onset
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
from omegaconf import OmegaConf
|
| 9 |
+
|
| 10 |
+
from barista.data.braintreebank_wrapper import BrainTreebankWrapper
|
| 11 |
+
|
| 12 |
+
if __name__ == "__main__":
|
| 13 |
+
parser = argparse.ArgumentParser()
|
| 14 |
+
|
| 15 |
+
parser.add_argument("--config", required=True, type=str, help="path to config for segmentation")
|
| 16 |
+
parser.add_argument("--experiment", required=True, type=str, help="experiment to segment data for")
|
| 17 |
+
|
| 18 |
+
args = parser.parse_args()
|
| 19 |
+
|
| 20 |
+
print(f"Loading config: {args.config}")
|
| 21 |
+
config = OmegaConf.load(args.config)
|
| 22 |
+
|
| 23 |
+
## Instantiating BrainTreebankWrapper will be default handle all preprocessing.
|
| 24 |
+
## If preprocessing is complete, then the dataset will be ready to use for training.
|
| 25 |
+
config.experiment = args.experiment
|
| 26 |
+
print(f"Segmenting data for experiment {args.experiment}")
|
| 27 |
+
braintreebank_wrapper = BrainTreebankWrapper(config, only_segment_generation=True)
|
barista/train.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import copy
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from omegaconf import OmegaConf
|
| 7 |
+
from sklearn.metrics import roc_auc_score
|
| 8 |
+
from torch import nn, optim
|
| 9 |
+
|
| 10 |
+
from barista.data.braintreebank_dataset import BrainTreebankDataset
|
| 11 |
+
from barista.models.model import Barista
|
| 12 |
+
from barista.models.utils import seed_everything
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def parse_args():
|
| 16 |
+
"""Parse command line arguments."""
|
| 17 |
+
parser = argparse.ArgumentParser(
|
| 18 |
+
description="Fine-tune Barista model on BrainTreebank dataset"
|
| 19 |
+
)
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
"--dataset_config",
|
| 22 |
+
type=str,
|
| 23 |
+
default="barista/config/braintreebank.yaml",
|
| 24 |
+
help="Path to dataset configuration file",
|
| 25 |
+
)
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--train_config",
|
| 28 |
+
type=str,
|
| 29 |
+
default="barista/config/train.yaml",
|
| 30 |
+
help="Path to training configuration file",
|
| 31 |
+
)
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--model_config",
|
| 34 |
+
type=str,
|
| 35 |
+
default="barista/config/model.yaml",
|
| 36 |
+
help="Path to model configuration file",
|
| 37 |
+
)
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--override",
|
| 40 |
+
type=str,
|
| 41 |
+
nargs="+",
|
| 42 |
+
default=[],
|
| 43 |
+
help="Override config parameters (e.g., --override epochs=50 optimization.finetune_lr=1e-4)",
|
| 44 |
+
)
|
| 45 |
+
return parser.parse_args()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def load_configs(args):
|
| 49 |
+
"""Load all configuration files."""
|
| 50 |
+
dataset_config = OmegaConf.load(args.dataset_config)
|
| 51 |
+
train_config = OmegaConf.load(args.train_config)
|
| 52 |
+
model_config = OmegaConf.load(args.model_config)
|
| 53 |
+
|
| 54 |
+
assert (
|
| 55 |
+
len(dataset_config.finetune_sessions) == 1
|
| 56 |
+
), "Specify one session for finetuning"
|
| 57 |
+
|
| 58 |
+
return dataset_config, train_config, model_config
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def apply_overrides(config_dict, overrides):
|
| 62 |
+
"""Apply command-line overrides to configs using dot notation."""
|
| 63 |
+
if not overrides:
|
| 64 |
+
return config_dict
|
| 65 |
+
|
| 66 |
+
override_dict = {}
|
| 67 |
+
for override in overrides:
|
| 68 |
+
if "=" not in override:
|
| 69 |
+
raise ValueError(
|
| 70 |
+
f"Invalid override format: {override}. Expected format: key=value"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
key, value = override.split("=", 1)
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
if value.isnumeric():
|
| 77 |
+
if "." in value:
|
| 78 |
+
value = float(value)
|
| 79 |
+
else:
|
| 80 |
+
value = int(value)
|
| 81 |
+
elif value.startswith("[") or value in ("True", "False"): # list, bool
|
| 82 |
+
value = eval(value)
|
| 83 |
+
except ValueError as e:
|
| 84 |
+
print(e)
|
| 85 |
+
pass
|
| 86 |
+
|
| 87 |
+
keys = key.split(".")
|
| 88 |
+
current = override_dict
|
| 89 |
+
for k in keys[:-1]:
|
| 90 |
+
if k not in current:
|
| 91 |
+
current[k] = {}
|
| 92 |
+
current = current[k]
|
| 93 |
+
current[keys[-1]] = value
|
| 94 |
+
|
| 95 |
+
# Convert override dict to OmegaConf and merge
|
| 96 |
+
override_conf = OmegaConf.create(override_dict)
|
| 97 |
+
|
| 98 |
+
# Determine which config to merge based on keys
|
| 99 |
+
merged_configs = {}
|
| 100 |
+
for config_name, config in config_dict.items():
|
| 101 |
+
config_keys = set(OmegaConf.to_container(config).keys())
|
| 102 |
+
override_keys = set(override_dict.keys())
|
| 103 |
+
|
| 104 |
+
if config_keys.intersection(override_keys):
|
| 105 |
+
merged_configs[config_name] = OmegaConf.merge(config, override_conf)
|
| 106 |
+
else:
|
| 107 |
+
merged_configs[config_name] = config
|
| 108 |
+
|
| 109 |
+
if merged_configs.get("train") is not None:
|
| 110 |
+
merged_configs["train"] = OmegaConf.merge(
|
| 111 |
+
merged_configs["train"], override_conf
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
return merged_configs
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def setup_dataloaders(dataset_config, train_config):
|
| 118 |
+
"""Initialize dataset and create dataloaders."""
|
| 119 |
+
dataset = BrainTreebankDataset(dataset_config)
|
| 120 |
+
|
| 121 |
+
train_dataloader = dataset.get_dataloader("train", train_config)
|
| 122 |
+
val_dataloader = dataset.get_dataloader("val", train_config)
|
| 123 |
+
test_dataloader = dataset.get_dataloader("test", train_config)
|
| 124 |
+
|
| 125 |
+
print(f"Train: {len(train_dataloader.dataset.metadata)} samples")
|
| 126 |
+
print(f"Val: {len(val_dataloader.dataset.metadata)} samples")
|
| 127 |
+
print(f"Test: {len(test_dataloader.dataset.metadata)} samples")
|
| 128 |
+
|
| 129 |
+
dataset.check_no_common_segment(train_dataloader, val_dataloader, test_dataloader)
|
| 130 |
+
|
| 131 |
+
return dataset, train_dataloader, val_dataloader, test_dataloader
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def get_optimizer(model, finetune_lr=1e-4, new_param_lr=1e-3):
|
| 135 |
+
"""Create optimizer with different learning rates for task and upstream parameters."""
|
| 136 |
+
task_params, upstream_params = [], []
|
| 137 |
+
|
| 138 |
+
for _, p in model.get_task_params():
|
| 139 |
+
if p.requires_grad:
|
| 140 |
+
task_params.append(p)
|
| 141 |
+
|
| 142 |
+
for _, p in model.get_upstream_params():
|
| 143 |
+
if p.requires_grad:
|
| 144 |
+
upstream_params.append(p)
|
| 145 |
+
|
| 146 |
+
params = [
|
| 147 |
+
{"params": upstream_params, "lr": finetune_lr},
|
| 148 |
+
{"params": task_params, "lr": new_param_lr},
|
| 149 |
+
]
|
| 150 |
+
|
| 151 |
+
optimizer = optim.AdamW(params, lr=finetune_lr, weight_decay=1e-2)
|
| 152 |
+
return optimizer
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def get_lr_scheduler(optimizer):
|
| 156 |
+
"""Create learning rate scheduler with warmup and exponential decay."""
|
| 157 |
+
milestone = 5
|
| 158 |
+
|
| 159 |
+
lr_schedulers_list = [
|
| 160 |
+
torch.optim.lr_scheduler.LinearLR(
|
| 161 |
+
optimizer,
|
| 162 |
+
start_factor=0.2,
|
| 163 |
+
end_factor=1.0,
|
| 164 |
+
total_iters=milestone,
|
| 165 |
+
),
|
| 166 |
+
torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99),
|
| 167 |
+
]
|
| 168 |
+
|
| 169 |
+
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
|
| 170 |
+
optimizer,
|
| 171 |
+
lr_schedulers_list,
|
| 172 |
+
milestones=[milestone],
|
| 173 |
+
)
|
| 174 |
+
return lr_scheduler
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def load_pretrained_weights(model, checkpoint_path, device):
|
| 178 |
+
"""Load pretrained weights, excluding masked_recon and multi_head_fc layers."""
|
| 179 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
|
| 180 |
+
model.load_state_dict(checkpoint)
|
| 181 |
+
print(f"Pretrained weights loaded from {checkpoint_path}")
|
| 182 |
+
return model
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def freeze_tokenizer(model):
|
| 186 |
+
for n, p in model.tokenizer.named_parameters():
|
| 187 |
+
p.requires_grad = False
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def print_number_of_parmas(model):
|
| 191 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 192 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 193 |
+
|
| 194 |
+
print(f"Model parameters: {total_params}\t Trainable params: {trainable_params}")
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def run_epoch(
|
| 198 |
+
model, dataloader, criterion, device, optimizer=None, scheduler=None, train=False
|
| 199 |
+
):
|
| 200 |
+
"""Run one epoch of training or evaluation."""
|
| 201 |
+
if train:
|
| 202 |
+
model.train()
|
| 203 |
+
else:
|
| 204 |
+
model.eval()
|
| 205 |
+
|
| 206 |
+
all_preds = []
|
| 207 |
+
all_labels = []
|
| 208 |
+
running_loss = 0
|
| 209 |
+
|
| 210 |
+
for batch in dataloader:
|
| 211 |
+
x = [x_item.to(device) for x_item in batch.x]
|
| 212 |
+
y = batch.labels.flatten().long().to(device)
|
| 213 |
+
|
| 214 |
+
if train:
|
| 215 |
+
optimizer.zero_grad()
|
| 216 |
+
|
| 217 |
+
with torch.set_grad_enabled(train):
|
| 218 |
+
logits = model(
|
| 219 |
+
x,
|
| 220 |
+
subject_sessions=batch.subject_sessions,
|
| 221 |
+
)
|
| 222 |
+
loss = criterion(logits, y)
|
| 223 |
+
|
| 224 |
+
if train:
|
| 225 |
+
loss.backward()
|
| 226 |
+
optimizer.step()
|
| 227 |
+
|
| 228 |
+
running_loss += loss.item() * y.size(0)
|
| 229 |
+
|
| 230 |
+
probs = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy()
|
| 231 |
+
labels = y.detach().cpu().numpy()
|
| 232 |
+
|
| 233 |
+
all_preds.append(probs)
|
| 234 |
+
all_labels.append(labels)
|
| 235 |
+
|
| 236 |
+
if train:
|
| 237 |
+
# step scheduler at epoch interval
|
| 238 |
+
scheduler.step()
|
| 239 |
+
|
| 240 |
+
all_preds = np.concatenate(all_preds)
|
| 241 |
+
all_labels = np.concatenate(all_labels)
|
| 242 |
+
|
| 243 |
+
try:
|
| 244 |
+
auc = roc_auc_score(all_labels, all_preds)
|
| 245 |
+
except:
|
| 246 |
+
auc = float("nan")
|
| 247 |
+
|
| 248 |
+
avg_loss = running_loss / len(dataloader.dataset)
|
| 249 |
+
return avg_loss, auc
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def finetune_model(model, train_dataloader, val_dataloader, train_config, device):
|
| 253 |
+
"""Finetune the model and track best validation performance."""
|
| 254 |
+
criterion = nn.CrossEntropyLoss()
|
| 255 |
+
optimizer = get_optimizer(
|
| 256 |
+
model,
|
| 257 |
+
finetune_lr=train_config.optimization.finetune_lr,
|
| 258 |
+
new_param_lr=train_config.optimization.new_param_lr,
|
| 259 |
+
)
|
| 260 |
+
scheduler = get_lr_scheduler(optimizer)
|
| 261 |
+
|
| 262 |
+
best_val_auc = -1
|
| 263 |
+
best_state = None
|
| 264 |
+
num_epochs = train_config.epochs
|
| 265 |
+
|
| 266 |
+
for epoch in range(num_epochs):
|
| 267 |
+
train_loss, train_auc = run_epoch(
|
| 268 |
+
model, train_dataloader, criterion, device, optimizer, scheduler, train=True
|
| 269 |
+
)
|
| 270 |
+
val_loss, val_auc = evaluate_model(model, val_dataloader, criterion, device)
|
| 271 |
+
|
| 272 |
+
print(
|
| 273 |
+
f"Epoch {epoch+1}/{num_epochs} "
|
| 274 |
+
f"- Train Loss: {train_loss:.4f}, AUC: {train_auc:.4f} "
|
| 275 |
+
f"- Val Loss: {val_loss:.4f}, AUC: {val_auc:.4f}"
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
# Track best model by validation AUC
|
| 279 |
+
if best_state is None or val_auc > best_val_auc:
|
| 280 |
+
best_val_auc = val_auc
|
| 281 |
+
best_state = {
|
| 282 |
+
"epoch": epoch + 1,
|
| 283 |
+
"model": copy.deepcopy(model.state_dict()),
|
| 284 |
+
"optimizer": copy.deepcopy(optimizer.state_dict()),
|
| 285 |
+
"scheduler": copy.deepcopy(scheduler.state_dict()),
|
| 286 |
+
"val_auc": val_auc,
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
return best_state, criterion
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def evaluate_model(model, test_dataloader, criterion, device):
|
| 293 |
+
"""Evaluate model on test set."""
|
| 294 |
+
test_loss, test_auc = run_epoch(
|
| 295 |
+
model, test_dataloader, criterion, device, train=False
|
| 296 |
+
)
|
| 297 |
+
return test_loss, test_auc
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def main():
|
| 301 |
+
"""Main training pipeline."""
|
| 302 |
+
# Parse arguments and load configs
|
| 303 |
+
args = parse_args()
|
| 304 |
+
dataset_config, train_config, model_config = load_configs(args)
|
| 305 |
+
|
| 306 |
+
configs = {"dataset": dataset_config, "train": train_config, "model": model_config}
|
| 307 |
+
configs = apply_overrides(configs, args.override)
|
| 308 |
+
dataset_config = configs["dataset"]
|
| 309 |
+
train_config = configs["train"]
|
| 310 |
+
model_config = configs["model"]
|
| 311 |
+
|
| 312 |
+
# Set random seed
|
| 313 |
+
seed_everything(train_config.seed)
|
| 314 |
+
|
| 315 |
+
# Setup data
|
| 316 |
+
dataset, train_dataloader, val_dataloader, test_dataloader = setup_dataloaders(
|
| 317 |
+
dataset_config, train_config
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
# Get fine-tuning session info
|
| 321 |
+
ft_session = dataset_config.finetune_sessions[0]
|
| 322 |
+
ft_session_n_chans = dataset.metadata.get_subject_session_full_d_data()[ft_session][
|
| 323 |
+
-1
|
| 324 |
+
]
|
| 325 |
+
|
| 326 |
+
# Initialize model
|
| 327 |
+
device = train_config.device
|
| 328 |
+
model = Barista(model_config, dataset.metadata)
|
| 329 |
+
|
| 330 |
+
# Load pretrained weights
|
| 331 |
+
if train_config.checkpoint_path:
|
| 332 |
+
print("Running pretrained model")
|
| 333 |
+
model = load_pretrained_weights(model, train_config.checkpoint_path, device)
|
| 334 |
+
|
| 335 |
+
# Freeze tokenizer
|
| 336 |
+
if train_config.optimization.freeze_tokenizer:
|
| 337 |
+
freeze_tokenizer(model)
|
| 338 |
+
|
| 339 |
+
else:
|
| 340 |
+
print("Running non-pretrained model")
|
| 341 |
+
|
| 342 |
+
# Create downstream head and move to device
|
| 343 |
+
model.create_downstream_head(n_chans=ft_session_n_chans, output_dim=2)
|
| 344 |
+
model.to(device)
|
| 345 |
+
|
| 346 |
+
print_number_of_parmas(model)
|
| 347 |
+
|
| 348 |
+
# Finetune model
|
| 349 |
+
best_state, criterion = finetune_model(
|
| 350 |
+
model, train_dataloader, val_dataloader, train_config, device
|
| 351 |
+
)
|
| 352 |
+
print(f"\nBEST VAL AUC: {best_state['val_auc']:.4f}")
|
| 353 |
+
|
| 354 |
+
# Evaluate on test set
|
| 355 |
+
_, last_test_auc = evaluate_model(model, test_dataloader, criterion, device)
|
| 356 |
+
print(f"LAST TEST AUC: {last_test_auc:.4f}")
|
| 357 |
+
|
| 358 |
+
# Load best model for testing
|
| 359 |
+
model.load_state_dict(best_state["model"])
|
| 360 |
+
|
| 361 |
+
# Evaluate on test set
|
| 362 |
+
_, test_auc = evaluate_model(model, test_dataloader, criterion, device)
|
| 363 |
+
|
| 364 |
+
print(f"BEST TEST AUC: {test_auc:.4f}")
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
if __name__ == "__main__":
|
| 368 |
+
main()
|
barista/utility_scripts/aggregate_runs.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import glob
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
|
| 9 |
+
KEY = 'TEST' # Options: 'VAL', 'TEST', 'LAST_TEST'
|
| 10 |
+
|
| 11 |
+
def parse_summary(path):
|
| 12 |
+
try:
|
| 13 |
+
txt = open(path).read()
|
| 14 |
+
mean = float(re.search(rf"{KEY}_MEAN=([0-9.]+)", txt).group(1))
|
| 15 |
+
std = float(re.search(rf"{KEY}_STD=([0-9.]+)", txt).group(1))
|
| 16 |
+
ckpt_line = re.search(r"Checkpoint:\s*(.*)", txt).group(1)
|
| 17 |
+
model = os.path.basename(ckpt_line).replace(".ckpt", "")
|
| 18 |
+
return model, f"{mean:.3f} ± {std:.3f}"
|
| 19 |
+
except:
|
| 20 |
+
return None
|
| 21 |
+
|
| 22 |
+
def parse_from_seeds(folder):
|
| 23 |
+
logs = sorted(glob.glob(os.path.join(folder, "seed_*.log")))
|
| 24 |
+
expected_seeds = 5
|
| 25 |
+
|
| 26 |
+
if not logs:
|
| 27 |
+
print(f"WARNING: No seed logs found in {folder}")
|
| 28 |
+
return None
|
| 29 |
+
|
| 30 |
+
auc_pattern = r"TEST AUC:\s*([0-9.]+)" if KEY == "TEST" else \
|
| 31 |
+
r"LAST TEST AUC:\s*([0-9.]+)" if KEY == "LAST_TEST" else None
|
| 32 |
+
if auc_pattern is None:
|
| 33 |
+
return None
|
| 34 |
+
|
| 35 |
+
ckpt_pattern = r"'checkpoint_path':\s*'([^']*)'"
|
| 36 |
+
|
| 37 |
+
vals, model_name, valid_logs = [], None, 0
|
| 38 |
+
|
| 39 |
+
for log in logs:
|
| 40 |
+
try:
|
| 41 |
+
txt = open(log).read()
|
| 42 |
+
m = re.search(auc_pattern, txt)
|
| 43 |
+
if m:
|
| 44 |
+
vals.append(float(m.group(1)))
|
| 45 |
+
valid_logs += 1
|
| 46 |
+
|
| 47 |
+
cm = re.search(ckpt_pattern, txt)
|
| 48 |
+
if cm:
|
| 49 |
+
ckpt_path = cm.group(1)
|
| 50 |
+
model_name = os.path.basename(ckpt_path).replace(".ckpt", "")
|
| 51 |
+
except:
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
model_name = model_name or "unknown"
|
| 55 |
+
if model_name == '':
|
| 56 |
+
model_name = "random"
|
| 57 |
+
|
| 58 |
+
if valid_logs != expected_seeds and model_name != 'random':
|
| 59 |
+
print(f"WARNING: Incomplete seeds for {model_name} in {folder} "
|
| 60 |
+
f"(found {valid_logs}/{expected_seeds})")
|
| 61 |
+
|
| 62 |
+
if not vals:
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
mean, std = float(np.mean(vals)), float(np.std(vals))
|
| 66 |
+
return model_name, f"{mean:.3f} ± {std:.3f}"
|
| 67 |
+
|
| 68 |
+
def parse_summary_or_seeds(folder):
|
| 69 |
+
summary_path = os.path.join(folder, "summary.txt")
|
| 70 |
+
if os.path.exists(summary_path):
|
| 71 |
+
parsed = parse_summary(summary_path)
|
| 72 |
+
if parsed:
|
| 73 |
+
return parsed
|
| 74 |
+
return parse_from_seeds(folder)
|
| 75 |
+
|
| 76 |
+
def extract_mean(x):
|
| 77 |
+
if isinstance(x, str) and "±" in x:
|
| 78 |
+
return float(x.split("±")[0].strip())
|
| 79 |
+
return np.nan
|
| 80 |
+
|
| 81 |
+
def main():
|
| 82 |
+
parser = argparse.ArgumentParser()
|
| 83 |
+
parser.add_argument("--results_dir", type=str, default="results", help="Path to results folder")
|
| 84 |
+
args = parser.parse_args()
|
| 85 |
+
ROOT = args.results_dir
|
| 86 |
+
|
| 87 |
+
rows, subjects, tasks, models, folds = [], set(), set(), set(), set()
|
| 88 |
+
|
| 89 |
+
# Collect data from folders
|
| 90 |
+
for folder in os.listdir(ROOT):
|
| 91 |
+
fpath = os.path.join(ROOT, folder)
|
| 92 |
+
if not os.path.isdir(fpath):
|
| 93 |
+
continue
|
| 94 |
+
|
| 95 |
+
parts = folder.split("_")
|
| 96 |
+
if len(parts) < 6:
|
| 97 |
+
continue
|
| 98 |
+
|
| 99 |
+
subj = parts[1]
|
| 100 |
+
task = parts[4]
|
| 101 |
+
if len(parts) > 5 and parts[5] in ["onset", "vs", "nonspeech", "speech", "time"]:
|
| 102 |
+
task += f"_{parts[5]}"
|
| 103 |
+
if len(parts) > 6 and parts[6] == "nonspeech":
|
| 104 |
+
task += f"_{parts[6]}"
|
| 105 |
+
|
| 106 |
+
fold = None
|
| 107 |
+
for p in parts:
|
| 108 |
+
if p.startswith("fold"):
|
| 109 |
+
fold = int(p.replace("fold", ""))
|
| 110 |
+
folds.add(fold)
|
| 111 |
+
break
|
| 112 |
+
|
| 113 |
+
parsed = parse_summary_or_seeds(fpath)
|
| 114 |
+
if not parsed:
|
| 115 |
+
continue
|
| 116 |
+
|
| 117 |
+
model, value = parsed
|
| 118 |
+
subjects.add(subj)
|
| 119 |
+
tasks.add(task)
|
| 120 |
+
models.add(model)
|
| 121 |
+
rows.append((task, model, subj, fold, value))
|
| 122 |
+
|
| 123 |
+
# Build DataFrame
|
| 124 |
+
subjects = sorted(subjects, key=lambda x: int(x))
|
| 125 |
+
df = pd.DataFrame(columns=["task", "model", "fold"] + subjects)
|
| 126 |
+
|
| 127 |
+
for task in sorted(tasks):
|
| 128 |
+
for model in sorted(models):
|
| 129 |
+
all_folds = sorted(folds) + [None]
|
| 130 |
+
for fold in all_folds:
|
| 131 |
+
subset = [(s, v) for t, m, s, f, v in rows if t == task and m == model and f == fold]
|
| 132 |
+
if not subset:
|
| 133 |
+
continue
|
| 134 |
+
row = {"task": task, "model": model, "fold": fold if fold is not None else ""}
|
| 135 |
+
for subj, val in subset:
|
| 136 |
+
row[subj] = val
|
| 137 |
+
df.loc[len(df)] = row
|
| 138 |
+
|
| 139 |
+
# Add AVG column
|
| 140 |
+
subj_cols = [c for c in df.columns if c not in ["task", "model", "fold"]]
|
| 141 |
+
df["avg"] = df[subj_cols].applymap(extract_mean).mean(axis=1)
|
| 142 |
+
df["avg"] = df["avg"].apply(lambda x: f"{x:.3f}" if pd.notnull(x) else "")
|
| 143 |
+
|
| 144 |
+
# Add final AVG rows per (task, model)
|
| 145 |
+
avg_rows = []
|
| 146 |
+
for (task, model), group in df.groupby(["task", "model"]):
|
| 147 |
+
subj_avgs = {}
|
| 148 |
+
for subj in subj_cols:
|
| 149 |
+
vals = [float(v.split("±")[0].strip()) for v in group[subj] if isinstance(v, str) and "±" in v]
|
| 150 |
+
subj_avgs[subj] = f"{np.mean(vals):.3f}" if vals else ""
|
| 151 |
+
overall_vals = [float(v) for v in subj_avgs.values() if v != ""]
|
| 152 |
+
overall_avg = f"{np.mean(overall_vals):.3f}" if overall_vals else ""
|
| 153 |
+
row = {"task": task, "model": model, "fold": "AVG", "avg": overall_avg}
|
| 154 |
+
row.update(subj_avgs)
|
| 155 |
+
avg_rows.append(row)
|
| 156 |
+
|
| 157 |
+
df = pd.concat([df, pd.DataFrame(avg_rows)], ignore_index=True)
|
| 158 |
+
print(df.to_markdown(index=False))
|
| 159 |
+
|
| 160 |
+
if __name__ == "__main__":
|
| 161 |
+
main()
|
barista/utility_scripts/run_finetune_folds.sh
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Usage:
|
| 4 |
+
# ./run_finetune_folds.sh --spe coords --checkpoint "pretrained_models/chans_chans.ckpt" --session HOLDSUBJ_2_HS2_6 --gpu 1 --fold 0 --exp sentence_onset_time
|
| 5 |
+
# ./run_finetune_folds.sh --spe destrieux --checkpoint "pretrained_models/parcels_chans.ckpt" --session HOLDSUBJ_2_HS2_6 --gpu 2 --fold 1 --exp speech_vs_nonspeech_time
|
| 6 |
+
|
| 7 |
+
# Default values
|
| 8 |
+
GPU=0
|
| 9 |
+
SEEDS=(0 1 2 3 4)
|
| 10 |
+
SESSION=""
|
| 11 |
+
CHECKPOINT=""
|
| 12 |
+
DATASET_CONFIG="barista/config/braintreebank.yaml"
|
| 13 |
+
TRAIN_CONFIG="barista/config/train.yaml"
|
| 14 |
+
MODEL_CONFIG="barista/config/model.yaml"
|
| 15 |
+
SPATIAL_GROUPING="coords"
|
| 16 |
+
EXPERIMENT="sentence_onset_time"
|
| 17 |
+
FOLD_NUM=0
|
| 18 |
+
|
| 19 |
+
# Parse arguments
|
| 20 |
+
while [[ $# -gt 0 ]]; do
|
| 21 |
+
case $1 in
|
| 22 |
+
--session)
|
| 23 |
+
SESSION="$2"
|
| 24 |
+
shift 2
|
| 25 |
+
;;
|
| 26 |
+
--checkpoint)
|
| 27 |
+
CHECKPOINT="$2"
|
| 28 |
+
shift 2
|
| 29 |
+
;;
|
| 30 |
+
--gpu)
|
| 31 |
+
GPU="$2"
|
| 32 |
+
shift 2
|
| 33 |
+
;;
|
| 34 |
+
--fold)
|
| 35 |
+
FOLD_NUM="$2"
|
| 36 |
+
shift 2
|
| 37 |
+
;;
|
| 38 |
+
--seeds)
|
| 39 |
+
IFS=',' read -ra SEEDS <<< "$2"
|
| 40 |
+
shift 2
|
| 41 |
+
;;
|
| 42 |
+
--dataset_config)
|
| 43 |
+
DATASET_CONFIG="$2"
|
| 44 |
+
shift 2
|
| 45 |
+
;;
|
| 46 |
+
--exp)
|
| 47 |
+
EXPERIMENT="$2"
|
| 48 |
+
shift 2
|
| 49 |
+
;;
|
| 50 |
+
--train_config)
|
| 51 |
+
TRAIN_CONFIG="$2"
|
| 52 |
+
shift 2
|
| 53 |
+
;;
|
| 54 |
+
--spe)
|
| 55 |
+
SPATIAL_GROUPING="$2"
|
| 56 |
+
shift 2
|
| 57 |
+
;;
|
| 58 |
+
--model_config)
|
| 59 |
+
MODEL_CONFIG="$2"
|
| 60 |
+
shift 2
|
| 61 |
+
;;
|
| 62 |
+
*)
|
| 63 |
+
echo "Unknown argument: $1"
|
| 64 |
+
echo "Usage: $0 --session <session_name> --checkpoint <checkpoint_path> [--gpu <gpu_id>] [--seeds <seed_list>]"
|
| 65 |
+
echo "Example: $0 --session session1 --checkpoint checkpoints/model.pt --gpu 0 --seeds 42,123,456,789,1024"
|
| 66 |
+
exit 1
|
| 67 |
+
;;
|
| 68 |
+
esac
|
| 69 |
+
done
|
| 70 |
+
|
| 71 |
+
# Validate required arguments
|
| 72 |
+
if [ -z "$SESSION" ]; then
|
| 73 |
+
echo "Error: --session is required"
|
| 74 |
+
exit 1
|
| 75 |
+
fi
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
NUM_SEEDS=${#SEEDS[@]}
|
| 79 |
+
|
| 80 |
+
# Create output directory
|
| 81 |
+
OUTPUT_DIR="results_folds/${SESSION}_${EXPERIMENT}_fold${FOLD_NUM}_model${SPATIAL_GROUPING}_$(date +%Y%m%d_%H%M%S)"
|
| 82 |
+
|
| 83 |
+
mkdir -p "$OUTPUT_DIR"
|
| 84 |
+
|
| 85 |
+
echo "=========================================="
|
| 86 |
+
echo "Sequential Multi-Seed Fine-tuning"
|
| 87 |
+
echo "=========================================="
|
| 88 |
+
echo "Session: $SESSION"
|
| 89 |
+
echo "Checkpoint: $CHECKPOINT"
|
| 90 |
+
echo "GPU: $GPU"
|
| 91 |
+
echo "Seeds: ${SEEDS[@]}"
|
| 92 |
+
echo "Number of runs: $NUM_SEEDS"
|
| 93 |
+
echo "Output Directory: $OUTPUT_DIR"
|
| 94 |
+
echo "=========================================="
|
| 95 |
+
echo ""
|
| 96 |
+
|
| 97 |
+
# Arrays to store results
|
| 98 |
+
VAL_AUCS=()
|
| 99 |
+
BEST_TEST_AUCS=()
|
| 100 |
+
LAST_TEST_AUCS=()
|
| 101 |
+
FAILED_SEEDS=()
|
| 102 |
+
|
| 103 |
+
# Run jobs sequentially
|
| 104 |
+
for i in $(seq 0 $(($NUM_SEEDS - 1))); do
|
| 105 |
+
SEED=${SEEDS[$i]}
|
| 106 |
+
|
| 107 |
+
LOG_FILE="$OUTPUT_DIR/seed_${SEED}.log"
|
| 108 |
+
|
| 109 |
+
echo "=========================================="
|
| 110 |
+
echo "Running job $((i+1))/$NUM_SEEDS: Seed=$SEED"
|
| 111 |
+
echo "=========================================="
|
| 112 |
+
echo "Log file: $LOG_FILE"
|
| 113 |
+
echo ""
|
| 114 |
+
|
| 115 |
+
# Run training
|
| 116 |
+
CUDA_VISIBLE_DEVICES=$GPU python barista/train.py \
|
| 117 |
+
--dataset_config "$DATASET_CONFIG" \
|
| 118 |
+
--train_config "$TRAIN_CONFIG" \
|
| 119 |
+
--model_config "$MODEL_CONFIG" \
|
| 120 |
+
--override \
|
| 121 |
+
seed=$SEED \
|
| 122 |
+
device=cuda:0 \
|
| 123 |
+
checkpoint_path="$CHECKPOINT" \
|
| 124 |
+
force_nonoverlap=False \
|
| 125 |
+
experiment=$EXPERIMENT \
|
| 126 |
+
chron_fold_num=$FOLD_NUM \
|
| 127 |
+
tokenizer.spatial_grouping="$SPATIAL_GROUPING" \
|
| 128 |
+
"finetune_sessions=['$SESSION']" \
|
| 129 |
+
2>&1 | tee "$LOG_FILE"
|
| 130 |
+
|
| 131 |
+
# Check if job completed successfully
|
| 132 |
+
if [ ${PIPESTATUS[0]} -eq 0 ]; then
|
| 133 |
+
echo ""
|
| 134 |
+
echo "✓ Job $((i+1)) completed successfully"
|
| 135 |
+
|
| 136 |
+
# Extract results from log file
|
| 137 |
+
VAL_AUC=$(grep "BEST VAL AUC" "$LOG_FILE" | awk '{print $NF}')
|
| 138 |
+
BEST_TEST_AUC=$(grep "BEST TEST AUC" "$LOG_FILE" | tail -1 | awk '{print $NF}')
|
| 139 |
+
LAST_TEST_AUC=$(grep "LAST TEST AUC" "$LOG_FILE" | awk '{print $NF}')
|
| 140 |
+
|
| 141 |
+
if [ ! -z "$VAL_AUC" ] && [ ! -z "$BEST_TEST_AUC" ] && [ ! -z "$LAST_TEST_AUC" ]; then
|
| 142 |
+
VAL_AUCS+=($VAL_AUC)
|
| 143 |
+
BEST_TEST_AUCS+=($BEST_TEST_AUC)
|
| 144 |
+
LAST_TEST_AUCS+=($LAST_TEST_AUC)
|
| 145 |
+
echo " Val AUC: $VAL_AUC"
|
| 146 |
+
echo " Best Test AUC: $BEST_TEST_AUC"
|
| 147 |
+
echo " Last Test AUC: $LAST_TEST_AUC"
|
| 148 |
+
else
|
| 149 |
+
echo " Warning: Could not extract AUC values"
|
| 150 |
+
FAILED_SEEDS+=($SEED)
|
| 151 |
+
fi
|
| 152 |
+
else
|
| 153 |
+
echo ""
|
| 154 |
+
echo "✗ Job $((i+1)) failed"
|
| 155 |
+
FAILED_SEEDS+=($SEED)
|
| 156 |
+
fi
|
| 157 |
+
|
| 158 |
+
echo ""
|
| 159 |
+
done
|
| 160 |
+
|
| 161 |
+
echo "=========================================="
|
| 162 |
+
echo "All jobs completed!"
|
| 163 |
+
echo "=========================================="
|
| 164 |
+
echo ""
|
| 165 |
+
|
| 166 |
+
# Calculate statistics using Python
|
| 167 |
+
STATS_SCRIPT="$OUTPUT_DIR/calculate_stats.py"
|
| 168 |
+
cat > "$STATS_SCRIPT" << 'EOF'
|
| 169 |
+
import sys
|
| 170 |
+
import numpy as np
|
| 171 |
+
|
| 172 |
+
def calculate_stats(values):
|
| 173 |
+
if len(values) == 0:
|
| 174 |
+
return None, None
|
| 175 |
+
arr = np.array(values, dtype=float)
|
| 176 |
+
return np.mean(arr), np.std(arr)
|
| 177 |
+
|
| 178 |
+
# Read values from command line
|
| 179 |
+
val_aucs = [float(x) for x in sys.argv[1].split(',') if x]
|
| 180 |
+
best_test_aucs = [float(x) for x in sys.argv[2].split(',') if x]
|
| 181 |
+
last_test_aucs = [float(x) for x in sys.argv[3].split(',') if x]
|
| 182 |
+
|
| 183 |
+
val_mean, val_std = calculate_stats(val_aucs)
|
| 184 |
+
best_test_mean, best_test_std = calculate_stats(best_test_aucs)
|
| 185 |
+
last_test_mean, last_test_std = calculate_stats(last_test_aucs)
|
| 186 |
+
|
| 187 |
+
print(f"VAL_MEAN={val_mean:.4f}")
|
| 188 |
+
print(f"VAL_STD={val_std:.4f}")
|
| 189 |
+
print(f"BEST_TEST_MEAN={best_test_mean:.4f}")
|
| 190 |
+
print(f"BEST_TEST_STD={best_test_std:.4f}")
|
| 191 |
+
print(f"LAST_TEST_MEAN={last_test_mean:.4f}")
|
| 192 |
+
print(f"LAST_TEST_STD={last_test_std:.4f}")
|
| 193 |
+
|
| 194 |
+
# Print individual values
|
| 195 |
+
print("\nIndividual Results:")
|
| 196 |
+
for i, (val, test, last_test) in enumerate(zip(val_aucs, best_test_aucs, last_test_aucs), 1):
|
| 197 |
+
print(f" Run {i}: Val AUC = {val:.4f}, Best Test AUC = {test:.4f}, Last Test AUC = {last_test:.4f}")
|
| 198 |
+
EOF
|
| 199 |
+
|
| 200 |
+
# Convert arrays to comma-separated strings
|
| 201 |
+
VAL_AUCS_STR=$(IFS=,; echo "${VAL_AUCS[*]}")
|
| 202 |
+
BEST_TEST_AUCS_STR=$(IFS=,; echo "${BEST_TEST_AUCS[*]}")
|
| 203 |
+
LAST_TEST_AUCS_STR=$(IFS=,; echo "${LAST_TEST_AUCS[*]}")
|
| 204 |
+
|
| 205 |
+
# Calculate and display statistics
|
| 206 |
+
if [ ${#BEST_TEST_AUCS[@]} -gt 0 ]; then
|
| 207 |
+
echo "=========================================="
|
| 208 |
+
echo "FINAL RESULTS"
|
| 209 |
+
echo "=========================================="
|
| 210 |
+
|
| 211 |
+
STATS_OUTPUT=$(python "$STATS_SCRIPT" "$VAL_AUCS_STR" "$BEST_TEST_AUCS_STR" "$LAST_TEST_AUCS_STR")
|
| 212 |
+
echo "$STATS_OUTPUT"
|
| 213 |
+
|
| 214 |
+
VAL_MEAN=$(awk -F= '/^VAL_MEAN=/{print $2; exit}' <<<"$STATS_OUTPUT")
|
| 215 |
+
VAL_STD=$(awk -F= '/^VAL_STD=/{print $2; exit}' <<<"$STATS_OUTPUT")
|
| 216 |
+
BEST_TEST_MEAN=$(awk -F= '/^BEST_TEST_MEAN=/{print $2; exit}' <<<"$STATS_OUTPUT")
|
| 217 |
+
BEST_TEST_STD=$(awk -F= '/^BEST_TEST_STD=/{print $2; exit}' <<<"$STATS_OUTPUT")
|
| 218 |
+
LAST_TEST_MEAN=$(awk -F= '/^LAST_TEST_MEAN=/{print $2; exit}' <<<"$STATS_OUTPUT")
|
| 219 |
+
LAST_TEST_STD=$(awk -F= '/^LAST_TEST_STD=/{print $2; exit}' <<<"$STATS_OUTPUT")
|
| 220 |
+
|
| 221 |
+
echo ""
|
| 222 |
+
echo "Summary:"
|
| 223 |
+
echo " Validation AUC: ${VAL_MEAN} ± ${VAL_STD}"
|
| 224 |
+
echo " Test AUC: ${BEST_TEST_MEAN} ± ${BEST_TEST_STD}"
|
| 225 |
+
echo " Last Test AUC: ${LAST_TEST_MEAN} ± ${LAST_TEST_STD}"
|
| 226 |
+
echo ""
|
| 227 |
+
echo "Successful runs: ${#BEST_TEST_AUCS[@]}/$NUM_SEEDS"
|
| 228 |
+
|
| 229 |
+
if [ ${#FAILED_SEEDS[@]} -gt 0 ]; then
|
| 230 |
+
echo "Failed seeds: ${FAILED_SEEDS[@]}"
|
| 231 |
+
fi
|
| 232 |
+
|
| 233 |
+
echo "=========================================="
|
| 234 |
+
|
| 235 |
+
# Save summary to file
|
| 236 |
+
SUMMARY_FILE="$OUTPUT_DIR/summary.txt"
|
| 237 |
+
{
|
| 238 |
+
echo "Summary Report - $(date)"
|
| 239 |
+
echo "=================================="
|
| 240 |
+
echo "Session: $SESSION"
|
| 241 |
+
echo "Checkpoint: $CHECKPOINT"
|
| 242 |
+
echo "GPU: $GPU"
|
| 243 |
+
echo "Seeds: ${SEEDS[@]}"
|
| 244 |
+
echo ""
|
| 245 |
+
echo "FINAL RESULTS"
|
| 246 |
+
echo "=================================="
|
| 247 |
+
echo "$STATS_OUTPUT"
|
| 248 |
+
echo ""
|
| 249 |
+
echo "Summary:"
|
| 250 |
+
echo " Validation AUC: ${VAL_MEAN} ± ${VAL_STD}"
|
| 251 |
+
echo " BEST Test AUC: ${BEST_TEST_MEAN} ± ${BEST_TEST_STD}"
|
| 252 |
+
echo " Last Test AUC: ${LAST_TEST_MEAN} ± ${LAST_TEST_STD}"
|
| 253 |
+
echo ""
|
| 254 |
+
echo "Successful runs: ${#BEST_TEST_AUCS[@]}/$NUM_SEEDS"
|
| 255 |
+
if [ ${#FAILED_SEEDS[@]} -gt 0 ]; then
|
| 256 |
+
echo "Failed seeds: ${FAILED_SEEDS[@]}"
|
| 257 |
+
fi
|
| 258 |
+
} > "$SUMMARY_FILE"
|
| 259 |
+
|
| 260 |
+
echo ""
|
| 261 |
+
echo "Summary saved to: $SUMMARY_FILE"
|
| 262 |
+
echo "All logs saved to: $OUTPUT_DIR"
|
| 263 |
+
else
|
| 264 |
+
echo "ERROR: No successful runs completed"
|
| 265 |
+
exit 1
|
| 266 |
+
fi
|
| 267 |
+
|
| 268 |
+
# Clean up temporary script
|
| 269 |
+
rm "$STATS_SCRIPT"
|
| 270 |
+
|
| 271 |
+
# Exit with error if any jobs failed
|
| 272 |
+
if [ ${#FAILED_SEEDS[@]} -gt 0 ]; then
|
| 273 |
+
exit 1
|
| 274 |
+
fi
|
| 275 |
+
|
| 276 |
+
exit 0
|
barista/utility_scripts/run_finetune_random_splits.sh
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Usage:
|
| 4 |
+
# ./run_finetune_random_splits.sh --spe coords --checkpoint "pretrained_models/chans_chans.ckpt" --session HOLDSUBJ_2_HS2_6 --gpu 1 --exp sentence_onset
|
| 5 |
+
# ./run_finetune_random_splits.sh --spe destrieux --checkpoint "pretrained_models/parcels_chans.ckpt" --session HOLDSUBJ_2_HS2_6 --gpu 2 --exp speech_vs_nonspeech
|
| 6 |
+
|
| 7 |
+
# Default values
|
| 8 |
+
GPU=0
|
| 9 |
+
SEEDS=(0 1 2 3 4)
|
| 10 |
+
SESSION=""
|
| 11 |
+
CHECKPOINT=""
|
| 12 |
+
DATASET_CONFIG="barista/config/braintreebank.yaml"
|
| 13 |
+
TRAIN_CONFIG="barista/config/train.yaml"
|
| 14 |
+
MODEL_CONFIG="barista/config/model.yaml"
|
| 15 |
+
SPATIAL_GROUPING="coords"
|
| 16 |
+
EXPERIMENT="sentence_onset"
|
| 17 |
+
|
| 18 |
+
# Parse arguments
|
| 19 |
+
while [[ $# -gt 0 ]]; do
|
| 20 |
+
case $1 in
|
| 21 |
+
--session)
|
| 22 |
+
SESSION="$2"
|
| 23 |
+
shift 2
|
| 24 |
+
;;
|
| 25 |
+
--checkpoint)
|
| 26 |
+
CHECKPOINT="$2"
|
| 27 |
+
shift 2
|
| 28 |
+
;;
|
| 29 |
+
--gpu)
|
| 30 |
+
GPU="$2"
|
| 31 |
+
shift 2
|
| 32 |
+
;;
|
| 33 |
+
--seeds)
|
| 34 |
+
IFS=',' read -ra SEEDS <<< "$2"
|
| 35 |
+
shift 2
|
| 36 |
+
;;
|
| 37 |
+
--dataset_config)
|
| 38 |
+
DATASET_CONFIG="$2"
|
| 39 |
+
shift 2
|
| 40 |
+
;;
|
| 41 |
+
--train_config)
|
| 42 |
+
TRAIN_CONFIG="$2"
|
| 43 |
+
shift 2
|
| 44 |
+
;;
|
| 45 |
+
--exp)
|
| 46 |
+
EXPERIMENT="$2"
|
| 47 |
+
shift 2
|
| 48 |
+
;;
|
| 49 |
+
--spe)
|
| 50 |
+
SPATIAL_GROUPING="$2"
|
| 51 |
+
shift 2
|
| 52 |
+
;;
|
| 53 |
+
--model_config)
|
| 54 |
+
MODEL_CONFIG="$2"
|
| 55 |
+
shift 2
|
| 56 |
+
;;
|
| 57 |
+
*)
|
| 58 |
+
echo "Unknown argument: $1"
|
| 59 |
+
exit 1
|
| 60 |
+
;;
|
| 61 |
+
esac
|
| 62 |
+
done
|
| 63 |
+
|
| 64 |
+
# Validate required arguments
|
| 65 |
+
if [ -z "$SESSION" ]; then
|
| 66 |
+
echo "Error: --session is required"
|
| 67 |
+
exit 1
|
| 68 |
+
fi
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
NUM_SEEDS=${#SEEDS[@]}
|
| 72 |
+
|
| 73 |
+
# Create output directory
|
| 74 |
+
OUTPUT_DIR="results/${SESSION}_${EXPERIMENT}_model${SPATIAL_GROUPING}$(date +%Y%m%d_%H%M%S)"
|
| 75 |
+
mkdir -p "$OUTPUT_DIR"
|
| 76 |
+
|
| 77 |
+
echo "=========================================="
|
| 78 |
+
echo "Sequential Multi-Seed Fine-tuning"
|
| 79 |
+
echo "=========================================="
|
| 80 |
+
echo "Session: $SESSION"
|
| 81 |
+
echo "Checkpoint: $CHECKPOINT"
|
| 82 |
+
echo "GPU: $GPU"
|
| 83 |
+
echo "Seeds: ${SEEDS[@]}"
|
| 84 |
+
echo "Number of runs: $NUM_SEEDS"
|
| 85 |
+
echo "Output Directory: $OUTPUT_DIR"
|
| 86 |
+
echo "=========================================="
|
| 87 |
+
echo ""
|
| 88 |
+
|
| 89 |
+
# Arrays to store results
|
| 90 |
+
VAL_AUCS=()
|
| 91 |
+
BEST_TEST_AUCS=()
|
| 92 |
+
LAST_TEST_AUCS=()
|
| 93 |
+
FAILED_SEEDS=()
|
| 94 |
+
|
| 95 |
+
# Run jobs sequentially
|
| 96 |
+
for i in $(seq 0 $(($NUM_SEEDS - 1))); do
|
| 97 |
+
SEED=${SEEDS[$i]}
|
| 98 |
+
|
| 99 |
+
LOG_FILE="$OUTPUT_DIR/seed_${SEED}.log"
|
| 100 |
+
|
| 101 |
+
echo "=========================================="
|
| 102 |
+
echo "Running job $((i+1))/$NUM_SEEDS: Seed=$SEED"
|
| 103 |
+
echo "=========================================="
|
| 104 |
+
echo "Log file: $LOG_FILE"
|
| 105 |
+
echo ""
|
| 106 |
+
|
| 107 |
+
# Run training
|
| 108 |
+
CUDA_VISIBLE_DEVICES=$GPU python barista/train.py \
|
| 109 |
+
--dataset_config "$DATASET_CONFIG" \
|
| 110 |
+
--train_config "$TRAIN_CONFIG" \
|
| 111 |
+
--model_config "$MODEL_CONFIG" \
|
| 112 |
+
--override \
|
| 113 |
+
seed=$SEED \
|
| 114 |
+
device=cuda:0 \
|
| 115 |
+
checkpoint_path="$CHECKPOINT" \
|
| 116 |
+
force_nonoverlap=True \
|
| 117 |
+
experiment="$EXPERIMENT" \
|
| 118 |
+
tokenizer.spatial_grouping="$SPATIAL_GROUPING" \
|
| 119 |
+
"finetune_sessions=['$SESSION']" \
|
| 120 |
+
2>&1 | tee "$LOG_FILE"
|
| 121 |
+
|
| 122 |
+
# Check if job completed successfully
|
| 123 |
+
if [ ${PIPESTATUS[0]} -eq 0 ]; then
|
| 124 |
+
echo ""
|
| 125 |
+
echo "✓ Job $((i+1)) completed successfully"
|
| 126 |
+
|
| 127 |
+
# Extract results from log file
|
| 128 |
+
VAL_AUC=$(grep "BEST VAL AUC" "$LOG_FILE" | awk '{print $NF}')
|
| 129 |
+
BEST_TEST_AUC=$(grep "^BEST TEST AUC" "$LOG_FILE" | tail -1 | awk '{print $NF}')
|
| 130 |
+
LAST_TEST_AUC=$(grep "LAST TEST AUC" "$LOG_FILE" | awk '{print $NF}')
|
| 131 |
+
|
| 132 |
+
if [ ! -z "$VAL_AUC" ] && [ ! -z "$BEST_TEST_AUC" ] && [ ! -z "$LAST_TEST_AUC" ]; then
|
| 133 |
+
VAL_AUCS+=($VAL_AUC)
|
| 134 |
+
BEST_TEST_AUCS+=($BEST_TEST_AUC)
|
| 135 |
+
LAST_TEST_AUCS+=($LAST_TEST_AUC)
|
| 136 |
+
echo " Val AUC: $VAL_AUC"
|
| 137 |
+
echo " Test AUC: $BEST_TEST_AUC"
|
| 138 |
+
echo " Last Test AUC: $LAST_TEST_AUC"
|
| 139 |
+
else
|
| 140 |
+
echo " Warning: Could not extract AUC values"
|
| 141 |
+
FAILED_SEEDS+=($SEED)
|
| 142 |
+
fi
|
| 143 |
+
else
|
| 144 |
+
echo ""
|
| 145 |
+
echo "✗ Job $((i+1)) failed"
|
| 146 |
+
FAILED_SEEDS+=($SEED)
|
| 147 |
+
fi
|
| 148 |
+
|
| 149 |
+
echo ""
|
| 150 |
+
done
|
| 151 |
+
|
| 152 |
+
echo "=========================================="
|
| 153 |
+
echo "All jobs completed!"
|
| 154 |
+
echo "=========================================="
|
| 155 |
+
echo ""
|
| 156 |
+
|
| 157 |
+
# Calculate statistics using Python
|
| 158 |
+
STATS_SCRIPT="$OUTPUT_DIR/calculate_stats.py"
|
| 159 |
+
cat > "$STATS_SCRIPT" << 'EOF'
|
| 160 |
+
import sys
|
| 161 |
+
import numpy as np
|
| 162 |
+
|
| 163 |
+
def calculate_stats(values):
|
| 164 |
+
if len(values) == 0:
|
| 165 |
+
return None, None
|
| 166 |
+
arr = np.array(values, dtype=float)
|
| 167 |
+
return np.mean(arr), np.std(arr)
|
| 168 |
+
|
| 169 |
+
# Read values from command line
|
| 170 |
+
val_aucs = [float(x) for x in sys.argv[1].split(',') if x]
|
| 171 |
+
best_test_aucs = [float(x) for x in sys.argv[2].split(',') if x]
|
| 172 |
+
last_test_aucs = [float(x) for x in sys.argv[3].split(',') if x]
|
| 173 |
+
|
| 174 |
+
val_mean, val_std = calculate_stats(val_aucs)
|
| 175 |
+
best_test_mean, best_test_std = calculate_stats(best_test_aucs)
|
| 176 |
+
last_test_mean, last_test_std = calculate_stats(last_test_aucs)
|
| 177 |
+
|
| 178 |
+
print(f"VAL_MEAN={val_mean:.4f}")
|
| 179 |
+
print(f"VAL_STD={val_std:.4f}")
|
| 180 |
+
print(f"BEST_TEST_MEAN={best_test_mean:.4f}")
|
| 181 |
+
print(f"BEST_TEST_STD={best_test_std:.4f}")
|
| 182 |
+
print(f"LAST_TEST_MEAN={last_test_mean:.4f}")
|
| 183 |
+
print(f"LAST_TEST_STD={last_test_std:.4f}")
|
| 184 |
+
|
| 185 |
+
# Print individual values
|
| 186 |
+
print("\nIndividual Results:")
|
| 187 |
+
for i, (val, test, last_test) in enumerate(zip(val_aucs, best_test_aucs, last_test_aucs), 1):
|
| 188 |
+
print(f" Run {i}: Val AUC = {val:.4f}, Best Test AUC = {test:.4f}, Last Test AUC = {last_test:.4f}")
|
| 189 |
+
EOF
|
| 190 |
+
|
| 191 |
+
# Convert arrays to comma-separated strings
|
| 192 |
+
VAL_AUCS_STR=$(IFS=,; echo "${VAL_AUCS[*]}")
|
| 193 |
+
BEST_TEST_AUCS_STR=$(IFS=,; echo "${BEST_TEST_AUCS[*]}")
|
| 194 |
+
LAST_TEST_AUCS_STR=$(IFS=,; echo "${LAST_TEST_AUCS[*]}")
|
| 195 |
+
|
| 196 |
+
# Calculate and display statistics
|
| 197 |
+
if [ ${#BEST_TEST_AUCS[@]} -gt 0 ]; then
|
| 198 |
+
echo "=========================================="
|
| 199 |
+
echo "FINAL RESULTS"
|
| 200 |
+
echo "=========================================="
|
| 201 |
+
|
| 202 |
+
STATS_OUTPUT=$(python "$STATS_SCRIPT" "$VAL_AUCS_STR" "$BEST_TEST_AUCS_STR" "$LAST_TEST_AUCS_STR")
|
| 203 |
+
echo "$STATS_OUTPUT"
|
| 204 |
+
|
| 205 |
+
VAL_MEAN=$(awk -F= '/^VAL_MEAN=/{print $2; exit}' <<<"$STATS_OUTPUT")
|
| 206 |
+
VAL_STD=$(awk -F= '/^VAL_STD=/{print $2; exit}' <<<"$STATS_OUTPUT")
|
| 207 |
+
BEST_TEST_MEAN=$(awk -F= '/^BEST_TEST_MEAN=/{print $2; exit}' <<<"$STATS_OUTPUT")
|
| 208 |
+
BEST_TEST_STD=$(awk -F= '/^BEST_TEST_STD=/{print $2; exit}' <<<"$STATS_OUTPUT")
|
| 209 |
+
LAST_TEST_MEAN=$(awk -F= '/^LAST_TEST_MEAN=/{print $2; exit}' <<<"$STATS_OUTPUT")
|
| 210 |
+
LAST_TEST_STD=$(awk -F= '/^LAST_TEST_STD=/{print $2; exit}' <<<"$STATS_OUTPUT")
|
| 211 |
+
|
| 212 |
+
echo ""
|
| 213 |
+
echo "Summary:"
|
| 214 |
+
echo " Validation AUC: ${VAL_MEAN} ± ${VAL_STD}"
|
| 215 |
+
echo " Best Test AUC: ${BEST_TEST_MEAN} ± ${BEST_TEST_STD}"
|
| 216 |
+
echo " Last Test AUC: ${LAST_TEST_MEAN} ± ${LAST_TEST_STD}"
|
| 217 |
+
echo ""
|
| 218 |
+
echo "Successful runs: ${#BEST_TEST_AUCS[@]}/$NUM_SEEDS"
|
| 219 |
+
|
| 220 |
+
if [ ${#FAILED_SEEDS[@]} -gt 0 ]; then
|
| 221 |
+
echo "Failed seeds: ${FAILED_SEEDS[@]}"
|
| 222 |
+
fi
|
| 223 |
+
|
| 224 |
+
echo "=========================================="
|
| 225 |
+
|
| 226 |
+
# Save summary to file
|
| 227 |
+
SUMMARY_FILE="$OUTPUT_DIR/summary.txt"
|
| 228 |
+
{
|
| 229 |
+
echo "Summary Report - $(date)"
|
| 230 |
+
echo "=================================="
|
| 231 |
+
echo "Session: $SESSION"
|
| 232 |
+
echo "Checkpoint: $CHECKPOINT"
|
| 233 |
+
echo "GPU: $GPU"
|
| 234 |
+
echo "Seeds: ${SEEDS[@]}"
|
| 235 |
+
echo ""
|
| 236 |
+
echo "FINAL RESULTS"
|
| 237 |
+
echo "=================================="
|
| 238 |
+
echo "$STATS_OUTPUT"
|
| 239 |
+
echo ""
|
| 240 |
+
echo "Summary:"
|
| 241 |
+
echo " Validation AUC: ${VAL_MEAN} ± ${VAL_STD}"
|
| 242 |
+
echo " Test AUC: ${BEST_TEST_MEAN} ± ${BEST_TEST_STD}"
|
| 243 |
+
echo " Last Test AUC: ${LAST_TEST_MEAN} ± ${LAST_TEST_STD}"
|
| 244 |
+
echo ""
|
| 245 |
+
echo "Successful runs: ${#BEST_TEST_AUCS[@]}/$NUM_SEEDS"
|
| 246 |
+
if [ ${#FAILED_SEEDS[@]} -gt 0 ]; then
|
| 247 |
+
echo "Failed seeds: ${FAILED_SEEDS[@]}"
|
| 248 |
+
fi
|
| 249 |
+
} > "$SUMMARY_FILE"
|
| 250 |
+
|
| 251 |
+
echo ""
|
| 252 |
+
echo "Summary saved to: $SUMMARY_FILE"
|
| 253 |
+
echo "All logs saved to: $OUTPUT_DIR"
|
| 254 |
+
else
|
| 255 |
+
echo "ERROR: No successful runs completed"
|
| 256 |
+
exit 1
|
| 257 |
+
fi
|
| 258 |
+
|
| 259 |
+
# Clean up temporary script
|
| 260 |
+
rm "$STATS_SCRIPT"
|
| 261 |
+
|
| 262 |
+
# Exit with error if any jobs failed
|
| 263 |
+
if [ ${#FAILED_SEEDS[@]} -gt 0 ]; then
|
| 264 |
+
exit 1
|
| 265 |
+
fi
|
| 266 |
+
|
| 267 |
+
exit 0
|
pretrained_models/chans_chans.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:400eeacc0697004cb81c9ecf754859da184ffeea40afc8ee7b5930c3b997e1d0
|
| 3 |
+
size 3538414
|
pretrained_models/lobes_chans.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d810338a4929df0fb2421f342b3ee859f9fef269e35fb4f2fd9c55347a63324a
|
| 3 |
+
size 3389478
|
pretrained_models/parcels_chans.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6c234517d286a8e710b09716dc88c713618670df523cfffb89e4c9073f2657c1
|
| 3 |
+
size 3415452
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.4.0
|
| 2 |
+
einops==0.8.0
|
| 3 |
+
h5py==3.11.0
|
| 4 |
+
ipykernel==6.29.5
|
| 5 |
+
ipython==8.12.3
|
| 6 |
+
jupyter-client==8.6.3
|
| 7 |
+
jupyter-core==5.7.2
|
| 8 |
+
numpy==1.24.4
|
| 9 |
+
omegaconf==2.3.0
|
| 10 |
+
ordered-set==4.1.0
|
| 11 |
+
pandas==2.0.3
|
| 12 |
+
scikit-learn==1.3.2
|
| 13 |
+
scipy==1.10.1
|
| 14 |
+
xformers==0.0.27.post2
|
| 15 |
+
tabulate==0.9.0
|
setup.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import find_packages, setup
|
| 2 |
+
|
| 3 |
+
with open("requirements.txt", "r") as f:
|
| 4 |
+
requirements = f.read().splitlines()
|
| 5 |
+
|
| 6 |
+
setup(
|
| 7 |
+
name="barista",
|
| 8 |
+
version="1.0.0",
|
| 9 |
+
description="PyTorch implementation of BaRISTA: Brain Scale Informed Spatiotemporal Representation of Human Intracranial Neural Activity",
|
| 10 |
+
long_description=open("README.md", encoding="utf-8").read(),
|
| 11 |
+
long_description_content_type="text/markdown",
|
| 12 |
+
author="Lucine L. Oganesian, Saba Hashemi, Maryam M. Shanechi",
|
| 13 |
+
author_email="shanechi@usc.edu",
|
| 14 |
+
url="https://github.com/ShanechiLab/BaRISTA", # change to actual repo URL
|
| 15 |
+
packages=find_packages(),
|
| 16 |
+
python_requires=">=3.8",
|
| 17 |
+
install_requires=requirements,
|
| 18 |
+
include_package_data=True,
|
| 19 |
+
entry_points={
|
| 20 |
+
"console_scripts": [
|
| 21 |
+
"barista-train=barista.train:main",
|
| 22 |
+
"barista-prepare=barista.prepare_segments:main",
|
| 23 |
+
],
|
| 24 |
+
},
|
| 25 |
+
)
|