| import math | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn.modules.utils import _pair | |
| from ..functions.deform_conv import deform_conv, modulated_deform_conv | |
| class DeformConv(nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=1, | |
| padding=0, | |
| dilation=1, | |
| groups=1, | |
| deformable_groups=1, | |
| bias=False): | |
| super(DeformConv, self).__init__() | |
| assert not bias | |
| assert in_channels % groups == 0, \ | |
| 'in_channels {} cannot be divisible by groups {}'.format( | |
| in_channels, groups) | |
| assert out_channels % groups == 0, \ | |
| 'out_channels {} cannot be divisible by groups {}'.format( | |
| out_channels, groups) | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.kernel_size = _pair(kernel_size) | |
| self.stride = _pair(stride) | |
| self.padding = _pair(padding) | |
| self.dilation = _pair(dilation) | |
| self.groups = groups | |
| self.deformable_groups = deformable_groups | |
| self.weight = nn.Parameter( | |
| torch.Tensor(out_channels, in_channels // self.groups, | |
| *self.kernel_size)) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| n = self.in_channels | |
| for k in self.kernel_size: | |
| n *= k | |
| stdv = 1. / math.sqrt(n) | |
| self.weight.data.uniform_(-stdv, stdv) | |
| def forward(self, x, offset): | |
| return deform_conv(x, offset, self.weight, self.stride, self.padding, | |
| self.dilation, self.groups, self.deformable_groups) | |
| class DeformConvPack(DeformConv): | |
| def __init__(self, *args, **kwargs): | |
| super(DeformConvPack, self).__init__(*args, **kwargs) | |
| self.conv_offset = nn.Conv2d( | |
| self.in_channels, | |
| self.deformable_groups * 2 * self.kernel_size[0] * | |
| self.kernel_size[1], | |
| kernel_size=self.kernel_size, | |
| stride=_pair(self.stride), | |
| padding=_pair(self.padding), | |
| bias=True) | |
| self.init_offset() | |
| def init_offset(self): | |
| self.conv_offset.weight.data.zero_() | |
| self.conv_offset.bias.data.zero_() | |
| def forward(self, x): | |
| offset = self.conv_offset(x) | |
| return deform_conv(x, offset, self.weight, self.stride, self.padding, | |
| self.dilation, self.groups, self.deformable_groups) | |
| class ModulatedDeformConv(nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=1, | |
| padding=0, | |
| dilation=1, | |
| groups=1, | |
| deformable_groups=1, | |
| bias=True): | |
| super(ModulatedDeformConv, self).__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.kernel_size = _pair(kernel_size) | |
| self.stride = stride | |
| self.padding = padding | |
| self.dilation = dilation | |
| self.groups = groups | |
| self.deformable_groups = deformable_groups | |
| self.with_bias = bias | |
| self.weight = nn.Parameter( | |
| torch.Tensor(out_channels, in_channels // groups, | |
| *self.kernel_size)) | |
| if bias: | |
| self.bias = nn.Parameter(torch.Tensor(out_channels)) | |
| else: | |
| self.register_parameter('bias', None) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| n = self.in_channels | |
| for k in self.kernel_size: | |
| n *= k | |
| stdv = 1. / math.sqrt(n) | |
| self.weight.data.uniform_(-stdv, stdv) | |
| if self.bias is not None: | |
| self.bias.data.zero_() | |
| def forward(self, x, offset, mask): | |
| return modulated_deform_conv(x, offset, mask, self.weight, self.bias, | |
| self.stride, self.padding, self.dilation, | |
| self.groups, self.deformable_groups) | |
| class ModulatedDeformConvPack(ModulatedDeformConv): | |
| def __init__(self, *args, **kwargs): | |
| super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) | |
| self.conv_offset_mask = nn.Conv2d( | |
| self.in_channels, | |
| self.deformable_groups * 3 * self.kernel_size[0] * | |
| self.kernel_size[1], | |
| kernel_size=self.kernel_size, | |
| stride=_pair(self.stride), | |
| padding=_pair(self.padding), | |
| bias=True) | |
| self.init_offset() | |
| def init_offset(self): | |
| self.conv_offset_mask.weight.data.zero_() | |
| self.conv_offset_mask.bias.data.zero_() | |
| def forward(self, x): | |
| out = self.conv_offset_mask(x) | |
| o1, o2, mask = torch.chunk(out, 3, dim=1) | |
| offset = torch.cat((o1, o2), dim=1) | |
| mask = torch.sigmoid(mask) | |
| return modulated_deform_conv(x, offset, mask, self.weight, self.bias, | |
| self.stride, self.padding, self.dilation, | |
| self.groups, self.deformable_groups) | |