File size: 17,720 Bytes
7934b29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 |
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import torch
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import (
get_indexed_dataset_,
get_samples_mapping,
)
from nemo.collections.nlp.data.language_modeling.text_memmap_dataset import TextMemMapDataset
from nemo.core.classes import Dataset
from nemo.utils import logging
__all__ = ['SequenceToSequenceDataset', 'TextMemmapSequenceToSequenceDataset']
class SequenceToSequenceDataset(Dataset):
"""Sequence to Sequence Dataset in memory."""
def __init__(
self,
src_file_name: str,
tgt_file_name: str,
src_tokenizer: TokenizerSpec,
tgt_tokenizer: TokenizerSpec,
max_src_seq_length: int,
max_tgt_seq_length: int,
add_bos_to_input: bool = True,
add_eos_to_input: bool = True,
replace_bos_with_pad: bool = False,
):
super().__init__()
self.src_file_name = src_file_name
self.tgt_file_name = tgt_file_name
self.src_tokenizer = src_tokenizer
self.tgt_tokenizer = tgt_tokenizer
self.max_src_seq_length = max_src_seq_length
self.max_tgt_seq_length = max_tgt_seq_length
self.add_bos_to_input = add_bos_to_input
self.add_eos_to_input = add_eos_to_input
self.replace_bos_with_pad = replace_bos_with_pad
assert self.max_src_seq_length > 0
assert self.max_tgt_seq_length > 0
self._check_files_exist()
self._get_examples()
def _check_files_exist(self):
if not os.path.exists(self.src_file_name):
raise FileNotFoundError(f"Source file {self.src_file_name} not found")
if not os.path.exists(self.tgt_file_name):
raise FileNotFoundError(f"Source file {self.src_file_name} not found")
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
example = self.examples[idx]
text_enc = example['src']
text_dec = example['tgt'][:-1]
labels = example['tgt'][1:]
return {'text_enc': text_enc, 'text_dec': text_dec, 'labels': labels}
def _get_examples(self):
self.examples = []
with open(self.src_file_name, encoding='utf8') as f_src, open(self.tgt_file_name, encoding='utf8') as f_tgt:
for i, (src, tgt) in enumerate(zip(f_src, f_tgt)):
if i % 10000 == 0 and i != 0:
logging.info(f"Read {i} lines from {self.src_file_name} & {self.tgt_file_name}")
src = self.src_tokenizer.text_to_ids(src.strip())
if self.add_bos_to_input:
src = [self.src_tokenizer.pad_id if self.replace_bos_with_pad else self.src_tokenizer.bos_id] + src
if self.add_eos_to_input:
src = src + [self.src_tokenizer.eos_id]
tgt = (
[self.tgt_tokenizer.pad_id if self.replace_bos_with_pad else self.tgt_tokenizer.bos_id]
+ self.tgt_tokenizer.text_to_ids(tgt.strip())
+ [self.tgt_tokenizer.eos_id]
)
# Truncate to max sequence length.
if len(src) > self.max_src_seq_length:
src = src[-self.max_src_seq_length + 1 :]
if len(tgt) > self.max_tgt_seq_length:
tgt = tgt[-self.max_tgt_seq_length + 1 :]
self.examples.append({'src': src, 'tgt': tgt})
logging.info(f'Dataset Length : {len(self.examples)}')
def collate_fn(self, batch):
text_enc = [item['text_enc'] for item in batch]
text_dec = [item['text_dec'] for item in batch]
labels = [item['labels'] for item in batch]
if isinstance(text_enc[0], np.ndarray):
text_enc = [x.tolist() for x in text_enc]
if isinstance(text_dec[0], np.ndarray):
text_dec = [x.tolist() for x in text_dec]
if isinstance(labels[0], np.ndarray):
labels = [x.tolist() for x in labels]
max_dec_input_length = max([len(item) for item in text_dec]) if text_dec else 0
max_enc_input_length = max([len(item) for item in text_enc]) if text_enc else 0
max_label_length = max([len(item) for item in labels]) if labels else 0
loss_mask = [([1] * (len(item))) + ([0] * (max_label_length - len(item))) for item in labels]
text_enc = [item + [self.src_tokenizer.pad_id] * (max_enc_input_length - len(item)) for item in text_enc]
text_dec = [item + [self.tgt_tokenizer.pad_id] * (max_dec_input_length - len(item)) for item in text_dec]
labels = [item + [self.tgt_tokenizer.pad_id] * (max_label_length - len(item)) for item in labels]
text_enc = torch.LongTensor(text_enc)
text_dec = torch.LongTensor(text_dec)
labels = torch.LongTensor(labels)
loss_mask = torch.LongTensor(loss_mask)
enc_mask = (text_enc != self.src_tokenizer.pad_id).long()
dec_mask = (text_dec != self.tgt_tokenizer.pad_id).long()
return {
'text_enc': text_enc,
'text_dec': text_dec,
'labels': labels,
'loss_mask': loss_mask,
'enc_mask': enc_mask,
'dec_mask': dec_mask,
}
class IndexedSequenceToSequenceDataset(SequenceToSequenceDataset):
"""Abstract class for TextMemmapSequenceToSequenceDataset and BinarizedMemmapSequenceToSequenceDataset.
This class is not meant to be used standalone and just as an abstract class for the two subclasses.
"""
def __init__(
self,
src_file_name: str,
tgt_file_name: str,
src_tokenizer: TokenizerSpec,
tgt_tokenizer: TokenizerSpec,
max_src_seq_length: int,
max_tgt_seq_length: int,
seed: int = 1234,
add_bos_to_enc: bool = True,
add_eos_to_enc: bool = True,
max_num_samples: int = None,
prepend_id: int = None,
):
"""
src_file_name: Path to a single source file on disk. This is either the path to a raw text file or the prefix to the processed src_file_name.bin/idx files.
src_file_name: Path to a single target file on disk. This is either the path to a raw text file or the prefix to the processed tgt_file_name.bin/idx files.
src_tokenizer: Tokenizer for the source dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece).
tgt_tokenizer: Tokenizer for the target dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece).
max_src_seq_length: Maximum length of the source sequences. Lines above this length will be truncated.
max_tgt_seq_length: Maximum length of the target sequences. Lines above this length will be truncated.
seed: Random seed for data shuffling.
max_num_samples: Maximum number of samples to load. This can be > dataset length if you want to oversample data. If None, all samples will be loaded.
prepend_id: If not None, prepend this id to the encoder input.
"""
super().__init__(
src_file_name=src_file_name,
tgt_file_name=tgt_file_name,
src_tokenizer=src_tokenizer,
tgt_tokenizer=tgt_tokenizer,
max_src_seq_length=max_src_seq_length,
max_tgt_seq_length=max_tgt_seq_length,
)
self.seed = seed
self.max_num_samples = max_num_samples
self.add_bos_to_enc = add_bos_to_enc
self.add_eos_to_enc = add_eos_to_enc
self.prepend_id = prepend_id
logging.info(f'Desired number of samples : {self.max_num_samples}')
logging.info(f'Source Dataset Length : {len(self.src_indexed_dataset)}')
logging.info(f'Target Dataset Length : {len(self.tgt_indexed_dataset)}')
def __len__(self):
if self.max_num_samples is None:
return len(self.src_indexed_dataset)
else:
return self.max_num_samples
def _get_sample(self, idx):
if isinstance(idx, np.int64):
idx = idx.item()
if self.samples_mapping is not None:
assert idx < len(self.samples_mapping)
idx, _, _ = self.samples_mapping[idx]
if isinstance(idx, np.uint32):
idx = idx.item()
assert idx < len(self.src_indexed_dataset)
src = self.src_indexed_dataset[idx]
tgt = self.tgt_indexed_dataset[idx]
return src, tgt
def __getitem__(self, idx):
src, tgt = self._get_sample(idx)
offset = 0
if self.add_bos_to_enc:
offset += 1
if self.add_eos_to_enc:
offset += 1
if self.prepend_id is not None:
offset += 1
if len(src) > self.max_src_seq_length - offset:
src = src[: self.max_src_seq_length - offset]
if self.add_bos_to_enc:
src = np.concatenate([[self.src_tokenizer.bos_id], src])
if self.prepend_id is not None:
src = np.concatenate([[self.prepend_id], src])
if self.add_eos_to_enc:
src = np.concatenate([src, [self.src_tokenizer.eos_id]])
if len(tgt) > self.max_tgt_seq_length - 2:
tgt = tgt[: self.max_tgt_seq_length - 2]
text_dec = np.concatenate([[self.tgt_tokenizer.bos_id], tgt])
labels = np.concatenate([tgt, [self.tgt_tokenizer.eos_id]])
return {'text_enc': src, 'text_dec': text_dec, 'labels': labels}
def _build_samples_mapping(self):
if self.max_num_samples is not None:
# This means max src and max tgt sequence length need to be the same
if self.max_src_seq_length != self.max_tgt_seq_length:
raise ValueError(
f"max_src_seq_length ({self.max_src_seq_length}) != max_tgt_seq_length ({self.max_tgt_seq_length}). This is needed for max_samples based training for now."
)
self.samples_mapping = get_samples_mapping(
indexed_dataset=self.src_indexed_dataset,
data_prefix=self.src_file_name,
num_epochs=None,
max_num_samples=self.max_num_samples,
max_seq_length=self.max_src_seq_length - 2,
short_seq_prob=0,
seed=self.seed,
name=self.src_file_name.split('/')[-1],
binary_head=False,
)
else:
self.samples_mapping = None
class TextMemmapSequenceToSequenceDataset(IndexedSequenceToSequenceDataset):
"""Memory-mapped text sequence to sequence dataset. Operates on raw text files and tokenizes the text on-the-fly."""
def __init__(
self,
src_file_name: str,
tgt_file_name: str,
src_tokenizer: TokenizerSpec,
tgt_tokenizer: TokenizerSpec,
max_src_seq_length: int,
max_tgt_seq_length: int,
seed: int = 1234,
max_num_samples: int = None,
add_bos_to_enc: bool = True,
add_eos_to_enc: bool = True,
prepend_id: int = None,
):
"""
src_file_name: Path to a single source file on disk. The file should contain one sentence per line and be raw text.
tgt_file_name: Path to a single target file on disk. The file should contain one sentence per line aligned with src_file_name and be raw text.
src_tokenizer: Tokenizer for the source dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece).
tgt_tokenizer: Tokenizer for the target dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece).
max_src_seq_length: Maximum length of the source sequences. Lines above this length will be truncated.
max_tgt_seq_length: Maximum length of the target sequences. Lines above this length will be truncated.
seed: Random seed for data shuffling.
max_num_samples: Maximum number of samples to load. This can be > dataset length if you want to oversample data. If None, all samples will be loaded.
add_bos_to_enc: Add BOS token to the encoder input.
add_eos_to_enc: Add EOS token to the encoder input.
prepend_id: If not None, prepend this id to the encoder input.
"""
self.seed = seed
self.max_num_samples = max_num_samples
super().__init__(
src_file_name=src_file_name,
tgt_file_name=tgt_file_name,
src_tokenizer=src_tokenizer,
tgt_tokenizer=tgt_tokenizer,
max_src_seq_length=max_src_seq_length,
max_tgt_seq_length=max_tgt_seq_length,
seed=seed,
max_num_samples=max_num_samples,
add_bos_to_enc=add_bos_to_enc,
add_eos_to_enc=add_eos_to_enc,
prepend_id=prepend_id,
)
def _get_examples(self):
self.src_indexed_dataset = TextMemMapDataset(
dataset_paths=[self.src_file_name], tokenizer=self.src_tokenizer, header_lines=0
)
self.tgt_indexed_dataset = TextMemMapDataset(
dataset_paths=[self.tgt_file_name], tokenizer=self.tgt_tokenizer, header_lines=0
)
assert len(self.src_indexed_dataset) == len(
self.tgt_indexed_dataset
), "src and tgt has different number of lines"
self._build_samples_mapping()
class BinarizedMemmapSequenceToSequenceDataset(IndexedSequenceToSequenceDataset):
"""Memory-mapped text sequence to sequence dataset. Operates pre-tokenized binarized data files."""
def __init__(
self,
src_dataset_prefix: str,
tgt_dataset_prefix: str,
src_tokenizer: TokenizerSpec,
tgt_tokenizer: TokenizerSpec,
max_src_seq_length: int,
max_tgt_seq_length: int,
seed: int = 1234,
max_num_samples: int = None,
add_bos_to_enc: bool = True,
add_eos_to_enc: bool = True,
prepend_id: int = None,
):
"""
src_dataset_prefix: Path to the *prefix* of a single source bin/idx file on disk. This necessitates the existance src_file_prefix.bin and src_file_prefix.idx.
tgt_dataset_prefix: Path to the *prefix* of a single target aligned with source bin/idx file on disk. This necessitates the existance tgt_file_prefix.bin and tgt_file_prefix.idx.
src_tokenizer: Tokenizer for the source dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece).
tgt_tokenizer: Tokenizer for the target dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece).
max_src_seq_length: Maximum length of the source sequences. Lines above this length will be truncated.
max_tgt_seq_length: Maximum length of the target sequences. Lines above this length will be truncated.
seed: Random seed for data shuffling.
max_num_samples: Maximum number of samples to load. This can be > dataset length if you want to oversample data. If None, all samples will be loaded.
add_bos_to_enc: Add BOS token to the encoder input.
add_eos_to_enc: Add EOS token to the encoder input.
prepend_id: If not None, prepend this id to the encoder input.
"""
self.src_dataset_prefix = src_dataset_prefix
self.tgt_dataset_prefix = tgt_dataset_prefix
self.seed = seed
self.max_num_samples = max_num_samples
super().__init__(
src_file_name=src_dataset_prefix,
tgt_file_name=tgt_dataset_prefix,
src_tokenizer=src_tokenizer,
tgt_tokenizer=tgt_tokenizer,
max_src_seq_length=max_src_seq_length,
max_tgt_seq_length=max_tgt_seq_length,
seed=seed,
max_num_samples=max_num_samples,
add_bos_to_enc=add_bos_to_enc,
add_eos_to_enc=add_eos_to_enc,
prepend_id=prepend_id,
)
def _check_files_exist(self):
if not os.path.exists(self.src_dataset_prefix + ".bin") or not os.path.exists(
self.src_dataset_prefix + ".idx"
):
raise FileNotFoundError(f"{self.src_dataset_prefix}.bin or {self.src_dataset_prefix}.idx not found")
if not os.path.exists(self.tgt_dataset_prefix + ".bin") or not os.path.exists(
self.tgt_dataset_prefix + ".idx"
):
raise FileNotFoundError(f"{self.tgt_dataset_prefix}.bin or {self.tgt_dataset_prefix}.idx not found")
def _get_examples(self):
self.src_indexed_dataset = self._get_indexed_dataset(
self.src_dataset_prefix, data_impl='mmap', skip_warmup=True
)
self.tgt_indexed_dataset = self._get_indexed_dataset(
self.tgt_dataset_prefix, data_impl='mmap', skip_warmup=True
)
assert len(self.src_indexed_dataset) == len(self.tgt_indexed_dataset)
self._build_samples_mapping()
def _get_indexed_dataset(self, data_prefix, data_impl, skip_warmup):
indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup)
return indexed_dataset
|