| import math |
| import re |
| from collections import defaultdict, deque |
| from dataclasses import dataclass, field |
| from typing import Dict, List, Optional, Tuple |
|
|
| import numpy as np |
|
|
|
|
| def log_add(*args: float) -> float: |
| if all(a == -float("inf") for a in args): |
| return -float("inf") |
| a_max = max(args) |
| return a_max + math.log(sum(math.exp(a - a_max) for a in args)) |
|
|
|
|
| def tokenize_by_bpe_model(sp, text: str, upper: bool = True) -> List[str]: |
| pattern = re.compile(r"([\u4e00-\u9fff])") |
| chars = pattern.split(text.upper() if upper else text) |
| tokens = [] |
| for item in chars: |
| if len(item.strip()) == 0: |
| continue |
| if pattern.fullmatch(item) is not None: |
| tokens.append(item) |
| else: |
| tokens.extend(sp.encode_as_pieces(item)) |
| return tokens |
|
|
|
|
| def tokenize(contexts: List[str], symbol_table: Dict[str, int], bpe_model: str = None): |
| sp = None |
| if bpe_model is not None: |
| try: |
| import sentencepiece as spm |
|
|
| sp = spm.SentencePieceProcessor() |
| sp.load(bpe_model) |
| except ImportError: |
| sp = None |
|
|
| unk = symbol_table.get("<unk>") |
| context_ids = [] |
| for context in contexts: |
| context = context.strip() |
| if sp is not None: |
| pieces = tokenize_by_bpe_model(sp, context) |
| else: |
| pieces = list(context.replace(" ", "▁")) |
|
|
| labels = [] |
| for piece in pieces: |
| if piece in symbol_table: |
| labels.append(symbol_table[piece]) |
| elif unk is not None: |
| labels.append(unk) |
| if labels: |
| context_ids.append(labels) |
| return context_ids |
|
|
|
|
| @dataclass |
| class ContextState: |
| idx: int |
| token: int = -1 |
| token_score: float = 0.0 |
| node_score: float = 0.0 |
| output_score: float = 0.0 |
| is_end: bool = False |
| fail: Optional["ContextState"] = None |
| output: Optional["ContextState"] = None |
| next: Dict[int, "ContextState"] = field(default_factory=dict) |
|
|
|
|
| class ContextGraph: |
| def __init__( |
| self, |
| contexts: List[str], |
| symbol_table: Dict[str, int], |
| bpe_model: str = None, |
| context_score: float = 6.0, |
| ): |
| self.context_score = context_score |
| self.num_nodes = 0 |
| self.root = ContextState(self.num_nodes) |
| self.root.fail = self.root |
| self.build(tokenize(contexts, symbol_table, bpe_model)) |
|
|
| def build(self, token_ids: List[List[int]]): |
| for tokens in token_ids: |
| node = self.root |
| for idx, token in enumerate(tokens): |
| if token not in node.next: |
| self.num_nodes += 1 |
| is_end = idx == len(tokens) - 1 |
| node_score = node.node_score + self.context_score |
| node.next[token] = ContextState( |
| idx=self.num_nodes, |
| token=token, |
| token_score=self.context_score, |
| node_score=node_score, |
| output_score=node_score if is_end else 0.0, |
| is_end=is_end, |
| ) |
| node = node.next[token] |
| self.fill_fail_output() |
|
|
| def fill_fail_output(self): |
| queue = deque() |
| for node in self.root.next.values(): |
| node.fail = self.root |
| queue.append(node) |
|
|
| while queue: |
| current = queue.popleft() |
| for token, node in current.next.items(): |
| fail = current.fail |
| while fail is not self.root and token not in fail.next: |
| fail = fail.fail |
| node.fail = fail.next[token] if token in fail.next else self.root |
|
|
| output = node.fail |
| while output is not self.root and not output.is_end: |
| output = output.fail |
| node.output = output if output.is_end else None |
| if node.output is not None: |
| node.output_score += node.output.output_score |
| queue.append(node) |
|
|
| def forward_one_step( |
| self, state: ContextState, token: int |
| ) -> Tuple[float, ContextState]: |
| node = state |
| while node is not self.root and token not in node.next: |
| node = node.fail |
| node = node.next[token] if token in node.next else self.root |
| score = node.node_score - state.node_score + node.output_score |
| return score, node |
|
|
| def finalize(self, state: ContextState) -> Tuple[float, ContextState]: |
| return -state.node_score, self.root |
|
|
|
|
| @dataclass |
| class PrefixScore: |
| s: float = float("-inf") |
| ns: float = float("-inf") |
| v_s: float = float("-inf") |
| v_ns: float = float("-inf") |
| context_state: Optional[ContextState] = None |
| context_score: float = 0.0 |
| cur_token_prob: float = float("-inf") |
| times_s: List[int] = field(default_factory=list) |
| times_ns: List[int] = field(default_factory=list) |
| token_probs: List[float] = field(default_factory=list) |
| has_context: bool = False |
|
|
| def score(self): |
| return log_add(self.s, self.ns) |
|
|
| def viterbi_score(self): |
| return self.v_s if self.v_s > self.v_ns else self.v_ns |
|
|
| def times(self): |
| return self.times_s if self.v_s > self.v_ns else self.times_ns |
|
|
| def total_score(self): |
| return self.score() + self.context_score |
|
|
|
|
| class CTCDecoder: |
| def __init__( |
| self, |
| contexts: List[str] = None, |
| symbol_table: Dict[str, int] = None, |
| bpe_model: str = None, |
| context_score: float = 6.0, |
| blank_id: int = 0, |
| ): |
| self.context_graph = None |
| if contexts is not None: |
| self.context_graph = ContextGraph( |
| contexts, symbol_table, bpe_model, context_score |
| ) |
| self.blank_id = blank_id |
| self.reset() |
|
|
| def reset(self): |
| context_root = self.context_graph.root if self.context_graph is not None else None |
| self.cur_t = 0 |
| self.cur_hyps = [ |
| (tuple(), PrefixScore(s=0.0, v_s=0.0, context_state=context_root)) |
| ] |
|
|
| def copy_context(self, prefix_score: PrefixScore, next_score: PrefixScore): |
| if self.context_graph is not None and not next_score.has_context: |
| next_score.context_score = prefix_score.context_score |
| next_score.context_state = prefix_score.context_state |
| next_score.has_context = True |
|
|
| def update_context( |
| self, prefix_score: PrefixScore, next_score: PrefixScore, token: int |
| ): |
| if self.context_graph is not None and not next_score.has_context: |
| score, state = self.context_graph.forward_one_step( |
| prefix_score.context_state, token |
| ) |
| next_score.context_score = prefix_score.context_score + score |
| next_score.context_state = state |
| next_score.has_context = True |
|
|
| def backoff_context(self): |
| if self.context_graph is None: |
| return |
| for _, score in self.cur_hyps: |
| backoff_score, state = self.context_graph.finalize(score.context_state) |
| score.context_score += backoff_score |
| score.context_state = state |
|
|
| @staticmethod |
| def topk(logp: np.ndarray, beam_size: int): |
| if beam_size >= logp.shape[0]: |
| indices = np.argsort(logp)[::-1] |
| else: |
| candidates = np.argpartition(logp, -beam_size)[-beam_size:] |
| indices = candidates[np.argsort(logp[candidates])[::-1]] |
| return logp[indices], indices |
|
|
| def ctc_prefix_beam_search( |
| self, |
| ctc_probs: np.ndarray, |
| beam_size: int, |
| is_last: bool = False, |
| return_probs: bool = False, |
| ): |
| for logp in ctc_probs: |
| self.cur_t += 1 |
| next_hyps = defaultdict(PrefixScore) |
| top_probs, top_indices = self.topk(logp, beam_size) |
| for prob, token in zip(top_probs.tolist(), top_indices.tolist()): |
| for prefix, prefix_score in self.cur_hyps: |
| last = prefix[-1] if prefix else None |
| if token == self.blank_id: |
| next_score = next_hyps[prefix] |
| next_score.s = log_add( |
| next_score.s, prefix_score.score() + prob |
| ) |
| next_score.v_s = prefix_score.viterbi_score() + prob |
| next_score.times_s = prefix_score.times().copy() |
| if return_probs: |
| next_score.token_probs = prefix_score.token_probs.copy() |
| self.copy_context(prefix_score, next_score) |
| elif token == last: |
| next_score = next_hyps[prefix] |
| next_score.ns = log_add( |
| next_score.ns, prefix_score.ns + prob |
| ) |
| if next_score.v_ns < prefix_score.v_ns + prob: |
| next_score.v_ns = prefix_score.v_ns + prob |
| if next_score.cur_token_prob < prob: |
| next_score.cur_token_prob = prob |
| next_score.times_ns = prefix_score.times_ns.copy() |
| next_score.times_ns[-1] = self.cur_t |
| if return_probs: |
| next_score.token_probs = prefix_score.token_probs.copy() |
| next_score.token_probs[-1] = max( |
| next_score.token_probs[-1], prob |
| ) |
| self.copy_context(prefix_score, next_score) |
|
|
| new_prefix = prefix + (token,) |
| next_score = next_hyps[new_prefix] |
| next_score.ns = log_add( |
| next_score.ns, prefix_score.s + prob |
| ) |
| if next_score.v_ns < prefix_score.v_s + prob: |
| next_score.v_ns = prefix_score.v_s + prob |
| next_score.cur_token_prob = prob |
| next_score.times_ns = prefix_score.times_s.copy() |
| next_score.times_ns.append(self.cur_t) |
| if return_probs: |
| next_score.token_probs = prefix_score.token_probs.copy() |
| next_score.token_probs.append(prob) |
| self.update_context(prefix_score, next_score, token) |
| else: |
| new_prefix = prefix + (token,) |
| next_score = next_hyps[new_prefix] |
| next_score.ns = log_add( |
| next_score.ns, prefix_score.score() + prob |
| ) |
| if next_score.v_ns < prefix_score.viterbi_score() + prob: |
| next_score.v_ns = prefix_score.viterbi_score() + prob |
| next_score.cur_token_prob = prob |
| next_score.times_ns = prefix_score.times().copy() |
| next_score.times_ns.append(self.cur_t) |
| if return_probs: |
| next_score.token_probs = prefix_score.token_probs.copy() |
| next_score.token_probs.append(prob) |
| self.update_context(prefix_score, next_score, token) |
|
|
| self.cur_hyps = sorted( |
| next_hyps.items(), key=lambda item: item[1].total_score(), reverse=True |
| )[:beam_size] |
|
|
| cur_hyps = self.cur_hyps |
| if is_last: |
| self.backoff_context() |
| self.reset() |
|
|
| response = { |
| "tokens": [list(prefix) for prefix, _ in cur_hyps], |
| "times": [score.times() for _, score in cur_hyps], |
| } |
| if return_probs: |
| response["probs"] = [ |
| [math.exp(prob) for prob in score.token_probs] for _, score in cur_hyps |
| ] |
| return response |
|
|