|
|
import axengine as axe |
|
|
import numpy as np |
|
|
import librosa |
|
|
from frontend import WavFrontend |
|
|
import os |
|
|
import time |
|
|
from typing import List, Union, Optional |
|
|
from asr_decoder import CTCDecoder |
|
|
from tokenizer import SentencepiecesTokenizer |
|
|
from online_fbank import OnlineFbank |
|
|
import torch |
|
|
|
|
|
|
|
|
def sequence_mask(lengths, maxlen=None, dtype=np.float32): |
|
|
|
|
|
if maxlen is None: |
|
|
maxlen = np.max(lengths) |
|
|
|
|
|
|
|
|
row_vector = np.arange(0, maxlen, 1) |
|
|
|
|
|
|
|
|
matrix = np.expand_dims(lengths, axis=-1) |
|
|
|
|
|
|
|
|
mask = row_vector < matrix |
|
|
if mask.shape[-1] < lengths[0]: |
|
|
mask = np.concatenate( |
|
|
[ |
|
|
mask, |
|
|
np.zeros( |
|
|
(mask.shape[0], lengths[0] - mask.shape[-1]), dtype=np.float32 |
|
|
), |
|
|
], |
|
|
axis=-1, |
|
|
) |
|
|
|
|
|
|
|
|
return mask.astype(dtype)[None, ...] |
|
|
|
|
|
|
|
|
def unique_consecutive_np(arr): |
|
|
""" |
|
|
找出数组中连续的唯一值,模拟 torch.unique_consecutive(yseq, dim=-1) |
|
|
|
|
|
参数: |
|
|
arr: 一维numpy数组 |
|
|
|
|
|
返回: |
|
|
unique_values: 去除连续重复值后的数组 |
|
|
""" |
|
|
if len(arr) == 0: |
|
|
return np.array([]) |
|
|
|
|
|
if len(arr) == 1: |
|
|
return arr.copy() |
|
|
|
|
|
|
|
|
diff = np.diff(arr) |
|
|
change_positions = np.where(diff != 0)[0] + 1 |
|
|
|
|
|
|
|
|
start_positions = np.concatenate(([0], change_positions)) |
|
|
|
|
|
|
|
|
unique_values = arr[start_positions] |
|
|
|
|
|
return unique_values |
|
|
|
|
|
|
|
|
class SenseVoiceAx: |
|
|
"""SenseVoice axmodel runner""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_path: str, |
|
|
max_len: int = 256, |
|
|
beam_size: int = 3, |
|
|
language: str = "auto", |
|
|
hot_words: Optional[List[str]] = None, |
|
|
use_itn: bool = True, |
|
|
streaming: bool = False, |
|
|
providers=['AxEngineExecutionProvider'] |
|
|
): |
|
|
""" |
|
|
Initialize SenseVoiceAx |
|
|
|
|
|
Args: |
|
|
model_path: Path of axmodel |
|
|
max_len: Fixed shape of input of axmodel |
|
|
beam_size: Max number of hypos to hold after each decode step |
|
|
language: Support auto, zh(Chinese), en(English), yue(Cantonese), ja(Japanese), ko(Korean) |
|
|
hot_words: Words that may fail to recognize, |
|
|
special words/phrases (aka hotwords) like rare words, personalized information etc. |
|
|
use_itn: Allow Invert Text Normalization if True, |
|
|
ITN converts ASR model output into its written form to improve text readability, |
|
|
For example, the ITN module replaces “one hundred and twenty-three dollars” transcribed by an ASR model with “$123.” |
|
|
streaming: Processes audio in small segments or "chunks" sequentially and outputs text on the fly. |
|
|
Use stream_infer method if streaming is true otherwise infer. |
|
|
|
|
|
""" |
|
|
model_path_root = os.path.dirname(model_path) |
|
|
emb_path = os.path.join(model_path_root, "../embeddings.npy") |
|
|
cmvn_file = os.path.join(model_path_root, "../am.mvn") |
|
|
bpe_model = os.path.join( |
|
|
model_path_root, "../chn_jpn_yue_eng_ko_spectok.bpe.model" |
|
|
) |
|
|
if streaming: |
|
|
self.position_encoding = np.load( |
|
|
os.path.join(model_path_root, "../pe_streaming.npy") |
|
|
) |
|
|
else: |
|
|
self.position_encoding = np.load( |
|
|
os.path.join(model_path_root, "../pe_nonstream.npy") |
|
|
) |
|
|
|
|
|
self.streaming = streaming |
|
|
self.tokenizer = SentencepiecesTokenizer(bpemodel=bpe_model) |
|
|
|
|
|
self.frontend = WavFrontend( |
|
|
cmvn_file=cmvn_file, |
|
|
fs=16000, |
|
|
window="hamming", |
|
|
n_mels=80, |
|
|
frame_length=25, |
|
|
frame_shift=10, |
|
|
lfr_m=7, |
|
|
lfr_n=6, |
|
|
) |
|
|
self.model = axe.InferenceSession(model_path, providers=providers) |
|
|
self.sample_rate = 16000 |
|
|
self.blank_id = 0 |
|
|
self.max_len = max_len |
|
|
self.padding = 16 |
|
|
self.input_size = 560 |
|
|
|
|
|
self.lid_dict = { |
|
|
"auto": 0, |
|
|
"zh": 3, |
|
|
"en": 4, |
|
|
"yue": 7, |
|
|
"ja": 11, |
|
|
"ko": 12, |
|
|
"nospeech": 13, |
|
|
} |
|
|
self.lid_int_dict = { |
|
|
24884: 3, |
|
|
24885: 4, |
|
|
24888: 7, |
|
|
24892: 11, |
|
|
24896: 12, |
|
|
24992: 13, |
|
|
} |
|
|
self.textnorm_dict = {"withitn": 14, "woitn": 15} |
|
|
self.textnorm_int_dict = {25016: 14, 25017: 15} |
|
|
self.emo_dict = { |
|
|
"unk": 25009, |
|
|
"happy": 25001, |
|
|
"sad": 25002, |
|
|
"angry": 25003, |
|
|
"neutral": 25004, |
|
|
} |
|
|
|
|
|
self.load_embeddings(emb_path, language, use_itn) |
|
|
self.language = language |
|
|
|
|
|
|
|
|
if beam_size > 1 and hot_words is not None: |
|
|
self.beam_size = beam_size |
|
|
symbol_table = {} |
|
|
for i in range(self.tokenizer.get_vocab_size()): |
|
|
symbol_table[self.tokenizer.decode(i)] = i |
|
|
self.decoder = CTCDecoder(hot_words, symbol_table, bpe_model) |
|
|
else: |
|
|
self.beam_size = 1 |
|
|
self.decoder = CTCDecoder() |
|
|
|
|
|
if streaming: |
|
|
self.cur_idx = -1 |
|
|
self.chunk_size = max_len - self.padding |
|
|
self.caches_shape = (max_len, self.input_size) |
|
|
self.caches = np.zeros(self.caches_shape, dtype=np.float32) |
|
|
self.zeros = np.zeros((1, self.input_size), dtype=np.float32) |
|
|
self.neg_mean, self.inv_stddev = ( |
|
|
self.frontend.cmvn[0, :], |
|
|
self.frontend.cmvn[1, :], |
|
|
) |
|
|
|
|
|
self.fbank = OnlineFbank(window_type="hamming") |
|
|
self.masks = sequence_mask( |
|
|
np.array([self.max_len], dtype=np.int32), |
|
|
maxlen=self.max_len, |
|
|
dtype=np.float32, |
|
|
) |
|
|
|
|
|
@property |
|
|
def language_options(self): |
|
|
return list(self.lid_dict.keys()) |
|
|
|
|
|
@property |
|
|
def textnorm_options(self): |
|
|
return list(self.textnorm_dict.keys()) |
|
|
|
|
|
def load_embeddings(self, emb_path, language, use_itn): |
|
|
self.embeddings = np.load(emb_path, allow_pickle=True).item() |
|
|
self.language_query = self.embeddings[language] |
|
|
self.textnorm_query = ( |
|
|
self.embeddings["withitn"] if use_itn else self.embeddings["woitn"] |
|
|
) |
|
|
self.event_emo_query = self.embeddings["event_emo"] |
|
|
self.input_query = np.concatenate( |
|
|
(self.textnorm_query, self.language_query, self.event_emo_query), axis=1 |
|
|
) |
|
|
self.query_num = self.input_query.shape[1] |
|
|
|
|
|
def choose_language(self, language): |
|
|
self.language_query = self.embeddings[language] |
|
|
self.input_query = np.concatenate( |
|
|
(self.textnorm_query, self.language_query, self.event_emo_query), axis=1 |
|
|
) |
|
|
self.language = language |
|
|
|
|
|
def load_data(self, filepath: str) -> np.ndarray: |
|
|
waveform, _ = librosa.load(filepath, sr=self.sample_rate) |
|
|
return waveform.flatten() |
|
|
|
|
|
@staticmethod |
|
|
def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray: |
|
|
def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray: |
|
|
pad_width = ((0, max_feat_len - cur_len), (0, 0)) |
|
|
return np.pad(feat, pad_width, "constant", constant_values=0) |
|
|
|
|
|
feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats] |
|
|
feats = np.array(feat_res).astype(np.float32) |
|
|
return feats |
|
|
|
|
|
def preprocess(self, waveform): |
|
|
feats, feats_len = [], [] |
|
|
for wf in [waveform]: |
|
|
speech, _ = self.frontend.fbank(wf) |
|
|
feat, feat_len = self.frontend.lfr_cmvn(speech) |
|
|
feats.append(feat) |
|
|
feats_len.append(feat_len) |
|
|
|
|
|
feats = self.pad_feats(feats, np.max(feats_len)) |
|
|
feats_len = np.array(feats_len).astype(np.int32) |
|
|
return feats, feats_len |
|
|
|
|
|
def postprocess(self, ctc_logits, encoder_out_lens): |
|
|
|
|
|
x = ctc_logits[0, 4 : encoder_out_lens[0], :] |
|
|
|
|
|
|
|
|
yseq = np.argmax(x, axis=-1) |
|
|
|
|
|
|
|
|
yseq = unique_consecutive_np(yseq) |
|
|
|
|
|
|
|
|
mask = yseq != self.blank_id |
|
|
token_int = yseq[mask].tolist() |
|
|
|
|
|
return token_int |
|
|
|
|
|
def infer_waveform(self, waveform: np.ndarray, language="auto"): |
|
|
if language != self.language: |
|
|
self.choose_language(language) |
|
|
|
|
|
|
|
|
feat, feat_len = self.preprocess(waveform) |
|
|
|
|
|
|
|
|
slice_len = self.max_len - self.query_num |
|
|
slice_num = int(np.ceil(feat.shape[1] / slice_len)) |
|
|
|
|
|
asr_res = [] |
|
|
for i in range(slice_num): |
|
|
if i == 0: |
|
|
sub_feat = feat[:, i * slice_len : (i + 1) * slice_len, :] |
|
|
else: |
|
|
sub_feat = feat[ |
|
|
:, |
|
|
i * slice_len - self.padding : (i + 1) * slice_len - self.padding, |
|
|
:, |
|
|
] |
|
|
|
|
|
sub_feat = np.concatenate([self.input_query, sub_feat], axis=1) |
|
|
real_len = sub_feat.shape[1] |
|
|
if real_len < self.max_len: |
|
|
sub_feat = np.concatenate( |
|
|
[ |
|
|
sub_feat, |
|
|
np.zeros( |
|
|
(1, self.max_len - real_len, sub_feat.shape[-1]), |
|
|
dtype=np.float32, |
|
|
), |
|
|
], |
|
|
axis=1, |
|
|
) |
|
|
|
|
|
masks = sequence_mask( |
|
|
np.array([self.max_len], dtype=np.int32), |
|
|
maxlen=real_len, |
|
|
dtype=np.float32, |
|
|
) |
|
|
|
|
|
|
|
|
outputs = self.model.run( |
|
|
None, |
|
|
{ |
|
|
"speech": sub_feat, |
|
|
"masks": masks, |
|
|
"position_encoding": self.position_encoding, |
|
|
}, |
|
|
) |
|
|
ctc_logits, encoder_out_lens = outputs |
|
|
|
|
|
token_int = self.postprocess(ctc_logits, encoder_out_lens) |
|
|
|
|
|
if self.tokenizer is not None: |
|
|
asr_res.append(self.tokenizer.tokens2text(token_int)) |
|
|
else: |
|
|
asr_res.append(token_int) |
|
|
|
|
|
return asr_res |
|
|
|
|
|
def infer( |
|
|
self, filepath_or_data: Union[np.ndarray, str], language="auto", print_rtf=False |
|
|
): |
|
|
assert not self.streaming, "This method is for non-streaming model" |
|
|
|
|
|
if isinstance(filepath_or_data, str): |
|
|
waveform = self.load_data(filepath_or_data) |
|
|
else: |
|
|
waveform = filepath_or_data |
|
|
|
|
|
total_time = waveform.shape[-1] / self.sample_rate |
|
|
|
|
|
start = time.time() |
|
|
asr_res = self.infer_waveform(waveform, language) |
|
|
latency = time.time() - start |
|
|
|
|
|
if print_rtf: |
|
|
rtf = latency / total_time |
|
|
print(f"RTF: {rtf} Latency: {latency}s Total length: {total_time}s") |
|
|
return "".join(asr_res) |
|
|
|
|
|
def decode(self, times, tokens): |
|
|
times_ms = [] |
|
|
for step, token in zip(times, tokens): |
|
|
if len(self.tokenizer.decode(token).strip()) == 0: |
|
|
continue |
|
|
times_ms.append(step * 60) |
|
|
return times_ms, self.tokenizer.decode(tokens) |
|
|
|
|
|
def reset(self): |
|
|
self.cur_idx = -1 |
|
|
self.decoder.reset() |
|
|
self.fbank = OnlineFbank(window_type="hamming") |
|
|
self.caches = np.zeros(self.caches_shape) |
|
|
|
|
|
def get_size(self): |
|
|
effective_size = self.cur_idx + 1 - self.padding |
|
|
if effective_size <= 0: |
|
|
return 0 |
|
|
return effective_size % self.chunk_size or self.chunk_size |
|
|
|
|
|
def stream_infer(self, audio, is_last, language="auto"): |
|
|
assert self.streaming, "This method is for streaming model" |
|
|
|
|
|
if language != self.language: |
|
|
self.choose_language(language) |
|
|
|
|
|
self.fbank.accept_waveform(audio, is_last) |
|
|
features = self.fbank.get_lfr_frames( |
|
|
neg_mean=self.neg_mean, inv_stddev=self.inv_stddev |
|
|
) |
|
|
|
|
|
if is_last and len(features) == 0: |
|
|
features = self.zeros |
|
|
|
|
|
for idx, feature in enumerate(features): |
|
|
is_last = is_last and idx == features.shape[0] - 1 |
|
|
self.caches = np.roll(self.caches, -1, axis=0) |
|
|
self.caches[-1, :] = feature |
|
|
self.cur_idx += 1 |
|
|
cur_size = self.get_size() |
|
|
if cur_size != self.chunk_size and not is_last: |
|
|
continue |
|
|
|
|
|
speech = self.caches[None, ...] |
|
|
outputs = self.model.run( |
|
|
None, |
|
|
{ |
|
|
"speech": speech, |
|
|
"masks": self.masks, |
|
|
"position_encoding": self.position_encoding, |
|
|
}, |
|
|
) |
|
|
ctc_logits, encoder_out_lens = outputs |
|
|
probs = ctc_logits[0, 4 : encoder_out_lens[0]] |
|
|
probs = torch.from_numpy(probs) |
|
|
|
|
|
if cur_size != self.chunk_size: |
|
|
probs = probs[self.chunk_size - cur_size :] |
|
|
if not is_last: |
|
|
probs = probs[: self.chunk_size] |
|
|
if self.beam_size > 1: |
|
|
res = self.decoder.ctc_prefix_beam_search( |
|
|
probs, beam_size=self.beam_size, is_last=is_last |
|
|
) |
|
|
times_ms, text = self.decode(res["times"][0], res["tokens"][0]) |
|
|
else: |
|
|
res = self.decoder.ctc_greedy_search(probs, is_last=is_last) |
|
|
times_ms, text = self.decode(res["times"], res["tokens"]) |
|
|
yield {"timestamps": times_ms, "text": text} |
|
|
|