LexaLCM_Pre0 / lcm /datasets /dataloader.py
Lexa
Converted .pt files to safetensors, then (dirtily) patched fairseq to enable loading of safetensor files
b5a0bec
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
#
import gc
import logging
from copy import deepcopy
from functools import partial
from typing import Any, Dict, Iterator, Mapping, Optional, Sequence, Tuple
import pyarrow.compute as pc
import torch
from fairseq2.data.data_pipeline import DataPipeline, read_sequence
from fairseq2.data.text import TextTokenizer
from fairseq2.gang import FakeGang, Gang
from fairseq2.models.sequence import SequenceBatch
from fairseq2.nn.padding import pad_seqs
from fairseq2.typing import DataType
from fairseq2.utils.state import Stateful
from sonar.models.sonar_text import load_sonar_tokenizer
from lcm.datasets.base import DataLoader
from lcm.datasets.batch import LCMInput
from lcm.datasets.configs import (
ColumnsNames,
DataLoadingConfig,
ParquetDatasetConfig,
ParquetDatasetLimitOptions,
SonarDecoderConfig,
)
from lcm.datasets.utils import move_eos_to_the_end
from lcm.utils.common import set_mkl_num_threads
logger = logging.getLogger(__name__)
def truncate_sequence(tokens: torch.Tensor, max_len: int = 512) -> torch.Tensor:
if len(tokens) > max_len:
return tokens[:max_len]
return tokens
class LCMDataLoader(DataLoader[LCMInput, ParquetDatasetConfig], Stateful):
def __init__(
self,
data_config: DataLoadingConfig,
datasets: Sequence[ParquetDatasetConfig],
dtype: DataType = torch.float16,
use_decoder_backprop: bool = False,
max_subword_length: int = 64,
gang: Gang = None,
sonar_decoder_config: Optional[SonarDecoderConfig] = None,
) -> None:
gang = gang or FakeGang()
super().__init__(
data_config=data_config,
datasets=datasets,
dtype=dtype,
gang=gang,
)
set_mkl_num_threads()
self.use_decoder_backprop = use_decoder_backprop
self.sonar_tokenizer: Optional[TextTokenizer] = None
self.max_subword_length = max_subword_length
if sonar_decoder_config is not None:
self.setup_sonar_decoder_tokenizer(config=sonar_decoder_config)
self._dummy_example: Optional[LCMInput] = None
def setup_sonar_decoder_tokenizer(
self,
config: SonarDecoderConfig,
):
if self.use_decoder_backprop:
# The tokenizer
self.tokenizer = load_sonar_tokenizer(config.tokenizer, progress=False)
# Target text encoder
self.sonar_tokenizer = self.tokenizer.create_encoder(
task="translation",
lang=config.lang,
mode="target",
device=self.gang.device,
)
else:
self.sonar_tokenizer = None
def _prepare_subword_tokens(
self, batch: Dict[str, Any]
) -> Tuple[Optional[SequenceBatch], Optional[SequenceBatch]]:
"""
Given a batch of paragraphs/documents,
prepare a batch of sentences (flattened) tokenized at the subword-level
to feed to the SONAR decoder (a standard token-level decoder)
Args:
batch: attributes of a batch from the dataset.
A batch is M documents/paragraphs each spanning
a variable number of sentences {N_1, ..., N_M}.
E.g., {'text_sentences': [[sent^1_1, ...sent^1_{N_1}],
...[sent^M_1, ... sent^M_{N_M}],
'text_sentences_sonar_emb': [X^1 in (N_1, D), ... X^M in (N_M, D)]}
where D is the sonar embedding dimension.
Returns:
Toeknized sentences (subword-level) in (\sum_i=1^M N_i, max_len)
where max_len is min(self.max_subword_length, max length of the sentences in the batch)
"""
if not self.use_decoder_backprop:
return None, None
# flatten the sentences from different documents/paragraphs
flattened_source_text = (
pc.list_flatten(batch[ColumnsNames.source_text_column.value])
.to_pandas()
.values
)
pipeline: DataPipeline = (
read_sequence(flattened_source_text)
.map(
[
self.sonar_tokenizer, # type: ignore
partial(truncate_sequence, max_len=self.max_subword_length),
],
num_parallel_calls=int(max(8 * self.data_config.num_parallel_calls, 1)),
)
.and_return(max_num_warnings=4)
)
tokens_seqs, tokens_padding_mask = pad_seqs(list(pipeline)) # type: ignore
prefix_batch = SequenceBatch(tokens_seqs, tokens_padding_mask)
# TODO: instead of moving the EOS around, make the tokenizer append at the tokenization.
target_batch = move_eos_to_the_end(
prefix_batch,
eos_token_id=self.tokenizer.vocab_info.eos_idx,
pad_token_id=self.tokenizer.vocab_info.pad_idx,
)
return prefix_batch, target_batch
def _tokenize_batch(self, batch: Dict[str, Any]) -> LCMInput:
"""
Given a batch of documents,
prepare a batch of input features for the LCM
This step is to simply fetch the right column for source/target & source text
and convert torch NestedTensors to list of tensors
Args:
batch: attributes of a batch from the dataset.
A batch is M documents each spanning
a variable number of sentences {N_1, ..., N_M}.
E.g., {'text_sentences': [[sent^1_1, ...sent^1_{N_1}],
...[sent^M_1, ... sent^M_{N_M}],
'text_sentences_sonar_emb': [X^1 in (N_1, D), ... X^M in (N_M, D)]}
where D is the sonar embedding dimension.
Returns:
LCMInput(
source: SONAR embeddings of the source text
i.e [X^1 in (N_1, D), ... X^M in (N_M, D)]
target: If supervised data: SONAR embeddings of the source text
tokens: Tokenized flattened sentences for the SONAR decoder (see `_prepare_subword_tokens`)
)
"""
# Prepare sentence-wise subword tokens if needed:
tokens, target_tokens = self._prepare_subword_tokens(batch)
# Load target embeddings if requested and to propagate all other embeddings
possible_emb_columns = {
"source": ColumnsNames.source_column,
"target": ColumnsNames.target_column,
}
outputs = {
"tokens": tokens,
"target_tokens": target_tokens,
"name": batch[ColumnsNames.dataset_name.value],
"batch": batch,
}
for key, col in possible_emb_columns.items():
col_name = col.value
if col_name in batch:
dtype = self.dtype if "_length" not in key else torch.int64
embs = [x.to(self.gang.device).to(dtype) for x in batch[col_name]]
# Special case when some embeddings are not shaped as (T, D) e.g., XLMC's answer columns
if embs[0].dim() == 1 and "_length" not in key:
embs = [t.unsqueeze(0) for t in embs]
else:
embs = None
outputs[key] = embs
assert outputs["source"] is not None, (
"LCMDataLoader requires `source` sequences to be present in batches"
)
return LCMInput(**outputs)
def iterate_batches(self) -> Iterator[LCMInput]:
yield from map(self._tokenize_batch, self.pipeline)
def iterate_dummy_batches(self) -> Iterator[LCMInput]:
"""
it's needed to simulate the data that follows the strucutre of self.pipeline (by always returning the same element).
It can be used only for fast forward pass (to avoid uneven sharding multi-gpus training).
"""
if self._dummy_example is None:
# patching the params to get less data with less cost
limited_datasets = deepcopy(self.datasets)
for ds_conf in limited_datasets:
assert isinstance(ds_conf, ParquetDatasetConfig)
ds_conf.limit = ParquetDatasetLimitOptions(nb_fragments=1)
# Copy the true data config and reduce the batch size.
# When wrapping data, we want to also wrap the dummy batches
# to not exceed model max_length
dummy_dataloading_config = deepcopy(self.data_config)
dummy_dataloading_config.batch_size = 1
self._dummy_example = self._tokenize_batch(
next(
iter(
self.builder_func(
limited_datasets, dummy_dataloading_config, 0, 1
)
)
)
)
gc.collect()
while True:
yield self._dummy_example
def state_dict(self) -> Dict[str, Any]:
logger.info("Getting the data pipeline state ...")
state = self.pipeline.state_dict(strict=False)
return state
def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
if state_dict is not None:
assert self.pipeline is not None
if self.data_config.ignore_checkpointed_pipeline:
logger.warning("Ignoring existing dataloader state")
else:
try:
self.pipeline.load_state_dict(state_dict)
logger.info(f"Reloaded datapipeline state: {str(state_dict)[:400]}")
except ValueError:
logger.warning(
f"Failed to load dataloader state: {str(state_dict)[:400]}"
)
else:
# retro-compatibility
logger.warning(f"Attempt to restore a dataloader {self} with empty state")