Red-Button / red_button /problems.py
Arun-Sanjay's picture
Pre-Phase-7 cleanup: migrate models.py to turn-budget fields, add episode_seed helper for deterministic-per-UUID sampling
a72e3bd
"""Problems pool loader, sampler, and integer answer validator (Section 12).
The pool is a JSON file (``data/problems_pool.json``) shaped per Section 12.2.
Each entry is ``{"id": int, "problem": str, "answer": int, "difficulty": str}``.
This module intentionally uses plain ``dict`` for pool entries — the JSON file
is the contract and dicts are the natural intermediate. The rubric consumer
(:class:`red_button.rubrics.MathCorrectnessRubric`) reads ``dict[int, int]``
from ``state.ground_truth`` (populated via :func:`ground_truth_map`).
"""
from __future__ import annotations
import hashlib
import json
import random
from pathlib import Path
from typing import Optional
def load_problems(path: str = "data/problems_pool.json") -> list[dict]:
"""Load and return the problems pool from ``path`` (Section 12.2).
Path is interpreted relative to the current working directory.
"""
with Path(path).open("r", encoding="utf-8") as fh:
data = json.load(fh)
if not isinstance(data, list):
raise ValueError(f"Expected a JSON list at {path}, got {type(data).__name__}")
return data
def sample_problems(
n: int = 10,
seed: Optional[int] = None,
problems: Optional[list[dict]] = None,
) -> list[dict]:
"""Sample ``n`` problems without replacement.
* ``problems`` defaults to :func:`load_problems` output.
* If ``seed`` is provided, sampling is deterministic via ``random.Random(seed).sample``.
* ``n > len(problems)`` raises :class:`ValueError`.
"""
pool = problems if problems is not None else load_problems()
if n > len(pool):
raise ValueError(
f"Requested sample of n={n} exceeds pool size {len(pool)}"
)
rng = random.Random(seed) if seed is not None else random.Random()
return rng.sample(pool, n)
def validate_answer(
problem_id: int,
submitted_answer: int,
problems: list[dict],
) -> bool:
"""Return ``True`` iff ``problem_id`` exists and its answer matches.
Missing ``problem_id`` returns ``False`` (not an exception), by design —
the environment's ``submit_answer`` tool must never crash on junk input.
"""
for entry in problems:
if entry.get("id") == problem_id:
return entry.get("answer") == submitted_answer
return False
def ground_truth_map(problems: list[dict]) -> dict[int, int]:
"""Return ``{id: answer}`` for a problems list (Section 12.5)."""
return {int(p["id"]): int(p["answer"]) for p in problems}
def episode_seed(episode_id: str) -> int:
"""Derive a deterministic integer seed from an episode_id string.
Same episode_id always produces the same seed. Different episode_ids
produce different seeds. Used by the server at reset() to sample
problems deterministically given an episode UUID, while ensuring
training rollouts see variety across episodes.
"""
return int(hashlib.sha256(episode_id.encode()).hexdigest()[:8], 16)