Spaces:
Runtime error
Runtime error
| import copy | |
| import logging | |
| import inspect | |
| from torch.utils.data import DataLoader | |
| from functools import partial | |
| from typing import Callable, Dict, List, Optional, Union | |
| from mmengine.logging import print_log | |
| from mmengine.dist import get_rank | |
| from mmengine.dataset import worker_init_fn as default_worker_init_fn | |
| from mmengine.utils import digit_version | |
| from mmengine.utils.dl_utils import TORCH_VERSION | |
| from mmengine.runner import FlexibleRunner | |
| from mmengine.registry import ( | |
| DATA_SAMPLERS, | |
| DATASETS, | |
| FUNCTIONS, | |
| ) | |
| from xtuner.registry import BUILDER | |
| def clean_concatdataset_fields(cfg): | |
| """ | |
| 递归清除所有 ConcatDataset 配置中的非法字段(如 image_size) | |
| """ | |
| if isinstance(cfg, dict): | |
| # 如果是 ConcatDataset 层,清除非法字段 | |
| if cfg.get('type') == "ConcatDataset": | |
| for key in ['image_size']: | |
| if key in cfg: | |
| del cfg[key] | |
| # 递归处理子字段 | |
| for k, v in cfg.items(): | |
| clean_concatdataset_fields(v) | |
| elif isinstance(cfg, list): | |
| for item in cfg: | |
| clean_concatdataset_fields(item) | |
| return cfg | |
| class CustomRunner(FlexibleRunner): | |
| def __init__( | |
| self, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| def build_dataloader( | |
| dataloader: Union[DataLoader, Dict], | |
| seed: Optional[int] = None, | |
| diff_rank_seed: bool = False, | |
| ) -> DataLoader: | |
| """Build dataloader. | |
| The method builds three components: | |
| - Dataset | |
| - Sampler | |
| - Dataloader | |
| An example of ``dataloader``:: | |
| dataloader = dict( | |
| dataset=dict(type='ToyDataset'), | |
| sampler=dict(type='DefaultSampler', shuffle=True), | |
| batch_size=1, | |
| num_workers=9 | |
| ) | |
| Args: | |
| dataloader (DataLoader or dict): A Dataloader object or a dict to | |
| build Dataloader object. If ``dataloader`` is a Dataloader | |
| object, just returns itself. | |
| seed (int, optional): Random seed. Defaults to None. | |
| diff_rank_seed (bool): Whether or not set different seeds to | |
| different ranks. If True, the seed passed to sampler is set | |
| to None, in order to synchronize the seeds used in samplers | |
| across different ranks. Defaults to False. | |
| Returns: | |
| Dataloader: DataLoader build from ``dataloader_cfg``. | |
| """ | |
| if isinstance(dataloader, DataLoader): | |
| return dataloader | |
| dataloader_cfg = copy.deepcopy(dataloader) | |
| clean_concatdataset_fields(dataloader_cfg) | |
| # build dataset | |
| dataset_cfg = dataloader_cfg.pop('dataset') | |
| if isinstance(dataset_cfg, dict): | |
| dataset = DATASETS.build(dataset_cfg) | |
| if hasattr(dataset, 'full_init'): | |
| dataset.full_init() | |
| else: | |
| # fallback to raise error in dataloader | |
| # if `dataset_cfg` is not a valid type | |
| dataset = dataset_cfg | |
| # build sampler | |
| sampler_cfg = dataloader_cfg.pop('sampler') | |
| if isinstance(sampler_cfg, dict): | |
| sampler_seed = None if diff_rank_seed else seed | |
| sampler = DATA_SAMPLERS.build( | |
| sampler_cfg, | |
| default_args=dict(dataset=dataset, seed=sampler_seed)) | |
| else: | |
| # fallback to raise error in dataloader | |
| # if `sampler_cfg` is not a valid type | |
| sampler = sampler_cfg | |
| # build batch sampler | |
| batch_sampler_cfg = dataloader_cfg.pop('batch_sampler', None) | |
| if batch_sampler_cfg is None: | |
| batch_sampler = None | |
| elif isinstance(batch_sampler_cfg, dict): | |
| batch_sampler = DATA_SAMPLERS.build( | |
| batch_sampler_cfg, | |
| default_args=dict( | |
| dataset=dataset, | |
| sampler=sampler, | |
| batch_size=dataloader_cfg.pop('batch_size'))) | |
| else: | |
| # fallback to raise error in dataloader | |
| # if `batch_sampler_cfg` is not a valid type | |
| batch_sampler = batch_sampler_cfg | |
| # build dataloader | |
| init_fn: Optional[partial] | |
| if 'worker_init_fn' in dataloader_cfg: | |
| worker_init_fn_cfg = dataloader_cfg.pop('worker_init_fn') | |
| worker_init_fn_type = worker_init_fn_cfg.pop('type') | |
| worker_init_fn = FUNCTIONS.get(worker_init_fn_type) | |
| assert callable(worker_init_fn) | |
| init_fn = partial(worker_init_fn, | |
| **worker_init_fn_cfg) # type: ignore | |
| else: | |
| if seed is not None: | |
| disable_subprocess_warning = dataloader_cfg.pop( | |
| 'disable_subprocess_warning', False) | |
| assert isinstance(disable_subprocess_warning, bool), ( | |
| 'disable_subprocess_warning should be a bool, but got ' | |
| f'{type(disable_subprocess_warning)}') | |
| init_fn = partial( | |
| default_worker_init_fn, | |
| num_workers=dataloader_cfg.get('num_workers'), | |
| rank=get_rank(), | |
| seed=seed, | |
| disable_subprocess_warning=disable_subprocess_warning) | |
| else: | |
| init_fn = None | |
| # `persistent_workers` requires pytorch version >= 1.7 | |
| if ('persistent_workers' in dataloader_cfg | |
| and digit_version(TORCH_VERSION) < digit_version('1.7.0')): | |
| print_log( | |
| '`persistent_workers` is only available when ' | |
| 'pytorch version >= 1.7', | |
| logger='current', | |
| level=logging.WARNING) | |
| dataloader_cfg.pop('persistent_workers') | |
| # The default behavior of `collat_fn` in dataloader is to | |
| # merge a list of samples to form a mini-batch of Tensor(s). | |
| # However, in mmengine, if `collate_fn` is not defined in | |
| # dataloader_cfg, `pseudo_collate` will only convert the list of | |
| # samples into a dict without stacking the batch tensor. | |
| collate_fn_cfg = dataloader_cfg.pop('collate_fn', | |
| dict(type='pseudo_collate')) | |
| if isinstance(collate_fn_cfg, dict): | |
| collate_fn_type = collate_fn_cfg.pop('type') | |
| if isinstance(collate_fn_type, str): | |
| collate_fn = FUNCTIONS.get(collate_fn_type) | |
| elif inspect.isclass(collate_fn_type): | |
| collate_fn_cfg['type'] = collate_fn_type | |
| collate_fn = BUILDER.build(collate_fn_cfg) | |
| else: | |
| collate_fn = collate_fn_type | |
| if not inspect.isclass(collate_fn_type): | |
| collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore | |
| elif callable(collate_fn_cfg): | |
| collate_fn = collate_fn_cfg | |
| else: | |
| raise TypeError( | |
| 'collate_fn should be a dict or callable object, but got ' | |
| f'{collate_fn_cfg}') | |
| data_loader = DataLoader( | |
| dataset=dataset, | |
| sampler=sampler if batch_sampler is None else None, | |
| batch_sampler=batch_sampler, | |
| collate_fn=collate_fn, | |
| worker_init_fn=init_fn, | |
| **dataloader_cfg) | |
| return data_loader | |