| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import copy |
| | import os |
| |
|
| | import pytest |
| | import torch |
| | from omegaconf import DictConfig, ListConfig |
| |
|
| | from nemo.collections.asr.data import audio_to_label |
| | from nemo.collections.asr.models import EncDecClassificationModel, configs |
| | from nemo.utils.config_utils import assert_dataclass_signature_match |
| |
|
| |
|
| | @pytest.fixture() |
| | def speech_classification_model(): |
| | preprocessor = {'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'params': dict({})} |
| | encoder = { |
| | 'cls': 'nemo.collections.asr.modules.ConvASREncoder', |
| | 'params': { |
| | 'feat_in': 64, |
| | 'activation': 'relu', |
| | 'conv_mask': True, |
| | 'jasper': [ |
| | { |
| | 'filters': 32, |
| | 'repeat': 1, |
| | 'kernel': [1], |
| | 'stride': [1], |
| | 'dilation': [1], |
| | 'dropout': 0.0, |
| | 'residual': False, |
| | 'separable': True, |
| | 'se': True, |
| | 'se_context_size': -1, |
| | } |
| | ], |
| | }, |
| | } |
| |
|
| | decoder = { |
| | 'cls': 'nemo.collections.asr.modules.ConvASRDecoderClassification', |
| | 'params': {'feat_in': 32, 'num_classes': 30,}, |
| | } |
| |
|
| | modelConfig = DictConfig( |
| | { |
| | 'preprocessor': DictConfig(preprocessor), |
| | 'encoder': DictConfig(encoder), |
| | 'decoder': DictConfig(decoder), |
| | 'labels': ListConfig(["dummy_cls_{}".format(i + 1) for i in range(30)]), |
| | } |
| | ) |
| | model = EncDecClassificationModel(cfg=modelConfig) |
| | return model |
| |
|
| |
|
| | class TestEncDecClassificationModel: |
| | @pytest.mark.unit |
| | def test_constructor(self, speech_classification_model): |
| | asr_model = speech_classification_model.train() |
| |
|
| | conv_cnt = (64 * 32 * 1 + 32) + (64 * 1 * 1 + 32) |
| | bn_cnt = (4 * 32) * 2 |
| | dec_cnt = 32 * 30 + 30 |
| |
|
| | param_count = conv_cnt + bn_cnt + dec_cnt |
| | assert asr_model.num_weights == param_count |
| |
|
| | |
| | confdict = asr_model.to_config_dict() |
| | instance2 = EncDecClassificationModel.from_config_dict(confdict) |
| |
|
| | assert isinstance(instance2, EncDecClassificationModel) |
| |
|
| | @pytest.mark.unit |
| | def test_forward(self, speech_classification_model): |
| | asr_model = speech_classification_model.eval() |
| |
|
| | asr_model.preprocessor.featurizer.dither = 0.0 |
| | asr_model.preprocessor.featurizer.pad_to = 0 |
| |
|
| | input_signal = torch.randn(size=(4, 512)) |
| | length = torch.randint(low=161, high=500, size=[4]) |
| |
|
| | with torch.no_grad(): |
| | |
| | logprobs_instance = [] |
| | for i in range(input_signal.size(0)): |
| | logprobs_ins = asr_model.forward( |
| | input_signal=input_signal[i : i + 1], input_signal_length=length[i : i + 1] |
| | ) |
| | logprobs_instance.append(logprobs_ins) |
| | logprobs_instance = torch.cat(logprobs_instance, 0) |
| |
|
| | |
| | logprobs_batch = asr_model.forward(input_signal=input_signal, input_signal_length=length) |
| |
|
| | assert logprobs_instance.shape == logprobs_batch.shape |
| | diff = torch.mean(torch.abs(logprobs_instance - logprobs_batch)) |
| | assert diff <= 1e-6 |
| | diff = torch.max(torch.abs(logprobs_instance - logprobs_batch)) |
| | assert diff <= 1e-6 |
| |
|
| | @pytest.mark.unit |
| | def test_vocab_change(self, speech_classification_model): |
| | asr_model = speech_classification_model.train() |
| |
|
| | old_labels = copy.deepcopy(asr_model._cfg.labels) |
| | nw1 = asr_model.num_weights |
| | asr_model.change_labels(new_labels=old_labels) |
| | |
| | assert nw1 == asr_model.num_weights |
| | new_labels = copy.deepcopy(old_labels) |
| | new_labels.append('dummy_cls_31') |
| | new_labels.append('dummy_cls_32') |
| | new_labels.append('dummy_cls_33') |
| | asr_model.change_labels(new_labels=new_labels) |
| | |
| | assert asr_model.num_weights == nw1 + 3 * (asr_model.decoder._feat_in + 1) |
| |
|
| | @pytest.mark.unit |
| | def test_transcription(self, speech_classification_model, test_data_dir): |
| | |
| | audio_filenames = ['an22-flrp-b.wav', 'an90-fbbh-b.wav'] |
| | audio_paths = [os.path.join(test_data_dir, "asr", "train", "an4", "wav", fp) for fp in audio_filenames] |
| |
|
| | model = speech_classification_model.eval() |
| |
|
| | |
| | results = model.transcribe(audio_paths, batch_size=2) |
| | assert len(results) == 2 |
| | assert results[0].shape == torch.Size([1]) |
| |
|
| | |
| | model._accuracy.top_k = [5] |
| | results = model.transcribe(audio_paths, batch_size=2) |
| | assert len(results) == 2 |
| | assert results[0].shape == torch.Size([5]) |
| |
|
| | |
| | model._accuracy.top_k = [1, 5] |
| | results = model.transcribe(audio_paths, batch_size=2) |
| | assert len(results) == 2 |
| | assert results[0].shape == torch.Size([2, 1]) |
| | assert results[1].shape == torch.Size([2, 5]) |
| | assert model._accuracy.top_k == [1, 5] |
| |
|
| | |
| | model._accuracy.top_k = [1] |
| | results = model.transcribe(audio_paths, batch_size=2, logprobs=True) |
| | assert len(results) == 2 |
| | assert results[0].shape == torch.Size([len(model.cfg.labels)]) |
| |
|
| | |
| | model._accuracy.top_k = [5] |
| | results = model.transcribe(audio_paths, batch_size=2, logprobs=True) |
| | assert len(results) == 2 |
| | assert results[0].shape == torch.Size([len(model.cfg.labels)]) |
| |
|
| | @pytest.mark.unit |
| | def test_EncDecClassificationDatasetConfig_for_AudioToSpeechLabelDataset(self): |
| | |
| | IGNORE_ARGS = [ |
| | 'is_tarred', |
| | 'num_workers', |
| | 'batch_size', |
| | 'tarred_audio_filepaths', |
| | 'shuffle', |
| | 'pin_memory', |
| | 'drop_last', |
| | 'tarred_shard_strategy', |
| | 'shuffle_n', |
| | |
| | 'featurizer', |
| | |
| | 'vad_stream', |
| | 'int_values', |
| | 'sample_rate', |
| | 'normalize_audio', |
| | 'augmentor', |
| | 'bucketing_batch_size', |
| | 'bucketing_strategy', |
| | 'bucketing_weights', |
| | ] |
| |
|
| | REMAP_ARGS = {'trim_silence': 'trim'} |
| |
|
| | result = assert_dataclass_signature_match( |
| | audio_to_label.AudioToSpeechLabelDataset, |
| | configs.EncDecClassificationDatasetConfig, |
| | ignore_args=IGNORE_ARGS, |
| | remap_args=REMAP_ARGS, |
| | ) |
| | signatures_match, cls_subset, dataclass_subset = result |
| |
|
| | assert signatures_match |
| | assert cls_subset is None |
| | assert dataclass_subset is None |
| |
|