File size: 3,068 Bytes
f233443 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from tests.test_utils import assert_expected
from torchmultimodal.modules.layers.normalizations import (
Fp32GroupNorm,
Fp32LayerNorm,
RMSNorm,
SimpleRMSNorm,
)
def test_fp32layernorm():
x = torch.ones(1, 1, dtype=torch.float16)
norm = Fp32LayerNorm(1)
output = norm(x)
assert output.dtype == torch.float16
def test_fp32groupnorm():
x = torch.ones(2, 4, dtype=torch.float16)
norm = Fp32GroupNorm(2, 4)
output = norm(x)
assert output.dtype == torch.float16
def test_rms_norm_core_algo():
"""compare RMSNorm with RMSNorm using F.norm version"""
dims = 10
rms_norm = RMSNorm(dims)
input_ones = torch.ones(dims, dtype=torch.float)
input_fixed = torch.tensor(
[0.999, 1.1111, 2.222, 3.333, 4.444, 5.555, 6.678, 7.987, 8.123, 9.101010],
dtype=torch.float16,
)
fixed_expected = torch.tensor(
[
0.1749,
0.1946,
0.3892,
0.5835,
0.7783,
0.9727,
1.1699,
1.3984,
1.4229,
1.5938,
],
dtype=torch.float,
)
output_fixed = rms_norm(input_fixed)
output_ones = rms_norm(input_ones)
assert_expected(output_ones, input_ones)
assert_expected(output_fixed, fixed_expected, atol=1e-04, rtol=1e-05)
assert output_fixed.dtype == torch.float32
def test_simple_rmsnorm():
dims = 12
srms_norm = SimpleRMSNorm(dims)
input_bf16_ones = torch.ones(dims, dtype=torch.bfloat16)
input_fixed_fp32 = torch.tensor(
[
0.999,
1.1111,
2.222,
3.333,
4.444,
5.555,
6.678,
7.987,
8.123,
9.101010,
110.00,
120.2589,
],
dtype=torch.float32,
)
expected_output_bf16_ones = torch.tensor(
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
dtype=torch.bfloat16,
)
expected_output_fixed = torch.tensor(
[
0.0211,
0.0235,
0.0469,
0.0704,
0.0939,
0.1174,
0.1411,
0.1687,
0.1716,
0.1923,
2.3238,
2.5405,
],
dtype=torch.float32,
)
actual_output_bf16_ones = srms_norm(input_bf16_ones)
actual_output_fixed = srms_norm(input_fixed_fp32)
# verify ones output and dtype
assert_expected(
actual_output_bf16_ones, expected_output_bf16_ones, atol=1e-04, rtol=1e-05
)
assert actual_output_bf16_ones.dtype == torch.bfloat16
# verify fixed output and dtype
assert_expected(actual_output_fixed, expected_output_fixed, atol=1e-04, rtol=1e-05)
assert actual_output_fixed.dtype == torch.float32
|