| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| """ |
| |
| This script takes as input `lang_dir`, which should contain:: |
| |
| - lang_dir/bbpe.model, |
| - lang_dir/words.txt |
| |
| and generates the following files in the directory `lang_dir`: |
| |
| - lexicon.txt |
| - lexicon_disambig.txt |
| - L.pt |
| - L_disambig.pt |
| - tokens.txt |
| """ |
|
|
| import argparse |
| from pathlib import Path |
| from typing import Dict, List, Tuple |
|
|
| import k2 |
| import sentencepiece as spm |
| import torch |
| from prepare_lang import ( |
| Lexicon, |
| add_disambig_symbols, |
| add_self_loops, |
| write_lexicon, |
| write_mapping, |
| ) |
|
|
| from icefall.byte_utils import byte_encode |
| from icefall.utils import str2bool, tokenize_by_CJK_char |
|
|
|
|
| def lexicon_to_fst_no_sil( |
| lexicon: Lexicon, |
| token2id: Dict[str, int], |
| word2id: Dict[str, int], |
| need_self_loops: bool = False, |
| ) -> k2.Fsa: |
| """Convert a lexicon to an FST (in k2 format). |
| |
| Args: |
| lexicon: |
| The input lexicon. See also :func:`read_lexicon` |
| token2id: |
| A dict mapping tokens to IDs. |
| word2id: |
| A dict mapping words to IDs. |
| need_self_loops: |
| If True, add self-loop to states with non-epsilon output symbols |
| on at least one arc out of the state. The input label for this |
| self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. |
| Returns: |
| Return an instance of `k2.Fsa` representing the given lexicon. |
| """ |
| loop_state = 0 |
| next_state = 1 |
|
|
| arcs = [] |
|
|
| |
| assert token2id["<blk>"] == 0 |
| assert word2id["<eps>"] == 0 |
|
|
| eps = 0 |
|
|
| for word, pieces in lexicon: |
| assert len(pieces) > 0, f"{word} has no pronunciations" |
| cur_state = loop_state |
|
|
| word = word2id[word] |
| pieces = [token2id[i] for i in pieces] |
|
|
| for i in range(len(pieces) - 1): |
| w = word if i == 0 else eps |
| arcs.append([cur_state, next_state, pieces[i], w, 0]) |
|
|
| cur_state = next_state |
| next_state += 1 |
|
|
| |
| i = len(pieces) - 1 |
| w = word if i == 0 else eps |
| arcs.append([cur_state, loop_state, pieces[i], w, 0]) |
|
|
| if need_self_loops: |
| disambig_token = token2id["#0"] |
| disambig_word = word2id["#0"] |
| arcs = add_self_loops( |
| arcs, |
| disambig_token=disambig_token, |
| disambig_word=disambig_word, |
| ) |
|
|
| final_state = next_state |
| arcs.append([loop_state, final_state, -1, -1, 0]) |
| arcs.append([final_state]) |
|
|
| arcs = sorted(arcs, key=lambda arc: arc[0]) |
| arcs = [[str(i) for i in arc] for arc in arcs] |
| arcs = [" ".join(arc) for arc in arcs] |
| arcs = "\n".join(arcs) |
|
|
| fsa = k2.Fsa.from_str(arcs, acceptor=False) |
| return fsa |
|
|
|
|
| def generate_lexicon( |
| model_file: str, words: List[str], oov: str |
| ) -> Tuple[Lexicon, Dict[str, int]]: |
| """Generate a lexicon from a BPE model. |
| |
| Args: |
| model_file: |
| Path to a sentencepiece model. |
| words: |
| A list of strings representing words. |
| oov: |
| The out of vocabulary word in lexicon. |
| Returns: |
| Return a tuple with two elements: |
| - A dict whose keys are words and values are the corresponding |
| word pieces. |
| - A dict representing the token symbol, mapping from tokens to IDs. |
| """ |
| sp = spm.SentencePieceProcessor() |
| sp.load(str(model_file)) |
|
|
| |
| |
| encode_words = [byte_encode(tokenize_by_CJK_char(w)) for w in words] |
| words_pieces_ids: List[List[int]] = sp.encode(encode_words, out_type=int) |
|
|
| |
| words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] |
|
|
| lexicon = [] |
| for word, pieces in zip(words, words_pieces): |
| lexicon.append((word, pieces)) |
|
|
| lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())])) |
|
|
| token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())} |
|
|
| return lexicon, token2id |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--lang-dir", |
| type=str, |
| help="""Input and output directory. |
| It should contain the bpe.model and words.txt |
| """, |
| ) |
|
|
| parser.add_argument( |
| "--oov", |
| type=str, |
| default="<UNK>", |
| help="The out of vocabulary word in lexicon.", |
| ) |
|
|
| parser.add_argument( |
| "--debug", |
| type=str2bool, |
| default=False, |
| help="""True for debugging, which will generate |
| a visualization of the lexicon FST. |
| |
| Caution: If your lexicon contains hundreds of thousands |
| of lines, please set it to False! |
| |
| See "test/test_bpe_lexicon.py" for usage. |
| """, |
| ) |
|
|
| return parser.parse_args() |
|
|
|
|
| def main(): |
| args = get_args() |
| lang_dir = Path(args.lang_dir) |
| model_file = lang_dir / "bbpe.model" |
|
|
| word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") |
|
|
| words = word_sym_table.symbols |
|
|
| excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", args.oov, "#0", "<s>", "</s>"] |
|
|
| for w in excluded: |
| if w in words: |
| words.remove(w) |
|
|
| lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov) |
|
|
| lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) |
|
|
| next_token_id = max(token_sym_table.values()) + 1 |
| for i in range(max_disambig + 1): |
| disambig = f"#{i}" |
| assert disambig not in token_sym_table |
| token_sym_table[disambig] = next_token_id |
| next_token_id += 1 |
|
|
| word_sym_table.add("#0") |
| word_sym_table.add("<s>") |
| word_sym_table.add("</s>") |
|
|
| write_mapping(lang_dir / "tokens.txt", token_sym_table) |
|
|
| write_lexicon(lang_dir / "lexicon.txt", lexicon) |
| write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) |
|
|
| L = lexicon_to_fst_no_sil( |
| lexicon, |
| token2id=token_sym_table, |
| word2id=word_sym_table, |
| ) |
|
|
| L_disambig = lexicon_to_fst_no_sil( |
| lexicon_disambig, |
| token2id=token_sym_table, |
| word2id=word_sym_table, |
| need_self_loops=True, |
| ) |
| torch.save(L.as_dict(), lang_dir / "L.pt") |
| torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") |
|
|
| if args.debug: |
| labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") |
| aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") |
|
|
| L.labels_sym = labels_sym |
| L.aux_labels_sym = aux_labels_sym |
| L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") |
|
|
| L_disambig.labels_sym = labels_sym |
| L_disambig.aux_labels_sym = aux_labels_sym |
| L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|