| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import datetime |
| import json |
| import multiprocessing as mp |
| import os |
| import pickle |
| import time |
| from functools import partial |
|
|
| import numpy as np |
| import torch |
|
|
| from nemo.core import Dataset |
| from nemo.utils import logging |
|
|
| __all__ = ['TextMemMapDataset', 'CSVMemMapDataset', 'build_index_files'] |
| __idx_version__ = '0.2' |
| __idx_suffix__ = 'idx' |
|
|
|
|
| def _build_index_from_memdata(fn, newline_int): |
| """ |
| Build index of delimiter positions between samples in memmap. |
| Can be provided externally. |
| |
| Returns a 1D array of ints. |
| """ |
| |
| mdata = np.memmap(fn, dtype=np.uint8, mode='r') |
| |
| midx = np.where(mdata == newline_int)[0] |
| midx_dtype = midx.dtype |
| |
| midx = midx.tolist() |
| |
| if (len(midx) == 0) or (midx[-1] + 1 != len(mdata)): |
| midx = midx + [len(mdata) + 1] |
|
|
| |
| while len(midx) > 1 and (midx[-1] - midx[-2]) < 2: |
| midx.pop(-1) |
| midx = np.asarray(midx, dtype=midx_dtype) |
|
|
| |
| mdata._mmap.close() |
| del mdata |
|
|
| return midx |
|
|
|
|
| class TextMemMapDataset(Dataset): |
| """ |
| Allow per-line lazy access to multiple text files using numpy memmap. |
| """ |
|
|
| def __init__( |
| self, |
| dataset_paths, |
| newline_int=10, |
| header_lines=0, |
| workers=None, |
| tokenizer=None, |
| sort_dataset_paths=True, |
| build_index_fn=_build_index_from_memdata, |
| ): |
| """ |
| build_index_fn - a callable build_index_fn(fn, newline_int) -> midx [np.array] that returns the index of newlines in a file fn |
| must be pickleable (to be used in multiprocessing.Pool.map) |
| """ |
| super().__init__() |
| self.mdata_midx_list = [] |
|
|
| |
| if isinstance(dataset_paths, str): |
| dataset_paths = [dataset_paths] |
|
|
| if len(dataset_paths) < 1: |
| raise ValueError("files_list must contain at leat one file name") |
|
|
| self._newline_int = newline_int |
| |
| self._header_lines = header_lines |
| self._files_list = dataset_paths |
| self._worker = workers |
| self.tokenizer = tokenizer |
| self._sort_dataset_paths = sort_dataset_paths |
|
|
| if sort_dataset_paths: |
| self._files_list = sorted(self._files_list) |
|
|
| logging.info(f"Building data files") |
| |
| is_ditributed = torch.distributed.is_available() and torch.distributed.is_initialized() |
|
|
| if not is_ditributed or (is_ditributed and torch.distributed.get_rank() == 0): |
| build_index_files(dataset_paths, newline_int, workers=self._worker, build_index_fn=build_index_fn) |
|
|
| if is_ditributed: |
| torch.distributed.barrier() |
|
|
| logging.info(f"Loading data files") |
| start_time = time.time() |
| mdata_midx_list = [self.load_file(fn) for fn in self._files_list] |
| logging.info( |
| f'Time loading {len(mdata_midx_list)} mem-mapped files: {datetime.timedelta(seconds=time.time() - start_time)}' |
| ) |
|
|
| logging.info("Computing global indices") |
| midx_bins = np.cumsum([(len(midx) - header_lines) for _, midx in mdata_midx_list]) |
|
|
| self.midx_bins = midx_bins |
| self.mdata_midx_list = mdata_midx_list |
|
|
| |
| self._size = self.midx_bins[-1] |
|
|
| def __del__(self): |
| if self.mdata_midx_list: |
| for mdata, midx in self.mdata_midx_list: |
| mdata._mmap.close() |
|
|
| def __len__(self): |
| return self._size |
|
|
| def __getitem__(self, idx): |
| """ |
| Return a string from binary memmap |
| """ |
| if (idx >= len(self)) or (idx < 0): |
| raise IndexError(f"Index {idx} if out of dataset range with {len(self)} samples") |
|
|
| |
| file_id = np.digitize(idx, self.midx_bins, right=False) |
| base_idx = self.midx_bins[file_id - 1] if file_id > 0 else 0 |
| file_idx = idx - base_idx + self._header_lines |
| mdata, midx = self.mdata_midx_list[file_id] |
| |
| if file_idx == 0: |
| i = 0 |
| j = midx[0] |
| else: |
| i = midx[file_idx - 1] + 1 |
| j = midx[file_idx] |
|
|
| |
|
|
| try: |
| sample = self._fetch_sample_from_memmap(mdata, i, j) |
| except Exception as e: |
| logging.error(f"Error while fetching sample from memmap: {e}") |
| logging.error(f"file_id: {file_id}, file_idx: {file_idx}, i: {i}, j: {j}") |
| raise e |
|
|
| |
| try: |
| data = self._build_data_from_text(sample) |
| except Exception as e: |
| logging.error( |
| f"Error while building data from text, possible issue with sample expected format (see offending sample below): {e}" |
| ) |
| logging.error(f"sample: {sample}, file_id: {file_id}, file_idx: {file_idx}, i: {i}, j: {j}") |
| raise e |
|
|
| return data |
|
|
| def _fetch_sample_from_memmap(self, mdata, i, j): |
| """Fetchs the text sample. Can be overriden by child-classes to support loading of partial samples and alternative decode methods""" |
| |
| text = mdata[i:j].tobytes().decode("utf-8") |
|
|
| return text |
|
|
| def _build_data_from_text(self, text): |
| """Allows child-classes to modify the parsing of raw text, prior to tokenization""" |
| |
| if self.tokenizer is not None: |
| data = self.tokenizer.text_to_ids(text) |
| else: |
| data = text |
|
|
| return data |
|
|
| def load_file(self, fn): |
| """ |
| Loads a text file as np.int8. |
| |
| Returns: |
| mdata - memorymap of np.int8 |
| midx - indices pointing to the end-of-line (or end of file) position |
| size - number of lines in file |
| """ |
| logging.info(f"Loading {fn}") |
| idx_fn = f"{fn}.{__idx_suffix__}" |
|
|
| |
| mdata = np.memmap(fn, dtype=np.uint8, mode='r') |
|
|
| if _index_file_exists(idx_fn): |
| |
| midx = np.load(idx_fn + ".npy", allow_pickle=True, mmap_mode='r') |
| |
| if len(midx) < self._header_lines: |
| raise RuntimeError(f"Missing header, expected {self._header_lines} header lines") |
|
|
| |
| idx_info_dict = pickle.load(open(idx_fn + ".info", 'rb')) |
| |
| if 'newline_int' in idx_info_dict: |
| newline_int = idx_info_dict['newline_int'] |
| if self._newline_int != newline_int: |
| logging.warning( |
| f"Mismatch in newline_int, expected = {self._newline_int} but loaded {newline_int}" |
| ) |
|
|
| |
| idx_version = idx_info_dict.get('version', '0.0') |
| if __idx_version__ != idx_version: |
| raise RuntimeError( |
| f"Version mismatch: Please delete existing '.{__idx_suffix__}' files. Expected version = {__idx_version__}, but file version = {idx_version}. File path = {idx_fn}" |
| ) |
| else: |
| raise ValueError( |
| f'Memory Map for {fn} is not found, missing one or more of files: {idx_fn}.{{.npy,.info}}' |
| ) |
|
|
| return (mdata, midx) |
|
|
|
|
| class CSVMemMapDataset(TextMemMapDataset): |
| """ |
| Allow per-line lazy access to multiple text files using numpy memmap. |
| """ |
|
|
| def __init__( |
| self, |
| dataset_paths, |
| newline_int=10, |
| header_lines=1, |
| workers=None, |
| tokenizer=None, |
| sort_dataset_paths=True, |
| data_col=1, |
| data_sep=',', |
| ): |
| super().__init__( |
| dataset_paths=dataset_paths, |
| newline_int=newline_int, |
| header_lines=header_lines, |
| workers=workers, |
| tokenizer=tokenizer, |
| sort_dataset_paths=sort_dataset_paths, |
| ) |
| self._data_col = data_col |
| self._data_sep = data_sep |
|
|
| def _build_data_from_text(self, text): |
| """Return a CSV field from text""" |
| |
| text = text.split(self._data_sep)[self._data_col] |
| |
| return super()._build_data_from_text(text) |
|
|
|
|
| class JSONLMemMapDataset(TextMemMapDataset): |
| """ |
| Memory-mapped iteration over a JSONL file. |
| """ |
|
|
| def __init__( |
| self, dataset_paths, newline_int=10, header_lines=1, workers=None, tokenizer=None, sort_dataset_paths=True, |
| ): |
| super().__init__( |
| dataset_paths=dataset_paths, |
| newline_int=newline_int, |
| header_lines=header_lines, |
| workers=workers, |
| tokenizer=tokenizer, |
| sort_dataset_paths=sort_dataset_paths, |
| ) |
|
|
| def _build_data_from_text(self, text): |
| """Return a dictionary of data based on a single JSON line.""" |
| return json.loads(text) |
|
|
|
|
| def _index_file_exists(idx_fn): |
| """Helper function to test if index file exists""" |
| if os.path.exists(idx_fn + ".npy") and os.path.exists(idx_fn + ".info"): |
| return True |
| else: |
| return False |
|
|
|
|
| def _build_memmap_index_files(newline_int, build_index_fn, fn): |
| """Helper function to build an index file""" |
| idx_fn = f"{fn}.{__idx_suffix__}" |
|
|
| |
| if _index_file_exists(idx_fn): |
| return False |
| else: |
| logging.info(f"Building indexing for fn = {fn}") |
| |
| midx = build_index_fn(fn, newline_int) |
| |
| midx = np.asarray(midx) |
| if not np.issubdtype(midx.dtype, np.integer): |
| raise TypeError(f"midx must be an integer array, but got type = {midx.dtype}") |
|
|
| |
| data = dict(newline_int=newline_int, version=__idx_version__) |
|
|
| |
| logging.info(f"Saving idx file = {idx_fn}.npy") |
| np.save(idx_fn + ".npy", midx, allow_pickle=True) |
| logging.info(f"Saving metadata file = {idx_fn}.info") |
| pickle.dump(data, open(idx_fn + ".info", "wb")) |
|
|
| return True |
|
|
|
|
| def build_index_files(dataset_paths, newline_int, workers=None, build_index_fn=_build_index_from_memdata): |
| """Auxiliary method to build multiple index files""" |
| if len(dataset_paths) < 1: |
| raise ValueError("files_list must contain at leat one file name") |
|
|
| if workers is None: |
| workers = max(1, os.cpu_count() // 2) |
|
|
| logging.info(f"Processing {len(dataset_paths)} data files using {workers} workers") |
| |
| start_time = time.time() |
| with mp.Pool(workers) as p: |
| build_status = p.map(partial(_build_memmap_index_files, newline_int, build_index_fn), dataset_paths) |
|
|
| logging.info( |
| f'Time building {sum(build_status)} / {len(build_status)} mem-mapped files: {datetime.timedelta(seconds=time.time() - start_time)}' |
| ) |
|
|