speaker_verification / lossfunction /AdditiveAngularMargin.py
xiaoxuezi's picture
2
875baeb
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy
import math
from utils.acc import accuracy
class AdditiveAngularMargin(nn.Module):
def __init__(self,
feature_dim=256,
n_classes=1000,
margin=0.2,
scale=30,
easy_margin=False):
super(AdditiveAngularMargin, self).__init__()
self.margin = margin
self.scale = scale
self.easy_margin = easy_margin
self.w = nn.Parameter(torch.FloatTensor(feature_dim, n_classes))
nn.init.xavier_normal_(self.w)
self.cos_m = math.cos(self.margin)
self.sin_m = math.sin(self.margin)
self.th = math.cos(math.pi - self.margin)
self.mm = math.sin(math.pi - self.margin) * self.margin
self.nll_loss = nn.NLLLoss()
self.n_classes = n_classes
self.test_normalize = True
def forward(self, logits, targets):
# logits = self.drop(logits)
logits = F.normalize(logits, p=2, dim=1, eps=1e-8)
wn = F.normalize(self.w, p=2, dim=0, eps=1e-8)
cosine = logits @ wn
#cosine = outputs.astype('float32')
sine = torch.sqrt(1.0 - torch.square(cosine))
phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m)
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
target_one_hot = F.one_hot(targets, self.n_classes)
outputs = (target_one_hot * phi) + ((1.0 - target_one_hot) * cosine)
outputs = self.scale * outputs
pred = F.log_softmax(outputs, dim=-1)
nloss = self.nll_loss(pred, targets)
prec1 = accuracy(pred.detach(), targets.detach(), topk=(1,))[0]
return nloss, prec1