import numpy as np import time import torch import os from riva.asrlib.decoder.python_decoder import BatchedMappedDecoderCuda, BatchedMappedOnlineDecoderCuda, BatchedMappedDecoderCudaConfig from typing import List from test_frame_reducer import FrameReducer USE_FINAL_PROBS=False def remove_duplicates_and_blank(hyp: List[int], eos: int, blank_id: int = 0) -> List[int]: new_hyp: List[int] = [] cur = 0 while cur < len(hyp): if hyp[cur] != blank_id and hyp[cur] != eos: new_hyp.append(hyp[cur]) prev = cur while cur < len(hyp) and hyp[cur] == hyp[prev]: cur += 1 return new_hyp class RivaWFSTDecoder: def __init__(self, vocab_size, tlg_dir, config_dict=None, beam_size=8): config = BatchedMappedDecoderCudaConfig() config.online_opts.decoder_opts.lattice_beam = beam_size config.online_opts.lattice_postprocessor_opts.acoustic_scale = 10.0 # noqa config.n_input_per_chunk = 50 config.online_opts.decoder_opts.default_beam = 17.0 config.online_opts.decoder_opts.max_active = 7000 config.online_opts.determinize_lattice = True config.online_opts.max_batch_size = 100 config.online_opts.num_channels = config.online_opts.max_batch_size * 2 config.online_opts.frame_shift_seconds = 0.04 config.online_opts.lattice_postprocessor_opts.lm_scale = 5.0 config.online_opts.lattice_postprocessor_opts.word_ins_penalty = 0.0 config.online_opts.decoder_opts.blank_penalty = 0.95 config.online_opts.num_post_processing_worker_threads = 16 config.online_opts.num_decoder_copy_threads = 4 config.online_opts.use_final_probs = USE_FINAL_PROBS #config.online_opts.decoder_opts.ntokens_pre_allocated = 10_000_000 config.online_opts.lattice_postprocessor_opts.nbest = beam_size self.decoder = BatchedMappedDecoderCuda( config, os.path.join(tlg_dir, "TLG.fst"), os.path.join(tlg_dir, "words.txt"), vocab_size ) self.online_decoder = BatchedMappedOnlineDecoderCuda( config.online_opts, os.path.join(tlg_dir, "TLG.fst"), os.path.join(tlg_dir, "words.txt"), vocab_size ) self.word_id_to_word_str = load_word_symbols(os.path.join(tlg_dir, "words.txt")) self.nbest = beam_size self.vocab_size = vocab_size def decode_nbest(self, logits, length): logits = logits.to(torch.float32).contiguous() sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous() results = self.decoder.decode_nbest(logits, sequence_lengths_tensor) total_hyps = [] for nbest_sentences in results: nbest_list = [] for sent in nbest_sentences: # subtract 1 to get the label id, since fst decoder adds 1 to the label id hyp_ids = [label - 1 for label in sent.ilabels] new_hyp = remove_duplicates_and_blank(hyp_ids, eos=self.vocab_size-1, blank_id=0) nbest_list.append(new_hyp) total_hyps.append(nbest_list) return total_hyps def decode_mbr(self, logits, length): logits = logits.to(torch.float32).contiguous() sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous() results = self.decoder.decode_mbr(logits, sequence_lengths_tensor) total_hyps = [] for sent in results: #print([word for word in sent]) hyp = [word[0] for word in sent] hyp_zh = "".join(hyp) total_hyps.append(hyp_zh) return total_hyps def load_word_symbols(path): word_id_to_word_str = {} with open(path, "rt") as fh: for line in fh: word_str, word_id = line.rstrip().split() word_id_to_word_str[int(word_id)] = word_str return word_id_to_word_str if __name__ == "__main__": lang_dir = "../output" # TLG.fst, words.txt data = np.load('./data/input3.npz') word_id_to_word_str = load_word_symbols(os.path.join(lang_dir, "words.txt")) char_dict = load_word_symbols('./data/words.txt') beam_size = 7 batch_size = 1 counts = 1 # ctc_log_probs [1,103,4233] ctc_log_probs = torch.from_numpy(data['ctc_log_probs']) # ctc_log_probs , [batch_size,T,vocab_size ] ctc_log_probs = ctc_log_probs.repeat(batch_size,1,1) encoder_out_lens = torch.from_numpy(data['encoder_out_len']) # encoder_out_lens single element 103 #encoder_out_lens = torch.from_numpy(data['encoder_out_lens']) # encoder_out_lens single element 103 encoder_out_lens = encoder_out_lens.repeat(batch_size) # [batch_size] ctc_log_probs = ctc_log_probs.contiguous().cuda() frame_reducer = FrameReducer() #ctc_log_probs, encoder_out_lens = frame_reducer(ctc_log_probs, encoder_out_lens.cuda(), ctc_log_probs) vocab_size = ctc_log_probs.shape[2] riva_decoder = RivaWFSTDecoder(vocab_size, lang_dir, beam_size) decode_start = time.perf_counter() for i in range(counts): print("ctc_log_probs.shape:", ctc_log_probs.shape) total_hyps = riva_decoder.decode_mbr(ctc_log_probs, encoder_out_lens) print('mbr', 'use_final_probs:', USE_FINAL_PROBS, total_hyps) #total_hyps = riva_decoder.decode_nbest(ctc_log_probs, encoder_out_lens) #print('nbest', total_hyps) decode_end = time.perf_counter() - decode_start #chunk_size = 32 ctc_log_probs_list, is_first_chunk, is_last_chunk = [], [True] * batch_size, [True] * batch_size corr_ids = list(range(batch_size)) for corr_id in corr_ids: success = riva_decoder.online_decoder.try_init_corr_id(corr_id) assert success for i in range(batch_size): ctc_log_probs_list.append(ctc_log_probs[i,:,:]) channels, partial_hypotheses = \ riva_decoder.online_decoder.decode_batch(corr_ids, ctc_log_probs_list, is_first_chunk, is_last_chunk) for j, ph in enumerate(partial_hypotheses): #print("streaming word ids", ph.words, ph.score) pass #print(f"Decode {ctc_log_probs.shape[0] * counts} sentences, cost {decode_end} seconds")