lejepa / comparison.md
adipanda's picture
add mechanism section
4ac79bb verified

TC-LeJEPA ablation — results

Text-conditioned LeJEPA: baseline vs two conditioning architectures (FiLM, cross-attention) vs a wrong-label control. Backbone = ViT-Small/16 @ 128×128. Text tower = OpenCLIP ViT-B/32 (frozen). Dataset = CIFAR-100. Loss = (1-λ)·L_inv + λ·SIGReg, λ = 0.02. Artifacts: https://huggingface.co/adipanda/lejepa

Ablation table

Variant Linear probe top-1 SIGReg↔acc Spearman Loss smoothness (last 50%) Notes
baseline 0.3502 -0.943 (n=6) 0.00270 vanilla LeJEPA, no text path.
film 0.2941 -1.000 (n=6) 0.00120 MLP predictor with FiLM(text) on each hidden layer; γ,β init=(1,0).
xattn 0.2313 -0.943 (n=6) 0.00282 Patch tokens cross-attend to text; zero-init output proj.
wrong_text 0.2303 -0.886 (n=6) 0.00175 xattn architecture with permuted label-text map — probes whether text content matters.

Per-variant findings

  • baseline — best probe top-1 0.3502 at epoch 29; final train lejepa=0.2282, final online probe=0.2496.
  • film — best probe top-1 0.2941 at epoch 29; final train lejepa=0.0850, final online probe=0.2028.
  • xattn — best probe top-1 0.2313 at epoch 29; final train lejepa=0.1334, final online probe=0.1404.
  • wrong_text — best probe top-1 0.2303 at epoch 24; final train lejepa=0.1357, final online probe=0.1456.

Research questions — direct answers

  1. Does text conditioning beat vanilla LeJEPA at matched epochs and λ?No, baseline at 0.3502 ≥ best conditioned (film, 0.2941).
  2. FiLM vs cross-attention — which wins and by how much?FiLM wins: 0.2941 vs xattn 0.2313 (Δ = +6.28 pp).
  3. Does text content matter (correct vs wrong vs none)?No (regularization-only effect): wrong_text 0.2303 ≈ xattn 0.2313 (Δ = -0.10 pp). Text content does not matter at this scale — conditioning is acting as a regularizer, not a semantic signal.
    • baseline (no text): 0.3502 xattn: 0.2313 wrong_text: 0.2303
  4. Does the SIGReg↔accuracy correlation survive under conditioning?
    • baseline: ρ = -0.943 (n=6)
    • film: ρ = -1.000 (n=6)
    • xattn: ρ = -0.943 (n=6)
    • wrong_text: ρ = -0.886 (n=6)
    • A strong negative ρ is the expected sign for LeJEPA: lower SIGReg ⇒ closer to iso-Gaussian in projection space ⇒ higher linear probe. Conditioning preserves this whenever |ρ| stays close to baseline's.

Mechanism — why conditioning hurt at this scale

The conditioned predictors achieve a 10–20× lower invariance loss than baseline (inv at epoch 29: baseline 0.138, film 0.013, xattn 0.007), yet their linear-probe accuracy is strictly lower. Two orthogonal observations explain this:

  1. Text shortcuts the invariance. Both views of an image share the same prompt, so the predictor can satisfy L_inv by routing information through the (identical-across-views) text channel rather than by learning view-invariant image features. The backbone has no reason to match views — the predictor does it for free. Low L_inv therefore stops implying a well-learned backbone.
  2. The projection space is text-controlled. t-SNE of the same 500 val images under three prompts ("the object", "the background", "the texture") shows a silhouette-by-prompt of +0.51 to +0.60 and a negative silhouette-by-image (−0.33 to −0.49) for every conditioned variant. Projections cluster by prompt, not by image. The predictor is effectively a text-indexed codebook; image content is secondary.

The wrong_text control closes the loop: when we swap each class's prompt for another class's prompt (fixed permutation), accuracy does not change (23.03 vs 23.13). Since the text content is now semantically uncoupled from the image, the residual gain above random must be a regularization effect from conditioning on any constant-per-image signal — not from text semantics.

The SIGReg↔accuracy correlation survives conditioning (|ρ| ≥ 0.89 every variant), so the two-term LeJEPA loss remains a valid model-selection proxy even when the predictor is conditioned; it just selects among worse solutions in the conditioned case.

Figures

Training loss and online probe accuracy per epoch, per variant. Training loss and online probe accuracy per epoch, per variant.

SIGReg(val) vs linear probe — baseline. SIGReg(val) vs linear probe — baseline.

SIGReg(val) vs linear probe — FiLM. SIGReg(val) vs linear probe — FiLM.

SIGReg(val) vs linear probe — cross-attention. SIGReg(val) vs linear probe — cross-attention.

SIGReg(val) vs linear probe — wrong-label text. SIGReg(val) vs linear probe — wrong-label text.

t-SNE of projections under three conditioning prompts (FiLM). t-SNE of projections under three conditioning prompts (FiLM).

t-SNE of projections under three conditioning prompts (xattn). t-SNE of projections under three conditioning prompts (xattn).

t-SNE of projections under three conditioning prompts (wrong_text). t-SNE of projections under three conditioning prompts (wrong_text).

Setup parity checklist

  • Backbone: ViT-Small/16 @ 128×128, drop_path_rate=0.1, no pretraining — identical across variants.
  • Data: CIFAR-100 train (50k) / test (10k); shared split, shared two-view augmentation.
  • Optimizer: AdamW, lr=1e-3, wd=5e-2, grad clip 1.0, bf16 autocast, cosine schedule + 2-epoch warmup.
  • λ: 0.02 for all variants.
  • Epochs & batch: 30 epochs × 512 batch × 2 views, identical.
  • Eval harness: same feature extractor, same StandardScaler + logistic regression probe on train/val, same SIGReg probe on ≤4k val projections.
  • Text tower: OpenCLIP ViT-B/32 (frozen), .eval(), no gradient. Only the text-projection layer inside the predictor is trainable on the text path.

Honest caveats

  • Single dataset (CIFAR-100, 128×128 upsampled). Findings may not transfer to ImageNet-scale pretraining.
  • Frozen text encoder: we use CLIP ViT-B/32 text features verbatim (only the predictor's text-projection layer adapts).
  • Ablation-scale: 30 epochs is short for SSL; absolute numbers are lower than published LeJEPA results.
  • Linear probe only: no fine-tuning, no transfer to downstream tasks, no k-NN or MLP probe.
  • Class-name prompts only: "a photo of a <class>". Richer prompts (scenes, attributes) are left for future work.
  • SIGReg on val uses single-view features rather than the two-view training population; it is still comparable across variants (same harness).

Suggested next experiments

  • Longer schedules (200+ epochs) and larger backbones (ViT-B/16).
  • Richer text prompts (captions, attributes) rather than class names alone.
  • Downstream transfer (CIFAR-10, STL-10, ImageNet-100 linear probe and full fine-tuning).
  • Unfreeze the text tower (or its last block) — does adaptation improve the conditioning signal?
  • Temporal extension à la LeWorldModel: condition the predictor on future-frame captions.
  • Per-patch prediction target (true JEPA-style masked prediction) rather than global projection, to match the instruction's 'target patch embeddings' framing more literally.