Yuekai Zhang commited on
Commit
96f87e0
·
1 Parent(s): c92992e

add streaming support

Browse files
Files changed (1) hide show
  1. test/test_riva_wfst_decoder.py +27 -3
test/test_riva_wfst_decoder.py CHANGED
@@ -2,7 +2,7 @@ import numpy as np
2
  import time
3
  import torch
4
  import os
5
- from riva.asrlib.decoder.python_decoder import BatchedMappedDecoderCuda, BatchedMappedDecoderCudaConfig
6
  from typing import List
7
  from test_frame_reducer import FrameReducer
8
 
@@ -30,7 +30,7 @@ class RivaWFSTDecoder:
30
  config.online_opts.decoder_opts.max_active = 7000
31
  config.online_opts.determinize_lattice = True
32
  config.online_opts.max_batch_size = 100
33
- config.online_opts.num_channels = 200
34
  config.online_opts.frame_shift_seconds = 0.04
35
  config.online_opts.lattice_postprocessor_opts.lm_scale = 5.0
36
  config.online_opts.lattice_postprocessor_opts.word_ins_penalty = 0.0
@@ -38,12 +38,20 @@ class RivaWFSTDecoder:
38
  config.online_opts.num_post_processing_worker_threads = 16
39
  config.online_opts.num_decoder_copy_threads = 4
40
 
 
 
41
  config.online_opts.lattice_postprocessor_opts.nbest = beam_size
42
 
43
  self.decoder = BatchedMappedDecoderCuda(
44
  config, os.path.join(tlg_dir, "TLG.fst"),
45
  os.path.join(tlg_dir, "words.txt"), vocab_size
46
  )
 
 
 
 
 
 
47
  self.word_id_to_word_str = load_word_symbols(os.path.join(tlg_dir, "words.txt"))
48
  self.nbest = beam_size
49
  self.vocab_size = vocab_size
@@ -91,7 +99,7 @@ if __name__ == "__main__":
91
  char_dict = load_word_symbols('./data/words.txt')
92
 
93
  beam_size = 10
94
- batch_size = 1
95
  counts = 10
96
 
97
  # ctc_log_probs [1,103,4233]
@@ -116,4 +124,20 @@ if __name__ == "__main__":
116
  # total_hyps = riva_decoder.decode_nbest(ctc_log_probs, encoder_out_lens)
117
  # print('nbest', total_hyps)
118
  decode_end = time.perf_counter() - decode_start
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  print(f"Decode {ctc_log_probs.shape[0] * counts} sentences, cost {decode_end} seconds")
 
2
  import time
3
  import torch
4
  import os
5
+ from riva.asrlib.decoder.python_decoder import BatchedMappedDecoderCuda, BatchedMappedOnlineDecoderCuda, BatchedMappedDecoderCudaConfig
6
  from typing import List
7
  from test_frame_reducer import FrameReducer
8
 
 
30
  config.online_opts.decoder_opts.max_active = 7000
31
  config.online_opts.determinize_lattice = True
32
  config.online_opts.max_batch_size = 100
33
+ config.online_opts.num_channels = config.online_opts.max_batch_size * 2
34
  config.online_opts.frame_shift_seconds = 0.04
35
  config.online_opts.lattice_postprocessor_opts.lm_scale = 5.0
36
  config.online_opts.lattice_postprocessor_opts.word_ins_penalty = 0.0
 
38
  config.online_opts.num_post_processing_worker_threads = 16
39
  config.online_opts.num_decoder_copy_threads = 4
40
 
41
+ #config.online_opts.decoder_opts.ntokens_pre_allocated = 10_000_000
42
+
43
  config.online_opts.lattice_postprocessor_opts.nbest = beam_size
44
 
45
  self.decoder = BatchedMappedDecoderCuda(
46
  config, os.path.join(tlg_dir, "TLG.fst"),
47
  os.path.join(tlg_dir, "words.txt"), vocab_size
48
  )
49
+
50
+ self.online_decoder = BatchedMappedOnlineDecoderCuda(
51
+ config.online_opts, os.path.join(tlg_dir, "TLG.fst"),
52
+ os.path.join(tlg_dir, "words.txt"), vocab_size
53
+ )
54
+
55
  self.word_id_to_word_str = load_word_symbols(os.path.join(tlg_dir, "words.txt"))
56
  self.nbest = beam_size
57
  self.vocab_size = vocab_size
 
99
  char_dict = load_word_symbols('./data/words.txt')
100
 
101
  beam_size = 10
102
+ batch_size = 10
103
  counts = 10
104
 
105
  # ctc_log_probs [1,103,4233]
 
124
  # total_hyps = riva_decoder.decode_nbest(ctc_log_probs, encoder_out_lens)
125
  # print('nbest', total_hyps)
126
  decode_end = time.perf_counter() - decode_start
127
+ #chunk_size = 32
128
+ ctc_log_probs_list, is_first_chunk, is_last_chunk = [], [True] * batch_size, [True] * batch_size
129
+ corr_ids = list(range(batch_size))
130
+ for corr_id in corr_ids:
131
+ success = riva_decoder.online_decoder.try_init_corr_id(corr_id)
132
+ assert success
133
+ for i in range(batch_size):
134
+ #ctc_log_probs_list.append(ctc_log_probs[i,:chunk_size,:])
135
+ ctc_log_probs_list.append(ctc_log_probs[i,:,:])
136
+ channels, partial_hypotheses = \
137
+ riva_decoder.online_decoder.decode_batch(corr_ids, ctc_log_probs_list,
138
+ is_first_chunk, is_last_chunk)
139
+
140
+ for j, ph in enumerate(partial_hypotheses):
141
+ print(j, ph.words, ph.score, ph.ilabels)
142
+
143
  print(f"Decode {ctc_log_probs.shape[0] * counts} sentences, cost {decode_end} seconds")