File size: 7,479 Bytes
66c9c8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
# Copyright (c) 2022 NVIDIA CORPORATION.  All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import warp as wp


class Tape:
    """
    Record kernel launches within a Tape scope to enable automatic differentiation.
    Gradients can be computed after the operations have been recorded on the tape via
    ``tape.backward()``.

    Example
    -------

    .. code-block:: python

        tape = wp.Tape()

        # forward pass
        with tape:
            wp.launch(kernel=compute1, inputs=[a, b], device="cuda")
            wp.launch(kernel=compute2, inputs=[c, d], device="cuda")
            wp.launch(kernel=loss, inputs=[d, l], device="cuda")

        # reverse pass
        tape.backward(l)

    Gradients can be accessed via the ``tape.gradients`` dictionary, e.g.:

    .. code-block:: python

        print(tape.gradients[a])

    """

    def __init__(self):
        self.gradients = {}
        self.const_gradients = set()
        self.launches = []

        self.loss = None

    def __enter__(self):
        if wp.context.runtime.tape is not None:
            raise RuntimeError("Warp: Error, entering a tape while one is already active")

        wp.context.runtime.tape = self

        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if wp.context.runtime.tape is None:
            raise RuntimeError("Warp: Error, ended tape capture, but tape not present")

        wp.context.runtime.tape = None

    # adj_outputs is a mapping from output tensor -> adjoint of the output
    # after running backward the gradients of tensors may be retrieved by:
    #
    #  adj_tensor = tape.gradients[tensor]
    #
    def backward(self, loss: wp.array = None, grads: dict = None):
        """
        Evaluate the backward pass of the recorded operations on the tape.
        A single-element array ``loss`` or a dictionary of arrays ``grads``
        can be provided to assign the incoming gradients for the reverse-mode
        automatic differentiation pass.

        Args:
            loss (wp.array): A single-element array that holds the loss function value whose gradient is to be computed
            grads (dict): A dictionary of arrays that map from Warp arrays to their incoming gradients

        """
        # if scalar loss is specified then initialize
        # a 'seed' array for it, with gradient of one
        if loss:
            if loss.size > 1 or wp.types.type_length(loss.dtype) > 1:
                raise RuntimeError("Can only return gradients for scalar loss functions.")

            if not loss.requires_grad:
                raise RuntimeError(
                    "Scalar loss arrays should have requires_grad=True set before calling Tape.backward()"
                )

            # set the seed grad to 1.0
            loss.grad.fill_(1.0)

        # simply apply dict grads to objects
        # this is just for backward compat. with
        # existing code before we added wp.array.grad attribute
        if grads:
            for a, g in grads.items():
                a.grad = g
                self.const_gradients.add(a)

        # run launches backwards
        for launch in reversed(self.launches):
            if callable(launch):
                launch()

            else:
                kernel = launch[0]
                dim = launch[1]
                max_blocks = launch[2]
                inputs = launch[3]
                outputs = launch[4]
                device = launch[5]

                adj_inputs = []
                adj_outputs = []

                # lookup adjoint inputs
                for a in inputs:
                    adj_inputs.append(self.get_adjoint(a))

                # lookup adjoint outputs, todo: only allocate outputs if necessary
                for a in outputs:
                    adj_outputs.append(self.get_adjoint(a))

                wp.launch(
                    kernel=kernel,
                    dim=dim,
                    inputs=inputs,
                    outputs=outputs,
                    adj_inputs=adj_inputs,
                    adj_outputs=adj_outputs,
                    device=device,
                    adjoint=True,
                    max_blocks=max_blocks,
                )

    # record a kernel launch on the tape
    def record_launch(self, kernel, dim, max_blocks, inputs, outputs, device):
        self.launches.append([kernel, dim, max_blocks, inputs, outputs, device])

    def record_func(self, backward, arrays):
        """
        Records a custom function to be executed only in the backward pass.

        Args:
            backward (Callable): A callable Python object (can be any function) that will be executed in the backward pass.
            arrays (list): A list of arrays that are used by the function for gradient tracking.
        """
        self.launches.append(backward)

        for a in arrays:
            if isinstance(a, wp.array) and a.grad:
                self.gradients[a] = a.grad
            else:
                raise RuntimeError(
                    f"Array {a} is not of type wp.array or is missing a gradient array. Set array parameter requires_grad=True during instantiation."
                )

    # returns the adjoint of a kernel parameter
    def get_adjoint(self, a):
        if not wp.types.is_array(a) and not isinstance(a, wp.codegen.StructInstance):
            # if input is a simple type (e.g.: float, vec3, etc) then
            # no gradient needed (we only return gradients through arrays and structs)
            return a

        elif wp.types.is_array(a) and a.grad:
            # keep track of all gradients used by the tape (for zeroing)
            # ignore the scalar loss since we don't want to clear its grad
            self.gradients[a] = a.grad
            return a.grad

        elif isinstance(a, wp.codegen.StructInstance):
            adj = a._cls()
            for name, _ in a._cls.ctype._fields_:
                if name.startswith("_"):
                    continue
                if isinstance(a._cls.vars[name].type, wp.array):
                    arr = getattr(a, name)
                    if arr.grad:
                        grad = self.gradients[arr] = arr.grad
                    else:
                        grad = None
                    setattr(adj, name, grad)
                else:
                    setattr(adj, name, getattr(a, name))

            self.gradients[a] = adj
            return adj

        return None

    def reset(self):
        """
        Clear all operations recorded on the tape and zero out all gradients.
        """
        self.launches = []
        self.zero()

    def zero(self):
        """
        Zero out all gradients recorded on the tape.
        """
        for a, g in self.gradients.items():
            if a not in self.const_gradients:
                if isinstance(a, wp.codegen.StructInstance):
                    for name in g._cls.vars:
                        if isinstance(g._cls.vars[name].type, wp.array) and g._cls.vars[name].requires_grad:
                            getattr(g, name).zero_()
                else:
                    g.zero_()