SAE / attacks /AIM /tests /test_models /test_surrogate.py
Ttius's picture
Upload 192 files
998bb30 verified
import unittest
import torch
from gat.models.surrogate import build_surrogate, list_surrogates
from gat.models.surrogate.hooks import feat_col
class TestTVModels(unittest.TestCase):
@torch.no_grad()
def test_outputs_shape(self):
inputs = torch.rand(1, 3, 224, 224)
for surrogate_id in list_surrogates():
if surrogate_id not in ['inception_v3']:
outputs = build_surrogate(surrogate_id, pretrain=False)(inputs)
self.assertEqual(outputs.shape, (1, 1000))
class TestFeatCol(unittest.TestCase):
@torch.no_grad()
def test_feat_col(self):
testcases = [{
'surrogate_id': 'vgg16',
'feat_layer': 'features.16',
'input_shape': (1, 3, 224, 224),
'feat_shape': (1, 256, 28, 28)
}, {
'surrogate_id': 'vgg19',
'feat_layer': 'features.18',
'input_shape': (1, 3, 224, 224),
'feat_shape': (1, 256, 28, 28)
}, {
'surrogate_id': 'resnet152',
'feat_layer': 'layer2',
'input_shape': (1, 3, 224, 224),
'feat_shape': (1, 512, 28, 28)
}, {
'surrogate_id': 'densenet169',
'feat_layer': 'features.denseblock2',
'input_shape': (1, 3, 224, 224),
'feat_shape': (1, 512, 28, 28)
}]
for testcase in testcases:
model = build_surrogate(testcase['surrogate_id'], pretrain=False)
with feat_col(model, testcase['feat_layer']) as _feat_collecter:
inputs = torch.rand(testcase['input_shape'])
model(inputs)
model(inputs)
self.assertEqual(len(_feat_collecter), 2)
self.assertEqual(_feat_collecter.pop().shape,
testcase['feat_shape'])
self.assertEqual(_feat_collecter.pop().shape,
testcase['feat_shape'])
if __name__ == '__main__':
unittest.main()