Red-Button / tests /test_problems.py
Arun-Sanjay's picture
Pre-Phase-7 cleanup: migrate models.py to turn-budget fields, add episode_seed helper for deterministic-per-UUID sampling
a72e3bd
"""Tests for red_button.problems (PROJECT.md Section 12)."""
from __future__ import annotations
from pathlib import Path
import pytest
from red_button.problems import (
episode_seed,
ground_truth_map,
load_problems,
sample_problems,
validate_answer,
)
POOL_PATH = str(Path(__file__).resolve().parents[1] / "data" / "problems_pool.json")
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
def pool() -> list[dict]:
return load_problems(POOL_PATH)
# ---------------------------------------------------------------------------
# load_problems + pool structure
# ---------------------------------------------------------------------------
def test_load_problems_returns_list_of_dicts_with_required_keys(pool: list[dict]) -> None:
assert isinstance(pool, list)
required = {"id", "problem", "answer", "difficulty"}
for entry in pool:
assert isinstance(entry, dict)
assert required.issubset(entry.keys()), f"missing keys in {entry}"
def test_all_ids_are_unique(pool: list[dict]) -> None:
ids = [e["id"] for e in pool]
assert len(ids) == len(set(ids))
def test_all_answers_are_integers(pool: list[dict]) -> None:
# Explicit type check rules out bool (which is a subclass of int).
for entry in pool:
assert type(entry["answer"]) is int, (
f"non-int answer {entry['answer']!r} in problem id={entry['id']}"
)
def test_pool_size_meets_target(pool: list[dict]) -> None:
# Section 12.3 target: 300 GSM8K + 200 MATH = 500. Generated pool size
# must be >= 500. Lower the floor only with a documented rationale.
assert len(pool) >= 500
def test_all_difficulty_labels_are_valid(pool: list[dict]) -> None:
valid = {"easy", "medium", "hard"}
for entry in pool:
assert entry["difficulty"] in valid, f"bad difficulty in {entry}"
# ---------------------------------------------------------------------------
# sample_problems
# ---------------------------------------------------------------------------
def test_sample_problems_seeded_returns_correct_count(pool: list[dict]) -> None:
sample = sample_problems(n=10, seed=42, problems=pool)
assert len(sample) == 10
def test_sample_problems_seeded_is_deterministic(pool: list[dict]) -> None:
a = sample_problems(n=10, seed=42, problems=pool)
b = sample_problems(n=10, seed=42, problems=pool)
assert [e["id"] for e in a] == [e["id"] for e in b]
def test_sample_problems_different_seeds_differ(pool: list[dict]) -> None:
a_ids = {e["id"] for e in sample_problems(n=10, seed=42, problems=pool)}
b_ids = {e["id"] for e in sample_problems(n=10, seed=43, problems=pool)}
assert a_ids != b_ids
def test_sample_problems_unseeded_returns_correct_count(pool: list[dict]) -> None:
sample = sample_problems(n=10, problems=pool)
assert len(sample) == 10
def test_sample_problems_n_equals_twenty_works(pool: list[dict]) -> None:
sample = sample_problems(n=20, seed=7, problems=pool)
assert len(sample) == 20
# Sampling is without replacement.
assert len({e["id"] for e in sample}) == 20
def test_sample_problems_n_exceeds_pool_raises(pool: list[dict]) -> None:
with pytest.raises(ValueError):
sample_problems(n=10000, seed=0, problems=pool)
# ---------------------------------------------------------------------------
# validate_answer
# ---------------------------------------------------------------------------
def test_validate_answer_correct_returns_true(pool: list[dict]) -> None:
first = pool[0]
assert validate_answer(first["id"], first["answer"], pool) is True
def test_validate_answer_wrong_returns_false(pool: list[dict]) -> None:
first = pool[0]
wrong = first["answer"] + 99999
assert validate_answer(first["id"], wrong, pool) is False
def test_validate_answer_missing_id_returns_false(pool: list[dict]) -> None:
# Missing id returns False, not an exception.
assert validate_answer(10_000_000, 42, pool) is False
# ---------------------------------------------------------------------------
# ground_truth_map
# ---------------------------------------------------------------------------
def test_ground_truth_map_entries_count_matches_input(pool: list[dict]) -> None:
gt = ground_truth_map(pool)
assert len(gt) == len(pool)
def test_ground_truth_map_keys_are_ints_values_are_ints(pool: list[dict]) -> None:
gt = ground_truth_map(pool)
for k, v in gt.items():
assert type(k) is int
assert type(v) is int
# ---------------------------------------------------------------------------
# episode_seed
# ---------------------------------------------------------------------------
def test_episode_seed_is_deterministic() -> None:
assert episode_seed("abc") == episode_seed("abc")
def test_episode_seed_differs_for_different_ids() -> None:
assert episode_seed("abc") != episode_seed("def")
def test_episode_seed_returns_int() -> None:
seed = episode_seed("x")
# Guard against ``bool`` being an ``int`` subclass: require exact type.
assert isinstance(seed, int)
assert type(seed) is int
def test_sample_problems_with_episode_seed_is_deterministic(pool: list[dict]) -> None:
seed = episode_seed("test-1")
a = sample_problems(n=10, seed=seed, problems=pool)
b = sample_problems(n=10, seed=seed, problems=pool)
assert [e["id"] for e in a] == [e["id"] for e in b]
def test_sample_problems_with_different_episode_seeds_differ(pool: list[dict]) -> None:
a_ids = {
e["id"]
for e in sample_problems(n=10, seed=episode_seed("test-1"), problems=pool)
}
b_ids = {
e["id"]
for e in sample_problems(n=10, seed=episode_seed("test-2"), problems=pool)
}
assert a_ids != b_ids