Yuekai Zhang commited on
Commit
170cf1f
·
1 Parent(s): 96f87e0

add reproducable bug

Browse files
test/test_frame_reducer.py CHANGED
@@ -96,7 +96,7 @@ class FrameReducer(nn.Module):
96
  """
97
  N, T, C = x.size()
98
 
99
- padding_mask = make_pad_mask(x_lens)
100
  non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
101
 
102
  if y_lens is not None:
@@ -188,4 +188,4 @@ if __name__ == "__main__":
188
  avg_time += delta_time
189
  print(x_fr.shape)
190
  print(x_lens_fr)
191
- print(avg_time / test_times)
 
96
  """
97
  N, T, C = x.size()
98
 
99
+ padding_mask = make_pad_mask(x_lens, x.size(1))
100
  non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
101
 
102
  if y_lens is not None:
 
188
  avg_time += delta_time
189
  print(x_fr.shape)
190
  print(x_lens_fr)
191
+ print(avg_time / test_times)
test/test_riva_wfst_decoder.py CHANGED
@@ -6,6 +6,8 @@ from riva.asrlib.decoder.python_decoder import BatchedMappedDecoderCuda, Batched
6
  from typing import List
7
  from test_frame_reducer import FrameReducer
8
 
 
 
9
  def remove_duplicates_and_blank(hyp: List[int],
10
  eos: int,
11
  blank_id: int = 0) -> List[int]:
@@ -37,6 +39,7 @@ class RivaWFSTDecoder:
37
  config.online_opts.decoder_opts.blank_penalty = 0.95
38
  config.online_opts.num_post_processing_worker_threads = 16
39
  config.online_opts.num_decoder_copy_threads = 4
 
40
 
41
  #config.online_opts.decoder_opts.ntokens_pre_allocated = 10_000_000
42
 
@@ -77,6 +80,7 @@ class RivaWFSTDecoder:
77
  results = self.decoder.decode_mbr(logits, sequence_lengths_tensor)
78
  total_hyps = []
79
  for sent in results:
 
80
  hyp = [word[0] for word in sent]
81
  hyp_zh = "".join(hyp)
82
  total_hyps.append(hyp_zh)
@@ -94,24 +98,25 @@ def load_word_symbols(path):
94
 
95
  if __name__ == "__main__":
96
  lang_dir = "../output" # TLG.fst, words.txt
97
- data = np.load('./data/input2.npz')
98
  word_id_to_word_str = load_word_symbols(os.path.join(lang_dir, "words.txt"))
99
  char_dict = load_word_symbols('./data/words.txt')
100
 
101
- beam_size = 10
102
- batch_size = 10
103
- counts = 10
104
 
105
  # ctc_log_probs [1,103,4233]
106
  ctc_log_probs = torch.from_numpy(data['ctc_log_probs'])
107
  # ctc_log_probs , [batch_size,T,vocab_size ]
108
  ctc_log_probs = ctc_log_probs.repeat(batch_size,1,1)
109
- encoder_out_lens = torch.from_numpy(data['encoder_out_lens']) # encoder_out_lens single element 103
 
110
  encoder_out_lens = encoder_out_lens.repeat(batch_size) # [batch_size]
111
  ctc_log_probs = ctc_log_probs.contiguous().cuda()
112
  frame_reducer = FrameReducer()
113
 
114
- ctc_log_probs, encoder_out_lens = frame_reducer(ctc_log_probs, encoder_out_lens.cuda(), ctc_log_probs)
115
 
116
  vocab_size = ctc_log_probs.shape[2]
117
  riva_decoder = RivaWFSTDecoder(vocab_size, lang_dir, beam_size)
@@ -120,9 +125,9 @@ if __name__ == "__main__":
120
  for i in range(counts):
121
  print("ctc_log_probs.shape:", ctc_log_probs.shape)
122
  total_hyps = riva_decoder.decode_mbr(ctc_log_probs, encoder_out_lens)
123
- print('mbr', total_hyps)
124
- # total_hyps = riva_decoder.decode_nbest(ctc_log_probs, encoder_out_lens)
125
- # print('nbest', total_hyps)
126
  decode_end = time.perf_counter() - decode_start
127
  #chunk_size = 32
128
  ctc_log_probs_list, is_first_chunk, is_last_chunk = [], [True] * batch_size, [True] * batch_size
@@ -131,13 +136,13 @@ if __name__ == "__main__":
131
  success = riva_decoder.online_decoder.try_init_corr_id(corr_id)
132
  assert success
133
  for i in range(batch_size):
134
- #ctc_log_probs_list.append(ctc_log_probs[i,:chunk_size,:])
135
  ctc_log_probs_list.append(ctc_log_probs[i,:,:])
136
  channels, partial_hypotheses = \
137
  riva_decoder.online_decoder.decode_batch(corr_ids, ctc_log_probs_list,
138
  is_first_chunk, is_last_chunk)
139
 
140
  for j, ph in enumerate(partial_hypotheses):
141
- print(j, ph.words, ph.score, ph.ilabels)
 
142
 
143
- print(f"Decode {ctc_log_probs.shape[0] * counts} sentences, cost {decode_end} seconds")
 
6
  from typing import List
7
  from test_frame_reducer import FrameReducer
8
 
9
+ USE_FINAL_PROBS=False
10
+
11
  def remove_duplicates_and_blank(hyp: List[int],
12
  eos: int,
13
  blank_id: int = 0) -> List[int]:
 
39
  config.online_opts.decoder_opts.blank_penalty = 0.95
40
  config.online_opts.num_post_processing_worker_threads = 16
41
  config.online_opts.num_decoder_copy_threads = 4
42
+ config.online_opts.use_final_probs = USE_FINAL_PROBS
43
 
44
  #config.online_opts.decoder_opts.ntokens_pre_allocated = 10_000_000
45
 
 
80
  results = self.decoder.decode_mbr(logits, sequence_lengths_tensor)
81
  total_hyps = []
82
  for sent in results:
83
+ #print([word for word in sent])
84
  hyp = [word[0] for word in sent]
85
  hyp_zh = "".join(hyp)
86
  total_hyps.append(hyp_zh)
 
98
 
99
  if __name__ == "__main__":
100
  lang_dir = "../output" # TLG.fst, words.txt
101
+ data = np.load('./data/input3.npz')
102
  word_id_to_word_str = load_word_symbols(os.path.join(lang_dir, "words.txt"))
103
  char_dict = load_word_symbols('./data/words.txt')
104
 
105
+ beam_size = 7
106
+ batch_size = 1
107
+ counts = 1
108
 
109
  # ctc_log_probs [1,103,4233]
110
  ctc_log_probs = torch.from_numpy(data['ctc_log_probs'])
111
  # ctc_log_probs , [batch_size,T,vocab_size ]
112
  ctc_log_probs = ctc_log_probs.repeat(batch_size,1,1)
113
+ encoder_out_lens = torch.from_numpy(data['encoder_out_len']) # encoder_out_lens single element 103
114
+ #encoder_out_lens = torch.from_numpy(data['encoder_out_lens']) # encoder_out_lens single element 103
115
  encoder_out_lens = encoder_out_lens.repeat(batch_size) # [batch_size]
116
  ctc_log_probs = ctc_log_probs.contiguous().cuda()
117
  frame_reducer = FrameReducer()
118
 
119
+ #ctc_log_probs, encoder_out_lens = frame_reducer(ctc_log_probs, encoder_out_lens.cuda(), ctc_log_probs)
120
 
121
  vocab_size = ctc_log_probs.shape[2]
122
  riva_decoder = RivaWFSTDecoder(vocab_size, lang_dir, beam_size)
 
125
  for i in range(counts):
126
  print("ctc_log_probs.shape:", ctc_log_probs.shape)
127
  total_hyps = riva_decoder.decode_mbr(ctc_log_probs, encoder_out_lens)
128
+ print('mbr', 'use_final_probs:', USE_FINAL_PROBS, total_hyps)
129
+ #total_hyps = riva_decoder.decode_nbest(ctc_log_probs, encoder_out_lens)
130
+ #print('nbest', total_hyps)
131
  decode_end = time.perf_counter() - decode_start
132
  #chunk_size = 32
133
  ctc_log_probs_list, is_first_chunk, is_last_chunk = [], [True] * batch_size, [True] * batch_size
 
136
  success = riva_decoder.online_decoder.try_init_corr_id(corr_id)
137
  assert success
138
  for i in range(batch_size):
 
139
  ctc_log_probs_list.append(ctc_log_probs[i,:,:])
140
  channels, partial_hypotheses = \
141
  riva_decoder.online_decoder.decode_batch(corr_ids, ctc_log_probs_list,
142
  is_first_chunk, is_last_chunk)
143
 
144
  for j, ph in enumerate(partial_hypotheses):
145
+ #print("streaming word ids", ph.words, ph.score)
146
+ pass
147
 
148
+ #print(f"Decode {ctc_log_probs.shape[0] * counts} sentences, cost {decode_end} seconds")