JorgeAV commited on
Commit
b10dd74
·
verified ·
1 Parent(s): 9ad9d20

Update README with adapted component stack (DINOv3, Qwen3, SAM3.1, etc.)

Browse files
Files changed (1) hide show
  1. README.md +183 -22
README.md CHANGED
@@ -1,11 +1,11 @@
1
  ---
2
- title: ml-intern sandbox
3
- emoji: 🌍
4
- colorFrom: gray
5
- colorTo: blue
6
- sdk: docker
7
- app_port: 7860
8
- pinned: false
9
  ---
10
 
11
  # MR-JEPA: Multimodal Reasoning via Joint-Embedding Predictive Architecture
@@ -22,23 +22,184 @@ z₀ (initial evidence) → z₁ (first refinement) → z₂ (deeper reasoning)
22
 
23
  This trajectory is supervised by a **JEPA objective**: a target encoder (EMA) generates target latent states, and the online predictor learns to predict them. The JEPA loss encourages the model to learn **meaningful intermediate reasoning states** — not just the final answer.
24
 
 
 
25
  ## Architecture
26
 
27
  ```
28
- ┌─────────────┐ ┌──────────────┐ ┌─────────────────┐ ┌──────────┐
29
- DINOv2/v3 │────▶│ Evidence │────▶│ Latent Rollout │────▶│ Answer
30
- (frozen) │ │ Memory │ │ z₀→z₁→z₂→z₃ │ │ Heads
31
- └─────────────┘ │ (Perceiver) │ (shared block) │ └──────────┘
32
- └──────┬───────┘ ────────────────┘
33
- ┌─────────────┐
34
- DeBERTa-v3 │───────────┘ ┌───────┴────────┐
35
- (frozen) │ Target Encoder
36
- └─────────────┘ │ (EMA copy)
37
- └────────────────┘
38
- ┌─────────────┐
39
- │ OCR/Layout/ │──────────JEPA Loss: L₂ + SIGReg
40
- Chart/SAM │ (Phase 3)
41
- └─────────────┘
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  ```
43
 
44
- See `mr_jepa/ARCHITECTURE.md` for the complete specification.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ tags:
3
+ - multimodal
4
+ - reasoning
5
+ - jepa
6
+ - world-model
7
+ - vision-language
8
+ license: apache-2.0
9
  ---
10
 
11
  # MR-JEPA: Multimodal Reasoning via Joint-Embedding Predictive Architecture
 
22
 
23
  This trajectory is supervised by a **JEPA objective**: a target encoder (EMA) generates target latent states, and the online predictor learns to predict them. The JEPA loss encourages the model to learn **meaningful intermediate reasoning states** — not just the final answer.
24
 
25
+ ---
26
+
27
  ## Architecture
28
 
29
  ```
30
+ ┌─────────────┐ ┌────────────────────┐ ┌─────────────────┐ ┌───────────────
31
+ DINOv3-L/16 │────▶│ Evidence Memory │────▶│ Latent Rollout │────▶│ Disc. Head
32
+ (frozen) │ │ (Perceiver Resampl)│ │ z₀→z₁→z₂→z₃ │ │ (MC scoring)
33
+ └─────────────└────────┬───────────┘ (shared block) │ └───────────────
34
+ └────────┬────────────────────────
35
+ ┌───────────── ├───▶│ Gen. Decoder │
36
+ Qwen3-Embed │────────────── ┌───────┴────────┐ │ (Qwen3.5-4B) │
37
+ 0.6B │ Target Encoder └───────────────┘
38
+ (frozen) │ │ (EMA copy)
39
+ └──────────────┘ └────────────────
40
+
41
+ ──────────────┐ JEPA Loss:
42
+ │ Phase 3 opt: │──────────┘ SmoothL1/Cosine
43
+ │ PaddleOCR-VL │ + SIGReg (purist)
44
+ │ SAM 3.1 │ / VICReg (hybrid)
45
+ └──────────────┘
46
+ ```
47
+
48
+ ---
49
+
50
+ ## Component Stack
51
+
52
+ | Module | Primary Choice | Alternative | Notes |
53
+ |--------|---------------|-------------|-------|
54
+ | **Visual backbone** | [`timm/vit_large_patch16_dinov3.lvd1689m`](https://hf.co/timm/vit_large_patch16_dinov3.lvd1689m) — DINOv3-L/16, 1024-dim, 300M | DINOv3-B/16 (purist); DINOv2-L/14 (ablation) | Frozen Phase 1; last 6 layers unfrozen Phase 2 |
55
+ | **Text encoder** | [`Qwen/Qwen3-Embedding-0.6B`](https://hf.co/Qwen/Qwen3-Embedding-0.6B) — 1024-dim, 596M | Qwen3-Embedding-4B (heavier); EmbeddingGemma-300M (lighter) | Frozen Phase 1; last 4 layers unfrozen Phase 2 |
56
+ | **Evidence memory** | Perceiver Resampler, 64 queries, 4 cross-attn layers | Q-Former as baseline | Modality-typed tokens (visual/text/OCR/layout/chart/SAM) |
57
+ | **OCR / doc / charts** | [`PaddlePaddle/PaddleOCR-VL-1.5`](https://hf.co/PaddlePaddle/PaddleOCR-VL-1.5) — 958M | MinerU2.5 for heavy PDF parsing | Phase 3 only, offline token extraction |
58
+ | **Segmentation** | [`facebook/sam3.1`](https://hf.co/facebook/sam3.1) — SAM 3.1, 3.3GB, gated | SAM 2.1-Large (stable) | Phase 3 optional, offline mask extraction |
59
+ | **Latent rollout** | Shared transformer predictor, 6 layers, K=3 | Per-step unshared blocks (ablation) | Weight-tied across steps; sigmoid evidence gates |
60
+ | **Target encoder** | EMA copy (cosine 0.996→1.0) of evidence+rollout | Frozen target (ablation baseline) | From I-JEPA |
61
+ | **JEPA loss** | SmoothL1 + VICReg (hybrid); Cosine + SIGReg (purist) | MSE (ablation) | SIGReg emphasis in purist branch |
62
+ | **Disc. head** | MLP/bilinear scorer | Cross-encoder scorer (ablation) | Attention-pooled z_K × option embeddings |
63
+ | **Gen. decoder** | [`Qwen/Qwen3.5-4B`](https://hf.co/Qwen/Qwen3.5-4B) — 4.7B, multimodal | [`HuggingFaceTB/SmolLM3-3B`](https://hf.co/HuggingFaceTB/SmolLM3-3B) (cheaper); Gemma3-4B | Phase 3+, cross-attends to z_K + evidence |
64
+ | **Teacher/baseline** | InternVL3.5 / Qwen3-VL | External comparison only | NOT used as internal module |
65
+
66
+ ---
67
+
68
+ ## Training Protocol
69
+
70
+ ### Phase 1: Reasoning Core (15–20 epochs)
71
+ - **Freeze** all perception (DINOv3 + Qwen3-Embedding)
72
+ - **Train** evidence memory + latent rollout + discriminative head
73
+ - Full JEPA loss + task loss
74
+ - LR: 3e-4, effective batch: 64
75
+
76
+ ### Phase 2: Perception Fine-tuning (10 epochs)
77
+ - **Unfreeze** last 6 DINOv3 layers + last 4 Qwen3-Embedding layers (1e-5)
78
+ - Continue training reasoning core (1e-4)
79
+
80
+ ### Phase 3: Enriched Evidence + Generative Decoder (10 epochs)
81
+ - **Enable** PaddleOCR-VL tokens, SAM 3.1 masks, layout/chart tokens
82
+ - **Attach** Qwen3.5-4B generative decoder for open-ended answers
83
+ - End-to-end fine-tuning, LR: 5e-5
84
+
85
+ ---
86
+
87
+ ## Target Benchmarks (9)
88
+
89
+ | Benchmark | Type | Metric | Key Challenge |
90
+ |-----------|------|--------|---------------|
91
+ | MMMU | MC (multi-image) | Accuracy | Multi-discipline, up to 7 images |
92
+ | MathVista | Mixed MC/Open | Accuracy | Mathematical reasoning |
93
+ | ScienceQA | MC | Accuracy | Scientific diagrams, nullable images |
94
+ | AI2D | MC | Accuracy | Science diagram comprehension |
95
+ | MMBench | MC | CircularEval Acc | General visual understanding |
96
+ | MMStar | MC | Accuracy | Vision-dependent questions |
97
+ | DocVQA | Open | ANLS | Document text extraction |
98
+ | TextVQA | Open | VQA Accuracy | Scene text reading |
99
+ | ChartQA | Open | Relaxed Accuracy | Chart data extraction |
100
+
101
+ ---
102
+
103
+ ## Experimental Branches
104
+
105
+ ### Hybrid-main (competitive)
106
+ - DINOv3-L backbone, SmoothL1 + VICReg, K=3
107
+ - Full enriched evidence in Phase 3
108
+ - Target: state-of-the-art on all benchmarks
109
+
110
+ ### Purist-side (scientific validation)
111
+ - DINOv3-B backbone, Cosine + SIGReg, K=5
112
+ - No enriched evidence, pure JEPA reasoning
113
+ - Target: demonstrate JEPA contributes beyond perception
114
+
115
+ ---
116
+
117
+ ## Ablation Experiments
118
+
119
+ | Experiment | Modification | Purpose |
120
+ |------------|-------------|---------|
121
+ | `hybrid_main` | Full model | Baseline |
122
+ | `no_jepa` | Remove L_JEPA, task loss only | Validate JEPA objective |
123
+ | `no_rollout` | K=0, use z₀ directly | Validate iterative refinement |
124
+ | `no_gate` | Remove evidence gating | Validate adaptive evidence flow |
125
+ | `K1` / `K5` / `K7` | Vary rollout depth | Find optimal depth |
126
+ | `dinov2_ablation` | DINOv2-L/14 backbone | DINOv3 vs DINOv2 |
127
+ | `purist` | DINOv3-B, no enriched ev., SIGReg | Isolate JEPA contribution |
128
+ | `mse_loss` / `cosine_loss` | Alternative JEPA losses | Loss function ablation |
129
+
130
+ ---
131
+
132
+ ## Project Structure
133
+
134
+ ```
135
+ MR-JEPA/
136
+ ├── README.md # This file
137
+ ├── train_mrjepa.py # Complete training script (CLI, all ablations)
138
+ ├── test_architecture.py # Architecture validation tests (synthetic data)
139
+
140
+ ├── mr_jepa/
141
+ │ ├── __init__.py
142
+ │ ├── ARCHITECTURE.md # Detailed architecture specification
143
+ │ │
144
+ │ ├── configs/
145
+ │ │ ├── __init__.py
146
+ │ │ └── model_config.py # All hyperparameter dataclasses
147
+ │ │
148
+ │ ├── models/
149
+ │ │ ├── __init__.py
150
+ │ │ ├── mr_jepa.py # Main model (integrates all components)
151
+ │ │ ├── backbones.py # Visual (DINOv3/v2) + Text (Qwen3-Embedding)
152
+ │ │ ├── evidence_memory.py # Perceiver Resampler multimodal fusion
153
+ │ │ ├── latent_rollout.py # K-step shared predictor + evidence gates
154
+ │ │ ├── target_encoder.py # EMA encoder + JEPA/SIGReg/VICReg losses
155
+ │ │ └── answer_heads.py # Discriminative (MC) + Generative (open-ended)
156
+ │ │
157
+ │ ├── data/
158
+ │ │ ├── __init__.py
159
+ │ │ ├── unified_dataset.py # 9-benchmark unified loader with format quirks
160
+ │ │ └── data_utils.py # Collator, dataloader factory, benchmark configs
161
+ │ │
162
+ │ ├── training/
163
+ │ │ ├── __init__.py
164
+ │ │ ├── trainer.py # 3-phase training loop
165
+ │ │ └── phase_scheduler.py # Phase transitions, LR scheduling
166
+ │ ��
167
+ │ ├── evaluation/
168
+ │ │ ├── __init__.py
169
+ │ │ └── metrics.py # Accuracy, ANLS, VQA Acc, Relaxed Acc
170
+ │ │
171
+ │ └── utils/
172
+ │ ├── __init__.py
173
+ │ ├── visualization.py # Trajectory PCA, gate analysis
174
+ │ └── ablation.py # Systematic ablation runner
175
+
176
+ ├── results/ # Training results (auto-pushed)
177
+ │ ├── hybrid_main.json
178
+ │ ├── no_jepa.json
179
+ │ ├── no_rollout.json
180
+ │ └── ...
181
+
182
+ └── checkpoints/ # Best model checkpoints (auto-pushed)
183
+ ├── hybrid_main_best.pt
184
+ └── ...
185
  ```
186
 
187
+ ---
188
+
189
+ ## Paper Contribution
190
+
191
+ > **A world model for multimodal reasoning**: We demonstrate that modeling the evolution of a latent belief state via JEPA-style prediction improves performance on static multimodal benchmarks compared to single-pass baselines. The evidence-gated rollout with K=3 steps learns meaningful intermediate reasoning states, validated through ablation studies across 9 benchmarks. The JEPA objective (not human chain-of-thought) supervises a latent trajectory generated by an EMA target encoder, showing that self-supervised dynamics training transfers to discriminative reasoning tasks.
192
+
193
+ ---
194
+
195
+ ## Key References
196
+
197
+ 1. **I-JEPA** (Assran et al., 2023) — [arxiv:2301.08243](https://arxiv.org/abs/2301.08243): JEPA architecture, EMA target, L2 loss, narrow predictor
198
+ 2. **LeWorldModel** (Maes et al., 2025) — [arxiv:2603.19312](https://arxiv.org/abs/2603.19312): SIGReg anti-collapse, end-to-end JEPA
199
+ 3. **Coconut** (Yu et al., 2024) — [arxiv:2412.06769](https://arxiv.org/abs/2412.06769): Chain of Continuous Thought, latent reasoning
200
+ 4. **DINOv3** (Meta, 2025) — [arxiv:2508.10104](https://arxiv.org/abs/2508.10104): Dense SSL with RoPE + Gram anchoring
201
+ 5. **SoftCoT++** (Xu et al., 2025) — [arxiv:2505.11484](https://arxiv.org/abs/2505.11484): Soft chain-of-thought with contrastive learning
202
+
203
+ ## License
204
+
205
+ Apache-2.0