| import torch | |
| from .base_attack import BaseGenerativeAttack | |
| from .generator.cda import CDAGenerator | |
| class CDAAttack(BaseGenerativeAttack): | |
| def set_adv_gen(self): | |
| self.adv_gen = CDAGenerator().to(self.device) | |
| def attack(self, x_nat, *extra_inputs) -> torch.Tensor: | |
| return self.adv_gen(x_nat) | |