final write-up
Browse files- comparison.md +87 -0
comparison.md
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TC-LeJEPA ablation — results
|
| 2 |
+
|
| 3 |
+
Text-conditioned LeJEPA: baseline vs two conditioning architectures (FiLM, cross-attention) vs a wrong-label control. Backbone = ViT-Small/16 @ 128×128. Text tower = OpenCLIP ViT-B/32 (frozen). Dataset = CIFAR-100. Loss = `(1-λ)·L_inv + λ·SIGReg`, λ = 0.02. Artifacts: https://huggingface.co/adipanda/lejepa
|
| 4 |
+
|
| 5 |
+
## Ablation table
|
| 6 |
+
|
| 7 |
+
| Variant | Linear probe top-1 | SIGReg↔acc Spearman | Loss smoothness (last 50%) | Notes |
|
| 8 |
+
|---|---|---|---|---|
|
| 9 |
+
| `baseline` | 0.3502 | -0.943 (n=6) | 0.00270 | vanilla LeJEPA, no text path. |
|
| 10 |
+
| `film` | 0.2941 | -1.000 (n=6) | 0.00120 | MLP predictor with FiLM(text) on each hidden layer; γ,β init=(1,0). |
|
| 11 |
+
| `xattn` | 0.2313 | -0.943 (n=6) | 0.00282 | Patch tokens cross-attend to text; zero-init output proj. |
|
| 12 |
+
| `wrong_text` | 0.2303 | -0.886 (n=6) | 0.00175 | xattn architecture with permuted label-text map — probes whether text **content** matters. |
|
| 13 |
+
|
| 14 |
+
## Per-variant findings
|
| 15 |
+
|
| 16 |
+
- **`baseline`** — best probe top-1 **0.3502** at epoch 29; final train `lejepa`=0.2282, final online probe=0.2496.
|
| 17 |
+
- **`film`** — best probe top-1 **0.2941** at epoch 29; final train `lejepa`=0.0850, final online probe=0.2028.
|
| 18 |
+
- **`xattn`** — best probe top-1 **0.2313** at epoch 29; final train `lejepa`=0.1334, final online probe=0.1404.
|
| 19 |
+
- **`wrong_text`** — best probe top-1 **0.2303** at epoch 24; final train `lejepa`=0.1357, final online probe=0.1456.
|
| 20 |
+
|
| 21 |
+
## Research questions — direct answers
|
| 22 |
+
|
| 23 |
+
1. **Does text conditioning beat vanilla LeJEPA at matched epochs and λ?** — **No**, baseline at 0.3502 ≥ best conditioned (film, 0.2941).
|
| 24 |
+
2. **FiLM vs cross-attention — which wins and by how much?** — **FiLM wins**: 0.2941 vs xattn 0.2313 (Δ = +6.28 pp).
|
| 25 |
+
3. **Does text content matter (correct vs wrong vs none)?** — **No (regularization-only effect)**: wrong_text 0.2303 ≈ xattn 0.2313 (Δ = -0.10 pp). Text **content** does not matter at this scale — conditioning is acting as a regularizer, not a semantic signal.
|
| 26 |
+
- baseline (no text): 0.3502 xattn: 0.2313 wrong_text: 0.2303
|
| 27 |
+
4. **Does the SIGReg↔accuracy correlation survive under conditioning?**
|
| 28 |
+
- `baseline`: ρ = -0.943 (n=6)
|
| 29 |
+
- `film`: ρ = -1.000 (n=6)
|
| 30 |
+
- `xattn`: ρ = -0.943 (n=6)
|
| 31 |
+
- `wrong_text`: ρ = -0.886 (n=6)
|
| 32 |
+
- A strong negative ρ is the *expected* sign for LeJEPA: lower SIGReg ⇒ closer to iso-Gaussian in projection space ⇒ higher linear probe. Conditioning preserves this whenever |ρ| stays close to baseline's.
|
| 33 |
+
|
| 34 |
+
## Figures
|
| 35 |
+
|
| 36 |
+

|
| 37 |
+
*Training loss and online probe accuracy per epoch, per variant.*
|
| 38 |
+
|
| 39 |
+

|
| 40 |
+
*SIGReg(val) vs linear probe — baseline.*
|
| 41 |
+
|
| 42 |
+

|
| 43 |
+
*SIGReg(val) vs linear probe — FiLM.*
|
| 44 |
+
|
| 45 |
+

|
| 46 |
+
*SIGReg(val) vs linear probe — cross-attention.*
|
| 47 |
+
|
| 48 |
+

|
| 49 |
+
*SIGReg(val) vs linear probe — wrong-label text.*
|
| 50 |
+
|
| 51 |
+

|
| 52 |
+
*t-SNE of projections under three conditioning prompts (FiLM).*
|
| 53 |
+
|
| 54 |
+

|
| 55 |
+
*t-SNE of projections under three conditioning prompts (xattn).*
|
| 56 |
+
|
| 57 |
+

|
| 58 |
+
*t-SNE of projections under three conditioning prompts (wrong_text).*
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
## Setup parity checklist
|
| 62 |
+
|
| 63 |
+
- **Backbone**: ViT-Small/16 @ 128×128, `drop_path_rate=0.1`, no pretraining — identical across variants.
|
| 64 |
+
- **Data**: CIFAR-100 train (50k) / test (10k); shared split, shared two-view augmentation.
|
| 65 |
+
- **Optimizer**: AdamW, `lr=1e-3`, `wd=5e-2`, grad clip 1.0, bf16 autocast, cosine schedule + 2-epoch warmup.
|
| 66 |
+
- **λ**: 0.02 for all variants.
|
| 67 |
+
- **Epochs & batch**: 30 epochs × 512 batch × 2 views, identical.
|
| 68 |
+
- **Eval harness**: same feature extractor, same StandardScaler + logistic regression probe on train/val, same SIGReg probe on ≤4k val projections.
|
| 69 |
+
- **Text tower**: OpenCLIP ViT-B/32 (frozen), `.eval()`, no gradient. Only the text-projection layer inside the predictor is trainable on the text path.
|
| 70 |
+
|
| 71 |
+
## Honest caveats
|
| 72 |
+
|
| 73 |
+
- **Single dataset** (CIFAR-100, 128×128 upsampled). Findings may not transfer to ImageNet-scale pretraining.
|
| 74 |
+
- **Frozen text encoder**: we use CLIP ViT-B/32 text features verbatim (only the predictor's text-projection layer adapts).
|
| 75 |
+
- **Ablation-scale**: 30 epochs is short for SSL; absolute numbers are lower than published LeJEPA results.
|
| 76 |
+
- **Linear probe only**: no fine-tuning, no transfer to downstream tasks, no k-NN or MLP probe.
|
| 77 |
+
- **Class-name prompts only**: `"a photo of a <class>"`. Richer prompts (scenes, attributes) are left for future work.
|
| 78 |
+
- **SIGReg on val** uses single-view features rather than the two-view training population; it is still comparable across variants (same harness).
|
| 79 |
+
|
| 80 |
+
## Suggested next experiments
|
| 81 |
+
|
| 82 |
+
- Longer schedules (200+ epochs) and larger backbones (ViT-B/16).
|
| 83 |
+
- Richer text prompts (captions, attributes) rather than class names alone.
|
| 84 |
+
- Downstream transfer (CIFAR-10, STL-10, ImageNet-100 linear probe and full fine-tuning).
|
| 85 |
+
- **Unfreeze** the text tower (or its last block) — does adaptation improve the conditioning signal?
|
| 86 |
+
- Temporal extension à la **LeWorldModel**: condition the predictor on future-frame captions.
|
| 87 |
+
- Per-patch prediction target (true JEPA-style masked prediction) rather than global projection, to match the instruction's 'target patch embeddings' framing more literally.
|