File size: 1,935 Bytes
998bb30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import abc
from collections import OrderedDict
from pathlib import Path
from typing import Union
import torch
class BaseGenerativeAttack(abc.ABC):
def __init__(self,
device: Union[str, torch.device],
epsilon: float = 32 / 255) -> None:
if isinstance(device, str):
device = torch.device(device)
self.device = device
self.set_adv_gen()
self.set_mode('eval')
self.epsilon = epsilon
@abc.abstractmethod
def set_adv_gen(self):
pass
def load_ckpt(self, ckpt: Union[str, Path, OrderedDict]) -> None:
if isinstance(ckpt, str):
ckpt = Path(ckpt)
if isinstance(ckpt, Path):
if not ckpt.exists():
raise FileNotFoundError(f'File not found: {ckpt}')
ckpt = torch.load(ckpt, map_location=self.device)
self.adv_gen.load_state_dict(ckpt)
self.adv_gen.to(self.device)
def save_ckpt(self, ckpt: Union[str, Path]) -> None:
if isinstance(ckpt, str):
ckpt = Path(ckpt)
_adv_gen_cpu = self.adv_gen.to('cpu')
torch.save(_adv_gen_cpu.state_dict(), ckpt)
def get_params(self) -> torch.nn.Parameter:
return self.adv_gen.parameters()
def get_model(self) -> torch.nn.Module:
return self.adv_gen
def set_mode(self, mode: str) -> None:
assert mode in ['train', 'eval']
self.adv_gen.train() if mode == 'train' else self.adv_gen.eval()
@abc.abstractmethod
def attack(self, *args) -> torch.Tensor:
pass
def __call__(self, x_nat: torch.Tensor, *extra_inputs) -> torch.Tensor:
x_adv = self.attack(x_nat, *extra_inputs)
x_adv = torch.min(torch.max(x_adv, x_nat - self.epsilon),
x_nat + self.epsilon)
torch.clamp_(x_adv, 0.0, 1.0)
return x_adv
|