swc2's picture
update change 2
7eddfc5
#!/usr/bin/env python
import torch as th
import torch.nn as nn
class ChannelwiseLayerNorm(nn.LayerNorm):
"""
Channel-wise layer normalization based on nn.LayerNorm
Input: 3D tensor with [batch_size(N), channel_size(C), frame_num(T)]
Output: 3D tensor with same shape
"""
def __init__(self, *args, **kwargs):
super(ChannelwiseLayerNorm, self).__init__(*args, **kwargs)
def forward(self, x):
if x.dim() != 3:
raise RuntimeError("{} requires a 3D tensor input".format(
self.__name__))
x = th.transpose(x, 1, 2)
x = super().forward(x)
x = th.transpose(x, 1, 2)
return x
class CumLN(nn.Module):
"""
Cumulative Global layer normalization
Input: 3D tensor with [batch_size(N), channel_size(C), frame_num(T)]
Output: 3D tensor with same shape
"""
def __init__(self, dim, eps=1e-05, elementwise_affine=True):
super(CumLN, self).__init__()
self.eps = eps
self.elementwise_affine = elementwise_affine
self.normalized_dim = dim
if elementwise_affine:
self.beta = nn.Parameter(th.zeros(dim, 1))
self.gamma = nn.Parameter(th.ones(dim, 1))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def forward(self, x):
if x.dim() != 3:
raise RuntimeError("{} requires a 3D tensor input".format(self.__class__.__name__))
batch, chan, spec_len = x.size()
cum_sum = th.cumsum(x.sum(1, keepdim=True), dim=-1)
cum_pow_sum = th.cumsum(x.pow(2).sum(1, keepdim=True), dim=-1) #th.cumsum εŽεŠ ε‰ ι€ε…ƒη΄ η›ΈεŠ 
cnt = th.arange(start=chan, end=chan * (spec_len + 1), step=chan, dtype=x.dtype, device=x.device).view(1, 1, -1)
cum_mean = cum_sum / cnt
cum_var = cum_pow_sum / cnt - cum_mean.pow(2)
normalized_x = (x - cum_mean) / (cum_var + self.eps).sqrt()
if self.elementwise_affine:
normalized_x = self.gamma * normalized_x + self.beta
return normalized_x
def extra_repr(self):
return "{normalized_dim}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
class GlobalLayerNorm(nn.Module):
"""
Global layer normalization
Input: 3D tensor with [batch_size(N), channel_size(C), frame_num(T)]
Output: 3D tensor with same shape
"""
def __init__(self, dim, eps=1e-05, elementwise_affine=True):
super(GlobalLayerNorm, self).__init__()
self.eps = eps
self.normalized_dim = dim
self.elementwise_affine = elementwise_affine
if elementwise_affine:
self.beta = nn.Parameter(th.zeros(dim, 1))
self.gamma = nn.Parameter(th.ones(dim, 1))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def forward(self, x):
if x.dim() != 3:
raise RuntimeError("{} requires a 3D tensor input".format(
self.__name__))
# calculate the mean, variance over the channel and time dimensions
mean = th.mean(x, (1, 2), keepdim=True)
var = th.mean((x - mean)**2, (1, 2), keepdim=True)
if self.elementwise_affine:
x = self.gamma * (x - mean) / th.sqrt(var + self.eps) + self.beta
else:
x = (x - mean) / th.sqrt(var + self.eps)
return x
def extra_repr(self):
return "{normalized_dim}, eps={eps}, " \
"elementwise_affine={elementwise_affine}".format(**self.__dict__)