Spaces:
Running
Running
| from torch import nn | |
| from torchvision.models.resnet import ResNet, Bottleneck | |
| class ResNetEncoder(ResNet): | |
| """ | |
| ResNetEncoder inherits from torchvision's official ResNet. It is modified to | |
| use dilation on the last block to maintain output stride 16, and deleted the | |
| global average pooling layer and the fully connected layer that was originally | |
| used for classification. The forward method additionally returns the feature | |
| maps at all resolutions for decoder's use. | |
| """ | |
| layers = { | |
| 'resnet50': [3, 4, 6, 3], | |
| 'resnet101': [3, 4, 23, 3], | |
| } | |
| def __init__(self, in_channels, variant='resnet101', norm_layer=None): | |
| super().__init__( | |
| block=Bottleneck, | |
| layers=self.layers[variant], | |
| replace_stride_with_dilation=[False, False, True], | |
| norm_layer=norm_layer) | |
| # Replace first conv layer if in_channels doesn't match. | |
| if in_channels != 3: | |
| self.conv1 = nn.Conv2d(in_channels, 64, 7, 2, 3, bias=False) | |
| # Delete fully-connected layer | |
| del self.avgpool | |
| del self.fc | |
| def forward(self, x): | |
| x0 = x # 1/1 | |
| x = self.conv1(x) | |
| x = self.bn1(x) | |
| x = self.relu(x) | |
| x1 = x # 1/2 | |
| x = self.maxpool(x) | |
| x = self.layer1(x) | |
| x2 = x # 1/4 | |
| x = self.layer2(x) | |
| x3 = x # 1/8 | |
| x = self.layer3(x) | |
| x = self.layer4(x) | |
| x4 = x # 1/16 | |
| return x4, x3, x2, x1, x0 | |