File size: 9,915 Bytes
ee3e701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import weakref

import torch
from torch.utils.checkpoint import check_backward_validity, detach_variable

from internlm.core.context.random import (
    get_current_mode,
    get_states,
    set_mode,
    set_seed_states,
    sync_states,
)

from .common import get_current_device


def copy_to_device(obj, device):
    if torch.is_tensor(obj):
        # Notice:
        # When in no_grad context, requires_gard is False after movement
        ret = obj.to(device).detach()
        ret.requires_grad = obj.requires_grad
        return ret
    elif isinstance(obj, list):
        return [copy_to_device(i, device) for i in obj]
    elif isinstance(obj, tuple):
        return tuple([copy_to_device(v, device) for v in obj])
    elif isinstance(obj, dict):
        return {k: copy_to_device(v, device) for k, v in obj.items()}
    else:
        return obj


class CheckpointFunction(torch.autograd.Function):
    """
    Checkpoint Function
    """

    @staticmethod
    def forward(ctx, run_function, activation_offload=False, *args):  # pylint: disable=W1113
        check_backward_validity(args)
        ctx.run_function = run_function
        ctx.activation_offload = activation_offload
        ctx.device = get_current_device()

        # preserve rng states
        ctx.fwd_cpu_rng_state = torch.get_rng_state()
        sync_states()
        ctx.fwd_seed_states = get_states(copy=True)
        ctx.fwd_current_mode = get_current_mode()

        if hasattr(torch, "is_autocast_enabled"):
            ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
        else:
            ctx.had_autocast_in_fwd = False

        if activation_offload:
            inputs_cuda = copy_to_device(args, ctx.device)
        else:
            inputs_cuda = args

        with torch.no_grad():
            outputs = run_function(*inputs_cuda)
        # Save non-tensor inputs in ctx, keep a placeholder None for tensors
        # to be filled out during the backward.
        ctx.inputs = []
        ctx.tensor_indices = []
        tensor_inputs = []
        for i, arg in enumerate(args):
            if torch.is_tensor(arg):
                if activation_offload:
                    tensor_inputs.append(copy_to_device(arg, "cpu"))
                else:
                    tensor_inputs.append(arg)
                ctx.tensor_indices.append(i)
                ctx.inputs.append(None)
            else:
                ctx.inputs.append(arg)

        if activation_offload:
            ctx.tensor_inputs = tensor_inputs
        else:
            ctx.save_for_backward(*tensor_inputs)
        return outputs

    @staticmethod
    def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
                "Checkpointing is not compatible with .grad() or when an `inputs` parameter is "
                "passed to .backward(). Please use .backward() and do not pass its `inputs` argument."
            )
        # Copy the list to avoid modifying original list.
        inputs = list(ctx.inputs)
        tensor_indices = ctx.tensor_indices

        if ctx.activation_offload:
            tensors = ctx.tensor_inputs
        else:
            tensors = ctx.saved_tensors

        # store the current states
        bwd_cpu_rng_state = torch.get_rng_state()
        sync_states()
        bwd_seed_states = get_states(copy=True)
        bwd_current_mode = get_current_mode()

        # set the states to what it used to be
        torch.set_rng_state(ctx.fwd_cpu_rng_state)
        for parallel_mode, state in ctx.fwd_seed_states.items():
            set_seed_states(parallel_mode, state)
        set_mode(ctx.fwd_current_mode)
        if ctx.activation_offload:
            tensors = copy_to_device(tensors, ctx.device)

        # Fill in inputs with appropriate saved tensors.
        for i, idx in enumerate(tensor_indices):
            inputs[idx] = tensors[i]
        detached_inputs = detach_variable(tuple(inputs))
        if ctx.had_autocast_in_fwd:
            with torch.enable_grad(), torch.cuda.amp.autocast():
                outputs = ctx.run_function(*detached_inputs)
        else:
            with torch.enable_grad():
                outputs = ctx.run_function(*detached_inputs)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)
        # recover the rng states
        torch.set_rng_state(bwd_cpu_rng_state)
        for parallel_mode, state in bwd_seed_states.items():
            set_seed_states(parallel_mode, state)
        set_mode(bwd_current_mode)

        # run backward() with only tensor that requires grad
        outputs_with_grad = []
        args_with_grad = []
        for i in range(len(outputs)):
            if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
                outputs_with_grad.append(outputs[i])
                args_with_grad.append(args[i])
        if len(outputs_with_grad) == 0:
            raise RuntimeError("none of output has requires_grad=True," " this checkpoint() is not necessary")
        torch.autograd.backward(outputs_with_grad, args_with_grad)
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs)
        return (None, None) + grads


def activation_checkpoint(function, activation_offload, *args, use_reentrant: bool = True):
    """Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint.
    Args:
        function: Describe the forward pass function. It should know how to handle the input tuples.
        activation_offload: The variable to check whether we should offload activation to cpu
        args (list): Tuple containing the parameters of the function
        use_reentrant: Bool type to check if we need to use_reentrant, if use_reentrant=False, there
        might be more flexibility for user to define there checkpoint function
    Returns:
        Output of running function with provided args.
    """
    if use_reentrant:
        return CheckpointFunction.apply(function, activation_offload, *args)
    else:
        return _checkpoint_without_reentrant(
            function,
            activation_offload,
            *args,
        )


def _checkpoint_without_reentrant(function, activation_offload=False, *args):  # pylint: disable=W1113
    # store rng_state
    fwd_cpu_state = torch.get_rng_state()
    sync_states()
    fwd_seed_states = get_states(copy=True)
    fwd_current_mode = get_current_mode()

    # check if use autocast
    if hasattr(torch, "is_autocast_enabled"):
        has_autocast_in_fwd = torch.is_autocast_enabled()
    else:
        has_autocast_in_fwd = False

    # using WeakKeyDictionary to store all the activation the first time we call unpack
    storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
    weak_holder_list = []

    # class for weakref.ref
    class Holder:
        pass

    # return a Holder object for later unpack process
    def pack():
        res = Holder()
        weak_holder_list.append(weakref.ref(res))
        return res

    # unpack hook
    def unpack(x):
        unpack_counter = 0

        # re-compute all the activation inside the function when we first call unpack
        if len(storage) == 0:

            def inner_pack(inner):
                nonlocal unpack_counter
                unpack_counter += 1

                # If the holder went out of scope, the SavedVariable is dead and so
                # the value will never be read from the storage. Skip filling it.
                if weak_holder_list[unpack_counter - 1]() is None:
                    return

                # Use detach here to ensure we don't keep the temporary autograd
                # graph created during the second forward
                storage[weak_holder_list[unpack_counter - 1]()] = inner.detach()
                return

            def inner_unpack(packed):
                raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.")

            # restore rng state
            torch.set_rng_state(fwd_cpu_state)
            for parallel_mode, state in fwd_seed_states.items():
                set_seed_states(parallel_mode, state)
            set_mode(fwd_current_mode)

            # reload arg into device if needed
            if activation_offload:
                for arg in args:
                    if torch.is_tensor(arg):
                        arg = arg.to(device=device)

            # rerun forward, the inner_pack will store all the activations in storage
            if has_autocast_in_fwd:
                with torch.enable_grad(), torch.cuda.amp.autocast(), torch.autograd.graph.saved_tensors_hooks(
                    inner_pack, inner_unpack
                ):
                    function(*args)
            else:
                with torch.enable_grad(), torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
                    function(*args)

        if x not in storage:
            raise RuntimeError(
                "Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
                " recomputation being triggered in between, this is not currently supported. Please"
                " open an issue with details on your use case so that we can prioritize adding this."
            )

        return storage[x]

    # get device if we need to offload the activation
    if activation_offload:
        device = get_current_device()

    # run function with pack and unpack as saved_tensors_hooks
    with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
        output = function(*args)

        # offload activation if needed
        if activation_offload:
            for arg in args:
                if torch.is_tensor(arg):
                    arg = arg.to(device="cpu")

    return output