|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import pytest |
|
|
from examples.asr.transcribe_speech import TranscriptionConfig |
|
|
from omegaconf import OmegaConf |
|
|
|
|
|
from nemo.collections.asr.parts.utils.transcribe_utils import prepare_audio_data, setup_model |
|
|
|
|
|
TEST_DATA_PATH = "/home/TestData/an4_dataset/an4_val.json" |
|
|
PRETRAINED_MODEL_NAME = "stt_en_conformer_transducer_small" |
|
|
|
|
|
|
|
|
def get_rnnt_alignments(strategy: str): |
|
|
cfg = OmegaConf.structured(TranscriptionConfig(pretrained_name=PRETRAINED_MODEL_NAME)) |
|
|
cfg.rnnt_decoding.confidence_cfg.preserve_frame_confidence = True |
|
|
cfg.rnnt_decoding.preserve_alignments = True |
|
|
cfg.rnnt_decoding.strategy = strategy |
|
|
cfg.dataset_manifest = TEST_DATA_PATH |
|
|
filepaths = prepare_audio_data(cfg)[0][:10] |
|
|
|
|
|
model = setup_model(cfg, map_location="cuda")[0] |
|
|
model.change_decoding_strategy(cfg.rnnt_decoding) |
|
|
|
|
|
transcriptions = model.transcribe( |
|
|
paths2audio_files=filepaths, |
|
|
batch_size=cfg.batch_size, |
|
|
num_workers=cfg.num_workers, |
|
|
return_hypotheses=True, |
|
|
channel_selector=cfg.channel_selector, |
|
|
)[0] |
|
|
|
|
|
for transcription in transcriptions: |
|
|
for align_elem, frame_confidence in zip(transcription.alignments, transcription.frame_confidence): |
|
|
assert len(align_elem) == len(frame_confidence) |
|
|
assert len(align_elem) > 0 |
|
|
for idx, pred in enumerate(align_elem): |
|
|
if idx < len(align_elem) - 1: |
|
|
assert pred[1].item() != model.decoder.blank_idx |
|
|
else: |
|
|
assert pred[1].item() == model.decoder.blank_idx |
|
|
return transcriptions |
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True) |
|
|
def cleanup_local_folder(): |
|
|
"""Overriding global fixture to make sure it's not applied for this test. |
|
|
|
|
|
Otherwise, there will be errors in the CI in github. |
|
|
""" |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
def test_rnnt_alignments(): |
|
|
|
|
|
ref_transcriptions = get_rnnt_alignments("greedy") |
|
|
transcriptions = get_rnnt_alignments("greedy_batch") |
|
|
|
|
|
|
|
|
|
|
|
assert len(ref_transcriptions) == len(transcriptions) |
|
|
for ref_transcription, transcription in zip(ref_transcriptions, transcriptions): |
|
|
for ref_align_elem, align_elem in zip(ref_transcription.alignments, transcription.alignments): |
|
|
assert len(ref_align_elem) == len(align_elem) |
|
|
for ref_pred, pred in zip(ref_align_elem, align_elem): |
|
|
assert ref_pred[1].item() == pred[1].item() |
|
|
|