File size: 549 Bytes
875baeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
import torch.nn as nn
from utils.acc import accuracy


class Softmax(nn.Module):
    def __init__(self, nOut, nClasses):
        super(Softmax, self).__init__()

        self.test_normalize = True

        self.criterion = torch.nn.CrossEntropyLoss()
        self.fc = nn.Linear(nOut, nClasses)

        print('Initialised Softmax Loss')

    def forward(self, x, label=None):
        x = self.fc(x)
        nloss = self.criterion(x, label)
        prec1 = accuracy(x.detach(), label.detach(), topk=(1,))[0]

        return nloss, prec1