File size: 2,424 Bytes
e9fe176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
--
-- Different weight initialization methods
--
-- > model = require('weight-init')(model, 'heuristic')
--
require("nn")


-- "Efficient backprop"
-- Yann Lecun, 1998
local function w_init_heuristic(fan_in, fan_out)
   return math.sqrt(1/(3*fan_in))
end


-- "Understanding the difficulty of training deep feedforward neural networks"
-- Xavier Glorot, 2010
local function w_init_xavier(fan_in, fan_out)
   return math.sqrt(2/(fan_in + fan_out))
end


-- "Understanding the difficulty of training deep feedforward neural networks"
-- Xavier Glorot, 2010
local function w_init_xavier_caffe(fan_in, fan_out)
   return math.sqrt(1/fan_in)
end


-- "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification"
-- Kaiming He, 2015
local function w_init_kaiming(fan_in, fan_out)
   return math.sqrt(4/(fan_in + fan_out))
end


local function w_init(net, arg)
   -- choose initialization method
   local method = nil
   if     arg == 'heuristic'    then method = w_init_heuristic
   elseif arg == 'xavier'       then method = w_init_xavier
   elseif arg == 'xavier_caffe' then method = w_init_xavier_caffe
   elseif arg == 'kaiming'      then method = w_init_kaiming
   else
      assert(false)
   end

   -- loop over all convolutional modules
   for i = 1, #net.modules do
      local m = net.modules[i]
      if m.__typename == 'nn.SpatialConvolution' then
         m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW))
      elseif m.__typename == 'nn.SpatialConvolutionMM' then
         m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW))
      elseif m.__typename == 'cudnn.SpatialConvolution' then
         m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW))
      elseif m.__typename == 'nn.LateralConvolution' then
         m:reset(method(m.nInputPlane*1*1, m.nOutputPlane*1*1))
      elseif m.__typename == 'nn.VerticalConvolution' then
         m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW))
      elseif m.__typename == 'nn.HorizontalConvolution' then
         m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW))
      elseif m.__typename == 'nn.Linear' then
         m:reset(method(m.weight:size(2), m.weight:size(1)))
      elseif m.__typename == 'nn.TemporalConvolution' then
         m:reset(method(m.weight:size(2), m.weight:size(1)))            
      end

      if m.bias then
         m.bias:zero()
      end
   end
   return net
end

return w_init