FloorPlanTransformation / util /lua /compute_stats.lua
rawanessam's picture
Upload 79 files
e9fe176 verified
-- Compute network status.
local transform = require 'torchnet.transform'
local stats = {}
local function get_layer(model, layer_names)
local res = {}
-- local inv_table = {}
-- for _, l in pairs(layer_names) do inv_table[l] = true end
for i = 1, #model.modules do
local m = model.modules[i]
if layer_names[torch.typename(m)] then
table.insert(res, m)
elseif m.modules then
local prev = get_layer(m, layer_names)
for _, mm in ipairs(prev) do
table.insert(res, mm)
end
end
end
return res
end
function stats.get_relus(model)
local layer_names = { ["nn.ReLU"] = true, ["cudnn.ReLU"] = true }
return get_layer(model, layer_names)
end
function stats.get_fcs(model)
local fc_layers = {}
for i = 1, #model.modules do
local m = model.modules[i]
if torch.typename(m) == 'nn.Linear' then
local prev_layer = (i == 1 and 'input' or model.modules[i-1])
table.insert(fc_layers, {prev_layer, model.modules[i] })
end
end
return fc_layers
end
local function permute_stats(stats)
for j = 1, #stats do
local perm = {}
for rr = 3, stats[j]:nDimension() do table.insert(perm, rr) end
table.insert(perm, 1); table.insert(perm, 2)
stats[j] = stats[j]:permute(unpack(perm))
end
end
function stats.create_cell(dims)
local counter = {}
for i = 1, #dims do
counter[i] = 1
end
local done = false
local t = {}
while not done do
-- Based on current
local currt = t
for i = 1, #dims do
if not currt[counter[i]] then currt[counter[i]] = {} end
currt = currt[counter[i]]
end
-- Advance
done = true
for i = #dims, 1, -1 do
counter[i] = counter[i] + 1
if counter[i] <= dims[i] then
done = false
break
end
counter[i] = 1
end
end
return t
end
local function evaluate_through(model, all_data, all_labels, nBatch, func)
model:evaluate()
local nTrain = all_data:size(1)
nBatch = nBatch or 128
local accuracy = 0.0
-- Statistics
for i = 1, nTrain, nBatch do
local data = all_data[{{i, i + nBatch - 1}, {}}]
local res = model:forward(data:cuda())
local gt_labels = all_labels:sub(i, i + nBatch - 1):float()
func(data, gt_labels, res)
best_score, best_idx = torch.max(res, 2)
accuracy = accuracy + best_idx:float():eq(gt_labels):sum()
end
return accuracy / nTrain
end
function stats.torchnet_evaluator(model, dataset, collector, maxload)
local n = 0
local perm = transform.randperm(dataset:size())
local accuracy = 0.0
local crit = collector.needbackprop and nn.ClassNLLCriterion():cuda()
if collector.needbackprop then model:training() else model:evaluate() end
if collector.starter then collector.starter(model) end
for sample in dataset:iterator{perm = perm} do
local input_cuda = sample.input:cuda()
local res = model:forward(input_cuda)
if collector.needbackprop then
-- Also backprop with ground truth data
local target_cuda = torch.squeeze(sample.target):cuda()
crit:forward(res, target_cuda)
model:zeroGradParameters()
crit:backward(res, target_cuda)
model:backward(input_cuda, crit.gradInput)
end
if collector.collector then collector.collector(input_cuda, sample.target) end
best_score, best_idx = torch.max(res, 2)
accuracy = accuracy + best_idx:double():eq(sample.target):sum()
n = n + input_cuda:size(1)
if n >= maxload then break end
end
if collector.finalizer then collector.finalizer(model) end
return accuracy / n, collector.returner()
end
function stats.node_stats_collector(nClass)
-- Check how many ReLU layers are there.
-- Statistics
-- layer, node id -> K by 2 tensor.
local node_stats = {}
local relu_layers = {}
local collector = function (batch_input, batch_target)
-- Get layer statistics
local n_batch = batch_input:size(1)
for j = 1, #relu_layers do
local output = relu_layers[j].output:clone()
local dim = torch.totable(output[1]:size())
if not node_stats[j] then
node_stats[j] = torch.zeros(nClass, 2, unpack(dim))
end
local high = output:ge(1e-4):double()
local low = output:lt(1e-4):double()
for k = 1, n_batch do
local t = batch_target[k]
if type(t) ~= 'number' then t = torch.squeeze(t) end
local s = node_stats[j][t]
assert(s, string.format("stats.node_stats_collector out of bound. j = %d, batch_target[%d] = %d", j, k, t))
s[1]:add(low[k])
s[2]:add(high[k])
end
end
end
-- local accuracy = evaluate_through(model, all_data, all_labels, nBatch, collector)
-- return node_stats, relu_layers, accuracy
return {
returner = function () return node_stats, relu_layers end,
starter = function (model) relu_layers = stats.get_relus(model) end,
collector = collector,
finalizer = function () permute_stats(node_stats) end
}
end
function stats.node_grad_corr_collector(nClass)
local node_stats = {}
local relu_layers = {}
local collector = function (batch_input, batch_target)
-- Get layer statistics
local n_batch = batch_input:size(1)
for j = 1, #relu_layers do
local gradInput = relu_layers[j].gradInput:clone()
local dim = torch.totable(gradInput[1]:size())
if not node_stats[j] then
node_stats[j] = torch.zeros(nClass, 2, unpack(dim))
end
local pos = gradInput:gt(1e-4):double()
local neg = gradInput:lt(-1e-4):double()
for k = 1, n_batch do
local t = batch_target[k]
if type(t) ~= 'number' then t = torch.squeeze(t) end
local s = node_stats[j][t]
assert(s, string.format("stats.node_grad_corr_collector out of bound. j = %d, batch_target[%d] = %d", j, k, t))
s[1]:add(neg[k])
s[2]:add(pos[k])
end
end
end
-- local accuracy = evaluate_through(model, all_data, all_labels, nBatch, collector)
-- return node_stats, relu_layers, accuracy
return {
returner = function () return node_stats, relu_layers end,
starter = function (model) relu_layers = stats.get_relus(model) end,
collector = collector,
needbackprop = true,
finalizer = function () permute_stats(node_stats) end
}
end
--[[
local s = node_stats[j][l][t]
if bConv then
local gv = gradInput[k][l]:view(-1)
for kk = 1, gv:size(1) do
if gv[kk] < -1e-4 then
table.insert(s[1], img_counter)
elseif gv[kk] > 1e-4 then
table.insert(s[2], img_counter)
end
end
else
local gv = gradInput[k][l]
if gv < -1e-4 then
table.insert(s[1], img_counter)
elseif gv > 1e-4 then
table.insert(s[2], img_counter)
end
end
]]--
function stats.node_resp_image_collector(nImage, collection_type)
local node_stats = {}
local relu_layers = {}
local batch_inputs = {}
local batch_targets = {}
local collector = function (batch_input, batch_target)
-- Get layer statistics
local n_batch = batch_input:size(1)
local baseaddr = #batch_inputs * batch_input:size(1)
for j = 1, #relu_layers do
local output = relu_layers[j].output
local outputAgg
if output[1]:nDimension() == 3 then
if collection_type == 'sum' then
outputAgg = output:sum(3):sum(4)
elseif collection_type == 'max' then
local agg1, _ = output:max(3)
outputAgg, _ = agg1:max(4)
else
error(string.format("collection_type = %s is not defined.", collection_type))
end
else
outputAgg = output:clone()
end
-- Check if the
local nChannel = output[1]:size(1)
if not node_stats[j] then
node_stats[j] = torch.FloatTensor(nImage, nChannel):zero()
end
assert(baseaddr + 1 <= nImage, string.format("baseaddr + 1 = %d is out of bound (%d)", baseaddr + 1, nImage))
assert(baseaddr + n_batch <= nImage, string.format("baseaddr + n_batch = %d is out of bound (%d)", baseaddr + n_batch, nImage))
node_stats[j]:sub(baseaddr + 1, baseaddr + n_batch):copy(outputAgg)
end
table.insert(batch_inputs, batch_input:float())
table.insert(batch_targets, batch_target:float())
end
-- local accuracy = evaluate_through(model, all_data, all_labels, nBatch, collector)
-- return node_stats, relu_layers, accuracy
return {
returner = function () return node_stats, relu_layers, batch_inputs, batch_targets end,
starter = function (model) relu_layers = stats.get_relus(model) end,
collector = collector,
finalizer = function ()
for i = 1, #node_stats do
node_stats[i] = node_stats[i]:transpose(1, 2)
end
end
}
end
function stats.node_grad_image_collector(nImage)
local node_stats = {}
local relu_layers = {}
local batch_inputs = {}
local batch_targets = {}
local collector = function (batch_input, batch_target)
-- Get layer statistics
local n_batch = batch_input:size(1)
local baseaddr = #batch_inputs * batch_input:size(1)
for j = 1, #relu_layers do
local gradInput = relu_layers[j].gradInput
local gradInputAgg
if gradInput[1]:nDimension() == 3 then
gradInputAgg = gradInput:sum(3):sum(4)
else
gradInputAgg = gradInput:clone()
end
-- Check if the
local nChannel = gradInput[1]:size(1)
if not node_stats[j] then
node_stats[j] = torch.FloatTensor(nImage, nChannel):zero()
end
assert(baseaddr + 1 <= nImage, string.format("baseaddr + 1 = %d is out of bound (%d)", baseaddr + 1, nImage))
assert(baseaddr + n_batch <= nImage, string.format("baseaddr + n_batch = %d is out of bound (%d)", baseaddr + n_batch, nImage))
node_stats[j]:sub(baseaddr + 1, baseaddr + n_batch):copy(gradInputAgg)
end
table.insert(batch_inputs, batch_input:float())
table.insert(batch_targets, batch_target:float())
end
-- local accuracy = evaluate_through(model, all_data, all_labels, nBatch, collector)
-- return node_stats, relu_layers, accuracy
return {
returner = function () return node_stats, relu_layers, batch_inputs, batch_targets end,
starter = function (model) relu_layers = stats.get_relus(model) end,
collector = collector,
needbackprop = true,
finalizer = function ()
for i = 1, #node_stats do
node_stats[i] = node_stats[i]:transpose(1, 2)
end
end
}
end
function stats.weight_stats(model, all_data, all_labels, nBatch, nClass)
local fc_layers = stats.get_fcs(model)
-- weight stats
local weight_stats = {}
for j = 1, #fc_layers do weight_stats[j] = {} end
local collector = function (batch_input, batch_target)
-- Get weight stats
local nbatch = batch_input:size(1)
for j = 1, #fc_layers do
-- Get input and output, and compute their cross statistics
local prev_input = fc_layers[j][1] == 'input' and batch_input or fc_layers[j][1].output
local curr = fc_layers[j][2]
if not weight_stats[j].stats then
weight_stats[j].stats = torch.zeros(nClass, 4, curr.weight:size(1), curr.weight:size(2))
end
local high_i = prev_input:ge(1e-4):double()
local low_i = prev_input:lt(1e-4):double()
local high_o = curr.output:ge(1e-4):double()
local low_o = curr.output:lt(1e-4):double()
for k = 1, nbatch do
local s = weight_stats[j].stats[batch_target[k]]
s[1]:add(torch.ger(low_o[k], low_i[k]))
s[2]:add(torch.ger(low_o[k], high_i[k]))
s[3]:add(torch.ger(high_o[k], low_i[k]))
s[4]:add(torch.ger(high_o[k], high_i[k]))
end
end
end
local accuracy = evaluate_through(model, all_data, all_labels, nBatch, collector)
for j = 1, #fc_layers do
weight_stats[j].stats = weight_stats[j].stats:permute(3, 4, 2, 1)
end
return weight_stats, fc_layers, accuracy
end
return stats