File size: 7,003 Bytes
4edc9aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# Flow Matching fMRI Encoder

A two-stage architecture for predicting fMRI BOLD responses from naturalistic video stimuli (Friends TV show + Movie10 dataset) using Conditional Flow Matching.

---

## Overview

The pipeline decodes brain activity in two sequential stages:

| Stage | Model | Goal |
|---|---|---|
| **1** | `MultiSubjectConvLinearEncoder` | Predict a **Mean Anchor** β€” a deterministic per-voxel fMRI estimate shared across subjects |
| **2** | `CFM` (Conditional Flow Matching) | Learn a **per-subject neural vector field** that refines the Mean Anchor into a sharper, stochastic fMRI prediction |

The design mirrors the MedARC approach: Stage 1 provides a stable conditional mean $\mu$; Stage 2 integrates a continuous normalizing flow $\phi_t$ conditioned on $\mu$ to sample from the true posterior over voxel activations.

---

## Architecture

### Stage 1 β€” Mean Anchor (`medarc_architecture.py`)

```
Features (N, T, D_i)  β†’  DepthConv1d + Linear  β†’  embed (N, T, d)
                                                         ↓
                                               [Shared Decoder]  +  [Subject Decoders]
                                                         ↓
                                               fMRI Prediction (N, S, T, V)
```

- **`MultiSubjectConvLinearEncoder`**: Projects each feature stream through a `LinearConv` (depthwise conv + linear) to a shared embedding dimension `d = 192`.
- Global average pooling across the feature stack combines multi-model features.
- A shared linear decoder and per-subject linear decoders combine additively to predict `V` voxels for each of `S` subjects.
- Trained with MSE loss against ground-truth BOLD.

### Stage 2 β€” Neural Vector Field (`matcha_architecture.py`)

```
x1 (B, V, T)  β†’  proj_in (V β†’ d)  β†’  latent x1 (B, d, T)
mu (B, V, T)  β†’  proj_in (V β†’ d)  β†’  latent mu (B, d, T)
                                              ↓
                              OT-CFM loss (vector field u_t)
                              Matcha-TTS U-Net estimator
                              [Conformer / Transformer blocks]
                                              ↓
                              latent pred (B, d, T)  β†’  proj_out (d β†’ V)  β†’  fMRI (B, V, T)
```

- **Latent Bottleneck**: fMRI voxels (`V β‰ˆ 1000`) are projected down to a dense latent dimension (`d = 128`) before any convolution, reducing the first-layer parameter count from ~6M to ~98K and preventing gradient collapse.
- **`CFM`** wraps a Matcha-TTS style **1D U-Net** (`Decoder`) with ResNet-1D blocks and Conformer/Transformer attention at each scale.
- At inference, noise $z \sim \mathcal{N}(0,I)$ is integrated from $t=0$ to $t=1$ over 25 Euler steps conditioned on $\mu$.
- An auxiliary reconstruction loss (weight `0.1`) on `proj_in β†’ proj_out` trains the projection pair jointly with the vector field.

---

## Data

| Source | Content | Usage |
|---|---|---|
| **Friends** (seasons 1–7) | fMRI BOLD + multimodal features | Train (S1), Val (S6, S7) |
| **Movie10** | fMRI BOLD + multimodal features | Supplementary val (Figures, Life, Bourne, Wolf) |

Subjects used: **1, 2, 3, 5**.

### Feature Models

The encoder can ingest intermediate activations from any combination of:

| Key | Model |
|---|---|
| `internvl3_8b`, `internvl3_14b` | InternVL3 vision-language model |
| `qwen-2-5-omni-3b`, `qwen-2-5-omni-7b` | Qwen2.5-Omni audio-video model |
| `whisper` | OpenAI Whisper (audio) |
| `llama_3.2_1b`, `llama_3.2_3b` | LLaMA 3.2 (text) |
| `vjepa2` | V-JEPA 2 (video) |

Active features are set in `config.yml` under `include_features`.

---

## File Structure

```
flow_matching/
β”œβ”€β”€ config.yml               # Main training config (GPU, full data)
β”œβ”€β”€ debug_config.yml         # Fast local debug config (CPU, tiny data)
β”œβ”€β”€ environment.yml          # Conda environment spec
β”‚
β”œβ”€β”€ src/
β”‚   β”œβ”€β”€ training.py          # Two-stage training loop + evaluation
β”‚   β”œβ”€β”€ matcha_architecture.py  # CFM + Matcha-TTS U-Net decoder
β”‚   β”œβ”€β”€ medarc_architecture.py  # Stage 1 MultiSubjectConvLinearEncoder
β”‚   β”œβ”€β”€ data.py              # Algonauts2025 dataset + loaders
β”‚   β”œβ”€β”€ metric.py            # Pearson's r voxel-wise scoring
β”‚   β”œβ”€β”€ visualize.py         # Loss curve plotting
β”‚   └── inference.py         # Standalone inference helper
β”‚
β”œβ”€β”€ test/
β”‚   β”œβ”€β”€ overfit_test.py      # Tiny-batch overfit sanity check for Stage 2
β”‚   β”œβ”€β”€ check_pearson.py     # Load checkpoints and plot per-voxel Pearson's r heatmaps
β”‚   └── debug_training.py    # End-to-end smoke test
β”‚
β”œβ”€β”€ experiments/
β”‚   └── *.ipynb              # Analysis notebooks (RSA, OOD, brain region plots)
β”‚
└── Matcha-TTS/              # Vendored Matcha-TTS source (U-Net + solver)
```

---

## Training

### Full training (server)

```bash
cd flow_matching
python src/training.py --cfg-path config.yml
```

Checkpoints are written to `output/two_stage_encoding/`:
- `stage1_best.pt` β€” best Stage 1 model by validation Pearson's r
- `stage2_epoch_N.pt` β€” Stage 2 snapshot every 5 epochs

### Local debug (CPU, tiny model)

```bash
python src/training.py --cfg-path debug_config.yml
```

---

## Evaluation

### Pearson's r heatmaps

Loads all available Stage 1 and Stage 2 checkpoints, evaluates on the configured validation set, and saves per-subject per-voxel Pearson's r heatmaps to `output/two_stage_encoding/heatmaps/`.

```bash
python test/check_pearson.py
```

**Output per checkpoint:**
```
Stage 1 Overall Pearson's r: 0.1832
Stage 1 - Sub 1 Mean Pearson's r: 0.1754
Stage 2 Epoch 5 Overall Pearson's r: 0.2110
Stage 2 Epoch 5 - Sub 1 Mean Pearson's r: 0.2043
...
```

### Tiny-batch overfit test

Confirms Stage 2 can memorize a single training batch. If loss does not approach `0` within 500 steps, the architecture cannot learn the task.

```bash
python test/overfit_test.py --cfg-path config.yml --subject-idx 0 --steps 500
```

---

## Key Hyperparameters

| Parameter | Value | Location |
|---|---|---|
| Stage 1 embed dim | 192 | `config.yml / stage1.model.embed_dim` |
| Stage 1 encoder kernel | 45 | `config.yml / stage1.model.encoder_kernel_size` |
| Stage 1 LR | 3e-4 | `config.yml / stage1.lr` |
| Stage 2 latent dim | 128 | `config.yml / stage2.latent_dim` |
| Stage 2 U-Net channels | [256, 256] | `config.yml / stage2.decoder.channels` |
| Stage 2 block type | Conformer | `config.yml / stage2.decoder.*_block_type` |
| Stage 2 LR | 3e-4 | `config.yml / stage2.lr` |
| Euler integration steps | 25 | `config.yml / stage2.n_timesteps` |
| CFM Οƒ_min | 1e-4 | `config.yml / stage2.cfm.sigma_min` |

---

## Metric

Evaluation uses **voxel-wise Pearson's r** averaged across subjects:

$$r_v = \frac{\sum_t (y_v^t - \bar{y}_v)(\hat{y}_v^t - \bar{\hat{y}}_v)}{\|\mathbf{y}_v\| \cdot \|\hat{\mathbf{y}}_v\|}$$

The scalar reported is the mean over all `V` voxels and all `S` subjects.