| --- |
| tags: |
| - multimodal |
| - reasoning |
| - jepa |
| - world-model |
| - vision-language |
| license: apache-2.0 |
| --- |
| |
| # MR-JEPA: Multimodal Reasoning via Joint-Embedding Predictive Architecture |
|
|
| > A world model for multimodal reasoning that refines a latent belief state over K=3 steps using JEPA-style prediction, evidence gating, and dense visual backbones. |
|
|
| ## Key Idea |
|
|
| Traditional multimodal models produce answers in a single forward pass. MR-JEPA instead models **the evolution of a belief state** as the system reasons about a question: |
|
|
| ``` |
| z₀ (initial evidence) → z₁ (first refinement) → z₂ (deeper reasoning) → z₃ (answer) |
| ``` |
|
|
| This trajectory is supervised by a **JEPA objective**: a target encoder (EMA) generates target latent states, and the online predictor learns to predict them. The JEPA loss encourages the model to learn **meaningful intermediate reasoning states** — not just the final answer. |
|
|
| --- |
|
|
| ## Architecture |
|
|
| ``` |
| ┌──────────────┐ ┌────────────────────┐ ┌──────────────────┐ ┌───────────────┐ |
| │ DINOv3-L/16 │────▶│ Evidence Memory │────▶│ Latent Rollout │────▶│ Disc. Head │ |
| │ (frozen) │ │ (Perceiver Resampl)│ │ z₀→z₁→z₂→z₃ │ │ (MC scoring) │ |
| └──────────────┘ └────────┬───────────┘ │ (shared block) │ └───────────────┘ |
| │ └────────┬─────────┘ ┌───────────────┐ |
| ┌──────────────┐ │ │ ├───▶│ Gen. Decoder │ |
| │ Qwen3-Embed │──────────────┘ ┌───────┴────────┐ │ (Qwen3.5-4B) │ |
| │ 0.6B │ │ Target Encoder │ └───────────────┘ |
| │ (frozen) │ │ (EMA copy) │ |
| └──────────────┘ └────────────────┘ |
| │ |
| ┌──────────────┐ JEPA Loss: |
| │ Phase 3 opt: │──────────┘ SmoothL1/Cosine |
| │ PaddleOCR-VL │ + SIGReg (purist) |
| │ SAM 3.1 │ / VICReg (hybrid) |
| └──────────────┘ |
| ``` |
|
|
| --- |
|
|
| ## Component Stack |
|
|
| | Module | Primary Choice | Alternative | Notes | |
| |--------|---------------|-------------|-------| |
| | **Visual backbone** | [`timm/vit_large_patch16_dinov3.lvd1689m`](https://hf.co/timm/vit_large_patch16_dinov3.lvd1689m) — DINOv3-L/16, 1024-dim, 300M | DINOv3-B/16 (purist); DINOv2-L/14 (ablation) | Frozen Phase 1; last 6 layers unfrozen Phase 2 | |
| | **Text encoder** | [`Qwen/Qwen3-Embedding-0.6B`](https://hf.co/Qwen/Qwen3-Embedding-0.6B) — 1024-dim, 596M | Qwen3-Embedding-4B (heavier); EmbeddingGemma-300M (lighter) | Frozen Phase 1; last 4 layers unfrozen Phase 2 | |
| | **Evidence memory** | Perceiver Resampler, 64 queries, 4 cross-attn layers | Q-Former as baseline | Modality-typed tokens (visual/text/OCR/layout/chart/SAM) | |
| | **OCR / doc / charts** | [`PaddlePaddle/PaddleOCR-VL-1.5`](https://hf.co/PaddlePaddle/PaddleOCR-VL-1.5) — 958M | MinerU2.5 for heavy PDF parsing | Phase 3 only, offline token extraction | |
| | **Segmentation** | [`jetjodh/sam3.1`](https://hf.co/jetjodh/sam3.1) — SAM 3.1, non-gated mirror | SAM 2.1-Large (stable) | Phase 3 optional, offline mask extraction | |
| | **Latent rollout** | Shared transformer predictor, 6 layers, K=3 | Per-step unshared blocks (ablation) | Weight-tied across steps; sigmoid evidence gates | |
| | **Target encoder** | EMA copy (cosine 0.996→1.0) of evidence+rollout | Frozen target (ablation baseline) | From I-JEPA | |
| | **JEPA loss** | SmoothL1 + VICReg (hybrid); Cosine + SIGReg (purist) | MSE (ablation) | SIGReg emphasis in purist branch | |
| | **Disc. head** | MLP/bilinear scorer | Cross-encoder scorer (ablation) | Attention-pooled z_K × option embeddings | |
| | **Gen. decoder** | [`Qwen/Qwen3.5-4B`](https://hf.co/Qwen/Qwen3.5-4B) — 4.7B, multimodal | [`HuggingFaceTB/SmolLM3-3B`](https://hf.co/HuggingFaceTB/SmolLM3-3B) (cheaper); Gemma3-4B | Phase 3+, cross-attends to z_K + evidence | |
| | **Teacher/baseline** | InternVL3.5 / Qwen3-VL | External comparison only | NOT used as internal module | |
|
|
| --- |
|
|
| ## Training Protocol |
|
|
| ### Phase 1: Reasoning Core (15–20 epochs) |
| - **Freeze** all perception (DINOv3 + Qwen3-Embedding) |
| - **Train** evidence memory + latent rollout + discriminative head |
| - Full JEPA loss + task loss |
| - LR: 3e-4, effective batch: 64 |
|
|
| ### Phase 2: Perception Fine-tuning (10 epochs) |
| - **Unfreeze** last 6 DINOv3 layers + last 4 Qwen3-Embedding layers (1e-5) |
| - Continue training reasoning core (1e-4) |
|
|
| ### Phase 3: Enriched Evidence + Generative Decoder (10 epochs) |
| - **Enable** PaddleOCR-VL tokens, SAM 3.1 masks, layout/chart tokens |
| - **Attach** Qwen3.5-4B generative decoder for open-ended answers |
| - End-to-end fine-tuning, LR: 5e-5 |
|
|
| --- |
|
|
| ## Target Benchmarks (9) |
|
|
| | Benchmark | Type | Metric | Key Challenge | |
| |-----------|------|--------|---------------| |
| | MMMU | MC (multi-image) | Accuracy | Multi-discipline, up to 7 images | |
| | MathVista | Mixed MC/Open | Accuracy | Mathematical reasoning | |
| | ScienceQA | MC | Accuracy | Scientific diagrams, nullable images | |
| | AI2D | MC | Accuracy | Science diagram comprehension | |
| | MMBench | MC | CircularEval Acc | General visual understanding | |
| | MMStar | MC | Accuracy | Vision-dependent questions | |
| | DocVQA | Open | ANLS | Document text extraction | |
| | TextVQA | Open | VQA Accuracy | Scene text reading | |
| | ChartQA | Open | Relaxed Accuracy | Chart data extraction | |
|
|
| --- |
|
|
| ## Experimental Branches |
|
|
| ### Hybrid-main (competitive) |
| - DINOv3-L backbone, SmoothL1 + VICReg, K=3 |
| - Full enriched evidence in Phase 3 |
| - Target: state-of-the-art on all benchmarks |
|
|
| ### Purist-side (scientific validation) |
| - DINOv3-B backbone, Cosine + SIGReg, K=5 |
| - No enriched evidence, pure JEPA reasoning |
| - Target: demonstrate JEPA contributes beyond perception |
|
|
| --- |
|
|
| ## Ablation Experiments |
|
|
| Each experiment maps 1:1 to a CLI flag in `train_mrjepa.py`. |
|
|
| | Experiment | CLI flag | Modification | Purpose | |
| |------------|----------|-------------|---------| |
| | `hybrid_main` | *(default)* | Full model | Baseline | |
| | `no_jepa` | `--no_jepa` | Remove L_JEPA, task loss only | Validate JEPA objective | |
| | `no_rollout` | `--no_rollout` | K=0, use z₀ directly | Validate iterative refinement | |
| | `no_gate` | `--no_evidence_gate` | Remove evidence gating | Validate adaptive evidence flow | |
| | `K1` / `K5` / `K7` | `--K 1/5/7` | Vary rollout depth | Find optimal depth | |
| | `dinov2_ablation` | `--backbone dinov2` | DINOv2-L/14 backbone | DINOv3 vs DINOv2 | |
| | `mse_loss` | `--loss_fn mse` | MSE (L2) JEPA loss | Original I-JEPA loss | |
| | `cosine_loss` | `--loss_fn cosine` | Cosine similarity JEPA loss | Purist-style loss | |
| | `no_sigreg` | `--no_sigreg` | Disable SIGReg anti-collapse | Test regularization | |
| | `vicreg_only` | `--no_sigreg --use_vicreg` | VICReg only | Alternative anti-collapse | |
| | `purist` | `--purist` | DINOv3-B, K=5, Cosine+SIGReg | Isolate JEPA contribution | |
|
|
| --- |
|
|
| ## Project Structure |
|
|
| ``` |
| MR-JEPA/ |
| ├── README.md # This file |
| ├── train_mrjepa.py # Complete training script (CLI, all ablations) |
| ├── test_architecture.py # Architecture validation tests (synthetic data) |
| │ |
| ├── mr_jepa/ |
| │ ├── __init__.py |
| │ ├── ARCHITECTURE.md # Detailed architecture specification |
| │ │ |
| │ ├── configs/ |
| │ │ ├── __init__.py |
| │ │ └── model_config.py # All hyperparameter dataclasses |
| │ │ |
| │ ├── models/ |
| │ │ ├── __init__.py |
| │ │ ├── mr_jepa.py # Main model (integrates all components) |
| │ │ ├── backbones.py # Visual (DINOv3/v2) + Text (Qwen3-Embedding) |
| │ │ ├── evidence_memory.py # Perceiver Resampler multimodal fusion |
| │ │ ├── latent_rollout.py # K-step shared predictor + evidence gates |
| │ │ ├── target_encoder.py # EMA encoder + JEPA/SIGReg/VICReg losses |
| │ │ └── answer_heads.py # Discriminative (MC) + Generative (open-ended) |
| │ │ |
| │ ├── data/ |
| │ │ ├── __init__.py |
| │ │ ├── unified_dataset.py # 9-benchmark unified loader with format quirks |
| │ │ └── data_utils.py # Collator, dataloader factory, benchmark configs |
| │ │ |
| │ ├── training/ |
| │ │ ├── __init__.py |
| │ │ ├── trainer.py # 3-phase training loop |
| │ │ └── phase_scheduler.py # Phase transitions, LR scheduling |
| │ │ |
| │ ├── evaluation/ |
| │ │ ├── __init__.py |
| │ │ └── metrics.py # Accuracy, ANLS, VQA Acc, Relaxed Acc |
| │ │ |
| │ └── utils/ |
| │ ├── __init__.py |
| │ ├── visualization.py # Trajectory PCA, gate analysis |
| │ └── ablation.py # Systematic ablation runner |
| │ |
| ├── results/ # Training results (auto-pushed) |
| │ ├── hybrid_main.json |
| │ ├── no_jepa.json |
| │ ├── no_rollout.json |
| │ └── ... |
| │ |
| └── checkpoints/ # Best model checkpoints (auto-pushed) |
| ├── hybrid_main_best.pt |
| └── ... |
| ``` |
|
|
| --- |
|
|
| ## Paper Contribution |
|
|
| > **A world model for multimodal reasoning**: We demonstrate that modeling the evolution of a latent belief state via JEPA-style prediction improves performance on static multimodal benchmarks compared to single-pass baselines. The evidence-gated rollout with K=3 steps learns meaningful intermediate reasoning states, validated through ablation studies across 9 benchmarks. The JEPA objective (not human chain-of-thought) supervises a latent trajectory generated by an EMA target encoder, showing that self-supervised dynamics training transfers to discriminative reasoning tasks. |
|
|
| --- |
|
|
| ## Key References |
|
|
| 1. **I-JEPA** (Assran et al., 2023) — [arxiv:2301.08243](https://arxiv.org/abs/2301.08243): JEPA architecture, EMA target encoder, L2 prediction loss, narrow predictor |
| 2. **LeWorldModel** (Maes et al., 2025) — [arxiv:2603.19312](https://arxiv.org/abs/2603.19312): SIGReg anti-collapse, end-to-end JEPA |
| 3. **Coconut** (Yu et al., 2024) — [arxiv:2412.06769](https://arxiv.org/abs/2412.06769): Chain of Continuous Thought, latent reasoning |
| 4. **DINOv3** (Meta, 2025) — [arxiv:2508.10104](https://arxiv.org/abs/2508.10104): Dense SSL with RoPE + Gram anchoring |
| 5. **SoftCoT++** (Xu et al., 2025) — [arxiv:2505.11484](https://arxiv.org/abs/2505.11484): Soft chain-of-thought with contrastive learning |
|
|
| ## License |
|
|
| Apache-2.0 |
|
|