import numpy as np from PIL import Image def center_crop_arr(pil_image, image_size): while min(*pil_image.size) >= 2 * image_size: pil_image = pil_image.resize( tuple(x // 2 for x in pil_image.size), resample=Image.BOX ) scale = image_size / min(*pil_image.size) pil_image = pil_image.resize( tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC ) arr = np.array(pil_image) crop_y = (arr.shape[0] - image_size) // 2 crop_x = (arr.shape[1] - image_size) // 2 return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] class DatasetFactory(object): def __init__(self): self.train = None self.test = None def get_split(self, split, labeled=False): if split == "train": dataset = self.train elif split == "test": dataset = self.test else: raise ValueError return dataset def unpreprocess(self, v): # to B C H W and [0, 1] v = 0.5 * (v + 1.) v.clamp_(0., 1.) return v @property def data_shape(self): raise NotImplementedError @property def fid_stat(self): return None