import importlib.util import sys from collections import OrderedDict from typing import Union, Tuple from easy_tpp.utils.log_utils import default_logger as logger if sys.version_info < (3, 8): import importlib_metadata else: import importlib.metadata as importlib_metadata def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]: # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version package_exists = importlib.util.find_spec(pkg_name) is not None package_version = "N/A" if package_exists: try: package_version = importlib_metadata.version(pkg_name) except importlib_metadata.PackageNotFoundError: pass logger.debug(f"Detected {pkg_name} version {package_version}") if return_version: return package_exists, package_version else: return package_exists _torchdistx_available = _is_package_available("torchdistx") _torchvision_available = _is_package_available("torchvision") _torch_available, _torch_version = _is_package_available("torch", return_version=True) def is_torch_available(): return _torch_available def get_torch_version(): return _torch_version def is_torchvision_available(): return _torchvision_available def is_torch_cuda_available(): if is_torch_available(): import torch return torch.cuda.is_available() else: return False def is_torch_mps_available(): if is_torch_available(): try: import torch torch.device('mps') return True except RuntimeError: return False else: return False def is_torch_gpu_available(): is_cuda_available = is_torch_cuda_available() is_mps_available = is_torch_mps_available() return is_cuda_available | is_mps_available def torch_only_method(fn): def wrapper(*args, **kwargs): if not _torch_available: raise ImportError( "You need to install pytorch to use this method or class." ) else: return fn(*args, **kwargs) return wrapper # docstyle-ignore PYTORCH_IMPORT_ERROR = """ {0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. Please note that you may need to restart your runtime after installation. """ # docstyle-ignore TORCHVISION_IMPORT_ERROR = """ {0} requires the Torchvision library but it was not found in your environment. Checkout the instructions on the installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. Please note that you may need to restart your runtime after installation. """ BACKENDS_MAPPING = OrderedDict( [ ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), ("torchvision", (is_torchvision_available, TORCHVISION_IMPORT_ERROR)) ] ) def requires_backends(obj, backends): if not isinstance(backends, (list, tuple)): backends = [backends] name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ checks = (BACKENDS_MAPPING[backend] for backend in backends) failed = [msg.format(name) for available, msg in checks if not available()] if failed: raise ImportError("".join(failed))