|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import re |
|
|
from typing import List |
|
|
|
|
|
import k2 |
|
|
import torch |
|
|
|
|
|
from icefall.lexicon import Lexicon |
|
|
|
|
|
|
|
|
class CharCtcTrainingGraphCompiler(object): |
|
|
def __init__( |
|
|
self, |
|
|
lexicon: Lexicon, |
|
|
device: torch.device, |
|
|
sos_token: str = "<sos/eos>", |
|
|
eos_token: str = "<sos/eos>", |
|
|
oov: str = "<unk>", |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
lexicon: |
|
|
It is built from `data/lang_char/lexicon.txt`. |
|
|
device: |
|
|
The device to use for operations compiling transcripts to FSAs. |
|
|
oov: |
|
|
Out of vocabulary token. When a word(token) in the transcript |
|
|
does not exist in the token list, it is replaced with `oov`. |
|
|
""" |
|
|
|
|
|
assert oov in lexicon.token_table |
|
|
|
|
|
self.oov_id = lexicon.token_table[oov] |
|
|
self.token_table = lexicon.token_table |
|
|
|
|
|
self.device = device |
|
|
|
|
|
self.sos_id = self.token_table[sos_token] |
|
|
self.eos_id = self.token_table[eos_token] |
|
|
|
|
|
def texts_to_ids(self, texts: List[str], sep: str = "") -> List[List[int]]: |
|
|
"""Convert a list of texts to a list-of-list of token IDs. |
|
|
|
|
|
Args: |
|
|
texts: |
|
|
It is a list of strings. |
|
|
An example containing two strings is given below: |
|
|
|
|
|
['你好中国', '北京欢迎您'] |
|
|
sep: |
|
|
The separator of the items in one sequence, mainly no separator for |
|
|
Chinese (one character a token), "/" for Chinese characters plus BPE |
|
|
token and pinyin tokens. |
|
|
Returns: |
|
|
Return a list-of-list of token IDs. |
|
|
""" |
|
|
assert sep in ("", "/"), sep |
|
|
ids: List[List[int]] = [] |
|
|
whitespace = re.compile(r"([ \t])") |
|
|
for text in texts: |
|
|
if sep == "": |
|
|
text = re.sub(whitespace, "", text) |
|
|
else: |
|
|
text = text.split(sep) |
|
|
sub_ids = [ |
|
|
self.token_table[txt] if txt in self.token_table else self.oov_id |
|
|
for txt in text |
|
|
] |
|
|
ids.append(sub_ids) |
|
|
return ids |
|
|
|
|
|
def compile( |
|
|
self, |
|
|
token_ids: List[List[int]], |
|
|
modified: bool = False, |
|
|
) -> k2.Fsa: |
|
|
"""Build a ctc graph from a list-of-list token IDs. |
|
|
|
|
|
Args: |
|
|
piece_ids: |
|
|
It is a list-of-list integer IDs. |
|
|
modified: |
|
|
See :func:`k2.ctc_graph` for its meaning. |
|
|
Return: |
|
|
Return an FsaVec, which is the result of composing a |
|
|
CTC topology with linear FSAs constructed from the given |
|
|
piece IDs. |
|
|
""" |
|
|
graph = k2.ctc_graph(token_ids, modified=modified, device=self.device) |
|
|
return graph |
|
|
|