Spaces:
Sleeping
Sleeping
File size: 2,951 Bytes
eb9c81a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | import torch
import torch.nn as nn
import torch.nn.functional as F
class LayerNorm(nn.Module):
"""Applies layer normalization to the input tensor.
Arguments
---------
input_size : int
The expected size of the dimension to be normalized.
input_shape : tuple
The expected shape of the input.
eps : float
This value is added to std deviation estimation to improve the numerical
stability.
elementwise_affine : bool
If True, this module has learnable per-element affine parameters
initialized to ones (for weights) and zeros (for biases).
Example
-------
>>> input = torch.randn(100, 101, 128)
>>> norm = LayerNorm(input_shape=input.shape)
>>> output = norm(input)
>>> output.shape
torch.Size([100, 101, 128])
"""
def __init__(
self,
input_size=None,
input_shape=None,
eps=1e-05,
elementwise_affine=True,
):
super().__init__()
self.eps = eps
self.elementwise_affine = elementwise_affine
if input_shape is not None:
input_size = input_shape[2:]
self.norm = torch.nn.LayerNorm(
input_size,
eps=self.eps,
elementwise_affine=self.elementwise_affine,
)
def forward(self, x):
"""Returns the normalized input tensor.
Arguments
---------
x : torch.Tensor (batch, time, channels)
input to normalize. 3d or 4d tensors are expected.
Returns
-------
The normalized outputs.
"""
return self.norm(x)
class LayerNormCN(nn.Module):
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
""" # noqa: E501
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(
x, self.normalized_shape, self.weight, self.bias, self.eps
)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None] * x + self.bias[:, None]
return x |