File size: 5,708 Bytes
998bb30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os.path
import torch
from foolbox.attacks.base import *
from foolbox.attacks.gradient_descent_base import *
from tqdm import tqdm
import pandas as pd
from attacks.attack_config import SustainableAttack
from utils.plot import plot_asr_per_target, save_grad_cam
import logging
from foolbox import PyTorchModel, accuracy
import numpy as np
from utils import factory
from utils.data_manager import get_dataloader
from attacks.CleanSheet.utils_ import Trigger
from attacks.CleanSheet.generate_kd import train


class CleanSheet(SustainableAttack):
    def __init__(self, args, device='cuda'):
        super().__init__(args, device)
        self.device = device
        self.args = args
        self.surrogate_model = None
        self.eval_batch_szie = 128

        self.eval = args['eval']
        self.args['run_baseline'] = True

        test_mode = 'Trigger'
        self.trigger_name = f'{test_mode}'
        self.plot_gradcam = True

    def train_adv(self):
        if self.eval:
            pass
        else:
            for i in range(10):
                self.args['target_class'] = i
                train(self.args)
                torch.cuda.empty_cache()
    def run_test(self):
        pth_name = self.get_max_step_filename(f'{self.args["logs_eval_name"]}/{self.trigger_name}/{self.target_class}')
        self.ckpt = f'{self.args["logs_eval_name"]}/{self.trigger_name}/{self.target_class}/{pth_name}.pth'
        self.prefix = f'{self.trigger_name}_{pth_name.split("_")[-1]}'
        a = torch.load(self.ckpt)
        self.trigger = Trigger(size=32).to(self.device)
        self.trigger.load_state_dict(a)
        self.trigger.eval()

        # Load Batch Data
        self.loader = get_dataloader(self.data_manager, batch_size=self.eval_batch_szie,
                                start_class=0, end_class=10,
                                train=False, shuffle=True, num_workers=0)
        for i, (_, imgs, labels) in enumerate(tqdm(self.loader, total=len(self.loader),
                                                   desc=f'Loading Data with Batch Size of {self.batch_size}) :')):
            if i > 0:
                break

            imgs, labels = ep.astensors(*(imgs.to(self.device), labels.to(self.device)))

            target_imgs = imgs[labels == self.target_class]
            target_labels = labels[labels == self.target_class]
            imgs_f = imgs[labels != self.target_class]
            labels_f = labels[labels != self.target_class]
            labels_t_f = ep.full_like(labels_f, fill_value=self.target_class)
            if self.args["model_name"] != 'finetune':
                imgs_f, labels_f = self.to_all(imgs_f, labels_f)
                if imgs_f is None:
                    continue
                if target_imgs is None:
                    continue
                labels_t_f = labels_t_f[:len(imgs_f)]

            self.attacks(i, imgs_f, labels_f, labels_t_f)

    def attacks(self, i_batch, imgs, labels, labels_t):
        asr_matrix = np.ones((10, 1))
        self.model = factory.get_model(self.args["model_name"], self.args)
        eval_path = os.path.join(self.args["logs_eval_name"], self.trigger_name)
        for task in range(10):
            logging.info("***** Starting attack on task [{}]. *****".format(task))
            self.model.incremental_train(self.data_manager)
            self.model._network.load_state_dict(torch.load(self.ckpt_paths[task], map_location=self.device)['model_state_dict'])
            self.model._network.to(self.device)
            self.model._network.eval()

            # Run attack on ecah target image
            current_model = PyTorchModel(self.model._network, bounds=(0, 1), preprocessing=self.preprocessing)
            verify_input_bounds(imgs, current_model)
            logging.info("Eval attack on each target images.")
            advs = self.trigger(imgs.raw)
            advs = ep.astensor(advs).clip(-1, 1)
            asr = accuracy(current_model, ep.astensor(advs), labels_t)[0]
            asr_matrix[task] = asr
            if self.plot_gradcam:
                save_grad_cam(self.args, torch.clip(advs.raw.detach(), -1, 1), labels_t.raw,
                              self.model._network, eval_path + "/GradCam", prefix=f'task{task}',
                              layer_name='stage_3', save_num=100, save_raw=True)
            del advs, current_model
            torch.cuda.empty_cache()

            self.model.after_task()

        # Save all target images info: everage asr,
        prefix = f'batch{i_batch}_{self.prefix}_tc{self.target_class}'
        plot_asr_per_target(asr_matrix, eval_path, prefix, self.args)
        df = pd.DataFrame(asr_matrix, columns=['ASR'])
        df.to_excel(os.path.join(eval_path, f"{prefix}.xlsx"), index=False)

        del asr_matrix, imgs, labels, labels_t
        torch.cuda.empty_cache()

    def get_max_step_filename(self, folder_path):
        files = [f for f in os.listdir(folder_path) if f.endswith('.pth')]
        step_files = [(f, int(f.split('_')[-1].split('.')[0])) for f in files]
        step_files.sort(key=lambda x: x[1], reverse=True)
        max_step_file = step_files[0][0]
        return os.path.splitext(max_step_file)[0]


    def __call__(

            self,

            model: Model,

            inputs: T,

            criterion: Any,

            *,

            epsilons: Union[Sequence[Union[float, None]], float, None],

            **kwargs: Any,

    ) -> Union[Tuple[List[T], List[T], T], Tuple[T, T, T]]:
        ...

    def repeat(self, times: int) -> "CleanSheet":
        ...