File size: 332 Bytes
998bb30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import torch
from .base_attack import BaseGenerativeAttack
from .generator.cda import CDAGenerator
class CDAAttack(BaseGenerativeAttack):
def set_adv_gen(self):
self.adv_gen = CDAGenerator().to(self.device)
def attack(self, x_nat, *extra_inputs) -> torch.Tensor:
return self.adv_gen(x_nat)
|