| import itertools |
| from typing import Optional, Tuple |
|
|
| import pytest |
| import torch |
| from sgl_kernel import awq_dequantize |
|
|
|
|
| def reverse_awq_order(t: torch.Tensor): |
| bits = 4 |
| AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] |
| reverse_order_tensor = torch.arange( |
| t.shape[-1], |
| dtype=torch.int32, |
| device=t.device, |
| ) |
| reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits) |
| reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER] |
| reverse_order_tensor = reverse_order_tensor.view(-1) |
|
|
| t = t[:, reverse_order_tensor] & 0xF |
| return t |
|
|
|
|
| |
| |
| |
| def awq_dequantize_torch( |
| qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, group_size: int |
| ) -> torch.Tensor: |
|
|
| if group_size == -1: |
| group_size = qweight.shape[0] |
|
|
| bits = 4 |
| shifts = torch.arange(0, 32, bits, device=qzeros.device) |
|
|
| iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to( |
| torch.int8 |
| ) |
|
|
| iweights = iweights.view(iweights.shape[0], -1) |
|
|
| zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to( |
| torch.int8 |
| ) |
| zeros = zeros.view(qzeros.shape[0], -1) |
| zeros = reverse_awq_order(zeros) |
|
|
| iweights = reverse_awq_order(iweights) |
|
|
| iweights = torch.bitwise_and(iweights, (2**bits) - 1) |
| zeros = torch.bitwise_and(zeros, (2**bits) - 1) |
|
|
| scales = scales.repeat_interleave(group_size, dim=0) |
| zeros = zeros.repeat_interleave(group_size, dim=0) |
| return (iweights - zeros) * scales |
|
|
|
|
| def sglang_awq_dequantize( |
| qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor |
| ) -> torch.Tensor: |
| return awq_dequantize(qweight, scales, qzeros) |
|
|
|
|
| @pytest.mark.parametrize( |
| "qweight_row,qweight_col,is_bf16_act", |
| list( |
| itertools.product( |
| [3584, 18944, 128, 256, 512, 1024, 1536], |
| [448, 576, 4736, 16, 32, 64, 128, 72], |
| [True, False], |
| ) |
| ), |
| ) |
| def test_awq_dequant_compare_implementations( |
| qweight_row: int, qweight_col: int, is_bf16_act: bool |
| ): |
| device = torch.device("cuda") |
| qweight = torch.randint( |
| 0, |
| torch.iinfo(torch.int32).max, |
| (qweight_row, qweight_col), |
| dtype=torch.int32, |
| device=device, |
| ) |
| group_size = qweight_row |
| scales_row = qweight_row // group_size |
| scales_col = qweight_col * 8 |
|
|
| if is_bf16_act: |
| scales = torch.rand(scales_row, scales_col, dtype=torch.bfloat16, device=device) |
| else: |
| scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device) |
|
|
| qzeros = torch.randint( |
| 0, |
| torch.iinfo(torch.int32).max, |
| (scales_row, qweight_col), |
| dtype=torch.int32, |
| device=device, |
| ) |
|
|
| |
| torch_out = awq_dequantize_torch(qweight, scales, qzeros, group_size) |
| sglang_out = sglang_awq_dequantize(qweight, scales, qzeros) |
|
|
| |
| torch.testing.assert_close( |
| torch_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5 |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| pytest.main([__file__]) |
|
|