import itertools import unittest from typing import Optional, Tuple, Union import torch from utils import make_non_contiguous, parametrize, precision from sglang.test.test_utils import CustomTestCase torch.manual_seed(1234) class TestNorm(CustomTestCase): M = [4096, 1024] N = [4096, 4096 + 13] dtype = [torch.float16, torch.bfloat16] def _forward_native( self, x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float = 1e-6, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: orig_dtype = x.dtype x = x.to(torch.float32) if residual is not None: x = x + residual.to(torch.float32) residual = x.to(orig_dtype) variance = x.pow(2).mean(dim=-1, keepdim=True) x = x * torch.rsqrt(variance + variance_epsilon) x = x.to(orig_dtype) * weight if residual is None: return x else: return x, residual def _norm(self, x, eps): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) def _gemma3_rmsnorm_native( self, x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float = 1e-6 ): output = self._norm(x.float(), variance_epsilon) output = output * (1.0 + weight.float()) return output.type_as(x) def _gemma_rmsnorm_native( self, x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float = 1e-6, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: orig_dtype = x.dtype if residual is not None: x = x + residual residual = x x = x.float() variance = x.pow(2).mean(dim=-1, keepdim=True) x = x * torch.rsqrt(variance + variance_epsilon) x = x * (1.0 + weight.float()) x = x.to(orig_dtype) return x if residual is None else (x, residual) def _norm_test(self, m, n, dtype): x = torch.randn([m, n], dtype=dtype) x = make_non_contiguous(x) hidden_size = x.size(-1) weight = torch.randn(hidden_size, dtype=dtype) variance_epsilon = 1e-6 out = torch.ops.sgl_kernel.rmsnorm_cpu(x, weight, variance_epsilon) ref_out = self._forward_native(x, weight, variance_epsilon) atol = rtol = precision[ref_out.dtype] torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) ref_x = x.clone() residual = torch.randn([m, hidden_size], dtype=dtype) ref_residual = residual.clone() torch.ops.sgl_kernel.fused_add_rmsnorm_cpu( x, residual, weight, variance_epsilon ) ref_x, ref_residual = self._forward_native( ref_x, weight, variance_epsilon, ref_residual ) torch.testing.assert_close(x, ref_x, atol=atol, rtol=rtol) torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol) def _l2norm_test(self, m, n, dtype): x = torch.randn([m, n], dtype=dtype) hidden_size = x.size(-1) fake_ones_weight = torch.ones(hidden_size, dtype=dtype) variance_epsilon = 1e-6 out = torch.ops.sgl_kernel.l2norm_cpu(x, variance_epsilon) ref_out = self._forward_native(x, fake_ones_weight, variance_epsilon) atol = rtol = precision[ref_out.dtype] torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) def _gemma_rmsnorm_test(self, m, n, dtype): x = torch.randn([m, n], dtype=dtype) x = make_non_contiguous(x) hidden_size = x.size(-1) weight = torch.randn(hidden_size, dtype=dtype) variance_epsilon = 1e-6 out = torch.ops.sgl_kernel.gemma_rmsnorm_cpu(x, weight, variance_epsilon) ref_out = self._gemma_rmsnorm_native(x, weight, variance_epsilon) atol = rtol = precision[ref_out.dtype] torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) ref_x = x.clone() residual = torch.randn([m, hidden_size], dtype=dtype) ref_residual = residual.clone() torch.ops.sgl_kernel.gemma_fused_add_rmsnorm_cpu( x, residual, weight, variance_epsilon ) ref_x, ref_residual = self._gemma_rmsnorm_native( ref_x, weight, variance_epsilon, ref_residual ) torch.testing.assert_close(x, ref_x, atol=atol, rtol=rtol) torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol) def _gemma3_rmsnorm_test(self, m, n, dtype): x_list = [ torch.randn([m, n], dtype=dtype), torch.randn([1, m, 2, n], dtype=dtype), ] for x in x_list: x = make_non_contiguous(x) hidden_size = x.size(-1) weight = torch.randn(hidden_size, dtype=dtype) variance_epsilon = 1e-6 out = torch.ops.sgl_kernel.gemma3_rmsnorm_cpu(x, weight, variance_epsilon) ref_out = self._gemma3_rmsnorm_native(x, weight, variance_epsilon) atol = rtol = precision[ref_out.dtype] torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) def test_norm(self): for params in itertools.product(self.M, self.N, self.dtype): with self.subTest(m=params[0], n=params[1], dtype=params[2]): self._norm_test(*params) self._l2norm_test(*params) self._gemma_rmsnorm_test(*params) self._gemma3_rmsnorm_test(*params) class TestFusedRMSNormGated(CustomTestCase): M = [4096, 1024] N = [4096, 4096 + 13] dtype = [torch.float16, torch.bfloat16] def _forward_native( self, hidden_states: torch.Tensor, weight: torch.Tensor, variance_epsilon: float = 1e-6, gate: Optional[torch.Tensor] = None, ) -> torch.Tensor: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) # Norm before gate hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon) hidden_states = weight * hidden_states.to(input_dtype) hidden_states = hidden_states * torch.nn.functional.silu(gate.to(torch.float32)) return hidden_states.to(input_dtype) def _norm_test(self, m, n, dtype): x = torch.randn([m, n], dtype=dtype) x = make_non_contiguous(x) batch_size = x.size(0) hidden_size = x.size(-1) weight = torch.randn(hidden_size, dtype=dtype) variance_epsilon = 1e-6 gate = torch.randn([batch_size, hidden_size], dtype=dtype) out = torch.ops.sgl_kernel.fused_rmsnorm_gated_cpu( x, weight, gate, variance_epsilon ) ref_out = self._forward_native(x, weight, variance_epsilon, gate) atol = rtol = precision[ref_out.dtype] * 2 torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) def test_norm(self): for params in itertools.product(self.M, self.N, self.dtype): with self.subTest(m=params[0], n=params[1], dtype=params[2]): self._norm_test(*params) class TestLayerNorm(CustomTestCase): def _forward_native( self, x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: orig_dtype = x.dtype x = x.to(torch.float32) if residual is not None: x = x + residual.to(torch.float32) residual = x.to(orig_dtype) variance, mean = torch.var_mean(x, dim=-1, keepdim=True, correction=0) x = (x - mean) * torch.rsqrt(variance + variance_epsilon) x = x.to(orig_dtype) * weight if residual is None: return x else: return x, residual @parametrize( m=[4096, 1024], n=[4096, 4109], dtype=[torch.float16, torch.bfloat16], ) def test_norm(self, m: int, n: int, dtype: torch.dtype) -> None: x_ln = torch.randn([m, n], dtype=dtype) x_ln = make_non_contiguous(x_ln) ref_x_ln = x_ln.clone() hidden_size = x_ln.size(-1) weight = torch.randn(hidden_size, dtype=dtype) variance_epsilon = 1e-6 torch.ops.sgl_kernel.layernorm_cpu(x_ln, weight, variance_epsilon) ref_ln_out = self._forward_native(ref_x_ln, weight, variance_epsilon) atol = rtol = precision[ref_ln_out.dtype] torch.testing.assert_close(x_ln, ref_ln_out, atol=atol, rtol=rtol) x_add_ln = torch.randn([m, n], dtype=dtype) x_add_ln = make_non_contiguous(x_add_ln) ref_x_add_ln = x_add_ln.clone() residual = torch.randn([m, hidden_size], dtype=dtype) ref_residual = residual.clone() torch.ops.sgl_kernel.fused_add_layernorm_cpu( x_add_ln, residual, weight, variance_epsilon ) ref_add_ln_out, ref_residual = self._forward_native( ref_x_add_ln, weight, variance_epsilon, ref_residual ) torch.testing.assert_close(x_add_ln, ref_add_ln_out, atol=atol, rtol=rtol) torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol) if __name__ == "__main__": unittest.main()