Hanrui / sglang /test /srt /cpu /test_norm.py
Lekr0's picture
Add files using upload-large-folder tool
a402b9b verified
import itertools
import unittest
from typing import Optional, Tuple, Union
import torch
from utils import make_non_contiguous, parametrize, precision
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
class TestNorm(CustomTestCase):
M = [4096, 1024]
N = [4096, 4096 + 13]
dtype = [torch.float16, torch.bfloat16]
def _forward_native(
self,
x: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float = 1e-6,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon)
x = x.to(orig_dtype) * weight
if residual is None:
return x
else:
return x, residual
def _norm(self, x, eps):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
def _gemma3_rmsnorm_native(
self, x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float = 1e-6
):
output = self._norm(x.float(), variance_epsilon)
output = output * (1.0 + weight.float())
return output.type_as(x)
def _gemma_rmsnorm_native(
self,
x: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float = 1e-6,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype
if residual is not None:
x = x + residual
residual = x
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon)
x = x * (1.0 + weight.float())
x = x.to(orig_dtype)
return x if residual is None else (x, residual)
def _norm_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype)
x = make_non_contiguous(x)
hidden_size = x.size(-1)
weight = torch.randn(hidden_size, dtype=dtype)
variance_epsilon = 1e-6
out = torch.ops.sgl_kernel.rmsnorm_cpu(x, weight, variance_epsilon)
ref_out = self._forward_native(x, weight, variance_epsilon)
atol = rtol = precision[ref_out.dtype]
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
ref_x = x.clone()
residual = torch.randn([m, hidden_size], dtype=dtype)
ref_residual = residual.clone()
torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
x, residual, weight, variance_epsilon
)
ref_x, ref_residual = self._forward_native(
ref_x, weight, variance_epsilon, ref_residual
)
torch.testing.assert_close(x, ref_x, atol=atol, rtol=rtol)
torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol)
def _l2norm_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype)
hidden_size = x.size(-1)
fake_ones_weight = torch.ones(hidden_size, dtype=dtype)
variance_epsilon = 1e-6
out = torch.ops.sgl_kernel.l2norm_cpu(x, variance_epsilon)
ref_out = self._forward_native(x, fake_ones_weight, variance_epsilon)
atol = rtol = precision[ref_out.dtype]
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def _gemma_rmsnorm_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype)
x = make_non_contiguous(x)
hidden_size = x.size(-1)
weight = torch.randn(hidden_size, dtype=dtype)
variance_epsilon = 1e-6
out = torch.ops.sgl_kernel.gemma_rmsnorm_cpu(x, weight, variance_epsilon)
ref_out = self._gemma_rmsnorm_native(x, weight, variance_epsilon)
atol = rtol = precision[ref_out.dtype]
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
ref_x = x.clone()
residual = torch.randn([m, hidden_size], dtype=dtype)
ref_residual = residual.clone()
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm_cpu(
x, residual, weight, variance_epsilon
)
ref_x, ref_residual = self._gemma_rmsnorm_native(
ref_x, weight, variance_epsilon, ref_residual
)
torch.testing.assert_close(x, ref_x, atol=atol, rtol=rtol)
torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol)
def _gemma3_rmsnorm_test(self, m, n, dtype):
x_list = [
torch.randn([m, n], dtype=dtype),
torch.randn([1, m, 2, n], dtype=dtype),
]
for x in x_list:
x = make_non_contiguous(x)
hidden_size = x.size(-1)
weight = torch.randn(hidden_size, dtype=dtype)
variance_epsilon = 1e-6
out = torch.ops.sgl_kernel.gemma3_rmsnorm_cpu(x, weight, variance_epsilon)
ref_out = self._gemma3_rmsnorm_native(x, weight, variance_epsilon)
atol = rtol = precision[ref_out.dtype]
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def test_norm(self):
for params in itertools.product(self.M, self.N, self.dtype):
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
self._norm_test(*params)
self._l2norm_test(*params)
self._gemma_rmsnorm_test(*params)
self._gemma3_rmsnorm_test(*params)
class TestFusedRMSNormGated(CustomTestCase):
M = [4096, 1024]
N = [4096, 4096 + 13]
dtype = [torch.float16, torch.bfloat16]
def _forward_native(
self,
hidden_states: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float = 1e-6,
gate: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
# Norm before gate
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
hidden_states = weight * hidden_states.to(input_dtype)
hidden_states = hidden_states * torch.nn.functional.silu(gate.to(torch.float32))
return hidden_states.to(input_dtype)
def _norm_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype)
x = make_non_contiguous(x)
batch_size = x.size(0)
hidden_size = x.size(-1)
weight = torch.randn(hidden_size, dtype=dtype)
variance_epsilon = 1e-6
gate = torch.randn([batch_size, hidden_size], dtype=dtype)
out = torch.ops.sgl_kernel.fused_rmsnorm_gated_cpu(
x, weight, gate, variance_epsilon
)
ref_out = self._forward_native(x, weight, variance_epsilon, gate)
atol = rtol = precision[ref_out.dtype] * 2
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def test_norm(self):
for params in itertools.product(self.M, self.N, self.dtype):
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
self._norm_test(*params)
class TestLayerNorm(CustomTestCase):
def _forward_native(
self,
x: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance, mean = torch.var_mean(x, dim=-1, keepdim=True, correction=0)
x = (x - mean) * torch.rsqrt(variance + variance_epsilon)
x = x.to(orig_dtype) * weight
if residual is None:
return x
else:
return x, residual
@parametrize(
m=[4096, 1024],
n=[4096, 4109],
dtype=[torch.float16, torch.bfloat16],
)
def test_norm(self, m: int, n: int, dtype: torch.dtype) -> None:
x_ln = torch.randn([m, n], dtype=dtype)
x_ln = make_non_contiguous(x_ln)
ref_x_ln = x_ln.clone()
hidden_size = x_ln.size(-1)
weight = torch.randn(hidden_size, dtype=dtype)
variance_epsilon = 1e-6
torch.ops.sgl_kernel.layernorm_cpu(x_ln, weight, variance_epsilon)
ref_ln_out = self._forward_native(ref_x_ln, weight, variance_epsilon)
atol = rtol = precision[ref_ln_out.dtype]
torch.testing.assert_close(x_ln, ref_ln_out, atol=atol, rtol=rtol)
x_add_ln = torch.randn([m, n], dtype=dtype)
x_add_ln = make_non_contiguous(x_add_ln)
ref_x_add_ln = x_add_ln.clone()
residual = torch.randn([m, hidden_size], dtype=dtype)
ref_residual = residual.clone()
torch.ops.sgl_kernel.fused_add_layernorm_cpu(
x_add_ln, residual, weight, variance_epsilon
)
ref_add_ln_out, ref_residual = self._forward_native(
ref_x_add_ln, weight, variance_epsilon, ref_residual
)
torch.testing.assert_close(x_add_ln, ref_add_ln_out, atol=atol, rtol=rtol)
torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol)
if __name__ == "__main__":
unittest.main()