Lexa
Converted .pt files to safetensors, then (dirtily) patched fairseq to enable loading of safetensor files
b5a0bec
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # | |
| import logging | |
| from abc import ABC, abstractmethod | |
| from typing import Callable, Dict, Generic, Iterator, Optional, Sequence, TypeVar, Union | |
| import torch | |
| from fairseq2.data.data_pipeline import DataPipeline | |
| from fairseq2.gang import FakeGang, Gang | |
| from fairseq2.typing import DataType | |
| from lcm.datasets.configs import ( | |
| DataLoadingConfig, | |
| DatasetConfigT, | |
| create_dataset_config_from_cards, | |
| ) | |
| from lcm.datasets.dataloading import ( | |
| build_weighted_pipeline_with_renaming as default_build_fn, | |
| ) | |
| from lcm.utils.common import Batched, set_mkl_num_threads | |
| BatchT_co = TypeVar("BatchT_co", bound=Union[Dict, Batched], covariant=True) | |
| logger = logging.getLogger(__name__) | |
| class DataLoader(ABC, Generic[BatchT_co, DatasetConfigT]): | |
| def __init__( | |
| self, | |
| data_config: DataLoadingConfig, | |
| datasets: Sequence[DatasetConfigT], | |
| gang: Gang, | |
| builder_func: Callable[..., DataPipeline] = default_build_fn, | |
| dtype: DataType = torch.float16, | |
| ): | |
| self.data_config = data_config | |
| self.datasets = list(map(create_dataset_config_from_cards, datasets)) | |
| self.dtype = dtype | |
| self.gang = gang | |
| self.builder_func = builder_func | |
| self._pipeline: Optional[DataPipeline] = None | |
| def pipeline(self) -> DataPipeline: | |
| if self._pipeline is None: | |
| logger.info(f"R{self.gang.rank} self._pipeline is None, building...") | |
| gang_rank = self.gang.rank if self.gang else 0 | |
| world_size = self.gang.size if self.gang else 1 | |
| self._pipeline = self.builder_func( | |
| self.datasets, self.data_config, gang_rank, world_size | |
| ) | |
| assert self._pipeline, ( | |
| f"Cannot build data pipeline from config {self.data_config}" | |
| ) | |
| return self._pipeline | |
| def destroy(self) -> None: | |
| """Destroy the pipeline to rebuild it with different shuffling""" | |
| self._pipeline = None | |
| # Build again and reset it | |
| logger.info(f"R{self.gang.rank} resetting the pipeline in DataLoader.destroy") | |
| self.reset() | |
| def reset(self) -> None: | |
| """ | |
| Applying reset will result in different shuffling for next iterations, | |
| since pipeline will use modified generator state from previous one. | |
| This's suitable side effect for `sharding_in_memory=False` (training) scenario. | |
| Illustrative example : | |
| >>> import torch | |
| >>> from fairseq2.data import read_sequence | |
| >>> def get_one_epoch_pipeline(): | |
| ... torch.manual_seed(13) | |
| ... return read_sequence(list(range(10))).shuffle(5) | |
| >>> bb = get_one_epoch_pipeline().and_return() | |
| >>> list(bb) | |
| [3, 1, 2, 4, 0, 8, 5, 6, 9, 7] | |
| >>> bb.reset() | |
| >>> list(bb) | |
| [4, 0, 3, 2, 1, 9, 7, 6, 8, 5] | |
| """ | |
| self.pipeline.reset() | |
| def iterate_batches(self) -> Iterator[BatchT_co]: ... | |
| class BaseDataLoader(DataLoader[dict, DatasetConfigT]): | |
| def __init__( | |
| self, | |
| data_config: DataLoadingConfig, | |
| datasets: Sequence[DatasetConfigT], | |
| dtype: DataType = torch.float16, | |
| gang: Gang = None, | |
| ) -> None: | |
| gang = gang or FakeGang() | |
| super().__init__( | |
| data_config=data_config, | |
| datasets=datasets, | |
| builder_func=default_build_fn, | |
| dtype=dtype, | |
| gang=gang, | |
| ) | |
| set_mkl_num_threads() | |
| def iterate_batches(self) -> Iterator[dict]: | |
| yield from iter(self.pipeline) | |