Spaces:
Sleeping
Sleeping
| # Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved. | |
| # NVIDIA CORPORATION and its licensors retain all intellectual property | |
| # and proprietary rights in and to this software, related documentation | |
| # and any modifications thereto. Any use, reproduction, disclosure or | |
| # distribution of this software and related documentation without an express | |
| # license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| import warp | |
| def device_to_jax(wp_device): | |
| import jax | |
| d = warp.get_device(wp_device) | |
| if d.is_cuda: | |
| cuda_devices = jax.devices("cuda") | |
| if d.ordinal >= len(cuda_devices): | |
| raise RuntimeError(f"Jax device corresponding to '{wp_device}' is not available") | |
| return cuda_devices[d.ordinal] | |
| else: | |
| cpu_devices = jax.devices("cpu") | |
| if not cpu_devices: | |
| raise RuntimeError(f"Jax device corresponding to '{wp_device}' is not available") | |
| return cpu_devices[0] | |
| def device_from_jax(jax_device): | |
| if jax_device.platform == "cpu": | |
| return warp.get_device("cpu") | |
| elif jax_device.platform == "gpu": | |
| return warp.get_cuda_device(jax_device.id) | |
| else: | |
| raise RuntimeError(f"Unknown or unsupported Jax device platform '{jax_device.platform}'") | |
| def to_jax(wp_array): | |
| import jax.dlpack | |
| return jax.dlpack.from_dlpack(warp.to_dlpack(wp_array)) | |
| def from_jax(jax_array, dtype=None): | |
| import jax.dlpack | |
| return warp.from_dlpack(jax.dlpack.to_dlpack(jax_array), dtype=dtype) | |