File size: 9,935 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from functools import lru_cache

import pytest
import torch
from omegaconf import DictConfig

from nemo.collections.asr.metrics.rnnt_wer import RNNTDecoding, RNNTDecodingConfig
from nemo.collections.asr.metrics.rnnt_wer_bpe import RNNTBPEDecoding, RNNTBPEDecodingConfig
from nemo.collections.asr.models import ASRModel
from nemo.collections.asr.modules import RNNTDecoder, RNNTJoint
from nemo.collections.asr.parts.mixins import mixins
from nemo.collections.asr.parts.submodules import rnnt_beam_decoding as beam_decode
from nemo.collections.asr.parts.submodules import rnnt_greedy_decoding as greedy_decode
from nemo.collections.asr.parts.utils import rnnt_utils
from nemo.core.utils import numba_utils
from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__

NUMBA_RNNT_LOSS_AVAILABLE = numba_utils.numba_cpu_is_supported(
    __NUMBA_MINIMUM_VERSION__
) or numba_utils.numba_cuda_is_supported(__NUMBA_MINIMUM_VERSION__)


def char_vocabulary():
    return [' ', 'a', 'b', 'c', 'd', 'e', 'f']


@pytest.fixture()
@lru_cache(maxsize=8)
def tmp_tokenizer(test_data_dir):
    cfg = DictConfig({'dir': os.path.join(test_data_dir, "asr", "tokenizers", "an4_wpe_128"), 'type': 'wpe'})

    class _TmpASRBPE(mixins.ASRBPEMixin):
        def register_artifact(self, _, vocab_path):
            return vocab_path

    asrbpe = _TmpASRBPE()
    asrbpe._setup_tokenizer(cfg)
    return asrbpe.tokenizer


@lru_cache(maxsize=2)
def get_rnnt_decoder(vocab_size, decoder_output_size=4):
    prednet_cfg = {'pred_hidden': decoder_output_size, 'pred_rnn_layers': 1}
    torch.manual_seed(0)
    decoder = RNNTDecoder(prednet=prednet_cfg, vocab_size=vocab_size)
    decoder.freeze()
    return decoder


@lru_cache(maxsize=2)
def get_rnnt_joint(vocab_size, vocabulary=None, encoder_output_size=4, decoder_output_size=4, joint_output_shape=4):
    jointnet_cfg = {
        'encoder_hidden': encoder_output_size,
        'pred_hidden': decoder_output_size,
        'joint_hidden': joint_output_shape,
        'activation': 'relu',
    }
    torch.manual_seed(0)
    joint = RNNTJoint(jointnet_cfg, vocab_size, vocabulary=vocabulary)
    joint.freeze()
    return joint


@lru_cache(maxsize=1)
def get_model_encoder_output(data_dir, model_name):
    # Import inside function to avoid issues with dependencies
    import librosa

    audio_filepath = os.path.join(data_dir, 'asr', 'test', 'an4', 'wav', 'cen3-fjlp-b.wav')

    with torch.no_grad():
        model = ASRModel.from_pretrained(model_name, map_location='cpu')  # type: ASRModel
        model.preprocessor.featurizer.dither = 0.0
        model.preprocessor.featurizer.pad_to = 0

        audio, sr = librosa.load(path=audio_filepath, sr=16000, mono=True)

        input_signal = torch.tensor(audio, dtype=torch.float32).unsqueeze(0)
        input_signal_length = torch.tensor([len(audio)], dtype=torch.int32)

        encoded, encoded_len = model(input_signal=input_signal, input_signal_length=input_signal_length)

    return model, encoded, encoded_len


def decode_text_from_greedy_hypotheses(hyps, decoding):
    decoded_hyps = decoding.decode_hypothesis(hyps)  # type: List[str]

    return decoded_hyps


def decode_text_from_nbest_hypotheses(hyps, decoding):
    hypotheses = []
    all_hypotheses = []

    for nbest_hyp in hyps:  # type: rnnt_utils.NBestHypotheses
        n_hyps = nbest_hyp.n_best_hypotheses  # Extract all hypotheses for this sample
        decoded_hyps = decoding.decode_hypothesis(n_hyps)  # type: List[str]

        hypotheses.append(decoded_hyps[0])  # best hypothesis
        all_hypotheses.append(decoded_hyps)

    return hypotheses, all_hypotheses


class TestRNNTDecoding:
    @pytest.mark.unit
    def test_constructor(self):
        cfg = RNNTDecodingConfig()
        vocab = char_vocabulary()
        decoder = get_rnnt_decoder(vocab_size=len(vocab))
        joint = get_rnnt_joint(vocab_size=len(vocab))
        decoding = RNNTDecoding(decoding_cfg=cfg, decoder=decoder, joint=joint, vocabulary=vocab)
        assert decoding is not None

    @pytest.mark.unit
    def test_constructor_subword(self, tmp_tokenizer):
        cfg = RNNTBPEDecodingConfig()
        vocab = tmp_tokenizer.vocab
        decoder = get_rnnt_decoder(vocab_size=len(vocab))
        joint = get_rnnt_joint(vocab_size=len(vocab))
        decoding = RNNTBPEDecoding(decoding_cfg=cfg, decoder=decoder, joint=joint, tokenizer=tmp_tokenizer)
        assert decoding is not None

    @pytest.mark.skipif(
        not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.',
    )
    @pytest.mark.with_downloads
    @pytest.mark.unit
    def test_greedy_decoding_preserve_alignments(self, test_data_dir):
        model, encoded, encoded_len = get_model_encoder_output(test_data_dir, 'stt_en_conformer_transducer_small')

        beam = greedy_decode.GreedyRNNTInfer(
            model.decoder,
            model.joint,
            blank_index=model.joint.num_classes_with_blank - 1,
            max_symbols_per_step=5,
            preserve_alignments=True,
        )

        enc_out = encoded
        enc_len = encoded_len

        with torch.no_grad():
            hyps = beam(encoder_output=enc_out, encoded_lengths=enc_len)[0]  # type: rnnt_utils.Hypothesis
            hyp = decode_text_from_greedy_hypotheses(hyps, model.decoding)
            hyp = hyp[0]

            assert hyp.alignments is not None

            # Use the following commented print statements to check
            # the alignment of other algorithms compared to the default
            print("Text", hyp.text)
            for t in range(len(hyp.alignments)):
                t_u = []
                for u in range(len(hyp.alignments[t])):
                    logp, label = hyp.alignments[t][u]
                    assert torch.is_tensor(logp)
                    assert torch.is_tensor(label)

                    t_u.append(int(label))

                print(f"Tokens at timestep {t} = {t_u}")
            print()

    @pytest.mark.skipif(
        not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.',
    )
    @pytest.mark.with_downloads
    @pytest.mark.unit
    @pytest.mark.parametrize(
        "beam_config",
        [
            {"search_type": "greedy"},
            {"search_type": "default", "beam_size": 2,},
            {"search_type": "alsd", "alsd_max_target_len": 0.5, "beam_size": 2,},
            {"search_type": "tsd", "tsd_max_sym_exp_per_step": 3, "beam_size": 2,},
            {"search_type": "maes", "maes_num_steps": 2, "maes_expansion_beta": 2, "beam_size": 2},
            {"search_type": "maes", "maes_num_steps": 3, "maes_expansion_beta": 1, "beam_size": 2},
        ],
    )
    def test_beam_decoding_preserve_alignments(self, test_data_dir, beam_config):
        beam_size = beam_config.pop("beam_size", 1)
        model, encoded, encoded_len = get_model_encoder_output(test_data_dir, 'stt_en_conformer_transducer_small')
        beam = beam_decode.BeamRNNTInfer(
            model.decoder,
            model.joint,
            beam_size=beam_size,
            return_best_hypothesis=False,
            preserve_alignments=True,
            **beam_config,
        )

        enc_out = encoded
        enc_len = encoded_len
        blank_id = torch.tensor(model.joint.num_classes_with_blank - 1, dtype=torch.int32)

        with torch.no_grad():
            hyps = beam(encoder_output=enc_out, encoded_lengths=enc_len)[0]  # type: rnnt_utils.Hypothesis
            hyp, all_hyps = decode_text_from_nbest_hypotheses(hyps, model.decoding)
            hyp = hyp[0]  # best hypothesis
            all_hyps = all_hyps[0]

            assert hyp.alignments is not None

            if beam_config['search_type'] == 'alsd':
                assert len(all_hyps) <= int(beam_config['alsd_max_target_len'] * float(enc_len[0]))

            print("Beam search algorithm :", beam_config['search_type'])
            # Use the following commented print statements to check
            # the alignment of other algorithms compared to the default
            for idx, hyp_ in enumerate(all_hyps):  # type: (int, rnnt_utils.Hypothesis)
                print("Hyp index", idx + 1, "text :", hyp_.text)

                # Alignment length (T) must match audio length (T)
                assert abs(len(hyp_.alignments) - enc_len[0]) <= 1

                for t in range(len(hyp_.alignments)):
                    t_u = []
                    for u in range(len(hyp_.alignments[t])):
                        logp, label = hyp_.alignments[t][u]
                        assert torch.is_tensor(logp)
                        assert torch.is_tensor(label)

                        t_u.append(int(label))

                    # Blank token must be the last token in the current
                    if len(t_u) > 1:
                        assert t_u[-1] == blank_id

                        # No blank token should be present in the current timestep other than at the end
                        for token in t_u[:-1]:
                            assert token != blank_id

                    print(f"Tokens at timestep {t} = {t_u}")
                print()

                assert len(hyp_.timestep) > 0
                print("Timesteps", hyp_.timestep)
                print()