| """ | |
| CelebFaces Attributes (CelebA) Dataset | |
| https://www.kaggle.com/datasets/jessicali9530/celeba-dataset | |
| """ | |
| import os | |
| import torch | |
| from PIL import Image | |
| from torch.utils.data import DataLoader, Dataset | |
| from torchvision import transforms | |
| class CelebADataset(Dataset): | |
| def __init__(self, root, img_shape=(64, 64)) -> None: | |
| super().__init__() | |
| self.root = root | |
| self.img_shape = img_shape | |
| self.filenames = sorted(os.listdir(root)) | |
| def __len__(self) -> int: | |
| return len(self.filenames) | |
| def __getitem__(self, index: int): | |
| path = os.path.join(self.root, self.filenames[index]) | |
| img = Image.open(path).convert('RGB') | |
| pipeline = transforms.Compose([ | |
| transforms.CenterCrop(168), | |
| transforms.Resize(self.img_shape), | |
| transforms.ToTensor() | |
| ]) | |
| return pipeline(img) | |
| def get_dataloader(root='data/celebA/img_align_celeba', **kwargs): | |
| dataset = CelebADataset(root, **kwargs) | |
| return DataLoader(dataset, 16, shuffle=True) | |