File size: 6,796 Bytes
7771996 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import torch
import torch.nn as nn
from .base import Attack, LabelMixin
from .utils import ctx_noparamgrad_and_eval
from .utils import batch_multiply
from .utils import clamp ,normalize_by_pnorm
from utils.distributed import DistributedMetric
from tqdm import tqdm
from torchpack import distributed as dist
from utils import accuracy
from typing import Dict
class FGSMAttack(Attack, LabelMixin):
"""
One step fast gradient sign method (Goodfellow et al, 2014).
Arguments:
predict (nn.Module): forward pass function.
loss_fn (nn.Module): loss function.
eps (float): attack step size.
clip_min (float): mininum value per input dimension.
clip_max (float): maximum value per input dimension.
targeted (bool): indicate if this is a targeted attack.
"""
def __init__(self, predict, loss_fn=None, eps=0.3, clip_min=0., clip_max=1., targeted=False):
super(FGSMAttack, self).__init__(predict, loss_fn, clip_min, clip_max)
self.eps = eps
self.targeted = targeted
if self.loss_fn is None:
self.loss_fn = nn.CrossEntropyLoss(reduction="sum")
def perturb(self, x, y=None):
"""
Given examples (x, y), returns their adversarial counterparts with an attack length of eps.
Arguments:
x (torch.Tensor): input tensor.
y (torch.Tensor): label tensor.
- if None and self.targeted=False, compute y as predicted labels.
- if self.targeted=True, then y must be the targeted labels.
Returns:
torch.Tensor containing perturbed inputs.
torch.Tensor containing the perturbation.
"""
x, y = self._verify_and_process_inputs(x, y)
xadv = x.requires_grad_()
outputs = self.predict(xadv)
loss = self.loss_fn(outputs, y)
if self.targeted:
loss = -loss
loss.backward()
grad_sign = xadv.grad.detach().sign()
xadv = xadv + batch_multiply(self.eps, grad_sign)
xadv = clamp(xadv, self.clip_min, self.clip_max)
radv = xadv - x
return xadv.detach(), radv.detach()
LinfFastGradientAttack = FGSMAttack
class FGMAttack(Attack, LabelMixin):
"""
One step fast gradient method. Perturbs the input with gradient (not gradient sign) of the loss wrt the input.
Arguments:
predict (nn.Module): forward pass function.
loss_fn (nn.Module): loss function.
eps (float): attack step size.
clip_min (float): mininum value per input dimension.
clip_max (float): maximum value per input dimension.
targeted (bool): indicate if this is a targeted attack.
"""
def __init__(self, predict, loss_fn=None, eps=0.3, clip_min=0., clip_max=1., targeted=False):
super(FGMAttack, self).__init__(
predict, loss_fn, clip_min, clip_max)
self.eps = eps
self.targeted = targeted
if self.loss_fn is None:
self.loss_fn = nn.CrossEntropyLoss(reduction="sum")
def perturb(self, x, y=None):
"""
Given examples (x, y), returns their adversarial counterparts with an attack length of eps.
Arguments:
x (torch.Tensor): input tensor.
y (torch.Tensor): label tensor.
- if None and self.targeted=False, compute y as predicted labels.
- if self.targeted=True, then y must be the targeted labels.
Returns:
torch.Tensor containing perturbed inputs.
torch.Tensor containing the perturbation.
"""
x, y = self._verify_and_process_inputs(x, y)
xadv = x.requires_grad_()
outputs = self.predict(xadv)
loss = self.loss_fn(outputs, y)
if self.targeted:
loss = -loss
loss.backward()
grad = normalize_by_pnorm(xadv.grad)
xadv = xadv + batch_multiply(self.eps, grad)
xadv = clamp(xadv, self.clip_min, self.clip_max)
radv = xadv - x
return xadv.detach(), radv.detach()
def eval_fgsm(self,data_loader_dict: Dict)-> Dict:
test_criterion = nn.CrossEntropyLoss().cuda()
val_loss = DistributedMetric()
val_top1 = DistributedMetric()
val_top5 = DistributedMetric()
val_advloss = DistributedMetric()
val_advtop1 = DistributedMetric()
val_advtop5 = DistributedMetric()
self.predict.eval()
with tqdm(
total=len(data_loader_dict["val"]),
desc="Eval",
disable=not dist.is_master(),
) as t:
for images, labels in data_loader_dict["val"]:
images, labels = images.cuda(), labels.cuda()
# compute output
output = self.predict(images)
loss = test_criterion(output, labels)
val_loss.update(loss, images.shape[0])
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
val_top5.update(acc5[0], images.shape[0])
val_top1.update(acc1[0], images.shape[0])
with ctx_noparamgrad_and_eval(self.predict):
images_adv,_ = self.perturb(images, labels)
output_adv = self.predict(images_adv)
loss_adv = test_criterion(output_adv,labels)
val_advloss.update(loss_adv, images.shape[0])
acc1_adv, acc5_adv = accuracy(output_adv, labels, topk=(1, 5))
val_advtop1.update(acc1_adv[0], images.shape[0])
val_advtop5.update(acc5_adv[0], images.shape[0])
t.set_postfix(
{
"loss": val_loss.avg.item(),
"top1": val_top1.avg.item(),
"top5": val_top5.avg.item(),
"adv_loss": val_advloss.avg.item(),
"adv_top1": val_advtop1.avg.item(),
"adv_top5": val_advtop5.avg.item(),
"#samples": val_top1.count.item(),
"batch_size": images.shape[0],
"img_size": images.shape[2],
}
)
t.update()
val_results = {
"val_top1": val_top1.avg.item(),
"val_top5": val_top5.avg.item(),
"val_loss": val_loss.avg.item(),
"val_advtop1": val_advtop1.avg.item(),
"val_advtop5": val_advtop5.avg.item(),
"val_advloss": val_advloss.avg.item(),
}
return val_results
L2FastGradientAttack = FGMAttack
|