| import multiprocessing |
| import wave |
|
|
| import numpy as np |
| import yaml |
|
|
|
|
| def _log_add(*values): |
| values = [value for value in values if value != -float("inf")] |
| if not values: |
| return -float("inf") |
| max_value = max(values) |
| return max_value + np.log(sum(np.exp(value - max_value) |
| for value in values)) |
|
|
|
|
| def _map_sentence(sent, vocabulary, greedy=False, blank_id=0): |
| mapped = [] |
| prev = None |
| for token in sent: |
| token = int(token) |
| if greedy and token == prev: |
| prev = token |
| continue |
| prev = token |
| if token == blank_id or token < 0 or token >= len(vocabulary): |
| continue |
| piece = vocabulary[token] |
| if piece.startswith("<") and piece.endswith(">"): |
| continue |
| mapped.append(piece) |
| return "".join(mapped) |
|
|
|
|
| def map_batch(batch_sents, vocabulary, num_processes, greedy=False, |
| blank_id=0): |
| del num_processes |
| return [_map_sentence(sent, vocabulary, greedy, blank_id) |
| for sent in batch_sents] |
|
|
|
|
| def _ctc_prefix_beam_search(log_probs_seq, log_probs_idx, beam_size, |
| blank_id): |
| beam = {(): (0.0, -float("inf"))} |
| for frame_probs, frame_ids in zip(log_probs_seq, log_probs_idx): |
| frame_probs = np.asarray(frame_probs, dtype=np.float32) |
| frame_probs = frame_probs - _log_add(*frame_probs.tolist()) |
| next_beam = {} |
| for prefix, (prob_blank, prob_non_blank) in beam.items(): |
| for prob, token in zip(frame_probs, frame_ids): |
| token = int(token) |
| prob = float(prob) |
| next_prob_blank, next_prob_non_blank = next_beam.get( |
| prefix, (-float("inf"), -float("inf"))) |
| if token == blank_id: |
| next_beam[prefix] = ( |
| _log_add(next_prob_blank, prob_blank + prob, |
| prob_non_blank + prob), |
| next_prob_non_blank, |
| ) |
| continue |
|
|
| last = prefix[-1] if prefix else None |
| if token == last: |
| next_beam[prefix] = ( |
| next_prob_blank, |
| _log_add(next_prob_non_blank, prob_non_blank + prob), |
| ) |
| new_prefix = prefix + (token, ) |
| nb_blank, nb_non_blank = next_beam.get( |
| new_prefix, (-float("inf"), -float("inf"))) |
| next_beam[new_prefix] = ( |
| nb_blank, |
| _log_add(nb_non_blank, prob_blank + prob), |
| ) |
| else: |
| new_prefix = prefix + (token, ) |
| nb_blank, nb_non_blank = next_beam.get( |
| new_prefix, (-float("inf"), -float("inf"))) |
| next_beam[new_prefix] = ( |
| nb_blank, |
| _log_add(nb_non_blank, prob_blank + prob, |
| prob_non_blank + prob), |
| ) |
| beam = dict(sorted( |
| next_beam.items(), |
| key=lambda item: _log_add(item[1][0], item[1][1]), |
| reverse=True)[:beam_size]) |
| return [(_log_add(prob_blank, prob_non_blank), prefix) |
| for prefix, (prob_blank, prob_non_blank) in sorted( |
| beam.items(), |
| key=lambda item: _log_add(item[1][0], item[1][1]), |
| reverse=True)] |
|
|
|
|
| def ctc_beam_search_decoder_batch(batch_log_probs_seq, |
| batch_log_probs_idx, |
| batch_root_trie, |
| batch_start, |
| beam_size, |
| num_processes, |
| blank_id=0, |
| space_id=-1, |
| cutoff_prob=0.999, |
| ext_scorer=None): |
| del batch_root_trie, batch_start, num_processes, space_id |
| del cutoff_prob, ext_scorer |
| return [ |
| _ctc_prefix_beam_search(log_probs_seq, log_probs_idx, beam_size, |
| blank_id) |
| for log_probs_seq, log_probs_idx in zip(batch_log_probs_seq, |
| batch_log_probs_idx) |
| ] |
|
|
|
|
| def load_config(config_path): |
| with open(config_path, "r") as fin: |
| return yaml.load(fin, Loader=yaml.FullLoader) |
|
|
|
|
| def load_vocab(vocab_path): |
| vocabulary = [] |
| char_dict = {} |
| with open(vocab_path, "r") as fin: |
| for line in fin: |
| arr = line.strip().split() |
| assert len(arr) == 2 |
| char_dict[int(arr[1])] = arr[0] |
| vocabulary.append(arr[0]) |
| return vocabulary, char_dict |
|
|
|
|
| def load_wav(audio_file): |
| with wave.open(audio_file, "rb") as wav_file: |
| sample_rate = wav_file.getframerate() |
| num_channels = wav_file.getnchannels() |
| sample_width = wav_file.getsampwidth() |
| frames = wav_file.readframes(wav_file.getnframes()) |
|
|
| if sample_width == 1: |
| waveform = np.frombuffer(frames, dtype=np.uint8).astype(np.float32) |
| waveform -= 128.0 |
| elif sample_width == 2: |
| waveform = np.frombuffer(frames, dtype="<i2").astype(np.float32) |
| elif sample_width == 4: |
| waveform = np.frombuffer(frames, dtype="<i4").astype(np.float32) |
| else: |
| raise ValueError(f"Unsupported wav sample width: {sample_width}") |
|
|
| if num_channels > 1: |
| waveform = waveform.reshape(-1, num_channels).mean(axis=1) |
| return waveform, sample_rate |
|
|
|
|
| def resample_linear(waveform, orig_sr, target_sr): |
| if orig_sr == target_sr: |
| return waveform |
| duration = waveform.shape[0] / float(orig_sr) |
| target_len = int(round(duration * target_sr)) |
| if target_len <= 1: |
| return waveform |
| src_pos = np.linspace(0, waveform.shape[0] - 1, target_len) |
| return np.interp(src_pos, np.arange(waveform.shape[0]), |
| waveform).astype(np.float32) |
|
|
|
|
| def hz_to_mel(freq): |
| return 1127.0 * np.log1p(freq / 700.0) |
|
|
|
|
| def mel_to_hz(mel): |
| return 700.0 * np.expm1(mel / 1127.0) |
|
|
|
|
| def mel_filterbank(num_mel_bins, n_fft, sample_rate): |
| low_mel = hz_to_mel(20.0) |
| high_mel = hz_to_mel(sample_rate / 2.0) |
| mel_points = np.linspace(low_mel, high_mel, num_mel_bins + 2) |
| hz_points = mel_to_hz(mel_points) |
| bins = np.floor((n_fft + 1) * hz_points / sample_rate).astype(np.int32) |
| fbanks = np.zeros((num_mel_bins, n_fft // 2 + 1), dtype=np.float32) |
| for i in range(num_mel_bins): |
| left, center, right = bins[i], bins[i + 1], bins[i + 2] |
| if center > left: |
| fbanks[i, left:center] = ( |
| np.arange(left, center) - left) / float(center - left) |
| if right > center: |
| fbanks[i, center:right] = ( |
| right - np.arange(center, right)) / float(right - center) |
| return fbanks |
|
|
|
|
| def numpy_fbank(waveform, |
| sample_rate=16000, |
| num_mel_bins=80, |
| frame_length=25, |
| frame_shift=10): |
| frame_size = int(round(sample_rate * frame_length / 1000.0)) |
| frame_step = int(round(sample_rate * frame_shift / 1000.0)) |
| if waveform.shape[0] < frame_size: |
| waveform = np.pad(waveform, (0, frame_size - waveform.shape[0])) |
| num_frames = 1 + (waveform.shape[0] - frame_size) // frame_step |
| frames = np.lib.stride_tricks.as_strided( |
| waveform, |
| shape=(num_frames, frame_size), |
| strides=(waveform.strides[0] * frame_step, waveform.strides[0]), |
| ).copy() |
| frames *= np.hamming(frame_size).astype(np.float32) |
| n_fft = 1 |
| while n_fft < frame_size: |
| n_fft <<= 1 |
| power = np.abs(np.fft.rfft(frames, n=n_fft))**2 |
| fbanks = mel_filterbank(num_mel_bins, n_fft, sample_rate) |
| mel_energies = np.maximum(np.dot(power, fbanks.T), np.finfo(np.float32).eps) |
| return np.log(mel_energies).astype(np.float32) |
|
|
|
|
| def compute_feats(audio_file, sr=16000): |
| waveform, sample_rate = load_wav(audio_file) |
| waveform = resample_linear(waveform.astype(np.float32), sample_rate, sr) |
| return numpy_fbank(waveform, sample_rate=sr).reshape(1, -1, 80) |
|
|
|
|
| def pad_array_along_axis(array, pad_width, axis, mode="constant", **kwargs): |
| if array.shape[axis] >= pad_width: |
| return array |
| full_pad_width = [(0, 0)] * array.ndim |
| full_pad_width[axis] = (0, pad_width - array.shape[axis]) |
| return np.pad(array, pad_width=full_pad_width, mode=mode, **kwargs) |
|
|
|
|
| def numpy_topk(array, k, axis=-1, largest=True): |
| if largest: |
| partitioned_indices = np.argpartition(array, -k, axis=axis) |
| topk_indices = np.take(partitioned_indices, range(-k, 0), axis=axis) |
| else: |
| partitioned_indices = np.argpartition(array, k, axis=axis) |
| topk_indices = np.take(partitioned_indices, range(0, k), axis=axis) |
|
|
| topk_values = np.take_along_axis(array, topk_indices, axis=axis) |
| sorted_indices_in_topk = np.argsort(topk_values, axis=axis) |
| if largest: |
| sorted_indices_in_topk = np.flip(sorted_indices_in_topk, axis=axis) |
| sorted_topk_values = np.take_along_axis(topk_values, |
| sorted_indices_in_topk, |
| axis=axis) |
| sorted_topk_indices = np.take_along_axis(topk_indices, |
| sorted_indices_in_topk, |
| axis=axis) |
| return sorted_topk_values, sorted_topk_indices |
|
|
|
|
| def ctc_decoding(beam_log_probs, |
| beam_log_probs_idx, |
| encoder_out_lens, |
| vocabulary, |
| mode="ctc_prefix_beam_search"): |
| beam_size = beam_log_probs.shape[-1] |
| batch_size = beam_log_probs.shape[0] |
| num_processes = min(multiprocessing.cpu_count(), batch_size) |
| hyps = [] |
| score_hyps = [] |
|
|
| if mode == "ctc_greedy_search": |
| if beam_size == 1: |
| log_probs_idx = beam_log_probs_idx.squeeze(-1) |
| else: |
| log_probs_idx = beam_log_probs_idx[:, :, 0] |
| batch_sents = [] |
| for idx, seq in enumerate(log_probs_idx): |
| batch_sents.append(seq[0:encoder_out_lens[idx]].tolist()) |
| hyps = map_batch(batch_sents, vocabulary, num_processes, True, 0) |
| elif mode in ("ctc_prefix_beam_search", "attention_rescoring"): |
| batch_log_probs_seq_list = beam_log_probs.tolist() |
| batch_log_probs_idx_list = beam_log_probs_idx.tolist() |
| batch_len_list = encoder_out_lens.tolist() |
| batch_log_probs_seq = [] |
| batch_log_probs_ids = [] |
| batch_start = [] |
| batch_root = [] |
| for i in range(len(batch_len_list)): |
| num_sent = batch_len_list[i] |
| batch_log_probs_seq.append(batch_log_probs_seq_list[i][0:num_sent]) |
| batch_log_probs_ids.append(batch_log_probs_idx_list[i][0:num_sent]) |
| batch_root.append(None) |
| batch_start.append(True) |
| score_hyps = ctc_beam_search_decoder_batch(batch_log_probs_seq, |
| batch_log_probs_ids, |
| batch_root, |
| batch_start, |
| beam_size, |
| num_processes, |
| 0, -2, 0.99999) |
| if mode == "ctc_prefix_beam_search": |
| for cand_hyps in score_hyps: |
| hyps.append(cand_hyps[0][1]) |
| hyps = map_batch(hyps, vocabulary, num_processes, False, 0) |
| return hyps, score_hyps |
|
|
|
|
| def has_higher_scored_collapsed_repeat(hyp, kept_hyps): |
| for i in range(1, len(hyp)): |
| if hyp[i] != hyp[i - 1]: |
| continue |
| collapsed = hyp[:i] + hyp[i + 1:] |
| if collapsed in kept_hyps: |
| return True |
| return False |
|
|
|
|
| def make_decoder_inputs(encoder_out, |
| encoder_out_lens, |
| beam_log_probs, |
| beam_log_probs_idx, |
| vocabulary, |
| sos, |
| eos, |
| decoder_len): |
| _, score_hyps = ctc_decoding(beam_log_probs, beam_log_probs_idx, |
| encoder_out_lens, vocabulary, |
| "attention_rescoring") |
| ignore_id = -1 |
| beam_size = beam_log_probs.shape[-1] |
| batch_size = beam_log_probs.shape[0] |
| ctc_score, all_hyps = [], [] |
| for hyps in score_hyps: |
| filtered_hyps = [] |
| kept_hyps = set() |
| for score, hyp in hyps: |
| hyp = tuple(hyp) |
| if has_higher_scored_collapsed_repeat(hyp, kept_hyps): |
| continue |
| filtered_hyps.append((score, hyp)) |
| kept_hyps.add(hyp) |
| if len(filtered_hyps) == beam_size: |
| break |
| hyps = filtered_hyps |
| cur_len = len(hyps) |
| if len(hyps) < beam_size: |
| hyps += (beam_size - cur_len) * [(-float("inf"), (0,))] |
| cur_ctc_score = [] |
| for hyp in hyps: |
| cur_ctc_score.append(hyp[0]) |
| all_hyps.append(list(hyp[1])) |
| ctc_score.append(cur_ctc_score) |
| ctc_score = np.array(ctc_score, dtype=np.float32) |
|
|
| max_len = decoder_len - 2 |
| hyps_pad_sos_eos = np.ones((batch_size, beam_size, max_len + 2), |
| dtype=np.int64) * ignore_id |
| r_hyps_pad_sos_eos = np.ones((batch_size, beam_size, max_len + 2), |
| dtype=np.int64) * ignore_id |
| hyps_lens_sos = np.ones((batch_size, beam_size), dtype=np.int32) |
| k = 0 |
| for i in range(batch_size): |
| for j in range(beam_size): |
| cand = all_hyps[k][:max_len] |
| length = len(cand) + 2 |
| hyps_pad_sos_eos[i][j][0:length] = [sos] + cand + [eos] |
| r_hyps_pad_sos_eos[i][j][0:length] = [sos] + cand[::-1] + [eos] |
| hyps_lens_sos[i][j] = len(cand) + 1 |
| k += 1 |
|
|
| if decoder_len > encoder_out.shape[1]: |
| encoder_out = np.pad(encoder_out, |
| [(0, 0), |
| (0, decoder_len - encoder_out.shape[1]), |
| (0, 0)], |
| mode="constant", |
| constant_values=0) |
| elif decoder_len < encoder_out.shape[1]: |
| encoder_out = encoder_out[:, :decoder_len, :] |
|
|
| return { |
| "encoder_out": encoder_out, |
| "encoder_out_lens": np.full(batch_size, |
| fill_value=decoder_len, |
| dtype=np.int32), |
| "hyps_pad_sos_eos": hyps_pad_sos_eos.astype(np.int32), |
| "hyps_lens_sos": hyps_lens_sos, |
| "r_hyps_pad_sos_eos": r_hyps_pad_sos_eos.astype(np.int32), |
| "ctc_score": ctc_score, |
| }, all_hyps |
|
|
|
|
| def make_offline_inputs(feats, seq_len): |
| feats = feats[:, :seq_len, :] |
| speech_lengths = np.array([feats.shape[1]], dtype=np.int32) |
| if feats.shape[1] < seq_len: |
| feats = pad_array_along_axis(feats, pad_width=seq_len, axis=1) |
| return {"speech": feats, "speech_lengths": speech_lengths} |
|
|
|
|
| def make_online_initial_state(configs, |
| batch_size=1, |
| decoding_chunk_size=16, |
| num_decoding_left_chunks=5): |
| subsampling = 4 |
| context = 7 |
| stride = subsampling * decoding_chunk_size |
| decoding_window = (decoding_chunk_size - 1) * subsampling + context |
| required_cache_size = decoding_chunk_size * num_decoding_left_chunks |
|
|
| output_size = configs["encoder_conf"]["output_size"] |
| num_layers = configs["encoder_conf"]["num_blocks"] |
| cnn_module_kernel = configs["encoder_conf"].get("cnn_module_kernel", 1) - 1 |
| head = configs["encoder_conf"]["attention_heads"] |
| d_k = configs["encoder_conf"]["output_size"] // head |
|
|
| state = { |
| "att_cache": np.zeros((batch_size, num_layers, head, |
| required_cache_size, d_k * 2), |
| dtype=np.float32), |
| "cnn_cache": np.zeros((batch_size, num_layers, output_size, |
| cnn_module_kernel), |
| dtype=np.float32), |
| "cache_mask": np.zeros((batch_size, 1, required_cache_size), |
| dtype=np.float32), |
| "offset": np.zeros((batch_size, 1), dtype=np.int32), |
| } |
| params = { |
| "batch_size": batch_size, |
| "context": context, |
| "stride": stride, |
| "decoding_window": decoding_window, |
| } |
| return state, params |
|
|
|
|
| def make_online_encoder_input(feats, cur, params, state): |
| batch_size = params["batch_size"] |
| decoding_window = params["decoding_window"] |
| end = min(cur + decoding_window, feats.shape[1]) |
| chunk_xs = feats[:, cur:end, :] |
| if chunk_xs.shape[1] < decoding_window: |
| chunk_xs = pad_array_along_axis(chunk_xs, |
| pad_width=decoding_window, |
| axis=1) |
| chunk_xs = chunk_xs.astype(np.float32) |
| chunk_lens = np.full(batch_size, |
| fill_value=chunk_xs.shape[1], |
| dtype=np.int32) |
| return { |
| "chunk_xs": chunk_xs, |
| "chunk_lens": chunk_lens, |
| "offset": state["offset"], |
| "att_cache": state["att_cache"], |
| "cnn_cache": state["cnn_cache"], |
| "cache_mask": state["cache_mask"], |
| } |
|
|
|
|
| def output_value(outputs, name): |
| if name in outputs: |
| return outputs[name] |
| r_name = "r_" + name |
| if r_name in outputs: |
| return outputs[r_name] |
| raise KeyError(name) |
|
|
|
|
| def update_online_state(state, outputs): |
| state["offset"] = output_value(outputs, "offset") |
| state["att_cache"] = output_value(outputs, "att_cache") |
| state["cnn_cache"] = output_value(outputs, "cnn_cache") |
| state["cache_mask"] = output_value(outputs, "cache_mask") |
|
|
|
|
| class AxModel: |
|
|
| def __init__(self, path, provider="AxEngineExecutionProvider"): |
| from axengine import InferenceSession |
|
|
| self.session = InferenceSession(path, providers=[provider]) |
| self.output_names = [item.name for item in self.session.get_outputs()] |
|
|
| def run(self, input_feed): |
| output_values = self.session.run(self.output_names, input_feed) |
| return dict(zip(self.output_names, output_values)) |
|
|
|
|
| class WenetAXRunner: |
|
|
| def __init__(self, |
| config_path, |
| vocab_path, |
| encoder_offline_path="axmodel/encoder_offline/encoder_offline.axmodel", |
| encoder_online_path="axmodel/encoder_online/encoder_online.axmodel", |
| decoder_path="axmodel/decoder/decoder.axmodel", |
| offline_seq_len=1024, |
| decoder_len=32, |
| decoding_chunk_size=16, |
| num_decoding_left_chunks=5, |
| batch_size=1, |
| provider="AxEngineExecutionProvider"): |
| self.config_path = config_path |
| self.vocab_path = vocab_path |
| self.encoder_offline_path = encoder_offline_path |
| self.encoder_online_path = encoder_online_path |
| self.decoder_path = decoder_path |
| self.offline_seq_len = offline_seq_len |
| self.decoder_len = decoder_len |
| self.decoding_chunk_size = decoding_chunk_size |
| self.num_decoding_left_chunks = num_decoding_left_chunks |
| self.batch_size = batch_size |
| self.provider = provider |
|
|
| self.configs = load_config(config_path) |
| self.vocabulary, self.char_dict = load_vocab(vocab_path) |
| self.eos = self.sos = len(self.char_dict) - 1 |
|
|
| self._offline_encoder = None |
| self._online_encoder = None |
| self._decoder = None |
|
|
| @property |
| def offline_encoder(self): |
| if self._offline_encoder is None: |
| self._offline_encoder = AxModel(self.encoder_offline_path, |
| self.provider) |
| return self._offline_encoder |
|
|
| @property |
| def online_encoder(self): |
| if self._online_encoder is None: |
| self._online_encoder = AxModel(self.encoder_online_path, |
| self.provider) |
| return self._online_encoder |
|
|
| @property |
| def decoder(self): |
| if self._decoder is None: |
| self._decoder = AxModel(self.decoder_path, self.provider) |
| return self._decoder |
|
|
| def compute_feats(self, audio_file): |
| return compute_feats(audio_file) |
|
|
| def run_offline_encoder(self, feats): |
| encoder_input = make_offline_inputs(feats, self.offline_seq_len) |
| speech_lengths = encoder_input["speech_lengths"] |
| outputs = self.offline_encoder.run(encoder_input) |
| encoder_out_lens = outputs["encoder_out_lens"].astype(np.int32) |
| encoder_out_lens[0] = np.ones([speech_lengths[0]], |
| dtype=np.int32)[2::2][2::2].sum() |
| beam_log_probs, beam_log_probs_idx = numpy_topk( |
| outputs["ctc_log_probs"], k=10) |
| return { |
| "encoder_out": outputs["encoder_out"], |
| "encoder_out_lens": encoder_out_lens, |
| "ctc_log_probs": outputs["ctc_log_probs"], |
| "beam_log_probs": beam_log_probs, |
| "beam_log_probs_idx": beam_log_probs_idx, |
| } |
|
|
| def run_online_encoder(self, feats): |
| state, online_params = make_online_initial_state( |
| self.configs, self.batch_size, self.decoding_chunk_size, |
| self.num_decoding_left_chunks) |
| encoder_out = [] |
| beam_log_probs = [] |
| beam_log_probs_idx = [] |
| num_frames = feats.shape[1] |
|
|
| for cur in range(0, num_frames - online_params["context"] + 1, |
| online_params["stride"]): |
| encoder_input = make_online_encoder_input(feats, cur, |
| online_params, state) |
| outputs = self.online_encoder.run(encoder_input) |
| update_online_state(state, outputs) |
| encoder_out.append(outputs["chunk_out"]) |
| beam_log_probs.append(outputs["log_probs"]) |
| beam_log_probs_idx.append(outputs["log_probs_idx"].astype(np.int32)) |
|
|
| return { |
| "encoder_out": np.concatenate(encoder_out, axis=1), |
| "encoder_out_lens": np.full(self.batch_size, |
| fill_value=sum( |
| out.shape[1] |
| for out in encoder_out), |
| dtype=np.int32), |
| "beam_log_probs": np.concatenate(beam_log_probs, axis=1), |
| "beam_log_probs_idx": np.concatenate(beam_log_probs_idx, axis=1), |
| } |
|
|
| def ctc_decode(self, encoder_outputs, mode): |
| return ctc_decoding(encoder_outputs["beam_log_probs"], |
| encoder_outputs["beam_log_probs_idx"], |
| encoder_outputs["encoder_out_lens"], |
| self.vocabulary, mode) |
|
|
| def run_decoder(self, encoder_outputs): |
| decoder_input, all_hyps = make_decoder_inputs( |
| encoder_outputs["encoder_out"], |
| encoder_outputs["encoder_out_lens"], |
| encoder_outputs["beam_log_probs"], |
| encoder_outputs["beam_log_probs_idx"], |
| self.vocabulary, |
| self.sos, |
| self.eos, |
| self.decoder_len, |
| ) |
| best_index = self.decoder.run(decoder_input)["best_index"].astype( |
| np.int32) |
| beam_size = encoder_outputs["beam_log_probs"].shape[-1] |
| num_processes = min(multiprocessing.cpu_count(), best_index.shape[0]) |
| best_sents = [] |
| k = 0 |
| for idx in best_index: |
| best_sents.append(all_hyps[k:k + beam_size][idx]) |
| k += beam_size |
| hyps = map_batch(best_sents, self.vocabulary, num_processes) |
| return "".join(hyps) |
|
|
| def transcribe(self, |
| audio_file, |
| online=False, |
| mode="ctc_prefix_beam_search"): |
| feats = self.compute_feats(audio_file) |
| if online: |
| encoder_outputs = self.run_online_encoder(feats) |
| else: |
| encoder_outputs = self.run_offline_encoder(feats) |
|
|
| if mode == "attention_rescoring": |
| return self.run_decoder(encoder_outputs) |
|
|
| hyps, _ = self.ctc_decode(encoder_outputs, mode) |
| return "".join(hyps) if hyps else "" |
|
|