| # Flow Matching fMRI Encoder |
|
|
| A two-stage architecture for predicting fMRI BOLD responses from naturalistic video stimuli (Friends TV show + Movie10 dataset) using Conditional Flow Matching. |
|
|
| --- |
|
|
| ## Overview |
|
|
| The pipeline decodes brain activity in two sequential stages: |
|
|
| | Stage | Model | Goal | |
| |---|---|---| |
| | **1** | `MultiSubjectConvLinearEncoder` | Predict a **Mean Anchor** β a deterministic per-voxel fMRI estimate shared across subjects | |
| | **2** | `CFM` (Conditional Flow Matching) | Learn a **per-subject neural vector field** that refines the Mean Anchor into a sharper, stochastic fMRI prediction | |
|
|
| The design mirrors the MedARC approach: Stage 1 provides a stable conditional mean $\mu$; Stage 2 integrates a continuous normalizing flow $\phi_t$ conditioned on $\mu$ to sample from the true posterior over voxel activations. |
| |
| --- |
| |
| ## Architecture |
| |
| ### Stage 1 β Mean Anchor (`medarc_architecture.py`) |
|
|
| ``` |
| Features (N, T, D_i) β DepthConv1d + Linear β embed (N, T, d) |
| β |
| [Shared Decoder] + [Subject Decoders] |
| β |
| fMRI Prediction (N, S, T, V) |
| ``` |
|
|
| - **`MultiSubjectConvLinearEncoder`**: Projects each feature stream through a `LinearConv` (depthwise conv + linear) to a shared embedding dimension `d = 192`. |
| - Global average pooling across the feature stack combines multi-model features. |
| - A shared linear decoder and per-subject linear decoders combine additively to predict `V` voxels for each of `S` subjects. |
| - Trained with MSE loss against ground-truth BOLD. |
|
|
| ### Stage 2 β Neural Vector Field (`matcha_architecture.py`) |
| |
| ``` |
| x1 (B, V, T) β proj_in (V β d) β latent x1 (B, d, T) |
| mu (B, V, T) β proj_in (V β d) β latent mu (B, d, T) |
| β |
| OT-CFM loss (vector field u_t) |
| Matcha-TTS U-Net estimator |
| [Conformer / Transformer blocks] |
| β |
| latent pred (B, d, T) β proj_out (d β V) β fMRI (B, V, T) |
| ``` |
| |
| - **Latent Bottleneck**: fMRI voxels (`V β 1000`) are projected down to a dense latent dimension (`d = 128`) before any convolution, reducing the first-layer parameter count from ~6M to ~98K and preventing gradient collapse. |
| - **`CFM`** wraps a Matcha-TTS style **1D U-Net** (`Decoder`) with ResNet-1D blocks and Conformer/Transformer attention at each scale. |
| - At inference, noise $z \sim \mathcal{N}(0,I)$ is integrated from $t=0$ to $t=1$ over 25 Euler steps conditioned on $\mu$. |
| - An auxiliary reconstruction loss (weight `0.1`) on `proj_in β proj_out` trains the projection pair jointly with the vector field. |
|
|
| --- |
|
|
| ## Data |
|
|
| | Source | Content | Usage | |
| |---|---|---| |
| | **Friends** (seasons 1β7) | fMRI BOLD + multimodal features | Train (S1), Val (S6, S7) | |
| | **Movie10** | fMRI BOLD + multimodal features | Supplementary val (Figures, Life, Bourne, Wolf) | |
|
|
| Subjects used: **1, 2, 3, 5**. |
|
|
| ### Feature Models |
|
|
| The encoder can ingest intermediate activations from any combination of: |
|
|
| | Key | Model | |
| |---|---| |
| | `internvl3_8b`, `internvl3_14b` | InternVL3 vision-language model | |
| | `qwen-2-5-omni-3b`, `qwen-2-5-omni-7b` | Qwen2.5-Omni audio-video model | |
| | `whisper` | OpenAI Whisper (audio) | |
| | `llama_3.2_1b`, `llama_3.2_3b` | LLaMA 3.2 (text) | |
| | `vjepa2` | V-JEPA 2 (video) | |
|
|
| Active features are set in `config.yml` under `include_features`. |
|
|
| --- |
|
|
| ## File Structure |
|
|
| ``` |
| flow_matching/ |
| βββ config.yml # Main training config (GPU, full data) |
| βββ debug_config.yml # Fast local debug config (CPU, tiny data) |
| βββ environment.yml # Conda environment spec |
| β |
| βββ src/ |
| β βββ training.py # Two-stage training loop + evaluation |
| β βββ matcha_architecture.py # CFM + Matcha-TTS U-Net decoder |
| β βββ medarc_architecture.py # Stage 1 MultiSubjectConvLinearEncoder |
| β βββ data.py # Algonauts2025 dataset + loaders |
| β βββ metric.py # Pearson's r voxel-wise scoring |
| β βββ visualize.py # Loss curve plotting |
| β βββ inference.py # Standalone inference helper |
| β |
| βββ test/ |
| β βββ overfit_test.py # Tiny-batch overfit sanity check for Stage 2 |
| β βββ check_pearson.py # Load checkpoints and plot per-voxel Pearson's r heatmaps |
| β βββ debug_training.py # End-to-end smoke test |
| β |
| βββ experiments/ |
| β βββ *.ipynb # Analysis notebooks (RSA, OOD, brain region plots) |
| β |
| βββ Matcha-TTS/ # Vendored Matcha-TTS source (U-Net + solver) |
| ``` |
|
|
| --- |
|
|
| ## Training |
|
|
| ### Full training (server) |
|
|
| ```bash |
| cd flow_matching |
| python src/training.py --cfg-path config.yml |
| ``` |
|
|
| Checkpoints are written to `output/two_stage_encoding/`: |
| - `stage1_best.pt` β best Stage 1 model by validation Pearson's r |
| - `stage2_epoch_N.pt` β Stage 2 snapshot every 5 epochs |
|
|
| ### Local debug (CPU, tiny model) |
|
|
| ```bash |
| python src/training.py --cfg-path debug_config.yml |
| ``` |
|
|
| --- |
|
|
| ## Evaluation |
|
|
| ### Pearson's r heatmaps |
|
|
| Loads all available Stage 1 and Stage 2 checkpoints, evaluates on the configured validation set, and saves per-subject per-voxel Pearson's r heatmaps to `output/two_stage_encoding/heatmaps/`. |
|
|
| ```bash |
| python test/check_pearson.py |
| ``` |
|
|
| **Output per checkpoint:** |
| ``` |
| Stage 1 Overall Pearson's r: 0.1832 |
| Stage 1 - Sub 1 Mean Pearson's r: 0.1754 |
| Stage 2 Epoch 5 Overall Pearson's r: 0.2110 |
| Stage 2 Epoch 5 - Sub 1 Mean Pearson's r: 0.2043 |
| ... |
| ``` |
|
|
| ### Tiny-batch overfit test |
|
|
| Confirms Stage 2 can memorize a single training batch. If loss does not approach `0` within 500 steps, the architecture cannot learn the task. |
|
|
| ```bash |
| python test/overfit_test.py --cfg-path config.yml --subject-idx 0 --steps 500 |
| ``` |
|
|
| --- |
|
|
| ## Key Hyperparameters |
|
|
| | Parameter | Value | Location | |
| |---|---|---| |
| | Stage 1 embed dim | 192 | `config.yml / stage1.model.embed_dim` | |
| | Stage 1 encoder kernel | 45 | `config.yml / stage1.model.encoder_kernel_size` | |
| | Stage 1 LR | 3e-4 | `config.yml / stage1.lr` | |
| | Stage 2 latent dim | 128 | `config.yml / stage2.latent_dim` | |
| | Stage 2 U-Net channels | [256, 256] | `config.yml / stage2.decoder.channels` | |
| | Stage 2 block type | Conformer | `config.yml / stage2.decoder.*_block_type` | |
| | Stage 2 LR | 3e-4 | `config.yml / stage2.lr` | |
| | Euler integration steps | 25 | `config.yml / stage2.n_timesteps` | |
| | CFM Ο_min | 1e-4 | `config.yml / stage2.cfm.sigma_min` | |
|
|
| --- |
|
|
| ## Metric |
|
|
| Evaluation uses **voxel-wise Pearson's r** averaged across subjects: |
|
|
| $$r_v = \frac{\sum_t (y_v^t - \bar{y}_v)(\hat{y}_v^t - \bar{\hat{y}}_v)}{\|\mathbf{y}_v\| \cdot \|\hat{\mathbf{y}}_v\|}$$ |
|
|
| The scalar reported is the mean over all `V` voxels and all `S` subjects. |
|
|