import torch from .base_attack import BaseGenerativeAttack from .generator.cda import CDAGenerator class CDAAttack(BaseGenerativeAttack): def set_adv_gen(self): self.adv_gen = CDAGenerator().to(self.device) def attack(self, x_nat, *extra_inputs) -> torch.Tensor: return self.adv_gen(x_nat)