Spaces:
Running on Zero
Running on Zero
File size: 4,025 Bytes
0dd6c2f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | import pytest
from linalg_zero.generator.composition.components import (
DeterminantWrapperComponent,
FrobeniusNormWrapperComponent,
MatrixTraceWrapperComponent,
RankWrapperComponent,
TransposeWrapperComponent,
)
from linalg_zero.generator.composition.composition import (
CompositeProblem,
SequentialComposition,
)
from linalg_zero.generator.entropy_control import EntropyConstraints
from linalg_zero.generator.models import DifficultyCategory, Task, Topic
from linalg_zero.generator.sympy.generators.determinant_generator import (
DeterminantGenerator,
DeterminantGeneratorDependent,
)
from linalg_zero.generator.sympy.generators.frobenius_norm_generator import (
FrobeniusNormGenerator,
FrobeniusNormGeneratorDependent,
)
from linalg_zero.generator.sympy.generators.matrix_rank_generator import (
MatrixRankGenerator,
MatrixRankGeneratorDependent,
)
from linalg_zero.generator.sympy.generators.matrix_trace_generator import (
MatrixTraceGenerator,
MatrixTraceGeneratorDependent,
)
from linalg_zero.generator.sympy.generators.matrix_transpose_generator import (
MatrixTransposeGenerator,
MatrixTransposeGeneratorDependent,
)
from linalg_zero.generator.sympy.template_engine import TemplateEngine
def make_composite(
components: list, difficulty: DifficultyCategory = DifficultyCategory.TWO_TOOL_CALLS
) -> CompositeProblem:
return CompositeProblem(
components=components,
composition_strategy=SequentialComposition(),
difficulty_level=difficulty,
problem_type=Task.SEQUENTIAL_PROBLEM,
topic=Topic.LINEAR_ALGEBRA,
template_engine=TemplateEngine(),
)
class TestWrapperComponentGeneratorSelectionComprehensive:
"""Comprehensive tests across all wrapper components to ensure consistency."""
@pytest.mark.parametrize(
"wrapper_class,task,independent_generator",
[
(DeterminantWrapperComponent, Task.ONE_DETERMINANT, DeterminantGenerator),
(FrobeniusNormWrapperComponent, Task.ONE_FROBENIUS_NORM, FrobeniusNormGenerator),
(RankWrapperComponent, Task.ONE_RANK, MatrixRankGenerator),
(MatrixTraceWrapperComponent, Task.ONE_TRACE, MatrixTraceGenerator),
(TransposeWrapperComponent, Task.ONE_TRANSPOSE, MatrixTransposeGenerator),
],
)
def test_all_wrappers_independent_case(self, wrapper_class, task, independent_generator):
"""Test that all wrapper components correctly select independent generator when is_independent=True."""
component = wrapper_class(
name=task, constraints={"is_independent": True}, entropy_constraints=EntropyConstraints(entropy=0.1)
)
assert component.generator_class is independent_generator
assert component.is_independent is True
assert component.name == task
@pytest.mark.parametrize(
"wrapper_class,task,dependent_generator",
[
(DeterminantWrapperComponent, Task.ONE_DETERMINANT, DeterminantGeneratorDependent),
(FrobeniusNormWrapperComponent, Task.ONE_FROBENIUS_NORM, FrobeniusNormGeneratorDependent),
(RankWrapperComponent, Task.ONE_RANK, MatrixRankGeneratorDependent),
(MatrixTraceWrapperComponent, Task.ONE_TRACE, MatrixTraceGeneratorDependent),
(TransposeWrapperComponent, Task.ONE_TRANSPOSE, MatrixTransposeGeneratorDependent),
],
)
def test_all_wrappers_dependent_case(self, wrapper_class, task, dependent_generator):
"""Test that all wrapper components correctly select dependent generator when is_independent=False."""
component = wrapper_class(
name=task,
constraints={"is_independent": False, "input_indices": {"input_vector_b": 0}},
entropy_constraints=EntropyConstraints(entropy=0.1),
)
assert component.generator_class is dependent_generator
assert component.is_independent is False
assert component.name == task
|