xmutly's picture
Upload 294 files
e1aaaac verified
raw
history blame contribute delete
901 Bytes
import torch
import torch.nn.functional as F
# some parts of this code are adapted from
# https://github.com/M4xim4l/InNOutRobustness/blob/main/utils/adversarial_attacks/utils.py
def project_perturbation(perturbation, eps, norm):
if norm in ['inf', 'linf', 'Linf']:
pert_normalized = torch.clamp(perturbation, -eps, eps)
return pert_normalized
elif norm in [2, 2.0, 'l2', 'L2', '2']:
pert_normalized = torch.renorm(perturbation, p=2, dim=0, maxnorm=eps)
return pert_normalized
else:
raise NotImplementedError(f'Norm {norm} not supported')
def normalize_grad(grad, p):
if p in ['inf', 'linf', 'Linf']:
return grad.sign()
elif p in [2, 2.0, 'l2', 'L2', '2']:
bs = grad.shape[0]
grad_flat = grad.view(bs, -1)
grad_normalized = F.normalize(grad_flat, p=2, dim=1)
return grad_normalized.view_as(grad)