| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import numpy as np |
| import torch |
|
|
| from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec |
| from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import get_samples_mapping |
| from nemo.collections.nlp.data.language_modeling.text_memmap_dataset import JSONLMemMapDataset, TextMemMapDataset |
| from nemo.core.classes import Dataset |
|
|
|
|
| class T0Dataset(Dataset): |
| def __init__( |
| self, |
| file_path: str, |
| tokenizer: TokenizerSpec, |
| max_src_seq_length: int = 512, |
| max_tgt_seq_length: int = 512, |
| replace_bos_with_pad: bool = False, |
| add_bos_to_input: bool = False, |
| add_eos_to_input: bool = False, |
| max_num_samples: int = None, |
| seed: int = 1234, |
| ): |
| """ |
| src_file_name: Path to a JSONL T0 dataset file. |
| tokenizer: Tokenizer for the dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece). |
| max_src_seq_length: Maximum length of the source sequences. Lines above this length will be truncated. |
| max_tgt_seq_length: Maximum length of the target sequences. Lines above this length will be truncated. |
| replace_bos_with_pad: Whether the decoder starts with a pad token. This is needed for Google's T5 models that may be converted from HF. |
| add_bos_to_input: Whether to add the bos_id to the input sequence. |
| add_eos_to_input: Whether to add the eos_id to the input sequence. |
| seed: Random seed for data shuffling. |
| max_num_samples: Maximum number of samples to load. This can be > dataset length if you want to oversample data. If None, all samples will be loaded. |
| """ |
| self.tokenizer = tokenizer |
| self.file_path = file_path |
| self.max_src_seq_length = max_src_seq_length |
| self.max_tgt_seq_length = max_tgt_seq_length |
| self.replace_bos_with_pad = replace_bos_with_pad |
| self.add_bos_to_input = add_bos_to_input |
| self.add_eos_to_input = add_eos_to_input |
| self.max_num_samples = max_num_samples |
| self.seed = seed |
|
|
| self.indexed_dataset = JSONLMemMapDataset(dataset_paths=[file_path], tokenizer=None, header_lines=0) |
|
|
| |
| self._build_samples_mapping() |
|
|
| def _build_samples_mapping(self): |
| if self.max_num_samples is not None: |
| |
| if self.max_src_seq_length != self.max_tgt_seq_length: |
| raise ValueError( |
| f"max_src_seq_length ({self.max_src_seq_length}) != max_tgt_seq_length ({self.max_tgt_seq_length}). This is needed for max_samples based training for now." |
| ) |
|
|
| self.samples_mapping = get_samples_mapping( |
| indexed_dataset=self.indexed_dataset, |
| data_prefix=self.file_path, |
| num_epochs=None, |
| max_num_samples=self.max_num_samples, |
| max_seq_length=self.max_src_seq_length - 2, |
| short_seq_prob=0, |
| seed=self.seed, |
| name=self.file_path.split('/')[-1], |
| binary_head=False, |
| ) |
| else: |
| self.samples_mapping = None |
|
|
| def __len__(self): |
| if self.max_num_samples is None: |
| return len(self.indexed_dataset) |
| else: |
| return len(self.samples_mapping) |
|
|
| def __getitem__(self, idx): |
| if isinstance(idx, np.int64): |
| idx = idx.item() |
|
|
| if self.samples_mapping is not None: |
| assert idx < len(self.samples_mapping) |
| idx, _, _ = self.samples_mapping[idx] |
| if isinstance(idx, np.uint32): |
| idx = idx.item() |
|
|
| assert idx < len(self.indexed_dataset) |
| example = self.indexed_dataset[idx] |
| return self._process_example(example) |
|
|
| def _process_example(self, example): |
| """ |
| Process a single example from the dataset into IDs and other T0-related metadata. |
| """ |
| tokenized_input = self.tokenizer.text_to_ids(example['input']) |
| tokenized_output = self.tokenizer.text_to_ids(example['output']) |
| offset = 0 |
| if self.add_bos_to_input: |
| offset += 1 |
| if self.add_eos_to_input: |
| offset += 1 |
|
|
| if len(tokenized_input) > self.max_src_seq_length - offset: |
| tokenized_input = tokenized_input[: self.max_src_seq_length - offset] |
|
|
| if len(tokenized_output) > self.max_tgt_seq_length - 2: |
| tokenized_output = tokenized_output[: self.max_tgt_seq_length - 2] |
|
|
| bos_id = self.tokenizer.pad_id if self.replace_bos_with_pad else self.tokenizer.bos_id |
| if self.add_bos_to_input: |
| tokenized_input = [bos_id] + tokenized_input |
| if self.add_eos_to_input: |
| tokenized_input = tokenized_input + [self.tokenizer.eos_id] |
| target = [bos_id] + tokenized_output + [self.tokenizer.eos_id] |
|
|
| processed_example = { |
| 'text_enc': tokenized_input, |
| 'text_dec': target[:-1], |
| 'labels': target[1:], |
| } |
|
|
| |
| if 'chunked_idx' in example: |
| original = "" |
| template = "" |
| for item in example['chunked_idx'].split(', '): |
| item = item.split('-') |
| if item[0] == "original_text": |
| original += example['input'][int(item[1]) : int(item[2])] |
| elif item[0] == "template": |
| template += example['input'][int(item[1]) : int(item[2])] |
| else: |
| raise ValueError(f"Unknown chunk type: {item[0]}") |
|
|
| additional_args = { |
| 'original': self.tokenizer.text_to_ids(original), |
| 'template': self.tokenizer.text_to_ids(template), |
| 'prompt': self.tokenizer.text_to_ids(example['prompt']), |
| } |
| processed_example.update(additional_args) |
|
|
| if 'choices' in example: |
| additional_args = { |
| 'choices': [self.tokenizer.text_to_ids(choice) for choice in example['choices']], |
| } |
| processed_example.update(additional_args) |
|
|
| if 'task_name_with_prompt' in example: |
| additional_args = { |
| 'task_name_with_prompt': self.tokenizer.text_to_ids(example['task_name_with_prompt']), |
| } |
| processed_example.update(additional_args) |
|
|
| return processed_example |
|
|
| def _maybe_cast_to_list(self, x): |
| if isinstance(x, np.ndarray): |
| return [item.tolist() for item in x] |
| return x |
|
|
| def _collate_item(self, item): |
| item = self._maybe_cast_to_list(item) |
| max_length = max([len(x) for x in item]) if item else 0 |
| item = [x + [self.tokenizer.pad_id] * (max_length - len(x)) for x in item] |
| return item |
|
|
| def collate_fn(self, batch): |
| enc_query = [item['text_enc'] for item in batch] |
| dec_input = [item['text_dec'] for item in batch] |
| labels = [item['labels'] for item in batch] |
|
|
| enc_query = torch.LongTensor(self._collate_item(enc_query)) |
| dec_input = torch.LongTensor(self._collate_item(dec_input)) |
| loss_mask = torch.LongTensor( |
| [([1] * (len(item))) + ([0] * (dec_input.size(1) - len(item))) for item in labels] |
| ) |
| labels = torch.LongTensor(self._collate_item(labels)) |
|
|
| enc_mask = (enc_query != self.tokenizer.pad_id).long() |
| dec_mask = (dec_input != self.tokenizer.pad_id).long() |
|
|
| processed_example = { |
| 'text_enc': enc_query, |
| 'text_dec': dec_input, |
| 'labels': labels, |
| 'loss_mask': loss_mask, |
| 'enc_mask': enc_mask, |
| 'dec_mask': dec_mask, |
| } |
|
|
| |
| if 'original' in batch[0]: |
| original = self._collate_item([item['original'] for item in batch]) |
| processed_example['original'] = torch.LongTensor(original) |
|
|
| if 'template' in batch[0]: |
| template = self._collate_item([item['template'] for item in batch]) |
| processed_example['template'] = torch.LongTensor(template) |
|
|
| if 'prompt' in batch[0]: |
| prompt = self._collate_item([item['prompt'] for item in batch]) |
| processed_example['prompt'] = torch.LongTensor(prompt) |
|
|
| if 'task_name_with_prompt' in batch[0]: |
| task_name_with_prompt = self._collate_item([item['task_name_with_prompt'] for item in batch]) |
| processed_example['task_name_with_prompt'] = torch.LongTensor(task_name_with_prompt) |
|
|
| return processed_example |
|
|