| import contextlib |
|
|
| import torch |
| from torch.distributed.fsdp._common_utils import _get_module_fsdp_state_if_fully_sharded_module, _module_handle |
| from torch.distributed.fsdp._runtime_utils import ( |
| _post_backward_hook, |
| _pre_backward_hook, |
| ) |
| from torch.utils.checkpoint import ( |
| _get_autocast_kwargs, |
| _get_device_module, |
| _infer_device_type, |
| check_backward_validity, |
| detach_variable, |
| get_device_states, |
| set_device_states, |
| ) |
|
|
|
|
| class CheckpointFunction(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, run_function, preserve_rng_state, *args): |
| check_backward_validity(args) |
| ctx.run_function = run_function |
| ctx.preserve_rng_state = preserve_rng_state |
| |
| ctx.device = _infer_device_type(*args) |
| ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(ctx.device) |
| if preserve_rng_state: |
| ctx.fwd_cpu_state = torch.get_rng_state() |
| |
| |
| |
| |
| ctx.had_device_in_fwd = False |
| device_module = _get_device_module(ctx.device) |
| if getattr(device_module, "_initialized", False): |
| ctx.had_device_in_fwd = True |
| ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args) |
|
|
| |
| |
| ctx.inputs = [] |
| ctx.tensor_indices = [] |
| tensor_inputs = [] |
| for i, arg in enumerate(args): |
| if torch.is_tensor(arg): |
| tensor_inputs.append(arg) |
| ctx.tensor_indices.append(i) |
| ctx.inputs.append(None) |
| else: |
| ctx.inputs.append(arg) |
|
|
| ctx.save_for_backward(*tensor_inputs) |
|
|
| with torch.no_grad(): |
| outputs = run_function(*args) |
|
|
| |
| if not isinstance(ctx.run_function, torch.nn.Module): |
| ctx.patch_module = ctx.run_function.__self__ |
| else: |
| ctx.patch_module = ctx.run_function |
| state = _get_module_fsdp_state_if_fully_sharded_module(ctx.patch_module) |
| if state: |
| handle = _module_handle(state, ctx.patch_module) |
| if handle: |
| handle._needs_pre_backward_unshard = True |
| return outputs |
|
|
| @staticmethod |
| def backward(ctx, *args): |
| if not torch.autograd._is_checkpoint_valid(): |
| raise RuntimeError( |
| "When use_reentrant=True, torch.utils.checkpoint is incompatible" |
| " with .grad() or passing an `inputs` parameter to .backward()." |
| " To resolve this error, you can either set use_reentrant=False," |
| " or call .backward() without passing the `inputs` argument." |
| ) |
| |
| handle = None |
| state = _get_module_fsdp_state_if_fully_sharded_module(ctx.patch_module) |
| if state: |
| handle = _module_handle(state, ctx.patch_module) |
| if handle: |
| _pre_backward_hook(state, ctx.patch_module, handle, None) |
|
|
| |
| inputs = list(ctx.inputs) |
| tensor_indices = ctx.tensor_indices |
| tensors = ctx.saved_tensors |
|
|
| |
| for i, idx in enumerate(tensor_indices): |
| inputs[idx] = tensors[i] |
|
|
| |
| |
| |
| rng_devices = [] |
| if ctx.preserve_rng_state and ctx.had_device_in_fwd: |
| rng_devices = ctx.fwd_devices |
| with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device): |
| if ctx.preserve_rng_state: |
| torch.set_rng_state(ctx.fwd_cpu_state) |
| if ctx.had_device_in_fwd: |
| set_device_states(ctx.fwd_devices, ctx.fwd_device_states) |
| detached_inputs = detach_variable(tuple(inputs)) |
|
|
| device_autocast_ctx = ( |
| torch.amp.autocast(device_type=ctx.device, **ctx.device_autocast_kwargs) |
| if torch.amp.is_autocast_available(ctx.device) |
| else contextlib.nullcontext() |
| ) |
| with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): |
| outputs = ctx.run_function(*detached_inputs) |
|
|
| if isinstance(outputs, torch.Tensor): |
| outputs = (outputs,) |
|
|
| |
| outputs_with_grad = [] |
| args_with_grad = [] |
| for i in range(len(outputs)): |
| if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: |
| outputs_with_grad.append(outputs[i]) |
| args_with_grad.append(args[i]) |
| if len(outputs_with_grad) == 0: |
| raise RuntimeError("none of output has requires_grad=True, this checkpoint() is not necessary") |
| torch.autograd.backward(outputs_with_grad, args_with_grad) |
| grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs) |
|
|
| |
| if handle: |
| _post_backward_hook(state, handle, None) |
|
|
| return (None, None) + grads |
|
|