HONEST-RL-Calibrator / data /tests /test_integration.py
Rushhaabhhh's picture
HONEST-RL-Calibrator-v0
3040767 verified
"""End-to-end integration test for the data-sampler adapter layer.
Simulates exactly what server/environment.py does when it imports and
calls the adapter generate() functions.
Run from the project root:
PYTHONPATH=. pytest data/tests/test_integration.py -v
"""
from __future__ import annotations
import json
import warnings
import pytest
# ---------------------------------------------------------------------------
# 1. Import the adapter functions the same way the environment would
# ---------------------------------------------------------------------------
from data.sampler.math_gen_adapter import generate as math_generate
from data.sampler.code_gen_adapter import generate as code_generate
from data.sampler.logic_gen_adapter import generate as logic_generate
from data.sampler.environment_adapter import get_sampler
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _call(fn, difficulty: int, seed: int = 0):
"""Call fn, suppressing fallback warnings for empty buckets."""
with warnings.catch_warnings(record=True):
warnings.simplefilter("always")
return fn(difficulty, seed=seed)
# ---------------------------------------------------------------------------
# 2. Each adapter's generate() at every difficulty 1–5
# ---------------------------------------------------------------------------
class TestAdapterSignatures:
"""Verify (str, str, str) contract is preserved for every domain × difficulty."""
@pytest.mark.parametrize("diff", [1, 2, 3, 4, 5])
def test_math_generate_returns_str_triple(self, diff):
q, a, pid = _call(math_generate, diff)
assert isinstance(q, str) and len(q) > 0, f"diff={diff}: question is empty"
assert isinstance(a, str), f"diff={diff}: answer is not str"
assert isinstance(pid, str) and len(pid) > 0, f"diff={diff}: problem_id is empty"
@pytest.mark.parametrize("diff", [1, 2, 3, 4, 5])
def test_code_generate_returns_str_triple(self, diff):
q, a, pid = _call(code_generate, diff)
assert isinstance(q, str) and len(q) > 0, f"diff={diff}: question is empty"
assert isinstance(a, str), f"diff={diff}: answer is not str"
assert isinstance(pid, str) and len(pid) > 0, f"diff={diff}: problem_id is empty"
@pytest.mark.parametrize("diff", [1, 2, 3, 4, 5])
def test_logic_generate_returns_str_triple(self, diff):
q, a, pid = _call(logic_generate, diff)
assert isinstance(q, str) and len(q) > 0, f"diff={diff}: question is empty"
assert isinstance(a, str), f"diff={diff}: answer is not str"
assert isinstance(pid, str) and len(pid) > 0, f"diff={diff}: problem_id is empty"
def test_logic_answer_is_valid_json_dict_at_d3(self):
"""ZebraLogic canonical answer (difficulty >= 3) must be a JSON-parseable dict."""
_, a, _ = _call(logic_generate, 3)
parsed = json.loads(a)
assert isinstance(parsed, dict) and len(parsed) > 0
def test_adapters_are_deterministic_with_seed(self):
q1, a1, p1 = math_generate(2, seed=42)
q2, a2, p2 = math_generate(2, seed=42)
assert q1 == q2 and a1 == a2 and p1 == p2
def test_adapters_vary_without_seed(self):
"""Two calls without a seed should (almost always) return different questions."""
results = {math_generate(3)[0] for _ in range(5)}
assert len(results) > 1, "Five un-seeded calls all returned the same question"
# ---------------------------------------------------------------------------
# 3. verify() on known-correct and known-wrong answers
# ---------------------------------------------------------------------------
class TestVerifyDispatch:
"""Confirm UnifiedSampler.verify() dispatches correctly and returns bool."""
@pytest.fixture(scope="class")
def sampler(self):
return get_sampler()
def _sample_id_and_answer(self, sampler, domain: str, generate_fn, difficulty: int):
"""Sample a problem and use the returned problem_id directly."""
with warnings.catch_warnings(record=True):
warnings.simplefilter("always")
_, canonical, pid = generate_fn(difficulty, seed=7)
assert pid in sampler._by_id, (
f"Returned problem_id={pid!r} not in sampler._by_id for domain={domain}"
)
return pid, canonical
# -- Correct answers should pass --
def test_math_correct_answer_passes(self, sampler):
pid, canon = self._sample_id_and_answer(sampler, "math", math_generate, 1)
assert sampler.verify(pid, canon) is True
def test_code_correct_answer_passes(self, sampler):
pid, canon = self._sample_id_and_answer(sampler, "code", code_generate, 1)
assert sampler.verify(pid, canon) is True
def test_logic_correct_answer_passes(self, sampler):
pid, canon = self._sample_id_and_answer(sampler, "logic", logic_generate, 3)
assert sampler.verify(pid, canon) is True
# -- Wrong answers should fail --
def test_math_wrong_answer_fails(self, sampler):
pid, _ = self._sample_id_and_answer(sampler, "math", math_generate, 1)
assert sampler.verify(pid, "999999") is False
def test_logic_wrong_answer_fails(self, sampler):
pid, _ = self._sample_id_and_answer(sampler, "logic", logic_generate, 3)
# Completely wrong JSON grid
wrong = json.dumps({"House 1": {"Name": "WRONG", "Pet": "WRONG", "Drink": "WRONG"}})
assert sampler.verify(pid, wrong) is False
def test_verify_returns_bool_type(self, sampler):
"""Ensure return type is exactly bool, not a truthy/falsy value."""
pid, canon = self._sample_id_and_answer(sampler, "math", math_generate, 2)
result = sampler.verify(pid, canon)
assert type(result) is bool # noqa: E721
def test_verify_unknown_id_returns_false(self, sampler):
assert sampler.verify("__nonexistent__", "42") is False
# -- verify() never raises --
def test_verify_does_not_raise_on_garbage_input(self, sampler):
pid, _ = self._sample_id_and_answer(sampler, "math", math_generate, 1)
# All of these should return False, never raise
for bad in ["", "\x00\xff", "null", "[]", "{}", "NaN", " "]:
result = sampler.verify(pid, bad)
assert isinstance(result, bool), f"verify() raised or returned non-bool for input={bad!r}"
# ---------------------------------------------------------------------------
# 4. Singleton is shared across adapter modules
# ---------------------------------------------------------------------------
class TestSingleton:
def test_singleton_identity(self):
"""All three adapters share the exact same UnifiedSampler instance."""
from data.sampler.environment_adapter import get_sampler as ga
s1 = ga()
s2 = ga()
assert s1 is s2
def test_singleton_loaded_once(self):
"""Second call to get_sampler() does not re-load data (same object)."""
s1 = get_sampler()
count_before = s1.total_count()
s2 = get_sampler()
assert s2.total_count() == count_before
assert s1 is s2