| import os |
| import sys |
| import json |
| import argparse |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), |
| os.path.pardir))) |
|
|
| from megatron.data import indexed_dataset |
|
|
|
|
| def main(args): |
|
|
| prefixes = set() |
| for basename in os.listdir(args.input): |
| prefix, ext = os.path.splitext(basename) |
|
|
| if prefix in prefixes: |
| continue |
|
|
| if not os.path.isfile(os.path.join(args.input, basename)): |
| continue |
|
|
| ext_pair = '.bin' if ext == '.idx' else '.idx' |
| assert os.path.isfile(os.path.join(args.input, prefix) + ext_pair), \ |
| f'ERROR: {ext_pair} file not provided for {os.path.join(args.input, prefix)}' |
|
|
| prefixes.add(prefix) |
|
|
| builder = None |
| for prefix in sorted(prefixes): |
| if builder is None: |
| dataset = indexed_dataset.make_dataset(os.path.join(args.input, prefix), 'infer') |
|
|
| if isinstance(dataset, indexed_dataset.MMapIndexedDataset): |
| builder = indexed_dataset.MMapIndexedDatasetBuilder(args.output_prefix + '.bin', dtype=dataset._index.dtype) |
| else: |
| builder = indexed_dataset.IndexedDatasetBuilder(args.output_prefix + '.bin') |
|
|
| del dataset |
|
|
| builder.merge_file_(os.path.join(args.input, prefix)) |
|
|
| builder.finalize(args.output_prefix + '.idx') |
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
|
|
| group = parser.add_argument_group(title='input data') |
| group.add_argument('--input', type=str, required=True, |
| help='Path to directory containing all document files to merge') |
|
|
| group = parser.add_argument_group(title='output data') |
| group.add_argument('--output-prefix', type=str, required=True, |
| help='Path to binary output file without suffix') |
|
|
| args = parser.parse_args() |
|
|
| assert os.path.isdir(args.input), \ |
| f'ERROR: {args.input} is not a directory or does not exist' |
|
|
| assert os.path.isdir(os.path.dirname(args.output_prefix)), \ |
| f'ERROR: {os.path.dirname(args.output_prefix)} is not a directory or does not exist' |
|
|
| main(args) |
|
|
|
|