| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
|
|
| import pytest |
| from omegaconf import DictConfig |
|
|
| from nemo.collections.asr.models.classification_models import EncDecRegressionModel |
|
|
|
|
| @pytest.fixture() |
| def speech_regression_model(): |
| preprocessor = {'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'params': dict({})} |
| encoder = { |
| 'cls': 'nemo.collections.asr.modules.ConvASREncoder', |
| 'params': { |
| 'feat_in': 64, |
| 'activation': 'relu', |
| 'conv_mask': True, |
| 'jasper': [ |
| { |
| 'filters': 32, |
| 'repeat': 1, |
| 'kernel': [1], |
| 'stride': [1], |
| 'dilation': [1], |
| 'dropout': 0.0, |
| 'residual': False, |
| 'separable': True, |
| 'se': True, |
| 'se_context_size': -1, |
| } |
| ], |
| }, |
| } |
|
|
| decoder = { |
| 'cls': 'nemo.collections.asr.modules.conv_asr.ConvASRDecoderClassification', |
| 'params': {'feat_in': 32, 'return_logits': True, 'num_classes': 1}, |
| } |
|
|
| modelConfig = DictConfig( |
| { |
| 'preprocessor': DictConfig(preprocessor), |
| 'encoder': DictConfig(encoder), |
| 'decoder': DictConfig(decoder), |
| 'labels': None, |
| 'is_regression_task': True, |
| } |
| ) |
| model = EncDecRegressionModel(cfg=modelConfig) |
| return model |
|
|
|
|
| class TestEncDecRegressionModel: |
| @pytest.mark.unit |
| def test_constructor(self, speech_regression_model): |
| asr_model = speech_regression_model.train() |
|
|
| conv_cnt = (64 * 32 * 1 + 32) + (64 * 1 * 1 + 32) |
| bn_cnt = (4 * 32) * 2 |
| dec_cnt = 32 * 1 + 1 |
|
|
| param_count = conv_cnt + bn_cnt + dec_cnt |
| assert asr_model.num_weights == param_count |
|
|
| |
| confdict = asr_model.to_config_dict() |
| instance2 = EncDecRegressionModel.from_config_dict(confdict) |
|
|
| assert isinstance(instance2, EncDecRegressionModel) |
|
|
| @pytest.mark.unit |
| def test_transcription(self, speech_regression_model, test_data_dir): |
|
|
| audio_filenames = ['an22-flrp-b.wav', 'an90-fbbh-b.wav'] |
| audio_paths = [os.path.join(test_data_dir, "asr", "train", "an4", "wav", fp) for fp in audio_filenames] |
|
|
| model = speech_regression_model.eval() |
|
|
| |
| results = model.transcribe(audio_paths, batch_size=2) |
| assert len(results) == 2 |
|
|