File size: 390 Bytes
998bb30
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch

from .base_attack import BaseGenerativeAttack
from .generator.aim import AIMGenerator


class AIMAttack(BaseGenerativeAttack):

    def set_adv_gen(self):
        self.adv_gen = AIMGenerator().to(self.device)

    def attack(self, x_nat, *extra_inputs) -> torch.Tensor:
        x_guid = extra_inputs[0].to(self.device)
        return self.adv_gen(x_nat, x_guid)