# Once for All: Train One Network and Specialize it for Efficient Deployment # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han # International Conference on Learning Representations (ICLR), 2020. import copy import torch import torch.nn as nn from collections import OrderedDict from proard.utils.layers import ( MBConvLayer, ConvLayer, IdentityLayer, set_layer_from_config, ) from proard.utils.layers import ResNetBottleneckBlock, LinearLayer from proard.utils import ( MyModule, val2list, get_net_device, build_activation, make_divisible, SEModule, MyNetwork, ) from .dynamic_op import ( DynamicSeparableConv2d, DynamicConv2d, DynamicBatchNorm2d, DynamicSE, DynamicGroupNorm, ) from .dynamic_op import DynamicLinear __all__ = [ "adjust_bn_according_to_idx", "copy_bn", "DynamicMBConvLayer", "DynamicConvLayer", "DynamicLinearLayer", "DynamicResNetBottleneckBlock", ] def adjust_bn_according_to_idx(bn, idx): bn.weight.data = torch.index_select(bn.weight.data, 0, idx) bn.bias.data = torch.index_select(bn.bias.data, 0, idx) if type(bn) in [nn.BatchNorm1d, nn.BatchNorm2d]: bn.running_mean.data = torch.index_select(bn.running_mean.data, 0, idx) bn.running_var.data = torch.index_select(bn.running_var.data, 0, idx) def copy_bn(target_bn, src_bn): feature_dim = ( target_bn.num_channels if isinstance(target_bn, nn.GroupNorm) else target_bn.num_features ) target_bn.weight.data.copy_(src_bn.weight.data[:feature_dim]) target_bn.bias.data.copy_(src_bn.bias.data[:feature_dim]) if type(src_bn) in [nn.BatchNorm1d, nn.BatchNorm2d]: target_bn.running_mean.data.copy_(src_bn.running_mean.data[:feature_dim]) target_bn.running_var.data.copy_(src_bn.running_var.data[:feature_dim]) class DynamicLinearLayer(MyModule): def __init__(self, in_features_list, out_features, bias=True, dropout_rate=0): super(DynamicLinearLayer, self).__init__() self.in_features_list = in_features_list self.out_features = out_features self.bias = bias self.dropout_rate = dropout_rate if self.dropout_rate > 0: self.dropout = nn.Dropout(self.dropout_rate, inplace=True) else: self.dropout = None self.linear = DynamicLinear( max_in_features=max(self.in_features_list), max_out_features=self.out_features, bias=self.bias, ) def forward(self, x): if self.dropout is not None: x = self.dropout(x) return self.linear(x) @property def module_str(self): return "DyLinear(%d, %d)" % (max(self.in_features_list), self.out_features) @property def config(self): return { "name": DynamicLinear.__name__, "in_features_list": self.in_features_list, "out_features": self.out_features, "bias": self.bias, "dropout_rate": self.dropout_rate, } @staticmethod def build_from_config(config): return DynamicLinearLayer(**config) def get_active_subnet(self, in_features, preserve_weight=True): sub_layer = LinearLayer( in_features, self.out_features, self.bias, dropout_rate=self.dropout_rate ) sub_layer = sub_layer.to(get_net_device(self)) if not preserve_weight: return sub_layer sub_layer.linear.weight.data.copy_( self.linear.get_active_weight(self.out_features, in_features).data ) if self.bias: sub_layer.linear.bias.data.copy_( self.linear.get_active_bias(self.out_features).data ) return sub_layer def get_active_subnet_config(self, in_features): return { "name": LinearLayer.__name__, "in_features": in_features, "out_features": self.out_features, "bias": self.bias, "dropout_rate": self.dropout_rate, } class DynamicMBConvLayer(MyModule): def __init__( self, in_channel_list, out_channel_list, kernel_size_list=3, expand_ratio_list=6, stride=1, act_func="relu6", use_se=False, ): super(DynamicMBConvLayer, self).__init__() self.in_channel_list = in_channel_list self.out_channel_list = out_channel_list self.kernel_size_list = val2list(kernel_size_list) self.expand_ratio_list = val2list(expand_ratio_list) self.stride = stride self.act_func = act_func self.use_se = use_se # build modules max_middle_channel = make_divisible( round(max(self.in_channel_list) * max(self.expand_ratio_list)), MyNetwork.CHANNEL_DIVISIBLE, ) if max(self.expand_ratio_list) == 1: self.inverted_bottleneck = None else: self.inverted_bottleneck = nn.Sequential( OrderedDict( [ ( "conv", DynamicConv2d( max(self.in_channel_list), max_middle_channel ), ), ("bn", DynamicBatchNorm2d(max_middle_channel)), ("act", build_activation(self.act_func)), ] ) ) self.depth_conv = nn.Sequential( OrderedDict( [ ( "conv", DynamicSeparableConv2d( max_middle_channel, self.kernel_size_list, self.stride ), ), ("bn", DynamicBatchNorm2d(max_middle_channel)), ("act", build_activation(self.act_func)), ] ) ) if self.use_se: self.depth_conv.add_module("se", DynamicSE(max_middle_channel)) self.point_linear = nn.Sequential( OrderedDict( [ ( "conv", DynamicConv2d(max_middle_channel, max(self.out_channel_list)), ), ("bn", DynamicBatchNorm2d(max(self.out_channel_list))), ] ) ) self.active_kernel_size = max(self.kernel_size_list) self.active_expand_ratio = max(self.expand_ratio_list) self.active_out_channel = max(self.out_channel_list) def forward(self, x): in_channel = x.size(1) if self.inverted_bottleneck is not None: self.inverted_bottleneck.conv.active_out_channel = make_divisible( round(in_channel * self.active_expand_ratio), MyNetwork.CHANNEL_DIVISIBLE, ) self.depth_conv.conv.active_kernel_size = self.active_kernel_size self.point_linear.conv.active_out_channel = self.active_out_channel if self.inverted_bottleneck is not None: x = self.inverted_bottleneck(x) x = self.depth_conv(x) x = self.point_linear(x) return x @property def module_str(self): if self.use_se: return "SE(O%d, E%.1f, K%d)" % ( self.active_out_channel, self.active_expand_ratio, self.active_kernel_size, ) else: return "(O%d, E%.1f, K%d)" % ( self.active_out_channel, self.active_expand_ratio, self.active_kernel_size, ) @property def config(self): return { "name": DynamicMBConvLayer.__name__, "in_channel_list": self.in_channel_list, "out_channel_list": self.out_channel_list, "kernel_size_list": self.kernel_size_list, "expand_ratio_list": self.expand_ratio_list, "stride": self.stride, "act_func": self.act_func, "use_se": self.use_se, } @staticmethod def build_from_config(config): return DynamicMBConvLayer(**config) ############################################################################################ @property def in_channels(self): return max(self.in_channel_list) @property def out_channels(self): return max(self.out_channel_list) def active_middle_channel(self, in_channel): return make_divisible( round(in_channel * self.active_expand_ratio), MyNetwork.CHANNEL_DIVISIBLE ) ############################################################################################ def get_active_subnet(self, in_channel, preserve_weight=True): # build the new layer sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel)) sub_layer = sub_layer.to(get_net_device(self)) if not preserve_weight: return sub_layer middle_channel = self.active_middle_channel(in_channel) # copy weight from current layer if sub_layer.inverted_bottleneck is not None: sub_layer.inverted_bottleneck.conv.weight.data.copy_( self.inverted_bottleneck.conv.get_active_filter( middle_channel, in_channel ).data, ) copy_bn(sub_layer.inverted_bottleneck.bn, self.inverted_bottleneck.bn.bn) sub_layer.depth_conv.conv.weight.data.copy_( self.depth_conv.conv.get_active_filter( middle_channel, self.active_kernel_size ).data ) copy_bn(sub_layer.depth_conv.bn, self.depth_conv.bn.bn) if self.use_se: se_mid = make_divisible( middle_channel // SEModule.REDUCTION, divisor=MyNetwork.CHANNEL_DIVISIBLE, ) sub_layer.depth_conv.se.fc.reduce.weight.data.copy_( self.depth_conv.se.get_active_reduce_weight(se_mid, middle_channel).data ) sub_layer.depth_conv.se.fc.reduce.bias.data.copy_( self.depth_conv.se.get_active_reduce_bias(se_mid).data ) sub_layer.depth_conv.se.fc.expand.weight.data.copy_( self.depth_conv.se.get_active_expand_weight(se_mid, middle_channel).data ) sub_layer.depth_conv.se.fc.expand.bias.data.copy_( self.depth_conv.se.get_active_expand_bias(middle_channel).data ) sub_layer.point_linear.conv.weight.data.copy_( self.point_linear.conv.get_active_filter( self.active_out_channel, middle_channel ).data ) copy_bn(sub_layer.point_linear.bn, self.point_linear.bn.bn) return sub_layer def get_active_subnet_config(self, in_channel): return { "name": MBConvLayer.__name__, "in_channels": in_channel, "out_channels": self.active_out_channel, "kernel_size": self.active_kernel_size, "stride": self.stride, "expand_ratio": self.active_expand_ratio, "mid_channels": self.active_middle_channel(in_channel), "act_func": self.act_func, "use_se": self.use_se, } def re_organize_middle_weights(self, expand_ratio_stage=0): importance = torch.sum( torch.abs(self.point_linear.conv.conv.weight.data), dim=(0, 2, 3) ) if isinstance(self.depth_conv.bn, DynamicGroupNorm): channel_per_group = self.depth_conv.bn.channel_per_group importance_chunks = torch.split(importance, channel_per_group) for chunk in importance_chunks: chunk.data.fill_(torch.mean(chunk)) importance = torch.cat(importance_chunks, dim=0) if expand_ratio_stage > 0: sorted_expand_list = copy.deepcopy(self.expand_ratio_list) sorted_expand_list.sort(reverse=True) target_width_list = [ make_divisible( round(max(self.in_channel_list) * expand), MyNetwork.CHANNEL_DIVISIBLE, ) for expand in sorted_expand_list ] right = len(importance) base = -len(target_width_list) * 1e5 for i in range(expand_ratio_stage + 1): left = target_width_list[i] importance[left:right] += base base += 1e5 right = left sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True) self.point_linear.conv.conv.weight.data = torch.index_select( self.point_linear.conv.conv.weight.data, 1, sorted_idx ) adjust_bn_according_to_idx(self.depth_conv.bn.bn, sorted_idx) self.depth_conv.conv.conv.weight.data = torch.index_select( self.depth_conv.conv.conv.weight.data, 0, sorted_idx ) if self.use_se: # se expand: output dim 0 reorganize se_expand = self.depth_conv.se.fc.expand se_expand.weight.data = torch.index_select( se_expand.weight.data, 0, sorted_idx ) se_expand.bias.data = torch.index_select(se_expand.bias.data, 0, sorted_idx) # se reduce: input dim 1 reorganize se_reduce = self.depth_conv.se.fc.reduce se_reduce.weight.data = torch.index_select( se_reduce.weight.data, 1, sorted_idx ) # middle weight reorganize se_importance = torch.sum(torch.abs(se_expand.weight.data), dim=(0, 2, 3)) se_importance, se_idx = torch.sort(se_importance, dim=0, descending=True) se_expand.weight.data = torch.index_select(se_expand.weight.data, 1, se_idx) se_reduce.weight.data = torch.index_select(se_reduce.weight.data, 0, se_idx) se_reduce.bias.data = torch.index_select(se_reduce.bias.data, 0, se_idx) if self.inverted_bottleneck is not None: adjust_bn_according_to_idx(self.inverted_bottleneck.bn.bn, sorted_idx) self.inverted_bottleneck.conv.conv.weight.data = torch.index_select( self.inverted_bottleneck.conv.conv.weight.data, 0, sorted_idx ) return None else: return sorted_idx class DynamicConvLayer(MyModule): def __init__( self, in_channel_list, out_channel_list, kernel_size=3, stride=1, dilation=1, use_bn=True, act_func="relu6", ): super(DynamicConvLayer, self).__init__() self.in_channel_list = in_channel_list self.out_channel_list = out_channel_list self.kernel_size = kernel_size self.stride = stride self.dilation = dilation self.use_bn = use_bn self.act_func = act_func self.conv = DynamicConv2d( max_in_channels=max(self.in_channel_list), max_out_channels=max(self.out_channel_list), kernel_size=self.kernel_size, stride=self.stride, dilation=self.dilation, ) if self.use_bn: self.bn = DynamicBatchNorm2d(max(self.out_channel_list)) self.act = build_activation(self.act_func) self.active_out_channel = max(self.out_channel_list) def forward(self, x): self.conv.active_out_channel = self.active_out_channel x = self.conv(x) if self.use_bn: x = self.bn(x) x = self.act(x) return x @property def module_str(self): return "DyConv(O%d, K%d, S%d)" % ( self.active_out_channel, self.kernel_size, self.stride, ) @property def config(self): return { "name": DynamicConvLayer.__name__, "in_channel_list": self.in_channel_list, "out_channel_list": self.out_channel_list, "kernel_size": self.kernel_size, "stride": self.stride, "dilation": self.dilation, "use_bn": self.use_bn, "act_func": self.act_func, } @staticmethod def build_from_config(config): return DynamicConvLayer(**config) ############################################################################################ @property def in_channels(self): return max(self.in_channel_list) @property def out_channels(self): return max(self.out_channel_list) ############################################################################################ def get_active_subnet(self, in_channel, preserve_weight=True): sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel)) sub_layer = sub_layer.to(get_net_device(self)) if not preserve_weight: return sub_layer sub_layer.conv.weight.data.copy_( self.conv.get_active_filter(self.active_out_channel, in_channel).data ) if self.use_bn: copy_bn(sub_layer.bn, self.bn.bn) return sub_layer def get_active_subnet_config(self, in_channel): return { "name": ConvLayer.__name__, "in_channels": in_channel, "out_channels": self.active_out_channel, "kernel_size": self.kernel_size, "stride": self.stride, "dilation": self.dilation, "use_bn": self.use_bn, "act_func": self.act_func, } class DynamicResNetBottleneckBlock(MyModule): def __init__( self, in_channel_list, out_channel_list, expand_ratio_list=0.25, kernel_size=3, stride=1, act_func="relu", downsample_mode="avgpool_conv", ): super(DynamicResNetBottleneckBlock, self).__init__() self.in_channel_list = in_channel_list self.out_channel_list = out_channel_list self.expand_ratio_list = val2list(expand_ratio_list) self.kernel_size = kernel_size self.stride = stride self.act_func = act_func self.downsample_mode = downsample_mode # build modules max_middle_channel = make_divisible( round(max(self.out_channel_list) * max(self.expand_ratio_list)), MyNetwork.CHANNEL_DIVISIBLE, ) self.conv1 = nn.Sequential( OrderedDict( [ ( "conv", DynamicConv2d(max(self.in_channel_list), max_middle_channel), ), ("bn", DynamicBatchNorm2d(max_middle_channel)), ("act", build_activation(self.act_func, inplace=True)), ] ) ) self.conv2 = nn.Sequential( OrderedDict( [ ( "conv", DynamicConv2d( max_middle_channel, max_middle_channel, kernel_size, stride ), ), ("bn", DynamicBatchNorm2d(max_middle_channel)), ("act", build_activation(self.act_func, inplace=True)), ] ) ) self.conv3 = nn.Sequential( OrderedDict( [ ( "conv", DynamicConv2d(max_middle_channel, max(self.out_channel_list)), ), ("bn", DynamicBatchNorm2d(max(self.out_channel_list))), ] ) ) if self.stride == 1 and self.in_channel_list == self.out_channel_list: self.downsample = IdentityLayer( max(self.in_channel_list), max(self.out_channel_list) ) elif self.downsample_mode == "conv": self.downsample = nn.Sequential( OrderedDict( [ ( "conv", DynamicConv2d( max(self.in_channel_list), max(self.out_channel_list), stride=stride, ), ), ("bn", DynamicBatchNorm2d(max(self.out_channel_list))), ] ) ) elif self.downsample_mode == "avgpool_conv": self.downsample = nn.Sequential( OrderedDict( [ ( "avg_pool", nn.AvgPool2d( kernel_size=stride, stride=stride, padding=0, ceil_mode=True, ), ), ( "conv", DynamicConv2d( max(self.in_channel_list), max(self.out_channel_list) ), ), ("bn", DynamicBatchNorm2d(max(self.out_channel_list))), ] ) ) else: raise NotImplementedError self.final_act = build_activation(self.act_func, inplace=True) self.active_expand_ratio = max(self.expand_ratio_list) self.active_out_channel = max(self.out_channel_list) def forward(self, x): feature_dim = self.active_middle_channels self.conv1.conv.active_out_channel = feature_dim self.conv2.conv.active_out_channel = feature_dim self.conv3.conv.active_out_channel = self.active_out_channel if not isinstance(self.downsample, IdentityLayer): self.downsample.conv.active_out_channel = self.active_out_channel residual = self.downsample(x) x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = x + residual x = self.final_act(x) return x @property def module_str(self): return "(%s, %s)" % ( "%dx%d_BottleneckConv_in->%d->%d_S%d" % ( self.kernel_size, self.kernel_size, self.active_middle_channels, self.active_out_channel, self.stride, ), "Identity" if isinstance(self.downsample, IdentityLayer) else self.downsample_mode, ) @property def config(self): return { "name": DynamicResNetBottleneckBlock.__name__, "in_channel_list": self.in_channel_list, "out_channel_list": self.out_channel_list, "expand_ratio_list": self.expand_ratio_list, "kernel_size": self.kernel_size, "stride": self.stride, "act_func": self.act_func, "downsample_mode": self.downsample_mode, } @staticmethod def build_from_config(config): return DynamicResNetBottleneckBlock(**config) ############################################################################################ @property def in_channels(self): return max(self.in_channel_list) @property def out_channels(self): return max(self.out_channel_list) @property def active_middle_channels(self): feature_dim = round(self.active_out_channel * self.active_expand_ratio) feature_dim = make_divisible(feature_dim, MyNetwork.CHANNEL_DIVISIBLE) return feature_dim ############################################################################################ def get_active_subnet(self, in_channel, preserve_weight=True): # build the new layer sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel)) sub_layer = sub_layer.to(get_net_device(self)) if not preserve_weight: return sub_layer # copy weight from current layer sub_layer.conv1.conv.weight.data.copy_( self.conv1.conv.get_active_filter( self.active_middle_channels, in_channel ).data ) copy_bn(sub_layer.conv1.bn, self.conv1.bn.bn) sub_layer.conv2.conv.weight.data.copy_( self.conv2.conv.get_active_filter( self.active_middle_channels, self.active_middle_channels ).data ) copy_bn(sub_layer.conv2.bn, self.conv2.bn.bn) sub_layer.conv3.conv.weight.data.copy_( self.conv3.conv.get_active_filter( self.active_out_channel, self.active_middle_channels ).data ) copy_bn(sub_layer.conv3.bn, self.conv3.bn.bn) if not isinstance(self.downsample, IdentityLayer): sub_layer.downsample.conv.weight.data.copy_( self.downsample.conv.get_active_filter( self.active_out_channel, in_channel ).data ) copy_bn(sub_layer.downsample.bn, self.downsample.bn.bn) return sub_layer def get_active_subnet_config(self, in_channel): return { "name": ResNetBottleneckBlock.__name__, "in_channels": in_channel, "out_channels": self.active_out_channel, "kernel_size": self.kernel_size, "stride": self.stride, "expand_ratio": self.active_expand_ratio, "mid_channels": self.active_middle_channels, "act_func": self.act_func, "groups": 1, "downsample_mode": self.downsample_mode, } def re_organize_middle_weights(self, expand_ratio_stage=0): # conv3 -> conv2 importance = torch.sum( torch.abs(self.conv3.conv.conv.weight.data), dim=(0, 2, 3) ) if isinstance(self.conv2.bn, DynamicGroupNorm): channel_per_group = self.conv2.bn.channel_per_group importance_chunks = torch.split(importance, channel_per_group) for chunk in importance_chunks: chunk.data.fill_(torch.mean(chunk)) importance = torch.cat(importance_chunks, dim=0) if expand_ratio_stage > 0: sorted_expand_list = copy.deepcopy(self.expand_ratio_list) sorted_expand_list.sort(reverse=True) target_width_list = [ make_divisible( round(max(self.out_channel_list) * expand), MyNetwork.CHANNEL_DIVISIBLE, ) for expand in sorted_expand_list ] right = len(importance) base = -len(target_width_list) * 1e5 for i in range(expand_ratio_stage + 1): left = target_width_list[i] importance[left:right] += base base += 1e5 right = left sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True) self.conv3.conv.conv.weight.data = torch.index_select( self.conv3.conv.conv.weight.data, 1, sorted_idx ) adjust_bn_according_to_idx(self.conv2.bn.bn, sorted_idx) self.conv2.conv.conv.weight.data = torch.index_select( self.conv2.conv.conv.weight.data, 0, sorted_idx ) # conv2 -> conv1 importance = torch.sum( torch.abs(self.conv2.conv.conv.weight.data), dim=(0, 2, 3) ) if isinstance(self.conv1.bn, DynamicGroupNorm): channel_per_group = self.conv1.bn.channel_per_group importance_chunks = torch.split(importance, channel_per_group) for chunk in importance_chunks: chunk.data.fill_(torch.mean(chunk)) importance = torch.cat(importance_chunks, dim=0) if expand_ratio_stage > 0: sorted_expand_list = copy.deepcopy(self.expand_ratio_list) sorted_expand_list.sort(reverse=True) target_width_list = [ make_divisible( round(max(self.out_channel_list) * expand), MyNetwork.CHANNEL_DIVISIBLE, ) for expand in sorted_expand_list ] right = len(importance) base = -len(target_width_list) * 1e5 for i in range(expand_ratio_stage + 1): left = target_width_list[i] importance[left:right] += base base += 1e5 right = left sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True) self.conv2.conv.conv.weight.data = torch.index_select( self.conv2.conv.conv.weight.data, 1, sorted_idx ) adjust_bn_according_to_idx(self.conv1.bn.bn, sorted_idx) self.conv1.conv.conv.weight.data = torch.index_select( self.conv1.conv.conv.weight.data, 0, sorted_idx ) return None