NeMo / tests /collections /asr /test_asr_classification_model.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
import pytest
import torch
from omegaconf import DictConfig, ListConfig
from nemo.collections.asr.data import audio_to_label
from nemo.collections.asr.models import EncDecClassificationModel, configs
from nemo.utils.config_utils import assert_dataclass_signature_match
@pytest.fixture()
def speech_classification_model():
preprocessor = {'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'params': dict({})}
encoder = {
'cls': 'nemo.collections.asr.modules.ConvASREncoder',
'params': {
'feat_in': 64,
'activation': 'relu',
'conv_mask': True,
'jasper': [
{
'filters': 32,
'repeat': 1,
'kernel': [1],
'stride': [1],
'dilation': [1],
'dropout': 0.0,
'residual': False,
'separable': True,
'se': True,
'se_context_size': -1,
}
],
},
}
decoder = {
'cls': 'nemo.collections.asr.modules.ConvASRDecoderClassification',
'params': {'feat_in': 32, 'num_classes': 30,},
}
modelConfig = DictConfig(
{
'preprocessor': DictConfig(preprocessor),
'encoder': DictConfig(encoder),
'decoder': DictConfig(decoder),
'labels': ListConfig(["dummy_cls_{}".format(i + 1) for i in range(30)]),
}
)
model = EncDecClassificationModel(cfg=modelConfig)
return model
class TestEncDecClassificationModel:
@pytest.mark.unit
def test_constructor(self, speech_classification_model):
asr_model = speech_classification_model.train()
conv_cnt = (64 * 32 * 1 + 32) + (64 * 1 * 1 + 32) # separable kernel + bias + pointwise kernel + bias
bn_cnt = (4 * 32) * 2 # 2 * moving averages
dec_cnt = 32 * 30 + 30 # fc + bias
param_count = conv_cnt + bn_cnt + dec_cnt
assert asr_model.num_weights == param_count
# Check to/from config_dict:
confdict = asr_model.to_config_dict()
instance2 = EncDecClassificationModel.from_config_dict(confdict)
assert isinstance(instance2, EncDecClassificationModel)
@pytest.mark.unit
def test_forward(self, speech_classification_model):
asr_model = speech_classification_model.eval()
asr_model.preprocessor.featurizer.dither = 0.0
asr_model.preprocessor.featurizer.pad_to = 0
input_signal = torch.randn(size=(4, 512))
length = torch.randint(low=161, high=500, size=[4])
with torch.no_grad():
# batch size 1
logprobs_instance = []
for i in range(input_signal.size(0)):
logprobs_ins = asr_model.forward(
input_signal=input_signal[i : i + 1], input_signal_length=length[i : i + 1]
)
logprobs_instance.append(logprobs_ins)
logprobs_instance = torch.cat(logprobs_instance, 0)
# batch size 4
logprobs_batch = asr_model.forward(input_signal=input_signal, input_signal_length=length)
assert logprobs_instance.shape == logprobs_batch.shape
diff = torch.mean(torch.abs(logprobs_instance - logprobs_batch))
assert diff <= 1e-6
diff = torch.max(torch.abs(logprobs_instance - logprobs_batch))
assert diff <= 1e-6
@pytest.mark.unit
def test_vocab_change(self, speech_classification_model):
asr_model = speech_classification_model.train()
old_labels = copy.deepcopy(asr_model._cfg.labels)
nw1 = asr_model.num_weights
asr_model.change_labels(new_labels=old_labels)
# No change
assert nw1 == asr_model.num_weights
new_labels = copy.deepcopy(old_labels)
new_labels.append('dummy_cls_31')
new_labels.append('dummy_cls_32')
new_labels.append('dummy_cls_33')
asr_model.change_labels(new_labels=new_labels)
# fully connected + bias
assert asr_model.num_weights == nw1 + 3 * (asr_model.decoder._feat_in + 1)
@pytest.mark.unit
def test_transcription(self, speech_classification_model, test_data_dir):
# Ground truth labels = ["yes", "no"]
audio_filenames = ['an22-flrp-b.wav', 'an90-fbbh-b.wav']
audio_paths = [os.path.join(test_data_dir, "asr", "train", "an4", "wav", fp) for fp in audio_filenames]
model = speech_classification_model.eval()
# Test Top 1 classification transcription
results = model.transcribe(audio_paths, batch_size=2)
assert len(results) == 2
assert results[0].shape == torch.Size([1])
# Test Top 5 classification transcription
model._accuracy.top_k = [5] # set top k to 5 for accuracy calculation
results = model.transcribe(audio_paths, batch_size=2)
assert len(results) == 2
assert results[0].shape == torch.Size([5])
# Test Top 1 and Top 5 classification transcription
model._accuracy.top_k = [1, 5]
results = model.transcribe(audio_paths, batch_size=2)
assert len(results) == 2
assert results[0].shape == torch.Size([2, 1])
assert results[1].shape == torch.Size([2, 5])
assert model._accuracy.top_k == [1, 5]
# Test log probs extraction
model._accuracy.top_k = [1]
results = model.transcribe(audio_paths, batch_size=2, logprobs=True)
assert len(results) == 2
assert results[0].shape == torch.Size([len(model.cfg.labels)])
# Test log probs extraction remains same for any top_k
model._accuracy.top_k = [5]
results = model.transcribe(audio_paths, batch_size=2, logprobs=True)
assert len(results) == 2
assert results[0].shape == torch.Size([len(model.cfg.labels)])
@pytest.mark.unit
def test_EncDecClassificationDatasetConfig_for_AudioToSpeechLabelDataset(self):
# ignore some additional arguments as dataclass is generic
IGNORE_ARGS = [
'is_tarred',
'num_workers',
'batch_size',
'tarred_audio_filepaths',
'shuffle',
'pin_memory',
'drop_last',
'tarred_shard_strategy',
'shuffle_n',
# `featurizer` is supplied at runtime
'featurizer',
# additional ignored arguments
'vad_stream',
'int_values',
'sample_rate',
'normalize_audio',
'augmentor',
'bucketing_batch_size',
'bucketing_strategy',
'bucketing_weights',
]
REMAP_ARGS = {'trim_silence': 'trim'}
result = assert_dataclass_signature_match(
audio_to_label.AudioToSpeechLabelDataset,
configs.EncDecClassificationDatasetConfig,
ignore_args=IGNORE_ARGS,
remap_args=REMAP_ARGS,
)
signatures_match, cls_subset, dataclass_subset = result
assert signatures_match
assert cls_subset is None
assert dataclass_subset is None