Spaces:
Sleeping
Sleeping
File size: 706 Bytes
f3b11f9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
class SimpleLossCompute:
"A simple loss compute and train function."
def __init__(self, generator, loss_function, opt):
self.generator = generator
self.loss_function = loss_function
self.opt = opt
def __call__(self, x, y, norm):
x = self.generator(x)
loss = self.loss_function(x.contiguous().view(-1, x.size(-1)),
y.contiguous().view(-1)) / norm
if self.opt is not None:
loss.backward()
self.opt.step()
self.opt.optimizer.zero_grad()
# print("loss from simplelosscompute:",loss)
# print("norm from simplelosscompute:",norm)
return loss.data * norm
|