| import torch |
|
|
| import triton |
| import triton.language as tl |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| @triton.jit |
| def dequant_8bit_kernel( |
| a_ptr, |
| out_ptr, |
| code_ptr, |
| absmax_ptr, |
| n, |
| QUANT_BLOCK: tl.constexpr, |
| SPLIT_SIZE: tl.constexpr, |
| ): |
| pid = tl.program_id(axis=0) |
| block_start = pid * SPLIT_SIZE |
| offsets = block_start + tl.arange(0, SPLIT_SIZE) |
| mask = offsets < n |
| out_dq = dequant_8bit_blockwise_kernel_util(a_ptr, offsets, code_ptr, absmax_ptr, mask, QUANT_BLOCK) |
| tl.store(out_ptr + offsets, out_dq, mask) |
|
|
|
|
| def dequant_8bit_blockwise( |
| a: torch.Tensor, |
| absmax: torch.Tensor, |
| quant_state_code: torch.Tensor, |
| quant_blocksize: int = 64, |
| dtype: torch.dtype = None, |
| out: torch.Tensor = None, |
| ): |
| n = a.numel() |
| if out is None: |
| if dtype is None: |
| raise ValueError("If out is None, dtype must be specified") |
| out = torch.empty_like(a, dtype=dtype, device=a.device) |
|
|
| SPLIT_SIZE = 256 |
| |
| grid = (triton.cdiv(n, SPLIT_SIZE),) |
| dequant_8bit_kernel[grid]( |
| a, |
| out, |
| quant_state_code, |
| absmax, |
| n, |
| quant_blocksize, |
| SPLIT_SIZE, |
| ) |
| return out |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| @triton.jit |
| def quantize_8bit_blockwise_kernel( |
| A_ptr, |
| code_ptr, |
| absmax_ptr, |
| out_ptr, |
| n_elements, |
| BLOCK_SIZE: tl.constexpr, |
| CODE_SIZE: tl.constexpr, |
| SPLIT_NUM_BLOCKS: tl.constexpr, |
| ): |
| block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS |
| thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE) |
|
|
| offsets = block_start_idx * BLOCK_SIZE + thread_idx |
| mask = offsets < n_elements |
|
|
| A = tl.load(A_ptr + offsets, mask=mask, other=0.0) |
|
|
| quantized, absmax = quantize_8bit_blockwise_kernel_util(A, code_ptr, CODE_SIZE, BLOCK_SIZE, SPLIT_NUM_BLOCKS) |
| tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax) |
| tl.store(out_ptr + offsets, quantized, mask=mask) |
|
|
|
|
| def quantize_blockwise_triton(A, code, blocksize, absmax=None, out=None): |
| n = A.numel() |
| blocks = -(n // -blocksize) |
|
|
| if absmax is None: |
| absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype) |
| if out is None: |
| out = torch.empty_like(A.flatten(), dtype=torch.uint8) |
|
|
| split_num_blocks = 1 |
| grid = (triton.cdiv(blocks, split_num_blocks),) |
| |
| quantize_8bit_blockwise_kernel[grid]( |
| A_ptr=A, |
| code_ptr=code, |
| absmax_ptr=absmax, |
| out_ptr=out, |
| n_elements=n, |
| BLOCK_SIZE=blocksize, |
| CODE_SIZE=code.numel(), |
| SPLIT_NUM_BLOCKS=split_num_blocks, |
| |
| |
| ) |
| out = out.reshape(A.shape) |
|
|
| return out, absmax |
|
|
|
|
| @triton.jit |
| def quantize_8bit_blockwise_kernel_util( |
| a, |
| code_ptr, |
| CODE_SIZE: tl.constexpr, |
| BLOCK_SIZE: tl.constexpr, |
| N_PER_TH: tl.constexpr, |
| ): |
| |
| a_reshaped = tl.reshape(a, (N_PER_TH, BLOCK_SIZE)) |
|
|
| |
| absmax = tl.max(tl.abs(a_reshaped), axis=1) |
|
|
| a_normalized = a_reshaped / absmax[:, None] |
| a_normalized = tl.clamp(a_normalized, -1.0, 1.0) |
|
|
| lower_pivot = tl.zeros((N_PER_TH, BLOCK_SIZE), dtype=tl.int32) |
| upper_pivot = tl.full((N_PER_TH, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32) |
|
|
| |
| for _ in range(8): |
| pivot = (lower_pivot + upper_pivot) // 2 |
| val = tl.load(code_ptr + pivot) |
| is_higher = a_normalized > val |
| lower_pivot = tl.where(is_higher, pivot, lower_pivot) |
| upper_pivot = tl.where(is_higher, upper_pivot, pivot) |
|
|
| |
| lower_val = tl.load(code_ptr + lower_pivot) |
| upper_val = tl.load(code_ptr + upper_pivot) |
| lower_dist = tl.abs(a_normalized - lower_val) |
| upper_dist = tl.abs(a_normalized - upper_val) |
| quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8) |
|
|
| |
| |
| |
|
|
| quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * N_PER_TH,)) |
| return quantized_flat, absmax |
|
|
|
|
| @triton.jit |
| def dequant_8bit_blockwise_kernel_util( |
| a_ptr, |
| offsets, |
| code_ptr, |
| absmax_ptr, |
| mask, |
| BLOCK_SIZE: tl.constexpr, |
| ): |
| a = tl.load(a_ptr + offsets, mask, other=0).to(tl.uint8) |
| scaled_int8 = tl.load(code_ptr + a, mask) |
| |
| absmax_offsets = offsets // BLOCK_SIZE |
| absmax = tl.load(absmax_ptr + absmax_offsets, mask=mask, other=0.0, eviction_policy="evict_last") |
| |
| out_dq = scaled_int8 * absmax |
| return out_dq |
|
|