| |
| |
| |
| |
|
|
| from typing import List |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.nn.functional import avg_pool2d, relu |
|
|
| from backbone import MammothBackbone |
|
|
|
|
| def conv3x3(in_planes: int, out_planes: int, stride: int=1) -> F.conv2d: |
| """ |
| Instantiates a 3x3 convolutional layer with no bias. |
| :param in_planes: number of input channels |
| :param out_planes: number of output channels |
| :param stride: stride of the convolution |
| :return: convolutional layer |
| """ |
| return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, |
| padding=1, bias=False) |
|
|
|
|
| class BasicBlock(nn.Module): |
| """ |
| The basic block of ResNet. |
| """ |
| expansion = 1 |
|
|
| def __init__(self, in_planes: int, planes: int, stride: int=1) -> None: |
| """ |
| Instantiates the basic block of the network. |
| :param in_planes: the number of input channels |
| :param planes: the number of channels (to be possibly expanded) |
| """ |
| super(BasicBlock, self).__init__() |
| self.conv1 = conv3x3(in_planes, planes, stride) |
| self.bn1 = nn.BatchNorm2d(planes) |
| self.conv2 = conv3x3(planes, planes) |
| self.bn2 = nn.BatchNorm2d(planes) |
|
|
| self.shortcut = nn.Sequential() |
| if stride != 1 or in_planes != self.expansion * planes: |
| self.shortcut = nn.Sequential( |
| nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, |
| stride=stride, bias=False), |
| nn.BatchNorm2d(self.expansion * planes) |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Compute a forward pass. |
| :param x: input tensor (batch_size, input_size) |
| :return: output tensor (10) |
| """ |
| out = relu(self.bn1(self.conv1(x))) |
| out = self.bn2(self.conv2(out)) |
| out += self.shortcut(x) |
| out = relu(out) |
| return out |
|
|
|
|
| class ResNet(MammothBackbone): |
| """ |
| ResNet network architecture. Designed for complex datasets. |
| """ |
|
|
| def __init__(self, block: BasicBlock, num_blocks: List[int], |
| num_classes: int, nf: int) -> None: |
| """ |
| Instantiates the layers of the network. |
| :param block: the basic ResNet block |
| :param num_blocks: the number of blocks per layer |
| :param num_classes: the number of output classes |
| :param nf: the number of filters |
| """ |
| super(ResNet, self).__init__() |
| self.in_planes = nf |
| self.block = block |
| self.num_classes = num_classes |
| self.nf = nf |
| self.final_d = nf * 8 |
| self.conv1 = conv3x3(3, nf * 1) |
| self.bn1 = nn.BatchNorm2d(nf * 1) |
| self.layer1 = self._make_layer(block, nf * 1, num_blocks[0], stride=1) |
| self.layer2 = self._make_layer(block, nf * 2, num_blocks[1], stride=2) |
| self.layer3 = self._make_layer(block, nf * 4, num_blocks[2], stride=2) |
| self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2) |
| self.linear = nn.Linear(nf * 8 * block.expansion, num_classes) |
| self.net_channels = [nf * 1, nf * 2, nf * 4, nf * 8] |
| |
|
|
| self._features = nn.Sequential(self.conv1, |
| self.bn1, |
| nn.ReLU(), |
| self.layer1, |
| self.layer2, |
| self.layer3, |
| self.layer4 |
| ) |
| self.classifier = self.linear |
|
|
| def _make_layer(self, block: BasicBlock, planes: int, |
| num_blocks: int, stride: int) -> nn.Module: |
| """ |
| Instantiates a ResNet layer. |
| :param block: ResNet basic block |
| :param planes: channels across the network |
| :param num_blocks: number of blocks |
| :param stride: stride |
| :return: ResNet layer |
| """ |
| strides = [stride] + [1] * (num_blocks - 1) |
| layers = [] |
| for stride in strides: |
| layers.append(block(self.in_planes, planes, stride)) |
| self.in_planes = planes * block.expansion |
| return nn.Sequential(*layers) |
|
|
| def forward(self, x: torch.Tensor, returnt='out') -> torch.Tensor: |
| """ |
| Compute a forward pass. |
| :param x: input tensor (batch_size, *input_shape) |
| :param returnt: return type (a string among 'out', 'features', 'all') |
| :return: output tensor (output_classes) |
| """ |
|
|
| out = relu(self.bn1(self.conv1(x))) |
| if hasattr(self, 'maxpool'): |
| out = self.maxpool(out) |
| out = self.layer1(out) |
| out = self.layer2(out) |
| out = self.layer3(out) |
| out = self.layer4(out) |
| out = avg_pool2d(out, out.shape[2]) |
| feature = out.view(out.size(0), -1) |
|
|
| if returnt == 'features': |
| return feature |
|
|
| out = self.classifier(feature) |
|
|
| if returnt == 'out': |
| return out |
| elif returnt == 'all': |
| return (out, feature) |
|
|
| raise NotImplementedError("Unknown return type") |
| def forward_all_layers(self, x: torch.Tensor) -> torch.Tensor: |
| feats = [] |
| |
| out = relu(self.bn1(self.conv1(x))) |
| if hasattr(self, 'maxpool'): |
| out = self.maxpool(out) |
| out = self.layer1(out) |
| feats.append(out) |
|
|
| out = self.layer2(out) |
| feats.append(out) |
|
|
| out = self.layer3(out) |
| feats.append(out) |
|
|
| out = self.layer4(out) |
| feats.append(out) |
|
|
| out = avg_pool2d(out, out.shape[2]) |
| feature = out.view(out.size(0), -1) |
| |
| out = self.classifier(feature) |
|
|
| return out, feats |
|
|
|
|
| def resnet18(nclasses: int, nf: int=64) -> ResNet: |
| """ |
| Instantiates a ResNet18 network. |
| :param nclasses: number of output classes |
| :param nf: number of filters |
| :return: ResNet network |
| """ |
| return ResNet(BasicBlock, [2, 2, 2, 2], nclasses, nf) |
|
|