import math import re from collections import defaultdict, deque from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple import numpy as np def log_add(*args: float) -> float: if all(a == -float("inf") for a in args): return -float("inf") a_max = max(args) return a_max + math.log(sum(math.exp(a - a_max) for a in args)) def tokenize_by_bpe_model(sp, text: str, upper: bool = True) -> List[str]: pattern = re.compile(r"([\u4e00-\u9fff])") chars = pattern.split(text.upper() if upper else text) tokens = [] for item in chars: if len(item.strip()) == 0: continue if pattern.fullmatch(item) is not None: tokens.append(item) else: tokens.extend(sp.encode_as_pieces(item)) return tokens def tokenize(contexts: List[str], symbol_table: Dict[str, int], bpe_model: str = None): sp = None if bpe_model is not None: try: import sentencepiece as spm sp = spm.SentencePieceProcessor() sp.load(bpe_model) except ImportError: sp = None unk = symbol_table.get("") context_ids = [] for context in contexts: context = context.strip() if sp is not None: pieces = tokenize_by_bpe_model(sp, context) else: pieces = list(context.replace(" ", "▁")) labels = [] for piece in pieces: if piece in symbol_table: labels.append(symbol_table[piece]) elif unk is not None: labels.append(unk) if labels: context_ids.append(labels) return context_ids @dataclass class ContextState: idx: int token: int = -1 token_score: float = 0.0 node_score: float = 0.0 output_score: float = 0.0 is_end: bool = False fail: Optional["ContextState"] = None output: Optional["ContextState"] = None next: Dict[int, "ContextState"] = field(default_factory=dict) class ContextGraph: def __init__( self, contexts: List[str], symbol_table: Dict[str, int], bpe_model: str = None, context_score: float = 6.0, ): self.context_score = context_score self.num_nodes = 0 self.root = ContextState(self.num_nodes) self.root.fail = self.root self.build(tokenize(contexts, symbol_table, bpe_model)) def build(self, token_ids: List[List[int]]): for tokens in token_ids: node = self.root for idx, token in enumerate(tokens): if token not in node.next: self.num_nodes += 1 is_end = idx == len(tokens) - 1 node_score = node.node_score + self.context_score node.next[token] = ContextState( idx=self.num_nodes, token=token, token_score=self.context_score, node_score=node_score, output_score=node_score if is_end else 0.0, is_end=is_end, ) node = node.next[token] self.fill_fail_output() def fill_fail_output(self): queue = deque() for node in self.root.next.values(): node.fail = self.root queue.append(node) while queue: current = queue.popleft() for token, node in current.next.items(): fail = current.fail while fail is not self.root and token not in fail.next: fail = fail.fail node.fail = fail.next[token] if token in fail.next else self.root output = node.fail while output is not self.root and not output.is_end: output = output.fail node.output = output if output.is_end else None if node.output is not None: node.output_score += node.output.output_score queue.append(node) def forward_one_step( self, state: ContextState, token: int ) -> Tuple[float, ContextState]: node = state while node is not self.root and token not in node.next: node = node.fail node = node.next[token] if token in node.next else self.root score = node.node_score - state.node_score + node.output_score return score, node def finalize(self, state: ContextState) -> Tuple[float, ContextState]: return -state.node_score, self.root @dataclass class PrefixScore: s: float = float("-inf") ns: float = float("-inf") v_s: float = float("-inf") v_ns: float = float("-inf") context_state: Optional[ContextState] = None context_score: float = 0.0 cur_token_prob: float = float("-inf") times_s: List[int] = field(default_factory=list) times_ns: List[int] = field(default_factory=list) token_probs: List[float] = field(default_factory=list) has_context: bool = False def score(self): return log_add(self.s, self.ns) def viterbi_score(self): return self.v_s if self.v_s > self.v_ns else self.v_ns def times(self): return self.times_s if self.v_s > self.v_ns else self.times_ns def total_score(self): return self.score() + self.context_score class CTCDecoder: def __init__( self, contexts: List[str] = None, symbol_table: Dict[str, int] = None, bpe_model: str = None, context_score: float = 6.0, blank_id: int = 0, ): self.context_graph = None if contexts is not None: self.context_graph = ContextGraph( contexts, symbol_table, bpe_model, context_score ) self.blank_id = blank_id self.reset() def reset(self): context_root = self.context_graph.root if self.context_graph is not None else None self.cur_t = 0 self.cur_hyps = [ (tuple(), PrefixScore(s=0.0, v_s=0.0, context_state=context_root)) ] def copy_context(self, prefix_score: PrefixScore, next_score: PrefixScore): if self.context_graph is not None and not next_score.has_context: next_score.context_score = prefix_score.context_score next_score.context_state = prefix_score.context_state next_score.has_context = True def update_context( self, prefix_score: PrefixScore, next_score: PrefixScore, token: int ): if self.context_graph is not None and not next_score.has_context: score, state = self.context_graph.forward_one_step( prefix_score.context_state, token ) next_score.context_score = prefix_score.context_score + score next_score.context_state = state next_score.has_context = True def backoff_context(self): if self.context_graph is None: return for _, score in self.cur_hyps: backoff_score, state = self.context_graph.finalize(score.context_state) score.context_score += backoff_score score.context_state = state @staticmethod def topk(logp: np.ndarray, beam_size: int): if beam_size >= logp.shape[0]: indices = np.argsort(logp)[::-1] else: candidates = np.argpartition(logp, -beam_size)[-beam_size:] indices = candidates[np.argsort(logp[candidates])[::-1]] return logp[indices], indices def ctc_prefix_beam_search( self, ctc_probs: np.ndarray, beam_size: int, is_last: bool = False, return_probs: bool = False, ): for logp in ctc_probs: self.cur_t += 1 next_hyps = defaultdict(PrefixScore) top_probs, top_indices = self.topk(logp, beam_size) for prob, token in zip(top_probs.tolist(), top_indices.tolist()): for prefix, prefix_score in self.cur_hyps: last = prefix[-1] if prefix else None if token == self.blank_id: next_score = next_hyps[prefix] next_score.s = log_add( next_score.s, prefix_score.score() + prob ) next_score.v_s = prefix_score.viterbi_score() + prob next_score.times_s = prefix_score.times().copy() if return_probs: next_score.token_probs = prefix_score.token_probs.copy() self.copy_context(prefix_score, next_score) elif token == last: next_score = next_hyps[prefix] next_score.ns = log_add( next_score.ns, prefix_score.ns + prob ) if next_score.v_ns < prefix_score.v_ns + prob: next_score.v_ns = prefix_score.v_ns + prob if next_score.cur_token_prob < prob: next_score.cur_token_prob = prob next_score.times_ns = prefix_score.times_ns.copy() next_score.times_ns[-1] = self.cur_t if return_probs: next_score.token_probs = prefix_score.token_probs.copy() next_score.token_probs[-1] = max( next_score.token_probs[-1], prob ) self.copy_context(prefix_score, next_score) new_prefix = prefix + (token,) next_score = next_hyps[new_prefix] next_score.ns = log_add( next_score.ns, prefix_score.s + prob ) if next_score.v_ns < prefix_score.v_s + prob: next_score.v_ns = prefix_score.v_s + prob next_score.cur_token_prob = prob next_score.times_ns = prefix_score.times_s.copy() next_score.times_ns.append(self.cur_t) if return_probs: next_score.token_probs = prefix_score.token_probs.copy() next_score.token_probs.append(prob) self.update_context(prefix_score, next_score, token) else: new_prefix = prefix + (token,) next_score = next_hyps[new_prefix] next_score.ns = log_add( next_score.ns, prefix_score.score() + prob ) if next_score.v_ns < prefix_score.viterbi_score() + prob: next_score.v_ns = prefix_score.viterbi_score() + prob next_score.cur_token_prob = prob next_score.times_ns = prefix_score.times().copy() next_score.times_ns.append(self.cur_t) if return_probs: next_score.token_probs = prefix_score.token_probs.copy() next_score.token_probs.append(prob) self.update_context(prefix_score, next_score, token) self.cur_hyps = sorted( next_hyps.items(), key=lambda item: item[1].total_score(), reverse=True )[:beam_size] cur_hyps = self.cur_hyps if is_last: self.backoff_context() self.reset() response = { "tokens": [list(prefix) for prefix, _ in cur_hyps], "times": [score.times() for _, score in cur_hyps], } if return_probs: response["probs"] = [ [math.exp(prob) for prob in score.token_probs] for _, score in cur_hyps ] return response