| import torch
|
| import torch.nn as nn
|
| import torch.utils.model_zoo as model_zoo
|
| from itertools import chain
|
| import torch.utils.checkpoint as cp
|
| from math import log2
|
| from timm.models.registry import register_model
|
|
|
| from ..utils.utils import _Upsample, SpatialPyramidPooling, SeparableConv2d
|
|
|
| __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'BasicBlock']
|
|
|
| model_urls = {
|
| 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
| 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
| 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
| 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
| 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
| }
|
|
|
|
|
| def conv3x3(in_planes, out_planes, stride=1, separable=False):
|
| """3x3 convolution with padding"""
|
| conv_class = SeparableConv2d if separable else nn.Conv2d
|
| return conv_class(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
|
|
|
|
| def _bn_function_factory(conv, norm, relu=None):
|
| def bn_function(x):
|
| x = conv(x)
|
| if norm is not None:
|
| x = norm(x)
|
| if relu is not None:
|
| x = relu(x)
|
| return x
|
|
|
| return bn_function
|
|
|
| def do_efficient_fwd(block, x, efficient):
|
| if efficient and x.requires_grad:
|
| return cp.checkpoint(block, x)
|
| else:
|
| return block(x)
|
|
|
|
|
| class BasicBlock(nn.Module):
|
| expansion = 1
|
|
|
| def __init__(self, inplanes, planes, stride=1, downsample=None, efficient=True, use_bn=True, deleting=False,
|
| separable=False):
|
| super(BasicBlock, self).__init__()
|
| self.use_bn = use_bn
|
| self.conv1 = conv3x3(inplanes, planes, stride, separable=separable)
|
| self.bn1 = nn.BatchNorm2d(planes) if self.use_bn else None
|
| self.relu = nn.ReLU(inplace=True)
|
| self.conv2 = conv3x3(planes, planes, separable=separable)
|
| self.bn2 = nn.BatchNorm2d(planes) if self.use_bn else None
|
| self.downsample = downsample
|
| self.stride = stride
|
| self.efficient = efficient
|
| self.deleting = deleting
|
|
|
| def forward(self, x):
|
| residual = x
|
|
|
| if self.downsample is not None:
|
| residual = self.downsample(x)
|
|
|
| if self.deleting is False:
|
| bn_1 = _bn_function_factory(self.conv1, self.bn1, self.relu)
|
| bn_2 = _bn_function_factory(self.conv2, self.bn2)
|
|
|
| out = do_efficient_fwd(bn_1, x, self.efficient)
|
| out = do_efficient_fwd(bn_2, out, self.efficient)
|
| else:
|
| out = torch.zeros_like(residual)
|
|
|
| out = out + residual
|
| relu = self.relu(out)
|
|
|
|
|
| return relu, out
|
|
|
|
|
| class Bottleneck(nn.Module):
|
| expansion = 4
|
|
|
| def __init__(self, inplanes, planes, stride=1, downsample=None, efficient=True, use_bn=True, separable=False):
|
| super(Bottleneck, self).__init__()
|
| self.use_bn = use_bn
|
| self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| self.bn1 = nn.BatchNorm2d(planes) if self.use_bn else None
|
| conv_class = SeparableConv2d if separable else nn.Conv2d
|
| self.conv2 = conv_class(planes, planes, kernel_size=3, stride=stride,
|
| padding=1, bias=False)
|
| self.bn2 = nn.BatchNorm2d(planes) if self.use_bn else None
|
| self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
|
| self.bn3 = nn.BatchNorm2d(planes * self.expansion) if self.use_bn else None
|
| self.relu = nn.ReLU(inplace=False)
|
| self.downsample = downsample
|
| self.stride = stride
|
| self.efficient = efficient
|
|
|
| def forward(self, x):
|
| residual = x
|
|
|
| bn_1 = _bn_function_factory(self.conv1, self.bn1, self.relu)
|
| bn_2 = _bn_function_factory(self.conv2, self.bn2, self.relu)
|
| bn_3 = _bn_function_factory(self.conv3, self.bn3, self.relu)
|
|
|
| out = do_efficient_fwd(bn_1, x, self.efficient)
|
| out = do_efficient_fwd(bn_2, out, self.efficient)
|
| out = do_efficient_fwd(bn_3, out, self.efficient)
|
|
|
| if self.downsample is not None:
|
| residual = self.downsample(x)
|
|
|
| out = out + residual
|
| relu = self.relu(out)
|
|
|
| return relu, out
|
|
|
|
|
| class ResNet(nn.Module):
|
| def __init__(self, block, layers, *, num_features=128, k_up=3, efficient=False, use_bn=True,
|
| spp_grids=(8, 4, 2, 1), spp_square_grid=False, spp_drop_rate=0.0,
|
| upsample_skip=True, upsample_only_skip=False,
|
| detach_upsample_skips=(), detach_upsample_in=False,
|
| target_size=None, output_stride=4, mean=(73.1584, 82.9090, 72.3924),
|
| std=(44.9149, 46.1529, 45.3192), scale=1, separable=False,
|
| upsample_separable=False, **kwargs):
|
| super(ResNet, self).__init__()
|
| self.inplanes = 64
|
| self.efficient = efficient
|
| self.use_bn = use_bn
|
| self.separable = separable
|
| self.register_buffer('img_mean', torch.tensor(mean).view(1, -1, 1, 1))
|
| self.register_buffer('img_std', torch.tensor(std).view(1, -1, 1, 1))
|
| if scale != 1:
|
| self.register_buffer('img_scale', torch.tensor(scale).view(1, -1, 1, 1).float())
|
|
|
| self.detach_upsample_in = detach_upsample_in
|
| self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
| bias=False)
|
| self.bn1 = nn.BatchNorm2d(64) if self.use_bn else lambda x: x
|
| self.relu = nn.ReLU(inplace=True)
|
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| self.target_size = target_size
|
| if self.target_size is not None:
|
| h, w = target_size
|
| target_sizes = [(h // 2 ** i, w // 2 ** i) for i in range(2, 6)]
|
| else:
|
| target_sizes = [None] * 4
|
| upsamples = []
|
| self.layer1 = self._make_layer(block, 64, layers[0])
|
| upsamples += [
|
| _Upsample(num_features, self.inplanes, num_features, use_bn=self.use_bn, k=k_up, use_skip=upsample_skip,
|
| only_skip=upsample_only_skip, detach_skip=2 in detach_upsample_skips, fixed_size=target_sizes[0],
|
| separable=upsample_separable)]
|
| self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| upsamples += [
|
| _Upsample(num_features, self.inplanes, num_features, use_bn=self.use_bn, k=k_up, use_skip=upsample_skip,
|
| only_skip=upsample_only_skip, detach_skip=1 in detach_upsample_skips, fixed_size=target_sizes[1],
|
| separable=upsample_separable)]
|
| self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| upsamples += [
|
| _Upsample(num_features, self.inplanes, num_features, use_bn=self.use_bn, k=k_up, use_skip=upsample_skip,
|
| only_skip=upsample_only_skip, detach_skip=0 in detach_upsample_skips, fixed_size=target_sizes[2],
|
| separable=upsample_separable)]
|
| self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
|
|
| self.fine_tune = [self.conv1, self.maxpool, self.layer1, self.layer2, self.layer3, self.layer4]
|
| if self.use_bn:
|
| self.fine_tune += [self.bn1]
|
|
|
| num_levels = 3
|
| self.spp_size = kwargs.get('spp_size', num_features)
|
| bt_size = self.spp_size
|
|
|
| level_size = self.spp_size // num_levels
|
|
|
| self.spp = SpatialPyramidPooling(self.inplanes, num_levels, bt_size=bt_size, level_size=level_size,
|
| out_size=num_features, grids=spp_grids, square_grid=spp_square_grid,
|
| bn_momentum=0.01 / 2, use_bn=self.use_bn, drop_rate=spp_drop_rate
|
| , fixed_size=target_sizes[3])
|
| num_up_remove = max(0, int(log2(output_stride) - 2))
|
| self.upsample = nn.ModuleList(list(reversed(upsamples[num_up_remove:])))
|
|
|
| self.random_init = [self.spp, self.upsample]
|
|
|
| self.num_features = num_features
|
|
|
| for m in self.modules():
|
| if isinstance(m, nn.Conv2d):
|
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| elif isinstance(m, nn.BatchNorm2d):
|
| nn.init.constant_(m.weight, 1)
|
| nn.init.constant_(m.bias, 0)
|
|
|
| def _make_layer(self, block, planes, blocks, stride=1):
|
| downsample = None
|
| if stride != 1 or self.inplanes != planes * block.expansion:
|
| layers = [nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False)]
|
| if self.use_bn:
|
| layers += [nn.BatchNorm2d(planes * block.expansion)]
|
| downsample = nn.Sequential(*layers)
|
| layers = [block(self.inplanes, planes, stride, downsample, efficient=self.efficient, use_bn=self.use_bn,
|
| separable=self.separable)]
|
| self.inplanes = planes * block.expansion
|
| for i in range(1, blocks):
|
| layers += [block(self.inplanes, planes, efficient=self.efficient, use_bn=self.use_bn,
|
| separable=self.separable)]
|
|
|
| return nn.Sequential(*layers)
|
|
|
| def random_init_params(self):
|
| return chain(*[f.parameters() for f in self.random_init])
|
|
|
| def fine_tune_params(self):
|
| return chain(*[f.parameters() for f in self.fine_tune])
|
|
|
| def forward_resblock(self, x, layers):
|
| skip = None
|
| for l in layers:
|
| x = l(x)
|
| if isinstance(x, tuple):
|
| x, skip = x
|
| return x, skip
|
|
|
| def forward_down(self, image):
|
| if hasattr(self, 'img_scale'):
|
| image /= self.img_scale
|
| image -= self.img_mean
|
| image /= self.img_std
|
|
|
| x = self.conv1(image)
|
| x = self.bn1(x)
|
| x = self.relu(x)
|
| x = self.maxpool(x)
|
|
|
| features = []
|
| x, skip = self.forward_resblock(x, self.layer1)
|
| features += [skip]
|
| x, skip = self.forward_resblock(x, self.layer2)
|
| features += [skip]
|
| x, skip = self.forward_resblock(x, self.layer3)
|
| features += [skip]
|
| x, skip = self.forward_resblock(x, self.layer4)
|
| features += [self.spp.forward(skip)]
|
| return features
|
|
|
| def forward_up(self, features):
|
| features = features[::-1]
|
|
|
| x = features[0]
|
| if self.detach_upsample_in:
|
| x = x.detach()
|
|
|
| upsamples = []
|
| for skip, up in zip(features[1:], self.upsample):
|
| x = up(x, skip)
|
| upsamples += [x]
|
| return x, {'features': features, 'upsamples': upsamples}
|
|
|
| def forward(self, image):
|
| return self.forward_up(self.forward_down(image))
|
|
|
| @register_model
|
| def resnet18(pretrained=True, **kwargs):
|
| """Constructs a ResNet-18 model.
|
| Args:
|
| pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| """
|
| model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
| if pretrained:
|
| model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False)
|
| return model
|
|
|
|
|
| @register_model
|
| def resnet34(pretrained=True, **kwargs):
|
| """Constructs a ResNet-34 model.
|
| Args:
|
| pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| """
|
| model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
| if pretrained:
|
| model.load_state_dict(model_zoo.load_url(model_urls['resnet34']), strict=False)
|
| return model
|
|
|
| @register_model
|
| def resnet50(pretrained=True, **kwargs):
|
| """Constructs a ResNet-50 model.
|
| Args:
|
| pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| """
|
| model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
| if pretrained:
|
| model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False)
|
| return model
|
|
|
| @register_model
|
| def resnet101(pretrained=True, **kwargs):
|
| """Constructs a ResNet-101 model.
|
| Args:
|
| pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| """
|
| model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
| if pretrained:
|
| model.load_state_dict(model_zoo.load_url(model_urls['resnet101']), strict=False)
|
| return model
|
|
|
| @register_model
|
| def resnet152(pretrained=True, **kwargs):
|
| """Constructs a ResNet-152 model.
|
| Args:
|
| pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| """
|
| model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
| if pretrained:
|
| model.load_state_dict(model_zoo.load_url(model_urls['resnet152']), strict=False)
|
| return model |