File size: 4,025 Bytes
0dd6c2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import pytest

from linalg_zero.generator.composition.components import (
    DeterminantWrapperComponent,
    FrobeniusNormWrapperComponent,
    MatrixTraceWrapperComponent,
    RankWrapperComponent,
    TransposeWrapperComponent,
)
from linalg_zero.generator.composition.composition import (
    CompositeProblem,
    SequentialComposition,
)
from linalg_zero.generator.entropy_control import EntropyConstraints
from linalg_zero.generator.models import DifficultyCategory, Task, Topic
from linalg_zero.generator.sympy.generators.determinant_generator import (
    DeterminantGenerator,
    DeterminantGeneratorDependent,
)
from linalg_zero.generator.sympy.generators.frobenius_norm_generator import (
    FrobeniusNormGenerator,
    FrobeniusNormGeneratorDependent,
)
from linalg_zero.generator.sympy.generators.matrix_rank_generator import (
    MatrixRankGenerator,
    MatrixRankGeneratorDependent,
)
from linalg_zero.generator.sympy.generators.matrix_trace_generator import (
    MatrixTraceGenerator,
    MatrixTraceGeneratorDependent,
)
from linalg_zero.generator.sympy.generators.matrix_transpose_generator import (
    MatrixTransposeGenerator,
    MatrixTransposeGeneratorDependent,
)
from linalg_zero.generator.sympy.template_engine import TemplateEngine


def make_composite(
    components: list, difficulty: DifficultyCategory = DifficultyCategory.TWO_TOOL_CALLS
) -> CompositeProblem:
    return CompositeProblem(
        components=components,
        composition_strategy=SequentialComposition(),
        difficulty_level=difficulty,
        problem_type=Task.SEQUENTIAL_PROBLEM,
        topic=Topic.LINEAR_ALGEBRA,
        template_engine=TemplateEngine(),
    )


class TestWrapperComponentGeneratorSelectionComprehensive:
    """Comprehensive tests across all wrapper components to ensure consistency."""

    @pytest.mark.parametrize(
        "wrapper_class,task,independent_generator",
        [
            (DeterminantWrapperComponent, Task.ONE_DETERMINANT, DeterminantGenerator),
            (FrobeniusNormWrapperComponent, Task.ONE_FROBENIUS_NORM, FrobeniusNormGenerator),
            (RankWrapperComponent, Task.ONE_RANK, MatrixRankGenerator),
            (MatrixTraceWrapperComponent, Task.ONE_TRACE, MatrixTraceGenerator),
            (TransposeWrapperComponent, Task.ONE_TRANSPOSE, MatrixTransposeGenerator),
        ],
    )
    def test_all_wrappers_independent_case(self, wrapper_class, task, independent_generator):
        """Test that all wrapper components correctly select independent generator when is_independent=True."""
        component = wrapper_class(
            name=task, constraints={"is_independent": True}, entropy_constraints=EntropyConstraints(entropy=0.1)
        )

        assert component.generator_class is independent_generator
        assert component.is_independent is True
        assert component.name == task

    @pytest.mark.parametrize(
        "wrapper_class,task,dependent_generator",
        [
            (DeterminantWrapperComponent, Task.ONE_DETERMINANT, DeterminantGeneratorDependent),
            (FrobeniusNormWrapperComponent, Task.ONE_FROBENIUS_NORM, FrobeniusNormGeneratorDependent),
            (RankWrapperComponent, Task.ONE_RANK, MatrixRankGeneratorDependent),
            (MatrixTraceWrapperComponent, Task.ONE_TRACE, MatrixTraceGeneratorDependent),
            (TransposeWrapperComponent, Task.ONE_TRANSPOSE, MatrixTransposeGeneratorDependent),
        ],
    )
    def test_all_wrappers_dependent_case(self, wrapper_class, task, dependent_generator):
        """Test that all wrapper components correctly select dependent generator when is_independent=False."""
        component = wrapper_class(
            name=task,
            constraints={"is_independent": False, "input_indices": {"input_vector_b": 0}},
            entropy_constraints=EntropyConstraints(entropy=0.1),
        )

        assert component.generator_class is dependent_generator
        assert component.is_independent is False
        assert component.name == task