| import copy |
| import logging |
| import inspect |
|
|
| from torch.utils.data import DataLoader |
| from functools import partial |
| from typing import Callable, Dict, List, Optional, Union |
|
|
| from mmengine.logging import print_log |
| from mmengine.dist import get_rank |
| from mmengine.dataset import worker_init_fn as default_worker_init_fn |
| from mmengine.utils import digit_version |
| from mmengine.utils.dl_utils import TORCH_VERSION |
| from mmengine.runner import FlexibleRunner |
| from mmengine.registry import ( |
| DATA_SAMPLERS, |
| DATASETS, |
| FUNCTIONS, |
| ) |
| from xtuner.registry import BUILDER |
|
|
|
|
| def clean_concatdataset_fields(cfg): |
| """ |
| 递归清除所有 ConcatDataset 配置中的非法字段(如 image_size) |
| """ |
| if isinstance(cfg, dict): |
| |
| if cfg.get('type') == "ConcatDataset": |
| for key in ['image_size']: |
| if key in cfg: |
| del cfg[key] |
|
|
| |
| for k, v in cfg.items(): |
| clean_concatdataset_fields(v) |
|
|
| elif isinstance(cfg, list): |
| for item in cfg: |
| clean_concatdataset_fields(item) |
|
|
| return cfg |
|
|
|
|
|
|
| class CustomRunner(FlexibleRunner): |
| def __init__( |
| self, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
|
|
| @staticmethod |
| def build_dataloader( |
| dataloader: Union[DataLoader, Dict], |
| seed: Optional[int] = None, |
| diff_rank_seed: bool = False, |
| ) -> DataLoader: |
| """Build dataloader. |
| |
| The method builds three components: |
| |
| - Dataset |
| - Sampler |
| - Dataloader |
| |
| An example of ``dataloader``:: |
| |
| dataloader = dict( |
| dataset=dict(type='ToyDataset'), |
| sampler=dict(type='DefaultSampler', shuffle=True), |
| batch_size=1, |
| num_workers=9 |
| ) |
| |
| Args: |
| dataloader (DataLoader or dict): A Dataloader object or a dict to |
| build Dataloader object. If ``dataloader`` is a Dataloader |
| object, just returns itself. |
| seed (int, optional): Random seed. Defaults to None. |
| diff_rank_seed (bool): Whether or not set different seeds to |
| different ranks. If True, the seed passed to sampler is set |
| to None, in order to synchronize the seeds used in samplers |
| across different ranks. Defaults to False. |
| |
| Returns: |
| Dataloader: DataLoader build from ``dataloader_cfg``. |
| """ |
| if isinstance(dataloader, DataLoader): |
| return dataloader |
|
|
| dataloader_cfg = copy.deepcopy(dataloader) |
|
|
| clean_concatdataset_fields(dataloader_cfg) |
|
|
| |
| dataset_cfg = dataloader_cfg.pop('dataset') |
| if isinstance(dataset_cfg, dict): |
| dataset = DATASETS.build(dataset_cfg) |
| if hasattr(dataset, 'full_init'): |
| dataset.full_init() |
| else: |
| |
| |
| dataset = dataset_cfg |
|
|
| |
| sampler_cfg = dataloader_cfg.pop('sampler') |
| if isinstance(sampler_cfg, dict): |
| sampler_seed = None if diff_rank_seed else seed |
| sampler = DATA_SAMPLERS.build( |
| sampler_cfg, |
| default_args=dict(dataset=dataset, seed=sampler_seed)) |
| else: |
| |
| |
| sampler = sampler_cfg |
|
|
| |
| batch_sampler_cfg = dataloader_cfg.pop('batch_sampler', None) |
| if batch_sampler_cfg is None: |
| batch_sampler = None |
| elif isinstance(batch_sampler_cfg, dict): |
| batch_sampler = DATA_SAMPLERS.build( |
| batch_sampler_cfg, |
| default_args=dict( |
| dataset=dataset, |
| sampler=sampler, |
| batch_size=dataloader_cfg.pop('batch_size'))) |
| else: |
| |
| |
| batch_sampler = batch_sampler_cfg |
|
|
| |
| init_fn: Optional[partial] |
| if 'worker_init_fn' in dataloader_cfg: |
| worker_init_fn_cfg = dataloader_cfg.pop('worker_init_fn') |
| worker_init_fn_type = worker_init_fn_cfg.pop('type') |
| worker_init_fn = FUNCTIONS.get(worker_init_fn_type) |
| assert callable(worker_init_fn) |
| init_fn = partial(worker_init_fn, |
| **worker_init_fn_cfg) |
| else: |
| if seed is not None: |
| disable_subprocess_warning = dataloader_cfg.pop( |
| 'disable_subprocess_warning', False) |
| assert isinstance(disable_subprocess_warning, bool), ( |
| 'disable_subprocess_warning should be a bool, but got ' |
| f'{type(disable_subprocess_warning)}') |
| init_fn = partial( |
| default_worker_init_fn, |
| num_workers=dataloader_cfg.get('num_workers'), |
| rank=get_rank(), |
| seed=seed, |
| disable_subprocess_warning=disable_subprocess_warning) |
| else: |
| init_fn = None |
|
|
| |
| if ('persistent_workers' in dataloader_cfg |
| and digit_version(TORCH_VERSION) < digit_version('1.7.0')): |
| print_log( |
| '`persistent_workers` is only available when ' |
| 'pytorch version >= 1.7', |
| logger='current', |
| level=logging.WARNING) |
| dataloader_cfg.pop('persistent_workers') |
|
|
| |
| |
| |
| |
| |
| collate_fn_cfg = dataloader_cfg.pop('collate_fn', |
| dict(type='pseudo_collate')) |
| if isinstance(collate_fn_cfg, dict): |
| collate_fn_type = collate_fn_cfg.pop('type') |
| if isinstance(collate_fn_type, str): |
| collate_fn = FUNCTIONS.get(collate_fn_type) |
| elif inspect.isclass(collate_fn_type): |
| collate_fn_cfg['type'] = collate_fn_type |
| collate_fn = BUILDER.build(collate_fn_cfg) |
| else: |
| collate_fn = collate_fn_type |
| if not inspect.isclass(collate_fn_type): |
| collate_fn = partial(collate_fn, **collate_fn_cfg) |
| elif callable(collate_fn_cfg): |
| collate_fn = collate_fn_cfg |
| else: |
| raise TypeError( |
| 'collate_fn should be a dict or callable object, but got ' |
| f'{collate_fn_cfg}') |
| data_loader = DataLoader( |
| dataset=dataset, |
| sampler=sampler if batch_sampler is None else None, |
| batch_sampler=batch_sampler, |
| collate_fn=collate_fn, |
| worker_init_fn=init_fn, |
| **dataloader_cfg) |
|
|
| return data_loader |
|
|