Spaces:
Running on Zero
Running on Zero
| import json | |
| from types import TracebackType | |
| from typing import Any | |
| from linalg_zero.generator.models import ComponentResult, DifficultyCategory | |
| from linalg_zero.generator.sympy.template_engine import TemplateEngine | |
| from linalg_zero.shared.types import LibTypes | |
| class ProblemContext: | |
| """ | |
| Context manager for state information around the resolution process. | |
| """ | |
| def __init__(self, entropy: float, difficulty_level: DifficultyCategory, step_counter: int): | |
| self.entropy = entropy | |
| self.difficulty_level = difficulty_level | |
| self.used_entropy = 0.0 | |
| self.tool_calls_count = 0 | |
| self.stepwise_results: list[dict[str, Any]] = [] | |
| self.golden_result: dict[str, str] = {} | |
| self._step_counter = step_counter | |
| self.constraints: dict[str, Any] = {} | |
| def __enter__(self) -> "ProblemContext": | |
| return self | |
| def __exit__(self, exc_type: type, exc_val: Exception, exc_tb: TracebackType) -> None: | |
| pass | |
| def record_entropy_usage(self, amount: float) -> None: | |
| """ | |
| Record entropy usage for tracking problem complexity. | |
| """ | |
| self.used_entropy += amount | |
| def allocate_entropy(self, entropy: float | None) -> float: | |
| """ | |
| Resolve and consume an entropy amount based on the given value or use | |
| entire entropy budget if None is provided. The provided value allows to | |
| allocate entropy multiple times across a context lifetime. | |
| The chosen amount is recorded against the context budget. | |
| """ | |
| remaining = self.entropy - self.used_entropy | |
| if remaining <= 1e-12: | |
| raise ValueError(f"Entropy budget exceeded: remaining {remaining:.3f}") | |
| amount: float | None = None | |
| if entropy is not None: | |
| # If entropy is provided, use it directly. | |
| # This can happen during generator lifetime that require multiple variable allocations. | |
| amount = entropy | |
| if amount is None: | |
| amount = remaining | |
| if amount > remaining + 1e-12: | |
| raise ValueError(f"Entropy budget exceeded: request {amount:.3f}, remaining {remaining:.3f}") | |
| self.record_entropy_usage(amount) | |
| return amount | |
| def _prepare_verification_data(self, input_data: dict[str, Any]) -> dict[str, Any]: | |
| """Prepare verification data by extracting dependencies and input fields.""" | |
| verification = {} | |
| num_inputs = 0 | |
| num_dependencies = 0 | |
| # Handle dependencies | |
| dependent_on = input_data.pop("dependent_on", None) | |
| if dependent_on is not None: | |
| verification["dependent_on"] = dependent_on | |
| num_dependencies = len(dependent_on) | |
| # Extract and JSON-encode all input_* fields | |
| for key in list(input_data.keys()): | |
| if key.startswith("input_"): | |
| verification[key] = json.dumps(input_data.pop(key)) | |
| num_inputs += 1 | |
| assert num_inputs == num_dependencies, "Number of inputs and dependencies must match" | |
| # Add generator type and remaining input data | |
| verification["generator_type"] = input_data.pop("generator_type") | |
| verification["input"] = json.dumps(input_data) | |
| return verification | |
| def record_tool_call( | |
| self, | |
| function_name: str, | |
| result: LibTypes, | |
| input_data: dict[str, Any], | |
| is_final: bool = False, | |
| ) -> str: | |
| """ | |
| Record a tool call with its result. It tracks the dependencies between | |
| steps which will later be used to verify correctness during GRPO. | |
| """ | |
| self.tool_calls_count += 1 | |
| step_id = str(self._step_counter) | |
| if result is not None: | |
| result_json = json.dumps(result) | |
| step_data = { | |
| "tool": function_name, | |
| "result": result_json, | |
| "step_id": step_id, | |
| "verification": self._prepare_verification_data(input_data), | |
| } | |
| if is_final: | |
| self.golden_result = {"final_answer": result_json, "from_step_id": step_id} | |
| self.stepwise_results.append(step_data) | |
| self._step_counter += 1 | |
| return step_id | |
| class CompositionContext(ProblemContext): | |
| """ | |
| Extends the base ProblemContext to support shared state and global variables | |
| across composed problem components. | |
| """ | |
| def __init__( | |
| self, | |
| entropy: float, | |
| difficulty_level: DifficultyCategory, | |
| step_counter: int, | |
| template_engine: TemplateEngine, | |
| local_index: int, | |
| ): | |
| super().__init__(entropy, difficulty_level, step_counter) | |
| self.component_results: list[ComponentResult] = [] | |
| self.template_engine = template_engine | |
| self.local_index = local_index | |
| def record_component_result(self, result: ComponentResult) -> None: | |
| """Record the result of a component execution.""" | |
| self.component_results.append(result) | |
| # Update entropy usage and validate budget | |
| self.used_entropy += result.entropy_consumed | |
| if self.used_entropy > self.entropy + 1e-12: | |
| raise ValueError(f"Entropy budget exceeded: used {self.used_entropy:.3f}, available {self.entropy:.3f}") | |
| self.tool_calls_count += result.tool_calls_used | |