SGJM / tests /test_eval_mlx.py
adampippert's picture
SGJM 2026.6.5 — code/docs
e51ccda verified
Raw
History Blame Contribute Delete
2.84 kB
from __future__ import annotations
import pytest
mlx = pytest.importorskip("mlx.core")
from sgjm.training.config import TrainingConfig
from sgjm.training.data import ByteDataset, synthetic_corpus
def test_mlx_evaluate_sgjm_returns_finite_metrics():
"""evaluate_sgjm over MLX model returns valid metric values."""
from sgjm.eval.mlx_metrics import evaluate_sgjm
from sgjm.training.mlx_backend.model import SGJM
cfg = TrainingConfig.smoke()
model = SGJM(cfg.model)
corpus = synthetic_corpus(4096, seed=0)
ds = ByteDataset(corpus, cfg.optim.seq_len)
metrics = evaluate_sgjm(
model,
cfg,
ds,
n_batches=2,
n_distractors=4,
n_merge_pairs=128,
drafts_per_step=2,
seed=42,
)
assert 0.0 <= metrics.branch_acceptance_rate <= 1.0
assert 0.0 <= metrics.jepa_top1_acc <= 1.0
assert metrics.compute_per_accepted_token > 0
def test_mlx_evaluate_baseline_returns_finite_metrics():
"""evaluate_baseline over MLX model returns valid metric values."""
from sgjm.eval.mlx_metrics import evaluate_baseline
from sgjm.training.mlx_backend.baseline import BaselineLM
cfg = TrainingConfig.smoke()
model = BaselineLM(cfg.model)
corpus = synthetic_corpus(4096, seed=0)
ds = ByteDataset(corpus, cfg.optim.seq_len)
metrics = evaluate_baseline(model, cfg, ds, n_batches=2, seed=42)
assert metrics.token_nll > 0
assert metrics.compute_per_token > 0
def test_mlx_evaluate_sgjm_non_nan_fields():
"""All non-merge fields in SGJM metrics should be finite."""
from sgjm.eval.mlx_metrics import evaluate_sgjm
from sgjm.training.mlx_backend.model import SGJM
cfg = TrainingConfig.smoke()
model = SGJM(cfg.model)
corpus = synthetic_corpus(4096, seed=1)
ds = ByteDataset(corpus, cfg.optim.seq_len)
metrics = evaluate_sgjm(
model,
cfg,
ds,
n_batches=2,
n_distractors=4,
n_merge_pairs=128,
drafts_per_step=2,
seed=99,
)
nan_ok = {"merge_precision_js", "random_pair_js"}
for k, v in metrics.to_dict().items():
if isinstance(v, float) and k not in nan_ok:
assert v == v, f"{k} is NaN"
def test_mlx_evaluate_baseline_token_count_correct():
"""n_tokens in BaselineEvalMetrics matches expected token count."""
from sgjm.eval.mlx_metrics import evaluate_baseline
from sgjm.training.mlx_backend.baseline import BaselineLM
cfg = TrainingConfig.smoke()
model = BaselineLM(cfg.model)
corpus = synthetic_corpus(4096, seed=2)
ds = ByteDataset(corpus, cfg.optim.seq_len)
n_batches = 3
metrics = evaluate_baseline(model, cfg, ds, n_batches=n_batches, seed=7)
expected = n_batches * cfg.optim.batch_size * cfg.optim.seq_len
assert metrics.n_tokens == expected