| import torch | |
| from torchvision import transforms | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| import matplotlib.pyplot as plt | |
| import PIL | |
| import io | |
| from PIL import Image | |
| import numpy as np | |
| import random | |
| transform = transforms.ToTensor() | |
| targets = None | |
| device = torch.device("cpu") | |
| mu = [0.49139968, 0.48215841, 0.44653091] | |
| std = [0.24703223, 0.24348513, 0.26158784] | |
| inv_normalize = transforms.Normalize( | |
| mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23], | |
| std=[1/0.23, 1/0.23, 1/0.23] | |
| ) | |
| classes = ('plane', 'car', 'bird', 'cat', 'deer', | |
| 'dog', 'frog', 'horse', 'ship', 'truck') | |
| transform = transforms.ToTensor() | |
| def get_examples(): | |
| example_images = [f'{c}.jpg' for c in classes] | |
| example_top = [random.randint(2, 9) for r in range(10)] | |
| example_transparency = [random.choice([0.6, 0.7, 0.8]) for r in range(10)] | |
| examples = [[example_images[i], example_top[i], example_transparency[i]] for i in range(len(example_images))] | |
| return(examples) | |
| def image_to_array(input_img, model, layer_val, transparency=0.6): | |
| input_tensor = input_img[0] | |
| print(input_tensor.shape) | |
| cam = GradCAM(model=model, target_layers=[model.res_block2.conv[-layer_val]]) | |
| grayscale_cam = cam(input_tensor=input_tensor, targets=targets) | |
| grayscale_cam = grayscale_cam[0, :] | |
| img = input_tensor.squeeze(0) | |
| img = inv_normalize(img) | |
| rgb_img = np.transpose(img, (1, 2, 0)) | |
| rgb_img = rgb_img.numpy() | |
| visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, | |
| image_weight=transparency) | |
| plt.imshow(visualization) | |
| plt.title(r"Correct: " + classes[input_img[1].item()] + '\n' + 'Output: ' + classes[input_img[2].item()]) | |
| with io.BytesIO() as buffer: | |
| plt.savefig(buffer, format = "png") | |
| buffer.seek(0) | |
| image = Image.open(buffer) | |
| ar = np.asarray(image) | |
| return(ar) | |