| # TC-LeJEPA ablation — results |
|
|
| Text-conditioned LeJEPA: baseline vs two conditioning architectures (FiLM, cross-attention) vs a wrong-label control. Backbone = ViT-Small/16 @ 128×128. Text tower = OpenCLIP ViT-B/32 (frozen). Dataset = CIFAR-100. Loss = `(1-λ)·L_inv + λ·SIGReg`, λ = 0.02. Artifacts: https://huggingface.co/adipanda/lejepa |
|
|
| ## Ablation table |
|
|
| | Variant | Linear probe top-1 | SIGReg↔acc Spearman | Loss smoothness (last 50%) | Notes | |
| |---|---|---|---|---| |
| | `baseline` | 0.3502 | -0.943 (n=6) | 0.00270 | vanilla LeJEPA, no text path. | |
| | `film` | 0.2941 | -1.000 (n=6) | 0.00120 | MLP predictor with FiLM(text) on each hidden layer; γ,β init=(1,0). | |
| | `xattn` | 0.2313 | -0.943 (n=6) | 0.00282 | Patch tokens cross-attend to text; zero-init output proj. | |
| | `wrong_text` | 0.2303 | -0.886 (n=6) | 0.00175 | xattn architecture with permuted label-text map — probes whether text **content** matters. | |
|
|
| ## Per-variant findings |
|
|
| - **`baseline`** — best probe top-1 **0.3502** at epoch 29; final train `lejepa`=0.2282, final online probe=0.2496. |
| - **`film`** — best probe top-1 **0.2941** at epoch 29; final train `lejepa`=0.0850, final online probe=0.2028. |
| - **`xattn`** — best probe top-1 **0.2313** at epoch 29; final train `lejepa`=0.1334, final online probe=0.1404. |
| - **`wrong_text`** — best probe top-1 **0.2303** at epoch 24; final train `lejepa`=0.1357, final online probe=0.1456. |
| |
| ## Research questions — direct answers |
| |
| 1. **Does text conditioning beat vanilla LeJEPA at matched epochs and λ?** — **No**, baseline at 0.3502 ≥ best conditioned (film, 0.2941). |
| 2. **FiLM vs cross-attention — which wins and by how much?** — **FiLM wins**: 0.2941 vs xattn 0.2313 (Δ = +6.28 pp). |
| 3. **Does text content matter (correct vs wrong vs none)?** — **No (regularization-only effect)**: wrong_text 0.2303 ≈ xattn 0.2313 (Δ = -0.10 pp). Text **content** does not matter at this scale — conditioning is acting as a regularizer, not a semantic signal. |
| - baseline (no text): 0.3502 xattn: 0.2313 wrong_text: 0.2303 |
| 4. **Does the SIGReg↔accuracy correlation survive under conditioning?** |
| - `baseline`: ρ = -0.943 (n=6) |
| - `film`: ρ = -1.000 (n=6) |
| - `xattn`: ρ = -0.943 (n=6) |
| - `wrong_text`: ρ = -0.886 (n=6) |
| - A strong negative ρ is the *expected* sign for LeJEPA: lower SIGReg ⇒ closer to iso-Gaussian in projection space ⇒ higher linear probe. Conditioning preserves this whenever |ρ| stays close to baseline's. |
|
|
| ## Mechanism — why conditioning hurt at this scale |
|
|
| The conditioned predictors achieve a **10–20× lower** invariance loss than baseline |
| (`inv` at epoch 29: baseline 0.138, film 0.013, xattn 0.007), yet their linear-probe |
| accuracy is **strictly lower**. Two orthogonal observations explain this: |
|
|
| 1. **Text shortcuts the invariance.** Both views of an image share the same prompt, so |
| the predictor can satisfy `L_inv` by routing information through the (identical-across-views) |
| text channel rather than by learning view-invariant image features. The backbone has no |
| reason to match views — the predictor does it for free. Low `L_inv` therefore stops |
| implying a well-learned backbone. |
| 2. **The projection space is text-controlled.** t-SNE of the same 500 val images under |
| three prompts (`"the object"`, `"the background"`, `"the texture"`) shows a |
| **silhouette-by-prompt of +0.51 to +0.60** and a **negative silhouette-by-image |
| (−0.33 to −0.49)** for every conditioned variant. Projections cluster by *prompt*, not |
| by *image*. The predictor is effectively a text-indexed codebook; image content is secondary. |
|
|
| The `wrong_text` control closes the loop: when we swap each class's prompt for another |
| class's prompt (fixed permutation), accuracy **does not change** (23.03 vs 23.13). Since |
| the text content is now semantically uncoupled from the image, the residual gain above |
| random must be a *regularization* effect from conditioning on any constant-per-image |
| signal — not from text semantics. |
|
|
| The SIGReg↔accuracy correlation survives conditioning (|ρ| ≥ 0.89 every variant), so |
| the two-term LeJEPA loss remains a valid model-selection proxy even when the predictor |
| is conditioned; it just selects among *worse* solutions in the conditioned case. |
|
|
| ## Figures |
|
|
|  |
| *Training loss and online probe accuracy per epoch, per variant.* |
|
|
|  |
| *SIGReg(val) vs linear probe — baseline.* |
|
|
|  |
| *SIGReg(val) vs linear probe — FiLM.* |
|
|
|  |
| *SIGReg(val) vs linear probe — cross-attention.* |
|
|
|  |
| *SIGReg(val) vs linear probe — wrong-label text.* |
|
|
|  |
| *t-SNE of projections under three conditioning prompts (FiLM).* |
|
|
|  |
| *t-SNE of projections under three conditioning prompts (xattn).* |
|
|
|  |
| *t-SNE of projections under three conditioning prompts (wrong_text).* |
|
|
|
|
| ## Setup parity checklist |
|
|
| - **Backbone**: ViT-Small/16 @ 128×128, `drop_path_rate=0.1`, no pretraining — identical across variants. |
| - **Data**: CIFAR-100 train (50k) / test (10k); shared split, shared two-view augmentation. |
| - **Optimizer**: AdamW, `lr=1e-3`, `wd=5e-2`, grad clip 1.0, bf16 autocast, cosine schedule + 2-epoch warmup. |
| - **λ**: 0.02 for all variants. |
| - **Epochs & batch**: 30 epochs × 512 batch × 2 views, identical. |
| - **Eval harness**: same feature extractor, same StandardScaler + logistic regression probe on train/val, same SIGReg probe on ≤4k val projections. |
| - **Text tower**: OpenCLIP ViT-B/32 (frozen), `.eval()`, no gradient. Only the text-projection layer inside the predictor is trainable on the text path. |
|
|
| ## Honest caveats |
|
|
| - **Single dataset** (CIFAR-100, 128×128 upsampled). Findings may not transfer to ImageNet-scale pretraining. |
| - **Frozen text encoder**: we use CLIP ViT-B/32 text features verbatim (only the predictor's text-projection layer adapts). |
| - **Ablation-scale**: 30 epochs is short for SSL; absolute numbers are lower than published LeJEPA results. |
| - **Linear probe only**: no fine-tuning, no transfer to downstream tasks, no k-NN or MLP probe. |
| - **Class-name prompts only**: `"a photo of a <class>"`. Richer prompts (scenes, attributes) are left for future work. |
| - **SIGReg on val** uses single-view features rather than the two-view training population; it is still comparable across variants (same harness). |
|
|
| ## Suggested next experiments |
|
|
| - Longer schedules (200+ epochs) and larger backbones (ViT-B/16). |
| - Richer text prompts (captions, attributes) rather than class names alone. |
| - Downstream transfer (CIFAR-10, STL-10, ImageNet-100 linear probe and full fine-tuning). |
| - **Unfreeze** the text tower (or its last block) — does adaptation improve the conditioning signal? |
| - Temporal extension à la **LeWorldModel**: condition the predictor on future-frame captions. |
| - Per-patch prediction target (true JEPA-style masked prediction) rather than global projection, to match the instruction's 'target patch embeddings' framing more literally. |
|
|