lejepa / comparison.md
adipanda's picture
add mechanism section
4ac79bb verified
# TC-LeJEPA ablation — results
Text-conditioned LeJEPA: baseline vs two conditioning architectures (FiLM, cross-attention) vs a wrong-label control. Backbone = ViT-Small/16 @ 128×128. Text tower = OpenCLIP ViT-B/32 (frozen). Dataset = CIFAR-100. Loss = `(1-λ)·L_inv + λ·SIGReg`, λ = 0.02. Artifacts: https://huggingface.co/adipanda/lejepa
## Ablation table
| Variant | Linear probe top-1 | SIGReg↔acc Spearman | Loss smoothness (last 50%) | Notes |
|---|---|---|---|---|
| `baseline` | 0.3502 | -0.943 (n=6) | 0.00270 | vanilla LeJEPA, no text path. |
| `film` | 0.2941 | -1.000 (n=6) | 0.00120 | MLP predictor with FiLM(text) on each hidden layer; γ,β init=(1,0). |
| `xattn` | 0.2313 | -0.943 (n=6) | 0.00282 | Patch tokens cross-attend to text; zero-init output proj. |
| `wrong_text` | 0.2303 | -0.886 (n=6) | 0.00175 | xattn architecture with permuted label-text map — probes whether text **content** matters. |
## Per-variant findings
- **`baseline`** — best probe top-1 **0.3502** at epoch 29; final train `lejepa`=0.2282, final online probe=0.2496.
- **`film`** — best probe top-1 **0.2941** at epoch 29; final train `lejepa`=0.0850, final online probe=0.2028.
- **`xattn`** — best probe top-1 **0.2313** at epoch 29; final train `lejepa`=0.1334, final online probe=0.1404.
- **`wrong_text`** — best probe top-1 **0.2303** at epoch 24; final train `lejepa`=0.1357, final online probe=0.1456.
## Research questions — direct answers
1. **Does text conditioning beat vanilla LeJEPA at matched epochs and λ?** — **No**, baseline at 0.3502 ≥ best conditioned (film, 0.2941).
2. **FiLM vs cross-attention — which wins and by how much?** — **FiLM wins**: 0.2941 vs xattn 0.2313 (Δ = +6.28 pp).
3. **Does text content matter (correct vs wrong vs none)?** — **No (regularization-only effect)**: wrong_text 0.2303 ≈ xattn 0.2313 (Δ = -0.10 pp). Text **content** does not matter at this scale — conditioning is acting as a regularizer, not a semantic signal.
- baseline (no text): 0.3502 xattn: 0.2313 wrong_text: 0.2303
4. **Does the SIGReg↔accuracy correlation survive under conditioning?**
- `baseline`: ρ = -0.943 (n=6)
- `film`: ρ = -1.000 (n=6)
- `xattn`: ρ = -0.943 (n=6)
- `wrong_text`: ρ = -0.886 (n=6)
- A strong negative ρ is the *expected* sign for LeJEPA: lower SIGReg ⇒ closer to iso-Gaussian in projection space ⇒ higher linear probe. Conditioning preserves this whenever |ρ| stays close to baseline's.
## Mechanism — why conditioning hurt at this scale
The conditioned predictors achieve a **10–20× lower** invariance loss than baseline
(`inv` at epoch 29: baseline 0.138, film 0.013, xattn 0.007), yet their linear-probe
accuracy is **strictly lower**. Two orthogonal observations explain this:
1. **Text shortcuts the invariance.** Both views of an image share the same prompt, so
the predictor can satisfy `L_inv` by routing information through the (identical-across-views)
text channel rather than by learning view-invariant image features. The backbone has no
reason to match views — the predictor does it for free. Low `L_inv` therefore stops
implying a well-learned backbone.
2. **The projection space is text-controlled.** t-SNE of the same 500 val images under
three prompts (`"the object"`, `"the background"`, `"the texture"`) shows a
**silhouette-by-prompt of +0.51 to +0.60** and a **negative silhouette-by-image
(−0.33 to −0.49)** for every conditioned variant. Projections cluster by *prompt*, not
by *image*. The predictor is effectively a text-indexed codebook; image content is secondary.
The `wrong_text` control closes the loop: when we swap each class's prompt for another
class's prompt (fixed permutation), accuracy **does not change** (23.03 vs 23.13). Since
the text content is now semantically uncoupled from the image, the residual gain above
random must be a *regularization* effect from conditioning on any constant-per-image
signal — not from text semantics.
The SIGReg↔accuracy correlation survives conditioning (|ρ| ≥ 0.89 every variant), so
the two-term LeJEPA loss remains a valid model-selection proxy even when the predictor
is conditioned; it just selects among *worse* solutions in the conditioned case.
## Figures
![Training loss and online probe accuracy per epoch, per variant.](figures/loss_curves.png)
*Training loss and online probe accuracy per epoch, per variant.*
![SIGReg(val) vs linear probe — baseline.](figures/sigreg_vs_acc_baseline.png)
*SIGReg(val) vs linear probe — baseline.*
![SIGReg(val) vs linear probe — FiLM.](figures/sigreg_vs_acc_film.png)
*SIGReg(val) vs linear probe — FiLM.*
![SIGReg(val) vs linear probe — cross-attention.](figures/sigreg_vs_acc_xattn.png)
*SIGReg(val) vs linear probe — cross-attention.*
![SIGReg(val) vs linear probe — wrong-label text.](figures/sigreg_vs_acc_wrong_text.png)
*SIGReg(val) vs linear probe — wrong-label text.*
![t-SNE of projections under three conditioning prompts (FiLM).](figures/tsne_steer_film.png)
*t-SNE of projections under three conditioning prompts (FiLM).*
![t-SNE of projections under three conditioning prompts (xattn).](figures/tsne_steer_xattn.png)
*t-SNE of projections under three conditioning prompts (xattn).*
![t-SNE of projections under three conditioning prompts (wrong_text).](figures/tsne_steer_wrong_text.png)
*t-SNE of projections under three conditioning prompts (wrong_text).*
## Setup parity checklist
- **Backbone**: ViT-Small/16 @ 128×128, `drop_path_rate=0.1`, no pretraining — identical across variants.
- **Data**: CIFAR-100 train (50k) / test (10k); shared split, shared two-view augmentation.
- **Optimizer**: AdamW, `lr=1e-3`, `wd=5e-2`, grad clip 1.0, bf16 autocast, cosine schedule + 2-epoch warmup.
- **λ**: 0.02 for all variants.
- **Epochs & batch**: 30 epochs × 512 batch × 2 views, identical.
- **Eval harness**: same feature extractor, same StandardScaler + logistic regression probe on train/val, same SIGReg probe on ≤4k val projections.
- **Text tower**: OpenCLIP ViT-B/32 (frozen), `.eval()`, no gradient. Only the text-projection layer inside the predictor is trainable on the text path.
## Honest caveats
- **Single dataset** (CIFAR-100, 128×128 upsampled). Findings may not transfer to ImageNet-scale pretraining.
- **Frozen text encoder**: we use CLIP ViT-B/32 text features verbatim (only the predictor's text-projection layer adapts).
- **Ablation-scale**: 30 epochs is short for SSL; absolute numbers are lower than published LeJEPA results.
- **Linear probe only**: no fine-tuning, no transfer to downstream tasks, no k-NN or MLP probe.
- **Class-name prompts only**: `"a photo of a <class>"`. Richer prompts (scenes, attributes) are left for future work.
- **SIGReg on val** uses single-view features rather than the two-view training population; it is still comparable across variants (same harness).
## Suggested next experiments
- Longer schedules (200+ epochs) and larger backbones (ViT-B/16).
- Richer text prompts (captions, attributes) rather than class names alone.
- Downstream transfer (CIFAR-10, STL-10, ImageNet-100 linear probe and full fine-tuning).
- **Unfreeze** the text tower (or its last block) — does adaptation improve the conditioning signal?
- Temporal extension à la **LeWorldModel**: condition the predictor on future-frame captions.
- Per-patch prediction target (true JEPA-style masked prediction) rather than global projection, to match the instruction's 'target patch embeddings' framing more literally.