Yuekai Zhang commited on
Commit
57bf40b
·
1 Parent(s): 55f6a13

add blank skip

Browse files
test/test_frame_reducer.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ #
3
+ # Copyright 2022 Xiaomi Corp. (authors: Yifan Yang,
4
+ # Zengwei Yao,
5
+ # Wei Kang)
6
+ #
7
+ # See ../../../../LICENSE for clarification regarding multiple authors
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ import math
22
+ from typing import Optional, Tuple
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+
28
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
29
+ """
30
+ Args:
31
+ lengths:
32
+ A 1-D tensor containing sentence lengths.
33
+ max_len:
34
+ The length of masks.
35
+ Returns:
36
+ Return a 2-D bool tensor, where masked positions
37
+ are filled with `True` and non-masked positions are
38
+ filled with `False`.
39
+ >>> lengths = torch.tensor([1, 3, 2, 5])
40
+ >>> make_pad_mask(lengths)
41
+ tensor([[False, True, True, True, True],
42
+ [False, False, False, True, True],
43
+ [False, False, True, True, True],
44
+ [False, False, False, False, False]])
45
+ """
46
+ assert lengths.ndim == 1, lengths.ndim
47
+ max_len = max(max_len, lengths.max())
48
+ n = lengths.size(0)
49
+ seq_range = torch.arange(0, max_len, device=lengths.device)
50
+ expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
51
+
52
+ return expaned_lengths >= lengths.unsqueeze(-1)
53
+
54
+
55
+
56
+ class FrameReducer(nn.Module):
57
+ """The encoder output is first used to calculate
58
+ the CTC posterior probability; then for each output frame,
59
+ if its blank posterior is bigger than some thresholds,
60
+ it will be simply discarded from the encoder output.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ ):
66
+ super().__init__()
67
+
68
+ def forward(
69
+ self,
70
+ x: torch.Tensor,
71
+ x_lens: torch.Tensor,
72
+ ctc_output: torch.Tensor,
73
+ y_lens: Optional[torch.Tensor] = None,
74
+ blank_id: int = 0,
75
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
76
+ """
77
+ Args:
78
+ x:
79
+ The shared encoder output with shape [N, T, C].
80
+ x_lens:
81
+ A tensor of shape (batch_size,) containing the number of frames in
82
+ `x` before padding.
83
+ ctc_output:
84
+ The CTC output with shape [N, T, vocab_size].
85
+ y_lens:
86
+ A tensor of shape (batch_size,) containing the number of frames in
87
+ `y` before padding.
88
+ blank_id:
89
+ The blank id of ctc_output.
90
+ Returns:
91
+ out:
92
+ The frame reduced encoder output with shape [N, T', C].
93
+ out_lens:
94
+ A tensor of shape (batch_size,) containing the number of frames in
95
+ `out` before padding.
96
+ """
97
+ N, T, C = x.size()
98
+
99
+ padding_mask = make_pad_mask(x_lens)
100
+ non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
101
+
102
+ if y_lens is not None:
103
+ # Limit the maximum number of reduced frames
104
+ limit_lens = T - y_lens
105
+ max_limit_len = limit_lens.max().int()
106
+ fake_limit_indexes = torch.topk(
107
+ ctc_output[:, :, blank_id], max_limit_len
108
+ ).indices
109
+ T = (
110
+ torch.arange(max_limit_len)
111
+ .expand_as(
112
+ fake_limit_indexes,
113
+ )
114
+ .to(device=x.device)
115
+ )
116
+ T = torch.remainder(T, limit_lens.unsqueeze(1))
117
+ limit_indexes = torch.gather(fake_limit_indexes, 1, T)
118
+ limit_mask = torch.full_like(
119
+ non_blank_mask,
120
+ False,
121
+ device=x.device,
122
+ ).scatter_(1, limit_indexes, True)
123
+
124
+ non_blank_mask = non_blank_mask | ~limit_mask
125
+
126
+ out_lens = non_blank_mask.sum(dim=1)
127
+ max_len = out_lens.max()
128
+ pad_lens_list = (
129
+ torch.full_like(
130
+ out_lens,
131
+ max_len.item(),
132
+ device=x.device,
133
+ )
134
+ - out_lens
135
+ )
136
+ max_pad_len = pad_lens_list.max()
137
+
138
+ out = F.pad(x, (0, 0, 0, max_pad_len))
139
+
140
+ valid_pad_mask = ~make_pad_mask(pad_lens_list)
141
+ total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1)
142
+
143
+ out = out[total_valid_mask].reshape(N, -1, C)
144
+
145
+ return out, out_lens
146
+
147
+
148
+ if __name__ == "__main__":
149
+ import time
150
+
151
+ test_times = 10000
152
+ device = "cuda:0"
153
+ frame_reducer = FrameReducer()
154
+
155
+ # non zero case
156
+ x = torch.ones(15, 498, 384, dtype=torch.float32, device=device)
157
+ x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
158
+ y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
159
+ ctc_output = torch.log(
160
+ torch.randn(15, 498, 500, dtype=torch.float32, device=device),
161
+ )
162
+
163
+ avg_time = 0
164
+ for i in range(test_times):
165
+ torch.cuda.synchronize(device=x.device)
166
+ delta_time = time.time()
167
+ x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
168
+ torch.cuda.synchronize(device=x.device)
169
+ delta_time = time.time() - delta_time
170
+ avg_time += delta_time
171
+ print(x_fr.shape)
172
+ print(x_lens_fr)
173
+ print(avg_time / test_times)
174
+
175
+ # all zero case
176
+ x = torch.zeros(15, 498, 384, dtype=torch.float32, device=device)
177
+ x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
178
+ y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
179
+ ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32, device=device)
180
+
181
+ avg_time = 0
182
+ for i in range(test_times):
183
+ torch.cuda.synchronize(device=x.device)
184
+ delta_time = time.time()
185
+ x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
186
+ torch.cuda.synchronize(device=x.device)
187
+ delta_time = time.time() - delta_time
188
+ avg_time += delta_time
189
+ print(x_fr.shape)
190
+ print(x_lens_fr)
191
+ print(avg_time / test_times)
test/test_riva_wfst_decoder.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  import os
5
  from riva.asrlib.decoder.python_decoder import BatchedMappedDecoderCuda, BatchedMappedDecoderCudaConfig
6
  from typing import List
 
7
 
8
  def remove_duplicates_and_blank(hyp: List[int],
9
  eos: int,
@@ -28,11 +29,14 @@ class RivaWFSTDecoder:
28
  config.online_opts.decoder_opts.default_beam = 17.0
29
  config.online_opts.decoder_opts.max_active = 7000
30
  config.online_opts.determinize_lattice = True
31
- config.online_opts.max_batch_size = 800
32
- config.online_opts.num_channels = 800
33
  config.online_opts.frame_shift_seconds = 0.04
34
  config.online_opts.lattice_postprocessor_opts.lm_scale = 5.0
35
  config.online_opts.lattice_postprocessor_opts.word_ins_penalty = 0.0
 
 
 
36
 
37
  config.online_opts.lattice_postprocessor_opts.nbest = beam_size
38
 
@@ -87,8 +91,8 @@ if __name__ == "__main__":
87
  char_dict = load_word_symbols('./data/words.txt')
88
 
89
  beam_size = 10
90
- batch_size = 50
91
- counts = 1
92
 
93
  # ctc_log_probs [1,103,4233]
94
  ctc_log_probs = torch.from_numpy(data['ctc_log_probs'])
@@ -97,6 +101,9 @@ if __name__ == "__main__":
97
  encoder_out_lens = torch.from_numpy(data['encoder_out_lens']) # encoder_out_lens single element 103
98
  encoder_out_lens = encoder_out_lens.repeat(batch_size) # [batch_size]
99
  ctc_log_probs = ctc_log_probs.contiguous().cuda()
 
 
 
100
 
101
  vocab_size = ctc_log_probs.shape[2]
102
  riva_decoder = RivaWFSTDecoder(vocab_size, lang_dir, beam_size)
@@ -106,7 +113,7 @@ if __name__ == "__main__":
106
  print("ctc_log_probs.shape:", ctc_log_probs.shape)
107
  total_hyps = riva_decoder.decode_mbr(ctc_log_probs, encoder_out_lens)
108
  print('mbr', total_hyps)
109
- total_hyps = riva_decoder.decode_nbest(ctc_log_probs, encoder_out_lens)
110
- print('nbest', total_hyps)
111
  decode_end = time.perf_counter() - decode_start
112
  print(f"Decode {ctc_log_probs.shape[0] * counts} sentences, cost {decode_end} seconds")
 
4
  import os
5
  from riva.asrlib.decoder.python_decoder import BatchedMappedDecoderCuda, BatchedMappedDecoderCudaConfig
6
  from typing import List
7
+ from test_frame_reducer import FrameReducer
8
 
9
  def remove_duplicates_and_blank(hyp: List[int],
10
  eos: int,
 
29
  config.online_opts.decoder_opts.default_beam = 17.0
30
  config.online_opts.decoder_opts.max_active = 7000
31
  config.online_opts.determinize_lattice = True
32
+ config.online_opts.max_batch_size = 100
33
+ config.online_opts.num_channels = 200
34
  config.online_opts.frame_shift_seconds = 0.04
35
  config.online_opts.lattice_postprocessor_opts.lm_scale = 5.0
36
  config.online_opts.lattice_postprocessor_opts.word_ins_penalty = 0.0
37
+ config.online_opts.decoder_opts.blank_penalty = 0.95
38
+ config.online_opts.num_post_processing_worker_threads = 16
39
+ config.online_opts.num_decoder_copy_threads = 4
40
 
41
  config.online_opts.lattice_postprocessor_opts.nbest = beam_size
42
 
 
91
  char_dict = load_word_symbols('./data/words.txt')
92
 
93
  beam_size = 10
94
+ batch_size = 1
95
+ counts = 10
96
 
97
  # ctc_log_probs [1,103,4233]
98
  ctc_log_probs = torch.from_numpy(data['ctc_log_probs'])
 
101
  encoder_out_lens = torch.from_numpy(data['encoder_out_lens']) # encoder_out_lens single element 103
102
  encoder_out_lens = encoder_out_lens.repeat(batch_size) # [batch_size]
103
  ctc_log_probs = ctc_log_probs.contiguous().cuda()
104
+ frame_reducer = FrameReducer()
105
+
106
+ ctc_log_probs, encoder_out_len = frame_reducer(ctc_log_probs, encoder_out_lens.cuda(), ctc_log_probs)
107
 
108
  vocab_size = ctc_log_probs.shape[2]
109
  riva_decoder = RivaWFSTDecoder(vocab_size, lang_dir, beam_size)
 
113
  print("ctc_log_probs.shape:", ctc_log_probs.shape)
114
  total_hyps = riva_decoder.decode_mbr(ctc_log_probs, encoder_out_lens)
115
  print('mbr', total_hyps)
116
+ # total_hyps = riva_decoder.decode_nbest(ctc_log_probs, encoder_out_lens)
117
+ # print('nbest', total_hyps)
118
  decode_end = time.perf_counter() - decode_start
119
  print(f"Decode {ctc_log_probs.shape[0] * counts} sentences, cost {decode_end} seconds")