File size: 7,078 Bytes
ac618f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4d2cf2
 
 
ac618f3
f4d2cf2
 
 
 
ac618f3
 
 
 
 
 
 
3ac4183
ac618f3
 
 
 
 
 
 
 
 
 
f4d2cf2
 
 
 
 
 
ac618f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ac4183
ac618f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ac4183
 
ac618f3
 
 
 
 
 
3ac4183
 
ac618f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""Architecture ablation harness.

Trains each variant for the SAME fixed token budget on the SAME data, then emits
a markdown results table + JSON. The point of an ablation is a controlled
comparison: one thing changes per row, everything else held fixed.

    python scripts/ablate.py --data-dir data/fwedu --tokens 300000000
    python scripts/ablate.py --dry-run            # tiny synthetic, for CI/sanity
    python scripts/ablate.py --data-dir data/fwedu --only baseline no_qk_norm

Each variant only overrides what it's testing; the rest inherits from BASE below.
Results land in results/ablations/<name>/ (checkpoints + metrics.jsonl) and a
combined docs/ABLATIONS.md + results/ablations.json.
"""

from __future__ import annotations

import os
import sys
import json
import argparse
from pathlib import Path

ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT / "src"))

from matilda.config import ModelConfig                          # noqa: E402
from matilda.train import Trainer, TrainConfig                  # noqa: E402
from matilda.model import Transformer                           # noqa: E402
from matilda.data import SyntheticStream, BinStream, shard_paths  # noqa: E402

# Base config every variant inherits. Real architecture, short token budget.
BASE_MODEL = dict(vocab_size=50257, max_seq_len=1024, d_model=768,
                  n_layers=12, n_heads=12, n_kv_heads=4)
import torch  # noqa: E402

_HAS_CUDA = torch.cuda.is_available()
BASE_TRAIN = dict(seq_len=1024, batch_size=24, grad_accum=8, warmup_steps=80,
                  lr=6e-4, log_every=20, ckpt_every=100000,
                  device="cuda" if _HAS_CUDA else "cpu",
                  dtype="bfloat16" if _HAS_CUDA else "float32",
                  compile=_HAS_CUDA)

# One change per row. Inherits BASE; only the listed keys differ.
VARIANTS = [
    {"name": "baseline",   "model": {}},
    {"name": "no_qk_norm", "model": {"qk_norm": False, "attn_logit_softcap": 20.0}},
    {"name": "mha",        "model": {"n_kv_heads": 12}},   # full multi-head (no GQA)
    {"name": "mqa",        "model": {"n_kv_heads": 1}},    # multi-query (extreme GQA)
    {"name": "muon",       "model": {}, "train": {"optimizer": "muon"}},
]


def steps_for_tokens(tokens, tcfg: TrainConfig) -> int:
    return max(1, round(tokens / (tcfg.batch_size * tcfg.grad_accum * tcfg.seq_len)))


def final_metrics(ckpt_dir, tail=5) -> dict:
    """Average the last `tail` logged steps from metrics.jsonl."""
    path = os.path.join(ckpt_dir, "metrics.jsonl")
    try:
        with open(path) as f:
            rows = [json.loads(l) for l in f if l.strip()]
    except FileNotFoundError:
        return {"loss": None, "mfu": None}     # crashed before logging a step
    steps = [r for r in rows if r.get("event") == "step"]
    if not steps:
        return {"loss": None, "mfu": None}
    last = steps[-tail:]
    avg = lambda k: sum(s[k] for s in last if s.get(k) is not None) / len(last)
    return {"loss": round(avg("loss"), 4), "mfu": round(avg("mfu"), 4),
            "tokens_per_s": round(avg("tokens_per_s"))}


def run_variant(v, tokens, data_dir, dry_run, results_root):
    model = {**BASE_MODEL, **v.get("model", {})}
    train = {**BASE_TRAIN, **v.get("train", {})}
    if dry_run:                                   # shrink for CPU/CI
        model.update(d_model=64, n_layers=2, n_heads=4,
                     n_kv_heads=min(4, model.get("n_kv_heads", 4)), max_seq_len=64)
        if v["name"] == "mha":
            model["n_kv_heads"] = 4
        train.update(seq_len=64, batch_size=4, grad_accum=1, warmup_steps=2,
                     device="cpu", dtype="float32", compile=False)
    ckpt_dir = os.path.join(results_root, v["name"])
    train["ckpt_dir"] = ckpt_dir

    mcfg = ModelConfig(**model)
    tcfg = TrainConfig(total_steps=steps_for_tokens(tokens, TrainConfig(**train)),
                       **train)
    active = Transformer(mcfg).num_params(non_embedding=True)

    if dry_run or not data_dir:
        stream = SyntheticStream(mcfg.vocab_size, tcfg.batch_size, tcfg.seq_len,
                                 seed=0, device=tcfg.device)
    else:
        stream = BinStream(shard_paths(data_dir), tcfg.batch_size, tcfg.seq_len,
                           seed=0, device=tcfg.device)

    print(f"\n=== variant: {v['name']} | steps={tcfg.total_steps} "
          f"active={active/1e6:.1f}M kv_heads={mcfg.n_kv_heads} "
          f"qk_norm={mcfg.qk_norm} ===")
    Trainer(mcfg, tcfg, stream).train()
    m = final_metrics(ckpt_dir)
    return {"name": v["name"], "active_params_m": round(active / 1e6, 1),
            "n_kv_heads": mcfg.n_kv_heads, "qk_norm": mcfg.qk_norm,
            "optimizer": tcfg.optimizer,
            "final_loss": m["loss"], "mfu": m["mfu"],
            "tokens_per_s": m.get("tokens_per_s")}


def write_table(rows, tokens, out_md, out_json):
    os.makedirs(os.path.dirname(out_md), exist_ok=True)
    os.makedirs(os.path.dirname(out_json), exist_ok=True)
    json.dump({"tokens_per_variant": tokens, "rows": rows},
              open(out_json, "w"), indent=2)
    lines = [
        "# Architecture Ablations",
        "",
        f"Each variant trained on **{tokens/1e6:.0f}M tokens** of the same data, "
        "identical except for the column under test.",
        "",
        "| Variant | Active params | n_kv_heads | QK-norm | Optimizer | Final loss | MFU | tok/s |",
        "|---------|--------------|-----------|---------|-----------|-----------|-----|-------|",
    ]
    for r in rows:
        mfu = f"{r['mfu']*100:.1f}%" if r["mfu"] is not None else "—"
        tps = f"{r['tokens_per_s']:,}" if r.get("tokens_per_s") else "—"
        loss = r["final_loss"] if r["final_loss"] is not None else "—"
        lines.append(f"| {r['name']} | {r['active_params_m']}M | {r['n_kv_heads']} "
                     f"| {r['qk_norm']} | {r.get('optimizer','adamw')} | {loss} "
                     f"| {mfu} | {tps} |")
    open(out_md, "w").write("\n".join(lines) + "\n")
    print(f"\nwrote {out_md} and {out_json}")


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--data-dir", default=None)
    ap.add_argument("--tokens", type=int, default=300_000_000)
    ap.add_argument("--dry-run", action="store_true")
    ap.add_argument("--only", nargs="*", default=None,
                    help="subset of variant names to run")
    ap.add_argument("--results", default=str(ROOT / "results" / "ablations"))
    args = ap.parse_args()

    tokens = 50_000 if args.dry_run else args.tokens
    variants = [v for v in VARIANTS
                if args.only is None or v["name"] in args.only]
    rows = [run_variant(v, tokens, args.data_dir, args.dry_run, args.results)
            for v in variants]
    write_table(rows, tokens,
                out_md=str(ROOT / "docs" / "ABLATIONS.md"),
                out_json=str(Path(args.results) / "ablations.json"))
    print("\n".join(f"  {r['name']:12} loss={r['final_loss']} mfu={r['mfu']}"
                    for r in rows))


if __name__ == "__main__":
    main()