File size: 2,082 Bytes
998bb30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import unittest

import torch

from gat.models.surrogate import build_surrogate, list_surrogates
from gat.models.surrogate.hooks import feat_col


class TestTVModels(unittest.TestCase):

    @torch.no_grad()
    def test_outputs_shape(self):
        inputs = torch.rand(1, 3, 224, 224)
        for surrogate_id in list_surrogates():
            if surrogate_id not in ['inception_v3']:
                outputs = build_surrogate(surrogate_id, pretrain=False)(inputs)
                self.assertEqual(outputs.shape, (1, 1000))


class TestFeatCol(unittest.TestCase):

    @torch.no_grad()
    def test_feat_col(self):
        testcases = [{
            'surrogate_id': 'vgg16',
            'feat_layer': 'features.16',
            'input_shape': (1, 3, 224, 224),
            'feat_shape': (1, 256, 28, 28)
        }, {
            'surrogate_id': 'vgg19',
            'feat_layer': 'features.18',
            'input_shape': (1, 3, 224, 224),
            'feat_shape': (1, 256, 28, 28)
        }, {
            'surrogate_id': 'resnet152',
            'feat_layer': 'layer2',
            'input_shape': (1, 3, 224, 224),
            'feat_shape': (1, 512, 28, 28)
        }, {
            'surrogate_id': 'densenet169',
            'feat_layer': 'features.denseblock2',
            'input_shape': (1, 3, 224, 224),
            'feat_shape': (1, 512, 28, 28)
        }]
        for testcase in testcases:
            model = build_surrogate(testcase['surrogate_id'], pretrain=False)
            with feat_col(model, testcase['feat_layer']) as _feat_collecter:
                inputs = torch.rand(testcase['input_shape'])
                model(inputs)
                model(inputs)
                self.assertEqual(len(_feat_collecter), 2)
                self.assertEqual(_feat_collecter.pop().shape,
                                 testcase['feat_shape'])
                self.assertEqual(_feat_collecter.pop().shape,
                                 testcase['feat_shape'])


if __name__ == '__main__':
    unittest.main()