Spaces:
Sleeping
Sleeping
| import numpy as np | |
| from PIL import Image | |
| def center_crop_arr(pil_image, image_size): | |
| while min(*pil_image.size) >= 2 * image_size: | |
| pil_image = pil_image.resize( | |
| tuple(x // 2 for x in pil_image.size), resample=Image.BOX | |
| ) | |
| scale = image_size / min(*pil_image.size) | |
| pil_image = pil_image.resize( | |
| tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC | |
| ) | |
| arr = np.array(pil_image) | |
| crop_y = (arr.shape[0] - image_size) // 2 | |
| crop_x = (arr.shape[1] - image_size) // 2 | |
| return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] | |
| class DatasetFactory(object): | |
| def __init__(self): | |
| self.train = None | |
| self.test = None | |
| def get_split(self, split, labeled=False): | |
| if split == "train": | |
| dataset = self.train | |
| elif split == "test": | |
| dataset = self.test | |
| else: | |
| raise ValueError | |
| return dataset | |
| def unpreprocess(self, v): # to B C H W and [0, 1] | |
| v = 0.5 * (v + 1.) | |
| v.clamp_(0., 1.) | |
| return v | |
| def data_shape(self): | |
| raise NotImplementedError | |
| def fid_stat(self): | |
| return None | |