File size: 35,678 Bytes
7934b29 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import random
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
import torch
from packaging import version
from nemo.collections.asr.parts.numba.spec_augment import SpecAugmentNumba, spec_augment_launch_heuristics
from nemo.collections.asr.parts.preprocessing.features import (
FilterbankFeatures,
FilterbankFeaturesTA,
make_seq_mask_like,
)
from nemo.collections.asr.parts.submodules.spectr_augment import SpecAugment, SpecCutout
from nemo.core.classes import Exportable, NeuralModule, typecheck
from nemo.core.neural_types import (
AudioSignal,
LengthsType,
MelSpectrogramType,
MFCCSpectrogramType,
NeuralType,
SpectrogramType,
)
from nemo.core.utils import numba_utils
from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__
from nemo.utils import logging
try:
import torchaudio
import torchaudio.functional
import torchaudio.transforms
TORCHAUDIO_VERSION = version.parse(torchaudio.__version__)
TORCHAUDIO_VERSION_MIN = version.parse('0.5')
HAVE_TORCHAUDIO = True
except ModuleNotFoundError:
HAVE_TORCHAUDIO = False
__all__ = [
'AudioToMelSpectrogramPreprocessor',
'AudioToSpectrogram',
'SpectrogramToAudio',
'AudioToMFCCPreprocessor',
'SpectrogramAugmentation',
'MaskedPatchAugmentation',
'CropOrPadSpectrogramAugmentation',
]
class AudioPreprocessor(NeuralModule, ABC):
"""
An interface for Neural Modules that performs audio pre-processing,
transforming the wav files to features.
"""
def __init__(self, win_length, hop_length):
super().__init__()
self.win_length = win_length
self.hop_length = hop_length
self.torch_windows = {
'hann': torch.hann_window,
'hamming': torch.hamming_window,
'blackman': torch.blackman_window,
'bartlett': torch.bartlett_window,
'ones': torch.ones,
None: torch.ones,
}
@typecheck()
@torch.no_grad()
def forward(self, input_signal, length):
processed_signal, processed_length = self.get_features(input_signal, length)
return processed_signal, processed_length
@abstractmethod
def get_features(self, input_signal, length):
# Called by forward(). Subclasses should implement this.
pass
class AudioToMelSpectrogramPreprocessor(AudioPreprocessor, Exportable):
"""Featurizer module that converts wavs to mel spectrograms.
Args:
sample_rate (int): Sample rate of the input audio data.
Defaults to 16000
window_size (float): Size of window for fft in seconds
Defaults to 0.02
window_stride (float): Stride of window for fft in seconds
Defaults to 0.01
n_window_size (int): Size of window for fft in samples
Defaults to None. Use one of window_size or n_window_size.
n_window_stride (int): Stride of window for fft in samples
Defaults to None. Use one of window_stride or n_window_stride.
window (str): Windowing function for fft. can be one of ['hann',
'hamming', 'blackman', 'bartlett']
Defaults to "hann"
normalize (str): Can be one of ['per_feature', 'all_features']; all
other options disable feature normalization. 'all_features'
normalizes the entire spectrogram to be mean 0 with std 1.
'pre_features' normalizes per channel / freq instead.
Defaults to "per_feature"
n_fft (int): Length of FT window. If None, it uses the smallest power
of 2 that is larger than n_window_size.
Defaults to None
preemph (float): Amount of pre emphasis to add to audio. Can be
disabled by passing None.
Defaults to 0.97
features (int): Number of mel spectrogram freq bins to output.
Defaults to 64
lowfreq (int): Lower bound on mel basis in Hz.
Defaults to 0
highfreq (int): Lower bound on mel basis in Hz.
Defaults to None
log (bool): Log features.
Defaults to True
log_zero_guard_type(str): Need to avoid taking the log of zero. There
are two options: "add" or "clamp".
Defaults to "add".
log_zero_guard_value(float, or str): Add or clamp requires the number
to add with or clamp to. log_zero_guard_value can either be a float
or "tiny" or "eps". torch.finfo is used if "tiny" or "eps" is
passed.
Defaults to 2**-24.
dither (float): Amount of white-noise dithering.
Defaults to 1e-5
pad_to (int): Ensures that the output size of the time dimension is
a multiple of pad_to.
Defaults to 16
frame_splicing (int): Defaults to 1
exact_pad (bool): If True, sets stft center to False and adds padding, such that num_frames = audio_length
// hop_length. Defaults to False.
pad_value (float): The value that shorter mels are padded with.
Defaults to 0
mag_power (float): The power that the linear spectrogram is raised to
prior to multiplication with mel basis.
Defaults to 2 for a power spec
rng : Random number generator
nb_augmentation_prob (float) : Probability with which narrowband augmentation would be applied to
samples in the batch.
Defaults to 0.0
nb_max_freq (int) : Frequency above which all frequencies will be masked for narrowband augmentation.
Defaults to 4000
use_torchaudio: Whether to use the `torchaudio` implementation.
mel_norm: Normalization used for mel filterbank weights.
Defaults to 'slaney' (area normalization)
stft_exact_pad: Deprecated argument, kept for compatibility with older checkpoints.
stft_conv: Deprecated argument, kept for compatibility with older checkpoints.
"""
def save_to(self, save_path: str):
pass
@classmethod
def restore_from(cls, restore_path: str):
pass
@property
def input_types(self):
"""Returns definitions of module input ports.
"""
return {
"input_signal": NeuralType(('B', 'T'), AudioSignal(freq=self._sample_rate)),
"length": NeuralType(
tuple('B'), LengthsType()
), # Please note that length should be in samples not seconds.
}
@property
def output_types(self):
"""Returns definitions of module output ports.
processed_signal:
0: AxisType(BatchTag)
1: AxisType(MelSpectrogramSignalTag)
2: AxisType(ProcessedTimeTag)
processed_length:
0: AxisType(BatchTag)
"""
return {
"processed_signal": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
"processed_length": NeuralType(tuple('B'), LengthsType()),
}
def __init__(
self,
sample_rate=16000,
window_size=0.02,
window_stride=0.01,
n_window_size=None,
n_window_stride=None,
window="hann",
normalize="per_feature",
n_fft=None,
preemph=0.97,
features=64,
lowfreq=0,
highfreq=None,
log=True,
log_zero_guard_type="add",
log_zero_guard_value=2 ** -24,
dither=1e-5,
pad_to=16,
frame_splicing=1,
exact_pad=False,
pad_value=0,
mag_power=2.0,
rng=None,
nb_augmentation_prob=0.0,
nb_max_freq=4000,
use_torchaudio: bool = False,
mel_norm="slaney",
stft_exact_pad=False, # Deprecated arguments; kept for config compatibility
stft_conv=False, # Deprecated arguments; kept for config compatibility
):
super().__init__(n_window_size, n_window_stride)
self._sample_rate = sample_rate
if window_size and n_window_size:
raise ValueError(f"{self} received both window_size and " f"n_window_size. Only one should be specified.")
if window_stride and n_window_stride:
raise ValueError(
f"{self} received both window_stride and " f"n_window_stride. Only one should be specified."
)
if window_size:
n_window_size = int(window_size * self._sample_rate)
if window_stride:
n_window_stride = int(window_stride * self._sample_rate)
# Given the long and similar argument list, point to the class and instantiate it by reference
if not use_torchaudio:
featurizer_class = FilterbankFeatures
else:
featurizer_class = FilterbankFeaturesTA
self.featurizer = featurizer_class(
sample_rate=self._sample_rate,
n_window_size=n_window_size,
n_window_stride=n_window_stride,
window=window,
normalize=normalize,
n_fft=n_fft,
preemph=preemph,
nfilt=features,
lowfreq=lowfreq,
highfreq=highfreq,
log=log,
log_zero_guard_type=log_zero_guard_type,
log_zero_guard_value=log_zero_guard_value,
dither=dither,
pad_to=pad_to,
frame_splicing=frame_splicing,
exact_pad=exact_pad,
pad_value=pad_value,
mag_power=mag_power,
rng=rng,
nb_augmentation_prob=nb_augmentation_prob,
nb_max_freq=nb_max_freq,
mel_norm=mel_norm,
stft_exact_pad=stft_exact_pad, # Deprecated arguments; kept for config compatibility
stft_conv=stft_conv, # Deprecated arguments; kept for config compatibility
)
def input_example(self, max_batch: int = 8, max_dim: int = 32000, min_length: int = 200):
batch_size = torch.randint(low=1, high=max_batch, size=[1]).item()
max_length = torch.randint(low=min_length, high=max_dim, size=[1]).item()
signals = torch.rand(size=[batch_size, max_length]) * 2 - 1
lengths = torch.randint(low=min_length, high=max_dim, size=[batch_size])
lengths[0] = max_length
return signals, lengths
def get_features(self, input_signal, length):
return self.featurizer(input_signal, length)
@property
def filter_banks(self):
return self.featurizer.filter_banks
class AudioToMFCCPreprocessor(AudioPreprocessor):
"""Preprocessor that converts wavs to MFCCs.
Uses torchaudio.transforms.MFCC.
Args:
sample_rate: The sample rate of the audio.
Defaults to 16000.
window_size: Size of window for fft in seconds. Used to calculate the
win_length arg for mel spectrogram.
Defaults to 0.02
window_stride: Stride of window for fft in seconds. Used to caculate
the hop_length arg for mel spect.
Defaults to 0.01
n_window_size: Size of window for fft in samples
Defaults to None. Use one of window_size or n_window_size.
n_window_stride: Stride of window for fft in samples
Defaults to None. Use one of window_stride or n_window_stride.
window: Windowing function for fft. can be one of ['hann',
'hamming', 'blackman', 'bartlett', 'none', 'null'].
Defaults to 'hann'
n_fft: Length of FT window. If None, it uses the smallest power of 2
that is larger than n_window_size.
Defaults to None
lowfreq (int): Lower bound on mel basis in Hz.
Defaults to 0
highfreq (int): Lower bound on mel basis in Hz.
Defaults to None
n_mels: Number of mel filterbanks.
Defaults to 64
n_mfcc: Number of coefficients to retain
Defaults to 64
dct_type: Type of discrete cosine transform to use
norm: Type of norm to use
log: Whether to use log-mel spectrograms instead of db-scaled.
Defaults to True.
"""
@property
def input_types(self):
"""Returns definitions of module input ports.
"""
return {
"input_signal": NeuralType(('B', 'T'), AudioSignal(freq=self._sample_rate)),
"length": NeuralType(tuple('B'), LengthsType()),
}
@property
def output_types(self):
"""Returns definitions of module output ports.
"""
return {
"processed_signal": NeuralType(('B', 'D', 'T'), MFCCSpectrogramType()),
"processed_length": NeuralType(tuple('B'), LengthsType()),
}
def save_to(self, save_path: str):
pass
@classmethod
def restore_from(cls, restore_path: str):
pass
def __init__(
self,
sample_rate=16000,
window_size=0.02,
window_stride=0.01,
n_window_size=None,
n_window_stride=None,
window='hann',
n_fft=None,
lowfreq=0.0,
highfreq=None,
n_mels=64,
n_mfcc=64,
dct_type=2,
norm='ortho',
log=True,
):
self._sample_rate = sample_rate
if not HAVE_TORCHAUDIO:
logging.error('Could not import torchaudio. Some features might not work.')
raise ModuleNotFoundError(
"torchaudio is not installed but is necessary for "
"AudioToMFCCPreprocessor. We recommend you try "
"building it from source for the PyTorch version you have."
)
if window_size and n_window_size:
raise ValueError(f"{self} received both window_size and " f"n_window_size. Only one should be specified.")
if window_stride and n_window_stride:
raise ValueError(
f"{self} received both window_stride and " f"n_window_stride. Only one should be specified."
)
# Get win_length (n_window_size) and hop_length (n_window_stride)
if window_size:
n_window_size = int(window_size * self._sample_rate)
if window_stride:
n_window_stride = int(window_stride * self._sample_rate)
super().__init__(n_window_size, n_window_stride)
mel_kwargs = {}
mel_kwargs['f_min'] = lowfreq
mel_kwargs['f_max'] = highfreq
mel_kwargs['n_mels'] = n_mels
mel_kwargs['n_fft'] = n_fft or 2 ** math.ceil(math.log2(n_window_size))
mel_kwargs['win_length'] = n_window_size
mel_kwargs['hop_length'] = n_window_stride
# Set window_fn. None defaults to torch.ones.
window_fn = self.torch_windows.get(window, None)
if window_fn is None:
raise ValueError(
f"Window argument for AudioProcessor is invalid: {window}."
f"For no window function, use 'ones' or None."
)
mel_kwargs['window_fn'] = window_fn
# Use torchaudio's implementation of MFCCs as featurizer
self.featurizer = torchaudio.transforms.MFCC(
sample_rate=self._sample_rate,
n_mfcc=n_mfcc,
dct_type=dct_type,
norm=norm,
log_mels=log,
melkwargs=mel_kwargs,
)
def get_features(self, input_signal, length):
features = self.featurizer(input_signal)
seq_len = torch.ceil(length.to(torch.float32) / self.hop_length).to(dtype=torch.long)
return features, seq_len
class SpectrogramAugmentation(NeuralModule):
"""
Performs time and freq cuts in one of two ways.
SpecAugment zeroes out vertical and horizontal sections as described in
SpecAugment (https://arxiv.org/abs/1904.08779). Arguments for use with
SpecAugment are `freq_masks`, `time_masks`, `freq_width`, and `time_width`.
SpecCutout zeroes out rectangulars as described in Cutout
(https://arxiv.org/abs/1708.04552). Arguments for use with Cutout are
`rect_masks`, `rect_freq`, and `rect_time`.
Args:
freq_masks (int): how many frequency segments should be cut.
Defaults to 0.
time_masks (int): how many time segments should be cut
Defaults to 0.
freq_width (int): maximum number of frequencies to be cut in one
segment.
Defaults to 10.
time_width (int): maximum number of time steps to be cut in one
segment
Defaults to 10.
rect_masks (int): how many rectangular masks should be cut
Defaults to 0.
rect_freq (int): maximum size of cut rectangles along the frequency
dimension
Defaults to 5.
rect_time (int): maximum size of cut rectangles along the time
dimension
Defaults to 25.
"""
@property
def input_types(self):
"""Returns definitions of module input types
"""
return {
"input_spec": NeuralType(('B', 'D', 'T'), SpectrogramType()),
"length": NeuralType(tuple('B'), LengthsType()),
}
@property
def output_types(self):
"""Returns definitions of module output types
"""
return {"augmented_spec": NeuralType(('B', 'D', 'T'), SpectrogramType())}
def __init__(
self,
freq_masks=0,
time_masks=0,
freq_width=10,
time_width=10,
rect_masks=0,
rect_time=5,
rect_freq=20,
rng=None,
mask_value=0.0,
use_numba_spec_augment: bool = True,
):
super().__init__()
if rect_masks > 0:
self.spec_cutout = SpecCutout(rect_masks=rect_masks, rect_time=rect_time, rect_freq=rect_freq, rng=rng,)
# self.spec_cutout.to(self._device)
else:
self.spec_cutout = lambda input_spec: input_spec
if freq_masks + time_masks > 0:
self.spec_augment = SpecAugment(
freq_masks=freq_masks,
time_masks=time_masks,
freq_width=freq_width,
time_width=time_width,
rng=rng,
mask_value=mask_value,
)
else:
self.spec_augment = lambda input_spec, length: input_spec
# Check if numba is supported, and use a Numba kernel if it is
if use_numba_spec_augment and numba_utils.numba_cuda_is_supported(__NUMBA_MINIMUM_VERSION__):
logging.info('Numba CUDA SpecAugment kernel is being used')
self.spec_augment_numba = SpecAugmentNumba(
freq_masks=freq_masks,
time_masks=time_masks,
freq_width=freq_width,
time_width=time_width,
rng=rng,
mask_value=mask_value,
)
else:
self.spec_augment_numba = None
@typecheck()
def forward(self, input_spec, length):
augmented_spec = self.spec_cutout(input_spec=input_spec)
# To run the Numba kernel, correct numba version is required as well as
# tensor must be on GPU and length must be provided
if self.spec_augment_numba is not None and spec_augment_launch_heuristics(augmented_spec, length):
augmented_spec = self.spec_augment_numba(input_spec=augmented_spec, length=length)
else:
augmented_spec = self.spec_augment(input_spec=augmented_spec, length=length)
return augmented_spec
class MaskedPatchAugmentation(NeuralModule):
"""
Zeroes out fixed size time patches of the spectrogram.
All samples in batch are guaranteed to have the same amount of masked time steps.
Optionally also performs frequency masking in the same way as SpecAugment.
Args:
patch_size (int): up to how many time steps does one patch consist of.
Defaults to 48.
mask_patches (float): how many patches should be masked in each sample.
if >= 1., interpreted as number of patches (after converting to int)
if <1., interpreted as fraction of total tokens to be masked (number of patches is rounded up)
Defaults to 10.
freq_masks (int): how many frequency segments should be cut.
Defaults to 0.
freq_width (int): maximum number of frequencies to be cut in a segment.
Defaults to 0.
"""
@property
def input_types(self):
"""Returns definitions of module input types
"""
return {
"input_spec": NeuralType(('B', 'D', 'T'), SpectrogramType()),
"length": NeuralType(tuple('B'), LengthsType()),
}
@property
def output_types(self):
"""Returns definitions of module output types
"""
return {"augmented_spec": NeuralType(('B', 'D', 'T'), SpectrogramType())}
def __init__(
self, patch_size: int = 48, mask_patches: float = 10.0, freq_masks: int = 0, freq_width: int = 0,
):
super().__init__()
self.patch_size = patch_size
if mask_patches >= 1:
self.mask_patches = int(mask_patches)
elif mask_patches >= 0:
self._mask_fraction = mask_patches
self.mask_patches = None
else:
raise ValueError('mask_patches cannot be negative')
if freq_masks > 0:
self.spec_augment = SpecAugment(freq_masks=freq_masks, time_masks=0, freq_width=freq_width, time_width=0,)
else:
self.spec_augment = None
@typecheck()
def forward(self, input_spec, length):
augmented_spec = input_spec
min_len = torch.min(length)
if self.mask_patches is None:
# masking specified as fraction
len_fraction = int(min_len * self._mask_fraction)
mask_patches = len_fraction // self.patch_size + int(len_fraction % self.patch_size != 0)
else:
mask_patches = self.mask_patches
if min_len < self.patch_size * mask_patches:
mask_patches = min_len // self.patch_size
for idx in range(input_spec.shape[0]):
cur_len = length[idx]
patches = range(cur_len // self.patch_size - 1)
masked_patches = random.sample(patches, mask_patches)
for mp in masked_patches:
augmented_spec[idx, :, mp * self.patch_size : (mp + 1) * self.patch_size] = 0.0
if self.spec_augment is not None:
augmented_spec = self.spec_augment(input_spec=augmented_spec, length=length)
return augmented_spec
class CropOrPadSpectrogramAugmentation(NeuralModule):
"""
Pad or Crop the incoming Spectrogram to a certain shape.
Args:
audio_length (int): the final number of timesteps that is required.
The signal will be either padded or cropped temporally to this
size.
"""
def __init__(self, audio_length):
super(CropOrPadSpectrogramAugmentation, self).__init__()
self.audio_length = audio_length
@typecheck()
@torch.no_grad()
def forward(self, input_signal, length):
image = input_signal
num_images = image.shape[0]
audio_length = self.audio_length
image_len = image.shape[-1]
# Crop long signal
if image_len > audio_length: # randomly slice
cutout_images = []
offset = torch.randint(low=0, high=image_len - audio_length + 1, size=[num_images])
for idx, offset in enumerate(offset):
cutout_images.append(image[idx : idx + 1, :, offset : offset + audio_length])
image = torch.cat(cutout_images, dim=0)
del cutout_images
else: # symmetrically pad short signal with zeros
pad_left = (audio_length - image_len) // 2
pad_right = (audio_length - image_len) // 2
if (audio_length - image_len) % 2 == 1:
pad_right += 1
image = torch.nn.functional.pad(image, [pad_left, pad_right], mode="constant", value=0)
# Replace dynamic length sequences with static number of timesteps
length = (length * 0) + audio_length
return image, length
@property
def input_types(self):
"""Returns definitions of module output ports.
"""
return {
"input_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()),
"length": NeuralType(tuple('B'), LengthsType()),
}
@property
def output_types(self):
"""Returns definitions of module output ports.
"""
return {
"processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()),
"processed_length": NeuralType(tuple('B'), LengthsType()),
}
def save_to(self, save_path: str):
pass
@classmethod
def restore_from(cls, restore_path: str):
pass
class AudioToSpectrogram(NeuralModule):
"""Transform a batch of input multi-channel signals into a batch of
STFT-based spectrograms.
Args:
fft_length: length of FFT
hop_length: length of hops/shifts of the sliding window
power: exponent for magnitude spectrogram. Default `None` will
return a complex-valued spectrogram
"""
def __init__(self, fft_length: int, hop_length: int, power: Optional[float] = None):
if not HAVE_TORCHAUDIO:
logging.error('Could not import torchaudio. Some features might not work.')
raise ModuleNotFoundError(
"torchaudio is not installed but is necessary to instantiate a {self.__class__.__name__}"
)
super().__init__()
# For now, assume FFT length is divisible by two
if fft_length % 2 != 0:
raise ValueError(f'fft_length = {fft_length} must be divisible by 2')
self.stft = torchaudio.transforms.Spectrogram(
n_fft=fft_length, hop_length=hop_length, power=power, pad_mode='constant'
)
# number of subbands
self.F = fft_length // 2 + 1
@property
def num_subbands(self) -> int:
return self.F
@property
def input_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports.
"""
return {
"input": NeuralType(('B', 'C', 'T'), AudioSignal()),
"input_length": NeuralType(('B',), LengthsType(), optional=True),
}
@property
def output_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports.
"""
return {
"output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
"output_length": NeuralType(('B',), LengthsType()),
}
@typecheck()
def forward(
self, input: torch.Tensor, input_length: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Convert a batch of C-channel input signals
into a batch of complex-valued spectrograms.
Args:
input: Time-domain input signal with C channels, shape (B, C, T)
input_length: Length of valid entries along the time dimension, shape (B,)
Returns:
Output spectrogram with F subbands and N time frames, shape (B, C, F, N)
and output length with shape (B,).
"""
B, T = input.size(0), input.size(-1)
input = input.view(B, -1, T)
# STFT output (B, C, F, N)
with torch.cuda.amp.autocast(enabled=False):
output = self.stft(input.float())
if input_length is not None:
# Mask padded frames
output_length = self.get_output_length(input_length=input_length)
length_mask: torch.Tensor = make_seq_mask_like(
lengths=output_length, like=output, time_dim=-1, valid_ones=False
)
output = output.masked_fill(length_mask, 0.0)
else:
# Assume all frames are valid for all examples in the batch
output_length = output.size(-1) * torch.ones(B, device=output.device).long()
return output, output_length
def get_output_length(self, input_length: torch.Tensor) -> torch.Tensor:
"""Get length of valid frames for the output.
Args:
input_length: number of valid samples, shape (B,)
Returns:
Number of valid frames, shape (B,)
"""
output_length = input_length.div(self.stft.hop_length, rounding_mode='floor').add(1).long()
return output_length
class SpectrogramToAudio(NeuralModule):
"""Transform a batch of input multi-channel spectrograms into a batch of
time-domain multi-channel signals.
Args:
fft_length: length of FFT
hop_length: length of hops/shifts of the sliding window
power: exponent for magnitude spectrogram. Default `None` will
return a complex-valued spectrogram
"""
def __init__(self, fft_length: int, hop_length: int):
if not HAVE_TORCHAUDIO:
logging.error('Could not import torchaudio. Some features might not work.')
raise ModuleNotFoundError(
"torchaudio is not installed but is necessary to instantiate a {self.__class__.__name__}"
)
super().__init__()
# For now, assume FFT length is divisible by two
if fft_length % 2 != 0:
raise ValueError(f'fft_length = {fft_length} must be divisible by 2')
self.istft = torchaudio.transforms.InverseSpectrogram(
n_fft=fft_length, hop_length=hop_length, pad_mode='constant'
)
self.F = fft_length // 2 + 1
@property
def num_subbands(self) -> int:
return self.F
@property
def input_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports.
"""
return {
"input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
"input_length": NeuralType(('B',), LengthsType(), optional=True),
}
@property
def output_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports.
"""
return {
"output": NeuralType(('B', 'C', 'T'), AudioSignal()),
"output_length": NeuralType(('B',), LengthsType()),
}
@typecheck()
def forward(self, input: torch.Tensor, input_length: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Convert input complex-valued spectrogram to a time-domain
signal. Multi-channel IO is supported.
Args:
input: Input spectrogram for C channels, shape (B, C, F, N)
input_length: Length of valid entries along the time dimension, shape (B,)
Returns:
Time-domain signal with T time-domain samples and C channels, (B, C, T)
and output length with shape (B,).
"""
B, F, N = input.size(0), input.size(-2), input.size(-1)
assert F == self.F, f'Number of subbands F={F} not matching self.F={self.F}'
input = input.view(B, -1, F, N)
# iSTFT output (B, C, T)
with torch.cuda.amp.autocast(enabled=False):
output = self.istft(input.cfloat())
if input_length is not None:
# Mask padded samples
output_length = self.get_output_length(input_length=input_length)
length_mask: torch.Tensor = make_seq_mask_like(
lengths=output_length, like=output, time_dim=-1, valid_ones=False
)
output = output.masked_fill(length_mask, 0.0)
else:
# Assume all frames are valid for all examples in the batch
output_length = output.size(-1) * torch.ones(B, device=output.device).long()
return output, output_length
def get_output_length(self, input_length: torch.Tensor) -> torch.Tensor:
"""Get length of valid samples for the output.
Args:
input_length: number of valid frames, shape (B,)
Returns:
Number of valid samples, shape (B,)
"""
output_length = input_length.sub(1).mul(self.istft.hop_length).long()
return output_length
@dataclass
class AudioToMelSpectrogramPreprocessorConfig:
_target_: str = "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor"
sample_rate: int = 16000
window_size: float = 0.02
window_stride: float = 0.01
n_window_size: Optional[int] = None
n_window_stride: Optional[int] = None
window: str = "hann"
normalize: str = "per_feature"
n_fft: Optional[int] = None
preemph: float = 0.97
features: int = 64
lowfreq: int = 0
highfreq: Optional[int] = None
log: bool = True
log_zero_guard_type: str = "add"
log_zero_guard_value: float = 2 ** -24
dither: float = 1e-5
pad_to: int = 16
frame_splicing: int = 1
exact_pad: bool = False
pad_value: int = 0
mag_power: float = 2.0
rng: Optional[str] = None
nb_augmentation_prob: float = 0.0
nb_max_freq: int = 4000
use_torchaudio: bool = False
mel_norm: str = "slaney"
stft_exact_pad: bool = False # Deprecated argument, kept for compatibility with older checkpoints.
stft_conv: bool = False # Deprecated argument, kept for compatibility with older checkpoints.
@dataclass
class AudioToMFCCPreprocessorConfig:
_target_: str = 'nemo.collections.asr.modules.AudioToMFCCPreprocessor'
sample_rate: int = 16000
window_size: float = 0.02
window_stride: float = 0.01
n_window_size: Optional[int] = None
n_window_stride: Optional[int] = None
window: str = 'hann'
n_fft: Optional[int] = None
lowfreq: Optional[float] = 0.0
highfreq: Optional[float] = None
n_mels: int = 64
n_mfcc: int = 64
dct_type: int = 2
norm: str = 'ortho'
log: bool = True
@dataclass
class SpectrogramAugmentationConfig:
_target_: str = "nemo.collections.asr.modules.SpectrogramAugmentation"
freq_masks: int = 0
time_masks: int = 0
freq_width: int = 0
time_width: Optional[Any] = 0
rect_masks: int = 0
rect_time: int = 0
rect_freq: int = 0
mask_value: float = 0
rng: Optional[Any] = None # random.Random() type
use_numba_spec_augment: bool = True
@dataclass
class CropOrPadSpectrogramAugmentationConfig:
audio_length: int
_target_: str = "nemo.collections.asr.modules.CropOrPadSpectrogramAugmentation"
@dataclass
class MaskedPatchAugmentationConfig:
patch_size: int = 48
mask_patches: float = 10.0
freq_masks: int = 0
freq_width: int = 0
_target_: str = "nemo.collections.asr.modules.MaskedPatchAugmentation"
|