build-tools / bitsandbytes /backends /triton /kernels_8bit_quant.py
salmankhanpm's picture
Add files using upload-large-folder tool
dc9bb20 verified
import torch
import triton
import triton.language as tl
# @triton.autotune(
# configs=[
# # triton.Config({'SPLIT_SIZE': 64}),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128}),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_SIZE": 256}),
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# triton.Config({"SPLIT_SIZE": 512}),
# # triton.Config({'SPLIT_SIZE': 1024}),
# ],
# key=["num_paired_elements", "QUANT_BLOCK"],
# )
@triton.jit
def dequant_8bit_kernel(
a_ptr,
out_ptr,
code_ptr,
absmax_ptr,
n,
QUANT_BLOCK: tl.constexpr,
SPLIT_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * SPLIT_SIZE
offsets = block_start + tl.arange(0, SPLIT_SIZE)
mask = offsets < n
out_dq = dequant_8bit_blockwise_kernel_util(a_ptr, offsets, code_ptr, absmax_ptr, mask, QUANT_BLOCK)
tl.store(out_ptr + offsets, out_dq, mask)
def dequant_8bit_blockwise(
a: torch.Tensor,
absmax: torch.Tensor,
quant_state_code: torch.Tensor,
quant_blocksize: int = 64,
dtype: torch.dtype = None,
out: torch.Tensor = None,
):
n = a.numel()
if out is None:
if dtype is None:
raise ValueError("If out is None, dtype must be specified")
out = torch.empty_like(a, dtype=dtype, device=a.device)
SPLIT_SIZE = 256
# grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),)
grid = (triton.cdiv(n, SPLIT_SIZE),)
dequant_8bit_kernel[grid](
a,
out,
quant_state_code,
absmax,
n,
quant_blocksize,
SPLIT_SIZE,
)
return out
# @triton.autotune(
# configs=[
# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 1}),
# triton.Config({"SPLIT_NUM_BLOCKS": 2}),
# ],
# key=["n_elements"],
# )
@triton.jit
def quantize_8bit_blockwise_kernel(
A_ptr,
code_ptr,
absmax_ptr,
out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
CODE_SIZE: tl.constexpr,
SPLIT_NUM_BLOCKS: tl.constexpr,
):
block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS
thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)
offsets = block_start_idx * BLOCK_SIZE + thread_idx
mask = offsets < n_elements
A = tl.load(A_ptr + offsets, mask=mask, other=0.0)
quantized, absmax = quantize_8bit_blockwise_kernel_util(A, code_ptr, CODE_SIZE, BLOCK_SIZE, SPLIT_NUM_BLOCKS)
tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax)
tl.store(out_ptr + offsets, quantized, mask=mask)
def quantize_blockwise_triton(A, code, blocksize, absmax=None, out=None):
n = A.numel()
blocks = -(n // -blocksize)
if absmax is None:
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
if out is None:
out = torch.empty_like(A.flatten(), dtype=torch.uint8)
split_num_blocks = 1
grid = (triton.cdiv(blocks, split_num_blocks),)
# grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),)
quantize_8bit_blockwise_kernel[grid](
A_ptr=A,
code_ptr=code,
absmax_ptr=absmax,
out_ptr=out,
n_elements=n,
BLOCK_SIZE=blocksize,
CODE_SIZE=code.numel(),
SPLIT_NUM_BLOCKS=split_num_blocks,
# num_warps=1,
# num_stages=2,
)
out = out.reshape(A.shape)
return out, absmax
@triton.jit
def quantize_8bit_blockwise_kernel_util(
a,
code_ptr,
CODE_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
N_PER_TH: tl.constexpr,
):
# To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS)
a_reshaped = tl.reshape(a, (N_PER_TH, BLOCK_SIZE))
# Calculating absmax for each block
absmax = tl.max(tl.abs(a_reshaped), axis=1)
a_normalized = a_reshaped / absmax[:, None]
a_normalized = tl.clamp(a_normalized, -1.0, 1.0)
lower_pivot = tl.zeros((N_PER_TH, BLOCK_SIZE), dtype=tl.int32)
upper_pivot = tl.full((N_PER_TH, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32)
# ceil(log2(code_size)) = 8, actually, in general case should be input parameter
for _ in range(8):
pivot = (lower_pivot + upper_pivot) // 2
val = tl.load(code_ptr + pivot)
is_higher = a_normalized > val # code[pivot]
lower_pivot = tl.where(is_higher, pivot, lower_pivot)
upper_pivot = tl.where(is_higher, upper_pivot, pivot)
# Choose closest level
lower_val = tl.load(code_ptr + lower_pivot)
upper_val = tl.load(code_ptr + upper_pivot)
lower_dist = tl.abs(a_normalized - lower_val)
upper_dist = tl.abs(a_normalized - upper_val)
quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8)
# too slow approach
# diff = tl.abs(A_normalized[:, :, None] - code[None, None, :])
# quantized = tl.argmin(diff, axis=2).to(tl.uint8)
quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * N_PER_TH,))
return quantized_flat, absmax
@triton.jit
def dequant_8bit_blockwise_kernel_util(
a_ptr,
offsets,
code_ptr,
absmax_ptr,
mask,
BLOCK_SIZE: tl.constexpr,
):
a = tl.load(a_ptr + offsets, mask, other=0).to(tl.uint8)
scaled_int8 = tl.load(code_ptr + a, mask)
# Load scales
absmax_offsets = offsets // BLOCK_SIZE
absmax = tl.load(absmax_ptr + absmax_offsets, mask=mask, other=0.0, eviction_policy="evict_last")
# Apply scales
out_dq = scaled_int8 * absmax
return out_dq