linalg-zero / linalg_zero /generator /difficulty_config.py
atomwalk12's picture
initial commit
0dd6c2f
from __future__ import annotations
import random
from dataclasses import dataclass
from enum import Enum
from linalg_zero.generator.models import DifficultyCategory, Task
from linalg_zero.shared.utils import get_logger
logger = get_logger(__name__)
# Deterministic controls
#
# Why both set_seed and DETERMINISTIC_BASE_SEED exist:
# - set_seed(seed) initializes Python/NumPy/SymPy RNGs. It makes a single
# process deterministic only when the sequence of RNG calls is identical.
# Across phases (e.g., analysis vs generation), RNG call order/count differs
# (entropy sampling timing, retries, filters), so outputs can diverge despite
# using the same seed.
# - When DETERMINISTIC_MODE is True, factories reseed per question using a
# stable function of (DETERMINISTIC_BASE_SEED, problem_type, topic, index).
# This pins each question's randomness to its identity, making results
# invariant to incidental RNG call ordering differences between phases.
# - If a CLI --seed is provided and DETERMINISTIC_MODE is True, we set
# DETERMINISTIC_BASE_SEED = --seed so users can reproduce/scan deterministic
# sequences without code changes. When DETERMINISTIC_MODE is False, the base
# seed is ignored and set_seed controls reproducibility as usual.
#
DETERMINISTIC_MODE: bool = True
DETERMINISTIC_BASE_SEED: int = 146959810
class Precision(Enum):
"""Precision for formatting mathematical expressions."""
MATRIX_VECTOR_MULTIPLICATION = 2
MATRIX_MATRIX_MULTIPLICATION = 2
LINEAR_SYSTEM_SOLVER = 2
DETERMINANT = 2
FROBENIUS_NORM = 2
MATRIX_RANK = 2
MATRIX_TRANSPOSE = 2
MATRIX_INVERSE = 2
MATRIX_TRACE = 2
MATRIX_COFACTOR = 2
FULL = -1
class ToolCallDifficulty(Enum):
"""Tool call based difficulty levels."""
SINGLE_TOOL = 1
DUAL_TOOL = 2
MULTI_TOOL = 3
@dataclass(frozen=True)
class ProblemConfig:
"""Configuration parameters for problems based on tool calls and difficulty."""
target_tool_calls: int
matrix_size_range: tuple[int, int]
allow_rationals: bool
def get_random_matrix_size(self) -> int:
"""Get a random matrix size within the allowed range."""
return random.randint(*self.matrix_size_range)
# Possible entropy ranges:
# Moderate variability:
# - 1 tool call: (1.2, 1.8)
# - 2 tool calls: (2.6, 3.6)
# - 3 tool calls: (3.8, 5.2)
# High variability:
# - 1 tool call: (1.0, 2.0)
# - 2 tool calls: (2.0, 4.0)
# - 3 tool calls: (3.0, 6.0)
# Low variability:
# - 1 tool call: (1.4, 1.6)
# - 2 tool calls: (2.8, 3.2)
# - 3 tool calls: (4.2, 4.8)
EASY_PROBLEM_CONFIG = ProblemConfig(target_tool_calls=1, matrix_size_range=(2, 3), allow_rationals=False)
MEDIUM_PROBLEM_CONFIG = ProblemConfig(target_tool_calls=1, matrix_size_range=(2, 3), allow_rationals=False)
HARD_PROBLEM_CONFIG = ProblemConfig(target_tool_calls=1, matrix_size_range=(2, 2), allow_rationals=False)
def determine_difficulty(problem_type: Task) -> DifficultyCategory:
"""Determine difficulty category based on problem type name."""
if problem_type.name.startswith("THREE_"):
return DifficultyCategory.THREE_TOOL_CALLS
elif problem_type.name.startswith("TWO_"):
return DifficultyCategory.TWO_TOOL_CALLS
elif problem_type.name.startswith("ONE_"):
return DifficultyCategory.ONE_TOOL_CALL
else:
raise ValueError(f"Invalid problem type: {problem_type}")
def get_problem_config(difficulty: DifficultyCategory) -> ProblemConfig:
"""Get problem configuration for a given difficulty level, topic, and problem type."""
if difficulty == DifficultyCategory.ONE_TOOL_CALL:
return EASY_PROBLEM_CONFIG
elif difficulty == DifficultyCategory.TWO_TOOL_CALLS:
return MEDIUM_PROBLEM_CONFIG
elif difficulty == DifficultyCategory.THREE_TOOL_CALLS:
return HARD_PROBLEM_CONFIG
else:
raise ValueError(f"Invalid difficulty category: {difficulty}")
def validate_tool_calls(expected: int, actual: int, problem_type: Task) -> bool:
"""Validate that a problem uses the expected number of tool calls."""
if actual != expected:
raise ValueError(
f"Problem type '{problem_type}' expected {expected} tool calls, "
f"but used {actual}. This violates the difficulty system."
)
return True