Spaces:
Runtime error
Runtime error
File size: 834 Bytes
875baeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch
import torch.nn as nn
import lossfunction.aamsoftmax as aamsoftmax
import lossfunction.angleproto as angleproto
class AamSoftmaxProto(nn.Module):
def __init__(self, nOut, nClasses, margin, scale):
super(AamSoftmaxProto, self).__init__()
self.test_normalize = True
self.aamsoftmax = aamsoftmax.AamSoftmax(nOut, nClasses, margin, scale)
self.angleproto = angleproto.AngleProto()
print('Initialised AamSoftmaxPrototypical Loss')
def forward(self, x, label=None):
assert x.size()[1] == 2
nlossS, prec1 = self.aamsoftmax(x.reshape(-1, x.size()[-1]), label.repeat_interleave(2))
nlossP, _ = self.angleproto(x, None)
# print("lossP:", nlossP, "nlossS:", nlossS)
# lossP:0.6678 nlossS:13.6913
return nlossS + nlossP, prec1 |