File size: 6,829 Bytes
67a5826
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import math
import torch
from typing import Tuple, Union, Optional, NamedTuple

PackedDType = torch.int16
PackedNumBits = torch.iinfo(PackedDType).bits
FloatTensorType = torch.Tensor
UInt8TensorType = torch.Tensor
Int16TensorType = torch.Tensor
Int32TensorType = torch.Tensor
BinaryTensorType = torch.Tensor
PackedBinaryTensorType = torch.Tensor


# https://stackoverflow.com/questions/55918468/convert-integer-to-pytorch-tensor-of-binary-bits
def to_binary(tensor: UInt8TensorType, num_bits: int, legacy: bool = True) -> BinaryTensorType:
    if tensor.dtype != torch.uint8:
        raise TypeError
    if num_bits > 8:
        raise NotImplementedError

    # Explicit casting, and the following code will
    # raise an Error if casting leads to overflow
    bits_max = torch.tensor(
        2 ** num_bits - 1,
        dtype=torch.uint8,
        device=tensor.device)
    if tensor.max() > bits_max:
        raise OverflowError

    if legacy is True:
        # When using `torch.compile`, the `pow` ops
        # requires floating point numbers, but the
        # `bitwise_and` requires integers.
        mask = 2 ** torch.arange(
            num_bits - 1, -1, -1,
            dtype=torch.float32,
            device=tensor.device)
        mask = mask.to(dtype=torch.uint8)
    else:
        # 1. The above casting is not necessary for PyTorch>=2.1
        # 2. We no longer reverse the bits directions
        mask = 2 ** torch.arange(
            num_bits,
            dtype=torch.uint8,
            device=tensor.device)

    return (
        tensor
        .unsqueeze(dim=-1)
        .bitwise_and(mask)
        .ne(0)
        .bool())


def from_binary(tensor: BinaryTensorType, num_bits: int, legacy: bool = True) -> UInt8TensorType:
    if tensor.dtype != torch.bool:
        raise TypeError
    if tensor.shape[-1] != num_bits:
        raise ValueError
    if num_bits > 8:
        raise NotImplementedError

    if legacy is True:
        mask = 2 ** torch.arange(
            num_bits - 1, -1, -1,
            dtype=torch.float32,
            device=tensor.device)
        mask = mask.to(dtype=torch.uint8)
    else:
        mask = 2 ** torch.arange(
            num_bits,
            dtype=torch.uint8,
            device=tensor.device)

    # This casting is somewhat unnecessary.
    tensor = tensor.to(dtype=torch.uint8)
    output = torch.sum(mask * tensor, dim=-1)
    output = output.to(dtype=torch.uint8)
    return output


def pack_bools_into_integers(
    tensor: BinaryTensorType,
    packed_dtype: torch.dtype,
    legacy: bool = False,
) -> Tuple[PackedBinaryTensorType, int]:
    if tensor.ndim != 1 or tensor.shape[-1] != tensor.numel():
        raise ValueError
    if tensor.dtype != torch.bool:
        raise TypeError
    if packed_dtype not in [torch.uint8, torch.int16, torch.int32]:
        raise NotImplementedError

    # number of bits in the packed dtype
    packed_num_bits = torch.iinfo(packed_dtype).bits

    remainder = (
        tensor.shape[-1] %
        packed_num_bits)
    if remainder > 0:
        padding_length = (
            packed_num_bits -
            remainder)
        padding = tensor.new_zeros(padding_length)
        tensor = torch.cat([tensor, padding], dim=-1)
    else:
        padding_length = 0

    # [-1, packed_num_bits]
    tensor = tensor.view(
        int(tensor.shape[-1] / packed_num_bits),
        packed_num_bits)

    if legacy is True:
        # [1, packed_num_bits]
        bits = torch.arange(
            packed_num_bits,
            dtype=packed_dtype,
            device=tensor.device)
        bits = torch.unsqueeze(bits, dim=0)
        packed_tensor = (tensor << bits)
        packed_tensor = torch.sum(packed_tensor, dim=-1)
        packed_tensor = packed_tensor.to(dtype=packed_dtype)

    else:
        # Allocate the output tensor in the desired dtype.
        packed_tensor = torch.zeros(
            tensor.shape[0],
            dtype=packed_dtype,
            device=tensor.device)

        # Process each bit column individually.
        for bit in range(packed_num_bits):
            # Convert the boolean column to the target dtype and shift left by `bit`.
            # This computes in the target dtype (e.g., int16) rather than int64.
            packed_tensor |= tensor[:, bit].to(packed_dtype) << bit

    return packed_tensor, padding_length


def unpack_integers_into_bools(
    packed_tensor: PackedBinaryTensorType,
    padding_length: int,
    packed_dtype: torch.dtype,
) -> BinaryTensorType:
    if packed_tensor.ndim != 1:
        raise ValueError
    if packed_tensor.dtype != packed_dtype:
        raise TypeError
    if packed_dtype not in [torch.uint8, torch.int16, torch.int32]:
        raise NotImplementedError

    # number of bits in the packed dtype
    packed_num_bits = torch.iinfo(packed_dtype).bits

    # [1, packed_num_bits]
    bits = packed_tensor.new_tensor(
        1,
        dtype=packed_dtype)
    bits = bits << torch.arange(
        packed_num_bits,
        dtype=packed_dtype,
        device=packed_tensor.device)
    bits = torch.unsqueeze(
        bits,
        dim=0)
    unpacked_tensor = torch.unsqueeze(
        packed_tensor,
        dim=-1)
    unpacked_tensor = unpacked_tensor & bits
    if packed_dtype == torch.uint8:
        unpacked_tensor = unpacked_tensor > 0
    elif packed_dtype == torch.int32:
        # For signed integers such as int32, the 31st element is the
        # sign bit, so 0b10000000000000000000000000000000 = -2^31
        # The following line of code can be applied to both settings.
        # However, for legacy reasons, we only apply it to int32.
        unpacked_tensor = unpacked_tensor != 0
    else:
        raise NotImplementedError

    unpacked_tensor = unpacked_tensor.to(dtype=torch.bool)
    unpacked_tensor = unpacked_tensor.view(-1)
    if padding_length > 0:
        unpacked_tensor = unpacked_tensor[:-padding_length]
    return unpacked_tensor


def pack_integer_tensors(
    tensor: UInt8TensorType,
    num_bits: int,
) -> PackedBinaryTensorType:
    # Two major differences for faster dequantization
    # 1. `reverse=False`
    # 2. `packed_dtype=torch.int32`
    # 3. special implementation for `num_bits=3`
    # 4. does not support padding

    # [*tensor.shape, num_bits]
    binary_tensor = to_binary(
        tensor=tensor,
        num_bits=num_bits,
        legacy=False)

    if num_bits == 3:
        raise NotImplementedError

    # [tensor.numel() x num_bits]
    binary_tensor = binary_tensor.view(
        tensor.numel() * num_bits)
    binary_tensor = binary_tensor.contiguous()
    # [tensor.numel() x num_bits / 32]
    packed_tensor, padding_length = pack_bools_into_integers(
        tensor=binary_tensor,
        packed_dtype=PackedDType)
    if padding_length != 0:
        raise ValueError
    return packed_tensor