import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import init from torch.nn.parameter import Parameter class LayerNormalization4DCF(nn.Module): def __init__(self, input_dimension, eps=1e-5): super().__init__() assert len(input_dimension) == 2 param_size = [1, input_dimension[0], 1, input_dimension[1]] self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32)) self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32)) init.ones_(self.gamma) init.zeros_(self.beta) self.eps = eps def forward(self, x): if x.ndim == 4: stat_dim = (1, 3) else: raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim)) mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,1,T,1] std_ = torch.sqrt( x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps ) # [B,1,T,F] x_hat = ((x - mu_) / std_) * self.gamma + self.beta return x_hat