File size: 6,463 Bytes
998bb30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
from foolbox import PyTorchModel, accuracy
from foolbox.attacks.base import *
from foolbox.attacks.gradient_descent_base import *
from torchvision import transforms
from utils.data_manager import DataManager, get_dataloader
import torch
import logging
import eagerpy as ep
from utils.data_manager import load_all_task_models



class SustainableAttack(Attack):
    def __init__(self, args, device='cuda'):
        super().__init__()
        self.device = device
        self.args = args

        # Only init the first 10 classes
        self.data_manager = DataManager(
            args["dataset"],
            args["shuffle"],
            args["seed"],
            args["init_cls"],
            args["increment"],
            args["attack"]
        )
        self.args['target_class_list'] = self.data_manager._class_order[:self.data_manager._increments[0]]
        self.args['target_class_dict'] = dict(zip(self.args['target_class_list'], range(len(self.args['target_class_list']))))
        self.img_s = 32 if args["dataset"] == 'cifar100' else 224
        self.batch_size = args['batch_size']
        self.loader = get_dataloader(self.data_manager, batch_size=self.batch_size,
                                start_class=0, end_class=10,
                                train=True, shuffle=True, num_workers=0)

        ckpts = sorted([f for f in os.listdir(args['logs_name']) if f.endswith('.pkl')])
        self.ckpt_paths = [os.path.join(args['logs_name'], ckpt_file) for ckpt_file in ckpts]
        self.model = None
        self.model0 = None
        self.attack = None

        self.target_class = args['target_class']
        if args["dataset"] == "cifar100":
            self.norm = transforms.Normalize(mean=[0.5071, 0.4867, 0.4408],
                                             std=[0.2675, 0.2565, 0.2761])
            self.preprocessing = dict(mean=[0.5071, 0.4867, 0.4408],
                                      std=[0.2675, 0.2565, 0.2761], axis=-3)
        else:
            self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])
            self.preprocessing = dict(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225], axis=-3)

    def run_attack(self):
        pass


    def to_alls(self, imgs, labels, labels_t=None,

               target_imgs=None, target_labels=None, return_index=False):
        correct_index = ep.full_like(ep.astensors(torch.ones((len(imgs),), dtype=bool, device=self.device))[0], fill_value=True)
        correct_index_t = ep.full_like(ep.astensors(torch.ones((len(target_imgs),), dtype=bool, device=self.device))[0], fill_value=True)
        models = load_all_task_models(self.args, self.args['logs_name'], self.data_manager,
                                      batch_size=self.batch_size,
                                      train=True,
                                      load_type='model')[0]
        for task in range(len(models)):
            model = PyTorchModel(models[task]._network, bounds=(0, 1), preprocessing=self.preprocessing)
            acc_bool = accuracy(model, imgs, labels)[1]
            if task == 0:
                acc_bool_t, target_logits = accuracy(model, target_imgs, target_labels)[1:]
            else:
                acc_bool_t = accuracy(model, target_imgs, target_labels)[1]
            correct_index = ep.logical_and(correct_index, acc_bool)
            correct_index_t = ep.logical_and(correct_index_t, acc_bool_t)
            del model, acc_bool, acc_bool_t
        if correct_index.any():
            imgs = imgs[correct_index]
            labels = labels[correct_index]
            if self.target_class is not None:
                labels_t = labels_t[correct_index]
            logging.info(
                f"Filtering {len(labels)} Correct samples for all CL models.")
        else:
            print("No valid samples found for IMGS, skipping this batch.")
            imgs, labels, labels_t = None, None, None

        if correct_index_t.any():
            target_imgs = target_imgs[correct_index_t]
            target_labels = target_labels[correct_index_t]
            target_logits = target_logits[correct_index_t]
            logging.info(
                f"Filtering {len(target_labels)} Target samples for all CL models.")
        else:
            logging.info("No valid samples found for TARGET IMGS, skipping this batch.")
            target_imgs, target_labels, target_logits = None, None, None
        if return_index:
            return correct_index, correct_index_t
        del models, correct_index, correct_index_t
        return imgs, labels, labels_t, target_imgs, target_labels, target_logits

    def to_all(self, imgs, labels, return_index=False):
        # Filtering Correct Samples for All CL Models
        correct_index = ep.full_like(ep.astensors(torch.ones((len(imgs),), dtype=bool, device=self.device))[0], fill_value=True)
        models = load_all_task_models(self.args, self.args['logs_name'], self.data_manager,
                                      batch_size=self.batch_size,
                                      train=True,
                                      load_type='model')[0]
        for task in range(len(models)):
            model = PyTorchModel(models[task]._network, bounds=(0, 1), preprocessing=self.preprocessing)
            acc_bool = accuracy(model, imgs, labels)[1]
            correct_index = ep.logical_and(correct_index, acc_bool)
            del model, acc_bool
        if correct_index.any():
            imgs = imgs[correct_index]
            labels = labels[correct_index]
            logging.info(
                f"Filtering {len(labels)} Correct samples for all CL models.")
        else:
            logging.info("No valid samples found for IMGS, skipping this batch.")
            imgs, labels = None, None
        if return_index:
            return correct_index
        del models, correct_index
        return imgs, labels


    def __call__(

        self,

        model: Model,

        inputs: T,

        criterion: Any,

        *,

        epsilons: Sequence[Union[float, None]],

        **kwargs: Any,

    ) -> Tuple[List[T], List[T], T]:
        ...

    def repeat(self, times: int) -> "SustainableAttack":
        ...