Matt300209's picture
Upload folder using huggingface_hub
9823a7e verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import ctypes
import operator
from functools import reduce
from logging import getLogger
import torch
import torch.nn as nn
logger = getLogger(__name__)
from typing import List, Union, Optional
from bitblas.cache import global_operator_cache, get_database_path
from bitblas import Matmul, MatmulConfig
from bitblas.quantization.utils import general_compress
from bitblas import auto_detect_nvidia_target
BITBLAS_DATABASE_PATH = get_database_path()
def unpack_qzeros(qzeros, bits):
qzeros = qzeros.view(torch.int32)
elems_per_int32 = 32 // bits
unpacked_zeros = torch.zeros(
(qzeros.shape[0], qzeros.shape[1] * elems_per_int32),
dtype=torch.int8,
device=qzeros.device,
requires_grad=False,
)
for col in range(unpacked_zeros.shape[1]):
i = col % elems_per_int32
unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> (bits * i))
# Follow the instruction in AutoGPTQ qlinear_cuda_old.py line 303
# NOTE: It appears that casting after the `unpacked_zeros + 1` is important.
return torch.bitwise_and(unpacked_zeros + 1, 2**bits - 1)
# For gptqv2 from gptqmodel
def unpack_qzeros_v2(qzeros, bits):
qzeros = qzeros.view(torch.int32)
elems_per_int32 = 32 // bits
unpacked_zeros = torch.zeros(
(qzeros.shape[0], qzeros.shape[1] * elems_per_int32),
dtype=torch.int8,
device=qzeros.device,
requires_grad=False,
)
for col in range(unpacked_zeros.shape[1]):
i = col % elems_per_int32
unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> (bits * i))
# Follow the instruction in AutoGPTQ qlinear_cuda_old.py line 303
# NOTE: It appears that casting after the `unpacked_zeros + 1` is important.
return torch.bitwise_and(unpacked_zeros, 2**bits - 1)
def unpack_qweight(qweight, bits):
qweight = qweight.view(torch.int8)
elems_per_int8 = 8 // bits
unpacked_weight = torch.zeros(
(qweight.shape[0], qweight.shape[1] * elems_per_int8),
dtype=torch.int8,
device=qweight.device,
requires_grad=False,
)
for col in range(unpacked_weight.shape[1]):
i = col % elems_per_int8
unpacked_weight[:, col] = (qweight[:, col // elems_per_int8] >> (bits * i))
return torch.bitwise_and(unpacked_weight, 2**bits - 1)
class Linear(nn.Module):
opt_M = [16, 32, 64, 128, 256, 512]
STORAGE_DTYPE = "int8" # assume int8 storage
TORCH_STORAGE_DTYPE = getattr(torch, STORAGE_DTYPE)
BITBLAS_DTYPES = {
torch.float32: "float32",
torch.float16: "float16",
torch.half: "float16",
torch.int8: "int8",
}
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
A_dtype: str = "float16",
W_dtype: str = "float16",
accum_dtype: str = "float16",
out_dtype: str = "float16",
# configs for weight only quantization
group_size: int = -1,
with_scaling: bool = None,
with_zeros: bool = False,
zeros_mode: str = None,
opt_M: Union[int, List[int]] = opt_M,
# performance related configs
enable_tuning: bool = True,
fast_decoding: Optional[bool] = None,
propagate_b: bool = False,
):
"""
@opt_M: optimize range of the input shape for dynamic symbolic
if the input shape is a range, we will optimize the matmul with dynamic symbolic.
if the input shape is int, we will optimize the matmul with static symbolic.
"""
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.opt_M = opt_M
self.group_size = self._set_group_size(group_size, in_features)
self.torch_dtype = getattr(torch, A_dtype)
self.is_consitent = A_dtype == W_dtype
self.zeros_mode = zeros_mode
self._validate_parameters(self.group_size, in_features, out_features)
self._configure_bitblas_matmul(
A_dtype,
W_dtype,
accum_dtype,
out_dtype,
with_scaling,
with_zeros,
zeros_mode,
enable_tuning,
fast_decoding,
bias,
propagate_b,
)
self._initialize_buffers(in_features, out_features, bias)
def init_params(self):
# eliminate runtime overhead like exllama state
if self.is_consitent:
param_list = [self.weight]
if self.bitblas_matmul.config.with_bias:
param_list.append(self.bias)
self.q_params = [ctypes.c_void_p(arr.data_ptr()) for arr in param_list]
else:
param_list = [self.qweight]
if self.bitblas_matmul.config.with_scaling:
param_list.append(self.scales)
if self.bitblas_matmul.config.with_zeros:
param_list.append(self.zeros)
if self.bitblas_matmul.config.with_bias:
param_list.append(self.bias)
self.q_params = [ctypes.c_void_p(arr.data_ptr()) for arr in param_list]
def _validate_parameters(self, group_size, in_features, out_features):
if in_features % 16 != 0 or out_features % 16 != 0:
raise ValueError("`in_features` and `out_features` must be divisible by 16.")
if in_features % group_size != 0:
raise ValueError("`in_features` must be divisible by `group_size`.")
def _set_group_size(self, group_size, in_features):
return in_features if (group_size == -1 or group_size is None) else group_size
def _initialize_buffers(self, in_features, out_features, bias):
if self.consistent:
self.register_buffer(
"weight",
torch.zeros((out_features, in_features // self.group_size), dtype=self.torch_dtype),
)
else:
self.register_buffer(
"qweight",
torch.zeros(
self.bitblas_matmul.retrieve_weight_shape(),
dtype=self.TORCH_STORAGE_DTYPE,
),
)
self.register_buffer(
"scales",
torch.zeros((out_features, in_features // self.group_size), dtype=self.torch_dtype),
)
if self.zeros_mode == "quantized":
storage_nbit = int("".join(c for c in self.STORAGE_DTYPE if c.isdigit()))
self.register_buffer(
"zeros",
torch.zeros(
(
in_features // self.group_size,
out_features // storage_nbit * self.bits,
),
dtype=self.TORCH_STORAGE_DTYPE,
),
)
else:
self.register_buffer(
"zeros",
torch.zeros(
(out_features, in_features // self.group_size),
dtype=self.torch_dtype,
),
)
if bias:
self.register_buffer("bias", torch.zeros((out_features), dtype=self.torch_dtype))
else:
self.bias = None
def _configure_bitblas_matmul(
self,
A_dtype,
W_dtype,
accum_dtype,
out_dtype,
with_scaling,
with_zeros,
zeros_mode,
enable_tuning,
fast_decoding,
bias,
propagate_b,
):
matmul_config = MatmulConfig(
M=self.opt_M,
N=self.out_features,
K=self.in_features,
A_dtype=A_dtype,
W_dtype=W_dtype,
accum_dtype=accum_dtype,
out_dtype=out_dtype,
storage_dtype=self.STORAGE_DTYPE,
with_scaling=with_scaling,
with_zeros=with_zeros,
group_size=self.group_size,
fast_decoding=fast_decoding,
with_bias=bias,
propagate_b=propagate_b,
zeros_mode=zeros_mode,
)
self.bitblas_matmul = self._get_or_create_bitblas_operator(matmul_config, enable_tuning)
self.bits = self.bitblas_matmul.bit
self.source_format = self.bitblas_matmul.source_format
def _get_or_create_bitblas_operator(self, config, enable_tuning):
BITBLAS_TARGET = auto_detect_nvidia_target()
if global_operator_cache.size() == 0:
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
logger.info(f"Loaded {global_operator_cache.size()} operators from database.")
bitblas_matmul = global_operator_cache.get(config)
if bitblas_matmul is None:
# should disable tuning for the first time because we may require loading bitblas operator from database.
bitblas_matmul = Matmul(config, target=BITBLAS_TARGET, enable_tuning=False)
if enable_tuning:
bitblas_matmul.hardware_aware_finetune(topk=20)
global_operator_cache.add(config, bitblas_matmul)
global_operator_cache.save_into_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
print("BitBLAS Tuning done, appended operator to global_operator_cache.")
else:
print("BitBLAS Operator created.")
else:
print("BitBLAS Operator found in global_operator_cache.")
return bitblas_matmul
def warmup(self, topk=20):
self.bitblas_matmul.hardware_aware_finetune(topk=topk)
def forward(self, A, output=None):
A = self.bitblas_matmul.transform_input(A)
stream = torch.cuda.current_stream()
A_void = ctypes.c_void_p(A.data_ptr())
stream_handle = ctypes.c_void_p(stream.cuda_stream)
# can be lifted to post init.
self.init_params()
args = [A_void, *self.q_params]
if output is None:
output = torch.zeros(
A.shape[:-1] + (self.out_features,),
dtype=getattr(torch, self.bitblas_matmul.out_dtype),
device=A.device)
args.append(ctypes.c_void_p(output.data_ptr()))
if self.bitblas_matmul.dynamic_range is not None:
m = reduce(operator.mul, A.shape[:-1], 1)
args.append(m)
args.append(stream_handle)
# m is the product of the last n - 1 dimensions of A
self.bitblas_matmul.lib.call(*args)
return output
def load_and_transform_weight(
self,
weight: torch.Tensor,
scales: torch.Tensor = None,
zeros: torch.Tensor = None,
bias: torch.Tensor = None,
):
if self.consistent:
assert scales is None, "scales should be None for consistent mode."
assert zeros is None, "zeros should be None for consistent mode."
weight = self.bitblas_matmul.transform_weight(weight)
self.weight = nn.Parameter(weight)
if bias is not None:
self.bias = bias
else:
weight = self.bitblas_matmul.transform_weight(weight)
self.qweight = weight
if scales is not None:
self.scales = scales
if zeros is not None:
self.zeros = zeros
if bias is not None:
self.bias = bias
def repack_from_gptq(self, gptq_module, device="cuda"):
# qweight in gptq old quant linear stored with (out_features, in_features), should be transposed.
qweight = gptq_module.qweight.T.contiguous().view(self.TORCH_STORAGE_DTYPE)
intweight = unpack_qweight(qweight, self.bits).contiguous()
if self.bitblas_matmul.weight_transform is not None:
qweight = self.bitblas_matmul.weight_transform(intweight.cpu()).to(device)
self.qweight = qweight
# scales in gptq old quant linear stored with (in_features // group_size, out_features), should be transposed.
scales = gptq_module.scales.T.contiguous().view(self.torch_dtype)
self.scales = scales
# qzeros should be dequantized to int zeros.
intzeros = unpack_qzeros(gptq_module.qzeros, self.bits).T.contiguous()
if self.bitblas_matmul.config.zeros_mode == "original":
self.zeros = intzeros.to(torch.float16).contiguous()
elif self.bitblas_matmul.config.zeros_mode == "rescale":
self.zeros[:, :] = intzeros.to(torch.float16)[:, :] * self.scales[:, :]
elif self.bitblas_matmul.config.zeros_mode == "quantized":
self.zeros = (
torch.Tensor(general_compress(intzeros.T.contiguous().cpu().numpy(), self.bits)).to(
self.qweight.device).to(self.zeros.dtype).contiguous())
else:
raise ValueError(f"Unsupported zeros type: {self.bitblas_matmul.config.zeros_mode}")
if self.bias is not None:
self.bias = gptq_module.bias.data.to(torch.float16).contiguous()
def repack_from_gptq_v2(self, gptq_module):
# qweight in gptq old quant linear stored with (out_features, in_features), should be transposed.
qweight = gptq_module.qweight.T.contiguous().view(self.TORCH_STORAGE_DTYPE)
intweight = unpack_qweight(qweight, self.bits).contiguous()
if self.bitblas_matmul.weight_transform is not None:
qweight = self.bitblas_matmul.weight_transform(intweight.cpu()).cuda()
self.qweight = qweight
# scales in gptq old quant linear stored with (in_features // group_size, out_features), should be transposed.
scales = gptq_module.scales.T.contiguous().view(self.torch_dtype)
self.scales = scales
# qzeros should be dequantized to int zeros.
intzeros = unpack_qzeros_v2(gptq_module.qzeros, self.bits).T.contiguous()
if self.bitblas_matmul.config.zeros_mode == "original":
self.zeros = intzeros.to(torch.float16).contiguous()
elif self.bitblas_matmul.config.zeros_mode == "rescale":
self.zeros[:, :] = intzeros.to(torch.float16)[:, :] * self.scales[:, :]
elif self.bitblas_matmul.config.zeros_mode == "quantized":
self.zeros = (
torch.Tensor(general_compress(intzeros.T.contiguous().cpu().numpy(), self.bits)).to(
self.qweight.device).to(self.zeros.dtype).contiguous())
else:
raise ValueError(f"Unsupported zeros type: {self.bitblas_matmul.config.zeros_mode}")
if self.bias is not None:
self.bias = gptq_module.bias.data.to(torch.float16).contiguous()
@property
def consistent(self):
return self.is_consitent
__all__ = ["Linear"]