JorgeAV commited on
Commit
bac1326
·
verified ·
1 Parent(s): 7741500

fix: ARCHITECTURE.md — complete ablation table with all 13 experiments, CLI flags, dinov2/loss_fn/sigreg/vicreg ablations, footnote on no_rollout→no_jepa

Browse files
Files changed (1) hide show
  1. mr_jepa/ARCHITECTURE.md +31 -23
mr_jepa/ARCHITECTURE.md CHANGED
@@ -29,7 +29,7 @@ The core insight: solving a multimodal question (e.g., "What is the GDP growth s
29
 
30
  ┌─────────┐ JEPA Loss:
31
  │Optional:│ SmoothL1/Cosine
32
- │OCR,SAM, │──────────┘ + SIGReg
33
  │Layout │
34
  └─────────┘
35
  ```
@@ -151,30 +151,28 @@ The online predictor must predict these targets.
151
 
152
  ### 2.6 JEPA Objective
153
 
154
- **Prediction loss** (hybrid branch — SmoothL1 from I-JEPA, more robust than L2):
155
  ```
156
  L_JEPA = (1/K) Σ_{k=1}^{K} SmoothL1(z_pred_k, sg(z*_k))
157
  ```
158
  Only steps k=1..K are supervised (z₀ is deterministic from evidence).
159
 
160
- **Alternative (purist branch)**: Cosine similarity loss.
 
 
161
 
162
- **Anti-collapse regularization** (from LeWorldModel — SIGReg):
163
- ```
164
- L_SIGReg = (1/M) Σ_{m=1}^{M} T(Z · u_m)
165
- ```
166
- Where T is the Epps-Pulley normality test statistic, u_m are random unit vectors.
167
- This encourages latent embeddings to remain Gaussian-distributed, preventing collapse.
168
-
169
- **Alternative (hybrid branch)**: VICReg (variance-invariance-covariance) regularization.
170
 
171
  **Total loss**:
172
  ```
173
- L_total = L_JEPA + L_task + λ · L_reg + α · L_gen
174
 
175
  Where:
176
  L_task = CrossEntropy(disc_head(z_K), answer_label) # MC scoring
177
  L_gen = CE(gen_head(z_K), target_answer_tokens) # Short answer (Phase 3)
 
178
  λ = 0.1 (regularization weight)
179
  α = 0.1 (generative weight)
180
  ```
@@ -241,23 +239,33 @@ Qwen3.5-4B (or SmolLM3-3B) decoder:
241
 
242
  ## 4. Ablation Experiments
243
 
244
- ### Key ablations for the paper:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
- | Experiment | Modification | Expected finding |
247
- |------------|-------------|-----------------|
248
- | **Full MR-JEPA** | Baseline | Best overall |
249
- | **No JEPA** | Remove L_JEPA, train with task loss only | Drops on reasoning-heavy benchmarks |
250
- | **No Rollout** | K=0, use z₀ directly | Significant drop (proves rollout value) |
251
- | **No Evidence Gate** | Remove gating | Slight drop (gate helps focus) |
252
- | **K=1** | Shallow rollout | Worse than K=3 |
253
- | **K=5** | Deeper rollout | Diminishing returns |
254
- | **No SIGReg** | Remove anti-collapse | Training instability |
255
- | **Purist branch** | DINOv3-B, no enriched evidence | Lower absolute scores, but validates JEPA contribution |
256
 
257
  ### Cross-benchmark analysis:
258
  - JEPA contribution should be highest on **reasoning** benchmarks (MathVista, MMMU, ScienceQA)
259
  - Evidence gate contribution should be highest on **evidence-rich** benchmarks (DocVQA, ChartQA)
260
  - Enriched evidence (Phase 3) should matter most for **document** benchmarks
 
261
 
262
  ---
263
 
 
29
 
30
  ┌─────────┐ JEPA Loss:
31
  │Optional:│ SmoothL1/Cosine
32
+ │OCR,SAM, │──────────┘ + SIGReg/VICReg
33
  │Layout │
34
  └─────────┘
35
  ```
 
151
 
152
  ### 2.6 JEPA Objective
153
 
154
+ **Prediction loss** (hybrid branch — SmoothL1, more robust than L2):
155
  ```
156
  L_JEPA = (1/K) Σ_{k=1}^{K} SmoothL1(z_pred_k, sg(z*_k))
157
  ```
158
  Only steps k=1..K are supervised (z₀ is deterministic from evidence).
159
 
160
+ **Alternative losses** (ablation):
161
+ - **MSE** (L2): original I-JEPA loss
162
+ - **Cosine**: 1 - cos_sim, used in purist branch
163
 
164
+ **Anti-collapse regularization**:
165
+ - **SIGReg** (from LeWorldModel): Epps-Pulley normality test on random projections, encourages Gaussian-distributed latents
166
+ - **VICReg**: variance (keep std ≥ 1) + covariance (decorrelate features) regularization
 
 
 
 
 
167
 
168
  **Total loss**:
169
  ```
170
+ L_total = w_jepa · L_JEPA + w_task · L_task + λ · L_reg + α · L_gen
171
 
172
  Where:
173
  L_task = CrossEntropy(disc_head(z_K), answer_label) # MC scoring
174
  L_gen = CE(gen_head(z_K), target_answer_tokens) # Short answer (Phase 3)
175
+ L_reg = SIGReg and/or VICReg
176
  λ = 0.1 (regularization weight)
177
  α = 0.1 (generative weight)
178
  ```
 
239
 
240
  ## 4. Ablation Experiments
241
 
242
+ ### Complete ablation matrix
243
+
244
+ Each experiment maps 1:1 to a CLI flag in `train_mrjepa.py`.
245
+
246
+ | Experiment | CLI flag | Modification | Expected finding |
247
+ |------------|----------|-------------|-----------------|
248
+ | `hybrid_main` | *(default)* | Full model (DINOv3-L, K=3, SmoothL1+VICReg) | Best overall |
249
+ | `no_jepa` | `--no_jepa` | Remove L_JEPA, task loss only | Drops on reasoning-heavy benchmarks |
250
+ | `no_rollout` | `--no_rollout` | K=0, use z₀ directly (also disables JEPA¹) | Significant drop (proves rollout value) |
251
+ | `no_gate` | `--no_evidence_gate` | Remove sigmoid evidence gating | Slight drop (gate helps focus) |
252
+ | `K1` | `--K 1` | Shallow rollout | Worse than K=3 |
253
+ | `K5` | `--K 5` | Deeper rollout | Diminishing returns |
254
+ | `K7` | `--K 7` | Very deep rollout | Overfitting / diminishing returns |
255
+ | `dinov2_ablation` | `--backbone dinov2` | DINOv2-L/14 instead of DINOv3-L/16 | DINOv3 > DINOv2 due to Gram anchoring + RoPE |
256
+ | `mse_loss` | `--loss_fn mse` | MSE (L2) JEPA loss (original I-JEPA) | Slightly worse than SmoothL1 |
257
+ | `cosine_loss` | `--loss_fn cosine` | Cosine similarity JEPA loss | Better for purist, similar for hybrid |
258
+ | `no_sigreg` | `--no_sigreg` | Disable SIGReg anti-collapse | Training instability / representation collapse |
259
+ | `vicreg_only` | `--no_sigreg --use_vicreg` | VICReg only (no SIGReg) | Alternative anti-collapse strategy |
260
+ | `purist` | `--purist` | DINOv3-B, K=5, Cosine+SIGReg, no enriched ev. | Lower absolute, validates JEPA contribution |
261
 
262
+ ¹ `no_rollout` also disables JEPA because with K=0 there is only z₀ — no trajectory to supervise. To test JEPA in isolation, use `--no_jepa` with K>0.
 
 
 
 
 
 
 
 
 
263
 
264
  ### Cross-benchmark analysis:
265
  - JEPA contribution should be highest on **reasoning** benchmarks (MathVista, MMMU, ScienceQA)
266
  - Evidence gate contribution should be highest on **evidence-rich** benchmarks (DocVQA, ChartQA)
267
  - Enriched evidence (Phase 3) should matter most for **document** benchmarks
268
+ - DINOv3 vs DINOv2 gap should be largest on **fine-grained visual** benchmarks (AI2D, ChartQA)
269
 
270
  ---
271