File size: 1,705 Bytes
e9fe176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
local CrossConvolve, parent = torch.class('nn.CrossConvolve', 'nn.Module')

function CrossConvolve:__init(nInputPlane, kW, kH)
   parent.__init(self)

   kH = kH or kW

   -- get args
   self.nInputPlane = nInputPlane or 1
   self.kW = kW
   self.kH = kH

   -- padding values
   self.padH = math.floor(kH / 2)
   self.padW = math.floor(kW / 2)

   -- create convolver
   self.convolver = nn.Sequential()
   self.convolver:add(nn.SpatialZeroPadding(self.padW, self.padW, self.padH, self.padH))
   self.convolver:add(nn.SpatialConvolution(self.nInputPlane, 1, kW, kH))
   
   -- set bias
   self.convolver.modules[2].bias:zero()
    
   self.gradInput = {}
end

function CrossConvolve:updateOutput(input)   
   assert(input[2]:size(1) == self.kH)
   assert(input[2]:size(2) == self.kW)

   for i = 1, self.nInputPlane do 
      self.convolver.modules[2].weight[1][i] = input[2]
   end
   
   -- compute output
   self.output = self.convolver:updateOutput(input[1])

   -- done
   return self.output
end

function CrossConvolve:updateGradInput(input, gradOutput)
   -- resize grad
   for i = 1, #input do 
      if not self.gradInput[i] then
         self.gradInput[i] = input[i].new()
      end
      self.gradInput[i]:resizeAs(input[i]):zero()
   end
   self.convolver:zeroGradParameters()

   -- backprop 
   self.gradInput[1]:add(self.convolver:updateGradInput(input[1], gradOutput))
   self.convolver:accGradParameters(input[1], gradOutput)
   for i = 1, self.nInputPlane do
       self.gradInput[2]:add(self.convolver.modules[2].gradWeight[1][i])
   end

   -- done
   return self.gradInput
end

function CrossConvolve:clearState()
   self.convolver:clearState()
   return parent.clearState(self)
end