File size: 2,546 Bytes
7bb42cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d064e52
7bb42cb
 
 
d064e52
7bb42cb
d064e52
7bb42cb
d064e52
7bb42cb
d064e52
7bb42cb
 
 
d064e52
 
7bb42cb
d064e52
7bb42cb
d064e52
7bb42cb
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
---
license: mit
language:
- en
library_name: mlx
tags:
- mlx
- differential-transformer
- apple-silicon
- research
- ablation
---

# diff-mlx: Stage 1 paired checkpoints (Differential Transformer vs vanilla MHA)

Final checkpoints from a small-scale, controlled, paired-init reproduction of the **Differential Transformer** (Ye et al., ICLR 2025; [arXiv 2410.05258](https://arxiv.org/abs/2410.05258)), implemented in MLX on Apple Silicon with custom Metal kernels.

Code, full writeup, and methodology: **[github.com/guygrigsby/diff-mlx](https://github.com/guygrigsby/diff-mlx)**

## What's in here

| Path | Variant | Description |
|---|---|---|
| `diff/latest.safetensors` | Differential Attention | 162M params, 2.0B tokens, seed 0 |
| `vanilla/latest.safetensors` | Vanilla MHA baseline | 162M params, 2.0B tokens, seed 0 |

Each variant folder also has its `config.json` and training `metrics.jsonl`. The two models share a **byte-identical paired init** and identical data order, so the difference between them isolates the attention variant.

## Model

- Pre-norm LLaMA-style transformer: dim 768, 12 layers, interleaved RoPE, SwiGLU, RMSNorm, tied embeddings, vocab 100277 (cl100k_base).
- Context length 2048. bf16 mixed precision.
- Trained on a FineWeb-Edu sample, 2.0B tokens, effective batch 32, peak LR 4e-4, 1000-step warmup, on one M5 Max.

## The headline (the interesting part)

On held-out validation, **vanilla edges out diff** at this scale, even though diff wins on train loss:

| metric | diff | vanilla | δ (diff − vanilla) |
|---|---|---|---|
| final train loss (last 1000-step mean) | 3.0414 | 3.1526 | −0.111 (diff lower) |
| held-out val (75M tok) @ step 30000 | 3.3616 | 3.3265 | +0.035 (vanilla lower) |

Diff's train-loss lead is memorization: its val loss *rose* over the final leg while train loss kept falling. A position-binned eval put vanilla uniformly ahead across the whole 2048-token window, with no widening of diff's deficit at later positions, so the architecture's long-context edge didn't show up here either.

This sits **three orders of magnitude below** the paper's 3B-param / 1T-token setup, so it refutes nothing about the paper. It's an honest negative for this small-scale, short-context, single-seed regime. Full discussion in the repo writeup.

## Loading

```python
import mlx.core as mx
params = mx.load("diff/latest.safetensors")  # or vanilla/latest.safetensors
```

Model construction lives in the repo (`model.py`, `Transformer(cfg, variant="diff"|"vanilla")`).

## License

MIT.