File size: 6,463 Bytes
998bb30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import os
from foolbox import PyTorchModel, accuracy
from foolbox.attacks.base import *
from foolbox.attacks.gradient_descent_base import *
from torchvision import transforms
from utils.data_manager import DataManager, get_dataloader
import torch
import logging
import eagerpy as ep
from utils.data_manager import load_all_task_models
class SustainableAttack(Attack):
def __init__(self, args, device='cuda'):
super().__init__()
self.device = device
self.args = args
# Only init the first 10 classes
self.data_manager = DataManager(
args["dataset"],
args["shuffle"],
args["seed"],
args["init_cls"],
args["increment"],
args["attack"]
)
self.args['target_class_list'] = self.data_manager._class_order[:self.data_manager._increments[0]]
self.args['target_class_dict'] = dict(zip(self.args['target_class_list'], range(len(self.args['target_class_list']))))
self.img_s = 32 if args["dataset"] == 'cifar100' else 224
self.batch_size = args['batch_size']
self.loader = get_dataloader(self.data_manager, batch_size=self.batch_size,
start_class=0, end_class=10,
train=True, shuffle=True, num_workers=0)
ckpts = sorted([f for f in os.listdir(args['logs_name']) if f.endswith('.pkl')])
self.ckpt_paths = [os.path.join(args['logs_name'], ckpt_file) for ckpt_file in ckpts]
self.model = None
self.model0 = None
self.attack = None
self.target_class = args['target_class']
if args["dataset"] == "cifar100":
self.norm = transforms.Normalize(mean=[0.5071, 0.4867, 0.4408],
std=[0.2675, 0.2565, 0.2761])
self.preprocessing = dict(mean=[0.5071, 0.4867, 0.4408],
std=[0.2675, 0.2565, 0.2761], axis=-3)
else:
self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
self.preprocessing = dict(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225], axis=-3)
def run_attack(self):
pass
def to_alls(self, imgs, labels, labels_t=None,
target_imgs=None, target_labels=None, return_index=False):
correct_index = ep.full_like(ep.astensors(torch.ones((len(imgs),), dtype=bool, device=self.device))[0], fill_value=True)
correct_index_t = ep.full_like(ep.astensors(torch.ones((len(target_imgs),), dtype=bool, device=self.device))[0], fill_value=True)
models = load_all_task_models(self.args, self.args['logs_name'], self.data_manager,
batch_size=self.batch_size,
train=True,
load_type='model')[0]
for task in range(len(models)):
model = PyTorchModel(models[task]._network, bounds=(0, 1), preprocessing=self.preprocessing)
acc_bool = accuracy(model, imgs, labels)[1]
if task == 0:
acc_bool_t, target_logits = accuracy(model, target_imgs, target_labels)[1:]
else:
acc_bool_t = accuracy(model, target_imgs, target_labels)[1]
correct_index = ep.logical_and(correct_index, acc_bool)
correct_index_t = ep.logical_and(correct_index_t, acc_bool_t)
del model, acc_bool, acc_bool_t
if correct_index.any():
imgs = imgs[correct_index]
labels = labels[correct_index]
if self.target_class is not None:
labels_t = labels_t[correct_index]
logging.info(
f"Filtering {len(labels)} Correct samples for all CL models.")
else:
print("No valid samples found for IMGS, skipping this batch.")
imgs, labels, labels_t = None, None, None
if correct_index_t.any():
target_imgs = target_imgs[correct_index_t]
target_labels = target_labels[correct_index_t]
target_logits = target_logits[correct_index_t]
logging.info(
f"Filtering {len(target_labels)} Target samples for all CL models.")
else:
logging.info("No valid samples found for TARGET IMGS, skipping this batch.")
target_imgs, target_labels, target_logits = None, None, None
if return_index:
return correct_index, correct_index_t
del models, correct_index, correct_index_t
return imgs, labels, labels_t, target_imgs, target_labels, target_logits
def to_all(self, imgs, labels, return_index=False):
# Filtering Correct Samples for All CL Models
correct_index = ep.full_like(ep.astensors(torch.ones((len(imgs),), dtype=bool, device=self.device))[0], fill_value=True)
models = load_all_task_models(self.args, self.args['logs_name'], self.data_manager,
batch_size=self.batch_size,
train=True,
load_type='model')[0]
for task in range(len(models)):
model = PyTorchModel(models[task]._network, bounds=(0, 1), preprocessing=self.preprocessing)
acc_bool = accuracy(model, imgs, labels)[1]
correct_index = ep.logical_and(correct_index, acc_bool)
del model, acc_bool
if correct_index.any():
imgs = imgs[correct_index]
labels = labels[correct_index]
logging.info(
f"Filtering {len(labels)} Correct samples for all CL models.")
else:
logging.info("No valid samples found for IMGS, skipping this batch.")
imgs, labels = None, None
if return_index:
return correct_index
del models, correct_index
return imgs, labels
def __call__(
self,
model: Model,
inputs: T,
criterion: Any,
*,
epsilons: Sequence[Union[float, None]],
**kwargs: Any,
) -> Tuple[List[T], List[T], T]:
...
def repeat(self, times: int) -> "SustainableAttack":
...
|