from __future__ import annotations import pytest mlx = pytest.importorskip("mlx.core") from sgjm.training.config import TrainingConfig from sgjm.training.data import ByteDataset, synthetic_corpus def test_mlx_evaluate_sgjm_returns_finite_metrics(): """evaluate_sgjm over MLX model returns valid metric values.""" from sgjm.eval.mlx_metrics import evaluate_sgjm from sgjm.training.mlx_backend.model import SGJM cfg = TrainingConfig.smoke() model = SGJM(cfg.model) corpus = synthetic_corpus(4096, seed=0) ds = ByteDataset(corpus, cfg.optim.seq_len) metrics = evaluate_sgjm( model, cfg, ds, n_batches=2, n_distractors=4, n_merge_pairs=128, drafts_per_step=2, seed=42, ) assert 0.0 <= metrics.branch_acceptance_rate <= 1.0 assert 0.0 <= metrics.jepa_top1_acc <= 1.0 assert metrics.compute_per_accepted_token > 0 def test_mlx_evaluate_baseline_returns_finite_metrics(): """evaluate_baseline over MLX model returns valid metric values.""" from sgjm.eval.mlx_metrics import evaluate_baseline from sgjm.training.mlx_backend.baseline import BaselineLM cfg = TrainingConfig.smoke() model = BaselineLM(cfg.model) corpus = synthetic_corpus(4096, seed=0) ds = ByteDataset(corpus, cfg.optim.seq_len) metrics = evaluate_baseline(model, cfg, ds, n_batches=2, seed=42) assert metrics.token_nll > 0 assert metrics.compute_per_token > 0 def test_mlx_evaluate_sgjm_non_nan_fields(): """All non-merge fields in SGJM metrics should be finite.""" from sgjm.eval.mlx_metrics import evaluate_sgjm from sgjm.training.mlx_backend.model import SGJM cfg = TrainingConfig.smoke() model = SGJM(cfg.model) corpus = synthetic_corpus(4096, seed=1) ds = ByteDataset(corpus, cfg.optim.seq_len) metrics = evaluate_sgjm( model, cfg, ds, n_batches=2, n_distractors=4, n_merge_pairs=128, drafts_per_step=2, seed=99, ) nan_ok = {"merge_precision_js", "random_pair_js"} for k, v in metrics.to_dict().items(): if isinstance(v, float) and k not in nan_ok: assert v == v, f"{k} is NaN" def test_mlx_evaluate_baseline_token_count_correct(): """n_tokens in BaselineEvalMetrics matches expected token count.""" from sgjm.eval.mlx_metrics import evaluate_baseline from sgjm.training.mlx_backend.baseline import BaselineLM cfg = TrainingConfig.smoke() model = BaselineLM(cfg.model) corpus = synthetic_corpus(4096, seed=2) ds = ByteDataset(corpus, cfg.optim.seq_len) n_batches = 3 metrics = evaluate_baseline(model, cfg, ds, n_batches=n_batches, seed=7) expected = n_batches * cfg.optim.batch_size * cfg.optim.seq_len assert metrics.n_tokens == expected