SAE / attacks /AIM /tests /test_models /test_attack.py
Ttius's picture
Upload 192 files
998bb30 verified
import unittest
import torch
from gat.models.attack.aim_attack import AIMAttack
from gat.models.attack.cda_attack import CDAAttack
from gat.models.attack.generator.aim import AIMGenerator
from gat.models.attack.generator.cda import CDAGenerator
from gat.models.attack.loss.logits import ContrastiveLoss
class TestGenerator(unittest.TestCase):
@torch.no_grad()
def test_cda(self):
generator = CDAGenerator()
inputs = torch.rand(1, 3, 224, 224)
outputs = generator(inputs)
self.assertEqual(outputs.shape, (1, 3, 224, 224))
@torch.no_grad()
def test_aim(self):
generator = AIMGenerator()
inputs = (torch.rand(1, 3, 224, 224), torch.rand(1, 3, 224, 224))
outputs = generator(*inputs)
self.assertEqual(outputs.shape, (1, 3, 224, 224))
class TestContrastiveLoss(unittest.TestCase):
def setUp(self):
self.margin = 1.0
self.loss_fn = ContrastiveLoss(margin=self.margin)
@torch.no_grad()
def test_forward_shape(self):
anchors = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
positives = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
negatives = torch.tensor([[0.0, 0.0], [1.0, 1.0]])
loss = self.loss_fn(anchors, negatives, positives)
self.assertEqual(loss.shape, torch.Size([]))
@torch.no_grad()
def test_loss_outputs(self):
anchors = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
positives = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
negatives = torch.tensor([[2.0, 0.0], [0.0, 2.0]])
loss = self.loss_fn(anchors, negatives, positives)
self.assertAlmostEqual(loss.item(), 0.25)
@torch.no_grad()
def test_non_zero_loss(self):
anchors = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
positives = torch.tensor([[0.0, 1.0], [1.0, 0.0]])
negatives = torch.tensor([[2.0, 0.0], [0.0, 2.0]])
loss = self.loss_fn(anchors, negatives, positives)
self.assertAlmostEqual(loss.item(), 0.75)
class TestCDAAttack(unittest.TestCase):
def setUp(self):
self.device = torch.device('cpu')
self.epsilon = 16. / 255.
self.attack = CDAAttack(self.device, self.epsilon)
@torch.no_grad()
def test_outputs_shape(self):
x_nat = torch.rand(1, 3, 224, 224)
x_adv = self.attack(x_nat)
self.assertEqual(x_adv.shape, (1, 3, 224, 224))
@torch.no_grad()
def test_outputs_bound(self):
x_nat = torch.rand(1, 3, 224, 224)
x_adv = self.attack(x_nat)
self.assertTrue((x_adv >= 0.0).all())
self.assertTrue((x_adv <= 1.0).all())
self.assertTrue((x_adv - x_nat).abs().max() <= self.epsilon)
class TestAIMAttack(unittest.TestCase):
def setUp(self):
self.device = torch.device('cpu')
self.epsilon = 16. / 255.
self.attack = AIMAttack(self.device, self.epsilon)
@torch.no_grad()
def test_outputs_shape(self):
x_nat = torch.rand(1, 3, 224, 224)
x_guid = torch.rand(1, 3, 224, 224)
x_adv = self.attack(x_nat, x_guid)
self.assertEqual(x_adv.shape, (1, 3, 224, 224))
@torch.no_grad()
def test_outputs_bound(self):
x_nat = torch.rand(1, 3, 224, 224)
x_guid = torch.rand(1, 3, 224, 224)
x_adv = self.attack(x_nat, x_guid)
self.assertTrue((x_adv >= 0.0).all())
self.assertTrue((x_adv <= 1.0).all())
self.assertTrue((x_adv - x_nat).abs().max() <= self.epsilon)
if __name__ == '__main__':
unittest.main()