File size: 11,374 Bytes
dba2c56
b10dd74
 
 
 
 
 
 
dba2c56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b10dd74
 
dba2c56
 
 
b10dd74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22e8e09
b10dd74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b292067
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b10dd74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dba2c56
 
b10dd74
 
 
 
 
 
 
 
 
 
22e8e09
b10dd74
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
---
tags:
  - multimodal
  - reasoning
  - jepa
  - world-model
  - vision-language
license: apache-2.0
---

# MR-JEPA: Multimodal Reasoning via Joint-Embedding Predictive Architecture

> A world model for multimodal reasoning that refines a latent belief state over K=3 steps using JEPA-style prediction, evidence gating, and dense visual backbones.

## Key Idea

Traditional multimodal models produce answers in a single forward pass. MR-JEPA instead models **the evolution of a belief state** as the system reasons about a question:

```
z₀ (initial evidence) → z₁ (first refinement) → z₂ (deeper reasoning) → z₃ (answer)
```

This trajectory is supervised by a **JEPA objective**: a target encoder (EMA) generates target latent states, and the online predictor learns to predict them. The JEPA loss encourages the model to learn **meaningful intermediate reasoning states** — not just the final answer.

---

## Architecture

```
┌──────────────┐     ┌────────────────────┐     ┌──────────────────┐     ┌───────────────┐
│  DINOv3-L/16 │────▶│   Evidence Memory  │────▶│  Latent Rollout  │────▶│ Disc. Head    │
│   (frozen)   │     │ (Perceiver Resampl)│     │  z₀→z₁→z₂→z₃    │     │ (MC scoring)  │
└──────────────┘     └────────┬───────────┘     │ (shared block)   │     └───────────────┘
                              │                 └────────┬─────────┘     ┌───────────────┐
┌──────────────┐              │                          │          ├───▶│ Gen. Decoder   │
│ Qwen3-Embed  │──────────────┘                  ┌───────┴────────┐     │ (Qwen3.5-4B)  │
│   0.6B       │                                 │ Target Encoder │     └───────────────┘
│   (frozen)   │                                 │  (EMA copy)    │
└──────────────┘                                 └────────────────┘

┌──────────────┐                                  JEPA Loss:
│ Phase 3 opt: │──────────┘                       SmoothL1/Cosine
│ PaddleOCR-VL │                                  + SIGReg (purist)
│ SAM 3.1      │                                  / VICReg (hybrid)
└──────────────┘
```

---

## Component Stack

| Module | Primary Choice | Alternative | Notes |
|--------|---------------|-------------|-------|
| **Visual backbone** | [`timm/vit_large_patch16_dinov3.lvd1689m`](https://hf.co/timm/vit_large_patch16_dinov3.lvd1689m) — DINOv3-L/16, 1024-dim, 300M | DINOv3-B/16 (purist); DINOv2-L/14 (ablation) | Frozen Phase 1; last 6 layers unfrozen Phase 2 |
| **Text encoder** | [`Qwen/Qwen3-Embedding-0.6B`](https://hf.co/Qwen/Qwen3-Embedding-0.6B) — 1024-dim, 596M | Qwen3-Embedding-4B (heavier); EmbeddingGemma-300M (lighter) | Frozen Phase 1; last 4 layers unfrozen Phase 2 |
| **Evidence memory** | Perceiver Resampler, 64 queries, 4 cross-attn layers | Q-Former as baseline | Modality-typed tokens (visual/text/OCR/layout/chart/SAM) |
| **OCR / doc / charts** | [`PaddlePaddle/PaddleOCR-VL-1.5`](https://hf.co/PaddlePaddle/PaddleOCR-VL-1.5) — 958M | MinerU2.5 for heavy PDF parsing | Phase 3 only, offline token extraction |
| **Segmentation** | [`jetjodh/sam3.1`](https://hf.co/jetjodh/sam3.1) — SAM 3.1, non-gated mirror | SAM 2.1-Large (stable) | Phase 3 optional, offline mask extraction |
| **Latent rollout** | Shared transformer predictor, 6 layers, K=3 | Per-step unshared blocks (ablation) | Weight-tied across steps; sigmoid evidence gates |
| **Target encoder** | EMA copy (cosine 0.996→1.0) of evidence+rollout | Frozen target (ablation baseline) | From I-JEPA |
| **JEPA loss** | SmoothL1 + VICReg (hybrid); Cosine + SIGReg (purist) | MSE (ablation) | SIGReg emphasis in purist branch |
| **Disc. head** | MLP/bilinear scorer | Cross-encoder scorer (ablation) | Attention-pooled z_K × option embeddings |
| **Gen. decoder** | [`Qwen/Qwen3.5-4B`](https://hf.co/Qwen/Qwen3.5-4B) — 4.7B, multimodal | [`HuggingFaceTB/SmolLM3-3B`](https://hf.co/HuggingFaceTB/SmolLM3-3B) (cheaper); Gemma3-4B | Phase 3+, cross-attends to z_K + evidence |
| **Teacher/baseline** | InternVL3.5 / Qwen3-VL | External comparison only | NOT used as internal module |

---

## Training Protocol

### Phase 1: Reasoning Core (15–20 epochs)
- **Freeze** all perception (DINOv3 + Qwen3-Embedding)
- **Train** evidence memory + latent rollout + discriminative head
- Full JEPA loss + task loss
- LR: 3e-4, effective batch: 64

### Phase 2: Perception Fine-tuning (10 epochs)
- **Unfreeze** last 6 DINOv3 layers + last 4 Qwen3-Embedding layers (1e-5)
- Continue training reasoning core (1e-4)

### Phase 3: Enriched Evidence + Generative Decoder (10 epochs)
- **Enable** PaddleOCR-VL tokens, SAM 3.1 masks, layout/chart tokens
- **Attach** Qwen3.5-4B generative decoder for open-ended answers
- End-to-end fine-tuning, LR: 5e-5

---

## Target Benchmarks (9)

| Benchmark | Type | Metric | Key Challenge |
|-----------|------|--------|---------------|
| MMMU | MC (multi-image) | Accuracy | Multi-discipline, up to 7 images |
| MathVista | Mixed MC/Open | Accuracy | Mathematical reasoning |
| ScienceQA | MC | Accuracy | Scientific diagrams, nullable images |
| AI2D | MC | Accuracy | Science diagram comprehension |
| MMBench | MC | CircularEval Acc | General visual understanding |
| MMStar | MC | Accuracy | Vision-dependent questions |
| DocVQA | Open | ANLS | Document text extraction |
| TextVQA | Open | VQA Accuracy | Scene text reading |
| ChartQA | Open | Relaxed Accuracy | Chart data extraction |

---

## Experimental Branches

### Hybrid-main (competitive)
- DINOv3-L backbone, SmoothL1 + VICReg, K=3
- Full enriched evidence in Phase 3
- Target: state-of-the-art on all benchmarks

### Purist-side (scientific validation)
- DINOv3-B backbone, Cosine + SIGReg, K=5
- No enriched evidence, pure JEPA reasoning
- Target: demonstrate JEPA contributes beyond perception

---

## Ablation Experiments

Each experiment maps 1:1 to a CLI flag in `train_mrjepa.py`.

| Experiment | CLI flag | Modification | Purpose |
|------------|----------|-------------|---------|
| `hybrid_main` | *(default)* | Full model | Baseline |
| `no_jepa` | `--no_jepa` | Remove L_JEPA, task loss only | Validate JEPA objective |
| `no_rollout` | `--no_rollout` | K=0, use z₀ directly | Validate iterative refinement |
| `no_gate` | `--no_evidence_gate` | Remove evidence gating | Validate adaptive evidence flow |
| `K1` / `K5` / `K7` | `--K 1/5/7` | Vary rollout depth | Find optimal depth |
| `dinov2_ablation` | `--backbone dinov2` | DINOv2-L/14 backbone | DINOv3 vs DINOv2 |
| `mse_loss` | `--loss_fn mse` | MSE (L2) JEPA loss | Original I-JEPA loss |
| `cosine_loss` | `--loss_fn cosine` | Cosine similarity JEPA loss | Purist-style loss |
| `no_sigreg` | `--no_sigreg` | Disable SIGReg anti-collapse | Test regularization |
| `vicreg_only` | `--no_sigreg --use_vicreg` | VICReg only | Alternative anti-collapse |
| `purist` | `--purist` | DINOv3-B, K=5, Cosine+SIGReg | Isolate JEPA contribution |

---

## Project Structure

```
MR-JEPA/
├── README.md                    # This file
├── train_mrjepa.py              # Complete training script (CLI, all ablations)
├── test_architecture.py         # Architecture validation tests (synthetic data)

├── mr_jepa/
│   ├── __init__.py
│   ├── ARCHITECTURE.md          # Detailed architecture specification
│   │
│   ├── configs/
│   │   ├── __init__.py
│   │   └── model_config.py      # All hyperparameter dataclasses
│   │
│   ├── models/
│   │   ├── __init__.py
│   │   ├── mr_jepa.py           # Main model (integrates all components)
│   │   ├── backbones.py         # Visual (DINOv3/v2) + Text (Qwen3-Embedding)
│   │   ├── evidence_memory.py   # Perceiver Resampler multimodal fusion
│   │   ├── latent_rollout.py    # K-step shared predictor + evidence gates
│   │   ├── target_encoder.py    # EMA encoder + JEPA/SIGReg/VICReg losses
│   │   └── answer_heads.py      # Discriminative (MC) + Generative (open-ended)
│   │
│   ├── data/
│   │   ├── __init__.py
│   │   ├── unified_dataset.py   # 9-benchmark unified loader with format quirks
│   │   └── data_utils.py        # Collator, dataloader factory, benchmark configs
│   │
│   ├── training/
│   │   ├── __init__.py
│   │   ├── trainer.py           # 3-phase training loop
│   │   └── phase_scheduler.py   # Phase transitions, LR scheduling
│   │
│   ├── evaluation/
│   │   ├── __init__.py
│   │   └── metrics.py           # Accuracy, ANLS, VQA Acc, Relaxed Acc
│   │
│   └── utils/
│       ├── __init__.py
│       ├── visualization.py     # Trajectory PCA, gate analysis
│       └── ablation.py          # Systematic ablation runner

├── results/                     # Training results (auto-pushed)
│   ├── hybrid_main.json
│   ├── no_jepa.json
│   ├── no_rollout.json
│   └── ...

└── checkpoints/                 # Best model checkpoints (auto-pushed)
    ├── hybrid_main_best.pt
    └── ...
```

---

## Paper Contribution

> **A world model for multimodal reasoning**: We demonstrate that modeling the evolution of a latent belief state via JEPA-style prediction improves performance on static multimodal benchmarks compared to single-pass baselines. The evidence-gated rollout with K=3 steps learns meaningful intermediate reasoning states, validated through ablation studies across 9 benchmarks. The JEPA objective (not human chain-of-thought) supervises a latent trajectory generated by an EMA target encoder, showing that self-supervised dynamics training transfers to discriminative reasoning tasks.

---

## Key References

1. **I-JEPA** (Assran et al., 2023) — [arxiv:2301.08243](https://arxiv.org/abs/2301.08243): JEPA architecture, EMA target encoder, L2 prediction loss, narrow predictor
2. **LeWorldModel** (Maes et al., 2025) — [arxiv:2603.19312](https://arxiv.org/abs/2603.19312): SIGReg anti-collapse, end-to-end JEPA
3. **Coconut** (Yu et al., 2024) — [arxiv:2412.06769](https://arxiv.org/abs/2412.06769): Chain of Continuous Thought, latent reasoning
4. **DINOv3** (Meta, 2025) — [arxiv:2508.10104](https://arxiv.org/abs/2508.10104): Dense SSL with RoPE + Gram anchoring
5. **SoftCoT++** (Xu et al., 2025) — [arxiv:2505.11484](https://arxiv.org/abs/2505.11484): Soft chain-of-thought with contrastive learning

## License

Apache-2.0