File size: 706 Bytes
f3b11f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

class SimpleLossCompute:
    "A simple loss compute and train function."

    def __init__(self, generator, loss_function, opt):
        self.generator = generator
        self.loss_function = loss_function
        self.opt = opt

    def __call__(self, x, y, norm):

        x = self.generator(x)

        loss = self.loss_function(x.contiguous().view(-1, x.size(-1)),
                                  y.contiguous().view(-1)) / norm

        if self.opt is not None:
            loss.backward()
            self.opt.step()
            self.opt.optimizer.zero_grad()
       # print("loss from simplelosscompute:",loss)
       # print("norm from simplelosscompute:",norm)
        return loss.data * norm