aishell1_tlg_essentials / test /test_riva_wfst_decoder.py
Yuekai Zhang
add reproducable bug
170cf1f
import numpy as np
import time
import torch
import os
from riva.asrlib.decoder.python_decoder import BatchedMappedDecoderCuda, BatchedMappedOnlineDecoderCuda, BatchedMappedDecoderCudaConfig
from typing import List
from test_frame_reducer import FrameReducer
USE_FINAL_PROBS=False
def remove_duplicates_and_blank(hyp: List[int],
eos: int,
blank_id: int = 0) -> List[int]:
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
if hyp[cur] != blank_id and hyp[cur] != eos:
new_hyp.append(hyp[cur])
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp
class RivaWFSTDecoder:
def __init__(self, vocab_size, tlg_dir, config_dict=None, beam_size=8):
config = BatchedMappedDecoderCudaConfig()
config.online_opts.decoder_opts.lattice_beam = beam_size
config.online_opts.lattice_postprocessor_opts.acoustic_scale = 10.0 # noqa
config.n_input_per_chunk = 50
config.online_opts.decoder_opts.default_beam = 17.0
config.online_opts.decoder_opts.max_active = 7000
config.online_opts.determinize_lattice = True
config.online_opts.max_batch_size = 100
config.online_opts.num_channels = config.online_opts.max_batch_size * 2
config.online_opts.frame_shift_seconds = 0.04
config.online_opts.lattice_postprocessor_opts.lm_scale = 5.0
config.online_opts.lattice_postprocessor_opts.word_ins_penalty = 0.0
config.online_opts.decoder_opts.blank_penalty = 0.95
config.online_opts.num_post_processing_worker_threads = 16
config.online_opts.num_decoder_copy_threads = 4
config.online_opts.use_final_probs = USE_FINAL_PROBS
#config.online_opts.decoder_opts.ntokens_pre_allocated = 10_000_000
config.online_opts.lattice_postprocessor_opts.nbest = beam_size
self.decoder = BatchedMappedDecoderCuda(
config, os.path.join(tlg_dir, "TLG.fst"),
os.path.join(tlg_dir, "words.txt"), vocab_size
)
self.online_decoder = BatchedMappedOnlineDecoderCuda(
config.online_opts, os.path.join(tlg_dir, "TLG.fst"),
os.path.join(tlg_dir, "words.txt"), vocab_size
)
self.word_id_to_word_str = load_word_symbols(os.path.join(tlg_dir, "words.txt"))
self.nbest = beam_size
self.vocab_size = vocab_size
def decode_nbest(self, logits, length):
logits = logits.to(torch.float32).contiguous()
sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous()
results = self.decoder.decode_nbest(logits, sequence_lengths_tensor)
total_hyps = []
for nbest_sentences in results:
nbest_list = []
for sent in nbest_sentences:
# subtract 1 to get the label id, since fst decoder adds 1 to the label id
hyp_ids = [label - 1 for label in sent.ilabels]
new_hyp = remove_duplicates_and_blank(hyp_ids, eos=self.vocab_size-1, blank_id=0)
nbest_list.append(new_hyp)
total_hyps.append(nbest_list)
return total_hyps
def decode_mbr(self, logits, length):
logits = logits.to(torch.float32).contiguous()
sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous()
results = self.decoder.decode_mbr(logits, sequence_lengths_tensor)
total_hyps = []
for sent in results:
#print([word for word in sent])
hyp = [word[0] for word in sent]
hyp_zh = "".join(hyp)
total_hyps.append(hyp_zh)
return total_hyps
def load_word_symbols(path):
word_id_to_word_str = {}
with open(path, "rt") as fh:
for line in fh:
word_str, word_id = line.rstrip().split()
word_id_to_word_str[int(word_id)] = word_str
return word_id_to_word_str
if __name__ == "__main__":
lang_dir = "../output" # TLG.fst, words.txt
data = np.load('./data/input3.npz')
word_id_to_word_str = load_word_symbols(os.path.join(lang_dir, "words.txt"))
char_dict = load_word_symbols('./data/words.txt')
beam_size = 7
batch_size = 1
counts = 1
# ctc_log_probs [1,103,4233]
ctc_log_probs = torch.from_numpy(data['ctc_log_probs'])
# ctc_log_probs , [batch_size,T,vocab_size ]
ctc_log_probs = ctc_log_probs.repeat(batch_size,1,1)
encoder_out_lens = torch.from_numpy(data['encoder_out_len']) # encoder_out_lens single element 103
#encoder_out_lens = torch.from_numpy(data['encoder_out_lens']) # encoder_out_lens single element 103
encoder_out_lens = encoder_out_lens.repeat(batch_size) # [batch_size]
ctc_log_probs = ctc_log_probs.contiguous().cuda()
frame_reducer = FrameReducer()
#ctc_log_probs, encoder_out_lens = frame_reducer(ctc_log_probs, encoder_out_lens.cuda(), ctc_log_probs)
vocab_size = ctc_log_probs.shape[2]
riva_decoder = RivaWFSTDecoder(vocab_size, lang_dir, beam_size)
decode_start = time.perf_counter()
for i in range(counts):
print("ctc_log_probs.shape:", ctc_log_probs.shape)
total_hyps = riva_decoder.decode_mbr(ctc_log_probs, encoder_out_lens)
print('mbr', 'use_final_probs:', USE_FINAL_PROBS, total_hyps)
#total_hyps = riva_decoder.decode_nbest(ctc_log_probs, encoder_out_lens)
#print('nbest', total_hyps)
decode_end = time.perf_counter() - decode_start
#chunk_size = 32
ctc_log_probs_list, is_first_chunk, is_last_chunk = [], [True] * batch_size, [True] * batch_size
corr_ids = list(range(batch_size))
for corr_id in corr_ids:
success = riva_decoder.online_decoder.try_init_corr_id(corr_id)
assert success
for i in range(batch_size):
ctc_log_probs_list.append(ctc_log_probs[i,:,:])
channels, partial_hypotheses = \
riva_decoder.online_decoder.decode_batch(corr_ids, ctc_log_probs_list,
is_first_chunk, is_last_chunk)
for j, ph in enumerate(partial_hypotheses):
#print("streaming word ids", ph.words, ph.score)
pass
#print(f"Decode {ctc_log_probs.shape[0] * counts} sentences, cost {decode_end} seconds")