Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import io | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Any | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from torchvision import transforms as T | |
| from torchvision.transforms import functional as F | |
| from torchvision.transforms import InterpolationMode | |
| import datasets | |
| from datasets import load_dataset, load_from_disk | |
| from transformers import CLIPTokenizer | |
| class CannyCFG: | |
| sigma: float = 0.33 | |
| d: int = 7 | |
| sigma_color: float = 50 | |
| sigma_space: float = 50 | |
| class LaionPrepCFG: | |
| dataset_name: str = 'bhargavsdesai/laion_improved_aesthetics_6.5plus_with_images' | |
| resolution: tuple[int, int] = (512, 512) | |
| val_size: int = 10 | |
| val_seed: int = 1 | |
| canny: CannyCFG = field(default_factory=CannyCFG) | |
| cache_dir: str = './data' | |
| map_bs: int = 256 | |
| map_np: int = 8 | |
| num_workers: int = 4 | |
| def canny_auto_median_bilateral(pil_img: Image.Image, cfg: CannyCFG) -> Image.Image: | |
| gray = np.array(pil_img.convert('L'), dtype=np.uint8) | |
| gray_bilat = cv2.bilateralFilter( | |
| gray, d=cfg.d, sigmaColor=cfg.sigma_color, sigmaSpace=cfg.sigma_space | |
| ) | |
| v = float(np.median(gray_bilat)) | |
| low = int(max(0, (1.0 - cfg.sigma) * v)) | |
| high = int(min(255, (1.0 + cfg.sigma) * v)) | |
| if high <= low: | |
| high = min(255, low + 1) | |
| edges = cv2.Canny(gray_bilat, low, high) | |
| return Image.fromarray(edges, mode='L') | |
| def pil_to_png_bytes(img: Image.Image, compress_level: int = 1) -> bytes: | |
| buf = io.BytesIO() | |
| img.save(buf, format='PNG', compress_level=compress_level) | |
| return buf.getvalue() | |
| def get_image_map(canny_cfg: CannyCFG, resolution: tuple[int, int]): | |
| def image_map(batch: dict[str, Any]) -> dict[str, Any]: | |
| try: | |
| cv2.setNumThreads(0) | |
| except Exception: | |
| pass | |
| out_img = [] | |
| out_canny = [] | |
| for img in batch['image']: | |
| img = img.convert('RGB') | |
| img = F.resize(img, list(resolution), interpolation=InterpolationMode.BICUBIC) | |
| canny = canny_auto_median_bilateral(img, canny_cfg) # type: ignore | |
| out_img.append({'bytes': pil_to_png_bytes(img), 'path': None}) # type: ignore | |
| out_canny.append({'bytes': pil_to_png_bytes(canny, compress_level=1), 'path': None}) | |
| return {'image': out_img, 'canny': out_canny} | |
| return image_map | |
| def build_prepped_transform(): | |
| to_tensor = T.ToTensor() | |
| norm = T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) | |
| def _one(img: Image.Image, cond: Image.Image, text: Any): | |
| img = img.convert('RGB') | |
| cond = cond.convert('L') | |
| img_t = norm(to_tensor(img)) # [3,H,W] in [-1,1] | |
| cond_t = to_tensor(cond) # [1,H,W] in [0,1] | |
| cond_t = cond_t.repeat(3, 1, 1) # [3,H,W] to match conditioning_channels=3 | |
| text = '' if text is None else str(text) | |
| return img_t, cond_t, text | |
| def prepped_transform(ex: dict[str, list]) -> dict[str, list]: | |
| imgs = ex['image'] | |
| conds = ex['canny'] | |
| texts = ex['text'] | |
| px_list = [] | |
| cond_list = [] | |
| text_list = [] | |
| for img, cond, t in zip(imgs, conds, texts): | |
| px, cv, tt = _one(img, cond, t) | |
| px_list.append(px) | |
| cond_list.append(cv) | |
| text_list.append(tt) | |
| return { | |
| 'pixel_values': px_list, | |
| 'conditioning_pixel_values': cond_list, | |
| 'texts': text_list, | |
| } | |
| return prepped_transform | |
| def get_train_collate_fn(tokeniser: CLIPTokenizer, max_length: int, no_caption_prob: float): | |
| def train_collator_fn(batch: list[dict[str, Any]]) -> dict[str, Any]: | |
| pixel_values = torch.stack([b['pixel_values'] for b in batch]) | |
| conditioning_pixel_values = torch.stack([b['conditioning_pixel_values'] for b in batch]) | |
| texts = [b['texts'] for b in batch] | |
| if no_caption_prob > 0: | |
| drop = torch.rand(len(texts)) < no_caption_prob | |
| texts = [('' if d else t) for t, d in zip(texts, drop.tolist())] | |
| toks = tokeniser( | |
| texts, | |
| truncation=True, | |
| padding='longest', | |
| max_length=max_length, | |
| return_tensors='pt', | |
| ) | |
| return { | |
| 'pixel_values': pixel_values, | |
| 'conditioning_pixel_values': conditioning_pixel_values, | |
| 'input_ids': toks['input_ids'], | |
| 'attention_mask': toks['attention_mask'], | |
| } | |
| return train_collator_fn | |
| def get_train_dataloader(train_ds, collate_fn, batch_size: int, num_workers: int=0): | |
| return DataLoader( | |
| dataset=train_ds, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=num_workers, | |
| pin_memory=True, | |
| persistent_workers=(num_workers > 0), | |
| collate_fn=collate_fn, | |
| ) | |
| def _dataset_dirname(cfg: LaionPrepCFG) -> str: | |
| H, W = cfg.resolution | |
| c = cfg.canny | |
| name = ( | |
| f'laion_r{H}x{W}' | |
| f'_sigma{c.sigma}_d{c.d}_sc{c.sigma_color}_ss{c.sigma_space}' | |
| ) | |
| return name.replace('.', '-') | |
| def get_dataset(cfg: LaionPrepCFG): | |
| ds_dir = _dataset_dirname(cfg) | |
| path = (Path(cfg.cache_dir) / ds_dir).resolve() | |
| if path.exists(): | |
| print(f'[load] {path}') | |
| return load_from_disk(str(path)) | |
| print(f'[build] {path} (not found, creating now)') | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| ds = load_dataset(cfg.dataset_name, split='train') | |
| ds = ds.cast_column('image', datasets.Image(decode=True)) | |
| image_map = get_image_map(cfg.canny, cfg.resolution) | |
| ds = ds.map( | |
| function=image_map, | |
| batched=True, | |
| batch_size=cfg.map_bs, | |
| num_proc=cfg.map_np, # type: ignore | |
| ) | |
| ds = ds.cast_column('image', datasets.Image(decode=True)) | |
| ds = ds.cast_column('canny', datasets.Image(decode=True)) | |
| ds.save_to_disk(str(path)) | |
| print(f'[saved] {path}') | |
| return ds | |
| def prepare_laion(cfg: LaionPrepCFG): | |
| ds = get_dataset(cfg) | |
| split = ds.train_test_split(test_size=cfg.val_size, seed=cfg.val_seed, shuffle=True) # type: ignore | |
| train_ds, val_ds = split['train'], split['test'] | |
| train_ds = train_ds.with_transform(build_prepped_transform()) | |
| return train_ds, val_ds | |