Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy | |
| from utils.acc import accuracy | |
| class proto(nn.Module): | |
| def __init__(self, **kwargs): | |
| super(proto, self).__init__() | |
| self.test_normalize = False | |
| self.criterion = torch.nn.CrossEntropyLoss() | |
| print('Initialised Prototypical Loss') | |
| def forward(self, x, label=None): | |
| assert x.size()[1] >= 2 | |
| out_anchor = torch.mean(x[:, 1:, :], 1) | |
| out_positive = x[:, 0, :] | |
| stepsize = out_anchor.size()[0] | |
| # print(out_anchor.shape, out_positive.shape) | |
| # print(out_positive.unsqueeze(-1).shape, out_anchor.unsqueeze(-1).transpose(0, 2).shape) | |
| # (10, 512, 1) (1,512,10)生成一个矩阵,使相同的靠近,对角线靠近。 | |
| output = -1 * (F.pairwise_distance(out_positive.unsqueeze(-1), out_anchor.unsqueeze(-1).transpose(0,2))**2) | |
| # print(output) | |
| label = torch.from_numpy(numpy.asarray(range(0,stepsize))).cuda() | |
| # label = torch.from_numpy(numpy.asarray(range(0, stepsize))) | |
| # print(label) | |
| nloss = self.criterion(output, label) | |
| prec1 = accuracy(output.detach(), label.detach(), topk=(1,))[0] | |
| return nloss, prec1 | |
| if __name__ == "__main__": | |
| # x = torch.randn(10, 10, 512) | |
| # loss = LossFunction() | |
| # nloss, prec1 = loss(x) | |
| # print(nloss, prec1) | |
| x = torch.randint(10, (10,512,10)) | |
| y = torch.randint(10, (10,512,10)) | |
| d = F.pairwise_distance(x,y) | |
| print(d) | |
| print(d.shape) |