| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import importlib |
| | import importlib.metadata |
| | import os |
| | import sys |
| | import warnings |
| | from functools import lru_cache, wraps |
| |
|
| | import torch |
| | from packaging import version |
| | from packaging.version import parse |
| |
|
| | from .environment import parse_flag_from_env, patch_environment, str_to_bool |
| | from .versions import compare_versions, is_torch_version |
| |
|
| |
|
| | |
| | USE_TORCH_XLA = parse_flag_from_env("USE_TORCH_XLA", default=True) |
| |
|
| | _torch_xla_available = False |
| | if USE_TORCH_XLA: |
| | try: |
| | import torch_xla.core.xla_model as xm |
| | import torch_xla.runtime |
| |
|
| | _torch_xla_available = True |
| | except ImportError: |
| | pass |
| |
|
| | |
| | _tpu_available = _torch_xla_available |
| |
|
| | |
| | _torch_distributed_available = torch.distributed.is_available() |
| |
|
| |
|
| | def _is_package_available(pkg_name, metadata_name=None): |
| | |
| | package_exists = importlib.util.find_spec(pkg_name) is not None |
| | if package_exists: |
| | try: |
| | |
| | _ = importlib.metadata.metadata(pkg_name if metadata_name is None else metadata_name) |
| | return True |
| | except importlib.metadata.PackageNotFoundError: |
| | return False |
| |
|
| |
|
| | def is_torch_distributed_available() -> bool: |
| | return _torch_distributed_available |
| |
|
| |
|
| | def is_xccl_available(): |
| | if is_torch_version(">=", "2.7.0"): |
| | return torch.distributed.distributed_c10d.is_xccl_available() |
| | if is_ipex_available(): |
| | return False |
| | return False |
| |
|
| |
|
| | def is_ccl_available(): |
| | try: |
| | pass |
| | except ImportError: |
| | print( |
| | "Intel(R) oneCCL Bindings for PyTorch* is required to run DDP on Intel(R) XPUs, but it is not" |
| | " detected. If you see \"ValueError: Invalid backend: 'ccl'\" error, please install Intel(R) oneCCL" |
| | " Bindings for PyTorch*." |
| | ) |
| | return importlib.util.find_spec("oneccl_bindings_for_pytorch") is not None |
| |
|
| |
|
| | def get_ccl_version(): |
| | return importlib.metadata.version("oneccl_bind_pt") |
| |
|
| |
|
| | def is_import_timer_available(): |
| | return _is_package_available("import_timer") |
| |
|
| |
|
| | def is_pynvml_available(): |
| | return _is_package_available("pynvml") or _is_package_available("pynvml", "nvidia-ml-py") |
| |
|
| |
|
| | def is_pytest_available(): |
| | return _is_package_available("pytest") |
| |
|
| |
|
| | def is_msamp_available(): |
| | return _is_package_available("msamp", "ms-amp") |
| |
|
| |
|
| | def is_schedulefree_available(): |
| | return _is_package_available("schedulefree") |
| |
|
| |
|
| | def is_transformer_engine_available(): |
| | if is_hpu_available(): |
| | return _is_package_available("intel_transformer_engine", "intel-transformer-engine") |
| | else: |
| | return _is_package_available("transformer_engine", "transformer-engine") |
| |
|
| |
|
| | def is_transformer_engine_mxfp8_available(): |
| | if _is_package_available("transformer_engine", "transformer-engine"): |
| | import transformer_engine.pytorch as te |
| |
|
| | return te.fp8.check_mxfp8_support()[0] |
| | return False |
| |
|
| |
|
| | def is_lomo_available(): |
| | return _is_package_available("lomo_optim") |
| |
|
| |
|
| | def is_cuda_available(): |
| | """ |
| | Checks if `cuda` is available via an `nvml-based` check which won't trigger the drivers and leave cuda |
| | uninitialized. |
| | """ |
| | with patch_environment(PYTORCH_NVML_BASED_CUDA_CHECK="1"): |
| | available = torch.cuda.is_available() |
| |
|
| | return available |
| |
|
| |
|
| | @lru_cache |
| | def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False): |
| | """ |
| | Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set |
| | the USE_TORCH_XLA to false. |
| | """ |
| | assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true." |
| |
|
| | if not _torch_xla_available: |
| | return False |
| | elif check_is_gpu: |
| | return torch_xla.runtime.device_type() in ["GPU", "CUDA"] |
| | elif check_is_tpu: |
| | return torch_xla.runtime.device_type() == "TPU" |
| |
|
| | return True |
| |
|
| |
|
| | def is_torchao_available(): |
| | package_exists = _is_package_available("torchao") |
| | if package_exists: |
| | torchao_version = version.parse(importlib.metadata.version("torchao")) |
| | return compare_versions(torchao_version, ">=", "0.6.1") |
| | return False |
| |
|
| |
|
| | def is_deepspeed_available(): |
| | return _is_package_available("deepspeed") |
| |
|
| |
|
| | def is_pippy_available(): |
| | return is_torch_version(">=", "2.4.0") |
| |
|
| |
|
| | def is_bf16_available(ignore_tpu=False): |
| | "Checks if bf16 is supported, optionally ignoring the TPU" |
| | if is_torch_xla_available(check_is_tpu=True): |
| | return not ignore_tpu |
| | if is_cuda_available(): |
| | return torch.cuda.is_bf16_supported() |
| | if is_mlu_available(): |
| | return torch.mlu.is_bf16_supported() |
| | if is_xpu_available(): |
| | return torch.xpu.is_bf16_supported() |
| | if is_mps_available(): |
| | return torch.backends.mps.is_macos_or_newer(14, 0) |
| | return True |
| |
|
| |
|
| | def is_fp16_available(): |
| | "Checks if fp16 is supported" |
| | if is_habana_gaudi1(): |
| | return False |
| |
|
| | return True |
| |
|
| |
|
| | def is_fp8_available(): |
| | "Checks if fp8 is supported" |
| | return is_msamp_available() or is_transformer_engine_available() or is_torchao_available() |
| |
|
| |
|
| | def is_4bit_bnb_available(): |
| | package_exists = _is_package_available("bitsandbytes") |
| | if package_exists: |
| | bnb_version = version.parse(importlib.metadata.version("bitsandbytes")) |
| | return compare_versions(bnb_version, ">=", "0.39.0") |
| | return False |
| |
|
| |
|
| | def is_8bit_bnb_available(): |
| | package_exists = _is_package_available("bitsandbytes") |
| | if package_exists: |
| | bnb_version = version.parse(importlib.metadata.version("bitsandbytes")) |
| | return compare_versions(bnb_version, ">=", "0.37.2") |
| | return False |
| |
|
| |
|
| | def is_bnb_available(min_version=None): |
| | package_exists = _is_package_available("bitsandbytes") |
| | if package_exists and min_version is not None: |
| | bnb_version = version.parse(importlib.metadata.version("bitsandbytes")) |
| | return compare_versions(bnb_version, ">=", min_version) |
| | else: |
| | return package_exists |
| |
|
| |
|
| | def is_bitsandbytes_multi_backend_available(): |
| | if not is_bnb_available(): |
| | return False |
| | import bitsandbytes as bnb |
| |
|
| | return "multi_backend" in getattr(bnb, "features", set()) |
| |
|
| |
|
| | def is_torchvision_available(): |
| | return _is_package_available("torchvision") |
| |
|
| |
|
| | def is_megatron_lm_available(): |
| | if str_to_bool(os.environ.get("ACCELERATE_USE_MEGATRON_LM", "False")) == 1: |
| | if importlib.util.find_spec("megatron") is not None: |
| | try: |
| | megatron_version = parse(importlib.metadata.version("megatron-core")) |
| | if compare_versions(megatron_version, ">=", "0.8.0"): |
| | return importlib.util.find_spec(".training", "megatron") |
| | except Exception as e: |
| | warnings.warn(f"Parse Megatron version failed. Exception:{e}") |
| | return False |
| |
|
| |
|
| | def is_transformers_available(): |
| | return _is_package_available("transformers") |
| |
|
| |
|
| | def is_datasets_available(): |
| | return _is_package_available("datasets") |
| |
|
| |
|
| | def is_peft_available(): |
| | return _is_package_available("peft") |
| |
|
| |
|
| | def is_timm_available(): |
| | return _is_package_available("timm") |
| |
|
| |
|
| | def is_triton_available(): |
| | if is_xpu_available(): |
| | return _is_package_available("triton", "pytorch-triton-xpu") |
| | return _is_package_available("triton") |
| |
|
| |
|
| | def is_aim_available(): |
| | package_exists = _is_package_available("aim") |
| | if package_exists: |
| | aim_version = version.parse(importlib.metadata.version("aim")) |
| | return compare_versions(aim_version, "<", "4.0.0") |
| | return False |
| |
|
| |
|
| | def is_tensorboard_available(): |
| | return _is_package_available("tensorboard") or _is_package_available("tensorboardX") |
| |
|
| |
|
| | def is_wandb_available(): |
| | return _is_package_available("wandb") |
| |
|
| |
|
| | def is_comet_ml_available(): |
| | return _is_package_available("comet_ml") |
| |
|
| |
|
| | def is_swanlab_available(): |
| | return _is_package_available("swanlab") |
| |
|
| |
|
| | def is_trackio_available(): |
| | return sys.version_info >= (3, 10) and _is_package_available("trackio") |
| |
|
| |
|
| | def is_boto3_available(): |
| | return _is_package_available("boto3") |
| |
|
| |
|
| | def is_rich_available(): |
| | if _is_package_available("rich"): |
| | return parse_flag_from_env("ACCELERATE_ENABLE_RICH", False) |
| | return False |
| |
|
| |
|
| | def is_sagemaker_available(): |
| | return _is_package_available("sagemaker") |
| |
|
| |
|
| | def is_tqdm_available(): |
| | return _is_package_available("tqdm") |
| |
|
| |
|
| | def is_clearml_available(): |
| | return _is_package_available("clearml") |
| |
|
| |
|
| | def is_pandas_available(): |
| | return _is_package_available("pandas") |
| |
|
| |
|
| | def is_matplotlib_available(): |
| | return _is_package_available("matplotlib") |
| |
|
| |
|
| | def is_mlflow_available(): |
| | if _is_package_available("mlflow"): |
| | return True |
| |
|
| | if importlib.util.find_spec("mlflow") is not None: |
| | try: |
| | _ = importlib.metadata.metadata("mlflow-skinny") |
| | return True |
| | except importlib.metadata.PackageNotFoundError: |
| | return False |
| | return False |
| |
|
| |
|
| | def is_mps_available(min_version="1.12"): |
| | "Checks if MPS device is available. The minimum version required is 1.12." |
| | |
| | |
| | return is_torch_version(">=", min_version) and torch.backends.mps.is_available() and torch.backends.mps.is_built() |
| |
|
| |
|
| | def is_ipex_available(): |
| | "Checks if ipex is installed." |
| |
|
| | def get_major_and_minor_from_version(full_version): |
| | return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor) |
| |
|
| | _torch_version = importlib.metadata.version("torch") |
| | if importlib.util.find_spec("intel_extension_for_pytorch") is None: |
| | return False |
| | _ipex_version = "N/A" |
| | try: |
| | _ipex_version = importlib.metadata.version("intel_extension_for_pytorch") |
| | except importlib.metadata.PackageNotFoundError: |
| | return False |
| | torch_major_and_minor = get_major_and_minor_from_version(_torch_version) |
| | ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version) |
| | if torch_major_and_minor != ipex_major_and_minor: |
| | warnings.warn( |
| | f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*," |
| | f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again." |
| | ) |
| | return False |
| | return True |
| |
|
| |
|
| | @lru_cache |
| | def is_mlu_available(check_device=False): |
| | """ |
| | Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu |
| | uninitialized. |
| | """ |
| | if importlib.util.find_spec("torch_mlu") is None: |
| | return False |
| |
|
| | import torch_mlu |
| |
|
| | with patch_environment(PYTORCH_CNDEV_BASED_MLU_CHECK="1"): |
| | available = torch.mlu.is_available() |
| |
|
| | return available |
| |
|
| |
|
| | @lru_cache |
| | def is_musa_available(check_device=False): |
| | "Checks if `torch_musa` is installed and potentially if a MUSA is in the environment" |
| | if importlib.util.find_spec("torch_musa") is None: |
| | return False |
| |
|
| | import torch_musa |
| |
|
| | if check_device: |
| | try: |
| | |
| | _ = torch.musa.device_count() |
| | return torch.musa.is_available() |
| | except RuntimeError: |
| | return False |
| | return hasattr(torch, "musa") and torch.musa.is_available() |
| |
|
| |
|
| | @lru_cache |
| | def is_npu_available(check_device=False): |
| | "Checks if `torch_npu` is installed and potentially if a NPU is in the environment" |
| | if importlib.util.find_spec("torch_npu") is None: |
| | return False |
| |
|
| | |
| | |
| | try: |
| | import torch_npu |
| | except Exception: |
| | return False |
| |
|
| | if check_device: |
| | try: |
| | |
| | _ = torch.npu.device_count() |
| | return torch.npu.is_available() |
| | except RuntimeError: |
| | return False |
| | return hasattr(torch, "npu") and torch.npu.is_available() |
| |
|
| |
|
| | @lru_cache |
| | def is_sdaa_available(check_device=False): |
| | "Checks if `torch_sdaa` is installed and potentially if a SDAA is in the environment" |
| | if importlib.util.find_spec("torch_sdaa") is None: |
| | return False |
| |
|
| | import torch_sdaa |
| |
|
| | if check_device: |
| | try: |
| | |
| | _ = torch.sdaa.device_count() |
| | return torch.sdaa.is_available() |
| | except RuntimeError: |
| | return False |
| | return hasattr(torch, "sdaa") and torch.sdaa.is_available() |
| |
|
| |
|
| | @lru_cache |
| | def is_hpu_available(init_hccl=False): |
| | "Checks if `torch.hpu` is installed and potentially if a HPU is in the environment" |
| | if ( |
| | importlib.util.find_spec("habana_frameworks") is None |
| | or importlib.util.find_spec("habana_frameworks.torch") is None |
| | ): |
| | return False |
| |
|
| | import habana_frameworks.torch |
| |
|
| | if init_hccl: |
| | import habana_frameworks.torch.distributed.hccl as hccl |
| |
|
| | return hasattr(torch, "hpu") and torch.hpu.is_available() |
| |
|
| |
|
| | def is_habana_gaudi1(): |
| | if is_hpu_available(): |
| | import habana_frameworks.torch.utils.experimental as htexp |
| |
|
| | if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi: |
| | return True |
| |
|
| | return False |
| |
|
| |
|
| | @lru_cache |
| | def is_xpu_available(check_device=False): |
| | """ |
| | Checks if XPU acceleration is available either via `intel_extension_for_pytorch` or via stock PyTorch (>=2.4) and |
| | potentially if a XPU is in the environment |
| | """ |
| |
|
| | if is_ipex_available(): |
| | import intel_extension_for_pytorch |
| | else: |
| | if is_torch_version("<=", "2.3"): |
| | return False |
| |
|
| | if check_device: |
| | try: |
| | |
| | _ = torch.xpu.device_count() |
| | return torch.xpu.is_available() |
| | except RuntimeError: |
| | return False |
| | return hasattr(torch, "xpu") and torch.xpu.is_available() |
| |
|
| |
|
| | def is_dvclive_available(): |
| | return _is_package_available("dvclive") |
| |
|
| |
|
| | def is_torchdata_available(): |
| | return _is_package_available("torchdata") |
| |
|
| |
|
| | |
| | def is_torchdata_stateful_dataloader_available(): |
| | package_exists = _is_package_available("torchdata") |
| | if package_exists: |
| | torchdata_version = version.parse(importlib.metadata.version("torchdata")) |
| | return compare_versions(torchdata_version, ">=", "0.8.0") |
| | return False |
| |
|
| |
|
| | def torchao_required(func): |
| | """ |
| | A decorator that ensures the decorated function is only called when torchao is available. |
| | """ |
| |
|
| | @wraps(func) |
| | def wrapper(*args, **kwargs): |
| | if not is_torchao_available(): |
| | raise ImportError( |
| | "`torchao` is not available, please install it before calling this function via `pip install torchao`." |
| | ) |
| | return func(*args, **kwargs) |
| |
|
| | return wrapper |
| |
|
| |
|
| | |
| | def deepspeed_required(func): |
| | """ |
| | A decorator that ensures the decorated function is only called when deepspeed is enabled. |
| | """ |
| |
|
| | @wraps(func) |
| | def wrapper(*args, **kwargs): |
| | from accelerate.state import AcceleratorState |
| | from accelerate.utils.dataclasses import DistributedType |
| |
|
| | if AcceleratorState._shared_state != {} and AcceleratorState().distributed_type != DistributedType.DEEPSPEED: |
| | raise ValueError( |
| | "DeepSpeed is not enabled, please make sure that an `Accelerator` is configured for `deepspeed` " |
| | "before calling this function." |
| | ) |
| | return func(*args, **kwargs) |
| |
|
| | return wrapper |
| |
|
| |
|
| | def is_weights_only_available(): |
| | |
| | |
| | return is_torch_version(">=", "2.4.0") |
| |
|
| |
|
| | def is_numpy_available(min_version="1.25.0"): |
| | numpy_version = parse(importlib.metadata.version("numpy")) |
| | return compare_versions(numpy_version, ">=", min_version) |
| |
|