import abc from collections import OrderedDict from pathlib import Path from typing import Union import torch class BaseGenerativeAttack(abc.ABC): def __init__(self, device: Union[str, torch.device], epsilon: float = 32 / 255) -> None: if isinstance(device, str): device = torch.device(device) self.device = device self.set_adv_gen() self.set_mode('eval') self.epsilon = epsilon @abc.abstractmethod def set_adv_gen(self): pass def load_ckpt(self, ckpt: Union[str, Path, OrderedDict]) -> None: if isinstance(ckpt, str): ckpt = Path(ckpt) if isinstance(ckpt, Path): if not ckpt.exists(): raise FileNotFoundError(f'File not found: {ckpt}') ckpt = torch.load(ckpt, map_location=self.device) self.adv_gen.load_state_dict(ckpt) self.adv_gen.to(self.device) def save_ckpt(self, ckpt: Union[str, Path]) -> None: if isinstance(ckpt, str): ckpt = Path(ckpt) _adv_gen_cpu = self.adv_gen.to('cpu') torch.save(_adv_gen_cpu.state_dict(), ckpt) def get_params(self) -> torch.nn.Parameter: return self.adv_gen.parameters() def get_model(self) -> torch.nn.Module: return self.adv_gen def set_mode(self, mode: str) -> None: assert mode in ['train', 'eval'] self.adv_gen.train() if mode == 'train' else self.adv_gen.eval() @abc.abstractmethod def attack(self, *args) -> torch.Tensor: pass def __call__(self, x_nat: torch.Tensor, *extra_inputs) -> torch.Tensor: x_adv = self.attack(x_nat, *extra_inputs) x_adv = torch.min(torch.max(x_adv, x_nat - self.epsilon), x_nat + self.epsilon) torch.clamp_(x_adv, 0.0, 1.0) return x_adv