|
|
|
|
|
local transform = require 'torchnet.transform' |
|
|
|
|
|
local stats = {} |
|
|
|
|
|
local function get_layer(model, layer_names) |
|
|
local res = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i = 1, #model.modules do |
|
|
local m = model.modules[i] |
|
|
|
|
|
if layer_names[torch.typename(m)] then |
|
|
table.insert(res, m) |
|
|
elseif m.modules then |
|
|
local prev = get_layer(m, layer_names) |
|
|
for _, mm in ipairs(prev) do |
|
|
table.insert(res, mm) |
|
|
end |
|
|
end |
|
|
end |
|
|
return res |
|
|
end |
|
|
|
|
|
function stats.get_relus(model) |
|
|
local layer_names = { ["nn.ReLU"] = true, ["cudnn.ReLU"] = true } |
|
|
return get_layer(model, layer_names) |
|
|
end |
|
|
|
|
|
function stats.get_fcs(model) |
|
|
local fc_layers = {} |
|
|
for i = 1, #model.modules do |
|
|
local m = model.modules[i] |
|
|
if torch.typename(m) == 'nn.Linear' then |
|
|
local prev_layer = (i == 1 and 'input' or model.modules[i-1]) |
|
|
table.insert(fc_layers, {prev_layer, model.modules[i] }) |
|
|
end |
|
|
end |
|
|
return fc_layers |
|
|
end |
|
|
|
|
|
local function permute_stats(stats) |
|
|
for j = 1, #stats do |
|
|
local perm = {} |
|
|
for rr = 3, stats[j]:nDimension() do table.insert(perm, rr) end |
|
|
table.insert(perm, 1); table.insert(perm, 2) |
|
|
|
|
|
stats[j] = stats[j]:permute(unpack(perm)) |
|
|
end |
|
|
end |
|
|
|
|
|
function stats.create_cell(dims) |
|
|
local counter = {} |
|
|
for i = 1, #dims do |
|
|
counter[i] = 1 |
|
|
end |
|
|
|
|
|
local done = false |
|
|
local t = {} |
|
|
while not done do |
|
|
|
|
|
local currt = t |
|
|
for i = 1, #dims do |
|
|
if not currt[counter[i]] then currt[counter[i]] = {} end |
|
|
currt = currt[counter[i]] |
|
|
end |
|
|
|
|
|
|
|
|
done = true |
|
|
for i = #dims, 1, -1 do |
|
|
counter[i] = counter[i] + 1 |
|
|
if counter[i] <= dims[i] then |
|
|
done = false |
|
|
break |
|
|
end |
|
|
counter[i] = 1 |
|
|
end |
|
|
end |
|
|
|
|
|
return t |
|
|
end |
|
|
|
|
|
local function evaluate_through(model, all_data, all_labels, nBatch, func) |
|
|
model:evaluate() |
|
|
local nTrain = all_data:size(1) |
|
|
nBatch = nBatch or 128 |
|
|
local accuracy = 0.0 |
|
|
|
|
|
|
|
|
for i = 1, nTrain, nBatch do |
|
|
local data = all_data[{{i, i + nBatch - 1}, {}}] |
|
|
local res = model:forward(data:cuda()) |
|
|
local gt_labels = all_labels:sub(i, i + nBatch - 1):float() |
|
|
|
|
|
func(data, gt_labels, res) |
|
|
|
|
|
best_score, best_idx = torch.max(res, 2) |
|
|
accuracy = accuracy + best_idx:float():eq(gt_labels):sum() |
|
|
end |
|
|
return accuracy / nTrain |
|
|
end |
|
|
|
|
|
function stats.torchnet_evaluator(model, dataset, collector, maxload) |
|
|
local n = 0 |
|
|
local perm = transform.randperm(dataset:size()) |
|
|
local accuracy = 0.0 |
|
|
|
|
|
local crit = collector.needbackprop and nn.ClassNLLCriterion():cuda() |
|
|
if collector.needbackprop then model:training() else model:evaluate() end |
|
|
|
|
|
if collector.starter then collector.starter(model) end |
|
|
|
|
|
for sample in dataset:iterator{perm = perm} do |
|
|
local input_cuda = sample.input:cuda() |
|
|
local res = model:forward(input_cuda) |
|
|
|
|
|
if collector.needbackprop then |
|
|
|
|
|
local target_cuda = torch.squeeze(sample.target):cuda() |
|
|
|
|
|
crit:forward(res, target_cuda) |
|
|
model:zeroGradParameters() |
|
|
crit:backward(res, target_cuda) |
|
|
model:backward(input_cuda, crit.gradInput) |
|
|
end |
|
|
|
|
|
if collector.collector then collector.collector(input_cuda, sample.target) end |
|
|
|
|
|
best_score, best_idx = torch.max(res, 2) |
|
|
accuracy = accuracy + best_idx:double():eq(sample.target):sum() |
|
|
n = n + input_cuda:size(1) |
|
|
if n >= maxload then break end |
|
|
end |
|
|
|
|
|
if collector.finalizer then collector.finalizer(model) end |
|
|
|
|
|
return accuracy / n, collector.returner() |
|
|
end |
|
|
|
|
|
function stats.node_stats_collector(nClass) |
|
|
|
|
|
|
|
|
|
|
|
local node_stats = {} |
|
|
local relu_layers = {} |
|
|
|
|
|
local collector = function (batch_input, batch_target) |
|
|
|
|
|
local n_batch = batch_input:size(1) |
|
|
for j = 1, #relu_layers do |
|
|
local output = relu_layers[j].output:clone() |
|
|
local dim = torch.totable(output[1]:size()) |
|
|
if not node_stats[j] then |
|
|
node_stats[j] = torch.zeros(nClass, 2, unpack(dim)) |
|
|
end |
|
|
|
|
|
local high = output:ge(1e-4):double() |
|
|
local low = output:lt(1e-4):double() |
|
|
|
|
|
for k = 1, n_batch do |
|
|
local t = batch_target[k] |
|
|
if type(t) ~= 'number' then t = torch.squeeze(t) end |
|
|
local s = node_stats[j][t] |
|
|
assert(s, string.format("stats.node_stats_collector out of bound. j = %d, batch_target[%d] = %d", j, k, t)) |
|
|
s[1]:add(low[k]) |
|
|
s[2]:add(high[k]) |
|
|
end |
|
|
end |
|
|
end |
|
|
|
|
|
|
|
|
|
|
|
return { |
|
|
returner = function () return node_stats, relu_layers end, |
|
|
starter = function (model) relu_layers = stats.get_relus(model) end, |
|
|
collector = collector, |
|
|
finalizer = function () permute_stats(node_stats) end |
|
|
} |
|
|
end |
|
|
|
|
|
function stats.node_grad_corr_collector(nClass) |
|
|
local node_stats = {} |
|
|
local relu_layers = {} |
|
|
|
|
|
local collector = function (batch_input, batch_target) |
|
|
|
|
|
local n_batch = batch_input:size(1) |
|
|
for j = 1, #relu_layers do |
|
|
local gradInput = relu_layers[j].gradInput:clone() |
|
|
|
|
|
local dim = torch.totable(gradInput[1]:size()) |
|
|
if not node_stats[j] then |
|
|
node_stats[j] = torch.zeros(nClass, 2, unpack(dim)) |
|
|
end |
|
|
|
|
|
local pos = gradInput:gt(1e-4):double() |
|
|
local neg = gradInput:lt(-1e-4):double() |
|
|
|
|
|
for k = 1, n_batch do |
|
|
local t = batch_target[k] |
|
|
if type(t) ~= 'number' then t = torch.squeeze(t) end |
|
|
local s = node_stats[j][t] |
|
|
assert(s, string.format("stats.node_grad_corr_collector out of bound. j = %d, batch_target[%d] = %d", j, k, t)) |
|
|
s[1]:add(neg[k]) |
|
|
s[2]:add(pos[k]) |
|
|
end |
|
|
end |
|
|
end |
|
|
|
|
|
|
|
|
|
|
|
return { |
|
|
returner = function () return node_stats, relu_layers end, |
|
|
starter = function (model) relu_layers = stats.get_relus(model) end, |
|
|
collector = collector, |
|
|
needbackprop = true, |
|
|
finalizer = function () permute_stats(node_stats) end |
|
|
} |
|
|
end |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function stats.node_resp_image_collector(nImage, collection_type) |
|
|
local node_stats = {} |
|
|
local relu_layers = {} |
|
|
local batch_inputs = {} |
|
|
local batch_targets = {} |
|
|
|
|
|
local collector = function (batch_input, batch_target) |
|
|
|
|
|
local n_batch = batch_input:size(1) |
|
|
local baseaddr = #batch_inputs * batch_input:size(1) |
|
|
|
|
|
for j = 1, #relu_layers do |
|
|
local output = relu_layers[j].output |
|
|
local outputAgg |
|
|
if output[1]:nDimension() == 3 then |
|
|
if collection_type == 'sum' then |
|
|
outputAgg = output:sum(3):sum(4) |
|
|
elseif collection_type == 'max' then |
|
|
local agg1, _ = output:max(3) |
|
|
outputAgg, _ = agg1:max(4) |
|
|
else |
|
|
error(string.format("collection_type = %s is not defined.", collection_type)) |
|
|
end |
|
|
else |
|
|
outputAgg = output:clone() |
|
|
end |
|
|
|
|
|
|
|
|
local nChannel = output[1]:size(1) |
|
|
if not node_stats[j] then |
|
|
node_stats[j] = torch.FloatTensor(nImage, nChannel):zero() |
|
|
end |
|
|
|
|
|
assert(baseaddr + 1 <= nImage, string.format("baseaddr + 1 = %d is out of bound (%d)", baseaddr + 1, nImage)) |
|
|
assert(baseaddr + n_batch <= nImage, string.format("baseaddr + n_batch = %d is out of bound (%d)", baseaddr + n_batch, nImage)) |
|
|
|
|
|
node_stats[j]:sub(baseaddr + 1, baseaddr + n_batch):copy(outputAgg) |
|
|
end |
|
|
table.insert(batch_inputs, batch_input:float()) |
|
|
table.insert(batch_targets, batch_target:float()) |
|
|
end |
|
|
|
|
|
|
|
|
|
|
|
return { |
|
|
returner = function () return node_stats, relu_layers, batch_inputs, batch_targets end, |
|
|
starter = function (model) relu_layers = stats.get_relus(model) end, |
|
|
collector = collector, |
|
|
finalizer = function () |
|
|
for i = 1, #node_stats do |
|
|
node_stats[i] = node_stats[i]:transpose(1, 2) |
|
|
end |
|
|
end |
|
|
} |
|
|
end |
|
|
|
|
|
function stats.node_grad_image_collector(nImage) |
|
|
local node_stats = {} |
|
|
local relu_layers = {} |
|
|
local batch_inputs = {} |
|
|
local batch_targets = {} |
|
|
|
|
|
local collector = function (batch_input, batch_target) |
|
|
|
|
|
local n_batch = batch_input:size(1) |
|
|
local baseaddr = #batch_inputs * batch_input:size(1) |
|
|
|
|
|
for j = 1, #relu_layers do |
|
|
local gradInput = relu_layers[j].gradInput |
|
|
local gradInputAgg |
|
|
if gradInput[1]:nDimension() == 3 then |
|
|
gradInputAgg = gradInput:sum(3):sum(4) |
|
|
else |
|
|
gradInputAgg = gradInput:clone() |
|
|
end |
|
|
|
|
|
|
|
|
local nChannel = gradInput[1]:size(1) |
|
|
if not node_stats[j] then |
|
|
node_stats[j] = torch.FloatTensor(nImage, nChannel):zero() |
|
|
end |
|
|
|
|
|
assert(baseaddr + 1 <= nImage, string.format("baseaddr + 1 = %d is out of bound (%d)", baseaddr + 1, nImage)) |
|
|
assert(baseaddr + n_batch <= nImage, string.format("baseaddr + n_batch = %d is out of bound (%d)", baseaddr + n_batch, nImage)) |
|
|
|
|
|
node_stats[j]:sub(baseaddr + 1, baseaddr + n_batch):copy(gradInputAgg) |
|
|
end |
|
|
table.insert(batch_inputs, batch_input:float()) |
|
|
table.insert(batch_targets, batch_target:float()) |
|
|
end |
|
|
|
|
|
|
|
|
|
|
|
return { |
|
|
returner = function () return node_stats, relu_layers, batch_inputs, batch_targets end, |
|
|
starter = function (model) relu_layers = stats.get_relus(model) end, |
|
|
collector = collector, |
|
|
needbackprop = true, |
|
|
finalizer = function () |
|
|
for i = 1, #node_stats do |
|
|
node_stats[i] = node_stats[i]:transpose(1, 2) |
|
|
end |
|
|
end |
|
|
} |
|
|
end |
|
|
|
|
|
function stats.weight_stats(model, all_data, all_labels, nBatch, nClass) |
|
|
local fc_layers = stats.get_fcs(model) |
|
|
|
|
|
|
|
|
local weight_stats = {} |
|
|
for j = 1, #fc_layers do weight_stats[j] = {} end |
|
|
|
|
|
local collector = function (batch_input, batch_target) |
|
|
|
|
|
local nbatch = batch_input:size(1) |
|
|
for j = 1, #fc_layers do |
|
|
|
|
|
local prev_input = fc_layers[j][1] == 'input' and batch_input or fc_layers[j][1].output |
|
|
local curr = fc_layers[j][2] |
|
|
|
|
|
if not weight_stats[j].stats then |
|
|
weight_stats[j].stats = torch.zeros(nClass, 4, curr.weight:size(1), curr.weight:size(2)) |
|
|
end |
|
|
|
|
|
local high_i = prev_input:ge(1e-4):double() |
|
|
local low_i = prev_input:lt(1e-4):double() |
|
|
local high_o = curr.output:ge(1e-4):double() |
|
|
local low_o = curr.output:lt(1e-4):double() |
|
|
|
|
|
for k = 1, nbatch do |
|
|
local s = weight_stats[j].stats[batch_target[k]] |
|
|
s[1]:add(torch.ger(low_o[k], low_i[k])) |
|
|
s[2]:add(torch.ger(low_o[k], high_i[k])) |
|
|
s[3]:add(torch.ger(high_o[k], low_i[k])) |
|
|
s[4]:add(torch.ger(high_o[k], high_i[k])) |
|
|
end |
|
|
end |
|
|
end |
|
|
|
|
|
local accuracy = evaluate_through(model, all_data, all_labels, nBatch, collector) |
|
|
for j = 1, #fc_layers do |
|
|
weight_stats[j].stats = weight_stats[j].stats:permute(3, 4, 2, 1) |
|
|
end |
|
|
return weight_stats, fc_layers, accuracy |
|
|
end |
|
|
|
|
|
return stats |