File size: 7,038 Bytes
0dd6c2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
from collections.abc import Callable
from typing import Any

from sympy import Matrix
from sympy.matrices.exceptions import NonSquareMatrixError
from transformers.utils.chat_template_utils import get_json_schema

from linalg_zero.generator.difficulty_config import Precision
from linalg_zero.generator.sympy.template_engine import MathFormatter
from linalg_zero.shared.types import assert_lib_returns


def matrix_transpose(matrix: list[list[float | int]]) -> list[list[float | int]]:
    """Return the transpose of a matrix.

    Args:
        matrix: Matrix represented as a list of rows (list[list[float | int]]).

    Returns:
        list[list[float | int]]: Transposed matrix (rows and columns swapped).

    Examples:
        >>> matrix_transpose([[1, 2, 3], [4, 5, 6]])
        [[1, 4], [2, 5], [3, 6]]
        >>> matrix_transpose([[1]])
        [[1]]
    """
    try:
        sym_matrix = Matrix(matrix)
        transpose_result = sym_matrix.T

        result = MathFormatter.sympy_to_primitive(transpose_result, precision=Precision.MATRIX_TRANSPOSE)

        if isinstance(result, list) and all(isinstance(row, list) for row in result):
            return result

    except Exception as e:
        raise ValueError(f"Cannot calculate matrix transpose: {e}") from e

    raise TypeError(f"Expected list of lists, got {type(result)}")


def matrix_cofactor(matrix: list[list[float | int]]) -> list[list[float | int]]:
    """Return the cofactor matrix of a square matrix.

    Args:
        matrix: Square matrix as a list of rows (list[list[float | int]], n x n).

    Returns:
        list[list[float | int]]: Cofactor matrix with the same shape as the input.

    Raises:
        ValueError: If the input matrix is not square.

    Examples:
        >>> matrix_cofactor([[1, 2], [3, 4]])
        [[4, -3], [-2, 1]]
        >>> matrix_cofactor([[1]])
        [[1]]
    """
    try:
        sym_matrix = Matrix(matrix)

        cofactor_result = sym_matrix.cofactor_matrix()

        result = MathFormatter.sympy_to_primitive(cofactor_result, precision=Precision.MATRIX_COFACTOR)

        if isinstance(result, list) and all(isinstance(row, list) for row in result):
            return result

    except NonSquareMatrixError as e:
        raise ValueError(f"Matrix must be square for cofactor calculation: {e}") from e
    except Exception as e:
        raise ValueError(f"Cannot calculate cofactor matrix: {e}") from e

    raise TypeError(f"Expected list of lists, got {type(result)}")


def determinant(matrix: list[list[float | int]]) -> float:
    """Return the determinant of a square matrix.

    Args:
        matrix: Square matrix as a list of rows (list[list[float | int]], n x n).

    Returns:
        float: Determinant value.

    Examples:
        >>> determinant([[1, 2], [3, 4]])
        -2.0
        >>> determinant([[2, 0], [0, 3]])
        6.0
    """
    try:
        sym_matrix = Matrix(matrix)

        det_result = sym_matrix.det()
        result = MathFormatter.sympy_to_primitive(det_result, precision=Precision.DETERMINANT)

        if isinstance(result, int | float):
            return float(result)

    except NonSquareMatrixError as e:
        raise ValueError("Matrix must be square") from e
    except Exception as e:
        raise ValueError(f"Cannot calculate determinant: {e}") from e

    raise TypeError(f"Expected numeric result, got {type(result)}")


def frobenius_norm(matrix: list[list[float | int]]) -> float:
    """Return the Frobenius norm of a matrix.

    Args:
        matrix: Matrix as a list of rows (list[list[float | int]]).

    Returns:
        float: Frobenius norm value.

    Examples:
        >>> frobenius_norm([[1, 2], [3, 4]])
        5.48
        >>> frobenius_norm([[0, 0], [0, 0]])
        0.0
    """
    try:
        sym_matrix = Matrix(matrix)

        # Calculate Frobenius norm: sqrt(sum of squared elements)
        norm_result = sym_matrix.norm()
        result = MathFormatter.sympy_to_primitive(norm_result, precision=Precision.FROBENIUS_NORM)

        if isinstance(result, int | float):
            return float(result)

    except Exception as e:
        raise ValueError(f"Cannot calculate Frobenius norm: {e}") from e

    raise TypeError(f"Expected numeric result, got {type(result)}")


def matrix_rank(matrix: list[list[float | int]]) -> int:
    """Return the rank of a matrix.

    Args:
        matrix: Matrix as a list of rows (list[list[float | int]]).

    Returns:
        int: Rank (non-negative integer).

    Examples:
        >>> matrix_rank([[1, 2], [3, 4]])
        2
        >>> matrix_rank([[1, 2], [2, 4]])
        1
    """
    try:
        sym_matrix = Matrix(matrix)
        rank_result = sym_matrix.rank()

        if isinstance(rank_result, int):
            return rank_result

    except Exception as e:
        raise ValueError(f"Cannot calculate matrix rank: {e}") from e

    raise TypeError(f"Expected integer result, got {type(rank_result)}")


def matrix_trace(matrix: list[list[float | int]]) -> float:
    """Return the trace of a square matrix.

    Args:
        matrix: Square matrix as a list of rows (list[list[float | int]], n x n).

    Returns:
        float: Trace (sum of diagonal entries).

    Examples:
        >>> matrix_trace([[1, 2], [3, 4]])
        5.0
        >>> matrix_trace([[5]])
        5.0
    """
    try:
        sym_matrix = Matrix(matrix)

        trace_result = sym_matrix.trace()
        result = MathFormatter.sympy_to_primitive(trace_result, precision=Precision.MATRIX_TRACE)

        if isinstance(result, int | float):
            return float(result)

    except NonSquareMatrixError as e:
        raise ValueError("Trace is only defined for square matrices.") from e
    except Exception as e:
        raise ValueError(f"Cannot calculate matrix trace: {e}") from e

    raise TypeError(f"Expected numeric result, got {type(result)}")


def get_lib() -> dict[str, Callable[..., Any]]:
    """Return the library of available functions."""
    return {
        # Matrix results
        "matrix_transpose": matrix_transpose,
        "matrix_cofactor": matrix_cofactor,
        # Scalar results
        "determinant": determinant,
        "frobenius_norm": frobenius_norm,
        "matrix_rank": matrix_rank,
        "matrix_trace": matrix_trace,
    }


def get_lib_fn_names() -> list[str]:
    """Return the names of the functions in the library."""
    return list(get_lib().keys())


def get_tools() -> list[dict[str, Any]]:
    """Returns the tool representation of the functions in the library."""
    return [get_json_schema(func) for func in get_lib().values()]


def get_lib_types_list() -> list[type]:
    """
    Get the list of library return types.
    This is a check to ensure grpo training uses well-tested types in math-verify.
    This only influences the reward functions, and will likely work with other types
    as well. Make sure the types defined below coincide by using this function.
    """
    return assert_lib_returns({float, int, list}, get_lib())