Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch.nn as nn | |
| from mmcv.cnn import build_plugin_layer | |
| def conv3x3(in_planes, out_planes, stride=1): | |
| """3x3 convolution with padding.""" | |
| return nn.Conv2d( | |
| in_planes, | |
| out_planes, | |
| kernel_size=3, | |
| stride=stride, | |
| padding=1, | |
| bias=False) | |
| def conv1x1(in_planes, out_planes): | |
| """1x1 convolution with padding.""" | |
| return nn.Conv2d( | |
| in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) | |
| class BasicBlock(nn.Module): | |
| expansion = 1 | |
| def __init__(self, | |
| inplanes, | |
| planes, | |
| stride=1, | |
| downsample=None, | |
| use_conv1x1=False, | |
| plugins=None): | |
| super(BasicBlock, self).__init__() | |
| if use_conv1x1: | |
| self.conv1 = conv1x1(inplanes, planes) | |
| self.conv2 = conv3x3(planes, planes * self.expansion, stride) | |
| else: | |
| self.conv1 = conv3x3(inplanes, planes, stride) | |
| self.conv2 = conv3x3(planes, planes * self.expansion) | |
| self.with_plugins = False | |
| if plugins: | |
| if isinstance(plugins, dict): | |
| plugins = [plugins] | |
| self.with_plugins = True | |
| # collect plugins for conv1/conv2/ | |
| self.before_conv1_plugin = [ | |
| plugin['cfg'] for plugin in plugins | |
| if plugin['position'] == 'before_conv1' | |
| ] | |
| self.after_conv1_plugin = [ | |
| plugin['cfg'] for plugin in plugins | |
| if plugin['position'] == 'after_conv1' | |
| ] | |
| self.after_conv2_plugin = [ | |
| plugin['cfg'] for plugin in plugins | |
| if plugin['position'] == 'after_conv2' | |
| ] | |
| self.after_shortcut_plugin = [ | |
| plugin['cfg'] for plugin in plugins | |
| if plugin['position'] == 'after_shortcut' | |
| ] | |
| self.planes = planes | |
| self.bn1 = nn.BatchNorm2d(planes) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.bn2 = nn.BatchNorm2d(planes * self.expansion) | |
| self.downsample = downsample | |
| self.stride = stride | |
| if self.with_plugins: | |
| self.before_conv1_plugin_names = self.make_block_plugins( | |
| inplanes, self.before_conv1_plugin) | |
| self.after_conv1_plugin_names = self.make_block_plugins( | |
| planes, self.after_conv1_plugin) | |
| self.after_conv2_plugin_names = self.make_block_plugins( | |
| planes, self.after_conv2_plugin) | |
| self.after_shortcut_plugin_names = self.make_block_plugins( | |
| planes, self.after_shortcut_plugin) | |
| def make_block_plugins(self, in_channels, plugins): | |
| """make plugins for block. | |
| Args: | |
| in_channels (int): Input channels of plugin. | |
| plugins (list[dict]): List of plugins cfg to build. | |
| Returns: | |
| list[str]: List of the names of plugin. | |
| """ | |
| assert isinstance(plugins, list) | |
| plugin_names = [] | |
| for plugin in plugins: | |
| plugin = plugin.copy() | |
| name, layer = build_plugin_layer( | |
| plugin, | |
| in_channels=in_channels, | |
| out_channels=in_channels, | |
| postfix=plugin.pop('postfix', '')) | |
| assert not hasattr(self, name), f'duplicate plugin {name}' | |
| self.add_module(name, layer) | |
| plugin_names.append(name) | |
| return plugin_names | |
| def forward_plugin(self, x, plugin_names): | |
| out = x | |
| for name in plugin_names: | |
| out = getattr(self, name)(x) | |
| return out | |
| def forward(self, x): | |
| if self.with_plugins: | |
| x = self.forward_plugin(x, self.before_conv1_plugin_names) | |
| residual = x | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = self.relu(out) | |
| if self.with_plugins: | |
| out = self.forward_plugin(out, self.after_conv1_plugin_names) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| if self.with_plugins: | |
| out = self.forward_plugin(out, self.after_conv2_plugin_names) | |
| if self.downsample is not None: | |
| residual = self.downsample(x) | |
| out += residual | |
| out = self.relu(out) | |
| if self.with_plugins: | |
| out = self.forward_plugin(out, self.after_shortcut_plugin_names) | |
| return out | |
| class Bottleneck(nn.Module): | |
| expansion = 4 | |
| def __init__(self, inplanes, planes, stride=1, downsample=False): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | |
| self.bn1 = nn.BatchNorm2d(planes) | |
| self.conv2 = nn.Conv2d(planes, planes, 3, stride, 1, bias=False) | |
| self.bn2 = nn.BatchNorm2d(planes) | |
| self.conv3 = nn.Conv2d( | |
| planes, planes * self.expansion, kernel_size=1, bias=False) | |
| self.bn3 = nn.BatchNorm2d(planes * self.expansion) | |
| self.relu = nn.ReLU(inplace=True) | |
| if downsample: | |
| self.downsample = nn.Sequential( | |
| nn.Conv2d( | |
| inplanes, planes * self.expansion, 1, stride, bias=False), | |
| nn.BatchNorm2d(planes * self.expansion), | |
| ) | |
| else: | |
| self.downsample = nn.Sequential() | |
| def forward(self, x): | |
| residual = self.downsample(x) | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = self.relu(out) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| out = self.relu(out) | |
| out = self.conv3(out) | |
| out = self.bn3(out) | |
| out += residual | |
| out = self.relu(out) | |
| return out | |