File size: 6,286 Bytes
0c5ca4b 96f87e0 55f6a13 57bf40b 55f6a13 170cf1f 55f6a13 0c5ca4b 55f6a13 0c5ca4b 55f6a13 0c5ca4b 57bf40b 96f87e0 0c5ca4b 57bf40b 170cf1f 0c5ca4b 96f87e0 55f6a13 0c5ca4b 55f6a13 0c5ca4b 96f87e0 55f6a13 0c5ca4b 55f6a13 0c5ca4b 55f6a13 0c5ca4b 55f6a13 0c5ca4b 170cf1f 0c5ca4b 55f6a13 0c5ca4b 55f6a13 0c5ca4b 55f6a13 170cf1f 55f6a13 0c5ca4b 170cf1f 0c5ca4b 170cf1f 0c5ca4b 57bf40b 170cf1f 0c5ca4b 55f6a13 170cf1f 0c5ca4b 96f87e0 170cf1f 96f87e0 170cf1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import numpy as np
import time
import torch
import os
from riva.asrlib.decoder.python_decoder import BatchedMappedDecoderCuda, BatchedMappedOnlineDecoderCuda, BatchedMappedDecoderCudaConfig
from typing import List
from test_frame_reducer import FrameReducer
USE_FINAL_PROBS=False
def remove_duplicates_and_blank(hyp: List[int],
eos: int,
blank_id: int = 0) -> List[int]:
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
if hyp[cur] != blank_id and hyp[cur] != eos:
new_hyp.append(hyp[cur])
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp
class RivaWFSTDecoder:
def __init__(self, vocab_size, tlg_dir, config_dict=None, beam_size=8):
config = BatchedMappedDecoderCudaConfig()
config.online_opts.decoder_opts.lattice_beam = beam_size
config.online_opts.lattice_postprocessor_opts.acoustic_scale = 10.0 # noqa
config.n_input_per_chunk = 50
config.online_opts.decoder_opts.default_beam = 17.0
config.online_opts.decoder_opts.max_active = 7000
config.online_opts.determinize_lattice = True
config.online_opts.max_batch_size = 100
config.online_opts.num_channels = config.online_opts.max_batch_size * 2
config.online_opts.frame_shift_seconds = 0.04
config.online_opts.lattice_postprocessor_opts.lm_scale = 5.0
config.online_opts.lattice_postprocessor_opts.word_ins_penalty = 0.0
config.online_opts.decoder_opts.blank_penalty = 0.95
config.online_opts.num_post_processing_worker_threads = 16
config.online_opts.num_decoder_copy_threads = 4
config.online_opts.use_final_probs = USE_FINAL_PROBS
#config.online_opts.decoder_opts.ntokens_pre_allocated = 10_000_000
config.online_opts.lattice_postprocessor_opts.nbest = beam_size
self.decoder = BatchedMappedDecoderCuda(
config, os.path.join(tlg_dir, "TLG.fst"),
os.path.join(tlg_dir, "words.txt"), vocab_size
)
self.online_decoder = BatchedMappedOnlineDecoderCuda(
config.online_opts, os.path.join(tlg_dir, "TLG.fst"),
os.path.join(tlg_dir, "words.txt"), vocab_size
)
self.word_id_to_word_str = load_word_symbols(os.path.join(tlg_dir, "words.txt"))
self.nbest = beam_size
self.vocab_size = vocab_size
def decode_nbest(self, logits, length):
logits = logits.to(torch.float32).contiguous()
sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous()
results = self.decoder.decode_nbest(logits, sequence_lengths_tensor)
total_hyps = []
for nbest_sentences in results:
nbest_list = []
for sent in nbest_sentences:
# subtract 1 to get the label id, since fst decoder adds 1 to the label id
hyp_ids = [label - 1 for label in sent.ilabels]
new_hyp = remove_duplicates_and_blank(hyp_ids, eos=self.vocab_size-1, blank_id=0)
nbest_list.append(new_hyp)
total_hyps.append(nbest_list)
return total_hyps
def decode_mbr(self, logits, length):
logits = logits.to(torch.float32).contiguous()
sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous()
results = self.decoder.decode_mbr(logits, sequence_lengths_tensor)
total_hyps = []
for sent in results:
#print([word for word in sent])
hyp = [word[0] for word in sent]
hyp_zh = "".join(hyp)
total_hyps.append(hyp_zh)
return total_hyps
def load_word_symbols(path):
word_id_to_word_str = {}
with open(path, "rt") as fh:
for line in fh:
word_str, word_id = line.rstrip().split()
word_id_to_word_str[int(word_id)] = word_str
return word_id_to_word_str
if __name__ == "__main__":
lang_dir = "../output" # TLG.fst, words.txt
data = np.load('./data/input3.npz')
word_id_to_word_str = load_word_symbols(os.path.join(lang_dir, "words.txt"))
char_dict = load_word_symbols('./data/words.txt')
beam_size = 7
batch_size = 1
counts = 1
# ctc_log_probs [1,103,4233]
ctc_log_probs = torch.from_numpy(data['ctc_log_probs'])
# ctc_log_probs , [batch_size,T,vocab_size ]
ctc_log_probs = ctc_log_probs.repeat(batch_size,1,1)
encoder_out_lens = torch.from_numpy(data['encoder_out_len']) # encoder_out_lens single element 103
#encoder_out_lens = torch.from_numpy(data['encoder_out_lens']) # encoder_out_lens single element 103
encoder_out_lens = encoder_out_lens.repeat(batch_size) # [batch_size]
ctc_log_probs = ctc_log_probs.contiguous().cuda()
frame_reducer = FrameReducer()
#ctc_log_probs, encoder_out_lens = frame_reducer(ctc_log_probs, encoder_out_lens.cuda(), ctc_log_probs)
vocab_size = ctc_log_probs.shape[2]
riva_decoder = RivaWFSTDecoder(vocab_size, lang_dir, beam_size)
decode_start = time.perf_counter()
for i in range(counts):
print("ctc_log_probs.shape:", ctc_log_probs.shape)
total_hyps = riva_decoder.decode_mbr(ctc_log_probs, encoder_out_lens)
print('mbr', 'use_final_probs:', USE_FINAL_PROBS, total_hyps)
#total_hyps = riva_decoder.decode_nbest(ctc_log_probs, encoder_out_lens)
#print('nbest', total_hyps)
decode_end = time.perf_counter() - decode_start
#chunk_size = 32
ctc_log_probs_list, is_first_chunk, is_last_chunk = [], [True] * batch_size, [True] * batch_size
corr_ids = list(range(batch_size))
for corr_id in corr_ids:
success = riva_decoder.online_decoder.try_init_corr_id(corr_id)
assert success
for i in range(batch_size):
ctc_log_probs_list.append(ctc_log_probs[i,:,:])
channels, partial_hypotheses = \
riva_decoder.online_decoder.decode_batch(corr_ids, ctc_log_probs_list,
is_first_chunk, is_last_chunk)
for j, ph in enumerate(partial_hypotheses):
#print("streaming word ids", ph.words, ph.score)
pass
#print(f"Decode {ctc_log_probs.shape[0] * counts} sentences, cost {decode_end} seconds")
|