File size: 1,997 Bytes
e9fe176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
local GDLCriterion, parent = torch.class('nn.GDLCriterion', 'nn.Criterion')

function GDLCriterion:__init(alpha, sizeAverage)

  parent.__init(self)
  self.alpha = alpha or 1
  self.sizeAverage = sizeAverage or true

  assert(self.alpha == 1 or self.alpha == 2, "alpha should be 1 or 2")

  local hNet = nn.Sequential()
  hNet:add(nn.ConcatTable():add(nn.Narrow(3, 1, -2)):add(nn.Narrow(3, 2, -1)))
  hNet:add(nn.CSubTable()):add(nn.Abs())
  local wNet = nn.Sequential()
  wNet:add(nn.ConcatTable():add(nn.Narrow(4, 1, -2)):add(nn.Narrow(4, 2, -1)))
  wNet:add(nn.CSubTable()):add(nn.Abs())

  self.inputNet = nn.ConcatTable():add(hNet):add(wNet)
  self.targetNet = self.inputNet:clone()

  self.criterion = {}
  if self.alpha == 1 then
    self.criterion[1] = nn.AbsCriterion(self.sizeAverage)
    self.criterion[2] = nn.AbsCriterion(self.sizeAverage)
  else
    self.criterion[1] = nn.MSECriterion(self.sizeAverage)
    self.criterion[2] = nn.MSECriterion(self.sizeAverage)
  end

end

function GDLCriterion:updateOutput(input, target)

  assert( input:nElement() == target:nElement(),
  "input and target size mismatch")

  self.inputNetOutput = self.inputNet:forward(input)
  self.targetNetOutput = self.targetNet:forward(target)

  self.output = self.criterion[1]:forward(self.inputNetOutput[1], self.targetNetOutput[1])
  self.output = self.output + self.criterion[2]:forward(self.inputNetOutput[2], self.targetNetOutput[2])

  return self.output

end

-- must have called updateOutput with the same input/target pair right before
function GDLCriterion:updateGradInput(input, target)

  assert( input:nElement() == target:nElement(),
  "input and target size mismatch")

  local gradCriterion = {}
  gradCriterion[1] = self.criterion[1]:backward(self.inputNetOutput[1], self.targetNetOutput[1])
  gradCriterion[2] = self.criterion[2]:backward(self.inputNetOutput[2], self.targetNetOutput[2])
  
  self.gradInput = self.inputNet:backward(input, gradCriterion)
  
  return self.gradInput

end