SGJM / tests /test_eval.py
adampippert's picture
SGJM 2026.6.5 — code/docs
e51ccda verified
Raw
History Blame Contribute Delete
4.81 kB
import pytest
torch = pytest.importorskip("torch")
from sgjm.eval.checkpoint import load_checkpoint
from sgjm.eval.metrics import (
BaselineEvalMetrics,
SGJMEvalMetrics,
compare,
evaluate_baseline,
evaluate_sgjm,
)
from sgjm.training.config import TrainingConfig
from sgjm.training.data import ByteDataset, synthetic_corpus
from sgjm.training.torch_backend.baseline import BaselineLM
from sgjm.training.torch_backend.model import SGJM
from sgjm.training.torch_backend.trainer import train
def _smoke_cfg(arch: str, ckpt_dir) -> TrainingConfig:
cfg = TrainingConfig.smoke()
cfg.arch = arch
cfg.checkpoint_dir = str(ckpt_dir)
return cfg
def test_baseline_model_param_count_matches_sgjm_total():
cfg = TrainingConfig.sgjm_25m()
sgjm = SGJM(cfg.model)
baseline = BaselineLM(cfg.model)
sgjm_n = sgjm.num_parameters()
base_n = baseline.num_parameters()
# Should land within +/- 10% so the comparison is fair.
assert 0.9 * sgjm_n <= base_n <= 1.1 * sgjm_n, (
f"baseline {base_n/1e6:.2f}M vs sgjm {sgjm_n/1e6:.2f}M too different"
)
def test_evaluate_sgjm_returns_finite_metrics():
cfg = TrainingConfig.smoke()
model = SGJM(cfg.model)
corpus = synthetic_corpus(4096, seed=0)
ds = ByteDataset(corpus, cfg.optim.seq_len)
metrics = evaluate_sgjm(
model,
cfg,
ds,
n_batches=2,
n_distractors=4,
n_merge_pairs=128,
drafts_per_step=2,
device="cpu",
)
assert 0.0 <= metrics.branch_acceptance_rate <= 1.0
assert 0.0 <= metrics.jepa_top1_acc <= 1.0
assert metrics.compute_per_accepted_token > 0
# merge_precision_js can legitimately be NaN when too few pairs land
# inside the merge radius; everything else must be finite.
nan_ok = {"merge_precision_js", "random_pair_js"}
for k, v in metrics.to_dict().items():
if isinstance(v, float) and k not in nan_ok:
assert v == v, f"{k} is NaN"
def test_evaluate_baseline_returns_finite_metrics():
cfg = TrainingConfig.smoke()
model = BaselineLM(cfg.model)
corpus = synthetic_corpus(4096, seed=0)
ds = ByteDataset(corpus, cfg.optim.seq_len)
metrics = evaluate_baseline(model, cfg, ds, n_batches=2, device="cpu")
assert metrics.token_nll > 0
assert metrics.compute_per_token > 0
def test_compare_constructs_report():
sgjm = SGJMEvalMetrics(
n_tokens=100, n_positions=50, token_nll=1.5, token_ppl=4.5,
branch_acceptance_rate=0.6, jepa_top1_acc=0.4, jepa_chance_top1=0.1,
merge_precision_js=0.01, random_pair_js=0.05,
merge_precision_advantage=5.0,
compute_per_accepted_token=1e6,
)
baseline = BaselineEvalMetrics(
n_tokens=100, token_nll=1.5, token_ppl=4.5, compute_per_token=2e6,
)
report = compare(sgjm, baseline)
assert report.gate_passed
assert report.compute_advantage > 1.0
def test_compare_fails_gate_on_high_nll():
sgjm = SGJMEvalMetrics(
n_tokens=100, n_positions=50, token_nll=3.0, token_ppl=20.0,
branch_acceptance_rate=0.6, jepa_top1_acc=0.4, jepa_chance_top1=0.1,
merge_precision_js=0.01, random_pair_js=0.05,
merge_precision_advantage=5.0,
compute_per_accepted_token=1e6,
)
baseline = BaselineEvalMetrics(
n_tokens=100, token_nll=1.5, token_ppl=4.5, compute_per_token=2e6,
)
report = compare(sgjm, baseline)
assert not report.gate_passed
assert any("nll" in r for r in report.gate_reasons)
def test_end_to_end_train_then_eval(tmp_path):
sgjm_dir = tmp_path / "sgjm"
base_dir = tmp_path / "baseline"
sgjm_cfg = _smoke_cfg("sgjm", sgjm_dir)
base_cfg = _smoke_cfg("baseline", base_dir)
sgjm_result = train(sgjm_cfg, backend="cpu")
base_result = train(base_cfg, backend="cpu")
assert sgjm_result.checkpoint_path is not None
assert base_result.checkpoint_path is not None
sgjm_loaded = load_checkpoint(sgjm_result.checkpoint_path, device="cpu")
base_loaded = load_checkpoint(base_result.checkpoint_path, device="cpu")
assert sgjm_loaded.arch == "sgjm"
assert base_loaded.arch == "baseline"
corpus = synthetic_corpus(2048, seed=0)
ds = ByteDataset(corpus, sgjm_cfg.optim.seq_len)
sgjm_metrics = evaluate_sgjm(
sgjm_loaded.model, sgjm_cfg, ds, n_batches=2,
n_distractors=4, n_merge_pairs=64, drafts_per_step=2, device="cpu",
)
base_metrics = evaluate_baseline(
base_loaded.model, base_cfg, ds, n_batches=2, device="cpu",
)
report = compare(sgjm_metrics, base_metrics)
assert isinstance(report.gate_passed, bool)
# Untrained smoke models won't pass the gate, just verify the report shape
assert report.sgjm.token_nll > 0
assert report.baseline.token_nll > 0