| from typing import List, Optional, Tuple | |
| import torch | |
| from sglang.srt.layers.quantization.int8_kernel import ( | |
| per_token_group_quant_int8, | |
| w8a8_block_int8_matmul, | |
| ) | |
| def apply_w8a8_block_int8_linear( | |
| input: torch.Tensor, | |
| weight: torch.Tensor, | |
| block_size: List[int], | |
| weight_scale: torch.Tensor, | |
| input_scale: Optional[torch.Tensor] = None, | |
| bias: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| assert input_scale is None | |
| # View input as 2D matrix for fp8 methods | |
| input_2d = input.view(-1, input.shape[-1]) | |
| output_shape = [*input.shape[:-1], weight.shape[0]] | |
| q_input, x_scale = per_token_group_quant_int8(input_2d, block_size[1]) | |
| output = w8a8_block_int8_matmul( | |
| q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype | |
| ) | |
| if bias is not None: | |
| output = output + bias | |
| return output.to(dtype=input.dtype).view(*output_shape) | |
| def input_to_int8( | |
| x: torch.Tensor, dtype: torch.dtype = torch.int8 | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """This function quantizes input values to int8 values with tensor-wise quantization.""" | |
| iinfo = torch.iinfo(dtype) | |
| min_val, max_val = x.aminmax() | |
| amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) | |
| int8_min, int8_max = iinfo.min, iinfo.max | |
| scale = int8_max / amax | |
| x_scl_sat = (x * scale).clamp(min=int8_min, max=int8_max) | |
| return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() | |
| def block_dequant( | |
| x_q_block: torch.Tensor, | |
| x_s: torch.Tensor, | |
| block_size: List[int], | |
| ) -> torch.Tensor: | |
| """This function conducts block-wise dequantization. | |
| The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale | |
| and the block size. | |
| The outputs are dequantized tensor. | |
| """ | |
| block_n, block_k = block_size[0], block_size[1] | |
| n, k = x_q_block.shape | |
| n_tiles = (n + block_n - 1) // block_n | |
| k_tiles = (k + block_k - 1) // block_k | |
| assert n_tiles == x_s.shape[0] | |
| assert k_tiles == x_s.shape[1] | |
| x_dq_block = x_q_block.to(torch.float32) | |
| for i in range(k_tiles): | |
| for j in range(n_tiles): | |
| x_dq_block[ | |
| j * block_n : min((j + 1) * block_n, n), | |
| i * block_k : min((i + 1) * block_k, k), | |
| ] *= x_s[j][i] | |
| return x_dq_block | |
Xet Storage Details
- Size:
- 2.36 kB
- Xet hash:
- ce5272764f40f598ea59131561430df27a9e1640591427f716d294052ff3f59c
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.