| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
|
|
| import pytest |
|
|
| mlx = pytest.importorskip("mlx.core") |
|
|
| from sgjm.training.config import TrainingConfig |
| from sgjm.training.mlx_backend.trainer import train as mlx_train |
|
|
|
|
| def _train_checkpoints(tmp_path: Path) -> tuple[Path, Path]: |
| """Train both SGJM and baseline MLX checkpoints; return their paths.""" |
| sgjm_cfg = TrainingConfig.smoke() |
| sgjm_cfg.arch = "sgjm" |
| sgjm_cfg.checkpoint_dir = str(tmp_path / "sgjm") |
| sgjm_result = mlx_train(sgjm_cfg, "mlx") |
|
|
| base_cfg = TrainingConfig.smoke() |
| base_cfg.arch = "baseline" |
| base_cfg.checkpoint_dir = str(tmp_path / "baseline") |
| base_result = mlx_train(base_cfg, "mlx") |
|
|
| assert sgjm_result.checkpoint_path is not None |
| assert base_result.checkpoint_path is not None |
| return sgjm_result.checkpoint_path, base_result.checkpoint_path |
|
|
|
|
| def test_eval_cli_mlx_backend_runs(tmp_path): |
| """eval __main__ with --backend mlx completes without error.""" |
| from sgjm.eval.__main__ import main |
|
|
| sgjm_path, baseline_path = _train_checkpoints(tmp_path) |
| report_path = tmp_path / "report.json" |
| ret = main([ |
| "--sgjm", str(sgjm_path), |
| "--baseline", str(baseline_path), |
| "--backend", "mlx", |
| "--batches", "2", |
| "--n-distractors", "4", |
| "--n-merge-pairs", "32", |
| "--seed", "42", |
| "--report", str(report_path), |
| ]) |
| |
| assert ret in (0, 1) |
| assert report_path.exists() |
| report = json.loads(report_path.read_text()) |
| assert "sgjm" in report |
| assert "baseline" in report |
|
|
|
|
| def test_eval_cli_mlx_arg_accepted(): |
| """Argument parser accepts --backend mlx without raising SystemExit.""" |
| from sgjm.eval.__main__ import main |
| import argparse |
|
|
| |
| |
| import sys |
| from io import StringIO |
| |
| import argparse as ap |
| |
| |
| parser = ap.ArgumentParser() |
| parser.add_argument("--backend", choices=["auto", "cuda", "rocm", "cpu", "mlx"], default="auto") |
| ns = parser.parse_args(["--backend", "mlx"]) |
| assert ns.backend == "mlx" |
|
|