| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | from tests.test_utils import assert_expected |
| |
|
| | from torchmultimodal.modules.layers.normalizations import ( |
| | Fp32GroupNorm, |
| | Fp32LayerNorm, |
| | RMSNorm, |
| | SimpleRMSNorm, |
| | ) |
| |
|
| |
|
| | def test_fp32layernorm(): |
| | x = torch.ones(1, 1, dtype=torch.float16) |
| | norm = Fp32LayerNorm(1) |
| | output = norm(x) |
| | assert output.dtype == torch.float16 |
| |
|
| |
|
| | def test_fp32groupnorm(): |
| | x = torch.ones(2, 4, dtype=torch.float16) |
| | norm = Fp32GroupNorm(2, 4) |
| | output = norm(x) |
| | assert output.dtype == torch.float16 |
| |
|
| |
|
| | def test_rms_norm_core_algo(): |
| | """compare RMSNorm with RMSNorm using F.norm version""" |
| | dims = 10 |
| | rms_norm = RMSNorm(dims) |
| |
|
| | input_ones = torch.ones(dims, dtype=torch.float) |
| |
|
| | input_fixed = torch.tensor( |
| | [0.999, 1.1111, 2.222, 3.333, 4.444, 5.555, 6.678, 7.987, 8.123, 9.101010], |
| | dtype=torch.float16, |
| | ) |
| | fixed_expected = torch.tensor( |
| | [ |
| | 0.1749, |
| | 0.1946, |
| | 0.3892, |
| | 0.5835, |
| | 0.7783, |
| | 0.9727, |
| | 1.1699, |
| | 1.3984, |
| | 1.4229, |
| | 1.5938, |
| | ], |
| | dtype=torch.float, |
| | ) |
| |
|
| | output_fixed = rms_norm(input_fixed) |
| | output_ones = rms_norm(input_ones) |
| |
|
| | assert_expected(output_ones, input_ones) |
| | assert_expected(output_fixed, fixed_expected, atol=1e-04, rtol=1e-05) |
| | assert output_fixed.dtype == torch.float32 |
| |
|
| |
|
| | def test_simple_rmsnorm(): |
| | dims = 12 |
| | srms_norm = SimpleRMSNorm(dims) |
| |
|
| | input_bf16_ones = torch.ones(dims, dtype=torch.bfloat16) |
| |
|
| | input_fixed_fp32 = torch.tensor( |
| | [ |
| | 0.999, |
| | 1.1111, |
| | 2.222, |
| | 3.333, |
| | 4.444, |
| | 5.555, |
| | 6.678, |
| | 7.987, |
| | 8.123, |
| | 9.101010, |
| | 110.00, |
| | 120.2589, |
| | ], |
| | dtype=torch.float32, |
| | ) |
| |
|
| | expected_output_bf16_ones = torch.tensor( |
| | [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], |
| | dtype=torch.bfloat16, |
| | ) |
| | expected_output_fixed = torch.tensor( |
| | [ |
| | 0.0211, |
| | 0.0235, |
| | 0.0469, |
| | 0.0704, |
| | 0.0939, |
| | 0.1174, |
| | 0.1411, |
| | 0.1687, |
| | 0.1716, |
| | 0.1923, |
| | 2.3238, |
| | 2.5405, |
| | ], |
| | dtype=torch.float32, |
| | ) |
| |
|
| | actual_output_bf16_ones = srms_norm(input_bf16_ones) |
| | actual_output_fixed = srms_norm(input_fixed_fp32) |
| |
|
| | |
| | assert_expected( |
| | actual_output_bf16_ones, expected_output_bf16_ones, atol=1e-04, rtol=1e-05 |
| | ) |
| | assert actual_output_bf16_ones.dtype == torch.bfloat16 |
| |
|
| | |
| | assert_expected(actual_output_fixed, expected_output_fixed, atol=1e-04, rtol=1e-05) |
| | assert actual_output_fixed.dtype == torch.float32 |
| |
|