File size: 4,334 Bytes
0dd6c2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from __future__ import annotations

import random
from dataclasses import dataclass
from enum import Enum

from linalg_zero.generator.models import DifficultyCategory, Task
from linalg_zero.shared.utils import get_logger

logger = get_logger(__name__)

# Deterministic controls
#
# Why both set_seed and DETERMINISTIC_BASE_SEED exist:
# - set_seed(seed) initializes Python/NumPy/SymPy RNGs. It makes a single
#   process deterministic only when the sequence of RNG calls is identical.
#   Across phases (e.g., analysis vs generation), RNG call order/count differs
#   (entropy sampling timing, retries, filters), so outputs can diverge despite
#   using the same seed.
# - When DETERMINISTIC_MODE is True, factories reseed per question using a
#   stable function of (DETERMINISTIC_BASE_SEED, problem_type, topic, index).
#   This pins each question's randomness to its identity, making results
#   invariant to incidental RNG call ordering differences between phases.
# - If a CLI --seed is provided and DETERMINISTIC_MODE is True, we set
#   DETERMINISTIC_BASE_SEED = --seed so users can reproduce/scan deterministic
#   sequences without code changes. When DETERMINISTIC_MODE is False, the base
#   seed is ignored and set_seed controls reproducibility as usual.
#
DETERMINISTIC_MODE: bool = True
DETERMINISTIC_BASE_SEED: int = 146959810


class Precision(Enum):
    """Precision for formatting mathematical expressions."""

    MATRIX_VECTOR_MULTIPLICATION = 2
    MATRIX_MATRIX_MULTIPLICATION = 2
    LINEAR_SYSTEM_SOLVER = 2
    DETERMINANT = 2
    FROBENIUS_NORM = 2
    MATRIX_RANK = 2
    MATRIX_TRANSPOSE = 2
    MATRIX_INVERSE = 2
    MATRIX_TRACE = 2
    MATRIX_COFACTOR = 2
    FULL = -1


class ToolCallDifficulty(Enum):
    """Tool call based difficulty levels."""

    SINGLE_TOOL = 1
    DUAL_TOOL = 2
    MULTI_TOOL = 3


@dataclass(frozen=True)
class ProblemConfig:
    """Configuration parameters for problems based on tool calls and difficulty."""

    target_tool_calls: int
    matrix_size_range: tuple[int, int]
    allow_rationals: bool

    def get_random_matrix_size(self) -> int:
        """Get a random matrix size within the allowed range."""
        return random.randint(*self.matrix_size_range)


# Possible entropy ranges:
# Moderate variability:
#   - 1 tool call: (1.2, 1.8)
#   - 2 tool calls: (2.6, 3.6)
#   - 3 tool calls: (3.8, 5.2)

# High variability:
#   - 1 tool call: (1.0, 2.0)
#   - 2 tool calls: (2.0, 4.0)
#   - 3 tool calls: (3.0, 6.0)

# Low variability:
#   - 1 tool call: (1.4, 1.6)
#   - 2 tool calls: (2.8, 3.2)
#   - 3 tool calls: (4.2, 4.8)

EASY_PROBLEM_CONFIG = ProblemConfig(target_tool_calls=1, matrix_size_range=(2, 3), allow_rationals=False)
MEDIUM_PROBLEM_CONFIG = ProblemConfig(target_tool_calls=1, matrix_size_range=(2, 3), allow_rationals=False)
HARD_PROBLEM_CONFIG = ProblemConfig(target_tool_calls=1, matrix_size_range=(2, 2), allow_rationals=False)


def determine_difficulty(problem_type: Task) -> DifficultyCategory:
    """Determine difficulty category based on problem type name."""
    if problem_type.name.startswith("THREE_"):
        return DifficultyCategory.THREE_TOOL_CALLS
    elif problem_type.name.startswith("TWO_"):
        return DifficultyCategory.TWO_TOOL_CALLS
    elif problem_type.name.startswith("ONE_"):
        return DifficultyCategory.ONE_TOOL_CALL
    else:
        raise ValueError(f"Invalid problem type: {problem_type}")


def get_problem_config(difficulty: DifficultyCategory) -> ProblemConfig:
    """Get problem configuration for a given difficulty level, topic, and problem type."""
    if difficulty == DifficultyCategory.ONE_TOOL_CALL:
        return EASY_PROBLEM_CONFIG
    elif difficulty == DifficultyCategory.TWO_TOOL_CALLS:
        return MEDIUM_PROBLEM_CONFIG
    elif difficulty == DifficultyCategory.THREE_TOOL_CALLS:
        return HARD_PROBLEM_CONFIG
    else:
        raise ValueError(f"Invalid difficulty category: {difficulty}")


def validate_tool_calls(expected: int, actual: int, problem_type: Task) -> bool:
    """Validate that a problem uses the expected number of tool calls."""
    if actual != expected:
        raise ValueError(
            f"Problem type '{problem_type}' expected {expected} tool calls, "
            f"but used {actual}. This violates the difficulty system."
        )
    return True