File size: 11,374 Bytes
dba2c56 b10dd74 dba2c56 b10dd74 dba2c56 b10dd74 22e8e09 b10dd74 b292067 b10dd74 dba2c56 b10dd74 22e8e09 b10dd74 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 | ---
tags:
- multimodal
- reasoning
- jepa
- world-model
- vision-language
license: apache-2.0
---
# MR-JEPA: Multimodal Reasoning via Joint-Embedding Predictive Architecture
> A world model for multimodal reasoning that refines a latent belief state over K=3 steps using JEPA-style prediction, evidence gating, and dense visual backbones.
## Key Idea
Traditional multimodal models produce answers in a single forward pass. MR-JEPA instead models **the evolution of a belief state** as the system reasons about a question:
```
z₀ (initial evidence) → z₁ (first refinement) → z₂ (deeper reasoning) → z₃ (answer)
```
This trajectory is supervised by a **JEPA objective**: a target encoder (EMA) generates target latent states, and the online predictor learns to predict them. The JEPA loss encourages the model to learn **meaningful intermediate reasoning states** — not just the final answer.
---
## Architecture
```
┌──────────────┐ ┌────────────────────┐ ┌──────────────────┐ ┌───────────────┐
│ DINOv3-L/16 │────▶│ Evidence Memory │────▶│ Latent Rollout │────▶│ Disc. Head │
│ (frozen) │ │ (Perceiver Resampl)│ │ z₀→z₁→z₂→z₃ │ │ (MC scoring) │
└──────────────┘ └────────┬───────────┘ │ (shared block) │ └───────────────┘
│ └────────┬─────────┘ ┌───────────────┐
┌──────────────┐ │ │ ├───▶│ Gen. Decoder │
│ Qwen3-Embed │──────────────┘ ┌───────┴────────┐ │ (Qwen3.5-4B) │
│ 0.6B │ │ Target Encoder │ └───────────────┘
│ (frozen) │ │ (EMA copy) │
└──────────────┘ └────────────────┘
│
┌──────────────┐ JEPA Loss:
│ Phase 3 opt: │──────────┘ SmoothL1/Cosine
│ PaddleOCR-VL │ + SIGReg (purist)
│ SAM 3.1 │ / VICReg (hybrid)
└──────────────┘
```
---
## Component Stack
| Module | Primary Choice | Alternative | Notes |
|--------|---------------|-------------|-------|
| **Visual backbone** | [`timm/vit_large_patch16_dinov3.lvd1689m`](https://hf.co/timm/vit_large_patch16_dinov3.lvd1689m) — DINOv3-L/16, 1024-dim, 300M | DINOv3-B/16 (purist); DINOv2-L/14 (ablation) | Frozen Phase 1; last 6 layers unfrozen Phase 2 |
| **Text encoder** | [`Qwen/Qwen3-Embedding-0.6B`](https://hf.co/Qwen/Qwen3-Embedding-0.6B) — 1024-dim, 596M | Qwen3-Embedding-4B (heavier); EmbeddingGemma-300M (lighter) | Frozen Phase 1; last 4 layers unfrozen Phase 2 |
| **Evidence memory** | Perceiver Resampler, 64 queries, 4 cross-attn layers | Q-Former as baseline | Modality-typed tokens (visual/text/OCR/layout/chart/SAM) |
| **OCR / doc / charts** | [`PaddlePaddle/PaddleOCR-VL-1.5`](https://hf.co/PaddlePaddle/PaddleOCR-VL-1.5) — 958M | MinerU2.5 for heavy PDF parsing | Phase 3 only, offline token extraction |
| **Segmentation** | [`jetjodh/sam3.1`](https://hf.co/jetjodh/sam3.1) — SAM 3.1, non-gated mirror | SAM 2.1-Large (stable) | Phase 3 optional, offline mask extraction |
| **Latent rollout** | Shared transformer predictor, 6 layers, K=3 | Per-step unshared blocks (ablation) | Weight-tied across steps; sigmoid evidence gates |
| **Target encoder** | EMA copy (cosine 0.996→1.0) of evidence+rollout | Frozen target (ablation baseline) | From I-JEPA |
| **JEPA loss** | SmoothL1 + VICReg (hybrid); Cosine + SIGReg (purist) | MSE (ablation) | SIGReg emphasis in purist branch |
| **Disc. head** | MLP/bilinear scorer | Cross-encoder scorer (ablation) | Attention-pooled z_K × option embeddings |
| **Gen. decoder** | [`Qwen/Qwen3.5-4B`](https://hf.co/Qwen/Qwen3.5-4B) — 4.7B, multimodal | [`HuggingFaceTB/SmolLM3-3B`](https://hf.co/HuggingFaceTB/SmolLM3-3B) (cheaper); Gemma3-4B | Phase 3+, cross-attends to z_K + evidence |
| **Teacher/baseline** | InternVL3.5 / Qwen3-VL | External comparison only | NOT used as internal module |
---
## Training Protocol
### Phase 1: Reasoning Core (15–20 epochs)
- **Freeze** all perception (DINOv3 + Qwen3-Embedding)
- **Train** evidence memory + latent rollout + discriminative head
- Full JEPA loss + task loss
- LR: 3e-4, effective batch: 64
### Phase 2: Perception Fine-tuning (10 epochs)
- **Unfreeze** last 6 DINOv3 layers + last 4 Qwen3-Embedding layers (1e-5)
- Continue training reasoning core (1e-4)
### Phase 3: Enriched Evidence + Generative Decoder (10 epochs)
- **Enable** PaddleOCR-VL tokens, SAM 3.1 masks, layout/chart tokens
- **Attach** Qwen3.5-4B generative decoder for open-ended answers
- End-to-end fine-tuning, LR: 5e-5
---
## Target Benchmarks (9)
| Benchmark | Type | Metric | Key Challenge |
|-----------|------|--------|---------------|
| MMMU | MC (multi-image) | Accuracy | Multi-discipline, up to 7 images |
| MathVista | Mixed MC/Open | Accuracy | Mathematical reasoning |
| ScienceQA | MC | Accuracy | Scientific diagrams, nullable images |
| AI2D | MC | Accuracy | Science diagram comprehension |
| MMBench | MC | CircularEval Acc | General visual understanding |
| MMStar | MC | Accuracy | Vision-dependent questions |
| DocVQA | Open | ANLS | Document text extraction |
| TextVQA | Open | VQA Accuracy | Scene text reading |
| ChartQA | Open | Relaxed Accuracy | Chart data extraction |
---
## Experimental Branches
### Hybrid-main (competitive)
- DINOv3-L backbone, SmoothL1 + VICReg, K=3
- Full enriched evidence in Phase 3
- Target: state-of-the-art on all benchmarks
### Purist-side (scientific validation)
- DINOv3-B backbone, Cosine + SIGReg, K=5
- No enriched evidence, pure JEPA reasoning
- Target: demonstrate JEPA contributes beyond perception
---
## Ablation Experiments
Each experiment maps 1:1 to a CLI flag in `train_mrjepa.py`.
| Experiment | CLI flag | Modification | Purpose |
|------------|----------|-------------|---------|
| `hybrid_main` | *(default)* | Full model | Baseline |
| `no_jepa` | `--no_jepa` | Remove L_JEPA, task loss only | Validate JEPA objective |
| `no_rollout` | `--no_rollout` | K=0, use z₀ directly | Validate iterative refinement |
| `no_gate` | `--no_evidence_gate` | Remove evidence gating | Validate adaptive evidence flow |
| `K1` / `K5` / `K7` | `--K 1/5/7` | Vary rollout depth | Find optimal depth |
| `dinov2_ablation` | `--backbone dinov2` | DINOv2-L/14 backbone | DINOv3 vs DINOv2 |
| `mse_loss` | `--loss_fn mse` | MSE (L2) JEPA loss | Original I-JEPA loss |
| `cosine_loss` | `--loss_fn cosine` | Cosine similarity JEPA loss | Purist-style loss |
| `no_sigreg` | `--no_sigreg` | Disable SIGReg anti-collapse | Test regularization |
| `vicreg_only` | `--no_sigreg --use_vicreg` | VICReg only | Alternative anti-collapse |
| `purist` | `--purist` | DINOv3-B, K=5, Cosine+SIGReg | Isolate JEPA contribution |
---
## Project Structure
```
MR-JEPA/
├── README.md # This file
├── train_mrjepa.py # Complete training script (CLI, all ablations)
├── test_architecture.py # Architecture validation tests (synthetic data)
│
├── mr_jepa/
│ ├── __init__.py
│ ├── ARCHITECTURE.md # Detailed architecture specification
│ │
│ ├── configs/
│ │ ├── __init__.py
│ │ └── model_config.py # All hyperparameter dataclasses
│ │
│ ├── models/
│ │ ├── __init__.py
│ │ ├── mr_jepa.py # Main model (integrates all components)
│ │ ├── backbones.py # Visual (DINOv3/v2) + Text (Qwen3-Embedding)
│ │ ├── evidence_memory.py # Perceiver Resampler multimodal fusion
│ │ ├── latent_rollout.py # K-step shared predictor + evidence gates
│ │ ├── target_encoder.py # EMA encoder + JEPA/SIGReg/VICReg losses
│ │ └── answer_heads.py # Discriminative (MC) + Generative (open-ended)
│ │
│ ├── data/
│ │ ├── __init__.py
│ │ ├── unified_dataset.py # 9-benchmark unified loader with format quirks
│ │ └── data_utils.py # Collator, dataloader factory, benchmark configs
│ │
│ ├── training/
│ │ ├── __init__.py
│ │ ├── trainer.py # 3-phase training loop
│ │ └── phase_scheduler.py # Phase transitions, LR scheduling
│ │
│ ├── evaluation/
│ │ ├── __init__.py
│ │ └── metrics.py # Accuracy, ANLS, VQA Acc, Relaxed Acc
│ │
│ └── utils/
│ ├── __init__.py
│ ├── visualization.py # Trajectory PCA, gate analysis
│ └── ablation.py # Systematic ablation runner
│
├── results/ # Training results (auto-pushed)
│ ├── hybrid_main.json
│ ├── no_jepa.json
│ ├── no_rollout.json
│ └── ...
│
└── checkpoints/ # Best model checkpoints (auto-pushed)
├── hybrid_main_best.pt
└── ...
```
---
## Paper Contribution
> **A world model for multimodal reasoning**: We demonstrate that modeling the evolution of a latent belief state via JEPA-style prediction improves performance on static multimodal benchmarks compared to single-pass baselines. The evidence-gated rollout with K=3 steps learns meaningful intermediate reasoning states, validated through ablation studies across 9 benchmarks. The JEPA objective (not human chain-of-thought) supervises a latent trajectory generated by an EMA target encoder, showing that self-supervised dynamics training transfers to discriminative reasoning tasks.
---
## Key References
1. **I-JEPA** (Assran et al., 2023) — [arxiv:2301.08243](https://arxiv.org/abs/2301.08243): JEPA architecture, EMA target encoder, L2 prediction loss, narrow predictor
2. **LeWorldModel** (Maes et al., 2025) — [arxiv:2603.19312](https://arxiv.org/abs/2603.19312): SIGReg anti-collapse, end-to-end JEPA
3. **Coconut** (Yu et al., 2024) — [arxiv:2412.06769](https://arxiv.org/abs/2412.06769): Chain of Continuous Thought, latent reasoning
4. **DINOv3** (Meta, 2025) — [arxiv:2508.10104](https://arxiv.org/abs/2508.10104): Dense SSL with RoPE + Gram anchoring
5. **SoftCoT++** (Xu et al., 2025) — [arxiv:2505.11484](https://arxiv.org/abs/2505.11484): Soft chain-of-thought with contrastive learning
## License
Apache-2.0
|