flow-matching / README.md
sabertoaster's picture
Upload folder using huggingface_hub
4edc9aa verified
|
Raw
History Blame Contribute Delete
7 kB
# Flow Matching fMRI Encoder
A two-stage architecture for predicting fMRI BOLD responses from naturalistic video stimuli (Friends TV show + Movie10 dataset) using Conditional Flow Matching.
---
## Overview
The pipeline decodes brain activity in two sequential stages:
| Stage | Model | Goal |
|---|---|---|
| **1** | `MultiSubjectConvLinearEncoder` | Predict a **Mean Anchor** β€” a deterministic per-voxel fMRI estimate shared across subjects |
| **2** | `CFM` (Conditional Flow Matching) | Learn a **per-subject neural vector field** that refines the Mean Anchor into a sharper, stochastic fMRI prediction |
The design mirrors the MedARC approach: Stage 1 provides a stable conditional mean $\mu$; Stage 2 integrates a continuous normalizing flow $\phi_t$ conditioned on $\mu$ to sample from the true posterior over voxel activations.
---
## Architecture
### Stage 1 β€” Mean Anchor (`medarc_architecture.py`)
```
Features (N, T, D_i) β†’ DepthConv1d + Linear β†’ embed (N, T, d)
↓
[Shared Decoder] + [Subject Decoders]
↓
fMRI Prediction (N, S, T, V)
```
- **`MultiSubjectConvLinearEncoder`**: Projects each feature stream through a `LinearConv` (depthwise conv + linear) to a shared embedding dimension `d = 192`.
- Global average pooling across the feature stack combines multi-model features.
- A shared linear decoder and per-subject linear decoders combine additively to predict `V` voxels for each of `S` subjects.
- Trained with MSE loss against ground-truth BOLD.
### Stage 2 β€” Neural Vector Field (`matcha_architecture.py`)
```
x1 (B, V, T) β†’ proj_in (V β†’ d) β†’ latent x1 (B, d, T)
mu (B, V, T) β†’ proj_in (V β†’ d) β†’ latent mu (B, d, T)
↓
OT-CFM loss (vector field u_t)
Matcha-TTS U-Net estimator
[Conformer / Transformer blocks]
↓
latent pred (B, d, T) β†’ proj_out (d β†’ V) β†’ fMRI (B, V, T)
```
- **Latent Bottleneck**: fMRI voxels (`V β‰ˆ 1000`) are projected down to a dense latent dimension (`d = 128`) before any convolution, reducing the first-layer parameter count from ~6M to ~98K and preventing gradient collapse.
- **`CFM`** wraps a Matcha-TTS style **1D U-Net** (`Decoder`) with ResNet-1D blocks and Conformer/Transformer attention at each scale.
- At inference, noise $z \sim \mathcal{N}(0,I)$ is integrated from $t=0$ to $t=1$ over 25 Euler steps conditioned on $\mu$.
- An auxiliary reconstruction loss (weight `0.1`) on `proj_in β†’ proj_out` trains the projection pair jointly with the vector field.
---
## Data
| Source | Content | Usage |
|---|---|---|
| **Friends** (seasons 1–7) | fMRI BOLD + multimodal features | Train (S1), Val (S6, S7) |
| **Movie10** | fMRI BOLD + multimodal features | Supplementary val (Figures, Life, Bourne, Wolf) |
Subjects used: **1, 2, 3, 5**.
### Feature Models
The encoder can ingest intermediate activations from any combination of:
| Key | Model |
|---|---|
| `internvl3_8b`, `internvl3_14b` | InternVL3 vision-language model |
| `qwen-2-5-omni-3b`, `qwen-2-5-omni-7b` | Qwen2.5-Omni audio-video model |
| `whisper` | OpenAI Whisper (audio) |
| `llama_3.2_1b`, `llama_3.2_3b` | LLaMA 3.2 (text) |
| `vjepa2` | V-JEPA 2 (video) |
Active features are set in `config.yml` under `include_features`.
---
## File Structure
```
flow_matching/
β”œβ”€β”€ config.yml # Main training config (GPU, full data)
β”œβ”€β”€ debug_config.yml # Fast local debug config (CPU, tiny data)
β”œβ”€β”€ environment.yml # Conda environment spec
β”‚
β”œβ”€β”€ src/
β”‚ β”œβ”€β”€ training.py # Two-stage training loop + evaluation
β”‚ β”œβ”€β”€ matcha_architecture.py # CFM + Matcha-TTS U-Net decoder
β”‚ β”œβ”€β”€ medarc_architecture.py # Stage 1 MultiSubjectConvLinearEncoder
β”‚ β”œβ”€β”€ data.py # Algonauts2025 dataset + loaders
β”‚ β”œβ”€β”€ metric.py # Pearson's r voxel-wise scoring
β”‚ β”œβ”€β”€ visualize.py # Loss curve plotting
β”‚ └── inference.py # Standalone inference helper
β”‚
β”œβ”€β”€ test/
β”‚ β”œβ”€β”€ overfit_test.py # Tiny-batch overfit sanity check for Stage 2
β”‚ β”œβ”€β”€ check_pearson.py # Load checkpoints and plot per-voxel Pearson's r heatmaps
β”‚ └── debug_training.py # End-to-end smoke test
β”‚
β”œβ”€β”€ experiments/
β”‚ └── *.ipynb # Analysis notebooks (RSA, OOD, brain region plots)
β”‚
└── Matcha-TTS/ # Vendored Matcha-TTS source (U-Net + solver)
```
---
## Training
### Full training (server)
```bash
cd flow_matching
python src/training.py --cfg-path config.yml
```
Checkpoints are written to `output/two_stage_encoding/`:
- `stage1_best.pt` β€” best Stage 1 model by validation Pearson's r
- `stage2_epoch_N.pt` β€” Stage 2 snapshot every 5 epochs
### Local debug (CPU, tiny model)
```bash
python src/training.py --cfg-path debug_config.yml
```
---
## Evaluation
### Pearson's r heatmaps
Loads all available Stage 1 and Stage 2 checkpoints, evaluates on the configured validation set, and saves per-subject per-voxel Pearson's r heatmaps to `output/two_stage_encoding/heatmaps/`.
```bash
python test/check_pearson.py
```
**Output per checkpoint:**
```
Stage 1 Overall Pearson's r: 0.1832
Stage 1 - Sub 1 Mean Pearson's r: 0.1754
Stage 2 Epoch 5 Overall Pearson's r: 0.2110
Stage 2 Epoch 5 - Sub 1 Mean Pearson's r: 0.2043
...
```
### Tiny-batch overfit test
Confirms Stage 2 can memorize a single training batch. If loss does not approach `0` within 500 steps, the architecture cannot learn the task.
```bash
python test/overfit_test.py --cfg-path config.yml --subject-idx 0 --steps 500
```
---
## Key Hyperparameters
| Parameter | Value | Location |
|---|---|---|
| Stage 1 embed dim | 192 | `config.yml / stage1.model.embed_dim` |
| Stage 1 encoder kernel | 45 | `config.yml / stage1.model.encoder_kernel_size` |
| Stage 1 LR | 3e-4 | `config.yml / stage1.lr` |
| Stage 2 latent dim | 128 | `config.yml / stage2.latent_dim` |
| Stage 2 U-Net channels | [256, 256] | `config.yml / stage2.decoder.channels` |
| Stage 2 block type | Conformer | `config.yml / stage2.decoder.*_block_type` |
| Stage 2 LR | 3e-4 | `config.yml / stage2.lr` |
| Euler integration steps | 25 | `config.yml / stage2.n_timesteps` |
| CFM Οƒ_min | 1e-4 | `config.yml / stage2.cfm.sigma_min` |
---
## Metric
Evaluation uses **voxel-wise Pearson's r** averaged across subjects:
$$r_v = \frac{\sum_t (y_v^t - \bar{y}_v)(\hat{y}_v^t - \bar{\hat{y}}_v)}{\|\mathbf{y}_v\| \cdot \|\hat{\mathbf{y}}_v\|}$$
The scalar reported is the mean over all `V` voxels and all `S` subjects.