Spaces:
Sleeping
Sleeping
| """This file contains the definition of data loader using webdataset. | |
| This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”). | |
| All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. | |
| Reference: | |
| https://github.com/mlfoundations/open_clip/blob/main/src/training/data.py | |
| https://github.com/huggingface/open-muse/blob/main/training/data.py | |
| """ | |
| import math | |
| from typing import List, Union, Text | |
| import webdataset as wds | |
| import torch | |
| from torch.utils.data import default_collate | |
| from torchvision import transforms | |
| from torch.utils.data import Dataset | |
| import linecache | |
| import json | |
| def filter_keys(key_set): | |
| def _f(dictionary): | |
| return {k: v for k, v in dictionary.items() if k in key_set} | |
| return _f | |
| class ImageTransform: | |
| def __init__(self, | |
| resize_shorter_edge: int = 256, | |
| crop_size: int = 256, | |
| random_crop: bool = True, | |
| random_flip: bool = True, | |
| normalize_mean: List[float] = [0., 0., 0.], | |
| normalize_std: List[float] = [1., 1., 1.]): | |
| """Initializes the WebDatasetReader with specified augmentation parameters. | |
| Args: | |
| resize_shorter_edge: An integer, the shorter edge size to resize the input image to. | |
| crop_size: An integer, the size to crop the input image to. | |
| random_crop: A boolean, whether to use random crop augmentation during training. | |
| random_flip: A boolean, whether to use random flipping augmentation during training. | |
| normalize_mean: A list of float, the normalization mean used to normalize the image tensor. | |
| normalize_std: A list of float, the normalization std used to normalize the image tensor. | |
| Raises: | |
| NotImplementedError: If the interpolation mode is not one of ["bicubic", "bilinear"]. | |
| """ | |
| train_transform = [] | |
| interpolation = transforms.InterpolationMode.BICUBIC | |
| train_transform.append( | |
| transforms.Resize(resize_shorter_edge, interpolation=interpolation, antialias=True)) | |
| if random_crop: | |
| train_transform.append(transforms.RandomCrop(crop_size)) | |
| else: | |
| train_transform.append(transforms.CenterCrop(crop_size)) | |
| if random_flip: | |
| train_transform.append(transforms.RandomHorizontalFlip()) | |
| train_transform.append(transforms.ToTensor()) | |
| # normalize_mean = [0, 0, 0] and normalize_std = [1, 1, 1] will normalize images into [0, 1], | |
| # normalize_mean = [0.5, 0.5, 0.5] and normalize_std = [0.5, 0.5, 0.5] will normalize images into [-1, 1]. | |
| train_transform.append(transforms.Normalize(normalize_mean, normalize_std)) | |
| self.train_transform = transforms.Compose(train_transform) | |
| self.eval_transform = transforms.Compose( | |
| [ | |
| # Note that we always resize to crop_size during eval to ensure the results | |
| # can be compared against reference numbers on ImageNet etc. | |
| transforms.Resize(crop_size, interpolation=interpolation, antialias=True), | |
| transforms.CenterCrop(crop_size), | |
| transforms.ToTensor(), | |
| transforms.Normalize(normalize_mean, normalize_std) | |
| ] | |
| ) | |
| print(f"self.train_transform: {self.train_transform}") | |
| print(f"self.eval_transform: {self.eval_transform}") | |
| class SimpleImageDataset: | |
| def __init__( | |
| self, | |
| train_shards_path: Union[Text, List[Text]], | |
| eval_shards_path: Union[Text, List[Text]], | |
| num_train_examples: int, | |
| per_gpu_batch_size: int, | |
| global_batch_size: int, | |
| num_workers_per_gpu: int = 12, | |
| resize_shorter_edge: int = 256, | |
| crop_size: int = 256, | |
| random_crop = True, | |
| random_flip = True, | |
| normalize_mean: List[float] = [0., 0., 0.], | |
| normalize_std: List[float] = [1., 1., 1.], | |
| ): | |
| """Initializes the WebDatasetReader class. | |
| Args: | |
| train_shards_path: A string or list of string, path to the training data shards in webdataset format. | |
| eval_shards_path: A string or list of string, path to the evaluation data shards in webdataset format. | |
| num_train_examples: An integer, total number of training examples. | |
| per_gpu_batch_size: An integer, number of examples per GPU batch. | |
| global_batch_size: An integer, total number of examples in a batch across all GPUs. | |
| num_workers_per_gpu: An integer, number of workers per GPU. | |
| resize_shorter_edge: An integer, the shorter edge size to resize the input image to. | |
| crop_size: An integer, the size to crop the input image to. | |
| random_crop: A boolean, whether to use random crop augmentation during training. | |
| random_flip: A boolean, whether to use random flipping augmentation during training. | |
| normalize_mean: A list of float, the normalization mean used to normalize the image tensor. | |
| normalize_std: A list of float, the normalization std used to normalize the image tensor. | |
| """ | |
| transform = ImageTransform( | |
| resize_shorter_edge, crop_size, random_crop, random_flip, | |
| normalize_mean, normalize_std) | |
| train_processing_pipeline = [ | |
| wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"])), | |
| wds.rename( | |
| image="jpg;png;jpeg;webp", | |
| class_id="cls", | |
| handler=wds.warn_and_continue, | |
| ), | |
| wds.map(filter_keys(set(["image", "class_id", "filename"]))), | |
| wds.map_dict( | |
| image=transform.train_transform, | |
| class_id=lambda x: int(x), | |
| handler=wds.warn_and_continue, | |
| ), | |
| ] | |
| test_processing_pipeline = [ | |
| wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"])), | |
| wds.rename( | |
| image="jpg;png;jpeg;webp", | |
| class_id="cls", | |
| handler=wds.warn_and_continue, | |
| ), | |
| wds.map(filter_keys(set(["image", "class_id", "filename"]))), | |
| wds.map_dict( | |
| image=transform.eval_transform, | |
| class_id=lambda x: int(x), | |
| handler=wds.warn_and_continue, | |
| ), | |
| ] | |
| # Create train dataset and loader. | |
| pipeline = [ | |
| wds.ResampledShards(train_shards_path), | |
| wds.tarfile_to_samples(handler=wds.warn_and_continue), | |
| wds.shuffle(bufsize=5000, | |
| initial=1000), | |
| *train_processing_pipeline, | |
| wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate), | |
| ] | |
| num_batches = math.ceil(num_train_examples / global_batch_size) | |
| num_worker_batches = math.ceil(num_train_examples / | |
| (global_batch_size * num_workers_per_gpu)) | |
| num_batches = num_worker_batches * num_workers_per_gpu | |
| num_samples = num_batches * global_batch_size | |
| # Each worker is iterating over the complete dataset. | |
| self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches) | |
| self._train_dataloader = wds.WebLoader( | |
| self._train_dataset, | |
| batch_size=None, | |
| shuffle=False, | |
| num_workers=num_workers_per_gpu, | |
| pin_memory=True, | |
| persistent_workers=True, | |
| ) | |
| # Add meta-data to dataloader instance for convenience. | |
| self._train_dataloader.num_batches = num_batches | |
| self._train_dataloader.num_samples = num_samples | |
| # Create eval dataset and loader. | |
| pipeline = [ | |
| wds.SimpleShardList(eval_shards_path), | |
| wds.split_by_worker, | |
| wds.tarfile_to_samples(handler=wds.ignore_and_continue), | |
| *test_processing_pipeline, | |
| wds.batched(per_gpu_batch_size, partial=True, collation_fn=default_collate), | |
| ] | |
| self._eval_dataset = wds.DataPipeline(*pipeline) | |
| self._eval_dataloader = wds.WebLoader( | |
| self._eval_dataset, | |
| batch_size=None, | |
| shuffle=False, | |
| num_workers=num_workers_per_gpu, | |
| pin_memory=True, | |
| persistent_workers=True, | |
| ) | |
| def train_dataset(self): | |
| return self._train_dataset | |
| def train_dataloader(self): | |
| return self._train_dataloader | |
| def eval_dataset(self): | |
| return self._eval_dataset | |
| def eval_dataloader(self): | |
| return self._eval_dataloader | |
| class PretoeknizedDataSetJSONL(Dataset): | |
| def __init__(self, data_path): | |
| super().__init__() | |
| self.jsonl_file = data_path | |
| self.num_lines = sum(1 for _ in open(self.jsonl_file)) | |
| # Ensure the file is cached | |
| linecache.checkcache(self.jsonl_file) | |
| print("Number of data:", self.num_lines) | |
| def __len__(self): | |
| return self.num_lines | |
| def __getitem__(self, idx): | |
| line = linecache.getline(self.jsonl_file, idx + 1).strip() | |
| data = json.loads(line) | |
| return torch.tensor(data["class_id"]), torch.tensor(data["tokens"]) |