How to use from the
Use from the
MLX library
# Download the model from the Hub
pip install huggingface_hub[hf_xet]

huggingface-cli download --local-dir diff-mlx guygrigsby/diff-mlx

diff-mlx: Stage 1 paired checkpoints (Differential Transformer vs vanilla MHA)

Final checkpoints from a small-scale, controlled, paired-init reproduction of the Differential Transformer (Ye et al., ICLR 2025; arXiv 2410.05258), implemented in MLX on Apple Silicon with custom Metal kernels.

Code, full writeup, and methodology: github.com/guygrigsby/diff-mlx

What's in here

Path Variant Description
diff/latest.safetensors Differential Attention 162M params, 2.0B tokens, seed 0
vanilla/latest.safetensors Vanilla MHA baseline 162M params, 2.0B tokens, seed 0

Each variant folder also has its config.json and training metrics.jsonl. The two models share a byte-identical paired init and identical data order, so the difference between them isolates the attention variant.

Model

  • Pre-norm LLaMA-style transformer: dim 768, 12 layers, interleaved RoPE, SwiGLU, RMSNorm, tied embeddings, vocab 100277 (cl100k_base).
  • Context length 2048. bf16 mixed precision.
  • Trained on a FineWeb-Edu sample, 2.0B tokens, effective batch 32, peak LR 4e-4, 1000-step warmup, on one M5 Max.

The headline (the interesting part)

On held-out validation, vanilla edges out diff at this scale, even though diff wins on train loss:

metric diff vanilla δ (diff − vanilla)
final train loss (last 1000-step mean) 3.0414 3.1526 −0.111 (diff lower)
held-out val (75M tok) @ step 30000 3.3616 3.3265 +0.035 (vanilla lower)

Diff's train-loss lead is memorization: its val loss rose over the final leg while train loss kept falling. A position-binned eval put vanilla uniformly ahead across the whole 2048-token window, with no widening of diff's deficit at later positions, so the architecture's long-context edge didn't show up here either.

This sits three orders of magnitude below the paper's 3B-param / 1T-token setup, so it refutes nothing about the paper. It's an honest negative for this small-scale, short-context, single-seed regime. Full discussion in the repo writeup.

Loading

import mlx.core as mx
params = mx.load("diff/latest.safetensors")  # or vanilla/latest.safetensors

Model construction lives in the repo (model.py, Transformer(cfg, variant="diff"|"vanilla")).

License

MIT.

Downloads last month

-

Downloads are not tracked for this model. How to track
MLX
Hardware compatibility
Log In to add your hardware

Quantized

Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Paper for guygrigsby/diff-mlx