File size: 3,657 Bytes
b5a0bec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
#

import logging
from abc import ABC, abstractmethod
from typing import Callable, Dict, Generic, Iterator, Optional, Sequence, TypeVar, Union

import torch
from fairseq2.data.data_pipeline import DataPipeline
from fairseq2.gang import FakeGang, Gang
from fairseq2.typing import DataType

from lcm.datasets.configs import (
    DataLoadingConfig,
    DatasetConfigT,
    create_dataset_config_from_cards,
)
from lcm.datasets.dataloading import (
    build_weighted_pipeline_with_renaming as default_build_fn,
)
from lcm.utils.common import Batched, set_mkl_num_threads

BatchT_co = TypeVar("BatchT_co", bound=Union[Dict, Batched], covariant=True)
logger = logging.getLogger(__name__)


class DataLoader(ABC, Generic[BatchT_co, DatasetConfigT]):
    def __init__(
        self,
        data_config: DataLoadingConfig,
        datasets: Sequence[DatasetConfigT],
        gang: Gang,
        builder_func: Callable[..., DataPipeline] = default_build_fn,
        dtype: DataType = torch.float16,
    ):
        self.data_config = data_config
        self.datasets = list(map(create_dataset_config_from_cards, datasets))
        self.dtype = dtype
        self.gang = gang
        self.builder_func = builder_func

        self._pipeline: Optional[DataPipeline] = None

    @property
    def pipeline(self) -> DataPipeline:
        if self._pipeline is None:
            logger.info(f"R{self.gang.rank} self._pipeline is None, building...")
            gang_rank = self.gang.rank if self.gang else 0
            world_size = self.gang.size if self.gang else 1

            self._pipeline = self.builder_func(
                self.datasets, self.data_config, gang_rank, world_size
            )
        assert self._pipeline, (
            f"Cannot build data pipeline from config {self.data_config}"
        )
        return self._pipeline

    def destroy(self) -> None:
        """Destroy the pipeline to rebuild it with different shuffling"""
        self._pipeline = None
        # Build again and reset it
        logger.info(f"R{self.gang.rank} resetting the pipeline in DataLoader.destroy")
        self.reset()

    def reset(self) -> None:
        """
        Applying reset will result in different shuffling for next iterations,
        since pipeline will use modified generator state from previous one.
        This's suitable side effect for `sharding_in_memory=False` (training) scenario.

        Illustrative example :
        >>> import torch
        >>> from fairseq2.data import read_sequence

        >>> def get_one_epoch_pipeline():
        ...     torch.manual_seed(13)
        ...     return read_sequence(list(range(10))).shuffle(5)

        >>> bb = get_one_epoch_pipeline().and_return()
        >>> list(bb)
        [3, 1, 2, 4, 0, 8, 5, 6, 9, 7]
        >>> bb.reset()
        >>> list(bb)
        [4, 0, 3, 2, 1, 9, 7, 6, 8, 5]
        """
        self.pipeline.reset()

    @abstractmethod
    def iterate_batches(self) -> Iterator[BatchT_co]: ...


class BaseDataLoader(DataLoader[dict, DatasetConfigT]):
    def __init__(
        self,
        data_config: DataLoadingConfig,
        datasets: Sequence[DatasetConfigT],
        dtype: DataType = torch.float16,
        gang: Gang = None,
    ) -> None:
        gang = gang or FakeGang()
        super().__init__(
            data_config=data_config,
            datasets=datasets,
            builder_func=default_build_fn,
            dtype=dtype,
            gang=gang,
        )
        set_mkl_num_threads()

    def iterate_batches(self) -> Iterator[dict]:
        yield from iter(self.pipeline)