File size: 5,014 Bytes
f828e59 6a51385 f828e59 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 | ---
license: cc-by-nc-sa-4.0
tags:
- fMRI
- neuroscience
- foundation_model
---
# Flexibrain
Flexibrain is a voxel-level fMRI representation learning framework for pretraining and downstream classification. It keeps fMRI volumes in a fixed 96 x 96 x 96 input grid, reads each sample's voxel spacing and TR from the NIfTI header, and resizes patch embedding kernels in physical spatial and temporal units before learning with a Mamba-JEPA backbone.
<p align="center">
<img src="assets/pipeline.png" width="900" alt="FlexiBrain framework pipeline">
</p>
## Installation
The code was tested on l40 with Python 3.10, PyTorch 2.1.2, CUDA 12.1, `causal-conv1d`, `mamba-ssm`, and `flash-attn`.
```bash
conda create -n flexibrain python=3.10
conda activate flexibrain
pip install -r requirements.txt
pip install -e .
```
Check the CLI:
```bash
python -m flexibrain --help
python -m flexibrain pretrain --help
python -m flexibrain downstream --help
```
## Data Preparation
Each sample should be a 4D NIfTI file shaped as:
```text
96 x 96 x 96 x T
```
Flexibrain uses the NIfTI header to read voxel spacing and TR. If a dataset has missing TR metadata, fix the header before training or pass an explicit fallback with `--default-tr` / `data.default_tr`.
`T_prime` and `tau_seconds` control the selected temporal length:
```text
kt = round(tau_seconds / TR)
T_selected = T_prime * kt
```
The preprocessing script can convert native/T1/MNI-space inputs to 96 x 96 x 96, apply sample-wise global z-score normalization over foreground voxels, and write 4D NIfTI outputs:
```bash
python data_process.py \
--input-root /path/to/input_root \
--output-root /path/to/output_root \
--spaces all \
--groups class0,class1,class2
```
Expected grouped input layout:
```text
input_root/
|-- nativespace/class0/*.nii.gz
|-- t1space/class0/*.nii.gz
`-- mnispace/class0/*.nii.gz
```
If files are not organized by group subfolders, omit `--groups`. For MNI-space inputs, provide `--template-mask` when the default mask is not available.
Pretraining list files contain one NIfTI path per line:
```text
/path/to/sub-0001_bold.nii.gz
/path/to/sub-0002_bold.nii.gz
```
Downstream classification uses the same list format plus a CSV label table:
```csv
Subject,Group_idx
003_S_0908,2
011_S_0002,1
1001,0
```
Default label fields are `Subject` and `Group_idx`. `path_id_mode=auto` supports ADNI-style IDs such as `003_S_0908`, ADHD-style filenames, and fallback digit IDs.
## Pretraining
Run from a config:
```bash
python -m flexibrain pretrain --config configs/pretrain_example.yaml
```
Or use CLI arguments:
```bash
python -m flexibrain pretrain \
--train-list /path/to/pretrain_train.txt \
--val-list /path/to/pretrain_val.txt \
--checkpoint-dir ./checkpoints/pretrain/example \
--log-dir ./logs/pretrain/example \
--embed-dim 512 \
--depth 24 \
--predictor-depth 2 \
--bimamba-type v2 \
--if-devide-out \
--batch-size 4 \
--epochs 30 \
--lr 5e-4 \
--weight-decay 0.05 \
--warmup-epochs 3 \
--mask-ratio 0.65 \
--grad-accumulation-steps 4 \
--t-prime 30 \
--tau-seconds 6.0 \
--use-amp
```
Outputs:
```text
checkpoint_latest.pt
checkpoint_best.pt
pretrain_*.log
```
## Downstream Classification
Run from a config:
```bash
python -m flexibrain downstream --config configs/downstream_example.yaml
```
Or use CLI arguments:
```bash
python -m flexibrain downstream \
--train-list /path/to/downstream_train.txt \
--val-list /path/to/downstream_val.txt \
--test-list /path/to/downstream_test.txt \
--csv /path/to/labels.csv \
--pretrain-checkpoint /path/to/checkpoint_best.pt \
--num-classes 3 \
--head-type transformer \
--batch-size 8 \
--epochs 30 \
--lr 1e-5 \
--lr-backbone 6e-6 \
--lr-head 6e-5 \
--checkpoint-dir ./checkpoints/downstream/example \
--log-dir ./logs/downstream/example \
--use-amp
```
During downstream training, validation metrics select `downstream_best.pt`. The test set is evaluated once at the end after loading that best validation checkpoint, and the final metrics are written to `test_metrics.json`.
## Configuration
YAML config mirrors the CLI options. Keep private paths in local config files and leave shared configs as portable examples. The provided examples use placeholder paths under `data/`:
```text
configs/pretrain_example.yaml
configs/downstream_example.yaml
```
## Checkpoint Compatibility
The downstream loader can initialize from the original pretraining checkpoint path format:
```text
/path/to/checkpoint_best.pt
```
When `use_checkpoint_config: true`, model-shape settings stored in the checkpoint are applied before loading the backbone.
## License
This repository is provided for non-commercial research use under CC BY-NC-SA 4.0. See `LICENSE` and `NOTICE` for license boundaries and preserved notices. |