Mithridatium / tests /test_invisible_trigger.py
Gustavo Lucca
Invisble watermark training script and tests updated
4ba0984
import torch
from torch.utils.data import Dataset
from mithridatium.attacks.invisible import (
apply_invisible_trigger,
create_random_uap,
InvisibleBackdoorDataset,
)
def test_apply_invisible_trigger_clamps():
x = torch.zeros((3, 4, 4))
uap = torch.ones((3, 4, 4)) * 0.5
out = apply_invisible_trigger(x, uap)
assert torch.allclose(out, uap)
# exceeding 1.0 should be clamped
x = torch.full((3, 4, 4), 0.8)
uap = torch.full((3, 4, 4), 0.5)
out = apply_invisible_trigger(x, uap)
assert out.max() <= 1.0
assert out.min() >= 0.0
def test_create_random_uap_shapes_and_norms():
uap_inf = create_random_uap((3, 32, 32), xi=0.1, p="inf", seed=0)
assert uap_inf.shape == (3, 32, 32)
assert uap_inf.abs().max() <= 0.1 + 1e-6
uap2 = create_random_uap((3, 32, 32), xi=0.1, p="2", seed=0)
assert uap2.shape == (3, 32, 32)
# L2 norm per-channel should be ~0.1
norm = uap2.view(3, -1).norm(p=2, dim=1)
assert torch.allclose(norm, torch.tensor([0.1, 0.1, 0.1]), atol=1e-3)
def _make_simple_dataset(num=10, num_classes=3):
class DummyDS(Dataset):
def __len__(self):
return num
def __getitem__(self, idx):
# return a zero image and a class label cycling
return torch.zeros((3, 4, 4)), idx % num_classes
return DummyDS()
def test_invisible_backdoor_dataset_poisoning():
ds = _make_simple_dataset(num=20, num_classes=5)
uap = torch.zeros((3, 4, 4))
target = 2
inv = InvisibleBackdoorDataset(ds, poison_rate=0.5, target_class=target, uap=uap, mode='train', seed=42)
# check number of poisoned indices ~= 0.5*20 but skipping target-class entries
total_non_target = sum(1 for i in range(len(ds)) if ds[i][1] != target)
assert len(inv.poisoned_indices) <= total_non_target
# test ASR mode returns triples and always triggers
inv_test = InvisibleBackdoorDataset(ds, poison_rate=1.0, target_class=target, uap=uap, mode='test_poison')
for img, orig, targ in inv_test:
assert orig != targ
assert targ == target
assert torch.equal(img, torch.zeros((3, 4, 4))) # uap zero so image unchanged
def test_poison_loss_weight_effect():
# small model with known outputs
model = torch.nn.Linear(4, 2)
# deterministic weights
torch.manual_seed(0)
model.weight.data.fill_(0.1)
model.bias.data.fill_(0.0)
# inputs: two samples where second is "poisoned" (label target=1)
x = torch.randn(2, 4)
y = torch.tensor([0, 1])
target_class = 1
# compute base loss
outputs = model(x)
base_loss = torch.nn.functional.cross_entropy(outputs, y)
# compute weighted loss with weight>1
per_sample = torch.nn.functional.cross_entropy(outputs, y, reduction="none")
mask = y == target_class
weights = torch.ones_like(per_sample)
weights[mask] = 3.0
weighted = (per_sample * weights).mean()
assert weighted > base_loss, "Weighted loss should exceed unweighted when weight>1"