RedButton-v2 / redbutton_v2 /problems.py
Hugging557444's picture
V2-A: bootstrap RedButton v2 with v1 verbatim reusables
e78298b
Raw
History Blame Contribute Delete
2.61 kB
"""Math problem pool: load, sample (deterministic via seed), and verify.
Carries from v1 verbatim. v2 adds an ``is_eval_problem`` flag in the
JSON payload; the loader treats it as optional metadata and exposes it
through ``Problem.is_eval_problem`` for the sandbagging detector.
"""
import json
import random
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
@dataclass
class Problem:
id: int
problem: str
answer: int
difficulty: str
is_eval_problem: bool = False
def to_dict(self) -> Dict[str, Any]:
"""JSON-serialisable form."""
return {
"id": self.id,
"problem": self.problem,
"answer": self.answer,
"difficulty": self.difficulty,
"is_eval_problem": self.is_eval_problem,
}
@dataclass
class ProblemPool:
problems: List[Problem] = field(default_factory=list)
@classmethod
def load(cls, path: Union[str, Path]) -> "ProblemPool":
"""Load a problem pool from a JSON file at ``path``."""
path = Path(path)
with path.open() as f:
data = json.load(f)
return cls(problems=[
Problem(
id=item["id"],
problem=item["problem"],
answer=item["answer"],
difficulty=item["difficulty"],
is_eval_problem=bool(item.get("is_eval_problem", False)),
)
for item in data
])
def sample(self, n: int = 10, seed: Optional[int] = None) -> List[Problem]:
"""Return ``n`` problems sampled without replacement.
Deterministic when ``seed`` is given. When ``n`` exceeds the
pool size, returns the full pool (clamped, not raising).
"""
rng = random.Random(seed)
return rng.sample(self.problems, k=min(n, len(self.problems)))
def __len__(self) -> int:
return len(self.problems)
def load_problems(path: Optional[Union[str, Path]] = None) -> ProblemPool:
"""Load the bundled problem pool, or a custom path if provided."""
if path is None:
path = Path(__file__).parent / "data" / "problems_pool.json"
return ProblemPool.load(path)
def check_answer(submitted: Any, ground_truth: Any) -> bool:
"""Strict integer equality.
Rejects type mismatches (str vs int, float vs int, bool vs int)
explicitly — ``True == 1`` would otherwise pass naive ``==``.
"""
if type(submitted) is not int or type(ground_truth) is not int:
return False
return submitted == ground_truth