File size: 4,857 Bytes
f4cade0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# mypy: allow-untyped-defs
from collections.abc import Generator
from contextlib import AbstractContextManager, contextmanager, nullcontext
from typing import Any, Optional
import torch
import torch.nn as nn
from torch.utils.checkpoint import (
_checkpoint_without_reentrant_generator,
_DEFAULT_DETERMINISM_MODE,
)
from .contract import _State, contract
@contextmanager
def _no_hook(module: nn.Module, user_ctx: Optional[AbstractContextManager] = None):
r"""
Disable hooks installed by checkpoint to avoid unintentional recursion
during backward recomputation.
"""
with user_ctx if user_ctx else nullcontext():
orig_enable_hook = checkpoint.state(module).enable_hook
checkpoint.state(module).enable_hook = False
try:
yield
finally:
checkpoint.state(module).enable_hook = orig_enable_hook
class _CheckpointState(_State):
enable_hook: bool = False
_ac_generator: Optional[Generator[None, None, None]]
@contract(_CheckpointState)
def checkpoint(module: nn.Module, **kwargs) -> nn.Module:
r"""
This is a composable activation checkpointing API. Unlike functional
activation checkpointing APIs, this one does not require changing model
source code. Unlike ``nn.Module`` wrapper activation checkpointing APIs,
this one does not modify model structure or fully-qualified names either.
Under the hood, it registers activation checkpointing logic as pre- and
post-forward hooks. Hence, this API can be easily applied to any model or
sub-modules in the model.
Args:
module (nn.Module): the target model or sub-module to apply activation
checkpointing.
Example::
>>> # xdoctest: +SKIP
>>> import torch.nn as nn
>>>
>>> class MyModel(nn.Module):
>>> def __init__(self) -> None:
>>> super().__init__()
>>> self.l1 = nn.Linear(10, 10)
>>> self.l2 = nn.Linear(10, 10)
>>>
>>> def forward(self, x):
>>> return self.l2(self.l1(x))
>>>
>>> model = MyModel()
>>> checkpoint(model.l1) # apply activation checkpointing only to l1
>>> model(torch.zeros(2, 10)).sum().backward()
"""
torch._C._log_api_usage_once("torch.distributed.checkpoint")
use_reentrant = kwargs.pop("use_reentrant", False)
if use_reentrant:
raise NotImplementedError(
"use_reentrant=True is not supported in composable checkpoint. "
"Please use torch.utils.checkpoint.checkpoint instead."
)
preserve_rng_state = kwargs.pop("preserve_rng_state", True)
user_context_fns = kwargs.pop("context_fn", None)
determinism_check = kwargs.pop("determinism_check", _DEFAULT_DETERMINISM_MODE)
debug = kwargs.pop("debug", False)
if kwargs:
raise ValueError(
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
)
def forward_pre_hook(
module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> None:
if checkpoint.state(module).enable_hook:
def context_fns():
if user_context_fns is not None:
ctx1, ctx2 = user_context_fns()
return ctx1, _no_hook(module, ctx2)
else:
return nullcontext(), _no_hook(module)
gen = _checkpoint_without_reentrant_generator(
module,
preserve_rng_state,
context_fns,
determinism_check,
debug,
*args,
**kwargs,
)
checkpoint.state(module)._ac_generator = gen
next(gen)
def forward_hook(module: nn.Module, inputs: tuple[Any, ...], output: Any) -> Any:
if checkpoint.state(module).enable_hook:
try:
gen = checkpoint.state(module)._ac_generator
assert gen is not None
next(gen)
except StopIteration:
pass
else:
raise RuntimeError(
"Expected non-reentrant activation checkpoint generator to be exhausted, but it was not!"
)
# Ensure that we no longer hold on to the generator. always_call=True helps ensure we
# clear this even in the case of exception in fwd pass.
checkpoint.state(module)._ac_generator = None
checkpoint.state(module).enable_hook = True
module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
module.register_forward_hook(forward_hook, prepend=True, always_call=True)
return module
|