import unittest import torch from gat.models.attack.aim_attack import AIMAttack from gat.models.attack.cda_attack import CDAAttack from gat.models.attack.generator.aim import AIMGenerator from gat.models.attack.generator.cda import CDAGenerator from gat.models.attack.loss.logits import ContrastiveLoss class TestGenerator(unittest.TestCase): @torch.no_grad() def test_cda(self): generator = CDAGenerator() inputs = torch.rand(1, 3, 224, 224) outputs = generator(inputs) self.assertEqual(outputs.shape, (1, 3, 224, 224)) @torch.no_grad() def test_aim(self): generator = AIMGenerator() inputs = (torch.rand(1, 3, 224, 224), torch.rand(1, 3, 224, 224)) outputs = generator(*inputs) self.assertEqual(outputs.shape, (1, 3, 224, 224)) class TestContrastiveLoss(unittest.TestCase): def setUp(self): self.margin = 1.0 self.loss_fn = ContrastiveLoss(margin=self.margin) @torch.no_grad() def test_forward_shape(self): anchors = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) positives = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) negatives = torch.tensor([[0.0, 0.0], [1.0, 1.0]]) loss = self.loss_fn(anchors, negatives, positives) self.assertEqual(loss.shape, torch.Size([])) @torch.no_grad() def test_loss_outputs(self): anchors = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) positives = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) negatives = torch.tensor([[2.0, 0.0], [0.0, 2.0]]) loss = self.loss_fn(anchors, negatives, positives) self.assertAlmostEqual(loss.item(), 0.25) @torch.no_grad() def test_non_zero_loss(self): anchors = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) positives = torch.tensor([[0.0, 1.0], [1.0, 0.0]]) negatives = torch.tensor([[2.0, 0.0], [0.0, 2.0]]) loss = self.loss_fn(anchors, negatives, positives) self.assertAlmostEqual(loss.item(), 0.75) class TestCDAAttack(unittest.TestCase): def setUp(self): self.device = torch.device('cpu') self.epsilon = 16. / 255. self.attack = CDAAttack(self.device, self.epsilon) @torch.no_grad() def test_outputs_shape(self): x_nat = torch.rand(1, 3, 224, 224) x_adv = self.attack(x_nat) self.assertEqual(x_adv.shape, (1, 3, 224, 224)) @torch.no_grad() def test_outputs_bound(self): x_nat = torch.rand(1, 3, 224, 224) x_adv = self.attack(x_nat) self.assertTrue((x_adv >= 0.0).all()) self.assertTrue((x_adv <= 1.0).all()) self.assertTrue((x_adv - x_nat).abs().max() <= self.epsilon) class TestAIMAttack(unittest.TestCase): def setUp(self): self.device = torch.device('cpu') self.epsilon = 16. / 255. self.attack = AIMAttack(self.device, self.epsilon) @torch.no_grad() def test_outputs_shape(self): x_nat = torch.rand(1, 3, 224, 224) x_guid = torch.rand(1, 3, 224, 224) x_adv = self.attack(x_nat, x_guid) self.assertEqual(x_adv.shape, (1, 3, 224, 224)) @torch.no_grad() def test_outputs_bound(self): x_nat = torch.rand(1, 3, 224, 224) x_guid = torch.rand(1, 3, 224, 224) x_adv = self.attack(x_nat, x_guid) self.assertTrue((x_adv >= 0.0).all()) self.assertTrue((x_adv <= 1.0).all()) self.assertTrue((x_adv - x_nat).abs().max() <= self.epsilon) if __name__ == '__main__': unittest.main()