NoobNovel's picture
DDIM face generation — full project
0ca4c93
"""CelebA-HQ dataset loader.
Reads pre-cropped 256x256 JPGs from a flat directory, resizes to the target
stage resolution, applies horizontal flip augmentation, and normalizes to
[-1, 1] (the convention diffusion models work in).
"""
from __future__ import annotations
import glob
import os
from typing import List, Optional
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
class CelebAHQ(Dataset):
EXTS = (".jpg", ".jpeg", ".png")
def __init__(self, root: str, image_size: int, augment: bool = True,
limit: Optional[int] = None):
self.root = root
if not os.path.isdir(root):
raise FileNotFoundError(f"data dir not found: {root}")
files: List[str] = []
for ext in self.EXTS:
files.extend(glob.glob(os.path.join(root, f"*{ext}")))
files.extend(glob.glob(os.path.join(root, f"*{ext.upper()}")))
files = sorted(set(files))
if not files:
raise RuntimeError(f"no images found in {root}")
if limit is not None:
files = files[:limit]
self.files = files
self.image_size = image_size
ops = []
if augment:
ops.append(transforms.RandomHorizontalFlip(p=0.5))
# bilinear is the standard choice for downsampling photographs
ops.append(transforms.Resize(image_size, antialias=True))
ops.append(transforms.CenterCrop(image_size))
ops.append(transforms.ToTensor()) # [0, 1]
ops.append(transforms.Normalize([0.5] * 3, [0.5] * 3)) # [-1, 1]
self.transform = transforms.Compose(ops)
def __len__(self) -> int:
return len(self.files)
def __getitem__(self, idx: int) -> torch.Tensor:
path = self.files[idx]
with Image.open(path) as img:
img = img.convert("RGB")
return self.transform(img)
def make_dataloader(
root: str,
image_size: int,
batch_size: int,
num_workers: int = 4,
augment: bool = True,
shuffle: bool = True,
limit: Optional[int] = None,
pin_memory: bool = False,
) -> DataLoader:
dataset = CelebAHQ(root=root, image_size=image_size, augment=augment, limit=limit)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=True,
persistent_workers=num_workers > 0,
)
def denormalize(x: torch.Tensor) -> torch.Tensor:
"""Map [-1, 1] tensors back to [0, 1] for visualization/saving."""
return (x.clamp(-1.0, 1.0) + 1.0) / 2.0
# ---------------------------------------------------------------------------
# Self-test
# ---------------------------------------------------------------------------
if __name__ == "__main__":
DATA_DIR = "/Volumes/Projects/DDIM_image_Generation/celeba_hq_256"
ds = CelebAHQ(DATA_DIR, image_size=64, augment=True, limit=8)
assert len(ds) == 8
x = ds[0]
assert x.shape == (3, 64, 64), x.shape
assert -1.0 <= x.min().item() <= x.max().item() <= 1.0
loader = make_dataloader(DATA_DIR, image_size=64, batch_size=4,
num_workers=0, limit=8)
batch = next(iter(loader))
assert batch.shape == (4, 3, 64, 64), batch.shape
print(f"dataset ok: {len(CelebAHQ(DATA_DIR, image_size=64, augment=False))} images total")
# test a 256 sample as well
ds256 = CelebAHQ(DATA_DIR, image_size=256, augment=False, limit=2)
assert ds256[0].shape == (3, 256, 256)
print("dataset.py: all tests passed")