File size: 1,366 Bytes
373085f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from torchvision import transforms
from PIL import Image
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF
from torchvision.utils import make_grid
import torch

def image_process(image):
    image = Image.open(image).convert("RGB")
    transformations = transforms.Compose([
        transforms.ToTensor(),
        #transforms.Normalize(mean = [0.485, 0.456, 0.406],
                             #std=[0.229, 0.224, 0.225])

    ])
    return transformations(image)

def view_activations(result_list, max_channels = 3):
    i = 4
    for result in result_list:
      result_tensor = result[0, i:i+1]
      i+=1
      grid = make_grid(result_tensor, nrow = 3, normalize=True, padding = 1)
      plt.figure(figsize=(8,8))
      plt.imshow(grid.permute(1,2,0))
      plt.show()

def view_activations_gram(image, model, matrix, max_channels = 3):
    i = 4
    result_tensor = matrix[0, 0:i]
    print(result_tensor.shape)
    grid = make_grid(result_tensor, nrow = 1, normalize=True, padding = 1)
    plt.figure(figsize=(8,8))
    plt.imshow(grid.permute(1,2,0))
    plt.show()

def style_computing(result_list, model, image):
    final = 0
    for result in result_list:
        result = result.squeeze(0)
        matrix = torch.bmm(result, result.transpose(1,2))
        view_activations_gram(image, model, matrix)