| import torch.nn as nn | |
| class ResNormLayer(nn.Module): | |
| def __init__(self, linear_size,): | |
| super(ResNormLayer, self).__init__() | |
| self.l_size = linear_size | |
| self.nonlin1 = nn.ReLU(inplace=True) | |
| self.nonlin2 = nn.ReLU(inplace=True) | |
| self.norm_fn1 = nn.LayerNorm(self.l_size) | |
| self.norm_fn2 = nn.LayerNorm(self.l_size) | |
| self.w1 = nn.Linear(self.l_size, self.l_size) | |
| self.w2 = nn.Linear(self.l_size, self.l_size) | |
| def forward(self, x): | |
| y = self.w1(x) | |
| y = self.nonlin1(y) | |
| y = self.norm_fn1(y) | |
| y = self.w2(y) | |
| y = self.nonlin2(y) | |
| y = self.norm_fn2(y) | |
| out = x + y | |
| return out | |