|
|
import numpy as np |
|
|
import time |
|
|
import torch |
|
|
import os |
|
|
from riva.asrlib.decoder.python_decoder import BatchedMappedDecoderCuda, BatchedMappedOnlineDecoderCuda, BatchedMappedDecoderCudaConfig |
|
|
from typing import List |
|
|
from test_frame_reducer import FrameReducer |
|
|
|
|
|
USE_FINAL_PROBS=False |
|
|
|
|
|
def remove_duplicates_and_blank(hyp: List[int], |
|
|
eos: int, |
|
|
blank_id: int = 0) -> List[int]: |
|
|
new_hyp: List[int] = [] |
|
|
cur = 0 |
|
|
while cur < len(hyp): |
|
|
if hyp[cur] != blank_id and hyp[cur] != eos: |
|
|
new_hyp.append(hyp[cur]) |
|
|
prev = cur |
|
|
while cur < len(hyp) and hyp[cur] == hyp[prev]: |
|
|
cur += 1 |
|
|
return new_hyp |
|
|
|
|
|
class RivaWFSTDecoder: |
|
|
def __init__(self, vocab_size, tlg_dir, config_dict=None, beam_size=8): |
|
|
config = BatchedMappedDecoderCudaConfig() |
|
|
config.online_opts.decoder_opts.lattice_beam = beam_size |
|
|
|
|
|
config.online_opts.lattice_postprocessor_opts.acoustic_scale = 10.0 |
|
|
config.n_input_per_chunk = 50 |
|
|
config.online_opts.decoder_opts.default_beam = 17.0 |
|
|
config.online_opts.decoder_opts.max_active = 7000 |
|
|
config.online_opts.determinize_lattice = True |
|
|
config.online_opts.max_batch_size = 100 |
|
|
config.online_opts.num_channels = config.online_opts.max_batch_size * 2 |
|
|
config.online_opts.frame_shift_seconds = 0.04 |
|
|
config.online_opts.lattice_postprocessor_opts.lm_scale = 5.0 |
|
|
config.online_opts.lattice_postprocessor_opts.word_ins_penalty = 0.0 |
|
|
config.online_opts.decoder_opts.blank_penalty = 0.95 |
|
|
config.online_opts.num_post_processing_worker_threads = 16 |
|
|
config.online_opts.num_decoder_copy_threads = 4 |
|
|
config.online_opts.use_final_probs = USE_FINAL_PROBS |
|
|
|
|
|
|
|
|
|
|
|
config.online_opts.lattice_postprocessor_opts.nbest = beam_size |
|
|
|
|
|
self.decoder = BatchedMappedDecoderCuda( |
|
|
config, os.path.join(tlg_dir, "TLG.fst"), |
|
|
os.path.join(tlg_dir, "words.txt"), vocab_size |
|
|
) |
|
|
|
|
|
self.online_decoder = BatchedMappedOnlineDecoderCuda( |
|
|
config.online_opts, os.path.join(tlg_dir, "TLG.fst"), |
|
|
os.path.join(tlg_dir, "words.txt"), vocab_size |
|
|
) |
|
|
|
|
|
self.word_id_to_word_str = load_word_symbols(os.path.join(tlg_dir, "words.txt")) |
|
|
self.nbest = beam_size |
|
|
self.vocab_size = vocab_size |
|
|
|
|
|
def decode_nbest(self, logits, length): |
|
|
logits = logits.to(torch.float32).contiguous() |
|
|
sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous() |
|
|
results = self.decoder.decode_nbest(logits, sequence_lengths_tensor) |
|
|
total_hyps = [] |
|
|
for nbest_sentences in results: |
|
|
nbest_list = [] |
|
|
for sent in nbest_sentences: |
|
|
|
|
|
hyp_ids = [label - 1 for label in sent.ilabels] |
|
|
new_hyp = remove_duplicates_and_blank(hyp_ids, eos=self.vocab_size-1, blank_id=0) |
|
|
nbest_list.append(new_hyp) |
|
|
total_hyps.append(nbest_list) |
|
|
return total_hyps |
|
|
|
|
|
def decode_mbr(self, logits, length): |
|
|
logits = logits.to(torch.float32).contiguous() |
|
|
sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous() |
|
|
results = self.decoder.decode_mbr(logits, sequence_lengths_tensor) |
|
|
total_hyps = [] |
|
|
for sent in results: |
|
|
|
|
|
hyp = [word[0] for word in sent] |
|
|
hyp_zh = "".join(hyp) |
|
|
total_hyps.append(hyp_zh) |
|
|
return total_hyps |
|
|
|
|
|
|
|
|
|
|
|
def load_word_symbols(path): |
|
|
word_id_to_word_str = {} |
|
|
with open(path, "rt") as fh: |
|
|
for line in fh: |
|
|
word_str, word_id = line.rstrip().split() |
|
|
word_id_to_word_str[int(word_id)] = word_str |
|
|
return word_id_to_word_str |
|
|
|
|
|
if __name__ == "__main__": |
|
|
lang_dir = "../output" |
|
|
data = np.load('./data/input3.npz') |
|
|
word_id_to_word_str = load_word_symbols(os.path.join(lang_dir, "words.txt")) |
|
|
char_dict = load_word_symbols('./data/words.txt') |
|
|
|
|
|
beam_size = 7 |
|
|
batch_size = 1 |
|
|
counts = 1 |
|
|
|
|
|
|
|
|
ctc_log_probs = torch.from_numpy(data['ctc_log_probs']) |
|
|
|
|
|
ctc_log_probs = ctc_log_probs.repeat(batch_size,1,1) |
|
|
encoder_out_lens = torch.from_numpy(data['encoder_out_len']) |
|
|
|
|
|
encoder_out_lens = encoder_out_lens.repeat(batch_size) |
|
|
ctc_log_probs = ctc_log_probs.contiguous().cuda() |
|
|
frame_reducer = FrameReducer() |
|
|
|
|
|
|
|
|
|
|
|
vocab_size = ctc_log_probs.shape[2] |
|
|
riva_decoder = RivaWFSTDecoder(vocab_size, lang_dir, beam_size) |
|
|
|
|
|
decode_start = time.perf_counter() |
|
|
for i in range(counts): |
|
|
print("ctc_log_probs.shape:", ctc_log_probs.shape) |
|
|
total_hyps = riva_decoder.decode_mbr(ctc_log_probs, encoder_out_lens) |
|
|
print('mbr', 'use_final_probs:', USE_FINAL_PROBS, total_hyps) |
|
|
|
|
|
|
|
|
decode_end = time.perf_counter() - decode_start |
|
|
|
|
|
ctc_log_probs_list, is_first_chunk, is_last_chunk = [], [True] * batch_size, [True] * batch_size |
|
|
corr_ids = list(range(batch_size)) |
|
|
for corr_id in corr_ids: |
|
|
success = riva_decoder.online_decoder.try_init_corr_id(corr_id) |
|
|
assert success |
|
|
for i in range(batch_size): |
|
|
ctc_log_probs_list.append(ctc_log_probs[i,:,:]) |
|
|
channels, partial_hypotheses = \ |
|
|
riva_decoder.online_decoder.decode_batch(corr_ids, ctc_log_probs_list, |
|
|
is_first_chunk, is_last_chunk) |
|
|
|
|
|
for j, ph in enumerate(partial_hypotheses): |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|