File size: 834 Bytes
875baeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn as nn
import lossfunction.aamsoftmax as aamsoftmax
import lossfunction.angleproto as angleproto


class AamSoftmaxProto(nn.Module):

    def __init__(self, nOut, nClasses, margin, scale):
        super(AamSoftmaxProto, self).__init__()

        self.test_normalize = True

        self.aamsoftmax = aamsoftmax.AamSoftmax(nOut, nClasses, margin, scale)
        self.angleproto = angleproto.AngleProto()

        print('Initialised AamSoftmaxPrototypical Loss')

    def forward(self, x, label=None):

        assert x.size()[1] == 2

        nlossS, prec1 = self.aamsoftmax(x.reshape(-1, x.size()[-1]), label.repeat_interleave(2))

        nlossP, _ = self.angleproto(x, None)
        # print("lossP:", nlossP, "nlossS:", nlossS)
        # lossP:0.6678 nlossS:13.6913

        return nlossS + nlossP, prec1