from torchvision import transforms from PIL import Image from PIL import Image import matplotlib.pyplot as plt import torchvision.transforms.functional as TF from torchvision.utils import make_grid import torch def image_process(image): image = Image.open(image).convert("RGB") transformations = transforms.Compose([ transforms.ToTensor(), #transforms.Normalize(mean = [0.485, 0.456, 0.406], #std=[0.229, 0.224, 0.225]) ]) return transformations(image) def view_activations(result_list, max_channels = 3): i = 4 for result in result_list: result_tensor = result[0, i:i+1] i+=1 grid = make_grid(result_tensor, nrow = 3, normalize=True, padding = 1) plt.figure(figsize=(8,8)) plt.imshow(grid.permute(1,2,0)) plt.show() def view_activations_gram(image, model, matrix, max_channels = 3): i = 4 result_tensor = matrix[0, 0:i] print(result_tensor.shape) grid = make_grid(result_tensor, nrow = 1, normalize=True, padding = 1) plt.figure(figsize=(8,8)) plt.imshow(grid.permute(1,2,0)) plt.show() def style_computing(result_list, model, image): final = 0 for result in result_list: result = result.squeeze(0) matrix = torch.bmm(result, result.transpose(1,2)) view_activations_gram(image, model, matrix)