File size: 4,354 Bytes
d38bce3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
"""
https://github.com/ferjad/Universal_Adversarial_Perturbation_pytorch
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
"""
from deeprobust.image.attack import deepfool
import collections
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torch
import torch.optim as optim
import torch.utils.data as data_utils
import math
from PIL import Image
import torchvision.models as models
import sys
import random
import time
from tqdm import tqdm
def zero_gradients(x):
if isinstance(x, torch.Tensor):
if x.grad is not None:
x.grad.detach_()
x.grad.zero_()
elif isinstance(x, collections.abc.Iterable):
for elem in x:
zero_gradients(elem)
def get_model(model,device):
if model == 'vgg16':
net = models.vgg16(pretrained=True)
elif model =='resnet18':
net = models.resnet18(pretrained=True)
net.eval()
net=net.to(device)
return net
def data_input_init(xi):
mean = [ 0.485, 0.456, 0.406 ]
std = [ 0.229, 0.224, 0.225 ]
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean = mean,
std = std)])
return (mean,std,transform)
def proj_lp(v, xi, p):
# Project on the lp ball centered at 0 and of radius xi
if p==np.inf:
v=torch.clamp(v,-xi,xi)
else:
v=v * min(1, xi/(torch.norm(v,p)+0.00001))
return v
def get_fooling_rate(data_list,v,model, device):
f = data_input_init(0)[2]
num_images = len(data_list)
fooled=0.0
for name in tqdm(data_list):
image = Image.open(name)
image = tf(image)
image = image.unsqueeze(0)
image = image.to(device)
_, pred = torch.max(model(image),1)
_, adv_pred = torch.max(model(image+v),1)
if(pred!=adv_pred):
fooled+=1
# Compute the fooling rate
fooling_rate = fooled/num_images
print('Fooling Rate = ', fooling_rate)
for param in model.parameters():
param.requires_grad = False
return fooling_rate,model
def universal_adversarial_perturbation(dataloader, model, device, xi=10, delta=0.2, max_iter_uni = 10, p=np.inf,
num_classes=10, overshoot=0.02, max_iter_df=10,t_p = 0.2):
"""universal_adversarial_perturbation.
Parameters
----------
dataloader :
dataloader
model :
target model
device :
device
xi :
controls the l_p magnitude of the perturbation
delta :
controls the desired fooling rate (default = 80% fooling rate)
max_iter_uni :
maximum number of iteration (default = 10*num_images)
p :
norm to be used (default = np.inf)
num_classes :
num_classes (default = 10)
overshoot :
to prevent vanishing updates (default = 0.02)
max_iter_df :
maximum number of iterations for deepfool (default = 10)
t_p :
truth percentage, for how many flipped labels in a batch. (default = 0.2)
Returns
-------
the universal perturbation matrix.
"""
time_start = time.time()
mean, std,tf = data_input_init(xi)
v = torch.zeros(1,3,224,224).to(device)
v.requires_grad_()
fooling_rate = 0.0
num_images = len(dataloader)
itr = 0
while fooling_rate < 1-delta and itr < max_iter_uni:
# Iterate over the dataset and compute the purturbation incrementally
for i,(img, label) in enumerate(dataloader):
_, pred = torch.max(model(img),1)
_, adv_pred = torch.max(model(img+v),1)
if(pred == adv_pred):
perturb = deepfool(model, device)
_ = perturb.generate(img+v, num_classed = num_classed, overshoot = overshoot, max_iter = max_iter_df)
dr, iter = perturb.getpurb()
if(iter<max_iter_df-1):
v = v + torch.from_numpy(dr).to(device)
v = proj_lp(v,xi,p)
if(k%10==0):
print('Norm of v: '+str(torch.norm(v).detach().cpu().numpy()))
fooling_rate,model = get_fooling_rate(data_list,v,model, device)
itr = itr + 1
return v
|