import torch import torch.nn as nn import torch.nn.functional as F import numpy from utils.acc import accuracy class proto(nn.Module): def __init__(self, **kwargs): super(proto, self).__init__() self.test_normalize = False self.criterion = torch.nn.CrossEntropyLoss() print('Initialised Prototypical Loss') def forward(self, x, label=None): assert x.size()[1] >= 2 out_anchor = torch.mean(x[:, 1:, :], 1) out_positive = x[:, 0, :] stepsize = out_anchor.size()[0] # print(out_anchor.shape, out_positive.shape) # print(out_positive.unsqueeze(-1).shape, out_anchor.unsqueeze(-1).transpose(0, 2).shape) # (10, 512, 1) (1,512,10)生成一个矩阵,使相同的靠近,对角线靠近。 output = -1 * (F.pairwise_distance(out_positive.unsqueeze(-1), out_anchor.unsqueeze(-1).transpose(0,2))**2) # print(output) label = torch.from_numpy(numpy.asarray(range(0,stepsize))).cuda() # label = torch.from_numpy(numpy.asarray(range(0, stepsize))) # print(label) nloss = self.criterion(output, label) prec1 = accuracy(output.detach(), label.detach(), topk=(1,))[0] return nloss, prec1 if __name__ == "__main__": # x = torch.randn(10, 10, 512) # loss = LossFunction() # nloss, prec1 = loss(x) # print(nloss, prec1) x = torch.randint(10, (10,512,10)) y = torch.randint(10, (10,512,10)) d = F.pairwise_distance(x,y) print(d) print(d.shape)