File size: 5,616 Bytes
9823a7e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from bitblas import tvm
from typing import Union
from enum import IntEnum
import numpy as np
import torch
from torch.utils.dlpack import from_dlpack, to_dlpack
from math import prod
from tvm.relay import TensorType
from tvm._ffi.base import _LIB, c_str
from tvm._ffi._ctypes.types import TVMValue, check_call
from tvm._ffi.runtime_ctypes import (
TVMArrayHandle,)
import ctypes
TVMPyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
_c_str_dltensor = c_str("dltensor")
_c_str_used_dltensor = c_str("used_dltensor")
def get_values_from_torch_tensors(tensors, num_args):
values = (TVMValue * num_args)()
dlpack_tensors = [to_dlpack(torch_tensor) for torch_tensor in tensors]
for i, dltensor in enumerate(dlpack_tensors):
dltensor = ctypes.py_object(dltensor)
if ctypes.pythonapi.PyCapsule_IsValid(dltensor, _c_str_dltensor):
ptr = ctypes.pythonapi.PyCapsule_GetPointer(dltensor, _c_str_dltensor)
# enforce type to make sure it works for all ctypes
ptr = ctypes.cast(ptr, ctypes.c_void_p)
handle = TVMArrayHandle()
check_call(_LIB.TVMArrayFromDLPack(ptr, ctypes.byref(handle)))
# ndarray = tvm.runtime.ndarray._make_array(handle, False, False)
ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0))
values[i].v_handle = ctypes.cast(handle, ctypes.c_void_p)
else:
raise ValueError("Invalid DLTensor")
return values
class TensorSupplyType(IntEnum):
Integer = 1
Uniform = 2
Normal = 3
Randn = 4
Zero = 5
One = 6
def get_tensor_supply(supply_type: TensorSupplyType, opt_shapes: dict = None):
def var_wrapper(v, opt_shapes):
if isinstance(v, tvm.tir.Var):
assert opt_shapes
assert v.name in opt_shapes
return opt_shapes[v.name]
elif isinstance(v, tvm.tir.IntImm):
return v.value
else:
raise RuntimeError("Not supported type: ", type(v))
def get_tensor(tensor: TensorType) -> torch.Tensor:
dtype = torch.__getattribute__(str(tensor.dtype))
device = torch.cuda.current_device()
shape = [var_wrapper(i, opt_shapes) for i in tensor.shape]
if supply_type == TensorSupplyType.Integer:
return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype)
elif supply_type == TensorSupplyType.Uniform:
return torch.empty(*shape, device=device, dtype=dtype).uniform_(-1.0, 1.0)
elif supply_type == TensorSupplyType.Normal:
return torch.empty(*shape, device=device, dtype=dtype).normal_(-1.0, 1.0)
elif supply_type == TensorSupplyType.Randn:
return torch.randn(*shape, device=device).to(dtype)
elif supply_type == TensorSupplyType.Zero:
return torch.zeros(*shape, device=device, dtype=dtype)
elif supply_type == TensorSupplyType.One:
return torch.ones(*shape, device=device, dtype=dtype)
else:
raise NotImplementedError(supply_type)
return get_tensor
def tvm_tensor_to_torch(tensor: Union[tvm.te.Tensor, tvm.nd.NDArray]):
if isinstance(tensor, tvm.te.Tensor):
return torch.from_numpy(tensor.numpy())
elif isinstance(tensor, tvm.nd.NDArray):
return from_dlpack(tensor)
else:
raise RuntimeError("Not supported type: ", type(tensor))
def lazy_tvm_tensor_to_torch(tensor: Union[tvm.te.Tensor, tvm.nd.NDArray]):
# It additionally needs the ctypes type as torch type
def as_tensor(address, shape, elems_inbytes, torch_type):
arr = (ctypes.c_int8 * elems_inbytes).from_address(address)
return torch.frombuffer(arr, dtype=torch_type).view(*shape)
if isinstance(tensor, tvm.nd.NDArray):
np_array = tensor.asnumpy()
shape = np_array.shape
dtype = np_array.dtype
torch_dtype = getattr(torch, str(dtype))
num_elems_inbytes = prod(shape) * np_array.itemsize
data_ptr = np_array.ctypes.data
tensor = as_tensor(data_ptr, shape, num_elems_inbytes, torch_dtype)
return tensor
else:
raise RuntimeError("Not supported type: ", type(tensor))
def lazy_torch_to_tvm_tensor(tensor):
# It additionally needs the ctypes type as torch type
def as_tensor(address, shape, elems_inbytes, numpy_type):
arr = (ctypes.c_int8 * elems_inbytes).from_address(address)
return np.frombuffer(arr, dtype=numpy_type).reshape(shape)
if isinstance(tensor, torch.Tensor):
data_ptr = tensor.data_ptr()
shape = tensor.shape
torch_dtype = tensor.dtype
numpy_dtype = str(torch_dtype).replace("torch.", "")
num_elems_inbytes = prod(shape) * tensor.itemsize
np_tensor = as_tensor(data_ptr, shape, num_elems_inbytes, numpy_dtype)
tvm_tensor = tvm.nd.array(np_tensor)
return tvm_tensor
else:
raise RuntimeError("Not supported type: ", type(tensor))
def np_float2np_bf16(arr):
"""Convert a numpy array of float to a numpy array
of bf16 in uint16"""
orig = arr.view("<u4")
bias = np.bitwise_and(np.right_shift(orig, 16), 1) + 0x7FFF
return np.right_shift(orig + bias, 16).astype("uint16")
def np_bf162np_float(arr):
"""Convert a numpy array of bf16 (uint16) to a numpy array
of float"""
u32 = np.left_shift(arr.astype("uint32"), 16)
return u32.view("<f4")
|