| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """ |
| | A collection of utilities for ensuring that training can always occur. Heavily influenced by the |
| | [toma](https://github.com/BlackHC/toma) library. |
| | """ |
| |
|
| | import functools |
| | import gc |
| | import importlib |
| | import inspect |
| | import warnings |
| | from typing import Optional |
| |
|
| | import torch |
| | from packaging import version |
| |
|
| | from .imports import ( |
| | is_cuda_available, |
| | is_hpu_available, |
| | is_ipex_available, |
| | is_mlu_available, |
| | is_mps_available, |
| | is_musa_available, |
| | is_npu_available, |
| | is_sdaa_available, |
| | is_xpu_available, |
| | ) |
| | from .versions import compare_versions |
| |
|
| |
|
| | def clear_device_cache(garbage_collection=False): |
| | """ |
| | Clears the device cache by calling `torch.{backend}.empty_cache`. Can also run `gc.collect()`, but do note that |
| | this is a *considerable* slowdown and should be used sparingly. |
| | """ |
| | if garbage_collection: |
| | gc.collect() |
| |
|
| | if is_xpu_available(): |
| | torch.xpu.empty_cache() |
| | elif is_mlu_available(): |
| | torch.mlu.empty_cache() |
| | elif is_sdaa_available(): |
| | torch.sdaa.empty_cache() |
| | elif is_musa_available(): |
| | torch.musa.empty_cache() |
| | elif is_npu_available(): |
| | torch.npu.empty_cache() |
| | elif is_mps_available(min_version="2.0"): |
| | torch.mps.empty_cache() |
| | elif is_cuda_available(): |
| | torch.cuda.empty_cache() |
| | elif is_hpu_available(): |
| | |
| | pass |
| |
|
| |
|
| | def release_memory(*objects): |
| | """ |
| | Releases memory from `objects` by setting them to `None` and calls `gc.collect()` and `torch.cuda.empty_cache()`. |
| | Returned objects should be reassigned to the same variables. |
| | |
| | Args: |
| | objects (`Iterable`): |
| | An iterable of objects |
| | Returns: |
| | A list of `None` objects to replace `objects` |
| | |
| | Example: |
| | |
| | ```python |
| | >>> import torch |
| | >>> from accelerate.utils import release_memory |
| | |
| | >>> a = torch.ones(1000, 1000).cuda() |
| | >>> b = torch.ones(1000, 1000).cuda() |
| | >>> a, b = release_memory(a, b) |
| | ``` |
| | """ |
| | if not isinstance(objects, list): |
| | objects = list(objects) |
| | for i in range(len(objects)): |
| | objects[i] = None |
| | clear_device_cache(garbage_collection=True) |
| | return objects |
| |
|
| |
|
| | def should_reduce_batch_size(exception: Exception) -> bool: |
| | """ |
| | Checks if `exception` relates to CUDA out-of-memory, XPU out-of-memory, CUDNN not supported, or CPU out-of-memory |
| | |
| | Args: |
| | exception (`Exception`): |
| | An exception |
| | """ |
| | _statements = [ |
| | " out of memory.", |
| | "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED.", |
| | "DefaultCPUAllocator: can't allocate memory", |
| | "FATAL ERROR :: MODULE:PT_DEVMEM Allocation failed", |
| | ] |
| | if isinstance(exception, RuntimeError) and len(exception.args) == 1: |
| | return any(err in exception.args[0] for err in _statements) |
| | return False |
| |
|
| |
|
| | def find_executable_batch_size( |
| | function: Optional[callable] = None, |
| | starting_batch_size: int = 128, |
| | reduce_batch_size_fn: Optional[callable] = None, |
| | ): |
| | """ |
| | A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or |
| | CUDNN, the batch size is multiplied by 0.9 and passed to `function` |
| | |
| | `function` must take in a `batch_size` parameter as its first argument. |
| | |
| | Args: |
| | function (`callable`, *optional*): |
| | A function to wrap |
| | starting_batch_size (`int`, *optional*): |
| | The batch size to try and fit into memory |
| | |
| | Example: |
| | |
| | ```python |
| | >>> from accelerate.utils import find_executable_batch_size |
| | |
| | |
| | >>> @find_executable_batch_size(starting_batch_size=128) |
| | ... def train(batch_size, model, optimizer): |
| | ... ... |
| | |
| | |
| | >>> train(model, optimizer) |
| | ``` |
| | """ |
| | if function is None: |
| | return functools.partial(find_executable_batch_size, starting_batch_size=starting_batch_size) |
| |
|
| | batch_size = starting_batch_size |
| | if reduce_batch_size_fn is None: |
| |
|
| | def reduce_batch_size_fn(): |
| | nonlocal batch_size |
| | batch_size = int(batch_size * 0.9) |
| | return batch_size |
| |
|
| | def decorator(*args, **kwargs): |
| | nonlocal batch_size |
| | clear_device_cache(garbage_collection=True) |
| | params = list(inspect.signature(function).parameters.keys()) |
| | |
| | if len(params) < (len(args) + 1): |
| | arg_str = ", ".join([f"{arg}={value}" for arg, value in zip(params[1:], args[1:])]) |
| | raise TypeError( |
| | f"Batch size was passed into `{function.__name__}` as the first argument when called." |
| | f"Remove this as the decorator already does so: `{function.__name__}({arg_str})`" |
| | ) |
| | while True: |
| | if batch_size == 0: |
| | raise RuntimeError("No executable batch size found, reached zero.") |
| | try: |
| | return function(batch_size, *args, **kwargs) |
| | except Exception as e: |
| | if should_reduce_batch_size(e): |
| | clear_device_cache(garbage_collection=True) |
| | batch_size = reduce_batch_size_fn() |
| | else: |
| | raise |
| |
|
| | return decorator |
| |
|
| |
|
| | def get_xpu_available_memory(device_index: int): |
| | if version.parse(torch.__version__).release >= version.parse("2.6").release: |
| | |
| | |
| | |
| | |
| | |
| | |
| | try: |
| | return torch.xpu.mem_get_info(device_index)[0] |
| | except Exception: |
| | pass |
| | elif is_ipex_available(): |
| | ipex_version = version.parse(importlib.metadata.version("intel_extension_for_pytorch")) |
| | if compare_versions(ipex_version, ">=", "2.5"): |
| | from intel_extension_for_pytorch.xpu import mem_get_info |
| |
|
| | return mem_get_info(device_index)[0] |
| |
|
| | warnings.warn( |
| | "The XPU `mem_get_info` API is available in IPEX version >=2.5 or PyTorch >=2.6. The current returned available memory is incorrect. Please consider upgrading your IPEX or PyTorch version." |
| | ) |
| | return torch.xpu.max_memory_allocated(device_index) |
| |
|