WeNet / ax_common.py
yangrongzhao
Add provider, Update README, add RTF
78224f0
Raw
History Blame Contribute Delete
24.6 kB
import multiprocessing
import wave
import numpy as np
import yaml
def _log_add(*values):
values = [value for value in values if value != -float("inf")]
if not values:
return -float("inf")
max_value = max(values)
return max_value + np.log(sum(np.exp(value - max_value)
for value in values))
def _map_sentence(sent, vocabulary, greedy=False, blank_id=0):
mapped = []
prev = None
for token in sent:
token = int(token)
if greedy and token == prev:
prev = token
continue
prev = token
if token == blank_id or token < 0 or token >= len(vocabulary):
continue
piece = vocabulary[token]
if piece.startswith("<") and piece.endswith(">"):
continue
mapped.append(piece)
return "".join(mapped)
def map_batch(batch_sents, vocabulary, num_processes, greedy=False,
blank_id=0):
del num_processes
return [_map_sentence(sent, vocabulary, greedy, blank_id)
for sent in batch_sents]
def _ctc_prefix_beam_search(log_probs_seq, log_probs_idx, beam_size,
blank_id):
beam = {(): (0.0, -float("inf"))}
for frame_probs, frame_ids in zip(log_probs_seq, log_probs_idx):
frame_probs = np.asarray(frame_probs, dtype=np.float32)
frame_probs = frame_probs - _log_add(*frame_probs.tolist())
next_beam = {}
for prefix, (prob_blank, prob_non_blank) in beam.items():
for prob, token in zip(frame_probs, frame_ids):
token = int(token)
prob = float(prob)
next_prob_blank, next_prob_non_blank = next_beam.get(
prefix, (-float("inf"), -float("inf")))
if token == blank_id:
next_beam[prefix] = (
_log_add(next_prob_blank, prob_blank + prob,
prob_non_blank + prob),
next_prob_non_blank,
)
continue
last = prefix[-1] if prefix else None
if token == last:
next_beam[prefix] = (
next_prob_blank,
_log_add(next_prob_non_blank, prob_non_blank + prob),
)
new_prefix = prefix + (token, )
nb_blank, nb_non_blank = next_beam.get(
new_prefix, (-float("inf"), -float("inf")))
next_beam[new_prefix] = (
nb_blank,
_log_add(nb_non_blank, prob_blank + prob),
)
else:
new_prefix = prefix + (token, )
nb_blank, nb_non_blank = next_beam.get(
new_prefix, (-float("inf"), -float("inf")))
next_beam[new_prefix] = (
nb_blank,
_log_add(nb_non_blank, prob_blank + prob,
prob_non_blank + prob),
)
beam = dict(sorted(
next_beam.items(),
key=lambda item: _log_add(item[1][0], item[1][1]),
reverse=True)[:beam_size])
return [(_log_add(prob_blank, prob_non_blank), prefix)
for prefix, (prob_blank, prob_non_blank) in sorted(
beam.items(),
key=lambda item: _log_add(item[1][0], item[1][1]),
reverse=True)]
def ctc_beam_search_decoder_batch(batch_log_probs_seq,
batch_log_probs_idx,
batch_root_trie,
batch_start,
beam_size,
num_processes,
blank_id=0,
space_id=-1,
cutoff_prob=0.999,
ext_scorer=None):
del batch_root_trie, batch_start, num_processes, space_id
del cutoff_prob, ext_scorer
return [
_ctc_prefix_beam_search(log_probs_seq, log_probs_idx, beam_size,
blank_id)
for log_probs_seq, log_probs_idx in zip(batch_log_probs_seq,
batch_log_probs_idx)
]
def load_config(config_path):
with open(config_path, "r") as fin:
return yaml.load(fin, Loader=yaml.FullLoader)
def load_vocab(vocab_path):
vocabulary = []
char_dict = {}
with open(vocab_path, "r") as fin:
for line in fin:
arr = line.strip().split()
assert len(arr) == 2
char_dict[int(arr[1])] = arr[0]
vocabulary.append(arr[0])
return vocabulary, char_dict
def load_wav(audio_file):
with wave.open(audio_file, "rb") as wav_file:
sample_rate = wav_file.getframerate()
num_channels = wav_file.getnchannels()
sample_width = wav_file.getsampwidth()
frames = wav_file.readframes(wav_file.getnframes())
if sample_width == 1:
waveform = np.frombuffer(frames, dtype=np.uint8).astype(np.float32)
waveform -= 128.0
elif sample_width == 2:
waveform = np.frombuffer(frames, dtype="<i2").astype(np.float32)
elif sample_width == 4:
waveform = np.frombuffer(frames, dtype="<i4").astype(np.float32)
else:
raise ValueError(f"Unsupported wav sample width: {sample_width}")
if num_channels > 1:
waveform = waveform.reshape(-1, num_channels).mean(axis=1)
return waveform, sample_rate
def resample_linear(waveform, orig_sr, target_sr):
if orig_sr == target_sr:
return waveform
duration = waveform.shape[0] / float(orig_sr)
target_len = int(round(duration * target_sr))
if target_len <= 1:
return waveform
src_pos = np.linspace(0, waveform.shape[0] - 1, target_len)
return np.interp(src_pos, np.arange(waveform.shape[0]),
waveform).astype(np.float32)
def hz_to_mel(freq):
return 1127.0 * np.log1p(freq / 700.0)
def mel_to_hz(mel):
return 700.0 * np.expm1(mel / 1127.0)
def mel_filterbank(num_mel_bins, n_fft, sample_rate):
low_mel = hz_to_mel(20.0)
high_mel = hz_to_mel(sample_rate / 2.0)
mel_points = np.linspace(low_mel, high_mel, num_mel_bins + 2)
hz_points = mel_to_hz(mel_points)
bins = np.floor((n_fft + 1) * hz_points / sample_rate).astype(np.int32)
fbanks = np.zeros((num_mel_bins, n_fft // 2 + 1), dtype=np.float32)
for i in range(num_mel_bins):
left, center, right = bins[i], bins[i + 1], bins[i + 2]
if center > left:
fbanks[i, left:center] = (
np.arange(left, center) - left) / float(center - left)
if right > center:
fbanks[i, center:right] = (
right - np.arange(center, right)) / float(right - center)
return fbanks
def numpy_fbank(waveform,
sample_rate=16000,
num_mel_bins=80,
frame_length=25,
frame_shift=10):
frame_size = int(round(sample_rate * frame_length / 1000.0))
frame_step = int(round(sample_rate * frame_shift / 1000.0))
if waveform.shape[0] < frame_size:
waveform = np.pad(waveform, (0, frame_size - waveform.shape[0]))
num_frames = 1 + (waveform.shape[0] - frame_size) // frame_step
frames = np.lib.stride_tricks.as_strided(
waveform,
shape=(num_frames, frame_size),
strides=(waveform.strides[0] * frame_step, waveform.strides[0]),
).copy()
frames *= np.hamming(frame_size).astype(np.float32)
n_fft = 1
while n_fft < frame_size:
n_fft <<= 1
power = np.abs(np.fft.rfft(frames, n=n_fft))**2
fbanks = mel_filterbank(num_mel_bins, n_fft, sample_rate)
mel_energies = np.maximum(np.dot(power, fbanks.T), np.finfo(np.float32).eps)
return np.log(mel_energies).astype(np.float32)
def compute_feats(audio_file, sr=16000):
waveform, sample_rate = load_wav(audio_file)
waveform = resample_linear(waveform.astype(np.float32), sample_rate, sr)
return numpy_fbank(waveform, sample_rate=sr).reshape(1, -1, 80)
def pad_array_along_axis(array, pad_width, axis, mode="constant", **kwargs):
if array.shape[axis] >= pad_width:
return array
full_pad_width = [(0, 0)] * array.ndim
full_pad_width[axis] = (0, pad_width - array.shape[axis])
return np.pad(array, pad_width=full_pad_width, mode=mode, **kwargs)
def numpy_topk(array, k, axis=-1, largest=True):
if largest:
partitioned_indices = np.argpartition(array, -k, axis=axis)
topk_indices = np.take(partitioned_indices, range(-k, 0), axis=axis)
else:
partitioned_indices = np.argpartition(array, k, axis=axis)
topk_indices = np.take(partitioned_indices, range(0, k), axis=axis)
topk_values = np.take_along_axis(array, topk_indices, axis=axis)
sorted_indices_in_topk = np.argsort(topk_values, axis=axis)
if largest:
sorted_indices_in_topk = np.flip(sorted_indices_in_topk, axis=axis)
sorted_topk_values = np.take_along_axis(topk_values,
sorted_indices_in_topk,
axis=axis)
sorted_topk_indices = np.take_along_axis(topk_indices,
sorted_indices_in_topk,
axis=axis)
return sorted_topk_values, sorted_topk_indices
def ctc_decoding(beam_log_probs,
beam_log_probs_idx,
encoder_out_lens,
vocabulary,
mode="ctc_prefix_beam_search"):
beam_size = beam_log_probs.shape[-1]
batch_size = beam_log_probs.shape[0]
num_processes = min(multiprocessing.cpu_count(), batch_size)
hyps = []
score_hyps = []
if mode == "ctc_greedy_search":
if beam_size == 1:
log_probs_idx = beam_log_probs_idx.squeeze(-1)
else:
log_probs_idx = beam_log_probs_idx[:, :, 0]
batch_sents = []
for idx, seq in enumerate(log_probs_idx):
batch_sents.append(seq[0:encoder_out_lens[idx]].tolist())
hyps = map_batch(batch_sents, vocabulary, num_processes, True, 0)
elif mode in ("ctc_prefix_beam_search", "attention_rescoring"):
batch_log_probs_seq_list = beam_log_probs.tolist()
batch_log_probs_idx_list = beam_log_probs_idx.tolist()
batch_len_list = encoder_out_lens.tolist()
batch_log_probs_seq = []
batch_log_probs_ids = []
batch_start = []
batch_root = []
for i in range(len(batch_len_list)):
num_sent = batch_len_list[i]
batch_log_probs_seq.append(batch_log_probs_seq_list[i][0:num_sent])
batch_log_probs_ids.append(batch_log_probs_idx_list[i][0:num_sent])
batch_root.append(None)
batch_start.append(True)
score_hyps = ctc_beam_search_decoder_batch(batch_log_probs_seq,
batch_log_probs_ids,
batch_root,
batch_start,
beam_size,
num_processes,
0, -2, 0.99999)
if mode == "ctc_prefix_beam_search":
for cand_hyps in score_hyps:
hyps.append(cand_hyps[0][1])
hyps = map_batch(hyps, vocabulary, num_processes, False, 0)
return hyps, score_hyps
def has_higher_scored_collapsed_repeat(hyp, kept_hyps):
for i in range(1, len(hyp)):
if hyp[i] != hyp[i - 1]:
continue
collapsed = hyp[:i] + hyp[i + 1:]
if collapsed in kept_hyps:
return True
return False
def make_decoder_inputs(encoder_out,
encoder_out_lens,
beam_log_probs,
beam_log_probs_idx,
vocabulary,
sos,
eos,
decoder_len):
_, score_hyps = ctc_decoding(beam_log_probs, beam_log_probs_idx,
encoder_out_lens, vocabulary,
"attention_rescoring")
ignore_id = -1
beam_size = beam_log_probs.shape[-1]
batch_size = beam_log_probs.shape[0]
ctc_score, all_hyps = [], []
for hyps in score_hyps:
filtered_hyps = []
kept_hyps = set()
for score, hyp in hyps:
hyp = tuple(hyp)
if has_higher_scored_collapsed_repeat(hyp, kept_hyps):
continue
filtered_hyps.append((score, hyp))
kept_hyps.add(hyp)
if len(filtered_hyps) == beam_size:
break
hyps = filtered_hyps
cur_len = len(hyps)
if len(hyps) < beam_size:
hyps += (beam_size - cur_len) * [(-float("inf"), (0,))]
cur_ctc_score = []
for hyp in hyps:
cur_ctc_score.append(hyp[0])
all_hyps.append(list(hyp[1]))
ctc_score.append(cur_ctc_score)
ctc_score = np.array(ctc_score, dtype=np.float32)
max_len = decoder_len - 2
hyps_pad_sos_eos = np.ones((batch_size, beam_size, max_len + 2),
dtype=np.int64) * ignore_id
r_hyps_pad_sos_eos = np.ones((batch_size, beam_size, max_len + 2),
dtype=np.int64) * ignore_id
hyps_lens_sos = np.ones((batch_size, beam_size), dtype=np.int32)
k = 0
for i in range(batch_size):
for j in range(beam_size):
cand = all_hyps[k][:max_len]
length = len(cand) + 2
hyps_pad_sos_eos[i][j][0:length] = [sos] + cand + [eos]
r_hyps_pad_sos_eos[i][j][0:length] = [sos] + cand[::-1] + [eos]
hyps_lens_sos[i][j] = len(cand) + 1
k += 1
if decoder_len > encoder_out.shape[1]:
encoder_out = np.pad(encoder_out,
[(0, 0),
(0, decoder_len - encoder_out.shape[1]),
(0, 0)],
mode="constant",
constant_values=0)
elif decoder_len < encoder_out.shape[1]:
encoder_out = encoder_out[:, :decoder_len, :]
return {
"encoder_out": encoder_out,
"encoder_out_lens": np.full(batch_size,
fill_value=decoder_len,
dtype=np.int32),
"hyps_pad_sos_eos": hyps_pad_sos_eos.astype(np.int32),
"hyps_lens_sos": hyps_lens_sos,
"r_hyps_pad_sos_eos": r_hyps_pad_sos_eos.astype(np.int32),
"ctc_score": ctc_score,
}, all_hyps
def make_offline_inputs(feats, seq_len):
feats = feats[:, :seq_len, :]
speech_lengths = np.array([feats.shape[1]], dtype=np.int32)
if feats.shape[1] < seq_len:
feats = pad_array_along_axis(feats, pad_width=seq_len, axis=1)
return {"speech": feats, "speech_lengths": speech_lengths}
def make_online_initial_state(configs,
batch_size=1,
decoding_chunk_size=16,
num_decoding_left_chunks=5):
subsampling = 4
context = 7
stride = subsampling * decoding_chunk_size
decoding_window = (decoding_chunk_size - 1) * subsampling + context
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
output_size = configs["encoder_conf"]["output_size"]
num_layers = configs["encoder_conf"]["num_blocks"]
cnn_module_kernel = configs["encoder_conf"].get("cnn_module_kernel", 1) - 1
head = configs["encoder_conf"]["attention_heads"]
d_k = configs["encoder_conf"]["output_size"] // head
state = {
"att_cache": np.zeros((batch_size, num_layers, head,
required_cache_size, d_k * 2),
dtype=np.float32),
"cnn_cache": np.zeros((batch_size, num_layers, output_size,
cnn_module_kernel),
dtype=np.float32),
"cache_mask": np.zeros((batch_size, 1, required_cache_size),
dtype=np.float32),
"offset": np.zeros((batch_size, 1), dtype=np.int32),
}
params = {
"batch_size": batch_size,
"context": context,
"stride": stride,
"decoding_window": decoding_window,
}
return state, params
def make_online_encoder_input(feats, cur, params, state):
batch_size = params["batch_size"]
decoding_window = params["decoding_window"]
end = min(cur + decoding_window, feats.shape[1])
chunk_xs = feats[:, cur:end, :]
if chunk_xs.shape[1] < decoding_window:
chunk_xs = pad_array_along_axis(chunk_xs,
pad_width=decoding_window,
axis=1)
chunk_xs = chunk_xs.astype(np.float32)
chunk_lens = np.full(batch_size,
fill_value=chunk_xs.shape[1],
dtype=np.int32)
return {
"chunk_xs": chunk_xs,
"chunk_lens": chunk_lens,
"offset": state["offset"],
"att_cache": state["att_cache"],
"cnn_cache": state["cnn_cache"],
"cache_mask": state["cache_mask"],
}
def output_value(outputs, name):
if name in outputs:
return outputs[name]
r_name = "r_" + name
if r_name in outputs:
return outputs[r_name]
raise KeyError(name)
def update_online_state(state, outputs):
state["offset"] = output_value(outputs, "offset")
state["att_cache"] = output_value(outputs, "att_cache")
state["cnn_cache"] = output_value(outputs, "cnn_cache")
state["cache_mask"] = output_value(outputs, "cache_mask")
class AxModel:
def __init__(self, path, provider="AxEngineExecutionProvider"):
from axengine import InferenceSession
self.session = InferenceSession(path, providers=[provider])
self.output_names = [item.name for item in self.session.get_outputs()]
def run(self, input_feed):
output_values = self.session.run(self.output_names, input_feed)
return dict(zip(self.output_names, output_values))
class WenetAXRunner:
def __init__(self,
config_path,
vocab_path,
encoder_offline_path="axmodel/encoder_offline/encoder_offline.axmodel",
encoder_online_path="axmodel/encoder_online/encoder_online.axmodel",
decoder_path="axmodel/decoder/decoder.axmodel",
offline_seq_len=1024,
decoder_len=32,
decoding_chunk_size=16,
num_decoding_left_chunks=5,
batch_size=1,
provider="AxEngineExecutionProvider"):
self.config_path = config_path
self.vocab_path = vocab_path
self.encoder_offline_path = encoder_offline_path
self.encoder_online_path = encoder_online_path
self.decoder_path = decoder_path
self.offline_seq_len = offline_seq_len
self.decoder_len = decoder_len
self.decoding_chunk_size = decoding_chunk_size
self.num_decoding_left_chunks = num_decoding_left_chunks
self.batch_size = batch_size
self.provider = provider
self.configs = load_config(config_path)
self.vocabulary, self.char_dict = load_vocab(vocab_path)
self.eos = self.sos = len(self.char_dict) - 1
self._offline_encoder = None
self._online_encoder = None
self._decoder = None
@property
def offline_encoder(self):
if self._offline_encoder is None:
self._offline_encoder = AxModel(self.encoder_offline_path,
self.provider)
return self._offline_encoder
@property
def online_encoder(self):
if self._online_encoder is None:
self._online_encoder = AxModel(self.encoder_online_path,
self.provider)
return self._online_encoder
@property
def decoder(self):
if self._decoder is None:
self._decoder = AxModel(self.decoder_path, self.provider)
return self._decoder
def compute_feats(self, audio_file):
return compute_feats(audio_file)
def run_offline_encoder(self, feats):
encoder_input = make_offline_inputs(feats, self.offline_seq_len)
speech_lengths = encoder_input["speech_lengths"]
outputs = self.offline_encoder.run(encoder_input)
encoder_out_lens = outputs["encoder_out_lens"].astype(np.int32)
encoder_out_lens[0] = np.ones([speech_lengths[0]],
dtype=np.int32)[2::2][2::2].sum()
beam_log_probs, beam_log_probs_idx = numpy_topk(
outputs["ctc_log_probs"], k=10)
return {
"encoder_out": outputs["encoder_out"],
"encoder_out_lens": encoder_out_lens,
"ctc_log_probs": outputs["ctc_log_probs"],
"beam_log_probs": beam_log_probs,
"beam_log_probs_idx": beam_log_probs_idx,
}
def run_online_encoder(self, feats):
state, online_params = make_online_initial_state(
self.configs, self.batch_size, self.decoding_chunk_size,
self.num_decoding_left_chunks)
encoder_out = []
beam_log_probs = []
beam_log_probs_idx = []
num_frames = feats.shape[1]
for cur in range(0, num_frames - online_params["context"] + 1,
online_params["stride"]):
encoder_input = make_online_encoder_input(feats, cur,
online_params, state)
outputs = self.online_encoder.run(encoder_input)
update_online_state(state, outputs)
encoder_out.append(outputs["chunk_out"])
beam_log_probs.append(outputs["log_probs"])
beam_log_probs_idx.append(outputs["log_probs_idx"].astype(np.int32))
return {
"encoder_out": np.concatenate(encoder_out, axis=1),
"encoder_out_lens": np.full(self.batch_size,
fill_value=sum(
out.shape[1]
for out in encoder_out),
dtype=np.int32),
"beam_log_probs": np.concatenate(beam_log_probs, axis=1),
"beam_log_probs_idx": np.concatenate(beam_log_probs_idx, axis=1),
}
def ctc_decode(self, encoder_outputs, mode):
return ctc_decoding(encoder_outputs["beam_log_probs"],
encoder_outputs["beam_log_probs_idx"],
encoder_outputs["encoder_out_lens"],
self.vocabulary, mode)
def run_decoder(self, encoder_outputs):
decoder_input, all_hyps = make_decoder_inputs(
encoder_outputs["encoder_out"],
encoder_outputs["encoder_out_lens"],
encoder_outputs["beam_log_probs"],
encoder_outputs["beam_log_probs_idx"],
self.vocabulary,
self.sos,
self.eos,
self.decoder_len,
)
best_index = self.decoder.run(decoder_input)["best_index"].astype(
np.int32)
beam_size = encoder_outputs["beam_log_probs"].shape[-1]
num_processes = min(multiprocessing.cpu_count(), best_index.shape[0])
best_sents = []
k = 0
for idx in best_index:
best_sents.append(all_hyps[k:k + beam_size][idx])
k += beam_size
hyps = map_batch(best_sents, self.vocabulary, num_processes)
return "".join(hyps)
def transcribe(self,
audio_file,
online=False,
mode="ctc_prefix_beam_search"):
feats = self.compute_feats(audio_file)
if online:
encoder_outputs = self.run_online_encoder(feats)
else:
encoder_outputs = self.run_offline_encoder(feats)
if mode == "attention_rescoring":
return self.run_decoder(encoder_outputs)
hyps, _ = self.ctc_decode(encoder_outputs, mode)
return "".join(hyps) if hyps else ""