| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import collections |
| import platform |
| import re |
| import socket |
| from codecs import encode |
| from collections import OrderedDict |
| from functools import partial, reduce |
| from types import MethodType |
|
|
| import numpy as np |
| import torch |
| from packaging.version import Version |
| from safetensors.torch import save_file as safe_save_file |
|
|
| from ..commands.config.default import write_basic_config |
| from ..logging import get_logger |
| from ..state import PartialState |
| from .constants import FSDP_PYTORCH_VERSION |
| from .dataclasses import DistributedType |
| from .imports import ( |
| is_deepspeed_available, |
| is_numpy_available, |
| is_torch_distributed_available, |
| is_torch_xla_available, |
| is_weights_only_available, |
| ) |
| from .modeling import id_tensor_storage |
| from .transformer_engine import convert_model |
| from .versions import is_torch_version |
|
|
|
|
| logger = get_logger(__name__) |
|
|
|
|
| if is_torch_xla_available(): |
| import torch_xla.core.xla_model as xm |
|
|
|
|
| def is_compiled_module(module: torch.nn.Module) -> bool: |
| """ |
| Check whether the module was compiled with torch.compile() |
| """ |
| if not hasattr(torch, "_dynamo"): |
| return False |
|
|
| return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) |
|
|
|
|
| def has_compiled_regions(module: torch.nn.Module) -> bool: |
| """ |
| Check whether the module has submodules that were compiled with `torch.compile()`. |
| """ |
| if not hasattr(torch, "_dynamo"): |
| return False |
|
|
| if module._modules: |
| for submodule in module.modules(): |
| if isinstance(submodule, torch._dynamo.eval_frame.OptimizedModule): |
| return True |
|
|
| return False |
|
|
|
|
| def is_repeated_blocks(module: torch.nn.Module) -> bool: |
| """ |
| Check whether the module is a repeated block, i.e. `torch.nn.ModuleList` with all children of the same class. This |
| is useful to determine whether we should apply regional compilation to the module. |
| """ |
|
|
| return isinstance(module, torch.nn.ModuleList) and all(isinstance(m, module[0].__class__) for m in module) |
|
|
|
|
| def has_repeated_blocks(module: torch.nn.Module) -> bool: |
| """ |
| Check whether the module has repeated blocks, i.e. `torch.nn.ModuleList` with all children of the same class, at |
| any level of the module hierarchy. This is useful to determine whether we should apply regional compilation to the |
| module. |
| """ |
| if module._modules: |
| for submodule in module.modules(): |
| if is_repeated_blocks(submodule): |
| return True |
|
|
| return False |
|
|
|
|
| def compile_regions(module: torch.nn.Module, **compile_kwargs) -> torch.nn.Module: |
| """ |
| Performs regional compilation where we target repeated blocks of the same class and compile them sequentially to |
| hit the compiler's cache. For example, in `GPT2LMHeadModel`, the repeated block/class is `GPT2Block`, and can be |
| accessed as `model.transformer.h[0]`. The rest of the model (e.g. model.lm_head) is compiled separately. |
| |
| This allows us to speed up the compilation overhead / cold start of models like LLMs and Transformers in general. |
| See https://pytorch.org/tutorials/recipes/regional_compilation.html for more details. |
| |
| Args: |
| module (`torch.nn.Module`): |
| The model to compile. |
| **compile_kwargs: |
| Additional keyword arguments to pass to `torch.compile()`. |
| |
| Returns: |
| `torch.nn.Module`: A new instance of the model with some compiled regions. |
| |
| Example: |
| ```python |
| >>> from accelerate.utils import compile_regions |
| >>> from transformers import AutoModelForCausalLM |
| |
| >>> model = AutoModelForCausalLM.from_pretrained("gpt2") |
| >>> compiled_model = compile_regions(model, mode="reduce-overhead") |
| >>> compiled_model.transformer.h[0] |
| OptimizedModule( |
| (_orig_mod): GPT2Block( |
| (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True) |
| (attn): GPT2Attention( |
| (c_attn): Conv1D(nf=2304, nx=768) |
| (c_proj): Conv1D(nf=768, nx=768) |
| (attn_dropout): Dropout(p=0.1, inplace=False) |
| (resid_dropout): Dropout(p=0.1, inplace=False) |
| ) |
| (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True) |
| (mlp): GPT2MLP( |
| (c_fc): Conv1D(nf=3072, nx=768) |
| (c_proj): Conv1D(nf=768, nx=3072) |
| (act): NewGELUActivation() |
| (dropout): Dropout(p=0.1, inplace=False) |
| ) |
| ) |
| ) |
| ``` |
| """ |
|
|
| def _compile_regions(module: torch.nn.Module, **compile_kwargs) -> torch.nn.Module: |
| if is_repeated_blocks(module): |
| new_module = torch.nn.ModuleList() |
| for submodule in module: |
| new_module.append(torch.compile(submodule, **compile_kwargs)) |
| elif has_repeated_blocks(module): |
| new_module = module.__class__.__new__(module.__class__) |
| new_module.__dict__.update(module.__dict__) |
| new_module._modules = {} |
| for name, submodule in module.named_children(): |
| new_module.add_module(name, _compile_regions(submodule, **compile_kwargs)) |
| else: |
| new_module = torch.compile(module, **compile_kwargs) |
|
|
| return new_module |
|
|
| new_module = _compile_regions(module, **compile_kwargs) |
|
|
| if "_orig_mod" not in new_module.__dict__: |
| |
| new_module.__dict__["_orig_mod"] = module |
|
|
| return new_module |
|
|
|
|
| def compile_regions_deepspeed(module: torch.nn.Module, **compile_kwargs): |
| """ |
| Performs regional compilation the same way as `compile_regions`, but specifically for `DeepSpeedEngine.module`. |
| Since the model is wrapped in a `DeepSpeedEngine` and has many added hooks, offloaded parameters, etc that |
| `torch.compile(...)` interferes with, version of trgional compilation uses the inplace `module.compile()` method |
| instead. |
| |
| Args: |
| module (`torch.nn.Module`): |
| The model to compile. |
| **compile_kwargs: |
| Additional keyword arguments to pass to `module.compile()`. |
| """ |
|
|
| if is_repeated_blocks(module): |
| for submodule in module: |
| submodule.compile(**compile_kwargs) |
| elif has_repeated_blocks(module): |
| for child in module.children(): |
| compile_regions_deepspeed(child, **compile_kwargs) |
| else: |
| module.compile(**compile_kwargs) |
|
|
|
|
| def model_has_dtensor(model: torch.nn.Module) -> bool: |
| """ |
| Check if the model has DTensor parameters. |
| |
| Args: |
| model (`torch.nn.Module`): |
| The model to check. |
| |
| Returns: |
| `bool`: Whether the model has DTensor parameters. |
| """ |
| if is_torch_version(">=", "2.5.0"): |
| from torch.distributed.tensor import DTensor |
| else: |
| |
| from torch.distributed._tensor import DTensor |
|
|
| return any(isinstance(p, DTensor) for p in model.parameters()) |
|
|
|
|
| def extract_model_from_parallel( |
| model, keep_fp32_wrapper: bool = True, keep_torch_compile: bool = True, recursive: bool = False |
| ): |
| """ |
| Extract a model from its distributed containers. |
| |
| Args: |
| model (`torch.nn.Module`): |
| The model to extract. |
| keep_fp32_wrapper (`bool`, *optional*): |
| Whether to remove mixed precision hooks from the model. |
| keep_torch_compile (`bool`, *optional*): |
| Whether to unwrap compiled model. |
| recursive (`bool`, *optional*, defaults to `False`): |
| Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers |
| recursively, not just the top-level distributed containers. |
| |
| Returns: |
| `torch.nn.Module`: The extracted model. |
| """ |
| options = (torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel) |
|
|
| is_compiled = is_compiled_module(model) |
| has_compiled = has_compiled_regions(model) |
|
|
| if is_compiled: |
| compiled_model = model |
| model = model._orig_mod |
| elif has_compiled: |
| compiled_model = model |
| model = model.__dict__["_orig_mod"] |
|
|
| if is_deepspeed_available(): |
| from deepspeed import DeepSpeedEngine |
|
|
| options += (DeepSpeedEngine,) |
|
|
| if is_torch_version(">=", FSDP_PYTORCH_VERSION) and is_torch_distributed_available(): |
| from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP |
|
|
| options += (FSDP,) |
|
|
| while isinstance(model, options): |
| model = model.module |
|
|
| if recursive: |
| |
| def _recursive_unwrap(module): |
| |
| |
| if hasattr(module, "module"): |
| unwrapped_module = _recursive_unwrap(module.module) |
| else: |
| unwrapped_module = module |
| |
| for name, child in unwrapped_module.named_children(): |
| setattr(unwrapped_module, name, _recursive_unwrap(child)) |
| return unwrapped_module |
|
|
| |
| model = _recursive_unwrap(model) |
|
|
| if not keep_fp32_wrapper: |
| forward = model.forward |
| original_forward = model.__dict__.pop("_original_forward", None) |
| if original_forward is not None: |
| while hasattr(forward, "__wrapped__"): |
| forward = forward.__wrapped__ |
| if forward == original_forward: |
| break |
| model.forward = MethodType(forward, model) |
| if getattr(model, "_converted_to_transformer_engine", False): |
| convert_model(model, to_transformer_engine=False) |
|
|
| if keep_torch_compile: |
| if is_compiled: |
| compiled_model._orig_mod = model |
| model = compiled_model |
| elif has_compiled: |
| compiled_model.__dict__["_orig_mod"] = model |
| model = compiled_model |
|
|
| return model |
|
|
|
|
| def wait_for_everyone(): |
| """ |
| Introduces a blocking point in the script, making sure all processes have reached this point before continuing. |
| |
| <Tip warning={true}> |
| |
| Make sure all processes will reach this instruction otherwise one of your processes will hang forever. |
| |
| </Tip> |
| """ |
| PartialState().wait_for_everyone() |
|
|
|
|
| def clean_state_dict_for_safetensors(state_dict: dict): |
| """ |
| Cleans the state dictionary from a model and removes tensor aliasing if present. |
| |
| Args: |
| state_dict (`dict`): |
| The state dictionary from a model |
| """ |
| ptrs = collections.defaultdict(list) |
| |
| for name, tensor in state_dict.items(): |
| if not isinstance(tensor, str): |
| ptrs[id_tensor_storage(tensor)].append(name) |
|
|
| |
| shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} |
| warn_names = set() |
| for names in shared_ptrs.values(): |
| |
| |
| |
| |
| |
| found_names = [name for name in names if name in state_dict] |
| warn_names.update(found_names[1:]) |
| for name in found_names[1:]: |
| del state_dict[name] |
| if len(warn_names) > 0: |
| logger.warning( |
| f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading", |
| ) |
| state_dict = {k: v.contiguous() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()} |
| return state_dict |
|
|
|
|
| def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = False): |
| """ |
| Save the data to disk. Use in place of `torch.save()`. |
| |
| Args: |
| obj: |
| The data to save |
| f: |
| The file (or file-like object) to use to save the data |
| save_on_each_node (`bool`, *optional*, defaults to `False`): |
| Whether to only save on the global main process |
| safe_serialization (`bool`, *optional*, defaults to `False`): |
| Whether to save `obj` using `safetensors` or the traditional PyTorch way (that uses `pickle`). |
| """ |
| |
| |
| |
| |
| if PartialState().distributed_type == DistributedType.XLA: |
| obj = xm._maybe_convert_to_cpu(obj) |
| |
| if safe_serialization: |
| save_func = partial(safe_save_file, metadata={"format": "pt"}) |
| if isinstance(obj, OrderedDict): |
| obj = clean_state_dict_for_safetensors(obj) |
| else: |
| save_func = torch.save |
|
|
| if PartialState().is_main_process and not save_on_each_node: |
| save_func(obj, f) |
| elif PartialState().is_local_main_process and save_on_each_node: |
| save_func(obj, f) |
|
|
|
|
| |
| |
| np_core = np._core if is_numpy_available("2.0.0") else np.core |
| TORCH_SAFE_GLOBALS = [ |
| |
| np_core.multiarray._reconstruct, |
| np.ndarray, |
| |
| encode, |
| np.dtype, |
| ] |
|
|
| if is_numpy_available("1.25.0"): |
| TORCH_SAFE_GLOBALS.append(np.dtypes.UInt32DType) |
|
|
|
|
| def load(f, map_location=None, **kwargs): |
| """ |
| Compatible drop-in replacement of `torch.load()` which allows for `weights_only` to be used if `torch` version is |
| 2.4.0 or higher. Otherwise will ignore the kwarg. |
| |
| Will also add (and then remove) an exception for numpy arrays |
| |
| Args: |
| f: |
| The file (or file-like object) to use to load the data |
| map_location: |
| a function, `torch.device`, string or a dict specifying how to remap storage locations |
| **kwargs: |
| Additional keyword arguments to pass to `torch.load()`. |
| """ |
| try: |
| if is_weights_only_available(): |
| old_safe_globals = torch.serialization.get_safe_globals() |
| if "weights_only" not in kwargs: |
| kwargs["weights_only"] = True |
| torch.serialization.add_safe_globals(TORCH_SAFE_GLOBALS) |
| else: |
| kwargs.pop("weights_only", None) |
| loaded_obj = torch.load(f, map_location=map_location, **kwargs) |
| finally: |
| if is_weights_only_available(): |
| torch.serialization.clear_safe_globals() |
| if old_safe_globals: |
| torch.serialization.add_safe_globals(old_safe_globals) |
| return loaded_obj |
|
|
|
|
| def get_pretty_name(obj): |
| """ |
| Gets a pretty name from `obj`. |
| """ |
| if not hasattr(obj, "__qualname__") and not hasattr(obj, "__name__"): |
| obj = getattr(obj, "__class__", obj) |
| if hasattr(obj, "__qualname__"): |
| return obj.__qualname__ |
| if hasattr(obj, "__name__"): |
| return obj.__name__ |
| return str(obj) |
|
|
|
|
| def merge_dicts(source, destination): |
| """ |
| Recursively merges two dictionaries. |
| |
| Args: |
| source (`dict`): The dictionary to merge into `destination`. |
| destination (`dict`): The dictionary to merge `source` into. |
| """ |
| for key, value in source.items(): |
| if isinstance(value, dict): |
| node = destination.setdefault(key, {}) |
| merge_dicts(value, node) |
| else: |
| destination[key] = value |
|
|
| return destination |
|
|
|
|
| def is_port_in_use(port: int = None) -> bool: |
| """ |
| Checks if a port is in use on `localhost`. Useful for checking if multiple `accelerate launch` commands have been |
| run and need to see if the port is already in use. |
| """ |
| if port is None: |
| port = 29500 |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
| return s.connect_ex(("localhost", port)) == 0 |
|
|
|
|
| def get_free_port() -> int: |
| """ |
| Gets a free port on `localhost`. Useful for automatic port selection when port 0 is specified in distributed |
| training scenarios. |
| |
| Returns: |
| int: An available port number |
| """ |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
| s.bind(("", 0)) |
| return s.getsockname()[1] |
|
|
|
|
| def convert_bytes(size): |
| "Converts `size` from bytes to the largest possible unit" |
| for x in ["bytes", "KB", "MB", "GB", "TB"]: |
| if size < 1024.0: |
| return f"{round(size, 2)} {x}" |
| size /= 1024.0 |
|
|
| return f"{round(size, 2)} PB" |
|
|
|
|
| def check_os_kernel(): |
| """Warns if the kernel version is below the recommended minimum on Linux.""" |
| |
| info = platform.uname() |
| system = info.system |
| if system != "Linux": |
| return |
|
|
| _, version, *_ = re.split(r"(\d+\.\d+\.\d+)", info.release) |
| min_version = "5.5.0" |
| if Version(version) < Version(min_version): |
| msg = ( |
| f"Detected kernel version {version}, which is below the recommended minimum of {min_version}; this can " |
| "cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher." |
| ) |
| logger.warning(msg, main_process_only=True) |
|
|
|
|
| def recursive_getattr(obj, attr: str): |
| """ |
| Recursive `getattr`. |
| |
| Args: |
| obj: |
| A class instance holding the attribute. |
| attr (`str`): |
| The attribute that is to be retrieved, e.g. 'attribute1.attribute2'. |
| """ |
|
|
| def _getattr(obj, attr): |
| return getattr(obj, attr) |
|
|
| return reduce(_getattr, [obj] + attr.split(".")) |
|
|
|
|
| def get_module_children_bottom_up(model: torch.nn.Module, return_fqns: bool = False) -> list[torch.nn.Module]: |
| """Traverse the model in bottom-up order and return the children modules in that order. |
| |
| Args: |
| model (`torch.nn.Module`): the model to get the children of |
| |
| Returns: |
| `list[torch.nn.Module]`: a list of children modules of `model` in bottom-up order. The last element is the |
| `model` itself. |
| """ |
| top = model if not return_fqns else ("", model) |
| stack = [top] |
| ordered_modules = [] |
| while stack: |
| current_module = stack.pop() |
| if return_fqns: |
| current_module_name, current_module = current_module |
| for name, attr in current_module.named_children(): |
| if isinstance(attr, torch.nn.Module): |
| if return_fqns: |
| child_name = current_module_name + "." + name if current_module_name else name |
| stack.append((child_name, attr)) |
| else: |
| stack.append(attr) |
| if return_fqns: |
| ordered_modules.append((current_module_name, current_module)) |
| else: |
| ordered_modules.append(current_module) |
| return ordered_modules[::-1] |
|
|