| from typing import TypeVar | |
| from .arrow_dataset import Dataset, _split_by_node_map_style_dataset | |
| from .iterable_dataset import IterableDataset, _split_by_node_iterable_dataset | |
| DatasetType = TypeVar("DatasetType", Dataset, IterableDataset) | |
| def split_dataset_by_node(dataset: DatasetType, rank: int, world_size: int) -> DatasetType: | |
| """ | |
| Split a dataset for the node at rank `rank` in a pool of nodes of size `world_size`. | |
| For map-style datasets: | |
| Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of the dataset. | |
| To maximize data loading throughput, chunks are made of contiguous data on disk if possible. | |
| For iterable datasets: | |
| If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`), | |
| then the shards are evenly assigned across the nodes, which is the most optimized. | |
| Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples. | |
| Args: | |
| dataset ([`Dataset`] or [`IterableDataset`]): | |
| The dataset to split by node. | |
| rank (`int`): | |
| Rank of the current node. | |
| world_size (`int`): | |
| Total number of nodes. | |
| Returns: | |
| [`Dataset`] or [`IterableDataset`]: The dataset to be used on the node at rank `rank`. | |
| """ | |
| if isinstance(dataset, Dataset): | |
| return _split_by_node_map_style_dataset(dataset, rank=rank, world_size=world_size) | |
| else: | |
| return _split_by_node_iterable_dataset(dataset, rank=rank, world_size=world_size) | |