| import pytest |
| import torch |
| from sgl_kernel import fp8_scaled_mm |
|
|
|
|
| def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): |
| o = torch.matmul(a.to(torch.float32), b.to(torch.float32)) |
| o = o.to(torch.float32) |
| temp1 = o * scale_a.view(-1, 1) |
| temp2 = temp1 * scale_b.view(1, -1) |
| final = temp2.to(out_dtype) |
| if bias is not None: |
| final = final + bias.view(1, -1) |
| return final |
|
|
|
|
| def _test_accuracy_once(M, N, K, with_bias, out_dtype, device): |
| fp8_info = torch.finfo(torch.float8_e4m3fn) |
| fp8_max, fp8_min = fp8_info.max, fp8_info.min |
| a_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max |
| a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) |
| b_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max |
| b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) |
| scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001 |
| scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001 |
| if with_bias: |
| bias = torch.randn((N,), device=device, dtype=out_dtype) |
| else: |
| bias = None |
| b_fp8 = b_fp8.t() |
| o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias) |
| o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias) |
| rtol = 0.02 |
| atol = 1 |
| torch.testing.assert_close(o, o1, rtol=rtol, atol=atol) |
| print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK") |
|
|
|
|
| @pytest.mark.parametrize("M", [1, 128, 512, 1024, 4096]) |
| @pytest.mark.parametrize("N", [16, 128, 512, 1024, 4096]) |
| @pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384]) |
| @pytest.mark.parametrize("with_bias", [True, False]) |
| @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) |
| def test_accuracy(M, N, K, with_bias, out_dtype): |
| _test_accuracy_once(M, N, K, with_bias, out_dtype, "cuda") |
|
|
|
|
| if __name__ == "__main__": |
| pytest.main([__file__]) |
|
|