""" Preprocess and load datasets for training. """ import functools import io import json import math import re import random import numpy as np import torch import torchvision import webdataset as wds from PIL import Image import base64 from scipy.optimize import linear_sum_assignment from data_utils import * Image.MAX_IMAGE_PIXELS = 1000000000 N_CHANNELS = 3 MIN_KB = 10 _SHARD_SHUFFLE_SIZE = 2000 _SHARD_SHUFFLE_INITIAL = 500 _SAMPLE_SHUFFLE_SIZE = 5000 _SAMPLE_SHUFFLE_INITIAL = 1000 try: import horovod.torch as hvd except ImportError: hvd = None def preprocess_image(sample, image_processor): """ Convert images to tensors for training. Augmentations: random horizontal flip. Normalization handled by wds. """ image = [image_processor(s).unsqueeze(0) for s in sample] image = torch.cat(image, dim=0) image = torchvision.transforms.RandomHorizontalFlip(p=0.5)(image) return image def filter_no_caption_or_no_image(sample): """ Filter out LAION samples with no caption or no image. """ return ("txt" in sample) and ( "png" in sample or "jpg" in sample or "jpeg" in sample ) def preprocess_laion_text(sample, tokenizer, max_tokens=32): """ Preprocess text for LAION. Captions are truncated to 32 tokens by default. """ tokenizer.padding_side = "right" sample = [ (f"{s.strip()}<|endofchunk|>{tokenizer.eos_token}") for s in sample ] text = tokenizer( sample, max_length=max_tokens, padding="longest", truncation="only_first", return_tensors="pt", ) return text["input_ids"], text["attention_mask"] def preprocess_gpt_interleaved( info, tokenizer, clip_processor, min_num_images, max_num_images, max_tokens=256 ): """ Preprocess a ChatGPT-generated image-text sequence. """ text = info["example"] text = re.sub(r"_!_IMAGE\d+_!_", "<|endofchunk|>", text) # convert images from base64 to PIL images = [] for image_key in range(1, len(info["image_map"]) + 1): image_base64 = info["image_map"][f"_!_IMAGE{image_key}_!_"]["base64_image"] rawbytes = base64.b64decode(image_base64) images.append(Image.open(io.BytesIO(rawbytes)).convert("RGB")) # preprocess and pad images images_tensors = preprocess_image(images, clip_processor) keep_ixs = range(min(len(images_tensors), max_num_images)) images_tensors = images_tensors[keep_ixs] if len(images_tensors) < max_num_images: zero_padding = torch.zeros( (max_num_images - len(images_tensors), 3, 224, 224), dtype=torch.float ) images_tensors = torch.cat((images_tensors, zero_padding), dim=0) # preprocess and tokenize text text = text.replace("<|endofchunk|>", "", 1) # but remove first eoc # whitespace cleanup text = ( text.replace(" <|endofchunk|>", "<|endofchunk|>") .replace(" ", "") .replace(" ", "") ) indices = [m.start() for m in re.finditer("", text)] if len(indices) > max_num_images: start_index = indices[max_num_images - 1] text = text[:start_index] text = f"{text}<|endofchunk|>{tokenizer.eos_token}" tokenizer.padding_side = "right" text_tensor = tokenizer( text, max_length=max_tokens, truncation=True, padding="max_length", return_tensors="pt", ) # reject sequences with too few images after truncation num_images = torch.count_nonzero( text_tensor["input_ids"] == tokenizer.additional_special_tokens_ids[ tokenizer.additional_special_tokens.index("") ] ) if num_images < min_num_images: raise ValueError(f"Fewer than {min_num_images} images in sample") return (images_tensors, (text_tensor["input_ids"], text_tensor["attention_mask"])) def preprocess_interleaved( sample, tokenizer, clip_processor, sim_threshold, min_num_images, max_num_images, max_tokens=256, ): """ Preprocess an interleaved image-text sequence, either by calling preprocess_gpt_interleaved (if the sequence is ChatGPT-generated) or by preprocessing in this function (if the sequences is from MMC4). """ info = json.loads(sample[0]) if "is_gpt" in info: return preprocess_gpt_interleaved( info, tokenizer, clip_processor, min_num_images, max_num_images, max_tokens ) sentences = info["text_list"] sim_matrix = info["similarity_matrix"] # load images first to find which ones are valid valid_images, valid_image_indices = [], [] for i, sample_image in enumerate(info["image_info"]): if "image_base64" not in sample_image: continue image_base64 = sample_image["image_base64"] rawbytes = base64.b64decode(image_base64) # filter to images >= 10KB if len(rawbytes) // 1000 <= MIN_KB: continue image = Image.open(io.BytesIO(rawbytes)).convert("RGB") valid_images.append(image) valid_image_indices.append(i) if len(valid_image_indices) == 0: raise ValueError("No images in sample") sim_matrix = np.array(sim_matrix) # of shape images x sentences sim_matrix = sim_matrix[valid_image_indices] # negate the similarities to turn then into costs cost_matrix = -sim_matrix # find one to one assignements image_indices, sentence_indices = linear_sum_assignment(cost_matrix) images, sentence_ixs = [], [] for i, sim_ix in zip(image_indices, sentence_indices): sim_score = sim_matrix[i][sim_ix] if sim_score < sim_threshold: continue images.append(valid_images[i]) sentence_ixs.append(sim_ix) if len(images) == 0: raise ValueError("No images in sample") # preprocess and pad images images_tensors = preprocess_image(images, clip_processor) keep_ixs = range(min(len(images_tensors), max_num_images)) images_tensors = images_tensors[keep_ixs] sentence_ixs = [sentence_ixs[ix] for ix in keep_ixs] if len(images_tensors) < max_num_images: zero_padding = torch.zeros( ( max_num_images - len(images_tensors), N_CHANNELS, images_tensors[0].shape[1], images_tensors[0].shape[2], ), dtype=torch.float, ) images_tensors = torch.cat((images_tensors, zero_padding), dim=0) # preprocess and tokenize text # add in and tokens for ix in sentence_ixs: sentences[ix] = f"<|endofchunk|>{sentences[ix]}" text = " ".join(sentences) text = text.replace("<|endofchunk|>", "", 1) # but remove first eoc # whitespace cleanup text = ( text.replace(" <|endofchunk|>", "<|endofchunk|>") .replace(" ", "") .replace(" ", "") ) text = f"{text}<|endofchunk|>{tokenizer.eos_token}" tokenizer.padding_side = "right" text_tensor = tokenizer( text, max_length=max_tokens, truncation=True, padding="max_length", return_tensors="pt", ) # reject sequences with too few images (after truncation) num_images = torch.count_nonzero( text_tensor["input_ids"] == tokenizer.additional_special_tokens_ids[ tokenizer.additional_special_tokens.index("") ] ) if num_images < min_num_images: raise ValueError(f"Fewer than {min_num_images} images in sample") elif ( num_images == 1 and random.random() <= 0.5 ): # 50% chance of keeping single image samples raise ValueError("Only one image in sample") # avoid the situation where there's one token and it's at the end if ( num_images == 1 and text_tensor["input_ids"][:, -1] == tokenizer.additional_special_tokens_ids[ tokenizer.additional_special_tokens.index("") ] ): raise ValueError( "Only one image at the end of sample, so labels will all be -100" ) return ( images_tensors, (text_tensor["input_ids"], text_tensor["attention_mask"]), ) def get_mmc4_dataset(args, image_processor, tokenizer, epoch=0, floor=False): """ Initialize webdataset for MMC4 / ChatGPT sequences """ input_shards = args.mmc4_shards assert input_shards is not None resampled = getattr(args, "dataset_resampled", False) num_samples, num_shards = get_dataset_size(input_shards) num_samples = None if not num_samples: num_samples = args.train_num_samples_mmc4 if not num_samples: raise RuntimeError( "Currently, number of dataset samples must be specified for training dataset. " "Please specify via `--train-num-samples` if no dataset length info present." ) # create a shared epoch store to sync epoch to dataloader worker proc shared_epoch = SharedEpoch(epoch=epoch) if resampled: pipeline = [ ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch) ] else: pipeline = [wds.SimpleShardList(input_shards)] preprocess_fn = functools.partial( preprocess_interleaved, clip_processor=image_processor, tokenizer=tokenizer, sim_threshold=args.mmc4_textsim_threshold, min_num_images=args.mmc4_min_num_images, max_num_images=args.mmc4_max_num_images, ) # at this point we have an iterator over all the shards if not resampled: pipeline.extend( [ detshuffle2( bufsize=_SHARD_SHUFFLE_SIZE, initial=_SHARD_SHUFFLE_INITIAL, seed=args.seed, epoch=shared_epoch, ), wds.split_by_node, wds.split_by_worker, ] ) pipeline.extend( [ # at this point, we have an iterator over the shards assigned to each worker at each node # wds.tarfile_to_samples(handler=log_and_continue), tarfile_to_samples_nothrow, wds.shuffle( bufsize=_SAMPLE_SHUFFLE_SIZE, initial=_SAMPLE_SHUFFLE_INITIAL, ), ] ) pipeline.extend( [ wds.to_tuple("json", handler=log_and_continue), wds.map(preprocess_fn, handler=log_and_continue), wds.batched(args.batch_size_mmc4, partial=False), ] ) dataset = wds.DataPipeline(*pipeline) if not resampled: assert ( num_shards >= args.workers * args.world_size ), "number of shards must be >= total workers" # roll over and repeat a few samples to get same number of full batches on each node round_fn = math.floor if floor else math.ceil global_batch_size = args.batch_size_mmc4 * args.world_size num_batches = round_fn(num_samples / global_batch_size) num_workers = max(1, args.workers) num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker num_batches = num_worker_batches * num_workers num_samples = num_batches * global_batch_size # each worker is iterating over this dataset = dataset.with_epoch(num_worker_batches) dataloader = wds.WebLoader( dataset, batch_size=None, shuffle=False, num_workers=args.workers, persistent_workers=True, ) # add meta-data to dataloader instance for convenience dataloader.num_batches = num_batches dataloader.num_samples = num_samples return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) def get_laion_dataset(args, image_processor, tokenizer, epoch=0, floor=False): """ Initialize webdataset for LAION data """ input_shards = args.laion_shards assert input_shards is not None resampled = getattr(args, "dataset_resampled", False) num_samples, num_shards = get_dataset_size(input_shards) num_samples = None if not num_samples: num_samples = args.train_num_samples_laion if not num_samples: raise RuntimeError( "Currently, number of dataset samples must be specified for training dataset. " "Please specify via `--train-num-samples` if no dataset length info present." ) # create a shared epoch store to sync epoch to dataloader worker proc shared_epoch = SharedEpoch(epoch=epoch) if resampled: pipeline = [ ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch) ] else: pipeline = [wds.SimpleShardList(input_shards)] # create two preprocess functions that take in the passed in image_processor and tokenizer preprocess_image_fn = functools.partial( preprocess_image, image_processor=image_processor ) preprocess_text_fn = functools.partial(preprocess_laion_text, tokenizer=tokenizer) # at this point we have an iterator over all the shards if not resampled: pipeline.extend( [ detshuffle2( bufsize=_SHARD_SHUFFLE_SIZE, initial=_SHARD_SHUFFLE_INITIAL, seed=args.seed, epoch=shared_epoch, ), wds.split_by_node, wds.split_by_worker, ] ) pipeline.extend( [ # at this point, we have an iterator over the shards assigned to each worker at each node # wds.tarfile_to_samples(handler=log_and_continue), tarfile_to_samples_nothrow, wds.shuffle( bufsize=_SAMPLE_SHUFFLE_SIZE, initial=_SAMPLE_SHUFFLE_INITIAL, ), ] ) pipeline.extend( [ wds.select(filter_no_caption_or_no_image), wds.decode("pilrgb", handler=log_and_continue), wds.to_tuple("jpg;png;jpeg", "txt", handler=log_and_continue), wds.batched(args.batch_size_laion, partial=False), wds.map_tuple( preprocess_image_fn, preprocess_text_fn, handler=log_and_continue ), ] ) dataset = wds.DataPipeline(*pipeline) if not resampled: assert ( num_shards >= args.workers * args.world_size ), "number of shards must be >= total workers" # roll over and repeat a few samples to get same number of full batches on each node round_fn = math.floor if floor else math.ceil global_batch_size = args.batch_size_laion * args.world_size num_batches = round_fn(num_samples / global_batch_size) num_workers = max(1, args.workers) num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker num_batches = num_worker_batches * num_workers num_samples = num_batches * global_batch_size # each worker is iterating over this dataset = dataset.with_epoch(num_worker_batches) dataloader = wds.WebLoader( dataset, batch_size=None, shuffle=False, num_workers=args.workers, persistent_workers=True, ) # add meta-data to dataloader instance for convenience dataloader.num_batches = num_batches dataloader.num_samples = num_samples return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) def get_dataset_fn(dataset_type): """ Helper function to get the dataset function based on the dataset type """ if dataset_type == "image_text": return get_laion_dataset elif dataset_type == "mmc4": return get_mmc4_dataset else: raise ValueError(f"Unsupported dataset type: {dataset_type}") def get_data(args, image_processor, tokenizer, dataset_type, epoch=0): """ Interface for getting the webdatasets """ return get_dataset_fn(dataset_type)( args, image_processor=image_processor, epoch=epoch, tokenizer=tokenizer )