File size: 7,078 Bytes
ac618f3 f4d2cf2 ac618f3 f4d2cf2 ac618f3 3ac4183 ac618f3 f4d2cf2 ac618f3 3ac4183 ac618f3 3ac4183 ac618f3 3ac4183 ac618f3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | """Architecture ablation harness.
Trains each variant for the SAME fixed token budget on the SAME data, then emits
a markdown results table + JSON. The point of an ablation is a controlled
comparison: one thing changes per row, everything else held fixed.
python scripts/ablate.py --data-dir data/fwedu --tokens 300000000
python scripts/ablate.py --dry-run # tiny synthetic, for CI/sanity
python scripts/ablate.py --data-dir data/fwedu --only baseline no_qk_norm
Each variant only overrides what it's testing; the rest inherits from BASE below.
Results land in results/ablations/<name>/ (checkpoints + metrics.jsonl) and a
combined docs/ABLATIONS.md + results/ablations.json.
"""
from __future__ import annotations
import os
import sys
import json
import argparse
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT / "src"))
from matilda.config import ModelConfig # noqa: E402
from matilda.train import Trainer, TrainConfig # noqa: E402
from matilda.model import Transformer # noqa: E402
from matilda.data import SyntheticStream, BinStream, shard_paths # noqa: E402
# Base config every variant inherits. Real architecture, short token budget.
BASE_MODEL = dict(vocab_size=50257, max_seq_len=1024, d_model=768,
n_layers=12, n_heads=12, n_kv_heads=4)
import torch # noqa: E402
_HAS_CUDA = torch.cuda.is_available()
BASE_TRAIN = dict(seq_len=1024, batch_size=24, grad_accum=8, warmup_steps=80,
lr=6e-4, log_every=20, ckpt_every=100000,
device="cuda" if _HAS_CUDA else "cpu",
dtype="bfloat16" if _HAS_CUDA else "float32",
compile=_HAS_CUDA)
# One change per row. Inherits BASE; only the listed keys differ.
VARIANTS = [
{"name": "baseline", "model": {}},
{"name": "no_qk_norm", "model": {"qk_norm": False, "attn_logit_softcap": 20.0}},
{"name": "mha", "model": {"n_kv_heads": 12}}, # full multi-head (no GQA)
{"name": "mqa", "model": {"n_kv_heads": 1}}, # multi-query (extreme GQA)
{"name": "muon", "model": {}, "train": {"optimizer": "muon"}},
]
def steps_for_tokens(tokens, tcfg: TrainConfig) -> int:
return max(1, round(tokens / (tcfg.batch_size * tcfg.grad_accum * tcfg.seq_len)))
def final_metrics(ckpt_dir, tail=5) -> dict:
"""Average the last `tail` logged steps from metrics.jsonl."""
path = os.path.join(ckpt_dir, "metrics.jsonl")
try:
with open(path) as f:
rows = [json.loads(l) for l in f if l.strip()]
except FileNotFoundError:
return {"loss": None, "mfu": None} # crashed before logging a step
steps = [r for r in rows if r.get("event") == "step"]
if not steps:
return {"loss": None, "mfu": None}
last = steps[-tail:]
avg = lambda k: sum(s[k] for s in last if s.get(k) is not None) / len(last)
return {"loss": round(avg("loss"), 4), "mfu": round(avg("mfu"), 4),
"tokens_per_s": round(avg("tokens_per_s"))}
def run_variant(v, tokens, data_dir, dry_run, results_root):
model = {**BASE_MODEL, **v.get("model", {})}
train = {**BASE_TRAIN, **v.get("train", {})}
if dry_run: # shrink for CPU/CI
model.update(d_model=64, n_layers=2, n_heads=4,
n_kv_heads=min(4, model.get("n_kv_heads", 4)), max_seq_len=64)
if v["name"] == "mha":
model["n_kv_heads"] = 4
train.update(seq_len=64, batch_size=4, grad_accum=1, warmup_steps=2,
device="cpu", dtype="float32", compile=False)
ckpt_dir = os.path.join(results_root, v["name"])
train["ckpt_dir"] = ckpt_dir
mcfg = ModelConfig(**model)
tcfg = TrainConfig(total_steps=steps_for_tokens(tokens, TrainConfig(**train)),
**train)
active = Transformer(mcfg).num_params(non_embedding=True)
if dry_run or not data_dir:
stream = SyntheticStream(mcfg.vocab_size, tcfg.batch_size, tcfg.seq_len,
seed=0, device=tcfg.device)
else:
stream = BinStream(shard_paths(data_dir), tcfg.batch_size, tcfg.seq_len,
seed=0, device=tcfg.device)
print(f"\n=== variant: {v['name']} | steps={tcfg.total_steps} "
f"active={active/1e6:.1f}M kv_heads={mcfg.n_kv_heads} "
f"qk_norm={mcfg.qk_norm} ===")
Trainer(mcfg, tcfg, stream).train()
m = final_metrics(ckpt_dir)
return {"name": v["name"], "active_params_m": round(active / 1e6, 1),
"n_kv_heads": mcfg.n_kv_heads, "qk_norm": mcfg.qk_norm,
"optimizer": tcfg.optimizer,
"final_loss": m["loss"], "mfu": m["mfu"],
"tokens_per_s": m.get("tokens_per_s")}
def write_table(rows, tokens, out_md, out_json):
os.makedirs(os.path.dirname(out_md), exist_ok=True)
os.makedirs(os.path.dirname(out_json), exist_ok=True)
json.dump({"tokens_per_variant": tokens, "rows": rows},
open(out_json, "w"), indent=2)
lines = [
"# Architecture Ablations",
"",
f"Each variant trained on **{tokens/1e6:.0f}M tokens** of the same data, "
"identical except for the column under test.",
"",
"| Variant | Active params | n_kv_heads | QK-norm | Optimizer | Final loss | MFU | tok/s |",
"|---------|--------------|-----------|---------|-----------|-----------|-----|-------|",
]
for r in rows:
mfu = f"{r['mfu']*100:.1f}%" if r["mfu"] is not None else "—"
tps = f"{r['tokens_per_s']:,}" if r.get("tokens_per_s") else "—"
loss = r["final_loss"] if r["final_loss"] is not None else "—"
lines.append(f"| {r['name']} | {r['active_params_m']}M | {r['n_kv_heads']} "
f"| {r['qk_norm']} | {r.get('optimizer','adamw')} | {loss} "
f"| {mfu} | {tps} |")
open(out_md, "w").write("\n".join(lines) + "\n")
print(f"\nwrote {out_md} and {out_json}")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--data-dir", default=None)
ap.add_argument("--tokens", type=int, default=300_000_000)
ap.add_argument("--dry-run", action="store_true")
ap.add_argument("--only", nargs="*", default=None,
help="subset of variant names to run")
ap.add_argument("--results", default=str(ROOT / "results" / "ablations"))
args = ap.parse_args()
tokens = 50_000 if args.dry_run else args.tokens
variants = [v for v in VARIANTS
if args.only is None or v["name"] in args.only]
rows = [run_variant(v, tokens, args.data_dir, args.dry_run, args.results)
for v in variants]
write_table(rows, tokens,
out_md=str(ROOT / "docs" / "ABLATIONS.md"),
out_json=str(Path(args.results) / "ablations.json"))
print("\n".join(f" {r['name']:12} loss={r['final_loss']} mfu={r['mfu']}"
for r in rows))
if __name__ == "__main__":
main()
|