savaw commited on
Commit
a35137b
·
verified ·
1 Parent(s): fe8f329

Upload folder using huggingface_hub

Browse files
LICENSE CHANGED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This software is Copyright © 2025 The University of Southern California. All Rights Reserved.
2
+
3
+ Permission to use, copy, modify, and distribute this software and its documentation for educational, research and non-profit purposes, without fee, and without a writen agreement is hereby granted, provided that the above copyright notice, this paragraph and the following three paragraphs appear in all copies.
4
+
5
+ Permission to make commercial use of this software may be obtained by contacting:\
6
+ USC Stevens Center for Innovation\
7
+ University of Southern California\
8
+ 1150 S. Olive Street, Suite 2300\
9
+ Los Angeles, CA 90115, USA\
10
+ E-mail to: info@stevens.usc.edu and cc to: accounting@stevens.usc.edu
11
+
12
+ This software program and documentation are copyrighted by The University of Southern California. The software program and documentation are supplied "as is", without any accompanying services from USC. USC does not warrant that the operation of the program will be uninterrupted or error-free. The end-user understands that the program was developed for research purposes and is advised not to rely exclusively on the program for any reason.
13
+
14
+ IN NO EVENT SHALL THE UNIVERSITY OF SOUTHERN CALIFORNIA BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF THE UNIVERSITY OF SOUTHERN CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. THE UNIVERSITY OF SOUTHERN CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF SOUTHERN CALIFORNIA HAS NO OBLIGATIONS TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
LICENSE.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This software is Copyright © 2025 The University of Southern California. All Rights Reserved.
2
+
3
+ Permission to use, copy, modify, and distribute this software and its documentation for educational, research and non-profit purposes, without fee, and without a writen agreement is hereby granted, provided that the above copyright notice, this paragraph and the following three paragraphs appear in all copies.
4
+
5
+ Permission to make commercial use of this software may be obtained by contacting:\
6
+ USC Stevens Center for Innovation\
7
+ University of Southern California\
8
+ 1150 S. Olive Street, Suite 2300\
9
+ Los Angeles, CA 90115, USA\
10
+ E-mail to: info@stevens.usc.edu and cc to: accounting@stevens.usc.edu
11
+
12
+ This software program and documentation are copyrighted by The University of Southern California. The software program and documentation are supplied "as is", without any accompanying services from USC. USC does not warrant that the operation of the program will be uninterrupted or error-free. The end-user understands that the program was developed for research purposes and is advised not to rely exclusively on the program for any reason.
13
+
14
+ IN NO EVENT SHALL THE UNIVERSITY OF SOUTHERN CALIFORNIA BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF THE UNIVERSITY OF SOUTHERN CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. THE UNIVERSITY OF SOUTHERN CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF SOUTHERN CALIFORNIA HAS NO OBLIGATIONS TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
README.md CHANGED
@@ -1,5 +1,255 @@
1
- ---
2
- license: other
3
- license_name: usc
4
- license_link: LICENSE
5
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ tags:
4
+ - ieeg
5
+ - bci
6
+ - neuroscience
7
+ - foundation-model
8
+ - neurips-2025
9
+ arxiv: 2512.12135
10
+ metrics:
11
+ - accuracy
12
+ license: other
13
+ license_link: LICENSE
14
+ ---
15
+
16
+ # BaRISTA ☕
17
+
18
+ [![Python](https://img.shields.io/badge/python-3.8%2B-blue)](https://www.python.org/)
19
+ [![NeurIPS 2025](https://img.shields.io/badge/NeurIPS-2025-a55eea)](https://openreview.net/forum?id=LDjBDk3Czb)
20
+
21
+ This repository contains the official PyTorch implementation of [**BaRISTA** (Brain Scale Informed Spatiotemporal Representation of Human Intracranial Neural Activity)](#publication).
22
+
23
+ ## Table of Contents
24
+ - [Installation](#installation)
25
+ - [Data Preparation](#data-preparation)
26
+ - [Data Segmentation](#data-segmentation)
27
+ - [Finetuning the Model](#finetuning-the-model)
28
+ - [Additional Scripts](#additional-scripts)
29
+ - [Publication](#publication)
30
+
31
+ ---
32
+ ## Installation
33
+
34
+ We recommend setting up a virtual environment to manage dependencies.
35
+
36
+ ```bash
37
+ # 1. Create and activate a virtual environment
38
+ python -m venv barista_venv
39
+ source barista_venv/bin/activate
40
+
41
+ # 2. Install the package in editable mode
42
+ python -m pip install -e .
43
+ ```
44
+
45
+ ## Data Preparation
46
+
47
+ 1. Download the data from the [Brain Treebank website](https://braintreebank.dev/). You will also need the `clean_laplacian.json` file from the [PopT codebase](https://github.com/czlwang/PopulationTransformer/blob/main/electrode_selections/clean_laplacian.json).
48
+
49
+ 2. Update the `dataset_dir` config in `barista/braintreebank.yaml` to point to the raw data path.
50
+
51
+ The data directory structure should have the following structure:
52
+
53
+ <details> <summary><strong>Click to expand full directory tree</strong></summary>
54
+
55
+ ```
56
+ braintreebank_data
57
+ |__corrupted_elec.json
58
+ |__clean_laplacian.json
59
+ |__all_subject_data
60
+ | |__ sub_1_trial000.h5
61
+ | |__ sub_1_trial001.h5
62
+ | |__ sub_1_trial002.h5
63
+ | |__ sub_2_trial000.h5
64
+ | |
65
+ | ...
66
+ |
67
+ |__ electrode_labels
68
+ | |__ sub_1
69
+ | | |__ electrode_labels.json
70
+ | |__ sub_2
71
+ | | |__ electrode_labels.json
72
+ | ...
73
+ |
74
+ |__ localization
75
+ | |__ elec_coords_full.csv
76
+ | |__ sub_1
77
+ | | |__ depth-wm.csv
78
+ | |__ sub_2
79
+ | | |__ depth-wm.csv
80
+ | ...
81
+ |
82
+ |__ subject_metadata
83
+ | |__ sub_1_trial000_metadata.json
84
+ | |__ sub_1_trial001_metadata.json
85
+ | |__ sub_1_trial002_metadata.json
86
+ | |__ sub_2_trial000_metadata.json
87
+ | |
88
+ | ...
89
+ |
90
+ |__ subject_timings
91
+ | |__ sub_1_trial000_timings.csv
92
+ | |__ sub_1_trial001_timings.csv
93
+ | |__ sub_1_trial002_timings.csv
94
+ | |__ sub_2_trial000_timings.csv
95
+ | |
96
+ | ...
97
+ |
98
+ |__ transcripts
99
+ | |__ ant-man
100
+ | | |__ features.csv
101
+ | |__ aquaman
102
+ | | |__ features.csv
103
+ | ......
104
+ ```
105
+
106
+ </details>
107
+
108
+
109
+ ## Data Segmentation
110
+
111
+ You must segment the data **before training**. The required arguments depend on the experiment:
112
+
113
+ | Experiment Type | `force_nonoverlap` | `experiment` options |
114
+ --------------------------------------------------|----------------------|----------------------|
115
+ | **1. Random splits**, non-overlapping neural segments (Main Analysis in the paper) | `True` | `sentence_onset`, `speech_vs_nonspeech` |
116
+ | **2. Chronological splits**, increased labels (Appendix K in the paper) | `False` | `sentence_onset_time`, `speech_vs_nonspeech_time`, `volume`, `optical_flow` |
117
+
118
+ ### 1. Generating Random Splits with Non-Overlapping Neural Segments
119
+
120
+ To generate the random splits with non-overlapping neural segments, as used for the main analysis (Section 4), you will need to run the following:
121
+
122
+ ```bash
123
+ python barista/prepare_segments.py \
124
+ --config barista/config/braintreebank.yaml \
125
+ --experiment <sentence_onset|speech_vs_nonspeech>
126
+ ```
127
+
128
+ > ⚠️ Ensure `force_nonoverlap` in `barista/config/braintreebank.yaml` is set to `True` for this experiment. Incorrect settings will produce invalid splits.
129
+
130
+
131
+ This setting should **only** be used with the `sentence_onset` and `speech_vs_nonspeech` experiments.
132
+
133
+ ### 2. Generating Chronological Splits with Increased Label Data
134
+ We can also generate chronological splitting (splitting sessions based on time rather than random shuffling). This approach enables us to increase the number of labeled segments for finetuning by allowing overlap between segments within the same split, while preventing information leakage (i.e., no overlapping neural segments) between train and test splits. To generate the chronological splits used for the evaluation in Appendix K, there are two steps to follow.
135
+
136
+ First, you will need to segment the data using the following command:
137
+
138
+ ```bash
139
+ python barista/prepare_segments.py \
140
+ --config barista/config/braintreebank.yaml \
141
+ --experiment <sentence_onset_time|speech_vs_nonspeech_time|volume|optical_flow>
142
+ ```
143
+
144
+ > ⚠️ Ensure `force_nonoverlap` in `barista/config/braintreebank.yaml` is set to `False` for this experiment. Incorrect settings will produce invalid splits.
145
+
146
+ This setting should **only** be used with the `sentence_onset_time`, `speech_vs_nonspeech_time`, `volume`, and `optical_flow` experiments.
147
+
148
+
149
+ Second, you will need to generate the 5 chronological folds to use during evaluation. To create these different folds, we use the `data/generate_chronological_folds.ipynb` notebook. This notebook automatically will generate 5 different train/valid/test splits across time, while ensuring that all generated splits have both positive and negative labels present. To use the notebook, take the following steps:
150
+
151
+ 1. Open `generate_chronological_folds.ipynb`
152
+
153
+ 2. Update the `_METADATA_FNAMES` variable with the metadata hash string produced from the previous step.
154
+
155
+ 3. Run the notebook to generate the 5 train/valid/test fold pickle files.
156
+
157
+ The notebook will output a pickle file in the same directory as the specified metadata file and it will be dynamically loaded during train/eval time to ensure the right chronological split fold is used.
158
+
159
+ ## Finetuning the Model
160
+ To finetune the model,
161
+
162
+ 1. Set update `finetune_sessions` field in `barista/config/braintreebank.yaml` to the desired finetuning session.
163
+
164
+ 2. Use the following command to run finetuning:
165
+
166
+ ```bash
167
+ python barista/train.py
168
+ ```
169
+
170
+ It is important to ensure the `braintreebank.yaml` fields match precisely with the config used during segmentation generation, including the `experiment` field. Otherwise, the metadata hash string will not match and the experiment will fail. For the chronological folds, the experiment will also fail if the pickle file outlined in the second step of [Generating chronological splits with increased label data](#generating-chronological-splits-with-increased-label-data) hasn't been generated.
171
+
172
+ ### Loading Pretrained Model
173
+
174
+ Pretrained models are available under `pretrained_models/`. Set the `checkpoint_path` in `barista/config/train.yaml` to the specific pretrained model path. e.g. `checkpoint_path: pretrained_models/parcels_chans.ckpt`.
175
+
176
+ > ⚠️ You also need to set the `tokenizer.spatial_grouping` in `barista/config/model.yaml` accordingly for each of the models.
177
+
178
+ | Checkpoint Name | `tokenizer.spatial_grouping` |
179
+ | -------------------- | ---------------------------- |
180
+ | `chans_chans.ckpt` | `coords` |
181
+ | `parcels_chans.ckpt` | `destrieux` |
182
+ | `lobes_chans.ckpt` | `lobes` |
183
+
184
+
185
+ Alternatively, you can pass these as extra argument to train command:
186
+
187
+ **Example finetuning command for Parcel level model**
188
+ ```bash
189
+ python barista/train.py \
190
+ --override \
191
+ tokenizer.spatial_grouping="destrieux" \
192
+ checkpoint_path="pretrained_models/parcels_chans.ckpt"
193
+ ```
194
+
195
+ ## Additional Scripts
196
+
197
+ You can also use the scripts under `barista/utility_scripts` to run the model for a specific setting across different finetuning seeds.
198
+ The run outputs are saved in the results directory specified in the script and can be easily aggregated using `aggregate_runs.py` across different subjects, models, and folds.
199
+
200
+ **Example usage for random splits**
201
+ ```bash
202
+ ./barista/utility_scripts/run_finetune_random_splits.sh \
203
+ --spe destrieux \
204
+ --checkpoint "pretrained_models/parcels_chans.ckpt" \
205
+ --session HOLDSUBJ_1_HS1_1 \
206
+ --gpu 0 \
207
+ --exp sentence_onset
208
+ ```
209
+
210
+
211
+ **Example usage for chronological fold**
212
+ ```bash
213
+ ./barista/utility_scripts/run_finetune_folds.sh \
214
+ --spe destrieux \
215
+ --checkpoint "pretrained_models/parcels_chans.ckpt" \
216
+ --session HOLDSUBJ_1_HS1_1 \
217
+ --gpu 0 \
218
+ --fold 0 \
219
+ --exp sentence_onset_time
220
+ ```
221
+
222
+ ### Aggregating Results
223
+
224
+ You can use `utility_scripts/aggregate_runs.py` to get the average results as a markdown table:
225
+
226
+ ```bash
227
+ python barista/utility_scripts/aggregate_runs.py \
228
+ --results_dir <results|results_folds>
229
+ ```
230
+
231
+
232
+ ## Publication
233
+ [Oganesian, L. L.\*, Hashemi, S.\*, Shanechi, M. M. BaRISTA: Brain Scale Informed Spatiotemporal Representation of Human Intracranial Neural Activity. In Advances in Neural Information Processing Systems 2025.](https://openreview.net/forum?id=LDjBDk3Czb)
234
+
235
+
236
+ **Citation**
237
+ ```
238
+ @inproceedings{
239
+ oganesian2025barista,
240
+ title={BaRISTA: Brain Scale Informed Spatiotemporal Representation of Human Intracranial Neural Activity},
241
+ author={Oganesian, Lucine L. and Hashemi, Saba and Shanechi, Maryam M.},
242
+ booktitle={Advances in Neural Information Processing Systems},
243
+ year={2025},
244
+ url={https://openreview.net/pdf?id=LDjBDk3Czb}
245
+ }
246
+ ```
247
+
248
+
249
+ ## License
250
+ Copyright (c) 2025 University of Southern California <br />
251
+ See full notice in [LICENSE.md](LICENSE.md) <br />
252
+ Lucine L. Oganesian, Saba Hashemi, and Maryam M. Shanechi <br />
253
+ Shanechi Lab, University of Southern California
254
+
255
+
barista/config/braintreebank.yaml ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Directory where the raw data exists.
2
+ dataset_dir: "braintreebank_raw"
3
+ ## Directory where to save the preprocessed data.
4
+ save_dir: "braintreebank_data_segments"
5
+ ## Directory where to store cached stage 1 preprocessed data (i.e., filtered, rereferenced) to then segment.
6
+ stage1_cache_dir: "braintreebank_processed_raw_cache"
7
+
8
+ samp_frequency: 2048 # in Hz. Default: 2048.
9
+ segment_length_s: 3
10
+ region_filtering:
11
+ active: True
12
+ # Use region names that partially match the Destrieux column in the
13
+ # localization file to exclude channels.
14
+ filters:
15
+ - GRID
16
+ - VENT
17
+
18
+ aggregate_labels:
19
+ nan_threshold: 1 # value between 0 and 1, drop segments with more than this percentage of NaNs
20
+ type: threshold # threshold | mean
21
+ threshold: 0.5
22
+
23
+ quantile_numerical_labels:
24
+ active: True
25
+ lower_threshold: 0.25
26
+ higher_threshold: 0.75
27
+
28
+ force_balanced: True
29
+ force_nonoverlap: True
30
+
31
+ ## NOTE: val_ratio and test_ratio only used for shuffle & random splits.
32
+ val_ratio: 0.1
33
+ test_ratio: 0.1
34
+
35
+ ## NOTE: run_ratios only used for chronological splits; use val_ratio and test_ratio in
36
+ ## dataset/single/base.yaml for shuffle & random splits.
37
+ run_ratios: [0.8, 0.1, 0.1]
38
+ run_splits: ["train", "val", "test"]
39
+ chron_fold_num: 0 # Chronological fold number to use. Default is ratios & splits in config.
40
+
41
+ ## This is the step size used when generating negative sample segments for sentence_onset*
42
+ ## and speech_vs_nonspeech* tasks.
43
+ nonword_stepsize_s: # leave empty for no nonword overlap (i.e., step = segment length)
44
+
45
+ trial_alignment: center # center only supported for now. Can extend to other alignments as desired.
46
+ subjects_to_process: # list of which subjects to process, set empty to run for all available
47
+ # - SUBJ_1
48
+ # - SUBJ_2
49
+ # - SUBJ_3
50
+ # - SUBJ_4
51
+ # - SUBJ_5
52
+ # - SUBJ_6
53
+ # - SUBJ_7
54
+ # - SUBJ_8
55
+ # - SUBJ_9
56
+ # - SUBJ_10
57
+ - HOLDSUBJ_1
58
+ - HOLDSUBJ_2
59
+ - HOLDSUBJ_3
60
+ - HOLDSUBJ_4
61
+ - HOLDSUBJ_6
62
+ - HOLDSUBJ_7
63
+ - HOLDSUBJ_10
64
+
65
+ # Options:
66
+ # "speech_vs_nonspeech" | "sentence_onset" [random split]
67
+ # "sentence_onset_time" | "speech_vs_nonspeech_time" | "volume" | "optical_flow" [chronological split]
68
+ experiment: "sentence_onset_time"
69
+
70
+ ### Dataset processing
71
+ skip_segment_generation_completely: False
72
+ force_reprocess_stage1: False
73
+ force_reprocess_stage2: False
74
+ force_recreate_spatial_groupings: False
75
+
76
+ processing_save_interval: 100 # save files every # of segments
77
+ processing_log_interval: 50
78
+
79
+ use_fixed_seed_for_splitter: True
80
+ split_together_length_s: 3 # Note: Recommended to use the same value as segment_length_s above
81
+
82
+ shuffle_dataloader: True
83
+
84
+ # Note: Recommendation is to use the full subject_session label here.
85
+ pretrain_sessions:
86
+ - SUBJ_1_S1_0
87
+ # - SUBJ_1_S1_2
88
+ # - SUBJ_2_S2_0
89
+ # - SUBJ_2_S2_1
90
+ # - SUBJ_2_S2_2
91
+ # - SUBJ_2_S2_3
92
+ # - SUBJ_2_S2_4
93
+ # - SUBJ_3_S3_1
94
+ # - SUBJ_3_S3_2
95
+ # - SUBJ_4_S4_1
96
+ # - SUBJ_5_S5_0
97
+ # - SUBJ_6_S6_0
98
+ # - SUBJ_6_S6_1
99
+ # - SUBJ_7_S7_1
100
+ # - SUBJ_8_S8_0
101
+ # - SUBJ_9_S9_0
102
+ # - SUBJ_10_S10_1
103
+ finetune_sessions:
104
+ # - SUBJ_2_S2_5 # Pseudo held out
105
+ # - SUBJ_4_S4_2 # Pseudo held out
106
+ - HOLDSUBJ_1_HS1_1
107
+ # - HOLDSUBJ_2_HS2_6
108
+ # - HOLDSUBJ_3_HS3_0
109
+ # - HOLDSUBJ_4_HS4_0
110
+ # - HOLDSUBJ_6_HS6_4
111
+ # - HOLDSUBJ_7_HS7_0
112
+ # - HOLDSUBJ_10_HS10_0
113
+
114
+ spatial_groupings_to_create:
115
+ - coords
116
+ - destrieux
117
+ - lobes
barista/config/model.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ backbone:
2
+ num_layers: 12
3
+ d_hidden: 64
4
+ d_input: ${backbone.d_hidden} # same as d hidden
5
+ d_out: ${backbone.d_hidden} # same as d hidden
6
+ mlp_ratio: 4
7
+ norm: rmsnorm
8
+ norm_eps: 1e-8
9
+ activation: gelu
10
+ num_heads: 4
11
+ max_position: 1024
12
+ dropout: 0.1
13
+
14
+ tokenizer:
15
+ temporal_encoder:
16
+ input_dims: 128
17
+ output_dims: 128
18
+ hidden_dims: 5
19
+ depth: 4 # Zero-index (will have 5 convolution blocks all together)
20
+ kernel_size: 3
21
+ stride: 1
22
+ enable_checkpointing: False
23
+
24
+ temporal_subsegment_len: 512
25
+ temporal_subsegment_step: 512
26
+
27
+ samp_frequency: 2048
28
+ num_seconds: 3
29
+
30
+ d_hidden: ${backbone.d_input}
31
+
32
+ add_spatial_encoding: True
33
+ spatial_grouping: destrieux # coords | destrieux | lobes
34
+
35
+ embedding_max_dim: # leave empty for no normalization of embeddings
36
+ embedding_init_scale: 1.0
barista/config/train.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 0
2
+ checkpoint_path: "pretrained_models/parcels_chans.ckpt"
3
+ device: cuda:0
4
+ epochs: 30
5
+ dataloader:
6
+ drop_last: False
7
+ drop_last_val: False
8
+ num_workers: 16
9
+ batch_size: 128
10
+ persistent_workers: False
11
+ pin_memory: True
12
+ optimization:
13
+ finetune_lr: 1e-4
14
+ new_param_lr: 1e-3
15
+ freeze_tokenizer: True
barista/data/atlas.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Enums for the various spatial scales explored.
2
+
3
+ Useful references for the atlas parcels:
4
+ https://pmc.ncbi.nlm.nih.gov/articles/PMC2937159/pdf/nihms213933.pdf
5
+ https://surfer.nmr.mgh.harvard.edu/pub/articles/HBM09-Destrieux-Sulcal.pdf
6
+
7
+ Useful references for mapping atlas parcels to lobes (see below):
8
+ https://surfer.nmr.mgh.harvard.edu/fswiki/CorticalParcellation
9
+ https://www.frontiersin.org/journals/neuroscience/articles/10.3389/fnins.2012.00171/full#h12
10
+ """
11
+ import enum
12
+
13
+ UNKNWON_STR = "UNKNOWN"
14
+
15
+
16
+ class EnumWithUnknown(enum.Enum):
17
+ @classmethod
18
+ def get_enum(cls, value):
19
+ value = (value or UNKNWON_STR).upper()
20
+ try:
21
+ return cls[value]
22
+ except KeyError as e:
23
+ raise NotImplementedError(
24
+ f"Unknown value '{value}' for enum {cls.__name__}"
25
+ ) from e
26
+
27
+
28
+ class Destrieux(EnumWithUnknown):
29
+ UNKNOWN = 0
30
+ LEFT_AMYGDALA = 1
31
+ LEFT_HIPPOCAMPUS = 2
32
+ LEFT_INF_LAT_VENT = 3
33
+ LEFT_PUTAMEN = 4
34
+ RIGHT_AMYGDALA = 5
35
+ RIGHT_HIPPOCAMPUS = 6
36
+ RIGHT_INF_LAT_VENT = 7
37
+ RIGHT_PUTAMEN = 8
38
+ CTX_LH_G_INS_LG_AND_S_CENT_INS = 9
39
+ CTX_LH_G_AND_S_CINGUL_ANT = 10
40
+ CTX_LH_G_AND_S_CINGUL_MID_ANT = 11
41
+ CTX_LH_G_AND_S_CINGUL_MID_POST = 12
42
+ CTX_LH_G_AND_S_SUBCENTRAL = 13
43
+ CTX_LH_G_CINGUL_POST_DORSAL = 14
44
+ CTX_LH_G_FRONT_INF_OPERCULAR = 15
45
+ CTX_LH_G_FRONT_INF_ORBITAL = 16
46
+ CTX_LH_G_FRONT_INF_TRIANGUL = 17
47
+ CTX_LH_G_FRONT_MIDDLE = 18
48
+ CTX_LH_G_FRONT_SUP = 19
49
+ CTX_LH_G_INSULAR_SHORT = 20
50
+ CTX_LH_G_OC_TEMP_MED_PARAHIP = 21
51
+ CTX_LH_G_OCCIPITAL_MIDDLE = 22
52
+ CTX_LH_G_ORBITAL = 23
53
+ CTX_LH_G_PARIET_INF_ANGULAR = 24
54
+ CTX_LH_G_PARIET_INF_SUPRAMAR = 25
55
+ CTX_LH_G_PARIETAL_SUP = 26
56
+ CTX_LH_G_POSTCENTRAL = 27
57
+ CTX_LH_G_PRECENTRAL = 28
58
+ CTX_LH_G_PRECUNEUS = 29
59
+ CTX_LH_G_RECTUS = 30
60
+ CTX_LH_G_TEMP_SUP_G_T_TRANSV = 31
61
+ CTX_LH_G_TEMP_SUP_LATERAL = 32
62
+ CTX_LH_G_TEMP_SUP_PLAN_POLAR = 33
63
+ CTX_LH_G_TEMP_SUP_PLAN_TEMPO = 34
64
+ CTX_LH_G_TEMPORAL_INF = 35
65
+ CTX_LH_G_TEMPORAL_MIDDLE = 36
66
+ CTX_LH_LAT_FIS_ANT_HORIZONT = 37
67
+ CTX_LH_LAT_FIS_ANT_VERTICAL = 38
68
+ CTX_LH_LAT_FIS_POST = 39
69
+ CTX_LH_POLE_TEMPORAL = 40
70
+ CTX_LH_S_CALCARINE = 41
71
+ CTX_LH_S_CENTRAL = 42
72
+ CTX_LH_S_CINGUL_MARGINALIS = 43
73
+ CTX_LH_S_CIRCULAR_INSULA_ANT = 44
74
+ CTX_LH_S_CIRCULAR_INSULA_INF = 45
75
+ CTX_LH_S_CIRCULAR_INSULA_SUP = 46
76
+ CTX_LH_S_COLLAT_TRANSV_ANT = 47
77
+ CTX_LH_S_FRONT_INF = 48
78
+ CTX_LH_S_FRONT_MIDDLE = 49
79
+ CTX_LH_S_FRONT_SUP = 50
80
+ CTX_LH_S_INTRAPARIET_AND_P_TRANS = 51
81
+ CTX_LH_S_OC_TEMP_MED_AND_LINGUAL = 52
82
+ CTX_LH_S_ORBITAL_H_SHAPED = 53
83
+ CTX_LH_S_ORBITAL_LATERAL = 54
84
+ CTX_LH_S_ORBITAL_MED_OLFACT = 55
85
+ CTX_LH_S_PARIETO_OCCIPITAL = 56
86
+ CTX_LH_S_PERICALLOSAL = 57
87
+ CTX_LH_S_POSTCENTRAL = 58
88
+ CTX_LH_S_PRECENTRAL_INF_PART = 59
89
+ CTX_LH_S_PRECENTRAL_SUP_PART = 60
90
+ CTX_LH_S_SUBORBITAL = 61
91
+ CTX_LH_S_SUBPARIETAL = 62
92
+ CTX_LH_S_TEMPORAL_INF = 63
93
+ CTX_LH_S_TEMPORAL_SUP = 64
94
+ CTX_LH_S_TEMPORAL_TRANSVERSE = 65
95
+ CTX_RH_G_INS_LG_AND_S_CENT_INS = 66
96
+ CTX_RH_G_AND_S_CINGUL_ANT = 67
97
+ CTX_RH_G_AND_S_CINGUL_MID_ANT = 68
98
+ CTX_RH_G_AND_S_CINGUL_MID_POST = 69
99
+ CTX_RH_G_AND_S_FRONTOMARGIN = 70
100
+ CTX_RH_G_AND_S_PARACENTRAL = 71
101
+ CTX_RH_G_AND_S_SUBCENTRAL = 72
102
+ CTX_RH_G_CINGUL_POST_DORSAL = 73
103
+ CTX_RH_G_FRONT_INF_OPERCULAR = 74
104
+ CTX_RH_G_FRONT_INF_ORBITAL = 75
105
+ CTX_RH_G_FRONT_INF_TRIANGUL = 76
106
+ CTX_RH_G_FRONT_MIDDLE = 77
107
+ CTX_RH_G_FRONT_SUP = 78
108
+ CTX_RH_G_INSULAR_SHORT = 79
109
+ CTX_RH_G_OC_TEMP_LAT_FUSIFOR = 80
110
+ CTX_RH_G_OC_TEMP_MED_PARAHIP = 81
111
+ CTX_RH_G_ORBITAL = 82
112
+ CTX_RH_G_PARIET_INF_ANGULAR = 83
113
+ CTX_RH_G_PARIET_INF_SUPRAMAR = 84
114
+ CTX_RH_G_PRECENTRAL = 85
115
+ CTX_RH_G_RECTUS = 86
116
+ CTX_RH_G_TEMP_SUP_G_T_TRANSV = 87
117
+ CTX_RH_G_TEMP_SUP_LATERAL = 88
118
+ CTX_RH_G_TEMP_SUP_PLAN_POLAR = 89
119
+ CTX_RH_G_TEMP_SUP_PLAN_TEMPO = 90
120
+ CTX_RH_G_TEMPORAL_INF = 91
121
+ CTX_RH_G_TEMPORAL_MIDDLE = 92
122
+ CTX_RH_LAT_FIS_ANT_HORIZONT = 93
123
+ CTX_RH_LAT_FIS_ANT_VERTICAL = 94
124
+ CTX_RH_LAT_FIS_POST = 95
125
+ CTX_RH_POLE_TEMPORAL = 96
126
+ CTX_RH_S_CENTRAL = 97
127
+ CTX_RH_S_CINGUL_MARGINALIS = 98
128
+ CTX_RH_S_CIRCULAR_INSULA_ANT = 99
129
+ CTX_RH_S_CIRCULAR_INSULA_INF = 100
130
+ CTX_RH_S_CIRCULAR_INSULA_SUP = 101
131
+ CTX_RH_S_COLLAT_TRANSV_ANT = 102
132
+ CTX_RH_S_FRONT_INF = 103
133
+ CTX_RH_S_FRONT_MIDDLE = 104
134
+ CTX_RH_S_FRONT_SUP = 105
135
+ CTX_RH_S_INTRAPARIET_AND_P_TRANS = 106
136
+ CTX_RH_S_OC_TEMP_LAT = 107
137
+ CTX_RH_S_OC_TEMP_MED_AND_LINGUAL = 108
138
+ CTX_RH_S_ORBITAL_H_SHAPED = 109
139
+ CTX_RH_S_ORBITAL_LATERAL = 110
140
+ CTX_RH_S_ORBITAL_MED_OLFACT = 111
141
+ CTX_RH_S_PERICALLOSAL = 112
142
+ CTX_RH_S_POSTCENTRAL = 113
143
+ CTX_RH_S_PRECENTRAL_INF_PART = 114
144
+ CTX_RH_S_PRECENTRAL_SUP_PART = 115
145
+ CTX_RH_S_SUBORBITAL = 116
146
+ CTX_RH_S_SUBPARIETAL = 117
147
+ CTX_RH_S_TEMPORAL_INF = 118
148
+ CTX_RH_S_TEMPORAL_SUP = 119
149
+ CTX_RH_S_TEMPORAL_TRANSVERSE = 120
150
+
151
+
152
+ class Lobes(EnumWithUnknown):
153
+ """Maps the Desikan-Killany Atlas regions to lobes."""
154
+ UNKNOWN = 0
155
+
156
+ ## Amygdala (Left, Right)
157
+ LEFT_AMYGDALA = 1
158
+ RIGHT_AMYGDALA = 2
159
+
160
+ ## Hippocampus (Left, Right)
161
+ LEFT_HIPPOCAMPUS = 3
162
+ RIGHT_HIPPOCAMPUS = 4
163
+
164
+ ## Frontal Lobe (Left)
165
+ CTX_LH_SUPERIORFRONTAL = 5
166
+ CTX_LH_ROSTRALMIDDLEFRONTAL = 5
167
+ CTX_LH_CAUDALMIDDLEFRONTAL = 5
168
+ CTX_LH_PARSOPERCULARIS = 5
169
+ CTX_LH_PARSORBITALIS = 5
170
+ CTX_LH_PARSTRIANGULARIS = 5
171
+ CTX_LH_LATERALORBITOFRONTAL = 5
172
+ CTX_LH_MEDIALORBITOFRONTAL = 5
173
+ CTX_LH_PRECENTRAL = 5
174
+ CTX_LH_PARACENTRAL = 5
175
+
176
+ ## Frontal Cortex (Right)
177
+ CTX_RH_SUPERIORFRONTAL = 6
178
+ CTX_RH_ROSTRALMIDDLEFRONTAL = 6
179
+ CTX_RH_CAUDALMIDDLEFRONTAL = 6
180
+ CTX_RH_PARSOPERCULARIS = 6
181
+ CTX_RH_PARSORBITALIS = 6
182
+ CTX_RH_PARSTRIANGULARIS = 6
183
+ CTX_RH_LATERALORBITOFRONTAL = 6
184
+ CTX_RH_MEDIALORBITOFRONTAL = 6
185
+ CTX_RH_PRECENTRAL = 6
186
+ CTX_RH_PARACENTRAL = 6
187
+ # Frontal pole should go here in the future
188
+
189
+ ## Parietal Lobe (Left)
190
+ CTX_LH_SUPERIORPARIETAL = 7
191
+ CTX_LH_INFERIORPARIETAL = 7
192
+ CTX_LH_SUPRAMARGINAL = 7
193
+ CTX_LH_POSTCENTRAL = 7
194
+ CTX_LH_PRECUNEUS = 7
195
+
196
+ ## Parietal Lobe (Right)
197
+ CTX_RH_SUPERIORPARIETAL = 8
198
+ CTX_RH_INFERIORPARIETAL = 8
199
+ CTX_RH_SUPRAMARGINAL = 8
200
+ CTX_RH_POSTCENTRAL = 8
201
+ CTX_RH_PRECUNEUS = 8
202
+
203
+ ## Temporal Lobe (Left)
204
+ CTX_LH_SUPERIORTEMPORAL = 9
205
+ CTX_LH_MIDDLETEMPORAL = 9
206
+ CTX_LH_INFERIORTEMPORAL = 9
207
+ CTX_LH_BANKSSTS = 9
208
+ CTX_LH_FUSIFORM = 9
209
+ CTX_LH_TRANSVERSETEMPORAL = 9
210
+ CTX_LH_ENTORHINAL = 9
211
+ CTX_LH_TEMPORALPOLE = 9
212
+ CTX_LH_PARAHIPPOCAMPAL = 9
213
+
214
+ ## Temporal Lobe (Right)
215
+ CTX_RH_SUPERIORTEMPORAL = 10
216
+ CTX_RH_MIDDLETEMPORAL = 10
217
+ CTX_RH_INFERIORTEMPORAL = 10
218
+ CTX_RH_BANKSSTS = 10
219
+ CTX_RH_FUSIFORM = 10
220
+ CTX_RH_TRANSVERSETEMPORAL = 10
221
+ CTX_RH_ENTORHINAL = 10
222
+ CTX_RH_TEMPORALPOLE = 10
223
+ CTX_RH_PARAHIPPOCAMPAL = 10
224
+
225
+ ## Occipital Lobe (Left) - ENUM 11 RESERVED
226
+
227
+ ## Occipital Lobe (Right) - ENUM 12 RESERVED
228
+
229
+ ## Cingulate (Left)
230
+ CTX_LH_ROSTRALANTERIORCINGULATE = 13
231
+ CTX_LH_CAUDALANTERIORCINGULATE = 13
232
+ CTX_LH_POSTERIORCINGULATE = 13
233
+ CTX_LH_ISTHMUSCINGULATE = 13
234
+
235
+ ## Cingulate (Right)
236
+ CTX_RH_ROSTRALANTERIORCINGULATE = 14
237
+ CTX_RH_CAUDALANTERIORCINGULATE = 14
238
+ CTX_RH_POSTERIORCINGULATE = 14
239
+ CTX_RH_ISTHMUSCINGULATE = 14
240
+
241
+ ## Insula (Left, Right)
242
+ CTX_LH_INSULA = 15
243
+ CTX_RH_INSULA = 16
244
+
245
+ ## Putamen (Left, Right)
246
+ LEFT_PUTAMEN = 17
247
+ RIGHT_PUTAMEN = 18
248
+
249
+ ## Ventricles (Left, Right)
250
+ LEFT_INF_LAT_VENT = 19
251
+ RIGHT_INF_LAT_VENT = 20
barista/data/available_sessions.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ def enumval_formatter(subject, trial_list):
4
+ return [f"S{subject}_{trial}" for trial in trial_list]
5
+
6
+ def holdout_enumval_formatter(subject, trial_list):
7
+ return [f"HS{subject}_{trial}" for trial in trial_list]
8
+
9
+ class BrainTreebankAvailableSessions(Enum):
10
+ SUBJ_1: list = enumval_formatter("1", ["0", "2"])
11
+ SUBJ_2: list = enumval_formatter("2", ["0", "1", "2", "3", "4", "5"])
12
+ SUBJ_3: list = enumval_formatter("3", ["1", "2"])
13
+ SUBJ_4: list = enumval_formatter("4", ["1", "2"])
14
+ SUBJ_5: list = enumval_formatter("5", ["0"])
15
+ SUBJ_6: list = enumval_formatter("6", ["0", "1"])
16
+ SUBJ_7: list = enumval_formatter("7", ["1"])
17
+ SUBJ_8: list = enumval_formatter("8", ["0"])
18
+ SUBJ_9: list = enumval_formatter("9", ["0"])
19
+ SUBJ_10: list = enumval_formatter("10", ["1"])
20
+
21
+ ## Heldout trials.
22
+ HOLDSUBJ_1: list = holdout_enumval_formatter("1", ["1"])
23
+ HOLDSUBJ_2: list = holdout_enumval_formatter("2", ["6"])
24
+ HOLDSUBJ_3: list = holdout_enumval_formatter("3", ["0"])
25
+ HOLDSUBJ_4: list = holdout_enumval_formatter("4", ["0"])
26
+ HOLDSUBJ_6: list = holdout_enumval_formatter("6", ["4"])
27
+ HOLDSUBJ_7: list = holdout_enumval_formatter("7", ["0"])
28
+ HOLDSUBJ_10: list = holdout_enumval_formatter("10", ["0"])
barista/data/braintreebank_data_helpers.py ADDED
@@ -0,0 +1,741 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Code to handle data I/O, parsing, and data/feature preprocessing for the BrainTreebank dataset.
2
+
3
+ Functionality in this module is based on the implementations found in the following
4
+ repositories, but have been modified as needed to be used as outlined in the BaRISTA paper:
5
+ https://github.com/czlwang/BrainBERT/tree/master/data
6
+ https://github.com/czlwang/PopulationTransformer/tree/main/data
7
+ https://github.com/czlwang/brain_treebank_code_release/tree/master/data
8
+ """
9
+
10
+ import json
11
+ import os
12
+ from collections import OrderedDict
13
+ from enum import Enum
14
+ from typing import Dict, List, Union
15
+
16
+ import h5py
17
+ import numpy as np
18
+ import ordered_set
19
+ import pandas as pd
20
+ import scipy
21
+ import sklearn.preprocessing as sk_preprocessing
22
+
23
+ # Data frame column IDs for *_timings.csv and features.csv files.
24
+ _START_COL = "start"
25
+ _END_COL = "end"
26
+ _LBL_COL = "pos"
27
+ _TRIG_TIME_COL = "movie_time"
28
+ _START_WALLTIME = "start_time"
29
+ _TRIG_IDX_COL = "index"
30
+ _EST_IDX_COL = "est_idx"
31
+ _EST_END_IDX_COL = "est_end_idx"
32
+ _WORD_TIME_COL = "word_time"
33
+ _WORD_TEXT_COL = "text"
34
+ _IS_ONSET_COL = "is_onset"
35
+ _IS_OFFSET_COL = "is_offset"
36
+
37
+ # Data frame column IDs elec_coords_full.csv file.
38
+ _ELECTRODE_INFO = "Electrode"
39
+
40
+
41
+ class BrainTreebankDatasetNames(Enum):
42
+ PRETRAIN = "pretrain"
43
+
44
+ ## Random splits downstream tasks.
45
+ SENTENCE_ONSET = "sentence_onset"
46
+ SPEECH_VS_NONSPEECH = "speech_vs_nonspeech"
47
+
48
+ ## Chronological split downstream tasks.
49
+ SENTENCE_ONSET_TIME = "sentence_onset_time"
50
+ SPEECH_VS_NONSPEECH_TIME = "speech_vs_nonspeech_time"
51
+ VOLUME = "volume"
52
+ OPTICAL_FLOW = "optical_flow"
53
+
54
+ @classmethod
55
+ def get_modes(cls, modes_str: Union[str, List[str]]):
56
+ if isinstance(modes_str, str):
57
+ return cls(modes_str)
58
+ else:
59
+ modes = [cls(mode_str) for mode_str in modes_str]
60
+ return modes
61
+
62
+ def get_abbrv(self, c=1) -> str:
63
+ return "".join([b[:c] for b in self.value.split("_")])
64
+
65
+
66
+ class BrainTreebankDatasetPathManager:
67
+ """Manage file paths for Brain Treebank dataset
68
+
69
+ Expected dataset directory structure:
70
+ braintreebank_data
71
+ |__corrupted_elec.json
72
+ |__clean_laplacian.json
73
+ |__all_subject_data
74
+ | |__ sub_1_trial000.h5
75
+ | |__ sub_1_trial001.h5
76
+ | |__ sub_1_trial002.h5
77
+ | |__ sub_2_trial000.h5
78
+ | |
79
+ | ...
80
+ |
81
+ |__ electrode_labels
82
+ | |__ sub_1
83
+ | | |__ electrode_labels.json
84
+ | |__ sub_2
85
+ | | |__ electrode_labels.json
86
+ | ...
87
+ |
88
+ |__ localization
89
+ | |__ elec_coords_full.csv
90
+ | |__ sub_1
91
+ | | |__ depth-wm.csv
92
+ | |__ sub_2
93
+ | | |__ depth-wm.csv
94
+ | ...
95
+ |
96
+ |__ subject_metadata
97
+ | |__ sub_1_trial000_metadata.json
98
+ | |__ sub_1_trial001_metadata.json
99
+ | |__ sub_1_trial002_metadata.json
100
+ | |__ sub_2_trial000_metadata.json
101
+ | |
102
+ | ...
103
+ |
104
+ |__ subject_timings
105
+ | |__ sub_1_trial000_timings.csv
106
+ | |__ sub_1_trial001_timings.csv
107
+ | |__ sub_1_trial002_timings.csv
108
+ | |__ sub_2_trial000_timings.csv
109
+ | |
110
+ | ...
111
+ |
112
+ |__ transcripts
113
+ | |__ ant-man
114
+ | | |__ features.csv
115
+ | |__ aquaman
116
+ | | |__ features.csv
117
+ | ......
118
+ """
119
+
120
+ def __init__(self, dataset_dir: str):
121
+ self.dataset_dir = dataset_dir
122
+
123
+ # Path to neural data h5 file.
124
+ self.neural_data_file = os.path.join(
125
+ self.dataset_dir,
126
+ "all_subject_data",
127
+ "sub_{}_trial00{}.h5",
128
+ )
129
+
130
+ # Path to electrode channel name meta information.
131
+ self.raw_electrodes_meta_file = os.path.join(
132
+ self.dataset_dir, "electrode_labels", "sub_{}", "electrode_labels.json"
133
+ )
134
+
135
+ # Path to brain regions csv file.
136
+ self.regions_file = os.path.join(
137
+ self.dataset_dir, "localization", "sub_{}", "depth-wm.csv"
138
+ )
139
+
140
+ # Path to trial movie trigger times to align features with neural activity.
141
+ self.movie_triggers_file = os.path.join(
142
+ self.dataset_dir, "subject_timings", "sub_{}_trial00{}_timings.csv"
143
+ )
144
+
145
+ # Path to trial meta information.
146
+ self.trial_meta = os.path.join(
147
+ self.dataset_dir, "subject_metadata", "sub_{}_trial00{}_metadata.json"
148
+ )
149
+
150
+ # Path to extracted features csv file.
151
+ self.features_file = os.path.join(
152
+ self.dataset_dir, "transcripts", "{}", "features.csv"
153
+ )
154
+
155
+ self._CORRUPTED_ELECTRODES_PATH = os.path.join(
156
+ self.dataset_dir, "corrupted_elec.json"
157
+ )
158
+ self._CLEAN_LAPLACIAN = os.path.join(
159
+ self.dataset_dir, "clean_laplacian.json"
160
+ )
161
+
162
+ def format_subject(self, subject: str) -> str:
163
+ """AvailableSessions stores subjects as SUBJ_#. Strips 'SUBJ' prefix here."""
164
+ return subject.split("_")[-1]
165
+
166
+ def format_session(self, session: str) -> str:
167
+ """AvailableSessions stores subject sessions with a prefix as (H)S_#. Strips prefix here."""
168
+ return session.split("_")[-1]
169
+
170
+ def get_raw_data_filepath(self, subject: str, session: str) -> str:
171
+ """Get raw data file path for a given subject and trial.
172
+
173
+ Args:
174
+ subject: subject str e.g. 1
175
+ session: trial int e.g. 0
176
+ """
177
+ return self.neural_data_file.format(
178
+ self.format_subject(subject), self.format_session(session)
179
+ )
180
+
181
+ def get_raw_electrode_channel_names_filepath(self, subject: str) -> str:
182
+ return self.raw_electrodes_meta_file.format(self.format_subject(subject))
183
+
184
+ def get_localization_filepath(self, subject: str) -> str:
185
+ return self.regions_file.format(self.format_subject(subject))
186
+
187
+ def get_noise_area_filepath(self) -> str:
188
+ return self._CORRUPTED_ELECTRODES_PATH
189
+
190
+ def get_clean_laplacian_filepath(self) -> str:
191
+ return self._CLEAN_LAPLACIAN
192
+
193
+ def get_movie_triggers_filepath(self, subject: str, trial: str) -> str:
194
+ return self.movie_triggers_file.format(
195
+ self.format_subject(subject), self.format_session(trial)
196
+ )
197
+
198
+ def get_features_filepath(self, subject: str, trial: str) -> str:
199
+ with open(
200
+ self.trial_meta.format(
201
+ self.format_subject(subject), self.format_session(trial)
202
+ ),
203
+ "r",
204
+ ) as f:
205
+ meta_dict = json.load(f)
206
+ title = meta_dict["title"]
207
+ movie_id = meta_dict["filename"]
208
+
209
+ print(f"Loading features for movie {title}.")
210
+ return self.features_file.format(movie_id), title
211
+
212
+
213
+ class BrainTreebankDatasetRawDataHelper:
214
+ """Manages loading data from the BrainTreebank dataset files.
215
+
216
+ Check each method docstring for file information.
217
+ """
218
+ def __init__(
219
+ self,
220
+ path_manager: BrainTreebankDatasetPathManager,
221
+ samp_frequency: int = 2048,
222
+ ):
223
+ self.path_manager = path_manager
224
+ self.samp_frequency = samp_frequency
225
+ self.localization_df = {}
226
+ self.trial_triggers_cache = {}
227
+
228
+ def get_raw_file(
229
+ self,
230
+ subject: str,
231
+ trial: str,
232
+ ) -> dict:
233
+ """File load from the file noise info meta hashmap.
234
+
235
+ Args:
236
+ subject: str or int. Subject to index by.
237
+ trial: str or int. Subject trial to index by.
238
+
239
+ Returns:
240
+ A dictionary containing following keys:
241
+ data: np.ndarray (n_samples x channels) -- actual recordings
242
+ time: np.ndarray (n_samples) -- timestamps when movie trigger times recorded
243
+ samp_frequency: sampling rate Hz
244
+ raw_electrode_info: list of channel names, indices are in order of columns in data
245
+ """
246
+ path = self.path_manager.get_raw_data_filepath(subject, trial)
247
+ with h5py.File(path, "r") as hf:
248
+ raw_data = hf["data"]
249
+
250
+ channel_labels = self.get_electrode_info(subject)
251
+
252
+ raw_data_n_channels = len(raw_data.keys())
253
+ if subject == "SUBJ_1" or subject == "HOLDSUBJ_1":
254
+ raw_data_n_channels -= 1 # Will ignore last channel for subject 1 based on dataset author's comment
255
+ assert (
256
+ len(channel_labels) == raw_data_n_channels
257
+ ), "Channel count mismatch between h5 and json."
258
+
259
+ # Extracts a numpy array from h5 dataset (may take a few minutes).
260
+ electrode_data = []
261
+ for i in range(len(channel_labels)):
262
+ electrode_data.append(raw_data[f"electrode_{i}"][:])
263
+
264
+ electrode_data = np.stack(electrode_data)
265
+
266
+ return {
267
+ "data": electrode_data.T, # n_samples x n_channels
268
+ "time": self._extract_neural_timestamps(subject, trial, electrode_data),
269
+ "samp_frequency": self.samp_frequency,
270
+ "electrode_info": channel_labels,
271
+ }
272
+
273
+ def get_corrupted_elecs(self, subject: str) -> List[str]:
274
+ """
275
+ Returns:
276
+ a list of strings corresponding to corrupted electrode channel names.
277
+ """
278
+ with open(self.path_manager.get_noise_area_filepath(), "r") as f:
279
+ corrupted_elecs = json.load(f)
280
+ return corrupted_elecs[f"subject{self.path_manager.format_subject(subject)}"]
281
+
282
+ def get_clean_elecs(self, subject: str) -> List[str]:
283
+ """
284
+ Returns:
285
+ a list of strings corresponding to clean electrode channel names.
286
+ """
287
+ with open(self.path_manager.get_clean_laplacian_filepath(), "r") as f:
288
+ elecs = json.load(f)
289
+ return elecs[f"sub_{self.path_manager.format_subject(subject)}"]
290
+
291
+ def _elec_name_strip(self, x):
292
+ return x.replace("*", "").replace("#", "").replace("_", "")
293
+
294
+ def get_electrode_info(self, subject: str) -> List[str]:
295
+ """
296
+ Returns list of electrodes for the specified trial.
297
+ NOTE: the order of these labels is important. Their position corresponds with a row in data.h5
298
+ """
299
+ with open(
300
+ self.path_manager.get_raw_electrode_channel_names_filepath(subject), "r"
301
+ ) as f:
302
+ electrode_labels = json.load(f)
303
+
304
+ electrode_labels = [self._elec_name_strip(e) for e in electrode_labels]
305
+ return electrode_labels
306
+
307
+ def get_channel_localization_raw(self, subject: str) -> dict:
308
+ # Lazy loading.
309
+ if subject not in self.localization_df:
310
+ df = pd.read_csv(self.path_manager.get_localization_filepath(subject))
311
+ df[_ELECTRODE_INFO] = df[_ELECTRODE_INFO].apply(self._elec_name_strip)
312
+ self.localization_df[subject] = df
313
+ return self.localization_df[subject]
314
+
315
+ def get_channel_localization(
316
+ self, subject: str, channel_name: str
317
+ ) -> dict:
318
+ """Extract localization information for given subject and channel label.
319
+
320
+ Channel localization info is a pandas DataFrame with the headers:
321
+ ID: electrode channel ID
322
+ Z: Z coordinate (subject specific, to the best of our understanding)
323
+ X: X coordinate (subject specific, to the best of our understanding)
324
+ Y: Y coordinate (subject specific, to the best of our understanding)
325
+ Hemisphere: 0 (right) vs 1 (left)
326
+ Subject: sub_<id>
327
+ Electrode: Electrode channel label
328
+ Region: region based on Destrieux atlas
329
+
330
+ NOTE: https://surfer.nmr.mgh.harvard.edu/fswiki/CorticalParcellation
331
+
332
+ Returns:
333
+ Dictionary with the following keys:
334
+ hemi: hemisphere
335
+ region_info: Destrieux parcel info
336
+ channel_stem: electrode name
337
+ coords: LIP coords
338
+ """
339
+ df = self.get_channel_localization_raw(subject)
340
+ channel_row = df.loc[df[_ELECTRODE_INFO] == channel_name]
341
+
342
+ if len(channel_row) == 0:
343
+ return {}
344
+
345
+ def parse_region_str(region_str):
346
+ if "_" in region_str:
347
+ split_region_str = region_str.split("_")
348
+ hemi = "L" if split_region_str[1].lower() == "lh" else "R"
349
+ region_info = "_".join(split_region_str[2:])
350
+ elif "-" in region_str and "_" not in region_str:
351
+ split_region_str = region_str.split("-")
352
+ hemi = "L" if split_region_str[0].lower() == "left" else "R"
353
+ region_info = split_region_str[-1]
354
+ elif region_str.lower() == "unknown":
355
+ hemi = "UNKNOWN"
356
+ region_info = "UNKNOWN"
357
+ else:
358
+ raise ValueError(f"Unsupported region_str: {region_str}.")
359
+ return hemi, region_info
360
+
361
+ hemi, region_info = parse_region_str(channel_row.iloc[0]["Destrieux"])
362
+ channel_stem, _ = BrainTreebankDatasetRawDataHelper.stem_electrode_name(
363
+ channel_name
364
+ )
365
+ coords = channel_row.iloc[0][["L", "I", "P"]].to_numpy().astype(np.int64)
366
+ return {
367
+ "hemi": hemi,
368
+ "region_info": region_info,
369
+ "channel_stem": channel_stem,
370
+ "coords": coords,
371
+ }
372
+
373
+ @classmethod
374
+ def stem_electrode_name(cls, name):
375
+ """Need to stem the electrode channel names to find neighbors.
376
+
377
+ Functionality from the BrainBERT repository:
378
+ https://github.com/czlwang/BrainBERT/tree/master/data
379
+ """
380
+ # names look like 'O1aIb4', 'O1aIb5', 'O1aIb6', 'O1aIb7'
381
+ # names look like 'T1b2
382
+ name = name.replace("*", "") # some stems have * in name
383
+ found_stem_end = False
384
+ stem, num = [], []
385
+ for c in reversed(name):
386
+ if c.isalpha():
387
+ found_stem_end = True
388
+ if found_stem_end:
389
+ stem.append(c)
390
+ else:
391
+ num.append(c)
392
+ return "".join(reversed(stem)), int("".join(reversed(num)))
393
+
394
+ @classmethod
395
+ def get_all_laplacian_electrodes(cls, elec_list):
396
+ """Select for channels that have neighbors needed for Laplacian rereferencing.
397
+
398
+ Functionality from the BrainBERT repository:
399
+ https://github.com/czlwang/BrainBERT/tree/master/data
400
+ """
401
+ stems = [
402
+ BrainTreebankDatasetRawDataHelper.stem_electrode_name(e) for e in elec_list
403
+ ]
404
+
405
+ def has_nbrs(stem, stems):
406
+ (x, y) = stem
407
+ return ((x, y + 1) in stems) and ((x, y - 1) in stems)
408
+
409
+ laplacian_stems = [x for x in stems if has_nbrs(x, stems)]
410
+ electrodes = [f"{x}{y}" for (x, y) in laplacian_stems]
411
+ return electrodes
412
+
413
+ def _get_trial_triggers(self, subject: str, trial: str) -> pd.DataFrame:
414
+ """
415
+ Returns:
416
+ a pandas DataFrame with the following column headers:
417
+ type: trigger type
418
+ movie_time: movie time at which trigger was sent
419
+ start_time: wall clock time at which trigger was sent
420
+ end_time: wall clock time at which trigger concluded
421
+ trig_type: type of trigger token sent (movie beginning/end/pause/unpause)
422
+ index: neural data samples that recorded the beginning of the trigger
423
+ diff: ??
424
+ """
425
+ movie_triggers_fpath = self.path_manager.get_movie_triggers_filepath(
426
+ subject, trial
427
+ )
428
+ triggers_cache_key = os.path.basename(movie_triggers_fpath)
429
+ # Use lazy loading of movie triggers to save on compute in the future.
430
+ if triggers_cache_key in self.trial_triggers_cache:
431
+ df = self.trial_triggers_cache[triggers_cache_key]
432
+ else:
433
+ df = pd.read_csv(movie_triggers_fpath)
434
+ self.trial_triggers_cache[triggers_cache_key] = df
435
+ return df
436
+
437
+ def _get_trial_features(self, subject: str, trial: str) -> List[Dict]:
438
+ """
439
+ Returns:
440
+ a pandas DataFrame with the following column headers:
441
+ 'bin_head',
442
+ 'charecter_num',
443
+ 'delta_magnitude',
444
+ 'delta_mel',
445
+ 'delta_pitch',
446
+ 'delta_rms',
447
+ 'deprel',
448
+ 'end',
449
+ 'est_idx', = estimated first neural sample
450
+ 'est_end_idx', = estimated last neural sample
451
+ 'face_num',
452
+ 'gpt2_surprisal',
453
+ 'head',
454
+ 'idx_in_sentence',
455
+ 'is_onset',
456
+ 'lemma',
457
+ 'magnitude',
458
+ 'max_global_angle',
459
+ 'max_global_magnitude',
460
+ 'max_mean_magnitude',
461
+ 'max_mean_pixel_brightness',
462
+ 'max_mean_pixel_difference',
463
+ 'max_median_magnitude',
464
+ 'max_vector_angle',
465
+ 'max_vector_magnitude',
466
+ 'mean_pixel_brightness',
467
+ 'mel',
468
+ 'min_mean_pixel_brightness',
469
+ 'min_mean_pixel_difference',
470
+ 'onset_diff',
471
+ 'phoneme_num',
472
+ 'pitch',
473
+ 'pos',
474
+ 'prev_word_idx',
475
+ 'rms',
476
+ 'sentence',
477
+ 'sentence_idx',
478
+ 'speaker',
479
+ 'start',
480
+ 'syllable',
481
+ 'text',
482
+ 'word_diff',
483
+ 'word_idx',
484
+ 'word_length'
485
+
486
+ See dataset technical paper for full explanation: https://braintreebank.dev/.
487
+ """
488
+ features_filename, movie_title = self.path_manager.get_features_filepath(
489
+ subject, trial
490
+ )
491
+
492
+ df = pd.read_csv(features_filename).set_index("Unnamed: 0")
493
+ df = df.dropna().reset_index(drop=True) # Drop rows with NaN word times.
494
+ trig_df = self._get_trial_triggers(subject, trial)
495
+ df = self._add_estimated_sample_index(df, trig_df)
496
+ df = df.dropna().reset_index(drop=True) # Drop rows with NaN sample times.
497
+ return df
498
+
499
+ def get_features(
500
+ self, subject: str, trial: str, feature_name: str, n_samples: int
501
+ ) -> np.ndarray:
502
+ df = self._get_trial_features(subject, trial)
503
+
504
+ if feature_name == "volume":
505
+ feature_vals = df.rms
506
+ elif (
507
+ feature_name == "sentence_onset"
508
+ or feature_name == "sentence_onset_time"
509
+ ):
510
+ feature_vals = df.is_onset
511
+ elif (
512
+ feature_name == "speech_vs_nonspeech"
513
+ or feature_name == "speech_vs_nonspeech_time"
514
+ ):
515
+ feature_vals = np.ones(df.size)
516
+ elif feature_name == "optical_flow":
517
+ feature_vals = df.max_global_magnitude
518
+ else:
519
+ raise ValueError(f"Unsupported feature_name: {feature_name}")
520
+
521
+ label_intervals = list(zip(df[_EST_IDX_COL].array, df[_EST_END_IDX_COL].array))
522
+ label_init = lambda x: (
523
+ 0
524
+ if x
525
+ in [
526
+ "speech_vs_nonspeech",
527
+ "speech_vs_nonspeech_time",
528
+ "sentence_onset",
529
+ "sentence_onset_time",
530
+ ]
531
+ else np.nan
532
+ )
533
+ labels = np.ones(n_samples) * label_init(feature_name)
534
+ for label_ind, label_interval in enumerate(label_intervals):
535
+ if feature_name != "sentence_onset" and feature_name != "sentence_onset_time":
536
+ labels[int(label_interval[0]) : int(label_interval[1])] = feature_vals[
537
+ label_ind
538
+ ]
539
+ else:
540
+ # sentence_onset has to only handle putting labels for onset words
541
+ labels[int(label_interval[0]) : int(label_interval[1])] = (
542
+ 1 if feature_vals[label_ind] else np.nan
543
+ )
544
+
545
+ return labels, label_intervals
546
+
547
+ def _estimate_sample_index(self, t, near_t, near_trig):
548
+ """Estimates the word onset data sample by interpolation from nearest trigger.
549
+
550
+ Source:
551
+ quickstart.ipynb notebook on https://braintreebank.dev/
552
+
553
+ Args:
554
+ t - word movie time
555
+ near_t - nearest trigger movie time
556
+ near_trig - nearest trigger sample index
557
+
558
+ Returns:
559
+ Estimated word onset sample index.
560
+ """
561
+ trig_diff = (t - near_t) * self.samp_frequency
562
+ return round(near_trig + trig_diff)
563
+
564
+ def _add_estimated_sample_index(self, w_df, t_df):
565
+ """Computes and adds data sample indices to annotated movie word onsets.
566
+
567
+ Source:
568
+ quickstart.ipynb notebook on https://braintreebank.dev/
569
+
570
+ Args:
571
+ w_df - movie annotated words data frame
572
+ t_df - computer triggers data frame
573
+
574
+ Returns:
575
+ Movie annotated words data frame augmented with estimated data sample indices
576
+ """
577
+ tmp_w_df = w_df.copy(deep=True)
578
+ last_t = t_df.loc[len(t_df) - 1, _TRIG_TIME_COL]
579
+ for i, t, endt in zip(w_df.index, w_df[_START_COL], w_df[_END_COL]):
580
+ if t > last_t: # If movie continues after triggers
581
+ break
582
+
583
+ # Find nearest movie time index for start.
584
+ idx = (abs(t_df[_TRIG_TIME_COL] - t)).idxmin()
585
+ tmp_w_df.loc[i, :] = w_df.loc[i, :]
586
+ tmp_w_df.loc[i, _EST_IDX_COL] = self._estimate_sample_index(
587
+ t, t_df.loc[idx, _TRIG_TIME_COL], t_df.loc[idx, _TRIG_IDX_COL]
588
+ )
589
+
590
+ # Find nearest movie time index for end.
591
+ end_idx = (abs(t_df[_TRIG_TIME_COL] - endt)).idxmin()
592
+ tmp_w_df.loc[i, _EST_END_IDX_COL] = self._estimate_sample_index(
593
+ endt,
594
+ t_df.loc[end_idx, _TRIG_TIME_COL],
595
+ t_df.loc[end_idx, _TRIG_IDX_COL],
596
+ )
597
+
598
+ return tmp_w_df
599
+
600
+ def _extract_neural_timestamps(self, subject: str, trial: str, data: np.ndarray):
601
+ """Extracts wall clock timestamps associated with recorded triggers.
602
+
603
+ NOTE: Not all samples will have a timestamp.
604
+ """
605
+ t_df = self._get_trial_triggers(subject, trial)
606
+ timestamps = np.ones(data.shape[-1]) * np.nan
607
+ for sample_index, sample_walltime in zip(
608
+ t_df[_TRIG_IDX_COL], t_df[_START_WALLTIME]
609
+ ):
610
+ timestamps[int(sample_index)] = sample_walltime
611
+ return timestamps
612
+
613
+
614
+ class BrainTreebankDatasetPreprocessor:
615
+ """Helper class to preprocess the raw BrainTreebank neural data.
616
+
617
+ Recommended flow:
618
+ filter_data -> rereference
619
+
620
+ filter_data() currently performs:
621
+ notch filtering
622
+
623
+ Functionality partially utilizes implementations from the BrainBERT repository:
624
+ https://github.com/czlwang/BrainBERT/tree/master/data
625
+ """
626
+
627
+ def __init__(self, config: Dict):
628
+ self.config = config
629
+
630
+ # For notch filtering.
631
+ self.freqs_to_filter = [60, 120, 180, 240, 300, 360]
632
+
633
+ def notch_filter(self, data: np.ndarray, freq: float, Q: int = 30) -> np.ndarray:
634
+ """Notch filters input data along time axis.
635
+
636
+ Args:
637
+ data: np.ndarray shape (n_channels, n_samples)
638
+
639
+ Returns filtered signal.
640
+ """
641
+ w0 = freq / (self.config.samp_frequency / 2)
642
+ b, a = scipy.signal.iirnotch(w0, Q)
643
+ y = scipy.signal.lfilter(b, a, data, axis=-1)
644
+ return y
645
+
646
+ def filter_data(self, data_arr: np.ndarray):
647
+ """Filters data based on provided config.
648
+
649
+ Args:
650
+ data: np.ndarray shape (n_channels, n_samples)
651
+
652
+ Returns filtered signal.
653
+ """
654
+ for f in self.freqs_to_filter:
655
+ data_arr = self.notch_filter(data_arr, f)
656
+ return data_arr
657
+
658
+ def _get_all_adj_electrodes(
659
+ self, selected_electrodes: List[str], all_electrodes: List[str]
660
+ ):
661
+ """Extracts all adjacent electrodes to use with Laplacian rereferencing."""
662
+ all_electrode_stems = [
663
+ BrainTreebankDatasetRawDataHelper.stem_electrode_name(l)
664
+ for l in all_electrodes
665
+ ]
666
+
667
+ elec2neighbors_dict, unique_neighbors = OrderedDict(), ordered_set.OrderedSet()
668
+ for selected_electrode in selected_electrodes:
669
+ stem, num = BrainTreebankDatasetRawDataHelper.stem_electrode_name(
670
+ selected_electrode
671
+ )
672
+ nbrs = [
673
+ n
674
+ for n in [(stem, num - 1), (stem, num + 1)]
675
+ if n in all_electrode_stems
676
+ ]
677
+
678
+ assert len(nbrs) == 2, "Neighbors must be 2 for Laplacian rereferencing."
679
+
680
+ elec2neighbors_dict[selected_electrode] = [
681
+ e_stem + str(num_stem) for (e_stem, num_stem) in nbrs
682
+ ]
683
+ unique_neighbors.update(elec2neighbors_dict[selected_electrode])
684
+
685
+ neighbor_label2id = {
686
+ elec: all_electrodes.index(elec) for elec in unique_neighbors
687
+ }
688
+ return elec2neighbors_dict, neighbor_label2id
689
+
690
+ def _laplacian_rereference(
691
+ self,
692
+ selected_data: np.ndarray,
693
+ selected_electrodes: List[str],
694
+ all_data: np.ndarray,
695
+ all_electrodes: List[str],
696
+ ):
697
+ """
698
+ Args:
699
+ selected_data: np.ndarray shape (n_selected_channels, n_samples), corresponding
700
+ to the selected electrodes.
701
+ selected_electrodes: List[str], labels corrresponding to selected electrodes
702
+ (e.g., "clean" electrodes).
703
+ all_data: np.ndarray shape (n_total_channels, n_samples).
704
+ all_electrodes: List[str], labels corrresponding to all electrodes.
705
+ """
706
+ elec2neighbors_dict, neighbor_label2id = self._get_all_adj_electrodes(
707
+ selected_electrodes, all_electrodes
708
+ )
709
+
710
+ selected_neighbor_data = [
711
+ [
712
+ all_data[neighbor_label2id[nghbr_elec], ...]
713
+ for nghbr_elec in elec2neighbors_dict[elec]
714
+ ]
715
+ for elec in selected_electrodes
716
+ ]
717
+ selected_neighbor_data = np.array(selected_neighbor_data)
718
+ selected_neighbor_data = self.filter_data(selected_neighbor_data)
719
+
720
+ assert selected_data.shape == (
721
+ selected_neighbor_data.shape[0],
722
+ selected_neighbor_data.shape[-1],
723
+ )
724
+ ref_data = selected_data - np.mean(selected_neighbor_data, axis=1)
725
+ return ref_data
726
+
727
+ def rereference_data(self, **rereference_kwargs) -> np.ndarray:
728
+ """Rereferences electrode data based on provided reference electrodes.
729
+
730
+ Check _laplacian_rereference() above for required arguments.
731
+ """
732
+ data = self._laplacian_rereference(**rereference_kwargs)
733
+ return data
734
+
735
+ def zscore_data(self, data: np.ndarray) -> np.ndarray:
736
+ data = (
737
+ sk_preprocessing.StandardScaler(with_mean=True, with_std=True)
738
+ .fit_transform(data.T)
739
+ .T
740
+ )
741
+ return data
barista/data/braintreebank_dataset.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict, defaultdict, namedtuple
2
+ from copy import deepcopy
3
+ from typing import List, Optional, Union
4
+
5
+ import pandas as pd
6
+ import torch
7
+ from barista.data.braintreebank_wrapper import BrainTreebankWrapper
8
+ from omegaconf import DictConfig, OmegaConf
9
+ from torch.utils.data import DataLoader, Dataset
10
+
11
+ DatapointMetadata = namedtuple(
12
+ "Metadata",
13
+ ["subject_session", "subject"],
14
+ )
15
+
16
+ DataPoint = namedtuple(
17
+ "DataPoint",
18
+ ["x", "label", "metadata"],
19
+ defaults=(None,) * 3
20
+ )
21
+
22
+ BatchItem = namedtuple(
23
+ "BatchItem",
24
+ [
25
+ "x",
26
+ "labels",
27
+ "subject_sessions",
28
+ ],
29
+ )
30
+
31
+ torch_version = torch.__version__.split("+")[0]
32
+
33
+
34
+ class BrainTreebankDataset(Dataset):
35
+ def __init__(
36
+ self,
37
+ config: Union[OmegaConf, DictConfig],
38
+ max_cache_size: int = 5000,
39
+ include_subject_sessions: Optional[List[str]] = [],
40
+ exclude_subject_sessions: Optional[List[str]] = [],
41
+ ):
42
+ """BrainTreebank Dataset class.
43
+
44
+ Args:
45
+ config: OmegaConf or DictConfig.
46
+ max_cache_size: int. The segment cache size to use to avoid
47
+ reloading segments.
48
+ include_subject_sessions: Optional list of str corresponding to
49
+ the subject_sessions to keep/use in the dataset
50
+ exclude_subject_sessions: Optional list of str corresponding to
51
+ the subject_sessions to discard/not use in the dataset.
52
+ """
53
+ self.config = config
54
+
55
+ self.dataset = BrainTreebankWrapper(config)
56
+ self.metadata = self.dataset.metadata
57
+ if self.config.get("shuffle_dataloader", True):
58
+ print("Shuffling metadata.")
59
+ self.metadata.shuffle()
60
+
61
+ if not include_subject_sessions:
62
+ print(
63
+ f"Including only finetune sessions specified in config: {config.finetune_sessions}"
64
+ )
65
+ include_subject_sessions = list(config.finetune_sessions)
66
+
67
+ self._reduce_metadata(
68
+ subject_sessions=include_subject_sessions,
69
+ keep=True
70
+ )
71
+
72
+ if exclude_subject_sessions:
73
+ self._reduce_metadata(
74
+ subject_sessions=exclude_subject_sessions,
75
+ keep=False
76
+ )
77
+
78
+ self.max_cache_size = max_cache_size
79
+ self.data_cache = OrderedDict()
80
+
81
+ def check_no_common_segment(self, train_dataset, val_dataset, test_dataset):
82
+ """Double checking paths for no overlap in splits."""
83
+ train_paths = set(train_dataset.dataset.metadata.get_unique_values_in_col("path"))
84
+ val_paths = set(val_dataset.dataset.metadata.get_unique_values_in_col("path"))
85
+ test_paths = set(test_dataset.dataset.metadata.get_unique_values_in_col("path"))
86
+
87
+ assert not train_paths.intersection(test_paths)
88
+ assert not train_paths.intersection(val_paths)
89
+ assert not val_paths.intersection(test_paths)
90
+
91
+ def _reduce_metadata(self, subject_sessions: List[str], keep=True):
92
+ """Reduce metadata by either keeping OR discarding the specified subject_sessions.
93
+
94
+ Args:
95
+ subject_sessions: list of str corresponding to subject session identifiers.
96
+ keep: bool. If true, keep the specified subject sessions, otherwise discard.
97
+ """
98
+ if not isinstance(subject_sessions, list):
99
+ subject_sessions = [subject_sessions]
100
+
101
+ combined_pattern = "|".join(subject_sessions)
102
+
103
+ self.metadata.reduce_based_on_col_value(
104
+ col_name="subject_session",
105
+ value=combined_pattern,
106
+ regex=True,
107
+ keep=keep,
108
+ )
109
+
110
+ summary_str = self.metadata.get_summary_str()
111
+ print(f"Reduced dataset: {summary_str}")
112
+
113
+ def set_split(self, split: str):
114
+ self.metadata.reduce_based_on_col_value(col_name="split", value=split)
115
+
116
+ def get_dataloader(self, split: str, train_config: Union[DictConfig, OmegaConf]):
117
+ split_dataset = deepcopy(self)
118
+ split_dataset.set_split(split=split)
119
+
120
+ if split == "test":
121
+ # Don't drop any samples for test for consistency across different batch size.
122
+ drop_last = False
123
+ elif split == "train":
124
+ drop_last = train_config.dataloader.drop_last
125
+ else: # split == "val"
126
+ drop_last = train_config.dataloader.get(
127
+ "drop_last_val",
128
+ train_config.dataloader.drop_last
129
+ )
130
+
131
+ return DataLoader(
132
+ split_dataset,
133
+ batch_size=train_config.dataloader.batch_size,
134
+ collate_fn=collate_with_metadata_fn_group_subjects,
135
+ num_workers=train_config.dataloader.num_workers,
136
+ persistent_workers=train_config.dataloader.persistent_workers,
137
+ pin_memory=train_config.dataloader.pin_memory,
138
+ drop_last=drop_last,
139
+ )
140
+
141
+ def __len__(self):
142
+ return len(self.metadata)
143
+
144
+ def __getitem__(self, idx):
145
+ meta_row = self.metadata[idx]
146
+ segment_path = meta_row["path"]
147
+
148
+ if segment_path not in self.data_cache:
149
+ data_file = torch.load(
150
+ segment_path, weights_only=(torch_version > "2.2.1")
151
+ )
152
+ if len(self.data_cache) >= self.max_cache_size:
153
+ first_path = next(iter(self.data_cache))
154
+ self.data_cache.pop(first_path)
155
+ self.data_cache[segment_path] = data_file
156
+
157
+ else:
158
+ data_file = self.data_cache[segment_path]
159
+
160
+ metadata = DatapointMetadata(
161
+ subject_session=meta_row.subject_session,
162
+ subject=meta_row.subject,
163
+ )
164
+
165
+ if "label" in meta_row and not pd.isna(meta_row.label):
166
+ label = torch.tensor((meta_row.label,))
167
+ else:
168
+ label = data_file[meta_row.experiment]
169
+ if label is None:
170
+ raise ValueError("Label cannot be None in the data_file.")
171
+
172
+ datapoint = DataPoint(
173
+ x=data_file["x"],
174
+ label=label,
175
+ metadata=metadata,
176
+ )
177
+ return datapoint
178
+
179
+
180
+ def collate_with_metadata_fn_group_subjects(batch: List[DataPoint]):
181
+ """Returns a list of batched tensors, each for one session."""
182
+ x, labels, subject_sessions = (
183
+ [],
184
+ [],
185
+ [],
186
+ )
187
+ x_dims, labels_dims = [], []
188
+ x_seq_lens, labels_seq_lens = [], []
189
+
190
+ x_dict = defaultdict(list)
191
+ for i, datapoint in enumerate(batch):
192
+ ss = datapoint.metadata.subject_session
193
+ x_dict[ss].append(i)
194
+
195
+ for sub_sesh_list in x_dict.values():
196
+ sub_sesh_x = []
197
+ for i in sub_sesh_list:
198
+ datapoint = batch[i]
199
+
200
+ # Skip all zero sessions
201
+ if torch.all(datapoint.x == 0):
202
+ continue
203
+
204
+ sub_sesh_x.append(datapoint.x)
205
+ labels.append(datapoint.label)
206
+
207
+ subject_sessions.append(datapoint.metadata.subject_session)
208
+
209
+ x_dims.append(datapoint.x.shape[-1])
210
+ labels_dims.append(datapoint.label.shape[-1])
211
+
212
+ x_seq_lens.append(datapoint.x.shape[0])
213
+ labels_seq_lens.append(datapoint.label.shape[0])
214
+
215
+ if sub_sesh_x:
216
+ sub_sesh_x = torch.stack(sub_sesh_x, dim=0)
217
+ x.append(sub_sesh_x)
218
+
219
+
220
+ if (torch.tensor(labels_dims) == labels_dims[0]).all() and (
221
+ torch.tensor(labels_seq_lens) == labels_seq_lens[0]
222
+ ).all():
223
+ labels = torch.stack(labels, dim=0)
224
+
225
+ batch = BatchItem(
226
+ x=x,
227
+ labels=labels,
228
+ subject_sessions=subject_sessions,
229
+ )
230
+ return batch
barista/data/braintreebank_dataset_spatial_groupings.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import pandas as pd
4
+
5
+ import barista.data.atlas as atlas_enums
6
+ from barista.data.metadata_spatial_groups import (
7
+ MetadataSpatialGroupRow,
8
+ SpatialGroupingName,
9
+ )
10
+
11
+ XYZ_MAX = 200
12
+
13
+ class BrainTreebankSpatialGroupingsHelper:
14
+ """
15
+ Helper class to generate spatial groups rows
16
+
17
+ Creating new spatial groupings should be added here.
18
+ """
19
+
20
+ def __init__(self, config, dataset_name: str):
21
+ self.config = config
22
+ self.dataset_name = dataset_name
23
+
24
+ def get_spatial_groupings(
25
+ self,
26
+ subject: str,
27
+ session: str,
28
+ coords: List[Tuple],
29
+ localization: pd.DataFrame,
30
+ ) -> List[MetadataSpatialGroupRow]:
31
+ rows = []
32
+ for spatial_grouping in self.config.spatial_groupings_to_create:
33
+ sg = SpatialGroupingName(spatial_grouping)
34
+ if sg == SpatialGroupingName.COORDS:
35
+ group_components = coords
36
+ n_effective_components = 3
37
+ max_elements_for_component = (XYZ_MAX, XYZ_MAX, XYZ_MAX)
38
+ padding_indices = (None, None, None)
39
+
40
+ elif sg == SpatialGroupingName.DESTRIEUX:
41
+ (
42
+ group_components,
43
+ n_effective_components,
44
+ max_elements_for_component,
45
+ padding_indices,
46
+ ) = self._get_grouping_based_on_loc_file(
47
+ subject=subject,
48
+ coords=coords,
49
+ localization=localization,
50
+ localization_col="Destrieux",
51
+ enum_class=atlas_enums.Destrieux,
52
+ )
53
+
54
+ elif sg == SpatialGroupingName.LOBES:
55
+ (
56
+ group_components,
57
+ n_effective_components,
58
+ max_elements_for_component,
59
+ padding_indices,
60
+ ) = self._get_grouping_based_on_loc_file(
61
+ subject=subject,
62
+ coords=coords,
63
+ localization=localization,
64
+ localization_col="DesikanKilliany",
65
+ enum_class=atlas_enums.Lobes,
66
+ )
67
+
68
+ else:
69
+ raise NotImplementedError()
70
+
71
+ group_ids = self._get_group_ids_based_on_group_components(
72
+ group_components, n_effective_components
73
+ )
74
+
75
+ assert len(max_elements_for_component) >= n_effective_components
76
+ assert len(padding_indices) >= n_effective_components
77
+
78
+ row = MetadataSpatialGroupRow(
79
+ dataset=self.dataset_name,
80
+ subject=subject,
81
+ session=session,
82
+ subject_session=f"{subject}_{session}",
83
+ name=sg.value,
84
+ n_effective_components=n_effective_components,
85
+ max_elements_for_component=max_elements_for_component,
86
+ padding_indices=padding_indices,
87
+ group_components=group_components,
88
+ group_ids=group_ids,
89
+ )
90
+ rows.append(row)
91
+ return rows
92
+
93
+ def _get_grouping_based_on_loc_file(
94
+ self,
95
+ subject: str,
96
+ coords: List[Tuple],
97
+ localization: pd.DataFrame,
98
+ localization_col: str,
99
+ enum_class,
100
+ ):
101
+ group_components = []
102
+ for coord in coords:
103
+ found = False
104
+
105
+ for i in range(len(localization)):
106
+ loc = localization.iloc[i]
107
+
108
+ df_coord = (loc.L, loc.I, loc.P)
109
+
110
+ if df_coord == coord:
111
+ identifier_value = loc[localization_col].replace("-", "_").upper()
112
+ enum_i = enum_class.get_enum(identifier_value)
113
+ group_components.append((enum_i.value, identifier_value))
114
+ found = True
115
+ break
116
+
117
+ if not found:
118
+ raise ValueError(
119
+ f"Channel not found in localization file for {subject}"
120
+ )
121
+
122
+ max_elements_for_component = (max([v.value for v in enum_class]) + 1,)
123
+ padding_indices = (enum_class.UNKNOWN.value,)
124
+ n_effective_components = 1
125
+
126
+ return (
127
+ group_components,
128
+ n_effective_components,
129
+ max_elements_for_component,
130
+ padding_indices,
131
+ )
132
+
133
+ def _get_group_ids_based_on_group_components(
134
+ self, group_components: List[Tuple], n_effective_componetns: int
135
+ ) -> List[int]:
136
+ groups_to_id_mapping = dict()
137
+ group_id = 0
138
+ group_ids = []
139
+ for components in group_components:
140
+ group = components[:n_effective_componetns]
141
+ if group not in groups_to_id_mapping:
142
+ chan_group_id = group_id
143
+ groups_to_id_mapping[group] = group_id
144
+ group_id += 1
145
+ else:
146
+ chan_group_id = groups_to_id_mapping[group]
147
+ group_ids.append(chan_group_id)
148
+
149
+ return group_ids
barista/data/braintreebank_wrapper.py ADDED
@@ -0,0 +1,1186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Code to handle preprocessing, segmenting and labeling the BrainTreebank dataset.
2
+
3
+ Preprocessing and segmentation functionality is based on the implementations found in the
4
+ following repositories, but has been modified as needed to be used for the evaluation scheme
5
+ outlined in the BaRISTA paper:
6
+ https://github.com/czlwang/BrainBERT/tree/master/data
7
+ https://github.com/czlwang/PopulationTransformer/tree/main/data
8
+ https://github.com/czlwang/brain_treebank_code_release/tree/master/data
9
+ """
10
+ import dataclasses
11
+ import einops
12
+ import hashlib
13
+ import numpy as np
14
+ from omegaconf import DictConfig, OmegaConf
15
+ import os
16
+ import pandas as pd
17
+ import pickle
18
+ import torch
19
+ from typing import Dict, List, Optional, Tuple, Union
20
+
21
+ from barista.data.available_sessions import BrainTreebankAvailableSessions
22
+ from barista.data.braintreebank_data_helpers import (
23
+ BrainTreebankDatasetNames,
24
+ BrainTreebankDatasetPathManager,
25
+ BrainTreebankDatasetPreprocessor,
26
+ BrainTreebankDatasetRawDataHelper,
27
+ )
28
+ from barista.data.braintreebank_dataset_spatial_groupings import (
29
+ BrainTreebankSpatialGroupingsHelper,
30
+ )
31
+ from barista.data.metadata import Metadata, MetadataRow, MetadataSpatialGroupRow
32
+ from barista.data.splitter import Splitter
33
+ from barista.data.fileprogresstracker import FileProgressTracker
34
+
35
+ _DEFAULT_FS = 2048 # Hz
36
+
37
+
38
+ torch_version = torch.__version__.split("+")[0]
39
+
40
+
41
+ class BrainTreebankWrapper:
42
+ def __init__(self, config: Union[DictConfig, OmegaConf], only_segment_generation=False):
43
+ self.config = config
44
+
45
+ self._setup_helpers()
46
+
47
+ self.spatial_groups_helper = BrainTreebankSpatialGroupingsHelper(
48
+ self.config, dataset_name=self.name
49
+ )
50
+
51
+ # Hash string identifier corresponding to the preprocessing config used.
52
+ self.segments_processing_str, self.segments_processing_hash_str = (
53
+ self._get_segments_processing_hash(
54
+ segment_length_s=self.config.segment_length_s,
55
+ )
56
+ )
57
+
58
+ # Raw data processing (e.g., filtering).
59
+ if not self._is_raw_data_processed() or self.config.force_reprocess_stage1:
60
+ print(
61
+ "Processed raw dataset does not exist or reprocessing is enabled, processing starts."
62
+ )
63
+ self._process_raw_data()
64
+ print(f"Raw data processing complete: {self._processed_raw_data_dir}")
65
+ else:
66
+ print("Processed raw data exists")
67
+
68
+ # Processing of segments from processed raw data
69
+ os.makedirs(self._processed_segments_data_dir, exist_ok=True)
70
+
71
+ self.metadata = self._load_metadata()
72
+
73
+ # Empty the metadata since segments do not exist
74
+ self.metadata = self._initialize_metadata()
75
+
76
+ # Process the segments now
77
+ self.process_segments(only_segment_generation)
78
+ print(f"Segments are processed and ready to use. Metadata path: {self.metadata_path}")
79
+
80
+ @property
81
+ def name(self) -> str:
82
+ return "BrainTreebank"
83
+
84
+ @property
85
+ def available_sessions(self) -> Dict[str, List]:
86
+ return {
87
+ k.name: k.value
88
+ for k in BrainTreebankAvailableSessions
89
+ if not self.config.subjects_to_process
90
+ or k.name in self.config.subjects_to_process
91
+ }
92
+
93
+ @property
94
+ def experiment(self):
95
+ return self.config.experiment
96
+
97
+ @property
98
+ def metadata_path(self):
99
+ return os.path.join(
100
+ self.config.save_dir,
101
+ self.experiment,
102
+ f"metadata_{self.segments_processing_hash_str}.csv",
103
+ )
104
+
105
+ def _setup_helpers(self):
106
+ self.path_manager = BrainTreebankDatasetPathManager(
107
+ dataset_dir=self.config.dataset_dir,
108
+ )
109
+ self.raw_data_helper = BrainTreebankDatasetRawDataHelper(self.path_manager)
110
+ self.raw_data_preprocessor = BrainTreebankDatasetPreprocessor(self.config)
111
+ self.experiment_dataset_name = BrainTreebankDatasetNames.get_modes(
112
+ self.config.experiment
113
+ )
114
+
115
+ self.samp_frequency = self.config.get("samp_frequency", _DEFAULT_FS)
116
+ self.splitter = Splitter(
117
+ config=self.config,
118
+ subjects=list(self.available_sessions.keys()),
119
+ experiment=self.experiment,
120
+ use_fixed_seed=self.config.use_fixed_seed_for_splitter,
121
+ )
122
+
123
+ def _process_raw_data(self):
124
+ os.makedirs(self._processed_raw_data_dir, exist_ok=True)
125
+
126
+ for subject in self.available_sessions.keys():
127
+ print(f"Raw data processing for subject {subject} starts.")
128
+
129
+ sessions_count = len(self.available_sessions[subject])
130
+ for i, session in enumerate(self.available_sessions[subject]):
131
+ processed_file_path = self._get_processed_raw_data_file_path(
132
+ subject=subject, session=session
133
+ )
134
+ if os.path.exists(processed_file_path):
135
+ print(
136
+ f"Skipping session {session} ({i+1}/{sessions_count}), "
137
+ f"processed raw data exists in {processed_file_path}."
138
+ )
139
+ else:
140
+ print(
141
+ f"Processing session {session} ({i+1}/{sessions_count})..."
142
+ )
143
+
144
+ self._process_single_session_raw_data(
145
+ subject=subject, session=session
146
+ )
147
+
148
+ def _process_single_session_raw_data(self, subject: str, session: str):
149
+ save_path = self._get_processed_raw_data_file_path(
150
+ subject=subject, session=session
151
+ )
152
+ cache_dir, cache_path = self._get_processed_raw_data_file_path_cache(
153
+ subject=subject, session=session
154
+ )
155
+
156
+ if not self.config.force_reprocess_stage1:
157
+ if os.path.isfile(save_path):
158
+ print(f"Skipping raw processing for {subject} {session}")
159
+ return
160
+
161
+ if os.path.isfile(cache_path):
162
+ print(
163
+ f"Making symlink for raw processed file for {subject} {session}"
164
+ )
165
+ os.symlink(src=cache_path, dst=save_path)
166
+ return
167
+
168
+ raw_data_dict = self.raw_data_helper.get_raw_file(subject, session)
169
+ electrodes = raw_data_dict["electrode_info"]
170
+
171
+ ## Clean the electrodes based on corrupted channel meta information.
172
+ selected_electrodes = self.raw_data_helper.get_clean_elecs(subject)
173
+ assert len(set(selected_electrodes).intersection(set(electrodes))) == len(
174
+ selected_electrodes
175
+ )
176
+
177
+ selected_elecs_inds = [
178
+ i for i, e in enumerate(electrodes) if e in selected_electrodes
179
+ ]
180
+ electrode_data = raw_data_dict["data"][:, np.array(selected_elecs_inds)]
181
+ electrode_data = (
182
+ electrode_data.T
183
+ ) # Preprocessor requires (n_channels, n_samples)
184
+
185
+ ## Resample the data if self.samp_frequency != default_fs
186
+ if self.samp_frequency != _DEFAULT_FS:
187
+ raise NotImplementedError(
188
+ f"Resampling {self.name} dataset not yet supported."
189
+ )
190
+
191
+ ## Filter the data (e.g., notch).
192
+ electrode_data = self.raw_data_preprocessor.filter_data(electrode_data)
193
+
194
+ ## Do rerefencing.
195
+ electrode_data = self.raw_data_preprocessor.rereference_data(
196
+ selected_data=electrode_data,
197
+ selected_electrodes=selected_electrodes,
198
+ all_data=raw_data_dict["data"].T,
199
+ all_electrodes=raw_data_dict["electrode_info"],
200
+ )
201
+
202
+ save_dict = dict(
203
+ data=torch.tensor(electrode_data.T), # (n_samples, n_channels)
204
+ time=torch.tensor(raw_data_dict["time"]),
205
+ samp_frequency=self.samp_frequency,
206
+ electrode_info=selected_electrodes,
207
+ )
208
+
209
+ try:
210
+ os.makedirs(cache_dir, exist_ok=True)
211
+ torch.save(save_dict, cache_path)
212
+ print(f"Raw processed file created in {cache_path}")
213
+ os.symlink(src=cache_path, dst=save_path)
214
+ print(f"Raw processed file symlink created in {save_path}")
215
+ except (OSError, PermissionError, FileNotFoundError):
216
+ torch.save(save_dict, save_path)
217
+ print(f"Raw processed file created in {save_path}")
218
+
219
+ def _is_raw_data_processed(self):
220
+ if not os.path.exists(self._processed_raw_data_dir):
221
+ return False
222
+
223
+ files_exist = []
224
+ for subject in self.available_sessions.keys():
225
+ for session in self.available_sessions[subject]:
226
+ path = self._get_processed_raw_data_file_path(
227
+ subject=subject, session=session
228
+ )
229
+ files_exist.append(os.path.exists(path))
230
+ return np.array(files_exist).all()
231
+
232
+ def _get_file_progress_tracker_save_path(self, subject: str, session: str) -> str:
233
+ filename = f"{subject}_{session}_processing_status.json"
234
+ return os.path.join(self._processed_segments_data_dir, filename)
235
+
236
+ def _get_channels_region_info(
237
+ self,
238
+ subject: str,
239
+ electrode_info: List[str],
240
+ ) -> List[Tuple]:
241
+ """
242
+ Generate a list of Channels each including region information of the channel.
243
+ """
244
+ channels, coords, channel_inds_to_remove = [], [], []
245
+ for channel_ind, channel_name in enumerate(electrode_info):
246
+ localization_info = self.raw_data_helper.get_channel_localization(
247
+ subject, channel_name
248
+ )
249
+ if not localization_info:
250
+ raise ValueError(
251
+ f"Couldn't found elec {channel_name} for subject {subject}"
252
+ )
253
+
254
+ assert (
255
+ "coords" in localization_info
256
+ ), "localization_info incomplete, missing coords"
257
+ coord = localization_info.pop("coords")
258
+
259
+ ## Remove channels from regions specified in the config file.
260
+ if self.config.region_filtering.active:
261
+ match = False
262
+ for filtered_region in self.config.region_filtering.filters:
263
+ component_info = localization_info['region_info']
264
+ match = filtered_region.lower() in component_info.lower()
265
+ if match:
266
+ break
267
+
268
+ if match:
269
+ channel_inds_to_remove.append(channel_ind)
270
+ continue
271
+
272
+ coords.append((coord[0], coord[1], coord[2]))
273
+ channels.append((
274
+ localization_info['hemi'],
275
+ localization_info['region_info'],
276
+ localization_info['channel_stem'],
277
+ ))
278
+
279
+ return channels, coords, channel_inds_to_remove
280
+
281
+ def _create_spatial_groupings(
282
+ self, subject: str, session: str, coords: List[Tuple]
283
+ ):
284
+ localization = self.raw_data_helper.get_channel_localization_raw(subject)
285
+ rows = self.spatial_groups_helper.get_spatial_groupings(
286
+ subject,
287
+ session,
288
+ coords,
289
+ localization,
290
+ )
291
+ for row in rows:
292
+ self.metadata.add_spatial_group(row)
293
+ print(f"Add spatial group {row.name} for {row.subject_session}")
294
+
295
+ self.metadata.save(self.metadata_path)
296
+
297
+ def _spatial_groupings_exist_for_subject(self, subject: str, session: str):
298
+ for spatial_grouping in self.config.spatial_groupings_to_create:
299
+ sg = self.metadata.get_spatial_grouping(
300
+ subject_session=f"{subject}_{session}", name=spatial_grouping
301
+ )
302
+ if sg is None:
303
+ return False
304
+ return True
305
+
306
+ def _save_segment(
307
+ self,
308
+ subject: str,
309
+ session: str,
310
+ segment_data: torch.tensor,
311
+ segment_time: torch.tensor,
312
+ segment_labels: torch.tensor,
313
+ segment_id: int,
314
+ segment_seq_len: int,
315
+ file_progress_tracker: FileProgressTracker,
316
+ is_last_segment: bool
317
+ ) -> dict:
318
+ """Process and save one segment to file."""
319
+
320
+ segment_data = {
321
+ "x": segment_data.float().clone(),
322
+ "timestamps": segment_time.clone(),
323
+ self.experiment: segment_labels.clone(),
324
+ }
325
+
326
+ segment_label = self._get_segment_label(segment_labels)
327
+ segment_filename = f"{subject}_{session}_{segment_id}.pt"
328
+ segment_path = os.path.join(self._processed_segments_data_dir, segment_filename)
329
+ torch.save(segment_data, segment_path)
330
+
331
+ meta_row = MetadataRow(
332
+ dataset=self.name,
333
+ subject=subject,
334
+ session=session,
335
+ subject_session=f"{subject}_{session}",
336
+ experiment=self.experiment,
337
+ seq_len=segment_seq_len,
338
+ d_input=np.prod(segment_data["x"].shape),
339
+ d_data=segment_data["x"].shape,
340
+ path=segment_path,
341
+ split="train",
342
+ filename=segment_filename,
343
+ processing_str=self.segments_processing_str,
344
+ label=segment_label,
345
+ )
346
+
347
+ self.metadata.concat(pd.DataFrame([meta_row]))
348
+
349
+ if segment_id % self.config.processing_save_interval == 0 or is_last_segment:
350
+ self.metadata.save(self.metadata_path)
351
+ file_progress_tracker.update_last_file_ind(
352
+ file_ind=-1, ending_ind=-1, segment_id=segment_id
353
+ )
354
+
355
+ def _create_segments_for_subject_session(
356
+ self,
357
+ subject: str,
358
+ session: str,
359
+ segment_length_s: int,
360
+ file_progress_tracker: FileProgressTracker,
361
+ ) -> int:
362
+ """
363
+ Args:
364
+ subject: str. Subject name.
365
+ session: str. Session name.
366
+ segment_length_s: desired segment length in seconds
367
+ file_progress_tracker: tracker of last segment info that is processed
368
+
369
+ Returns:
370
+ Number of newly added segments.
371
+ """
372
+ processed_raw_data_path = self._get_processed_raw_data_file_path(
373
+ subject=subject, session=session
374
+ )
375
+ preprocessed_data_dict = torch.load(processed_raw_data_path, weights_only=False)
376
+
377
+ data = preprocessed_data_dict["data"].T # (n_channels, n_samples)
378
+
379
+ electrode_names = preprocessed_data_dict["electrode_info"]
380
+ channels, coords, channel_inds_to_remove = self._get_channels_region_info(
381
+ subject, electrode_names
382
+ )
383
+ assert len(electrode_names) - len(channel_inds_to_remove) == len(channels)
384
+
385
+ if channel_inds_to_remove: # Channels and coords already have these indices removed.
386
+ print(
387
+ f"Dropping {len(channel_inds_to_remove)} channels out of {len(electrode_names)} because missing."
388
+ )
389
+ channels_to_keep = np.delete(
390
+ np.arange(data.shape[0]), channel_inds_to_remove
391
+ )
392
+ data = data[channels_to_keep, ...]
393
+ electrode_names = [
394
+ electrode_names[i]
395
+ for i in range(len(electrode_names))
396
+ if i not in channel_inds_to_remove
397
+ ]
398
+
399
+ assert data.shape[0] == len(channels)
400
+
401
+ self._create_spatial_groupings(subject, session, coords)
402
+
403
+ if (
404
+ file_progress_tracker.is_completed()
405
+ and not self.config.force_reprocess_stage2
406
+ ):
407
+ return 0
408
+
409
+ # Segment the neural activity data into segments of segment_length_s seconds.
410
+ n_steps_in_one_segment = int(self.samp_frequency * segment_length_s)
411
+ data, labels, data_sample_indices = self._get_experiment_data_and_labels(
412
+ subject,
413
+ session,
414
+ data,
415
+ n_steps_in_one_segment,
416
+ time=preprocessed_data_dict["time"],
417
+ samp_frequency=preprocessed_data_dict["samp_frequency"],
418
+ electrode_info=preprocessed_data_dict["electrode_info"],
419
+ )
420
+
421
+ # Get the file index of previously processed files
422
+ _, _, last_segment_id = file_progress_tracker.get_last_file_ind()
423
+
424
+ print(
425
+ f"{last_segment_id+1} segment(s) already processed for subject {subject} session {session}."
426
+ )
427
+
428
+ for segment_ind in range(last_segment_id + 1, data.shape[0]):
429
+ segment_data = data[segment_ind, ...] # (n_channels, segment_len)
430
+ segment_label = labels[segment_ind, ...]
431
+
432
+ # Normalize current segment
433
+ segment_data = torch.tensor(
434
+ self.raw_data_preprocessor.zscore_data(segment_data)
435
+ )
436
+ segment_data = segment_data.T # (segment_len, n_channels)
437
+
438
+ self._save_segment(
439
+ subject,
440
+ session=session,
441
+ segment_data=segment_data,
442
+ segment_time=data_sample_indices[segment_ind, ...],
443
+ segment_labels=segment_label,
444
+ segment_id=segment_ind,
445
+ segment_seq_len=n_steps_in_one_segment,
446
+ file_progress_tracker=file_progress_tracker,
447
+ is_last_segment=(segment_ind == data.shape[0] - 1),
448
+ )
449
+
450
+ return data.shape[0] - last_segment_id
451
+
452
+ def _generate_segmented_data(
453
+ self,
454
+ data: torch.Tensor,
455
+ n_steps_in_one_segment: int,
456
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
457
+ """
458
+ Segment data of shape (channels x time_samples) to (number_of_segments x channels x n_steps_in_one_segment).
459
+ It will truncate extra samples.
460
+
461
+ Returns segmented data and also indices corresponding to original data tensor.
462
+ """
463
+ # Truncate time series to a divisible length by the desired window size.
464
+ cutoff_len = int(data.shape[-1] - data.shape[-1] % n_steps_in_one_segment)
465
+ data = data[..., :cutoff_len]
466
+ data_sample_indices = torch.arange(data.shape[-1])
467
+ data = einops.rearrange(data, "c (ns sl) -> ns c sl", sl=n_steps_in_one_segment)
468
+ data_sample_indices = data_sample_indices.reshape(
469
+ [-1, n_steps_in_one_segment]
470
+ ) # (n_segments, segment_length)
471
+
472
+ return data, data_sample_indices
473
+
474
+ def _get_experiment_data_and_labels(
475
+ self,
476
+ subject: str,
477
+ session: str,
478
+ raw_data: torch.Tensor,
479
+ n_steps_in_one_segment: int,
480
+ **kwargs, ## Needed for child classes.
481
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
482
+ """
483
+ Generate data and labels pairs. The data is reshaped to segments, which is done either by chunking
484
+ or by word-based segmenting based on the given experiment.
485
+
486
+ Args:
487
+ subject: str. Current data's subject name.
488
+ session: str. Current data's session name.
489
+ raw_data: a tensor of shape (n_channels x n_total_samples)
490
+ n_steps_in_one_segment: int. Number of samples we want in one segment.
491
+
492
+ Output:
493
+ data: a tensor of shape (n_segments x n_channels x n_steps_in_one_segment)
494
+ labels: a tensor of shape (n_segments x n_steps_in_one_segment)
495
+ data_sample_indices: a tensor of shape (n_segments x n_steps_in_one_segment)
496
+ containing indices of samples of the raw data each item in data corresponds to
497
+ """
498
+ if self.experiment_dataset_name == self._pretrain_enum:
499
+ data, data_sample_indices = self._generate_segmented_data(
500
+ raw_data, n_steps_in_one_segment
501
+ )
502
+ labels = torch.tensor(np.ones_like(data_sample_indices) * np.nan) # dummy
503
+ return data, labels, data_sample_indices
504
+
505
+ # Get associated experiment labels
506
+ raw_labels, label_intervals = self.raw_data_helper.get_features(
507
+ subject, session, self.experiment, raw_data.shape[-1]
508
+ )
509
+
510
+ if (
511
+ self.experiment_dataset_name == BrainTreebankDatasetNames.SENTENCE_ONSET
512
+ or self.experiment_dataset_name
513
+ == BrainTreebankDatasetNames.SPEECH_VS_NONSPEECH
514
+ or self.experiment_dataset_name
515
+ == BrainTreebankDatasetNames.SENTENCE_ONSET_TIME
516
+ or self.experiment_dataset_name
517
+ == BrainTreebankDatasetNames.SPEECH_VS_NONSPEECH_TIME
518
+ ):
519
+ data, labels, data_sample_indices = (
520
+ self._generate_data_and_labels_by_speech(
521
+ raw_data, n_steps_in_one_segment, raw_labels
522
+ )
523
+ )
524
+
525
+ elif (
526
+ self.experiment_dataset_name == BrainTreebankDatasetNames.VOLUME
527
+ or self.experiment_dataset_name == BrainTreebankDatasetNames.OPTICAL_FLOW
528
+ ):
529
+ # label switch point will be the the neural activity index that corresponds to the word onset
530
+ label_switchpoints = np.array(
531
+ [elem[0] for elem in label_intervals], dtype=int
532
+ )
533
+ data, data_sample_indices, _ = self._generate_word_aligned_segments(
534
+ raw_data, n_steps_in_one_segment, label_switchpoints
535
+ )
536
+ # data_sample_indices are neural activity indice that corresponds to the segment start
537
+ # which is label switch points - segment len / 2 * sampling rate
538
+
539
+ start = (
540
+ int(data.shape[-1] / 2)
541
+ if self.config.trial_alignment == "center"
542
+ else 0
543
+ )
544
+ valid_label_switchpoints = data_sample_indices[start :: data.shape[-1]]
545
+
546
+ labels = raw_labels[valid_label_switchpoints]
547
+ labels = einops.repeat(labels, "n -> n l", l=data.shape[-1])
548
+
549
+ if self.config.quantile_numerical_labels.active:
550
+ labels = self._generate_quartile_labels(labels)
551
+
552
+ data_sample_indices = data_sample_indices.reshape(
553
+ (data.shape[0], data.shape[-1])
554
+ )
555
+ labels = torch.from_numpy(labels)
556
+
557
+ return data, labels, data_sample_indices
558
+
559
+ def _generate_data_and_labels_by_segments(
560
+ self,
561
+ raw_data: torch.Tensor,
562
+ n_steps_in_one_segment: int,
563
+ raw_labels: np.ndarray,
564
+ ):
565
+ """
566
+ Generate data and labels pairs by chunking the full session
567
+
568
+ Args:
569
+ raw_data: a tensor of shape (N_channels x N_total_samples)
570
+ n_steps_in_one_segment: number of samples we want in one segment
571
+ raw_labels: a numpy array of length N_total_samples containing labels
572
+ corresponding to each sample
573
+
574
+ Output:
575
+ data: a tensor of shape (N_segments x N_channels x n_steps_in_one_segment)
576
+ labels: a tensor of shape (N_segments x n_steps_in_one_segment)
577
+ data_sample_indices: a tensor of shape (N_segments x n_steps_in_one_segment)
578
+ containing indices of samples of the raw data each item in data corresponds to
579
+ """
580
+ data, data_sample_indices = self._generate_segmented_data(
581
+ raw_data, n_steps_in_one_segment
582
+ )
583
+
584
+ # data: N x channels x n_steps_in_one_segment
585
+ cutoff_len = data.shape[0] * data.shape[-1]
586
+
587
+ labels = raw_labels[..., :cutoff_len]
588
+ labels = einops.rearrange(labels, "(ns sl) -> ns sl", sl=n_steps_in_one_segment)
589
+
590
+ assert labels.shape[0] == data.shape[0]
591
+
592
+ if self.config.quantile_numerical_labels.active:
593
+ labels = self._generate_quartile_labels(labels)
594
+
595
+ labels = torch.from_numpy(labels)
596
+ return data, labels, data_sample_indices
597
+
598
+ def _generate_quartile_labels(self, feature_values: np.ndarray) -> np.ndarray:
599
+ """
600
+ Convert float labels based on quantile values: values in the top quantile will be assigned 1,
601
+ values in the bottom quantile will be assigned 0, and all others will be assigned NaN.
602
+ """
603
+ valid_inds = ~np.isnan(feature_values)
604
+ lower_thresh, higher_thresh = np.quantile(
605
+ feature_values[valid_inds],
606
+ [
607
+ self.config.quantile_numerical_labels.lower_threshold,
608
+ self.config.quantile_numerical_labels.higher_threshold,
609
+ ],
610
+ )
611
+
612
+ valid_inds = np.logical_or(
613
+ feature_values <= lower_thresh, feature_values >= higher_thresh
614
+ )
615
+ new_feature_values = feature_values.copy()
616
+ new_feature_values[~valid_inds] = np.nan
617
+ new_feature_values[feature_values <= lower_thresh] = 0
618
+ new_feature_values[feature_values >= higher_thresh] = 1
619
+
620
+ return new_feature_values
621
+
622
+ def _generate_word_aligned_segments(
623
+ self,
624
+ raw_data: torch.Tensor,
625
+ n_steps_in_one_segment: int,
626
+ label_switchpoints: np.ndarray,
627
+ ):
628
+ if self.config.trial_alignment == "center":
629
+ half_window = int(n_steps_in_one_segment / 2)
630
+ start_inds = label_switchpoints - half_window # start of word boundries
631
+ valid_start_inds = start_inds[
632
+ np.logical_and(
633
+ start_inds >= 0,
634
+ start_inds + n_steps_in_one_segment < raw_data.shape[-1],
635
+ )
636
+ ]
637
+
638
+ all_word_aligned_inds, word_aligned_inds, word_aligned_samples = (
639
+ [],
640
+ [],
641
+ [],
642
+ )
643
+ ## Note that the positive samples will most likely have overlaps between the windows.
644
+ for samp_ind, samp_start_ind in enumerate(valid_start_inds):
645
+ # inds in neural activity for this word
646
+ inds_to_query = torch.arange(
647
+ samp_start_ind, samp_start_ind + n_steps_in_one_segment
648
+ )
649
+ all_word_aligned_inds.append(inds_to_query)
650
+
651
+ ## Explicitly avoiding overlapping positive samples here.
652
+ if (
653
+ self.config.force_nonoverlap
654
+ and samp_ind > 0
655
+ and samp_start_ind <= word_aligned_inds[-1][-1]
656
+ ):
657
+ continue
658
+
659
+ word_aligned_samples.append(raw_data[:, inds_to_query])
660
+ word_aligned_inds.append(inds_to_query)
661
+
662
+ print(
663
+ f"Using only {len(word_aligned_inds)} out of {len(all_word_aligned_inds)} word-aligned segments."
664
+ )
665
+ all_word_aligned_inds = torch.cat(all_word_aligned_inds)
666
+ word_aligned_inds = torch.cat(
667
+ word_aligned_inds
668
+ ) # (n_segments * segment_length)
669
+ word_aligned_samples = torch.stack( #
670
+ word_aligned_samples
671
+ ) # (n_segments, n_channels, segment_length)
672
+
673
+ if self.config.force_nonoverlap:
674
+ assert len(torch.unique(word_aligned_inds)) == len(word_aligned_inds)
675
+
676
+ else:
677
+ raise NotImplementedError("Only center trial alignment supported.")
678
+
679
+ return word_aligned_samples, word_aligned_inds, all_word_aligned_inds
680
+
681
+ def _generate_data_and_labels_by_speech(
682
+ self,
683
+ raw_data: torch.Tensor,
684
+ n_steps_in_one_segment: int,
685
+ labels: np.ndarray,
686
+ ):
687
+ """
688
+ Generate data and labels pairs by segmenting based on words.
689
+
690
+ This function will first create word-aligned non-overlapping segments and
691
+ then assign labels to each word. For speech_vs_nonspeech(_time) and
692
+ sentence_onset(_time) tasks, it then chunks the data and uses segments that
693
+ don't overlap with any word to generate negative labels. Note, this function
694
+ can generate either non-overlapping **or** overlapping word center-aligned
695
+ segments -- based on user preference. In the former case with non-overlapping
696
+ segments, not all parts of the data will be used, since this is word-based.
697
+
698
+ Args:
699
+ data: a tensor of shape (n_channels x n_total_samples)
700
+ n_steps_in_one_segment: number of samples we want in one segment
701
+ raw_labels: a numpy array of length n_total_samples containing labels
702
+ corresponding to each sample
703
+
704
+ Output:
705
+ data: a tensor of shape (n_segments x n_channels x n_steps_in_one_segment)
706
+ labels: a tensor of shape (n_segments x n_steps_in_one_segment)
707
+ data_sample_indices: a tensor of shape (n_segments x n_steps_in_one_segment)
708
+ containing indices of samples of the raw data each item in data corresponds to.
709
+ """
710
+ # NOTE: The reason why label_intervals/word start times are not used as the switchpoints is
711
+ # because sentence onset true labels don't include all words, but only words that are onsets.
712
+ # Using word start times as switch points will generate more word aligned segments than is
713
+ # correct / needed. As such, here we use the raw labels directly to determine switchpoints.
714
+ label_switchpoints = np.where(
715
+ np.logical_and(
716
+ # All switch points should have delta with previous sample greater than 0.
717
+ np.concatenate((np.array([0]), np.diff(np.nan_to_num(labels)))) > 0,
718
+ ~np.isnan(labels),
719
+ )
720
+ )[0]
721
+ out = self._generate_word_aligned_segments(
722
+ raw_data, n_steps_in_one_segment, label_switchpoints
723
+ )
724
+ word_aligned_samples, word_aligned_inds, all_word_aligned_inds = out
725
+
726
+ if self.config.force_nonoverlap:
727
+ data_sample_indices = torch.arange(raw_data.shape[-1])
728
+ is_unaligned_inds = np.logical_and(
729
+ ~np.isin(data_sample_indices, np.unique(all_word_aligned_inds)),
730
+ ~np.isnan(labels),
731
+ )
732
+ # Truncate time series to a divisible length by the desired window size.
733
+ cutoff_len = int(
734
+ raw_data.shape[-1] - raw_data.shape[-1] % n_steps_in_one_segment
735
+ )
736
+ is_unaligned_inds = np.reshape(
737
+ is_unaligned_inds[..., :cutoff_len], (-1, n_steps_in_one_segment)
738
+ )
739
+ unaligned_inds = np.where(np.all(is_unaligned_inds, axis=1))[0]
740
+ unaligned_word_samples = torch.stack(
741
+ [
742
+ raw_data[
743
+ :,
744
+ start_ind
745
+ * n_steps_in_one_segment : (start_ind + 1)
746
+ * n_steps_in_one_segment,
747
+ ]
748
+ for start_ind in unaligned_inds
749
+ ]
750
+ )
751
+
752
+ word_aligned_data_sample_inds = torch.reshape(
753
+ word_aligned_inds, (-1, n_steps_in_one_segment)
754
+ )
755
+ unaligned_data_sample_inds = torch.reshape(
756
+ data_sample_indices[:cutoff_len], (-1, n_steps_in_one_segment)
757
+ )[unaligned_inds]
758
+
759
+ else: # not self.config.force_nonoverlap
760
+ # setting self.config.nonword_stepsize_s=segment_length should yield non overlap
761
+ if self.config.nonword_stepsize_s is None:
762
+ self.config.nonword_stepsize_s = self.config.segment_length_s
763
+
764
+ offset = int(self.samp_frequency * self.config.nonword_stepsize_s)
765
+ # Computation for n_rows: https://stackoverflow.com/a/53580139
766
+ n_rows = ((raw_data.shape[-1] - n_steps_in_one_segment) // offset) + 1
767
+
768
+ data_sample_indices = np.array(
769
+ [
770
+ np.arange(i * offset, i * offset + n_steps_in_one_segment)
771
+ for i in range(n_rows)
772
+ ]
773
+ )
774
+
775
+ is_unaligned_inds = np.logical_and(
776
+ ~np.isin(data_sample_indices, np.unique(all_word_aligned_inds)),
777
+ # NOTE: The second conditional is necessary because in the sentence onset case,
778
+ # regions with speech that aren't sentence onsets are labelled with nans.
779
+ # These should also be considered when labeling negatives.
780
+ ~np.isnan(
781
+ labels[data_sample_indices.flatten()].reshape(
782
+ data_sample_indices.shape
783
+ )
784
+ ),
785
+ )
786
+ unaligned_inds = np.where(np.all(is_unaligned_inds, axis=1))[0]
787
+
788
+ unaligned_word_samples = torch.stack(
789
+ [
790
+ raw_data[
791
+ :,
792
+ start_ind * offset : start_ind * offset
793
+ + n_steps_in_one_segment,
794
+ ]
795
+ for start_ind in unaligned_inds
796
+ ]
797
+ )
798
+
799
+ data_sample_indices = torch.tensor(data_sample_indices)
800
+
801
+ word_aligned_data_sample_inds = torch.reshape(
802
+ word_aligned_inds, (-1, n_steps_in_one_segment)
803
+ )
804
+ unaligned_data_sample_inds = data_sample_indices[unaligned_inds]
805
+
806
+ n_word_aligned_samples = word_aligned_samples.shape[0]
807
+ n_unaligned_word_samples = unaligned_word_samples.shape[0]
808
+
809
+ num_samples = n_unaligned_word_samples + n_word_aligned_samples
810
+
811
+ if self.config.force_balanced:
812
+ num_samples = min(n_unaligned_word_samples, n_word_aligned_samples) * 2
813
+
814
+ word_aligned_to_use = np.sort(
815
+ np.random.choice(
816
+ range(n_word_aligned_samples),
817
+ replace=False,
818
+ size=num_samples // 2,
819
+ )
820
+ )
821
+ word_aligned_samples = word_aligned_samples[word_aligned_to_use, ...]
822
+ word_aligned_data_sample_inds = word_aligned_data_sample_inds[
823
+ word_aligned_to_use
824
+ ]
825
+
826
+ unaligned_to_use = np.sort(
827
+ np.random.choice(
828
+ range(n_unaligned_word_samples),
829
+ replace=False,
830
+ size=num_samples // 2,
831
+ )
832
+ )
833
+ unaligned_word_samples = unaligned_word_samples[unaligned_to_use, ...]
834
+ unaligned_data_sample_inds = unaligned_data_sample_inds[unaligned_to_use]
835
+
836
+ n_word_aligned_samples = word_aligned_samples.shape[0]
837
+ n_unaligned_word_samples = unaligned_word_samples.shape[0]
838
+
839
+ # Concatenate data
840
+ data = torch.empty(
841
+ n_word_aligned_samples + n_unaligned_word_samples,
842
+ *word_aligned_samples.shape[1:],
843
+ )
844
+ data[:n_word_aligned_samples] = word_aligned_samples
845
+ data[n_word_aligned_samples:] = unaligned_word_samples
846
+
847
+ num_channels = raw_data.shape[0]
848
+ assert data.shape == (
849
+ num_samples,
850
+ num_channels,
851
+ n_steps_in_one_segment,
852
+ )
853
+
854
+ # Concatenate labels
855
+ labels = torch.zeros(num_samples, n_steps_in_one_segment)
856
+ labels[:n_word_aligned_samples] = 1
857
+
858
+ # Concatenate sample indices
859
+ data_sample_indices = torch.empty(
860
+ n_word_aligned_samples + n_unaligned_word_samples,
861
+ n_steps_in_one_segment,
862
+ )
863
+ data_sample_indices[:n_word_aligned_samples] = word_aligned_data_sample_inds
864
+ data_sample_indices[n_word_aligned_samples:] = unaligned_data_sample_inds
865
+
866
+ ## Putting the samples back in temporally sorted order.
867
+ sorted_inds = torch.argsort(data_sample_indices[:, 0])
868
+ data_sample_indices = data_sample_indices[sorted_inds, ...]
869
+ data = data[sorted_inds, ...]
870
+ labels = labels[sorted_inds, ...]
871
+ return data, labels, data_sample_indices
872
+
873
+ def _aggregate_labels(self, labels: torch.Tensor) -> float:
874
+ """
875
+ Return one label for each segment in batch instead of having one label for each timepoint
876
+ """
877
+
878
+ nan_numels = torch.isnan(labels).sum()
879
+
880
+ if nan_numels / len(labels) >= self.config.aggregate_labels.nan_threshold:
881
+ label = torch.nan
882
+ elif self.config.aggregate_labels.type == "mean":
883
+ label = labels.nanmean()
884
+ label = float(label)
885
+ elif self.config.aggregate_labels.type == "threshold":
886
+ non_nan_numels = len(labels) - nan_numels
887
+ label = int(
888
+ (
889
+ labels.nansum() / non_nan_numels
890
+ > self.config.aggregate_labels.threshold
891
+ ).long()
892
+ )
893
+
894
+ return label
895
+
896
+ def _get_segment_label(self, labels: torch.tensor) -> float:
897
+ if self.experiment_dataset_name == self._pretrain_enum:
898
+ return np.nan # pretraining data has no labels
899
+
900
+ agg_label = self._aggregate_labels(labels)
901
+ return agg_label
902
+
903
+ def _process_segments_and_update_metadata_file(self):
904
+ """
905
+ Process data files of subjects and add/update segments
906
+ """
907
+ number_of_added_segments = 0
908
+ for subject in self.available_sessions.keys():
909
+ for session in self.available_sessions[subject]:
910
+ print(
911
+ f"Segment processing for subject {subject} session {session} starts."
912
+ )
913
+
914
+ # Check status of processing
915
+ file_progress_tracker = FileProgressTracker(
916
+ save_path=self._get_file_progress_tracker_save_path(
917
+ subject, session
918
+ ),
919
+ experiment=self.experiment,
920
+ )
921
+
922
+ if self.config.force_reprocess_stage2:
923
+ corresponding_indices_to_remove = (
924
+ self.metadata.get_indices_matching_cols_values(
925
+ ["subject", "session", "experiment"],
926
+ [subject, session, self.experiment],
927
+ )
928
+ )
929
+ self.metadata.drop_rows_based_on_indices(
930
+ corresponding_indices_to_remove
931
+ )
932
+
933
+ file_progress_tracker.reset_process()
934
+ print(
935
+ f"Force reprocessing active, removed subject: {subject} session: "
936
+ f"{session} experiment: {self.experiment} from metadata, will "
937
+ f"start processing from the first file."
938
+ )
939
+
940
+ if file_progress_tracker.is_completed():
941
+ sp_exist = self._spatial_groupings_exist_for_subject(
942
+ subject, session
943
+ )
944
+ if sp_exist and not self.config.force_recreate_spatial_groupings:
945
+ print(
946
+ f"Subject {subject} data already processed completely, skipping."
947
+ )
948
+ continue
949
+ else:
950
+ print(
951
+ f"Subject {subject} data already processed completely,"
952
+ " but force recreate spatial groupings is active,"
953
+ " will recreate spatial groups"
954
+ )
955
+
956
+ number_of_added_segments_for_subject_session = (
957
+ self._create_segments_for_subject_session(
958
+ subject,
959
+ session,
960
+ self.config.segment_length_s,
961
+ file_progress_tracker,
962
+ )
963
+ )
964
+
965
+ print(
966
+ f"Added {number_of_added_segments_for_subject_session} new segments for subject {subject} session {session}"
967
+ )
968
+
969
+ nan_labels = self.metadata.get_indices_matching_cols_values(
970
+ ["subject", "session", "experiment", "label"],
971
+ [subject, session, self.experiment, None],
972
+ )
973
+ print(
974
+ f"{len(nan_labels)} segments for this subject session have nan labels"
975
+ )
976
+
977
+ number_of_added_segments += number_of_added_segments_for_subject_session
978
+
979
+ self.metadata = self.splitter.set_splits_for_subject(
980
+ subject, self.metadata, self._split_method
981
+ )
982
+ file_progress_tracker.mark_completion_status()
983
+ self.metadata.save(self.metadata_path)
984
+
985
+ print(f"Metadata saved in {self.metadata_path}")
986
+ print(f"Added {number_of_added_segments} new segments")
987
+
988
+ summary_str = self.metadata.get_summary_str()
989
+ print(f"{self.name} dataset, full metadata summary: {summary_str}")
990
+
991
+ def _filter_metadata_for_the_run(self):
992
+ """
993
+ Do filtering on metadata based on experiment design
994
+
995
+ # NOTE: Add stuff that are run dependent but do **not** alter the saved metadata here.
996
+ """
997
+ # Return only needed experiment
998
+ self.metadata.reduce_based_on_col_value("experiment", self.experiment)
999
+
1000
+ # Drop rows with no label if not pretraining
1001
+ if not self.experiment_dataset_name == self._pretrain_enum:
1002
+ n_dropped = self.metadata.reduce_based_on_col_value(
1003
+ "label", None, keep=False
1004
+ )
1005
+ print(f"Dropping {n_dropped} segments with no labels")
1006
+
1007
+ if self.experiment_dataset_name in (
1008
+ BrainTreebankDatasetNames.SPEECH_VS_NONSPEECH_TIME,
1009
+ BrainTreebankDatasetNames.SENTENCE_ONSET_TIME,
1010
+ BrainTreebankDatasetNames.VOLUME,
1011
+ BrainTreebankDatasetNames.OPTICAL_FLOW
1012
+ ):
1013
+
1014
+ curr_fold = self.config.get("chron_fold_num", None)
1015
+ if curr_fold is not None:
1016
+ print(f"Using chronological fold: {curr_fold}.")
1017
+ folds_path = os.path.join(
1018
+ self.config.save_dir,
1019
+ self.experiment,
1020
+ f"metadata_{self.segments_processing_hash_str}_folds.pkl",
1021
+ )
1022
+ try:
1023
+ with open(
1024
+ folds_path,
1025
+ "rb",
1026
+ ) as f:
1027
+ folds_info = pickle.load(f)
1028
+ except FileNotFoundError as e:
1029
+ print(f"File {folds_path} not found. Generate the folds for the metadata ({self.metadata_path}) using `barista/generate_chronological_folds` notebook.")
1030
+ exit(0)
1031
+
1032
+ assert (
1033
+ len(self.config.finetune_sessions) == 1
1034
+ ), "Only one finetune session expected."
1035
+
1036
+ subject_session = self.config.finetune_sessions[0]
1037
+ self.config.run_ratios = [
1038
+ # In case values were saved out as non-primitive float type.
1039
+ float(elem) for elem in folds_info[subject_session][curr_fold][0]
1040
+ ]
1041
+ self.config.run_splits = folds_info[subject_session][curr_fold][1]
1042
+
1043
+ else: # no chron_fold_num specified.
1044
+ print("Using default run chronological ratios and splits.")
1045
+
1046
+ for subject_session in self.config.finetune_sessions:
1047
+ self.splitter.resplit_for_subject(
1048
+ subject_session,
1049
+ self.metadata,
1050
+ self._split_method,
1051
+ )
1052
+
1053
+ summary_str = self.metadata.get_summary_str()
1054
+ print(f"{self.name} dataset, current run summary: {summary_str}")
1055
+
1056
+ def process_segments(self, only_segment_generation=False):
1057
+ # Load the metadata in this dataset to have info from previously precessed segments.
1058
+ old_metadata = self._load_metadata()
1059
+ if old_metadata is not None:
1060
+ self.metadata = old_metadata
1061
+
1062
+ if not self.config.skip_segment_generation_completely:
1063
+ self._process_segments_and_update_metadata_file()
1064
+
1065
+ if not only_segment_generation:
1066
+ self._filter_metadata_for_the_run()
1067
+
1068
+ @property
1069
+ def _split_method(self):
1070
+ if self.experiment_dataset_name in (
1071
+ BrainTreebankDatasetNames.SPEECH_VS_NONSPEECH,
1072
+ BrainTreebankDatasetNames.SENTENCE_ONSET,
1073
+ ):
1074
+ assert self.config.force_nonoverlap is True, "Set force_nonoverlap to True for random split segments"
1075
+ return "shuffle"
1076
+ # Everything else should just be split chronologically.
1077
+
1078
+ if self.experiment_dataset_name != BrainTreebankDatasetNames.PRETRAIN:
1079
+ assert self.config.force_nonoverlap is False, "Set force_nonoverlap to False for chronological segments"
1080
+
1081
+ return "chronological"
1082
+
1083
+ @property
1084
+ def _pretrain_enum(self) -> BrainTreebankDatasetNames:
1085
+ return BrainTreebankDatasetNames.PRETRAIN
1086
+
1087
+ def get_raw_data_file_path(self, subject: str, session: str):
1088
+ self.path_manager.get_raw_data_filepath(subject, session)
1089
+
1090
+ @property
1091
+ def _processed_raw_data_dir(self):
1092
+ """
1093
+ Filename for processed raw data, i.e., filtering and referencing
1094
+ """
1095
+ return os.path.join(
1096
+ self.config.save_dir,
1097
+ self._get_processed_raw_data_dir_name,
1098
+ )
1099
+
1100
+ @property
1101
+ def _get_processed_raw_data_dir_name(self):
1102
+ return f"processed_raw_{self.samp_frequency}Hz_notch_laplacianref_clnLap"
1103
+
1104
+ @property
1105
+ def _processed_segments_data_dir(self):
1106
+ """Data dir for the segmented trials corresponding to a particular experimental config."""
1107
+ return os.path.join(
1108
+ self.config.save_dir,
1109
+ self.experiment,
1110
+ f"processed_segments_{self.segments_processing_hash_str}",
1111
+ )
1112
+
1113
+ def _load_metadata(self) -> Optional[Metadata]:
1114
+ if os.path.exists(self.metadata_path):
1115
+ metadata = Metadata(load_path=self.metadata_path)
1116
+ print(f"Metadata loaded from {self.metadata_path}")
1117
+ return metadata
1118
+ return None
1119
+
1120
+ def _initialize_metadata(self) -> Metadata:
1121
+ columns = [f.name for f in dataclasses.fields(MetadataRow)]
1122
+ metadata_df = pd.DataFrame(columns=columns)
1123
+
1124
+ columns = [f.name for f in dataclasses.fields(MetadataSpatialGroupRow)]
1125
+ spatial_group_df = pd.DataFrame(columns=columns)
1126
+
1127
+ metadata = Metadata(df=metadata_df, spatial_group_df=spatial_group_df)
1128
+ print(f"Metadata initialized: {self.metadata_path}")
1129
+ return metadata
1130
+
1131
+ def _get_processed_raw_data_file_path(self, subject, session):
1132
+ filename = f"{subject}_{session}.pt"
1133
+ return os.path.join(self._processed_raw_data_dir, filename)
1134
+
1135
+ def _get_processed_raw_data_file_path_cache(self, subject, session):
1136
+ filename = f"{subject}_{session}.pt"
1137
+ path = os.path.join(
1138
+ self.config.stage1_cache_dir,
1139
+ self._get_processed_raw_data_dir_name,
1140
+ )
1141
+ print(f"Cache dir: {path}")
1142
+ return path, os.path.join(path, filename)
1143
+
1144
+ def _get_segments_processing_hash(self, segment_length_s):
1145
+ """
1146
+ returns a tuple where the key is the processing str, value is the hashed key.
1147
+ actual str can be found in metadata.
1148
+
1149
+ this part can be overwritten by each dataset class based on specific settings
1150
+ """
1151
+
1152
+ processing_str = (
1153
+ f"{self.config.samp_frequency}Hz_zscrTrue"
1154
+ f"_segment_length{segment_length_s}_val_ratio{self.config.val_ratio:.1e}_test_ratio{self.config.test_ratio:.1e}"
1155
+ )
1156
+
1157
+ if self.experiment_dataset_name != self._pretrain_enum:
1158
+ processing_str += f"_trial_align{self.config.trial_alignment}"
1159
+
1160
+ if self.config.quantile_numerical_labels.active:
1161
+ processing_str += f"quantile_numerical_labels_L{self.config.quantile_numerical_labels.lower_threshold}_H{self.config.quantile_numerical_labels.higher_threshold}"
1162
+
1163
+ processing_str += self.config.dataset_dir
1164
+ processing_str += "_laplacian"
1165
+
1166
+ if self.config.region_filtering.active:
1167
+ self.config.region_filtering['filters'].sort()
1168
+ filter_str = (
1169
+ f"_region_filtered_{str(self.config.region_filtering.filters)}"
1170
+ )
1171
+ processing_str += filter_str
1172
+
1173
+ if not self.config.force_balanced:
1174
+ processing_str += "_all_labels"
1175
+
1176
+ if self._split_method == "chronological":
1177
+ processing_str += "_chronosplit"
1178
+ if not self.config.force_nonoverlap:
1179
+ processing_str += "_overlapsegs"
1180
+
1181
+ processing_str += "_use_clean_laplacian"
1182
+ processing_str += "_aggregate_label" + str(self.config.aggregate_labels)
1183
+
1184
+ hash_str = hashlib.sha256(bytes(processing_str, "utf-8")).hexdigest()[:5]
1185
+ print(f"HASHSTR: {hash_str}")
1186
+ return processing_str, hash_str
barista/data/dataframe_wrapper.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
+ from typing import List, Optional, Union
6
+
7
+
8
+ class DataframeWrapper:
9
+ """
10
+ A wrapper for a pandas DataFrame
11
+
12
+ This class provide extra functionality over pd.DataFrame and abstracts
13
+ the dependency on pandas dataframe (for the most part).
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ df: Optional[pd.DataFrame] = None,
19
+ load_path: Optional[str] = None,
20
+ ) -> None:
21
+ if df is not None and load_path is not None:
22
+ raise ValueError("Only one of inner df or load path should be set")
23
+
24
+ if df is not None:
25
+ self._df: pd.DataFrame = df
26
+ else:
27
+ self._df: pd.DataFrame = self.load(load_path)
28
+
29
+ def copy(self):
30
+ new_df = self._df.copy(deep=True)
31
+ return self.__class__(df=new_df)
32
+
33
+ @classmethod
34
+ def merge(
35
+ cls,
36
+ metadatas: List["DataframeWrapper"],
37
+ drop_duplicate: bool = False,
38
+ merge_columns: Union[str, List[str], None] = None,
39
+ keep="first",
40
+ ) -> "DataframeWrapper":
41
+ """
42
+ Merge metadata's dataframes
43
+ If drop_duplicate = True, only one row from rows having same `merge_columns` will remain
44
+ based on `keep` strategy. Default to using all columns.
45
+ """
46
+ metadata_dfs = [m._df for m in metadatas]
47
+ df = pd.concat(metadata_dfs, ignore_index=True)
48
+ if drop_duplicate:
49
+ df = df.drop_duplicates(subset=merge_columns, keep=keep)
50
+ return cls(df)
51
+
52
+ @property
53
+ def columns(self):
54
+ return self._df.columns
55
+
56
+ def concat(self, new_df: pd.DataFrame):
57
+ self._df = pd.concat([self._df, new_df], ignore_index=True, sort=True)
58
+
59
+ def shuffle(self, column: Optional[str] = None) -> None:
60
+ """Shuffle the metadata table rows, or only a column if specified"""
61
+ shuffled = self._df.sample(frac=1, random_state=42).reset_index(drop=True)
62
+
63
+ if column is not None:
64
+ self._df[column] = shuffled[column]
65
+ else:
66
+ self._df = shuffled
67
+
68
+ def clear(self) -> None:
69
+ """Setting the metadata to empty table"""
70
+ self._df = self._df.head(0)
71
+
72
+ def is_empty(self) -> bool:
73
+ return len(self._df) == 0
74
+
75
+ def __getitem__(self, idx: int) -> pd.Series:
76
+ """Get a metadata table row"""
77
+ return self._df.iloc[idx]
78
+
79
+ def apply_fn_on_all_rows(self, col_name: str, fn: callable) -> pd.Series:
80
+ """Apply a function on each row of the dataframe"""
81
+ return self._df[col_name].apply(fn)
82
+
83
+ def get_unique_values_in_col(
84
+ self, col_name: str, indices: Optional[List[int]] = None
85
+ ) -> np.ndarray:
86
+ """Get unique values of a columnn"""
87
+ values = self._df[col_name]
88
+ if indices is not None:
89
+ values = values.iloc[indices]
90
+ return list(values.unique())
91
+
92
+ def get_indices_matching_cols_values(
93
+ self, col_names: List, values: List, contains: bool = False, check_range: bool = False
94
+ ) -> List[int]:
95
+ """
96
+ Get indices of the rows that their value of specified `col_names`
97
+ match the values in the `values` list
98
+
99
+ value can be a tuple of two for continues values, specify `range=True`, it can also be a list
100
+ which in that case if `contains=True` it will check if the row value is in the list
101
+ """
102
+
103
+ assert len(col_names) == len(values)
104
+
105
+ mask = pd.Series(True, range(len(self)))
106
+ for col_name, value in zip(col_names, values):
107
+ if check_range and isinstance(value, tuple):
108
+ assert len(value) == 2, "For a range provide min and max value"
109
+ min_val, max_val = value
110
+ mask &= (self._df[col_name] >= min_val) & (self._df[col_name] <= max_val)
111
+ elif contains and isinstance(value, list):
112
+ mask &= self._df[col_name].isin(value)
113
+ elif value == None or pd.isnull(value):
114
+ mask &= self._df[col_name].isnull()
115
+ else:
116
+ mask &= self._df[col_name] == value
117
+
118
+ return self._df.index[mask].tolist()
119
+
120
+ def get_column_max_value(self, col_name: str):
121
+ return self._df[col_name].max()
122
+
123
+ def set_col_to_value(self, indices: List[int], col: str, value):
124
+ self._df.loc[indices, col] = value
125
+
126
+ def save(self, path: str) -> None:
127
+ """Save metadata table to csv after converting lists and tuples to strings"""
128
+
129
+ def convert_complex_data(val, delimiter=","):
130
+ if isinstance(val, (list, tuple)):
131
+ return "[" + delimiter.join(map(str, val)) + "]"
132
+ elif isinstance(val, (dict, torch.Tensor, np.ndarray)):
133
+ raise TypeError(
134
+ f"Only columns of type list and tuple can be converted and saved, but received {type(val)}."
135
+ )
136
+ else:
137
+ return val
138
+
139
+ metadata_save = deepcopy(self._df)
140
+ if len(metadata_save) > 0:
141
+ for col in metadata_save.columns:
142
+ metadata_save[col] = metadata_save[col].apply(convert_complex_data)
143
+ metadata_save.to_csv(path, index=False)
144
+
145
+ def load(self, path: str) -> pd.DataFrame:
146
+ metadata = pd.read_csv(path)
147
+
148
+ def convert_from_string(val, delimiter=","):
149
+ # Check if the value is a list or tuple
150
+ if isinstance(val, str) and (
151
+ (val.startswith("[") and val.endswith("]"))
152
+ or (val.startswith("(") and val.endswith(")"))
153
+ ):
154
+ val = val[1:-1]
155
+ # Attempt to convert to a list of floats or ints
156
+ val_split = val.split(delimiter)
157
+ converted = []
158
+ for item in val_split:
159
+ try:
160
+ if "." in item or "e-" in item or "e+" in item:
161
+ converted.append(float(item))
162
+ elif item == "None" or item == "":
163
+ converted.append(None)
164
+ else:
165
+ converted.append(int(item))
166
+ except Exception:
167
+ converted.append(item)
168
+ return converted
169
+ return val
170
+
171
+ def convert_channels_string_to_tuples(val: str):
172
+ if val.startswith("[") and val.endswith("]"):
173
+ val = val[1:-1]
174
+
175
+ def convert_channel_value(ch_val: str):
176
+ if ch_val.isnumeric():
177
+ return int(ch_val)
178
+ elif (ch_val.startswith("'") and ch_val.endswith("'")) or (
179
+ ch_val.startswith('"') and ch_val.endswith('"')
180
+ ):
181
+ return ch_val[1:-1]
182
+ return ch_val
183
+
184
+ try:
185
+ return [
186
+ tuple(
187
+ [convert_channel_value(c) for c in ch_info_str[1:].split(", ")]
188
+ )
189
+ for ch_info_str in val[:-1].split("),")
190
+ ]
191
+ except ValueError as e:
192
+ return [
193
+ tuple(ch_info_str[1:].split(", "))
194
+ for ch_info_str in val[:-1].split("),")
195
+ ]
196
+
197
+ # Apply conversion to each column
198
+ for col in metadata.columns:
199
+ if col == "channels" or col == "coords": # keeping for backward compatibility
200
+ metadata[col] = np.nan
201
+ elif col == "group_components":
202
+ # Only do conversion for unique channel str since many segments have same channels
203
+ unique_str = metadata[col].unique()
204
+ channel_dict = {
205
+ c: convert_channels_string_to_tuples(c) for c in unique_str
206
+ }
207
+ metadata[col] = metadata[col].apply(lambda c: channel_dict[c])
208
+ else:
209
+ metadata[col] = metadata[col].apply(convert_from_string)
210
+ return metadata
211
+
212
+ def drop_rows_based_on_indices(self, indices: List[int]) -> None:
213
+ """Drop certain rows based on list of indices"""
214
+ self._df = self._df.drop(indices).reset_index(drop=True)
215
+
216
+ def reduce_based_on_col_value(
217
+ self,
218
+ col_name: str,
219
+ value: Union[str, float],
220
+ regex: bool = False,
221
+ keep: bool = True,
222
+ ) -> None:
223
+ """
224
+ Filter rows based on `value` of the column `col_name`
225
+ Pass None as value if want to check for nan values.
226
+
227
+ regex: whether to use regex expression (contains) or exact value
228
+ keep: whether to keep the matching values rows or the rows that do not match
229
+
230
+ Returns number of dropped rows
231
+ """
232
+ if not regex:
233
+ if value == None:
234
+ indices = self._df[col_name].isnull()
235
+ else:
236
+ indices = self._df[col_name] == value
237
+ else:
238
+ indices = self._df[col_name].str.contains(value)
239
+
240
+ if not keep:
241
+ indices = ~indices
242
+
243
+ self._df = self._df[indices].reset_index(drop=True)
244
+ return (~indices).sum()
245
+
246
+ def __len__(self):
247
+ return len(self._df)
248
+
249
+ def _get_column_mapping_dict_from_dataframe(self, key_col: str, value_col: str, df: Optional[None] = None):
250
+ """
251
+ Get a dictionary containing `key_col` column values as keys and
252
+ `value_col` column values as values
253
+ """
254
+
255
+ if df is None:
256
+ df = self._df
257
+
258
+ unique_keys_index = (
259
+ df.dropna(subset=value_col)
260
+ .drop_duplicates(subset=key_col, keep="first")
261
+ .index
262
+ )
263
+
264
+ keys = df.loc[unique_keys_index, key_col]
265
+ values = df.loc[unique_keys_index, value_col]
266
+
267
+ output = dict(zip(keys, values))
268
+ return output
barista/data/fileprogresstracker.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import Tuple
4
+
5
+ class FileProgressTracker:
6
+ """Manage loading and storing latest completely processed file index
7
+
8
+ This class save information required to continue processing in a file.
9
+ The file structure will be:
10
+ {
11
+ [experiment]: {
12
+ [self._file_ind_key]: int,
13
+ [self._ending_ind_key]: int,
14
+ [self._segment_id_key]: int
15
+ }
16
+ }
17
+ """
18
+
19
+ def __init__(self, save_path: str, experiment: str):
20
+ self.path = save_path
21
+ self.experiment = experiment
22
+ self._file_ind_key = "file_ind"
23
+ self._ending_ind_key = "ending_ind"
24
+ self._segment_id_key = "segment_id"
25
+ self._completed_key = "is_completed"
26
+
27
+ def _load_file(self) -> dict:
28
+ """Load processing info from file
29
+
30
+ Returns:
31
+ A dictionary having structure as descripted in the class info
32
+ """
33
+ data = {}
34
+ if os.path.exists(self.path):
35
+ with open(self.path) as f:
36
+ data = json.load(f)
37
+
38
+ if self.experiment not in data:
39
+ data[self.experiment] = {
40
+ self._file_ind_key: 0,
41
+ self._ending_ind_key: 0,
42
+ self._segment_id_key: -1,
43
+ self._completed_key: False,
44
+ }
45
+
46
+ return data
47
+
48
+ def _update_file(self, update_dict: dict) -> None:
49
+ """Update specified keys in file"""
50
+
51
+ data = self._load_file()
52
+ data[self.experiment].update(update_dict)
53
+
54
+ with open(self.path, "w+") as f:
55
+ json.dump(data, f)
56
+
57
+ def get_last_file_ind(self) -> Tuple[int, int, int]:
58
+ """Get last file that was processed for this experiment
59
+
60
+ Returns:
61
+ A tuple containing file index, ending index in the file, and the segment number of the last processed file
62
+ """
63
+ data = self._load_file()
64
+ return (
65
+ data[self.experiment][self._file_ind_key],
66
+ data[self.experiment][self._ending_ind_key],
67
+ data[self.experiment][self._segment_id_key],
68
+ )
69
+
70
+ def update_last_file_ind(
71
+ self, file_ind: int, ending_ind: int, segment_id: int
72
+ ) -> None:
73
+ """Update last file processed info in this experiment without changing other info in file if necessary"""
74
+
75
+ self._update_file(
76
+ {
77
+ self._file_ind_key: file_ind,
78
+ self._ending_ind_key: ending_ind,
79
+ self._segment_id_key: segment_id,
80
+ }
81
+ )
82
+
83
+ def mark_completion_status(self, completed: bool = True) -> None:
84
+ self._update_file({self._completed_key: completed})
85
+
86
+ def is_completed(self) -> bool:
87
+ data = self._load_file()
88
+ return data[self.experiment].get(self._completed_key, False)
89
+
90
+ def reset_process(self) -> None:
91
+ """Reset file processing status"""
92
+ self.mark_completion_status(completed=False)
93
+ self.update_last_file_ind(0, 0, -1)
barista/data/metadata.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from collections import defaultdict
3
+ import pandas as pd
4
+ import torch
5
+ from typing import Dict, List, Optional, Union
6
+
7
+ from barista.data.dataframe_wrapper import DataframeWrapper
8
+ from barista.data.metadata_spatial_groups import (
9
+ MetadataSpatialGroupRow,
10
+ MetadataSpatialGroups,
11
+ )
12
+
13
+
14
+ @dataclasses.dataclass
15
+ class MetadataRow:
16
+ dataset: str
17
+ subject: str
18
+ session: str
19
+ subject_session: str
20
+ experiment: str
21
+ d_input: int
22
+ d_data: torch.Size
23
+ split: str
24
+ path: str
25
+ filename: str
26
+ processing_str: str
27
+ seq_len: int
28
+ label: Optional[float]
29
+
30
+
31
+ class Metadata(DataframeWrapper):
32
+ """
33
+ Metadata class to keep track of all segment meta information.
34
+ """
35
+
36
+ def __init__(self, df=None, load_path=None, spatial_group_df=None):
37
+ if df is None:
38
+ assert spatial_group_df is None
39
+
40
+ super().__init__(df, load_path)
41
+
42
+ self._spatial_groups = None
43
+ if load_path is not None:
44
+ try:
45
+ self._spatial_groups = MetadataSpatialGroups(
46
+ load_path=self._get_spatial_group_path(load_path)
47
+ )
48
+ except FileNotFoundError:
49
+ pass
50
+ elif spatial_group_df is not None:
51
+ self._spatial_groups = MetadataSpatialGroups(df=spatial_group_df)
52
+
53
+ def _get_spatial_group_path(self, path: str) -> str:
54
+ suffix = ".csv"
55
+ new_path = path[: -len(suffix)]
56
+ spatial_path = f"{new_path}_spatial_groups{suffix}"
57
+ return spatial_path
58
+
59
+ def save(self, path: str) -> None:
60
+ super().save(path)
61
+ self._spatial_groups.save(self._get_spatial_group_path(path))
62
+
63
+ @classmethod
64
+ def merge(
65
+ cls,
66
+ metadatas: List["Metadata"],
67
+ drop_duplicate: bool = False,
68
+ merge_columns: Union[str, List[str], None] = None,
69
+ keep="first",
70
+ ) -> "Metadata":
71
+ new_metadata = super().merge(metadatas, drop_duplicate, merge_columns, keep)
72
+
73
+ # Add spatial groups
74
+ spatial_groups = [m._spatial_groups for m in metadatas]
75
+ merged_spatial_groups = MetadataSpatialGroups.merge(
76
+ spatial_groups,
77
+ drop_duplicate=True,
78
+ merge_columns=[
79
+ "dataset",
80
+ "subject_session",
81
+ "name",
82
+ ],
83
+ )
84
+ new_metadata._spatial_groups = merged_spatial_groups
85
+ return new_metadata
86
+
87
+ def get_subject_session_d_input(self) -> dict:
88
+ return self._get_column_mapping_dict_from_dataframe(
89
+ key_col="subject_session",
90
+ value_col="d_input",
91
+ )
92
+
93
+ def get_subjects(self) -> dict:
94
+ return self.get_unique_values_in_col("subject")
95
+
96
+ def _shape_str_to_list(self, value) -> tuple:
97
+ if not isinstance(value, str):
98
+ return value
99
+ return [int(a) for a in value.split(",")]
100
+
101
+ def get_subject_session_full_d_data(self) -> Dict[str, List[int]]:
102
+ """
103
+ Returns a dict containing subject_session to data shape
104
+ """
105
+ my_dict = self._get_column_mapping_dict_from_dataframe(
106
+ key_col="subject_session",
107
+ value_col="d_data",
108
+ )
109
+ return {k: self._shape_str_to_list(v) for k, v in my_dict.items()}
110
+
111
+
112
+ def get_labels_count_summary(self) -> dict:
113
+ splits = self.get_unique_values_in_col("split")
114
+ labels = self.get_unique_values_in_col("label")
115
+
116
+ labels_count = defaultdict(dict)
117
+ for split in splits:
118
+ for label in labels:
119
+ count = len(
120
+ self.get_indices_matching_cols_values(
121
+ ["split", "label"],
122
+ [split, label],
123
+ )
124
+ )
125
+ labels_count[split][label] = count
126
+ return labels_count
127
+
128
+ def get_summary_str(self) -> str:
129
+ subjects = self.get_unique_values_in_col("subject")
130
+ labels_count = self.get_labels_count_summary()
131
+
132
+ summary_str = f"Metadata for {len(subjects)} subjects ({subjects})"
133
+
134
+ for split, labels in labels_count.items():
135
+ for label, count in labels.items():
136
+ summary_str += f", {count} {split} segments with label {label}"
137
+
138
+ return summary_str
139
+
140
+ ########################### spatial group related ###########################
141
+
142
+ def add_spatial_group(self, spatial_group_row: MetadataSpatialGroupRow):
143
+ """
144
+ Add (or overwrite) the spatial group
145
+ """
146
+ self._spatial_groups.remove_spatial_group(
147
+ spatial_group_row.subject_session, spatial_group_row.name
148
+ )
149
+ self._spatial_groups.concat(pd.DataFrame([spatial_group_row]))
150
+
151
+ def get_spatial_grouping(
152
+ self, subject_session: str, name: str
153
+ ) -> Optional[MetadataSpatialGroupRow]:
154
+ """
155
+ Return spatial grouping information for spatial grouping `name` and subject_session `subject_session`'s.
156
+
157
+ Spatial grouping is MetadataSpatialGroupRow which the most important property is group_components
158
+ which is a list of tuples that contains group info for each channel of the data,
159
+ and group_ids which is a list of integer that specify which group each channel belongs to.
160
+ """
161
+
162
+ return self._spatial_groups.get_spatial_grouping(subject_session, name)
163
+
164
+ def get_spatial_grouping_id_hashmap(self, name: str) -> Dict[str, List[int]]:
165
+ """
166
+ Return spatial grouping dictionary which maps each subject_session to list of group ids which is a list of
167
+ length channels specifying which group each channel belongs to.
168
+
169
+ # NOTE Don't use during forward because of the copy
170
+ """
171
+ temp_copy = self._spatial_groups.copy()
172
+ temp_copy.reduce_based_on_col_value(col_name="name", value=name, keep=True)
173
+ return temp_copy._get_column_mapping_dict_from_dataframe(
174
+ "subject_session", "group_ids"
175
+ )
barista/data/metadata_spatial_groups.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import Enum
3
+ from typing import List, Optional, Tuple
4
+
5
+ from barista.data.dataframe_wrapper import DataframeWrapper
6
+
7
+
8
+ @dataclasses.dataclass
9
+ class MetadataSpatialGroupRow:
10
+ dataset: str
11
+ subject: str
12
+ session: str
13
+ subject_session: str
14
+ name: str # name/identifier of the spatial grouping
15
+ n_effective_components: int
16
+ max_elements_for_component: (
17
+ Tuple # tuple of size n_effective_components (or larger)
18
+ )
19
+ padding_indices: Tuple # tuple of size n_effective_components (or larger)
20
+ group_components: List # list of len number of channels -- List tuples that contains group info for each channel, useful for spatial encoding
21
+ group_ids: List # list of len number of channels -- List of int specifying which group each channel belongs to, useful for spatial masking
22
+
23
+
24
+ class SpatialGroupingName(Enum):
25
+ COORDS = "coords"
26
+ DESTRIEUX = "destrieux"
27
+ LOBES = "lobes"
28
+
29
+
30
+ class MetadataSpatialGroups(DataframeWrapper):
31
+ def _get_spatial_grouping_index(
32
+ self, subject_session: str, name: str
33
+ ) -> Optional[int]:
34
+ indices = self.get_indices_matching_cols_values(
35
+ ["subject_session", "name"], [subject_session, name]
36
+ )
37
+ if len(indices) == 0:
38
+ return None
39
+ assert (
40
+ len(indices) == 1
41
+ ), f"More than one results for spatial grouping '{name}' for '{subject_session}'"
42
+
43
+ return indices[0]
44
+
45
+ def get_spatial_grouping(
46
+ self, subject_session: str, name: str
47
+ ) -> MetadataSpatialGroupRow:
48
+ idx = self._get_spatial_grouping_index(subject_session, name)
49
+ if idx is None:
50
+ return None
51
+ a = self._df.iloc[idx].to_dict()
52
+ if "uniq_group_components" in a:
53
+ del a["uniq_group_components"]
54
+ return MetadataSpatialGroupRow(**a)
55
+
56
+ def remove_spatial_group(self, subject_session: str, name: str) -> int:
57
+ idx = self._get_spatial_grouping_index(subject_session, name)
58
+ if idx is None:
59
+ return 0
60
+ return self.drop_rows_based_on_indices([idx])
barista/data/splitter.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ from typing import Dict, List
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from barista.data.metadata import Metadata
9
+ from barista.models.utils import seed_everything
10
+
11
+ _SUPPORTED_SPLITS = ["shuffle", "chronological"]
12
+
13
+
14
+ class Splitter:
15
+ """Helper class to handle train/test/val splitting."""
16
+
17
+ def __init__(
18
+ self,
19
+ config: Dict,
20
+ subjects: List,
21
+ experiment: str,
22
+ use_fixed_seed: bool = False,
23
+ ):
24
+ self.config = config
25
+ self.subjects = subjects
26
+ self.experiment = experiment
27
+
28
+ self.use_fixed_seed = use_fixed_seed
29
+
30
+ def _use_configured_seed(func):
31
+ """Decorator for changing seed for a specific function"""
32
+
33
+ def wrapper(self, *args, **kwargs):
34
+ if not self.use_fixed_seed:
35
+ return func(self, *args, **kwargs)
36
+
37
+ prev_seed = int(os.environ.get("PL_GLOBAL_SEED", 0))
38
+ new_seed = int(self.config.get("splitter_seed", 0))
39
+
40
+ print(
41
+ f"Changing seed from {prev_seed} to {new_seed} for splitting"
42
+ )
43
+ seed_everything(new_seed)
44
+
45
+ out = func(self, *args, **kwargs)
46
+
47
+ print(f"Changing back seed from {new_seed} to {prev_seed}.")
48
+ seed_everything(prev_seed)
49
+
50
+ return out
51
+
52
+ return wrapper
53
+
54
+ @_use_configured_seed
55
+ def set_splits_for_subject(
56
+ self,
57
+ subject: str,
58
+ metadata: Metadata,
59
+ split_method: str = "shuffle"
60
+ ) -> Metadata:
61
+ """Set train/validation/test split
62
+
63
+ Every `split_together_length_s` will be splitted into one of the train/val/test
64
+
65
+ NOTE: This function assumes the segments are in order and consecutive in metadata if you want
66
+ to use split together multiple consecutive segments
67
+ """
68
+ # Set default if necessary.
69
+ if split_method not in _SUPPORTED_SPLITS:
70
+ print(f"[Warning] Setting split_method={split_method} to 'shuffle'")
71
+ split_method = "shuffle"
72
+
73
+ # Ensure the split together length is at least as long as the segments.
74
+ # Setting allows to split time series based on intervals > neural segment length.
75
+ split_together_length_s = max(
76
+ self.config.get("split_together_length_s", self.config.segment_length_s),
77
+ self.config.segment_length_s
78
+ )
79
+
80
+ subject_rows_indices = metadata.get_indices_matching_cols_values(
81
+ ["subject", "experiment"], [subject, self.experiment]
82
+ )
83
+
84
+ if split_method == "chronological":
85
+ return self._set_splits_across_time(
86
+ metadata, subject_rows_indices=subject_rows_indices
87
+ )
88
+
89
+ split_together_count = int(
90
+ split_together_length_s // self.config.segment_length_s
91
+ )
92
+ consecutive = (torch.diff(torch.tensor(subject_rows_indices)) == 1).all()
93
+
94
+ if split_together_count > 1:
95
+ assert (
96
+ consecutive
97
+ ), "subject rows are not consecutive, can't do splitting together"
98
+
99
+ n_segments = len(subject_rows_indices)
100
+ if n_segments == 0:
101
+ print(
102
+ f"[WARNING] No rows found for the subject {subject} and experiment {self.experiment} in metadata"
103
+ )
104
+ return metadata
105
+
106
+ starting_ind = subject_rows_indices[0]
107
+
108
+ if consecutive:
109
+ groups = list(
110
+ range(
111
+ starting_ind,
112
+ starting_ind + n_segments - split_together_count + 1,
113
+ split_together_count,
114
+ )
115
+ )
116
+ else:
117
+ # we've asserted that split_together_count is 1 in this case
118
+ groups = copy.deepcopy(subject_rows_indices)
119
+
120
+ np.random.shuffle(groups)
121
+
122
+ val_size = max(int(self.config.val_ratio * len(groups)), 1)
123
+ test_size = max(int(self.config.test_ratio * len(groups)), 1)
124
+
125
+ val_indices = []
126
+ for group_starting_idx in groups[:val_size]:
127
+ group_elem_indices = np.arange(split_together_count) + group_starting_idx
128
+ val_indices.extend(group_elem_indices)
129
+
130
+ test_indices = []
131
+ for group_starting_idx in groups[val_size : val_size + test_size]:
132
+ group_elem_indices = np.arange(split_together_count) + group_starting_idx
133
+ test_indices.extend(group_elem_indices)
134
+
135
+ metadata.set_col_to_value(subject_rows_indices, "split", "train")
136
+ metadata.set_col_to_value(val_indices, "split", "val")
137
+ metadata.set_col_to_value(test_indices, "split", "test")
138
+
139
+ return metadata
140
+
141
+ @_use_configured_seed
142
+ def resplit_for_subject(
143
+ self,
144
+ subject_session: str,
145
+ metadata: Metadata,
146
+ split_method: str,
147
+ ) -> Metadata:
148
+ if split_method == "chronological":
149
+ return self._set_splits_across_time(
150
+ metadata, subject_session=subject_session
151
+ )
152
+ else:
153
+ print("[WARNING] Resplitting only for chronological; splits unchanged")
154
+ return metadata
155
+
156
+ def __check_contiguous(self, subject_rows_indices, check_monotonic_only=False):
157
+ if check_monotonic_only:
158
+ assert (
159
+ torch.diff(torch.tensor(subject_rows_indices)) >= 1
160
+ ).all(), "subject rows are not consecutive, can't do splitting together"
161
+ else: # we need to be exactly increments of one.
162
+ assert (
163
+ torch.diff(torch.tensor(subject_rows_indices)) == 1
164
+ ).all(), "subject rows are not consecutive, can't do splitting together"
165
+
166
+ @_use_configured_seed
167
+ def _set_splits_across_time(
168
+ self,
169
+ metadata: Metadata,
170
+ subject_rows_indices: list = [],
171
+ subject_session: str = "",
172
+ return_splitted_indices: bool = False,
173
+ check_monotonic_only: bool = False,
174
+ verbose: bool = False,
175
+ ) -> Metadata:
176
+ if not subject_rows_indices and not subject_session:
177
+ raise ValueError(
178
+ "Need to either pass complete subject session name or subject_row_indices"
179
+ )
180
+
181
+ if (
182
+ not subject_rows_indices
183
+ ): # Prioritize using the subject_row_indices if given.
184
+ subject_rows_indices = metadata.get_indices_matching_cols_values(
185
+ ["subject_session", "experiment"], [subject_session, self.experiment]
186
+ )
187
+
188
+ self.__check_contiguous(
189
+ subject_rows_indices, check_monotonic_only=check_monotonic_only
190
+ )
191
+
192
+ n_segments = len(subject_rows_indices)
193
+
194
+ assert len(self.config.run_ratios) == len(self.config.run_splits)
195
+
196
+ counts = (np.array(self.config.run_ratios) * n_segments).astype(int)
197
+ counts[-1] = n_segments - sum(counts[:-1])
198
+
199
+ if verbose:
200
+ print(f"subject_session: {subject_session}")
201
+ print(f"RATIOS: {self.config.run_ratios}")
202
+ print(f"self.config.run_splits: {self.config.run_splits}")
203
+ print(f"COUNTS: {counts}")
204
+
205
+ if return_splitted_indices:
206
+ splitted_indices = []
207
+ sum_now = 0
208
+ for c, split in zip(counts, self.config.run_splits):
209
+ label_split_indices = subject_rows_indices[sum_now : sum_now + c]
210
+ if return_splitted_indices:
211
+ splitted_indices.append(label_split_indices)
212
+
213
+ sum_now += c
214
+ metadata.set_col_to_value(label_split_indices, "split", split)
215
+
216
+ self._check_split_labels(metadata, subject_session)
217
+ if return_splitted_indices:
218
+ return metadata, splitted_indices
219
+ return metadata
220
+
221
+ def _check_split_labels(self, metadata, subject_session):
222
+ # Check that both labels available in each split.
223
+ # NOTE: Not using asserts because the initial default splits might not have
224
+ # both, but the ones computed offline will and provided through the .pkl file
225
+ # will satisfy requirement.
226
+ for split in np.unique(self.config.run_splits):
227
+ for i in range(2): # magic 2 = positive/negative labels
228
+ if (
229
+ len(
230
+ metadata.get_indices_matching_cols_values(
231
+ ["subject_session", "experiment", "label", "split"],
232
+ [subject_session, self.experiment, i, split],
233
+ )
234
+ )
235
+ == 0
236
+ ):
237
+ print(f"split {split} missing label {i}")
barista/generate_chronological_folds.ipynb ADDED
@@ -0,0 +1,626 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "6d5e7d9f",
6
+ "metadata": {},
7
+ "source": [
8
+ "### Chronological split generation.\n",
9
+ "\n",
10
+ "The following is code used to generate the chronological splits based on the presence of positive and negative samples. This is more of an issue for the speech/sentence tasks, but the same approach is also used for the volume and optical flow tasks."
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 1,
16
+ "id": "70411f5f",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "%load_ext autoreload\n",
21
+ "%autoreload 2\n"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "id": "2d6f1fed",
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "from barista.data.metadata import Metadata\n",
32
+ "from collections import Counter, defaultdict\n",
33
+ "import numpy as np\n",
34
+ "import os\n",
35
+ "from pathlib import Path"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": 3,
41
+ "id": "b579134b",
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "def load_metadata(metadata_path):\n",
46
+ " return Metadata(load_path=metadata_path)"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 4,
52
+ "id": "d17fbaaa",
53
+ "metadata": {},
54
+ "outputs": [],
55
+ "source": [
56
+ "def generate_folds(subject_rows_indices, per_label_subject_rows_indices,\n",
57
+ " bucket_size=0.05, step_size=1, base_step_size=1,\n",
58
+ " window=4, base_window=1, **folds_kwargss):\n",
59
+ " assert window % 4 == 0, \"Window should be divisible by 4\"\n",
60
+ "\n",
61
+ " bucket_len = int(bucket_size * len(subject_rows_indices)) # bucket size in samples\n",
62
+ " buckets = np.arange(subject_rows_indices[0], subject_rows_indices[-1], bucket_len)\n",
63
+ " print(f\"Buckets: {buckets}\")\n",
64
+ "\n",
65
+ " ## Magic number 2 everywhere corresponds to the 0/1 (negative/positive) labels.\n",
66
+ " ## First, sum the unique label counts per bucket according to the specifications provided.\n",
67
+ " bucket_counts = [{} for i in range(len(buckets)-1)]\n",
68
+ " for bucket_ind in range(0, len(bucket_counts), base_step_size):\n",
69
+ " bucket_start = buckets[bucket_ind]\n",
70
+ " bucket_end = bucket_start + base_window * bucket_len\n",
71
+ " for i in range(2):\n",
72
+ " bucket_counts[bucket_ind][i] = np.sum(np.logical_and(\n",
73
+ " per_label_subject_rows_indices[i] >= bucket_start,\n",
74
+ " per_label_subject_rows_indices[i] < bucket_end\n",
75
+ " ))\n",
76
+ "\n",
77
+ " ## Count the residual samples in the last bucket.\n",
78
+ " for i in range(2):\n",
79
+ " bucket_counts[-1][i] += np.sum(\n",
80
+ " per_label_subject_rows_indices[i] >= bucket_end\n",
81
+ " )\n",
82
+ " print(f\"bucket_counts: {bucket_counts}\")\n",
83
+ "\n",
84
+ " return _find_folds(bucket_counts, step_size, window, bucket_size, **folds_kwargss)\n",
85
+ "\n",
86
+ "\n",
87
+ "def _find_folds(bucket_counts, step_size, window, bucket_size, num_folds=5):\n",
88
+ " \"\"\"Logic to find all legitimate folds such that train and test are separated with valid, e.g.,\n",
89
+ " \n",
90
+ " [train, valid, test]\n",
91
+ " [test, valid, train]\n",
92
+ " [train, valid (0.05), test, valid(0.05), train]\n",
93
+ " \"\"\"\n",
94
+ " all_folds, all_folds_splits = [], []\n",
95
+ " head, tail = 0, len(bucket_counts) - window\n",
96
+ " use_tail, quad_window = 0, int(window / 4)\n",
97
+ " while len(all_folds) < num_folds:\n",
98
+ " curr_ind = tail if use_tail else head\n",
99
+ " found = False\n",
100
+ " while not found and curr_ind >= 0 and curr_ind <= len(bucket_counts) - window:\n",
101
+ " ## Check that any of the validation buckets has both sets of labels.\n",
102
+ " val_found = False\n",
103
+ " for check_i in range(quad_window):\n",
104
+ " val_found |= bucket_counts[curr_ind + check_i][0] > 0 and bucket_counts[curr_ind + check_i][1] > 0\n",
105
+ " for check_i in range(window - quad_window, window):\n",
106
+ " val_found |= bucket_counts[curr_ind + check_i][0] > 0 and bucket_counts[curr_ind + check_i][1] > 0\n",
107
+ "\n",
108
+ " ## Check that any of the test buckets for test data has both labels.\n",
109
+ " test_found = False\n",
110
+ " for check_i in range(quad_window, 3*quad_window):\n",
111
+ " test_found |= bucket_counts[curr_ind + check_i][0] > 0 and bucket_counts[curr_ind + check_i][1] > 0\n",
112
+ "\n",
113
+ " found = val_found & test_found\n",
114
+ " if found:\n",
115
+ " found_ind = curr_ind\n",
116
+ " curr_ind += -step_size if use_tail else step_size\n",
117
+ "\n",
118
+ " val_test_interval = np.array([found_ind, found_ind + window]) * bucket_size\n",
119
+ " \n",
120
+ " this_fold = [bucket_size, (window-2)*bucket_size, bucket_size]\n",
121
+ " this_fold_splits = [\"val\", \"test\", \"val\"]\n",
122
+ " if 1.0 - val_test_interval[-1] > 0:\n",
123
+ " this_fold.append(1.0 - val_test_interval[-1])\n",
124
+ " this_fold_splits.append('train')\n",
125
+ " if val_test_interval[0] > 0:\n",
126
+ " this_fold = [val_test_interval[0]] + this_fold\n",
127
+ " this_fold_splits = ['train'] + this_fold_splits\n",
128
+ "\n",
129
+ " assert np.sum(this_fold) == 1.0\n",
130
+ " all_folds.append(this_fold)\n",
131
+ " all_folds_splits.append(this_fold_splits)\n",
132
+ "\n",
133
+ " if use_tail:\n",
134
+ " tail = curr_ind - 1 * step_size\n",
135
+ " else:\n",
136
+ " head = curr_ind + 1 * step_size\n",
137
+ " use_tail = 1 - use_tail\n",
138
+ "\n",
139
+ " return all_folds, all_folds_splits\n"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": null,
145
+ "id": "34c7aa28",
146
+ "metadata": {},
147
+ "outputs": [
148
+ {
149
+ "name": "stdout",
150
+ "output_type": "stream",
151
+ "text": [
152
+ "Buckets: [ 0 154 308 462 616 770 924 1078 1232 1386 1540 1694 1848 2002\n",
153
+ " 2156 2310 2464 2618 2772 2926 3080]\n",
154
+ "bucket_counts: [{0: 68, 1: 86}, {0: 92, 1: 62}, {0: 123, 1: 31}, {0: 42, 1: 112}, {0: 25, 1: 129}, {0: 76, 1: 78}, {0: 65, 1: 89}, {0: 81, 1: 73}, {0: 65, 1: 89}, {0: 33, 1: 121}, {0: 23, 1: 131}, {0: 65, 1: 89}, {0: 75, 1: 79}, {0: 106, 1: 48}, {0: 51, 1: 103}, {0: 103, 1: 51}, {0: 74, 1: 80}, {0: 62, 1: 92}, {0: 154, 1: 0}, {0: 160, 1: 0}]\n",
155
+ "Buckets: [ 0 165 330 495 660 825 990 1155 1320 1485 1650 1815 1980 2145\n",
156
+ " 2310 2475 2640 2805 2970 3135]\n",
157
+ "bucket_counts: [{0: 75, 1: 90}, {0: 117, 1: 48}, {0: 116, 1: 49}, {0: 35, 1: 130}, {0: 19, 1: 146}, {0: 40, 1: 125}, {0: 86, 1: 79}, {0: 48, 1: 117}, {0: 115, 1: 50}, {0: 50, 1: 115}, {0: 28, 1: 137}, {0: 26, 1: 139}, {0: 121, 1: 44}, {0: 95, 1: 70}, {0: 73, 1: 92}, {0: 83, 1: 82}, {0: 105, 1: 60}, {0: 88, 1: 77}, {0: 330, 1: 0}]\n",
158
+ "Buckets: [3086 3187 3288 3389 3490 3591 3692 3793 3894 3995 4096 4197 4298 4399\n",
159
+ " 4500 4601 4702 4803 4904 5005 5106]\n",
160
+ "bucket_counts: [{0: 78, 1: 23}, {0: 46, 1: 55}, {0: 68, 1: 33}, {0: 101, 1: 0}, {0: 95, 1: 6}, {0: 30, 1: 71}, {0: 17, 1: 84}, {0: 42, 1: 59}, {0: 25, 1: 76}, {0: 48, 1: 53}, {0: 21, 1: 80}, {0: 31, 1: 70}, {0: 26, 1: 75}, {0: 25, 1: 76}, {0: 74, 1: 27}, {0: 33, 1: 68}, {0: 39, 1: 62}, {0: 59, 1: 42}, {0: 45, 1: 56}, {0: 113, 1: 0}]\n",
161
+ "Buckets: [3300 3588 3876 4164 4452 4740 5028 5316 5604 5892 6180 6468 6756 7044\n",
162
+ " 7332 7620 7908 8196 8484 8772 9060]\n",
163
+ "bucket_counts: [{0: 231, 1: 57}, {0: 124, 1: 164}, {0: 195, 1: 93}, {0: 288, 1: 0}, {0: 246, 1: 42}, {0: 64, 1: 224}, {0: 95, 1: 193}, {0: 85, 1: 203}, {0: 45, 1: 243}, {0: 120, 1: 168}, {0: 42, 1: 246}, {0: 115, 1: 173}, {0: 57, 1: 231}, {0: 71, 1: 217}, {0: 236, 1: 52}, {0: 85, 1: 203}, {0: 109, 1: 179}, {0: 193, 1: 95}, {0: 184, 1: 104}, {0: 302, 1: 0}]\n",
164
+ "Buckets: [5118 5184 5250 5316 5382 5448 5514 5580 5646 5712 5778 5844 5910 5976\n",
165
+ " 6042 6108 6174 6240 6306 6372 6438]\n",
166
+ "bucket_counts: [{0: 66, 1: 0}, {0: 66, 1: 0}, {0: 39, 1: 27}, {0: 46, 1: 20}, {0: 9, 1: 57}, {0: 38, 1: 28}, {0: 13, 1: 53}, {0: 19, 1: 47}, {0: 26, 1: 40}, {0: 20, 1: 46}, {0: 18, 1: 48}, {0: 12, 1: 54}, {0: 28, 1: 38}, {0: 34, 1: 32}, {0: 49, 1: 17}, {0: 28, 1: 38}, {0: 35, 1: 31}, {0: 25, 1: 41}, {0: 19, 1: 47}, {0: 76, 1: 2}]\n",
167
+ "Buckets: [ 9074 9140 9206 9272 9338 9404 9470 9536 9602 9668 9734 9800\n",
168
+ " 9866 9932 9998 10064 10130 10196 10262 10328 10394]\n",
169
+ "bucket_counts: [{0: 66, 1: 0}, {0: 66, 1: 0}, {0: 35, 1: 31}, {0: 58, 1: 8}, {0: 18, 1: 48}, {0: 36, 1: 30}, {0: 9, 1: 57}, {0: 20, 1: 46}, {0: 20, 1: 46}, {0: 22, 1: 44}, {0: 16, 1: 50}, {0: 7, 1: 59}, {0: 18, 1: 48}, {0: 28, 1: 38}, {0: 55, 1: 11}, {0: 39, 1: 27}, {0: 35, 1: 31}, {0: 21, 1: 45}, {0: 19, 1: 47}, {0: 81, 1: 3}]\n",
170
+ "Buckets: [6450 6529 6608 6687 6766 6845 6924 7003 7082 7161 7240 7319 7398 7477\n",
171
+ " 7556 7635 7714 7793 7872 7951 8030]\n",
172
+ "bucket_counts: [{0: 79, 1: 0}, {0: 79, 1: 0}, {0: 64, 1: 15}, {0: 52, 1: 27}, {0: 19, 1: 60}, {0: 51, 1: 28}, {0: 27, 1: 52}, {0: 20, 1: 59}, {0: 9, 1: 70}, {0: 23, 1: 56}, {0: 18, 1: 61}, {0: 56, 1: 23}, {0: 14, 1: 65}, {0: 26, 1: 53}, {0: 6, 1: 73}, {0: 37, 1: 42}, {0: 46, 1: 33}, {0: 25, 1: 54}, {0: 54, 1: 25}, {0: 92, 1: 1}]\n",
173
+ "Buckets: [10412 10509 10606 10703 10800 10897 10994 11091 11188 11285 11382 11479\n",
174
+ " 11576 11673 11770 11867 11964 12061 12158 12255 12352]\n",
175
+ "bucket_counts: [{0: 97, 1: 0}, {0: 86, 1: 11}, {0: 76, 1: 21}, {0: 24, 1: 73}, {0: 53, 1: 44}, {0: 36, 1: 61}, {0: 8, 1: 89}, {0: 17, 1: 80}, {0: 45, 1: 52}, {0: 97, 1: 0}, {0: 69, 1: 28}, {0: 56, 1: 41}, {0: 32, 1: 65}, {0: 21, 1: 76}, {0: 12, 1: 85}, {0: 28, 1: 69}, {0: 39, 1: 58}, {0: 28, 1: 69}, {0: 43, 1: 54}, {0: 110, 1: 1}]\n",
176
+ "Buckets: [8044 8095 8146 8197 8248 8299 8350 8401 8452 8503 8554 8605 8656 8707\n",
177
+ " 8758 8809 8860 8911 8962 9013 9064]\n",
178
+ "bucket_counts: [{0: 51, 1: 0}, {0: 51, 1: 0}, {0: 43, 1: 8}, {0: 4, 1: 47}, {0: 16, 1: 35}, {0: 16, 1: 35}, {0: 18, 1: 33}, {0: 6, 1: 45}, {0: 37, 1: 14}, {0: 42, 1: 9}, {0: 8, 1: 43}, {0: 0, 1: 51}, {0: 24, 1: 27}, {0: 51, 1: 0}, {0: 51, 1: 0}, {0: 28, 1: 23}, {0: 8, 1: 43}, {0: 24, 1: 27}, {0: 12, 1: 39}, {0: 24, 1: 35}]\n",
179
+ "Buckets: [12366 12499 12632 12765 12898 13031 13164 13297 13430 13563 13696 13829\n",
180
+ " 13962 14095 14228 14361 14494 14627 14760 14893 15026]\n",
181
+ "bucket_counts: [{0: 133, 1: 0}, {0: 133, 1: 0}, {0: 74, 1: 59}, {0: 9, 1: 124}, {0: 57, 1: 76}, {0: 22, 1: 111}, {0: 60, 1: 73}, {0: 48, 1: 85}, {0: 71, 1: 62}, {0: 133, 1: 0}, {0: 24, 1: 109}, {0: 15, 1: 118}, {0: 71, 1: 62}, {0: 133, 1: 0}, {0: 133, 1: 0}, {0: 50, 1: 83}, {0: 42, 1: 91}, {0: 30, 1: 103}, {0: 39, 1: 94}, {0: 60, 1: 87}]\n",
182
+ "Buckets: [9072 9112 9152 9192 9232 9272 9312 9352 9392 9432 9472 9512 9552 9592\n",
183
+ " 9632 9672 9712 9752 9792 9832 9872]\n",
184
+ "bucket_counts: [{0: 30, 1: 10}, {0: 29, 1: 11}, {0: 39, 1: 1}, {0: 15, 1: 25}, {0: 12, 1: 28}, {0: 27, 1: 13}, {0: 12, 1: 28}, {0: 16, 1: 24}, {0: 21, 1: 19}, {0: 20, 1: 20}, {0: 17, 1: 23}, {0: 18, 1: 22}, {0: 11, 1: 29}, {0: 15, 1: 25}, {0: 24, 1: 16}, {0: 19, 1: 21}, {0: 17, 1: 23}, {0: 30, 1: 10}, {0: 10, 1: 30}, {0: 24, 1: 28}]\n",
185
+ "Buckets: [15040 15079 15118 15157 15196 15235 15274 15313 15352 15391 15430 15469\n",
186
+ " 15508 15547 15586 15625 15664 15703 15742 15781]\n",
187
+ "bucket_counts: [{0: 35, 1: 4}, {0: 25, 1: 14}, {0: 38, 1: 1}, {0: 17, 1: 22}, {0: 7, 1: 32}, {0: 32, 1: 7}, {0: 12, 1: 27}, {0: 17, 1: 22}, {0: 15, 1: 24}, {0: 14, 1: 25}, {0: 19, 1: 20}, {0: 18, 1: 21}, {0: 7, 1: 32}, {0: 13, 1: 26}, {0: 21, 1: 18}, {0: 17, 1: 22}, {0: 20, 1: 19}, {0: 27, 1: 12}, {0: 36, 1: 42}]\n",
188
+ "Buckets: [ 9884 9942 10000 10058 10116 10174 10232 10290 10348 10406 10464 10522\n",
189
+ " 10580 10638 10696 10754 10812 10870 10928 10986 11044]\n",
190
+ "bucket_counts: [{0: 48, 1: 10}, {0: 39, 1: 19}, {0: 44, 1: 14}, {0: 23, 1: 35}, {0: 27, 1: 31}, {0: 17, 1: 41}, {0: 25, 1: 33}, {0: 9, 1: 49}, {0: 19, 1: 39}, {0: 18, 1: 40}, {0: 58, 1: 0}, {0: 58, 1: 0}, {0: 12, 1: 46}, {0: 13, 1: 45}, {0: 12, 1: 46}, {0: 14, 1: 44}, {0: 25, 1: 33}, {0: 14, 1: 44}, {0: 38, 1: 20}, {0: 76, 1: 0}]\n",
191
+ "Buckets: [15820 15877 15934 15991 16048 16105 16162 16219 16276 16333 16390 16447\n",
192
+ " 16504 16561 16618 16675 16732 16789 16846 16903 16960]\n",
193
+ "bucket_counts: [{0: 48, 1: 9}, {0: 38, 1: 19}, {0: 45, 1: 12}, {0: 19, 1: 38}, {0: 7, 1: 50}, {0: 32, 1: 25}, {0: 22, 1: 35}, {0: 15, 1: 42}, {0: 14, 1: 43}, {0: 16, 1: 41}, {0: 44, 1: 13}, {0: 57, 1: 0}, {0: 21, 1: 36}, {0: 16, 1: 41}, {0: 15, 1: 42}, {0: 13, 1: 44}, {0: 25, 1: 32}, {0: 23, 1: 34}, {0: 41, 1: 16}, {0: 61, 1: 0}]\n"
194
+ ]
195
+ }
196
+ ],
197
+ "source": [
198
+ "## Specify all subjects to compute the chronological folds for.\n",
199
+ "## By default we have the held out sessions (val/test) listed here.\n",
200
+ "ALL_SUBJECTS = [\n",
201
+ " \"HOLDSUBJ_1_HS1_1\",\n",
202
+ " \"HOLDSUBJ_2_HS2_6\",\n",
203
+ " \"HOLDSUBJ_3_HS3_0\",\n",
204
+ " \"HOLDSUBJ_4_HS4_0\",\n",
205
+ " \"HOLDSUBJ_6_HS6_4\",\n",
206
+ " \"HOLDSUBJ_7_HS7_0\",\n",
207
+ " \"HOLDSUBJ_10_HS10_0\",\n",
208
+ "\n",
209
+ " # \"SUBJ_2_S2_5\",\n",
210
+ " # \"SUBJ_4_S4_2\",\n",
211
+ "]\n",
212
+ "\n",
213
+ "## List all the metadata files that correspond to the segments to preprocess. Can optionally use\n",
214
+ "## keyword identifiers for each of the metadata files that need to be processed.\n",
215
+ "_METADATA_FNAMES = {\n",
216
+ " 'default_metadata': 'metadata_ee8e0.csv',\n",
217
+ "}\n",
218
+ "\n",
219
+ "## List all experiments for which the folds should be computed.\n",
220
+ "# _ALL_EXPERIMENTS = [\"sentence_onset_time\", \"speech_vs_nonspeech_time\", \"volume\", \"optical_flow\"]\n",
221
+ "_ALL_EXPERIMENTS = [\"sentence_onset_time\", \"speech_vs_nonspeech_time\"]\n",
222
+ "\n",
223
+ "_SEGMENT_DIR = 'braintreebank_data_segments/{0}'\n",
224
+ "\n",
225
+ "## These are the recommended default settings for computing the folds.\n",
226
+ "bucket_size = 0.05 # Each bucket is 5% duration in samples\n",
227
+ "base_step_size = 1 # We take increments of base_step_size * 5% in samples when constructing buckets.\n",
228
+ "base_window = 1 # Count number of samples per base_window * 5% interval per bucket. Should match base_step_size ideally.\n",
229
+ "step_size = 2 # We take increments of step_size * bucket_size (5%) when looking for buckets.\n",
230
+ "window = 4 # Targeting 20% of data for val and test (i.e., 4 buckets combined for val and test).\n",
231
+ "num_folds = 5 # Number of folds to generate.\n",
232
+ "\n",
233
+ "subject_folds = {}\n",
234
+ "for metadata_setting in _METADATA_FNAMES.keys():\n",
235
+ " metadata_setting_folds = defaultdict(dict)\n",
236
+ "\n",
237
+ " for subject_session in ALL_SUBJECTS:\n",
238
+ " for experiment in _ALL_EXPERIMENTS:\n",
239
+ "\n",
240
+ " fpath = _SEGMENT_DIR.format(experiment)\n",
241
+ " metadata_fname = _METADATA_FNAMES[metadata_setting]\n",
242
+ " metadata = load_metadata(os.path.join(fpath, metadata_fname))\n",
243
+ "\n",
244
+ " subject_rows_indices = metadata.get_indices_matching_cols_values(\n",
245
+ " [\"subject_session\", \"experiment\"], [subject_session, experiment]\n",
246
+ " )\n",
247
+ "\n",
248
+ " per_label_subject_rows_indices = [0, 0]\n",
249
+ " for i in range(2): # 2 = negative/positive labels.\n",
250
+ " per_label_subject_rows_indices[i] = (\n",
251
+ " metadata.get_indices_matching_cols_values(\n",
252
+ " [\"subject_session\", \"experiment\", \"label\"],\n",
253
+ " [subject_session, experiment, i],\n",
254
+ " )\n",
255
+ " )\n",
256
+ "\n",
257
+ " all_folds, all_folds_splits = generate_folds(\n",
258
+ " subject_rows_indices,\n",
259
+ " per_label_subject_rows_indices,\n",
260
+ " bucket_size,\n",
261
+ " step_size,\n",
262
+ " base_step_size,\n",
263
+ " window,\n",
264
+ " base_window,\n",
265
+ " num_folds=num_folds\n",
266
+ " )\n",
267
+ "\n",
268
+ " metadata_setting_folds[subject_session][experiment] = (all_folds, all_folds_splits)\n",
269
+ "\n",
270
+ " subject_folds[metadata_setting] = metadata_setting_folds"
271
+ ]
272
+ },
273
+ {
274
+ "cell_type": "code",
275
+ "execution_count": 6,
276
+ "id": "aacfb210",
277
+ "metadata": {},
278
+ "outputs": [
279
+ {
280
+ "name": "stdout",
281
+ "output_type": "stream",
282
+ "text": [
283
+ "metadata_setting:default_metadata, subject_session:HOLDSUBJ_1_HS1_1, experiment:sentence_onset_time\n",
284
+ "\n",
285
+ "Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
286
+ "Split statistics: {'train': Counter({1: 1252, 0: 1218}), 'val': Counter({1: 198, 0: 110}), 'test': Counter({0: 215, 1: 93})}\n",
287
+ "Run_ratio: [0.8, 0.05, 0.1, 0.05]\n",
288
+ "Split statistics: {'train': Counter({1: 1371, 0: 1097}), 'val': Counter({0: 228, 1: 82}), 'test': Counter({0: 218, 1: 90})}\n",
289
+ "Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
290
+ "Split statistics: {'train': Counter({0: 1296, 1: 1174}), 'val': Counter({1: 202, 0: 106}), 'test': Counter({1: 167, 0: 141})}\n",
291
+ "Run_ratio: [0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996]\n",
292
+ "Split statistics: {'train': Counter({1: 1262, 0: 1208}), 'val': Counter({0: 178, 1: 130}), 'test': Counter({0: 157, 1: 151})}\n",
293
+ "Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
294
+ "Split statistics: {'train': Counter({0: 1355, 1: 1115}), 'val': Counter({1: 175, 0: 133}), 'test': Counter({1: 253, 0: 55})}\n",
295
+ "\n",
296
+ "\n",
297
+ "metadata_setting:default_metadata, subject_session:HOLDSUBJ_2_HS2_6, experiment:sentence_onset_time\n",
298
+ "\n",
299
+ "Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
300
+ "Split statistics: {'train': Counter({1: 905, 0: 722}), 'val': Counter({0: 179, 1: 23}), 'test': Counter({0: 115, 1: 88})}\n",
301
+ "Run_ratio: [0.8, 0.05, 0.1, 0.05]\n",
302
+ "Split statistics: {'train': Counter({1: 860, 0: 765}), 'val': Counter({0: 140, 1: 64}), 'test': Counter({0: 111, 1: 92})}\n",
303
+ "Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
304
+ "Split statistics: {'train': Counter({0: 834, 1: 793}), 'val': Counter({0: 133, 1: 69}), 'test': Counter({1: 154, 0: 49})}\n",
305
+ "Run_ratio: [0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996]\n",
306
+ "Split statistics: {'train': Counter({0: 856, 1: 771}), 'val': Counter({1: 146, 0: 56}), 'test': Counter({0: 104, 1: 99})}\n",
307
+ "Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
308
+ "Split statistics: {'train': Counter({0: 890, 1: 737}), 'val': Counter({1: 146, 0: 56}), 'test': Counter({1: 133, 0: 70})}\n",
309
+ "\n",
310
+ "\n",
311
+ "metadata_setting:default_metadata, subject_session:HOLDSUBJ_3_HS3_0, experiment:sentence_onset_time\n",
312
+ "\n",
313
+ "Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
314
+ "Split statistics: {'train': Counter({1: 618, 0: 449}), 'val': Counter({0: 111, 1: 21}), 'test': Counter({0: 106, 1: 27})}\n",
315
+ "Run_ratio: [0.8, 0.05, 0.1, 0.05]\n",
316
+ "Split statistics: {'train': Counter({1: 552, 0: 513}), 'val': Counter({0: 101, 1: 33}), 'test': Counter({1: 81, 0: 52})}\n",
317
+ "Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
318
+ "Split statistics: {'train': Counter({0: 586, 1: 481}), 'val': Counter({1: 104, 0: 28}), 'test': Counter({1: 81, 0: 52})}\n",
319
+ "Run_ratio: [0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996]\n",
320
+ "Split statistics: {'train': Counter({1: 539, 0: 528}), 'val': Counter({1: 76, 0: 56}), 'test': Counter({0: 82, 1: 51})}\n",
321
+ "Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
322
+ "Split statistics: {'train': Counter({0: 589, 1: 478}), 'val': Counter({1: 92, 0: 40}), 'test': Counter({1: 96, 0: 37})}\n",
323
+ "\n",
324
+ "\n",
325
+ "metadata_setting:default_metadata, subject_session:HOLDSUBJ_4_HS4_0, experiment:sentence_onset_time\n",
326
+ "\n",
327
+ "Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
328
+ "Split statistics: {'train': Counter({1: 754, 0: 523}), 'val': Counter({0: 130, 1: 28}), 'test': Counter({0: 144, 1: 15})}\n",
329
+ "Run_ratio: [0.8, 0.05, 0.1, 0.05]\n",
330
+ "Split statistics: {'train': Counter({1: 689, 0: 586}), 'val': Counter({0: 125, 1: 35}), 'test': Counter({0: 86, 1: 73})}\n",
331
+ "Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
332
+ "Split statistics: {'train': Counter({0: 680, 1: 597}), 'val': Counter({1: 119, 0: 39}), 'test': Counter({1: 81, 0: 78})}\n",
333
+ "Run_ratio: [0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996]\n",
334
+ "Split statistics: {'train': Counter({0: 710, 1: 567}), 'val': Counter({1: 102, 0: 56}), 'test': Counter({1: 128, 0: 31})}\n",
335
+ "Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
336
+ "Split statistics: {'train': Counter({0: 691, 1: 586}), 'val': Counter({1: 98, 0: 60}), 'test': Counter({1: 113, 0: 46})}\n",
337
+ "\n",
338
+ "\n",
339
+ "metadata_setting:default_metadata, subject_session:HOLDSUBJ_6_HS6_4, experiment:sentence_onset_time\n",
340
+ "\n",
341
+ "Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
342
+ "Split statistics: {'train': Counter({1: 459, 0: 365}), 'val': Counter({0: 55, 1: 47}), 'test': Counter({0: 94, 1: 8})}\n",
343
+ "Run_ratio: [0.8, 0.05, 0.1, 0.05]\n",
344
+ "Split statistics: {'train': Counter({0: 448, 1: 374}), 'val': Counter({1: 70, 0: 34}), 'test': Counter({1: 70, 0: 32})}\n",
345
+ "Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
346
+ "Split statistics: {'train': Counter({0: 458, 1: 366}), 'val': Counter({1: 79, 0: 23}), 'test': Counter({1: 69, 0: 33})}\n",
347
+ "Run_ratio: [0.5, 0.05, 0.1, 0.05, 0.29999999999999993]\n",
348
+ "Split statistics: {'train': Counter({0: 427, 1: 397}), 'val': Counter({0: 59, 1: 43}), 'test': Counter({1: 74, 0: 28})}\n",
349
+ "Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
350
+ "Split statistics: {'train': Counter({0: 427, 1: 397}), 'val': Counter({1: 62, 0: 40}), 'test': Counter({1: 55, 0: 47})}\n",
351
+ "\n",
352
+ "\n",
353
+ "metadata_setting:default_metadata, subject_session:HOLDSUBJ_7_HS7_0, experiment:sentence_onset_time\n",
354
+ "\n",
355
+ "Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
356
+ "Split statistics: {'train': Counter({1: 358, 0: 293}), 'val': Counter({0: 44, 1: 36}), 'test': Counter({0: 69, 1: 12})}\n",
357
+ "Run_ratio: [0.8, 0.05, 0.1, 0.05]\n",
358
+ "Split statistics: {'train': Counter({0: 330, 1: 319}), 'val': Counter({0: 45, 1: 37}), 'test': Counter({1: 50, 0: 31})}\n",
359
+ "Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
360
+ "Split statistics: {'train': Counter({0: 337, 1: 314}), 'val': Counter({1: 50, 0: 30}), 'test': Counter({1: 42, 0: 39})}\n",
361
+ "Run_ratio: [0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996]\n",
362
+ "Split statistics: {'train': Counter({0: 333, 1: 318}), 'val': Counter({1: 43, 0: 37}), 'test': Counter({1: 45, 0: 36})}\n",
363
+ "Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
364
+ "Split statistics: {'train': Counter({0: 331, 1: 320}), 'val': Counter({0: 41, 1: 39}), 'test': Counter({1: 47, 0: 34})}\n",
365
+ "\n",
366
+ "\n",
367
+ "metadata_setting:default_metadata, subject_session:HOLDSUBJ_10_HS10_0, experiment:sentence_onset_time\n",
368
+ "\n",
369
+ "Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
370
+ "Split statistics: {'train': Counter({1: 510, 0: 435}), 'val': Counter({0: 70, 1: 46}), 'test': Counter({0: 84, 1: 33})}\n",
371
+ "Run_ratio: [0.8, 0.05, 0.1, 0.05]\n",
372
+ "Split statistics: {'train': Counter({1: 495, 0: 447}), 'val': Counter({0: 76, 1: 43}), 'test': Counter({0: 66, 1: 51})}\n",
373
+ "Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
374
+ "Split statistics: {'train': Counter({0: 511, 1: 434}), 'val': Counter({1: 80, 0: 36}), 'test': Counter({1: 75, 0: 42})}\n",
375
+ "Run_ratio: [0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996]\n",
376
+ "Split statistics: {'train': Counter({0: 534, 1: 411}), 'val': Counter({1: 89, 0: 27}), 'test': Counter({1: 89, 0: 28})}\n",
377
+ "Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
378
+ "Split statistics: {'train': Counter({1: 509, 0: 436}), 'val': Counter({0: 74, 1: 42}), 'test': Counter({0: 79, 1: 38})}\n",
379
+ "\n",
380
+ "\n",
381
+ "metadata_setting:default_metadata, subject_session:HOLDSUBJ_1_HS1_1, experiment:speech_vs_nonspeech_time\n",
382
+ "\n",
383
+ "Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
384
+ "Split statistics: {'train': Counter({1: 1333, 0: 1307}), 'val': Counter({1: 220, 0: 110}), 'test': Counter({0: 233, 1: 97})}\n",
385
+ "Run_ratio: [0.75, 0.05, 0.1, 0.05, 0.04999999999999993]\n",
386
+ "Split statistics: {'train': Counter({1: 1431, 0: 1209}), 'val': Counter({0: 248, 1: 82}), 'test': Counter({0: 193, 1: 137})}\n",
387
+ "Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
388
+ "Split statistics: {'train': Counter({0: 1457, 1: 1183}), 'val': Counter({1: 263, 0: 67}), 'test': Counter({1: 204, 0: 126})}\n",
389
+ "Run_ratio: [0.55, 0.05, 0.1, 0.05, 0.25]\n",
390
+ "Split statistics: {'train': Counter({0: 1335, 1: 1305}), 'val': Counter({1: 231, 0: 99}), 'test': Counter({0: 216, 1: 114})}\n",
391
+ "Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
392
+ "Split statistics: {'train': Counter({0: 1431, 1: 1209}), 'val': Counter({1: 189, 0: 141}), 'test': Counter({1: 252, 0: 78})}\n",
393
+ "\n",
394
+ "\n",
395
+ "metadata_setting:default_metadata, subject_session:HOLDSUBJ_2_HS2_6, experiment:speech_vs_nonspeech_time\n",
396
+ "\n",
397
+ "Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
398
+ "Split statistics: {'train': Counter({1: 2573, 0: 2048}), 'val': Counter({0: 519, 1: 57}), 'test': Counter({0: 320, 1: 257})}\n",
399
+ "Run_ratio: [0.8, 0.05, 0.1, 0.05]\n",
400
+ "Split statistics: {'train': Counter({1: 2511, 0: 2108}), 'val': Counter({0: 396, 1: 182}), 'test': Counter({0: 383, 1: 194})}\n",
401
+ "Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
402
+ "Split statistics: {'train': Counter({0: 2399, 1: 2222}), 'val': Counter({0: 329, 1: 247}), 'test': Counter({1: 418, 0: 159})}\n",
403
+ "Run_ratio: [0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996]\n",
404
+ "Split statistics: {'train': Counter({0: 2431, 1: 2190}), 'val': Counter({1: 434, 0: 142}), 'test': Counter({0: 314, 1: 263})}\n",
405
+ "Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
406
+ "Split statistics: {'train': Counter({0: 2565, 1: 2056}), 'val': Counter({1: 418, 0: 158}), 'test': Counter({1: 413, 0: 164})}\n",
407
+ "\n",
408
+ "\n",
409
+ "metadata_setting:default_metadata, subject_session:HOLDSUBJ_3_HS3_0, experiment:speech_vs_nonspeech_time\n",
410
+ "\n",
411
+ "Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
412
+ "Split statistics: {'train': Counter({1: 630, 0: 443}), 'val': Counter({0: 125, 1: 7}), 'test': Counter({0: 101, 1: 32})}\n",
413
+ "Run_ratio: [0.8, 0.05, 0.1, 0.05]\n",
414
+ "Split statistics: {'train': Counter({1: 555, 0: 515}), 'val': Counter({0: 105, 1: 30}), 'test': Counter({1: 84, 0: 49})}\n",
415
+ "Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
416
+ "Split statistics: {'train': Counter({0: 588, 1: 485}), 'val': Counter({1: 95, 0: 37}), 'test': Counter({1: 89, 0: 44})}\n",
417
+ "Run_ratio: [0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996]\n",
418
+ "Split statistics: {'train': Counter({1: 544, 0: 529}), 'val': Counter({1: 85, 0: 47}), 'test': Counter({0: 93, 1: 40})}\n",
419
+ "Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
420
+ "Split statistics: {'train': Counter({0: 604, 1: 469}), 'val': Counter({1: 100, 0: 32}), 'test': Counter({1: 100, 0: 33})}\n",
421
+ "\n",
422
+ "\n",
423
+ "metadata_setting:default_metadata, subject_session:HOLDSUBJ_4_HS4_0, experiment:speech_vs_nonspeech_time\n",
424
+ "\n",
425
+ "Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
426
+ "Split statistics: {'train': Counter({1: 872, 0: 693}), 'val': Counter({0: 122, 1: 72}), 'test': Counter({0: 162, 1: 33})}\n",
427
+ "Run_ratio: [0.8, 0.05, 0.1, 0.05]\n",
428
+ "Split statistics: {'train': Counter({1: 805, 0: 758}), 'val': Counter({0: 146, 1: 50}), 'test': Counter({1: 122, 0: 73})}\n",
429
+ "Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
430
+ "Split statistics: {'train': Counter({0: 863, 1: 702}), 'val': Counter({1: 125, 0: 69}), 'test': Counter({1: 150, 0: 45})}\n",
431
+ "Run_ratio: [0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996]\n",
432
+ "Split statistics: {'train': Counter({0: 883, 1: 682}), 'val': Counter({1: 132, 0: 62}), 'test': Counter({1: 163, 0: 32})}\n",
433
+ "Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
434
+ "Split statistics: {'train': Counter({1: 852, 0: 713}), 'val': Counter({0: 102, 1: 92}), 'test': Counter({0: 162, 1: 33})}\n",
435
+ "\n",
436
+ "\n",
437
+ "metadata_setting:default_metadata, subject_session:HOLDSUBJ_6_HS6_4, experiment:speech_vs_nonspeech_time\n",
438
+ "\n",
439
+ "Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
440
+ "Split statistics: {'train': Counter({1: 1153, 0: 988}), 'val': Counter({0: 142, 1: 124}), 'test': Counter({0: 207, 1: 60})}\n",
441
+ "Run_ratio: [0.8, 0.05, 0.1, 0.05]\n",
442
+ "Split statistics: {'train': Counter({0: 1168, 1: 971}), 'val': Counter({1: 165, 0: 103}), 'test': Counter({1: 201, 0: 66})}\n",
443
+ "Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
444
+ "Split statistics: {'train': Counter({0: 1149, 1: 992}), 'val': Counter({1: 162, 0: 104}), 'test': Counter({1: 183, 0: 84})}\n",
445
+ "Run_ratio: [0.5, 0.05, 0.1, 0.05, 0.29999999999999993]\n",
446
+ "Split statistics: {'train': Counter({0: 1089, 1: 1052}), 'val': Counter({0: 157, 1: 109}), 'test': Counter({1: 176, 0: 91})}\n",
447
+ "Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
448
+ "Split statistics: {'train': Counter({0: 1095, 1: 1046}), 'val': Counter({1: 178, 0: 88}), 'test': Counter({0: 154, 1: 113})}\n",
449
+ "\n",
450
+ "\n",
451
+ "metadata_setting:default_metadata, subject_session:HOLDSUBJ_7_HS7_0, experiment:speech_vs_nonspeech_time\n",
452
+ "\n",
453
+ "Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
454
+ "Split statistics: {'train': Counter({1: 349, 0: 275}), 'val': Counter({0: 52, 1: 26}), 'test': Counter({0: 63, 1: 15})}\n",
455
+ "Run_ratio: [0.75, 0.05, 0.1, 0.05, 0.04999999999999993]\n",
456
+ "Split statistics: {'train': Counter({0: 313, 1: 311}), 'val': Counter({1: 48, 0: 30}), 'test': Counter({0: 47, 1: 31})}\n",
457
+ "Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
458
+ "Split statistics: {'train': Counter({0: 322, 1: 302}), 'val': Counter({1: 54, 0: 24}), 'test': Counter({0: 44, 1: 34})}\n",
459
+ "Run_ratio: [0.55, 0.05, 0.1, 0.05, 0.25]\n",
460
+ "Split statistics: {'train': Counter({0: 331, 1: 293}), 'val': Counter({1: 39, 0: 39}), 'test': Counter({1: 58, 0: 20})}\n",
461
+ "Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
462
+ "Split statistics: {'train': Counter({0: 324, 1: 300}), 'val': Counter({1: 45, 0: 33}), 'test': Counter({1: 45, 0: 33})}\n",
463
+ "\n",
464
+ "\n",
465
+ "metadata_setting:default_metadata, subject_session:HOLDSUBJ_10_HS10_0, experiment:speech_vs_nonspeech_time\n",
466
+ "\n",
467
+ "Run_ratio: [0.05, 0.1, 0.05, 0.8]\n",
468
+ "Split statistics: {'train': Counter({1: 494, 0: 422}), 'val': Counter({0: 67, 1: 47}), 'test': Counter({0: 83, 1: 31})}\n",
469
+ "Run_ratio: [0.8, 0.05, 0.1, 0.05]\n",
470
+ "Split statistics: {'train': Counter({1: 493, 0: 422}), 'val': Counter({0: 83, 1: 32}), 'test': Counter({0: 67, 1: 47})}\n",
471
+ "Run_ratio: [0.2, 0.05, 0.1, 0.05, 0.6]\n",
472
+ "Split statistics: {'train': Counter({0: 496, 1: 420}), 'val': Counter({1: 92, 0: 22}), 'test': Counter({1: 60, 0: 54})}\n",
473
+ "Run_ratio: [0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996]\n",
474
+ "Split statistics: {'train': Counter({0: 509, 1: 407}), 'val': Counter({1: 82, 0: 32}), 'test': Counter({1: 83, 0: 31})}\n",
475
+ "Run_ratio: [0.4, 0.05, 0.1, 0.05, 0.3999999999999999]\n",
476
+ "Split statistics: {'train': Counter({1: 476, 0: 440}), 'val': Counter({0: 71, 1: 43}), 'test': Counter({0: 61, 1: 53})}\n",
477
+ "\n",
478
+ "\n"
479
+ ]
480
+ }
481
+ ],
482
+ "source": [
483
+ "## Following code will compute the statistics associated with each fold.\n",
484
+ "all_output_dicts = {}\n",
485
+ "for metadata_setting in _METADATA_FNAMES.keys():\n",
486
+ " output_dict = {} # {experiment_name: {subject_session: [(ratio1, split1), (ratio2, split2), ...]}}\n",
487
+ " for experiment in _ALL_EXPERIMENTS:\n",
488
+ " output_dict[experiment] = {}\n",
489
+ "\n",
490
+ " fpath = _SEGMENT_DIR.format(experiment)\n",
491
+ " metadata_fname = _METADATA_FNAMES[metadata_setting]\n",
492
+ " metadata = load_metadata(os.path.join(fpath, metadata_fname))\n",
493
+ "\n",
494
+ " for subject_session in ALL_SUBJECTS:\n",
495
+ " print(\n",
496
+ " f'metadata_setting:{metadata_setting}, '\n",
497
+ " f'subject_session:{subject_session}, '\n",
498
+ " f'experiment:{experiment}\\n'\n",
499
+ " )\n",
500
+ "\n",
501
+ " subject_rows_indices = metadata.get_indices_matching_cols_values(\n",
502
+ " [\"subject_session\", \"experiment\"], [subject_session, experiment]\n",
503
+ " )\n",
504
+ " n_segments = len(subject_rows_indices)\n",
505
+ "\n",
506
+ " folds, splits = subject_folds[metadata_setting][subject_session][experiment]\n",
507
+ " out_tuples = []\n",
508
+ " for run_ratio, run_splits in zip(folds, splits):\n",
509
+ " counts = (np.array(run_ratio) * n_segments).astype(int)\n",
510
+ " counts[-1] = n_segments - sum(counts[:-1])\n",
511
+ "\n",
512
+ " print(f\"Run_ratio: {run_ratio}\")\n",
513
+ "\n",
514
+ " agg_split_counts = {'train': Counter(), 'val': Counter(), 'test': Counter()}\n",
515
+ " sum_now = 0\n",
516
+ " for c, split in zip(counts, run_splits):\n",
517
+ " label_split_indices = subject_rows_indices[sum_now : sum_now + c]\n",
518
+ " sum_now += c\n",
519
+ " agg_split_counts[split].update(\n",
520
+ " metadata._df.iloc[label_split_indices].label.to_numpy()\n",
521
+ " )\n",
522
+ "\n",
523
+ " print(f'Split statistics: {agg_split_counts}')\n",
524
+ " out_tuples.append((run_ratio, run_splits))\n",
525
+ " print('\\n')\n",
526
+ "\n",
527
+ " output_dict[experiment][subject_session] = out_tuples\n",
528
+ "\n",
529
+ " all_output_dicts[metadata_setting] = output_dict"
530
+ ]
531
+ },
532
+ {
533
+ "cell_type": "code",
534
+ "execution_count": 7,
535
+ "id": "22372166",
536
+ "metadata": {},
537
+ "outputs": [
538
+ {
539
+ "name": "stdout",
540
+ "output_type": "stream",
541
+ "text": [
542
+ "/data/seyedesa/njepa/public_release_test/data_nov30_15_00/sentence_onset_time\n",
543
+ "/data/seyedesa/njepa/public_release_test/data_nov30_15_00/speech_vs_nonspeech_time\n"
544
+ ]
545
+ }
546
+ ],
547
+ "source": [
548
+ "## Save out the data in the format expected in braintreebank_dataset.py.\n",
549
+ "import pickle\n",
550
+ "\n",
551
+ "for fb_setting, fb_setting_output in all_output_dicts.items():\n",
552
+ " out_fname = f\"{Path(_METADATA_FNAMES[fb_setting]).stem}_folds.pkl\"\n",
553
+ "\n",
554
+ " for experiment, experiment_output in fb_setting_output.items():\n",
555
+ " out_path = _SEGMENT_DIR.format(experiment)\n",
556
+ " print(out_path)\n",
557
+ " with open(os.path.join(out_path, out_fname), 'wb') as file:\n",
558
+ " pickle.dump(experiment_output, file)"
559
+ ]
560
+ },
561
+ {
562
+ "cell_type": "code",
563
+ "execution_count": 8,
564
+ "id": "1e6e1189",
565
+ "metadata": {},
566
+ "outputs": [
567
+ {
568
+ "name": "stdout",
569
+ "output_type": "stream",
570
+ "text": [
571
+ "sentence_onset_time\n",
572
+ "{'HOLDSUBJ_1_HS1_1': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_2_HS2_6': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_3_HS3_0': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_4_HS4_0': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_6_HS6_4': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.5, 0.05, 0.1, 0.05, 0.29999999999999993], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_7_HS7_0': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_10_HS10_0': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])]}\n",
573
+ "\n",
574
+ "\n",
575
+ "speech_vs_nonspeech_time\n",
576
+ "{'HOLDSUBJ_1_HS1_1': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.75, 0.05, 0.1, 0.05, 0.04999999999999993], ['train', 'val', 'test', 'val', 'train']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.55, 0.05, 0.1, 0.05, 0.25], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_2_HS2_6': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_3_HS3_0': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_4_HS4_0': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_6_HS6_4': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.5, 0.05, 0.1, 0.05, 0.29999999999999993], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_7_HS7_0': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.75, 0.05, 0.1, 0.05, 0.04999999999999993], ['train', 'val', 'test', 'val', 'train']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.55, 0.05, 0.1, 0.05, 0.25], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])], 'HOLDSUBJ_10_HS10_0': [([0.05, 0.1, 0.05, 0.8], ['val', 'test', 'val', 'train']), ([0.8, 0.05, 0.1, 0.05], ['train', 'val', 'test', 'val']), ([0.2, 0.05, 0.1, 0.05, 0.6], ['train', 'val', 'test', 'val', 'train']), ([0.6000000000000001, 0.05, 0.1, 0.05, 0.19999999999999996], ['train', 'val', 'test', 'val', 'train']), ([0.4, 0.05, 0.1, 0.05, 0.3999999999999999], ['train', 'val', 'test', 'val', 'train'])]}\n",
577
+ "\n",
578
+ "\n"
579
+ ]
580
+ }
581
+ ],
582
+ "source": [
583
+ "## Checking output was correct.\n",
584
+ "for fb_setting, fb_setting_output in all_output_dicts.items():\n",
585
+ " out_fname = f\"{Path(_METADATA_FNAMES[fb_setting]).stem}_folds.pkl\"\n",
586
+ "\n",
587
+ " for experiment, experiment_output in fb_setting_output.items():\n",
588
+ " out_path = _SEGMENT_DIR.format(experiment)\n",
589
+ " with open(os.path.join(out_path, out_fname), 'rb') as file:\n",
590
+ " datatmp = pickle.load(file)\n",
591
+ " print(experiment)\n",
592
+ " print(datatmp)\n",
593
+ " print('\\n')"
594
+ ]
595
+ },
596
+ {
597
+ "cell_type": "code",
598
+ "execution_count": null,
599
+ "id": "77052232",
600
+ "metadata": {},
601
+ "outputs": [],
602
+ "source": []
603
+ }
604
+ ],
605
+ "metadata": {
606
+ "kernelspec": {
607
+ "display_name": "venv",
608
+ "language": "python",
609
+ "name": "python3"
610
+ },
611
+ "language_info": {
612
+ "codemirror_mode": {
613
+ "name": "ipython",
614
+ "version": 3
615
+ },
616
+ "file_extension": ".py",
617
+ "mimetype": "text/x-python",
618
+ "name": "python",
619
+ "nbconvert_exporter": "python",
620
+ "pygments_lexer": "ipython3",
621
+ "version": "3.8.10"
622
+ }
623
+ },
624
+ "nbformat": 4,
625
+ "nbformat_minor": 5
626
+ }
barista/models/TSEncoder2D.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Source code based on publicly released dilated CNN models as found in
2
+ ## SimTS model: https://github.com/xingyu617/SimTS_Representation_Learning/blob/main/models/dilation.py
3
+ ## and
4
+ ## TS2Vec repo: https://github.com/zhihanyue/ts2vec/blob/main/models/dilated_conv.py
5
+
6
+ import torch
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+ from torch.utils.checkpoint import checkpoint
10
+
11
+
12
+ def init_weights(m):
13
+ """
14
+ Relevant reading material:
15
+ https://pytorch.org/docs/stable/nn.init.html
16
+ https://github.com/pytorch/vision/blob/309bd7a1512ad9ff0e9729fbdad043cb3472e4cb/torchvision/models/densenet.py#L203
17
+ """
18
+ if isinstance(m, nn.Conv2d):
19
+ nn.init.kaiming_normal_(m.weight)
20
+ m.bias.data.fill_(0.0)
21
+ elif isinstance(m, nn.Linear):
22
+ nn.init.constant_(m.bias, 0)
23
+
24
+
25
+ class SamePadConv(nn.Module):
26
+ def __init__(
27
+ self,
28
+ in_channels,
29
+ out_channels,
30
+ kernel_size,
31
+ stride=1,
32
+ dilation=1,
33
+ groups=1,
34
+ ):
35
+ """Padded convolution to ensure same sized input and output."""
36
+ super().__init__()
37
+ self.receptive_field = (kernel_size - 1) * dilation + 1
38
+ padding = self.receptive_field // 2
39
+ self.conv = nn.Conv2d(
40
+ in_channels,
41
+ out_channels,
42
+ (1, kernel_size),
43
+ padding=(0, padding),
44
+ stride=(1, stride),
45
+ dilation=(1, dilation),
46
+ groups=groups,
47
+ )
48
+
49
+ init_weights(self.conv)
50
+
51
+ self.remove = 1 if self.receptive_field % 2 == 0 else 0
52
+
53
+ def forward(self, x):
54
+ out = self.conv(x)
55
+ if self.remove > 0:
56
+ out = out[:, :, :, : -self.remove]
57
+ return out
58
+
59
+
60
+ class ConvBlock(nn.Module):
61
+ def __init__(
62
+ self,
63
+ in_channels,
64
+ out_channels,
65
+ kernel_size,
66
+ stride,
67
+ dilation,
68
+ final=False,
69
+ enable_checkpointing=False,
70
+ ):
71
+ """
72
+ Convolutional block implementation.
73
+
74
+ Consists of two convolution layers followed by a residual stream.
75
+
76
+ Args:
77
+ in_channels: int. Input channel count.
78
+ out_channels: int. Output channel count.
79
+ kernel_size: int. Convolution kernel size.
80
+ stride: int. Convolution stride size.
81
+ dilation: int. Convolution dilation amount.
82
+ final: bool. This is the final convolutional block in the stack. Only relevant for
83
+ using a projection head for the residual stream.
84
+ enable_checkpointing: bool. Enable checkpointing of the intermediate weights if
85
+ desired. Default False.
86
+ """
87
+ super().__init__()
88
+
89
+ self.enable_checkpointing = enable_checkpointing
90
+
91
+ self.conv1 = SamePadConv(
92
+ in_channels,
93
+ out_channels,
94
+ kernel_size,
95
+ stride=stride,
96
+ dilation=dilation,
97
+ )
98
+
99
+ self.conv2 = SamePadConv(
100
+ out_channels,
101
+ out_channels,
102
+ kernel_size,
103
+ stride=stride,
104
+ dilation=dilation,
105
+ )
106
+
107
+ self.projector = (
108
+ nn.Conv2d(
109
+ in_channels, out_channels, kernel_size=(1, 1), stride=(1, stride**2),
110
+ )
111
+ if in_channels != out_channels or final or stride != 1
112
+ else None
113
+ )
114
+ if self.projector is not None:
115
+ init_weights(self.projector)
116
+
117
+ def _forward_mini_block(self, x: torch.tensor, block_num: int):
118
+ x = self.conv1(x) if block_num == 1 else self.conv2(x)
119
+ x = F.layer_norm(x, (x.shape[-1],))
120
+ x = F.gelu(x)
121
+ return x
122
+
123
+ def forward(self, x: torch.tensor):
124
+ residual = x if self.projector is None else self.projector(x)
125
+
126
+ if self.enable_checkpointing:
127
+ x = checkpoint(self._forward_mini_block, x, 1, use_reentrant=False)
128
+ x = checkpoint(self._forward_mini_block, x, 2, use_reentrant=False)
129
+ else:
130
+ x = self._forward_mini_block(x, block_num=1)
131
+ x = self._forward_mini_block(x, block_num=2)
132
+
133
+ return x + residual
134
+
135
+
136
+ class DilatedConvEncoder(nn.Module):
137
+ def __init__(
138
+ self,
139
+ in_channels,
140
+ channels,
141
+ kernel_size,
142
+ stride=1,
143
+ enable_checkpointing=False,
144
+ ):
145
+ """Dilated CNN implementation. See ConvBlock for argument definitions."""
146
+ super().__init__()
147
+
148
+ self.enable_checkpointing = enable_checkpointing
149
+
150
+ self.net = nn.ModuleList(
151
+ [
152
+ ConvBlock(
153
+ channels[i - 1] if i > 0 else in_channels,
154
+ channels[i],
155
+ kernel_size=kernel_size,
156
+ stride=stride,
157
+ dilation=2**i,
158
+ final=(i == len(channels) - 1),
159
+ enable_checkpointing=enable_checkpointing,
160
+ )
161
+ for i in range(len(channels))
162
+ ]
163
+ )
164
+
165
+ def forward(self, x: torch.tensor):
166
+ for layer in self.net:
167
+ x = layer(x)
168
+ return x
169
+
170
+
171
+ class TSEncoder2D(nn.Module):
172
+ def __init__(
173
+ self,
174
+ input_dims,
175
+ output_dims,
176
+ hidden_dims=64,
177
+ depth=10,
178
+ kernel_size=3,
179
+ stride=1,
180
+ enable_checkpointing=False,
181
+ ):
182
+ """
183
+ Original source implementation:
184
+ TS2Vec Encoder: https://github.com/zhihanyue/ts2vec/blob/main/models/encoder.py
185
+
186
+ See ConvBlock function for argument definitions.
187
+ """
188
+ super().__init__()
189
+ self.input_dims = input_dims
190
+ self.output_dims = output_dims
191
+ self.hidden_dims = hidden_dims
192
+ self.enable_checkpointing = enable_checkpointing
193
+
194
+ self.feature_extractor = DilatedConvEncoder(
195
+ input_dims,
196
+ [hidden_dims] * depth + [output_dims],
197
+ kernel_size=kernel_size,
198
+ stride=stride,
199
+ enable_checkpointing=self.enable_checkpointing,
200
+ )
201
+
202
+ def forward(self, x: torch.tensor):
203
+ """
204
+ Args:
205
+ x: torch.tensor of shape (1, 1, B * T * D, N) with time (N) along the last axis.
206
+ Note: the additional (1, 1) for the first two axies is to use 2D convs for
207
+ 1D convolution operations.
208
+ Note: B=Batch, T=Number of segments, D=Channels.
209
+
210
+ Returns:
211
+ Temporal encoded version of the input tensor of shape (1, 1, B * T * D, N)
212
+ """
213
+ return self.feature_extractor(x)
barista/models/mlp.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch.nn as nn
4
+
5
+ from barista.models.utils import get_activation_function
6
+
7
+
8
+ class MLP(nn.Module):
9
+ def __init__(
10
+ self,
11
+ d_input: int,
12
+ d_out: int,
13
+ layer_list: List = None,
14
+ dropout: float = 0.1,
15
+ bias: bool = True,
16
+ use_first_dropout: bool = True,
17
+ use_final_dropout: bool = False,
18
+ use_final_activation: bool = False,
19
+ activation: str = "linear",
20
+ use_identity_stub: bool = True,
21
+ **kwargs
22
+ ):
23
+ super(MLP, self).__init__()
24
+
25
+ self.d_input = d_input
26
+ self.d_out = d_out
27
+ self.layer_list = layer_list
28
+ self.dropout = dropout
29
+ self.use_first_dropout = use_first_dropout
30
+ self.use_final_dropout = use_final_dropout
31
+ self.use_final_activation = use_final_activation
32
+ self.activation_fn = get_activation_function(activation)
33
+
34
+ current_dim = self.d_input
35
+ self.layers = nn.ModuleList()
36
+ if self.layer_list is not None:
37
+ for _, dim in enumerate(self.layer_list):
38
+ self.layers.append(nn.Linear(current_dim, dim, bias=bias))
39
+ current_dim = dim
40
+ else:
41
+ if use_identity_stub:
42
+ self.layers.append(nn.Identity())
43
+
44
+ self.final_layer = nn.Linear(current_dim, self.d_out, bias=bias)
45
+
46
+ def forward(self, x, *args, **kwargs):
47
+ if self.use_first_dropout:
48
+ x = nn.Dropout(self.dropout)(x)
49
+ for layer in self.layers:
50
+ x = layer(x)
51
+ x = self.activation_fn(x)
52
+ x = nn.Dropout(self.dropout)(x)
53
+ x = self.final_layer(x)
54
+ if self.use_final_activation:
55
+ x = self.activation_fn(x)
56
+ if self.use_final_dropout:
57
+ x = nn.Dropout(self.dropout)(x)
58
+ return x
59
+
60
+
barista/models/model.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import DictConfig
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import List
5
+
6
+ from barista.data.metadata import Metadata
7
+ from barista.models.tokenizer import Tokenizer
8
+ from barista.models.transformer import Transformer
9
+
10
+
11
+ class Barista(nn.Module):
12
+ def __init__(self, model_config: DictConfig, metadata: Metadata, *args, **kwargs):
13
+ super().__init__(*args, **kwargs)
14
+ self.metadata = metadata
15
+
16
+ self.tokenizer = Tokenizer(
17
+ config=model_config.tokenizer,
18
+ metadata=self.metadata,
19
+ )
20
+
21
+ self.backbone = Transformer(
22
+ **model_config.backbone,
23
+ )
24
+
25
+ self.d_hidden = model_config.backbone.d_hidden
26
+
27
+ self.head = None
28
+
29
+ def create_downstream_head(self, n_chans, output_dim):
30
+ self.channel_weights = nn.Linear(
31
+ n_chans * self.tokenizer.num_subsegments,
32
+ 1,
33
+ bias=False,
34
+ )
35
+ self.binary_classifier = nn.Linear(
36
+ self.d_hidden, output_dim
37
+ )
38
+
39
+ def get_latent_embeddings(self, x: torch.Tensor, subject_sessions: List):
40
+ # Get tokens
41
+ tokenized_x = self.tokenizer(x, subject_sessions, output_as_list=False)
42
+
43
+ # Pass through transformer
44
+ latents = self.backbone(
45
+ x=tokenized_x.tokens,
46
+ seq_lens=tokenized_x.seq_lens,
47
+ position_ids=tokenized_x.position_ids,
48
+ )
49
+
50
+ return latents
51
+
52
+ def forward(self, x: torch.Tensor, subject_sessions: List):
53
+
54
+ latents = self.get_latent_embeddings(x, subject_sessions)
55
+
56
+ # Pass through Task head
57
+ batch_size = x[0].shape[0]
58
+ latents_reshaped = latents.reshape(batch_size, -1, latents.shape[-1])
59
+ x = self.channel_weights(latents_reshaped.permute(0, 2, 1)).squeeze(dim=-1)
60
+ x = self.binary_classifier(x)
61
+
62
+ return x
63
+
64
+ def get_task_params(self):
65
+ return [*self.channel_weights.named_parameters(), *self.binary_classifier.named_parameters()]
66
+
67
+ def get_upstream_params(self):
68
+ return [*self.tokenizer.named_parameters(), *self.backbone.named_parameters()]
barista/models/spatial_encoder.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from abc import ABC, abstractmethod
3
+ import einops
4
+ import torch
5
+ import torch.nn as nn
6
+ from typing import Optional
7
+
8
+
9
+ class SpatialEncoderMeta:
10
+ def __init__(self, subject_session_spatial_groups=None):
11
+ """Metadata object with subject session information for spatial encoding."""
12
+ self.subject_session_spatial_groups = subject_session_spatial_groups
13
+
14
+ @property
15
+ def num_region_info(self):
16
+ n_effective_components_across_sessions = set(
17
+ [a.n_effective_components for a in self.subject_session_spatial_groups.values()]
18
+ )
19
+
20
+ assert len(n_effective_components_across_sessions) == 1, (
21
+ "Doesn't support variable number of effective components for different subject_sessions"
22
+ )
23
+
24
+ self._num_region_info = n_effective_components_across_sessions.pop()
25
+ return self._num_region_info
26
+
27
+ @property
28
+ def embedding_table_configs(self):
29
+ configs = {}
30
+ for i in range(self.num_region_info):
31
+ n_embeddings_for_components_set = set(
32
+ [a.max_elements_for_component[i] for a in self.subject_session_spatial_groups.values()]
33
+ )
34
+ padding_indices_set = set(
35
+ [a.padding_indices[i] for a in self.subject_session_spatial_groups.values()]
36
+ )
37
+
38
+ assert len(n_embeddings_for_components_set) == 1, (
39
+ "Doesn't support variable number of max components for different subject_sessions, "
40
+ "change to use max of values across the subject if it is not important."
41
+ )
42
+ assert len(padding_indices_set) == 1, (
43
+ "Doesn't support variable number of padding indices for different subject_sessions, "
44
+ "change to use max of values across the subject if it is not important."
45
+ )
46
+
47
+ configs[i] = {
48
+ 'num_embeddings': n_embeddings_for_components_set.pop(),
49
+ 'padding_idx': padding_indices_set.pop()
50
+ }
51
+
52
+ return configs
53
+
54
+
55
+ class BaseSpatialEncoder(ABC, nn.Module):
56
+ """Abstract class definition for spatial encoding modules.
57
+
58
+ Implement this interface to try new spatial encoding approaches in the tokenizer.
59
+ """
60
+ _SUBJ_SESH_QUERY_HASH_STR = "{0}_queryvec"
61
+
62
+ def __init__(
63
+ self,
64
+ dim_h: int,
65
+ spatial_encoder_meta: SpatialEncoderMeta,
66
+ ):
67
+ super().__init__()
68
+ self.dim_h = dim_h
69
+ self.spatial_encoder_meta = spatial_encoder_meta
70
+
71
+ self._construct_region_encoding_meta()
72
+
73
+ def _construct_region_encoding_meta(self):
74
+ """Constructs a hashmap of channel region information -> query vector for spatial encoding."""
75
+ for (
76
+ subject_session,
77
+ spatial_groups,
78
+ ) in self.spatial_encoder_meta.subject_session_spatial_groups.items():
79
+ query_vector = torch.tensor(
80
+ [tuple(map(int, e[:spatial_groups.n_effective_components])) for e in spatial_groups.group_components]
81
+ )
82
+
83
+ query_vector = self._transform_query_vector(query_vector)
84
+
85
+ self.register_buffer(
86
+ BaseSpatialEncoder._SUBJ_SESH_QUERY_HASH_STR.format(subject_session),
87
+ query_vector, persistent=False
88
+ )
89
+
90
+ def _transform_query_vector(self, query_vector: torch.Tensor):
91
+ return query_vector
92
+
93
+ def get_embedding_table_query_vector(self, subject_session: str) -> torch.Tensor:
94
+ return self._buffers[BaseSpatialEncoder._SUBJ_SESH_QUERY_HASH_STR.format(subject_session)].to(torch.long)
95
+
96
+ def update_for_new_sessions(self,
97
+ new_subject_session_spatial_groups):
98
+ self.spatial_encoder_meta.subject_session_spatial_groups = new_subject_session_spatial_groups
99
+ self._construct_region_encoding_meta()
100
+ return []
101
+
102
+ @abstractmethod
103
+ def _encode(self, x: torch.tensor) -> torch.tensor:
104
+ pass
105
+
106
+ @abstractmethod
107
+ def _get_position_encoding(
108
+ self, x: torch.tensor, subject_session: str
109
+ ) -> torch.tensor:
110
+ pass
111
+
112
+ def forward(
113
+ self,
114
+ x: torch.tensor,
115
+ subject_session: str,
116
+ timepoints: int = 1,
117
+ mask: torch.tensor = None,
118
+ ) -> torch.tensor:
119
+ """
120
+ Args:
121
+ x: torch.tensor of shape (B, T*R, D). Time-space interleaved tokens of dim D.
122
+
123
+ Returns:
124
+ A torch.tensor of shape (B, T*R, D) that is the encoding corresponding to
125
+ the input token x.
126
+ """
127
+ session_PE = self._get_position_encoding(x, subject_session)
128
+ assert (
129
+ x.shape[-1] == session_PE.shape[-1]
130
+ ), f"Region dimension mismatch: {x.shape[-1]} vs {session_PE.shape[-1]}."
131
+
132
+ position_encoding = einops.repeat(
133
+ session_PE, "r d -> b (t r) d", b=x.shape[0], t=timepoints
134
+ )
135
+
136
+ if mask is not None:
137
+ position_encoding = position_encoding[:, mask, :]
138
+
139
+ assert (
140
+ x.shape == position_encoding.shape
141
+ ), "Output position encoding does not match in shape"
142
+ return position_encoding
143
+
144
+
145
+ class EmbeddingTable(BaseSpatialEncoder):
146
+ def __init__(
147
+ self,
148
+ dim_h: int,
149
+ spatial_encoder_meta: SpatialEncoderMeta,
150
+ embedding_max_dim: Optional[float] = None,
151
+ embedding_init_scale: float = 1.0
152
+ ):
153
+ """A lookup table of different embeddings for different spatial fields."""
154
+ super().__init__(dim_h, spatial_encoder_meta)
155
+
156
+ # Create the embeddings.
157
+ self.subcomponent_embedding_info = self.spatial_encoder_meta.embedding_table_configs
158
+ subcomponent_dims = self._get_subcomponent_dims()
159
+
160
+ self.subcomponent_embeddings = nn.ModuleDict()
161
+ for (
162
+ subcomponent_ind,
163
+ subcomponent_config,
164
+ ) in self.subcomponent_embedding_info.items():
165
+ subcomponent_dim = subcomponent_dims[subcomponent_ind]
166
+
167
+ self.subcomponent_embeddings[str(subcomponent_ind)] = nn.Embedding(
168
+ subcomponent_config["num_embeddings"],
169
+ subcomponent_dim,
170
+ padding_idx=subcomponent_config["padding_idx"],
171
+ max_norm=embedding_max_dim,
172
+ )
173
+
174
+ self.init_weights_for_embeddings(
175
+ self.subcomponent_embeddings[str(subcomponent_ind)],
176
+ embedding_init_scale
177
+ )
178
+
179
+ @abstractmethod
180
+ def _get_subcomponent_dims(self):
181
+ raise NotImplementedError
182
+
183
+ def update_for_new_sessions(self, new_subject_session_spatial_groups):
184
+ """Add need embedding table elements based on new subject session information."""
185
+ new_params = super().update_for_new_sessions(new_subject_session_spatial_groups)
186
+
187
+ subcomponent_embedding_info = self.spatial_encoder_meta.embedding_table_configs
188
+ for subcomponent_ind, subcomponent_config in subcomponent_embedding_info.items():
189
+ prev_embeddings = self.subcomponent_embeddings[str(subcomponent_ind)]
190
+ n_rows, subcomponent_dim = prev_embeddings.weight.shape
191
+
192
+ if subcomponent_config['num_embeddings'] == n_rows:
193
+ # no need to add any new embedding
194
+ continue
195
+
196
+ new_embeddings = torch.empty(
197
+ subcomponent_config['num_embeddings'] - n_rows,
198
+ subcomponent_dim,
199
+ device=prev_embeddings.weight.device
200
+ )
201
+ nn.init.normal_(new_embeddings)
202
+
203
+ new_data = torch.cat((prev_embeddings.weight.data, new_embeddings))
204
+
205
+ self.subcomponent_embeddings[str(subcomponent_ind)] = nn.Embedding(
206
+ subcomponent_config["num_embeddings"],
207
+ subcomponent_dim,
208
+ padding_idx=subcomponent_config["padding_idx"],
209
+ )
210
+ self.subcomponent_embeddings[str(subcomponent_ind)].weight.data = new_data
211
+
212
+ new_params.extend([n for n, _ in self.named_parameters()])
213
+
214
+ return new_params
215
+
216
+ def init_weights_for_embeddings(self, embedding_table: nn.Embedding, embedding_init_scale: float = 1.0):
217
+ nn.init.normal_(embedding_table.weight, std=embedding_init_scale)
218
+ embedding_table._fill_padding_idx_with_zero()
219
+
220
+ def _transform_query_vector(self, query_vector: torch.Tensor):
221
+ return query_vector.to(torch.float).T
222
+
223
+ def _get_position_encoding(
224
+ self, _: torch.tensor, subject_session: str
225
+ ) -> torch.tensor:
226
+ """Returns the encoding vector based on a subject session query."""
227
+ session_region_query = self.get_embedding_table_query_vector(
228
+ subject_session
229
+ )
230
+ single_session_PE = self._encode(session_region_query)
231
+ return single_session_PE
232
+
233
+
234
+ class EmbeddingTablePool(EmbeddingTable):
235
+ def _get_subcomponent_dims(self):
236
+ return {k: self.dim_h for k in self.subcomponent_embedding_info.keys()}
237
+
238
+ def _encode(self, x: torch.tensor) -> torch.tensor:
239
+ """
240
+ Args:
241
+ x: torch.tensor of shape (B, T*R, D). Time-space interleaved tokens of dim D.
242
+
243
+ Returns:
244
+ A torch.tensor of shape (B, T*R, D) that is the encoding corresponding to
245
+ the input token. If token has multiple spatial fields, the encoding for
246
+ each of these fields will be summed together before being return (e.g.,
247
+ x,y,z LPI coordinates).
248
+ """
249
+ PE = torch.zeros((x.shape[0], x.shape[1], self.dim_h), device=x.get_device())
250
+ for subcomponent_ind in range(x.shape[0]):
251
+ subcomponent_x = x[subcomponent_ind, ...]
252
+ PE[subcomponent_ind, ...] = self.subcomponent_embeddings[
253
+ str(subcomponent_ind)
254
+ ](subcomponent_x)
255
+ return torch.sum(PE, axis=0)
256
+
257
+
258
+ def create_spatial_encoder(
259
+ dim_h: int,
260
+ subject_session_spatial_groups=None,
261
+ embedding_max_dim=None,
262
+ embedding_init_scale=1.0,
263
+ ) -> BaseSpatialEncoder:
264
+ """Creates the spatial encoder and the cached spatial encoding information needed during forward passes."""
265
+ spatial_encoder_meta = SpatialEncoderMeta(
266
+ subject_session_spatial_groups
267
+ )
268
+
269
+ spatial_encoder = EmbeddingTablePool(
270
+ dim_h,
271
+ spatial_encoder_meta,
272
+ embedding_max_dim,
273
+ embedding_init_scale
274
+ )
275
+
276
+ return spatial_encoder
barista/models/tokenized_batched_item.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import einops
3
+ import torch
4
+ from typing import List, Optional
5
+
6
+
7
+ @dataclasses.dataclass
8
+ class TokenizedBatchedItem:
9
+ """
10
+ tokens: (B_i, N, D)
11
+ position_ids: (B_i, N)
12
+ temporal_group_ids: (B_i, N)
13
+ spatial_group_ids: (B_i, N)
14
+ seq_lens: List[int]
15
+ spatial_embeddings: (B_i, N, D)
16
+
17
+
18
+ NOTE: Assumption: Either seq_lens length is one, or B_i is one, i.e. we either
19
+ have a batched tensor or a list of single tensors.
20
+ """
21
+ tokens: torch.Tensor
22
+ position_ids: torch.Tensor
23
+ seq_lens: List[int]
24
+ spatial_embeddings: Optional[torch.Tensor]
25
+ temporal_group_ids: Optional[torch.Tensor]
26
+ spatial_group_ids: Optional[torch.Tensor]
27
+ subject_sessions: List[str]
28
+
29
+ @classmethod
30
+ def get_as_one_sequence(
31
+ cls, tokenized_items_list: List["TokenizedBatchedItem"]
32
+ ) -> "TokenizedBatchedItem":
33
+ """
34
+ Generate a long concatenated sequence from a list of TokenizedBatchedItem
35
+ """
36
+ (
37
+ seq_lens,
38
+ tokens_list,
39
+ position_ids,
40
+ temporal_group_ids,
41
+ spatial_group_ids,
42
+ spatial_embeddings_list,
43
+ subject_sessions_list,
44
+ ) = ([], [], [], [], [], [], [])
45
+ for item in tokenized_items_list:
46
+ batch_size = item.tokens.shape[0]
47
+
48
+ tokens_list.append(einops.rearrange(item.tokens, "b n d -> (b n) d"))
49
+ if item.spatial_embeddings is not None:
50
+ spatial_embeddings_list.append(
51
+ einops.rearrange(item.spatial_embeddings, "b n d -> (b n) d")
52
+ )
53
+
54
+ if item.position_ids is not None:
55
+ position_ids.append(item.position_ids.flatten())
56
+
57
+ if item.temporal_group_ids is not None:
58
+ temporal_group_ids.append(item.temporal_group_ids.flatten())
59
+
60
+ if item.spatial_group_ids is not None:
61
+ spatial_group_ids.append(item.spatial_group_ids.flatten())
62
+
63
+ seq_lens.extend(item.seq_lens * batch_size)
64
+ subject_sessions_list.extend(item.subject_sessions * batch_size)
65
+
66
+ tokens = torch.cat(tokens_list).unsqueeze(dim=0)
67
+ assert tokens.shape[:2] == (1, sum(seq_lens))
68
+
69
+ if len(spatial_embeddings_list) > 0:
70
+ spatial_embeddings = torch.cat(spatial_embeddings_list).unsqueeze(dim=0)
71
+ assert spatial_embeddings.shape[:2] == (1, sum(seq_lens))
72
+ else:
73
+ spatial_embeddings = None
74
+
75
+ if len(position_ids) > 0:
76
+ position_ids = torch.cat(position_ids).unsqueeze(dim=0)
77
+ assert position_ids.shape == (1, sum(seq_lens))
78
+ else:
79
+ position_ids = None
80
+
81
+ if len(temporal_group_ids) > 0:
82
+ temporal_group_ids = torch.cat(temporal_group_ids).unsqueeze(dim=0)
83
+ assert temporal_group_ids.shape == (1, sum(seq_lens))
84
+ else:
85
+ temporal_group_ids = None
86
+
87
+ if len(spatial_group_ids) > 0:
88
+ spatial_group_ids = torch.cat(spatial_group_ids).unsqueeze(dim=0)
89
+ assert spatial_group_ids.shape == (1, sum(seq_lens))
90
+ else:
91
+ spatial_group_ids = None
92
+
93
+ return TokenizedBatchedItem(
94
+ tokens=tokens,
95
+ position_ids=position_ids,
96
+ temporal_group_ids=temporal_group_ids,
97
+ spatial_group_ids=spatial_group_ids,
98
+ seq_lens=seq_lens,
99
+ spatial_embeddings=spatial_embeddings,
100
+ subject_sessions=subject_sessions_list
101
+ )
102
+
103
+ def get_as_list_items(self) -> List["TokenizedBatchedItem"]:
104
+ """
105
+ Note: this does not exactly reverse `get_as_one_sequence` because it does not batch items with the
106
+ same seq length together
107
+ """
108
+ tokenized_items_list = []
109
+ cur_total_len = 0
110
+ for seq_ind, seq_len in enumerate(self.seq_lens):
111
+ tokens = TokenizedBatchedItem(
112
+ tokens=self.tokens[:, cur_total_len : cur_total_len + seq_len],
113
+ position_ids=None if self.position_ids is None else self.position_ids[
114
+ :, cur_total_len : cur_total_len + seq_len
115
+ ],
116
+ temporal_group_ids=self.temporal_group_ids[
117
+ :, cur_total_len : cur_total_len + seq_len
118
+ ],
119
+ spatial_group_ids=self.spatial_group_ids[
120
+ :, cur_total_len : cur_total_len + seq_len
121
+ ],
122
+ spatial_embeddings=None if self.spatial_embeddings is None else self.spatial_embeddings[
123
+ :, cur_total_len : cur_total_len + seq_len
124
+ ],
125
+ seq_lens=[seq_len],
126
+ subject_sessions=self.subject_sessions[seq_ind]
127
+ )
128
+ cur_total_len += seq_len
129
+
130
+ tokenized_items_list.append(tokens)
131
+
132
+ return tokenized_items_list
barista/models/tokenizer.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import einops
2
+ from omegaconf import DictConfig
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import Dict, List, Union
6
+
7
+ import barista.models.spatial_encoder as spe
8
+ from barista.data.metadata import Metadata
9
+ from barista.models.mlp import MLP
10
+ from barista.models.tokenized_batched_item import TokenizedBatchedItem
11
+ from barista.models.TSEncoder2D import TSEncoder2D
12
+
13
+
14
+ class Tokenizer(nn.Module):
15
+ def __init__(
16
+ self,
17
+ config: DictConfig,
18
+ metadata: Metadata,
19
+ ):
20
+ super().__init__()
21
+
22
+ self.metadata = metadata
23
+ self.config = config
24
+
25
+ self.subjects = metadata.get_subjects()
26
+
27
+ self.num_subsegments = int(
28
+ (
29
+ self.config.samp_frequency * self.config.num_seconds
30
+ - self.config.temporal_subsegment_len
31
+ )
32
+ // (self.config.temporal_subsegment_step)
33
+ + 1
34
+ )
35
+
36
+ self.dim_h = self.config.d_hidden
37
+
38
+ self._build_temporal_encoder()
39
+
40
+ self._build_temporal_pooler()
41
+
42
+ self._build_spatial_encoder()
43
+
44
+ def _build_temporal_encoder(self):
45
+ self.config.temporal_encoder.input_dims = 1
46
+ self.config.temporal_encoder.output_dims = 1
47
+ self.temporal_encoder = TSEncoder2D(**self.config.temporal_encoder)
48
+
49
+ def _build_temporal_pooler(self):
50
+ self.temporal_pooler = MLP(
51
+ d_input=self.config.temporal_subsegment_len,
52
+ d_out=self.dim_h,
53
+ dropout=0.0,
54
+ bias=False,
55
+ )
56
+
57
+ def _build_spatial_encoder(self):
58
+ self.subject_session_spatial_groups = {}
59
+ for sub_sesh in self.metadata.get_subject_session_d_input().keys():
60
+ spatial_grouping = self.metadata.get_spatial_grouping(
61
+ subject_session=sub_sesh, name=self.config.spatial_grouping
62
+ )
63
+ self.subject_session_spatial_groups[sub_sesh] = spatial_grouping
64
+
65
+ self.spatial_encoder = spe.create_spatial_encoder(
66
+ dim_h=self.dim_h,
67
+ subject_session_spatial_groups=self.subject_session_spatial_groups,
68
+ embedding_max_dim=self.config.get('embedding_max_dim', None),
69
+ embedding_init_scale=self.config.get('embedding_init_scale', 1.0),
70
+ )
71
+
72
+ def update_for_new_sessions(
73
+ self,
74
+ new_session_d_input_dict: Dict[str, int],
75
+ new_metadata: Metadata,
76
+ ) -> List:
77
+
78
+ self.subject_session_spatial_groups = {}
79
+ for sub_sesh in new_session_d_input_dict.keys():
80
+ spatial_grouping = new_metadata.get_spatial_grouping(
81
+ subject_session=sub_sesh, name=self.config.spatial_grouping
82
+ )
83
+ self.subject_session_spatial_groups[sub_sesh] = spatial_grouping
84
+
85
+ self.metadata = new_metadata
86
+
87
+
88
+ new_params = []
89
+ if self.config.add_spatial_encoding:
90
+ new_se_params = self.spatial_encoder.update_for_new_sessions(
91
+ new_subject_session_spatial_groups=self.subject_session_spatial_groups
92
+ )
93
+
94
+ new_params.extend([f"spatial_encoder.{n}" for n in new_se_params])
95
+
96
+ return new_params
97
+
98
+ def _tokenize_for_batch_tensor(
99
+ self,
100
+ x: Union[torch.Tensor, List],
101
+ subject_session: str,
102
+ add_spatial_encoding_to_tokens: bool = True,
103
+ ) -> torch.tensor:
104
+ """
105
+ Args:
106
+ x: Input tensor of shape (B, N, D) or a list of tensors each of shape (N_i, D_i)
107
+ B: Batch size
108
+ N: Time points
109
+ R: Channel dim
110
+
111
+ Returns:
112
+ Tokenized version of the same data as a TokenizedBatchedItem object.
113
+ """
114
+ batch_size, num_timepoints, num_channels = x.shape
115
+
116
+ x = einops.rearrange(x, "b n d -> b d n")
117
+
118
+ # NOTE that unfold doesn't copy the memory, so if step is less than size (sliding window)
119
+ # and any of shared elements are changed, all occurance of that element in patches will change
120
+ x = x.unfold(
121
+ dimension=-1,
122
+ size=self.config.temporal_subsegment_len,
123
+ step=self.config.temporal_subsegment_step,
124
+ ) # (B D num_subsegments subseg_len)
125
+
126
+ collapsed_x = einops.rearrange(
127
+ x, "b d t n -> (b t d) n"
128
+ ) # (B * T * D, N)
129
+
130
+ transposed_tokens = einops.rearrange(
131
+ collapsed_x, "btd n -> 1 1 btd n"
132
+ ) # (1, 1, B * T * D, N)
133
+
134
+ collapsed_tokens = self.temporal_encoder(transposed_tokens)
135
+ collapsed_tokens = collapsed_tokens.squeeze() # (B * T * D, N)
136
+
137
+ # "Time" dimension to hidden dimension. Using a fully connected layer here.
138
+ collapsed_tokens = self.temporal_pooler(
139
+ collapsed_tokens
140
+ ) # (B * T * D, N) -> (B * T * D, HID_D)
141
+
142
+ collapsed_tokens_full = collapsed_tokens
143
+
144
+ # Create the time-space interleaved tokens.
145
+ tokens = einops.rearrange(
146
+ collapsed_tokens_full,
147
+ "(b t d) dh -> b (t d) dh",
148
+ b=batch_size,
149
+ t=self.num_subsegments,
150
+ )
151
+
152
+ seqlen_timepoints = self.num_subsegments
153
+
154
+ if self.config.add_spatial_encoding:
155
+ spatial_encoding = self.spatial_encoder(
156
+ tokens,
157
+ subject_session=subject_session,
158
+ timepoints=seqlen_timepoints,
159
+ )
160
+
161
+ # Make sure regions at differnet timestamps have same spatial encoding
162
+ assert (
163
+ seqlen_timepoints == 1
164
+ or spatial_encoding[0, 0, 0] == spatial_encoding[0, num_channels, 0]
165
+ )
166
+
167
+ if add_spatial_encoding_to_tokens:
168
+ tokens = tokens + spatial_encoding
169
+
170
+ else: # not self.config.add_spatial_encoding
171
+ spatial_encoding = None
172
+
173
+ temporal_group_ids = torch.arange(seqlen_timepoints, device=x.device)
174
+ temporal_group_ids = einops.repeat(
175
+ temporal_group_ids,
176
+ "t -> b (t d)",
177
+ b=batch_size,
178
+ d=num_channels
179
+ )
180
+ # Make sure different regions at same timestamps have same positional encoding
181
+ assert seqlen_timepoints == 1 or (
182
+ temporal_group_ids[0, 0] == temporal_group_ids[0, 1]
183
+ and temporal_group_ids[0, 0]
184
+ != temporal_group_ids[
185
+ 0, num_channels
186
+ ]
187
+ )
188
+
189
+ position_ids = temporal_group_ids.clone()
190
+
191
+ return TokenizedBatchedItem(
192
+ tokens=tokens,
193
+ position_ids=position_ids,
194
+ spatial_group_ids=None,
195
+ temporal_group_ids=temporal_group_ids,
196
+ seq_lens=[tokens.shape[1]],
197
+ spatial_embeddings=spatial_encoding,
198
+ subject_sessions=[subject_session]
199
+ )
200
+
201
+ def forward(
202
+ self,
203
+ x: List,
204
+ subject_sessions: List,
205
+ output_as_list: bool = False,
206
+ add_spatial_encoding_to_tokens: bool = True,
207
+ ) -> Union[TokenizedBatchedItem, List[TokenizedBatchedItem]]:
208
+ """
209
+ Args:
210
+ x: A list of tensors each of shape (B_i, N_i, D_i)
211
+ B: Batch size
212
+ N: Time points
213
+ D: Channel dim
214
+ subject_sessions: list of strings corresponding to subject_session identifier
215
+ output_as_list: if True, will output a list of TokenizedBatchedItem, each correspond to one subject,
216
+ if False, will merge all as a long sequence
217
+ add_spatial_encoding_to_tokens: bool. Adds spatial encoding to tokens
218
+
219
+ Returns:
220
+ TokenizedBatchItem if output_as_list is False, else list of TokenizedBatchItem objects.
221
+ """
222
+ passed_datapoints = 0
223
+ tokenized_items_list = []
224
+
225
+ for x_item in x:
226
+ tokenized_item = self._tokenize_for_batch_tensor(
227
+ x_item,
228
+ subject_sessions[passed_datapoints],
229
+ add_spatial_encoding_to_tokens=add_spatial_encoding_to_tokens,
230
+ )
231
+
232
+ tokenized_items_list.append(tokenized_item)
233
+ passed_datapoints += x_item.shape[0]
234
+
235
+ if output_as_list:
236
+ return tokenized_items_list
237
+
238
+ return TokenizedBatchedItem.get_as_one_sequence(tokenized_items_list)
barista/models/transformer.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import xformers.ops as xops
5
+ from einops import rearrange, repeat
6
+
7
+ from barista.models.utils import get_activation_function
8
+
9
+
10
+ class RotaryEmbedding(nn.Module):
11
+ def __init__(self, d_head, base=10000, max_position=1024):
12
+ super().__init__()
13
+
14
+ self.d_head = d_head
15
+ self.max_position = max_position
16
+
17
+ inv_freq = 1 / (
18
+ base
19
+ ** (torch.arange(0, self.d_head, 2, dtype=torch.float32) / self.d_head)
20
+ )
21
+ self.register_buffer("inv_freq", inv_freq)
22
+ self.build_cache()
23
+
24
+ def build_cache(self):
25
+ t = torch.arange(
26
+ self.max_position,
27
+ dtype=self.inv_freq.dtype,
28
+ )
29
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq) # (self.max_position, d//2)
30
+
31
+ emb = torch.cat((freqs, freqs), dim=-1) # (self.max_position, d)
32
+ dtype = torch.get_default_dtype()
33
+ self.register_buffer(
34
+ "cos_cached", emb.cos().to(dtype), persistent=False
35
+ ) # (self.max_position, d)
36
+ self.register_buffer(
37
+ "sin_cached", emb.sin().to(dtype), persistent=False
38
+ ) # (self.max_position, d)
39
+
40
+ def forward(self, position_ids):
41
+ """Returns the rotation matrices"""
42
+ cos = self.cos_cached[position_ids].unsqueeze(2) # [bs, seq_len, 1, head_dim]
43
+ sin = self.sin_cached[position_ids].unsqueeze(2) # [bs, seq_len, 1, head_dim]
44
+ return cos, sin
45
+
46
+
47
+ def rotate_half(x):
48
+ """Rotates half the hidden dims of the input."""
49
+ x1 = x[..., : x.shape[-1] // 2]
50
+ x2 = x[..., x.shape[-1] // 2 :]
51
+ return torch.cat((-x2, x1), dim=-1)
52
+
53
+
54
+ def apply_rotary_pos_emb(q, k, cos, sin):
55
+ """
56
+ Applies the rotation matrices on query and key tensors
57
+ q: B x seq_len x num_head x head_dim
58
+ k: B x seq_len x num_head x head_dim
59
+ """
60
+ q_embed = (q * cos.to(q)) + (
61
+ rotate_half(q) * sin.to(q)
62
+ ) # [bs, seq_len, num_heads, head_dim]
63
+ k_embed = (k * cos.to(k)) + (
64
+ rotate_half(k) * sin.to(k)
65
+ ) # [bs, seq_len, num_heads, head_dim]
66
+ return q_embed, k_embed
67
+
68
+
69
+ class RMSNorm(nn.Module):
70
+ def __init__(self, d_hidden, eps=1e-6):
71
+ """
72
+ https://github.com/huggingface/transformers/blob/8e164c5400b7b413c7b8fb32e35132001effc970/src/transformers/models/llama/modeling_llama.py#L74
73
+ """
74
+ super().__init__()
75
+ self.weight = nn.Parameter(torch.ones(d_hidden))
76
+ self.variance_epsilon = eps
77
+
78
+ def forward(self, x):
79
+ input_dtype = x.dtype
80
+ variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
81
+ x = x * torch.rsqrt(variance + self.variance_epsilon)
82
+ return (self.weight * x).to(input_dtype)
83
+
84
+
85
+ class SelfAttention(nn.Module):
86
+
87
+ def __init__(
88
+ self, d_hidden, num_heads=8, dropout=0.1, **kwargs
89
+ ):
90
+ super().__init__()
91
+ self.d_hidden = d_hidden
92
+ self.num_heads = num_heads
93
+ self.d_head = self.d_hidden // self.num_heads
94
+ self.dropout = nn.Dropout(dropout)
95
+
96
+ assert (
97
+ self.d_hidden % self.num_heads == 0
98
+ ), f"Number of attention heads: {self.num_heads} must divide embedding dimension: {self.d_hidden}."
99
+
100
+ self.qkv_proj = nn.Linear(self.d_hidden, 3 * self.d_hidden, bias=True)
101
+ self.o_proj = nn.Linear(self.d_hidden, self.d_hidden, bias=True)
102
+
103
+
104
+ def get_qkv(self, x):
105
+ q, k, v = self.qkv_proj(x).chunk(3, dim=-1)
106
+
107
+ q = rearrange(q, "b n (h d_h) -> b n h d_h", h=self.num_heads)
108
+ k = rearrange(k, "b n (h d_h) -> b n h d_h", h=self.num_heads)
109
+ v = rearrange(v, "b n (h d_h) -> b n h d_h", h=self.num_heads)
110
+ return q, k, v
111
+
112
+ def get_attention_out(self, q, k, v, seq_lens=None):
113
+ attention_weights = None
114
+
115
+ attention_out = self.get_memory_efficient_attention(q, k, v, seq_lens)
116
+
117
+ attention_out = self.dropout(attention_out)
118
+ attention_out = rearrange(attention_out, "b n h d_h -> b n (h d_h)")
119
+ out = self.o_proj(attention_out)
120
+ return out, attention_weights
121
+
122
+ def get_memory_efficient_attention(self, q, k, v, seq_lens=None):
123
+ if seq_lens is not None and q.shape[0] == 1:
124
+ attn_bias = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
125
+ else:
126
+ attn_bias = None
127
+
128
+ attn_bias = attn_bias.to(q.device)
129
+
130
+ assert q.shape[-2:] == (
131
+ self.num_heads,
132
+ self.d_head,
133
+ )
134
+ attention_out = xops.memory_efficient_attention(
135
+ q,
136
+ k,
137
+ v,
138
+ p=0,
139
+ attn_bias=attn_bias,
140
+ )
141
+ return attention_out
142
+
143
+
144
+ def forward(self, x, seq_lens=None, **kwargs):
145
+ if seq_lens is None and x.shape[0] == 1:
146
+ raise ValueError(
147
+ f"'seq_lens' for memory efficient attention with variable length sequences (x.shape[0] == 1) must be non-None."
148
+ )
149
+ q, k, v = self.get_qkv(x)
150
+ out, att_weights = self.get_attention_out(q, k, v, seq_lens)
151
+ return out, att_weights
152
+
153
+
154
+ class RotarySelfAttention(SelfAttention):
155
+ def __init__(
156
+ self,
157
+ d_hidden,
158
+ num_heads=8,
159
+ max_position=1024,
160
+ dropout=0.1,
161
+ **kwargs,
162
+ ):
163
+ super().__init__(
164
+ d_hidden=d_hidden,
165
+ num_heads=num_heads,
166
+ dropout=dropout,
167
+ )
168
+ self.max_position = max_position
169
+ self.rotary_emb = RotaryEmbedding(self.d_head, max_position=self.max_position)
170
+
171
+ def forward(self, x, position_ids=None, seq_lens=None):
172
+ if seq_lens is None and x.shape[0] == 1:
173
+ raise ValueError(
174
+ "'seq_lens' for memory efficient attention with variable length sequences (x.shape[0] == 1) must be non-None."
175
+ )
176
+
177
+ if position_ids is None:
178
+ if x.shape[0] == 1:
179
+ position_ids = [torch.arange(seq_len_, device=x.device, dtype=int) for seq_len_ in seq_lens]
180
+ position_ids = torch.cat(position_ids).unsqueeze(dim=0)
181
+ else:
182
+ position_ids = repeat(
183
+ torch.arange(x.shape[1], device=x.device, dtype=int), "n -> b n", b=x.shape[0])
184
+
185
+ q, k, v = self.get_qkv(x)
186
+
187
+ cos, sin = self.rotary_emb(position_ids)
188
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
189
+ v = v.to(q)
190
+
191
+ out, att_weights = self.get_attention_out(q, k, v, seq_lens)
192
+ return out, att_weights
193
+
194
+
195
+ class GatedTransformerMLP(nn.Module):
196
+ def __init__(self, d_hidden, mlp_ratio=4, activation="silu", dropout=0.1):
197
+ super().__init__()
198
+ d_feedforward = mlp_ratio * d_hidden
199
+ self.gate_proj = nn.Linear(d_hidden, d_feedforward, bias=True)
200
+ self.down_proj = nn.Linear(d_feedforward, d_hidden, bias=True)
201
+ self.up_proj = nn.Linear(d_hidden, d_feedforward, bias=True)
202
+ self.activation_fn = get_activation_function(activation)
203
+ self.dropout1 = nn.Dropout(dropout)
204
+ self.dropout2 = nn.Dropout(dropout)
205
+
206
+ def forward(self, x):
207
+ x = self.dropout1(self.activation_fn(self.gate_proj(x)) * self.up_proj(x))
208
+ return self.dropout2(self.down_proj(x))
209
+
210
+
211
+ class TransformerEncoderLayer(nn.Module):
212
+ def __init__(
213
+ self,
214
+ d_hidden,
215
+ mlp_ratio=4,
216
+ norm="rmsnorm",
217
+ norm_eps=1e-6,
218
+ activation="silu",
219
+ num_heads=8,
220
+ dropout=0.1,
221
+ **attention_module_kwargs,
222
+ ):
223
+ super().__init__()
224
+ self.d_hidden = d_hidden
225
+
226
+ attention_cls = RotarySelfAttention
227
+
228
+ self.attention = attention_cls(
229
+ d_hidden=d_hidden,
230
+ num_heads=num_heads,
231
+ dropout=dropout,
232
+ **attention_module_kwargs,
233
+ )
234
+ self.mlp = GatedTransformerMLP(
235
+ d_hidden=d_hidden,
236
+ mlp_ratio=mlp_ratio,
237
+ activation=activation,
238
+ dropout=dropout,
239
+ )
240
+ self.dropout = nn.Dropout(dropout)
241
+
242
+ if norm.lower() == "rmsnorm":
243
+ self.norm1 = RMSNorm(d_hidden, eps=norm_eps)
244
+ self.norm2 = RMSNorm(d_hidden, eps=norm_eps)
245
+ elif norm.lower() == "layernorm":
246
+ self.norm1 = nn.LayerNorm(d_hidden, eps=norm_eps)
247
+ self.norm2 = nn.LayerNorm(d_hidden, eps=norm_eps)
248
+ else:
249
+ raise NotImplementedError()
250
+
251
+ def forward(self, x, position_ids=None, seq_lens=None, ):
252
+ residual = x
253
+ x = self.norm1(x)
254
+ x, att_weights = self.attention(
255
+ x=x,
256
+ position_ids=position_ids,
257
+ seq_lens=seq_lens,
258
+ )
259
+ x = self.dropout(x)
260
+ x = residual + x
261
+
262
+ residual = x
263
+ x = self.norm2(x)
264
+ x = self.mlp(x)
265
+ x = residual + x
266
+
267
+ return x, att_weights
268
+
269
+
270
+ class Transformer(nn.Module):
271
+ def __init__(
272
+ self,
273
+ num_layers,
274
+ d_hidden,
275
+ mlp_ratio=4,
276
+ norm="rmsnorm",
277
+ norm_eps=1e-6,
278
+ activation="gelu",
279
+ num_heads=8,
280
+ dropout=0.1,
281
+ **attention_module_kwargs,
282
+ ):
283
+ super().__init__()
284
+ self.layers = nn.ModuleList(
285
+ [
286
+ TransformerEncoderLayer(
287
+ d_hidden=d_hidden,
288
+ mlp_ratio=mlp_ratio,
289
+ norm=norm,
290
+ norm_eps=norm_eps,
291
+ activation=activation,
292
+ num_heads=num_heads,
293
+ dropout=dropout,
294
+ **attention_module_kwargs,
295
+ )
296
+ for _ in range(num_layers)
297
+ ]
298
+ )
299
+
300
+ if norm.lower() == "rmsnorm":
301
+ self.norm = RMSNorm(d_hidden, eps=norm_eps)
302
+ elif norm.lower() == "layernorm":
303
+ self.norm = nn.LayerNorm(d_hidden, eps=norm_eps)
304
+
305
+ def forward(self, x, position_ids=None, seq_lens=None, **kwargs):
306
+ weights_list = []
307
+ for layer in self.layers:
308
+ x, weights = layer(
309
+ x=x,
310
+ position_ids=position_ids,
311
+ seq_lens=seq_lens,
312
+ )
313
+ weights_list.append(weights)
314
+
315
+ if self.norm:
316
+ x = self.norm(x)
317
+
318
+ return x
barista/models/utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import random
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ def get_activation_function(activation_str):
9
+ if activation_str.lower() == "relu":
10
+ return nn.ReLU()
11
+ elif activation_str.lower() == "linear":
12
+ return lambda x: x
13
+ elif activation_str.lower() == "gelu":
14
+ return nn.GELU()
15
+
16
+ def seed_everything(seed):
17
+ os.environ["PL_GLOBAL_SEED"] = str(seed)
18
+ random.seed(seed)
19
+ np.random.seed(seed)
20
+ torch.manual_seed(seed)
21
+ torch.cuda.manual_seed(seed)
22
+ torch.cuda.manual_seed_all(seed)
23
+ print(f"Random seed set as {seed}")
barista/prepare_segments.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Script to preprocess and prepare data segments.
2
+
3
+ Example usage:
4
+ python prepare_segments.py --config config/braintreebank_config.yaml --experiment sentence_onset
5
+ """
6
+
7
+ import argparse
8
+ from omegaconf import OmegaConf
9
+
10
+ from barista.data.braintreebank_wrapper import BrainTreebankWrapper
11
+
12
+ if __name__ == "__main__":
13
+ parser = argparse.ArgumentParser()
14
+
15
+ parser.add_argument("--config", required=True, type=str, help="path to config for segmentation")
16
+ parser.add_argument("--experiment", required=True, type=str, help="experiment to segment data for")
17
+
18
+ args = parser.parse_args()
19
+
20
+ print(f"Loading config: {args.config}")
21
+ config = OmegaConf.load(args.config)
22
+
23
+ ## Instantiating BrainTreebankWrapper will be default handle all preprocessing.
24
+ ## If preprocessing is complete, then the dataset will be ready to use for training.
25
+ config.experiment = args.experiment
26
+ print(f"Segmenting data for experiment {args.experiment}")
27
+ braintreebank_wrapper = BrainTreebankWrapper(config, only_segment_generation=True)
barista/train.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+
4
+ import numpy as np
5
+ import torch
6
+ from omegaconf import OmegaConf
7
+ from sklearn.metrics import roc_auc_score
8
+ from torch import nn, optim
9
+
10
+ from barista.data.braintreebank_dataset import BrainTreebankDataset
11
+ from barista.models.model import Barista
12
+ from barista.models.utils import seed_everything
13
+
14
+
15
+ def parse_args():
16
+ """Parse command line arguments."""
17
+ parser = argparse.ArgumentParser(
18
+ description="Fine-tune Barista model on BrainTreebank dataset"
19
+ )
20
+ parser.add_argument(
21
+ "--dataset_config",
22
+ type=str,
23
+ default="barista/config/braintreebank.yaml",
24
+ help="Path to dataset configuration file",
25
+ )
26
+ parser.add_argument(
27
+ "--train_config",
28
+ type=str,
29
+ default="barista/config/train.yaml",
30
+ help="Path to training configuration file",
31
+ )
32
+ parser.add_argument(
33
+ "--model_config",
34
+ type=str,
35
+ default="barista/config/model.yaml",
36
+ help="Path to model configuration file",
37
+ )
38
+ parser.add_argument(
39
+ "--override",
40
+ type=str,
41
+ nargs="+",
42
+ default=[],
43
+ help="Override config parameters (e.g., --override epochs=50 optimization.finetune_lr=1e-4)",
44
+ )
45
+ return parser.parse_args()
46
+
47
+
48
+ def load_configs(args):
49
+ """Load all configuration files."""
50
+ dataset_config = OmegaConf.load(args.dataset_config)
51
+ train_config = OmegaConf.load(args.train_config)
52
+ model_config = OmegaConf.load(args.model_config)
53
+
54
+ assert (
55
+ len(dataset_config.finetune_sessions) == 1
56
+ ), "Specify one session for finetuning"
57
+
58
+ return dataset_config, train_config, model_config
59
+
60
+
61
+ def apply_overrides(config_dict, overrides):
62
+ """Apply command-line overrides to configs using dot notation."""
63
+ if not overrides:
64
+ return config_dict
65
+
66
+ override_dict = {}
67
+ for override in overrides:
68
+ if "=" not in override:
69
+ raise ValueError(
70
+ f"Invalid override format: {override}. Expected format: key=value"
71
+ )
72
+
73
+ key, value = override.split("=", 1)
74
+
75
+ try:
76
+ if value.isnumeric():
77
+ if "." in value:
78
+ value = float(value)
79
+ else:
80
+ value = int(value)
81
+ elif value.startswith("[") or value in ("True", "False"): # list, bool
82
+ value = eval(value)
83
+ except ValueError as e:
84
+ print(e)
85
+ pass
86
+
87
+ keys = key.split(".")
88
+ current = override_dict
89
+ for k in keys[:-1]:
90
+ if k not in current:
91
+ current[k] = {}
92
+ current = current[k]
93
+ current[keys[-1]] = value
94
+
95
+ # Convert override dict to OmegaConf and merge
96
+ override_conf = OmegaConf.create(override_dict)
97
+
98
+ # Determine which config to merge based on keys
99
+ merged_configs = {}
100
+ for config_name, config in config_dict.items():
101
+ config_keys = set(OmegaConf.to_container(config).keys())
102
+ override_keys = set(override_dict.keys())
103
+
104
+ if config_keys.intersection(override_keys):
105
+ merged_configs[config_name] = OmegaConf.merge(config, override_conf)
106
+ else:
107
+ merged_configs[config_name] = config
108
+
109
+ if merged_configs.get("train") is not None:
110
+ merged_configs["train"] = OmegaConf.merge(
111
+ merged_configs["train"], override_conf
112
+ )
113
+
114
+ return merged_configs
115
+
116
+
117
+ def setup_dataloaders(dataset_config, train_config):
118
+ """Initialize dataset and create dataloaders."""
119
+ dataset = BrainTreebankDataset(dataset_config)
120
+
121
+ train_dataloader = dataset.get_dataloader("train", train_config)
122
+ val_dataloader = dataset.get_dataloader("val", train_config)
123
+ test_dataloader = dataset.get_dataloader("test", train_config)
124
+
125
+ print(f"Train: {len(train_dataloader.dataset.metadata)} samples")
126
+ print(f"Val: {len(val_dataloader.dataset.metadata)} samples")
127
+ print(f"Test: {len(test_dataloader.dataset.metadata)} samples")
128
+
129
+ dataset.check_no_common_segment(train_dataloader, val_dataloader, test_dataloader)
130
+
131
+ return dataset, train_dataloader, val_dataloader, test_dataloader
132
+
133
+
134
+ def get_optimizer(model, finetune_lr=1e-4, new_param_lr=1e-3):
135
+ """Create optimizer with different learning rates for task and upstream parameters."""
136
+ task_params, upstream_params = [], []
137
+
138
+ for _, p in model.get_task_params():
139
+ if p.requires_grad:
140
+ task_params.append(p)
141
+
142
+ for _, p in model.get_upstream_params():
143
+ if p.requires_grad:
144
+ upstream_params.append(p)
145
+
146
+ params = [
147
+ {"params": upstream_params, "lr": finetune_lr},
148
+ {"params": task_params, "lr": new_param_lr},
149
+ ]
150
+
151
+ optimizer = optim.AdamW(params, lr=finetune_lr, weight_decay=1e-2)
152
+ return optimizer
153
+
154
+
155
+ def get_lr_scheduler(optimizer):
156
+ """Create learning rate scheduler with warmup and exponential decay."""
157
+ milestone = 5
158
+
159
+ lr_schedulers_list = [
160
+ torch.optim.lr_scheduler.LinearLR(
161
+ optimizer,
162
+ start_factor=0.2,
163
+ end_factor=1.0,
164
+ total_iters=milestone,
165
+ ),
166
+ torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99),
167
+ ]
168
+
169
+ lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
170
+ optimizer,
171
+ lr_schedulers_list,
172
+ milestones=[milestone],
173
+ )
174
+ return lr_scheduler
175
+
176
+
177
+ def load_pretrained_weights(model, checkpoint_path, device):
178
+ """Load pretrained weights, excluding masked_recon and multi_head_fc layers."""
179
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
180
+ model.load_state_dict(checkpoint)
181
+ print(f"Pretrained weights loaded from {checkpoint_path}")
182
+ return model
183
+
184
+
185
+ def freeze_tokenizer(model):
186
+ for n, p in model.tokenizer.named_parameters():
187
+ p.requires_grad = False
188
+
189
+
190
+ def print_number_of_parmas(model):
191
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
192
+ total_params = sum(p.numel() for p in model.parameters())
193
+
194
+ print(f"Model parameters: {total_params}\t Trainable params: {trainable_params}")
195
+
196
+
197
+ def run_epoch(
198
+ model, dataloader, criterion, device, optimizer=None, scheduler=None, train=False
199
+ ):
200
+ """Run one epoch of training or evaluation."""
201
+ if train:
202
+ model.train()
203
+ else:
204
+ model.eval()
205
+
206
+ all_preds = []
207
+ all_labels = []
208
+ running_loss = 0
209
+
210
+ for batch in dataloader:
211
+ x = [x_item.to(device) for x_item in batch.x]
212
+ y = batch.labels.flatten().long().to(device)
213
+
214
+ if train:
215
+ optimizer.zero_grad()
216
+
217
+ with torch.set_grad_enabled(train):
218
+ logits = model(
219
+ x,
220
+ subject_sessions=batch.subject_sessions,
221
+ )
222
+ loss = criterion(logits, y)
223
+
224
+ if train:
225
+ loss.backward()
226
+ optimizer.step()
227
+
228
+ running_loss += loss.item() * y.size(0)
229
+
230
+ probs = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy()
231
+ labels = y.detach().cpu().numpy()
232
+
233
+ all_preds.append(probs)
234
+ all_labels.append(labels)
235
+
236
+ if train:
237
+ # step scheduler at epoch interval
238
+ scheduler.step()
239
+
240
+ all_preds = np.concatenate(all_preds)
241
+ all_labels = np.concatenate(all_labels)
242
+
243
+ try:
244
+ auc = roc_auc_score(all_labels, all_preds)
245
+ except:
246
+ auc = float("nan")
247
+
248
+ avg_loss = running_loss / len(dataloader.dataset)
249
+ return avg_loss, auc
250
+
251
+
252
+ def finetune_model(model, train_dataloader, val_dataloader, train_config, device):
253
+ """Finetune the model and track best validation performance."""
254
+ criterion = nn.CrossEntropyLoss()
255
+ optimizer = get_optimizer(
256
+ model,
257
+ finetune_lr=train_config.optimization.finetune_lr,
258
+ new_param_lr=train_config.optimization.new_param_lr,
259
+ )
260
+ scheduler = get_lr_scheduler(optimizer)
261
+
262
+ best_val_auc = -1
263
+ best_state = None
264
+ num_epochs = train_config.epochs
265
+
266
+ for epoch in range(num_epochs):
267
+ train_loss, train_auc = run_epoch(
268
+ model, train_dataloader, criterion, device, optimizer, scheduler, train=True
269
+ )
270
+ val_loss, val_auc = evaluate_model(model, val_dataloader, criterion, device)
271
+
272
+ print(
273
+ f"Epoch {epoch+1}/{num_epochs} "
274
+ f"- Train Loss: {train_loss:.4f}, AUC: {train_auc:.4f} "
275
+ f"- Val Loss: {val_loss:.4f}, AUC: {val_auc:.4f}"
276
+ )
277
+
278
+ # Track best model by validation AUC
279
+ if best_state is None or val_auc > best_val_auc:
280
+ best_val_auc = val_auc
281
+ best_state = {
282
+ "epoch": epoch + 1,
283
+ "model": copy.deepcopy(model.state_dict()),
284
+ "optimizer": copy.deepcopy(optimizer.state_dict()),
285
+ "scheduler": copy.deepcopy(scheduler.state_dict()),
286
+ "val_auc": val_auc,
287
+ }
288
+
289
+ return best_state, criterion
290
+
291
+
292
+ def evaluate_model(model, test_dataloader, criterion, device):
293
+ """Evaluate model on test set."""
294
+ test_loss, test_auc = run_epoch(
295
+ model, test_dataloader, criterion, device, train=False
296
+ )
297
+ return test_loss, test_auc
298
+
299
+
300
+ def main():
301
+ """Main training pipeline."""
302
+ # Parse arguments and load configs
303
+ args = parse_args()
304
+ dataset_config, train_config, model_config = load_configs(args)
305
+
306
+ configs = {"dataset": dataset_config, "train": train_config, "model": model_config}
307
+ configs = apply_overrides(configs, args.override)
308
+ dataset_config = configs["dataset"]
309
+ train_config = configs["train"]
310
+ model_config = configs["model"]
311
+
312
+ # Set random seed
313
+ seed_everything(train_config.seed)
314
+
315
+ # Setup data
316
+ dataset, train_dataloader, val_dataloader, test_dataloader = setup_dataloaders(
317
+ dataset_config, train_config
318
+ )
319
+
320
+ # Get fine-tuning session info
321
+ ft_session = dataset_config.finetune_sessions[0]
322
+ ft_session_n_chans = dataset.metadata.get_subject_session_full_d_data()[ft_session][
323
+ -1
324
+ ]
325
+
326
+ # Initialize model
327
+ device = train_config.device
328
+ model = Barista(model_config, dataset.metadata)
329
+
330
+ # Load pretrained weights
331
+ if train_config.checkpoint_path:
332
+ print("Running pretrained model")
333
+ model = load_pretrained_weights(model, train_config.checkpoint_path, device)
334
+
335
+ # Freeze tokenizer
336
+ if train_config.optimization.freeze_tokenizer:
337
+ freeze_tokenizer(model)
338
+
339
+ else:
340
+ print("Running non-pretrained model")
341
+
342
+ # Create downstream head and move to device
343
+ model.create_downstream_head(n_chans=ft_session_n_chans, output_dim=2)
344
+ model.to(device)
345
+
346
+ print_number_of_parmas(model)
347
+
348
+ # Finetune model
349
+ best_state, criterion = finetune_model(
350
+ model, train_dataloader, val_dataloader, train_config, device
351
+ )
352
+ print(f"\nBEST VAL AUC: {best_state['val_auc']:.4f}")
353
+
354
+ # Evaluate on test set
355
+ _, last_test_auc = evaluate_model(model, test_dataloader, criterion, device)
356
+ print(f"LAST TEST AUC: {last_test_auc:.4f}")
357
+
358
+ # Load best model for testing
359
+ model.load_state_dict(best_state["model"])
360
+
361
+ # Evaluate on test set
362
+ _, test_auc = evaluate_model(model, test_dataloader, criterion, device)
363
+
364
+ print(f"BEST TEST AUC: {test_auc:.4f}")
365
+
366
+
367
+ if __name__ == "__main__":
368
+ main()
barista/utility_scripts/aggregate_runs.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import os
4
+ import re
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+
9
+ KEY = 'TEST' # Options: 'VAL', 'TEST', 'LAST_TEST'
10
+
11
+ def parse_summary(path):
12
+ try:
13
+ txt = open(path).read()
14
+ mean = float(re.search(rf"{KEY}_MEAN=([0-9.]+)", txt).group(1))
15
+ std = float(re.search(rf"{KEY}_STD=([0-9.]+)", txt).group(1))
16
+ ckpt_line = re.search(r"Checkpoint:\s*(.*)", txt).group(1)
17
+ model = os.path.basename(ckpt_line).replace(".ckpt", "")
18
+ return model, f"{mean:.3f} ± {std:.3f}"
19
+ except:
20
+ return None
21
+
22
+ def parse_from_seeds(folder):
23
+ logs = sorted(glob.glob(os.path.join(folder, "seed_*.log")))
24
+ expected_seeds = 5
25
+
26
+ if not logs:
27
+ print(f"WARNING: No seed logs found in {folder}")
28
+ return None
29
+
30
+ auc_pattern = r"TEST AUC:\s*([0-9.]+)" if KEY == "TEST" else \
31
+ r"LAST TEST AUC:\s*([0-9.]+)" if KEY == "LAST_TEST" else None
32
+ if auc_pattern is None:
33
+ return None
34
+
35
+ ckpt_pattern = r"'checkpoint_path':\s*'([^']*)'"
36
+
37
+ vals, model_name, valid_logs = [], None, 0
38
+
39
+ for log in logs:
40
+ try:
41
+ txt = open(log).read()
42
+ m = re.search(auc_pattern, txt)
43
+ if m:
44
+ vals.append(float(m.group(1)))
45
+ valid_logs += 1
46
+
47
+ cm = re.search(ckpt_pattern, txt)
48
+ if cm:
49
+ ckpt_path = cm.group(1)
50
+ model_name = os.path.basename(ckpt_path).replace(".ckpt", "")
51
+ except:
52
+ pass
53
+
54
+ model_name = model_name or "unknown"
55
+ if model_name == '':
56
+ model_name = "random"
57
+
58
+ if valid_logs != expected_seeds and model_name != 'random':
59
+ print(f"WARNING: Incomplete seeds for {model_name} in {folder} "
60
+ f"(found {valid_logs}/{expected_seeds})")
61
+
62
+ if not vals:
63
+ return None
64
+
65
+ mean, std = float(np.mean(vals)), float(np.std(vals))
66
+ return model_name, f"{mean:.3f} ± {std:.3f}"
67
+
68
+ def parse_summary_or_seeds(folder):
69
+ summary_path = os.path.join(folder, "summary.txt")
70
+ if os.path.exists(summary_path):
71
+ parsed = parse_summary(summary_path)
72
+ if parsed:
73
+ return parsed
74
+ return parse_from_seeds(folder)
75
+
76
+ def extract_mean(x):
77
+ if isinstance(x, str) and "±" in x:
78
+ return float(x.split("±")[0].strip())
79
+ return np.nan
80
+
81
+ def main():
82
+ parser = argparse.ArgumentParser()
83
+ parser.add_argument("--results_dir", type=str, default="results", help="Path to results folder")
84
+ args = parser.parse_args()
85
+ ROOT = args.results_dir
86
+
87
+ rows, subjects, tasks, models, folds = [], set(), set(), set(), set()
88
+
89
+ # Collect data from folders
90
+ for folder in os.listdir(ROOT):
91
+ fpath = os.path.join(ROOT, folder)
92
+ if not os.path.isdir(fpath):
93
+ continue
94
+
95
+ parts = folder.split("_")
96
+ if len(parts) < 6:
97
+ continue
98
+
99
+ subj = parts[1]
100
+ task = parts[4]
101
+ if len(parts) > 5 and parts[5] in ["onset", "vs", "nonspeech", "speech", "time"]:
102
+ task += f"_{parts[5]}"
103
+ if len(parts) > 6 and parts[6] == "nonspeech":
104
+ task += f"_{parts[6]}"
105
+
106
+ fold = None
107
+ for p in parts:
108
+ if p.startswith("fold"):
109
+ fold = int(p.replace("fold", ""))
110
+ folds.add(fold)
111
+ break
112
+
113
+ parsed = parse_summary_or_seeds(fpath)
114
+ if not parsed:
115
+ continue
116
+
117
+ model, value = parsed
118
+ subjects.add(subj)
119
+ tasks.add(task)
120
+ models.add(model)
121
+ rows.append((task, model, subj, fold, value))
122
+
123
+ # Build DataFrame
124
+ subjects = sorted(subjects, key=lambda x: int(x))
125
+ df = pd.DataFrame(columns=["task", "model", "fold"] + subjects)
126
+
127
+ for task in sorted(tasks):
128
+ for model in sorted(models):
129
+ all_folds = sorted(folds) + [None]
130
+ for fold in all_folds:
131
+ subset = [(s, v) for t, m, s, f, v in rows if t == task and m == model and f == fold]
132
+ if not subset:
133
+ continue
134
+ row = {"task": task, "model": model, "fold": fold if fold is not None else ""}
135
+ for subj, val in subset:
136
+ row[subj] = val
137
+ df.loc[len(df)] = row
138
+
139
+ # Add AVG column
140
+ subj_cols = [c for c in df.columns if c not in ["task", "model", "fold"]]
141
+ df["avg"] = df[subj_cols].applymap(extract_mean).mean(axis=1)
142
+ df["avg"] = df["avg"].apply(lambda x: f"{x:.3f}" if pd.notnull(x) else "")
143
+
144
+ # Add final AVG rows per (task, model)
145
+ avg_rows = []
146
+ for (task, model), group in df.groupby(["task", "model"]):
147
+ subj_avgs = {}
148
+ for subj in subj_cols:
149
+ vals = [float(v.split("±")[0].strip()) for v in group[subj] if isinstance(v, str) and "±" in v]
150
+ subj_avgs[subj] = f"{np.mean(vals):.3f}" if vals else ""
151
+ overall_vals = [float(v) for v in subj_avgs.values() if v != ""]
152
+ overall_avg = f"{np.mean(overall_vals):.3f}" if overall_vals else ""
153
+ row = {"task": task, "model": model, "fold": "AVG", "avg": overall_avg}
154
+ row.update(subj_avgs)
155
+ avg_rows.append(row)
156
+
157
+ df = pd.concat([df, pd.DataFrame(avg_rows)], ignore_index=True)
158
+ print(df.to_markdown(index=False))
159
+
160
+ if __name__ == "__main__":
161
+ main()
barista/utility_scripts/run_finetune_folds.sh ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Usage:
4
+ # ./run_finetune_folds.sh --spe coords --checkpoint "pretrained_models/chans_chans.ckpt" --session HOLDSUBJ_2_HS2_6 --gpu 1 --fold 0 --exp sentence_onset_time
5
+ # ./run_finetune_folds.sh --spe destrieux --checkpoint "pretrained_models/parcels_chans.ckpt" --session HOLDSUBJ_2_HS2_6 --gpu 2 --fold 1 --exp speech_vs_nonspeech_time
6
+
7
+ # Default values
8
+ GPU=0
9
+ SEEDS=(0 1 2 3 4)
10
+ SESSION=""
11
+ CHECKPOINT=""
12
+ DATASET_CONFIG="barista/config/braintreebank.yaml"
13
+ TRAIN_CONFIG="barista/config/train.yaml"
14
+ MODEL_CONFIG="barista/config/model.yaml"
15
+ SPATIAL_GROUPING="coords"
16
+ EXPERIMENT="sentence_onset_time"
17
+ FOLD_NUM=0
18
+
19
+ # Parse arguments
20
+ while [[ $# -gt 0 ]]; do
21
+ case $1 in
22
+ --session)
23
+ SESSION="$2"
24
+ shift 2
25
+ ;;
26
+ --checkpoint)
27
+ CHECKPOINT="$2"
28
+ shift 2
29
+ ;;
30
+ --gpu)
31
+ GPU="$2"
32
+ shift 2
33
+ ;;
34
+ --fold)
35
+ FOLD_NUM="$2"
36
+ shift 2
37
+ ;;
38
+ --seeds)
39
+ IFS=',' read -ra SEEDS <<< "$2"
40
+ shift 2
41
+ ;;
42
+ --dataset_config)
43
+ DATASET_CONFIG="$2"
44
+ shift 2
45
+ ;;
46
+ --exp)
47
+ EXPERIMENT="$2"
48
+ shift 2
49
+ ;;
50
+ --train_config)
51
+ TRAIN_CONFIG="$2"
52
+ shift 2
53
+ ;;
54
+ --spe)
55
+ SPATIAL_GROUPING="$2"
56
+ shift 2
57
+ ;;
58
+ --model_config)
59
+ MODEL_CONFIG="$2"
60
+ shift 2
61
+ ;;
62
+ *)
63
+ echo "Unknown argument: $1"
64
+ echo "Usage: $0 --session <session_name> --checkpoint <checkpoint_path> [--gpu <gpu_id>] [--seeds <seed_list>]"
65
+ echo "Example: $0 --session session1 --checkpoint checkpoints/model.pt --gpu 0 --seeds 42,123,456,789,1024"
66
+ exit 1
67
+ ;;
68
+ esac
69
+ done
70
+
71
+ # Validate required arguments
72
+ if [ -z "$SESSION" ]; then
73
+ echo "Error: --session is required"
74
+ exit 1
75
+ fi
76
+
77
+
78
+ NUM_SEEDS=${#SEEDS[@]}
79
+
80
+ # Create output directory
81
+ OUTPUT_DIR="results_folds/${SESSION}_${EXPERIMENT}_fold${FOLD_NUM}_model${SPATIAL_GROUPING}_$(date +%Y%m%d_%H%M%S)"
82
+
83
+ mkdir -p "$OUTPUT_DIR"
84
+
85
+ echo "=========================================="
86
+ echo "Sequential Multi-Seed Fine-tuning"
87
+ echo "=========================================="
88
+ echo "Session: $SESSION"
89
+ echo "Checkpoint: $CHECKPOINT"
90
+ echo "GPU: $GPU"
91
+ echo "Seeds: ${SEEDS[@]}"
92
+ echo "Number of runs: $NUM_SEEDS"
93
+ echo "Output Directory: $OUTPUT_DIR"
94
+ echo "=========================================="
95
+ echo ""
96
+
97
+ # Arrays to store results
98
+ VAL_AUCS=()
99
+ BEST_TEST_AUCS=()
100
+ LAST_TEST_AUCS=()
101
+ FAILED_SEEDS=()
102
+
103
+ # Run jobs sequentially
104
+ for i in $(seq 0 $(($NUM_SEEDS - 1))); do
105
+ SEED=${SEEDS[$i]}
106
+
107
+ LOG_FILE="$OUTPUT_DIR/seed_${SEED}.log"
108
+
109
+ echo "=========================================="
110
+ echo "Running job $((i+1))/$NUM_SEEDS: Seed=$SEED"
111
+ echo "=========================================="
112
+ echo "Log file: $LOG_FILE"
113
+ echo ""
114
+
115
+ # Run training
116
+ CUDA_VISIBLE_DEVICES=$GPU python barista/train.py \
117
+ --dataset_config "$DATASET_CONFIG" \
118
+ --train_config "$TRAIN_CONFIG" \
119
+ --model_config "$MODEL_CONFIG" \
120
+ --override \
121
+ seed=$SEED \
122
+ device=cuda:0 \
123
+ checkpoint_path="$CHECKPOINT" \
124
+ force_nonoverlap=False \
125
+ experiment=$EXPERIMENT \
126
+ chron_fold_num=$FOLD_NUM \
127
+ tokenizer.spatial_grouping="$SPATIAL_GROUPING" \
128
+ "finetune_sessions=['$SESSION']" \
129
+ 2>&1 | tee "$LOG_FILE"
130
+
131
+ # Check if job completed successfully
132
+ if [ ${PIPESTATUS[0]} -eq 0 ]; then
133
+ echo ""
134
+ echo "✓ Job $((i+1)) completed successfully"
135
+
136
+ # Extract results from log file
137
+ VAL_AUC=$(grep "BEST VAL AUC" "$LOG_FILE" | awk '{print $NF}')
138
+ BEST_TEST_AUC=$(grep "BEST TEST AUC" "$LOG_FILE" | tail -1 | awk '{print $NF}')
139
+ LAST_TEST_AUC=$(grep "LAST TEST AUC" "$LOG_FILE" | awk '{print $NF}')
140
+
141
+ if [ ! -z "$VAL_AUC" ] && [ ! -z "$BEST_TEST_AUC" ] && [ ! -z "$LAST_TEST_AUC" ]; then
142
+ VAL_AUCS+=($VAL_AUC)
143
+ BEST_TEST_AUCS+=($BEST_TEST_AUC)
144
+ LAST_TEST_AUCS+=($LAST_TEST_AUC)
145
+ echo " Val AUC: $VAL_AUC"
146
+ echo " Best Test AUC: $BEST_TEST_AUC"
147
+ echo " Last Test AUC: $LAST_TEST_AUC"
148
+ else
149
+ echo " Warning: Could not extract AUC values"
150
+ FAILED_SEEDS+=($SEED)
151
+ fi
152
+ else
153
+ echo ""
154
+ echo "✗ Job $((i+1)) failed"
155
+ FAILED_SEEDS+=($SEED)
156
+ fi
157
+
158
+ echo ""
159
+ done
160
+
161
+ echo "=========================================="
162
+ echo "All jobs completed!"
163
+ echo "=========================================="
164
+ echo ""
165
+
166
+ # Calculate statistics using Python
167
+ STATS_SCRIPT="$OUTPUT_DIR/calculate_stats.py"
168
+ cat > "$STATS_SCRIPT" << 'EOF'
169
+ import sys
170
+ import numpy as np
171
+
172
+ def calculate_stats(values):
173
+ if len(values) == 0:
174
+ return None, None
175
+ arr = np.array(values, dtype=float)
176
+ return np.mean(arr), np.std(arr)
177
+
178
+ # Read values from command line
179
+ val_aucs = [float(x) for x in sys.argv[1].split(',') if x]
180
+ best_test_aucs = [float(x) for x in sys.argv[2].split(',') if x]
181
+ last_test_aucs = [float(x) for x in sys.argv[3].split(',') if x]
182
+
183
+ val_mean, val_std = calculate_stats(val_aucs)
184
+ best_test_mean, best_test_std = calculate_stats(best_test_aucs)
185
+ last_test_mean, last_test_std = calculate_stats(last_test_aucs)
186
+
187
+ print(f"VAL_MEAN={val_mean:.4f}")
188
+ print(f"VAL_STD={val_std:.4f}")
189
+ print(f"BEST_TEST_MEAN={best_test_mean:.4f}")
190
+ print(f"BEST_TEST_STD={best_test_std:.4f}")
191
+ print(f"LAST_TEST_MEAN={last_test_mean:.4f}")
192
+ print(f"LAST_TEST_STD={last_test_std:.4f}")
193
+
194
+ # Print individual values
195
+ print("\nIndividual Results:")
196
+ for i, (val, test, last_test) in enumerate(zip(val_aucs, best_test_aucs, last_test_aucs), 1):
197
+ print(f" Run {i}: Val AUC = {val:.4f}, Best Test AUC = {test:.4f}, Last Test AUC = {last_test:.4f}")
198
+ EOF
199
+
200
+ # Convert arrays to comma-separated strings
201
+ VAL_AUCS_STR=$(IFS=,; echo "${VAL_AUCS[*]}")
202
+ BEST_TEST_AUCS_STR=$(IFS=,; echo "${BEST_TEST_AUCS[*]}")
203
+ LAST_TEST_AUCS_STR=$(IFS=,; echo "${LAST_TEST_AUCS[*]}")
204
+
205
+ # Calculate and display statistics
206
+ if [ ${#BEST_TEST_AUCS[@]} -gt 0 ]; then
207
+ echo "=========================================="
208
+ echo "FINAL RESULTS"
209
+ echo "=========================================="
210
+
211
+ STATS_OUTPUT=$(python "$STATS_SCRIPT" "$VAL_AUCS_STR" "$BEST_TEST_AUCS_STR" "$LAST_TEST_AUCS_STR")
212
+ echo "$STATS_OUTPUT"
213
+
214
+ VAL_MEAN=$(awk -F= '/^VAL_MEAN=/{print $2; exit}' <<<"$STATS_OUTPUT")
215
+ VAL_STD=$(awk -F= '/^VAL_STD=/{print $2; exit}' <<<"$STATS_OUTPUT")
216
+ BEST_TEST_MEAN=$(awk -F= '/^BEST_TEST_MEAN=/{print $2; exit}' <<<"$STATS_OUTPUT")
217
+ BEST_TEST_STD=$(awk -F= '/^BEST_TEST_STD=/{print $2; exit}' <<<"$STATS_OUTPUT")
218
+ LAST_TEST_MEAN=$(awk -F= '/^LAST_TEST_MEAN=/{print $2; exit}' <<<"$STATS_OUTPUT")
219
+ LAST_TEST_STD=$(awk -F= '/^LAST_TEST_STD=/{print $2; exit}' <<<"$STATS_OUTPUT")
220
+
221
+ echo ""
222
+ echo "Summary:"
223
+ echo " Validation AUC: ${VAL_MEAN} ± ${VAL_STD}"
224
+ echo " Test AUC: ${BEST_TEST_MEAN} ± ${BEST_TEST_STD}"
225
+ echo " Last Test AUC: ${LAST_TEST_MEAN} ± ${LAST_TEST_STD}"
226
+ echo ""
227
+ echo "Successful runs: ${#BEST_TEST_AUCS[@]}/$NUM_SEEDS"
228
+
229
+ if [ ${#FAILED_SEEDS[@]} -gt 0 ]; then
230
+ echo "Failed seeds: ${FAILED_SEEDS[@]}"
231
+ fi
232
+
233
+ echo "=========================================="
234
+
235
+ # Save summary to file
236
+ SUMMARY_FILE="$OUTPUT_DIR/summary.txt"
237
+ {
238
+ echo "Summary Report - $(date)"
239
+ echo "=================================="
240
+ echo "Session: $SESSION"
241
+ echo "Checkpoint: $CHECKPOINT"
242
+ echo "GPU: $GPU"
243
+ echo "Seeds: ${SEEDS[@]}"
244
+ echo ""
245
+ echo "FINAL RESULTS"
246
+ echo "=================================="
247
+ echo "$STATS_OUTPUT"
248
+ echo ""
249
+ echo "Summary:"
250
+ echo " Validation AUC: ${VAL_MEAN} ± ${VAL_STD}"
251
+ echo " BEST Test AUC: ${BEST_TEST_MEAN} ± ${BEST_TEST_STD}"
252
+ echo " Last Test AUC: ${LAST_TEST_MEAN} ± ${LAST_TEST_STD}"
253
+ echo ""
254
+ echo "Successful runs: ${#BEST_TEST_AUCS[@]}/$NUM_SEEDS"
255
+ if [ ${#FAILED_SEEDS[@]} -gt 0 ]; then
256
+ echo "Failed seeds: ${FAILED_SEEDS[@]}"
257
+ fi
258
+ } > "$SUMMARY_FILE"
259
+
260
+ echo ""
261
+ echo "Summary saved to: $SUMMARY_FILE"
262
+ echo "All logs saved to: $OUTPUT_DIR"
263
+ else
264
+ echo "ERROR: No successful runs completed"
265
+ exit 1
266
+ fi
267
+
268
+ # Clean up temporary script
269
+ rm "$STATS_SCRIPT"
270
+
271
+ # Exit with error if any jobs failed
272
+ if [ ${#FAILED_SEEDS[@]} -gt 0 ]; then
273
+ exit 1
274
+ fi
275
+
276
+ exit 0
barista/utility_scripts/run_finetune_random_splits.sh ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Usage:
4
+ # ./run_finetune_random_splits.sh --spe coords --checkpoint "pretrained_models/chans_chans.ckpt" --session HOLDSUBJ_2_HS2_6 --gpu 1 --exp sentence_onset
5
+ # ./run_finetune_random_splits.sh --spe destrieux --checkpoint "pretrained_models/parcels_chans.ckpt" --session HOLDSUBJ_2_HS2_6 --gpu 2 --exp speech_vs_nonspeech
6
+
7
+ # Default values
8
+ GPU=0
9
+ SEEDS=(0 1 2 3 4)
10
+ SESSION=""
11
+ CHECKPOINT=""
12
+ DATASET_CONFIG="barista/config/braintreebank.yaml"
13
+ TRAIN_CONFIG="barista/config/train.yaml"
14
+ MODEL_CONFIG="barista/config/model.yaml"
15
+ SPATIAL_GROUPING="coords"
16
+ EXPERIMENT="sentence_onset"
17
+
18
+ # Parse arguments
19
+ while [[ $# -gt 0 ]]; do
20
+ case $1 in
21
+ --session)
22
+ SESSION="$2"
23
+ shift 2
24
+ ;;
25
+ --checkpoint)
26
+ CHECKPOINT="$2"
27
+ shift 2
28
+ ;;
29
+ --gpu)
30
+ GPU="$2"
31
+ shift 2
32
+ ;;
33
+ --seeds)
34
+ IFS=',' read -ra SEEDS <<< "$2"
35
+ shift 2
36
+ ;;
37
+ --dataset_config)
38
+ DATASET_CONFIG="$2"
39
+ shift 2
40
+ ;;
41
+ --train_config)
42
+ TRAIN_CONFIG="$2"
43
+ shift 2
44
+ ;;
45
+ --exp)
46
+ EXPERIMENT="$2"
47
+ shift 2
48
+ ;;
49
+ --spe)
50
+ SPATIAL_GROUPING="$2"
51
+ shift 2
52
+ ;;
53
+ --model_config)
54
+ MODEL_CONFIG="$2"
55
+ shift 2
56
+ ;;
57
+ *)
58
+ echo "Unknown argument: $1"
59
+ exit 1
60
+ ;;
61
+ esac
62
+ done
63
+
64
+ # Validate required arguments
65
+ if [ -z "$SESSION" ]; then
66
+ echo "Error: --session is required"
67
+ exit 1
68
+ fi
69
+
70
+
71
+ NUM_SEEDS=${#SEEDS[@]}
72
+
73
+ # Create output directory
74
+ OUTPUT_DIR="results/${SESSION}_${EXPERIMENT}_model${SPATIAL_GROUPING}$(date +%Y%m%d_%H%M%S)"
75
+ mkdir -p "$OUTPUT_DIR"
76
+
77
+ echo "=========================================="
78
+ echo "Sequential Multi-Seed Fine-tuning"
79
+ echo "=========================================="
80
+ echo "Session: $SESSION"
81
+ echo "Checkpoint: $CHECKPOINT"
82
+ echo "GPU: $GPU"
83
+ echo "Seeds: ${SEEDS[@]}"
84
+ echo "Number of runs: $NUM_SEEDS"
85
+ echo "Output Directory: $OUTPUT_DIR"
86
+ echo "=========================================="
87
+ echo ""
88
+
89
+ # Arrays to store results
90
+ VAL_AUCS=()
91
+ BEST_TEST_AUCS=()
92
+ LAST_TEST_AUCS=()
93
+ FAILED_SEEDS=()
94
+
95
+ # Run jobs sequentially
96
+ for i in $(seq 0 $(($NUM_SEEDS - 1))); do
97
+ SEED=${SEEDS[$i]}
98
+
99
+ LOG_FILE="$OUTPUT_DIR/seed_${SEED}.log"
100
+
101
+ echo "=========================================="
102
+ echo "Running job $((i+1))/$NUM_SEEDS: Seed=$SEED"
103
+ echo "=========================================="
104
+ echo "Log file: $LOG_FILE"
105
+ echo ""
106
+
107
+ # Run training
108
+ CUDA_VISIBLE_DEVICES=$GPU python barista/train.py \
109
+ --dataset_config "$DATASET_CONFIG" \
110
+ --train_config "$TRAIN_CONFIG" \
111
+ --model_config "$MODEL_CONFIG" \
112
+ --override \
113
+ seed=$SEED \
114
+ device=cuda:0 \
115
+ checkpoint_path="$CHECKPOINT" \
116
+ force_nonoverlap=True \
117
+ experiment="$EXPERIMENT" \
118
+ tokenizer.spatial_grouping="$SPATIAL_GROUPING" \
119
+ "finetune_sessions=['$SESSION']" \
120
+ 2>&1 | tee "$LOG_FILE"
121
+
122
+ # Check if job completed successfully
123
+ if [ ${PIPESTATUS[0]} -eq 0 ]; then
124
+ echo ""
125
+ echo "✓ Job $((i+1)) completed successfully"
126
+
127
+ # Extract results from log file
128
+ VAL_AUC=$(grep "BEST VAL AUC" "$LOG_FILE" | awk '{print $NF}')
129
+ BEST_TEST_AUC=$(grep "^BEST TEST AUC" "$LOG_FILE" | tail -1 | awk '{print $NF}')
130
+ LAST_TEST_AUC=$(grep "LAST TEST AUC" "$LOG_FILE" | awk '{print $NF}')
131
+
132
+ if [ ! -z "$VAL_AUC" ] && [ ! -z "$BEST_TEST_AUC" ] && [ ! -z "$LAST_TEST_AUC" ]; then
133
+ VAL_AUCS+=($VAL_AUC)
134
+ BEST_TEST_AUCS+=($BEST_TEST_AUC)
135
+ LAST_TEST_AUCS+=($LAST_TEST_AUC)
136
+ echo " Val AUC: $VAL_AUC"
137
+ echo " Test AUC: $BEST_TEST_AUC"
138
+ echo " Last Test AUC: $LAST_TEST_AUC"
139
+ else
140
+ echo " Warning: Could not extract AUC values"
141
+ FAILED_SEEDS+=($SEED)
142
+ fi
143
+ else
144
+ echo ""
145
+ echo "✗ Job $((i+1)) failed"
146
+ FAILED_SEEDS+=($SEED)
147
+ fi
148
+
149
+ echo ""
150
+ done
151
+
152
+ echo "=========================================="
153
+ echo "All jobs completed!"
154
+ echo "=========================================="
155
+ echo ""
156
+
157
+ # Calculate statistics using Python
158
+ STATS_SCRIPT="$OUTPUT_DIR/calculate_stats.py"
159
+ cat > "$STATS_SCRIPT" << 'EOF'
160
+ import sys
161
+ import numpy as np
162
+
163
+ def calculate_stats(values):
164
+ if len(values) == 0:
165
+ return None, None
166
+ arr = np.array(values, dtype=float)
167
+ return np.mean(arr), np.std(arr)
168
+
169
+ # Read values from command line
170
+ val_aucs = [float(x) for x in sys.argv[1].split(',') if x]
171
+ best_test_aucs = [float(x) for x in sys.argv[2].split(',') if x]
172
+ last_test_aucs = [float(x) for x in sys.argv[3].split(',') if x]
173
+
174
+ val_mean, val_std = calculate_stats(val_aucs)
175
+ best_test_mean, best_test_std = calculate_stats(best_test_aucs)
176
+ last_test_mean, last_test_std = calculate_stats(last_test_aucs)
177
+
178
+ print(f"VAL_MEAN={val_mean:.4f}")
179
+ print(f"VAL_STD={val_std:.4f}")
180
+ print(f"BEST_TEST_MEAN={best_test_mean:.4f}")
181
+ print(f"BEST_TEST_STD={best_test_std:.4f}")
182
+ print(f"LAST_TEST_MEAN={last_test_mean:.4f}")
183
+ print(f"LAST_TEST_STD={last_test_std:.4f}")
184
+
185
+ # Print individual values
186
+ print("\nIndividual Results:")
187
+ for i, (val, test, last_test) in enumerate(zip(val_aucs, best_test_aucs, last_test_aucs), 1):
188
+ print(f" Run {i}: Val AUC = {val:.4f}, Best Test AUC = {test:.4f}, Last Test AUC = {last_test:.4f}")
189
+ EOF
190
+
191
+ # Convert arrays to comma-separated strings
192
+ VAL_AUCS_STR=$(IFS=,; echo "${VAL_AUCS[*]}")
193
+ BEST_TEST_AUCS_STR=$(IFS=,; echo "${BEST_TEST_AUCS[*]}")
194
+ LAST_TEST_AUCS_STR=$(IFS=,; echo "${LAST_TEST_AUCS[*]}")
195
+
196
+ # Calculate and display statistics
197
+ if [ ${#BEST_TEST_AUCS[@]} -gt 0 ]; then
198
+ echo "=========================================="
199
+ echo "FINAL RESULTS"
200
+ echo "=========================================="
201
+
202
+ STATS_OUTPUT=$(python "$STATS_SCRIPT" "$VAL_AUCS_STR" "$BEST_TEST_AUCS_STR" "$LAST_TEST_AUCS_STR")
203
+ echo "$STATS_OUTPUT"
204
+
205
+ VAL_MEAN=$(awk -F= '/^VAL_MEAN=/{print $2; exit}' <<<"$STATS_OUTPUT")
206
+ VAL_STD=$(awk -F= '/^VAL_STD=/{print $2; exit}' <<<"$STATS_OUTPUT")
207
+ BEST_TEST_MEAN=$(awk -F= '/^BEST_TEST_MEAN=/{print $2; exit}' <<<"$STATS_OUTPUT")
208
+ BEST_TEST_STD=$(awk -F= '/^BEST_TEST_STD=/{print $2; exit}' <<<"$STATS_OUTPUT")
209
+ LAST_TEST_MEAN=$(awk -F= '/^LAST_TEST_MEAN=/{print $2; exit}' <<<"$STATS_OUTPUT")
210
+ LAST_TEST_STD=$(awk -F= '/^LAST_TEST_STD=/{print $2; exit}' <<<"$STATS_OUTPUT")
211
+
212
+ echo ""
213
+ echo "Summary:"
214
+ echo " Validation AUC: ${VAL_MEAN} ± ${VAL_STD}"
215
+ echo " Best Test AUC: ${BEST_TEST_MEAN} ± ${BEST_TEST_STD}"
216
+ echo " Last Test AUC: ${LAST_TEST_MEAN} ± ${LAST_TEST_STD}"
217
+ echo ""
218
+ echo "Successful runs: ${#BEST_TEST_AUCS[@]}/$NUM_SEEDS"
219
+
220
+ if [ ${#FAILED_SEEDS[@]} -gt 0 ]; then
221
+ echo "Failed seeds: ${FAILED_SEEDS[@]}"
222
+ fi
223
+
224
+ echo "=========================================="
225
+
226
+ # Save summary to file
227
+ SUMMARY_FILE="$OUTPUT_DIR/summary.txt"
228
+ {
229
+ echo "Summary Report - $(date)"
230
+ echo "=================================="
231
+ echo "Session: $SESSION"
232
+ echo "Checkpoint: $CHECKPOINT"
233
+ echo "GPU: $GPU"
234
+ echo "Seeds: ${SEEDS[@]}"
235
+ echo ""
236
+ echo "FINAL RESULTS"
237
+ echo "=================================="
238
+ echo "$STATS_OUTPUT"
239
+ echo ""
240
+ echo "Summary:"
241
+ echo " Validation AUC: ${VAL_MEAN} ± ${VAL_STD}"
242
+ echo " Test AUC: ${BEST_TEST_MEAN} ± ${BEST_TEST_STD}"
243
+ echo " Last Test AUC: ${LAST_TEST_MEAN} ± ${LAST_TEST_STD}"
244
+ echo ""
245
+ echo "Successful runs: ${#BEST_TEST_AUCS[@]}/$NUM_SEEDS"
246
+ if [ ${#FAILED_SEEDS[@]} -gt 0 ]; then
247
+ echo "Failed seeds: ${FAILED_SEEDS[@]}"
248
+ fi
249
+ } > "$SUMMARY_FILE"
250
+
251
+ echo ""
252
+ echo "Summary saved to: $SUMMARY_FILE"
253
+ echo "All logs saved to: $OUTPUT_DIR"
254
+ else
255
+ echo "ERROR: No successful runs completed"
256
+ exit 1
257
+ fi
258
+
259
+ # Clean up temporary script
260
+ rm "$STATS_SCRIPT"
261
+
262
+ # Exit with error if any jobs failed
263
+ if [ ${#FAILED_SEEDS[@]} -gt 0 ]; then
264
+ exit 1
265
+ fi
266
+
267
+ exit 0
pretrained_models/chans_chans.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:400eeacc0697004cb81c9ecf754859da184ffeea40afc8ee7b5930c3b997e1d0
3
+ size 3538414
pretrained_models/lobes_chans.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d810338a4929df0fb2421f342b3ee859f9fef269e35fb4f2fd9c55347a63324a
3
+ size 3389478
pretrained_models/parcels_chans.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c234517d286a8e710b09716dc88c713618670df523cfffb89e4c9073f2657c1
3
+ size 3415452
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.4.0
2
+ einops==0.8.0
3
+ h5py==3.11.0
4
+ ipykernel==6.29.5
5
+ ipython==8.12.3
6
+ jupyter-client==8.6.3
7
+ jupyter-core==5.7.2
8
+ numpy==1.24.4
9
+ omegaconf==2.3.0
10
+ ordered-set==4.1.0
11
+ pandas==2.0.3
12
+ scikit-learn==1.3.2
13
+ scipy==1.10.1
14
+ xformers==0.0.27.post2
15
+ tabulate==0.9.0
setup.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import find_packages, setup
2
+
3
+ with open("requirements.txt", "r") as f:
4
+ requirements = f.read().splitlines()
5
+
6
+ setup(
7
+ name="barista",
8
+ version="1.0.0",
9
+ description="PyTorch implementation of BaRISTA: Brain Scale Informed Spatiotemporal Representation of Human Intracranial Neural Activity",
10
+ long_description=open("README.md", encoding="utf-8").read(),
11
+ long_description_content_type="text/markdown",
12
+ author="Lucine L. Oganesian, Saba Hashemi, Maryam M. Shanechi",
13
+ author_email="shanechi@usc.edu",
14
+ url="https://github.com/ShanechiLab/BaRISTA", # change to actual repo URL
15
+ packages=find_packages(),
16
+ python_requires=">=3.8",
17
+ install_requires=requirements,
18
+ include_package_data=True,
19
+ entry_points={
20
+ "console_scripts": [
21
+ "barista-train=barista.train:main",
22
+ "barista-prepare=barista.prepare_segments:main",
23
+ ],
24
+ },
25
+ )