File size: 1,997 Bytes
e9fe176 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
local GDLCriterion, parent = torch.class('nn.GDLCriterion', 'nn.Criterion')
function GDLCriterion:__init(alpha, sizeAverage)
parent.__init(self)
self.alpha = alpha or 1
self.sizeAverage = sizeAverage or true
assert(self.alpha == 1 or self.alpha == 2, "alpha should be 1 or 2")
local hNet = nn.Sequential()
hNet:add(nn.ConcatTable():add(nn.Narrow(3, 1, -2)):add(nn.Narrow(3, 2, -1)))
hNet:add(nn.CSubTable()):add(nn.Abs())
local wNet = nn.Sequential()
wNet:add(nn.ConcatTable():add(nn.Narrow(4, 1, -2)):add(nn.Narrow(4, 2, -1)))
wNet:add(nn.CSubTable()):add(nn.Abs())
self.inputNet = nn.ConcatTable():add(hNet):add(wNet)
self.targetNet = self.inputNet:clone()
self.criterion = {}
if self.alpha == 1 then
self.criterion[1] = nn.AbsCriterion(self.sizeAverage)
self.criterion[2] = nn.AbsCriterion(self.sizeAverage)
else
self.criterion[1] = nn.MSECriterion(self.sizeAverage)
self.criterion[2] = nn.MSECriterion(self.sizeAverage)
end
end
function GDLCriterion:updateOutput(input, target)
assert( input:nElement() == target:nElement(),
"input and target size mismatch")
self.inputNetOutput = self.inputNet:forward(input)
self.targetNetOutput = self.targetNet:forward(target)
self.output = self.criterion[1]:forward(self.inputNetOutput[1], self.targetNetOutput[1])
self.output = self.output + self.criterion[2]:forward(self.inputNetOutput[2], self.targetNetOutput[2])
return self.output
end
-- must have called updateOutput with the same input/target pair right before
function GDLCriterion:updateGradInput(input, target)
assert( input:nElement() == target:nElement(),
"input and target size mismatch")
local gradCriterion = {}
gradCriterion[1] = self.criterion[1]:backward(self.inputNetOutput[1], self.targetNetOutput[1])
gradCriterion[2] = self.criterion[2]:backward(self.inputNetOutput[2], self.targetNetOutput[2])
self.gradInput = self.inputNet:backward(input, gradCriterion)
return self.gradInput
end
|