import torch import torch.nn.functional as F # some parts of this code are adapted from # https://github.com/M4xim4l/InNOutRobustness/blob/main/utils/adversarial_attacks/utils.py def project_perturbation(perturbation, eps, norm): if norm in ['inf', 'linf', 'Linf']: pert_normalized = torch.clamp(perturbation, -eps, eps) return pert_normalized elif norm in [2, 2.0, 'l2', 'L2', '2']: pert_normalized = torch.renorm(perturbation, p=2, dim=0, maxnorm=eps) return pert_normalized else: raise NotImplementedError(f'Norm {norm} not supported') def normalize_grad(grad, p): if p in ['inf', 'linf', 'Linf']: return grad.sign() elif p in [2, 2.0, 'l2', 'L2', '2']: bs = grad.shape[0] grad_flat = grad.view(bs, -1) grad_normalized = F.normalize(grad_flat, p=2, dim=1) return grad_normalized.view_as(grad)