Spaces:
Sleeping
Sleeping
File size: 9,605 Bytes
66c9c8a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 | # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import ctypes
import numpy
import warp
# return the warp device corresponding to a torch device
def device_from_torch(torch_device):
"""Return the warp device corresponding to a torch device."""
return warp.get_device(str(torch_device))
def device_to_torch(wp_device):
"""Return the torch device corresponding to a warp device."""
device = warp.get_device(wp_device)
if device.is_cpu or device.is_primary:
return str(device)
elif device.is_cuda and device.is_uva:
# it's not a primary context, but torch can access the data ptr directly thanks to UVA
return f"cuda:{device.ordinal}"
raise RuntimeError(f"Warp device {device} is not compatible with torch")
def dtype_from_torch(torch_dtype):
"""Return the Warp dtype corresponding to a torch dtype."""
# initialize lookup table on first call to defer torch import
if dtype_from_torch.type_map is None:
import torch
dtype_from_torch.type_map = {
torch.float64: warp.float64,
torch.float32: warp.float32,
torch.float16: warp.float16,
torch.int64: warp.int64,
torch.int32: warp.int32,
torch.int16: warp.int16,
torch.int8: warp.int8,
torch.uint8: warp.uint8,
torch.bool: warp.bool,
# currently unsupported by Warp
# torch.bfloat16:
# torch.complex64:
# torch.complex128:
}
warp_dtype = dtype_from_torch.type_map.get(torch_dtype)
if warp_dtype is not None:
return warp_dtype
else:
raise TypeError(f"Invalid or unsupported data type: {torch_dtype}")
dtype_from_torch.type_map = None
def dtype_is_compatible(torch_dtype, warp_dtype):
"""Evaluates whether the given torch dtype is compatible with the given warp dtype."""
# initialize lookup table on first call to defer torch import
if dtype_is_compatible.compatible_sets is None:
import torch
dtype_is_compatible.compatible_sets = {
torch.float64: {warp.float64},
torch.float32: {warp.float32},
torch.float16: {warp.float16},
# allow aliasing integer tensors as signed or unsigned integer arrays
torch.int64: {warp.int64, warp.uint64},
torch.int32: {warp.int32, warp.uint32},
torch.int16: {warp.int16, warp.uint16},
torch.int8: {warp.int8, warp.uint8},
torch.uint8: {warp.uint8, warp.int8},
torch.bool: {warp.bool, warp.uint8, warp.int8},
# currently unsupported by Warp
# torch.bfloat16:
# torch.complex64:
# torch.complex128:
}
compatible_set = dtype_is_compatible.compatible_sets.get(torch_dtype)
if compatible_set is not None:
if hasattr(warp_dtype, "_wp_scalar_type_"):
return warp_dtype._wp_scalar_type_ in compatible_set
else:
return warp_dtype in compatible_set
else:
raise TypeError(f"Invalid or unsupported data type: {torch_dtype}")
dtype_is_compatible.compatible_sets = None
# wrap a torch tensor to a wp array, data is not copied
def from_torch(t, dtype=None, requires_grad=None, grad=None):
"""Wrap a PyTorch tensor to a Warp array without copying the data.
Args:
t (torch.Tensor): The torch tensor to wrap.
dtype (warp.dtype, optional): The target data type of the resulting Warp array. Defaults to the tensor value type mapped to a Warp array value type.
requires_grad (bool, optional): Whether the resulting array should wrap the tensor's gradient, if it exists (the grad tensor will be allocated otherwise). Defaults to the tensor's `requires_grad` value.
Returns:
warp.array: The wrapped array.
"""
if dtype is None:
dtype = dtype_from_torch(t.dtype)
elif not dtype_is_compatible(t.dtype, dtype):
raise RuntimeError(f"Incompatible data types: {t.dtype} and {dtype}")
# get size of underlying data type to compute strides
ctype_size = ctypes.sizeof(dtype._type_)
shape = tuple(t.shape)
strides = tuple(s * ctype_size for s in t.stride())
# if target is a vector or matrix type
# then check if trailing dimensions match
# the target type and update the shape
if hasattr(dtype, "_shape_"):
dtype_shape = dtype._shape_
dtype_dims = len(dtype._shape_)
if dtype_dims > len(shape) or dtype_shape != shape[-dtype_dims:]:
raise RuntimeError(
f"Could not convert Torch tensor with shape {shape} to Warp array with dtype={dtype}, ensure that source inner shape is {dtype_shape}"
)
# ensure the inner strides are contiguous
stride = ctype_size
for i in range(dtype_dims):
if strides[-i - 1] != stride:
raise RuntimeError(
f"Could not convert Torch tensor with shape {shape} to Warp array with dtype={dtype}, because the source inner strides are not contiguous"
)
stride *= dtype_shape[-i - 1]
shape = tuple(shape[:-dtype_dims]) or (1,)
strides = tuple(strides[:-dtype_dims]) or (ctype_size,)
requires_grad = t.requires_grad if requires_grad is None else requires_grad
if grad is not None:
if not isinstance(grad, warp.array):
import torch
if isinstance(grad, torch.Tensor):
grad = from_torch(grad, dtype=dtype)
else:
raise ValueError(f"Invalid gradient type: {type(grad)}")
elif requires_grad:
# wrap the tensor gradient, allocate if necessary
if t.grad is None:
# allocate a zero-filled gradient tensor if it doesn't exist
import torch
t.grad = torch.zeros_like(t, requires_grad=False)
grad = from_torch(t.grad, dtype=dtype)
a = warp.types.array(
ptr=t.data_ptr(),
dtype=dtype,
shape=shape,
strides=strides,
device=device_from_torch(t.device),
copy=False,
owner=False,
grad=grad,
requires_grad=requires_grad,
)
# save a reference to the source tensor, otherwise it will be deallocated
a._tensor = t
return a
def to_torch(a, requires_grad=None):
"""
Convert a Warp array to a PyTorch tensor without copying the data.
Args:
a (warp.array): The Warp array to convert.
requires_grad (bool, optional): Whether the resulting tensor should convert the array's gradient, if it exists, to a grad tensor. Defaults to the array's `requires_grad` value.
Returns:
torch.Tensor: The converted tensor.
"""
import torch
if requires_grad is None:
requires_grad = a.requires_grad
# Torch does not support structured arrays
if isinstance(a.dtype, warp.codegen.Struct):
raise RuntimeError("Cannot convert structured Warp arrays to Torch.")
if a.device.is_cpu:
# Torch has an issue wrapping CPU objects
# that support the __array_interface__ protocol
# in this case we need to workaround by going
# to an ndarray first, see https://pearu.github.io/array_interface_pytorch.html
t = torch.as_tensor(numpy.asarray(a))
t.requires_grad = requires_grad
if requires_grad and a.requires_grad:
t.grad = torch.as_tensor(numpy.asarray(a.grad))
return t
elif a.device.is_cuda:
# Torch does support the __cuda_array_interface__
# correctly, but we must be sure to maintain a reference
# to the owning object to prevent memory allocs going out of scope
t = torch.as_tensor(a, device=device_to_torch(a.device))
t.requires_grad = requires_grad
if requires_grad and a.requires_grad:
t.grad = torch.as_tensor(a.grad, device=device_to_torch(a.device))
return t
else:
raise RuntimeError("Unsupported device")
def stream_from_torch(stream_or_device=None):
"""Convert from a PyTorch CUDA stream to a Warp.Stream."""
import torch
if isinstance(stream_or_device, torch.cuda.Stream):
stream = stream_or_device
else:
# assume arg is a torch device
stream = torch.cuda.current_stream(stream_or_device)
device = device_from_torch(stream.device)
warp_stream = warp.Stream(device, cuda_stream=stream.cuda_stream)
# save a reference to the source stream, otherwise it may be destroyed
warp_stream._torch_stream = stream
return warp_stream
def stream_to_torch(stream_or_device=None):
"""Convert from a Warp.Stream to a PyTorch CUDA stream."""
import torch
if isinstance(stream_or_device, warp.Stream):
stream = stream_or_device
else:
# assume arg is a warp device
stream = warp.get_device(stream_or_device).stream
device = device_to_torch(stream.device)
torch_stream = torch.cuda.ExternalStream(stream.cuda_stream, device=device)
# save a reference to the source stream, otherwise it may be destroyed
torch_stream._warp_stream = stream
return torch_stream
|