prekshyam commited on
Commit
e059298
·
verified ·
1 Parent(s): b189f29

Not needed; deleted visualizer file

Browse files
Files changed (1) hide show
  1. ModelVisualizer.py +0 -89
ModelVisualizer.py DELETED
@@ -1,89 +0,0 @@
1
- import torch
2
- import numpy as np
3
- import torchvision.transforms as transforms
4
-
5
- import matplotlib.pyplot as plt
6
- from PIL import Image
7
- from models.maevit import MAEViT
8
-
9
-
10
- def visualize(model_path, img_path, figure_name):
11
-
12
- model = MAEViT(
13
- image_size=224,
14
- patch_size=16,
15
- embed_dim=128,
16
- encoder_layers=2,
17
- encoder_heads=4,
18
- mlp_ratio=2.0,
19
- mask_ratio=0.75,
20
- decoder_embed_dim=64,
21
- decoder_layers=2,
22
- decoder_heads=4,
23
- dropout=0.1
24
- )
25
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
- model.to(device)
27
- checkpoint = torch.load(model_path, map_location=device)
28
- model.load_state_dict(checkpoint)
29
- model.eval()
30
-
31
-
32
- to_tensor = transforms.Compose([
33
- transforms.Resize((224, 224)),
34
- transforms.ToTensor(),
35
- transforms.Normalize(
36
- mean=[0.485, 0.456, 0.406],
37
- std =[0.229, 0.224, 0.225]
38
- )
39
- ])
40
-
41
-
42
- img = Image.open(img_path).convert('RGB')
43
- x = to_tensor(img).unsqueeze(0).to(device) # [1,3,224,224]
44
-
45
-
46
- with torch.no_grad():
47
-
48
- x_enc, mask, ids_restore = model.forward_encoder(x)
49
-
50
- x_rec_patches = model.forward_decoder(x_enc, ids_restore)
51
-
52
-
53
- img_rec = model.unpatchify(x_rec_patches[:, 1:, :]) # exclude CLS # [1,3,224,224]
54
- img_patches = model.patchify(x) # [1, num_patches, patch_dim]
55
-
56
- masked_patches = img_patches.clone()
57
- mask = mask.unsqueeze(-1).to(torch.bool) # [1, num_patches, 1]
58
- # masked_patches[mask] = 0
59
- masked_patches = masked_patches.masked_fill(mask, 0)
60
-
61
- img_masked = model.unpatchify(masked_patches) # [1,3,224,224]
62
-
63
- inv_normalize = transforms.Normalize(
64
- mean=[-m/s for m, s in zip((0.485,0.456,0.406),(0.229,0.224,0.225))],
65
- std =[1/s for s in (0.229,0.224,0.225)]
66
- )
67
- def to_img(tensor):
68
- img = tensor.squeeze(0).cpu()
69
- img = inv_normalize(img)
70
- img = img.permute(1,2,0).clamp(0,1).numpy()
71
- return img
72
-
73
- orig_np = to_img(x)
74
- masked_np = to_img(img_masked)
75
- recon_np = to_img(img_rec)
76
-
77
- # 8. Plot
78
- fig, axes = plt.subplots(1, 3, figsize=(15,5))
79
- for ax, im, title in zip(axes,
80
- [orig_np, masked_np, recon_np],
81
- ['Original', 'Masked Input', 'Reconstruction']):
82
- ax.imshow(im)
83
- ax.set_title(title)
84
- ax.axis('off')
85
- plt.tight_layout()
86
- plt.show()
87
- plt.savefig(figure_name)
88
-
89
- visualize('MAE1.bin', img_path='guineapig.jpg', figure_name='figures/MAE_visualization1.png')