| | import ast |
| | import json |
| | import logging |
| | import math |
| | import os |
| | import random |
| | import sys |
| | import braceexpand |
| | from dataclasses import dataclass |
| | from multiprocessing import Value |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | import torch |
| | import torchvision.datasets as datasets |
| | import webdataset as wds |
| | from PIL import Image |
| | from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info |
| | from torch.utils.data.distributed import DistributedSampler |
| | from webdataset.filters import _shuffle |
| | from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample |
| |
|
| | try: |
| | import horovod.torch as hvd |
| | except ImportError: |
| | hvd = None |
| |
|
| |
|
| | class CsvDataset(Dataset): |
| | def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t", tokenizer=None): |
| | logging.debug(f'Loading csv data from {input_filename}.') |
| | df = pd.read_csv(input_filename, sep=sep) |
| |
|
| | self.images = df[img_key].tolist() |
| | self.captions = df[caption_key].tolist() |
| | self.transforms = transforms |
| | logging.debug('Done loading data.') |
| |
|
| | self.tokenize = tokenizer |
| |
|
| | def __len__(self): |
| | return len(self.captions) |
| |
|
| | def __getitem__(self, idx): |
| | images = self.transforms(Image.open(str(self.images[idx]))) |
| | texts = self.tokenize([str(self.captions[idx])])[0] |
| | return images, texts |
| |
|
| |
|
| | class SharedEpoch: |
| | def __init__(self, epoch: int = 0): |
| | self.shared_epoch = Value('i', epoch) |
| |
|
| | def set_value(self, epoch): |
| | self.shared_epoch.value = epoch |
| |
|
| | def get_value(self): |
| | return self.shared_epoch.value |
| |
|
| |
|
| | @dataclass |
| | class DataInfo: |
| | dataloader: DataLoader |
| | sampler: DistributedSampler = None |
| | shared_epoch: SharedEpoch = None |
| |
|
| | def set_epoch(self, epoch): |
| | if self.shared_epoch is not None: |
| | self.shared_epoch.set_value(epoch) |
| | if self.sampler is not None and isinstance(self.sampler, DistributedSampler): |
| | self.sampler.set_epoch(epoch) |
| |
|
| |
|
| | def expand_urls(urls, weights=None): |
| | if weights is None: |
| | expanded_urls = wds.shardlists.expand_urls(urls) |
| | return expanded_urls, None |
| | if isinstance(urls, str): |
| | urllist = urls.split("::") |
| | weights = weights.split('::') |
| | assert len(weights) == len(urllist),\ |
| | f"Expected the number of data components ({len(urllist)}) and weights({len(weights)}) to match." |
| | weights = [float(weight) for weight in weights] |
| | all_urls, all_weights = [], [] |
| | for url, weight in zip(urllist, weights): |
| | expanded_url = list(braceexpand.braceexpand(url)) |
| | expanded_weights = [weight for _ in expanded_url] |
| | all_urls.extend(expanded_url) |
| | all_weights.extend(expanded_weights) |
| | return all_urls, all_weights |
| | else: |
| | all_urls = list(urls) |
| | return all_urls, weights |
| |
|
| |
|
| | def get_dataset_size(shards): |
| | shards_list, _ = expand_urls(shards) |
| | dir_path = os.path.dirname(shards_list[0]) |
| | sizes_filename = os.path.join(dir_path, 'sizes.json') |
| | len_filename = os.path.join(dir_path, '__len__') |
| | if os.path.exists(sizes_filename): |
| | sizes = json.load(open(sizes_filename, 'r')) |
| | total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list]) |
| | elif os.path.exists(len_filename): |
| | |
| | total_size = ast.literal_eval(open(len_filename, 'r').read()) |
| | else: |
| | total_size = None |
| | |
| | |
| | |
| | |
| | |
| | num_shards = len(shards_list) |
| | return total_size, num_shards |
| |
|
| |
|
| | def get_imagenet(args, preprocess_fns, split): |
| | assert split in ["train", "val", "v2"] |
| | is_train = split == "train" |
| | preprocess_train, preprocess_val = preprocess_fns |
| |
|
| | if split == "v2": |
| | from imagenetv2_pytorch import ImageNetV2Dataset |
| | dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val) |
| | else: |
| | if is_train: |
| | data_path = args.imagenet_train |
| | preprocess_fn = preprocess_train |
| | else: |
| | data_path = args.imagenet_val |
| | preprocess_fn = preprocess_val |
| | assert data_path |
| |
|
| | dataset = datasets.ImageFolder(data_path, transform=preprocess_fn) |
| |
|
| | if is_train: |
| | idxs = np.zeros(len(dataset.targets)) |
| | target_array = np.array(dataset.targets) |
| | k = 50 |
| | for c in range(1000): |
| | m = target_array == c |
| | n = len(idxs[m]) |
| | arr = np.zeros(n) |
| | arr[:k] = 1 |
| | np.random.shuffle(arr) |
| | idxs[m] = arr |
| |
|
| | idxs = idxs.astype('int') |
| | sampler = SubsetRandomSampler(np.where(idxs)[0]) |
| | else: |
| | sampler = None |
| |
|
| | dataloader = torch.utils.data.DataLoader( |
| | dataset, |
| | batch_size=args.batch_size, |
| | num_workers=args.workers, |
| | sampler=sampler, |
| | ) |
| |
|
| | return DataInfo(dataloader=dataloader, sampler=sampler) |
| |
|
| |
|
| | def count_samples(dataloader): |
| | os.environ["WDS_EPOCH"] = "0" |
| | n_elements, n_batches = 0, 0 |
| | for images, texts in dataloader: |
| | n_batches += 1 |
| | n_elements += len(images) |
| | assert len(images) == len(texts) |
| | return n_elements, n_batches |
| |
|
| |
|
| | def filter_no_caption_or_no_image(sample): |
| | has_caption = ('txt' in sample) |
| | has_image = ('png' in sample or 'jpg' in sample or 'jpeg' in sample or 'webp' in sample) |
| | return has_caption and has_image |
| |
|
| |
|
| | def log_and_continue(exn): |
| | """Call in an exception handler to ignore any exception, issue a warning, and continue.""" |
| | logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.') |
| | return True |
| |
|
| |
|
| | def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): |
| | """Return function over iterator that groups key, value pairs into samples. |
| | |
| | :param keys: function that splits the key into key and extension (base_plus_ext) |
| | :param lcase: convert suffixes to lower case (Default value = True) |
| | """ |
| | current_sample = None |
| | for filesample in data: |
| | assert isinstance(filesample, dict) |
| | fname, value = filesample["fname"], filesample["data"] |
| | prefix, suffix = keys(fname) |
| | if prefix is None: |
| | continue |
| | if lcase: |
| | suffix = suffix.lower() |
| | |
| | |
| | |
| | if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: |
| | if valid_sample(current_sample): |
| | yield current_sample |
| | current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) |
| | if suffixes is None or suffix in suffixes: |
| | current_sample[suffix] = value |
| | if valid_sample(current_sample): |
| | yield current_sample |
| |
|
| |
|
| | def tarfile_to_samples_nothrow(src, handler=log_and_continue): |
| | |
| | streams = url_opener(src, handler=handler) |
| | files = tar_file_expander(streams, handler=handler) |
| | samples = group_by_keys_nothrow(files, handler=handler) |
| | return samples |
| |
|
| |
|
| | def pytorch_worker_seed(increment=0): |
| | """get dataloader worker seed from pytorch""" |
| | worker_info = get_worker_info() |
| | if worker_info is not None: |
| | |
| | seed = worker_info.seed |
| | if increment: |
| | |
| | seed += increment * max(1, worker_info.num_workers) |
| | return seed |
| | |
| | return wds.utils.pytorch_worker_seed() |
| |
|
| |
|
| | _SHARD_SHUFFLE_SIZE = 2000 |
| | _SHARD_SHUFFLE_INITIAL = 500 |
| | _SAMPLE_SHUFFLE_SIZE = 5000 |
| | _SAMPLE_SHUFFLE_INITIAL = 1000 |
| |
|
| |
|
| | class detshuffle2(wds.PipelineStage): |
| | def __init__( |
| | self, |
| | bufsize=1000, |
| | initial=100, |
| | seed=0, |
| | epoch=-1, |
| | ): |
| | self.bufsize = bufsize |
| | self.initial = initial |
| | self.seed = seed |
| | self.epoch = epoch |
| |
|
| | def run(self, src): |
| | if isinstance(self.epoch, SharedEpoch): |
| | epoch = self.epoch.get_value() |
| | else: |
| | |
| | |
| | self.epoch += 1 |
| | epoch = self.epoch |
| | rng = random.Random() |
| | if self.seed < 0: |
| | |
| | seed = pytorch_worker_seed(epoch) |
| | else: |
| | |
| | seed = self.seed + epoch |
| | rng.seed(seed) |
| | return _shuffle(src, self.bufsize, self.initial, rng) |
| |
|
| |
|
| | class ResampledShards2(IterableDataset): |
| | """An iterable dataset yielding a list of urls.""" |
| |
|
| | def __init__( |
| | self, |
| | urls, |
| | weights=None, |
| | nshards=sys.maxsize, |
| | worker_seed=None, |
| | deterministic=False, |
| | epoch=-1, |
| | ): |
| | """Sample shards from the shard list with replacement. |
| | |
| | :param urls: a list of URLs as a Python list or brace notation string |
| | """ |
| | super().__init__() |
| | urls, weights = expand_urls(urls, weights) |
| | self.urls = urls |
| | self.weights = weights |
| | if self.weights is not None: |
| | assert len(self.urls) == len(self.weights),\ |
| | f"Number of urls {len(self.urls)} and weights {len(self.weights)} should match." |
| | assert isinstance(self.urls[0], str) |
| | self.nshards = nshards |
| | self.rng = random.Random() |
| | self.worker_seed = worker_seed |
| | self.deterministic = deterministic |
| | self.epoch = epoch |
| |
|
| | def __iter__(self): |
| | """Return an iterator over the shards.""" |
| | if isinstance(self.epoch, SharedEpoch): |
| | epoch = self.epoch.get_value() |
| | else: |
| | |
| | |
| | self.epoch += 1 |
| | epoch = self.epoch |
| | if self.deterministic: |
| | |
| | if self.worker_seed is None: |
| | |
| | seed = pytorch_worker_seed(epoch) |
| | else: |
| | seed = self.worker_seed() + epoch |
| | self.rng.seed(seed) |
| | for _ in range(self.nshards): |
| | if self.weights is None: |
| | yield dict(url=self.rng.choice(self.urls)) |
| | else: |
| | yield dict(url=self.rng.choices(self.urls, weights=self.weights, k=1)[0]) |
| |
|
| |
|
| | def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokenizer=None): |
| | input_shards = args.train_data if is_train else args.val_data |
| | assert input_shards is not None |
| | resampled = getattr(args, 'dataset_resampled', False) and is_train |
| |
|
| | num_shards = None |
| | if is_train: |
| | if args.train_num_samples is not None: |
| | num_samples = args.train_num_samples |
| | else: |
| | num_samples, num_shards = get_dataset_size(input_shards) |
| | if not num_samples: |
| | raise RuntimeError( |
| | 'Currently, the number of dataset samples must be specified for the training dataset. ' |
| | 'Please specify it via `--train-num-samples` if no dataset length info is present.') |
| | else: |
| | |
| | num_samples = args.val_num_samples or 0 |
| |
|
| | shared_epoch = SharedEpoch(epoch=epoch) |
| |
|
| | if is_train and args.train_data_upsampling_factors is not None: |
| | assert resampled, "--train_data_upsampling_factors is only supported when sampling with replacement (with --dataset-resampled)." |
| | |
| | if resampled: |
| | pipeline = [ResampledShards2( |
| | input_shards, |
| | weights=args.train_data_upsampling_factors, |
| | deterministic=True, |
| | epoch=shared_epoch, |
| | )] |
| | else: |
| | pipeline = [wds.SimpleShardList(input_shards)] |
| |
|
| | |
| | if is_train: |
| | if not resampled: |
| | pipeline.extend([ |
| | detshuffle2( |
| | bufsize=_SHARD_SHUFFLE_SIZE, |
| | initial=_SHARD_SHUFFLE_INITIAL, |
| | seed=args.seed, |
| | epoch=shared_epoch, |
| | ), |
| | wds.split_by_node, |
| | wds.split_by_worker, |
| | ]) |
| | pipeline.extend([ |
| | |
| | tarfile_to_samples_nothrow, |
| | wds.shuffle( |
| | bufsize=_SAMPLE_SHUFFLE_SIZE, |
| | initial=_SAMPLE_SHUFFLE_INITIAL, |
| | ), |
| | ]) |
| | else: |
| | pipeline.extend([ |
| | wds.split_by_worker, |
| | |
| | wds.tarfile_to_samples(handler=log_and_continue), |
| | ]) |
| | pipeline.extend([ |
| | wds.select(filter_no_caption_or_no_image), |
| | wds.decode("pilrgb", handler=log_and_continue), |
| | wds.rename(image="jpg;png;jpeg;webp", text="txt"), |
| | wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), |
| | wds.to_tuple("image", "text"), |
| | wds.batched(args.batch_size, partial=not is_train) |
| | ]) |
| |
|
| | dataset = wds.DataPipeline(*pipeline) |
| |
|
| | if is_train: |
| | if not resampled: |
| | num_shards = num_shards or len(expand_urls(input_shards)[0]) |
| | assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' |
| | |
| | round_fn = math.floor if floor else math.ceil |
| | global_batch_size = args.batch_size * args.world_size |
| | num_batches = round_fn(num_samples / global_batch_size) |
| | num_workers = max(1, args.workers) |
| | num_worker_batches = round_fn(num_batches / num_workers) |
| | num_batches = num_worker_batches * num_workers |
| | num_samples = num_batches * global_batch_size |
| | dataset = dataset.with_epoch(num_worker_batches) |
| | else: |
| | |
| | num_batches = math.ceil(num_samples / args.batch_size) |
| |
|
| | dataloader = wds.WebLoader( |
| | dataset, |
| | batch_size=None, |
| | shuffle=False, |
| | num_workers=args.workers, |
| | persistent_workers=args.workers > 0, |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | dataloader.num_batches = num_batches |
| | dataloader.num_samples = num_samples |
| |
|
| | return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) |
| |
|
| |
|
| | def get_csv_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): |
| | input_filename = args.train_data if is_train else args.val_data |
| | assert input_filename |
| | dataset = CsvDataset( |
| | input_filename, |
| | preprocess_fn, |
| | img_key=args.csv_img_key, |
| | caption_key=args.csv_caption_key, |
| | sep=args.csv_separator, |
| | tokenizer=tokenizer |
| | ) |
| | num_samples = len(dataset) |
| | sampler = DistributedSampler(dataset) if args.distributed and is_train else None |
| | shuffle = is_train and sampler is None |
| |
|
| | dataloader = DataLoader( |
| | dataset, |
| | batch_size=args.batch_size, |
| | shuffle=shuffle, |
| | num_workers=args.workers, |
| | pin_memory=True, |
| | sampler=sampler, |
| | drop_last=is_train, |
| | ) |
| | dataloader.num_samples = num_samples |
| | dataloader.num_batches = len(dataloader) |
| |
|
| | return DataInfo(dataloader, sampler) |
| |
|
| |
|
| | class SyntheticDataset(Dataset): |
| |
|
| | def __init__( |
| | self, |
| | transform=None, |
| | image_size=(224, 224), |
| | caption="Dummy caption", |
| | dataset_size=100, |
| | tokenizer=None, |
| | ): |
| | self.transform = transform |
| | self.image_size = image_size |
| | self.caption = caption |
| | self.image = Image.new('RGB', image_size) |
| | self.dataset_size = dataset_size |
| |
|
| | self.preprocess_txt = lambda text: tokenizer(text)[0] |
| |
|
| | def __len__(self): |
| | return self.dataset_size |
| |
|
| | def __getitem__(self, idx): |
| | if self.transform is not None: |
| | image = self.transform(self.image) |
| | return image, self.preprocess_txt(self.caption) |
| |
|
| |
|
| | def get_synthetic_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): |
| | image_size = preprocess_fn.transforms[0].size |
| | dataset = SyntheticDataset( |
| | transform=preprocess_fn, image_size=image_size, dataset_size=args.train_num_samples, tokenizer=tokenizer) |
| | num_samples = len(dataset) |
| | sampler = DistributedSampler(dataset) if args.distributed and is_train else None |
| | shuffle = is_train and sampler is None |
| |
|
| | dataloader = DataLoader( |
| | dataset, |
| | batch_size=args.batch_size, |
| | shuffle=shuffle, |
| | num_workers=args.workers, |
| | pin_memory=True, |
| | sampler=sampler, |
| | drop_last=is_train, |
| | ) |
| | dataloader.num_samples = num_samples |
| | dataloader.num_batches = len(dataloader) |
| |
|
| | return DataInfo(dataloader, sampler) |
| |
|
| |
|
| | def get_dataset_fn(data_path, dataset_type): |
| | if dataset_type == "webdataset": |
| | return get_wds_dataset |
| | elif dataset_type == "csv": |
| | return get_csv_dataset |
| | elif dataset_type == "synthetic": |
| | return get_synthetic_dataset |
| | elif dataset_type == "auto": |
| | ext = data_path.split('.')[-1] |
| | if ext in ['csv', 'tsv']: |
| | return get_csv_dataset |
| | elif ext in ['tar']: |
| | return get_wds_dataset |
| | else: |
| | raise ValueError( |
| | f"Tried to figure out dataset type, but failed for extension {ext}.") |
| | else: |
| | raise ValueError(f"Unsupported dataset type: {dataset_type}") |
| | |
| |
|
| | def get_data(args, preprocess_fns, epoch=0, tokenizer=None): |
| | preprocess_train, preprocess_val = preprocess_fns |
| | data = {} |
| |
|
| | if args.train_data or args.dataset_type == "synthetic": |
| | data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( |
| | args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer) |
| |
|
| | if args.val_data: |
| | data["val"] = get_dataset_fn(args.val_data, args.dataset_type)( |
| | args, preprocess_val, is_train=False, tokenizer=tokenizer) |
| |
|
| | if args.imagenet_val is not None: |
| | data["imagenet-val"] = get_imagenet(args, preprocess_fns, "val") |
| |
|
| | if args.imagenet_v2 is not None: |
| | data["imagenet-v2"] = get_imagenet(args, preprocess_fns, "v2") |
| |
|
| | return data |
| |
|