import torch import torch.nn as nn import torch.nn.functional as F import numpy from utils.acc import accuracy class Ge2e(nn.Module): def __init__(self, init_w=10.0, init_b=-5.0, **kwargs): super(Ge2e, self).__init__() self.test_normalize = True self.w = nn.Parameter(torch.tensor(init_w)) self.b = nn.Parameter(torch.tensor(init_b)) self.criterion = torch.nn.CrossEntropyLoss() print('Initialised GE2E') def forward(self, x, label=None): assert x.size()[1] >= 2 gsize = x.size()[1] centroids = torch.mean(x, 1) stepsize = x.size()[0] cos_sim_matrix = [] for ii in range(0,gsize): idx = [*range(0,gsize)] idx.remove(ii) exc_centroids = torch.mean(x[:,idx,:], 1) # (32,512) cos_sim_diag = F.cosine_similarity(x[:,ii,:],exc_centroids) # print(cos_sim_diag.shape) cos_sim = F.cosine_similarity(x[:,ii,:].unsqueeze(-1),centroids.unsqueeze(-1).transpose(0,2)) cos_sim[range(0,stepsize),range(0,stepsize)] = cos_sim_diag cos_sim_matrix.append(torch.clamp(cos_sim,1e-6)) cos_sim_matrix = torch.stack(cos_sim_matrix,dim=1) torch.clamp(self.w, 1e-6) cos_sim_matrix = cos_sim_matrix * self.w + self.b label = torch.from_numpy(numpy.asarray(range(0,stepsize))).cuda() nloss = self.criterion(cos_sim_matrix.view(-1,stepsize), torch.repeat_interleave(label,repeats=gsize,dim=0).cuda()) prec1 = accuracy(cos_sim_matrix.view(-1,stepsize).detach(), torch.repeat_interleave(label,repeats=gsize,dim=0).detach(), topk=(1,))[0] return nloss, prec1 if __name__ == "__main__": x = torch.randn(32, 10, 512).cuda() y = torch.randint(1000, size=(32,)).cuda() print(x.shape, y.shape) loss = Ge2e() nloss, prec1 = loss(x, y) print(nloss, prec1)