| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import pickle |
| | from abc import ABC, abstractmethod |
| | from collections.abc import Callable |
| | from typing import Any |
| |
|
| | from torch.distributed.checkpoint.stateful import Stateful |
| | from torch.utils.data import IterableDataset |
| | from torchdata.stateful_dataloader import StatefulDataLoader |
| | from torchtitan.tools.logging import logger |
| |
|
| |
|
| | class BaseDataLoader(Stateful, ABC): |
| | """Base class for all dataloaders. |
| | |
| | This is used to enforce that all dataloaders have the methods defined in ``Stateful``, |
| | ``state_dict()`` and ``load_state_dict()``. |
| | """ |
| |
|
| | @abstractmethod |
| | def __iter__(self): |
| | ... |
| |
|
| |
|
| | class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader): |
| | """Dataloader that is aware of distributed data parallelism. |
| | |
| | This dataloader is used to load data in a distributed data parallel fashion. It also |
| | utilizes ``torchdata.stateful_dataloader.StatefulDataLoader`` to implement the necessary |
| | methods such as ``__iter__``. |
| | |
| | Args: |
| | dataset (IterableDataset): The dataset to iterate over. |
| | dp_rank: Data parallelism rank for this dataloader. |
| | dp_world_size: The world size of the data parallelism. |
| | batch_size: The batch size to use for each iteration. |
| | collate_fn: Optional function to collate samples in a batch. |
| | """ |
| |
|
| | dp_rank: int |
| | dp_world_size: int |
| | batch_size: int |
| |
|
| | def __init__( |
| | self, |
| | dataset: IterableDataset, |
| | dp_rank: int, |
| | dp_world_size: int, |
| | batch_size: int, |
| | collate_fn: Callable | None = None, |
| | ): |
| | self.dp_world_size = dp_world_size |
| | self.dp_rank = dp_rank |
| | self.batch_size = batch_size |
| | super().__init__(dataset, batch_size, collate_fn=collate_fn) |
| | self._rank_id = f"dp_rank_{dp_rank}" |
| |
|
| | def state_dict(self) -> dict[str, Any]: |
| | |
| | return { |
| | |
| | |
| | self._rank_id: pickle.dumps(super().state_dict()), |
| | "world_size": self.dp_world_size, |
| | } |
| |
|
| | def load_state_dict(self, state_dict: dict[str, Any]) -> None: |
| | |
| | if not state_dict: |
| | return |
| |
|
| | if self._rank_id not in state_dict: |
| | logger.warning( |
| | f"DataLoader state is empty for dp rank {self.dp_rank}, " |
| | "expected key {self._rank_id}" |
| | ) |
| | return |
| |
|
| | assert self.dp_world_size == state_dict["world_size"], ( |
| | "dp_degree is inconsistent before and after checkpoint, " |
| | "dataloader resharding is not supported yet." |
| | ) |
| | |
| | |
| | super().load_state_dict(pickle.loads(state_dict[self._rank_id])) |
| |
|