File size: 721 Bytes
e9fe176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
local KLDCriterion, parent = torch.class('nn.KLDCriterion', 'nn.Criterion')

function KLDCriterion:updateOutput(mean, log_var)
    -- Appendix B from VAE paper: 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)

    local mean_sq = torch.pow(mean, 2)
    local KLDelements = log_var:clone()

    KLDelements:exp():mul(-1)
    KLDelements:add(-1, mean_sq)
    KLDelements:add(1)
    KLDelements:add(log_var)

    self.output = -0.5 * torch.sum(KLDelements)

    return self.output
end

function KLDCriterion:updateGradInput(mean, log_var)
	self.gradInput = {}

    self.gradInput[1] = mean:clone()

    -- Fix this to be nicer
    self.gradInput[2] = torch.exp(log_var):mul(-1):add(1):mul(-0.5)

    return self.gradInput
end