Ttius's picture
Upload 192 files
998bb30 verified
import torch
from .base_attack import BaseGenerativeAttack
from .generator.aim import AIMGenerator
class AIMAttack(BaseGenerativeAttack):
def set_adv_gen(self):
self.adv_gen = AIMGenerator().to(self.device)
def attack(self, x_nat, *extra_inputs) -> torch.Tensor:
x_guid = extra_inputs[0].to(self.device)
return self.adv_gen(x_nat, x_guid)