--- tags: - multimodal - reasoning - jepa - world-model - vision-language license: apache-2.0 --- # MR-JEPA: Multimodal Reasoning via Joint-Embedding Predictive Architecture > A world model for multimodal reasoning that refines a latent belief state over K=3 steps using JEPA-style prediction, evidence gating, and dense visual backbones. ## Key Idea Traditional multimodal models produce answers in a single forward pass. MR-JEPA instead models **the evolution of a belief state** as the system reasons about a question: ``` z₀ (initial evidence) → z₁ (first refinement) → z₂ (deeper reasoning) → z₃ (answer) ``` This trajectory is supervised by a **JEPA objective**: a target encoder (EMA) generates target latent states, and the online predictor learns to predict them. The JEPA loss encourages the model to learn **meaningful intermediate reasoning states** — not just the final answer. --- ## Architecture ``` ┌──────────────┐ ┌────────────────────┐ ┌──────────────────┐ ┌───────────────┐ │ DINOv3-L/16 │────▶│ Evidence Memory │────▶│ Latent Rollout │────▶│ Disc. Head │ │ (frozen) │ │ (Perceiver Resampl)│ │ z₀→z₁→z₂→z₃ │ │ (MC scoring) │ └──────────────┘ └────────┬───────────┘ │ (shared block) │ └───────────────┘ │ └────────┬─────────┘ ┌───────────────┐ ┌──────────────┐ │ │ ├───▶│ Gen. Decoder │ │ Qwen3-Embed │──────────────┘ ┌───────┴────────┐ │ (Qwen3.5-4B) │ │ 0.6B │ │ Target Encoder │ └───────────────┘ │ (frozen) │ │ (EMA copy) │ └──────────────┘ └────────────────┘ │ ┌──────────────┐ JEPA Loss: │ Phase 3 opt: │──────────┘ SmoothL1/Cosine │ PaddleOCR-VL │ + SIGReg (purist) │ SAM 3.1 │ / VICReg (hybrid) └──────────────┘ ``` --- ## Component Stack | Module | Primary Choice | Alternative | Notes | |--------|---------------|-------------|-------| | **Visual backbone** | [`timm/vit_large_patch16_dinov3.lvd1689m`](https://hf.co/timm/vit_large_patch16_dinov3.lvd1689m) — DINOv3-L/16, 1024-dim, 300M | DINOv3-B/16 (purist); DINOv2-L/14 (ablation) | Frozen Phase 1; last 6 layers unfrozen Phase 2 | | **Text encoder** | [`Qwen/Qwen3-Embedding-0.6B`](https://hf.co/Qwen/Qwen3-Embedding-0.6B) — 1024-dim, 596M | Qwen3-Embedding-4B (heavier); EmbeddingGemma-300M (lighter) | Frozen Phase 1; last 4 layers unfrozen Phase 2 | | **Evidence memory** | Perceiver Resampler, 64 queries, 4 cross-attn layers | Q-Former as baseline | Modality-typed tokens (visual/text/OCR/layout/chart/SAM) | | **OCR / doc / charts** | [`PaddlePaddle/PaddleOCR-VL-1.5`](https://hf.co/PaddlePaddle/PaddleOCR-VL-1.5) — 958M | MinerU2.5 for heavy PDF parsing | Phase 3 only, offline token extraction | | **Segmentation** | [`jetjodh/sam3.1`](https://hf.co/jetjodh/sam3.1) — SAM 3.1, non-gated mirror | SAM 2.1-Large (stable) | Phase 3 optional, offline mask extraction | | **Latent rollout** | Shared transformer predictor, 6 layers, K=3 | Per-step unshared blocks (ablation) | Weight-tied across steps; sigmoid evidence gates | | **Target encoder** | EMA copy (cosine 0.996→1.0) of evidence+rollout | Frozen target (ablation baseline) | From I-JEPA | | **JEPA loss** | SmoothL1 + VICReg (hybrid); Cosine + SIGReg (purist) | MSE (ablation) | SIGReg emphasis in purist branch | | **Disc. head** | MLP/bilinear scorer | Cross-encoder scorer (ablation) | Attention-pooled z_K × option embeddings | | **Gen. decoder** | [`Qwen/Qwen3.5-4B`](https://hf.co/Qwen/Qwen3.5-4B) — 4.7B, multimodal | [`HuggingFaceTB/SmolLM3-3B`](https://hf.co/HuggingFaceTB/SmolLM3-3B) (cheaper); Gemma3-4B | Phase 3+, cross-attends to z_K + evidence | | **Teacher/baseline** | InternVL3.5 / Qwen3-VL | External comparison only | NOT used as internal module | --- ## Training Protocol ### Phase 1: Reasoning Core (15–20 epochs) - **Freeze** all perception (DINOv3 + Qwen3-Embedding) - **Train** evidence memory + latent rollout + discriminative head - Full JEPA loss + task loss - LR: 3e-4, effective batch: 64 ### Phase 2: Perception Fine-tuning (10 epochs) - **Unfreeze** last 6 DINOv3 layers + last 4 Qwen3-Embedding layers (1e-5) - Continue training reasoning core (1e-4) ### Phase 3: Enriched Evidence + Generative Decoder (10 epochs) - **Enable** PaddleOCR-VL tokens, SAM 3.1 masks, layout/chart tokens - **Attach** Qwen3.5-4B generative decoder for open-ended answers - End-to-end fine-tuning, LR: 5e-5 --- ## Target Benchmarks (9) | Benchmark | Type | Metric | Key Challenge | |-----------|------|--------|---------------| | MMMU | MC (multi-image) | Accuracy | Multi-discipline, up to 7 images | | MathVista | Mixed MC/Open | Accuracy | Mathematical reasoning | | ScienceQA | MC | Accuracy | Scientific diagrams, nullable images | | AI2D | MC | Accuracy | Science diagram comprehension | | MMBench | MC | CircularEval Acc | General visual understanding | | MMStar | MC | Accuracy | Vision-dependent questions | | DocVQA | Open | ANLS | Document text extraction | | TextVQA | Open | VQA Accuracy | Scene text reading | | ChartQA | Open | Relaxed Accuracy | Chart data extraction | --- ## Experimental Branches ### Hybrid-main (competitive) - DINOv3-L backbone, SmoothL1 + VICReg, K=3 - Full enriched evidence in Phase 3 - Target: state-of-the-art on all benchmarks ### Purist-side (scientific validation) - DINOv3-B backbone, Cosine + SIGReg, K=5 - No enriched evidence, pure JEPA reasoning - Target: demonstrate JEPA contributes beyond perception --- ## Ablation Experiments Each experiment maps 1:1 to a CLI flag in `train_mrjepa.py`. | Experiment | CLI flag | Modification | Purpose | |------------|----------|-------------|---------| | `hybrid_main` | *(default)* | Full model | Baseline | | `no_jepa` | `--no_jepa` | Remove L_JEPA, task loss only | Validate JEPA objective | | `no_rollout` | `--no_rollout` | K=0, use z₀ directly | Validate iterative refinement | | `no_gate` | `--no_evidence_gate` | Remove evidence gating | Validate adaptive evidence flow | | `K1` / `K5` / `K7` | `--K 1/5/7` | Vary rollout depth | Find optimal depth | | `dinov2_ablation` | `--backbone dinov2` | DINOv2-L/14 backbone | DINOv3 vs DINOv2 | | `mse_loss` | `--loss_fn mse` | MSE (L2) JEPA loss | Original I-JEPA loss | | `cosine_loss` | `--loss_fn cosine` | Cosine similarity JEPA loss | Purist-style loss | | `no_sigreg` | `--no_sigreg` | Disable SIGReg anti-collapse | Test regularization | | `vicreg_only` | `--no_sigreg --use_vicreg` | VICReg only | Alternative anti-collapse | | `purist` | `--purist` | DINOv3-B, K=5, Cosine+SIGReg | Isolate JEPA contribution | --- ## Project Structure ``` MR-JEPA/ ├── README.md # This file ├── train_mrjepa.py # Complete training script (CLI, all ablations) ├── test_architecture.py # Architecture validation tests (synthetic data) │ ├── mr_jepa/ │ ├── __init__.py │ ├── ARCHITECTURE.md # Detailed architecture specification │ │ │ ├── configs/ │ │ ├── __init__.py │ │ └── model_config.py # All hyperparameter dataclasses │ │ │ ├── models/ │ │ ├── __init__.py │ │ ├── mr_jepa.py # Main model (integrates all components) │ │ ├── backbones.py # Visual (DINOv3/v2) + Text (Qwen3-Embedding) │ │ ├── evidence_memory.py # Perceiver Resampler multimodal fusion │ │ ├── latent_rollout.py # K-step shared predictor + evidence gates │ │ ├── target_encoder.py # EMA encoder + JEPA/SIGReg/VICReg losses │ │ └── answer_heads.py # Discriminative (MC) + Generative (open-ended) │ │ │ ├── data/ │ │ ├── __init__.py │ │ ├── unified_dataset.py # 9-benchmark unified loader with format quirks │ │ └── data_utils.py # Collator, dataloader factory, benchmark configs │ │ │ ├── training/ │ │ ├── __init__.py │ │ ├── trainer.py # 3-phase training loop │ │ └── phase_scheduler.py # Phase transitions, LR scheduling │ │ │ ├── evaluation/ │ │ ├── __init__.py │ │ └── metrics.py # Accuracy, ANLS, VQA Acc, Relaxed Acc │ │ │ └── utils/ │ ├── __init__.py │ ├── visualization.py # Trajectory PCA, gate analysis │ └── ablation.py # Systematic ablation runner │ ├── results/ # Training results (auto-pushed) │ ├── hybrid_main.json │ ├── no_jepa.json │ ├── no_rollout.json │ └── ... │ └── checkpoints/ # Best model checkpoints (auto-pushed) ├── hybrid_main_best.pt └── ... ``` --- ## Paper Contribution > **A world model for multimodal reasoning**: We demonstrate that modeling the evolution of a latent belief state via JEPA-style prediction improves performance on static multimodal benchmarks compared to single-pass baselines. The evidence-gated rollout with K=3 steps learns meaningful intermediate reasoning states, validated through ablation studies across 9 benchmarks. The JEPA objective (not human chain-of-thought) supervises a latent trajectory generated by an EMA target encoder, showing that self-supervised dynamics training transfers to discriminative reasoning tasks. --- ## Key References 1. **I-JEPA** (Assran et al., 2023) — [arxiv:2301.08243](https://arxiv.org/abs/2301.08243): JEPA architecture, EMA target encoder, L2 prediction loss, narrow predictor 2. **LeWorldModel** (Maes et al., 2025) — [arxiv:2603.19312](https://arxiv.org/abs/2603.19312): SIGReg anti-collapse, end-to-end JEPA 3. **Coconut** (Yu et al., 2024) — [arxiv:2412.06769](https://arxiv.org/abs/2412.06769): Chain of Continuous Thought, latent reasoning 4. **DINOv3** (Meta, 2025) — [arxiv:2508.10104](https://arxiv.org/abs/2508.10104): Dense SSL with RoPE + Gram anchoring 5. **SoftCoT++** (Xu et al., 2025) — [arxiv:2505.11484](https://arxiv.org/abs/2505.11484): Soft chain-of-thought with contrastive learning ## License Apache-2.0