Ttius's picture
Upload 192 files
998bb30 verified
import torch
from .base_attack import BaseGenerativeAttack
from .generator.cda import CDAGenerator
class CDAAttack(BaseGenerativeAttack):
def set_adv_gen(self):
self.adv_gen = CDAGenerator().to(self.device)
def attack(self, x_nat, *extra_inputs) -> torch.Tensor:
return self.adv_gen(x_nat)