|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def project_perturbation(perturbation, eps, norm): |
|
|
if norm in ['inf', 'linf', 'Linf']: |
|
|
pert_normalized = torch.clamp(perturbation, -eps, eps) |
|
|
return pert_normalized |
|
|
elif norm in [2, 2.0, 'l2', 'L2', '2']: |
|
|
pert_normalized = torch.renorm(perturbation, p=2, dim=0, maxnorm=eps) |
|
|
return pert_normalized |
|
|
else: |
|
|
raise NotImplementedError(f'Norm {norm} not supported') |
|
|
|
|
|
|
|
|
def normalize_grad(grad, p): |
|
|
if p in ['inf', 'linf', 'Linf']: |
|
|
return grad.sign() |
|
|
elif p in [2, 2.0, 'l2', 'L2', '2']: |
|
|
bs = grad.shape[0] |
|
|
grad_flat = grad.view(bs, -1) |
|
|
grad_normalized = F.normalize(grad_flat, p=2, dim=1) |
|
|
return grad_normalized.view_as(grad) |
|
|
|