| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Batching utils supports stateful dataloader. |
| | |
| | 1. Init stateful dataloader (tokenize) |
| | 2. Add to buffer |
| | 3. Yield batch indexes (micro batch * grad acc) |
| | a) non pack + non dynamic |
| | b) non pack + dynamic |
| | c) pack + non dynamic |
| | d) pack + dynamic |
| | """ |
| |
|
| | from collections.abc import Iterator |
| | from typing import Any |
| |
|
| | from torch.utils.data import default_collate |
| | from torchdata.stateful_dataloader import StatefulDataLoader |
| | from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler |
| |
|
| | from ...accelerator.interface import DistributedInterface |
| | from ...config import BatchingStrategy |
| | from ...utils import logging |
| | from ...utils.helper import pad_and_truncate |
| | from ...utils.objects import StatefulBuffer |
| | from ...utils.types import BatchInfo, BatchInput, ModelInput, TorchDataset |
| | from .rendering import Renderer |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | def default_collate_fn(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None: |
| | micro_batch_size = batch_info["micro_batch_size"] |
| | num_micro_batch = batch_info["num_micro_batch"] |
| | cutoff_len = batch_info["cutoff_len"] |
| | batch_size = micro_batch_size * num_micro_batch |
| | if len(buffer) < batch_size: |
| | return None |
| |
|
| | samples = buffer.get(batch_size) |
| | batch = [] |
| | for i in range(num_micro_batch): |
| | micro_batch = samples[i * micro_batch_size : (i + 1) * micro_batch_size] |
| | batch.append(default_collate(pad_and_truncate(micro_batch, cutoff_len))) |
| |
|
| | return batch |
| |
|
| |
|
| | class BatchGenerator(Iterator): |
| | def __init__( |
| | self, |
| | dataset: TorchDataset, |
| | renderer: Renderer, |
| | micro_batch_size: int = 1, |
| | global_batch_size: int | None = None, |
| | cutoff_len: int = 2048, |
| | batching_workers: int = 0, |
| | batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL, |
| | pin_memory: bool = True, |
| | drop_last: bool = True, |
| | ) -> None: |
| | self.dataset = dataset |
| | self.renderer = renderer |
| |
|
| | self.micro_batch_size = micro_batch_size |
| | self.global_batch_size = global_batch_size |
| | self.cutoff_len = cutoff_len |
| | self.batching_workers = batching_workers |
| | self.batching_strategy = batching_strategy |
| | self.pin_memory = pin_memory |
| | self.drop_last = drop_last |
| | |
| |
|
| | dp_size = DistributedInterface().get_world_size("dp") |
| |
|
| | if self.global_batch_size is None: |
| | self.global_batch_size = dp_size * micro_batch_size |
| | self.num_micro_batch = 1 |
| | elif self.global_batch_size % (dp_size * micro_batch_size) == 0: |
| | self.num_micro_batch = global_batch_size // dp_size // micro_batch_size |
| | else: |
| | raise ValueError( |
| | "Global batch size must be divisible by DP size and micro batch size. " |
| | f"Got {global_batch_size} % ({dp_size} * {micro_batch_size}) != 0." |
| | ) |
| |
|
| | if not self.drop_last: |
| | raise ValueError("Drop last must be True.") |
| |
|
| | self._init_data_provider() |
| |
|
| | self._is_resuming: bool = False |
| | self._data_iter = iter(self._data_provider) |
| | self._buffer = StatefulBuffer() |
| |
|
| | self._batch_info: BatchInfo = { |
| | "micro_batch_size": self.micro_batch_size, |
| | "num_micro_batch": self.num_micro_batch, |
| | "cutoff_len": self.cutoff_len, |
| | "data_iter": self._data_iter, |
| | } |
| |
|
| | logger.info_rank0( |
| | f"Init unified data loader with global batch size {self.global_batch_size}, " |
| | f"micro batch size {self.micro_batch_size}, " |
| | f"num micro batch {self.num_micro_batch}, " |
| | f"cutoff len {self.cutoff_len}, " |
| | f"batching workers {self.batching_workers}, " |
| | f"batching strategy {self.batching_strategy}." |
| | ) |
| |
|
| | def _init_data_provider(self) -> None: |
| | if len(self.dataset) != -1: |
| | sampler = StatefulDistributedSampler( |
| | self.dataset, |
| | num_replicas=DistributedInterface().get_world_size("dp"), |
| | rank=DistributedInterface().get_rank("dp"), |
| | shuffle=True, |
| | seed=0, |
| | drop_last=self.drop_last, |
| | ) |
| | else: |
| | raise NotImplementedError("Iterable dataset is not supported yet.") |
| |
|
| | self._data_provider = StatefulDataLoader( |
| | self.dataset, |
| | batch_size=self.micro_batch_size * self.num_micro_batch, |
| | sampler=sampler, |
| | num_workers=self.batching_workers, |
| | collate_fn=self.renderer.process_samples, |
| | pin_memory=self.pin_memory, |
| | drop_last=self.drop_last, |
| | ) |
| | if self.batching_strategy == BatchingStrategy.NORMAL: |
| | self._length = len(self._data_provider) |
| | else: |
| | from ...plugins.trainer_plugins.batching import BatchingPlugin |
| |
|
| | self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider) |
| | raise NotImplementedError("Batching strategy other than NORMAL is not supported yet.") |
| |
|
| | def __len__(self) -> int: |
| | return self._length |
| |
|
| | def __iter__(self): |
| | if not self._is_resuming: |
| | self._buffer.clear() |
| | self._buffer_tokens = 0 |
| |
|
| | self._data_iter = iter(self._data_provider) |
| | self._is_resuming = False |
| | return self |
| |
|
| | def __next__(self): |
| | self._fill_buffer() |
| | batch = self._generate_batch() |
| | if batch is None: |
| | raise StopIteration |
| |
|
| | return batch |
| |
|
| | def _fill_buffer(self) -> None: |
| | if self.batching_strategy == BatchingStrategy.NORMAL: |
| | while len(self._buffer) < self.micro_batch_size * self.num_micro_batch: |
| | try: |
| | samples: list[ModelInput] = next(self._data_iter) |
| | except StopIteration: |
| | break |
| |
|
| | self._buffer.put(samples) |
| | else: |
| | from ...plugins.trainer_plugins.batching import BatchingPlugin |
| |
|
| | BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info) |
| |
|
| | def _generate_batch(self) -> list[BatchInput] | None: |
| | if self.batching_strategy == BatchingStrategy.NORMAL: |
| | return default_collate_fn(self._buffer, self._batch_info) |
| | else: |
| | from ...plugins.trainer_plugins.batching import BatchingPlugin |
| |
|
| | return BatchingPlugin(self.batching_strategy).generate_batch(self._buffer, self._batch_info) |
| |
|
| | def state_dict(self) -> dict[str, Any]: |
| | return { |
| | "buffer": self._buffer, |
| | "buffer_tokens": self._buffer_tokens, |
| | "data_provider": self._data_provider.state_dict(), |
| | } |
| |
|
| | def load_state_dict(self, state: dict[str, Any]) -> None: |
| | self._buffer = state["buffer"] |
| | self._buffer_tokens = state["buffer_tokens"] |
| | self._data_provider.load_state_dict(state["data_provider"]) |
| | self._is_resuming = True |
| |
|
| | def set_epoch(self, epoch: int) -> None: |
| | if hasattr(self._data_provider.sampler, "set_epoch"): |
| | self._data_provider.sampler.set_epoch(epoch) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | """ |
| | python -m llamafactory.v1.core.utils.batching \ |
| | --model llamafactory/tiny-random-qwen2.5 \ |
| | --train_dataset data/v1_sft_demo.yaml \ |
| | --micro_batch_size 2 \ |
| | --global_batch_size 4 \ |
| | --batching_workers 0 |
| | """ |
| | from ...config.arg_parser import get_args |
| | from ..data_engine import DataEngine |
| | from ..model_engine import ModelEngine |
| |
|
| | model_args, data_args, training_args, _ = get_args() |
| | data_engine = DataEngine(data_args.train_dataset) |
| | model_engine = ModelEngine(model_args=model_args) |
| | batch_generator = BatchGenerator( |
| | data_engine, |
| | model_engine.renderer, |
| | micro_batch_size=training_args.micro_batch_size, |
| | global_batch_size=training_args.global_batch_size, |
| | cutoff_len=training_args.cutoff_len, |
| | batching_workers=training_args.batching_workers, |
| | batching_strategy=training_args.batching_strategy, |
| | ) |
| | for batch in batch_generator: |
| | print(batch) |
| | print(len(batch)) |
| | print(batch[0]["input_ids"].shape) |
| | break |
| |
|