from abc import abstractmethod from typing import Any import sympy from sympy import Float, Integer, Rational from linalg_zero.generator.composition.composition import ( ComponentResult, CompositionContext, ProblemComponent, ) from linalg_zero.generator.entropy_control import EntropyConstraints from linalg_zero.generator.generation_constraints import GenerationConstraints from linalg_zero.generator.models import Task, Topic from linalg_zero.generator.sympy.base import ProblemContext, ProblemTemplate, SympyProblemGenerator from linalg_zero.generator.sympy.generators.determinant_generator import ( DeterminantGenerator, DeterminantGeneratorDependent, ) from linalg_zero.generator.sympy.generators.frobenius_norm_generator import ( FrobeniusNormGenerator, FrobeniusNormGeneratorDependent, ) from linalg_zero.generator.sympy.generators.linear_system_generator import ( LinearSystemGenerator, LinearSystemGeneratorDependent, ) from linalg_zero.generator.sympy.generators.matrix_cofactor_generator import ( MatrixCofactorGenerator, MatrixCofactorGeneratorDependent, ) from linalg_zero.generator.sympy.generators.matrix_inverse_generator import ( MatrixInverseGenerator, MatrixInverseGeneratorDependent, ) from linalg_zero.generator.sympy.generators.matrix_matrix_generator import ( MatrixMatrixMultiplicationGenerator, MatrixMatrixMultiplicationGeneratorDependent, ) from linalg_zero.generator.sympy.generators.matrix_rank_generator import ( MatrixRankGenerator, MatrixRankGeneratorDependent, ) from linalg_zero.generator.sympy.generators.matrix_trace_generator import ( MatrixTraceGenerator, MatrixTraceGeneratorDependent, ) from linalg_zero.generator.sympy.generators.matrix_transpose_generator import ( MatrixTransposeGenerator, MatrixTransposeGeneratorDependent, ) from linalg_zero.generator.sympy.generators.matrix_vector_generator import ( MatrixVectorMultiplicationGenerator, MatrixVectorMultiplicationGeneratorDependent, ) class SympyGeneratorWrapperComponent(ProblemComponent): """Generic base class for wrapping sympy generators in the composition system.""" def __init__( self, name: Task, generator_class: type[SympyProblemGenerator], component_type: Task, topic: Topic, constraints: dict[str, Any], entropy_constraints: EntropyConstraints, gen_constraints: GenerationConstraints | None = None, **kwargs: Any, ) -> None: is_independent = constraints.get("is_independent") assert isinstance(is_independent, bool) super().__init__(name, is_independent=is_independent, entropy_constraints=entropy_constraints, **kwargs) self.constraints = constraints self.gen_constraints = gen_constraints self.generator_class = generator_class self.component_type = component_type self.topic = topic def get_generator_params(self, context: CompositionContext, input_names: list[str]) -> dict[str, Any]: """Extract previous component results to use as inputs.""" if not self.is_independent: params = {} input_indices = self.constraints["input_indices"] sources = self.constraints.get("sources", {}) # Validate we have indices for all input names for input_name in input_names: if input_name not in input_indices: raise ValueError(f"Missing input_index for input '{input_name}'") component_index = input_indices[input_name] source_type = sources[input_name] previous_result = context.component_results[component_index] # Get the appropriate data based on source type if source_type == "result": # Get the computed result from the previous component if not hasattr(previous_result.template, "sympy_solution"): raise ValueError(f"Previous component result has no sympy_solution: {previous_result}") previous_sol = previous_result.template.sympy_solution else: # Get a specific variable from the previous component's variables if not hasattr(previous_result.template, "variables"): raise ValueError(f"Previous component result has no variables: {previous_result}") variables = previous_result.template.variables if source_type not in variables: available_vars = list(variables.keys()) raise ValueError( f"Variable '{source_type}' not found in component {component_index}. Available variables: {available_vars}" ) previous_sol = variables[source_type] self._validate_dependent_input(previous_sol) # Add to params params[input_name] = previous_sol params[f"{input_name}_index"] = component_index return params return {} def _get_input_validation_spec(self) -> dict[str, bool]: """Subclasses may override to declare constraints for dependent input. Supported flags: - require_matrix: input must be a sympy.Matrix - non_empty: rows > 0 and cols > 0 - column_vector: cols == 1 - square: rows == cols - numeric_only: all elements are numeric (Integer, Float, Rational) """ return {} def _validate_dependent_input(self, value: Any) -> None: """Validate dependent input according to subclass spec.""" spec = self._get_input_validation_spec() if not spec: return is_matrix = isinstance(value, sympy.Matrix) if spec.get("require_matrix", False) and not is_matrix: raise TypeError(f"Expected dependent input to be a sympy Matrix, got {type(value)}") if not is_matrix: raise TypeError(f"Dependent input must be a sympy Matrix-like with shape, got {type(value)}") rows, cols = value.shape if spec.get("non_empty", False) and (rows == 0 or cols == 0): raise ValueError(f"Dependent input matrix cannot be empty, got shape {value.shape}") if spec.get("column_vector", False) and cols != 1: raise ValueError(f"Dependent input must be a column vector with shape (n, 1), got shape {value.shape}") if spec.get("square", False) and rows != cols: raise ValueError(f"Dependent input must be square, got shape {value.shape}") if spec.get("numeric_only", False) and not all( isinstance(element, Integer | Float | Rational) for element in value ): raise ValueError("Dependent input must contain only numeric elements") @abstractmethod def get_input_name(self) -> list[str]: pass def generate(self, context: CompositionContext) -> ComponentResult: # This context is used for communication and state tracking problem_context = ProblemContext( entropy=context.entropy, difficulty_level=context.difficulty_level, step_counter=context._step_counter ) # Get any additional parameters for parameterized generation additional_params = self.get_generator_params(context, self.get_input_name()) additional_params["constraints"] = self.constraints additional_params["gen_constraints"] = self.gen_constraints # Now, we perform the 3 key steps involved in component generation generator: SympyProblemGenerator = self.generator_class( difficulty_level=context.difficulty_level, problem_type=self.component_type, topic=self.topic, entropy=problem_context.entropy, is_independent=self.is_independent, template_engine=context.template_engine, local_index=context.local_index, **additional_params, ) template: ProblemTemplate = generator.generate_mathematical_content(problem_context) generator.verify_problem(template) # Transfer the state of the problem context to the new problem template formatted_template = ProblemTemplate( expression=template.expression, variables=template.variables, sympy_solution=template.sympy_solution, lib_result=template.lib_result, context_info={ **template.context_info, }, difficulty_markers=template.difficulty_markers, difficulty=template.difficulty, ) context.stepwise_results.extend(problem_context.stepwise_results) context.golden_result.update(problem_context.golden_result) context._step_counter = problem_context._step_counter return ComponentResult( template=formatted_template, entropy_consumed=problem_context.used_entropy, tool_calls_used=problem_context.tool_calls_count, generator=generator, ) class MatrixVectorMultiplicationWrapperComponent(SympyGeneratorWrapperComponent): """Wrapper for the MatrixVectorMultiplicationGenerator.""" def __init__(self, name: Task, **kwargs: Any) -> None: constraints = kwargs["constraints"] is_independent = constraints["is_independent"] generator_cls = ( MatrixVectorMultiplicationGenerator if is_independent else MatrixVectorMultiplicationGeneratorDependent ) super().__init__( name=name, generator_class=generator_cls, component_type=Task.ONE_MATRIX_VECTOR_MULTIPLICATION, topic=Topic.LINEAR_ALGEBRA, **kwargs, ) def get_input_name(self) -> list[str]: return ["input_vector_b"] def entropy_weight(self) -> float: if self.is_independent: return 1.0 # This component still needs to generate a matrix, even if a vector is # provided, so we provide 0.5 entropy weight. return 0.5 def _get_input_validation_spec(self) -> dict[str, bool]: return {"require_matrix": True, "non_empty": True, "column_vector": True} class MatrixMatrixMultiplicationWrapperComponent(SympyGeneratorWrapperComponent): """Wrapper for the MatrixMatrixMultiplicationGeneratorDependent.""" def __init__(self, name: Task, **kwargs: Any) -> None: self.constraints = kwargs["constraints"] is_independent = self.constraints["is_independent"] generator_cls = ( MatrixMatrixMultiplicationGenerator if is_independent else MatrixMatrixMultiplicationGeneratorDependent ) super().__init__( name=name, generator_class=generator_cls, component_type=Task.ONE_MATRIX_MATRIX_MULTIPLICATION, topic=Topic.LINEAR_ALGEBRA, **kwargs, ) def entropy_weight(self) -> float: # If independent, allocate all entropy if self.is_independent: return 1.0 input_indices = self.constraints.get("input_indices", {}) if "input_matrix_B" not in input_indices: # matrix_B is not provided, so it will be generated inside the component # allocate half of the total entropy amount return 0.5 else: # All components provided, allocate no entropy return 0.0 def get_input_name(self) -> list[str]: # Check if we need both matrices or just one based on constraints input_indices = self.constraints["input_indices"] if "input_matrix_B" in input_indices: return ["input_matrix_A", "input_matrix_B"] else: return ["input_matrix_A"] def _get_input_validation_spec(self) -> dict[str, bool]: return {"require_matrix": True, "non_empty": True} class LinearSystemSolverWrapperComponent(SympyGeneratorWrapperComponent): """Wrapper for the LinearSystemGenerator.""" def __init__(self, name: Task, **kwargs: Any) -> None: constraints = kwargs["constraints"] is_independent = constraints["is_independent"] generator_cls = LinearSystemGenerator if is_independent else LinearSystemGeneratorDependent super().__init__( name=name, generator_class=generator_cls, component_type=Task.ONE_LINEAR_SYSTEM_SOLVER, topic=Topic.LINEAR_ALGEBRA, **kwargs, ) def entropy_weight(self) -> float: if self.is_independent: return 1.0 # This component still needs to generate a matrix, even if vector b is # provided, so we provide 0.5 entropy weight. return 0.5 def get_input_name(self) -> list[str]: return ["input_vector_b"] def _get_input_validation_spec(self) -> dict[str, bool]: return {"require_matrix": True, "non_empty": True, "column_vector": True} class FrobeniusNormWrapperComponent(SympyGeneratorWrapperComponent): """Wrapper for the FrobeniusNormGenerator.""" def __init__(self, name: Task, **kwargs: Any) -> None: constraints = kwargs["constraints"] is_independent = constraints["is_independent"] generator_cls = FrobeniusNormGenerator if is_independent else FrobeniusNormGeneratorDependent super().__init__( name=name, generator_class=generator_cls, component_type=Task.ONE_FROBENIUS_NORM, topic=Topic.LINEAR_ALGEBRA, **kwargs, ) def get_input_name(self) -> list[str]: return ["input_matrix"] def _get_input_validation_spec(self) -> dict[str, bool]: return {"require_matrix": True, "non_empty": True} def entropy_weight(self) -> float: if self.is_independent: return 1.0 return 0.0 class DeterminantWrapperComponent(SympyGeneratorWrapperComponent): """Wrapper for the DeterminantGenerator.""" def __init__(self, name: Task, **kwargs: Any) -> None: constraints = kwargs["constraints"] is_independent = constraints["is_independent"] generator_cls = DeterminantGenerator if is_independent else DeterminantGeneratorDependent super().__init__( name=name, generator_class=generator_cls, component_type=Task.ONE_DETERMINANT, topic=Topic.LINEAR_ALGEBRA, **kwargs, ) def entropy_weight(self) -> float: if self.is_independent: return 1.0 return 0.0 def get_input_name(self) -> list[str]: return ["input_matrix"] def _get_input_validation_spec(self) -> dict[str, bool]: return {"require_matrix": True, "non_empty": True, "square": True} class RankWrapperComponent(SympyGeneratorWrapperComponent): """Wrapper for the MatrixRankGenerator.""" def __init__(self, name: Task, **kwargs: Any) -> None: constraints = kwargs["constraints"] is_independent = constraints["is_independent"] generator_cls = MatrixRankGenerator if is_independent else MatrixRankGeneratorDependent super().__init__( name=name, generator_class=generator_cls, component_type=Task.ONE_RANK, topic=Topic.LINEAR_ALGEBRA, **kwargs, ) def entropy_weight(self) -> float: if self.is_independent: return 1.0 return 0.0 def get_input_name(self) -> list[str]: return ["input_matrix"] def _get_input_validation_spec(self) -> dict[str, bool]: return {"require_matrix": True, "non_empty": True, "numeric_only": True} class TransposeWrapperComponent(SympyGeneratorWrapperComponent): """Wrapper for the MatrixTransposeGenerator.""" def __init__(self, name: Task, **kwargs: Any) -> None: constraints = kwargs["constraints"] is_independent = constraints["is_independent"] generator_cls = MatrixTransposeGenerator if is_independent else MatrixTransposeGeneratorDependent super().__init__( name=name, generator_class=generator_cls, component_type=Task.ONE_TRANSPOSE, topic=Topic.LINEAR_ALGEBRA, **kwargs, ) def entropy_weight(self) -> float: if self.is_independent: return 1.0 return 0.0 def get_input_name(self) -> list[str]: return ["input_matrix"] def _get_input_validation_spec(self) -> dict[str, bool]: return {"require_matrix": True, "non_empty": True} class MatrixTraceWrapperComponent(SympyGeneratorWrapperComponent): """Wrapper for the MatrixTraceGenerator.""" def __init__(self, name: Task, **kwargs: Any) -> None: constraints = kwargs["constraints"] is_independent = constraints["is_independent"] generator_cls = MatrixTraceGenerator if is_independent else MatrixTraceGeneratorDependent super().__init__( name=name, generator_class=generator_cls, component_type=Task.ONE_TRACE, topic=Topic.LINEAR_ALGEBRA, **kwargs, ) def entropy_weight(self) -> float: if self.is_independent: return 1.0 return 0.0 def get_input_name(self) -> list[str]: return ["input_matrix"] def _get_input_validation_spec(self) -> dict[str, bool]: return {"require_matrix": True, "non_empty": True, "square": True} class MatrixInverseWrapperComponent(SympyGeneratorWrapperComponent): """Wrapper for the MatrixInverseGenerator.""" def __init__(self, name: Task, **kwargs: Any) -> None: constraints = kwargs["constraints"] is_independent = constraints["is_independent"] generator_cls = MatrixInverseGenerator if is_independent else MatrixInverseGeneratorDependent super().__init__( name=name, generator_class=generator_cls, component_type=Task.ONE_INVERSE, topic=Topic.LINEAR_ALGEBRA, **kwargs, ) def entropy_weight(self) -> float: if self.is_independent: return 1.0 return 0.0 def get_input_name(self) -> list[str]: return ["input_matrix"] def _get_input_validation_spec(self) -> dict[str, bool]: return {"require_matrix": True, "non_empty": True, "square": True, "invertible": True} class MatrixCofactorWrapperComponent(SympyGeneratorWrapperComponent): """Wrapper for the MatrixCofactorGenerator.""" def __init__(self, name: Task, **kwargs: Any) -> None: constraints = kwargs["constraints"] is_independent = constraints["is_independent"] generator_cls = MatrixCofactorGenerator if is_independent else MatrixCofactorGeneratorDependent super().__init__( name=name, generator_class=generator_cls, component_type=Task.ONE_COFACTOR, topic=Topic.LINEAR_ALGEBRA, **kwargs, ) def entropy_weight(self) -> float: if self.is_independent: return 1.0 return 0.0 def get_input_name(self) -> list[str]: return ["input_matrix"] def _get_input_validation_spec(self) -> dict[str, bool]: return {"require_matrix": True, "non_empty": True, "square": True}