|
|
|
|
|
|
|
|
|
|
|
import weakref |
|
|
|
|
|
import torch |
|
|
from torch.utils.checkpoint import check_backward_validity, detach_variable |
|
|
|
|
|
from internlm.core.context.random import ( |
|
|
get_current_mode, |
|
|
get_states, |
|
|
set_mode, |
|
|
set_seed_states, |
|
|
sync_states, |
|
|
) |
|
|
|
|
|
from .common import get_current_device |
|
|
|
|
|
|
|
|
def copy_to_device(obj, device): |
|
|
if torch.is_tensor(obj): |
|
|
|
|
|
|
|
|
ret = obj.to(device).detach() |
|
|
ret.requires_grad = obj.requires_grad |
|
|
return ret |
|
|
elif isinstance(obj, list): |
|
|
return [copy_to_device(i, device) for i in obj] |
|
|
elif isinstance(obj, tuple): |
|
|
return tuple([copy_to_device(v, device) for v in obj]) |
|
|
elif isinstance(obj, dict): |
|
|
return {k: copy_to_device(v, device) for k, v in obj.items()} |
|
|
else: |
|
|
return obj |
|
|
|
|
|
|
|
|
class CheckpointFunction(torch.autograd.Function): |
|
|
""" |
|
|
Checkpoint Function |
|
|
""" |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, run_function, activation_offload=False, *args): |
|
|
check_backward_validity(args) |
|
|
ctx.run_function = run_function |
|
|
ctx.activation_offload = activation_offload |
|
|
ctx.device = get_current_device() |
|
|
|
|
|
|
|
|
ctx.fwd_cpu_rng_state = torch.get_rng_state() |
|
|
sync_states() |
|
|
ctx.fwd_seed_states = get_states(copy=True) |
|
|
ctx.fwd_current_mode = get_current_mode() |
|
|
|
|
|
if hasattr(torch, "is_autocast_enabled"): |
|
|
ctx.had_autocast_in_fwd = torch.is_autocast_enabled() |
|
|
else: |
|
|
ctx.had_autocast_in_fwd = False |
|
|
|
|
|
if activation_offload: |
|
|
inputs_cuda = copy_to_device(args, ctx.device) |
|
|
else: |
|
|
inputs_cuda = args |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = run_function(*inputs_cuda) |
|
|
|
|
|
|
|
|
ctx.inputs = [] |
|
|
ctx.tensor_indices = [] |
|
|
tensor_inputs = [] |
|
|
for i, arg in enumerate(args): |
|
|
if torch.is_tensor(arg): |
|
|
if activation_offload: |
|
|
tensor_inputs.append(copy_to_device(arg, "cpu")) |
|
|
else: |
|
|
tensor_inputs.append(arg) |
|
|
ctx.tensor_indices.append(i) |
|
|
ctx.inputs.append(None) |
|
|
else: |
|
|
ctx.inputs.append(arg) |
|
|
|
|
|
if activation_offload: |
|
|
ctx.tensor_inputs = tensor_inputs |
|
|
else: |
|
|
ctx.save_for_backward(*tensor_inputs) |
|
|
return outputs |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, *args): |
|
|
if not torch.autograd._is_checkpoint_valid(): |
|
|
raise RuntimeError( |
|
|
"Checkpointing is not compatible with .grad() or when an `inputs` parameter is " |
|
|
"passed to .backward(). Please use .backward() and do not pass its `inputs` argument." |
|
|
) |
|
|
|
|
|
inputs = list(ctx.inputs) |
|
|
tensor_indices = ctx.tensor_indices |
|
|
|
|
|
if ctx.activation_offload: |
|
|
tensors = ctx.tensor_inputs |
|
|
else: |
|
|
tensors = ctx.saved_tensors |
|
|
|
|
|
|
|
|
bwd_cpu_rng_state = torch.get_rng_state() |
|
|
sync_states() |
|
|
bwd_seed_states = get_states(copy=True) |
|
|
bwd_current_mode = get_current_mode() |
|
|
|
|
|
|
|
|
torch.set_rng_state(ctx.fwd_cpu_rng_state) |
|
|
for parallel_mode, state in ctx.fwd_seed_states.items(): |
|
|
set_seed_states(parallel_mode, state) |
|
|
set_mode(ctx.fwd_current_mode) |
|
|
if ctx.activation_offload: |
|
|
tensors = copy_to_device(tensors, ctx.device) |
|
|
|
|
|
|
|
|
for i, idx in enumerate(tensor_indices): |
|
|
inputs[idx] = tensors[i] |
|
|
detached_inputs = detach_variable(tuple(inputs)) |
|
|
if ctx.had_autocast_in_fwd: |
|
|
with torch.enable_grad(), torch.cuda.amp.autocast(): |
|
|
outputs = ctx.run_function(*detached_inputs) |
|
|
else: |
|
|
with torch.enable_grad(): |
|
|
outputs = ctx.run_function(*detached_inputs) |
|
|
|
|
|
if isinstance(outputs, torch.Tensor): |
|
|
outputs = (outputs,) |
|
|
|
|
|
torch.set_rng_state(bwd_cpu_rng_state) |
|
|
for parallel_mode, state in bwd_seed_states.items(): |
|
|
set_seed_states(parallel_mode, state) |
|
|
set_mode(bwd_current_mode) |
|
|
|
|
|
|
|
|
outputs_with_grad = [] |
|
|
args_with_grad = [] |
|
|
for i in range(len(outputs)): |
|
|
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: |
|
|
outputs_with_grad.append(outputs[i]) |
|
|
args_with_grad.append(args[i]) |
|
|
if len(outputs_with_grad) == 0: |
|
|
raise RuntimeError("none of output has requires_grad=True," " this checkpoint() is not necessary") |
|
|
torch.autograd.backward(outputs_with_grad, args_with_grad) |
|
|
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs) |
|
|
return (None, None) + grads |
|
|
|
|
|
|
|
|
def activation_checkpoint(function, activation_offload, *args, use_reentrant: bool = True): |
|
|
"""Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint. |
|
|
Args: |
|
|
function: Describe the forward pass function. It should know how to handle the input tuples. |
|
|
activation_offload: The variable to check whether we should offload activation to cpu |
|
|
args (list): Tuple containing the parameters of the function |
|
|
use_reentrant: Bool type to check if we need to use_reentrant, if use_reentrant=False, there |
|
|
might be more flexibility for user to define there checkpoint function |
|
|
Returns: |
|
|
Output of running function with provided args. |
|
|
""" |
|
|
if use_reentrant: |
|
|
return CheckpointFunction.apply(function, activation_offload, *args) |
|
|
else: |
|
|
return _checkpoint_without_reentrant( |
|
|
function, |
|
|
activation_offload, |
|
|
*args, |
|
|
) |
|
|
|
|
|
|
|
|
def _checkpoint_without_reentrant(function, activation_offload=False, *args): |
|
|
|
|
|
fwd_cpu_state = torch.get_rng_state() |
|
|
sync_states() |
|
|
fwd_seed_states = get_states(copy=True) |
|
|
fwd_current_mode = get_current_mode() |
|
|
|
|
|
|
|
|
if hasattr(torch, "is_autocast_enabled"): |
|
|
has_autocast_in_fwd = torch.is_autocast_enabled() |
|
|
else: |
|
|
has_autocast_in_fwd = False |
|
|
|
|
|
|
|
|
storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() |
|
|
weak_holder_list = [] |
|
|
|
|
|
|
|
|
class Holder: |
|
|
pass |
|
|
|
|
|
|
|
|
def pack(): |
|
|
res = Holder() |
|
|
weak_holder_list.append(weakref.ref(res)) |
|
|
return res |
|
|
|
|
|
|
|
|
def unpack(x): |
|
|
unpack_counter = 0 |
|
|
|
|
|
|
|
|
if len(storage) == 0: |
|
|
|
|
|
def inner_pack(inner): |
|
|
nonlocal unpack_counter |
|
|
unpack_counter += 1 |
|
|
|
|
|
|
|
|
|
|
|
if weak_holder_list[unpack_counter - 1]() is None: |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
storage[weak_holder_list[unpack_counter - 1]()] = inner.detach() |
|
|
return |
|
|
|
|
|
def inner_unpack(packed): |
|
|
raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.") |
|
|
|
|
|
|
|
|
torch.set_rng_state(fwd_cpu_state) |
|
|
for parallel_mode, state in fwd_seed_states.items(): |
|
|
set_seed_states(parallel_mode, state) |
|
|
set_mode(fwd_current_mode) |
|
|
|
|
|
|
|
|
if activation_offload: |
|
|
for arg in args: |
|
|
if torch.is_tensor(arg): |
|
|
arg = arg.to(device=device) |
|
|
|
|
|
|
|
|
if has_autocast_in_fwd: |
|
|
with torch.enable_grad(), torch.cuda.amp.autocast(), torch.autograd.graph.saved_tensors_hooks( |
|
|
inner_pack, inner_unpack |
|
|
): |
|
|
function(*args) |
|
|
else: |
|
|
with torch.enable_grad(), torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): |
|
|
function(*args) |
|
|
|
|
|
if x not in storage: |
|
|
raise RuntimeError( |
|
|
"Attempt to retrieve a tensor saved by autograd multiple times without checkpoint" |
|
|
" recomputation being triggered in between, this is not currently supported. Please" |
|
|
" open an issue with details on your use case so that we can prioritize adding this." |
|
|
) |
|
|
|
|
|
return storage[x] |
|
|
|
|
|
|
|
|
if activation_offload: |
|
|
device = get_current_device() |
|
|
|
|
|
|
|
|
with torch.autograd.graph.saved_tensors_hooks(pack, unpack): |
|
|
output = function(*args) |
|
|
|
|
|
|
|
|
if activation_offload: |
|
|
for arg in args: |
|
|
if torch.is_tensor(arg): |
|
|
arg = arg.to(device="cpu") |
|
|
|
|
|
return output |
|
|
|