| from torch import nn | |
| from .ChanNorm import ChanNorm | |
| class PreNorm(nn.Module): | |
| def __init__(self, dim, fn): | |
| super().__init__() | |
| self.fn = fn | |
| self.norm = ChanNorm(dim) | |
| def forward(self, x): | |
| return self.fn(self.norm(x)) | |
| from torch import nn | |
| from .ChanNorm import ChanNorm | |
| class PreNorm(nn.Module): | |
| def __init__(self, dim, fn): | |
| super().__init__() | |
| self.fn = fn | |
| self.norm = ChanNorm(dim) | |
| def forward(self, x): | |
| return self.fn(self.norm(x)) | |