File size: 13,319 Bytes
e9fe176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
-- Compute network status.
local transform = require 'torchnet.transform'

local stats = {}

local function get_layer(model, layer_names)
    local res = {}

    -- local inv_table = {}
    -- for _, l in pairs(layer_names) do inv_table[l] = true end

    for i = 1, #model.modules do
        local m = model.modules[i]

        if layer_names[torch.typename(m)] then 
            table.insert(res, m) 
        elseif m.modules then
            local prev = get_layer(m, layer_names)
            for _, mm in ipairs(prev) do
                table.insert(res, mm)
            end
        end
    end
    return res
end

function stats.get_relus(model)
    local layer_names = { ["nn.ReLU"] = true, ["cudnn.ReLU"] = true }
    return get_layer(model, layer_names)
end

function stats.get_fcs(model)
    local fc_layers = {}
    for i = 1, #model.modules do
        local m = model.modules[i]
        if torch.typename(m) == 'nn.Linear' then
            local prev_layer = (i == 1 and 'input' or model.modules[i-1])
            table.insert(fc_layers, {prev_layer, model.modules[i] })
        end
    end
    return fc_layers
end

local function permute_stats(stats) 
    for j = 1, #stats do
        local perm = {}
        for rr = 3, stats[j]:nDimension() do table.insert(perm, rr) end
        table.insert(perm, 1); table.insert(perm, 2)

        stats[j] = stats[j]:permute(unpack(perm))
    end
end

function stats.create_cell(dims)
    local counter = {}
    for i = 1, #dims do
        counter[i] = 1
    end

    local done = false
    local t = {}
    while not done do
        -- Based on current 
        local currt = t
        for i = 1, #dims do
            if not currt[counter[i]] then currt[counter[i]] = {} end
            currt = currt[counter[i]]
        end

        -- Advance
        done = true
        for i = #dims, 1, -1 do
            counter[i] = counter[i] + 1
            if counter[i] <= dims[i] then 
                done = false
                break 
            end
            counter[i] = 1
        end
    end

    return t
end

local function evaluate_through(model, all_data, all_labels, nBatch, func)
    model:evaluate()
    local nTrain = all_data:size(1)
    nBatch = nBatch or 128
    local accuracy = 0.0

    -- Statistics
    for i = 1, nTrain, nBatch do
        local data = all_data[{{i, i + nBatch - 1}, {}}]
        local res = model:forward(data:cuda())
        local gt_labels = all_labels:sub(i, i + nBatch - 1):float()

        func(data, gt_labels, res)

        best_score, best_idx = torch.max(res, 2)
        accuracy = accuracy + best_idx:float():eq(gt_labels):sum()
    end
    return accuracy / nTrain
end

function stats.torchnet_evaluator(model, dataset, collector, maxload)
    local n = 0
    local perm = transform.randperm(dataset:size())
    local accuracy = 0.0

    local crit = collector.needbackprop and nn.ClassNLLCriterion():cuda()
    if collector.needbackprop then model:training() else model:evaluate() end

    if collector.starter then collector.starter(model) end

    for sample in dataset:iterator{perm = perm} do
        local input_cuda = sample.input:cuda()
        local res = model:forward(input_cuda)

        if collector.needbackprop then
            -- Also backprop with ground truth data
            local target_cuda = torch.squeeze(sample.target):cuda()

            crit:forward(res, target_cuda)
            model:zeroGradParameters()
            crit:backward(res, target_cuda)
            model:backward(input_cuda, crit.gradInput)
        end

        if collector.collector then collector.collector(input_cuda, sample.target) end

        best_score, best_idx = torch.max(res, 2)
        accuracy = accuracy + best_idx:double():eq(sample.target):sum()
        n = n + input_cuda:size(1)
        if n >= maxload then break end
    end

    if collector.finalizer then collector.finalizer(model) end

    return accuracy / n, collector.returner()
end

function stats.node_stats_collector(nClass)
    -- Check how many ReLU layers are there.
    -- Statistics
    -- layer, node id -> K by 2 tensor.
    local node_stats = {}
    local relu_layers = {}

    local collector = function (batch_input, batch_target)
        -- Get layer statistics
        local n_batch = batch_input:size(1)
        for j = 1, #relu_layers do
            local output = relu_layers[j].output:clone()
            local dim = torch.totable(output[1]:size())
            if not node_stats[j] then
                node_stats[j] = torch.zeros(nClass, 2, unpack(dim))            
            end
                
            local high = output:ge(1e-4):double()
            local low = output:lt(1e-4):double()
                    
            for k = 1, n_batch do
                local t = batch_target[k]
                if type(t) ~= 'number' then t = torch.squeeze(t) end
                local s = node_stats[j][t]
                assert(s, string.format("stats.node_stats_collector out of bound. j = %d, batch_target[%d] = %d", j, k, t))
                s[1]:add(low[k])
                s[2]:add(high[k])
            end       
        end
    end

    -- local accuracy = evaluate_through(model, all_data, all_labels, nBatch, collector)
    -- return node_stats, relu_layers, accuracy
    return { 
        returner = function () return node_stats, relu_layers end, 
        starter = function (model) relu_layers = stats.get_relus(model) end, 
        collector = collector, 
        finalizer = function () permute_stats(node_stats) end 
    }
end

function stats.node_grad_corr_collector(nClass)
    local node_stats = {}
    local relu_layers = {}

    local collector = function (batch_input, batch_target)
        -- Get layer statistics
        local n_batch = batch_input:size(1)
        for j = 1, #relu_layers do
            local gradInput = relu_layers[j].gradInput:clone()

            local dim = torch.totable(gradInput[1]:size())
            if not node_stats[j] then
                node_stats[j] = torch.zeros(nClass, 2, unpack(dim))            
            end
                
            local pos = gradInput:gt(1e-4):double()
            local neg = gradInput:lt(-1e-4):double()
                    
            for k = 1, n_batch do
                local t = batch_target[k]
                if type(t) ~= 'number' then t = torch.squeeze(t) end
                local s = node_stats[j][t]
                assert(s, string.format("stats.node_grad_corr_collector out of bound. j = %d, batch_target[%d] = %d", j, k, t))
                s[1]:add(neg[k])
                s[2]:add(pos[k])
            end
        end
    end

    -- local accuracy = evaluate_through(model, all_data, all_labels, nBatch, collector)
    -- return node_stats, relu_layers, accuracy
    return { 
        returner = function () return node_stats, relu_layers end, 
        starter = function (model) relu_layers = stats.get_relus(model) end, 
        collector = collector, 
        needbackprop = true,
        finalizer = function () permute_stats(node_stats) end 
    }
end

--[[
local s = node_stats[j][l][t]
if bConv then
    local gv = gradInput[k][l]:view(-1)
    for kk = 1, gv:size(1) do
        if gv[kk] < -1e-4 then
            table.insert(s[1], img_counter)
        elseif gv[kk] > 1e-4 then
            table.insert(s[2], img_counter)
        end
    end
else
    local gv = gradInput[k][l]
    if gv < -1e-4 then
        table.insert(s[1], img_counter)
    elseif gv > 1e-4 then
        table.insert(s[2], img_counter)
    end
end
]]--

function stats.node_resp_image_collector(nImage, collection_type)
    local node_stats = {}
    local relu_layers = {}
    local batch_inputs = {}
    local batch_targets = {}

    local collector = function (batch_input, batch_target)
        -- Get layer statistics
        local n_batch = batch_input:size(1)
        local baseaddr = #batch_inputs * batch_input:size(1)

        for j = 1, #relu_layers do
            local output = relu_layers[j].output
            local outputAgg
            if output[1]:nDimension() == 3 then
                if collection_type == 'sum' then
                    outputAgg = output:sum(3):sum(4)
                elseif collection_type == 'max' then
                    local agg1, _ = output:max(3)
                    outputAgg, _ = agg1:max(4)
                else
                    error(string.format("collection_type = %s is not defined.", collection_type))
                end
            else
                outputAgg = output:clone()
            end

            -- Check if the 
            local nChannel = output[1]:size(1)
            if not node_stats[j] then
                node_stats[j] = torch.FloatTensor(nImage, nChannel):zero()
            end

            assert(baseaddr + 1 <= nImage, string.format("baseaddr + 1 = %d is out of bound (%d)", baseaddr + 1, nImage))
            assert(baseaddr + n_batch <= nImage, string.format("baseaddr + n_batch = %d is out of bound (%d)", baseaddr + n_batch, nImage))

            node_stats[j]:sub(baseaddr + 1, baseaddr + n_batch):copy(outputAgg)
        end
        table.insert(batch_inputs, batch_input:float())
        table.insert(batch_targets, batch_target:float())
    end

    -- local accuracy = evaluate_through(model, all_data, all_labels, nBatch, collector)
    -- return node_stats, relu_layers, accuracy
    return { 
        returner = function () return node_stats, relu_layers, batch_inputs, batch_targets end, 
        starter = function (model) relu_layers = stats.get_relus(model) end, 
        collector = collector, 
        finalizer = function () 
            for i = 1, #node_stats do 
                node_stats[i] = node_stats[i]:transpose(1, 2)
            end
        end
    }
end

function stats.node_grad_image_collector(nImage)
    local node_stats = {}
    local relu_layers = {}
    local batch_inputs = {}
    local batch_targets = {}

    local collector = function (batch_input, batch_target)
        -- Get layer statistics
        local n_batch = batch_input:size(1)
        local baseaddr = #batch_inputs * batch_input:size(1)

        for j = 1, #relu_layers do
            local gradInput = relu_layers[j].gradInput
            local gradInputAgg
            if gradInput[1]:nDimension() == 3 then
                gradInputAgg = gradInput:sum(3):sum(4)
            else
                gradInputAgg = gradInput:clone()
            end

            -- Check if the 
            local nChannel = gradInput[1]:size(1)
            if not node_stats[j] then
                node_stats[j] = torch.FloatTensor(nImage, nChannel):zero()
            end

            assert(baseaddr + 1 <= nImage, string.format("baseaddr + 1 = %d is out of bound (%d)", baseaddr + 1, nImage))
            assert(baseaddr + n_batch <= nImage, string.format("baseaddr + n_batch = %d is out of bound (%d)", baseaddr + n_batch, nImage))

            node_stats[j]:sub(baseaddr + 1, baseaddr + n_batch):copy(gradInputAgg)
        end
        table.insert(batch_inputs, batch_input:float())
        table.insert(batch_targets, batch_target:float())
    end

    -- local accuracy = evaluate_through(model, all_data, all_labels, nBatch, collector)
    -- return node_stats, relu_layers, accuracy
    return { 
        returner = function () return node_stats, relu_layers, batch_inputs, batch_targets end, 
        starter = function (model) relu_layers = stats.get_relus(model) end, 
        collector = collector, 
        needbackprop = true,
        finalizer = function () 
            for i = 1, #node_stats do 
                node_stats[i] = node_stats[i]:transpose(1, 2)
            end
        end
    }
end

function stats.weight_stats(model, all_data, all_labels, nBatch, nClass)
    local fc_layers = stats.get_fcs(model)

    -- weight stats
    local weight_stats = {}
    for j = 1, #fc_layers do weight_stats[j] = {} end

    local collector = function (batch_input, batch_target)  
        -- Get weight stats
        local nbatch = batch_input:size(1)
        for j = 1, #fc_layers do
            -- Get input and output, and compute their cross statistics
            local prev_input = fc_layers[j][1] == 'input' and batch_input or fc_layers[j][1].output                
            local curr = fc_layers[j][2]
            
            if not weight_stats[j].stats then
                weight_stats[j].stats = torch.zeros(nClass, 4, curr.weight:size(1), curr.weight:size(2))
            end

            local high_i = prev_input:ge(1e-4):double()
            local low_i = prev_input:lt(1e-4):double()
            local high_o = curr.output:ge(1e-4):double()
            local low_o = curr.output:lt(1e-4):double()
            
            for k = 1, nbatch do
                local s = weight_stats[j].stats[batch_target[k]]
                s[1]:add(torch.ger(low_o[k], low_i[k]))
                s[2]:add(torch.ger(low_o[k], high_i[k]))
                s[3]:add(torch.ger(high_o[k], low_i[k]))
                s[4]:add(torch.ger(high_o[k], high_i[k]))
            end
        end
    end

    local accuracy = evaluate_through(model, all_data, all_labels, nBatch, collector)
    for j = 1, #fc_layers do
        weight_stats[j].stats = weight_stats[j].stats:permute(3, 4, 2, 1)
    end
    return weight_stats, fc_layers, accuracy
end

return stats