| import numpy as np | |
| from easy_tpp.utils import is_torch_available | |
| def is_tensor(x): | |
| """ | |
| Tests if `x` is a `torch.Tensor` or `np.ndarray`. | |
| """ | |
| if is_torch_available(): | |
| import torch | |
| if isinstance(x, torch.Tensor): | |
| return True | |
| return isinstance(x, np.ndarray) | |
| def _is_numpy(x): | |
| return isinstance(x, np.ndarray) | |
| def is_numpy_array(x): | |
| """ | |
| Tests if `x` is a numpy array or not. | |
| """ | |
| return _is_numpy(x) | |
| def _is_torch(x): | |
| import torch | |
| return isinstance(x, torch.Tensor) | |
| def is_torch_tensor(x): | |
| """ | |
| Tests if `x` is a torch tensor or not. Safe to call even if torch is not installed. | |
| """ | |
| return False if not is_torch_available() else _is_torch(x) | |
| def _is_torch_device(x): | |
| import torch | |
| return isinstance(x, torch.device) | |
| def is_torch_device(x): | |
| """ | |
| Tests if `x` is a torch device or not. Safe to call even if torch is not installed. | |
| """ | |
| return False if not is_torch_available() else _is_torch_device(x) | |
| def _is_torch_dtype(x): | |
| import torch | |
| if isinstance(x, str): | |
| if hasattr(torch, x): | |
| x = getattr(torch, x) | |
| else: | |
| return False | |
| return isinstance(x, torch.dtype) | |
| def is_torch_dtype(x): | |
| """ | |
| Tests if `x` is a torch dtype or not. Safe to call even if torch is not installed. | |
| """ | |
| return False if not is_torch_available() else _is_torch_dtype(x) | |