import multiprocessing import wave import numpy as np import yaml def _log_add(*values): values = [value for value in values if value != -float("inf")] if not values: return -float("inf") max_value = max(values) return max_value + np.log(sum(np.exp(value - max_value) for value in values)) def _map_sentence(sent, vocabulary, greedy=False, blank_id=0): mapped = [] prev = None for token in sent: token = int(token) if greedy and token == prev: prev = token continue prev = token if token == blank_id or token < 0 or token >= len(vocabulary): continue piece = vocabulary[token] if piece.startswith("<") and piece.endswith(">"): continue mapped.append(piece) return "".join(mapped) def map_batch(batch_sents, vocabulary, num_processes, greedy=False, blank_id=0): del num_processes return [_map_sentence(sent, vocabulary, greedy, blank_id) for sent in batch_sents] def _ctc_prefix_beam_search(log_probs_seq, log_probs_idx, beam_size, blank_id): beam = {(): (0.0, -float("inf"))} for frame_probs, frame_ids in zip(log_probs_seq, log_probs_idx): frame_probs = np.asarray(frame_probs, dtype=np.float32) frame_probs = frame_probs - _log_add(*frame_probs.tolist()) next_beam = {} for prefix, (prob_blank, prob_non_blank) in beam.items(): for prob, token in zip(frame_probs, frame_ids): token = int(token) prob = float(prob) next_prob_blank, next_prob_non_blank = next_beam.get( prefix, (-float("inf"), -float("inf"))) if token == blank_id: next_beam[prefix] = ( _log_add(next_prob_blank, prob_blank + prob, prob_non_blank + prob), next_prob_non_blank, ) continue last = prefix[-1] if prefix else None if token == last: next_beam[prefix] = ( next_prob_blank, _log_add(next_prob_non_blank, prob_non_blank + prob), ) new_prefix = prefix + (token, ) nb_blank, nb_non_blank = next_beam.get( new_prefix, (-float("inf"), -float("inf"))) next_beam[new_prefix] = ( nb_blank, _log_add(nb_non_blank, prob_blank + prob), ) else: new_prefix = prefix + (token, ) nb_blank, nb_non_blank = next_beam.get( new_prefix, (-float("inf"), -float("inf"))) next_beam[new_prefix] = ( nb_blank, _log_add(nb_non_blank, prob_blank + prob, prob_non_blank + prob), ) beam = dict(sorted( next_beam.items(), key=lambda item: _log_add(item[1][0], item[1][1]), reverse=True)[:beam_size]) return [(_log_add(prob_blank, prob_non_blank), prefix) for prefix, (prob_blank, prob_non_blank) in sorted( beam.items(), key=lambda item: _log_add(item[1][0], item[1][1]), reverse=True)] def ctc_beam_search_decoder_batch(batch_log_probs_seq, batch_log_probs_idx, batch_root_trie, batch_start, beam_size, num_processes, blank_id=0, space_id=-1, cutoff_prob=0.999, ext_scorer=None): del batch_root_trie, batch_start, num_processes, space_id del cutoff_prob, ext_scorer return [ _ctc_prefix_beam_search(log_probs_seq, log_probs_idx, beam_size, blank_id) for log_probs_seq, log_probs_idx in zip(batch_log_probs_seq, batch_log_probs_idx) ] def load_config(config_path): with open(config_path, "r") as fin: return yaml.load(fin, Loader=yaml.FullLoader) def load_vocab(vocab_path): vocabulary = [] char_dict = {} with open(vocab_path, "r") as fin: for line in fin: arr = line.strip().split() assert len(arr) == 2 char_dict[int(arr[1])] = arr[0] vocabulary.append(arr[0]) return vocabulary, char_dict def load_wav(audio_file): with wave.open(audio_file, "rb") as wav_file: sample_rate = wav_file.getframerate() num_channels = wav_file.getnchannels() sample_width = wav_file.getsampwidth() frames = wav_file.readframes(wav_file.getnframes()) if sample_width == 1: waveform = np.frombuffer(frames, dtype=np.uint8).astype(np.float32) waveform -= 128.0 elif sample_width == 2: waveform = np.frombuffer(frames, dtype=" 1: waveform = waveform.reshape(-1, num_channels).mean(axis=1) return waveform, sample_rate def resample_linear(waveform, orig_sr, target_sr): if orig_sr == target_sr: return waveform duration = waveform.shape[0] / float(orig_sr) target_len = int(round(duration * target_sr)) if target_len <= 1: return waveform src_pos = np.linspace(0, waveform.shape[0] - 1, target_len) return np.interp(src_pos, np.arange(waveform.shape[0]), waveform).astype(np.float32) def hz_to_mel(freq): return 1127.0 * np.log1p(freq / 700.0) def mel_to_hz(mel): return 700.0 * np.expm1(mel / 1127.0) def mel_filterbank(num_mel_bins, n_fft, sample_rate): low_mel = hz_to_mel(20.0) high_mel = hz_to_mel(sample_rate / 2.0) mel_points = np.linspace(low_mel, high_mel, num_mel_bins + 2) hz_points = mel_to_hz(mel_points) bins = np.floor((n_fft + 1) * hz_points / sample_rate).astype(np.int32) fbanks = np.zeros((num_mel_bins, n_fft // 2 + 1), dtype=np.float32) for i in range(num_mel_bins): left, center, right = bins[i], bins[i + 1], bins[i + 2] if center > left: fbanks[i, left:center] = ( np.arange(left, center) - left) / float(center - left) if right > center: fbanks[i, center:right] = ( right - np.arange(center, right)) / float(right - center) return fbanks def numpy_fbank(waveform, sample_rate=16000, num_mel_bins=80, frame_length=25, frame_shift=10): frame_size = int(round(sample_rate * frame_length / 1000.0)) frame_step = int(round(sample_rate * frame_shift / 1000.0)) if waveform.shape[0] < frame_size: waveform = np.pad(waveform, (0, frame_size - waveform.shape[0])) num_frames = 1 + (waveform.shape[0] - frame_size) // frame_step frames = np.lib.stride_tricks.as_strided( waveform, shape=(num_frames, frame_size), strides=(waveform.strides[0] * frame_step, waveform.strides[0]), ).copy() frames *= np.hamming(frame_size).astype(np.float32) n_fft = 1 while n_fft < frame_size: n_fft <<= 1 power = np.abs(np.fft.rfft(frames, n=n_fft))**2 fbanks = mel_filterbank(num_mel_bins, n_fft, sample_rate) mel_energies = np.maximum(np.dot(power, fbanks.T), np.finfo(np.float32).eps) return np.log(mel_energies).astype(np.float32) def compute_feats(audio_file, sr=16000): waveform, sample_rate = load_wav(audio_file) waveform = resample_linear(waveform.astype(np.float32), sample_rate, sr) return numpy_fbank(waveform, sample_rate=sr).reshape(1, -1, 80) def pad_array_along_axis(array, pad_width, axis, mode="constant", **kwargs): if array.shape[axis] >= pad_width: return array full_pad_width = [(0, 0)] * array.ndim full_pad_width[axis] = (0, pad_width - array.shape[axis]) return np.pad(array, pad_width=full_pad_width, mode=mode, **kwargs) def numpy_topk(array, k, axis=-1, largest=True): if largest: partitioned_indices = np.argpartition(array, -k, axis=axis) topk_indices = np.take(partitioned_indices, range(-k, 0), axis=axis) else: partitioned_indices = np.argpartition(array, k, axis=axis) topk_indices = np.take(partitioned_indices, range(0, k), axis=axis) topk_values = np.take_along_axis(array, topk_indices, axis=axis) sorted_indices_in_topk = np.argsort(topk_values, axis=axis) if largest: sorted_indices_in_topk = np.flip(sorted_indices_in_topk, axis=axis) sorted_topk_values = np.take_along_axis(topk_values, sorted_indices_in_topk, axis=axis) sorted_topk_indices = np.take_along_axis(topk_indices, sorted_indices_in_topk, axis=axis) return sorted_topk_values, sorted_topk_indices def ctc_decoding(beam_log_probs, beam_log_probs_idx, encoder_out_lens, vocabulary, mode="ctc_prefix_beam_search"): beam_size = beam_log_probs.shape[-1] batch_size = beam_log_probs.shape[0] num_processes = min(multiprocessing.cpu_count(), batch_size) hyps = [] score_hyps = [] if mode == "ctc_greedy_search": if beam_size == 1: log_probs_idx = beam_log_probs_idx.squeeze(-1) else: log_probs_idx = beam_log_probs_idx[:, :, 0] batch_sents = [] for idx, seq in enumerate(log_probs_idx): batch_sents.append(seq[0:encoder_out_lens[idx]].tolist()) hyps = map_batch(batch_sents, vocabulary, num_processes, True, 0) elif mode in ("ctc_prefix_beam_search", "attention_rescoring"): batch_log_probs_seq_list = beam_log_probs.tolist() batch_log_probs_idx_list = beam_log_probs_idx.tolist() batch_len_list = encoder_out_lens.tolist() batch_log_probs_seq = [] batch_log_probs_ids = [] batch_start = [] batch_root = [] for i in range(len(batch_len_list)): num_sent = batch_len_list[i] batch_log_probs_seq.append(batch_log_probs_seq_list[i][0:num_sent]) batch_log_probs_ids.append(batch_log_probs_idx_list[i][0:num_sent]) batch_root.append(None) batch_start.append(True) score_hyps = ctc_beam_search_decoder_batch(batch_log_probs_seq, batch_log_probs_ids, batch_root, batch_start, beam_size, num_processes, 0, -2, 0.99999) if mode == "ctc_prefix_beam_search": for cand_hyps in score_hyps: hyps.append(cand_hyps[0][1]) hyps = map_batch(hyps, vocabulary, num_processes, False, 0) return hyps, score_hyps def has_higher_scored_collapsed_repeat(hyp, kept_hyps): for i in range(1, len(hyp)): if hyp[i] != hyp[i - 1]: continue collapsed = hyp[:i] + hyp[i + 1:] if collapsed in kept_hyps: return True return False def make_decoder_inputs(encoder_out, encoder_out_lens, beam_log_probs, beam_log_probs_idx, vocabulary, sos, eos, decoder_len): _, score_hyps = ctc_decoding(beam_log_probs, beam_log_probs_idx, encoder_out_lens, vocabulary, "attention_rescoring") ignore_id = -1 beam_size = beam_log_probs.shape[-1] batch_size = beam_log_probs.shape[0] ctc_score, all_hyps = [], [] for hyps in score_hyps: filtered_hyps = [] kept_hyps = set() for score, hyp in hyps: hyp = tuple(hyp) if has_higher_scored_collapsed_repeat(hyp, kept_hyps): continue filtered_hyps.append((score, hyp)) kept_hyps.add(hyp) if len(filtered_hyps) == beam_size: break hyps = filtered_hyps cur_len = len(hyps) if len(hyps) < beam_size: hyps += (beam_size - cur_len) * [(-float("inf"), (0,))] cur_ctc_score = [] for hyp in hyps: cur_ctc_score.append(hyp[0]) all_hyps.append(list(hyp[1])) ctc_score.append(cur_ctc_score) ctc_score = np.array(ctc_score, dtype=np.float32) max_len = decoder_len - 2 hyps_pad_sos_eos = np.ones((batch_size, beam_size, max_len + 2), dtype=np.int64) * ignore_id r_hyps_pad_sos_eos = np.ones((batch_size, beam_size, max_len + 2), dtype=np.int64) * ignore_id hyps_lens_sos = np.ones((batch_size, beam_size), dtype=np.int32) k = 0 for i in range(batch_size): for j in range(beam_size): cand = all_hyps[k][:max_len] length = len(cand) + 2 hyps_pad_sos_eos[i][j][0:length] = [sos] + cand + [eos] r_hyps_pad_sos_eos[i][j][0:length] = [sos] + cand[::-1] + [eos] hyps_lens_sos[i][j] = len(cand) + 1 k += 1 if decoder_len > encoder_out.shape[1]: encoder_out = np.pad(encoder_out, [(0, 0), (0, decoder_len - encoder_out.shape[1]), (0, 0)], mode="constant", constant_values=0) elif decoder_len < encoder_out.shape[1]: encoder_out = encoder_out[:, :decoder_len, :] return { "encoder_out": encoder_out, "encoder_out_lens": np.full(batch_size, fill_value=decoder_len, dtype=np.int32), "hyps_pad_sos_eos": hyps_pad_sos_eos.astype(np.int32), "hyps_lens_sos": hyps_lens_sos, "r_hyps_pad_sos_eos": r_hyps_pad_sos_eos.astype(np.int32), "ctc_score": ctc_score, }, all_hyps def make_offline_inputs(feats, seq_len): feats = feats[:, :seq_len, :] speech_lengths = np.array([feats.shape[1]], dtype=np.int32) if feats.shape[1] < seq_len: feats = pad_array_along_axis(feats, pad_width=seq_len, axis=1) return {"speech": feats, "speech_lengths": speech_lengths} def make_online_initial_state(configs, batch_size=1, decoding_chunk_size=16, num_decoding_left_chunks=5): subsampling = 4 context = 7 stride = subsampling * decoding_chunk_size decoding_window = (decoding_chunk_size - 1) * subsampling + context required_cache_size = decoding_chunk_size * num_decoding_left_chunks output_size = configs["encoder_conf"]["output_size"] num_layers = configs["encoder_conf"]["num_blocks"] cnn_module_kernel = configs["encoder_conf"].get("cnn_module_kernel", 1) - 1 head = configs["encoder_conf"]["attention_heads"] d_k = configs["encoder_conf"]["output_size"] // head state = { "att_cache": np.zeros((batch_size, num_layers, head, required_cache_size, d_k * 2), dtype=np.float32), "cnn_cache": np.zeros((batch_size, num_layers, output_size, cnn_module_kernel), dtype=np.float32), "cache_mask": np.zeros((batch_size, 1, required_cache_size), dtype=np.float32), "offset": np.zeros((batch_size, 1), dtype=np.int32), } params = { "batch_size": batch_size, "context": context, "stride": stride, "decoding_window": decoding_window, } return state, params def make_online_encoder_input(feats, cur, params, state): batch_size = params["batch_size"] decoding_window = params["decoding_window"] end = min(cur + decoding_window, feats.shape[1]) chunk_xs = feats[:, cur:end, :] if chunk_xs.shape[1] < decoding_window: chunk_xs = pad_array_along_axis(chunk_xs, pad_width=decoding_window, axis=1) chunk_xs = chunk_xs.astype(np.float32) chunk_lens = np.full(batch_size, fill_value=chunk_xs.shape[1], dtype=np.int32) return { "chunk_xs": chunk_xs, "chunk_lens": chunk_lens, "offset": state["offset"], "att_cache": state["att_cache"], "cnn_cache": state["cnn_cache"], "cache_mask": state["cache_mask"], } def output_value(outputs, name): if name in outputs: return outputs[name] r_name = "r_" + name if r_name in outputs: return outputs[r_name] raise KeyError(name) def update_online_state(state, outputs): state["offset"] = output_value(outputs, "offset") state["att_cache"] = output_value(outputs, "att_cache") state["cnn_cache"] = output_value(outputs, "cnn_cache") state["cache_mask"] = output_value(outputs, "cache_mask") class AxModel: def __init__(self, path, provider="AxEngineExecutionProvider"): from axengine import InferenceSession self.session = InferenceSession(path, providers=[provider]) self.output_names = [item.name for item in self.session.get_outputs()] def run(self, input_feed): output_values = self.session.run(self.output_names, input_feed) return dict(zip(self.output_names, output_values)) class WenetAXRunner: def __init__(self, config_path, vocab_path, encoder_offline_path="axmodel/encoder_offline/encoder_offline.axmodel", encoder_online_path="axmodel/encoder_online/encoder_online.axmodel", decoder_path="axmodel/decoder/decoder.axmodel", offline_seq_len=1024, decoder_len=32, decoding_chunk_size=16, num_decoding_left_chunks=5, batch_size=1, provider="AxEngineExecutionProvider"): self.config_path = config_path self.vocab_path = vocab_path self.encoder_offline_path = encoder_offline_path self.encoder_online_path = encoder_online_path self.decoder_path = decoder_path self.offline_seq_len = offline_seq_len self.decoder_len = decoder_len self.decoding_chunk_size = decoding_chunk_size self.num_decoding_left_chunks = num_decoding_left_chunks self.batch_size = batch_size self.provider = provider self.configs = load_config(config_path) self.vocabulary, self.char_dict = load_vocab(vocab_path) self.eos = self.sos = len(self.char_dict) - 1 self._offline_encoder = None self._online_encoder = None self._decoder = None @property def offline_encoder(self): if self._offline_encoder is None: self._offline_encoder = AxModel(self.encoder_offline_path, self.provider) return self._offline_encoder @property def online_encoder(self): if self._online_encoder is None: self._online_encoder = AxModel(self.encoder_online_path, self.provider) return self._online_encoder @property def decoder(self): if self._decoder is None: self._decoder = AxModel(self.decoder_path, self.provider) return self._decoder def compute_feats(self, audio_file): return compute_feats(audio_file) def run_offline_encoder(self, feats): encoder_input = make_offline_inputs(feats, self.offline_seq_len) speech_lengths = encoder_input["speech_lengths"] outputs = self.offline_encoder.run(encoder_input) encoder_out_lens = outputs["encoder_out_lens"].astype(np.int32) encoder_out_lens[0] = np.ones([speech_lengths[0]], dtype=np.int32)[2::2][2::2].sum() beam_log_probs, beam_log_probs_idx = numpy_topk( outputs["ctc_log_probs"], k=10) return { "encoder_out": outputs["encoder_out"], "encoder_out_lens": encoder_out_lens, "ctc_log_probs": outputs["ctc_log_probs"], "beam_log_probs": beam_log_probs, "beam_log_probs_idx": beam_log_probs_idx, } def run_online_encoder(self, feats): state, online_params = make_online_initial_state( self.configs, self.batch_size, self.decoding_chunk_size, self.num_decoding_left_chunks) encoder_out = [] beam_log_probs = [] beam_log_probs_idx = [] num_frames = feats.shape[1] for cur in range(0, num_frames - online_params["context"] + 1, online_params["stride"]): encoder_input = make_online_encoder_input(feats, cur, online_params, state) outputs = self.online_encoder.run(encoder_input) update_online_state(state, outputs) encoder_out.append(outputs["chunk_out"]) beam_log_probs.append(outputs["log_probs"]) beam_log_probs_idx.append(outputs["log_probs_idx"].astype(np.int32)) return { "encoder_out": np.concatenate(encoder_out, axis=1), "encoder_out_lens": np.full(self.batch_size, fill_value=sum( out.shape[1] for out in encoder_out), dtype=np.int32), "beam_log_probs": np.concatenate(beam_log_probs, axis=1), "beam_log_probs_idx": np.concatenate(beam_log_probs_idx, axis=1), } def ctc_decode(self, encoder_outputs, mode): return ctc_decoding(encoder_outputs["beam_log_probs"], encoder_outputs["beam_log_probs_idx"], encoder_outputs["encoder_out_lens"], self.vocabulary, mode) def run_decoder(self, encoder_outputs): decoder_input, all_hyps = make_decoder_inputs( encoder_outputs["encoder_out"], encoder_outputs["encoder_out_lens"], encoder_outputs["beam_log_probs"], encoder_outputs["beam_log_probs_idx"], self.vocabulary, self.sos, self.eos, self.decoder_len, ) best_index = self.decoder.run(decoder_input)["best_index"].astype( np.int32) beam_size = encoder_outputs["beam_log_probs"].shape[-1] num_processes = min(multiprocessing.cpu_count(), best_index.shape[0]) best_sents = [] k = 0 for idx in best_index: best_sents.append(all_hyps[k:k + beam_size][idx]) k += beam_size hyps = map_batch(best_sents, self.vocabulary, num_processes) return "".join(hyps) def transcribe(self, audio_file, online=False, mode="ctc_prefix_beam_search"): feats = self.compute_feats(audio_file) if online: encoder_outputs = self.run_online_encoder(feats) else: encoder_outputs = self.run_offline_encoder(feats) if mode == "attention_rescoring": return self.run_decoder(encoder_outputs) hyps, _ = self.ctc_decode(encoder_outputs, mode) return "".join(hyps) if hyps else ""