| | from typing import List, NamedTuple |
| |
|
| | import torch |
| | from pyctcdecode import build_ctcdecoder |
| |
|
| |
|
| | from hw_asr.base.base_text_encoder import BaseTextEncoder |
| | from .char_text_encoder import CharTextEncoder |
| | from collections import defaultdict |
| |
|
| |
|
| | class Hypothesis(NamedTuple): |
| | text: str |
| | prob: float |
| |
|
| |
|
| | class CTCCharTextEncoder(CharTextEncoder): |
| | EMPTY_TOK = "^" |
| |
|
| | def __init__(self, alphabet: List[str] = None, kenlm_model_path: str = None, unigrams_path: str = None): |
| | super().__init__(alphabet) |
| | vocab = [self.EMPTY_TOK] + list(self.alphabet) |
| | self.ind2char = dict(enumerate(vocab)) |
| | self.char2ind = {v: k for k, v in self.ind2char.items()} |
| | if kenlm_model_path is not None: |
| | with open(unigrams_path) as f: |
| | unigrams = [line.strip() for line in f.readlines()] |
| | self.decoder = build_ctcdecoder(labels=[""] + self.alphabet, kenlm_model_path=kenlm_model_path, unigrams=unigrams) |
| |
|
| | def ctc_decode(self, inds: List[int]) -> str: |
| | |
| | result = [] |
| | last_char = self.EMPTY_TOK |
| | for ind in inds: |
| | cur_char = self.ind2char[ind] |
| | if cur_char != self.EMPTY_TOK and last_char != cur_char: |
| | result.append(cur_char) |
| | last_char = cur_char |
| | return ''.join(result) |
| |
|
| | def ctc_beam_search(self, probs: torch.tensor, beam_size: int) -> str: |
| | """ |
| | Performs beam search and returns a list of pairs (hypothesis, hypothesis probability). |
| | """ |
| | assert len(probs.shape) == 2 |
| | char_length, voc_size = probs.shape |
| | assert voc_size == len(self.ind2char) |
| | hypos: List[Hypothesis] = [] |
| | |
| |
|
| | def extend_and_merge(frame, state): |
| | new_state = defaultdict(float) |
| | for next_char_index, next_char_proba in enumerate(frame): |
| | for (pref, last_char), pref_proba in state.items(): |
| | next_char = self.ind2char[next_char_index] |
| | if next_char == last_char: |
| | new_pref = pref |
| | else: |
| | if next_char != self.EMPTY_TOK: |
| | new_pref = pref + next_char |
| | else: |
| | new_pref = pref |
| | last_char = next_char |
| | new_state[(new_pref, last_char)] += pref_proba * next_char_proba |
| | return new_state |
| |
|
| | def truncate(state, beam_size): |
| | state_list = list(state.items()) |
| | state_list.sort(key=lambda x: -x[1]) |
| | return dict(state_list[:beam_size]) |
| |
|
| | state = {('', self.EMPTY_TOK): 1.0} |
| | for frame in probs: |
| | state = extend_and_merge(frame, state) |
| | state = truncate(state, beam_size) |
| | state_list = list(state.items()) |
| | state_list.sort(key=lambda x: -x[1]) |
| | |
| | |
| | |
| | |
| | return state_list[0][0][0] |
| | |
| | def ctc_lm_beam_search(self, logits: torch.tensor) -> str: |
| | assert self.decoder is not None |
| | return self.decoder.decode(logits, beam_width=500).lower() |