# Copyright 2025 the LlamaFactory team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Batching utils supports stateful dataloader. 1. Init stateful dataloader (tokenize) 2. Add to buffer 3. Yield batch indexes (micro batch * grad acc) a) non pack + non dynamic b) non pack + dynamic c) pack + non dynamic d) pack + dynamic """ from collections.abc import Iterator from typing import Any from torch.utils.data import default_collate from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler from ...accelerator.interface import DistributedInterface from ...config import BatchingStrategy from ...utils import logging from ...utils.helper import pad_and_truncate from ...utils.objects import StatefulBuffer from ...utils.types import BatchInfo, BatchInput, ModelInput, TorchDataset from .rendering import Renderer logger = logging.get_logger(__name__) def default_collate_fn(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None: micro_batch_size = batch_info["micro_batch_size"] num_micro_batch = batch_info["num_micro_batch"] cutoff_len = batch_info["cutoff_len"] batch_size = micro_batch_size * num_micro_batch if len(buffer) < batch_size: return None samples = buffer.get(batch_size) batch = [] for i in range(num_micro_batch): micro_batch = samples[i * micro_batch_size : (i + 1) * micro_batch_size] batch.append(default_collate(pad_and_truncate(micro_batch, cutoff_len))) return batch class BatchGenerator(Iterator): def __init__( self, dataset: TorchDataset, renderer: Renderer, micro_batch_size: int = 1, global_batch_size: int | None = None, cutoff_len: int = 2048, batching_workers: int = 0, batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL, pin_memory: bool = True, drop_last: bool = True, ) -> None: self.dataset = dataset self.renderer = renderer self.micro_batch_size = micro_batch_size self.global_batch_size = global_batch_size self.cutoff_len = cutoff_len self.batching_workers = batching_workers self.batching_strategy = batching_strategy self.pin_memory = pin_memory self.drop_last = drop_last # TODO: support length and infinity dp_size = DistributedInterface().get_world_size("dp") if self.global_batch_size is None: self.global_batch_size = dp_size * micro_batch_size self.num_micro_batch = 1 elif self.global_batch_size % (dp_size * micro_batch_size) == 0: self.num_micro_batch = global_batch_size // dp_size // micro_batch_size else: raise ValueError( "Global batch size must be divisible by DP size and micro batch size. " f"Got {global_batch_size} % ({dp_size} * {micro_batch_size}) != 0." ) if not self.drop_last: raise ValueError("Drop last must be True.") self._init_data_provider() self._is_resuming: bool = False self._data_iter = iter(self._data_provider) self._buffer = StatefulBuffer() self._batch_info: BatchInfo = { "micro_batch_size": self.micro_batch_size, "num_micro_batch": self.num_micro_batch, "cutoff_len": self.cutoff_len, "data_iter": self._data_iter, } logger.info_rank0( f"Init unified data loader with global batch size {self.global_batch_size}, " f"micro batch size {self.micro_batch_size}, " f"num micro batch {self.num_micro_batch}, " f"cutoff len {self.cutoff_len}, " f"batching workers {self.batching_workers}, " f"batching strategy {self.batching_strategy}." ) def _init_data_provider(self) -> None: if len(self.dataset) != -1: sampler = StatefulDistributedSampler( self.dataset, num_replicas=DistributedInterface().get_world_size("dp"), rank=DistributedInterface().get_rank("dp"), shuffle=True, seed=0, drop_last=self.drop_last, ) else: raise NotImplementedError("Iterable dataset is not supported yet.") self._data_provider = StatefulDataLoader( self.dataset, batch_size=self.micro_batch_size * self.num_micro_batch, sampler=sampler, num_workers=self.batching_workers, collate_fn=self.renderer.process_samples, pin_memory=self.pin_memory, drop_last=self.drop_last, ) if self.batching_strategy == BatchingStrategy.NORMAL: self._length = len(self._data_provider) else: from ...plugins.trainer_plugins.batching import BatchingPlugin self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider) raise NotImplementedError("Batching strategy other than NORMAL is not supported yet.") def __len__(self) -> int: return self._length def __iter__(self): if not self._is_resuming: self._buffer.clear() self._buffer_tokens = 0 self._data_iter = iter(self._data_provider) self._is_resuming = False return self def __next__(self): self._fill_buffer() batch = self._generate_batch() if batch is None: raise StopIteration return batch def _fill_buffer(self) -> None: if self.batching_strategy == BatchingStrategy.NORMAL: while len(self._buffer) < self.micro_batch_size * self.num_micro_batch: try: samples: list[ModelInput] = next(self._data_iter) except StopIteration: break self._buffer.put(samples) else: from ...plugins.trainer_plugins.batching import BatchingPlugin BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info) def _generate_batch(self) -> list[BatchInput] | None: if self.batching_strategy == BatchingStrategy.NORMAL: return default_collate_fn(self._buffer, self._batch_info) else: from ...plugins.trainer_plugins.batching import BatchingPlugin return BatchingPlugin(self.batching_strategy).generate_batch(self._buffer, self._batch_info) def state_dict(self) -> dict[str, Any]: return { "buffer": self._buffer, "buffer_tokens": self._buffer_tokens, "data_provider": self._data_provider.state_dict(), } def load_state_dict(self, state: dict[str, Any]) -> None: self._buffer = state["buffer"] self._buffer_tokens = state["buffer_tokens"] self._data_provider.load_state_dict(state["data_provider"]) self._is_resuming = True def set_epoch(self, epoch: int) -> None: if hasattr(self._data_provider.sampler, "set_epoch"): self._data_provider.sampler.set_epoch(epoch) if __name__ == "__main__": """ python -m llamafactory.v1.core.utils.batching \ --model llamafactory/tiny-random-qwen2.5 \ --train_dataset data/v1_sft_demo.yaml \ --micro_batch_size 2 \ --global_batch_size 4 \ --batching_workers 0 """ from ...config.arg_parser import get_args from ..data_engine import DataEngine from ..model_engine import ModelEngine model_args, data_args, training_args, _ = get_args() data_engine = DataEngine(data_args.train_dataset) model_engine = ModelEngine(model_args=model_args) batch_generator = BatchGenerator( data_engine, model_engine.renderer, micro_batch_size=training_args.micro_batch_size, global_batch_size=training_args.global_batch_size, cutoff_len=training_args.cutoff_len, batching_workers=training_args.batching_workers, batching_strategy=training_args.batching_strategy, ) for batch in batch_generator: print(batch) print(len(batch)) print(batch[0]["input_ids"].shape) break