File size: 9,934 Bytes
b5a0bec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
#

import gc
import logging
from copy import deepcopy
from functools import partial
from typing import Any, Dict, Iterator, Mapping, Optional, Sequence, Tuple

import pyarrow.compute as pc
import torch
from fairseq2.data.data_pipeline import DataPipeline, read_sequence
from fairseq2.data.text import TextTokenizer
from fairseq2.gang import FakeGang, Gang
from fairseq2.models.sequence import SequenceBatch
from fairseq2.nn.padding import pad_seqs
from fairseq2.typing import DataType
from fairseq2.utils.state import Stateful
from sonar.models.sonar_text import load_sonar_tokenizer

from lcm.datasets.base import DataLoader
from lcm.datasets.batch import LCMInput
from lcm.datasets.configs import (
    ColumnsNames,
    DataLoadingConfig,
    ParquetDatasetConfig,
    ParquetDatasetLimitOptions,
    SonarDecoderConfig,
)
from lcm.datasets.utils import move_eos_to_the_end
from lcm.utils.common import set_mkl_num_threads

logger = logging.getLogger(__name__)


def truncate_sequence(tokens: torch.Tensor, max_len: int = 512) -> torch.Tensor:
    if len(tokens) > max_len:
        return tokens[:max_len]
    return tokens


class LCMDataLoader(DataLoader[LCMInput, ParquetDatasetConfig], Stateful):
    def __init__(
        self,
        data_config: DataLoadingConfig,
        datasets: Sequence[ParquetDatasetConfig],
        dtype: DataType = torch.float16,
        use_decoder_backprop: bool = False,
        max_subword_length: int = 64,
        gang: Gang = None,
        sonar_decoder_config: Optional[SonarDecoderConfig] = None,
    ) -> None:
        gang = gang or FakeGang()

        super().__init__(
            data_config=data_config,
            datasets=datasets,
            dtype=dtype,
            gang=gang,
        )
        set_mkl_num_threads()

        self.use_decoder_backprop = use_decoder_backprop
        self.sonar_tokenizer: Optional[TextTokenizer] = None
        self.max_subword_length = max_subword_length
        if sonar_decoder_config is not None:
            self.setup_sonar_decoder_tokenizer(config=sonar_decoder_config)
        self._dummy_example: Optional[LCMInput] = None

    def setup_sonar_decoder_tokenizer(
        self,
        config: SonarDecoderConfig,
    ):
        if self.use_decoder_backprop:
            # The tokenizer
            self.tokenizer = load_sonar_tokenizer(config.tokenizer, progress=False)
            # Target text encoder
            self.sonar_tokenizer = self.tokenizer.create_encoder(
                task="translation",
                lang=config.lang,
                mode="target",
                device=self.gang.device,
            )
        else:
            self.sonar_tokenizer = None

    def _prepare_subword_tokens(
        self, batch: Dict[str, Any]
    ) -> Tuple[Optional[SequenceBatch], Optional[SequenceBatch]]:
        """
        Given a batch of paragraphs/documents,
        prepare a batch of sentences (flattened) tokenized at the subword-level
        to feed to the SONAR decoder (a standard token-level decoder)

        Args:
            batch: attributes of a batch from the dataset.
                    A batch is M documents/paragraphs each spanning
                    a variable number of sentences {N_1, ..., N_M}.

            E.g., {'text_sentences': [[sent^1_1, ...sent^1_{N_1}],
                                        ...[sent^M_1, ... sent^M_{N_M}],
                  'text_sentences_sonar_emb': [X^1 in (N_1, D), ... X^M in (N_M, D)]}
                  where D is the sonar embedding dimension.
        Returns:
            Toeknized sentences (subword-level) in (\sum_i=1^M N_i, max_len)
            where max_len is min(self.max_subword_length, max length of the sentences in the batch)

        """

        if not self.use_decoder_backprop:
            return None, None

        # flatten the sentences from different documents/paragraphs
        flattened_source_text = (
            pc.list_flatten(batch[ColumnsNames.source_text_column.value])
            .to_pandas()
            .values
        )

        pipeline: DataPipeline = (
            read_sequence(flattened_source_text)
            .map(
                [
                    self.sonar_tokenizer,  # type: ignore
                    partial(truncate_sequence, max_len=self.max_subword_length),
                ],
                num_parallel_calls=int(max(8 * self.data_config.num_parallel_calls, 1)),
            )
            .and_return(max_num_warnings=4)
        )

        tokens_seqs, tokens_padding_mask = pad_seqs(list(pipeline))  # type: ignore
        prefix_batch = SequenceBatch(tokens_seqs, tokens_padding_mask)
        # TODO: instead of moving the EOS around, make the tokenizer append at the tokenization.
        target_batch = move_eos_to_the_end(
            prefix_batch,
            eos_token_id=self.tokenizer.vocab_info.eos_idx,
            pad_token_id=self.tokenizer.vocab_info.pad_idx,
        )

        return prefix_batch, target_batch

    def _tokenize_batch(self, batch: Dict[str, Any]) -> LCMInput:
        """
        Given a batch of documents,
        prepare a batch of input features for the LCM
        This step is to simply fetch the right column for source/target & source text
        and convert torch NestedTensors to list of tensors

        Args:
            batch: attributes of a batch from the dataset.
                    A batch is M documents each spanning
                    a variable number of sentences {N_1, ..., N_M}.

            E.g., {'text_sentences': [[sent^1_1, ...sent^1_{N_1}],
                                        ...[sent^M_1, ... sent^M_{N_M}],
                  'text_sentences_sonar_emb': [X^1 in (N_1, D), ... X^M in (N_M, D)]}
                  where D is the sonar embedding dimension.
        Returns:
            LCMInput(
            source: SONAR embeddings of the source text
                i.e [X^1 in (N_1, D), ... X^M in (N_M, D)]
            target: If supervised data:  SONAR embeddings of the source text
            tokens: Tokenized flattened sentences for the SONAR decoder (see `_prepare_subword_tokens`)
            )

        """

        # Prepare sentence-wise subword tokens if needed:
        tokens, target_tokens = self._prepare_subword_tokens(batch)

        # Load target embeddings if requested and to propagate all other embeddings

        possible_emb_columns = {
            "source": ColumnsNames.source_column,
            "target": ColumnsNames.target_column,
        }

        outputs = {
            "tokens": tokens,
            "target_tokens": target_tokens,
            "name": batch[ColumnsNames.dataset_name.value],
            "batch": batch,
        }
        for key, col in possible_emb_columns.items():
            col_name = col.value
            if col_name in batch:
                dtype = self.dtype if "_length" not in key else torch.int64
                embs = [x.to(self.gang.device).to(dtype) for x in batch[col_name]]
                # Special case when some embeddings are not shaped as (T, D) e.g., XLMC's answer columns
                if embs[0].dim() == 1 and "_length" not in key:
                    embs = [t.unsqueeze(0) for t in embs]
            else:
                embs = None
            outputs[key] = embs
        assert outputs["source"] is not None, (
            "LCMDataLoader requires `source` sequences to be present in batches"
        )
        return LCMInput(**outputs)

    def iterate_batches(self) -> Iterator[LCMInput]:
        yield from map(self._tokenize_batch, self.pipeline)

    def iterate_dummy_batches(self) -> Iterator[LCMInput]:
        """
        it's needed to simulate the data that follows the strucutre of self.pipeline (by always returning the same element).
        It can be used only for fast forward pass (to avoid uneven sharding multi-gpus training).
        """
        if self._dummy_example is None:
            # patching the params to get less data with less cost
            limited_datasets = deepcopy(self.datasets)
            for ds_conf in limited_datasets:
                assert isinstance(ds_conf, ParquetDatasetConfig)
                ds_conf.limit = ParquetDatasetLimitOptions(nb_fragments=1)

            # Copy the true data config and reduce the batch size.
            # When wrapping data, we want to also wrap the dummy batches
            # to not exceed model max_length
            dummy_dataloading_config = deepcopy(self.data_config)
            dummy_dataloading_config.batch_size = 1

            self._dummy_example = self._tokenize_batch(
                next(
                    iter(
                        self.builder_func(
                            limited_datasets, dummy_dataloading_config, 0, 1
                        )
                    )
                )
            )
        gc.collect()

        while True:
            yield self._dummy_example

    def state_dict(self) -> Dict[str, Any]:
        logger.info("Getting the data pipeline state ...")
        state = self.pipeline.state_dict(strict=False)
        return state

    def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
        if state_dict is not None:
            assert self.pipeline is not None
            if self.data_config.ignore_checkpointed_pipeline:
                logger.warning("Ignoring existing dataloader state")
            else:
                try:
                    self.pipeline.load_state_dict(state_dict)
                    logger.info(f"Reloaded datapipeline state: {str(state_dict)[:400]}")
                except ValueError:
                    logger.warning(
                        f"Failed to load dataloader state: {str(state_dict)[:400]}"
                    )
        else:
            # retro-compatibility
            logger.warning(f"Attempt to restore a dataloader {self} with empty state")