NeMo / nemo /collections /nlp /data /common /sequence_to_sequence_dataset.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import torch
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import (
get_indexed_dataset_,
get_samples_mapping,
)
from nemo.collections.nlp.data.language_modeling.text_memmap_dataset import TextMemMapDataset
from nemo.core.classes import Dataset
from nemo.utils import logging
__all__ = ['SequenceToSequenceDataset', 'TextMemmapSequenceToSequenceDataset']
class SequenceToSequenceDataset(Dataset):
"""Sequence to Sequence Dataset in memory."""
def __init__(
self,
src_file_name: str,
tgt_file_name: str,
src_tokenizer: TokenizerSpec,
tgt_tokenizer: TokenizerSpec,
max_src_seq_length: int,
max_tgt_seq_length: int,
add_bos_to_input: bool = True,
add_eos_to_input: bool = True,
replace_bos_with_pad: bool = False,
):
super().__init__()
self.src_file_name = src_file_name
self.tgt_file_name = tgt_file_name
self.src_tokenizer = src_tokenizer
self.tgt_tokenizer = tgt_tokenizer
self.max_src_seq_length = max_src_seq_length
self.max_tgt_seq_length = max_tgt_seq_length
self.add_bos_to_input = add_bos_to_input
self.add_eos_to_input = add_eos_to_input
self.replace_bos_with_pad = replace_bos_with_pad
assert self.max_src_seq_length > 0
assert self.max_tgt_seq_length > 0
self._check_files_exist()
self._get_examples()
def _check_files_exist(self):
if not os.path.exists(self.src_file_name):
raise FileNotFoundError(f"Source file {self.src_file_name} not found")
if not os.path.exists(self.tgt_file_name):
raise FileNotFoundError(f"Source file {self.src_file_name} not found")
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
example = self.examples[idx]
text_enc = example['src']
text_dec = example['tgt'][:-1]
labels = example['tgt'][1:]
return {'text_enc': text_enc, 'text_dec': text_dec, 'labels': labels}
def _get_examples(self):
self.examples = []
with open(self.src_file_name, encoding='utf8') as f_src, open(self.tgt_file_name, encoding='utf8') as f_tgt:
for i, (src, tgt) in enumerate(zip(f_src, f_tgt)):
if i % 10000 == 0 and i != 0:
logging.info(f"Read {i} lines from {self.src_file_name} & {self.tgt_file_name}")
src = self.src_tokenizer.text_to_ids(src.strip())
if self.add_bos_to_input:
src = [self.src_tokenizer.pad_id if self.replace_bos_with_pad else self.src_tokenizer.bos_id] + src
if self.add_eos_to_input:
src = src + [self.src_tokenizer.eos_id]
tgt = (
[self.tgt_tokenizer.pad_id if self.replace_bos_with_pad else self.tgt_tokenizer.bos_id]
+ self.tgt_tokenizer.text_to_ids(tgt.strip())
+ [self.tgt_tokenizer.eos_id]
)
# Truncate to max sequence length.
if len(src) > self.max_src_seq_length:
src = src[-self.max_src_seq_length + 1 :]
if len(tgt) > self.max_tgt_seq_length:
tgt = tgt[-self.max_tgt_seq_length + 1 :]
self.examples.append({'src': src, 'tgt': tgt})
logging.info(f'Dataset Length : {len(self.examples)}')
def collate_fn(self, batch):
text_enc = [item['text_enc'] for item in batch]
text_dec = [item['text_dec'] for item in batch]
labels = [item['labels'] for item in batch]
if isinstance(text_enc[0], np.ndarray):
text_enc = [x.tolist() for x in text_enc]
if isinstance(text_dec[0], np.ndarray):
text_dec = [x.tolist() for x in text_dec]
if isinstance(labels[0], np.ndarray):
labels = [x.tolist() for x in labels]
max_dec_input_length = max([len(item) for item in text_dec]) if text_dec else 0
max_enc_input_length = max([len(item) for item in text_enc]) if text_enc else 0
max_label_length = max([len(item) for item in labels]) if labels else 0
loss_mask = [([1] * (len(item))) + ([0] * (max_label_length - len(item))) for item in labels]
text_enc = [item + [self.src_tokenizer.pad_id] * (max_enc_input_length - len(item)) for item in text_enc]
text_dec = [item + [self.tgt_tokenizer.pad_id] * (max_dec_input_length - len(item)) for item in text_dec]
labels = [item + [self.tgt_tokenizer.pad_id] * (max_label_length - len(item)) for item in labels]
text_enc = torch.LongTensor(text_enc)
text_dec = torch.LongTensor(text_dec)
labels = torch.LongTensor(labels)
loss_mask = torch.LongTensor(loss_mask)
enc_mask = (text_enc != self.src_tokenizer.pad_id).long()
dec_mask = (text_dec != self.tgt_tokenizer.pad_id).long()
return {
'text_enc': text_enc,
'text_dec': text_dec,
'labels': labels,
'loss_mask': loss_mask,
'enc_mask': enc_mask,
'dec_mask': dec_mask,
}
class IndexedSequenceToSequenceDataset(SequenceToSequenceDataset):
"""Abstract class for TextMemmapSequenceToSequenceDataset and BinarizedMemmapSequenceToSequenceDataset.
This class is not meant to be used standalone and just as an abstract class for the two subclasses.
"""
def __init__(
self,
src_file_name: str,
tgt_file_name: str,
src_tokenizer: TokenizerSpec,
tgt_tokenizer: TokenizerSpec,
max_src_seq_length: int,
max_tgt_seq_length: int,
seed: int = 1234,
add_bos_to_enc: bool = True,
add_eos_to_enc: bool = True,
max_num_samples: int = None,
prepend_id: int = None,
):
"""
src_file_name: Path to a single source file on disk. This is either the path to a raw text file or the prefix to the processed src_file_name.bin/idx files.
src_file_name: Path to a single target file on disk. This is either the path to a raw text file or the prefix to the processed tgt_file_name.bin/idx files.
src_tokenizer: Tokenizer for the source dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece).
tgt_tokenizer: Tokenizer for the target dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece).
max_src_seq_length: Maximum length of the source sequences. Lines above this length will be truncated.
max_tgt_seq_length: Maximum length of the target sequences. Lines above this length will be truncated.
seed: Random seed for data shuffling.
max_num_samples: Maximum number of samples to load. This can be > dataset length if you want to oversample data. If None, all samples will be loaded.
prepend_id: If not None, prepend this id to the encoder input.
"""
super().__init__(
src_file_name=src_file_name,
tgt_file_name=tgt_file_name,
src_tokenizer=src_tokenizer,
tgt_tokenizer=tgt_tokenizer,
max_src_seq_length=max_src_seq_length,
max_tgt_seq_length=max_tgt_seq_length,
)
self.seed = seed
self.max_num_samples = max_num_samples
self.add_bos_to_enc = add_bos_to_enc
self.add_eos_to_enc = add_eos_to_enc
self.prepend_id = prepend_id
logging.info(f'Desired number of samples : {self.max_num_samples}')
logging.info(f'Source Dataset Length : {len(self.src_indexed_dataset)}')
logging.info(f'Target Dataset Length : {len(self.tgt_indexed_dataset)}')
def __len__(self):
if self.max_num_samples is None:
return len(self.src_indexed_dataset)
else:
return self.max_num_samples
def _get_sample(self, idx):
if isinstance(idx, np.int64):
idx = idx.item()
if self.samples_mapping is not None:
assert idx < len(self.samples_mapping)
idx, _, _ = self.samples_mapping[idx]
if isinstance(idx, np.uint32):
idx = idx.item()
assert idx < len(self.src_indexed_dataset)
src = self.src_indexed_dataset[idx]
tgt = self.tgt_indexed_dataset[idx]
return src, tgt
def __getitem__(self, idx):
src, tgt = self._get_sample(idx)
offset = 0
if self.add_bos_to_enc:
offset += 1
if self.add_eos_to_enc:
offset += 1
if self.prepend_id is not None:
offset += 1
if len(src) > self.max_src_seq_length - offset:
src = src[: self.max_src_seq_length - offset]
if self.add_bos_to_enc:
src = np.concatenate([[self.src_tokenizer.bos_id], src])
if self.prepend_id is not None:
src = np.concatenate([[self.prepend_id], src])
if self.add_eos_to_enc:
src = np.concatenate([src, [self.src_tokenizer.eos_id]])
if len(tgt) > self.max_tgt_seq_length - 2:
tgt = tgt[: self.max_tgt_seq_length - 2]
text_dec = np.concatenate([[self.tgt_tokenizer.bos_id], tgt])
labels = np.concatenate([tgt, [self.tgt_tokenizer.eos_id]])
return {'text_enc': src, 'text_dec': text_dec, 'labels': labels}
def _build_samples_mapping(self):
if self.max_num_samples is not None:
# This means max src and max tgt sequence length need to be the same
if self.max_src_seq_length != self.max_tgt_seq_length:
raise ValueError(
f"max_src_seq_length ({self.max_src_seq_length}) != max_tgt_seq_length ({self.max_tgt_seq_length}). This is needed for max_samples based training for now."
)
self.samples_mapping = get_samples_mapping(
indexed_dataset=self.src_indexed_dataset,
data_prefix=self.src_file_name,
num_epochs=None,
max_num_samples=self.max_num_samples,
max_seq_length=self.max_src_seq_length - 2,
short_seq_prob=0,
seed=self.seed,
name=self.src_file_name.split('/')[-1],
binary_head=False,
)
else:
self.samples_mapping = None
class TextMemmapSequenceToSequenceDataset(IndexedSequenceToSequenceDataset):
"""Memory-mapped text sequence to sequence dataset. Operates on raw text files and tokenizes the text on-the-fly."""
def __init__(
self,
src_file_name: str,
tgt_file_name: str,
src_tokenizer: TokenizerSpec,
tgt_tokenizer: TokenizerSpec,
max_src_seq_length: int,
max_tgt_seq_length: int,
seed: int = 1234,
max_num_samples: int = None,
add_bos_to_enc: bool = True,
add_eos_to_enc: bool = True,
prepend_id: int = None,
):
"""
src_file_name: Path to a single source file on disk. The file should contain one sentence per line and be raw text.
tgt_file_name: Path to a single target file on disk. The file should contain one sentence per line aligned with src_file_name and be raw text.
src_tokenizer: Tokenizer for the source dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece).
tgt_tokenizer: Tokenizer for the target dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece).
max_src_seq_length: Maximum length of the source sequences. Lines above this length will be truncated.
max_tgt_seq_length: Maximum length of the target sequences. Lines above this length will be truncated.
seed: Random seed for data shuffling.
max_num_samples: Maximum number of samples to load. This can be > dataset length if you want to oversample data. If None, all samples will be loaded.
add_bos_to_enc: Add BOS token to the encoder input.
add_eos_to_enc: Add EOS token to the encoder input.
prepend_id: If not None, prepend this id to the encoder input.
"""
self.seed = seed
self.max_num_samples = max_num_samples
super().__init__(
src_file_name=src_file_name,
tgt_file_name=tgt_file_name,
src_tokenizer=src_tokenizer,
tgt_tokenizer=tgt_tokenizer,
max_src_seq_length=max_src_seq_length,
max_tgt_seq_length=max_tgt_seq_length,
seed=seed,
max_num_samples=max_num_samples,
add_bos_to_enc=add_bos_to_enc,
add_eos_to_enc=add_eos_to_enc,
prepend_id=prepend_id,
)
def _get_examples(self):
self.src_indexed_dataset = TextMemMapDataset(
dataset_paths=[self.src_file_name], tokenizer=self.src_tokenizer, header_lines=0
)
self.tgt_indexed_dataset = TextMemMapDataset(
dataset_paths=[self.tgt_file_name], tokenizer=self.tgt_tokenizer, header_lines=0
)
assert len(self.src_indexed_dataset) == len(
self.tgt_indexed_dataset
), "src and tgt has different number of lines"
self._build_samples_mapping()
class BinarizedMemmapSequenceToSequenceDataset(IndexedSequenceToSequenceDataset):
"""Memory-mapped text sequence to sequence dataset. Operates pre-tokenized binarized data files."""
def __init__(
self,
src_dataset_prefix: str,
tgt_dataset_prefix: str,
src_tokenizer: TokenizerSpec,
tgt_tokenizer: TokenizerSpec,
max_src_seq_length: int,
max_tgt_seq_length: int,
seed: int = 1234,
max_num_samples: int = None,
add_bos_to_enc: bool = True,
add_eos_to_enc: bool = True,
prepend_id: int = None,
):
"""
src_dataset_prefix: Path to the *prefix* of a single source bin/idx file on disk. This necessitates the existance src_file_prefix.bin and src_file_prefix.idx.
tgt_dataset_prefix: Path to the *prefix* of a single target aligned with source bin/idx file on disk. This necessitates the existance tgt_file_prefix.bin and tgt_file_prefix.idx.
src_tokenizer: Tokenizer for the source dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece).
tgt_tokenizer: Tokenizer for the target dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece).
max_src_seq_length: Maximum length of the source sequences. Lines above this length will be truncated.
max_tgt_seq_length: Maximum length of the target sequences. Lines above this length will be truncated.
seed: Random seed for data shuffling.
max_num_samples: Maximum number of samples to load. This can be > dataset length if you want to oversample data. If None, all samples will be loaded.
add_bos_to_enc: Add BOS token to the encoder input.
add_eos_to_enc: Add EOS token to the encoder input.
prepend_id: If not None, prepend this id to the encoder input.
"""
self.src_dataset_prefix = src_dataset_prefix
self.tgt_dataset_prefix = tgt_dataset_prefix
self.seed = seed
self.max_num_samples = max_num_samples
super().__init__(
src_file_name=src_dataset_prefix,
tgt_file_name=tgt_dataset_prefix,
src_tokenizer=src_tokenizer,
tgt_tokenizer=tgt_tokenizer,
max_src_seq_length=max_src_seq_length,
max_tgt_seq_length=max_tgt_seq_length,
seed=seed,
max_num_samples=max_num_samples,
add_bos_to_enc=add_bos_to_enc,
add_eos_to_enc=add_eos_to_enc,
prepend_id=prepend_id,
)
def _check_files_exist(self):
if not os.path.exists(self.src_dataset_prefix + ".bin") or not os.path.exists(
self.src_dataset_prefix + ".idx"
):
raise FileNotFoundError(f"{self.src_dataset_prefix}.bin or {self.src_dataset_prefix}.idx not found")
if not os.path.exists(self.tgt_dataset_prefix + ".bin") or not os.path.exists(
self.tgt_dataset_prefix + ".idx"
):
raise FileNotFoundError(f"{self.tgt_dataset_prefix}.bin or {self.tgt_dataset_prefix}.idx not found")
def _get_examples(self):
self.src_indexed_dataset = self._get_indexed_dataset(
self.src_dataset_prefix, data_impl='mmap', skip_warmup=True
)
self.tgt_indexed_dataset = self._get_indexed_dataset(
self.tgt_dataset_prefix, data_impl='mmap', skip_warmup=True
)
assert len(self.src_indexed_dataset) == len(self.tgt_indexed_dataset)
self._build_samples_mapping()
def _get_indexed_dataset(self, data_prefix, data_impl, skip_warmup):
indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup)
return indexed_dataset