diff-mlx / README.md
guygrigsby's picture
Upload README.md with huggingface_hub
d064e52 verified
metadata
license: mit
language:
  - en
library_name: mlx
tags:
  - mlx
  - differential-transformer
  - apple-silicon
  - research
  - ablation

diff-mlx: Stage 1 paired checkpoints (Differential Transformer vs vanilla MHA)

Final checkpoints from a small-scale, controlled, paired-init reproduction of the Differential Transformer (Ye et al., ICLR 2025; arXiv 2410.05258), implemented in MLX on Apple Silicon with custom Metal kernels.

Code, full writeup, and methodology: github.com/guygrigsby/diff-mlx

What's in here

Path Variant Description
diff/latest.safetensors Differential Attention 162M params, 2.0B tokens, seed 0
vanilla/latest.safetensors Vanilla MHA baseline 162M params, 2.0B tokens, seed 0

Each variant folder also has its config.json and training metrics.jsonl. The two models share a byte-identical paired init and identical data order, so the difference between them isolates the attention variant.

Model

  • Pre-norm LLaMA-style transformer: dim 768, 12 layers, interleaved RoPE, SwiGLU, RMSNorm, tied embeddings, vocab 100277 (cl100k_base).
  • Context length 2048. bf16 mixed precision.
  • Trained on a FineWeb-Edu sample, 2.0B tokens, effective batch 32, peak LR 4e-4, 1000-step warmup, on one M5 Max.

The headline (the interesting part)

On held-out validation, vanilla edges out diff at this scale, even though diff wins on train loss:

metric diff vanilla δ (diff − vanilla)
final train loss (last 1000-step mean) 3.0414 3.1526 −0.111 (diff lower)
held-out val (75M tok) @ step 30000 3.3616 3.3265 +0.035 (vanilla lower)

Diff's train-loss lead is memorization: its val loss rose over the final leg while train loss kept falling. A position-binned eval put vanilla uniformly ahead across the whole 2048-token window, with no widening of diff's deficit at later positions, so the architecture's long-context edge didn't show up here either.

This sits three orders of magnitude below the paper's 3B-param / 1T-token setup, so it refutes nothing about the paper. It's an honest negative for this small-scale, short-context, single-seed regime. Full discussion in the repo writeup.

Loading

import mlx.core as mx
params = mx.load("diff/latest.safetensors")  # or vanilla/latest.safetensors

Model construction lives in the repo (model.py, Transformer(cfg, variant="diff"|"vanilla")).

License

MIT.