File size: 1,049 Bytes
eefb734 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch.nn.parameter import Parameter
class LayerNormalization4DCF(nn.Module):
def __init__(self, input_dimension, eps=1e-5):
super().__init__()
assert len(input_dimension) == 2
param_size = [1, input_dimension[0], 1, input_dimension[1]]
self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
init.ones_(self.gamma)
init.zeros_(self.beta)
self.eps = eps
def forward(self, x):
if x.ndim == 4:
stat_dim = (1, 3)
else:
raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))
mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,1,T,1]
std_ = torch.sqrt(
x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps
) # [B,1,T,F]
x_hat = ((x - mu_) / std_) * self.gamma + self.beta
return x_hat
|