Update README with adapted component stack (DINOv3, Qwen3, SAM3.1, etc.)
Browse files
README.md
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
---
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
---
|
| 10 |
|
| 11 |
# MR-JEPA: Multimodal Reasoning via Joint-Embedding Predictive Architecture
|
|
@@ -22,23 +22,184 @@ z₀ (initial evidence) → z₁ (first refinement) → z₂ (deeper reasoning)
|
|
| 22 |
|
| 23 |
This trajectory is supervised by a **JEPA objective**: a target encoder (EMA) generates target latent states, and the online predictor learns to predict them. The JEPA loss encourages the model to learn **meaningful intermediate reasoning states** — not just the final answer.
|
| 24 |
|
|
|
|
|
|
|
| 25 |
## Architecture
|
| 26 |
|
| 27 |
```
|
| 28 |
-
┌─────────────┐ ┌──────────────┐ ┌─────────────────┐ ┌──────────┐
|
| 29 |
-
│
|
| 30 |
-
│
|
| 31 |
-
└─────────────┘
|
| 32 |
-
|
| 33 |
-
┌─────────────┐
|
| 34 |
-
│
|
| 35 |
-
│
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
│
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
```
|
| 43 |
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
tags:
|
| 3 |
+
- multimodal
|
| 4 |
+
- reasoning
|
| 5 |
+
- jepa
|
| 6 |
+
- world-model
|
| 7 |
+
- vision-language
|
| 8 |
+
license: apache-2.0
|
| 9 |
---
|
| 10 |
|
| 11 |
# MR-JEPA: Multimodal Reasoning via Joint-Embedding Predictive Architecture
|
|
|
|
| 22 |
|
| 23 |
This trajectory is supervised by a **JEPA objective**: a target encoder (EMA) generates target latent states, and the online predictor learns to predict them. The JEPA loss encourages the model to learn **meaningful intermediate reasoning states** — not just the final answer.
|
| 24 |
|
| 25 |
+
---
|
| 26 |
+
|
| 27 |
## Architecture
|
| 28 |
|
| 29 |
```
|
| 30 |
+
┌──────────────┐ ┌────────────────────┐ ┌──────────────────┐ ┌───────────────┐
|
| 31 |
+
│ DINOv3-L/16 │────▶│ Evidence Memory │────▶│ Latent Rollout │────▶│ Disc. Head │
|
| 32 |
+
│ (frozen) │ │ (Perceiver Resampl)│ │ z₀→z₁→z₂→z₃ │ │ (MC scoring) │
|
| 33 |
+
└──────────────┘ └────────┬───────────┘ │ (shared block) │ └───────────────┘
|
| 34 |
+
│ └────────┬─────────┘ ┌───────────────┐
|
| 35 |
+
┌──────────────┐ │ │ ├───▶│ Gen. Decoder │
|
| 36 |
+
│ Qwen3-Embed │──────────────┘ ┌───────┴────────┐ │ (Qwen3.5-4B) │
|
| 37 |
+
│ 0.6B │ │ Target Encoder │ └───────────────┘
|
| 38 |
+
│ (frozen) │ │ (EMA copy) │
|
| 39 |
+
└──────────────┘ └────────────────┘
|
| 40 |
+
│
|
| 41 |
+
┌──────────────┐ JEPA Loss:
|
| 42 |
+
│ Phase 3 opt: │──────────┘ SmoothL1/Cosine
|
| 43 |
+
│ PaddleOCR-VL │ + SIGReg (purist)
|
| 44 |
+
│ SAM 3.1 │ / VICReg (hybrid)
|
| 45 |
+
└──────────────┘
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
---
|
| 49 |
+
|
| 50 |
+
## Component Stack
|
| 51 |
+
|
| 52 |
+
| Module | Primary Choice | Alternative | Notes |
|
| 53 |
+
|--------|---------------|-------------|-------|
|
| 54 |
+
| **Visual backbone** | [`timm/vit_large_patch16_dinov3.lvd1689m`](https://hf.co/timm/vit_large_patch16_dinov3.lvd1689m) — DINOv3-L/16, 1024-dim, 300M | DINOv3-B/16 (purist); DINOv2-L/14 (ablation) | Frozen Phase 1; last 6 layers unfrozen Phase 2 |
|
| 55 |
+
| **Text encoder** | [`Qwen/Qwen3-Embedding-0.6B`](https://hf.co/Qwen/Qwen3-Embedding-0.6B) — 1024-dim, 596M | Qwen3-Embedding-4B (heavier); EmbeddingGemma-300M (lighter) | Frozen Phase 1; last 4 layers unfrozen Phase 2 |
|
| 56 |
+
| **Evidence memory** | Perceiver Resampler, 64 queries, 4 cross-attn layers | Q-Former as baseline | Modality-typed tokens (visual/text/OCR/layout/chart/SAM) |
|
| 57 |
+
| **OCR / doc / charts** | [`PaddlePaddle/PaddleOCR-VL-1.5`](https://hf.co/PaddlePaddle/PaddleOCR-VL-1.5) — 958M | MinerU2.5 for heavy PDF parsing | Phase 3 only, offline token extraction |
|
| 58 |
+
| **Segmentation** | [`facebook/sam3.1`](https://hf.co/facebook/sam3.1) — SAM 3.1, 3.3GB, gated | SAM 2.1-Large (stable) | Phase 3 optional, offline mask extraction |
|
| 59 |
+
| **Latent rollout** | Shared transformer predictor, 6 layers, K=3 | Per-step unshared blocks (ablation) | Weight-tied across steps; sigmoid evidence gates |
|
| 60 |
+
| **Target encoder** | EMA copy (cosine 0.996→1.0) of evidence+rollout | Frozen target (ablation baseline) | From I-JEPA |
|
| 61 |
+
| **JEPA loss** | SmoothL1 + VICReg (hybrid); Cosine + SIGReg (purist) | MSE (ablation) | SIGReg emphasis in purist branch |
|
| 62 |
+
| **Disc. head** | MLP/bilinear scorer | Cross-encoder scorer (ablation) | Attention-pooled z_K × option embeddings |
|
| 63 |
+
| **Gen. decoder** | [`Qwen/Qwen3.5-4B`](https://hf.co/Qwen/Qwen3.5-4B) — 4.7B, multimodal | [`HuggingFaceTB/SmolLM3-3B`](https://hf.co/HuggingFaceTB/SmolLM3-3B) (cheaper); Gemma3-4B | Phase 3+, cross-attends to z_K + evidence |
|
| 64 |
+
| **Teacher/baseline** | InternVL3.5 / Qwen3-VL | External comparison only | NOT used as internal module |
|
| 65 |
+
|
| 66 |
+
---
|
| 67 |
+
|
| 68 |
+
## Training Protocol
|
| 69 |
+
|
| 70 |
+
### Phase 1: Reasoning Core (15–20 epochs)
|
| 71 |
+
- **Freeze** all perception (DINOv3 + Qwen3-Embedding)
|
| 72 |
+
- **Train** evidence memory + latent rollout + discriminative head
|
| 73 |
+
- Full JEPA loss + task loss
|
| 74 |
+
- LR: 3e-4, effective batch: 64
|
| 75 |
+
|
| 76 |
+
### Phase 2: Perception Fine-tuning (10 epochs)
|
| 77 |
+
- **Unfreeze** last 6 DINOv3 layers + last 4 Qwen3-Embedding layers (1e-5)
|
| 78 |
+
- Continue training reasoning core (1e-4)
|
| 79 |
+
|
| 80 |
+
### Phase 3: Enriched Evidence + Generative Decoder (10 epochs)
|
| 81 |
+
- **Enable** PaddleOCR-VL tokens, SAM 3.1 masks, layout/chart tokens
|
| 82 |
+
- **Attach** Qwen3.5-4B generative decoder for open-ended answers
|
| 83 |
+
- End-to-end fine-tuning, LR: 5e-5
|
| 84 |
+
|
| 85 |
+
---
|
| 86 |
+
|
| 87 |
+
## Target Benchmarks (9)
|
| 88 |
+
|
| 89 |
+
| Benchmark | Type | Metric | Key Challenge |
|
| 90 |
+
|-----------|------|--------|---------------|
|
| 91 |
+
| MMMU | MC (multi-image) | Accuracy | Multi-discipline, up to 7 images |
|
| 92 |
+
| MathVista | Mixed MC/Open | Accuracy | Mathematical reasoning |
|
| 93 |
+
| ScienceQA | MC | Accuracy | Scientific diagrams, nullable images |
|
| 94 |
+
| AI2D | MC | Accuracy | Science diagram comprehension |
|
| 95 |
+
| MMBench | MC | CircularEval Acc | General visual understanding |
|
| 96 |
+
| MMStar | MC | Accuracy | Vision-dependent questions |
|
| 97 |
+
| DocVQA | Open | ANLS | Document text extraction |
|
| 98 |
+
| TextVQA | Open | VQA Accuracy | Scene text reading |
|
| 99 |
+
| ChartQA | Open | Relaxed Accuracy | Chart data extraction |
|
| 100 |
+
|
| 101 |
+
---
|
| 102 |
+
|
| 103 |
+
## Experimental Branches
|
| 104 |
+
|
| 105 |
+
### Hybrid-main (competitive)
|
| 106 |
+
- DINOv3-L backbone, SmoothL1 + VICReg, K=3
|
| 107 |
+
- Full enriched evidence in Phase 3
|
| 108 |
+
- Target: state-of-the-art on all benchmarks
|
| 109 |
+
|
| 110 |
+
### Purist-side (scientific validation)
|
| 111 |
+
- DINOv3-B backbone, Cosine + SIGReg, K=5
|
| 112 |
+
- No enriched evidence, pure JEPA reasoning
|
| 113 |
+
- Target: demonstrate JEPA contributes beyond perception
|
| 114 |
+
|
| 115 |
+
---
|
| 116 |
+
|
| 117 |
+
## Ablation Experiments
|
| 118 |
+
|
| 119 |
+
| Experiment | Modification | Purpose |
|
| 120 |
+
|------------|-------------|---------|
|
| 121 |
+
| `hybrid_main` | Full model | Baseline |
|
| 122 |
+
| `no_jepa` | Remove L_JEPA, task loss only | Validate JEPA objective |
|
| 123 |
+
| `no_rollout` | K=0, use z₀ directly | Validate iterative refinement |
|
| 124 |
+
| `no_gate` | Remove evidence gating | Validate adaptive evidence flow |
|
| 125 |
+
| `K1` / `K5` / `K7` | Vary rollout depth | Find optimal depth |
|
| 126 |
+
| `dinov2_ablation` | DINOv2-L/14 backbone | DINOv3 vs DINOv2 |
|
| 127 |
+
| `purist` | DINOv3-B, no enriched ev., SIGReg | Isolate JEPA contribution |
|
| 128 |
+
| `mse_loss` / `cosine_loss` | Alternative JEPA losses | Loss function ablation |
|
| 129 |
+
|
| 130 |
+
---
|
| 131 |
+
|
| 132 |
+
## Project Structure
|
| 133 |
+
|
| 134 |
+
```
|
| 135 |
+
MR-JEPA/
|
| 136 |
+
├── README.md # This file
|
| 137 |
+
├── train_mrjepa.py # Complete training script (CLI, all ablations)
|
| 138 |
+
├── test_architecture.py # Architecture validation tests (synthetic data)
|
| 139 |
+
│
|
| 140 |
+
├── mr_jepa/
|
| 141 |
+
│ ├── __init__.py
|
| 142 |
+
│ ├── ARCHITECTURE.md # Detailed architecture specification
|
| 143 |
+
│ │
|
| 144 |
+
│ ├── configs/
|
| 145 |
+
│ │ ├── __init__.py
|
| 146 |
+
│ │ └── model_config.py # All hyperparameter dataclasses
|
| 147 |
+
│ │
|
| 148 |
+
│ ├── models/
|
| 149 |
+
│ │ ├── __init__.py
|
| 150 |
+
│ │ ├── mr_jepa.py # Main model (integrates all components)
|
| 151 |
+
│ │ ├── backbones.py # Visual (DINOv3/v2) + Text (Qwen3-Embedding)
|
| 152 |
+
│ │ ├── evidence_memory.py # Perceiver Resampler multimodal fusion
|
| 153 |
+
│ │ ├── latent_rollout.py # K-step shared predictor + evidence gates
|
| 154 |
+
│ │ ├── target_encoder.py # EMA encoder + JEPA/SIGReg/VICReg losses
|
| 155 |
+
│ │ └── answer_heads.py # Discriminative (MC) + Generative (open-ended)
|
| 156 |
+
│ │
|
| 157 |
+
│ ├── data/
|
| 158 |
+
│ │ ├── __init__.py
|
| 159 |
+
│ │ ├── unified_dataset.py # 9-benchmark unified loader with format quirks
|
| 160 |
+
│ │ └── data_utils.py # Collator, dataloader factory, benchmark configs
|
| 161 |
+
│ │
|
| 162 |
+
│ ├── training/
|
| 163 |
+
│ │ ├── __init__.py
|
| 164 |
+
│ │ ├── trainer.py # 3-phase training loop
|
| 165 |
+
│ │ └── phase_scheduler.py # Phase transitions, LR scheduling
|
| 166 |
+
│ ��
|
| 167 |
+
│ ├── evaluation/
|
| 168 |
+
│ │ ├── __init__.py
|
| 169 |
+
│ │ └── metrics.py # Accuracy, ANLS, VQA Acc, Relaxed Acc
|
| 170 |
+
│ │
|
| 171 |
+
│ └── utils/
|
| 172 |
+
│ ├── __init__.py
|
| 173 |
+
│ ├── visualization.py # Trajectory PCA, gate analysis
|
| 174 |
+
│ └── ablation.py # Systematic ablation runner
|
| 175 |
+
│
|
| 176 |
+
├── results/ # Training results (auto-pushed)
|
| 177 |
+
│ ├── hybrid_main.json
|
| 178 |
+
│ ├── no_jepa.json
|
| 179 |
+
│ ├── no_rollout.json
|
| 180 |
+
│ └── ...
|
| 181 |
+
│
|
| 182 |
+
└── checkpoints/ # Best model checkpoints (auto-pushed)
|
| 183 |
+
├── hybrid_main_best.pt
|
| 184 |
+
└── ...
|
| 185 |
```
|
| 186 |
|
| 187 |
+
---
|
| 188 |
+
|
| 189 |
+
## Paper Contribution
|
| 190 |
+
|
| 191 |
+
> **A world model for multimodal reasoning**: We demonstrate that modeling the evolution of a latent belief state via JEPA-style prediction improves performance on static multimodal benchmarks compared to single-pass baselines. The evidence-gated rollout with K=3 steps learns meaningful intermediate reasoning states, validated through ablation studies across 9 benchmarks. The JEPA objective (not human chain-of-thought) supervises a latent trajectory generated by an EMA target encoder, showing that self-supervised dynamics training transfers to discriminative reasoning tasks.
|
| 192 |
+
|
| 193 |
+
---
|
| 194 |
+
|
| 195 |
+
## Key References
|
| 196 |
+
|
| 197 |
+
1. **I-JEPA** (Assran et al., 2023) — [arxiv:2301.08243](https://arxiv.org/abs/2301.08243): JEPA architecture, EMA target, L2 loss, narrow predictor
|
| 198 |
+
2. **LeWorldModel** (Maes et al., 2025) — [arxiv:2603.19312](https://arxiv.org/abs/2603.19312): SIGReg anti-collapse, end-to-end JEPA
|
| 199 |
+
3. **Coconut** (Yu et al., 2024) — [arxiv:2412.06769](https://arxiv.org/abs/2412.06769): Chain of Continuous Thought, latent reasoning
|
| 200 |
+
4. **DINOv3** (Meta, 2025) — [arxiv:2508.10104](https://arxiv.org/abs/2508.10104): Dense SSL with RoPE + Gram anchoring
|
| 201 |
+
5. **SoftCoT++** (Xu et al., 2025) — [arxiv:2505.11484](https://arxiv.org/abs/2505.11484): Soft chain-of-thought with contrastive learning
|
| 202 |
+
|
| 203 |
+
## License
|
| 204 |
+
|
| 205 |
+
Apache-2.0
|