JorgeAV commited on
Commit
dba2c56
·
verified ·
1 Parent(s): 3b4df8f

Initial MR-JEPA codebase: architecture, training, evaluation, and tests

Browse files
Files changed (35) hide show
  1. README.md +44 -0
  2. mr_jepa/ARCHITECTURE.md +303 -0
  3. mr_jepa/__init__.py +9 -0
  4. mr_jepa/configs/__init__.py +25 -0
  5. mr_jepa/configs/__pycache__/__init__.cpython-312.pyc +0 -0
  6. mr_jepa/configs/__pycache__/model_config.cpython-312.pyc +0 -0
  7. mr_jepa/configs/model_config.py +306 -0
  8. mr_jepa/data/__init__.py +9 -0
  9. mr_jepa/data/data_utils.py +273 -0
  10. mr_jepa/data/unified_dataset.py +380 -0
  11. mr_jepa/evaluation/__init__.py +15 -0
  12. mr_jepa/evaluation/__pycache__/__init__.cpython-312.pyc +0 -0
  13. mr_jepa/evaluation/__pycache__/metrics.cpython-312.pyc +0 -0
  14. mr_jepa/evaluation/metrics.py +251 -0
  15. mr_jepa/models/__init__.py +17 -0
  16. mr_jepa/models/__pycache__/answer_heads.cpython-312.pyc +0 -0
  17. mr_jepa/models/__pycache__/evidence_memory.cpython-312.pyc +0 -0
  18. mr_jepa/models/__pycache__/latent_rollout.cpython-312.pyc +0 -0
  19. mr_jepa/models/__pycache__/target_encoder.cpython-312.pyc +0 -0
  20. mr_jepa/models/answer_heads.py +369 -0
  21. mr_jepa/models/backbones.py +180 -0
  22. mr_jepa/models/evidence_memory.py +299 -0
  23. mr_jepa/models/latent_rollout.py +324 -0
  24. mr_jepa/models/mr_jepa.py +350 -0
  25. mr_jepa/models/target_encoder.py +354 -0
  26. mr_jepa/training/__init__.py +4 -0
  27. mr_jepa/training/phase_scheduler.py +107 -0
  28. mr_jepa/training/trainer.py +397 -0
  29. mr_jepa/utils/__init__.py +8 -0
  30. mr_jepa/utils/__pycache__/__init__.cpython-312.pyc +0 -0
  31. mr_jepa/utils/__pycache__/ablation.cpython-312.pyc +0 -0
  32. mr_jepa/utils/__pycache__/visualization.cpython-312.pyc +0 -0
  33. mr_jepa/utils/ablation.py +182 -0
  34. mr_jepa/utils/visualization.py +137 -0
  35. test_architecture.py +506 -0
README.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ml-intern sandbox
3
+ emoji: 🌍
4
+ colorFrom: gray
5
+ colorTo: blue
6
+ sdk: docker
7
+ app_port: 7860
8
+ pinned: false
9
+ ---
10
+
11
+ # MR-JEPA: Multimodal Reasoning via Joint-Embedding Predictive Architecture
12
+
13
+ > A world model for multimodal reasoning that refines a latent belief state over K=3 steps using JEPA-style prediction, evidence gating, and dense visual backbones.
14
+
15
+ ## Key Idea
16
+
17
+ Traditional multimodal models produce answers in a single forward pass. MR-JEPA instead models **the evolution of a belief state** as the system reasons about a question:
18
+
19
+ ```
20
+ z₀ (initial evidence) → z₁ (first refinement) → z₂ (deeper reasoning) → z₃ (answer)
21
+ ```
22
+
23
+ This trajectory is supervised by a **JEPA objective**: a target encoder (EMA) generates target latent states, and the online predictor learns to predict them. The JEPA loss encourages the model to learn **meaningful intermediate reasoning states** — not just the final answer.
24
+
25
+ ## Architecture
26
+
27
+ ```
28
+ ┌─────────────┐ ┌──────────────┐ ┌─────────────────┐ ┌──────────┐
29
+ │ DINOv2/v3 │────▶│ Evidence │────▶│ Latent Rollout │────▶│ Answer │
30
+ │ (frozen) │ │ Memory │ │ z₀→z₁→z₂→z₃ │ │ Heads │
31
+ └─────────────┘ │ (Perceiver) │ │ (shared block) │ └──────────┘
32
+ └──────┬───────┘ └────────┬────────┘
33
+ ┌─────────────┐ │ │
34
+ │ DeBERTa-v3 │───────────┘ ┌───────┴────────┐
35
+ │ (frozen) │ │ Target Encoder │
36
+ └─────────────┘ │ (EMA copy) │
37
+ └────────────────┘
38
+ ┌─────────────┐ │
39
+ │ OCR/Layout/ │──────────┘ JEPA Loss: L₂ + SIGReg
40
+ │ Chart/SAM │ (Phase 3)
41
+ └─────────────┘
42
+ ```
43
+
44
+ See `mr_jepa/ARCHITECTURE.md` for the complete specification.
mr_jepa/ARCHITECTURE.md ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MR-JEPA: Multimodal Reasoning via Joint-Embedding Predictive Architecture
2
+
3
+ ## Detailed Architecture Specification
4
+
5
+ ---
6
+
7
+ ## 1. Overview
8
+
9
+ MR-JEPA is a **world model for static multimodal reasoning**. Unlike traditional world models that predict physical dynamics (video, robotics), MR-JEPA models the evolution of a **belief state** as the system reasons about a visual question.
10
+
11
+ The core insight: solving a multimodal question (e.g., "What is the GDP growth shown in this chart?") requires iterative evidence accumulation — first extracting relevant visual features, then integrating textual context, then refining understanding through multiple reasoning steps. MR-JEPA formalizes this process as a **latent trajectory** supervised by a JEPA objective.
12
+
13
+ ```
14
+ ┌──────────────────────────────────────────┐
15
+ │ MR-JEPA Architecture │
16
+ └──────────────────────────────────────────┘
17
+
18
+ ┌─────────┐ ┌─────────────┐ ┌──────────────┐
19
+ │ DINOv2/v3│────▶│ Evidence │────▶│ Latent │──▶ Answer
20
+ │ Visual │ │ Memory │ │ Rollout │ Heads
21
+ │ Backbone │ │ (Perceiver)│ │ K=3 steps │
22
+ └─────────┘ └──────┬──────┘ └──────┬───────┘
23
+ │ │
24
+ ┌─────────┐ │ ┌──────┴───────┐
25
+ │ DeBERTa │────────────┘ │ Target │
26
+ │ Text │ │ Encoder │
27
+ │ Encoder │ │ (EMA) │
28
+ └─────────┘ └──────────────┘
29
+
30
+ ┌─────────┐ JEPA Loss:
31
+ │Optional:│ L₂ prediction
32
+ │OCR,SAM, │──────────┘ + SIGReg
33
+ │Layout │
34
+ └─────────┘
35
+ ```
36
+
37
+ ---
38
+
39
+ ## 2. Component Details
40
+
41
+ ### 2.1 Visual Backbone
42
+
43
+ **Primary choice: DINOv2-L/14** (`facebook/dinov2-large`)
44
+ - Architecture: ViT-L/14 with 300M parameters
45
+ - Output: 1024-dim patch tokens, 518×518 input → 1369 patches
46
+ - 4 register tokens + CLS token (skipped, only patch tokens used)
47
+ - Pre-trained with self-supervised DINO objective on LVD-142M
48
+ - **Why DINOv2 over CLIP/SigLIP**: Dense patch features are critical for evidence extraction. CLIP-style models optimize for global image-text alignment but lose local spatial information. DINOv2 produces patch-level features that capture fine-grained visual details needed for chart reading, document OCR, and diagram understanding.
49
+
50
+ **Alternative: DINOv3-L/16** (`timm/vit_large_patch16_dinov3.lvd1689m`)
51
+ - Architecture: ViT-L/16 with RoPE positional encoding
52
+ - Advantages: Better resolution generalization, Gram anchoring prevents feature degradation
53
+ - Trained on LVD-1689M (10× more data)
54
+
55
+ **Purist branch: DINOv2-B/14** (`facebook/dinov2-base`)
56
+ - 768-dim output, 86M params
57
+ - Compensated by deeper rollout (K=5)
58
+
59
+ ### 2.2 Text Encoder
60
+
61
+ **DeBERTa-v3-Large** (`microsoft/deberta-v3-large`)
62
+ - 1024-dim hidden, 24 layers, 304M params
63
+ - Processes: question text + answer options (concatenated with separators)
64
+ - Output: token-level embeddings for cross-attention + CLS for option scoring
65
+
66
+ **Why DeBERTa over BERT/RoBERTa**: DeBERTa-v3's disentangled attention mechanism explicitly models content vs. position, giving stronger performance on complex NLU tasks. Its relative position bias is particularly useful for understanding mathematical notation and structured question formats.
67
+
68
+ ### 2.3 Evidence Memory
69
+
70
+ **Architecture: Perceiver-style cross-attention**
71
+
72
+ ```python
73
+ N_evidence = 64 # Learnable query tokens
74
+ D = 768 # Hidden dimension
75
+ L = 4 # Cross-attention layers
76
+
77
+ # Each layer:
78
+ # 1. Self-attention among evidence queries
79
+ # 2. Cross-attention: queries attend to [visual_patches || text_tokens || enriched_tokens]
80
+ # 3. FFN with residual
81
+ ```
82
+
83
+ **Input tokens (concatenated KV sequence):**
84
+ | Source | Tokens | Dimension | Phase |
85
+ |--------|--------|-----------|-------|
86
+ | DINOv2-L patches | 1369 | 1024→768 (projected) | 1+ |
87
+ | DeBERTa text | 256 | 1024→768 (projected) | 1+ |
88
+ | OCR tokens | 128 | 768 | 3 |
89
+ | Layout tokens | 64 | 256→768 (projected) | 3 |
90
+ | Chart tokens | 64 | 512→768 (projected) | 3 |
91
+ | SAM2 segments | 32 | 256→768 (projected) | 3 (optional) |
92
+
93
+ **Modality type embeddings** (learned, added to distinguish token sources).
94
+
95
+ **Output**: 64 evidence tokens × 768 dim = dense multimodal representation.
96
+
97
+ ### 2.4 Latent Rollout (JEPA Core)
98
+
99
+ The reasoning engine. Refines a belief state over K steps:
100
+
101
+ ```
102
+ z₀ = StateInit + Proj(AvgPool(evidence)) # Initial state from evidence
103
+ z₁ = PredictorBlock(z₀, evidence) + step_emb[1]
104
+ z₂ = PredictorBlock(z₁, evidence) + step_emb[2]
105
+ z₃ = PredictorBlock(z₂, evidence) + step_emb[3] # Final state → answer
106
+ ```
107
+
108
+ **State representation**: 32 learnable tokens × 768 dim
109
+
110
+ **Shared Predictor Block** (weight-tied across K steps):
111
+ ```
112
+ For each step k:
113
+ 1. Self-attention among 32 state tokens
114
+ 2. Evidence-gated cross-attention to 64 evidence tokens
115
+ 3. FFN (768 → 3072 → 768)
116
+
117
+ PredictorBlock = [SelfAttn → EvidenceGate(CrossAttn) → FFN] × 6 layers
118
+ ```
119
+
120
+ **Evidence Gate** (sigmoid):
121
+ ```python
122
+ gate = sigmoid(W_g · [z_k || cross_attn_output]) # Per-dimension gating
123
+ gated_evidence = gate * cross_attn_output
124
+ z_k = z_{k-1} + gated_evidence # Residual
125
+ ```
126
+
127
+ The gate learns to control evidence flow per step:
128
+ - Early steps: high gate → absorb more visual/textual evidence
129
+ - Later steps: lower gate → rely on accumulated reasoning
130
+
131
+ **Step embeddings**: Learned per-step bias vectors differentiate rollout positions.
132
+
133
+ ### 2.5 Target Encoder (EMA)
134
+
135
+ **Following I-JEPA** (Assran et al., 2023):
136
+
137
+ The target encoder is an EMA copy of [Evidence Memory + Latent Rollout]:
138
+ ```
139
+ θ̄_t+1 = m(t) · θ̄_t + (1 - m(t)) · θ_t
140
+ ```
141
+
142
+ **Momentum schedule** (cosine from 0.996 → 1.0):
143
+ ```python
144
+ m(t) = 1 - (1 - 0.996) * (1 + cos(π · t/T)) / 2
145
+ ```
146
+
147
+ The target encoder generates target trajectory z*₀, z*₁, z*₂, z*₃.
148
+ The online predictor must predict these targets.
149
+
150
+ **Critical**: Target encoder receives stop-gradient inputs and produces stop-gradient outputs.
151
+
152
+ ### 2.6 JEPA Objective
153
+
154
+ **Prediction loss** (from I-JEPA):
155
+ ```
156
+ L_JEPA = (1/K) Σ_{k=1}^{K} ||z_pred_k - sg(z*_k)||²
157
+ ```
158
+ Only steps k=1..K are supervised (z₀ is deterministic from evidence).
159
+
160
+ **Anti-collapse regularization** (from LeWorldModel):
161
+ ```
162
+ L_SIGReg = (1/M) Σ_{m=1}^{M} T(Z · u_m)
163
+ ```
164
+ Where T is the Epps-Pulley normality test statistic, u_m are random unit vectors.
165
+ This encourages latent embeddings to remain Gaussian-distributed, preventing collapse.
166
+
167
+ **Total loss**:
168
+ ```
169
+ L_total = L_JEPA + L_task + λ · L_SIGReg + α · L_gen
170
+
171
+ Where:
172
+ L_task = CrossEntropy(disc_head(z_K), answer_label) # MC scoring
173
+ L_gen = CE(gen_head(z_K), target_answer_tokens) # Short answer
174
+ λ = 0.1 (SIGReg weight)
175
+ α = 0.1 (generative weight)
176
+ ```
177
+
178
+ ### 2.7 Answer Heads
179
+
180
+ **Discriminative Head (Primary)** — for MC questions:
181
+ ```
182
+ z_pooled = AttentionPool(z_K) # 32 tokens → 1 vector
183
+ score_i = MLP([z_pooled, opt_i, z_pooled ⊙ opt_i]) # Per-option score
184
+ probs = softmax(scores, mask=valid_options)
185
+ ```
186
+
187
+ **Generative Head (Secondary)** — for open-ended questions:
188
+ ```
189
+ Small transformer decoder (4 layers):
190
+ - Causal self-attention
191
+ - Cross-attention to z_K (latent state)
192
+ - Cross-attention to evidence memory (evidence-constrained)
193
+ - FFN
194
+
195
+ Max 64 tokens output. Weight-tied embedding + LM head.
196
+ ```
197
+
198
+ ---
199
+
200
+ ## 3. Training Protocol
201
+
202
+ ### Phase 1: Reasoning Core (20 epochs)
203
+
204
+ | Component | Status | LR |
205
+ |-----------|--------|-----|
206
+ | DINOv2-L | **Frozen** | — |
207
+ | DeBERTa | **Frozen** | — |
208
+ | Evidence Memory | Training | 3e-4 |
209
+ | Latent Rollout | Training | 3e-4 |
210
+ | Answer Heads | Training | 3e-4 |
211
+ | Target Encoder | EMA update | — |
212
+
213
+ **Data**: ScienceQA train (12.7K) + any available train splits
214
+ **Objective**: Full JEPA + task + SIGReg
215
+ **Batch size**: 32 × 4 accum = 128 effective
216
+
217
+ ### Phase 2: Perception Fine-tuning (10 epochs)
218
+
219
+ | Component | Status | LR |
220
+ |-----------|--------|-----|
221
+ | DINOv2-L (last 6 layers) | **Training** | 1e-5 |
222
+ | DeBERTa (last 4 layers) | **Training** | 1e-5 |
223
+ | Evidence Memory | Training | 1e-4 |
224
+ | Latent Rollout | Training | 1e-4 |
225
+ | Answer Heads | Training | 1e-4 |
226
+
227
+ ### Phase 3: Enriched Evidence (10 epochs)
228
+
229
+ | Component | Status | LR |
230
+ |-----------|--------|-----|
231
+ | All above | Training | 5e-5 |
232
+ | OCR tokens | **Enabled** | 5e-5 |
233
+ | Layout tokens | **Enabled** | 5e-5 |
234
+ | Chart tokens | **Enabled** | 5e-5 |
235
+
236
+ **Focus benchmarks**: DocVQA, TextVQA, ChartQA
237
+
238
+ ---
239
+
240
+ ## 4. Ablation Experiments
241
+
242
+ ### Key ablations for the paper:
243
+
244
+ | Experiment | Modification | Expected finding |
245
+ |------------|-------------|-----------------|
246
+ | **Full MR-JEPA** | Baseline | Best overall |
247
+ | **No JEPA** | Remove L_JEPA, train with task loss only | Drops on reasoning-heavy benchmarks |
248
+ | **No Rollout** | K=0, use z₀ directly | Significant drop (proves rollout value) |
249
+ | **No Evidence Gate** | Remove gating | Slight drop (gate helps focus) |
250
+ | **K=1** | Shallow rollout | Worse than K=3 |
251
+ | **K=5** | Deeper rollout | Diminishing returns |
252
+ | **No SIGReg** | Remove anti-collapse | Training instability |
253
+ | **Purist branch** | DINOv2-B, no enriched evidence | Lower absolute scores, but validates JEPA contribution |
254
+
255
+ ### Cross-benchmark analysis:
256
+ - JEPA contribution should be highest on **reasoning** benchmarks (MathVista, MMMU, ScienceQA)
257
+ - Evidence gate contribution should be highest on **evidence-rich** benchmarks (DocVQA, ChartQA)
258
+ - Enriched evidence (Phase 3) should matter most for **document** benchmarks
259
+
260
+ ---
261
+
262
+ ## 5. Parameter Budget
263
+
264
+ | Component | Parameters | Trainable (Phase 1) |
265
+ |-----------|-----------|---------------------|
266
+ | DINOv2-L | 300M | 0 |
267
+ | DeBERTa-v3-L | 304M | 0 |
268
+ | Evidence Memory | ~3M | 3M |
269
+ | Latent Rollout | ~3M | 3M |
270
+ | Disc Head | ~2M | 2M |
271
+ | Gen Head | ~25M | 25M |
272
+ | **Total** | **~637M** | **~33M** |
273
+
274
+ Phase 1 trains only ~5% of total parameters. The model is computationally efficient — the JEPA reasoning core is lightweight compared to the frozen perception backbones.
275
+
276
+ ---
277
+
278
+ ## 6. Benchmark Format Reference
279
+
280
+ | Benchmark | Type | Answer | Metric | Eval Split |
281
+ |-----------|------|--------|--------|------------|
282
+ | MMMU | MC (up to 7 images) | Letter A-D | Accuracy | validation (900) |
283
+ | MathVista | Mixed MC/Open | Letter or value | Accuracy | testmini (1000) |
284
+ | ScienceQA | MC (nullable image) | 0-indexed int | Accuracy | test (4241) |
285
+ | AI2D | MC (diagrams) | 0-indexed str | Accuracy | test (3088) |
286
+ | MMBench | MC (A/B/C/D cols) | Letter | CircularEval Acc | dev (4329) |
287
+ | MMStar | MC (embedded options) | Letter | Accuracy | val (1500) |
288
+ | DocVQA | Open (documents) | List[str] | ANLS | validation (5349) |
289
+ | TextVQA | Open (scene text) | 10 annotations | VQA Accuracy | validation (5000) |
290
+ | ChartQA | Open (charts) | str/number | Relaxed Accuracy | test (2500) |
291
+
292
+ ---
293
+
294
+ ## 7. Key References
295
+
296
+ 1. **I-JEPA** (Assran et al., 2023) — arxiv:2301.08243: JEPA architecture, EMA target encoder, L2 prediction loss, narrow predictor
297
+ 2. **V-JEPA** (Bardes et al., 2024) — arxiv:2412.10925: Temporal extension, multi-step prediction in latent space
298
+ 3. **LeWorldModel** (Maes et al., 2025) — arxiv:2603.19312: SIGReg anti-collapse, end-to-end JEPA from pixels, 2474 GitHub stars
299
+ 4. **Coconut** (Yu et al., 2024) — arxiv:2412.06769: Chain of Continuous Thought, latent reasoning paradigm
300
+ 5. **SoftCoT++** (Xu et al., 2025) — arxiv:2505.11484: Soft chain-of-thought with perturbation and contrastive learning
301
+ 6. **DINOv2** (Oquab et al., 2023) — arxiv:2304.07193: Dense SSL visual backbone
302
+ 7. **DINOv3** (Meta, 2025) — arxiv:2508.10104: Improved SSL with RoPE, Gram anchoring
303
+ 8. **SigLIP2** (Google, 2025) — arxiv:2502.14786: CLIP-style with DINO features + captioning
mr_jepa/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MR-JEPA: Multimodal Reasoning via Joint-Embedding Predictive Architecture
3
+
4
+ A world model for multimodal reasoning that refines a latent belief state
5
+ over K steps using JEPA-style prediction, evidence gating, and dense
6
+ visual backbones.
7
+ """
8
+
9
+ __version__ = "0.1.0"
mr_jepa/configs/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .model_config import (
2
+ MRJEPAConfig,
3
+ VisualBackboneConfig,
4
+ TextEncoderConfig,
5
+ EvidenceMemoryConfig,
6
+ LatentRolloutConfig,
7
+ JEPAObjectiveConfig,
8
+ AnswerHeadConfig,
9
+ TrainingPhaseConfig,
10
+ get_hybrid_config,
11
+ get_purist_config,
12
+ )
13
+
14
+ __all__ = [
15
+ "MRJEPAConfig",
16
+ "VisualBackboneConfig",
17
+ "TextEncoderConfig",
18
+ "EvidenceMemoryConfig",
19
+ "LatentRolloutConfig",
20
+ "JEPAObjectiveConfig",
21
+ "AnswerHeadConfig",
22
+ "TrainingPhaseConfig",
23
+ "get_hybrid_config",
24
+ "get_purist_config",
25
+ ]
mr_jepa/configs/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (480 Bytes). View file
 
mr_jepa/configs/__pycache__/model_config.cpython-312.pyc ADDED
Binary file (12.8 kB). View file
 
mr_jepa/configs/model_config.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MR-JEPA Model Configuration
3
+
4
+ Defines all hyperparameters for the model architecture, training phases,
5
+ and JEPA objectives. Values are grounded in the literature:
6
+
7
+ - I-JEPA (Assran et al., 2023): EMA schedule, L2 prediction loss
8
+ - LeWorldModel (Maes et al., 2025): SIGReg anti-collapse, end-to-end JEPA
9
+ - Coconut (Yu et al., 2024): Latent reasoning rollout paradigm
10
+ - DINOv2/v3 (Oquab et al., 2023 / Meta 2025): Visual backbone config
11
+ """
12
+
13
+ from dataclasses import dataclass, field
14
+ from typing import Optional, Literal
15
+ import math
16
+
17
+
18
+ @dataclass
19
+ class VisualBackboneConfig:
20
+ """Configuration for the visual backbone encoder."""
21
+ # Backbone selection
22
+ backbone_type: Literal["dinov2", "dinov3", "siglip2"] = "dinov2"
23
+ model_name: str = "facebook/dinov2-large" # 1024-dim, 300M params
24
+
25
+ # DINOv2-L: hidden_size=1024, patch=14, 518px → 1369 patches + CLS + 4 reg = 1374 tokens
26
+ # DINOv3-L: hidden_size=1024, patch=16, RoPE, better dense features
27
+ # SigLIP2-So400m: hidden_size=1152, patch=14, 384px → 729 patches
28
+
29
+ hidden_size: int = 1024 # DINOv2-L / DINOv3-L output dim
30
+ image_size: int = 518 # DINOv2 default; 384 for SigLIP2
31
+ patch_size: int = 14 # 14 for DINOv2/SigLIP2, 16 for DINOv3
32
+ num_register_tokens: int = 4 # DINOv2/v3 register tokens
33
+
34
+ # Freezing control (Phase 1: fully frozen, Phase 2: unfreeze last N layers)
35
+ freeze: bool = True
36
+ unfreeze_last_n_layers: int = 0 # Phase 2: set to 4-6
37
+
38
+ # Optional: use only last N layers' features (multi-scale)
39
+ use_multi_scale: bool = False
40
+ multi_scale_layers: list = field(default_factory=lambda: [-1]) # last layer only
41
+
42
+
43
+ @dataclass
44
+ class TextEncoderConfig:
45
+ """Configuration for the text encoder."""
46
+ model_name: str = "microsoft/deberta-v3-large" # 1024-dim, strong NLU
47
+ hidden_size: int = 1024
48
+ max_text_length: int = 256 # questions + options
49
+ freeze: bool = True
50
+ unfreeze_last_n_layers: int = 0
51
+
52
+
53
+ @dataclass
54
+ class EvidenceMemoryConfig:
55
+ """
56
+ Configuration for the unified Evidence Memory.
57
+
58
+ The evidence memory is a set of tokens that fuse visual and textual information.
59
+ It uses cross-attention to attend to both visual patch tokens and text tokens,
60
+ producing a unified multimodal representation.
61
+ """
62
+ hidden_dim: int = 768 # Internal dim of the evidence memory
63
+ num_evidence_tokens: int = 64 # Learnable evidence query tokens
64
+ num_cross_attn_layers: int = 4 # Cross-attention layers for fusion
65
+ num_heads: int = 12
66
+ dropout: float = 0.1
67
+
68
+ # Projections from backbone dims to evidence dim
69
+ visual_proj_dim: int = 768 # Project visual tokens to this dim
70
+ text_proj_dim: int = 768 # Project text tokens to this dim
71
+
72
+ # Optional enriched evidence (Phase 3)
73
+ use_ocr_tokens: bool = False
74
+ use_layout_tokens: bool = False
75
+ use_chart_tokens: bool = False
76
+ use_sam_tokens: bool = False
77
+ max_ocr_tokens: int = 128
78
+ max_layout_tokens: int = 64
79
+ max_chart_tokens: int = 64
80
+ max_sam_tokens: int = 32
81
+
82
+
83
+ @dataclass
84
+ class LatentRolloutConfig:
85
+ """
86
+ Configuration for the latent belief-state rollout.
87
+
88
+ The core JEPA reasoning module. Refines z₀ over K steps:
89
+ z₀ → z₁ → z₂ → z₃
90
+
91
+ Each step applies:
92
+ 1. Self-attention over current state tokens
93
+ 2. Evidence-gated cross-attention to evidence memory
94
+ 3. FFN with residual connection
95
+
96
+ The predictor block is SHARED across all K steps (weight-tied),
97
+ following the recurrent predictor design from V-JEPA.
98
+
99
+ From I-JEPA: L2 loss in representation space, EMA target encoder
100
+ From LeWorldModel: SIGReg anti-collapse regularization
101
+ From Coconut: Iterative latent refinement paradigm
102
+ """
103
+ hidden_dim: int = 768 # Latent state dimension
104
+ num_state_tokens: int = 32 # Number of latent belief tokens per step
105
+ K: int = 3 # Number of rollout steps
106
+
107
+ # Shared predictor block
108
+ num_predictor_layers: int = 6 # Transformer layers in predictor
109
+ num_heads: int = 12
110
+ ffn_dim: int = 3072 # 4x hidden_dim
111
+ dropout: float = 0.1
112
+
113
+ # Evidence gating
114
+ use_evidence_gate: bool = True
115
+ gate_type: Literal["sigmoid", "softmax", "learned"] = "sigmoid"
116
+
117
+ # Step embedding (to differentiate rollout steps)
118
+ use_step_embedding: bool = True
119
+
120
+
121
+ @dataclass
122
+ class JEPAObjectiveConfig:
123
+ """
124
+ Configuration for the JEPA training objective.
125
+
126
+ Target encoder: EMA of the online encoder (evidence memory + rollout).
127
+ The target generates z*_k for each rollout step k.
128
+ The online predictor must predict z*_k from z_{k-1}.
129
+
130
+ Loss: L2 in representation space (from I-JEPA)
131
+ Anti-collapse: SIGReg (from LeWorldModel) or VICReg-style
132
+ """
133
+ # EMA schedule (from I-JEPA: cosine schedule 0.996 → 1.0)
134
+ ema_momentum_base: float = 0.996
135
+ ema_momentum_end: float = 1.0
136
+ ema_schedule: Literal["cosine", "linear", "constant"] = "cosine"
137
+
138
+ # Loss weights
139
+ jepa_loss_weight: float = 1.0 # L2 prediction loss
140
+ task_loss_weight: float = 1.0 # CE loss for answer classification
141
+ generative_loss_weight: float = 0.1 # Optional decoder loss
142
+
143
+ # Anti-collapse regularization (from LeWorldModel)
144
+ use_sigreg: bool = True
145
+ sigreg_weight: float = 0.1 # λ in LeWM paper
146
+ sigreg_num_projections: int = 1024 # M random projections
147
+
148
+ # Alternative: VICReg-style regularization
149
+ use_vicreg: bool = False
150
+ vicreg_var_weight: float = 1.0
151
+ vicreg_cov_weight: float = 0.04
152
+
153
+
154
+ @dataclass
155
+ class AnswerHeadConfig:
156
+ """Configuration for answer prediction heads."""
157
+ # Discriminative head (primary): scores answer options
158
+ disc_hidden_dim: int = 768
159
+ disc_num_layers: int = 2
160
+ max_num_options: int = 8 # MMMU can have up to 8 options
161
+ disc_dropout: float = 0.1
162
+
163
+ # Generative head (secondary): short open-ended answers
164
+ gen_hidden_dim: int = 768
165
+ gen_num_layers: int = 4 # Small transformer decoder
166
+ gen_num_heads: int = 12
167
+ gen_vocab_size: int = 32000 # Shared with text encoder tokenizer
168
+ gen_max_answer_length: int = 64
169
+ gen_dropout: float = 0.1
170
+
171
+ # Evidence-constrained decoding
172
+ use_evidence_constraint: bool = True # Cross-attend to evidence during generation
173
+
174
+
175
+ @dataclass
176
+ class MRJEPAConfig:
177
+ """
178
+ Complete MR-JEPA model configuration.
179
+
180
+ Two experimental branches:
181
+ - Hybrid-main: Full model with pretrained backbones + JEPA core
182
+ - Purist-side: Stripped-down version closer to LeWorldModel spirit
183
+ """
184
+ # Component configs
185
+ visual: VisualBackboneConfig = field(default_factory=VisualBackboneConfig)
186
+ text: TextEncoderConfig = field(default_factory=TextEncoderConfig)
187
+ evidence: EvidenceMemoryConfig = field(default_factory=EvidenceMemoryConfig)
188
+ rollout: LatentRolloutConfig = field(default_factory=LatentRolloutConfig)
189
+ jepa: JEPAObjectiveConfig = field(default_factory=JEPAObjectiveConfig)
190
+ answer: AnswerHeadConfig = field(default_factory=AnswerHeadConfig)
191
+
192
+ # Branch selection
193
+ branch: Literal["hybrid", "purist"] = "hybrid"
194
+
195
+ # Global settings
196
+ seed: int = 42
197
+
198
+ @property
199
+ def num_visual_tokens(self) -> int:
200
+ """Number of visual patch tokens output by backbone."""
201
+ n_patches = (self.visual.image_size // self.visual.patch_size) ** 2
202
+ return n_patches # Exclude CLS and register tokens (handled separately)
203
+
204
+ @property
205
+ def total_evidence_input_tokens(self) -> int:
206
+ """Total tokens feeding into evidence memory."""
207
+ n = self.num_visual_tokens + self.text.max_text_length
208
+ if self.evidence.use_ocr_tokens:
209
+ n += self.evidence.max_ocr_tokens
210
+ if self.evidence.use_layout_tokens:
211
+ n += self.evidence.max_layout_tokens
212
+ if self.evidence.use_chart_tokens:
213
+ n += self.evidence.max_chart_tokens
214
+ if self.evidence.use_sam_tokens:
215
+ n += self.evidence.max_sam_tokens
216
+ return n
217
+
218
+
219
+ @dataclass
220
+ class TrainingPhaseConfig:
221
+ """Configuration for the 3-phase training schedule."""
222
+
223
+ # Phase 1: Freeze perception, train reasoning core
224
+ phase1_epochs: int = 20
225
+ phase1_lr: float = 3e-4
226
+ phase1_warmup_ratio: float = 0.05
227
+ phase1_weight_decay: float = 0.05
228
+ phase1_batch_size: int = 32
229
+ phase1_grad_accum: int = 4
230
+
231
+ # Phase 2: Unfreeze last visual layers
232
+ phase2_epochs: int = 10
233
+ phase2_lr: float = 1e-4 # Lower LR for backbone fine-tuning
234
+ phase2_backbone_lr: float = 1e-5 # Even lower for backbone
235
+ phase2_warmup_ratio: float = 0.05
236
+ phase2_weight_decay: float = 0.05
237
+ phase2_batch_size: int = 16 # Smaller batch (more VRAM for gradients)
238
+ phase2_grad_accum: int = 8
239
+ phase2_unfreeze_visual_layers: int = 6 # Last 6 layers
240
+ phase2_unfreeze_text_layers: int = 4 # Last 4 layers
241
+
242
+ # Phase 3: Add enriched evidence
243
+ phase3_epochs: int = 10
244
+ phase3_lr: float = 5e-5
245
+ phase3_warmup_ratio: float = 0.1
246
+ phase3_weight_decay: float = 0.05
247
+ phase3_batch_size: int = 16
248
+ phase3_grad_accum: int = 8
249
+ phase3_enable_ocr: bool = True
250
+ phase3_enable_layout: bool = True
251
+ phase3_enable_chart: bool = True
252
+ phase3_enable_sam: bool = False # Optional, heavy
253
+
254
+ # Common
255
+ optimizer: str = "adamw"
256
+ scheduler: str = "cosine"
257
+ max_grad_norm: float = 1.0
258
+ fp16: bool = False
259
+ bf16: bool = True
260
+ gradient_checkpointing: bool = True
261
+
262
+
263
+ def get_hybrid_config() -> MRJEPAConfig:
264
+ """Get the Hybrid-main branch configuration."""
265
+ config = MRJEPAConfig(branch="hybrid")
266
+ # DINOv2-L backbone for strong dense features
267
+ config.visual.model_name = "facebook/dinov2-large"
268
+ config.visual.hidden_size = 1024
269
+ config.visual.image_size = 518
270
+ config.visual.patch_size = 14
271
+ return config
272
+
273
+
274
+ def get_purist_config() -> MRJEPAConfig:
275
+ """
276
+ Get the Purist-side branch configuration.
277
+ Closer to LeWorldModel: smaller backbone, stronger JEPA emphasis.
278
+ """
279
+ config = MRJEPAConfig(branch="purist")
280
+ # Smaller backbone, more emphasis on JEPA dynamics
281
+ config.visual.model_name = "facebook/dinov2-base"
282
+ config.visual.hidden_size = 768
283
+ config.visual.image_size = 518
284
+ config.visual.patch_size = 14
285
+
286
+ # Larger rollout to compensate for weaker perception
287
+ config.rollout.K = 5
288
+ config.rollout.num_state_tokens = 48
289
+ config.rollout.num_predictor_layers = 8
290
+
291
+ # Stronger JEPA objective
292
+ config.jepa.jepa_loss_weight = 2.0
293
+ config.jepa.task_loss_weight = 1.0
294
+ config.jepa.sigreg_weight = 0.2
295
+
296
+ # No enriched evidence (pure JEPA reasoning)
297
+ config.evidence.use_ocr_tokens = False
298
+ config.evidence.use_layout_tokens = False
299
+ config.evidence.use_chart_tokens = False
300
+ config.evidence.use_sam_tokens = False
301
+
302
+ # Smaller text encoder
303
+ config.text.model_name = "microsoft/deberta-v3-base"
304
+ config.text.hidden_size = 768
305
+
306
+ return config
mr_jepa/data/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .unified_dataset import UnifiedBenchmarkDataset, BenchmarkType
2
+ from .data_utils import build_dataloader, get_benchmark_config
3
+
4
+ __all__ = [
5
+ "UnifiedBenchmarkDataset",
6
+ "BenchmarkType",
7
+ "build_dataloader",
8
+ "get_benchmark_config",
9
+ ]
mr_jepa/data/data_utils.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data utilities for MR-JEPA.
3
+
4
+ Includes:
5
+ - Collator that handles variable-length options, multi-image samples
6
+ - Dataloader factory
7
+ - Benchmark configuration helpers
8
+ """
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch.utils.data import DataLoader
13
+ from typing import Optional, Dict, List, Any, Tuple
14
+ from PIL import Image
15
+ import numpy as np
16
+
17
+ from .unified_dataset import UnifiedBenchmarkDataset, BenchmarkSample, BenchmarkType
18
+
19
+
20
+ BENCHMARK_CONFIGS = {
21
+ 'mmmu': {
22
+ 'repo_id': 'MMMU/MMMU',
23
+ 'eval_split': 'validation',
24
+ 'metric': 'accuracy',
25
+ 'answer_type': 'mc',
26
+ 'configs': [
27
+ 'Accounting', 'Agriculture', 'Architecture_and_Engineering',
28
+ 'Art', 'Art_Theory', 'Basic_Medical_Science', 'Biology',
29
+ 'Chemistry', 'Clinical_Medicine', 'Computer_Science',
30
+ 'Design', 'Diagnostics_and_Laboratory_Medicine', 'Economics',
31
+ 'Electronics', 'Energy_and_Power', 'Finance', 'Geography',
32
+ 'History', 'Literature', 'Manage', 'Marketing',
33
+ 'Materials', 'Math', 'Mechanical_Engineering', 'Music',
34
+ 'Pharmacy', 'Physics', 'Psychology', 'Public_Health',
35
+ 'Sociology'
36
+ ],
37
+ },
38
+ 'mathvista': {
39
+ 'repo_id': 'AI4Math/MathVista',
40
+ 'eval_split': 'testmini',
41
+ 'metric': 'accuracy',
42
+ 'answer_type': 'mixed',
43
+ },
44
+ 'scienceqa': {
45
+ 'repo_id': 'derek-thomas/ScienceQA',
46
+ 'eval_split': 'test',
47
+ 'train_split': 'train',
48
+ 'metric': 'accuracy',
49
+ 'answer_type': 'mc',
50
+ },
51
+ 'ai2d': {
52
+ 'repo_id': 'lmms-lab/ai2d',
53
+ 'eval_split': 'test',
54
+ 'metric': 'accuracy',
55
+ 'answer_type': 'mc',
56
+ },
57
+ 'mmbench': {
58
+ 'repo_id': 'lmms-lab/MMBench',
59
+ 'eval_split': 'dev',
60
+ 'metric': 'accuracy',
61
+ 'answer_type': 'mc',
62
+ },
63
+ 'mmstar': {
64
+ 'repo_id': 'Lin-Chen/MMStar',
65
+ 'eval_split': 'val',
66
+ 'metric': 'accuracy',
67
+ 'answer_type': 'mc',
68
+ },
69
+ 'docvqa': {
70
+ 'repo_id': 'lmms-lab/DocVQA',
71
+ 'eval_split': 'validation',
72
+ 'metric': 'anls',
73
+ 'answer_type': 'open',
74
+ },
75
+ 'textvqa': {
76
+ 'repo_id': 'lmms-lab/textvqa',
77
+ 'eval_split': 'validation',
78
+ 'metric': 'vqa_accuracy',
79
+ 'answer_type': 'open',
80
+ },
81
+ 'chartqa': {
82
+ 'repo_id': 'lmms-lab/ChartQA',
83
+ 'eval_split': 'test',
84
+ 'metric': 'relaxed_accuracy',
85
+ 'answer_type': 'open',
86
+ },
87
+ }
88
+
89
+
90
+ def get_benchmark_config(benchmark: str) -> Dict:
91
+ """Get benchmark configuration."""
92
+ return BENCHMARK_CONFIGS[benchmark]
93
+
94
+
95
+ class MRJEPACollator:
96
+ """
97
+ Collator for MR-JEPA that handles:
98
+ - Variable number of images per sample (MMMU)
99
+ - Variable number of answer options
100
+ - Mixed MC/open-ended questions
101
+ - Image preprocessing via backbone processor
102
+ - Text tokenization
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ image_processor,
108
+ text_tokenizer,
109
+ max_options: int = 8,
110
+ max_text_length: int = 256,
111
+ max_gen_length: int = 64,
112
+ image_size: int = 518,
113
+ ):
114
+ self.image_processor = image_processor
115
+ self.text_tokenizer = text_tokenizer
116
+ self.max_options = max_options
117
+ self.max_text_length = max_text_length
118
+ self.max_gen_length = max_gen_length
119
+ self.image_size = image_size
120
+
121
+ def __call__(self, batch: List[BenchmarkSample]) -> Dict[str, torch.Tensor]:
122
+ """Collate a batch of BenchmarkSamples."""
123
+ B = len(batch)
124
+
125
+ # ==================== Images ====================
126
+ # Use first image for now (multi-image MMMU handled separately)
127
+ images = []
128
+ for sample in batch:
129
+ img = sample.images[0]
130
+ if not isinstance(img, Image.Image):
131
+ img = Image.new('RGB', (self.image_size, self.image_size), 'white')
132
+ images.append(img.convert('RGB'))
133
+
134
+ # Process images through backbone processor
135
+ pixel_values = self.image_processor(
136
+ images=images,
137
+ return_tensors='pt',
138
+ )['pixel_values'] # [B, C, H, W]
139
+
140
+ # ==================== Question Text ====================
141
+ questions = [s.question for s in batch]
142
+ text_encoded = self.text_tokenizer(
143
+ questions,
144
+ padding='max_length',
145
+ truncation=True,
146
+ max_length=self.max_text_length,
147
+ return_tensors='pt',
148
+ )
149
+
150
+ # ==================== Options (MC) ====================
151
+ # Encode each option separately, pad to max_options
152
+ option_embeddings_list = []
153
+ option_masks = []
154
+ answer_labels = []
155
+
156
+ has_mc = any(s.options is not None for s in batch)
157
+
158
+ if has_mc:
159
+ for sample in batch:
160
+ if sample.options:
161
+ n_opts = min(len(sample.options), self.max_options)
162
+ # Tokenize options
163
+ opts_text = sample.options[:n_opts]
164
+ # Pad option text list to max_options
165
+ while len(opts_text) < self.max_options:
166
+ opts_text.append("")
167
+
168
+ mask = [True] * n_opts + [False] * (self.max_options - n_opts)
169
+ option_masks.append(mask)
170
+
171
+ # Answer label
172
+ if isinstance(sample.answer, int):
173
+ answer_labels.append(min(sample.answer, n_opts - 1))
174
+ elif isinstance(sample.answer, str) and len(sample.answer) == 1:
175
+ answer_labels.append(ord(sample.answer.upper()) - ord('A'))
176
+ else:
177
+ answer_labels.append(0)
178
+ else:
179
+ option_masks.append([False] * self.max_options)
180
+ answer_labels.append(0)
181
+
182
+ # ==================== Open-ended answers ====================
183
+ gen_target_ids = None
184
+ has_open = any(s.answer_type == 'open' for s in batch)
185
+
186
+ if has_open:
187
+ # Prepare generative targets
188
+ gen_texts = []
189
+ for sample in batch:
190
+ if sample.answer_type == 'open':
191
+ if isinstance(sample.answer, list):
192
+ gen_texts.append(str(sample.answer[0]))
193
+ else:
194
+ gen_texts.append(str(sample.answer))
195
+ else:
196
+ gen_texts.append("")
197
+
198
+ gen_encoded = self.text_tokenizer(
199
+ gen_texts,
200
+ padding='max_length',
201
+ truncation=True,
202
+ max_length=self.max_gen_length,
203
+ return_tensors='pt',
204
+ )
205
+ gen_target_ids = gen_encoded['input_ids']
206
+
207
+ # ==================== Build output dict ====================
208
+ result = {
209
+ 'pixel_values': pixel_values,
210
+ 'input_ids': text_encoded['input_ids'],
211
+ 'attention_mask': text_encoded['attention_mask'],
212
+ }
213
+
214
+ if has_mc:
215
+ result['option_mask'] = torch.tensor(option_masks, dtype=torch.bool)
216
+ result['answer_labels'] = torch.tensor(answer_labels, dtype=torch.long)
217
+
218
+ # We need to encode options through text encoder at runtime
219
+ # Store raw option texts for the model to encode
220
+ all_option_texts = []
221
+ for sample in batch:
222
+ opts = sample.options or [""] * self.max_options
223
+ opts = opts[:self.max_options]
224
+ while len(opts) < self.max_options:
225
+ opts.append("")
226
+ all_option_texts.append(opts)
227
+ result['option_texts'] = all_option_texts
228
+
229
+ if gen_target_ids is not None:
230
+ result['gen_target_ids'] = gen_target_ids
231
+
232
+ # Metadata
233
+ result['benchmarks'] = [s.benchmark for s in batch]
234
+ result['answer_types'] = [s.answer_type for s in batch]
235
+ result['raw_answers'] = [s.answer for s in batch]
236
+
237
+ return result
238
+
239
+
240
+ def build_dataloader(
241
+ benchmark: str,
242
+ split: str,
243
+ image_processor,
244
+ text_tokenizer,
245
+ batch_size: int = 32,
246
+ num_workers: int = 4,
247
+ max_samples: Optional[int] = None,
248
+ config: Optional[str] = None,
249
+ **collator_kwargs,
250
+ ) -> DataLoader:
251
+ """Build a DataLoader for a specific benchmark."""
252
+ dataset = UnifiedBenchmarkDataset(
253
+ benchmark=benchmark,
254
+ split=split,
255
+ config=config,
256
+ max_samples=max_samples,
257
+ )
258
+
259
+ collator = MRJEPACollator(
260
+ image_processor=image_processor,
261
+ text_tokenizer=text_tokenizer,
262
+ **collator_kwargs,
263
+ )
264
+
265
+ return DataLoader(
266
+ dataset,
267
+ batch_size=batch_size,
268
+ shuffle=(split in ('train', 'training')),
269
+ num_workers=num_workers,
270
+ collate_fn=collator,
271
+ pin_memory=True,
272
+ drop_last=(split in ('train', 'training')),
273
+ )
mr_jepa/data/unified_dataset.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified Dataset for MR-JEPA Benchmarks.
3
+
4
+ Handles all 9 benchmarks with their quirky formats in a single pipeline:
5
+
6
+ MC Benchmarks:
7
+ - MMMU: up to 7 images, string-encoded options, letter answers
8
+ - ScienceQA: nullable images, int8 answer index
9
+ - AI2D: string-encoded int index answer
10
+ - MMBench: separate A/B/C/D columns
11
+ - MMStar: options embedded in question text
12
+
13
+ Open-Ended Benchmarks:
14
+ - MathVista: mixed MC/free-form, dual image columns
15
+ - DocVQA: multiple valid answers (ANLS metric)
16
+ - TextVQA: 10 annotations (VQA Accuracy)
17
+ - ChartQA: relaxed numeric accuracy
18
+
19
+ Each sample is normalized to a common format:
20
+ {
21
+ 'image': PIL.Image or List[PIL.Image],
22
+ 'question': str,
23
+ 'options': List[str] or None, # None for open-ended
24
+ 'answer': str or int, # Correct answer
25
+ 'answer_type': 'mc' or 'open',
26
+ 'benchmark': str,
27
+ 'metadata': dict,
28
+ }
29
+ """
30
+
31
+ import ast
32
+ import re
33
+ import torch
34
+ from torch.utils.data import Dataset
35
+ from PIL import Image
36
+ from enum import Enum
37
+ from typing import Optional, Dict, List, Any, Tuple
38
+ from dataclasses import dataclass
39
+
40
+
41
+ class BenchmarkType(Enum):
42
+ MMMU = "mmmu"
43
+ MATHVISTA = "mathvista"
44
+ SCIENCEQA = "scienceqa"
45
+ AI2D = "ai2d"
46
+ MMBENCH = "mmbench"
47
+ MMSTAR = "mmstar"
48
+ DOCVQA = "docvqa"
49
+ TEXTVQA = "textvqa"
50
+ CHARTQA = "chartqa"
51
+
52
+
53
+ @dataclass
54
+ class BenchmarkSample:
55
+ """Normalized sample format across all benchmarks."""
56
+ images: List[Image.Image] # 1+ images (MMMU can have up to 7)
57
+ question: str
58
+ options: Optional[List[str]] # None for open-ended
59
+ answer: Any # str (letter/text) or int (index)
60
+ answer_type: str # 'mc' or 'open'
61
+ benchmark: str
62
+ metadata: Dict[str, Any]
63
+
64
+
65
+ class UnifiedBenchmarkDataset(Dataset):
66
+ """
67
+ Unified dataset that loads any of the 9 benchmarks into a common format.
68
+
69
+ Usage:
70
+ dataset = UnifiedBenchmarkDataset(
71
+ benchmark='mmmu',
72
+ split='validation',
73
+ config='Accounting', # MMMU has per-subject configs
74
+ )
75
+ sample = dataset[0] # Returns BenchmarkSample
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ benchmark: str,
81
+ split: str = "validation",
82
+ config: Optional[str] = None,
83
+ max_samples: Optional[int] = None,
84
+ transform: Optional[Any] = None,
85
+ ):
86
+ self.benchmark = BenchmarkType(benchmark)
87
+ self.split = split
88
+ self.transform = transform
89
+
90
+ # Load dataset
91
+ self.data = self._load_dataset(config, max_samples)
92
+
93
+ def _load_dataset(self, config: Optional[str], max_samples: Optional[int]):
94
+ """Load dataset from HuggingFace Hub."""
95
+ from datasets import load_dataset
96
+
97
+ repo_map = {
98
+ BenchmarkType.MMMU: "MMMU/MMMU",
99
+ BenchmarkType.MATHVISTA: "AI4Math/MathVista",
100
+ BenchmarkType.SCIENCEQA: "derek-thomas/ScienceQA",
101
+ BenchmarkType.AI2D: "lmms-lab/ai2d",
102
+ BenchmarkType.MMBENCH: "lmms-lab/MMBench",
103
+ BenchmarkType.MMSTAR: "Lin-Chen/MMStar",
104
+ BenchmarkType.DOCVQA: "lmms-lab/DocVQA",
105
+ BenchmarkType.TEXTVQA: "lmms-lab/textvqa",
106
+ BenchmarkType.CHARTQA: "lmms-lab/ChartQA",
107
+ }
108
+
109
+ repo_id = repo_map[self.benchmark]
110
+
111
+ # Handle config/split variations
112
+ kwargs = {}
113
+ if config:
114
+ kwargs['name'] = config
115
+ elif self.benchmark == BenchmarkType.MMBENCH:
116
+ kwargs['name'] = 'en'
117
+ elif self.benchmark == BenchmarkType.DOCVQA:
118
+ kwargs['name'] = 'DocVQA'
119
+
120
+ # Some datasets have different split names
121
+ split_name = self.split
122
+ if self.benchmark == BenchmarkType.MMSTAR and self.split == 'validation':
123
+ split_name = 'val'
124
+
125
+ try:
126
+ ds = load_dataset(repo_id, split=split_name, **kwargs)
127
+ except Exception as e:
128
+ # Fallback: try without config
129
+ print(f"Warning: Failed to load {repo_id} with config={config}, split={split_name}: {e}")
130
+ ds = load_dataset(repo_id, split=split_name)
131
+
132
+ if max_samples:
133
+ ds = ds.select(range(min(max_samples, len(ds))))
134
+
135
+ return ds
136
+
137
+ def __len__(self):
138
+ return len(self.data)
139
+
140
+ def __getitem__(self, idx: int) -> BenchmarkSample:
141
+ row = self.data[idx]
142
+
143
+ # Dispatch to benchmark-specific parser
144
+ parser_map = {
145
+ BenchmarkType.MMMU: self._parse_mmmu,
146
+ BenchmarkType.MATHVISTA: self._parse_mathvista,
147
+ BenchmarkType.SCIENCEQA: self._parse_scienceqa,
148
+ BenchmarkType.AI2D: self._parse_ai2d,
149
+ BenchmarkType.MMBENCH: self._parse_mmbench,
150
+ BenchmarkType.MMSTAR: self._parse_mmstar,
151
+ BenchmarkType.DOCVQA: self._parse_docvqa,
152
+ BenchmarkType.TEXTVQA: self._parse_textvqa,
153
+ BenchmarkType.CHARTQA: self._parse_chartqa,
154
+ }
155
+
156
+ return parser_map[self.benchmark](row)
157
+
158
+ # ==================== Benchmark-Specific Parsers ====================
159
+
160
+ def _parse_mmmu(self, row) -> BenchmarkSample:
161
+ """MMMU: up to 7 images, string-encoded options."""
162
+ images = []
163
+ for i in range(1, 8):
164
+ img = row.get(f'image_{i}')
165
+ if img is not None:
166
+ if isinstance(img, Image.Image):
167
+ images.append(img)
168
+
169
+ if not images:
170
+ # Create a blank image as fallback
171
+ images = [Image.new('RGB', (224, 224), color='white')]
172
+
173
+ # Parse options (string-encoded Python list)
174
+ options_str = row.get('options', '[]')
175
+ try:
176
+ options = ast.literal_eval(options_str) if isinstance(options_str, str) else options_str
177
+ except (ValueError, SyntaxError):
178
+ options = []
179
+
180
+ question = row['question']
181
+ answer = row.get('answer', 'A')
182
+
183
+ return BenchmarkSample(
184
+ images=images,
185
+ question=question,
186
+ options=options if options else None,
187
+ answer=answer,
188
+ answer_type='mc' if row.get('question_type', 'multiple-choice') == 'multiple-choice' else 'open',
189
+ benchmark='mmmu',
190
+ metadata={
191
+ 'id': row.get('id', ''),
192
+ 'subject': row.get('subfield', ''),
193
+ 'difficulty': row.get('topic_difficulty', ''),
194
+ 'img_type': row.get('img_type', ''),
195
+ }
196
+ )
197
+
198
+ def _parse_mathvista(self, row) -> BenchmarkSample:
199
+ """MathVista: mixed MC/free-form, use decoded_image."""
200
+ image = row.get('decoded_image') or row.get('image')
201
+ if isinstance(image, str):
202
+ # It's a path, not an image — this shouldn't happen with decoded_image
203
+ image = Image.new('RGB', (224, 224), color='white')
204
+ images = [image] if image else [Image.new('RGB', (224, 224), color='white')]
205
+
206
+ question = row.get('query', row.get('question', ''))
207
+ choices = row.get('choices', None)
208
+ answer = row.get('answer', '')
209
+ qtype = row.get('question_type', 'free_form')
210
+
211
+ return BenchmarkSample(
212
+ images=images,
213
+ question=question,
214
+ options=list(choices) if choices else None,
215
+ answer=answer,
216
+ answer_type='mc' if qtype == 'multi_choice' else 'open',
217
+ benchmark='mathvista',
218
+ metadata={
219
+ 'pid': row.get('pid', ''),
220
+ 'answer_type': row.get('answer_type', ''),
221
+ 'unit': row.get('unit', ''),
222
+ }
223
+ )
224
+
225
+ def _parse_scienceqa(self, row) -> BenchmarkSample:
226
+ """ScienceQA: nullable images, int8 answer index."""
227
+ image = row.get('image')
228
+ if image is None:
229
+ images = [Image.new('RGB', (224, 224), color='white')]
230
+ has_image = False
231
+ else:
232
+ images = [image]
233
+ has_image = True
234
+
235
+ choices = row.get('choices', [])
236
+ answer_idx = int(row.get('answer', 0))
237
+
238
+ return BenchmarkSample(
239
+ images=images,
240
+ question=row['question'],
241
+ options=list(choices),
242
+ answer=answer_idx, # 0-indexed integer
243
+ answer_type='mc',
244
+ benchmark='scienceqa',
245
+ metadata={
246
+ 'has_image': has_image,
247
+ 'subject': row.get('subject', ''),
248
+ 'grade': row.get('grade', ''),
249
+ 'hint': row.get('hint', ''),
250
+ 'lecture': row.get('lecture', ''),
251
+ 'solution': row.get('solution', ''),
252
+ }
253
+ )
254
+
255
+ def _parse_ai2d(self, row) -> BenchmarkSample:
256
+ """AI2D: string-encoded int index answer."""
257
+ images = [row['image']]
258
+ options = list(row.get('options', []))
259
+ answer_idx = int(row.get('answer', '0'))
260
+
261
+ return BenchmarkSample(
262
+ images=images,
263
+ question=row['question'],
264
+ options=options,
265
+ answer=answer_idx, # 0-indexed integer
266
+ answer_type='mc',
267
+ benchmark='ai2d',
268
+ metadata={}
269
+ )
270
+
271
+ def _parse_mmbench(self, row) -> BenchmarkSample:
272
+ """MMBench: separate A/B/C/D columns."""
273
+ images = [row['image']]
274
+
275
+ # Build options from separate columns
276
+ options = []
277
+ for letter in ['A', 'B', 'C', 'D']:
278
+ opt = row.get(letter, '')
279
+ if opt:
280
+ options.append(opt)
281
+
282
+ # Answer is a letter
283
+ answer = row.get('answer', 'A')
284
+ # Convert letter to index
285
+ answer_idx = ord(answer) - ord('A') if isinstance(answer, str) and len(answer) == 1 else 0
286
+
287
+ return BenchmarkSample(
288
+ images=images,
289
+ question=row['question'],
290
+ options=options,
291
+ answer=answer_idx,
292
+ answer_type='mc',
293
+ benchmark='mmbench',
294
+ metadata={
295
+ 'category': row.get('category', ''),
296
+ 'hint': row.get('hint', ''),
297
+ }
298
+ )
299
+
300
+ def _parse_mmstar(self, row) -> BenchmarkSample:
301
+ """MMStar: options embedded in question text."""
302
+ images = [row['image']]
303
+ question = row['question']
304
+
305
+ # Parse options from question text
306
+ # Format: "... Options: A: ..., B: ..., C: ..., D: ..."
307
+ options = []
308
+ option_pattern = r'([A-D]):\s*([^,\n]+(?:,\s*[^A-D\n][^,\n]*)*)'
309
+ matches = re.findall(option_pattern, question)
310
+ if matches:
311
+ for letter, text in matches:
312
+ options.append(text.strip())
313
+
314
+ answer = row.get('answer', 'A')
315
+ answer_idx = ord(answer) - ord('A') if isinstance(answer, str) and len(answer) == 1 else 0
316
+
317
+ return BenchmarkSample(
318
+ images=images,
319
+ question=question,
320
+ options=options if options else None,
321
+ answer=answer_idx,
322
+ answer_type='mc',
323
+ benchmark='mmstar',
324
+ metadata={
325
+ 'category': row.get('category', ''),
326
+ 'l2_category': row.get('l2_category', ''),
327
+ }
328
+ )
329
+
330
+ def _parse_docvqa(self, row) -> BenchmarkSample:
331
+ """DocVQA: multiple valid answers."""
332
+ images = [row['image']]
333
+ answers = row.get('answers', [''])
334
+
335
+ return BenchmarkSample(
336
+ images=images,
337
+ question=row['question'],
338
+ options=None,
339
+ answer=answers, # List of valid answers
340
+ answer_type='open',
341
+ benchmark='docvqa',
342
+ metadata={
343
+ 'question_id': row.get('questionId', ''),
344
+ 'question_types': row.get('question_types', []),
345
+ }
346
+ )
347
+
348
+ def _parse_textvqa(self, row) -> BenchmarkSample:
349
+ """TextVQA: 10 annotations."""
350
+ images = [row['image']]
351
+ answers = row.get('answers', [''])
352
+
353
+ return BenchmarkSample(
354
+ images=images,
355
+ question=row['question'],
356
+ options=None,
357
+ answer=answers, # 10 annotations
358
+ answer_type='open',
359
+ benchmark='textvqa',
360
+ metadata={
361
+ 'question_id': row.get('question_id', ''),
362
+ 'ocr_tokens': row.get('ocr_tokens', []),
363
+ }
364
+ )
365
+
366
+ def _parse_chartqa(self, row) -> BenchmarkSample:
367
+ """ChartQA: relaxed numeric accuracy."""
368
+ images = [row['image']]
369
+
370
+ return BenchmarkSample(
371
+ images=images,
372
+ question=row['question'],
373
+ options=None,
374
+ answer=row.get('answer', ''),
375
+ answer_type='open',
376
+ benchmark='chartqa',
377
+ metadata={
378
+ 'type': row.get('type', ''),
379
+ }
380
+ )
mr_jepa/evaluation/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .metrics import (
2
+ compute_accuracy,
3
+ compute_anls,
4
+ compute_vqa_accuracy,
5
+ compute_relaxed_accuracy,
6
+ evaluate_benchmark,
7
+ )
8
+
9
+ __all__ = [
10
+ "compute_accuracy",
11
+ "compute_anls",
12
+ "compute_vqa_accuracy",
13
+ "compute_relaxed_accuracy",
14
+ "evaluate_benchmark",
15
+ ]
mr_jepa/evaluation/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (332 Bytes). View file
 
mr_jepa/evaluation/__pycache__/metrics.cpython-312.pyc ADDED
Binary file (10.1 kB). View file
 
mr_jepa/evaluation/metrics.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation Metrics for MR-JEPA Benchmarks.
3
+
4
+ Each benchmark has specific evaluation protocols:
5
+ - Accuracy: MMMU, ScienceQA, AI2D, MMBench, MMStar
6
+ - ANLS: DocVQA (Average Normalized Levenshtein Similarity)
7
+ - VQA Accuracy: TextVQA (soft majority over 10 annotations)
8
+ - Relaxed Accuracy: ChartQA (±5% tolerance for numerics)
9
+ - Mixed: MathVista (accuracy for MC, relaxed match for free-form)
10
+ """
11
+
12
+ import re
13
+ import torch
14
+ import numpy as np
15
+ from typing import List, Dict, Optional, Any, Union
16
+ from collections import defaultdict
17
+
18
+
19
+ def compute_accuracy(
20
+ predictions: List[int],
21
+ ground_truth: List[int],
22
+ category_labels: Optional[List[str]] = None,
23
+ ) -> Dict[str, float]:
24
+ """
25
+ Standard accuracy for MC benchmarks.
26
+
27
+ Args:
28
+ predictions: Predicted option indices
29
+ ground_truth: Correct option indices
30
+ category_labels: Optional per-sample categories for breakdown
31
+
32
+ Returns:
33
+ Dict with 'accuracy' and optional per-category breakdown
34
+ """
35
+ assert len(predictions) == len(ground_truth)
36
+
37
+ correct = sum(p == g for p, g in zip(predictions, ground_truth))
38
+ total = len(predictions)
39
+
40
+ result = {'accuracy': correct / max(total, 1) * 100}
41
+
42
+ # Per-category breakdown
43
+ if category_labels:
44
+ cat_correct = defaultdict(int)
45
+ cat_total = defaultdict(int)
46
+ for p, g, c in zip(predictions, ground_truth, category_labels):
47
+ cat_total[c] += 1
48
+ if p == g:
49
+ cat_correct[c] += 1
50
+
51
+ result['per_category'] = {
52
+ c: cat_correct[c] / max(cat_total[c], 1) * 100
53
+ for c in sorted(cat_total.keys())
54
+ }
55
+
56
+ return result
57
+
58
+
59
+ def _normalized_levenshtein(s1: str, s2: str) -> float:
60
+ """Compute normalized Levenshtein distance between two strings."""
61
+ s1 = s1.lower().strip()
62
+ s2 = s2.lower().strip()
63
+
64
+ if s1 == s2:
65
+ return 0.0
66
+
67
+ len1, len2 = len(s1), len(s2)
68
+ if len1 == 0 or len2 == 0:
69
+ return 1.0
70
+
71
+ # Dynamic programming Levenshtein
72
+ matrix = [[0] * (len2 + 1) for _ in range(len1 + 1)]
73
+ for i in range(len1 + 1):
74
+ matrix[i][0] = i
75
+ for j in range(len2 + 1):
76
+ matrix[0][j] = j
77
+
78
+ for i in range(1, len1 + 1):
79
+ for j in range(1, len2 + 1):
80
+ cost = 0 if s1[i-1] == s2[j-1] else 1
81
+ matrix[i][j] = min(
82
+ matrix[i-1][j] + 1,
83
+ matrix[i][j-1] + 1,
84
+ matrix[i-1][j-1] + cost,
85
+ )
86
+
87
+ return matrix[len1][len2] / max(len1, len2)
88
+
89
+
90
+ def compute_anls(
91
+ predictions: List[str],
92
+ ground_truths: List[List[str]],
93
+ threshold: float = 0.5,
94
+ ) -> Dict[str, float]:
95
+ """
96
+ Average Normalized Levenshtein Similarity (ANLS) for DocVQA.
97
+
98
+ ANLS = 1 - NL_distance if NL_distance < threshold, else 0
99
+ Final score is max over all valid answers.
100
+
101
+ Args:
102
+ predictions: List of predicted answer strings
103
+ ground_truths: List of lists of valid answer strings
104
+ threshold: NL distance threshold (default 0.5)
105
+ """
106
+ scores = []
107
+ for pred, gts in zip(predictions, ground_truths):
108
+ if not gts:
109
+ scores.append(0.0)
110
+ continue
111
+
112
+ # Take max ANLS over all valid answers
113
+ max_score = 0.0
114
+ for gt in gts:
115
+ nl_dist = _normalized_levenshtein(pred, gt)
116
+ if nl_dist < threshold:
117
+ score = 1.0 - nl_dist
118
+ else:
119
+ score = 0.0
120
+ max_score = max(max_score, score)
121
+
122
+ scores.append(max_score)
123
+
124
+ return {'anls': np.mean(scores) * 100 if scores else 0.0}
125
+
126
+
127
+ def compute_vqa_accuracy(
128
+ predictions: List[str],
129
+ ground_truths: List[List[str]],
130
+ ) -> Dict[str, float]:
131
+ """
132
+ VQA Accuracy for TextVQA.
133
+
134
+ score = min(count(matching annotations) / 3, 1.0)
135
+
136
+ Args:
137
+ predictions: Predicted answers
138
+ ground_truths: Lists of 10 human annotations per question
139
+ """
140
+ scores = []
141
+ for pred, gts in zip(predictions, ground_truths):
142
+ pred_norm = pred.lower().strip()
143
+ matching = sum(1 for gt in gts if gt.lower().strip() == pred_norm)
144
+ score = min(matching / 3.0, 1.0)
145
+ scores.append(score)
146
+
147
+ return {'vqa_accuracy': np.mean(scores) * 100 if scores else 0.0}
148
+
149
+
150
+ def _is_numeric(s: str) -> bool:
151
+ """Check if string represents a number."""
152
+ try:
153
+ float(s.replace(',', '').replace('%', '').strip())
154
+ return True
155
+ except (ValueError, AttributeError):
156
+ return False
157
+
158
+
159
+ def _parse_numeric(s: str) -> float:
160
+ """Parse numeric value from string."""
161
+ s = s.replace(',', '').replace('%', '').strip()
162
+ return float(s)
163
+
164
+
165
+ def compute_relaxed_accuracy(
166
+ predictions: List[str],
167
+ ground_truths: List[str],
168
+ tolerance: float = 0.05,
169
+ types: Optional[List[str]] = None,
170
+ ) -> Dict[str, float]:
171
+ """
172
+ Relaxed Accuracy for ChartQA.
173
+
174
+ - Numeric answers: within ±5% tolerance
175
+ - String answers: exact match (case-insensitive)
176
+
177
+ Args:
178
+ predictions: Predicted answers
179
+ ground_truths: Ground truth answers
180
+ tolerance: Numeric tolerance (default 5%)
181
+ types: Optional list of 'human_test'/'augmented_test' for breakdown
182
+ """
183
+ correct = []
184
+ for pred, gt in zip(predictions, ground_truths):
185
+ pred_str = str(pred).strip().lower()
186
+ gt_str = str(gt).strip().lower()
187
+
188
+ if _is_numeric(gt_str) and _is_numeric(pred_str):
189
+ gt_val = _parse_numeric(gt_str)
190
+ pred_val = _parse_numeric(pred_str)
191
+ if gt_val == 0:
192
+ is_correct = abs(pred_val) <= tolerance
193
+ else:
194
+ is_correct = abs(pred_val - gt_val) / abs(gt_val) <= tolerance
195
+ else:
196
+ is_correct = pred_str == gt_str
197
+
198
+ correct.append(is_correct)
199
+
200
+ result = {'relaxed_accuracy': np.mean(correct) * 100 if correct else 0.0}
201
+
202
+ # Per-type breakdown (human vs augmented)
203
+ if types:
204
+ for t in set(types):
205
+ type_correct = [c for c, tp in zip(correct, types) if tp == t]
206
+ result[f'relaxed_accuracy_{t}'] = np.mean(type_correct) * 100 if type_correct else 0.0
207
+
208
+ return result
209
+
210
+
211
+ def evaluate_benchmark(
212
+ benchmark: str,
213
+ predictions: List[Any],
214
+ ground_truths: List[Any],
215
+ metadata: Optional[Dict[str, List]] = None,
216
+ ) -> Dict[str, float]:
217
+ """
218
+ Evaluate predictions for a specific benchmark.
219
+
220
+ Dispatches to the appropriate metric function.
221
+ """
222
+ metric_map = {
223
+ 'mmmu': 'accuracy',
224
+ 'scienceqa': 'accuracy',
225
+ 'ai2d': 'accuracy',
226
+ 'mmbench': 'accuracy',
227
+ 'mmstar': 'accuracy',
228
+ 'mathvista': 'accuracy', # Simplified; full eval handles mixed types
229
+ 'docvqa': 'anls',
230
+ 'textvqa': 'vqa_accuracy',
231
+ 'chartqa': 'relaxed_accuracy',
232
+ }
233
+
234
+ metric = metric_map.get(benchmark, 'accuracy')
235
+
236
+ if metric == 'accuracy':
237
+ categories = metadata.get('categories') if metadata else None
238
+ return compute_accuracy(predictions, ground_truths, categories)
239
+
240
+ elif metric == 'anls':
241
+ return compute_anls(predictions, ground_truths)
242
+
243
+ elif metric == 'vqa_accuracy':
244
+ return compute_vqa_accuracy(predictions, ground_truths)
245
+
246
+ elif metric == 'relaxed_accuracy':
247
+ types = metadata.get('types') if metadata else None
248
+ return compute_relaxed_accuracy(predictions, ground_truths, types=types)
249
+
250
+ else:
251
+ raise ValueError(f"Unknown metric: {metric}")
mr_jepa/models/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .mr_jepa import MRJEPAModel
2
+ from .evidence_memory import EvidenceMemory
3
+ from .latent_rollout import LatentRolloutModule
4
+ from .answer_heads import DiscriminativeHead, GenerativeHead
5
+ from .backbones import VisualBackbone, TextEncoder
6
+ from .target_encoder import TargetEncoder
7
+
8
+ __all__ = [
9
+ "MRJEPAModel",
10
+ "EvidenceMemory",
11
+ "LatentRolloutModule",
12
+ "DiscriminativeHead",
13
+ "GenerativeHead",
14
+ "VisualBackbone",
15
+ "TextEncoder",
16
+ "TargetEncoder",
17
+ ]
mr_jepa/models/__pycache__/answer_heads.cpython-312.pyc ADDED
Binary file (14.6 kB). View file
 
mr_jepa/models/__pycache__/evidence_memory.cpython-312.pyc ADDED
Binary file (14 kB). View file
 
mr_jepa/models/__pycache__/latent_rollout.cpython-312.pyc ADDED
Binary file (13 kB). View file
 
mr_jepa/models/__pycache__/target_encoder.cpython-312.pyc ADDED
Binary file (15 kB). View file
 
mr_jepa/models/answer_heads.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Answer Prediction Heads for MR-JEPA.
3
+
4
+ Two heads:
5
+ 1. Discriminative Head (primary): Scores answer options for MC questions.
6
+ Takes the final latent state z_K and computes compatibility scores
7
+ with encoded answer option representations.
8
+
9
+ 2. Generative Head (secondary): Short text decoder for open-ended answers.
10
+ Small transformer decoder that cross-attends to the final latent state
11
+ and evidence memory, constrained to produce brief answers.
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import math
18
+ from typing import Optional, Dict, Tuple
19
+
20
+ from ..configs.model_config import AnswerHeadConfig
21
+
22
+
23
+ class DiscriminativeHead(nn.Module):
24
+ """
25
+ Multiple-choice answer scoring head.
26
+
27
+ Architecture:
28
+ 1. Pool latent state z_K → global reasoning vector
29
+ 2. Encode each answer option via a small MLP
30
+ 3. Compute compatibility score: score_i = MLP(z_pool ⊙ opt_i)
31
+
32
+ Supports variable number of options (2-8, with masking).
33
+ """
34
+
35
+ def __init__(self, config: AnswerHeadConfig, hidden_dim: int, text_dim: int):
36
+ super().__init__()
37
+ self.config = config
38
+ self.hidden_dim = hidden_dim
39
+
40
+ # State pooling: attention-weighted pooling over state tokens
41
+ self.state_pool_query = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
42
+ self.state_pool_attn = nn.MultiheadAttention(
43
+ embed_dim=hidden_dim,
44
+ num_heads=8,
45
+ batch_first=True,
46
+ )
47
+ self.state_pool_norm = nn.LayerNorm(hidden_dim)
48
+
49
+ # Option encoder: project text option embeddings
50
+ self.option_proj = nn.Sequential(
51
+ nn.Linear(text_dim, hidden_dim),
52
+ nn.LayerNorm(hidden_dim),
53
+ nn.GELU(),
54
+ nn.Linear(hidden_dim, hidden_dim),
55
+ )
56
+
57
+ # Score computation: bilinear-style scoring
58
+ self.score_mlp = nn.Sequential(
59
+ nn.Linear(hidden_dim * 3, config.disc_hidden_dim),
60
+ nn.GELU(),
61
+ nn.Dropout(config.disc_dropout),
62
+ nn.Linear(config.disc_hidden_dim, config.disc_hidden_dim),
63
+ nn.GELU(),
64
+ nn.Dropout(config.disc_dropout),
65
+ nn.Linear(config.disc_hidden_dim, 1),
66
+ )
67
+
68
+ def _pool_state(self, z_final: torch.Tensor) -> torch.Tensor:
69
+ """
70
+ Attention-weighted pooling of final latent state.
71
+
72
+ Args:
73
+ z_final: [B, N_s, D]
74
+
75
+ Returns:
76
+ Pooled state vector [B, D]
77
+ """
78
+ B = z_final.size(0)
79
+ query = self.state_pool_query.expand(B, -1, -1) # [B, 1, D]
80
+ z_normed = self.state_pool_norm(z_final)
81
+ pooled, _ = self.state_pool_attn(query, z_normed, z_normed)
82
+ return pooled.squeeze(1) # [B, D]
83
+
84
+ def forward(
85
+ self,
86
+ z_final: torch.Tensor, # [B, N_s, D] final latent state
87
+ option_embeddings: torch.Tensor, # [B, max_opts, D_text] encoded options
88
+ option_mask: torch.Tensor, # [B, max_opts] bool: True=valid
89
+ ) -> Dict[str, torch.Tensor]:
90
+ """
91
+ Score answer options.
92
+
93
+ Returns:
94
+ dict with:
95
+ 'logits': [B, max_opts] raw scores
96
+ 'probs': [B, max_opts] masked softmax probabilities
97
+ """
98
+ B, max_opts = option_mask.shape
99
+
100
+ # Pool final latent state
101
+ z_pooled = self._pool_state(z_final) # [B, D]
102
+
103
+ # Project option embeddings
104
+ opt_proj = self.option_proj(option_embeddings) # [B, max_opts, D]
105
+
106
+ # Compute scores for each option
107
+ z_expanded = z_pooled.unsqueeze(1).expand(-1, max_opts, -1) # [B, max_opts, D]
108
+
109
+ # Concatenate: [z, opt, z⊙opt] for rich interaction
110
+ combined = torch.cat([
111
+ z_expanded,
112
+ opt_proj,
113
+ z_expanded * opt_proj, # Element-wise interaction
114
+ ], dim=-1) # [B, max_opts, 3*D]
115
+
116
+ logits = self.score_mlp(combined).squeeze(-1) # [B, max_opts]
117
+
118
+ # Mask invalid options
119
+ logits = logits.masked_fill(~option_mask, float('-inf'))
120
+ probs = F.softmax(logits, dim=-1)
121
+
122
+ return {
123
+ 'logits': logits,
124
+ 'probs': probs,
125
+ }
126
+
127
+
128
+ class GenerativeHead(nn.Module):
129
+ """
130
+ Short-answer generative decoder.
131
+
132
+ Small transformer decoder that:
133
+ 1. Cross-attends to the final latent state z_K
134
+ 2. Optionally cross-attends to evidence memory (evidence-constrained)
135
+ 3. Autoregressively generates a short answer (≤64 tokens)
136
+
137
+ This is a secondary objective — the primary evaluation uses the
138
+ discriminative head for MC questions.
139
+ """
140
+
141
+ def __init__(
142
+ self,
143
+ config: AnswerHeadConfig,
144
+ hidden_dim: int,
145
+ vocab_size: int,
146
+ ):
147
+ super().__init__()
148
+ self.config = config
149
+ self.hidden_dim = hidden_dim
150
+ self.vocab_size = vocab_size
151
+
152
+ # Token embedding + positional encoding
153
+ self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
154
+ self.pos_embedding = nn.Embedding(config.gen_max_answer_length, hidden_dim)
155
+
156
+ # Transformer decoder layers
157
+ self.decoder_layers = nn.ModuleList()
158
+ for _ in range(config.gen_num_layers):
159
+ self.decoder_layers.append(
160
+ GenerativeDecoderLayer(
161
+ hidden_dim=hidden_dim,
162
+ num_heads=config.gen_num_heads,
163
+ dropout=config.gen_dropout,
164
+ use_evidence_cross_attn=config.use_evidence_constraint,
165
+ )
166
+ )
167
+
168
+ # Output projection to vocabulary
169
+ self.output_norm = nn.LayerNorm(hidden_dim)
170
+ self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)
171
+
172
+ # Tie weights with token embedding
173
+ self.lm_head.weight = self.token_embedding.weight
174
+
175
+ def forward(
176
+ self,
177
+ z_final: torch.Tensor, # [B, N_s, D]
178
+ target_ids: torch.Tensor, # [B, seq_len]
179
+ evidence_tokens: Optional[torch.Tensor] = None, # [B, N_e, D]
180
+ evidence_mask: Optional[torch.Tensor] = None,
181
+ ) -> Dict[str, torch.Tensor]:
182
+ """
183
+ Teacher-forced forward pass for training.
184
+
185
+ Args:
186
+ z_final: Final latent state from rollout
187
+ target_ids: Target answer token IDs
188
+ evidence_tokens: Evidence memory for constrained decoding
189
+
190
+ Returns:
191
+ dict with:
192
+ 'logits': [B, seq_len, vocab_size]
193
+ 'loss': scalar cross-entropy loss
194
+ """
195
+ B, seq_len = target_ids.shape
196
+ device = target_ids.device
197
+
198
+ # Embed target tokens
199
+ positions = torch.arange(seq_len, device=device).unsqueeze(0)
200
+ x = self.token_embedding(target_ids) + self.pos_embedding(positions)
201
+
202
+ # Causal mask
203
+ causal_mask = torch.triu(
204
+ torch.ones(seq_len, seq_len, device=device, dtype=torch.bool),
205
+ diagonal=1
206
+ )
207
+
208
+ # Apply decoder layers
209
+ for layer in self.decoder_layers:
210
+ x = layer(
211
+ x=x,
212
+ z_final=z_final,
213
+ causal_mask=causal_mask,
214
+ evidence_tokens=evidence_tokens,
215
+ evidence_mask=evidence_mask,
216
+ )
217
+
218
+ # Project to vocabulary
219
+ logits = self.lm_head(self.output_norm(x)) # [B, seq_len, vocab]
220
+
221
+ # Compute loss (shift by 1 for next-token prediction)
222
+ shift_logits = logits[:, :-1].contiguous()
223
+ shift_labels = target_ids[:, 1:].contiguous()
224
+ loss = F.cross_entropy(
225
+ shift_logits.view(-1, self.vocab_size),
226
+ shift_labels.view(-1),
227
+ ignore_index=-100,
228
+ )
229
+
230
+ return {
231
+ 'logits': logits,
232
+ 'loss': loss,
233
+ }
234
+
235
+ @torch.no_grad()
236
+ def generate(
237
+ self,
238
+ z_final: torch.Tensor,
239
+ start_token_id: int,
240
+ max_length: int = 64,
241
+ evidence_tokens: Optional[torch.Tensor] = None,
242
+ evidence_mask: Optional[torch.Tensor] = None,
243
+ eos_token_id: Optional[int] = None,
244
+ ) -> torch.Tensor:
245
+ """
246
+ Autoregressive generation for inference.
247
+
248
+ Returns:
249
+ generated_ids: [B, gen_len]
250
+ """
251
+ B = z_final.size(0)
252
+ device = z_final.device
253
+
254
+ generated = torch.full((B, 1), start_token_id, dtype=torch.long, device=device)
255
+
256
+ for step in range(max_length - 1):
257
+ seq_len = generated.size(1)
258
+ positions = torch.arange(seq_len, device=device).unsqueeze(0)
259
+ x = self.token_embedding(generated) + self.pos_embedding(positions)
260
+
261
+ causal_mask = torch.triu(
262
+ torch.ones(seq_len, seq_len, device=device, dtype=torch.bool),
263
+ diagonal=1
264
+ )
265
+
266
+ for layer in self.decoder_layers:
267
+ x = layer(
268
+ x=x,
269
+ z_final=z_final,
270
+ causal_mask=causal_mask,
271
+ evidence_tokens=evidence_tokens,
272
+ evidence_mask=evidence_mask,
273
+ )
274
+
275
+ logits = self.lm_head(self.output_norm(x[:, -1:])) # [B, 1, vocab]
276
+ next_token = logits.argmax(dim=-1) # [B, 1]
277
+ generated = torch.cat([generated, next_token], dim=1)
278
+
279
+ # Check EOS
280
+ if eos_token_id is not None:
281
+ if (next_token == eos_token_id).all():
282
+ break
283
+
284
+ return generated
285
+
286
+
287
+ class GenerativeDecoderLayer(nn.Module):
288
+ """Single transformer decoder layer with optional evidence cross-attention."""
289
+
290
+ def __init__(
291
+ self,
292
+ hidden_dim: int,
293
+ num_heads: int,
294
+ dropout: float,
295
+ use_evidence_cross_attn: bool = True,
296
+ ):
297
+ super().__init__()
298
+
299
+ # Causal self-attention
300
+ self.self_attn = nn.MultiheadAttention(
301
+ embed_dim=hidden_dim, num_heads=num_heads,
302
+ dropout=dropout, batch_first=True,
303
+ )
304
+ self.self_attn_norm = nn.LayerNorm(hidden_dim)
305
+
306
+ # Cross-attention to latent state z_K
307
+ self.state_cross_attn = nn.MultiheadAttention(
308
+ embed_dim=hidden_dim, num_heads=num_heads,
309
+ dropout=dropout, batch_first=True,
310
+ )
311
+ self.state_cross_norm = nn.LayerNorm(hidden_dim)
312
+
313
+ # Optional: cross-attention to evidence memory
314
+ self.use_evidence_cross_attn = use_evidence_cross_attn
315
+ if use_evidence_cross_attn:
316
+ self.evidence_cross_attn = nn.MultiheadAttention(
317
+ embed_dim=hidden_dim, num_heads=num_heads,
318
+ dropout=dropout, batch_first=True,
319
+ )
320
+ self.evidence_cross_norm = nn.LayerNorm(hidden_dim)
321
+
322
+ # FFN
323
+ self.ffn = nn.Sequential(
324
+ nn.Linear(hidden_dim, hidden_dim * 4),
325
+ nn.GELU(),
326
+ nn.Dropout(dropout),
327
+ nn.Linear(hidden_dim * 4, hidden_dim),
328
+ nn.Dropout(dropout),
329
+ )
330
+ self.ffn_norm = nn.LayerNorm(hidden_dim)
331
+
332
+ def forward(
333
+ self,
334
+ x: torch.Tensor,
335
+ z_final: torch.Tensor,
336
+ causal_mask: torch.Tensor,
337
+ evidence_tokens: Optional[torch.Tensor] = None,
338
+ evidence_mask: Optional[torch.Tensor] = None,
339
+ ) -> torch.Tensor:
340
+ # Causal self-attention
341
+ residual = x
342
+ x_normed = self.self_attn_norm(x)
343
+ x_out, _ = self.self_attn(
344
+ x_normed, x_normed, x_normed,
345
+ attn_mask=causal_mask,
346
+ )
347
+ x = residual + x_out
348
+
349
+ # Cross-attention to latent state
350
+ residual = x
351
+ x_normed = self.state_cross_norm(x)
352
+ x_out, _ = self.state_cross_attn(x_normed, z_final, z_final)
353
+ x = residual + x_out
354
+
355
+ # Optional evidence cross-attention
356
+ if self.use_evidence_cross_attn and evidence_tokens is not None:
357
+ residual = x
358
+ x_normed = self.evidence_cross_norm(x)
359
+ x_out, _ = self.evidence_cross_attn(
360
+ x_normed, evidence_tokens, evidence_tokens,
361
+ key_padding_mask=evidence_mask,
362
+ )
363
+ x = residual + x_out
364
+
365
+ # FFN
366
+ residual = x
367
+ x = residual + self.ffn(self.ffn_norm(x))
368
+
369
+ return x
mr_jepa/models/backbones.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visual and Text Backbone Encoders for MR-JEPA.
3
+
4
+ Visual: DINOv2-L/G or DINOv3-L (dense SSL features, no text alignment)
5
+ Text: DeBERTa-v3 (strong NLU encoder for questions + options)
6
+
7
+ Both backbones are frozen in Phase 1 and partially unfrozen in Phase 2.
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from typing import Optional, Dict, Any
13
+
14
+ from ..configs.model_config import VisualBackboneConfig, TextEncoderConfig
15
+
16
+
17
+ class VisualBackbone(nn.Module):
18
+ """
19
+ Dense visual feature extractor using DINOv2/v3 or SigLIP2.
20
+
21
+ Outputs patch-level tokens (excluding CLS and register tokens).
22
+ For DINOv2-L at 518px: 1369 patch tokens × 1024 dim.
23
+ """
24
+
25
+ def __init__(self, config: VisualBackboneConfig):
26
+ super().__init__()
27
+ self.config = config
28
+ self.backbone = None
29
+ self.hidden_size = config.hidden_size
30
+ self._build_backbone()
31
+
32
+ if config.freeze:
33
+ self.freeze_all()
34
+
35
+ def _build_backbone(self):
36
+ """Load pretrained backbone from HuggingFace."""
37
+ from transformers import AutoModel, AutoImageProcessor
38
+
39
+ if self.config.backbone_type in ("dinov2", "dinov3"):
40
+ self.backbone = AutoModel.from_pretrained(
41
+ self.config.model_name,
42
+ torch_dtype=torch.float32, # DINOv2 is fp32
43
+ )
44
+ self.processor = AutoImageProcessor.from_pretrained(
45
+ self.config.model_name
46
+ )
47
+ # DINOv2/v3 outputs: last_hidden_state includes [CLS] + registers + patches
48
+ self._skip_tokens = 1 + self.config.num_register_tokens # CLS + regs
49
+
50
+ elif self.config.backbone_type == "siglip2":
51
+ from transformers import SiglipVisionModel, SiglipImageProcessor
52
+ self.backbone = SiglipVisionModel.from_pretrained(
53
+ self.config.model_name,
54
+ torch_dtype=torch.float32,
55
+ )
56
+ self.processor = SiglipImageProcessor.from_pretrained(
57
+ self.config.model_name
58
+ )
59
+ self._skip_tokens = 0 # SigLIP has no CLS or register tokens
60
+
61
+ def freeze_all(self):
62
+ """Freeze all backbone parameters."""
63
+ for param in self.backbone.parameters():
64
+ param.requires_grad = False
65
+
66
+ def unfreeze_last_n_layers(self, n: int):
67
+ """Unfreeze the last N transformer layers (Phase 2)."""
68
+ # DINOv2 uses model.encoder.layer[i]
69
+ if hasattr(self.backbone, 'encoder'):
70
+ layers = self.backbone.encoder.layer
71
+ elif hasattr(self.backbone, 'vision_model'):
72
+ layers = self.backbone.vision_model.encoder.layers
73
+ else:
74
+ raise ValueError(f"Unknown backbone structure for {self.config.model_name}")
75
+
76
+ total_layers = len(layers)
77
+ for i, layer in enumerate(layers):
78
+ if i >= total_layers - n:
79
+ for param in layer.parameters():
80
+ param.requires_grad = True
81
+
82
+ def forward(
83
+ self,
84
+ pixel_values: torch.Tensor, # [B, C, H, W]
85
+ return_cls: bool = False,
86
+ ) -> Dict[str, torch.Tensor]:
87
+ """
88
+ Extract dense patch tokens from images.
89
+
90
+ Args:
91
+ pixel_values: Preprocessed image tensors [B, C, H, W]
92
+ return_cls: Whether to also return the CLS token
93
+
94
+ Returns:
95
+ dict with:
96
+ 'patch_tokens': [B, num_patches, hidden_size]
97
+ 'cls_token': [B, hidden_size] (if return_cls=True)
98
+ """
99
+ outputs = self.backbone(pixel_values=pixel_values)
100
+ hidden_states = outputs.last_hidden_state # [B, 1+reg+patches, D]
101
+
102
+ result = {}
103
+ result['patch_tokens'] = hidden_states[:, self._skip_tokens:] # [B, num_patches, D]
104
+
105
+ if return_cls:
106
+ result['cls_token'] = hidden_states[:, 0] # [B, D]
107
+
108
+ return result
109
+
110
+
111
+ class TextEncoder(nn.Module):
112
+ """
113
+ Text encoder for questions, options, and optional context.
114
+
115
+ Uses DeBERTa-v3 for strong NLU. Outputs:
116
+ - Token-level representations for cross-attention
117
+ - [CLS] representation for global text understanding
118
+ """
119
+
120
+ def __init__(self, config: TextEncoderConfig):
121
+ super().__init__()
122
+ self.config = config
123
+ self.hidden_size = config.hidden_size
124
+ self._build_encoder()
125
+
126
+ if config.freeze:
127
+ self.freeze_all()
128
+
129
+ def _build_encoder(self):
130
+ """Load pretrained text encoder."""
131
+ from transformers import AutoModel, AutoTokenizer
132
+
133
+ self.encoder = AutoModel.from_pretrained(
134
+ self.config.model_name,
135
+ torch_dtype=torch.float32,
136
+ )
137
+ self.tokenizer = AutoTokenizer.from_pretrained(
138
+ self.config.model_name
139
+ )
140
+
141
+ def freeze_all(self):
142
+ for param in self.encoder.parameters():
143
+ param.requires_grad = False
144
+
145
+ def unfreeze_last_n_layers(self, n: int):
146
+ if hasattr(self.encoder, 'encoder'):
147
+ layers = self.encoder.encoder.layer
148
+ else:
149
+ raise ValueError(f"Unknown encoder structure for {self.config.model_name}")
150
+
151
+ total_layers = len(layers)
152
+ for i, layer in enumerate(layers):
153
+ if i >= total_layers - n:
154
+ for param in layer.parameters():
155
+ param.requires_grad = True
156
+
157
+ def forward(
158
+ self,
159
+ input_ids: torch.Tensor, # [B, seq_len]
160
+ attention_mask: torch.Tensor, # [B, seq_len]
161
+ ) -> Dict[str, torch.Tensor]:
162
+ """
163
+ Encode text (question + options).
164
+
165
+ Returns:
166
+ dict with:
167
+ 'token_embeddings': [B, seq_len, hidden_size]
168
+ 'cls_embedding': [B, hidden_size]
169
+ 'attention_mask': [B, seq_len]
170
+ """
171
+ outputs = self.encoder(
172
+ input_ids=input_ids,
173
+ attention_mask=attention_mask,
174
+ )
175
+
176
+ return {
177
+ 'token_embeddings': outputs.last_hidden_state,
178
+ 'cls_embedding': outputs.last_hidden_state[:, 0],
179
+ 'attention_mask': attention_mask,
180
+ }
mr_jepa/models/evidence_memory.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evidence Memory Module for MR-JEPA.
3
+
4
+ The Evidence Memory is a unified multimodal representation that fuses:
5
+ 1. Dense visual patch tokens (from DINOv2/v3)
6
+ 2. Text tokens (question + options from DeBERTa)
7
+ 3. Optional enriched tokens: OCR, layout, chart structure, SAM segments
8
+
9
+ Architecture:
10
+ - N learnable evidence query tokens
11
+ - Cross-attention layers: queries attend to all input modalities
12
+ - Each cross-attention layer also has self-attention among queries
13
+ - Output: N evidence tokens that capture the full multimodal context
14
+
15
+ This is inspired by Perceiver/Q-Former architectures but designed specifically
16
+ as the initial evidence state for the JEPA rollout.
17
+ """
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ import math
23
+ from typing import Optional, Dict, List
24
+
25
+ from ..configs.model_config import EvidenceMemoryConfig
26
+
27
+
28
+ class CrossAttentionLayer(nn.Module):
29
+ """
30
+ Single cross-attention layer with self-attention.
31
+
32
+ Flow: self_attn(queries) → cross_attn(queries, kv=evidence) → FFN
33
+ """
34
+
35
+ def __init__(self, hidden_dim: int, num_heads: int, dropout: float = 0.1):
36
+ super().__init__()
37
+ self.hidden_dim = hidden_dim
38
+ self.num_heads = num_heads
39
+ self.head_dim = hidden_dim // num_heads
40
+
41
+ # Self-attention among evidence queries
42
+ self.self_attn = nn.MultiheadAttention(
43
+ embed_dim=hidden_dim,
44
+ num_heads=num_heads,
45
+ dropout=dropout,
46
+ batch_first=True,
47
+ )
48
+ self.self_attn_norm = nn.LayerNorm(hidden_dim)
49
+
50
+ # Cross-attention: queries attend to input tokens
51
+ self.cross_attn = nn.MultiheadAttention(
52
+ embed_dim=hidden_dim,
53
+ num_heads=num_heads,
54
+ dropout=dropout,
55
+ batch_first=True,
56
+ )
57
+ self.cross_attn_norm = nn.LayerNorm(hidden_dim)
58
+
59
+ # FFN
60
+ self.ffn = nn.Sequential(
61
+ nn.Linear(hidden_dim, hidden_dim * 4),
62
+ nn.GELU(),
63
+ nn.Dropout(dropout),
64
+ nn.Linear(hidden_dim * 4, hidden_dim),
65
+ nn.Dropout(dropout),
66
+ )
67
+ self.ffn_norm = nn.LayerNorm(hidden_dim)
68
+
69
+ def forward(
70
+ self,
71
+ queries: torch.Tensor, # [B, N_q, D]
72
+ kv_tokens: torch.Tensor, # [B, N_kv, D]
73
+ kv_mask: Optional[torch.Tensor] = None, # [B, N_kv] bool
74
+ ) -> torch.Tensor:
75
+ """
76
+ Args:
77
+ queries: Evidence query tokens [B, N_q, D]
78
+ kv_tokens: Concatenated input tokens [B, N_kv, D]
79
+ kv_mask: Key padding mask for kv_tokens [B, N_kv]
80
+
81
+ Returns:
82
+ Updated queries [B, N_q, D]
83
+ """
84
+ # Self-attention among queries
85
+ residual = queries
86
+ queries = self.self_attn_norm(queries)
87
+ queries_out, _ = self.self_attn(queries, queries, queries)
88
+ queries = residual + queries_out
89
+
90
+ # Cross-attention to input tokens
91
+ residual = queries
92
+ queries_normed = self.cross_attn_norm(queries)
93
+ queries_out, _ = self.cross_attn(
94
+ query=queries_normed,
95
+ key=kv_tokens,
96
+ value=kv_tokens,
97
+ key_padding_mask=kv_mask,
98
+ )
99
+ queries = residual + queries_out
100
+
101
+ # FFN
102
+ residual = queries
103
+ queries = residual + self.ffn(self.ffn_norm(queries))
104
+
105
+ return queries
106
+
107
+
108
+ class ModalityProjector(nn.Module):
109
+ """Projects tokens from a specific modality to the evidence memory dimension."""
110
+
111
+ def __init__(self, input_dim: int, output_dim: int):
112
+ super().__init__()
113
+ self.proj = nn.Sequential(
114
+ nn.Linear(input_dim, output_dim),
115
+ nn.LayerNorm(output_dim),
116
+ nn.GELU(),
117
+ nn.Linear(output_dim, output_dim),
118
+ )
119
+
120
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
121
+ return self.proj(x)
122
+
123
+
124
+ class EvidenceMemory(nn.Module):
125
+ """
126
+ Unified Evidence Memory that fuses all input modalities.
127
+
128
+ The output evidence tokens serve as:
129
+ 1. The basis for constructing the initial latent state z₀
130
+ 2. The key-value memory for evidence-gated cross-attention in rollout steps
131
+
132
+ Architecture follows a Perceiver-style design with learnable queries
133
+ cross-attending to projected multimodal tokens.
134
+ """
135
+
136
+ def __init__(
137
+ self,
138
+ config: EvidenceMemoryConfig,
139
+ visual_dim: int,
140
+ text_dim: int,
141
+ ):
142
+ super().__init__()
143
+ self.config = config
144
+ self.hidden_dim = config.hidden_dim
145
+
146
+ # Learnable evidence query tokens
147
+ self.evidence_queries = nn.Parameter(
148
+ torch.randn(1, config.num_evidence_tokens, config.hidden_dim) * 0.02
149
+ )
150
+
151
+ # Modality projectors
152
+ self.visual_proj = ModalityProjector(visual_dim, config.hidden_dim)
153
+ self.text_proj = ModalityProjector(text_dim, config.hidden_dim)
154
+
155
+ # Modality type embeddings (to distinguish sources in cross-attention)
156
+ self.modality_embeddings = nn.Embedding(6, config.hidden_dim)
157
+ # 0=visual, 1=text, 2=ocr, 3=layout, 4=chart, 5=sam
158
+
159
+ # Optional enriched evidence projectors (Phase 3)
160
+ if config.use_ocr_tokens:
161
+ self.ocr_proj = ModalityProjector(text_dim, config.hidden_dim)
162
+ if config.use_layout_tokens:
163
+ self.layout_proj = ModalityProjector(256, config.hidden_dim) # Layout features
164
+ if config.use_chart_tokens:
165
+ self.chart_proj = ModalityProjector(512, config.hidden_dim) # Chart structure
166
+ if config.use_sam_tokens:
167
+ self.sam_proj = ModalityProjector(256, config.hidden_dim) # SAM2 features
168
+
169
+ # Cross-attention layers
170
+ self.layers = nn.ModuleList([
171
+ CrossAttentionLayer(
172
+ hidden_dim=config.hidden_dim,
173
+ num_heads=config.num_heads,
174
+ dropout=config.dropout,
175
+ )
176
+ for _ in range(config.num_cross_attn_layers)
177
+ ])
178
+
179
+ # Final norm
180
+ self.output_norm = nn.LayerNorm(config.hidden_dim)
181
+
182
+ def _prepare_kv_tokens(
183
+ self,
184
+ visual_tokens: torch.Tensor, # [B, N_v, D_v]
185
+ text_tokens: torch.Tensor, # [B, N_t, D_t]
186
+ text_mask: torch.Tensor, # [B, N_t]
187
+ ocr_tokens: Optional[torch.Tensor] = None, # [B, N_ocr, D_t]
188
+ ocr_mask: Optional[torch.Tensor] = None,
189
+ layout_tokens: Optional[torch.Tensor] = None, # [B, N_lay, D_lay]
190
+ layout_mask: Optional[torch.Tensor] = None,
191
+ chart_tokens: Optional[torch.Tensor] = None, # [B, N_ch, D_ch]
192
+ chart_mask: Optional[torch.Tensor] = None,
193
+ sam_tokens: Optional[torch.Tensor] = None, # [B, N_sam, D_sam]
194
+ sam_mask: Optional[torch.Tensor] = None,
195
+ ):
196
+ """Project all modalities and concatenate into a single KV sequence."""
197
+ B = visual_tokens.size(0)
198
+ device = visual_tokens.device
199
+
200
+ all_tokens = []
201
+ all_masks = []
202
+
203
+ # Visual tokens (always present)
204
+ v_proj = self.visual_proj(visual_tokens) # [B, N_v, D]
205
+ v_proj = v_proj + self.modality_embeddings(
206
+ torch.zeros(v_proj.size(1), dtype=torch.long, device=device)
207
+ ).unsqueeze(0)
208
+ all_tokens.append(v_proj)
209
+ all_masks.append(torch.zeros(B, v_proj.size(1), dtype=torch.bool, device=device))
210
+
211
+ # Text tokens (always present)
212
+ t_proj = self.text_proj(text_tokens) # [B, N_t, D]
213
+ t_proj = t_proj + self.modality_embeddings(
214
+ torch.ones(t_proj.size(1), dtype=torch.long, device=device)
215
+ ).unsqueeze(0)
216
+ all_tokens.append(t_proj)
217
+ # Invert mask: True = padding (to be masked out)
218
+ all_masks.append(~text_mask.bool())
219
+
220
+ # Optional modalities (Phase 3)
221
+ if ocr_tokens is not None and self.config.use_ocr_tokens:
222
+ o_proj = self.ocr_proj(ocr_tokens)
223
+ o_proj = o_proj + self.modality_embeddings(
224
+ torch.full((o_proj.size(1),), 2, dtype=torch.long, device=device)
225
+ ).unsqueeze(0)
226
+ all_tokens.append(o_proj)
227
+ all_masks.append(~ocr_mask.bool() if ocr_mask is not None
228
+ else torch.zeros(B, o_proj.size(1), dtype=torch.bool, device=device))
229
+
230
+ if layout_tokens is not None and self.config.use_layout_tokens:
231
+ l_proj = self.layout_proj(layout_tokens)
232
+ l_proj = l_proj + self.modality_embeddings(
233
+ torch.full((l_proj.size(1),), 3, dtype=torch.long, device=device)
234
+ ).unsqueeze(0)
235
+ all_tokens.append(l_proj)
236
+ all_masks.append(~layout_mask.bool() if layout_mask is not None
237
+ else torch.zeros(B, l_proj.size(1), dtype=torch.bool, device=device))
238
+
239
+ if chart_tokens is not None and self.config.use_chart_tokens:
240
+ c_proj = self.chart_proj(chart_tokens)
241
+ c_proj = c_proj + self.modality_embeddings(
242
+ torch.full((c_proj.size(1),), 4, dtype=torch.long, device=device)
243
+ ).unsqueeze(0)
244
+ all_tokens.append(c_proj)
245
+ all_masks.append(~chart_mask.bool() if chart_mask is not None
246
+ else torch.zeros(B, c_proj.size(1), dtype=torch.bool, device=device))
247
+
248
+ if sam_tokens is not None and self.config.use_sam_tokens:
249
+ s_proj = self.sam_proj(sam_tokens)
250
+ s_proj = s_proj + self.modality_embeddings(
251
+ torch.full((s_proj.size(1),), 5, dtype=torch.long, device=device)
252
+ ).unsqueeze(0)
253
+ all_tokens.append(s_proj)
254
+ all_masks.append(~sam_mask.bool() if sam_mask is not None
255
+ else torch.zeros(B, s_proj.size(1), dtype=torch.bool, device=device))
256
+
257
+ # Concatenate all modalities
258
+ kv_tokens = torch.cat(all_tokens, dim=1) # [B, N_total, D]
259
+ kv_mask = torch.cat(all_masks, dim=1) # [B, N_total]
260
+
261
+ return kv_tokens, kv_mask
262
+
263
+ def forward(
264
+ self,
265
+ visual_tokens: torch.Tensor,
266
+ text_tokens: torch.Tensor,
267
+ text_mask: torch.Tensor,
268
+ **enriched_kwargs,
269
+ ) -> Dict[str, torch.Tensor]:
270
+ """
271
+ Fuse all modalities into evidence tokens.
272
+
273
+ Returns:
274
+ dict with:
275
+ 'evidence_tokens': [B, N_evidence, D] - fused evidence
276
+ 'kv_tokens': [B, N_total, D] - projected multimodal KV for rollout
277
+ 'kv_mask': [B, N_total] - mask for KV tokens
278
+ """
279
+ B = visual_tokens.size(0)
280
+
281
+ # Prepare KV tokens from all modalities
282
+ kv_tokens, kv_mask = self._prepare_kv_tokens(
283
+ visual_tokens, text_tokens, text_mask, **enriched_kwargs
284
+ )
285
+
286
+ # Expand learnable queries for batch
287
+ queries = self.evidence_queries.expand(B, -1, -1) # [B, N_q, D]
288
+
289
+ # Apply cross-attention layers
290
+ for layer in self.layers:
291
+ queries = layer(queries, kv_tokens, kv_mask)
292
+
293
+ evidence_tokens = self.output_norm(queries) # [B, N_evidence, D]
294
+
295
+ return {
296
+ 'evidence_tokens': evidence_tokens,
297
+ 'kv_tokens': kv_tokens,
298
+ 'kv_mask': kv_mask,
299
+ }
mr_jepa/models/latent_rollout.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Latent Belief-State Rollout Module for MR-JEPA.
3
+
4
+ This is the core JEPA reasoning module. It models the evolution of a
5
+ multimodal belief state as the system "reasons" about a question:
6
+
7
+ z₀ → z₁ → z₂ → z₃ (K=3 steps)
8
+
9
+ Each step applies a shared predictor block with evidence gating:
10
+ 1. Self-attention: latent state tokens attend to each other
11
+ 2. Evidence-gated cross-attention: state attends to evidence memory
12
+ 3. FFN with residual
13
+
14
+ Key design choices grounded in literature:
15
+ - SHARED predictor across steps (weight-tied, like V-JEPA/LeWorldModel)
16
+ - Step embeddings to differentiate rollout positions
17
+ - Evidence gates (sigmoid/softmax) control information flow per step
18
+ - The predictor is a "narrow" transformer (from I-JEPA: predictor is
19
+ smaller than encoder)
20
+
21
+ The JEPA objective supervises this trajectory: the target encoder (EMA)
22
+ generates z*_k targets, and the predictor must predict z*_k from z_{k-1}.
23
+ """
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ import math
29
+ from typing import Optional, Dict, List, Tuple
30
+
31
+ from ..configs.model_config import LatentRolloutConfig
32
+
33
+
34
+ class EvidenceGate(nn.Module):
35
+ """
36
+ Learned gate that controls how much evidence flows into each rollout step.
37
+
38
+ Intuition: Early steps may need more visual evidence, while later steps
39
+ may rely more on accumulated reasoning. The gate learns this schedule.
40
+ """
41
+
42
+ def __init__(self, hidden_dim: int, gate_type: str = "sigmoid"):
43
+ super().__init__()
44
+ self.gate_type = gate_type
45
+
46
+ if gate_type == "sigmoid":
47
+ # Per-dimension gate: scales each feature independently
48
+ self.gate_proj = nn.Sequential(
49
+ nn.Linear(hidden_dim * 2, hidden_dim),
50
+ nn.Sigmoid(),
51
+ )
52
+ elif gate_type == "learned":
53
+ # Scalar gate per token, learned as a function of state + evidence
54
+ self.gate_proj = nn.Sequential(
55
+ nn.Linear(hidden_dim * 2, hidden_dim),
56
+ nn.ReLU(),
57
+ nn.Linear(hidden_dim, 1),
58
+ nn.Sigmoid(),
59
+ )
60
+ # softmax gate is implemented in forward via attention weights
61
+
62
+ def forward(
63
+ self,
64
+ state: torch.Tensor, # [B, N_s, D]
65
+ evidence_contribution: torch.Tensor, # [B, N_s, D]
66
+ ) -> torch.Tensor:
67
+ """
68
+ Apply evidence gate.
69
+
70
+ Args:
71
+ state: Current latent state
72
+ evidence_contribution: Cross-attention output from evidence
73
+
74
+ Returns:
75
+ Gated evidence contribution [B, N_s, D]
76
+ """
77
+ if self.gate_type == "sigmoid":
78
+ gate = self.gate_proj(torch.cat([state, evidence_contribution], dim=-1))
79
+ return gate * evidence_contribution
80
+ elif self.gate_type == "learned":
81
+ gate = self.gate_proj(torch.cat([state, evidence_contribution], dim=-1))
82
+ return gate * evidence_contribution
83
+ else:
84
+ # No explicit gating (softmax via attention weights)
85
+ return evidence_contribution
86
+
87
+
88
+ class PredictorBlock(nn.Module):
89
+ """
90
+ Single rollout step predictor block.
91
+
92
+ This is the "narrow" predictor from I-JEPA adapted for reasoning:
93
+ - Self-attention among latent state tokens
94
+ - Evidence-gated cross-attention to evidence memory
95
+ - FFN
96
+
97
+ All K rollout steps share this same block (weight-tied).
98
+ """
99
+
100
+ def __init__(
101
+ self,
102
+ hidden_dim: int,
103
+ num_heads: int,
104
+ ffn_dim: int,
105
+ dropout: float,
106
+ gate_type: str = "sigmoid",
107
+ ):
108
+ super().__init__()
109
+
110
+ # Self-attention among state tokens
111
+ self.self_attn = nn.MultiheadAttention(
112
+ embed_dim=hidden_dim,
113
+ num_heads=num_heads,
114
+ dropout=dropout,
115
+ batch_first=True,
116
+ )
117
+ self.self_attn_norm = nn.LayerNorm(hidden_dim)
118
+
119
+ # Cross-attention to evidence memory
120
+ self.cross_attn = nn.MultiheadAttention(
121
+ embed_dim=hidden_dim,
122
+ num_heads=num_heads,
123
+ dropout=dropout,
124
+ batch_first=True,
125
+ )
126
+ self.cross_attn_norm = nn.LayerNorm(hidden_dim)
127
+
128
+ # Evidence gate
129
+ self.evidence_gate = EvidenceGate(hidden_dim, gate_type)
130
+
131
+ # FFN
132
+ self.ffn = nn.Sequential(
133
+ nn.Linear(hidden_dim, ffn_dim),
134
+ nn.GELU(),
135
+ nn.Dropout(dropout),
136
+ nn.Linear(ffn_dim, hidden_dim),
137
+ nn.Dropout(dropout),
138
+ )
139
+ self.ffn_norm = nn.LayerNorm(hidden_dim)
140
+
141
+ def forward(
142
+ self,
143
+ state: torch.Tensor, # [B, N_s, D]
144
+ evidence_kv: torch.Tensor, # [B, N_e, D]
145
+ evidence_mask: Optional[torch.Tensor] = None, # [B, N_e]
146
+ ) -> torch.Tensor:
147
+ """One rollout step: state → updated state."""
148
+ # Self-attention
149
+ residual = state
150
+ state_normed = self.self_attn_norm(state)
151
+ state_out, _ = self.self_attn(state_normed, state_normed, state_normed)
152
+ state = residual + state_out
153
+
154
+ # Cross-attention to evidence
155
+ residual = state
156
+ state_normed = self.cross_attn_norm(state)
157
+ evidence_contribution, _ = self.cross_attn(
158
+ query=state_normed,
159
+ key=evidence_kv,
160
+ value=evidence_kv,
161
+ key_padding_mask=evidence_mask,
162
+ )
163
+
164
+ # Apply evidence gate
165
+ gated_evidence = self.evidence_gate(state, evidence_contribution)
166
+ state = residual + gated_evidence
167
+
168
+ # FFN
169
+ residual = state
170
+ state = residual + self.ffn(self.ffn_norm(state))
171
+
172
+ return state
173
+
174
+
175
+ class LatentRolloutModule(nn.Module):
176
+ """
177
+ Full latent belief-state rollout.
178
+
179
+ Constructs z₀ from evidence memory, then refines it over K steps.
180
+ Each step uses the same shared PredictorBlock (weight-tied across steps).
181
+
182
+ The full trajectory [z₀, z₁, ..., z_K] is returned for the JEPA objective.
183
+
184
+ Architecture:
185
+ z₀ = LinearProj(evidence_pool) + state_init_tokens
186
+ For k in 1..K:
187
+ z_k = PredictorBlock(z_{k-1}, evidence_memory) + step_emb[k]
188
+ """
189
+
190
+ def __init__(self, config: LatentRolloutConfig):
191
+ super().__init__()
192
+ self.config = config
193
+ self.K = config.K
194
+ self.hidden_dim = config.hidden_dim
195
+ self.num_state_tokens = config.num_state_tokens
196
+
197
+ # Initial state construction
198
+ # Learnable state initialization tokens
199
+ self.state_init = nn.Parameter(
200
+ torch.randn(1, config.num_state_tokens, config.hidden_dim) * 0.02
201
+ )
202
+
203
+ # Project evidence summary into initial state
204
+ self.z0_proj = nn.Sequential(
205
+ nn.Linear(config.hidden_dim, config.hidden_dim),
206
+ nn.LayerNorm(config.hidden_dim),
207
+ nn.GELU(),
208
+ nn.Linear(config.hidden_dim, config.hidden_dim),
209
+ )
210
+
211
+ # Step embeddings (learned per-step bias)
212
+ if config.use_step_embedding:
213
+ self.step_embeddings = nn.Parameter(
214
+ torch.randn(config.K + 1, 1, config.hidden_dim) * 0.02
215
+ ) # [K+1, 1, D] — one per step including z₀
216
+
217
+ # Shared predictor block (weight-tied across K steps)
218
+ # We use a stack of transformer layers as the predictor
219
+ self.predictor_layers = nn.ModuleList([
220
+ PredictorBlock(
221
+ hidden_dim=config.hidden_dim,
222
+ num_heads=config.num_heads,
223
+ ffn_dim=config.ffn_dim,
224
+ dropout=config.dropout,
225
+ gate_type=config.gate_type if config.use_evidence_gate else "none",
226
+ )
227
+ for _ in range(config.num_predictor_layers)
228
+ ])
229
+
230
+ # Output projection (project each z_k to prediction space)
231
+ self.output_proj = nn.Sequential(
232
+ nn.LayerNorm(config.hidden_dim),
233
+ nn.Linear(config.hidden_dim, config.hidden_dim),
234
+ )
235
+
236
+ def _construct_z0(
237
+ self,
238
+ evidence_tokens: torch.Tensor, # [B, N_e, D]
239
+ ) -> torch.Tensor:
240
+ """
241
+ Construct initial latent state z₀ from evidence.
242
+
243
+ z₀ = state_init_tokens + projected_evidence_pool + step_emb[0]
244
+
245
+ The evidence pool is computed by adaptive average pooling the evidence
246
+ tokens down to the number of state tokens.
247
+ """
248
+ B = evidence_tokens.size(0)
249
+
250
+ # Pool evidence into state-sized representation
251
+ # [B, N_e, D] → [B, N_s, D] via adaptive pooling
252
+ evidence_pooled = F.adaptive_avg_pool1d(
253
+ evidence_tokens.permute(0, 2, 1), # [B, D, N_e]
254
+ self.num_state_tokens
255
+ ).permute(0, 2, 1) # [B, N_s, D]
256
+
257
+ # Project and combine with learnable init
258
+ z0 = self.state_init.expand(B, -1, -1) + self.z0_proj(evidence_pooled)
259
+
260
+ # Add step embedding for step 0
261
+ if self.config.use_step_embedding:
262
+ z0 = z0 + self.step_embeddings[0].unsqueeze(0)
263
+
264
+ return z0
265
+
266
+ def _single_rollout_step(
267
+ self,
268
+ z_prev: torch.Tensor, # [B, N_s, D]
269
+ evidence_tokens: torch.Tensor, # [B, N_e, D]
270
+ evidence_mask: Optional[torch.Tensor],
271
+ ) -> torch.Tensor:
272
+ """Apply the shared predictor block for one rollout step."""
273
+ z = z_prev
274
+ for layer in self.predictor_layers:
275
+ z = layer(z, evidence_tokens, evidence_mask)
276
+ return z
277
+
278
+ def forward(
279
+ self,
280
+ evidence_tokens: torch.Tensor, # [B, N_e, D]
281
+ evidence_mask: Optional[torch.Tensor] = None, # [B, N_e]
282
+ ) -> Dict[str, torch.Tensor]:
283
+ """
284
+ Full K-step latent rollout.
285
+
286
+ Args:
287
+ evidence_tokens: Fused evidence from EvidenceMemory [B, N_e, D]
288
+ evidence_mask: Padding mask for evidence tokens
289
+
290
+ Returns:
291
+ dict with:
292
+ 'trajectory': [B, K+1, N_s, D] - full latent trajectory
293
+ 'z_final': [B, N_s, D] - final latent state z_K
294
+ 'z_projected': [B, K+1, N_s, D] - projected trajectory for JEPA loss
295
+ """
296
+ # Construct z₀
297
+ z = self._construct_z0(evidence_tokens)
298
+
299
+ trajectory = [z]
300
+
301
+ # Rollout K steps
302
+ for k in range(1, self.K + 1):
303
+ z = self._single_rollout_step(z, evidence_tokens, evidence_mask)
304
+
305
+ # Add step embedding
306
+ if self.config.use_step_embedding:
307
+ z = z + self.step_embeddings[k].unsqueeze(0)
308
+
309
+ trajectory.append(z)
310
+
311
+ # Stack trajectory: [B, K+1, N_s, D]
312
+ trajectory_tensor = torch.stack(trajectory, dim=1)
313
+
314
+ # Project each state for JEPA prediction loss
315
+ B, Kp1, N_s, D = trajectory_tensor.shape
316
+ flat = trajectory_tensor.reshape(B * Kp1 * N_s, D)
317
+ projected_flat = self.output_proj(flat)
318
+ z_projected = projected_flat.reshape(B, Kp1, N_s, D)
319
+
320
+ return {
321
+ 'trajectory': trajectory_tensor, # Raw states
322
+ 'z_final': trajectory[-1], # Final state
323
+ 'z_projected': z_projected, # For JEPA loss
324
+ }
mr_jepa/models/mr_jepa.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MR-JEPA: Multimodal Reasoning via Joint-Embedding Predictive Architecture.
3
+
4
+ Complete model that integrates all components:
5
+ Visual Backbone → Evidence Memory ← Text Encoder
6
+ Evidence Memory → z₀ → Latent Rollout (K=3) → Answer Heads
7
+ Target Encoder (EMA) → JEPA Supervision
8
+
9
+ The model supports two branches:
10
+ - Hybrid-main: Full model, pretrained backbones, competitive on benchmarks
11
+ - Purist-side: Stripped-down, closer to LeWorldModel spirit
12
+
13
+ Forward pass:
14
+ 1. Extract visual tokens (DINOv2/v3)
15
+ 2. Encode question + options (DeBERTa)
16
+ 3. Fuse in Evidence Memory (cross-attention)
17
+ 4. Construct z₀ and rollout K steps
18
+ 5. Score answer options (discriminative) and/or generate short answer
19
+ 6. Compute JEPA loss against target encoder trajectory
20
+ """
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ from typing import Optional, Dict, Any
26
+
27
+ from ..configs.model_config import MRJEPAConfig
28
+ from .backbones import VisualBackbone, TextEncoder
29
+ from .evidence_memory import EvidenceMemory
30
+ from .latent_rollout import LatentRolloutModule
31
+ from .target_encoder import TargetEncoder, JEPALoss
32
+ from .answer_heads import DiscriminativeHead, GenerativeHead
33
+
34
+
35
+ class MRJEPAModel(nn.Module):
36
+ """
37
+ MR-JEPA: A world model for multimodal reasoning.
38
+
39
+ Instead of modeling physical dynamics, this model models the evolution
40
+ of a belief state while solving a visual question. The JEPA objective
41
+ trains the latent rollout to produce meaningful intermediate states,
42
+ supervised by an EMA target encoder.
43
+
44
+ Parameters:
45
+ config: MRJEPAConfig with all architecture hyperparameters
46
+ """
47
+
48
+ def __init__(self, config: MRJEPAConfig):
49
+ super().__init__()
50
+ self.config = config
51
+
52
+ # ===================== Perception Encoders =====================
53
+ self.visual_backbone = VisualBackbone(config.visual)
54
+ self.text_encoder = TextEncoder(config.text)
55
+
56
+ # ===================== Evidence Memory =====================
57
+ self.evidence_memory = EvidenceMemory(
58
+ config=config.evidence,
59
+ visual_dim=config.visual.hidden_size,
60
+ text_dim=config.text.hidden_size,
61
+ )
62
+
63
+ # ===================== Latent Rollout =====================
64
+ self.latent_rollout = LatentRolloutModule(config.rollout)
65
+
66
+ # ===================== Target Encoder (EMA) =====================
67
+ self.target_encoder = TargetEncoder(
68
+ online_evidence_memory=self.evidence_memory,
69
+ online_rollout=self.latent_rollout,
70
+ config=config.jepa,
71
+ )
72
+
73
+ # ===================== Answer Heads =====================
74
+ self.disc_head = DiscriminativeHead(
75
+ config=config.answer,
76
+ hidden_dim=config.rollout.hidden_dim,
77
+ text_dim=config.text.hidden_size,
78
+ )
79
+
80
+ self.gen_head = GenerativeHead(
81
+ config=config.answer,
82
+ hidden_dim=config.rollout.hidden_dim,
83
+ vocab_size=config.answer.gen_vocab_size,
84
+ )
85
+
86
+ # ===================== JEPA Loss =====================
87
+ self.jepa_loss_fn = JEPALoss(
88
+ config=config.jepa,
89
+ hidden_dim=config.rollout.hidden_dim,
90
+ )
91
+
92
+ # ===================== Ablation controls =====================
93
+ self._use_jepa = True # Disable for "no-JEPA" ablation
94
+ self._use_rollout = True # Disable for "no-rollout" ablation (z₀ only)
95
+ self._use_evidence_gate = config.rollout.use_evidence_gate
96
+
97
+ def get_trainable_params(self, phase: int = 1) -> Dict[str, list]:
98
+ """
99
+ Get parameter groups for each training phase.
100
+
101
+ Phase 1: Freeze backbones, train evidence memory + rollout + heads
102
+ Phase 2: Unfreeze last N backbone layers with lower LR
103
+ Phase 3: Add enriched evidence modules
104
+
105
+ Returns dict with 'high_lr' and 'low_lr' parameter groups.
106
+ """
107
+ high_lr_params = []
108
+ low_lr_params = []
109
+
110
+ if phase >= 1:
111
+ # Always train: evidence memory, rollout, heads, loss
112
+ for module in [self.evidence_memory, self.latent_rollout,
113
+ self.disc_head, self.gen_head, self.jepa_loss_fn]:
114
+ high_lr_params.extend(module.parameters())
115
+
116
+ if phase >= 2:
117
+ # Unfreeze last N visual backbone layers
118
+ self.visual_backbone.unfreeze_last_n_layers(
119
+ self.config.visual.unfreeze_last_n_layers
120
+ )
121
+ # Unfreeze last N text encoder layers
122
+ self.text_encoder.unfreeze_last_n_layers(
123
+ self.config.text.unfreeze_last_n_layers
124
+ )
125
+ # Add backbone params with lower LR
126
+ for module in [self.visual_backbone, self.text_encoder]:
127
+ for p in module.parameters():
128
+ if p.requires_grad:
129
+ low_lr_params.append(p)
130
+
131
+ return {
132
+ 'high_lr': high_lr_params,
133
+ 'low_lr': low_lr_params,
134
+ }
135
+
136
+ def forward(
137
+ self,
138
+ pixel_values: torch.Tensor, # [B, C, H, W]
139
+ input_ids: torch.Tensor, # [B, seq_len]
140
+ attention_mask: torch.Tensor, # [B, seq_len]
141
+ option_embeddings: Optional[torch.Tensor] = None, # [B, max_opts, D_text]
142
+ option_mask: Optional[torch.Tensor] = None, # [B, max_opts]
143
+ answer_labels: Optional[torch.Tensor] = None, # [B] index of correct option
144
+ gen_target_ids: Optional[torch.Tensor] = None, # [B, gen_seq_len]
145
+ # Optional enriched evidence (Phase 3)
146
+ ocr_tokens: Optional[torch.Tensor] = None,
147
+ ocr_mask: Optional[torch.Tensor] = None,
148
+ layout_tokens: Optional[torch.Tensor] = None,
149
+ layout_mask: Optional[torch.Tensor] = None,
150
+ chart_tokens: Optional[torch.Tensor] = None,
151
+ chart_mask: Optional[torch.Tensor] = None,
152
+ sam_tokens: Optional[torch.Tensor] = None,
153
+ sam_mask: Optional[torch.Tensor] = None,
154
+ ) -> Dict[str, torch.Tensor]:
155
+ """
156
+ Full forward pass of MR-JEPA.
157
+
158
+ Returns dict with losses and predictions.
159
+ """
160
+ # ==================== 1. Perception ====================
161
+ # Visual features
162
+ visual_output = self.visual_backbone(pixel_values)
163
+ visual_tokens = visual_output['patch_tokens'] # [B, N_v, D_v]
164
+
165
+ # Text features
166
+ text_output = self.text_encoder(input_ids, attention_mask)
167
+ text_tokens = text_output['token_embeddings'] # [B, N_t, D_t]
168
+ text_mask = text_output['attention_mask'] # [B, N_t]
169
+
170
+ # ==================== 2. Evidence Memory ====================
171
+ enriched_kwargs = {}
172
+ for name, tokens, mask in [
173
+ ('ocr_tokens', ocr_tokens, ocr_mask),
174
+ ('layout_tokens', layout_tokens, layout_mask),
175
+ ('chart_tokens', chart_tokens, chart_mask),
176
+ ('sam_tokens', sam_tokens, sam_mask),
177
+ ]:
178
+ if tokens is not None:
179
+ enriched_kwargs[name] = tokens
180
+ enriched_kwargs[name.replace('tokens', 'mask')] = mask
181
+
182
+ evidence_output = self.evidence_memory(
183
+ visual_tokens=visual_tokens,
184
+ text_tokens=text_tokens,
185
+ text_mask=text_mask,
186
+ **enriched_kwargs,
187
+ )
188
+ evidence_tokens = evidence_output['evidence_tokens'] # [B, N_e, D]
189
+
190
+ # ==================== 3. Latent Rollout ====================
191
+ if self._use_rollout:
192
+ rollout_output = self.latent_rollout(
193
+ evidence_tokens=evidence_tokens,
194
+ )
195
+ trajectory = rollout_output['trajectory'] # [B, K+1, N_s, D]
196
+ z_final = rollout_output['z_final'] # [B, N_s, D]
197
+ z_projected = rollout_output['z_projected'] # [B, K+1, N_s, D]
198
+ else:
199
+ # Ablation: no rollout, use z₀ directly
200
+ z0 = self.latent_rollout._construct_z0(evidence_tokens)
201
+ z_final = z0
202
+ trajectory = z0.unsqueeze(1)
203
+ z_projected = self.latent_rollout.output_proj(z0).unsqueeze(1)
204
+
205
+ # ==================== 4. Target Encoder (JEPA) ====================
206
+ results = {}
207
+
208
+ if self._use_jepa and self.training:
209
+ target_output = self.target_encoder(
210
+ visual_tokens=visual_tokens.detach(),
211
+ text_tokens=text_tokens.detach(),
212
+ text_mask=text_mask.detach(),
213
+ **{k: v.detach() if v is not None else None
214
+ for k, v in enriched_kwargs.items()},
215
+ )
216
+ target_trajectory = target_output['target_trajectory']
217
+ results['target_trajectory'] = target_trajectory
218
+
219
+ # ==================== 5. Answer Heads ====================
220
+ # Discriminative head (MC questions)
221
+ if option_embeddings is not None and option_mask is not None:
222
+ disc_output = self.disc_head(z_final, option_embeddings, option_mask)
223
+ results['disc_logits'] = disc_output['logits']
224
+ results['disc_probs'] = disc_output['probs']
225
+
226
+ # Task loss
227
+ if answer_labels is not None:
228
+ task_loss = F.cross_entropy(disc_output['logits'], answer_labels)
229
+ results['task_loss'] = task_loss
230
+
231
+ # Generative head (open-ended questions)
232
+ if gen_target_ids is not None:
233
+ gen_output = self.gen_head(
234
+ z_final=z_final,
235
+ target_ids=gen_target_ids,
236
+ evidence_tokens=evidence_tokens,
237
+ )
238
+ results['gen_logits'] = gen_output['logits']
239
+ results['gen_loss'] = gen_output['loss']
240
+
241
+ # ==================== 6. JEPA Loss ====================
242
+ if self._use_jepa and self.training and 'target_trajectory' in results:
243
+ task_loss = results.get('task_loss', torch.tensor(0.0, device=pixel_values.device))
244
+ gen_loss = results.get('gen_loss', None)
245
+
246
+ loss_dict = self.jepa_loss_fn(
247
+ predicted_trajectory=z_projected,
248
+ target_trajectory=target_trajectory,
249
+ task_loss=task_loss,
250
+ gen_loss=gen_loss,
251
+ )
252
+ results.update(loss_dict)
253
+ elif 'task_loss' in results:
254
+ results['total_loss'] = results['task_loss']
255
+ if 'gen_loss' in results:
256
+ results['total_loss'] = results['total_loss'] + \
257
+ self.config.jepa.generative_loss_weight * results['gen_loss']
258
+
259
+ # Store trajectory for analysis
260
+ results['trajectory'] = trajectory
261
+ results['z_final'] = z_final
262
+ results['evidence_tokens'] = evidence_tokens
263
+
264
+ return results
265
+
266
+ def update_target_encoder(self, step: int, total_steps: int):
267
+ """Update EMA target encoder (call after each optimizer step)."""
268
+ self.target_encoder.update_ema(
269
+ online_evidence_memory=self.evidence_memory,
270
+ online_rollout=self.latent_rollout,
271
+ step=step,
272
+ total_steps=total_steps,
273
+ )
274
+
275
+ @torch.no_grad()
276
+ def predict_mc(
277
+ self,
278
+ pixel_values: torch.Tensor,
279
+ input_ids: torch.Tensor,
280
+ attention_mask: torch.Tensor,
281
+ option_embeddings: torch.Tensor,
282
+ option_mask: torch.Tensor,
283
+ ) -> torch.Tensor:
284
+ """Predict answer for multiple-choice questions. Returns predicted indices."""
285
+ self.eval()
286
+ outputs = self.forward(
287
+ pixel_values=pixel_values,
288
+ input_ids=input_ids,
289
+ attention_mask=attention_mask,
290
+ option_embeddings=option_embeddings,
291
+ option_mask=option_mask,
292
+ )
293
+ return outputs['disc_probs'].argmax(dim=-1)
294
+
295
+ @torch.no_grad()
296
+ def predict_open(
297
+ self,
298
+ pixel_values: torch.Tensor,
299
+ input_ids: torch.Tensor,
300
+ attention_mask: torch.Tensor,
301
+ start_token_id: int,
302
+ max_length: int = 64,
303
+ eos_token_id: Optional[int] = None,
304
+ ) -> torch.Tensor:
305
+ """Generate short answer for open-ended questions."""
306
+ self.eval()
307
+ outputs = self.forward(
308
+ pixel_values=pixel_values,
309
+ input_ids=input_ids,
310
+ attention_mask=attention_mask,
311
+ )
312
+ return self.gen_head.generate(
313
+ z_final=outputs['z_final'],
314
+ start_token_id=start_token_id,
315
+ max_length=max_length,
316
+ evidence_tokens=outputs['evidence_tokens'],
317
+ eos_token_id=eos_token_id,
318
+ )
319
+
320
+ def set_ablation(self, use_jepa: bool = True, use_rollout: bool = True,
321
+ use_evidence_gate: bool = True):
322
+ """Configure ablation settings for experiments."""
323
+ self._use_jepa = use_jepa
324
+ self._use_rollout = use_rollout
325
+
326
+ # Disable evidence gates in rollout
327
+ if not use_evidence_gate:
328
+ for layer in self.latent_rollout.predictor_layers:
329
+ layer.evidence_gate = lambda s, e: e # Identity gate
330
+
331
+ def count_parameters(self) -> Dict[str, int]:
332
+ """Count parameters by component."""
333
+ counts = {}
334
+ for name, module in [
335
+ ('visual_backbone', self.visual_backbone),
336
+ ('text_encoder', self.text_encoder),
337
+ ('evidence_memory', self.evidence_memory),
338
+ ('latent_rollout', self.latent_rollout),
339
+ ('disc_head', self.disc_head),
340
+ ('gen_head', self.gen_head),
341
+ ]:
342
+ total = sum(p.numel() for p in module.parameters())
343
+ trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
344
+ counts[name] = {'total': total, 'trainable': trainable}
345
+
346
+ counts['total'] = {
347
+ 'total': sum(c['total'] for c in counts.values()),
348
+ 'trainable': sum(c['trainable'] for c in counts.values()),
349
+ }
350
+ return counts
mr_jepa/models/target_encoder.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Target Encoder (EMA) for MR-JEPA.
3
+
4
+ The target encoder generates the supervision signal for the JEPA objective.
5
+ It is an exponential moving average (EMA) copy of the online encoder
6
+ (evidence memory + rollout module).
7
+
8
+ From I-JEPA:
9
+ θ̄ ← m·θ̄ + (1-m)·θ
10
+ where m follows a cosine schedule from 0.996 → 1.0
11
+
12
+ The target encoder processes the same inputs but with stop-gradient,
13
+ producing target latent states z*_k that the online predictor must predict.
14
+
15
+ From LeWorldModel: We also add SIGReg anti-collapse regularization
16
+ to prevent the representation space from collapsing.
17
+ """
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ import math
23
+ import copy
24
+ from typing import Optional, Dict
25
+
26
+ from ..configs.model_config import JEPAObjectiveConfig
27
+
28
+
29
+ class TargetEncoder(nn.Module):
30
+ """
31
+ EMA target encoder that generates JEPA targets.
32
+
33
+ This module wraps a copy of the online encoder (evidence memory + rollout)
34
+ and updates its weights via exponential moving average.
35
+
36
+ The target latent trajectory is used as the ground truth for the
37
+ JEPA prediction loss: ||z_predicted_k - sg(z*_k)||²
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ online_evidence_memory: nn.Module,
43
+ online_rollout: nn.Module,
44
+ config: JEPAObjectiveConfig,
45
+ ):
46
+ super().__init__()
47
+ self.config = config
48
+
49
+ # Deep copy of online modules
50
+ self.target_evidence_memory = copy.deepcopy(online_evidence_memory)
51
+ self.target_rollout = copy.deepcopy(online_rollout)
52
+
53
+ # Freeze target encoder (no gradient)
54
+ for param in self.target_evidence_memory.parameters():
55
+ param.requires_grad = False
56
+ for param in self.target_rollout.parameters():
57
+ param.requires_grad = False
58
+
59
+ # EMA schedule tracking
60
+ self._current_momentum = config.ema_momentum_base
61
+
62
+ @torch.no_grad()
63
+ def update_ema(
64
+ self,
65
+ online_evidence_memory: nn.Module,
66
+ online_rollout: nn.Module,
67
+ step: int,
68
+ total_steps: int,
69
+ ):
70
+ """
71
+ Update target encoder weights via EMA.
72
+
73
+ From I-JEPA: cosine schedule from base momentum to 1.0
74
+ m(t) = 1 - (1 - m_base) * (1 + cos(π * t / T)) / 2
75
+ """
76
+ # Compute momentum
77
+ if self.config.ema_schedule == "cosine":
78
+ # Cosine annealing from base to end momentum
79
+ progress = step / max(total_steps, 1)
80
+ momentum = self.config.ema_momentum_end - \
81
+ (self.config.ema_momentum_end - self.config.ema_momentum_base) * \
82
+ (1 + math.cos(math.pi * progress)) / 2
83
+ elif self.config.ema_schedule == "linear":
84
+ progress = step / max(total_steps, 1)
85
+ momentum = self.config.ema_momentum_base + \
86
+ (self.config.ema_momentum_end - self.config.ema_momentum_base) * progress
87
+ else: # constant
88
+ momentum = self.config.ema_momentum_base
89
+
90
+ self._current_momentum = momentum
91
+
92
+ # Update evidence memory
93
+ for online_p, target_p in zip(
94
+ online_evidence_memory.parameters(),
95
+ self.target_evidence_memory.parameters()
96
+ ):
97
+ target_p.data.mul_(momentum).add_(online_p.data, alpha=1 - momentum)
98
+
99
+ # Update rollout module
100
+ for online_p, target_p in zip(
101
+ online_rollout.parameters(),
102
+ self.target_rollout.parameters()
103
+ ):
104
+ target_p.data.mul_(momentum).add_(online_p.data, alpha=1 - momentum)
105
+
106
+ @torch.no_grad()
107
+ def forward(
108
+ self,
109
+ visual_tokens: torch.Tensor,
110
+ text_tokens: torch.Tensor,
111
+ text_mask: torch.Tensor,
112
+ **enriched_kwargs,
113
+ ) -> Dict[str, torch.Tensor]:
114
+ """
115
+ Generate target latent trajectory (no gradient).
116
+
117
+ Returns:
118
+ dict with:
119
+ 'target_trajectory': [B, K+1, N_s, D] - target states
120
+ 'target_evidence': [B, N_e, D] - target evidence tokens
121
+ """
122
+ # Target evidence memory
123
+ evidence_output = self.target_evidence_memory(
124
+ visual_tokens=visual_tokens,
125
+ text_tokens=text_tokens,
126
+ text_mask=text_mask,
127
+ **enriched_kwargs,
128
+ )
129
+
130
+ target_evidence = evidence_output['evidence_tokens']
131
+
132
+ # Target rollout
133
+ rollout_output = self.target_rollout(
134
+ evidence_tokens=target_evidence,
135
+ )
136
+
137
+ return {
138
+ 'target_trajectory': rollout_output['trajectory'],
139
+ 'target_evidence': target_evidence,
140
+ }
141
+
142
+
143
+ class SIGRegLoss(nn.Module):
144
+ """
145
+ Sketched Isotropic Gaussian Regularizer (from LeWorldModel).
146
+
147
+ Prevents representation collapse by encouraging latent embeddings
148
+ to match an isotropic Gaussian distribution.
149
+
150
+ Uses random projections + Epps-Pulley test statistic.
151
+ SIGReg(Z) = (1/M) Σ_m T(Z @ u_m)
152
+
153
+ where T is the Epps-Pulley univariate normality test.
154
+ """
155
+
156
+ def __init__(self, hidden_dim: int, num_projections: int = 1024):
157
+ super().__init__()
158
+ self.num_projections = num_projections
159
+ # Random projection directions (fixed, not learned)
160
+ self.register_buffer(
161
+ 'projections',
162
+ F.normalize(torch.randn(hidden_dim, num_projections), dim=0)
163
+ )
164
+
165
+ def _epps_pulley_statistic(self, h: torch.Tensor) -> torch.Tensor:
166
+ """
167
+ Compute Epps-Pulley test statistic for univariate normality.
168
+
169
+ T(h) measures how far the distribution of h is from N(0,1).
170
+ Lower values = more Gaussian.
171
+
172
+ Simplified version: uses moment-based approximation.
173
+ """
174
+ # Standardize
175
+ h_mean = h.mean()
176
+ h_std = h.std() + 1e-6
177
+ h_norm = (h - h_mean) / h_std
178
+
179
+ n = h_norm.size(0)
180
+
181
+ # Compute pairwise differences for the EP statistic
182
+ # EP test: based on characteristic function
183
+ # Simplified: variance + kurtosis penalty
184
+ variance = h_norm.var()
185
+ kurtosis = ((h_norm ** 4).mean() - 3).abs() # Excess kurtosis
186
+
187
+ # Penalize deviation from unit variance and zero excess kurtosis
188
+ return (variance - 1.0) ** 2 + 0.5 * kurtosis
189
+
190
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
191
+ """
192
+ Compute SIGReg loss.
193
+
194
+ Args:
195
+ z: Latent embeddings [B, N, D] or [B*N, D]
196
+
197
+ Returns:
198
+ Scalar SIGReg loss
199
+ """
200
+ if z.dim() == 3:
201
+ B, N, D = z.shape
202
+ z_flat = z.reshape(B * N, D)
203
+ else:
204
+ z_flat = z
205
+
206
+ # Project onto random directions
207
+ projections = z_flat @ self.projections # [B*N, M]
208
+
209
+ # Compute EP statistic for each projection
210
+ losses = []
211
+ for m in range(min(self.num_projections, 64)): # Sample subset for efficiency
212
+ losses.append(self._epps_pulley_statistic(projections[:, m]))
213
+
214
+ return torch.stack(losses).mean()
215
+
216
+
217
+ class VICRegLoss(nn.Module):
218
+ """
219
+ VICReg-style regularization (alternative to SIGReg).
220
+
221
+ Three terms:
222
+ - Variance: keep feature std above a threshold
223
+ - Invariance: prediction should match target (already handled by L2)
224
+ - Covariance: decorrelate features
225
+ """
226
+
227
+ def __init__(self, var_weight: float = 1.0, cov_weight: float = 0.04):
228
+ super().__init__()
229
+ self.var_weight = var_weight
230
+ self.cov_weight = cov_weight
231
+
232
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
233
+ """
234
+ Args:
235
+ z: [B*N, D] latent embeddings
236
+ """
237
+ if z.dim() == 3:
238
+ z = z.reshape(-1, z.size(-1))
239
+
240
+ # Variance: penalize if std drops below 1
241
+ std = z.std(dim=0)
242
+ var_loss = F.relu(1.0 - std).mean()
243
+
244
+ # Covariance: penalize off-diagonal correlations
245
+ z_centered = z - z.mean(dim=0, keepdim=True)
246
+ N = z_centered.size(0)
247
+ cov = (z_centered.T @ z_centered) / (N - 1)
248
+ D = cov.size(0)
249
+ # Off-diagonal elements
250
+ off_diag = cov.flatten()[:-1].view(D - 1, D + 1)[:, 1:].flatten()
251
+ cov_loss = (off_diag ** 2).mean()
252
+
253
+ return self.var_weight * var_loss + self.cov_weight * cov_loss
254
+
255
+
256
+ class JEPALoss(nn.Module):
257
+ """
258
+ Complete JEPA objective for MR-JEPA.
259
+
260
+ L_JEPA = (1/K) Σ_{k=1}^{K} ||z_pred_k - sg(z*_k)||²
261
+
262
+ Plus anti-collapse regularization:
263
+ L_total = L_JEPA + λ * SIGReg(Z) + L_task + α * L_gen
264
+ """
265
+
266
+ def __init__(self, config: JEPAObjectiveConfig, hidden_dim: int):
267
+ super().__init__()
268
+ self.config = config
269
+
270
+ # Anti-collapse
271
+ if config.use_sigreg:
272
+ self.sigreg = SIGRegLoss(hidden_dim, config.sigreg_num_projections)
273
+ if config.use_vicreg:
274
+ self.vicreg = VICRegLoss(config.vicreg_var_weight, config.vicreg_cov_weight)
275
+
276
+ def compute_jepa_loss(
277
+ self,
278
+ predicted_trajectory: torch.Tensor, # [B, K+1, N_s, D]
279
+ target_trajectory: torch.Tensor, # [B, K+1, N_s, D]
280
+ ) -> torch.Tensor:
281
+ """
282
+ Compute L2 prediction loss between online and target trajectories.
283
+
284
+ Only compute loss for steps k=1..K (not z₀, which is deterministic).
285
+ """
286
+ # Skip z₀ (step 0) — only supervise predicted states
287
+ pred = predicted_trajectory[:, 1:] # [B, K, N_s, D]
288
+ target = target_trajectory[:, 1:] # [B, K, N_s, D]
289
+
290
+ # L2 loss per step, averaged
291
+ loss = F.mse_loss(pred, target.detach())
292
+ return loss
293
+
294
+ def compute_regularization(
295
+ self,
296
+ trajectory: torch.Tensor, # [B, K+1, N_s, D]
297
+ ) -> torch.Tensor:
298
+ """Compute anti-collapse regularization."""
299
+ reg_loss = torch.tensor(0.0, device=trajectory.device)
300
+
301
+ if self.config.use_sigreg:
302
+ # Apply SIGReg to each step's representations
303
+ B, Kp1, N_s, D = trajectory.shape
304
+ for k in range(Kp1):
305
+ reg_loss = reg_loss + self.sigreg(trajectory[:, k])
306
+ reg_loss = reg_loss / Kp1
307
+ reg_loss = self.config.sigreg_weight * reg_loss
308
+
309
+ if self.config.use_vicreg:
310
+ B, Kp1, N_s, D = trajectory.shape
311
+ for k in range(Kp1):
312
+ reg_loss = reg_loss + self.vicreg(trajectory[:, k])
313
+ reg_loss = reg_loss / Kp1
314
+
315
+ return reg_loss
316
+
317
+ def forward(
318
+ self,
319
+ predicted_trajectory: torch.Tensor,
320
+ target_trajectory: torch.Tensor,
321
+ task_loss: torch.Tensor,
322
+ gen_loss: Optional[torch.Tensor] = None,
323
+ ) -> Dict[str, torch.Tensor]:
324
+ """
325
+ Compute total MR-JEPA loss.
326
+
327
+ Returns dict with individual loss components for logging.
328
+ """
329
+ # JEPA prediction loss
330
+ jepa_loss = self.compute_jepa_loss(predicted_trajectory, target_trajectory)
331
+
332
+ # Anti-collapse regularization
333
+ reg_loss = self.compute_regularization(predicted_trajectory)
334
+
335
+ # Total loss
336
+ total = (
337
+ self.config.jepa_loss_weight * jepa_loss +
338
+ self.config.task_loss_weight * task_loss +
339
+ reg_loss
340
+ )
341
+
342
+ losses = {
343
+ 'total_loss': total,
344
+ 'jepa_loss': jepa_loss,
345
+ 'task_loss': task_loss,
346
+ 'reg_loss': reg_loss,
347
+ }
348
+
349
+ if gen_loss is not None:
350
+ total = total + self.config.generative_loss_weight * gen_loss
351
+ losses['total_loss'] = total
352
+ losses['gen_loss'] = gen_loss
353
+
354
+ return losses
mr_jepa/training/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .trainer import MRJEPATrainer
2
+ from .phase_scheduler import PhaseScheduler
3
+
4
+ __all__ = ["MRJEPATrainer", "PhaseScheduler"]
mr_jepa/training/phase_scheduler.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Phase Scheduler for MR-JEPA 3-Phase Training.
3
+
4
+ Manages the transition between training phases:
5
+ Phase 1: Freeze perception → train reasoning core
6
+ Phase 2: Unfreeze perception → fine-tune end-to-end
7
+ Phase 3: Enable enriched evidence → document/chart specialization
8
+ """
9
+
10
+ import math
11
+ import torch
12
+ from torch.optim.lr_scheduler import _LRScheduler
13
+ from typing import Optional
14
+
15
+
16
+ class CosineWarmupScheduler(_LRScheduler):
17
+ """Cosine schedule with linear warmup (per phase)."""
18
+
19
+ def __init__(
20
+ self,
21
+ optimizer: torch.optim.Optimizer,
22
+ warmup_steps: int,
23
+ total_steps: int,
24
+ min_lr_ratio: float = 0.01,
25
+ last_epoch: int = -1,
26
+ ):
27
+ self.warmup_steps = warmup_steps
28
+ self.total_steps = total_steps
29
+ self.min_lr_ratio = min_lr_ratio
30
+ super().__init__(optimizer, last_epoch)
31
+
32
+ def get_lr(self):
33
+ step = self.last_epoch
34
+
35
+ if step < self.warmup_steps:
36
+ # Linear warmup
37
+ factor = step / max(self.warmup_steps, 1)
38
+ else:
39
+ # Cosine decay
40
+ progress = (step - self.warmup_steps) / max(
41
+ self.total_steps - self.warmup_steps, 1
42
+ )
43
+ factor = self.min_lr_ratio + (1 - self.min_lr_ratio) * \
44
+ 0.5 * (1 + math.cos(math.pi * progress))
45
+
46
+ return [base_lr * factor for base_lr in self.base_lrs]
47
+
48
+
49
+ class PhaseScheduler:
50
+ """
51
+ Orchestrates the 3-phase training schedule.
52
+
53
+ Handles:
54
+ - Phase transitions (unfreezing, enabling modules)
55
+ - Per-phase optimizer and LR scheduler creation
56
+ - Checkpoint management between phases
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ model,
62
+ training_config,
63
+ ):
64
+ self.model = model
65
+ self.training_config = training_config
66
+ self.current_phase = 0
67
+ self.phase_histories = {1: [], 2: [], 3: []}
68
+
69
+ def get_phase_scheduler(
70
+ self,
71
+ optimizer: torch.optim.Optimizer,
72
+ phase: int,
73
+ steps_per_epoch: int,
74
+ ) -> CosineWarmupScheduler:
75
+ """Create LR scheduler for a specific phase."""
76
+ if phase == 1:
77
+ epochs = self.training_config.phase1_epochs
78
+ warmup_ratio = self.training_config.phase1_warmup_ratio
79
+ elif phase == 2:
80
+ epochs = self.training_config.phase2_epochs
81
+ warmup_ratio = self.training_config.phase2_warmup_ratio
82
+ else:
83
+ epochs = self.training_config.phase3_epochs
84
+ warmup_ratio = self.training_config.phase3_warmup_ratio
85
+
86
+ total_steps = epochs * steps_per_epoch
87
+ warmup_steps = int(total_steps * warmup_ratio)
88
+
89
+ return CosineWarmupScheduler(
90
+ optimizer=optimizer,
91
+ warmup_steps=warmup_steps,
92
+ total_steps=total_steps,
93
+ )
94
+
95
+ def should_transition(self, phase: int, epoch: int) -> bool:
96
+ """Check if we should move to the next phase."""
97
+ if phase == 1:
98
+ return epoch >= self.training_config.phase1_epochs
99
+ elif phase == 2:
100
+ return epoch >= self.training_config.phase2_epochs
101
+ elif phase == 3:
102
+ return epoch >= self.training_config.phase3_epochs
103
+ return True
104
+
105
+ def log_phase_metrics(self, phase: int, metrics: dict):
106
+ """Record metrics for phase transition analysis."""
107
+ self.phase_histories[phase].append(metrics)
mr_jepa/training/trainer.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MR-JEPA Trainer.
3
+
4
+ Implements the 3-phase training schedule:
5
+
6
+ Phase 1 (Reasoning Core):
7
+ - Freeze visual backbone + text encoder
8
+ - Train evidence memory, latent rollout, answer heads
9
+ - Full JEPA objective + task loss
10
+
11
+ Phase 2 (Perception Fine-tuning):
12
+ - Unfreeze last N visual backbone layers (lower LR)
13
+ - Unfreeze last N text encoder layers (lower LR)
14
+ - Continue training all other components
15
+
16
+ Phase 3 (Enriched Evidence):
17
+ - Enable OCR, layout, chart tokens
18
+ - Fine-tune entire model end-to-end
19
+ - Focus on document/chart benchmarks
20
+
21
+ Each phase uses cosine LR schedule with warmup.
22
+ EMA target encoder is updated after each optimizer step.
23
+ """
24
+
25
+ import os
26
+ import time
27
+ import json
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+ from torch.optim import AdamW
32
+ from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
33
+ from torch.cuda.amp import autocast, GradScaler
34
+ from typing import Optional, Dict, Any, List
35
+ import logging
36
+ from pathlib import Path
37
+
38
+ from ..configs.model_config import MRJEPAConfig, TrainingPhaseConfig
39
+ from ..models.mr_jepa import MRJEPAModel
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ class MRJEPATrainer:
45
+ """
46
+ 3-phase trainer for MR-JEPA.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ model: MRJEPAModel,
52
+ config: MRJEPAConfig,
53
+ training_config: TrainingPhaseConfig,
54
+ train_dataloaders: Dict[str, Any], # Per-benchmark dataloaders
55
+ eval_dataloaders: Dict[str, Any],
56
+ output_dir: str = "./outputs",
57
+ device: str = "cuda",
58
+ ):
59
+ self.model = model.to(device)
60
+ self.config = config
61
+ self.training_config = training_config
62
+ self.train_dataloaders = train_dataloaders
63
+ self.eval_dataloaders = eval_dataloaders
64
+ self.output_dir = Path(output_dir)
65
+ self.output_dir.mkdir(parents=True, exist_ok=True)
66
+ self.device = device
67
+
68
+ # Training state
69
+ self.global_step = 0
70
+ self.current_phase = 0
71
+ self.best_metric = 0.0
72
+
73
+ # Mixed precision
74
+ self.use_amp = training_config.bf16 or training_config.fp16
75
+ self.amp_dtype = torch.bfloat16 if training_config.bf16 else torch.float16
76
+ self.scaler = GradScaler(enabled=training_config.fp16) # Only for fp16
77
+
78
+ def _build_optimizer(self, phase: int) -> torch.optim.Optimizer:
79
+ """Build optimizer with per-phase parameter groups."""
80
+ param_groups = self.model.get_trainable_params(phase)
81
+
82
+ if phase == 1:
83
+ lr = self.training_config.phase1_lr
84
+ groups = [
85
+ {'params': param_groups['high_lr'], 'lr': lr},
86
+ ]
87
+ elif phase == 2:
88
+ lr = self.training_config.phase2_lr
89
+ backbone_lr = self.training_config.phase2_backbone_lr
90
+ groups = [
91
+ {'params': param_groups['high_lr'], 'lr': lr},
92
+ {'params': param_groups['low_lr'], 'lr': backbone_lr},
93
+ ]
94
+ else: # phase 3
95
+ lr = self.training_config.phase3_lr
96
+ groups = [
97
+ {'params': param_groups['high_lr'], 'lr': lr},
98
+ {'params': param_groups['low_lr'], 'lr': lr * 0.1},
99
+ ]
100
+
101
+ # Filter out empty param groups
102
+ groups = [g for g in groups if len(g['params']) > 0]
103
+
104
+ optimizer = AdamW(
105
+ groups,
106
+ weight_decay=self.training_config.phase1_weight_decay,
107
+ )
108
+ return optimizer
109
+
110
+ def _get_phase_config(self, phase: int) -> Dict[str, Any]:
111
+ """Get training parameters for a specific phase."""
112
+ if phase == 1:
113
+ return {
114
+ 'epochs': self.training_config.phase1_epochs,
115
+ 'batch_size': self.training_config.phase1_batch_size,
116
+ 'grad_accum': self.training_config.phase1_grad_accum,
117
+ 'warmup_ratio': self.training_config.phase1_warmup_ratio,
118
+ }
119
+ elif phase == 2:
120
+ return {
121
+ 'epochs': self.training_config.phase2_epochs,
122
+ 'batch_size': self.training_config.phase2_batch_size,
123
+ 'grad_accum': self.training_config.phase2_grad_accum,
124
+ 'warmup_ratio': self.training_config.phase2_warmup_ratio,
125
+ }
126
+ else:
127
+ return {
128
+ 'epochs': self.training_config.phase3_epochs,
129
+ 'batch_size': self.training_config.phase3_batch_size,
130
+ 'grad_accum': self.training_config.phase3_grad_accum,
131
+ 'warmup_ratio': self.training_config.phase3_warmup_ratio,
132
+ }
133
+
134
+ def _prepare_phase(self, phase: int):
135
+ """Set up model for a specific training phase."""
136
+ logger.info(f"=== Preparing Phase {phase} ===")
137
+
138
+ if phase == 1:
139
+ # Freeze all perception, train reasoning core
140
+ self.model.visual_backbone.freeze_all()
141
+ self.model.text_encoder.freeze_all()
142
+
143
+ elif phase == 2:
144
+ # Unfreeze last N layers of backbones
145
+ n_visual = self.training_config.phase2_unfreeze_visual_layers
146
+ n_text = self.training_config.phase2_unfreeze_text_layers
147
+ self.model.visual_backbone.unfreeze_last_n_layers(n_visual)
148
+ self.model.text_encoder.unfreeze_last_n_layers(n_text)
149
+ logger.info(f"Unfroze last {n_visual} visual layers, {n_text} text layers")
150
+
151
+ elif phase == 3:
152
+ # Enable enriched evidence
153
+ if self.training_config.phase3_enable_ocr:
154
+ self.config.evidence.use_ocr_tokens = True
155
+ if self.training_config.phase3_enable_layout:
156
+ self.config.evidence.use_layout_tokens = True
157
+ if self.training_config.phase3_enable_chart:
158
+ self.config.evidence.use_chart_tokens = True
159
+ if self.training_config.phase3_enable_sam:
160
+ self.config.evidence.use_sam_tokens = True
161
+ logger.info("Enabled enriched evidence tokens")
162
+
163
+ self.current_phase = phase
164
+
165
+ def _train_step(
166
+ self,
167
+ batch: Dict[str, torch.Tensor],
168
+ optimizer: torch.optim.Optimizer,
169
+ grad_accum_steps: int,
170
+ total_steps: int,
171
+ ) -> Dict[str, float]:
172
+ """Single training step with gradient accumulation."""
173
+ # Move batch to device
174
+ device_batch = {}
175
+ for k, v in batch.items():
176
+ if isinstance(v, torch.Tensor):
177
+ device_batch[k] = v.to(self.device)
178
+ else:
179
+ device_batch[k] = v
180
+
181
+ # Handle option embeddings (encode option texts through text encoder)
182
+ if 'option_texts' in batch:
183
+ option_embs = self._encode_options(batch['option_texts'])
184
+ device_batch['option_embeddings'] = option_embs.to(self.device)
185
+
186
+ # Forward pass with AMP
187
+ with autocast(device_type='cuda', dtype=self.amp_dtype, enabled=self.use_amp):
188
+ outputs = self.model(
189
+ pixel_values=device_batch.get('pixel_values'),
190
+ input_ids=device_batch.get('input_ids'),
191
+ attention_mask=device_batch.get('attention_mask'),
192
+ option_embeddings=device_batch.get('option_embeddings'),
193
+ option_mask=device_batch.get('option_mask'),
194
+ answer_labels=device_batch.get('answer_labels'),
195
+ gen_target_ids=device_batch.get('gen_target_ids'),
196
+ )
197
+
198
+ loss = outputs.get('total_loss', outputs.get('task_loss', torch.tensor(0.0)))
199
+ loss = loss / grad_accum_steps
200
+
201
+ # Backward
202
+ if self.training_config.fp16:
203
+ self.scaler.scale(loss).backward()
204
+ else:
205
+ loss.backward()
206
+
207
+ # Step optimizer (with grad accumulation)
208
+ if (self.global_step + 1) % grad_accum_steps == 0:
209
+ if self.training_config.max_grad_norm > 0:
210
+ if self.training_config.fp16:
211
+ self.scaler.unscale_(optimizer)
212
+ nn.utils.clip_grad_norm_(
213
+ self.model.parameters(),
214
+ self.training_config.max_grad_norm,
215
+ )
216
+
217
+ if self.training_config.fp16:
218
+ self.scaler.step(optimizer)
219
+ self.scaler.update()
220
+ else:
221
+ optimizer.step()
222
+
223
+ optimizer.zero_grad()
224
+
225
+ # Update EMA target encoder
226
+ self.model.update_target_encoder(self.global_step, total_steps)
227
+
228
+ self.global_step += 1
229
+
230
+ # Collect metrics
231
+ metrics = {
232
+ 'loss': loss.item() * grad_accum_steps,
233
+ }
234
+ for key in ['jepa_loss', 'task_loss', 'reg_loss', 'gen_loss']:
235
+ if key in outputs:
236
+ metrics[key] = outputs[key].item()
237
+
238
+ return metrics
239
+
240
+ def _encode_options(self, option_texts: List[List[str]]) -> torch.Tensor:
241
+ """Encode option texts using the text encoder (pooled representation)."""
242
+ B = len(option_texts)
243
+ max_opts = len(option_texts[0])
244
+
245
+ # Flatten all options
246
+ flat_texts = []
247
+ for opts in option_texts:
248
+ flat_texts.extend(opts)
249
+
250
+ # Tokenize
251
+ tokenizer = self.model.text_encoder.tokenizer
252
+ encoded = tokenizer(
253
+ flat_texts,
254
+ padding='max_length',
255
+ truncation=True,
256
+ max_length=64,
257
+ return_tensors='pt',
258
+ )
259
+
260
+ # Encode through text encoder (no gradient for efficiency)
261
+ with torch.no_grad():
262
+ text_output = self.model.text_encoder(
263
+ input_ids=encoded['input_ids'].to(self.device),
264
+ attention_mask=encoded['attention_mask'].to(self.device),
265
+ )
266
+
267
+ # Get CLS embedding for each option
268
+ cls_embeddings = text_output['cls_embedding'] # [B*max_opts, D]
269
+ option_embeddings = cls_embeddings.reshape(B, max_opts, -1) # [B, max_opts, D]
270
+
271
+ return option_embeddings
272
+
273
+ def train_phase(self, phase: int):
274
+ """Run a complete training phase."""
275
+ self._prepare_phase(phase)
276
+
277
+ phase_config = self._get_phase_config(phase)
278
+ optimizer = self._build_optimizer(phase)
279
+
280
+ total_steps = phase_config['epochs'] * sum(
281
+ len(dl) for dl in self.train_dataloaders.values()
282
+ )
283
+
284
+ logger.info(f"Phase {phase}: {phase_config['epochs']} epochs, "
285
+ f"~{total_steps} steps")
286
+
287
+ self.model.train()
288
+
289
+ for epoch in range(phase_config['epochs']):
290
+ epoch_metrics = {}
291
+
292
+ # Iterate over all training benchmarks
293
+ for benchmark_name, dataloader in self.train_dataloaders.items():
294
+ for step, batch in enumerate(dataloader):
295
+ metrics = self._train_step(
296
+ batch, optimizer,
297
+ phase_config['grad_accum'],
298
+ total_steps,
299
+ )
300
+
301
+ # Accumulate metrics
302
+ for k, v in metrics.items():
303
+ epoch_metrics.setdefault(k, []).append(v)
304
+
305
+ # Logging
306
+ if self.global_step % 100 == 0:
307
+ avg_loss = sum(epoch_metrics.get('loss', [0])) / max(len(epoch_metrics.get('loss', [1])), 1)
308
+ logger.info(
309
+ f"Phase {phase} | Epoch {epoch} | Step {self.global_step} | "
310
+ f"Loss: {avg_loss:.4f} | "
311
+ f"Benchmark: {benchmark_name}"
312
+ )
313
+
314
+ # Epoch-level logging
315
+ avg_metrics = {
316
+ k: sum(v) / len(v) for k, v in epoch_metrics.items()
317
+ }
318
+ logger.info(f"Phase {phase} | Epoch {epoch} complete | "
319
+ f"Avg Loss: {avg_metrics.get('loss', 0):.4f}")
320
+
321
+ # Save checkpoint
322
+ self._save_checkpoint(phase, epoch)
323
+
324
+ def train(self, phases: List[int] = [1, 2, 3]):
325
+ """Run the full multi-phase training."""
326
+ logger.info("Starting MR-JEPA training")
327
+ logger.info(f"Model parameter counts: {self.model.count_parameters()}")
328
+
329
+ for phase in phases:
330
+ logger.info(f"\n{'='*60}")
331
+ logger.info(f"PHASE {phase}")
332
+ logger.info(f"{'='*60}")
333
+ self.train_phase(phase)
334
+
335
+ # Evaluate after each phase
336
+ eval_results = self.evaluate()
337
+ logger.info(f"Phase {phase} eval results: {json.dumps(eval_results, indent=2)}")
338
+
339
+ logger.info("Training complete!")
340
+
341
+ def evaluate(self) -> Dict[str, Dict[str, float]]:
342
+ """Evaluate on all benchmark eval sets."""
343
+ from ..evaluation.metrics import evaluate_benchmark
344
+
345
+ self.model.eval()
346
+ results = {}
347
+
348
+ for benchmark_name, dataloader in self.eval_dataloaders.items():
349
+ predictions = []
350
+ ground_truths = []
351
+
352
+ with torch.no_grad():
353
+ for batch in dataloader:
354
+ # Move to device
355
+ pixel_values = batch['pixel_values'].to(self.device)
356
+ input_ids = batch['input_ids'].to(self.device)
357
+ attention_mask = batch['attention_mask'].to(self.device)
358
+
359
+ if 'option_mask' in batch:
360
+ option_mask = batch['option_mask'].to(self.device)
361
+ option_embs = self._encode_options(batch['option_texts'])
362
+
363
+ preds = self.model.predict_mc(
364
+ pixel_values, input_ids, attention_mask,
365
+ option_embs, option_mask,
366
+ )
367
+ predictions.extend(preds.cpu().tolist())
368
+ ground_truths.extend(batch['answer_labels'].tolist())
369
+ else:
370
+ # Open-ended (would need generation)
371
+ # Simplified: skip for now
372
+ pass
373
+
374
+ if predictions:
375
+ result = evaluate_benchmark(
376
+ benchmark_name, predictions, ground_truths
377
+ )
378
+ results[benchmark_name] = result
379
+
380
+ self.model.train()
381
+ return results
382
+
383
+ def _save_checkpoint(self, phase: int, epoch: int):
384
+ """Save model checkpoint."""
385
+ ckpt_dir = self.output_dir / f"phase{phase}_epoch{epoch}"
386
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
387
+
388
+ # Save model state
389
+ torch.save({
390
+ 'model_state_dict': self.model.state_dict(),
391
+ 'phase': phase,
392
+ 'epoch': epoch,
393
+ 'global_step': self.global_step,
394
+ 'config': self.config,
395
+ }, ckpt_dir / "checkpoint.pt")
396
+
397
+ logger.info(f"Saved checkpoint to {ckpt_dir}")
mr_jepa/utils/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .visualization import visualize_trajectory, visualize_evidence_gates
2
+ from .ablation import AblationRunner
3
+
4
+ __all__ = [
5
+ "visualize_trajectory",
6
+ "visualize_evidence_gates",
7
+ "AblationRunner",
8
+ ]
mr_jepa/utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (312 Bytes). View file
 
mr_jepa/utils/__pycache__/ablation.cpython-312.pyc ADDED
Binary file (7.16 kB). View file
 
mr_jepa/utils/__pycache__/visualization.cpython-312.pyc ADDED
Binary file (5.35 kB). View file
 
mr_jepa/utils/ablation.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Ablation Study Runner for MR-JEPA.
3
+
4
+ Supports systematic ablation experiments to validate the paper's contributions:
5
+
6
+ 1. Full MR-JEPA vs. No JEPA (remove JEPA loss, train with task loss only)
7
+ 2. Full MR-JEPA vs. No Rollout (use z₀ directly, K=0)
8
+ 3. Full MR-JEPA vs. No Evidence Gate (remove gating, always use full evidence)
9
+ 4. K=1 vs. K=3 vs. K=5 (rollout depth ablation)
10
+ 5. With vs. Without enriched evidence (Phase 3 ablation)
11
+ 6. Hybrid vs. Purist branch comparison
12
+ """
13
+
14
+ import copy
15
+ import json
16
+ import logging
17
+ from typing import Dict, List, Any, Optional
18
+ from dataclasses import dataclass, field
19
+ from pathlib import Path
20
+
21
+ from ..configs.model_config import MRJEPAConfig, get_hybrid_config, get_purist_config
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ @dataclass
27
+ class AblationConfig:
28
+ """Configuration for a single ablation experiment."""
29
+ name: str
30
+ description: str
31
+ modifications: Dict[str, Any] = field(default_factory=dict)
32
+ # What to change from the base config
33
+ disable_jepa: bool = False
34
+ disable_rollout: bool = False
35
+ disable_evidence_gate: bool = False
36
+ override_K: Optional[int] = None
37
+
38
+
39
+ # Predefined ablation experiments
40
+ ABLATION_EXPERIMENTS = {
41
+ "full_model": AblationConfig(
42
+ name="full_model",
43
+ description="Complete MR-JEPA (baseline)",
44
+ ),
45
+ "no_jepa": AblationConfig(
46
+ name="no_jepa",
47
+ description="Without JEPA objective (task loss only)",
48
+ disable_jepa=True,
49
+ ),
50
+ "no_rollout": AblationConfig(
51
+ name="no_rollout",
52
+ description="Without latent rollout (z₀ only, K=0)",
53
+ disable_rollout=True,
54
+ ),
55
+ "no_evidence_gate": AblationConfig(
56
+ name="no_evidence_gate",
57
+ description="Without evidence gating",
58
+ disable_evidence_gate=True,
59
+ ),
60
+ "K1": AblationConfig(
61
+ name="K1",
62
+ description="Rollout depth K=1",
63
+ override_K=1,
64
+ ),
65
+ "K3": AblationConfig(
66
+ name="K3",
67
+ description="Rollout depth K=3 (default)",
68
+ override_K=3,
69
+ ),
70
+ "K5": AblationConfig(
71
+ name="K5",
72
+ description="Rollout depth K=5",
73
+ override_K=5,
74
+ ),
75
+ "K7": AblationConfig(
76
+ name="K7",
77
+ description="Rollout depth K=7 (deep rollout)",
78
+ override_K=7,
79
+ ),
80
+ }
81
+
82
+
83
+ class AblationRunner:
84
+ """
85
+ Systematically run ablation experiments.
86
+
87
+ Usage:
88
+ runner = AblationRunner(base_config, experiments=['full_model', 'no_jepa', 'no_rollout'])
89
+ results = runner.run(train_data, eval_data)
90
+ runner.report()
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ base_config: Optional[MRJEPAConfig] = None,
96
+ experiments: Optional[List[str]] = None,
97
+ output_dir: str = "./ablations",
98
+ ):
99
+ self.base_config = base_config or get_hybrid_config()
100
+ self.experiments = experiments or list(ABLATION_EXPERIMENTS.keys())
101
+ self.output_dir = Path(output_dir)
102
+ self.output_dir.mkdir(parents=True, exist_ok=True)
103
+ self.results = {}
104
+
105
+ def _apply_ablation(self, config: MRJEPAConfig, ablation: AblationConfig) -> MRJEPAConfig:
106
+ """Apply ablation modifications to a config."""
107
+ modified = copy.deepcopy(config)
108
+
109
+ if ablation.override_K is not None:
110
+ modified.rollout.K = ablation.override_K
111
+
112
+ return modified
113
+
114
+ def generate_configs(self) -> Dict[str, MRJEPAConfig]:
115
+ """Generate configs for all ablation experiments."""
116
+ configs = {}
117
+ for exp_name in self.experiments:
118
+ if exp_name not in ABLATION_EXPERIMENTS:
119
+ logger.warning(f"Unknown ablation: {exp_name}")
120
+ continue
121
+
122
+ ablation = ABLATION_EXPERIMENTS[exp_name]
123
+ config = self._apply_ablation(self.base_config, ablation)
124
+ configs[exp_name] = config
125
+
126
+ return configs
127
+
128
+ def report(self) -> str:
129
+ """Generate a formatted ablation report."""
130
+ if not self.results:
131
+ return "No results yet."
132
+
133
+ lines = [
134
+ "=" * 80,
135
+ "MR-JEPA Ablation Study Results",
136
+ "=" * 80,
137
+ "",
138
+ ]
139
+
140
+ # Header
141
+ benchmarks = set()
142
+ for exp_results in self.results.values():
143
+ benchmarks.update(exp_results.keys())
144
+ benchmarks = sorted(benchmarks)
145
+
146
+ header = f"{'Experiment':<25}"
147
+ for b in benchmarks:
148
+ header += f" | {b:<12}"
149
+ lines.append(header)
150
+ lines.append("-" * len(header))
151
+
152
+ # Results rows
153
+ for exp_name, exp_results in self.results.items():
154
+ ablation = ABLATION_EXPERIMENTS.get(exp_name)
155
+ row = f"{exp_name:<25}"
156
+ for b in benchmarks:
157
+ if b in exp_results:
158
+ val = exp_results[b].get('accuracy',
159
+ exp_results[b].get('anls',
160
+ exp_results[b].get('vqa_accuracy',
161
+ exp_results[b].get('relaxed_accuracy', 0))))
162
+ row += f" | {val:>10.1f}%"
163
+ else:
164
+ row += f" | {'N/A':>10}"
165
+ lines.append(row)
166
+
167
+ lines.append("")
168
+ lines.append("Key findings:")
169
+
170
+ # Auto-detect key findings
171
+ if 'full_model' in self.results and 'no_jepa' in self.results:
172
+ lines.append("- JEPA vs No-JEPA: Compare 'full_model' and 'no_jepa' rows")
173
+ if 'full_model' in self.results and 'no_rollout' in self.results:
174
+ lines.append("- Rollout vs No-Rollout: Compare 'full_model' and 'no_rollout' rows")
175
+
176
+ report = "\n".join(lines)
177
+
178
+ # Save to file
179
+ with open(self.output_dir / "ablation_report.txt", "w") as f:
180
+ f.write(report)
181
+
182
+ return report
mr_jepa/utils/visualization.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visualization utilities for MR-JEPA.
3
+
4
+ Tools for analyzing and visualizing:
5
+ - Latent trajectory evolution (z₀ → z₁ → z₂ → z₃)
6
+ - Evidence gate activations per rollout step
7
+ - Attention maps between state and evidence
8
+ - t-SNE/UMAP of latent states across benchmarks
9
+ """
10
+
11
+ import torch
12
+ import numpy as np
13
+ from typing import Optional, Dict, List
14
+
15
+
16
+ def visualize_trajectory(
17
+ trajectory: torch.Tensor, # [K+1, N_s, D]
18
+ method: str = "pca",
19
+ title: str = "Latent Trajectory Evolution",
20
+ ) -> Dict[str, np.ndarray]:
21
+ """
22
+ Visualize the latent trajectory z₀→z₁→...→z_K.
23
+
24
+ Projects high-dimensional states into 2D for plotting.
25
+ Returns coordinates that can be plotted with matplotlib.
26
+
27
+ Args:
28
+ trajectory: [K+1, N_s, D] latent states for a single sample
29
+ method: 'pca' or 'tsne'
30
+ title: Plot title
31
+
32
+ Returns:
33
+ Dict with 'coords': [K+1, 2] projected centroids per step
34
+ """
35
+ K_plus_1, N_s, D = trajectory.shape
36
+
37
+ # Pool each step's tokens into a single vector
38
+ centroids = trajectory.mean(dim=1).detach().cpu().numpy() # [K+1, D]
39
+
40
+ if method == "pca":
41
+ # Simple PCA (no sklearn dependency)
42
+ centered = centroids - centroids.mean(axis=0)
43
+ cov = np.cov(centered.T)
44
+ eigenvalues, eigenvectors = np.linalg.eigh(cov)
45
+ # Take top 2 components
46
+ idx = np.argsort(eigenvalues)[::-1][:2]
47
+ proj_matrix = eigenvectors[:, idx]
48
+ coords = centered @ proj_matrix
49
+ else:
50
+ # Fallback to PCA for simplicity
51
+ centered = centroids - centroids.mean(axis=0)
52
+ U, S, Vt = np.linalg.svd(centered, full_matrices=False)
53
+ coords = U[:, :2] * S[:2]
54
+
55
+ return {
56
+ 'coords': coords, # [K+1, 2]
57
+ 'centroids': centroids, # [K+1, D] original
58
+ 'step_labels': [f'z_{k}' for k in range(K_plus_1)],
59
+ }
60
+
61
+
62
+ def visualize_evidence_gates(
63
+ model,
64
+ sample_output: Dict[str, torch.Tensor],
65
+ ) -> Dict[str, np.ndarray]:
66
+ """
67
+ Extract and visualize evidence gate activations per rollout step.
68
+
69
+ Shows how much evidence flows into each step of the rollout.
70
+ Early steps may attend more to visual evidence, while later steps
71
+ rely more on accumulated reasoning.
72
+
73
+ Args:
74
+ model: MRJEPAModel instance
75
+ sample_output: Forward pass output dict
76
+
77
+ Returns:
78
+ Dict with gate activation statistics per step
79
+ """
80
+ # This requires hooks or storing gate values during forward pass
81
+ # For now, return placeholder structure
82
+ gate_stats = {
83
+ 'mean_gate_values': [],
84
+ 'gate_entropy': [],
85
+ }
86
+
87
+ # Access predictor layers' evidence gates
88
+ for i, layer in enumerate(model.latent_rollout.predictor_layers):
89
+ if hasattr(layer.evidence_gate, 'gate_proj'):
90
+ # Could install hooks here for detailed analysis
91
+ pass
92
+
93
+ return gate_stats
94
+
95
+
96
+ def compute_trajectory_metrics(
97
+ trajectory: torch.Tensor, # [B, K+1, N_s, D]
98
+ ) -> Dict[str, float]:
99
+ """
100
+ Compute analytical metrics on the latent trajectory.
101
+
102
+ Useful for ablation analysis:
103
+ - Inter-step distance: how much the state changes per step
104
+ - Trajectory length: total path length in latent space
105
+ - Convergence rate: diminishing step sizes indicate convergence
106
+ - State diversity: variance within each step's tokens
107
+ """
108
+ B, K_plus_1, N_s, D = trajectory.shape
109
+
110
+ # Pool to centroids
111
+ centroids = trajectory.mean(dim=2) # [B, K+1, D]
112
+
113
+ # Inter-step distances
114
+ step_distances = []
115
+ for k in range(K_plus_1 - 1):
116
+ dist = torch.norm(centroids[:, k+1] - centroids[:, k], dim=-1) # [B]
117
+ step_distances.append(dist.mean().item())
118
+
119
+ # Trajectory length
120
+ total_length = sum(step_distances)
121
+
122
+ # Convergence rate (ratio of last step distance to first)
123
+ convergence = step_distances[-1] / max(step_distances[0], 1e-6) if step_distances else 1.0
124
+
125
+ # State diversity per step
126
+ diversity = []
127
+ for k in range(K_plus_1):
128
+ var = trajectory[:, k].var(dim=1).mean().item() # Avg variance across tokens
129
+ diversity.append(var)
130
+
131
+ return {
132
+ 'step_distances': step_distances,
133
+ 'trajectory_length': total_length,
134
+ 'convergence_rate': convergence,
135
+ 'state_diversity': diversity,
136
+ 'avg_step_distance': total_length / max(K_plus_1 - 1, 1),
137
+ }
test_architecture.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MR-JEPA Architecture Validation Test.
3
+
4
+ Tests the complete forward pass with synthetic data to verify:
5
+ 1. All modules instantiate correctly
6
+ 2. Tensor shapes are consistent throughout
7
+ 3. JEPA loss computes correctly
8
+ 4. Target encoder EMA updates work
9
+ 5. Both MC and open-ended heads produce valid output
10
+ 6. Ablation controls work (no-JEPA, no-rollout, no-evidence-gate)
11
+ 7. Parameter counting is correct
12
+ """
13
+
14
+ import sys
15
+ sys.path.insert(0, '/app')
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import numpy as np
20
+ from mr_jepa.configs.model_config import (
21
+ MRJEPAConfig, VisualBackboneConfig, TextEncoderConfig,
22
+ EvidenceMemoryConfig, LatentRolloutConfig, JEPAObjectiveConfig,
23
+ AnswerHeadConfig, TrainingPhaseConfig,
24
+ )
25
+ from mr_jepa.models.evidence_memory import EvidenceMemory
26
+ from mr_jepa.models.latent_rollout import LatentRolloutModule
27
+ from mr_jepa.models.target_encoder import TargetEncoder, JEPALoss, SIGRegLoss, VICRegLoss
28
+ from mr_jepa.models.answer_heads import DiscriminativeHead, GenerativeHead
29
+
30
+
31
+ def test_evidence_memory():
32
+ """Test Evidence Memory module."""
33
+ print("\n=== Test: Evidence Memory ===")
34
+
35
+ config = EvidenceMemoryConfig(
36
+ hidden_dim=256,
37
+ num_evidence_tokens=16,
38
+ num_cross_attn_layers=2,
39
+ num_heads=4,
40
+ dropout=0.1,
41
+ )
42
+
43
+ visual_dim = 512
44
+ text_dim = 384
45
+ B = 4
46
+ N_v = 49 # e.g., 7x7 patches
47
+ N_t = 32 # text tokens
48
+
49
+ model = EvidenceMemory(config, visual_dim=visual_dim, text_dim=text_dim)
50
+
51
+ # Synthetic inputs
52
+ visual_tokens = torch.randn(B, N_v, visual_dim)
53
+ text_tokens = torch.randn(B, N_t, text_dim)
54
+ text_mask = torch.ones(B, N_t) # All valid
55
+ text_mask[:, -5:] = 0 # Last 5 are padding
56
+
57
+ output = model(visual_tokens, text_tokens, text_mask)
58
+
59
+ evidence = output['evidence_tokens']
60
+ kv_tokens = output['kv_tokens']
61
+
62
+ print(f" Evidence tokens shape: {evidence.shape}") # [B, 16, 256]
63
+ print(f" KV tokens shape: {kv_tokens.shape}") # [B, N_v+N_t, 256]
64
+
65
+ assert evidence.shape == (B, config.num_evidence_tokens, config.hidden_dim)
66
+ assert kv_tokens.shape[0] == B
67
+ assert kv_tokens.shape[2] == config.hidden_dim
68
+
69
+ print(" ✓ Evidence Memory passed!")
70
+ return model
71
+
72
+
73
+ def test_latent_rollout():
74
+ """Test Latent Rollout module."""
75
+ print("\n=== Test: Latent Rollout ===")
76
+
77
+ config = LatentRolloutConfig(
78
+ hidden_dim=256,
79
+ num_state_tokens=8,
80
+ K=3,
81
+ num_predictor_layers=2,
82
+ num_heads=4,
83
+ ffn_dim=512,
84
+ dropout=0.1,
85
+ use_evidence_gate=True,
86
+ gate_type="sigmoid",
87
+ use_step_embedding=True,
88
+ )
89
+
90
+ B = 4
91
+ N_e = 16 # Evidence tokens
92
+
93
+ model = LatentRolloutModule(config)
94
+
95
+ evidence_tokens = torch.randn(B, N_e, config.hidden_dim)
96
+
97
+ output = model(evidence_tokens)
98
+
99
+ trajectory = output['trajectory']
100
+ z_final = output['z_final']
101
+ z_projected = output['z_projected']
102
+
103
+ print(f" Trajectory shape: {trajectory.shape}") # [B, K+1, N_s, D]
104
+ print(f" Z_final shape: {z_final.shape}") # [B, N_s, D]
105
+ print(f" Z_projected shape: {z_projected.shape}") # [B, K+1, N_s, D]
106
+
107
+ assert trajectory.shape == (B, config.K + 1, config.num_state_tokens, config.hidden_dim)
108
+ assert z_final.shape == (B, config.num_state_tokens, config.hidden_dim)
109
+ assert z_projected.shape == trajectory.shape
110
+
111
+ print(" ✓ Latent Rollout passed!")
112
+ return model
113
+
114
+
115
+ def test_target_encoder_and_jepa_loss():
116
+ """Test Target Encoder EMA and JEPA Loss."""
117
+ print("\n=== Test: Target Encoder + JEPA Loss ===")
118
+
119
+ D = 256
120
+ N_e = 16
121
+ N_s = 8
122
+ K = 3
123
+ B = 4
124
+
125
+ evidence_config = EvidenceMemoryConfig(
126
+ hidden_dim=D, num_evidence_tokens=N_e,
127
+ num_cross_attn_layers=2, num_heads=4,
128
+ )
129
+ rollout_config = LatentRolloutConfig(
130
+ hidden_dim=D, num_state_tokens=N_s, K=K,
131
+ num_predictor_layers=2, num_heads=4, ffn_dim=512,
132
+ )
133
+ jepa_config = JEPAObjectiveConfig(
134
+ ema_momentum_base=0.996, ema_momentum_end=1.0,
135
+ use_sigreg=True, sigreg_weight=0.1,
136
+ )
137
+
138
+ # Create online modules
139
+ visual_dim = 512
140
+ text_dim = 384
141
+ evidence_mem = EvidenceMemory(evidence_config, visual_dim, text_dim)
142
+ rollout = LatentRolloutModule(rollout_config)
143
+
144
+ # Create target encoder
145
+ target_enc = TargetEncoder(evidence_mem, rollout, jepa_config)
146
+
147
+ # Test EMA update
148
+ original_param = list(target_enc.target_rollout.parameters())[0].clone()
149
+
150
+ # Modify online params
151
+ with torch.no_grad():
152
+ for p in rollout.parameters():
153
+ p.add_(torch.randn_like(p) * 0.1)
154
+
155
+ target_enc.update_ema(evidence_mem, rollout, step=100, total_steps=1000)
156
+
157
+ updated_param = list(target_enc.target_rollout.parameters())[0]
158
+ assert not torch.allclose(original_param, updated_param), "EMA did not update!"
159
+ print(f" EMA momentum: {target_enc._current_momentum:.6f}")
160
+
161
+ # Test target forward
162
+ visual_tokens = torch.randn(B, 49, visual_dim)
163
+ text_tokens = torch.randn(B, 32, text_dim)
164
+ text_mask = torch.ones(B, 32)
165
+
166
+ target_output = target_enc(visual_tokens, text_tokens, text_mask)
167
+ target_traj = target_output['target_trajectory']
168
+ print(f" Target trajectory shape: {target_traj.shape}")
169
+ assert target_traj.shape == (B, K + 1, N_s, D)
170
+
171
+ # Test JEPA Loss
172
+ jepa_loss_fn = JEPALoss(jepa_config, D)
173
+
174
+ pred_traj = torch.randn(B, K + 1, N_s, D, requires_grad=True)
175
+ task_loss = torch.tensor(1.5)
176
+
177
+ loss_dict = jepa_loss_fn(pred_traj, target_traj, task_loss)
178
+
179
+ print(f" JEPA loss: {loss_dict['jepa_loss'].item():.4f}")
180
+ print(f" Task loss: {loss_dict['task_loss'].item():.4f}")
181
+ print(f" Reg loss: {loss_dict['reg_loss'].item():.4f}")
182
+ print(f" Total loss: {loss_dict['total_loss'].item():.4f}")
183
+
184
+ # Check gradients flow
185
+ loss_dict['total_loss'].backward()
186
+ assert pred_traj.grad is not None, "No gradients!"
187
+ print(f" Gradient norm: {pred_traj.grad.norm().item():.4f}")
188
+
189
+ print(" ✓ Target Encoder + JEPA Loss passed!")
190
+
191
+
192
+ def test_answer_heads():
193
+ """Test Discriminative and Generative heads."""
194
+ print("\n=== Test: Answer Heads ===")
195
+
196
+ D = 256
197
+ text_dim = 384
198
+ B = 4
199
+ N_s = 8
200
+ max_opts = 4
201
+ vocab_size = 1000
202
+
203
+ head_config = AnswerHeadConfig(
204
+ disc_hidden_dim=256,
205
+ disc_num_layers=2,
206
+ max_num_options=max_opts,
207
+ gen_hidden_dim=256,
208
+ gen_num_layers=2,
209
+ gen_num_heads=4,
210
+ gen_vocab_size=vocab_size,
211
+ gen_max_answer_length=32,
212
+ )
213
+
214
+ # Test Discriminative Head
215
+ disc_head = DiscriminativeHead(head_config, hidden_dim=D, text_dim=text_dim)
216
+
217
+ z_final = torch.randn(B, N_s, D)
218
+ option_embs = torch.randn(B, max_opts, text_dim)
219
+ option_mask = torch.tensor([
220
+ [True, True, True, True],
221
+ [True, True, True, False],
222
+ [True, True, False, False],
223
+ [True, True, True, True],
224
+ ])
225
+
226
+ disc_output = disc_head(z_final, option_embs, option_mask)
227
+
228
+ print(f" Disc logits shape: {disc_output['logits'].shape}") # [B, max_opts]
229
+ print(f" Disc probs shape: {disc_output['probs'].shape}")
230
+ print(f" Sample probs: {disc_output['probs'][0].tolist()}")
231
+
232
+ # Check masking
233
+ assert disc_output['logits'][2, 2] == float('-inf'), "Masked option should be -inf!"
234
+ assert disc_output['probs'][2, 2].item() < 1e-6, "Masked option should have ~0 prob!"
235
+
236
+ # Test Generative Head
237
+ gen_head = GenerativeHead(head_config, hidden_dim=D, vocab_size=vocab_size)
238
+
239
+ target_ids = torch.randint(0, vocab_size, (B, 16))
240
+
241
+ gen_output = gen_head(z_final, target_ids)
242
+
243
+ print(f" Gen logits shape: {gen_output['logits'].shape}") # [B, 16, vocab_size]
244
+ print(f" Gen loss: {gen_output['loss'].item():.4f}")
245
+
246
+ # Test generation
247
+ generated = gen_head.generate(z_final, start_token_id=1, max_length=10)
248
+ print(f" Generated shape: {generated.shape}") # [B, <=10]
249
+
250
+ print(" ✓ Answer Heads passed!")
251
+
252
+
253
+ def test_sigreg_and_vicreg():
254
+ """Test anti-collapse regularization losses."""
255
+ print("\n=== Test: SIGReg + VICReg ===")
256
+
257
+ D = 256
258
+ B = 32
259
+ N = 8
260
+
261
+ # SIGReg
262
+ sigreg = SIGRegLoss(D, num_projections=64)
263
+ z = torch.randn(B, N, D)
264
+ loss = sigreg(z)
265
+ print(f" SIGReg loss (random): {loss.item():.4f}")
266
+
267
+ # Test collapse detection
268
+ z_collapsed = torch.ones(B, N, D) # Collapsed representation
269
+ loss_collapsed = sigreg(z_collapsed)
270
+ print(f" SIGReg loss (collapsed): {loss_collapsed.item():.4f}")
271
+ assert loss_collapsed > loss, "SIGReg should penalize collapsed representations more!"
272
+
273
+ # VICReg
274
+ vicreg = VICRegLoss(var_weight=1.0, cov_weight=0.04)
275
+ z = torch.randn(B, N, D)
276
+ loss = vicreg(z)
277
+ print(f" VICReg loss (random): {loss.item():.4f}")
278
+
279
+ print(" ✓ SIGReg + VICReg passed!")
280
+
281
+
282
+ def test_parameter_counting():
283
+ """Count and verify parameter distribution."""
284
+ print("\n=== Test: Parameter Counting ===")
285
+
286
+ D = 256
287
+
288
+ evidence_config = EvidenceMemoryConfig(
289
+ hidden_dim=D, num_evidence_tokens=16,
290
+ num_cross_attn_layers=2, num_heads=4,
291
+ )
292
+ rollout_config = LatentRolloutConfig(
293
+ hidden_dim=D, num_state_tokens=8, K=3,
294
+ num_predictor_layers=3, num_heads=4, ffn_dim=512,
295
+ )
296
+
297
+ evidence = EvidenceMemory(evidence_config, visual_dim=512, text_dim=384)
298
+ rollout = LatentRolloutModule(rollout_config)
299
+
300
+ def count_params(module):
301
+ return sum(p.numel() for p in module.parameters())
302
+
303
+ def count_trainable(module):
304
+ return sum(p.numel() for p in module.parameters() if p.requires_grad)
305
+
306
+ print(f" Evidence Memory: {count_params(evidence):,} params")
307
+ print(f" Latent Rollout: {count_params(rollout):,} params")
308
+
309
+ # The rollout should be much smaller than the backbone (I-JEPA: narrow predictor)
310
+ print(f" Evidence trainable: {count_trainable(evidence):,}")
311
+ print(f" Rollout trainable: {count_trainable(rollout):,}")
312
+
313
+ print(" ✓ Parameter Counting passed!")
314
+
315
+
316
+ def test_trajectory_metrics():
317
+ """Test trajectory analysis utilities."""
318
+ print("\n=== Test: Trajectory Metrics ===")
319
+
320
+ from mr_jepa.utils.visualization import compute_trajectory_metrics, visualize_trajectory
321
+
322
+ B = 4
323
+ K = 3
324
+ N_s = 8
325
+ D = 256
326
+
327
+ # Create a trajectory that converges
328
+ trajectory = torch.randn(B, K + 1, N_s, D)
329
+ # Make each step closer to the previous (simulating convergence)
330
+ for k in range(1, K + 1):
331
+ trajectory[:, k] = trajectory[:, k-1] + torch.randn(B, N_s, D) * (0.5 ** k)
332
+
333
+ metrics = compute_trajectory_metrics(trajectory)
334
+
335
+ print(f" Step distances: {[f'{d:.4f}' for d in metrics['step_distances']]}")
336
+ print(f" Trajectory length: {metrics['trajectory_length']:.4f}")
337
+ print(f" Convergence rate: {metrics['convergence_rate']:.4f}")
338
+ print(f" State diversity: {[f'{d:.4f}' for d in metrics['state_diversity']]}")
339
+
340
+ # Test visualization
341
+ viz = visualize_trajectory(trajectory[0], method='pca')
342
+ print(f" PCA coords shape: {viz['coords'].shape}")
343
+ print(f" Step labels: {viz['step_labels']}")
344
+
345
+ assert metrics['convergence_rate'] < 1.0, "Convergence rate should be < 1 for converging trajectory"
346
+
347
+ print(" ✓ Trajectory Metrics passed!")
348
+
349
+
350
+ def test_evaluation_metrics():
351
+ """Test all evaluation metrics."""
352
+ print("\n=== Test: Evaluation Metrics ===")
353
+
354
+ from mr_jepa.evaluation.metrics import (
355
+ compute_accuracy, compute_anls, compute_vqa_accuracy,
356
+ compute_relaxed_accuracy, evaluate_benchmark,
357
+ )
358
+
359
+ # Accuracy
360
+ result = compute_accuracy([0, 1, 2, 0], [0, 1, 1, 0])
361
+ print(f" Accuracy: {result['accuracy']:.1f}%")
362
+ assert result['accuracy'] == 75.0
363
+
364
+ # ANLS
365
+ result = compute_anls(
366
+ ["hello world", "test", "abc"],
367
+ [["hello world", "hi world"], ["testing"], ["xyz"]],
368
+ )
369
+ print(f" ANLS: {result['anls']:.1f}%")
370
+
371
+ # VQA Accuracy
372
+ result = compute_vqa_accuracy(
373
+ ["cat", "dog"],
374
+ [["cat", "cat", "cat", "kitten", "cat", "cat", "feline", "cat", "cat", "cat"],
375
+ ["dog", "puppy", "dog", "canine", "dog", "dog", "dog", "dog", "dog", "dog"]],
376
+ )
377
+ print(f" VQA Accuracy: {result['vqa_accuracy']:.1f}%")
378
+
379
+ # Relaxed Accuracy
380
+ result = compute_relaxed_accuracy(
381
+ ["100", "52", "hello"],
382
+ ["100", "50", "hello"],
383
+ types=["human_test", "augmented_test", "human_test"],
384
+ )
385
+ print(f" Relaxed Accuracy: {result['relaxed_accuracy']:.1f}%")
386
+
387
+ print(" ✓ Evaluation Metrics passed!")
388
+
389
+
390
+ def test_end_to_end_forward():
391
+ """Test a simplified end-to-end forward pass (without pretrained backbones)."""
392
+ print("\n=== Test: End-to-End Forward Pass (Synthetic) ===")
393
+
394
+ D = 256
395
+ B = 2
396
+ N_v = 49
397
+ N_t = 32
398
+ N_e = 16
399
+ N_s = 8
400
+ K = 3
401
+ max_opts = 4
402
+ vocab_size = 100
403
+ visual_dim = 512
404
+ text_dim = 384
405
+
406
+ # Build components manually (without pretrained models)
407
+ evidence_config = EvidenceMemoryConfig(
408
+ hidden_dim=D, num_evidence_tokens=N_e,
409
+ num_cross_attn_layers=2, num_heads=4,
410
+ )
411
+ rollout_config = LatentRolloutConfig(
412
+ hidden_dim=D, num_state_tokens=N_s, K=K,
413
+ num_predictor_layers=2, num_heads=4, ffn_dim=512,
414
+ )
415
+ jepa_config = JEPAObjectiveConfig(use_sigreg=True, sigreg_weight=0.1)
416
+ head_config = AnswerHeadConfig(
417
+ disc_hidden_dim=D, gen_hidden_dim=D, gen_num_layers=2,
418
+ gen_num_heads=4, gen_vocab_size=vocab_size, gen_max_answer_length=16,
419
+ )
420
+
421
+ evidence_mem = EvidenceMemory(evidence_config, visual_dim, text_dim)
422
+ rollout = LatentRolloutModule(rollout_config)
423
+ target_enc = TargetEncoder(evidence_mem, rollout, jepa_config)
424
+ disc_head = DiscriminativeHead(head_config, D, text_dim)
425
+ gen_head = GenerativeHead(head_config, D, vocab_size)
426
+ jepa_loss_fn = JEPALoss(jepa_config, D)
427
+
428
+ # Synthetic inputs
429
+ visual_tokens = torch.randn(B, N_v, visual_dim)
430
+ text_tokens = torch.randn(B, N_t, text_dim)
431
+ text_mask = torch.ones(B, N_t)
432
+ option_embs = torch.randn(B, max_opts, text_dim)
433
+ option_mask = torch.ones(B, max_opts, dtype=torch.bool)
434
+ answer_labels = torch.tensor([1, 3])
435
+ gen_targets = torch.randint(0, vocab_size, (B, 16))
436
+
437
+ # Forward pass
438
+ evidence_output = evidence_mem(visual_tokens, text_tokens, text_mask)
439
+ evidence = evidence_output['evidence_tokens']
440
+
441
+ rollout_output = rollout(evidence)
442
+ trajectory = rollout_output['trajectory']
443
+ z_final = rollout_output['z_final']
444
+ z_projected = rollout_output['z_projected']
445
+
446
+ # Target encoder (no grad)
447
+ target_output = target_enc(visual_tokens, text_tokens, text_mask)
448
+ target_traj = target_output['target_trajectory']
449
+
450
+ # Answer heads
451
+ disc_output = disc_head(z_final, option_embs, option_mask)
452
+ task_loss = nn.functional.cross_entropy(disc_output['logits'], answer_labels)
453
+
454
+ gen_output = gen_head(z_final, gen_targets, evidence)
455
+
456
+ # JEPA loss
457
+ loss_dict = jepa_loss_fn(z_projected, target_traj, task_loss, gen_output['loss'])
458
+
459
+ total_loss = loss_dict['total_loss']
460
+ total_loss.backward()
461
+
462
+ print(f" Evidence shape: {evidence.shape}")
463
+ print(f" Trajectory shape: {trajectory.shape}")
464
+ print(f" Z_final shape: {z_final.shape}")
465
+ print(f" Disc logits: {disc_output['logits'].shape}")
466
+ print(f" Gen logits: {gen_output['logits'].shape}")
467
+ print(f" Total loss: {total_loss.item():.4f}")
468
+ print(f" JEPA loss: {loss_dict['jepa_loss'].item():.4f}")
469
+ print(f" Task loss: {loss_dict['task_loss'].item():.4f}")
470
+ print(f" Gen loss: {loss_dict['gen_loss'].item():.4f}")
471
+ print(f" Reg loss: {loss_dict['reg_loss'].item():.4f}")
472
+
473
+ # EMA update
474
+ target_enc.update_ema(evidence_mem, rollout, step=1, total_steps=100)
475
+ print(f" EMA momentum: {target_enc._current_momentum:.6f}")
476
+
477
+ # Check all gradients flow
478
+ has_grad = sum(1 for p in evidence_mem.parameters() if p.grad is not None)
479
+ total_p = sum(1 for p in evidence_mem.parameters())
480
+ print(f" Evidence memory: {has_grad}/{total_p} params have gradients")
481
+
482
+ has_grad = sum(1 for p in rollout.parameters() if p.grad is not None)
483
+ total_p = sum(1 for p in rollout.parameters())
484
+ print(f" Rollout: {has_grad}/{total_p} params have gradients")
485
+
486
+ print(" ✓ End-to-End Forward Pass passed!")
487
+
488
+
489
+ if __name__ == "__main__":
490
+ print("=" * 60)
491
+ print("MR-JEPA Architecture Validation")
492
+ print("=" * 60)
493
+
494
+ test_evidence_memory()
495
+ test_latent_rollout()
496
+ test_target_encoder_and_jepa_loss()
497
+ test_answer_heads()
498
+ test_sigreg_and_vicreg()
499
+ test_parameter_counting()
500
+ test_trajectory_metrics()
501
+ test_evaluation_metrics()
502
+ test_end_to_end_forward()
503
+
504
+ print("\n" + "=" * 60)
505
+ print("ALL TESTS PASSED ✓")
506
+ print("=" * 60)