Instructions to use guygrigsby/diff-mlx with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use guygrigsby/diff-mlx with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir diff-mlx guygrigsby/diff-mlx
- Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- LM Studio
Upload README.md with huggingface_hub
Browse files
README.md
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
library_name: mlx
|
| 6 |
+
tags:
|
| 7 |
+
- mlx
|
| 8 |
+
- differential-transformer
|
| 9 |
+
- apple-silicon
|
| 10 |
+
- research
|
| 11 |
+
- ablation
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# diff-mlx: Stage 1 paired checkpoints (Differential Transformer vs vanilla MHA)
|
| 15 |
+
|
| 16 |
+
Final checkpoints from a small-scale, controlled, paired-init reproduction of the **Differential Transformer** (Ye et al., ICLR 2025; [arXiv 2410.05258](https://arxiv.org/abs/2410.05258)), implemented in MLX on Apple Silicon with custom Metal kernels.
|
| 17 |
+
|
| 18 |
+
Code, full writeup, and methodology: **[github.com/guygrigsby/diff-mlx](https://github.com/guygrigsby/diff-mlx)**
|
| 19 |
+
|
| 20 |
+
## What's in here
|
| 21 |
+
|
| 22 |
+
| Path | Variant | Description |
|
| 23 |
+
|---|---|---|
|
| 24 |
+
| `diff/latest.safetensors` | Differential Attention | 162M params, 2.0B tokens, seed 0 |
|
| 25 |
+
| `vanilla/latest.safetensors` | Vanilla MHA baseline | 162M params, 2.0B tokens, seed 0 |
|
| 26 |
+
|
| 27 |
+
Each variant folder also includes its `config.json` and training `metrics.jsonl`. The two models share a **byte-identical paired init** and identical data order, so their difference isolates the attention variant.
|
| 28 |
+
|
| 29 |
+
## Model
|
| 30 |
+
|
| 31 |
+
- Pre-norm LLaMA-style transformer: dim 768, 12 layers, RoPE (interleaved), SwiGLU, RMSNorm, tied embeddings, vocab 100277 (cl100k_base).
|
| 32 |
+
- Context length 2048. bf16 mixed precision.
|
| 33 |
+
- Trained on a FineWeb-Edu sample, 2.0B tokens, effective batch 32, peak LR 4e-4, 1000-step warmup, single M5 Max.
|
| 34 |
+
|
| 35 |
+
## Headline result (the interesting part)
|
| 36 |
+
|
| 37 |
+
On held-out validation, **vanilla edges out diff** at this scale, despite diff winning on train loss:
|
| 38 |
+
|
| 39 |
+
| metric | diff | vanilla | δ (diff − vanilla) |
|
| 40 |
+
|---|---|---|---|
|
| 41 |
+
| Final train loss (last 1000-step mean) | 3.0414 | 3.1526 | −0.111 (diff lower) |
|
| 42 |
+
| Held-out val (75M tok) @ step 30000 | 3.3616 | 3.3265 | +0.035 (vanilla lower) |
|
| 43 |
+
|
| 44 |
+
Diff's train-loss advantage is memorization: its val loss *rose* over the final leg while train loss kept falling. A position-binned eval found vanilla uniformly better across the whole 2048-token window, with no widening of diff's deficit at later positions, so the architecture's signature long-context advantage did not appear here either.
|
| 45 |
+
|
| 46 |
+
This is **three orders of magnitude below** the paper's 3B-param / 1T-token setup, so it refutes nothing about the paper. It is an honest negative for this small-scale, short-context, single-seed regime. See the repo writeup for the full discussion.
|
| 47 |
+
|
| 48 |
+
## Loading
|
| 49 |
+
|
| 50 |
+
```python
|
| 51 |
+
import mlx.core as mx
|
| 52 |
+
params = mx.load("diff/latest.safetensors") # or vanilla/latest.safetensors
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
Model construction lives in the repo (`model.py`, `Transformer(cfg, variant="diff"|"vanilla")`).
|
| 56 |
+
|
| 57 |
+
## License
|
| 58 |
+
|
| 59 |
+
MIT.
|