File size: 390 Bytes
998bb30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
import torch
from .base_attack import BaseGenerativeAttack
from .generator.aim import AIMGenerator
class AIMAttack(BaseGenerativeAttack):
def set_adv_gen(self):
self.adv_gen = AIMGenerator().to(self.device)
def attack(self, x_nat, *extra_inputs) -> torch.Tensor:
x_guid = extra_inputs[0].to(self.device)
return self.adv_gen(x_nat, x_guid)
|