ProArd / attacks /fgsm.py
smi08's picture
Upload folder using huggingface_hub
7771996 verified
import torch
import torch.nn as nn
from .base import Attack, LabelMixin
from .utils import ctx_noparamgrad_and_eval
from .utils import batch_multiply
from .utils import clamp ,normalize_by_pnorm
from utils.distributed import DistributedMetric
from tqdm import tqdm
from torchpack import distributed as dist
from utils import accuracy
from typing import Dict
class FGSMAttack(Attack, LabelMixin):
"""
One step fast gradient sign method (Goodfellow et al, 2014).
Arguments:
predict (nn.Module): forward pass function.
loss_fn (nn.Module): loss function.
eps (float): attack step size.
clip_min (float): mininum value per input dimension.
clip_max (float): maximum value per input dimension.
targeted (bool): indicate if this is a targeted attack.
"""
def __init__(self, predict, loss_fn=None, eps=0.3, clip_min=0., clip_max=1., targeted=False):
super(FGSMAttack, self).__init__(predict, loss_fn, clip_min, clip_max)
self.eps = eps
self.targeted = targeted
if self.loss_fn is None:
self.loss_fn = nn.CrossEntropyLoss(reduction="sum")
def perturb(self, x, y=None):
"""
Given examples (x, y), returns their adversarial counterparts with an attack length of eps.
Arguments:
x (torch.Tensor): input tensor.
y (torch.Tensor): label tensor.
- if None and self.targeted=False, compute y as predicted labels.
- if self.targeted=True, then y must be the targeted labels.
Returns:
torch.Tensor containing perturbed inputs.
torch.Tensor containing the perturbation.
"""
x, y = self._verify_and_process_inputs(x, y)
xadv = x.requires_grad_()
outputs = self.predict(xadv)
loss = self.loss_fn(outputs, y)
if self.targeted:
loss = -loss
loss.backward()
grad_sign = xadv.grad.detach().sign()
xadv = xadv + batch_multiply(self.eps, grad_sign)
xadv = clamp(xadv, self.clip_min, self.clip_max)
radv = xadv - x
return xadv.detach(), radv.detach()
LinfFastGradientAttack = FGSMAttack
class FGMAttack(Attack, LabelMixin):
"""
One step fast gradient method. Perturbs the input with gradient (not gradient sign) of the loss wrt the input.
Arguments:
predict (nn.Module): forward pass function.
loss_fn (nn.Module): loss function.
eps (float): attack step size.
clip_min (float): mininum value per input dimension.
clip_max (float): maximum value per input dimension.
targeted (bool): indicate if this is a targeted attack.
"""
def __init__(self, predict, loss_fn=None, eps=0.3, clip_min=0., clip_max=1., targeted=False):
super(FGMAttack, self).__init__(
predict, loss_fn, clip_min, clip_max)
self.eps = eps
self.targeted = targeted
if self.loss_fn is None:
self.loss_fn = nn.CrossEntropyLoss(reduction="sum")
def perturb(self, x, y=None):
"""
Given examples (x, y), returns their adversarial counterparts with an attack length of eps.
Arguments:
x (torch.Tensor): input tensor.
y (torch.Tensor): label tensor.
- if None and self.targeted=False, compute y as predicted labels.
- if self.targeted=True, then y must be the targeted labels.
Returns:
torch.Tensor containing perturbed inputs.
torch.Tensor containing the perturbation.
"""
x, y = self._verify_and_process_inputs(x, y)
xadv = x.requires_grad_()
outputs = self.predict(xadv)
loss = self.loss_fn(outputs, y)
if self.targeted:
loss = -loss
loss.backward()
grad = normalize_by_pnorm(xadv.grad)
xadv = xadv + batch_multiply(self.eps, grad)
xadv = clamp(xadv, self.clip_min, self.clip_max)
radv = xadv - x
return xadv.detach(), radv.detach()
def eval_fgsm(self,data_loader_dict: Dict)-> Dict:
test_criterion = nn.CrossEntropyLoss().cuda()
val_loss = DistributedMetric()
val_top1 = DistributedMetric()
val_top5 = DistributedMetric()
val_advloss = DistributedMetric()
val_advtop1 = DistributedMetric()
val_advtop5 = DistributedMetric()
self.predict.eval()
with tqdm(
total=len(data_loader_dict["val"]),
desc="Eval",
disable=not dist.is_master(),
) as t:
for images, labels in data_loader_dict["val"]:
images, labels = images.cuda(), labels.cuda()
# compute output
output = self.predict(images)
loss = test_criterion(output, labels)
val_loss.update(loss, images.shape[0])
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
val_top5.update(acc5[0], images.shape[0])
val_top1.update(acc1[0], images.shape[0])
with ctx_noparamgrad_and_eval(self.predict):
images_adv,_ = self.perturb(images, labels)
output_adv = self.predict(images_adv)
loss_adv = test_criterion(output_adv,labels)
val_advloss.update(loss_adv, images.shape[0])
acc1_adv, acc5_adv = accuracy(output_adv, labels, topk=(1, 5))
val_advtop1.update(acc1_adv[0], images.shape[0])
val_advtop5.update(acc5_adv[0], images.shape[0])
t.set_postfix(
{
"loss": val_loss.avg.item(),
"top1": val_top1.avg.item(),
"top5": val_top5.avg.item(),
"adv_loss": val_advloss.avg.item(),
"adv_top1": val_advtop1.avg.item(),
"adv_top5": val_advtop5.avg.item(),
"#samples": val_top1.count.item(),
"batch_size": images.shape[0],
"img_size": images.shape[2],
}
)
t.update()
val_results = {
"val_top1": val_top1.avg.item(),
"val_top5": val_top5.avg.item(),
"val_loss": val_loss.avg.item(),
"val_advtop1": val_advtop1.avg.item(),
"val_advtop5": val_advtop5.avg.item(),
"val_advloss": val_advloss.avg.item(),
}
return val_results
L2FastGradientAttack = FGMAttack