| | import random |
| |
|
| | import pytest |
| | import torch |
| |
|
| | import activation |
| |
|
| | from .utils import assert_close, opcheck |
| |
|
| | DTYPES = [torch.float, torch.bfloat16, torch.half] |
| | NUM_TOKENS = [7, 83, 256, 2048] |
| | D = [1, 7, 512, 13824] |
| | SEEDS = [0] |
| | CUDA_DEVICES = [ |
| | f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) |
| | ] |
| |
|
| |
|
| | @pytest.mark.parametrize("num_tokens", NUM_TOKENS) |
| | @pytest.mark.parametrize("d", D) |
| | @pytest.mark.parametrize("dtype", DTYPES) |
| | @pytest.mark.parametrize("seed", SEEDS) |
| | @pytest.mark.parametrize("device", CUDA_DEVICES) |
| | def test_rms_norm( |
| | num_tokens: int, |
| | d: int, |
| | dtype: torch.dtype, |
| | seed: int, |
| | device: str, |
| | ) -> None: |
| | random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.set_default_device(device) |
| |
|
| | x = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True) |
| | weight = torch.randn(d, dtype=dtype, requires_grad=True) |
| | eps = 1e-05 |
| |
|
| | x.retain_grad() |
| | weight.retain_grad() |
| | |
| |
|
| | x_ref = x.detach().clone().requires_grad_(True) |
| | weight_ref = weight.detach().clone().requires_grad_(True) |
| |
|
| | torch_layer = torch.nn.RMSNorm(d, eps=eps, dtype=dtype) |
| | torch_layer.weight = torch.nn.Parameter(weight_ref) |
| |
|
| | op = activation.ops.rms_norm |
| | fn = activation.rms_norm |
| | layer = activation.layers.RMSNorm(d, eps=eps, dtype=dtype) |
| | layer.weight = torch.nn.Parameter(weight) |
| |
|
| | opcheck(op, (x, weight, eps)) |
| |
|
| | out = fn(x, weight, eps) |
| | mod_out = layer(x) |
| | ref_out = torch_layer(x_ref) |
| |
|
| | assert_close(out, ref_out) |
| | assert_close(mod_out, out, atol=0.0, rtol=0.0) |
| |
|
| | |
| | out_grad = torch.randn_like(out) |
| | out_grad = out_grad / out_grad.norm() |
| |
|
| | ref_out.backward(out_grad) |
| | mod_out.backward(out_grad) |
| |
|
| | assert_close(x.grad, x_ref.grad) |
| | assert_close(layer.weight.grad, torch_layer.weight.grad, rtol=0.05) |
| |
|