|
|
import os
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
import matplotlib.pyplot as plt
|
|
|
import pandas as pd
|
|
|
from pytorch_grad_cam.grad_cam import GradCAM
|
|
|
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
|
|
from pytorch_grad_cam.utils.image import show_cam_on_image
|
|
|
from torchvision import transforms
|
|
|
import cv2
|
|
|
import torchvision
|
|
|
from pathlib import Path
|
|
|
from PIL import Image
|
|
|
from torchvision.utils import save_image
|
|
|
|
|
|
|
|
|
def save_images(x, x_hat, render_num=64, output_dir='rendered_images', step=0, test=False):
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
|
num_rows = int(render_num ** 0.5 / 2) * 2
|
|
|
|
|
|
img_lst = []
|
|
|
for i in range(int(render_num / 2)):
|
|
|
img_lst.append(x[i])
|
|
|
img_lst.append(x_hat[i])
|
|
|
|
|
|
grid = torchvision.utils.make_grid(img_lst, nrow=num_rows, padding=2)
|
|
|
if test:
|
|
|
torchvision.utils.save_image(grid, os.path.join(output_dir, str(step) + '.png'), nrow=num_rows)
|
|
|
else:
|
|
|
torchvision.utils.save_image(grid * 0.5 + 0.5, os.path.join(output_dir, str(step) + '.png'), nrow=num_rows)
|
|
|
|
|
|
def load_imagenet_class_dict():
|
|
|
with open(os.path.join('attacks/UnivIntruder/utils_/map_clsloc.txt'), 'r') as file:
|
|
|
lines = file.readlines()
|
|
|
label_dict = {}
|
|
|
for line in lines:
|
|
|
parts = line.strip().split(',')
|
|
|
label_dict[int(parts[1])] = parts[2]
|
|
|
return label_dict
|
|
|
|
|
|
def load_imagenet100_class_dict():
|
|
|
with open(os.path.join('attacks/SAE/map_clsloc_imagenet100.txt'), 'r') as file:
|
|
|
lines = file.readlines()
|
|
|
label_dict = {}
|
|
|
for line in lines:
|
|
|
parts = line.strip().split(',')
|
|
|
label_dict[int(parts[1])] = parts[2]
|
|
|
return label_dict
|
|
|
|
|
|
|
|
|
imagenet100_classes_dict = load_imagenet100_class_dict()
|
|
|
imagenet_classes_dict = load_imagenet_class_dict()
|
|
|
cifar100_classes_dict = {
|
|
|
0: 'apple', 1: 'aquarium_fish', 2: 'baby', 3: 'bear', 4: 'beaver', 5: 'bed',
|
|
|
6: 'bee', 7: 'beetle', 8: 'bicycle', 9: 'bottle', 10: 'bowl', 11: 'boy', 12: 'bridge',
|
|
|
13: 'bus', 14: 'butterfly', 15: 'camel', 16: 'can', 17: 'castle', 18: 'caterpillar',
|
|
|
19: 'cattle', 20: 'chair', 21: 'chimpanzee', 22: 'clock', 23: 'cloud', 24: 'cockroach',
|
|
|
25: 'couch', 26: 'crab', 27: 'crocodile', 28: 'cup', 29: 'dinosaur', 30: 'dolphin',
|
|
|
31: 'elephant', 32: 'flatfish', 33: 'forest', 34: 'fox', 35: 'girl', 36: 'hamster',
|
|
|
37: 'house', 38: 'kangaroo', 39: 'computer_keyboard', 40: 'lamp', 41: 'lawn_mower',
|
|
|
42: 'leopard', 43: 'lion', 44: 'lizard', 45: 'lobster', 46: 'man', 47: 'maple_tree',
|
|
|
48: 'motorcycle', 49: 'mountain', 50: 'mouse', 51: 'mushroom', 52: 'oak_tree',
|
|
|
53: 'orange', 54: 'orchid', 55: 'otter', 56: 'palm_tree', 57: 'pear', 58: 'pickup_truck',
|
|
|
59: 'pine_tree', 60: 'plain', 61: 'plate', 62: 'poppy', 63: 'porcupine', 64: 'possum',
|
|
|
65: 'rabbit', 66: 'raccoon', 67: 'ray', 68: 'road', 69: 'rocket', 70: 'rose', 71: 'sea',
|
|
|
72: 'seal', 73: 'shark', 74: 'shrew', 75: 'skunk', 76: 'skyscraper', 77: 'snail',
|
|
|
78: 'snake', 79: 'spider', 80: 'squirrel', 81: 'streetcar', 82: 'sunflower',
|
|
|
83: 'sweet_pepper', 84: 'table', 85: 'tank', 86: 'telephone', 87: 'television',
|
|
|
88: 'tiger', 89: 'tractor', 90: 'train', 91: 'trout', 92: 'tulip', 93: 'turtle',
|
|
|
94: 'wardrobe', 95: 'whale', 96: 'willow_tree', 97: 'wolf', 98: 'woman', 99: 'worm'
|
|
|
}
|
|
|
|
|
|
def load_cifar100_classes():
|
|
|
return [
|
|
|
'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
|
|
|
'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
|
|
|
'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
|
|
|
'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin',
|
|
|
'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo',
|
|
|
'computer_keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster',
|
|
|
'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree',
|
|
|
'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree',
|
|
|
'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray',
|
|
|
'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper',
|
|
|
'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper',
|
|
|
'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train',
|
|
|
'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman',
|
|
|
'worm'
|
|
|
]
|
|
|
|
|
|
|
|
|
def plot_asr_per_target(asr_matrix, save_path, prefix, args, acc_metric=None):
|
|
|
num_tasks, num_targets = asr_matrix.shape
|
|
|
tasks = np.arange(num_tasks)
|
|
|
|
|
|
avg_asr = asr_matrix.mean(axis=0)
|
|
|
|
|
|
fig, (ax_line, ax_bar) = plt.subplots(2, 1, figsize=(12, 10), gridspec_kw={'height_ratios': [3, 1]})
|
|
|
|
|
|
for i in range(num_targets):
|
|
|
ax_line.plot(tasks, asr_matrix[:, i], marker='o', label=f'Target Image {i} (ASR)')
|
|
|
|
|
|
if acc_metric is not None:
|
|
|
ax_line.plot(tasks, acc_metric, marker='x', label='Clean Accuracy', color='red', linestyle='--')
|
|
|
|
|
|
ax_line.set_xlabel('Task')
|
|
|
ax_line.set_ylabel('Attack Success Rate (ASR)' if acc_metric is None else 'ASR / Accuracy')
|
|
|
ax_line.set_title(f'ASR of each Target Image (Target Class: {args["target_class"]}) across Tasks')
|
|
|
ax_line.set_xticks(tasks)
|
|
|
ax_line.set_ylim(0, 1)
|
|
|
ax_line.grid(True)
|
|
|
ax_line.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize='small', ncol=1)
|
|
|
|
|
|
indices = np.arange(num_targets)
|
|
|
ax_bar.bar(indices, avg_asr, color='skyblue')
|
|
|
ax_bar.set_xlabel('Target Image')
|
|
|
ax_bar.set_ylabel('Average ASR')
|
|
|
ax_bar.set_title(f'Average ASR per Target Image (Target Class: {args["target_class"]})')
|
|
|
ax_bar.set_xticks(indices)
|
|
|
ax_bar.set_xticklabels([f'{i}' for i in range(num_targets)], rotation=45, fontsize='small')
|
|
|
ax_bar.set_ylim(0, 1)
|
|
|
ax_bar.grid(axis='y')
|
|
|
|
|
|
plt.tight_layout(rect=[0, 0, 0.85, 1])
|
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
plt.savefig(os.path.join(save_path, f"{prefix}.png"), bbox_inches='tight')
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
|
|
|
def save_batch_images(batch_imgs, logs_eval_name, filename=None, prefix="adv", save_num=2):
|
|
|
if filename is not None:
|
|
|
target_folder = os.path.join(logs_eval_name, f'{filename}')
|
|
|
else:
|
|
|
target_folder = os.path.join(logs_eval_name)
|
|
|
os.makedirs(target_folder, exist_ok=True)
|
|
|
|
|
|
for i, img_tensor in enumerate(batch_imgs):
|
|
|
if i + 1 > save_num:
|
|
|
break
|
|
|
img_name = f"{prefix}_{i}.png"
|
|
|
save_image(img_tensor, os.path.join(target_folder, img_name))
|
|
|
|
|
|
|
|
|
def save_grad_cam(args, imgs, labels, model, save_path, prefix, layer_name="layer4", save_num=2, save_raw=False):
|
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
model.eval()
|
|
|
cl_methods = args['model_name']
|
|
|
if cl_methods == 'icarl' or cl_methods == 'finetune' or cl_methods == 'wa' or cl_methods == 'replay' or cl_methods == 'podnet' or cl_methods == 'bic':
|
|
|
target_layers = [model.convnet.get_submodule(layer_name)[-1]]
|
|
|
elif cl_methods == 'foster' or cl_methods == 'der':
|
|
|
target_layers = [model.convnets[0].get_submodule(layer_name)[-1]]
|
|
|
elif cl_methods == 'memo':
|
|
|
target_layers = [model.TaskAgnosticExtractor.get_submodule(layer_name)[-1]]
|
|
|
cam = GradCAM(model=model, target_layers=target_layers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, img in enumerate(imgs):
|
|
|
if i + 1 > save_num:
|
|
|
break
|
|
|
|
|
|
cam_target = ClassifierOutputTarget(labels[i])
|
|
|
|
|
|
grayscale_cam = cam(img.unsqueeze(0))
|
|
|
grayscale_cam = grayscale_cam[0, :]
|
|
|
|
|
|
img_np = np.array(img.cpu().permute(1, 2, 0))
|
|
|
img_np = np.float32(img_np)
|
|
|
|
|
|
cam_imgs = show_cam_on_image(img_np, grayscale_cam, use_rgb=False)
|
|
|
|
|
|
|
|
|
output_path = Path(save_path) / f"{i}_{prefix}_grad_cam.png"
|
|
|
cv2.imwrite(str(output_path), cam_imgs)
|
|
|
|
|
|
if save_raw:
|
|
|
output_path = Path(save_path) / f"{i}_{prefix}_grad_cam_raw.png"
|
|
|
cv2.imwrite(str(output_path), np.clip(img_np, 0.0, 1.0) * 255)
|
|
|
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Grad-CAM images saved to {save_path}")
|
|
|
|
|
|
|