|
|
require 'torch' |
|
|
require 'nn' |
|
|
require 'cunn' |
|
|
require 'cudnn' |
|
|
require 'fbnn' |
|
|
require 'fbcunn' |
|
|
require 'fbcode.deeplearning.experimental.yuandong.layers.custom_layers' |
|
|
|
|
|
local pl = require 'pl.import_into'() |
|
|
local tnt = require 'torchnet' |
|
|
local bistro = require 'bistro' |
|
|
local cjson = require 'cjson' |
|
|
|
|
|
|
|
|
|
|
|
local conv_layer, relu_layer, maxpool_layer |
|
|
|
|
|
if cudnn then |
|
|
conv_layer = cudnn.SpatialConvolution |
|
|
relu_layer = cudnn.ReLU |
|
|
maxpool_layer = cudnn.SpatialMaxPooling |
|
|
else |
|
|
conv_layer = nn.SpatialConvolutionMM |
|
|
relu_layer = nn.ReLU |
|
|
maxpool_layer = nn.SpatialMaxPooling |
|
|
end |
|
|
|
|
|
|
|
|
|
|
|
local function conv_layer_size(inputsize, kw, dw, pw) |
|
|
dw = dw or 1 |
|
|
pw = pw or 0 |
|
|
return math.floor((inputsize + pw * 2 - kw) / dw) + 1 |
|
|
end |
|
|
|
|
|
local function spatial_layer_size(layer, inputdim) |
|
|
layer.nip = layer.nip or inputdim[2] |
|
|
layer.nop = layer.nop or layer.nip |
|
|
|
|
|
|
|
|
assert(#inputdim == 4, 'Spatial_layer_size: Input dim must be 4 dimensions') |
|
|
assert(layer.nip == inputdim[2], string.format('Spatial_layer_size: the number of input channel [%d] is not the same as specified [%d]!', inputdim[2], layer.nip)) |
|
|
assert(layer.nop, 'Spatial_layer_size: layer.nop is null!') |
|
|
|
|
|
layer.pw = layer.pw or 0 |
|
|
layer.dw = layer.dw or 1 |
|
|
|
|
|
layer.kh = layer.kh or layer.kw |
|
|
layer.ph = layer.ph or layer.pw |
|
|
layer.dh = layer.dh or layer.dw |
|
|
|
|
|
local outputw = conv_layer_size(inputdim[4], layer.kw, layer.dw, layer.pw) |
|
|
local outputh = conv_layer_size(inputdim[3], layer.kh, layer.dh, layer.ph) |
|
|
|
|
|
return {inputdim[1], layer.nop, outputh, outputw} |
|
|
end |
|
|
|
|
|
local function spec_expand(dim, spec, inputdim) |
|
|
local res = {} |
|
|
local total = 0 |
|
|
for i = 1, #spec do |
|
|
if type(spec[i]) == 'table' then |
|
|
this_res, this_total = spec_expand(dim, spec[i], inputdim) |
|
|
else |
|
|
this_res = pl.tablex.deepcopy(inputdim) |
|
|
this_res[dim] = spec[i] |
|
|
this_total = spec[i] |
|
|
end |
|
|
table.insert(res, this_res) |
|
|
total = total + this_total |
|
|
end |
|
|
return res, total |
|
|
end |
|
|
|
|
|
local function make_network(layer, inputdim) |
|
|
|
|
|
print("++++++++++++++++++++++++++++") |
|
|
print("Layer spec = " .. pl.pretty.write(layer, '', false)) |
|
|
print("Input dim = " .. pl.pretty.write(inputdim, '', false)) |
|
|
print("----------------------------") |
|
|
|
|
|
|
|
|
if not layer.type or layer.type == 'seq' then |
|
|
local seq = nn.Sequential() |
|
|
|
|
|
|
|
|
local curr_inputdim = inputdim |
|
|
local ll |
|
|
for _, l in ipairs(layer) do |
|
|
ll, curr_inputdim = make_network(l, curr_inputdim) |
|
|
seq:add(ll) |
|
|
end |
|
|
return seq, curr_inputdim |
|
|
|
|
|
elseif layer.type == 'parallel' then |
|
|
local para = nn.ParallelTable() |
|
|
|
|
|
local ll |
|
|
local outputdim = {} |
|
|
for idx, l in ipairs(layer) do |
|
|
ll, outputdim[idx] = make_network(l, inputdim[idx]) |
|
|
para:add(ll) |
|
|
end |
|
|
return para, outputdim |
|
|
elseif layer.type == 'join' then |
|
|
assert(type(inputdim) == 'table', 'MakeNetwork::Join: inputdim must be a table!') |
|
|
|
|
|
local outputdim |
|
|
for idx, dim in ipairs(inputdim) do |
|
|
if not outputdim then |
|
|
outputdim = dim |
|
|
else |
|
|
for dim_idx, d_size in ipairs(dim) do |
|
|
if dim_idx == layer.dim then |
|
|
outputdim[dim_idx] = outputdim[dim_idx] + d_size |
|
|
else |
|
|
assert(outputdim[dim_idx] == d_size, 'Join table, dimension ' .. dim_idx .. ' disagree!') |
|
|
end |
|
|
end |
|
|
end |
|
|
end |
|
|
return nn.JoinTable(layer.dim), outputdim |
|
|
|
|
|
elseif layer.type == 'conv' then |
|
|
layer.nip = layer.nip or inputdim[2] |
|
|
layer.dw = layer.dw or 1 |
|
|
|
|
|
layer.kh = layer.kh or layer.kw |
|
|
layer.dh = layer.dh or layer.dw |
|
|
|
|
|
layer.pw = layer.pw or math.floor(layer.kw / 2) |
|
|
layer.ph = layer.ph or math.floor(layer.kh / 2) |
|
|
|
|
|
assert(layer.nip == inputdim[2], string.format('MakeNetwork::Conv: #input channel [%d] must match with specification [%d]!', inputdim[2], layer.nip)); |
|
|
assert(layer.kw) |
|
|
assert(layer.kh) |
|
|
assert(layer.pw) |
|
|
assert(layer.ph) |
|
|
assert(layer.dw) |
|
|
assert(layer.dh) |
|
|
assert(layer.nip) |
|
|
assert(layer.nop) |
|
|
|
|
|
local conv_layer = conv_layer(layer.nip, layer.nop, layer.kw, layer.kh, layer.dw, layer.dh, layer.pw, layer.ph) |
|
|
return conv_layer, spatial_layer_size(layer, inputdim) |
|
|
|
|
|
elseif layer.type == 'relu' then |
|
|
return relu_layer(), inputdim |
|
|
|
|
|
elseif layer.type == 'bn' then |
|
|
assert(#inputdim == 2, 'Error! Input to BatchNormalization must be 2-dimensional.') |
|
|
return nn.BatchNormalization(inputdim[2]), inputdim |
|
|
|
|
|
elseif layer.type == 'spatialbn' then |
|
|
assert(#inputdim == 4, 'Error! Input to SpatialBatchNormalization must be 4-dimensional.') |
|
|
return nn.SpatialBatchNormalization(inputdim[2]), inputdim |
|
|
|
|
|
elseif layer.type == 'thres' then |
|
|
return nn.Threshold(0, 1e-6), inputdim |
|
|
|
|
|
elseif layer.type == 'maxp' then |
|
|
assert(layer.kw and layer.dw, "MakeNetwork:MaxP: kw and dw should not be nil") |
|
|
|
|
|
layer.kh = layer.kh or layer.kw |
|
|
layer.dh = layer.dh or layer.dw |
|
|
|
|
|
return maxpool_layer(layer.kw, layer.kh, layer.dw, layer.dh), spatial_layer_size(layer, inputdim) |
|
|
|
|
|
elseif layer.type == 'maxp1' then |
|
|
assert(layer.kw and layer.dw, "MakeNetwork:MaxP1: kw and dw should not be nil") |
|
|
local outputdim = { inputdim[1], conv_layer_size(inputdim[2], layer.kw, layer.dw), inputdim[3] } |
|
|
return nn.TemporalMaxPooling(layer.kw, layer.dw), outputdim |
|
|
|
|
|
elseif layer.type == 'reshape' then |
|
|
if layer.dir == '4-2' then |
|
|
layer.wi = layer.wi or inputdim[4] |
|
|
layer.nip = layer.nip or inputdim[2] |
|
|
layer.nop = layer.nop or inputdim[2]*inputdim[3]*inputdim[4] |
|
|
elseif layer.dir == '3-2' then |
|
|
|
|
|
layer.wi = layer.wi or inputdim[2] |
|
|
layer.nip = layer.nip or inputdim[3] |
|
|
layer.nop = layer.nop or inputdim[2]*inputdim[3] |
|
|
end |
|
|
|
|
|
if layer.wi then |
|
|
|
|
|
if layer.dir == '4-2' then |
|
|
layer.hi = layer.hi or inputdim[3] |
|
|
|
|
|
assert(#inputdim == 4, 'MakeNetwork::Reshape4-2: Input dim must be 4 dimensions') |
|
|
local outputsize = { inputdim[1], layer.nip*layer.wi*layer.hi } |
|
|
assert(outputsize[2] == inputdim[2]*inputdim[3]*inputdim[4], 'MakeNetwork::Reshape4-2: Input dim must match with specified dimensions') |
|
|
assert(outputsize[2] > 0, |
|
|
string.format("MakeNetwork::Reshape4-2: outputsize[2] = %d, (nip, wi, hi) = (%d, %d, %d)", |
|
|
outputsize[2], layer.nip, layer.wi, layer.hi)) |
|
|
return nn.View(outputsize[2]), outputsize |
|
|
elseif layer.dir == '3-2' then |
|
|
assert(#inputdim == 3, 'MakeNetwork::Reshape3-2: Input dim must be 3 dimensions') |
|
|
local outputsize = { inputdim[1], layer.nip*layer.wi } |
|
|
assert(outputsize[2] == inputdim[2]*inputdim[3], 'MakeNetwork::Reshape3-2: Input dim must match with specified dimensions') |
|
|
assert(outputsize[2] > 0, |
|
|
string.format("MakeNetwork::Reshape3-2: outputsize[2] = %d, (nip, wi) = (%d, %d)", |
|
|
outputsize[2], layer.nip, layer.wi)) |
|
|
return nn.View(outputsize[2]), outputsize |
|
|
end |
|
|
elseif layer.wo then |
|
|
|
|
|
assert(#inputdim == 2, 'MakeNetwork::Reshape2-4: Input dim must be 2 dimensions') |
|
|
layer.nip = layer.nip or inputdim[2] |
|
|
layer.ho = layer.ho or layer.wo |
|
|
layer.nop = layer.nop or inputdim[2] / (layer.wo * layer.ho) |
|
|
|
|
|
local outputsize = { inputdim[1], layer.nop, layer.ho, layer.wo } |
|
|
assert(outputsize[2]*outputsize[3]*outputsize[4] == inputdim[2], 'MakeNetwork::Reshape2-4: Input dim must match with specified dimensions') |
|
|
return nn.View(layer.nop, layer.ho, layer.wo), outputsize |
|
|
end |
|
|
elseif layer.type == 'fc' then |
|
|
assert(#inputdim == 2, 'MakeNetwork::FC: Input dim must be 2 dimensions') |
|
|
layer.nip = layer.nip or inputdim[2] |
|
|
assert(layer.nip == inputdim[2], string.format('MakeNetwork::FC: the number of input channel [%d] is not the same as specified [%d]!', inputdim[2], layer.nip)) |
|
|
return nn.Linear(layer.nip, layer.nop), { inputdim[1], layer.nop } |
|
|
|
|
|
elseif layer.type == 'conv1' then |
|
|
|
|
|
|
|
|
|
|
|
assert(layer.kw, 'MakeNetwork:Conv1: kw must be specified') |
|
|
assert(#inputdim == 3, 'MakeNetwork:Conv1: Input dim must be 3 dimensions') |
|
|
assert(layer.nop, 'MakeNetwork:Conv1: nop must be specified') |
|
|
|
|
|
layer.nip = layer.nip or inputdim[3] |
|
|
layer.dw = layer.dw or 1 |
|
|
|
|
|
assert(layer.nip == inputdim[3], string.format('MakeNetwork::Conv1: the number of input channels [%d] is not the same as specified [%d]!', inputdim[3], layer.nip)) |
|
|
local outputdim = {inputdim[1], conv_layer_size(inputdim[2], layer.kw, layer.dw), layer.nop} |
|
|
return nn.TemporalConvolution(layer.nip, layer.nop, layer.kw, layer.dw), outputdim |
|
|
|
|
|
elseif layer.type == 'usample' then |
|
|
assert(#inputdim == 4, 'MakeNetwork::USample: Input dim must be 4 dimensions') |
|
|
layer.wi = layer.wi or inputdim[3] |
|
|
assert(layer.wi == inputdim[3], string.format('MakeNetwork::USample: Input height [%d] much match with specification [%d].', inputdim[3], layer.wi)) |
|
|
assert(layer.wi == inputdim[4], string.format('MakeNetwork::USample: Input width [%d] much match with specification [%d].', inputdim[4], layer.wi)) |
|
|
return nn.SpatialUpSamplingNearest(layer.wo / layer.wi), { inputdim[1], inputdim[2], layer.wo, layer.wo } |
|
|
elseif layer.type == 'recursive-split' then |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
outputdim, total_use = spec_expand(layer.dim, layer.spec, inputdim) |
|
|
assert(total_use == inputdim[layer.dim], string.format("MakeNetwork::RecursiveSplitTable: Total usage specified by layer.spec [%d] is not the same as the inputdim[%d] (which is %d)", total_use, layer.dim, inputdim[layer.dim])) |
|
|
|
|
|
return nn.RecursiveSplitTable(layer.dim - 1, #inputdim - 1, layer.spec), outputdim |
|
|
elseif layer.type == 'addtable' then |
|
|
assert(type(inputdim) == 'table', "MakeNetwork::addtable: Inputdim must be a table.") |
|
|
assert(inputdim[1], "MakeNetwork::addtable: Inputdim must not be empty.") |
|
|
|
|
|
for i = 2, #inputdim do |
|
|
assert(#inputdim[i] == #inputdim[1], string.format("MakeNetwork::addtable: Each entry of input dim must be of the same length, yet #input[%d] = %d while #inputdim[1] = %d", i, #inputdim[i], #inputdim[1])) |
|
|
for j = 1, #inputdim[i] do |
|
|
assert(inputdim[i][j] == inputdim[1][j], string.format("MakeNetwork::addtable: Each entry of inputdim must be of same size. Now inputdim[%d][%d] = %d while inputdim[1][%d] = %d", i, j, inputdim[i][j], j, inputdim[1][j])) |
|
|
end |
|
|
end |
|
|
|
|
|
return nn.CAddTable(), inputdim[1] |
|
|
elseif layer.type == 'dropout' then |
|
|
return nn.Dropout(layer.ratio), inputdim |
|
|
elseif layer.type == 'logsoftmax' then |
|
|
return nn.LogSoftMax(), inputdim |
|
|
else |
|
|
error("Unknown layer type " .. layer.type); |
|
|
end |
|
|
end |
|
|
|
|
|
local function merge_tables(tbls) |
|
|
if #tbls == 0 then return {} end |
|
|
local res = tbls[1] |
|
|
for i = 2,#tbls do |
|
|
for j = 1, #tbls[i] do |
|
|
table.insert(res, tbls[i][j]) |
|
|
end |
|
|
end |
|
|
|
|
|
return res |
|
|
end |
|
|
|
|
|
local nnutils = { |
|
|
make_network = make_network, |
|
|
spatial_layer_size = spatial_layer_size, |
|
|
merge_tables = merge_tables |
|
|
} |
|
|
|
|
|
|
|
|
local g_nn_dbg = false |
|
|
|
|
|
function nnutils.dbg_set() |
|
|
g_nn_dbg = true |
|
|
end |
|
|
|
|
|
function nnutils.dbg_clear() |
|
|
g_nn_dbg = false |
|
|
end |
|
|
|
|
|
function nnutils.dprint(s, ...) |
|
|
if g_nn_dbg then |
|
|
local p = {...} |
|
|
if #p == 0 then print(s) |
|
|
else print(string.format(s, unpack(p))) |
|
|
end |
|
|
end |
|
|
end |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function nnutils.in_bistro() |
|
|
local cwd = io.popen('pwd'):read("*all") |
|
|
local bistro_prefix = '/gfsai-bistro' |
|
|
return string.sub(cwd, 1, #bistro_prefix) == bistro_prefix |
|
|
end |
|
|
|
|
|
function nnutils.json_stats(t) |
|
|
return 'json_stats: ' .. cjson.encode(t) |
|
|
end |
|
|
|
|
|
function nnutils.deepcopy(obj) |
|
|
file = torch.MemoryFile() |
|
|
file:writeObject(obj) |
|
|
file:seek(1) |
|
|
return file:readObject() |
|
|
end |
|
|
|
|
|
function nnutils.add_if_nonexist(t1, t2) |
|
|
for k, v in pairs(t2) do |
|
|
if not t1[k] then t1[k] = v end |
|
|
end |
|
|
return t1 |
|
|
end |
|
|
|
|
|
function nnutils.get_first_available(t, keys) |
|
|
for _, k in ipairs(keys) do |
|
|
if t[k] then return k, t[k] end |
|
|
end |
|
|
return nil, nil |
|
|
end |
|
|
|
|
|
|
|
|
|
|
|
function nnutils.pick_layers(model, name) |
|
|
local layers = {} |
|
|
for i = 1, #model.modules do |
|
|
local m = model.modules[i] |
|
|
if torch.typename(m):match(name) then |
|
|
table.insert(layers, m) |
|
|
end |
|
|
end |
|
|
return layers |
|
|
end |
|
|
|
|
|
function nnutils.operate_layers(model, name, func) |
|
|
for i = 1, #model.modules do |
|
|
local m = model.modules[i] |
|
|
if torch.typename(m):match(name) then |
|
|
func(m) |
|
|
end |
|
|
end |
|
|
end |
|
|
|
|
|
function nnutils.add_regular_hooks(rack) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rack:addHook('forward', function() collectgarbage() end) |
|
|
if rack.hooks.update then |
|
|
rack:addHook('update', function() cutorch.synchronize() end) |
|
|
end |
|
|
|
|
|
local tntexp = require 'fbcode.deeplearning.experimental.yuandong.torchnet.init' |
|
|
|
|
|
|
|
|
tntexp.TimeMeter{ rack = rack, label = "time", perbatch = true } |
|
|
end |
|
|
|
|
|
function nnutils.set_output_json(log, trainer, config) |
|
|
local epoch = 1 |
|
|
local entries_to_log = {'trainloss', 'testloss', 'train top@1', 'train top@5', 'test top@1', 'test top@5' } |
|
|
|
|
|
trainer:addHook( |
|
|
'end-epoch', |
|
|
function() |
|
|
|
|
|
local perf_table = { epoch = epoch } |
|
|
for _, entry in ipairs(entries_to_log) do |
|
|
if log.key2idx[entry] then |
|
|
perf_table[entry] = log:get(entry) |
|
|
end |
|
|
end |
|
|
|
|
|
|
|
|
perf_table.timestamp = os.clock() |
|
|
|
|
|
|
|
|
perf_table.neg_loss = -log:get("testloss") |
|
|
|
|
|
bistro.log(pl.tablex.merge(perf_table, config, true)) |
|
|
epoch = epoch + 1 |
|
|
end |
|
|
) |
|
|
end |
|
|
|
|
|
function nnutils.add_save_on_trainer(log, net, trainer, saveto) |
|
|
|
|
|
log:column('saved') |
|
|
local netsav = net:clone('weight', 'bias', 'running_mean', 'running_std') |
|
|
local minerr = math.huge |
|
|
trainer:addHook( |
|
|
'end-epoch', |
|
|
function() |
|
|
local valid_col = nnutils.get_first_available(log.key2idx, { 'testloss', 'trainloss' } ) |
|
|
local z = log:get(valid_col) |
|
|
if z and z < minerr then |
|
|
if pl.path.isdir(saveto) then |
|
|
savefile = string.format('%s/model.bin', saveto) |
|
|
else |
|
|
savefile = saveto |
|
|
end |
|
|
local f = torch.DiskFile(savefile, 'w') |
|
|
f:binary() |
|
|
f:writeObject(netsav) |
|
|
f:close() |
|
|
minerr = z |
|
|
log:set('saved', '*') |
|
|
else |
|
|
log:set('saved', '') |
|
|
end |
|
|
end |
|
|
) |
|
|
end |
|
|
|
|
|
function nnutils.add_logging(trainer, log) |
|
|
trainer:addHook( |
|
|
'end-epoch', |
|
|
function() |
|
|
log:print{} |
|
|
log:print{stdout = true, labels = true, separator = ' | '} |
|
|
end |
|
|
) |
|
|
end |
|
|
|
|
|
function nnutils.reload_if(config, model_name, config_name) |
|
|
|
|
|
if config.evalOnly and config.reload == '' then |
|
|
error("evalOnly only works if there is a model to be reloaded!") |
|
|
end |
|
|
|
|
|
local net |
|
|
if config.reload ~='' then |
|
|
require 'nn' |
|
|
require 'cutorch' |
|
|
require 'cunn' |
|
|
require 'cudnn' |
|
|
print(string.format('| reloading experiment %s', config.reload)) |
|
|
local f = torch.DiskFile(string.format('%s/%s', config.reload, model_name)) |
|
|
f:binary() |
|
|
net = f:readObject() |
|
|
f:close() |
|
|
|
|
|
if config_name then |
|
|
local oldconfig_file = string.format('%s/%s', config.reload, config_name) |
|
|
if pl.path.exists(oldconfig_file) then |
|
|
local oldconfig = torch.load(oldconfig_file) |
|
|
oldconfig.subdir = nil |
|
|
tnt.utils.table.merge(oldconfig, config) |
|
|
oldconfig.reload = nil |
|
|
oldconfig.save = config.reload |
|
|
config = oldconfig |
|
|
end |
|
|
end |
|
|
end |
|
|
return net, config |
|
|
end |
|
|
|
|
|
|
|
|
|
|
|
function nnutils.parse_to_idx(s, isfilename, word2index, per_char) |
|
|
local word_indices = {} |
|
|
|
|
|
local sep = per_char and "." or "[^%s]+" |
|
|
|
|
|
local content |
|
|
if isfilename then |
|
|
local f = torch.DiskFile(s) |
|
|
content = f:readString('*a') |
|
|
print(string.format("size of content = %d", string.len(content))) |
|
|
f:close() |
|
|
else |
|
|
content = s |
|
|
end |
|
|
|
|
|
for token in string.gmatch(content, sep) do |
|
|
local index = word2index[token] |
|
|
if index ~= nil then |
|
|
table.insert(word_indices, index) |
|
|
end |
|
|
end |
|
|
print(string.format("Number of tokens = %d", #word_indices)) |
|
|
|
|
|
return word_indices |
|
|
end |
|
|
|
|
|
function nnutils.split_seq_into_batch(data, batch_size, seq_length) |
|
|
|
|
|
local len = data:size(1) |
|
|
local xdata, ydata |
|
|
|
|
|
if len % (batch_size * seq_length) ~= 0 then |
|
|
print('cutting off end of data so that the batches/sequences divide evenly') |
|
|
xdata = data:sub(1, batch_size * seq_length * math.floor(len / (batch_size * seq_length))) |
|
|
else |
|
|
xdata = data |
|
|
end |
|
|
|
|
|
local ydata = xdata:clone() |
|
|
ydata:sub(1,-2):copy(xdata:sub(2,-1)) |
|
|
ydata[-1] = xdata[1] |
|
|
|
|
|
local x_batches = xdata:view(batch_size, -1):split(seq_length, 2) |
|
|
local y_batches = ydata:view(batch_size, -1):split(seq_length, 2) |
|
|
assert(#x_batches == #y_batches) |
|
|
|
|
|
return x_batches, y_batches |
|
|
end |
|
|
|
|
|
|
|
|
function nnutils.get_config(default_config) |
|
|
local config = bistro.get_params(nnutils.add_if_nonexist(default_config, |
|
|
{ lr = 0.05, gpu = 1, seed = 1111, nthread = 8, evalOnly = false, |
|
|
train_batch = 128, test_batch = 128, max_epoch = 100, reload = '', save = './'})) |
|
|
|
|
|
|
|
|
if config.evalOnly and config.reload == '' then |
|
|
error("evalOnly only works if there is a model to be reloaded!") |
|
|
end |
|
|
|
|
|
local net |
|
|
if config.reload ~= '' then |
|
|
require 'nn' |
|
|
require 'cutorch' |
|
|
require 'cunn' |
|
|
require 'cudnn' |
|
|
print(string.format('| reloading experiment %s', config.reload)) |
|
|
if pl.path.isdir(config.reload) then |
|
|
loadfilename = string.format('%s/model.bin', config.reload) |
|
|
else |
|
|
loadfilename = config.reload |
|
|
end |
|
|
local f = torch.DiskFile(loadfilename) |
|
|
f:binary() |
|
|
net = f:readObject() |
|
|
f:close() |
|
|
|
|
|
local oldconfig_file = string.format('%s/config.bin', config.reload) |
|
|
if pl.path.exists(oldconfig_file) then |
|
|
local oldconfig = torch.load(oldconfig_file) |
|
|
oldconfig.subdir = nil |
|
|
tnt.utils.table.merge(oldconfig, config) |
|
|
oldconfig.reload = nil |
|
|
oldconfig.save = config.reload |
|
|
config = oldconfig |
|
|
end |
|
|
|
|
|
end |
|
|
|
|
|
|
|
|
|
|
|
print(pl.pretty.write(config)) |
|
|
|
|
|
return config, net |
|
|
end |
|
|
|
|
|
|
|
|
function nnutils.torchnet_custom_merge() |
|
|
local transform = require 'torchnet.transform' |
|
|
local utils = require 'torchnet.utils' |
|
|
return transform.tableapply( |
|
|
function (field) |
|
|
if type(field) == 'table' and field[1] then |
|
|
if type(field[1]) == 'number' then |
|
|
return torch.Tensor(field) |
|
|
elseif torch.typename(field[1]) and torch.typename(field[1]):match('Tensor') then |
|
|
return utils.table.mergetensor(field) |
|
|
end |
|
|
end |
|
|
return field |
|
|
end) |
|
|
end |
|
|
|
|
|
local function old_wrap_dataset(dataset_closure, nthread, nbatch) |
|
|
local tntexp = require('fbcode.deeplearning.experimental.yuandong.torchnet.init') |
|
|
if nthread == 0 then |
|
|
|
|
|
return tntexp.CudaDataset(dataset_closure()) |
|
|
else |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return tntexp.CudaDataset(tntexp.ParallelDataset{ |
|
|
nthread = nthread, |
|
|
closure = dataset_closure |
|
|
}) |
|
|
end |
|
|
end |
|
|
|
|
|
function nnutils.run_old_torchnet(train, test, net, crit_type, config) |
|
|
local tntexp = require('fbcode.deeplearning.experimental.yuandong.torchnet.init') |
|
|
|
|
|
net = net:cuda() |
|
|
|
|
|
local crit |
|
|
if crit_type == "classification" then |
|
|
crit = nn.ClassNLLCriterion() |
|
|
elseif crit_type == "reconstruction" then |
|
|
crit = nn.MSECriterion() |
|
|
end |
|
|
crit = crit:cuda() |
|
|
|
|
|
config.logger_filename = config.logger_filename or string.format('%s/log', config.save) |
|
|
config.model_filename = config.model_filename or string.format("%s/model.bin", config.save) |
|
|
|
|
|
local log = tntexp.Logger{ filename = config.logger_filename } |
|
|
|
|
|
local trainer = tntexp.SGDTrainer(log) |
|
|
nnutils.add_regular_hooks(trainer, crit) |
|
|
|
|
|
tntexp.AverageValueMeter{ rack = trainer, eval = function() return crit.output end, label = "trainloss" } |
|
|
if crit_type == "classification" then |
|
|
tntexp.ClassErrorMeter{ rack = trainer, eval = function() return net.output end, topk = {5,1}, label = "train" } |
|
|
end |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if test then |
|
|
local tester = tntexp.SGDTester(log) |
|
|
nnutils.add_regular_hooks(tester, crit) |
|
|
|
|
|
tntexp.AverageValueMeter{ rack = tester, eval = function() return crit.output end, label = "testloss" } |
|
|
if crit_type == "classification" then |
|
|
tntexp.ClassErrorMeter{ rack = tester, eval = function() return net.output end, topk = {5,1}, label = "test" } |
|
|
end |
|
|
|
|
|
|
|
|
tester:test{ |
|
|
network = net, |
|
|
dataset = old_wrap_dataset(test, config.nthread, config.test_batch), |
|
|
rack = trainer |
|
|
} |
|
|
end |
|
|
|
|
|
|
|
|
nnutils.add_save_on_trainer(log, net, trainer, config.model_filename) |
|
|
|
|
|
|
|
|
|
|
|
trainer:addHook( |
|
|
'end-epoch', |
|
|
function() |
|
|
log:print{} |
|
|
log:print{stdout = true, labels = true, separator = ' | '} |
|
|
end |
|
|
) |
|
|
|
|
|
|
|
|
log:header{} |
|
|
|
|
|
trainer:train{ |
|
|
network = net, |
|
|
criterion = crit, |
|
|
dataset = old_wrap_dataset(train, config.nthread, config.train_batch), |
|
|
lr = config.lr, |
|
|
maxepoch = config.max_epoch |
|
|
} |
|
|
return log |
|
|
end |
|
|
|
|
|
|
|
|
local function wrap_dataset(dataset_closure, nthread) |
|
|
local tnt = require('torchnet') |
|
|
if nthread == 0 then |
|
|
|
|
|
return tnt.DatasetSampler(dataset_closure()) |
|
|
else |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
local dataset = tnt.ParallelDatasetSampler{ |
|
|
nthread = nthread, |
|
|
closure = dataset_closure |
|
|
} |
|
|
|
|
|
return dataset |
|
|
end |
|
|
end |
|
|
|
|
|
function nnutils.run_torchnet(train, test, net, crit_type, config) |
|
|
local tnt = require('torchnet') |
|
|
|
|
|
nnutils.dprint("Put network to cuda") |
|
|
net = net:cuda() |
|
|
|
|
|
local crit |
|
|
if crit_type == "classification" then |
|
|
crit = nn.ClassNLLCriterion() |
|
|
elseif crit_type == "reconstruction" then |
|
|
crit = nn.MSECriterion() |
|
|
end |
|
|
|
|
|
nnutils.dprint("Put crit to cuda") |
|
|
crit = crit:cuda() |
|
|
|
|
|
config.logger_filename = config.logger_filename or string.format('%s/log', config.save) |
|
|
config.model_filename = config.model_filename or string.format("%s/model.bin", config.save) |
|
|
|
|
|
local log = tnt.Logger{ filename = config.logger_filename } |
|
|
|
|
|
local engine = tnt.SGDEngine() |
|
|
|
|
|
|
|
|
local timer = tnt.TimeMeter{ per = true } |
|
|
|
|
|
|
|
|
local trainloss = tnt.AverageValueMeter() |
|
|
local testloss = tnt.AverageValueMeter() |
|
|
local trainerr = tnt.ClassErrorMeter{ topk = {5,1} } |
|
|
local testerr = tnt.ClassErrorMeter{ topk = {5,1} } |
|
|
local saved = false |
|
|
|
|
|
local log_terms = { |
|
|
timer = function () return timer:value()*1000 end, |
|
|
trainloss = function () return trainloss:value() end, |
|
|
|
|
|
trainerr1 = function () return trainerr:value(1) end, |
|
|
testerr1 = function () return testerr:value(1) end, |
|
|
trainerr5 = function () return trainerr:value(5) end, |
|
|
testerr5 = function () return testerr:value(5) end, |
|
|
saved = function () return saved and '*' or ' ' end |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
local netsav = net:clone('weight', 'bias') |
|
|
local minerr = math.huge |
|
|
|
|
|
local train_wrapper = wrap_dataset(train, config.nthread) |
|
|
local test_wrapper = wrap_dataset(test, config.nthread) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for event, state in |
|
|
engine:train{ network = net, criterion = crit, sampler = train_wrapper, |
|
|
lr = config.lr, maxepoch = config.max_epoch } do |
|
|
if event == 'start-epoch' then |
|
|
|
|
|
trainloss:reset() |
|
|
trainerr:reset() |
|
|
timer:reset() |
|
|
timer:resume() |
|
|
elseif event == 'update' then |
|
|
|
|
|
|
|
|
trainloss:add(state.criterion.output) |
|
|
trainerr:add(state.network.output, state.sample.target) |
|
|
cutorch.synchronize() |
|
|
collectgarbage() |
|
|
timer:inc() |
|
|
elseif event == 'end-epoch' then |
|
|
|
|
|
|
|
|
timer:stop() |
|
|
|
|
|
|
|
|
for event, state in engine:test{ network = net, sampler = test_wrapper } do |
|
|
if event == 'start-epoch' then |
|
|
testerr:reset() |
|
|
|
|
|
elseif event == 'forward' then |
|
|
collectgarbage() |
|
|
|
|
|
testerr:add(state.network.output, state.sample.target) |
|
|
end |
|
|
end |
|
|
|
|
|
|
|
|
local z = testerr:value(1) |
|
|
if z < minerr then |
|
|
local f = torch.DiskFile(config.model_filename, 'w') |
|
|
f:binary() |
|
|
f:writeObject(netsav) |
|
|
f:close() |
|
|
minerr = z |
|
|
saved = true |
|
|
else |
|
|
saved = false |
|
|
end |
|
|
|
|
|
local messages = { string.format(" epoch: %d", state.epoch) } |
|
|
for k, v in pairs(log_terms) do |
|
|
local value = v() |
|
|
if type(value) == 'number' then |
|
|
if value == math.ceil(value) then |
|
|
value = string.format("%d", value) |
|
|
else |
|
|
value = string.format("%.2f", value) |
|
|
end |
|
|
end |
|
|
table.insert(messages, string.format("%s: %s", k, value)) |
|
|
end |
|
|
log:print(table.concat(messages, " | ")) |
|
|
end |
|
|
end |
|
|
return log_terms |
|
|
|
|
|
end |
|
|
|
|
|
|
|
|
|
|
|
local function merge_layer(bn, linear) |
|
|
if bn == nil or linear == nil then return end |
|
|
|
|
|
local bn_matched = (torch.type(bn) == "nn.SpatialBatchNormalization" or torch.type(bn) == "nn.BatchNormalization") |
|
|
local linear_matched = (torch.type(linear) == "cudnn.SpatialConvolution" or torch.type(linear) == "nn.Linear") |
|
|
|
|
|
assert(not bn_matched or linear_matched, "Find BatchNormalization layer but linear layer is missing!") |
|
|
if (not bn_matched) or (not linear_matched) then return end |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
local device_id = bn.running_mean:getDevice() |
|
|
cutorch.withDevice(device_id, |
|
|
function () |
|
|
local scale = bn.running_std:clone() |
|
|
local shift = bn.running_mean:clone() |
|
|
|
|
|
scale:cmul(bn.weight) |
|
|
shift:cmul(scale):mul(-1.0):add(bn.bias) |
|
|
|
|
|
for i = 1, linear.weight:size(1) do |
|
|
linear.weight[i]:mul(scale[i]) |
|
|
end |
|
|
linear.bias:cmul(scale):add(shift) |
|
|
end) |
|
|
|
|
|
return linear |
|
|
end |
|
|
|
|
|
local function recursive_merge_layer(model) |
|
|
local prev_mod |
|
|
for i = 1, #model.modules do |
|
|
local mod = model.modules[i] |
|
|
if mod.modules then |
|
|
recursive_merge_layer(mod) |
|
|
else |
|
|
merge_layer(mod, prev_mod) |
|
|
end |
|
|
prev_mod = mod |
|
|
end |
|
|
end |
|
|
|
|
|
local function rebuild_layers_except(old_model, except_mod_names) |
|
|
if old_model.modules then |
|
|
local new_model = old_model:clone() |
|
|
new_model.modules = {} |
|
|
|
|
|
for i = 1, #old_model.modules do |
|
|
local layer = rebuild_layers_except(old_model.modules[i], except_mod_names) |
|
|
if layer ~= nil then new_model:add(layer) end |
|
|
end |
|
|
return new_model |
|
|
else |
|
|
if pl.tablex.find(except_mod_names, torch.type(old_model)) == nil then |
|
|
return old_model |
|
|
end |
|
|
end |
|
|
end |
|
|
|
|
|
function nnutils.remove_batchnorm(model) |
|
|
|
|
|
recursive_merge_layer(model) |
|
|
|
|
|
return rebuild_layers_except(model, { "nn.BatchNormalization", "nn.SpatialBatchNormalization" }) |
|
|
end |
|
|
|
|
|
|
|
|
|
|
|
local function remove_data_parallel(model) |
|
|
local res = {} |
|
|
for i, m in ipairs(model.modules) do |
|
|
if torch.typename(m) == 'nn.DataParallel' then |
|
|
table.insert(res, remove_data_parallel(m.modules[1])) |
|
|
else |
|
|
table.insert(res, m:clone()) |
|
|
end |
|
|
end |
|
|
|
|
|
local new_model = model:clone() |
|
|
new_model.modules = res |
|
|
return new_model |
|
|
end |
|
|
|
|
|
nnutils.remove_data_parallel = remove_data_parallel |
|
|
|
|
|
|
|
|
|
|
|
local function load_list(f) |
|
|
if type(f) ~= 'string' then return f end |
|
|
|
|
|
local ext = pl.path.extension(f) |
|
|
if ext == '.t7' then |
|
|
return torch.load(f) |
|
|
elseif ext == '.lst' then |
|
|
return (require 'torchnet.utils.sys').loadlist(f, true) |
|
|
end |
|
|
end |
|
|
|
|
|
function nnutils.classconverter(sourcefile, targetfile, name_converter) |
|
|
local src = load_list(sourcefile) |
|
|
local dst = load_list(targetfile) |
|
|
|
|
|
local dst_inv = {} |
|
|
for i, v in ipairs(dst) do |
|
|
dst_inv[v] = i |
|
|
end |
|
|
|
|
|
return function (srcidx) |
|
|
if type(srcidx) == 'number' then |
|
|
return dst_inv[name_converter(src[srcidx])] |
|
|
elseif type(srcidx) == 'table' then |
|
|
local res = {} |
|
|
for _, i in ipairs(srcidx) do |
|
|
table.insert(res, dst_inv[name_converter(src[i])]) |
|
|
end |
|
|
return res |
|
|
elseif torch.typename(srcidx) == 'torch.DoubleTensor' then |
|
|
local res = torch.DoubleTensor(srcidx:size()) |
|
|
for i = 1, res:nElement() do |
|
|
res[i] = dst_inv[name_converter(src[i])] |
|
|
end |
|
|
return res |
|
|
end |
|
|
end |
|
|
end |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function nnutils.predict_compare(model, s) |
|
|
model:evaluate() |
|
|
local data_cuda = s.input:cuda() |
|
|
local res = model:forward(data_cuda) |
|
|
local max_value, max_indices = torch.max(res:float(), 2) |
|
|
|
|
|
local accuracy = s.target:long():eq(max_indices:long()):sum() / s.input:size(1) |
|
|
return max_indices, accuracy |
|
|
end |
|
|
|
|
|
|
|
|
|
|
|
function nnutils.predict_top(model, s, topk) |
|
|
model:evaluate() |
|
|
local data_cuda = s.input:cuda() |
|
|
local output = model:forward(data_cuda) |
|
|
|
|
|
local sum = {} |
|
|
local maxk = 0 |
|
|
for _,k in ipairs(topk) do |
|
|
sum[k] = 0 |
|
|
maxk = math.max(maxk, k) |
|
|
end |
|
|
local _, pred = output:double():sort(2, true) |
|
|
local no = output:size(1) |
|
|
for i=1,no do |
|
|
local predi = pred[i] |
|
|
local targi = s.target[i] |
|
|
local minik = math.huge |
|
|
for k=1,maxk do |
|
|
if predi[k] == targi then |
|
|
minik = k |
|
|
break |
|
|
end |
|
|
end |
|
|
for _,k in ipairs(topk) do |
|
|
if minik > k then |
|
|
sum[k] = sum[k]+1 |
|
|
end |
|
|
end |
|
|
end |
|
|
|
|
|
for _, k in ipairs(topk) do |
|
|
sum[k] = sum[k] / no |
|
|
end |
|
|
|
|
|
return sum |
|
|
end |
|
|
|
|
|
|
|
|
|
|
|
function nnutils.compare_network_skip_bn_upto_sign(net1, net2, layername) |
|
|
local i = 1 |
|
|
local j = 1 |
|
|
|
|
|
print(string.format("#net1.modules = %d", #net1.modules)) |
|
|
print(string.format("#net2.modules = %d", #net2.modules)) |
|
|
|
|
|
while true do |
|
|
print("i = ", i) |
|
|
print("j = ", j) |
|
|
|
|
|
while i <= #net1.modules do |
|
|
if torch.type(net1.modules[i]) ~= layername then i = i + 1 else break end |
|
|
end |
|
|
if i > #net1.modules then break end |
|
|
|
|
|
while j <= #net2.modules do |
|
|
if torch.type(net2.modules[j]) ~= layername then j = j + 1 else break end |
|
|
end |
|
|
if j > #net2.modules then break end |
|
|
|
|
|
|
|
|
local w1 = net1.modules[i]:parameters() |
|
|
local w2 = net2.modules[j]:parameters() |
|
|
if #w1 ~= #w2 then |
|
|
print("compare_network_skip_bn_upto_sign: Dimension mismatch.") |
|
|
print(string.format("net1.modules[%d] = ", i)) |
|
|
print(#w1) |
|
|
print(string.format("net2.modules[%d] = ", j)) |
|
|
print(#w2) |
|
|
error("") |
|
|
end |
|
|
|
|
|
for k = 1, #w1 do |
|
|
local s1 = w1[k]:storage() |
|
|
local s2 = w2[k]:storage() |
|
|
|
|
|
for l = 1, s1:size() do |
|
|
if math.abs(math.abs(s1[l]) - math.abs(s2[l])) > 1e-4 then |
|
|
error(string.format("Weight net1.modules[%d][%d][%d] (= %f) is different from net2.modules[%d][%d][%d] (= %f)", i, k, l, s1[l], j, k, l, s2[l])) |
|
|
end |
|
|
end |
|
|
end |
|
|
|
|
|
i = i + 1 |
|
|
j = j + 1 |
|
|
end |
|
|
end |
|
|
|
|
|
return nnutils |
|
|
|