| import torch | |
| from torch import nn | |
| class ChanNorm(nn.Module): | |
| def __init__(self, dim, eps=1e-5): | |
| super().__init__() | |
| self.eps = eps | |
| self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) | |
| self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) | |
| def forward(self, x): | |
| var = torch.var(x, dim=1, unbiased=False, keepdim=True) | |
| mean = torch.mean(x, dim=1, keepdim=True) | |
| return (x - mean) / (var + self.eps).sqrt() * self.g + self.b | |