|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec |
|
|
from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import ( |
|
|
get_indexed_dataset_, |
|
|
get_samples_mapping, |
|
|
) |
|
|
from nemo.collections.nlp.data.language_modeling.text_memmap_dataset import TextMemMapDataset |
|
|
from nemo.core.classes import Dataset |
|
|
from nemo.utils import logging |
|
|
|
|
|
__all__ = ['SequenceToSequenceDataset', 'TextMemmapSequenceToSequenceDataset'] |
|
|
|
|
|
|
|
|
class SequenceToSequenceDataset(Dataset): |
|
|
"""Sequence to Sequence Dataset in memory.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
src_file_name: str, |
|
|
tgt_file_name: str, |
|
|
src_tokenizer: TokenizerSpec, |
|
|
tgt_tokenizer: TokenizerSpec, |
|
|
max_src_seq_length: int, |
|
|
max_tgt_seq_length: int, |
|
|
add_bos_to_input: bool = True, |
|
|
add_eos_to_input: bool = True, |
|
|
replace_bos_with_pad: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
self.src_file_name = src_file_name |
|
|
self.tgt_file_name = tgt_file_name |
|
|
self.src_tokenizer = src_tokenizer |
|
|
self.tgt_tokenizer = tgt_tokenizer |
|
|
self.max_src_seq_length = max_src_seq_length |
|
|
self.max_tgt_seq_length = max_tgt_seq_length |
|
|
self.add_bos_to_input = add_bos_to_input |
|
|
self.add_eos_to_input = add_eos_to_input |
|
|
self.replace_bos_with_pad = replace_bos_with_pad |
|
|
assert self.max_src_seq_length > 0 |
|
|
assert self.max_tgt_seq_length > 0 |
|
|
self._check_files_exist() |
|
|
self._get_examples() |
|
|
|
|
|
def _check_files_exist(self): |
|
|
if not os.path.exists(self.src_file_name): |
|
|
raise FileNotFoundError(f"Source file {self.src_file_name} not found") |
|
|
if not os.path.exists(self.tgt_file_name): |
|
|
raise FileNotFoundError(f"Source file {self.src_file_name} not found") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.examples) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
example = self.examples[idx] |
|
|
text_enc = example['src'] |
|
|
text_dec = example['tgt'][:-1] |
|
|
labels = example['tgt'][1:] |
|
|
return {'text_enc': text_enc, 'text_dec': text_dec, 'labels': labels} |
|
|
|
|
|
def _get_examples(self): |
|
|
self.examples = [] |
|
|
with open(self.src_file_name, encoding='utf8') as f_src, open(self.tgt_file_name, encoding='utf8') as f_tgt: |
|
|
for i, (src, tgt) in enumerate(zip(f_src, f_tgt)): |
|
|
if i % 10000 == 0 and i != 0: |
|
|
logging.info(f"Read {i} lines from {self.src_file_name} & {self.tgt_file_name}") |
|
|
src = self.src_tokenizer.text_to_ids(src.strip()) |
|
|
if self.add_bos_to_input: |
|
|
src = [self.src_tokenizer.pad_id if self.replace_bos_with_pad else self.src_tokenizer.bos_id] + src |
|
|
if self.add_eos_to_input: |
|
|
src = src + [self.src_tokenizer.eos_id] |
|
|
|
|
|
tgt = ( |
|
|
[self.tgt_tokenizer.pad_id if self.replace_bos_with_pad else self.tgt_tokenizer.bos_id] |
|
|
+ self.tgt_tokenizer.text_to_ids(tgt.strip()) |
|
|
+ [self.tgt_tokenizer.eos_id] |
|
|
) |
|
|
|
|
|
if len(src) > self.max_src_seq_length: |
|
|
src = src[-self.max_src_seq_length + 1 :] |
|
|
if len(tgt) > self.max_tgt_seq_length: |
|
|
tgt = tgt[-self.max_tgt_seq_length + 1 :] |
|
|
self.examples.append({'src': src, 'tgt': tgt}) |
|
|
|
|
|
logging.info(f'Dataset Length : {len(self.examples)}') |
|
|
|
|
|
def collate_fn(self, batch): |
|
|
text_enc = [item['text_enc'] for item in batch] |
|
|
text_dec = [item['text_dec'] for item in batch] |
|
|
labels = [item['labels'] for item in batch] |
|
|
|
|
|
if isinstance(text_enc[0], np.ndarray): |
|
|
text_enc = [x.tolist() for x in text_enc] |
|
|
|
|
|
if isinstance(text_dec[0], np.ndarray): |
|
|
text_dec = [x.tolist() for x in text_dec] |
|
|
|
|
|
if isinstance(labels[0], np.ndarray): |
|
|
labels = [x.tolist() for x in labels] |
|
|
|
|
|
max_dec_input_length = max([len(item) for item in text_dec]) if text_dec else 0 |
|
|
max_enc_input_length = max([len(item) for item in text_enc]) if text_enc else 0 |
|
|
max_label_length = max([len(item) for item in labels]) if labels else 0 |
|
|
|
|
|
loss_mask = [([1] * (len(item))) + ([0] * (max_label_length - len(item))) for item in labels] |
|
|
text_enc = [item + [self.src_tokenizer.pad_id] * (max_enc_input_length - len(item)) for item in text_enc] |
|
|
text_dec = [item + [self.tgt_tokenizer.pad_id] * (max_dec_input_length - len(item)) for item in text_dec] |
|
|
labels = [item + [self.tgt_tokenizer.pad_id] * (max_label_length - len(item)) for item in labels] |
|
|
|
|
|
text_enc = torch.LongTensor(text_enc) |
|
|
text_dec = torch.LongTensor(text_dec) |
|
|
labels = torch.LongTensor(labels) |
|
|
loss_mask = torch.LongTensor(loss_mask) |
|
|
|
|
|
enc_mask = (text_enc != self.src_tokenizer.pad_id).long() |
|
|
dec_mask = (text_dec != self.tgt_tokenizer.pad_id).long() |
|
|
|
|
|
return { |
|
|
'text_enc': text_enc, |
|
|
'text_dec': text_dec, |
|
|
'labels': labels, |
|
|
'loss_mask': loss_mask, |
|
|
'enc_mask': enc_mask, |
|
|
'dec_mask': dec_mask, |
|
|
} |
|
|
|
|
|
|
|
|
class IndexedSequenceToSequenceDataset(SequenceToSequenceDataset): |
|
|
"""Abstract class for TextMemmapSequenceToSequenceDataset and BinarizedMemmapSequenceToSequenceDataset. |
|
|
This class is not meant to be used standalone and just as an abstract class for the two subclasses. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
src_file_name: str, |
|
|
tgt_file_name: str, |
|
|
src_tokenizer: TokenizerSpec, |
|
|
tgt_tokenizer: TokenizerSpec, |
|
|
max_src_seq_length: int, |
|
|
max_tgt_seq_length: int, |
|
|
seed: int = 1234, |
|
|
add_bos_to_enc: bool = True, |
|
|
add_eos_to_enc: bool = True, |
|
|
max_num_samples: int = None, |
|
|
prepend_id: int = None, |
|
|
): |
|
|
""" |
|
|
src_file_name: Path to a single source file on disk. This is either the path to a raw text file or the prefix to the processed src_file_name.bin/idx files. |
|
|
src_file_name: Path to a single target file on disk. This is either the path to a raw text file or the prefix to the processed tgt_file_name.bin/idx files. |
|
|
src_tokenizer: Tokenizer for the source dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece). |
|
|
tgt_tokenizer: Tokenizer for the target dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece). |
|
|
max_src_seq_length: Maximum length of the source sequences. Lines above this length will be truncated. |
|
|
max_tgt_seq_length: Maximum length of the target sequences. Lines above this length will be truncated. |
|
|
seed: Random seed for data shuffling. |
|
|
max_num_samples: Maximum number of samples to load. This can be > dataset length if you want to oversample data. If None, all samples will be loaded. |
|
|
prepend_id: If not None, prepend this id to the encoder input. |
|
|
""" |
|
|
super().__init__( |
|
|
src_file_name=src_file_name, |
|
|
tgt_file_name=tgt_file_name, |
|
|
src_tokenizer=src_tokenizer, |
|
|
tgt_tokenizer=tgt_tokenizer, |
|
|
max_src_seq_length=max_src_seq_length, |
|
|
max_tgt_seq_length=max_tgt_seq_length, |
|
|
) |
|
|
self.seed = seed |
|
|
self.max_num_samples = max_num_samples |
|
|
self.add_bos_to_enc = add_bos_to_enc |
|
|
self.add_eos_to_enc = add_eos_to_enc |
|
|
self.prepend_id = prepend_id |
|
|
|
|
|
logging.info(f'Desired number of samples : {self.max_num_samples}') |
|
|
logging.info(f'Source Dataset Length : {len(self.src_indexed_dataset)}') |
|
|
logging.info(f'Target Dataset Length : {len(self.tgt_indexed_dataset)}') |
|
|
|
|
|
def __len__(self): |
|
|
if self.max_num_samples is None: |
|
|
return len(self.src_indexed_dataset) |
|
|
else: |
|
|
return self.max_num_samples |
|
|
|
|
|
def _get_sample(self, idx): |
|
|
if isinstance(idx, np.int64): |
|
|
idx = idx.item() |
|
|
|
|
|
if self.samples_mapping is not None: |
|
|
assert idx < len(self.samples_mapping) |
|
|
idx, _, _ = self.samples_mapping[idx] |
|
|
if isinstance(idx, np.uint32): |
|
|
idx = idx.item() |
|
|
|
|
|
assert idx < len(self.src_indexed_dataset) |
|
|
src = self.src_indexed_dataset[idx] |
|
|
tgt = self.tgt_indexed_dataset[idx] |
|
|
|
|
|
return src, tgt |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
src, tgt = self._get_sample(idx) |
|
|
offset = 0 |
|
|
if self.add_bos_to_enc: |
|
|
offset += 1 |
|
|
if self.add_eos_to_enc: |
|
|
offset += 1 |
|
|
if self.prepend_id is not None: |
|
|
offset += 1 |
|
|
|
|
|
if len(src) > self.max_src_seq_length - offset: |
|
|
src = src[: self.max_src_seq_length - offset] |
|
|
|
|
|
if self.add_bos_to_enc: |
|
|
src = np.concatenate([[self.src_tokenizer.bos_id], src]) |
|
|
|
|
|
if self.prepend_id is not None: |
|
|
src = np.concatenate([[self.prepend_id], src]) |
|
|
|
|
|
if self.add_eos_to_enc: |
|
|
src = np.concatenate([src, [self.src_tokenizer.eos_id]]) |
|
|
|
|
|
if len(tgt) > self.max_tgt_seq_length - 2: |
|
|
tgt = tgt[: self.max_tgt_seq_length - 2] |
|
|
|
|
|
text_dec = np.concatenate([[self.tgt_tokenizer.bos_id], tgt]) |
|
|
labels = np.concatenate([tgt, [self.tgt_tokenizer.eos_id]]) |
|
|
|
|
|
return {'text_enc': src, 'text_dec': text_dec, 'labels': labels} |
|
|
|
|
|
def _build_samples_mapping(self): |
|
|
if self.max_num_samples is not None: |
|
|
|
|
|
if self.max_src_seq_length != self.max_tgt_seq_length: |
|
|
raise ValueError( |
|
|
f"max_src_seq_length ({self.max_src_seq_length}) != max_tgt_seq_length ({self.max_tgt_seq_length}). This is needed for max_samples based training for now." |
|
|
) |
|
|
|
|
|
self.samples_mapping = get_samples_mapping( |
|
|
indexed_dataset=self.src_indexed_dataset, |
|
|
data_prefix=self.src_file_name, |
|
|
num_epochs=None, |
|
|
max_num_samples=self.max_num_samples, |
|
|
max_seq_length=self.max_src_seq_length - 2, |
|
|
short_seq_prob=0, |
|
|
seed=self.seed, |
|
|
name=self.src_file_name.split('/')[-1], |
|
|
binary_head=False, |
|
|
) |
|
|
else: |
|
|
self.samples_mapping = None |
|
|
|
|
|
|
|
|
class TextMemmapSequenceToSequenceDataset(IndexedSequenceToSequenceDataset): |
|
|
"""Memory-mapped text sequence to sequence dataset. Operates on raw text files and tokenizes the text on-the-fly.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
src_file_name: str, |
|
|
tgt_file_name: str, |
|
|
src_tokenizer: TokenizerSpec, |
|
|
tgt_tokenizer: TokenizerSpec, |
|
|
max_src_seq_length: int, |
|
|
max_tgt_seq_length: int, |
|
|
seed: int = 1234, |
|
|
max_num_samples: int = None, |
|
|
add_bos_to_enc: bool = True, |
|
|
add_eos_to_enc: bool = True, |
|
|
prepend_id: int = None, |
|
|
): |
|
|
""" |
|
|
src_file_name: Path to a single source file on disk. The file should contain one sentence per line and be raw text. |
|
|
tgt_file_name: Path to a single target file on disk. The file should contain one sentence per line aligned with src_file_name and be raw text. |
|
|
src_tokenizer: Tokenizer for the source dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece). |
|
|
tgt_tokenizer: Tokenizer for the target dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece). |
|
|
max_src_seq_length: Maximum length of the source sequences. Lines above this length will be truncated. |
|
|
max_tgt_seq_length: Maximum length of the target sequences. Lines above this length will be truncated. |
|
|
seed: Random seed for data shuffling. |
|
|
max_num_samples: Maximum number of samples to load. This can be > dataset length if you want to oversample data. If None, all samples will be loaded. |
|
|
add_bos_to_enc: Add BOS token to the encoder input. |
|
|
add_eos_to_enc: Add EOS token to the encoder input. |
|
|
prepend_id: If not None, prepend this id to the encoder input. |
|
|
""" |
|
|
self.seed = seed |
|
|
self.max_num_samples = max_num_samples |
|
|
super().__init__( |
|
|
src_file_name=src_file_name, |
|
|
tgt_file_name=tgt_file_name, |
|
|
src_tokenizer=src_tokenizer, |
|
|
tgt_tokenizer=tgt_tokenizer, |
|
|
max_src_seq_length=max_src_seq_length, |
|
|
max_tgt_seq_length=max_tgt_seq_length, |
|
|
seed=seed, |
|
|
max_num_samples=max_num_samples, |
|
|
add_bos_to_enc=add_bos_to_enc, |
|
|
add_eos_to_enc=add_eos_to_enc, |
|
|
prepend_id=prepend_id, |
|
|
) |
|
|
|
|
|
def _get_examples(self): |
|
|
self.src_indexed_dataset = TextMemMapDataset( |
|
|
dataset_paths=[self.src_file_name], tokenizer=self.src_tokenizer, header_lines=0 |
|
|
) |
|
|
self.tgt_indexed_dataset = TextMemMapDataset( |
|
|
dataset_paths=[self.tgt_file_name], tokenizer=self.tgt_tokenizer, header_lines=0 |
|
|
) |
|
|
|
|
|
assert len(self.src_indexed_dataset) == len( |
|
|
self.tgt_indexed_dataset |
|
|
), "src and tgt has different number of lines" |
|
|
self._build_samples_mapping() |
|
|
|
|
|
|
|
|
class BinarizedMemmapSequenceToSequenceDataset(IndexedSequenceToSequenceDataset): |
|
|
"""Memory-mapped text sequence to sequence dataset. Operates pre-tokenized binarized data files.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
src_dataset_prefix: str, |
|
|
tgt_dataset_prefix: str, |
|
|
src_tokenizer: TokenizerSpec, |
|
|
tgt_tokenizer: TokenizerSpec, |
|
|
max_src_seq_length: int, |
|
|
max_tgt_seq_length: int, |
|
|
seed: int = 1234, |
|
|
max_num_samples: int = None, |
|
|
add_bos_to_enc: bool = True, |
|
|
add_eos_to_enc: bool = True, |
|
|
prepend_id: int = None, |
|
|
): |
|
|
""" |
|
|
src_dataset_prefix: Path to the *prefix* of a single source bin/idx file on disk. This necessitates the existance src_file_prefix.bin and src_file_prefix.idx. |
|
|
tgt_dataset_prefix: Path to the *prefix* of a single target aligned with source bin/idx file on disk. This necessitates the existance tgt_file_prefix.bin and tgt_file_prefix.idx. |
|
|
src_tokenizer: Tokenizer for the source dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece). |
|
|
tgt_tokenizer: Tokenizer for the target dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece). |
|
|
max_src_seq_length: Maximum length of the source sequences. Lines above this length will be truncated. |
|
|
max_tgt_seq_length: Maximum length of the target sequences. Lines above this length will be truncated. |
|
|
seed: Random seed for data shuffling. |
|
|
max_num_samples: Maximum number of samples to load. This can be > dataset length if you want to oversample data. If None, all samples will be loaded. |
|
|
add_bos_to_enc: Add BOS token to the encoder input. |
|
|
add_eos_to_enc: Add EOS token to the encoder input. |
|
|
prepend_id: If not None, prepend this id to the encoder input. |
|
|
""" |
|
|
self.src_dataset_prefix = src_dataset_prefix |
|
|
self.tgt_dataset_prefix = tgt_dataset_prefix |
|
|
self.seed = seed |
|
|
self.max_num_samples = max_num_samples |
|
|
super().__init__( |
|
|
src_file_name=src_dataset_prefix, |
|
|
tgt_file_name=tgt_dataset_prefix, |
|
|
src_tokenizer=src_tokenizer, |
|
|
tgt_tokenizer=tgt_tokenizer, |
|
|
max_src_seq_length=max_src_seq_length, |
|
|
max_tgt_seq_length=max_tgt_seq_length, |
|
|
seed=seed, |
|
|
max_num_samples=max_num_samples, |
|
|
add_bos_to_enc=add_bos_to_enc, |
|
|
add_eos_to_enc=add_eos_to_enc, |
|
|
prepend_id=prepend_id, |
|
|
) |
|
|
|
|
|
def _check_files_exist(self): |
|
|
if not os.path.exists(self.src_dataset_prefix + ".bin") or not os.path.exists( |
|
|
self.src_dataset_prefix + ".idx" |
|
|
): |
|
|
raise FileNotFoundError(f"{self.src_dataset_prefix}.bin or {self.src_dataset_prefix}.idx not found") |
|
|
if not os.path.exists(self.tgt_dataset_prefix + ".bin") or not os.path.exists( |
|
|
self.tgt_dataset_prefix + ".idx" |
|
|
): |
|
|
raise FileNotFoundError(f"{self.tgt_dataset_prefix}.bin or {self.tgt_dataset_prefix}.idx not found") |
|
|
|
|
|
def _get_examples(self): |
|
|
self.src_indexed_dataset = self._get_indexed_dataset( |
|
|
self.src_dataset_prefix, data_impl='mmap', skip_warmup=True |
|
|
) |
|
|
self.tgt_indexed_dataset = self._get_indexed_dataset( |
|
|
self.tgt_dataset_prefix, data_impl='mmap', skip_warmup=True |
|
|
) |
|
|
assert len(self.src_indexed_dataset) == len(self.tgt_indexed_dataset) |
|
|
self._build_samples_mapping() |
|
|
|
|
|
def _get_indexed_dataset(self, data_prefix, data_impl, skip_warmup): |
|
|
indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) |
|
|
return indexed_dataset |
|
|
|