| import multiprocessing |
| import os |
| import tarfile as tf |
|
|
| import numpy as np |
| import torch |
| import torchaudio |
| import torchaudio.compliance.kaldi as kaldi |
| import yaml |
|
|
| try: |
| from swig_decoders import (map_batch, ctc_beam_search_decoder_batch, |
| TrieVector, PathTrie) |
| except ModuleNotFoundError: |
| class PathTrie: |
| pass |
|
|
| class TrieVector(list): |
| pass |
|
|
| def _log_add(*values): |
| values = [value for value in values if value != -float("inf")] |
| if not values: |
| return -float("inf") |
| max_value = max(values) |
| return max_value + np.log(sum(np.exp(value - max_value) |
| for value in values)) |
|
|
| def _map_sentence(sent, vocabulary, greedy=False, blank_id=0): |
| mapped = [] |
| prev = None |
| for token in sent: |
| token = int(token) |
| if greedy and token == prev: |
| prev = token |
| continue |
| prev = token |
| if token == blank_id or token < 0 or token >= len(vocabulary): |
| continue |
| piece = vocabulary[token] |
| if piece.startswith("<") and piece.endswith(">"): |
| continue |
| mapped.append(piece) |
| return "".join(mapped) |
|
|
| def map_batch(batch_sents, vocabulary, num_processes, greedy=False, |
| blank_id=0): |
| del num_processes |
| return [_map_sentence(sent, vocabulary, greedy, blank_id) |
| for sent in batch_sents] |
|
|
| def _ctc_prefix_beam_search(log_probs_seq, log_probs_idx, beam_size, |
| blank_id): |
| beam = {(): (0.0, -float("inf"))} |
| for frame_probs, frame_ids in zip(log_probs_seq, log_probs_idx): |
| next_beam = {} |
| for prefix, (prob_blank, prob_non_blank) in beam.items(): |
| for prob, token in zip(frame_probs, frame_ids): |
| token = int(token) |
| prob = float(prob) |
| next_prob_blank, next_prob_non_blank = next_beam.get( |
| prefix, (-float("inf"), -float("inf"))) |
| if token == blank_id: |
| next_beam[prefix] = ( |
| _log_add(next_prob_blank, prob_blank + prob, |
| prob_non_blank + prob), |
| next_prob_non_blank, |
| ) |
| continue |
|
|
| last = prefix[-1] if prefix else None |
| if token == last: |
| next_beam[prefix] = ( |
| next_prob_blank, |
| _log_add(next_prob_non_blank, |
| prob_non_blank + prob), |
| ) |
| new_prefix = prefix + (token, ) |
| nb_blank, nb_non_blank = next_beam.get( |
| new_prefix, (-float("inf"), -float("inf"))) |
| next_beam[new_prefix] = ( |
| nb_blank, |
| _log_add(nb_non_blank, prob_blank + prob), |
| ) |
| else: |
| new_prefix = prefix + (token, ) |
| nb_blank, nb_non_blank = next_beam.get( |
| new_prefix, (-float("inf"), -float("inf"))) |
| next_beam[new_prefix] = ( |
| nb_blank, |
| _log_add(nb_non_blank, prob_blank + prob, |
| prob_non_blank + prob), |
| ) |
| beam = dict(sorted( |
| next_beam.items(), |
| key=lambda item: _log_add(item[1][0], item[1][1]), |
| reverse=True)[:beam_size]) |
| return [(_log_add(prob_blank, prob_non_blank), prefix) |
| for prefix, (prob_blank, prob_non_blank) in sorted( |
| beam.items(), |
| key=lambda item: _log_add(item[1][0], item[1][1]), |
| reverse=True)] |
|
|
| def ctc_beam_search_decoder_batch(batch_log_probs_seq, |
| batch_log_probs_idx, |
| batch_root_trie, |
| batch_start, |
| beam_size, |
| num_processes, |
| blank_id=0, |
| space_id=-1, |
| cutoff_prob=0.999, |
| ext_scorer=None): |
| del batch_root_trie, batch_start, num_processes, space_id |
| del cutoff_prob, ext_scorer |
| return [ |
| _ctc_prefix_beam_search(log_probs_seq, log_probs_idx, beam_size, |
| blank_id) |
| for log_probs_seq, log_probs_idx in zip(batch_log_probs_seq, |
| batch_log_probs_idx) |
| ] |
|
|
|
|
| def load_config(config_path): |
| with open(config_path, "r") as fin: |
| return yaml.load(fin, Loader=yaml.FullLoader) |
|
|
|
|
| def load_vocab(vocab_path): |
| vocabulary = [] |
| char_dict = {} |
| with open(vocab_path, "r") as fin: |
| for line in fin: |
| arr = line.strip().split() |
| assert len(arr) == 2 |
| char_dict[int(arr[1])] = arr[0] |
| vocabulary.append(arr[0]) |
| return vocabulary, char_dict |
|
|
|
|
| def compute_feats(audio_file: str, sr=16000) -> np.ndarray: |
| try: |
| import soundfile as sf |
| waveform, sample_rate = sf.read(audio_file, dtype="int16", |
| always_2d=True) |
| waveform = torch.from_numpy(waveform.T).to(torch.float) |
| except ModuleNotFoundError: |
| waveform, sample_rate = torchaudio.load(audio_file, normalize=True) |
| waveform = waveform.to(torch.float) |
| if sample_rate != sr: |
| waveform = torchaudio.transforms.Resample( |
| orig_freq=sample_rate, new_freq=sr)(waveform) |
| feats = kaldi.fbank(waveform, |
| num_mel_bins=80, |
| frame_length=25, |
| frame_shift=10, |
| energy_floor=0.0, |
| sample_frequency=sr) |
| return feats.unsqueeze(0).numpy() |
|
|
|
|
| def pad_array_along_axis(array, pad_width, axis, mode="constant", **kwargs): |
| if array.shape[axis] >= pad_width: |
| return array |
| full_pad_width = [(0, 0)] * array.ndim |
| full_pad_width[axis] = (0, pad_width - array.shape[axis]) |
| return np.pad(array, pad_width=full_pad_width, mode=mode, **kwargs) |
|
|
|
|
| def ctc_decoding(beam_log_probs, |
| beam_log_probs_idx, |
| encoder_out_lens, |
| vocabulary, |
| mode="ctc_prefix_beam_search"): |
| beam_size = beam_log_probs.shape[-1] |
| batch_size = beam_log_probs.shape[0] |
| num_processes = min(multiprocessing.cpu_count(), batch_size) |
| hyps = [] |
| score_hyps = [] |
|
|
| if mode == "ctc_greedy_search": |
| log_probs_idx = beam_log_probs_idx[:, :, 0] |
| batch_sents = [] |
| for idx, seq in enumerate(log_probs_idx): |
| batch_sents.append(seq[0:encoder_out_lens[idx]].tolist()) |
| hyps = map_batch(batch_sents, vocabulary, num_processes, True, 0) |
| elif mode in ("ctc_prefix_beam_search", "attention_rescoring"): |
| batch_log_probs_seq_list = beam_log_probs.tolist() |
| batch_log_probs_idx_list = beam_log_probs_idx.tolist() |
| batch_len_list = encoder_out_lens.tolist() |
| batch_log_probs_seq = [] |
| batch_log_probs_ids = [] |
| batch_start = [] |
| batch_root = TrieVector() |
| root_dict = {} |
| for i in range(len(batch_len_list)): |
| num_sent = batch_len_list[i] |
| batch_log_probs_seq.append(batch_log_probs_seq_list[i][0:num_sent]) |
| batch_log_probs_ids.append(batch_log_probs_idx_list[i][0:num_sent]) |
| root_dict[i] = PathTrie() |
| batch_root.append(root_dict[i]) |
| batch_start.append(True) |
| score_hyps = ctc_beam_search_decoder_batch(batch_log_probs_seq, |
| batch_log_probs_ids, |
| batch_root, |
| batch_start, |
| beam_size, |
| num_processes, |
| 0, -2, 0.99999) |
| if mode == "ctc_prefix_beam_search": |
| for cand_hyps in score_hyps: |
| hyps.append(cand_hyps[0][1]) |
| hyps = map_batch(hyps, vocabulary, num_processes, False, 0) |
| return hyps, score_hyps |
|
|
|
|
| def make_decoder_inputs(encoder_out, |
| encoder_out_lens, |
| beam_log_probs, |
| beam_log_probs_idx, |
| vocabulary, |
| sos, |
| eos, |
| decoder_len): |
| _, score_hyps = ctc_decoding(beam_log_probs, beam_log_probs_idx, |
| encoder_out_lens, vocabulary, |
| "attention_rescoring") |
| ignore_id = -1 |
| beam_size = beam_log_probs.shape[-1] |
| batch_size = beam_log_probs.shape[0] |
| ctc_score, all_hyps = [], [] |
| for hyps in score_hyps: |
| cur_len = len(hyps) |
| if len(hyps) < beam_size: |
| hyps += (beam_size - cur_len) * [(-float("inf"), (0,))] |
| cur_ctc_score = [] |
| for hyp in hyps: |
| cur_ctc_score.append(hyp[0]) |
| all_hyps.append(list(hyp[1])) |
| ctc_score.append(cur_ctc_score) |
| ctc_score = np.array(ctc_score, dtype=np.float32) |
|
|
| max_len = decoder_len - 2 |
| hyps_pad_sos_eos = np.ones((batch_size, beam_size, max_len + 2), |
| dtype=np.int64) * ignore_id |
| r_hyps_pad_sos_eos = np.ones((batch_size, beam_size, max_len + 2), |
| dtype=np.int64) * ignore_id |
| hyps_lens_sos = np.ones((batch_size, beam_size), dtype=np.int32) |
| k = 0 |
| for i in range(batch_size): |
| for j in range(beam_size): |
| cand = all_hyps[k][:max_len] |
| length = len(cand) + 2 |
| hyps_pad_sos_eos[i][j][0:length] = [sos] + cand + [eos] |
| r_hyps_pad_sos_eos[i][j][0:length] = [sos] + cand[::-1] + [eos] |
| hyps_lens_sos[i][j] = len(cand) + 1 |
| k += 1 |
|
|
| if decoder_len > encoder_out.shape[1]: |
| encoder_out = np.pad(encoder_out, |
| [(0, 0), |
| (0, decoder_len - encoder_out.shape[1]), |
| (0, 0)], |
| mode="constant", |
| constant_values=0) |
| elif decoder_len < encoder_out.shape[1]: |
| encoder_out = encoder_out[:, :decoder_len, :] |
|
|
| return { |
| "encoder_out": encoder_out, |
| "encoder_out_lens": np.full(batch_size, |
| fill_value=decoder_len, |
| dtype=np.int32), |
| "hyps_pad_sos_eos": hyps_pad_sos_eos.astype(np.int32), |
| "hyps_lens_sos": hyps_lens_sos, |
| "r_hyps_pad_sos_eos": r_hyps_pad_sos_eos.astype(np.int32), |
| "ctc_score": ctc_score, |
| }, all_hyps |
|
|
|
|
| def make_offline_inputs(feats, seq_len): |
| feats = feats[:, :seq_len, :] |
| speech_lengths = np.array([feats.shape[1]], dtype=np.int32) |
| if feats.shape[1] < seq_len: |
| feats = pad_array_along_axis(feats, pad_width=seq_len, axis=1) |
| return {"speech": feats, "speech_lengths": speech_lengths} |
|
|
|
|
| def make_online_initial_state(configs, |
| batch_size=1, |
| decoding_chunk_size=16, |
| num_decoding_left_chunks=5): |
| subsampling = 4 |
| context = 7 |
| stride = subsampling * decoding_chunk_size |
| decoding_window = (decoding_chunk_size - 1) * subsampling + context |
| required_cache_size = decoding_chunk_size * num_decoding_left_chunks |
|
|
| output_size = configs["encoder_conf"]["output_size"] |
| num_layers = configs["encoder_conf"]["num_blocks"] |
| cnn_module_kernel = configs["encoder_conf"].get("cnn_module_kernel", 1) - 1 |
| head = configs["encoder_conf"]["attention_heads"] |
| d_k = configs["encoder_conf"]["output_size"] // head |
|
|
| state = { |
| "att_cache": np.zeros((batch_size, num_layers, head, |
| required_cache_size, d_k * 2), |
| dtype=np.float32), |
| "cnn_cache": np.zeros((batch_size, num_layers, output_size, |
| cnn_module_kernel), |
| dtype=np.float32), |
| "cache_mask": np.zeros((batch_size, 1, required_cache_size), |
| dtype=np.float32), |
| "offset": np.zeros((batch_size, 1), dtype=np.int32), |
| } |
| params = { |
| "batch_size": batch_size, |
| "context": context, |
| "stride": stride, |
| "decoding_window": decoding_window, |
| } |
| return state, params |
|
|
|
|
| def make_online_encoder_input(feats, cur, params, state): |
| batch_size = params["batch_size"] |
| decoding_window = params["decoding_window"] |
| end = min(cur + decoding_window, feats.shape[1]) |
| chunk_xs = feats[:, cur:end, :] |
| if chunk_xs.shape[1] < decoding_window: |
| chunk_xs = pad_array_along_axis(chunk_xs, |
| pad_width=decoding_window, |
| axis=1) |
| chunk_xs = chunk_xs.astype(np.float32) |
| chunk_lens = np.full(batch_size, |
| fill_value=chunk_xs.shape[1], |
| dtype=np.int32) |
| return { |
| "chunk_xs": chunk_xs, |
| "chunk_lens": chunk_lens, |
| "offset": state["offset"], |
| "att_cache": state["att_cache"], |
| "cnn_cache": state["cnn_cache"], |
| "cache_mask": state["cache_mask"], |
| } |
|
|
|
|
| def update_online_state(state, outputs): |
| state["offset"] = outputs[4] |
| state["att_cache"] = outputs[5] |
| state["cnn_cache"] = outputs[6] |
| state["cache_mask"] = outputs[7] |
|
|
|
|
| def save_calibration_inputs(calib_data_path, inputs, sample_id): |
| for input_name, data in inputs.items(): |
| data_path = os.path.join(calib_data_path, input_name) |
| os.makedirs(data_path, exist_ok=True) |
| np.save(os.path.join(data_path, f"{sample_id}.npy"), data) |
|
|
|
|
| def pack_calibration_dataset(calib_data_path): |
| for input_name in sorted(os.listdir(calib_data_path)): |
| data_path = os.path.join(calib_data_path, input_name) |
| if not os.path.isdir(data_path): |
| continue |
| tar_path = os.path.join(calib_data_path, input_name + ".tar.gz") |
| with tf.open(tar_path, "w:gz") as tf_file: |
| tf_file.add(data_path, arcname=input_name) |
|
|
|
|
| class WenetONNXRunner: |
|
|
| def __init__(self, |
| config_path, |
| vocab_path, |
| onnx_dir="onnx_model", |
| offline_seq_len=1024, |
| decoder_len=32, |
| decoding_chunk_size=16, |
| num_decoding_left_chunks=5, |
| batch_size=1, |
| providers=None): |
| self.config_path = config_path |
| self.vocab_path = vocab_path |
| self.onnx_dir = onnx_dir |
| self.offline_seq_len = offline_seq_len |
| self.decoder_len = decoder_len |
| self.decoding_chunk_size = decoding_chunk_size |
| self.num_decoding_left_chunks = num_decoding_left_chunks |
| self.batch_size = batch_size |
| self.providers = providers or ["CPUExecutionProvider"] |
|
|
| self.configs = load_config(config_path) |
| self.vocabulary, self.char_dict = load_vocab(vocab_path) |
| self.eos = self.sos = len(self.char_dict) - 1 |
|
|
| self._offline_encoder = None |
| self._online_encoder = None |
| self._decoder = None |
|
|
| @property |
| def offline_encoder(self): |
| if self._offline_encoder is None: |
| self._offline_encoder = self._new_session("encoder_offline.onnx") |
| return self._offline_encoder |
|
|
| @property |
| def online_encoder(self): |
| if self._online_encoder is None: |
| self._online_encoder = self._new_session("encoder_online.onnx") |
| return self._online_encoder |
|
|
| @property |
| def decoder(self): |
| if self._decoder is None: |
| self._decoder = self._new_session("decoder.onnx") |
| return self._decoder |
|
|
| def _new_session(self, filename): |
| import onnxruntime as ort |
| return ort.InferenceSession(os.path.join(self.onnx_dir, filename), |
| providers=self.providers) |
|
|
| def compute_feats(self, audio_file): |
| return compute_feats(audio_file) |
|
|
| def run_offline_encoder(self, |
| feats, |
| calib_data_path=None, |
| sample_id=0): |
| encoder_input = make_offline_inputs(feats, self.offline_seq_len) |
| if calib_data_path: |
| save_calibration_inputs(calib_data_path, encoder_input, sample_id) |
| encoder_out, encoder_out_lens, ctc_log_probs, beam_log_probs, beam_log_probs_idx = ( |
| self.offline_encoder.run(None, encoder_input)) |
| return { |
| "encoder_out": encoder_out, |
| "encoder_out_lens": encoder_out_lens, |
| "ctc_log_probs": ctc_log_probs, |
| "beam_log_probs": beam_log_probs, |
| "beam_log_probs_idx": beam_log_probs_idx, |
| } |
|
|
| def run_online_encoder(self, |
| feats, |
| calib_data_path=None, |
| sample_prefix=0): |
| state, online_params = make_online_initial_state( |
| self.configs, self.batch_size, self.decoding_chunk_size, |
| self.num_decoding_left_chunks) |
| encoder_out = [] |
| beam_log_probs = [] |
| beam_log_probs_idx = [] |
| num_frames = feats.shape[1] |
|
|
| for cur in range(0, num_frames - online_params["context"] + 1, |
| online_params["stride"]): |
| encoder_input = make_online_encoder_input(feats, cur, |
| online_params, state) |
| if calib_data_path: |
| save_calibration_inputs(calib_data_path, encoder_input, |
| f"{sample_prefix}_{cur}") |
| outputs = self.online_encoder.run(None, encoder_input) |
| chunk_log_probs, chunk_log_probs_idx, chunk_out, chunk_out_lens = outputs[:4] |
| update_online_state(state, outputs) |
| del chunk_out_lens |
| encoder_out.append(chunk_out) |
| beam_log_probs.append(chunk_log_probs) |
| beam_log_probs_idx.append(chunk_log_probs_idx.astype(np.int32)) |
|
|
| return { |
| "encoder_out": np.concatenate(encoder_out, axis=1), |
| "encoder_out_lens": np.full(self.batch_size, |
| fill_value=sum( |
| out.shape[1] |
| for out in encoder_out), |
| dtype=np.int32), |
| "beam_log_probs": np.concatenate(beam_log_probs, axis=1), |
| "beam_log_probs_idx": np.concatenate(beam_log_probs_idx, axis=1), |
| "num_chunks": len(encoder_out), |
| } |
|
|
| def ctc_decode(self, encoder_outputs, mode): |
| return ctc_decoding(encoder_outputs["beam_log_probs"], |
| encoder_outputs["beam_log_probs_idx"], |
| encoder_outputs["encoder_out_lens"], |
| self.vocabulary, mode) |
|
|
| def run_decoder(self, |
| encoder_outputs, |
| calib_data_path=None, |
| sample_id=0): |
| decoder_input, all_hyps = make_decoder_inputs( |
| encoder_outputs["encoder_out"], |
| encoder_outputs["encoder_out_lens"], |
| encoder_outputs["beam_log_probs"], |
| encoder_outputs["beam_log_probs_idx"], |
| self.vocabulary, |
| self.sos, |
| self.eos, |
| self.decoder_len, |
| ) |
| if calib_data_path: |
| save_calibration_inputs(calib_data_path, decoder_input, sample_id) |
|
|
| best_index = self.decoder.run(None, decoder_input)[0].astype(np.int32) |
| beam_size = encoder_outputs["beam_log_probs"].shape[-1] |
| num_processes = min(multiprocessing.cpu_count(), best_index.shape[0]) |
| best_sents = [] |
| k = 0 |
| for idx in best_index: |
| best_sents.append(all_hyps[k:k + beam_size][idx]) |
| k += beam_size |
| hyps = map_batch(best_sents, self.vocabulary, num_processes) |
| return "".join(hyps) |
|
|
| def transcribe(self, |
| audio_file, |
| online=False, |
| mode="ctc_prefix_beam_search", |
| calib_data_path=None): |
| feats = self.compute_feats(audio_file) |
| if online: |
| encoder_outputs = self.run_online_encoder( |
| feats, calib_data_path=calib_data_path, sample_prefix=0) |
| else: |
| encoder_outputs = self.run_offline_encoder( |
| feats, calib_data_path=calib_data_path, sample_id=0) |
|
|
| if mode == "attention_rescoring": |
| result = self.run_decoder(encoder_outputs, calib_data_path, 0) |
| else: |
| hyps, _ = self.ctc_decode(encoder_outputs, mode) |
| result = "".join(hyps) if hyps else "" |
|
|
| if calib_data_path: |
| pack_calibration_dataset(calib_data_path) |
| return result |
|
|
| def save_calibration_for_audio(self, |
| audio_file, |
| parts, |
| calib_data_path, |
| sample_id): |
| counts = {"offline": 0, "online": 0, "decoder": 0} |
| feats = self.compute_feats(audio_file) |
| offline_outputs = None |
| if "offline" in parts or "decoder" in parts: |
| offline_outputs = self.run_offline_encoder( |
| feats, |
| calib_data_path=calib_data_path if "offline" in parts else None, |
| sample_id=sample_id) |
| if "offline" in parts: |
| counts["offline"] += 1 |
| if "online" in parts: |
| online_outputs = self.run_online_encoder( |
| feats, |
| calib_data_path=calib_data_path, |
| sample_prefix=sample_id) |
| counts["online"] += online_outputs["num_chunks"] |
| if "decoder" in parts: |
| self.run_decoder(offline_outputs, |
| calib_data_path=calib_data_path, |
| sample_id=sample_id) |
| counts["decoder"] += 1 |
| return counts |
|
|