Yuekai Zhang
commited on
Commit
·
55f6a13
1
Parent(s):
0c5ca4b
update nbest
Browse files- test/data/input2.npz +3 -0
- test/test_riva_wfst_decoder.py +61 -20
test/data/input2.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a7831ad64f1fb28bf1adb5e197d43b00aebb430ac171c4a08db891321220fd90
|
| 3 |
+
size 2083188
|
test/test_riva_wfst_decoder.py
CHANGED
|
@@ -3,48 +3,88 @@ import time
|
|
| 3 |
import torch
|
| 4 |
import os
|
| 5 |
from riva.asrlib.decoder.python_decoder import BatchedMappedDecoderCuda, BatchedMappedDecoderCudaConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
class RivaWFSTDecoder:
|
| 8 |
-
def __init__(self, vocab_size, tlg_dir, beam_size=8
|
| 9 |
config = BatchedMappedDecoderCudaConfig()
|
| 10 |
-
config.online_opts.
|
|
|
|
|
|
|
| 11 |
config.n_input_per_chunk = 50
|
| 12 |
config.online_opts.decoder_opts.default_beam = 17.0
|
| 13 |
-
config.online_opts.decoder_opts.lattice_beam = beam_size
|
| 14 |
config.online_opts.decoder_opts.max_active = 7000
|
| 15 |
config.online_opts.determinize_lattice = True
|
| 16 |
config.online_opts.max_batch_size = 800
|
| 17 |
-
|
| 18 |
config.online_opts.num_channels = 800
|
| 19 |
config.online_opts.frame_shift_seconds = 0.04
|
| 20 |
-
|
| 21 |
config.online_opts.lattice_postprocessor_opts.lm_scale = 5.0
|
| 22 |
config.online_opts.lattice_postprocessor_opts.word_ins_penalty = 0.0
|
| 23 |
|
|
|
|
|
|
|
| 24 |
self.decoder = BatchedMappedDecoderCuda(
|
| 25 |
-
config, os.path.join(tlg_dir, "TLG.fst"),
|
|
|
|
| 26 |
)
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
def
|
| 29 |
-
|
| 30 |
sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous()
|
| 31 |
-
results = self.decoder.
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
def
|
| 35 |
-
|
|
|
|
|
|
|
| 36 |
total_hyps = []
|
| 37 |
for sent in results:
|
| 38 |
hyp = [word[0] for word in sent]
|
| 39 |
hyp_zh = "".join(hyp)
|
| 40 |
-
|
| 41 |
-
total_hyps.append(nbest_list)
|
| 42 |
return total_hyps
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
if __name__ == "__main__":
|
| 46 |
-
lang_dir = "../
|
| 47 |
-
data = np.load('./data/
|
|
|
|
|
|
|
| 48 |
|
| 49 |
beam_size = 10
|
| 50 |
batch_size = 50
|
|
@@ -64,8 +104,9 @@ if __name__ == "__main__":
|
|
| 64 |
decode_start = time.perf_counter()
|
| 65 |
for i in range(counts):
|
| 66 |
print("ctc_log_probs.shape:", ctc_log_probs.shape)
|
| 67 |
-
|
| 68 |
-
total_hyps
|
| 69 |
-
|
|
|
|
| 70 |
decode_end = time.perf_counter() - decode_start
|
| 71 |
print(f"Decode {ctc_log_probs.shape[0] * counts} sentences, cost {decode_end} seconds")
|
|
|
|
| 3 |
import torch
|
| 4 |
import os
|
| 5 |
from riva.asrlib.decoder.python_decoder import BatchedMappedDecoderCuda, BatchedMappedDecoderCudaConfig
|
| 6 |
+
from typing import List
|
| 7 |
+
|
| 8 |
+
def remove_duplicates_and_blank(hyp: List[int],
|
| 9 |
+
eos: int,
|
| 10 |
+
blank_id: int = 0) -> List[int]:
|
| 11 |
+
new_hyp: List[int] = []
|
| 12 |
+
cur = 0
|
| 13 |
+
while cur < len(hyp):
|
| 14 |
+
if hyp[cur] != blank_id and hyp[cur] != eos:
|
| 15 |
+
new_hyp.append(hyp[cur])
|
| 16 |
+
prev = cur
|
| 17 |
+
while cur < len(hyp) and hyp[cur] == hyp[prev]:
|
| 18 |
+
cur += 1
|
| 19 |
+
return new_hyp
|
| 20 |
|
| 21 |
class RivaWFSTDecoder:
|
| 22 |
+
def __init__(self, vocab_size, tlg_dir, config_dict=None, beam_size=8):
|
| 23 |
config = BatchedMappedDecoderCudaConfig()
|
| 24 |
+
config.online_opts.decoder_opts.lattice_beam = beam_size
|
| 25 |
+
|
| 26 |
+
config.online_opts.lattice_postprocessor_opts.acoustic_scale = 10.0 # noqa
|
| 27 |
config.n_input_per_chunk = 50
|
| 28 |
config.online_opts.decoder_opts.default_beam = 17.0
|
|
|
|
| 29 |
config.online_opts.decoder_opts.max_active = 7000
|
| 30 |
config.online_opts.determinize_lattice = True
|
| 31 |
config.online_opts.max_batch_size = 800
|
|
|
|
| 32 |
config.online_opts.num_channels = 800
|
| 33 |
config.online_opts.frame_shift_seconds = 0.04
|
|
|
|
| 34 |
config.online_opts.lattice_postprocessor_opts.lm_scale = 5.0
|
| 35 |
config.online_opts.lattice_postprocessor_opts.word_ins_penalty = 0.0
|
| 36 |
|
| 37 |
+
config.online_opts.lattice_postprocessor_opts.nbest = beam_size
|
| 38 |
+
|
| 39 |
self.decoder = BatchedMappedDecoderCuda(
|
| 40 |
+
config, os.path.join(tlg_dir, "TLG.fst"),
|
| 41 |
+
os.path.join(tlg_dir, "words.txt"), vocab_size
|
| 42 |
)
|
| 43 |
+
self.word_id_to_word_str = load_word_symbols(os.path.join(tlg_dir, "words.txt"))
|
| 44 |
+
self.nbest = beam_size
|
| 45 |
+
self.vocab_size = vocab_size
|
| 46 |
|
| 47 |
+
def decode_nbest(self, logits, length):
|
| 48 |
+
logits = logits.to(torch.float32).contiguous()
|
| 49 |
sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous()
|
| 50 |
+
results = self.decoder.decode_nbest(logits, sequence_lengths_tensor)
|
| 51 |
+
total_hyps = []
|
| 52 |
+
for nbest_sentences in results:
|
| 53 |
+
nbest_list = []
|
| 54 |
+
for sent in nbest_sentences:
|
| 55 |
+
# subtract 1 to get the label id, since fst decoder adds 1 to the label id
|
| 56 |
+
hyp_ids = [label - 1 for label in sent.ilabels]
|
| 57 |
+
new_hyp = remove_duplicates_and_blank(hyp_ids, eos=self.vocab_size-1, blank_id=0)
|
| 58 |
+
nbest_list.append(new_hyp)
|
| 59 |
+
total_hyps.append(nbest_list)
|
| 60 |
+
return total_hyps
|
| 61 |
|
| 62 |
+
def decode_mbr(self, logits, length):
|
| 63 |
+
logits = logits.to(torch.float32).contiguous()
|
| 64 |
+
sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous()
|
| 65 |
+
results = self.decoder.decode_mbr(logits, sequence_lengths_tensor)
|
| 66 |
total_hyps = []
|
| 67 |
for sent in results:
|
| 68 |
hyp = [word[0] for word in sent]
|
| 69 |
hyp_zh = "".join(hyp)
|
| 70 |
+
total_hyps.append(hyp_zh)
|
|
|
|
| 71 |
return total_hyps
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def load_word_symbols(path):
|
| 76 |
+
word_id_to_word_str = {}
|
| 77 |
+
with open(path, "rt") as fh:
|
| 78 |
+
for line in fh:
|
| 79 |
+
word_str, word_id = line.rstrip().split()
|
| 80 |
+
word_id_to_word_str[int(word_id)] = word_str
|
| 81 |
+
return word_id_to_word_str
|
| 82 |
|
| 83 |
if __name__ == "__main__":
|
| 84 |
+
lang_dir = "../output" # TLG.fst, words.txt
|
| 85 |
+
data = np.load('./data/input2.npz')
|
| 86 |
+
word_id_to_word_str = load_word_symbols(os.path.join(lang_dir, "words.txt"))
|
| 87 |
+
char_dict = load_word_symbols('./data/words.txt')
|
| 88 |
|
| 89 |
beam_size = 10
|
| 90 |
batch_size = 50
|
|
|
|
| 104 |
decode_start = time.perf_counter()
|
| 105 |
for i in range(counts):
|
| 106 |
print("ctc_log_probs.shape:", ctc_log_probs.shape)
|
| 107 |
+
total_hyps = riva_decoder.decode_mbr(ctc_log_probs, encoder_out_lens)
|
| 108 |
+
print('mbr', total_hyps)
|
| 109 |
+
total_hyps = riva_decoder.decode_nbest(ctc_log_probs, encoder_out_lens)
|
| 110 |
+
print('nbest', total_hyps)
|
| 111 |
decode_end = time.perf_counter() - decode_start
|
| 112 |
print(f"Decode {ctc_log_probs.shape[0] * counts} sentences, cost {decode_end} seconds")
|