from torch import nn from .ChanNorm import ChanNorm class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = ChanNorm(dim) def forward(self, x): return self.fn(self.norm(x))