atomwalk12's picture
initial commit
0dd6c2f
from sympy import Matrix, ShapeError
from sympy.matrices.exceptions import NonSquareMatrixError
from sympy.solvers.solvers import NonInvertibleMatrixError
from linalg_zero.generator.difficulty_config import Precision
from linalg_zero.generator.sympy.template_engine import MathFormatter
def multiply_matrices(matrix_a: list[list[float]], matrix_b: list[list[float]]) -> list[list[float]]:
"""Multiplies two matrices, or a matrix and a vector, or two vectors.
Examples:
>>> multiply_matrices([[1, 2], [3, 4]], [[2, 0], [1, 3]])
[[4, 6], [10, 12]]
>>> multiply_matrices([[1, 2], [3, 4]], [[5], [6]]) # matrix x vector
[[17], [39]]
>>> multiply_matrices([[1, 0], [0, 1]], [[5, 6], [7, 8]]) # Identity x matrix
[[5, 6], [7, 8]]
Args:
matrix_a (list[list[float]]): The first matrix or vector.
matrix_b (list[list[float]]): The second matrix or vector.
Returns:
list[list[float]]: The product of the two matrices, matrix and vector, or two vectors.
"""
try:
sym_a = Matrix(matrix_a)
sym_b = Matrix(matrix_b)
result_matrix: Matrix = sym_a * sym_b
result = MathFormatter.sympy_to_primitive(result_matrix, precision=Precision.MATRIX_VECTOR_MULTIPLICATION)
if isinstance(result, list) and all(isinstance(row, list) for row in result):
return result
else:
raise TypeError(f"Expected list of lists, got {type(result)}")
except ShapeError as e:
raise ValueError(f"Matrix dimensions incompatible for multiplication: {e}") from e
def solve_linear_system(matrix_a: list[list[float | int]], vector_b: list[float | int]) -> list[list[float | int]]:
"""Solve the linear system Ax = b for x using SymPy.
Examples:
>>> solve_linear_system([[2, 1], [1, 3]], [7, 8])
[[2.0], [3.0]]
>>> solve_linear_system([[1, 0], [0, 1]], [5, 3]) # Identity matrix
[[5.0], [3.0]]
Args:
matrix_a: The coefficient matrix as a list of lists.
vector_b: The right-hand side vector as a list.
Returns:
The solution vector x as a list.
"""
try:
sym_a = Matrix(matrix_a)
sym_b = Matrix(vector_b)
solution_matrix = sym_a.LUsolve(sym_b)
result = MathFormatter.sympy_to_primitive(solution_matrix, precision=Precision.LINEAR_SYSTEM_SOLVER)
if isinstance(result, list):
return result
except NonInvertibleMatrixError as e:
raise NonInvertibleMatrixError(f"Cannot solve linear system: {e}") from e
except ShapeError as e:
raise ShapeError(f"Matrix dimensions incompatible for solving linear system: {e}") from e
raise TypeError(f"Expected list, got {type(result)}")
def matrix_inverse(matrix: list[list[float | int]]) -> list[list[float | int]]:
"""Calculate the inverse of a square matrix using SymPy.
The inverse of a matrix A is the unique matrix A⁻¹ such that A * A⁻¹ = A⁻¹ * A = I,
where I is the identity matrix. Only defined for square, invertible matrices.
Examples:
>>> matrix_inverse([[1, 2], [3, 4]])
[[-2.0, 1.0], [1.5, -0.5]]
>>> matrix_inverse([[2, 0], [0, 3]])
[[0.5, 0.0], [0.0, 0.33333333]]
>>> matrix_inverse([[1]]) # 1x1 matrix
[[1.0]]
Args:
matrix: The square invertible matrix as a list of lists.
Returns:
The inverse of the matrix as a list of lists.
Raises:
ValueError: If the matrix is not square or not invertible.
"""
try:
sym_matrix = Matrix(matrix)
inverse_result = sym_matrix.inv()
result = MathFormatter.sympy_to_primitive(inverse_result, precision=Precision.MATRIX_INVERSE)
if isinstance(result, list) and all(isinstance(row, list) for row in result):
return result
except NonSquareMatrixError as e:
raise ValueError("Matrix inverse is only defined for square matrices.") from e
except NonInvertibleMatrixError as e:
raise ValueError("Matrix is not invertible (determinant is zero).") from e
except Exception as e:
raise ValueError(f"Cannot calculate matrix inverse: {e}") from e
raise TypeError(f"Expected list of lists, got {type(result)}")