import json import argparse import sys import logging import copy import torch import os from trainCIL import _set_random, _set_device, print_args from attacks.AIM.AIMAttack import AIM from attacks.BASEAttack import BASEAttack from attacks.Gaker.GAKERAttack import Gaker from attacks.CGNC.CGNCAttack import CGNC from attacks.CleanSheet.CleanSheetAttack import CleanSheet from attacks.UnivIntruder.UnivIntruderAttack import UnivIntruder from attacks.SAE.SAEAttack import SAE def main(): args = setup_parser().parse_args() param = load_json(args.config) args = vars(args) # Converting argparse Namespace to a dict. args.update(param) # Add parameters from json evaluate(args) def evaluate(args): seed_list = copy.deepcopy(args["seed"]) device = copy.deepcopy(args["device"]) for seed in seed_list: args["seed"] = seed args["device"] = device _evaluate(args) def _evaluate(args): # For attacks args["attack"] = True device = 'cuda' if torch.cuda.is_available() else 'cpu' init_cls = 0 if args["init_cls"] == args["increment"] else args["init_cls"] logs_name = "logs/{}/{}/init_cls_{}/per_classes_{}/{}".format(args["model_name"], args["dataset"], init_cls, args['increment'], args["convnet_type"]) if not os.path.exists(logs_name): raise "Model is not trained yet." logs_eval_name = "logs/{}/{}/init_cls_{}/per_classes_{}/{}/eval/{}".format(args["model_name"], args["dataset"], init_cls, args['increment'], args["convnet_type"], args['attack_method']) if not os.path.exists(logs_eval_name): os.makedirs(logs_eval_name) logfilename = "logs/{}/{}/init_cls_{}/per_classes_{}/{}/eval/{}/adv_log".format( args["model_name"], args["dataset"], init_cls, args["increment"], args["convnet_type"], args['attack_method'] ) logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(filename)s] => %(message)s", handlers=[ logging.FileHandler(filename=logfilename + ".log"), logging.StreamHandler(sys.stdout), ], ) args['epsilons'] = [0.01, 0.015, 0.03, 0.06, 0.1, 0.2] args['logs_name'] = logs_name args['logs_eval_name'] = logs_eval_name _set_random() _set_device(args) print_args(args) # Init the attack if args['attack_method'] in NEW_ATTACKS: adv = init_new_attack(args, device=device) if args['attack_method'] == 'AIM' or args['attack_method'] == 'Gaker' or args['attack_method'] == 'CGNC': adv.train_generator() elif args['attack_method'] == 'SAE' or args['attack_method'] == 'UnivIntruder' or args['attack_method'] == 'CleanSheet': adv.train_adv() else: adv = init_foolbox_attack(args, device=device) # Conduct attack adv.run_test() NEW_ATTACKS = ['AIM', 'Gaker', 'CGNC', 'CleanSheet', 'UnivIntruder', 'SAE'] FOOLBOX_ATTACKS = ['L2FGM', 'FGSM', 'MIFGSM', 'L1PGD', 'L2PGD', 'LinfPGD', 'L2DeepFool', 'LinfDeepFool', 'BoundaryAttack', 'CarliniWagnerL2', 'GaussianNoise', 'UniformNoise'] def init_new_attack(args, device='cuda', **kwargs): attack_name = args['attack_method'] models = kwargs.get('models', None) if attack_name == 'AIM': adv = AIM(args=args, device=device) elif attack_name == 'Gaker': adv = Gaker(args=args, device=device) elif attack_name == 'CGNC': adv = CGNC(args=args, device=device) elif attack_name == 'CleanSheet': adv = CleanSheet(args=args, device=device) elif attack_name == 'UnivIntruder': adv = UnivIntruder(args=args, device=device) elif attack_name == 'SAE': adv = SAE(args=args, device=device) else: raise ValueError(f"Unknown attack method: {attack_name}") return adv def init_foolbox_attack(args, device='cuda', **kwargs): attack = BASEAttack(args=args, device=device) return attack def load_json(settings_path): with open(settings_path) as data_file: param = json.load(data_file) return param def setup_parser(): parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorithms.') parser.add_argument('--config', type=str, default='exps/finetune.json', help='Json file of settings.') parser.add_argument('--batch_size', type=int, default=128, help='set the batch size.') parser.add_argument('--attack_method', type=str, default='AIM', help='set the attack method, e.g., LinfPGD, MIFGSM, AIM, Gaker, CGNC, CleanSheet, UnivIntruder, SAE.') parser.add_argument('--target_class', type=int, default=0, help='the target class, None indicates untargeted attack.') parser.add_argument('--eval', action='store_true', help='evaluation only') return parser if __name__ == '__main__': main()