# Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved. # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import ctypes import os import unittest import numpy as np import warp as wp from warp.tests.unittest_utils import * wp.init() @wp.kernel def inc(a: wp.array(dtype=float)): tid = wp.tid() a[tid] = a[tid] + 1.0 def test_dlpack_warp_to_warp(test, device): a1 = wp.array(data=np.arange(10, dtype=np.float32), device=device) a2 = wp.from_dlpack(wp.to_dlpack(a1)) test.assertEqual(a1.ptr, a2.ptr) test.assertEqual(a1.device, a2.device) test.assertEqual(a1.dtype, a2.dtype) test.assertEqual(a1.shape, a2.shape) test.assertEqual(a1.strides, a2.strides) assert_np_equal(a1.numpy(), a2.numpy()) wp.launch(inc, dim=a2.size, inputs=[a2], device=device) assert_np_equal(a1.numpy(), a2.numpy()) def test_dlpack_dtypes_and_shapes(test, device): # automatically determine scalar dtype def wrap_scalar_tensor_implicit(dtype): a1 = wp.zeros(10, dtype=dtype, device=device) a2 = wp.from_dlpack(wp.to_dlpack(a1)) test.assertEqual(a1.ptr, a2.ptr) test.assertEqual(a1.device, a2.device) test.assertEqual(a1.dtype, a2.dtype) test.assertEqual(a1.shape, a2.shape) test.assertEqual(a1.strides, a2.strides) # explicitly specify scalar dtype def wrap_scalar_tensor_explicit(dtype, target_dtype): a1 = wp.zeros(10, dtype=dtype, device=device) a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=target_dtype) test.assertEqual(a1.ptr, a2.ptr) test.assertEqual(a1.device, a2.device) test.assertEqual(a1.dtype, dtype) test.assertEqual(a2.dtype, target_dtype) test.assertEqual(a1.shape, a2.shape) test.assertEqual(a1.strides, a2.strides) # convert vector arrays to scalar arrays def wrap_vector_to_scalar_tensor(vec_dtype): scalar_type = vec_dtype._wp_scalar_type_ scalar_size = ctypes.sizeof(vec_dtype._type_) a1 = wp.zeros(10, dtype=vec_dtype, device=device) a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=scalar_type) test.assertEqual(a1.ptr, a2.ptr) test.assertEqual(a1.device, a2.device) test.assertEqual(a2.ndim, a1.ndim + 1) test.assertEqual(a1.dtype, vec_dtype) test.assertEqual(a2.dtype, scalar_type) test.assertEqual(a2.shape, (*a1.shape, vec_dtype._length_)) test.assertEqual(a2.strides, (*a1.strides, scalar_size)) # convert scalar arrays to vector arrays def wrap_scalar_to_vector_tensor(vec_dtype): scalar_type = vec_dtype._wp_scalar_type_ scalar_size = ctypes.sizeof(vec_dtype._type_) a1 = wp.zeros((10, vec_dtype._length_), dtype=scalar_type, device=device) a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=vec_dtype) test.assertEqual(a1.ptr, a2.ptr) test.assertEqual(a1.device, a2.device) test.assertEqual(a2.ndim, a1.ndim - 1) test.assertEqual(a1.dtype, scalar_type) test.assertEqual(a2.dtype, vec_dtype) test.assertEqual(a1.shape, (*a2.shape, vec_dtype._length_)) test.assertEqual(a1.strides, (*a2.strides, scalar_size)) # convert matrix arrays to scalar arrays def wrap_matrix_to_scalar_tensor(mat_dtype): scalar_type = mat_dtype._wp_scalar_type_ scalar_size = ctypes.sizeof(mat_dtype._type_) a1 = wp.zeros(10, dtype=mat_dtype, device=device) a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=scalar_type) test.assertEqual(a1.ptr, a2.ptr) test.assertEqual(a1.device, a2.device) test.assertEqual(a2.ndim, a1.ndim + 2) test.assertEqual(a1.dtype, mat_dtype) test.assertEqual(a2.dtype, scalar_type) test.assertEqual(a2.shape, (*a1.shape, *mat_dtype._shape_)) test.assertEqual(a2.strides, (*a1.strides, scalar_size * mat_dtype._shape_[1], scalar_size)) # convert scalar arrays to matrix arrays def wrap_scalar_to_matrix_tensor(mat_dtype): scalar_type = mat_dtype._wp_scalar_type_ scalar_size = ctypes.sizeof(mat_dtype._type_) a1 = wp.zeros((10, *mat_dtype._shape_), dtype=scalar_type, device=device) a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=mat_dtype) test.assertEqual(a1.ptr, a2.ptr) test.assertEqual(a1.device, a2.device) test.assertEqual(a2.ndim, a1.ndim - 2) test.assertEqual(a1.dtype, scalar_type) test.assertEqual(a2.dtype, mat_dtype) test.assertEqual(a1.shape, (*a2.shape, *mat_dtype._shape_)) test.assertEqual(a1.strides, (*a2.strides, scalar_size * mat_dtype._shape_[1], scalar_size)) for t in wp.types.scalar_types: wrap_scalar_tensor_implicit(t) for t in wp.types.scalar_types: wrap_scalar_tensor_explicit(t, t) # test signed/unsigned conversions wrap_scalar_tensor_explicit(wp.int8, wp.uint8) wrap_scalar_tensor_explicit(wp.uint8, wp.int8) wrap_scalar_tensor_explicit(wp.int16, wp.uint16) wrap_scalar_tensor_explicit(wp.uint16, wp.int16) wrap_scalar_tensor_explicit(wp.int32, wp.uint32) wrap_scalar_tensor_explicit(wp.uint32, wp.int32) wrap_scalar_tensor_explicit(wp.int64, wp.uint64) wrap_scalar_tensor_explicit(wp.uint64, wp.int64) vec_types = [] for t in wp.types.scalar_types: for vec_len in [2, 3, 4, 5]: vec_types.append(wp.types.vector(vec_len, t)) vec_types.append(wp.quath) vec_types.append(wp.quatf) vec_types.append(wp.quatd) vec_types.append(wp.transformh) vec_types.append(wp.transformf) vec_types.append(wp.transformd) vec_types.append(wp.spatial_vectorh) vec_types.append(wp.spatial_vectorf) vec_types.append(wp.spatial_vectord) for vec_type in vec_types: wrap_vector_to_scalar_tensor(vec_type) wrap_scalar_to_vector_tensor(vec_type) mat_shapes = [(2, 2), (3, 3), (4, 4), (5, 5), (2, 3), (3, 2), (3, 4), (4, 3)] mat_types = [] for t in wp.types.scalar_types: for mat_shape in mat_shapes: mat_types.append(wp.types.matrix(mat_shape, t)) mat_types.append(wp.spatial_matrixh) mat_types.append(wp.spatial_matrixf) mat_types.append(wp.spatial_matrixd) for mat_type in mat_types: wrap_matrix_to_scalar_tensor(mat_type) wrap_scalar_to_matrix_tensor(mat_type) def test_dlpack_warp_to_torch(test, device): import torch.utils.dlpack a = wp.array(data=np.arange(10, dtype=np.float32), device=device) t = torch.utils.dlpack.from_dlpack(wp.to_dlpack(a)) item_size = wp.types.type_size_in_bytes(a.dtype) test.assertEqual(a.ptr, t.data_ptr()) test.assertEqual(a.device, wp.device_from_torch(t.device)) test.assertEqual(a.dtype, wp.torch.dtype_from_torch(t.dtype)) test.assertEqual(a.shape, tuple(t.shape)) test.assertEqual(a.strides, tuple(s * item_size for s in t.stride())) assert_np_equal(a.numpy(), t.cpu().numpy()) wp.launch(inc, dim=a.size, inputs=[a], device=device) assert_np_equal(a.numpy(), t.cpu().numpy()) t += 1 assert_np_equal(a.numpy(), t.cpu().numpy()) def test_dlpack_torch_to_warp(test, device): import torch import torch.utils.dlpack t = torch.arange(10, dtype=torch.float32, device=wp.device_to_torch(device)) a = wp.from_dlpack(torch.utils.dlpack.to_dlpack(t)) item_size = wp.types.type_size_in_bytes(a.dtype) test.assertEqual(a.ptr, t.data_ptr()) test.assertEqual(a.device, wp.device_from_torch(t.device)) test.assertEqual(a.dtype, wp.torch.dtype_from_torch(t.dtype)) test.assertEqual(a.shape, tuple(t.shape)) test.assertEqual(a.strides, tuple(s * item_size for s in t.stride())) assert_np_equal(a.numpy(), t.cpu().numpy()) wp.launch(inc, dim=a.size, inputs=[a], device=device) assert_np_equal(a.numpy(), t.cpu().numpy()) t += 1 assert_np_equal(a.numpy(), t.cpu().numpy()) def test_dlpack_warp_to_jax(test, device): import jax import jax.dlpack a = wp.array(data=np.arange(10, dtype=np.float32), device=device) # use generic dlpack conversion j1 = jax.dlpack.from_dlpack(wp.to_dlpack(a)) # use jax wrapper j2 = wp.to_jax(a) test.assertEqual(a.ptr, j1.unsafe_buffer_pointer()) test.assertEqual(a.ptr, j2.unsafe_buffer_pointer()) test.assertEqual(a.device, wp.device_from_jax(j1.device())) test.assertEqual(a.device, wp.device_from_jax(j2.device())) test.assertEqual(a.shape, j1.shape) test.assertEqual(a.shape, j2.shape) assert_np_equal(a.numpy(), np.asarray(j1)) assert_np_equal(a.numpy(), np.asarray(j2)) wp.launch(inc, dim=a.size, inputs=[a], device=device) wp.synchronize_device(device) # HACK? Run a no-op operation so that Jax flags the arrays as dirty # and gets the latest values, which were modified by Warp. j1 += 0 j2 += 0 assert_np_equal(a.numpy(), np.asarray(j1)) assert_np_equal(a.numpy(), np.asarray(j2)) def test_dlpack_jax_to_warp(test, device): import jax import jax.dlpack with jax.default_device(wp.device_to_jax(device)): j = jax.numpy.arange(10, dtype=jax.numpy.float32) # use generic dlpack conversion a1 = wp.from_dlpack(jax.dlpack.to_dlpack(j)) # use jax wrapper a2 = wp.from_jax(j) test.assertEqual(a1.ptr, j.unsafe_buffer_pointer()) test.assertEqual(a2.ptr, j.unsafe_buffer_pointer()) test.assertEqual(a1.device, wp.device_from_jax(j.device())) test.assertEqual(a2.device, wp.device_from_jax(j.device())) test.assertEqual(a1.shape, j.shape) test.assertEqual(a2.shape, j.shape) assert_np_equal(a1.numpy(), np.asarray(j)) assert_np_equal(a2.numpy(), np.asarray(j)) wp.launch(inc, dim=a1.size, inputs=[a1], device=device) wp.synchronize_device(device) # HACK? Run a no-op operation so that Jax flags the array as dirty # and gets the latest values, which were modified by Warp. j += 0 assert_np_equal(a1.numpy(), np.asarray(j)) assert_np_equal(a2.numpy(), np.asarray(j)) class TestDLPack(unittest.TestCase): pass devices = get_test_devices() add_function_test(TestDLPack, "test_dlpack_warp_to_warp", test_dlpack_warp_to_warp, devices=devices) add_function_test(TestDLPack, "test_dlpack_dtypes_and_shapes", test_dlpack_dtypes_and_shapes, devices=devices) # torch interop via dlpack try: import torch import torch.utils.dlpack # check which Warp devices work with Torch # CUDA devices may fail if Torch was not compiled with CUDA support test_devices = get_test_devices() torch_compatible_devices = [] for d in test_devices: try: t = torch.arange(10, device=wp.device_to_torch(d)) t += 1 torch_compatible_devices.append(d) except Exception as e: print(f"Skipping Torch DLPack tests on device '{d}' due to exception: {e}") if torch_compatible_devices: add_function_test( TestDLPack, "test_dlpack_warp_to_torch", test_dlpack_warp_to_torch, devices=torch_compatible_devices ) add_function_test( TestDLPack, "test_dlpack_torch_to_warp", test_dlpack_torch_to_warp, devices=torch_compatible_devices ) except Exception as e: print(f"Skipping Torch DLPack tests due to exception: {e}") # jax interop via dlpack try: # prevent Jax from gobbling up GPU memory os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" import jax import jax.dlpack # check which Warp devices work with Jax # CUDA devices may fail if Jax cannot find a CUDA Toolkit test_devices = get_test_devices() jax_compatible_devices = [] for d in test_devices: try: with jax.default_device(wp.device_to_jax(d)): j = jax.numpy.arange(10, dtype=jax.numpy.float32) j += 1 jax_compatible_devices.append(d) except Exception as e: print(f"Skipping Jax DLPack tests on device '{d}' due to exception: {e}") if jax_compatible_devices: add_function_test( TestDLPack, "test_dlpack_warp_to_jax", test_dlpack_warp_to_jax, devices=jax_compatible_devices ) add_function_test( TestDLPack, "test_dlpack_jax_to_warp", test_dlpack_jax_to_warp, devices=jax_compatible_devices ) except Exception as e: print(f"Skipping Jax DLPack tests due to exception: {e}") if __name__ == "__main__": wp.build.clear_kernel_cache() unittest.main(verbosity=2)