| |
| |
| from bitblas import tvm |
| from typing import Union |
| from enum import IntEnum |
| import numpy as np |
| import torch |
| from torch.utils.dlpack import from_dlpack, to_dlpack |
| from math import prod |
|
|
| from tvm.relay import TensorType |
| from tvm._ffi.base import _LIB, c_str |
| from tvm._ffi._ctypes.types import TVMValue, check_call |
| from tvm._ffi.runtime_ctypes import ( |
| TVMArrayHandle,) |
| import ctypes |
|
|
| TVMPyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p) |
| _c_str_dltensor = c_str("dltensor") |
| _c_str_used_dltensor = c_str("used_dltensor") |
|
|
|
|
| def get_values_from_torch_tensors(tensors, num_args): |
| values = (TVMValue * num_args)() |
| dlpack_tensors = [to_dlpack(torch_tensor) for torch_tensor in tensors] |
| for i, dltensor in enumerate(dlpack_tensors): |
| dltensor = ctypes.py_object(dltensor) |
| if ctypes.pythonapi.PyCapsule_IsValid(dltensor, _c_str_dltensor): |
| ptr = ctypes.pythonapi.PyCapsule_GetPointer(dltensor, _c_str_dltensor) |
| |
| ptr = ctypes.cast(ptr, ctypes.c_void_p) |
| handle = TVMArrayHandle() |
| check_call(_LIB.TVMArrayFromDLPack(ptr, ctypes.byref(handle))) |
| |
| ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor) |
| ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0)) |
| values[i].v_handle = ctypes.cast(handle, ctypes.c_void_p) |
| else: |
| raise ValueError("Invalid DLTensor") |
| return values |
|
|
|
|
| class TensorSupplyType(IntEnum): |
| Integer = 1 |
| Uniform = 2 |
| Normal = 3 |
| Randn = 4 |
| Zero = 5 |
| One = 6 |
|
|
|
|
| def get_tensor_supply(supply_type: TensorSupplyType, opt_shapes: dict = None): |
|
|
| def var_wrapper(v, opt_shapes): |
| if isinstance(v, tvm.tir.Var): |
| assert opt_shapes |
| assert v.name in opt_shapes |
| return opt_shapes[v.name] |
| elif isinstance(v, tvm.tir.IntImm): |
| return v.value |
| else: |
| raise RuntimeError("Not supported type: ", type(v)) |
|
|
| def get_tensor(tensor: TensorType) -> torch.Tensor: |
| dtype = torch.__getattribute__(str(tensor.dtype)) |
| device = torch.cuda.current_device() |
| shape = [var_wrapper(i, opt_shapes) for i in tensor.shape] |
| if supply_type == TensorSupplyType.Integer: |
| return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype) |
| elif supply_type == TensorSupplyType.Uniform: |
| return torch.empty(*shape, device=device, dtype=dtype).uniform_(-1.0, 1.0) |
| elif supply_type == TensorSupplyType.Normal: |
| return torch.empty(*shape, device=device, dtype=dtype).normal_(-1.0, 1.0) |
| elif supply_type == TensorSupplyType.Randn: |
| return torch.randn(*shape, device=device).to(dtype) |
| elif supply_type == TensorSupplyType.Zero: |
| return torch.zeros(*shape, device=device, dtype=dtype) |
| elif supply_type == TensorSupplyType.One: |
| return torch.ones(*shape, device=device, dtype=dtype) |
| else: |
| raise NotImplementedError(supply_type) |
|
|
| return get_tensor |
|
|
|
|
| def tvm_tensor_to_torch(tensor: Union[tvm.te.Tensor, tvm.nd.NDArray]): |
| if isinstance(tensor, tvm.te.Tensor): |
| return torch.from_numpy(tensor.numpy()) |
| elif isinstance(tensor, tvm.nd.NDArray): |
| return from_dlpack(tensor) |
| else: |
| raise RuntimeError("Not supported type: ", type(tensor)) |
|
|
|
|
| def lazy_tvm_tensor_to_torch(tensor: Union[tvm.te.Tensor, tvm.nd.NDArray]): |
| |
| def as_tensor(address, shape, elems_inbytes, torch_type): |
| arr = (ctypes.c_int8 * elems_inbytes).from_address(address) |
| return torch.frombuffer(arr, dtype=torch_type).view(*shape) |
|
|
| if isinstance(tensor, tvm.nd.NDArray): |
| np_array = tensor.asnumpy() |
| shape = np_array.shape |
| dtype = np_array.dtype |
| torch_dtype = getattr(torch, str(dtype)) |
| num_elems_inbytes = prod(shape) * np_array.itemsize |
| data_ptr = np_array.ctypes.data |
| tensor = as_tensor(data_ptr, shape, num_elems_inbytes, torch_dtype) |
| return tensor |
| else: |
| raise RuntimeError("Not supported type: ", type(tensor)) |
|
|
|
|
| def lazy_torch_to_tvm_tensor(tensor): |
| |
| def as_tensor(address, shape, elems_inbytes, numpy_type): |
| arr = (ctypes.c_int8 * elems_inbytes).from_address(address) |
| return np.frombuffer(arr, dtype=numpy_type).reshape(shape) |
|
|
| if isinstance(tensor, torch.Tensor): |
| data_ptr = tensor.data_ptr() |
| shape = tensor.shape |
| torch_dtype = tensor.dtype |
| numpy_dtype = str(torch_dtype).replace("torch.", "") |
| num_elems_inbytes = prod(shape) * tensor.itemsize |
| np_tensor = as_tensor(data_ptr, shape, num_elems_inbytes, numpy_dtype) |
| tvm_tensor = tvm.nd.array(np_tensor) |
| return tvm_tensor |
| else: |
| raise RuntimeError("Not supported type: ", type(tensor)) |
|
|
|
|
| def np_float2np_bf16(arr): |
| """Convert a numpy array of float to a numpy array |
| of bf16 in uint16""" |
| orig = arr.view("<u4") |
| bias = np.bitwise_and(np.right_shift(orig, 16), 1) + 0x7FFF |
| return np.right_shift(orig + bias, 16).astype("uint16") |
|
|
|
|
| def np_bf162np_float(arr): |
| """Convert a numpy array of bf16 (uint16) to a numpy array |
| of float""" |
| u32 = np.left_shift(arr.astype("uint32"), 16) |
| return u32.view("<f4") |
|
|