SAE / attack.py
Ttius's picture
Upload 192 files
998bb30 verified
import json
import argparse
import sys
import logging
import copy
import torch
import os
from trainCIL import _set_random, _set_device, print_args
from attacks.AIM.AIMAttack import AIM
from attacks.BASEAttack import BASEAttack
from attacks.Gaker.GAKERAttack import Gaker
from attacks.CGNC.CGNCAttack import CGNC
from attacks.CleanSheet.CleanSheetAttack import CleanSheet
from attacks.UnivIntruder.UnivIntruderAttack import UnivIntruder
from attacks.SAE.SAEAttack import SAE
def main():
args = setup_parser().parse_args()
param = load_json(args.config)
args = vars(args) # Converting argparse Namespace to a dict.
args.update(param) # Add parameters from json
evaluate(args)
def evaluate(args):
seed_list = copy.deepcopy(args["seed"])
device = copy.deepcopy(args["device"])
for seed in seed_list:
args["seed"] = seed
args["device"] = device
_evaluate(args)
def _evaluate(args):
# For attacks
args["attack"] = True
device = 'cuda' if torch.cuda.is_available() else 'cpu'
init_cls = 0 if args["init_cls"] == args["increment"] else args["init_cls"]
logs_name = "logs/{}/{}/init_cls_{}/per_classes_{}/{}".format(args["model_name"], args["dataset"], init_cls,
args['increment'], args["convnet_type"])
if not os.path.exists(logs_name):
raise "Model is not trained yet."
logs_eval_name = "logs/{}/{}/init_cls_{}/per_classes_{}/{}/eval/{}".format(args["model_name"], args["dataset"], init_cls,
args['increment'], args["convnet_type"], args['attack_method'])
if not os.path.exists(logs_eval_name):
os.makedirs(logs_eval_name)
logfilename = "logs/{}/{}/init_cls_{}/per_classes_{}/{}/eval/{}/adv_log".format(
args["model_name"],
args["dataset"],
init_cls,
args["increment"],
args["convnet_type"],
args['attack_method']
)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(filename)s] => %(message)s",
handlers=[
logging.FileHandler(filename=logfilename + ".log"),
logging.StreamHandler(sys.stdout),
],
)
args['epsilons'] = [0.01, 0.015, 0.03, 0.06, 0.1, 0.2]
args['logs_name'] = logs_name
args['logs_eval_name'] = logs_eval_name
_set_random()
_set_device(args)
print_args(args)
# Init the attack
if args['attack_method'] in NEW_ATTACKS:
adv = init_new_attack(args, device=device)
if args['attack_method'] == 'AIM' or args['attack_method'] == 'Gaker' or args['attack_method'] == 'CGNC':
adv.train_generator()
elif args['attack_method'] == 'SAE' or args['attack_method'] == 'UnivIntruder' or args['attack_method'] == 'CleanSheet':
adv.train_adv()
else:
adv = init_foolbox_attack(args, device=device)
# Conduct attack
adv.run_test()
NEW_ATTACKS = ['AIM', 'Gaker', 'CGNC', 'CleanSheet', 'UnivIntruder', 'SAE']
FOOLBOX_ATTACKS = ['L2FGM', 'FGSM', 'MIFGSM',
'L1PGD', 'L2PGD', 'LinfPGD',
'L2DeepFool', 'LinfDeepFool', 'BoundaryAttack',
'CarliniWagnerL2', 'GaussianNoise', 'UniformNoise']
def init_new_attack(args, device='cuda', **kwargs):
attack_name = args['attack_method']
models = kwargs.get('models', None)
if attack_name == 'AIM':
adv = AIM(args=args, device=device)
elif attack_name == 'Gaker':
adv = Gaker(args=args, device=device)
elif attack_name == 'CGNC':
adv = CGNC(args=args, device=device)
elif attack_name == 'CleanSheet':
adv = CleanSheet(args=args, device=device)
elif attack_name == 'UnivIntruder':
adv = UnivIntruder(args=args, device=device)
elif attack_name == 'SAE':
adv = SAE(args=args, device=device)
else:
raise ValueError(f"Unknown attack method: {attack_name}")
return adv
def init_foolbox_attack(args, device='cuda', **kwargs):
attack = BASEAttack(args=args, device=device)
return attack
def load_json(settings_path):
with open(settings_path) as data_file:
param = json.load(data_file)
return param
def setup_parser():
parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorithms.')
parser.add_argument('--config', type=str, default='exps/finetune.json', help='Json file of settings.')
parser.add_argument('--batch_size', type=int, default=128, help='set the batch size.')
parser.add_argument('--attack_method', type=str, default='AIM', help='set the attack method, e.g., LinfPGD, MIFGSM, AIM, Gaker, CGNC, CleanSheet, UnivIntruder, SAE.')
parser.add_argument('--target_class', type=int, default=0, help='the target class, None indicates untargeted attack.')
parser.add_argument('--eval', action='store_true', help='evaluation only')
return parser
if __name__ == '__main__':
main()