| """ |
| Checkpoint Utils for Models |
| |
| Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) |
| Please cite our work if the code is helpful to you. |
| """ |
|
|
| import torch |
|
|
|
|
| class CheckpointFunction(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, run_function, length, *args): |
| ctx.run_function = run_function |
| ctx.input_tensors = list(args[:length]) |
| ctx.input_params = list(args[length:]) |
|
|
| with torch.no_grad(): |
| output_tensors = ctx.run_function(*ctx.input_tensors) |
| return output_tensors |
|
|
| @staticmethod |
| def backward(ctx, *output_grads): |
| ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] |
| with torch.enable_grad(): |
| |
| |
| |
| shallow_copies = [x.view_as(x) for x in ctx.input_tensors] |
| output_tensors = ctx.run_function(*shallow_copies) |
| input_grads = torch.autograd.grad( |
| output_tensors, |
| ctx.input_tensors + ctx.input_params, |
| output_grads, |
| allow_unused=True, |
| ) |
| del ctx.input_tensors |
| del ctx.input_params |
| del output_tensors |
| return (None, None) + input_grads |
|
|
|
|
| def checkpoint(func, inputs, params, flag): |
| """ |
| Evaluate a function without caching intermediate activations, allowing for |
| reduced memory at the expense of extra compute in the backward pass. |
| :param func: the function to evaluate. |
| :param inputs: the argument sequence to pass to `func`. |
| :param params: a sequence of parameters `func` depends on but does not |
| explicitly take as arguments. |
| :param flag: if False, disable gradient checkpointing. |
| """ |
| if flag: |
| args = tuple(inputs) + tuple(params) |
| return CheckpointFunction.apply(func, len(inputs), *args) |
| else: |
| return func(*inputs) |
|
|