smi08's picture
Upload folder using huggingface_hub
188f311 verified
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import torch.nn.functional as F
import torch.nn as nn
import torch
from torch.nn.parameter import Parameter
from proard.utils import (
get_same_padding,
sub_filter_start_end,
make_divisible,
SEModule,
MyNetwork,
MyConv2d,
)
__all__ = [
"DynamicSeparableConv2d",
"DynamicConv2d",
"DynamicGroupConv2d",
"DynamicBatchNorm2d",
"DynamicGroupNorm",
"DynamicSE",
"DynamicLinear",
]
# Seprable conv consits of a depthwise and pointwise conv
class DynamicSeparableConv2d(nn.Module):
KERNEL_TRANSFORM_MODE = 1 # None or 1
def __init__(self, max_in_channels, kernel_size_list, stride=1, dilation=1):
super(DynamicSeparableConv2d, self).__init__()
self.max_in_channels = max_in_channels
self.kernel_size_list = kernel_size_list # list of kernel size
self.stride = stride
self.dilation = dilation
self.conv = nn.Conv2d(
self.max_in_channels,
self.max_in_channels,
max(self.kernel_size_list),
self.stride,
groups=self.max_in_channels,
bias=False,
)
self._ks_set = list(set(self.kernel_size_list))
self._ks_set.sort() # e.g., [3, 5, 7]
# define a matrix for converting from damll kernel size to larger one
if self.KERNEL_TRANSFORM_MODE is not None:
# register scaling parameters
# 7to5_matrix, 5to3_matrix
scale_params = {}
for i in range(len(self._ks_set) - 1):
ks_small = self._ks_set[i]
ks_larger = self._ks_set[i + 1]
param_name = "%dto%d" % (ks_larger, ks_small)
# noinspection PyArgumentList
scale_params["%s_matrix" % param_name] = Parameter(
torch.eye(ks_small ** 2)
)
for name, param in scale_params.items():
self.register_parameter(name, param)
self.active_kernel_size = max(self.kernel_size_list)
def get_active_filter(self, in_channel, kernel_size):
out_channel = in_channel
max_kernel_size = max(self.kernel_size_list)
start, end = sub_filter_start_end(max_kernel_size, kernel_size)
filters = self.conv.weight[:out_channel, :in_channel, start:end, start:end]
if self.KERNEL_TRANSFORM_MODE is not None and kernel_size < max_kernel_size:
start_filter = self.conv.weight[
:out_channel, :in_channel, :, :
] # start with max kernel
for i in range(len(self._ks_set) - 1, 0, -1):
src_ks = self._ks_set[i]
if src_ks <= kernel_size:
break
target_ks = self._ks_set[i - 1]
start, end = sub_filter_start_end(src_ks, target_ks)
_input_filter = start_filter[:, :, start:end, start:end]
_input_filter = _input_filter.contiguous()
_input_filter = _input_filter.view(
_input_filter.size(0), _input_filter.size(1), -1
)
_input_filter = _input_filter.view(-1, _input_filter.size(2))
_input_filter = F.linear(
_input_filter,
self.__getattr__("%dto%d_matrix" % (src_ks, target_ks)),
)
_input_filter = _input_filter.view(
filters.size(0), filters.size(1), target_ks ** 2
)
_input_filter = _input_filter.view(
filters.size(0), filters.size(1), target_ks, target_ks
)
start_filter = _input_filter
filters = start_filter
return filters
def forward(self, x, kernel_size=None):
if kernel_size is None:
kernel_size = self.active_kernel_size
in_channel = x.size(1)
filters = self.get_active_filter(in_channel, kernel_size).contiguous()
padding = get_same_padding(kernel_size)
filters = (
self.conv.weight_standardization(filters)
if isinstance(self.conv, MyConv2d)
else filters
)
y = F.conv2d(x, filters, None, self.stride, padding, self.dilation, in_channel)
return y
class DynamicConv2d(nn.Module):
def __init__(
self, max_in_channels, max_out_channels, kernel_size=1, stride=1, dilation=1
):
super(DynamicConv2d, self).__init__()
self.max_in_channels = max_in_channels
self.max_out_channels = max_out_channels
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.conv = nn.Conv2d(
self.max_in_channels,
self.max_out_channels,
self.kernel_size,
stride=self.stride,
bias=False,
)
self.active_out_channel = self.max_out_channels
def get_active_filter(self, out_channel, in_channel):
return self.conv.weight[:out_channel, :in_channel, :, :]
def forward(self, x, out_channel=None):
if out_channel is None:
out_channel = self.active_out_channel
in_channel = x.size(1)
filters = self.get_active_filter(out_channel, in_channel).contiguous()
padding = get_same_padding(self.kernel_size)
filters = (
self.conv.weight_standardization(filters)
if isinstance(self.conv, MyConv2d)
else filters
)
y = F.conv2d(x, filters, None, self.stride, padding, self.dilation, 1)
return y
class DynamicGroupConv2d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size_list,
groups_list,
stride=1,
dilation=1,
):
super(DynamicGroupConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size_list = kernel_size_list
self.groups_list = groups_list
self.stride = stride
self.dilation = dilation
self.conv = nn.Conv2d(
self.in_channels,
self.out_channels,
max(self.kernel_size_list),
self.stride,
groups=min(self.groups_list),
bias=False,
)
self.active_kernel_size = max(self.kernel_size_list)
self.active_groups = min(self.groups_list)
def get_active_filter(self, kernel_size, groups):
start, end = sub_filter_start_end(max(self.kernel_size_list), kernel_size)
filters = self.conv.weight[:, :, start:end, start:end]
sub_filters = torch.chunk(filters, groups, dim=0)
sub_in_channels = self.in_channels // groups
sub_ratio = filters.size(1) // sub_in_channels
filter_crops = []
for i, sub_filter in enumerate(sub_filters):
part_id = i % sub_ratio
start = part_id * sub_in_channels
filter_crops.append(sub_filter[:, start : start + sub_in_channels, :, :])
filters = torch.cat(filter_crops, dim=0)
return filters
def forward(self, x, kernel_size=None, groups=None):
if kernel_size is None:
kernel_size = self.active_kernel_size
if groups is None:
groups = self.active_groups
filters = self.get_active_filter(kernel_size, groups).contiguous()
padding = get_same_padding(kernel_size)
filters = (
self.conv.weight_standardization(filters)
if isinstance(self.conv, MyConv2d)
else filters
)
y = F.conv2d(
x,
filters,
None,
self.stride,
padding,
self.dilation,
groups,
)
return y
class DynamicBatchNorm2d(nn.Module):
SET_RUNNING_STATISTICS = False
def __init__(self, max_feature_dim):
super(DynamicBatchNorm2d, self).__init__()
self.max_feature_dim = max_feature_dim
self.bn = nn.BatchNorm2d(self.max_feature_dim)
@staticmethod
def bn_forward(x, bn: nn.BatchNorm2d, feature_dim):
if bn.num_features == feature_dim or DynamicBatchNorm2d.SET_RUNNING_STATISTICS:
return bn(x)
else:
exponential_average_factor = 0.0
if bn.training and bn.track_running_stats:
if bn.num_batches_tracked is not None:
bn.num_batches_tracked += 1
if bn.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(bn.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = bn.momentum
return F.batch_norm(
x,
bn.running_mean[:feature_dim],
bn.running_var[:feature_dim],
bn.weight[:feature_dim],
bn.bias[:feature_dim],
bn.training or not bn.track_running_stats,
exponential_average_factor,
bn.eps,
)
def forward(self, x):
feature_dim = x.size(1)
y = self.bn_forward(x, self.bn, feature_dim)
return y
class DynamicGroupNorm(nn.GroupNorm):
def __init__(
self, num_groups, num_channels, eps=1e-5, affine=True, channel_per_group=None
):
super(DynamicGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
self.channel_per_group = channel_per_group
def forward(self, x):
n_channels = x.size(1)
n_groups = n_channels // self.channel_per_group
return F.group_norm(
x, n_groups, self.weight[:n_channels], self.bias[:n_channels], self.eps
)
@property
def bn(self):
return self
class DynamicSE(SEModule):
def __init__(self, max_channel):
super(DynamicSE, self).__init__(max_channel)
def get_active_reduce_weight(self, num_mid, in_channel, groups=None):
if groups is None or groups == 1:
return self.fc.reduce.weight[:num_mid, :in_channel, :, :]
else:
assert in_channel % groups == 0
sub_in_channels = in_channel // groups
sub_filters = torch.chunk(
self.fc.reduce.weight[:num_mid, :, :, :], groups, dim=1
)
return torch.cat(
[sub_filter[:, :sub_in_channels, :, :] for sub_filter in sub_filters],
dim=1,
)
def get_active_reduce_bias(self, num_mid):
return (
self.fc.reduce.bias[:num_mid] if self.fc.reduce.bias is not None else None
)
def get_active_expand_weight(self, num_mid, in_channel, groups=None):
if groups is None or groups == 1:
return self.fc.expand.weight[:in_channel, :num_mid, :, :]
else:
assert in_channel % groups == 0
sub_in_channels = in_channel // groups
sub_filters = torch.chunk(
self.fc.expand.weight[:, :num_mid, :, :], groups, dim=0
)
return torch.cat(
[sub_filter[:sub_in_channels, :, :, :] for sub_filter in sub_filters],
dim=0,
)
def get_active_expand_bias(self, in_channel, groups=None):
if groups is None or groups == 1:
return (
self.fc.expand.bias[:in_channel]
if self.fc.expand.bias is not None
else None
)
else:
assert in_channel % groups == 0
sub_in_channels = in_channel // groups
sub_bias_list = torch.chunk(self.fc.expand.bias, groups, dim=0)
return torch.cat(
[sub_bias[:sub_in_channels] for sub_bias in sub_bias_list], dim=0
)
def forward(self, x, groups=None):
in_channel = x.size(1)
num_mid = make_divisible(
in_channel // self.reduction, divisor=MyNetwork.CHANNEL_DIVISIBLE
)
y = x.mean(3, keepdim=True).mean(2, keepdim=True)
# reduce
reduce_filter = self.get_active_reduce_weight(
num_mid, in_channel, groups=groups
).contiguous()
reduce_bias = self.get_active_reduce_bias(num_mid)
y = F.conv2d(y, reduce_filter, reduce_bias, 1, 0, 1, 1)
# relu
y = self.fc.relu(y)
# expand
expand_filter = self.get_active_expand_weight(
num_mid, in_channel, groups=groups
).contiguous()
expand_bias = self.get_active_expand_bias(in_channel, groups=groups)
y = F.conv2d(y, expand_filter, expand_bias, 1, 0, 1, 1)
# hard sigmoid
y = self.fc.h_sigmoid(y)
return x * y
class DynamicLinear(nn.Module):
def __init__(self, max_in_features, max_out_features, bias=True):
super(DynamicLinear, self).__init__()
self.max_in_features = max_in_features
self.max_out_features = max_out_features
self.bias = bias
self.linear = nn.Linear(self.max_in_features, self.max_out_features, self.bias)
self.active_out_features = self.max_out_features
def get_active_weight(self, out_features, in_features):
return self.linear.weight[:out_features, :in_features]
def get_active_bias(self, out_features):
return self.linear.bias[:out_features] if self.bias else None
def forward(self, x, out_features=None):
if out_features is None:
out_features = self.active_out_features
in_features = x.size(1)
weight = self.get_active_weight(out_features, in_features).contiguous()
bias = self.get_active_bias(out_features)
y = F.linear(x, weight, bias)
return y