rawanessam's picture
Upload 79 files
e9fe176 verified
local KLDCriterion, parent = torch.class('nn.KLDCriterion', 'nn.Criterion')
function KLDCriterion:updateOutput(mean, log_var)
-- Appendix B from VAE paper: 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
local mean_sq = torch.pow(mean, 2)
local KLDelements = log_var:clone()
KLDelements:exp():mul(-1)
KLDelements:add(-1, mean_sq)
KLDelements:add(1)
KLDelements:add(log_var)
self.output = -0.5 * torch.sum(KLDelements)
return self.output
end
function KLDCriterion:updateGradInput(mean, log_var)
self.gradInput = {}
self.gradInput[1] = mean:clone()
-- Fix this to be nicer
self.gradInput[2] = torch.exp(log_var):mul(-1):add(1):mul(-0.5)
return self.gradInput
end