File size: 4,293 Bytes
0dd6c2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from sympy import Matrix, ShapeError
from sympy.matrices.exceptions import NonSquareMatrixError
from sympy.solvers.solvers import NonInvertibleMatrixError

from linalg_zero.generator.difficulty_config import Precision
from linalg_zero.generator.sympy.template_engine import MathFormatter


def multiply_matrices(matrix_a: list[list[float]], matrix_b: list[list[float]]) -> list[list[float]]:
    """Multiplies two matrices, or a matrix and a vector, or two vectors.

    Examples:
        >>> multiply_matrices([[1, 2], [3, 4]], [[2, 0], [1, 3]])
        [[4, 6], [10, 12]]
        >>> multiply_matrices([[1, 2], [3, 4]], [[5], [6]])  # matrix x vector
        [[17], [39]]
        >>> multiply_matrices([[1, 0], [0, 1]], [[5, 6], [7, 8]])  # Identity x matrix
        [[5, 6], [7, 8]]

    Args:
        matrix_a (list[list[float]]): The first matrix or vector.
        matrix_b (list[list[float]]): The second matrix or vector.

    Returns:
        list[list[float]]: The product of the two matrices, matrix and vector, or two vectors.
    """
    try:
        sym_a = Matrix(matrix_a)
        sym_b = Matrix(matrix_b)
        result_matrix: Matrix = sym_a * sym_b
        result = MathFormatter.sympy_to_primitive(result_matrix, precision=Precision.MATRIX_VECTOR_MULTIPLICATION)

        if isinstance(result, list) and all(isinstance(row, list) for row in result):
            return result
        else:
            raise TypeError(f"Expected list of lists, got {type(result)}")
    except ShapeError as e:
        raise ValueError(f"Matrix dimensions incompatible for multiplication: {e}") from e


def solve_linear_system(matrix_a: list[list[float | int]], vector_b: list[float | int]) -> list[list[float | int]]:
    """Solve the linear system Ax = b for x using SymPy.

    Examples:
        >>> solve_linear_system([[2, 1], [1, 3]], [7, 8])
        [[2.0], [3.0]]
        >>> solve_linear_system([[1, 0], [0, 1]], [5, 3])  # Identity matrix
        [[5.0], [3.0]]

    Args:
        matrix_a: The coefficient matrix as a list of lists.
        vector_b: The right-hand side vector as a list.

    Returns:
        The solution vector x as a list.
    """
    try:
        sym_a = Matrix(matrix_a)
        sym_b = Matrix(vector_b)

        solution_matrix = sym_a.LUsolve(sym_b)

        result = MathFormatter.sympy_to_primitive(solution_matrix, precision=Precision.LINEAR_SYSTEM_SOLVER)

        if isinstance(result, list):
            return result

    except NonInvertibleMatrixError as e:
        raise NonInvertibleMatrixError(f"Cannot solve linear system: {e}") from e
    except ShapeError as e:
        raise ShapeError(f"Matrix dimensions incompatible for solving linear system: {e}") from e

    raise TypeError(f"Expected list, got {type(result)}")


def matrix_inverse(matrix: list[list[float | int]]) -> list[list[float | int]]:
    """Calculate the inverse of a square matrix using SymPy.

    The inverse of a matrix A is the unique matrix A⁻¹ such that A * A⁻¹ = A⁻¹ * A = I,
    where I is the identity matrix. Only defined for square, invertible matrices.

    Examples:
        >>> matrix_inverse([[1, 2], [3, 4]])
        [[-2.0, 1.0], [1.5, -0.5]]
        >>> matrix_inverse([[2, 0], [0, 3]])
        [[0.5, 0.0], [0.0, 0.33333333]]
        >>> matrix_inverse([[1]])  # 1x1 matrix
        [[1.0]]

    Args:
        matrix: The square invertible matrix as a list of lists.

    Returns:
        The inverse of the matrix as a list of lists.

    Raises:
        ValueError: If the matrix is not square or not invertible.
    """
    try:
        sym_matrix = Matrix(matrix)
        inverse_result = sym_matrix.inv()
        result = MathFormatter.sympy_to_primitive(inverse_result, precision=Precision.MATRIX_INVERSE)

        if isinstance(result, list) and all(isinstance(row, list) for row in result):
            return result

    except NonSquareMatrixError as e:
        raise ValueError("Matrix inverse is only defined for square matrices.") from e
    except NonInvertibleMatrixError as e:
        raise ValueError("Matrix is not invertible (determinant is zero).") from e
    except Exception as e:
        raise ValueError(f"Cannot calculate matrix inverse: {e}") from e

    raise TypeError(f"Expected list of lists, got {type(result)}")