qbhf2's picture
added NvidiaWarp and GarmentCode repos
66c9c8a
# Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import ctypes
import os
import unittest
import numpy as np
import warp as wp
from warp.tests.unittest_utils import *
wp.init()
@wp.kernel
def inc(a: wp.array(dtype=float)):
tid = wp.tid()
a[tid] = a[tid] + 1.0
def test_dlpack_warp_to_warp(test, device):
a1 = wp.array(data=np.arange(10, dtype=np.float32), device=device)
a2 = wp.from_dlpack(wp.to_dlpack(a1))
test.assertEqual(a1.ptr, a2.ptr)
test.assertEqual(a1.device, a2.device)
test.assertEqual(a1.dtype, a2.dtype)
test.assertEqual(a1.shape, a2.shape)
test.assertEqual(a1.strides, a2.strides)
assert_np_equal(a1.numpy(), a2.numpy())
wp.launch(inc, dim=a2.size, inputs=[a2], device=device)
assert_np_equal(a1.numpy(), a2.numpy())
def test_dlpack_dtypes_and_shapes(test, device):
# automatically determine scalar dtype
def wrap_scalar_tensor_implicit(dtype):
a1 = wp.zeros(10, dtype=dtype, device=device)
a2 = wp.from_dlpack(wp.to_dlpack(a1))
test.assertEqual(a1.ptr, a2.ptr)
test.assertEqual(a1.device, a2.device)
test.assertEqual(a1.dtype, a2.dtype)
test.assertEqual(a1.shape, a2.shape)
test.assertEqual(a1.strides, a2.strides)
# explicitly specify scalar dtype
def wrap_scalar_tensor_explicit(dtype, target_dtype):
a1 = wp.zeros(10, dtype=dtype, device=device)
a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=target_dtype)
test.assertEqual(a1.ptr, a2.ptr)
test.assertEqual(a1.device, a2.device)
test.assertEqual(a1.dtype, dtype)
test.assertEqual(a2.dtype, target_dtype)
test.assertEqual(a1.shape, a2.shape)
test.assertEqual(a1.strides, a2.strides)
# convert vector arrays to scalar arrays
def wrap_vector_to_scalar_tensor(vec_dtype):
scalar_type = vec_dtype._wp_scalar_type_
scalar_size = ctypes.sizeof(vec_dtype._type_)
a1 = wp.zeros(10, dtype=vec_dtype, device=device)
a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=scalar_type)
test.assertEqual(a1.ptr, a2.ptr)
test.assertEqual(a1.device, a2.device)
test.assertEqual(a2.ndim, a1.ndim + 1)
test.assertEqual(a1.dtype, vec_dtype)
test.assertEqual(a2.dtype, scalar_type)
test.assertEqual(a2.shape, (*a1.shape, vec_dtype._length_))
test.assertEqual(a2.strides, (*a1.strides, scalar_size))
# convert scalar arrays to vector arrays
def wrap_scalar_to_vector_tensor(vec_dtype):
scalar_type = vec_dtype._wp_scalar_type_
scalar_size = ctypes.sizeof(vec_dtype._type_)
a1 = wp.zeros((10, vec_dtype._length_), dtype=scalar_type, device=device)
a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=vec_dtype)
test.assertEqual(a1.ptr, a2.ptr)
test.assertEqual(a1.device, a2.device)
test.assertEqual(a2.ndim, a1.ndim - 1)
test.assertEqual(a1.dtype, scalar_type)
test.assertEqual(a2.dtype, vec_dtype)
test.assertEqual(a1.shape, (*a2.shape, vec_dtype._length_))
test.assertEqual(a1.strides, (*a2.strides, scalar_size))
# convert matrix arrays to scalar arrays
def wrap_matrix_to_scalar_tensor(mat_dtype):
scalar_type = mat_dtype._wp_scalar_type_
scalar_size = ctypes.sizeof(mat_dtype._type_)
a1 = wp.zeros(10, dtype=mat_dtype, device=device)
a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=scalar_type)
test.assertEqual(a1.ptr, a2.ptr)
test.assertEqual(a1.device, a2.device)
test.assertEqual(a2.ndim, a1.ndim + 2)
test.assertEqual(a1.dtype, mat_dtype)
test.assertEqual(a2.dtype, scalar_type)
test.assertEqual(a2.shape, (*a1.shape, *mat_dtype._shape_))
test.assertEqual(a2.strides, (*a1.strides, scalar_size * mat_dtype._shape_[1], scalar_size))
# convert scalar arrays to matrix arrays
def wrap_scalar_to_matrix_tensor(mat_dtype):
scalar_type = mat_dtype._wp_scalar_type_
scalar_size = ctypes.sizeof(mat_dtype._type_)
a1 = wp.zeros((10, *mat_dtype._shape_), dtype=scalar_type, device=device)
a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=mat_dtype)
test.assertEqual(a1.ptr, a2.ptr)
test.assertEqual(a1.device, a2.device)
test.assertEqual(a2.ndim, a1.ndim - 2)
test.assertEqual(a1.dtype, scalar_type)
test.assertEqual(a2.dtype, mat_dtype)
test.assertEqual(a1.shape, (*a2.shape, *mat_dtype._shape_))
test.assertEqual(a1.strides, (*a2.strides, scalar_size * mat_dtype._shape_[1], scalar_size))
for t in wp.types.scalar_types:
wrap_scalar_tensor_implicit(t)
for t in wp.types.scalar_types:
wrap_scalar_tensor_explicit(t, t)
# test signed/unsigned conversions
wrap_scalar_tensor_explicit(wp.int8, wp.uint8)
wrap_scalar_tensor_explicit(wp.uint8, wp.int8)
wrap_scalar_tensor_explicit(wp.int16, wp.uint16)
wrap_scalar_tensor_explicit(wp.uint16, wp.int16)
wrap_scalar_tensor_explicit(wp.int32, wp.uint32)
wrap_scalar_tensor_explicit(wp.uint32, wp.int32)
wrap_scalar_tensor_explicit(wp.int64, wp.uint64)
wrap_scalar_tensor_explicit(wp.uint64, wp.int64)
vec_types = []
for t in wp.types.scalar_types:
for vec_len in [2, 3, 4, 5]:
vec_types.append(wp.types.vector(vec_len, t))
vec_types.append(wp.quath)
vec_types.append(wp.quatf)
vec_types.append(wp.quatd)
vec_types.append(wp.transformh)
vec_types.append(wp.transformf)
vec_types.append(wp.transformd)
vec_types.append(wp.spatial_vectorh)
vec_types.append(wp.spatial_vectorf)
vec_types.append(wp.spatial_vectord)
for vec_type in vec_types:
wrap_vector_to_scalar_tensor(vec_type)
wrap_scalar_to_vector_tensor(vec_type)
mat_shapes = [(2, 2), (3, 3), (4, 4), (5, 5), (2, 3), (3, 2), (3, 4), (4, 3)]
mat_types = []
for t in wp.types.scalar_types:
for mat_shape in mat_shapes:
mat_types.append(wp.types.matrix(mat_shape, t))
mat_types.append(wp.spatial_matrixh)
mat_types.append(wp.spatial_matrixf)
mat_types.append(wp.spatial_matrixd)
for mat_type in mat_types:
wrap_matrix_to_scalar_tensor(mat_type)
wrap_scalar_to_matrix_tensor(mat_type)
def test_dlpack_warp_to_torch(test, device):
import torch.utils.dlpack
a = wp.array(data=np.arange(10, dtype=np.float32), device=device)
t = torch.utils.dlpack.from_dlpack(wp.to_dlpack(a))
item_size = wp.types.type_size_in_bytes(a.dtype)
test.assertEqual(a.ptr, t.data_ptr())
test.assertEqual(a.device, wp.device_from_torch(t.device))
test.assertEqual(a.dtype, wp.torch.dtype_from_torch(t.dtype))
test.assertEqual(a.shape, tuple(t.shape))
test.assertEqual(a.strides, tuple(s * item_size for s in t.stride()))
assert_np_equal(a.numpy(), t.cpu().numpy())
wp.launch(inc, dim=a.size, inputs=[a], device=device)
assert_np_equal(a.numpy(), t.cpu().numpy())
t += 1
assert_np_equal(a.numpy(), t.cpu().numpy())
def test_dlpack_torch_to_warp(test, device):
import torch
import torch.utils.dlpack
t = torch.arange(10, dtype=torch.float32, device=wp.device_to_torch(device))
a = wp.from_dlpack(torch.utils.dlpack.to_dlpack(t))
item_size = wp.types.type_size_in_bytes(a.dtype)
test.assertEqual(a.ptr, t.data_ptr())
test.assertEqual(a.device, wp.device_from_torch(t.device))
test.assertEqual(a.dtype, wp.torch.dtype_from_torch(t.dtype))
test.assertEqual(a.shape, tuple(t.shape))
test.assertEqual(a.strides, tuple(s * item_size for s in t.stride()))
assert_np_equal(a.numpy(), t.cpu().numpy())
wp.launch(inc, dim=a.size, inputs=[a], device=device)
assert_np_equal(a.numpy(), t.cpu().numpy())
t += 1
assert_np_equal(a.numpy(), t.cpu().numpy())
def test_dlpack_warp_to_jax(test, device):
import jax
import jax.dlpack
a = wp.array(data=np.arange(10, dtype=np.float32), device=device)
# use generic dlpack conversion
j1 = jax.dlpack.from_dlpack(wp.to_dlpack(a))
# use jax wrapper
j2 = wp.to_jax(a)
test.assertEqual(a.ptr, j1.unsafe_buffer_pointer())
test.assertEqual(a.ptr, j2.unsafe_buffer_pointer())
test.assertEqual(a.device, wp.device_from_jax(j1.device()))
test.assertEqual(a.device, wp.device_from_jax(j2.device()))
test.assertEqual(a.shape, j1.shape)
test.assertEqual(a.shape, j2.shape)
assert_np_equal(a.numpy(), np.asarray(j1))
assert_np_equal(a.numpy(), np.asarray(j2))
wp.launch(inc, dim=a.size, inputs=[a], device=device)
wp.synchronize_device(device)
# HACK? Run a no-op operation so that Jax flags the arrays as dirty
# and gets the latest values, which were modified by Warp.
j1 += 0
j2 += 0
assert_np_equal(a.numpy(), np.asarray(j1))
assert_np_equal(a.numpy(), np.asarray(j2))
def test_dlpack_jax_to_warp(test, device):
import jax
import jax.dlpack
with jax.default_device(wp.device_to_jax(device)):
j = jax.numpy.arange(10, dtype=jax.numpy.float32)
# use generic dlpack conversion
a1 = wp.from_dlpack(jax.dlpack.to_dlpack(j))
# use jax wrapper
a2 = wp.from_jax(j)
test.assertEqual(a1.ptr, j.unsafe_buffer_pointer())
test.assertEqual(a2.ptr, j.unsafe_buffer_pointer())
test.assertEqual(a1.device, wp.device_from_jax(j.device()))
test.assertEqual(a2.device, wp.device_from_jax(j.device()))
test.assertEqual(a1.shape, j.shape)
test.assertEqual(a2.shape, j.shape)
assert_np_equal(a1.numpy(), np.asarray(j))
assert_np_equal(a2.numpy(), np.asarray(j))
wp.launch(inc, dim=a1.size, inputs=[a1], device=device)
wp.synchronize_device(device)
# HACK? Run a no-op operation so that Jax flags the array as dirty
# and gets the latest values, which were modified by Warp.
j += 0
assert_np_equal(a1.numpy(), np.asarray(j))
assert_np_equal(a2.numpy(), np.asarray(j))
class TestDLPack(unittest.TestCase):
pass
devices = get_test_devices()
add_function_test(TestDLPack, "test_dlpack_warp_to_warp", test_dlpack_warp_to_warp, devices=devices)
add_function_test(TestDLPack, "test_dlpack_dtypes_and_shapes", test_dlpack_dtypes_and_shapes, devices=devices)
# torch interop via dlpack
try:
import torch
import torch.utils.dlpack
# check which Warp devices work with Torch
# CUDA devices may fail if Torch was not compiled with CUDA support
test_devices = get_test_devices()
torch_compatible_devices = []
for d in test_devices:
try:
t = torch.arange(10, device=wp.device_to_torch(d))
t += 1
torch_compatible_devices.append(d)
except Exception as e:
print(f"Skipping Torch DLPack tests on device '{d}' due to exception: {e}")
if torch_compatible_devices:
add_function_test(
TestDLPack, "test_dlpack_warp_to_torch", test_dlpack_warp_to_torch, devices=torch_compatible_devices
)
add_function_test(
TestDLPack, "test_dlpack_torch_to_warp", test_dlpack_torch_to_warp, devices=torch_compatible_devices
)
except Exception as e:
print(f"Skipping Torch DLPack tests due to exception: {e}")
# jax interop via dlpack
try:
# prevent Jax from gobbling up GPU memory
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
import jax
import jax.dlpack
# check which Warp devices work with Jax
# CUDA devices may fail if Jax cannot find a CUDA Toolkit
test_devices = get_test_devices()
jax_compatible_devices = []
for d in test_devices:
try:
with jax.default_device(wp.device_to_jax(d)):
j = jax.numpy.arange(10, dtype=jax.numpy.float32)
j += 1
jax_compatible_devices.append(d)
except Exception as e:
print(f"Skipping Jax DLPack tests on device '{d}' due to exception: {e}")
if jax_compatible_devices:
add_function_test(
TestDLPack, "test_dlpack_warp_to_jax", test_dlpack_warp_to_jax, devices=jax_compatible_devices
)
add_function_test(
TestDLPack, "test_dlpack_jax_to_warp", test_dlpack_jax_to_warp, devices=jax_compatible_devices
)
except Exception as e:
print(f"Skipping Jax DLPack tests due to exception: {e}")
if __name__ == "__main__":
wp.build.clear_kernel_cache()
unittest.main(verbosity=2)