File size: 3,022 Bytes
4ba0984
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from torch.utils.data import Dataset

from mithridatium.attacks.invisible import (
    apply_invisible_trigger,
    create_random_uap,
    InvisibleBackdoorDataset,
)


def test_apply_invisible_trigger_clamps():
    x = torch.zeros((3, 4, 4))
    uap = torch.ones((3, 4, 4)) * 0.5
    out = apply_invisible_trigger(x, uap)
    assert torch.allclose(out, uap)

    # exceeding 1.0 should be clamped
    x = torch.full((3, 4, 4), 0.8)
    uap = torch.full((3, 4, 4), 0.5)
    out = apply_invisible_trigger(x, uap)
    assert out.max() <= 1.0
    assert out.min() >= 0.0


def test_create_random_uap_shapes_and_norms():
    uap_inf = create_random_uap((3, 32, 32), xi=0.1, p="inf", seed=0)
    assert uap_inf.shape == (3, 32, 32)
    assert uap_inf.abs().max() <= 0.1 + 1e-6

    uap2 = create_random_uap((3, 32, 32), xi=0.1, p="2", seed=0)
    assert uap2.shape == (3, 32, 32)
    # L2 norm per-channel should be ~0.1
    norm = uap2.view(3, -1).norm(p=2, dim=1)
    assert torch.allclose(norm, torch.tensor([0.1, 0.1, 0.1]), atol=1e-3)


def _make_simple_dataset(num=10, num_classes=3):
    class DummyDS(Dataset):
        def __len__(self):
            return num

        def __getitem__(self, idx):
            # return a zero image and a class label cycling
            return torch.zeros((3, 4, 4)), idx % num_classes

    return DummyDS()


def test_invisible_backdoor_dataset_poisoning():
    ds = _make_simple_dataset(num=20, num_classes=5)
    uap = torch.zeros((3, 4, 4))
    target = 2
    inv = InvisibleBackdoorDataset(ds, poison_rate=0.5, target_class=target, uap=uap, mode='train', seed=42)
    # check number of poisoned indices ~= 0.5*20 but skipping target-class entries
    total_non_target = sum(1 for i in range(len(ds)) if ds[i][1] != target)
    assert len(inv.poisoned_indices) <= total_non_target

    # test ASR mode returns triples and always triggers
    inv_test = InvisibleBackdoorDataset(ds, poison_rate=1.0, target_class=target, uap=uap, mode='test_poison')
    for img, orig, targ in inv_test:
        assert orig != targ
        assert targ == target
        assert torch.equal(img, torch.zeros((3, 4, 4)))  # uap zero so image unchanged


def test_poison_loss_weight_effect():
    # small model with known outputs
    model = torch.nn.Linear(4, 2)
    # deterministic weights
    torch.manual_seed(0)
    model.weight.data.fill_(0.1)
    model.bias.data.fill_(0.0)

    # inputs: two samples where second is "poisoned" (label target=1)
    x = torch.randn(2, 4)
    y = torch.tensor([0, 1])
    target_class = 1

    # compute base loss
    outputs = model(x)
    base_loss = torch.nn.functional.cross_entropy(outputs, y)

    # compute weighted loss with weight>1
    per_sample = torch.nn.functional.cross_entropy(outputs, y, reduction="none")
    mask = y == target_class
    weights = torch.ones_like(per_sample)
    weights[mask] = 3.0
    weighted = (per_sample * weights).mean()

    assert weighted > base_loss, "Weighted loss should exceed unweighted when weight>1"