#!/usr/bin/env python import torch as th import torch.nn as nn class ChannelwiseLayerNorm(nn.LayerNorm): """ Channel-wise layer normalization based on nn.LayerNorm Input: 3D tensor with [batch_size(N), channel_size(C), frame_num(T)] Output: 3D tensor with same shape """ def __init__(self, *args, **kwargs): super(ChannelwiseLayerNorm, self).__init__(*args, **kwargs) def forward(self, x): if x.dim() != 3: raise RuntimeError("{} requires a 3D tensor input".format( self.__name__)) x = th.transpose(x, 1, 2) x = super().forward(x) x = th.transpose(x, 1, 2) return x class CumLN(nn.Module): """ Cumulative Global layer normalization Input: 3D tensor with [batch_size(N), channel_size(C), frame_num(T)] Output: 3D tensor with same shape """ def __init__(self, dim, eps=1e-05, elementwise_affine=True): super(CumLN, self).__init__() self.eps = eps self.elementwise_affine = elementwise_affine self.normalized_dim = dim if elementwise_affine: self.beta = nn.Parameter(th.zeros(dim, 1)) self.gamma = nn.Parameter(th.ones(dim, 1)) else: self.register_parameter("weight", None) self.register_parameter("bias", None) def forward(self, x): if x.dim() != 3: raise RuntimeError("{} requires a 3D tensor input".format(self.__class__.__name__)) batch, chan, spec_len = x.size() cum_sum = th.cumsum(x.sum(1, keepdim=True), dim=-1) cum_pow_sum = th.cumsum(x.pow(2).sum(1, keepdim=True), dim=-1) #th.cumsum 后加前 逐元素相加 cnt = th.arange(start=chan, end=chan * (spec_len + 1), step=chan, dtype=x.dtype, device=x.device).view(1, 1, -1) cum_mean = cum_sum / cnt cum_var = cum_pow_sum / cnt - cum_mean.pow(2) normalized_x = (x - cum_mean) / (cum_var + self.eps).sqrt() if self.elementwise_affine: normalized_x = self.gamma * normalized_x + self.beta return normalized_x def extra_repr(self): return "{normalized_dim}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__) class GlobalLayerNorm(nn.Module): """ Global layer normalization Input: 3D tensor with [batch_size(N), channel_size(C), frame_num(T)] Output: 3D tensor with same shape """ def __init__(self, dim, eps=1e-05, elementwise_affine=True): super(GlobalLayerNorm, self).__init__() self.eps = eps self.normalized_dim = dim self.elementwise_affine = elementwise_affine if elementwise_affine: self.beta = nn.Parameter(th.zeros(dim, 1)) self.gamma = nn.Parameter(th.ones(dim, 1)) else: self.register_parameter("weight", None) self.register_parameter("bias", None) def forward(self, x): if x.dim() != 3: raise RuntimeError("{} requires a 3D tensor input".format( self.__name__)) # calculate the mean, variance over the channel and time dimensions mean = th.mean(x, (1, 2), keepdim=True) var = th.mean((x - mean)**2, (1, 2), keepdim=True) if self.elementwise_affine: x = self.gamma * (x - mean) / th.sqrt(var + self.eps) + self.beta else: x = (x - mean) / th.sqrt(var + self.eps) return x def extra_repr(self): return "{normalized_dim}, eps={eps}, " \ "elementwise_affine={elementwise_affine}".format(**self.__dict__)