| import torch | |
| from .base_attack import BaseGenerativeAttack | |
| from .generator.aim import AIMGenerator | |
| class AIMAttack(BaseGenerativeAttack): | |
| def set_adv_gen(self): | |
| self.adv_gen = AIMGenerator().to(self.device) | |
| def attack(self, x_nat, *extra_inputs) -> torch.Tensor: | |
| x_guid = extra_inputs[0].to(self.device) | |
| return self.adv_gen(x_nat, x_guid) | |