NeMo
/
examples
/nlp
/token_classification
/data
/create_punctuation_capitalization_tarred_dataset.py
| # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import argparse | |
| import multiprocessing as mp | |
| from pathlib import Path | |
| from nemo.collections.nlp.data.token_classification.punctuation_capitalization_tarred_dataset import ( | |
| DEFAULT_CAPIT_LABEL_VOCAB_FILE_NAME, | |
| DEFAULT_PUNCT_LABEL_VOCAB_FILE_NAME, | |
| METADATA_CAPIT_LABEL_VOCAB_KEY, | |
| METADATA_PUNCT_LABEL_VOCAB_KEY, | |
| build_label_ids_from_list_of_labels, | |
| check_labels_for_being_unique_before_building_label_ids, | |
| check_tar_file_prefix, | |
| create_tarred_dataset, | |
| ) | |
| """ | |
| A tarred dataset allows to train on large amounts without storing it all into memory simultaneously. In case of | |
| punctuation and capitalization model, tarred dataset is a directory which contains metadata file, tar files with | |
| batches, punct_label_vocab.csv and capit_label_vocab.csv files. | |
| A metadata file is a JSON file with 4 fields: 'num_batches', 'tar_files', 'punct_label_vocab_file', | |
| 'capit_label_vocab_file'. 'num_batches' (int) is a total number of batches in tarred dataset. 'tar_files' is a list of | |
| paths to tar files relative to directory containing the metadata file. 'punct_label_vocab_file' and | |
| 'capit_label_vocab_file' are paths to .csv files containing all unique punctuation and capitalization labels. Each | |
| label in these files is written in a separate line. The first labels in both files are equal and serve for padding and | |
| as neutral labels. | |
| Every tar file contains objects written using `webdataset.TarWriter`. Each object is a dictionary with two items: | |
| '__key__' and 'batch.pyd'. '__key__' is a name of a batch and 'batch.pyd' is a pickled dictionary which contains | |
| 'input_ids', 'subtokens_mask', 'punct_labels', 'capit_labels'. 'input_ids' is an array containing ids of source tokens, | |
| 'subtokens_mask' is a boolean array showing first tokens in words, 'punct_labels' and 'capit_labels' are arrays with | |
| ids of labels. Metadata file should be passed to constructor of | |
| `nemo.collections.nlp.data.token_classification.PunctuationCapitalizationTarredDataset` and the instance of | |
| the class will handle iteration and constructing masks and token types for BERT model. | |
| Example of usage: | |
| python create_punctuation_capitalization_tarred_dataset.py \ | |
| --text <PATH/TO/TEXT/FILE> \ | |
| --labels <PATH/TO/LABELS/FILE> \ | |
| --output_dir <PATH/TO/OUTPUT/DIR> \ | |
| --lines_per_dataset_fragment 10000 \ | |
| --tokens_in_batch 8000 \ | |
| --num_batches_per_tarfile 5 \ | |
| --tokenizer_name char \ | |
| --vocab_file <PATH_TO_CHAR_TOKENIZER_VOCABULARY> | |
| """ | |
| def get_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
| description=f"A tarred dataset allows to train on large amounts without storing it all into memory " | |
| f"simultaneously. In case of punctuation and capitalization model, tarred dataset is a directory which " | |
| f"contains metadata file, tar files with batches, {DEFAULT_PUNCT_LABEL_VOCAB_FILE_NAME} and " | |
| f"{DEFAULT_CAPIT_LABEL_VOCAB_FILE_NAME} files. A metadata file is a JSON file with 4 fields: 'num_batches', " | |
| f"'tar_files', '{METADATA_PUNCT_LABEL_VOCAB_KEY}', '{METADATA_CAPIT_LABEL_VOCAB_KEY}'. 'num_batches' (int) is " | |
| f"a total number of batches in tarred dataset. 'tar_files' is a list of paths to tar files relative " | |
| f"to directory containing the metadata file. '{METADATA_PUNCT_LABEL_VOCAB_KEY}' and " | |
| f"'{METADATA_CAPIT_LABEL_VOCAB_KEY}' are paths to .csv files containing all unique punctuation and " | |
| f"capitalization labels. Each label in these files is written in a separate line. The first labels in both " | |
| f"files are equal and serve for padding and as neutral labels. Every tar file contains objects written " | |
| f"using `webdataset.TarWriter`. Each object is a dictionary with two items: '__key__' and 'batch.pyd'. " | |
| f"'__key__' is a name of a batch and 'batch.pyd' is a pickled dictionary which contains 'input_ids', " | |
| f"'subtokens_mask', 'punct_labels', 'capit_labels'. 'input_ids' is an array containing ids of source tokens, " | |
| f"'subtokens_mask' is a boolean array showing first tokens in words, 'punct_labels' and 'capit_labels' are " | |
| f"arrays with ids of labels. Metadata file should be passed to constructor of " | |
| "`nemo.collections.nlp.data.token_classification.PunctuationCapitalizationTarredDataset` and the instance of " | |
| "the class will handle iteration and constructing masks and token types for BERT model.", | |
| ) | |
| parser.add_argument( | |
| "--text", | |
| "-t", | |
| help="Path to source lowercased text without punctuation. Number of lines in `--text` file has to be equal " | |
| "to number of lines in `--labels` file.", | |
| type=Path, | |
| required=True, | |
| ) | |
| parser.add_argument( | |
| "--audio_file", | |
| type=Path, | |
| required=False, | |
| help="Path to source file which contains paths to audio one path per line. " | |
| "Number of lines in `--audio_file` has to be equal to number of lines in `--labels` file", | |
| ) | |
| parser.add_argument( | |
| "--use_audio", | |
| required=False, | |
| action="store_true", | |
| help="If set to `True` script creates lexical audio dataset which can be used with `PunctuationCapitalizationLexicalAudioModel`.", | |
| ) | |
| parser.add_argument( | |
| "--sample_rate", | |
| type=int, | |
| required=False, | |
| help="Target sample rate of audios. Can be used for downsampling or upsampling.", | |
| ) | |
| parser.add_argument( | |
| "--labels", | |
| "-L", | |
| type=Path, | |
| required=True, | |
| help="Path to file with labels in the format described here " | |
| "https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/nlp/punctuation_and_capitalization.html#" | |
| "nemo-data-format . Number of lines in `--labels` file has to be equal to the number of lines in `--text` " | |
| "file.", | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| "-o", | |
| type=Path, | |
| required=True, | |
| help="Path to directory where .tar files, metadata file, label id files are stored.", | |
| ) | |
| parser.add_argument( | |
| "--max_seq_length", | |
| "-s", | |
| type=int, | |
| default=512, | |
| help="Maximum number of subtokens in an input sequence. A source sequence which contain too many subtokens are " | |
| "clipped to `--max_seq_length - 2` subtokens and then [CLS] token is prepended to the clipped sequence and " | |
| "[SEP] token is appended to the clipped sequence. The clipping is performed via removal of subtokens in the " | |
| "end of a source sequence.", | |
| ) | |
| parser.add_argument( | |
| "--tokens_in_batch", | |
| "-b", | |
| type=int, | |
| default=15000, | |
| help="Maximum number of tokens in a batch including [CLS], [SEP], [UNK], and [PAD] tokens. Before packing into " | |
| "batches source sequences are sorted by number of tokens in order to reduce number of pad tokens. So the " | |
| "number of sequences in a batch may be different.", | |
| ) | |
| parser.add_argument( | |
| "--lines_per_dataset_fragment", | |
| type=int, | |
| default=10 ** 6, | |
| help="A number of lines processed by one worker during creation of tarred dataset. A worker tokenizes " | |
| "`--lines_per_dataset_fragment` lines and keeps in RAM tokenized text labels before packing them into " | |
| "batches. Reducing `--lines_per_dataset_fragment` leads to reducing of the amount of memory required by this " | |
| "script.", | |
| ) | |
| parser.add_argument( | |
| "--num_batches_per_tarfile", | |
| type=int, | |
| default=1000, | |
| help="A number of batches saved in a tar file. If you increase `--num_batches_per_tarfile`, then there will " | |
| "be less tar files in the dataset. There cannot be less then `--num_batches_per_tarfile` batches in a tar " | |
| "file, and all excess batches are removed. Maximum number of discarded batches is " | |
| "`--num_batches_per_tarfile - 1`.", | |
| ) | |
| parser.add_argument( | |
| "--tokenizer_name", | |
| "-T", | |
| default="bert-base-uncased", | |
| help="Name of the tokenizer used for tokenization of source sequences. Possible options are 'sentencepiece', " | |
| "'word', 'char', HuggingFace tokenizers. For more options see function " | |
| "`nemo.collections.nlp.modules.common.get_tokenizer`. The tokenizer has to have properties `cls_id`, " | |
| "`pad_id`, `sep_id`, `unk_id`.", | |
| ) | |
| parser.add_argument( | |
| "--tokenizer_model", "-m", type=Path, help="Path to tokenizer model required for 'sentencepiece' tokenizer." | |
| ) | |
| parser.add_argument( | |
| "--vocab_file", | |
| "-v", | |
| type=Path, | |
| help="Path to vocabulary file which can be used in 'word', 'char', and HuggingFace tokenizers.", | |
| ) | |
| parser.add_argument( | |
| "--merges_file", "-M", type=Path, help="Path to merges file which can be used in HuggingFace tokenizers." | |
| ) | |
| parser.add_argument( | |
| "--special_token_names", | |
| "-n", | |
| nargs="+", | |
| help="Names of special tokens which may be passed to constructors of 'char', 'word', 'sentencepiece', and " | |
| "HuggingFace tokenizers.", | |
| ) | |
| parser.add_argument( | |
| "--special_token_values", | |
| "-V", | |
| nargs="+", | |
| help="Values of special tokens which may be passed to constructors of 'char', 'word', 'sentencepiece', and " | |
| "HuggingFace tokenizers.", | |
| ) | |
| parser.add_argument( | |
| "--use_fast_tokenizer", "-f", action="store_true", help="Whether to use fast HuggingFace tokenizer." | |
| ) | |
| parser.add_argument( | |
| "--pad_label", | |
| "-P", | |
| default='O', | |
| help="Pad label both for punctuation and capitalization. This label is also is used for marking words which " | |
| "do not need punctuation and capitalization. It is also a neutral label used for marking words which do " | |
| "not require punctuation and capitalization.", | |
| ) | |
| punct = parser.add_mutually_exclusive_group(required=False) | |
| punct.add_argument( | |
| "--punct_labels", | |
| "-p", | |
| nargs="+", | |
| help="All punctuation labels EXCEPT PAD LABEL. Punctuation labels are strings separated by spaces. " | |
| "Alternatively you can use parameter `--punct_label_vocab_file`. If none of parameters `--punct_labels` " | |
| "and `--punct_label_vocab_file` are provided, then punctuation label ids will be inferred from `--labels` " | |
| "file.", | |
| ) | |
| punct.add_argument( | |
| "--punct_label_vocab_file", | |
| type=Path, | |
| help="A path to file with punctuation labels. These labels include pad label. Pad label has to be the first " | |
| "label in the file. Each label is written on separate line. Alternatively you can use `--punct_labels` " | |
| "parameter. If none of parameters `--punct_labels` and `--punct_label_vocab_file` are provided, then " | |
| "punctuation label ids will be inferred from `--labels` file.", | |
| ) | |
| capit = parser.add_mutually_exclusive_group(required=False) | |
| capit.add_argument( | |
| "--capit_labels", | |
| "-c", | |
| nargs="+", | |
| help="All capitalization labels EXCEPT PAD LABEL. Capitalization labels are strings separated by spaces. " | |
| "Alternatively you can use parameter `--capit_label_vocab_file`. If none of parameters `--capit_labels` " | |
| "and `--capit_label_vocab_file` are provided, then capitalization label ids will be inferred from `--labels` " | |
| "file.", | |
| ) | |
| capit.add_argument( | |
| "--capit_label_vocab_file", | |
| type=Path, | |
| help="A path to file with capitalization labels. These labels include pad label. Pad label has to be the " | |
| "first label in the file. Each label is written on separate line. Alternatively you can use `--capit_labels` " | |
| "parameter. If none of parameters `--capit_labels` and `--capit_label_vocab_file` are provided, then " | |
| "capitalization label ids will be inferred from `--labels` file.", | |
| ) | |
| parser.add_argument( | |
| "--tar_file_prefix", | |
| "-x", | |
| default="punctuation_capitalization", | |
| help="A string from which tar file names start. It can contain only characters 'A-Z', 'a-z', '0-9', '_', '-', " | |
| "'.'.", | |
| ) | |
| parser.add_argument( | |
| "--n_jobs", | |
| "-j", | |
| type=int, | |
| default=mp.cpu_count(), | |
| help="Number of workers for creating tarred dataset. By default it is equal to the number of CPU cores.", | |
| ) | |
| args = parser.parse_args() | |
| for name in [ | |
| "text", | |
| "labels", | |
| "output_dir", | |
| "tokenizer_model", | |
| "vocab_file", | |
| "merges_file", | |
| "punct_label_vocab_file", | |
| "capit_label_vocab_file", | |
| ]: | |
| if getattr(args, name) is not None: | |
| setattr(args, name, getattr(args, name).expanduser()) | |
| if args.special_token_names is not None or args.special_token_values is not None: | |
| if args.special_token_names is None: | |
| parser.error( | |
| "If you provide parameter `--special_token_values` you have to provide parameter " | |
| "`--special_token_names`." | |
| ) | |
| if args.special_token_values is None: | |
| parser.error( | |
| "If you provide parameter `--special_token_names` you have to provide parameter " | |
| "`--special_token_values`." | |
| ) | |
| if len(args.special_token_names) != len(args.special_token_values): | |
| parser.error( | |
| f"Parameters `--special_token_names` and `--special_token_values` have to have equal number of values " | |
| f"whereas parameter `--special_token_names` has {len(args.special_token_names)} values and parameter " | |
| f"`--special_token_values` has {len(args.special_token_values)} values." | |
| ) | |
| if len(set(args.special_token_names)) != len(args.special_token_names): | |
| for i in range(len(args.special_token_names) - 1): | |
| if args.special_token_names[i] in args.special_token_names[i + 1 :]: | |
| parser.error( | |
| f"Values of parameter `--special_token_names` has to be unique. Found duplicate value " | |
| f"'{args.special_token_names[i]}'." | |
| ) | |
| if args.punct_labels is not None: | |
| check_labels_for_being_unique_before_building_label_ids( | |
| args.pad_label, args.punct_labels, '--pad_label', '--punct_labels', parser.error | |
| ) | |
| check_labels_for_being_unique_before_building_label_ids( | |
| args.pad_label, args.capit_labels, '--pad_label', '--capit_labels', parser.error | |
| ) | |
| check_tar_file_prefix(args.tar_file_prefix, parser.error, '--tar_file_prefix') | |
| return args | |
| def main() -> None: | |
| args = get_args() | |
| if args.special_token_names is None: | |
| special_tokens = None | |
| else: | |
| special_tokens = dict(zip(args.special_token_names, args.special_token_values)) | |
| if args.punct_labels is not None: | |
| punct_label_ids = build_label_ids_from_list_of_labels(args.pad_label, args.punct_labels) | |
| else: | |
| punct_label_ids = None | |
| if args.capit_labels is not None: | |
| capit_label_ids = build_label_ids_from_list_of_labels(args.pad_label, args.capit_labels) | |
| else: | |
| capit_label_ids = None | |
| create_tarred_dataset( | |
| args.text, | |
| args.labels, | |
| args.output_dir, | |
| args.max_seq_length, | |
| args.tokens_in_batch, | |
| args.lines_per_dataset_fragment, | |
| args.num_batches_per_tarfile, | |
| args.tokenizer_name, | |
| tokenizer_model=args.tokenizer_model, | |
| vocab_file=args.vocab_file, | |
| merges_file=args.merges_file, | |
| special_tokens=special_tokens, | |
| use_fast_tokenizer=args.use_fast_tokenizer, | |
| pad_label=args.pad_label, | |
| punct_label_ids=punct_label_ids, | |
| capit_label_ids=capit_label_ids, | |
| punct_label_vocab_file=args.punct_label_vocab_file, | |
| capit_label_vocab_file=args.capit_label_vocab_file, | |
| tar_file_prefix=args.tar_file_prefix, | |
| n_jobs=args.n_jobs, | |
| audio_file=args.audio_file, | |
| sample_rate=args.sample_rate, | |
| use_audio=args.use_audio, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |