Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.autograd import Variable | |
| import torch.nn as nn | |
| import numpy as np | |
| from typing import List | |
| from torch.nn.modules.batchnorm import _BatchNorm | |
| from collections.abc import Iterable | |
| def norm(x, dims: List[int], EPS: float = 1e-8): | |
| mean = x.mean(dim=dims, keepdim=True) | |
| var2 = torch.var(x, dim=dims, keepdim=True, unbiased=False) | |
| value = (x - mean) / torch.sqrt(var2 + EPS) | |
| return value | |
| def glob_norm(x, ESP: float = 1e-8): | |
| dims: List[int] = torch.arange(1, len(x.shape)).tolist() | |
| return norm(x, dims, ESP) | |
| class MLayerNorm(nn.Module): | |
| def __init__(self, channel_size): | |
| super().__init__() | |
| self.channel_size = channel_size | |
| self.gamma = nn.Parameter(torch.ones(channel_size), requires_grad=True) | |
| self.beta = nn.Parameter(torch.ones(channel_size), requires_grad=True) | |
| def apply_gain_and_bias(self, normed_x): | |
| """Assumes input of size `[batch, chanel, *]`.""" | |
| return (self.gamma * normed_x.transpose(1, -1) + self.beta).transpose(1, -1) | |
| def forward(self, x, EPS: float = 1e-8): | |
| pass | |
| class GlobalLN(MLayerNorm): | |
| def forward(self, x, EPS: float = 1e-8): | |
| value = glob_norm(x, EPS) | |
| return self.apply_gain_and_bias(value) | |
| class ChannelLN(MLayerNorm): | |
| def forward(self, x, EPS: float = 1e-8): | |
| mean = torch.mean(x, dim=1, keepdim=True) | |
| var = torch.var(x, dim=1, keepdim=True, unbiased=False) | |
| return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt()) | |
| # class CumulateLN(MLayerNorm): | |
| # def forward(self, x, EPS: float = 1e-8): | |
| # batch, channels, time = x.size() | |
| # cum_sum = torch.cumsum(x.sum(1, keepdim=True), dim=1) | |
| # cum_pow_sum = torch.cumsum(x.pow(2).sum(1, keepdim=True), dim=1) | |
| # cnt = torch.arange( | |
| # start=channels, end=channels * (time + 1), step=channels, dtype=x.dtype, device=x.device | |
| # ).view(1, 1, -1) | |
| # cum_mean = cum_sum / cnt | |
| # cum_var = (cum_pow_sum / cnt) - cum_mean.pow(2) | |
| # return self.apply_gain_and_bias((x - cum_mean) / (cum_var + EPS).sqrt()) | |
| class BatchNorm(_BatchNorm): | |
| """Wrapper class for pytorch BatchNorm1D and BatchNorm2D""" | |
| def _check_input_dim(self, input): | |
| if input.dim() < 2 or input.dim() > 4: | |
| raise ValueError( | |
| "expected 4D or 3D input (got {}D input)".format(input.dim()) | |
| ) | |
| class CumulativeLayerNorm(nn.LayerNorm): | |
| def __init__(self, dim, elementwise_affine=True): | |
| super(CumulativeLayerNorm, self).__init__( | |
| dim, elementwise_affine=elementwise_affine | |
| ) | |
| def forward(self, x): | |
| # x: N x C x L | |
| # N x L x C | |
| x = torch.transpose(x, 1, -1) | |
| # N x L x C == only channel norm | |
| x = super().forward(x) | |
| # N x C x L | |
| x = torch.transpose(x, 1, -1) | |
| return x | |
| class CumulateLN(nn.Module): | |
| def __init__(self, dimension, eps=1e-8, trainable=True): | |
| super(CumulateLN, self).__init__() | |
| self.eps = eps | |
| if trainable: | |
| self.gain = nn.Parameter(torch.ones(1, dimension, 1)) | |
| self.bias = nn.Parameter(torch.zeros(1, dimension, 1)) | |
| else: | |
| self.gain = Variable(torch.ones(1, dimension, 1), requires_grad=False) | |
| self.bias = Variable(torch.zeros(1, dimension, 1), requires_grad=False) | |
| def forward(self, input): | |
| # input size: (Batch, Freq, Time) | |
| # cumulative mean for each time step | |
| batch_size = input.size(0) | |
| channel = input.size(1) | |
| time_step = input.size(2) | |
| step_sum = input.sum(1) # B, T | |
| step_pow_sum = input.pow(2).sum(1) # B, T | |
| cum_sum = torch.cumsum(step_sum, dim=1) # B, T | |
| cum_pow_sum = torch.cumsum(step_pow_sum, dim=1) # B, T | |
| entry_cnt = np.arange(channel, channel * (time_step + 1), channel) | |
| entry_cnt = torch.from_numpy(entry_cnt).type(input.type()) | |
| entry_cnt = entry_cnt.view(1, -1).expand_as(cum_sum) | |
| cum_mean = cum_sum / entry_cnt # B, T | |
| cum_var = (cum_pow_sum - 2 * cum_mean * cum_sum) / entry_cnt + cum_mean.pow( | |
| 2 | |
| ) # B, T | |
| cum_std = (cum_var + self.eps).sqrt() # B, T | |
| cum_mean = cum_mean.unsqueeze(1) | |
| cum_std = cum_std.unsqueeze(1) | |
| x = (input - cum_mean.expand_as(input)) / cum_std.expand_as(input) | |
| return x * self.gain.expand_as(x).type(x.type()) + self.bias.expand_as(x).type( | |
| x.type() | |
| ) | |
| class LayerNormalization4D(nn.Module): | |
| def __init__(self, input_dimension: Iterable, eps: float = 1e-5): | |
| super(LayerNormalization4D, self).__init__() | |
| assert len(input_dimension) == 2 | |
| param_size = [1, input_dimension[0], 1, input_dimension[1]] | |
| self.dim = (1, 3) if param_size[-1] > 1 else (1,) | |
| self.gamma = nn.Parameter(torch.Tensor(*param_size).to(torch.float32)) | |
| self.beta = nn.Parameter(torch.Tensor(*param_size).to(torch.float32)) | |
| nn.init.ones_(self.gamma) | |
| nn.init.zeros_(self.beta) | |
| self.eps = eps | |
| def forward(self, x: torch.Tensor): | |
| mu_ = x.mean(dim=self.dim, keepdim=True) | |
| std_ = torch.sqrt(x.var(dim=self.dim, unbiased=False, keepdim=True) + self.eps) | |
| x_hat = ((x - mu_) / std_) * self.gamma + self.beta | |
| return x_hat | |
| # Aliases. | |
| gLN = GlobalLN | |
| cLN = CumulateLN | |
| LN = CumulativeLayerNorm | |
| bN = BatchNorm | |
| LN4D = LayerNormalization4D | |
| def get(identifier): | |
| """Returns a norm class from a string. Returns its input if it | |
| is callable (already a :class:`._LayerNorm` for example). | |
| Args: | |
| identifier (str or Callable or None): the norm identifier. | |
| Returns: | |
| :class:`._LayerNorm` or None | |
| """ | |
| if identifier is None: | |
| return None | |
| elif callable(identifier): | |
| return identifier | |
| elif isinstance(identifier, str): | |
| cls = globals().get(identifier) | |
| if cls is None: | |
| raise ValueError( | |
| "Could not interpret normalization identifier: " + str(identifier) | |
| ) | |
| return cls | |
| else: | |
| raise ValueError( | |
| "Could not interpret normalization identifier: " + str(identifier) | |
| ) | |