| |
| |
| |
| |
|
|
| import numpy as np |
| import torch |
| from fairseq.data import FairseqDataset, plasma_utils |
| from fairseq.data.indexed_dataset import best_fitting_int_dtype |
| from typing import Tuple |
|
|
|
|
| class TokenBlockDataset(FairseqDataset): |
| """Break a Dataset of tokens into blocks. |
| |
| Args: |
| dataset (~torch.utils.data.Dataset): dataset to break into blocks |
| sizes (List[int]): sentence lengths (required for 'complete' and 'eos') |
| block_size (int): maximum block size (ignored in 'eos' break mode) |
| break_mode (str, optional): Mode used for breaking tokens. Values can |
| be one of: |
| - 'none': break tokens into equally sized blocks (up to block_size) |
| - 'complete': break tokens into blocks (up to block_size) such that |
| blocks contains complete sentences, although block_size may be |
| exceeded if some sentences exceed block_size |
| - 'complete_doc': similar to 'complete' mode, but do not |
| cross document boundaries |
| - 'eos': each block contains one sentence (block_size is ignored) |
| include_targets (bool, optional): return next tokens as targets |
| (default: False). |
| document_sep_len (int, optional): document separator size (required for |
| 'complete_doc' break mode). Typically 1 if the sentences have eos |
| and 0 otherwise. |
| """ |
|
|
| def __init__( |
| self, |
| dataset, |
| sizes, |
| block_size, |
| pad, |
| eos, |
| break_mode=None, |
| include_targets=False, |
| document_sep_len=1, |
| use_plasma_view=False, |
| split_path=None, |
| plasma_path=None, |
| ): |
|
|
| super().__init__() |
| self.dataset = dataset |
| self.pad = pad |
| self.eos = eos |
| self.include_targets = include_targets |
|
|
| assert len(dataset) > 0 |
|
|
| assert len(dataset) == len(sizes) |
| _sizes, block_to_dataset_index, slice_indices = self._build_slice_indices( |
| sizes, break_mode, document_sep_len, block_size |
| ) |
| if use_plasma_view: |
| plasma_id = (block_size, document_sep_len, str(break_mode), len(dataset)) |
| self._slice_indices = plasma_utils.PlasmaView( |
| slice_indices, split_path, (plasma_id, 0), plasma_path=plasma_path |
| ) |
| self._sizes = plasma_utils.PlasmaView( |
| _sizes, split_path, (plasma_id, 1), plasma_path=plasma_path |
| ) |
| self._block_to_dataset_index = plasma_utils.PlasmaView( |
| block_to_dataset_index, split_path, (plasma_id, 2), plasma_path=plasma_path, |
| ) |
| else: |
| self._slice_indices = plasma_utils.PlasmaArray(slice_indices) |
| self._sizes = plasma_utils.PlasmaArray(_sizes) |
| self._block_to_dataset_index = plasma_utils.PlasmaArray( |
| block_to_dataset_index |
| ) |
|
|
| @staticmethod |
| def _build_slice_indices( |
| sizes, break_mode, document_sep_len, block_size |
| ) -> Tuple[np.ndarray]: |
| """Use token_block_utils_fast to build arrays for indexing into self.dataset""" |
| try: |
| from fairseq.data.token_block_utils_fast import ( |
| _get_slice_indices_fast, |
| _get_block_to_dataset_index_fast, |
| ) |
| except ImportError: |
| raise ImportError( |
| "Please build Cython components with: `pip install --editable .` " |
| "or `python setup.py build_ext --inplace`" |
| ) |
|
|
| if isinstance(sizes, list): |
| sizes = np.array(sizes, dtype=np.int64) |
| else: |
| if torch.is_tensor(sizes): |
| sizes = sizes.numpy() |
| sizes = sizes.astype(np.int64) |
|
|
| break_mode = break_mode if break_mode is not None else "none" |
|
|
| |
| if break_mode == "eos" and block_size is None: |
| block_size = 0 |
|
|
| slice_indices = _get_slice_indices_fast( |
| sizes, str(break_mode), block_size, document_sep_len |
| ) |
| _sizes = slice_indices[:, 1] - slice_indices[:, 0] |
|
|
| |
| if break_mode == "eos": |
| |
| block_to_dataset_index = np.stack( |
| [ |
| np.arange(len(sizes)), |
| np.zeros( |
| len(sizes), dtype=np.compat.long |
| ), |
| np.arange(len(sizes)), |
| ], |
| 1, |
| ) |
| else: |
| block_to_dataset_index = _get_block_to_dataset_index_fast( |
| sizes, slice_indices, |
| ) |
| size_dtype = np.uint16 if block_size < 65535 else np.uint32 |
| num_tokens = slice_indices[-1].max() |
| slice_indices_dtype = best_fitting_int_dtype(num_tokens) |
| slice_indices = slice_indices.astype(slice_indices_dtype) |
| _sizes = _sizes.astype(size_dtype) |
| block_to_dataset_index = block_to_dataset_index.astype(slice_indices_dtype) |
| return _sizes, block_to_dataset_index, slice_indices |
|
|
| @property |
| def slice_indices(self): |
| return self._slice_indices.array |
|
|
| @property |
| def sizes(self): |
| return self._sizes.array |
|
|
| @property |
| def block_to_dataset_index(self): |
| return self._block_to_dataset_index.array |
|
|
| def attr(self, attr: str, index: int): |
| start_ds_idx, _, _ = self.block_to_dataset_index[index] |
| return self.dataset.attr(attr, start_ds_idx) |
|
|
| def __getitem__(self, index): |
| start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index] |
|
|
| buffer = torch.cat( |
| [self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)] |
| ) |
| slice_s, slice_e = self.slice_indices[index] |
| length = slice_e - slice_s |
| s, e = start_offset, start_offset + length |
| item = buffer[s:e] |
|
|
| if self.include_targets: |
| |
| |
| |
| if s == 0: |
| source = torch.cat([item.new([self.eos]), buffer[0 : e - 1]]) |
| past_target = torch.cat( |
| [item.new([self.pad, self.eos]), buffer[0 : e - 2]] |
| ) |
| else: |
| source = buffer[s - 1 : e - 1] |
| if s == 1: |
| past_target = torch.cat([item.new([self.eos]), buffer[0 : e - 2]]) |
| else: |
| past_target = buffer[s - 2 : e - 2] |
|
|
| return source, item, past_target |
|
|
| return item |
|
|
| def __len__(self): |
| return len(self.slice_indices) |
|
|
| @property |
| def supports_prefetch(self): |
| return getattr(self.dataset, "supports_prefetch", False) |
|
|
| def prefetch(self, indices): |
| self.dataset.prefetch( |
| { |
| ds_idx |
| for index in indices |
| for start_ds_idx, _, end_ds_idx in [self.block_to_dataset_index[index]] |
| for ds_idx in range(start_ds_idx, end_ds_idx + 1) |
| } |
| ) |
|
|