Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| class RevIN(nn.Module): | |
| def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False): | |
| """ | |
| :param num_features: the number of features or channels | |
| :param eps: a value added for numerical stability | |
| :param affine: if True, RevIN has learnable affine parameters | |
| """ | |
| super(RevIN, self).__init__() | |
| self.num_features = num_features | |
| self.eps = eps | |
| self.affine = affine | |
| self.subtract_last = subtract_last | |
| if self.affine: | |
| self._init_params() | |
| def forward(self, x, mode: str): | |
| if mode == "norm": | |
| self._get_statistics(x) | |
| x = self._normalize(x) | |
| elif mode == "denorm": | |
| x = self._denormalize(x) | |
| else: | |
| raise NotImplementedError | |
| return x | |
| def _init_params(self): | |
| # initialize RevIN params: (C,) | |
| self.affine_weight = nn.Parameter(torch.ones(self.num_features)) | |
| self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) | |
| def _get_statistics(self, x): | |
| dim2reduce = tuple(range(1, x.ndim - 1)) | |
| if self.subtract_last: | |
| self.last = x[:, -1, :].unsqueeze(1) | |
| else: | |
| self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() | |
| self.stdev = torch.sqrt( | |
| torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps | |
| ).detach() | |
| def _normalize(self, x): | |
| if self.subtract_last: | |
| x = x - self.last | |
| else: | |
| x = x - self.mean | |
| x = x / self.stdev | |
| if self.affine: | |
| x = x * self.affine_weight | |
| x = x + self.affine_bias | |
| return x | |
| def _denormalize(self, x): | |
| if self.affine: | |
| x = x - self.affine_bias | |
| x = x / (self.affine_weight + self.eps * self.eps) | |
| x = x * self.stdev | |
| if self.subtract_last: | |
| x = x + self.last | |
| else: | |
| x = x + self.mean | |
| return x | |