| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn import init | |
| from torch.nn.parameter import Parameter | |
| class LayerNormalization4DCF(nn.Module): | |
| def __init__(self, input_dimension, eps=1e-5): | |
| super().__init__() | |
| assert len(input_dimension) == 2 | |
| param_size = [1, input_dimension[0], 1, input_dimension[1]] | |
| self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32)) | |
| self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32)) | |
| init.ones_(self.gamma) | |
| init.zeros_(self.beta) | |
| self.eps = eps | |
| def forward(self, x): | |
| if x.ndim == 4: | |
| stat_dim = (1, 3) | |
| else: | |
| raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim)) | |
| mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,1,T,1] | |
| std_ = torch.sqrt( | |
| x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps | |
| ) # [B,1,T,F] | |
| x_hat = ((x - mu_) / std_) * self.gamma + self.beta | |
| return x_hat | |