import torch from .base_attack import BaseGenerativeAttack from .generator.aim import AIMGenerator class AIMAttack(BaseGenerativeAttack): def set_adv_gen(self): self.adv_gen = AIMGenerator().to(self.device) def attack(self, x_nat, *extra_inputs) -> torch.Tensor: x_guid = extra_inputs[0].to(self.device) return self.adv_gen(x_nat, x_guid)