File size: 1,580 Bytes
875baeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy
from utils.acc import accuracy

class proto(nn.Module):

    def __init__(self, **kwargs):
        super(proto, self).__init__()

        self.test_normalize = False

        self.criterion  = torch.nn.CrossEntropyLoss()

        print('Initialised Prototypical Loss')

    def forward(self, x, label=None):

        assert x.size()[1] >= 2
        
        out_anchor      = torch.mean(x[:, 1:, :], 1)
        out_positive    = x[:, 0, :]
        stepsize        = out_anchor.size()[0]
        # print(out_anchor.shape, out_positive.shape)
        # print(out_positive.unsqueeze(-1).shape, out_anchor.unsqueeze(-1).transpose(0, 2).shape)
        # (10, 512, 1)  (1,512,10)生成一个矩阵,使相同的靠近,对角线靠近。
        output  = -1 * (F.pairwise_distance(out_positive.unsqueeze(-1), out_anchor.unsqueeze(-1).transpose(0,2))**2)
        # print(output)
        label   = torch.from_numpy(numpy.asarray(range(0,stepsize))).cuda()
        # label = torch.from_numpy(numpy.asarray(range(0, stepsize)))
        # print(label)
        nloss   = self.criterion(output, label)
        prec1   = accuracy(output.detach(), label.detach(), topk=(1,))[0]

        return nloss, prec1


if __name__ == "__main__":
    # x = torch.randn(10, 10, 512)
    # loss = LossFunction()
    # nloss, prec1 = loss(x)
    # print(nloss, prec1)
    x = torch.randint(10, (10,512,10))
    y = torch.randint(10, (10,512,10))
    d = F.pairwise_distance(x,y)
    print(d)
    print(d.shape)