Spaces:
Running
on
L40S
Running
on
L40S
| # The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at | |
| # https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/model_irse.py | |
| from collections import namedtuple | |
| from torch.nn import BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, MaxPool2d, Module, PReLU, Sequential | |
| from .common import Flatten, SEModule, initialize_weights | |
| class BasicBlockIR(Module): | |
| """ BasicBlock for IRNet | |
| """ | |
| def __init__(self, in_channel, depth, stride): | |
| super(BasicBlockIR, self).__init__() | |
| if in_channel == depth: | |
| self.shortcut_layer = MaxPool2d(1, stride) | |
| else: | |
| self.shortcut_layer = Sequential( | |
| Conv2d(in_channel, depth, (1, 1), stride, bias=False), | |
| BatchNorm2d(depth)) | |
| self.res_layer = Sequential( | |
| BatchNorm2d(in_channel), | |
| Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), | |
| BatchNorm2d(depth), PReLU(depth), | |
| Conv2d(depth, depth, (3, 3), stride, 1, bias=False), | |
| BatchNorm2d(depth)) | |
| def forward(self, x): | |
| shortcut = self.shortcut_layer(x) | |
| res = self.res_layer(x) | |
| return res + shortcut | |
| class BottleneckIR(Module): | |
| """ BasicBlock with bottleneck for IRNet | |
| """ | |
| def __init__(self, in_channel, depth, stride): | |
| super(BottleneckIR, self).__init__() | |
| reduction_channel = depth // 4 | |
| if in_channel == depth: | |
| self.shortcut_layer = MaxPool2d(1, stride) | |
| else: | |
| self.shortcut_layer = Sequential( | |
| Conv2d(in_channel, depth, (1, 1), stride, bias=False), | |
| BatchNorm2d(depth)) | |
| self.res_layer = Sequential( | |
| BatchNorm2d(in_channel), | |
| Conv2d( | |
| in_channel, reduction_channel, (1, 1), (1, 1), 0, bias=False), | |
| BatchNorm2d(reduction_channel), PReLU(reduction_channel), | |
| Conv2d( | |
| reduction_channel, | |
| reduction_channel, (3, 3), (1, 1), | |
| 1, | |
| bias=False), BatchNorm2d(reduction_channel), | |
| PReLU(reduction_channel), | |
| Conv2d(reduction_channel, depth, (1, 1), stride, 0, bias=False), | |
| BatchNorm2d(depth)) | |
| def forward(self, x): | |
| shortcut = self.shortcut_layer(x) | |
| res = self.res_layer(x) | |
| return res + shortcut | |
| class BasicBlockIRSE(BasicBlockIR): | |
| def __init__(self, in_channel, depth, stride): | |
| super(BasicBlockIRSE, self).__init__(in_channel, depth, stride) | |
| self.res_layer.add_module('se_block', SEModule(depth, 16)) | |
| class BottleneckIRSE(BottleneckIR): | |
| def __init__(self, in_channel, depth, stride): | |
| super(BottleneckIRSE, self).__init__(in_channel, depth, stride) | |
| self.res_layer.add_module('se_block', SEModule(depth, 16)) | |
| class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): | |
| '''A named tuple describing a ResNet block.''' | |
| def get_block(in_channel, depth, num_units, stride=2): | |
| return [Bottleneck(in_channel, depth, stride)] + \ | |
| [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] | |
| def get_blocks(num_layers): | |
| if num_layers == 18: | |
| blocks = [ | |
| get_block(in_channel=64, depth=64, num_units=2), | |
| get_block(in_channel=64, depth=128, num_units=2), | |
| get_block(in_channel=128, depth=256, num_units=2), | |
| get_block(in_channel=256, depth=512, num_units=2) | |
| ] | |
| elif num_layers == 34: | |
| blocks = [ | |
| get_block(in_channel=64, depth=64, num_units=3), | |
| get_block(in_channel=64, depth=128, num_units=4), | |
| get_block(in_channel=128, depth=256, num_units=6), | |
| get_block(in_channel=256, depth=512, num_units=3) | |
| ] | |
| elif num_layers == 50: | |
| blocks = [ | |
| get_block(in_channel=64, depth=64, num_units=3), | |
| get_block(in_channel=64, depth=128, num_units=4), | |
| get_block(in_channel=128, depth=256, num_units=14), | |
| get_block(in_channel=256, depth=512, num_units=3) | |
| ] | |
| elif num_layers == 100: | |
| blocks = [ | |
| get_block(in_channel=64, depth=64, num_units=3), | |
| get_block(in_channel=64, depth=128, num_units=13), | |
| get_block(in_channel=128, depth=256, num_units=30), | |
| get_block(in_channel=256, depth=512, num_units=3) | |
| ] | |
| elif num_layers == 152: | |
| blocks = [ | |
| get_block(in_channel=64, depth=256, num_units=3), | |
| get_block(in_channel=256, depth=512, num_units=8), | |
| get_block(in_channel=512, depth=1024, num_units=36), | |
| get_block(in_channel=1024, depth=2048, num_units=3) | |
| ] | |
| elif num_layers == 200: | |
| blocks = [ | |
| get_block(in_channel=64, depth=256, num_units=3), | |
| get_block(in_channel=256, depth=512, num_units=24), | |
| get_block(in_channel=512, depth=1024, num_units=36), | |
| get_block(in_channel=1024, depth=2048, num_units=3) | |
| ] | |
| return blocks | |
| class Backbone(Module): | |
| def __init__(self, input_size, num_layers, mode='ir'): | |
| """ Args: | |
| input_size: input_size of backbone | |
| num_layers: num_layers of backbone | |
| mode: support ir or irse | |
| """ | |
| super(Backbone, self).__init__() | |
| assert input_size[0] in [112, 224], \ | |
| 'input_size should be [112, 112] or [224, 224]' | |
| assert num_layers in [18, 34, 50, 100, 152, 200], \ | |
| 'num_layers should be 18, 34, 50, 100 or 152' | |
| assert mode in ['ir', 'ir_se'], \ | |
| 'mode should be ir or ir_se' | |
| self.input_layer = Sequential( | |
| Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), | |
| PReLU(64)) | |
| blocks = get_blocks(num_layers) | |
| if num_layers <= 100: | |
| if mode == 'ir': | |
| unit_module = BasicBlockIR | |
| elif mode == 'ir_se': | |
| unit_module = BasicBlockIRSE | |
| output_channel = 512 | |
| else: | |
| if mode == 'ir': | |
| unit_module = BottleneckIR | |
| elif mode == 'ir_se': | |
| unit_module = BottleneckIRSE | |
| output_channel = 2048 | |
| if input_size[0] == 112: | |
| self.output_layer = Sequential( | |
| BatchNorm2d(output_channel), Dropout(0.4), Flatten(), | |
| Linear(output_channel * 7 * 7, 512), | |
| BatchNorm1d(512, affine=False)) | |
| else: | |
| self.output_layer = Sequential( | |
| BatchNorm2d(output_channel), Dropout(0.4), Flatten(), | |
| Linear(output_channel * 14 * 14, 512), | |
| BatchNorm1d(512, affine=False)) | |
| modules = [] | |
| mid_layer_indices = [] # [2, 15, 45, 48], total 49 layers for IR101 | |
| for block in blocks: | |
| if len(mid_layer_indices) == 0: | |
| mid_layer_indices.append(len(block) - 1) | |
| else: | |
| mid_layer_indices.append(len(block) + mid_layer_indices[-1]) | |
| for bottleneck in block: | |
| modules.append( | |
| unit_module(bottleneck.in_channel, bottleneck.depth, | |
| bottleneck.stride)) | |
| self.body = Sequential(*modules) | |
| self.mid_layer_indices = mid_layer_indices[-4:] | |
| # self.dtype = next(self.parameters()).dtype | |
| initialize_weights(self.modules()) | |
| def device(self): | |
| return next(self.parameters()).device | |
| def dtype(self): | |
| return next(self.parameters()).dtype | |
| def forward(self, x, return_mid_feats=False): | |
| x = self.input_layer(x) | |
| if not return_mid_feats: | |
| x = self.body(x) | |
| x = self.output_layer(x) | |
| return x | |
| else: | |
| out_feats = [] | |
| for idx, module in enumerate(self.body): | |
| x = module(x) | |
| if idx in self.mid_layer_indices: | |
| out_feats.append(x) | |
| x = self.output_layer(x) | |
| return x, out_feats | |
| def IR_18(input_size): | |
| """ Constructs a ir-18 model. | |
| """ | |
| model = Backbone(input_size, 18, 'ir') | |
| return model | |
| def IR_34(input_size): | |
| """ Constructs a ir-34 model. | |
| """ | |
| model = Backbone(input_size, 34, 'ir') | |
| return model | |
| def IR_50(input_size): | |
| """ Constructs a ir-50 model. | |
| """ | |
| model = Backbone(input_size, 50, 'ir') | |
| return model | |
| def IR_101(input_size): | |
| """ Constructs a ir-101 model. | |
| """ | |
| model = Backbone(input_size, 100, 'ir') | |
| return model | |
| def IR_152(input_size): | |
| """ Constructs a ir-152 model. | |
| """ | |
| model = Backbone(input_size, 152, 'ir') | |
| return model | |
| def IR_200(input_size): | |
| """ Constructs a ir-200 model. | |
| """ | |
| model = Backbone(input_size, 200, 'ir') | |
| return model | |
| def IR_SE_50(input_size): | |
| """ Constructs a ir_se-50 model. | |
| """ | |
| model = Backbone(input_size, 50, 'ir_se') | |
| return model | |
| def IR_SE_101(input_size): | |
| """ Constructs a ir_se-101 model. | |
| """ | |
| model = Backbone(input_size, 100, 'ir_se') | |
| return model | |
| def IR_SE_152(input_size): | |
| """ Constructs a ir_se-152 model. | |
| """ | |
| model = Backbone(input_size, 152, 'ir_se') | |
| return model | |
| def IR_SE_200(input_size): | |
| """ Constructs a ir_se-200 model. | |
| """ | |
| model = Backbone(input_size, 200, 'ir_se') | |
| return model | |