|
|
import torch.nn as nn |
|
|
|
|
|
from proard.utils.layers import ( |
|
|
set_layer_from_config, |
|
|
ConvLayer, |
|
|
IdentityLayer, |
|
|
LinearLayer, |
|
|
) |
|
|
from proard.utils.layers import ResNetBottleneckBlock, ResidualBlock |
|
|
from proard.utils import make_divisible, MyNetwork, MyGlobalAvgPool2d |
|
|
|
|
|
__all__ = ["ResNets", "ResNet50", "ResNet50D","ResNets_Cifar","ResNet50_Cifar", "ResNet50D_Cifar"] |
|
|
|
|
|
|
|
|
class ResNets(MyNetwork): |
|
|
BASE_DEPTH_LIST = [2, 2, 4, 2] |
|
|
STAGE_WIDTH_LIST = [256, 512, 1024, 2048] |
|
|
|
|
|
def __init__(self, input_stem, blocks, classifier): |
|
|
super(ResNets, self).__init__() |
|
|
|
|
|
self.input_stem = nn.ModuleList(input_stem) |
|
|
self.max_pooling = nn.MaxPool2d( |
|
|
kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False |
|
|
) |
|
|
self.blocks = nn.ModuleList(blocks) |
|
|
self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=False) |
|
|
self.classifier = classifier |
|
|
|
|
|
def forward(self, x): |
|
|
for layer in self.input_stem: |
|
|
x = layer(x) |
|
|
x = self.max_pooling(x) |
|
|
for block in self.blocks: |
|
|
x = block(x) |
|
|
x = self.global_avg_pool(x) |
|
|
x = self.classifier(x) |
|
|
return x |
|
|
|
|
|
@property |
|
|
def module_str(self): |
|
|
_str = "" |
|
|
for layer in self.input_stem: |
|
|
_str += layer.module_str + "\n" |
|
|
_str += "max_pooling(ks=3, stride=2)\n" |
|
|
for block in self.blocks: |
|
|
_str += block.module_str + "\n" |
|
|
_str += self.global_avg_pool.__repr__() + "\n" |
|
|
_str += self.classifier.module_str |
|
|
return _str |
|
|
|
|
|
@property |
|
|
def config(self): |
|
|
return { |
|
|
"name": ResNets.__name__, |
|
|
"bn": self.get_bn_param(), |
|
|
"input_stem": [layer.config for layer in self.input_stem], |
|
|
"blocks": [block.config for block in self.blocks], |
|
|
"classifier": self.classifier.config, |
|
|
} |
|
|
|
|
|
@staticmethod |
|
|
def build_from_config(config): |
|
|
classifier = set_layer_from_config(config["classifier"]) |
|
|
|
|
|
input_stem = [] |
|
|
for layer_config in config["input_stem"]: |
|
|
input_stem.append(set_layer_from_config(layer_config)) |
|
|
blocks = [] |
|
|
for block_config in config["blocks"]: |
|
|
blocks.append(set_layer_from_config(block_config)) |
|
|
|
|
|
net = ResNets(input_stem, blocks, classifier) |
|
|
if "bn" in config: |
|
|
net.set_bn_param(**config["bn"]) |
|
|
else: |
|
|
net.set_bn_param(momentum=0.1, eps=1e-5) |
|
|
|
|
|
return net |
|
|
|
|
|
def zero_last_gamma(self): |
|
|
for m in self.modules(): |
|
|
if isinstance(m, ResNetBottleneckBlock) and isinstance( |
|
|
m.downsample, IdentityLayer |
|
|
): |
|
|
m.conv3.bn.weight.data.zero_() |
|
|
|
|
|
@property |
|
|
def grouped_block_index(self): |
|
|
info_list = [] |
|
|
block_index_list = [] |
|
|
for i, block in enumerate(self.blocks): |
|
|
if ( |
|
|
not isinstance(block.downsample, IdentityLayer) |
|
|
and len(block_index_list) > 0 |
|
|
): |
|
|
info_list.append(block_index_list) |
|
|
block_index_list = [] |
|
|
block_index_list.append(i) |
|
|
if len(block_index_list) > 0: |
|
|
info_list.append(block_index_list) |
|
|
return info_list |
|
|
|
|
|
def load_state_dict(self, state_dict, **kwargs): |
|
|
super(ResNets, self).load_state_dict(state_dict) |
|
|
|
|
|
|
|
|
class ResNet50(ResNets): |
|
|
def __init__( |
|
|
self, |
|
|
n_classes=1000, |
|
|
width_mult=1.0, |
|
|
bn_param=(0.1, 1e-5), |
|
|
dropout_rate=0, |
|
|
expand_ratio=None, |
|
|
depth_param=None, |
|
|
): |
|
|
|
|
|
expand_ratio = 0.25 if expand_ratio is None else expand_ratio |
|
|
|
|
|
input_channel = make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE) |
|
|
stage_width_list = ResNets.STAGE_WIDTH_LIST.copy() |
|
|
for i, width in enumerate(stage_width_list): |
|
|
stage_width_list[i] = make_divisible( |
|
|
width * width_mult, MyNetwork.CHANNEL_DIVISIBLE |
|
|
) |
|
|
|
|
|
depth_list = [3, 4, 6, 3] |
|
|
if depth_param is not None: |
|
|
for i, depth in enumerate(ResNets.BASE_DEPTH_LIST): |
|
|
depth_list[i] = depth + depth_param |
|
|
|
|
|
stride_list = [1, 2, 2, 2] |
|
|
|
|
|
|
|
|
input_stem = [ |
|
|
ConvLayer( |
|
|
3, |
|
|
input_channel, |
|
|
kernel_size=7, |
|
|
stride=2, |
|
|
use_bn=True, |
|
|
act_func="relu", |
|
|
ops_order="weight_bn_act", |
|
|
) |
|
|
] |
|
|
|
|
|
|
|
|
blocks = [] |
|
|
for d, width, s in zip(depth_list, stage_width_list, stride_list): |
|
|
for i in range(d): |
|
|
stride = s if i == 0 else 1 |
|
|
bottleneck_block = ResNetBottleneckBlock( |
|
|
input_channel, |
|
|
width, |
|
|
kernel_size=3, |
|
|
stride=stride, |
|
|
expand_ratio=expand_ratio, |
|
|
act_func="relu", |
|
|
downsample_mode="conv", |
|
|
) |
|
|
blocks.append(bottleneck_block) |
|
|
input_channel = width |
|
|
|
|
|
classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate) |
|
|
|
|
|
super(ResNet50, self).__init__(input_stem, blocks, classifier) |
|
|
|
|
|
|
|
|
self.set_bn_param(*bn_param) |
|
|
|
|
|
|
|
|
class ResNet50D(ResNets): |
|
|
def __init__( |
|
|
self, |
|
|
n_classes=1000, |
|
|
width_mult=1.0, |
|
|
bn_param=(0.1, 1e-5), |
|
|
dropout_rate=0, |
|
|
expand_ratio=None, |
|
|
depth_param=None, |
|
|
): |
|
|
|
|
|
expand_ratio = 0.25 if expand_ratio is None else expand_ratio |
|
|
|
|
|
input_channel = make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE) |
|
|
mid_input_channel = make_divisible( |
|
|
input_channel // 2, MyNetwork.CHANNEL_DIVISIBLE |
|
|
) |
|
|
stage_width_list = ResNets.STAGE_WIDTH_LIST.copy() |
|
|
for i, width in enumerate(stage_width_list): |
|
|
stage_width_list[i] = make_divisible( |
|
|
width * width_mult, MyNetwork.CHANNEL_DIVISIBLE |
|
|
) |
|
|
|
|
|
depth_list = [3, 4, 6, 3] |
|
|
if depth_param is not None: |
|
|
for i, depth in enumerate(ResNets.BASE_DEPTH_LIST): |
|
|
depth_list[i] = depth + depth_param |
|
|
|
|
|
stride_list = [1, 2, 2, 2] |
|
|
|
|
|
|
|
|
input_stem = [ |
|
|
ConvLayer(3, mid_input_channel, 3, stride=2, use_bn=True, act_func="relu"), |
|
|
ResidualBlock( |
|
|
ConvLayer( |
|
|
mid_input_channel, |
|
|
mid_input_channel, |
|
|
3, |
|
|
stride=1, |
|
|
use_bn=True, |
|
|
act_func="relu", |
|
|
), |
|
|
IdentityLayer(mid_input_channel, mid_input_channel), |
|
|
), |
|
|
ConvLayer( |
|
|
mid_input_channel, |
|
|
input_channel, |
|
|
3, |
|
|
stride=1, |
|
|
use_bn=True, |
|
|
act_func="relu", |
|
|
), |
|
|
] |
|
|
|
|
|
|
|
|
blocks = [] |
|
|
for d, width, s in zip(depth_list, stage_width_list, stride_list): |
|
|
for i in range(d): |
|
|
stride = s if i == 0 else 1 |
|
|
bottleneck_block = ResNetBottleneckBlock( |
|
|
input_channel, |
|
|
width, |
|
|
kernel_size=3, |
|
|
stride=stride, |
|
|
expand_ratio=expand_ratio, |
|
|
act_func="relu", |
|
|
downsample_mode="avgpool_conv", |
|
|
) |
|
|
blocks.append(bottleneck_block) |
|
|
input_channel = width |
|
|
|
|
|
classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate) |
|
|
|
|
|
super(ResNet50D, self).__init__(input_stem, blocks, classifier) |
|
|
|
|
|
|
|
|
self.set_bn_param(*bn_param) |
|
|
|
|
|
|
|
|
|
|
|
class ResNets_Cifar(MyNetwork): |
|
|
|
|
|
BASE_DEPTH_LIST = [2, 2, 4, 2] |
|
|
STAGE_WIDTH_LIST = [256, 512, 1024, 2048] |
|
|
|
|
|
def __init__(self, input_stem, blocks, classifier): |
|
|
super(ResNets_Cifar, self).__init__() |
|
|
|
|
|
self.input_stem = nn.ModuleList(input_stem) |
|
|
self.blocks = nn.ModuleList(blocks) |
|
|
self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=False) |
|
|
self.classifier = classifier |
|
|
|
|
|
def forward(self, x): |
|
|
for layer in self.input_stem: |
|
|
x = layer(x) |
|
|
for block in self.blocks: |
|
|
x = block(x) |
|
|
x = self.global_avg_pool(x) |
|
|
x = self.classifier(x) |
|
|
return x |
|
|
|
|
|
@property |
|
|
def module_str(self): |
|
|
_str = "" |
|
|
for layer in self.input_stem: |
|
|
_str += layer.module_str + "\n" |
|
|
|
|
|
for block in self.blocks: |
|
|
_str += block.module_str + "\n" |
|
|
_str += self.global_avg_pool.__repr__() + "\n" |
|
|
_str += self.classifier.module_str |
|
|
return _str |
|
|
|
|
|
@property |
|
|
def config(self): |
|
|
return { |
|
|
"name": ResNets_Cifar.__name__, |
|
|
"bn": self.get_bn_param(), |
|
|
"input_stem": [layer.config for layer in self.input_stem], |
|
|
"blocks": [block.config for block in self.blocks], |
|
|
"classifier": self.classifier.config, |
|
|
} |
|
|
|
|
|
@staticmethod |
|
|
def build_from_config(config): |
|
|
classifier = set_layer_from_config(config["classifier"]) |
|
|
|
|
|
input_stem = [] |
|
|
for layer_config in config["input_stem"]: |
|
|
input_stem.append(set_layer_from_config(layer_config)) |
|
|
blocks = [] |
|
|
for block_config in config["blocks"]: |
|
|
blocks.append(set_layer_from_config(block_config)) |
|
|
|
|
|
net = ResNets(input_stem, blocks, classifier) |
|
|
if "bn" in config: |
|
|
net.set_bn_param(**config["bn"]) |
|
|
else: |
|
|
net.set_bn_param(momentum=0.1, eps=1e-5) |
|
|
|
|
|
return net |
|
|
|
|
|
def zero_last_gamma(self): |
|
|
for m in self.modules(): |
|
|
if isinstance(m, ResNetBottleneckBlock) and isinstance( |
|
|
m.downsample, IdentityLayer |
|
|
): |
|
|
m.conv3.bn.weight.data.zero_() |
|
|
|
|
|
@property |
|
|
def grouped_block_index(self): |
|
|
info_list = [] |
|
|
block_index_list = [] |
|
|
for i, block in enumerate(self.blocks): |
|
|
if ( |
|
|
not isinstance(block.downsample, IdentityLayer) |
|
|
and len(block_index_list) > 0 |
|
|
): |
|
|
info_list.append(block_index_list) |
|
|
block_index_list = [] |
|
|
block_index_list.append(i) |
|
|
if len(block_index_list) > 0: |
|
|
info_list.append(block_index_list) |
|
|
return info_list |
|
|
|
|
|
def load_state_dict(self, state_dict, **kwargs): |
|
|
super(ResNets_Cifar, self).load_state_dict(state_dict) |
|
|
|
|
|
|
|
|
class ResNet50_Cifar(ResNets_Cifar): |
|
|
def __init__( |
|
|
self, |
|
|
n_classes=10, |
|
|
width_mult=1.0, |
|
|
bn_param=(0.1, 1e-5), |
|
|
dropout_rate=0, |
|
|
expand_ratio=None, |
|
|
depth_param=None, |
|
|
): |
|
|
|
|
|
expand_ratio = 0.25 if expand_ratio is None else expand_ratio |
|
|
|
|
|
input_channel = make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE) |
|
|
stage_width_list = ResNets_Cifar.STAGE_WIDTH_LIST.copy() |
|
|
for i, width in enumerate(stage_width_list): |
|
|
stage_width_list[i] = make_divisible( |
|
|
width * width_mult, MyNetwork.CHANNEL_DIVISIBLE |
|
|
) |
|
|
|
|
|
depth_list = [3, 4, 6, 3] |
|
|
if depth_param is not None: |
|
|
for i, depth in enumerate(ResNets_Cifar.BASE_DEPTH_LIST): |
|
|
depth_list[i] = depth + depth_param |
|
|
|
|
|
stride_list = [1, 2, 2, 2] |
|
|
|
|
|
|
|
|
input_stem = [ |
|
|
ConvLayer( |
|
|
3, |
|
|
input_channel, |
|
|
kernel_size=3, |
|
|
stride=1, |
|
|
use_bn=True, |
|
|
act_func="relu", |
|
|
ops_order="weight_bn_act", |
|
|
) |
|
|
] |
|
|
|
|
|
|
|
|
blocks = [] |
|
|
for d, width, s in zip(depth_list, stage_width_list, stride_list): |
|
|
for i in range(d): |
|
|
stride = s if i == 0 else 1 |
|
|
bottleneck_block = ResNetBottleneckBlock( |
|
|
input_channel, |
|
|
width, |
|
|
kernel_size=3, |
|
|
stride=stride, |
|
|
expand_ratio=expand_ratio, |
|
|
act_func="relu", |
|
|
downsample_mode="conv", |
|
|
) |
|
|
blocks.append(bottleneck_block) |
|
|
input_channel = width |
|
|
|
|
|
classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate) |
|
|
|
|
|
super(ResNet50_Cifar, self).__init__(input_stem, blocks, classifier) |
|
|
|
|
|
|
|
|
self.set_bn_param(*bn_param) |
|
|
|
|
|
|
|
|
class ResNet50D_Cifar(ResNets_Cifar): |
|
|
def __init__( |
|
|
self, |
|
|
n_classes=10, |
|
|
width_mult=1.0, |
|
|
bn_param=(0.1, 1e-5), |
|
|
dropout_rate=0, |
|
|
expand_ratio=None, |
|
|
depth_param=None, |
|
|
): |
|
|
|
|
|
expand_ratio = 0.25 if expand_ratio is None else expand_ratio |
|
|
|
|
|
input_channel = make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE) |
|
|
mid_input_channel = make_divisible( |
|
|
input_channel // 2, MyNetwork.CHANNEL_DIVISIBLE |
|
|
) |
|
|
stage_width_list = ResNets.STAGE_WIDTH_LIST.copy() |
|
|
for i, width in enumerate(stage_width_list): |
|
|
stage_width_list[i] = make_divisible( |
|
|
width * width_mult, MyNetwork.CHANNEL_DIVISIBLE |
|
|
) |
|
|
|
|
|
depth_list = [3, 4, 6, 3] |
|
|
if depth_param is not None: |
|
|
for i, depth in enumerate(ResNets.BASE_DEPTH_LIST): |
|
|
depth_list[i] = depth + depth_param |
|
|
|
|
|
stride_list = [1, 2, 2, 2] |
|
|
|
|
|
|
|
|
input_stem = [ |
|
|
ConvLayer(3, mid_input_channel, 3, stride=1, use_bn=True, act_func="relu"), |
|
|
ResidualBlock( |
|
|
ConvLayer( |
|
|
mid_input_channel, |
|
|
mid_input_channel, |
|
|
3, |
|
|
stride=1, |
|
|
use_bn=True, |
|
|
act_func="relu", |
|
|
), |
|
|
IdentityLayer(mid_input_channel, mid_input_channel), |
|
|
), |
|
|
ConvLayer( |
|
|
mid_input_channel, |
|
|
input_channel, |
|
|
3, |
|
|
stride=1, |
|
|
use_bn=True, |
|
|
act_func="relu", |
|
|
), |
|
|
] |
|
|
|
|
|
|
|
|
blocks = [] |
|
|
for d, width, s in zip(depth_list, stage_width_list, stride_list): |
|
|
for i in range(d): |
|
|
stride = s if i == 0 else 1 |
|
|
bottleneck_block = ResNetBottleneckBlock( |
|
|
input_channel, |
|
|
width, |
|
|
kernel_size=3, |
|
|
stride=stride, |
|
|
expand_ratio=expand_ratio, |
|
|
act_func="relu", |
|
|
downsample_mode="avgpool_conv", |
|
|
) |
|
|
blocks.append(bottleneck_block) |
|
|
input_channel = width |
|
|
|
|
|
classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate) |
|
|
|
|
|
super(ResNet50D_Cifar, self).__init__(input_stem, blocks, classifier) |
|
|
|
|
|
|
|
|
self.set_bn_param(*bn_param) |
|
|
if __name__=="__main__": |
|
|
import torch |
|
|
resnet = ResNet50_Cifar() |
|
|
x = torch.randn((1,3,32,32)) |
|
|
resnet(x) |
|
|
|
|
|
|