| | |
| |
|
| | from __future__ import annotations |
| |
|
| | import functools |
| | import math |
| | from typing import TYPE_CHECKING |
| |
|
| | import torch |
| |
|
| | from . import _dtypes_impl, _util |
| | from ._normalizations import ArrayLike, KeepDims, normalizer |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | from collections.abc import Sequence |
| |
|
| |
|
| | class LinAlgError(Exception): |
| | pass |
| |
|
| |
|
| | def _atleast_float_1(a): |
| | if not (a.dtype.is_floating_point or a.dtype.is_complex): |
| | a = a.to(_dtypes_impl.default_dtypes().float_dtype) |
| | return a |
| |
|
| |
|
| | def _atleast_float_2(a, b): |
| | dtyp = _dtypes_impl.result_type_impl(a, b) |
| | if not (dtyp.is_floating_point or dtyp.is_complex): |
| | dtyp = _dtypes_impl.default_dtypes().float_dtype |
| |
|
| | a = _util.cast_if_needed(a, dtyp) |
| | b = _util.cast_if_needed(b, dtyp) |
| | return a, b |
| |
|
| |
|
| | def linalg_errors(func): |
| | @functools.wraps(func) |
| | def wrapped(*args, **kwds): |
| | try: |
| | return func(*args, **kwds) |
| | except torch._C._LinAlgError as e: |
| | raise LinAlgError(*e.args) |
| |
|
| | return wrapped |
| |
|
| |
|
| | |
| |
|
| |
|
| | @normalizer |
| | @linalg_errors |
| | def matrix_power(a: ArrayLike, n): |
| | a = _atleast_float_1(a) |
| | return torch.linalg.matrix_power(a, n) |
| |
|
| |
|
| | @normalizer |
| | @linalg_errors |
| | def multi_dot(inputs: Sequence[ArrayLike], *, out=None): |
| | return torch.linalg.multi_dot(inputs) |
| |
|
| |
|
| | |
| |
|
| |
|
| | @normalizer |
| | @linalg_errors |
| | def solve(a: ArrayLike, b: ArrayLike): |
| | a, b = _atleast_float_2(a, b) |
| | return torch.linalg.solve(a, b) |
| |
|
| |
|
| | @normalizer |
| | @linalg_errors |
| | def lstsq(a: ArrayLike, b: ArrayLike, rcond=None): |
| | a, b = _atleast_float_2(a, b) |
| | |
| | |
| | driver = "gels" if a.is_cuda or b.is_cuda else "gelsd" |
| | return torch.linalg.lstsq(a, b, rcond=rcond, driver=driver) |
| |
|
| |
|
| | @normalizer |
| | @linalg_errors |
| | def inv(a: ArrayLike): |
| | a = _atleast_float_1(a) |
| | result = torch.linalg.inv(a) |
| | return result |
| |
|
| |
|
| | @normalizer |
| | @linalg_errors |
| | def pinv(a: ArrayLike, rcond=1e-15, hermitian=False): |
| | a = _atleast_float_1(a) |
| | return torch.linalg.pinv(a, rtol=rcond, hermitian=hermitian) |
| |
|
| |
|
| | @normalizer |
| | @linalg_errors |
| | def tensorsolve(a: ArrayLike, b: ArrayLike, axes=None): |
| | a, b = _atleast_float_2(a, b) |
| | return torch.linalg.tensorsolve(a, b, dims=axes) |
| |
|
| |
|
| | @normalizer |
| | @linalg_errors |
| | def tensorinv(a: ArrayLike, ind=2): |
| | a = _atleast_float_1(a) |
| | return torch.linalg.tensorinv(a, ind=ind) |
| |
|
| |
|
| | |
| |
|
| |
|
| | @normalizer |
| | @linalg_errors |
| | def det(a: ArrayLike): |
| | a = _atleast_float_1(a) |
| | return torch.linalg.det(a) |
| |
|
| |
|
| | @normalizer |
| | @linalg_errors |
| | def slogdet(a: ArrayLike): |
| | a = _atleast_float_1(a) |
| | return torch.linalg.slogdet(a) |
| |
|
| |
|
| | @normalizer |
| | @linalg_errors |
| | def cond(x: ArrayLike, p=None): |
| | x = _atleast_float_1(x) |
| |
|
| | |
| | |
| | if x.numel() == 0 and math.prod(x.shape[-2:]) == 0: |
| | raise LinAlgError("cond is not defined on empty arrays") |
| |
|
| | result = torch.linalg.cond(x, p=p) |
| |
|
| | |
| | |
| | |
| | return torch.where(torch.isnan(result), float("inf"), result) |
| |
|
| |
|
| | @normalizer |
| | @linalg_errors |
| | def matrix_rank(a: ArrayLike, tol=None, hermitian=False): |
| | a = _atleast_float_1(a) |
| |
|
| | if a.ndim < 2: |
| | return int((a != 0).any()) |
| |
|
| | if tol is None: |
| | |
| | atol = 0 |
| | rtol = max(a.shape[-2:]) * torch.finfo(a.dtype).eps |
| | else: |
| | atol, rtol = tol, 0 |
| | return torch.linalg.matrix_rank(a, atol=atol, rtol=rtol, hermitian=hermitian) |
| |
|
| |
|
| | @normalizer |
| | @linalg_errors |
| | def norm(x: ArrayLike, ord=None, axis=None, keepdims: KeepDims = False): |
| | x = _atleast_float_1(x) |
| | return torch.linalg.norm(x, ord=ord, dim=axis) |
| |
|
| |
|
| | |
| |
|
| |
|
| | @normalizer |
| | @linalg_errors |
| | def cholesky(a: ArrayLike): |
| | a = _atleast_float_1(a) |
| | return torch.linalg.cholesky(a) |
| |
|
| |
|
| | @normalizer |
| | @linalg_errors |
| | def qr(a: ArrayLike, mode="reduced"): |
| | a = _atleast_float_1(a) |
| | result = torch.linalg.qr(a, mode=mode) |
| | if mode == "r": |
| | |
| | result = result.R |
| | return result |
| |
|
| |
|
| | @normalizer |
| | @linalg_errors |
| | def svd(a: ArrayLike, full_matrices=True, compute_uv=True, hermitian=False): |
| | a = _atleast_float_1(a) |
| | if not compute_uv: |
| | return torch.linalg.svdvals(a) |
| |
|
| | |
| | result = torch.linalg.svd(a, full_matrices=full_matrices) |
| | return result |
| |
|
| |
|
| | |
| |
|
| |
|
| | @normalizer |
| | @linalg_errors |
| | def eig(a: ArrayLike): |
| | a = _atleast_float_1(a) |
| | w, vt = torch.linalg.eig(a) |
| |
|
| | if not a.is_complex() and w.is_complex() and (w.imag == 0).all(): |
| | w = w.real |
| | vt = vt.real |
| | return w, vt |
| |
|
| |
|
| | @normalizer |
| | @linalg_errors |
| | def eigh(a: ArrayLike, UPLO="L"): |
| | a = _atleast_float_1(a) |
| | return torch.linalg.eigh(a, UPLO=UPLO) |
| |
|
| |
|
| | @normalizer |
| | @linalg_errors |
| | def eigvals(a: ArrayLike): |
| | a = _atleast_float_1(a) |
| | result = torch.linalg.eigvals(a) |
| | if not a.is_complex() and result.is_complex() and (result.imag == 0).all(): |
| | result = result.real |
| | return result |
| |
|
| |
|
| | @normalizer |
| | @linalg_errors |
| | def eigvalsh(a: ArrayLike, UPLO="L"): |
| | a = _atleast_float_1(a) |
| | return torch.linalg.eigvalsh(a, UPLO=UPLO) |
| |
|