math_trainer / tests /test_core_utils.py
NorthernTribe-Research's picture
Promotion Space deploy (2026-03-25 05:46 UTC)
73f3887 verified
#!/usr/bin/env python3
"""Production safety tests for key pipeline utilities."""
from __future__ import annotations
import json
import sys
import tempfile
import unittest
from unittest import mock
from pathlib import Path
try:
from datasets import Dataset
except ModuleNotFoundError: # pragma: no cover - optional test dependency in this environment
Dataset = None
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
try:
import app
except Exception: # pragma: no cover - optional test dependency in this environment
app = None
try:
from scripts import eval_sota
except Exception: # pragma: no cover - optional test dependency in this environment
eval_sota = None
try:
from scripts import train_sota
except Exception: # pragma: no cover - optional test dependency in this environment
train_sota = None
@unittest.skipUnless(app is not None, "app runtime dependencies are not installed")
class AppUtilityTests(unittest.TestCase):
def test_validate_repo_id_accepts_valid(self) -> None:
self.assertEqual(
app.validate_repo_id("NorthernTribe-Research/math_trainer", "Model repo"),
"NorthernTribe-Research/math_trainer",
)
def test_validate_repo_id_rejects_invalid(self) -> None:
with self.assertRaises(ValueError):
app.validate_repo_id("invalid repo id", "Model repo")
def test_merge_log_chunk_truncates(self) -> None:
merged = app._merge_log_chunk("a" * 9, "b" * 9, max_chars=10)
self.assertEqual(len(merged), 10)
self.assertTrue(merged.endswith("b" * 9))
def test_build_stage_timeline_returns_list_markup(self) -> None:
stage_meta = {"start_stage": 1, "stage_count": 2, "completed": 1, "active_stage": 2}
html = app._build_stage_timeline({}, stage_meta)
self.assertIn("ops-stage-list", html)
self.assertIn("ops-stage-item", html)
def test_validate_stage_window_rejects_overflow(self) -> None:
with self.assertRaises(ValueError):
app.validate_stage_window(app.TEMPLATE_STAGE_COUNT, 2)
def test_build_recent_runs_panel_markup(self) -> None:
summary = {
"recent_runs": [
{
"run_label": "run-20260101-000000",
"result": "completed",
"duration_seconds": 42,
"finished_at_utc": "2026-01-01 00:00:42 UTC",
"evaluation": {"pass_at_1": 0.11, "pass_at_k": 0.27, "evaluated_rows": 128},
}
]
}
html = app._build_recent_runs_panel(summary)
self.assertIn("ops-run-list", html)
self.assertIn("run-20260101-000000", html)
self.assertIn("completed", html)
def test_run_result_badge_class_handles_preflight_variants(self) -> None:
self.assertEqual(app._run_result_badge_class("preflight_passed"), "ok")
self.assertEqual(app._run_result_badge_class("preflight passed"), "ok")
def test_persist_run_artifacts_updates_history(self) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
history_path = Path(tmpdir) / "run_history.json"
records_dir = Path(tmpdir) / "run_records"
summary = {
"run_label": "run-20260102-030405",
"result": "completed",
"started_at_utc": "2026-01-02 03:04:05 UTC",
"finished_at_utc": "2026-01-02 03:04:35 UTC",
"evaluation": {"pass_at_1": 0.1, "pass_at_k": 0.2, "evaluated_rows": 64},
}
with mock.patch.object(app, "RUN_HISTORY_PATH", history_path):
with mock.patch.object(app, "RUN_RECORDS_DIR", records_dir):
warning = app.persist_run_artifacts(summary)
self.assertIsNone(warning)
self.assertTrue(history_path.exists())
payload = json.loads(history_path.read_text(encoding="utf-8"))
self.assertEqual(payload[0]["run_label"], "run-20260102-030405")
self.assertEqual(payload[0]["result"], "completed")
self.assertTrue((records_dir / "run-20260102-030405.json").exists())
@unittest.skipUnless(eval_sota is not None, "eval_sota runtime dependencies are not installed")
class EvalUtilityTests(unittest.TestCase):
def test_parse_numeric_fraction(self) -> None:
value = eval_sota.parse_numeric_value("3/4")
self.assertIsNotNone(value)
assert value is not None
self.assertAlmostEqual(value, 0.75, places=8)
def test_match_candidate_boxed(self) -> None:
result = eval_sota.match_candidate(r"\boxed{42}", ["42"])
self.assertTrue(result["match"])
self.assertTrue(result["boxed"] or result["exact"])
def test_infer_response_profile_handles_formal_and_non_formal_rows(self) -> None:
formal_row = {"family": "formal_proof", "difficulty": "formal_proof"}
simple_row = {"family": "problem_solving", "difficulty": "basic"}
self.assertEqual(eval_sota.infer_response_profile(formal_row), "lean_formal")
self.assertEqual(eval_sota.infer_response_profile(simple_row), "simple")
@unittest.skipUnless(train_sota is not None, "train_sota runtime dependencies are not installed")
class TrainUtilityTests(unittest.TestCase):
def test_as_bool_conversions(self) -> None:
self.assertTrue(train_sota.as_bool("yes"))
self.assertFalse(train_sota.as_bool("no"))
self.assertTrue(train_sota.as_bool(True))
self.assertFalse(train_sota.as_bool(None, default=False))
def test_canonical_difficulty_mappings(self) -> None:
self.assertEqual(train_sota.canonical_difficulty("basic_to_intermediate"), "simple")
self.assertEqual(train_sota.canonical_difficulty("formal_proof"), "lean_formal")
self.assertEqual(train_sota.canonical_difficulty("olympiad"), "advanced")
def test_apply_filters_include_bands_and_require_lean_formal(self) -> None:
if Dataset is None:
self.skipTest("datasets is not installed")
dataset = Dataset.from_dict(
{
"family": ["formal_proof", "problem_solving", "competition"],
"task_type": ["theorem_proving", "word_problem", "olympiad"],
"source_dataset": ["src-a", "src-b", "src-c"],
"difficulty": ["formal_proof", "basic_to_intermediate", "olympiad"],
"conjecture_id": ["c1", "c2", "c3"],
"sample_weight": [1.0, 1.0, 1.0],
}
)
filtered = train_sota.apply_filters(
dataset,
{
"include_difficulty_bands": ["lean_formal", "simple"],
"require_lean_formal": True,
},
)
self.assertEqual(len(filtered), 1)
self.assertEqual(filtered[0]["family"], "formal_proof")
self.assertEqual(filtered[0]["difficulty"], "formal_proof")
def test_build_tokenizer_falls_back_when_protobuf_missing(self) -> None:
class DummyTokenizer:
def __init__(self) -> None:
self.pad_token = None
self.eos_token = "<eos>"
self.unk_token = "<unk>"
def add_special_tokens(self, tokens):
self.pad_token = tokens.get("pad_token")
calls = []
def fake_from_pretrained(*args, **kwargs):
calls.append(kwargs.get("use_fast"))
if kwargs.get("use_fast"):
raise ImportError("requires the protobuf library")
return DummyTokenizer()
with mock.patch.object(train_sota.AutoTokenizer, "from_pretrained", side_effect=fake_from_pretrained):
tok = train_sota.build_tokenizer({"base_model": "dummy/model", "trust_remote_code": False})
self.assertEqual(calls, [True, False])
self.assertEqual(tok.pad_token, "<eos>")
@unittest.skipUnless(eval_sota is not None, "eval_sota runtime dependencies are not installed")
class EvalTokenizerFallbackTests(unittest.TestCase):
def test_eval_tokenizer_falls_back_when_protobuf_missing(self) -> None:
class DummyTokenizer:
def __init__(self) -> None:
self.pad_token = None
self.eos_token = "<eos>"
self.unk_token = "<unk>"
def add_special_tokens(self, tokens):
self.pad_token = tokens.get("pad_token")
class DummyModel:
def eval(self):
return None
calls = []
def fake_tok_from_pretrained(*args, **kwargs):
calls.append(kwargs.get("use_fast"))
if kwargs.get("use_fast"):
raise ImportError("requires the protobuf library")
return DummyTokenizer()
with mock.patch.object(eval_sota.AutoTokenizer, "from_pretrained", side_effect=fake_tok_from_pretrained):
with mock.patch.object(eval_sota.AutoModelForCausalLM, "from_pretrained", return_value=DummyModel()):
model, tok = eval_sota.load_model_and_tokenizer(
base_model="dummy/model",
adapter_path=None,
trust_remote_code=False,
)
self.assertIsNotNone(model)
self.assertEqual(calls, [True, False])
self.assertEqual(tok.pad_token, "<eos>")
@unittest.skipUnless(app is not None, "app runtime dependencies are not installed")
class ContinuousModeSafetyTests(unittest.TestCase):
def test_continuous_mode_halts_after_consecutive_failures(self) -> None:
original_max = app.CONTINUOUS_MAX_CONSECUTIVE_FAILURES
original_delay = app.CONTINUOUS_RESTART_DELAY_SECONDS
app.CONTINUOUS_MAX_CONSECUTIVE_FAILURES = 2
app.CONTINUOUS_RESTART_DELAY_SECONDS = 0
self.addCleanup(setattr, app, "CONTINUOUS_MAX_CONSECUTIVE_FAILURES", original_max)
self.addCleanup(setattr, app, "CONTINUOUS_RESTART_DELAY_SECONDS", original_delay)
def fake_pipeline_core(**kwargs):
summary = json.dumps({"result": "failed"})
yield "line-1", "Failed", summary
with mock.patch.object(app, "run_pipeline_core", side_effect=fake_pipeline_core):
outputs = list(
app.run_pipeline(
dataset_repo_id="owner/dataset",
model_repo_id="owner/model",
base_model_id="model/base",
autonomous_mode=False,
continuous_mode=True,
start_stage=1,
max_stages=1,
run_eval=False,
eval_k=1,
eval_samples=50,
enforce_quality_gate=False,
gate_min_pass_at_1=0.0,
gate_min_pass_at_k=0.0,
gate_min_rows=10,
push_to_hub=False,
force_redownload=False,
preflight_only=False,
)
)
self.assertGreaterEqual(len(outputs), 3)
last_status = outputs[-1][1]
self.assertIn("halted", last_status.lower())
def test_continuous_mode_cooldown_stops_on_cancel(self) -> None:
original_max = app.CONTINUOUS_MAX_CONSECUTIVE_FAILURES
original_delay = app.CONTINUOUS_RESTART_DELAY_SECONDS
app.CONTINUOUS_MAX_CONSECUTIVE_FAILURES = 3
app.CONTINUOUS_RESTART_DELAY_SECONDS = 1
self.addCleanup(setattr, app, "CONTINUOUS_MAX_CONSECUTIVE_FAILURES", original_max)
self.addCleanup(setattr, app, "CONTINUOUS_RESTART_DELAY_SECONDS", original_delay)
def fake_pipeline_core(**kwargs):
summary = json.dumps({"result": "completed"})
yield "line-1", "Completed", summary
with mock.patch.object(app, "run_pipeline_core", side_effect=fake_pipeline_core):
with mock.patch.object(app, "is_cancel_requested", return_value=True):
outputs = list(
app.run_pipeline(
dataset_repo_id="owner/dataset",
model_repo_id="owner/model",
base_model_id="model/base",
autonomous_mode=False,
continuous_mode=True,
start_stage=1,
max_stages=1,
run_eval=False,
eval_k=1,
eval_samples=50,
enforce_quality_gate=False,
gate_min_pass_at_1=0.0,
gate_min_pass_at_k=0.0,
gate_min_rows=10,
push_to_hub=False,
force_redownload=False,
preflight_only=False,
)
)
self.assertGreaterEqual(len(outputs), 3)
self.assertIn("stopped", outputs[-1][1].lower())
if __name__ == "__main__":
unittest.main(verbosity=2)