File size: 6,039 Bytes
998bb30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os.path
import torch
from foolbox.attacks.base import *
from foolbox.attacks.gradient_descent_base import *
from tqdm import tqdm
import pandas as pd
from attacks.attack_config import SustainableAttack
from utils.plot import save_grad_cam, plot_asr_per_target
import logging
from foolbox import PyTorchModel, accuracy
import numpy as np
from utils import factory
from utils.data_manager import DataManager, get_dataloader
from attacks.UnivIntruder.att import train
from attacks.UnivIntruder.loss import UniversalPerturbation


class UnivIntruder(SustainableAttack):
    def __init__(self, args, device='cuda'):
        super().__init__(args, device)
        self.device = device
        self.args = args
        self.surrogate_model = None  # surrogate__model.to(device).eval()

        self.target_class = args['target_class']
        self.epsilon = 32
        self.image_size = 32 if self.args['dataset'] == 'cifar100' else 224
        self.eval_batch_szie = 128

        self.adv_name = f'adv_eps{self.epsilon}_tc{self.target_class}'
        self.out_path = f'{self.args["logs_eval_name"]}/{self.adv_name}'
        os.makedirs(self.out_path, exist_ok=True)

        self.eval = args['eval']
        self.plot_gradcam = True
        self.ckpt_num = None  # None

    def train_adv(self):
        if self.eval:
            pass
        else:
            train(self.args)

    def run_test(self):
        pth_name = self.get_max_step_filename(f'{self.args["logs_eval_name"]}/{self.adv_name}/ckpts')

        self.ckpt = f'{self.args["logs_eval_name"]}/{self.adv_name}/ckpts/{pth_name}.pth'
        self.prefix = f'{self.adv_name}_{pth_name.split("_")[-1]}'
        a = torch.load(self.ckpt)
        self.adv = UniversalPerturbation((3, self.image_size, self.image_size), self.epsilon / 255,
                                             initialization=a, device=self.device)
        self.adv.eval()

        # Load Batch Data
        self.loader = get_dataloader(self.data_manager, batch_size=self.eval_batch_szie,
                                     start_class=0, end_class=10,
                                     train=False, shuffle=False, num_workers=0)
        for i, (_, imgs, labels) in enumerate(tqdm(self.loader, total=len(self.loader),
                                                   desc=f'Loading Data with Batch Size of {self.batch_size}) :')):
            if i > 0:
                break

            imgs, labels = ep.astensors(*(imgs.to(self.device), labels.to(self.device)))

            imgs_f = imgs[labels != self.target_class]
            labels_f = labels[labels != self.target_class]
            labels_t_f = ep.full_like(labels_f, fill_value=self.target_class)

            self.attacks(i, imgs_f, labels_f, labels_t_f)

    def attacks(self, i_batch, imgs, labels, labels_t):
        if self.args["model_name"] != 'finetune':
            imgs_f, labels_f = self.to_all(imgs, labels)
            labels_t = labels_t[:len(imgs_f)]
        else:
            imgs_f, labels_f = imgs, labels
        clean_acc_matrix = []
        asr_matrix = np.ones((self.data_manager.nb_tasks, 1))
        self.model = factory.get_model(self.args["model_name"], self.args)
        eval_path = os.path.join(self.args["logs_eval_name"], self.adv_name)
        cnn_matrix, nme_matrix = [], []
        for task in range(self.data_manager.nb_tasks):
            logging.info("***** Starting attack on task [{}]. *****".format(task))
            self.model.incremental_train(self.data_manager)
            self.model._network.load_state_dict(
                torch.load(self.ckpt_paths[task], map_location=self.device)['model_state_dict'])
            self.model._network.to(self.device)
            self.model._network.eval()

            # Run attack on ecah target image
            current_model = PyTorchModel(self.model._network, bounds=(0, 1), preprocessing=self.preprocessing)
            verify_input_bounds(imgs_f, current_model)

            # Evaluate the model perfromance with clean data
            acc = accuracy(current_model, imgs, labels)[0]
            logging.info("Clean accuracy on task {}: {}%".format(task, acc * 100))
            clean_acc_matrix.append(acc)

            advs = self.adv(imgs_f.raw)
            asr = accuracy(current_model, ep.astensor(advs), labels_t)[0]
            asr_matrix[task] = asr
            if self.plot_gradcam:
                save_grad_cam(self.args, torch.clip(advs.detach(), 0, 1), labels_t.raw,
                              self.model._network, self.out_path + "/GradCam", prefix=f'task{task}',
                              layer_name='stage_3', save_num=100, save_raw=True)

            del advs, current_model
            torch.cuda.empty_cache()

            self.model.after_task()

        # Save all target images info: everage asr,
        prefix = f'batch{i_batch}_{self.prefix}'
        plot_asr_per_target(asr_matrix, eval_path, prefix, self.args, clean_acc_matrix)
        df = pd.DataFrame(asr_matrix, columns=['ASR'])
        df.to_excel(os.path.join(eval_path, f"{prefix}.xlsx"), index=False)

        del asr_matrix, imgs, labels, labels_t, imgs_f, labels_f
        torch.cuda.empty_cache()

    def get_max_step_filename(self, folder_path):
        files = [f for f in os.listdir(folder_path) if f.endswith('.pth')]
        step_files = [(f, int(f.split('_')[-1].split('.')[0])) for f in files]
        step_files.sort(key=lambda x: x[1], reverse=True)
        max_step_file = step_files[0][0]
        return os.path.splitext(max_step_file)[0]

    def __call__(

            self,

            model: Model,

            inputs: T,

            criterion: Any,

            *,

            epsilons: Union[Sequence[Union[float, None]], float, None],

            **kwargs: Any,

    ) -> Union[Tuple[List[T], List[T], T], Tuple[T, T, T]]:
        ...

    def repeat(self, times: int) -> "UnivIntruder":
        ...