File size: 4,132 Bytes
a402b9b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | import unittest
import torch
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
w8a8_block_fp8_matmul,
)
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.test.test_utils import CustomTestCase
register_cuda_ci(est_time=132, suite="stage-b-test-large-1-gpu")
class TestFP8Base(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.M = 256
# test non-aligned
cls.N = 1024 + 64
cls.K = 512
cls.group_size = 128
cls.quant_type = torch.float8_e4m3fn
cls.output_type = torch.bfloat16
@staticmethod
def _make_A(M, K, group_size, out_dtype):
quant_A = torch.rand(
M, K // group_size, group_size, dtype=torch.float32, device="cuda"
)
# -1 ~ 1
quant_A = quant_A * 2 - 1
# scaling abs max to fmax
finfo = torch.finfo(out_dtype)
fmax = finfo.max
scaling = fmax / quant_A.abs().amax(-1, keepdim=True)
quant_A *= scaling
quant_A = quant_A.to(out_dtype).to(torch.float32)
# create scale and A
scale = torch.rand(M, K // group_size, dtype=torch.float32, device="cuda")
scale /= fmax
A = quant_A * scale[..., None]
A = A.reshape(M, K)
quant_A = quant_A.reshape(M, K).to(out_dtype)
return A, quant_A, scale
@staticmethod
def _make_B(K, N, group_size, out_dtype):
def _aligned_size(a, b):
return (a + b - 1) // b * b
K_aligned = _aligned_size(K, group_size)
N_aligned = _aligned_size(N, group_size)
quant_B = torch.rand(
K_aligned // group_size,
group_size,
N_aligned // group_size,
group_size,
dtype=torch.float32,
device="cuda",
)
quant_B = quant_B * 2 - 1
# scaling abs max to fmax
finfo = torch.finfo(out_dtype)
fmax = finfo.max
scaling = fmax / quant_B.abs().amax((1, 3), keepdim=True)
quant_B *= scaling
quant_B = quant_B.to(out_dtype).to(torch.float32)
scale = torch.rand(
K_aligned // group_size,
1,
N_aligned // group_size,
1,
dtype=torch.float32,
device="cuda",
)
scale /= fmax
B = quant_B * scale
B = B.reshape(K_aligned, N_aligned)[:K, :N]
quant_B = quant_B.reshape(K_aligned, N_aligned).to(out_dtype)[:K, :N]
scale = scale.reshape(K_aligned // group_size, N_aligned // group_size)
return B, quant_B, scale
class TestPerTokenGroupQuantFP8(TestFP8Base):
def test_per_token_group_quant_fp8(self):
if torch.cuda.get_device_capability()[0] < 9:
return
A, A_quant_gt, scale_gt = self._make_A(
M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type
)
A_quant, scale = per_token_group_quant_fp8(x=A, group_size=self.group_size)
torch.testing.assert_close(scale, scale_gt)
diff = (A_quant.to(torch.float16) - A_quant_gt.to(torch.float16)).abs()
diff_count = (diff > 1e-5).count_nonzero()
assert diff_count / diff.numel() < 1e-4
class TestW8A8BlockFP8Matmul(TestFP8Base):
def test_w8a8_block_fp8_matmul(self):
if torch.cuda.get_device_capability()[0] < 9:
return
A, A_quant_gt, A_scale_gt = self._make_A(
M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type
)
B, B_quant_gt, B_scale_gt = self._make_B(
K=self.K, N=self.N, group_size=self.group_size, out_dtype=self.quant_type
)
C_gt = A.to(self.output_type) @ B.to(self.output_type)
C = w8a8_block_fp8_matmul(
A=A_quant_gt,
B=B_quant_gt.T.contiguous(),
As=A_scale_gt,
Bs=B_scale_gt.T.contiguous(),
block_size=[128, 128],
output_dtype=self.output_type,
)
torch.testing.assert_close(C, C_gt, atol=0.5, rtol=1e-4)
if __name__ == "__main__":
unittest.main()
|