|
|
require 'nn' |
|
|
|
|
|
local GaussianCriterion, parent = torch.class('nn.GaussianCriterion', 'nn.Criterion') |
|
|
|
|
|
function GaussianCriterion:updateOutput(input, target) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
local Gelement = torch.mul(input[2],0.5):add(0.5 * math.log(2 * math.pi)) |
|
|
Gelement:add(torch.add(target,-1,input[1]):pow(2):cdiv(torch.exp(input[2])):mul(0.5)) |
|
|
|
|
|
self.output = torch.sum(Gelement) |
|
|
|
|
|
return self.output |
|
|
end |
|
|
|
|
|
function GaussianCriterion:updateGradInput(input, target) |
|
|
self.gradInput = {} |
|
|
|
|
|
|
|
|
self.gradInput[1] = torch.exp(-input[2]):cmul(torch.add(target,-1,input[1])):mul(-1) |
|
|
|
|
|
|
|
|
self.gradInput[2] = torch.exp(-input[2]):cmul(torch.add(target,-1,input[1]):pow(2)):mul(-1):add(0.5) |
|
|
|
|
|
return self.gradInput |
|
|
end |