File size: 6,200 Bytes
998bb30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
from foolbox.attacks.base import *
from foolbox.attacks.gradient_descent_base import *
from attacks.attack_config import SustainableAttack
from tqdm import tqdm
import logging
import foolbox as fb
from foolbox import PyTorchModel, accuracy
import numpy as np
from utils import factory
from utils.data_manager import  get_dataloader
import os
from utils.plot import plot_asr_per_target, save_grad_cam
import pandas as pd


class BASEAttack(SustainableAttack):
    def __init__(self, args, device='cuda'):
        super().__init__(args, device)
        attack_name = args['attack_method']
        if attack_name == 'L2FGM':
            self.attack = fb.attacks.FGM()
        elif attack_name == 'FGSM':
            self.attack = fb.attacks.FGSM()
        elif attack_name == 'MIFGSM':
            self.attack = fb.attacks.MIFGSM(momentum=0.9, steps=10)
        elif attack_name == 'L1PGD':
            self.attack = fb.attacks.L1PGD(steps=10)
        elif attack_name == 'L2PGD':
            self.attack = fb.attacks.L2PGD(steps=10)
        elif attack_name == 'LinfPGD':
            self.attack = fb.attacks.LinfPGD(steps=10)
        elif attack_name == 'L2DeepFool':
            self.attack = fb.attacks.L2DeepFoolAttack(steps=10)
        elif attack_name == 'LinfDeepFool':
            self.attack = fb.attacks.LinfDeepFoolAttack(steps=10)
        elif attack_name == 'BoundaryAttack':
            self.attack = fb.attacks.BoundaryAttack(steps=10)
        elif attack_name == 'CarliniWagnerL2':
            self.attack = fb.attacks.L2CarliniWagnerAttack(steps=10)
        elif attack_name == 'GaussianNoise':
            self.attack = fb.attacks.LinfRepeatedAdditiveUniformNoiseAttack()
        elif attack_name == 'UniformNoise':
            self.attack = fb.attacks.LinfAdditiveUniformNoiseAttack()
        else:
            raise ValueError(f"Unknown attack method: {attack_name}")

        self.test_mode = 'BaseAttacks'
        self.epsilon = 16
        self.eval = args['eval']
        self.prefix = f'eps{self.epsilon}'
        self.save_path = os.path.join(self.args['logs_eval_name'])
        self.out_path = os.path.join(self.args['logs_eval_name'], f'gradcam_t{self.target_class}')
        if not os.path.exists(self.out_path):
            os.makedirs(self.out_path)

        self.plot_gradcam = True

    def run_test(self):
        # Load Batch Data
        eval_batch_szie = 1024
        self.loader = get_dataloader(self.data_manager, batch_size=eval_batch_szie,
                                start_class=0, end_class=10,
                                train=False, shuffle=False, num_workers=0)
        for i, (_, imgs, labels) in enumerate(tqdm(self.loader, total=len(self.loader),
                                                   desc=f'Loading Data with Batch Size of {self.batch_size}) :')):
            if i> 0:
                break

            imgs, labels = ep.astensors(*(imgs.to(self.device), labels.to(self.device)))
            target_imgs = imgs[labels == self.target_class]
            target_labels = labels[labels == self.target_class]
            imgs_f = imgs[labels != self.target_class]
            labels_f = labels[labels != self.target_class]
            labels_t_f = ep.full_like(labels_f, fill_value=self.target_class)

            self.attacks(i, imgs_f, labels_f, labels_t_f, target_imgs, target_labels)

    def attacks(self, i_batch, imgs, labels, labels_t):
        clean_acc_matrix = []
        asr_matrix = np.ones((10, 1))
        self.model = factory.get_model(self.args["model_name"], self.args)
        for task in range(10):
            logging.info("***** Starting attack on task [{}]. *****".format(task))
            self.model.incremental_train(self.data_manager)
            self.model._network.load_state_dict(torch.load(self.ckpt_paths[task], map_location=self.device)['model_state_dict'])
            self.model._network.to(self.device)
            self.model._network.eval()

            current_model = PyTorchModel(self.model._network, bounds=(0, 1), preprocessing=self.preprocessing)
            acc, acc_bool = accuracy(current_model, imgs, labels)[:2]
            clean_acc_matrix.append(acc)
            if task == 0:
                imgs = imgs[acc_bool]
                labels = labels[acc_bool]
                labels_t = labels_t[acc_bool]
            verify_input_bounds(imgs, current_model)
            criterion = fb.criteria.Misclassification(
                labels) if self.target_class is None else fb.criteria.TargetedMisclassification(
                labels_t)
            if task == 0:
                adv, adv_clip, asr_bool = self.attack(current_model, imgs, criterion=criterion, epsilons=self.epsilon/255)
                asr_matrix[task] = asr_bool.sum().raw.item() / len(imgs)

            else:
                asr = accuracy(current_model, adv_clip, labels_t)[0]
                asr_matrix[task] = asr

            if self.plot_gradcam:
                save_grad_cam(self.args, torch.clip(adv_clip.raw, 0, 1), labels_t.raw,
                               self.model._network, self.out_path + "/GradCam", prefix=f'task{task}',
                               layer_name='stage_3', save_num=100, save_raw=True)

            del criterion, current_model
            torch.cuda.empty_cache()

            self.model.after_task()

        # Save all target images info: everage asr,
        prefix = f'batch{i_batch}_{self.prefix}_tc{self.target_class}'
        plot_asr_per_target(asr_matrix, self.save_path, prefix, self.args, clean_acc_matrix)
        df = pd.DataFrame(asr_matrix, columns=['ASR'])
        df.to_excel(os.path.join(self.save_path, f"{prefix}.xlsx"), index=False)

        del asr_matrix, imgs, labels, labels_t
        torch.cuda.empty_cache()


    def __call__(

        self,

        model: Model,

        inputs: T,

        criterion: Any,

        *,

        epsilons: Sequence[Union[float, None]],

        **kwargs: Any,

    ) -> Tuple[List[T], List[T], T]:
        ...

    def repeat(self, times: int) -> "BASEAttack":
        ...