|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
require("nn") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
local function w_init_heuristic(fan_in, fan_out) |
|
|
return math.sqrt(1/(3*fan_in)) |
|
|
end |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
local function w_init_xavier(fan_in, fan_out) |
|
|
return math.sqrt(2/(fan_in + fan_out)) |
|
|
end |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
local function w_init_xavier_caffe(fan_in, fan_out) |
|
|
return math.sqrt(1/fan_in) |
|
|
end |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
local function w_init_kaiming(fan_in, fan_out) |
|
|
return math.sqrt(4/(fan_in + fan_out)) |
|
|
end |
|
|
|
|
|
|
|
|
local function w_init(net, arg) |
|
|
|
|
|
local method = nil |
|
|
if arg == 'heuristic' then method = w_init_heuristic |
|
|
elseif arg == 'xavier' then method = w_init_xavier |
|
|
elseif arg == 'xavier_caffe' then method = w_init_xavier_caffe |
|
|
elseif arg == 'kaiming' then method = w_init_kaiming |
|
|
else |
|
|
assert(false) |
|
|
end |
|
|
|
|
|
|
|
|
for i = 1, #net.modules do |
|
|
local m = net.modules[i] |
|
|
if m.__typename == 'nn.SpatialConvolution' then |
|
|
m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW)) |
|
|
elseif m.__typename == 'nn.SpatialConvolutionMM' then |
|
|
m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW)) |
|
|
elseif m.__typename == 'cudnn.SpatialConvolution' then |
|
|
m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW)) |
|
|
elseif m.__typename == 'nn.LateralConvolution' then |
|
|
m:reset(method(m.nInputPlane*1*1, m.nOutputPlane*1*1)) |
|
|
elseif m.__typename == 'nn.VerticalConvolution' then |
|
|
m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW)) |
|
|
elseif m.__typename == 'nn.HorizontalConvolution' then |
|
|
m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW)) |
|
|
elseif m.__typename == 'nn.Linear' then |
|
|
m:reset(method(m.weight:size(2), m.weight:size(1))) |
|
|
elseif m.__typename == 'nn.TemporalConvolution' then |
|
|
m:reset(method(m.weight:size(2), m.weight:size(1))) |
|
|
end |
|
|
|
|
|
if m.bias then |
|
|
m.bias:zero() |
|
|
end |
|
|
end |
|
|
return net |
|
|
end |
|
|
|
|
|
return w_init |
|
|
|