Initial MR-JEPA codebase: architecture, training, evaluation, and tests
Browse files- README.md +44 -0
- mr_jepa/ARCHITECTURE.md +303 -0
- mr_jepa/__init__.py +9 -0
- mr_jepa/configs/__init__.py +25 -0
- mr_jepa/configs/__pycache__/__init__.cpython-312.pyc +0 -0
- mr_jepa/configs/__pycache__/model_config.cpython-312.pyc +0 -0
- mr_jepa/configs/model_config.py +306 -0
- mr_jepa/data/__init__.py +9 -0
- mr_jepa/data/data_utils.py +273 -0
- mr_jepa/data/unified_dataset.py +380 -0
- mr_jepa/evaluation/__init__.py +15 -0
- mr_jepa/evaluation/__pycache__/__init__.cpython-312.pyc +0 -0
- mr_jepa/evaluation/__pycache__/metrics.cpython-312.pyc +0 -0
- mr_jepa/evaluation/metrics.py +251 -0
- mr_jepa/models/__init__.py +17 -0
- mr_jepa/models/__pycache__/answer_heads.cpython-312.pyc +0 -0
- mr_jepa/models/__pycache__/evidence_memory.cpython-312.pyc +0 -0
- mr_jepa/models/__pycache__/latent_rollout.cpython-312.pyc +0 -0
- mr_jepa/models/__pycache__/target_encoder.cpython-312.pyc +0 -0
- mr_jepa/models/answer_heads.py +369 -0
- mr_jepa/models/backbones.py +180 -0
- mr_jepa/models/evidence_memory.py +299 -0
- mr_jepa/models/latent_rollout.py +324 -0
- mr_jepa/models/mr_jepa.py +350 -0
- mr_jepa/models/target_encoder.py +354 -0
- mr_jepa/training/__init__.py +4 -0
- mr_jepa/training/phase_scheduler.py +107 -0
- mr_jepa/training/trainer.py +397 -0
- mr_jepa/utils/__init__.py +8 -0
- mr_jepa/utils/__pycache__/__init__.cpython-312.pyc +0 -0
- mr_jepa/utils/__pycache__/ablation.cpython-312.pyc +0 -0
- mr_jepa/utils/__pycache__/visualization.cpython-312.pyc +0 -0
- mr_jepa/utils/ablation.py +182 -0
- mr_jepa/utils/visualization.py +137 -0
- test_architecture.py +506 -0
README.md
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: ml-intern sandbox
|
| 3 |
+
emoji: 🌍
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
pinned: false
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# MR-JEPA: Multimodal Reasoning via Joint-Embedding Predictive Architecture
|
| 12 |
+
|
| 13 |
+
> A world model for multimodal reasoning that refines a latent belief state over K=3 steps using JEPA-style prediction, evidence gating, and dense visual backbones.
|
| 14 |
+
|
| 15 |
+
## Key Idea
|
| 16 |
+
|
| 17 |
+
Traditional multimodal models produce answers in a single forward pass. MR-JEPA instead models **the evolution of a belief state** as the system reasons about a question:
|
| 18 |
+
|
| 19 |
+
```
|
| 20 |
+
z₀ (initial evidence) → z₁ (first refinement) → z₂ (deeper reasoning) → z₃ (answer)
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
This trajectory is supervised by a **JEPA objective**: a target encoder (EMA) generates target latent states, and the online predictor learns to predict them. The JEPA loss encourages the model to learn **meaningful intermediate reasoning states** — not just the final answer.
|
| 24 |
+
|
| 25 |
+
## Architecture
|
| 26 |
+
|
| 27 |
+
```
|
| 28 |
+
┌─────────────┐ ┌──────────────┐ ┌─────────────────┐ ┌──────────┐
|
| 29 |
+
│ DINOv2/v3 │────▶│ Evidence │────▶│ Latent Rollout │────▶│ Answer │
|
| 30 |
+
│ (frozen) │ │ Memory │ │ z₀→z₁→z₂→z₃ │ │ Heads │
|
| 31 |
+
└─────────────┘ │ (Perceiver) │ │ (shared block) │ └──────────┘
|
| 32 |
+
└──────┬───────┘ └────────┬────────┘
|
| 33 |
+
┌─────────────┐ │ │
|
| 34 |
+
│ DeBERTa-v3 │───────────┘ ┌───────┴────────┐
|
| 35 |
+
│ (frozen) │ │ Target Encoder │
|
| 36 |
+
└─────────────┘ │ (EMA copy) │
|
| 37 |
+
└────────────────┘
|
| 38 |
+
┌─────────────┐ │
|
| 39 |
+
│ OCR/Layout/ │──────────┘ JEPA Loss: L₂ + SIGReg
|
| 40 |
+
│ Chart/SAM │ (Phase 3)
|
| 41 |
+
└─────────────┘
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
See `mr_jepa/ARCHITECTURE.md` for the complete specification.
|
mr_jepa/ARCHITECTURE.md
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MR-JEPA: Multimodal Reasoning via Joint-Embedding Predictive Architecture
|
| 2 |
+
|
| 3 |
+
## Detailed Architecture Specification
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## 1. Overview
|
| 8 |
+
|
| 9 |
+
MR-JEPA is a **world model for static multimodal reasoning**. Unlike traditional world models that predict physical dynamics (video, robotics), MR-JEPA models the evolution of a **belief state** as the system reasons about a visual question.
|
| 10 |
+
|
| 11 |
+
The core insight: solving a multimodal question (e.g., "What is the GDP growth shown in this chart?") requires iterative evidence accumulation — first extracting relevant visual features, then integrating textual context, then refining understanding through multiple reasoning steps. MR-JEPA formalizes this process as a **latent trajectory** supervised by a JEPA objective.
|
| 12 |
+
|
| 13 |
+
```
|
| 14 |
+
┌──────────────────────────────────────────┐
|
| 15 |
+
│ MR-JEPA Architecture │
|
| 16 |
+
└──────────────────────────────────────────┘
|
| 17 |
+
|
| 18 |
+
┌─────────┐ ┌─────────────┐ ┌──────────────┐
|
| 19 |
+
│ DINOv2/v3│────▶│ Evidence │────▶│ Latent │──▶ Answer
|
| 20 |
+
│ Visual │ │ Memory │ │ Rollout │ Heads
|
| 21 |
+
│ Backbone │ │ (Perceiver)│ │ K=3 steps │
|
| 22 |
+
└─────────┘ └──────┬──────┘ └──────┬───────┘
|
| 23 |
+
│ │
|
| 24 |
+
┌─────────┐ │ ┌──────┴───────┐
|
| 25 |
+
│ DeBERTa │────────────┘ │ Target │
|
| 26 |
+
│ Text │ │ Encoder │
|
| 27 |
+
│ Encoder │ │ (EMA) │
|
| 28 |
+
└─────────┘ └──────────────┘
|
| 29 |
+
│
|
| 30 |
+
┌─────────┐ JEPA Loss:
|
| 31 |
+
│Optional:│ L₂ prediction
|
| 32 |
+
│OCR,SAM, │──────────┘ + SIGReg
|
| 33 |
+
│Layout │
|
| 34 |
+
└─────────┘
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
---
|
| 38 |
+
|
| 39 |
+
## 2. Component Details
|
| 40 |
+
|
| 41 |
+
### 2.1 Visual Backbone
|
| 42 |
+
|
| 43 |
+
**Primary choice: DINOv2-L/14** (`facebook/dinov2-large`)
|
| 44 |
+
- Architecture: ViT-L/14 with 300M parameters
|
| 45 |
+
- Output: 1024-dim patch tokens, 518×518 input → 1369 patches
|
| 46 |
+
- 4 register tokens + CLS token (skipped, only patch tokens used)
|
| 47 |
+
- Pre-trained with self-supervised DINO objective on LVD-142M
|
| 48 |
+
- **Why DINOv2 over CLIP/SigLIP**: Dense patch features are critical for evidence extraction. CLIP-style models optimize for global image-text alignment but lose local spatial information. DINOv2 produces patch-level features that capture fine-grained visual details needed for chart reading, document OCR, and diagram understanding.
|
| 49 |
+
|
| 50 |
+
**Alternative: DINOv3-L/16** (`timm/vit_large_patch16_dinov3.lvd1689m`)
|
| 51 |
+
- Architecture: ViT-L/16 with RoPE positional encoding
|
| 52 |
+
- Advantages: Better resolution generalization, Gram anchoring prevents feature degradation
|
| 53 |
+
- Trained on LVD-1689M (10× more data)
|
| 54 |
+
|
| 55 |
+
**Purist branch: DINOv2-B/14** (`facebook/dinov2-base`)
|
| 56 |
+
- 768-dim output, 86M params
|
| 57 |
+
- Compensated by deeper rollout (K=5)
|
| 58 |
+
|
| 59 |
+
### 2.2 Text Encoder
|
| 60 |
+
|
| 61 |
+
**DeBERTa-v3-Large** (`microsoft/deberta-v3-large`)
|
| 62 |
+
- 1024-dim hidden, 24 layers, 304M params
|
| 63 |
+
- Processes: question text + answer options (concatenated with separators)
|
| 64 |
+
- Output: token-level embeddings for cross-attention + CLS for option scoring
|
| 65 |
+
|
| 66 |
+
**Why DeBERTa over BERT/RoBERTa**: DeBERTa-v3's disentangled attention mechanism explicitly models content vs. position, giving stronger performance on complex NLU tasks. Its relative position bias is particularly useful for understanding mathematical notation and structured question formats.
|
| 67 |
+
|
| 68 |
+
### 2.3 Evidence Memory
|
| 69 |
+
|
| 70 |
+
**Architecture: Perceiver-style cross-attention**
|
| 71 |
+
|
| 72 |
+
```python
|
| 73 |
+
N_evidence = 64 # Learnable query tokens
|
| 74 |
+
D = 768 # Hidden dimension
|
| 75 |
+
L = 4 # Cross-attention layers
|
| 76 |
+
|
| 77 |
+
# Each layer:
|
| 78 |
+
# 1. Self-attention among evidence queries
|
| 79 |
+
# 2. Cross-attention: queries attend to [visual_patches || text_tokens || enriched_tokens]
|
| 80 |
+
# 3. FFN with residual
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
**Input tokens (concatenated KV sequence):**
|
| 84 |
+
| Source | Tokens | Dimension | Phase |
|
| 85 |
+
|--------|--------|-----------|-------|
|
| 86 |
+
| DINOv2-L patches | 1369 | 1024→768 (projected) | 1+ |
|
| 87 |
+
| DeBERTa text | 256 | 1024→768 (projected) | 1+ |
|
| 88 |
+
| OCR tokens | 128 | 768 | 3 |
|
| 89 |
+
| Layout tokens | 64 | 256→768 (projected) | 3 |
|
| 90 |
+
| Chart tokens | 64 | 512→768 (projected) | 3 |
|
| 91 |
+
| SAM2 segments | 32 | 256→768 (projected) | 3 (optional) |
|
| 92 |
+
|
| 93 |
+
**Modality type embeddings** (learned, added to distinguish token sources).
|
| 94 |
+
|
| 95 |
+
**Output**: 64 evidence tokens × 768 dim = dense multimodal representation.
|
| 96 |
+
|
| 97 |
+
### 2.4 Latent Rollout (JEPA Core)
|
| 98 |
+
|
| 99 |
+
The reasoning engine. Refines a belief state over K steps:
|
| 100 |
+
|
| 101 |
+
```
|
| 102 |
+
z₀ = StateInit + Proj(AvgPool(evidence)) # Initial state from evidence
|
| 103 |
+
z₁ = PredictorBlock(z₀, evidence) + step_emb[1]
|
| 104 |
+
z₂ = PredictorBlock(z₁, evidence) + step_emb[2]
|
| 105 |
+
z₃ = PredictorBlock(z₂, evidence) + step_emb[3] # Final state → answer
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
**State representation**: 32 learnable tokens × 768 dim
|
| 109 |
+
|
| 110 |
+
**Shared Predictor Block** (weight-tied across K steps):
|
| 111 |
+
```
|
| 112 |
+
For each step k:
|
| 113 |
+
1. Self-attention among 32 state tokens
|
| 114 |
+
2. Evidence-gated cross-attention to 64 evidence tokens
|
| 115 |
+
3. FFN (768 → 3072 → 768)
|
| 116 |
+
|
| 117 |
+
PredictorBlock = [SelfAttn → EvidenceGate(CrossAttn) → FFN] × 6 layers
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
**Evidence Gate** (sigmoid):
|
| 121 |
+
```python
|
| 122 |
+
gate = sigmoid(W_g · [z_k || cross_attn_output]) # Per-dimension gating
|
| 123 |
+
gated_evidence = gate * cross_attn_output
|
| 124 |
+
z_k = z_{k-1} + gated_evidence # Residual
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
The gate learns to control evidence flow per step:
|
| 128 |
+
- Early steps: high gate → absorb more visual/textual evidence
|
| 129 |
+
- Later steps: lower gate → rely on accumulated reasoning
|
| 130 |
+
|
| 131 |
+
**Step embeddings**: Learned per-step bias vectors differentiate rollout positions.
|
| 132 |
+
|
| 133 |
+
### 2.5 Target Encoder (EMA)
|
| 134 |
+
|
| 135 |
+
**Following I-JEPA** (Assran et al., 2023):
|
| 136 |
+
|
| 137 |
+
The target encoder is an EMA copy of [Evidence Memory + Latent Rollout]:
|
| 138 |
+
```
|
| 139 |
+
θ̄_t+1 = m(t) · θ̄_t + (1 - m(t)) · θ_t
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
**Momentum schedule** (cosine from 0.996 → 1.0):
|
| 143 |
+
```python
|
| 144 |
+
m(t) = 1 - (1 - 0.996) * (1 + cos(π · t/T)) / 2
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
The target encoder generates target trajectory z*₀, z*₁, z*₂, z*₃.
|
| 148 |
+
The online predictor must predict these targets.
|
| 149 |
+
|
| 150 |
+
**Critical**: Target encoder receives stop-gradient inputs and produces stop-gradient outputs.
|
| 151 |
+
|
| 152 |
+
### 2.6 JEPA Objective
|
| 153 |
+
|
| 154 |
+
**Prediction loss** (from I-JEPA):
|
| 155 |
+
```
|
| 156 |
+
L_JEPA = (1/K) Σ_{k=1}^{K} ||z_pred_k - sg(z*_k)||²
|
| 157 |
+
```
|
| 158 |
+
Only steps k=1..K are supervised (z₀ is deterministic from evidence).
|
| 159 |
+
|
| 160 |
+
**Anti-collapse regularization** (from LeWorldModel):
|
| 161 |
+
```
|
| 162 |
+
L_SIGReg = (1/M) Σ_{m=1}^{M} T(Z · u_m)
|
| 163 |
+
```
|
| 164 |
+
Where T is the Epps-Pulley normality test statistic, u_m are random unit vectors.
|
| 165 |
+
This encourages latent embeddings to remain Gaussian-distributed, preventing collapse.
|
| 166 |
+
|
| 167 |
+
**Total loss**:
|
| 168 |
+
```
|
| 169 |
+
L_total = L_JEPA + L_task + λ · L_SIGReg + α · L_gen
|
| 170 |
+
|
| 171 |
+
Where:
|
| 172 |
+
L_task = CrossEntropy(disc_head(z_K), answer_label) # MC scoring
|
| 173 |
+
L_gen = CE(gen_head(z_K), target_answer_tokens) # Short answer
|
| 174 |
+
λ = 0.1 (SIGReg weight)
|
| 175 |
+
α = 0.1 (generative weight)
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
### 2.7 Answer Heads
|
| 179 |
+
|
| 180 |
+
**Discriminative Head (Primary)** — for MC questions:
|
| 181 |
+
```
|
| 182 |
+
z_pooled = AttentionPool(z_K) # 32 tokens → 1 vector
|
| 183 |
+
score_i = MLP([z_pooled, opt_i, z_pooled ⊙ opt_i]) # Per-option score
|
| 184 |
+
probs = softmax(scores, mask=valid_options)
|
| 185 |
+
```
|
| 186 |
+
|
| 187 |
+
**Generative Head (Secondary)** — for open-ended questions:
|
| 188 |
+
```
|
| 189 |
+
Small transformer decoder (4 layers):
|
| 190 |
+
- Causal self-attention
|
| 191 |
+
- Cross-attention to z_K (latent state)
|
| 192 |
+
- Cross-attention to evidence memory (evidence-constrained)
|
| 193 |
+
- FFN
|
| 194 |
+
|
| 195 |
+
Max 64 tokens output. Weight-tied embedding + LM head.
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
---
|
| 199 |
+
|
| 200 |
+
## 3. Training Protocol
|
| 201 |
+
|
| 202 |
+
### Phase 1: Reasoning Core (20 epochs)
|
| 203 |
+
|
| 204 |
+
| Component | Status | LR |
|
| 205 |
+
|-----------|--------|-----|
|
| 206 |
+
| DINOv2-L | **Frozen** | — |
|
| 207 |
+
| DeBERTa | **Frozen** | — |
|
| 208 |
+
| Evidence Memory | Training | 3e-4 |
|
| 209 |
+
| Latent Rollout | Training | 3e-4 |
|
| 210 |
+
| Answer Heads | Training | 3e-4 |
|
| 211 |
+
| Target Encoder | EMA update | — |
|
| 212 |
+
|
| 213 |
+
**Data**: ScienceQA train (12.7K) + any available train splits
|
| 214 |
+
**Objective**: Full JEPA + task + SIGReg
|
| 215 |
+
**Batch size**: 32 × 4 accum = 128 effective
|
| 216 |
+
|
| 217 |
+
### Phase 2: Perception Fine-tuning (10 epochs)
|
| 218 |
+
|
| 219 |
+
| Component | Status | LR |
|
| 220 |
+
|-----------|--------|-----|
|
| 221 |
+
| DINOv2-L (last 6 layers) | **Training** | 1e-5 |
|
| 222 |
+
| DeBERTa (last 4 layers) | **Training** | 1e-5 |
|
| 223 |
+
| Evidence Memory | Training | 1e-4 |
|
| 224 |
+
| Latent Rollout | Training | 1e-4 |
|
| 225 |
+
| Answer Heads | Training | 1e-4 |
|
| 226 |
+
|
| 227 |
+
### Phase 3: Enriched Evidence (10 epochs)
|
| 228 |
+
|
| 229 |
+
| Component | Status | LR |
|
| 230 |
+
|-----------|--------|-----|
|
| 231 |
+
| All above | Training | 5e-5 |
|
| 232 |
+
| OCR tokens | **Enabled** | 5e-5 |
|
| 233 |
+
| Layout tokens | **Enabled** | 5e-5 |
|
| 234 |
+
| Chart tokens | **Enabled** | 5e-5 |
|
| 235 |
+
|
| 236 |
+
**Focus benchmarks**: DocVQA, TextVQA, ChartQA
|
| 237 |
+
|
| 238 |
+
---
|
| 239 |
+
|
| 240 |
+
## 4. Ablation Experiments
|
| 241 |
+
|
| 242 |
+
### Key ablations for the paper:
|
| 243 |
+
|
| 244 |
+
| Experiment | Modification | Expected finding |
|
| 245 |
+
|------------|-------------|-----------------|
|
| 246 |
+
| **Full MR-JEPA** | Baseline | Best overall |
|
| 247 |
+
| **No JEPA** | Remove L_JEPA, train with task loss only | Drops on reasoning-heavy benchmarks |
|
| 248 |
+
| **No Rollout** | K=0, use z₀ directly | Significant drop (proves rollout value) |
|
| 249 |
+
| **No Evidence Gate** | Remove gating | Slight drop (gate helps focus) |
|
| 250 |
+
| **K=1** | Shallow rollout | Worse than K=3 |
|
| 251 |
+
| **K=5** | Deeper rollout | Diminishing returns |
|
| 252 |
+
| **No SIGReg** | Remove anti-collapse | Training instability |
|
| 253 |
+
| **Purist branch** | DINOv2-B, no enriched evidence | Lower absolute scores, but validates JEPA contribution |
|
| 254 |
+
|
| 255 |
+
### Cross-benchmark analysis:
|
| 256 |
+
- JEPA contribution should be highest on **reasoning** benchmarks (MathVista, MMMU, ScienceQA)
|
| 257 |
+
- Evidence gate contribution should be highest on **evidence-rich** benchmarks (DocVQA, ChartQA)
|
| 258 |
+
- Enriched evidence (Phase 3) should matter most for **document** benchmarks
|
| 259 |
+
|
| 260 |
+
---
|
| 261 |
+
|
| 262 |
+
## 5. Parameter Budget
|
| 263 |
+
|
| 264 |
+
| Component | Parameters | Trainable (Phase 1) |
|
| 265 |
+
|-----------|-----------|---------------------|
|
| 266 |
+
| DINOv2-L | 300M | 0 |
|
| 267 |
+
| DeBERTa-v3-L | 304M | 0 |
|
| 268 |
+
| Evidence Memory | ~3M | 3M |
|
| 269 |
+
| Latent Rollout | ~3M | 3M |
|
| 270 |
+
| Disc Head | ~2M | 2M |
|
| 271 |
+
| Gen Head | ~25M | 25M |
|
| 272 |
+
| **Total** | **~637M** | **~33M** |
|
| 273 |
+
|
| 274 |
+
Phase 1 trains only ~5% of total parameters. The model is computationally efficient — the JEPA reasoning core is lightweight compared to the frozen perception backbones.
|
| 275 |
+
|
| 276 |
+
---
|
| 277 |
+
|
| 278 |
+
## 6. Benchmark Format Reference
|
| 279 |
+
|
| 280 |
+
| Benchmark | Type | Answer | Metric | Eval Split |
|
| 281 |
+
|-----------|------|--------|--------|------------|
|
| 282 |
+
| MMMU | MC (up to 7 images) | Letter A-D | Accuracy | validation (900) |
|
| 283 |
+
| MathVista | Mixed MC/Open | Letter or value | Accuracy | testmini (1000) |
|
| 284 |
+
| ScienceQA | MC (nullable image) | 0-indexed int | Accuracy | test (4241) |
|
| 285 |
+
| AI2D | MC (diagrams) | 0-indexed str | Accuracy | test (3088) |
|
| 286 |
+
| MMBench | MC (A/B/C/D cols) | Letter | CircularEval Acc | dev (4329) |
|
| 287 |
+
| MMStar | MC (embedded options) | Letter | Accuracy | val (1500) |
|
| 288 |
+
| DocVQA | Open (documents) | List[str] | ANLS | validation (5349) |
|
| 289 |
+
| TextVQA | Open (scene text) | 10 annotations | VQA Accuracy | validation (5000) |
|
| 290 |
+
| ChartQA | Open (charts) | str/number | Relaxed Accuracy | test (2500) |
|
| 291 |
+
|
| 292 |
+
---
|
| 293 |
+
|
| 294 |
+
## 7. Key References
|
| 295 |
+
|
| 296 |
+
1. **I-JEPA** (Assran et al., 2023) — arxiv:2301.08243: JEPA architecture, EMA target encoder, L2 prediction loss, narrow predictor
|
| 297 |
+
2. **V-JEPA** (Bardes et al., 2024) — arxiv:2412.10925: Temporal extension, multi-step prediction in latent space
|
| 298 |
+
3. **LeWorldModel** (Maes et al., 2025) — arxiv:2603.19312: SIGReg anti-collapse, end-to-end JEPA from pixels, 2474 GitHub stars
|
| 299 |
+
4. **Coconut** (Yu et al., 2024) — arxiv:2412.06769: Chain of Continuous Thought, latent reasoning paradigm
|
| 300 |
+
5. **SoftCoT++** (Xu et al., 2025) — arxiv:2505.11484: Soft chain-of-thought with perturbation and contrastive learning
|
| 301 |
+
6. **DINOv2** (Oquab et al., 2023) — arxiv:2304.07193: Dense SSL visual backbone
|
| 302 |
+
7. **DINOv3** (Meta, 2025) — arxiv:2508.10104: Improved SSL with RoPE, Gram anchoring
|
| 303 |
+
8. **SigLIP2** (Google, 2025) — arxiv:2502.14786: CLIP-style with DINO features + captioning
|
mr_jepa/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MR-JEPA: Multimodal Reasoning via Joint-Embedding Predictive Architecture
|
| 3 |
+
|
| 4 |
+
A world model for multimodal reasoning that refines a latent belief state
|
| 5 |
+
over K steps using JEPA-style prediction, evidence gating, and dense
|
| 6 |
+
visual backbones.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
__version__ = "0.1.0"
|
mr_jepa/configs/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .model_config import (
|
| 2 |
+
MRJEPAConfig,
|
| 3 |
+
VisualBackboneConfig,
|
| 4 |
+
TextEncoderConfig,
|
| 5 |
+
EvidenceMemoryConfig,
|
| 6 |
+
LatentRolloutConfig,
|
| 7 |
+
JEPAObjectiveConfig,
|
| 8 |
+
AnswerHeadConfig,
|
| 9 |
+
TrainingPhaseConfig,
|
| 10 |
+
get_hybrid_config,
|
| 11 |
+
get_purist_config,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"MRJEPAConfig",
|
| 16 |
+
"VisualBackboneConfig",
|
| 17 |
+
"TextEncoderConfig",
|
| 18 |
+
"EvidenceMemoryConfig",
|
| 19 |
+
"LatentRolloutConfig",
|
| 20 |
+
"JEPAObjectiveConfig",
|
| 21 |
+
"AnswerHeadConfig",
|
| 22 |
+
"TrainingPhaseConfig",
|
| 23 |
+
"get_hybrid_config",
|
| 24 |
+
"get_purist_config",
|
| 25 |
+
]
|
mr_jepa/configs/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (480 Bytes). View file
|
|
|
mr_jepa/configs/__pycache__/model_config.cpython-312.pyc
ADDED
|
Binary file (12.8 kB). View file
|
|
|
mr_jepa/configs/model_config.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MR-JEPA Model Configuration
|
| 3 |
+
|
| 4 |
+
Defines all hyperparameters for the model architecture, training phases,
|
| 5 |
+
and JEPA objectives. Values are grounded in the literature:
|
| 6 |
+
|
| 7 |
+
- I-JEPA (Assran et al., 2023): EMA schedule, L2 prediction loss
|
| 8 |
+
- LeWorldModel (Maes et al., 2025): SIGReg anti-collapse, end-to-end JEPA
|
| 9 |
+
- Coconut (Yu et al., 2024): Latent reasoning rollout paradigm
|
| 10 |
+
- DINOv2/v3 (Oquab et al., 2023 / Meta 2025): Visual backbone config
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from dataclasses import dataclass, field
|
| 14 |
+
from typing import Optional, Literal
|
| 15 |
+
import math
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class VisualBackboneConfig:
|
| 20 |
+
"""Configuration for the visual backbone encoder."""
|
| 21 |
+
# Backbone selection
|
| 22 |
+
backbone_type: Literal["dinov2", "dinov3", "siglip2"] = "dinov2"
|
| 23 |
+
model_name: str = "facebook/dinov2-large" # 1024-dim, 300M params
|
| 24 |
+
|
| 25 |
+
# DINOv2-L: hidden_size=1024, patch=14, 518px → 1369 patches + CLS + 4 reg = 1374 tokens
|
| 26 |
+
# DINOv3-L: hidden_size=1024, patch=16, RoPE, better dense features
|
| 27 |
+
# SigLIP2-So400m: hidden_size=1152, patch=14, 384px → 729 patches
|
| 28 |
+
|
| 29 |
+
hidden_size: int = 1024 # DINOv2-L / DINOv3-L output dim
|
| 30 |
+
image_size: int = 518 # DINOv2 default; 384 for SigLIP2
|
| 31 |
+
patch_size: int = 14 # 14 for DINOv2/SigLIP2, 16 for DINOv3
|
| 32 |
+
num_register_tokens: int = 4 # DINOv2/v3 register tokens
|
| 33 |
+
|
| 34 |
+
# Freezing control (Phase 1: fully frozen, Phase 2: unfreeze last N layers)
|
| 35 |
+
freeze: bool = True
|
| 36 |
+
unfreeze_last_n_layers: int = 0 # Phase 2: set to 4-6
|
| 37 |
+
|
| 38 |
+
# Optional: use only last N layers' features (multi-scale)
|
| 39 |
+
use_multi_scale: bool = False
|
| 40 |
+
multi_scale_layers: list = field(default_factory=lambda: [-1]) # last layer only
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class TextEncoderConfig:
|
| 45 |
+
"""Configuration for the text encoder."""
|
| 46 |
+
model_name: str = "microsoft/deberta-v3-large" # 1024-dim, strong NLU
|
| 47 |
+
hidden_size: int = 1024
|
| 48 |
+
max_text_length: int = 256 # questions + options
|
| 49 |
+
freeze: bool = True
|
| 50 |
+
unfreeze_last_n_layers: int = 0
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@dataclass
|
| 54 |
+
class EvidenceMemoryConfig:
|
| 55 |
+
"""
|
| 56 |
+
Configuration for the unified Evidence Memory.
|
| 57 |
+
|
| 58 |
+
The evidence memory is a set of tokens that fuse visual and textual information.
|
| 59 |
+
It uses cross-attention to attend to both visual patch tokens and text tokens,
|
| 60 |
+
producing a unified multimodal representation.
|
| 61 |
+
"""
|
| 62 |
+
hidden_dim: int = 768 # Internal dim of the evidence memory
|
| 63 |
+
num_evidence_tokens: int = 64 # Learnable evidence query tokens
|
| 64 |
+
num_cross_attn_layers: int = 4 # Cross-attention layers for fusion
|
| 65 |
+
num_heads: int = 12
|
| 66 |
+
dropout: float = 0.1
|
| 67 |
+
|
| 68 |
+
# Projections from backbone dims to evidence dim
|
| 69 |
+
visual_proj_dim: int = 768 # Project visual tokens to this dim
|
| 70 |
+
text_proj_dim: int = 768 # Project text tokens to this dim
|
| 71 |
+
|
| 72 |
+
# Optional enriched evidence (Phase 3)
|
| 73 |
+
use_ocr_tokens: bool = False
|
| 74 |
+
use_layout_tokens: bool = False
|
| 75 |
+
use_chart_tokens: bool = False
|
| 76 |
+
use_sam_tokens: bool = False
|
| 77 |
+
max_ocr_tokens: int = 128
|
| 78 |
+
max_layout_tokens: int = 64
|
| 79 |
+
max_chart_tokens: int = 64
|
| 80 |
+
max_sam_tokens: int = 32
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@dataclass
|
| 84 |
+
class LatentRolloutConfig:
|
| 85 |
+
"""
|
| 86 |
+
Configuration for the latent belief-state rollout.
|
| 87 |
+
|
| 88 |
+
The core JEPA reasoning module. Refines z₀ over K steps:
|
| 89 |
+
z₀ → z₁ → z₂ → z₃
|
| 90 |
+
|
| 91 |
+
Each step applies:
|
| 92 |
+
1. Self-attention over current state tokens
|
| 93 |
+
2. Evidence-gated cross-attention to evidence memory
|
| 94 |
+
3. FFN with residual connection
|
| 95 |
+
|
| 96 |
+
The predictor block is SHARED across all K steps (weight-tied),
|
| 97 |
+
following the recurrent predictor design from V-JEPA.
|
| 98 |
+
|
| 99 |
+
From I-JEPA: L2 loss in representation space, EMA target encoder
|
| 100 |
+
From LeWorldModel: SIGReg anti-collapse regularization
|
| 101 |
+
From Coconut: Iterative latent refinement paradigm
|
| 102 |
+
"""
|
| 103 |
+
hidden_dim: int = 768 # Latent state dimension
|
| 104 |
+
num_state_tokens: int = 32 # Number of latent belief tokens per step
|
| 105 |
+
K: int = 3 # Number of rollout steps
|
| 106 |
+
|
| 107 |
+
# Shared predictor block
|
| 108 |
+
num_predictor_layers: int = 6 # Transformer layers in predictor
|
| 109 |
+
num_heads: int = 12
|
| 110 |
+
ffn_dim: int = 3072 # 4x hidden_dim
|
| 111 |
+
dropout: float = 0.1
|
| 112 |
+
|
| 113 |
+
# Evidence gating
|
| 114 |
+
use_evidence_gate: bool = True
|
| 115 |
+
gate_type: Literal["sigmoid", "softmax", "learned"] = "sigmoid"
|
| 116 |
+
|
| 117 |
+
# Step embedding (to differentiate rollout steps)
|
| 118 |
+
use_step_embedding: bool = True
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@dataclass
|
| 122 |
+
class JEPAObjectiveConfig:
|
| 123 |
+
"""
|
| 124 |
+
Configuration for the JEPA training objective.
|
| 125 |
+
|
| 126 |
+
Target encoder: EMA of the online encoder (evidence memory + rollout).
|
| 127 |
+
The target generates z*_k for each rollout step k.
|
| 128 |
+
The online predictor must predict z*_k from z_{k-1}.
|
| 129 |
+
|
| 130 |
+
Loss: L2 in representation space (from I-JEPA)
|
| 131 |
+
Anti-collapse: SIGReg (from LeWorldModel) or VICReg-style
|
| 132 |
+
"""
|
| 133 |
+
# EMA schedule (from I-JEPA: cosine schedule 0.996 → 1.0)
|
| 134 |
+
ema_momentum_base: float = 0.996
|
| 135 |
+
ema_momentum_end: float = 1.0
|
| 136 |
+
ema_schedule: Literal["cosine", "linear", "constant"] = "cosine"
|
| 137 |
+
|
| 138 |
+
# Loss weights
|
| 139 |
+
jepa_loss_weight: float = 1.0 # L2 prediction loss
|
| 140 |
+
task_loss_weight: float = 1.0 # CE loss for answer classification
|
| 141 |
+
generative_loss_weight: float = 0.1 # Optional decoder loss
|
| 142 |
+
|
| 143 |
+
# Anti-collapse regularization (from LeWorldModel)
|
| 144 |
+
use_sigreg: bool = True
|
| 145 |
+
sigreg_weight: float = 0.1 # λ in LeWM paper
|
| 146 |
+
sigreg_num_projections: int = 1024 # M random projections
|
| 147 |
+
|
| 148 |
+
# Alternative: VICReg-style regularization
|
| 149 |
+
use_vicreg: bool = False
|
| 150 |
+
vicreg_var_weight: float = 1.0
|
| 151 |
+
vicreg_cov_weight: float = 0.04
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
@dataclass
|
| 155 |
+
class AnswerHeadConfig:
|
| 156 |
+
"""Configuration for answer prediction heads."""
|
| 157 |
+
# Discriminative head (primary): scores answer options
|
| 158 |
+
disc_hidden_dim: int = 768
|
| 159 |
+
disc_num_layers: int = 2
|
| 160 |
+
max_num_options: int = 8 # MMMU can have up to 8 options
|
| 161 |
+
disc_dropout: float = 0.1
|
| 162 |
+
|
| 163 |
+
# Generative head (secondary): short open-ended answers
|
| 164 |
+
gen_hidden_dim: int = 768
|
| 165 |
+
gen_num_layers: int = 4 # Small transformer decoder
|
| 166 |
+
gen_num_heads: int = 12
|
| 167 |
+
gen_vocab_size: int = 32000 # Shared with text encoder tokenizer
|
| 168 |
+
gen_max_answer_length: int = 64
|
| 169 |
+
gen_dropout: float = 0.1
|
| 170 |
+
|
| 171 |
+
# Evidence-constrained decoding
|
| 172 |
+
use_evidence_constraint: bool = True # Cross-attend to evidence during generation
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
@dataclass
|
| 176 |
+
class MRJEPAConfig:
|
| 177 |
+
"""
|
| 178 |
+
Complete MR-JEPA model configuration.
|
| 179 |
+
|
| 180 |
+
Two experimental branches:
|
| 181 |
+
- Hybrid-main: Full model with pretrained backbones + JEPA core
|
| 182 |
+
- Purist-side: Stripped-down version closer to LeWorldModel spirit
|
| 183 |
+
"""
|
| 184 |
+
# Component configs
|
| 185 |
+
visual: VisualBackboneConfig = field(default_factory=VisualBackboneConfig)
|
| 186 |
+
text: TextEncoderConfig = field(default_factory=TextEncoderConfig)
|
| 187 |
+
evidence: EvidenceMemoryConfig = field(default_factory=EvidenceMemoryConfig)
|
| 188 |
+
rollout: LatentRolloutConfig = field(default_factory=LatentRolloutConfig)
|
| 189 |
+
jepa: JEPAObjectiveConfig = field(default_factory=JEPAObjectiveConfig)
|
| 190 |
+
answer: AnswerHeadConfig = field(default_factory=AnswerHeadConfig)
|
| 191 |
+
|
| 192 |
+
# Branch selection
|
| 193 |
+
branch: Literal["hybrid", "purist"] = "hybrid"
|
| 194 |
+
|
| 195 |
+
# Global settings
|
| 196 |
+
seed: int = 42
|
| 197 |
+
|
| 198 |
+
@property
|
| 199 |
+
def num_visual_tokens(self) -> int:
|
| 200 |
+
"""Number of visual patch tokens output by backbone."""
|
| 201 |
+
n_patches = (self.visual.image_size // self.visual.patch_size) ** 2
|
| 202 |
+
return n_patches # Exclude CLS and register tokens (handled separately)
|
| 203 |
+
|
| 204 |
+
@property
|
| 205 |
+
def total_evidence_input_tokens(self) -> int:
|
| 206 |
+
"""Total tokens feeding into evidence memory."""
|
| 207 |
+
n = self.num_visual_tokens + self.text.max_text_length
|
| 208 |
+
if self.evidence.use_ocr_tokens:
|
| 209 |
+
n += self.evidence.max_ocr_tokens
|
| 210 |
+
if self.evidence.use_layout_tokens:
|
| 211 |
+
n += self.evidence.max_layout_tokens
|
| 212 |
+
if self.evidence.use_chart_tokens:
|
| 213 |
+
n += self.evidence.max_chart_tokens
|
| 214 |
+
if self.evidence.use_sam_tokens:
|
| 215 |
+
n += self.evidence.max_sam_tokens
|
| 216 |
+
return n
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
@dataclass
|
| 220 |
+
class TrainingPhaseConfig:
|
| 221 |
+
"""Configuration for the 3-phase training schedule."""
|
| 222 |
+
|
| 223 |
+
# Phase 1: Freeze perception, train reasoning core
|
| 224 |
+
phase1_epochs: int = 20
|
| 225 |
+
phase1_lr: float = 3e-4
|
| 226 |
+
phase1_warmup_ratio: float = 0.05
|
| 227 |
+
phase1_weight_decay: float = 0.05
|
| 228 |
+
phase1_batch_size: int = 32
|
| 229 |
+
phase1_grad_accum: int = 4
|
| 230 |
+
|
| 231 |
+
# Phase 2: Unfreeze last visual layers
|
| 232 |
+
phase2_epochs: int = 10
|
| 233 |
+
phase2_lr: float = 1e-4 # Lower LR for backbone fine-tuning
|
| 234 |
+
phase2_backbone_lr: float = 1e-5 # Even lower for backbone
|
| 235 |
+
phase2_warmup_ratio: float = 0.05
|
| 236 |
+
phase2_weight_decay: float = 0.05
|
| 237 |
+
phase2_batch_size: int = 16 # Smaller batch (more VRAM for gradients)
|
| 238 |
+
phase2_grad_accum: int = 8
|
| 239 |
+
phase2_unfreeze_visual_layers: int = 6 # Last 6 layers
|
| 240 |
+
phase2_unfreeze_text_layers: int = 4 # Last 4 layers
|
| 241 |
+
|
| 242 |
+
# Phase 3: Add enriched evidence
|
| 243 |
+
phase3_epochs: int = 10
|
| 244 |
+
phase3_lr: float = 5e-5
|
| 245 |
+
phase3_warmup_ratio: float = 0.1
|
| 246 |
+
phase3_weight_decay: float = 0.05
|
| 247 |
+
phase3_batch_size: int = 16
|
| 248 |
+
phase3_grad_accum: int = 8
|
| 249 |
+
phase3_enable_ocr: bool = True
|
| 250 |
+
phase3_enable_layout: bool = True
|
| 251 |
+
phase3_enable_chart: bool = True
|
| 252 |
+
phase3_enable_sam: bool = False # Optional, heavy
|
| 253 |
+
|
| 254 |
+
# Common
|
| 255 |
+
optimizer: str = "adamw"
|
| 256 |
+
scheduler: str = "cosine"
|
| 257 |
+
max_grad_norm: float = 1.0
|
| 258 |
+
fp16: bool = False
|
| 259 |
+
bf16: bool = True
|
| 260 |
+
gradient_checkpointing: bool = True
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def get_hybrid_config() -> MRJEPAConfig:
|
| 264 |
+
"""Get the Hybrid-main branch configuration."""
|
| 265 |
+
config = MRJEPAConfig(branch="hybrid")
|
| 266 |
+
# DINOv2-L backbone for strong dense features
|
| 267 |
+
config.visual.model_name = "facebook/dinov2-large"
|
| 268 |
+
config.visual.hidden_size = 1024
|
| 269 |
+
config.visual.image_size = 518
|
| 270 |
+
config.visual.patch_size = 14
|
| 271 |
+
return config
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def get_purist_config() -> MRJEPAConfig:
|
| 275 |
+
"""
|
| 276 |
+
Get the Purist-side branch configuration.
|
| 277 |
+
Closer to LeWorldModel: smaller backbone, stronger JEPA emphasis.
|
| 278 |
+
"""
|
| 279 |
+
config = MRJEPAConfig(branch="purist")
|
| 280 |
+
# Smaller backbone, more emphasis on JEPA dynamics
|
| 281 |
+
config.visual.model_name = "facebook/dinov2-base"
|
| 282 |
+
config.visual.hidden_size = 768
|
| 283 |
+
config.visual.image_size = 518
|
| 284 |
+
config.visual.patch_size = 14
|
| 285 |
+
|
| 286 |
+
# Larger rollout to compensate for weaker perception
|
| 287 |
+
config.rollout.K = 5
|
| 288 |
+
config.rollout.num_state_tokens = 48
|
| 289 |
+
config.rollout.num_predictor_layers = 8
|
| 290 |
+
|
| 291 |
+
# Stronger JEPA objective
|
| 292 |
+
config.jepa.jepa_loss_weight = 2.0
|
| 293 |
+
config.jepa.task_loss_weight = 1.0
|
| 294 |
+
config.jepa.sigreg_weight = 0.2
|
| 295 |
+
|
| 296 |
+
# No enriched evidence (pure JEPA reasoning)
|
| 297 |
+
config.evidence.use_ocr_tokens = False
|
| 298 |
+
config.evidence.use_layout_tokens = False
|
| 299 |
+
config.evidence.use_chart_tokens = False
|
| 300 |
+
config.evidence.use_sam_tokens = False
|
| 301 |
+
|
| 302 |
+
# Smaller text encoder
|
| 303 |
+
config.text.model_name = "microsoft/deberta-v3-base"
|
| 304 |
+
config.text.hidden_size = 768
|
| 305 |
+
|
| 306 |
+
return config
|
mr_jepa/data/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .unified_dataset import UnifiedBenchmarkDataset, BenchmarkType
|
| 2 |
+
from .data_utils import build_dataloader, get_benchmark_config
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"UnifiedBenchmarkDataset",
|
| 6 |
+
"BenchmarkType",
|
| 7 |
+
"build_dataloader",
|
| 8 |
+
"get_benchmark_config",
|
| 9 |
+
]
|
mr_jepa/data/data_utils.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data utilities for MR-JEPA.
|
| 3 |
+
|
| 4 |
+
Includes:
|
| 5 |
+
- Collator that handles variable-length options, multi-image samples
|
| 6 |
+
- Dataloader factory
|
| 7 |
+
- Benchmark configuration helpers
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from torch.utils.data import DataLoader
|
| 13 |
+
from typing import Optional, Dict, List, Any, Tuple
|
| 14 |
+
from PIL import Image
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
from .unified_dataset import UnifiedBenchmarkDataset, BenchmarkSample, BenchmarkType
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
BENCHMARK_CONFIGS = {
|
| 21 |
+
'mmmu': {
|
| 22 |
+
'repo_id': 'MMMU/MMMU',
|
| 23 |
+
'eval_split': 'validation',
|
| 24 |
+
'metric': 'accuracy',
|
| 25 |
+
'answer_type': 'mc',
|
| 26 |
+
'configs': [
|
| 27 |
+
'Accounting', 'Agriculture', 'Architecture_and_Engineering',
|
| 28 |
+
'Art', 'Art_Theory', 'Basic_Medical_Science', 'Biology',
|
| 29 |
+
'Chemistry', 'Clinical_Medicine', 'Computer_Science',
|
| 30 |
+
'Design', 'Diagnostics_and_Laboratory_Medicine', 'Economics',
|
| 31 |
+
'Electronics', 'Energy_and_Power', 'Finance', 'Geography',
|
| 32 |
+
'History', 'Literature', 'Manage', 'Marketing',
|
| 33 |
+
'Materials', 'Math', 'Mechanical_Engineering', 'Music',
|
| 34 |
+
'Pharmacy', 'Physics', 'Psychology', 'Public_Health',
|
| 35 |
+
'Sociology'
|
| 36 |
+
],
|
| 37 |
+
},
|
| 38 |
+
'mathvista': {
|
| 39 |
+
'repo_id': 'AI4Math/MathVista',
|
| 40 |
+
'eval_split': 'testmini',
|
| 41 |
+
'metric': 'accuracy',
|
| 42 |
+
'answer_type': 'mixed',
|
| 43 |
+
},
|
| 44 |
+
'scienceqa': {
|
| 45 |
+
'repo_id': 'derek-thomas/ScienceQA',
|
| 46 |
+
'eval_split': 'test',
|
| 47 |
+
'train_split': 'train',
|
| 48 |
+
'metric': 'accuracy',
|
| 49 |
+
'answer_type': 'mc',
|
| 50 |
+
},
|
| 51 |
+
'ai2d': {
|
| 52 |
+
'repo_id': 'lmms-lab/ai2d',
|
| 53 |
+
'eval_split': 'test',
|
| 54 |
+
'metric': 'accuracy',
|
| 55 |
+
'answer_type': 'mc',
|
| 56 |
+
},
|
| 57 |
+
'mmbench': {
|
| 58 |
+
'repo_id': 'lmms-lab/MMBench',
|
| 59 |
+
'eval_split': 'dev',
|
| 60 |
+
'metric': 'accuracy',
|
| 61 |
+
'answer_type': 'mc',
|
| 62 |
+
},
|
| 63 |
+
'mmstar': {
|
| 64 |
+
'repo_id': 'Lin-Chen/MMStar',
|
| 65 |
+
'eval_split': 'val',
|
| 66 |
+
'metric': 'accuracy',
|
| 67 |
+
'answer_type': 'mc',
|
| 68 |
+
},
|
| 69 |
+
'docvqa': {
|
| 70 |
+
'repo_id': 'lmms-lab/DocVQA',
|
| 71 |
+
'eval_split': 'validation',
|
| 72 |
+
'metric': 'anls',
|
| 73 |
+
'answer_type': 'open',
|
| 74 |
+
},
|
| 75 |
+
'textvqa': {
|
| 76 |
+
'repo_id': 'lmms-lab/textvqa',
|
| 77 |
+
'eval_split': 'validation',
|
| 78 |
+
'metric': 'vqa_accuracy',
|
| 79 |
+
'answer_type': 'open',
|
| 80 |
+
},
|
| 81 |
+
'chartqa': {
|
| 82 |
+
'repo_id': 'lmms-lab/ChartQA',
|
| 83 |
+
'eval_split': 'test',
|
| 84 |
+
'metric': 'relaxed_accuracy',
|
| 85 |
+
'answer_type': 'open',
|
| 86 |
+
},
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_benchmark_config(benchmark: str) -> Dict:
|
| 91 |
+
"""Get benchmark configuration."""
|
| 92 |
+
return BENCHMARK_CONFIGS[benchmark]
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class MRJEPACollator:
|
| 96 |
+
"""
|
| 97 |
+
Collator for MR-JEPA that handles:
|
| 98 |
+
- Variable number of images per sample (MMMU)
|
| 99 |
+
- Variable number of answer options
|
| 100 |
+
- Mixed MC/open-ended questions
|
| 101 |
+
- Image preprocessing via backbone processor
|
| 102 |
+
- Text tokenization
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
image_processor,
|
| 108 |
+
text_tokenizer,
|
| 109 |
+
max_options: int = 8,
|
| 110 |
+
max_text_length: int = 256,
|
| 111 |
+
max_gen_length: int = 64,
|
| 112 |
+
image_size: int = 518,
|
| 113 |
+
):
|
| 114 |
+
self.image_processor = image_processor
|
| 115 |
+
self.text_tokenizer = text_tokenizer
|
| 116 |
+
self.max_options = max_options
|
| 117 |
+
self.max_text_length = max_text_length
|
| 118 |
+
self.max_gen_length = max_gen_length
|
| 119 |
+
self.image_size = image_size
|
| 120 |
+
|
| 121 |
+
def __call__(self, batch: List[BenchmarkSample]) -> Dict[str, torch.Tensor]:
|
| 122 |
+
"""Collate a batch of BenchmarkSamples."""
|
| 123 |
+
B = len(batch)
|
| 124 |
+
|
| 125 |
+
# ==================== Images ====================
|
| 126 |
+
# Use first image for now (multi-image MMMU handled separately)
|
| 127 |
+
images = []
|
| 128 |
+
for sample in batch:
|
| 129 |
+
img = sample.images[0]
|
| 130 |
+
if not isinstance(img, Image.Image):
|
| 131 |
+
img = Image.new('RGB', (self.image_size, self.image_size), 'white')
|
| 132 |
+
images.append(img.convert('RGB'))
|
| 133 |
+
|
| 134 |
+
# Process images through backbone processor
|
| 135 |
+
pixel_values = self.image_processor(
|
| 136 |
+
images=images,
|
| 137 |
+
return_tensors='pt',
|
| 138 |
+
)['pixel_values'] # [B, C, H, W]
|
| 139 |
+
|
| 140 |
+
# ==================== Question Text ====================
|
| 141 |
+
questions = [s.question for s in batch]
|
| 142 |
+
text_encoded = self.text_tokenizer(
|
| 143 |
+
questions,
|
| 144 |
+
padding='max_length',
|
| 145 |
+
truncation=True,
|
| 146 |
+
max_length=self.max_text_length,
|
| 147 |
+
return_tensors='pt',
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# ==================== Options (MC) ====================
|
| 151 |
+
# Encode each option separately, pad to max_options
|
| 152 |
+
option_embeddings_list = []
|
| 153 |
+
option_masks = []
|
| 154 |
+
answer_labels = []
|
| 155 |
+
|
| 156 |
+
has_mc = any(s.options is not None for s in batch)
|
| 157 |
+
|
| 158 |
+
if has_mc:
|
| 159 |
+
for sample in batch:
|
| 160 |
+
if sample.options:
|
| 161 |
+
n_opts = min(len(sample.options), self.max_options)
|
| 162 |
+
# Tokenize options
|
| 163 |
+
opts_text = sample.options[:n_opts]
|
| 164 |
+
# Pad option text list to max_options
|
| 165 |
+
while len(opts_text) < self.max_options:
|
| 166 |
+
opts_text.append("")
|
| 167 |
+
|
| 168 |
+
mask = [True] * n_opts + [False] * (self.max_options - n_opts)
|
| 169 |
+
option_masks.append(mask)
|
| 170 |
+
|
| 171 |
+
# Answer label
|
| 172 |
+
if isinstance(sample.answer, int):
|
| 173 |
+
answer_labels.append(min(sample.answer, n_opts - 1))
|
| 174 |
+
elif isinstance(sample.answer, str) and len(sample.answer) == 1:
|
| 175 |
+
answer_labels.append(ord(sample.answer.upper()) - ord('A'))
|
| 176 |
+
else:
|
| 177 |
+
answer_labels.append(0)
|
| 178 |
+
else:
|
| 179 |
+
option_masks.append([False] * self.max_options)
|
| 180 |
+
answer_labels.append(0)
|
| 181 |
+
|
| 182 |
+
# ==================== Open-ended answers ====================
|
| 183 |
+
gen_target_ids = None
|
| 184 |
+
has_open = any(s.answer_type == 'open' for s in batch)
|
| 185 |
+
|
| 186 |
+
if has_open:
|
| 187 |
+
# Prepare generative targets
|
| 188 |
+
gen_texts = []
|
| 189 |
+
for sample in batch:
|
| 190 |
+
if sample.answer_type == 'open':
|
| 191 |
+
if isinstance(sample.answer, list):
|
| 192 |
+
gen_texts.append(str(sample.answer[0]))
|
| 193 |
+
else:
|
| 194 |
+
gen_texts.append(str(sample.answer))
|
| 195 |
+
else:
|
| 196 |
+
gen_texts.append("")
|
| 197 |
+
|
| 198 |
+
gen_encoded = self.text_tokenizer(
|
| 199 |
+
gen_texts,
|
| 200 |
+
padding='max_length',
|
| 201 |
+
truncation=True,
|
| 202 |
+
max_length=self.max_gen_length,
|
| 203 |
+
return_tensors='pt',
|
| 204 |
+
)
|
| 205 |
+
gen_target_ids = gen_encoded['input_ids']
|
| 206 |
+
|
| 207 |
+
# ==================== Build output dict ====================
|
| 208 |
+
result = {
|
| 209 |
+
'pixel_values': pixel_values,
|
| 210 |
+
'input_ids': text_encoded['input_ids'],
|
| 211 |
+
'attention_mask': text_encoded['attention_mask'],
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
if has_mc:
|
| 215 |
+
result['option_mask'] = torch.tensor(option_masks, dtype=torch.bool)
|
| 216 |
+
result['answer_labels'] = torch.tensor(answer_labels, dtype=torch.long)
|
| 217 |
+
|
| 218 |
+
# We need to encode options through text encoder at runtime
|
| 219 |
+
# Store raw option texts for the model to encode
|
| 220 |
+
all_option_texts = []
|
| 221 |
+
for sample in batch:
|
| 222 |
+
opts = sample.options or [""] * self.max_options
|
| 223 |
+
opts = opts[:self.max_options]
|
| 224 |
+
while len(opts) < self.max_options:
|
| 225 |
+
opts.append("")
|
| 226 |
+
all_option_texts.append(opts)
|
| 227 |
+
result['option_texts'] = all_option_texts
|
| 228 |
+
|
| 229 |
+
if gen_target_ids is not None:
|
| 230 |
+
result['gen_target_ids'] = gen_target_ids
|
| 231 |
+
|
| 232 |
+
# Metadata
|
| 233 |
+
result['benchmarks'] = [s.benchmark for s in batch]
|
| 234 |
+
result['answer_types'] = [s.answer_type for s in batch]
|
| 235 |
+
result['raw_answers'] = [s.answer for s in batch]
|
| 236 |
+
|
| 237 |
+
return result
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def build_dataloader(
|
| 241 |
+
benchmark: str,
|
| 242 |
+
split: str,
|
| 243 |
+
image_processor,
|
| 244 |
+
text_tokenizer,
|
| 245 |
+
batch_size: int = 32,
|
| 246 |
+
num_workers: int = 4,
|
| 247 |
+
max_samples: Optional[int] = None,
|
| 248 |
+
config: Optional[str] = None,
|
| 249 |
+
**collator_kwargs,
|
| 250 |
+
) -> DataLoader:
|
| 251 |
+
"""Build a DataLoader for a specific benchmark."""
|
| 252 |
+
dataset = UnifiedBenchmarkDataset(
|
| 253 |
+
benchmark=benchmark,
|
| 254 |
+
split=split,
|
| 255 |
+
config=config,
|
| 256 |
+
max_samples=max_samples,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
collator = MRJEPACollator(
|
| 260 |
+
image_processor=image_processor,
|
| 261 |
+
text_tokenizer=text_tokenizer,
|
| 262 |
+
**collator_kwargs,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
return DataLoader(
|
| 266 |
+
dataset,
|
| 267 |
+
batch_size=batch_size,
|
| 268 |
+
shuffle=(split in ('train', 'training')),
|
| 269 |
+
num_workers=num_workers,
|
| 270 |
+
collate_fn=collator,
|
| 271 |
+
pin_memory=True,
|
| 272 |
+
drop_last=(split in ('train', 'training')),
|
| 273 |
+
)
|
mr_jepa/data/unified_dataset.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified Dataset for MR-JEPA Benchmarks.
|
| 3 |
+
|
| 4 |
+
Handles all 9 benchmarks with their quirky formats in a single pipeline:
|
| 5 |
+
|
| 6 |
+
MC Benchmarks:
|
| 7 |
+
- MMMU: up to 7 images, string-encoded options, letter answers
|
| 8 |
+
- ScienceQA: nullable images, int8 answer index
|
| 9 |
+
- AI2D: string-encoded int index answer
|
| 10 |
+
- MMBench: separate A/B/C/D columns
|
| 11 |
+
- MMStar: options embedded in question text
|
| 12 |
+
|
| 13 |
+
Open-Ended Benchmarks:
|
| 14 |
+
- MathVista: mixed MC/free-form, dual image columns
|
| 15 |
+
- DocVQA: multiple valid answers (ANLS metric)
|
| 16 |
+
- TextVQA: 10 annotations (VQA Accuracy)
|
| 17 |
+
- ChartQA: relaxed numeric accuracy
|
| 18 |
+
|
| 19 |
+
Each sample is normalized to a common format:
|
| 20 |
+
{
|
| 21 |
+
'image': PIL.Image or List[PIL.Image],
|
| 22 |
+
'question': str,
|
| 23 |
+
'options': List[str] or None, # None for open-ended
|
| 24 |
+
'answer': str or int, # Correct answer
|
| 25 |
+
'answer_type': 'mc' or 'open',
|
| 26 |
+
'benchmark': str,
|
| 27 |
+
'metadata': dict,
|
| 28 |
+
}
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
import ast
|
| 32 |
+
import re
|
| 33 |
+
import torch
|
| 34 |
+
from torch.utils.data import Dataset
|
| 35 |
+
from PIL import Image
|
| 36 |
+
from enum import Enum
|
| 37 |
+
from typing import Optional, Dict, List, Any, Tuple
|
| 38 |
+
from dataclasses import dataclass
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class BenchmarkType(Enum):
|
| 42 |
+
MMMU = "mmmu"
|
| 43 |
+
MATHVISTA = "mathvista"
|
| 44 |
+
SCIENCEQA = "scienceqa"
|
| 45 |
+
AI2D = "ai2d"
|
| 46 |
+
MMBENCH = "mmbench"
|
| 47 |
+
MMSTAR = "mmstar"
|
| 48 |
+
DOCVQA = "docvqa"
|
| 49 |
+
TEXTVQA = "textvqa"
|
| 50 |
+
CHARTQA = "chartqa"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@dataclass
|
| 54 |
+
class BenchmarkSample:
|
| 55 |
+
"""Normalized sample format across all benchmarks."""
|
| 56 |
+
images: List[Image.Image] # 1+ images (MMMU can have up to 7)
|
| 57 |
+
question: str
|
| 58 |
+
options: Optional[List[str]] # None for open-ended
|
| 59 |
+
answer: Any # str (letter/text) or int (index)
|
| 60 |
+
answer_type: str # 'mc' or 'open'
|
| 61 |
+
benchmark: str
|
| 62 |
+
metadata: Dict[str, Any]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class UnifiedBenchmarkDataset(Dataset):
|
| 66 |
+
"""
|
| 67 |
+
Unified dataset that loads any of the 9 benchmarks into a common format.
|
| 68 |
+
|
| 69 |
+
Usage:
|
| 70 |
+
dataset = UnifiedBenchmarkDataset(
|
| 71 |
+
benchmark='mmmu',
|
| 72 |
+
split='validation',
|
| 73 |
+
config='Accounting', # MMMU has per-subject configs
|
| 74 |
+
)
|
| 75 |
+
sample = dataset[0] # Returns BenchmarkSample
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
benchmark: str,
|
| 81 |
+
split: str = "validation",
|
| 82 |
+
config: Optional[str] = None,
|
| 83 |
+
max_samples: Optional[int] = None,
|
| 84 |
+
transform: Optional[Any] = None,
|
| 85 |
+
):
|
| 86 |
+
self.benchmark = BenchmarkType(benchmark)
|
| 87 |
+
self.split = split
|
| 88 |
+
self.transform = transform
|
| 89 |
+
|
| 90 |
+
# Load dataset
|
| 91 |
+
self.data = self._load_dataset(config, max_samples)
|
| 92 |
+
|
| 93 |
+
def _load_dataset(self, config: Optional[str], max_samples: Optional[int]):
|
| 94 |
+
"""Load dataset from HuggingFace Hub."""
|
| 95 |
+
from datasets import load_dataset
|
| 96 |
+
|
| 97 |
+
repo_map = {
|
| 98 |
+
BenchmarkType.MMMU: "MMMU/MMMU",
|
| 99 |
+
BenchmarkType.MATHVISTA: "AI4Math/MathVista",
|
| 100 |
+
BenchmarkType.SCIENCEQA: "derek-thomas/ScienceQA",
|
| 101 |
+
BenchmarkType.AI2D: "lmms-lab/ai2d",
|
| 102 |
+
BenchmarkType.MMBENCH: "lmms-lab/MMBench",
|
| 103 |
+
BenchmarkType.MMSTAR: "Lin-Chen/MMStar",
|
| 104 |
+
BenchmarkType.DOCVQA: "lmms-lab/DocVQA",
|
| 105 |
+
BenchmarkType.TEXTVQA: "lmms-lab/textvqa",
|
| 106 |
+
BenchmarkType.CHARTQA: "lmms-lab/ChartQA",
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
repo_id = repo_map[self.benchmark]
|
| 110 |
+
|
| 111 |
+
# Handle config/split variations
|
| 112 |
+
kwargs = {}
|
| 113 |
+
if config:
|
| 114 |
+
kwargs['name'] = config
|
| 115 |
+
elif self.benchmark == BenchmarkType.MMBENCH:
|
| 116 |
+
kwargs['name'] = 'en'
|
| 117 |
+
elif self.benchmark == BenchmarkType.DOCVQA:
|
| 118 |
+
kwargs['name'] = 'DocVQA'
|
| 119 |
+
|
| 120 |
+
# Some datasets have different split names
|
| 121 |
+
split_name = self.split
|
| 122 |
+
if self.benchmark == BenchmarkType.MMSTAR and self.split == 'validation':
|
| 123 |
+
split_name = 'val'
|
| 124 |
+
|
| 125 |
+
try:
|
| 126 |
+
ds = load_dataset(repo_id, split=split_name, **kwargs)
|
| 127 |
+
except Exception as e:
|
| 128 |
+
# Fallback: try without config
|
| 129 |
+
print(f"Warning: Failed to load {repo_id} with config={config}, split={split_name}: {e}")
|
| 130 |
+
ds = load_dataset(repo_id, split=split_name)
|
| 131 |
+
|
| 132 |
+
if max_samples:
|
| 133 |
+
ds = ds.select(range(min(max_samples, len(ds))))
|
| 134 |
+
|
| 135 |
+
return ds
|
| 136 |
+
|
| 137 |
+
def __len__(self):
|
| 138 |
+
return len(self.data)
|
| 139 |
+
|
| 140 |
+
def __getitem__(self, idx: int) -> BenchmarkSample:
|
| 141 |
+
row = self.data[idx]
|
| 142 |
+
|
| 143 |
+
# Dispatch to benchmark-specific parser
|
| 144 |
+
parser_map = {
|
| 145 |
+
BenchmarkType.MMMU: self._parse_mmmu,
|
| 146 |
+
BenchmarkType.MATHVISTA: self._parse_mathvista,
|
| 147 |
+
BenchmarkType.SCIENCEQA: self._parse_scienceqa,
|
| 148 |
+
BenchmarkType.AI2D: self._parse_ai2d,
|
| 149 |
+
BenchmarkType.MMBENCH: self._parse_mmbench,
|
| 150 |
+
BenchmarkType.MMSTAR: self._parse_mmstar,
|
| 151 |
+
BenchmarkType.DOCVQA: self._parse_docvqa,
|
| 152 |
+
BenchmarkType.TEXTVQA: self._parse_textvqa,
|
| 153 |
+
BenchmarkType.CHARTQA: self._parse_chartqa,
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
return parser_map[self.benchmark](row)
|
| 157 |
+
|
| 158 |
+
# ==================== Benchmark-Specific Parsers ====================
|
| 159 |
+
|
| 160 |
+
def _parse_mmmu(self, row) -> BenchmarkSample:
|
| 161 |
+
"""MMMU: up to 7 images, string-encoded options."""
|
| 162 |
+
images = []
|
| 163 |
+
for i in range(1, 8):
|
| 164 |
+
img = row.get(f'image_{i}')
|
| 165 |
+
if img is not None:
|
| 166 |
+
if isinstance(img, Image.Image):
|
| 167 |
+
images.append(img)
|
| 168 |
+
|
| 169 |
+
if not images:
|
| 170 |
+
# Create a blank image as fallback
|
| 171 |
+
images = [Image.new('RGB', (224, 224), color='white')]
|
| 172 |
+
|
| 173 |
+
# Parse options (string-encoded Python list)
|
| 174 |
+
options_str = row.get('options', '[]')
|
| 175 |
+
try:
|
| 176 |
+
options = ast.literal_eval(options_str) if isinstance(options_str, str) else options_str
|
| 177 |
+
except (ValueError, SyntaxError):
|
| 178 |
+
options = []
|
| 179 |
+
|
| 180 |
+
question = row['question']
|
| 181 |
+
answer = row.get('answer', 'A')
|
| 182 |
+
|
| 183 |
+
return BenchmarkSample(
|
| 184 |
+
images=images,
|
| 185 |
+
question=question,
|
| 186 |
+
options=options if options else None,
|
| 187 |
+
answer=answer,
|
| 188 |
+
answer_type='mc' if row.get('question_type', 'multiple-choice') == 'multiple-choice' else 'open',
|
| 189 |
+
benchmark='mmmu',
|
| 190 |
+
metadata={
|
| 191 |
+
'id': row.get('id', ''),
|
| 192 |
+
'subject': row.get('subfield', ''),
|
| 193 |
+
'difficulty': row.get('topic_difficulty', ''),
|
| 194 |
+
'img_type': row.get('img_type', ''),
|
| 195 |
+
}
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
def _parse_mathvista(self, row) -> BenchmarkSample:
|
| 199 |
+
"""MathVista: mixed MC/free-form, use decoded_image."""
|
| 200 |
+
image = row.get('decoded_image') or row.get('image')
|
| 201 |
+
if isinstance(image, str):
|
| 202 |
+
# It's a path, not an image — this shouldn't happen with decoded_image
|
| 203 |
+
image = Image.new('RGB', (224, 224), color='white')
|
| 204 |
+
images = [image] if image else [Image.new('RGB', (224, 224), color='white')]
|
| 205 |
+
|
| 206 |
+
question = row.get('query', row.get('question', ''))
|
| 207 |
+
choices = row.get('choices', None)
|
| 208 |
+
answer = row.get('answer', '')
|
| 209 |
+
qtype = row.get('question_type', 'free_form')
|
| 210 |
+
|
| 211 |
+
return BenchmarkSample(
|
| 212 |
+
images=images,
|
| 213 |
+
question=question,
|
| 214 |
+
options=list(choices) if choices else None,
|
| 215 |
+
answer=answer,
|
| 216 |
+
answer_type='mc' if qtype == 'multi_choice' else 'open',
|
| 217 |
+
benchmark='mathvista',
|
| 218 |
+
metadata={
|
| 219 |
+
'pid': row.get('pid', ''),
|
| 220 |
+
'answer_type': row.get('answer_type', ''),
|
| 221 |
+
'unit': row.get('unit', ''),
|
| 222 |
+
}
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
def _parse_scienceqa(self, row) -> BenchmarkSample:
|
| 226 |
+
"""ScienceQA: nullable images, int8 answer index."""
|
| 227 |
+
image = row.get('image')
|
| 228 |
+
if image is None:
|
| 229 |
+
images = [Image.new('RGB', (224, 224), color='white')]
|
| 230 |
+
has_image = False
|
| 231 |
+
else:
|
| 232 |
+
images = [image]
|
| 233 |
+
has_image = True
|
| 234 |
+
|
| 235 |
+
choices = row.get('choices', [])
|
| 236 |
+
answer_idx = int(row.get('answer', 0))
|
| 237 |
+
|
| 238 |
+
return BenchmarkSample(
|
| 239 |
+
images=images,
|
| 240 |
+
question=row['question'],
|
| 241 |
+
options=list(choices),
|
| 242 |
+
answer=answer_idx, # 0-indexed integer
|
| 243 |
+
answer_type='mc',
|
| 244 |
+
benchmark='scienceqa',
|
| 245 |
+
metadata={
|
| 246 |
+
'has_image': has_image,
|
| 247 |
+
'subject': row.get('subject', ''),
|
| 248 |
+
'grade': row.get('grade', ''),
|
| 249 |
+
'hint': row.get('hint', ''),
|
| 250 |
+
'lecture': row.get('lecture', ''),
|
| 251 |
+
'solution': row.get('solution', ''),
|
| 252 |
+
}
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
def _parse_ai2d(self, row) -> BenchmarkSample:
|
| 256 |
+
"""AI2D: string-encoded int index answer."""
|
| 257 |
+
images = [row['image']]
|
| 258 |
+
options = list(row.get('options', []))
|
| 259 |
+
answer_idx = int(row.get('answer', '0'))
|
| 260 |
+
|
| 261 |
+
return BenchmarkSample(
|
| 262 |
+
images=images,
|
| 263 |
+
question=row['question'],
|
| 264 |
+
options=options,
|
| 265 |
+
answer=answer_idx, # 0-indexed integer
|
| 266 |
+
answer_type='mc',
|
| 267 |
+
benchmark='ai2d',
|
| 268 |
+
metadata={}
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
def _parse_mmbench(self, row) -> BenchmarkSample:
|
| 272 |
+
"""MMBench: separate A/B/C/D columns."""
|
| 273 |
+
images = [row['image']]
|
| 274 |
+
|
| 275 |
+
# Build options from separate columns
|
| 276 |
+
options = []
|
| 277 |
+
for letter in ['A', 'B', 'C', 'D']:
|
| 278 |
+
opt = row.get(letter, '')
|
| 279 |
+
if opt:
|
| 280 |
+
options.append(opt)
|
| 281 |
+
|
| 282 |
+
# Answer is a letter
|
| 283 |
+
answer = row.get('answer', 'A')
|
| 284 |
+
# Convert letter to index
|
| 285 |
+
answer_idx = ord(answer) - ord('A') if isinstance(answer, str) and len(answer) == 1 else 0
|
| 286 |
+
|
| 287 |
+
return BenchmarkSample(
|
| 288 |
+
images=images,
|
| 289 |
+
question=row['question'],
|
| 290 |
+
options=options,
|
| 291 |
+
answer=answer_idx,
|
| 292 |
+
answer_type='mc',
|
| 293 |
+
benchmark='mmbench',
|
| 294 |
+
metadata={
|
| 295 |
+
'category': row.get('category', ''),
|
| 296 |
+
'hint': row.get('hint', ''),
|
| 297 |
+
}
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
def _parse_mmstar(self, row) -> BenchmarkSample:
|
| 301 |
+
"""MMStar: options embedded in question text."""
|
| 302 |
+
images = [row['image']]
|
| 303 |
+
question = row['question']
|
| 304 |
+
|
| 305 |
+
# Parse options from question text
|
| 306 |
+
# Format: "... Options: A: ..., B: ..., C: ..., D: ..."
|
| 307 |
+
options = []
|
| 308 |
+
option_pattern = r'([A-D]):\s*([^,\n]+(?:,\s*[^A-D\n][^,\n]*)*)'
|
| 309 |
+
matches = re.findall(option_pattern, question)
|
| 310 |
+
if matches:
|
| 311 |
+
for letter, text in matches:
|
| 312 |
+
options.append(text.strip())
|
| 313 |
+
|
| 314 |
+
answer = row.get('answer', 'A')
|
| 315 |
+
answer_idx = ord(answer) - ord('A') if isinstance(answer, str) and len(answer) == 1 else 0
|
| 316 |
+
|
| 317 |
+
return BenchmarkSample(
|
| 318 |
+
images=images,
|
| 319 |
+
question=question,
|
| 320 |
+
options=options if options else None,
|
| 321 |
+
answer=answer_idx,
|
| 322 |
+
answer_type='mc',
|
| 323 |
+
benchmark='mmstar',
|
| 324 |
+
metadata={
|
| 325 |
+
'category': row.get('category', ''),
|
| 326 |
+
'l2_category': row.get('l2_category', ''),
|
| 327 |
+
}
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
def _parse_docvqa(self, row) -> BenchmarkSample:
|
| 331 |
+
"""DocVQA: multiple valid answers."""
|
| 332 |
+
images = [row['image']]
|
| 333 |
+
answers = row.get('answers', [''])
|
| 334 |
+
|
| 335 |
+
return BenchmarkSample(
|
| 336 |
+
images=images,
|
| 337 |
+
question=row['question'],
|
| 338 |
+
options=None,
|
| 339 |
+
answer=answers, # List of valid answers
|
| 340 |
+
answer_type='open',
|
| 341 |
+
benchmark='docvqa',
|
| 342 |
+
metadata={
|
| 343 |
+
'question_id': row.get('questionId', ''),
|
| 344 |
+
'question_types': row.get('question_types', []),
|
| 345 |
+
}
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
def _parse_textvqa(self, row) -> BenchmarkSample:
|
| 349 |
+
"""TextVQA: 10 annotations."""
|
| 350 |
+
images = [row['image']]
|
| 351 |
+
answers = row.get('answers', [''])
|
| 352 |
+
|
| 353 |
+
return BenchmarkSample(
|
| 354 |
+
images=images,
|
| 355 |
+
question=row['question'],
|
| 356 |
+
options=None,
|
| 357 |
+
answer=answers, # 10 annotations
|
| 358 |
+
answer_type='open',
|
| 359 |
+
benchmark='textvqa',
|
| 360 |
+
metadata={
|
| 361 |
+
'question_id': row.get('question_id', ''),
|
| 362 |
+
'ocr_tokens': row.get('ocr_tokens', []),
|
| 363 |
+
}
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
def _parse_chartqa(self, row) -> BenchmarkSample:
|
| 367 |
+
"""ChartQA: relaxed numeric accuracy."""
|
| 368 |
+
images = [row['image']]
|
| 369 |
+
|
| 370 |
+
return BenchmarkSample(
|
| 371 |
+
images=images,
|
| 372 |
+
question=row['question'],
|
| 373 |
+
options=None,
|
| 374 |
+
answer=row.get('answer', ''),
|
| 375 |
+
answer_type='open',
|
| 376 |
+
benchmark='chartqa',
|
| 377 |
+
metadata={
|
| 378 |
+
'type': row.get('type', ''),
|
| 379 |
+
}
|
| 380 |
+
)
|
mr_jepa/evaluation/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .metrics import (
|
| 2 |
+
compute_accuracy,
|
| 3 |
+
compute_anls,
|
| 4 |
+
compute_vqa_accuracy,
|
| 5 |
+
compute_relaxed_accuracy,
|
| 6 |
+
evaluate_benchmark,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"compute_accuracy",
|
| 11 |
+
"compute_anls",
|
| 12 |
+
"compute_vqa_accuracy",
|
| 13 |
+
"compute_relaxed_accuracy",
|
| 14 |
+
"evaluate_benchmark",
|
| 15 |
+
]
|
mr_jepa/evaluation/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (332 Bytes). View file
|
|
|
mr_jepa/evaluation/__pycache__/metrics.cpython-312.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
mr_jepa/evaluation/metrics.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluation Metrics for MR-JEPA Benchmarks.
|
| 3 |
+
|
| 4 |
+
Each benchmark has specific evaluation protocols:
|
| 5 |
+
- Accuracy: MMMU, ScienceQA, AI2D, MMBench, MMStar
|
| 6 |
+
- ANLS: DocVQA (Average Normalized Levenshtein Similarity)
|
| 7 |
+
- VQA Accuracy: TextVQA (soft majority over 10 annotations)
|
| 8 |
+
- Relaxed Accuracy: ChartQA (±5% tolerance for numerics)
|
| 9 |
+
- Mixed: MathVista (accuracy for MC, relaxed match for free-form)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import re
|
| 13 |
+
import torch
|
| 14 |
+
import numpy as np
|
| 15 |
+
from typing import List, Dict, Optional, Any, Union
|
| 16 |
+
from collections import defaultdict
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def compute_accuracy(
|
| 20 |
+
predictions: List[int],
|
| 21 |
+
ground_truth: List[int],
|
| 22 |
+
category_labels: Optional[List[str]] = None,
|
| 23 |
+
) -> Dict[str, float]:
|
| 24 |
+
"""
|
| 25 |
+
Standard accuracy for MC benchmarks.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
predictions: Predicted option indices
|
| 29 |
+
ground_truth: Correct option indices
|
| 30 |
+
category_labels: Optional per-sample categories for breakdown
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
Dict with 'accuracy' and optional per-category breakdown
|
| 34 |
+
"""
|
| 35 |
+
assert len(predictions) == len(ground_truth)
|
| 36 |
+
|
| 37 |
+
correct = sum(p == g for p, g in zip(predictions, ground_truth))
|
| 38 |
+
total = len(predictions)
|
| 39 |
+
|
| 40 |
+
result = {'accuracy': correct / max(total, 1) * 100}
|
| 41 |
+
|
| 42 |
+
# Per-category breakdown
|
| 43 |
+
if category_labels:
|
| 44 |
+
cat_correct = defaultdict(int)
|
| 45 |
+
cat_total = defaultdict(int)
|
| 46 |
+
for p, g, c in zip(predictions, ground_truth, category_labels):
|
| 47 |
+
cat_total[c] += 1
|
| 48 |
+
if p == g:
|
| 49 |
+
cat_correct[c] += 1
|
| 50 |
+
|
| 51 |
+
result['per_category'] = {
|
| 52 |
+
c: cat_correct[c] / max(cat_total[c], 1) * 100
|
| 53 |
+
for c in sorted(cat_total.keys())
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
return result
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _normalized_levenshtein(s1: str, s2: str) -> float:
|
| 60 |
+
"""Compute normalized Levenshtein distance between two strings."""
|
| 61 |
+
s1 = s1.lower().strip()
|
| 62 |
+
s2 = s2.lower().strip()
|
| 63 |
+
|
| 64 |
+
if s1 == s2:
|
| 65 |
+
return 0.0
|
| 66 |
+
|
| 67 |
+
len1, len2 = len(s1), len(s2)
|
| 68 |
+
if len1 == 0 or len2 == 0:
|
| 69 |
+
return 1.0
|
| 70 |
+
|
| 71 |
+
# Dynamic programming Levenshtein
|
| 72 |
+
matrix = [[0] * (len2 + 1) for _ in range(len1 + 1)]
|
| 73 |
+
for i in range(len1 + 1):
|
| 74 |
+
matrix[i][0] = i
|
| 75 |
+
for j in range(len2 + 1):
|
| 76 |
+
matrix[0][j] = j
|
| 77 |
+
|
| 78 |
+
for i in range(1, len1 + 1):
|
| 79 |
+
for j in range(1, len2 + 1):
|
| 80 |
+
cost = 0 if s1[i-1] == s2[j-1] else 1
|
| 81 |
+
matrix[i][j] = min(
|
| 82 |
+
matrix[i-1][j] + 1,
|
| 83 |
+
matrix[i][j-1] + 1,
|
| 84 |
+
matrix[i-1][j-1] + cost,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
return matrix[len1][len2] / max(len1, len2)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def compute_anls(
|
| 91 |
+
predictions: List[str],
|
| 92 |
+
ground_truths: List[List[str]],
|
| 93 |
+
threshold: float = 0.5,
|
| 94 |
+
) -> Dict[str, float]:
|
| 95 |
+
"""
|
| 96 |
+
Average Normalized Levenshtein Similarity (ANLS) for DocVQA.
|
| 97 |
+
|
| 98 |
+
ANLS = 1 - NL_distance if NL_distance < threshold, else 0
|
| 99 |
+
Final score is max over all valid answers.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
predictions: List of predicted answer strings
|
| 103 |
+
ground_truths: List of lists of valid answer strings
|
| 104 |
+
threshold: NL distance threshold (default 0.5)
|
| 105 |
+
"""
|
| 106 |
+
scores = []
|
| 107 |
+
for pred, gts in zip(predictions, ground_truths):
|
| 108 |
+
if not gts:
|
| 109 |
+
scores.append(0.0)
|
| 110 |
+
continue
|
| 111 |
+
|
| 112 |
+
# Take max ANLS over all valid answers
|
| 113 |
+
max_score = 0.0
|
| 114 |
+
for gt in gts:
|
| 115 |
+
nl_dist = _normalized_levenshtein(pred, gt)
|
| 116 |
+
if nl_dist < threshold:
|
| 117 |
+
score = 1.0 - nl_dist
|
| 118 |
+
else:
|
| 119 |
+
score = 0.0
|
| 120 |
+
max_score = max(max_score, score)
|
| 121 |
+
|
| 122 |
+
scores.append(max_score)
|
| 123 |
+
|
| 124 |
+
return {'anls': np.mean(scores) * 100 if scores else 0.0}
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def compute_vqa_accuracy(
|
| 128 |
+
predictions: List[str],
|
| 129 |
+
ground_truths: List[List[str]],
|
| 130 |
+
) -> Dict[str, float]:
|
| 131 |
+
"""
|
| 132 |
+
VQA Accuracy for TextVQA.
|
| 133 |
+
|
| 134 |
+
score = min(count(matching annotations) / 3, 1.0)
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
predictions: Predicted answers
|
| 138 |
+
ground_truths: Lists of 10 human annotations per question
|
| 139 |
+
"""
|
| 140 |
+
scores = []
|
| 141 |
+
for pred, gts in zip(predictions, ground_truths):
|
| 142 |
+
pred_norm = pred.lower().strip()
|
| 143 |
+
matching = sum(1 for gt in gts if gt.lower().strip() == pred_norm)
|
| 144 |
+
score = min(matching / 3.0, 1.0)
|
| 145 |
+
scores.append(score)
|
| 146 |
+
|
| 147 |
+
return {'vqa_accuracy': np.mean(scores) * 100 if scores else 0.0}
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def _is_numeric(s: str) -> bool:
|
| 151 |
+
"""Check if string represents a number."""
|
| 152 |
+
try:
|
| 153 |
+
float(s.replace(',', '').replace('%', '').strip())
|
| 154 |
+
return True
|
| 155 |
+
except (ValueError, AttributeError):
|
| 156 |
+
return False
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _parse_numeric(s: str) -> float:
|
| 160 |
+
"""Parse numeric value from string."""
|
| 161 |
+
s = s.replace(',', '').replace('%', '').strip()
|
| 162 |
+
return float(s)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def compute_relaxed_accuracy(
|
| 166 |
+
predictions: List[str],
|
| 167 |
+
ground_truths: List[str],
|
| 168 |
+
tolerance: float = 0.05,
|
| 169 |
+
types: Optional[List[str]] = None,
|
| 170 |
+
) -> Dict[str, float]:
|
| 171 |
+
"""
|
| 172 |
+
Relaxed Accuracy for ChartQA.
|
| 173 |
+
|
| 174 |
+
- Numeric answers: within ±5% tolerance
|
| 175 |
+
- String answers: exact match (case-insensitive)
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
predictions: Predicted answers
|
| 179 |
+
ground_truths: Ground truth answers
|
| 180 |
+
tolerance: Numeric tolerance (default 5%)
|
| 181 |
+
types: Optional list of 'human_test'/'augmented_test' for breakdown
|
| 182 |
+
"""
|
| 183 |
+
correct = []
|
| 184 |
+
for pred, gt in zip(predictions, ground_truths):
|
| 185 |
+
pred_str = str(pred).strip().lower()
|
| 186 |
+
gt_str = str(gt).strip().lower()
|
| 187 |
+
|
| 188 |
+
if _is_numeric(gt_str) and _is_numeric(pred_str):
|
| 189 |
+
gt_val = _parse_numeric(gt_str)
|
| 190 |
+
pred_val = _parse_numeric(pred_str)
|
| 191 |
+
if gt_val == 0:
|
| 192 |
+
is_correct = abs(pred_val) <= tolerance
|
| 193 |
+
else:
|
| 194 |
+
is_correct = abs(pred_val - gt_val) / abs(gt_val) <= tolerance
|
| 195 |
+
else:
|
| 196 |
+
is_correct = pred_str == gt_str
|
| 197 |
+
|
| 198 |
+
correct.append(is_correct)
|
| 199 |
+
|
| 200 |
+
result = {'relaxed_accuracy': np.mean(correct) * 100 if correct else 0.0}
|
| 201 |
+
|
| 202 |
+
# Per-type breakdown (human vs augmented)
|
| 203 |
+
if types:
|
| 204 |
+
for t in set(types):
|
| 205 |
+
type_correct = [c for c, tp in zip(correct, types) if tp == t]
|
| 206 |
+
result[f'relaxed_accuracy_{t}'] = np.mean(type_correct) * 100 if type_correct else 0.0
|
| 207 |
+
|
| 208 |
+
return result
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def evaluate_benchmark(
|
| 212 |
+
benchmark: str,
|
| 213 |
+
predictions: List[Any],
|
| 214 |
+
ground_truths: List[Any],
|
| 215 |
+
metadata: Optional[Dict[str, List]] = None,
|
| 216 |
+
) -> Dict[str, float]:
|
| 217 |
+
"""
|
| 218 |
+
Evaluate predictions for a specific benchmark.
|
| 219 |
+
|
| 220 |
+
Dispatches to the appropriate metric function.
|
| 221 |
+
"""
|
| 222 |
+
metric_map = {
|
| 223 |
+
'mmmu': 'accuracy',
|
| 224 |
+
'scienceqa': 'accuracy',
|
| 225 |
+
'ai2d': 'accuracy',
|
| 226 |
+
'mmbench': 'accuracy',
|
| 227 |
+
'mmstar': 'accuracy',
|
| 228 |
+
'mathvista': 'accuracy', # Simplified; full eval handles mixed types
|
| 229 |
+
'docvqa': 'anls',
|
| 230 |
+
'textvqa': 'vqa_accuracy',
|
| 231 |
+
'chartqa': 'relaxed_accuracy',
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
metric = metric_map.get(benchmark, 'accuracy')
|
| 235 |
+
|
| 236 |
+
if metric == 'accuracy':
|
| 237 |
+
categories = metadata.get('categories') if metadata else None
|
| 238 |
+
return compute_accuracy(predictions, ground_truths, categories)
|
| 239 |
+
|
| 240 |
+
elif metric == 'anls':
|
| 241 |
+
return compute_anls(predictions, ground_truths)
|
| 242 |
+
|
| 243 |
+
elif metric == 'vqa_accuracy':
|
| 244 |
+
return compute_vqa_accuracy(predictions, ground_truths)
|
| 245 |
+
|
| 246 |
+
elif metric == 'relaxed_accuracy':
|
| 247 |
+
types = metadata.get('types') if metadata else None
|
| 248 |
+
return compute_relaxed_accuracy(predictions, ground_truths, types=types)
|
| 249 |
+
|
| 250 |
+
else:
|
| 251 |
+
raise ValueError(f"Unknown metric: {metric}")
|
mr_jepa/models/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .mr_jepa import MRJEPAModel
|
| 2 |
+
from .evidence_memory import EvidenceMemory
|
| 3 |
+
from .latent_rollout import LatentRolloutModule
|
| 4 |
+
from .answer_heads import DiscriminativeHead, GenerativeHead
|
| 5 |
+
from .backbones import VisualBackbone, TextEncoder
|
| 6 |
+
from .target_encoder import TargetEncoder
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"MRJEPAModel",
|
| 10 |
+
"EvidenceMemory",
|
| 11 |
+
"LatentRolloutModule",
|
| 12 |
+
"DiscriminativeHead",
|
| 13 |
+
"GenerativeHead",
|
| 14 |
+
"VisualBackbone",
|
| 15 |
+
"TextEncoder",
|
| 16 |
+
"TargetEncoder",
|
| 17 |
+
]
|
mr_jepa/models/__pycache__/answer_heads.cpython-312.pyc
ADDED
|
Binary file (14.6 kB). View file
|
|
|
mr_jepa/models/__pycache__/evidence_memory.cpython-312.pyc
ADDED
|
Binary file (14 kB). View file
|
|
|
mr_jepa/models/__pycache__/latent_rollout.cpython-312.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
mr_jepa/models/__pycache__/target_encoder.cpython-312.pyc
ADDED
|
Binary file (15 kB). View file
|
|
|
mr_jepa/models/answer_heads.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Answer Prediction Heads for MR-JEPA.
|
| 3 |
+
|
| 4 |
+
Two heads:
|
| 5 |
+
1. Discriminative Head (primary): Scores answer options for MC questions.
|
| 6 |
+
Takes the final latent state z_K and computes compatibility scores
|
| 7 |
+
with encoded answer option representations.
|
| 8 |
+
|
| 9 |
+
2. Generative Head (secondary): Short text decoder for open-ended answers.
|
| 10 |
+
Small transformer decoder that cross-attends to the final latent state
|
| 11 |
+
and evidence memory, constrained to produce brief answers.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
import math
|
| 18 |
+
from typing import Optional, Dict, Tuple
|
| 19 |
+
|
| 20 |
+
from ..configs.model_config import AnswerHeadConfig
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DiscriminativeHead(nn.Module):
|
| 24 |
+
"""
|
| 25 |
+
Multiple-choice answer scoring head.
|
| 26 |
+
|
| 27 |
+
Architecture:
|
| 28 |
+
1. Pool latent state z_K → global reasoning vector
|
| 29 |
+
2. Encode each answer option via a small MLP
|
| 30 |
+
3. Compute compatibility score: score_i = MLP(z_pool ⊙ opt_i)
|
| 31 |
+
|
| 32 |
+
Supports variable number of options (2-8, with masking).
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, config: AnswerHeadConfig, hidden_dim: int, text_dim: int):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.config = config
|
| 38 |
+
self.hidden_dim = hidden_dim
|
| 39 |
+
|
| 40 |
+
# State pooling: attention-weighted pooling over state tokens
|
| 41 |
+
self.state_pool_query = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
|
| 42 |
+
self.state_pool_attn = nn.MultiheadAttention(
|
| 43 |
+
embed_dim=hidden_dim,
|
| 44 |
+
num_heads=8,
|
| 45 |
+
batch_first=True,
|
| 46 |
+
)
|
| 47 |
+
self.state_pool_norm = nn.LayerNorm(hidden_dim)
|
| 48 |
+
|
| 49 |
+
# Option encoder: project text option embeddings
|
| 50 |
+
self.option_proj = nn.Sequential(
|
| 51 |
+
nn.Linear(text_dim, hidden_dim),
|
| 52 |
+
nn.LayerNorm(hidden_dim),
|
| 53 |
+
nn.GELU(),
|
| 54 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Score computation: bilinear-style scoring
|
| 58 |
+
self.score_mlp = nn.Sequential(
|
| 59 |
+
nn.Linear(hidden_dim * 3, config.disc_hidden_dim),
|
| 60 |
+
nn.GELU(),
|
| 61 |
+
nn.Dropout(config.disc_dropout),
|
| 62 |
+
nn.Linear(config.disc_hidden_dim, config.disc_hidden_dim),
|
| 63 |
+
nn.GELU(),
|
| 64 |
+
nn.Dropout(config.disc_dropout),
|
| 65 |
+
nn.Linear(config.disc_hidden_dim, 1),
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def _pool_state(self, z_final: torch.Tensor) -> torch.Tensor:
|
| 69 |
+
"""
|
| 70 |
+
Attention-weighted pooling of final latent state.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
z_final: [B, N_s, D]
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
Pooled state vector [B, D]
|
| 77 |
+
"""
|
| 78 |
+
B = z_final.size(0)
|
| 79 |
+
query = self.state_pool_query.expand(B, -1, -1) # [B, 1, D]
|
| 80 |
+
z_normed = self.state_pool_norm(z_final)
|
| 81 |
+
pooled, _ = self.state_pool_attn(query, z_normed, z_normed)
|
| 82 |
+
return pooled.squeeze(1) # [B, D]
|
| 83 |
+
|
| 84 |
+
def forward(
|
| 85 |
+
self,
|
| 86 |
+
z_final: torch.Tensor, # [B, N_s, D] final latent state
|
| 87 |
+
option_embeddings: torch.Tensor, # [B, max_opts, D_text] encoded options
|
| 88 |
+
option_mask: torch.Tensor, # [B, max_opts] bool: True=valid
|
| 89 |
+
) -> Dict[str, torch.Tensor]:
|
| 90 |
+
"""
|
| 91 |
+
Score answer options.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
dict with:
|
| 95 |
+
'logits': [B, max_opts] raw scores
|
| 96 |
+
'probs': [B, max_opts] masked softmax probabilities
|
| 97 |
+
"""
|
| 98 |
+
B, max_opts = option_mask.shape
|
| 99 |
+
|
| 100 |
+
# Pool final latent state
|
| 101 |
+
z_pooled = self._pool_state(z_final) # [B, D]
|
| 102 |
+
|
| 103 |
+
# Project option embeddings
|
| 104 |
+
opt_proj = self.option_proj(option_embeddings) # [B, max_opts, D]
|
| 105 |
+
|
| 106 |
+
# Compute scores for each option
|
| 107 |
+
z_expanded = z_pooled.unsqueeze(1).expand(-1, max_opts, -1) # [B, max_opts, D]
|
| 108 |
+
|
| 109 |
+
# Concatenate: [z, opt, z⊙opt] for rich interaction
|
| 110 |
+
combined = torch.cat([
|
| 111 |
+
z_expanded,
|
| 112 |
+
opt_proj,
|
| 113 |
+
z_expanded * opt_proj, # Element-wise interaction
|
| 114 |
+
], dim=-1) # [B, max_opts, 3*D]
|
| 115 |
+
|
| 116 |
+
logits = self.score_mlp(combined).squeeze(-1) # [B, max_opts]
|
| 117 |
+
|
| 118 |
+
# Mask invalid options
|
| 119 |
+
logits = logits.masked_fill(~option_mask, float('-inf'))
|
| 120 |
+
probs = F.softmax(logits, dim=-1)
|
| 121 |
+
|
| 122 |
+
return {
|
| 123 |
+
'logits': logits,
|
| 124 |
+
'probs': probs,
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class GenerativeHead(nn.Module):
|
| 129 |
+
"""
|
| 130 |
+
Short-answer generative decoder.
|
| 131 |
+
|
| 132 |
+
Small transformer decoder that:
|
| 133 |
+
1. Cross-attends to the final latent state z_K
|
| 134 |
+
2. Optionally cross-attends to evidence memory (evidence-constrained)
|
| 135 |
+
3. Autoregressively generates a short answer (≤64 tokens)
|
| 136 |
+
|
| 137 |
+
This is a secondary objective — the primary evaluation uses the
|
| 138 |
+
discriminative head for MC questions.
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
def __init__(
|
| 142 |
+
self,
|
| 143 |
+
config: AnswerHeadConfig,
|
| 144 |
+
hidden_dim: int,
|
| 145 |
+
vocab_size: int,
|
| 146 |
+
):
|
| 147 |
+
super().__init__()
|
| 148 |
+
self.config = config
|
| 149 |
+
self.hidden_dim = hidden_dim
|
| 150 |
+
self.vocab_size = vocab_size
|
| 151 |
+
|
| 152 |
+
# Token embedding + positional encoding
|
| 153 |
+
self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
|
| 154 |
+
self.pos_embedding = nn.Embedding(config.gen_max_answer_length, hidden_dim)
|
| 155 |
+
|
| 156 |
+
# Transformer decoder layers
|
| 157 |
+
self.decoder_layers = nn.ModuleList()
|
| 158 |
+
for _ in range(config.gen_num_layers):
|
| 159 |
+
self.decoder_layers.append(
|
| 160 |
+
GenerativeDecoderLayer(
|
| 161 |
+
hidden_dim=hidden_dim,
|
| 162 |
+
num_heads=config.gen_num_heads,
|
| 163 |
+
dropout=config.gen_dropout,
|
| 164 |
+
use_evidence_cross_attn=config.use_evidence_constraint,
|
| 165 |
+
)
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# Output projection to vocabulary
|
| 169 |
+
self.output_norm = nn.LayerNorm(hidden_dim)
|
| 170 |
+
self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)
|
| 171 |
+
|
| 172 |
+
# Tie weights with token embedding
|
| 173 |
+
self.lm_head.weight = self.token_embedding.weight
|
| 174 |
+
|
| 175 |
+
def forward(
|
| 176 |
+
self,
|
| 177 |
+
z_final: torch.Tensor, # [B, N_s, D]
|
| 178 |
+
target_ids: torch.Tensor, # [B, seq_len]
|
| 179 |
+
evidence_tokens: Optional[torch.Tensor] = None, # [B, N_e, D]
|
| 180 |
+
evidence_mask: Optional[torch.Tensor] = None,
|
| 181 |
+
) -> Dict[str, torch.Tensor]:
|
| 182 |
+
"""
|
| 183 |
+
Teacher-forced forward pass for training.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
z_final: Final latent state from rollout
|
| 187 |
+
target_ids: Target answer token IDs
|
| 188 |
+
evidence_tokens: Evidence memory for constrained decoding
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
dict with:
|
| 192 |
+
'logits': [B, seq_len, vocab_size]
|
| 193 |
+
'loss': scalar cross-entropy loss
|
| 194 |
+
"""
|
| 195 |
+
B, seq_len = target_ids.shape
|
| 196 |
+
device = target_ids.device
|
| 197 |
+
|
| 198 |
+
# Embed target tokens
|
| 199 |
+
positions = torch.arange(seq_len, device=device).unsqueeze(0)
|
| 200 |
+
x = self.token_embedding(target_ids) + self.pos_embedding(positions)
|
| 201 |
+
|
| 202 |
+
# Causal mask
|
| 203 |
+
causal_mask = torch.triu(
|
| 204 |
+
torch.ones(seq_len, seq_len, device=device, dtype=torch.bool),
|
| 205 |
+
diagonal=1
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# Apply decoder layers
|
| 209 |
+
for layer in self.decoder_layers:
|
| 210 |
+
x = layer(
|
| 211 |
+
x=x,
|
| 212 |
+
z_final=z_final,
|
| 213 |
+
causal_mask=causal_mask,
|
| 214 |
+
evidence_tokens=evidence_tokens,
|
| 215 |
+
evidence_mask=evidence_mask,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Project to vocabulary
|
| 219 |
+
logits = self.lm_head(self.output_norm(x)) # [B, seq_len, vocab]
|
| 220 |
+
|
| 221 |
+
# Compute loss (shift by 1 for next-token prediction)
|
| 222 |
+
shift_logits = logits[:, :-1].contiguous()
|
| 223 |
+
shift_labels = target_ids[:, 1:].contiguous()
|
| 224 |
+
loss = F.cross_entropy(
|
| 225 |
+
shift_logits.view(-1, self.vocab_size),
|
| 226 |
+
shift_labels.view(-1),
|
| 227 |
+
ignore_index=-100,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
return {
|
| 231 |
+
'logits': logits,
|
| 232 |
+
'loss': loss,
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
@torch.no_grad()
|
| 236 |
+
def generate(
|
| 237 |
+
self,
|
| 238 |
+
z_final: torch.Tensor,
|
| 239 |
+
start_token_id: int,
|
| 240 |
+
max_length: int = 64,
|
| 241 |
+
evidence_tokens: Optional[torch.Tensor] = None,
|
| 242 |
+
evidence_mask: Optional[torch.Tensor] = None,
|
| 243 |
+
eos_token_id: Optional[int] = None,
|
| 244 |
+
) -> torch.Tensor:
|
| 245 |
+
"""
|
| 246 |
+
Autoregressive generation for inference.
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
generated_ids: [B, gen_len]
|
| 250 |
+
"""
|
| 251 |
+
B = z_final.size(0)
|
| 252 |
+
device = z_final.device
|
| 253 |
+
|
| 254 |
+
generated = torch.full((B, 1), start_token_id, dtype=torch.long, device=device)
|
| 255 |
+
|
| 256 |
+
for step in range(max_length - 1):
|
| 257 |
+
seq_len = generated.size(1)
|
| 258 |
+
positions = torch.arange(seq_len, device=device).unsqueeze(0)
|
| 259 |
+
x = self.token_embedding(generated) + self.pos_embedding(positions)
|
| 260 |
+
|
| 261 |
+
causal_mask = torch.triu(
|
| 262 |
+
torch.ones(seq_len, seq_len, device=device, dtype=torch.bool),
|
| 263 |
+
diagonal=1
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
for layer in self.decoder_layers:
|
| 267 |
+
x = layer(
|
| 268 |
+
x=x,
|
| 269 |
+
z_final=z_final,
|
| 270 |
+
causal_mask=causal_mask,
|
| 271 |
+
evidence_tokens=evidence_tokens,
|
| 272 |
+
evidence_mask=evidence_mask,
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
logits = self.lm_head(self.output_norm(x[:, -1:])) # [B, 1, vocab]
|
| 276 |
+
next_token = logits.argmax(dim=-1) # [B, 1]
|
| 277 |
+
generated = torch.cat([generated, next_token], dim=1)
|
| 278 |
+
|
| 279 |
+
# Check EOS
|
| 280 |
+
if eos_token_id is not None:
|
| 281 |
+
if (next_token == eos_token_id).all():
|
| 282 |
+
break
|
| 283 |
+
|
| 284 |
+
return generated
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class GenerativeDecoderLayer(nn.Module):
|
| 288 |
+
"""Single transformer decoder layer with optional evidence cross-attention."""
|
| 289 |
+
|
| 290 |
+
def __init__(
|
| 291 |
+
self,
|
| 292 |
+
hidden_dim: int,
|
| 293 |
+
num_heads: int,
|
| 294 |
+
dropout: float,
|
| 295 |
+
use_evidence_cross_attn: bool = True,
|
| 296 |
+
):
|
| 297 |
+
super().__init__()
|
| 298 |
+
|
| 299 |
+
# Causal self-attention
|
| 300 |
+
self.self_attn = nn.MultiheadAttention(
|
| 301 |
+
embed_dim=hidden_dim, num_heads=num_heads,
|
| 302 |
+
dropout=dropout, batch_first=True,
|
| 303 |
+
)
|
| 304 |
+
self.self_attn_norm = nn.LayerNorm(hidden_dim)
|
| 305 |
+
|
| 306 |
+
# Cross-attention to latent state z_K
|
| 307 |
+
self.state_cross_attn = nn.MultiheadAttention(
|
| 308 |
+
embed_dim=hidden_dim, num_heads=num_heads,
|
| 309 |
+
dropout=dropout, batch_first=True,
|
| 310 |
+
)
|
| 311 |
+
self.state_cross_norm = nn.LayerNorm(hidden_dim)
|
| 312 |
+
|
| 313 |
+
# Optional: cross-attention to evidence memory
|
| 314 |
+
self.use_evidence_cross_attn = use_evidence_cross_attn
|
| 315 |
+
if use_evidence_cross_attn:
|
| 316 |
+
self.evidence_cross_attn = nn.MultiheadAttention(
|
| 317 |
+
embed_dim=hidden_dim, num_heads=num_heads,
|
| 318 |
+
dropout=dropout, batch_first=True,
|
| 319 |
+
)
|
| 320 |
+
self.evidence_cross_norm = nn.LayerNorm(hidden_dim)
|
| 321 |
+
|
| 322 |
+
# FFN
|
| 323 |
+
self.ffn = nn.Sequential(
|
| 324 |
+
nn.Linear(hidden_dim, hidden_dim * 4),
|
| 325 |
+
nn.GELU(),
|
| 326 |
+
nn.Dropout(dropout),
|
| 327 |
+
nn.Linear(hidden_dim * 4, hidden_dim),
|
| 328 |
+
nn.Dropout(dropout),
|
| 329 |
+
)
|
| 330 |
+
self.ffn_norm = nn.LayerNorm(hidden_dim)
|
| 331 |
+
|
| 332 |
+
def forward(
|
| 333 |
+
self,
|
| 334 |
+
x: torch.Tensor,
|
| 335 |
+
z_final: torch.Tensor,
|
| 336 |
+
causal_mask: torch.Tensor,
|
| 337 |
+
evidence_tokens: Optional[torch.Tensor] = None,
|
| 338 |
+
evidence_mask: Optional[torch.Tensor] = None,
|
| 339 |
+
) -> torch.Tensor:
|
| 340 |
+
# Causal self-attention
|
| 341 |
+
residual = x
|
| 342 |
+
x_normed = self.self_attn_norm(x)
|
| 343 |
+
x_out, _ = self.self_attn(
|
| 344 |
+
x_normed, x_normed, x_normed,
|
| 345 |
+
attn_mask=causal_mask,
|
| 346 |
+
)
|
| 347 |
+
x = residual + x_out
|
| 348 |
+
|
| 349 |
+
# Cross-attention to latent state
|
| 350 |
+
residual = x
|
| 351 |
+
x_normed = self.state_cross_norm(x)
|
| 352 |
+
x_out, _ = self.state_cross_attn(x_normed, z_final, z_final)
|
| 353 |
+
x = residual + x_out
|
| 354 |
+
|
| 355 |
+
# Optional evidence cross-attention
|
| 356 |
+
if self.use_evidence_cross_attn and evidence_tokens is not None:
|
| 357 |
+
residual = x
|
| 358 |
+
x_normed = self.evidence_cross_norm(x)
|
| 359 |
+
x_out, _ = self.evidence_cross_attn(
|
| 360 |
+
x_normed, evidence_tokens, evidence_tokens,
|
| 361 |
+
key_padding_mask=evidence_mask,
|
| 362 |
+
)
|
| 363 |
+
x = residual + x_out
|
| 364 |
+
|
| 365 |
+
# FFN
|
| 366 |
+
residual = x
|
| 367 |
+
x = residual + self.ffn(self.ffn_norm(x))
|
| 368 |
+
|
| 369 |
+
return x
|
mr_jepa/models/backbones.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Visual and Text Backbone Encoders for MR-JEPA.
|
| 3 |
+
|
| 4 |
+
Visual: DINOv2-L/G or DINOv3-L (dense SSL features, no text alignment)
|
| 5 |
+
Text: DeBERTa-v3 (strong NLU encoder for questions + options)
|
| 6 |
+
|
| 7 |
+
Both backbones are frozen in Phase 1 and partially unfrozen in Phase 2.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from typing import Optional, Dict, Any
|
| 13 |
+
|
| 14 |
+
from ..configs.model_config import VisualBackboneConfig, TextEncoderConfig
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class VisualBackbone(nn.Module):
|
| 18 |
+
"""
|
| 19 |
+
Dense visual feature extractor using DINOv2/v3 or SigLIP2.
|
| 20 |
+
|
| 21 |
+
Outputs patch-level tokens (excluding CLS and register tokens).
|
| 22 |
+
For DINOv2-L at 518px: 1369 patch tokens × 1024 dim.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, config: VisualBackboneConfig):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.config = config
|
| 28 |
+
self.backbone = None
|
| 29 |
+
self.hidden_size = config.hidden_size
|
| 30 |
+
self._build_backbone()
|
| 31 |
+
|
| 32 |
+
if config.freeze:
|
| 33 |
+
self.freeze_all()
|
| 34 |
+
|
| 35 |
+
def _build_backbone(self):
|
| 36 |
+
"""Load pretrained backbone from HuggingFace."""
|
| 37 |
+
from transformers import AutoModel, AutoImageProcessor
|
| 38 |
+
|
| 39 |
+
if self.config.backbone_type in ("dinov2", "dinov3"):
|
| 40 |
+
self.backbone = AutoModel.from_pretrained(
|
| 41 |
+
self.config.model_name,
|
| 42 |
+
torch_dtype=torch.float32, # DINOv2 is fp32
|
| 43 |
+
)
|
| 44 |
+
self.processor = AutoImageProcessor.from_pretrained(
|
| 45 |
+
self.config.model_name
|
| 46 |
+
)
|
| 47 |
+
# DINOv2/v3 outputs: last_hidden_state includes [CLS] + registers + patches
|
| 48 |
+
self._skip_tokens = 1 + self.config.num_register_tokens # CLS + regs
|
| 49 |
+
|
| 50 |
+
elif self.config.backbone_type == "siglip2":
|
| 51 |
+
from transformers import SiglipVisionModel, SiglipImageProcessor
|
| 52 |
+
self.backbone = SiglipVisionModel.from_pretrained(
|
| 53 |
+
self.config.model_name,
|
| 54 |
+
torch_dtype=torch.float32,
|
| 55 |
+
)
|
| 56 |
+
self.processor = SiglipImageProcessor.from_pretrained(
|
| 57 |
+
self.config.model_name
|
| 58 |
+
)
|
| 59 |
+
self._skip_tokens = 0 # SigLIP has no CLS or register tokens
|
| 60 |
+
|
| 61 |
+
def freeze_all(self):
|
| 62 |
+
"""Freeze all backbone parameters."""
|
| 63 |
+
for param in self.backbone.parameters():
|
| 64 |
+
param.requires_grad = False
|
| 65 |
+
|
| 66 |
+
def unfreeze_last_n_layers(self, n: int):
|
| 67 |
+
"""Unfreeze the last N transformer layers (Phase 2)."""
|
| 68 |
+
# DINOv2 uses model.encoder.layer[i]
|
| 69 |
+
if hasattr(self.backbone, 'encoder'):
|
| 70 |
+
layers = self.backbone.encoder.layer
|
| 71 |
+
elif hasattr(self.backbone, 'vision_model'):
|
| 72 |
+
layers = self.backbone.vision_model.encoder.layers
|
| 73 |
+
else:
|
| 74 |
+
raise ValueError(f"Unknown backbone structure for {self.config.model_name}")
|
| 75 |
+
|
| 76 |
+
total_layers = len(layers)
|
| 77 |
+
for i, layer in enumerate(layers):
|
| 78 |
+
if i >= total_layers - n:
|
| 79 |
+
for param in layer.parameters():
|
| 80 |
+
param.requires_grad = True
|
| 81 |
+
|
| 82 |
+
def forward(
|
| 83 |
+
self,
|
| 84 |
+
pixel_values: torch.Tensor, # [B, C, H, W]
|
| 85 |
+
return_cls: bool = False,
|
| 86 |
+
) -> Dict[str, torch.Tensor]:
|
| 87 |
+
"""
|
| 88 |
+
Extract dense patch tokens from images.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
pixel_values: Preprocessed image tensors [B, C, H, W]
|
| 92 |
+
return_cls: Whether to also return the CLS token
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
dict with:
|
| 96 |
+
'patch_tokens': [B, num_patches, hidden_size]
|
| 97 |
+
'cls_token': [B, hidden_size] (if return_cls=True)
|
| 98 |
+
"""
|
| 99 |
+
outputs = self.backbone(pixel_values=pixel_values)
|
| 100 |
+
hidden_states = outputs.last_hidden_state # [B, 1+reg+patches, D]
|
| 101 |
+
|
| 102 |
+
result = {}
|
| 103 |
+
result['patch_tokens'] = hidden_states[:, self._skip_tokens:] # [B, num_patches, D]
|
| 104 |
+
|
| 105 |
+
if return_cls:
|
| 106 |
+
result['cls_token'] = hidden_states[:, 0] # [B, D]
|
| 107 |
+
|
| 108 |
+
return result
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class TextEncoder(nn.Module):
|
| 112 |
+
"""
|
| 113 |
+
Text encoder for questions, options, and optional context.
|
| 114 |
+
|
| 115 |
+
Uses DeBERTa-v3 for strong NLU. Outputs:
|
| 116 |
+
- Token-level representations for cross-attention
|
| 117 |
+
- [CLS] representation for global text understanding
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(self, config: TextEncoderConfig):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.config = config
|
| 123 |
+
self.hidden_size = config.hidden_size
|
| 124 |
+
self._build_encoder()
|
| 125 |
+
|
| 126 |
+
if config.freeze:
|
| 127 |
+
self.freeze_all()
|
| 128 |
+
|
| 129 |
+
def _build_encoder(self):
|
| 130 |
+
"""Load pretrained text encoder."""
|
| 131 |
+
from transformers import AutoModel, AutoTokenizer
|
| 132 |
+
|
| 133 |
+
self.encoder = AutoModel.from_pretrained(
|
| 134 |
+
self.config.model_name,
|
| 135 |
+
torch_dtype=torch.float32,
|
| 136 |
+
)
|
| 137 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 138 |
+
self.config.model_name
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def freeze_all(self):
|
| 142 |
+
for param in self.encoder.parameters():
|
| 143 |
+
param.requires_grad = False
|
| 144 |
+
|
| 145 |
+
def unfreeze_last_n_layers(self, n: int):
|
| 146 |
+
if hasattr(self.encoder, 'encoder'):
|
| 147 |
+
layers = self.encoder.encoder.layer
|
| 148 |
+
else:
|
| 149 |
+
raise ValueError(f"Unknown encoder structure for {self.config.model_name}")
|
| 150 |
+
|
| 151 |
+
total_layers = len(layers)
|
| 152 |
+
for i, layer in enumerate(layers):
|
| 153 |
+
if i >= total_layers - n:
|
| 154 |
+
for param in layer.parameters():
|
| 155 |
+
param.requires_grad = True
|
| 156 |
+
|
| 157 |
+
def forward(
|
| 158 |
+
self,
|
| 159 |
+
input_ids: torch.Tensor, # [B, seq_len]
|
| 160 |
+
attention_mask: torch.Tensor, # [B, seq_len]
|
| 161 |
+
) -> Dict[str, torch.Tensor]:
|
| 162 |
+
"""
|
| 163 |
+
Encode text (question + options).
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
dict with:
|
| 167 |
+
'token_embeddings': [B, seq_len, hidden_size]
|
| 168 |
+
'cls_embedding': [B, hidden_size]
|
| 169 |
+
'attention_mask': [B, seq_len]
|
| 170 |
+
"""
|
| 171 |
+
outputs = self.encoder(
|
| 172 |
+
input_ids=input_ids,
|
| 173 |
+
attention_mask=attention_mask,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
return {
|
| 177 |
+
'token_embeddings': outputs.last_hidden_state,
|
| 178 |
+
'cls_embedding': outputs.last_hidden_state[:, 0],
|
| 179 |
+
'attention_mask': attention_mask,
|
| 180 |
+
}
|
mr_jepa/models/evidence_memory.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evidence Memory Module for MR-JEPA.
|
| 3 |
+
|
| 4 |
+
The Evidence Memory is a unified multimodal representation that fuses:
|
| 5 |
+
1. Dense visual patch tokens (from DINOv2/v3)
|
| 6 |
+
2. Text tokens (question + options from DeBERTa)
|
| 7 |
+
3. Optional enriched tokens: OCR, layout, chart structure, SAM segments
|
| 8 |
+
|
| 9 |
+
Architecture:
|
| 10 |
+
- N learnable evidence query tokens
|
| 11 |
+
- Cross-attention layers: queries attend to all input modalities
|
| 12 |
+
- Each cross-attention layer also has self-attention among queries
|
| 13 |
+
- Output: N evidence tokens that capture the full multimodal context
|
| 14 |
+
|
| 15 |
+
This is inspired by Perceiver/Q-Former architectures but designed specifically
|
| 16 |
+
as the initial evidence state for the JEPA rollout.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
import math
|
| 23 |
+
from typing import Optional, Dict, List
|
| 24 |
+
|
| 25 |
+
from ..configs.model_config import EvidenceMemoryConfig
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class CrossAttentionLayer(nn.Module):
|
| 29 |
+
"""
|
| 30 |
+
Single cross-attention layer with self-attention.
|
| 31 |
+
|
| 32 |
+
Flow: self_attn(queries) → cross_attn(queries, kv=evidence) → FFN
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, hidden_dim: int, num_heads: int, dropout: float = 0.1):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.hidden_dim = hidden_dim
|
| 38 |
+
self.num_heads = num_heads
|
| 39 |
+
self.head_dim = hidden_dim // num_heads
|
| 40 |
+
|
| 41 |
+
# Self-attention among evidence queries
|
| 42 |
+
self.self_attn = nn.MultiheadAttention(
|
| 43 |
+
embed_dim=hidden_dim,
|
| 44 |
+
num_heads=num_heads,
|
| 45 |
+
dropout=dropout,
|
| 46 |
+
batch_first=True,
|
| 47 |
+
)
|
| 48 |
+
self.self_attn_norm = nn.LayerNorm(hidden_dim)
|
| 49 |
+
|
| 50 |
+
# Cross-attention: queries attend to input tokens
|
| 51 |
+
self.cross_attn = nn.MultiheadAttention(
|
| 52 |
+
embed_dim=hidden_dim,
|
| 53 |
+
num_heads=num_heads,
|
| 54 |
+
dropout=dropout,
|
| 55 |
+
batch_first=True,
|
| 56 |
+
)
|
| 57 |
+
self.cross_attn_norm = nn.LayerNorm(hidden_dim)
|
| 58 |
+
|
| 59 |
+
# FFN
|
| 60 |
+
self.ffn = nn.Sequential(
|
| 61 |
+
nn.Linear(hidden_dim, hidden_dim * 4),
|
| 62 |
+
nn.GELU(),
|
| 63 |
+
nn.Dropout(dropout),
|
| 64 |
+
nn.Linear(hidden_dim * 4, hidden_dim),
|
| 65 |
+
nn.Dropout(dropout),
|
| 66 |
+
)
|
| 67 |
+
self.ffn_norm = nn.LayerNorm(hidden_dim)
|
| 68 |
+
|
| 69 |
+
def forward(
|
| 70 |
+
self,
|
| 71 |
+
queries: torch.Tensor, # [B, N_q, D]
|
| 72 |
+
kv_tokens: torch.Tensor, # [B, N_kv, D]
|
| 73 |
+
kv_mask: Optional[torch.Tensor] = None, # [B, N_kv] bool
|
| 74 |
+
) -> torch.Tensor:
|
| 75 |
+
"""
|
| 76 |
+
Args:
|
| 77 |
+
queries: Evidence query tokens [B, N_q, D]
|
| 78 |
+
kv_tokens: Concatenated input tokens [B, N_kv, D]
|
| 79 |
+
kv_mask: Key padding mask for kv_tokens [B, N_kv]
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Updated queries [B, N_q, D]
|
| 83 |
+
"""
|
| 84 |
+
# Self-attention among queries
|
| 85 |
+
residual = queries
|
| 86 |
+
queries = self.self_attn_norm(queries)
|
| 87 |
+
queries_out, _ = self.self_attn(queries, queries, queries)
|
| 88 |
+
queries = residual + queries_out
|
| 89 |
+
|
| 90 |
+
# Cross-attention to input tokens
|
| 91 |
+
residual = queries
|
| 92 |
+
queries_normed = self.cross_attn_norm(queries)
|
| 93 |
+
queries_out, _ = self.cross_attn(
|
| 94 |
+
query=queries_normed,
|
| 95 |
+
key=kv_tokens,
|
| 96 |
+
value=kv_tokens,
|
| 97 |
+
key_padding_mask=kv_mask,
|
| 98 |
+
)
|
| 99 |
+
queries = residual + queries_out
|
| 100 |
+
|
| 101 |
+
# FFN
|
| 102 |
+
residual = queries
|
| 103 |
+
queries = residual + self.ffn(self.ffn_norm(queries))
|
| 104 |
+
|
| 105 |
+
return queries
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class ModalityProjector(nn.Module):
|
| 109 |
+
"""Projects tokens from a specific modality to the evidence memory dimension."""
|
| 110 |
+
|
| 111 |
+
def __init__(self, input_dim: int, output_dim: int):
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.proj = nn.Sequential(
|
| 114 |
+
nn.Linear(input_dim, output_dim),
|
| 115 |
+
nn.LayerNorm(output_dim),
|
| 116 |
+
nn.GELU(),
|
| 117 |
+
nn.Linear(output_dim, output_dim),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 121 |
+
return self.proj(x)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class EvidenceMemory(nn.Module):
|
| 125 |
+
"""
|
| 126 |
+
Unified Evidence Memory that fuses all input modalities.
|
| 127 |
+
|
| 128 |
+
The output evidence tokens serve as:
|
| 129 |
+
1. The basis for constructing the initial latent state z₀
|
| 130 |
+
2. The key-value memory for evidence-gated cross-attention in rollout steps
|
| 131 |
+
|
| 132 |
+
Architecture follows a Perceiver-style design with learnable queries
|
| 133 |
+
cross-attending to projected multimodal tokens.
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
config: EvidenceMemoryConfig,
|
| 139 |
+
visual_dim: int,
|
| 140 |
+
text_dim: int,
|
| 141 |
+
):
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.config = config
|
| 144 |
+
self.hidden_dim = config.hidden_dim
|
| 145 |
+
|
| 146 |
+
# Learnable evidence query tokens
|
| 147 |
+
self.evidence_queries = nn.Parameter(
|
| 148 |
+
torch.randn(1, config.num_evidence_tokens, config.hidden_dim) * 0.02
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Modality projectors
|
| 152 |
+
self.visual_proj = ModalityProjector(visual_dim, config.hidden_dim)
|
| 153 |
+
self.text_proj = ModalityProjector(text_dim, config.hidden_dim)
|
| 154 |
+
|
| 155 |
+
# Modality type embeddings (to distinguish sources in cross-attention)
|
| 156 |
+
self.modality_embeddings = nn.Embedding(6, config.hidden_dim)
|
| 157 |
+
# 0=visual, 1=text, 2=ocr, 3=layout, 4=chart, 5=sam
|
| 158 |
+
|
| 159 |
+
# Optional enriched evidence projectors (Phase 3)
|
| 160 |
+
if config.use_ocr_tokens:
|
| 161 |
+
self.ocr_proj = ModalityProjector(text_dim, config.hidden_dim)
|
| 162 |
+
if config.use_layout_tokens:
|
| 163 |
+
self.layout_proj = ModalityProjector(256, config.hidden_dim) # Layout features
|
| 164 |
+
if config.use_chart_tokens:
|
| 165 |
+
self.chart_proj = ModalityProjector(512, config.hidden_dim) # Chart structure
|
| 166 |
+
if config.use_sam_tokens:
|
| 167 |
+
self.sam_proj = ModalityProjector(256, config.hidden_dim) # SAM2 features
|
| 168 |
+
|
| 169 |
+
# Cross-attention layers
|
| 170 |
+
self.layers = nn.ModuleList([
|
| 171 |
+
CrossAttentionLayer(
|
| 172 |
+
hidden_dim=config.hidden_dim,
|
| 173 |
+
num_heads=config.num_heads,
|
| 174 |
+
dropout=config.dropout,
|
| 175 |
+
)
|
| 176 |
+
for _ in range(config.num_cross_attn_layers)
|
| 177 |
+
])
|
| 178 |
+
|
| 179 |
+
# Final norm
|
| 180 |
+
self.output_norm = nn.LayerNorm(config.hidden_dim)
|
| 181 |
+
|
| 182 |
+
def _prepare_kv_tokens(
|
| 183 |
+
self,
|
| 184 |
+
visual_tokens: torch.Tensor, # [B, N_v, D_v]
|
| 185 |
+
text_tokens: torch.Tensor, # [B, N_t, D_t]
|
| 186 |
+
text_mask: torch.Tensor, # [B, N_t]
|
| 187 |
+
ocr_tokens: Optional[torch.Tensor] = None, # [B, N_ocr, D_t]
|
| 188 |
+
ocr_mask: Optional[torch.Tensor] = None,
|
| 189 |
+
layout_tokens: Optional[torch.Tensor] = None, # [B, N_lay, D_lay]
|
| 190 |
+
layout_mask: Optional[torch.Tensor] = None,
|
| 191 |
+
chart_tokens: Optional[torch.Tensor] = None, # [B, N_ch, D_ch]
|
| 192 |
+
chart_mask: Optional[torch.Tensor] = None,
|
| 193 |
+
sam_tokens: Optional[torch.Tensor] = None, # [B, N_sam, D_sam]
|
| 194 |
+
sam_mask: Optional[torch.Tensor] = None,
|
| 195 |
+
):
|
| 196 |
+
"""Project all modalities and concatenate into a single KV sequence."""
|
| 197 |
+
B = visual_tokens.size(0)
|
| 198 |
+
device = visual_tokens.device
|
| 199 |
+
|
| 200 |
+
all_tokens = []
|
| 201 |
+
all_masks = []
|
| 202 |
+
|
| 203 |
+
# Visual tokens (always present)
|
| 204 |
+
v_proj = self.visual_proj(visual_tokens) # [B, N_v, D]
|
| 205 |
+
v_proj = v_proj + self.modality_embeddings(
|
| 206 |
+
torch.zeros(v_proj.size(1), dtype=torch.long, device=device)
|
| 207 |
+
).unsqueeze(0)
|
| 208 |
+
all_tokens.append(v_proj)
|
| 209 |
+
all_masks.append(torch.zeros(B, v_proj.size(1), dtype=torch.bool, device=device))
|
| 210 |
+
|
| 211 |
+
# Text tokens (always present)
|
| 212 |
+
t_proj = self.text_proj(text_tokens) # [B, N_t, D]
|
| 213 |
+
t_proj = t_proj + self.modality_embeddings(
|
| 214 |
+
torch.ones(t_proj.size(1), dtype=torch.long, device=device)
|
| 215 |
+
).unsqueeze(0)
|
| 216 |
+
all_tokens.append(t_proj)
|
| 217 |
+
# Invert mask: True = padding (to be masked out)
|
| 218 |
+
all_masks.append(~text_mask.bool())
|
| 219 |
+
|
| 220 |
+
# Optional modalities (Phase 3)
|
| 221 |
+
if ocr_tokens is not None and self.config.use_ocr_tokens:
|
| 222 |
+
o_proj = self.ocr_proj(ocr_tokens)
|
| 223 |
+
o_proj = o_proj + self.modality_embeddings(
|
| 224 |
+
torch.full((o_proj.size(1),), 2, dtype=torch.long, device=device)
|
| 225 |
+
).unsqueeze(0)
|
| 226 |
+
all_tokens.append(o_proj)
|
| 227 |
+
all_masks.append(~ocr_mask.bool() if ocr_mask is not None
|
| 228 |
+
else torch.zeros(B, o_proj.size(1), dtype=torch.bool, device=device))
|
| 229 |
+
|
| 230 |
+
if layout_tokens is not None and self.config.use_layout_tokens:
|
| 231 |
+
l_proj = self.layout_proj(layout_tokens)
|
| 232 |
+
l_proj = l_proj + self.modality_embeddings(
|
| 233 |
+
torch.full((l_proj.size(1),), 3, dtype=torch.long, device=device)
|
| 234 |
+
).unsqueeze(0)
|
| 235 |
+
all_tokens.append(l_proj)
|
| 236 |
+
all_masks.append(~layout_mask.bool() if layout_mask is not None
|
| 237 |
+
else torch.zeros(B, l_proj.size(1), dtype=torch.bool, device=device))
|
| 238 |
+
|
| 239 |
+
if chart_tokens is not None and self.config.use_chart_tokens:
|
| 240 |
+
c_proj = self.chart_proj(chart_tokens)
|
| 241 |
+
c_proj = c_proj + self.modality_embeddings(
|
| 242 |
+
torch.full((c_proj.size(1),), 4, dtype=torch.long, device=device)
|
| 243 |
+
).unsqueeze(0)
|
| 244 |
+
all_tokens.append(c_proj)
|
| 245 |
+
all_masks.append(~chart_mask.bool() if chart_mask is not None
|
| 246 |
+
else torch.zeros(B, c_proj.size(1), dtype=torch.bool, device=device))
|
| 247 |
+
|
| 248 |
+
if sam_tokens is not None and self.config.use_sam_tokens:
|
| 249 |
+
s_proj = self.sam_proj(sam_tokens)
|
| 250 |
+
s_proj = s_proj + self.modality_embeddings(
|
| 251 |
+
torch.full((s_proj.size(1),), 5, dtype=torch.long, device=device)
|
| 252 |
+
).unsqueeze(0)
|
| 253 |
+
all_tokens.append(s_proj)
|
| 254 |
+
all_masks.append(~sam_mask.bool() if sam_mask is not None
|
| 255 |
+
else torch.zeros(B, s_proj.size(1), dtype=torch.bool, device=device))
|
| 256 |
+
|
| 257 |
+
# Concatenate all modalities
|
| 258 |
+
kv_tokens = torch.cat(all_tokens, dim=1) # [B, N_total, D]
|
| 259 |
+
kv_mask = torch.cat(all_masks, dim=1) # [B, N_total]
|
| 260 |
+
|
| 261 |
+
return kv_tokens, kv_mask
|
| 262 |
+
|
| 263 |
+
def forward(
|
| 264 |
+
self,
|
| 265 |
+
visual_tokens: torch.Tensor,
|
| 266 |
+
text_tokens: torch.Tensor,
|
| 267 |
+
text_mask: torch.Tensor,
|
| 268 |
+
**enriched_kwargs,
|
| 269 |
+
) -> Dict[str, torch.Tensor]:
|
| 270 |
+
"""
|
| 271 |
+
Fuse all modalities into evidence tokens.
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
dict with:
|
| 275 |
+
'evidence_tokens': [B, N_evidence, D] - fused evidence
|
| 276 |
+
'kv_tokens': [B, N_total, D] - projected multimodal KV for rollout
|
| 277 |
+
'kv_mask': [B, N_total] - mask for KV tokens
|
| 278 |
+
"""
|
| 279 |
+
B = visual_tokens.size(0)
|
| 280 |
+
|
| 281 |
+
# Prepare KV tokens from all modalities
|
| 282 |
+
kv_tokens, kv_mask = self._prepare_kv_tokens(
|
| 283 |
+
visual_tokens, text_tokens, text_mask, **enriched_kwargs
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Expand learnable queries for batch
|
| 287 |
+
queries = self.evidence_queries.expand(B, -1, -1) # [B, N_q, D]
|
| 288 |
+
|
| 289 |
+
# Apply cross-attention layers
|
| 290 |
+
for layer in self.layers:
|
| 291 |
+
queries = layer(queries, kv_tokens, kv_mask)
|
| 292 |
+
|
| 293 |
+
evidence_tokens = self.output_norm(queries) # [B, N_evidence, D]
|
| 294 |
+
|
| 295 |
+
return {
|
| 296 |
+
'evidence_tokens': evidence_tokens,
|
| 297 |
+
'kv_tokens': kv_tokens,
|
| 298 |
+
'kv_mask': kv_mask,
|
| 299 |
+
}
|
mr_jepa/models/latent_rollout.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Latent Belief-State Rollout Module for MR-JEPA.
|
| 3 |
+
|
| 4 |
+
This is the core JEPA reasoning module. It models the evolution of a
|
| 5 |
+
multimodal belief state as the system "reasons" about a question:
|
| 6 |
+
|
| 7 |
+
z₀ → z₁ → z₂ → z₃ (K=3 steps)
|
| 8 |
+
|
| 9 |
+
Each step applies a shared predictor block with evidence gating:
|
| 10 |
+
1. Self-attention: latent state tokens attend to each other
|
| 11 |
+
2. Evidence-gated cross-attention: state attends to evidence memory
|
| 12 |
+
3. FFN with residual
|
| 13 |
+
|
| 14 |
+
Key design choices grounded in literature:
|
| 15 |
+
- SHARED predictor across steps (weight-tied, like V-JEPA/LeWorldModel)
|
| 16 |
+
- Step embeddings to differentiate rollout positions
|
| 17 |
+
- Evidence gates (sigmoid/softmax) control information flow per step
|
| 18 |
+
- The predictor is a "narrow" transformer (from I-JEPA: predictor is
|
| 19 |
+
smaller than encoder)
|
| 20 |
+
|
| 21 |
+
The JEPA objective supervises this trajectory: the target encoder (EMA)
|
| 22 |
+
generates z*_k targets, and the predictor must predict z*_k from z_{k-1}.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
import math
|
| 29 |
+
from typing import Optional, Dict, List, Tuple
|
| 30 |
+
|
| 31 |
+
from ..configs.model_config import LatentRolloutConfig
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class EvidenceGate(nn.Module):
|
| 35 |
+
"""
|
| 36 |
+
Learned gate that controls how much evidence flows into each rollout step.
|
| 37 |
+
|
| 38 |
+
Intuition: Early steps may need more visual evidence, while later steps
|
| 39 |
+
may rely more on accumulated reasoning. The gate learns this schedule.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, hidden_dim: int, gate_type: str = "sigmoid"):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.gate_type = gate_type
|
| 45 |
+
|
| 46 |
+
if gate_type == "sigmoid":
|
| 47 |
+
# Per-dimension gate: scales each feature independently
|
| 48 |
+
self.gate_proj = nn.Sequential(
|
| 49 |
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
| 50 |
+
nn.Sigmoid(),
|
| 51 |
+
)
|
| 52 |
+
elif gate_type == "learned":
|
| 53 |
+
# Scalar gate per token, learned as a function of state + evidence
|
| 54 |
+
self.gate_proj = nn.Sequential(
|
| 55 |
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
| 56 |
+
nn.ReLU(),
|
| 57 |
+
nn.Linear(hidden_dim, 1),
|
| 58 |
+
nn.Sigmoid(),
|
| 59 |
+
)
|
| 60 |
+
# softmax gate is implemented in forward via attention weights
|
| 61 |
+
|
| 62 |
+
def forward(
|
| 63 |
+
self,
|
| 64 |
+
state: torch.Tensor, # [B, N_s, D]
|
| 65 |
+
evidence_contribution: torch.Tensor, # [B, N_s, D]
|
| 66 |
+
) -> torch.Tensor:
|
| 67 |
+
"""
|
| 68 |
+
Apply evidence gate.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
state: Current latent state
|
| 72 |
+
evidence_contribution: Cross-attention output from evidence
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Gated evidence contribution [B, N_s, D]
|
| 76 |
+
"""
|
| 77 |
+
if self.gate_type == "sigmoid":
|
| 78 |
+
gate = self.gate_proj(torch.cat([state, evidence_contribution], dim=-1))
|
| 79 |
+
return gate * evidence_contribution
|
| 80 |
+
elif self.gate_type == "learned":
|
| 81 |
+
gate = self.gate_proj(torch.cat([state, evidence_contribution], dim=-1))
|
| 82 |
+
return gate * evidence_contribution
|
| 83 |
+
else:
|
| 84 |
+
# No explicit gating (softmax via attention weights)
|
| 85 |
+
return evidence_contribution
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class PredictorBlock(nn.Module):
|
| 89 |
+
"""
|
| 90 |
+
Single rollout step predictor block.
|
| 91 |
+
|
| 92 |
+
This is the "narrow" predictor from I-JEPA adapted for reasoning:
|
| 93 |
+
- Self-attention among latent state tokens
|
| 94 |
+
- Evidence-gated cross-attention to evidence memory
|
| 95 |
+
- FFN
|
| 96 |
+
|
| 97 |
+
All K rollout steps share this same block (weight-tied).
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(
|
| 101 |
+
self,
|
| 102 |
+
hidden_dim: int,
|
| 103 |
+
num_heads: int,
|
| 104 |
+
ffn_dim: int,
|
| 105 |
+
dropout: float,
|
| 106 |
+
gate_type: str = "sigmoid",
|
| 107 |
+
):
|
| 108 |
+
super().__init__()
|
| 109 |
+
|
| 110 |
+
# Self-attention among state tokens
|
| 111 |
+
self.self_attn = nn.MultiheadAttention(
|
| 112 |
+
embed_dim=hidden_dim,
|
| 113 |
+
num_heads=num_heads,
|
| 114 |
+
dropout=dropout,
|
| 115 |
+
batch_first=True,
|
| 116 |
+
)
|
| 117 |
+
self.self_attn_norm = nn.LayerNorm(hidden_dim)
|
| 118 |
+
|
| 119 |
+
# Cross-attention to evidence memory
|
| 120 |
+
self.cross_attn = nn.MultiheadAttention(
|
| 121 |
+
embed_dim=hidden_dim,
|
| 122 |
+
num_heads=num_heads,
|
| 123 |
+
dropout=dropout,
|
| 124 |
+
batch_first=True,
|
| 125 |
+
)
|
| 126 |
+
self.cross_attn_norm = nn.LayerNorm(hidden_dim)
|
| 127 |
+
|
| 128 |
+
# Evidence gate
|
| 129 |
+
self.evidence_gate = EvidenceGate(hidden_dim, gate_type)
|
| 130 |
+
|
| 131 |
+
# FFN
|
| 132 |
+
self.ffn = nn.Sequential(
|
| 133 |
+
nn.Linear(hidden_dim, ffn_dim),
|
| 134 |
+
nn.GELU(),
|
| 135 |
+
nn.Dropout(dropout),
|
| 136 |
+
nn.Linear(ffn_dim, hidden_dim),
|
| 137 |
+
nn.Dropout(dropout),
|
| 138 |
+
)
|
| 139 |
+
self.ffn_norm = nn.LayerNorm(hidden_dim)
|
| 140 |
+
|
| 141 |
+
def forward(
|
| 142 |
+
self,
|
| 143 |
+
state: torch.Tensor, # [B, N_s, D]
|
| 144 |
+
evidence_kv: torch.Tensor, # [B, N_e, D]
|
| 145 |
+
evidence_mask: Optional[torch.Tensor] = None, # [B, N_e]
|
| 146 |
+
) -> torch.Tensor:
|
| 147 |
+
"""One rollout step: state → updated state."""
|
| 148 |
+
# Self-attention
|
| 149 |
+
residual = state
|
| 150 |
+
state_normed = self.self_attn_norm(state)
|
| 151 |
+
state_out, _ = self.self_attn(state_normed, state_normed, state_normed)
|
| 152 |
+
state = residual + state_out
|
| 153 |
+
|
| 154 |
+
# Cross-attention to evidence
|
| 155 |
+
residual = state
|
| 156 |
+
state_normed = self.cross_attn_norm(state)
|
| 157 |
+
evidence_contribution, _ = self.cross_attn(
|
| 158 |
+
query=state_normed,
|
| 159 |
+
key=evidence_kv,
|
| 160 |
+
value=evidence_kv,
|
| 161 |
+
key_padding_mask=evidence_mask,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# Apply evidence gate
|
| 165 |
+
gated_evidence = self.evidence_gate(state, evidence_contribution)
|
| 166 |
+
state = residual + gated_evidence
|
| 167 |
+
|
| 168 |
+
# FFN
|
| 169 |
+
residual = state
|
| 170 |
+
state = residual + self.ffn(self.ffn_norm(state))
|
| 171 |
+
|
| 172 |
+
return state
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class LatentRolloutModule(nn.Module):
|
| 176 |
+
"""
|
| 177 |
+
Full latent belief-state rollout.
|
| 178 |
+
|
| 179 |
+
Constructs z₀ from evidence memory, then refines it over K steps.
|
| 180 |
+
Each step uses the same shared PredictorBlock (weight-tied across steps).
|
| 181 |
+
|
| 182 |
+
The full trajectory [z₀, z₁, ..., z_K] is returned for the JEPA objective.
|
| 183 |
+
|
| 184 |
+
Architecture:
|
| 185 |
+
z₀ = LinearProj(evidence_pool) + state_init_tokens
|
| 186 |
+
For k in 1..K:
|
| 187 |
+
z_k = PredictorBlock(z_{k-1}, evidence_memory) + step_emb[k]
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
def __init__(self, config: LatentRolloutConfig):
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.config = config
|
| 193 |
+
self.K = config.K
|
| 194 |
+
self.hidden_dim = config.hidden_dim
|
| 195 |
+
self.num_state_tokens = config.num_state_tokens
|
| 196 |
+
|
| 197 |
+
# Initial state construction
|
| 198 |
+
# Learnable state initialization tokens
|
| 199 |
+
self.state_init = nn.Parameter(
|
| 200 |
+
torch.randn(1, config.num_state_tokens, config.hidden_dim) * 0.02
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# Project evidence summary into initial state
|
| 204 |
+
self.z0_proj = nn.Sequential(
|
| 205 |
+
nn.Linear(config.hidden_dim, config.hidden_dim),
|
| 206 |
+
nn.LayerNorm(config.hidden_dim),
|
| 207 |
+
nn.GELU(),
|
| 208 |
+
nn.Linear(config.hidden_dim, config.hidden_dim),
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# Step embeddings (learned per-step bias)
|
| 212 |
+
if config.use_step_embedding:
|
| 213 |
+
self.step_embeddings = nn.Parameter(
|
| 214 |
+
torch.randn(config.K + 1, 1, config.hidden_dim) * 0.02
|
| 215 |
+
) # [K+1, 1, D] — one per step including z₀
|
| 216 |
+
|
| 217 |
+
# Shared predictor block (weight-tied across K steps)
|
| 218 |
+
# We use a stack of transformer layers as the predictor
|
| 219 |
+
self.predictor_layers = nn.ModuleList([
|
| 220 |
+
PredictorBlock(
|
| 221 |
+
hidden_dim=config.hidden_dim,
|
| 222 |
+
num_heads=config.num_heads,
|
| 223 |
+
ffn_dim=config.ffn_dim,
|
| 224 |
+
dropout=config.dropout,
|
| 225 |
+
gate_type=config.gate_type if config.use_evidence_gate else "none",
|
| 226 |
+
)
|
| 227 |
+
for _ in range(config.num_predictor_layers)
|
| 228 |
+
])
|
| 229 |
+
|
| 230 |
+
# Output projection (project each z_k to prediction space)
|
| 231 |
+
self.output_proj = nn.Sequential(
|
| 232 |
+
nn.LayerNorm(config.hidden_dim),
|
| 233 |
+
nn.Linear(config.hidden_dim, config.hidden_dim),
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
def _construct_z0(
|
| 237 |
+
self,
|
| 238 |
+
evidence_tokens: torch.Tensor, # [B, N_e, D]
|
| 239 |
+
) -> torch.Tensor:
|
| 240 |
+
"""
|
| 241 |
+
Construct initial latent state z₀ from evidence.
|
| 242 |
+
|
| 243 |
+
z₀ = state_init_tokens + projected_evidence_pool + step_emb[0]
|
| 244 |
+
|
| 245 |
+
The evidence pool is computed by adaptive average pooling the evidence
|
| 246 |
+
tokens down to the number of state tokens.
|
| 247 |
+
"""
|
| 248 |
+
B = evidence_tokens.size(0)
|
| 249 |
+
|
| 250 |
+
# Pool evidence into state-sized representation
|
| 251 |
+
# [B, N_e, D] → [B, N_s, D] via adaptive pooling
|
| 252 |
+
evidence_pooled = F.adaptive_avg_pool1d(
|
| 253 |
+
evidence_tokens.permute(0, 2, 1), # [B, D, N_e]
|
| 254 |
+
self.num_state_tokens
|
| 255 |
+
).permute(0, 2, 1) # [B, N_s, D]
|
| 256 |
+
|
| 257 |
+
# Project and combine with learnable init
|
| 258 |
+
z0 = self.state_init.expand(B, -1, -1) + self.z0_proj(evidence_pooled)
|
| 259 |
+
|
| 260 |
+
# Add step embedding for step 0
|
| 261 |
+
if self.config.use_step_embedding:
|
| 262 |
+
z0 = z0 + self.step_embeddings[0].unsqueeze(0)
|
| 263 |
+
|
| 264 |
+
return z0
|
| 265 |
+
|
| 266 |
+
def _single_rollout_step(
|
| 267 |
+
self,
|
| 268 |
+
z_prev: torch.Tensor, # [B, N_s, D]
|
| 269 |
+
evidence_tokens: torch.Tensor, # [B, N_e, D]
|
| 270 |
+
evidence_mask: Optional[torch.Tensor],
|
| 271 |
+
) -> torch.Tensor:
|
| 272 |
+
"""Apply the shared predictor block for one rollout step."""
|
| 273 |
+
z = z_prev
|
| 274 |
+
for layer in self.predictor_layers:
|
| 275 |
+
z = layer(z, evidence_tokens, evidence_mask)
|
| 276 |
+
return z
|
| 277 |
+
|
| 278 |
+
def forward(
|
| 279 |
+
self,
|
| 280 |
+
evidence_tokens: torch.Tensor, # [B, N_e, D]
|
| 281 |
+
evidence_mask: Optional[torch.Tensor] = None, # [B, N_e]
|
| 282 |
+
) -> Dict[str, torch.Tensor]:
|
| 283 |
+
"""
|
| 284 |
+
Full K-step latent rollout.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
evidence_tokens: Fused evidence from EvidenceMemory [B, N_e, D]
|
| 288 |
+
evidence_mask: Padding mask for evidence tokens
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
dict with:
|
| 292 |
+
'trajectory': [B, K+1, N_s, D] - full latent trajectory
|
| 293 |
+
'z_final': [B, N_s, D] - final latent state z_K
|
| 294 |
+
'z_projected': [B, K+1, N_s, D] - projected trajectory for JEPA loss
|
| 295 |
+
"""
|
| 296 |
+
# Construct z₀
|
| 297 |
+
z = self._construct_z0(evidence_tokens)
|
| 298 |
+
|
| 299 |
+
trajectory = [z]
|
| 300 |
+
|
| 301 |
+
# Rollout K steps
|
| 302 |
+
for k in range(1, self.K + 1):
|
| 303 |
+
z = self._single_rollout_step(z, evidence_tokens, evidence_mask)
|
| 304 |
+
|
| 305 |
+
# Add step embedding
|
| 306 |
+
if self.config.use_step_embedding:
|
| 307 |
+
z = z + self.step_embeddings[k].unsqueeze(0)
|
| 308 |
+
|
| 309 |
+
trajectory.append(z)
|
| 310 |
+
|
| 311 |
+
# Stack trajectory: [B, K+1, N_s, D]
|
| 312 |
+
trajectory_tensor = torch.stack(trajectory, dim=1)
|
| 313 |
+
|
| 314 |
+
# Project each state for JEPA prediction loss
|
| 315 |
+
B, Kp1, N_s, D = trajectory_tensor.shape
|
| 316 |
+
flat = trajectory_tensor.reshape(B * Kp1 * N_s, D)
|
| 317 |
+
projected_flat = self.output_proj(flat)
|
| 318 |
+
z_projected = projected_flat.reshape(B, Kp1, N_s, D)
|
| 319 |
+
|
| 320 |
+
return {
|
| 321 |
+
'trajectory': trajectory_tensor, # Raw states
|
| 322 |
+
'z_final': trajectory[-1], # Final state
|
| 323 |
+
'z_projected': z_projected, # For JEPA loss
|
| 324 |
+
}
|
mr_jepa/models/mr_jepa.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MR-JEPA: Multimodal Reasoning via Joint-Embedding Predictive Architecture.
|
| 3 |
+
|
| 4 |
+
Complete model that integrates all components:
|
| 5 |
+
Visual Backbone → Evidence Memory ← Text Encoder
|
| 6 |
+
Evidence Memory → z₀ → Latent Rollout (K=3) → Answer Heads
|
| 7 |
+
Target Encoder (EMA) → JEPA Supervision
|
| 8 |
+
|
| 9 |
+
The model supports two branches:
|
| 10 |
+
- Hybrid-main: Full model, pretrained backbones, competitive on benchmarks
|
| 11 |
+
- Purist-side: Stripped-down, closer to LeWorldModel spirit
|
| 12 |
+
|
| 13 |
+
Forward pass:
|
| 14 |
+
1. Extract visual tokens (DINOv2/v3)
|
| 15 |
+
2. Encode question + options (DeBERTa)
|
| 16 |
+
3. Fuse in Evidence Memory (cross-attention)
|
| 17 |
+
4. Construct z₀ and rollout K steps
|
| 18 |
+
5. Score answer options (discriminative) and/or generate short answer
|
| 19 |
+
6. Compute JEPA loss against target encoder trajectory
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
from typing import Optional, Dict, Any
|
| 26 |
+
|
| 27 |
+
from ..configs.model_config import MRJEPAConfig
|
| 28 |
+
from .backbones import VisualBackbone, TextEncoder
|
| 29 |
+
from .evidence_memory import EvidenceMemory
|
| 30 |
+
from .latent_rollout import LatentRolloutModule
|
| 31 |
+
from .target_encoder import TargetEncoder, JEPALoss
|
| 32 |
+
from .answer_heads import DiscriminativeHead, GenerativeHead
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class MRJEPAModel(nn.Module):
|
| 36 |
+
"""
|
| 37 |
+
MR-JEPA: A world model for multimodal reasoning.
|
| 38 |
+
|
| 39 |
+
Instead of modeling physical dynamics, this model models the evolution
|
| 40 |
+
of a belief state while solving a visual question. The JEPA objective
|
| 41 |
+
trains the latent rollout to produce meaningful intermediate states,
|
| 42 |
+
supervised by an EMA target encoder.
|
| 43 |
+
|
| 44 |
+
Parameters:
|
| 45 |
+
config: MRJEPAConfig with all architecture hyperparameters
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, config: MRJEPAConfig):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.config = config
|
| 51 |
+
|
| 52 |
+
# ===================== Perception Encoders =====================
|
| 53 |
+
self.visual_backbone = VisualBackbone(config.visual)
|
| 54 |
+
self.text_encoder = TextEncoder(config.text)
|
| 55 |
+
|
| 56 |
+
# ===================== Evidence Memory =====================
|
| 57 |
+
self.evidence_memory = EvidenceMemory(
|
| 58 |
+
config=config.evidence,
|
| 59 |
+
visual_dim=config.visual.hidden_size,
|
| 60 |
+
text_dim=config.text.hidden_size,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# ===================== Latent Rollout =====================
|
| 64 |
+
self.latent_rollout = LatentRolloutModule(config.rollout)
|
| 65 |
+
|
| 66 |
+
# ===================== Target Encoder (EMA) =====================
|
| 67 |
+
self.target_encoder = TargetEncoder(
|
| 68 |
+
online_evidence_memory=self.evidence_memory,
|
| 69 |
+
online_rollout=self.latent_rollout,
|
| 70 |
+
config=config.jepa,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# ===================== Answer Heads =====================
|
| 74 |
+
self.disc_head = DiscriminativeHead(
|
| 75 |
+
config=config.answer,
|
| 76 |
+
hidden_dim=config.rollout.hidden_dim,
|
| 77 |
+
text_dim=config.text.hidden_size,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
self.gen_head = GenerativeHead(
|
| 81 |
+
config=config.answer,
|
| 82 |
+
hidden_dim=config.rollout.hidden_dim,
|
| 83 |
+
vocab_size=config.answer.gen_vocab_size,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# ===================== JEPA Loss =====================
|
| 87 |
+
self.jepa_loss_fn = JEPALoss(
|
| 88 |
+
config=config.jepa,
|
| 89 |
+
hidden_dim=config.rollout.hidden_dim,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# ===================== Ablation controls =====================
|
| 93 |
+
self._use_jepa = True # Disable for "no-JEPA" ablation
|
| 94 |
+
self._use_rollout = True # Disable for "no-rollout" ablation (z₀ only)
|
| 95 |
+
self._use_evidence_gate = config.rollout.use_evidence_gate
|
| 96 |
+
|
| 97 |
+
def get_trainable_params(self, phase: int = 1) -> Dict[str, list]:
|
| 98 |
+
"""
|
| 99 |
+
Get parameter groups for each training phase.
|
| 100 |
+
|
| 101 |
+
Phase 1: Freeze backbones, train evidence memory + rollout + heads
|
| 102 |
+
Phase 2: Unfreeze last N backbone layers with lower LR
|
| 103 |
+
Phase 3: Add enriched evidence modules
|
| 104 |
+
|
| 105 |
+
Returns dict with 'high_lr' and 'low_lr' parameter groups.
|
| 106 |
+
"""
|
| 107 |
+
high_lr_params = []
|
| 108 |
+
low_lr_params = []
|
| 109 |
+
|
| 110 |
+
if phase >= 1:
|
| 111 |
+
# Always train: evidence memory, rollout, heads, loss
|
| 112 |
+
for module in [self.evidence_memory, self.latent_rollout,
|
| 113 |
+
self.disc_head, self.gen_head, self.jepa_loss_fn]:
|
| 114 |
+
high_lr_params.extend(module.parameters())
|
| 115 |
+
|
| 116 |
+
if phase >= 2:
|
| 117 |
+
# Unfreeze last N visual backbone layers
|
| 118 |
+
self.visual_backbone.unfreeze_last_n_layers(
|
| 119 |
+
self.config.visual.unfreeze_last_n_layers
|
| 120 |
+
)
|
| 121 |
+
# Unfreeze last N text encoder layers
|
| 122 |
+
self.text_encoder.unfreeze_last_n_layers(
|
| 123 |
+
self.config.text.unfreeze_last_n_layers
|
| 124 |
+
)
|
| 125 |
+
# Add backbone params with lower LR
|
| 126 |
+
for module in [self.visual_backbone, self.text_encoder]:
|
| 127 |
+
for p in module.parameters():
|
| 128 |
+
if p.requires_grad:
|
| 129 |
+
low_lr_params.append(p)
|
| 130 |
+
|
| 131 |
+
return {
|
| 132 |
+
'high_lr': high_lr_params,
|
| 133 |
+
'low_lr': low_lr_params,
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
def forward(
|
| 137 |
+
self,
|
| 138 |
+
pixel_values: torch.Tensor, # [B, C, H, W]
|
| 139 |
+
input_ids: torch.Tensor, # [B, seq_len]
|
| 140 |
+
attention_mask: torch.Tensor, # [B, seq_len]
|
| 141 |
+
option_embeddings: Optional[torch.Tensor] = None, # [B, max_opts, D_text]
|
| 142 |
+
option_mask: Optional[torch.Tensor] = None, # [B, max_opts]
|
| 143 |
+
answer_labels: Optional[torch.Tensor] = None, # [B] index of correct option
|
| 144 |
+
gen_target_ids: Optional[torch.Tensor] = None, # [B, gen_seq_len]
|
| 145 |
+
# Optional enriched evidence (Phase 3)
|
| 146 |
+
ocr_tokens: Optional[torch.Tensor] = None,
|
| 147 |
+
ocr_mask: Optional[torch.Tensor] = None,
|
| 148 |
+
layout_tokens: Optional[torch.Tensor] = None,
|
| 149 |
+
layout_mask: Optional[torch.Tensor] = None,
|
| 150 |
+
chart_tokens: Optional[torch.Tensor] = None,
|
| 151 |
+
chart_mask: Optional[torch.Tensor] = None,
|
| 152 |
+
sam_tokens: Optional[torch.Tensor] = None,
|
| 153 |
+
sam_mask: Optional[torch.Tensor] = None,
|
| 154 |
+
) -> Dict[str, torch.Tensor]:
|
| 155 |
+
"""
|
| 156 |
+
Full forward pass of MR-JEPA.
|
| 157 |
+
|
| 158 |
+
Returns dict with losses and predictions.
|
| 159 |
+
"""
|
| 160 |
+
# ==================== 1. Perception ====================
|
| 161 |
+
# Visual features
|
| 162 |
+
visual_output = self.visual_backbone(pixel_values)
|
| 163 |
+
visual_tokens = visual_output['patch_tokens'] # [B, N_v, D_v]
|
| 164 |
+
|
| 165 |
+
# Text features
|
| 166 |
+
text_output = self.text_encoder(input_ids, attention_mask)
|
| 167 |
+
text_tokens = text_output['token_embeddings'] # [B, N_t, D_t]
|
| 168 |
+
text_mask = text_output['attention_mask'] # [B, N_t]
|
| 169 |
+
|
| 170 |
+
# ==================== 2. Evidence Memory ====================
|
| 171 |
+
enriched_kwargs = {}
|
| 172 |
+
for name, tokens, mask in [
|
| 173 |
+
('ocr_tokens', ocr_tokens, ocr_mask),
|
| 174 |
+
('layout_tokens', layout_tokens, layout_mask),
|
| 175 |
+
('chart_tokens', chart_tokens, chart_mask),
|
| 176 |
+
('sam_tokens', sam_tokens, sam_mask),
|
| 177 |
+
]:
|
| 178 |
+
if tokens is not None:
|
| 179 |
+
enriched_kwargs[name] = tokens
|
| 180 |
+
enriched_kwargs[name.replace('tokens', 'mask')] = mask
|
| 181 |
+
|
| 182 |
+
evidence_output = self.evidence_memory(
|
| 183 |
+
visual_tokens=visual_tokens,
|
| 184 |
+
text_tokens=text_tokens,
|
| 185 |
+
text_mask=text_mask,
|
| 186 |
+
**enriched_kwargs,
|
| 187 |
+
)
|
| 188 |
+
evidence_tokens = evidence_output['evidence_tokens'] # [B, N_e, D]
|
| 189 |
+
|
| 190 |
+
# ==================== 3. Latent Rollout ====================
|
| 191 |
+
if self._use_rollout:
|
| 192 |
+
rollout_output = self.latent_rollout(
|
| 193 |
+
evidence_tokens=evidence_tokens,
|
| 194 |
+
)
|
| 195 |
+
trajectory = rollout_output['trajectory'] # [B, K+1, N_s, D]
|
| 196 |
+
z_final = rollout_output['z_final'] # [B, N_s, D]
|
| 197 |
+
z_projected = rollout_output['z_projected'] # [B, K+1, N_s, D]
|
| 198 |
+
else:
|
| 199 |
+
# Ablation: no rollout, use z₀ directly
|
| 200 |
+
z0 = self.latent_rollout._construct_z0(evidence_tokens)
|
| 201 |
+
z_final = z0
|
| 202 |
+
trajectory = z0.unsqueeze(1)
|
| 203 |
+
z_projected = self.latent_rollout.output_proj(z0).unsqueeze(1)
|
| 204 |
+
|
| 205 |
+
# ==================== 4. Target Encoder (JEPA) ====================
|
| 206 |
+
results = {}
|
| 207 |
+
|
| 208 |
+
if self._use_jepa and self.training:
|
| 209 |
+
target_output = self.target_encoder(
|
| 210 |
+
visual_tokens=visual_tokens.detach(),
|
| 211 |
+
text_tokens=text_tokens.detach(),
|
| 212 |
+
text_mask=text_mask.detach(),
|
| 213 |
+
**{k: v.detach() if v is not None else None
|
| 214 |
+
for k, v in enriched_kwargs.items()},
|
| 215 |
+
)
|
| 216 |
+
target_trajectory = target_output['target_trajectory']
|
| 217 |
+
results['target_trajectory'] = target_trajectory
|
| 218 |
+
|
| 219 |
+
# ==================== 5. Answer Heads ====================
|
| 220 |
+
# Discriminative head (MC questions)
|
| 221 |
+
if option_embeddings is not None and option_mask is not None:
|
| 222 |
+
disc_output = self.disc_head(z_final, option_embeddings, option_mask)
|
| 223 |
+
results['disc_logits'] = disc_output['logits']
|
| 224 |
+
results['disc_probs'] = disc_output['probs']
|
| 225 |
+
|
| 226 |
+
# Task loss
|
| 227 |
+
if answer_labels is not None:
|
| 228 |
+
task_loss = F.cross_entropy(disc_output['logits'], answer_labels)
|
| 229 |
+
results['task_loss'] = task_loss
|
| 230 |
+
|
| 231 |
+
# Generative head (open-ended questions)
|
| 232 |
+
if gen_target_ids is not None:
|
| 233 |
+
gen_output = self.gen_head(
|
| 234 |
+
z_final=z_final,
|
| 235 |
+
target_ids=gen_target_ids,
|
| 236 |
+
evidence_tokens=evidence_tokens,
|
| 237 |
+
)
|
| 238 |
+
results['gen_logits'] = gen_output['logits']
|
| 239 |
+
results['gen_loss'] = gen_output['loss']
|
| 240 |
+
|
| 241 |
+
# ==================== 6. JEPA Loss ====================
|
| 242 |
+
if self._use_jepa and self.training and 'target_trajectory' in results:
|
| 243 |
+
task_loss = results.get('task_loss', torch.tensor(0.0, device=pixel_values.device))
|
| 244 |
+
gen_loss = results.get('gen_loss', None)
|
| 245 |
+
|
| 246 |
+
loss_dict = self.jepa_loss_fn(
|
| 247 |
+
predicted_trajectory=z_projected,
|
| 248 |
+
target_trajectory=target_trajectory,
|
| 249 |
+
task_loss=task_loss,
|
| 250 |
+
gen_loss=gen_loss,
|
| 251 |
+
)
|
| 252 |
+
results.update(loss_dict)
|
| 253 |
+
elif 'task_loss' in results:
|
| 254 |
+
results['total_loss'] = results['task_loss']
|
| 255 |
+
if 'gen_loss' in results:
|
| 256 |
+
results['total_loss'] = results['total_loss'] + \
|
| 257 |
+
self.config.jepa.generative_loss_weight * results['gen_loss']
|
| 258 |
+
|
| 259 |
+
# Store trajectory for analysis
|
| 260 |
+
results['trajectory'] = trajectory
|
| 261 |
+
results['z_final'] = z_final
|
| 262 |
+
results['evidence_tokens'] = evidence_tokens
|
| 263 |
+
|
| 264 |
+
return results
|
| 265 |
+
|
| 266 |
+
def update_target_encoder(self, step: int, total_steps: int):
|
| 267 |
+
"""Update EMA target encoder (call after each optimizer step)."""
|
| 268 |
+
self.target_encoder.update_ema(
|
| 269 |
+
online_evidence_memory=self.evidence_memory,
|
| 270 |
+
online_rollout=self.latent_rollout,
|
| 271 |
+
step=step,
|
| 272 |
+
total_steps=total_steps,
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
@torch.no_grad()
|
| 276 |
+
def predict_mc(
|
| 277 |
+
self,
|
| 278 |
+
pixel_values: torch.Tensor,
|
| 279 |
+
input_ids: torch.Tensor,
|
| 280 |
+
attention_mask: torch.Tensor,
|
| 281 |
+
option_embeddings: torch.Tensor,
|
| 282 |
+
option_mask: torch.Tensor,
|
| 283 |
+
) -> torch.Tensor:
|
| 284 |
+
"""Predict answer for multiple-choice questions. Returns predicted indices."""
|
| 285 |
+
self.eval()
|
| 286 |
+
outputs = self.forward(
|
| 287 |
+
pixel_values=pixel_values,
|
| 288 |
+
input_ids=input_ids,
|
| 289 |
+
attention_mask=attention_mask,
|
| 290 |
+
option_embeddings=option_embeddings,
|
| 291 |
+
option_mask=option_mask,
|
| 292 |
+
)
|
| 293 |
+
return outputs['disc_probs'].argmax(dim=-1)
|
| 294 |
+
|
| 295 |
+
@torch.no_grad()
|
| 296 |
+
def predict_open(
|
| 297 |
+
self,
|
| 298 |
+
pixel_values: torch.Tensor,
|
| 299 |
+
input_ids: torch.Tensor,
|
| 300 |
+
attention_mask: torch.Tensor,
|
| 301 |
+
start_token_id: int,
|
| 302 |
+
max_length: int = 64,
|
| 303 |
+
eos_token_id: Optional[int] = None,
|
| 304 |
+
) -> torch.Tensor:
|
| 305 |
+
"""Generate short answer for open-ended questions."""
|
| 306 |
+
self.eval()
|
| 307 |
+
outputs = self.forward(
|
| 308 |
+
pixel_values=pixel_values,
|
| 309 |
+
input_ids=input_ids,
|
| 310 |
+
attention_mask=attention_mask,
|
| 311 |
+
)
|
| 312 |
+
return self.gen_head.generate(
|
| 313 |
+
z_final=outputs['z_final'],
|
| 314 |
+
start_token_id=start_token_id,
|
| 315 |
+
max_length=max_length,
|
| 316 |
+
evidence_tokens=outputs['evidence_tokens'],
|
| 317 |
+
eos_token_id=eos_token_id,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
def set_ablation(self, use_jepa: bool = True, use_rollout: bool = True,
|
| 321 |
+
use_evidence_gate: bool = True):
|
| 322 |
+
"""Configure ablation settings for experiments."""
|
| 323 |
+
self._use_jepa = use_jepa
|
| 324 |
+
self._use_rollout = use_rollout
|
| 325 |
+
|
| 326 |
+
# Disable evidence gates in rollout
|
| 327 |
+
if not use_evidence_gate:
|
| 328 |
+
for layer in self.latent_rollout.predictor_layers:
|
| 329 |
+
layer.evidence_gate = lambda s, e: e # Identity gate
|
| 330 |
+
|
| 331 |
+
def count_parameters(self) -> Dict[str, int]:
|
| 332 |
+
"""Count parameters by component."""
|
| 333 |
+
counts = {}
|
| 334 |
+
for name, module in [
|
| 335 |
+
('visual_backbone', self.visual_backbone),
|
| 336 |
+
('text_encoder', self.text_encoder),
|
| 337 |
+
('evidence_memory', self.evidence_memory),
|
| 338 |
+
('latent_rollout', self.latent_rollout),
|
| 339 |
+
('disc_head', self.disc_head),
|
| 340 |
+
('gen_head', self.gen_head),
|
| 341 |
+
]:
|
| 342 |
+
total = sum(p.numel() for p in module.parameters())
|
| 343 |
+
trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
|
| 344 |
+
counts[name] = {'total': total, 'trainable': trainable}
|
| 345 |
+
|
| 346 |
+
counts['total'] = {
|
| 347 |
+
'total': sum(c['total'] for c in counts.values()),
|
| 348 |
+
'trainable': sum(c['trainable'] for c in counts.values()),
|
| 349 |
+
}
|
| 350 |
+
return counts
|
mr_jepa/models/target_encoder.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Target Encoder (EMA) for MR-JEPA.
|
| 3 |
+
|
| 4 |
+
The target encoder generates the supervision signal for the JEPA objective.
|
| 5 |
+
It is an exponential moving average (EMA) copy of the online encoder
|
| 6 |
+
(evidence memory + rollout module).
|
| 7 |
+
|
| 8 |
+
From I-JEPA:
|
| 9 |
+
θ̄ ← m·θ̄ + (1-m)·θ
|
| 10 |
+
where m follows a cosine schedule from 0.996 → 1.0
|
| 11 |
+
|
| 12 |
+
The target encoder processes the same inputs but with stop-gradient,
|
| 13 |
+
producing target latent states z*_k that the online predictor must predict.
|
| 14 |
+
|
| 15 |
+
From LeWorldModel: We also add SIGReg anti-collapse regularization
|
| 16 |
+
to prevent the representation space from collapsing.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
import math
|
| 23 |
+
import copy
|
| 24 |
+
from typing import Optional, Dict
|
| 25 |
+
|
| 26 |
+
from ..configs.model_config import JEPAObjectiveConfig
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class TargetEncoder(nn.Module):
|
| 30 |
+
"""
|
| 31 |
+
EMA target encoder that generates JEPA targets.
|
| 32 |
+
|
| 33 |
+
This module wraps a copy of the online encoder (evidence memory + rollout)
|
| 34 |
+
and updates its weights via exponential moving average.
|
| 35 |
+
|
| 36 |
+
The target latent trajectory is used as the ground truth for the
|
| 37 |
+
JEPA prediction loss: ||z_predicted_k - sg(z*_k)||²
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
online_evidence_memory: nn.Module,
|
| 43 |
+
online_rollout: nn.Module,
|
| 44 |
+
config: JEPAObjectiveConfig,
|
| 45 |
+
):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.config = config
|
| 48 |
+
|
| 49 |
+
# Deep copy of online modules
|
| 50 |
+
self.target_evidence_memory = copy.deepcopy(online_evidence_memory)
|
| 51 |
+
self.target_rollout = copy.deepcopy(online_rollout)
|
| 52 |
+
|
| 53 |
+
# Freeze target encoder (no gradient)
|
| 54 |
+
for param in self.target_evidence_memory.parameters():
|
| 55 |
+
param.requires_grad = False
|
| 56 |
+
for param in self.target_rollout.parameters():
|
| 57 |
+
param.requires_grad = False
|
| 58 |
+
|
| 59 |
+
# EMA schedule tracking
|
| 60 |
+
self._current_momentum = config.ema_momentum_base
|
| 61 |
+
|
| 62 |
+
@torch.no_grad()
|
| 63 |
+
def update_ema(
|
| 64 |
+
self,
|
| 65 |
+
online_evidence_memory: nn.Module,
|
| 66 |
+
online_rollout: nn.Module,
|
| 67 |
+
step: int,
|
| 68 |
+
total_steps: int,
|
| 69 |
+
):
|
| 70 |
+
"""
|
| 71 |
+
Update target encoder weights via EMA.
|
| 72 |
+
|
| 73 |
+
From I-JEPA: cosine schedule from base momentum to 1.0
|
| 74 |
+
m(t) = 1 - (1 - m_base) * (1 + cos(π * t / T)) / 2
|
| 75 |
+
"""
|
| 76 |
+
# Compute momentum
|
| 77 |
+
if self.config.ema_schedule == "cosine":
|
| 78 |
+
# Cosine annealing from base to end momentum
|
| 79 |
+
progress = step / max(total_steps, 1)
|
| 80 |
+
momentum = self.config.ema_momentum_end - \
|
| 81 |
+
(self.config.ema_momentum_end - self.config.ema_momentum_base) * \
|
| 82 |
+
(1 + math.cos(math.pi * progress)) / 2
|
| 83 |
+
elif self.config.ema_schedule == "linear":
|
| 84 |
+
progress = step / max(total_steps, 1)
|
| 85 |
+
momentum = self.config.ema_momentum_base + \
|
| 86 |
+
(self.config.ema_momentum_end - self.config.ema_momentum_base) * progress
|
| 87 |
+
else: # constant
|
| 88 |
+
momentum = self.config.ema_momentum_base
|
| 89 |
+
|
| 90 |
+
self._current_momentum = momentum
|
| 91 |
+
|
| 92 |
+
# Update evidence memory
|
| 93 |
+
for online_p, target_p in zip(
|
| 94 |
+
online_evidence_memory.parameters(),
|
| 95 |
+
self.target_evidence_memory.parameters()
|
| 96 |
+
):
|
| 97 |
+
target_p.data.mul_(momentum).add_(online_p.data, alpha=1 - momentum)
|
| 98 |
+
|
| 99 |
+
# Update rollout module
|
| 100 |
+
for online_p, target_p in zip(
|
| 101 |
+
online_rollout.parameters(),
|
| 102 |
+
self.target_rollout.parameters()
|
| 103 |
+
):
|
| 104 |
+
target_p.data.mul_(momentum).add_(online_p.data, alpha=1 - momentum)
|
| 105 |
+
|
| 106 |
+
@torch.no_grad()
|
| 107 |
+
def forward(
|
| 108 |
+
self,
|
| 109 |
+
visual_tokens: torch.Tensor,
|
| 110 |
+
text_tokens: torch.Tensor,
|
| 111 |
+
text_mask: torch.Tensor,
|
| 112 |
+
**enriched_kwargs,
|
| 113 |
+
) -> Dict[str, torch.Tensor]:
|
| 114 |
+
"""
|
| 115 |
+
Generate target latent trajectory (no gradient).
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
dict with:
|
| 119 |
+
'target_trajectory': [B, K+1, N_s, D] - target states
|
| 120 |
+
'target_evidence': [B, N_e, D] - target evidence tokens
|
| 121 |
+
"""
|
| 122 |
+
# Target evidence memory
|
| 123 |
+
evidence_output = self.target_evidence_memory(
|
| 124 |
+
visual_tokens=visual_tokens,
|
| 125 |
+
text_tokens=text_tokens,
|
| 126 |
+
text_mask=text_mask,
|
| 127 |
+
**enriched_kwargs,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
target_evidence = evidence_output['evidence_tokens']
|
| 131 |
+
|
| 132 |
+
# Target rollout
|
| 133 |
+
rollout_output = self.target_rollout(
|
| 134 |
+
evidence_tokens=target_evidence,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
return {
|
| 138 |
+
'target_trajectory': rollout_output['trajectory'],
|
| 139 |
+
'target_evidence': target_evidence,
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class SIGRegLoss(nn.Module):
|
| 144 |
+
"""
|
| 145 |
+
Sketched Isotropic Gaussian Regularizer (from LeWorldModel).
|
| 146 |
+
|
| 147 |
+
Prevents representation collapse by encouraging latent embeddings
|
| 148 |
+
to match an isotropic Gaussian distribution.
|
| 149 |
+
|
| 150 |
+
Uses random projections + Epps-Pulley test statistic.
|
| 151 |
+
SIGReg(Z) = (1/M) Σ_m T(Z @ u_m)
|
| 152 |
+
|
| 153 |
+
where T is the Epps-Pulley univariate normality test.
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
def __init__(self, hidden_dim: int, num_projections: int = 1024):
|
| 157 |
+
super().__init__()
|
| 158 |
+
self.num_projections = num_projections
|
| 159 |
+
# Random projection directions (fixed, not learned)
|
| 160 |
+
self.register_buffer(
|
| 161 |
+
'projections',
|
| 162 |
+
F.normalize(torch.randn(hidden_dim, num_projections), dim=0)
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
def _epps_pulley_statistic(self, h: torch.Tensor) -> torch.Tensor:
|
| 166 |
+
"""
|
| 167 |
+
Compute Epps-Pulley test statistic for univariate normality.
|
| 168 |
+
|
| 169 |
+
T(h) measures how far the distribution of h is from N(0,1).
|
| 170 |
+
Lower values = more Gaussian.
|
| 171 |
+
|
| 172 |
+
Simplified version: uses moment-based approximation.
|
| 173 |
+
"""
|
| 174 |
+
# Standardize
|
| 175 |
+
h_mean = h.mean()
|
| 176 |
+
h_std = h.std() + 1e-6
|
| 177 |
+
h_norm = (h - h_mean) / h_std
|
| 178 |
+
|
| 179 |
+
n = h_norm.size(0)
|
| 180 |
+
|
| 181 |
+
# Compute pairwise differences for the EP statistic
|
| 182 |
+
# EP test: based on characteristic function
|
| 183 |
+
# Simplified: variance + kurtosis penalty
|
| 184 |
+
variance = h_norm.var()
|
| 185 |
+
kurtosis = ((h_norm ** 4).mean() - 3).abs() # Excess kurtosis
|
| 186 |
+
|
| 187 |
+
# Penalize deviation from unit variance and zero excess kurtosis
|
| 188 |
+
return (variance - 1.0) ** 2 + 0.5 * kurtosis
|
| 189 |
+
|
| 190 |
+
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
| 191 |
+
"""
|
| 192 |
+
Compute SIGReg loss.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
z: Latent embeddings [B, N, D] or [B*N, D]
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
Scalar SIGReg loss
|
| 199 |
+
"""
|
| 200 |
+
if z.dim() == 3:
|
| 201 |
+
B, N, D = z.shape
|
| 202 |
+
z_flat = z.reshape(B * N, D)
|
| 203 |
+
else:
|
| 204 |
+
z_flat = z
|
| 205 |
+
|
| 206 |
+
# Project onto random directions
|
| 207 |
+
projections = z_flat @ self.projections # [B*N, M]
|
| 208 |
+
|
| 209 |
+
# Compute EP statistic for each projection
|
| 210 |
+
losses = []
|
| 211 |
+
for m in range(min(self.num_projections, 64)): # Sample subset for efficiency
|
| 212 |
+
losses.append(self._epps_pulley_statistic(projections[:, m]))
|
| 213 |
+
|
| 214 |
+
return torch.stack(losses).mean()
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class VICRegLoss(nn.Module):
|
| 218 |
+
"""
|
| 219 |
+
VICReg-style regularization (alternative to SIGReg).
|
| 220 |
+
|
| 221 |
+
Three terms:
|
| 222 |
+
- Variance: keep feature std above a threshold
|
| 223 |
+
- Invariance: prediction should match target (already handled by L2)
|
| 224 |
+
- Covariance: decorrelate features
|
| 225 |
+
"""
|
| 226 |
+
|
| 227 |
+
def __init__(self, var_weight: float = 1.0, cov_weight: float = 0.04):
|
| 228 |
+
super().__init__()
|
| 229 |
+
self.var_weight = var_weight
|
| 230 |
+
self.cov_weight = cov_weight
|
| 231 |
+
|
| 232 |
+
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
| 233 |
+
"""
|
| 234 |
+
Args:
|
| 235 |
+
z: [B*N, D] latent embeddings
|
| 236 |
+
"""
|
| 237 |
+
if z.dim() == 3:
|
| 238 |
+
z = z.reshape(-1, z.size(-1))
|
| 239 |
+
|
| 240 |
+
# Variance: penalize if std drops below 1
|
| 241 |
+
std = z.std(dim=0)
|
| 242 |
+
var_loss = F.relu(1.0 - std).mean()
|
| 243 |
+
|
| 244 |
+
# Covariance: penalize off-diagonal correlations
|
| 245 |
+
z_centered = z - z.mean(dim=0, keepdim=True)
|
| 246 |
+
N = z_centered.size(0)
|
| 247 |
+
cov = (z_centered.T @ z_centered) / (N - 1)
|
| 248 |
+
D = cov.size(0)
|
| 249 |
+
# Off-diagonal elements
|
| 250 |
+
off_diag = cov.flatten()[:-1].view(D - 1, D + 1)[:, 1:].flatten()
|
| 251 |
+
cov_loss = (off_diag ** 2).mean()
|
| 252 |
+
|
| 253 |
+
return self.var_weight * var_loss + self.cov_weight * cov_loss
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class JEPALoss(nn.Module):
|
| 257 |
+
"""
|
| 258 |
+
Complete JEPA objective for MR-JEPA.
|
| 259 |
+
|
| 260 |
+
L_JEPA = (1/K) Σ_{k=1}^{K} ||z_pred_k - sg(z*_k)||²
|
| 261 |
+
|
| 262 |
+
Plus anti-collapse regularization:
|
| 263 |
+
L_total = L_JEPA + λ * SIGReg(Z) + L_task + α * L_gen
|
| 264 |
+
"""
|
| 265 |
+
|
| 266 |
+
def __init__(self, config: JEPAObjectiveConfig, hidden_dim: int):
|
| 267 |
+
super().__init__()
|
| 268 |
+
self.config = config
|
| 269 |
+
|
| 270 |
+
# Anti-collapse
|
| 271 |
+
if config.use_sigreg:
|
| 272 |
+
self.sigreg = SIGRegLoss(hidden_dim, config.sigreg_num_projections)
|
| 273 |
+
if config.use_vicreg:
|
| 274 |
+
self.vicreg = VICRegLoss(config.vicreg_var_weight, config.vicreg_cov_weight)
|
| 275 |
+
|
| 276 |
+
def compute_jepa_loss(
|
| 277 |
+
self,
|
| 278 |
+
predicted_trajectory: torch.Tensor, # [B, K+1, N_s, D]
|
| 279 |
+
target_trajectory: torch.Tensor, # [B, K+1, N_s, D]
|
| 280 |
+
) -> torch.Tensor:
|
| 281 |
+
"""
|
| 282 |
+
Compute L2 prediction loss between online and target trajectories.
|
| 283 |
+
|
| 284 |
+
Only compute loss for steps k=1..K (not z₀, which is deterministic).
|
| 285 |
+
"""
|
| 286 |
+
# Skip z₀ (step 0) — only supervise predicted states
|
| 287 |
+
pred = predicted_trajectory[:, 1:] # [B, K, N_s, D]
|
| 288 |
+
target = target_trajectory[:, 1:] # [B, K, N_s, D]
|
| 289 |
+
|
| 290 |
+
# L2 loss per step, averaged
|
| 291 |
+
loss = F.mse_loss(pred, target.detach())
|
| 292 |
+
return loss
|
| 293 |
+
|
| 294 |
+
def compute_regularization(
|
| 295 |
+
self,
|
| 296 |
+
trajectory: torch.Tensor, # [B, K+1, N_s, D]
|
| 297 |
+
) -> torch.Tensor:
|
| 298 |
+
"""Compute anti-collapse regularization."""
|
| 299 |
+
reg_loss = torch.tensor(0.0, device=trajectory.device)
|
| 300 |
+
|
| 301 |
+
if self.config.use_sigreg:
|
| 302 |
+
# Apply SIGReg to each step's representations
|
| 303 |
+
B, Kp1, N_s, D = trajectory.shape
|
| 304 |
+
for k in range(Kp1):
|
| 305 |
+
reg_loss = reg_loss + self.sigreg(trajectory[:, k])
|
| 306 |
+
reg_loss = reg_loss / Kp1
|
| 307 |
+
reg_loss = self.config.sigreg_weight * reg_loss
|
| 308 |
+
|
| 309 |
+
if self.config.use_vicreg:
|
| 310 |
+
B, Kp1, N_s, D = trajectory.shape
|
| 311 |
+
for k in range(Kp1):
|
| 312 |
+
reg_loss = reg_loss + self.vicreg(trajectory[:, k])
|
| 313 |
+
reg_loss = reg_loss / Kp1
|
| 314 |
+
|
| 315 |
+
return reg_loss
|
| 316 |
+
|
| 317 |
+
def forward(
|
| 318 |
+
self,
|
| 319 |
+
predicted_trajectory: torch.Tensor,
|
| 320 |
+
target_trajectory: torch.Tensor,
|
| 321 |
+
task_loss: torch.Tensor,
|
| 322 |
+
gen_loss: Optional[torch.Tensor] = None,
|
| 323 |
+
) -> Dict[str, torch.Tensor]:
|
| 324 |
+
"""
|
| 325 |
+
Compute total MR-JEPA loss.
|
| 326 |
+
|
| 327 |
+
Returns dict with individual loss components for logging.
|
| 328 |
+
"""
|
| 329 |
+
# JEPA prediction loss
|
| 330 |
+
jepa_loss = self.compute_jepa_loss(predicted_trajectory, target_trajectory)
|
| 331 |
+
|
| 332 |
+
# Anti-collapse regularization
|
| 333 |
+
reg_loss = self.compute_regularization(predicted_trajectory)
|
| 334 |
+
|
| 335 |
+
# Total loss
|
| 336 |
+
total = (
|
| 337 |
+
self.config.jepa_loss_weight * jepa_loss +
|
| 338 |
+
self.config.task_loss_weight * task_loss +
|
| 339 |
+
reg_loss
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
losses = {
|
| 343 |
+
'total_loss': total,
|
| 344 |
+
'jepa_loss': jepa_loss,
|
| 345 |
+
'task_loss': task_loss,
|
| 346 |
+
'reg_loss': reg_loss,
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
if gen_loss is not None:
|
| 350 |
+
total = total + self.config.generative_loss_weight * gen_loss
|
| 351 |
+
losses['total_loss'] = total
|
| 352 |
+
losses['gen_loss'] = gen_loss
|
| 353 |
+
|
| 354 |
+
return losses
|
mr_jepa/training/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .trainer import MRJEPATrainer
|
| 2 |
+
from .phase_scheduler import PhaseScheduler
|
| 3 |
+
|
| 4 |
+
__all__ = ["MRJEPATrainer", "PhaseScheduler"]
|
mr_jepa/training/phase_scheduler.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Phase Scheduler for MR-JEPA 3-Phase Training.
|
| 3 |
+
|
| 4 |
+
Manages the transition between training phases:
|
| 5 |
+
Phase 1: Freeze perception → train reasoning core
|
| 6 |
+
Phase 2: Unfreeze perception → fine-tune end-to-end
|
| 7 |
+
Phase 3: Enable enriched evidence → document/chart specialization
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
import torch
|
| 12 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class CosineWarmupScheduler(_LRScheduler):
|
| 17 |
+
"""Cosine schedule with linear warmup (per phase)."""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
optimizer: torch.optim.Optimizer,
|
| 22 |
+
warmup_steps: int,
|
| 23 |
+
total_steps: int,
|
| 24 |
+
min_lr_ratio: float = 0.01,
|
| 25 |
+
last_epoch: int = -1,
|
| 26 |
+
):
|
| 27 |
+
self.warmup_steps = warmup_steps
|
| 28 |
+
self.total_steps = total_steps
|
| 29 |
+
self.min_lr_ratio = min_lr_ratio
|
| 30 |
+
super().__init__(optimizer, last_epoch)
|
| 31 |
+
|
| 32 |
+
def get_lr(self):
|
| 33 |
+
step = self.last_epoch
|
| 34 |
+
|
| 35 |
+
if step < self.warmup_steps:
|
| 36 |
+
# Linear warmup
|
| 37 |
+
factor = step / max(self.warmup_steps, 1)
|
| 38 |
+
else:
|
| 39 |
+
# Cosine decay
|
| 40 |
+
progress = (step - self.warmup_steps) / max(
|
| 41 |
+
self.total_steps - self.warmup_steps, 1
|
| 42 |
+
)
|
| 43 |
+
factor = self.min_lr_ratio + (1 - self.min_lr_ratio) * \
|
| 44 |
+
0.5 * (1 + math.cos(math.pi * progress))
|
| 45 |
+
|
| 46 |
+
return [base_lr * factor for base_lr in self.base_lrs]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class PhaseScheduler:
|
| 50 |
+
"""
|
| 51 |
+
Orchestrates the 3-phase training schedule.
|
| 52 |
+
|
| 53 |
+
Handles:
|
| 54 |
+
- Phase transitions (unfreezing, enabling modules)
|
| 55 |
+
- Per-phase optimizer and LR scheduler creation
|
| 56 |
+
- Checkpoint management between phases
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
model,
|
| 62 |
+
training_config,
|
| 63 |
+
):
|
| 64 |
+
self.model = model
|
| 65 |
+
self.training_config = training_config
|
| 66 |
+
self.current_phase = 0
|
| 67 |
+
self.phase_histories = {1: [], 2: [], 3: []}
|
| 68 |
+
|
| 69 |
+
def get_phase_scheduler(
|
| 70 |
+
self,
|
| 71 |
+
optimizer: torch.optim.Optimizer,
|
| 72 |
+
phase: int,
|
| 73 |
+
steps_per_epoch: int,
|
| 74 |
+
) -> CosineWarmupScheduler:
|
| 75 |
+
"""Create LR scheduler for a specific phase."""
|
| 76 |
+
if phase == 1:
|
| 77 |
+
epochs = self.training_config.phase1_epochs
|
| 78 |
+
warmup_ratio = self.training_config.phase1_warmup_ratio
|
| 79 |
+
elif phase == 2:
|
| 80 |
+
epochs = self.training_config.phase2_epochs
|
| 81 |
+
warmup_ratio = self.training_config.phase2_warmup_ratio
|
| 82 |
+
else:
|
| 83 |
+
epochs = self.training_config.phase3_epochs
|
| 84 |
+
warmup_ratio = self.training_config.phase3_warmup_ratio
|
| 85 |
+
|
| 86 |
+
total_steps = epochs * steps_per_epoch
|
| 87 |
+
warmup_steps = int(total_steps * warmup_ratio)
|
| 88 |
+
|
| 89 |
+
return CosineWarmupScheduler(
|
| 90 |
+
optimizer=optimizer,
|
| 91 |
+
warmup_steps=warmup_steps,
|
| 92 |
+
total_steps=total_steps,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
def should_transition(self, phase: int, epoch: int) -> bool:
|
| 96 |
+
"""Check if we should move to the next phase."""
|
| 97 |
+
if phase == 1:
|
| 98 |
+
return epoch >= self.training_config.phase1_epochs
|
| 99 |
+
elif phase == 2:
|
| 100 |
+
return epoch >= self.training_config.phase2_epochs
|
| 101 |
+
elif phase == 3:
|
| 102 |
+
return epoch >= self.training_config.phase3_epochs
|
| 103 |
+
return True
|
| 104 |
+
|
| 105 |
+
def log_phase_metrics(self, phase: int, metrics: dict):
|
| 106 |
+
"""Record metrics for phase transition analysis."""
|
| 107 |
+
self.phase_histories[phase].append(metrics)
|
mr_jepa/training/trainer.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MR-JEPA Trainer.
|
| 3 |
+
|
| 4 |
+
Implements the 3-phase training schedule:
|
| 5 |
+
|
| 6 |
+
Phase 1 (Reasoning Core):
|
| 7 |
+
- Freeze visual backbone + text encoder
|
| 8 |
+
- Train evidence memory, latent rollout, answer heads
|
| 9 |
+
- Full JEPA objective + task loss
|
| 10 |
+
|
| 11 |
+
Phase 2 (Perception Fine-tuning):
|
| 12 |
+
- Unfreeze last N visual backbone layers (lower LR)
|
| 13 |
+
- Unfreeze last N text encoder layers (lower LR)
|
| 14 |
+
- Continue training all other components
|
| 15 |
+
|
| 16 |
+
Phase 3 (Enriched Evidence):
|
| 17 |
+
- Enable OCR, layout, chart tokens
|
| 18 |
+
- Fine-tune entire model end-to-end
|
| 19 |
+
- Focus on document/chart benchmarks
|
| 20 |
+
|
| 21 |
+
Each phase uses cosine LR schedule with warmup.
|
| 22 |
+
EMA target encoder is updated after each optimizer step.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
import time
|
| 27 |
+
import json
|
| 28 |
+
import torch
|
| 29 |
+
import torch.nn as nn
|
| 30 |
+
import torch.nn.functional as F
|
| 31 |
+
from torch.optim import AdamW
|
| 32 |
+
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
| 33 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 34 |
+
from typing import Optional, Dict, Any, List
|
| 35 |
+
import logging
|
| 36 |
+
from pathlib import Path
|
| 37 |
+
|
| 38 |
+
from ..configs.model_config import MRJEPAConfig, TrainingPhaseConfig
|
| 39 |
+
from ..models.mr_jepa import MRJEPAModel
|
| 40 |
+
|
| 41 |
+
logger = logging.getLogger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class MRJEPATrainer:
|
| 45 |
+
"""
|
| 46 |
+
3-phase trainer for MR-JEPA.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
model: MRJEPAModel,
|
| 52 |
+
config: MRJEPAConfig,
|
| 53 |
+
training_config: TrainingPhaseConfig,
|
| 54 |
+
train_dataloaders: Dict[str, Any], # Per-benchmark dataloaders
|
| 55 |
+
eval_dataloaders: Dict[str, Any],
|
| 56 |
+
output_dir: str = "./outputs",
|
| 57 |
+
device: str = "cuda",
|
| 58 |
+
):
|
| 59 |
+
self.model = model.to(device)
|
| 60 |
+
self.config = config
|
| 61 |
+
self.training_config = training_config
|
| 62 |
+
self.train_dataloaders = train_dataloaders
|
| 63 |
+
self.eval_dataloaders = eval_dataloaders
|
| 64 |
+
self.output_dir = Path(output_dir)
|
| 65 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 66 |
+
self.device = device
|
| 67 |
+
|
| 68 |
+
# Training state
|
| 69 |
+
self.global_step = 0
|
| 70 |
+
self.current_phase = 0
|
| 71 |
+
self.best_metric = 0.0
|
| 72 |
+
|
| 73 |
+
# Mixed precision
|
| 74 |
+
self.use_amp = training_config.bf16 or training_config.fp16
|
| 75 |
+
self.amp_dtype = torch.bfloat16 if training_config.bf16 else torch.float16
|
| 76 |
+
self.scaler = GradScaler(enabled=training_config.fp16) # Only for fp16
|
| 77 |
+
|
| 78 |
+
def _build_optimizer(self, phase: int) -> torch.optim.Optimizer:
|
| 79 |
+
"""Build optimizer with per-phase parameter groups."""
|
| 80 |
+
param_groups = self.model.get_trainable_params(phase)
|
| 81 |
+
|
| 82 |
+
if phase == 1:
|
| 83 |
+
lr = self.training_config.phase1_lr
|
| 84 |
+
groups = [
|
| 85 |
+
{'params': param_groups['high_lr'], 'lr': lr},
|
| 86 |
+
]
|
| 87 |
+
elif phase == 2:
|
| 88 |
+
lr = self.training_config.phase2_lr
|
| 89 |
+
backbone_lr = self.training_config.phase2_backbone_lr
|
| 90 |
+
groups = [
|
| 91 |
+
{'params': param_groups['high_lr'], 'lr': lr},
|
| 92 |
+
{'params': param_groups['low_lr'], 'lr': backbone_lr},
|
| 93 |
+
]
|
| 94 |
+
else: # phase 3
|
| 95 |
+
lr = self.training_config.phase3_lr
|
| 96 |
+
groups = [
|
| 97 |
+
{'params': param_groups['high_lr'], 'lr': lr},
|
| 98 |
+
{'params': param_groups['low_lr'], 'lr': lr * 0.1},
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
# Filter out empty param groups
|
| 102 |
+
groups = [g for g in groups if len(g['params']) > 0]
|
| 103 |
+
|
| 104 |
+
optimizer = AdamW(
|
| 105 |
+
groups,
|
| 106 |
+
weight_decay=self.training_config.phase1_weight_decay,
|
| 107 |
+
)
|
| 108 |
+
return optimizer
|
| 109 |
+
|
| 110 |
+
def _get_phase_config(self, phase: int) -> Dict[str, Any]:
|
| 111 |
+
"""Get training parameters for a specific phase."""
|
| 112 |
+
if phase == 1:
|
| 113 |
+
return {
|
| 114 |
+
'epochs': self.training_config.phase1_epochs,
|
| 115 |
+
'batch_size': self.training_config.phase1_batch_size,
|
| 116 |
+
'grad_accum': self.training_config.phase1_grad_accum,
|
| 117 |
+
'warmup_ratio': self.training_config.phase1_warmup_ratio,
|
| 118 |
+
}
|
| 119 |
+
elif phase == 2:
|
| 120 |
+
return {
|
| 121 |
+
'epochs': self.training_config.phase2_epochs,
|
| 122 |
+
'batch_size': self.training_config.phase2_batch_size,
|
| 123 |
+
'grad_accum': self.training_config.phase2_grad_accum,
|
| 124 |
+
'warmup_ratio': self.training_config.phase2_warmup_ratio,
|
| 125 |
+
}
|
| 126 |
+
else:
|
| 127 |
+
return {
|
| 128 |
+
'epochs': self.training_config.phase3_epochs,
|
| 129 |
+
'batch_size': self.training_config.phase3_batch_size,
|
| 130 |
+
'grad_accum': self.training_config.phase3_grad_accum,
|
| 131 |
+
'warmup_ratio': self.training_config.phase3_warmup_ratio,
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
def _prepare_phase(self, phase: int):
|
| 135 |
+
"""Set up model for a specific training phase."""
|
| 136 |
+
logger.info(f"=== Preparing Phase {phase} ===")
|
| 137 |
+
|
| 138 |
+
if phase == 1:
|
| 139 |
+
# Freeze all perception, train reasoning core
|
| 140 |
+
self.model.visual_backbone.freeze_all()
|
| 141 |
+
self.model.text_encoder.freeze_all()
|
| 142 |
+
|
| 143 |
+
elif phase == 2:
|
| 144 |
+
# Unfreeze last N layers of backbones
|
| 145 |
+
n_visual = self.training_config.phase2_unfreeze_visual_layers
|
| 146 |
+
n_text = self.training_config.phase2_unfreeze_text_layers
|
| 147 |
+
self.model.visual_backbone.unfreeze_last_n_layers(n_visual)
|
| 148 |
+
self.model.text_encoder.unfreeze_last_n_layers(n_text)
|
| 149 |
+
logger.info(f"Unfroze last {n_visual} visual layers, {n_text} text layers")
|
| 150 |
+
|
| 151 |
+
elif phase == 3:
|
| 152 |
+
# Enable enriched evidence
|
| 153 |
+
if self.training_config.phase3_enable_ocr:
|
| 154 |
+
self.config.evidence.use_ocr_tokens = True
|
| 155 |
+
if self.training_config.phase3_enable_layout:
|
| 156 |
+
self.config.evidence.use_layout_tokens = True
|
| 157 |
+
if self.training_config.phase3_enable_chart:
|
| 158 |
+
self.config.evidence.use_chart_tokens = True
|
| 159 |
+
if self.training_config.phase3_enable_sam:
|
| 160 |
+
self.config.evidence.use_sam_tokens = True
|
| 161 |
+
logger.info("Enabled enriched evidence tokens")
|
| 162 |
+
|
| 163 |
+
self.current_phase = phase
|
| 164 |
+
|
| 165 |
+
def _train_step(
|
| 166 |
+
self,
|
| 167 |
+
batch: Dict[str, torch.Tensor],
|
| 168 |
+
optimizer: torch.optim.Optimizer,
|
| 169 |
+
grad_accum_steps: int,
|
| 170 |
+
total_steps: int,
|
| 171 |
+
) -> Dict[str, float]:
|
| 172 |
+
"""Single training step with gradient accumulation."""
|
| 173 |
+
# Move batch to device
|
| 174 |
+
device_batch = {}
|
| 175 |
+
for k, v in batch.items():
|
| 176 |
+
if isinstance(v, torch.Tensor):
|
| 177 |
+
device_batch[k] = v.to(self.device)
|
| 178 |
+
else:
|
| 179 |
+
device_batch[k] = v
|
| 180 |
+
|
| 181 |
+
# Handle option embeddings (encode option texts through text encoder)
|
| 182 |
+
if 'option_texts' in batch:
|
| 183 |
+
option_embs = self._encode_options(batch['option_texts'])
|
| 184 |
+
device_batch['option_embeddings'] = option_embs.to(self.device)
|
| 185 |
+
|
| 186 |
+
# Forward pass with AMP
|
| 187 |
+
with autocast(device_type='cuda', dtype=self.amp_dtype, enabled=self.use_amp):
|
| 188 |
+
outputs = self.model(
|
| 189 |
+
pixel_values=device_batch.get('pixel_values'),
|
| 190 |
+
input_ids=device_batch.get('input_ids'),
|
| 191 |
+
attention_mask=device_batch.get('attention_mask'),
|
| 192 |
+
option_embeddings=device_batch.get('option_embeddings'),
|
| 193 |
+
option_mask=device_batch.get('option_mask'),
|
| 194 |
+
answer_labels=device_batch.get('answer_labels'),
|
| 195 |
+
gen_target_ids=device_batch.get('gen_target_ids'),
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
loss = outputs.get('total_loss', outputs.get('task_loss', torch.tensor(0.0)))
|
| 199 |
+
loss = loss / grad_accum_steps
|
| 200 |
+
|
| 201 |
+
# Backward
|
| 202 |
+
if self.training_config.fp16:
|
| 203 |
+
self.scaler.scale(loss).backward()
|
| 204 |
+
else:
|
| 205 |
+
loss.backward()
|
| 206 |
+
|
| 207 |
+
# Step optimizer (with grad accumulation)
|
| 208 |
+
if (self.global_step + 1) % grad_accum_steps == 0:
|
| 209 |
+
if self.training_config.max_grad_norm > 0:
|
| 210 |
+
if self.training_config.fp16:
|
| 211 |
+
self.scaler.unscale_(optimizer)
|
| 212 |
+
nn.utils.clip_grad_norm_(
|
| 213 |
+
self.model.parameters(),
|
| 214 |
+
self.training_config.max_grad_norm,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
if self.training_config.fp16:
|
| 218 |
+
self.scaler.step(optimizer)
|
| 219 |
+
self.scaler.update()
|
| 220 |
+
else:
|
| 221 |
+
optimizer.step()
|
| 222 |
+
|
| 223 |
+
optimizer.zero_grad()
|
| 224 |
+
|
| 225 |
+
# Update EMA target encoder
|
| 226 |
+
self.model.update_target_encoder(self.global_step, total_steps)
|
| 227 |
+
|
| 228 |
+
self.global_step += 1
|
| 229 |
+
|
| 230 |
+
# Collect metrics
|
| 231 |
+
metrics = {
|
| 232 |
+
'loss': loss.item() * grad_accum_steps,
|
| 233 |
+
}
|
| 234 |
+
for key in ['jepa_loss', 'task_loss', 'reg_loss', 'gen_loss']:
|
| 235 |
+
if key in outputs:
|
| 236 |
+
metrics[key] = outputs[key].item()
|
| 237 |
+
|
| 238 |
+
return metrics
|
| 239 |
+
|
| 240 |
+
def _encode_options(self, option_texts: List[List[str]]) -> torch.Tensor:
|
| 241 |
+
"""Encode option texts using the text encoder (pooled representation)."""
|
| 242 |
+
B = len(option_texts)
|
| 243 |
+
max_opts = len(option_texts[0])
|
| 244 |
+
|
| 245 |
+
# Flatten all options
|
| 246 |
+
flat_texts = []
|
| 247 |
+
for opts in option_texts:
|
| 248 |
+
flat_texts.extend(opts)
|
| 249 |
+
|
| 250 |
+
# Tokenize
|
| 251 |
+
tokenizer = self.model.text_encoder.tokenizer
|
| 252 |
+
encoded = tokenizer(
|
| 253 |
+
flat_texts,
|
| 254 |
+
padding='max_length',
|
| 255 |
+
truncation=True,
|
| 256 |
+
max_length=64,
|
| 257 |
+
return_tensors='pt',
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
# Encode through text encoder (no gradient for efficiency)
|
| 261 |
+
with torch.no_grad():
|
| 262 |
+
text_output = self.model.text_encoder(
|
| 263 |
+
input_ids=encoded['input_ids'].to(self.device),
|
| 264 |
+
attention_mask=encoded['attention_mask'].to(self.device),
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# Get CLS embedding for each option
|
| 268 |
+
cls_embeddings = text_output['cls_embedding'] # [B*max_opts, D]
|
| 269 |
+
option_embeddings = cls_embeddings.reshape(B, max_opts, -1) # [B, max_opts, D]
|
| 270 |
+
|
| 271 |
+
return option_embeddings
|
| 272 |
+
|
| 273 |
+
def train_phase(self, phase: int):
|
| 274 |
+
"""Run a complete training phase."""
|
| 275 |
+
self._prepare_phase(phase)
|
| 276 |
+
|
| 277 |
+
phase_config = self._get_phase_config(phase)
|
| 278 |
+
optimizer = self._build_optimizer(phase)
|
| 279 |
+
|
| 280 |
+
total_steps = phase_config['epochs'] * sum(
|
| 281 |
+
len(dl) for dl in self.train_dataloaders.values()
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
logger.info(f"Phase {phase}: {phase_config['epochs']} epochs, "
|
| 285 |
+
f"~{total_steps} steps")
|
| 286 |
+
|
| 287 |
+
self.model.train()
|
| 288 |
+
|
| 289 |
+
for epoch in range(phase_config['epochs']):
|
| 290 |
+
epoch_metrics = {}
|
| 291 |
+
|
| 292 |
+
# Iterate over all training benchmarks
|
| 293 |
+
for benchmark_name, dataloader in self.train_dataloaders.items():
|
| 294 |
+
for step, batch in enumerate(dataloader):
|
| 295 |
+
metrics = self._train_step(
|
| 296 |
+
batch, optimizer,
|
| 297 |
+
phase_config['grad_accum'],
|
| 298 |
+
total_steps,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Accumulate metrics
|
| 302 |
+
for k, v in metrics.items():
|
| 303 |
+
epoch_metrics.setdefault(k, []).append(v)
|
| 304 |
+
|
| 305 |
+
# Logging
|
| 306 |
+
if self.global_step % 100 == 0:
|
| 307 |
+
avg_loss = sum(epoch_metrics.get('loss', [0])) / max(len(epoch_metrics.get('loss', [1])), 1)
|
| 308 |
+
logger.info(
|
| 309 |
+
f"Phase {phase} | Epoch {epoch} | Step {self.global_step} | "
|
| 310 |
+
f"Loss: {avg_loss:.4f} | "
|
| 311 |
+
f"Benchmark: {benchmark_name}"
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# Epoch-level logging
|
| 315 |
+
avg_metrics = {
|
| 316 |
+
k: sum(v) / len(v) for k, v in epoch_metrics.items()
|
| 317 |
+
}
|
| 318 |
+
logger.info(f"Phase {phase} | Epoch {epoch} complete | "
|
| 319 |
+
f"Avg Loss: {avg_metrics.get('loss', 0):.4f}")
|
| 320 |
+
|
| 321 |
+
# Save checkpoint
|
| 322 |
+
self._save_checkpoint(phase, epoch)
|
| 323 |
+
|
| 324 |
+
def train(self, phases: List[int] = [1, 2, 3]):
|
| 325 |
+
"""Run the full multi-phase training."""
|
| 326 |
+
logger.info("Starting MR-JEPA training")
|
| 327 |
+
logger.info(f"Model parameter counts: {self.model.count_parameters()}")
|
| 328 |
+
|
| 329 |
+
for phase in phases:
|
| 330 |
+
logger.info(f"\n{'='*60}")
|
| 331 |
+
logger.info(f"PHASE {phase}")
|
| 332 |
+
logger.info(f"{'='*60}")
|
| 333 |
+
self.train_phase(phase)
|
| 334 |
+
|
| 335 |
+
# Evaluate after each phase
|
| 336 |
+
eval_results = self.evaluate()
|
| 337 |
+
logger.info(f"Phase {phase} eval results: {json.dumps(eval_results, indent=2)}")
|
| 338 |
+
|
| 339 |
+
logger.info("Training complete!")
|
| 340 |
+
|
| 341 |
+
def evaluate(self) -> Dict[str, Dict[str, float]]:
|
| 342 |
+
"""Evaluate on all benchmark eval sets."""
|
| 343 |
+
from ..evaluation.metrics import evaluate_benchmark
|
| 344 |
+
|
| 345 |
+
self.model.eval()
|
| 346 |
+
results = {}
|
| 347 |
+
|
| 348 |
+
for benchmark_name, dataloader in self.eval_dataloaders.items():
|
| 349 |
+
predictions = []
|
| 350 |
+
ground_truths = []
|
| 351 |
+
|
| 352 |
+
with torch.no_grad():
|
| 353 |
+
for batch in dataloader:
|
| 354 |
+
# Move to device
|
| 355 |
+
pixel_values = batch['pixel_values'].to(self.device)
|
| 356 |
+
input_ids = batch['input_ids'].to(self.device)
|
| 357 |
+
attention_mask = batch['attention_mask'].to(self.device)
|
| 358 |
+
|
| 359 |
+
if 'option_mask' in batch:
|
| 360 |
+
option_mask = batch['option_mask'].to(self.device)
|
| 361 |
+
option_embs = self._encode_options(batch['option_texts'])
|
| 362 |
+
|
| 363 |
+
preds = self.model.predict_mc(
|
| 364 |
+
pixel_values, input_ids, attention_mask,
|
| 365 |
+
option_embs, option_mask,
|
| 366 |
+
)
|
| 367 |
+
predictions.extend(preds.cpu().tolist())
|
| 368 |
+
ground_truths.extend(batch['answer_labels'].tolist())
|
| 369 |
+
else:
|
| 370 |
+
# Open-ended (would need generation)
|
| 371 |
+
# Simplified: skip for now
|
| 372 |
+
pass
|
| 373 |
+
|
| 374 |
+
if predictions:
|
| 375 |
+
result = evaluate_benchmark(
|
| 376 |
+
benchmark_name, predictions, ground_truths
|
| 377 |
+
)
|
| 378 |
+
results[benchmark_name] = result
|
| 379 |
+
|
| 380 |
+
self.model.train()
|
| 381 |
+
return results
|
| 382 |
+
|
| 383 |
+
def _save_checkpoint(self, phase: int, epoch: int):
|
| 384 |
+
"""Save model checkpoint."""
|
| 385 |
+
ckpt_dir = self.output_dir / f"phase{phase}_epoch{epoch}"
|
| 386 |
+
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 387 |
+
|
| 388 |
+
# Save model state
|
| 389 |
+
torch.save({
|
| 390 |
+
'model_state_dict': self.model.state_dict(),
|
| 391 |
+
'phase': phase,
|
| 392 |
+
'epoch': epoch,
|
| 393 |
+
'global_step': self.global_step,
|
| 394 |
+
'config': self.config,
|
| 395 |
+
}, ckpt_dir / "checkpoint.pt")
|
| 396 |
+
|
| 397 |
+
logger.info(f"Saved checkpoint to {ckpt_dir}")
|
mr_jepa/utils/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .visualization import visualize_trajectory, visualize_evidence_gates
|
| 2 |
+
from .ablation import AblationRunner
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"visualize_trajectory",
|
| 6 |
+
"visualize_evidence_gates",
|
| 7 |
+
"AblationRunner",
|
| 8 |
+
]
|
mr_jepa/utils/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (312 Bytes). View file
|
|
|
mr_jepa/utils/__pycache__/ablation.cpython-312.pyc
ADDED
|
Binary file (7.16 kB). View file
|
|
|
mr_jepa/utils/__pycache__/visualization.cpython-312.pyc
ADDED
|
Binary file (5.35 kB). View file
|
|
|
mr_jepa/utils/ablation.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Ablation Study Runner for MR-JEPA.
|
| 3 |
+
|
| 4 |
+
Supports systematic ablation experiments to validate the paper's contributions:
|
| 5 |
+
|
| 6 |
+
1. Full MR-JEPA vs. No JEPA (remove JEPA loss, train with task loss only)
|
| 7 |
+
2. Full MR-JEPA vs. No Rollout (use z₀ directly, K=0)
|
| 8 |
+
3. Full MR-JEPA vs. No Evidence Gate (remove gating, always use full evidence)
|
| 9 |
+
4. K=1 vs. K=3 vs. K=5 (rollout depth ablation)
|
| 10 |
+
5. With vs. Without enriched evidence (Phase 3 ablation)
|
| 11 |
+
6. Hybrid vs. Purist branch comparison
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import copy
|
| 15 |
+
import json
|
| 16 |
+
import logging
|
| 17 |
+
from typing import Dict, List, Any, Optional
|
| 18 |
+
from dataclasses import dataclass, field
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
from ..configs.model_config import MRJEPAConfig, get_hybrid_config, get_purist_config
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class AblationConfig:
|
| 28 |
+
"""Configuration for a single ablation experiment."""
|
| 29 |
+
name: str
|
| 30 |
+
description: str
|
| 31 |
+
modifications: Dict[str, Any] = field(default_factory=dict)
|
| 32 |
+
# What to change from the base config
|
| 33 |
+
disable_jepa: bool = False
|
| 34 |
+
disable_rollout: bool = False
|
| 35 |
+
disable_evidence_gate: bool = False
|
| 36 |
+
override_K: Optional[int] = None
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# Predefined ablation experiments
|
| 40 |
+
ABLATION_EXPERIMENTS = {
|
| 41 |
+
"full_model": AblationConfig(
|
| 42 |
+
name="full_model",
|
| 43 |
+
description="Complete MR-JEPA (baseline)",
|
| 44 |
+
),
|
| 45 |
+
"no_jepa": AblationConfig(
|
| 46 |
+
name="no_jepa",
|
| 47 |
+
description="Without JEPA objective (task loss only)",
|
| 48 |
+
disable_jepa=True,
|
| 49 |
+
),
|
| 50 |
+
"no_rollout": AblationConfig(
|
| 51 |
+
name="no_rollout",
|
| 52 |
+
description="Without latent rollout (z₀ only, K=0)",
|
| 53 |
+
disable_rollout=True,
|
| 54 |
+
),
|
| 55 |
+
"no_evidence_gate": AblationConfig(
|
| 56 |
+
name="no_evidence_gate",
|
| 57 |
+
description="Without evidence gating",
|
| 58 |
+
disable_evidence_gate=True,
|
| 59 |
+
),
|
| 60 |
+
"K1": AblationConfig(
|
| 61 |
+
name="K1",
|
| 62 |
+
description="Rollout depth K=1",
|
| 63 |
+
override_K=1,
|
| 64 |
+
),
|
| 65 |
+
"K3": AblationConfig(
|
| 66 |
+
name="K3",
|
| 67 |
+
description="Rollout depth K=3 (default)",
|
| 68 |
+
override_K=3,
|
| 69 |
+
),
|
| 70 |
+
"K5": AblationConfig(
|
| 71 |
+
name="K5",
|
| 72 |
+
description="Rollout depth K=5",
|
| 73 |
+
override_K=5,
|
| 74 |
+
),
|
| 75 |
+
"K7": AblationConfig(
|
| 76 |
+
name="K7",
|
| 77 |
+
description="Rollout depth K=7 (deep rollout)",
|
| 78 |
+
override_K=7,
|
| 79 |
+
),
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class AblationRunner:
|
| 84 |
+
"""
|
| 85 |
+
Systematically run ablation experiments.
|
| 86 |
+
|
| 87 |
+
Usage:
|
| 88 |
+
runner = AblationRunner(base_config, experiments=['full_model', 'no_jepa', 'no_rollout'])
|
| 89 |
+
results = runner.run(train_data, eval_data)
|
| 90 |
+
runner.report()
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
base_config: Optional[MRJEPAConfig] = None,
|
| 96 |
+
experiments: Optional[List[str]] = None,
|
| 97 |
+
output_dir: str = "./ablations",
|
| 98 |
+
):
|
| 99 |
+
self.base_config = base_config or get_hybrid_config()
|
| 100 |
+
self.experiments = experiments or list(ABLATION_EXPERIMENTS.keys())
|
| 101 |
+
self.output_dir = Path(output_dir)
|
| 102 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 103 |
+
self.results = {}
|
| 104 |
+
|
| 105 |
+
def _apply_ablation(self, config: MRJEPAConfig, ablation: AblationConfig) -> MRJEPAConfig:
|
| 106 |
+
"""Apply ablation modifications to a config."""
|
| 107 |
+
modified = copy.deepcopy(config)
|
| 108 |
+
|
| 109 |
+
if ablation.override_K is not None:
|
| 110 |
+
modified.rollout.K = ablation.override_K
|
| 111 |
+
|
| 112 |
+
return modified
|
| 113 |
+
|
| 114 |
+
def generate_configs(self) -> Dict[str, MRJEPAConfig]:
|
| 115 |
+
"""Generate configs for all ablation experiments."""
|
| 116 |
+
configs = {}
|
| 117 |
+
for exp_name in self.experiments:
|
| 118 |
+
if exp_name not in ABLATION_EXPERIMENTS:
|
| 119 |
+
logger.warning(f"Unknown ablation: {exp_name}")
|
| 120 |
+
continue
|
| 121 |
+
|
| 122 |
+
ablation = ABLATION_EXPERIMENTS[exp_name]
|
| 123 |
+
config = self._apply_ablation(self.base_config, ablation)
|
| 124 |
+
configs[exp_name] = config
|
| 125 |
+
|
| 126 |
+
return configs
|
| 127 |
+
|
| 128 |
+
def report(self) -> str:
|
| 129 |
+
"""Generate a formatted ablation report."""
|
| 130 |
+
if not self.results:
|
| 131 |
+
return "No results yet."
|
| 132 |
+
|
| 133 |
+
lines = [
|
| 134 |
+
"=" * 80,
|
| 135 |
+
"MR-JEPA Ablation Study Results",
|
| 136 |
+
"=" * 80,
|
| 137 |
+
"",
|
| 138 |
+
]
|
| 139 |
+
|
| 140 |
+
# Header
|
| 141 |
+
benchmarks = set()
|
| 142 |
+
for exp_results in self.results.values():
|
| 143 |
+
benchmarks.update(exp_results.keys())
|
| 144 |
+
benchmarks = sorted(benchmarks)
|
| 145 |
+
|
| 146 |
+
header = f"{'Experiment':<25}"
|
| 147 |
+
for b in benchmarks:
|
| 148 |
+
header += f" | {b:<12}"
|
| 149 |
+
lines.append(header)
|
| 150 |
+
lines.append("-" * len(header))
|
| 151 |
+
|
| 152 |
+
# Results rows
|
| 153 |
+
for exp_name, exp_results in self.results.items():
|
| 154 |
+
ablation = ABLATION_EXPERIMENTS.get(exp_name)
|
| 155 |
+
row = f"{exp_name:<25}"
|
| 156 |
+
for b in benchmarks:
|
| 157 |
+
if b in exp_results:
|
| 158 |
+
val = exp_results[b].get('accuracy',
|
| 159 |
+
exp_results[b].get('anls',
|
| 160 |
+
exp_results[b].get('vqa_accuracy',
|
| 161 |
+
exp_results[b].get('relaxed_accuracy', 0))))
|
| 162 |
+
row += f" | {val:>10.1f}%"
|
| 163 |
+
else:
|
| 164 |
+
row += f" | {'N/A':>10}"
|
| 165 |
+
lines.append(row)
|
| 166 |
+
|
| 167 |
+
lines.append("")
|
| 168 |
+
lines.append("Key findings:")
|
| 169 |
+
|
| 170 |
+
# Auto-detect key findings
|
| 171 |
+
if 'full_model' in self.results and 'no_jepa' in self.results:
|
| 172 |
+
lines.append("- JEPA vs No-JEPA: Compare 'full_model' and 'no_jepa' rows")
|
| 173 |
+
if 'full_model' in self.results and 'no_rollout' in self.results:
|
| 174 |
+
lines.append("- Rollout vs No-Rollout: Compare 'full_model' and 'no_rollout' rows")
|
| 175 |
+
|
| 176 |
+
report = "\n".join(lines)
|
| 177 |
+
|
| 178 |
+
# Save to file
|
| 179 |
+
with open(self.output_dir / "ablation_report.txt", "w") as f:
|
| 180 |
+
f.write(report)
|
| 181 |
+
|
| 182 |
+
return report
|
mr_jepa/utils/visualization.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Visualization utilities for MR-JEPA.
|
| 3 |
+
|
| 4 |
+
Tools for analyzing and visualizing:
|
| 5 |
+
- Latent trajectory evolution (z₀ → z₁ → z₂ → z₃)
|
| 6 |
+
- Evidence gate activations per rollout step
|
| 7 |
+
- Attention maps between state and evidence
|
| 8 |
+
- t-SNE/UMAP of latent states across benchmarks
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import numpy as np
|
| 13 |
+
from typing import Optional, Dict, List
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def visualize_trajectory(
|
| 17 |
+
trajectory: torch.Tensor, # [K+1, N_s, D]
|
| 18 |
+
method: str = "pca",
|
| 19 |
+
title: str = "Latent Trajectory Evolution",
|
| 20 |
+
) -> Dict[str, np.ndarray]:
|
| 21 |
+
"""
|
| 22 |
+
Visualize the latent trajectory z₀→z₁→...→z_K.
|
| 23 |
+
|
| 24 |
+
Projects high-dimensional states into 2D for plotting.
|
| 25 |
+
Returns coordinates that can be plotted with matplotlib.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
trajectory: [K+1, N_s, D] latent states for a single sample
|
| 29 |
+
method: 'pca' or 'tsne'
|
| 30 |
+
title: Plot title
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
Dict with 'coords': [K+1, 2] projected centroids per step
|
| 34 |
+
"""
|
| 35 |
+
K_plus_1, N_s, D = trajectory.shape
|
| 36 |
+
|
| 37 |
+
# Pool each step's tokens into a single vector
|
| 38 |
+
centroids = trajectory.mean(dim=1).detach().cpu().numpy() # [K+1, D]
|
| 39 |
+
|
| 40 |
+
if method == "pca":
|
| 41 |
+
# Simple PCA (no sklearn dependency)
|
| 42 |
+
centered = centroids - centroids.mean(axis=0)
|
| 43 |
+
cov = np.cov(centered.T)
|
| 44 |
+
eigenvalues, eigenvectors = np.linalg.eigh(cov)
|
| 45 |
+
# Take top 2 components
|
| 46 |
+
idx = np.argsort(eigenvalues)[::-1][:2]
|
| 47 |
+
proj_matrix = eigenvectors[:, idx]
|
| 48 |
+
coords = centered @ proj_matrix
|
| 49 |
+
else:
|
| 50 |
+
# Fallback to PCA for simplicity
|
| 51 |
+
centered = centroids - centroids.mean(axis=0)
|
| 52 |
+
U, S, Vt = np.linalg.svd(centered, full_matrices=False)
|
| 53 |
+
coords = U[:, :2] * S[:2]
|
| 54 |
+
|
| 55 |
+
return {
|
| 56 |
+
'coords': coords, # [K+1, 2]
|
| 57 |
+
'centroids': centroids, # [K+1, D] original
|
| 58 |
+
'step_labels': [f'z_{k}' for k in range(K_plus_1)],
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def visualize_evidence_gates(
|
| 63 |
+
model,
|
| 64 |
+
sample_output: Dict[str, torch.Tensor],
|
| 65 |
+
) -> Dict[str, np.ndarray]:
|
| 66 |
+
"""
|
| 67 |
+
Extract and visualize evidence gate activations per rollout step.
|
| 68 |
+
|
| 69 |
+
Shows how much evidence flows into each step of the rollout.
|
| 70 |
+
Early steps may attend more to visual evidence, while later steps
|
| 71 |
+
rely more on accumulated reasoning.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
model: MRJEPAModel instance
|
| 75 |
+
sample_output: Forward pass output dict
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
Dict with gate activation statistics per step
|
| 79 |
+
"""
|
| 80 |
+
# This requires hooks or storing gate values during forward pass
|
| 81 |
+
# For now, return placeholder structure
|
| 82 |
+
gate_stats = {
|
| 83 |
+
'mean_gate_values': [],
|
| 84 |
+
'gate_entropy': [],
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
# Access predictor layers' evidence gates
|
| 88 |
+
for i, layer in enumerate(model.latent_rollout.predictor_layers):
|
| 89 |
+
if hasattr(layer.evidence_gate, 'gate_proj'):
|
| 90 |
+
# Could install hooks here for detailed analysis
|
| 91 |
+
pass
|
| 92 |
+
|
| 93 |
+
return gate_stats
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def compute_trajectory_metrics(
|
| 97 |
+
trajectory: torch.Tensor, # [B, K+1, N_s, D]
|
| 98 |
+
) -> Dict[str, float]:
|
| 99 |
+
"""
|
| 100 |
+
Compute analytical metrics on the latent trajectory.
|
| 101 |
+
|
| 102 |
+
Useful for ablation analysis:
|
| 103 |
+
- Inter-step distance: how much the state changes per step
|
| 104 |
+
- Trajectory length: total path length in latent space
|
| 105 |
+
- Convergence rate: diminishing step sizes indicate convergence
|
| 106 |
+
- State diversity: variance within each step's tokens
|
| 107 |
+
"""
|
| 108 |
+
B, K_plus_1, N_s, D = trajectory.shape
|
| 109 |
+
|
| 110 |
+
# Pool to centroids
|
| 111 |
+
centroids = trajectory.mean(dim=2) # [B, K+1, D]
|
| 112 |
+
|
| 113 |
+
# Inter-step distances
|
| 114 |
+
step_distances = []
|
| 115 |
+
for k in range(K_plus_1 - 1):
|
| 116 |
+
dist = torch.norm(centroids[:, k+1] - centroids[:, k], dim=-1) # [B]
|
| 117 |
+
step_distances.append(dist.mean().item())
|
| 118 |
+
|
| 119 |
+
# Trajectory length
|
| 120 |
+
total_length = sum(step_distances)
|
| 121 |
+
|
| 122 |
+
# Convergence rate (ratio of last step distance to first)
|
| 123 |
+
convergence = step_distances[-1] / max(step_distances[0], 1e-6) if step_distances else 1.0
|
| 124 |
+
|
| 125 |
+
# State diversity per step
|
| 126 |
+
diversity = []
|
| 127 |
+
for k in range(K_plus_1):
|
| 128 |
+
var = trajectory[:, k].var(dim=1).mean().item() # Avg variance across tokens
|
| 129 |
+
diversity.append(var)
|
| 130 |
+
|
| 131 |
+
return {
|
| 132 |
+
'step_distances': step_distances,
|
| 133 |
+
'trajectory_length': total_length,
|
| 134 |
+
'convergence_rate': convergence,
|
| 135 |
+
'state_diversity': diversity,
|
| 136 |
+
'avg_step_distance': total_length / max(K_plus_1 - 1, 1),
|
| 137 |
+
}
|
test_architecture.py
ADDED
|
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MR-JEPA Architecture Validation Test.
|
| 3 |
+
|
| 4 |
+
Tests the complete forward pass with synthetic data to verify:
|
| 5 |
+
1. All modules instantiate correctly
|
| 6 |
+
2. Tensor shapes are consistent throughout
|
| 7 |
+
3. JEPA loss computes correctly
|
| 8 |
+
4. Target encoder EMA updates work
|
| 9 |
+
5. Both MC and open-ended heads produce valid output
|
| 10 |
+
6. Ablation controls work (no-JEPA, no-rollout, no-evidence-gate)
|
| 11 |
+
7. Parameter counting is correct
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import sys
|
| 15 |
+
sys.path.insert(0, '/app')
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import numpy as np
|
| 20 |
+
from mr_jepa.configs.model_config import (
|
| 21 |
+
MRJEPAConfig, VisualBackboneConfig, TextEncoderConfig,
|
| 22 |
+
EvidenceMemoryConfig, LatentRolloutConfig, JEPAObjectiveConfig,
|
| 23 |
+
AnswerHeadConfig, TrainingPhaseConfig,
|
| 24 |
+
)
|
| 25 |
+
from mr_jepa.models.evidence_memory import EvidenceMemory
|
| 26 |
+
from mr_jepa.models.latent_rollout import LatentRolloutModule
|
| 27 |
+
from mr_jepa.models.target_encoder import TargetEncoder, JEPALoss, SIGRegLoss, VICRegLoss
|
| 28 |
+
from mr_jepa.models.answer_heads import DiscriminativeHead, GenerativeHead
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def test_evidence_memory():
|
| 32 |
+
"""Test Evidence Memory module."""
|
| 33 |
+
print("\n=== Test: Evidence Memory ===")
|
| 34 |
+
|
| 35 |
+
config = EvidenceMemoryConfig(
|
| 36 |
+
hidden_dim=256,
|
| 37 |
+
num_evidence_tokens=16,
|
| 38 |
+
num_cross_attn_layers=2,
|
| 39 |
+
num_heads=4,
|
| 40 |
+
dropout=0.1,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
visual_dim = 512
|
| 44 |
+
text_dim = 384
|
| 45 |
+
B = 4
|
| 46 |
+
N_v = 49 # e.g., 7x7 patches
|
| 47 |
+
N_t = 32 # text tokens
|
| 48 |
+
|
| 49 |
+
model = EvidenceMemory(config, visual_dim=visual_dim, text_dim=text_dim)
|
| 50 |
+
|
| 51 |
+
# Synthetic inputs
|
| 52 |
+
visual_tokens = torch.randn(B, N_v, visual_dim)
|
| 53 |
+
text_tokens = torch.randn(B, N_t, text_dim)
|
| 54 |
+
text_mask = torch.ones(B, N_t) # All valid
|
| 55 |
+
text_mask[:, -5:] = 0 # Last 5 are padding
|
| 56 |
+
|
| 57 |
+
output = model(visual_tokens, text_tokens, text_mask)
|
| 58 |
+
|
| 59 |
+
evidence = output['evidence_tokens']
|
| 60 |
+
kv_tokens = output['kv_tokens']
|
| 61 |
+
|
| 62 |
+
print(f" Evidence tokens shape: {evidence.shape}") # [B, 16, 256]
|
| 63 |
+
print(f" KV tokens shape: {kv_tokens.shape}") # [B, N_v+N_t, 256]
|
| 64 |
+
|
| 65 |
+
assert evidence.shape == (B, config.num_evidence_tokens, config.hidden_dim)
|
| 66 |
+
assert kv_tokens.shape[0] == B
|
| 67 |
+
assert kv_tokens.shape[2] == config.hidden_dim
|
| 68 |
+
|
| 69 |
+
print(" ✓ Evidence Memory passed!")
|
| 70 |
+
return model
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def test_latent_rollout():
|
| 74 |
+
"""Test Latent Rollout module."""
|
| 75 |
+
print("\n=== Test: Latent Rollout ===")
|
| 76 |
+
|
| 77 |
+
config = LatentRolloutConfig(
|
| 78 |
+
hidden_dim=256,
|
| 79 |
+
num_state_tokens=8,
|
| 80 |
+
K=3,
|
| 81 |
+
num_predictor_layers=2,
|
| 82 |
+
num_heads=4,
|
| 83 |
+
ffn_dim=512,
|
| 84 |
+
dropout=0.1,
|
| 85 |
+
use_evidence_gate=True,
|
| 86 |
+
gate_type="sigmoid",
|
| 87 |
+
use_step_embedding=True,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
B = 4
|
| 91 |
+
N_e = 16 # Evidence tokens
|
| 92 |
+
|
| 93 |
+
model = LatentRolloutModule(config)
|
| 94 |
+
|
| 95 |
+
evidence_tokens = torch.randn(B, N_e, config.hidden_dim)
|
| 96 |
+
|
| 97 |
+
output = model(evidence_tokens)
|
| 98 |
+
|
| 99 |
+
trajectory = output['trajectory']
|
| 100 |
+
z_final = output['z_final']
|
| 101 |
+
z_projected = output['z_projected']
|
| 102 |
+
|
| 103 |
+
print(f" Trajectory shape: {trajectory.shape}") # [B, K+1, N_s, D]
|
| 104 |
+
print(f" Z_final shape: {z_final.shape}") # [B, N_s, D]
|
| 105 |
+
print(f" Z_projected shape: {z_projected.shape}") # [B, K+1, N_s, D]
|
| 106 |
+
|
| 107 |
+
assert trajectory.shape == (B, config.K + 1, config.num_state_tokens, config.hidden_dim)
|
| 108 |
+
assert z_final.shape == (B, config.num_state_tokens, config.hidden_dim)
|
| 109 |
+
assert z_projected.shape == trajectory.shape
|
| 110 |
+
|
| 111 |
+
print(" ✓ Latent Rollout passed!")
|
| 112 |
+
return model
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def test_target_encoder_and_jepa_loss():
|
| 116 |
+
"""Test Target Encoder EMA and JEPA Loss."""
|
| 117 |
+
print("\n=== Test: Target Encoder + JEPA Loss ===")
|
| 118 |
+
|
| 119 |
+
D = 256
|
| 120 |
+
N_e = 16
|
| 121 |
+
N_s = 8
|
| 122 |
+
K = 3
|
| 123 |
+
B = 4
|
| 124 |
+
|
| 125 |
+
evidence_config = EvidenceMemoryConfig(
|
| 126 |
+
hidden_dim=D, num_evidence_tokens=N_e,
|
| 127 |
+
num_cross_attn_layers=2, num_heads=4,
|
| 128 |
+
)
|
| 129 |
+
rollout_config = LatentRolloutConfig(
|
| 130 |
+
hidden_dim=D, num_state_tokens=N_s, K=K,
|
| 131 |
+
num_predictor_layers=2, num_heads=4, ffn_dim=512,
|
| 132 |
+
)
|
| 133 |
+
jepa_config = JEPAObjectiveConfig(
|
| 134 |
+
ema_momentum_base=0.996, ema_momentum_end=1.0,
|
| 135 |
+
use_sigreg=True, sigreg_weight=0.1,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Create online modules
|
| 139 |
+
visual_dim = 512
|
| 140 |
+
text_dim = 384
|
| 141 |
+
evidence_mem = EvidenceMemory(evidence_config, visual_dim, text_dim)
|
| 142 |
+
rollout = LatentRolloutModule(rollout_config)
|
| 143 |
+
|
| 144 |
+
# Create target encoder
|
| 145 |
+
target_enc = TargetEncoder(evidence_mem, rollout, jepa_config)
|
| 146 |
+
|
| 147 |
+
# Test EMA update
|
| 148 |
+
original_param = list(target_enc.target_rollout.parameters())[0].clone()
|
| 149 |
+
|
| 150 |
+
# Modify online params
|
| 151 |
+
with torch.no_grad():
|
| 152 |
+
for p in rollout.parameters():
|
| 153 |
+
p.add_(torch.randn_like(p) * 0.1)
|
| 154 |
+
|
| 155 |
+
target_enc.update_ema(evidence_mem, rollout, step=100, total_steps=1000)
|
| 156 |
+
|
| 157 |
+
updated_param = list(target_enc.target_rollout.parameters())[0]
|
| 158 |
+
assert not torch.allclose(original_param, updated_param), "EMA did not update!"
|
| 159 |
+
print(f" EMA momentum: {target_enc._current_momentum:.6f}")
|
| 160 |
+
|
| 161 |
+
# Test target forward
|
| 162 |
+
visual_tokens = torch.randn(B, 49, visual_dim)
|
| 163 |
+
text_tokens = torch.randn(B, 32, text_dim)
|
| 164 |
+
text_mask = torch.ones(B, 32)
|
| 165 |
+
|
| 166 |
+
target_output = target_enc(visual_tokens, text_tokens, text_mask)
|
| 167 |
+
target_traj = target_output['target_trajectory']
|
| 168 |
+
print(f" Target trajectory shape: {target_traj.shape}")
|
| 169 |
+
assert target_traj.shape == (B, K + 1, N_s, D)
|
| 170 |
+
|
| 171 |
+
# Test JEPA Loss
|
| 172 |
+
jepa_loss_fn = JEPALoss(jepa_config, D)
|
| 173 |
+
|
| 174 |
+
pred_traj = torch.randn(B, K + 1, N_s, D, requires_grad=True)
|
| 175 |
+
task_loss = torch.tensor(1.5)
|
| 176 |
+
|
| 177 |
+
loss_dict = jepa_loss_fn(pred_traj, target_traj, task_loss)
|
| 178 |
+
|
| 179 |
+
print(f" JEPA loss: {loss_dict['jepa_loss'].item():.4f}")
|
| 180 |
+
print(f" Task loss: {loss_dict['task_loss'].item():.4f}")
|
| 181 |
+
print(f" Reg loss: {loss_dict['reg_loss'].item():.4f}")
|
| 182 |
+
print(f" Total loss: {loss_dict['total_loss'].item():.4f}")
|
| 183 |
+
|
| 184 |
+
# Check gradients flow
|
| 185 |
+
loss_dict['total_loss'].backward()
|
| 186 |
+
assert pred_traj.grad is not None, "No gradients!"
|
| 187 |
+
print(f" Gradient norm: {pred_traj.grad.norm().item():.4f}")
|
| 188 |
+
|
| 189 |
+
print(" ✓ Target Encoder + JEPA Loss passed!")
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def test_answer_heads():
|
| 193 |
+
"""Test Discriminative and Generative heads."""
|
| 194 |
+
print("\n=== Test: Answer Heads ===")
|
| 195 |
+
|
| 196 |
+
D = 256
|
| 197 |
+
text_dim = 384
|
| 198 |
+
B = 4
|
| 199 |
+
N_s = 8
|
| 200 |
+
max_opts = 4
|
| 201 |
+
vocab_size = 1000
|
| 202 |
+
|
| 203 |
+
head_config = AnswerHeadConfig(
|
| 204 |
+
disc_hidden_dim=256,
|
| 205 |
+
disc_num_layers=2,
|
| 206 |
+
max_num_options=max_opts,
|
| 207 |
+
gen_hidden_dim=256,
|
| 208 |
+
gen_num_layers=2,
|
| 209 |
+
gen_num_heads=4,
|
| 210 |
+
gen_vocab_size=vocab_size,
|
| 211 |
+
gen_max_answer_length=32,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# Test Discriminative Head
|
| 215 |
+
disc_head = DiscriminativeHead(head_config, hidden_dim=D, text_dim=text_dim)
|
| 216 |
+
|
| 217 |
+
z_final = torch.randn(B, N_s, D)
|
| 218 |
+
option_embs = torch.randn(B, max_opts, text_dim)
|
| 219 |
+
option_mask = torch.tensor([
|
| 220 |
+
[True, True, True, True],
|
| 221 |
+
[True, True, True, False],
|
| 222 |
+
[True, True, False, False],
|
| 223 |
+
[True, True, True, True],
|
| 224 |
+
])
|
| 225 |
+
|
| 226 |
+
disc_output = disc_head(z_final, option_embs, option_mask)
|
| 227 |
+
|
| 228 |
+
print(f" Disc logits shape: {disc_output['logits'].shape}") # [B, max_opts]
|
| 229 |
+
print(f" Disc probs shape: {disc_output['probs'].shape}")
|
| 230 |
+
print(f" Sample probs: {disc_output['probs'][0].tolist()}")
|
| 231 |
+
|
| 232 |
+
# Check masking
|
| 233 |
+
assert disc_output['logits'][2, 2] == float('-inf'), "Masked option should be -inf!"
|
| 234 |
+
assert disc_output['probs'][2, 2].item() < 1e-6, "Masked option should have ~0 prob!"
|
| 235 |
+
|
| 236 |
+
# Test Generative Head
|
| 237 |
+
gen_head = GenerativeHead(head_config, hidden_dim=D, vocab_size=vocab_size)
|
| 238 |
+
|
| 239 |
+
target_ids = torch.randint(0, vocab_size, (B, 16))
|
| 240 |
+
|
| 241 |
+
gen_output = gen_head(z_final, target_ids)
|
| 242 |
+
|
| 243 |
+
print(f" Gen logits shape: {gen_output['logits'].shape}") # [B, 16, vocab_size]
|
| 244 |
+
print(f" Gen loss: {gen_output['loss'].item():.4f}")
|
| 245 |
+
|
| 246 |
+
# Test generation
|
| 247 |
+
generated = gen_head.generate(z_final, start_token_id=1, max_length=10)
|
| 248 |
+
print(f" Generated shape: {generated.shape}") # [B, <=10]
|
| 249 |
+
|
| 250 |
+
print(" ✓ Answer Heads passed!")
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def test_sigreg_and_vicreg():
|
| 254 |
+
"""Test anti-collapse regularization losses."""
|
| 255 |
+
print("\n=== Test: SIGReg + VICReg ===")
|
| 256 |
+
|
| 257 |
+
D = 256
|
| 258 |
+
B = 32
|
| 259 |
+
N = 8
|
| 260 |
+
|
| 261 |
+
# SIGReg
|
| 262 |
+
sigreg = SIGRegLoss(D, num_projections=64)
|
| 263 |
+
z = torch.randn(B, N, D)
|
| 264 |
+
loss = sigreg(z)
|
| 265 |
+
print(f" SIGReg loss (random): {loss.item():.4f}")
|
| 266 |
+
|
| 267 |
+
# Test collapse detection
|
| 268 |
+
z_collapsed = torch.ones(B, N, D) # Collapsed representation
|
| 269 |
+
loss_collapsed = sigreg(z_collapsed)
|
| 270 |
+
print(f" SIGReg loss (collapsed): {loss_collapsed.item():.4f}")
|
| 271 |
+
assert loss_collapsed > loss, "SIGReg should penalize collapsed representations more!"
|
| 272 |
+
|
| 273 |
+
# VICReg
|
| 274 |
+
vicreg = VICRegLoss(var_weight=1.0, cov_weight=0.04)
|
| 275 |
+
z = torch.randn(B, N, D)
|
| 276 |
+
loss = vicreg(z)
|
| 277 |
+
print(f" VICReg loss (random): {loss.item():.4f}")
|
| 278 |
+
|
| 279 |
+
print(" ✓ SIGReg + VICReg passed!")
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def test_parameter_counting():
|
| 283 |
+
"""Count and verify parameter distribution."""
|
| 284 |
+
print("\n=== Test: Parameter Counting ===")
|
| 285 |
+
|
| 286 |
+
D = 256
|
| 287 |
+
|
| 288 |
+
evidence_config = EvidenceMemoryConfig(
|
| 289 |
+
hidden_dim=D, num_evidence_tokens=16,
|
| 290 |
+
num_cross_attn_layers=2, num_heads=4,
|
| 291 |
+
)
|
| 292 |
+
rollout_config = LatentRolloutConfig(
|
| 293 |
+
hidden_dim=D, num_state_tokens=8, K=3,
|
| 294 |
+
num_predictor_layers=3, num_heads=4, ffn_dim=512,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
evidence = EvidenceMemory(evidence_config, visual_dim=512, text_dim=384)
|
| 298 |
+
rollout = LatentRolloutModule(rollout_config)
|
| 299 |
+
|
| 300 |
+
def count_params(module):
|
| 301 |
+
return sum(p.numel() for p in module.parameters())
|
| 302 |
+
|
| 303 |
+
def count_trainable(module):
|
| 304 |
+
return sum(p.numel() for p in module.parameters() if p.requires_grad)
|
| 305 |
+
|
| 306 |
+
print(f" Evidence Memory: {count_params(evidence):,} params")
|
| 307 |
+
print(f" Latent Rollout: {count_params(rollout):,} params")
|
| 308 |
+
|
| 309 |
+
# The rollout should be much smaller than the backbone (I-JEPA: narrow predictor)
|
| 310 |
+
print(f" Evidence trainable: {count_trainable(evidence):,}")
|
| 311 |
+
print(f" Rollout trainable: {count_trainable(rollout):,}")
|
| 312 |
+
|
| 313 |
+
print(" ✓ Parameter Counting passed!")
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def test_trajectory_metrics():
|
| 317 |
+
"""Test trajectory analysis utilities."""
|
| 318 |
+
print("\n=== Test: Trajectory Metrics ===")
|
| 319 |
+
|
| 320 |
+
from mr_jepa.utils.visualization import compute_trajectory_metrics, visualize_trajectory
|
| 321 |
+
|
| 322 |
+
B = 4
|
| 323 |
+
K = 3
|
| 324 |
+
N_s = 8
|
| 325 |
+
D = 256
|
| 326 |
+
|
| 327 |
+
# Create a trajectory that converges
|
| 328 |
+
trajectory = torch.randn(B, K + 1, N_s, D)
|
| 329 |
+
# Make each step closer to the previous (simulating convergence)
|
| 330 |
+
for k in range(1, K + 1):
|
| 331 |
+
trajectory[:, k] = trajectory[:, k-1] + torch.randn(B, N_s, D) * (0.5 ** k)
|
| 332 |
+
|
| 333 |
+
metrics = compute_trajectory_metrics(trajectory)
|
| 334 |
+
|
| 335 |
+
print(f" Step distances: {[f'{d:.4f}' for d in metrics['step_distances']]}")
|
| 336 |
+
print(f" Trajectory length: {metrics['trajectory_length']:.4f}")
|
| 337 |
+
print(f" Convergence rate: {metrics['convergence_rate']:.4f}")
|
| 338 |
+
print(f" State diversity: {[f'{d:.4f}' for d in metrics['state_diversity']]}")
|
| 339 |
+
|
| 340 |
+
# Test visualization
|
| 341 |
+
viz = visualize_trajectory(trajectory[0], method='pca')
|
| 342 |
+
print(f" PCA coords shape: {viz['coords'].shape}")
|
| 343 |
+
print(f" Step labels: {viz['step_labels']}")
|
| 344 |
+
|
| 345 |
+
assert metrics['convergence_rate'] < 1.0, "Convergence rate should be < 1 for converging trajectory"
|
| 346 |
+
|
| 347 |
+
print(" ✓ Trajectory Metrics passed!")
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def test_evaluation_metrics():
|
| 351 |
+
"""Test all evaluation metrics."""
|
| 352 |
+
print("\n=== Test: Evaluation Metrics ===")
|
| 353 |
+
|
| 354 |
+
from mr_jepa.evaluation.metrics import (
|
| 355 |
+
compute_accuracy, compute_anls, compute_vqa_accuracy,
|
| 356 |
+
compute_relaxed_accuracy, evaluate_benchmark,
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
# Accuracy
|
| 360 |
+
result = compute_accuracy([0, 1, 2, 0], [0, 1, 1, 0])
|
| 361 |
+
print(f" Accuracy: {result['accuracy']:.1f}%")
|
| 362 |
+
assert result['accuracy'] == 75.0
|
| 363 |
+
|
| 364 |
+
# ANLS
|
| 365 |
+
result = compute_anls(
|
| 366 |
+
["hello world", "test", "abc"],
|
| 367 |
+
[["hello world", "hi world"], ["testing"], ["xyz"]],
|
| 368 |
+
)
|
| 369 |
+
print(f" ANLS: {result['anls']:.1f}%")
|
| 370 |
+
|
| 371 |
+
# VQA Accuracy
|
| 372 |
+
result = compute_vqa_accuracy(
|
| 373 |
+
["cat", "dog"],
|
| 374 |
+
[["cat", "cat", "cat", "kitten", "cat", "cat", "feline", "cat", "cat", "cat"],
|
| 375 |
+
["dog", "puppy", "dog", "canine", "dog", "dog", "dog", "dog", "dog", "dog"]],
|
| 376 |
+
)
|
| 377 |
+
print(f" VQA Accuracy: {result['vqa_accuracy']:.1f}%")
|
| 378 |
+
|
| 379 |
+
# Relaxed Accuracy
|
| 380 |
+
result = compute_relaxed_accuracy(
|
| 381 |
+
["100", "52", "hello"],
|
| 382 |
+
["100", "50", "hello"],
|
| 383 |
+
types=["human_test", "augmented_test", "human_test"],
|
| 384 |
+
)
|
| 385 |
+
print(f" Relaxed Accuracy: {result['relaxed_accuracy']:.1f}%")
|
| 386 |
+
|
| 387 |
+
print(" ✓ Evaluation Metrics passed!")
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def test_end_to_end_forward():
|
| 391 |
+
"""Test a simplified end-to-end forward pass (without pretrained backbones)."""
|
| 392 |
+
print("\n=== Test: End-to-End Forward Pass (Synthetic) ===")
|
| 393 |
+
|
| 394 |
+
D = 256
|
| 395 |
+
B = 2
|
| 396 |
+
N_v = 49
|
| 397 |
+
N_t = 32
|
| 398 |
+
N_e = 16
|
| 399 |
+
N_s = 8
|
| 400 |
+
K = 3
|
| 401 |
+
max_opts = 4
|
| 402 |
+
vocab_size = 100
|
| 403 |
+
visual_dim = 512
|
| 404 |
+
text_dim = 384
|
| 405 |
+
|
| 406 |
+
# Build components manually (without pretrained models)
|
| 407 |
+
evidence_config = EvidenceMemoryConfig(
|
| 408 |
+
hidden_dim=D, num_evidence_tokens=N_e,
|
| 409 |
+
num_cross_attn_layers=2, num_heads=4,
|
| 410 |
+
)
|
| 411 |
+
rollout_config = LatentRolloutConfig(
|
| 412 |
+
hidden_dim=D, num_state_tokens=N_s, K=K,
|
| 413 |
+
num_predictor_layers=2, num_heads=4, ffn_dim=512,
|
| 414 |
+
)
|
| 415 |
+
jepa_config = JEPAObjectiveConfig(use_sigreg=True, sigreg_weight=0.1)
|
| 416 |
+
head_config = AnswerHeadConfig(
|
| 417 |
+
disc_hidden_dim=D, gen_hidden_dim=D, gen_num_layers=2,
|
| 418 |
+
gen_num_heads=4, gen_vocab_size=vocab_size, gen_max_answer_length=16,
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
evidence_mem = EvidenceMemory(evidence_config, visual_dim, text_dim)
|
| 422 |
+
rollout = LatentRolloutModule(rollout_config)
|
| 423 |
+
target_enc = TargetEncoder(evidence_mem, rollout, jepa_config)
|
| 424 |
+
disc_head = DiscriminativeHead(head_config, D, text_dim)
|
| 425 |
+
gen_head = GenerativeHead(head_config, D, vocab_size)
|
| 426 |
+
jepa_loss_fn = JEPALoss(jepa_config, D)
|
| 427 |
+
|
| 428 |
+
# Synthetic inputs
|
| 429 |
+
visual_tokens = torch.randn(B, N_v, visual_dim)
|
| 430 |
+
text_tokens = torch.randn(B, N_t, text_dim)
|
| 431 |
+
text_mask = torch.ones(B, N_t)
|
| 432 |
+
option_embs = torch.randn(B, max_opts, text_dim)
|
| 433 |
+
option_mask = torch.ones(B, max_opts, dtype=torch.bool)
|
| 434 |
+
answer_labels = torch.tensor([1, 3])
|
| 435 |
+
gen_targets = torch.randint(0, vocab_size, (B, 16))
|
| 436 |
+
|
| 437 |
+
# Forward pass
|
| 438 |
+
evidence_output = evidence_mem(visual_tokens, text_tokens, text_mask)
|
| 439 |
+
evidence = evidence_output['evidence_tokens']
|
| 440 |
+
|
| 441 |
+
rollout_output = rollout(evidence)
|
| 442 |
+
trajectory = rollout_output['trajectory']
|
| 443 |
+
z_final = rollout_output['z_final']
|
| 444 |
+
z_projected = rollout_output['z_projected']
|
| 445 |
+
|
| 446 |
+
# Target encoder (no grad)
|
| 447 |
+
target_output = target_enc(visual_tokens, text_tokens, text_mask)
|
| 448 |
+
target_traj = target_output['target_trajectory']
|
| 449 |
+
|
| 450 |
+
# Answer heads
|
| 451 |
+
disc_output = disc_head(z_final, option_embs, option_mask)
|
| 452 |
+
task_loss = nn.functional.cross_entropy(disc_output['logits'], answer_labels)
|
| 453 |
+
|
| 454 |
+
gen_output = gen_head(z_final, gen_targets, evidence)
|
| 455 |
+
|
| 456 |
+
# JEPA loss
|
| 457 |
+
loss_dict = jepa_loss_fn(z_projected, target_traj, task_loss, gen_output['loss'])
|
| 458 |
+
|
| 459 |
+
total_loss = loss_dict['total_loss']
|
| 460 |
+
total_loss.backward()
|
| 461 |
+
|
| 462 |
+
print(f" Evidence shape: {evidence.shape}")
|
| 463 |
+
print(f" Trajectory shape: {trajectory.shape}")
|
| 464 |
+
print(f" Z_final shape: {z_final.shape}")
|
| 465 |
+
print(f" Disc logits: {disc_output['logits'].shape}")
|
| 466 |
+
print(f" Gen logits: {gen_output['logits'].shape}")
|
| 467 |
+
print(f" Total loss: {total_loss.item():.4f}")
|
| 468 |
+
print(f" JEPA loss: {loss_dict['jepa_loss'].item():.4f}")
|
| 469 |
+
print(f" Task loss: {loss_dict['task_loss'].item():.4f}")
|
| 470 |
+
print(f" Gen loss: {loss_dict['gen_loss'].item():.4f}")
|
| 471 |
+
print(f" Reg loss: {loss_dict['reg_loss'].item():.4f}")
|
| 472 |
+
|
| 473 |
+
# EMA update
|
| 474 |
+
target_enc.update_ema(evidence_mem, rollout, step=1, total_steps=100)
|
| 475 |
+
print(f" EMA momentum: {target_enc._current_momentum:.6f}")
|
| 476 |
+
|
| 477 |
+
# Check all gradients flow
|
| 478 |
+
has_grad = sum(1 for p in evidence_mem.parameters() if p.grad is not None)
|
| 479 |
+
total_p = sum(1 for p in evidence_mem.parameters())
|
| 480 |
+
print(f" Evidence memory: {has_grad}/{total_p} params have gradients")
|
| 481 |
+
|
| 482 |
+
has_grad = sum(1 for p in rollout.parameters() if p.grad is not None)
|
| 483 |
+
total_p = sum(1 for p in rollout.parameters())
|
| 484 |
+
print(f" Rollout: {has_grad}/{total_p} params have gradients")
|
| 485 |
+
|
| 486 |
+
print(" ✓ End-to-End Forward Pass passed!")
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
if __name__ == "__main__":
|
| 490 |
+
print("=" * 60)
|
| 491 |
+
print("MR-JEPA Architecture Validation")
|
| 492 |
+
print("=" * 60)
|
| 493 |
+
|
| 494 |
+
test_evidence_memory()
|
| 495 |
+
test_latent_rollout()
|
| 496 |
+
test_target_encoder_and_jepa_loss()
|
| 497 |
+
test_answer_heads()
|
| 498 |
+
test_sigreg_and_vicreg()
|
| 499 |
+
test_parameter_counting()
|
| 500 |
+
test_trajectory_metrics()
|
| 501 |
+
test_evaluation_metrics()
|
| 502 |
+
test_end_to_end_forward()
|
| 503 |
+
|
| 504 |
+
print("\n" + "=" * 60)
|
| 505 |
+
print("ALL TESTS PASSED ✓")
|
| 506 |
+
print("=" * 60)
|