| import torch |
| from diffusers import AutoencoderKL |
| from PIL import Image |
| from torchvision import transforms |
| import numpy as np |
|
|
| def center_crop_arr(pil_image, image_size): |
| """ |
| Center cropping implementation from ADM. |
| https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 |
| """ |
| while min(*pil_image.size) >= 2 * image_size: |
| pil_image = pil_image.resize( |
| tuple(x // 2 for x in pil_image.size), resample=Image.BOX |
| ) |
|
|
| scale = image_size / min(*pil_image.size) |
| pil_image = pil_image.resize( |
| tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC |
| ) |
|
|
| arr = np.array(pil_image) |
| crop_y = (arr.shape[0] - image_size) // 2 |
| crop_x = (arr.shape[1] - image_size) // 2 |
| return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) |
|
|
| class Diffusers_AutoencoderKL(AutoencoderKL): |
|
|
| def __init__(self, img_size=256, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.img_size = img_size |
|
|
| def img_transform(self, p_hflip=0, img_size=None): |
| """Image preprocessing transforms |
| Args: |
| p_hflip: Probability of horizontal flip |
| img_size: Target image size, use default if None |
| Returns: |
| transforms.Compose: Image transform pipeline |
| """ |
| img_size = img_size if img_size is not None else self.img_size |
| img_transforms = [ |
| transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, img_size)), |
| transforms.RandomHorizontalFlip(p=p_hflip), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) |
| ] |
| return transforms.Compose(img_transforms) |
| |
| def encode_images(self, images): |
| """Encode images to latent representations |
| Args: |
| images: Input image tensor |
| Returns: |
| torch.Tensor: Encoded latent representation |
| """ |
| with torch.no_grad(): |
| posterior = self.encode(images.cuda(), return_dict=False)[0] |
| return posterior.mode() |
|
|
| def decode_to_images(self, z): |
| """Decode latent representations to images |
| Args: |
| z: Latent representation tensor |
| Returns: |
| np.ndarray: Decoded image array |
| """ |
| with torch.no_grad(): |
| images = self.decode(z.cuda(), return_dict=False)[0] |
| images = torch.clamp(127.5 * images + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() |
| return images |