Spaces:
Sleeping
Sleeping
File size: 5,932 Bytes
0738d13 a72e3bd 0738d13 a72e3bd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 | """Tests for red_button.problems (PROJECT.md Section 12)."""
from __future__ import annotations
from pathlib import Path
import pytest
from red_button.problems import (
episode_seed,
ground_truth_map,
load_problems,
sample_problems,
validate_answer,
)
POOL_PATH = str(Path(__file__).resolve().parents[1] / "data" / "problems_pool.json")
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
def pool() -> list[dict]:
return load_problems(POOL_PATH)
# ---------------------------------------------------------------------------
# load_problems + pool structure
# ---------------------------------------------------------------------------
def test_load_problems_returns_list_of_dicts_with_required_keys(pool: list[dict]) -> None:
assert isinstance(pool, list)
required = {"id", "problem", "answer", "difficulty"}
for entry in pool:
assert isinstance(entry, dict)
assert required.issubset(entry.keys()), f"missing keys in {entry}"
def test_all_ids_are_unique(pool: list[dict]) -> None:
ids = [e["id"] for e in pool]
assert len(ids) == len(set(ids))
def test_all_answers_are_integers(pool: list[dict]) -> None:
# Explicit type check rules out bool (which is a subclass of int).
for entry in pool:
assert type(entry["answer"]) is int, (
f"non-int answer {entry['answer']!r} in problem id={entry['id']}"
)
def test_pool_size_meets_target(pool: list[dict]) -> None:
# Section 12.3 target: 300 GSM8K + 200 MATH = 500. Generated pool size
# must be >= 500. Lower the floor only with a documented rationale.
assert len(pool) >= 500
def test_all_difficulty_labels_are_valid(pool: list[dict]) -> None:
valid = {"easy", "medium", "hard"}
for entry in pool:
assert entry["difficulty"] in valid, f"bad difficulty in {entry}"
# ---------------------------------------------------------------------------
# sample_problems
# ---------------------------------------------------------------------------
def test_sample_problems_seeded_returns_correct_count(pool: list[dict]) -> None:
sample = sample_problems(n=10, seed=42, problems=pool)
assert len(sample) == 10
def test_sample_problems_seeded_is_deterministic(pool: list[dict]) -> None:
a = sample_problems(n=10, seed=42, problems=pool)
b = sample_problems(n=10, seed=42, problems=pool)
assert [e["id"] for e in a] == [e["id"] for e in b]
def test_sample_problems_different_seeds_differ(pool: list[dict]) -> None:
a_ids = {e["id"] for e in sample_problems(n=10, seed=42, problems=pool)}
b_ids = {e["id"] for e in sample_problems(n=10, seed=43, problems=pool)}
assert a_ids != b_ids
def test_sample_problems_unseeded_returns_correct_count(pool: list[dict]) -> None:
sample = sample_problems(n=10, problems=pool)
assert len(sample) == 10
def test_sample_problems_n_equals_twenty_works(pool: list[dict]) -> None:
sample = sample_problems(n=20, seed=7, problems=pool)
assert len(sample) == 20
# Sampling is without replacement.
assert len({e["id"] for e in sample}) == 20
def test_sample_problems_n_exceeds_pool_raises(pool: list[dict]) -> None:
with pytest.raises(ValueError):
sample_problems(n=10000, seed=0, problems=pool)
# ---------------------------------------------------------------------------
# validate_answer
# ---------------------------------------------------------------------------
def test_validate_answer_correct_returns_true(pool: list[dict]) -> None:
first = pool[0]
assert validate_answer(first["id"], first["answer"], pool) is True
def test_validate_answer_wrong_returns_false(pool: list[dict]) -> None:
first = pool[0]
wrong = first["answer"] + 99999
assert validate_answer(first["id"], wrong, pool) is False
def test_validate_answer_missing_id_returns_false(pool: list[dict]) -> None:
# Missing id returns False, not an exception.
assert validate_answer(10_000_000, 42, pool) is False
# ---------------------------------------------------------------------------
# ground_truth_map
# ---------------------------------------------------------------------------
def test_ground_truth_map_entries_count_matches_input(pool: list[dict]) -> None:
gt = ground_truth_map(pool)
assert len(gt) == len(pool)
def test_ground_truth_map_keys_are_ints_values_are_ints(pool: list[dict]) -> None:
gt = ground_truth_map(pool)
for k, v in gt.items():
assert type(k) is int
assert type(v) is int
# ---------------------------------------------------------------------------
# episode_seed
# ---------------------------------------------------------------------------
def test_episode_seed_is_deterministic() -> None:
assert episode_seed("abc") == episode_seed("abc")
def test_episode_seed_differs_for_different_ids() -> None:
assert episode_seed("abc") != episode_seed("def")
def test_episode_seed_returns_int() -> None:
seed = episode_seed("x")
# Guard against ``bool`` being an ``int`` subclass: require exact type.
assert isinstance(seed, int)
assert type(seed) is int
def test_sample_problems_with_episode_seed_is_deterministic(pool: list[dict]) -> None:
seed = episode_seed("test-1")
a = sample_problems(n=10, seed=seed, problems=pool)
b = sample_problems(n=10, seed=seed, problems=pool)
assert [e["id"] for e in a] == [e["id"] for e in b]
def test_sample_problems_with_different_episode_seeds_differ(pool: list[dict]) -> None:
a_ids = {
e["id"]
for e in sample_problems(n=10, seed=episode_seed("test-1"), problems=pool)
}
b_ids = {
e["id"]
for e in sample_problems(n=10, seed=episode_seed("test-2"), problems=pool)
}
assert a_ids != b_ids
|