| |
|
|
| import inspect |
| from functools import wraps |
| from typing import Callable, TypeVar, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.utils.checkpoint as checkpoint |
| from torch.utils._pytree import tree_map_only |
|
|
| |
| T = TypeVar("T") |
| Module = TypeVar("Module", bound=nn.Module) |
|
|
|
|
| def activation_ckpt_wrapper(module: Union[nn.Module, Callable]) -> Callable: |
| """ |
| Wraps a given module to enable or disable activation checkpointing. |
| |
| Activation checkpointing (gradient checkpointing) trades compute for memory by |
| recomputing intermediate activations during the backward pass instead of storing |
| them in memory during the forward pass. |
| |
| When activation checkpointing is enabled, the wrapper expects only keyword arguments, |
| and it maps these to positional arguments based on the module's signature. |
| |
| Args: |
| module: The module or function to wrap with activation checkpointing |
| |
| Returns: |
| A wrapped callable that supports activation checkpointing |
| |
| Usage: |
| The returned wrapper function can be called with the same arguments as the |
| original module, with an additional `act_ckpt_enable` keyword argument to control |
| activation checkpointing and optional `use_reentrant` parameter. |
| |
| Example: |
| ```python |
| wrapped_module = activation_ckpt_wrapper(my_module) |
| output = wrapped_module(x=input_tensor, y=another_tensor, act_ckpt_enable=True) |
| ``` |
| """ |
|
|
| @wraps(module) |
| def act_ckpt_wrapper( |
| *args, act_ckpt_enable: bool = True, use_reentrant: bool = False, **kwargs |
| ): |
| if act_ckpt_enable: |
| if len(args) > 0: |
| raise ValueError( |
| "This wrapper expects keyword arguments only when `act_ckpt_enable=True`" |
| ) |
| |
| callable_fn = module.forward if isinstance(module, nn.Module) else module |
| sig = inspect.signature(callable_fn) |
| |
| param_defaults = { |
| name: param.default for name, param in sig.parameters.items() |
| } |
| args = [] |
| for p_name in param_defaults.keys(): |
| if p_name in kwargs: |
| args.append(kwargs.pop(p_name)) |
| elif param_defaults[p_name] is not inspect.Parameter.empty: |
| |
| args.append(param_defaults[p_name]) |
| elif ( |
| sig.parameters[p_name].kind is not inspect.Parameter.VAR_KEYWORD |
| ): |
| raise ValueError(f"Missing positional argument: {p_name}") |
|
|
| |
| remaining_keys = list(kwargs.keys()) |
| for key in remaining_keys: |
| if isinstance(kwargs[key], torch.Tensor): |
| |
| |
| kwargs[key] = "_REMOVED_BY_ACT_CKPT_WRAPPER_" |
|
|
| ret = checkpoint.checkpoint( |
| module, *args, use_reentrant=use_reentrant, **kwargs |
| ) |
| else: |
| ret = module(*args, **kwargs) |
|
|
| return ret |
|
|
| return act_ckpt_wrapper |
|
|
|
|
| def clone_output_wrapper(f: Callable[..., T]) -> Callable[..., T]: |
| """ |
| Clone the CUDA output tensors of a function to avoid in-place operations. |
| |
| This wrapper is useful when working with torch.compile to prevent errors |
| related to in-place operations on tensors. |
| |
| Args: |
| f: The function whose CUDA tensor outputs should be cloned |
| |
| Returns: |
| A wrapped function that clones any CUDA tensor outputs |
| """ |
|
|
| @wraps(f) |
| def wrapped(*args, **kwargs): |
| outputs = f(*args, **kwargs) |
| return tree_map_only( |
| torch.Tensor, lambda t: t.clone() if t.is_cuda else t, outputs |
| ) |
|
|
| return wrapped |
|
|