File size: 5,508 Bytes
d596074 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import k2
import torch
from icefall.lexicon import Lexicon
class CtcTrainingGraphCompiler(object):
def __init__(
self,
lexicon: Lexicon,
device: torch.device,
oov: str = "<UNK>",
need_repeat_flag: bool = False,
):
"""
Args:
lexicon:
It is built from `data/lang/lexicon.txt`.
device:
The device to use for operations compiling transcripts to FSAs.
oov:
Out of vocabulary word. When a word in the transcript
does not exist in the lexicon, it is replaced with `oov`.
need_repeat_flag:
If True, will add an attribute named `_is_repeat_token_` to ctc_topo
indicating whether this token is a repeat token in ctc graph.
This attribute is needed to implement delay-penalty for phone-based
ctc loss. See https://github.com/k2-fsa/k2/pull/1086 for more
details. Note: The above change MUST be included in k2 to open this
flag.
"""
L_inv = lexicon.L_inv.to(device)
assert L_inv.requires_grad is False
assert oov in lexicon.word_table
self.L_inv = k2.arc_sort(L_inv)
self.oov_id = lexicon.word_table[oov]
self.word_table = lexicon.word_table
max_token_id = max(lexicon.tokens)
ctc_topo = k2.ctc_topo(max_token_id, modified=False)
self.ctc_topo = ctc_topo.to(device)
if need_repeat_flag:
self.ctc_topo._is_repeat_token_ = (
self.ctc_topo.labels != self.ctc_topo.aux_labels
)
self.device = device
def compile(self, texts: List[str]) -> k2.Fsa:
"""Build decoding graphs by composing ctc_topo with
given transcripts.
Args:
texts:
A list of strings. Each string contains a sentence for an utterance.
A sentence consists of spaces separated words. An example `texts`
looks like:
['hello icefall', 'CTC training with k2']
Returns:
An FsaVec, the composition result of `self.ctc_topo` and the
transcript FSA.
"""
transcript_fsa = self.convert_transcript_to_fsa(texts)
# NOTE: k2.compose runs on CUDA only when treat_epsilons_specially
# is False, so we add epsilon self-loops here
fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(transcript_fsa)
fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
decoding_graph = k2.compose(
self.ctc_topo, fsa_with_self_loops, treat_epsilons_specially=False
)
assert decoding_graph.requires_grad is False
return decoding_graph
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
"""Convert a list of texts to a list-of-list of word IDs.
Args:
texts:
It is a list of strings. Each string consists of space(s)
separated words. An example containing two strings is given below:
['HELLO ICEFALL', 'HELLO k2']
Returns:
Return a list-of-list of word IDs.
"""
word_ids_list = []
for text in texts:
word_ids = []
for word in text.split():
if word in self.word_table:
word_ids.append(self.word_table[word])
else:
word_ids.append(self.oov_id)
word_ids_list.append(word_ids)
return word_ids_list
def convert_transcript_to_fsa(self, texts: List[str]) -> k2.Fsa:
"""Convert a list of transcript texts to an FsaVec.
Args:
texts:
A list of strings. Each string contains a sentence for an utterance.
A sentence consists of spaces separated words. An example `texts`
looks like:
['hello icefall', 'CTC training with k2']
Returns:
Return an FsaVec, whose `shape[0]` equals to `len(texts)`.
"""
word_ids_list = []
for text in texts:
word_ids = []
for word in text.split():
if word in self.word_table:
word_ids.append(self.word_table[word])
else:
word_ids.append(self.oov_id)
word_ids_list.append(word_ids)
word_fsa = k2.linear_fsa(word_ids_list, self.device)
word_fsa_with_self_loops = k2.add_epsilon_self_loops(word_fsa)
fsa = k2.intersect(
self.L_inv, word_fsa_with_self_loops, treat_epsilons_specially=False
)
# fsa has word ID as labels and token ID as aux_labels, so
# we need to invert it
ans_fsa = fsa.invert_()
return k2.arc_sort(ans_fsa)
|