| | --- |
| | license: mit |
| | --- |
| | |
| | Custom hand-made 3-scale VQVAE trained on private dataset that consists of about 4k images pixelart images. |
| | Source code for model can be found [here](https://github.com/Kemsekov/kemsekov_torch/tree/main/vqvae). |
| |
|
| |
|
| | It acrhived 0.987 r2 metric on image reconstruction in 500 epoch on 256x256 images crops. |
| |
|
| | Because I used crops, this model works fine with larger and smaller images as well. |
| |
|
| | Model have codebook: |
| | * 512 bottom |
| | * 512 mid |
| | * 256 top |
| |
|
| | This provides enough space for model to achieve good metrics. |
| |
|
| | Here is code example how to use it. |
| |
|
| |
|
| | ```py |
| | import random |
| | import PIL.Image |
| | from matplotlib import pyplot as plt |
| | import torch |
| | import torchvision.transforms as T |
| | |
| | sample = PIL.Image.open("image.png") # you sample image |
| | sample = T.ToTensor()(sample)[None,:] # add batch dimension |
| | sample = T.RandomCrop((256,256))(sample) # this vqvae works fine with any input image size that is divisible by 8 |
| | |
| | vqvae=torch.jit.load("model_v3.pt") |
| | |
| | # rec, rec_ind is reconstructions |
| | # rec is reconstruction from latent space values z |
| | # rec_ind is reconstruction from model predicted vector indices |
| | # z latent space tensor with 64 channels and 4x smaller than input image |
| | # z_layers is list of latent space tensors at different scales |
| | # z_q_layers is quantized list of latent space tensors |
| | # ind is list of encoded indices of quantized elements in latent space for each scale |
| | |
| | z, z_layers,z_q_layers, ind = vqvae.encode(sample) |
| | rec_ind = vqvae.decode_from_ind(ind).sigmoid() |
| | rec = vqvae.decode(z).sigmoid() |
| | |
| | print("Original image shape",list(sample.shape[1:])) |
| | print("ind shapes",[list(v.shape[1:]) for v in ind]) |
| | |
| | plt.figure(figsize=(18,6)) |
| | plt.subplot(1,3,1) |
| | plt.imshow(T.ToPILImage()(sample[0]).resize((256,256))) |
| | plt.title("original") |
| | plt.axis('off') |
| | |
| | # these two must look the same |
| | plt.subplot(1,3,2) |
| | plt.imshow(T.ToPILImage()(rec[0]).resize((256,256))) |
| | plt.title("reconstruction") |
| | plt.axis('off') |
| | |
| | |
| | plt.subplot(1,3,3) |
| | plt.imshow(T.ToPILImage()(rec_ind[0]).resize((256,256))) |
| | plt.title("reconstruction from ind") |
| | plt.axis('off') |
| | plt.show() |
| | |
| | # this must look like a pile of mess |
| | plt.figure(figsize=(18,6)) |
| | plt.subplot(1,3,1) |
| | plt.imshow(T.ToPILImage()(ind[0]/512).resize((256,256))) |
| | plt.title("ind0") |
| | plt.axis('off') |
| | |
| | plt.subplot(1,3,2) |
| | plt.imshow(T.ToPILImage()(ind[1]/512).resize((256,256))) |
| | plt.title("ind1") |
| | plt.axis('off') |
| | |
| | plt.subplot(1,3,3) |
| | plt.imshow(T.ToPILImage()(ind[2]/256).resize((256,256))) |
| | plt.title("ind2") |
| | plt.axis('off') |
| | plt.show() |
| | |
| | print("latent space render") |
| | for z_ in z_layers: |
| | dims = len(z_[0]) |
| | dims_sqrt = int(dims**0.5) |
| | plt.figure(figsize=(10,10)) |
| | plt.axis('off') |
| | for i in range(dims_sqrt): |
| | for j in range(dims_sqrt): |
| | slice_ind = i*dims_sqrt+j |
| | slice_ind_end = slice_ind+1 |
| | plt.subplot(dims_sqrt,dims_sqrt,slice_ind+1) |
| | plt.imshow(T.ToPILImage()(z_[0][slice_ind:slice_ind_end])) |
| | plt.axis('off') |
| | plt.show() |
| | ``` |
| |
|
| | ``` |
| | Original image shape [3, 256, 256] |
| | ind shapes [[64, 64], [32, 32], [16, 16]] |
| | ``` |
| |
|
| | Here is some examples at 256x256 resolution |
| |  |
| |  |
| |  |
| |  |
| |
|