File size: 1,228 Bytes
36c95ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch
from packaging import version
def torch_version() -> str:
"""Parse the `torch.__version__` variable and removes +cu*/cpu."""
return torch.__version__.split('+')[0]
def torch_version_geq(major, minor) -> bool:
_version = version.parse(torch_version())
return _version >= version.parse(f"{major}.{minor}")
if version.parse(torch_version()) > version.parse("1.7.1"):
# TODO: remove the type: ignore once Python 3.6 is deprecated.
# It turns out that Pytorch has no attribute `torch.linalg` for
# Python 3.6 / PyTorch 1.7.0, 1.7.1
from torch.linalg import solve # type: ignore
else:
from torch import solve as _solve
# NOTE: in previous versions `torch.solve` accepted arguments in another order.
def solve(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
return _solve(B, A).solution
if version.parse(torch_version()) > version.parse("1.7.1"):
# TODO: remove the type: ignore once Python 3.6 is deprecated.
# It turns out that Pytorch has no attribute `torch.linalg` for
# Python 3.6 / PyTorch 1.7.0, 1.7.1
from torch.linalg import qr as linalg_qr # type: ignore
else:
from torch import qr as linalg_qr # type: ignore # noqa: F401
|