build-tools / bitsandbytes /backends /triton /kernels_4bit.py
salmankhanpm's picture
Add files using upload-large-folder tool
dc9bb20 verified
import torch
import triton
import triton.language as tl
# Triton implementation of similar CUDA kernel to avoid loading code from csrc/kernels.cu::dQuantizeFP4
# @triton.autotune(
# configs=[
# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 1}),
# triton.Config({"SPLIT_NUM_BLOCKS": 2}),
# triton.Config({"SPLIT_NUM_BLOCKS": 4}),
# triton.Config({"SPLIT_NUM_BLOCKS": 8}),
# ],
# key=["n_elements"],
# )
@triton.jit
def quantize_fp4_blockwise_kernel(
A_ptr,
absmax_ptr,
out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
SPLIT_NUM_BLOCKS: tl.constexpr,
):
PAIRED_SPLIT_NUM_BLOCKS: tl.constexpr = SPLIT_NUM_BLOCKS * 2
block_start_idx = tl.program_id(0) * PAIRED_SPLIT_NUM_BLOCKS
thread_idx = tl.arange(0, PAIRED_SPLIT_NUM_BLOCKS * BLOCK_SIZE)
offsets = block_start_idx * BLOCK_SIZE + thread_idx
mask = offsets < n_elements
A = tl.load(A_ptr + offsets, mask=mask, other=0.0)
# To be able process several blocks -> (PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE)
A_reshaped = tl.reshape(A, (PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE))
# Calculating absamax for each block
absmax = tl.max(tl.abs(A_reshaped), axis=1)
tl.store(absmax_ptr + block_start_idx + tl.arange(0, PAIRED_SPLIT_NUM_BLOCKS), absmax)
A_normalized = A_reshaped / absmax[:, None]
A_normalized = tl.clamp(A_normalized, -1.0, 1.0)
sign = tl.where(A_normalized < 0, 0b1000, 0b0000)
A_absf = tl.abs(A_normalized)
result = tl.where(
A_absf > 0.29166667,
tl.where(
A_absf > 0.583333, tl.where(A_absf > 0.8333333, 0b011, 0b010), tl.where(A_absf > 0.4166667, 0b101, 0b100)
),
tl.where(
A_absf > 0.0859375,
tl.where(A_absf > 0.20833333, 0b0111, 0b0110),
tl.where(A_absf > 0.00260417, 0b0001, 0b0000),
),
)
quantized = (result ^ sign).to(tl.uint8)
quantized = quantized.reshape((PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE // 2, 2))
left, right = quantized.split()
packed = left << 4 | (right & 0xF)
packed_flat = tl.reshape(packed, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,))
out_offsets = block_start_idx * BLOCK_SIZE // 2 + tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)
# Use n - n//2 instead of (n+1)//2 to avoid integer overflow for large n
out_mask = out_offsets < (n_elements - n_elements // 2)
tl.store(out_ptr + out_offsets, packed_flat, mask=out_mask)
# Triton implementation of similar CUDA kernel to avoid loading code from csrc/kernels.cu::dQuantizeNF4
# @triton.autotune(
# configs=[
# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 1}),
# triton.Config({"SPLIT_NUM_BLOCKS": 2}),
# triton.Config({"SPLIT_NUM_BLOCKS": 4}),
# triton.Config({"SPLIT_NUM_BLOCKS": 8}),
# ],
# key=["n_elements"],
# )
@triton.jit
def quantize_nf4_blockwise_kernel(
A_ptr,
absmax_ptr,
out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
SPLIT_NUM_BLOCKS: tl.constexpr,
):
PAIRED_SPLIT_NUM_BLOCKS: tl.constexpr = SPLIT_NUM_BLOCKS * 2
block_start_idx = tl.program_id(0) * PAIRED_SPLIT_NUM_BLOCKS
thread_idx = tl.arange(0, PAIRED_SPLIT_NUM_BLOCKS * BLOCK_SIZE)
offsets = block_start_idx * BLOCK_SIZE + thread_idx
mask = offsets < n_elements
A = tl.load(A_ptr + offsets, mask=mask, other=0.0)
# To be able process several blocks -> (PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE)
A_reshaped = tl.reshape(A, (PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE))
# Calculating absamax for each block
absmax = tl.max(tl.abs(A_reshaped), axis=1)
tl.store(absmax_ptr + block_start_idx + tl.arange(0, PAIRED_SPLIT_NUM_BLOCKS), absmax)
A_normalized = A_reshaped / absmax[:, None]
A_normalized = tl.clamp(A_normalized, -1.0, 1.0)
result = tl.where(
A_normalized > 0.03979014977812767,
tl.where(
A_normalized > 0.3893125355243683,
tl.where(
A_normalized > 0.6427869200706482,
tl.where(A_normalized > 0.8614784181118011, 0b1111, 0b1110),
tl.where(A_normalized > 0.5016634166240692, 0b1101, 0b1100),
),
tl.where(
A_normalized > 0.2035212516784668,
tl.where(A_normalized > 0.2920137718319893, 0b1011, 0b1010),
tl.where(A_normalized > 0.1202552504837513, 0b1001, 0b1000),
),
),
tl.where(
A_normalized > -0.33967943489551544,
tl.where(
A_normalized > -0.13791173323988914,
tl.where(A_normalized > -0.045525018125772476, 0b0111, 0b0110),
tl.where(A_normalized > -0.23460740596055984, 0b0101, 0b0100),
),
tl.where(
A_normalized > -0.6106329262256622,
tl.where(A_normalized > -0.4599952697753906, 0b0011, 0b0010),
tl.where(A_normalized > -0.8480964004993439, 0b0001, 0b0000),
),
),
)
quantized = result.to(tl.uint8)
quantized = quantized.reshape((PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE // 2, 2))
left, right = quantized.split()
packed = left << 4 | (right & 0xF)
packed_flat = tl.reshape(packed, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,))
out_offsets = block_start_idx * BLOCK_SIZE // 2 + tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)
# Use n - n//2 instead of (n+1)//2 to avoid integer overflow for large n
out_mask = out_offsets < (n_elements - n_elements // 2)
tl.store(out_ptr + out_offsets, packed_flat, mask=out_mask)
def quantize_4bit_blockwise_triton(A, blocksize, quant_type, blocks, absmax, num_elements, quantized_out):
# grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),)
split_num_blocks = 4
grid = (triton.cdiv(blocks, split_num_blocks),)
if quant_type == "fp4":
quantize_fp4_blockwise_kernel[grid](
A_ptr=A,
absmax_ptr=absmax,
out_ptr=quantized_out,
n_elements=num_elements,
BLOCK_SIZE=blocksize,
SPLIT_NUM_BLOCKS=split_num_blocks,
)
else:
quantize_nf4_blockwise_kernel[grid](
A_ptr=A,
absmax_ptr=absmax,
out_ptr=quantized_out,
n_elements=num_elements,
BLOCK_SIZE=blocksize,
SPLIT_NUM_BLOCKS=split_num_blocks,
)
return quantized_out, absmax
@triton.jit
def dequant_4bit_body_util(a, offsets, quant_ptr, absmax_ptr, n_elems, QUANT_BLOCK: tl.constexpr):
PAIRED_QUANT_BLOCK: tl.constexpr = QUANT_BLOCK // 2
mask = offsets < n_elems
higher = a & 0xF
# lower 4bits
lower = a >> 4
abs_offsets = offsets // PAIRED_QUANT_BLOCK
absmax = tl.load(absmax_ptr + abs_offsets, mask=mask, other=1.0, eviction_policy="evict_last")
# apply conversion
lower_4 = tl.load(quant_ptr + lower, eviction_policy="evict_last")
higher_4 = tl.load(quant_ptr + higher, eviction_policy="evict_last")
mul_high = higher_4 * absmax
mul_low = lower_4 * absmax
out_dq = tl.interleave(mul_low, mul_high)
return out_dq
# Triton implementation of similar CUDA kernel to avoid loading code from csrc/kernels.cu::dDequantizeFP4Tree
@triton.jit
def dequantize_fp4_tree(val, absmax):
# val: tl.tensor (uint8)
# absmax: tl.tensor (float32/float16)
# 00001100 00001011 00001001 00001111
sign = tl.where((val & 0b1000) == 0b1000, -1.0, 1.0) # -1
third_bit = (val & 0b0100) == 0b0100 # True
second_bit = (val & 0b0010) == 0b0010 # False
first_bit = (val & 0b0001) == 0b0001 # False
branch1 = tl.where(
second_bit,
tl.where(first_bit, 0.25, 0.16666667), # 1111, 1110
tl.where(first_bit, 0.5, 0.33333333), # 1101, 1100
)
branch2 = tl.where(
second_bit,
tl.where(first_bit, 1.0, 0.66666667), # 1011, 1010
tl.where(first_bit, 0.00520833, 0.0), # 1001, 1000
)
out = tl.where(third_bit, branch1, branch2)
return out * sign * absmax
@triton.jit
def dequant_fp4_body_util(a, offsets, absmax_ptr, n_elems, QUANT_BLOCK: tl.constexpr):
PAIRED_QUANT_BLOCK: tl.constexpr = QUANT_BLOCK // 2
mask = offsets < n_elems
higher = a & 0xF
lower = a >> 4
abs_offsets = offsets // PAIRED_QUANT_BLOCK
absmax = tl.load(absmax_ptr + abs_offsets, mask=mask, other=1.0, eviction_policy="evict_last")
mul_high = dequantize_fp4_tree(higher, absmax)
mul_low = dequantize_fp4_tree(lower, absmax)
out_dq = tl.interleave(mul_low, mul_high)
return out_dq
# Triton implementation of similar CUDA kernel to avoid loading code from csrc/kernels.cu::dDequantizeNF4
@triton.jit
def dequantize_nf4_tree(val):
# val: tl.tensor (uint8)
cond0 = (val & 0b1000) == 0b1000
cond1 = (val & 0b0100) == 0b0100
cond2 = (val & 0b0010) == 0b0010
cond3 = (val & 0b0001) == 0b0001
# Positive branch (val & 0b1000) == 8
branch_pos = tl.where(
cond1,
tl.where(
cond2,
tl.where(cond3, 1.0, 0.7229568362236023), # 1111, 1110
tl.where(cond3, 0.5626170039176941, 0.44070982933044434), # 1101, 1100
),
tl.where(
cond2,
tl.where(cond3, 0.33791524171829224, 0.24611230194568634), # 1011, 1010
tl.where(cond3, 0.16093020141124725, 0.07958029955625534), # 1001, 1000
),
)
# Negative branch (val & 0b1000) == 0
branch_neg = tl.where(
cond1,
tl.where(
cond2,
tl.where(cond3, 0.0, -0.09105003625154495), # 0111, 0110
tl.where(cond3, -0.18477343022823334, -0.28444138169288635), # 0101, 0100
),
tl.where(
cond2,
tl.where(cond3, -0.39491748809814453, -0.5250730514526367), # 0011, 0010
tl.where(cond3, -0.6961928009986877, -1.0), # 0001, 0000
),
)
return tl.where(cond0, branch_pos, branch_neg)
@triton.jit
def dequant_nf4_body_util(a, offsets, absmax_ptr, n_elems, QUANT_BLOCK: tl.constexpr):
PAIRED_QUANT_BLOCK: tl.constexpr = QUANT_BLOCK // 2
mask = offsets < n_elems
higher = a & 0xF
# lower 4bits
lower = a >> 4
abs_offsets = offsets // PAIRED_QUANT_BLOCK
absmax = tl.load(absmax_ptr + abs_offsets, mask=mask, other=1.0, eviction_policy="evict_last")
mul_high = dequantize_nf4_tree(higher) * absmax
mul_low = dequantize_nf4_tree(lower) * absmax
out_dq = tl.interleave(mul_low, mul_high)
return out_dq
# All such kernels are similar, so maybe code can be generalised.
# @triton.autotune(
# configs=[
# # # triton.Config({'SPLIT_SIZE': 64}),
# # # # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # # # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # # # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # # # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# triton.Config({'SPLIT_SIZE': 128}),
# triton.Config({'SPLIT_SIZE': 128}, num_warps = 32, num_stages = 2),
# # # triton.Config({'SPLIT_SIZE': 128}, num_warps = 4, num_stages = 4),
# # # # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # # # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # # # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # # # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# triton.Config({'SPLIT_SIZE': 256}),
# triton.Config({'SPLIT_SIZE': 256}, num_warps = 32, num_stages = 2),
# # triton.Config({'SPLIT_SIZE': 256}, num_warps = 4, num_stages = 4),
# triton.Config({'SPLIT_SIZE': 512}),
# triton.Config({'SPLIT_SIZE': 512}, num_warps = 32, num_stages = 2),
# # triton.Config({'SPLIT_SIZE': 512}, num_warps = 4, num_stages = 4),
# # # # triton.Config({'SPLIT_SIZE': 512, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # # # triton.Config({'SPLIT_SIZE': 512, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # # # triton.Config({'SPLIT_SIZE': 512, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # # # triton.Config({'SPLIT_SIZE': 512, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# # # triton.Config({'SPLIT_SIZE': 1024}),
# # # # triton.Config({'SPLIT_SIZE': 2048}),
# # # # triton.Config({'SPLIT_SIZE': 4096}),
# # # # triton.Config({'SPLIT_SIZE': 8192}),
# # # # triton.Config({'SPLIT_SIZE': 16384}),
# ],
# key=['num_paired_elements'],
# )
@triton.jit
def dequant_4bit_kernel(
a_ptr,
c_ptr,
quant_ptr,
absmax_ptr,
num_paired_elements,
num_output_elements,
QUANT_BLOCK: tl.constexpr,
SPLIT_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
block_start = pid * SPLIT_SIZE
offsets = block_start + tl.arange(0, SPLIT_SIZE)
mask = offsets < num_paired_elements
a = tl.load(a_ptr + offsets, mask, eviction_policy="evict_first")
out_dq = dequant_4bit_body_util(
a=a,
offsets=offsets,
quant_ptr=quant_ptr,
absmax_ptr=absmax_ptr,
n_elems=num_paired_elements,
QUANT_BLOCK=QUANT_BLOCK,
)
out_block_start = pid * SPLIT_SIZE * 2
offs = out_block_start + tl.arange(0, SPLIT_SIZE * 2)
mask = offs < num_output_elements
tl.store(c_ptr + offs, out_dq, mask)
# @triton.autotune(
# configs=[
# triton.Config({'SPLIT_SIZE': 128}, num_warps = 32, num_stages = 2),
# triton.Config({'SPLIT_SIZE': 256}),
# triton.Config({'SPLIT_SIZE': 256}, num_warps = 32, num_stages = 2),
# triton.Config({'SPLIT_SIZE': 512}),
# triton.Config({'SPLIT_SIZE': 512}, num_warps = 32, num_stages = 2),
# triton.Config({'SPLIT_SIZE': 1024}, num_warps = 32, num_stages = 2),
# ],
# key=['num_paired_elements'],
# )
@triton.jit
def dequant_fp4_kernel(
a_ptr,
c_ptr,
absmax_ptr,
num_paired_elements,
num_output_elements,
QUANT_BLOCK: tl.constexpr,
SPLIT_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
block_start = pid * SPLIT_SIZE
offsets = block_start + tl.arange(0, SPLIT_SIZE)
mask = offsets < num_paired_elements
a = tl.load(a_ptr + offsets, mask, eviction_policy="evict_first")
out_dq = dequant_fp4_body_util(
a=a,
offsets=offsets,
absmax_ptr=absmax_ptr,
n_elems=num_paired_elements,
QUANT_BLOCK=QUANT_BLOCK,
)
out_block_start = pid * SPLIT_SIZE * 2
offs = out_block_start + tl.arange(0, SPLIT_SIZE * 2)
mask = offs < num_output_elements
tl.store(c_ptr + offs, out_dq, mask)
# @triton.autotune(
# configs=[
# triton.Config({'SPLIT_SIZE': 128}, num_warps = 32, num_stages = 2),
# triton.Config({'SPLIT_SIZE': 256}),
# triton.Config({'SPLIT_SIZE': 256}, num_warps = 32, num_stages = 2),
# triton.Config({'SPLIT_SIZE': 512}),
# triton.Config({'SPLIT_SIZE': 512}, num_warps = 32, num_stages = 2),
# triton.Config({'SPLIT_SIZE': 1024}, num_warps = 32, num_stages = 2),
# ],
# key=['num_paired_elements'],
# )
@triton.jit
def dequant_nf4_kernel(
a_ptr,
c_ptr,
absmax_ptr,
num_paired_elements,
num_output_elements,
QUANT_BLOCK: tl.constexpr,
SPLIT_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
block_start = pid * SPLIT_SIZE
offsets = block_start + tl.arange(0, SPLIT_SIZE)
mask = offsets < num_paired_elements
a = tl.load(a_ptr + offsets, mask, eviction_policy="evict_first")
out_dq = dequant_nf4_body_util(
a=a,
offsets=offsets,
absmax_ptr=absmax_ptr,
n_elems=num_paired_elements,
QUANT_BLOCK=QUANT_BLOCK,
)
out_block_start = pid * SPLIT_SIZE * 2
offs = out_block_start + tl.arange(0, SPLIT_SIZE * 2)
mask = offs < num_output_elements
tl.store(c_ptr + offs, out_dq, mask)
def dequantize_4bit_impl(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
# It's will be processed as an array, so
# actual length is row * col
# Elements are in uint8 format, so interleaved
# so total amount of data is 2 * elem_count
number_of_paired_elements = A.numel()
num_output_elements = out.numel()
# we assume that split_size > quant_blocksize
SPLIT_SIZE = 256
# grid = lambda META: (triton.cdiv(number_of_paired_elements, META['SPLIT_SIZE']), )
grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),)
if quant_type == "fp4":
dequant_fp4_kernel[grid](A, out, absmax, number_of_paired_elements, num_output_elements, blocksize, SPLIT_SIZE)
else:
dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, num_output_elements, blocksize, SPLIT_SIZE)
def dequantize_4bit_impl_passing_code(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
code: torch.Tensor,
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
number_of_paired_elements = A.numel()
num_output_elements = out.numel()
# we assume that split_size > quant_blocksize
SPLIT_SIZE = 256
# grid = lambda META: (triton.cdiv(number_of_paired_elements, META['SPLIT_SIZE']), )
grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),)
dequant_4bit_kernel[grid](
A, out, code, absmax, number_of_paired_elements, num_output_elements, blocksize, SPLIT_SIZE
)
######################### Fallback dequantization functions #########################
## for debug ##
# @triton.autotune(
# configs=[
# # triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # #
# # triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# #
# triton.Config({"SPLIT_NUM_BLOCKS": 2}),
# # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "large"}, num_stages=2, num_warps=32),
# # # triton.Config({'SPLIT_NUM_BLOCKS': 2, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=2, num_warps=32),
# # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# # triton.Config({"SPLIT_NUM_BLOCKS": 4, "grf_mode": "large"}, num_stages=2, num_warps=32),
# # triton.Config({"SPLIT_NUM_BLOCKS": 4, "grf_mode": "large"}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_NUM_BLOCKS': 8, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# ],
# key=["n_elements", "BLOCK_SIZE"],
# )
@triton.jit
def quantize_4bit_blockwise_kernel(
A_ptr,
code_ptr,
absmax_ptr,
out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
CODE_SIZE: tl.constexpr,
SPLIT_NUM_BLOCKS: tl.constexpr,
):
PAIRED_SPLIT_NUM_BLOCKS: tl.constexpr = SPLIT_NUM_BLOCKS * 2
block_start_idx = tl.program_id(0) * PAIRED_SPLIT_NUM_BLOCKS
thread_idx = tl.arange(0, PAIRED_SPLIT_NUM_BLOCKS * BLOCK_SIZE)
offsets = block_start_idx * BLOCK_SIZE + thread_idx
mask = offsets < n_elements
A = tl.load(A_ptr + offsets, mask=mask, other=0.0)
# To be able process several blocks -> (PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE)
A_reshaped = tl.reshape(A, (PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE))
# Calculating absamax for each block
absmax = tl.max(tl.abs(A_reshaped), axis=1)
tl.store(absmax_ptr + block_start_idx + tl.arange(0, PAIRED_SPLIT_NUM_BLOCKS), absmax)
A_normalized = A_reshaped / absmax[:, None]
A_normalized = tl.clamp(A_normalized, -1.0, 1.0)
lower_pivot = tl.zeros((PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE), dtype=tl.int32)
upper_pivot = tl.full((PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32)
for _ in range(4): # ceil(log2(code_size)) = 4, actually, in general case should be input parameter
pivot = (lower_pivot + upper_pivot) // 2
val = tl.load(code_ptr + pivot)
is_higher = A_normalized > val # code[pivot]
lower_pivot = tl.where(is_higher, pivot, lower_pivot)
upper_pivot = tl.where(is_higher, upper_pivot, pivot)
# Choose closest level
lower_val = tl.load(code_ptr + lower_pivot)
upper_val = tl.load(code_ptr + upper_pivot)
lower_dist = tl.abs(A_normalized - lower_val)
upper_dist = tl.abs(A_normalized - upper_val)
quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8)
quantized = quantized.reshape((PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE // 2, 2))
quantized = quantized.to(tl.uint8, bitcast=True)
left, right = quantized.split()
packed = left << 4 | (right & 0xF)
# Reduce don't guarantee the order of the elements passed to unite_2_int4
# packed = tl.reduce(quantized, axis=2, combine_fn=unite_2_int4)
# packed = packed.to(tl.uint8, bitcast=True)
packed_flat = tl.reshape(packed, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,))
out_offsets = block_start_idx * BLOCK_SIZE // 2 + tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)
out_mask = out_offsets < n_elements // 2
tl.store(out_ptr + out_offsets, packed_flat, mask=out_mask)