File size: 1,363 Bytes
ac618f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""Ablation harness: variants run, metrics captured, table emitted (dry CPU)."""

import os
import sys
import json
from pathlib import Path

ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT / "scripts"))

import ablate  # noqa: E402


def test_steps_for_tokens():
    from matilda.train import TrainConfig
    tc = TrainConfig(batch_size=24, grad_accum=8, seq_len=1024)
    # 300M / (24*8*1024) ~= 1526
    assert ablate.steps_for_tokens(300_000_000, tc) == 1526


def test_run_variant_dry_and_table(tmp_path):
    results = str(tmp_path / "abl")
    rows = []
    for vname in ("baseline", "mqa"):
        v = next(x for x in ablate.VARIANTS if x["name"] == vname)
        row = ablate.run_variant(v, tokens=50_000, data_dir=None,
                                 dry_run=True, results_root=results)
        rows.append(row)
        assert row["final_loss"] is not None      # metrics captured
        assert row["mfu"] is not None
    assert rows[1]["n_kv_heads"] == 1              # mqa override applied

    md = tmp_path / "ABLATIONS.md"
    js = tmp_path / "ablations.json"
    ablate.write_table(rows, 50_000, str(md), str(js))
    assert md.exists() and js.exists()
    text = md.read_text()
    assert "baseline" in text and "mqa" in text and "MFU" in text
    data = json.loads(js.read_text())
    assert len(data["rows"]) == 2