Stage 4B: Larger Specialist with Cosine Loss
Tried the natural next knobs on Stage 4's specialist student: 5Γ bigger model, cosine similarity loss on the full 768-D pooled teacher output, longer schedule.
Setup
- Architecture: depth 8, embed 384, 6 heads, MLP ratio 4, patch 16 β 15.67M parameters
- Target: full 768-D pooled layernormed output from EUPE-ViT-B (not the 40-dim subset used in Stage 4)
- Loss: 1 β cosine_similarity(student_output, teacher_target)
- Schedule: 15 epochs Γ 117,266 COCO train images, batch 16, AdamW lr 5e-4, cosine schedule with 3 % warmup
- Eval: apply Stage 0 classifier weights to the 40 classifier-relevant dims of the student's 768-D output; sweep threshold
Result
Stage Student params Loss F1 checkpoint
4 3.27 M MSE on 40-D 0.717 ep3
4B 15.67 M cosine on 768-D 0.726 ep10 (shipped)
0 85.64 M (ViT-B) baseline 0.889 β
Cosine loss converged in epoch 1 (0.072 β 0.061) and stayed flat through epoch 15. F1 peaked at 0.726 at epoch 10; epoch 15 drifted down to 0.723. The shipped student_final.safetensors is the epoch 10 checkpoint.
What this says
The student reproduces the teacher's pooled feature geometry well in aggregate (cosine β 0.94 across 768 dims), but the 40 classifier-relevant dims are not all equal. Even a small average error on those specific axes destroys Stage 0's precision β every epoch shows precision around 0.57 and recall approaching 1.0, i.e., the student is consistently over-firing.
Two candidate next iterations:
- Dim-weighted cosine: scale the cosine loss by a per-dim importance weight, with the 40 classifier-relevant dims weighted heavily. The student would then be forced to reproduce those exact values rather than any 40 dims of equal average fidelity.
- Direct classifier supervision: train the student to minimize
|score_student - score_teacher|wherescore = sum(pos_dims) - sum(neg_dims), not the 768-D vector.
Either is cheaper than further capacity/epoch scaling.
Files
student.pyβ architectureprepare_targets_768.pyβ builds the 768-D teacher target tensor from the ViT-B cachetrain.pyβ training loopstudent_ep{5,10,15}.safetensorsβ intermediate checkpointsstudent_final.safetensorsβ final weightstraining_log.jsonβ per-epoch loss + F1