|
|
import pytest |
|
|
import torch |
|
|
|
|
|
from anti_kd_backdoor.network.trigger import Trigger |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
@pytest.mark.parametrize('size', [32, 224]) |
|
|
def test_trigger_init(size: int) -> None: |
|
|
trigger = Trigger(size) |
|
|
assert trigger.size == size |
|
|
assert list(trigger.mask.shape) == [size, size] |
|
|
assert list(trigger.trigger.shape) == [3, size, size] |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
@pytest.mark.parametrize('size', [32, 224]) |
|
|
def test_trigger_forward(size: int) -> None: |
|
|
trigger = Trigger(size) |
|
|
|
|
|
x = torch.rand(10, 3, size, size) |
|
|
xp = trigger(x) |
|
|
assert xp.shape == x.shape |
|
|
|
|
|
|
|
|
trigger.mask.fill_(0) |
|
|
xp = trigger(x) |
|
|
assert torch.equal(xp, x) |
|
|
|
|
|
trigger.mask.fill_(1) |
|
|
xp = trigger(x) |
|
|
for i in range(xp.size(0)): |
|
|
assert torch.equal(xp[i], trigger.trigger) |
|
|
|