File size: 4,354 Bytes
d38bce3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
https://github.com/ferjad/Universal_Adversarial_Perturbation_pytorch
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>

"""
from deeprobust.image.attack import deepfool
import collections
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torch
import torch.optim as optim
import torch.utils.data as data_utils
import math
from PIL import Image
import torchvision.models as models
import sys
import random
import time
from tqdm import tqdm

def zero_gradients(x):
    if isinstance(x, torch.Tensor):
        if x.grad is not None:
            x.grad.detach_()
            x.grad.zero_()
    elif isinstance(x, collections.abc.Iterable):
        for elem in x:
            zero_gradients(elem)

def get_model(model,device):
    if model == 'vgg16':
        net = models.vgg16(pretrained=True)
    elif model =='resnet18':
        net = models.resnet18(pretrained=True)

    net.eval()
    net=net.to(device)
    return net

def data_input_init(xi):
    mean = [ 0.485, 0.456, 0.406 ]
    std = [ 0.229, 0.224, 0.225 ]
    transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean = mean,
                         std = std)])

    return (mean,std,transform)

def proj_lp(v, xi, p):
    # Project on the lp ball centered at 0 and of radius xi
    if p==np.inf:
        v=torch.clamp(v,-xi,xi)
    else:
        v=v * min(1, xi/(torch.norm(v,p)+0.00001))
    return v

def get_fooling_rate(data_list,v,model, device):
    f = data_input_init(0)[2]
    num_images = len(data_list)

    fooled=0.0

    for name in tqdm(data_list):
        image = Image.open(name)
        image = tf(image)
        image = image.unsqueeze(0)
        image = image.to(device)
        _, pred = torch.max(model(image),1)
        _, adv_pred = torch.max(model(image+v),1)
        if(pred!=adv_pred):
            fooled+=1

    # Compute the fooling rate
    fooling_rate = fooled/num_images
    print('Fooling Rate = ', fooling_rate)
    for param in model.parameters():
        param.requires_grad = False

    return fooling_rate,model

def universal_adversarial_perturbation(dataloader, model, device, xi=10, delta=0.2, max_iter_uni = 10, p=np.inf,
                                       num_classes=10, overshoot=0.02, max_iter_df=10,t_p = 0.2):
    """universal_adversarial_perturbation.

    Parameters
    ----------
    dataloader :
        dataloader
    model :
        target model
    device :
        device
    xi :
        controls the l_p magnitude of the perturbation
    delta :
        controls the desired fooling rate (default = 80% fooling rate)
    max_iter_uni :
        maximum number of iteration (default = 10*num_images)
    p :
        norm to be used (default = np.inf)
    num_classes :
        num_classes (default = 10)
    overshoot :
        to prevent vanishing updates (default = 0.02)
    max_iter_df :
        maximum number of iterations for deepfool (default = 10)
    t_p :
        truth percentage, for how many flipped labels in a batch. (default = 0.2)

    Returns
    -------
        the universal perturbation matrix.
    """
    time_start = time.time()
    mean, std,tf = data_input_init(xi)
    v = torch.zeros(1,3,224,224).to(device)
    v.requires_grad_()

    fooling_rate = 0.0
    num_images = len(dataloader)
    itr = 0

    while fooling_rate < 1-delta and itr < max_iter_uni:

        # Iterate over the dataset and compute the purturbation incrementally

        for i,(img, label) in enumerate(dataloader):
            _, pred = torch.max(model(img),1)
            _, adv_pred = torch.max(model(img+v),1)

            if(pred == adv_pred):
                perturb = deepfool(model, device)
                _ = perturb.generate(img+v, num_classed = num_classed, overshoot = overshoot, max_iter = max_iter_df)
                dr, iter = perturb.getpurb()
                if(iter<max_iter_df-1):
                    v = v + torch.from_numpy(dr).to(device)
                    v = proj_lp(v,xi,p)

            if(k%10==0):
                print('Norm of v: '+str(torch.norm(v).detach().cpu().numpy()))

        fooling_rate,model = get_fooling_rate(data_list,v,model, device)
        itr = itr + 1

    return v