File size: 2,614 Bytes
e78298b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""Math problem pool: load, sample (deterministic via seed), and verify.

Carries from v1 verbatim. v2 adds an ``is_eval_problem`` flag in the
JSON payload; the loader treats it as optional metadata and exposes it
through ``Problem.is_eval_problem`` for the sandbagging detector.
"""

import json
import random
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Union


@dataclass
class Problem:
    id: int
    problem: str
    answer: int
    difficulty: str
    is_eval_problem: bool = False

    def to_dict(self) -> Dict[str, Any]:
        """JSON-serialisable form."""
        return {
            "id": self.id,
            "problem": self.problem,
            "answer": self.answer,
            "difficulty": self.difficulty,
            "is_eval_problem": self.is_eval_problem,
        }


@dataclass
class ProblemPool:
    problems: List[Problem] = field(default_factory=list)

    @classmethod
    def load(cls, path: Union[str, Path]) -> "ProblemPool":
        """Load a problem pool from a JSON file at ``path``."""
        path = Path(path)
        with path.open() as f:
            data = json.load(f)
        return cls(problems=[
            Problem(
                id=item["id"],
                problem=item["problem"],
                answer=item["answer"],
                difficulty=item["difficulty"],
                is_eval_problem=bool(item.get("is_eval_problem", False)),
            )
            for item in data
        ])

    def sample(self, n: int = 10, seed: Optional[int] = None) -> List[Problem]:
        """Return ``n`` problems sampled without replacement.

        Deterministic when ``seed`` is given. When ``n`` exceeds the
        pool size, returns the full pool (clamped, not raising).
        """
        rng = random.Random(seed)
        return rng.sample(self.problems, k=min(n, len(self.problems)))

    def __len__(self) -> int:
        return len(self.problems)


def load_problems(path: Optional[Union[str, Path]] = None) -> ProblemPool:
    """Load the bundled problem pool, or a custom path if provided."""
    if path is None:
        path = Path(__file__).parent / "data" / "problems_pool.json"
    return ProblemPool.load(path)


def check_answer(submitted: Any, ground_truth: Any) -> bool:
    """Strict integer equality.

    Rejects type mismatches (str vs int, float vs int, bool vs int)
    explicitly — ``True == 1`` would otherwise pass naive ``==``.
    """
    if type(submitted) is not int or type(ground_truth) is not int:
        return False
    return submitted == ground_truth