import multiprocessing import os import tarfile as tf import numpy as np import torch import torchaudio import torchaudio.compliance.kaldi as kaldi import yaml try: from swig_decoders import (map_batch, ctc_beam_search_decoder_batch, TrieVector, PathTrie) except ModuleNotFoundError: class PathTrie: pass class TrieVector(list): pass def _log_add(*values): values = [value for value in values if value != -float("inf")] if not values: return -float("inf") max_value = max(values) return max_value + np.log(sum(np.exp(value - max_value) for value in values)) def _map_sentence(sent, vocabulary, greedy=False, blank_id=0): mapped = [] prev = None for token in sent: token = int(token) if greedy and token == prev: prev = token continue prev = token if token == blank_id or token < 0 or token >= len(vocabulary): continue piece = vocabulary[token] if piece.startswith("<") and piece.endswith(">"): continue mapped.append(piece) return "".join(mapped) def map_batch(batch_sents, vocabulary, num_processes, greedy=False, blank_id=0): del num_processes return [_map_sentence(sent, vocabulary, greedy, blank_id) for sent in batch_sents] def _ctc_prefix_beam_search(log_probs_seq, log_probs_idx, beam_size, blank_id): beam = {(): (0.0, -float("inf"))} for frame_probs, frame_ids in zip(log_probs_seq, log_probs_idx): next_beam = {} for prefix, (prob_blank, prob_non_blank) in beam.items(): for prob, token in zip(frame_probs, frame_ids): token = int(token) prob = float(prob) next_prob_blank, next_prob_non_blank = next_beam.get( prefix, (-float("inf"), -float("inf"))) if token == blank_id: next_beam[prefix] = ( _log_add(next_prob_blank, prob_blank + prob, prob_non_blank + prob), next_prob_non_blank, ) continue last = prefix[-1] if prefix else None if token == last: next_beam[prefix] = ( next_prob_blank, _log_add(next_prob_non_blank, prob_non_blank + prob), ) new_prefix = prefix + (token, ) nb_blank, nb_non_blank = next_beam.get( new_prefix, (-float("inf"), -float("inf"))) next_beam[new_prefix] = ( nb_blank, _log_add(nb_non_blank, prob_blank + prob), ) else: new_prefix = prefix + (token, ) nb_blank, nb_non_blank = next_beam.get( new_prefix, (-float("inf"), -float("inf"))) next_beam[new_prefix] = ( nb_blank, _log_add(nb_non_blank, prob_blank + prob, prob_non_blank + prob), ) beam = dict(sorted( next_beam.items(), key=lambda item: _log_add(item[1][0], item[1][1]), reverse=True)[:beam_size]) return [(_log_add(prob_blank, prob_non_blank), prefix) for prefix, (prob_blank, prob_non_blank) in sorted( beam.items(), key=lambda item: _log_add(item[1][0], item[1][1]), reverse=True)] def ctc_beam_search_decoder_batch(batch_log_probs_seq, batch_log_probs_idx, batch_root_trie, batch_start, beam_size, num_processes, blank_id=0, space_id=-1, cutoff_prob=0.999, ext_scorer=None): del batch_root_trie, batch_start, num_processes, space_id del cutoff_prob, ext_scorer return [ _ctc_prefix_beam_search(log_probs_seq, log_probs_idx, beam_size, blank_id) for log_probs_seq, log_probs_idx in zip(batch_log_probs_seq, batch_log_probs_idx) ] def load_config(config_path): with open(config_path, "r") as fin: return yaml.load(fin, Loader=yaml.FullLoader) def load_vocab(vocab_path): vocabulary = [] char_dict = {} with open(vocab_path, "r") as fin: for line in fin: arr = line.strip().split() assert len(arr) == 2 char_dict[int(arr[1])] = arr[0] vocabulary.append(arr[0]) return vocabulary, char_dict def compute_feats(audio_file: str, sr=16000) -> np.ndarray: try: import soundfile as sf waveform, sample_rate = sf.read(audio_file, dtype="int16", always_2d=True) waveform = torch.from_numpy(waveform.T).to(torch.float) except ModuleNotFoundError: waveform, sample_rate = torchaudio.load(audio_file, normalize=True) waveform = waveform.to(torch.float) if sample_rate != sr: waveform = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=sr)(waveform) feats = kaldi.fbank(waveform, num_mel_bins=80, frame_length=25, frame_shift=10, energy_floor=0.0, sample_frequency=sr) return feats.unsqueeze(0).numpy() def pad_array_along_axis(array, pad_width, axis, mode="constant", **kwargs): if array.shape[axis] >= pad_width: return array full_pad_width = [(0, 0)] * array.ndim full_pad_width[axis] = (0, pad_width - array.shape[axis]) return np.pad(array, pad_width=full_pad_width, mode=mode, **kwargs) def ctc_decoding(beam_log_probs, beam_log_probs_idx, encoder_out_lens, vocabulary, mode="ctc_prefix_beam_search"): beam_size = beam_log_probs.shape[-1] batch_size = beam_log_probs.shape[0] num_processes = min(multiprocessing.cpu_count(), batch_size) hyps = [] score_hyps = [] if mode == "ctc_greedy_search": log_probs_idx = beam_log_probs_idx[:, :, 0] batch_sents = [] for idx, seq in enumerate(log_probs_idx): batch_sents.append(seq[0:encoder_out_lens[idx]].tolist()) hyps = map_batch(batch_sents, vocabulary, num_processes, True, 0) elif mode in ("ctc_prefix_beam_search", "attention_rescoring"): batch_log_probs_seq_list = beam_log_probs.tolist() batch_log_probs_idx_list = beam_log_probs_idx.tolist() batch_len_list = encoder_out_lens.tolist() batch_log_probs_seq = [] batch_log_probs_ids = [] batch_start = [] batch_root = TrieVector() root_dict = {} for i in range(len(batch_len_list)): num_sent = batch_len_list[i] batch_log_probs_seq.append(batch_log_probs_seq_list[i][0:num_sent]) batch_log_probs_ids.append(batch_log_probs_idx_list[i][0:num_sent]) root_dict[i] = PathTrie() batch_root.append(root_dict[i]) batch_start.append(True) score_hyps = ctc_beam_search_decoder_batch(batch_log_probs_seq, batch_log_probs_ids, batch_root, batch_start, beam_size, num_processes, 0, -2, 0.99999) if mode == "ctc_prefix_beam_search": for cand_hyps in score_hyps: hyps.append(cand_hyps[0][1]) hyps = map_batch(hyps, vocabulary, num_processes, False, 0) return hyps, score_hyps def make_decoder_inputs(encoder_out, encoder_out_lens, beam_log_probs, beam_log_probs_idx, vocabulary, sos, eos, decoder_len): _, score_hyps = ctc_decoding(beam_log_probs, beam_log_probs_idx, encoder_out_lens, vocabulary, "attention_rescoring") ignore_id = -1 beam_size = beam_log_probs.shape[-1] batch_size = beam_log_probs.shape[0] ctc_score, all_hyps = [], [] for hyps in score_hyps: cur_len = len(hyps) if len(hyps) < beam_size: hyps += (beam_size - cur_len) * [(-float("inf"), (0,))] cur_ctc_score = [] for hyp in hyps: cur_ctc_score.append(hyp[0]) all_hyps.append(list(hyp[1])) ctc_score.append(cur_ctc_score) ctc_score = np.array(ctc_score, dtype=np.float32) max_len = decoder_len - 2 hyps_pad_sos_eos = np.ones((batch_size, beam_size, max_len + 2), dtype=np.int64) * ignore_id r_hyps_pad_sos_eos = np.ones((batch_size, beam_size, max_len + 2), dtype=np.int64) * ignore_id hyps_lens_sos = np.ones((batch_size, beam_size), dtype=np.int32) k = 0 for i in range(batch_size): for j in range(beam_size): cand = all_hyps[k][:max_len] length = len(cand) + 2 hyps_pad_sos_eos[i][j][0:length] = [sos] + cand + [eos] r_hyps_pad_sos_eos[i][j][0:length] = [sos] + cand[::-1] + [eos] hyps_lens_sos[i][j] = len(cand) + 1 k += 1 if decoder_len > encoder_out.shape[1]: encoder_out = np.pad(encoder_out, [(0, 0), (0, decoder_len - encoder_out.shape[1]), (0, 0)], mode="constant", constant_values=0) elif decoder_len < encoder_out.shape[1]: encoder_out = encoder_out[:, :decoder_len, :] return { "encoder_out": encoder_out, "encoder_out_lens": np.full(batch_size, fill_value=decoder_len, dtype=np.int32), "hyps_pad_sos_eos": hyps_pad_sos_eos.astype(np.int32), "hyps_lens_sos": hyps_lens_sos, "r_hyps_pad_sos_eos": r_hyps_pad_sos_eos.astype(np.int32), "ctc_score": ctc_score, }, all_hyps def make_offline_inputs(feats, seq_len): feats = feats[:, :seq_len, :] speech_lengths = np.array([feats.shape[1]], dtype=np.int32) if feats.shape[1] < seq_len: feats = pad_array_along_axis(feats, pad_width=seq_len, axis=1) return {"speech": feats, "speech_lengths": speech_lengths} def make_online_initial_state(configs, batch_size=1, decoding_chunk_size=16, num_decoding_left_chunks=5): subsampling = 4 context = 7 stride = subsampling * decoding_chunk_size decoding_window = (decoding_chunk_size - 1) * subsampling + context required_cache_size = decoding_chunk_size * num_decoding_left_chunks output_size = configs["encoder_conf"]["output_size"] num_layers = configs["encoder_conf"]["num_blocks"] cnn_module_kernel = configs["encoder_conf"].get("cnn_module_kernel", 1) - 1 head = configs["encoder_conf"]["attention_heads"] d_k = configs["encoder_conf"]["output_size"] // head state = { "att_cache": np.zeros((batch_size, num_layers, head, required_cache_size, d_k * 2), dtype=np.float32), "cnn_cache": np.zeros((batch_size, num_layers, output_size, cnn_module_kernel), dtype=np.float32), "cache_mask": np.zeros((batch_size, 1, required_cache_size), dtype=np.float32), "offset": np.zeros((batch_size, 1), dtype=np.int32), } params = { "batch_size": batch_size, "context": context, "stride": stride, "decoding_window": decoding_window, } return state, params def make_online_encoder_input(feats, cur, params, state): batch_size = params["batch_size"] decoding_window = params["decoding_window"] end = min(cur + decoding_window, feats.shape[1]) chunk_xs = feats[:, cur:end, :] if chunk_xs.shape[1] < decoding_window: chunk_xs = pad_array_along_axis(chunk_xs, pad_width=decoding_window, axis=1) chunk_xs = chunk_xs.astype(np.float32) chunk_lens = np.full(batch_size, fill_value=chunk_xs.shape[1], dtype=np.int32) return { "chunk_xs": chunk_xs, "chunk_lens": chunk_lens, "offset": state["offset"], "att_cache": state["att_cache"], "cnn_cache": state["cnn_cache"], "cache_mask": state["cache_mask"], } def update_online_state(state, outputs): state["offset"] = outputs[4] state["att_cache"] = outputs[5] state["cnn_cache"] = outputs[6] state["cache_mask"] = outputs[7] def save_calibration_inputs(calib_data_path, inputs, sample_id): for input_name, data in inputs.items(): data_path = os.path.join(calib_data_path, input_name) os.makedirs(data_path, exist_ok=True) np.save(os.path.join(data_path, f"{sample_id}.npy"), data) def pack_calibration_dataset(calib_data_path): for input_name in sorted(os.listdir(calib_data_path)): data_path = os.path.join(calib_data_path, input_name) if not os.path.isdir(data_path): continue tar_path = os.path.join(calib_data_path, input_name + ".tar.gz") with tf.open(tar_path, "w:gz") as tf_file: tf_file.add(data_path, arcname=input_name) class WenetONNXRunner: def __init__(self, config_path, vocab_path, onnx_dir="onnx_model", offline_seq_len=1024, decoder_len=32, decoding_chunk_size=16, num_decoding_left_chunks=5, batch_size=1, providers=None): self.config_path = config_path self.vocab_path = vocab_path self.onnx_dir = onnx_dir self.offline_seq_len = offline_seq_len self.decoder_len = decoder_len self.decoding_chunk_size = decoding_chunk_size self.num_decoding_left_chunks = num_decoding_left_chunks self.batch_size = batch_size self.providers = providers or ["CPUExecutionProvider"] self.configs = load_config(config_path) self.vocabulary, self.char_dict = load_vocab(vocab_path) self.eos = self.sos = len(self.char_dict) - 1 self._offline_encoder = None self._online_encoder = None self._decoder = None @property def offline_encoder(self): if self._offline_encoder is None: self._offline_encoder = self._new_session("encoder_offline.onnx") return self._offline_encoder @property def online_encoder(self): if self._online_encoder is None: self._online_encoder = self._new_session("encoder_online.onnx") return self._online_encoder @property def decoder(self): if self._decoder is None: self._decoder = self._new_session("decoder.onnx") return self._decoder def _new_session(self, filename): import onnxruntime as ort return ort.InferenceSession(os.path.join(self.onnx_dir, filename), providers=self.providers) def compute_feats(self, audio_file): return compute_feats(audio_file) def run_offline_encoder(self, feats, calib_data_path=None, sample_id=0): encoder_input = make_offline_inputs(feats, self.offline_seq_len) if calib_data_path: save_calibration_inputs(calib_data_path, encoder_input, sample_id) encoder_out, encoder_out_lens, ctc_log_probs, beam_log_probs, beam_log_probs_idx = ( self.offline_encoder.run(None, encoder_input)) return { "encoder_out": encoder_out, "encoder_out_lens": encoder_out_lens, "ctc_log_probs": ctc_log_probs, "beam_log_probs": beam_log_probs, "beam_log_probs_idx": beam_log_probs_idx, } def run_online_encoder(self, feats, calib_data_path=None, sample_prefix=0): state, online_params = make_online_initial_state( self.configs, self.batch_size, self.decoding_chunk_size, self.num_decoding_left_chunks) encoder_out = [] beam_log_probs = [] beam_log_probs_idx = [] num_frames = feats.shape[1] for cur in range(0, num_frames - online_params["context"] + 1, online_params["stride"]): encoder_input = make_online_encoder_input(feats, cur, online_params, state) if calib_data_path: save_calibration_inputs(calib_data_path, encoder_input, f"{sample_prefix}_{cur}") outputs = self.online_encoder.run(None, encoder_input) chunk_log_probs, chunk_log_probs_idx, chunk_out, chunk_out_lens = outputs[:4] update_online_state(state, outputs) del chunk_out_lens encoder_out.append(chunk_out) beam_log_probs.append(chunk_log_probs) beam_log_probs_idx.append(chunk_log_probs_idx.astype(np.int32)) return { "encoder_out": np.concatenate(encoder_out, axis=1), "encoder_out_lens": np.full(self.batch_size, fill_value=sum( out.shape[1] for out in encoder_out), dtype=np.int32), "beam_log_probs": np.concatenate(beam_log_probs, axis=1), "beam_log_probs_idx": np.concatenate(beam_log_probs_idx, axis=1), "num_chunks": len(encoder_out), } def ctc_decode(self, encoder_outputs, mode): return ctc_decoding(encoder_outputs["beam_log_probs"], encoder_outputs["beam_log_probs_idx"], encoder_outputs["encoder_out_lens"], self.vocabulary, mode) def run_decoder(self, encoder_outputs, calib_data_path=None, sample_id=0): decoder_input, all_hyps = make_decoder_inputs( encoder_outputs["encoder_out"], encoder_outputs["encoder_out_lens"], encoder_outputs["beam_log_probs"], encoder_outputs["beam_log_probs_idx"], self.vocabulary, self.sos, self.eos, self.decoder_len, ) if calib_data_path: save_calibration_inputs(calib_data_path, decoder_input, sample_id) best_index = self.decoder.run(None, decoder_input)[0].astype(np.int32) beam_size = encoder_outputs["beam_log_probs"].shape[-1] num_processes = min(multiprocessing.cpu_count(), best_index.shape[0]) best_sents = [] k = 0 for idx in best_index: best_sents.append(all_hyps[k:k + beam_size][idx]) k += beam_size hyps = map_batch(best_sents, self.vocabulary, num_processes) return "".join(hyps) def transcribe(self, audio_file, online=False, mode="ctc_prefix_beam_search", calib_data_path=None): feats = self.compute_feats(audio_file) if online: encoder_outputs = self.run_online_encoder( feats, calib_data_path=calib_data_path, sample_prefix=0) else: encoder_outputs = self.run_offline_encoder( feats, calib_data_path=calib_data_path, sample_id=0) if mode == "attention_rescoring": result = self.run_decoder(encoder_outputs, calib_data_path, 0) else: hyps, _ = self.ctc_decode(encoder_outputs, mode) result = "".join(hyps) if hyps else "" if calib_data_path: pack_calibration_dataset(calib_data_path) return result def save_calibration_for_audio(self, audio_file, parts, calib_data_path, sample_id): counts = {"offline": 0, "online": 0, "decoder": 0} feats = self.compute_feats(audio_file) offline_outputs = None if "offline" in parts or "decoder" in parts: offline_outputs = self.run_offline_encoder( feats, calib_data_path=calib_data_path if "offline" in parts else None, sample_id=sample_id) if "offline" in parts: counts["offline"] += 1 if "online" in parts: online_outputs = self.run_online_encoder( feats, calib_data_path=calib_data_path, sample_prefix=sample_id) counts["online"] += online_outputs["num_chunks"] if "decoder" in parts: self.run_decoder(offline_outputs, calib_data_path=calib_data_path, sample_id=sample_id) counts["decoder"] += 1 return counts