""" Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py Original author cavalleria """ import torch import torch.nn as nn from torch.nn import BatchNorm1d from torch.nn import BatchNorm2d from torch.nn import Conv2d from torch.nn import Linear from torch.nn import Module from torch.nn import PReLU from torch.nn import Sequential class Flatten(Module): def forward(self, x): return x.view(x.size(0), -1) class ConvBlock(Module): def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): super(ConvBlock, self).__init__() self.layers = nn.Sequential( Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False), BatchNorm2d(num_features=out_c), PReLU(num_parameters=out_c), ) def forward(self, x): return self.layers(x) class LinearBlock(Module): def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): super(LinearBlock, self).__init__() self.layers = nn.Sequential( Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False), BatchNorm2d(num_features=out_c) ) def forward(self, x): return self.layers(x) class DepthWise(Module): def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1): super(DepthWise, self).__init__() self.residual = residual self.layers = nn.Sequential( ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)), ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride), LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)), ) def forward(self, x): short_cut = None if self.residual: short_cut = x x = self.layers(x) if self.residual: output = short_cut + x else: output = x return output class Residual(Module): def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)): super(Residual, self).__init__() modules = [] for _ in range(num_block): modules.append(DepthWise(c, c, True, kernel, stride, padding, groups)) self.layers = Sequential(*modules) def forward(self, x): return self.layers(x) class GDC(Module): def __init__(self, embedding_size): super(GDC, self).__init__() self.layers = nn.Sequential( LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)), Flatten(), Linear(512, embedding_size, bias=False), BatchNorm1d(embedding_size), ) def forward(self, x): return self.layers(x) class MobileFaceNet(Module): def __init__(self, fp16=False, num_features=512, blocks=(1, 4, 6, 2), scale=2): super(MobileFaceNet, self).__init__() self.scale = scale self.fp16 = fp16 self.layers = nn.ModuleList() self.layers.append(ConvBlock(3, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1))) if blocks[0] == 1: self.layers.append( ConvBlock(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64) ) else: self.layers.append( Residual( 64 * self.scale, num_block=blocks[0], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1) ), ) self.layers.extend( [ DepthWise(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128), Residual( 64 * self.scale, num_block=blocks[1], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1) ), DepthWise(64 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256), Residual( 128 * self.scale, num_block=blocks[2], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1) ), DepthWise(128 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512), Residual( 128 * self.scale, num_block=blocks[3], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1) ), ] ) self.conv_sep = ConvBlock(128 * self.scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)) self.features = GDC(num_features) self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: m.bias.data.zero_() def forward(self, x): with torch.cuda.amp.autocast(self.fp16): for func in self.layers: x = func(x) x = self.conv_sep(x.float() if self.fp16 else x) x = self.features(x) return x def get_mbf(fp16, num_features, blocks=(1, 4, 6, 2), scale=2): return MobileFaceNet(fp16, num_features, blocks, scale=scale) def get_mbf_large(fp16, num_features, blocks=(2, 8, 12, 4), scale=4): return MobileFaceNet(fp16, num_features, blocks, scale=scale)