Spaces:
Running on Zero
Running on Zero
File size: 10,408 Bytes
0dd6c2f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 | from typing import Any
from typing_extensions import override
from linalg_zero.generator.context import CompositionContext
from linalg_zero.generator.entropy_control import EntropyConstraints
from linalg_zero.generator.models import (
ComponentResult,
CompositeResultBuilder,
DifficultyCategory,
Task,
Topic,
)
from linalg_zero.generator.sympy.base import (
CompositionStrategy,
ProblemComponent,
ProblemContext,
ProblemTemplate,
SympyProblemGenerator,
)
from linalg_zero.generator.sympy.template_engine import TemplateEngine
from linalg_zero.grpo.verify import verify_answers
from linalg_zero.shared.types import LibTypes
from linalg_zero.shared.utils import get_logger
logger = get_logger(__name__)
class SequentialComposition(CompositionStrategy):
"""
Sequential composition strategy.
Executes components in order, where each component can use results
from previous components. Useful for multi-step problems.
"""
def compose(self, components: list[ProblemComponent], base_context: CompositionContext) -> list[ComponentResult]:
"""Execute components using DeepMind-style entropy distribution."""
results = []
# Allocate entropy proportionally to integer module counts (simple, DM-style).
def component_modules(c: ProblemComponent) -> float:
return max(0, c.entropy_weight())
weights = [component_modules(c) for c in components]
total_modules = sum(weights)
if total_modules <= 0:
raise ValueError("Total modules must be > 0 for composite problems")
# Alternative use:
# component_sample_args = sample_args.split(len(components))
else:
# Instead of a uniform distribution, we sample the provided values component-wise
# This allows to provide a range of entropy values or fixed values for each component.
allocations: list[float] = []
for comp in components:
override: EntropyConstraints = comp.entropy_constraints
entropy = override.sample_entropy()
assert entropy is not None
allocations.append(float(entropy))
for alloc, weight, comp in zip(allocations, weights, components, strict=True):
if weight == 0 and alloc != 0:
raise ValueError(f"Weight is 0 but allocation is {alloc} for component {comp.name}")
if weight != 0 and alloc == 0:
raise ValueError(f"Weight is {weight} but allocation is 0 for component {comp.name}")
allocations = [0.0 if weight == 0 else alloc for alloc, weight in zip(allocations, weights, strict=True)]
# Set the composite budget to the sum of per-component allocations
base_context.entropy = float(sum(allocations))
for local_index, (component_wrapper, component_entropy) in enumerate(
zip(components, allocations, strict=True)
):
if not component_wrapper.can_execute(base_context):
continue
# Create a context copy with the allocated entropy for this component
component_context = CompositionContext(
component_entropy,
base_context.difficulty_level,
base_context._step_counter,
template_engine=base_context.template_engine,
local_index=local_index,
)
# NOTE[atom]: these variables can be useful to share state between components
component_context.constraints = base_context.constraints.copy()
# Pass previous component results to enable sequential data flow
component_context.component_results = base_context.component_results.copy()
result = component_wrapper.generate(component_context)
base_context.record_component_result(result)
base_context.stepwise_results.extend(component_context.stepwise_results)
base_context.golden_result.update(component_context.golden_result)
base_context._step_counter = component_context._step_counter
results.append(result)
return results
class CompositeProblem(SympyProblemGenerator):
"""
Generator for composite mathematical problems.
Combines multiple ProblemComponent instances using a CompositionStrategy
to create complex, multi-part mathematical problems.
"""
def __init__(
self,
components: list[ProblemComponent],
composition_strategy: CompositionStrategy,
template_engine: TemplateEngine,
difficulty_level: DifficultyCategory,
problem_type: Task,
topic: Topic,
):
super().__init__(
entropy=0.0,
difficulty_level=difficulty_level,
problem_type=problem_type,
topic=topic,
template_engine=template_engine,
local_index=-1,
constraints={},
)
self.components = components
self.composition_strategy = composition_strategy
# No global sample_args: composition strategy allocates per-component entropy
@override
def generate_mathematical_content(self, context: ProblemContext) -> ProblemTemplate:
"""Generate composed mathematical content."""
# Convert to CompositionContext with a temporary zero budget; the strategy sets the proper budget
comp_context = CompositionContext(
0.0,
context.difficulty_level,
context._step_counter,
self.template_engine,
local_index=-1,
)
comp_context.constraints = context.constraints.copy()
# Execute all components and store their results
component_results = self.composition_strategy.compose(self.components, comp_context)
if not component_results:
raise ValueError("No components could be executed")
# This is a helper class to aggregate component results and create a composite template
builder = CompositeResultBuilder(self.composition_strategy)
for result in component_results:
builder.add_component_result(result)
template = builder.build_template(comp_context, component_results)
# Transfer state back to original context
self._transfer_context_state(comp_context, context)
return template
def _transfer_context_state(self, comp_context: CompositionContext, original_context: ProblemContext) -> None:
"""Transfer entropy and tool call tracking back to original context."""
original_context.entropy = comp_context.entropy
original_context.used_entropy = comp_context.used_entropy
original_context.tool_calls_count = comp_context.tool_calls_count
original_context.stepwise_results = comp_context.stepwise_results
original_context.golden_result = comp_context.golden_result
original_context._step_counter = comp_context._step_counter
@override
def get_template_variables(self, template: ProblemTemplate) -> dict[str, Any]:
"""Not used for composite problems."""
raise NotImplementedError("Not used for composite problems.")
@override
def format_question(self, template: ProblemTemplate) -> str:
"""Format composite problem as natural language multi-step question."""
context_info = template.context_info
composition_type = context_info["composition_type"]
if isinstance(template.expression, list) and len(template.expression) > 1:
if composition_type == SequentialComposition.__name__:
return self._format_sequential_question(template)
else:
raise ValueError(f"Unknown composition type: {composition_type}")
else:
raise ValueError("Composite problem should have multiple expressions.")
def _format_sequential_question(self, template: ProblemTemplate) -> str:
"""Format sequential composition data with the results produced by each component."""
component_results: list[ComponentResult] = template.context_info.get("component_results", [])
if not component_results:
raise ValueError("Sequential composition requires component results with generators")
step_descriptions = []
for i, result in enumerate(component_results, 1):
formatted_question = result.generator.format_question(result.template)
formatted_question = formatted_question[0].lower() + formatted_question[1:]
step_descriptions.append(f"Step {i}: {formatted_question}")
return "\n".join(step_descriptions)
@override
def format_solution(self, template: ProblemTemplate) -> str:
"""Format composite problem solution using MathFormatter for clean output."""
component_results: list[ComponentResult] = template.context_info.get("component_results", [])
if not isinstance(template.sympy_solution, list):
raise TypeError("The sympy solution should be a list because the number of provided components is a list.")
if len(template.sympy_solution) == 1:
raise ValueError("Composite problem should have multiple solutions.")
return self.template_engine.format_composite_answer(template.sympy_solution, component_results)
@override
def verify_problem(self, template: ProblemTemplate) -> bool:
"""Verify the problem is mathematically correct."""
lib_results = template.lib_result
sympy_solutions = template.sympy_solution
assert isinstance(sympy_solutions, list)
assert isinstance(lib_results, list)
component_results: list[ComponentResult] = template.context_info["component_results"]
for sympy_solution, lib_result, result in zip(sympy_solutions, lib_results, component_results, strict=True):
precision = result.generator.precision
sympy_solution = self.formatter.sympy_to_primitive(sympy_solution, precision=precision)
assert isinstance(lib_result, LibTypes)
assert isinstance(sympy_solution, LibTypes)
if not verify_answers(sympy_solution, lib_result):
raise ValueError(f"Verification failed: sympy={sympy_solution} vs lib={lib_result}")
return True
|