qbhf2's picture
added NvidiaWarp and GarmentCode repos
66c9c8a
# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import unittest
import numpy as np
import warp as wp
from warp.tests.unittest_utils import *
wp.init()
# atomic add function that memorizes which thread incremented the counter
# so that the correct counter value per thread can be used in the replay
# phase of the backward pass
@wp.func
def reversible_increment(
counter: wp.array(dtype=int), counter_index: int, value: int, thread_values: wp.array(dtype=int), tid: int
):
next_index = wp.atomic_add(counter, counter_index, value)
thread_values[tid] = next_index
return next_index
@wp.func_replay(reversible_increment)
def replay_reversible_increment(
counter: wp.array(dtype=int), counter_index: int, value: int, thread_values: wp.array(dtype=int), tid: int
):
return thread_values[tid]
def test_custom_replay_grad(test, device):
num_threads = 128
counter = wp.zeros(1, dtype=wp.int32, device=device)
thread_ids = wp.zeros(num_threads, dtype=wp.int32, device=device)
inputs = wp.array(np.arange(num_threads, dtype=np.float32), device=device, requires_grad=True)
outputs = wp.zeros_like(inputs)
@wp.kernel
def run_atomic_add(
input: wp.array(dtype=float),
counter: wp.array(dtype=int),
thread_values: wp.array(dtype=int),
output: wp.array(dtype=float),
):
tid = wp.tid()
idx = reversible_increment(counter, 0, 1, thread_values, tid)
output[idx] = input[idx] ** 2.0
tape = wp.Tape()
with tape:
wp.launch(
run_atomic_add, dim=num_threads, inputs=[inputs, counter, thread_ids], outputs=[outputs], device=device
)
tape.backward(grads={outputs: wp.array(np.ones(num_threads, dtype=np.float32), device=device)})
assert_np_equal(inputs.grad.numpy(), 2.0 * inputs.numpy(), tol=1e-4)
@wp.func
def overload_fn(x: float, y: float):
return x * 3.0 + y / 3.0, y**2.5
@wp.func_grad(overload_fn)
def overload_fn_grad(x: float, y: float, adj_ret0: float, adj_ret1: float):
wp.adjoint[x] += x * adj_ret0 * 42.0 + y * adj_ret1 * 10.0
wp.adjoint[y] += y * adj_ret1 * 3.0
@wp.struct
class MyStruct:
scalar: float
vec: wp.vec3
@wp.func
def overload_fn(x: MyStruct):
return x.vec[0] * x.vec[1] * x.vec[2] * 4.0, wp.length(x.vec), x.scalar**0.5
@wp.func_grad(overload_fn)
def overload_fn_grad(x: MyStruct, adj_ret0: float, adj_ret1: float, adj_ret2: float):
wp.adjoint[x.scalar] += x.scalar * adj_ret0 * 10.0
wp.adjoint[x.vec][0] += adj_ret0 * x.vec[1] * x.vec[2] * 20.0
wp.adjoint[x.vec][1] += adj_ret1 * x.vec[0] * x.vec[2] * 30.0
wp.adjoint[x.vec][2] += adj_ret2 * x.vec[0] * x.vec[1] * 40.0
@wp.kernel
def run_overload_float_fn(
xs: wp.array(dtype=float), ys: wp.array(dtype=float), output0: wp.array(dtype=float), output1: wp.array(dtype=float)
):
i = wp.tid()
out0, out1 = overload_fn(xs[i], ys[i])
output0[i] = out0
output1[i] = out1
@wp.kernel
def run_overload_struct_fn(xs: wp.array(dtype=MyStruct), output: wp.array(dtype=float)):
i = wp.tid()
out0, out1, out2 = overload_fn(xs[i])
output[i] = out0 + out1 + out2
def test_custom_overload_grad(test, device):
dim = 3
xs_float = wp.array(np.arange(1.0, dim + 1.0), dtype=wp.float32, requires_grad=True)
ys_float = wp.array(np.arange(10.0, dim + 10.0), dtype=wp.float32, requires_grad=True)
out0_float = wp.zeros(dim)
out1_float = wp.zeros(dim)
tape = wp.Tape()
with tape:
wp.launch(run_overload_float_fn, dim=dim, inputs=[xs_float, ys_float], outputs=[out0_float, out1_float])
tape.backward(
grads={
out0_float: wp.array(np.ones(dim), dtype=wp.float32),
out1_float: wp.array(np.ones(dim), dtype=wp.float32),
}
)
assert_np_equal(xs_float.grad.numpy(), xs_float.numpy() * 42.0 + ys_float.numpy() * 10.0)
assert_np_equal(ys_float.grad.numpy(), ys_float.numpy() * 3.0)
x0 = MyStruct()
x0.vec = wp.vec3(1.0, 2.0, 3.0)
x0.scalar = 4.0
x1 = MyStruct()
x1.vec = wp.vec3(5.0, 6.0, 7.0)
x1.scalar = -1.0
x2 = MyStruct()
x2.vec = wp.vec3(8.0, 9.0, 10.0)
x2.scalar = 19.0
xs_struct = wp.array([x0, x1, x2], dtype=MyStruct, requires_grad=True)
out_struct = wp.zeros(dim)
tape = wp.Tape()
with tape:
wp.launch(run_overload_struct_fn, dim=dim, inputs=[xs_struct], outputs=[out_struct])
tape.backward(grads={out_struct: wp.array(np.ones(dim), dtype=wp.float32)})
xs_struct_np = xs_struct.numpy()
struct_grads = xs_struct.grad.numpy()
# fmt: off
assert_np_equal(
np.array([g[0] for g in struct_grads]),
np.array([g[0] * 10.0 for g in xs_struct_np]))
assert_np_equal(
np.array([g[1][0] for g in struct_grads]),
np.array([g[1][1] * g[1][2] * 20.0 for g in xs_struct_np]))
assert_np_equal(
np.array([g[1][1] for g in struct_grads]),
np.array([g[1][0] * g[1][2] * 30.0 for g in xs_struct_np]))
assert_np_equal(
np.array([g[1][2] for g in struct_grads]),
np.array([g[1][0] * g[1][1] * 40.0 for g in xs_struct_np]))
# fmt: on
devices = get_test_devices()
class TestGradCustoms(unittest.TestCase):
pass
add_function_test(TestGradCustoms, "test_custom_replay_grad", test_custom_replay_grad, devices=devices)
add_function_test(TestGradCustoms, "test_custom_overload_grad", test_custom_overload_grad, devices=devices)
if __name__ == "__main__":
wp.build.clear_kernel_cache()
unittest.main(verbosity=2, failfast=False)