| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
|
|
| import logging |
| import os |
| import json |
| import pathlib |
| from os.path import basename |
|
|
| from timm.data import create_loader |
| import torch |
| import torch.utils.data |
| import torch.distributed as dist |
| import torchvision.datasets as datasets |
| from torchvision.io import read_image |
| import torch.distributed as dist |
| from pathlib import Path |
| from yacs.config import CfgNode as CN |
|
|
| from ..LangEncoder import build_tokenizer |
|
|
| from .tsv import TSVImageTextDatasetV2 |
| from .tsv import TSVMeta |
| from .transforms import build_transforms |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def build_dataset(cfg, is_train): |
| if cfg['DATASET']['DATASET'] == 'image_text_pairs_v2': |
| dataset = _build_pairs_dataset_v2(cfg, is_train) |
| else: |
| raise ValueError(f'Unknown dataset: {cfg["DATASET"]["DATASET"]}') |
| return dataset |
|
|
|
|
| def _get_tsv_list(cfg, is_train): |
| tmp_list = [] |
| if is_train and 'TRAIN_TSV_LIST' in cfg['DATASET']: |
| tmp_list = cfg['DATASET']['TRAIN_TSV_LIST'] |
| elif 'TEST_TSV_LIST' in cfg['DATASET']: |
| tmp_list = cfg['DATASET']['TEST_TSV_LIST'] |
|
|
| tsv_list = [] |
| for l in tmp_list: |
| if l.endswith('.list'): |
| with open(l, 'r') as f: |
| tsv_list.extend([i.strip() for i in f]) |
| else: |
| tsv_list.append(l) |
|
|
| logger.info(f'tsv list: {tsv_list}') |
|
|
| return tsv_list |
|
|
|
|
| def _get_token_file(cfg): |
| num_nodes = dist.get_world_size() // torch.cuda.device_count() |
| if isinstance(cfg['DATASET']['TOKEN_FILE'], list): |
| if num_nodes == 1: |
| logger.warning('=> Multi token files are provided, but only one node is used for training') |
| sas_token_file = cfg['DATASET']['TOKEN_FILE'][0] |
| else: |
| rank = dist.get_rank() |
| node_idx = rank // torch.cuda.device_count() |
| num_token_files = len(cfg['DATASET']['TOKEN_FILE']) |
| sas_token_file = cfg['DATASET']['TOKEN_FILE'][node_idx % num_token_files] |
| else: |
| sas_token_file = cfg['DATASET']['TOKEN_FILE'] |
|
|
| sas_token_file = os.path.join(cfg['DATASET']['ROOT'], sas_token_file) |
|
|
| if ( |
| cfg['DATASET']['LOADER'] == 'blobfuse' |
| or not os.path.isfile(sas_token_file) |
| ): |
| sas_token_file = None |
|
|
| return sas_token_file |
|
|
|
|
| def _build_pairs_dataset_v2(cfg, is_train): |
| transforms = build_transforms(cfg, is_train) |
| logger.info('transforms: {}'.format(transforms)) |
|
|
| dataset_name = cfg['DATASET']['TRAIN_SET'] \ |
| if is_train else cfg['DATASET']['TEST_SET'] |
|
|
| tokenobj = build_tokenizer(cfg['LANG_ENCODER']) |
|
|
| if cfg['DATASET']['DATA_FORMAT'] != 'tsv': |
| raise ValueError('Only support tsv format for pairs dataset v2') |
|
|
| tsv_list = _get_tsv_list(cfg, is_train) |
|
|
| if len(tsv_list) > 0: |
| tsv_filenames = sorted( |
| [ |
| os.path.join(cfg['DATASET']['ROOT'], dataset_name, f) |
| for f in tsv_list |
| ] |
| ) |
| else: |
| dataset_path = os.path.join(cfg['DATASET']['ROOT'], dataset_name) |
| tsv_files = Path(dataset_path).glob('**/*.tsv') |
|
|
| tsv_filenames = sorted( |
| [ |
| str(path) |
| for path in tsv_files |
| ] |
| ) |
|
|
| image_tsv_files = [ |
| filename |
| for filename in tsv_filenames |
| if ( |
| 'image-' in basename(filename) |
| or 'image_' in basename(filename) |
| or '_image' in basename(filename) |
| or '-image' in basename(filename) |
| or 'images-' in basename(filename) |
| ) |
| ] |
| text_tsv_files = [ |
| filename |
| for filename in tsv_filenames |
| if ( |
| 'text-' in basename(filename) |
| or 'text_' in basename(filename) |
| or '_text' in basename(filename) |
| or '-text' in basename(filename) |
| or 'texts-' in basename(filename) |
| ) |
| ] |
|
|
| logger.info( |
| "=> found %d/%d tsv file(s) to load.", |
| len(image_tsv_files), len(text_tsv_files) |
| ) |
|
|
| num_captions = 1 \ |
| if is_train else cfg['DATASET'].get('NUM_CAPTIONS', 1) |
| text_format = cfg['DATASET'].get('TEXT_FORMAT', 'json') |
|
|
| sas_token_file = _get_token_file(cfg) |
| logger.info("=> SAS token path: %s", sas_token_file) |
|
|
| metas = [] |
| cfg_data = cfg['DATASET'] |
| if 'CLASSIFICATION_SETS' in cfg_data and 'NUM_CLASSES' in cfg_data: |
| for source, num_classes in zip(cfg_data['CLASSIFICATION_SETS'], cfg_data['NUM_CLASSES']): |
| metas.append( |
| TSVMeta( |
| source=source, |
| num_classes=num_classes, |
| task='classification' |
| ) |
| ) |
| logger.info('=> add meta: {}'.format(metas[-1])) |
|
|
| if 'coco-caption' in dataset_name: |
| logger.info('=> coco caption data is used') |
| logger.info('=> update num_captions: 5, text_format: json') |
| logger.warning('=> set sas token to None for coco evaluation') |
| sas_token_file = None |
| num_captions = 5 |
| text_format = 'json' |
|
|
| dataset = TSVImageTextDatasetV2( |
| image_tsv_files, text_tsv_files, |
| transform=transforms, |
| tokenize=tokenobj, |
| context_length=cfg['LANG_ENCODER']['CONTEXT_LENGTH'], |
| num_captions=num_captions, |
| text_format=text_format, |
| is_train=is_train, |
| sas_token_path=sas_token_file, |
| metas=metas, |
| prompt_engineering=cfg['DATASET'].get('PROMPT_ENGINEERING', True), |
| concat_queries=cfg['DATASET'].get('CONCAT_QUERIES', False) |
| ) |
|
|
| logger.info( |
| "=> %s set size: %d", 'train' |
| if is_train else 'val', len(dataset) |
| ) |
|
|
| return dataset |
|
|
|
|
| def build_dataloader(cfg, is_train=True, distributed=False): |
| dataset = build_dataset(cfg, is_train) |
|
|
| if ( |
| is_train |
| and 'TIMM_AUG' in cfg['AUG'] |
| and cfg['AUG']['TIMM_AUG']['USE_LOADER'] |
| ): |
| logger.info('=> use timm loader for training') |
| timm_cfg = CN(init_dict=cfg['AUG']['TIMM_AUG']) |
| data_loader = create_loader( |
| dataset, |
| input_size=cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0], |
| batch_size=cfg['TRAIN']['BATCH_SIZE_PER_GPU'], |
| is_training=True, |
| use_prefetcher=True, |
| no_aug=False, |
| re_prob=timm_cfg.RE_PROB, |
| re_mode=timm_cfg.RE_MODE, |
| re_count=timm_cfg.RE_COUNT, |
| re_split=timm_cfg.RE_SPLIT, |
| scale=cfg['AUG']['SCALE'], |
| ratio=cfg['AUG']['RATIO'], |
| hflip=timm_cfg.HFLIP, |
| vflip=timm_cfg.VFLIP, |
| color_jitter=timm_cfg.COLOR_JITTER, |
| auto_augment=timm_cfg.AUTO_AUGMENT, |
| num_aug_splits=0, |
| interpolation=cfg['AUG']['INTERPOLATION'], |
| mean=cfg['IMAGE_ENCODER']['IMAGE_MEAN'], |
| std=cfg['IMAGE_ENCODER']['IMAGE_STD'], |
| num_workers=cfg['WORKERS'], |
| distributed=distributed, |
| collate_fn=None, |
| pin_memory=cfg['PIN_MEMORY'], |
| use_multi_epochs_loader=True |
| ) |
| else: |
| if is_train: |
| batch_size_per_gpu = cfg['TRAIN']['BATCH_SIZE_PER_GPU'] |
| shuffle = cfg['TRAIN'].get('SHUFFLE', True) |
| else: |
| batch_size_per_gpu = cfg['TEST']['BATCH_SIZE_PER_GPU'] |
| shuffle = cfg['TEST'].get('SHUFFLE', False) |
|
|
| if distributed or cfg.get('ALWAYS_ENABLE_SAMPLER', False): |
| |
| sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle) |
| shuffle = False |
| else: |
| sampler = None |
|
|
| data_loader = torch.utils.data.DataLoader( |
| dataset, |
| batch_size=batch_size_per_gpu, |
| shuffle=shuffle, |
| num_workers=cfg['WORKERS'], |
| pin_memory=cfg['PIN_MEMORY'], |
| sampler=sampler, |
| drop_last=True if is_train else False, |
| prefetch_factor=cfg.get('PREFETCH_FACTOR', 2) |
| ) |
|
|
| return data_loader |
|
|
|
|
|
|
|
|
|
|