File size: 2,707 Bytes
4283e0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import numpy as np
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
from PIL import Image
import maevit import MAEViT


def visualize(model_path, img_path, figure_name):
    
    model = MAEViT(
        image_size=224,
        patch_size=16,
        embed_dim=128,
        encoder_layers=2,
        encoder_heads=4,
        mlp_ratio=2.0,
        mask_ratio=0.75,
        decoder_embed_dim=64,
        decoder_layers=2,
        decoder_heads=4,
        dropout=0.1
    )
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint)
    model.eval()

    
    to_tensor = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std =[0.229, 0.224, 0.225]
        )
    ])

    
    img = Image.open(img_path).convert('RGB')
    x = to_tensor(img).unsqueeze(0).to(device)  # [1,3,224,224]

    
    with torch.no_grad():
        
        x_enc, mask, ids_restore = model.forward_encoder(x)
        
        x_rec_patches = model.forward_decoder(x_enc, ids_restore)

    
    img_rec = model.unpatchify(x_rec_patches[:, 1:, :]) # exclude CLS    # [1,3,224,224]
    img_patches = model.patchify(x)              # [1, num_patches, patch_dim]

    masked_patches = img_patches.clone()
    mask = mask.unsqueeze(-1).to(torch.bool)     # [1, num_patches, 1]
    # masked_patches[mask] = 0
    masked_patches = masked_patches.masked_fill(mask, 0)

    img_masked = model.unpatchify(masked_patches) # [1,3,224,224]

    inv_normalize = transforms.Normalize(
        mean=[-m/s for m, s in zip((0.485,0.456,0.406),(0.229,0.224,0.225))],
        std =[1/s for s in       (0.229,0.224,0.225)]
    )
    def to_img(tensor):
        img = tensor.squeeze(0).cpu()
        img = inv_normalize(img)
        img = img.permute(1,2,0).clamp(0,1).numpy()
        return img

    orig_np     = to_img(x)
    masked_np   = to_img(img_masked)
    recon_np    = to_img(img_rec)

    # 8. Plot
    fig, axes = plt.subplots(1, 3, figsize=(15,5))
    for ax, im, title in zip(axes,
                             [orig_np, masked_np, recon_np],
                             ['Original', 'Masked Input', 'Reconstruction']):
        ax.imshow(im)
        ax.set_title(title)
        ax.axis('off')
    plt.tight_layout()
    plt.show()
    plt.savefig(figure_name)

visualize('MAE1.bin', img_path='guineapig.jpg', figure_name='figures/MAE_visualization1.png')