Spaces:
Sleeping
Sleeping
| """Tests for shutdown_gym.problems — load, sample, check_answer.""" | |
| import pytest | |
| from shutdown_gym.problems import ( | |
| Problem, | |
| ProblemPool, | |
| check_answer, | |
| load_problems, | |
| ) | |
| def test_load_problems_loads_default_pool(): | |
| pool = load_problems() | |
| assert len(pool) >= 5 | |
| assert all(isinstance(p, Problem) for p in pool.problems) | |
| def test_load_problems_custom_path(tmp_path): | |
| """ProblemPool.load round-trips through the file system.""" | |
| payload = [{"id": 1, "problem": "x", "answer": 42, "difficulty": "easy"}] | |
| p = tmp_path / "tiny.json" | |
| p.write_text(__import__("json").dumps(payload)) | |
| pool = load_problems(p) | |
| assert len(pool) == 1 | |
| assert pool.problems[0].answer == 42 | |
| def test_sample_returns_n_problems(): | |
| pool = load_problems() | |
| sampled = pool.sample(n=3, seed=42) | |
| assert len(sampled) == 3 | |
| def test_sample_is_deterministic_for_same_seed(): | |
| pool = load_problems() | |
| a = pool.sample(n=5, seed=42) | |
| b = pool.sample(n=5, seed=42) | |
| assert [p.id for p in a] == [p.id for p in b] | |
| def test_sample_differs_across_seeds(): | |
| """At least 2 of 5 different seeds should produce different | |
| orderings on a non-trivial pool. Probabilistic but extremely | |
| unlikely to fail by chance. | |
| """ | |
| pool = load_problems() | |
| samples = [tuple(p.id for p in pool.sample(n=5, seed=s)) for s in (1, 2, 3, 4, 5)] | |
| assert len({s for s in samples}) >= 2 | |
| def test_sample_clamps_when_n_exceeds_pool_size(): | |
| pool = load_problems() | |
| sampled = pool.sample(n=10_000, seed=0) | |
| assert len(sampled) == len(pool) | |
| def test_problem_to_dict_round_trips(): | |
| p = Problem(id=1, problem="What is 2+2?", answer=4, difficulty="easy") | |
| assert p.to_dict() == { | |
| "id": 1, | |
| "problem": "What is 2+2?", | |
| "answer": 4, | |
| "difficulty": "easy", | |
| } | |
| def test_problem_pool_default_is_empty(): | |
| assert len(ProblemPool()) == 0 | |
| # ---- check_answer --------------------------------------------------------- | |
| def test_check_answer_correct(): | |
| assert check_answer(7, 7) is True | |
| def test_check_answer_wrong(): | |
| assert check_answer(8, 7) is False | |
| def test_check_answer_rejects_type_mismatch(submitted, ground_truth): | |
| assert check_answer(submitted, ground_truth) is False | |
| def test_check_answer_rejects_bool(): | |
| """``bool`` is an int subclass; ``True == 1`` would pass naive ==. | |
| Strict integer equality rejects bool explicitly. | |
| """ | |
| assert check_answer(True, 1) is False | |
| assert check_answer(False, 0) is False | |