| import torch |
| from models import VQVAE, build_vae_var |
| from dataset.imagenet_dataset import get_train_transforms |
| from PIL import Image |
| from torchvision import transforms |
|
|
|
|
| device = 'mps' |
| patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16) |
|
|
| vae, var = build_vae_var( |
| V=4096, Cvae=32, ch=160, share_quant_resi=4, |
| device=device, patch_nums=patch_nums, |
| num_classes=1000, depth=16, shared_aln=False, |
| ) |
| var_ckpt='var_d16.pth' |
| vae_ckpt='vae_ch160v4096z32.pth' |
| var.load_state_dict(torch.load(var_ckpt, map_location=device), strict=True) |
| vae.load_state_dict(torch.load(vae_ckpt, map_location=device), strict=True) |