import warnings from collections import OrderedDict from distutils.version import LooseVersion from functools import partial from inspect import isclass from typing import Callable, Optional, Dict, Union import numpy as np import torch import tqdm from torch import Tensor, nn from torch.nn import functional as F from adv_lib.distances.lp_norms import l0_distances, l1_distances, l2_distances, linf_distances from adv_lib.utils import ForwardCounter, BackwardCounter, predict_inputs def generate_random_targets(labels: Tensor, num_classes: int) -> Tensor: """ Generates one random target in (num_classes - 1) possibilities for each label that is different from the original label. Parameters ---------- labels: Tensor Original labels. Generated targets will be different from labels. num_classes: int Number of classes to generate the random targets from. Returns ------- targets: Tensor Random target for each label. Has the same shape as labels. """ random = torch.rand(len(labels), num_classes, device=labels.device, dtype=torch.float) random.scatter_(1, labels.unsqueeze(-1), 0) return random.argmax(1) def get_all_targets(labels: Tensor, num_classes: int): """ Generates all possible targets that are different from the original labels. Parameters ---------- labels: Tensor Original labels. Generated targets will be different from labels. num_classes: int Number of classes to generate the random targets from. Returns ------- targets: Tensor Random targets for each label. shape: (len(labels), num_classes - 1). """ all_possible_targets = torch.zeros(len(labels), num_classes - 1, dtype=torch.long) all_classes = set(range(num_classes)) for i in range(len(labels)): this_label = labels[i].item() other_labels = list(all_classes.difference({this_label})) all_possible_targets[i] = torch.tensor(other_labels) return all_possible_targets def run_attack(model: nn.Module, inputs: Tensor, labels: Tensor, attack: Callable, targets: Optional[Tensor] = None, batch_size: Optional[int] = None) -> dict: device = next(model.parameters()).device to_device = lambda tensor: tensor.to(device) targeted, adv_labels = False, labels if targets is not None: targeted, adv_labels = True, targets batch_size = batch_size or len(inputs) # run attack only on non already adversarial samples already_adv = [] chunks = [tensor.split(batch_size) for tensor in [inputs, adv_labels]] for (inputs_chunk, label_chunk) in zip(*chunks): batch_chunk_d, label_chunk_d = [to_device(tensor) for tensor in [inputs_chunk, label_chunk]] preds = model(batch_chunk_d).argmax(1) is_adv = (preds == label_chunk_d) if targeted else (preds != label_chunk_d) already_adv.append(is_adv.cpu()) not_adv = ~torch.cat(already_adv, 0) start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) forward_counter, backward_counter = ForwardCounter(), BackwardCounter() model.register_forward_pre_hook(forward_counter) if LooseVersion(torch.__version__) >= LooseVersion('1.8'): model.register_full_backward_hook(backward_counter) else: model.register_backward_hook(backward_counter) average_forwards, average_backwards = [], [] # number of forward and backward calls per sample advs_chunks = [] chunks = [tensor.split(batch_size) for tensor in [inputs[not_adv], adv_labels[not_adv]]] total_time = 0 for (inputs_chunk, label_chunk) in tqdm.tqdm(zip(*chunks), ncols=80, total=len(chunks[0])): batch_chunk_d, label_chunk_d = [to_device(tensor.clone()) for tensor in [inputs_chunk, label_chunk]] start.record() advs_chunk_d = attack(model, batch_chunk_d, label_chunk_d, targeted=targeted) # performance monitoring end.record() torch.cuda.synchronize() total_time += (start.elapsed_time(end)) / 1000 # times for cuda Events are in milliseconds average_forwards.append(forward_counter.num_samples_called / len(batch_chunk_d)) average_backwards.append(backward_counter.num_samples_called / len(batch_chunk_d)) forward_counter.reset(), backward_counter.reset() advs_chunks.append(advs_chunk_d.cpu()) if isinstance(attack, partial) and (callback := attack.keywords.get('callback')) is not None: callback.reset_windows() adv_inputs = inputs.clone() adv_inputs[not_adv] = torch.cat(advs_chunks, 0) data = { 'inputs': inputs, 'labels': labels, 'targets': adv_labels if targeted else None, 'adv_inputs': adv_inputs, 'time': total_time, 'num_forwards': sum(average_forwards) / len(chunks[0]), 'num_backwards': sum(average_backwards) / len(chunks[0]), } return data _default_metrics = OrderedDict([ ('linf', linf_distances), ('l0', l0_distances), ('l1', l1_distances), ('l2', l2_distances), ]) def compute_attack_metrics(model: nn.Module, attack_data: Dict[str, Union[Tensor, float]], batch_size: Optional[int] = None, metrics: Dict[str, Callable] = _default_metrics) -> Dict[str, Union[Tensor, float]]: inputs, labels, targets, adv_inputs = map(attack_data.get, ['inputs', 'labels', 'targets', 'adv_inputs']) if adv_inputs.min() < 0 or adv_inputs.max() > 1: warnings.warn('Values of produced adversarials are not in the [0, 1] range -> Clipping to [0, 1].') adv_inputs.clamp_(min=0, max=1) device = next(model.parameters()).device to_device = lambda tensor: tensor.to(device) batch_size = batch_size or len(inputs) chunks = [tensor.split(batch_size) for tensor in [inputs, labels, adv_inputs]] all_predictions = [[] for _ in range(6)] distances = {k: [] for k in metrics.keys()} metrics = {k: v().to(device) if (isclass(v.func) if isinstance(v, partial) else False) else v for k, v in metrics.items()} append = lambda list, data: list.append(data.cpu()) for inputs_chunk, labels_chunk, adv_chunk in zip(*chunks): inputs_chunk, adv_chunk = map(to_device, [inputs_chunk, adv_chunk]) clean_preds, adv_preds = [predict_inputs(model, chunk.to(device)) for chunk in [inputs_chunk, adv_chunk]] list(map(append, all_predictions, [*clean_preds, *adv_preds])) for metric, metric_func in metrics.items(): distances[metric].append(metric_func(adv_chunk, inputs_chunk).detach().cpu()) logits, probs, preds, logits_adv, probs_adv, preds_adv = [torch.cat(l) for l in all_predictions] for metric in metrics.keys(): distances[metric] = torch.cat(distances[metric], 0) accuracy_orig = (preds == labels).float().mean().item() if targets is not None: success = (preds_adv == targets) labels = targets else: success = (preds_adv != labels) prob_orig = probs.gather(1, labels.unsqueeze(1)).squeeze(1) prob_adv = probs_adv.gather(1, labels.unsqueeze(1)).squeeze(1) labels_infhot = torch.zeros_like(logits_adv).scatter_(1, labels.unsqueeze(1), float('inf')) real = logits_adv.gather(1, labels.unsqueeze(1)).squeeze(1) other = (logits_adv - labels_infhot).max(1).values diff_vs_max_adv = (real - other) nll = F.cross_entropy(logits, labels, reduction='none') nll_adv = F.cross_entropy(logits_adv, labels, reduction='none') data = { 'time': attack_data['time'], 'num_forwards': attack_data['num_forwards'], 'num_backwards': attack_data['num_backwards'], 'targeted': targets is not None, 'preds': preds, 'adv_preds': preds_adv, 'accuracy_orig': accuracy_orig, 'success': success, 'probs_orig': prob_orig, 'probs_adv': prob_adv, 'logit_diff_adv': diff_vs_max_adv, 'nll': nll, 'nll_adv': nll_adv, 'distances': distances, } return data def print_metrics(metrics: dict) -> None: np.set_printoptions(formatter={'float': '{:0.3f}'.format}, threshold=16, edgeitems=3, linewidth=120) # To print arrays with less precision print('Original accuracy: {:.2%}'.format(metrics['accuracy_orig'])) print('Attack done in: {:.2f}s with {:.4g} forwards and {:.4g} backwards.'.format( metrics['time'], metrics['num_forwards'], metrics['num_backwards'])) success = metrics['success'].numpy() fail = bool(success.mean() != 1) print('Attack success: {:.2%}'.format(success.mean()) + fail * ' - {}'.format(success)) for distance, values in metrics['distances'].items(): data = values.numpy() print('{}: {} - Average: {:.3f} - Median: {:.3f}'.format(distance, data, data.mean(), np.median(data)) + fail * ' | Avg over success: {:.3f}'.format(data[success].mean())) attack_type = 'targets' if metrics['targeted'] else 'correct' print('Logit({} class) - max_Logit(other classes): {} - Average: {:.2f}'.format( attack_type, metrics['logit_diff_adv'].numpy(), metrics['logit_diff_adv'].numpy().mean())) print('NLL of target/pred class: {:.3f}'.format(metrics['nll_adv'].numpy().mean()))