#! /usr/bin/python # -*- encoding: utf-8 -*- import torch import torch.nn as nn import lossfunction.softmax as softmax import lossfunction.angleproto as angleproto class SoftmaxProto(nn.Module): def __init__(self, nOut, nClasses): super(SoftmaxProto, self).__init__() self.test_normalize = True self.softmax = softmax.Softmax(nOut, nClasses) self.angleproto = angleproto.AngleProto() print('Initialised SoftmaxPrototypical Loss') def forward(self, x, label=None): if x.size()[1] != 2: # 2是nPerSpeaker x = x.reshape(-1, 2, x.size()[-1]).squeeze(1) assert x.size()[1] == 2 nlossS, prec1 = self.softmax(x.reshape(-1, x.size()[-1]), label.repeat_interleave(2)) nlossP, _ = self.angleproto(x, None) # print("lossP:", nlossP, "nlossS:", nlossS) # lossP:0.6678 nlossS:13.6913 # return nlossS + nlossP, prec1 return nlossS + nlossP