Spaces:
Sleeping
Sleeping
| # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved. | |
| # NVIDIA CORPORATION and its licensors retain all intellectual property | |
| # and proprietary rights in and to this software, related documentation | |
| # and any modifications thereto. Any use, reproduction, disclosure or | |
| # distribution of this software and related documentation without an express | |
| # license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| import warp as wp | |
| class Tape: | |
| """ | |
| Record kernel launches within a Tape scope to enable automatic differentiation. | |
| Gradients can be computed after the operations have been recorded on the tape via | |
| ``tape.backward()``. | |
| Example | |
| ------- | |
| .. code-block:: python | |
| tape = wp.Tape() | |
| # forward pass | |
| with tape: | |
| wp.launch(kernel=compute1, inputs=[a, b], device="cuda") | |
| wp.launch(kernel=compute2, inputs=[c, d], device="cuda") | |
| wp.launch(kernel=loss, inputs=[d, l], device="cuda") | |
| # reverse pass | |
| tape.backward(l) | |
| Gradients can be accessed via the ``tape.gradients`` dictionary, e.g.: | |
| .. code-block:: python | |
| print(tape.gradients[a]) | |
| """ | |
| def __init__(self): | |
| self.gradients = {} | |
| self.const_gradients = set() | |
| self.launches = [] | |
| self.loss = None | |
| def __enter__(self): | |
| if wp.context.runtime.tape is not None: | |
| raise RuntimeError("Warp: Error, entering a tape while one is already active") | |
| wp.context.runtime.tape = self | |
| return self | |
| def __exit__(self, exc_type, exc_value, traceback): | |
| if wp.context.runtime.tape is None: | |
| raise RuntimeError("Warp: Error, ended tape capture, but tape not present") | |
| wp.context.runtime.tape = None | |
| # adj_outputs is a mapping from output tensor -> adjoint of the output | |
| # after running backward the gradients of tensors may be retrieved by: | |
| # | |
| # adj_tensor = tape.gradients[tensor] | |
| # | |
| def backward(self, loss: wp.array = None, grads: dict = None): | |
| """ | |
| Evaluate the backward pass of the recorded operations on the tape. | |
| A single-element array ``loss`` or a dictionary of arrays ``grads`` | |
| can be provided to assign the incoming gradients for the reverse-mode | |
| automatic differentiation pass. | |
| Args: | |
| loss (wp.array): A single-element array that holds the loss function value whose gradient is to be computed | |
| grads (dict): A dictionary of arrays that map from Warp arrays to their incoming gradients | |
| """ | |
| # if scalar loss is specified then initialize | |
| # a 'seed' array for it, with gradient of one | |
| if loss: | |
| if loss.size > 1 or wp.types.type_length(loss.dtype) > 1: | |
| raise RuntimeError("Can only return gradients for scalar loss functions.") | |
| if not loss.requires_grad: | |
| raise RuntimeError( | |
| "Scalar loss arrays should have requires_grad=True set before calling Tape.backward()" | |
| ) | |
| # set the seed grad to 1.0 | |
| loss.grad.fill_(1.0) | |
| # simply apply dict grads to objects | |
| # this is just for backward compat. with | |
| # existing code before we added wp.array.grad attribute | |
| if grads: | |
| for a, g in grads.items(): | |
| a.grad = g | |
| self.const_gradients.add(a) | |
| # run launches backwards | |
| for launch in reversed(self.launches): | |
| if callable(launch): | |
| launch() | |
| else: | |
| kernel = launch[0] | |
| dim = launch[1] | |
| max_blocks = launch[2] | |
| inputs = launch[3] | |
| outputs = launch[4] | |
| device = launch[5] | |
| adj_inputs = [] | |
| adj_outputs = [] | |
| # lookup adjoint inputs | |
| for a in inputs: | |
| adj_inputs.append(self.get_adjoint(a)) | |
| # lookup adjoint outputs, todo: only allocate outputs if necessary | |
| for a in outputs: | |
| adj_outputs.append(self.get_adjoint(a)) | |
| wp.launch( | |
| kernel=kernel, | |
| dim=dim, | |
| inputs=inputs, | |
| outputs=outputs, | |
| adj_inputs=adj_inputs, | |
| adj_outputs=adj_outputs, | |
| device=device, | |
| adjoint=True, | |
| max_blocks=max_blocks, | |
| ) | |
| # record a kernel launch on the tape | |
| def record_launch(self, kernel, dim, max_blocks, inputs, outputs, device): | |
| self.launches.append([kernel, dim, max_blocks, inputs, outputs, device]) | |
| def record_func(self, backward, arrays): | |
| """ | |
| Records a custom function to be executed only in the backward pass. | |
| Args: | |
| backward (Callable): A callable Python object (can be any function) that will be executed in the backward pass. | |
| arrays (list): A list of arrays that are used by the function for gradient tracking. | |
| """ | |
| self.launches.append(backward) | |
| for a in arrays: | |
| if isinstance(a, wp.array) and a.grad: | |
| self.gradients[a] = a.grad | |
| else: | |
| raise RuntimeError( | |
| f"Array {a} is not of type wp.array or is missing a gradient array. Set array parameter requires_grad=True during instantiation." | |
| ) | |
| # returns the adjoint of a kernel parameter | |
| def get_adjoint(self, a): | |
| if not wp.types.is_array(a) and not isinstance(a, wp.codegen.StructInstance): | |
| # if input is a simple type (e.g.: float, vec3, etc) then | |
| # no gradient needed (we only return gradients through arrays and structs) | |
| return a | |
| elif wp.types.is_array(a) and a.grad: | |
| # keep track of all gradients used by the tape (for zeroing) | |
| # ignore the scalar loss since we don't want to clear its grad | |
| self.gradients[a] = a.grad | |
| return a.grad | |
| elif isinstance(a, wp.codegen.StructInstance): | |
| adj = a._cls() | |
| for name, _ in a._cls.ctype._fields_: | |
| if name.startswith("_"): | |
| continue | |
| if isinstance(a._cls.vars[name].type, wp.array): | |
| arr = getattr(a, name) | |
| if arr.grad: | |
| grad = self.gradients[arr] = arr.grad | |
| else: | |
| grad = None | |
| setattr(adj, name, grad) | |
| else: | |
| setattr(adj, name, getattr(a, name)) | |
| self.gradients[a] = adj | |
| return adj | |
| return None | |
| def reset(self): | |
| """ | |
| Clear all operations recorded on the tape and zero out all gradients. | |
| """ | |
| self.launches = [] | |
| self.zero() | |
| def zero(self): | |
| """ | |
| Zero out all gradients recorded on the tape. | |
| """ | |
| for a, g in self.gradients.items(): | |
| if a not in self.const_gradients: | |
| if isinstance(a, wp.codegen.StructInstance): | |
| for name in g._cls.vars: | |
| if isinstance(g._cls.vars[name].type, wp.array) and g._cls.vars[name].requires_grad: | |
| getattr(g, name).zero_() | |
| else: | |
| g.zero_() | |