|
|
import unittest
|
|
|
|
|
|
import torch
|
|
|
|
|
|
from gat.models.surrogate import build_surrogate, list_surrogates
|
|
|
from gat.models.surrogate.hooks import feat_col
|
|
|
|
|
|
|
|
|
class TestTVModels(unittest.TestCase):
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def test_outputs_shape(self):
|
|
|
inputs = torch.rand(1, 3, 224, 224)
|
|
|
for surrogate_id in list_surrogates():
|
|
|
if surrogate_id not in ['inception_v3']:
|
|
|
outputs = build_surrogate(surrogate_id, pretrain=False)(inputs)
|
|
|
self.assertEqual(outputs.shape, (1, 1000))
|
|
|
|
|
|
|
|
|
class TestFeatCol(unittest.TestCase):
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def test_feat_col(self):
|
|
|
testcases = [{
|
|
|
'surrogate_id': 'vgg16',
|
|
|
'feat_layer': 'features.16',
|
|
|
'input_shape': (1, 3, 224, 224),
|
|
|
'feat_shape': (1, 256, 28, 28)
|
|
|
}, {
|
|
|
'surrogate_id': 'vgg19',
|
|
|
'feat_layer': 'features.18',
|
|
|
'input_shape': (1, 3, 224, 224),
|
|
|
'feat_shape': (1, 256, 28, 28)
|
|
|
}, {
|
|
|
'surrogate_id': 'resnet152',
|
|
|
'feat_layer': 'layer2',
|
|
|
'input_shape': (1, 3, 224, 224),
|
|
|
'feat_shape': (1, 512, 28, 28)
|
|
|
}, {
|
|
|
'surrogate_id': 'densenet169',
|
|
|
'feat_layer': 'features.denseblock2',
|
|
|
'input_shape': (1, 3, 224, 224),
|
|
|
'feat_shape': (1, 512, 28, 28)
|
|
|
}]
|
|
|
for testcase in testcases:
|
|
|
model = build_surrogate(testcase['surrogate_id'], pretrain=False)
|
|
|
with feat_col(model, testcase['feat_layer']) as _feat_collecter:
|
|
|
inputs = torch.rand(testcase['input_shape'])
|
|
|
model(inputs)
|
|
|
model(inputs)
|
|
|
self.assertEqual(len(_feat_collecter), 2)
|
|
|
self.assertEqual(_feat_collecter.pop().shape,
|
|
|
testcase['feat_shape'])
|
|
|
self.assertEqual(_feat_collecter.pop().shape,
|
|
|
testcase['feat_shape'])
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
unittest.main()
|
|
|
|