File size: 923 Bytes
e9fe176 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
require 'nn'
local GaussianCriterion, parent = torch.class('nn.GaussianCriterion', 'nn.Criterion')
function GaussianCriterion:updateOutput(input, target)
-- - log(sigma) - 0.5 *(2pi)) - 0.5 * (x - mu)^2/sigma^2
-- input[1] = mu
-- input[2] = log(sigma^2)
local Gelement = torch.mul(input[2],0.5):add(0.5 * math.log(2 * math.pi))
Gelement:add(torch.add(target,-1,input[1]):pow(2):cdiv(torch.exp(input[2])):mul(0.5))
self.output = torch.sum(Gelement)
return self.output
end
function GaussianCriterion:updateGradInput(input, target)
self.gradInput = {}
-- (x - mu) / sigma^2 --> (1 / sigma^2 = exp(-log(sigma^2)) )
self.gradInput[1] = torch.exp(-input[2]):cmul(torch.add(target,-1,input[1])):mul(-1)
-- - 0.5 + 0.5 * (x - mu)^2 / sigma^2
self.gradInput[2] = torch.exp(-input[2]):cmul(torch.add(target,-1,input[1]):pow(2)):mul(-1):add(0.5)
return self.gradInput
end |