| import torch |
| import torch.nn as nn |
| import math |
| import torch.nn.functional as F |
| from torch.nn.parameter import Parameter |
|
|
| class ArcMarginProduct(nn.Module): |
| r"""Implement of large margin arc distance: : |
| Args: |
| in_features: size of each input sample |
| out_features: size of each output sample |
| s: norm of input feature |
| m: margin |
| cos(theta + m)wandb: ERROR Abnormal program exit |
| |
| """ |
| def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False, ls_eps=0.0): |
| super(ArcMarginProduct, self).__init__() |
| self.in_features = in_features |
| self.out_features = out_features |
| self.s = s |
| self.m = m |
| self.ls_eps = ls_eps |
| self.weight = Parameter(torch.FloatTensor(out_features, in_features)) |
| nn.init.xavier_uniform_(self.weight) |
|
|
| self.easy_margin = easy_margin |
| self.cos_m = math.cos(m) |
| self.sin_m = math.sin(m) |
| self.th = math.cos(math.pi - m) |
| self.mm = math.sin(math.pi - m) * m |
|
|
| def forward(self, input, label): |
| |
| cosine = F.linear(F.normalize(input), F.normalize(self.weight)) |
| sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) |
| phi = cosine * self.cos_m - sine * self.sin_m |
| if self.easy_margin: |
| phi = torch.where(cosine > 0, phi, cosine) |
| else: |
| phi = torch.where(cosine > self.th, phi, cosine - self.mm) |
| |
| |
| one_hot = torch.zeros(cosine.size(), device='cuda') |
| one_hot.scatter_(1, label.view(-1, 1).long(), 1) |
| if self.ls_eps > 0: |
| one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features |
| |
| output = (one_hot * phi) + ((1.0 - one_hot) * cosine) |
| output *= self.s |
|
|
| return output |
| |
|
|
| def l2_norm(input, axis = 1): |
| norm = torch.norm(input, 2, axis, True) |
| output = torch.div(input, norm) |
|
|
| return output |
| class ElasticArcFace(nn.Module): |
| def __init__(self, in_features, out_features, s=64.0, m=0.50,std=0.0125,plus=False, k=None): |
| super(ElasticArcFace, self).__init__() |
| self.in_features = in_features |
| self.out_features = out_features |
| self.s = s |
| self.m = m |
| self.kernel = nn.Parameter(torch.FloatTensor(in_features, out_features)) |
| nn.init.normal_(self.kernel, std=0.01) |
| self.std=std |
| self.plus=plus |
| def forward(self, embbedings, label): |
| embbedings = l2_norm(embbedings, axis=1) |
| kernel_norm = l2_norm(self.kernel, axis=0) |
| cos_theta = torch.mm(embbedings, kernel_norm) |
| cos_theta = cos_theta.clamp(-1, 1) |
| index = torch.where(label != -1)[0] |
| m_hot = torch.zeros(index.size()[0], cos_theta.size()[1], device=cos_theta.device) |
| margin = torch.normal(mean=self.m, std=self.std, size=label[index, None].size(), device=cos_theta.device) |
| if self.plus: |
| with torch.no_grad(): |
| distmat = cos_theta[index, label.view(-1)].detach().clone() |
| _, idicate_cosie = torch.sort(distmat, dim=0, descending=True) |
| margin, _ = torch.sort(margin, dim=0) |
| m_hot.scatter_(1, label[index, None], margin[idicate_cosie]) |
| else: |
| m_hot.scatter_(1, label[index, None], margin) |
| cos_theta.acos_() |
| cos_theta[index] += m_hot |
| cos_theta.cos_().mul_(self.s) |
| return cos_theta |
| |
| |
|
|
| class ArcMarginProduct_subcenter(nn.Module): |
| def __init__(self, in_features, out_features, k=3): |
| super().__init__() |
| self.weight = nn.Parameter(torch.FloatTensor(out_features*k, in_features)) |
| self.reset_parameters() |
| self.k = k |
| self.out_features = out_features |
| |
| def reset_parameters(self): |
| stdv = 1. / math.sqrt(self.weight.size(1)) |
| self.weight.data.uniform_(-stdv, stdv) |
| |
| def forward(self, features): |
| cosine_all = F.linear(F.normalize(features), F.normalize(self.weight)) |
| cosine_all = cosine_all.view(-1, self.out_features, self.k) |
| cosine, _ = torch.max(cosine_all, dim=2) |
| return cosine |
| |
| class ArcFaceLossAdaptiveMargin(nn.modules.Module): |
| def __init__(self, margins, out_dim, s): |
| super().__init__() |
| |
| self.s = s |
| self.register_buffer('margins', torch.tensor(margins)) |
| self.out_dim = out_dim |
| |
| def forward(self, logits, labels): |
| |
| |
| ms = self.margins[labels] |
| cos_m = torch.cos(ms) |
| sin_m = torch.sin(ms) |
| th = torch.cos(math.pi - ms) |
| mm = torch.sin(math.pi - ms) * ms |
| labels = F.one_hot(labels, self.out_dim).float() |
| cosine = logits |
| sine = torch.sqrt(1.0 - cosine * cosine) |
| phi = cosine * cos_m.view(-1,1) - sine * sin_m.view(-1,1) |
| phi = torch.where(cosine > th.view(-1,1), phi, cosine - mm.view(-1,1)) |
| output = (labels * phi) + ((1.0 - labels) * cosine) |
| output *= self.s |
| return output |
| |
| class ArcFaceSubCenterDynamic(nn.Module): |
| def __init__( |
| self, |
| embedding_dim, |
| output_classes, |
| margins, |
| s, |
| k=2, |
| ): |
| super().__init__() |
|
|
| self.embedding_dim = embedding_dim |
| self.output_classes = output_classes |
| self.margins = margins |
| self.s = s |
| self.wmetric_classify = ArcMarginProduct_subcenter(self.embedding_dim, self.output_classes, k=k) |
| |
| self.warcface_margin = ArcFaceLossAdaptiveMargin(margins=self.margins, |
| out_dim=self.output_classes, |
| s=self.s) |
|
|
| def forward(self, features, labels): |
| logits = self.wmetric_classify(features.float()) |
| logits = self.warcface_margin(logits, labels) |
| return logits |