Yuekai Zhang
commited on
Commit
·
57bf40b
1
Parent(s):
55f6a13
add blank skip
Browse files- test/test_frame_reducer.py +191 -0
- test/test_riva_wfst_decoder.py +13 -6
test/test_frame_reducer.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang,
|
| 4 |
+
# Zengwei Yao,
|
| 5 |
+
# Wei Kang)
|
| 6 |
+
#
|
| 7 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
|
| 21 |
+
import math
|
| 22 |
+
from typing import Optional, Tuple
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
|
| 28 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
| 29 |
+
"""
|
| 30 |
+
Args:
|
| 31 |
+
lengths:
|
| 32 |
+
A 1-D tensor containing sentence lengths.
|
| 33 |
+
max_len:
|
| 34 |
+
The length of masks.
|
| 35 |
+
Returns:
|
| 36 |
+
Return a 2-D bool tensor, where masked positions
|
| 37 |
+
are filled with `True` and non-masked positions are
|
| 38 |
+
filled with `False`.
|
| 39 |
+
>>> lengths = torch.tensor([1, 3, 2, 5])
|
| 40 |
+
>>> make_pad_mask(lengths)
|
| 41 |
+
tensor([[False, True, True, True, True],
|
| 42 |
+
[False, False, False, True, True],
|
| 43 |
+
[False, False, True, True, True],
|
| 44 |
+
[False, False, False, False, False]])
|
| 45 |
+
"""
|
| 46 |
+
assert lengths.ndim == 1, lengths.ndim
|
| 47 |
+
max_len = max(max_len, lengths.max())
|
| 48 |
+
n = lengths.size(0)
|
| 49 |
+
seq_range = torch.arange(0, max_len, device=lengths.device)
|
| 50 |
+
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
| 51 |
+
|
| 52 |
+
return expaned_lengths >= lengths.unsqueeze(-1)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class FrameReducer(nn.Module):
|
| 57 |
+
"""The encoder output is first used to calculate
|
| 58 |
+
the CTC posterior probability; then for each output frame,
|
| 59 |
+
if its blank posterior is bigger than some thresholds,
|
| 60 |
+
it will be simply discarded from the encoder output.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
):
|
| 66 |
+
super().__init__()
|
| 67 |
+
|
| 68 |
+
def forward(
|
| 69 |
+
self,
|
| 70 |
+
x: torch.Tensor,
|
| 71 |
+
x_lens: torch.Tensor,
|
| 72 |
+
ctc_output: torch.Tensor,
|
| 73 |
+
y_lens: Optional[torch.Tensor] = None,
|
| 74 |
+
blank_id: int = 0,
|
| 75 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 76 |
+
"""
|
| 77 |
+
Args:
|
| 78 |
+
x:
|
| 79 |
+
The shared encoder output with shape [N, T, C].
|
| 80 |
+
x_lens:
|
| 81 |
+
A tensor of shape (batch_size,) containing the number of frames in
|
| 82 |
+
`x` before padding.
|
| 83 |
+
ctc_output:
|
| 84 |
+
The CTC output with shape [N, T, vocab_size].
|
| 85 |
+
y_lens:
|
| 86 |
+
A tensor of shape (batch_size,) containing the number of frames in
|
| 87 |
+
`y` before padding.
|
| 88 |
+
blank_id:
|
| 89 |
+
The blank id of ctc_output.
|
| 90 |
+
Returns:
|
| 91 |
+
out:
|
| 92 |
+
The frame reduced encoder output with shape [N, T', C].
|
| 93 |
+
out_lens:
|
| 94 |
+
A tensor of shape (batch_size,) containing the number of frames in
|
| 95 |
+
`out` before padding.
|
| 96 |
+
"""
|
| 97 |
+
N, T, C = x.size()
|
| 98 |
+
|
| 99 |
+
padding_mask = make_pad_mask(x_lens)
|
| 100 |
+
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
|
| 101 |
+
|
| 102 |
+
if y_lens is not None:
|
| 103 |
+
# Limit the maximum number of reduced frames
|
| 104 |
+
limit_lens = T - y_lens
|
| 105 |
+
max_limit_len = limit_lens.max().int()
|
| 106 |
+
fake_limit_indexes = torch.topk(
|
| 107 |
+
ctc_output[:, :, blank_id], max_limit_len
|
| 108 |
+
).indices
|
| 109 |
+
T = (
|
| 110 |
+
torch.arange(max_limit_len)
|
| 111 |
+
.expand_as(
|
| 112 |
+
fake_limit_indexes,
|
| 113 |
+
)
|
| 114 |
+
.to(device=x.device)
|
| 115 |
+
)
|
| 116 |
+
T = torch.remainder(T, limit_lens.unsqueeze(1))
|
| 117 |
+
limit_indexes = torch.gather(fake_limit_indexes, 1, T)
|
| 118 |
+
limit_mask = torch.full_like(
|
| 119 |
+
non_blank_mask,
|
| 120 |
+
False,
|
| 121 |
+
device=x.device,
|
| 122 |
+
).scatter_(1, limit_indexes, True)
|
| 123 |
+
|
| 124 |
+
non_blank_mask = non_blank_mask | ~limit_mask
|
| 125 |
+
|
| 126 |
+
out_lens = non_blank_mask.sum(dim=1)
|
| 127 |
+
max_len = out_lens.max()
|
| 128 |
+
pad_lens_list = (
|
| 129 |
+
torch.full_like(
|
| 130 |
+
out_lens,
|
| 131 |
+
max_len.item(),
|
| 132 |
+
device=x.device,
|
| 133 |
+
)
|
| 134 |
+
- out_lens
|
| 135 |
+
)
|
| 136 |
+
max_pad_len = pad_lens_list.max()
|
| 137 |
+
|
| 138 |
+
out = F.pad(x, (0, 0, 0, max_pad_len))
|
| 139 |
+
|
| 140 |
+
valid_pad_mask = ~make_pad_mask(pad_lens_list)
|
| 141 |
+
total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1)
|
| 142 |
+
|
| 143 |
+
out = out[total_valid_mask].reshape(N, -1, C)
|
| 144 |
+
|
| 145 |
+
return out, out_lens
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
if __name__ == "__main__":
|
| 149 |
+
import time
|
| 150 |
+
|
| 151 |
+
test_times = 10000
|
| 152 |
+
device = "cuda:0"
|
| 153 |
+
frame_reducer = FrameReducer()
|
| 154 |
+
|
| 155 |
+
# non zero case
|
| 156 |
+
x = torch.ones(15, 498, 384, dtype=torch.float32, device=device)
|
| 157 |
+
x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
|
| 158 |
+
y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
|
| 159 |
+
ctc_output = torch.log(
|
| 160 |
+
torch.randn(15, 498, 500, dtype=torch.float32, device=device),
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
avg_time = 0
|
| 164 |
+
for i in range(test_times):
|
| 165 |
+
torch.cuda.synchronize(device=x.device)
|
| 166 |
+
delta_time = time.time()
|
| 167 |
+
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
|
| 168 |
+
torch.cuda.synchronize(device=x.device)
|
| 169 |
+
delta_time = time.time() - delta_time
|
| 170 |
+
avg_time += delta_time
|
| 171 |
+
print(x_fr.shape)
|
| 172 |
+
print(x_lens_fr)
|
| 173 |
+
print(avg_time / test_times)
|
| 174 |
+
|
| 175 |
+
# all zero case
|
| 176 |
+
x = torch.zeros(15, 498, 384, dtype=torch.float32, device=device)
|
| 177 |
+
x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
|
| 178 |
+
y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
|
| 179 |
+
ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32, device=device)
|
| 180 |
+
|
| 181 |
+
avg_time = 0
|
| 182 |
+
for i in range(test_times):
|
| 183 |
+
torch.cuda.synchronize(device=x.device)
|
| 184 |
+
delta_time = time.time()
|
| 185 |
+
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
|
| 186 |
+
torch.cuda.synchronize(device=x.device)
|
| 187 |
+
delta_time = time.time() - delta_time
|
| 188 |
+
avg_time += delta_time
|
| 189 |
+
print(x_fr.shape)
|
| 190 |
+
print(x_lens_fr)
|
| 191 |
+
print(avg_time / test_times)
|
test/test_riva_wfst_decoder.py
CHANGED
|
@@ -4,6 +4,7 @@ import torch
|
|
| 4 |
import os
|
| 5 |
from riva.asrlib.decoder.python_decoder import BatchedMappedDecoderCuda, BatchedMappedDecoderCudaConfig
|
| 6 |
from typing import List
|
|
|
|
| 7 |
|
| 8 |
def remove_duplicates_and_blank(hyp: List[int],
|
| 9 |
eos: int,
|
|
@@ -28,11 +29,14 @@ class RivaWFSTDecoder:
|
|
| 28 |
config.online_opts.decoder_opts.default_beam = 17.0
|
| 29 |
config.online_opts.decoder_opts.max_active = 7000
|
| 30 |
config.online_opts.determinize_lattice = True
|
| 31 |
-
config.online_opts.max_batch_size =
|
| 32 |
-
config.online_opts.num_channels =
|
| 33 |
config.online_opts.frame_shift_seconds = 0.04
|
| 34 |
config.online_opts.lattice_postprocessor_opts.lm_scale = 5.0
|
| 35 |
config.online_opts.lattice_postprocessor_opts.word_ins_penalty = 0.0
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
config.online_opts.lattice_postprocessor_opts.nbest = beam_size
|
| 38 |
|
|
@@ -87,8 +91,8 @@ if __name__ == "__main__":
|
|
| 87 |
char_dict = load_word_symbols('./data/words.txt')
|
| 88 |
|
| 89 |
beam_size = 10
|
| 90 |
-
batch_size =
|
| 91 |
-
counts =
|
| 92 |
|
| 93 |
# ctc_log_probs [1,103,4233]
|
| 94 |
ctc_log_probs = torch.from_numpy(data['ctc_log_probs'])
|
|
@@ -97,6 +101,9 @@ if __name__ == "__main__":
|
|
| 97 |
encoder_out_lens = torch.from_numpy(data['encoder_out_lens']) # encoder_out_lens single element 103
|
| 98 |
encoder_out_lens = encoder_out_lens.repeat(batch_size) # [batch_size]
|
| 99 |
ctc_log_probs = ctc_log_probs.contiguous().cuda()
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
vocab_size = ctc_log_probs.shape[2]
|
| 102 |
riva_decoder = RivaWFSTDecoder(vocab_size, lang_dir, beam_size)
|
|
@@ -106,7 +113,7 @@ if __name__ == "__main__":
|
|
| 106 |
print("ctc_log_probs.shape:", ctc_log_probs.shape)
|
| 107 |
total_hyps = riva_decoder.decode_mbr(ctc_log_probs, encoder_out_lens)
|
| 108 |
print('mbr', total_hyps)
|
| 109 |
-
total_hyps = riva_decoder.decode_nbest(ctc_log_probs, encoder_out_lens)
|
| 110 |
-
print('nbest', total_hyps)
|
| 111 |
decode_end = time.perf_counter() - decode_start
|
| 112 |
print(f"Decode {ctc_log_probs.shape[0] * counts} sentences, cost {decode_end} seconds")
|
|
|
|
| 4 |
import os
|
| 5 |
from riva.asrlib.decoder.python_decoder import BatchedMappedDecoderCuda, BatchedMappedDecoderCudaConfig
|
| 6 |
from typing import List
|
| 7 |
+
from test_frame_reducer import FrameReducer
|
| 8 |
|
| 9 |
def remove_duplicates_and_blank(hyp: List[int],
|
| 10 |
eos: int,
|
|
|
|
| 29 |
config.online_opts.decoder_opts.default_beam = 17.0
|
| 30 |
config.online_opts.decoder_opts.max_active = 7000
|
| 31 |
config.online_opts.determinize_lattice = True
|
| 32 |
+
config.online_opts.max_batch_size = 100
|
| 33 |
+
config.online_opts.num_channels = 200
|
| 34 |
config.online_opts.frame_shift_seconds = 0.04
|
| 35 |
config.online_opts.lattice_postprocessor_opts.lm_scale = 5.0
|
| 36 |
config.online_opts.lattice_postprocessor_opts.word_ins_penalty = 0.0
|
| 37 |
+
config.online_opts.decoder_opts.blank_penalty = 0.95
|
| 38 |
+
config.online_opts.num_post_processing_worker_threads = 16
|
| 39 |
+
config.online_opts.num_decoder_copy_threads = 4
|
| 40 |
|
| 41 |
config.online_opts.lattice_postprocessor_opts.nbest = beam_size
|
| 42 |
|
|
|
|
| 91 |
char_dict = load_word_symbols('./data/words.txt')
|
| 92 |
|
| 93 |
beam_size = 10
|
| 94 |
+
batch_size = 1
|
| 95 |
+
counts = 10
|
| 96 |
|
| 97 |
# ctc_log_probs [1,103,4233]
|
| 98 |
ctc_log_probs = torch.from_numpy(data['ctc_log_probs'])
|
|
|
|
| 101 |
encoder_out_lens = torch.from_numpy(data['encoder_out_lens']) # encoder_out_lens single element 103
|
| 102 |
encoder_out_lens = encoder_out_lens.repeat(batch_size) # [batch_size]
|
| 103 |
ctc_log_probs = ctc_log_probs.contiguous().cuda()
|
| 104 |
+
frame_reducer = FrameReducer()
|
| 105 |
+
|
| 106 |
+
ctc_log_probs, encoder_out_len = frame_reducer(ctc_log_probs, encoder_out_lens.cuda(), ctc_log_probs)
|
| 107 |
|
| 108 |
vocab_size = ctc_log_probs.shape[2]
|
| 109 |
riva_decoder = RivaWFSTDecoder(vocab_size, lang_dir, beam_size)
|
|
|
|
| 113 |
print("ctc_log_probs.shape:", ctc_log_probs.shape)
|
| 114 |
total_hyps = riva_decoder.decode_mbr(ctc_log_probs, encoder_out_lens)
|
| 115 |
print('mbr', total_hyps)
|
| 116 |
+
# total_hyps = riva_decoder.decode_nbest(ctc_log_probs, encoder_out_lens)
|
| 117 |
+
# print('nbest', total_hyps)
|
| 118 |
decode_end = time.perf_counter() - decode_start
|
| 119 |
print(f"Decode {ctc_log_probs.shape[0] * counts} sentences, cost {decode_end} seconds")
|