adipanda commited on
Commit
a5eba26
·
verified ·
1 Parent(s): 70d6b42

final write-up

Browse files
Files changed (1) hide show
  1. comparison.md +87 -0
comparison.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TC-LeJEPA ablation — results
2
+
3
+ Text-conditioned LeJEPA: baseline vs two conditioning architectures (FiLM, cross-attention) vs a wrong-label control. Backbone = ViT-Small/16 @ 128×128. Text tower = OpenCLIP ViT-B/32 (frozen). Dataset = CIFAR-100. Loss = `(1-λ)·L_inv + λ·SIGReg`, λ = 0.02. Artifacts: https://huggingface.co/adipanda/lejepa
4
+
5
+ ## Ablation table
6
+
7
+ | Variant | Linear probe top-1 | SIGReg↔acc Spearman | Loss smoothness (last 50%) | Notes |
8
+ |---|---|---|---|---|
9
+ | `baseline` | 0.3502 | -0.943 (n=6) | 0.00270 | vanilla LeJEPA, no text path. |
10
+ | `film` | 0.2941 | -1.000 (n=6) | 0.00120 | MLP predictor with FiLM(text) on each hidden layer; γ,β init=(1,0). |
11
+ | `xattn` | 0.2313 | -0.943 (n=6) | 0.00282 | Patch tokens cross-attend to text; zero-init output proj. |
12
+ | `wrong_text` | 0.2303 | -0.886 (n=6) | 0.00175 | xattn architecture with permuted label-text map — probes whether text **content** matters. |
13
+
14
+ ## Per-variant findings
15
+
16
+ - **`baseline`** — best probe top-1 **0.3502** at epoch 29; final train `lejepa`=0.2282, final online probe=0.2496.
17
+ - **`film`** — best probe top-1 **0.2941** at epoch 29; final train `lejepa`=0.0850, final online probe=0.2028.
18
+ - **`xattn`** — best probe top-1 **0.2313** at epoch 29; final train `lejepa`=0.1334, final online probe=0.1404.
19
+ - **`wrong_text`** — best probe top-1 **0.2303** at epoch 24; final train `lejepa`=0.1357, final online probe=0.1456.
20
+
21
+ ## Research questions — direct answers
22
+
23
+ 1. **Does text conditioning beat vanilla LeJEPA at matched epochs and λ?** — **No**, baseline at 0.3502 ≥ best conditioned (film, 0.2941).
24
+ 2. **FiLM vs cross-attention — which wins and by how much?** — **FiLM wins**: 0.2941 vs xattn 0.2313 (Δ = +6.28 pp).
25
+ 3. **Does text content matter (correct vs wrong vs none)?** — **No (regularization-only effect)**: wrong_text 0.2303 ≈ xattn 0.2313 (Δ = -0.10 pp). Text **content** does not matter at this scale — conditioning is acting as a regularizer, not a semantic signal.
26
+ - baseline (no text): 0.3502 xattn: 0.2313 wrong_text: 0.2303
27
+ 4. **Does the SIGReg↔accuracy correlation survive under conditioning?**
28
+ - `baseline`: ρ = -0.943 (n=6)
29
+ - `film`: ρ = -1.000 (n=6)
30
+ - `xattn`: ρ = -0.943 (n=6)
31
+ - `wrong_text`: ρ = -0.886 (n=6)
32
+ - A strong negative ρ is the *expected* sign for LeJEPA: lower SIGReg ⇒ closer to iso-Gaussian in projection space ⇒ higher linear probe. Conditioning preserves this whenever |ρ| stays close to baseline's.
33
+
34
+ ## Figures
35
+
36
+ ![Training loss and online probe accuracy per epoch, per variant.](figures/loss_curves.png)
37
+ *Training loss and online probe accuracy per epoch, per variant.*
38
+
39
+ ![SIGReg(val) vs linear probe — baseline.](figures/sigreg_vs_acc_baseline.png)
40
+ *SIGReg(val) vs linear probe — baseline.*
41
+
42
+ ![SIGReg(val) vs linear probe — FiLM.](figures/sigreg_vs_acc_film.png)
43
+ *SIGReg(val) vs linear probe — FiLM.*
44
+
45
+ ![SIGReg(val) vs linear probe — cross-attention.](figures/sigreg_vs_acc_xattn.png)
46
+ *SIGReg(val) vs linear probe — cross-attention.*
47
+
48
+ ![SIGReg(val) vs linear probe — wrong-label text.](figures/sigreg_vs_acc_wrong_text.png)
49
+ *SIGReg(val) vs linear probe — wrong-label text.*
50
+
51
+ ![t-SNE of projections under three conditioning prompts (FiLM).](figures/tsne_steer_film.png)
52
+ *t-SNE of projections under three conditioning prompts (FiLM).*
53
+
54
+ ![t-SNE of projections under three conditioning prompts (xattn).](figures/tsne_steer_xattn.png)
55
+ *t-SNE of projections under three conditioning prompts (xattn).*
56
+
57
+ ![t-SNE of projections under three conditioning prompts (wrong_text).](figures/tsne_steer_wrong_text.png)
58
+ *t-SNE of projections under three conditioning prompts (wrong_text).*
59
+
60
+
61
+ ## Setup parity checklist
62
+
63
+ - **Backbone**: ViT-Small/16 @ 128×128, `drop_path_rate=0.1`, no pretraining — identical across variants.
64
+ - **Data**: CIFAR-100 train (50k) / test (10k); shared split, shared two-view augmentation.
65
+ - **Optimizer**: AdamW, `lr=1e-3`, `wd=5e-2`, grad clip 1.0, bf16 autocast, cosine schedule + 2-epoch warmup.
66
+ - **λ**: 0.02 for all variants.
67
+ - **Epochs & batch**: 30 epochs × 512 batch × 2 views, identical.
68
+ - **Eval harness**: same feature extractor, same StandardScaler + logistic regression probe on train/val, same SIGReg probe on ≤4k val projections.
69
+ - **Text tower**: OpenCLIP ViT-B/32 (frozen), `.eval()`, no gradient. Only the text-projection layer inside the predictor is trainable on the text path.
70
+
71
+ ## Honest caveats
72
+
73
+ - **Single dataset** (CIFAR-100, 128×128 upsampled). Findings may not transfer to ImageNet-scale pretraining.
74
+ - **Frozen text encoder**: we use CLIP ViT-B/32 text features verbatim (only the predictor's text-projection layer adapts).
75
+ - **Ablation-scale**: 30 epochs is short for SSL; absolute numbers are lower than published LeJEPA results.
76
+ - **Linear probe only**: no fine-tuning, no transfer to downstream tasks, no k-NN or MLP probe.
77
+ - **Class-name prompts only**: `"a photo of a <class>"`. Richer prompts (scenes, attributes) are left for future work.
78
+ - **SIGReg on val** uses single-view features rather than the two-view training population; it is still comparable across variants (same harness).
79
+
80
+ ## Suggested next experiments
81
+
82
+ - Longer schedules (200+ epochs) and larger backbones (ViT-B/16).
83
+ - Richer text prompts (captions, attributes) rather than class names alone.
84
+ - Downstream transfer (CIFAR-10, STL-10, ImageNet-100 linear probe and full fine-tuning).
85
+ - **Unfreeze** the text tower (or its last block) — does adaptation improve the conditioning signal?
86
+ - Temporal extension à la **LeWorldModel**: condition the predictor on future-frame captions.
87
+ - Per-patch prediction target (true JEPA-style masked prediction) rather than global projection, to match the instruction's 'target patch embeddings' framing more literally.