| import itertools | |
| import unittest | |
| import torch | |
| from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm | |
| from sglang.test.test_utils import CustomTestCase | |
| class TestRMSNorm(CustomTestCase): | |
| DTYPES = [torch.half, torch.bfloat16] | |
| NUM_TOKENS = [7, 83, 4096] | |
| HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199] | |
| ADD_RESIDUAL = [False, True] | |
| SEEDS = [0] | |
| def setUpClass(cls): | |
| if not torch.cuda.is_available(): | |
| raise unittest.SkipTest("CUDA is not available") | |
| torch.set_default_device("cuda") | |
| def _run_rms_norm_test(self, num_tokens, hidden_size, add_residual, dtype, seed): | |
| torch.manual_seed(seed) | |
| layer = RMSNorm(hidden_size).to(dtype=dtype) | |
| layer.weight.data.normal_(mean=1.0, std=0.1) | |
| scale = 1 / (2 * hidden_size) | |
| x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale | |
| residual = torch.randn_like(x) * scale if add_residual else None | |
| with torch.inference_mode(): | |
| ref_out = layer.forward_native(x, residual) | |
| out = layer(x, residual) | |
| if add_residual: | |
| self.assertTrue(torch.allclose(out[0], ref_out[0], atol=1e-2, rtol=1e-2)) | |
| self.assertTrue(torch.allclose(out[1], ref_out[1], atol=1e-2, rtol=1e-2)) | |
| else: | |
| self.assertTrue(torch.allclose(out, ref_out, atol=1e-2, rtol=1e-2)) | |
| def test_rms_norm(self): | |
| for params in itertools.product( | |
| self.NUM_TOKENS, | |
| self.HIDDEN_SIZES, | |
| self.ADD_RESIDUAL, | |
| self.DTYPES, | |
| self.SEEDS, | |
| ): | |
| with self.subTest( | |
| num_tokens=params[0], | |
| hidden_size=params[1], | |
| add_residual=params[2], | |
| dtype=params[3], | |
| seed=params[4], | |
| ): | |
| self._run_rms_norm_test(*params) | |
| class TestGemmaRMSNorm(CustomTestCase): | |
| DTYPES = [torch.half, torch.bfloat16] | |
| NUM_TOKENS = [7, 83, 4096] | |
| HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199] | |
| ADD_RESIDUAL = [False, True] | |
| SEEDS = [0] | |
| def setUpClass(cls): | |
| if not torch.cuda.is_available(): | |
| raise unittest.SkipTest("CUDA is not available") | |
| torch.set_default_device("cuda") | |
| def _run_gemma_rms_norm_test( | |
| self, num_tokens, hidden_size, add_residual, dtype, seed | |
| ): | |
| torch.manual_seed(seed) | |
| layer = GemmaRMSNorm(hidden_size).to(dtype=dtype) | |
| layer.weight.data.normal_(mean=1.0, std=0.1) | |
| scale = 1 / (2 * hidden_size) | |
| x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale | |
| residual = torch.randn_like(x) * scale if add_residual else None | |
| with torch.inference_mode(): | |
| ref_out = layer.forward_native(x, residual) | |
| out = layer(x, residual) | |
| if add_residual: | |
| self.assertTrue(torch.allclose(out[0], ref_out[0], atol=1e-3, rtol=1e-3)) | |
| self.assertTrue(torch.allclose(out[1], ref_out[1], atol=1e-3, rtol=1e-3)) | |
| else: | |
| self.assertTrue(torch.allclose(out, ref_out, atol=1e-3, rtol=1e-3)) | |
| def test_gemma_rms_norm(self): | |
| for params in itertools.product( | |
| self.NUM_TOKENS, | |
| self.HIDDEN_SIZES, | |
| self.ADD_RESIDUAL, | |
| self.DTYPES, | |
| self.SEEDS, | |
| ): | |
| with self.subTest( | |
| num_tokens=params[0], | |
| hidden_size=params[1], | |
| add_residual=params[2], | |
| dtype=params[3], | |
| seed=params[4], | |
| ): | |
| self._run_gemma_rms_norm_test(*params) | |
| if __name__ == "__main__": | |
| unittest.main(verbosity=2) | |
Xet Storage Details
- Size:
- 3.76 kB
- Xet hash:
- 4f17d629e18babf35dd8ae20c30bfd8638968de38ead01f8a832cb4b3598fd04
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.