| import functools | |
| import json | |
| import logging | |
| import os | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import torch | |
| import triton | |
| import triton.language as tl | |
| from sglang.srt.utils import get_device_name, is_cuda | |
| _is_cuda = is_cuda() | |
| if _is_cuda: | |
| # Temporary | |
| try: | |
| from sgl_kernel import sgl_per_token_group_quant_8bit | |
| enable_sgl_per_token_group_quant_8bit = True | |
| except ImportError: | |
| from sgl_kernel import sgl_per_token_group_quant_int8 | |
| enable_sgl_per_token_group_quant_8bit = False | |
| logger = logging.getLogger(__name__) | |
| def _per_token_quant_int8( | |
| x_ptr, | |
| xq_ptr, | |
| scale_ptr, | |
| x_sum_ptr, | |
| stride_x, | |
| stride_xq, | |
| N, | |
| CAL_SUM: tl.constexpr, | |
| BLOCK: tl.constexpr, | |
| ): | |
| # Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282 | |
| row_id = tl.program_id(0) | |
| cols = tl.arange(0, BLOCK) | |
| mask = cols < N | |
| x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, other=0.0).to(tl.float32) | |
| absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10) | |
| scale_x = absmax / 127 | |
| x_q = x * (127 / absmax) | |
| x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8) | |
| if CAL_SUM: | |
| x_sum = tl.sum(x, axis=0) | |
| tl.store(x_sum_ptr + row_id, x_sum.to(x_sum_ptr.dtype.element_ty)) | |
| tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask) | |
| tl.store(scale_ptr + row_id, scale_x.to(scale_ptr.dtype.element_ty)) | |
| def per_token_quant_int8(x, scale_dtype=torch.float32, cal_sum=False): | |
| M = x.numel() // x.shape[-1] | |
| N = x.shape[-1] | |
| x_q = torch.empty_like(x, device=x.device, dtype=torch.int8) | |
| scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=scale_dtype) | |
| if cal_sum: | |
| x_sum = torch.empty(x.shape[:-1], device=x.device, dtype=x.dtype) | |
| else: | |
| x_sum = None | |
| BLOCK = triton.next_power_of_2(N) | |
| # heuristics for number of warps | |
| num_warps = min(max(BLOCK // 256, 1), 8) | |
| assert x.is_contiguous() | |
| _per_token_quant_int8[(M,)]( | |
| x, | |
| x_q, | |
| scales, | |
| x_sum, | |
| stride_x=x.stride(-2), | |
| stride_xq=x_q.stride(-2), | |
| N=N, | |
| CAL_SUM=cal_sum, | |
| BLOCK=BLOCK, | |
| num_warps=num_warps, | |
| num_stages=1, | |
| ) | |
| if cal_sum: | |
| return x_q, scales, x_sum | |
| else: | |
| return x_q, scales | |
| def _per_token_group_quant_int8( | |
| # Pointers to inputs and output | |
| y_ptr, | |
| y_q_ptr, | |
| y_s_ptr, | |
| # Stride of input | |
| y_stride, | |
| # Columns of input | |
| N, | |
| # Avoid to divide zero | |
| eps, | |
| # Information for int8 | |
| int8_min, | |
| int8_max, | |
| # Meta-parameters | |
| BLOCK: tl.constexpr, | |
| ): | |
| """A Triton-accelerated function to perform per-token-group quantization on a | |
| tensor. | |
| This function converts the tensor values into int8 values. | |
| """ | |
| # Map the program id to the row of X and Y it should compute. | |
| g_id = tl.program_id(0) | |
| y_ptr += g_id * y_stride | |
| y_q_ptr += g_id * y_stride | |
| y_s_ptr += g_id | |
| cols = tl.arange(0, BLOCK) # N <= BLOCK | |
| mask = cols < N | |
| y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) | |
| # Quant | |
| _absmax = tl.maximum(tl.max(tl.abs(y)), eps) | |
| y_s = _absmax / int8_max | |
| y_q = tl.clamp(y / y_s, int8_min, int8_max).to(y_q_ptr.dtype.element_ty) | |
| tl.store(y_q_ptr + cols, y_q, mask=mask) | |
| tl.store(y_s_ptr, y_s) | |
| def per_token_group_quant_int8( | |
| x: torch.Tensor, | |
| group_size: int, | |
| eps: float = 1e-10, | |
| dtype: torch.dtype = torch.int8, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Function to perform per-token-group quantization on an input tensor `x`. | |
| It converts the tensor values into signed int8 values and returns the | |
| quantized tensor along with the scaling factor used for quantization. | |
| Args: | |
| x: The input tenosr with ndim >= 2. | |
| group_size: The group size used for quantization. | |
| eps: The minimum to avoid dividing zero. | |
| dtype: The dype of output tensor. Note that only `torch.int8` is supported for now. | |
| Returns: | |
| Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. | |
| """ | |
| assert ( | |
| x.shape[-1] % group_size == 0 | |
| ), "the last dimension of `x` cannot be divisible by `group_size`" | |
| assert x.is_contiguous(), "`x` is not contiguous" | |
| iinfo = torch.iinfo(dtype) | |
| int8_max = iinfo.max | |
| int8_min = iinfo.min | |
| x_q = torch.empty_like(x, device=x.device, dtype=dtype) | |
| M = x.numel() // group_size | |
| N = group_size | |
| x_s = torch.empty( | |
| x.shape[:-1] + (x.shape[-1] // group_size,), | |
| device=x.device, | |
| dtype=torch.float32, | |
| ) | |
| BLOCK = triton.next_power_of_2(N) | |
| # heuristics for number of warps | |
| num_warps = min(max(BLOCK // 256, 1), 8) | |
| num_stages = 1 | |
| _per_token_group_quant_int8[(M,)]( | |
| x, | |
| x_q, | |
| x_s, | |
| group_size, | |
| N, | |
| eps, | |
| int8_min=int8_min, | |
| int8_max=int8_max, | |
| BLOCK=BLOCK, | |
| num_warps=num_warps, | |
| num_stages=num_stages, | |
| ) | |
| return x_q, x_s | |
| def sglang_per_token_group_quant_int8( | |
| x: torch.Tensor, | |
| group_size: int, | |
| eps: float = 1e-10, | |
| dtype: torch.dtype = torch.int8, | |
| enable_v2: Optional[bool] = None, | |
| ): | |
| assert ( | |
| x.shape[-1] % group_size == 0 | |
| ), "the last dimension of `x` cannot be divisible by `group_size`" | |
| assert x.is_contiguous(), "`x` is not contiguous" | |
| iinfo = torch.iinfo(dtype) | |
| int8_max = iinfo.max | |
| int8_min = iinfo.min | |
| x_q = torch.empty_like(x, device=x.device, dtype=dtype) | |
| x_s = torch.empty( | |
| x.shape[:-1] + (x.shape[-1] // group_size,), | |
| device=x.device, | |
| dtype=torch.float32, | |
| ) | |
| # Temporary | |
| if enable_sgl_per_token_group_quant_8bit: | |
| sgl_per_token_group_quant_8bit( | |
| x, x_q, x_s, group_size, eps, int8_min, int8_max, enable_v2=enable_v2 | |
| ) | |
| else: | |
| assert not enable_v2 | |
| sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max) | |
| return x_q, x_s | |
| def _w8a8_block_int8_matmul( | |
| # Pointers to inputs and output | |
| A, | |
| B, | |
| C, | |
| As, | |
| Bs, | |
| # Shape for matmul | |
| M, | |
| N, | |
| K, | |
| # Block size for block-wise quantization | |
| group_n, | |
| group_k, | |
| # Stride for inputs and output | |
| stride_am, | |
| stride_ak, | |
| stride_bk, | |
| stride_bn, | |
| stride_cm, | |
| stride_cn, | |
| stride_As_m, | |
| stride_As_k, | |
| stride_Bs_k, | |
| stride_Bs_n, | |
| # Meta-parameters | |
| BLOCK_SIZE_M: tl.constexpr, | |
| BLOCK_SIZE_N: tl.constexpr, | |
| BLOCK_SIZE_K: tl.constexpr, | |
| GROUP_SIZE_M: tl.constexpr, | |
| ): | |
| """Triton-accelerated function used to perform linear operations (dot | |
| product) on input tensors `A` and `B` with block-wise quantization, and store the result in output | |
| tensor `C`. | |
| """ | |
| pid = tl.program_id(axis=0) | |
| num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) | |
| num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) | |
| num_pid_in_group = GROUP_SIZE_M * num_pid_n | |
| group_id = pid // num_pid_in_group | |
| first_pid_m = group_id * GROUP_SIZE_M | |
| group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) | |
| pid_m = first_pid_m + (pid % group_size_m) | |
| pid_n = (pid % num_pid_in_group) // group_size_m | |
| offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M | |
| offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N | |
| offs_k = tl.arange(0, BLOCK_SIZE_K) | |
| a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) | |
| b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) | |
| As_ptrs = As + offs_am * stride_As_m | |
| offs_bsn = offs_bn // group_n | |
| Bs_ptrs = Bs + offs_bsn * stride_Bs_n | |
| accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) | |
| for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): | |
| a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) | |
| b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) | |
| k_start = k * BLOCK_SIZE_K | |
| offs_ks = k_start // group_k | |
| a_s = tl.load(As_ptrs + offs_ks * stride_As_k) | |
| b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) | |
| accumulator += tl.dot(a, b).to(tl.float32) * a_s[:, None] * b_s[None, :] | |
| a_ptrs += BLOCK_SIZE_K * stride_ak | |
| b_ptrs += BLOCK_SIZE_K * stride_bk | |
| if C.dtype.element_ty == tl.bfloat16: | |
| c = accumulator.to(tl.bfloat16) | |
| elif C.dtype.element_ty == tl.float16: | |
| c = accumulator.to(tl.float16) | |
| else: | |
| c = accumulator.to(tl.float32) | |
| offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | |
| offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | |
| c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] | |
| c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) | |
| tl.store(c_ptrs, c, mask=c_mask) | |
| def get_w8a8_block_int8_configs( | |
| N: int, K: int, block_n: int, block_k: int | |
| ) -> Optional[Dict[int, Any]]: | |
| """ | |
| Return optimized configurations for the w8a8 block fp8 kernel. | |
| The return value will be a dictionary that maps an irregular grid of | |
| batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the | |
| kernel on a given batch size bs, the closest batch size in the grid should | |
| be picked and the associated configuration chosen to invoke the kernel. | |
| """ | |
| # First look up if an optimized configuration is available in the configs | |
| # directory | |
| device_name = get_device_name().replace(" ", "_") | |
| json_file_name = f"N={N},K={K},device_name={device_name},dtype=int8_w8a8,block_shape=[{block_n}, {block_k}].json" | |
| config_file_path = os.path.join( | |
| os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name | |
| ) | |
| if os.path.exists(config_file_path): | |
| with open(config_file_path) as f: | |
| logger.info( | |
| "Using configuration from %s for W8A8 Block INT8 kernel.", | |
| config_file_path, | |
| ) | |
| # If a configuration has been found, return it | |
| return {int(key): val for key, val in json.load(f).items()} | |
| # If no optimized configuration is available, we will use the default | |
| # configuration | |
| logger.warning( | |
| ( | |
| "Using default W8A8 Block INT8 kernel config. Performance might be sub-optimal! " | |
| "Config file not found at %s" | |
| ), | |
| config_file_path, | |
| ) | |
| return None | |
| def w8a8_block_int8_matmul( | |
| A: torch.Tensor, | |
| B: torch.Tensor, | |
| As: torch.Tensor, | |
| Bs: torch.Tensor, | |
| block_size: List[int], | |
| output_dtype: torch.dtype = torch.float16, | |
| ) -> torch.Tensor: | |
| """This function performs matrix multiplication with block-wise quantization. | |
| It takes two input tensors `A` and `B` with scales `As` and `Bs`. | |
| The output is returned in the specified `output_dtype`. | |
| Args: | |
| A: The input tensor, e.g., activation. | |
| B: The input tensor, e.g., weight. | |
| As: The per-token-group quantization scale for `A`. | |
| Bs: The per-block quantization scale for `B`. | |
| block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128]. | |
| output_dytpe: The dtype of the returned tensor. | |
| Returns: | |
| torch.Tensor: The result of matmul. | |
| """ | |
| assert len(block_size) == 2 | |
| block_n, block_k = block_size[0], block_size[1] | |
| assert A.shape[-1] == B.shape[-1] | |
| assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() | |
| assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] | |
| M = A.numel() // A.shape[-1] | |
| assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 | |
| N, K = B.shape | |
| assert triton.cdiv(N, block_n) == Bs.shape[0] | |
| assert triton.cdiv(K, block_k) == Bs.shape[1] | |
| C_shape = A.shape[:-1] + (N,) | |
| C = A.new_empty(C_shape, dtype=output_dtype) | |
| configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1]) | |
| if configs: | |
| # If an optimal configuration map has been found, look up the | |
| # optimal config | |
| config = configs[min(configs.keys(), key=lambda x: abs(x - M))] | |
| else: | |
| # Default config | |
| # Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1] | |
| config = { | |
| "BLOCK_SIZE_M": 64, | |
| "BLOCK_SIZE_N": block_size[0], | |
| "BLOCK_SIZE_K": block_size[1], | |
| "GROUP_SIZE_M": 32, | |
| "num_warps": 4, | |
| "num_stages": 3, | |
| } | |
| def grid(META): | |
| return ( | |
| triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), | |
| ) | |
| _w8a8_block_int8_matmul[grid]( | |
| A, | |
| B, | |
| C, | |
| As, | |
| Bs, | |
| M, | |
| N, | |
| K, | |
| block_n, | |
| block_k, | |
| A.stride(-2), | |
| A.stride(-1), | |
| B.stride(1), | |
| B.stride(0), | |
| C.stride(-2), | |
| C.stride(-1), | |
| As.stride(-2), | |
| As.stride(-1), | |
| Bs.stride(1), | |
| Bs.stride(0), | |
| **config, | |
| ) | |
| return C | |
Xet Storage Details
- Size:
- 13.1 kB
- Xet hash:
- 3f4c19b46866b76928048088464420b7f448acee1ae115886d1c86c0ed86449a
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.