rawanessam's picture
Upload 79 files
e9fe176 verified
require 'torch'
require 'nn'
require 'cunn'
require 'cudnn'
require 'fbnn'
require 'fbcunn'
require 'fbcode.deeplearning.experimental.yuandong.layers.custom_layers'
local pl = require 'pl.import_into'()
local tnt = require 'torchnet'
local bistro = require 'bistro'
local cjson = require 'cjson'
-- Some global variables.
local conv_layer, relu_layer, maxpool_layer
if cudnn then
conv_layer = cudnn.SpatialConvolution
relu_layer = cudnn.ReLU
maxpool_layer = cudnn.SpatialMaxPooling
else
conv_layer = nn.SpatialConvolutionMM
relu_layer = nn.ReLU
maxpool_layer = nn.SpatialMaxPooling
end
-- Some local utility that make networks.
-- make_network, it take layer specification + inputdim as input, return the actual layer and output dim.
local function conv_layer_size(inputsize, kw, dw, pw)
dw = dw or 1
pw = pw or 0
return math.floor((inputsize + pw * 2 - kw) / dw) + 1
end
local function spatial_layer_size(layer, inputdim)
layer.nip = layer.nip or inputdim[2]
layer.nop = layer.nop or layer.nip
-- print(pl.pretty.write(inputdim))
assert(#inputdim == 4, 'Spatial_layer_size: Input dim must be 4 dimensions')
assert(layer.nip == inputdim[2], string.format('Spatial_layer_size: the number of input channel [%d] is not the same as specified [%d]!', inputdim[2], layer.nip))
assert(layer.nop, 'Spatial_layer_size: layer.nop is null!')
layer.pw = layer.pw or 0
layer.dw = layer.dw or 1
layer.kh = layer.kh or layer.kw
layer.ph = layer.ph or layer.pw
layer.dh = layer.dh or layer.dw
local outputw = conv_layer_size(inputdim[4], layer.kw, layer.dw, layer.pw)
local outputh = conv_layer_size(inputdim[3], layer.kh, layer.dh, layer.ph)
return {inputdim[1], layer.nop, outputh, outputw}
end
local function spec_expand(dim, spec, inputdim)
local res = {}
local total = 0
for i = 1, #spec do
if type(spec[i]) == 'table' then
this_res, this_total = spec_expand(dim, spec[i], inputdim)
else
this_res = pl.tablex.deepcopy(inputdim)
this_res[dim] = spec[i]
this_total = spec[i]
end
table.insert(res, this_res)
total = total + this_total
end
return res, total
end
local function make_network(layer, inputdim)
-- if layer.showinputdim then
print("++++++++++++++++++++++++++++")
print("Layer spec = " .. pl.pretty.write(layer, '', false))
print("Input dim = " .. pl.pretty.write(inputdim, '', false))
print("----------------------------")
-- end
if not layer.type or layer.type == 'seq' then
local seq = nn.Sequential()
-- Maka a sequential network
local curr_inputdim = inputdim
local ll
for _, l in ipairs(layer) do
ll, curr_inputdim = make_network(l, curr_inputdim)
seq:add(ll)
end
return seq, curr_inputdim
elseif layer.type == 'parallel' then
local para = nn.ParallelTable()
local ll
local outputdim = {}
for idx, l in ipairs(layer) do
ll, outputdim[idx] = make_network(l, inputdim[idx])
para:add(ll)
end
return para, outputdim
elseif layer.type == 'join' then
assert(type(inputdim) == 'table', 'MakeNetwork::Join: inputdim must be a table!')
local outputdim
for idx, dim in ipairs(inputdim) do
if not outputdim then
outputdim = dim
else
for dim_idx, d_size in ipairs(dim) do
if dim_idx == layer.dim then
outputdim[dim_idx] = outputdim[dim_idx] + d_size
else
assert(outputdim[dim_idx] == d_size, 'Join table, dimension ' .. dim_idx .. ' disagree!')
end
end
end
end
return nn.JoinTable(layer.dim), outputdim
elseif layer.type == 'conv' then
layer.nip = layer.nip or inputdim[2]
layer.dw = layer.dw or 1
layer.kh = layer.kh or layer.kw
layer.dh = layer.dh or layer.dw
layer.pw = layer.pw or math.floor(layer.kw / 2)
layer.ph = layer.ph or math.floor(layer.kh / 2)
assert(layer.nip == inputdim[2], string.format('MakeNetwork::Conv: #input channel [%d] must match with specification [%d]!', inputdim[2], layer.nip));
assert(layer.kw)
assert(layer.kh)
assert(layer.pw)
assert(layer.ph)
assert(layer.dw)
assert(layer.dh)
assert(layer.nip)
assert(layer.nop)
local conv_layer = conv_layer(layer.nip, layer.nop, layer.kw, layer.kh, layer.dw, layer.dh, layer.pw, layer.ph)
return conv_layer, spatial_layer_size(layer, inputdim)
elseif layer.type == 'relu' then
return relu_layer(), inputdim
elseif layer.type == 'bn' then
assert(#inputdim == 2, 'Error! Input to BatchNormalization must be 2-dimensional.')
return nn.BatchNormalization(inputdim[2]), inputdim
elseif layer.type == 'spatialbn' then
assert(#inputdim == 4, 'Error! Input to SpatialBatchNormalization must be 4-dimensional.')
return nn.SpatialBatchNormalization(inputdim[2]), inputdim
elseif layer.type == 'thres' then
return nn.Threshold(0, 1e-6), inputdim
elseif layer.type == 'maxp' then
assert(layer.kw and layer.dw, "MakeNetwork:MaxP: kw and dw should not be nil")
layer.kh = layer.kh or layer.kw
layer.dh = layer.dh or layer.dw
return maxpool_layer(layer.kw, layer.kh, layer.dw, layer.dh), spatial_layer_size(layer, inputdim)
elseif layer.type == 'maxp1' then
assert(layer.kw and layer.dw, "MakeNetwork:MaxP1: kw and dw should not be nil")
local outputdim = { inputdim[1], conv_layer_size(inputdim[2], layer.kw, layer.dw), inputdim[3] }
return nn.TemporalMaxPooling(layer.kw, layer.dw), outputdim
elseif layer.type == 'reshape' then
if layer.dir == '4-2' then
layer.wi = layer.wi or inputdim[4]
layer.nip = layer.nip or inputdim[2]
layer.nop = layer.nop or inputdim[2]*inputdim[3]*inputdim[4]
elseif layer.dir == '3-2' then
-- For temporal 1D network, inputdim[2] is the length and inputdim[3] is the number of channels.
layer.wi = layer.wi or inputdim[2]
layer.nip = layer.nip or inputdim[3]
layer.nop = layer.nop or inputdim[2]*inputdim[3]
end
if layer.wi then
-- Reshape from image to vector.
if layer.dir == '4-2' then
layer.hi = layer.hi or inputdim[3]
assert(#inputdim == 4, 'MakeNetwork::Reshape4-2: Input dim must be 4 dimensions')
local outputsize = { inputdim[1], layer.nip*layer.wi*layer.hi }
assert(outputsize[2] == inputdim[2]*inputdim[3]*inputdim[4], 'MakeNetwork::Reshape4-2: Input dim must match with specified dimensions')
assert(outputsize[2] > 0,
string.format("MakeNetwork::Reshape4-2: outputsize[2] = %d, (nip, wi, hi) = (%d, %d, %d)",
outputsize[2], layer.nip, layer.wi, layer.hi))
return nn.View(outputsize[2]), outputsize
elseif layer.dir == '3-2' then
assert(#inputdim == 3, 'MakeNetwork::Reshape3-2: Input dim must be 3 dimensions')
local outputsize = { inputdim[1], layer.nip*layer.wi }
assert(outputsize[2] == inputdim[2]*inputdim[3], 'MakeNetwork::Reshape3-2: Input dim must match with specified dimensions')
assert(outputsize[2] > 0,
string.format("MakeNetwork::Reshape3-2: outputsize[2] = %d, (nip, wi) = (%d, %d)",
outputsize[2], layer.nip, layer.wi))
return nn.View(outputsize[2]), outputsize
end
elseif layer.wo then
-- Reshape from vector to image.
assert(#inputdim == 2, 'MakeNetwork::Reshape2-4: Input dim must be 2 dimensions')
layer.nip = layer.nip or inputdim[2]
layer.ho = layer.ho or layer.wo
layer.nop = layer.nop or inputdim[2] / (layer.wo * layer.ho)
local outputsize = { inputdim[1], layer.nop, layer.ho, layer.wo }
assert(outputsize[2]*outputsize[3]*outputsize[4] == inputdim[2], 'MakeNetwork::Reshape2-4: Input dim must match with specified dimensions')
return nn.View(layer.nop, layer.ho, layer.wo), outputsize
end
elseif layer.type == 'fc' then
assert(#inputdim == 2, 'MakeNetwork::FC: Input dim must be 2 dimensions')
layer.nip = layer.nip or inputdim[2]
assert(layer.nip == inputdim[2], string.format('MakeNetwork::FC: the number of input channel [%d] is not the same as specified [%d]!', inputdim[2], layer.nip))
return nn.Linear(layer.nip, layer.nop), { inputdim[1], layer.nop }
elseif layer.type == 'conv1' then
-- inputdim[1] : batchsize
-- inputdim[2] : input length
-- inputdim[3] : nip
assert(layer.kw, 'MakeNetwork:Conv1: kw must be specified')
assert(#inputdim == 3, 'MakeNetwork:Conv1: Input dim must be 3 dimensions')
assert(layer.nop, 'MakeNetwork:Conv1: nop must be specified')
-- Note that for temporal convolutional,
layer.nip = layer.nip or inputdim[3]
layer.dw = layer.dw or 1
assert(layer.nip == inputdim[3], string.format('MakeNetwork::Conv1: the number of input channels [%d] is not the same as specified [%d]!', inputdim[3], layer.nip))
local outputdim = {inputdim[1], conv_layer_size(inputdim[2], layer.kw, layer.dw), layer.nop}
return nn.TemporalConvolution(layer.nip, layer.nop, layer.kw, layer.dw), outputdim
elseif layer.type == 'usample' then
assert(#inputdim == 4, 'MakeNetwork::USample: Input dim must be 4 dimensions')
layer.wi = layer.wi or inputdim[3]
assert(layer.wi == inputdim[3], string.format('MakeNetwork::USample: Input height [%d] much match with specification [%d].', inputdim[3], layer.wi))
assert(layer.wi == inputdim[4], string.format('MakeNetwork::USample: Input width [%d] much match with specification [%d].', inputdim[4], layer.wi))
return nn.SpatialUpSamplingNearest(layer.wo / layer.wi), { inputdim[1], inputdim[2], layer.wo, layer.wo }
elseif layer.type == 'recursive-split' then
-- Check if the size are the same.
-- print("InputDim:")
-- print(pl.pretty.write(inputdim))
outputdim, total_use = spec_expand(layer.dim, layer.spec, inputdim)
assert(total_use == inputdim[layer.dim], string.format("MakeNetwork::RecursiveSplitTable: Total usage specified by layer.spec [%d] is not the same as the inputdim[%d] (which is %d)", total_use, layer.dim, inputdim[layer.dim]))
return nn.RecursiveSplitTable(layer.dim - 1, #inputdim - 1, layer.spec), outputdim
elseif layer.type == 'addtable' then
assert(type(inputdim) == 'table', "MakeNetwork::addtable: Inputdim must be a table.")
assert(inputdim[1], "MakeNetwork::addtable: Inputdim must not be empty.")
for i = 2, #inputdim do
assert(#inputdim[i] == #inputdim[1], string.format("MakeNetwork::addtable: Each entry of input dim must be of the same length, yet #input[%d] = %d while #inputdim[1] = %d", i, #inputdim[i], #inputdim[1]))
for j = 1, #inputdim[i] do
assert(inputdim[i][j] == inputdim[1][j], string.format("MakeNetwork::addtable: Each entry of inputdim must be of same size. Now inputdim[%d][%d] = %d while inputdim[1][%d] = %d", i, j, inputdim[i][j], j, inputdim[1][j]))
end
end
return nn.CAddTable(), inputdim[1]
elseif layer.type == 'dropout' then
return nn.Dropout(layer.ratio), inputdim
elseif layer.type == 'logsoftmax' then
return nn.LogSoftMax(), inputdim
else
error("Unknown layer type " .. layer.type);
end
end
local function merge_tables(tbls)
if #tbls == 0 then return {} end
local res = tbls[1]
for i = 2,#tbls do
for j = 1, #tbls[i] do
table.insert(res, tbls[i][j])
end
end
return res
end
local nnutils = {
make_network = make_network,
spatial_layer_size = spatial_layer_size,
merge_tables = merge_tables
}
-- Debugging
local g_nn_dbg = false
function nnutils.dbg_set()
g_nn_dbg = true
end
function nnutils.dbg_clear()
g_nn_dbg = false
end
function nnutils.dprint(s, ...)
if g_nn_dbg then
local p = {...}
if #p == 0 then print(s)
else print(string.format(s, unpack(p)))
end
end
end
-- local debug_mapping = {}
-- function nnutils.dbg_set_mapping(key, value)
-- debug_mapping[key] = value
-- end
-- function nnutils.
-- Get to know whether we are in a bistro run or in a local run, by checking the name of local directory..
function nnutils.in_bistro()
local cwd = io.popen('pwd'):read("*all")
local bistro_prefix = '/gfsai-bistro'
return string.sub(cwd, 1, #bistro_prefix) == bistro_prefix
end
function nnutils.json_stats(t)
return 'json_stats: ' .. cjson.encode(t)
end
function nnutils.deepcopy(obj)
file = torch.MemoryFile() -- creates a file in memory
file:writeObject(obj) -- writes the object into file
file:seek(1) -- comes back at the beginning of the file
return file:readObject() -- gets a clone of object
end
function nnutils.add_if_nonexist(t1, t2)
for k, v in pairs(t2) do
if not t1[k] then t1[k] = v end
end
return t1
end
function nnutils.get_first_available(t, keys)
for _, k in ipairs(keys) do
if t[k] then return k, t[k] end
end
return nil, nil
end
---------- Layer-wise operation -----------------
function nnutils.pick_layers(model, name)
local layers = {}
for i = 1, #model.modules do
local m = model.modules[i]
if torch.typename(m):match(name) then
table.insert(layers, m)
end
end
return layers
end
function nnutils.operate_layers(model, name, func)
for i = 1, #model.modules do
local m = model.modules[i]
if torch.typename(m):match(name) then
func(m)
end
end
end
function nnutils.add_regular_hooks(rack)
-- Trainer has the following hooks:
-- start, start-epoch, sample, forward, backward, update, "end-epoch", "end"
-- Tester has the following hooks:
-- start, start-epoch, sample, forward, end-epoch, end
-- hook collectgarbage and synchronize (timing purposes)
rack:addHook('forward', function() collectgarbage() end)
if rack.hooks.update then
rack:addHook('update', function() cutorch.synchronize() end)
end
local tntexp = require 'fbcode.deeplearning.experimental.yuandong.torchnet.init'
-- time one iteration
tntexp.TimeMeter{ rack = rack, label = "time", perbatch = true }
end
function nnutils.set_output_json(log, trainer, config)
local epoch = 1
local entries_to_log = {'trainloss', 'testloss', 'train top@1', 'train top@5', 'test top@1', 'test top@5' }
trainer:addHook(
'end-epoch',
function()
-- Save a few things to json
local perf_table = { epoch = epoch }
for _, entry in ipairs(entries_to_log) do
if log.key2idx[entry] then
perf_table[entry] = log:get(entry)
end
end
-- Entries for current time stamp
perf_table.timestamp = os.clock()
-- Special entry for whetlab (must be the last one)
perf_table.neg_loss = -log:get("testloss")
bistro.log(pl.tablex.merge(perf_table, config, true))
epoch = epoch + 1
end
)
end
function nnutils.add_save_on_trainer(log, net, trainer, saveto)
-- customized model save
log:column('saved')
local netsav = net:clone('weight', 'bias', 'running_mean', 'running_std')
local minerr = math.huge
trainer:addHook(
'end-epoch',
function()
local valid_col = nnutils.get_first_available(log.key2idx, { 'testloss', 'trainloss' } )
local z = log:get(valid_col)
if z and z < minerr then
if pl.path.isdir(saveto) then
savefile = string.format('%s/model.bin', saveto)
else
savefile = saveto
end
local f = torch.DiskFile(savefile, 'w')
f:binary()
f:writeObject(netsav)
f:close()
minerr = z
log:set('saved', '*')
else
log:set('saved', '')
end
end
)
end
function nnutils.add_logging(trainer, log)
trainer:addHook(
'end-epoch',
function()
log:print{}
log:print{stdout = true, labels = true, separator = ' | '}
end
)
end
function nnutils.reload_if(config, model_name, config_name)
-- reload?
if config.evalOnly and config.reload == '' then
error("evalOnly only works if there is a model to be reloaded!")
end
local net
if config.reload ~='' then
require 'nn'
require 'cutorch'
require 'cunn'
require 'cudnn'
print(string.format('| reloading experiment %s', config.reload))
local f = torch.DiskFile(string.format('%s/%s', config.reload, model_name))
f:binary()
net = f:readObject()
f:close()
if config_name then
local oldconfig_file = string.format('%s/%s', config.reload, config_name)
if pl.path.exists(oldconfig_file) then
local oldconfig = torch.load(oldconfig_file)
oldconfig.subdir = nil
tnt.utils.table.merge(oldconfig, config)
oldconfig.reload = nil
oldconfig.save = config.reload
config = oldconfig
end
end
end
return net, config
end
--------------------------- Parse text to word table index ----------------------------
function nnutils.parse_to_idx(s, isfilename, word2index, per_char)
local word_indices = {}
-- Check the words
local sep = per_char and "." or "[^%s]+"
local content
if isfilename then
local f = torch.DiskFile(s)
content = f:readString('*a') -- NOTE: this reads the whole file at once
print(string.format("size of content = %d", string.len(content)))
f:close()
else
content = s
end
for token in string.gmatch(content, sep) do
local index = word2index[token]
if index ~= nil then
table.insert(word_indices, index)
end
end
print(string.format("Number of tokens = %d", #word_indices))
return word_indices
end
function nnutils.split_seq_into_batch(data, batch_size, seq_length)
-- Cut them into batches.
local len = data:size(1)
local xdata, ydata
if len % (batch_size * seq_length) ~= 0 then
print('cutting off end of data so that the batches/sequences divide evenly')
xdata = data:sub(1, batch_size * seq_length * math.floor(len / (batch_size * seq_length)))
else
xdata = data
end
local ydata = xdata:clone()
ydata:sub(1,-2):copy(xdata:sub(2,-1))
ydata[-1] = xdata[1]
local x_batches = xdata:view(batch_size, -1):split(seq_length, 2) -- #rows = #batches
local y_batches = ydata:view(batch_size, -1):split(seq_length, 2) -- #rows = #batches
assert(#x_batches == #y_batches)
return x_batches, y_batches
end
--------------------------- Deal with configurations ----------------------------------
function nnutils.get_config(default_config)
local config = bistro.get_params(nnutils.add_if_nonexist(default_config,
{ lr = 0.05, gpu = 1, seed = 1111, nthread = 8, evalOnly = false,
train_batch = 128, test_batch = 128, max_epoch = 100, reload = '', save = './'}))
-- reload?
if config.evalOnly and config.reload == '' then
error("evalOnly only works if there is a model to be reloaded!")
end
local net
if config.reload ~= '' then
require 'nn'
require 'cutorch'
require 'cunn'
require 'cudnn'
print(string.format('| reloading experiment %s', config.reload))
if pl.path.isdir(config.reload) then
loadfilename = string.format('%s/model.bin', config.reload)
else
loadfilename = config.reload
end
local f = torch.DiskFile(loadfilename)
f:binary()
net = f:readObject()
f:close()
local oldconfig_file = string.format('%s/config.bin', config.reload)
if pl.path.exists(oldconfig_file) then
local oldconfig = torch.load(oldconfig_file)
oldconfig.subdir = nil
tnt.utils.table.merge(oldconfig, config)
oldconfig.reload = nil
oldconfig.save = config.reload
config = oldconfig
end
end
-- execute lua code in command line with config as environmentß
-- tnt.utils.sys.cmdline(arg, config)
print(pl.pretty.write(config))
return config, net
end
-------------------------- Simple Torchnet framework ----------------------------------
function nnutils.torchnet_custom_merge()
local transform = require 'torchnet.transform'
local utils = require 'torchnet.utils'
return transform.tableapply(
function (field)
if type(field) == 'table' and field[1] then
if type(field[1]) == 'number' then
return torch.Tensor(field)
elseif torch.typename(field[1]) and torch.typename(field[1]):match('Tensor') then
return utils.table.mergetensor(field)
end
end
return field
end)
end
local function old_wrap_dataset(dataset_closure, nthread, nbatch)
local tntexp = require('fbcode.deeplearning.experimental.yuandong.torchnet.init')
if nthread == 0 then
-- return tnt.CudaDataset(tnt.BatchDataset{ dataset = dataset_closure(), batchsize = nbatch })
return tntexp.CudaDataset(dataset_closure())
else
-- local dataset_closure_local = dataset_closure
-- local batch_closure = function ()
-- return tnt.BatchDataset{ dataset = dataset_closure_local(), batchsize = nbatch }
-- end
return tntexp.CudaDataset(tntexp.ParallelDataset{
nthread = nthread,
closure = dataset_closure
})
end
end
function nnutils.run_old_torchnet(train, test, net, crit_type, config)
local tntexp = require('fbcode.deeplearning.experimental.yuandong.torchnet.init')
net = net:cuda()
local crit
if crit_type == "classification" then
crit = nn.ClassNLLCriterion()
elseif crit_type == "reconstruction" then
crit = nn.MSECriterion()
end
crit = crit:cuda()
config.logger_filename = config.logger_filename or string.format('%s/log', config.save)
config.model_filename = config.model_filename or string.format("%s/model.bin", config.save)
local log = tntexp.Logger{ filename = config.logger_filename }
local trainer = tntexp.SGDTrainer(log)
nnutils.add_regular_hooks(trainer, crit)
-- check the average criterion value
tntexp.AverageValueMeter{ rack = trainer, eval = function() return crit.output end, label = "trainloss" }
if crit_type == "classification" then
tntexp.ClassErrorMeter{ rack = trainer, eval = function() return net.output end, topk = {5,1}, label = "train" }
end
-- log
-- trainer:addHook(
-- 'sample',
-- function(sample)
-- print(sample)
-- end
-- )
-- tester
if test then
local tester = tntexp.SGDTester(log)
nnutils.add_regular_hooks(tester, crit)
tntexp.AverageValueMeter{ rack = tester, eval = function() return crit.output end, label = "testloss" }
if crit_type == "classification" then
tntexp.ClassErrorMeter{ rack = tester, eval = function() return net.output end, topk = {5,1}, label = "test" }
end
-- we hook it to the trainer
tester:test{
network = net,
dataset = old_wrap_dataset(test, config.nthread, config.test_batch),
rack = trainer
}
end
-- customized model save
nnutils.add_save_on_trainer(log, net, trainer, config.model_filename)
-- log
-- Note that this has to be put in the last, otherwise since the statistics are not fully collected, it will error.
trainer:addHook(
'end-epoch',
function()
log:print{}
log:print{stdout = true, labels = true, separator = ' | '}
end
)
-- go
log:header{}
trainer:train{
network = net,
criterion = crit,
dataset = old_wrap_dataset(train, config.nthread, config.train_batch),
lr = config.lr,
maxepoch = config.max_epoch
}
return log
end
local function wrap_dataset(dataset_closure, nthread)
local tnt = require('torchnet')
if nthread == 0 then
-- return tnt.CudaDataset(tnt.BatchDataset{ dataset = dataset_closure(), batchsize = nbatch })
return tnt.DatasetSampler(dataset_closure())
else
-- local dataset_closure_local = dataset_closure
-- local batch_closure = function ()
-- return tnt.BatchDataset{ dataset = dataset_closure_local(), batchsize = nbatch }
-- end
local dataset = tnt.ParallelDatasetSampler{
nthread = nthread,
closure = dataset_closure
}
return dataset
end
end
function nnutils.run_torchnet(train, test, net, crit_type, config)
local tnt = require('torchnet')
nnutils.dprint("Put network to cuda")
net = net:cuda()
local crit
if crit_type == "classification" then
crit = nn.ClassNLLCriterion()
elseif crit_type == "reconstruction" then
crit = nn.MSECriterion()
end
nnutils.dprint("Put crit to cuda")
crit = crit:cuda()
config.logger_filename = config.logger_filename or string.format('%s/log', config.save)
config.model_filename = config.model_filename or string.format("%s/model.bin", config.save)
local log = tnt.Logger{ filename = config.logger_filename }
local engine = tnt.SGDEngine()
-- time one iteration
local timer = tnt.TimeMeter{ per = true }
-- check the average criterion value
local trainloss = tnt.AverageValueMeter()
local testloss = tnt.AverageValueMeter()
local trainerr = tnt.ClassErrorMeter{ topk = {5,1} }
local testerr = tnt.ClassErrorMeter{ topk = {5,1} }
local saved = false
local log_terms = {
timer = function () return timer:value()*1000 end,
trainloss = function () return trainloss:value() end,
-- testloss = function () return testloss:value() end,
trainerr1 = function () return trainerr:value(1) end,
testerr1 = function () return testerr:value(1) end,
trainerr5 = function () return trainerr:value(5) end,
testerr5 = function () return testerr:value(5) end,
saved = function () return saved and '*' or ' ' end
}
-- customized model save
-- we save a stateless model
local netsav = net:clone('weight', 'bias')
local minerr = math.huge
local train_wrapper = wrap_dataset(train, config.nthread)
local test_wrapper = wrap_dataset(test, config.nthread)
-- print(train_wrapper)
-- print("Net type = ")
-- print(torch.typename(net))
-- print("Crit type = ")
-- print(torch.typename(crit))
-- print("TrainWrapper type = ")
-- print(torch.typename(train_wrapper))
-- print("Lr type = ")
-- print(type(config.lr))
-- print("max_epoch type = ")
-- print(type(config.max_epoch))
-- local class = require 'class'
-- local env = require 'argcheck.env'
-- print(env.istype(net, 'nn.Module'))
-- print(env.istype(crit, 'nn.Criterion'))
-- print(env.istype(train_wrapper, 'tnt.DatasetSampler'))
-- print(env.istype(config.lr, 'number'))
-- print(class.type(train_wrapper))
for event, state in
engine:train{ network = net, criterion = crit, sampler = train_wrapper,
lr = config.lr, maxepoch = config.max_epoch } do
if event == 'start-epoch' then
-- print("nnutils.run_torchnet: In start-epoch!")
trainloss:reset()
trainerr:reset()
timer:reset()
timer:resume()
elseif event == 'update' then
-- print("nnutils.run_torchnet: In update!")
trainloss:add(state.criterion.output)
trainerr:add(state.network.output, state.sample.target)
cutorch.synchronize()
collectgarbage()
timer:inc()
elseif event == 'end-epoch' then
-- print("nnutils.run_torchnet: In end-epoch!")
timer:stop()
-- test
for event, state in engine:test{ network = net, sampler = test_wrapper } do
if event == 'start-epoch' then
testerr:reset()
-- testloss:reset()
elseif event == 'forward' then
collectgarbage()
-- testloss:add(state.criterion.output)
testerr:add(state.network.output, state.sample.target)
end
end
-- save if better than ever
local z = testerr:value(1)
if z < minerr then
local f = torch.DiskFile(config.model_filename, 'w')
f:binary()
f:writeObject(netsav)
f:close()
minerr = z
saved = true
else
saved = false
end
-- spit out log
local messages = { string.format(" epoch: %d", state.epoch) }
for k, v in pairs(log_terms) do
local value = v()
if type(value) == 'number' then
if value == math.ceil(value) then
value = string.format("%d", value)
else
value = string.format("%.2f", value)
end
end
table.insert(messages, string.format("%s: %s", k, value))
end
log:print(table.concat(messages, " | "))
end
end
return log_terms
-- return log
end
--------------------------- Remove batch normalization ---------------------------------
local function merge_layer(bn, linear)
if bn == nil or linear == nil then return end
local bn_matched = (torch.type(bn) == "nn.SpatialBatchNormalization" or torch.type(bn) == "nn.BatchNormalization")
local linear_matched = (torch.type(linear) == "cudnn.SpatialConvolution" or torch.type(linear) == "nn.Linear")
assert(not bn_matched or linear_matched, "Find BatchNormalization layer but linear layer is missing!")
if (not bn_matched) or (not linear_matched) then return end
--[[
linear.weight = channelo * channeli * kh * kw
linear.bias = channelo
bn.weight = channelo
bn.bias = channelo
bn.running_mean = channelo
bn.running_std = channelo
--]]
-- Note that running_std is the inverse of std.
local device_id = bn.running_mean:getDevice()
cutorch.withDevice(device_id,
function ()
local scale = bn.running_std:clone()
local shift = bn.running_mean:clone()
scale:cmul(bn.weight)
shift:cmul(scale):mul(-1.0):add(bn.bias)
for i = 1, linear.weight:size(1) do
linear.weight[i]:mul(scale[i])
end
linear.bias:cmul(scale):add(shift)
end)
return linear
end
local function recursive_merge_layer(model)
local prev_mod
for i = 1, #model.modules do
local mod = model.modules[i]
if mod.modules then
recursive_merge_layer(mod)
else
merge_layer(mod, prev_mod)
end
prev_mod = mod
end
end
local function rebuild_layers_except(old_model, except_mod_names)
if old_model.modules then
local new_model = old_model:clone()
new_model.modules = {}
for i = 1, #old_model.modules do
local layer = rebuild_layers_except(old_model.modules[i], except_mod_names)
if layer ~= nil then new_model:add(layer) end
end
return new_model
else
if pl.tablex.find(except_mod_names, torch.type(old_model)) == nil then
return old_model
end
end
end
function nnutils.remove_batchnorm(model)
-- Actual convert the model.
recursive_merge_layer(model)
-- Remove batch normalization layers by rebuild the model
return rebuild_layers_except(model, { "nn.BatchNormalization", "nn.SpatialBatchNormalization" })
end
---------------------------- Remove all data parallel parallel ---------------------
local function remove_data_parallel(model)
local res = {}
for i, m in ipairs(model.modules) do
if torch.typename(m) == 'nn.DataParallel' then
table.insert(res, remove_data_parallel(m.modules[1]))
else
table.insert(res, m:clone())
end
end
local new_model = model:clone()
new_model.modules = res
return new_model
end
nnutils.remove_data_parallel = remove_data_parallel
---------------------------- Convert between different permutations of classes ---------------------
local function load_list(f)
if type(f) ~= 'string' then return f end
local ext = pl.path.extension(f)
if ext == '.t7' then
return torch.load(f)
elseif ext == '.lst' then
return (require 'torchnet.utils.sys').loadlist(f, true)
end
end
function nnutils.classconverter(sourcefile, targetfile, name_converter)
local src = load_list(sourcefile)
local dst = load_list(targetfile)
local dst_inv = {}
for i, v in ipairs(dst) do
dst_inv[v] = i
end
return function (srcidx)
if type(srcidx) == 'number' then
return dst_inv[name_converter(src[srcidx])]
elseif type(srcidx) == 'table' then
local res = {}
for _, i in ipairs(srcidx) do
table.insert(res, dst_inv[name_converter(src[i])])
end
return res
elseif torch.typename(srcidx) == 'torch.DoubleTensor' then
local res = torch.DoubleTensor(srcidx:size())
for i = 1, res:nElement() do
res[i] = dst_inv[name_converter(src[i])]
end
return res
end
end
end
------------------ Compute the prediction ------------------
--- Top one accuracy
function nnutils.predict_compare(model, s)
model:evaluate()
local data_cuda = s.input:cuda()
local res = model:forward(data_cuda)
local max_value, max_indices = torch.max(res:float(), 2)
local accuracy = s.target:long():eq(max_indices:long()):sum() / s.input:size(1)
return max_indices, accuracy
end
-- Predict the top k error
-- s.input is the data, s.target is the label, topk = { 1, 3, 5} e.g.
function nnutils.predict_top(model, s, topk)
model:evaluate()
local data_cuda = s.input:cuda()
local output = model:forward(data_cuda)
local sum = {}
local maxk = 0
for _,k in ipairs(topk) do
sum[k] = 0
maxk = math.max(maxk, k)
end
local _, pred = output:double():sort(2, true)
local no = output:size(1)
for i=1,no do
local predi = pred[i]
local targi = s.target[i]
local minik = math.huge
for k=1,maxk do
if predi[k] == targi then
minik = k
break
end
end
for _,k in ipairs(topk) do
if minik > k then
sum[k] = sum[k]+1
end
end
end
for _, k in ipairs(topk) do
sum[k] = sum[k] / no
end
return sum
end
------------------------------ Misc ---------------------------------------
-- Check layers with given name in net1 and net2, and skip any difference caused by bn sign.
function nnutils.compare_network_skip_bn_upto_sign(net1, net2, layername)
local i = 1
local j = 1
print(string.format("#net1.modules = %d", #net1.modules))
print(string.format("#net2.modules = %d", #net2.modules))
while true do
print("i = ", i)
print("j = ", j)
while i <= #net1.modules do
if torch.type(net1.modules[i]) ~= layername then i = i + 1 else break end
end
if i > #net1.modules then break end
while j <= #net2.modules do
if torch.type(net2.modules[j]) ~= layername then j = j + 1 else break end
end
if j > #net2.modules then break end
-- Compare their parameters
local w1 = net1.modules[i]:parameters()
local w2 = net2.modules[j]:parameters()
if #w1 ~= #w2 then
print("compare_network_skip_bn_upto_sign: Dimension mismatch.")
print(string.format("net1.modules[%d] = ", i))
print(#w1)
print(string.format("net2.modules[%d] = ", j))
print(#w2)
error("")
end
for k = 1, #w1 do
local s1 = w1[k]:storage()
local s2 = w2[k]:storage()
for l = 1, s1:size() do
if math.abs(math.abs(s1[l]) - math.abs(s2[l])) > 1e-4 then
error(string.format("Weight net1.modules[%d][%d][%d] (= %f) is different from net2.modules[%d][%d][%d] (= %f)", i, k, l, s1[l], j, k, l, s2[l]))
end
end
end
i = i + 1
j = j + 1
end
end
return nnutils