Spaces:
Sleeping
Sleeping
File size: 9,453 Bytes
eeef81e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 | import warnings
from collections import OrderedDict
from distutils.version import LooseVersion
from functools import partial
from inspect import isclass
from typing import Callable, Optional, Dict, Union
import numpy as np
import torch
import tqdm
from torch import Tensor, nn
from torch.nn import functional as F
from adv_lib.distances.lp_norms import l0_distances, l1_distances, l2_distances, linf_distances
from adv_lib.utils import ForwardCounter, BackwardCounter, predict_inputs
def generate_random_targets(labels: Tensor, num_classes: int) -> Tensor:
"""
Generates one random target in (num_classes - 1) possibilities for each label that is different from the original
label.
Parameters
----------
labels: Tensor
Original labels. Generated targets will be different from labels.
num_classes: int
Number of classes to generate the random targets from.
Returns
-------
targets: Tensor
Random target for each label. Has the same shape as labels.
"""
random = torch.rand(len(labels), num_classes, device=labels.device, dtype=torch.float)
random.scatter_(1, labels.unsqueeze(-1), 0)
return random.argmax(1)
def get_all_targets(labels: Tensor, num_classes: int):
"""
Generates all possible targets that are different from the original labels.
Parameters
----------
labels: Tensor
Original labels. Generated targets will be different from labels.
num_classes: int
Number of classes to generate the random targets from.
Returns
-------
targets: Tensor
Random targets for each label. shape: (len(labels), num_classes - 1).
"""
all_possible_targets = torch.zeros(len(labels), num_classes - 1, dtype=torch.long)
all_classes = set(range(num_classes))
for i in range(len(labels)):
this_label = labels[i].item()
other_labels = list(all_classes.difference({this_label}))
all_possible_targets[i] = torch.tensor(other_labels)
return all_possible_targets
def run_attack(model: nn.Module,
inputs: Tensor,
labels: Tensor,
attack: Callable,
targets: Optional[Tensor] = None,
batch_size: Optional[int] = None) -> dict:
device = next(model.parameters()).device
to_device = lambda tensor: tensor.to(device)
targeted, adv_labels = False, labels
if targets is not None:
targeted, adv_labels = True, targets
batch_size = batch_size or len(inputs)
# run attack only on non already adversarial samples
already_adv = []
chunks = [tensor.split(batch_size) for tensor in [inputs, adv_labels]]
for (inputs_chunk, label_chunk) in zip(*chunks):
batch_chunk_d, label_chunk_d = [to_device(tensor) for tensor in [inputs_chunk, label_chunk]]
preds = model(batch_chunk_d).argmax(1)
is_adv = (preds == label_chunk_d) if targeted else (preds != label_chunk_d)
already_adv.append(is_adv.cpu())
not_adv = ~torch.cat(already_adv, 0)
start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
forward_counter, backward_counter = ForwardCounter(), BackwardCounter()
model.register_forward_pre_hook(forward_counter)
if LooseVersion(torch.__version__) >= LooseVersion('1.8'):
model.register_full_backward_hook(backward_counter)
else:
model.register_backward_hook(backward_counter)
average_forwards, average_backwards = [], [] # number of forward and backward calls per sample
advs_chunks = []
chunks = [tensor.split(batch_size) for tensor in [inputs[not_adv], adv_labels[not_adv]]]
total_time = 0
for (inputs_chunk, label_chunk) in tqdm.tqdm(zip(*chunks), ncols=80, total=len(chunks[0])):
batch_chunk_d, label_chunk_d = [to_device(tensor.clone()) for tensor in [inputs_chunk, label_chunk]]
start.record()
advs_chunk_d = attack(model, batch_chunk_d, label_chunk_d, targeted=targeted)
# performance monitoring
end.record()
torch.cuda.synchronize()
total_time += (start.elapsed_time(end)) / 1000 # times for cuda Events are in milliseconds
average_forwards.append(forward_counter.num_samples_called / len(batch_chunk_d))
average_backwards.append(backward_counter.num_samples_called / len(batch_chunk_d))
forward_counter.reset(), backward_counter.reset()
advs_chunks.append(advs_chunk_d.cpu())
if isinstance(attack, partial) and (callback := attack.keywords.get('callback')) is not None:
callback.reset_windows()
adv_inputs = inputs.clone()
adv_inputs[not_adv] = torch.cat(advs_chunks, 0)
data = {
'inputs': inputs,
'labels': labels,
'targets': adv_labels if targeted else None,
'adv_inputs': adv_inputs,
'time': total_time,
'num_forwards': sum(average_forwards) / len(chunks[0]),
'num_backwards': sum(average_backwards) / len(chunks[0]),
}
return data
_default_metrics = OrderedDict([
('linf', linf_distances),
('l0', l0_distances),
('l1', l1_distances),
('l2', l2_distances),
])
def compute_attack_metrics(model: nn.Module,
attack_data: Dict[str, Union[Tensor, float]],
batch_size: Optional[int] = None,
metrics: Dict[str, Callable] = _default_metrics) -> Dict[str, Union[Tensor, float]]:
inputs, labels, targets, adv_inputs = map(attack_data.get, ['inputs', 'labels', 'targets', 'adv_inputs'])
if adv_inputs.min() < 0 or adv_inputs.max() > 1:
warnings.warn('Values of produced adversarials are not in the [0, 1] range -> Clipping to [0, 1].')
adv_inputs.clamp_(min=0, max=1)
device = next(model.parameters()).device
to_device = lambda tensor: tensor.to(device)
batch_size = batch_size or len(inputs)
chunks = [tensor.split(batch_size) for tensor in [inputs, labels, adv_inputs]]
all_predictions = [[] for _ in range(6)]
distances = {k: [] for k in metrics.keys()}
metrics = {k: v().to(device) if (isclass(v.func) if isinstance(v, partial) else False) else v for k, v in
metrics.items()}
append = lambda list, data: list.append(data.cpu())
for inputs_chunk, labels_chunk, adv_chunk in zip(*chunks):
inputs_chunk, adv_chunk = map(to_device, [inputs_chunk, adv_chunk])
clean_preds, adv_preds = [predict_inputs(model, chunk.to(device)) for chunk in [inputs_chunk, adv_chunk]]
list(map(append, all_predictions, [*clean_preds, *adv_preds]))
for metric, metric_func in metrics.items():
distances[metric].append(metric_func(adv_chunk, inputs_chunk).detach().cpu())
logits, probs, preds, logits_adv, probs_adv, preds_adv = [torch.cat(l) for l in all_predictions]
for metric in metrics.keys():
distances[metric] = torch.cat(distances[metric], 0)
accuracy_orig = (preds == labels).float().mean().item()
if targets is not None:
success = (preds_adv == targets)
labels = targets
else:
success = (preds_adv != labels)
prob_orig = probs.gather(1, labels.unsqueeze(1)).squeeze(1)
prob_adv = probs_adv.gather(1, labels.unsqueeze(1)).squeeze(1)
labels_infhot = torch.zeros_like(logits_adv).scatter_(1, labels.unsqueeze(1), float('inf'))
real = logits_adv.gather(1, labels.unsqueeze(1)).squeeze(1)
other = (logits_adv - labels_infhot).max(1).values
diff_vs_max_adv = (real - other)
nll = F.cross_entropy(logits, labels, reduction='none')
nll_adv = F.cross_entropy(logits_adv, labels, reduction='none')
data = {
'time': attack_data['time'],
'num_forwards': attack_data['num_forwards'],
'num_backwards': attack_data['num_backwards'],
'targeted': targets is not None,
'preds': preds,
'adv_preds': preds_adv,
'accuracy_orig': accuracy_orig,
'success': success,
'probs_orig': prob_orig,
'probs_adv': prob_adv,
'logit_diff_adv': diff_vs_max_adv,
'nll': nll,
'nll_adv': nll_adv,
'distances': distances,
}
return data
def print_metrics(metrics: dict) -> None:
np.set_printoptions(formatter={'float': '{:0.3f}'.format}, threshold=16, edgeitems=3,
linewidth=120) # To print arrays with less precision
print('Original accuracy: {:.2%}'.format(metrics['accuracy_orig']))
print('Attack done in: {:.2f}s with {:.4g} forwards and {:.4g} backwards.'.format(
metrics['time'], metrics['num_forwards'], metrics['num_backwards']))
success = metrics['success'].numpy()
fail = bool(success.mean() != 1)
print('Attack success: {:.2%}'.format(success.mean()) + fail * ' - {}'.format(success))
for distance, values in metrics['distances'].items():
data = values.numpy()
print('{}: {} - Average: {:.3f} - Median: {:.3f}'.format(distance, data, data.mean(), np.median(data)) +
fail * ' | Avg over success: {:.3f}'.format(data[success].mean()))
attack_type = 'targets' if metrics['targeted'] else 'correct'
print('Logit({} class) - max_Logit(other classes): {} - Average: {:.2f}'.format(
attack_type, metrics['logit_diff_adv'].numpy(), metrics['logit_diff_adv'].numpy().mean()))
print('NLL of target/pred class: {:.3f}'.format(metrics['nll_adv'].numpy().mean()))
|