import os import numpy as np import torch import torch.nn.functional as F import matplotlib.pyplot as plt import pandas as pd from pytorch_grad_cam.grad_cam import GradCAM from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from pytorch_grad_cam.utils.image import show_cam_on_image from torchvision import transforms import cv2 import torchvision from pathlib import Path from PIL import Image from torchvision.utils import save_image def save_images(x, x_hat, render_num=64, output_dir='rendered_images', step=0, test=False): os.makedirs(output_dir, exist_ok=True) num_rows = int(render_num ** 0.5 / 2) * 2 img_lst = [] for i in range(int(render_num / 2)): img_lst.append(x[i]) img_lst.append(x_hat[i]) grid = torchvision.utils.make_grid(img_lst, nrow=num_rows, padding=2) if test: torchvision.utils.save_image(grid, os.path.join(output_dir, str(step) + '.png'), nrow=num_rows) else: torchvision.utils.save_image(grid * 0.5 + 0.5, os.path.join(output_dir, str(step) + '.png'), nrow=num_rows) def load_imagenet_class_dict(): with open(os.path.join('attacks/UnivIntruder/utils_/map_clsloc.txt'), 'r') as file: lines = file.readlines() label_dict = {} for line in lines: parts = line.strip().split(',') label_dict[int(parts[1])] = parts[2] return label_dict def load_imagenet100_class_dict(): with open(os.path.join('attacks/SAE/map_clsloc_imagenet100.txt'), 'r') as file: lines = file.readlines() label_dict = {} for line in lines: parts = line.strip().split(',') label_dict[int(parts[1])] = parts[2] return label_dict imagenet100_classes_dict = load_imagenet100_class_dict() imagenet_classes_dict = load_imagenet_class_dict() cifar100_classes_dict = { 0: 'apple', 1: 'aquarium_fish', 2: 'baby', 3: 'bear', 4: 'beaver', 5: 'bed', 6: 'bee', 7: 'beetle', 8: 'bicycle', 9: 'bottle', 10: 'bowl', 11: 'boy', 12: 'bridge', 13: 'bus', 14: 'butterfly', 15: 'camel', 16: 'can', 17: 'castle', 18: 'caterpillar', 19: 'cattle', 20: 'chair', 21: 'chimpanzee', 22: 'clock', 23: 'cloud', 24: 'cockroach', 25: 'couch', 26: 'crab', 27: 'crocodile', 28: 'cup', 29: 'dinosaur', 30: 'dolphin', 31: 'elephant', 32: 'flatfish', 33: 'forest', 34: 'fox', 35: 'girl', 36: 'hamster', 37: 'house', 38: 'kangaroo', 39: 'computer_keyboard', 40: 'lamp', 41: 'lawn_mower', 42: 'leopard', 43: 'lion', 44: 'lizard', 45: 'lobster', 46: 'man', 47: 'maple_tree', 48: 'motorcycle', 49: 'mountain', 50: 'mouse', 51: 'mushroom', 52: 'oak_tree', 53: 'orange', 54: 'orchid', 55: 'otter', 56: 'palm_tree', 57: 'pear', 58: 'pickup_truck', 59: 'pine_tree', 60: 'plain', 61: 'plate', 62: 'poppy', 63: 'porcupine', 64: 'possum', 65: 'rabbit', 66: 'raccoon', 67: 'ray', 68: 'road', 69: 'rocket', 70: 'rose', 71: 'sea', 72: 'seal', 73: 'shark', 74: 'shrew', 75: 'skunk', 76: 'skyscraper', 77: 'snail', 78: 'snake', 79: 'spider', 80: 'squirrel', 81: 'streetcar', 82: 'sunflower', 83: 'sweet_pepper', 84: 'table', 85: 'tank', 86: 'telephone', 87: 'television', 88: 'tiger', 89: 'tractor', 90: 'train', 91: 'trout', 92: 'tulip', 93: 'turtle', 94: 'wardrobe', 95: 'whale', 96: 'willow_tree', 97: 'wolf', 98: 'woman', 99: 'worm' } def load_cifar100_classes(): return [ 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'computer_keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm' ] def plot_asr_per_target(asr_matrix, save_path, prefix, args, acc_metric=None): num_tasks, num_targets = asr_matrix.shape tasks = np.arange(num_tasks) avg_asr = asr_matrix.mean(axis=0) fig, (ax_line, ax_bar) = plt.subplots(2, 1, figsize=(12, 10), gridspec_kw={'height_ratios': [3, 1]}) for i in range(num_targets): ax_line.plot(tasks, asr_matrix[:, i], marker='o', label=f'Target Image {i} (ASR)') if acc_metric is not None: ax_line.plot(tasks, acc_metric, marker='x', label='Clean Accuracy', color='red', linestyle='--') ax_line.set_xlabel('Task') ax_line.set_ylabel('Attack Success Rate (ASR)' if acc_metric is None else 'ASR / Accuracy') ax_line.set_title(f'ASR of each Target Image (Target Class: {args["target_class"]}) across Tasks') ax_line.set_xticks(tasks) ax_line.set_ylim(0, 1) ax_line.grid(True) ax_line.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize='small', ncol=1) indices = np.arange(num_targets) ax_bar.bar(indices, avg_asr, color='skyblue') ax_bar.set_xlabel('Target Image') ax_bar.set_ylabel('Average ASR') ax_bar.set_title(f'Average ASR per Target Image (Target Class: {args["target_class"]})') ax_bar.set_xticks(indices) ax_bar.set_xticklabels([f'{i}' for i in range(num_targets)], rotation=45, fontsize='small') ax_bar.set_ylim(0, 1) ax_bar.grid(axis='y') plt.tight_layout(rect=[0, 0, 0.85, 1]) os.makedirs(save_path, exist_ok=True) plt.savefig(os.path.join(save_path, f"{prefix}.png"), bbox_inches='tight') plt.close() def save_batch_images(batch_imgs, logs_eval_name, filename=None, prefix="adv", save_num=2): if filename is not None: target_folder = os.path.join(logs_eval_name, f'{filename}') else: target_folder = os.path.join(logs_eval_name) os.makedirs(target_folder, exist_ok=True) for i, img_tensor in enumerate(batch_imgs): if i + 1 > save_num: break img_name = f"{prefix}_{i}.png" save_image(img_tensor, os.path.join(target_folder, img_name)) def save_grad_cam(args, imgs, labels, model, save_path, prefix, layer_name="layer4", save_num=2, save_raw=False): os.makedirs(save_path, exist_ok=True) model.eval() cl_methods = args['model_name'] if cl_methods == 'icarl' or cl_methods == 'finetune' or cl_methods == 'wa' or cl_methods == 'replay' or cl_methods == 'podnet' or cl_methods == 'bic': target_layers = [model.convnet.get_submodule(layer_name)[-1]] elif cl_methods == 'foster' or cl_methods == 'der': target_layers = [model.convnets[0].get_submodule(layer_name)[-1]] elif cl_methods == 'memo': target_layers = [model.TaskAgnosticExtractor.get_submodule(layer_name)[-1]] cam = GradCAM(model=model, target_layers=target_layers) # cam_target = ClassifierOutputTarget(labels) # Iterate through the images for i, img in enumerate(imgs): if i + 1 > save_num: break cam_target = ClassifierOutputTarget(labels[i]) # grayscale_cam = cam(img.unsqueeze(0), [cam_target]) grayscale_cam = cam(img.unsqueeze(0)) grayscale_cam = grayscale_cam[0, :] img_np = np.array(img.cpu().permute(1, 2, 0)) img_np = np.float32(img_np) cam_imgs = show_cam_on_image(img_np, grayscale_cam, use_rgb=False) # Save the result output_path = Path(save_path) / f"{i}_{prefix}_grad_cam.png" cv2.imwrite(str(output_path), cam_imgs) if save_raw: output_path = Path(save_path) / f"{i}_{prefix}_grad_cam_raw.png" cv2.imwrite(str(output_path), np.clip(img_np, 0.0, 1.0) * 255) # Clear GPU cache and gradients torch.cuda.empty_cache() # # if i == 10: # break print(f"Grad-CAM images saved to {save_path}")