File size: 19,203 Bytes
a7c2243 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
from functools import cached_property, lru_cache
from pathlib import Path
import jiwer
import pytest
import torch
from omegaconf import DictConfig, open_dict
from nemo.collections.asr.models import ASRModel
from nemo.collections.asr.parts.mixins import mixins
from nemo.collections.asr.parts.submodules.ctc_decoding import (
CTCBPEDecoding,
CTCBPEDecodingConfig,
CTCDecoding,
CTCDecodingConfig,
)
from nemo.collections.asr.parts.submodules.ngram_lm.ngram_lm_batched import NGramGPULanguageModel
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceConfig
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
from nemo.core.utils.cuda_python_utils import skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported
from tests.collections.asr.decoding.test_timestamps import BaseTimestampsTest
@pytest.fixture(scope="module")
def audio_file(test_data_dir):
return os.path.join(test_data_dir, "asr/test/an4/wav/cen3-mjwl-b.wav")
CTC_MODEL = "nvidia/stt_en_conformer_ctc_small"
@pytest.fixture(scope="module")
def kenlm_model_path(tmp_path_factory, test_data_dir):
lm_path = Path(test_data_dir) / "asr/kenlm_ngram_lm/parakeet-tdt_ctc-110m-libri-1024.kenlm.tmp.arpa"
assert os.path.exists(lm_path), f"LM file not found: {lm_path}"
lm_nemo_path = tmp_path_factory.mktemp("lm") / f"{lm_path.name}.nemo"
NGramGPULanguageModel.from_file(lm_path, vocab_size=1024).save_to(f"{lm_nemo_path}")
return f"{lm_nemo_path}"
@pytest.fixture(scope="module")
def ctc_model():
model = ASRModel.from_pretrained(model_name=CTC_MODEL, map_location="cpu")
model.eval()
return model
def char_vocabulary():
return [' ', 'a', 'b', 'c', 'd', 'e', 'f', '.']
@pytest.fixture()
@lru_cache(maxsize=8)
def tmp_tokenizer(test_data_dir):
cfg = DictConfig({'dir': os.path.join(test_data_dir, "asr", "tokenizers", "an4_wpe_128"), 'type': 'wpe'})
class _TmpASRBPE(mixins.ASRBPEMixin):
def register_artifact(self, _, vocab_path):
return vocab_path
asrbpe = _TmpASRBPE()
asrbpe._setup_tokenizer(cfg)
return asrbpe.tokenizer
class TestCTCDecoding:
@pytest.mark.unit
def test_constructor(self):
cfg = CTCDecodingConfig()
vocab = char_vocabulary()
decoding = CTCDecoding(decoding_cfg=cfg, vocabulary=vocab)
assert decoding is not None
@pytest.mark.unit
def test_constructor_subword(self, tmp_tokenizer):
cfg = CTCBPEDecodingConfig()
decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer)
assert decoding is not None
@pytest.mark.unit
def test_char_decoding_greedy_forward(
self,
):
cfg = CTCDecodingConfig(strategy='greedy')
vocab = char_vocabulary()
decoding = CTCDecoding(decoding_cfg=cfg, vocabulary=vocab)
B, T = 4, 20
V = len(char_vocabulary()) + 1
input_signal = torch.randn(size=(B, T, V))
length = torch.randint(low=1, high=T, size=[B])
with torch.no_grad():
hypotheses = decoding.ctc_decoder_predictions_tensor(
input_signal, length, fold_consecutive=True, return_hypotheses=False
)
texts = [hyp.text for hyp in hypotheses]
for text in texts:
assert isinstance(text, str)
@pytest.mark.unit
@pytest.mark.parametrize('alignments', [False, True])
@pytest.mark.parametrize('timestamps', [False, True])
def test_char_decoding_greedy_forward_hypotheses(self, alignments, timestamps):
cfg = CTCDecodingConfig(strategy='greedy', preserve_alignments=alignments, compute_timestamps=timestamps)
vocab = char_vocabulary()
decoding = CTCDecoding(decoding_cfg=cfg, vocabulary=vocab)
B, T = 4, 20
V = len(char_vocabulary()) + 1
input_signal = torch.randn(size=(B, T, V))
length = torch.randint(low=1, high=T, size=[B])
with torch.no_grad():
hyps = decoding.ctc_decoder_predictions_tensor(
input_signal, length, fold_consecutive=True, return_hypotheses=True
)
for idx, hyp in enumerate(hyps):
assert isinstance(hyp, Hypothesis)
assert torch.is_tensor(hyp.y_sequence)
assert isinstance(hyp.text, str)
# alignments check
if alignments:
assert hyp.alignments is not None
assert isinstance(hyp.alignments, tuple)
assert len(hyp.alignments[0]) == length[idx]
assert len(hyp.alignments[1]) == length[idx]
# timestamps check
if timestamps:
BaseTimestampsTest.check_char_timestamps(hyp, decoding)
@pytest.mark.unit
def test_subword_decoding_greedy_forward(self, tmp_tokenizer):
cfg = CTCBPEDecodingConfig(strategy='greedy')
decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer)
B, T = 4, 20
V = decoding.tokenizer.tokenizer.vocab_size + 1
input_signal = torch.randn(size=(B, T, V))
length = torch.randint(low=1, high=T, size=[B])
with torch.no_grad():
hypotheses = decoding.ctc_decoder_predictions_tensor(
input_signal, length, fold_consecutive=True, return_hypotheses=False
)
texts = [hyp.text for hyp in hypotheses]
for text in texts:
assert isinstance(text, str)
@pytest.mark.unit
@pytest.mark.parametrize('alignments', [False, True])
@pytest.mark.parametrize('timestamps', [False, True])
@pytest.mark.pleasefixme
def test_subword_decoding_greedy_forward_hypotheses(self, tmp_tokenizer, alignments, timestamps):
cfg = CTCBPEDecodingConfig(strategy='greedy', preserve_alignments=alignments, compute_timestamps=timestamps)
decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer)
B, T = 4, 20
V = decoding.tokenizer.tokenizer.vocab_size + 1
input_signal = torch.randn(size=(B, T, V))
length = torch.randint(low=1, high=T, size=[B])
with torch.no_grad():
hyps = decoding.ctc_decoder_predictions_tensor(
input_signal, length, fold_consecutive=True, return_hypotheses=True
)
for idx, hyp in enumerate(hyps):
assert isinstance(hyp, Hypothesis)
assert torch.is_tensor(hyp.y_sequence)
assert isinstance(hyp.text, str)
# alignments check
if alignments:
assert hyp.alignments is not None
assert isinstance(hyp.alignments, tuple)
assert len(hyp.alignments[0]) == length[idx]
assert len(hyp.alignments[1]) == length[idx]
# timestamps check
if timestamps:
BaseTimestampsTest.check_subword_timestamps(hyp, decoding)
@pytest.mark.unit
@pytest.mark.parametrize('alignments', [False, True])
@pytest.mark.parametrize('timestamps', [False, True])
@pytest.mark.parametrize('preserve_frame_confidence', [False, True])
@pytest.mark.parametrize('length_is_none', [False, True])
@pytest.mark.parametrize(
"logprobs_device",
[
torch.device("cpu"),
pytest.param(
torch.device("cuda"),
marks=pytest.mark.skipif(
not torch.cuda.is_available(),
reason='CUDA required for test.',
),
),
],
)
@pytest.mark.parametrize(
"length_device",
[
torch.device("cpu"),
pytest.param(
torch.device("cuda"),
marks=pytest.mark.skipif(
not torch.cuda.is_available(),
reason='CUDA required for test.',
),
),
],
)
def test_batched_decoding_logprobs(
self,
tmp_tokenizer,
alignments,
timestamps,
preserve_frame_confidence,
length_is_none,
logprobs_device,
length_device,
):
cfg = CTCBPEDecodingConfig(
strategy='greedy',
preserve_alignments=alignments,
compute_timestamps=timestamps,
confidence_cfg=ConfidenceConfig(preserve_frame_confidence=preserve_frame_confidence),
)
unbatched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer)
cfg.strategy = 'greedy_batch'
batched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer)
torch.manual_seed(1)
B, T = 4, 20
V = unbatched_decoding.tokenizer.tokenizer.vocab_size + 1
input_signal = torch.randn(size=(B, T, V), device=logprobs_device)
# Set the blank index to a very high probability to make sure
# that we always handle at least a few blanks.
input_signal[:, 0, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000
input_signal[:, 1, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000
if length_is_none:
length = None
else:
length = torch.randint(low=1, high=T, size=[B], device=length_device)
with torch.inference_mode():
hyps = unbatched_decoding.ctc_decoder_predictions_tensor(
input_signal, length, fold_consecutive=True, return_hypotheses=True
)
batched_hyps = batched_decoding.ctc_decoder_predictions_tensor(
input_signal, length, fold_consecutive=True, return_hypotheses=True
)
assert len(hyps) == len(batched_hyps) == B
for hyp, batched_hyp in zip(hyps, batched_hyps):
assert torch.abs(hyp.score - batched_hyp.score) <= 1e-5
assert torch.all(hyp.y_sequence == batched_hyp.y_sequence)
if timestamps:
assert hyp.timestamp == batched_hyp.timestamp
if alignments:
assert torch.all(hyp.alignments[0] == batched_hyp.alignments[0])
assert torch.all(hyp.alignments[1] == batched_hyp.alignments[1])
@pytest.mark.unit
@pytest.mark.parametrize('timestamps', [False, True])
@pytest.mark.parametrize('length_is_none', [False, True])
@pytest.mark.parametrize(
"labels_device",
[
torch.device("cpu"),
pytest.param(
torch.device("cuda"),
marks=pytest.mark.skipif(
not torch.cuda.is_available(),
reason='CUDA required for test.',
),
),
],
)
@pytest.mark.parametrize(
"length_device",
[
torch.device("cpu"),
pytest.param(
torch.device("cuda"),
marks=pytest.mark.skipif(
not torch.cuda.is_available(),
reason='CUDA required for test.',
),
),
],
)
def test_batched_decoding_labels(self, tmp_tokenizer, timestamps, length_is_none, labels_device, length_device):
cfg = CTCBPEDecodingConfig(strategy='greedy', compute_timestamps=timestamps)
unbatched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer)
cfg.strategy = 'greedy_batch'
batched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer)
torch.manual_seed(1)
B, T = 4, 20
V = unbatched_decoding.tokenizer.tokenizer.vocab_size + 1
input_labels = torch.randint(V, size=(B, T), device=labels_device)
# Set some indices to blank to make sure that we always handle
# at least a few blanks.
input_labels[:, 0] = unbatched_decoding.tokenizer.tokenizer.vocab_size
input_labels[:, 1] = unbatched_decoding.tokenizer.tokenizer.vocab_size
if length_is_none:
length = None
else:
length = torch.randint(low=1, high=T, size=[B], device=length_device)
with torch.inference_mode():
hyps = unbatched_decoding.ctc_decoder_predictions_tensor(
input_labels, length, fold_consecutive=True, return_hypotheses=True
)
batched_hyps = batched_decoding.ctc_decoder_predictions_tensor(
input_labels, length, fold_consecutive=True, return_hypotheses=True
)
assert len(hyps) == len(batched_hyps) == B
for hyp, batched_hyp in zip(hyps, batched_hyps):
assert abs(hyp.score - batched_hyp.score) <= 1e-5
assert torch.all(hyp.y_sequence == batched_hyp.y_sequence)
if timestamps:
assert hyp.timestamp == batched_hyp.timestamp
class TestCTCTimestamps(BaseTimestampsTest):
"""CTC-specific timestamp tests that inherit from BaseTimestampsTest"""
@cached_property
def decoding_char(self):
cfg = CTCDecodingConfig()
vocab = char_vocabulary()
decoding = CTCDecoding(decoding_cfg=cfg, vocabulary=vocab)
return decoding
@cached_property
def decoding_subword_wpe(self):
cfg = CTCBPEDecodingConfig(compute_timestamps=True)
decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=self.tmp_tokenizer)
return decoding
@cached_property
def decoding_subword_bpe(self):
cfg = CTCBPEDecodingConfig(compute_timestamps=True)
decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=self.bpe_tokenizer)
return decoding
@pytest.mark.unit
def test_word_offsets_subword_wpe(self, tmp_tokenizer):
self.tmp_tokenizer = tmp_tokenizer
super().test_word_offsets_subword_wpe()
@pytest.mark.unit
def test_word_offsets_subword_wpe_other_delimiter(self, tmp_tokenizer):
self.tmp_tokenizer = tmp_tokenizer
super().test_word_offsets_subword_wpe_other_delimiter()
class TestCTCGreedyDecodingWithNGPU_LM:
@pytest.mark.with_downloads
@pytest.mark.unit
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test is only GPU-based decoding")
def test_ctc_decoding_gpulm(
self,
audio_file,
kenlm_model_path,
ctc_model,
):
device = torch.device("cuda")
model = ctc_model.to(device)
gt_hyp = model.transcribe([audio_file], num_workers=None)
decoding_config = copy.deepcopy(model.cfg.decoding)
with open_dict(model.decoding.cfg) as cfg:
cfg.greedy["ngram_lm_model"] = kenlm_model_path
cfg.greedy["ngram_lm_alpha"] = 0.0
model.change_decoding_strategy(cfg)
lm_hyp = model.transcribe([audio_file], num_workers=None)
assert gt_hyp[0].text == lm_hyp[0].text
assert abs(gt_hyp[0].score - lm_hyp[0].score) <= 1e-3
with open_dict(model.decoding.cfg) as cfg:
cfg.greedy["ngram_lm_model"] = kenlm_model_path
cfg.greedy["ngram_lm_alpha"] = 10.0
model.change_decoding_strategy(cfg)
lm_hyp = model.transcribe([audio_file], num_workers=None)
assert gt_hyp[0].text != lm_hyp[0].text
assert abs(gt_hyp[0].score - lm_hyp[0].score) > 1e-3
model.change_decoding_strategy(decoding_config)
class TestCTCGreedyDecodingCudaGrpahs:
"""
Tests CudaGraphs implementations from CTC models greedy decoding
"""
@pytest.mark.with_downloads
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA decoder can run only on CUDA")
@pytest.mark.parametrize("force_mode", ["no_graphs", "no_while_loops", "full_graph"])
def test_stated_stateless(self, audio_file, kenlm_model_path, ctc_model, force_mode: str):
"""
Compares pure Pytorch and with three modes of statefull implementations for double floating point precision.
1. Pure pytorch, but statefull implementation: no_graphs
2. With CudaGrpahs: no_while_loops and full_graph.
"""
if force_mode == "full_graph":
skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported()
device = torch.device("cuda")
model = ctc_model.to(device)
decoding_config = copy.deepcopy(model.cfg.decoding)
with open_dict(model.decoding.cfg) as cfg:
cfg.greedy["ngram_lm_model"] = kenlm_model_path
cfg.greedy["ngram_lm_alpha"] = 0.2
cfg.greedy["allow_cuda_graphs"] = False
model.change_decoding_strategy(cfg)
actual_hypotheses = model.transcribe([audio_file], num_workers=None)
actual_transcripts = [hyp.text for hyp in actual_hypotheses]
actual_scores = [hyp.score for hyp in actual_hypotheses]
actual_timestamps = [hyp.timestamp for hyp in actual_hypotheses]
# transcribe with use implementation with cuda graphs
model.decoding.cfg["greedy"]["allow_cuda_graphs"] = True
model.change_decoding_strategy(model.decoding.cfg)
model.decoding.decoding.force_cuda_graphs_mode(mode=force_mode)
cudagraph_hypotheses = model.transcribe([audio_file], num_workers=None)
cudagraph_transcripts = [hyp.text for hyp in cudagraph_hypotheses]
cudagraph_scores = [hyp.score for hyp in cudagraph_hypotheses]
cudagraph_timestamps = [hyp.timestamp for hyp in cudagraph_hypotheses]
for batch_idx in range(len(actual_transcripts)):
assert len(actual_transcripts[batch_idx]) == len(cudagraph_transcripts[batch_idx])
assert cudagraph_scores[batch_idx] == pytest.approx(
actual_scores[batch_idx], abs=1e-2
), f"Scores mismatch for batch_idx {batch_idx}"
assert (
cudagraph_timestamps[batch_idx] == actual_timestamps[batch_idx]
), f"Timestamps mismatch for batch_idx {batch_idx}"
wer = jiwer.wer(actual_transcripts[batch_idx], cudagraph_transcripts[batch_idx])
assert wer <= 1e-3, "Cuda graph greedy decoder should match original decoder implementation."
for actual, fast in zip(actual_transcripts[batch_idx], cudagraph_transcripts[batch_idx]):
if actual != fast:
print("Erroneous samples in batch:", batch_idx)
print("Original transcript:", actual)
print("New transcript:", fast)
model.change_decoding_strategy(decoding_config)
|