SenseVoice / python /ctc_decoder.py
inoryQwQ's picture
Simplify SenseVoiceAx decoder dependencies
5c1c337
Raw
History Blame Contribute Delete
12.1 kB
import math
import re
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import numpy as np
def log_add(*args: float) -> float:
if all(a == -float("inf") for a in args):
return -float("inf")
a_max = max(args)
return a_max + math.log(sum(math.exp(a - a_max) for a in args))
def tokenize_by_bpe_model(sp, text: str, upper: bool = True) -> List[str]:
pattern = re.compile(r"([\u4e00-\u9fff])")
chars = pattern.split(text.upper() if upper else text)
tokens = []
for item in chars:
if len(item.strip()) == 0:
continue
if pattern.fullmatch(item) is not None:
tokens.append(item)
else:
tokens.extend(sp.encode_as_pieces(item))
return tokens
def tokenize(contexts: List[str], symbol_table: Dict[str, int], bpe_model: str = None):
sp = None
if bpe_model is not None:
try:
import sentencepiece as spm
sp = spm.SentencePieceProcessor()
sp.load(bpe_model)
except ImportError:
sp = None
unk = symbol_table.get("<unk>")
context_ids = []
for context in contexts:
context = context.strip()
if sp is not None:
pieces = tokenize_by_bpe_model(sp, context)
else:
pieces = list(context.replace(" ", "▁"))
labels = []
for piece in pieces:
if piece in symbol_table:
labels.append(symbol_table[piece])
elif unk is not None:
labels.append(unk)
if labels:
context_ids.append(labels)
return context_ids
@dataclass
class ContextState:
idx: int
token: int = -1
token_score: float = 0.0
node_score: float = 0.0
output_score: float = 0.0
is_end: bool = False
fail: Optional["ContextState"] = None
output: Optional["ContextState"] = None
next: Dict[int, "ContextState"] = field(default_factory=dict)
class ContextGraph:
def __init__(
self,
contexts: List[str],
symbol_table: Dict[str, int],
bpe_model: str = None,
context_score: float = 6.0,
):
self.context_score = context_score
self.num_nodes = 0
self.root = ContextState(self.num_nodes)
self.root.fail = self.root
self.build(tokenize(contexts, symbol_table, bpe_model))
def build(self, token_ids: List[List[int]]):
for tokens in token_ids:
node = self.root
for idx, token in enumerate(tokens):
if token not in node.next:
self.num_nodes += 1
is_end = idx == len(tokens) - 1
node_score = node.node_score + self.context_score
node.next[token] = ContextState(
idx=self.num_nodes,
token=token,
token_score=self.context_score,
node_score=node_score,
output_score=node_score if is_end else 0.0,
is_end=is_end,
)
node = node.next[token]
self.fill_fail_output()
def fill_fail_output(self):
queue = deque()
for node in self.root.next.values():
node.fail = self.root
queue.append(node)
while queue:
current = queue.popleft()
for token, node in current.next.items():
fail = current.fail
while fail is not self.root and token not in fail.next:
fail = fail.fail
node.fail = fail.next[token] if token in fail.next else self.root
output = node.fail
while output is not self.root and not output.is_end:
output = output.fail
node.output = output if output.is_end else None
if node.output is not None:
node.output_score += node.output.output_score
queue.append(node)
def forward_one_step(
self, state: ContextState, token: int
) -> Tuple[float, ContextState]:
node = state
while node is not self.root and token not in node.next:
node = node.fail
node = node.next[token] if token in node.next else self.root
score = node.node_score - state.node_score + node.output_score
return score, node
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
return -state.node_score, self.root
@dataclass
class PrefixScore:
s: float = float("-inf")
ns: float = float("-inf")
v_s: float = float("-inf")
v_ns: float = float("-inf")
context_state: Optional[ContextState] = None
context_score: float = 0.0
cur_token_prob: float = float("-inf")
times_s: List[int] = field(default_factory=list)
times_ns: List[int] = field(default_factory=list)
token_probs: List[float] = field(default_factory=list)
has_context: bool = False
def score(self):
return log_add(self.s, self.ns)
def viterbi_score(self):
return self.v_s if self.v_s > self.v_ns else self.v_ns
def times(self):
return self.times_s if self.v_s > self.v_ns else self.times_ns
def total_score(self):
return self.score() + self.context_score
class CTCDecoder:
def __init__(
self,
contexts: List[str] = None,
symbol_table: Dict[str, int] = None,
bpe_model: str = None,
context_score: float = 6.0,
blank_id: int = 0,
):
self.context_graph = None
if contexts is not None:
self.context_graph = ContextGraph(
contexts, symbol_table, bpe_model, context_score
)
self.blank_id = blank_id
self.reset()
def reset(self):
context_root = self.context_graph.root if self.context_graph is not None else None
self.cur_t = 0
self.cur_hyps = [
(tuple(), PrefixScore(s=0.0, v_s=0.0, context_state=context_root))
]
def copy_context(self, prefix_score: PrefixScore, next_score: PrefixScore):
if self.context_graph is not None and not next_score.has_context:
next_score.context_score = prefix_score.context_score
next_score.context_state = prefix_score.context_state
next_score.has_context = True
def update_context(
self, prefix_score: PrefixScore, next_score: PrefixScore, token: int
):
if self.context_graph is not None and not next_score.has_context:
score, state = self.context_graph.forward_one_step(
prefix_score.context_state, token
)
next_score.context_score = prefix_score.context_score + score
next_score.context_state = state
next_score.has_context = True
def backoff_context(self):
if self.context_graph is None:
return
for _, score in self.cur_hyps:
backoff_score, state = self.context_graph.finalize(score.context_state)
score.context_score += backoff_score
score.context_state = state
@staticmethod
def topk(logp: np.ndarray, beam_size: int):
if beam_size >= logp.shape[0]:
indices = np.argsort(logp)[::-1]
else:
candidates = np.argpartition(logp, -beam_size)[-beam_size:]
indices = candidates[np.argsort(logp[candidates])[::-1]]
return logp[indices], indices
def ctc_prefix_beam_search(
self,
ctc_probs: np.ndarray,
beam_size: int,
is_last: bool = False,
return_probs: bool = False,
):
for logp in ctc_probs:
self.cur_t += 1
next_hyps = defaultdict(PrefixScore)
top_probs, top_indices = self.topk(logp, beam_size)
for prob, token in zip(top_probs.tolist(), top_indices.tolist()):
for prefix, prefix_score in self.cur_hyps:
last = prefix[-1] if prefix else None
if token == self.blank_id:
next_score = next_hyps[prefix]
next_score.s = log_add(
next_score.s, prefix_score.score() + prob
)
next_score.v_s = prefix_score.viterbi_score() + prob
next_score.times_s = prefix_score.times().copy()
if return_probs:
next_score.token_probs = prefix_score.token_probs.copy()
self.copy_context(prefix_score, next_score)
elif token == last:
next_score = next_hyps[prefix]
next_score.ns = log_add(
next_score.ns, prefix_score.ns + prob
)
if next_score.v_ns < prefix_score.v_ns + prob:
next_score.v_ns = prefix_score.v_ns + prob
if next_score.cur_token_prob < prob:
next_score.cur_token_prob = prob
next_score.times_ns = prefix_score.times_ns.copy()
next_score.times_ns[-1] = self.cur_t
if return_probs:
next_score.token_probs = prefix_score.token_probs.copy()
next_score.token_probs[-1] = max(
next_score.token_probs[-1], prob
)
self.copy_context(prefix_score, next_score)
new_prefix = prefix + (token,)
next_score = next_hyps[new_prefix]
next_score.ns = log_add(
next_score.ns, prefix_score.s + prob
)
if next_score.v_ns < prefix_score.v_s + prob:
next_score.v_ns = prefix_score.v_s + prob
next_score.cur_token_prob = prob
next_score.times_ns = prefix_score.times_s.copy()
next_score.times_ns.append(self.cur_t)
if return_probs:
next_score.token_probs = prefix_score.token_probs.copy()
next_score.token_probs.append(prob)
self.update_context(prefix_score, next_score, token)
else:
new_prefix = prefix + (token,)
next_score = next_hyps[new_prefix]
next_score.ns = log_add(
next_score.ns, prefix_score.score() + prob
)
if next_score.v_ns < prefix_score.viterbi_score() + prob:
next_score.v_ns = prefix_score.viterbi_score() + prob
next_score.cur_token_prob = prob
next_score.times_ns = prefix_score.times().copy()
next_score.times_ns.append(self.cur_t)
if return_probs:
next_score.token_probs = prefix_score.token_probs.copy()
next_score.token_probs.append(prob)
self.update_context(prefix_score, next_score, token)
self.cur_hyps = sorted(
next_hyps.items(), key=lambda item: item[1].total_score(), reverse=True
)[:beam_size]
cur_hyps = self.cur_hyps
if is_last:
self.backoff_context()
self.reset()
response = {
"tokens": [list(prefix) for prefix, _ in cur_hyps],
"times": [score.times() for _, score in cur_hyps],
}
if return_probs:
response["probs"] = [
[math.exp(prob) for prob in score.token_probs] for _, score in cur_hyps
]
return response