import torch import torch.nn as nn from torchvision.ops import deform_conv2d class DeformableConv2d(nn.Module): def __init__( self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False ): super(DeformableConv2d, self).__init__() assert type(kernel_size) == tuple or type(kernel_size) == int kernel_size = ( kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size) ) self.stride = stride if type(stride) == tuple else (stride, stride) self.padding = padding self.offset_conv = nn.Conv2d( in_channels, 2 * kernel_size[0] * kernel_size[1], kernel_size=kernel_size, stride=stride, padding=self.padding, bias=True, ) nn.init.constant_(self.offset_conv.weight, 0.0) nn.init.constant_(self.offset_conv.bias, 0.0) self.modulator_conv = nn.Conv2d( in_channels, 1 * kernel_size[0] * kernel_size[1], kernel_size=kernel_size, stride=stride, padding=self.padding, bias=True, ) nn.init.constant_(self.modulator_conv.weight, 0.0) nn.init.constant_(self.modulator_conv.bias, 0.0) self.regular_conv = nn.Conv2d( in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=self.padding, bias=bias, ) def forward(self, x): # h, w = x.shape[2:] # max_offset = max(h, w)/4. offset = self.offset_conv(x) # .clamp(-max_offset, max_offset) modulator = 2.0 * torch.sigmoid(self.modulator_conv(x)) x = deform_conv2d( input=x, offset=offset, weight=self.regular_conv.weight, bias=self.regular_conv.bias, padding=self.padding, mask=modulator, stride=self.stride, ) return x