| |
|
|
| import pytest |
| import sgl_kernel |
| import torch |
| from sgl_kernel.utils import is_arch_support_pdl |
|
|
|
|
| def llama_rms_norm(x, w, eps=1e-6): |
| orig_dtype = x.dtype |
| x = x.float() |
| variance = x.pow(2).mean(dim=-1, keepdim=True) |
| x = x * torch.rsqrt(variance + eps) |
| x = x * w.float() |
| x = x.to(orig_dtype) |
| return x |
|
|
|
|
| def gemma_rms_norm(x, w, eps=1e-6): |
| orig_dtype = x.dtype |
| x = x.float() |
| variance = x.pow(2).mean(dim=-1, keepdim=True) |
| x = x * torch.rsqrt(variance + eps) |
| x = x * (1.0 + w.float()) |
| x = x.to(orig_dtype) |
| return x |
|
|
|
|
| def gemma_fused_add_rms_norm(x, residual, w, eps=1e-6): |
| orig_dtype = x.dtype |
| x = x + residual |
| residual = x |
| x = x.float() |
| variance = x.pow(2).mean(dim=-1, keepdim=True) |
| x = x * torch.rsqrt(variance + eps) |
| x = x * (1.0 + w.float()) |
| x = x.to(orig_dtype) |
| return x, residual |
|
|
|
|
| def fused_add_rms_norm(x, residual, weight, eps): |
| orig_dtype = x.dtype |
| x = x.to(torch.float32) |
| x = x + residual.to(torch.float32) |
| residual = x.to(orig_dtype) |
|
|
| variance = x.pow(2).mean(dim=-1, keepdim=True) |
| x = x * torch.rsqrt(variance + eps) |
| x = (x * weight.float()).to(orig_dtype) |
| return x, residual |
|
|
|
|
| @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) |
| @pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) |
| @pytest.mark.parametrize("dtype", [torch.float16]) |
| @pytest.mark.parametrize("specify_out", [True, False]) |
| def test_norm(batch_size, hidden_size, dtype, specify_out): |
| x = torch.randn(batch_size, hidden_size).to(0).to(dtype) |
| w = torch.randn(hidden_size).to(0).to(dtype) |
|
|
| y_ref = llama_rms_norm(x, w) |
| enable_pdl = is_arch_support_pdl() |
| if specify_out: |
| y = torch.empty_like(x) |
| sgl_kernel.rmsnorm(x, w, out=y, enable_pdl=enable_pdl) |
| else: |
| y = sgl_kernel.rmsnorm(x, w, enable_pdl=enable_pdl) |
|
|
| torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) |
|
|
|
|
| @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) |
| @pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) |
| @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) |
| def test_fused_add_rmsnorm(batch_size, hidden_size, dtype): |
| eps = 1e-6 |
|
|
| x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda") |
| residual = torch.randn_like(x) |
| weight = torch.randn(hidden_size, dtype=dtype, device="cuda") |
|
|
| x_native, residual_native = fused_add_rms_norm( |
| x.clone(), residual.clone(), weight, eps |
| ) |
|
|
| x_fused = x.clone() |
| residual_fused = residual.clone() |
| enable_pdl = is_arch_support_pdl() |
| sgl_kernel.fused_add_rmsnorm( |
| x_fused, residual_fused, weight, eps, enable_pdl=enable_pdl |
| ) |
|
|
| torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3) |
| torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3) |
|
|
|
|
| @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) |
| @pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) |
| @pytest.mark.parametrize("dtype", [torch.float16]) |
| @pytest.mark.parametrize("specify_out", [True, False]) |
| def test_gemma_norm(batch_size, hidden_size, dtype, specify_out): |
| x = torch.randn(batch_size, hidden_size).to(0).to(dtype) |
| w = torch.randn(hidden_size).to(0).to(dtype) |
|
|
| y_ref = gemma_rms_norm(x, w) |
| enable_pdl = is_arch_support_pdl() |
| if specify_out: |
| y = torch.empty_like(x) |
| sgl_kernel.gemma_rmsnorm(x, w, out=y, enable_pdl=enable_pdl) |
| else: |
| y = sgl_kernel.gemma_rmsnorm(x, w, enable_pdl=enable_pdl) |
|
|
| torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) |
|
|
|
|
| @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) |
| @pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) |
| @pytest.mark.parametrize("dtype", [torch.float16]) |
| def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype): |
| eps = 1e-6 |
|
|
| x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda") |
| residual = torch.randn_like(x) |
| weight = torch.randn(hidden_size, dtype=dtype, device="cuda") |
|
|
| x_native, residual_native = gemma_fused_add_rms_norm( |
| x.clone(), residual.clone(), weight, eps |
| ) |
|
|
| x_fused = x.clone() |
| residual_fused = residual.clone() |
| enable_pdl = is_arch_support_pdl() |
| sgl_kernel.gemma_fused_add_rmsnorm( |
| x_fused, residual_fused, weight, eps, enable_pdl=enable_pdl |
| ) |
|
|
| torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3) |
| torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3) |
|
|
|
|
| if __name__ == "__main__": |
| pytest.main([__file__]) |
|
|