File size: 3,645 Bytes
0ca4c93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""CelebA-HQ dataset loader.

Reads pre-cropped 256x256 JPGs from a flat directory, resizes to the target
stage resolution, applies horizontal flip augmentation, and normalizes to
[-1, 1] (the convention diffusion models work in).
"""
from __future__ import annotations

import glob
import os
from typing import List, Optional

import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


class CelebAHQ(Dataset):
    EXTS = (".jpg", ".jpeg", ".png")

    def __init__(self, root: str, image_size: int, augment: bool = True,
                 limit: Optional[int] = None):
        self.root = root
        if not os.path.isdir(root):
            raise FileNotFoundError(f"data dir not found: {root}")

        files: List[str] = []
        for ext in self.EXTS:
            files.extend(glob.glob(os.path.join(root, f"*{ext}")))
            files.extend(glob.glob(os.path.join(root, f"*{ext.upper()}")))
        files = sorted(set(files))
        if not files:
            raise RuntimeError(f"no images found in {root}")
        if limit is not None:
            files = files[:limit]
        self.files = files
        self.image_size = image_size

        ops = []
        if augment:
            ops.append(transforms.RandomHorizontalFlip(p=0.5))
        # bilinear is the standard choice for downsampling photographs
        ops.append(transforms.Resize(image_size, antialias=True))
        ops.append(transforms.CenterCrop(image_size))
        ops.append(transforms.ToTensor())                  # [0, 1]
        ops.append(transforms.Normalize([0.5] * 3, [0.5] * 3))  # [-1, 1]
        self.transform = transforms.Compose(ops)

    def __len__(self) -> int:
        return len(self.files)

    def __getitem__(self, idx: int) -> torch.Tensor:
        path = self.files[idx]
        with Image.open(path) as img:
            img = img.convert("RGB")
            return self.transform(img)


def make_dataloader(
    root: str,
    image_size: int,
    batch_size: int,
    num_workers: int = 4,
    augment: bool = True,
    shuffle: bool = True,
    limit: Optional[int] = None,
    pin_memory: bool = False,
) -> DataLoader:
    dataset = CelebAHQ(root=root, image_size=image_size, augment=augment, limit=limit)
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=True,
        persistent_workers=num_workers > 0,
    )


def denormalize(x: torch.Tensor) -> torch.Tensor:
    """Map [-1, 1] tensors back to [0, 1] for visualization/saving."""
    return (x.clamp(-1.0, 1.0) + 1.0) / 2.0


# ---------------------------------------------------------------------------
# Self-test
# ---------------------------------------------------------------------------
if __name__ == "__main__":
    DATA_DIR = "/Volumes/Projects/DDIM_image_Generation/celeba_hq_256"

    ds = CelebAHQ(DATA_DIR, image_size=64, augment=True, limit=8)
    assert len(ds) == 8
    x = ds[0]
    assert x.shape == (3, 64, 64), x.shape
    assert -1.0 <= x.min().item() <= x.max().item() <= 1.0

    loader = make_dataloader(DATA_DIR, image_size=64, batch_size=4,
                             num_workers=0, limit=8)
    batch = next(iter(loader))
    assert batch.shape == (4, 3, 64, 64), batch.shape
    print(f"dataset ok: {len(CelebAHQ(DATA_DIR, image_size=64, augment=False))} images total")

    # test a 256 sample as well
    ds256 = CelebAHQ(DATA_DIR, image_size=256, augment=False, limit=2)
    assert ds256[0].shape == (3, 256, 256)
    print("dataset.py: all tests passed")