| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from annotator.uniformer.mmcv.cnn import CONV_LAYERS, ConvAWS2d, constant_init |
| from annotator.uniformer.mmcv.ops.deform_conv import deform_conv2d |
| from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version |
|
|
|
|
| @CONV_LAYERS.register_module(name='SAC') |
| class SAConv2d(ConvAWS2d): |
| """SAC (Switchable Atrous Convolution) |
| |
| This is an implementation of SAC in DetectoRS |
| (https://arxiv.org/pdf/2006.02334.pdf). |
| |
| Args: |
| in_channels (int): Number of channels in the input image |
| out_channels (int): Number of channels produced by the convolution |
| kernel_size (int or tuple): Size of the convolving kernel |
| stride (int or tuple, optional): Stride of the convolution. Default: 1 |
| padding (int or tuple, optional): Zero-padding added to both sides of |
| the input. Default: 0 |
| padding_mode (string, optional): ``'zeros'``, ``'reflect'``, |
| ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` |
| dilation (int or tuple, optional): Spacing between kernel elements. |
| Default: 1 |
| groups (int, optional): Number of blocked connections from input |
| channels to output channels. Default: 1 |
| bias (bool, optional): If ``True``, adds a learnable bias to the |
| output. Default: ``True`` |
| use_deform: If ``True``, replace convolution with deformable |
| convolution. Default: ``False``. |
| """ |
|
|
| def __init__(self, |
| in_channels, |
| out_channels, |
| kernel_size, |
| stride=1, |
| padding=0, |
| dilation=1, |
| groups=1, |
| bias=True, |
| use_deform=False): |
| super().__init__( |
| in_channels, |
| out_channels, |
| kernel_size, |
| stride=stride, |
| padding=padding, |
| dilation=dilation, |
| groups=groups, |
| bias=bias) |
| self.use_deform = use_deform |
| self.switch = nn.Conv2d( |
| self.in_channels, 1, kernel_size=1, stride=stride, bias=True) |
| self.weight_diff = nn.Parameter(torch.Tensor(self.weight.size())) |
| self.pre_context = nn.Conv2d( |
| self.in_channels, self.in_channels, kernel_size=1, bias=True) |
| self.post_context = nn.Conv2d( |
| self.out_channels, self.out_channels, kernel_size=1, bias=True) |
| if self.use_deform: |
| self.offset_s = nn.Conv2d( |
| self.in_channels, |
| 18, |
| kernel_size=3, |
| padding=1, |
| stride=stride, |
| bias=True) |
| self.offset_l = nn.Conv2d( |
| self.in_channels, |
| 18, |
| kernel_size=3, |
| padding=1, |
| stride=stride, |
| bias=True) |
| self.init_weights() |
|
|
| def init_weights(self): |
| constant_init(self.switch, 0, bias=1) |
| self.weight_diff.data.zero_() |
| constant_init(self.pre_context, 0) |
| constant_init(self.post_context, 0) |
| if self.use_deform: |
| constant_init(self.offset_s, 0) |
| constant_init(self.offset_l, 0) |
|
|
| def forward(self, x): |
| |
| avg_x = F.adaptive_avg_pool2d(x, output_size=1) |
| avg_x = self.pre_context(avg_x) |
| avg_x = avg_x.expand_as(x) |
| x = x + avg_x |
| |
| avg_x = F.pad(x, pad=(2, 2, 2, 2), mode='reflect') |
| avg_x = F.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0) |
| switch = self.switch(avg_x) |
| |
| weight = self._get_weight(self.weight) |
| zero_bias = torch.zeros( |
| self.out_channels, device=weight.device, dtype=weight.dtype) |
|
|
| if self.use_deform: |
| offset = self.offset_s(avg_x) |
| out_s = deform_conv2d(x, offset, weight, self.stride, self.padding, |
| self.dilation, self.groups, 1) |
| else: |
| if (TORCH_VERSION == 'parrots' |
| or digit_version(TORCH_VERSION) < digit_version('1.5.0')): |
| out_s = super().conv2d_forward(x, weight) |
| elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'): |
| |
| out_s = super()._conv_forward(x, weight, zero_bias) |
| else: |
| out_s = super()._conv_forward(x, weight) |
| ori_p = self.padding |
| ori_d = self.dilation |
| self.padding = tuple(3 * p for p in self.padding) |
| self.dilation = tuple(3 * d for d in self.dilation) |
| weight = weight + self.weight_diff |
| if self.use_deform: |
| offset = self.offset_l(avg_x) |
| out_l = deform_conv2d(x, offset, weight, self.stride, self.padding, |
| self.dilation, self.groups, 1) |
| else: |
| if (TORCH_VERSION == 'parrots' |
| or digit_version(TORCH_VERSION) < digit_version('1.5.0')): |
| out_l = super().conv2d_forward(x, weight) |
| elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'): |
| |
| out_l = super()._conv_forward(x, weight, zero_bias) |
| else: |
| out_l = super()._conv_forward(x, weight) |
|
|
| out = switch * out_s + (1 - switch) * out_l |
| self.padding = ori_p |
| self.dilation = ori_d |
| |
| avg_x = F.adaptive_avg_pool2d(out, output_size=1) |
| avg_x = self.post_context(avg_x) |
| avg_x = avg_x.expand_as(out) |
| out = out + avg_x |
| return out |
|
|