| from ...layers import * |
| from ...torch_core import * |
|
|
| __all__ = ['BasicBlock', 'WideResNet', 'wrn_22'] |
|
|
| def _bn(ni, init_zero=False): |
| "Batchnorm layer with 0 initialization" |
| m = nn.BatchNorm2d(ni) |
| m.weight.data.fill_(0 if init_zero else 1) |
| m.bias.data.zero_() |
| return m |
|
|
| def bn_relu_conv(ni, nf, ks, stride, init_zero=False): |
| bn_initzero = _bn(ni, init_zero=init_zero) |
| return nn.Sequential(bn_initzero, nn.ReLU(inplace=True), conv2d(ni, nf, ks, stride)) |
|
|
| class BasicBlock(Module): |
| "Block to from a wide ResNet." |
| def __init__(self, ni, nf, stride, drop_p=0.0): |
| self.bn = nn.BatchNorm2d(ni) |
| self.conv1 = conv2d(ni, nf, 3, stride) |
| self.conv2 = bn_relu_conv(nf, nf, 3, 1) |
| self.drop = nn.Dropout(drop_p, inplace=True) if drop_p else None |
| self.shortcut = conv2d(ni, nf, 1, stride) if ni != nf else noop |
|
|
| def forward(self, x): |
| x2 = F.relu(self.bn(x), inplace=True) |
| r = self.shortcut(x2) |
| x = self.conv1(x2) |
| if self.drop: x = self.drop(x) |
| x = self.conv2(x) * 0.2 |
| return x.add_(r) |
|
|
| def _make_group(N, ni, nf, block, stride, drop_p): |
| return [block(ni if i == 0 else nf, nf, stride if i == 0 else 1, drop_p) for i in range(N)] |
|
|
| class WideResNet(Module): |
| "Wide ResNet with `num_groups` and a width of `k`." |
| def __init__(self, num_groups:int, N:int, num_classes:int, k:int=1, drop_p:float=0.0, start_nf:int=16, n_in_channels:int=3): |
| n_channels = [start_nf] |
| for i in range(num_groups): n_channels.append(start_nf*(2**i)*k) |
|
|
| layers = [conv2d(n_in_channels, n_channels[0], 3, 1)] |
| for i in range(num_groups): |
| layers += _make_group(N, n_channels[i], n_channels[i+1], BasicBlock, (1 if i==0 else 2), drop_p) |
|
|
| layers += [nn.BatchNorm2d(n_channels[num_groups]), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d(1), |
| Flatten(), nn.Linear(n_channels[num_groups], num_classes)] |
| self.features = nn.Sequential(*layers) |
|
|
| def forward(self, x): return self.features(x) |
|
|
|
|
| def wrn_22(): |
| "Wide ResNet with 22 layers." |
| return WideResNet(num_groups=3, N=3, num_classes=10, k=6, drop_p=0.) |
|
|