Upload 192 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +113 -3
- attack.py +140 -0
- attacks/AIM/AIMAttack.py +226 -0
- attacks/AIM/examples/aim_attack.py +237 -0
- attacks/AIM/examples/ens-gen.py +523 -0
- attacks/AIM/examples/workdirs/1MsK/args.txt +11 -0
- attacks/AIM/examples/workdirs/CGKQ/args.txt +11 -0
- attacks/AIM/examples/workdirs/Fvu1/args.txt +11 -0
- attacks/AIM/examples/workdirs/Qonx/args.txt +11 -0
- attacks/AIM/examples/workdirs/fnNs/args.txt +11 -0
- attacks/AIM/examples/workdirs/jtMb/args.txt +11 -0
- attacks/AIM/examples/workdirs/krhX/args.txt +11 -0
- attacks/AIM/setup.py +47 -0
- attacks/AIM/src/gat/__init__.py +1 -0
- attacks/AIM/src/gat/datasets/__init__.py +5 -0
- attacks/AIM/src/gat/datasets/builder.py +11 -0
- attacks/AIM/src/gat/datasets/cub.py +6 -0
- attacks/AIM/src/gat/datasets/env.py +6 -0
- attacks/AIM/src/gat/datasets/imagenet.py +36 -0
- attacks/AIM/src/gat/datasets/transforms.py +55 -0
- attacks/AIM/src/gat/models/__init__.py +0 -0
- attacks/AIM/src/gat/models/attack/__init__.py +5 -0
- attacks/AIM/src/gat/models/attack/aim_attack.py +14 -0
- attacks/AIM/src/gat/models/attack/base_attack.py +60 -0
- attacks/AIM/src/gat/models/attack/cda_attack.py +13 -0
- attacks/AIM/src/gat/models/attack/generator/__init__.py +0 -0
- attacks/AIM/src/gat/models/attack/generator/aim.py +179 -0
- attacks/AIM/src/gat/models/attack/generator/cda.py +146 -0
- attacks/AIM/src/gat/models/attack/loss/__init__.py +0 -0
- attacks/AIM/src/gat/models/attack/loss/logits.py +27 -0
- attacks/AIM/src/gat/models/attack/optim/__init__.py +3 -0
- attacks/AIM/src/gat/models/attack/optim/sam.py +100 -0
- attacks/AIM/src/gat/models/surrogate/__init__.py +10 -0
- attacks/AIM/src/gat/models/surrogate/builder.py +11 -0
- attacks/AIM/src/gat/models/surrogate/hooks.py +59 -0
- attacks/AIM/src/gat/models/surrogate/tv.py +206 -0
- attacks/AIM/src/gat/runtime/__init__.py +4 -0
- attacks/AIM/src/gat/runtime/api/__init__.py +0 -0
- attacks/AIM/src/gat/runtime/api/aim_attack.py +58 -0
- attacks/AIM/src/gat/runtime/factory.py +97 -0
- attacks/AIM/src/gat/runtime/meter.py +27 -0
- attacks/AIM/src/gat/runtime/utils.py +18 -0
- attacks/AIM/tests/__init__.py +0 -0
- attacks/AIM/tests/test_datasets/__init__.py +0 -0
- attacks/AIM/tests/test_datasets/test_datasets.py +33 -0
- attacks/AIM/tests/test_datasets/test_transforms.py +80 -0
- attacks/AIM/tests/test_models/__init__.py +0 -0
- attacks/AIM/tests/test_models/test_attack.py +110 -0
- attacks/AIM/tests/test_models/test_surrogate.py +59 -0
- attacks/AIM/tests/test_runtime/__init__.py +0 -0
README.md
CHANGED
|
@@ -1,3 +1,113 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SAE: Sustainable Adversarial Example Evaluation Framework for Class-Incremental Learning
|
| 2 |
+
|
| 3 |
+
## 🌟 Overview
|
| 4 |
+
|
| 5 |
+
_**News**: This work has been accepted as a poster by [AAAI 2026](https://aaai.org/conference/aaai/aaai-26/)._
|
| 6 |
+
|
| 7 |
+
**SAE (Sustainable Adversarial Example)** is a *universal adversarial attack framework* targeting **Class-Incremental Learning (CIL)**. This repository provides a comprehensive pipeline for both CIL training and benchmarking multiple attack methods, including our proposed SAE approach.
|
| 8 |
+
|
| 9 |
+
The project integrates with [PyCIL: A Python Toolbox for Class-Incremental Learning](https://github.com/LAMDA-CL/PyCIL) for CIL model training. It also supports benchmarking several attack baselines alongside SAE, enabling fair and reproducible evaluations of adversarial robustness across CIL methods.
|
| 10 |
+
|
| 11 |
+
If you are interested in our work, please refer to:
|
| 12 |
+
```
|
| 13 |
+
@inproceedings{liu2026SAE,
|
| 14 |
+
title={Improving Sustainability of Adversarial Examples in Class-Incremental Learning},
|
| 15 |
+
author={Taifeng Liu, Xinjing Liu, Liangqiu Dong, Yang Liu, Yilong Yang, Zhuo Ma},
|
| 16 |
+
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
|
| 17 |
+
year={2026}
|
| 18 |
+
}
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
---
|
| 22 |
+
|
| 23 |
+
## ⚙️ Environment Setup
|
| 24 |
+
|
| 25 |
+
### Project Layout
|
| 26 |
+
|
| 27 |
+
* `attacks/`: Implementations of all attack baselines including **MIFGSM**, **Gaker**, **AIM**, **CGNC**, **CleanSheet**, **UnivIntruder**, and **SAE**.
|
| 28 |
+
* `convs/`: Backbone definitions for CIL models and CLIP model, including `resnet32`, `resnet50`, `cosine_resnet32`, and `cosine_resnet50`.
|
| 29 |
+
* **Pre-trained CLIP** is downloaded from [HuggingFace](https://huggingface.co/laion/CLIP-ViT-B-32-laion2B-s34B-b79K).
|
| 30 |
+
* `datasets/`: Dataset management for CIFAR-100 (32x32) and ImageNet-100 (224x224).
|
| 31 |
+
* **CIFAR-100** is automatically downloaded at runtime.
|
| 32 |
+
* **ImageNet-100** should be manually extracted from ImageNet-1K using `create_imagenet100_from_imagenet.py`, which parses `train.txt` and `eval.txt` to extract relevant classes.
|
| 33 |
+
* `exps/`: JSON configuration files for various CIL training methods.
|
| 34 |
+
* `logs/`: Stores trained CIL model checkpoints and evaluation results of attacks.
|
| 35 |
+
* `models/`: Contains implementations of **9** CIL algorithms: **BiC**, **DER**, **Finetune**, **Foster**, **iCaRL**, **MEMO**, **PodNet**, **Replay**, **WA**.
|
| 36 |
+
* `scripts/`: Shell scripts for CIL training and attack benchmarking.
|
| 37 |
+
* `utils/`: Utility functions for augmentation, logging, dataset processing, visualization, etc.
|
| 38 |
+
* `attack.py`: Main entry point to run adversarial attacks.
|
| 39 |
+
* `trainCIL.py`: Main entry point to train CIL models.
|
| 40 |
+
|
| 41 |
+
### Dependency Requirements
|
| 42 |
+
|
| 43 |
+
* **OS**: Ubuntu 22.04
|
| 44 |
+
* **Python**: 3.12
|
| 45 |
+
* **PyTorch**: ≥ 2.1
|
| 46 |
+
* **GPU**: NVIDIA RTX 4090 (24GB VRAM)
|
| 47 |
+
|
| 48 |
+
To set up the environment:
|
| 49 |
+
|
| 50 |
+
```bash
|
| 51 |
+
conda env create -f environment.yml
|
| 52 |
+
conda activate SAE
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
---
|
| 56 |
+
|
| 57 |
+
## 🚀 Result Reproduction
|
| 58 |
+
|
| 59 |
+
### Step 1: CIL Training
|
| 60 |
+
|
| 61 |
+
First, you need to train the target CIL models follow the instruction below.
|
| 62 |
+
|
| 63 |
+
Example: training a single CIL method (e.g., **iCaRL**) on **CIFAR-100**:
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
python trainCIL.py --config exps/icarl.json
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
Example: training a single CIL method (e.g., **iCaRL**) on **ImageNet-100**:
|
| 70 |
+
|
| 71 |
+
```bash
|
| 72 |
+
python trainCIL.py --config exps/icarl-imagenet100.json
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
We also provide scripts to train all 9 CIL methods on **CIFAR-100**:
|
| 76 |
+
|
| 77 |
+
```bash
|
| 78 |
+
./scripts/trainCIL-CIFAR100.sh
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
Train all 9 CIL methods on **ImageNet-100**:
|
| 82 |
+
|
| 83 |
+
```bash
|
| 84 |
+
./scripts/trainCIL-ImageNet100.sh
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
All model checkpoints will be saved under the `logs/` directory, organized by method and dataset.
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
---
|
| 91 |
+
|
| 92 |
+
### Step 2: Adversarial Attack Benchmarking
|
| 93 |
+
|
| 94 |
+
_Note: assume that you have already prepared the dataset, the CIL model, and the CLIP model._
|
| 95 |
+
|
| 96 |
+
To launch an **SAE** attack targeting class **0** on a CIL model trained on **CIFAR-100**:
|
| 97 |
+
|
| 98 |
+
```bash
|
| 99 |
+
python attack.py --config exps/icarl.json --attack_method SAE --target_class 0
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
To test a different attack baseline, simply change the `--attack_method` argument.
|
| 103 |
+
|
| 104 |
+
You can reproduce the overall evaluation results by runing the below scripts (referring to the **Table 1** in our paper).
|
| 105 |
+
|
| 106 |
+
```bash
|
| 107 |
+
./scripts/attacks-CIFAR100.sh
|
| 108 |
+
./scripts/attacks-ImageNet100.sh
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
#### Benchmark Results
|
| 112 |
+
|
| 113 |
+
All benchmark results are available in the `appendix.pdf` of Supplementary Material.
|
attack.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import argparse
|
| 3 |
+
import sys
|
| 4 |
+
import logging
|
| 5 |
+
import copy
|
| 6 |
+
import torch
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
from trainCIL import _set_random, _set_device, print_args
|
| 10 |
+
from attacks.AIM.AIMAttack import AIM
|
| 11 |
+
from attacks.BASEAttack import BASEAttack
|
| 12 |
+
from attacks.Gaker.GAKERAttack import Gaker
|
| 13 |
+
from attacks.CGNC.CGNCAttack import CGNC
|
| 14 |
+
from attacks.CleanSheet.CleanSheetAttack import CleanSheet
|
| 15 |
+
from attacks.UnivIntruder.UnivIntruderAttack import UnivIntruder
|
| 16 |
+
from attacks.SAE.SAEAttack import SAE
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def main():
|
| 20 |
+
args = setup_parser().parse_args()
|
| 21 |
+
param = load_json(args.config)
|
| 22 |
+
args = vars(args) # Converting argparse Namespace to a dict.
|
| 23 |
+
args.update(param) # Add parameters from json
|
| 24 |
+
|
| 25 |
+
evaluate(args)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def evaluate(args):
|
| 29 |
+
seed_list = copy.deepcopy(args["seed"])
|
| 30 |
+
device = copy.deepcopy(args["device"])
|
| 31 |
+
|
| 32 |
+
for seed in seed_list:
|
| 33 |
+
args["seed"] = seed
|
| 34 |
+
args["device"] = device
|
| 35 |
+
_evaluate(args)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _evaluate(args):
|
| 39 |
+
# For attacks
|
| 40 |
+
args["attack"] = True
|
| 41 |
+
|
| 42 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 43 |
+
init_cls = 0 if args["init_cls"] == args["increment"] else args["init_cls"]
|
| 44 |
+
logs_name = "logs/{}/{}/init_cls_{}/per_classes_{}/{}".format(args["model_name"], args["dataset"], init_cls,
|
| 45 |
+
args['increment'], args["convnet_type"])
|
| 46 |
+
if not os.path.exists(logs_name):
|
| 47 |
+
raise "Model is not trained yet."
|
| 48 |
+
logs_eval_name = "logs/{}/{}/init_cls_{}/per_classes_{}/{}/eval/{}".format(args["model_name"], args["dataset"], init_cls,
|
| 49 |
+
args['increment'], args["convnet_type"], args['attack_method'])
|
| 50 |
+
if not os.path.exists(logs_eval_name):
|
| 51 |
+
os.makedirs(logs_eval_name)
|
| 52 |
+
|
| 53 |
+
logfilename = "logs/{}/{}/init_cls_{}/per_classes_{}/{}/eval/{}/adv_log".format(
|
| 54 |
+
args["model_name"],
|
| 55 |
+
args["dataset"],
|
| 56 |
+
init_cls,
|
| 57 |
+
args["increment"],
|
| 58 |
+
args["convnet_type"],
|
| 59 |
+
args['attack_method']
|
| 60 |
+
)
|
| 61 |
+
logging.basicConfig(
|
| 62 |
+
level=logging.INFO,
|
| 63 |
+
format="%(asctime)s [%(filename)s] => %(message)s",
|
| 64 |
+
handlers=[
|
| 65 |
+
logging.FileHandler(filename=logfilename + ".log"),
|
| 66 |
+
logging.StreamHandler(sys.stdout),
|
| 67 |
+
],
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
args['epsilons'] = [0.01, 0.015, 0.03, 0.06, 0.1, 0.2]
|
| 71 |
+
args['logs_name'] = logs_name
|
| 72 |
+
args['logs_eval_name'] = logs_eval_name
|
| 73 |
+
_set_random()
|
| 74 |
+
_set_device(args)
|
| 75 |
+
print_args(args)
|
| 76 |
+
|
| 77 |
+
# Init the attack
|
| 78 |
+
if args['attack_method'] in NEW_ATTACKS:
|
| 79 |
+
adv = init_new_attack(args, device=device)
|
| 80 |
+
if args['attack_method'] == 'AIM' or args['attack_method'] == 'Gaker' or args['attack_method'] == 'CGNC':
|
| 81 |
+
adv.train_generator()
|
| 82 |
+
elif args['attack_method'] == 'SAE' or args['attack_method'] == 'UnivIntruder' or args['attack_method'] == 'CleanSheet':
|
| 83 |
+
adv.train_adv()
|
| 84 |
+
else:
|
| 85 |
+
adv = init_foolbox_attack(args, device=device)
|
| 86 |
+
|
| 87 |
+
# Conduct attack
|
| 88 |
+
adv.run_test()
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
NEW_ATTACKS = ['AIM', 'Gaker', 'CGNC', 'CleanSheet', 'UnivIntruder', 'SAE']
|
| 92 |
+
FOOLBOX_ATTACKS = ['L2FGM', 'FGSM', 'MIFGSM',
|
| 93 |
+
'L1PGD', 'L2PGD', 'LinfPGD',
|
| 94 |
+
'L2DeepFool', 'LinfDeepFool', 'BoundaryAttack',
|
| 95 |
+
'CarliniWagnerL2', 'GaussianNoise', 'UniformNoise']
|
| 96 |
+
|
| 97 |
+
def init_new_attack(args, device='cuda', **kwargs):
|
| 98 |
+
attack_name = args['attack_method']
|
| 99 |
+
models = kwargs.get('models', None)
|
| 100 |
+
|
| 101 |
+
if attack_name == 'AIM':
|
| 102 |
+
adv = AIM(args=args, device=device)
|
| 103 |
+
elif attack_name == 'Gaker':
|
| 104 |
+
adv = Gaker(args=args, device=device)
|
| 105 |
+
elif attack_name == 'CGNC':
|
| 106 |
+
adv = CGNC(args=args, device=device)
|
| 107 |
+
elif attack_name == 'CleanSheet':
|
| 108 |
+
adv = CleanSheet(args=args, device=device)
|
| 109 |
+
elif attack_name == 'UnivIntruder':
|
| 110 |
+
adv = UnivIntruder(args=args, device=device)
|
| 111 |
+
elif attack_name == 'SAE':
|
| 112 |
+
adv = SAE(args=args, device=device)
|
| 113 |
+
else:
|
| 114 |
+
raise ValueError(f"Unknown attack method: {attack_name}")
|
| 115 |
+
|
| 116 |
+
return adv
|
| 117 |
+
|
| 118 |
+
def init_foolbox_attack(args, device='cuda', **kwargs):
|
| 119 |
+
attack = BASEAttack(args=args, device=device)
|
| 120 |
+
return attack
|
| 121 |
+
|
| 122 |
+
def load_json(settings_path):
|
| 123 |
+
with open(settings_path) as data_file:
|
| 124 |
+
param = json.load(data_file)
|
| 125 |
+
|
| 126 |
+
return param
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def setup_parser():
|
| 130 |
+
parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorithms.')
|
| 131 |
+
parser.add_argument('--config', type=str, default='exps/finetune.json', help='Json file of settings.')
|
| 132 |
+
parser.add_argument('--batch_size', type=int, default=128, help='set the batch size.')
|
| 133 |
+
parser.add_argument('--attack_method', type=str, default='AIM', help='set the attack method, e.g., LinfPGD, MIFGSM, AIM, Gaker, CGNC, CleanSheet, UnivIntruder, SAE.')
|
| 134 |
+
parser.add_argument('--target_class', type=int, default=0, help='the target class, None indicates untargeted attack.')
|
| 135 |
+
parser.add_argument('--eval', action='store_true', help='evaluation only')
|
| 136 |
+
return parser
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
if __name__ == '__main__':
|
| 140 |
+
main()
|
attacks/AIM/AIMAttack.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
import torch
|
| 3 |
+
from foolbox.attacks.base import *
|
| 4 |
+
from foolbox.attacks.gradient_descent_base import *
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from attacks.AIM.src.gat.models.attack import AIMAttack, ContrastiveLoss
|
| 7 |
+
from attacks.AIM.src.gat.models.surrogate import midlayer_dict, register_collecter, register_collecter_cl
|
| 8 |
+
from attacks.attack_config import SustainableAttack
|
| 9 |
+
from utils.plot import plot_asr_per_target, save_grad_cam
|
| 10 |
+
import logging
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import foolbox as fb
|
| 13 |
+
from foolbox import PyTorchModel
|
| 14 |
+
import numpy as np
|
| 15 |
+
from utils import factory
|
| 16 |
+
from utils.data_manager import get_dataloader
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class AIM(SustainableAttack):
|
| 20 |
+
def __init__(self, args, device='cuda'):
|
| 21 |
+
super().__init__(args, device)
|
| 22 |
+
self.device = device
|
| 23 |
+
self.args = args
|
| 24 |
+
self.surrogate_model = None
|
| 25 |
+
|
| 26 |
+
self.adv_generator = AIMAttack(device=device)
|
| 27 |
+
self.adv_generator.set_mode('train')
|
| 28 |
+
self.lr = 0.001
|
| 29 |
+
self.betas = (0.5, 0.999)
|
| 30 |
+
self.num_epoch = 100
|
| 31 |
+
self.optim = torch.optim.Adam(self.adv_generator.get_params(), lr=self.lr, betas=self.betas)
|
| 32 |
+
self.contrastive_loss = ContrastiveLoss(0.2)
|
| 33 |
+
self.sim_loss = torch.nn.functional.cosine_similarity
|
| 34 |
+
self.eval_batch_szie = 128
|
| 35 |
+
|
| 36 |
+
self.surrogate_model_name = 'resnet32_cl'
|
| 37 |
+
self.layer = midlayer_dict[self.surrogate_model_name]
|
| 38 |
+
self.prefix = (f'adv_generator_{self.surrogate_model_name}'
|
| 39 |
+
f'_{self.layer}'
|
| 40 |
+
f'_tclass{self.target_class}')
|
| 41 |
+
self.save_path = os.path.join(self.args['logs_eval_name'], f'target{str(self.target_class)}')
|
| 42 |
+
if not os.path.exists(self.save_path):
|
| 43 |
+
os.makedirs(self.save_path)
|
| 44 |
+
|
| 45 |
+
self.plot_gradcam = False
|
| 46 |
+
|
| 47 |
+
def train_generator(self):
|
| 48 |
+
if 'cl' in self.surrogate_model_name:
|
| 49 |
+
s_model = factory.get_model(self.args["model_name"], self.args)
|
| 50 |
+
s_model.incremental_train(self.data_manager)
|
| 51 |
+
s_model._network.load_state_dict(
|
| 52 |
+
torch.load(self.ckpt_paths[0], map_location=self.device)['model_state_dict'])
|
| 53 |
+
s_model._network.to(self.device)
|
| 54 |
+
s_model._network.eval()
|
| 55 |
+
self.surrogate_model = s_model._network
|
| 56 |
+
del s_model
|
| 57 |
+
torch.cuda.empty_cache()
|
| 58 |
+
self.feat_collecter = []
|
| 59 |
+
self.feat_collecter_handler, self.feat_collecter = register_collecter_cl(self.surrogate_model,
|
| 60 |
+
self.layer,
|
| 61 |
+
self.feat_collecter,
|
| 62 |
+
self.args["model_name"])
|
| 63 |
+
else:
|
| 64 |
+
self.surrogate_model = torch.hub.load("chenyaofo/pytorch-cifar-models", 'cifar100_resnet32', pretrained=True)
|
| 65 |
+
self.surrogate_model.to(self.device)
|
| 66 |
+
self.surrogate_model.eval()
|
| 67 |
+
self.feat_collecter = []
|
| 68 |
+
self.feat_collecter_handler, self.feat_collecter = register_collecter(self.surrogate_model,
|
| 69 |
+
self.layer,
|
| 70 |
+
self.feat_collecter)
|
| 71 |
+
self.file_path = os.path.join(self.save_path, f'{self.prefix}.pth')
|
| 72 |
+
if os.path.exists(self.file_path):
|
| 73 |
+
self.adv_generator.load_ckpt(self.file_path)
|
| 74 |
+
self.adv_generator.set_mode('eval')
|
| 75 |
+
else:
|
| 76 |
+
loaders = get_dataloader(self.data_manager, batch_size=self.batch_size,
|
| 77 |
+
start_class=0, end_class=10,
|
| 78 |
+
train=True, shuffle=True, num_workers=0)
|
| 79 |
+
|
| 80 |
+
target_images = []
|
| 81 |
+
target_labels = []
|
| 82 |
+
for data in loaders:
|
| 83 |
+
_, image_batch, label_batch = data
|
| 84 |
+
mask = label_batch == self.target_class
|
| 85 |
+
selected_images = image_batch[mask]
|
| 86 |
+
selected_labels = label_batch[mask]
|
| 87 |
+
target_images.append(selected_images)
|
| 88 |
+
target_labels.append(selected_labels)
|
| 89 |
+
del loaders
|
| 90 |
+
target_images = torch.cat(target_images, dim=0).to(self.device)
|
| 91 |
+
target_labels = torch.cat(target_labels, dim=0).to(self.device)
|
| 92 |
+
target_images, target_labels = ep.astensors(*(target_images[:self.batch_size], target_labels[:self.batch_size]))
|
| 93 |
+
|
| 94 |
+
total_loss = []
|
| 95 |
+
for epoch in range(1, self.num_epoch + 1):
|
| 96 |
+
laoder_tqdm = tqdm(self.loader, total=len(self.loader), desc=f'Epoch {epoch}')
|
| 97 |
+
loss_np = 0
|
| 98 |
+
for i, (_, x, y) in enumerate(laoder_tqdm):
|
| 99 |
+
x_f = x[y != self.target_class].to(self.device)
|
| 100 |
+
del x, y
|
| 101 |
+
if len(x_f) > len(target_images):
|
| 102 |
+
x_f = x_f[:len(target_images)].to(self.device)
|
| 103 |
+
else:
|
| 104 |
+
random_indices = torch.randperm(len(target_images))[:len(x_f)].to(self.device)
|
| 105 |
+
target_images = target_images[random_indices]
|
| 106 |
+
|
| 107 |
+
x_adv = self.adv_generator(x_f, target_images.raw.to(self.device))
|
| 108 |
+
|
| 109 |
+
logits_nat = self.surrogate_model(self.norm(x_f))
|
| 110 |
+
feat_nat = self.feat_collecter.pop()
|
| 111 |
+
logits_tar = self.surrogate_model(self.norm(target_images.raw))
|
| 112 |
+
feat_tar = self.feat_collecter.pop()
|
| 113 |
+
logits_adv = self.surrogate_model(self.norm(x_adv))
|
| 114 |
+
feat_adv = self.feat_collecter.pop()
|
| 115 |
+
|
| 116 |
+
loss = (self.contrastive_loss(logits_adv, logits_nat, logits_tar) +
|
| 117 |
+
self.sim_loss(feat_nat, feat_adv) -
|
| 118 |
+
self.sim_loss(feat_tar, feat_adv)).mean()
|
| 119 |
+
# print(loss.item())
|
| 120 |
+
loss_np = loss_np + loss.item()
|
| 121 |
+
|
| 122 |
+
self.optim.zero_grad()
|
| 123 |
+
loss.backward()
|
| 124 |
+
self.optim.step()
|
| 125 |
+
del x_f, x_adv, logits_nat, logits_adv, logits_tar, feat_nat, feat_tar, feat_adv
|
| 126 |
+
torch.cuda.empty_cache()
|
| 127 |
+
total_loss.append(loss_np/(i+1))
|
| 128 |
+
logging.info(f'Epoch {epoch} loss: {loss_np/(len(self.loader))}')
|
| 129 |
+
logging.info(f'Total loss: {total_loss}')
|
| 130 |
+
self.feat_collecter_handler.remove()
|
| 131 |
+
self.adv_generator.save_ckpt(self.file_path)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def run_test(self):
|
| 135 |
+
# Load Batch Data
|
| 136 |
+
self.adv_generator.set_mode('eval')
|
| 137 |
+
self.adv_generator.adv_gen.to(self.device)
|
| 138 |
+
self.loader = get_dataloader(self.data_manager, batch_size=self.eval_batch_szie,
|
| 139 |
+
start_class=0, end_class=10,
|
| 140 |
+
train=False, shuffle=False, num_workers=0)
|
| 141 |
+
target_images = []
|
| 142 |
+
target_labels = []
|
| 143 |
+
for data in self.loader:
|
| 144 |
+
_, image_batch, label_batch = data
|
| 145 |
+
mask = label_batch == self.target_class
|
| 146 |
+
selected_images = image_batch[mask]
|
| 147 |
+
selected_labels = label_batch[mask]
|
| 148 |
+
target_images.append(selected_images)
|
| 149 |
+
target_labels.append(selected_labels)
|
| 150 |
+
target_imgs = torch.cat(target_images, dim=0).to(self.device)
|
| 151 |
+
target_labels = torch.cat(target_labels, dim=0).to(self.device)
|
| 152 |
+
target_imgs, target_labels = ep.astensors(*(target_imgs, target_labels))
|
| 153 |
+
for i, (_, imgs, labels) in enumerate(tqdm(self.loader, total=len(self.loader),
|
| 154 |
+
desc=f'Loading Data with Batch Size of {self.batch_size}) :')):
|
| 155 |
+
if i > 0:
|
| 156 |
+
break
|
| 157 |
+
|
| 158 |
+
imgs, labels = ep.astensors(*(imgs.to(self.device), labels.to(self.device)))
|
| 159 |
+
|
| 160 |
+
imgs_f = imgs[labels != self.target_class]
|
| 161 |
+
labels_f = labels[labels != self.target_class]
|
| 162 |
+
labels_t_f = ep.full_like(labels_f, fill_value=self.target_class)
|
| 163 |
+
|
| 164 |
+
self.attacks(i, imgs_f, labels_f, labels_t_f, target_imgs[:20], target_labels[:20])
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def attacks(self, i_batch, imgs, labels, labels_t, target_imgs=None, target_labels=None):
|
| 168 |
+
asr_matrix = np.ones((10, len(target_imgs)))
|
| 169 |
+
self.model = factory.get_model(self.args["model_name"], self.args)
|
| 170 |
+
for task in range(10):
|
| 171 |
+
logging.info("***** Starting attack on task [{}]. *****".format(task))
|
| 172 |
+
self.model.incremental_train(self.data_manager)
|
| 173 |
+
self.model._network.load_state_dict(torch.load(self.ckpt_paths[task], map_location=self.device)['model_state_dict'])
|
| 174 |
+
self.model._network.to(self.device)
|
| 175 |
+
self.model._network.eval()
|
| 176 |
+
|
| 177 |
+
# Run attack on ecah target image
|
| 178 |
+
criterion = fb.criteria.Misclassification(
|
| 179 |
+
labels) if self.target_class is None else fb.criteria.TargetedMisclassification(
|
| 180 |
+
labels_t)
|
| 181 |
+
current_model = PyTorchModel(self.model._network, bounds=(0, 1), preprocessing=self.preprocessing)
|
| 182 |
+
verify_input_bounds(imgs, current_model)
|
| 183 |
+
criterion = get_criterion(criterion)
|
| 184 |
+
is_adversarial = get_is_adversarial(criterion, current_model)
|
| 185 |
+
|
| 186 |
+
logging.info("Eval attack on each target images.")
|
| 187 |
+
for i, target_image in enumerate(target_imgs):
|
| 188 |
+
advs = ep.astensor(self.adv_generator(imgs.raw.to(self.device), target_image.raw.repeat(len(imgs), 1, 1, 1).to(self.device)))
|
| 189 |
+
is_adv = is_adversarial(advs)[0]
|
| 190 |
+
asr_matrix[task, i] = (is_adv.bool().sum().raw.item() / len(imgs))
|
| 191 |
+
if self.plot_gradcam:
|
| 192 |
+
save_grad_cam(self.args, torch.clip(advs.raw.detach(),0,1), labels_t.raw,
|
| 193 |
+
self.model._network, self.save_path + "/GradCam" + f"targetimg{i}", prefix=f'task{task}',
|
| 194 |
+
layer_name='stage_3', save_num=100, save_raw=True)
|
| 195 |
+
|
| 196 |
+
del advs, is_adv, target_image
|
| 197 |
+
torch.cuda.empty_cache()
|
| 198 |
+
|
| 199 |
+
del criterion, current_model, is_adversarial
|
| 200 |
+
torch.cuda.empty_cache()
|
| 201 |
+
|
| 202 |
+
self.model.after_task()
|
| 203 |
+
|
| 204 |
+
# Save all target images info: everage asr,
|
| 205 |
+
asr_matrix = np.mean(asr_matrix, axis=1, keepdims=True)
|
| 206 |
+
prefix = f'batch{i_batch}_{self.prefix}'
|
| 207 |
+
plot_asr_per_target(asr_matrix, self.save_path, prefix, self.args)
|
| 208 |
+
df = pd.DataFrame(asr_matrix, columns=['ASR'])
|
| 209 |
+
df.to_excel(os.path.join(self.save_path, f"{prefix}.xlsx"), index=False)
|
| 210 |
+
|
| 211 |
+
del asr_matrix, imgs, labels, labels_t, target_imgs
|
| 212 |
+
torch.cuda.empty_cache()
|
| 213 |
+
|
| 214 |
+
def __call__(
|
| 215 |
+
self,
|
| 216 |
+
model: Model,
|
| 217 |
+
inputs: T,
|
| 218 |
+
criterion: Any,
|
| 219 |
+
*,
|
| 220 |
+
epsilons: Union[Sequence[Union[float, None]], float, None],
|
| 221 |
+
**kwargs: Any,
|
| 222 |
+
) -> Union[Tuple[List[T], List[T], T], Tuple[T, T, T]]:
|
| 223 |
+
...
|
| 224 |
+
|
| 225 |
+
def repeat(self, times: int) -> "AIM":
|
| 226 |
+
...
|
attacks/AIM/examples/aim_attack.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Usage:
|
| 3 |
+
# ./examples/aim_attack.py -h
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from pprint import pformat
|
| 8 |
+
from typing import List, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from attacks.GAT.src.gat.datasets import build_dataset, list_datasets
|
| 14 |
+
from attacks.GAT.src.gat.datasets.transforms import norm
|
| 15 |
+
from attacks.GAT.src.gat.models.attack import AIMAttack, ContrastiveLoss
|
| 16 |
+
from attacks.GAT.src.gat.models.surrogate import (build_surrogate, list_surrogates,
|
| 17 |
+
midlayer_dict, register_collecter)
|
| 18 |
+
from attacks.GAT.src.gat.runtime import AverageMeter, calc_cls_accuracy, fix_random, randid
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def parse_args():
|
| 22 |
+
parser = argparse.ArgumentParser()
|
| 23 |
+
parser.add_argument('-v', '--verbose', action='store_true')
|
| 24 |
+
parser.add_argument('--seed', type=int, default=0)
|
| 25 |
+
parser.add_argument('--expid', type=str, default=randid(4))
|
| 26 |
+
parser.add_argument('--workdir', type=str, default='workdirs')
|
| 27 |
+
parser.add_argument('--device', type=str, default='cuda')
|
| 28 |
+
parser.add_argument('--tar-classes', type=int, default=24)
|
| 29 |
+
parser.add_argument('--batch-size', type=int, default=16)
|
| 30 |
+
parser.add_argument('--dataset',
|
| 31 |
+
type=str,
|
| 32 |
+
default='imagenet',
|
| 33 |
+
choices=list_datasets())
|
| 34 |
+
parser.add_argument('--data-root',
|
| 35 |
+
type=str,
|
| 36 |
+
default=Path(__file__).parent / '../data' / 'in_1k')
|
| 37 |
+
sub_parsers = parser.add_subparsers(dest='command')
|
| 38 |
+
train_parser = sub_parsers.add_parser('train')
|
| 39 |
+
train_parser.add_argument('--surrogate-id',
|
| 40 |
+
type=str,
|
| 41 |
+
default='resnet152',
|
| 42 |
+
choices=list_surrogates())
|
| 43 |
+
train_parser.add_argument('--num-epoch', type=int, default=10)
|
| 44 |
+
train_parser.add_argument('--lr', type=float, default=0.0002)
|
| 45 |
+
train_parser.add_argument('--betas',
|
| 46 |
+
type=float,
|
| 47 |
+
nargs=2,
|
| 48 |
+
default=(0.5, 0.999))
|
| 49 |
+
sub_parsers.add_parser('evaluate')
|
| 50 |
+
args = parser.parse_args()
|
| 51 |
+
|
| 52 |
+
args.workdir = Path(args.workdir) / args.expid
|
| 53 |
+
args.workdir.mkdir(parents=True, exist_ok=True)
|
| 54 |
+
args.device = torch.device(args.device)
|
| 55 |
+
|
| 56 |
+
args.ckpt = args.workdir / 'model.pth'
|
| 57 |
+
|
| 58 |
+
fix_random(args.seed)
|
| 59 |
+
|
| 60 |
+
with open(args.workdir / 'args.txt', 'w') as f:
|
| 61 |
+
f.write(pformat(vars(args)))
|
| 62 |
+
|
| 63 |
+
return args
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def init_loader(dataset: str,
|
| 67 |
+
data_root: Union[str, Path],
|
| 68 |
+
tar_classes: Union[int, List[int]],
|
| 69 |
+
batch_size: int = 16,
|
| 70 |
+
command: str = 'train') -> List[torch.utils.data.DataLoader]:
|
| 71 |
+
train_ds = build_dataset(dataset,
|
| 72 |
+
data_root=data_root,
|
| 73 |
+
is_train=(command == 'train'))
|
| 74 |
+
train_loader = torch.utils.data.DataLoader(
|
| 75 |
+
train_ds,
|
| 76 |
+
batch_size=batch_size,
|
| 77 |
+
shuffle=True,
|
| 78 |
+
num_workers=4,
|
| 79 |
+
pin_memory=True,
|
| 80 |
+
)
|
| 81 |
+
target_ds = build_dataset(dataset,
|
| 82 |
+
data_root=data_root,
|
| 83 |
+
is_train=True,
|
| 84 |
+
filter_class=tar_classes)
|
| 85 |
+
target_loader = torch.utils.data.DataLoader(
|
| 86 |
+
target_ds,
|
| 87 |
+
batch_size=batch_size,
|
| 88 |
+
sampler=torch.utils.data.RandomSampler(target_ds,
|
| 89 |
+
replacement=True,
|
| 90 |
+
num_samples=len(train_ds)),
|
| 91 |
+
num_workers=4,
|
| 92 |
+
pin_memory=True,
|
| 93 |
+
)
|
| 94 |
+
return train_loader, target_loader
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def train(
|
| 98 |
+
surrogate_id: str,
|
| 99 |
+
dataset: str,
|
| 100 |
+
data_root: Union[str, Path],
|
| 101 |
+
tar_classes: Union[int, List[int]] = 24,
|
| 102 |
+
num_epoch: int = 10,
|
| 103 |
+
batch_size: int = 16,
|
| 104 |
+
lr: float = 0.0002,
|
| 105 |
+
betas: Union[float, List[float]] = (0.5, 0.999),
|
| 106 |
+
device: Union[str, torch.device] = torch.device('cuda'),
|
| 107 |
+
command: str = 'train',
|
| 108 |
+
workdir: Union[str,
|
| 109 |
+
Path] = Path(__file__).parents[1] / 'workdirs') -> None:
|
| 110 |
+
|
| 111 |
+
train_loader, target_loader = init_loader(dataset, data_root, tar_classes,
|
| 112 |
+
batch_size, command)
|
| 113 |
+
normalizer = norm(dataset, _callable=True)
|
| 114 |
+
|
| 115 |
+
surrogate = build_surrogate(surrogate_id, pretrain=True).to(device)
|
| 116 |
+
surrogate.eval()
|
| 117 |
+
feat_collecter_handler, feat_collecter = register_collecter(
|
| 118 |
+
surrogate, midlayer_dict[surrogate_id])
|
| 119 |
+
|
| 120 |
+
attack = AIMAttack(device=device)
|
| 121 |
+
attack.set_mode('train')
|
| 122 |
+
optim = torch.optim.Adam(attack.get_params(), lr=lr, betas=betas)
|
| 123 |
+
|
| 124 |
+
contrastive_loss = ContrastiveLoss(0.2)
|
| 125 |
+
sim_loss = torch.nn.functional.cosine_similarity
|
| 126 |
+
|
| 127 |
+
for epoch in range(1, num_epoch + 1):
|
| 128 |
+
attack.set_mode('train')
|
| 129 |
+
enumerator = enumerate(zip(train_loader, target_loader))
|
| 130 |
+
enumerator = tqdm(enumerator,
|
| 131 |
+
total=len(train_loader),
|
| 132 |
+
desc=f'Epoch {epoch}')
|
| 133 |
+
for batch_idx, ((x_nat, y_nat), (x_tar, y_tar)) in enumerator:
|
| 134 |
+
if torch.any(y_nat == y_tar):
|
| 135 |
+
continue
|
| 136 |
+
x_nat, x_tar = x_nat.to(device), x_tar.to(device)
|
| 137 |
+
y_nat, y_tar = y_nat.to(device), y_tar.to(device)
|
| 138 |
+
x_adv = attack(x_nat, x_tar)
|
| 139 |
+
|
| 140 |
+
logits_nat = surrogate(normalizer(x_nat))
|
| 141 |
+
feat_nat = feat_collecter.pop()
|
| 142 |
+
logits_tar = surrogate(normalizer(x_tar))
|
| 143 |
+
feat_tar = feat_collecter.pop()
|
| 144 |
+
logits_adv = surrogate(normalizer(x_adv))
|
| 145 |
+
feat_adv = feat_collecter.pop()
|
| 146 |
+
|
| 147 |
+
loss = (contrastive_loss(logits_adv, logits_nat, logits_tar) +
|
| 148 |
+
sim_loss(feat_nat, feat_adv) -
|
| 149 |
+
sim_loss(feat_tar, feat_adv)).mean()
|
| 150 |
+
|
| 151 |
+
optim.zero_grad()
|
| 152 |
+
loss.backward()
|
| 153 |
+
optim.step()
|
| 154 |
+
|
| 155 |
+
feat_collecter_handler.remove()
|
| 156 |
+
|
| 157 |
+
attack.save_ckpt(workdir / 'model.pth')
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
@torch.no_grad()
|
| 161 |
+
def evaluate(
|
| 162 |
+
ckpt: Union[str, Path],
|
| 163 |
+
dataset: str,
|
| 164 |
+
data_root: Union[str, Path],
|
| 165 |
+
tar_classes: Union[int, List[int]] = 24,
|
| 166 |
+
batch_size: int = 16,
|
| 167 |
+
device: Union[str, torch.device] = torch.device('cuda'),
|
| 168 |
+
command: str = 'train',
|
| 169 |
+
workdir: Union[str,
|
| 170 |
+
Path] = Path(__file__).parents[1] / 'workdirs') -> None:
|
| 171 |
+
# init dataloader
|
| 172 |
+
eval_loader, target_loader = init_loader(dataset, data_root, tar_classes,
|
| 173 |
+
batch_size, command)
|
| 174 |
+
normalizer = norm(dataset, _callable=True)
|
| 175 |
+
# init attack method
|
| 176 |
+
attack = AIMAttack(device=device)
|
| 177 |
+
attack.load_ckpt(ckpt)
|
| 178 |
+
attack.set_mode('eval')
|
| 179 |
+
# init evaluate models
|
| 180 |
+
models = {
|
| 181 |
+
surrogate_id: build_surrogate(surrogate_id, pretrain=True).to(device)
|
| 182 |
+
for surrogate_id in list_surrogates()
|
| 183 |
+
}
|
| 184 |
+
for surrogate_id in models.keys():
|
| 185 |
+
models[surrogate_id].eval()
|
| 186 |
+
model_meters = {
|
| 187 |
+
surrogate_id: [AverageMeter() for _ in range(2)]
|
| 188 |
+
for surrogate_id in models.keys()
|
| 189 |
+
}
|
| 190 |
+
# evaluate
|
| 191 |
+
enumerator = enumerate(zip(eval_loader, target_loader))
|
| 192 |
+
enumerator = tqdm(enumerator, total=len(eval_loader), desc='Eval')
|
| 193 |
+
for batch_idx, ((x_nat, y_nat), (x_tar, y_tar)) in enumerator:
|
| 194 |
+
x_nat, y_nat = x_nat.to(device), y_nat.to(device)
|
| 195 |
+
x_tar, y_tar = x_tar.to(device), y_tar.to(device)
|
| 196 |
+
x_adv = attack(x_nat, x_tar)
|
| 197 |
+
for surrogate_id, model in models.items():
|
| 198 |
+
logits_nat = model(normalizer(x_nat))
|
| 199 |
+
logits_adv = model(normalizer(x_adv))
|
| 200 |
+
# collect metrics
|
| 201 |
+
acc = calc_cls_accuracy(logits_nat, y_nat)
|
| 202 |
+
asr = calc_cls_accuracy(logits_adv, y_tar)
|
| 203 |
+
model_meters[surrogate_id][0].update(acc[0].item(), x_nat.size(0))
|
| 204 |
+
model_meters[surrogate_id][1].update(asr[0].item(), x_nat.size(0))
|
| 205 |
+
# print result
|
| 206 |
+
results = {
|
| 207 |
+
surrogate_id: {
|
| 208 |
+
'acc': meters[0].avg,
|
| 209 |
+
'asr': meters[1].avg
|
| 210 |
+
}
|
| 211 |
+
for surrogate_id, meters in model_meters.items()
|
| 212 |
+
}
|
| 213 |
+
print(pformat(results))
|
| 214 |
+
with open(workdir / 'results.json', 'w') as f:
|
| 215 |
+
json.dump(results, f)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def main() -> None:
|
| 219 |
+
args = parse_args()
|
| 220 |
+
args.command = 'train'
|
| 221 |
+
args.surrogate_id = 'resnet152'
|
| 222 |
+
args.num_epoch = 10
|
| 223 |
+
args.lr = 0.0002
|
| 224 |
+
args.betas = (0.5, 0.999)
|
| 225 |
+
if args.command == 'train':
|
| 226 |
+
train(args.surrogate_id, args.dataset, args.data_root,
|
| 227 |
+
args.tar_classes, args.num_epoch, args.batch_size, args.lr,
|
| 228 |
+
args.betas, args.device, args.command, args.workdir)
|
| 229 |
+
elif args.command == 'evaluate':
|
| 230 |
+
evaluate(args.ckpt, args.dataset, args.data_root, args.tar_classes,
|
| 231 |
+
args.batch_size, args.device, args.command, args.workdir)
|
| 232 |
+
else:
|
| 233 |
+
raise NotImplementedError
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
if __name__ == '__main__':
|
| 237 |
+
main()
|
attacks/AIM/examples/ens-gen.py
ADDED
|
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Usage
|
| 4 |
+
-----
|
| 5 |
+
./examples/ens-gen.py -d -v -s 0 \
|
| 6 |
+
--dataset imagenet -b 16 --eps 8 --workdir "workdirs" \
|
| 7 |
+
--device "cuda:0" \
|
| 8 |
+
train --n-ep 1 \
|
| 9 |
+
--surrogate-model-ids vgg19 inception_v3 resnet152 densenet169 \
|
| 10 |
+
--lr 0.0002 --beta 0.5 0.999 \
|
| 11 |
+
--use-logit-loss --use-logit-weights --use-logit-softmax-weights
|
| 12 |
+
"""
|
| 13 |
+
import argparse
|
| 14 |
+
import json
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from pprint import pformat
|
| 17 |
+
from typing import List, Union
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torchvision
|
| 21 |
+
from torch.nn import functional as F
|
| 22 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
|
| 25 |
+
from gat.datasets import build_dataset, list_datasets
|
| 26 |
+
from gat.datasets.transforms import norm
|
| 27 |
+
from gat.models.attack import CDAAttack
|
| 28 |
+
from gat.models.attack.optim import (SAM, disable_running_stats,
|
| 29 |
+
enable_running_stats)
|
| 30 |
+
from gat.models.surrogate import (build_surrogate, feat_col, list_surrogates,
|
| 31 |
+
midlayer_dict)
|
| 32 |
+
from gat.runtime import AverageMeter, calc_cls_accuracy, fix_random, randid
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class CLIParser:
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def init_basic_parser(p: argparse.ArgumentParser):
|
| 39 |
+
g_basic = p.add_argument_group('Basic Settings')
|
| 40 |
+
g_basic.add_argument('-v',
|
| 41 |
+
'--verbose',
|
| 42 |
+
action='store_true',
|
| 43 |
+
default=False)
|
| 44 |
+
g_basic.add_argument('-d', '--dev', action='store_true', default=False)
|
| 45 |
+
g_basic.add_argument('-s', '--seed', type=int, default=0)
|
| 46 |
+
g_basic.add_argument('--expid', type=str, default=randid(4))
|
| 47 |
+
g_basic.add_argument('--device', type=str, default='cuda')
|
| 48 |
+
|
| 49 |
+
g_path = p.add_argument_group('Path Settings')
|
| 50 |
+
g_path.add_argument('--workdir', type=str, default='workdirs')
|
| 51 |
+
g_path.add_argument('--data-root',
|
| 52 |
+
type=str,
|
| 53 |
+
default=Path(__file__).parent / '../data' /
|
| 54 |
+
'in_1k')
|
| 55 |
+
|
| 56 |
+
g_ds = p.add_argument_group('Dataset Settings')
|
| 57 |
+
g_ds.add_argument('--dataset',
|
| 58 |
+
type=str,
|
| 59 |
+
default='imagenet',
|
| 60 |
+
choices=list_datasets())
|
| 61 |
+
g_ds.add_argument('-b', '--batch-size', type=int, default=16)
|
| 62 |
+
|
| 63 |
+
g_at_basic = p.add_argument_group('General Attack Settings')
|
| 64 |
+
g_at_basic.add_argument('--eps',
|
| 65 |
+
'--epsilon',
|
| 66 |
+
dest='epsilon',
|
| 67 |
+
type=int,
|
| 68 |
+
default=8,
|
| 69 |
+
choices=[1, 2, 4, 8, 16])
|
| 70 |
+
|
| 71 |
+
@staticmethod
|
| 72 |
+
def post_basic_parser(args: argparse.Namespace):
|
| 73 |
+
if args.dev:
|
| 74 |
+
args.workdir = args.workdir.replace('workdirs', 'workdirs-dev')
|
| 75 |
+
args.workdir = Path(args.workdir) / args.expid
|
| 76 |
+
args.workdir.mkdir(parents=True, exist_ok=True)
|
| 77 |
+
args.device = torch.device(args.device)
|
| 78 |
+
|
| 79 |
+
args.ckpt = args.workdir / 'model.pth'
|
| 80 |
+
args.tf_logger = SummaryWriter(args.workdir / 'tf_log')
|
| 81 |
+
|
| 82 |
+
args.epsilon /= 255.0
|
| 83 |
+
if args.command == 'evaluate-pgd':
|
| 84 |
+
args.alpha /= 255.0
|
| 85 |
+
|
| 86 |
+
fix_random(args.seed)
|
| 87 |
+
|
| 88 |
+
if args.verbose:
|
| 89 |
+
print(pformat(vars(args)))
|
| 90 |
+
with open(args.workdir / f'args-{args.command}.txt', 'w') as f:
|
| 91 |
+
f.write(pformat(vars(args)))
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
def init_train_parser(p: argparse.ArgumentParser):
|
| 95 |
+
g_at = p.add_argument_group('Attack Settings')
|
| 96 |
+
g_at.add_argument('--sur-ids',
|
| 97 |
+
'--surrogate-model-ids',
|
| 98 |
+
dest='surrogate_model_ids',
|
| 99 |
+
type=str,
|
| 100 |
+
default=['resnet152'],
|
| 101 |
+
nargs='+',
|
| 102 |
+
choices=list_surrogates())
|
| 103 |
+
g_at.add_argument('--n-ep',
|
| 104 |
+
'--num-epoch',
|
| 105 |
+
dest='num_epoch',
|
| 106 |
+
type=int,
|
| 107 |
+
default=10)
|
| 108 |
+
|
| 109 |
+
g_optim = p.add_argument_group('Optimization Settings')
|
| 110 |
+
g_optim.add_argument('--use-sam', action='store_true', default=False)
|
| 111 |
+
g_optim.add_argument('--lr', type=float, default=0.0002)
|
| 112 |
+
g_optim.add_argument('--betas',
|
| 113 |
+
type=float,
|
| 114 |
+
nargs=2,
|
| 115 |
+
default=(0.5, 0.999))
|
| 116 |
+
|
| 117 |
+
g_loss = p.add_argument_group('Loss Func Settings')
|
| 118 |
+
g_loss.add_argument('--use-logit-loss',
|
| 119 |
+
action='store_true',
|
| 120 |
+
default=False)
|
| 121 |
+
g_loss.add_argument('--use-logit-kl',
|
| 122 |
+
action='store_true',
|
| 123 |
+
default=False)
|
| 124 |
+
g_loss.add_argument('--use-logit-weights',
|
| 125 |
+
action='store_true',
|
| 126 |
+
default=False)
|
| 127 |
+
g_loss.add_argument('--use-logit-softmax-weights',
|
| 128 |
+
action='store_true',
|
| 129 |
+
default=False)
|
| 130 |
+
g_loss.add_argument('--use-feat-loss',
|
| 131 |
+
action='store_true',
|
| 132 |
+
default=False)
|
| 133 |
+
g_loss.add_argument('--use-feat-attn',
|
| 134 |
+
action='store_true',
|
| 135 |
+
default=False)
|
| 136 |
+
|
| 137 |
+
@staticmethod
|
| 138 |
+
def post_train_parser(args: argparse.Namespace):
|
| 139 |
+
if args.command == 'train':
|
| 140 |
+
assert args.use_logit_loss ^ args.use_feat_loss
|
| 141 |
+
if args.use_logit_kl:
|
| 142 |
+
assert not args.use_feat_loss
|
| 143 |
+
if args.use_feat_attn:
|
| 144 |
+
assert not args.use_logit_loss
|
| 145 |
+
if args.use_logit_weights:
|
| 146 |
+
assert args.use_logit_loss
|
| 147 |
+
if args.use_logit_softmax_weights:
|
| 148 |
+
assert args.use_logit_loss
|
| 149 |
+
|
| 150 |
+
@staticmethod
|
| 151 |
+
def init_evaluate_parser(p: argparse.ArgumentParser):
|
| 152 |
+
pass
|
| 153 |
+
|
| 154 |
+
@staticmethod
|
| 155 |
+
def post_evaluate_parser(args: argparse.Namespace):
|
| 156 |
+
pass
|
| 157 |
+
|
| 158 |
+
@staticmethod
|
| 159 |
+
def init_evaluate_pgd_parser(p: argparse.ArgumentParser):
|
| 160 |
+
g_at = p.add_argument_group('Attack Settings')
|
| 161 |
+
g_at.add_argument('--surrogate-model-ids',
|
| 162 |
+
type=str,
|
| 163 |
+
default=['resnet152'],
|
| 164 |
+
nargs='+',
|
| 165 |
+
choices=list_surrogates())
|
| 166 |
+
|
| 167 |
+
g_optim = p.add_argument_group('Optimization Settings')
|
| 168 |
+
g_optim.add_argument('--num-step', type=int, default=100)
|
| 169 |
+
g_optim.add_argument('--alpha',
|
| 170 |
+
type=int,
|
| 171 |
+
default=2,
|
| 172 |
+
choices=[1, 2, 4, 8, 16])
|
| 173 |
+
|
| 174 |
+
g_loss = p.add_argument_group('Loss Func Settings')
|
| 175 |
+
g_loss.add_argument('--use-loss-avg',
|
| 176 |
+
action='store_true',
|
| 177 |
+
default=False)
|
| 178 |
+
g_loss.add_argument('--use-logit-avg',
|
| 179 |
+
action='store_true',
|
| 180 |
+
default=False)
|
| 181 |
+
|
| 182 |
+
@staticmethod
|
| 183 |
+
def post_evaluate_pgd_parser(args: argparse.Namespace):
|
| 184 |
+
if args.command == 'evaluate-pgd':
|
| 185 |
+
assert args.use_loss_avg ^ args.use_logit_avg
|
| 186 |
+
|
| 187 |
+
@staticmethod
|
| 188 |
+
def parse_args():
|
| 189 |
+
p = argparse.ArgumentParser()
|
| 190 |
+
CLIParser.init_basic_parser(p)
|
| 191 |
+
sub_p = p.add_subparsers(dest='command')
|
| 192 |
+
|
| 193 |
+
CLIParser.init_train_parser(sub_p.add_parser('train'))
|
| 194 |
+
CLIParser.init_evaluate_parser(sub_p.add_parser('evaluate'))
|
| 195 |
+
CLIParser.init_evaluate_pgd_parser(sub_p.add_parser('evaluate-pgd'))
|
| 196 |
+
args = p.parse_args()
|
| 197 |
+
CLIParser.post_train_parser(args)
|
| 198 |
+
CLIParser.post_evaluate_parser(args)
|
| 199 |
+
CLIParser.post_evaluate_pgd_parser(args)
|
| 200 |
+
|
| 201 |
+
CLIParser.post_basic_parser(args)
|
| 202 |
+
|
| 203 |
+
return args
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def init_loader(dataset: str,
|
| 207 |
+
data_root: Union[str, Path],
|
| 208 |
+
num_epoch: int = 1,
|
| 209 |
+
batch_size: int = 16,
|
| 210 |
+
command: str = 'train') -> List[torch.utils.data.DataLoader]:
|
| 211 |
+
ds = build_dataset(dataset,
|
| 212 |
+
data_root=data_root,
|
| 213 |
+
is_train=(command == 'train'))
|
| 214 |
+
dataloader = torch.utils.data.DataLoader(
|
| 215 |
+
ds,
|
| 216 |
+
batch_size=batch_size,
|
| 217 |
+
sampler=torch.utils.data.RandomSampler(ds,
|
| 218 |
+
replacement=True,
|
| 219 |
+
num_samples=len(ds) *
|
| 220 |
+
num_epoch),
|
| 221 |
+
num_workers=4,
|
| 222 |
+
pin_memory=True,
|
| 223 |
+
)
|
| 224 |
+
normalizer = norm(dataset, _callable=True)
|
| 225 |
+
return dataloader, normalizer
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def init_models(model_ids: Union[str, List[str]],
|
| 229 |
+
device: Union[str, torch.device] = torch.device('cuda')):
|
| 230 |
+
if isinstance(model_ids, str):
|
| 231 |
+
model_ids = [model_ids]
|
| 232 |
+
models = [
|
| 233 |
+
build_surrogate(_surrogate_id, pretrain=True).to(device)
|
| 234 |
+
for _surrogate_id in model_ids
|
| 235 |
+
]
|
| 236 |
+
for _ in models:
|
| 237 |
+
_.eval()
|
| 238 |
+
return models
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def calc_loss(x_nat: torch.Tensor,
|
| 242 |
+
y_nat: torch.Tensor,
|
| 243 |
+
x_adv: torch.Tensor,
|
| 244 |
+
feat_collecter: List,
|
| 245 |
+
surrogate_models: List[torch.nn.Module],
|
| 246 |
+
normalizer: torchvision.transforms.Compose,
|
| 247 |
+
use_logit_loss: bool,
|
| 248 |
+
use_logit_kl: bool,
|
| 249 |
+
use_logit_weights: bool,
|
| 250 |
+
use_logit_softmax_weights: bool,
|
| 251 |
+
use_feat_loss: bool,
|
| 252 |
+
use_feat_attn: bool,
|
| 253 |
+
device: Union[str, torch.device] = torch.device('cuda')):
|
| 254 |
+
loss_sur = []
|
| 255 |
+
for surrogate_model in surrogate_models:
|
| 256 |
+
logit_nat = surrogate_model(normalizer(x_nat))
|
| 257 |
+
feat_nat = feat_collecter.pop()
|
| 258 |
+
logit_adv = surrogate_model(normalizer(x_adv))
|
| 259 |
+
feat_adv = feat_collecter.pop()
|
| 260 |
+
if use_logit_loss:
|
| 261 |
+
if use_logit_kl:
|
| 262 |
+
loss_sur.append(-(F.kl_div(F.log_softmax(logit_adv, dim=1),
|
| 263 |
+
F.softmax(logit_nat, dim=1)) +
|
| 264 |
+
F.kl_div(F.log_softmax(logit_nat, dim=1),
|
| 265 |
+
F.softmax(logit_adv, dim=1))))
|
| 266 |
+
else:
|
| 267 |
+
loss_sur.append(-(F.cross_entropy(logit_adv, y_nat).mean()))
|
| 268 |
+
elif use_feat_loss:
|
| 269 |
+
if use_feat_attn:
|
| 270 |
+
attn = torch.abs(torch.mean(feat_nat, dim=1, keepdim=True))
|
| 271 |
+
else:
|
| 272 |
+
attn = torch.ones_like(feat_nat)
|
| 273 |
+
loss_sur.append(1 + F.cosine_similarity(attn * feat_nat, attn *
|
| 274 |
+
feat_adv).mean())
|
| 275 |
+
else:
|
| 276 |
+
raise NotImplementedError
|
| 277 |
+
loss_sur = torch.stack(loss_sur)
|
| 278 |
+
if use_logit_weights:
|
| 279 |
+
if use_logit_softmax_weights:
|
| 280 |
+
loss_weights = torch.nn.functional.softmax(loss_sur)
|
| 281 |
+
else:
|
| 282 |
+
loss_weights = torch.nn.functional.softmin(loss_sur)
|
| 283 |
+
loss_all = torch.sum(loss_weights * loss_sur)
|
| 284 |
+
else:
|
| 285 |
+
loss_all = loss_sur.mean()
|
| 286 |
+
|
| 287 |
+
return loss_all
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def train(surrogate_model_ids: Union[str, List[str]],
|
| 291 |
+
epsilon: float = 16.0 / 255.0,
|
| 292 |
+
num_epoch: int = 10,
|
| 293 |
+
dataset: str = 'imagenet',
|
| 294 |
+
batch_size: int = 16,
|
| 295 |
+
use_sam: bool = False,
|
| 296 |
+
lr: float = 0.0002,
|
| 297 |
+
betas: Union[float, List[float]] = (0.5, 0.999),
|
| 298 |
+
use_logit_loss: bool = False,
|
| 299 |
+
use_logit_kl: bool = False,
|
| 300 |
+
use_logit_weights: bool = False,
|
| 301 |
+
use_logit_softmax_weights: bool = False,
|
| 302 |
+
use_feat_loss: bool = False,
|
| 303 |
+
use_feat_attn: bool = False,
|
| 304 |
+
device: Union[str, torch.device] = torch.device('cuda'),
|
| 305 |
+
workdir: Union[str, Path] = Path(__file__).parents[1] / 'workdirs',
|
| 306 |
+
data_root: Union[str,
|
| 307 |
+
Path] = Path(__file__).parent / '../data' / 'in_1k',
|
| 308 |
+
tf_logger: SummaryWriter = None) -> None:
|
| 309 |
+
"""
|
| 310 |
+
Train the attack model with the given surrogate models.
|
| 311 |
+
"""
|
| 312 |
+
loader, normalizer = init_loader(dataset, data_root, num_epoch, batch_size,
|
| 313 |
+
'train')
|
| 314 |
+
surrogate_models = init_models(surrogate_model_ids, device)
|
| 315 |
+
|
| 316 |
+
attack = CDAAttack(device=device, epsilon=epsilon)
|
| 317 |
+
attack.set_mode('train')
|
| 318 |
+
if use_sam:
|
| 319 |
+
optim = SAM(attack.get_params(), torch.optim.Adam, lr=lr, betas=betas)
|
| 320 |
+
else:
|
| 321 |
+
optim = torch.optim.Adam(attack.get_params(), lr=lr, betas=betas)
|
| 322 |
+
|
| 323 |
+
with feat_col(surrogate_models,
|
| 324 |
+
[midlayer_dict[_]
|
| 325 |
+
for _ in surrogate_model_ids]) as feat_collecter:
|
| 326 |
+
attack.set_mode('train')
|
| 327 |
+
enumerator = tqdm(enumerate(loader), total=len(loader), desc='')
|
| 328 |
+
for step, (x_nat, y_nat) in enumerator:
|
| 329 |
+
x_nat, y_nat = x_nat.to(device), y_nat.to(device)
|
| 330 |
+
|
| 331 |
+
if use_sam:
|
| 332 |
+
# 1
|
| 333 |
+
enable_running_stats(attack.get_model())
|
| 334 |
+
loss_v = calc_loss(x_nat, y_nat, attack(x_nat), feat_collecter,
|
| 335 |
+
surrogate_models, normalizer,
|
| 336 |
+
use_logit_loss, use_logit_kl,
|
| 337 |
+
use_logit_weights,
|
| 338 |
+
use_logit_softmax_weights, use_feat_loss,
|
| 339 |
+
use_feat_attn, device)
|
| 340 |
+
loss_v.backward()
|
| 341 |
+
optim.first_step(zero_grad=True)
|
| 342 |
+
# 2
|
| 343 |
+
disable_running_stats(attack.get_model())
|
| 344 |
+
calc_loss(x_nat, y_nat, attack(x_nat), feat_collecter,
|
| 345 |
+
surrogate_models, normalizer, use_logit_loss,
|
| 346 |
+
use_logit_kl, use_logit_weights,
|
| 347 |
+
use_logit_softmax_weights, use_feat_loss,
|
| 348 |
+
use_feat_attn, device).backward()
|
| 349 |
+
optim.second_step(zero_grad=True)
|
| 350 |
+
else:
|
| 351 |
+
x_adv = attack(x_nat)
|
| 352 |
+
loss_v = calc_loss(x_nat, y_nat, x_adv, feat_collecter,
|
| 353 |
+
surrogate_models, normalizer,
|
| 354 |
+
use_logit_loss, use_logit_kl,
|
| 355 |
+
use_logit_weights,
|
| 356 |
+
use_logit_softmax_weights, use_feat_loss,
|
| 357 |
+
use_feat_attn, device)
|
| 358 |
+
optim.zero_grad()
|
| 359 |
+
loss_v.backward()
|
| 360 |
+
optim.step()
|
| 361 |
+
|
| 362 |
+
if tf_logger:
|
| 363 |
+
tf_logger.add_scalar('loss', loss_v.item(), step)
|
| 364 |
+
tf_logger.add_scalar('lr', optim.param_groups[0]['lr'], step)
|
| 365 |
+
|
| 366 |
+
attack.save_ckpt(workdir / 'model.pth')
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
@torch.no_grad()
|
| 370 |
+
def evaluate(
|
| 371 |
+
ckpt: Union[str, Path],
|
| 372 |
+
epsilon: float = 16.0 / 255.0,
|
| 373 |
+
dataset: str = 'imagenet',
|
| 374 |
+
batch_size: int = 16,
|
| 375 |
+
device: Union[str, torch.device] = torch.device('cuda'),
|
| 376 |
+
workdir: Union[str, Path] = Path(__file__).parents[1] / 'workdirs',
|
| 377 |
+
data_root: Union[str, Path] = Path(__file__).parent / '../data' / 'in_1k',
|
| 378 |
+
) -> None:
|
| 379 |
+
"""
|
| 380 |
+
Evaluate the attack model with the given surrogate models
|
| 381 |
+
"""
|
| 382 |
+
loader, normalizer = init_loader(dataset, data_root, 1, batch_size,
|
| 383 |
+
'evaluate')
|
| 384 |
+
target_models = {
|
| 385 |
+
k: v
|
| 386 |
+
for k, v in zip(list_surrogates(),
|
| 387 |
+
init_models(list_surrogates(), device))
|
| 388 |
+
}
|
| 389 |
+
target_acc_meters = {
|
| 390 |
+
target_model_id: [AverageMeter() for _ in range(2)]
|
| 391 |
+
for target_model_id in target_models.keys()
|
| 392 |
+
}
|
| 393 |
+
# init attack method
|
| 394 |
+
attack = CDAAttack(device=device, epsilon=epsilon)
|
| 395 |
+
attack.load_ckpt(ckpt)
|
| 396 |
+
attack.set_mode('eval')
|
| 397 |
+
# evaluate
|
| 398 |
+
enumerator = tqdm(enumerate(loader), total=len(loader), desc='Eval')
|
| 399 |
+
for step, (x_nat, y_nat) in enumerator:
|
| 400 |
+
x_nat, y_nat = x_nat.to(device), y_nat.to(device)
|
| 401 |
+
x_adv = attack(x_nat)
|
| 402 |
+
for target_model_id, target_model in target_models.items():
|
| 403 |
+
logit_nat = target_model(normalizer(x_nat))
|
| 404 |
+
logit_adv = target_model(normalizer(x_adv))
|
| 405 |
+
# collect metrics
|
| 406 |
+
target_acc = calc_cls_accuracy(logit_nat, y_nat)
|
| 407 |
+
target_asr = calc_cls_accuracy(logit_adv, y_nat)
|
| 408 |
+
target_acc_meters[target_model_id][0].update(
|
| 409 |
+
target_acc[0].item(), x_nat.size(0))
|
| 410 |
+
target_acc_meters[target_model_id][1].update(
|
| 411 |
+
target_asr[0].item(), x_nat.size(0))
|
| 412 |
+
results = {
|
| 413 |
+
target_model_id: {
|
| 414 |
+
'nat_acc': target_acc_meter[0].avg,
|
| 415 |
+
'adv_acc': target_acc_meter[1].avg
|
| 416 |
+
}
|
| 417 |
+
for target_model_id, target_acc_meter in target_acc_meters.items()
|
| 418 |
+
}
|
| 419 |
+
print(pformat(results))
|
| 420 |
+
with open(workdir / 'results.json', 'w') as f:
|
| 421 |
+
json.dump(results, f)
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def evaluate_pgd(
|
| 425 |
+
surrogate_model_ids: Union[str, List[str]],
|
| 426 |
+
epsilon: float = 16.0 / 255.0,
|
| 427 |
+
num_step: int = 1000,
|
| 428 |
+
alpha: float = 2.0 / 255.0,
|
| 429 |
+
dataset: str = 'imagenet',
|
| 430 |
+
batch_size: int = 16,
|
| 431 |
+
use_loss_avg: bool = False,
|
| 432 |
+
use_logit_avg: bool = False,
|
| 433 |
+
device: Union[str, torch.device] = torch.device('cuda'),
|
| 434 |
+
workdir: Union[str, Path] = Path(__file__).parents[1] / 'workdirs',
|
| 435 |
+
data_root: Union[str, Path] = Path(__file__).parent / '../data' / 'in_1k',
|
| 436 |
+
):
|
| 437 |
+
loader, normalizer = init_loader(dataset, data_root, 1, batch_size,
|
| 438 |
+
'evaluate')
|
| 439 |
+
surrogate_models = init_models(surrogate_model_ids, device)
|
| 440 |
+
target_models = {
|
| 441 |
+
k: v
|
| 442 |
+
for k, v in zip(list_surrogates(),
|
| 443 |
+
init_models(list_surrogates(), device))
|
| 444 |
+
}
|
| 445 |
+
target_acc_meters = {
|
| 446 |
+
target_model_id: [AverageMeter() for _ in range(2)]
|
| 447 |
+
for target_model_id in target_models.keys()
|
| 448 |
+
}
|
| 449 |
+
# evaluate
|
| 450 |
+
enumerator = tqdm(enumerate(loader), total=len(loader), desc='')
|
| 451 |
+
for step, (x_nat, y_nat) in enumerator:
|
| 452 |
+
x_nat, y_nat = x_nat.to(device), y_nat.to(device)
|
| 453 |
+
# attack
|
| 454 |
+
x_nat_ori = x_nat.data
|
| 455 |
+
for _ in range(num_step):
|
| 456 |
+
x_nat.requires_grad = True
|
| 457 |
+
if use_loss_avg:
|
| 458 |
+
loss_all = 0.0
|
| 459 |
+
for surrogate_model in surrogate_models:
|
| 460 |
+
logit = surrogate_model(x_nat)
|
| 461 |
+
surrogate_model.zero_grad()
|
| 462 |
+
loss_all += F.cross_entropy(logit, y_nat)
|
| 463 |
+
elif use_logit_avg:
|
| 464 |
+
logit = torch.stack([
|
| 465 |
+
surrogate_model(x_nat)
|
| 466 |
+
for surrogate_model in surrogate_models
|
| 467 |
+
]).mean(dim=0)
|
| 468 |
+
loss_all = F.cross_entropy(logit, y_nat)
|
| 469 |
+
else:
|
| 470 |
+
raise NotADirectoryError
|
| 471 |
+
loss_all.backward()
|
| 472 |
+
x_adv_ = x_nat + alpha * x_nat.grad.sign()
|
| 473 |
+
eta = torch.clamp(x_adv_ - x_nat_ori, min=-epsilon, max=epsilon)
|
| 474 |
+
x_nat = torch.clamp(x_nat_ori + eta, min=0.0, max=1.0).detach_()
|
| 475 |
+
x_adv = x_nat
|
| 476 |
+
x_nat = x_nat_ori
|
| 477 |
+
# eval
|
| 478 |
+
with torch.no_grad():
|
| 479 |
+
for target_model_id, target_model in target_models.items():
|
| 480 |
+
logit_nat = target_model(normalizer(x_nat))
|
| 481 |
+
logit_adv = target_model(normalizer(x_adv))
|
| 482 |
+
# collect
|
| 483 |
+
target_acc_ = calc_cls_accuracy(logit_nat, y_nat)
|
| 484 |
+
target_asr_ = calc_cls_accuracy(logit_adv, y_nat)
|
| 485 |
+
target_acc_meters[target_model_id][0].update(
|
| 486 |
+
target_acc_[0].item(), x_nat.size(0))
|
| 487 |
+
target_acc_meters[target_model_id][1].update(
|
| 488 |
+
target_asr_[0].item(), x_nat.size(0))
|
| 489 |
+
results = {
|
| 490 |
+
target_model_id: {
|
| 491 |
+
'nat_acc': target_acc_meter[0].avg,
|
| 492 |
+
'adv_acc': target_acc_meter[1].avg
|
| 493 |
+
}
|
| 494 |
+
for target_model_id, target_acc_meter in target_acc_meters.items()
|
| 495 |
+
}
|
| 496 |
+
print(pformat(results))
|
| 497 |
+
with open(workdir / 'results-pgd.json', 'w') as f:
|
| 498 |
+
json.dump(results, f)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def main() -> None:
|
| 502 |
+
args = CLIParser.parse_args()
|
| 503 |
+
if args.command == 'train':
|
| 504 |
+
train(args.surrogate_model_ids, args.epsilon, args.num_epoch,
|
| 505 |
+
args.dataset, args.batch_size, args.use_sam, args.lr, args.betas,
|
| 506 |
+
args.use_logit_loss, args.use_logit_kl, args.use_logit_weights,
|
| 507 |
+
args.use_logit_softmax_weights, args.use_feat_loss,
|
| 508 |
+
args.use_feat_attn, args.device, args.workdir, args.data_root,
|
| 509 |
+
args.tf_logger)
|
| 510 |
+
elif args.command == 'evaluate':
|
| 511 |
+
evaluate(args.ckpt, args.epsilon, args.dataset, args.batch_size,
|
| 512 |
+
args.device, args.workdir, args.data_root)
|
| 513 |
+
elif args.command == 'evaluate-pgd':
|
| 514 |
+
evaluate_pgd(args.surrogate_model_ids, args.epsilon, args.num_step,
|
| 515 |
+
args.alpha, args.dataset, args.batch_size,
|
| 516 |
+
args.use_loss_avg, args.use_logit_avg, args.device,
|
| 517 |
+
args.workdir, args.data_root)
|
| 518 |
+
else:
|
| 519 |
+
raise NotImplementedError
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
if __name__ == '__main__':
|
| 523 |
+
main()
|
attacks/AIM/examples/workdirs/1MsK/args.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{'batch_size': 16,
|
| 2 |
+
'ckpt': WindowsPath('workdirs/1MsK/model.pth'),
|
| 3 |
+
'command': None,
|
| 4 |
+
'data_root': WindowsPath('D:/Sharing/Programs/3-Durable-Adv/PyCIL-master/attacks/GAT/examples/../data/in_1k'),
|
| 5 |
+
'dataset': 'imagenet',
|
| 6 |
+
'device': device(type='cuda'),
|
| 7 |
+
'expid': '1MsK',
|
| 8 |
+
'seed': 0,
|
| 9 |
+
'tar_classes': 24,
|
| 10 |
+
'verbose': False,
|
| 11 |
+
'workdir': WindowsPath('workdirs/1MsK')}
|
attacks/AIM/examples/workdirs/CGKQ/args.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{'batch_size': 16,
|
| 2 |
+
'ckpt': WindowsPath('workdirs/CGKQ/model.pth'),
|
| 3 |
+
'command': None,
|
| 4 |
+
'data_root': WindowsPath('D:/Sharing/Programs/3-Durable-Adv/PyCIL-master/attacks/GAT/examples/../data/in_1k'),
|
| 5 |
+
'dataset': 'imagenet',
|
| 6 |
+
'device': device(type='cuda'),
|
| 7 |
+
'expid': 'CGKQ',
|
| 8 |
+
'seed': 0,
|
| 9 |
+
'tar_classes': 24,
|
| 10 |
+
'verbose': False,
|
| 11 |
+
'workdir': WindowsPath('workdirs/CGKQ')}
|
attacks/AIM/examples/workdirs/Fvu1/args.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{'batch_size': 16,
|
| 2 |
+
'ckpt': WindowsPath('workdirs/Fvu1/model.pth'),
|
| 3 |
+
'command': None,
|
| 4 |
+
'data_root': WindowsPath('D:/Sharing/Programs/3-Durable-Adv/PyCIL-master/attacks/GAT/examples/../data/in_1k'),
|
| 5 |
+
'dataset': 'imagenet',
|
| 6 |
+
'device': device(type='cuda'),
|
| 7 |
+
'expid': 'Fvu1',
|
| 8 |
+
'seed': 0,
|
| 9 |
+
'tar_classes': 24,
|
| 10 |
+
'verbose': False,
|
| 11 |
+
'workdir': WindowsPath('workdirs/Fvu1')}
|
attacks/AIM/examples/workdirs/Qonx/args.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{'batch_size': 16,
|
| 2 |
+
'ckpt': WindowsPath('workdirs/Qonx/model.pth'),
|
| 3 |
+
'command': None,
|
| 4 |
+
'data_root': WindowsPath('D:/Sharing/Programs/3-Durable-Adv/PyCIL-master/attacks/GAT/examples/../data/in_1k'),
|
| 5 |
+
'dataset': 'imagenet',
|
| 6 |
+
'device': device(type='cuda'),
|
| 7 |
+
'expid': 'Qonx',
|
| 8 |
+
'seed': 0,
|
| 9 |
+
'tar_classes': 24,
|
| 10 |
+
'verbose': False,
|
| 11 |
+
'workdir': WindowsPath('workdirs/Qonx')}
|
attacks/AIM/examples/workdirs/fnNs/args.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{'batch_size': 16,
|
| 2 |
+
'ckpt': WindowsPath('workdirs/fnNs/model.pth'),
|
| 3 |
+
'command': None,
|
| 4 |
+
'data_root': WindowsPath('D:/Sharing/Programs/3-Durable-Adv/PyCIL-master/attacks/GAT/examples/../data/in_1k'),
|
| 5 |
+
'dataset': 'imagenet',
|
| 6 |
+
'device': device(type='cuda'),
|
| 7 |
+
'expid': 'fnNs',
|
| 8 |
+
'seed': 0,
|
| 9 |
+
'tar_classes': 24,
|
| 10 |
+
'verbose': False,
|
| 11 |
+
'workdir': WindowsPath('workdirs/fnNs')}
|
attacks/AIM/examples/workdirs/jtMb/args.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{'batch_size': 16,
|
| 2 |
+
'ckpt': WindowsPath('workdirs/jtMb/model.pth'),
|
| 3 |
+
'command': None,
|
| 4 |
+
'data_root': WindowsPath('D:/Sharing/Programs/3-Durable-Adv/PyCIL-master/attacks/GAT/examples/../data/in_1k'),
|
| 5 |
+
'dataset': 'imagenet',
|
| 6 |
+
'device': device(type='cuda'),
|
| 7 |
+
'expid': 'jtMb',
|
| 8 |
+
'seed': 0,
|
| 9 |
+
'tar_classes': 24,
|
| 10 |
+
'verbose': False,
|
| 11 |
+
'workdir': WindowsPath('workdirs/jtMb')}
|
attacks/AIM/examples/workdirs/krhX/args.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{'batch_size': 16,
|
| 2 |
+
'ckpt': WindowsPath('workdirs/krhX/model.pth'),
|
| 3 |
+
'command': None,
|
| 4 |
+
'data_root': WindowsPath('D:/Sharing/Programs/3-Durable-Adv/PyCIL-master/attacks/GAT/examples/../data/in_1k'),
|
| 5 |
+
'dataset': 'imagenet',
|
| 6 |
+
'device': device(type='cuda'),
|
| 7 |
+
'expid': 'krhX',
|
| 8 |
+
'seed': 0,
|
| 9 |
+
'tar_classes': 24,
|
| 10 |
+
'verbose': False,
|
| 11 |
+
'workdir': WindowsPath('workdirs/krhX')}
|
attacks/AIM/setup.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import sys
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Union
|
| 5 |
+
|
| 6 |
+
from setuptools import find_packages, setup
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
sys.path.append(str(Path(__file__).parent / 'src'))
|
| 10 |
+
import gat
|
| 11 |
+
except ImportError:
|
| 12 |
+
print('Please install aim_attack package')
|
| 13 |
+
sys.exit(1)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def read(path: Union[str, Path]) -> str:
|
| 17 |
+
with open(path, 'r') as f:
|
| 18 |
+
return f.read()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
setup(name='pygat',
|
| 22 |
+
version=gat.VERSION,
|
| 23 |
+
description='GAT: Generative Attack Toolbox',
|
| 24 |
+
long_description=read('README.md'),
|
| 25 |
+
long_description_content_type='text/markdown',
|
| 26 |
+
author='Terry Li',
|
| 27 |
+
url='https://terrytengli.com/GAT/',
|
| 28 |
+
classifiers=[
|
| 29 |
+
'Programming Language :: Python :: 3',
|
| 30 |
+
'License :: OSI Approved :: MIT License'
|
| 31 |
+
],
|
| 32 |
+
python_requires='>=3.10',
|
| 33 |
+
packages=find_packages('src'),
|
| 34 |
+
package_dir={'': 'src'},
|
| 35 |
+
entry_points={
|
| 36 |
+
'console_scripts': [
|
| 37 |
+
'aim-api=gat.runtime.api.aim_attack:main',
|
| 38 |
+
],
|
| 39 |
+
},
|
| 40 |
+
install_requires=[
|
| 41 |
+
'tqdm', 'tabulate', 'torch', 'torchvision', 'tensorboard', 'jupyter'
|
| 42 |
+
],
|
| 43 |
+
extras_require={
|
| 44 |
+
'cli': ['python-multipart', 'fastapi', 'uvicorn'],
|
| 45 |
+
'ens-gen': []
|
| 46 |
+
},
|
| 47 |
+
include_package_data=True)
|
attacks/AIM/src/gat/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
VERSION = '202411.2'
|
attacks/AIM/src/gat/datasets/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .builder import build_dataset, list_datasets
|
| 2 |
+
from .cub import cub
|
| 3 |
+
from .imagenet import imagenet
|
| 4 |
+
|
| 5 |
+
__all__ = ['build_dataset', 'list_datasets', 'imagenet', 'cub']
|
attacks/AIM/src/gat/datasets/builder.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..runtime.factory import Registry
|
| 2 |
+
|
| 3 |
+
DATASET_REGISTRY = Registry('DATASET')
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def build_dataset(_type: str, *args, **kwargs) -> object:
|
| 7 |
+
return DATASET_REGISTRY.get(_type)(*args, **kwargs)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def list_datasets():
|
| 11 |
+
return list(DATASET_REGISTRY._obj_map.keys())
|
attacks/AIM/src/gat/datasets/cub.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .builder import DATASET_REGISTRY
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@DATASET_REGISTRY.register()
|
| 5 |
+
def cub():
|
| 6 |
+
raise NotADirectoryError
|
attacks/AIM/src/gat/datasets/env.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CIFAR10_DEFAULT_MEAN = (0.4914, 0.4822, 0.4465)
|
| 2 |
+
CIFAR10_DEFAULT_STD = (0.2470, 0.2435, 0.2616)
|
| 3 |
+
CIFAR100_DEFAULT_MEAN = (0.5071, 0.4865, 0.4409)
|
| 4 |
+
CIFAR100_DEFAULT_STD = (0.2673, 0.2564, 0.2762)
|
| 5 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
| 6 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
attacks/AIM/src/gat/datasets/imagenet.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torchvision
|
| 6 |
+
|
| 7 |
+
from .builder import DATASET_REGISTRY
|
| 8 |
+
from .transforms import resize_256_224, to_color, to_ts
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@DATASET_REGISTRY.register()
|
| 12 |
+
def imagenet(
|
| 13 |
+
data_root: Union[str, Path],
|
| 14 |
+
is_train: bool = True,
|
| 15 |
+
filter_class: Union[int, List[int]] = None,
|
| 16 |
+
) -> torch.utils.data.Dataset:
|
| 17 |
+
if isinstance(data_root, str):
|
| 18 |
+
data_root = Path(data_root)
|
| 19 |
+
if is_train:
|
| 20 |
+
data_root = data_root / 'train'
|
| 21 |
+
else:
|
| 22 |
+
data_root = data_root / 'val'
|
| 23 |
+
|
| 24 |
+
_transforms = resize_256_224() + to_color() + to_ts()
|
| 25 |
+
|
| 26 |
+
_ds = torchvision.datasets.ImageFolder(
|
| 27 |
+
data_root,
|
| 28 |
+
transform=torchvision.transforms.Compose(_transforms),
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
if isinstance(filter_class, int):
|
| 32 |
+
filter_class = [filter_class]
|
| 33 |
+
if filter_class:
|
| 34 |
+
_ds.samples = list(filter(lambda x: x[1] in filter_class, _ds.samples))
|
| 35 |
+
|
| 36 |
+
return _ds
|
attacks/AIM/src/gat/datasets/transforms.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import torchvision
|
| 4 |
+
|
| 5 |
+
from . import env
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def resize_256_224() -> List:
|
| 9 |
+
return [
|
| 10 |
+
torchvision.transforms.Resize(size=256),
|
| 11 |
+
torchvision.transforms.CenterCrop(size=(224, 224)),
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def resize_512_448() -> List:
|
| 16 |
+
return [
|
| 17 |
+
torchvision.transforms.Resize(size=512),
|
| 18 |
+
torchvision.transforms.CenterCrop(size=(448, 448)),
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def resize_224() -> List:
|
| 23 |
+
return [torchvision.transforms.Resize(size=224)]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def hflip(p: float = 0.5) -> List:
|
| 27 |
+
assert 0 <= p <= 1
|
| 28 |
+
return [torchvision.transforms.RandomHorizontalFlip(p)]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def to_ts() -> List:
|
| 32 |
+
return [torchvision.transforms.ToTensor()]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def to_pil() -> List:
|
| 36 |
+
return [torchvision.transforms.ToPILImage()]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def to_color() -> List:
|
| 40 |
+
return [
|
| 41 |
+
torchvision.transforms.Lambda(lambda x: x.convert('RGB')
|
| 42 |
+
if x.mode != 'RGB' else x)
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def norm(dataset: str = 'IMAGENET', _callable: bool = False) -> List:
|
| 47 |
+
dataset = dataset.upper()
|
| 48 |
+
mean_std = (
|
| 49 |
+
getattr(env, dataset + '_DEFAULT_MEAN'),
|
| 50 |
+
getattr(env, dataset + '_DEFAULT_STD'),
|
| 51 |
+
)
|
| 52 |
+
transforms = [torchvision.transforms.Normalize(*mean_std)]
|
| 53 |
+
if _callable:
|
| 54 |
+
return torchvision.transforms.Compose(transforms)
|
| 55 |
+
return transforms
|
attacks/AIM/src/gat/models/__init__.py
ADDED
|
File without changes
|
attacks/AIM/src/gat/models/attack/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .aim_attack import AIMAttack
|
| 2 |
+
from .cda_attack import CDAAttack
|
| 3 |
+
from .loss.logits import ContrastiveLoss
|
| 4 |
+
|
| 5 |
+
__all__ = ['AIMAttack', 'CDAAttack', 'ContrastiveLoss']
|
attacks/AIM/src/gat/models/attack/aim_attack.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from .base_attack import BaseGenerativeAttack
|
| 4 |
+
from .generator.aim import AIMGenerator
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class AIMAttack(BaseGenerativeAttack):
|
| 8 |
+
|
| 9 |
+
def set_adv_gen(self):
|
| 10 |
+
self.adv_gen = AIMGenerator().to(self.device)
|
| 11 |
+
|
| 12 |
+
def attack(self, x_nat, *extra_inputs) -> torch.Tensor:
|
| 13 |
+
x_guid = extra_inputs[0].to(self.device)
|
| 14 |
+
return self.adv_gen(x_nat, x_guid)
|
attacks/AIM/src/gat/models/attack/base_attack.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class BaseGenerativeAttack(abc.ABC):
|
| 10 |
+
|
| 11 |
+
def __init__(self,
|
| 12 |
+
device: Union[str, torch.device],
|
| 13 |
+
epsilon: float = 32 / 255) -> None:
|
| 14 |
+
if isinstance(device, str):
|
| 15 |
+
device = torch.device(device)
|
| 16 |
+
self.device = device
|
| 17 |
+
self.set_adv_gen()
|
| 18 |
+
self.set_mode('eval')
|
| 19 |
+
self.epsilon = epsilon
|
| 20 |
+
|
| 21 |
+
@abc.abstractmethod
|
| 22 |
+
def set_adv_gen(self):
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
def load_ckpt(self, ckpt: Union[str, Path, OrderedDict]) -> None:
|
| 26 |
+
if isinstance(ckpt, str):
|
| 27 |
+
ckpt = Path(ckpt)
|
| 28 |
+
if isinstance(ckpt, Path):
|
| 29 |
+
if not ckpt.exists():
|
| 30 |
+
raise FileNotFoundError(f'File not found: {ckpt}')
|
| 31 |
+
ckpt = torch.load(ckpt, map_location=self.device)
|
| 32 |
+
self.adv_gen.load_state_dict(ckpt)
|
| 33 |
+
self.adv_gen.to(self.device)
|
| 34 |
+
|
| 35 |
+
def save_ckpt(self, ckpt: Union[str, Path]) -> None:
|
| 36 |
+
if isinstance(ckpt, str):
|
| 37 |
+
ckpt = Path(ckpt)
|
| 38 |
+
_adv_gen_cpu = self.adv_gen.to('cpu')
|
| 39 |
+
torch.save(_adv_gen_cpu.state_dict(), ckpt)
|
| 40 |
+
|
| 41 |
+
def get_params(self) -> torch.nn.Parameter:
|
| 42 |
+
return self.adv_gen.parameters()
|
| 43 |
+
|
| 44 |
+
def get_model(self) -> torch.nn.Module:
|
| 45 |
+
return self.adv_gen
|
| 46 |
+
|
| 47 |
+
def set_mode(self, mode: str) -> None:
|
| 48 |
+
assert mode in ['train', 'eval']
|
| 49 |
+
self.adv_gen.train() if mode == 'train' else self.adv_gen.eval()
|
| 50 |
+
|
| 51 |
+
@abc.abstractmethod
|
| 52 |
+
def attack(self, *args) -> torch.Tensor:
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
def __call__(self, x_nat: torch.Tensor, *extra_inputs) -> torch.Tensor:
|
| 56 |
+
x_adv = self.attack(x_nat, *extra_inputs)
|
| 57 |
+
x_adv = torch.min(torch.max(x_adv, x_nat - self.epsilon),
|
| 58 |
+
x_nat + self.epsilon)
|
| 59 |
+
torch.clamp_(x_adv, 0.0, 1.0)
|
| 60 |
+
return x_adv
|
attacks/AIM/src/gat/models/attack/cda_attack.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from .base_attack import BaseGenerativeAttack
|
| 4 |
+
from .generator.cda import CDAGenerator
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class CDAAttack(BaseGenerativeAttack):
|
| 8 |
+
|
| 9 |
+
def set_adv_gen(self):
|
| 10 |
+
self.adv_gen = CDAGenerator().to(self.device)
|
| 11 |
+
|
| 12 |
+
def attack(self, x_nat, *extra_inputs) -> torch.Tensor:
|
| 13 |
+
return self.adv_gen(x_nat)
|
attacks/AIM/src/gat/models/attack/generator/__init__.py
ADDED
|
File without changes
|
attacks/AIM/src/gat/models/attack/generator/aim.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class EnhancedBN(nn.Module):
|
| 6 |
+
def __init__(self, nc: int, sty_nc: int = 3, sty_nhidden: int = 128):
|
| 7 |
+
super(EnhancedBN, self).__init__()
|
| 8 |
+
self.bn = nn.BatchNorm2d(nc)
|
| 9 |
+
self.mapping = nn.Conv2d(
|
| 10 |
+
in_channels=sty_nc,
|
| 11 |
+
out_channels=sty_nhidden,
|
| 12 |
+
kernel_size=3,
|
| 13 |
+
padding=1,
|
| 14 |
+
stride=1,
|
| 15 |
+
)
|
| 16 |
+
self.gamma = nn.Conv2d(
|
| 17 |
+
in_channels=sty_nhidden,
|
| 18 |
+
out_channels=nc,
|
| 19 |
+
kernel_size=3,
|
| 20 |
+
padding=1,
|
| 21 |
+
stride=1,
|
| 22 |
+
)
|
| 23 |
+
self.beta = nn.Conv2d(
|
| 24 |
+
in_channels=sty_nhidden,
|
| 25 |
+
out_channels=nc,
|
| 26 |
+
kernel_size=3,
|
| 27 |
+
padding=1,
|
| 28 |
+
stride=1,
|
| 29 |
+
)
|
| 30 |
+
self.init_weight()
|
| 31 |
+
|
| 32 |
+
def init_weight(self):
|
| 33 |
+
nn.init.kaiming_normal_(self.mapping.weight)
|
| 34 |
+
nn.init.kaiming_normal_(self.gamma.weight)
|
| 35 |
+
nn.init.kaiming_normal_(self.beta.weight)
|
| 36 |
+
|
| 37 |
+
def forward(self, base, sty):
|
| 38 |
+
bn = self.bn(base)
|
| 39 |
+
sty_resized = torch.nn.functional.interpolate(
|
| 40 |
+
sty, size=bn.size()[2:], mode='bilinear'
|
| 41 |
+
)
|
| 42 |
+
actv = torch.nn.functional.relu(self.mapping(sty_resized))
|
| 43 |
+
# style injection
|
| 44 |
+
bn = bn * (1 + self.gamma(actv)) + self.beta(actv)
|
| 45 |
+
return bn
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class ResidualBlock(nn.Module):
|
| 49 |
+
def __init__(self, num_filters):
|
| 50 |
+
super(ResidualBlock, self).__init__()
|
| 51 |
+
self.block1 = nn.Sequential(
|
| 52 |
+
nn.ReflectionPad2d(1),
|
| 53 |
+
nn.Conv2d(
|
| 54 |
+
in_channels=num_filters,
|
| 55 |
+
out_channels=num_filters,
|
| 56 |
+
kernel_size=3,
|
| 57 |
+
stride=1,
|
| 58 |
+
padding=0,
|
| 59 |
+
bias=False,
|
| 60 |
+
),
|
| 61 |
+
)
|
| 62 |
+
self.bn1 = EnhancedBN(num_filters)
|
| 63 |
+
self.block2 = nn.Sequential(
|
| 64 |
+
nn.ReLU(True),
|
| 65 |
+
nn.Dropout(0.5),
|
| 66 |
+
nn.ReflectionPad2d(1),
|
| 67 |
+
nn.Conv2d(
|
| 68 |
+
in_channels=num_filters,
|
| 69 |
+
out_channels=num_filters,
|
| 70 |
+
kernel_size=3,
|
| 71 |
+
stride=1,
|
| 72 |
+
padding=0,
|
| 73 |
+
bias=False,
|
| 74 |
+
),
|
| 75 |
+
)
|
| 76 |
+
self.bn2 = EnhancedBN(num_filters)
|
| 77 |
+
|
| 78 |
+
def forward(self, x, sty):
|
| 79 |
+
residual = self.block1(x)
|
| 80 |
+
residual = self.bn1(residual, sty)
|
| 81 |
+
residual = self.block2(residual)
|
| 82 |
+
residual = self.bn2(residual, sty)
|
| 83 |
+
return x + residual
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
ngf = 64
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class ResNetGenerator(nn.Module):
|
| 90 |
+
def __init__(self):
|
| 91 |
+
super(ResNetGenerator, self).__init__()
|
| 92 |
+
self.block1 = nn.Sequential(
|
| 93 |
+
nn.ReflectionPad2d(3),
|
| 94 |
+
nn.Conv2d(3, ngf, kernel_size=7, padding=0, bias=False),
|
| 95 |
+
)
|
| 96 |
+
self.bn1 = EnhancedBN(ngf)
|
| 97 |
+
# Input size = 3, n, n
|
| 98 |
+
self.block2 = nn.Sequential(
|
| 99 |
+
nn.Conv2d(
|
| 100 |
+
ngf, ngf * 2, kernel_size=3, stride=2, padding=1, bias=False
|
| 101 |
+
),
|
| 102 |
+
)
|
| 103 |
+
self.bn2 = EnhancedBN(ngf * 2)
|
| 104 |
+
# Input size = 3, n/2, n/2
|
| 105 |
+
self.block3 = nn.Sequential(
|
| 106 |
+
nn.Conv2d(
|
| 107 |
+
ngf * 2,
|
| 108 |
+
ngf * 4,
|
| 109 |
+
kernel_size=3,
|
| 110 |
+
stride=2,
|
| 111 |
+
padding=1,
|
| 112 |
+
bias=False,
|
| 113 |
+
),
|
| 114 |
+
)
|
| 115 |
+
self.bn3 = EnhancedBN(ngf * 4)
|
| 116 |
+
# Input size = 3, n/4, n/4
|
| 117 |
+
# Residual Blocks: 6
|
| 118 |
+
self.resblock1 = ResidualBlock(ngf * 4)
|
| 119 |
+
self.resblock2 = ResidualBlock(ngf * 4)
|
| 120 |
+
self.resblock3 = ResidualBlock(ngf * 4)
|
| 121 |
+
self.resblock4 = ResidualBlock(ngf * 4)
|
| 122 |
+
self.resblock5 = ResidualBlock(ngf * 4)
|
| 123 |
+
self.resblock6 = ResidualBlock(ngf * 4)
|
| 124 |
+
# Input size = 3, n/4, n/4
|
| 125 |
+
self.upsampl1 = nn.ConvTranspose2d(
|
| 126 |
+
ngf * 4,
|
| 127 |
+
ngf * 2,
|
| 128 |
+
kernel_size=3,
|
| 129 |
+
stride=2,
|
| 130 |
+
padding=1,
|
| 131 |
+
output_padding=1,
|
| 132 |
+
bias=False,
|
| 133 |
+
)
|
| 134 |
+
self.ubn1 = EnhancedBN(ngf * 2)
|
| 135 |
+
# Input size = 3, n/2, n/2
|
| 136 |
+
self.upsampl2 = nn.ConvTranspose2d(
|
| 137 |
+
ngf * 2,
|
| 138 |
+
ngf,
|
| 139 |
+
kernel_size=3,
|
| 140 |
+
stride=2,
|
| 141 |
+
padding=1,
|
| 142 |
+
output_padding=1,
|
| 143 |
+
bias=False,
|
| 144 |
+
)
|
| 145 |
+
self.ubn2 = EnhancedBN(ngf)
|
| 146 |
+
# Input size = 3, n, n
|
| 147 |
+
self.blockf = nn.Sequential(
|
| 148 |
+
nn.ReflectionPad2d(3), nn.Conv2d(ngf, 3, kernel_size=7, padding=0)
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
def forward(self, input, sty):
|
| 152 |
+
x = self.block1(input)
|
| 153 |
+
x = self.bn1(x, sty)
|
| 154 |
+
x = torch.nn.functional.relu(x)
|
| 155 |
+
x = self.block2(x)
|
| 156 |
+
x = self.bn2(x, sty)
|
| 157 |
+
x = torch.nn.functional.relu(x)
|
| 158 |
+
x = self.block3(x)
|
| 159 |
+
x = self.bn3(x, sty)
|
| 160 |
+
x = torch.nn.functional.relu(x)
|
| 161 |
+
# =============================
|
| 162 |
+
x = self.resblock1(x, sty)
|
| 163 |
+
x = self.resblock2(x, sty)
|
| 164 |
+
x = self.resblock3(x, sty)
|
| 165 |
+
x = self.resblock4(x, sty)
|
| 166 |
+
x = self.resblock5(x, sty)
|
| 167 |
+
x = self.resblock6(x, sty)
|
| 168 |
+
# =============================
|
| 169 |
+
x = self.upsampl1(x)
|
| 170 |
+
x = self.ubn1(x, sty)
|
| 171 |
+
x = torch.nn.functional.relu(x)
|
| 172 |
+
x = self.upsampl2(x)
|
| 173 |
+
x = self.ubn2(x, sty)
|
| 174 |
+
x = torch.nn.functional.relu(x)
|
| 175 |
+
x = self.blockf(x)
|
| 176 |
+
return (torch.tanh(x) + 1) / 2
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
AIMGenerator = ResNetGenerator
|
attacks/AIM/src/gat/models/attack/generator/cda.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code is copied from
|
| 2 |
+
# github.com/Alibaba-AAIG/Beyond-ImageNet-Attack:generator.py@863b7
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ResidualBlock(nn.Module):
|
| 8 |
+
def __init__(self, num_filters):
|
| 9 |
+
super(ResidualBlock, self).__init__()
|
| 10 |
+
self.block = nn.Sequential(
|
| 11 |
+
nn.ReflectionPad2d(1),
|
| 12 |
+
nn.Conv2d(
|
| 13 |
+
in_channels=num_filters,
|
| 14 |
+
out_channels=num_filters,
|
| 15 |
+
kernel_size=3,
|
| 16 |
+
stride=1,
|
| 17 |
+
padding=0,
|
| 18 |
+
bias=False,
|
| 19 |
+
),
|
| 20 |
+
nn.BatchNorm2d(num_filters),
|
| 21 |
+
nn.ReLU(True),
|
| 22 |
+
nn.Dropout(0.5),
|
| 23 |
+
nn.ReflectionPad2d(1),
|
| 24 |
+
nn.Conv2d(
|
| 25 |
+
in_channels=num_filters,
|
| 26 |
+
out_channels=num_filters,
|
| 27 |
+
kernel_size=3,
|
| 28 |
+
stride=1,
|
| 29 |
+
padding=0,
|
| 30 |
+
bias=False,
|
| 31 |
+
),
|
| 32 |
+
nn.BatchNorm2d(num_filters),
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
residual = self.block(x)
|
| 37 |
+
return x + residual
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
ngf = 64
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ResNetGenerator(nn.Module):
|
| 44 |
+
"""
|
| 45 |
+
https://github.com/Alibaba-AAIG/Beyond-ImageNet-Attack/blob/863b758ee4f4a6d3d4e7777c5f94f457fa449f73/generator.py#L14
|
| 46 |
+
|
| 47 |
+
Test Case:
|
| 48 |
+
>>> netG = ResNetGenerator()
|
| 49 |
+
>>> test_sample = torch.rand(1, 3, 224, 224)
|
| 50 |
+
>>> print("Generator output:", netG(test_sample).size())
|
| 51 |
+
>>> print(
|
| 52 |
+
>>> "Generator parameters:",
|
| 53 |
+
>>> sum(p.numel() for p in netG.parameters() if p.requires_grad),
|
| 54 |
+
>>> )
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self, inception=False):
|
| 58 |
+
super(ResNetGenerator, self).__init__()
|
| 59 |
+
self.inception = inception
|
| 60 |
+
self.block1 = nn.Sequential(
|
| 61 |
+
nn.ReflectionPad2d(3),
|
| 62 |
+
nn.Conv2d(3, ngf, kernel_size=7, padding=0, bias=False),
|
| 63 |
+
nn.BatchNorm2d(ngf),
|
| 64 |
+
nn.ReLU(True),
|
| 65 |
+
)
|
| 66 |
+
# output: (ngf) x (n) x (n)
|
| 67 |
+
self.block2 = nn.Sequential(
|
| 68 |
+
nn.Conv2d(
|
| 69 |
+
ngf, ngf * 2, kernel_size=3, stride=2, padding=1, bias=False
|
| 70 |
+
),
|
| 71 |
+
nn.BatchNorm2d(ngf * 2),
|
| 72 |
+
nn.ReLU(True),
|
| 73 |
+
)
|
| 74 |
+
# output: (ngf*2) x (n/2) x (n/2)
|
| 75 |
+
self.block3 = nn.Sequential(
|
| 76 |
+
nn.Conv2d(
|
| 77 |
+
ngf * 2,
|
| 78 |
+
ngf * 4,
|
| 79 |
+
kernel_size=3,
|
| 80 |
+
stride=2,
|
| 81 |
+
padding=1,
|
| 82 |
+
bias=False,
|
| 83 |
+
),
|
| 84 |
+
nn.BatchNorm2d(ngf * 4),
|
| 85 |
+
nn.ReLU(True),
|
| 86 |
+
)
|
| 87 |
+
# output: (ngf*4) x (n/4) x (n/4)
|
| 88 |
+
self.resblock1 = ResidualBlock(ngf * 4)
|
| 89 |
+
self.resblock2 = ResidualBlock(ngf * 4)
|
| 90 |
+
self.resblock3 = ResidualBlock(ngf * 4)
|
| 91 |
+
self.resblock4 = ResidualBlock(ngf * 4)
|
| 92 |
+
self.resblock5 = ResidualBlock(ngf * 4)
|
| 93 |
+
self.resblock6 = ResidualBlock(ngf * 4)
|
| 94 |
+
# output: (ngf*4) x (n/4) x (n/4)
|
| 95 |
+
self.upsampl1 = nn.Sequential(
|
| 96 |
+
nn.ConvTranspose2d(
|
| 97 |
+
ngf * 4,
|
| 98 |
+
ngf * 2,
|
| 99 |
+
kernel_size=3,
|
| 100 |
+
stride=2,
|
| 101 |
+
padding=1,
|
| 102 |
+
output_padding=1,
|
| 103 |
+
bias=False,
|
| 104 |
+
),
|
| 105 |
+
nn.BatchNorm2d(ngf * 2),
|
| 106 |
+
nn.ReLU(True),
|
| 107 |
+
)
|
| 108 |
+
# output: (ngf*2) x (n/2) x (n/2)
|
| 109 |
+
self.upsampl2 = nn.Sequential(
|
| 110 |
+
nn.ConvTranspose2d(
|
| 111 |
+
ngf * 2,
|
| 112 |
+
ngf,
|
| 113 |
+
kernel_size=3,
|
| 114 |
+
stride=2,
|
| 115 |
+
padding=1,
|
| 116 |
+
output_padding=1,
|
| 117 |
+
bias=False,
|
| 118 |
+
),
|
| 119 |
+
nn.BatchNorm2d(ngf),
|
| 120 |
+
nn.ReLU(True),
|
| 121 |
+
)
|
| 122 |
+
# output: (ngf) x (n) x (n)
|
| 123 |
+
self.blockf = nn.Sequential(
|
| 124 |
+
nn.ReflectionPad2d(3), nn.Conv2d(ngf, 3, kernel_size=7, padding=0)
|
| 125 |
+
)
|
| 126 |
+
self.crop = nn.ConstantPad2d((0, -1, -1, 0), 0)
|
| 127 |
+
|
| 128 |
+
def forward(self, input):
|
| 129 |
+
x = self.block1(input)
|
| 130 |
+
x = self.block2(x)
|
| 131 |
+
x = self.block3(x)
|
| 132 |
+
x = self.resblock1(x)
|
| 133 |
+
x = self.resblock2(x)
|
| 134 |
+
x = self.resblock3(x)
|
| 135 |
+
x = self.resblock4(x)
|
| 136 |
+
x = self.resblock5(x)
|
| 137 |
+
x = self.resblock6(x)
|
| 138 |
+
x = self.upsampl1(x)
|
| 139 |
+
x = self.upsampl2(x)
|
| 140 |
+
x = self.blockf(x)
|
| 141 |
+
if self.inception:
|
| 142 |
+
x = self.crop(x)
|
| 143 |
+
return (torch.tanh(x) + 1) / 2
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
CDAGenerator = ResNetGenerator
|
attacks/AIM/src/gat/models/attack/loss/__init__.py
ADDED
|
File without changes
|
attacks/AIM/src/gat/models/attack/loss/logits.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ContrastiveLoss(torch.nn.Module):
|
| 5 |
+
"""
|
| 6 |
+
Contrastive loss
|
| 7 |
+
Adapted from: (OnlineContrastiveLoss)
|
| 8 |
+
https://github.com/adambielski/siamese-triplet/blob/master/losses.py
|
| 9 |
+
Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, margin):
|
| 13 |
+
super(ContrastiveLoss, self).__init__()
|
| 14 |
+
self.margin = margin
|
| 15 |
+
|
| 16 |
+
def forward(self, anchors, negatives, positives):
|
| 17 |
+
anchors = anchors / anchors.norm(dim=-1, keepdim=True)
|
| 18 |
+
negatives = negatives / negatives.norm(dim=-1, keepdim=True)
|
| 19 |
+
positives = positives / positives.norm(dim=-1, keepdim=True)
|
| 20 |
+
|
| 21 |
+
positive_loss = (anchors - positives).pow(2).sum(1)
|
| 22 |
+
negative_loss = torch.nn.functional.relu(
|
| 23 |
+
self.margin - (anchors - negatives).pow(2).sum(1).sqrt()).pow(2)
|
| 24 |
+
|
| 25 |
+
loss = 0.5 * torch.cat([positive_loss, negative_loss], dim=0)
|
| 26 |
+
|
| 27 |
+
return loss.mean()
|
attacks/AIM/src/gat/models/attack/optim/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .sam import SAM, disable_running_stats, enable_running_stats
|
| 2 |
+
|
| 3 |
+
__all__ = ['SAM', 'enable_running_stats', 'disable_running_stats']
|
attacks/AIM/src/gat/models/attack/optim/sam.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code is copied from
|
| 2 |
+
# github.com/davda54/sam/sam.py@3c3afdb
|
| 3 |
+
import torch
|
| 4 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class SAM(torch.optim.Optimizer):
|
| 8 |
+
|
| 9 |
+
def __init__(self,
|
| 10 |
+
params,
|
| 11 |
+
base_optimizer,
|
| 12 |
+
rho=0.05,
|
| 13 |
+
adaptive=False,
|
| 14 |
+
**kwargs):
|
| 15 |
+
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
|
| 16 |
+
|
| 17 |
+
defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
|
| 18 |
+
super(SAM, self).__init__(params, defaults)
|
| 19 |
+
|
| 20 |
+
self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
|
| 21 |
+
self.param_groups = self.base_optimizer.param_groups
|
| 22 |
+
self.defaults.update(self.base_optimizer.defaults)
|
| 23 |
+
|
| 24 |
+
@torch.no_grad()
|
| 25 |
+
def first_step(self, zero_grad=False):
|
| 26 |
+
grad_norm = self._grad_norm()
|
| 27 |
+
for group in self.param_groups:
|
| 28 |
+
scale = group['rho'] / (grad_norm + 1e-12)
|
| 29 |
+
|
| 30 |
+
for p in group['params']:
|
| 31 |
+
if p.grad is None:
|
| 32 |
+
continue
|
| 33 |
+
self.state[p]['old_p'] = p.data.clone()
|
| 34 |
+
e_w = (torch.pow(p, 2)
|
| 35 |
+
if group['adaptive'] else 1.0) * p.grad * scale.to(p)
|
| 36 |
+
p.add_(e_w) # climb to the local maximum "w + e(w)"
|
| 37 |
+
|
| 38 |
+
if zero_grad:
|
| 39 |
+
self.zero_grad()
|
| 40 |
+
|
| 41 |
+
@torch.no_grad()
|
| 42 |
+
def second_step(self, zero_grad=False):
|
| 43 |
+
for group in self.param_groups:
|
| 44 |
+
for p in group['params']:
|
| 45 |
+
if p.grad is None:
|
| 46 |
+
continue
|
| 47 |
+
p.data = self.state[p][
|
| 48 |
+
'old_p'] # get back to "w" from "w + e(w)"
|
| 49 |
+
|
| 50 |
+
self.base_optimizer.step() # do the actual "sharpness-aware" update
|
| 51 |
+
|
| 52 |
+
if zero_grad:
|
| 53 |
+
self.zero_grad()
|
| 54 |
+
|
| 55 |
+
@torch.no_grad()
|
| 56 |
+
def step(self, closure=None):
|
| 57 |
+
assert closure is not None, \
|
| 58 |
+
'Sharpness Aware Minimization requires closure, ' \
|
| 59 |
+
'but it was not provided'
|
| 60 |
+
closure = torch.enable_grad()(
|
| 61 |
+
closure) # the closure should do a full forward-backward pass
|
| 62 |
+
|
| 63 |
+
self.first_step(zero_grad=True)
|
| 64 |
+
closure()
|
| 65 |
+
self.second_step()
|
| 66 |
+
|
| 67 |
+
def _grad_norm(self):
|
| 68 |
+
# put everything on the same device, in case of model parallelism
|
| 69 |
+
shared_device = self.param_groups[0]['params'][0].device
|
| 70 |
+
norm = torch.norm(torch.stack([
|
| 71 |
+
((torch.abs(p) if group['adaptive'] else 1.0) *
|
| 72 |
+
p.grad).norm(p=2).to(shared_device) for group in self.param_groups
|
| 73 |
+
for p in group['params'] if p.grad is not None
|
| 74 |
+
]),
|
| 75 |
+
p=2)
|
| 76 |
+
return norm
|
| 77 |
+
|
| 78 |
+
def load_state_dict(self, state_dict):
|
| 79 |
+
super().load_state_dict(state_dict)
|
| 80 |
+
self.base_optimizer.param_groups = self.param_groups
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def disable_running_stats(model):
|
| 84 |
+
|
| 85 |
+
def _disable(module):
|
| 86 |
+
if isinstance(module, _BatchNorm):
|
| 87 |
+
module.backup_momentum = module.momentum
|
| 88 |
+
module.momentum = 0
|
| 89 |
+
|
| 90 |
+
model.apply(_disable)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def enable_running_stats(model):
|
| 94 |
+
|
| 95 |
+
def _enable(module):
|
| 96 |
+
if isinstance(module, _BatchNorm) and hasattr(module,
|
| 97 |
+
'backup_momentum'):
|
| 98 |
+
module.momentum = module.backup_momentum
|
| 99 |
+
|
| 100 |
+
model.apply(_enable)
|
attacks/AIM/src/gat/models/surrogate/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .builder import build_surrogate, list_surrogates
|
| 2 |
+
from .hooks import feat_col, midlayer_dict, register_collecter, register_collecter_cl
|
| 3 |
+
from .tv import (densenet121, densenet169, inception_v3, resnet50, resnet152,
|
| 4 |
+
swin_b, vgg16, vgg19, vit_b_16, vit_b_32)
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
'build_surrogate', 'list_surrogates', 'inception_v3', 'vgg16', 'vgg19',
|
| 8 |
+
'resnet50', 'resnet152', 'densenet121', 'densenet169', 'vit_b_16',
|
| 9 |
+
'vit_b_32', 'swin_b', 'midlayer_dict', 'register_collecter', 'feat_col', 'register_collecter_cl'
|
| 10 |
+
]
|
attacks/AIM/src/gat/models/surrogate/builder.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ...runtime.factory import Registry
|
| 2 |
+
|
| 3 |
+
SURROGATE_REGISTRY = Registry('SURROGATE')
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def build_surrogate(_type: str, *args, **kwargs) -> object:
|
| 7 |
+
return SURROGATE_REGISTRY.get(_type)(*args, **kwargs)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def list_surrogates():
|
| 11 |
+
return list(SURROGATE_REGISTRY._obj_map.keys())
|
attacks/AIM/src/gat/models/surrogate/hooks.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import contextmanager
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
midlayer_dict = {
|
| 7 |
+
'vgg16': 'features.16',
|
| 8 |
+
'vgg19': 'features.18',
|
| 9 |
+
'resnet152': 'layer1',
|
| 10 |
+
'densenet169': 'features.denseblock2',
|
| 11 |
+
'inception_v3': 'Mixed_6c',
|
| 12 |
+
'resnet50': 'conv1', # conv1, layer1, layer2, layer3, layer4
|
| 13 |
+
'resnet50_cl': 'layer4',
|
| 14 |
+
'resnet32': 'conv_1_3x3', # conv_1_3x3, stage_1, stage_2, stage_3
|
| 15 |
+
'resnet32_cl': 'conv_1_3x3',
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def register_collecter(m: torch.nn.Module, layer: str, feat_collecter: List):
|
| 20 |
+
|
| 21 |
+
def _hook(m, i, o):
|
| 22 |
+
feat_collecter.append(o)
|
| 23 |
+
|
| 24 |
+
_handler = m.get_submodule(layer).register_forward_hook(_hook) #m.convnets[0].get_submodule(layer).register_forward_hook(_hook)
|
| 25 |
+
return _handler, feat_collecter
|
| 26 |
+
|
| 27 |
+
def register_collecter_cl(m: torch.nn.Module, layer: str, feat_collecter: List, cl_methods: str):
|
| 28 |
+
|
| 29 |
+
def _hook(m, i, o):
|
| 30 |
+
feat_collecter.append(o)
|
| 31 |
+
|
| 32 |
+
if cl_methods == 'icarl' or cl_methods == 'finetune' or cl_methods == 'wa' or cl_methods == 'replay' or cl_methods == 'podnet' or cl_methods == 'bic':
|
| 33 |
+
_handler = m.convnet.get_submodule(layer).register_forward_hook(_hook) # For ewc, icarl methods
|
| 34 |
+
elif cl_methods == 'foster' or cl_methods == 'der':
|
| 35 |
+
_handler = m.convnets[0].get_submodule(layer).register_forward_hook(_hook) # For foster and der methods
|
| 36 |
+
elif cl_methods == 'memo':
|
| 37 |
+
_handler = m.TaskAgnosticExtractor.get_submodule(layer).register_forward_hook(_hook) # For foster and der methods
|
| 38 |
+
return _handler, feat_collecter
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@contextmanager
|
| 42 |
+
def feat_col(m: Union[torch.nn.Module, List[torch.nn.Module]],
|
| 43 |
+
layer: Union[str, List[str]]):
|
| 44 |
+
if isinstance(m, torch.nn.Module):
|
| 45 |
+
m = [m]
|
| 46 |
+
if isinstance(layer, str):
|
| 47 |
+
layer = [layer]
|
| 48 |
+
assert len(m) == len(layer)
|
| 49 |
+
handlers = []
|
| 50 |
+
feat_collecter = []
|
| 51 |
+
for _m, _layer in zip(m, layer):
|
| 52 |
+
handler, feat_collecter = register_collecter(_m, _layer,
|
| 53 |
+
feat_collecter)
|
| 54 |
+
handlers.append(handler)
|
| 55 |
+
yield feat_collecter
|
| 56 |
+
for handler in handlers:
|
| 57 |
+
handler.remove()
|
| 58 |
+
feat_collecter.clear()
|
| 59 |
+
del feat_collecter
|
attacks/AIM/src/gat/models/surrogate/tv.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchvision import models as tv_models
|
| 2 |
+
|
| 3 |
+
from .builder import SURROGATE_REGISTRY
|
| 4 |
+
|
| 5 |
+
weights_dict = dict(
|
| 6 |
+
alexnet=tv_models.AlexNet_Weights.IMAGENET1K_V1,
|
| 7 |
+
googlenet=tv_models.GoogLeNet_Weights.IMAGENET1K_V1,
|
| 8 |
+
vgg16=tv_models.VGG16_Weights.IMAGENET1K_V1,
|
| 9 |
+
vgg19=tv_models.VGG19_Weights.IMAGENET1K_V1,
|
| 10 |
+
inception_v3=tv_models.Inception_V3_Weights.IMAGENET1K_V1,
|
| 11 |
+
resnet18=tv_models.ResNet18_Weights.IMAGENET1K_V1,
|
| 12 |
+
resnet34=tv_models.ResNet34_Weights.IMAGENET1K_V1,
|
| 13 |
+
resnet50=tv_models.ResNet50_Weights.IMAGENET1K_V1,
|
| 14 |
+
resnet152=tv_models.ResNet152_Weights.IMAGENET1K_V1,
|
| 15 |
+
wide_resnet50_2=tv_models.Wide_ResNet50_2_Weights.IMAGENET1K_V1,
|
| 16 |
+
wide_resnet101_2=tv_models.Wide_ResNet101_2_Weights.IMAGENET1K_V1,
|
| 17 |
+
densenet121=tv_models.DenseNet121_Weights.IMAGENET1K_V1,
|
| 18 |
+
densenet169=tv_models.DenseNet169_Weights.IMAGENET1K_V1,
|
| 19 |
+
mobilenet_v2=tv_models.MobileNet_V2_Weights.IMAGENET1K_V1,
|
| 20 |
+
mobilenet_v3_small=tv_models.MobileNet_V3_Small_Weights.IMAGENET1K_V1,
|
| 21 |
+
mobilenet_v3_large=tv_models.MobileNet_V3_Large_Weights.IMAGENET1K_V1,
|
| 22 |
+
squeezenet1_0=tv_models.SqueezeNet1_0_Weights.IMAGENET1K_V1,
|
| 23 |
+
squeezenet1_1=tv_models.SqueezeNet1_1_Weights.IMAGENET1K_V1,
|
| 24 |
+
shufflenet_v2_x0_5=tv_models.ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1,
|
| 25 |
+
shufflenet_v2_x1_0=tv_models.ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1,
|
| 26 |
+
shufflenet_v2_x1_5=tv_models.ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1,
|
| 27 |
+
shufflenet_v2_x2_0=tv_models.ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1,
|
| 28 |
+
efficientnet_b0=tv_models.EfficientNet_B0_Weights.IMAGENET1K_V1,
|
| 29 |
+
efficientnet_v2_s=tv_models.EfficientNet_V2_S_Weights.IMAGENET1K_V1,
|
| 30 |
+
efficientnet_v2_m=tv_models.EfficientNet_V2_M_Weights.IMAGENET1K_V1,
|
| 31 |
+
efficientnet_v2_l=tv_models.EfficientNet_V2_L_Weights.IMAGENET1K_V1,
|
| 32 |
+
vit_b_16=tv_models.ViT_B_16_Weights.IMAGENET1K_V1,
|
| 33 |
+
vit_b_32=tv_models.ViT_B_32_Weights.IMAGENET1K_V1,
|
| 34 |
+
swin_b=tv_models.Swin_B_Weights.IMAGENET1K_V1)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@SURROGATE_REGISTRY.register()
|
| 38 |
+
def alexnet(pretrain=True):
|
| 39 |
+
return tv_models.alexnet(
|
| 40 |
+
weights=weights_dict['alexnet'] if pretrain else None)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@SURROGATE_REGISTRY.register()
|
| 44 |
+
def googlenet(pretrain=True):
|
| 45 |
+
return tv_models.googlenet(
|
| 46 |
+
weights=weights_dict['googlenet'] if pretrain else None)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@SURROGATE_REGISTRY.register()
|
| 50 |
+
def vgg16(pretrain=True):
|
| 51 |
+
return tv_models.vgg16(weights=weights_dict['vgg16'] if pretrain else None)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@SURROGATE_REGISTRY.register()
|
| 55 |
+
def vgg19(pretrain=True):
|
| 56 |
+
return tv_models.vgg19(weights=weights_dict['vgg19'] if pretrain else None)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@SURROGATE_REGISTRY.register()
|
| 60 |
+
def inception_v3(pretrain=True):
|
| 61 |
+
return tv_models.inception_v3(
|
| 62 |
+
weights=weights_dict['inception_v3'] if pretrain else None)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@SURROGATE_REGISTRY.register()
|
| 66 |
+
def resnet18(pretrain=True):
|
| 67 |
+
return tv_models.resnet18(
|
| 68 |
+
weights=weights_dict['resnet18'] if pretrain else None)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@SURROGATE_REGISTRY.register()
|
| 72 |
+
def resnet34(pretrain=True):
|
| 73 |
+
return tv_models.resnet34(
|
| 74 |
+
weights=weights_dict['resnet34'] if pretrain else None)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@SURROGATE_REGISTRY.register()
|
| 78 |
+
def resnet50(pretrain=True):
|
| 79 |
+
return tv_models.resnet50(
|
| 80 |
+
weights=weights_dict['resnet50'] if pretrain else None)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@SURROGATE_REGISTRY.register()
|
| 84 |
+
def resnet152(pretrain=True):
|
| 85 |
+
return tv_models.resnet152(
|
| 86 |
+
weights=weights_dict['resnet152'] if pretrain else None)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@SURROGATE_REGISTRY.register()
|
| 90 |
+
def wide_resnet50_2(pretrain=True):
|
| 91 |
+
return tv_models.wide_resnet50_2(
|
| 92 |
+
weights=weights_dict['wide_resnet50_2'] if pretrain else None)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@SURROGATE_REGISTRY.register()
|
| 96 |
+
def wide_resnet101_2(pretrain=True):
|
| 97 |
+
return tv_models.wide_resnet101_2(
|
| 98 |
+
weights=weights_dict['wide_resnet101_2'] if pretrain else None)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@SURROGATE_REGISTRY.register()
|
| 102 |
+
def densenet121(pretrain=True):
|
| 103 |
+
return tv_models.densenet121(
|
| 104 |
+
weights=weights_dict['densenet121'] if pretrain else None)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@SURROGATE_REGISTRY.register()
|
| 108 |
+
def densenet169(pretrain=True):
|
| 109 |
+
return tv_models.densenet169(
|
| 110 |
+
weights=weights_dict['densenet169'] if pretrain else None)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@SURROGATE_REGISTRY.register()
|
| 114 |
+
def mobilenet_v2(pretrain=True):
|
| 115 |
+
return tv_models.mobilenet_v2(
|
| 116 |
+
weights=weights_dict['mobilenet_v2'] if pretrain else None)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@SURROGATE_REGISTRY.register()
|
| 120 |
+
def mobilenet_v3_small(pretrain=True):
|
| 121 |
+
return tv_models.mobilenet_v3_small(
|
| 122 |
+
weights=weights_dict['mobilenet_v3_small'] if pretrain else None)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@SURROGATE_REGISTRY.register()
|
| 126 |
+
def mobilenet_v3_large(pretrain=True):
|
| 127 |
+
return tv_models.mobilenet_v3_large(
|
| 128 |
+
weights=weights_dict['mobilenet_v3_large'] if pretrain else None)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@SURROGATE_REGISTRY.register()
|
| 132 |
+
def squeezenet1_0(pretrain=True):
|
| 133 |
+
return tv_models.squeezenet1_0(
|
| 134 |
+
weights=weights_dict['squeezenet1_0'] if pretrain else None)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@SURROGATE_REGISTRY.register()
|
| 138 |
+
def squeezenet1_1(pretrain=True):
|
| 139 |
+
return tv_models.squeezenet1_1(
|
| 140 |
+
weights=weights_dict['squeezenet1_1'] if pretrain else None)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@SURROGATE_REGISTRY.register()
|
| 144 |
+
def shufflenet_v2_x0_5(pretrain=True):
|
| 145 |
+
return tv_models.shufflenet_v2_x0_5(
|
| 146 |
+
weights=weights_dict['shufflenet_v2_x0_5'] if pretrain else None)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
@SURROGATE_REGISTRY.register()
|
| 150 |
+
def shufflenet_v2_x1_0(pretrain=True):
|
| 151 |
+
return tv_models.shufflenet_v2_x1_0(
|
| 152 |
+
weights=weights_dict['shufflenet_v2_x1_0'] if pretrain else None)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@SURROGATE_REGISTRY.register()
|
| 156 |
+
def shufflenet_v2_x1_5(pretrain=True):
|
| 157 |
+
return tv_models.shufflenet_v2_x1_5(
|
| 158 |
+
weights=weights_dict['shufflenet_v2_x1_5'] if pretrain else None)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@SURROGATE_REGISTRY.register()
|
| 162 |
+
def shufflenet_v2_x2_0(pretrain=True):
|
| 163 |
+
return tv_models.shufflenet_v2_x2_0(
|
| 164 |
+
weights=weights_dict['shufflenet_v2_x2_0'] if pretrain else None)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
@SURROGATE_REGISTRY.register()
|
| 168 |
+
def efficientnet_b0(pretrain=True):
|
| 169 |
+
return tv_models.efficientnet_b0(
|
| 170 |
+
weights=weights_dict['efficientnet_b0'] if pretrain else None)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
@SURROGATE_REGISTRY.register()
|
| 174 |
+
def efficientnet_v2_s(pretrain=True):
|
| 175 |
+
return tv_models.efficientnet_v2_s(
|
| 176 |
+
weights=weights_dict['efficientnet_v2_s'] if pretrain else None)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
@SURROGATE_REGISTRY.register()
|
| 180 |
+
def efficientnet_v2_m(pretrain=True):
|
| 181 |
+
return tv_models.efficientnet_v2_m(
|
| 182 |
+
weights=weights_dict['efficientnet_v2_m'] if pretrain else None)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
@SURROGATE_REGISTRY.register()
|
| 186 |
+
def efficientnet_v2_l(pretrain=True):
|
| 187 |
+
return tv_models.efficientnet_v2_l(
|
| 188 |
+
weights=weights_dict['efficientnet_v2_l'] if pretrain else None)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
@SURROGATE_REGISTRY.register()
|
| 192 |
+
def vit_b_16(pretrain=True):
|
| 193 |
+
return tv_models.vit_b_16(
|
| 194 |
+
weights=weights_dict['vit_b_16'] if pretrain else None)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
@SURROGATE_REGISTRY.register()
|
| 198 |
+
def vit_b_32(pretrain=True):
|
| 199 |
+
return tv_models.vit_b_32(
|
| 200 |
+
weights=weights_dict['vit_b_32'] if pretrain else None)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
@SURROGATE_REGISTRY.register()
|
| 204 |
+
def swin_b(pretrain=True):
|
| 205 |
+
return tv_models.swin_b(
|
| 206 |
+
weights=weights_dict['swin_b'] if pretrain else None)
|
attacks/AIM/src/gat/runtime/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .meter import AverageMeter, calc_cls_accuracy
|
| 2 |
+
from .utils import fix_random, randid
|
| 3 |
+
|
| 4 |
+
__all__ = ['fix_random', 'randid', 'AverageMeter', 'calc_cls_accuracy']
|
attacks/AIM/src/gat/runtime/api/__init__.py
ADDED
|
File without changes
|
attacks/AIM/src/gat/runtime/api/aim_attack.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import io
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Union
|
| 6 |
+
|
| 7 |
+
import torchvision
|
| 8 |
+
import uvicorn
|
| 9 |
+
from fastapi import FastAPI, File, UploadFile
|
| 10 |
+
from fastapi.responses import StreamingResponse
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
from ...datasets.transforms import resize_256_224, to_color, to_pil, to_ts
|
| 14 |
+
from ...models.attack import AIMAttack
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def init_attack(ckpt: Union[str, Path] = None):
|
| 18 |
+
attack = AIMAttack(device='cpu')
|
| 19 |
+
if ckpt:
|
| 20 |
+
attack.load_ckpt(ckpt)
|
| 21 |
+
attack.set_mode('eval')
|
| 22 |
+
attack_preproc = torchvision.transforms.Compose(resize_256_224() +
|
| 23 |
+
to_color() + to_ts())
|
| 24 |
+
return attack, attack_preproc
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main():
|
| 28 |
+
parser = argparse.ArgumentParser(description='AIM Attack API')
|
| 29 |
+
parser.add_argument('--host', type=str, default='0.0.0.0', help='host')
|
| 30 |
+
parser.add_argument('--port', type=int, default=8000, help='port')
|
| 31 |
+
parser.add_argument('--ckpt',
|
| 32 |
+
type=str,
|
| 33 |
+
default=None,
|
| 34 |
+
help='path to the checkpoint')
|
| 35 |
+
args = parser.parse_args()
|
| 36 |
+
|
| 37 |
+
attack, attack_preproc = init_attack(args.ckpt)
|
| 38 |
+
app = FastAPI()
|
| 39 |
+
|
| 40 |
+
@app.get('/attack/aim/')
|
| 41 |
+
async def aim_attack(x_nat: UploadFile = File(...),
|
| 42 |
+
x_guid: UploadFile = File(...)):
|
| 43 |
+
io_x_nat = await x_nat.read()
|
| 44 |
+
io_x_guid = await x_guid.read()
|
| 45 |
+
pil_x_nat = Image.open(io.BytesIO(io_x_nat))
|
| 46 |
+
pil_x_guid = Image.open(io.BytesIO(io_x_guid))
|
| 47 |
+
ts_x_nat = attack_preproc(pil_x_nat).unsqueeze(0)
|
| 48 |
+
ts_x_guid = attack_preproc(pil_x_guid).unsqueeze(0)
|
| 49 |
+
ts_x_adv = attack(ts_x_nat, ts_x_guid)
|
| 50 |
+
pil_x_adv = torchvision.transforms.Compose(to_pil())(ts_x_adv[0])
|
| 51 |
+
|
| 52 |
+
img_byte_array = io.BytesIO()
|
| 53 |
+
pil_x_adv.save(img_byte_array, format='PNG')
|
| 54 |
+
img_byte_array.seek(0)
|
| 55 |
+
|
| 56 |
+
return StreamingResponse(img_byte_array, media_type='image/png')
|
| 57 |
+
|
| 58 |
+
uvicorn.run(app, host=args.host, port=args.port)
|
attacks/AIM/src/gat/runtime/factory.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code is copied from
|
| 2 |
+
# github.com/facebookresearch/fvcore:common/registry.py@242366
|
| 3 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 4 |
+
|
| 5 |
+
# pyre-strict
|
| 6 |
+
# pyre-ignore-all-errors[2,3]
|
| 7 |
+
from typing import Any, Dict, Iterable, Iterator, Tuple
|
| 8 |
+
|
| 9 |
+
from tabulate import tabulate
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Registry(Iterable[Tuple[str, Any]]):
|
| 13 |
+
"""
|
| 14 |
+
The registry that provides name -> object mapping, to support third-party
|
| 15 |
+
users' custom modules.
|
| 16 |
+
|
| 17 |
+
To create a registry (e.g. a backbone registry):
|
| 18 |
+
|
| 19 |
+
.. code-block:: python
|
| 20 |
+
|
| 21 |
+
BACKBONE_REGISTRY = Registry('BACKBONE')
|
| 22 |
+
|
| 23 |
+
To register an object:
|
| 24 |
+
|
| 25 |
+
.. code-block:: python
|
| 26 |
+
|
| 27 |
+
@BACKBONE_REGISTRY.register()
|
| 28 |
+
class MyBackbone():
|
| 29 |
+
...
|
| 30 |
+
|
| 31 |
+
Or:
|
| 32 |
+
|
| 33 |
+
.. code-block:: python
|
| 34 |
+
|
| 35 |
+
BACKBONE_REGISTRY.register(MyBackbone)
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, name: str) -> None:
|
| 39 |
+
"""
|
| 40 |
+
Args:
|
| 41 |
+
name (str): the name of this registry
|
| 42 |
+
"""
|
| 43 |
+
self._name: str = name
|
| 44 |
+
self._obj_map: Dict[str, Any] = {}
|
| 45 |
+
|
| 46 |
+
def _do_register(self, name: str, obj: Any) -> None:
|
| 47 |
+
assert (
|
| 48 |
+
name not in self._obj_map
|
| 49 |
+
), "An object named '{}' was already registered in '{}' registry!".format( # noqa: E501
|
| 50 |
+
name, self._name
|
| 51 |
+
)
|
| 52 |
+
self._obj_map[name] = obj
|
| 53 |
+
|
| 54 |
+
def register(self, obj: Any = None) -> Any:
|
| 55 |
+
"""
|
| 56 |
+
Register the given object under the the name `obj.__name__`.
|
| 57 |
+
Can be used as either a decorator or not. See docstring of this class
|
| 58 |
+
for usage.
|
| 59 |
+
"""
|
| 60 |
+
if obj is None:
|
| 61 |
+
# used as a decorator
|
| 62 |
+
def deco(func_or_class: Any) -> Any:
|
| 63 |
+
name = func_or_class.__name__
|
| 64 |
+
self._do_register(name, func_or_class)
|
| 65 |
+
return func_or_class
|
| 66 |
+
|
| 67 |
+
return deco
|
| 68 |
+
|
| 69 |
+
# used as a function call
|
| 70 |
+
name = obj.__name__
|
| 71 |
+
self._do_register(name, obj)
|
| 72 |
+
|
| 73 |
+
def get(self, name: str) -> Any:
|
| 74 |
+
ret = self._obj_map.get(name)
|
| 75 |
+
if ret is None:
|
| 76 |
+
raise KeyError(
|
| 77 |
+
"No object named '{}' found in '{}' registry!".format(
|
| 78 |
+
name, self._name
|
| 79 |
+
)
|
| 80 |
+
)
|
| 81 |
+
return ret
|
| 82 |
+
|
| 83 |
+
def __contains__(self, name: str) -> bool:
|
| 84 |
+
return name in self._obj_map
|
| 85 |
+
|
| 86 |
+
def __repr__(self) -> str:
|
| 87 |
+
table_headers = ['Names', 'Objects']
|
| 88 |
+
table = tabulate(
|
| 89 |
+
self._obj_map.items(), headers=table_headers, tablefmt='fancy_grid'
|
| 90 |
+
)
|
| 91 |
+
return 'Registry of {}:\n'.format(self._name) + table
|
| 92 |
+
|
| 93 |
+
def __iter__(self) -> Iterator[Tuple[str, Any]]:
|
| 94 |
+
return iter(self._obj_map.items())
|
| 95 |
+
|
| 96 |
+
# pyre-fixme[4]: Attribute must be annotated.
|
| 97 |
+
__str__ = __repr__
|
attacks/AIM/src/gat/runtime/meter.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class AverageMeter:
|
| 2 |
+
def __init__(self):
|
| 3 |
+
self.reset()
|
| 4 |
+
|
| 5 |
+
def reset(self):
|
| 6 |
+
self.val = 0
|
| 7 |
+
self.avg = 0
|
| 8 |
+
self.sum = 0
|
| 9 |
+
self.count = 0
|
| 10 |
+
|
| 11 |
+
def update(self, val, n=1):
|
| 12 |
+
self.val = val
|
| 13 |
+
self.sum += val * n
|
| 14 |
+
self.count += n
|
| 15 |
+
self.avg = self.sum / self.count
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def calc_cls_accuracy(output, target, topk=(1,)):
|
| 19 |
+
maxk = min(max(topk), output.size()[1])
|
| 20 |
+
batch_size = target.size(0)
|
| 21 |
+
_, pred = output.topk(maxk, 1, True, True)
|
| 22 |
+
pred = pred.t()
|
| 23 |
+
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
|
| 24 |
+
return [
|
| 25 |
+
correct[: min(k, maxk)].reshape(-1).float().sum(0) * 100.0 / batch_size
|
| 26 |
+
for k in topk
|
| 27 |
+
]
|
attacks/AIM/src/gat/runtime/utils.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def fix_random(seed: int = 0) -> None:
|
| 8 |
+
np.random.seed(seed)
|
| 9 |
+
torch.manual_seed(seed)
|
| 10 |
+
torch.cuda.manual_seed(seed)
|
| 11 |
+
torch.cuda.manual_seed_all(seed)
|
| 12 |
+
torch.backends.cudnn.deterministic = True
|
| 13 |
+
torch.backends.cudnn.benchmark = False
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def randid(k: int = 4):
|
| 17 |
+
charset = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'
|
| 18 |
+
return ''.join(random.choices(charset, k=k))
|
attacks/AIM/tests/__init__.py
ADDED
|
File without changes
|
attacks/AIM/tests/test_datasets/__init__.py
ADDED
|
File without changes
|
attacks/AIM/tests/test_datasets/test_datasets.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import unittest
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from gat.datasets import build_dataset
|
| 6 |
+
|
| 7 |
+
in1k_data_root = '/root/workspace/proj/transfer-at/data/in_1k'
|
| 8 |
+
in1k_data_root = os.environ.get('DATA_ROOT',
|
| 9 |
+
Path(__file__).parents[2] / 'data' / 'in_1k')
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TestImageNet(unittest.TestCase):
|
| 13 |
+
|
| 14 |
+
def test_in1k(self):
|
| 15 |
+
ds = build_dataset(
|
| 16 |
+
'imagenet',
|
| 17 |
+
data_root=in1k_data_root,
|
| 18 |
+
is_train=True,
|
| 19 |
+
)
|
| 20 |
+
self.assertEqual(len(ds), 1281167)
|
| 21 |
+
|
| 22 |
+
def test_in1k_filter(self):
|
| 23 |
+
ds = build_dataset(
|
| 24 |
+
'imagenet',
|
| 25 |
+
data_root=in1k_data_root,
|
| 26 |
+
is_train=False,
|
| 27 |
+
filter_class=0,
|
| 28 |
+
)
|
| 29 |
+
self.assertEqual(len(ds), 50)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if __name__ == '__main__':
|
| 33 |
+
unittest.main()
|
attacks/AIM/tests/test_datasets/test_transforms.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from torchvision.transforms import Compose
|
| 6 |
+
|
| 7 |
+
from gat.datasets.transforms import (hflip, norm, resize_224, resize_256_224,
|
| 8 |
+
resize_512_448, to_color, to_pil, to_ts)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TestResize(unittest.TestCase):
|
| 12 |
+
|
| 13 |
+
def test_resize_256_224(self):
|
| 14 |
+
inputs = torch.rand(3, 256, 256)
|
| 15 |
+
outputs = Compose(resize_256_224())(inputs)
|
| 16 |
+
self.assertEqual(outputs.shape, (3, 224, 224))
|
| 17 |
+
|
| 18 |
+
def test_resize_512_448(self):
|
| 19 |
+
inputs = torch.rand(3, 512, 512)
|
| 20 |
+
outputs = Compose(resize_512_448())(inputs)
|
| 21 |
+
self.assertEqual(outputs.shape, (3, 448, 448))
|
| 22 |
+
|
| 23 |
+
def test_resize_224(self):
|
| 24 |
+
inputs = torch.rand(3, 224, 224)
|
| 25 |
+
outputs = Compose(resize_224())(inputs)
|
| 26 |
+
self.assertEqual(outputs.shape, (3, 224, 224))
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class TestAug(unittest.TestCase):
|
| 30 |
+
|
| 31 |
+
def test_probability_invalid(self):
|
| 32 |
+
with self.assertRaises(AssertionError):
|
| 33 |
+
hflip(-0.1)
|
| 34 |
+
|
| 35 |
+
def test_probability_valid(self):
|
| 36 |
+
for valid_p in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
|
| 37 |
+
result = hflip(valid_p)
|
| 38 |
+
self.assertEqual(result[0].p, valid_p)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class TestTypeTransforms(unittest.TestCase):
|
| 42 |
+
|
| 43 |
+
def test_to_ts(self):
|
| 44 |
+
pil_image = Image.new('RGB', (224, 224))
|
| 45 |
+
outputs = Compose(to_ts())(pil_image)
|
| 46 |
+
self.assertIsInstance(outputs, torch.Tensor)
|
| 47 |
+
|
| 48 |
+
def test_to_pil(self):
|
| 49 |
+
inputs = torch.rand(3, 224, 224)
|
| 50 |
+
outputs = Compose(to_pil())(inputs)
|
| 51 |
+
self.assertIsInstance(outputs, Image.Image)
|
| 52 |
+
|
| 53 |
+
def test_to_color_grayscale(self):
|
| 54 |
+
inputs = Image.new('L', (224, 224))
|
| 55 |
+
outputs = Compose(to_color())(inputs)
|
| 56 |
+
self.assertEqual(outputs.mode, 'RGB')
|
| 57 |
+
|
| 58 |
+
def test_to_color_rgb(self):
|
| 59 |
+
inputs = Image.new('RGB', (224, 224))
|
| 60 |
+
outputs = Compose(to_color())(inputs)
|
| 61 |
+
self.assertEqual(outputs.mode, 'RGB')
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class TestNorm(unittest.TestCase):
|
| 65 |
+
|
| 66 |
+
def test_default(self):
|
| 67 |
+
self.assertEqual(norm()[0].mean, (0.485, 0.456, 0.406))
|
| 68 |
+
self.assertEqual(norm()[0].std, (0.229, 0.224, 0.225))
|
| 69 |
+
|
| 70 |
+
def test_imagenet(self):
|
| 71 |
+
self.assertEqual(norm('IMAGENET')[0].mean, (0.485, 0.456, 0.406))
|
| 72 |
+
self.assertEqual(norm('IMAGENET')[0].std, (0.229, 0.224, 0.225))
|
| 73 |
+
|
| 74 |
+
def test_invalid_ds(self):
|
| 75 |
+
with self.assertRaises(AttributeError):
|
| 76 |
+
norm('imagenets')
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
if __name__ == '__main__':
|
| 80 |
+
unittest.main()
|
attacks/AIM/tests/test_models/__init__.py
ADDED
|
File without changes
|
attacks/AIM/tests/test_models/test_attack.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from gat.models.attack.aim_attack import AIMAttack
|
| 6 |
+
from gat.models.attack.cda_attack import CDAAttack
|
| 7 |
+
from gat.models.attack.generator.aim import AIMGenerator
|
| 8 |
+
from gat.models.attack.generator.cda import CDAGenerator
|
| 9 |
+
from gat.models.attack.loss.logits import ContrastiveLoss
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TestGenerator(unittest.TestCase):
|
| 13 |
+
|
| 14 |
+
@torch.no_grad()
|
| 15 |
+
def test_cda(self):
|
| 16 |
+
generator = CDAGenerator()
|
| 17 |
+
inputs = torch.rand(1, 3, 224, 224)
|
| 18 |
+
outputs = generator(inputs)
|
| 19 |
+
self.assertEqual(outputs.shape, (1, 3, 224, 224))
|
| 20 |
+
|
| 21 |
+
@torch.no_grad()
|
| 22 |
+
def test_aim(self):
|
| 23 |
+
generator = AIMGenerator()
|
| 24 |
+
inputs = (torch.rand(1, 3, 224, 224), torch.rand(1, 3, 224, 224))
|
| 25 |
+
outputs = generator(*inputs)
|
| 26 |
+
self.assertEqual(outputs.shape, (1, 3, 224, 224))
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class TestContrastiveLoss(unittest.TestCase):
|
| 30 |
+
|
| 31 |
+
def setUp(self):
|
| 32 |
+
self.margin = 1.0
|
| 33 |
+
self.loss_fn = ContrastiveLoss(margin=self.margin)
|
| 34 |
+
|
| 35 |
+
@torch.no_grad()
|
| 36 |
+
def test_forward_shape(self):
|
| 37 |
+
anchors = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
|
| 38 |
+
positives = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
|
| 39 |
+
negatives = torch.tensor([[0.0, 0.0], [1.0, 1.0]])
|
| 40 |
+
|
| 41 |
+
loss = self.loss_fn(anchors, negatives, positives)
|
| 42 |
+
self.assertEqual(loss.shape, torch.Size([]))
|
| 43 |
+
|
| 44 |
+
@torch.no_grad()
|
| 45 |
+
def test_loss_outputs(self):
|
| 46 |
+
anchors = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
|
| 47 |
+
positives = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
|
| 48 |
+
negatives = torch.tensor([[2.0, 0.0], [0.0, 2.0]])
|
| 49 |
+
|
| 50 |
+
loss = self.loss_fn(anchors, negatives, positives)
|
| 51 |
+
self.assertAlmostEqual(loss.item(), 0.25)
|
| 52 |
+
|
| 53 |
+
@torch.no_grad()
|
| 54 |
+
def test_non_zero_loss(self):
|
| 55 |
+
anchors = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
|
| 56 |
+
positives = torch.tensor([[0.0, 1.0], [1.0, 0.0]])
|
| 57 |
+
negatives = torch.tensor([[2.0, 0.0], [0.0, 2.0]])
|
| 58 |
+
|
| 59 |
+
loss = self.loss_fn(anchors, negatives, positives)
|
| 60 |
+
self.assertAlmostEqual(loss.item(), 0.75)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class TestCDAAttack(unittest.TestCase):
|
| 64 |
+
|
| 65 |
+
def setUp(self):
|
| 66 |
+
self.device = torch.device('cpu')
|
| 67 |
+
self.epsilon = 16. / 255.
|
| 68 |
+
self.attack = CDAAttack(self.device, self.epsilon)
|
| 69 |
+
|
| 70 |
+
@torch.no_grad()
|
| 71 |
+
def test_outputs_shape(self):
|
| 72 |
+
x_nat = torch.rand(1, 3, 224, 224)
|
| 73 |
+
x_adv = self.attack(x_nat)
|
| 74 |
+
self.assertEqual(x_adv.shape, (1, 3, 224, 224))
|
| 75 |
+
|
| 76 |
+
@torch.no_grad()
|
| 77 |
+
def test_outputs_bound(self):
|
| 78 |
+
x_nat = torch.rand(1, 3, 224, 224)
|
| 79 |
+
x_adv = self.attack(x_nat)
|
| 80 |
+
self.assertTrue((x_adv >= 0.0).all())
|
| 81 |
+
self.assertTrue((x_adv <= 1.0).all())
|
| 82 |
+
self.assertTrue((x_adv - x_nat).abs().max() <= self.epsilon)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class TestAIMAttack(unittest.TestCase):
|
| 86 |
+
|
| 87 |
+
def setUp(self):
|
| 88 |
+
self.device = torch.device('cpu')
|
| 89 |
+
self.epsilon = 16. / 255.
|
| 90 |
+
self.attack = AIMAttack(self.device, self.epsilon)
|
| 91 |
+
|
| 92 |
+
@torch.no_grad()
|
| 93 |
+
def test_outputs_shape(self):
|
| 94 |
+
x_nat = torch.rand(1, 3, 224, 224)
|
| 95 |
+
x_guid = torch.rand(1, 3, 224, 224)
|
| 96 |
+
x_adv = self.attack(x_nat, x_guid)
|
| 97 |
+
self.assertEqual(x_adv.shape, (1, 3, 224, 224))
|
| 98 |
+
|
| 99 |
+
@torch.no_grad()
|
| 100 |
+
def test_outputs_bound(self):
|
| 101 |
+
x_nat = torch.rand(1, 3, 224, 224)
|
| 102 |
+
x_guid = torch.rand(1, 3, 224, 224)
|
| 103 |
+
x_adv = self.attack(x_nat, x_guid)
|
| 104 |
+
self.assertTrue((x_adv >= 0.0).all())
|
| 105 |
+
self.assertTrue((x_adv <= 1.0).all())
|
| 106 |
+
self.assertTrue((x_adv - x_nat).abs().max() <= self.epsilon)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
if __name__ == '__main__':
|
| 110 |
+
unittest.main()
|
attacks/AIM/tests/test_models/test_surrogate.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from gat.models.surrogate import build_surrogate, list_surrogates
|
| 6 |
+
from gat.models.surrogate.hooks import feat_col
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TestTVModels(unittest.TestCase):
|
| 10 |
+
|
| 11 |
+
@torch.no_grad()
|
| 12 |
+
def test_outputs_shape(self):
|
| 13 |
+
inputs = torch.rand(1, 3, 224, 224)
|
| 14 |
+
for surrogate_id in list_surrogates():
|
| 15 |
+
if surrogate_id not in ['inception_v3']:
|
| 16 |
+
outputs = build_surrogate(surrogate_id, pretrain=False)(inputs)
|
| 17 |
+
self.assertEqual(outputs.shape, (1, 1000))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TestFeatCol(unittest.TestCase):
|
| 21 |
+
|
| 22 |
+
@torch.no_grad()
|
| 23 |
+
def test_feat_col(self):
|
| 24 |
+
testcases = [{
|
| 25 |
+
'surrogate_id': 'vgg16',
|
| 26 |
+
'feat_layer': 'features.16',
|
| 27 |
+
'input_shape': (1, 3, 224, 224),
|
| 28 |
+
'feat_shape': (1, 256, 28, 28)
|
| 29 |
+
}, {
|
| 30 |
+
'surrogate_id': 'vgg19',
|
| 31 |
+
'feat_layer': 'features.18',
|
| 32 |
+
'input_shape': (1, 3, 224, 224),
|
| 33 |
+
'feat_shape': (1, 256, 28, 28)
|
| 34 |
+
}, {
|
| 35 |
+
'surrogate_id': 'resnet152',
|
| 36 |
+
'feat_layer': 'layer2',
|
| 37 |
+
'input_shape': (1, 3, 224, 224),
|
| 38 |
+
'feat_shape': (1, 512, 28, 28)
|
| 39 |
+
}, {
|
| 40 |
+
'surrogate_id': 'densenet169',
|
| 41 |
+
'feat_layer': 'features.denseblock2',
|
| 42 |
+
'input_shape': (1, 3, 224, 224),
|
| 43 |
+
'feat_shape': (1, 512, 28, 28)
|
| 44 |
+
}]
|
| 45 |
+
for testcase in testcases:
|
| 46 |
+
model = build_surrogate(testcase['surrogate_id'], pretrain=False)
|
| 47 |
+
with feat_col(model, testcase['feat_layer']) as _feat_collecter:
|
| 48 |
+
inputs = torch.rand(testcase['input_shape'])
|
| 49 |
+
model(inputs)
|
| 50 |
+
model(inputs)
|
| 51 |
+
self.assertEqual(len(_feat_collecter), 2)
|
| 52 |
+
self.assertEqual(_feat_collecter.pop().shape,
|
| 53 |
+
testcase['feat_shape'])
|
| 54 |
+
self.assertEqual(_feat_collecter.pop().shape,
|
| 55 |
+
testcase['feat_shape'])
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if __name__ == '__main__':
|
| 59 |
+
unittest.main()
|
attacks/AIM/tests/test_runtime/__init__.py
ADDED
|
File without changes
|