SAE / utils /plot.py
Ttius's picture
Upload 192 files
998bb30 verified
import os
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pandas as pd
from pytorch_grad_cam.grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision import transforms
import cv2
import torchvision
from pathlib import Path
from PIL import Image
from torchvision.utils import save_image
def save_images(x, x_hat, render_num=64, output_dir='rendered_images', step=0, test=False):
os.makedirs(output_dir, exist_ok=True)
num_rows = int(render_num ** 0.5 / 2) * 2
img_lst = []
for i in range(int(render_num / 2)):
img_lst.append(x[i])
img_lst.append(x_hat[i])
grid = torchvision.utils.make_grid(img_lst, nrow=num_rows, padding=2)
if test:
torchvision.utils.save_image(grid, os.path.join(output_dir, str(step) + '.png'), nrow=num_rows)
else:
torchvision.utils.save_image(grid * 0.5 + 0.5, os.path.join(output_dir, str(step) + '.png'), nrow=num_rows)
def load_imagenet_class_dict():
with open(os.path.join('attacks/UnivIntruder/utils_/map_clsloc.txt'), 'r') as file:
lines = file.readlines()
label_dict = {}
for line in lines:
parts = line.strip().split(',')
label_dict[int(parts[1])] = parts[2]
return label_dict
def load_imagenet100_class_dict():
with open(os.path.join('attacks/SAE/map_clsloc_imagenet100.txt'), 'r') as file:
lines = file.readlines()
label_dict = {}
for line in lines:
parts = line.strip().split(',')
label_dict[int(parts[1])] = parts[2]
return label_dict
imagenet100_classes_dict = load_imagenet100_class_dict()
imagenet_classes_dict = load_imagenet_class_dict()
cifar100_classes_dict = {
0: 'apple', 1: 'aquarium_fish', 2: 'baby', 3: 'bear', 4: 'beaver', 5: 'bed',
6: 'bee', 7: 'beetle', 8: 'bicycle', 9: 'bottle', 10: 'bowl', 11: 'boy', 12: 'bridge',
13: 'bus', 14: 'butterfly', 15: 'camel', 16: 'can', 17: 'castle', 18: 'caterpillar',
19: 'cattle', 20: 'chair', 21: 'chimpanzee', 22: 'clock', 23: 'cloud', 24: 'cockroach',
25: 'couch', 26: 'crab', 27: 'crocodile', 28: 'cup', 29: 'dinosaur', 30: 'dolphin',
31: 'elephant', 32: 'flatfish', 33: 'forest', 34: 'fox', 35: 'girl', 36: 'hamster',
37: 'house', 38: 'kangaroo', 39: 'computer_keyboard', 40: 'lamp', 41: 'lawn_mower',
42: 'leopard', 43: 'lion', 44: 'lizard', 45: 'lobster', 46: 'man', 47: 'maple_tree',
48: 'motorcycle', 49: 'mountain', 50: 'mouse', 51: 'mushroom', 52: 'oak_tree',
53: 'orange', 54: 'orchid', 55: 'otter', 56: 'palm_tree', 57: 'pear', 58: 'pickup_truck',
59: 'pine_tree', 60: 'plain', 61: 'plate', 62: 'poppy', 63: 'porcupine', 64: 'possum',
65: 'rabbit', 66: 'raccoon', 67: 'ray', 68: 'road', 69: 'rocket', 70: 'rose', 71: 'sea',
72: 'seal', 73: 'shark', 74: 'shrew', 75: 'skunk', 76: 'skyscraper', 77: 'snail',
78: 'snake', 79: 'spider', 80: 'squirrel', 81: 'streetcar', 82: 'sunflower',
83: 'sweet_pepper', 84: 'table', 85: 'tank', 86: 'telephone', 87: 'television',
88: 'tiger', 89: 'tractor', 90: 'train', 91: 'trout', 92: 'tulip', 93: 'turtle',
94: 'wardrobe', 95: 'whale', 96: 'willow_tree', 97: 'wolf', 98: 'woman', 99: 'worm'
}
def load_cifar100_classes():
return [
'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin',
'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo',
'computer_keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster',
'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree',
'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree',
'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray',
'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper',
'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper',
'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train',
'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman',
'worm'
]
def plot_asr_per_target(asr_matrix, save_path, prefix, args, acc_metric=None):
num_tasks, num_targets = asr_matrix.shape
tasks = np.arange(num_tasks)
avg_asr = asr_matrix.mean(axis=0)
fig, (ax_line, ax_bar) = plt.subplots(2, 1, figsize=(12, 10), gridspec_kw={'height_ratios': [3, 1]})
for i in range(num_targets):
ax_line.plot(tasks, asr_matrix[:, i], marker='o', label=f'Target Image {i} (ASR)')
if acc_metric is not None:
ax_line.plot(tasks, acc_metric, marker='x', label='Clean Accuracy', color='red', linestyle='--')
ax_line.set_xlabel('Task')
ax_line.set_ylabel('Attack Success Rate (ASR)' if acc_metric is None else 'ASR / Accuracy')
ax_line.set_title(f'ASR of each Target Image (Target Class: {args["target_class"]}) across Tasks')
ax_line.set_xticks(tasks)
ax_line.set_ylim(0, 1)
ax_line.grid(True)
ax_line.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize='small', ncol=1)
indices = np.arange(num_targets)
ax_bar.bar(indices, avg_asr, color='skyblue')
ax_bar.set_xlabel('Target Image')
ax_bar.set_ylabel('Average ASR')
ax_bar.set_title(f'Average ASR per Target Image (Target Class: {args["target_class"]})')
ax_bar.set_xticks(indices)
ax_bar.set_xticklabels([f'{i}' for i in range(num_targets)], rotation=45, fontsize='small')
ax_bar.set_ylim(0, 1)
ax_bar.grid(axis='y')
plt.tight_layout(rect=[0, 0, 0.85, 1])
os.makedirs(save_path, exist_ok=True)
plt.savefig(os.path.join(save_path, f"{prefix}.png"), bbox_inches='tight')
plt.close()
def save_batch_images(batch_imgs, logs_eval_name, filename=None, prefix="adv", save_num=2):
if filename is not None:
target_folder = os.path.join(logs_eval_name, f'{filename}')
else:
target_folder = os.path.join(logs_eval_name)
os.makedirs(target_folder, exist_ok=True)
for i, img_tensor in enumerate(batch_imgs):
if i + 1 > save_num:
break
img_name = f"{prefix}_{i}.png"
save_image(img_tensor, os.path.join(target_folder, img_name))
def save_grad_cam(args, imgs, labels, model, save_path, prefix, layer_name="layer4", save_num=2, save_raw=False):
os.makedirs(save_path, exist_ok=True)
model.eval()
cl_methods = args['model_name']
if cl_methods == 'icarl' or cl_methods == 'finetune' or cl_methods == 'wa' or cl_methods == 'replay' or cl_methods == 'podnet' or cl_methods == 'bic':
target_layers = [model.convnet.get_submodule(layer_name)[-1]]
elif cl_methods == 'foster' or cl_methods == 'der':
target_layers = [model.convnets[0].get_submodule(layer_name)[-1]]
elif cl_methods == 'memo':
target_layers = [model.TaskAgnosticExtractor.get_submodule(layer_name)[-1]]
cam = GradCAM(model=model, target_layers=target_layers)
# cam_target = ClassifierOutputTarget(labels)
# Iterate through the images
for i, img in enumerate(imgs):
if i + 1 > save_num:
break
cam_target = ClassifierOutputTarget(labels[i])
# grayscale_cam = cam(img.unsqueeze(0), [cam_target])
grayscale_cam = cam(img.unsqueeze(0))
grayscale_cam = grayscale_cam[0, :]
img_np = np.array(img.cpu().permute(1, 2, 0))
img_np = np.float32(img_np)
cam_imgs = show_cam_on_image(img_np, grayscale_cam, use_rgb=False)
# Save the result
output_path = Path(save_path) / f"{i}_{prefix}_grad_cam.png"
cv2.imwrite(str(output_path), cam_imgs)
if save_raw:
output_path = Path(save_path) / f"{i}_{prefix}_grad_cam_raw.png"
cv2.imwrite(str(output_path), np.clip(img_np, 0.0, 1.0) * 255)
# Clear GPU cache and gradients
torch.cuda.empty_cache()
#
# if i == 10:
# break
print(f"Grad-CAM images saved to {save_path}")