Instructions to use galqiwi/flute_kernels with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use galqiwi/flute_kernels with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("galqiwi/flute_kernels") - Notebooks
- Google Colab
- Kaggle
File size: 6,829 Bytes
67a5826 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 | import math
import torch
from typing import Tuple, Union, Optional, NamedTuple
PackedDType = torch.int16
PackedNumBits = torch.iinfo(PackedDType).bits
FloatTensorType = torch.Tensor
UInt8TensorType = torch.Tensor
Int16TensorType = torch.Tensor
Int32TensorType = torch.Tensor
BinaryTensorType = torch.Tensor
PackedBinaryTensorType = torch.Tensor
# https://stackoverflow.com/questions/55918468/convert-integer-to-pytorch-tensor-of-binary-bits
def to_binary(tensor: UInt8TensorType, num_bits: int, legacy: bool = True) -> BinaryTensorType:
if tensor.dtype != torch.uint8:
raise TypeError
if num_bits > 8:
raise NotImplementedError
# Explicit casting, and the following code will
# raise an Error if casting leads to overflow
bits_max = torch.tensor(
2 ** num_bits - 1,
dtype=torch.uint8,
device=tensor.device)
if tensor.max() > bits_max:
raise OverflowError
if legacy is True:
# When using `torch.compile`, the `pow` ops
# requires floating point numbers, but the
# `bitwise_and` requires integers.
mask = 2 ** torch.arange(
num_bits - 1, -1, -1,
dtype=torch.float32,
device=tensor.device)
mask = mask.to(dtype=torch.uint8)
else:
# 1. The above casting is not necessary for PyTorch>=2.1
# 2. We no longer reverse the bits directions
mask = 2 ** torch.arange(
num_bits,
dtype=torch.uint8,
device=tensor.device)
return (
tensor
.unsqueeze(dim=-1)
.bitwise_and(mask)
.ne(0)
.bool())
def from_binary(tensor: BinaryTensorType, num_bits: int, legacy: bool = True) -> UInt8TensorType:
if tensor.dtype != torch.bool:
raise TypeError
if tensor.shape[-1] != num_bits:
raise ValueError
if num_bits > 8:
raise NotImplementedError
if legacy is True:
mask = 2 ** torch.arange(
num_bits - 1, -1, -1,
dtype=torch.float32,
device=tensor.device)
mask = mask.to(dtype=torch.uint8)
else:
mask = 2 ** torch.arange(
num_bits,
dtype=torch.uint8,
device=tensor.device)
# This casting is somewhat unnecessary.
tensor = tensor.to(dtype=torch.uint8)
output = torch.sum(mask * tensor, dim=-1)
output = output.to(dtype=torch.uint8)
return output
def pack_bools_into_integers(
tensor: BinaryTensorType,
packed_dtype: torch.dtype,
legacy: bool = False,
) -> Tuple[PackedBinaryTensorType, int]:
if tensor.ndim != 1 or tensor.shape[-1] != tensor.numel():
raise ValueError
if tensor.dtype != torch.bool:
raise TypeError
if packed_dtype not in [torch.uint8, torch.int16, torch.int32]:
raise NotImplementedError
# number of bits in the packed dtype
packed_num_bits = torch.iinfo(packed_dtype).bits
remainder = (
tensor.shape[-1] %
packed_num_bits)
if remainder > 0:
padding_length = (
packed_num_bits -
remainder)
padding = tensor.new_zeros(padding_length)
tensor = torch.cat([tensor, padding], dim=-1)
else:
padding_length = 0
# [-1, packed_num_bits]
tensor = tensor.view(
int(tensor.shape[-1] / packed_num_bits),
packed_num_bits)
if legacy is True:
# [1, packed_num_bits]
bits = torch.arange(
packed_num_bits,
dtype=packed_dtype,
device=tensor.device)
bits = torch.unsqueeze(bits, dim=0)
packed_tensor = (tensor << bits)
packed_tensor = torch.sum(packed_tensor, dim=-1)
packed_tensor = packed_tensor.to(dtype=packed_dtype)
else:
# Allocate the output tensor in the desired dtype.
packed_tensor = torch.zeros(
tensor.shape[0],
dtype=packed_dtype,
device=tensor.device)
# Process each bit column individually.
for bit in range(packed_num_bits):
# Convert the boolean column to the target dtype and shift left by `bit`.
# This computes in the target dtype (e.g., int16) rather than int64.
packed_tensor |= tensor[:, bit].to(packed_dtype) << bit
return packed_tensor, padding_length
def unpack_integers_into_bools(
packed_tensor: PackedBinaryTensorType,
padding_length: int,
packed_dtype: torch.dtype,
) -> BinaryTensorType:
if packed_tensor.ndim != 1:
raise ValueError
if packed_tensor.dtype != packed_dtype:
raise TypeError
if packed_dtype not in [torch.uint8, torch.int16, torch.int32]:
raise NotImplementedError
# number of bits in the packed dtype
packed_num_bits = torch.iinfo(packed_dtype).bits
# [1, packed_num_bits]
bits = packed_tensor.new_tensor(
1,
dtype=packed_dtype)
bits = bits << torch.arange(
packed_num_bits,
dtype=packed_dtype,
device=packed_tensor.device)
bits = torch.unsqueeze(
bits,
dim=0)
unpacked_tensor = torch.unsqueeze(
packed_tensor,
dim=-1)
unpacked_tensor = unpacked_tensor & bits
if packed_dtype == torch.uint8:
unpacked_tensor = unpacked_tensor > 0
elif packed_dtype == torch.int32:
# For signed integers such as int32, the 31st element is the
# sign bit, so 0b10000000000000000000000000000000 = -2^31
# The following line of code can be applied to both settings.
# However, for legacy reasons, we only apply it to int32.
unpacked_tensor = unpacked_tensor != 0
else:
raise NotImplementedError
unpacked_tensor = unpacked_tensor.to(dtype=torch.bool)
unpacked_tensor = unpacked_tensor.view(-1)
if padding_length > 0:
unpacked_tensor = unpacked_tensor[:-padding_length]
return unpacked_tensor
def pack_integer_tensors(
tensor: UInt8TensorType,
num_bits: int,
) -> PackedBinaryTensorType:
# Two major differences for faster dequantization
# 1. `reverse=False`
# 2. `packed_dtype=torch.int32`
# 3. special implementation for `num_bits=3`
# 4. does not support padding
# [*tensor.shape, num_bits]
binary_tensor = to_binary(
tensor=tensor,
num_bits=num_bits,
legacy=False)
if num_bits == 3:
raise NotImplementedError
# [tensor.numel() x num_bits]
binary_tensor = binary_tensor.view(
tensor.numel() * num_bits)
binary_tensor = binary_tensor.contiguous()
# [tensor.numel() x num_bits / 32]
packed_tensor, padding_length = pack_bools_into_integers(
tensor=binary_tensor,
packed_dtype=PackedDType)
if padding_length != 0:
raise ValueError
return packed_tensor
|