File size: 5,616 Bytes
9823a7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from bitblas import tvm
from typing import Union
from enum import IntEnum
import numpy as np
import torch
from torch.utils.dlpack import from_dlpack, to_dlpack
from math import prod

from tvm.relay import TensorType
from tvm._ffi.base import _LIB, c_str
from tvm._ffi._ctypes.types import TVMValue, check_call
from tvm._ffi.runtime_ctypes import (
    TVMArrayHandle,)
import ctypes

TVMPyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
_c_str_dltensor = c_str("dltensor")
_c_str_used_dltensor = c_str("used_dltensor")


def get_values_from_torch_tensors(tensors, num_args):
    values = (TVMValue * num_args)()
    dlpack_tensors = [to_dlpack(torch_tensor) for torch_tensor in tensors]
    for i, dltensor in enumerate(dlpack_tensors):
        dltensor = ctypes.py_object(dltensor)
        if ctypes.pythonapi.PyCapsule_IsValid(dltensor, _c_str_dltensor):
            ptr = ctypes.pythonapi.PyCapsule_GetPointer(dltensor, _c_str_dltensor)
            # enforce type to make sure it works for all ctypes
            ptr = ctypes.cast(ptr, ctypes.c_void_p)
            handle = TVMArrayHandle()
            check_call(_LIB.TVMArrayFromDLPack(ptr, ctypes.byref(handle)))
            # ndarray = tvm.runtime.ndarray._make_array(handle, False, False)
            ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
            ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0))
            values[i].v_handle = ctypes.cast(handle, ctypes.c_void_p)
        else:
            raise ValueError("Invalid DLTensor")
    return values


class TensorSupplyType(IntEnum):
    Integer = 1
    Uniform = 2
    Normal = 3
    Randn = 4
    Zero = 5
    One = 6


def get_tensor_supply(supply_type: TensorSupplyType, opt_shapes: dict = None):

    def var_wrapper(v, opt_shapes):
        if isinstance(v, tvm.tir.Var):
            assert opt_shapes
            assert v.name in opt_shapes
            return opt_shapes[v.name]
        elif isinstance(v, tvm.tir.IntImm):
            return v.value
        else:
            raise RuntimeError("Not supported type: ", type(v))

    def get_tensor(tensor: TensorType) -> torch.Tensor:
        dtype = torch.__getattribute__(str(tensor.dtype))
        device = torch.cuda.current_device()
        shape = [var_wrapper(i, opt_shapes) for i in tensor.shape]
        if supply_type == TensorSupplyType.Integer:
            return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype)
        elif supply_type == TensorSupplyType.Uniform:
            return torch.empty(*shape, device=device, dtype=dtype).uniform_(-1.0, 1.0)
        elif supply_type == TensorSupplyType.Normal:
            return torch.empty(*shape, device=device, dtype=dtype).normal_(-1.0, 1.0)
        elif supply_type == TensorSupplyType.Randn:
            return torch.randn(*shape, device=device).to(dtype)
        elif supply_type == TensorSupplyType.Zero:
            return torch.zeros(*shape, device=device, dtype=dtype)
        elif supply_type == TensorSupplyType.One:
            return torch.ones(*shape, device=device, dtype=dtype)
        else:
            raise NotImplementedError(supply_type)

    return get_tensor


def tvm_tensor_to_torch(tensor: Union[tvm.te.Tensor, tvm.nd.NDArray]):
    if isinstance(tensor, tvm.te.Tensor):
        return torch.from_numpy(tensor.numpy())
    elif isinstance(tensor, tvm.nd.NDArray):
        return from_dlpack(tensor)
    else:
        raise RuntimeError("Not supported type: ", type(tensor))


def lazy_tvm_tensor_to_torch(tensor: Union[tvm.te.Tensor, tvm.nd.NDArray]):
    # It additionally needs the ctypes type as torch type
    def as_tensor(address, shape, elems_inbytes, torch_type):
        arr = (ctypes.c_int8 * elems_inbytes).from_address(address)
        return torch.frombuffer(arr, dtype=torch_type).view(*shape)

    if isinstance(tensor, tvm.nd.NDArray):
        np_array = tensor.asnumpy()
        shape = np_array.shape
        dtype = np_array.dtype
        torch_dtype = getattr(torch, str(dtype))
        num_elems_inbytes = prod(shape) * np_array.itemsize
        data_ptr = np_array.ctypes.data
        tensor = as_tensor(data_ptr, shape, num_elems_inbytes, torch_dtype)
        return tensor
    else:
        raise RuntimeError("Not supported type: ", type(tensor))


def lazy_torch_to_tvm_tensor(tensor):
    # It additionally needs the ctypes type as torch type
    def as_tensor(address, shape, elems_inbytes, numpy_type):
        arr = (ctypes.c_int8 * elems_inbytes).from_address(address)
        return np.frombuffer(arr, dtype=numpy_type).reshape(shape)

    if isinstance(tensor, torch.Tensor):
        data_ptr = tensor.data_ptr()
        shape = tensor.shape
        torch_dtype = tensor.dtype
        numpy_dtype = str(torch_dtype).replace("torch.", "")
        num_elems_inbytes = prod(shape) * tensor.itemsize
        np_tensor = as_tensor(data_ptr, shape, num_elems_inbytes, numpy_dtype)
        tvm_tensor = tvm.nd.array(np_tensor)
        return tvm_tensor
    else:
        raise RuntimeError("Not supported type: ", type(tensor))


def np_float2np_bf16(arr):
    """Convert a numpy array of float to a numpy array
    of bf16 in uint16"""
    orig = arr.view("<u4")
    bias = np.bitwise_and(np.right_shift(orig, 16), 1) + 0x7FFF
    return np.right_shift(orig + bias, 16).astype("uint16")


def np_bf162np_float(arr):
    """Convert a numpy array of bf16 (uint16) to a numpy array
    of float"""
    u32 = np.left_shift(arr.astype("uint32"), 16)
    return u32.view("<f4")