Spaces:
Sleeping
Sleeping
| # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved. | |
| # NVIDIA CORPORATION and its licensors retain all intellectual property | |
| # and proprietary rights in and to this software, related documentation | |
| # and any modifications thereto. Any use, reproduction, disclosure or | |
| # distribution of this software and related documentation without an express | |
| # license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| import unittest | |
| import numpy as np | |
| import warp as wp | |
| from warp.tests.unittest_utils import * | |
| wp.init() | |
| # atomic add function that memorizes which thread incremented the counter | |
| # so that the correct counter value per thread can be used in the replay | |
| # phase of the backward pass | |
| def reversible_increment( | |
| counter: wp.array(dtype=int), counter_index: int, value: int, thread_values: wp.array(dtype=int), tid: int | |
| ): | |
| next_index = wp.atomic_add(counter, counter_index, value) | |
| thread_values[tid] = next_index | |
| return next_index | |
| def replay_reversible_increment( | |
| counter: wp.array(dtype=int), counter_index: int, value: int, thread_values: wp.array(dtype=int), tid: int | |
| ): | |
| return thread_values[tid] | |
| def test_custom_replay_grad(test, device): | |
| num_threads = 128 | |
| counter = wp.zeros(1, dtype=wp.int32, device=device) | |
| thread_ids = wp.zeros(num_threads, dtype=wp.int32, device=device) | |
| inputs = wp.array(np.arange(num_threads, dtype=np.float32), device=device, requires_grad=True) | |
| outputs = wp.zeros_like(inputs) | |
| def run_atomic_add( | |
| input: wp.array(dtype=float), | |
| counter: wp.array(dtype=int), | |
| thread_values: wp.array(dtype=int), | |
| output: wp.array(dtype=float), | |
| ): | |
| tid = wp.tid() | |
| idx = reversible_increment(counter, 0, 1, thread_values, tid) | |
| output[idx] = input[idx] ** 2.0 | |
| tape = wp.Tape() | |
| with tape: | |
| wp.launch( | |
| run_atomic_add, dim=num_threads, inputs=[inputs, counter, thread_ids], outputs=[outputs], device=device | |
| ) | |
| tape.backward(grads={outputs: wp.array(np.ones(num_threads, dtype=np.float32), device=device)}) | |
| assert_np_equal(inputs.grad.numpy(), 2.0 * inputs.numpy(), tol=1e-4) | |
| def overload_fn(x: float, y: float): | |
| return x * 3.0 + y / 3.0, y**2.5 | |
| def overload_fn_grad(x: float, y: float, adj_ret0: float, adj_ret1: float): | |
| wp.adjoint[x] += x * adj_ret0 * 42.0 + y * adj_ret1 * 10.0 | |
| wp.adjoint[y] += y * adj_ret1 * 3.0 | |
| class MyStruct: | |
| scalar: float | |
| vec: wp.vec3 | |
| def overload_fn(x: MyStruct): | |
| return x.vec[0] * x.vec[1] * x.vec[2] * 4.0, wp.length(x.vec), x.scalar**0.5 | |
| def overload_fn_grad(x: MyStruct, adj_ret0: float, adj_ret1: float, adj_ret2: float): | |
| wp.adjoint[x.scalar] += x.scalar * adj_ret0 * 10.0 | |
| wp.adjoint[x.vec][0] += adj_ret0 * x.vec[1] * x.vec[2] * 20.0 | |
| wp.adjoint[x.vec][1] += adj_ret1 * x.vec[0] * x.vec[2] * 30.0 | |
| wp.adjoint[x.vec][2] += adj_ret2 * x.vec[0] * x.vec[1] * 40.0 | |
| def run_overload_float_fn( | |
| xs: wp.array(dtype=float), ys: wp.array(dtype=float), output0: wp.array(dtype=float), output1: wp.array(dtype=float) | |
| ): | |
| i = wp.tid() | |
| out0, out1 = overload_fn(xs[i], ys[i]) | |
| output0[i] = out0 | |
| output1[i] = out1 | |
| def run_overload_struct_fn(xs: wp.array(dtype=MyStruct), output: wp.array(dtype=float)): | |
| i = wp.tid() | |
| out0, out1, out2 = overload_fn(xs[i]) | |
| output[i] = out0 + out1 + out2 | |
| def test_custom_overload_grad(test, device): | |
| dim = 3 | |
| xs_float = wp.array(np.arange(1.0, dim + 1.0), dtype=wp.float32, requires_grad=True) | |
| ys_float = wp.array(np.arange(10.0, dim + 10.0), dtype=wp.float32, requires_grad=True) | |
| out0_float = wp.zeros(dim) | |
| out1_float = wp.zeros(dim) | |
| tape = wp.Tape() | |
| with tape: | |
| wp.launch(run_overload_float_fn, dim=dim, inputs=[xs_float, ys_float], outputs=[out0_float, out1_float]) | |
| tape.backward( | |
| grads={ | |
| out0_float: wp.array(np.ones(dim), dtype=wp.float32), | |
| out1_float: wp.array(np.ones(dim), dtype=wp.float32), | |
| } | |
| ) | |
| assert_np_equal(xs_float.grad.numpy(), xs_float.numpy() * 42.0 + ys_float.numpy() * 10.0) | |
| assert_np_equal(ys_float.grad.numpy(), ys_float.numpy() * 3.0) | |
| x0 = MyStruct() | |
| x0.vec = wp.vec3(1.0, 2.0, 3.0) | |
| x0.scalar = 4.0 | |
| x1 = MyStruct() | |
| x1.vec = wp.vec3(5.0, 6.0, 7.0) | |
| x1.scalar = -1.0 | |
| x2 = MyStruct() | |
| x2.vec = wp.vec3(8.0, 9.0, 10.0) | |
| x2.scalar = 19.0 | |
| xs_struct = wp.array([x0, x1, x2], dtype=MyStruct, requires_grad=True) | |
| out_struct = wp.zeros(dim) | |
| tape = wp.Tape() | |
| with tape: | |
| wp.launch(run_overload_struct_fn, dim=dim, inputs=[xs_struct], outputs=[out_struct]) | |
| tape.backward(grads={out_struct: wp.array(np.ones(dim), dtype=wp.float32)}) | |
| xs_struct_np = xs_struct.numpy() | |
| struct_grads = xs_struct.grad.numpy() | |
| # fmt: off | |
| assert_np_equal( | |
| np.array([g[0] for g in struct_grads]), | |
| np.array([g[0] * 10.0 for g in xs_struct_np])) | |
| assert_np_equal( | |
| np.array([g[1][0] for g in struct_grads]), | |
| np.array([g[1][1] * g[1][2] * 20.0 for g in xs_struct_np])) | |
| assert_np_equal( | |
| np.array([g[1][1] for g in struct_grads]), | |
| np.array([g[1][0] * g[1][2] * 30.0 for g in xs_struct_np])) | |
| assert_np_equal( | |
| np.array([g[1][2] for g in struct_grads]), | |
| np.array([g[1][0] * g[1][1] * 40.0 for g in xs_struct_np])) | |
| # fmt: on | |
| devices = get_test_devices() | |
| class TestGradCustoms(unittest.TestCase): | |
| pass | |
| add_function_test(TestGradCustoms, "test_custom_replay_grad", test_custom_replay_grad, devices=devices) | |
| add_function_test(TestGradCustoms, "test_custom_overload_grad", test_custom_overload_grad, devices=devices) | |
| if __name__ == "__main__": | |
| wp.build.clear_kernel_cache() | |
| unittest.main(verbosity=2, failfast=False) | |