Spaces:
Sleeping
Sleeping
File size: 1,511 Bytes
66c9c8a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | # Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import warp
def device_to_jax(wp_device):
import jax
d = warp.get_device(wp_device)
if d.is_cuda:
cuda_devices = jax.devices("cuda")
if d.ordinal >= len(cuda_devices):
raise RuntimeError(f"Jax device corresponding to '{wp_device}' is not available")
return cuda_devices[d.ordinal]
else:
cpu_devices = jax.devices("cpu")
if not cpu_devices:
raise RuntimeError(f"Jax device corresponding to '{wp_device}' is not available")
return cpu_devices[0]
def device_from_jax(jax_device):
if jax_device.platform == "cpu":
return warp.get_device("cpu")
elif jax_device.platform == "gpu":
return warp.get_cuda_device(jax_device.id)
else:
raise RuntimeError(f"Unknown or unsupported Jax device platform '{jax_device.platform}'")
def to_jax(wp_array):
import jax.dlpack
return jax.dlpack.from_dlpack(warp.to_dlpack(wp_array))
def from_jax(jax_array, dtype=None):
import jax.dlpack
return warp.from_dlpack(jax.dlpack.to_dlpack(jax_array), dtype=dtype)
|