import torch import numpy as np import torchvision.transforms as transforms import matplotlib.pyplot as plt from PIL import Image import maevit import MAEViT def visualize(model_path, img_path, figure_name): model = MAEViT( image_size=224, patch_size=16, embed_dim=128, encoder_layers=2, encoder_heads=4, mlp_ratio=2.0, mask_ratio=0.75, decoder_embed_dim=64, decoder_layers=2, decoder_heads=4, dropout=0.1 ) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint) model.eval() to_tensor = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std =[0.229, 0.224, 0.225] ) ]) img = Image.open(img_path).convert('RGB') x = to_tensor(img).unsqueeze(0).to(device) # [1,3,224,224] with torch.no_grad(): x_enc, mask, ids_restore = model.forward_encoder(x) x_rec_patches = model.forward_decoder(x_enc, ids_restore) img_rec = model.unpatchify(x_rec_patches[:, 1:, :]) # exclude CLS # [1,3,224,224] img_patches = model.patchify(x) # [1, num_patches, patch_dim] masked_patches = img_patches.clone() mask = mask.unsqueeze(-1).to(torch.bool) # [1, num_patches, 1] # masked_patches[mask] = 0 masked_patches = masked_patches.masked_fill(mask, 0) img_masked = model.unpatchify(masked_patches) # [1,3,224,224] inv_normalize = transforms.Normalize( mean=[-m/s for m, s in zip((0.485,0.456,0.406),(0.229,0.224,0.225))], std =[1/s for s in (0.229,0.224,0.225)] ) def to_img(tensor): img = tensor.squeeze(0).cpu() img = inv_normalize(img) img = img.permute(1,2,0).clamp(0,1).numpy() return img orig_np = to_img(x) masked_np = to_img(img_masked) recon_np = to_img(img_rec) # 8. Plot fig, axes = plt.subplots(1, 3, figsize=(15,5)) for ax, im, title in zip(axes, [orig_np, masked_np, recon_np], ['Original', 'Masked Input', 'Reconstruction']): ax.imshow(im) ax.set_title(title) ax.axis('off') plt.tight_layout() plt.show() plt.savefig(figure_name) visualize('MAE1.bin', img_path='guineapig.jpg', figure_name='figures/MAE_visualization1.png')