|
|
local GDLCriterion, parent = torch.class('nn.GDLCriterion', 'nn.Criterion') |
|
|
|
|
|
function GDLCriterion:__init(alpha, sizeAverage) |
|
|
|
|
|
parent.__init(self) |
|
|
self.alpha = alpha or 1 |
|
|
self.sizeAverage = sizeAverage or true |
|
|
|
|
|
assert(self.alpha == 1 or self.alpha == 2, "alpha should be 1 or 2") |
|
|
|
|
|
local hNet = nn.Sequential() |
|
|
hNet:add(nn.ConcatTable():add(nn.Narrow(3, 1, -2)):add(nn.Narrow(3, 2, -1))) |
|
|
hNet:add(nn.CSubTable()):add(nn.Abs()) |
|
|
local wNet = nn.Sequential() |
|
|
wNet:add(nn.ConcatTable():add(nn.Narrow(4, 1, -2)):add(nn.Narrow(4, 2, -1))) |
|
|
wNet:add(nn.CSubTable()):add(nn.Abs()) |
|
|
|
|
|
self.inputNet = nn.ConcatTable():add(hNet):add(wNet) |
|
|
self.targetNet = self.inputNet:clone() |
|
|
|
|
|
self.criterion = {} |
|
|
if self.alpha == 1 then |
|
|
self.criterion[1] = nn.AbsCriterion(self.sizeAverage) |
|
|
self.criterion[2] = nn.AbsCriterion(self.sizeAverage) |
|
|
else |
|
|
self.criterion[1] = nn.MSECriterion(self.sizeAverage) |
|
|
self.criterion[2] = nn.MSECriterion(self.sizeAverage) |
|
|
end |
|
|
|
|
|
end |
|
|
|
|
|
function GDLCriterion:updateOutput(input, target) |
|
|
|
|
|
assert( input:nElement() == target:nElement(), |
|
|
"input and target size mismatch") |
|
|
|
|
|
self.inputNetOutput = self.inputNet:forward(input) |
|
|
self.targetNetOutput = self.targetNet:forward(target) |
|
|
|
|
|
self.output = self.criterion[1]:forward(self.inputNetOutput[1], self.targetNetOutput[1]) |
|
|
self.output = self.output + self.criterion[2]:forward(self.inputNetOutput[2], self.targetNetOutput[2]) |
|
|
|
|
|
return self.output |
|
|
|
|
|
end |
|
|
|
|
|
|
|
|
function GDLCriterion:updateGradInput(input, target) |
|
|
|
|
|
assert( input:nElement() == target:nElement(), |
|
|
"input and target size mismatch") |
|
|
|
|
|
local gradCriterion = {} |
|
|
gradCriterion[1] = self.criterion[1]:backward(self.inputNetOutput[1], self.targetNetOutput[1]) |
|
|
gradCriterion[2] = self.criterion[2]:backward(self.inputNetOutput[2], self.targetNetOutput[2]) |
|
|
|
|
|
self.gradInput = self.inputNet:backward(input, gradCriterion) |
|
|
|
|
|
return self.gradInput |
|
|
|
|
|
end |
|
|
|