File size: 1,447 Bytes
ee3e701 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
encode_transform = transforms.Compose(
[
transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(256),
transforms.ToTensor(),
]
)
def convert_decode_to_pil(rec_image):
rec_image = 2.0 * rec_image - 1.0
rec_image = torch.clamp(rec_image, -1.0, 1.0)
rec_image = (rec_image + 1.0) / 2.0
rec_image *= 255.0
rec_image = rec_image.permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
pil_images = [Image.fromarray(image) for image in rec_image]
return pil_images
def patchify(imgs, p):
"""
imgs: (N, C, H, W)
x: (N, L, patch_size**2 * C)
"""
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
in_chans = imgs.shape[1]
h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], in_chans, h, p, w, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * in_chans))
return x
def unpatchify(x, p):
"""
x: (N, L, patch_size**2 * C)
imgs: (N, C, H, W)
"""
# p = self.patch_embed.patch_size[0]
h = w = int(x.shape[1] ** .5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, -1))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], -1, h * p, h * p))
return imgs |