from __future__ import annotations import io from dataclasses import dataclass, field from pathlib import Path from typing import Any import cv2 import numpy as np from PIL import Image import torch from torch.utils.data import DataLoader from torchvision import transforms as T from torchvision.transforms import functional as F from torchvision.transforms import InterpolationMode import datasets from datasets import load_dataset, load_from_disk from transformers import CLIPTokenizer @dataclass class CannyCFG: sigma: float = 0.33 d: int = 7 sigma_color: float = 50 sigma_space: float = 50 @dataclass class LaionPrepCFG: dataset_name: str = 'bhargavsdesai/laion_improved_aesthetics_6.5plus_with_images' resolution: tuple[int, int] = (512, 512) val_size: int = 10 val_seed: int = 1 canny: CannyCFG = field(default_factory=CannyCFG) cache_dir: str = './data' map_bs: int = 256 map_np: int = 8 num_workers: int = 4 def canny_auto_median_bilateral(pil_img: Image.Image, cfg: CannyCFG) -> Image.Image: gray = np.array(pil_img.convert('L'), dtype=np.uint8) gray_bilat = cv2.bilateralFilter( gray, d=cfg.d, sigmaColor=cfg.sigma_color, sigmaSpace=cfg.sigma_space ) v = float(np.median(gray_bilat)) low = int(max(0, (1.0 - cfg.sigma) * v)) high = int(min(255, (1.0 + cfg.sigma) * v)) if high <= low: high = min(255, low + 1) edges = cv2.Canny(gray_bilat, low, high) return Image.fromarray(edges, mode='L') def pil_to_png_bytes(img: Image.Image, compress_level: int = 1) -> bytes: buf = io.BytesIO() img.save(buf, format='PNG', compress_level=compress_level) return buf.getvalue() def get_image_map(canny_cfg: CannyCFG, resolution: tuple[int, int]): def image_map(batch: dict[str, Any]) -> dict[str, Any]: try: cv2.setNumThreads(0) except Exception: pass out_img = [] out_canny = [] for img in batch['image']: img = img.convert('RGB') img = F.resize(img, list(resolution), interpolation=InterpolationMode.BICUBIC) canny = canny_auto_median_bilateral(img, canny_cfg) # type: ignore out_img.append({'bytes': pil_to_png_bytes(img), 'path': None}) # type: ignore out_canny.append({'bytes': pil_to_png_bytes(canny, compress_level=1), 'path': None}) return {'image': out_img, 'canny': out_canny} return image_map def build_prepped_transform(): to_tensor = T.ToTensor() norm = T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) def _one(img: Image.Image, cond: Image.Image, text: Any): img = img.convert('RGB') cond = cond.convert('L') img_t = norm(to_tensor(img)) # [3,H,W] in [-1,1] cond_t = to_tensor(cond) # [1,H,W] in [0,1] cond_t = cond_t.repeat(3, 1, 1) # [3,H,W] to match conditioning_channels=3 text = '' if text is None else str(text) return img_t, cond_t, text def prepped_transform(ex: dict[str, list]) -> dict[str, list]: imgs = ex['image'] conds = ex['canny'] texts = ex['text'] px_list = [] cond_list = [] text_list = [] for img, cond, t in zip(imgs, conds, texts): px, cv, tt = _one(img, cond, t) px_list.append(px) cond_list.append(cv) text_list.append(tt) return { 'pixel_values': px_list, 'conditioning_pixel_values': cond_list, 'texts': text_list, } return prepped_transform def get_train_collate_fn(tokeniser: CLIPTokenizer, max_length: int, no_caption_prob: float): def train_collator_fn(batch: list[dict[str, Any]]) -> dict[str, Any]: pixel_values = torch.stack([b['pixel_values'] for b in batch]) conditioning_pixel_values = torch.stack([b['conditioning_pixel_values'] for b in batch]) texts = [b['texts'] for b in batch] if no_caption_prob > 0: drop = torch.rand(len(texts)) < no_caption_prob texts = [('' if d else t) for t, d in zip(texts, drop.tolist())] toks = tokeniser( texts, truncation=True, padding='longest', max_length=max_length, return_tensors='pt', ) return { 'pixel_values': pixel_values, 'conditioning_pixel_values': conditioning_pixel_values, 'input_ids': toks['input_ids'], 'attention_mask': toks['attention_mask'], } return train_collator_fn def get_train_dataloader(train_ds, collate_fn, batch_size: int, num_workers: int=0): return DataLoader( dataset=train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, persistent_workers=(num_workers > 0), collate_fn=collate_fn, ) def _dataset_dirname(cfg: LaionPrepCFG) -> str: H, W = cfg.resolution c = cfg.canny name = ( f'laion_r{H}x{W}' f'_sigma{c.sigma}_d{c.d}_sc{c.sigma_color}_ss{c.sigma_space}' ) return name.replace('.', '-') def get_dataset(cfg: LaionPrepCFG): ds_dir = _dataset_dirname(cfg) path = (Path(cfg.cache_dir) / ds_dir).resolve() if path.exists(): print(f'[load] {path}') return load_from_disk(str(path)) print(f'[build] {path} (not found, creating now)') path.parent.mkdir(parents=True, exist_ok=True) ds = load_dataset(cfg.dataset_name, split='train') ds = ds.cast_column('image', datasets.Image(decode=True)) image_map = get_image_map(cfg.canny, cfg.resolution) ds = ds.map( function=image_map, batched=True, batch_size=cfg.map_bs, num_proc=cfg.map_np, # type: ignore ) ds = ds.cast_column('image', datasets.Image(decode=True)) ds = ds.cast_column('canny', datasets.Image(decode=True)) ds.save_to_disk(str(path)) print(f'[saved] {path}') return ds def prepare_laion(cfg: LaionPrepCFG): ds = get_dataset(cfg) split = ds.train_test_split(test_size=cfg.val_size, seed=cfg.val_seed, shuffle=True) # type: ignore train_ds, val_ds = split['train'], split['test'] train_ds = train_ds.with_transform(build_prepped_transform()) return train_ds, val_ds