Spaces:
Sleeping
Sleeping
| """CelebA-HQ dataset loader. | |
| Reads pre-cropped 256x256 JPGs from a flat directory, resizes to the target | |
| stage resolution, applies horizontal flip augmentation, and normalizes to | |
| [-1, 1] (the convention diffusion models work in). | |
| """ | |
| from __future__ import annotations | |
| import glob | |
| import os | |
| from typing import List, Optional | |
| import torch | |
| from PIL import Image | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| class CelebAHQ(Dataset): | |
| EXTS = (".jpg", ".jpeg", ".png") | |
| def __init__(self, root: str, image_size: int, augment: bool = True, | |
| limit: Optional[int] = None): | |
| self.root = root | |
| if not os.path.isdir(root): | |
| raise FileNotFoundError(f"data dir not found: {root}") | |
| files: List[str] = [] | |
| for ext in self.EXTS: | |
| files.extend(glob.glob(os.path.join(root, f"*{ext}"))) | |
| files.extend(glob.glob(os.path.join(root, f"*{ext.upper()}"))) | |
| files = sorted(set(files)) | |
| if not files: | |
| raise RuntimeError(f"no images found in {root}") | |
| if limit is not None: | |
| files = files[:limit] | |
| self.files = files | |
| self.image_size = image_size | |
| ops = [] | |
| if augment: | |
| ops.append(transforms.RandomHorizontalFlip(p=0.5)) | |
| # bilinear is the standard choice for downsampling photographs | |
| ops.append(transforms.Resize(image_size, antialias=True)) | |
| ops.append(transforms.CenterCrop(image_size)) | |
| ops.append(transforms.ToTensor()) # [0, 1] | |
| ops.append(transforms.Normalize([0.5] * 3, [0.5] * 3)) # [-1, 1] | |
| self.transform = transforms.Compose(ops) | |
| def __len__(self) -> int: | |
| return len(self.files) | |
| def __getitem__(self, idx: int) -> torch.Tensor: | |
| path = self.files[idx] | |
| with Image.open(path) as img: | |
| img = img.convert("RGB") | |
| return self.transform(img) | |
| def make_dataloader( | |
| root: str, | |
| image_size: int, | |
| batch_size: int, | |
| num_workers: int = 4, | |
| augment: bool = True, | |
| shuffle: bool = True, | |
| limit: Optional[int] = None, | |
| pin_memory: bool = False, | |
| ) -> DataLoader: | |
| dataset = CelebAHQ(root=root, image_size=image_size, augment=augment, limit=limit) | |
| return DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=shuffle, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory, | |
| drop_last=True, | |
| persistent_workers=num_workers > 0, | |
| ) | |
| def denormalize(x: torch.Tensor) -> torch.Tensor: | |
| """Map [-1, 1] tensors back to [0, 1] for visualization/saving.""" | |
| return (x.clamp(-1.0, 1.0) + 1.0) / 2.0 | |
| # --------------------------------------------------------------------------- | |
| # Self-test | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| DATA_DIR = "/Volumes/Projects/DDIM_image_Generation/celeba_hq_256" | |
| ds = CelebAHQ(DATA_DIR, image_size=64, augment=True, limit=8) | |
| assert len(ds) == 8 | |
| x = ds[0] | |
| assert x.shape == (3, 64, 64), x.shape | |
| assert -1.0 <= x.min().item() <= x.max().item() <= 1.0 | |
| loader = make_dataloader(DATA_DIR, image_size=64, batch_size=4, | |
| num_workers=0, limit=8) | |
| batch = next(iter(loader)) | |
| assert batch.shape == (4, 3, 64, 64), batch.shape | |
| print(f"dataset ok: {len(CelebAHQ(DATA_DIR, image_size=64, augment=False))} images total") | |
| # test a 256 sample as well | |
| ds256 = CelebAHQ(DATA_DIR, image_size=256, augment=False, limit=2) | |
| assert ds256[0].shape == (3, 256, 256) | |
| print("dataset.py: all tests passed") | |