Yuekai Zhang commited on
Commit
55f6a13
·
1 Parent(s): 0c5ca4b

update nbest

Browse files
test/data/input2.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7831ad64f1fb28bf1adb5e197d43b00aebb430ac171c4a08db891321220fd90
3
+ size 2083188
test/test_riva_wfst_decoder.py CHANGED
@@ -3,48 +3,88 @@ import time
3
  import torch
4
  import os
5
  from riva.asrlib.decoder.python_decoder import BatchedMappedDecoderCuda, BatchedMappedDecoderCudaConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  class RivaWFSTDecoder:
8
- def __init__(self, vocab_size, tlg_dir, beam_size=8.0):
9
  config = BatchedMappedDecoderCudaConfig()
10
- config.online_opts.lattice_postprocessor_opts.acoustic_scale=10.0
 
 
11
  config.n_input_per_chunk = 50
12
  config.online_opts.decoder_opts.default_beam = 17.0
13
- config.online_opts.decoder_opts.lattice_beam = beam_size
14
  config.online_opts.decoder_opts.max_active = 7000
15
  config.online_opts.determinize_lattice = True
16
  config.online_opts.max_batch_size = 800
17
-
18
  config.online_opts.num_channels = 800
19
  config.online_opts.frame_shift_seconds = 0.04
20
-
21
  config.online_opts.lattice_postprocessor_opts.lm_scale = 5.0
22
  config.online_opts.lattice_postprocessor_opts.word_ins_penalty = 0.0
23
 
 
 
24
  self.decoder = BatchedMappedDecoderCuda(
25
- config, os.path.join(tlg_dir, "TLG.fst"), os.path.join(tlg_dir, "words.txt"), vocab_size
 
26
  )
 
 
 
27
 
28
- def decode(self, logits, length):
29
- padded_sequence = logits.contiguous()
30
  sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous()
31
- results = self.decoder.decode(padded_sequence, sequence_lengths_tensor)
32
- return results
 
 
 
 
 
 
 
 
 
33
 
34
- def get_nbest_list(self, results, nbest=1):
35
- assert nbest == 1, "Only support nbest=1 for now"
 
 
36
  total_hyps = []
37
  for sent in results:
38
  hyp = [word[0] for word in sent]
39
  hyp_zh = "".join(hyp)
40
- nbest_list = [hyp_zh] # TODO: add real nbest
41
- total_hyps.append(nbest_list)
42
  return total_hyps
43
-
 
 
 
 
 
 
 
 
 
44
 
45
  if __name__ == "__main__":
46
- lang_dir = "../data/lang_test" # TLG.fst, words.txt
47
- data = np.load('./data/input.npz')
 
 
48
 
49
  beam_size = 10
50
  batch_size = 50
@@ -64,8 +104,9 @@ if __name__ == "__main__":
64
  decode_start = time.perf_counter()
65
  for i in range(counts):
66
  print("ctc_log_probs.shape:", ctc_log_probs.shape)
67
- results = riva_decoder.decode(ctc_log_probs, encoder_out_lens)
68
- total_hyps = riva_decoder.get_nbest_list(results)
69
- print(total_hyps)
 
70
  decode_end = time.perf_counter() - decode_start
71
  print(f"Decode {ctc_log_probs.shape[0] * counts} sentences, cost {decode_end} seconds")
 
3
  import torch
4
  import os
5
  from riva.asrlib.decoder.python_decoder import BatchedMappedDecoderCuda, BatchedMappedDecoderCudaConfig
6
+ from typing import List
7
+
8
+ def remove_duplicates_and_blank(hyp: List[int],
9
+ eos: int,
10
+ blank_id: int = 0) -> List[int]:
11
+ new_hyp: List[int] = []
12
+ cur = 0
13
+ while cur < len(hyp):
14
+ if hyp[cur] != blank_id and hyp[cur] != eos:
15
+ new_hyp.append(hyp[cur])
16
+ prev = cur
17
+ while cur < len(hyp) and hyp[cur] == hyp[prev]:
18
+ cur += 1
19
+ return new_hyp
20
 
21
  class RivaWFSTDecoder:
22
+ def __init__(self, vocab_size, tlg_dir, config_dict=None, beam_size=8):
23
  config = BatchedMappedDecoderCudaConfig()
24
+ config.online_opts.decoder_opts.lattice_beam = beam_size
25
+
26
+ config.online_opts.lattice_postprocessor_opts.acoustic_scale = 10.0 # noqa
27
  config.n_input_per_chunk = 50
28
  config.online_opts.decoder_opts.default_beam = 17.0
 
29
  config.online_opts.decoder_opts.max_active = 7000
30
  config.online_opts.determinize_lattice = True
31
  config.online_opts.max_batch_size = 800
 
32
  config.online_opts.num_channels = 800
33
  config.online_opts.frame_shift_seconds = 0.04
 
34
  config.online_opts.lattice_postprocessor_opts.lm_scale = 5.0
35
  config.online_opts.lattice_postprocessor_opts.word_ins_penalty = 0.0
36
 
37
+ config.online_opts.lattice_postprocessor_opts.nbest = beam_size
38
+
39
  self.decoder = BatchedMappedDecoderCuda(
40
+ config, os.path.join(tlg_dir, "TLG.fst"),
41
+ os.path.join(tlg_dir, "words.txt"), vocab_size
42
  )
43
+ self.word_id_to_word_str = load_word_symbols(os.path.join(tlg_dir, "words.txt"))
44
+ self.nbest = beam_size
45
+ self.vocab_size = vocab_size
46
 
47
+ def decode_nbest(self, logits, length):
48
+ logits = logits.to(torch.float32).contiguous()
49
  sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous()
50
+ results = self.decoder.decode_nbest(logits, sequence_lengths_tensor)
51
+ total_hyps = []
52
+ for nbest_sentences in results:
53
+ nbest_list = []
54
+ for sent in nbest_sentences:
55
+ # subtract 1 to get the label id, since fst decoder adds 1 to the label id
56
+ hyp_ids = [label - 1 for label in sent.ilabels]
57
+ new_hyp = remove_duplicates_and_blank(hyp_ids, eos=self.vocab_size-1, blank_id=0)
58
+ nbest_list.append(new_hyp)
59
+ total_hyps.append(nbest_list)
60
+ return total_hyps
61
 
62
+ def decode_mbr(self, logits, length):
63
+ logits = logits.to(torch.float32).contiguous()
64
+ sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous()
65
+ results = self.decoder.decode_mbr(logits, sequence_lengths_tensor)
66
  total_hyps = []
67
  for sent in results:
68
  hyp = [word[0] for word in sent]
69
  hyp_zh = "".join(hyp)
70
+ total_hyps.append(hyp_zh)
 
71
  return total_hyps
72
+
73
+
74
+
75
+ def load_word_symbols(path):
76
+ word_id_to_word_str = {}
77
+ with open(path, "rt") as fh:
78
+ for line in fh:
79
+ word_str, word_id = line.rstrip().split()
80
+ word_id_to_word_str[int(word_id)] = word_str
81
+ return word_id_to_word_str
82
 
83
  if __name__ == "__main__":
84
+ lang_dir = "../output" # TLG.fst, words.txt
85
+ data = np.load('./data/input2.npz')
86
+ word_id_to_word_str = load_word_symbols(os.path.join(lang_dir, "words.txt"))
87
+ char_dict = load_word_symbols('./data/words.txt')
88
 
89
  beam_size = 10
90
  batch_size = 50
 
104
  decode_start = time.perf_counter()
105
  for i in range(counts):
106
  print("ctc_log_probs.shape:", ctc_log_probs.shape)
107
+ total_hyps = riva_decoder.decode_mbr(ctc_log_probs, encoder_out_lens)
108
+ print('mbr', total_hyps)
109
+ total_hyps = riva_decoder.decode_nbest(ctc_log_probs, encoder_out_lens)
110
+ print('nbest', total_hyps)
111
  decode_end = time.perf_counter() - decode_start
112
  print(f"Decode {ctc_log_probs.shape[0] * counts} sentences, cost {decode_end} seconds")