|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import init as init |
|
|
from torch.nn.modules.utils import _pair, _single |
|
|
import math |
|
|
|
|
|
class ModulatedDeformConv2d(nn.Module): |
|
|
def __init__(self, |
|
|
in_channels, |
|
|
out_channels, |
|
|
kernel_size, |
|
|
stride=1, |
|
|
padding=0, |
|
|
dilation=1, |
|
|
groups=1, |
|
|
deform_groups=1, |
|
|
bias=True): |
|
|
super(ModulatedDeformConv2d, self).__init__() |
|
|
|
|
|
self.in_channels = in_channels |
|
|
self.out_channels = out_channels |
|
|
self.kernel_size = _pair(kernel_size) |
|
|
self.stride = stride |
|
|
self.padding = padding |
|
|
self.dilation = dilation |
|
|
self.groups = groups |
|
|
self.deform_groups = deform_groups |
|
|
self.with_bias = bias |
|
|
|
|
|
self.transposed = False |
|
|
self.output_padding = _single(0) |
|
|
|
|
|
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) |
|
|
if bias: |
|
|
self.bias = nn.Parameter(torch.Tensor(out_channels)) |
|
|
else: |
|
|
self.register_parameter('bias', None) |
|
|
self.init_weights() |
|
|
|
|
|
def init_weights(self): |
|
|
n = self.in_channels |
|
|
for k in self.kernel_size: |
|
|
n *= k |
|
|
stdv = 1. / math.sqrt(n) |
|
|
self.weight.data.uniform_(-stdv, stdv) |
|
|
if self.bias is not None: |
|
|
self.bias.data.zero_() |
|
|
|
|
|
if hasattr(self, 'conv_offset'): |
|
|
self.conv_offset.weight.data.zero_() |
|
|
self.conv_offset.bias.data.zero_() |
|
|
|
|
|
def forward(self, x, offset, mask): |
|
|
pass |