|
|
import unittest
|
|
|
|
|
|
import torch
|
|
|
|
|
|
from gat.models.attack.aim_attack import AIMAttack
|
|
|
from gat.models.attack.cda_attack import CDAAttack
|
|
|
from gat.models.attack.generator.aim import AIMGenerator
|
|
|
from gat.models.attack.generator.cda import CDAGenerator
|
|
|
from gat.models.attack.loss.logits import ContrastiveLoss
|
|
|
|
|
|
|
|
|
class TestGenerator(unittest.TestCase):
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def test_cda(self):
|
|
|
generator = CDAGenerator()
|
|
|
inputs = torch.rand(1, 3, 224, 224)
|
|
|
outputs = generator(inputs)
|
|
|
self.assertEqual(outputs.shape, (1, 3, 224, 224))
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def test_aim(self):
|
|
|
generator = AIMGenerator()
|
|
|
inputs = (torch.rand(1, 3, 224, 224), torch.rand(1, 3, 224, 224))
|
|
|
outputs = generator(*inputs)
|
|
|
self.assertEqual(outputs.shape, (1, 3, 224, 224))
|
|
|
|
|
|
|
|
|
class TestContrastiveLoss(unittest.TestCase):
|
|
|
|
|
|
def setUp(self):
|
|
|
self.margin = 1.0
|
|
|
self.loss_fn = ContrastiveLoss(margin=self.margin)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def test_forward_shape(self):
|
|
|
anchors = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
|
|
|
positives = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
|
|
|
negatives = torch.tensor([[0.0, 0.0], [1.0, 1.0]])
|
|
|
|
|
|
loss = self.loss_fn(anchors, negatives, positives)
|
|
|
self.assertEqual(loss.shape, torch.Size([]))
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def test_loss_outputs(self):
|
|
|
anchors = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
|
|
|
positives = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
|
|
|
negatives = torch.tensor([[2.0, 0.0], [0.0, 2.0]])
|
|
|
|
|
|
loss = self.loss_fn(anchors, negatives, positives)
|
|
|
self.assertAlmostEqual(loss.item(), 0.25)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def test_non_zero_loss(self):
|
|
|
anchors = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
|
|
|
positives = torch.tensor([[0.0, 1.0], [1.0, 0.0]])
|
|
|
negatives = torch.tensor([[2.0, 0.0], [0.0, 2.0]])
|
|
|
|
|
|
loss = self.loss_fn(anchors, negatives, positives)
|
|
|
self.assertAlmostEqual(loss.item(), 0.75)
|
|
|
|
|
|
|
|
|
class TestCDAAttack(unittest.TestCase):
|
|
|
|
|
|
def setUp(self):
|
|
|
self.device = torch.device('cpu')
|
|
|
self.epsilon = 16. / 255.
|
|
|
self.attack = CDAAttack(self.device, self.epsilon)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def test_outputs_shape(self):
|
|
|
x_nat = torch.rand(1, 3, 224, 224)
|
|
|
x_adv = self.attack(x_nat)
|
|
|
self.assertEqual(x_adv.shape, (1, 3, 224, 224))
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def test_outputs_bound(self):
|
|
|
x_nat = torch.rand(1, 3, 224, 224)
|
|
|
x_adv = self.attack(x_nat)
|
|
|
self.assertTrue((x_adv >= 0.0).all())
|
|
|
self.assertTrue((x_adv <= 1.0).all())
|
|
|
self.assertTrue((x_adv - x_nat).abs().max() <= self.epsilon)
|
|
|
|
|
|
|
|
|
class TestAIMAttack(unittest.TestCase):
|
|
|
|
|
|
def setUp(self):
|
|
|
self.device = torch.device('cpu')
|
|
|
self.epsilon = 16. / 255.
|
|
|
self.attack = AIMAttack(self.device, self.epsilon)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def test_outputs_shape(self):
|
|
|
x_nat = torch.rand(1, 3, 224, 224)
|
|
|
x_guid = torch.rand(1, 3, 224, 224)
|
|
|
x_adv = self.attack(x_nat, x_guid)
|
|
|
self.assertEqual(x_adv.shape, (1, 3, 224, 224))
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def test_outputs_bound(self):
|
|
|
x_nat = torch.rand(1, 3, 224, 224)
|
|
|
x_guid = torch.rand(1, 3, 224, 224)
|
|
|
x_adv = self.attack(x_nat, x_guid)
|
|
|
self.assertTrue((x_adv >= 0.0).all())
|
|
|
self.assertTrue((x_adv <= 1.0).all())
|
|
|
self.assertTrue((x_adv - x_nat).abs().max() <= self.epsilon)
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
unittest.main()
|
|
|
|