xiaoxuezi's picture
2
875baeb
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy
from utils.acc import accuracy
class proto(nn.Module):
def __init__(self, **kwargs):
super(proto, self).__init__()
self.test_normalize = False
self.criterion = torch.nn.CrossEntropyLoss()
print('Initialised Prototypical Loss')
def forward(self, x, label=None):
assert x.size()[1] >= 2
out_anchor = torch.mean(x[:, 1:, :], 1)
out_positive = x[:, 0, :]
stepsize = out_anchor.size()[0]
# print(out_anchor.shape, out_positive.shape)
# print(out_positive.unsqueeze(-1).shape, out_anchor.unsqueeze(-1).transpose(0, 2).shape)
# (10, 512, 1) (1,512,10)生成一个矩阵,使相同的靠近,对角线靠近。
output = -1 * (F.pairwise_distance(out_positive.unsqueeze(-1), out_anchor.unsqueeze(-1).transpose(0,2))**2)
# print(output)
label = torch.from_numpy(numpy.asarray(range(0,stepsize))).cuda()
# label = torch.from_numpy(numpy.asarray(range(0, stepsize)))
# print(label)
nloss = self.criterion(output, label)
prec1 = accuracy(output.detach(), label.detach(), topk=(1,))[0]
return nloss, prec1
if __name__ == "__main__":
# x = torch.randn(10, 10, 512)
# loss = LossFunction()
# nloss, prec1 = loss(x)
# print(nloss, prec1)
x = torch.randint(10, (10,512,10))
y = torch.randint(10, (10,512,10))
d = F.pairwise_distance(x,y)
print(d)
print(d.shape)