|
|
from enum import auto, Enum |
|
|
from functools import partial |
|
|
from typing import Any, Dict, Iterator, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.autograd.graph import save_on_cpu |
|
|
from torch.distributed.utils import _pack_kwargs, _replace_by_prefix, _unpack_kwargs |
|
|
from torch.utils.checkpoint import checkpoint |
|
|
|
|
|
_CHECKPOINT_PREFIX = "_checkpoint_wrapped_module" |
|
|
|
|
|
class CheckpointImpl(Enum): |
|
|
REENTRANT = auto() |
|
|
NO_REENTRANT = auto() |
|
|
|
|
|
|
|
|
class CheckpointWrapper(torch.nn.Module): |
|
|
""" |
|
|
An nn.Module that wraps another nn.Module with checkpointing. Note that this |
|
|
module is not meant to be used directly, but instead it is to be used |
|
|
through the ``checkpoint_wrapper`` function. |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
mod: torch.nn.Module, |
|
|
checkpoint_impl: CheckpointImpl = CheckpointImpl.REENTRANT, |
|
|
offload_to_cpu: bool = False, |
|
|
checkpoint_fn=None, |
|
|
*checkpoint_fn_args, |
|
|
**checkpoint_fn_kwargs, |
|
|
): |
|
|
super().__init__() |
|
|
self._checkpoint_wrapped_module = mod |
|
|
self.checkpoint_impl = checkpoint_impl |
|
|
self.offload_to_cpu = offload_to_cpu |
|
|
if self.offload_to_cpu: |
|
|
self.checkpoint_fn = None |
|
|
else: |
|
|
if checkpoint_fn is None: |
|
|
|
|
|
self.checkpoint_fn = partial( |
|
|
checkpoint, |
|
|
use_reentrant=( |
|
|
self.checkpoint_impl == CheckpointImpl.REENTRANT |
|
|
), |
|
|
) |
|
|
else: |
|
|
self.checkpoint_fn = partial( |
|
|
checkpoint_fn, |
|
|
*checkpoint_fn_args, |
|
|
**checkpoint_fn_kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
self._register_state_dict_hook(self._post_state_dict_hook) |
|
|
|
|
|
|
|
|
self._register_load_state_dict_pre_hook( |
|
|
self._pre_load_state_dict_hook, with_module=True |
|
|
) |
|
|
|
|
|
def __getattr__(self, name: str) -> Any: |
|
|
"""Forward missing attributes to wrapped module.""" |
|
|
try: |
|
|
return super().__getattr__(name) |
|
|
except AttributeError: |
|
|
return getattr(self._checkpoint_wrapped_module, name) |
|
|
|
|
|
def __getitem__(self, key: int) -> Any: |
|
|
"""Forward indexing calls in case the module is a nn.Sequential.""" |
|
|
return self._checkpoint_wrapped_module.__getitem__(key) |
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
if self.offload_to_cpu: |
|
|
with save_on_cpu(pin_memory=True): |
|
|
return self._checkpoint_wrapped_module(*args, **kwargs) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
if self.checkpoint_impl == CheckpointImpl.REENTRANT and kwargs != {}: |
|
|
|
|
|
flat_args, kwarg_keys = _pack_kwargs(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def my_function(*inputs): |
|
|
|
|
|
unpacked_args, unpacked_kwargs = _unpack_kwargs( |
|
|
inputs, kwarg_keys |
|
|
) |
|
|
|
|
|
return self._checkpoint_wrapped_module( |
|
|
*unpacked_args, **unpacked_kwargs |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
return self.checkpoint_fn( |
|
|
my_function, |
|
|
*flat_args, |
|
|
) |
|
|
else: |
|
|
return self.checkpoint_fn( |
|
|
self._checkpoint_wrapped_module, |
|
|
*args, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
def named_parameters( |
|
|
self, |
|
|
*args, |
|
|
**kwargs, |
|
|
) -> Iterator[Tuple[str, torch.nn.Parameter]]: |
|
|
""" |
|
|
Overrides :meth:`named_parameters()` to intercept parameter names and |
|
|
remove all occurrences of _CHECKPOINT_PREFIX. |
|
|
""" |
|
|
for param_name, param in super().named_parameters(*args, **kwargs): |
|
|
yield param_name.replace(f"{_CHECKPOINT_PREFIX}.", ""), param |
|
|
|
|
|
@staticmethod |
|
|
def _post_state_dict_hook( |
|
|
module: nn.Module, |
|
|
state_dict: Dict[str, Any], |
|
|
prefix: str, |
|
|
*args: Any, |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
_post_state_dict_hook() is called after the state_dict() of this |
|
|
FSDP module is executed. For ``checkpoint_wrapper``, it will strip |
|
|
checkpoint-wrapped module prefix so that this module can be loaded into |
|
|
non-checkpointed modules. It would still be able to be loaded into |
|
|
checkpoint-wrapped modules as this class adds the prefix back before |
|
|
loading the state_dict. |
|
|
""" |
|
|
_replace_by_prefix(state_dict, f"{prefix}{_CHECKPOINT_PREFIX}.", prefix) |
|
|
return state_dict |
|
|
|
|
|
@staticmethod |
|
|
def _pre_load_state_dict_hook( |
|
|
module: nn.Module, |
|
|
state_dict: Dict[str, Any], |
|
|
prefix: str, |
|
|
*args: Any, |
|
|
) -> None: |
|
|
""" |
|
|
``_pre_state_dict_hook` is called before ``self._load_from_state_dict()`` |
|
|
is called. For ``checkpoint_wrapper``, it will add back the module |
|
|
prefix so that non-checkpointed modules can be loaded into |
|
|
checkpoint_wrapper modules properly. |
|
|
""" |
|
|
_replace_by_prefix(state_dict, prefix, prefix + f"{_CHECKPOINT_PREFIX}.") |
|
|
|
|
|
|
|
|
def checkpoint_wrapper( |
|
|
module: torch.nn.Module, |
|
|
checkpoint_impl: CheckpointImpl = CheckpointImpl.REENTRANT, |
|
|
offload_to_cpu: bool = False, |
|
|
checkpoint_fn=None, |
|
|
*checkpoint_fn_args, |
|
|
**checkpoint_fn_kwargs, |
|
|
) -> torch.nn.Module: |
|
|
""" |
|
|
A convenience wrapper for activation checkpointing. If the module is wrapped |
|
|
with this function, all subsequent calls to the module will automatically |
|
|
perform checkpointing without the user having to explicitly call ``checkpoint`` |
|
|
function. |
|
|
Usage:: |
|
|
checkpointed_module = checkpoint_wrapper(module) |
|
|
outputs = checkpointed_module(inputs) |
|
|
Args: |
|
|
module (nn.Module): |
|
|
The module to be wrapped |
|
|
checkpoint_impl (Optional[CheckpointImpl]): |
|
|
The checkpointing implementation to use. Note that this will only |
|
|
be passed into the ``torch.utils.checkpoint.checkpoint`` |
|
|
implementation, and is ignored if a custom ``checkpoint_fn`` is |
|
|
specified. Note that for implementations using reentrant checkpoint |
|
|
from ``torch.utils.checkpoint``, keyword arguments will only be |
|
|
supported if ``checkpoint_impl`` is passed as ``CheckpointImpl.REENTRANT`. |
|
|
offload_to_cpu (Optional[bool]): |
|
|
Whether to offload activations of this wrapped module to CPU. Note |
|
|
that if this is specified, ``checkpoint_impl`` and ``checkpoint_fn`` |
|
|
arguments will be ignored in favor of the activations being |
|
|
offloaded to CPU. Default is ``False``. Wrappers with activation |
|
|
offload can be composed with ones that do recomputation-based |
|
|
checkpoint to trade off increased compute versus increased CPU |
|
|
memory usage and additional H2D transfers. |
|
|
checkpoint_fn (Optional[Callable]): |
|
|
Functional checkpoint implementation to use. If this is specified, |
|
|
it will be used over the default ``torch.utils.checkpoint.checkpoint`` |
|
|
implementation and the `checkpoint_impl` argument will be ignored. |
|
|
*checkpoint_fn_args: (Sequence[Any]): Arguments to pass into `checkpoint_fn`. |
|
|
**checkpoint_fn_kwargs: (Dict[str, Any]): Keyword arguments to pass into `checkpoint_fn`. |
|
|
|
|
|
Returns: |
|
|
(nn.Module): |
|
|
Wrapped module |
|
|
""" |
|
|
|
|
|
return CheckpointWrapper( |
|
|
module, checkpoint_impl, offload_to_cpu, checkpoint_fn, checkpoint_fn_args, checkpoint_fn_kwargs |
|
|
) |
|
|
|
|
|
|
|
|
def apply_activation_checkpointing( |
|
|
model, checkpoint_wrapper_fn=checkpoint_wrapper, check_fn=lambda _: True |
|
|
): |
|
|
""" |
|
|
Applies :func:`checkpoint_wrapper` to modules within `model` based on a user-defined |
|
|
configuration. For each module within `model`, the `check_fn` is used to decide |
|
|
whether `module` should be wrapped with :func:`checkpoint_wrapper` or not. |
|
|
|
|
|
Note:: |
|
|
This function modifies `model` in place and replaces appropriate layers with |
|
|
their checkpoint-wrapped modules. |
|
|
Note:: |
|
|
This function will not wrap the overall root module. If this is needed, please directly use |
|
|
:class:`CheckpointWrapper`. |
|
|
Usage:: |
|
|
model = nn.Sequential( |
|
|
nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10) |
|
|
) |
|
|
check_fn = lambda l: isinstance(l, nn.Linear) |
|
|
apply_activation_checkpointing(model, checkpoint_wrapper_fn=checkpoint_wrapper, check_fn=check_fn) |
|
|
Args: |
|
|
model (nn.Module): |
|
|
The model whose submodules should be wrapped with activation checkpointing. |
|
|
checkpoint_wrapper_fn (Optional[Callable[nn.Module]]) |
|
|
A ``Callable`` which will wrap modules |
|
|
check_fn (Optional[Callable[nn.Module, nn.Module]]) |
|
|
A lambda function which will be passed each child submoule of ``model`` and returns |
|
|
``True`` or ``False`` depending on whether the submodule should be wrapped. |
|
|
Returns: None (`model` is modified inplace) |
|
|
""" |
|
|
|
|
|
|
|
|
from torch.distributed.fsdp.wrap import _recursive_wrap, lambda_auto_wrap_policy |
|
|
return _recursive_wrap( |
|
|
module=model, |
|
|
auto_wrap_policy=partial(lambda_auto_wrap_policy, lambda_fn=check_fn), |
|
|
wrapper_cls=checkpoint_wrapper_fn, |
|
|
ignored_modules=set(), |
|
|
ignored_params=set(), |
|
|
only_wrap_children=True |
|
|
) |
|
|
|