File size: 7,003 Bytes
4edc9aa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | # Flow Matching fMRI Encoder
A two-stage architecture for predicting fMRI BOLD responses from naturalistic video stimuli (Friends TV show + Movie10 dataset) using Conditional Flow Matching.
---
## Overview
The pipeline decodes brain activity in two sequential stages:
| Stage | Model | Goal |
|---|---|---|
| **1** | `MultiSubjectConvLinearEncoder` | Predict a **Mean Anchor** β a deterministic per-voxel fMRI estimate shared across subjects |
| **2** | `CFM` (Conditional Flow Matching) | Learn a **per-subject neural vector field** that refines the Mean Anchor into a sharper, stochastic fMRI prediction |
The design mirrors the MedARC approach: Stage 1 provides a stable conditional mean $\mu$; Stage 2 integrates a continuous normalizing flow $\phi_t$ conditioned on $\mu$ to sample from the true posterior over voxel activations.
---
## Architecture
### Stage 1 β Mean Anchor (`medarc_architecture.py`)
```
Features (N, T, D_i) β DepthConv1d + Linear β embed (N, T, d)
β
[Shared Decoder] + [Subject Decoders]
β
fMRI Prediction (N, S, T, V)
```
- **`MultiSubjectConvLinearEncoder`**: Projects each feature stream through a `LinearConv` (depthwise conv + linear) to a shared embedding dimension `d = 192`.
- Global average pooling across the feature stack combines multi-model features.
- A shared linear decoder and per-subject linear decoders combine additively to predict `V` voxels for each of `S` subjects.
- Trained with MSE loss against ground-truth BOLD.
### Stage 2 β Neural Vector Field (`matcha_architecture.py`)
```
x1 (B, V, T) β proj_in (V β d) β latent x1 (B, d, T)
mu (B, V, T) β proj_in (V β d) β latent mu (B, d, T)
β
OT-CFM loss (vector field u_t)
Matcha-TTS U-Net estimator
[Conformer / Transformer blocks]
β
latent pred (B, d, T) β proj_out (d β V) β fMRI (B, V, T)
```
- **Latent Bottleneck**: fMRI voxels (`V β 1000`) are projected down to a dense latent dimension (`d = 128`) before any convolution, reducing the first-layer parameter count from ~6M to ~98K and preventing gradient collapse.
- **`CFM`** wraps a Matcha-TTS style **1D U-Net** (`Decoder`) with ResNet-1D blocks and Conformer/Transformer attention at each scale.
- At inference, noise $z \sim \mathcal{N}(0,I)$ is integrated from $t=0$ to $t=1$ over 25 Euler steps conditioned on $\mu$.
- An auxiliary reconstruction loss (weight `0.1`) on `proj_in β proj_out` trains the projection pair jointly with the vector field.
---
## Data
| Source | Content | Usage |
|---|---|---|
| **Friends** (seasons 1β7) | fMRI BOLD + multimodal features | Train (S1), Val (S6, S7) |
| **Movie10** | fMRI BOLD + multimodal features | Supplementary val (Figures, Life, Bourne, Wolf) |
Subjects used: **1, 2, 3, 5**.
### Feature Models
The encoder can ingest intermediate activations from any combination of:
| Key | Model |
|---|---|
| `internvl3_8b`, `internvl3_14b` | InternVL3 vision-language model |
| `qwen-2-5-omni-3b`, `qwen-2-5-omni-7b` | Qwen2.5-Omni audio-video model |
| `whisper` | OpenAI Whisper (audio) |
| `llama_3.2_1b`, `llama_3.2_3b` | LLaMA 3.2 (text) |
| `vjepa2` | V-JEPA 2 (video) |
Active features are set in `config.yml` under `include_features`.
---
## File Structure
```
flow_matching/
βββ config.yml # Main training config (GPU, full data)
βββ debug_config.yml # Fast local debug config (CPU, tiny data)
βββ environment.yml # Conda environment spec
β
βββ src/
β βββ training.py # Two-stage training loop + evaluation
β βββ matcha_architecture.py # CFM + Matcha-TTS U-Net decoder
β βββ medarc_architecture.py # Stage 1 MultiSubjectConvLinearEncoder
β βββ data.py # Algonauts2025 dataset + loaders
β βββ metric.py # Pearson's r voxel-wise scoring
β βββ visualize.py # Loss curve plotting
β βββ inference.py # Standalone inference helper
β
βββ test/
β βββ overfit_test.py # Tiny-batch overfit sanity check for Stage 2
β βββ check_pearson.py # Load checkpoints and plot per-voxel Pearson's r heatmaps
β βββ debug_training.py # End-to-end smoke test
β
βββ experiments/
β βββ *.ipynb # Analysis notebooks (RSA, OOD, brain region plots)
β
βββ Matcha-TTS/ # Vendored Matcha-TTS source (U-Net + solver)
```
---
## Training
### Full training (server)
```bash
cd flow_matching
python src/training.py --cfg-path config.yml
```
Checkpoints are written to `output/two_stage_encoding/`:
- `stage1_best.pt` β best Stage 1 model by validation Pearson's r
- `stage2_epoch_N.pt` β Stage 2 snapshot every 5 epochs
### Local debug (CPU, tiny model)
```bash
python src/training.py --cfg-path debug_config.yml
```
---
## Evaluation
### Pearson's r heatmaps
Loads all available Stage 1 and Stage 2 checkpoints, evaluates on the configured validation set, and saves per-subject per-voxel Pearson's r heatmaps to `output/two_stage_encoding/heatmaps/`.
```bash
python test/check_pearson.py
```
**Output per checkpoint:**
```
Stage 1 Overall Pearson's r: 0.1832
Stage 1 - Sub 1 Mean Pearson's r: 0.1754
Stage 2 Epoch 5 Overall Pearson's r: 0.2110
Stage 2 Epoch 5 - Sub 1 Mean Pearson's r: 0.2043
...
```
### Tiny-batch overfit test
Confirms Stage 2 can memorize a single training batch. If loss does not approach `0` within 500 steps, the architecture cannot learn the task.
```bash
python test/overfit_test.py --cfg-path config.yml --subject-idx 0 --steps 500
```
---
## Key Hyperparameters
| Parameter | Value | Location |
|---|---|---|
| Stage 1 embed dim | 192 | `config.yml / stage1.model.embed_dim` |
| Stage 1 encoder kernel | 45 | `config.yml / stage1.model.encoder_kernel_size` |
| Stage 1 LR | 3e-4 | `config.yml / stage1.lr` |
| Stage 2 latent dim | 128 | `config.yml / stage2.latent_dim` |
| Stage 2 U-Net channels | [256, 256] | `config.yml / stage2.decoder.channels` |
| Stage 2 block type | Conformer | `config.yml / stage2.decoder.*_block_type` |
| Stage 2 LR | 3e-4 | `config.yml / stage2.lr` |
| Euler integration steps | 25 | `config.yml / stage2.n_timesteps` |
| CFM Ο_min | 1e-4 | `config.yml / stage2.cfm.sigma_min` |
---
## Metric
Evaluation uses **voxel-wise Pearson's r** averaged across subjects:
$$r_v = \frac{\sum_t (y_v^t - \bar{y}_v)(\hat{y}_v^t - \bar{\hat{y}}_v)}{\|\mathbf{y}_v\| \cdot \|\hat{\mathbf{y}}_v\|}$$
The scalar reported is the mean over all `V` voxels and all `S` subjects.
|