fix: ARCHITECTURE.md — complete ablation table with all 13 experiments, CLI flags, dinov2/loss_fn/sigreg/vicreg ablations, footnote on no_rollout→no_jepa
Browse files- mr_jepa/ARCHITECTURE.md +31 -23
mr_jepa/ARCHITECTURE.md
CHANGED
|
@@ -29,7 +29,7 @@ The core insight: solving a multimodal question (e.g., "What is the GDP growth s
|
|
| 29 |
│
|
| 30 |
┌─────────┐ JEPA Loss:
|
| 31 |
│Optional:│ SmoothL1/Cosine
|
| 32 |
-
│OCR,SAM, │──────────┘ + SIGReg
|
| 33 |
│Layout │
|
| 34 |
└─────────┘
|
| 35 |
```
|
|
@@ -151,30 +151,28 @@ The online predictor must predict these targets.
|
|
| 151 |
|
| 152 |
### 2.6 JEPA Objective
|
| 153 |
|
| 154 |
-
**Prediction loss** (hybrid branch — SmoothL1
|
| 155 |
```
|
| 156 |
L_JEPA = (1/K) Σ_{k=1}^{K} SmoothL1(z_pred_k, sg(z*_k))
|
| 157 |
```
|
| 158 |
Only steps k=1..K are supervised (z₀ is deterministic from evidence).
|
| 159 |
|
| 160 |
-
**Alternative
|
|
|
|
|
|
|
| 161 |
|
| 162 |
-
**Anti-collapse regularization**
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
```
|
| 166 |
-
Where T is the Epps-Pulley normality test statistic, u_m are random unit vectors.
|
| 167 |
-
This encourages latent embeddings to remain Gaussian-distributed, preventing collapse.
|
| 168 |
-
|
| 169 |
-
**Alternative (hybrid branch)**: VICReg (variance-invariance-covariance) regularization.
|
| 170 |
|
| 171 |
**Total loss**:
|
| 172 |
```
|
| 173 |
-
L_total = L_JEPA + L_task + λ · L_reg + α · L_gen
|
| 174 |
|
| 175 |
Where:
|
| 176 |
L_task = CrossEntropy(disc_head(z_K), answer_label) # MC scoring
|
| 177 |
L_gen = CE(gen_head(z_K), target_answer_tokens) # Short answer (Phase 3)
|
|
|
|
| 178 |
λ = 0.1 (regularization weight)
|
| 179 |
α = 0.1 (generative weight)
|
| 180 |
```
|
|
@@ -241,23 +239,33 @@ Qwen3.5-4B (or SmolLM3-3B) decoder:
|
|
| 241 |
|
| 242 |
## 4. Ablation Experiments
|
| 243 |
|
| 244 |
-
###
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
-
|
| 247 |
-
|------------|-------------|-----------------|
|
| 248 |
-
| **Full MR-JEPA** | Baseline | Best overall |
|
| 249 |
-
| **No JEPA** | Remove L_JEPA, train with task loss only | Drops on reasoning-heavy benchmarks |
|
| 250 |
-
| **No Rollout** | K=0, use z₀ directly | Significant drop (proves rollout value) |
|
| 251 |
-
| **No Evidence Gate** | Remove gating | Slight drop (gate helps focus) |
|
| 252 |
-
| **K=1** | Shallow rollout | Worse than K=3 |
|
| 253 |
-
| **K=5** | Deeper rollout | Diminishing returns |
|
| 254 |
-
| **No SIGReg** | Remove anti-collapse | Training instability |
|
| 255 |
-
| **Purist branch** | DINOv3-B, no enriched evidence | Lower absolute scores, but validates JEPA contribution |
|
| 256 |
|
| 257 |
### Cross-benchmark analysis:
|
| 258 |
- JEPA contribution should be highest on **reasoning** benchmarks (MathVista, MMMU, ScienceQA)
|
| 259 |
- Evidence gate contribution should be highest on **evidence-rich** benchmarks (DocVQA, ChartQA)
|
| 260 |
- Enriched evidence (Phase 3) should matter most for **document** benchmarks
|
|
|
|
| 261 |
|
| 262 |
---
|
| 263 |
|
|
|
|
| 29 |
│
|
| 30 |
┌─────────┐ JEPA Loss:
|
| 31 |
│Optional:│ SmoothL1/Cosine
|
| 32 |
+
│OCR,SAM, │──────────┘ + SIGReg/VICReg
|
| 33 |
│Layout │
|
| 34 |
└─────────┘
|
| 35 |
```
|
|
|
|
| 151 |
|
| 152 |
### 2.6 JEPA Objective
|
| 153 |
|
| 154 |
+
**Prediction loss** (hybrid branch — SmoothL1, more robust than L2):
|
| 155 |
```
|
| 156 |
L_JEPA = (1/K) Σ_{k=1}^{K} SmoothL1(z_pred_k, sg(z*_k))
|
| 157 |
```
|
| 158 |
Only steps k=1..K are supervised (z₀ is deterministic from evidence).
|
| 159 |
|
| 160 |
+
**Alternative losses** (ablation):
|
| 161 |
+
- **MSE** (L2): original I-JEPA loss
|
| 162 |
+
- **Cosine**: 1 - cos_sim, used in purist branch
|
| 163 |
|
| 164 |
+
**Anti-collapse regularization**:
|
| 165 |
+
- **SIGReg** (from LeWorldModel): Epps-Pulley normality test on random projections, encourages Gaussian-distributed latents
|
| 166 |
+
- **VICReg**: variance (keep std ≥ 1) + covariance (decorrelate features) regularization
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
**Total loss**:
|
| 169 |
```
|
| 170 |
+
L_total = w_jepa · L_JEPA + w_task · L_task + λ · L_reg + α · L_gen
|
| 171 |
|
| 172 |
Where:
|
| 173 |
L_task = CrossEntropy(disc_head(z_K), answer_label) # MC scoring
|
| 174 |
L_gen = CE(gen_head(z_K), target_answer_tokens) # Short answer (Phase 3)
|
| 175 |
+
L_reg = SIGReg and/or VICReg
|
| 176 |
λ = 0.1 (regularization weight)
|
| 177 |
α = 0.1 (generative weight)
|
| 178 |
```
|
|
|
|
| 239 |
|
| 240 |
## 4. Ablation Experiments
|
| 241 |
|
| 242 |
+
### Complete ablation matrix
|
| 243 |
+
|
| 244 |
+
Each experiment maps 1:1 to a CLI flag in `train_mrjepa.py`.
|
| 245 |
+
|
| 246 |
+
| Experiment | CLI flag | Modification | Expected finding |
|
| 247 |
+
|------------|----------|-------------|-----------------|
|
| 248 |
+
| `hybrid_main` | *(default)* | Full model (DINOv3-L, K=3, SmoothL1+VICReg) | Best overall |
|
| 249 |
+
| `no_jepa` | `--no_jepa` | Remove L_JEPA, task loss only | Drops on reasoning-heavy benchmarks |
|
| 250 |
+
| `no_rollout` | `--no_rollout` | K=0, use z₀ directly (also disables JEPA¹) | Significant drop (proves rollout value) |
|
| 251 |
+
| `no_gate` | `--no_evidence_gate` | Remove sigmoid evidence gating | Slight drop (gate helps focus) |
|
| 252 |
+
| `K1` | `--K 1` | Shallow rollout | Worse than K=3 |
|
| 253 |
+
| `K5` | `--K 5` | Deeper rollout | Diminishing returns |
|
| 254 |
+
| `K7` | `--K 7` | Very deep rollout | Overfitting / diminishing returns |
|
| 255 |
+
| `dinov2_ablation` | `--backbone dinov2` | DINOv2-L/14 instead of DINOv3-L/16 | DINOv3 > DINOv2 due to Gram anchoring + RoPE |
|
| 256 |
+
| `mse_loss` | `--loss_fn mse` | MSE (L2) JEPA loss (original I-JEPA) | Slightly worse than SmoothL1 |
|
| 257 |
+
| `cosine_loss` | `--loss_fn cosine` | Cosine similarity JEPA loss | Better for purist, similar for hybrid |
|
| 258 |
+
| `no_sigreg` | `--no_sigreg` | Disable SIGReg anti-collapse | Training instability / representation collapse |
|
| 259 |
+
| `vicreg_only` | `--no_sigreg --use_vicreg` | VICReg only (no SIGReg) | Alternative anti-collapse strategy |
|
| 260 |
+
| `purist` | `--purist` | DINOv3-B, K=5, Cosine+SIGReg, no enriched ev. | Lower absolute, validates JEPA contribution |
|
| 261 |
|
| 262 |
+
¹ `no_rollout` also disables JEPA because with K=0 there is only z₀ — no trajectory to supervise. To test JEPA in isolation, use `--no_jepa` with K>0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
### Cross-benchmark analysis:
|
| 265 |
- JEPA contribution should be highest on **reasoning** benchmarks (MathVista, MMMU, ScienceQA)
|
| 266 |
- Evidence gate contribution should be highest on **evidence-rich** benchmarks (DocVQA, ChartQA)
|
| 267 |
- Enriched evidence (Phase 3) should matter most for **document** benchmarks
|
| 268 |
+
- DINOv3 vs DINOv2 gap should be largest on **fine-grained visual** benchmarks (AI2D, ChartQA)
|
| 269 |
|
| 270 |
---
|
| 271 |
|