| |
| |
| |
| |
| |
|
|
| from dataclasses import dataclass |
| from typing import Any, Callable |
|
|
| import torch |
|
|
| from datasets import Dataset, load_dataset |
| from datasets.distributed import split_dataset_by_node |
| from torch.distributed.checkpoint.stateful import Stateful |
| from torch.utils.data import IterableDataset |
|
|
| from torchtitan.components.dataloader import ParallelAwareDataloader |
| from torchtitan.components.tokenizer import Tokenizer |
| from torchtitan.config_manager import JobConfig |
| from torchtitan.tools.logging import logger |
|
|
|
|
| def _load_c4_dataset(dataset_path: str): |
| """Load C4 dataset with default configuration.""" |
| return load_dataset(dataset_path, name="en", split="train", streaming=True) |
|
|
|
|
| def _process_c4_text(sample: dict[str, Any]) -> str: |
| """Process C4 dataset sample text.""" |
| return sample["text"] |
|
|
|
|
| @dataclass |
| class DatasetConfig: |
| path: str |
| loader: Callable |
| text_processor: Callable |
|
|
|
|
| |
| DATASETS = { |
| "c4": DatasetConfig( |
| path="allenai/c4", |
| loader=_load_c4_dataset, |
| text_processor=_process_c4_text, |
| ), |
| "c4_test": DatasetConfig( |
| path="tests/assets/c4_test", |
| loader=lambda path: load_dataset(path, split="train"), |
| text_processor=_process_c4_text, |
| ), |
| } |
|
|
|
|
| def _validate_dataset( |
| dataset_name: str, dataset_path: str | None = None |
| ) -> tuple[str, Callable, Callable]: |
| """Validate dataset name and path.""" |
| if dataset_name not in DATASETS: |
| raise ValueError( |
| f"Dataset {dataset_name} is not supported. " |
| f"Supported datasets are: {list(DATASETS.keys())}" |
| ) |
|
|
| config = DATASETS[dataset_name] |
| path = dataset_path or config.path |
| logger.info(f"Preparing {dataset_name} dataset from {path}") |
| return path, config.loader, config.text_processor |
|
|
|
|
| class HuggingFaceDataset(IterableDataset, Stateful): |
| def __init__( |
| self, |
| dataset_name: str, |
| dataset_path: str | None, |
| tokenizer: Tokenizer, |
| seq_len: int = 2048, |
| dp_rank: int = 0, |
| dp_world_size: int = 1, |
| infinite: bool = False, |
| ) -> None: |
| |
| dataset_name = dataset_name.lower() |
|
|
| path, dataset_loader, text_processor = _validate_dataset( |
| dataset_name, dataset_path |
| ) |
| ds = dataset_loader(path) |
|
|
| self.dataset_name = dataset_name |
| self._data = split_dataset_by_node(ds, dp_rank, dp_world_size) |
| self._tokenizer = tokenizer |
| self.seq_len = seq_len |
| self.infinite = infinite |
| self._text_processor = text_processor |
|
|
| |
| self._sample_idx = 0 |
| self._all_tokens: list[int] = [] |
|
|
| def _get_data_iter(self): |
| if isinstance(self._data, Dataset) and self._sample_idx == len(self._data): |
| return iter([]) |
|
|
| it = iter(self._data) |
| for _ in range(self._sample_idx): |
| next(it) |
| return it |
|
|
| def __iter__(self): |
| max_buffer_token_len = 1 + self.seq_len |
|
|
| while True: |
| for sample in self._get_data_iter(): |
| |
| sample_text = self._text_processor(sample) |
| sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True) |
| self._all_tokens.extend(sample_tokens) |
| self._sample_idx += 1 |
|
|
| while len(self._all_tokens) >= max_buffer_token_len: |
| x = torch.LongTensor(self._all_tokens[:max_buffer_token_len]) |
| |
| self._all_tokens = self._all_tokens[max_buffer_token_len:] |
| input = x[:-1] |
| label = x[1:] |
| yield {"input": input}, label |
|
|
| if not self.infinite: |
| logger.warning(f"Dataset {self.dataset_name} has run out of data") |
| break |
| else: |
| |
| self._sample_idx = 0 |
| logger.warning(f"Dataset {self.dataset_name} is being re-looped") |
|
|
| def load_state_dict(self, state_dict): |
| self._sample_idx = state_dict["sample_idx"] |
| self._all_tokens = state_dict["token_buffer"] |
|
|
| def state_dict(self): |
| return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx} |
|
|
|
|
| def build_hf_dataloader( |
| dp_world_size: int, |
| dp_rank: int, |
| tokenizer: Tokenizer, |
| job_config: JobConfig, |
| infinite: bool = True, |
| ) -> ParallelAwareDataloader: |
| """Build a data loader for HuggingFace datasets.""" |
| dataset_name = job_config.training.dataset |
| dataset_path = job_config.training.dataset_path |
| batch_size = job_config.training.batch_size |
| seq_len = job_config.training.seq_len |
|
|
| hf_ds = HuggingFaceDataset( |
| dataset_name=dataset_name, |
| dataset_path=dataset_path, |
| tokenizer=tokenizer, |
| seq_len=seq_len, |
| dp_rank=dp_rank, |
| dp_world_size=dp_world_size, |
| infinite=infinite, |
| ) |
|
|
| return ParallelAwareDataloader( |
| dataset=hf_ds, |
| dp_rank=dp_rank, |
| dp_world_size=dp_world_size, |
| batch_size=batch_size, |
| ) |
|
|