"""Tests for redbutton_v2.problems — load, sample, check_answer.""" import pytest from redbutton_v2.problems import ( Problem, ProblemPool, check_answer, load_problems, ) def test_load_problems_loads_default_pool(): pool = load_problems() assert len(pool) >= 5 assert all(isinstance(p, Problem) for p in pool.problems) def test_load_problems_custom_path(tmp_path): """ProblemPool.load round-trips through the file system.""" payload = [{"id": 1, "problem": "x", "answer": 42, "difficulty": "easy"}] p = tmp_path / "tiny.json" p.write_text(__import__("json").dumps(payload)) pool = load_problems(p) assert len(pool) == 1 assert pool.problems[0].answer == 42 def test_sample_returns_n_problems(): pool = load_problems() sampled = pool.sample(n=3, seed=42) assert len(sampled) == 3 def test_sample_is_deterministic_for_same_seed(): pool = load_problems() a = pool.sample(n=5, seed=42) b = pool.sample(n=5, seed=42) assert [p.id for p in a] == [p.id for p in b] def test_sample_differs_across_seeds(): pool = load_problems() samples = [tuple(p.id for p in pool.sample(n=5, seed=s)) for s in (1, 2, 3, 4, 5)] assert len({s for s in samples}) >= 2 def test_sample_clamps_when_n_exceeds_pool_size(): pool = load_problems() sampled = pool.sample(n=10_000, seed=0) assert len(sampled) == len(pool) def test_problem_to_dict_round_trips(): p = Problem(id=1, problem="What is 2+2?", answer=4, difficulty="easy") d = p.to_dict() assert d["id"] == 1 assert d["problem"] == "What is 2+2?" assert d["answer"] == 4 assert d["difficulty"] == "easy" assert d["is_eval_problem"] is False def test_problem_pool_default_is_empty(): assert len(ProblemPool()) == 0 # ---- check_answer --------------------------------------------------------- def test_check_answer_correct(): assert check_answer(7, 7) is True def test_check_answer_wrong(): assert check_answer(8, 7) is False @pytest.mark.parametrize( "submitted, ground_truth", [ ("7", 7), (7.0, 7), (7, "7"), (None, 7), ], ids=["str_submitted", "float_submitted", "str_truth", "none_submitted"], ) def test_check_answer_rejects_type_mismatch(submitted, ground_truth): assert check_answer(submitted, ground_truth) is False def test_check_answer_rejects_bool(): """``bool`` is an int subclass; ``True == 1`` would pass naive ==. Strict integer equality rejects bool explicitly. """ assert check_answer(True, 1) is False assert check_answer(False, 0) is False