File size: 7,859 Bytes
d596074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import logging
from pathlib import Path
from typing import Iterable, List, Tuple, Union

import k2
import torch

from icefall.lexicon import UniqLexicon


class MmiTrainingGraphCompiler(object):
    def __init__(
        self,
        lang_dir: Path,
        uniq_filename: str = "uniq_lexicon.txt",
        device: Union[str, torch.device] = "cpu",
        oov: str = "<UNK>",
        sos_id: int = 1,
        eos_id: int = 1,
    ):
        """
        Args:
          lang_dir:
            Path to the lang directory. It is expected to contain the
            following files::

                - tokens.txt
                - words.txt
                - P.fst.txt

            The above files are generated by the script `prepare.sh`. You
            should have run it before running the training code.
          uniq_filename:
            File name to the lexicon in which every word has exactly one
            pronunciation. We assume this file is inside the given `lang_dir`.

          device:
            It indicates CPU or CUDA.
          oov:
            Out of vocabulary word. When a word in the transcript
            does not exist in the lexicon, it is replaced with `oov`.
        """
        self.lang_dir = Path(lang_dir)
        self.lexicon = UniqLexicon(lang_dir, uniq_filename=uniq_filename)
        self.device = torch.device(device)

        self.L_inv = self.lexicon.L_inv.to(self.device)

        self.oov_id = self.lexicon.word_table[oov]
        self.sos_id = sos_id
        self.eos_id = eos_id

        self.build_ctc_topo_P()

    def build_ctc_topo_P(self):
        """Built ctc_topo_P, the composition result of
        ctc_topo and P, where P is a pre-trained bigram
        word piece LM.
        """
        # Note: there is no need to save a pre-compiled P and ctc_topo
        # as it is very fast to generate them.
        logging.info(f"Loading P from {self.lang_dir/'P.fst.txt'}")
        with open(self.lang_dir / "P.fst.txt") as f:
            # P is not an acceptor because there is
            # a back-off state, whose incoming arcs
            # have label #0 and aux_label 0 (i.e., <eps>).
            P = k2.Fsa.from_openfst(f.read(), acceptor=False)

        first_token_disambig_id = self.lexicon.token_table["#0"]

        # P.aux_labels is not needed in later computations, so
        # remove it here.
        del P.aux_labels
        # CAUTION: The following line is crucial.
        # Arcs entering the back-off state have label equal to #0.
        # We have to change it to 0 here.
        labels = P.labels.clone()
        labels[labels >= first_token_disambig_id] = 0
        P.labels = labels

        P = k2.remove_epsilon(P)
        P = k2.arc_sort(P)
        P = P.to(self.device)
        # Add epsilon self-loops to P because we want the
        # following operation "k2.intersect" to run on GPU.
        P_with_self_loops = k2.add_epsilon_self_loops(P)

        max_token_id = max(self.lexicon.tokens)
        logging.info(
            f"Building ctc_topo (modified=False). max_token_id: {max_token_id}"
        )
        ctc_topo = k2.ctc_topo(max_token_id, modified=False, device=self.device)

        ctc_topo_inv = k2.arc_sort(ctc_topo.invert_())

        logging.info("Building ctc_topo_P")
        ctc_topo_P = k2.intersect(
            ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False
        ).invert()

        self.ctc_topo_P = k2.arc_sort(ctc_topo_P)
        logging.info(f"ctc_topo_P num_arcs: {self.ctc_topo_P.num_arcs}")

    def compile(
        self, texts: Iterable[str], replicate_den: bool = True
    ) -> Tuple[k2.Fsa, k2.Fsa]:
        """Create numerator and denominator graphs from transcripts
        and the bigram phone LM.

        Args:
          texts:
            A list of transcripts. Within a transcript, words are
            separated by spaces. An example `texts` is given below::

                ["Hello icefall", "LF-MMI training with icefall using k2"]

          replicate_den:
            If True, the returned den_graph is replicated to match the number
            of FSAs in the returned num_graph; if False, the returned den_graph
            contains only a single FSA
        Returns:
          A tuple (num_graph, den_graph), where

            - `num_graph` is the numerator graph. It is an FsaVec with
              shape `(len(texts), None, None)`.

            - `den_graph` is the denominator graph. It is an FsaVec
              with the same shape of the `num_graph` if replicate_den is
              True; otherwise, it is an FsaVec containing only a single FSA.
        """
        transcript_fsa = self.build_transcript_fsa(texts)

        # remove word IDs from transcript_fsa since it is not needed
        del transcript_fsa.aux_labels
        # NOTE: You can comment out the above statement
        # if you want to run test/test_mmi_graph_compiler.py

        transcript_fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(
            transcript_fsa
        )

        transcript_fsa_with_self_loops = k2.arc_sort(transcript_fsa_with_self_loops)

        num = k2.compose(
            self.ctc_topo_P,
            transcript_fsa_with_self_loops,
            treat_epsilons_specially=False,
        )

        # CAUTION: Due to the presence of P,
        # the resulting `num` may not be connected
        num = k2.connect(num)

        num = k2.arc_sort(num)

        ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P])
        if replicate_den:
            indexes = torch.zeros(len(texts), dtype=torch.int32, device=self.device)
            den = k2.index_fsa(ctc_topo_P_vec, indexes)
        else:
            den = ctc_topo_P_vec

        return num, den

    def build_transcript_fsa(self, texts: List[str]) -> k2.Fsa:
        """Convert transcripts to an FsaVec with the help of a lexicon
        and word symbol table.

        Args:
          texts:
            Each element is a transcript containing words separated by space(s).
            For instance, it may be 'HELLO icefall', which contains
            two words.

        Returns:
          Return an FST (FsaVec) corresponding to the transcript.
          Its `labels` is token IDs and `aux_labels` is word IDs.
        """
        word_ids_list = []
        for text in texts:
            word_ids = []
            for word in text.split():
                if word in self.lexicon.word_table:
                    word_ids.append(self.lexicon.word_table[word])
                else:
                    word_ids.append(self.oov_id)
            word_ids_list.append(word_ids)

        fsa = k2.linear_fsa(word_ids_list, self.device)
        fsa = k2.add_epsilon_self_loops(fsa)

        # The reason to use `invert_()` at the end is as follows:
        #
        # (1) The `labels` of L_inv is word IDs and `aux_labels` is token IDs
        # (2) `fsa.labels` is word IDs
        # (3) after intersection, the `labels` is still word IDs
        # (4) after `invert_()`, the `labels` is token IDs
        #     and `aux_labels` is word IDs
        transcript_fsa = k2.intersect(
            self.L_inv, fsa, treat_epsilons_specially=False
        ).invert_()
        transcript_fsa = k2.arc_sort(transcript_fsa)
        return transcript_fsa

    def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
        """Convert a list of texts to a list-of-list of piece IDs.

        Args:
          texts:
            It is a list of strings. Each string consists of space(s)
            separated words. An example containing two strings is given below:

                ['HELLO ICEFALL', 'HELLO k2']
            We assume it contains no OOVs. Otherwise, it will raise an
            exception.
        Returns:
          Return a list-of-list of token IDs.
        """
        return self.lexicon.texts_to_token_ids(texts).tolist()