| import os |
| import numpy as np |
| import albumentations |
| from torch.utils.data import Dataset |
|
|
| from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex |
|
|
|
|
| class FacesBase(Dataset): |
| def __init__(self, *args, **kwargs): |
| super().__init__() |
| self.data = None |
| self.keys = None |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, i): |
| example = self.data[i] |
| ex = {} |
| if self.keys is not None: |
| for k in self.keys: |
| ex[k] = example[k] |
| else: |
| ex = example |
| return ex |
|
|
|
|
| class CelebAHQTrain(FacesBase): |
| def __init__(self, size, keys=None): |
| super().__init__() |
| root = "data/celebahq" |
| with open("data/celebahqtrain.txt", "r") as f: |
| relpaths = f.read().splitlines() |
| paths = [os.path.join(root, relpath) for relpath in relpaths] |
| self.data = NumpyPaths(paths=paths, size=size, random_crop=False) |
| self.keys = keys |
|
|
|
|
| class CelebAHQValidation(FacesBase): |
| def __init__(self, size, keys=None): |
| super().__init__() |
| root = "data/celebahq" |
| with open("data/celebahqvalidation.txt", "r") as f: |
| relpaths = f.read().splitlines() |
| paths = [os.path.join(root, relpath) for relpath in relpaths] |
| self.data = NumpyPaths(paths=paths, size=size, random_crop=False) |
| self.keys = keys |
|
|
|
|
| class FFHQTrain(FacesBase): |
| def __init__(self, size, keys=None): |
| super().__init__() |
| root = "data/ffhq" |
| with open("data/ffhqtrain.txt", "r") as f: |
| relpaths = f.read().splitlines() |
| paths = [os.path.join(root, relpath) for relpath in relpaths] |
| self.data = ImagePaths(paths=paths, size=size, random_crop=False) |
| self.keys = keys |
|
|
|
|
| class FFHQValidation(FacesBase): |
| def __init__(self, size, keys=None): |
| super().__init__() |
| root = "data/ffhq" |
| with open("data/ffhqvalidation.txt", "r") as f: |
| relpaths = f.read().splitlines() |
| paths = [os.path.join(root, relpath) for relpath in relpaths] |
| self.data = ImagePaths(paths=paths, size=size, random_crop=False) |
| self.keys = keys |
|
|
|
|
| class FacesHQTrain(Dataset): |
| |
| def __init__(self, size, keys=None, crop_size=None, coord=False): |
| d1 = CelebAHQTrain(size=size, keys=keys) |
| d2 = FFHQTrain(size=size, keys=keys) |
| self.data = ConcatDatasetWithIndex([d1, d2]) |
| self.coord = coord |
| if crop_size is not None: |
| self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) |
| if self.coord: |
| self.cropper = albumentations.Compose([self.cropper], |
| additional_targets={"coord": "image"}) |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, i): |
| ex, y = self.data[i] |
| if hasattr(self, "cropper"): |
| if not self.coord: |
| out = self.cropper(image=ex["image"]) |
| ex["image"] = out["image"] |
| else: |
| h,w,_ = ex["image"].shape |
| coord = np.arange(h*w).reshape(h,w,1)/(h*w) |
| out = self.cropper(image=ex["image"], coord=coord) |
| ex["image"] = out["image"] |
| ex["coord"] = out["coord"] |
| ex["class"] = y |
| return ex |
|
|
|
|
| class FacesHQValidation(Dataset): |
| |
| def __init__(self, size, keys=None, crop_size=None, coord=False): |
| d1 = CelebAHQValidation(size=size, keys=keys) |
| d2 = FFHQValidation(size=size, keys=keys) |
| self.data = ConcatDatasetWithIndex([d1, d2]) |
| self.coord = coord |
| if crop_size is not None: |
| self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) |
| if self.coord: |
| self.cropper = albumentations.Compose([self.cropper], |
| additional_targets={"coord": "image"}) |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, i): |
| ex, y = self.data[i] |
| if hasattr(self, "cropper"): |
| if not self.coord: |
| out = self.cropper(image=ex["image"]) |
| ex["image"] = out["image"] |
| else: |
| h,w,_ = ex["image"].shape |
| coord = np.arange(h*w).reshape(h,w,1)/(h*w) |
| out = self.cropper(image=ex["image"], coord=coord) |
| ex["image"] = out["image"] |
| ex["coord"] = out["coord"] |
| ex["class"] = y |
| return ex |
|
|