vanilla-large-parity-3B

A 1.31B-param vanilla GPT (24 layers, 16 heads, d=2048; no SAE bottleneck) trained to val-loss parity with markhenry/cayley-large-2L-mlp_in-20B, the CayleySAE flagship at the same chassis. Same backbone, same training recipe verbatim except sparsity_mode=none — the SAE bottleneck is the only architectural difference.

Headline

cayley-large-2L-mlp_in-20B vanilla-large-parity-3B
val_loss (CE) 2.8081 2.8284
training tokens 20.0B 3.000B
iters 12,716 5,723
n_layer / d_model / heads 24 / 2048 / 16 24 / 2048 / 16
pos enc learned learned
optimizer Muon + AdamW (decoupled) Muon + AdamW (decoupled)

Δ vs target: +0.020 nats — within the ±1% band [2.78, 2.836] requested for parity. Vanilla reaches the same val-loss as the CayleySAE flagship using 6.7× fewer tokens.

Training recipe (exact, from the producing script)

n_layer=24, n_head=16, n_embd=2048
seq_len=1024, pos_encoding=learned, vocab=50304
sparsity_mode=none
optimizer: Muon + AdamW (decoupled)
muon_lr   = 8e-3 → 1e-4
adamw_lr  = 3e-4 → 1e-5
schedule: linear_warmdown, warmup=200, warmdown_frac=0.5
batch_size = 64, gradient_accumulation_steps = 8
tokens/iter = 64 × 1024 × 8 = 524,288
max_tokens = 3.0B → max_iters = 5,723
warmdown starts at iter 2,861 (50% point)

Hardware: 4× NVIDIA H200 (143 GB), 4-GPU DDP, ~3.4h wall clock. Throughput: 242k tok/s aggregate. VRAM peak (rank 0): 125 GB.

Why D = 3B?

The token budget was chosen by fitting a Chinchilla-form scaling law val(N, D) = E + A·N^(−α) + B·D^(−β) to ten calibration probes:

n_layer d_model N D observed val
6 512 44.7M 0.5B 4.244
6 512 44.7M 1.0B 3.890
6 512 44.7M 2.0B 3.702
12 512 63.5M 1.0B 3.759
12 512 63.5M 2.0B 3.558
24 1024 354M 0.5B 3.573
24 1024 354M 1.0B 3.284
24 1024 354M 2.0B 3.090
24 1024 354M 4.0B 2.958
24 2048 1311M 3.0B 2.828 (this)

Fit on the first 9 calibration points predicted D≈2.93B → val 2.804; actual landing 2.828 (+0.024 nats above prediction, within fit residual scale). A second-iter refit including this point would shift the prediction to D≈3.35B, but 2.828 is already inside the requested ±1% band so we ship.

Param count formula: 12 · n_layer · d² + vocab · d (vocab=50304).

Lineage

Predecessor / target:

Sibling vanilla baselines at the same chassis:

This run is the flagship-recipe-matched parity twin: same Muon/AdamW LRs, same wf=0.5 warmdown shape, same warmup=200, same learned pos enc as the canonical CayleySAE large run. Apples-to-apples for alignment-tax measurement.

Files

  • ckpt.pt — full checkpoint (model + Muon momentum + AdamW state + iter_num, best_val_loss, wandb_step_offset). 5.46 GB.
  • config.json — training config snapshot.
  • train_vanilla_large_parity_3B.sh — the script that produced this checkpoint.

Loading

import torch
ckpt = torch.load("ckpt.pt", map_location="cpu", weights_only=False)
# ckpt contains: model, optimizer_states, iter_num=5723, best_val_loss=2.8284, config, model_config

The included optimizer state is suitable for further training. Inference consumers should pull ckpt["model"] and build the model from config.json.

Project context

Part of the Sparse NanoGPT / CayleySAE research program. This is the vanilla-baseline counterpart to cayley-large-2L-mlp_in-20B for alignment-tax measurements at the 1.3B chassis.

Wandb: probe-vanilla-vanilla-large-parity-3b

— Claude (Opus 4.7)

Downloads last month
33
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including markhenry/vanilla-large-parity-3B