import torch from torch.utils.data import Dataset from mithridatium.attacks.invisible import ( apply_invisible_trigger, create_random_uap, InvisibleBackdoorDataset, ) def test_apply_invisible_trigger_clamps(): x = torch.zeros((3, 4, 4)) uap = torch.ones((3, 4, 4)) * 0.5 out = apply_invisible_trigger(x, uap) assert torch.allclose(out, uap) # exceeding 1.0 should be clamped x = torch.full((3, 4, 4), 0.8) uap = torch.full((3, 4, 4), 0.5) out = apply_invisible_trigger(x, uap) assert out.max() <= 1.0 assert out.min() >= 0.0 def test_create_random_uap_shapes_and_norms(): uap_inf = create_random_uap((3, 32, 32), xi=0.1, p="inf", seed=0) assert uap_inf.shape == (3, 32, 32) assert uap_inf.abs().max() <= 0.1 + 1e-6 uap2 = create_random_uap((3, 32, 32), xi=0.1, p="2", seed=0) assert uap2.shape == (3, 32, 32) # L2 norm per-channel should be ~0.1 norm = uap2.view(3, -1).norm(p=2, dim=1) assert torch.allclose(norm, torch.tensor([0.1, 0.1, 0.1]), atol=1e-3) def _make_simple_dataset(num=10, num_classes=3): class DummyDS(Dataset): def __len__(self): return num def __getitem__(self, idx): # return a zero image and a class label cycling return torch.zeros((3, 4, 4)), idx % num_classes return DummyDS() def test_invisible_backdoor_dataset_poisoning(): ds = _make_simple_dataset(num=20, num_classes=5) uap = torch.zeros((3, 4, 4)) target = 2 inv = InvisibleBackdoorDataset(ds, poison_rate=0.5, target_class=target, uap=uap, mode='train', seed=42) # check number of poisoned indices ~= 0.5*20 but skipping target-class entries total_non_target = sum(1 for i in range(len(ds)) if ds[i][1] != target) assert len(inv.poisoned_indices) <= total_non_target # test ASR mode returns triples and always triggers inv_test = InvisibleBackdoorDataset(ds, poison_rate=1.0, target_class=target, uap=uap, mode='test_poison') for img, orig, targ in inv_test: assert orig != targ assert targ == target assert torch.equal(img, torch.zeros((3, 4, 4))) # uap zero so image unchanged def test_poison_loss_weight_effect(): # small model with known outputs model = torch.nn.Linear(4, 2) # deterministic weights torch.manual_seed(0) model.weight.data.fill_(0.1) model.bias.data.fill_(0.0) # inputs: two samples where second is "poisoned" (label target=1) x = torch.randn(2, 4) y = torch.tensor([0, 1]) target_class = 1 # compute base loss outputs = model(x) base_loss = torch.nn.functional.cross_entropy(outputs, y) # compute weighted loss with weight>1 per_sample = torch.nn.functional.cross_entropy(outputs, y, reduction="none") mask = y == target_class weights = torch.ones_like(per_sample) weights[mask] = 3.0 weighted = (per_sample * weights).mean() assert weighted > base_loss, "Weighted loss should exceed unweighted when weight>1"