jon-kyl's picture
add everything
49e7d2b
import io
import os
import queue
import threading
import jax
import numpy as np
import PIL.Image
import tqdm
from typing import Callable, Generator, Iterator
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from functools import partial
from datasets import load_dataset, IterableDatasetDict, Image as HFImage
from diffusers import FlaxAutoencoderKL
from PIL import PngImagePlugin
from quantization import optimized_for_sdxl, QuantizationType
PngImagePlugin.MAX_TEXT_CHUNK = 100 * (1024 * 1024) # MB
def load_encoder_decoder() -> Callable[[jax.Array], jax.Array]:
vae, params = FlaxAutoencoderKL.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", subfolder="vae")
def encode(x):
x = x.transpose((0, 3, 1, 2))
x = vae.apply({"params": params}, x, method=vae.encode).latent_dist.mean
x = x * vae.config.scaling_factor
return x
def decode(x):
x = x / vae.config.scaling_factor
x = x.transpose((0, 3, 1, 2))
x = vae.apply({"params": params}, x, method=vae.decode).sample
x = x.transpose((0, 2, 3, 1))
return x
return encode, decode
def load_encoder():
return load_encoder_decoder()[0]
def load_decoder():
return load_encoder_decoder()[1]
def stream_imagenet_raw() -> IterableDatasetDict:
return load_dataset(
path="ILSVRC/imagenet-1k",
revision="4603483700ee984ea9debe3ddbfdeae86f6489eb",
streaming=True,
trust_remote_code=True,
).cast_column("image", HFImage(decode=False))
BatchType = tuple[np.ndarray, tuple[int], tuple[str]]
def get_name(item: dict):
return os.path.basename(os.path.splitext(item["image"]["path"])[0])
def process_imagenet_batch(
batch: list[dict],
size: int,
resampling: PIL.Image.Resampling = PIL.Image.Resampling.LANCZOS,
) -> BatchType:
def process_example(item: dict) -> tuple[np.ndarray, int, str]:
label = item["label"]
image_bytes = item["image"]["bytes"]
image = PIL.Image.open(io.BytesIO(image_bytes)).convert("RGB")
width, height = image.size
short_edge = min(width, height)
left = (width - short_edge) // 2
upper = (height - short_edge) // 2
right = left + short_edge
lower = upper + short_edge
image = image.crop((left, upper, right, lower))
image = image.resize((size, size), resample=resampling)
image = np.asarray(image)
name = get_name(item)
return image, label, name
images, labels, names = zip(*map(process_example, batch))
images = np.array(images, dtype=np.float32) / 127.5 - 1
return images, labels, names
def parallel_process_dataset(
dataset_iterator: Iterator[BatchType],
batch_processor: Callable[[list[dict]], BatchType],
batch_size: int,
max_workers: int,
prefetch_factor: int = 2,
filter_fn: Callable[[dict], bool] = lambda _: True,
) -> Generator[BatchType, None, None]:
queue_size = max_workers * prefetch_factor
results_queue = queue.Queue(maxsize=queue_size)
done_loading = threading.Event()
def queue_filler():
current_batch = []
dataset_exhausted = False
pending_futures = set()
counter = 0
# Outer loop: check for processed batches and enqueue them.
while pending_futures or not dataset_exhausted:
# Middle loop: submit new batch processing jobs.
while len(pending_futures) < queue_size and not dataset_exhausted:
# Inner loop: grow the batch from the iterator.
while len(current_batch) < batch_size and not dataset_exhausted:
try:
item = next(dataset_iterator)
counter += 1
if filter_fn(item):
current_batch.append(item)
else:
print("skipping item", counter)
except StopIteration:
dataset_exhausted = True
break
# We have a full batch, so submit it to the parallel executor.
if current_batch:
future = executor.submit(batch_processor, current_batch)
pending_futures.add(future)
current_batch = []
else:
break
# Filter the finished jobs from the pending jobs.
done_futures = {f for f in pending_futures if f.done()}
pending_futures -= done_futures
# Enqueue all the finished batches.
for future in done_futures:
try:
results_queue.put(future.result())
except Exception as e:
print(f"Error processing batch: {e}")
# No more pending jobs or dataset elements left to process.
done_loading.set()
# Run the queue filler in a separate thread (batch processing is still multiprocess).
with ProcessPoolExecutor(max_workers=max_workers) as executor:
thread = threading.Thread(target=queue_filler, daemon=True)
thread.start()
# Main loop: pop batches from queue until dataset is exhausted.
while not (done_loading.is_set() and results_queue.empty()):
yield results_queue.get()
results_queue.task_done()
thread.join()
def gen_codes(
dataset_iterator: Iterator[BatchType],
batch_processor: Callable[[list[dict]], BatchType],
encoder: Callable[[jax.Array], jax.Array],
batch_size: int,
num_workers: int,
quantization: QuantizationType = optimized_for_sdxl,
filter_fn: Callable[[dict], bool] = lambda _: True,
) -> Generator[BatchType, None, None]:
@jax.jit
def encode_and_quantize(x):
return quantization.quantize(encoder(x))
for images, labels, names in parallel_process_dataset(
dataset_iterator,
batch_processor=batch_processor,
batch_size=batch_size,
max_workers=num_workers,
filter_fn=filter_fn,
):
codes = np.array(encode_and_quantize(images))
yield codes, labels, names
def get_save_path(label: int, output_dir: str, name: str) -> str:
folder_name = f"{label:03d}" if label >= 0 else "UNK"
folder_path = os.path.join(output_dir, folder_name)
img_path = os.path.join(folder_path, f"{name}.png")
return img_path
def save_image(
image: np.ndarray,
label: int,
output_dir: str,
name: str,
overwrite: bool = False,
) -> str:
img_path = get_save_path(label, output_dir, name)
os.makedirs(os.path.dirname(img_path), exist_ok=True)
if overwrite or not os.path.exists(img_path):
PIL.Image.fromarray(image).save(img_path)
return img_path
def save_images_async(
loader: Generator[BatchType, None, None],
output_dir: str,
max_workers: int,
) -> list[str]:
futures = set()
exception_event = threading.Event()
exception = [None]
def signal_error(future):
try:
future.result()
except Exception as exc:
exception[0] = exc
exception_event.set()
with ThreadPoolExecutor(max_workers=max_workers) as executor:
for batch in loader:
for image, label, name in zip(*batch):
future = executor.submit(
save_image,
image=image,
label=label,
output_dir=output_dir,
name=name,
)
future.add_done_callback(signal_error)
futures.add(future)
if exception_event.is_set():
break
if exception[0] is not None:
raise exception[0]
return [f.result() for f in futures]
def mock_encoder(x: jax.Array) -> jax.Array:
b, h, w, c = x.shape
y = x.reshape(b, h//8, 8, w//8, 8, c).mean((2, 4))[..., (0, 1, 2, 0)]
return y * 3
def should_process(item: dict, output_dir: str) -> bool:
label = item["label"]
name = get_name(item)
path = get_save_path(label, output_dir, name)
exists = os.path.exists(path)
return not exists
def main(
*,
batch_size: int,
image_size: int,
num_workers: int,
save_threads: int,
output_dir: str,
mock_encoding: bool,
) -> None:
imagenet = stream_imagenet_raw()
encoder = load_encoder() if not mock_encoding else mock_encoder
batch_processor = partial(process_imagenet_batch, size=image_size)
for split in imagenet:
split_dir = os.path.join(output_dir, split)
filter_fn = partial(should_process, output_dir=split_dir)
save_images_async(
tqdm.tqdm(
gen_codes(
dataset_iterator=iter(imagenet[split]),
batch_processor=batch_processor,
encoder=encoder,
batch_size=batch_size,
num_workers=num_workers,
filter_fn=filter_fn,
),
desc=split,
),
output_dir=split_dir,
max_workers=save_threads,
)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=96)
parser.add_argument("--image_size", type=int, default=256)
parser.add_argument("--num_workers", type=int, default=6)
parser.add_argument("--save_threads", type=int, default=4)
parser.add_argument("--output_dir", type=str, default="./data")
parser.add_argument("--mock_encoding", action="store_true", default=False)
kwargs = vars(parser.parse_args())
main(**kwargs)