import torch from diffusers import AutoencoderKL from PIL import Image from torchvision import transforms import numpy as np def center_crop_arr(pil_image, image_size): """ Center cropping implementation from ADM. https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 """ while min(*pil_image.size) >= 2 * image_size: pil_image = pil_image.resize( tuple(x // 2 for x in pil_image.size), resample=Image.BOX ) scale = image_size / min(*pil_image.size) pil_image = pil_image.resize( tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC ) arr = np.array(pil_image) crop_y = (arr.shape[0] - image_size) // 2 crop_x = (arr.shape[1] - image_size) // 2 return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) class Diffusers_AutoencoderKL(AutoencoderKL): def __init__(self, img_size=256, *args, **kwargs): super().__init__(*args, **kwargs) self.img_size = img_size def img_transform(self, p_hflip=0, img_size=None): """Image preprocessing transforms Args: p_hflip: Probability of horizontal flip img_size: Target image size, use default if None Returns: transforms.Compose: Image transform pipeline """ img_size = img_size if img_size is not None else self.img_size img_transforms = [ transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, img_size)), transforms.RandomHorizontalFlip(p=p_hflip), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) ] return transforms.Compose(img_transforms) def encode_images(self, images): """Encode images to latent representations Args: images: Input image tensor Returns: torch.Tensor: Encoded latent representation """ with torch.no_grad(): posterior = self.encode(images.cuda(), return_dict=False)[0] return posterior.mode() def decode_to_images(self, z): """Decode latent representations to images Args: z: Latent representation tensor Returns: np.ndarray: Decoded image array """ with torch.no_grad(): images = self.decode(z.cuda(), return_dict=False)[0] images = torch.clamp(127.5 * images + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() return images