Spaces:
Running
Running
| #!/usr/bin/env python | |
| import torch as th | |
| import torch.nn as nn | |
| class ChannelwiseLayerNorm(nn.LayerNorm): | |
| """ | |
| Channel-wise layer normalization based on nn.LayerNorm | |
| Input: 3D tensor with [batch_size(N), channel_size(C), frame_num(T)] | |
| Output: 3D tensor with same shape | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super(ChannelwiseLayerNorm, self).__init__(*args, **kwargs) | |
| def forward(self, x): | |
| if x.dim() != 3: | |
| raise RuntimeError("{} requires a 3D tensor input".format( | |
| self.__name__)) | |
| x = th.transpose(x, 1, 2) | |
| x = super().forward(x) | |
| x = th.transpose(x, 1, 2) | |
| return x | |
| class CumLN(nn.Module): | |
| """ | |
| Cumulative Global layer normalization | |
| Input: 3D tensor with [batch_size(N), channel_size(C), frame_num(T)] | |
| Output: 3D tensor with same shape | |
| """ | |
| def __init__(self, dim, eps=1e-05, elementwise_affine=True): | |
| super(CumLN, self).__init__() | |
| self.eps = eps | |
| self.elementwise_affine = elementwise_affine | |
| self.normalized_dim = dim | |
| if elementwise_affine: | |
| self.beta = nn.Parameter(th.zeros(dim, 1)) | |
| self.gamma = nn.Parameter(th.ones(dim, 1)) | |
| else: | |
| self.register_parameter("weight", None) | |
| self.register_parameter("bias", None) | |
| def forward(self, x): | |
| if x.dim() != 3: | |
| raise RuntimeError("{} requires a 3D tensor input".format(self.__class__.__name__)) | |
| batch, chan, spec_len = x.size() | |
| cum_sum = th.cumsum(x.sum(1, keepdim=True), dim=-1) | |
| cum_pow_sum = th.cumsum(x.pow(2).sum(1, keepdim=True), dim=-1) #th.cumsum εε ε ιε η΄ ηΈε | |
| cnt = th.arange(start=chan, end=chan * (spec_len + 1), step=chan, dtype=x.dtype, device=x.device).view(1, 1, -1) | |
| cum_mean = cum_sum / cnt | |
| cum_var = cum_pow_sum / cnt - cum_mean.pow(2) | |
| normalized_x = (x - cum_mean) / (cum_var + self.eps).sqrt() | |
| if self.elementwise_affine: | |
| normalized_x = self.gamma * normalized_x + self.beta | |
| return normalized_x | |
| def extra_repr(self): | |
| return "{normalized_dim}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__) | |
| class GlobalLayerNorm(nn.Module): | |
| """ | |
| Global layer normalization | |
| Input: 3D tensor with [batch_size(N), channel_size(C), frame_num(T)] | |
| Output: 3D tensor with same shape | |
| """ | |
| def __init__(self, dim, eps=1e-05, elementwise_affine=True): | |
| super(GlobalLayerNorm, self).__init__() | |
| self.eps = eps | |
| self.normalized_dim = dim | |
| self.elementwise_affine = elementwise_affine | |
| if elementwise_affine: | |
| self.beta = nn.Parameter(th.zeros(dim, 1)) | |
| self.gamma = nn.Parameter(th.ones(dim, 1)) | |
| else: | |
| self.register_parameter("weight", None) | |
| self.register_parameter("bias", None) | |
| def forward(self, x): | |
| if x.dim() != 3: | |
| raise RuntimeError("{} requires a 3D tensor input".format( | |
| self.__name__)) | |
| # calculate the mean, variance over the channel and time dimensions | |
| mean = th.mean(x, (1, 2), keepdim=True) | |
| var = th.mean((x - mean)**2, (1, 2), keepdim=True) | |
| if self.elementwise_affine: | |
| x = self.gamma * (x - mean) / th.sqrt(var + self.eps) + self.beta | |
| else: | |
| x = (x - mean) / th.sqrt(var + self.eps) | |
| return x | |
| def extra_repr(self): | |
| return "{normalized_dim}, eps={eps}, " \ | |
| "elementwise_affine={elementwise_affine}".format(**self.__dict__) | |