File size: 1,935 Bytes
998bb30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import abc
from collections import OrderedDict
from pathlib import Path
from typing import Union

import torch


class BaseGenerativeAttack(abc.ABC):

    def __init__(self,

                 device: Union[str, torch.device],

                 epsilon: float = 32 / 255) -> None:
        if isinstance(device, str):
            device = torch.device(device)
        self.device = device
        self.set_adv_gen()
        self.set_mode('eval')
        self.epsilon = epsilon

    @abc.abstractmethod
    def set_adv_gen(self):
        pass

    def load_ckpt(self, ckpt: Union[str, Path, OrderedDict]) -> None:
        if isinstance(ckpt, str):
            ckpt = Path(ckpt)
        if isinstance(ckpt, Path):
            if not ckpt.exists():
                raise FileNotFoundError(f'File not found: {ckpt}')
            ckpt = torch.load(ckpt, map_location=self.device)
        self.adv_gen.load_state_dict(ckpt)
        self.adv_gen.to(self.device)

    def save_ckpt(self, ckpt: Union[str, Path]) -> None:
        if isinstance(ckpt, str):
            ckpt = Path(ckpt)
        _adv_gen_cpu = self.adv_gen.to('cpu')
        torch.save(_adv_gen_cpu.state_dict(), ckpt)

    def get_params(self) -> torch.nn.Parameter:
        return self.adv_gen.parameters()

    def get_model(self) -> torch.nn.Module:
        return self.adv_gen

    def set_mode(self, mode: str) -> None:
        assert mode in ['train', 'eval']
        self.adv_gen.train() if mode == 'train' else self.adv_gen.eval()

    @abc.abstractmethod
    def attack(self, *args) -> torch.Tensor:
        pass

    def __call__(self, x_nat: torch.Tensor, *extra_inputs) -> torch.Tensor:
        x_adv = self.attack(x_nat, *extra_inputs)
        x_adv = torch.min(torch.max(x_adv, x_nat - self.epsilon),
                          x_nat + self.epsilon)
        torch.clamp_(x_adv, 0.0, 1.0)
        return x_adv