flute_kernels / torch-ext /flute /packbits_utils.py
galqiwi's picture
Initial source: FLUTE kernel scaffold (vendored CUTLASS, split TUs)
67a5826 verified
import math
import torch
from typing import Tuple, Union, Optional, NamedTuple
PackedDType = torch.int16
PackedNumBits = torch.iinfo(PackedDType).bits
FloatTensorType = torch.Tensor
UInt8TensorType = torch.Tensor
Int16TensorType = torch.Tensor
Int32TensorType = torch.Tensor
BinaryTensorType = torch.Tensor
PackedBinaryTensorType = torch.Tensor
# https://stackoverflow.com/questions/55918468/convert-integer-to-pytorch-tensor-of-binary-bits
def to_binary(tensor: UInt8TensorType, num_bits: int, legacy: bool = True) -> BinaryTensorType:
if tensor.dtype != torch.uint8:
raise TypeError
if num_bits > 8:
raise NotImplementedError
# Explicit casting, and the following code will
# raise an Error if casting leads to overflow
bits_max = torch.tensor(
2 ** num_bits - 1,
dtype=torch.uint8,
device=tensor.device)
if tensor.max() > bits_max:
raise OverflowError
if legacy is True:
# When using `torch.compile`, the `pow` ops
# requires floating point numbers, but the
# `bitwise_and` requires integers.
mask = 2 ** torch.arange(
num_bits - 1, -1, -1,
dtype=torch.float32,
device=tensor.device)
mask = mask.to(dtype=torch.uint8)
else:
# 1. The above casting is not necessary for PyTorch>=2.1
# 2. We no longer reverse the bits directions
mask = 2 ** torch.arange(
num_bits,
dtype=torch.uint8,
device=tensor.device)
return (
tensor
.unsqueeze(dim=-1)
.bitwise_and(mask)
.ne(0)
.bool())
def from_binary(tensor: BinaryTensorType, num_bits: int, legacy: bool = True) -> UInt8TensorType:
if tensor.dtype != torch.bool:
raise TypeError
if tensor.shape[-1] != num_bits:
raise ValueError
if num_bits > 8:
raise NotImplementedError
if legacy is True:
mask = 2 ** torch.arange(
num_bits - 1, -1, -1,
dtype=torch.float32,
device=tensor.device)
mask = mask.to(dtype=torch.uint8)
else:
mask = 2 ** torch.arange(
num_bits,
dtype=torch.uint8,
device=tensor.device)
# This casting is somewhat unnecessary.
tensor = tensor.to(dtype=torch.uint8)
output = torch.sum(mask * tensor, dim=-1)
output = output.to(dtype=torch.uint8)
return output
def pack_bools_into_integers(
tensor: BinaryTensorType,
packed_dtype: torch.dtype,
legacy: bool = False,
) -> Tuple[PackedBinaryTensorType, int]:
if tensor.ndim != 1 or tensor.shape[-1] != tensor.numel():
raise ValueError
if tensor.dtype != torch.bool:
raise TypeError
if packed_dtype not in [torch.uint8, torch.int16, torch.int32]:
raise NotImplementedError
# number of bits in the packed dtype
packed_num_bits = torch.iinfo(packed_dtype).bits
remainder = (
tensor.shape[-1] %
packed_num_bits)
if remainder > 0:
padding_length = (
packed_num_bits -
remainder)
padding = tensor.new_zeros(padding_length)
tensor = torch.cat([tensor, padding], dim=-1)
else:
padding_length = 0
# [-1, packed_num_bits]
tensor = tensor.view(
int(tensor.shape[-1] / packed_num_bits),
packed_num_bits)
if legacy is True:
# [1, packed_num_bits]
bits = torch.arange(
packed_num_bits,
dtype=packed_dtype,
device=tensor.device)
bits = torch.unsqueeze(bits, dim=0)
packed_tensor = (tensor << bits)
packed_tensor = torch.sum(packed_tensor, dim=-1)
packed_tensor = packed_tensor.to(dtype=packed_dtype)
else:
# Allocate the output tensor in the desired dtype.
packed_tensor = torch.zeros(
tensor.shape[0],
dtype=packed_dtype,
device=tensor.device)
# Process each bit column individually.
for bit in range(packed_num_bits):
# Convert the boolean column to the target dtype and shift left by `bit`.
# This computes in the target dtype (e.g., int16) rather than int64.
packed_tensor |= tensor[:, bit].to(packed_dtype) << bit
return packed_tensor, padding_length
def unpack_integers_into_bools(
packed_tensor: PackedBinaryTensorType,
padding_length: int,
packed_dtype: torch.dtype,
) -> BinaryTensorType:
if packed_tensor.ndim != 1:
raise ValueError
if packed_tensor.dtype != packed_dtype:
raise TypeError
if packed_dtype not in [torch.uint8, torch.int16, torch.int32]:
raise NotImplementedError
# number of bits in the packed dtype
packed_num_bits = torch.iinfo(packed_dtype).bits
# [1, packed_num_bits]
bits = packed_tensor.new_tensor(
1,
dtype=packed_dtype)
bits = bits << torch.arange(
packed_num_bits,
dtype=packed_dtype,
device=packed_tensor.device)
bits = torch.unsqueeze(
bits,
dim=0)
unpacked_tensor = torch.unsqueeze(
packed_tensor,
dim=-1)
unpacked_tensor = unpacked_tensor & bits
if packed_dtype == torch.uint8:
unpacked_tensor = unpacked_tensor > 0
elif packed_dtype == torch.int32:
# For signed integers such as int32, the 31st element is the
# sign bit, so 0b10000000000000000000000000000000 = -2^31
# The following line of code can be applied to both settings.
# However, for legacy reasons, we only apply it to int32.
unpacked_tensor = unpacked_tensor != 0
else:
raise NotImplementedError
unpacked_tensor = unpacked_tensor.to(dtype=torch.bool)
unpacked_tensor = unpacked_tensor.view(-1)
if padding_length > 0:
unpacked_tensor = unpacked_tensor[:-padding_length]
return unpacked_tensor
def pack_integer_tensors(
tensor: UInt8TensorType,
num_bits: int,
) -> PackedBinaryTensorType:
# Two major differences for faster dequantization
# 1. `reverse=False`
# 2. `packed_dtype=torch.int32`
# 3. special implementation for `num_bits=3`
# 4. does not support padding
# [*tensor.shape, num_bits]
binary_tensor = to_binary(
tensor=tensor,
num_bits=num_bits,
legacy=False)
if num_bits == 3:
raise NotImplementedError
# [tensor.numel() x num_bits]
binary_tensor = binary_tensor.view(
tensor.numel() * num_bits)
binary_tensor = binary_tensor.contiguous()
# [tensor.numel() x num_bits / 32]
packed_tensor, padding_length = pack_bools_into_integers(
tensor=binary_tensor,
packed_dtype=PackedDType)
if padding_length != 0:
raise ValueError
return packed_tensor