import unittest import torch from gat.models.surrogate import build_surrogate, list_surrogates from gat.models.surrogate.hooks import feat_col class TestTVModels(unittest.TestCase): @torch.no_grad() def test_outputs_shape(self): inputs = torch.rand(1, 3, 224, 224) for surrogate_id in list_surrogates(): if surrogate_id not in ['inception_v3']: outputs = build_surrogate(surrogate_id, pretrain=False)(inputs) self.assertEqual(outputs.shape, (1, 1000)) class TestFeatCol(unittest.TestCase): @torch.no_grad() def test_feat_col(self): testcases = [{ 'surrogate_id': 'vgg16', 'feat_layer': 'features.16', 'input_shape': (1, 3, 224, 224), 'feat_shape': (1, 256, 28, 28) }, { 'surrogate_id': 'vgg19', 'feat_layer': 'features.18', 'input_shape': (1, 3, 224, 224), 'feat_shape': (1, 256, 28, 28) }, { 'surrogate_id': 'resnet152', 'feat_layer': 'layer2', 'input_shape': (1, 3, 224, 224), 'feat_shape': (1, 512, 28, 28) }, { 'surrogate_id': 'densenet169', 'feat_layer': 'features.denseblock2', 'input_shape': (1, 3, 224, 224), 'feat_shape': (1, 512, 28, 28) }] for testcase in testcases: model = build_surrogate(testcase['surrogate_id'], pretrain=False) with feat_col(model, testcase['feat_layer']) as _feat_collecter: inputs = torch.rand(testcase['input_shape']) model(inputs) model(inputs) self.assertEqual(len(_feat_collecter), 2) self.assertEqual(_feat_collecter.pop().shape, testcase['feat_shape']) self.assertEqual(_feat_collecter.pop().shape, testcase['feat_shape']) if __name__ == '__main__': unittest.main()