Ttius commited on
Commit
998bb30
·
verified ·
1 Parent(s): 633b046

Upload 192 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +113 -3
  2. attack.py +140 -0
  3. attacks/AIM/AIMAttack.py +226 -0
  4. attacks/AIM/examples/aim_attack.py +237 -0
  5. attacks/AIM/examples/ens-gen.py +523 -0
  6. attacks/AIM/examples/workdirs/1MsK/args.txt +11 -0
  7. attacks/AIM/examples/workdirs/CGKQ/args.txt +11 -0
  8. attacks/AIM/examples/workdirs/Fvu1/args.txt +11 -0
  9. attacks/AIM/examples/workdirs/Qonx/args.txt +11 -0
  10. attacks/AIM/examples/workdirs/fnNs/args.txt +11 -0
  11. attacks/AIM/examples/workdirs/jtMb/args.txt +11 -0
  12. attacks/AIM/examples/workdirs/krhX/args.txt +11 -0
  13. attacks/AIM/setup.py +47 -0
  14. attacks/AIM/src/gat/__init__.py +1 -0
  15. attacks/AIM/src/gat/datasets/__init__.py +5 -0
  16. attacks/AIM/src/gat/datasets/builder.py +11 -0
  17. attacks/AIM/src/gat/datasets/cub.py +6 -0
  18. attacks/AIM/src/gat/datasets/env.py +6 -0
  19. attacks/AIM/src/gat/datasets/imagenet.py +36 -0
  20. attacks/AIM/src/gat/datasets/transforms.py +55 -0
  21. attacks/AIM/src/gat/models/__init__.py +0 -0
  22. attacks/AIM/src/gat/models/attack/__init__.py +5 -0
  23. attacks/AIM/src/gat/models/attack/aim_attack.py +14 -0
  24. attacks/AIM/src/gat/models/attack/base_attack.py +60 -0
  25. attacks/AIM/src/gat/models/attack/cda_attack.py +13 -0
  26. attacks/AIM/src/gat/models/attack/generator/__init__.py +0 -0
  27. attacks/AIM/src/gat/models/attack/generator/aim.py +179 -0
  28. attacks/AIM/src/gat/models/attack/generator/cda.py +146 -0
  29. attacks/AIM/src/gat/models/attack/loss/__init__.py +0 -0
  30. attacks/AIM/src/gat/models/attack/loss/logits.py +27 -0
  31. attacks/AIM/src/gat/models/attack/optim/__init__.py +3 -0
  32. attacks/AIM/src/gat/models/attack/optim/sam.py +100 -0
  33. attacks/AIM/src/gat/models/surrogate/__init__.py +10 -0
  34. attacks/AIM/src/gat/models/surrogate/builder.py +11 -0
  35. attacks/AIM/src/gat/models/surrogate/hooks.py +59 -0
  36. attacks/AIM/src/gat/models/surrogate/tv.py +206 -0
  37. attacks/AIM/src/gat/runtime/__init__.py +4 -0
  38. attacks/AIM/src/gat/runtime/api/__init__.py +0 -0
  39. attacks/AIM/src/gat/runtime/api/aim_attack.py +58 -0
  40. attacks/AIM/src/gat/runtime/factory.py +97 -0
  41. attacks/AIM/src/gat/runtime/meter.py +27 -0
  42. attacks/AIM/src/gat/runtime/utils.py +18 -0
  43. attacks/AIM/tests/__init__.py +0 -0
  44. attacks/AIM/tests/test_datasets/__init__.py +0 -0
  45. attacks/AIM/tests/test_datasets/test_datasets.py +33 -0
  46. attacks/AIM/tests/test_datasets/test_transforms.py +80 -0
  47. attacks/AIM/tests/test_models/__init__.py +0 -0
  48. attacks/AIM/tests/test_models/test_attack.py +110 -0
  49. attacks/AIM/tests/test_models/test_surrogate.py +59 -0
  50. attacks/AIM/tests/test_runtime/__init__.py +0 -0
README.md CHANGED
@@ -1,3 +1,113 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SAE: Sustainable Adversarial Example Evaluation Framework for Class-Incremental Learning
2
+
3
+ ## 🌟 Overview
4
+
5
+ _**News**: This work has been accepted as a poster by [AAAI 2026](https://aaai.org/conference/aaai/aaai-26/)._
6
+
7
+ **SAE (Sustainable Adversarial Example)** is a *universal adversarial attack framework* targeting **Class-Incremental Learning (CIL)**. This repository provides a comprehensive pipeline for both CIL training and benchmarking multiple attack methods, including our proposed SAE approach.
8
+
9
+ The project integrates with [PyCIL: A Python Toolbox for Class-Incremental Learning](https://github.com/LAMDA-CL/PyCIL) for CIL model training. It also supports benchmarking several attack baselines alongside SAE, enabling fair and reproducible evaluations of adversarial robustness across CIL methods.
10
+
11
+ If you are interested in our work, please refer to:
12
+ ```
13
+ @inproceedings{liu2026SAE,
14
+ title={Improving Sustainability of Adversarial Examples in Class-Incremental Learning},
15
+ author={Taifeng Liu, Xinjing Liu, Liangqiu Dong, Yang Liu, Yilong Yang, Zhuo Ma},
16
+ booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
17
+ year={2026}
18
+ }
19
+ ```
20
+
21
+ ---
22
+
23
+ ## ⚙️ Environment Setup
24
+
25
+ ### Project Layout
26
+
27
+ * `attacks/`: Implementations of all attack baselines including **MIFGSM**, **Gaker**, **AIM**, **CGNC**, **CleanSheet**, **UnivIntruder**, and **SAE**.
28
+ * `convs/`: Backbone definitions for CIL models and CLIP model, including `resnet32`, `resnet50`, `cosine_resnet32`, and `cosine_resnet50`.
29
+ * **Pre-trained CLIP** is downloaded from [HuggingFace](https://huggingface.co/laion/CLIP-ViT-B-32-laion2B-s34B-b79K).
30
+ * `datasets/`: Dataset management for CIFAR-100 (32x32) and ImageNet-100 (224x224).
31
+ * **CIFAR-100** is automatically downloaded at runtime.
32
+ * **ImageNet-100** should be manually extracted from ImageNet-1K using `create_imagenet100_from_imagenet.py`, which parses `train.txt` and `eval.txt` to extract relevant classes.
33
+ * `exps/`: JSON configuration files for various CIL training methods.
34
+ * `logs/`: Stores trained CIL model checkpoints and evaluation results of attacks.
35
+ * `models/`: Contains implementations of **9** CIL algorithms: **BiC**, **DER**, **Finetune**, **Foster**, **iCaRL**, **MEMO**, **PodNet**, **Replay**, **WA**.
36
+ * `scripts/`: Shell scripts for CIL training and attack benchmarking.
37
+ * `utils/`: Utility functions for augmentation, logging, dataset processing, visualization, etc.
38
+ * `attack.py`: Main entry point to run adversarial attacks.
39
+ * `trainCIL.py`: Main entry point to train CIL models.
40
+
41
+ ### Dependency Requirements
42
+
43
+ * **OS**: Ubuntu 22.04
44
+ * **Python**: 3.12
45
+ * **PyTorch**: ≥ 2.1
46
+ * **GPU**: NVIDIA RTX 4090 (24GB VRAM)
47
+
48
+ To set up the environment:
49
+
50
+ ```bash
51
+ conda env create -f environment.yml
52
+ conda activate SAE
53
+ ```
54
+
55
+ ---
56
+
57
+ ## 🚀 Result Reproduction
58
+
59
+ ### Step 1: CIL Training
60
+
61
+ First, you need to train the target CIL models follow the instruction below.
62
+
63
+ Example: training a single CIL method (e.g., **iCaRL**) on **CIFAR-100**:
64
+
65
+ ```bash
66
+ python trainCIL.py --config exps/icarl.json
67
+ ```
68
+
69
+ Example: training a single CIL method (e.g., **iCaRL**) on **ImageNet-100**:
70
+
71
+ ```bash
72
+ python trainCIL.py --config exps/icarl-imagenet100.json
73
+ ```
74
+
75
+ We also provide scripts to train all 9 CIL methods on **CIFAR-100**:
76
+
77
+ ```bash
78
+ ./scripts/trainCIL-CIFAR100.sh
79
+ ```
80
+
81
+ Train all 9 CIL methods on **ImageNet-100**:
82
+
83
+ ```bash
84
+ ./scripts/trainCIL-ImageNet100.sh
85
+ ```
86
+
87
+ All model checkpoints will be saved under the `logs/` directory, organized by method and dataset.
88
+
89
+
90
+ ---
91
+
92
+ ### Step 2: Adversarial Attack Benchmarking
93
+
94
+ _Note: assume that you have already prepared the dataset, the CIL model, and the CLIP model._
95
+
96
+ To launch an **SAE** attack targeting class **0** on a CIL model trained on **CIFAR-100**:
97
+
98
+ ```bash
99
+ python attack.py --config exps/icarl.json --attack_method SAE --target_class 0
100
+ ```
101
+
102
+ To test a different attack baseline, simply change the `--attack_method` argument.
103
+
104
+ You can reproduce the overall evaluation results by runing the below scripts (referring to the **Table 1** in our paper).
105
+
106
+ ```bash
107
+ ./scripts/attacks-CIFAR100.sh
108
+ ./scripts/attacks-ImageNet100.sh
109
+ ```
110
+
111
+ #### Benchmark Results
112
+
113
+ All benchmark results are available in the `appendix.pdf` of Supplementary Material.
attack.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+ import sys
4
+ import logging
5
+ import copy
6
+ import torch
7
+ import os
8
+
9
+ from trainCIL import _set_random, _set_device, print_args
10
+ from attacks.AIM.AIMAttack import AIM
11
+ from attacks.BASEAttack import BASEAttack
12
+ from attacks.Gaker.GAKERAttack import Gaker
13
+ from attacks.CGNC.CGNCAttack import CGNC
14
+ from attacks.CleanSheet.CleanSheetAttack import CleanSheet
15
+ from attacks.UnivIntruder.UnivIntruderAttack import UnivIntruder
16
+ from attacks.SAE.SAEAttack import SAE
17
+
18
+
19
+ def main():
20
+ args = setup_parser().parse_args()
21
+ param = load_json(args.config)
22
+ args = vars(args) # Converting argparse Namespace to a dict.
23
+ args.update(param) # Add parameters from json
24
+
25
+ evaluate(args)
26
+
27
+
28
+ def evaluate(args):
29
+ seed_list = copy.deepcopy(args["seed"])
30
+ device = copy.deepcopy(args["device"])
31
+
32
+ for seed in seed_list:
33
+ args["seed"] = seed
34
+ args["device"] = device
35
+ _evaluate(args)
36
+
37
+
38
+ def _evaluate(args):
39
+ # For attacks
40
+ args["attack"] = True
41
+
42
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
43
+ init_cls = 0 if args["init_cls"] == args["increment"] else args["init_cls"]
44
+ logs_name = "logs/{}/{}/init_cls_{}/per_classes_{}/{}".format(args["model_name"], args["dataset"], init_cls,
45
+ args['increment'], args["convnet_type"])
46
+ if not os.path.exists(logs_name):
47
+ raise "Model is not trained yet."
48
+ logs_eval_name = "logs/{}/{}/init_cls_{}/per_classes_{}/{}/eval/{}".format(args["model_name"], args["dataset"], init_cls,
49
+ args['increment'], args["convnet_type"], args['attack_method'])
50
+ if not os.path.exists(logs_eval_name):
51
+ os.makedirs(logs_eval_name)
52
+
53
+ logfilename = "logs/{}/{}/init_cls_{}/per_classes_{}/{}/eval/{}/adv_log".format(
54
+ args["model_name"],
55
+ args["dataset"],
56
+ init_cls,
57
+ args["increment"],
58
+ args["convnet_type"],
59
+ args['attack_method']
60
+ )
61
+ logging.basicConfig(
62
+ level=logging.INFO,
63
+ format="%(asctime)s [%(filename)s] => %(message)s",
64
+ handlers=[
65
+ logging.FileHandler(filename=logfilename + ".log"),
66
+ logging.StreamHandler(sys.stdout),
67
+ ],
68
+ )
69
+
70
+ args['epsilons'] = [0.01, 0.015, 0.03, 0.06, 0.1, 0.2]
71
+ args['logs_name'] = logs_name
72
+ args['logs_eval_name'] = logs_eval_name
73
+ _set_random()
74
+ _set_device(args)
75
+ print_args(args)
76
+
77
+ # Init the attack
78
+ if args['attack_method'] in NEW_ATTACKS:
79
+ adv = init_new_attack(args, device=device)
80
+ if args['attack_method'] == 'AIM' or args['attack_method'] == 'Gaker' or args['attack_method'] == 'CGNC':
81
+ adv.train_generator()
82
+ elif args['attack_method'] == 'SAE' or args['attack_method'] == 'UnivIntruder' or args['attack_method'] == 'CleanSheet':
83
+ adv.train_adv()
84
+ else:
85
+ adv = init_foolbox_attack(args, device=device)
86
+
87
+ # Conduct attack
88
+ adv.run_test()
89
+
90
+
91
+ NEW_ATTACKS = ['AIM', 'Gaker', 'CGNC', 'CleanSheet', 'UnivIntruder', 'SAE']
92
+ FOOLBOX_ATTACKS = ['L2FGM', 'FGSM', 'MIFGSM',
93
+ 'L1PGD', 'L2PGD', 'LinfPGD',
94
+ 'L2DeepFool', 'LinfDeepFool', 'BoundaryAttack',
95
+ 'CarliniWagnerL2', 'GaussianNoise', 'UniformNoise']
96
+
97
+ def init_new_attack(args, device='cuda', **kwargs):
98
+ attack_name = args['attack_method']
99
+ models = kwargs.get('models', None)
100
+
101
+ if attack_name == 'AIM':
102
+ adv = AIM(args=args, device=device)
103
+ elif attack_name == 'Gaker':
104
+ adv = Gaker(args=args, device=device)
105
+ elif attack_name == 'CGNC':
106
+ adv = CGNC(args=args, device=device)
107
+ elif attack_name == 'CleanSheet':
108
+ adv = CleanSheet(args=args, device=device)
109
+ elif attack_name == 'UnivIntruder':
110
+ adv = UnivIntruder(args=args, device=device)
111
+ elif attack_name == 'SAE':
112
+ adv = SAE(args=args, device=device)
113
+ else:
114
+ raise ValueError(f"Unknown attack method: {attack_name}")
115
+
116
+ return adv
117
+
118
+ def init_foolbox_attack(args, device='cuda', **kwargs):
119
+ attack = BASEAttack(args=args, device=device)
120
+ return attack
121
+
122
+ def load_json(settings_path):
123
+ with open(settings_path) as data_file:
124
+ param = json.load(data_file)
125
+
126
+ return param
127
+
128
+
129
+ def setup_parser():
130
+ parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorithms.')
131
+ parser.add_argument('--config', type=str, default='exps/finetune.json', help='Json file of settings.')
132
+ parser.add_argument('--batch_size', type=int, default=128, help='set the batch size.')
133
+ parser.add_argument('--attack_method', type=str, default='AIM', help='set the attack method, e.g., LinfPGD, MIFGSM, AIM, Gaker, CGNC, CleanSheet, UnivIntruder, SAE.')
134
+ parser.add_argument('--target_class', type=int, default=0, help='the target class, None indicates untargeted attack.')
135
+ parser.add_argument('--eval', action='store_true', help='evaluation only')
136
+ return parser
137
+
138
+
139
+ if __name__ == '__main__':
140
+ main()
attacks/AIM/AIMAttack.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import torch
3
+ from foolbox.attacks.base import *
4
+ from foolbox.attacks.gradient_descent_base import *
5
+ from tqdm import tqdm
6
+ from attacks.AIM.src.gat.models.attack import AIMAttack, ContrastiveLoss
7
+ from attacks.AIM.src.gat.models.surrogate import midlayer_dict, register_collecter, register_collecter_cl
8
+ from attacks.attack_config import SustainableAttack
9
+ from utils.plot import plot_asr_per_target, save_grad_cam
10
+ import logging
11
+ import pandas as pd
12
+ import foolbox as fb
13
+ from foolbox import PyTorchModel
14
+ import numpy as np
15
+ from utils import factory
16
+ from utils.data_manager import get_dataloader
17
+
18
+
19
+ class AIM(SustainableAttack):
20
+ def __init__(self, args, device='cuda'):
21
+ super().__init__(args, device)
22
+ self.device = device
23
+ self.args = args
24
+ self.surrogate_model = None
25
+
26
+ self.adv_generator = AIMAttack(device=device)
27
+ self.adv_generator.set_mode('train')
28
+ self.lr = 0.001
29
+ self.betas = (0.5, 0.999)
30
+ self.num_epoch = 100
31
+ self.optim = torch.optim.Adam(self.adv_generator.get_params(), lr=self.lr, betas=self.betas)
32
+ self.contrastive_loss = ContrastiveLoss(0.2)
33
+ self.sim_loss = torch.nn.functional.cosine_similarity
34
+ self.eval_batch_szie = 128
35
+
36
+ self.surrogate_model_name = 'resnet32_cl'
37
+ self.layer = midlayer_dict[self.surrogate_model_name]
38
+ self.prefix = (f'adv_generator_{self.surrogate_model_name}'
39
+ f'_{self.layer}'
40
+ f'_tclass{self.target_class}')
41
+ self.save_path = os.path.join(self.args['logs_eval_name'], f'target{str(self.target_class)}')
42
+ if not os.path.exists(self.save_path):
43
+ os.makedirs(self.save_path)
44
+
45
+ self.plot_gradcam = False
46
+
47
+ def train_generator(self):
48
+ if 'cl' in self.surrogate_model_name:
49
+ s_model = factory.get_model(self.args["model_name"], self.args)
50
+ s_model.incremental_train(self.data_manager)
51
+ s_model._network.load_state_dict(
52
+ torch.load(self.ckpt_paths[0], map_location=self.device)['model_state_dict'])
53
+ s_model._network.to(self.device)
54
+ s_model._network.eval()
55
+ self.surrogate_model = s_model._network
56
+ del s_model
57
+ torch.cuda.empty_cache()
58
+ self.feat_collecter = []
59
+ self.feat_collecter_handler, self.feat_collecter = register_collecter_cl(self.surrogate_model,
60
+ self.layer,
61
+ self.feat_collecter,
62
+ self.args["model_name"])
63
+ else:
64
+ self.surrogate_model = torch.hub.load("chenyaofo/pytorch-cifar-models", 'cifar100_resnet32', pretrained=True)
65
+ self.surrogate_model.to(self.device)
66
+ self.surrogate_model.eval()
67
+ self.feat_collecter = []
68
+ self.feat_collecter_handler, self.feat_collecter = register_collecter(self.surrogate_model,
69
+ self.layer,
70
+ self.feat_collecter)
71
+ self.file_path = os.path.join(self.save_path, f'{self.prefix}.pth')
72
+ if os.path.exists(self.file_path):
73
+ self.adv_generator.load_ckpt(self.file_path)
74
+ self.adv_generator.set_mode('eval')
75
+ else:
76
+ loaders = get_dataloader(self.data_manager, batch_size=self.batch_size,
77
+ start_class=0, end_class=10,
78
+ train=True, shuffle=True, num_workers=0)
79
+
80
+ target_images = []
81
+ target_labels = []
82
+ for data in loaders:
83
+ _, image_batch, label_batch = data
84
+ mask = label_batch == self.target_class
85
+ selected_images = image_batch[mask]
86
+ selected_labels = label_batch[mask]
87
+ target_images.append(selected_images)
88
+ target_labels.append(selected_labels)
89
+ del loaders
90
+ target_images = torch.cat(target_images, dim=0).to(self.device)
91
+ target_labels = torch.cat(target_labels, dim=0).to(self.device)
92
+ target_images, target_labels = ep.astensors(*(target_images[:self.batch_size], target_labels[:self.batch_size]))
93
+
94
+ total_loss = []
95
+ for epoch in range(1, self.num_epoch + 1):
96
+ laoder_tqdm = tqdm(self.loader, total=len(self.loader), desc=f'Epoch {epoch}')
97
+ loss_np = 0
98
+ for i, (_, x, y) in enumerate(laoder_tqdm):
99
+ x_f = x[y != self.target_class].to(self.device)
100
+ del x, y
101
+ if len(x_f) > len(target_images):
102
+ x_f = x_f[:len(target_images)].to(self.device)
103
+ else:
104
+ random_indices = torch.randperm(len(target_images))[:len(x_f)].to(self.device)
105
+ target_images = target_images[random_indices]
106
+
107
+ x_adv = self.adv_generator(x_f, target_images.raw.to(self.device))
108
+
109
+ logits_nat = self.surrogate_model(self.norm(x_f))
110
+ feat_nat = self.feat_collecter.pop()
111
+ logits_tar = self.surrogate_model(self.norm(target_images.raw))
112
+ feat_tar = self.feat_collecter.pop()
113
+ logits_adv = self.surrogate_model(self.norm(x_adv))
114
+ feat_adv = self.feat_collecter.pop()
115
+
116
+ loss = (self.contrastive_loss(logits_adv, logits_nat, logits_tar) +
117
+ self.sim_loss(feat_nat, feat_adv) -
118
+ self.sim_loss(feat_tar, feat_adv)).mean()
119
+ # print(loss.item())
120
+ loss_np = loss_np + loss.item()
121
+
122
+ self.optim.zero_grad()
123
+ loss.backward()
124
+ self.optim.step()
125
+ del x_f, x_adv, logits_nat, logits_adv, logits_tar, feat_nat, feat_tar, feat_adv
126
+ torch.cuda.empty_cache()
127
+ total_loss.append(loss_np/(i+1))
128
+ logging.info(f'Epoch {epoch} loss: {loss_np/(len(self.loader))}')
129
+ logging.info(f'Total loss: {total_loss}')
130
+ self.feat_collecter_handler.remove()
131
+ self.adv_generator.save_ckpt(self.file_path)
132
+
133
+
134
+ def run_test(self):
135
+ # Load Batch Data
136
+ self.adv_generator.set_mode('eval')
137
+ self.adv_generator.adv_gen.to(self.device)
138
+ self.loader = get_dataloader(self.data_manager, batch_size=self.eval_batch_szie,
139
+ start_class=0, end_class=10,
140
+ train=False, shuffle=False, num_workers=0)
141
+ target_images = []
142
+ target_labels = []
143
+ for data in self.loader:
144
+ _, image_batch, label_batch = data
145
+ mask = label_batch == self.target_class
146
+ selected_images = image_batch[mask]
147
+ selected_labels = label_batch[mask]
148
+ target_images.append(selected_images)
149
+ target_labels.append(selected_labels)
150
+ target_imgs = torch.cat(target_images, dim=0).to(self.device)
151
+ target_labels = torch.cat(target_labels, dim=0).to(self.device)
152
+ target_imgs, target_labels = ep.astensors(*(target_imgs, target_labels))
153
+ for i, (_, imgs, labels) in enumerate(tqdm(self.loader, total=len(self.loader),
154
+ desc=f'Loading Data with Batch Size of {self.batch_size}) :')):
155
+ if i > 0:
156
+ break
157
+
158
+ imgs, labels = ep.astensors(*(imgs.to(self.device), labels.to(self.device)))
159
+
160
+ imgs_f = imgs[labels != self.target_class]
161
+ labels_f = labels[labels != self.target_class]
162
+ labels_t_f = ep.full_like(labels_f, fill_value=self.target_class)
163
+
164
+ self.attacks(i, imgs_f, labels_f, labels_t_f, target_imgs[:20], target_labels[:20])
165
+
166
+
167
+ def attacks(self, i_batch, imgs, labels, labels_t, target_imgs=None, target_labels=None):
168
+ asr_matrix = np.ones((10, len(target_imgs)))
169
+ self.model = factory.get_model(self.args["model_name"], self.args)
170
+ for task in range(10):
171
+ logging.info("***** Starting attack on task [{}]. *****".format(task))
172
+ self.model.incremental_train(self.data_manager)
173
+ self.model._network.load_state_dict(torch.load(self.ckpt_paths[task], map_location=self.device)['model_state_dict'])
174
+ self.model._network.to(self.device)
175
+ self.model._network.eval()
176
+
177
+ # Run attack on ecah target image
178
+ criterion = fb.criteria.Misclassification(
179
+ labels) if self.target_class is None else fb.criteria.TargetedMisclassification(
180
+ labels_t)
181
+ current_model = PyTorchModel(self.model._network, bounds=(0, 1), preprocessing=self.preprocessing)
182
+ verify_input_bounds(imgs, current_model)
183
+ criterion = get_criterion(criterion)
184
+ is_adversarial = get_is_adversarial(criterion, current_model)
185
+
186
+ logging.info("Eval attack on each target images.")
187
+ for i, target_image in enumerate(target_imgs):
188
+ advs = ep.astensor(self.adv_generator(imgs.raw.to(self.device), target_image.raw.repeat(len(imgs), 1, 1, 1).to(self.device)))
189
+ is_adv = is_adversarial(advs)[0]
190
+ asr_matrix[task, i] = (is_adv.bool().sum().raw.item() / len(imgs))
191
+ if self.plot_gradcam:
192
+ save_grad_cam(self.args, torch.clip(advs.raw.detach(),0,1), labels_t.raw,
193
+ self.model._network, self.save_path + "/GradCam" + f"targetimg{i}", prefix=f'task{task}',
194
+ layer_name='stage_3', save_num=100, save_raw=True)
195
+
196
+ del advs, is_adv, target_image
197
+ torch.cuda.empty_cache()
198
+
199
+ del criterion, current_model, is_adversarial
200
+ torch.cuda.empty_cache()
201
+
202
+ self.model.after_task()
203
+
204
+ # Save all target images info: everage asr,
205
+ asr_matrix = np.mean(asr_matrix, axis=1, keepdims=True)
206
+ prefix = f'batch{i_batch}_{self.prefix}'
207
+ plot_asr_per_target(asr_matrix, self.save_path, prefix, self.args)
208
+ df = pd.DataFrame(asr_matrix, columns=['ASR'])
209
+ df.to_excel(os.path.join(self.save_path, f"{prefix}.xlsx"), index=False)
210
+
211
+ del asr_matrix, imgs, labels, labels_t, target_imgs
212
+ torch.cuda.empty_cache()
213
+
214
+ def __call__(
215
+ self,
216
+ model: Model,
217
+ inputs: T,
218
+ criterion: Any,
219
+ *,
220
+ epsilons: Union[Sequence[Union[float, None]], float, None],
221
+ **kwargs: Any,
222
+ ) -> Union[Tuple[List[T], List[T], T], Tuple[T, T, T]]:
223
+ ...
224
+
225
+ def repeat(self, times: int) -> "AIM":
226
+ ...
attacks/AIM/examples/aim_attack.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Usage:
3
+ # ./examples/aim_attack.py -h
4
+ import argparse
5
+ import json
6
+ from pathlib import Path
7
+ from pprint import pformat
8
+ from typing import List, Union
9
+
10
+ import torch
11
+ from tqdm import tqdm
12
+
13
+ from attacks.GAT.src.gat.datasets import build_dataset, list_datasets
14
+ from attacks.GAT.src.gat.datasets.transforms import norm
15
+ from attacks.GAT.src.gat.models.attack import AIMAttack, ContrastiveLoss
16
+ from attacks.GAT.src.gat.models.surrogate import (build_surrogate, list_surrogates,
17
+ midlayer_dict, register_collecter)
18
+ from attacks.GAT.src.gat.runtime import AverageMeter, calc_cls_accuracy, fix_random, randid
19
+
20
+
21
+ def parse_args():
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument('-v', '--verbose', action='store_true')
24
+ parser.add_argument('--seed', type=int, default=0)
25
+ parser.add_argument('--expid', type=str, default=randid(4))
26
+ parser.add_argument('--workdir', type=str, default='workdirs')
27
+ parser.add_argument('--device', type=str, default='cuda')
28
+ parser.add_argument('--tar-classes', type=int, default=24)
29
+ parser.add_argument('--batch-size', type=int, default=16)
30
+ parser.add_argument('--dataset',
31
+ type=str,
32
+ default='imagenet',
33
+ choices=list_datasets())
34
+ parser.add_argument('--data-root',
35
+ type=str,
36
+ default=Path(__file__).parent / '../data' / 'in_1k')
37
+ sub_parsers = parser.add_subparsers(dest='command')
38
+ train_parser = sub_parsers.add_parser('train')
39
+ train_parser.add_argument('--surrogate-id',
40
+ type=str,
41
+ default='resnet152',
42
+ choices=list_surrogates())
43
+ train_parser.add_argument('--num-epoch', type=int, default=10)
44
+ train_parser.add_argument('--lr', type=float, default=0.0002)
45
+ train_parser.add_argument('--betas',
46
+ type=float,
47
+ nargs=2,
48
+ default=(0.5, 0.999))
49
+ sub_parsers.add_parser('evaluate')
50
+ args = parser.parse_args()
51
+
52
+ args.workdir = Path(args.workdir) / args.expid
53
+ args.workdir.mkdir(parents=True, exist_ok=True)
54
+ args.device = torch.device(args.device)
55
+
56
+ args.ckpt = args.workdir / 'model.pth'
57
+
58
+ fix_random(args.seed)
59
+
60
+ with open(args.workdir / 'args.txt', 'w') as f:
61
+ f.write(pformat(vars(args)))
62
+
63
+ return args
64
+
65
+
66
+ def init_loader(dataset: str,
67
+ data_root: Union[str, Path],
68
+ tar_classes: Union[int, List[int]],
69
+ batch_size: int = 16,
70
+ command: str = 'train') -> List[torch.utils.data.DataLoader]:
71
+ train_ds = build_dataset(dataset,
72
+ data_root=data_root,
73
+ is_train=(command == 'train'))
74
+ train_loader = torch.utils.data.DataLoader(
75
+ train_ds,
76
+ batch_size=batch_size,
77
+ shuffle=True,
78
+ num_workers=4,
79
+ pin_memory=True,
80
+ )
81
+ target_ds = build_dataset(dataset,
82
+ data_root=data_root,
83
+ is_train=True,
84
+ filter_class=tar_classes)
85
+ target_loader = torch.utils.data.DataLoader(
86
+ target_ds,
87
+ batch_size=batch_size,
88
+ sampler=torch.utils.data.RandomSampler(target_ds,
89
+ replacement=True,
90
+ num_samples=len(train_ds)),
91
+ num_workers=4,
92
+ pin_memory=True,
93
+ )
94
+ return train_loader, target_loader
95
+
96
+
97
+ def train(
98
+ surrogate_id: str,
99
+ dataset: str,
100
+ data_root: Union[str, Path],
101
+ tar_classes: Union[int, List[int]] = 24,
102
+ num_epoch: int = 10,
103
+ batch_size: int = 16,
104
+ lr: float = 0.0002,
105
+ betas: Union[float, List[float]] = (0.5, 0.999),
106
+ device: Union[str, torch.device] = torch.device('cuda'),
107
+ command: str = 'train',
108
+ workdir: Union[str,
109
+ Path] = Path(__file__).parents[1] / 'workdirs') -> None:
110
+
111
+ train_loader, target_loader = init_loader(dataset, data_root, tar_classes,
112
+ batch_size, command)
113
+ normalizer = norm(dataset, _callable=True)
114
+
115
+ surrogate = build_surrogate(surrogate_id, pretrain=True).to(device)
116
+ surrogate.eval()
117
+ feat_collecter_handler, feat_collecter = register_collecter(
118
+ surrogate, midlayer_dict[surrogate_id])
119
+
120
+ attack = AIMAttack(device=device)
121
+ attack.set_mode('train')
122
+ optim = torch.optim.Adam(attack.get_params(), lr=lr, betas=betas)
123
+
124
+ contrastive_loss = ContrastiveLoss(0.2)
125
+ sim_loss = torch.nn.functional.cosine_similarity
126
+
127
+ for epoch in range(1, num_epoch + 1):
128
+ attack.set_mode('train')
129
+ enumerator = enumerate(zip(train_loader, target_loader))
130
+ enumerator = tqdm(enumerator,
131
+ total=len(train_loader),
132
+ desc=f'Epoch {epoch}')
133
+ for batch_idx, ((x_nat, y_nat), (x_tar, y_tar)) in enumerator:
134
+ if torch.any(y_nat == y_tar):
135
+ continue
136
+ x_nat, x_tar = x_nat.to(device), x_tar.to(device)
137
+ y_nat, y_tar = y_nat.to(device), y_tar.to(device)
138
+ x_adv = attack(x_nat, x_tar)
139
+
140
+ logits_nat = surrogate(normalizer(x_nat))
141
+ feat_nat = feat_collecter.pop()
142
+ logits_tar = surrogate(normalizer(x_tar))
143
+ feat_tar = feat_collecter.pop()
144
+ logits_adv = surrogate(normalizer(x_adv))
145
+ feat_adv = feat_collecter.pop()
146
+
147
+ loss = (contrastive_loss(logits_adv, logits_nat, logits_tar) +
148
+ sim_loss(feat_nat, feat_adv) -
149
+ sim_loss(feat_tar, feat_adv)).mean()
150
+
151
+ optim.zero_grad()
152
+ loss.backward()
153
+ optim.step()
154
+
155
+ feat_collecter_handler.remove()
156
+
157
+ attack.save_ckpt(workdir / 'model.pth')
158
+
159
+
160
+ @torch.no_grad()
161
+ def evaluate(
162
+ ckpt: Union[str, Path],
163
+ dataset: str,
164
+ data_root: Union[str, Path],
165
+ tar_classes: Union[int, List[int]] = 24,
166
+ batch_size: int = 16,
167
+ device: Union[str, torch.device] = torch.device('cuda'),
168
+ command: str = 'train',
169
+ workdir: Union[str,
170
+ Path] = Path(__file__).parents[1] / 'workdirs') -> None:
171
+ # init dataloader
172
+ eval_loader, target_loader = init_loader(dataset, data_root, tar_classes,
173
+ batch_size, command)
174
+ normalizer = norm(dataset, _callable=True)
175
+ # init attack method
176
+ attack = AIMAttack(device=device)
177
+ attack.load_ckpt(ckpt)
178
+ attack.set_mode('eval')
179
+ # init evaluate models
180
+ models = {
181
+ surrogate_id: build_surrogate(surrogate_id, pretrain=True).to(device)
182
+ for surrogate_id in list_surrogates()
183
+ }
184
+ for surrogate_id in models.keys():
185
+ models[surrogate_id].eval()
186
+ model_meters = {
187
+ surrogate_id: [AverageMeter() for _ in range(2)]
188
+ for surrogate_id in models.keys()
189
+ }
190
+ # evaluate
191
+ enumerator = enumerate(zip(eval_loader, target_loader))
192
+ enumerator = tqdm(enumerator, total=len(eval_loader), desc='Eval')
193
+ for batch_idx, ((x_nat, y_nat), (x_tar, y_tar)) in enumerator:
194
+ x_nat, y_nat = x_nat.to(device), y_nat.to(device)
195
+ x_tar, y_tar = x_tar.to(device), y_tar.to(device)
196
+ x_adv = attack(x_nat, x_tar)
197
+ for surrogate_id, model in models.items():
198
+ logits_nat = model(normalizer(x_nat))
199
+ logits_adv = model(normalizer(x_adv))
200
+ # collect metrics
201
+ acc = calc_cls_accuracy(logits_nat, y_nat)
202
+ asr = calc_cls_accuracy(logits_adv, y_tar)
203
+ model_meters[surrogate_id][0].update(acc[0].item(), x_nat.size(0))
204
+ model_meters[surrogate_id][1].update(asr[0].item(), x_nat.size(0))
205
+ # print result
206
+ results = {
207
+ surrogate_id: {
208
+ 'acc': meters[0].avg,
209
+ 'asr': meters[1].avg
210
+ }
211
+ for surrogate_id, meters in model_meters.items()
212
+ }
213
+ print(pformat(results))
214
+ with open(workdir / 'results.json', 'w') as f:
215
+ json.dump(results, f)
216
+
217
+
218
+ def main() -> None:
219
+ args = parse_args()
220
+ args.command = 'train'
221
+ args.surrogate_id = 'resnet152'
222
+ args.num_epoch = 10
223
+ args.lr = 0.0002
224
+ args.betas = (0.5, 0.999)
225
+ if args.command == 'train':
226
+ train(args.surrogate_id, args.dataset, args.data_root,
227
+ args.tar_classes, args.num_epoch, args.batch_size, args.lr,
228
+ args.betas, args.device, args.command, args.workdir)
229
+ elif args.command == 'evaluate':
230
+ evaluate(args.ckpt, args.dataset, args.data_root, args.tar_classes,
231
+ args.batch_size, args.device, args.command, args.workdir)
232
+ else:
233
+ raise NotImplementedError
234
+
235
+
236
+ if __name__ == '__main__':
237
+ main()
attacks/AIM/examples/ens-gen.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Usage
4
+ -----
5
+ ./examples/ens-gen.py -d -v -s 0 \
6
+ --dataset imagenet -b 16 --eps 8 --workdir "workdirs" \
7
+ --device "cuda:0" \
8
+ train --n-ep 1 \
9
+ --surrogate-model-ids vgg19 inception_v3 resnet152 densenet169 \
10
+ --lr 0.0002 --beta 0.5 0.999 \
11
+ --use-logit-loss --use-logit-weights --use-logit-softmax-weights
12
+ """
13
+ import argparse
14
+ import json
15
+ from pathlib import Path
16
+ from pprint import pformat
17
+ from typing import List, Union
18
+
19
+ import torch
20
+ import torchvision
21
+ from torch.nn import functional as F
22
+ from torch.utils.tensorboard import SummaryWriter
23
+ from tqdm import tqdm
24
+
25
+ from gat.datasets import build_dataset, list_datasets
26
+ from gat.datasets.transforms import norm
27
+ from gat.models.attack import CDAAttack
28
+ from gat.models.attack.optim import (SAM, disable_running_stats,
29
+ enable_running_stats)
30
+ from gat.models.surrogate import (build_surrogate, feat_col, list_surrogates,
31
+ midlayer_dict)
32
+ from gat.runtime import AverageMeter, calc_cls_accuracy, fix_random, randid
33
+
34
+
35
+ class CLIParser:
36
+
37
+ @staticmethod
38
+ def init_basic_parser(p: argparse.ArgumentParser):
39
+ g_basic = p.add_argument_group('Basic Settings')
40
+ g_basic.add_argument('-v',
41
+ '--verbose',
42
+ action='store_true',
43
+ default=False)
44
+ g_basic.add_argument('-d', '--dev', action='store_true', default=False)
45
+ g_basic.add_argument('-s', '--seed', type=int, default=0)
46
+ g_basic.add_argument('--expid', type=str, default=randid(4))
47
+ g_basic.add_argument('--device', type=str, default='cuda')
48
+
49
+ g_path = p.add_argument_group('Path Settings')
50
+ g_path.add_argument('--workdir', type=str, default='workdirs')
51
+ g_path.add_argument('--data-root',
52
+ type=str,
53
+ default=Path(__file__).parent / '../data' /
54
+ 'in_1k')
55
+
56
+ g_ds = p.add_argument_group('Dataset Settings')
57
+ g_ds.add_argument('--dataset',
58
+ type=str,
59
+ default='imagenet',
60
+ choices=list_datasets())
61
+ g_ds.add_argument('-b', '--batch-size', type=int, default=16)
62
+
63
+ g_at_basic = p.add_argument_group('General Attack Settings')
64
+ g_at_basic.add_argument('--eps',
65
+ '--epsilon',
66
+ dest='epsilon',
67
+ type=int,
68
+ default=8,
69
+ choices=[1, 2, 4, 8, 16])
70
+
71
+ @staticmethod
72
+ def post_basic_parser(args: argparse.Namespace):
73
+ if args.dev:
74
+ args.workdir = args.workdir.replace('workdirs', 'workdirs-dev')
75
+ args.workdir = Path(args.workdir) / args.expid
76
+ args.workdir.mkdir(parents=True, exist_ok=True)
77
+ args.device = torch.device(args.device)
78
+
79
+ args.ckpt = args.workdir / 'model.pth'
80
+ args.tf_logger = SummaryWriter(args.workdir / 'tf_log')
81
+
82
+ args.epsilon /= 255.0
83
+ if args.command == 'evaluate-pgd':
84
+ args.alpha /= 255.0
85
+
86
+ fix_random(args.seed)
87
+
88
+ if args.verbose:
89
+ print(pformat(vars(args)))
90
+ with open(args.workdir / f'args-{args.command}.txt', 'w') as f:
91
+ f.write(pformat(vars(args)))
92
+
93
+ @staticmethod
94
+ def init_train_parser(p: argparse.ArgumentParser):
95
+ g_at = p.add_argument_group('Attack Settings')
96
+ g_at.add_argument('--sur-ids',
97
+ '--surrogate-model-ids',
98
+ dest='surrogate_model_ids',
99
+ type=str,
100
+ default=['resnet152'],
101
+ nargs='+',
102
+ choices=list_surrogates())
103
+ g_at.add_argument('--n-ep',
104
+ '--num-epoch',
105
+ dest='num_epoch',
106
+ type=int,
107
+ default=10)
108
+
109
+ g_optim = p.add_argument_group('Optimization Settings')
110
+ g_optim.add_argument('--use-sam', action='store_true', default=False)
111
+ g_optim.add_argument('--lr', type=float, default=0.0002)
112
+ g_optim.add_argument('--betas',
113
+ type=float,
114
+ nargs=2,
115
+ default=(0.5, 0.999))
116
+
117
+ g_loss = p.add_argument_group('Loss Func Settings')
118
+ g_loss.add_argument('--use-logit-loss',
119
+ action='store_true',
120
+ default=False)
121
+ g_loss.add_argument('--use-logit-kl',
122
+ action='store_true',
123
+ default=False)
124
+ g_loss.add_argument('--use-logit-weights',
125
+ action='store_true',
126
+ default=False)
127
+ g_loss.add_argument('--use-logit-softmax-weights',
128
+ action='store_true',
129
+ default=False)
130
+ g_loss.add_argument('--use-feat-loss',
131
+ action='store_true',
132
+ default=False)
133
+ g_loss.add_argument('--use-feat-attn',
134
+ action='store_true',
135
+ default=False)
136
+
137
+ @staticmethod
138
+ def post_train_parser(args: argparse.Namespace):
139
+ if args.command == 'train':
140
+ assert args.use_logit_loss ^ args.use_feat_loss
141
+ if args.use_logit_kl:
142
+ assert not args.use_feat_loss
143
+ if args.use_feat_attn:
144
+ assert not args.use_logit_loss
145
+ if args.use_logit_weights:
146
+ assert args.use_logit_loss
147
+ if args.use_logit_softmax_weights:
148
+ assert args.use_logit_loss
149
+
150
+ @staticmethod
151
+ def init_evaluate_parser(p: argparse.ArgumentParser):
152
+ pass
153
+
154
+ @staticmethod
155
+ def post_evaluate_parser(args: argparse.Namespace):
156
+ pass
157
+
158
+ @staticmethod
159
+ def init_evaluate_pgd_parser(p: argparse.ArgumentParser):
160
+ g_at = p.add_argument_group('Attack Settings')
161
+ g_at.add_argument('--surrogate-model-ids',
162
+ type=str,
163
+ default=['resnet152'],
164
+ nargs='+',
165
+ choices=list_surrogates())
166
+
167
+ g_optim = p.add_argument_group('Optimization Settings')
168
+ g_optim.add_argument('--num-step', type=int, default=100)
169
+ g_optim.add_argument('--alpha',
170
+ type=int,
171
+ default=2,
172
+ choices=[1, 2, 4, 8, 16])
173
+
174
+ g_loss = p.add_argument_group('Loss Func Settings')
175
+ g_loss.add_argument('--use-loss-avg',
176
+ action='store_true',
177
+ default=False)
178
+ g_loss.add_argument('--use-logit-avg',
179
+ action='store_true',
180
+ default=False)
181
+
182
+ @staticmethod
183
+ def post_evaluate_pgd_parser(args: argparse.Namespace):
184
+ if args.command == 'evaluate-pgd':
185
+ assert args.use_loss_avg ^ args.use_logit_avg
186
+
187
+ @staticmethod
188
+ def parse_args():
189
+ p = argparse.ArgumentParser()
190
+ CLIParser.init_basic_parser(p)
191
+ sub_p = p.add_subparsers(dest='command')
192
+
193
+ CLIParser.init_train_parser(sub_p.add_parser('train'))
194
+ CLIParser.init_evaluate_parser(sub_p.add_parser('evaluate'))
195
+ CLIParser.init_evaluate_pgd_parser(sub_p.add_parser('evaluate-pgd'))
196
+ args = p.parse_args()
197
+ CLIParser.post_train_parser(args)
198
+ CLIParser.post_evaluate_parser(args)
199
+ CLIParser.post_evaluate_pgd_parser(args)
200
+
201
+ CLIParser.post_basic_parser(args)
202
+
203
+ return args
204
+
205
+
206
+ def init_loader(dataset: str,
207
+ data_root: Union[str, Path],
208
+ num_epoch: int = 1,
209
+ batch_size: int = 16,
210
+ command: str = 'train') -> List[torch.utils.data.DataLoader]:
211
+ ds = build_dataset(dataset,
212
+ data_root=data_root,
213
+ is_train=(command == 'train'))
214
+ dataloader = torch.utils.data.DataLoader(
215
+ ds,
216
+ batch_size=batch_size,
217
+ sampler=torch.utils.data.RandomSampler(ds,
218
+ replacement=True,
219
+ num_samples=len(ds) *
220
+ num_epoch),
221
+ num_workers=4,
222
+ pin_memory=True,
223
+ )
224
+ normalizer = norm(dataset, _callable=True)
225
+ return dataloader, normalizer
226
+
227
+
228
+ def init_models(model_ids: Union[str, List[str]],
229
+ device: Union[str, torch.device] = torch.device('cuda')):
230
+ if isinstance(model_ids, str):
231
+ model_ids = [model_ids]
232
+ models = [
233
+ build_surrogate(_surrogate_id, pretrain=True).to(device)
234
+ for _surrogate_id in model_ids
235
+ ]
236
+ for _ in models:
237
+ _.eval()
238
+ return models
239
+
240
+
241
+ def calc_loss(x_nat: torch.Tensor,
242
+ y_nat: torch.Tensor,
243
+ x_adv: torch.Tensor,
244
+ feat_collecter: List,
245
+ surrogate_models: List[torch.nn.Module],
246
+ normalizer: torchvision.transforms.Compose,
247
+ use_logit_loss: bool,
248
+ use_logit_kl: bool,
249
+ use_logit_weights: bool,
250
+ use_logit_softmax_weights: bool,
251
+ use_feat_loss: bool,
252
+ use_feat_attn: bool,
253
+ device: Union[str, torch.device] = torch.device('cuda')):
254
+ loss_sur = []
255
+ for surrogate_model in surrogate_models:
256
+ logit_nat = surrogate_model(normalizer(x_nat))
257
+ feat_nat = feat_collecter.pop()
258
+ logit_adv = surrogate_model(normalizer(x_adv))
259
+ feat_adv = feat_collecter.pop()
260
+ if use_logit_loss:
261
+ if use_logit_kl:
262
+ loss_sur.append(-(F.kl_div(F.log_softmax(logit_adv, dim=1),
263
+ F.softmax(logit_nat, dim=1)) +
264
+ F.kl_div(F.log_softmax(logit_nat, dim=1),
265
+ F.softmax(logit_adv, dim=1))))
266
+ else:
267
+ loss_sur.append(-(F.cross_entropy(logit_adv, y_nat).mean()))
268
+ elif use_feat_loss:
269
+ if use_feat_attn:
270
+ attn = torch.abs(torch.mean(feat_nat, dim=1, keepdim=True))
271
+ else:
272
+ attn = torch.ones_like(feat_nat)
273
+ loss_sur.append(1 + F.cosine_similarity(attn * feat_nat, attn *
274
+ feat_adv).mean())
275
+ else:
276
+ raise NotImplementedError
277
+ loss_sur = torch.stack(loss_sur)
278
+ if use_logit_weights:
279
+ if use_logit_softmax_weights:
280
+ loss_weights = torch.nn.functional.softmax(loss_sur)
281
+ else:
282
+ loss_weights = torch.nn.functional.softmin(loss_sur)
283
+ loss_all = torch.sum(loss_weights * loss_sur)
284
+ else:
285
+ loss_all = loss_sur.mean()
286
+
287
+ return loss_all
288
+
289
+
290
+ def train(surrogate_model_ids: Union[str, List[str]],
291
+ epsilon: float = 16.0 / 255.0,
292
+ num_epoch: int = 10,
293
+ dataset: str = 'imagenet',
294
+ batch_size: int = 16,
295
+ use_sam: bool = False,
296
+ lr: float = 0.0002,
297
+ betas: Union[float, List[float]] = (0.5, 0.999),
298
+ use_logit_loss: bool = False,
299
+ use_logit_kl: bool = False,
300
+ use_logit_weights: bool = False,
301
+ use_logit_softmax_weights: bool = False,
302
+ use_feat_loss: bool = False,
303
+ use_feat_attn: bool = False,
304
+ device: Union[str, torch.device] = torch.device('cuda'),
305
+ workdir: Union[str, Path] = Path(__file__).parents[1] / 'workdirs',
306
+ data_root: Union[str,
307
+ Path] = Path(__file__).parent / '../data' / 'in_1k',
308
+ tf_logger: SummaryWriter = None) -> None:
309
+ """
310
+ Train the attack model with the given surrogate models.
311
+ """
312
+ loader, normalizer = init_loader(dataset, data_root, num_epoch, batch_size,
313
+ 'train')
314
+ surrogate_models = init_models(surrogate_model_ids, device)
315
+
316
+ attack = CDAAttack(device=device, epsilon=epsilon)
317
+ attack.set_mode('train')
318
+ if use_sam:
319
+ optim = SAM(attack.get_params(), torch.optim.Adam, lr=lr, betas=betas)
320
+ else:
321
+ optim = torch.optim.Adam(attack.get_params(), lr=lr, betas=betas)
322
+
323
+ with feat_col(surrogate_models,
324
+ [midlayer_dict[_]
325
+ for _ in surrogate_model_ids]) as feat_collecter:
326
+ attack.set_mode('train')
327
+ enumerator = tqdm(enumerate(loader), total=len(loader), desc='')
328
+ for step, (x_nat, y_nat) in enumerator:
329
+ x_nat, y_nat = x_nat.to(device), y_nat.to(device)
330
+
331
+ if use_sam:
332
+ # 1
333
+ enable_running_stats(attack.get_model())
334
+ loss_v = calc_loss(x_nat, y_nat, attack(x_nat), feat_collecter,
335
+ surrogate_models, normalizer,
336
+ use_logit_loss, use_logit_kl,
337
+ use_logit_weights,
338
+ use_logit_softmax_weights, use_feat_loss,
339
+ use_feat_attn, device)
340
+ loss_v.backward()
341
+ optim.first_step(zero_grad=True)
342
+ # 2
343
+ disable_running_stats(attack.get_model())
344
+ calc_loss(x_nat, y_nat, attack(x_nat), feat_collecter,
345
+ surrogate_models, normalizer, use_logit_loss,
346
+ use_logit_kl, use_logit_weights,
347
+ use_logit_softmax_weights, use_feat_loss,
348
+ use_feat_attn, device).backward()
349
+ optim.second_step(zero_grad=True)
350
+ else:
351
+ x_adv = attack(x_nat)
352
+ loss_v = calc_loss(x_nat, y_nat, x_adv, feat_collecter,
353
+ surrogate_models, normalizer,
354
+ use_logit_loss, use_logit_kl,
355
+ use_logit_weights,
356
+ use_logit_softmax_weights, use_feat_loss,
357
+ use_feat_attn, device)
358
+ optim.zero_grad()
359
+ loss_v.backward()
360
+ optim.step()
361
+
362
+ if tf_logger:
363
+ tf_logger.add_scalar('loss', loss_v.item(), step)
364
+ tf_logger.add_scalar('lr', optim.param_groups[0]['lr'], step)
365
+
366
+ attack.save_ckpt(workdir / 'model.pth')
367
+
368
+
369
+ @torch.no_grad()
370
+ def evaluate(
371
+ ckpt: Union[str, Path],
372
+ epsilon: float = 16.0 / 255.0,
373
+ dataset: str = 'imagenet',
374
+ batch_size: int = 16,
375
+ device: Union[str, torch.device] = torch.device('cuda'),
376
+ workdir: Union[str, Path] = Path(__file__).parents[1] / 'workdirs',
377
+ data_root: Union[str, Path] = Path(__file__).parent / '../data' / 'in_1k',
378
+ ) -> None:
379
+ """
380
+ Evaluate the attack model with the given surrogate models
381
+ """
382
+ loader, normalizer = init_loader(dataset, data_root, 1, batch_size,
383
+ 'evaluate')
384
+ target_models = {
385
+ k: v
386
+ for k, v in zip(list_surrogates(),
387
+ init_models(list_surrogates(), device))
388
+ }
389
+ target_acc_meters = {
390
+ target_model_id: [AverageMeter() for _ in range(2)]
391
+ for target_model_id in target_models.keys()
392
+ }
393
+ # init attack method
394
+ attack = CDAAttack(device=device, epsilon=epsilon)
395
+ attack.load_ckpt(ckpt)
396
+ attack.set_mode('eval')
397
+ # evaluate
398
+ enumerator = tqdm(enumerate(loader), total=len(loader), desc='Eval')
399
+ for step, (x_nat, y_nat) in enumerator:
400
+ x_nat, y_nat = x_nat.to(device), y_nat.to(device)
401
+ x_adv = attack(x_nat)
402
+ for target_model_id, target_model in target_models.items():
403
+ logit_nat = target_model(normalizer(x_nat))
404
+ logit_adv = target_model(normalizer(x_adv))
405
+ # collect metrics
406
+ target_acc = calc_cls_accuracy(logit_nat, y_nat)
407
+ target_asr = calc_cls_accuracy(logit_adv, y_nat)
408
+ target_acc_meters[target_model_id][0].update(
409
+ target_acc[0].item(), x_nat.size(0))
410
+ target_acc_meters[target_model_id][1].update(
411
+ target_asr[0].item(), x_nat.size(0))
412
+ results = {
413
+ target_model_id: {
414
+ 'nat_acc': target_acc_meter[0].avg,
415
+ 'adv_acc': target_acc_meter[1].avg
416
+ }
417
+ for target_model_id, target_acc_meter in target_acc_meters.items()
418
+ }
419
+ print(pformat(results))
420
+ with open(workdir / 'results.json', 'w') as f:
421
+ json.dump(results, f)
422
+
423
+
424
+ def evaluate_pgd(
425
+ surrogate_model_ids: Union[str, List[str]],
426
+ epsilon: float = 16.0 / 255.0,
427
+ num_step: int = 1000,
428
+ alpha: float = 2.0 / 255.0,
429
+ dataset: str = 'imagenet',
430
+ batch_size: int = 16,
431
+ use_loss_avg: bool = False,
432
+ use_logit_avg: bool = False,
433
+ device: Union[str, torch.device] = torch.device('cuda'),
434
+ workdir: Union[str, Path] = Path(__file__).parents[1] / 'workdirs',
435
+ data_root: Union[str, Path] = Path(__file__).parent / '../data' / 'in_1k',
436
+ ):
437
+ loader, normalizer = init_loader(dataset, data_root, 1, batch_size,
438
+ 'evaluate')
439
+ surrogate_models = init_models(surrogate_model_ids, device)
440
+ target_models = {
441
+ k: v
442
+ for k, v in zip(list_surrogates(),
443
+ init_models(list_surrogates(), device))
444
+ }
445
+ target_acc_meters = {
446
+ target_model_id: [AverageMeter() for _ in range(2)]
447
+ for target_model_id in target_models.keys()
448
+ }
449
+ # evaluate
450
+ enumerator = tqdm(enumerate(loader), total=len(loader), desc='')
451
+ for step, (x_nat, y_nat) in enumerator:
452
+ x_nat, y_nat = x_nat.to(device), y_nat.to(device)
453
+ # attack
454
+ x_nat_ori = x_nat.data
455
+ for _ in range(num_step):
456
+ x_nat.requires_grad = True
457
+ if use_loss_avg:
458
+ loss_all = 0.0
459
+ for surrogate_model in surrogate_models:
460
+ logit = surrogate_model(x_nat)
461
+ surrogate_model.zero_grad()
462
+ loss_all += F.cross_entropy(logit, y_nat)
463
+ elif use_logit_avg:
464
+ logit = torch.stack([
465
+ surrogate_model(x_nat)
466
+ for surrogate_model in surrogate_models
467
+ ]).mean(dim=0)
468
+ loss_all = F.cross_entropy(logit, y_nat)
469
+ else:
470
+ raise NotADirectoryError
471
+ loss_all.backward()
472
+ x_adv_ = x_nat + alpha * x_nat.grad.sign()
473
+ eta = torch.clamp(x_adv_ - x_nat_ori, min=-epsilon, max=epsilon)
474
+ x_nat = torch.clamp(x_nat_ori + eta, min=0.0, max=1.0).detach_()
475
+ x_adv = x_nat
476
+ x_nat = x_nat_ori
477
+ # eval
478
+ with torch.no_grad():
479
+ for target_model_id, target_model in target_models.items():
480
+ logit_nat = target_model(normalizer(x_nat))
481
+ logit_adv = target_model(normalizer(x_adv))
482
+ # collect
483
+ target_acc_ = calc_cls_accuracy(logit_nat, y_nat)
484
+ target_asr_ = calc_cls_accuracy(logit_adv, y_nat)
485
+ target_acc_meters[target_model_id][0].update(
486
+ target_acc_[0].item(), x_nat.size(0))
487
+ target_acc_meters[target_model_id][1].update(
488
+ target_asr_[0].item(), x_nat.size(0))
489
+ results = {
490
+ target_model_id: {
491
+ 'nat_acc': target_acc_meter[0].avg,
492
+ 'adv_acc': target_acc_meter[1].avg
493
+ }
494
+ for target_model_id, target_acc_meter in target_acc_meters.items()
495
+ }
496
+ print(pformat(results))
497
+ with open(workdir / 'results-pgd.json', 'w') as f:
498
+ json.dump(results, f)
499
+
500
+
501
+ def main() -> None:
502
+ args = CLIParser.parse_args()
503
+ if args.command == 'train':
504
+ train(args.surrogate_model_ids, args.epsilon, args.num_epoch,
505
+ args.dataset, args.batch_size, args.use_sam, args.lr, args.betas,
506
+ args.use_logit_loss, args.use_logit_kl, args.use_logit_weights,
507
+ args.use_logit_softmax_weights, args.use_feat_loss,
508
+ args.use_feat_attn, args.device, args.workdir, args.data_root,
509
+ args.tf_logger)
510
+ elif args.command == 'evaluate':
511
+ evaluate(args.ckpt, args.epsilon, args.dataset, args.batch_size,
512
+ args.device, args.workdir, args.data_root)
513
+ elif args.command == 'evaluate-pgd':
514
+ evaluate_pgd(args.surrogate_model_ids, args.epsilon, args.num_step,
515
+ args.alpha, args.dataset, args.batch_size,
516
+ args.use_loss_avg, args.use_logit_avg, args.device,
517
+ args.workdir, args.data_root)
518
+ else:
519
+ raise NotImplementedError
520
+
521
+
522
+ if __name__ == '__main__':
523
+ main()
attacks/AIM/examples/workdirs/1MsK/args.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {'batch_size': 16,
2
+ 'ckpt': WindowsPath('workdirs/1MsK/model.pth'),
3
+ 'command': None,
4
+ 'data_root': WindowsPath('D:/Sharing/Programs/3-Durable-Adv/PyCIL-master/attacks/GAT/examples/../data/in_1k'),
5
+ 'dataset': 'imagenet',
6
+ 'device': device(type='cuda'),
7
+ 'expid': '1MsK',
8
+ 'seed': 0,
9
+ 'tar_classes': 24,
10
+ 'verbose': False,
11
+ 'workdir': WindowsPath('workdirs/1MsK')}
attacks/AIM/examples/workdirs/CGKQ/args.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {'batch_size': 16,
2
+ 'ckpt': WindowsPath('workdirs/CGKQ/model.pth'),
3
+ 'command': None,
4
+ 'data_root': WindowsPath('D:/Sharing/Programs/3-Durable-Adv/PyCIL-master/attacks/GAT/examples/../data/in_1k'),
5
+ 'dataset': 'imagenet',
6
+ 'device': device(type='cuda'),
7
+ 'expid': 'CGKQ',
8
+ 'seed': 0,
9
+ 'tar_classes': 24,
10
+ 'verbose': False,
11
+ 'workdir': WindowsPath('workdirs/CGKQ')}
attacks/AIM/examples/workdirs/Fvu1/args.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {'batch_size': 16,
2
+ 'ckpt': WindowsPath('workdirs/Fvu1/model.pth'),
3
+ 'command': None,
4
+ 'data_root': WindowsPath('D:/Sharing/Programs/3-Durable-Adv/PyCIL-master/attacks/GAT/examples/../data/in_1k'),
5
+ 'dataset': 'imagenet',
6
+ 'device': device(type='cuda'),
7
+ 'expid': 'Fvu1',
8
+ 'seed': 0,
9
+ 'tar_classes': 24,
10
+ 'verbose': False,
11
+ 'workdir': WindowsPath('workdirs/Fvu1')}
attacks/AIM/examples/workdirs/Qonx/args.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {'batch_size': 16,
2
+ 'ckpt': WindowsPath('workdirs/Qonx/model.pth'),
3
+ 'command': None,
4
+ 'data_root': WindowsPath('D:/Sharing/Programs/3-Durable-Adv/PyCIL-master/attacks/GAT/examples/../data/in_1k'),
5
+ 'dataset': 'imagenet',
6
+ 'device': device(type='cuda'),
7
+ 'expid': 'Qonx',
8
+ 'seed': 0,
9
+ 'tar_classes': 24,
10
+ 'verbose': False,
11
+ 'workdir': WindowsPath('workdirs/Qonx')}
attacks/AIM/examples/workdirs/fnNs/args.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {'batch_size': 16,
2
+ 'ckpt': WindowsPath('workdirs/fnNs/model.pth'),
3
+ 'command': None,
4
+ 'data_root': WindowsPath('D:/Sharing/Programs/3-Durable-Adv/PyCIL-master/attacks/GAT/examples/../data/in_1k'),
5
+ 'dataset': 'imagenet',
6
+ 'device': device(type='cuda'),
7
+ 'expid': 'fnNs',
8
+ 'seed': 0,
9
+ 'tar_classes': 24,
10
+ 'verbose': False,
11
+ 'workdir': WindowsPath('workdirs/fnNs')}
attacks/AIM/examples/workdirs/jtMb/args.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {'batch_size': 16,
2
+ 'ckpt': WindowsPath('workdirs/jtMb/model.pth'),
3
+ 'command': None,
4
+ 'data_root': WindowsPath('D:/Sharing/Programs/3-Durable-Adv/PyCIL-master/attacks/GAT/examples/../data/in_1k'),
5
+ 'dataset': 'imagenet',
6
+ 'device': device(type='cuda'),
7
+ 'expid': 'jtMb',
8
+ 'seed': 0,
9
+ 'tar_classes': 24,
10
+ 'verbose': False,
11
+ 'workdir': WindowsPath('workdirs/jtMb')}
attacks/AIM/examples/workdirs/krhX/args.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {'batch_size': 16,
2
+ 'ckpt': WindowsPath('workdirs/krhX/model.pth'),
3
+ 'command': None,
4
+ 'data_root': WindowsPath('D:/Sharing/Programs/3-Durable-Adv/PyCIL-master/attacks/GAT/examples/../data/in_1k'),
5
+ 'dataset': 'imagenet',
6
+ 'device': device(type='cuda'),
7
+ 'expid': 'krhX',
8
+ 'seed': 0,
9
+ 'tar_classes': 24,
10
+ 'verbose': False,
11
+ 'workdir': WindowsPath('workdirs/krhX')}
attacks/AIM/setup.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ from setuptools import find_packages, setup
7
+
8
+ try:
9
+ sys.path.append(str(Path(__file__).parent / 'src'))
10
+ import gat
11
+ except ImportError:
12
+ print('Please install aim_attack package')
13
+ sys.exit(1)
14
+
15
+
16
+ def read(path: Union[str, Path]) -> str:
17
+ with open(path, 'r') as f:
18
+ return f.read()
19
+
20
+
21
+ setup(name='pygat',
22
+ version=gat.VERSION,
23
+ description='GAT: Generative Attack Toolbox',
24
+ long_description=read('README.md'),
25
+ long_description_content_type='text/markdown',
26
+ author='Terry Li',
27
+ url='https://terrytengli.com/GAT/',
28
+ classifiers=[
29
+ 'Programming Language :: Python :: 3',
30
+ 'License :: OSI Approved :: MIT License'
31
+ ],
32
+ python_requires='>=3.10',
33
+ packages=find_packages('src'),
34
+ package_dir={'': 'src'},
35
+ entry_points={
36
+ 'console_scripts': [
37
+ 'aim-api=gat.runtime.api.aim_attack:main',
38
+ ],
39
+ },
40
+ install_requires=[
41
+ 'tqdm', 'tabulate', 'torch', 'torchvision', 'tensorboard', 'jupyter'
42
+ ],
43
+ extras_require={
44
+ 'cli': ['python-multipart', 'fastapi', 'uvicorn'],
45
+ 'ens-gen': []
46
+ },
47
+ include_package_data=True)
attacks/AIM/src/gat/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ VERSION = '202411.2'
attacks/AIM/src/gat/datasets/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .builder import build_dataset, list_datasets
2
+ from .cub import cub
3
+ from .imagenet import imagenet
4
+
5
+ __all__ = ['build_dataset', 'list_datasets', 'imagenet', 'cub']
attacks/AIM/src/gat/datasets/builder.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..runtime.factory import Registry
2
+
3
+ DATASET_REGISTRY = Registry('DATASET')
4
+
5
+
6
+ def build_dataset(_type: str, *args, **kwargs) -> object:
7
+ return DATASET_REGISTRY.get(_type)(*args, **kwargs)
8
+
9
+
10
+ def list_datasets():
11
+ return list(DATASET_REGISTRY._obj_map.keys())
attacks/AIM/src/gat/datasets/cub.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .builder import DATASET_REGISTRY
2
+
3
+
4
+ @DATASET_REGISTRY.register()
5
+ def cub():
6
+ raise NotADirectoryError
attacks/AIM/src/gat/datasets/env.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ CIFAR10_DEFAULT_MEAN = (0.4914, 0.4822, 0.4465)
2
+ CIFAR10_DEFAULT_STD = (0.2470, 0.2435, 0.2616)
3
+ CIFAR100_DEFAULT_MEAN = (0.5071, 0.4865, 0.4409)
4
+ CIFAR100_DEFAULT_STD = (0.2673, 0.2564, 0.2762)
5
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
6
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
attacks/AIM/src/gat/datasets/imagenet.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import List, Union
3
+
4
+ import torch
5
+ import torchvision
6
+
7
+ from .builder import DATASET_REGISTRY
8
+ from .transforms import resize_256_224, to_color, to_ts
9
+
10
+
11
+ @DATASET_REGISTRY.register()
12
+ def imagenet(
13
+ data_root: Union[str, Path],
14
+ is_train: bool = True,
15
+ filter_class: Union[int, List[int]] = None,
16
+ ) -> torch.utils.data.Dataset:
17
+ if isinstance(data_root, str):
18
+ data_root = Path(data_root)
19
+ if is_train:
20
+ data_root = data_root / 'train'
21
+ else:
22
+ data_root = data_root / 'val'
23
+
24
+ _transforms = resize_256_224() + to_color() + to_ts()
25
+
26
+ _ds = torchvision.datasets.ImageFolder(
27
+ data_root,
28
+ transform=torchvision.transforms.Compose(_transforms),
29
+ )
30
+
31
+ if isinstance(filter_class, int):
32
+ filter_class = [filter_class]
33
+ if filter_class:
34
+ _ds.samples = list(filter(lambda x: x[1] in filter_class, _ds.samples))
35
+
36
+ return _ds
attacks/AIM/src/gat/datasets/transforms.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torchvision
4
+
5
+ from . import env
6
+
7
+
8
+ def resize_256_224() -> List:
9
+ return [
10
+ torchvision.transforms.Resize(size=256),
11
+ torchvision.transforms.CenterCrop(size=(224, 224)),
12
+ ]
13
+
14
+
15
+ def resize_512_448() -> List:
16
+ return [
17
+ torchvision.transforms.Resize(size=512),
18
+ torchvision.transforms.CenterCrop(size=(448, 448)),
19
+ ]
20
+
21
+
22
+ def resize_224() -> List:
23
+ return [torchvision.transforms.Resize(size=224)]
24
+
25
+
26
+ def hflip(p: float = 0.5) -> List:
27
+ assert 0 <= p <= 1
28
+ return [torchvision.transforms.RandomHorizontalFlip(p)]
29
+
30
+
31
+ def to_ts() -> List:
32
+ return [torchvision.transforms.ToTensor()]
33
+
34
+
35
+ def to_pil() -> List:
36
+ return [torchvision.transforms.ToPILImage()]
37
+
38
+
39
+ def to_color() -> List:
40
+ return [
41
+ torchvision.transforms.Lambda(lambda x: x.convert('RGB')
42
+ if x.mode != 'RGB' else x)
43
+ ]
44
+
45
+
46
+ def norm(dataset: str = 'IMAGENET', _callable: bool = False) -> List:
47
+ dataset = dataset.upper()
48
+ mean_std = (
49
+ getattr(env, dataset + '_DEFAULT_MEAN'),
50
+ getattr(env, dataset + '_DEFAULT_STD'),
51
+ )
52
+ transforms = [torchvision.transforms.Normalize(*mean_std)]
53
+ if _callable:
54
+ return torchvision.transforms.Compose(transforms)
55
+ return transforms
attacks/AIM/src/gat/models/__init__.py ADDED
File without changes
attacks/AIM/src/gat/models/attack/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .aim_attack import AIMAttack
2
+ from .cda_attack import CDAAttack
3
+ from .loss.logits import ContrastiveLoss
4
+
5
+ __all__ = ['AIMAttack', 'CDAAttack', 'ContrastiveLoss']
attacks/AIM/src/gat/models/attack/aim_attack.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .base_attack import BaseGenerativeAttack
4
+ from .generator.aim import AIMGenerator
5
+
6
+
7
+ class AIMAttack(BaseGenerativeAttack):
8
+
9
+ def set_adv_gen(self):
10
+ self.adv_gen = AIMGenerator().to(self.device)
11
+
12
+ def attack(self, x_nat, *extra_inputs) -> torch.Tensor:
13
+ x_guid = extra_inputs[0].to(self.device)
14
+ return self.adv_gen(x_nat, x_guid)
attacks/AIM/src/gat/models/attack/base_attack.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from collections import OrderedDict
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+
9
+ class BaseGenerativeAttack(abc.ABC):
10
+
11
+ def __init__(self,
12
+ device: Union[str, torch.device],
13
+ epsilon: float = 32 / 255) -> None:
14
+ if isinstance(device, str):
15
+ device = torch.device(device)
16
+ self.device = device
17
+ self.set_adv_gen()
18
+ self.set_mode('eval')
19
+ self.epsilon = epsilon
20
+
21
+ @abc.abstractmethod
22
+ def set_adv_gen(self):
23
+ pass
24
+
25
+ def load_ckpt(self, ckpt: Union[str, Path, OrderedDict]) -> None:
26
+ if isinstance(ckpt, str):
27
+ ckpt = Path(ckpt)
28
+ if isinstance(ckpt, Path):
29
+ if not ckpt.exists():
30
+ raise FileNotFoundError(f'File not found: {ckpt}')
31
+ ckpt = torch.load(ckpt, map_location=self.device)
32
+ self.adv_gen.load_state_dict(ckpt)
33
+ self.adv_gen.to(self.device)
34
+
35
+ def save_ckpt(self, ckpt: Union[str, Path]) -> None:
36
+ if isinstance(ckpt, str):
37
+ ckpt = Path(ckpt)
38
+ _adv_gen_cpu = self.adv_gen.to('cpu')
39
+ torch.save(_adv_gen_cpu.state_dict(), ckpt)
40
+
41
+ def get_params(self) -> torch.nn.Parameter:
42
+ return self.adv_gen.parameters()
43
+
44
+ def get_model(self) -> torch.nn.Module:
45
+ return self.adv_gen
46
+
47
+ def set_mode(self, mode: str) -> None:
48
+ assert mode in ['train', 'eval']
49
+ self.adv_gen.train() if mode == 'train' else self.adv_gen.eval()
50
+
51
+ @abc.abstractmethod
52
+ def attack(self, *args) -> torch.Tensor:
53
+ pass
54
+
55
+ def __call__(self, x_nat: torch.Tensor, *extra_inputs) -> torch.Tensor:
56
+ x_adv = self.attack(x_nat, *extra_inputs)
57
+ x_adv = torch.min(torch.max(x_adv, x_nat - self.epsilon),
58
+ x_nat + self.epsilon)
59
+ torch.clamp_(x_adv, 0.0, 1.0)
60
+ return x_adv
attacks/AIM/src/gat/models/attack/cda_attack.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .base_attack import BaseGenerativeAttack
4
+ from .generator.cda import CDAGenerator
5
+
6
+
7
+ class CDAAttack(BaseGenerativeAttack):
8
+
9
+ def set_adv_gen(self):
10
+ self.adv_gen = CDAGenerator().to(self.device)
11
+
12
+ def attack(self, x_nat, *extra_inputs) -> torch.Tensor:
13
+ return self.adv_gen(x_nat)
attacks/AIM/src/gat/models/attack/generator/__init__.py ADDED
File without changes
attacks/AIM/src/gat/models/attack/generator/aim.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class EnhancedBN(nn.Module):
6
+ def __init__(self, nc: int, sty_nc: int = 3, sty_nhidden: int = 128):
7
+ super(EnhancedBN, self).__init__()
8
+ self.bn = nn.BatchNorm2d(nc)
9
+ self.mapping = nn.Conv2d(
10
+ in_channels=sty_nc,
11
+ out_channels=sty_nhidden,
12
+ kernel_size=3,
13
+ padding=1,
14
+ stride=1,
15
+ )
16
+ self.gamma = nn.Conv2d(
17
+ in_channels=sty_nhidden,
18
+ out_channels=nc,
19
+ kernel_size=3,
20
+ padding=1,
21
+ stride=1,
22
+ )
23
+ self.beta = nn.Conv2d(
24
+ in_channels=sty_nhidden,
25
+ out_channels=nc,
26
+ kernel_size=3,
27
+ padding=1,
28
+ stride=1,
29
+ )
30
+ self.init_weight()
31
+
32
+ def init_weight(self):
33
+ nn.init.kaiming_normal_(self.mapping.weight)
34
+ nn.init.kaiming_normal_(self.gamma.weight)
35
+ nn.init.kaiming_normal_(self.beta.weight)
36
+
37
+ def forward(self, base, sty):
38
+ bn = self.bn(base)
39
+ sty_resized = torch.nn.functional.interpolate(
40
+ sty, size=bn.size()[2:], mode='bilinear'
41
+ )
42
+ actv = torch.nn.functional.relu(self.mapping(sty_resized))
43
+ # style injection
44
+ bn = bn * (1 + self.gamma(actv)) + self.beta(actv)
45
+ return bn
46
+
47
+
48
+ class ResidualBlock(nn.Module):
49
+ def __init__(self, num_filters):
50
+ super(ResidualBlock, self).__init__()
51
+ self.block1 = nn.Sequential(
52
+ nn.ReflectionPad2d(1),
53
+ nn.Conv2d(
54
+ in_channels=num_filters,
55
+ out_channels=num_filters,
56
+ kernel_size=3,
57
+ stride=1,
58
+ padding=0,
59
+ bias=False,
60
+ ),
61
+ )
62
+ self.bn1 = EnhancedBN(num_filters)
63
+ self.block2 = nn.Sequential(
64
+ nn.ReLU(True),
65
+ nn.Dropout(0.5),
66
+ nn.ReflectionPad2d(1),
67
+ nn.Conv2d(
68
+ in_channels=num_filters,
69
+ out_channels=num_filters,
70
+ kernel_size=3,
71
+ stride=1,
72
+ padding=0,
73
+ bias=False,
74
+ ),
75
+ )
76
+ self.bn2 = EnhancedBN(num_filters)
77
+
78
+ def forward(self, x, sty):
79
+ residual = self.block1(x)
80
+ residual = self.bn1(residual, sty)
81
+ residual = self.block2(residual)
82
+ residual = self.bn2(residual, sty)
83
+ return x + residual
84
+
85
+
86
+ ngf = 64
87
+
88
+
89
+ class ResNetGenerator(nn.Module):
90
+ def __init__(self):
91
+ super(ResNetGenerator, self).__init__()
92
+ self.block1 = nn.Sequential(
93
+ nn.ReflectionPad2d(3),
94
+ nn.Conv2d(3, ngf, kernel_size=7, padding=0, bias=False),
95
+ )
96
+ self.bn1 = EnhancedBN(ngf)
97
+ # Input size = 3, n, n
98
+ self.block2 = nn.Sequential(
99
+ nn.Conv2d(
100
+ ngf, ngf * 2, kernel_size=3, stride=2, padding=1, bias=False
101
+ ),
102
+ )
103
+ self.bn2 = EnhancedBN(ngf * 2)
104
+ # Input size = 3, n/2, n/2
105
+ self.block3 = nn.Sequential(
106
+ nn.Conv2d(
107
+ ngf * 2,
108
+ ngf * 4,
109
+ kernel_size=3,
110
+ stride=2,
111
+ padding=1,
112
+ bias=False,
113
+ ),
114
+ )
115
+ self.bn3 = EnhancedBN(ngf * 4)
116
+ # Input size = 3, n/4, n/4
117
+ # Residual Blocks: 6
118
+ self.resblock1 = ResidualBlock(ngf * 4)
119
+ self.resblock2 = ResidualBlock(ngf * 4)
120
+ self.resblock3 = ResidualBlock(ngf * 4)
121
+ self.resblock4 = ResidualBlock(ngf * 4)
122
+ self.resblock5 = ResidualBlock(ngf * 4)
123
+ self.resblock6 = ResidualBlock(ngf * 4)
124
+ # Input size = 3, n/4, n/4
125
+ self.upsampl1 = nn.ConvTranspose2d(
126
+ ngf * 4,
127
+ ngf * 2,
128
+ kernel_size=3,
129
+ stride=2,
130
+ padding=1,
131
+ output_padding=1,
132
+ bias=False,
133
+ )
134
+ self.ubn1 = EnhancedBN(ngf * 2)
135
+ # Input size = 3, n/2, n/2
136
+ self.upsampl2 = nn.ConvTranspose2d(
137
+ ngf * 2,
138
+ ngf,
139
+ kernel_size=3,
140
+ stride=2,
141
+ padding=1,
142
+ output_padding=1,
143
+ bias=False,
144
+ )
145
+ self.ubn2 = EnhancedBN(ngf)
146
+ # Input size = 3, n, n
147
+ self.blockf = nn.Sequential(
148
+ nn.ReflectionPad2d(3), nn.Conv2d(ngf, 3, kernel_size=7, padding=0)
149
+ )
150
+
151
+ def forward(self, input, sty):
152
+ x = self.block1(input)
153
+ x = self.bn1(x, sty)
154
+ x = torch.nn.functional.relu(x)
155
+ x = self.block2(x)
156
+ x = self.bn2(x, sty)
157
+ x = torch.nn.functional.relu(x)
158
+ x = self.block3(x)
159
+ x = self.bn3(x, sty)
160
+ x = torch.nn.functional.relu(x)
161
+ # =============================
162
+ x = self.resblock1(x, sty)
163
+ x = self.resblock2(x, sty)
164
+ x = self.resblock3(x, sty)
165
+ x = self.resblock4(x, sty)
166
+ x = self.resblock5(x, sty)
167
+ x = self.resblock6(x, sty)
168
+ # =============================
169
+ x = self.upsampl1(x)
170
+ x = self.ubn1(x, sty)
171
+ x = torch.nn.functional.relu(x)
172
+ x = self.upsampl2(x)
173
+ x = self.ubn2(x, sty)
174
+ x = torch.nn.functional.relu(x)
175
+ x = self.blockf(x)
176
+ return (torch.tanh(x) + 1) / 2
177
+
178
+
179
+ AIMGenerator = ResNetGenerator
attacks/AIM/src/gat/models/attack/generator/cda.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code is copied from
2
+ # github.com/Alibaba-AAIG/Beyond-ImageNet-Attack:generator.py@863b7
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class ResidualBlock(nn.Module):
8
+ def __init__(self, num_filters):
9
+ super(ResidualBlock, self).__init__()
10
+ self.block = nn.Sequential(
11
+ nn.ReflectionPad2d(1),
12
+ nn.Conv2d(
13
+ in_channels=num_filters,
14
+ out_channels=num_filters,
15
+ kernel_size=3,
16
+ stride=1,
17
+ padding=0,
18
+ bias=False,
19
+ ),
20
+ nn.BatchNorm2d(num_filters),
21
+ nn.ReLU(True),
22
+ nn.Dropout(0.5),
23
+ nn.ReflectionPad2d(1),
24
+ nn.Conv2d(
25
+ in_channels=num_filters,
26
+ out_channels=num_filters,
27
+ kernel_size=3,
28
+ stride=1,
29
+ padding=0,
30
+ bias=False,
31
+ ),
32
+ nn.BatchNorm2d(num_filters),
33
+ )
34
+
35
+ def forward(self, x):
36
+ residual = self.block(x)
37
+ return x + residual
38
+
39
+
40
+ ngf = 64
41
+
42
+
43
+ class ResNetGenerator(nn.Module):
44
+ """
45
+ https://github.com/Alibaba-AAIG/Beyond-ImageNet-Attack/blob/863b758ee4f4a6d3d4e7777c5f94f457fa449f73/generator.py#L14
46
+
47
+ Test Case:
48
+ >>> netG = ResNetGenerator()
49
+ >>> test_sample = torch.rand(1, 3, 224, 224)
50
+ >>> print("Generator output:", netG(test_sample).size())
51
+ >>> print(
52
+ >>> "Generator parameters:",
53
+ >>> sum(p.numel() for p in netG.parameters() if p.requires_grad),
54
+ >>> )
55
+ """
56
+
57
+ def __init__(self, inception=False):
58
+ super(ResNetGenerator, self).__init__()
59
+ self.inception = inception
60
+ self.block1 = nn.Sequential(
61
+ nn.ReflectionPad2d(3),
62
+ nn.Conv2d(3, ngf, kernel_size=7, padding=0, bias=False),
63
+ nn.BatchNorm2d(ngf),
64
+ nn.ReLU(True),
65
+ )
66
+ # output: (ngf) x (n) x (n)
67
+ self.block2 = nn.Sequential(
68
+ nn.Conv2d(
69
+ ngf, ngf * 2, kernel_size=3, stride=2, padding=1, bias=False
70
+ ),
71
+ nn.BatchNorm2d(ngf * 2),
72
+ nn.ReLU(True),
73
+ )
74
+ # output: (ngf*2) x (n/2) x (n/2)
75
+ self.block3 = nn.Sequential(
76
+ nn.Conv2d(
77
+ ngf * 2,
78
+ ngf * 4,
79
+ kernel_size=3,
80
+ stride=2,
81
+ padding=1,
82
+ bias=False,
83
+ ),
84
+ nn.BatchNorm2d(ngf * 4),
85
+ nn.ReLU(True),
86
+ )
87
+ # output: (ngf*4) x (n/4) x (n/4)
88
+ self.resblock1 = ResidualBlock(ngf * 4)
89
+ self.resblock2 = ResidualBlock(ngf * 4)
90
+ self.resblock3 = ResidualBlock(ngf * 4)
91
+ self.resblock4 = ResidualBlock(ngf * 4)
92
+ self.resblock5 = ResidualBlock(ngf * 4)
93
+ self.resblock6 = ResidualBlock(ngf * 4)
94
+ # output: (ngf*4) x (n/4) x (n/4)
95
+ self.upsampl1 = nn.Sequential(
96
+ nn.ConvTranspose2d(
97
+ ngf * 4,
98
+ ngf * 2,
99
+ kernel_size=3,
100
+ stride=2,
101
+ padding=1,
102
+ output_padding=1,
103
+ bias=False,
104
+ ),
105
+ nn.BatchNorm2d(ngf * 2),
106
+ nn.ReLU(True),
107
+ )
108
+ # output: (ngf*2) x (n/2) x (n/2)
109
+ self.upsampl2 = nn.Sequential(
110
+ nn.ConvTranspose2d(
111
+ ngf * 2,
112
+ ngf,
113
+ kernel_size=3,
114
+ stride=2,
115
+ padding=1,
116
+ output_padding=1,
117
+ bias=False,
118
+ ),
119
+ nn.BatchNorm2d(ngf),
120
+ nn.ReLU(True),
121
+ )
122
+ # output: (ngf) x (n) x (n)
123
+ self.blockf = nn.Sequential(
124
+ nn.ReflectionPad2d(3), nn.Conv2d(ngf, 3, kernel_size=7, padding=0)
125
+ )
126
+ self.crop = nn.ConstantPad2d((0, -1, -1, 0), 0)
127
+
128
+ def forward(self, input):
129
+ x = self.block1(input)
130
+ x = self.block2(x)
131
+ x = self.block3(x)
132
+ x = self.resblock1(x)
133
+ x = self.resblock2(x)
134
+ x = self.resblock3(x)
135
+ x = self.resblock4(x)
136
+ x = self.resblock5(x)
137
+ x = self.resblock6(x)
138
+ x = self.upsampl1(x)
139
+ x = self.upsampl2(x)
140
+ x = self.blockf(x)
141
+ if self.inception:
142
+ x = self.crop(x)
143
+ return (torch.tanh(x) + 1) / 2
144
+
145
+
146
+ CDAGenerator = ResNetGenerator
attacks/AIM/src/gat/models/attack/loss/__init__.py ADDED
File without changes
attacks/AIM/src/gat/models/attack/loss/logits.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class ContrastiveLoss(torch.nn.Module):
5
+ """
6
+ Contrastive loss
7
+ Adapted from: (OnlineContrastiveLoss)
8
+ https://github.com/adambielski/siamese-triplet/blob/master/losses.py
9
+ Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
10
+ """
11
+
12
+ def __init__(self, margin):
13
+ super(ContrastiveLoss, self).__init__()
14
+ self.margin = margin
15
+
16
+ def forward(self, anchors, negatives, positives):
17
+ anchors = anchors / anchors.norm(dim=-1, keepdim=True)
18
+ negatives = negatives / negatives.norm(dim=-1, keepdim=True)
19
+ positives = positives / positives.norm(dim=-1, keepdim=True)
20
+
21
+ positive_loss = (anchors - positives).pow(2).sum(1)
22
+ negative_loss = torch.nn.functional.relu(
23
+ self.margin - (anchors - negatives).pow(2).sum(1).sqrt()).pow(2)
24
+
25
+ loss = 0.5 * torch.cat([positive_loss, negative_loss], dim=0)
26
+
27
+ return loss.mean()
attacks/AIM/src/gat/models/attack/optim/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .sam import SAM, disable_running_stats, enable_running_stats
2
+
3
+ __all__ = ['SAM', 'enable_running_stats', 'disable_running_stats']
attacks/AIM/src/gat/models/attack/optim/sam.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code is copied from
2
+ # github.com/davda54/sam/sam.py@3c3afdb
3
+ import torch
4
+ from torch.nn.modules.batchnorm import _BatchNorm
5
+
6
+
7
+ class SAM(torch.optim.Optimizer):
8
+
9
+ def __init__(self,
10
+ params,
11
+ base_optimizer,
12
+ rho=0.05,
13
+ adaptive=False,
14
+ **kwargs):
15
+ assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
16
+
17
+ defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
18
+ super(SAM, self).__init__(params, defaults)
19
+
20
+ self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
21
+ self.param_groups = self.base_optimizer.param_groups
22
+ self.defaults.update(self.base_optimizer.defaults)
23
+
24
+ @torch.no_grad()
25
+ def first_step(self, zero_grad=False):
26
+ grad_norm = self._grad_norm()
27
+ for group in self.param_groups:
28
+ scale = group['rho'] / (grad_norm + 1e-12)
29
+
30
+ for p in group['params']:
31
+ if p.grad is None:
32
+ continue
33
+ self.state[p]['old_p'] = p.data.clone()
34
+ e_w = (torch.pow(p, 2)
35
+ if group['adaptive'] else 1.0) * p.grad * scale.to(p)
36
+ p.add_(e_w) # climb to the local maximum "w + e(w)"
37
+
38
+ if zero_grad:
39
+ self.zero_grad()
40
+
41
+ @torch.no_grad()
42
+ def second_step(self, zero_grad=False):
43
+ for group in self.param_groups:
44
+ for p in group['params']:
45
+ if p.grad is None:
46
+ continue
47
+ p.data = self.state[p][
48
+ 'old_p'] # get back to "w" from "w + e(w)"
49
+
50
+ self.base_optimizer.step() # do the actual "sharpness-aware" update
51
+
52
+ if zero_grad:
53
+ self.zero_grad()
54
+
55
+ @torch.no_grad()
56
+ def step(self, closure=None):
57
+ assert closure is not None, \
58
+ 'Sharpness Aware Minimization requires closure, ' \
59
+ 'but it was not provided'
60
+ closure = torch.enable_grad()(
61
+ closure) # the closure should do a full forward-backward pass
62
+
63
+ self.first_step(zero_grad=True)
64
+ closure()
65
+ self.second_step()
66
+
67
+ def _grad_norm(self):
68
+ # put everything on the same device, in case of model parallelism
69
+ shared_device = self.param_groups[0]['params'][0].device
70
+ norm = torch.norm(torch.stack([
71
+ ((torch.abs(p) if group['adaptive'] else 1.0) *
72
+ p.grad).norm(p=2).to(shared_device) for group in self.param_groups
73
+ for p in group['params'] if p.grad is not None
74
+ ]),
75
+ p=2)
76
+ return norm
77
+
78
+ def load_state_dict(self, state_dict):
79
+ super().load_state_dict(state_dict)
80
+ self.base_optimizer.param_groups = self.param_groups
81
+
82
+
83
+ def disable_running_stats(model):
84
+
85
+ def _disable(module):
86
+ if isinstance(module, _BatchNorm):
87
+ module.backup_momentum = module.momentum
88
+ module.momentum = 0
89
+
90
+ model.apply(_disable)
91
+
92
+
93
+ def enable_running_stats(model):
94
+
95
+ def _enable(module):
96
+ if isinstance(module, _BatchNorm) and hasattr(module,
97
+ 'backup_momentum'):
98
+ module.momentum = module.backup_momentum
99
+
100
+ model.apply(_enable)
attacks/AIM/src/gat/models/surrogate/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .builder import build_surrogate, list_surrogates
2
+ from .hooks import feat_col, midlayer_dict, register_collecter, register_collecter_cl
3
+ from .tv import (densenet121, densenet169, inception_v3, resnet50, resnet152,
4
+ swin_b, vgg16, vgg19, vit_b_16, vit_b_32)
5
+
6
+ __all__ = [
7
+ 'build_surrogate', 'list_surrogates', 'inception_v3', 'vgg16', 'vgg19',
8
+ 'resnet50', 'resnet152', 'densenet121', 'densenet169', 'vit_b_16',
9
+ 'vit_b_32', 'swin_b', 'midlayer_dict', 'register_collecter', 'feat_col', 'register_collecter_cl'
10
+ ]
attacks/AIM/src/gat/models/surrogate/builder.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ...runtime.factory import Registry
2
+
3
+ SURROGATE_REGISTRY = Registry('SURROGATE')
4
+
5
+
6
+ def build_surrogate(_type: str, *args, **kwargs) -> object:
7
+ return SURROGATE_REGISTRY.get(_type)(*args, **kwargs)
8
+
9
+
10
+ def list_surrogates():
11
+ return list(SURROGATE_REGISTRY._obj_map.keys())
attacks/AIM/src/gat/models/surrogate/hooks.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ from typing import List, Union
3
+
4
+ import torch
5
+
6
+ midlayer_dict = {
7
+ 'vgg16': 'features.16',
8
+ 'vgg19': 'features.18',
9
+ 'resnet152': 'layer1',
10
+ 'densenet169': 'features.denseblock2',
11
+ 'inception_v3': 'Mixed_6c',
12
+ 'resnet50': 'conv1', # conv1, layer1, layer2, layer3, layer4
13
+ 'resnet50_cl': 'layer4',
14
+ 'resnet32': 'conv_1_3x3', # conv_1_3x3, stage_1, stage_2, stage_3
15
+ 'resnet32_cl': 'conv_1_3x3',
16
+ }
17
+
18
+
19
+ def register_collecter(m: torch.nn.Module, layer: str, feat_collecter: List):
20
+
21
+ def _hook(m, i, o):
22
+ feat_collecter.append(o)
23
+
24
+ _handler = m.get_submodule(layer).register_forward_hook(_hook) #m.convnets[0].get_submodule(layer).register_forward_hook(_hook)
25
+ return _handler, feat_collecter
26
+
27
+ def register_collecter_cl(m: torch.nn.Module, layer: str, feat_collecter: List, cl_methods: str):
28
+
29
+ def _hook(m, i, o):
30
+ feat_collecter.append(o)
31
+
32
+ if cl_methods == 'icarl' or cl_methods == 'finetune' or cl_methods == 'wa' or cl_methods == 'replay' or cl_methods == 'podnet' or cl_methods == 'bic':
33
+ _handler = m.convnet.get_submodule(layer).register_forward_hook(_hook) # For ewc, icarl methods
34
+ elif cl_methods == 'foster' or cl_methods == 'der':
35
+ _handler = m.convnets[0].get_submodule(layer).register_forward_hook(_hook) # For foster and der methods
36
+ elif cl_methods == 'memo':
37
+ _handler = m.TaskAgnosticExtractor.get_submodule(layer).register_forward_hook(_hook) # For foster and der methods
38
+ return _handler, feat_collecter
39
+
40
+
41
+ @contextmanager
42
+ def feat_col(m: Union[torch.nn.Module, List[torch.nn.Module]],
43
+ layer: Union[str, List[str]]):
44
+ if isinstance(m, torch.nn.Module):
45
+ m = [m]
46
+ if isinstance(layer, str):
47
+ layer = [layer]
48
+ assert len(m) == len(layer)
49
+ handlers = []
50
+ feat_collecter = []
51
+ for _m, _layer in zip(m, layer):
52
+ handler, feat_collecter = register_collecter(_m, _layer,
53
+ feat_collecter)
54
+ handlers.append(handler)
55
+ yield feat_collecter
56
+ for handler in handlers:
57
+ handler.remove()
58
+ feat_collecter.clear()
59
+ del feat_collecter
attacks/AIM/src/gat/models/surrogate/tv.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import models as tv_models
2
+
3
+ from .builder import SURROGATE_REGISTRY
4
+
5
+ weights_dict = dict(
6
+ alexnet=tv_models.AlexNet_Weights.IMAGENET1K_V1,
7
+ googlenet=tv_models.GoogLeNet_Weights.IMAGENET1K_V1,
8
+ vgg16=tv_models.VGG16_Weights.IMAGENET1K_V1,
9
+ vgg19=tv_models.VGG19_Weights.IMAGENET1K_V1,
10
+ inception_v3=tv_models.Inception_V3_Weights.IMAGENET1K_V1,
11
+ resnet18=tv_models.ResNet18_Weights.IMAGENET1K_V1,
12
+ resnet34=tv_models.ResNet34_Weights.IMAGENET1K_V1,
13
+ resnet50=tv_models.ResNet50_Weights.IMAGENET1K_V1,
14
+ resnet152=tv_models.ResNet152_Weights.IMAGENET1K_V1,
15
+ wide_resnet50_2=tv_models.Wide_ResNet50_2_Weights.IMAGENET1K_V1,
16
+ wide_resnet101_2=tv_models.Wide_ResNet101_2_Weights.IMAGENET1K_V1,
17
+ densenet121=tv_models.DenseNet121_Weights.IMAGENET1K_V1,
18
+ densenet169=tv_models.DenseNet169_Weights.IMAGENET1K_V1,
19
+ mobilenet_v2=tv_models.MobileNet_V2_Weights.IMAGENET1K_V1,
20
+ mobilenet_v3_small=tv_models.MobileNet_V3_Small_Weights.IMAGENET1K_V1,
21
+ mobilenet_v3_large=tv_models.MobileNet_V3_Large_Weights.IMAGENET1K_V1,
22
+ squeezenet1_0=tv_models.SqueezeNet1_0_Weights.IMAGENET1K_V1,
23
+ squeezenet1_1=tv_models.SqueezeNet1_1_Weights.IMAGENET1K_V1,
24
+ shufflenet_v2_x0_5=tv_models.ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1,
25
+ shufflenet_v2_x1_0=tv_models.ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1,
26
+ shufflenet_v2_x1_5=tv_models.ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1,
27
+ shufflenet_v2_x2_0=tv_models.ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1,
28
+ efficientnet_b0=tv_models.EfficientNet_B0_Weights.IMAGENET1K_V1,
29
+ efficientnet_v2_s=tv_models.EfficientNet_V2_S_Weights.IMAGENET1K_V1,
30
+ efficientnet_v2_m=tv_models.EfficientNet_V2_M_Weights.IMAGENET1K_V1,
31
+ efficientnet_v2_l=tv_models.EfficientNet_V2_L_Weights.IMAGENET1K_V1,
32
+ vit_b_16=tv_models.ViT_B_16_Weights.IMAGENET1K_V1,
33
+ vit_b_32=tv_models.ViT_B_32_Weights.IMAGENET1K_V1,
34
+ swin_b=tv_models.Swin_B_Weights.IMAGENET1K_V1)
35
+
36
+
37
+ @SURROGATE_REGISTRY.register()
38
+ def alexnet(pretrain=True):
39
+ return tv_models.alexnet(
40
+ weights=weights_dict['alexnet'] if pretrain else None)
41
+
42
+
43
+ @SURROGATE_REGISTRY.register()
44
+ def googlenet(pretrain=True):
45
+ return tv_models.googlenet(
46
+ weights=weights_dict['googlenet'] if pretrain else None)
47
+
48
+
49
+ @SURROGATE_REGISTRY.register()
50
+ def vgg16(pretrain=True):
51
+ return tv_models.vgg16(weights=weights_dict['vgg16'] if pretrain else None)
52
+
53
+
54
+ @SURROGATE_REGISTRY.register()
55
+ def vgg19(pretrain=True):
56
+ return tv_models.vgg19(weights=weights_dict['vgg19'] if pretrain else None)
57
+
58
+
59
+ @SURROGATE_REGISTRY.register()
60
+ def inception_v3(pretrain=True):
61
+ return tv_models.inception_v3(
62
+ weights=weights_dict['inception_v3'] if pretrain else None)
63
+
64
+
65
+ @SURROGATE_REGISTRY.register()
66
+ def resnet18(pretrain=True):
67
+ return tv_models.resnet18(
68
+ weights=weights_dict['resnet18'] if pretrain else None)
69
+
70
+
71
+ @SURROGATE_REGISTRY.register()
72
+ def resnet34(pretrain=True):
73
+ return tv_models.resnet34(
74
+ weights=weights_dict['resnet34'] if pretrain else None)
75
+
76
+
77
+ @SURROGATE_REGISTRY.register()
78
+ def resnet50(pretrain=True):
79
+ return tv_models.resnet50(
80
+ weights=weights_dict['resnet50'] if pretrain else None)
81
+
82
+
83
+ @SURROGATE_REGISTRY.register()
84
+ def resnet152(pretrain=True):
85
+ return tv_models.resnet152(
86
+ weights=weights_dict['resnet152'] if pretrain else None)
87
+
88
+
89
+ @SURROGATE_REGISTRY.register()
90
+ def wide_resnet50_2(pretrain=True):
91
+ return tv_models.wide_resnet50_2(
92
+ weights=weights_dict['wide_resnet50_2'] if pretrain else None)
93
+
94
+
95
+ @SURROGATE_REGISTRY.register()
96
+ def wide_resnet101_2(pretrain=True):
97
+ return tv_models.wide_resnet101_2(
98
+ weights=weights_dict['wide_resnet101_2'] if pretrain else None)
99
+
100
+
101
+ @SURROGATE_REGISTRY.register()
102
+ def densenet121(pretrain=True):
103
+ return tv_models.densenet121(
104
+ weights=weights_dict['densenet121'] if pretrain else None)
105
+
106
+
107
+ @SURROGATE_REGISTRY.register()
108
+ def densenet169(pretrain=True):
109
+ return tv_models.densenet169(
110
+ weights=weights_dict['densenet169'] if pretrain else None)
111
+
112
+
113
+ @SURROGATE_REGISTRY.register()
114
+ def mobilenet_v2(pretrain=True):
115
+ return tv_models.mobilenet_v2(
116
+ weights=weights_dict['mobilenet_v2'] if pretrain else None)
117
+
118
+
119
+ @SURROGATE_REGISTRY.register()
120
+ def mobilenet_v3_small(pretrain=True):
121
+ return tv_models.mobilenet_v3_small(
122
+ weights=weights_dict['mobilenet_v3_small'] if pretrain else None)
123
+
124
+
125
+ @SURROGATE_REGISTRY.register()
126
+ def mobilenet_v3_large(pretrain=True):
127
+ return tv_models.mobilenet_v3_large(
128
+ weights=weights_dict['mobilenet_v3_large'] if pretrain else None)
129
+
130
+
131
+ @SURROGATE_REGISTRY.register()
132
+ def squeezenet1_0(pretrain=True):
133
+ return tv_models.squeezenet1_0(
134
+ weights=weights_dict['squeezenet1_0'] if pretrain else None)
135
+
136
+
137
+ @SURROGATE_REGISTRY.register()
138
+ def squeezenet1_1(pretrain=True):
139
+ return tv_models.squeezenet1_1(
140
+ weights=weights_dict['squeezenet1_1'] if pretrain else None)
141
+
142
+
143
+ @SURROGATE_REGISTRY.register()
144
+ def shufflenet_v2_x0_5(pretrain=True):
145
+ return tv_models.shufflenet_v2_x0_5(
146
+ weights=weights_dict['shufflenet_v2_x0_5'] if pretrain else None)
147
+
148
+
149
+ @SURROGATE_REGISTRY.register()
150
+ def shufflenet_v2_x1_0(pretrain=True):
151
+ return tv_models.shufflenet_v2_x1_0(
152
+ weights=weights_dict['shufflenet_v2_x1_0'] if pretrain else None)
153
+
154
+
155
+ @SURROGATE_REGISTRY.register()
156
+ def shufflenet_v2_x1_5(pretrain=True):
157
+ return tv_models.shufflenet_v2_x1_5(
158
+ weights=weights_dict['shufflenet_v2_x1_5'] if pretrain else None)
159
+
160
+
161
+ @SURROGATE_REGISTRY.register()
162
+ def shufflenet_v2_x2_0(pretrain=True):
163
+ return tv_models.shufflenet_v2_x2_0(
164
+ weights=weights_dict['shufflenet_v2_x2_0'] if pretrain else None)
165
+
166
+
167
+ @SURROGATE_REGISTRY.register()
168
+ def efficientnet_b0(pretrain=True):
169
+ return tv_models.efficientnet_b0(
170
+ weights=weights_dict['efficientnet_b0'] if pretrain else None)
171
+
172
+
173
+ @SURROGATE_REGISTRY.register()
174
+ def efficientnet_v2_s(pretrain=True):
175
+ return tv_models.efficientnet_v2_s(
176
+ weights=weights_dict['efficientnet_v2_s'] if pretrain else None)
177
+
178
+
179
+ @SURROGATE_REGISTRY.register()
180
+ def efficientnet_v2_m(pretrain=True):
181
+ return tv_models.efficientnet_v2_m(
182
+ weights=weights_dict['efficientnet_v2_m'] if pretrain else None)
183
+
184
+
185
+ @SURROGATE_REGISTRY.register()
186
+ def efficientnet_v2_l(pretrain=True):
187
+ return tv_models.efficientnet_v2_l(
188
+ weights=weights_dict['efficientnet_v2_l'] if pretrain else None)
189
+
190
+
191
+ @SURROGATE_REGISTRY.register()
192
+ def vit_b_16(pretrain=True):
193
+ return tv_models.vit_b_16(
194
+ weights=weights_dict['vit_b_16'] if pretrain else None)
195
+
196
+
197
+ @SURROGATE_REGISTRY.register()
198
+ def vit_b_32(pretrain=True):
199
+ return tv_models.vit_b_32(
200
+ weights=weights_dict['vit_b_32'] if pretrain else None)
201
+
202
+
203
+ @SURROGATE_REGISTRY.register()
204
+ def swin_b(pretrain=True):
205
+ return tv_models.swin_b(
206
+ weights=weights_dict['swin_b'] if pretrain else None)
attacks/AIM/src/gat/runtime/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .meter import AverageMeter, calc_cls_accuracy
2
+ from .utils import fix_random, randid
3
+
4
+ __all__ = ['fix_random', 'randid', 'AverageMeter', 'calc_cls_accuracy']
attacks/AIM/src/gat/runtime/api/__init__.py ADDED
File without changes
attacks/AIM/src/gat/runtime/api/aim_attack.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import io
4
+ from pathlib import Path
5
+ from typing import Union
6
+
7
+ import torchvision
8
+ import uvicorn
9
+ from fastapi import FastAPI, File, UploadFile
10
+ from fastapi.responses import StreamingResponse
11
+ from PIL import Image
12
+
13
+ from ...datasets.transforms import resize_256_224, to_color, to_pil, to_ts
14
+ from ...models.attack import AIMAttack
15
+
16
+
17
+ def init_attack(ckpt: Union[str, Path] = None):
18
+ attack = AIMAttack(device='cpu')
19
+ if ckpt:
20
+ attack.load_ckpt(ckpt)
21
+ attack.set_mode('eval')
22
+ attack_preproc = torchvision.transforms.Compose(resize_256_224() +
23
+ to_color() + to_ts())
24
+ return attack, attack_preproc
25
+
26
+
27
+ def main():
28
+ parser = argparse.ArgumentParser(description='AIM Attack API')
29
+ parser.add_argument('--host', type=str, default='0.0.0.0', help='host')
30
+ parser.add_argument('--port', type=int, default=8000, help='port')
31
+ parser.add_argument('--ckpt',
32
+ type=str,
33
+ default=None,
34
+ help='path to the checkpoint')
35
+ args = parser.parse_args()
36
+
37
+ attack, attack_preproc = init_attack(args.ckpt)
38
+ app = FastAPI()
39
+
40
+ @app.get('/attack/aim/')
41
+ async def aim_attack(x_nat: UploadFile = File(...),
42
+ x_guid: UploadFile = File(...)):
43
+ io_x_nat = await x_nat.read()
44
+ io_x_guid = await x_guid.read()
45
+ pil_x_nat = Image.open(io.BytesIO(io_x_nat))
46
+ pil_x_guid = Image.open(io.BytesIO(io_x_guid))
47
+ ts_x_nat = attack_preproc(pil_x_nat).unsqueeze(0)
48
+ ts_x_guid = attack_preproc(pil_x_guid).unsqueeze(0)
49
+ ts_x_adv = attack(ts_x_nat, ts_x_guid)
50
+ pil_x_adv = torchvision.transforms.Compose(to_pil())(ts_x_adv[0])
51
+
52
+ img_byte_array = io.BytesIO()
53
+ pil_x_adv.save(img_byte_array, format='PNG')
54
+ img_byte_array.seek(0)
55
+
56
+ return StreamingResponse(img_byte_array, media_type='image/png')
57
+
58
+ uvicorn.run(app, host=args.host, port=args.port)
attacks/AIM/src/gat/runtime/factory.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code is copied from
2
+ # github.com/facebookresearch/fvcore:common/registry.py@242366
3
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
4
+
5
+ # pyre-strict
6
+ # pyre-ignore-all-errors[2,3]
7
+ from typing import Any, Dict, Iterable, Iterator, Tuple
8
+
9
+ from tabulate import tabulate
10
+
11
+
12
+ class Registry(Iterable[Tuple[str, Any]]):
13
+ """
14
+ The registry that provides name -> object mapping, to support third-party
15
+ users' custom modules.
16
+
17
+ To create a registry (e.g. a backbone registry):
18
+
19
+ .. code-block:: python
20
+
21
+ BACKBONE_REGISTRY = Registry('BACKBONE')
22
+
23
+ To register an object:
24
+
25
+ .. code-block:: python
26
+
27
+ @BACKBONE_REGISTRY.register()
28
+ class MyBackbone():
29
+ ...
30
+
31
+ Or:
32
+
33
+ .. code-block:: python
34
+
35
+ BACKBONE_REGISTRY.register(MyBackbone)
36
+ """
37
+
38
+ def __init__(self, name: str) -> None:
39
+ """
40
+ Args:
41
+ name (str): the name of this registry
42
+ """
43
+ self._name: str = name
44
+ self._obj_map: Dict[str, Any] = {}
45
+
46
+ def _do_register(self, name: str, obj: Any) -> None:
47
+ assert (
48
+ name not in self._obj_map
49
+ ), "An object named '{}' was already registered in '{}' registry!".format( # noqa: E501
50
+ name, self._name
51
+ )
52
+ self._obj_map[name] = obj
53
+
54
+ def register(self, obj: Any = None) -> Any:
55
+ """
56
+ Register the given object under the the name `obj.__name__`.
57
+ Can be used as either a decorator or not. See docstring of this class
58
+ for usage.
59
+ """
60
+ if obj is None:
61
+ # used as a decorator
62
+ def deco(func_or_class: Any) -> Any:
63
+ name = func_or_class.__name__
64
+ self._do_register(name, func_or_class)
65
+ return func_or_class
66
+
67
+ return deco
68
+
69
+ # used as a function call
70
+ name = obj.__name__
71
+ self._do_register(name, obj)
72
+
73
+ def get(self, name: str) -> Any:
74
+ ret = self._obj_map.get(name)
75
+ if ret is None:
76
+ raise KeyError(
77
+ "No object named '{}' found in '{}' registry!".format(
78
+ name, self._name
79
+ )
80
+ )
81
+ return ret
82
+
83
+ def __contains__(self, name: str) -> bool:
84
+ return name in self._obj_map
85
+
86
+ def __repr__(self) -> str:
87
+ table_headers = ['Names', 'Objects']
88
+ table = tabulate(
89
+ self._obj_map.items(), headers=table_headers, tablefmt='fancy_grid'
90
+ )
91
+ return 'Registry of {}:\n'.format(self._name) + table
92
+
93
+ def __iter__(self) -> Iterator[Tuple[str, Any]]:
94
+ return iter(self._obj_map.items())
95
+
96
+ # pyre-fixme[4]: Attribute must be annotated.
97
+ __str__ = __repr__
attacks/AIM/src/gat/runtime/meter.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class AverageMeter:
2
+ def __init__(self):
3
+ self.reset()
4
+
5
+ def reset(self):
6
+ self.val = 0
7
+ self.avg = 0
8
+ self.sum = 0
9
+ self.count = 0
10
+
11
+ def update(self, val, n=1):
12
+ self.val = val
13
+ self.sum += val * n
14
+ self.count += n
15
+ self.avg = self.sum / self.count
16
+
17
+
18
+ def calc_cls_accuracy(output, target, topk=(1,)):
19
+ maxk = min(max(topk), output.size()[1])
20
+ batch_size = target.size(0)
21
+ _, pred = output.topk(maxk, 1, True, True)
22
+ pred = pred.t()
23
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
24
+ return [
25
+ correct[: min(k, maxk)].reshape(-1).float().sum(0) * 100.0 / batch_size
26
+ for k in topk
27
+ ]
attacks/AIM/src/gat/runtime/utils.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def fix_random(seed: int = 0) -> None:
8
+ np.random.seed(seed)
9
+ torch.manual_seed(seed)
10
+ torch.cuda.manual_seed(seed)
11
+ torch.cuda.manual_seed_all(seed)
12
+ torch.backends.cudnn.deterministic = True
13
+ torch.backends.cudnn.benchmark = False
14
+
15
+
16
+ def randid(k: int = 4):
17
+ charset = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'
18
+ return ''.join(random.choices(charset, k=k))
attacks/AIM/tests/__init__.py ADDED
File without changes
attacks/AIM/tests/test_datasets/__init__.py ADDED
File without changes
attacks/AIM/tests/test_datasets/test_datasets.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import unittest
3
+ from pathlib import Path
4
+
5
+ from gat.datasets import build_dataset
6
+
7
+ in1k_data_root = '/root/workspace/proj/transfer-at/data/in_1k'
8
+ in1k_data_root = os.environ.get('DATA_ROOT',
9
+ Path(__file__).parents[2] / 'data' / 'in_1k')
10
+
11
+
12
+ class TestImageNet(unittest.TestCase):
13
+
14
+ def test_in1k(self):
15
+ ds = build_dataset(
16
+ 'imagenet',
17
+ data_root=in1k_data_root,
18
+ is_train=True,
19
+ )
20
+ self.assertEqual(len(ds), 1281167)
21
+
22
+ def test_in1k_filter(self):
23
+ ds = build_dataset(
24
+ 'imagenet',
25
+ data_root=in1k_data_root,
26
+ is_train=False,
27
+ filter_class=0,
28
+ )
29
+ self.assertEqual(len(ds), 50)
30
+
31
+
32
+ if __name__ == '__main__':
33
+ unittest.main()
attacks/AIM/tests/test_datasets/test_transforms.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+
3
+ import torch
4
+ from PIL import Image
5
+ from torchvision.transforms import Compose
6
+
7
+ from gat.datasets.transforms import (hflip, norm, resize_224, resize_256_224,
8
+ resize_512_448, to_color, to_pil, to_ts)
9
+
10
+
11
+ class TestResize(unittest.TestCase):
12
+
13
+ def test_resize_256_224(self):
14
+ inputs = torch.rand(3, 256, 256)
15
+ outputs = Compose(resize_256_224())(inputs)
16
+ self.assertEqual(outputs.shape, (3, 224, 224))
17
+
18
+ def test_resize_512_448(self):
19
+ inputs = torch.rand(3, 512, 512)
20
+ outputs = Compose(resize_512_448())(inputs)
21
+ self.assertEqual(outputs.shape, (3, 448, 448))
22
+
23
+ def test_resize_224(self):
24
+ inputs = torch.rand(3, 224, 224)
25
+ outputs = Compose(resize_224())(inputs)
26
+ self.assertEqual(outputs.shape, (3, 224, 224))
27
+
28
+
29
+ class TestAug(unittest.TestCase):
30
+
31
+ def test_probability_invalid(self):
32
+ with self.assertRaises(AssertionError):
33
+ hflip(-0.1)
34
+
35
+ def test_probability_valid(self):
36
+ for valid_p in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
37
+ result = hflip(valid_p)
38
+ self.assertEqual(result[0].p, valid_p)
39
+
40
+
41
+ class TestTypeTransforms(unittest.TestCase):
42
+
43
+ def test_to_ts(self):
44
+ pil_image = Image.new('RGB', (224, 224))
45
+ outputs = Compose(to_ts())(pil_image)
46
+ self.assertIsInstance(outputs, torch.Tensor)
47
+
48
+ def test_to_pil(self):
49
+ inputs = torch.rand(3, 224, 224)
50
+ outputs = Compose(to_pil())(inputs)
51
+ self.assertIsInstance(outputs, Image.Image)
52
+
53
+ def test_to_color_grayscale(self):
54
+ inputs = Image.new('L', (224, 224))
55
+ outputs = Compose(to_color())(inputs)
56
+ self.assertEqual(outputs.mode, 'RGB')
57
+
58
+ def test_to_color_rgb(self):
59
+ inputs = Image.new('RGB', (224, 224))
60
+ outputs = Compose(to_color())(inputs)
61
+ self.assertEqual(outputs.mode, 'RGB')
62
+
63
+
64
+ class TestNorm(unittest.TestCase):
65
+
66
+ def test_default(self):
67
+ self.assertEqual(norm()[0].mean, (0.485, 0.456, 0.406))
68
+ self.assertEqual(norm()[0].std, (0.229, 0.224, 0.225))
69
+
70
+ def test_imagenet(self):
71
+ self.assertEqual(norm('IMAGENET')[0].mean, (0.485, 0.456, 0.406))
72
+ self.assertEqual(norm('IMAGENET')[0].std, (0.229, 0.224, 0.225))
73
+
74
+ def test_invalid_ds(self):
75
+ with self.assertRaises(AttributeError):
76
+ norm('imagenets')
77
+
78
+
79
+ if __name__ == '__main__':
80
+ unittest.main()
attacks/AIM/tests/test_models/__init__.py ADDED
File without changes
attacks/AIM/tests/test_models/test_attack.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+
3
+ import torch
4
+
5
+ from gat.models.attack.aim_attack import AIMAttack
6
+ from gat.models.attack.cda_attack import CDAAttack
7
+ from gat.models.attack.generator.aim import AIMGenerator
8
+ from gat.models.attack.generator.cda import CDAGenerator
9
+ from gat.models.attack.loss.logits import ContrastiveLoss
10
+
11
+
12
+ class TestGenerator(unittest.TestCase):
13
+
14
+ @torch.no_grad()
15
+ def test_cda(self):
16
+ generator = CDAGenerator()
17
+ inputs = torch.rand(1, 3, 224, 224)
18
+ outputs = generator(inputs)
19
+ self.assertEqual(outputs.shape, (1, 3, 224, 224))
20
+
21
+ @torch.no_grad()
22
+ def test_aim(self):
23
+ generator = AIMGenerator()
24
+ inputs = (torch.rand(1, 3, 224, 224), torch.rand(1, 3, 224, 224))
25
+ outputs = generator(*inputs)
26
+ self.assertEqual(outputs.shape, (1, 3, 224, 224))
27
+
28
+
29
+ class TestContrastiveLoss(unittest.TestCase):
30
+
31
+ def setUp(self):
32
+ self.margin = 1.0
33
+ self.loss_fn = ContrastiveLoss(margin=self.margin)
34
+
35
+ @torch.no_grad()
36
+ def test_forward_shape(self):
37
+ anchors = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
38
+ positives = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
39
+ negatives = torch.tensor([[0.0, 0.0], [1.0, 1.0]])
40
+
41
+ loss = self.loss_fn(anchors, negatives, positives)
42
+ self.assertEqual(loss.shape, torch.Size([]))
43
+
44
+ @torch.no_grad()
45
+ def test_loss_outputs(self):
46
+ anchors = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
47
+ positives = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
48
+ negatives = torch.tensor([[2.0, 0.0], [0.0, 2.0]])
49
+
50
+ loss = self.loss_fn(anchors, negatives, positives)
51
+ self.assertAlmostEqual(loss.item(), 0.25)
52
+
53
+ @torch.no_grad()
54
+ def test_non_zero_loss(self):
55
+ anchors = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
56
+ positives = torch.tensor([[0.0, 1.0], [1.0, 0.0]])
57
+ negatives = torch.tensor([[2.0, 0.0], [0.0, 2.0]])
58
+
59
+ loss = self.loss_fn(anchors, negatives, positives)
60
+ self.assertAlmostEqual(loss.item(), 0.75)
61
+
62
+
63
+ class TestCDAAttack(unittest.TestCase):
64
+
65
+ def setUp(self):
66
+ self.device = torch.device('cpu')
67
+ self.epsilon = 16. / 255.
68
+ self.attack = CDAAttack(self.device, self.epsilon)
69
+
70
+ @torch.no_grad()
71
+ def test_outputs_shape(self):
72
+ x_nat = torch.rand(1, 3, 224, 224)
73
+ x_adv = self.attack(x_nat)
74
+ self.assertEqual(x_adv.shape, (1, 3, 224, 224))
75
+
76
+ @torch.no_grad()
77
+ def test_outputs_bound(self):
78
+ x_nat = torch.rand(1, 3, 224, 224)
79
+ x_adv = self.attack(x_nat)
80
+ self.assertTrue((x_adv >= 0.0).all())
81
+ self.assertTrue((x_adv <= 1.0).all())
82
+ self.assertTrue((x_adv - x_nat).abs().max() <= self.epsilon)
83
+
84
+
85
+ class TestAIMAttack(unittest.TestCase):
86
+
87
+ def setUp(self):
88
+ self.device = torch.device('cpu')
89
+ self.epsilon = 16. / 255.
90
+ self.attack = AIMAttack(self.device, self.epsilon)
91
+
92
+ @torch.no_grad()
93
+ def test_outputs_shape(self):
94
+ x_nat = torch.rand(1, 3, 224, 224)
95
+ x_guid = torch.rand(1, 3, 224, 224)
96
+ x_adv = self.attack(x_nat, x_guid)
97
+ self.assertEqual(x_adv.shape, (1, 3, 224, 224))
98
+
99
+ @torch.no_grad()
100
+ def test_outputs_bound(self):
101
+ x_nat = torch.rand(1, 3, 224, 224)
102
+ x_guid = torch.rand(1, 3, 224, 224)
103
+ x_adv = self.attack(x_nat, x_guid)
104
+ self.assertTrue((x_adv >= 0.0).all())
105
+ self.assertTrue((x_adv <= 1.0).all())
106
+ self.assertTrue((x_adv - x_nat).abs().max() <= self.epsilon)
107
+
108
+
109
+ if __name__ == '__main__':
110
+ unittest.main()
attacks/AIM/tests/test_models/test_surrogate.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+
3
+ import torch
4
+
5
+ from gat.models.surrogate import build_surrogate, list_surrogates
6
+ from gat.models.surrogate.hooks import feat_col
7
+
8
+
9
+ class TestTVModels(unittest.TestCase):
10
+
11
+ @torch.no_grad()
12
+ def test_outputs_shape(self):
13
+ inputs = torch.rand(1, 3, 224, 224)
14
+ for surrogate_id in list_surrogates():
15
+ if surrogate_id not in ['inception_v3']:
16
+ outputs = build_surrogate(surrogate_id, pretrain=False)(inputs)
17
+ self.assertEqual(outputs.shape, (1, 1000))
18
+
19
+
20
+ class TestFeatCol(unittest.TestCase):
21
+
22
+ @torch.no_grad()
23
+ def test_feat_col(self):
24
+ testcases = [{
25
+ 'surrogate_id': 'vgg16',
26
+ 'feat_layer': 'features.16',
27
+ 'input_shape': (1, 3, 224, 224),
28
+ 'feat_shape': (1, 256, 28, 28)
29
+ }, {
30
+ 'surrogate_id': 'vgg19',
31
+ 'feat_layer': 'features.18',
32
+ 'input_shape': (1, 3, 224, 224),
33
+ 'feat_shape': (1, 256, 28, 28)
34
+ }, {
35
+ 'surrogate_id': 'resnet152',
36
+ 'feat_layer': 'layer2',
37
+ 'input_shape': (1, 3, 224, 224),
38
+ 'feat_shape': (1, 512, 28, 28)
39
+ }, {
40
+ 'surrogate_id': 'densenet169',
41
+ 'feat_layer': 'features.denseblock2',
42
+ 'input_shape': (1, 3, 224, 224),
43
+ 'feat_shape': (1, 512, 28, 28)
44
+ }]
45
+ for testcase in testcases:
46
+ model = build_surrogate(testcase['surrogate_id'], pretrain=False)
47
+ with feat_col(model, testcase['feat_layer']) as _feat_collecter:
48
+ inputs = torch.rand(testcase['input_shape'])
49
+ model(inputs)
50
+ model(inputs)
51
+ self.assertEqual(len(_feat_collecter), 2)
52
+ self.assertEqual(_feat_collecter.pop().shape,
53
+ testcase['feat_shape'])
54
+ self.assertEqual(_feat_collecter.pop().shape,
55
+ testcase['feat_shape'])
56
+
57
+
58
+ if __name__ == '__main__':
59
+ unittest.main()
attacks/AIM/tests/test_runtime/__init__.py ADDED
File without changes