aw1app's picture
Deploy real interactive model
c6cb161
"""
Preprocess and load datasets for training.
"""
import functools
import io
import json
import math
import re
import random
import numpy as np
import torch
import torchvision
import webdataset as wds
from PIL import Image
import base64
from scipy.optimize import linear_sum_assignment
from data_utils import *
Image.MAX_IMAGE_PIXELS = 1000000000
N_CHANNELS = 3
MIN_KB = 10
_SHARD_SHUFFLE_SIZE = 2000
_SHARD_SHUFFLE_INITIAL = 500
_SAMPLE_SHUFFLE_SIZE = 5000
_SAMPLE_SHUFFLE_INITIAL = 1000
try:
import horovod.torch as hvd
except ImportError:
hvd = None
def preprocess_image(sample, image_processor):
"""
Convert images to tensors for training.
Augmentations: random horizontal flip.
Normalization handled by wds.
"""
image = [image_processor(s).unsqueeze(0) for s in sample]
image = torch.cat(image, dim=0)
image = torchvision.transforms.RandomHorizontalFlip(p=0.5)(image)
return image
def filter_no_caption_or_no_image(sample):
"""
Filter out LAION samples with no caption or no image.
"""
return ("txt" in sample) and (
"png" in sample or "jpg" in sample or "jpeg" in sample
)
def preprocess_laion_text(sample, tokenizer, max_tokens=32):
"""
Preprocess text for LAION.
Captions are truncated to 32 tokens by default.
"""
tokenizer.padding_side = "right"
sample = [
(f"<image>{s.strip()}<|endofchunk|>{tokenizer.eos_token}") for s in sample
]
text = tokenizer(
sample,
max_length=max_tokens,
padding="longest",
truncation="only_first",
return_tensors="pt",
)
return text["input_ids"], text["attention_mask"]
def preprocess_gpt_interleaved(
info, tokenizer, clip_processor, min_num_images, max_num_images, max_tokens=256
):
"""
Preprocess a ChatGPT-generated image-text sequence.
"""
text = info["example"]
text = re.sub(r"_!_IMAGE\d+_!_", "<|endofchunk|><image>", text)
# convert images from base64 to PIL
images = []
for image_key in range(1, len(info["image_map"]) + 1):
image_base64 = info["image_map"][f"_!_IMAGE{image_key}_!_"]["base64_image"]
rawbytes = base64.b64decode(image_base64)
images.append(Image.open(io.BytesIO(rawbytes)).convert("RGB"))
# preprocess and pad images
images_tensors = preprocess_image(images, clip_processor)
keep_ixs = range(min(len(images_tensors), max_num_images))
images_tensors = images_tensors[keep_ixs]
if len(images_tensors) < max_num_images:
zero_padding = torch.zeros(
(max_num_images - len(images_tensors), 3, 224, 224), dtype=torch.float
)
images_tensors = torch.cat((images_tensors, zero_padding), dim=0)
# preprocess and tokenize text
text = text.replace("<|endofchunk|>", "", 1) # but remove first eoc
# whitespace cleanup
text = (
text.replace(" <|endofchunk|>", "<|endofchunk|>")
.replace("<image> ", "<image>")
.replace(" <image>", "<image>")
)
indices = [m.start() for m in re.finditer("<image>", text)]
if len(indices) > max_num_images:
start_index = indices[max_num_images - 1]
text = text[:start_index]
text = f"{text}<|endofchunk|>{tokenizer.eos_token}"
tokenizer.padding_side = "right"
text_tensor = tokenizer(
text,
max_length=max_tokens,
truncation=True,
padding="max_length",
return_tensors="pt",
)
# reject sequences with too few images after truncation
num_images = torch.count_nonzero(
text_tensor["input_ids"]
== tokenizer.additional_special_tokens_ids[
tokenizer.additional_special_tokens.index("<image>")
]
)
if num_images < min_num_images:
raise ValueError(f"Fewer than {min_num_images} images in sample")
return (images_tensors, (text_tensor["input_ids"], text_tensor["attention_mask"]))
def preprocess_interleaved(
sample,
tokenizer,
clip_processor,
sim_threshold,
min_num_images,
max_num_images,
max_tokens=256,
):
"""
Preprocess an interleaved image-text sequence, either by calling preprocess_gpt_interleaved (if the sequence
is ChatGPT-generated) or by preprocessing in this function (if the sequences is from MMC4).
"""
info = json.loads(sample[0])
if "is_gpt" in info:
return preprocess_gpt_interleaved(
info, tokenizer, clip_processor, min_num_images, max_num_images, max_tokens
)
sentences = info["text_list"]
sim_matrix = info["similarity_matrix"]
# load images first to find which ones are valid
valid_images, valid_image_indices = [], []
for i, sample_image in enumerate(info["image_info"]):
if "image_base64" not in sample_image:
continue
image_base64 = sample_image["image_base64"]
rawbytes = base64.b64decode(image_base64)
# filter to images >= 10KB
if len(rawbytes) // 1000 <= MIN_KB:
continue
image = Image.open(io.BytesIO(rawbytes)).convert("RGB")
valid_images.append(image)
valid_image_indices.append(i)
if len(valid_image_indices) == 0:
raise ValueError("No images in sample")
sim_matrix = np.array(sim_matrix) # of shape images x sentences
sim_matrix = sim_matrix[valid_image_indices]
# negate the similarities to turn then into costs
cost_matrix = -sim_matrix
# find one to one assignements
image_indices, sentence_indices = linear_sum_assignment(cost_matrix)
images, sentence_ixs = [], []
for i, sim_ix in zip(image_indices, sentence_indices):
sim_score = sim_matrix[i][sim_ix]
if sim_score < sim_threshold:
continue
images.append(valid_images[i])
sentence_ixs.append(sim_ix)
if len(images) == 0:
raise ValueError("No images in sample")
# preprocess and pad images
images_tensors = preprocess_image(images, clip_processor)
keep_ixs = range(min(len(images_tensors), max_num_images))
images_tensors = images_tensors[keep_ixs]
sentence_ixs = [sentence_ixs[ix] for ix in keep_ixs]
if len(images_tensors) < max_num_images:
zero_padding = torch.zeros(
(
max_num_images - len(images_tensors),
N_CHANNELS,
images_tensors[0].shape[1],
images_tensors[0].shape[2],
),
dtype=torch.float,
)
images_tensors = torch.cat((images_tensors, zero_padding), dim=0)
# preprocess and tokenize text
# add in <image> and <eoc> tokens
for ix in sentence_ixs:
sentences[ix] = f"<|endofchunk|><image>{sentences[ix]}"
text = " ".join(sentences)
text = text.replace("<|endofchunk|>", "", 1) # but remove first eoc
# whitespace cleanup
text = (
text.replace(" <|endofchunk|>", "<|endofchunk|>")
.replace("<image> ", "<image>")
.replace(" <image>", "<image>")
)
text = f"{text}<|endofchunk|>{tokenizer.eos_token}"
tokenizer.padding_side = "right"
text_tensor = tokenizer(
text,
max_length=max_tokens,
truncation=True,
padding="max_length",
return_tensors="pt",
)
# reject sequences with too few images (after truncation)
num_images = torch.count_nonzero(
text_tensor["input_ids"]
== tokenizer.additional_special_tokens_ids[
tokenizer.additional_special_tokens.index("<image>")
]
)
if num_images < min_num_images:
raise ValueError(f"Fewer than {min_num_images} images in sample")
elif (
num_images == 1 and random.random() <= 0.5
): # 50% chance of keeping single image samples
raise ValueError("Only one image in sample")
# avoid the situation where there's one <image> token and it's at the end
if (
num_images == 1
and text_tensor["input_ids"][:, -1]
== tokenizer.additional_special_tokens_ids[
tokenizer.additional_special_tokens.index("<image>")
]
):
raise ValueError(
"Only one image at the end of sample, so labels will all be -100"
)
return (
images_tensors,
(text_tensor["input_ids"], text_tensor["attention_mask"]),
)
def get_mmc4_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
"""
Initialize webdataset for MMC4 / ChatGPT sequences
"""
input_shards = args.mmc4_shards
assert input_shards is not None
resampled = getattr(args, "dataset_resampled", False)
num_samples, num_shards = get_dataset_size(input_shards)
num_samples = None
if not num_samples:
num_samples = args.train_num_samples_mmc4
if not num_samples:
raise RuntimeError(
"Currently, number of dataset samples must be specified for training dataset. "
"Please specify via `--train-num-samples` if no dataset length info present."
)
# create a shared epoch store to sync epoch to dataloader worker proc
shared_epoch = SharedEpoch(epoch=epoch)
if resampled:
pipeline = [
ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch)
]
else:
pipeline = [wds.SimpleShardList(input_shards)]
preprocess_fn = functools.partial(
preprocess_interleaved,
clip_processor=image_processor,
tokenizer=tokenizer,
sim_threshold=args.mmc4_textsim_threshold,
min_num_images=args.mmc4_min_num_images,
max_num_images=args.mmc4_max_num_images,
)
# at this point we have an iterator over all the shards
if not resampled:
pipeline.extend(
[
detshuffle2(
bufsize=_SHARD_SHUFFLE_SIZE,
initial=_SHARD_SHUFFLE_INITIAL,
seed=args.seed,
epoch=shared_epoch,
),
wds.split_by_node,
wds.split_by_worker,
]
)
pipeline.extend(
[
# at this point, we have an iterator over the shards assigned to each worker at each node
# wds.tarfile_to_samples(handler=log_and_continue),
tarfile_to_samples_nothrow,
wds.shuffle(
bufsize=_SAMPLE_SHUFFLE_SIZE,
initial=_SAMPLE_SHUFFLE_INITIAL,
),
]
)
pipeline.extend(
[
wds.to_tuple("json", handler=log_and_continue),
wds.map(preprocess_fn, handler=log_and_continue),
wds.batched(args.batch_size_mmc4, partial=False),
]
)
dataset = wds.DataPipeline(*pipeline)
if not resampled:
assert (
num_shards >= args.workers * args.world_size
), "number of shards must be >= total workers"
# roll over and repeat a few samples to get same number of full batches on each node
round_fn = math.floor if floor else math.ceil
global_batch_size = args.batch_size_mmc4 * args.world_size
num_batches = round_fn(num_samples / global_batch_size)
num_workers = max(1, args.workers)
num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker
num_batches = num_worker_batches * num_workers
num_samples = num_batches * global_batch_size
# each worker is iterating over this
dataset = dataset.with_epoch(num_worker_batches)
dataloader = wds.WebLoader(
dataset,
batch_size=None,
shuffle=False,
num_workers=args.workers,
persistent_workers=True,
)
# add meta-data to dataloader instance for convenience
dataloader.num_batches = num_batches
dataloader.num_samples = num_samples
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
def get_laion_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
"""
Initialize webdataset for LAION data
"""
input_shards = args.laion_shards
assert input_shards is not None
resampled = getattr(args, "dataset_resampled", False)
num_samples, num_shards = get_dataset_size(input_shards)
num_samples = None
if not num_samples:
num_samples = args.train_num_samples_laion
if not num_samples:
raise RuntimeError(
"Currently, number of dataset samples must be specified for training dataset. "
"Please specify via `--train-num-samples` if no dataset length info present."
)
# create a shared epoch store to sync epoch to dataloader worker proc
shared_epoch = SharedEpoch(epoch=epoch)
if resampled:
pipeline = [
ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch)
]
else:
pipeline = [wds.SimpleShardList(input_shards)]
# create two preprocess functions that take in the passed in image_processor and tokenizer
preprocess_image_fn = functools.partial(
preprocess_image, image_processor=image_processor
)
preprocess_text_fn = functools.partial(preprocess_laion_text, tokenizer=tokenizer)
# at this point we have an iterator over all the shards
if not resampled:
pipeline.extend(
[
detshuffle2(
bufsize=_SHARD_SHUFFLE_SIZE,
initial=_SHARD_SHUFFLE_INITIAL,
seed=args.seed,
epoch=shared_epoch,
),
wds.split_by_node,
wds.split_by_worker,
]
)
pipeline.extend(
[
# at this point, we have an iterator over the shards assigned to each worker at each node
# wds.tarfile_to_samples(handler=log_and_continue),
tarfile_to_samples_nothrow,
wds.shuffle(
bufsize=_SAMPLE_SHUFFLE_SIZE,
initial=_SAMPLE_SHUFFLE_INITIAL,
),
]
)
pipeline.extend(
[
wds.select(filter_no_caption_or_no_image),
wds.decode("pilrgb", handler=log_and_continue),
wds.to_tuple("jpg;png;jpeg", "txt", handler=log_and_continue),
wds.batched(args.batch_size_laion, partial=False),
wds.map_tuple(
preprocess_image_fn, preprocess_text_fn, handler=log_and_continue
),
]
)
dataset = wds.DataPipeline(*pipeline)
if not resampled:
assert (
num_shards >= args.workers * args.world_size
), "number of shards must be >= total workers"
# roll over and repeat a few samples to get same number of full batches on each node
round_fn = math.floor if floor else math.ceil
global_batch_size = args.batch_size_laion * args.world_size
num_batches = round_fn(num_samples / global_batch_size)
num_workers = max(1, args.workers)
num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker
num_batches = num_worker_batches * num_workers
num_samples = num_batches * global_batch_size
# each worker is iterating over this
dataset = dataset.with_epoch(num_worker_batches)
dataloader = wds.WebLoader(
dataset,
batch_size=None,
shuffle=False,
num_workers=args.workers,
persistent_workers=True,
)
# add meta-data to dataloader instance for convenience
dataloader.num_batches = num_batches
dataloader.num_samples = num_samples
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
def get_dataset_fn(dataset_type):
"""
Helper function to get the dataset function based on the dataset type
"""
if dataset_type == "image_text":
return get_laion_dataset
elif dataset_type == "mmc4":
return get_mmc4_dataset
else:
raise ValueError(f"Unsupported dataset type: {dataset_type}")
def get_data(args, image_processor, tokenizer, dataset_type, epoch=0):
"""
Interface for getting the webdatasets
"""
return get_dataset_fn(dataset_type)(
args, image_processor=image_processor, epoch=epoch, tokenizer=tokenizer
)