ICGenAIShare07 commited on
Commit
bd51c5a
·
verified ·
1 Parent(s): 67ea22c

Upload prepare_laion.py

Browse files
Files changed (1) hide show
  1. prepare_laion.py +220 -0
prepare_laion.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import io
4
+ from dataclasses import dataclass, field
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import cv2
9
+ import numpy as np
10
+ from PIL import Image
11
+
12
+ import torch
13
+ from torch.utils.data import DataLoader
14
+ from torchvision import transforms as T
15
+ from torchvision.transforms import functional as F
16
+ from torchvision.transforms import InterpolationMode
17
+
18
+ import datasets
19
+ from datasets import load_dataset, load_from_disk
20
+ from transformers import CLIPTokenizer
21
+
22
+ @dataclass
23
+ class CannyCFG:
24
+ sigma: float = 0.33
25
+ d: int = 7
26
+ sigma_color: float = 50
27
+ sigma_space: float = 50
28
+
29
+
30
+ @dataclass
31
+ class LaionPrepCFG:
32
+ dataset_name: str = 'bhargavsdesai/laion_improved_aesthetics_6.5plus_with_images'
33
+ resolution: tuple[int, int] = (512, 512)
34
+
35
+ val_size: int = 10
36
+ val_seed: int = 1
37
+
38
+ canny: CannyCFG = field(default_factory=CannyCFG)
39
+ cache_dir: str = './data'
40
+
41
+ map_bs: int = 256
42
+ map_np: int = 8
43
+
44
+ num_workers: int = 4
45
+
46
+ def canny_auto_median_bilateral(pil_img: Image.Image, cfg: CannyCFG) -> Image.Image:
47
+
48
+ gray = np.array(pil_img.convert('L'), dtype=np.uint8)
49
+
50
+ gray_bilat = cv2.bilateralFilter(
51
+ gray, d=cfg.d, sigmaColor=cfg.sigma_color, sigmaSpace=cfg.sigma_space
52
+ )
53
+
54
+ v = float(np.median(gray_bilat))
55
+ low = int(max(0, (1.0 - cfg.sigma) * v))
56
+ high = int(min(255, (1.0 + cfg.sigma) * v))
57
+ if high <= low:
58
+ high = min(255, low + 1)
59
+
60
+ edges = cv2.Canny(gray_bilat, low, high)
61
+ return Image.fromarray(edges, mode='L')
62
+
63
+
64
+ def pil_to_png_bytes(img: Image.Image, compress_level: int = 1) -> bytes:
65
+ buf = io.BytesIO()
66
+ img.save(buf, format='PNG', compress_level=compress_level)
67
+ return buf.getvalue()
68
+
69
+
70
+ def get_image_map(canny_cfg: CannyCFG, resolution: tuple[int, int]):
71
+
72
+ def image_map(batch: dict[str, Any]) -> dict[str, Any]:
73
+ try:
74
+ cv2.setNumThreads(0)
75
+ except Exception:
76
+ pass
77
+
78
+ out_img = []
79
+ out_canny = []
80
+
81
+ for img in batch['image']:
82
+ img = img.convert('RGB')
83
+ img = F.resize(img, list(resolution), interpolation=InterpolationMode.BICUBIC)
84
+
85
+ canny = canny_auto_median_bilateral(img, canny_cfg) # type: ignore
86
+ out_img.append({'bytes': pil_to_png_bytes(img), 'path': None}) # type: ignore
87
+ out_canny.append({'bytes': pil_to_png_bytes(canny, compress_level=1), 'path': None})
88
+
89
+ return {'image': out_img, 'canny': out_canny}
90
+
91
+ return image_map
92
+
93
+ def build_prepped_transform():
94
+ to_tensor = T.ToTensor()
95
+ norm = T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
96
+
97
+ def _one(img: Image.Image, cond: Image.Image, text: Any):
98
+ img = img.convert('RGB')
99
+ cond = cond.convert('L')
100
+
101
+ img_t = norm(to_tensor(img)) # [3,H,W] in [-1,1]
102
+ cond_t = to_tensor(cond) # [1,H,W] in [0,1]
103
+ cond_t = cond_t.repeat(3, 1, 1) # [3,H,W] to match conditioning_channels=3
104
+
105
+ text = '' if text is None else str(text)
106
+ return img_t, cond_t, text
107
+
108
+ def prepped_transform(ex: dict[str, list]) -> dict[str, list]:
109
+ imgs = ex['image']
110
+ conds = ex['canny']
111
+ texts = ex['text']
112
+
113
+ px_list = []
114
+ cond_list = []
115
+ text_list = []
116
+
117
+ for img, cond, t in zip(imgs, conds, texts):
118
+ px, cv, tt = _one(img, cond, t)
119
+ px_list.append(px)
120
+ cond_list.append(cv)
121
+ text_list.append(tt)
122
+
123
+ return {
124
+ 'pixel_values': px_list,
125
+ 'conditioning_pixel_values': cond_list,
126
+ 'texts': text_list,
127
+ }
128
+
129
+ return prepped_transform
130
+
131
+
132
+ def get_train_collate_fn(tokeniser: CLIPTokenizer, max_length: int, no_caption_prob: float):
133
+ def train_collator_fn(batch: list[dict[str, Any]]) -> dict[str, Any]:
134
+ pixel_values = torch.stack([b['pixel_values'] for b in batch])
135
+ conditioning_pixel_values = torch.stack([b['conditioning_pixel_values'] for b in batch])
136
+ texts = [b['texts'] for b in batch]
137
+
138
+ if no_caption_prob > 0:
139
+ drop = torch.rand(len(texts)) < no_caption_prob
140
+ texts = [('' if d else t) for t, d in zip(texts, drop.tolist())]
141
+
142
+ toks = tokeniser(
143
+ texts,
144
+ truncation=True,
145
+ padding='longest',
146
+ max_length=max_length,
147
+ return_tensors='pt',
148
+ )
149
+
150
+ return {
151
+ 'pixel_values': pixel_values,
152
+ 'conditioning_pixel_values': conditioning_pixel_values,
153
+ 'input_ids': toks['input_ids'],
154
+ 'attention_mask': toks['attention_mask'],
155
+ }
156
+
157
+ return train_collator_fn
158
+
159
+
160
+ def get_train_dataloader(train_ds, collate_fn, batch_size: int, num_workers: int=0):
161
+ return DataLoader(
162
+ dataset=train_ds,
163
+ batch_size=batch_size,
164
+ shuffle=True,
165
+ num_workers=num_workers,
166
+ pin_memory=True,
167
+ persistent_workers=(num_workers > 0),
168
+ collate_fn=collate_fn,
169
+ )
170
+
171
+ def _dataset_dirname(cfg: LaionPrepCFG) -> str:
172
+ H, W = cfg.resolution
173
+ c = cfg.canny
174
+ name = (
175
+ f'laion_r{H}x{W}'
176
+ f'_sigma{c.sigma}_d{c.d}_sc{c.sigma_color}_ss{c.sigma_space}'
177
+ )
178
+ return name.replace('.', '-')
179
+
180
+
181
+ def get_dataset(cfg: LaionPrepCFG):
182
+ ds_dir = _dataset_dirname(cfg)
183
+ path = (Path(cfg.cache_dir) / ds_dir).resolve()
184
+
185
+ if path.exists():
186
+ print(f'[load] {path}')
187
+ return load_from_disk(str(path))
188
+
189
+ print(f'[build] {path} (not found, creating now)')
190
+ path.parent.mkdir(parents=True, exist_ok=True)
191
+
192
+ ds = load_dataset(cfg.dataset_name, split='train')
193
+ ds = ds.cast_column('image', datasets.Image(decode=True))
194
+
195
+ image_map = get_image_map(cfg.canny, cfg.resolution)
196
+ ds = ds.map(
197
+ function=image_map,
198
+ batched=True,
199
+ batch_size=cfg.map_bs,
200
+ num_proc=cfg.map_np, # type: ignore
201
+ )
202
+
203
+ ds = ds.cast_column('image', datasets.Image(decode=True))
204
+ ds = ds.cast_column('canny', datasets.Image(decode=True))
205
+
206
+ ds.save_to_disk(str(path))
207
+ print(f'[saved] {path}')
208
+ return ds
209
+
210
+
211
+ def prepare_laion(cfg: LaionPrepCFG):
212
+ ds = get_dataset(cfg)
213
+
214
+ split = ds.train_test_split(test_size=cfg.val_size, seed=cfg.val_seed, shuffle=True) # type: ignore
215
+ train_ds, val_ds = split['train'], split['test']
216
+
217
+ train_ds = train_ds.with_transform(build_prepped_transform())
218
+
219
+ return train_ds, val_ds
220
+