| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| This script takes a `bpe.model` and a text file such as |
| ./download/lm/librispeech-lm-norm.txt |
| and outputs the LM training data to a supplied directory such |
| as data/lm_training_bpe_500. The format is as follows: |
| |
| It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a |
| representation of a dict with the following format: |
| |
| 'words' -> a k2.RaggedTensor of two axes [word][token] with dtype torch.int32 |
| containing the BPE representations of each word, indexed by |
| integer word ID. (These integer word IDS are present in |
| 'lm_data'). The sentencepiece object can be used to turn the |
| words and BPE units into string form. |
| 'sentences' -> a k2.RaggedTensor of two axes [sentence][word] with dtype |
| torch.int32 containing all the sentences, as word-ids (we don't |
| output the string form of this directly but it can be worked out |
| together with 'words' and the bpe.model). |
| 'sentence_lengths' -> a 1-D torch.Tensor of dtype torch.int32, containing |
| number of BPE tokens of each sentence. |
| """ |
|
|
| import argparse |
| import logging |
| from pathlib import Path |
|
|
| import k2 |
| import sentencepiece as spm |
| import torch |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--bpe-model", |
| type=str, |
| help="Input BPE model, e.g. data/bpe_500/bpe.model", |
| ) |
| parser.add_argument( |
| "--lm-data", |
| type=str, |
| help="""Input LM training data as text, e.g. |
| download/pb.train.txt""", |
| ) |
| parser.add_argument( |
| "--lm-archive", |
| type=str, |
| help="""Path to output archive, e.g. data/bpe_500/lm_data.pt; |
| look at the source of this script to see the format.""", |
| ) |
|
|
| return parser.parse_args() |
|
|
|
|
| def main(): |
| args = get_args() |
|
|
| if Path(args.lm_archive).exists(): |
| logging.warning(f"{args.lm_archive} exists - skipping") |
| return |
|
|
| sp = spm.SentencePieceProcessor() |
| sp.load(args.bpe_model) |
|
|
| |
| |
| |
| word2index = dict() |
|
|
| word2bpe = [] |
| sentences = [] |
|
|
| if "librispeech-lm-norm" in args.lm_data: |
| num_lines_in_total = 40418261.0 |
| step = 5000000 |
| elif "valid" in args.lm_data: |
| num_lines_in_total = 5567.0 |
| step = 3000 |
| elif "test" in args.lm_data: |
| num_lines_in_total = 5559.0 |
| step = 3000 |
| else: |
| num_lines_in_total = None |
| step = None |
|
|
| processed = 0 |
|
|
| with open(args.lm_data) as f: |
| while True: |
| line = f.readline() |
| if line == "": |
| break |
|
|
| if step and processed % step == 0: |
| logging.info( |
| f"Processed number of lines: {processed} " |
| f"({processed/num_lines_in_total*100: .3f}%)" |
| ) |
| processed += 1 |
|
|
| line_words = line.split() |
| for w in line_words: |
| if w not in word2index: |
| w_bpe = sp.encode(w) |
| word2index[w] = len(word2bpe) |
| word2bpe.append(w_bpe) |
| sentences.append([word2index[w] for w in line_words]) |
|
|
| logging.info("Constructing ragged tensors") |
| words = k2.ragged.RaggedTensor(word2bpe) |
| sentences = k2.ragged.RaggedTensor(sentences) |
|
|
| output = dict(words=words, sentences=sentences) |
|
|
| num_sentences = sentences.dim0 |
| logging.info(f"Computing sentence lengths, num_sentences: {num_sentences}") |
| sentence_lengths = [0] * num_sentences |
| for i in range(num_sentences): |
| if step and i % step == 0: |
| logging.info( |
| f"Processed number of lines: {i} ({i/num_sentences*100: .3f}%)" |
| ) |
|
|
| word_ids = sentences[i] |
|
|
| |
| |
| token_ids = words[word_ids] |
| if isinstance(token_ids, k2.RaggedTensor): |
| token_ids = token_ids.values |
|
|
| |
| |
|
|
| sentence_lengths[i] = token_ids.numel() |
|
|
| output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32) |
|
|
| torch.save(output, args.lm_archive) |
| logging.info(f"Saved to {args.lm_archive}") |
|
|
|
|
| if __name__ == "__main__": |
| formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" |
|
|
| logging.basicConfig(format=formatter, level=logging.INFO) |
|
|
| main() |
|
|