|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch.nn.functional as F |
|
|
import torch.nn as nn |
|
|
import torch |
|
|
from torch.nn.parameter import Parameter |
|
|
|
|
|
from proard.utils import ( |
|
|
get_same_padding, |
|
|
sub_filter_start_end, |
|
|
make_divisible, |
|
|
SEModule, |
|
|
MyNetwork, |
|
|
MyConv2d, |
|
|
) |
|
|
|
|
|
__all__ = [ |
|
|
"DynamicSeparableConv2d", |
|
|
"DynamicConv2d", |
|
|
"DynamicGroupConv2d", |
|
|
"DynamicBatchNorm2d", |
|
|
"DynamicGroupNorm", |
|
|
"DynamicSE", |
|
|
"DynamicLinear", |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
class DynamicSeparableConv2d(nn.Module): |
|
|
KERNEL_TRANSFORM_MODE = 1 |
|
|
|
|
|
def __init__(self, max_in_channels, kernel_size_list, stride=1, dilation=1): |
|
|
super(DynamicSeparableConv2d, self).__init__() |
|
|
|
|
|
self.max_in_channels = max_in_channels |
|
|
self.kernel_size_list = kernel_size_list |
|
|
self.stride = stride |
|
|
self.dilation = dilation |
|
|
|
|
|
self.conv = nn.Conv2d( |
|
|
self.max_in_channels, |
|
|
self.max_in_channels, |
|
|
max(self.kernel_size_list), |
|
|
self.stride, |
|
|
groups=self.max_in_channels, |
|
|
bias=False, |
|
|
) |
|
|
|
|
|
self._ks_set = list(set(self.kernel_size_list)) |
|
|
self._ks_set.sort() |
|
|
|
|
|
if self.KERNEL_TRANSFORM_MODE is not None: |
|
|
|
|
|
|
|
|
scale_params = {} |
|
|
for i in range(len(self._ks_set) - 1): |
|
|
ks_small = self._ks_set[i] |
|
|
ks_larger = self._ks_set[i + 1] |
|
|
param_name = "%dto%d" % (ks_larger, ks_small) |
|
|
|
|
|
scale_params["%s_matrix" % param_name] = Parameter( |
|
|
torch.eye(ks_small ** 2) |
|
|
) |
|
|
for name, param in scale_params.items(): |
|
|
self.register_parameter(name, param) |
|
|
|
|
|
self.active_kernel_size = max(self.kernel_size_list) |
|
|
|
|
|
def get_active_filter(self, in_channel, kernel_size): |
|
|
out_channel = in_channel |
|
|
max_kernel_size = max(self.kernel_size_list) |
|
|
|
|
|
start, end = sub_filter_start_end(max_kernel_size, kernel_size) |
|
|
filters = self.conv.weight[:out_channel, :in_channel, start:end, start:end] |
|
|
if self.KERNEL_TRANSFORM_MODE is not None and kernel_size < max_kernel_size: |
|
|
start_filter = self.conv.weight[ |
|
|
:out_channel, :in_channel, :, : |
|
|
] |
|
|
for i in range(len(self._ks_set) - 1, 0, -1): |
|
|
src_ks = self._ks_set[i] |
|
|
if src_ks <= kernel_size: |
|
|
break |
|
|
target_ks = self._ks_set[i - 1] |
|
|
start, end = sub_filter_start_end(src_ks, target_ks) |
|
|
_input_filter = start_filter[:, :, start:end, start:end] |
|
|
_input_filter = _input_filter.contiguous() |
|
|
_input_filter = _input_filter.view( |
|
|
_input_filter.size(0), _input_filter.size(1), -1 |
|
|
) |
|
|
_input_filter = _input_filter.view(-1, _input_filter.size(2)) |
|
|
_input_filter = F.linear( |
|
|
_input_filter, |
|
|
self.__getattr__("%dto%d_matrix" % (src_ks, target_ks)), |
|
|
) |
|
|
_input_filter = _input_filter.view( |
|
|
filters.size(0), filters.size(1), target_ks ** 2 |
|
|
) |
|
|
_input_filter = _input_filter.view( |
|
|
filters.size(0), filters.size(1), target_ks, target_ks |
|
|
) |
|
|
start_filter = _input_filter |
|
|
filters = start_filter |
|
|
return filters |
|
|
|
|
|
def forward(self, x, kernel_size=None): |
|
|
if kernel_size is None: |
|
|
kernel_size = self.active_kernel_size |
|
|
in_channel = x.size(1) |
|
|
|
|
|
filters = self.get_active_filter(in_channel, kernel_size).contiguous() |
|
|
|
|
|
padding = get_same_padding(kernel_size) |
|
|
filters = ( |
|
|
self.conv.weight_standardization(filters) |
|
|
if isinstance(self.conv, MyConv2d) |
|
|
else filters |
|
|
) |
|
|
y = F.conv2d(x, filters, None, self.stride, padding, self.dilation, in_channel) |
|
|
return y |
|
|
|
|
|
|
|
|
class DynamicConv2d(nn.Module): |
|
|
def __init__( |
|
|
self, max_in_channels, max_out_channels, kernel_size=1, stride=1, dilation=1 |
|
|
): |
|
|
super(DynamicConv2d, self).__init__() |
|
|
|
|
|
self.max_in_channels = max_in_channels |
|
|
self.max_out_channels = max_out_channels |
|
|
self.kernel_size = kernel_size |
|
|
self.stride = stride |
|
|
self.dilation = dilation |
|
|
|
|
|
self.conv = nn.Conv2d( |
|
|
self.max_in_channels, |
|
|
self.max_out_channels, |
|
|
self.kernel_size, |
|
|
stride=self.stride, |
|
|
bias=False, |
|
|
) |
|
|
|
|
|
self.active_out_channel = self.max_out_channels |
|
|
|
|
|
def get_active_filter(self, out_channel, in_channel): |
|
|
return self.conv.weight[:out_channel, :in_channel, :, :] |
|
|
|
|
|
def forward(self, x, out_channel=None): |
|
|
if out_channel is None: |
|
|
out_channel = self.active_out_channel |
|
|
in_channel = x.size(1) |
|
|
filters = self.get_active_filter(out_channel, in_channel).contiguous() |
|
|
|
|
|
padding = get_same_padding(self.kernel_size) |
|
|
filters = ( |
|
|
self.conv.weight_standardization(filters) |
|
|
if isinstance(self.conv, MyConv2d) |
|
|
else filters |
|
|
) |
|
|
y = F.conv2d(x, filters, None, self.stride, padding, self.dilation, 1) |
|
|
return y |
|
|
|
|
|
|
|
|
class DynamicGroupConv2d(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels, |
|
|
out_channels, |
|
|
kernel_size_list, |
|
|
groups_list, |
|
|
stride=1, |
|
|
dilation=1, |
|
|
): |
|
|
super(DynamicGroupConv2d, self).__init__() |
|
|
|
|
|
self.in_channels = in_channels |
|
|
self.out_channels = out_channels |
|
|
self.kernel_size_list = kernel_size_list |
|
|
self.groups_list = groups_list |
|
|
self.stride = stride |
|
|
self.dilation = dilation |
|
|
|
|
|
self.conv = nn.Conv2d( |
|
|
self.in_channels, |
|
|
self.out_channels, |
|
|
max(self.kernel_size_list), |
|
|
self.stride, |
|
|
groups=min(self.groups_list), |
|
|
bias=False, |
|
|
) |
|
|
|
|
|
self.active_kernel_size = max(self.kernel_size_list) |
|
|
self.active_groups = min(self.groups_list) |
|
|
|
|
|
def get_active_filter(self, kernel_size, groups): |
|
|
start, end = sub_filter_start_end(max(self.kernel_size_list), kernel_size) |
|
|
filters = self.conv.weight[:, :, start:end, start:end] |
|
|
|
|
|
sub_filters = torch.chunk(filters, groups, dim=0) |
|
|
sub_in_channels = self.in_channels // groups |
|
|
sub_ratio = filters.size(1) // sub_in_channels |
|
|
|
|
|
filter_crops = [] |
|
|
for i, sub_filter in enumerate(sub_filters): |
|
|
part_id = i % sub_ratio |
|
|
start = part_id * sub_in_channels |
|
|
filter_crops.append(sub_filter[:, start : start + sub_in_channels, :, :]) |
|
|
filters = torch.cat(filter_crops, dim=0) |
|
|
return filters |
|
|
|
|
|
def forward(self, x, kernel_size=None, groups=None): |
|
|
if kernel_size is None: |
|
|
kernel_size = self.active_kernel_size |
|
|
if groups is None: |
|
|
groups = self.active_groups |
|
|
|
|
|
filters = self.get_active_filter(kernel_size, groups).contiguous() |
|
|
padding = get_same_padding(kernel_size) |
|
|
filters = ( |
|
|
self.conv.weight_standardization(filters) |
|
|
if isinstance(self.conv, MyConv2d) |
|
|
else filters |
|
|
) |
|
|
y = F.conv2d( |
|
|
x, |
|
|
filters, |
|
|
None, |
|
|
self.stride, |
|
|
padding, |
|
|
self.dilation, |
|
|
groups, |
|
|
) |
|
|
return y |
|
|
|
|
|
|
|
|
class DynamicBatchNorm2d(nn.Module): |
|
|
SET_RUNNING_STATISTICS = False |
|
|
|
|
|
def __init__(self, max_feature_dim): |
|
|
super(DynamicBatchNorm2d, self).__init__() |
|
|
|
|
|
self.max_feature_dim = max_feature_dim |
|
|
self.bn = nn.BatchNorm2d(self.max_feature_dim) |
|
|
|
|
|
@staticmethod |
|
|
def bn_forward(x, bn: nn.BatchNorm2d, feature_dim): |
|
|
if bn.num_features == feature_dim or DynamicBatchNorm2d.SET_RUNNING_STATISTICS: |
|
|
return bn(x) |
|
|
else: |
|
|
exponential_average_factor = 0.0 |
|
|
|
|
|
if bn.training and bn.track_running_stats: |
|
|
if bn.num_batches_tracked is not None: |
|
|
bn.num_batches_tracked += 1 |
|
|
if bn.momentum is None: |
|
|
exponential_average_factor = 1.0 / float(bn.num_batches_tracked) |
|
|
else: |
|
|
exponential_average_factor = bn.momentum |
|
|
return F.batch_norm( |
|
|
x, |
|
|
bn.running_mean[:feature_dim], |
|
|
bn.running_var[:feature_dim], |
|
|
bn.weight[:feature_dim], |
|
|
bn.bias[:feature_dim], |
|
|
bn.training or not bn.track_running_stats, |
|
|
exponential_average_factor, |
|
|
bn.eps, |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
feature_dim = x.size(1) |
|
|
y = self.bn_forward(x, self.bn, feature_dim) |
|
|
return y |
|
|
|
|
|
|
|
|
class DynamicGroupNorm(nn.GroupNorm): |
|
|
def __init__( |
|
|
self, num_groups, num_channels, eps=1e-5, affine=True, channel_per_group=None |
|
|
): |
|
|
super(DynamicGroupNorm, self).__init__(num_groups, num_channels, eps, affine) |
|
|
self.channel_per_group = channel_per_group |
|
|
|
|
|
def forward(self, x): |
|
|
n_channels = x.size(1) |
|
|
n_groups = n_channels // self.channel_per_group |
|
|
return F.group_norm( |
|
|
x, n_groups, self.weight[:n_channels], self.bias[:n_channels], self.eps |
|
|
) |
|
|
|
|
|
@property |
|
|
def bn(self): |
|
|
return self |
|
|
|
|
|
|
|
|
class DynamicSE(SEModule): |
|
|
def __init__(self, max_channel): |
|
|
super(DynamicSE, self).__init__(max_channel) |
|
|
|
|
|
def get_active_reduce_weight(self, num_mid, in_channel, groups=None): |
|
|
if groups is None or groups == 1: |
|
|
return self.fc.reduce.weight[:num_mid, :in_channel, :, :] |
|
|
else: |
|
|
assert in_channel % groups == 0 |
|
|
sub_in_channels = in_channel // groups |
|
|
sub_filters = torch.chunk( |
|
|
self.fc.reduce.weight[:num_mid, :, :, :], groups, dim=1 |
|
|
) |
|
|
return torch.cat( |
|
|
[sub_filter[:, :sub_in_channels, :, :] for sub_filter in sub_filters], |
|
|
dim=1, |
|
|
) |
|
|
|
|
|
def get_active_reduce_bias(self, num_mid): |
|
|
return ( |
|
|
self.fc.reduce.bias[:num_mid] if self.fc.reduce.bias is not None else None |
|
|
) |
|
|
|
|
|
def get_active_expand_weight(self, num_mid, in_channel, groups=None): |
|
|
if groups is None or groups == 1: |
|
|
return self.fc.expand.weight[:in_channel, :num_mid, :, :] |
|
|
else: |
|
|
assert in_channel % groups == 0 |
|
|
sub_in_channels = in_channel // groups |
|
|
sub_filters = torch.chunk( |
|
|
self.fc.expand.weight[:, :num_mid, :, :], groups, dim=0 |
|
|
) |
|
|
return torch.cat( |
|
|
[sub_filter[:sub_in_channels, :, :, :] for sub_filter in sub_filters], |
|
|
dim=0, |
|
|
) |
|
|
|
|
|
def get_active_expand_bias(self, in_channel, groups=None): |
|
|
if groups is None or groups == 1: |
|
|
return ( |
|
|
self.fc.expand.bias[:in_channel] |
|
|
if self.fc.expand.bias is not None |
|
|
else None |
|
|
) |
|
|
else: |
|
|
assert in_channel % groups == 0 |
|
|
sub_in_channels = in_channel // groups |
|
|
sub_bias_list = torch.chunk(self.fc.expand.bias, groups, dim=0) |
|
|
return torch.cat( |
|
|
[sub_bias[:sub_in_channels] for sub_bias in sub_bias_list], dim=0 |
|
|
) |
|
|
|
|
|
def forward(self, x, groups=None): |
|
|
in_channel = x.size(1) |
|
|
num_mid = make_divisible( |
|
|
in_channel // self.reduction, divisor=MyNetwork.CHANNEL_DIVISIBLE |
|
|
) |
|
|
|
|
|
y = x.mean(3, keepdim=True).mean(2, keepdim=True) |
|
|
|
|
|
reduce_filter = self.get_active_reduce_weight( |
|
|
num_mid, in_channel, groups=groups |
|
|
).contiguous() |
|
|
reduce_bias = self.get_active_reduce_bias(num_mid) |
|
|
y = F.conv2d(y, reduce_filter, reduce_bias, 1, 0, 1, 1) |
|
|
|
|
|
y = self.fc.relu(y) |
|
|
|
|
|
expand_filter = self.get_active_expand_weight( |
|
|
num_mid, in_channel, groups=groups |
|
|
).contiguous() |
|
|
expand_bias = self.get_active_expand_bias(in_channel, groups=groups) |
|
|
y = F.conv2d(y, expand_filter, expand_bias, 1, 0, 1, 1) |
|
|
|
|
|
y = self.fc.h_sigmoid(y) |
|
|
|
|
|
return x * y |
|
|
|
|
|
|
|
|
class DynamicLinear(nn.Module): |
|
|
def __init__(self, max_in_features, max_out_features, bias=True): |
|
|
super(DynamicLinear, self).__init__() |
|
|
|
|
|
self.max_in_features = max_in_features |
|
|
self.max_out_features = max_out_features |
|
|
self.bias = bias |
|
|
|
|
|
self.linear = nn.Linear(self.max_in_features, self.max_out_features, self.bias) |
|
|
|
|
|
self.active_out_features = self.max_out_features |
|
|
|
|
|
def get_active_weight(self, out_features, in_features): |
|
|
return self.linear.weight[:out_features, :in_features] |
|
|
|
|
|
def get_active_bias(self, out_features): |
|
|
return self.linear.bias[:out_features] if self.bias else None |
|
|
|
|
|
def forward(self, x, out_features=None): |
|
|
if out_features is None: |
|
|
out_features = self.active_out_features |
|
|
|
|
|
in_features = x.size(1) |
|
|
weight = self.get_active_weight(out_features, in_features).contiguous() |
|
|
bias = self.get_active_bias(out_features) |
|
|
y = F.linear(x, weight, bias) |
|
|
return y |
|
|
|