Yuekai Zhang
commited on
Commit
·
170cf1f
1
Parent(s):
96f87e0
add reproducable bug
Browse files- test/test_frame_reducer.py +2 -2
- test/test_riva_wfst_decoder.py +17 -12
test/test_frame_reducer.py
CHANGED
|
@@ -96,7 +96,7 @@ class FrameReducer(nn.Module):
|
|
| 96 |
"""
|
| 97 |
N, T, C = x.size()
|
| 98 |
|
| 99 |
-
padding_mask = make_pad_mask(x_lens)
|
| 100 |
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
|
| 101 |
|
| 102 |
if y_lens is not None:
|
|
@@ -188,4 +188,4 @@ if __name__ == "__main__":
|
|
| 188 |
avg_time += delta_time
|
| 189 |
print(x_fr.shape)
|
| 190 |
print(x_lens_fr)
|
| 191 |
-
print(avg_time / test_times)
|
|
|
|
| 96 |
"""
|
| 97 |
N, T, C = x.size()
|
| 98 |
|
| 99 |
+
padding_mask = make_pad_mask(x_lens, x.size(1))
|
| 100 |
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
|
| 101 |
|
| 102 |
if y_lens is not None:
|
|
|
|
| 188 |
avg_time += delta_time
|
| 189 |
print(x_fr.shape)
|
| 190 |
print(x_lens_fr)
|
| 191 |
+
print(avg_time / test_times)
|
test/test_riva_wfst_decoder.py
CHANGED
|
@@ -6,6 +6,8 @@ from riva.asrlib.decoder.python_decoder import BatchedMappedDecoderCuda, Batched
|
|
| 6 |
from typing import List
|
| 7 |
from test_frame_reducer import FrameReducer
|
| 8 |
|
|
|
|
|
|
|
| 9 |
def remove_duplicates_and_blank(hyp: List[int],
|
| 10 |
eos: int,
|
| 11 |
blank_id: int = 0) -> List[int]:
|
|
@@ -37,6 +39,7 @@ class RivaWFSTDecoder:
|
|
| 37 |
config.online_opts.decoder_opts.blank_penalty = 0.95
|
| 38 |
config.online_opts.num_post_processing_worker_threads = 16
|
| 39 |
config.online_opts.num_decoder_copy_threads = 4
|
|
|
|
| 40 |
|
| 41 |
#config.online_opts.decoder_opts.ntokens_pre_allocated = 10_000_000
|
| 42 |
|
|
@@ -77,6 +80,7 @@ class RivaWFSTDecoder:
|
|
| 77 |
results = self.decoder.decode_mbr(logits, sequence_lengths_tensor)
|
| 78 |
total_hyps = []
|
| 79 |
for sent in results:
|
|
|
|
| 80 |
hyp = [word[0] for word in sent]
|
| 81 |
hyp_zh = "".join(hyp)
|
| 82 |
total_hyps.append(hyp_zh)
|
|
@@ -94,24 +98,25 @@ def load_word_symbols(path):
|
|
| 94 |
|
| 95 |
if __name__ == "__main__":
|
| 96 |
lang_dir = "../output" # TLG.fst, words.txt
|
| 97 |
-
data = np.load('./data/
|
| 98 |
word_id_to_word_str = load_word_symbols(os.path.join(lang_dir, "words.txt"))
|
| 99 |
char_dict = load_word_symbols('./data/words.txt')
|
| 100 |
|
| 101 |
-
beam_size =
|
| 102 |
-
batch_size =
|
| 103 |
-
counts =
|
| 104 |
|
| 105 |
# ctc_log_probs [1,103,4233]
|
| 106 |
ctc_log_probs = torch.from_numpy(data['ctc_log_probs'])
|
| 107 |
# ctc_log_probs , [batch_size,T,vocab_size ]
|
| 108 |
ctc_log_probs = ctc_log_probs.repeat(batch_size,1,1)
|
| 109 |
-
encoder_out_lens = torch.from_numpy(data['
|
|
|
|
| 110 |
encoder_out_lens = encoder_out_lens.repeat(batch_size) # [batch_size]
|
| 111 |
ctc_log_probs = ctc_log_probs.contiguous().cuda()
|
| 112 |
frame_reducer = FrameReducer()
|
| 113 |
|
| 114 |
-
ctc_log_probs, encoder_out_lens = frame_reducer(ctc_log_probs, encoder_out_lens.cuda(), ctc_log_probs)
|
| 115 |
|
| 116 |
vocab_size = ctc_log_probs.shape[2]
|
| 117 |
riva_decoder = RivaWFSTDecoder(vocab_size, lang_dir, beam_size)
|
|
@@ -120,9 +125,9 @@ if __name__ == "__main__":
|
|
| 120 |
for i in range(counts):
|
| 121 |
print("ctc_log_probs.shape:", ctc_log_probs.shape)
|
| 122 |
total_hyps = riva_decoder.decode_mbr(ctc_log_probs, encoder_out_lens)
|
| 123 |
-
print('mbr', total_hyps)
|
| 124 |
-
#
|
| 125 |
-
#
|
| 126 |
decode_end = time.perf_counter() - decode_start
|
| 127 |
#chunk_size = 32
|
| 128 |
ctc_log_probs_list, is_first_chunk, is_last_chunk = [], [True] * batch_size, [True] * batch_size
|
|
@@ -131,13 +136,13 @@ if __name__ == "__main__":
|
|
| 131 |
success = riva_decoder.online_decoder.try_init_corr_id(corr_id)
|
| 132 |
assert success
|
| 133 |
for i in range(batch_size):
|
| 134 |
-
#ctc_log_probs_list.append(ctc_log_probs[i,:chunk_size,:])
|
| 135 |
ctc_log_probs_list.append(ctc_log_probs[i,:,:])
|
| 136 |
channels, partial_hypotheses = \
|
| 137 |
riva_decoder.online_decoder.decode_batch(corr_ids, ctc_log_probs_list,
|
| 138 |
is_first_chunk, is_last_chunk)
|
| 139 |
|
| 140 |
for j, ph in enumerate(partial_hypotheses):
|
| 141 |
-
print(
|
|
|
|
| 142 |
|
| 143 |
-
print(f"Decode {ctc_log_probs.shape[0] * counts} sentences, cost {decode_end} seconds")
|
|
|
|
| 6 |
from typing import List
|
| 7 |
from test_frame_reducer import FrameReducer
|
| 8 |
|
| 9 |
+
USE_FINAL_PROBS=False
|
| 10 |
+
|
| 11 |
def remove_duplicates_and_blank(hyp: List[int],
|
| 12 |
eos: int,
|
| 13 |
blank_id: int = 0) -> List[int]:
|
|
|
|
| 39 |
config.online_opts.decoder_opts.blank_penalty = 0.95
|
| 40 |
config.online_opts.num_post_processing_worker_threads = 16
|
| 41 |
config.online_opts.num_decoder_copy_threads = 4
|
| 42 |
+
config.online_opts.use_final_probs = USE_FINAL_PROBS
|
| 43 |
|
| 44 |
#config.online_opts.decoder_opts.ntokens_pre_allocated = 10_000_000
|
| 45 |
|
|
|
|
| 80 |
results = self.decoder.decode_mbr(logits, sequence_lengths_tensor)
|
| 81 |
total_hyps = []
|
| 82 |
for sent in results:
|
| 83 |
+
#print([word for word in sent])
|
| 84 |
hyp = [word[0] for word in sent]
|
| 85 |
hyp_zh = "".join(hyp)
|
| 86 |
total_hyps.append(hyp_zh)
|
|
|
|
| 98 |
|
| 99 |
if __name__ == "__main__":
|
| 100 |
lang_dir = "../output" # TLG.fst, words.txt
|
| 101 |
+
data = np.load('./data/input3.npz')
|
| 102 |
word_id_to_word_str = load_word_symbols(os.path.join(lang_dir, "words.txt"))
|
| 103 |
char_dict = load_word_symbols('./data/words.txt')
|
| 104 |
|
| 105 |
+
beam_size = 7
|
| 106 |
+
batch_size = 1
|
| 107 |
+
counts = 1
|
| 108 |
|
| 109 |
# ctc_log_probs [1,103,4233]
|
| 110 |
ctc_log_probs = torch.from_numpy(data['ctc_log_probs'])
|
| 111 |
# ctc_log_probs , [batch_size,T,vocab_size ]
|
| 112 |
ctc_log_probs = ctc_log_probs.repeat(batch_size,1,1)
|
| 113 |
+
encoder_out_lens = torch.from_numpy(data['encoder_out_len']) # encoder_out_lens single element 103
|
| 114 |
+
#encoder_out_lens = torch.from_numpy(data['encoder_out_lens']) # encoder_out_lens single element 103
|
| 115 |
encoder_out_lens = encoder_out_lens.repeat(batch_size) # [batch_size]
|
| 116 |
ctc_log_probs = ctc_log_probs.contiguous().cuda()
|
| 117 |
frame_reducer = FrameReducer()
|
| 118 |
|
| 119 |
+
#ctc_log_probs, encoder_out_lens = frame_reducer(ctc_log_probs, encoder_out_lens.cuda(), ctc_log_probs)
|
| 120 |
|
| 121 |
vocab_size = ctc_log_probs.shape[2]
|
| 122 |
riva_decoder = RivaWFSTDecoder(vocab_size, lang_dir, beam_size)
|
|
|
|
| 125 |
for i in range(counts):
|
| 126 |
print("ctc_log_probs.shape:", ctc_log_probs.shape)
|
| 127 |
total_hyps = riva_decoder.decode_mbr(ctc_log_probs, encoder_out_lens)
|
| 128 |
+
print('mbr', 'use_final_probs:', USE_FINAL_PROBS, total_hyps)
|
| 129 |
+
#total_hyps = riva_decoder.decode_nbest(ctc_log_probs, encoder_out_lens)
|
| 130 |
+
#print('nbest', total_hyps)
|
| 131 |
decode_end = time.perf_counter() - decode_start
|
| 132 |
#chunk_size = 32
|
| 133 |
ctc_log_probs_list, is_first_chunk, is_last_chunk = [], [True] * batch_size, [True] * batch_size
|
|
|
|
| 136 |
success = riva_decoder.online_decoder.try_init_corr_id(corr_id)
|
| 137 |
assert success
|
| 138 |
for i in range(batch_size):
|
|
|
|
| 139 |
ctc_log_probs_list.append(ctc_log_probs[i,:,:])
|
| 140 |
channels, partial_hypotheses = \
|
| 141 |
riva_decoder.online_decoder.decode_batch(corr_ids, ctc_log_probs_list,
|
| 142 |
is_first_chunk, is_last_chunk)
|
| 143 |
|
| 144 |
for j, ph in enumerate(partial_hypotheses):
|
| 145 |
+
#print("streaming word ids", ph.words, ph.score)
|
| 146 |
+
pass
|
| 147 |
|
| 148 |
+
#print(f"Decode {ctc_log_probs.shape[0] * counts} sentences, cost {decode_end} seconds")
|