File size: 4,811 Bytes
e51ccda | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | import pytest
torch = pytest.importorskip("torch")
from sgjm.eval.checkpoint import load_checkpoint
from sgjm.eval.metrics import (
BaselineEvalMetrics,
SGJMEvalMetrics,
compare,
evaluate_baseline,
evaluate_sgjm,
)
from sgjm.training.config import TrainingConfig
from sgjm.training.data import ByteDataset, synthetic_corpus
from sgjm.training.torch_backend.baseline import BaselineLM
from sgjm.training.torch_backend.model import SGJM
from sgjm.training.torch_backend.trainer import train
def _smoke_cfg(arch: str, ckpt_dir) -> TrainingConfig:
cfg = TrainingConfig.smoke()
cfg.arch = arch
cfg.checkpoint_dir = str(ckpt_dir)
return cfg
def test_baseline_model_param_count_matches_sgjm_total():
cfg = TrainingConfig.sgjm_25m()
sgjm = SGJM(cfg.model)
baseline = BaselineLM(cfg.model)
sgjm_n = sgjm.num_parameters()
base_n = baseline.num_parameters()
# Should land within +/- 10% so the comparison is fair.
assert 0.9 * sgjm_n <= base_n <= 1.1 * sgjm_n, (
f"baseline {base_n/1e6:.2f}M vs sgjm {sgjm_n/1e6:.2f}M too different"
)
def test_evaluate_sgjm_returns_finite_metrics():
cfg = TrainingConfig.smoke()
model = SGJM(cfg.model)
corpus = synthetic_corpus(4096, seed=0)
ds = ByteDataset(corpus, cfg.optim.seq_len)
metrics = evaluate_sgjm(
model,
cfg,
ds,
n_batches=2,
n_distractors=4,
n_merge_pairs=128,
drafts_per_step=2,
device="cpu",
)
assert 0.0 <= metrics.branch_acceptance_rate <= 1.0
assert 0.0 <= metrics.jepa_top1_acc <= 1.0
assert metrics.compute_per_accepted_token > 0
# merge_precision_js can legitimately be NaN when too few pairs land
# inside the merge radius; everything else must be finite.
nan_ok = {"merge_precision_js", "random_pair_js"}
for k, v in metrics.to_dict().items():
if isinstance(v, float) and k not in nan_ok:
assert v == v, f"{k} is NaN"
def test_evaluate_baseline_returns_finite_metrics():
cfg = TrainingConfig.smoke()
model = BaselineLM(cfg.model)
corpus = synthetic_corpus(4096, seed=0)
ds = ByteDataset(corpus, cfg.optim.seq_len)
metrics = evaluate_baseline(model, cfg, ds, n_batches=2, device="cpu")
assert metrics.token_nll > 0
assert metrics.compute_per_token > 0
def test_compare_constructs_report():
sgjm = SGJMEvalMetrics(
n_tokens=100, n_positions=50, token_nll=1.5, token_ppl=4.5,
branch_acceptance_rate=0.6, jepa_top1_acc=0.4, jepa_chance_top1=0.1,
merge_precision_js=0.01, random_pair_js=0.05,
merge_precision_advantage=5.0,
compute_per_accepted_token=1e6,
)
baseline = BaselineEvalMetrics(
n_tokens=100, token_nll=1.5, token_ppl=4.5, compute_per_token=2e6,
)
report = compare(sgjm, baseline)
assert report.gate_passed
assert report.compute_advantage > 1.0
def test_compare_fails_gate_on_high_nll():
sgjm = SGJMEvalMetrics(
n_tokens=100, n_positions=50, token_nll=3.0, token_ppl=20.0,
branch_acceptance_rate=0.6, jepa_top1_acc=0.4, jepa_chance_top1=0.1,
merge_precision_js=0.01, random_pair_js=0.05,
merge_precision_advantage=5.0,
compute_per_accepted_token=1e6,
)
baseline = BaselineEvalMetrics(
n_tokens=100, token_nll=1.5, token_ppl=4.5, compute_per_token=2e6,
)
report = compare(sgjm, baseline)
assert not report.gate_passed
assert any("nll" in r for r in report.gate_reasons)
def test_end_to_end_train_then_eval(tmp_path):
sgjm_dir = tmp_path / "sgjm"
base_dir = tmp_path / "baseline"
sgjm_cfg = _smoke_cfg("sgjm", sgjm_dir)
base_cfg = _smoke_cfg("baseline", base_dir)
sgjm_result = train(sgjm_cfg, backend="cpu")
base_result = train(base_cfg, backend="cpu")
assert sgjm_result.checkpoint_path is not None
assert base_result.checkpoint_path is not None
sgjm_loaded = load_checkpoint(sgjm_result.checkpoint_path, device="cpu")
base_loaded = load_checkpoint(base_result.checkpoint_path, device="cpu")
assert sgjm_loaded.arch == "sgjm"
assert base_loaded.arch == "baseline"
corpus = synthetic_corpus(2048, seed=0)
ds = ByteDataset(corpus, sgjm_cfg.optim.seq_len)
sgjm_metrics = evaluate_sgjm(
sgjm_loaded.model, sgjm_cfg, ds, n_batches=2,
n_distractors=4, n_merge_pairs=64, drafts_per_step=2, device="cpu",
)
base_metrics = evaluate_baseline(
base_loaded.model, base_cfg, ds, n_batches=2, device="cpu",
)
report = compare(sgjm_metrics, base_metrics)
assert isinstance(report.gate_passed, bool)
# Untrained smoke models won't pass the gate, just verify the report shape
assert report.sgjm.token_nll > 0
assert report.baseline.token_nll > 0
|