File size: 332 Bytes
998bb30
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch

from .base_attack import BaseGenerativeAttack
from .generator.cda import CDAGenerator


class CDAAttack(BaseGenerativeAttack):

    def set_adv_gen(self):
        self.adv_gen = CDAGenerator().to(self.device)

    def attack(self, x_nat, *extra_inputs) -> torch.Tensor:
        return self.adv_gen(x_nat)