rawanessam's picture
Upload 79 files
e9fe176 verified
local GDLCriterion, parent = torch.class('nn.GDLCriterion', 'nn.Criterion')
function GDLCriterion:__init(alpha, sizeAverage)
parent.__init(self)
self.alpha = alpha or 1
self.sizeAverage = sizeAverage or true
assert(self.alpha == 1 or self.alpha == 2, "alpha should be 1 or 2")
local hNet = nn.Sequential()
hNet:add(nn.ConcatTable():add(nn.Narrow(3, 1, -2)):add(nn.Narrow(3, 2, -1)))
hNet:add(nn.CSubTable()):add(nn.Abs())
local wNet = nn.Sequential()
wNet:add(nn.ConcatTable():add(nn.Narrow(4, 1, -2)):add(nn.Narrow(4, 2, -1)))
wNet:add(nn.CSubTable()):add(nn.Abs())
self.inputNet = nn.ConcatTable():add(hNet):add(wNet)
self.targetNet = self.inputNet:clone()
self.criterion = {}
if self.alpha == 1 then
self.criterion[1] = nn.AbsCriterion(self.sizeAverage)
self.criterion[2] = nn.AbsCriterion(self.sizeAverage)
else
self.criterion[1] = nn.MSECriterion(self.sizeAverage)
self.criterion[2] = nn.MSECriterion(self.sizeAverage)
end
end
function GDLCriterion:updateOutput(input, target)
assert( input:nElement() == target:nElement(),
"input and target size mismatch")
self.inputNetOutput = self.inputNet:forward(input)
self.targetNetOutput = self.targetNet:forward(target)
self.output = self.criterion[1]:forward(self.inputNetOutput[1], self.targetNetOutput[1])
self.output = self.output + self.criterion[2]:forward(self.inputNetOutput[2], self.targetNetOutput[2])
return self.output
end
-- must have called updateOutput with the same input/target pair right before
function GDLCriterion:updateGradInput(input, target)
assert( input:nElement() == target:nElement(),
"input and target size mismatch")
local gradCriterion = {}
gradCriterion[1] = self.criterion[1]:backward(self.inputNetOutput[1], self.targetNetOutput[1])
gradCriterion[2] = self.criterion[2]:backward(self.inputNetOutput[2], self.targetNetOutput[2])
self.gradInput = self.inputNet:backward(input, gradCriterion)
return self.gradInput
end