| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| from typing import TYPE_CHECKING, Callable, List, Optional, Union |
|
|
| from torch.utils.data import IterableDataset |
| from torchdata.stateful_dataloader import StatefulDataLoader |
| from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler |
|
|
| from ..distributed.parallel_state import get_parallel_state |
| from ..utils import logging |
| from .batching_strategy import TextBatchingStrategy |
| from .data_collator import ( |
| CollatePipeline, |
| DataCollatorWithPacking, |
| DataCollatorWithPadding, |
| DataCollatorWithPositionIDs, |
| MakeMicroBatchCollator, |
| TextSequenceShardCollator, |
| UnpackDataCollator, |
| ) |
| from .dynamic_batching import DynamicBatchSizeDataLoader |
|
|
|
|
| if TYPE_CHECKING: |
| from torch.utils.data import Dataset |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class DistributedDataloader(StatefulDataLoader): |
| dataset: "Dataset" |
| sampler: "StatefulDistributedSampler" |
|
|
| def set_epoch(self, epoch: int) -> None: |
| if self.sampler is not None and hasattr(self.sampler, "set_epoch"): |
| self.sampler.set_epoch(epoch) |
| elif hasattr(self.dataset, "set_epoch"): |
| self.dataset.set_epoch(epoch) |
|
|
|
|
| def build_dataloader( |
| dataset: "Dataset", |
| micro_batch_size: int, |
| global_batch_size: int, |
| dataloader_batch_size: int, |
| max_seq_len: int, |
| train_steps: int, |
| rmpad: bool = True, |
| rmpad_with_pos_ids: bool = False, |
| bsz_warmup_ratio: float = 0.02, |
| bsz_warmup_init_mbtoken: int = 200, |
| dyn_bsz_buffer_size: int = 500, |
| dyn_bsz_margin: int = 0, |
| collate_fn: Optional[Union[Callable, List[Callable]]] = None, |
| num_workers: int = 8, |
| drop_last: bool = True, |
| pin_memory: bool = True, |
| prefetch_factor: Optional[int] = 2, |
| seed: int = 0, |
| ) -> "DistributedDataloader": |
| parallel_state = get_parallel_state() |
| token_micro_bsz = micro_batch_size * max_seq_len |
| num_micro_batch = global_batch_size // ( |
| micro_batch_size * parallel_state.dp_size |
| ) |
| bsz_warmup_steps = int(train_steps * bsz_warmup_ratio) |
| use_rmpad = rmpad or rmpad_with_pos_ids |
| logger.info_rank0( |
| f"train_steps: {train_steps}, max_seq_len: {max_seq_len}, use_rmpad: {use_rmpad}, " |
| f"bsz_warmup_steps: {bsz_warmup_steps}, bsz_warmup_init_mbtoken: {bsz_warmup_init_mbtoken}, " |
| f"token_micro_bsz: {token_micro_bsz}, num_micro_batch: {num_micro_batch}, " |
| f"micro_batch_size: {micro_batch_size}, global_batch_size: {global_batch_size}, " |
| f"dp_size: {parallel_state.dp_size}, sp_size: {parallel_state.sp_size}." |
| ) |
|
|
| if collate_fn is None: |
| collate_fn_list = [] |
| if rmpad_with_pos_ids: |
| collate_fn_list.append(DataCollatorWithPositionIDs()) |
| elif rmpad: |
| collate_fn_list.append(DataCollatorWithPacking()) |
| else: |
| collate_fn_list.append(DataCollatorWithPadding()) |
|
|
| if parallel_state.sp_enabled: |
| collate_fn_list.append(TextSequenceShardCollator(rmpad=rmpad, rmpad_with_pos_ids=rmpad_with_pos_ids)) |
|
|
| collate_fn = CollatePipeline(collate_fn_list) |
|
|
| if isinstance(collate_fn, list): |
| collate_fn = CollatePipeline(collate_fn) |
|
|
| if use_rmpad: |
| batching_strategy = TextBatchingStrategy( |
| token_micro_bsz=token_micro_bsz - dyn_bsz_margin * max_seq_len, |
| buffer_size=dyn_bsz_buffer_size, |
| bsz_warmup_steps=bsz_warmup_steps if bsz_warmup_steps else -1, |
| bsz_warmup_init_mbtoken=bsz_warmup_init_mbtoken, |
| ) |
| dyn_bsz_collate_fn = collate_fn |
| collate_fn = UnpackDataCollator() |
| else: |
| collate_fn = MakeMicroBatchCollator(num_micro_batch=num_micro_batch, internal_data_collator=collate_fn) |
|
|
| sampler = None |
| if not isinstance(dataset, IterableDataset): |
| sampler = StatefulDistributedSampler( |
| dataset, |
| num_replicas=parallel_state.dp_size, |
| rank=parallel_state.dp_rank, |
| shuffle=True, |
| seed=seed, |
| ) |
|
|
| dataloader = DistributedDataloader( |
| dataset, |
| batch_size=dataloader_batch_size, |
| sampler=sampler, |
| num_workers=num_workers, |
| collate_fn=collate_fn, |
| pin_memory=pin_memory, |
| drop_last=drop_last, |
| prefetch_factor=prefetch_factor, |
| ) |
| if use_rmpad: |
| dataloader = DynamicBatchSizeDataLoader( |
| dataloader, |
| batching_strategy=batching_strategy, |
| collate_fn=dyn_bsz_collate_fn, |
| num_micro_batch=num_micro_batch, |
| length=train_steps, |
| drop_last=drop_last, |
| ) |
|
|
| return dataloader |
|
|