WeNet / ort_common.py
inoryQwQ's picture
First commit
3c50954
Raw
History Blame Contribute Delete
22.8 kB
import multiprocessing
import os
import tarfile as tf
import numpy as np
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import yaml
try:
from swig_decoders import (map_batch, ctc_beam_search_decoder_batch,
TrieVector, PathTrie)
except ModuleNotFoundError:
class PathTrie:
pass
class TrieVector(list):
pass
def _log_add(*values):
values = [value for value in values if value != -float("inf")]
if not values:
return -float("inf")
max_value = max(values)
return max_value + np.log(sum(np.exp(value - max_value)
for value in values))
def _map_sentence(sent, vocabulary, greedy=False, blank_id=0):
mapped = []
prev = None
for token in sent:
token = int(token)
if greedy and token == prev:
prev = token
continue
prev = token
if token == blank_id or token < 0 or token >= len(vocabulary):
continue
piece = vocabulary[token]
if piece.startswith("<") and piece.endswith(">"):
continue
mapped.append(piece)
return "".join(mapped)
def map_batch(batch_sents, vocabulary, num_processes, greedy=False,
blank_id=0):
del num_processes
return [_map_sentence(sent, vocabulary, greedy, blank_id)
for sent in batch_sents]
def _ctc_prefix_beam_search(log_probs_seq, log_probs_idx, beam_size,
blank_id):
beam = {(): (0.0, -float("inf"))}
for frame_probs, frame_ids in zip(log_probs_seq, log_probs_idx):
next_beam = {}
for prefix, (prob_blank, prob_non_blank) in beam.items():
for prob, token in zip(frame_probs, frame_ids):
token = int(token)
prob = float(prob)
next_prob_blank, next_prob_non_blank = next_beam.get(
prefix, (-float("inf"), -float("inf")))
if token == blank_id:
next_beam[prefix] = (
_log_add(next_prob_blank, prob_blank + prob,
prob_non_blank + prob),
next_prob_non_blank,
)
continue
last = prefix[-1] if prefix else None
if token == last:
next_beam[prefix] = (
next_prob_blank,
_log_add(next_prob_non_blank,
prob_non_blank + prob),
)
new_prefix = prefix + (token, )
nb_blank, nb_non_blank = next_beam.get(
new_prefix, (-float("inf"), -float("inf")))
next_beam[new_prefix] = (
nb_blank,
_log_add(nb_non_blank, prob_blank + prob),
)
else:
new_prefix = prefix + (token, )
nb_blank, nb_non_blank = next_beam.get(
new_prefix, (-float("inf"), -float("inf")))
next_beam[new_prefix] = (
nb_blank,
_log_add(nb_non_blank, prob_blank + prob,
prob_non_blank + prob),
)
beam = dict(sorted(
next_beam.items(),
key=lambda item: _log_add(item[1][0], item[1][1]),
reverse=True)[:beam_size])
return [(_log_add(prob_blank, prob_non_blank), prefix)
for prefix, (prob_blank, prob_non_blank) in sorted(
beam.items(),
key=lambda item: _log_add(item[1][0], item[1][1]),
reverse=True)]
def ctc_beam_search_decoder_batch(batch_log_probs_seq,
batch_log_probs_idx,
batch_root_trie,
batch_start,
beam_size,
num_processes,
blank_id=0,
space_id=-1,
cutoff_prob=0.999,
ext_scorer=None):
del batch_root_trie, batch_start, num_processes, space_id
del cutoff_prob, ext_scorer
return [
_ctc_prefix_beam_search(log_probs_seq, log_probs_idx, beam_size,
blank_id)
for log_probs_seq, log_probs_idx in zip(batch_log_probs_seq,
batch_log_probs_idx)
]
def load_config(config_path):
with open(config_path, "r") as fin:
return yaml.load(fin, Loader=yaml.FullLoader)
def load_vocab(vocab_path):
vocabulary = []
char_dict = {}
with open(vocab_path, "r") as fin:
for line in fin:
arr = line.strip().split()
assert len(arr) == 2
char_dict[int(arr[1])] = arr[0]
vocabulary.append(arr[0])
return vocabulary, char_dict
def compute_feats(audio_file: str, sr=16000) -> np.ndarray:
try:
import soundfile as sf
waveform, sample_rate = sf.read(audio_file, dtype="int16",
always_2d=True)
waveform = torch.from_numpy(waveform.T).to(torch.float)
except ModuleNotFoundError:
waveform, sample_rate = torchaudio.load(audio_file, normalize=True)
waveform = waveform.to(torch.float)
if sample_rate != sr:
waveform = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=sr)(waveform)
feats = kaldi.fbank(waveform,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
energy_floor=0.0,
sample_frequency=sr)
return feats.unsqueeze(0).numpy()
def pad_array_along_axis(array, pad_width, axis, mode="constant", **kwargs):
if array.shape[axis] >= pad_width:
return array
full_pad_width = [(0, 0)] * array.ndim
full_pad_width[axis] = (0, pad_width - array.shape[axis])
return np.pad(array, pad_width=full_pad_width, mode=mode, **kwargs)
def ctc_decoding(beam_log_probs,
beam_log_probs_idx,
encoder_out_lens,
vocabulary,
mode="ctc_prefix_beam_search"):
beam_size = beam_log_probs.shape[-1]
batch_size = beam_log_probs.shape[0]
num_processes = min(multiprocessing.cpu_count(), batch_size)
hyps = []
score_hyps = []
if mode == "ctc_greedy_search":
log_probs_idx = beam_log_probs_idx[:, :, 0]
batch_sents = []
for idx, seq in enumerate(log_probs_idx):
batch_sents.append(seq[0:encoder_out_lens[idx]].tolist())
hyps = map_batch(batch_sents, vocabulary, num_processes, True, 0)
elif mode in ("ctc_prefix_beam_search", "attention_rescoring"):
batch_log_probs_seq_list = beam_log_probs.tolist()
batch_log_probs_idx_list = beam_log_probs_idx.tolist()
batch_len_list = encoder_out_lens.tolist()
batch_log_probs_seq = []
batch_log_probs_ids = []
batch_start = []
batch_root = TrieVector()
root_dict = {}
for i in range(len(batch_len_list)):
num_sent = batch_len_list[i]
batch_log_probs_seq.append(batch_log_probs_seq_list[i][0:num_sent])
batch_log_probs_ids.append(batch_log_probs_idx_list[i][0:num_sent])
root_dict[i] = PathTrie()
batch_root.append(root_dict[i])
batch_start.append(True)
score_hyps = ctc_beam_search_decoder_batch(batch_log_probs_seq,
batch_log_probs_ids,
batch_root,
batch_start,
beam_size,
num_processes,
0, -2, 0.99999)
if mode == "ctc_prefix_beam_search":
for cand_hyps in score_hyps:
hyps.append(cand_hyps[0][1])
hyps = map_batch(hyps, vocabulary, num_processes, False, 0)
return hyps, score_hyps
def make_decoder_inputs(encoder_out,
encoder_out_lens,
beam_log_probs,
beam_log_probs_idx,
vocabulary,
sos,
eos,
decoder_len):
_, score_hyps = ctc_decoding(beam_log_probs, beam_log_probs_idx,
encoder_out_lens, vocabulary,
"attention_rescoring")
ignore_id = -1
beam_size = beam_log_probs.shape[-1]
batch_size = beam_log_probs.shape[0]
ctc_score, all_hyps = [], []
for hyps in score_hyps:
cur_len = len(hyps)
if len(hyps) < beam_size:
hyps += (beam_size - cur_len) * [(-float("inf"), (0,))]
cur_ctc_score = []
for hyp in hyps:
cur_ctc_score.append(hyp[0])
all_hyps.append(list(hyp[1]))
ctc_score.append(cur_ctc_score)
ctc_score = np.array(ctc_score, dtype=np.float32)
max_len = decoder_len - 2
hyps_pad_sos_eos = np.ones((batch_size, beam_size, max_len + 2),
dtype=np.int64) * ignore_id
r_hyps_pad_sos_eos = np.ones((batch_size, beam_size, max_len + 2),
dtype=np.int64) * ignore_id
hyps_lens_sos = np.ones((batch_size, beam_size), dtype=np.int32)
k = 0
for i in range(batch_size):
for j in range(beam_size):
cand = all_hyps[k][:max_len]
length = len(cand) + 2
hyps_pad_sos_eos[i][j][0:length] = [sos] + cand + [eos]
r_hyps_pad_sos_eos[i][j][0:length] = [sos] + cand[::-1] + [eos]
hyps_lens_sos[i][j] = len(cand) + 1
k += 1
if decoder_len > encoder_out.shape[1]:
encoder_out = np.pad(encoder_out,
[(0, 0),
(0, decoder_len - encoder_out.shape[1]),
(0, 0)],
mode="constant",
constant_values=0)
elif decoder_len < encoder_out.shape[1]:
encoder_out = encoder_out[:, :decoder_len, :]
return {
"encoder_out": encoder_out,
"encoder_out_lens": np.full(batch_size,
fill_value=decoder_len,
dtype=np.int32),
"hyps_pad_sos_eos": hyps_pad_sos_eos.astype(np.int32),
"hyps_lens_sos": hyps_lens_sos,
"r_hyps_pad_sos_eos": r_hyps_pad_sos_eos.astype(np.int32),
"ctc_score": ctc_score,
}, all_hyps
def make_offline_inputs(feats, seq_len):
feats = feats[:, :seq_len, :]
speech_lengths = np.array([feats.shape[1]], dtype=np.int32)
if feats.shape[1] < seq_len:
feats = pad_array_along_axis(feats, pad_width=seq_len, axis=1)
return {"speech": feats, "speech_lengths": speech_lengths}
def make_online_initial_state(configs,
batch_size=1,
decoding_chunk_size=16,
num_decoding_left_chunks=5):
subsampling = 4
context = 7
stride = subsampling * decoding_chunk_size
decoding_window = (decoding_chunk_size - 1) * subsampling + context
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
output_size = configs["encoder_conf"]["output_size"]
num_layers = configs["encoder_conf"]["num_blocks"]
cnn_module_kernel = configs["encoder_conf"].get("cnn_module_kernel", 1) - 1
head = configs["encoder_conf"]["attention_heads"]
d_k = configs["encoder_conf"]["output_size"] // head
state = {
"att_cache": np.zeros((batch_size, num_layers, head,
required_cache_size, d_k * 2),
dtype=np.float32),
"cnn_cache": np.zeros((batch_size, num_layers, output_size,
cnn_module_kernel),
dtype=np.float32),
"cache_mask": np.zeros((batch_size, 1, required_cache_size),
dtype=np.float32),
"offset": np.zeros((batch_size, 1), dtype=np.int32),
}
params = {
"batch_size": batch_size,
"context": context,
"stride": stride,
"decoding_window": decoding_window,
}
return state, params
def make_online_encoder_input(feats, cur, params, state):
batch_size = params["batch_size"]
decoding_window = params["decoding_window"]
end = min(cur + decoding_window, feats.shape[1])
chunk_xs = feats[:, cur:end, :]
if chunk_xs.shape[1] < decoding_window:
chunk_xs = pad_array_along_axis(chunk_xs,
pad_width=decoding_window,
axis=1)
chunk_xs = chunk_xs.astype(np.float32)
chunk_lens = np.full(batch_size,
fill_value=chunk_xs.shape[1],
dtype=np.int32)
return {
"chunk_xs": chunk_xs,
"chunk_lens": chunk_lens,
"offset": state["offset"],
"att_cache": state["att_cache"],
"cnn_cache": state["cnn_cache"],
"cache_mask": state["cache_mask"],
}
def update_online_state(state, outputs):
state["offset"] = outputs[4]
state["att_cache"] = outputs[5]
state["cnn_cache"] = outputs[6]
state["cache_mask"] = outputs[7]
def save_calibration_inputs(calib_data_path, inputs, sample_id):
for input_name, data in inputs.items():
data_path = os.path.join(calib_data_path, input_name)
os.makedirs(data_path, exist_ok=True)
np.save(os.path.join(data_path, f"{sample_id}.npy"), data)
def pack_calibration_dataset(calib_data_path):
for input_name in sorted(os.listdir(calib_data_path)):
data_path = os.path.join(calib_data_path, input_name)
if not os.path.isdir(data_path):
continue
tar_path = os.path.join(calib_data_path, input_name + ".tar.gz")
with tf.open(tar_path, "w:gz") as tf_file:
tf_file.add(data_path, arcname=input_name)
class WenetONNXRunner:
def __init__(self,
config_path,
vocab_path,
onnx_dir="onnx_model",
offline_seq_len=1024,
decoder_len=32,
decoding_chunk_size=16,
num_decoding_left_chunks=5,
batch_size=1,
providers=None):
self.config_path = config_path
self.vocab_path = vocab_path
self.onnx_dir = onnx_dir
self.offline_seq_len = offline_seq_len
self.decoder_len = decoder_len
self.decoding_chunk_size = decoding_chunk_size
self.num_decoding_left_chunks = num_decoding_left_chunks
self.batch_size = batch_size
self.providers = providers or ["CPUExecutionProvider"]
self.configs = load_config(config_path)
self.vocabulary, self.char_dict = load_vocab(vocab_path)
self.eos = self.sos = len(self.char_dict) - 1
self._offline_encoder = None
self._online_encoder = None
self._decoder = None
@property
def offline_encoder(self):
if self._offline_encoder is None:
self._offline_encoder = self._new_session("encoder_offline.onnx")
return self._offline_encoder
@property
def online_encoder(self):
if self._online_encoder is None:
self._online_encoder = self._new_session("encoder_online.onnx")
return self._online_encoder
@property
def decoder(self):
if self._decoder is None:
self._decoder = self._new_session("decoder.onnx")
return self._decoder
def _new_session(self, filename):
import onnxruntime as ort
return ort.InferenceSession(os.path.join(self.onnx_dir, filename),
providers=self.providers)
def compute_feats(self, audio_file):
return compute_feats(audio_file)
def run_offline_encoder(self,
feats,
calib_data_path=None,
sample_id=0):
encoder_input = make_offline_inputs(feats, self.offline_seq_len)
if calib_data_path:
save_calibration_inputs(calib_data_path, encoder_input, sample_id)
encoder_out, encoder_out_lens, ctc_log_probs, beam_log_probs, beam_log_probs_idx = (
self.offline_encoder.run(None, encoder_input))
return {
"encoder_out": encoder_out,
"encoder_out_lens": encoder_out_lens,
"ctc_log_probs": ctc_log_probs,
"beam_log_probs": beam_log_probs,
"beam_log_probs_idx": beam_log_probs_idx,
}
def run_online_encoder(self,
feats,
calib_data_path=None,
sample_prefix=0):
state, online_params = make_online_initial_state(
self.configs, self.batch_size, self.decoding_chunk_size,
self.num_decoding_left_chunks)
encoder_out = []
beam_log_probs = []
beam_log_probs_idx = []
num_frames = feats.shape[1]
for cur in range(0, num_frames - online_params["context"] + 1,
online_params["stride"]):
encoder_input = make_online_encoder_input(feats, cur,
online_params, state)
if calib_data_path:
save_calibration_inputs(calib_data_path, encoder_input,
f"{sample_prefix}_{cur}")
outputs = self.online_encoder.run(None, encoder_input)
chunk_log_probs, chunk_log_probs_idx, chunk_out, chunk_out_lens = outputs[:4]
update_online_state(state, outputs)
del chunk_out_lens
encoder_out.append(chunk_out)
beam_log_probs.append(chunk_log_probs)
beam_log_probs_idx.append(chunk_log_probs_idx.astype(np.int32))
return {
"encoder_out": np.concatenate(encoder_out, axis=1),
"encoder_out_lens": np.full(self.batch_size,
fill_value=sum(
out.shape[1]
for out in encoder_out),
dtype=np.int32),
"beam_log_probs": np.concatenate(beam_log_probs, axis=1),
"beam_log_probs_idx": np.concatenate(beam_log_probs_idx, axis=1),
"num_chunks": len(encoder_out),
}
def ctc_decode(self, encoder_outputs, mode):
return ctc_decoding(encoder_outputs["beam_log_probs"],
encoder_outputs["beam_log_probs_idx"],
encoder_outputs["encoder_out_lens"],
self.vocabulary, mode)
def run_decoder(self,
encoder_outputs,
calib_data_path=None,
sample_id=0):
decoder_input, all_hyps = make_decoder_inputs(
encoder_outputs["encoder_out"],
encoder_outputs["encoder_out_lens"],
encoder_outputs["beam_log_probs"],
encoder_outputs["beam_log_probs_idx"],
self.vocabulary,
self.sos,
self.eos,
self.decoder_len,
)
if calib_data_path:
save_calibration_inputs(calib_data_path, decoder_input, sample_id)
best_index = self.decoder.run(None, decoder_input)[0].astype(np.int32)
beam_size = encoder_outputs["beam_log_probs"].shape[-1]
num_processes = min(multiprocessing.cpu_count(), best_index.shape[0])
best_sents = []
k = 0
for idx in best_index:
best_sents.append(all_hyps[k:k + beam_size][idx])
k += beam_size
hyps = map_batch(best_sents, self.vocabulary, num_processes)
return "".join(hyps)
def transcribe(self,
audio_file,
online=False,
mode="ctc_prefix_beam_search",
calib_data_path=None):
feats = self.compute_feats(audio_file)
if online:
encoder_outputs = self.run_online_encoder(
feats, calib_data_path=calib_data_path, sample_prefix=0)
else:
encoder_outputs = self.run_offline_encoder(
feats, calib_data_path=calib_data_path, sample_id=0)
if mode == "attention_rescoring":
result = self.run_decoder(encoder_outputs, calib_data_path, 0)
else:
hyps, _ = self.ctc_decode(encoder_outputs, mode)
result = "".join(hyps) if hyps else ""
if calib_data_path:
pack_calibration_dataset(calib_data_path)
return result
def save_calibration_for_audio(self,
audio_file,
parts,
calib_data_path,
sample_id):
counts = {"offline": 0, "online": 0, "decoder": 0}
feats = self.compute_feats(audio_file)
offline_outputs = None
if "offline" in parts or "decoder" in parts:
offline_outputs = self.run_offline_encoder(
feats,
calib_data_path=calib_data_path if "offline" in parts else None,
sample_id=sample_id)
if "offline" in parts:
counts["offline"] += 1
if "online" in parts:
online_outputs = self.run_online_encoder(
feats,
calib_data_path=calib_data_path,
sample_prefix=sample_id)
counts["online"] += online_outputs["num_chunks"]
if "decoder" in parts:
self.run_decoder(offline_outputs,
calib_data_path=calib_data_path,
sample_id=sample_id)
counts["decoder"] += 1
return counts