| | from functools import partial |
| |
|
| | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
| | CheckpointImpl, |
| | apply_activation_checkpointing, |
| | checkpoint_wrapper, |
| | ) |
| |
|
| |
|
| | non_reentrant_wrapper = partial( |
| | checkpoint_wrapper, |
| | checkpoint_impl=CheckpointImpl.NO_REENTRANT, |
| | ) |
| |
|
| |
|
| | def apply_checkpointing(model, block, p): |
| | """ |
| | Apply selective activation checkpointing. |
| | |
| | Selectivity is defined as a percentage p, which means we apply ac |
| | on p of the total blocks. p is a floating number in the range of |
| | [0, 1]. |
| | |
| | Some examples: |
| | p = 0: no ac for all blocks. same as `fsdp_activation_checkpointing=False` |
| | p = 1: apply ac on every block. i.e. "full ac". |
| | p = 1/2: [ac, no-ac, ac, no-ac, ...] |
| | p = 1/3: [no-ac, ac, no-ac, no-ac, ac, no-ac, ...] |
| | p = 2/3: [ac, no-ac, ac, ac, no-ac, ac, ...] |
| | Since blocks are homogeneous, we make ac blocks evenly spaced among |
| | all blocks. |
| | |
| | Implementation: |
| | For a given ac ratio p, we should essentially apply ac on every "1/p" |
| | blocks. The first ac block can be as early as the 0th block, or as |
| | late as the "1/p"th block, and we pick the middle one: (0.5p)th block. |
| | Therefore, we are essentially to apply ac on: |
| | (0.5/p)th block, (1.5/p)th block, (2.5/p)th block, etc., and of course, |
| | with these values rounding to integers. |
| | Since ac is applied recursively, we can simply use the following math |
| | in the code to apply ac on corresponding blocks. |
| | """ |
| | block_idx = 0 |
| | cut_off = 1 / 2 |
| | |
| | |
| | p = eval(p) if isinstance(p, str) else p |
| |
|
| | def selective_checkpointing(submodule): |
| | nonlocal block_idx |
| | nonlocal cut_off |
| |
|
| | if isinstance(submodule, block): |
| | block_idx += 1 |
| | if block_idx * p >= cut_off: |
| | cut_off += 1 |
| | return True |
| | return False |
| |
|
| | apply_activation_checkpointing( |
| | model, |
| | checkpoint_wrapper_fn=non_reentrant_wrapper, |
| | check_fn=selective_checkpointing, |
| | ) |
| |
|