| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| import timm |
|
|
| from transformers import PreTrainedModel |
|
|
| from .heads import ArcMarginProduct, ElasticArcFace, ArcFaceSubCenterDynamic |
| from .configuration_miewid import MiewIdNetConfig |
|
|
| def weights_init_kaiming(m): |
| classname = m.__class__.__name__ |
| if classname.find('Linear') != -1: |
| nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') |
| nn.init.constant_(m.bias, 0.0) |
| elif classname.find('Conv') != -1: |
| nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0.0) |
| elif classname.find('BatchNorm') != -1: |
| if m.affine: |
| nn.init.constant_(m.weight, 1.0) |
| nn.init.constant_(m.bias, 0.0) |
|
|
|
|
| def weights_init_classifier(m): |
| classname = m.__class__.__name__ |
| if classname.find('Linear') != -1: |
| nn.init.normal_(m.weight, std=0.001) |
| if m.bias: |
| nn.init.constant_(m.bias, 0.0) |
|
|
| class GeM(nn.Module): |
| def __init__(self, p=3, eps=1e-6): |
| super(GeM, self).__init__() |
| self.p = nn.Parameter(torch.ones(1)*p) |
| self.eps = eps |
|
|
| def forward(self, x): |
| return self.gem(x, p=self.p, eps=self.eps) |
| |
| def gem(self, x, p=3, eps=1e-6): |
| return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p) |
| |
| def __repr__(self): |
| return self.__class__.__name__ + \ |
| '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \ |
| ', ' + 'eps=' + str(self.eps) + ')' |
|
|
| class MiewIdNet(PreTrainedModel): |
| config_class = MiewIdNetConfig |
| |
| def __init__(self, config): |
| """ |
| """ |
| super(MiewIdNet, self).__init__(config) |
| print('Building Model Backbone for {} model'.format(config.model_name)) |
| print('config.model_name', config.model_name) |
|
|
| n_classes=config.n_classes |
| model_name=config.model_name |
| use_fc=False |
| fc_dim=512 |
| dropout=0.0 |
| loss_module=config.loss_module |
| s=30.0 |
| margin=0.50 |
| ls_eps=0.0 |
| theta_zero=0.785 |
| pretrained=True |
| margins=config.k |
| k=config.k |
|
|
| print('model_name', model_name) |
|
|
| self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0) |
| final_in_features = 2152 |
|
|
| print('final_in_features', final_in_features) |
| |
| |
| self.backbone.global_pool = GeM() |
| |
| |
| self.bn = nn.BatchNorm1d(final_in_features) |
| self.use_fc = use_fc |
| if use_fc: |
| self.dropout = nn.Dropout(p=dropout) |
| self.bn = nn.BatchNorm1d(fc_dim) |
| self.bn.bias.requires_grad_(False) |
| self.fc = nn.Linear(final_in_features, n_classes, bias = False) |
| self.bn.apply(weights_init_kaiming) |
| self.fc.apply(weights_init_classifier) |
| final_in_features = fc_dim |
|
|
| self.loss_module = loss_module |
| if loss_module == 'arcface': |
| self.final = ElasticArcFace(final_in_features, n_classes, |
| s=s, m=margin) |
| elif loss_module == 'arcface_subcenter_dynamic': |
| if margins is None: |
| margins = [0.3] * n_classes |
| print(final_in_features, n_classes) |
| self.final = ArcFaceSubCenterDynamic( |
| embedding_dim=final_in_features, |
| output_classes=n_classes, |
| margins=margins, |
| s=s, |
| k=k ) |
| |
| |
| |
| |
| else: |
| self.final = nn.Linear(final_in_features, n_classes) |
|
|
| def _init_params(self): |
| nn.init.xavier_normal_(self.fc.weight) |
| nn.init.constant_(self.fc.bias, 0) |
| nn.init.constant_(self.bn.weight, 1) |
| nn.init.constant_(self.bn.bias, 0) |
|
|
| def forward(self, x, label=None): |
| feature = self.extract_feat(x) |
|
|
| return feature |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| def extract_feat(self, x): |
| batch_size = x.shape[0] |
| x = self.backbone(x).view(batch_size, -1) |
| |
| x = self.bn(x) |
| if self.use_fc: |
| x1 = self.dropout(x) |
| x1 = self.bn(x1) |
| x1 = self.fc(x1) |
| |
| return x |
|
|
| def extract_logits(self, x, label=None): |
| feature = self.extract_feat(x) |
| assert label is not None |
| if self.loss_module in ('arcface', 'arcface_subcenter_dynamic'): |
| logits = self.final(feature, label) |
| else: |
| logits = self.final(feature) |
| |
| return logits |