| import PIL | |
| import torch | |
| import requests | |
| import torchvision | |
| from math import ceil | |
| from io import BytesIO | |
| import torchvision.transforms.functional as F | |
| def download_image(url): | |
| return PIL.Image.open(requests.get(url, stream=True).raw).convert("RGB") | |
| def resize_image(image, size=768): | |
| tensor_image = F.to_tensor(image) | |
| resized_image = F.resize(tensor_image, size, antialias=True) | |
| return resized_image | |
| def downscale_images(images, factor=3/4): | |
| scaled_height, scaled_width = int(((images.size(-2)*factor)//32)*32), int(((images.size(-1)*factor)//32)*32) | |
| scaled_image = torchvision.transforms.functional.resize(images, (scaled_height, scaled_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST) | |
| return scaled_image | |
| def show_images(images, rows=None, cols=None, return_images=False, **kwargs): | |
| if images.size(1) == 1: | |
| images = images.repeat(1, 3, 1, 1) | |
| elif images.size(1) > 3: | |
| images = images[:, :3] | |
| if rows is None: | |
| rows = 1 | |
| if cols is None: | |
| cols = images.size(0) // rows | |
| _, _, h, w = images.shape | |
| grid = PIL.Image.new('RGB', size=(cols * w, rows * h)) | |
| for i, img in enumerate(images): | |
| img = torchvision.transforms.functional.to_pil_image(img.clamp(0, 1)) | |
| grid.paste(img, box=(i % cols * w, i // cols * h)) | |
| if return_images: | |
| return grid | |
| def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0): | |
| resolution_multiple = 42.67 | |
| latent_height = ceil(height / compression_factor_b) | |
| latent_width = ceil(width / compression_factor_b) | |
| stage_c_latent_shape = (batch_size, 16, latent_height, latent_width) | |
| latent_height = ceil(height / compression_factor_a) | |
| latent_width = ceil(width / compression_factor_a) | |
| stage_b_latent_shape = (batch_size, 4, latent_height, latent_width) | |
| return stage_c_latent_shape, stage_b_latent_shape | |