Spaces:
Sleeping
Sleeping
File size: 13,367 Bytes
66c9c8a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 | # Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import warp
import ctypes
from warp.thirdparty.dlpack import (
DLManagedTensor,
DLDevice,
DLDeviceType,
DLDataType,
DLDataTypeCode,
DLTensor,
_c_str_dltensor,
)
ctypes.pythonapi.PyMem_RawMalloc.restype = ctypes.c_void_p
ctypes.pythonapi.PyMem_RawFree.argtypes = [ctypes.c_void_p]
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
ctypes.pythonapi.PyCapsule_New.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]
ctypes.pythonapi.PyCapsule_IsValid.restype = ctypes.c_int
ctypes.pythonapi.PyCapsule_IsValid.argtypes = [ctypes.py_object, ctypes.c_char_p]
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
ctypes.pythonapi.PyCapsule_GetPointer.argtypes = [ctypes.py_object, ctypes.c_char_p]
class _Holder:
def __init__(self, wp_array) -> None:
self.wp_array = wp_array
def _as_manager_ctx(self) -> ctypes.c_void_p:
py_obj = ctypes.py_object(self)
py_obj_ptr = ctypes.pointer(py_obj)
ctypes.pythonapi.Py_IncRef(py_obj)
ctypes.pythonapi.Py_IncRef(ctypes.py_object(py_obj_ptr))
return ctypes.cast(py_obj_ptr, ctypes.c_void_p)
@ctypes.CFUNCTYPE(None, ctypes.c_void_p)
def _warp_array_deleter(handle: ctypes.c_void_p) -> None:
"""A function to deallocate the memory of a Warp array."""
dl_managed_tensor = DLManagedTensor.from_address(handle)
py_obj_ptr = ctypes.cast(dl_managed_tensor.manager_ctx, ctypes.POINTER(ctypes.py_object))
py_obj = py_obj_ptr.contents
ctypes.pythonapi.Py_DecRef(py_obj)
ctypes.pythonapi.Py_DecRef(ctypes.py_object(py_obj_ptr))
ctypes.pythonapi.PyMem_RawFree(handle)
@ctypes.CFUNCTYPE(None, ctypes.c_void_p)
def _warp_pycapsule_deleter(handle: ctypes.c_void_p) -> None:
"""A function to deallocate a pycapsule that wraps a Warp array."""
pycapsule: ctypes.py_object = ctypes.cast(handle, ctypes.py_object)
if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, _c_str_dltensor):
dl_managed_tensor = ctypes.pythonapi.PyCapsule_GetPointer(pycapsule, _c_str_dltensor)
_warp_array_deleter(dl_managed_tensor)
ctypes.pythonapi.PyCapsule_SetDestructor(pycapsule, None)
def device_to_dlpack(wp_device) -> DLDevice:
d = warp.get_device(wp_device)
if d.is_cpu:
device_type = DLDeviceType.kDLCPU
device_id = 0
elif d.is_cuda:
device_type = DLDeviceType.kDLCUDA
device_id = d.ordinal
else:
raise RuntimeError("Unhandled device type converting to dlpack")
dl_device = DLDevice()
dl_device.device_type = device_type
dl_device.device_id = device_id
return dl_device
def dtype_to_dlpack(wp_dtype) -> DLDataType:
if wp_dtype == warp.int8:
return (DLDataTypeCode.kDLInt, 8, 1)
elif wp_dtype == warp.uint8:
return (DLDataTypeCode.kDLUInt, 8, 1)
elif wp_dtype == warp.int16:
return (DLDataTypeCode.kDLInt, 16, 1)
elif wp_dtype == warp.uint16:
return (DLDataTypeCode.kDLUInt, 16, 1)
elif wp_dtype == warp.int32:
return (DLDataTypeCode.kDLInt, 32, 1)
elif wp_dtype == warp.uint32:
return (DLDataTypeCode.kDLUInt, 32, 1)
elif wp_dtype == warp.int64:
return (DLDataTypeCode.kDLInt, 64, 1)
elif wp_dtype == warp.uint64:
return (DLDataTypeCode.kDLUInt, 64, 1)
elif wp_dtype == warp.float16:
return (DLDataTypeCode.kDLFloat, 16, 1)
elif wp_dtype == warp.float32:
return (DLDataTypeCode.kDLFloat, 32, 1)
elif wp_dtype == warp.float64:
return (DLDataTypeCode.kDLFloat, 64, 1)
else:
raise RuntimeError(f"No conversion from Warp type {wp_dtype} to DLPack type")
def dtype_from_dlpack(dl_dtype):
# unpack to tuple for easier comparison
dl_dtype = (dl_dtype.type_code.value, dl_dtype.bits)
if dl_dtype == (DLDataTypeCode.kDLUInt, 1):
raise RuntimeError("Warp does not support bit boolean types")
elif dl_dtype == (DLDataTypeCode.kDLInt, 8):
return warp.types.int8
elif dl_dtype == (DLDataTypeCode.kDLInt, 16):
return warp.types.int16
elif dl_dtype == (DLDataTypeCode.kDLInt, 32):
return warp.types.int32
elif dl_dtype == (DLDataTypeCode.kDLInt, 64):
return warp.types.int64
elif dl_dtype == (DLDataTypeCode.kDLUInt, 8):
return warp.types.uint8
elif dl_dtype == (DLDataTypeCode.kDLUInt, 16):
return warp.types.uint16
elif dl_dtype == (DLDataTypeCode.kDLUInt, 32):
return warp.types.uint32
elif dl_dtype == (DLDataTypeCode.kDLUInt, 64):
return warp.types.uint64
elif dl_dtype == (DLDataTypeCode.kDLFloat, 16):
return warp.types.float16
elif dl_dtype == (DLDataTypeCode.kDLFloat, 32):
return warp.types.float32
elif dl_dtype == (DLDataTypeCode.kDLFloat, 64):
return warp.types.float64
elif dl_dtype == (DLDataTypeCode.kDLComplex, 64):
raise RuntimeError("Warp does not support complex types")
elif dl_dtype == (DLDataTypeCode.kDLComplex, 128):
raise RuntimeError("Warp does not support complex types")
else:
raise RuntimeError(f"Unknown dlpack datatype {dl_dtype}")
def device_from_dlpack(dl_device):
if dl_device.device_type.value == DLDeviceType.kDLCPU or dl_device.device_type.value == DLDeviceType.kDLCUDAHost:
return "cpu"
elif (
dl_device.device_type.value == DLDeviceType.kDLCUDA
or dl_device.device_type.value == DLDeviceType.kDLCUDAManaged
):
return f"cuda:{dl_device.device_id}"
else:
raise RuntimeError(f"Unknown device type from dlpack: {dl_device.device_type.value}")
def shape_to_dlpack(shape):
a = (ctypes.c_int64 * len(shape))(*shape)
return a
def strides_to_dlpack(strides, dtype):
# convert from byte count to element count
s = []
for i in range(len(strides)):
s.append(int(int(strides[i]) / int(warp.types.type_size_in_bytes(dtype))))
a = (ctypes.c_int64 * len(strides))(*s)
return a
def to_dlpack(wp_array: warp.array):
"""Convert a Warp array to another type of dlpack compatible array.
Parameters
----------
np_array : np.ndarray
The source numpy array that will be converted.
Returns
-------
pycapsule : PyCapsule
A pycapsule containing a DLManagedTensor that can be converted
to other array formats without copying the underlying memory.
"""
# DLPack does not support structured arrays
if isinstance(wp_array.dtype, warp.codegen.Struct):
raise RuntimeError("Cannot convert structured Warp arrays to DLPack.")
holder = _Holder(wp_array)
# allocate DLManagedTensor
size = ctypes.c_size_t(ctypes.sizeof(DLManagedTensor))
dl_managed_tensor = DLManagedTensor.from_address(ctypes.pythonapi.PyMem_RawMalloc(size))
# handle vector types
if hasattr(wp_array.dtype, "_wp_scalar_type_"):
# vector type, flatten the dimensions into one tuple
target_dtype = wp_array.dtype._wp_scalar_type_
target_ndim = wp_array.ndim + len(wp_array.dtype._shape_)
target_shape = (*wp_array.shape, *wp_array.dtype._shape_)
dtype_strides = warp.types.strides_from_shape(wp_array.dtype._shape_, wp_array.dtype._wp_scalar_type_)
target_strides = (*wp_array.strides, *dtype_strides)
else:
# scalar type
target_dtype = wp_array.dtype
target_ndim = wp_array.ndim
target_shape = wp_array.shape
target_strides = wp_array.strides
# store the shape and stride arrays with the holder to prevent them from getting deallocated
holder._shape = shape_to_dlpack(target_shape)
holder._strides = strides_to_dlpack(target_strides, target_dtype)
if wp_array.pinned:
dl_device = DLDeviceType.kDLCUDAHost
else:
dl_device = device_to_dlpack(wp_array.device)
# set Tensor attributes
dl_managed_tensor.dl_tensor.data = wp_array.ptr
dl_managed_tensor.dl_tensor.device = dl_device
dl_managed_tensor.dl_tensor.ndim = target_ndim
dl_managed_tensor.dl_tensor.dtype = dtype_to_dlpack(target_dtype)
dl_managed_tensor.dl_tensor.shape = holder._shape
dl_managed_tensor.dl_tensor.strides = holder._strides
dl_managed_tensor.dl_tensor.byte_offset = 0
dl_managed_tensor.manager_ctx = holder._as_manager_ctx()
dl_managed_tensor.deleter = _warp_array_deleter
pycapsule = ctypes.pythonapi.PyCapsule_New(
ctypes.byref(dl_managed_tensor),
_c_str_dltensor,
_warp_pycapsule_deleter,
)
return pycapsule
def dtype_is_compatible(dl_dtype, wp_dtype):
if dl_dtype.bits % 8 != 0:
raise RuntimeError("Data types with less than 8 bits are not supported")
if dl_dtype.type_code.value == DLDataTypeCode.kDLFloat:
if dl_dtype.bits == 16:
return wp_dtype == warp.float16
elif dl_dtype.bits == 32:
return wp_dtype == warp.float32
elif dl_dtype.bits == 64:
return wp_dtype == warp.float64
elif dl_dtype.type_code.value == DLDataTypeCode.kDLInt or dl_dtype.type_code.value == DLDataTypeCode.kDLUInt:
if dl_dtype.bits == 8:
return wp_dtype == warp.int8 or wp_dtype == warp.uint8
elif dl_dtype.bits == 16:
return wp_dtype == warp.int16 or wp_dtype == warp.uint16
elif dl_dtype.bits == 32:
return wp_dtype == warp.int32 or wp_dtype == warp.uint32
elif dl_dtype.bits == 64:
return wp_dtype == warp.int64 or wp_dtype == warp.uint64
elif dl_dtype.type_code.value == DLDataTypeCode.kDLBfloat:
raise RuntimeError("Bfloat data type is not supported")
elif dl_dtype.type_code.value == DLDataTypeCode.kDLComplex:
raise RuntimeError("Complex data types are not supported")
else:
raise RuntimeError(f"Unsupported dlpack dtype {(str(dl_dtype.type_code), dl_dtype.bits)}")
def from_dlpack(pycapsule, dtype=None) -> warp.array:
"""Convert a dlpack tensor into a numpy array without copying.
Parameters
----------
pycapsule : PyCapsule
A pycapsule wrapping a dlpack tensor that will be converted.
Returns
-------
np_array : np.ndarray
A new numpy array that uses the same underlying memory as the input
pycapsule.
"""
assert ctypes.pythonapi.PyCapsule_IsValid(pycapsule, _c_str_dltensor)
dl_managed_tensor = ctypes.pythonapi.PyCapsule_GetPointer(pycapsule, _c_str_dltensor)
dl_managed_tensor_ptr = ctypes.cast(dl_managed_tensor, ctypes.POINTER(DLManagedTensor))
dl_managed_tensor = dl_managed_tensor_ptr.contents
dlt = dl_managed_tensor.dl_tensor
assert isinstance(dlt, DLTensor)
device = device_from_dlpack(dlt.device)
pinned = dlt.device.device_type.value == DLDeviceType.kDLCUDAHost
shape = tuple(dlt.shape[dim] for dim in range(dlt.ndim))
itemsize = dlt.dtype.bits // 8
if dlt.strides:
strides = tuple(dlt.strides[dim] * itemsize for dim in range(dlt.ndim))
else:
strides = None
# handle multi-lane dtypes as another dimension
if dlt.dtype.lanes > 1:
shape = (*shape, dlt.dtype.lanes)
if strides is not None:
strides = (*strides, itemsize)
if dtype is None:
# automatically detect dtype
dtype = dtype_from_dlpack(dlt.dtype)
elif hasattr(dtype, "_wp_scalar_type_"):
# handle vector/matrix types
if not dtype_is_compatible(dlt.dtype, dtype._wp_scalar_type_):
raise RuntimeError(f"Incompatible data types: {dlt.dtype} and {dtype}")
dtype_shape = dtype._shape_
dtype_dims = len(dtype._shape_)
if dtype_dims > len(shape) or dtype_shape != shape[-dtype_dims:]:
raise RuntimeError(
f"Could not convert DLPack tensor with shape {shape} to Warp array with dtype={dtype}, ensure that source inner shape is {dtype_shape}"
)
if strides is not None:
# ensure the inner strides are contiguous
stride = itemsize
for i in range(dtype_dims):
if strides[-i - 1] != stride:
raise RuntimeError(
f"Could not convert DLPack tensor with shape {shape} to Warp array with dtype={dtype}, because the source inner strides are not contiguous"
)
stride *= dtype_shape[-i - 1]
strides = tuple(strides[:-dtype_dims]) or (itemsize,)
shape = tuple(shape[:-dtype_dims]) or (1,)
elif not dtype_is_compatible(dlt.dtype, dtype):
# incompatible dtype requested
raise RuntimeError(f"Incompatible data types: {dlt.dtype} and {dtype}")
a = warp.types.array(
ptr=dlt.data, dtype=dtype, shape=shape, strides=strides, copy=False, owner=False, device=device, pinned=pinned
)
# keep a reference to the capsule so it doesn't get deleted
a._pycapsule = pycapsule
return a
|