EmpathNetMAE / ModelVisualizer.py
prekshyam's picture
Added the Visualizer, Transformer, MAE, and a test image
4283e0c verified
raw
history blame
2.71 kB
import torch
import numpy as np
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
import maevit import MAEViT
def visualize(model_path, img_path, figure_name):
model = MAEViT(
image_size=224,
patch_size=16,
embed_dim=128,
encoder_layers=2,
encoder_heads=4,
mlp_ratio=2.0,
mask_ratio=0.75,
decoder_embed_dim=64,
decoder_layers=2,
decoder_heads=4,
dropout=0.1
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint)
model.eval()
to_tensor = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std =[0.229, 0.224, 0.225]
)
])
img = Image.open(img_path).convert('RGB')
x = to_tensor(img).unsqueeze(0).to(device) # [1,3,224,224]
with torch.no_grad():
x_enc, mask, ids_restore = model.forward_encoder(x)
x_rec_patches = model.forward_decoder(x_enc, ids_restore)
img_rec = model.unpatchify(x_rec_patches[:, 1:, :]) # exclude CLS # [1,3,224,224]
img_patches = model.patchify(x) # [1, num_patches, patch_dim]
masked_patches = img_patches.clone()
mask = mask.unsqueeze(-1).to(torch.bool) # [1, num_patches, 1]
# masked_patches[mask] = 0
masked_patches = masked_patches.masked_fill(mask, 0)
img_masked = model.unpatchify(masked_patches) # [1,3,224,224]
inv_normalize = transforms.Normalize(
mean=[-m/s for m, s in zip((0.485,0.456,0.406),(0.229,0.224,0.225))],
std =[1/s for s in (0.229,0.224,0.225)]
)
def to_img(tensor):
img = tensor.squeeze(0).cpu()
img = inv_normalize(img)
img = img.permute(1,2,0).clamp(0,1).numpy()
return img
orig_np = to_img(x)
masked_np = to_img(img_masked)
recon_np = to_img(img_rec)
# 8. Plot
fig, axes = plt.subplots(1, 3, figsize=(15,5))
for ax, im, title in zip(axes,
[orig_np, masked_np, recon_np],
['Original', 'Masked Input', 'Reconstruction']):
ax.imshow(im)
ax.set_title(title)
ax.axis('off')
plt.tight_layout()
plt.show()
plt.savefig(figure_name)
visualize('MAE1.bin', img_path='guineapig.jpg', figure_name='figures/MAE_visualization1.png')