| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import copy |
| import glob |
| import os |
| from pathlib import Path |
|
|
| import jiwer |
| import pytest |
| import torch |
| from omegaconf import open_dict |
| from tqdm import tqdm |
|
|
| from nemo.collections.asr.models import ASRModel |
| from nemo.collections.asr.models.ctc_models import EncDecCTCModel |
| from nemo.collections.asr.parts.submodules.ctc_beam_decoding import BeamBatchedCTCInfer |
| from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel |
| from nemo.collections.asr.parts.submodules.rnnt_beam_decoding import BeamBatchedRNNTInfer |
| from nemo.collections.asr.parts.submodules.tdt_beam_decoding import BeamBatchedTDTInfer |
| from nemo.collections.asr.parts.utils import rnnt_utils |
| from nemo.core.utils import numba_utils |
| from nemo.core.utils.cuda_python_utils import skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported |
| from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__ |
| from tests.collections.asr.decoding.utils import load_audio |
|
|
| RNNT_MODEL = "stt_en_conformer_transducer_small" |
| CTC_MODEL = "nvidia/stt_en_conformer_ctc_small" |
| TDT_MODEL = "nvidia/stt_en_fastconformer_tdt_large" |
| MAX_SAMPLES = 10 |
|
|
| DEVICES = [torch.device("cpu")] |
|
|
| if torch.cuda.is_available(): |
| DEVICES.append(torch.device('cuda')) |
|
|
| NUMBA_RNNT_LOSS_AVAILABLE = numba_utils.numba_cpu_is_supported( |
| __NUMBA_MINIMUM_VERSION__ |
| ) or numba_utils.numba_cuda_is_supported(__NUMBA_MINIMUM_VERSION__) |
|
|
|
|
| |
| @pytest.fixture(scope="module") |
| def test_audio_filenames(test_data_dir): |
| return tuple(glob.glob(os.path.join(test_data_dir, "asr", "test", "an4", "wav", "*.wav"))) |
|
|
|
|
| |
| @pytest.fixture(scope="module") |
| def rnnt_model(): |
| model = ASRModel.from_pretrained(model_name=RNNT_MODEL, map_location="cpu") |
| model.eval() |
| return model |
|
|
|
|
| @pytest.fixture(scope="module") |
| def tdt_model(): |
| model = ASRModel.from_pretrained(model_name=TDT_MODEL, map_location="cpu") |
| model.eval() |
| return model |
|
|
|
|
| @pytest.fixture(scope="module") |
| def ctc_model(): |
| model = ASRModel.from_pretrained(model_name=CTC_MODEL, map_location="cpu") |
| model.eval() |
| return model |
|
|
|
|
| |
| @pytest.fixture(scope="module") |
| def get_rnnt_encoder_output(rnnt_model, test_audio_filenames): |
| encoder_output, encoded_lengths = get_transducer_model_encoder_output( |
| test_audio_filenames, MAX_SAMPLES, rnnt_model |
| ) |
| return encoder_output, encoded_lengths |
|
|
|
|
| @pytest.fixture(scope="module") |
| def get_tdt_encoder_output(tdt_model, test_audio_filenames): |
| encoder_output, encoded_lengths = get_transducer_model_encoder_output(test_audio_filenames, MAX_SAMPLES, tdt_model) |
| return encoder_output, encoded_lengths |
|
|
|
|
| @pytest.fixture(scope="module") |
| def get_ctc_output(ctc_model, test_audio_filenames): |
| encoder_output, encoded_lengths = get_ctc_model_output(test_audio_filenames, MAX_SAMPLES, ctc_model) |
| return encoder_output, encoded_lengths |
|
|
|
|
| @pytest.fixture(scope="module") |
| def kenlm_model_path(tmp_path_factory, test_data_dir): |
| lm_path = Path(test_data_dir) / "asr/kenlm_ngram_lm/parakeet-tdt_ctc-110m-libri-1024.kenlm.tmp.arpa" |
| assert os.path.exists(lm_path), f"LM file not found: {lm_path}" |
| lm_nemo_path = tmp_path_factory.mktemp("lm") / f"{lm_path.name}.nemo" |
| NGramGPULanguageModel.from_file(lm_path, vocab_size=1024).save_to(f"{lm_nemo_path}") |
| return f"{lm_nemo_path}" |
|
|
|
|
| def get_transducer_model_encoder_output( |
| test_audio_filenames, |
| num_samples: int, |
| model: ASRModel, |
| device: torch.device = torch.device("cpu"), |
| dtype: torch.dtype = torch.float32, |
| ): |
| audio_filepaths = test_audio_filenames[:num_samples] |
|
|
| with torch.no_grad(): |
| model.preprocessor.featurizer.dither = 0.0 |
| model.preprocessor.featurizer.pad_to = 0 |
| model.eval() |
|
|
| all_inputs, all_lengths = [], [] |
| for audio_file in tqdm(audio_filepaths, desc="Loading audio files"): |
| audio_tensor, _ = load_audio(audio_file) |
| all_inputs.append(audio_tensor) |
| all_lengths.append(torch.tensor(audio_tensor.shape[0], dtype=torch.int64)) |
|
|
| input_batch = torch.nn.utils.rnn.pad_sequence(all_inputs, batch_first=True).to(device=device, dtype=dtype) |
| length_batch = torch.tensor(all_lengths, dtype=torch.int64).to(device) |
|
|
| encoded_outputs, encoded_length = model(input_signal=input_batch, input_signal_length=length_batch) |
|
|
| return encoded_outputs, encoded_length |
|
|
|
|
| def get_ctc_model_output( |
| test_audio_filenames, |
| num_samples: int, |
| model: ASRModel, |
| device: torch.device = torch.device("cpu"), |
| dtype: torch.dtype = torch.float32, |
| ): |
| audio_filepaths = test_audio_filenames[:num_samples] |
|
|
| with torch.no_grad(): |
| model.preprocessor.featurizer.dither = 0.0 |
| model.preprocessor.featurizer.pad_to = 0 |
| model.eval() |
|
|
| all_inputs, all_lengths = [], [] |
| for audio_file in tqdm(audio_filepaths, desc="Loading audio files"): |
| audio_tensor, _ = load_audio(audio_file) |
| all_inputs.append(audio_tensor) |
| all_lengths.append(torch.tensor(audio_tensor.shape[0], dtype=torch.int64)) |
|
|
| input_batch = torch.nn.utils.rnn.pad_sequence(all_inputs, batch_first=True).to(device=device, dtype=dtype) |
| length_batch = torch.tensor(all_lengths, dtype=torch.int64).to(device) |
|
|
| log_probs, encoded_length, _ = model(input_signal=input_batch, input_signal_length=length_batch) |
|
|
| return log_probs, encoded_length |
|
|
|
|
| def print_unit_test_info(strategy, batch_size, beam_size, allow_cuda_graphs, device): |
| print( |
| f"""Beam search algorithm: {strategy}, |
| Batch size: {batch_size}, |
| Beam size: {beam_size}, |
| Cuda Graphs: {allow_cuda_graphs}, |
| Decoding device: {device} |
| """ |
| ) |
|
|
|
|
| def check_res_best_hyps(num_samples, hyps): |
| assert type(hyps) == list |
| assert type(hyps[0]) == rnnt_utils.Hypothesis |
|
|
| assert len(hyps) == num_samples |
|
|
| assert all( |
| [ |
| hasattr(hyps[hyp_idx], "y_sequence") |
| and hasattr(hyps[hyp_idx], "score") |
| and hasattr(hyps[hyp_idx], "timestamp") |
| for hyp_idx in range(num_samples) |
| ] |
| ) |
|
|
|
|
| def print_res_best_hyps(hyps): |
| for hyp_idx, hyp in enumerate(hyps): |
| print("Sample: ", hyp_idx) |
| print("Decoded text: ", hyp.text) |
| print("Score: ", hyp.score) |
| print("Transcript", hyp.y_sequence) |
| print("Timesteps", hyp.timestamp) |
| print() |
|
|
|
|
| def check_res_nbest_hyps(num_samples, batch_nbest_hyps): |
| assert type(batch_nbest_hyps) == list |
| assert type(batch_nbest_hyps[0]) == rnnt_utils.NBestHypotheses |
|
|
| assert len(batch_nbest_hyps) == num_samples |
|
|
| for idx in range(num_samples): |
| assert all( |
| [ |
| hasattr(batch_nbest_hyps[idx].n_best_hypotheses[hyp_idx], "y_sequence") |
| and hasattr(batch_nbest_hyps[idx].n_best_hypotheses[hyp_idx], "score") |
| and hasattr(batch_nbest_hyps[idx].n_best_hypotheses[hyp_idx], "timestamp") |
| for hyp_idx in range(len(batch_nbest_hyps[idx].n_best_hypotheses)) |
| ] |
| ) |
|
|
| assert all( |
| [ |
| len(batch_nbest_hyps[idx].n_best_hypotheses[hyp_idx].y_sequence) > 0 |
| and len(batch_nbest_hyps[idx].n_best_hypotheses[hyp_idx].timestamp) > 0 |
| for hyp_idx in range(len(batch_nbest_hyps[idx].n_best_hypotheses)) |
| ] |
| ) |
|
|
|
|
| def print_res_nbest_hyps(batch_nbest_hyps): |
| for batch_idx, nbest_hyps in enumerate(batch_nbest_hyps): |
| print(f"Batch idx: {batch_idx}") |
| for idx, hyp in enumerate(nbest_hyps): |
| print(f"Hyp index: {idx + 1}") |
| print("Text: ", hyp.text) |
| print("Score: ", hyp.score) |
| print("Transcripts: ", hyp.y_sequence) |
| print("Timesteps: ", hyp.timestamp) |
| print() |
|
|
|
|
| def decode_text_from_hypotheses(hyps, model): |
| if isinstance(model, EncDecCTCModel): |
| return model.decoding.decode_hypothesis(hyps, fold_consecutive=False) |
| else: |
| return model.decoding.decode_hypothesis(hyps) |
|
|
|
|
| def decode_text_from_nbest_hypotheses(hyps, model): |
| if isinstance(model, EncDecCTCModel): |
| return [ |
| model.decoding.decode_hypothesis(nbest_hyp.n_best_hypotheses, fold_consecutive=False) for nbest_hyp in hyps |
| ] |
| else: |
| return [model.decoding.decode_hypothesis(nbest_hyp.n_best_hypotheses) for nbest_hyp in hyps] |
|
|
|
|
| class TestRNNTDecoding: |
| @pytest.mark.skipif( |
| not NUMBA_RNNT_LOSS_AVAILABLE, |
| reason='RNNTLoss has not been compiled with appropriate numba version.', |
| ) |
| @pytest.mark.with_downloads |
| @pytest.mark.unit |
| @pytest.mark.parametrize( |
| "beam_config", |
| [ |
| {"search_type": "malsd_batch", "allow_cuda_graphs": False}, |
| {"search_type": "malsd_batch", "allow_cuda_graphs": True}, |
| {"search_type": "maes_batch", "allow_cuda_graphs": False}, |
| ], |
| ) |
| @pytest.mark.parametrize("beam_size", [4]) |
| @pytest.mark.parametrize("batch_size", [4, 16]) |
| @pytest.mark.parametrize("device", DEVICES) |
| def test_rnnt_beam_decoding_return_best_hypothesis( |
| self, test_audio_filenames, rnnt_model, get_rnnt_encoder_output, beam_config, device, batch_size, beam_size |
| ): |
| num_samples = min(batch_size, len(test_audio_filenames)) |
| model = rnnt_model.to(device) |
| encoder_output, encoded_lengths = get_rnnt_encoder_output |
| encoder_output, encoded_lengths = encoder_output[:num_samples].to(device), encoded_lengths[:num_samples].to( |
| device |
| ) |
|
|
| vocab_size = model.tokenizer.vocab_size |
| decoding = BeamBatchedRNNTInfer( |
| model.decoder, |
| model.joint, |
| blank_index=vocab_size, |
| beam_size=beam_size, |
| score_norm=True, |
| return_best_hypothesis=True, |
| **beam_config, |
| ) |
|
|
| print_unit_test_info( |
| strategy=beam_config['search_type'], |
| batch_size=batch_size, |
| beam_size=beam_size, |
| allow_cuda_graphs=beam_config.get('allow_cuda_graphs', True), |
| device=device, |
| ) |
|
|
| with torch.no_grad(): |
| hyps = decoding(encoder_output=encoder_output, encoded_lengths=encoded_lengths)[0] |
|
|
| check_res_best_hyps(num_samples, hyps) |
| hyps = decode_text_from_hypotheses(hyps, model) |
| print_res_best_hyps(hyps) |
|
|
| @pytest.mark.skipif( |
| not NUMBA_RNNT_LOSS_AVAILABLE, |
| reason='RNNTLoss has not been compiled with appropriate numba version.', |
| ) |
| @pytest.mark.with_downloads |
| @pytest.mark.unit |
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="Test is only GPU-based decoding") |
| @pytest.mark.parametrize( |
| "beam_config", |
| [ |
| {"search_type": "malsd_batch", "allow_cuda_graphs": False}, |
| {"search_type": "malsd_batch", "allow_cuda_graphs": True}, |
| {"search_type": "maes_batch", "allow_cuda_graphs": False}, |
| ], |
| ) |
| @pytest.mark.parametrize("beam_size", [4]) |
| @pytest.mark.parametrize("batch_size", [4]) |
| def test_rnnt_beam_decoding_return_nbest( |
| self, test_audio_filenames, rnnt_model, get_rnnt_encoder_output, beam_config, device, beam_size, batch_size |
| ): |
| device = torch.device("cuda") |
| num_samples = min(batch_size, len(test_audio_filenames)) |
| model = rnnt_model.to(device) |
| encoder_output, encoded_lengths = get_rnnt_encoder_output |
| encoder_output, encoded_lengths = encoder_output[:num_samples].to(device), encoded_lengths[:num_samples].to( |
| device |
| ) |
|
|
| vocab_size = model.tokenizer.vocab_size |
| decoding = BeamBatchedRNNTInfer( |
| model.decoder, |
| model.joint, |
| blank_index=vocab_size, |
| beam_size=beam_size, |
| score_norm=True, |
| return_best_hypothesis=False, |
| **beam_config, |
| ) |
|
|
| print_unit_test_info( |
| strategy=beam_config['search_type'], |
| batch_size=batch_size, |
| beam_size=beam_size, |
| allow_cuda_graphs=beam_config.get('allow_cuda_graphs', True), |
| device=device, |
| ) |
|
|
| with torch.no_grad(): |
| batch_nbest_hyps = decoding(encoder_output=encoder_output, encoded_lengths=encoded_lengths)[0] |
|
|
| check_res_nbest_hyps(num_samples, batch_nbest_hyps) |
| batch_nbest_hyps = decode_text_from_nbest_hypotheses(batch_nbest_hyps, model) |
| print_res_nbest_hyps(batch_nbest_hyps) |
|
|
| @pytest.mark.skipif( |
| not NUMBA_RNNT_LOSS_AVAILABLE, |
| reason='RNNTLoss has not been compiled with appropriate numba version.', |
| ) |
| @pytest.mark.with_downloads |
| @pytest.mark.unit |
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="Test is only GPU-based decoding") |
| @pytest.mark.parametrize( |
| "beam_config", |
| [ |
| {"search_type": "malsd_batch", "allow_cuda_graphs": False}, |
| {"search_type": "maes_batch", "allow_cuda_graphs": False}, |
| {"search_type": "malsd_batch", "allow_cuda_graphs": True}, |
| ], |
| ) |
| @pytest.mark.parametrize("batch_size", [4]) |
| @pytest.mark.parametrize("beam_size", [4]) |
| @pytest.mark.parametrize("pruning_mode", ["late", "early"]) |
| @pytest.mark.parametrize("blank_lm_score_mode", ["no_score", "lm_weighted_full"]) |
| def test_rnnt_beam_decoding_kenlm( |
| self, |
| kenlm_model_path, |
| test_audio_filenames, |
| rnnt_model, |
| get_rnnt_encoder_output, |
| beam_config, |
| device, |
| batch_size, |
| beam_size, |
| pruning_mode, |
| blank_lm_score_mode, |
| ): |
| device = torch.device("cuda") |
|
|
| num_samples = min(batch_size, len(test_audio_filenames)) |
| model = rnnt_model.to(device) |
| encoder_output, encoded_lengths = get_rnnt_encoder_output |
| encoder_output, encoded_lengths = encoder_output[:num_samples].to(device), encoded_lengths[:num_samples].to( |
| device |
| ) |
|
|
| vocab_size = model.tokenizer.vocab_size |
|
|
| fusion_models = [NGramGPULanguageModel.from_file(lm_path=kenlm_model_path, vocab_size=vocab_size)] |
| fusion_models_alpha = [0.3] |
|
|
| decoding = BeamBatchedRNNTInfer( |
| model.decoder, |
| model.joint, |
| blank_index=vocab_size, |
| beam_size=beam_size, |
| score_norm=True, |
| return_best_hypothesis=True, |
| pruning_mode=pruning_mode, |
| blank_lm_score_mode=blank_lm_score_mode, |
| fusion_models=fusion_models, |
| fusion_models_alpha=fusion_models_alpha, |
| **beam_config, |
| ) |
|
|
| print_unit_test_info( |
| strategy=beam_config['search_type'], |
| batch_size=batch_size, |
| beam_size=beam_size, |
| allow_cuda_graphs=beam_config.get('allow_cuda_graphs', True), |
| device=device, |
| ) |
|
|
| with torch.no_grad(): |
| hyps = decoding(encoder_output=encoder_output, encoded_lengths=encoded_lengths)[0] |
|
|
| check_res_best_hyps(num_samples, hyps) |
| hyps = decode_text_from_hypotheses(hyps, model) |
| print_res_best_hyps(hyps) |
|
|
|
|
| class TestTDTDecoding: |
| @pytest.mark.skipif( |
| not NUMBA_RNNT_LOSS_AVAILABLE, |
| reason='RNNTLoss has not been compiled with appropriate numba version.', |
| ) |
| @pytest.mark.with_downloads |
| @pytest.mark.unit |
| @pytest.mark.parametrize( |
| "beam_config", |
| [ |
| {"search_type": "malsd_batch", "allow_cuda_graphs": False}, |
| {"search_type": "malsd_batch", "allow_cuda_graphs": True}, |
| ], |
| ) |
| @pytest.mark.parametrize("beam_size", [4]) |
| @pytest.mark.parametrize("batch_size", [4, 16]) |
| @pytest.mark.parametrize("device", DEVICES) |
| def test_tdt_beam_decoding_return_best_hypothesis( |
| self, test_audio_filenames, tdt_model, get_tdt_encoder_output, beam_config, device, batch_size, beam_size |
| ): |
| num_samples = min(batch_size, len(test_audio_filenames)) |
| model = tdt_model.to(device) |
| encoder_output, encoded_lengths = get_tdt_encoder_output |
| encoder_output, encoded_lengths = encoder_output[:num_samples].to(device), encoded_lengths[:num_samples].to( |
| device |
| ) |
|
|
| model_config = model.to_config_dict() |
| durations = list(model_config["model_defaults"]["tdt_durations"]) |
|
|
| vocab_size = model.tokenizer.vocab_size |
| decoding = BeamBatchedTDTInfer( |
| model.decoder, |
| model.joint, |
| blank_index=vocab_size, |
| durations=durations, |
| beam_size=beam_size, |
| score_norm=True, |
| return_best_hypothesis=True, |
| **beam_config, |
| ) |
|
|
| print_unit_test_info( |
| strategy=beam_config['search_type'], |
| batch_size=batch_size, |
| beam_size=beam_size, |
| allow_cuda_graphs=beam_config.get('allow_cuda_graphs', True), |
| device=device, |
| ) |
|
|
| with torch.no_grad(): |
| hyps = decoding(encoder_output=encoder_output, encoded_lengths=encoded_lengths)[0] |
|
|
| check_res_best_hyps(num_samples, hyps) |
| hyps = decode_text_from_hypotheses(hyps, model) |
| print_res_best_hyps(hyps) |
|
|
| @pytest.mark.skipif( |
| not NUMBA_RNNT_LOSS_AVAILABLE, |
| reason='RNNTLoss has not been compiled with appropriate numba version.', |
| ) |
| @pytest.mark.with_downloads |
| @pytest.mark.unit |
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="Test is only GPU-based decoding") |
| @pytest.mark.parametrize( |
| "beam_config", |
| [ |
| {"search_type": "malsd_batch", "allow_cuda_graphs": False}, |
| {"search_type": "malsd_batch", "allow_cuda_graphs": True}, |
| ], |
| ) |
| @pytest.mark.parametrize("beam_size", [4]) |
| @pytest.mark.parametrize("batch_size", [4]) |
| def test_tdt_beam_decoding_return_nbest( |
| self, test_audio_filenames, tdt_model, get_tdt_encoder_output, beam_config, device, beam_size, batch_size |
| ): |
| device = torch.device("cuda") |
| num_samples = min(batch_size, len(test_audio_filenames)) |
| model = tdt_model.to(device) |
| encoder_output, encoded_lengths = get_tdt_encoder_output |
| encoder_output, encoded_lengths = encoder_output[:num_samples].to(device), encoded_lengths[:num_samples].to( |
| device |
| ) |
|
|
| model_config = model.to_config_dict() |
| durations = list(model_config["model_defaults"]["tdt_durations"]) |
|
|
| vocab_size = model.tokenizer.vocab_size |
| decoding = BeamBatchedTDTInfer( |
| model.decoder, |
| model.joint, |
| blank_index=vocab_size, |
| durations=durations, |
| beam_size=beam_size, |
| score_norm=True, |
| return_best_hypothesis=False, |
| **beam_config, |
| ) |
|
|
| print_unit_test_info( |
| strategy=beam_config['search_type'], |
| batch_size=batch_size, |
| beam_size=beam_size, |
| allow_cuda_graphs=beam_config.get('allow_cuda_graphs', True), |
| device=device, |
| ) |
|
|
| with torch.no_grad(): |
| batch_nbest_hyps = decoding(encoder_output=encoder_output, encoded_lengths=encoded_lengths)[0] |
|
|
| check_res_nbest_hyps(num_samples, batch_nbest_hyps) |
| batch_nbest_hyps = decode_text_from_nbest_hypotheses(batch_nbest_hyps, model) |
| print_res_nbest_hyps(batch_nbest_hyps) |
|
|
| @pytest.mark.skipif( |
| not NUMBA_RNNT_LOSS_AVAILABLE, |
| reason='RNNTLoss has not been compiled with appropriate numba version.', |
| ) |
| @pytest.mark.with_downloads |
| @pytest.mark.unit |
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="Test is only GPU-based decoding") |
| @pytest.mark.parametrize( |
| "beam_config", |
| [ |
| {"search_type": "malsd_batch", "allow_cuda_graphs": False}, |
| {"search_type": "malsd_batch", "allow_cuda_graphs": True}, |
| ], |
| ) |
| @pytest.mark.parametrize("batch_size", [4]) |
| @pytest.mark.parametrize("beam_size", [4]) |
| @pytest.mark.parametrize("pruning_mode", ["late", "early"]) |
| @pytest.mark.parametrize("blank_lm_score_mode", ["lm_weighted_full", "no_score"]) |
| def test_tdt_beam_decoding_kenlm( |
| self, |
| kenlm_model_path, |
| test_audio_filenames, |
| tdt_model, |
| get_tdt_encoder_output, |
| beam_config, |
| device, |
| batch_size, |
| beam_size, |
| pruning_mode, |
| blank_lm_score_mode, |
| ): |
| device = torch.device("cuda") |
|
|
| num_samples = min(batch_size, len(test_audio_filenames)) |
| model = tdt_model.to(device) |
| encoder_output, encoded_lengths = get_tdt_encoder_output |
| encoder_output, encoded_lengths = encoder_output[:num_samples].to(device), encoded_lengths[:num_samples].to( |
| device |
| ) |
|
|
| model_config = model.to_config_dict() |
| durations = list(model_config["model_defaults"]["tdt_durations"]) |
|
|
| vocab_size = model.tokenizer.vocab_size |
|
|
| fusion_models = [NGramGPULanguageModel.from_file(lm_path=kenlm_model_path, vocab_size=vocab_size)] |
| fusion_models_alpha = [0.3] |
|
|
| decoding = BeamBatchedTDTInfer( |
| model.decoder, |
| model.joint, |
| blank_index=vocab_size, |
| durations=durations, |
| beam_size=beam_size, |
| score_norm=True, |
| return_best_hypothesis=True, |
| pruning_mode=pruning_mode, |
| blank_lm_score_mode=blank_lm_score_mode, |
| fusion_models=fusion_models, |
| fusion_models_alpha=fusion_models_alpha, |
| **beam_config, |
| ) |
|
|
| print_unit_test_info( |
| strategy=beam_config['search_type'], |
| batch_size=batch_size, |
| beam_size=beam_size, |
| allow_cuda_graphs=beam_config.get('allow_cuda_graphs', True), |
| device=device, |
| ) |
|
|
| with torch.no_grad(): |
| hyps = decoding(encoder_output=encoder_output, encoded_lengths=encoded_lengths)[0] |
|
|
| check_res_best_hyps(num_samples, hyps) |
| hyps = decode_text_from_hypotheses(hyps, model) |
| print_res_best_hyps(hyps) |
|
|
|
|
| class TestTransducerCudaGraphBeamDecoding: |
| """ |
| Tests CudaGraphs implementations from Transducer models (RNN-T and TDT) |
| """ |
|
|
| @pytest.mark.with_downloads |
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA decoder can run only on CUDA") |
| @pytest.mark.parametrize("force_mode", ["no_graphs", "no_while_loops", "full_graph"]) |
| @pytest.mark.parametrize("model_type", ["rnnt", "tdt"]) |
| def test_stated_stateless(self, test_audio_filenames, rnnt_model, tdt_model, model_type, force_mode: str): |
| """ |
| Compares pure Pytorch and with three modes of statefull implementations for double floating point precision. |
| 1. Pure pytorch, but statefull implementation: no_graphs |
| 2. With CudaGrpahs: no_while_loops and full_graph. |
| """ |
| if force_mode == "full_graph": |
| skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported() |
|
|
| batch_size = 16 |
| device = torch.device("cuda") |
| model = rnnt_model.to(device) if model_type == "rnnt" else tdt_model.to(device) |
| decoding_config = copy.deepcopy(model.cfg.decoding) |
|
|
| with open_dict(decoding_config): |
| decoding_config["strategy"] = "malsd_batch" |
| decoding_config["beam"]["beam_size"] = 4 |
| decoding_config["beam"]["return_best_hypothesis"] = False |
| decoding_config["beam"]["allow_cuda_graphs"] = False |
|
|
| model.change_decoding_strategy(decoding_config) |
|
|
| actual_hypotheses = model.transcribe(test_audio_filenames, batch_size=batch_size, num_workers=None) |
| actual_transcripts = [[hyp.text for hyp in actual_beam] for actual_beam in actual_hypotheses] |
| actual_scores = [[hyp.score for hyp in actual_beam] for actual_beam in actual_hypotheses] |
| actual_timestamps = [[hyp.timestamp for hyp in actual_beam] for actual_beam in actual_hypotheses] |
|
|
| |
| decoding_config["beam"]["allow_cuda_graphs"] = True |
| model.change_decoding_strategy(decoding_config) |
| model.decoding.decoding._decoding_computer.force_cuda_graphs_mode(mode=force_mode) |
|
|
| cudagraph_hypotheses = model.transcribe(test_audio_filenames, batch_size=batch_size, num_workers=None) |
| cudagraph_transcripts = [[hyp.text for hyp in cudagraphs_beam] for cudagraphs_beam in cudagraph_hypotheses] |
| cudagraph_scores = [[hyp.score for hyp in cudagraph_beam] for cudagraph_beam in cudagraph_hypotheses] |
| cudagraph_timestamps = [[hyp.timestamp for hyp in cudagraph_beam] for cudagraph_beam in cudagraph_hypotheses] |
|
|
| for batch_idx in range(min(batch_size, len(test_audio_filenames))): |
| assert len(actual_transcripts[batch_idx]) == len(cudagraph_transcripts[batch_idx]) |
| assert cudagraph_scores[batch_idx] == pytest.approx( |
| actual_scores[batch_idx], abs=1e-2 |
| ), f"Scores mismatch for batch_idx {batch_idx}" |
| assert ( |
| cudagraph_timestamps[batch_idx] == actual_timestamps[batch_idx] |
| ), f"Timestamps mismatch for batch_idx {batch_idx}" |
|
|
| wer = jiwer.wer(actual_transcripts[batch_idx], cudagraph_transcripts[batch_idx]) |
|
|
| assert wer <= 1e-3, "Cuda graph greedy decoder should match original decoder implementation." |
|
|
| for actual, fast in zip(actual_transcripts[batch_idx], cudagraph_transcripts[batch_idx]): |
| if actual != fast: |
| print("Erroneous samples in batch:", batch_idx) |
| print("Original transcript:", actual) |
| print("New transcript:", fast) |
|
|
| @pytest.mark.with_downloads |
| @pytest.mark.skipif( |
| not (torch.cuda.is_available() and torch.cuda.is_bf16_supported()), reason="CUDA decoder can run only on CUDA" |
| ) |
| @pytest.mark.parametrize("model_type", ["rnnt", "tdt"]) |
| def test_stated_stateless_bf16(self, test_audio_filenames, rnnt_model, tdt_model, model_type): |
| """ |
| Checks that we are able to run without errors all decodings in bfloat16. |
| Computational errors accumulate, so just checking if algorithms run without errors |
| """ |
| batch_size = 16 |
| device = torch.device("cuda") |
| model = rnnt_model.to(device) if model_type == "rnnt" else tdt_model.to(device) |
| decoding_config = copy.deepcopy(model.cfg.decoding) |
|
|
| |
| with open_dict(decoding_config): |
| decoding_config["strategy"] = "malsd_batch" |
| decoding_config["beam"]["beam_size"] = 4 |
| decoding_config["beam"]["return_best_hypothesis"] = False |
| decoding_config["beam"]["allow_cuda_graphs"] = False |
|
|
| model.change_decoding_strategy(decoding_config) |
|
|
| with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True): |
| model.transcribe(test_audio_filenames, batch_size=batch_size, num_workers=None) |
|
|
| modes = ["no_graphs", "no_while_loops", "full_graph"] |
| for force_mode in modes: |
| if force_mode == "full_graph": |
| skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported() |
|
|
| |
| decoding_config["beam"]["allow_cuda_graphs"] = True |
| model.change_decoding_strategy(decoding_config) |
| model.decoding.decoding._decoding_computer.force_cuda_graphs_mode(mode=force_mode) |
|
|
| with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True): |
| model.transcribe(test_audio_filenames, batch_size=batch_size, num_workers=None) |
|
|
|
|
| class TestCTCDecoding: |
| @pytest.mark.with_downloads |
| @pytest.mark.unit |
| @pytest.mark.parametrize( |
| "beam_config", |
| [ |
| {"allow_cuda_graphs": False}, |
| {"allow_cuda_graphs": True}, |
| ], |
| ) |
| @pytest.mark.parametrize("beam_size", [4]) |
| @pytest.mark.parametrize("batch_size", [4, 16]) |
| @pytest.mark.parametrize("device", DEVICES) |
| def test_ctc_beam_decoding_return_best_hypothesis( |
| self, test_audio_filenames, ctc_model, get_ctc_output, beam_config, device, batch_size, beam_size |
| ): |
| num_samples = min(batch_size, len(test_audio_filenames)) |
| model = ctc_model.to(device) |
| log_probs, encoded_lengths = get_ctc_output |
| log_probs, encoded_lengths = log_probs[:num_samples].to(device), encoded_lengths[:num_samples].to(device) |
|
|
| vocab_size = model.tokenizer.vocab_size |
| decoding = BeamBatchedCTCInfer( |
| blank_index=vocab_size, |
| beam_size=beam_size, |
| return_best_hypothesis=True, |
| **beam_config, |
| ) |
|
|
| print_unit_test_info( |
| strategy="beam_batch", |
| batch_size=batch_size, |
| beam_size=beam_size, |
| allow_cuda_graphs=beam_config.get('allow_cuda_graphs', True), |
| device=device, |
| ) |
|
|
| with torch.no_grad(): |
| hyps = decoding(decoder_output=log_probs, decoder_lengths=encoded_lengths)[0] |
|
|
| check_res_best_hyps(num_samples, hyps) |
| hyps = decode_text_from_hypotheses(hyps, model) |
| print_res_best_hyps(hyps) |
|
|
| @pytest.mark.with_downloads |
| @pytest.mark.unit |
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="Test is only GPU-based decoding") |
| @pytest.mark.parametrize( |
| "beam_config", |
| [ |
| {"allow_cuda_graphs": False}, |
| {"allow_cuda_graphs": True}, |
| ], |
| ) |
| @pytest.mark.parametrize("beam_size", [4]) |
| @pytest.mark.parametrize("batch_size", [4]) |
| def test_ctc_beam_decoding_return_nbest( |
| self, test_audio_filenames, ctc_model, get_ctc_output, beam_config, device, beam_size, batch_size |
| ): |
| device = torch.device("cuda") |
| num_samples = min(batch_size, len(test_audio_filenames)) |
| model = ctc_model.to(device) |
| log_probs, encoded_lengths = get_ctc_output |
| log_probs, encoded_lengths = log_probs[:num_samples].to(device), encoded_lengths[:num_samples].to(device) |
|
|
| vocab_size = model.tokenizer.vocab_size |
| decoding = BeamBatchedCTCInfer( |
| blank_index=vocab_size, |
| beam_size=beam_size, |
| return_best_hypothesis=False, |
| **beam_config, |
| ) |
|
|
| print_unit_test_info( |
| strategy="beam_batch", |
| batch_size=batch_size, |
| beam_size=beam_size, |
| allow_cuda_graphs=beam_config.get('allow_cuda_graphs', True), |
| device=device, |
| ) |
|
|
| with torch.no_grad(): |
| batch_nbest_hyps = decoding(decoder_output=log_probs, decoder_lengths=encoded_lengths)[0] |
|
|
| check_res_nbest_hyps(num_samples, batch_nbest_hyps) |
| batch_nbest_hyps = decode_text_from_nbest_hypotheses(batch_nbest_hyps, model) |
| print_res_nbest_hyps(batch_nbest_hyps) |
|
|
| @pytest.mark.with_downloads |
| @pytest.mark.unit |
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="Test is only GPU-based decoding") |
| @pytest.mark.parametrize( |
| "beam_config", |
| [ |
| {"allow_cuda_graphs": False, "ngram_lm_alpha": 0.3, "beam_beta": 1.0}, |
| {"allow_cuda_graphs": False, "ngram_lm_alpha": 0.3, "beam_beta": 1.0}, |
| ], |
| ) |
| @pytest.mark.parametrize("batch_size", [4]) |
| @pytest.mark.parametrize("beam_size", [4]) |
| def test_ctc_beam_decoding_kenlm( |
| self, |
| kenlm_model_path, |
| test_audio_filenames, |
| ctc_model, |
| get_ctc_output, |
| beam_config, |
| device, |
| batch_size, |
| beam_size, |
| ): |
| device = torch.device("cuda") |
| beam_config["ngram_lm_model"] = kenlm_model_path |
|
|
| num_samples = min(batch_size, len(test_audio_filenames)) |
| model = ctc_model.to(device) |
| decoder_output, decoder_lengths = get_ctc_output |
| decoder_output, decoder_lengths = decoder_output[:num_samples].to(device), decoder_lengths[:num_samples].to( |
| device |
| ) |
|
|
| vocab_size = model.tokenizer.vocab_size |
| decoding = BeamBatchedCTCInfer( |
| blank_index=vocab_size, |
| beam_size=beam_size, |
| return_best_hypothesis=True, |
| **beam_config, |
| ) |
|
|
| print_unit_test_info( |
| strategy="beam_batch", |
| batch_size=batch_size, |
| beam_size=beam_size, |
| allow_cuda_graphs=beam_config.get('allow_cuda_graphs', True), |
| device=device, |
| ) |
|
|
| with torch.no_grad(): |
| hyps = decoding(decoder_output=decoder_output, decoder_lengths=decoder_lengths)[0] |
|
|
| check_res_best_hyps(num_samples, hyps) |
| hyps = decode_text_from_hypotheses(hyps, model) |
| print_res_best_hyps(hyps) |
|
|