|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import itertools |
|
|
import math |
|
|
from typing import Iterable, List, Optional, Tuple, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec |
|
|
from nemo.core.classes import NeuralModule, typecheck |
|
|
from nemo.core.neural_types import LengthsType, LogprobsType, NeuralType, PredictionsType |
|
|
|
|
|
|
|
|
class _TokensWrapper: |
|
|
def __init__(self, vocabulary: List[str], tokenizer: TokenizerSpec): |
|
|
self.vocabulary = vocabulary |
|
|
self.tokenizer = tokenizer |
|
|
|
|
|
if tokenizer is None: |
|
|
self.reverse_map = {vocabulary[i]: i for i in range(len(vocabulary))} |
|
|
|
|
|
@property |
|
|
def blank(self): |
|
|
return len(self.vocabulary) |
|
|
|
|
|
@property |
|
|
def unk_id(self): |
|
|
if (self.tokenizer is not None) and hasattr(self.tokenizer, 'unk_id') and self.tokenizer.unk_id is not None: |
|
|
return self.tokenizer.unk_id |
|
|
|
|
|
if '<unk>' in self.vocabulary: |
|
|
return self.token_to_id('<unk>') |
|
|
else: |
|
|
return -1 |
|
|
|
|
|
@property |
|
|
def vocab(self): |
|
|
return self.vocabulary |
|
|
|
|
|
@property |
|
|
def vocab_size(self): |
|
|
|
|
|
return len(self.vocabulary) + 1 |
|
|
|
|
|
def token_to_id(self, token: str): |
|
|
if token == self.blank: |
|
|
return -1 |
|
|
|
|
|
if self.tokenizer is not None: |
|
|
return self.tokenizer.token_to_id(token) |
|
|
else: |
|
|
return self.reverse_map[token] |
|
|
|
|
|
|
|
|
class FlashLightKenLMBeamSearchDecoder(NeuralModule): |
|
|
''' |
|
|
@property |
|
|
def input_types(self): |
|
|
"""Returns definitions of module input ports. |
|
|
""" |
|
|
return { |
|
|
"log_probs": NeuralType(('B', 'T', 'D'), LogprobsType()), |
|
|
} |
|
|
|
|
|
@property |
|
|
def output_types(self): |
|
|
"""Returns definitions of module output ports. |
|
|
""" |
|
|
return {"hypos": NeuralType(('B'), PredictionsType())} |
|
|
''' |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
lm_path: str, |
|
|
vocabulary: List[str], |
|
|
tokenizer: Optional[TokenizerSpec] = None, |
|
|
lexicon_path: Optional[str] = None, |
|
|
beam_size: int = 32, |
|
|
beam_size_token: int = 32, |
|
|
beam_threshold: float = 25.0, |
|
|
lm_weight: float = 2.0, |
|
|
word_score: float = -1.0, |
|
|
unk_weight: float = -math.inf, |
|
|
sil_weight: float = 0.0, |
|
|
unit_lm: bool = False, |
|
|
): |
|
|
|
|
|
try: |
|
|
from flashlight.lib.text.decoder import ( |
|
|
LM, |
|
|
CriterionType, |
|
|
KenLM, |
|
|
LexiconDecoder, |
|
|
LexiconDecoderOptions, |
|
|
SmearingMode, |
|
|
Trie, |
|
|
) |
|
|
from flashlight.lib.text.dictionary import create_word_dict, load_words |
|
|
except ModuleNotFoundError: |
|
|
raise ModuleNotFoundError( |
|
|
"FlashLightKenLMBeamSearchDecoder requires the installation of flashlight python bindings " |
|
|
"from https://github.com/flashlight/text. Please follow the build instructions there." |
|
|
) |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
self.criterion_type = CriterionType.CTC |
|
|
self.tokenizer_wrapper = _TokensWrapper(vocabulary, tokenizer) |
|
|
self.vocab_size = self.tokenizer_wrapper.vocab_size |
|
|
self.blank = self.tokenizer_wrapper.blank |
|
|
self.silence = self.tokenizer_wrapper.unk_id |
|
|
self.unit_lm = unit_lm |
|
|
|
|
|
if lexicon_path is not None: |
|
|
self.lexicon = load_words(lexicon_path) |
|
|
self.word_dict = create_word_dict(self.lexicon) |
|
|
self.unk_word = self.word_dict.get_index("<unk>") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.lm = KenLM(lm_path, self.word_dict) |
|
|
self.trie = Trie(self.vocab_size, self.silence) |
|
|
|
|
|
start_state = self.lm.start(False) |
|
|
for i, (word, spellings) in enumerate(self.lexicon.items()): |
|
|
word_idx = self.word_dict.get_index(word) |
|
|
_, score = self.lm.score(start_state, word_idx) |
|
|
for spelling in spellings: |
|
|
spelling_idxs = [self.tokenizer_wrapper.token_to_id(token) for token in spelling] |
|
|
if self.tokenizer_wrapper.unk_id in spelling_idxs: |
|
|
print(f'tokenizer has unknown id for word[ {word} ] {spelling} {spelling_idxs}', flush=True) |
|
|
continue |
|
|
self.trie.insert(spelling_idxs, word_idx, score) |
|
|
self.trie.smear(SmearingMode.MAX) |
|
|
|
|
|
self.decoder_opts = LexiconDecoderOptions( |
|
|
beam_size=beam_size, |
|
|
beam_size_token=int(beam_size_token), |
|
|
beam_threshold=beam_threshold, |
|
|
lm_weight=lm_weight, |
|
|
word_score=word_score, |
|
|
unk_score=unk_weight, |
|
|
sil_score=sil_weight, |
|
|
log_add=False, |
|
|
criterion_type=self.criterion_type, |
|
|
) |
|
|
|
|
|
self.decoder = LexiconDecoder( |
|
|
self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, [], self.unit_lm, |
|
|
) |
|
|
else: |
|
|
assert self.unit_lm, "lexicon free decoding can only be done with a unit language model" |
|
|
from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions |
|
|
|
|
|
d = { |
|
|
w: [[w]] |
|
|
for w in self.tokenizer_wrapper.vocab + ([] if '<unk>' in self.tokenizer_wrapper.vocab else ['<unk>']) |
|
|
} |
|
|
self.word_dict = create_word_dict(d) |
|
|
self.lm = KenLM(lm_path, self.word_dict) |
|
|
self.decoder_opts = LexiconFreeDecoderOptions( |
|
|
beam_size=beam_size, |
|
|
beam_size_token=int(beam_size_token), |
|
|
beam_threshold=beam_threshold, |
|
|
lm_weight=lm_weight, |
|
|
sil_score=sil_weight, |
|
|
log_add=False, |
|
|
criterion_type=self.criterion_type, |
|
|
) |
|
|
self.decoder = LexiconFreeDecoder(self.decoder_opts, self.lm, self.silence, self.blank, []) |
|
|
|
|
|
def _get_tokens(self, idxs: List[int]): |
|
|
"""Normalize tokens by handling CTC blank, ASG replabels, etc.""" |
|
|
|
|
|
idxs = (g[0] for g in itertools.groupby(idxs)) |
|
|
idxs = filter(lambda x: x != self.blank and x != self.silence, idxs) |
|
|
|
|
|
return torch.LongTensor(list(idxs)) |
|
|
|
|
|
def _get_timesteps(self, token_idxs: List[int]): |
|
|
"""Returns frame numbers corresponding to every non-blank token. |
|
|
Parameters |
|
|
---------- |
|
|
token_idxs : List[int] |
|
|
IDs of decoded tokens. |
|
|
Returns |
|
|
------- |
|
|
List[int] |
|
|
Frame numbers corresponding to every non-blank token. |
|
|
""" |
|
|
|
|
|
timesteps = [] |
|
|
for i, token_idx in enumerate(token_idxs): |
|
|
if token_idx == self.blank: |
|
|
continue |
|
|
if i == 0 or token_idx != token_idxs[i - 1]: |
|
|
timesteps.append(i) |
|
|
|
|
|
return timesteps |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def forward(self, log_probs: Union[np.ndarray, torch.Tensor]): |
|
|
if isinstance(log_probs, np.ndarray): |
|
|
log_probs = torch.from_numpy(log_probs).float() |
|
|
if log_probs.dim() == 2: |
|
|
log_probs = log_probs.unsqueeze(0) |
|
|
|
|
|
emissions = log_probs.cpu().contiguous() |
|
|
|
|
|
B, T, N = emissions.size() |
|
|
hypos = [] |
|
|
|
|
|
for b in range(B): |
|
|
|
|
|
|
|
|
|
|
|
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) |
|
|
results = self.decoder.decode(emissions_ptr, T, N) |
|
|
|
|
|
hypos.append( |
|
|
[ |
|
|
{ |
|
|
"tokens": self._get_tokens(result.tokens), |
|
|
"score": result.score, |
|
|
"timesteps": self._get_timesteps(result.tokens), |
|
|
"words": [self.word_dict.get_entry(x) for x in result.words if x >= 0], |
|
|
} |
|
|
for result in results |
|
|
] |
|
|
) |
|
|
|
|
|
return hypos |
|
|
|