diff --git a/.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7431ce3526bc62795bfb1362ecdbd8363a96d236 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/_no_backend.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/_no_backend.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b939234afbf1015bfa817d010d34414beb48ce3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/_no_backend.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/common.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/common.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..677791945a748f4756972316356cbfdd10b40989 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/common.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/no_backend.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/no_backend.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a627b6b665580ce960897032389511237a038ad Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/no_backend.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/sox_io_backend.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/sox_io_backend.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..914b508c6ffaf05ed217f5b6a99855a84e4575df Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/sox_io_backend.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/backend/_sox_io_backend.py b/.venv/lib/python3.11/site-packages/torchaudio/backend/_sox_io_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..6af267b17a48d330c699e72dd3e31bc336a7d3da --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/backend/_sox_io_backend.py @@ -0,0 +1,294 @@ +import os +from typing import Optional, Tuple + +import torch +import torchaudio +from torchaudio import AudioMetaData + +sox_ext = torchaudio._extension.lazy_import_sox_ext() + + +def info( + filepath: str, + format: Optional[str] = None, +) -> AudioMetaData: + """Get signal information of an audio file. + + Args: + filepath (str): + Source of audio data. + + format (str or None, optional): + Override the format detection with the given format. + Providing the argument might help when libsox can not infer the format + from header or extension. + + Returns: + AudioMetaData: Metadata of the given audio. + """ + if not torch.jit.is_scripting(): + if hasattr(filepath, "read"): + raise RuntimeError("sox_io backend does not support file-like object.") + filepath = os.fspath(filepath) + sinfo = sox_ext.get_info(filepath, format) + return AudioMetaData(*sinfo) + + +def load( + filepath: str, + frame_offset: int = 0, + num_frames: int = -1, + normalize: bool = True, + channels_first: bool = True, + format: Optional[str] = None, +) -> Tuple[torch.Tensor, int]: + """Load audio data from file. + + Note: + This function can handle all the codecs that underlying libsox can handle, + however it is tested on the following formats; + + * WAV, AMB + + * 32-bit floating-point + * 32-bit signed integer + * 24-bit signed integer + * 16-bit signed integer + * 8-bit unsigned integer (WAV only) + + * MP3 + * FLAC + * OGG/VORBIS + * OPUS + * SPHERE + * AMR-NB + + To load ``MP3``, ``FLAC``, ``OGG/VORBIS``, ``OPUS`` and other codecs ``libsox`` does not + handle natively, your installation of ``torchaudio`` has to be linked to ``libsox`` + and corresponding codec libraries such as ``libmad`` or ``libmp3lame`` etc. + + By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with + ``float32`` dtype, and the shape of `[channel, time]`. + + .. warning:: + + ``normalize`` argument does not perform volume normalization. + It only converts the sample type to `torch.float32` from the native sample + type. + + When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit + signed integer, 24-bit signed integer, and 8-bit unsigned integer, by providing ``normalize=False``, + this function can return integer Tensor, where the samples are expressed within the whole range + of the corresponding dtype, that is, ``int32`` tensor for 32-bit signed PCM, + ``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM. Since torch does not + support ``int24`` dtype, 24-bit signed PCM are converted to ``int32`` tensors. + + ``normalize`` argument has no effect on 32-bit floating-point WAV and other formats, such as + ``flac`` and ``mp3``. + + For these formats, this function always returns ``float32`` Tensor with values. + + Args: + filepath (path-like object): Source of audio data. + frame_offset (int): + Number of frames to skip before start reading data. + num_frames (int, optional): + Maximum number of frames to read. ``-1`` reads all the remaining samples, + starting from ``frame_offset``. + This function may return the less number of frames if there is not enough + frames in the given file. + normalize (bool, optional): + When ``True``, this function converts the native sample type to ``float32``. + Default: ``True``. + + If input file is integer WAV, giving ``False`` will change the resulting Tensor type to + integer type. + This argument has no effect for formats other than integer WAV type. + + channels_first (bool, optional): + When True, the returned Tensor has dimension `[channel, time]`. + Otherwise, the returned Tensor's dimension is `[time, channel]`. + format (str or None, optional): + Override the format detection with the given format. + Providing the argument might help when libsox can not infer the format + from header or extension. + + Returns: + (torch.Tensor, int): Resulting Tensor and sample rate. + If the input file has integer wav format and ``normalize=False``, then it has + integer type, else ``float32`` type. If ``channels_first=True``, it has + `[channel, time]` else `[time, channel]`. + """ + if not torch.jit.is_scripting(): + if hasattr(filepath, "read"): + raise RuntimeError("sox_io backend does not support file-like object.") + filepath = os.fspath(filepath) + return sox_ext.load_audio_file(filepath, frame_offset, num_frames, normalize, channels_first, format) + + +def save( + filepath: str, + src: torch.Tensor, + sample_rate: int, + channels_first: bool = True, + compression: Optional[float] = None, + format: Optional[str] = None, + encoding: Optional[str] = None, + bits_per_sample: Optional[int] = None, +): + """Save audio data to file. + + Args: + filepath (path-like object): Path to save file. + src (torch.Tensor): Audio data to save. must be 2D tensor. + sample_rate (int): sampling rate + channels_first (bool, optional): If ``True``, the given tensor is interpreted as `[channel, time]`, + otherwise `[time, channel]`. + compression (float or None, optional): Used for formats other than WAV. + This corresponds to ``-C`` option of ``sox`` command. + + ``"mp3"`` + Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or + VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``. + + ``"flac"`` + Whole number from ``0`` to ``8``. ``8`` is default and highest compression. + + ``"ogg"``, ``"vorbis"`` + Number from ``-1`` to ``10``; ``-1`` is the highest compression + and lowest quality. Default: ``3``. + + See the detail at http://sox.sourceforge.net/soxformat.html. + format (str or None, optional): Override the audio format. + When ``filepath`` argument is path-like object, audio format is infered from + file extension. If file extension is missing or different, you can specify the + correct format with this argument. + + When ``filepath`` argument is file-like object, this argument is required. + + Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``, + ``"amb"``, ``"flac"``, ``"sph"``, ``"gsm"``, and ``"htk"``. + + encoding (str or None, optional): Changes the encoding for the supported formats. + This argument is effective only for supported formats, such as ``"wav"``, ``""amb"`` + and ``"sph"``. Valid values are; + + - ``"PCM_S"`` (signed integer Linear PCM) + - ``"PCM_U"`` (unsigned integer Linear PCM) + - ``"PCM_F"`` (floating point PCM) + - ``"ULAW"`` (mu-law) + - ``"ALAW"`` (a-law) + + Default values + If not provided, the default value is picked based on ``format`` and ``bits_per_sample``. + + ``"wav"``, ``"amb"`` + - | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the + | Tensor is used to determine the default value. + + - ``"PCM_U"`` if dtype is ``uint8`` + - ``"PCM_S"`` if dtype is ``int16`` or ``int32`` + - ``"PCM_F"`` if dtype is ``float32`` + + - ``"PCM_U"`` if ``bits_per_sample=8`` + - ``"PCM_S"`` otherwise + + ``"sph"`` format; + - the default value is ``"PCM_S"`` + + bits_per_sample (int or None, optional): Changes the bit depth for the supported formats. + When ``format`` is one of ``"wav"``, ``"flac"``, ``"sph"``, or ``"amb"``, you can change the + bit depth. Valid values are ``8``, ``16``, ``32`` and ``64``. + + Default Value; + If not provided, the default values are picked based on ``format`` and ``"encoding"``; + + ``"wav"``, ``"amb"``; + - | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the + | Tensor is used. + + - ``8`` if dtype is ``uint8`` + - ``16`` if dtype is ``int16`` + - ``32`` if dtype is ``int32`` or ``float32`` + + - ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"`` + - ``16`` if ``encoding`` is ``"PCM_S"`` + - ``32`` if ``encoding`` is ``"PCM_F"`` + + ``"flac"`` format; + - the default value is ``24`` + + ``"sph"`` format; + - ``16`` if ``encoding`` is ``"PCM_U"``, ``"PCM_S"``, ``"PCM_F"`` or not provided. + - ``8`` if ``encoding`` is ``"ULAW"`` or ``"ALAW"`` + + ``"amb"`` format; + - ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"`` + - ``16`` if ``encoding`` is ``"PCM_S"`` or not provided. + - ``32`` if ``encoding`` is ``"PCM_F"`` + + Supported formats/encodings/bit depth/compression are; + + ``"wav"``, ``"amb"`` + - 32-bit floating-point PCM + - 32-bit signed integer PCM + - 24-bit signed integer PCM + - 16-bit signed integer PCM + - 8-bit unsigned integer PCM + - 8-bit mu-law + - 8-bit a-law + + Note: Default encoding/bit depth is determined by the dtype of the input Tensor. + + ``"mp3"`` + Fixed bit rate (such as 128kHz) and variable bit rate compression. + Default: VBR with high quality. + + ``"flac"`` + - 8-bit + - 16-bit + - 24-bit (default) + + ``"ogg"``, ``"vorbis"`` + - Different quality level. Default: approx. 112kbps + + ``"sph"`` + - 8-bit signed integer PCM + - 16-bit signed integer PCM + - 24-bit signed integer PCM + - 32-bit signed integer PCM (default) + - 8-bit mu-law + - 8-bit a-law + - 16-bit a-law + - 24-bit a-law + - 32-bit a-law + + ``"amr-nb"`` + Bitrate ranging from 4.75 kbit/s to 12.2 kbit/s. Default: 4.75 kbit/s + + ``"gsm"`` + Lossy Speech Compression, CPU intensive. + + ``"htk"`` + Uses a default single-channel 16-bit PCM format. + + Note: + To save into formats that ``libsox`` does not handle natively, (such as ``"mp3"``, + ``"flac"``, ``"ogg"`` and ``"vorbis"``), your installation of ``torchaudio`` has + to be linked to ``libsox`` and corresponding codec libraries such as ``libmad`` + or ``libmp3lame`` etc. + """ + if not torch.jit.is_scripting(): + if hasattr(filepath, "write"): + raise RuntimeError("sox_io backend does not handle file-like object.") + filepath = os.fspath(filepath) + sox_ext.save_audio_file( + filepath, + src, + sample_rate, + channels_first, + compression, + format, + encoding, + bits_per_sample, + ) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/backend/no_backend.py b/.venv/lib/python3.11/site-packages/torchaudio/backend/no_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..2645a86bc80538fa0522f5eb80e304881f30acc7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/backend/no_backend.py @@ -0,0 +1,14 @@ +def __getattr__(name: str): + import warnings + + warnings.warn( + "Torchaudio's I/O functions now support par-call bakcend dispatch. " + "Importing backend implementation directly is no longer guaranteed to work. " + "Please use `backend` keyword with load/save/info function, instead of " + "calling the udnerlying implementation directly.", + stacklevel=2, + ) + + from . import _no_backend + + return getattr(_no_backend, name) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/compliance/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/compliance/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..65579b4f01ba09695860717f1e6cd90d6e42b631 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/compliance/__init__.py @@ -0,0 +1,5 @@ +from . import kaldi + +__all__ = [ + "kaldi", +] diff --git a/.venv/lib/python3.11/site-packages/torchaudio/compliance/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/compliance/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c052dbfa8390ebfdcb8c2cb0df3e41549031d0a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/compliance/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/compliance/__pycache__/kaldi.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/compliance/__pycache__/kaldi.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5cad6652f6e7126dd747abccaad1ac7ae1a466d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/compliance/__pycache__/kaldi.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/compliance/kaldi.py b/.venv/lib/python3.11/site-packages/torchaudio/compliance/kaldi.py new file mode 100644 index 0000000000000000000000000000000000000000..98358f40b522facc0abdfbaceec45f5887e00e54 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/compliance/kaldi.py @@ -0,0 +1,813 @@ +import math +from typing import Tuple + +import torch +import torchaudio +from torch import Tensor + +__all__ = [ + "get_mel_banks", + "inverse_mel_scale", + "inverse_mel_scale_scalar", + "mel_scale", + "mel_scale_scalar", + "spectrogram", + "fbank", + "mfcc", + "vtln_warp_freq", + "vtln_warp_mel_freq", +] + +# numeric_limits::epsilon() 1.1920928955078125e-07 +EPSILON = torch.tensor(torch.finfo(torch.float).eps) +# 1 milliseconds = 0.001 seconds +MILLISECONDS_TO_SECONDS = 0.001 + +# window types +HAMMING = "hamming" +HANNING = "hanning" +POVEY = "povey" +RECTANGULAR = "rectangular" +BLACKMAN = "blackman" +WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN] + + +def _get_epsilon(device, dtype): + return EPSILON.to(device=device, dtype=dtype) + + +def _next_power_of_2(x: int) -> int: + r"""Returns the smallest power of 2 that is greater than x""" + return 1 if x == 0 else 2 ** (x - 1).bit_length() + + +def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor: + r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``) + representing how the window is shifted along the waveform. Each row is a frame. + + Args: + waveform (Tensor): Tensor of size ``num_samples`` + window_size (int): Frame length + window_shift (int): Frame shift + snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. + + Returns: + Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame + """ + assert waveform.dim() == 1 + num_samples = waveform.size(0) + strides = (window_shift * waveform.stride(0), waveform.stride(0)) + + if snip_edges: + if num_samples < window_size: + return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device) + else: + m = 1 + (num_samples - window_size) // window_shift + else: + reversed_waveform = torch.flip(waveform, [0]) + m = (num_samples + (window_shift // 2)) // window_shift + pad = window_size // 2 - window_shift // 2 + pad_right = reversed_waveform + if pad > 0: + # torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect' + # but we want [2, 1, 0, 0, 1, 2] + pad_left = reversed_waveform[-pad:] + waveform = torch.cat((pad_left, waveform, pad_right), dim=0) + else: + # pad is negative so we want to trim the waveform at the front + waveform = torch.cat((waveform[-pad:], pad_right), dim=0) + + sizes = (m, window_size) + return waveform.as_strided(sizes, strides) + + +def _feature_window_function( + window_type: str, + window_size: int, + blackman_coeff: float, + device: torch.device, + dtype: int, +) -> Tensor: + r"""Returns a window function with the given type and size""" + if window_type == HANNING: + return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype) + elif window_type == HAMMING: + return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype) + elif window_type == POVEY: + # like hanning but goes to zero at edges + return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85) + elif window_type == RECTANGULAR: + return torch.ones(window_size, device=device, dtype=dtype) + elif window_type == BLACKMAN: + a = 2 * math.pi / (window_size - 1) + window_function = torch.arange(window_size, device=device, dtype=dtype) + # can't use torch.blackman_window as they use different coefficients + return ( + blackman_coeff + - 0.5 * torch.cos(a * window_function) + + (0.5 - blackman_coeff) * torch.cos(2 * a * window_function) + ).to(device=device, dtype=dtype) + else: + raise Exception("Invalid window type " + window_type) + + +def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor: + r"""Returns the log energy of size (m) for a strided_input (m,*)""" + device, dtype = strided_input.device, strided_input.dtype + log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m) + if energy_floor == 0.0: + return log_energy + return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype)) + + +def _get_waveform_and_window_properties( + waveform: Tensor, + channel: int, + sample_frequency: float, + frame_shift: float, + frame_length: float, + round_to_power_of_two: bool, + preemphasis_coefficient: float, +) -> Tuple[Tensor, int, int, int]: + r"""Gets the waveform and window properties""" + channel = max(channel, 0) + assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0)) + waveform = waveform[channel, :] # size (n) + window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS) + window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS) + padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size + + assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format( + window_size, len(waveform) + ) + assert 0 < window_shift, "`window_shift` must be greater than 0" + assert padded_window_size % 2 == 0, ( + "the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`" + ) + assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]" + assert sample_frequency > 0, "`sample_frequency` must be greater than zero" + return waveform, window_shift, window_size, padded_window_size + + +def _get_window( + waveform: Tensor, + padded_window_size: int, + window_size: int, + window_shift: int, + window_type: str, + blackman_coeff: float, + snip_edges: bool, + raw_energy: bool, + energy_floor: float, + dither: float, + remove_dc_offset: bool, + preemphasis_coefficient: float, +) -> Tuple[Tensor, Tensor]: + r"""Gets a window and its log energy + + Returns: + (Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m) + """ + device, dtype = waveform.device, waveform.dtype + epsilon = _get_epsilon(device, dtype) + + # size (m, window_size) + strided_input = _get_strided(waveform, window_size, window_shift, snip_edges) + + if dither != 0.0: + rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype) + strided_input = strided_input + rand_gauss * dither + + if remove_dc_offset: + # Subtract each row/frame by its mean + row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1) + strided_input = strided_input - row_means + + if raw_energy: + # Compute the log energy of each row/frame before applying preemphasis and + # window function + signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m) + + if preemphasis_coefficient != 0.0: + # strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j + offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze( + 0 + ) # size (m, window_size + 1) + strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1] + + # Apply window_function to each row/frame + window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze( + 0 + ) # size (1, window_size) + strided_input = strided_input * window_function # size (m, window_size) + + # Pad columns with zero until we reach size (m, padded_window_size) + if padded_window_size != window_size: + padding_right = padded_window_size - window_size + strided_input = torch.nn.functional.pad( + strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0 + ).squeeze(0) + + # Compute energy after window function (not the raw one) + if not raw_energy: + signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m) + + return strided_input, signal_log_energy + + +def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor: + # subtracts the column mean of the tensor size (m, n) if subtract_mean=True + # it returns size (m, n) + if subtract_mean: + col_means = torch.mean(tensor, dim=0).unsqueeze(0) + tensor = tensor - col_means + return tensor + + +def spectrogram( + waveform: Tensor, + blackman_coeff: float = 0.42, + channel: int = -1, + dither: float = 0.0, + energy_floor: float = 1.0, + frame_length: float = 25.0, + frame_shift: float = 10.0, + min_duration: float = 0.0, + preemphasis_coefficient: float = 0.97, + raw_energy: bool = True, + remove_dc_offset: bool = True, + round_to_power_of_two: bool = True, + sample_frequency: float = 16000.0, + snip_edges: bool = True, + subtract_mean: bool = False, + window_type: str = POVEY, +) -> Tensor: + r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's + compute-spectrogram-feats. + + Args: + waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) + blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``) + channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) + dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set + the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``) + energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: + this floor is applied to the zeroth component, representing the total signal energy. The floor on the + individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``1.0``) + frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``) + frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``) + min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``) + preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``) + raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``) + remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``) + round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input + to FFT. (Default: ``True``) + sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if + specified there) (Default: ``16000.0``) + snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) + subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do + it this way. (Default: ``False``) + window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') + (Default: ``'povey'``) + + Returns: + Tensor: A spectrogram identical to what Kaldi would output. The shape is + (m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided + """ + device, dtype = waveform.device, waveform.dtype + epsilon = _get_epsilon(device, dtype) + + waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties( + waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient + ) + + if len(waveform) < min_duration * sample_frequency: + # signal is too short + return torch.empty(0) + + strided_input, signal_log_energy = _get_window( + waveform, + padded_window_size, + window_size, + window_shift, + window_type, + blackman_coeff, + snip_edges, + raw_energy, + energy_floor, + dither, + remove_dc_offset, + preemphasis_coefficient, + ) + + # size (m, padded_window_size // 2 + 1, 2) + fft = torch.fft.rfft(strided_input) + + # Convert the FFT into a power spectrum + power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1) + power_spectrum[:, 0] = signal_log_energy + + power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean) + return power_spectrum + + +def inverse_mel_scale_scalar(mel_freq: float) -> float: + return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0) + + +def inverse_mel_scale(mel_freq: Tensor) -> Tensor: + return 700.0 * ((mel_freq / 1127.0).exp() - 1.0) + + +def mel_scale_scalar(freq: float) -> float: + return 1127.0 * math.log(1.0 + freq / 700.0) + + +def mel_scale(freq: Tensor) -> Tensor: + return 1127.0 * (1.0 + freq / 700.0).log() + + +def vtln_warp_freq( + vtln_low_cutoff: float, + vtln_high_cutoff: float, + low_freq: float, + high_freq: float, + vtln_warp_factor: float, + freq: Tensor, +) -> Tensor: + r"""This computes a VTLN warping function that is not the same as HTK's one, + but has similar inputs (this function has the advantage of never producing + empty bins). + + This function computes a warp function F(freq), defined between low_freq + and high_freq inclusive, with the following properties: + F(low_freq) == low_freq + F(high_freq) == high_freq + The function is continuous and piecewise linear with two inflection + points. + The lower inflection point (measured in terms of the unwarped + frequency) is at frequency l, determined as described below. + The higher inflection point is at a frequency h, determined as + described below. + If l <= f <= h, then F(f) = f/vtln_warp_factor. + If the higher inflection point (measured in terms of the unwarped + frequency) is at h, then max(h, F(h)) == vtln_high_cutoff. + Since (by the last point) F(h) == h/vtln_warp_factor, then + max(h, h/vtln_warp_factor) == vtln_high_cutoff, so + h = vtln_high_cutoff / max(1, 1/vtln_warp_factor). + = vtln_high_cutoff * min(1, vtln_warp_factor). + If the lower inflection point (measured in terms of the unwarped + frequency) is at l, then min(l, F(l)) == vtln_low_cutoff + This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor) + = vtln_low_cutoff * max(1, vtln_warp_factor) + Args: + vtln_low_cutoff (float): Lower frequency cutoffs for VTLN + vtln_high_cutoff (float): Upper frequency cutoffs for VTLN + low_freq (float): Lower frequency cutoffs in mel computation + high_freq (float): Upper frequency cutoffs in mel computation + vtln_warp_factor (float): Vtln warp factor + freq (Tensor): given frequency in Hz + + Returns: + Tensor: Freq after vtln warp + """ + assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq" + assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]" + l = vtln_low_cutoff * max(1.0, vtln_warp_factor) + h = vtln_high_cutoff * min(1.0, vtln_warp_factor) + scale = 1.0 / vtln_warp_factor + Fl = scale * l # F(l) + Fh = scale * h # F(h) + assert l > low_freq and h < high_freq + # slope of left part of the 3-piece linear function + scale_left = (Fl - low_freq) / (l - low_freq) + # [slope of center part is just "scale"] + + # slope of right part of the 3-piece linear function + scale_right = (high_freq - Fh) / (high_freq - h) + + res = torch.empty_like(freq) + + outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq + before_l = torch.lt(freq, l) # freq < l + before_h = torch.lt(freq, h) # freq < h + after_h = torch.ge(freq, h) # freq >= h + + # order of operations matter here (since there is overlapping frequency regions) + res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq) + res[before_h] = scale * freq[before_h] + res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq) + res[outside_low_high_freq] = freq[outside_low_high_freq] + + return res + + +def vtln_warp_mel_freq( + vtln_low_cutoff: float, + vtln_high_cutoff: float, + low_freq, + high_freq: float, + vtln_warp_factor: float, + mel_freq: Tensor, +) -> Tensor: + r""" + Args: + vtln_low_cutoff (float): Lower frequency cutoffs for VTLN + vtln_high_cutoff (float): Upper frequency cutoffs for VTLN + low_freq (float): Lower frequency cutoffs in mel computation + high_freq (float): Upper frequency cutoffs in mel computation + vtln_warp_factor (float): Vtln warp factor + mel_freq (Tensor): Given frequency in Mel + + Returns: + Tensor: ``mel_freq`` after vtln warp + """ + return mel_scale( + vtln_warp_freq( + vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq) + ) + ) + + +def get_mel_banks( + num_bins: int, + window_length_padded: int, + sample_freq: float, + low_freq: float, + high_freq: float, + vtln_low: float, + vtln_high: float, + vtln_warp_factor: float, +) -> Tuple[Tensor, Tensor]: + """ + Returns: + (Tensor, Tensor): The tuple consists of ``bins`` (which is + melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is + center frequencies of bins of size (``num_bins``)). + """ + assert num_bins > 3, "Must have at least 3 mel bins" + assert window_length_padded % 2 == 0 + num_fft_bins = window_length_padded / 2 + nyquist = 0.5 * sample_freq + + if high_freq <= 0.0: + high_freq += nyquist + + assert ( + (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq) + ), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist) + + # fft-bin width [think of it as Nyquist-freq / half-window-length] + fft_bin_width = sample_freq / window_length_padded + mel_low_freq = mel_scale_scalar(low_freq) + mel_high_freq = mel_scale_scalar(high_freq) + + # divide by num_bins+1 in next line because of end-effects where the bins + # spread out to the sides. + mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1) + + if vtln_high < 0.0: + vtln_high += nyquist + + assert vtln_warp_factor == 1.0 or ( + (low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high) + ), "Bad values in options: vtln-low {} and vtln-high {}, versus " "low-freq {} and high-freq {}".format( + vtln_low, vtln_high, low_freq, high_freq + ) + + bin = torch.arange(num_bins).unsqueeze(1) + left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1) + center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1) + right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1) + + if vtln_warp_factor != 1.0: + left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel) + center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel) + right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel) + + center_freqs = inverse_mel_scale(center_mel) # size (num_bins) + # size(1, num_fft_bins) + mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0) + + # size (num_bins, num_fft_bins) + up_slope = (mel - left_mel) / (center_mel - left_mel) + down_slope = (right_mel - mel) / (right_mel - center_mel) + + if vtln_warp_factor == 1.0: + # left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values + bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope)) + else: + # warping can move the order of left_mel, center_mel, right_mel anywhere + bins = torch.zeros_like(up_slope) + up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel + down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel + bins[up_idx] = up_slope[up_idx] + bins[down_idx] = down_slope[down_idx] + + return bins, center_freqs + + +def fbank( + waveform: Tensor, + blackman_coeff: float = 0.42, + channel: int = -1, + dither: float = 0.0, + energy_floor: float = 1.0, + frame_length: float = 25.0, + frame_shift: float = 10.0, + high_freq: float = 0.0, + htk_compat: bool = False, + low_freq: float = 20.0, + min_duration: float = 0.0, + num_mel_bins: int = 23, + preemphasis_coefficient: float = 0.97, + raw_energy: bool = True, + remove_dc_offset: bool = True, + round_to_power_of_two: bool = True, + sample_frequency: float = 16000.0, + snip_edges: bool = True, + subtract_mean: bool = False, + use_energy: bool = False, + use_log_fbank: bool = True, + use_power: bool = True, + vtln_high: float = -500.0, + vtln_low: float = 100.0, + vtln_warp: float = 1.0, + window_type: str = POVEY, +) -> Tensor: + r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's + compute-fbank-feats. + + Args: + waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) + blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``) + channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) + dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set + the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``) + energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: + this floor is applied to the zeroth component, representing the total signal energy. The floor on the + individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``1.0``) + frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``) + frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``) + high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist) + (Default: ``0.0``) + htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features + (need to change other parameters). (Default: ``False``) + low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``) + min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``) + num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``) + preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``) + raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``) + remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``) + round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input + to FFT. (Default: ``True``) + sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if + specified there) (Default: ``16000.0``) + snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) + subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do + it this way. (Default: ``False``) + use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``) + use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``) + use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``) + vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if + negative, offset from high-mel-freq (Default: ``-500.0``) + vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``) + vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``) + window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') + (Default: ``'povey'``) + + Returns: + Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``) + where m is calculated in _get_strided + """ + device, dtype = waveform.device, waveform.dtype + + waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties( + waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient + ) + + if len(waveform) < min_duration * sample_frequency: + # signal is too short + return torch.empty(0, device=device, dtype=dtype) + + # strided_input, size (m, padded_window_size) and signal_log_energy, size (m) + strided_input, signal_log_energy = _get_window( + waveform, + padded_window_size, + window_size, + window_shift, + window_type, + blackman_coeff, + snip_edges, + raw_energy, + energy_floor, + dither, + remove_dc_offset, + preemphasis_coefficient, + ) + + # size (m, padded_window_size // 2 + 1) + spectrum = torch.fft.rfft(strided_input).abs() + if use_power: + spectrum = spectrum.pow(2.0) + + # size (num_mel_bins, padded_window_size // 2) + mel_energies, _ = get_mel_banks( + num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp + ) + mel_energies = mel_energies.to(device=device, dtype=dtype) + + # pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1) + mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0) + + # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins) + mel_energies = torch.mm(spectrum, mel_energies.T) + if use_log_fbank: + # avoid log of zero (which should be prevented anyway by dithering) + mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log() + + # if use_energy then add it as the last column for htk_compat == true else first column + if use_energy: + signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1) + # returns size (m, num_mel_bins + 1) + if htk_compat: + mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1) + else: + mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1) + + mel_energies = _subtract_column_mean(mel_energies, subtract_mean) + return mel_energies + + +def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor: + # returns a dct matrix of size (num_mel_bins, num_ceps) + # size (num_mel_bins, num_mel_bins) + dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho") + # kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins) + # this would be the first column in the dct_matrix for torchaudio as it expects a + # right multiply (which would be the first column of the kaldi's dct_matrix as kaldi + # expects a left multiply e.g. dct_matrix * vector). + dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins)) + dct_matrix = dct_matrix[:, :num_ceps] + return dct_matrix + + +def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor: + # returns size (num_ceps) + # Compute liftering coefficients (scaling on cepstral coeffs) + # coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected. + i = torch.arange(num_ceps) + return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter) + + +def mfcc( + waveform: Tensor, + blackman_coeff: float = 0.42, + cepstral_lifter: float = 22.0, + channel: int = -1, + dither: float = 0.0, + energy_floor: float = 1.0, + frame_length: float = 25.0, + frame_shift: float = 10.0, + high_freq: float = 0.0, + htk_compat: bool = False, + low_freq: float = 20.0, + num_ceps: int = 13, + min_duration: float = 0.0, + num_mel_bins: int = 23, + preemphasis_coefficient: float = 0.97, + raw_energy: bool = True, + remove_dc_offset: bool = True, + round_to_power_of_two: bool = True, + sample_frequency: float = 16000.0, + snip_edges: bool = True, + subtract_mean: bool = False, + use_energy: bool = False, + vtln_high: float = -500.0, + vtln_low: float = 100.0, + vtln_warp: float = 1.0, + window_type: str = POVEY, +) -> Tensor: + r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's + compute-mfcc-feats. + + Args: + waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) + blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``) + cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``) + channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) + dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set + the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``) + energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: + this floor is applied to the zeroth component, representing the total signal energy. The floor on the + individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``1.0``) + frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``) + frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``) + high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist) + (Default: ``0.0``) + htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible + features (need to change other parameters). (Default: ``False``) + low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``) + num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``) + min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``) + num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``) + preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``) + raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``) + remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``) + round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input + to FFT. (Default: ``True``) + sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if + specified there) (Default: ``16000.0``) + snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) + subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do + it this way. (Default: ``False``) + use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``) + vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if + negative, offset from high-mel-freq (Default: ``-500.0``) + vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``) + vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``) + window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') + (Default: ``"povey"``) + + Returns: + Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``) + where m is calculated in _get_strided + """ + assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins) + + device, dtype = waveform.device, waveform.dtype + + # The mel_energies should not be squared (use_power=True), not have mean subtracted + # (subtract_mean=False), and use log (use_log_fbank=True). + # size (m, num_mel_bins + use_energy) + feature = fbank( + waveform=waveform, + blackman_coeff=blackman_coeff, + channel=channel, + dither=dither, + energy_floor=energy_floor, + frame_length=frame_length, + frame_shift=frame_shift, + high_freq=high_freq, + htk_compat=htk_compat, + low_freq=low_freq, + min_duration=min_duration, + num_mel_bins=num_mel_bins, + preemphasis_coefficient=preemphasis_coefficient, + raw_energy=raw_energy, + remove_dc_offset=remove_dc_offset, + round_to_power_of_two=round_to_power_of_two, + sample_frequency=sample_frequency, + snip_edges=snip_edges, + subtract_mean=False, + use_energy=use_energy, + use_log_fbank=True, + use_power=True, + vtln_high=vtln_high, + vtln_low=vtln_low, + vtln_warp=vtln_warp, + window_type=window_type, + ) + + if use_energy: + # size (m) + signal_log_energy = feature[:, num_mel_bins if htk_compat else 0] + # offset is 0 if htk_compat==True else 1 + mel_offset = int(not htk_compat) + feature = feature[:, mel_offset : (num_mel_bins + mel_offset)] + + # size (num_mel_bins, num_ceps) + dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device) + + # size (m, num_ceps) + feature = feature.matmul(dct_matrix) + + if cepstral_lifter != 0.0: + # size (1, num_ceps) + lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0) + feature *= lifter_coeffs.to(device=device, dtype=dtype) + + # if use_energy then replace the last column for htk_compat == true else first column + if use_energy: + feature[:, 0] = signal_log_energy + + if htk_compat: + energy = feature[:, 0].unsqueeze(1) # size (m, 1) + feature = feature[:, 1:] # size (m, num_ceps - 1) + if not use_energy: + # scale on C0 (actually removing a scale we previously added that's + # part of one common definition of the cosine transform.) + energy *= math.sqrt(2) + + feature = torch.cat((feature, energy), dim=1) + + feature = _subtract_column_mean(feature, subtract_mean) + return feature diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d344400d3b8771d2c1b93ba48def361615a132f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/models/__init__.py @@ -0,0 +1,85 @@ +from ._hdemucs import HDemucs, hdemucs_high, hdemucs_low, hdemucs_medium +from .conformer import Conformer +from .conv_tasnet import conv_tasnet_base, ConvTasNet +from .deepspeech import DeepSpeech +from .emformer import Emformer +from .rnnt import emformer_rnnt_base, emformer_rnnt_model, RNNT +from .rnnt_decoder import Hypothesis, RNNTBeamSearch +from .squim import ( + squim_objective_base, + squim_objective_model, + squim_subjective_base, + squim_subjective_model, + SquimObjective, + SquimSubjective, +) +from .tacotron2 import Tacotron2 +from .wav2letter import Wav2Letter +from .wav2vec2 import ( + hubert_base, + hubert_large, + hubert_pretrain_base, + hubert_pretrain_large, + hubert_pretrain_model, + hubert_pretrain_xlarge, + hubert_xlarge, + HuBERTPretrainModel, + wav2vec2_base, + wav2vec2_large, + wav2vec2_large_lv60k, + wav2vec2_model, + wav2vec2_xlsr_1b, + wav2vec2_xlsr_2b, + wav2vec2_xlsr_300m, + Wav2Vec2Model, + wavlm_base, + wavlm_large, + wavlm_model, +) +from .wavernn import WaveRNN + + +__all__ = [ + "Wav2Letter", + "WaveRNN", + "ConvTasNet", + "conv_tasnet_base", + "DeepSpeech", + "Wav2Vec2Model", + "HuBERTPretrainModel", + "wavlm_model", + "wavlm_base", + "wavlm_large", + "wav2vec2_model", + "wav2vec2_base", + "wav2vec2_large", + "wav2vec2_large_lv60k", + "hubert_base", + "hubert_large", + "hubert_xlarge", + "hubert_pretrain_model", + "hubert_pretrain_base", + "hubert_pretrain_large", + "hubert_pretrain_xlarge", + "wav2vec2_xlsr_300m", + "wav2vec2_xlsr_1b", + "wav2vec2_xlsr_2b", + "Tacotron2", + "Conformer", + "Emformer", + "Hypothesis", + "RNNT", + "RNNTBeamSearch", + "emformer_rnnt_base", + "emformer_rnnt_model", + "HDemucs", + "hdemucs_low", + "hdemucs_medium", + "hdemucs_high", + "squim_objective_base", + "squim_objective_model", + "squim_subjective_base", + "squim_subjective_model", + "SquimObjective", + "SquimSubjective", +] diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/_hdemucs.py b/.venv/lib/python3.11/site-packages/torchaudio/models/_hdemucs.py new file mode 100644 index 0000000000000000000000000000000000000000..74a3ebd1d609e67edd09f4356a8cefa305c1fc49 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/models/_hdemucs.py @@ -0,0 +1,1008 @@ +# ***************************************************************************** +# MIT License +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ***************************************************************************** + + +import math +import typing as tp +from typing import Any, Dict, List, Optional + +import torch +from torch import nn +from torch.nn import functional as F + + +class _ScaledEmbedding(torch.nn.Module): + r"""Make continuous embeddings and boost learning rate + + Args: + num_embeddings (int): number of embeddings + embedding_dim (int): embedding dimensions + scale (float, optional): amount to scale learning rate (Default: 10.0) + smooth (bool, optional): choose to apply smoothing (Default: ``False``) + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, scale: float = 10.0, smooth: bool = False): + super().__init__() + self.embedding = nn.Embedding(num_embeddings, embedding_dim) + if smooth: + weight = torch.cumsum(self.embedding.weight.data, dim=0) + # when summing gaussian, scale raises as sqrt(n), so we normalize by that. + weight = weight / torch.arange(1, num_embeddings + 1).sqrt()[:, None] + self.embedding.weight.data[:] = weight + self.embedding.weight.data /= scale + self.scale = scale + + @property + def weight(self) -> torch.Tensor: + return self.embedding.weight * self.scale + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r"""Forward pass for embedding with scale. + Args: + x (torch.Tensor): input tensor of shape `(num_embeddings)` + + Returns: + (Tensor): + Embedding output of shape `(num_embeddings, embedding_dim)` + """ + out = self.embedding(x) * self.scale + return out + + +class _HEncLayer(torch.nn.Module): + + r"""Encoder layer. This used both by the time and the frequency branch. + Args: + chin (int): number of input channels. + chout (int): number of output channels. + kernel_size (int, optional): Kernel size for encoder (Default: 8) + stride (int, optional): Stride for encoder layer (Default: 4) + norm_groups (int, optional): number of groups for group norm. (Default: 4) + empty (bool, optional): used to make a layer with just the first conv. this is used + before merging the time and freq. branches. (Default: ``False``) + freq (bool, optional): boolean for whether conv layer is for frequency domain (Default: ``True``) + norm_type (string, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``) + context (int, optional): context size for the 1x1 conv. (Default: 0) + dconv_kw (Dict[str, Any] or None, optional): dictionary of kwargs for the DConv class. (Default: ``None``) + pad (bool, optional): true to pad the input. Padding is done so that the output size is + always the input size / stride. (Default: ``True``) + """ + + def __init__( + self, + chin: int, + chout: int, + kernel_size: int = 8, + stride: int = 4, + norm_groups: int = 4, + empty: bool = False, + freq: bool = True, + norm_type: str = "group_norm", + context: int = 0, + dconv_kw: Optional[Dict[str, Any]] = None, + pad: bool = True, + ): + super().__init__() + if dconv_kw is None: + dconv_kw = {} + norm_fn = lambda d: nn.Identity() # noqa + if norm_type == "group_norm": + norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa + pad_val = kernel_size // 4 if pad else 0 + klass = nn.Conv1d + self.freq = freq + self.kernel_size = kernel_size + self.stride = stride + self.empty = empty + self.pad = pad_val + if freq: + kernel_size = [kernel_size, 1] + stride = [stride, 1] + pad_val = [pad_val, 0] + klass = nn.Conv2d + self.conv = klass(chin, chout, kernel_size, stride, pad_val) + self.norm1 = norm_fn(chout) + + if self.empty: + self.rewrite = nn.Identity() + self.norm2 = nn.Identity() + self.dconv = nn.Identity() + else: + self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context) + self.norm2 = norm_fn(2 * chout) + self.dconv = _DConv(chout, **dconv_kw) + + def forward(self, x: torch.Tensor, inject: Optional[torch.Tensor] = None) -> torch.Tensor: + r"""Forward pass for encoding layer. + + Size depends on whether frequency or time + + Args: + x (torch.Tensor): tensor input of shape `(B, C, F, T)` for frequency and shape + `(B, C, T)` for time + inject (torch.Tensor, optional): on last layer, combine frequency and time branches through inject param, + same shape as x (default: ``None``) + + Returns: + Tensor + output tensor after encoder layer of shape `(B, C, F / stride, T)` for frequency + and shape `(B, C, ceil(T / stride))` for time + """ + + if not self.freq and x.dim() == 4: + B, C, Fr, T = x.shape + x = x.view(B, -1, T) + + if not self.freq: + le = x.shape[-1] + if not le % self.stride == 0: + x = F.pad(x, (0, self.stride - (le % self.stride))) + y = self.conv(x) + if self.empty: + return y + if inject is not None: + if inject.shape[-1] != y.shape[-1]: + raise ValueError("Injection shapes do not align") + if inject.dim() == 3 and y.dim() == 4: + inject = inject[:, :, None] + y = y + inject + y = F.gelu(self.norm1(y)) + if self.freq: + B, C, Fr, T = y.shape + y = y.permute(0, 2, 1, 3).reshape(-1, C, T) + y = self.dconv(y) + y = y.view(B, Fr, C, T).permute(0, 2, 1, 3) + else: + y = self.dconv(y) + z = self.norm2(self.rewrite(y)) + z = F.glu(z, dim=1) + return z + + +class _HDecLayer(torch.nn.Module): + r"""Decoder layer. This used both by the time and the frequency branches. + Args: + chin (int): number of input channels. + chout (int): number of output channels. + last (bool, optional): whether current layer is final layer (Default: ``False``) + kernel_size (int, optional): Kernel size for encoder (Default: 8) + stride (int): Stride for encoder layer (Default: 4) + norm_groups (int, optional): number of groups for group norm. (Default: 1) + empty (bool, optional): used to make a layer with just the first conv. this is used + before merging the time and freq. branches. (Default: ``False``) + freq (bool, optional): boolean for whether conv layer is for frequency (Default: ``True``) + norm_type (str, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``) + context (int, optional): context size for the 1x1 conv. (Default: 1) + dconv_kw (Dict[str, Any] or None, optional): dictionary of kwargs for the DConv class. (Default: ``None``) + pad (bool, optional): true to pad the input. Padding is done so that the output size is + always the input size / stride. (Default: ``True``) + """ + + def __init__( + self, + chin: int, + chout: int, + last: bool = False, + kernel_size: int = 8, + stride: int = 4, + norm_groups: int = 1, + empty: bool = False, + freq: bool = True, + norm_type: str = "group_norm", + context: int = 1, + dconv_kw: Optional[Dict[str, Any]] = None, + pad: bool = True, + ): + super().__init__() + if dconv_kw is None: + dconv_kw = {} + norm_fn = lambda d: nn.Identity() # noqa + if norm_type == "group_norm": + norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa + if pad: + if (kernel_size - stride) % 2 != 0: + raise ValueError("Kernel size and stride do not align") + pad = (kernel_size - stride) // 2 + else: + pad = 0 + self.pad = pad + self.last = last + self.freq = freq + self.chin = chin + self.empty = empty + self.stride = stride + self.kernel_size = kernel_size + klass = nn.Conv1d + klass_tr = nn.ConvTranspose1d + if freq: + kernel_size = [kernel_size, 1] + stride = [stride, 1] + klass = nn.Conv2d + klass_tr = nn.ConvTranspose2d + self.conv_tr = klass_tr(chin, chout, kernel_size, stride) + self.norm2 = norm_fn(chout) + if self.empty: + self.rewrite = nn.Identity() + self.norm1 = nn.Identity() + else: + self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context) + self.norm1 = norm_fn(2 * chin) + + def forward(self, x: torch.Tensor, skip: Optional[torch.Tensor], length): + r"""Forward pass for decoding layer. + + Size depends on whether frequency or time + + Args: + x (torch.Tensor): tensor input of shape `(B, C, F, T)` for frequency and shape + `(B, C, T)` for time + skip (torch.Tensor, optional): on first layer, separate frequency and time branches using param + (default: ``None``) + length (int): Size of tensor for output + + Returns: + (Tensor, Tensor): + Tensor + output tensor after decoder layer of shape `(B, C, F * stride, T)` for frequency domain except last + frequency layer shape is `(B, C, kernel_size, T)`. Shape is `(B, C, stride * T)` + for time domain. + Tensor + contains the output just before final transposed convolution, which is used when the + freq. and time branch separate. Otherwise, does not matter. Shape is + `(B, C, F, T)` for frequency and `(B, C, T)` for time. + """ + if self.freq and x.dim() == 3: + B, C, T = x.shape + x = x.view(B, self.chin, -1, T) + + if not self.empty: + x = x + skip + y = F.glu(self.norm1(self.rewrite(x)), dim=1) + else: + y = x + if skip is not None: + raise ValueError("Skip must be none when empty is true.") + + z = self.norm2(self.conv_tr(y)) + if self.freq: + if self.pad: + z = z[..., self.pad : -self.pad, :] + else: + z = z[..., self.pad : self.pad + length] + if z.shape[-1] != length: + raise ValueError("Last index of z must be equal to length") + if not self.last: + z = F.gelu(z) + + return z, y + + +class HDemucs(torch.nn.Module): + r"""Hybrid Demucs model from + *Hybrid Spectrogram and Waveform Source Separation* :cite:`defossez2021hybrid`. + + See Also: + * :class:`torchaudio.pipelines.SourceSeparationBundle`: Source separation pipeline with pre-trained models. + + Args: + sources (List[str]): list of source names. List can contain the following source + options: [``"bass"``, ``"drums"``, ``"other"``, ``"mixture"``, ``"vocals"``]. + audio_channels (int, optional): input/output audio channels. (Default: 2) + channels (int, optional): initial number of hidden channels. (Default: 48) + growth (int, optional): increase the number of hidden channels by this factor at each layer. (Default: 2) + nfft (int, optional): number of fft bins. Note that changing this requires careful computation of + various shape parameters and will not work out of the box for hybrid models. (Default: 4096) + depth (int, optional): number of layers in encoder and decoder (Default: 6) + freq_emb (float, optional): add frequency embedding after the first frequency layer if > 0, + the actual value controls the weight of the embedding. (Default: 0.2) + emb_scale (int, optional): equivalent to scaling the embedding learning rate (Default: 10) + emb_smooth (bool, optional): initialize the embedding with a smooth one (with respect to frequencies). + (Default: ``True``) + kernel_size (int, optional): kernel_size for encoder and decoder layers. (Default: 8) + time_stride (int, optional): stride for the final time layer, after the merge. (Default: 2) + stride (int, optional): stride for encoder and decoder layers. (Default: 4) + context (int, optional): context for 1x1 conv in the decoder. (Default: 4) + context_enc (int, optional): context for 1x1 conv in the encoder. (Default: 0) + norm_starts (int, optional): layer at which group norm starts being used. + decoder layers are numbered in reverse order. (Default: 4) + norm_groups (int, optional): number of groups for group norm. (Default: 4) + dconv_depth (int, optional): depth of residual DConv branch. (Default: 2) + dconv_comp (int, optional): compression of DConv branch. (Default: 4) + dconv_attn (int, optional): adds attention layers in DConv branch starting at this layer. (Default: 4) + dconv_lstm (int, optional): adds a LSTM layer in DConv branch starting at this layer. (Default: 4) + dconv_init (float, optional): initial scale for the DConv branch LayerScale. (Default: 1e-4) + """ + + def __init__( + self, + sources: List[str], + audio_channels: int = 2, + channels: int = 48, + growth: int = 2, + nfft: int = 4096, + depth: int = 6, + freq_emb: float = 0.2, + emb_scale: int = 10, + emb_smooth: bool = True, + kernel_size: int = 8, + time_stride: int = 2, + stride: int = 4, + context: int = 1, + context_enc: int = 0, + norm_starts: int = 4, + norm_groups: int = 4, + dconv_depth: int = 2, + dconv_comp: int = 4, + dconv_attn: int = 4, + dconv_lstm: int = 4, + dconv_init: float = 1e-4, + ): + super().__init__() + self.depth = depth + self.nfft = nfft + self.audio_channels = audio_channels + self.sources = sources + self.kernel_size = kernel_size + self.context = context + self.stride = stride + self.channels = channels + + self.hop_length = self.nfft // 4 + self.freq_emb = None + + self.freq_encoder = nn.ModuleList() + self.freq_decoder = nn.ModuleList() + + self.time_encoder = nn.ModuleList() + self.time_decoder = nn.ModuleList() + + chin = audio_channels + chin_z = chin * 2 # number of channels for the freq branch + chout = channels + chout_z = channels + freqs = self.nfft // 2 + + for index in range(self.depth): + lstm = index >= dconv_lstm + attn = index >= dconv_attn + norm_type = "group_norm" if index >= norm_starts else "none" + freq = freqs > 1 + stri = stride + ker = kernel_size + if not freq: + if freqs != 1: + raise ValueError("When freq is false, freqs must be 1.") + ker = time_stride * 2 + stri = time_stride + + pad = True + last_freq = False + if freq and freqs <= kernel_size: + ker = freqs + pad = False + last_freq = True + + kw = { + "kernel_size": ker, + "stride": stri, + "freq": freq, + "pad": pad, + "norm_type": norm_type, + "norm_groups": norm_groups, + "dconv_kw": { + "lstm": lstm, + "attn": attn, + "depth": dconv_depth, + "compress": dconv_comp, + "init": dconv_init, + }, + } + kwt = dict(kw) + kwt["freq"] = 0 + kwt["kernel_size"] = kernel_size + kwt["stride"] = stride + kwt["pad"] = True + kw_dec = dict(kw) + + if last_freq: + chout_z = max(chout, chout_z) + chout = chout_z + + enc = _HEncLayer(chin_z, chout_z, context=context_enc, **kw) + if freq: + if last_freq is True and nfft == 2048: + kwt["stride"] = 2 + kwt["kernel_size"] = 4 + tenc = _HEncLayer(chin, chout, context=context_enc, empty=last_freq, **kwt) + self.time_encoder.append(tenc) + + self.freq_encoder.append(enc) + if index == 0: + chin = self.audio_channels * len(self.sources) + chin_z = chin * 2 + dec = _HDecLayer(chout_z, chin_z, last=index == 0, context=context, **kw_dec) + if freq: + tdec = _HDecLayer(chout, chin, empty=last_freq, last=index == 0, context=context, **kwt) + self.time_decoder.insert(0, tdec) + self.freq_decoder.insert(0, dec) + + chin = chout + chin_z = chout_z + chout = int(growth * chout) + chout_z = int(growth * chout_z) + if freq: + if freqs <= kernel_size: + freqs = 1 + else: + freqs //= stride + if index == 0 and freq_emb: + self.freq_emb = _ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale) + self.freq_emb_scale = freq_emb + + _rescale_module(self) + + def _spec(self, x): + hl = self.hop_length + nfft = self.nfft + x0 = x # noqa + + # We re-pad the signal in order to keep the property + # that the size of the output is exactly the size of the input + # divided by the stride (here hop_length), when divisible. + # This is achieved by padding by 1/4th of the kernel size (here nfft). + # which is not supported by torch.stft. + # Having all convolution operations follow this convention allow to easily + # align the time and frequency branches later on. + if hl != nfft // 4: + raise ValueError("Hop length must be nfft // 4") + le = int(math.ceil(x.shape[-1] / hl)) + pad = hl // 2 * 3 + x = self._pad1d(x, pad, pad + le * hl - x.shape[-1], mode="reflect") + + z = _spectro(x, nfft, hl)[..., :-1, :] + if z.shape[-1] != le + 4: + raise ValueError("Spectrogram's last dimension must be 4 + input size divided by stride") + z = z[..., 2 : 2 + le] + return z + + def _ispec(self, z, length=None): + hl = self.hop_length + z = F.pad(z, [0, 0, 0, 1]) + z = F.pad(z, [2, 2]) + pad = hl // 2 * 3 + le = hl * int(math.ceil(length / hl)) + 2 * pad + x = _ispectro(z, hl, length=le) + x = x[..., pad : pad + length] + return x + + def _pad1d(self, x: torch.Tensor, padding_left: int, padding_right: int, mode: str = "zero", value: float = 0.0): + """Wrapper around F.pad, in order for reflect padding when num_frames is shorter than max_pad. + Add extra zero padding around in order for padding to not break.""" + length = x.shape[-1] + if mode == "reflect": + max_pad = max(padding_left, padding_right) + if length <= max_pad: + x = F.pad(x, (0, max_pad - length + 1)) + return F.pad(x, (padding_left, padding_right), mode, value) + + def _magnitude(self, z): + # move the complex dimension to the channel one. + B, C, Fr, T = z.shape + m = torch.view_as_real(z).permute(0, 1, 4, 2, 3) + m = m.reshape(B, C * 2, Fr, T) + return m + + def _mask(self, m): + # `m` is a full spectrogram and `z` is ignored. + B, S, C, Fr, T = m.shape + out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3) + out = torch.view_as_complex(out.contiguous()) + return out + + def forward(self, input: torch.Tensor): + + r"""HDemucs forward call + + Args: + input (torch.Tensor): input mixed tensor of shape `(batch_size, channel, num_frames)` + + Returns: + Tensor + output tensor split into sources of shape `(batch_size, num_sources, channel, num_frames)` + """ + + if input.ndim != 3: + raise ValueError(f"Expected 3D tensor with dimensions (batch, channel, frames). Found: {input.shape}") + + if input.shape[1] != self.audio_channels: + raise ValueError( + f"The channel dimension of input Tensor must match `audio_channels` of HDemucs model. " + f"Found:{input.shape[1]}." + ) + + x = input + length = x.shape[-1] + + z = self._spec(input) + mag = self._magnitude(z) + x = mag + + B, C, Fq, T = x.shape + + # unlike previous Demucs, we always normalize because it is easier. + mean = x.mean(dim=(1, 2, 3), keepdim=True) + std = x.std(dim=(1, 2, 3), keepdim=True) + x = (x - mean) / (1e-5 + std) + # x will be the freq. branch input. + + # Prepare the time branch input. + xt = input + meant = xt.mean(dim=(1, 2), keepdim=True) + stdt = xt.std(dim=(1, 2), keepdim=True) + xt = (xt - meant) / (1e-5 + stdt) + + saved = [] # skip connections, freq. + saved_t = [] # skip connections, time. + lengths: List[int] = [] # saved lengths to properly remove padding, freq branch. + lengths_t: List[int] = [] # saved lengths for time branch. + + for idx, encode in enumerate(self.freq_encoder): + lengths.append(x.shape[-1]) + inject = None + if idx < len(self.time_encoder): + # we have not yet merged branches. + lengths_t.append(xt.shape[-1]) + tenc = self.time_encoder[idx] + xt = tenc(xt) + if not tenc.empty: + # save for skip connection + saved_t.append(xt) + else: + # tenc contains just the first conv., so that now time and freq. + # branches have the same shape and can be merged. + inject = xt + x = encode(x, inject) + if idx == 0 and self.freq_emb is not None: + # add frequency embedding to allow for non equivariant convolutions + # over the frequency axis. + frs = torch.arange(x.shape[-2], device=x.device) + emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x) + x = x + self.freq_emb_scale * emb + + saved.append(x) + + x = torch.zeros_like(x) + xt = torch.zeros_like(x) + # initialize everything to zero (signal will go through u-net skips). + + for idx, decode in enumerate(self.freq_decoder): + skip = saved.pop(-1) + x, pre = decode(x, skip, lengths.pop(-1)) + # `pre` contains the output just before final transposed convolution, + # which is used when the freq. and time branch separate. + offset = self.depth - len(self.time_decoder) + if idx >= offset: + tdec = self.time_decoder[idx - offset] + length_t = lengths_t.pop(-1) + if tdec.empty: + if pre.shape[2] != 1: + raise ValueError(f"If tdec empty is True, pre shape does not match {pre.shape}") + pre = pre[:, :, 0] + xt, _ = tdec(pre, None, length_t) + else: + skip = saved_t.pop(-1) + xt, _ = tdec(xt, skip, length_t) + + if len(saved) != 0: + raise AssertionError("saved is not empty") + if len(lengths_t) != 0: + raise AssertionError("lengths_t is not empty") + if len(saved_t) != 0: + raise AssertionError("saved_t is not empty") + + S = len(self.sources) + x = x.view(B, S, -1, Fq, T) + x = x * std[:, None] + mean[:, None] + + zout = self._mask(x) + x = self._ispec(zout, length) + + xt = xt.view(B, S, -1, length) + xt = xt * stdt[:, None] + meant[:, None] + x = xt + x + return x + + +class _DConv(torch.nn.Module): + r""" + New residual branches in each encoder layer. + This alternates dilated convolutions, potentially with LSTMs and attention. + Also before entering each residual branch, dimension is projected on a smaller subspace, + e.g. of dim `channels // compress`. + + Args: + channels (int): input/output channels for residual branch. + compress (float, optional): amount of channel compression inside the branch. (default: 4) + depth (int, optional): number of layers in the residual branch. Each layer has its own + projection, and potentially LSTM and attention.(default: 2) + init (float, optional): initial scale for LayerNorm. (default: 1e-4) + norm_type (bool, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``) + attn (bool, optional): use LocalAttention. (Default: ``False``) + heads (int, optional): number of heads for the LocalAttention. (default: 4) + ndecay (int, optional): number of decay controls in the LocalAttention. (default: 4) + lstm (bool, optional): use LSTM. (Default: ``False``) + kernel_size (int, optional): kernel size for the (dilated) convolutions. (default: 3) + """ + + def __init__( + self, + channels: int, + compress: float = 4, + depth: int = 2, + init: float = 1e-4, + norm_type: str = "group_norm", + attn: bool = False, + heads: int = 4, + ndecay: int = 4, + lstm: bool = False, + kernel_size: int = 3, + ): + + super().__init__() + if kernel_size % 2 == 0: + raise ValueError("Kernel size should not be divisible by 2") + self.channels = channels + self.compress = compress + self.depth = abs(depth) + dilate = depth > 0 + + norm_fn: tp.Callable[[int], nn.Module] + norm_fn = lambda d: nn.Identity() # noqa + if norm_type == "group_norm": + norm_fn = lambda d: nn.GroupNorm(1, d) # noqa + + hidden = int(channels / compress) + + act = nn.GELU + + self.layers = nn.ModuleList([]) + for d in range(self.depth): + dilation = pow(2, d) if dilate else 1 + padding = dilation * (kernel_size // 2) + mods = [ + nn.Conv1d(channels, hidden, kernel_size, dilation=dilation, padding=padding), + norm_fn(hidden), + act(), + nn.Conv1d(hidden, 2 * channels, 1), + norm_fn(2 * channels), + nn.GLU(1), + _LayerScale(channels, init), + ] + if attn: + mods.insert(3, _LocalState(hidden, heads=heads, ndecay=ndecay)) + if lstm: + mods.insert(3, _BLSTM(hidden, layers=2, skip=True)) + layer = nn.Sequential(*mods) + self.layers.append(layer) + + def forward(self, x): + r"""DConv forward call + + Args: + x (torch.Tensor): input tensor for convolution + + Returns: + Tensor + Output after being run through layers. + """ + for layer in self.layers: + x = x + layer(x) + return x + + +class _BLSTM(torch.nn.Module): + r""" + BiLSTM with same hidden units as input dim. + If `max_steps` is not None, input will be splitting in overlapping + chunks and the LSTM applied separately on each chunk. + Args: + dim (int): dimensions at LSTM layer. + layers (int, optional): number of LSTM layers. (default: 1) + skip (bool, optional): (default: ``False``) + """ + + def __init__(self, dim, layers: int = 1, skip: bool = False): + super().__init__() + self.max_steps = 200 + self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) + self.linear = nn.Linear(2 * dim, dim) + self.skip = skip + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r"""BLSTM forward call + + Args: + x (torch.Tensor): input tensor for BLSTM shape is `(batch_size, dim, time_steps)` + + Returns: + Tensor + Output after being run through bidirectional LSTM. Shape is `(batch_size, dim, time_steps)` + """ + B, C, T = x.shape + y = x + framed = False + width = 0 + stride = 0 + nframes = 0 + if self.max_steps is not None and T > self.max_steps: + width = self.max_steps + stride = width // 2 + frames = _unfold(x, width, stride) + nframes = frames.shape[2] + framed = True + x = frames.permute(0, 2, 1, 3).reshape(-1, C, width) + + x = x.permute(2, 0, 1) + + x = self.lstm(x)[0] + x = self.linear(x) + x = x.permute(1, 2, 0) + if framed: + out = [] + frames = x.reshape(B, -1, C, width) + limit = stride // 2 + for k in range(nframes): + if k == 0: + out.append(frames[:, k, :, :-limit]) + elif k == nframes - 1: + out.append(frames[:, k, :, limit:]) + else: + out.append(frames[:, k, :, limit:-limit]) + out = torch.cat(out, -1) + out = out[..., :T] + x = out + if self.skip: + x = x + y + + return x + + +class _LocalState(nn.Module): + """Local state allows to have attention based only on data (no positional embedding), + but while setting a constraint on the time window (e.g. decaying penalty term). + Also a failed experiments with trying to provide some frequency based attention. + """ + + def __init__(self, channels: int, heads: int = 4, ndecay: int = 4): + r""" + Args: + channels (int): Size of Conv1d layers. + heads (int, optional): (default: 4) + ndecay (int, optional): (default: 4) + """ + super(_LocalState, self).__init__() + if channels % heads != 0: + raise ValueError("Channels must be divisible by heads.") + self.heads = heads + self.ndecay = ndecay + self.content = nn.Conv1d(channels, channels, 1) + self.query = nn.Conv1d(channels, channels, 1) + self.key = nn.Conv1d(channels, channels, 1) + + self.query_decay = nn.Conv1d(channels, heads * ndecay, 1) + if ndecay: + # Initialize decay close to zero (there is a sigmoid), for maximum initial window. + self.query_decay.weight.data *= 0.01 + if self.query_decay.bias is None: + raise ValueError("bias must not be None.") + self.query_decay.bias.data[:] = -2 + self.proj = nn.Conv1d(channels + heads * 0, channels, 1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r"""LocalState forward call + + Args: + x (torch.Tensor): input tensor for LocalState + + Returns: + Tensor + Output after being run through LocalState layer. + """ + B, C, T = x.shape + heads = self.heads + indexes = torch.arange(T, device=x.device, dtype=x.dtype) + # left index are keys, right index are queries + delta = indexes[:, None] - indexes[None, :] + + queries = self.query(x).view(B, heads, -1, T) + keys = self.key(x).view(B, heads, -1, T) + # t are keys, s are queries + dots = torch.einsum("bhct,bhcs->bhts", keys, queries) + dots /= math.sqrt(keys.shape[2]) + if self.ndecay: + decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype) + decay_q = self.query_decay(x).view(B, heads, -1, T) + decay_q = torch.sigmoid(decay_q) / 2 + decay_kernel = -decays.view(-1, 1, 1) * delta.abs() / math.sqrt(self.ndecay) + dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q) + + # Kill self reference. + dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100) + weights = torch.softmax(dots, dim=2) + + content = self.content(x).view(B, heads, -1, T) + result = torch.einsum("bhts,bhct->bhcs", weights, content) + result = result.reshape(B, -1, T) + return x + self.proj(result) + + +class _LayerScale(nn.Module): + """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). + This rescales diagonally residual outputs close to 0 initially, then learnt. + """ + + def __init__(self, channels: int, init: float = 0): + r""" + Args: + channels (int): Size of rescaling + init (float, optional): Scale to default to (default: 0) + """ + super().__init__() + self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True)) + self.scale.data[:] = init + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r"""LayerScale forward call + + Args: + x (torch.Tensor): input tensor for LayerScale + + Returns: + Tensor + Output after rescaling tensor. + """ + return self.scale[:, None] * x + + +def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor: + """Given input of size [*OT, T], output Tensor of size [*OT, F, K] + with K the kernel size, by extracting frames with the given stride. + This will pad the input so that `F = ceil(T / K)`. + see https://github.com/pytorch/pytorch/issues/60466 + """ + shape = list(a.shape[:-1]) + length = int(a.shape[-1]) + n_frames = math.ceil(length / stride) + tgt_length = (n_frames - 1) * stride + kernel_size + a = F.pad(input=a, pad=[0, tgt_length - length]) + strides = [a.stride(dim) for dim in range(a.dim())] + if strides[-1] != 1: + raise ValueError("Data should be contiguous.") + strides = strides[:-1] + [stride, 1] + shape.append(n_frames) + shape.append(kernel_size) + return a.as_strided(shape, strides) + + +def _rescale_module(module): + r""" + Rescales initial weight scale for all models within the module. + """ + for sub in module.modules(): + if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)): + std = sub.weight.std().detach() + scale = (std / 0.1) ** 0.5 + sub.weight.data /= scale + if sub.bias is not None: + sub.bias.data /= scale + + +def _spectro(x: torch.Tensor, n_fft: int = 512, hop_length: int = 0, pad: int = 0) -> torch.Tensor: + other = list(x.shape[:-1]) + length = int(x.shape[-1]) + x = x.reshape(-1, length) + z = torch.stft( + x, + n_fft * (1 + pad), + hop_length, + window=torch.hann_window(n_fft).to(x), + win_length=n_fft, + normalized=True, + center=True, + return_complex=True, + pad_mode="reflect", + ) + _, freqs, frame = z.shape + other.extend([freqs, frame]) + return z.view(other) + + +def _ispectro(z: torch.Tensor, hop_length: int = 0, length: int = 0, pad: int = 0) -> torch.Tensor: + other = list(z.shape[:-2]) + freqs = int(z.shape[-2]) + frames = int(z.shape[-1]) + + n_fft = 2 * freqs - 2 + z = z.view(-1, freqs, frames) + win_length = n_fft // (1 + pad) + x = torch.istft( + z, + n_fft, + hop_length, + window=torch.hann_window(win_length).to(z.real), + win_length=win_length, + normalized=True, + length=length, + center=True, + ) + _, length = x.shape + other.append(length) + return x.view(other) + + +def hdemucs_low(sources: List[str]) -> HDemucs: + """Builds low nfft (1024) version of :class:`HDemucs`, suitable for sample rates around 8 kHz. + + Args: + sources (List[str]): See :py:func:`HDemucs`. + + Returns: + HDemucs: + HDemucs model. + """ + + return HDemucs(sources=sources, nfft=1024, depth=5) + + +def hdemucs_medium(sources: List[str]) -> HDemucs: + r"""Builds medium nfft (2048) version of :class:`HDemucs`, suitable for sample rates of 16-32 kHz. + + .. note:: + + Medium HDemucs has not been tested against the original Hybrid Demucs as this nfft and depth configuration is + not compatible with the original implementation in https://github.com/facebookresearch/demucs + + Args: + sources (List[str]): See :py:func:`HDemucs`. + + Returns: + HDemucs: + HDemucs model. + """ + + return HDemucs(sources=sources, nfft=2048, depth=6) + + +def hdemucs_high(sources: List[str]) -> HDemucs: + r"""Builds medium nfft (4096) version of :class:`HDemucs`, suitable for sample rates of 44.1-48 kHz. + + Args: + sources (List[str]): See :py:func:`HDemucs`. + + Returns: + HDemucs: + HDemucs model. + """ + + return HDemucs(sources=sources, nfft=4096, depth=6) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/conformer.py b/.venv/lib/python3.11/site-packages/torchaudio/models/conformer.py new file mode 100644 index 0000000000000000000000000000000000000000..3da0d24fac977a65cc97f4b0afae0ab64932d4b2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/models/conformer.py @@ -0,0 +1,293 @@ +from typing import Optional, Tuple + +import torch + + +__all__ = ["Conformer"] + + +def _lengths_to_padding_mask(lengths: torch.Tensor) -> torch.Tensor: + batch_size = lengths.shape[0] + max_length = int(torch.max(lengths).item()) + padding_mask = torch.arange(max_length, device=lengths.device, dtype=lengths.dtype).expand( + batch_size, max_length + ) >= lengths.unsqueeze(1) + return padding_mask + + +class _ConvolutionModule(torch.nn.Module): + r"""Conformer convolution module. + + Args: + input_dim (int): input dimension. + num_channels (int): number of depthwise convolution layer input channels. + depthwise_kernel_size (int): kernel size of depthwise convolution layer. + dropout (float, optional): dropout probability. (Default: 0.0) + bias (bool, optional): indicates whether to add bias term to each convolution layer. (Default: ``False``) + use_group_norm (bool, optional): use GroupNorm rather than BatchNorm. (Default: ``False``) + """ + + def __init__( + self, + input_dim: int, + num_channels: int, + depthwise_kernel_size: int, + dropout: float = 0.0, + bias: bool = False, + use_group_norm: bool = False, + ) -> None: + super().__init__() + if (depthwise_kernel_size - 1) % 2 != 0: + raise ValueError("depthwise_kernel_size must be odd to achieve 'SAME' padding.") + self.layer_norm = torch.nn.LayerNorm(input_dim) + self.sequential = torch.nn.Sequential( + torch.nn.Conv1d( + input_dim, + 2 * num_channels, + 1, + stride=1, + padding=0, + bias=bias, + ), + torch.nn.GLU(dim=1), + torch.nn.Conv1d( + num_channels, + num_channels, + depthwise_kernel_size, + stride=1, + padding=(depthwise_kernel_size - 1) // 2, + groups=num_channels, + bias=bias, + ), + torch.nn.GroupNorm(num_groups=1, num_channels=num_channels) + if use_group_norm + else torch.nn.BatchNorm1d(num_channels), + torch.nn.SiLU(), + torch.nn.Conv1d( + num_channels, + input_dim, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ), + torch.nn.Dropout(dropout), + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + r""" + Args: + input (torch.Tensor): with shape `(B, T, D)`. + + Returns: + torch.Tensor: output, with shape `(B, T, D)`. + """ + x = self.layer_norm(input) + x = x.transpose(1, 2) + x = self.sequential(x) + return x.transpose(1, 2) + + +class _FeedForwardModule(torch.nn.Module): + r"""Positionwise feed forward layer. + + Args: + input_dim (int): input dimension. + hidden_dim (int): hidden dimension. + dropout (float, optional): dropout probability. (Default: 0.0) + """ + + def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.0) -> None: + super().__init__() + self.sequential = torch.nn.Sequential( + torch.nn.LayerNorm(input_dim), + torch.nn.Linear(input_dim, hidden_dim, bias=True), + torch.nn.SiLU(), + torch.nn.Dropout(dropout), + torch.nn.Linear(hidden_dim, input_dim, bias=True), + torch.nn.Dropout(dropout), + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + r""" + Args: + input (torch.Tensor): with shape `(*, D)`. + + Returns: + torch.Tensor: output, with shape `(*, D)`. + """ + return self.sequential(input) + + +class ConformerLayer(torch.nn.Module): + r"""Conformer layer that constitutes Conformer. + + Args: + input_dim (int): input dimension. + ffn_dim (int): hidden layer dimension of feedforward network. + num_attention_heads (int): number of attention heads. + depthwise_conv_kernel_size (int): kernel size of depthwise convolution layer. + dropout (float, optional): dropout probability. (Default: 0.0) + use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d`` + in the convolution module. (Default: ``False``) + convolution_first (bool, optional): apply the convolution module ahead of + the attention module. (Default: ``False``) + """ + + def __init__( + self, + input_dim: int, + ffn_dim: int, + num_attention_heads: int, + depthwise_conv_kernel_size: int, + dropout: float = 0.0, + use_group_norm: bool = False, + convolution_first: bool = False, + ) -> None: + super().__init__() + + self.ffn1 = _FeedForwardModule(input_dim, ffn_dim, dropout=dropout) + + self.self_attn_layer_norm = torch.nn.LayerNorm(input_dim) + self.self_attn = torch.nn.MultiheadAttention(input_dim, num_attention_heads, dropout=dropout) + self.self_attn_dropout = torch.nn.Dropout(dropout) + + self.conv_module = _ConvolutionModule( + input_dim=input_dim, + num_channels=input_dim, + depthwise_kernel_size=depthwise_conv_kernel_size, + dropout=dropout, + bias=True, + use_group_norm=use_group_norm, + ) + + self.ffn2 = _FeedForwardModule(input_dim, ffn_dim, dropout=dropout) + self.final_layer_norm = torch.nn.LayerNorm(input_dim) + self.convolution_first = convolution_first + + def _apply_convolution(self, input: torch.Tensor) -> torch.Tensor: + residual = input + input = input.transpose(0, 1) + input = self.conv_module(input) + input = input.transpose(0, 1) + input = residual + input + return input + + def forward(self, input: torch.Tensor, key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor: + r""" + Args: + input (torch.Tensor): input, with shape `(T, B, D)`. + key_padding_mask (torch.Tensor or None): key padding mask to use in self attention layer. + + Returns: + torch.Tensor: output, with shape `(T, B, D)`. + """ + residual = input + x = self.ffn1(input) + x = x * 0.5 + residual + + if self.convolution_first: + x = self._apply_convolution(x) + + residual = x + x = self.self_attn_layer_norm(x) + x, _ = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=key_padding_mask, + need_weights=False, + ) + x = self.self_attn_dropout(x) + x = x + residual + + if not self.convolution_first: + x = self._apply_convolution(x) + + residual = x + x = self.ffn2(x) + x = x * 0.5 + residual + + x = self.final_layer_norm(x) + return x + + +class Conformer(torch.nn.Module): + r"""Conformer architecture introduced in + *Conformer: Convolution-augmented Transformer for Speech Recognition* + :cite:`gulati2020conformer`. + + Args: + input_dim (int): input dimension. + num_heads (int): number of attention heads in each Conformer layer. + ffn_dim (int): hidden layer dimension of feedforward networks. + num_layers (int): number of Conformer layers to instantiate. + depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer. + dropout (float, optional): dropout probability. (Default: 0.0) + use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d`` + in the convolution module. (Default: ``False``) + convolution_first (bool, optional): apply the convolution module ahead of + the attention module. (Default: ``False``) + + Examples: + >>> conformer = Conformer( + >>> input_dim=80, + >>> num_heads=4, + >>> ffn_dim=128, + >>> num_layers=4, + >>> depthwise_conv_kernel_size=31, + >>> ) + >>> lengths = torch.randint(1, 400, (10,)) # (batch,) + >>> input = torch.rand(10, int(lengths.max()), input_dim) # (batch, num_frames, input_dim) + >>> output = conformer(input, lengths) + """ + + def __init__( + self, + input_dim: int, + num_heads: int, + ffn_dim: int, + num_layers: int, + depthwise_conv_kernel_size: int, + dropout: float = 0.0, + use_group_norm: bool = False, + convolution_first: bool = False, + ): + super().__init__() + + self.conformer_layers = torch.nn.ModuleList( + [ + ConformerLayer( + input_dim, + ffn_dim, + num_heads, + depthwise_conv_kernel_size, + dropout=dropout, + use_group_norm=use_group_norm, + convolution_first=convolution_first, + ) + for _ in range(num_layers) + ] + ) + + def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + input (torch.Tensor): with shape `(B, T, input_dim)`. + lengths (torch.Tensor): with shape `(B,)` and i-th element representing + number of valid frames for i-th batch element in ``input``. + + Returns: + (torch.Tensor, torch.Tensor) + torch.Tensor + output frames, with shape `(B, T, input_dim)` + torch.Tensor + output lengths, with shape `(B,)` and i-th element representing + number of valid frames for i-th batch element in output frames. + """ + encoder_padding_mask = _lengths_to_padding_mask(lengths) + + x = input.transpose(0, 1) + for layer in self.conformer_layers: + x = layer(x, encoder_padding_mask) + return x.transpose(0, 1), lengths diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/conv_tasnet.py b/.venv/lib/python3.11/site-packages/torchaudio/models/conv_tasnet.py new file mode 100644 index 0000000000000000000000000000000000000000..770746dd46b34c47736e4607d4344672d0335ef2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/models/conv_tasnet.py @@ -0,0 +1,330 @@ +"""Implements Conv-TasNet with building blocks of it. + +Based on https://github.com/naplab/Conv-TasNet/tree/e66d82a8f956a69749ec8a4ae382217faa097c5c +""" + +from typing import Optional, Tuple + +import torch + + +class ConvBlock(torch.nn.Module): + """1D Convolutional block. + + Args: + io_channels (int): The number of input/output channels, + hidden_channels (int): The number of channels in the internal layers, . + kernel_size (int): The convolution kernel size of the middle layer,

. + padding (int): Padding value of the convolution in the middle layer. + dilation (int, optional): Dilation value of the convolution in the middle layer. + no_redisual (bool, optional): Disable residual block/output. + + Note: + This implementation corresponds to the "non-causal" setting in the paper. + """ + + def __init__( + self, + io_channels: int, + hidden_channels: int, + kernel_size: int, + padding: int, + dilation: int = 1, + no_residual: bool = False, + ): + super().__init__() + + self.conv_layers = torch.nn.Sequential( + torch.nn.Conv1d(in_channels=io_channels, out_channels=hidden_channels, kernel_size=1), + torch.nn.PReLU(), + torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08), + torch.nn.Conv1d( + in_channels=hidden_channels, + out_channels=hidden_channels, + kernel_size=kernel_size, + padding=padding, + dilation=dilation, + groups=hidden_channels, + ), + torch.nn.PReLU(), + torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08), + ) + + self.res_out = ( + None + if no_residual + else torch.nn.Conv1d(in_channels=hidden_channels, out_channels=io_channels, kernel_size=1) + ) + self.skip_out = torch.nn.Conv1d(in_channels=hidden_channels, out_channels=io_channels, kernel_size=1) + + def forward(self, input: torch.Tensor) -> Tuple[Optional[torch.Tensor], torch.Tensor]: + feature = self.conv_layers(input) + if self.res_out is None: + residual = None + else: + residual = self.res_out(feature) + skip_out = self.skip_out(feature) + return residual, skip_out + + +class MaskGenerator(torch.nn.Module): + """TCN (Temporal Convolution Network) Separation Module + + Generates masks for separation. + + Args: + input_dim (int): Input feature dimension, . + num_sources (int): The number of sources to separate. + kernel_size (int): The convolution kernel size of conv blocks,

. + num_featrs (int): Input/output feature dimenstion of conv blocks, . + num_hidden (int): Intermediate feature dimention of conv blocks, + num_layers (int): The number of conv blocks in one stack, . + num_stacks (int): The number of conv block stacks, . + msk_activate (str): The activation function of the mask output. + + Note: + This implementation corresponds to the "non-causal" setting in the paper. + """ + + def __init__( + self, + input_dim: int, + num_sources: int, + kernel_size: int, + num_feats: int, + num_hidden: int, + num_layers: int, + num_stacks: int, + msk_activate: str, + ): + super().__init__() + + self.input_dim = input_dim + self.num_sources = num_sources + + self.input_norm = torch.nn.GroupNorm(num_groups=1, num_channels=input_dim, eps=1e-8) + self.input_conv = torch.nn.Conv1d(in_channels=input_dim, out_channels=num_feats, kernel_size=1) + + self.receptive_field = 0 + self.conv_layers = torch.nn.ModuleList([]) + for s in range(num_stacks): + for l in range(num_layers): + multi = 2**l + self.conv_layers.append( + ConvBlock( + io_channels=num_feats, + hidden_channels=num_hidden, + kernel_size=kernel_size, + dilation=multi, + padding=multi, + # The last ConvBlock does not need residual + no_residual=(l == (num_layers - 1) and s == (num_stacks - 1)), + ) + ) + self.receptive_field += kernel_size if s == 0 and l == 0 else (kernel_size - 1) * multi + self.output_prelu = torch.nn.PReLU() + self.output_conv = torch.nn.Conv1d( + in_channels=num_feats, + out_channels=input_dim * num_sources, + kernel_size=1, + ) + if msk_activate == "sigmoid": + self.mask_activate = torch.nn.Sigmoid() + elif msk_activate == "relu": + self.mask_activate = torch.nn.ReLU() + else: + raise ValueError(f"Unsupported activation {msk_activate}") + + def forward(self, input: torch.Tensor) -> torch.Tensor: + """Generate separation mask. + + Args: + input (torch.Tensor): 3D Tensor with shape [batch, features, frames] + + Returns: + Tensor: shape [batch, num_sources, features, frames] + """ + batch_size = input.shape[0] + feats = self.input_norm(input) + feats = self.input_conv(feats) + output = 0.0 + for layer in self.conv_layers: + residual, skip = layer(feats) + if residual is not None: # the last conv layer does not produce residual + feats = feats + residual + output = output + skip + output = self.output_prelu(output) + output = self.output_conv(output) + output = self.mask_activate(output) + return output.view(batch_size, self.num_sources, self.input_dim, -1) + + +class ConvTasNet(torch.nn.Module): + """Conv-TasNet architecture introduced in + *Conv-TasNet: Surpassing Ideal Time–Frequency Magnitude Masking for Speech Separation* + :cite:`Luo_2019`. + + Note: + This implementation corresponds to the "non-causal" setting in the paper. + + See Also: + * :class:`torchaudio.pipelines.SourceSeparationBundle`: Source separation pipeline with pre-trained models. + + Args: + num_sources (int, optional): The number of sources to split. + enc_kernel_size (int, optional): The convolution kernel size of the encoder/decoder, . + enc_num_feats (int, optional): The feature dimensions passed to mask generator, . + msk_kernel_size (int, optional): The convolution kernel size of the mask generator,

. + msk_num_feats (int, optional): The input/output feature dimension of conv block in the mask generator, . + msk_num_hidden_feats (int, optional): The internal feature dimension of conv block of the mask generator, . + msk_num_layers (int, optional): The number of layers in one conv block of the mask generator, . + msk_num_stacks (int, optional): The numbr of conv blocks of the mask generator, . + msk_activate (str, optional): The activation function of the mask output (Default: ``sigmoid``). + """ + + def __init__( + self, + num_sources: int = 2, + # encoder/decoder parameters + enc_kernel_size: int = 16, + enc_num_feats: int = 512, + # mask generator parameters + msk_kernel_size: int = 3, + msk_num_feats: int = 128, + msk_num_hidden_feats: int = 512, + msk_num_layers: int = 8, + msk_num_stacks: int = 3, + msk_activate: str = "sigmoid", + ): + super().__init__() + + self.num_sources = num_sources + self.enc_num_feats = enc_num_feats + self.enc_kernel_size = enc_kernel_size + self.enc_stride = enc_kernel_size // 2 + + self.encoder = torch.nn.Conv1d( + in_channels=1, + out_channels=enc_num_feats, + kernel_size=enc_kernel_size, + stride=self.enc_stride, + padding=self.enc_stride, + bias=False, + ) + self.mask_generator = MaskGenerator( + input_dim=enc_num_feats, + num_sources=num_sources, + kernel_size=msk_kernel_size, + num_feats=msk_num_feats, + num_hidden=msk_num_hidden_feats, + num_layers=msk_num_layers, + num_stacks=msk_num_stacks, + msk_activate=msk_activate, + ) + self.decoder = torch.nn.ConvTranspose1d( + in_channels=enc_num_feats, + out_channels=1, + kernel_size=enc_kernel_size, + stride=self.enc_stride, + padding=self.enc_stride, + bias=False, + ) + + def _align_num_frames_with_strides(self, input: torch.Tensor) -> Tuple[torch.Tensor, int]: + """Pad input Tensor so that the end of the input tensor corresponds with + + 1. (if kernel size is odd) the center of the last convolution kernel + or 2. (if kernel size is even) the end of the first half of the last convolution kernel + + Assumption: + The resulting Tensor will be padded with the size of stride (== kernel_width // 2) + on the both ends in Conv1D + + |<--- k_1 --->| + | | |<-- k_n-1 -->| + | | | |<--- k_n --->| + | | | | | + | | | | | + | v v v | + |<---->|<--- input signal --->|<--->|<---->| + stride PAD stride + + Args: + input (torch.Tensor): 3D Tensor with shape (batch_size, channels==1, frames) + + Returns: + Tensor: Padded Tensor + int: Number of paddings performed + """ + batch_size, num_channels, num_frames = input.shape + is_odd = self.enc_kernel_size % 2 + num_strides = (num_frames - is_odd) // self.enc_stride + num_remainings = num_frames - (is_odd + num_strides * self.enc_stride) + if num_remainings == 0: + return input, 0 + + num_paddings = self.enc_stride - num_remainings + pad = torch.zeros( + batch_size, + num_channels, + num_paddings, + dtype=input.dtype, + device=input.device, + ) + return torch.cat([input, pad], 2), num_paddings + + def forward(self, input: torch.Tensor) -> torch.Tensor: + """Perform source separation. Generate audio source waveforms. + + Args: + input (torch.Tensor): 3D Tensor with shape [batch, channel==1, frames] + + Returns: + Tensor: 3D Tensor with shape [batch, channel==num_sources, frames] + """ + if input.ndim != 3 or input.shape[1] != 1: + raise ValueError(f"Expected 3D tensor (batch, channel==1, frames). Found: {input.shape}") + + # B: batch size + # L: input frame length + # L': padded input frame length + # F: feature dimension + # M: feature frame length + # S: number of sources + + padded, num_pads = self._align_num_frames_with_strides(input) # B, 1, L' + batch_size, num_padded_frames = padded.shape[0], padded.shape[2] + feats = self.encoder(padded) # B, F, M + masked = self.mask_generator(feats) * feats.unsqueeze(1) # B, S, F, M + masked = masked.view(batch_size * self.num_sources, self.enc_num_feats, -1) # B*S, F, M + decoded = self.decoder(masked) # B*S, 1, L' + output = decoded.view(batch_size, self.num_sources, num_padded_frames) # B, S, L' + if num_pads > 0: + output = output[..., :-num_pads] # B, S, L + return output + + +def conv_tasnet_base(num_sources: int = 2) -> ConvTasNet: + r"""Builds non-causal version of :class:`~torchaudio.models.ConvTasNet`. + + The parameter settings follow the ones with the highest Si-SNR metirc score in the paper, + except the mask activation function is changed from "sigmoid" to "relu" for performance improvement. + + Args: + num_sources (int, optional): Number of sources in the output. + (Default: 2) + Returns: + ConvTasNet: + ConvTasNet model. + """ + return ConvTasNet( + num_sources=num_sources, + enc_kernel_size=16, + enc_num_feats=512, + msk_kernel_size=3, + msk_num_feats=128, + msk_num_hidden_feats=512, + msk_num_layers=8, + msk_num_stacks=3, + msk_activate="relu", + ) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/decoder/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/models/decoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d2e9b06d52ef7af302a000bb0f572b4c563e12bd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/models/decoder/__init__.py @@ -0,0 +1,46 @@ +_CTC_DECODERS = [ + "CTCHypothesis", + "CTCDecoder", + "CTCDecoderLM", + "CTCDecoderLMState", + "ctc_decoder", + "download_pretrained_files", +] +_CUDA_CTC_DECODERS = [ + "CUCTCDecoder", + "CUCTCHypothesis", + "cuda_ctc_decoder", +] + + +def __getattr__(name: str): + if name in _CTC_DECODERS: + try: + from . import _ctc_decoder + except Exception as err: + raise RuntimeError( + "CTC Decoder suit requires flashlight-text package and optionally KenLM. Please install them." + ) from err + + item = getattr(_ctc_decoder, name) + globals()[name] = item + return item + elif name in _CUDA_CTC_DECODERS: + try: + from . import _cuda_ctc_decoder + except AttributeError as err: + raise RuntimeError( + "To use CUCTC decoder, please set BUILD_CUDA_CTC_DECODER=1 when building from source." + ) from err + + item = getattr(_cuda_ctc_decoder, name) + globals()[name] = item + return item + raise AttributeError(f"module {__name__} has no attribute {name}") + + +def __dir__(): + return sorted(__all__) + + +__all__ = _CTC_DECODERS + _CUDA_CTC_DECODERS diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/decoder/_ctc_decoder.py b/.venv/lib/python3.11/site-packages/torchaudio/models/decoder/_ctc_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4d45f12f523cc7748e1552ad410557fe9a1f6664 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/models/decoder/_ctc_decoder.py @@ -0,0 +1,568 @@ +from __future__ import annotations + +import itertools as it + +from abc import abstractmethod +from collections import namedtuple +from typing import Dict, List, NamedTuple, Optional, Tuple, Union + +import torch + +from flashlight.lib.text.decoder import ( + CriterionType as _CriterionType, + LexiconDecoder as _LexiconDecoder, + LexiconDecoderOptions as _LexiconDecoderOptions, + LexiconFreeDecoder as _LexiconFreeDecoder, + LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions, + LM as _LM, + LMState as _LMState, + SmearingMode as _SmearingMode, + Trie as _Trie, + ZeroLM as _ZeroLM, +) +from flashlight.lib.text.dictionary import ( + create_word_dict as _create_word_dict, + Dictionary as _Dictionary, + load_words as _load_words, +) +from torchaudio.utils import download_asset + +try: + from flashlight.lib.text.decoder.kenlm import KenLM as _KenLM +except Exception: + try: + from flashlight.lib.text.decoder import KenLM as _KenLM + except Exception: + _KenLM = None + +__all__ = [ + "CTCHypothesis", + "CTCDecoder", + "CTCDecoderLM", + "CTCDecoderLMState", + "ctc_decoder", + "download_pretrained_files", +] + +_PretrainedFiles = namedtuple("PretrainedFiles", ["lexicon", "tokens", "lm"]) + + +def _construct_trie(tokens_dict, word_dict, lexicon, lm, silence): + vocab_size = tokens_dict.index_size() + trie = _Trie(vocab_size, silence) + start_state = lm.start(False) + + for word, spellings in lexicon.items(): + word_idx = word_dict.get_index(word) + _, score = lm.score(start_state, word_idx) + for spelling in spellings: + spelling_idx = [tokens_dict.get_index(token) for token in spelling] + trie.insert(spelling_idx, word_idx, score) + trie.smear(_SmearingMode.MAX) + return trie + + +def _get_word_dict(lexicon, lm, lm_dict, tokens_dict, unk_word): + word_dict = None + if lm_dict is not None: + word_dict = _Dictionary(lm_dict) + + if lexicon and word_dict is None: + word_dict = _create_word_dict(lexicon) + elif not lexicon and word_dict is None and type(lm) == str: + d = {tokens_dict.get_entry(i): [[tokens_dict.get_entry(i)]] for i in range(tokens_dict.index_size())} + d[unk_word] = [[unk_word]] + word_dict = _create_word_dict(d) + + return word_dict + + +class CTCHypothesis(NamedTuple): + r"""Represents hypothesis generated by CTC beam search decoder :class:`CTCDecoder`.""" + tokens: torch.LongTensor + """Predicted sequence of token IDs. Shape `(L, )`, where `L` is the length of the output sequence""" + + words: List[str] + """List of predicted words. + + Note: + This attribute is only applicable if a lexicon is provided to the decoder. If + decoding without a lexicon, it will be blank. Please refer to :attr:`tokens` and + :func:`~torchaudio.models.decoder.CTCDecoder.idxs_to_tokens` instead. + """ + + score: float + """Score corresponding to hypothesis""" + + timesteps: torch.IntTensor + """Timesteps corresponding to the tokens. Shape `(L, )`, where `L` is the length of the output sequence""" + + +class CTCDecoderLMState(_LMState): + """Language model state.""" + + @property + def children(self) -> Dict[int, CTCDecoderLMState]: + """Map of indices to LM states""" + return super().children + + def child(self, usr_index: int) -> CTCDecoderLMState: + """Returns child corresponding to usr_index, or creates and returns a new state if input index + is not found. + + Args: + usr_index (int): index corresponding to child state + + Returns: + CTCDecoderLMState: child state corresponding to usr_index + """ + return super().child(usr_index) + + def compare(self, state: CTCDecoderLMState) -> CTCDecoderLMState: + """Compare two language model states. + + Args: + state (CTCDecoderLMState): LM state to compare against + + Returns: + int: 0 if the states are the same, -1 if self is less, +1 if self is greater. + """ + pass + + +class CTCDecoderLM(_LM): + """Language model base class for creating custom language models to use with the decoder.""" + + @abstractmethod + def start(self, start_with_nothing: bool) -> CTCDecoderLMState: + """Initialize or reset the language model. + + Args: + start_with_nothing (bool): whether or not to start sentence with sil token. + + Returns: + CTCDecoderLMState: starting state + """ + raise NotImplementedError + + @abstractmethod + def score(self, state: CTCDecoderLMState, usr_token_idx: int) -> Tuple[CTCDecoderLMState, float]: + """Evaluate the language model based on the current LM state and new word. + + Args: + state (CTCDecoderLMState): current LM state + usr_token_idx (int): index of the word + + Returns: + (CTCDecoderLMState, float) + CTCDecoderLMState: + new LM state + float: + score + """ + raise NotImplementedError + + @abstractmethod + def finish(self, state: CTCDecoderLMState) -> Tuple[CTCDecoderLMState, float]: + """Evaluate end for language model based on current LM state. + + Args: + state (CTCDecoderLMState): current LM state + + Returns: + (CTCDecoderLMState, float) + CTCDecoderLMState: + new LM state + float: + score + """ + raise NotImplementedError + + +class CTCDecoder: + """CTC beam search decoder from *Flashlight* :cite:`kahn2022flashlight`. + + .. devices:: CPU + + Note: + To build the decoder, please use the factory function :func:`ctc_decoder`. + """ + + def __init__( + self, + nbest: int, + lexicon: Optional[Dict], + word_dict: _Dictionary, + tokens_dict: _Dictionary, + lm: CTCDecoderLM, + decoder_options: Union[_LexiconDecoderOptions, _LexiconFreeDecoderOptions], + blank_token: str, + sil_token: str, + unk_word: str, + ) -> None: + """ + Args: + nbest (int): number of best decodings to return + lexicon (Dict or None): lexicon mapping of words to spellings, or None for lexicon-free decoder + word_dict (_Dictionary): dictionary of words + tokens_dict (_Dictionary): dictionary of tokens + lm (CTCDecoderLM): language model. If using a lexicon, only word level LMs are currently supported + decoder_options (_LexiconDecoderOptions or _LexiconFreeDecoderOptions): + parameters used for beam search decoding + blank_token (str): token corresopnding to blank + sil_token (str): token corresponding to silence + unk_word (str): word corresponding to unknown + """ + + self.nbest = nbest + self.word_dict = word_dict + self.tokens_dict = tokens_dict + self.blank = self.tokens_dict.get_index(blank_token) + silence = self.tokens_dict.get_index(sil_token) + transitions = [] + + if lexicon: + trie = _construct_trie(tokens_dict, word_dict, lexicon, lm, silence) + unk_word = word_dict.get_index(unk_word) + token_lm = False # use word level LM + + self.decoder = _LexiconDecoder( + decoder_options, + trie, + lm, + silence, + self.blank, + unk_word, + transitions, + token_lm, + ) + else: + self.decoder = _LexiconFreeDecoder(decoder_options, lm, silence, self.blank, transitions) + # https://github.com/pytorch/audio/issues/3218 + # If lm is passed like rvalue reference, the lm object gets garbage collected, + # and later call to the lm fails. + # This ensures that lm object is not deleted as long as the decoder is alive. + # https://github.com/pybind/pybind11/discussions/4013 + self.lm = lm + + def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor: + idxs = (g[0] for g in it.groupby(idxs)) + idxs = filter(lambda x: x != self.blank, idxs) + return torch.LongTensor(list(idxs)) + + def _get_timesteps(self, idxs: torch.IntTensor) -> torch.IntTensor: + """Returns frame numbers corresponding to non-blank tokens.""" + + timesteps = [] + for i, idx in enumerate(idxs): + if idx == self.blank: + continue + if i == 0 or idx != idxs[i - 1]: + timesteps.append(i) + return torch.IntTensor(timesteps) + + def decode_begin(self): + """Initialize the internal state of the decoder. + + See :py:meth:`decode_step` for the usage. + + .. note:: + + This method is required only when performing online decoding. + It is not necessary when performing batch decoding with :py:meth:`__call__`. + """ + self.decoder.decode_begin() + + def decode_end(self): + """Finalize the internal state of the decoder. + + See :py:meth:`decode_step` for the usage. + + .. note:: + + This method is required only when performing online decoding. + It is not necessary when performing batch decoding with :py:meth:`__call__`. + """ + self.decoder.decode_end() + + def decode_step(self, emissions: torch.FloatTensor): + """Perform incremental decoding on top of the curent internal state. + + .. note:: + + This method is required only when performing online decoding. + It is not necessary when performing batch decoding with :py:meth:`__call__`. + + Args: + emissions (torch.FloatTensor): CPU tensor of shape `(frame, num_tokens)` storing sequences of + probability distribution over labels; output of acoustic model. + + Example: + >>> decoder = torchaudio.models.decoder.ctc_decoder(...) + >>> decoder.decode_begin() + >>> decoder.decode_step(emission1) + >>> decoder.decode_step(emission2) + >>> decoder.decode_end() + >>> result = decoder.get_final_hypothesis() + """ + if emissions.dtype != torch.float32: + raise ValueError("emissions must be float32.") + + if not emissions.is_cpu: + raise RuntimeError("emissions must be a CPU tensor.") + + if not emissions.is_contiguous(): + raise RuntimeError("emissions must be contiguous.") + + if emissions.ndim != 2: + raise RuntimeError(f"emissions must be 2D. Found {emissions.shape}") + + T, N = emissions.size() + self.decoder.decode_step(emissions.data_ptr(), T, N) + + def _to_hypo(self, results) -> List[CTCHypothesis]: + return [ + CTCHypothesis( + tokens=self._get_tokens(result.tokens), + words=[self.word_dict.get_entry(x) for x in result.words if x >= 0], + score=result.score, + timesteps=self._get_timesteps(result.tokens), + ) + for result in results + ] + + def get_final_hypothesis(self) -> List[CTCHypothesis]: + """Get the final hypothesis + + Returns: + List[CTCHypothesis]: + List of sorted best hypotheses. + + .. note:: + + This method is required only when performing online decoding. + It is not necessary when performing batch decoding with :py:meth:`__call__`. + """ + results = self.decoder.get_all_final_hypothesis() + return self._to_hypo(results[: self.nbest]) + + def __call__( + self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None + ) -> List[List[CTCHypothesis]]: + """ + Performs batched offline decoding. + + .. note:: + + This method performs offline decoding in one go. To perform incremental decoding, + please refer to :py:meth:`decode_step`. + + Args: + emissions (torch.FloatTensor): CPU tensor of shape `(batch, frame, num_tokens)` storing sequences of + probability distribution over labels; output of acoustic model. + lengths (Tensor or None, optional): CPU tensor of shape `(batch, )` storing the valid length of + in time axis of the output Tensor in each batch. + + Returns: + List[List[CTCHypothesis]]: + List of sorted best hypotheses for each audio sequence in the batch. + """ + + if emissions.dtype != torch.float32: + raise ValueError("emissions must be float32.") + + if not emissions.is_cpu: + raise RuntimeError("emissions must be a CPU tensor.") + + if not emissions.is_contiguous(): + raise RuntimeError("emissions must be contiguous.") + + if emissions.ndim != 3: + raise RuntimeError(f"emissions must be 3D. Found {emissions.shape}") + + if lengths is not None and not lengths.is_cpu: + raise RuntimeError("lengths must be a CPU tensor.") + + B, T, N = emissions.size() + if lengths is None: + lengths = torch.full((B,), T) + + float_bytes = 4 + hypos = [] + + for b in range(B): + emissions_ptr = emissions.data_ptr() + float_bytes * b * emissions.stride(0) + results = self.decoder.decode(emissions_ptr, lengths[b], N) + hypos.append(self._to_hypo(results[: self.nbest])) + return hypos + + def idxs_to_tokens(self, idxs: torch.LongTensor) -> List: + """ + Map raw token IDs into corresponding tokens + + Args: + idxs (LongTensor): raw token IDs generated from decoder + + Returns: + List: tokens corresponding to the input IDs + """ + return [self.tokens_dict.get_entry(idx.item()) for idx in idxs] + + +def ctc_decoder( + lexicon: Optional[str], + tokens: Union[str, List[str]], + lm: Union[str, CTCDecoderLM] = None, + lm_dict: Optional[str] = None, + nbest: int = 1, + beam_size: int = 50, + beam_size_token: Optional[int] = None, + beam_threshold: float = 50, + lm_weight: float = 2, + word_score: float = 0, + unk_score: float = float("-inf"), + sil_score: float = 0, + log_add: bool = False, + blank_token: str = "-", + sil_token: str = "|", + unk_word: str = "", +) -> CTCDecoder: + """Builds an instance of :class:`CTCDecoder`. + + Args: + lexicon (str or None): lexicon file containing the possible words and corresponding spellings. + Each line consists of a word and its space separated spelling. If `None`, uses lexicon-free + decoding. + tokens (str or List[str]): file or list containing valid tokens. If using a file, the expected + format is for tokens mapping to the same index to be on the same line + lm (str, CTCDecoderLM, or None, optional): either a path containing KenLM language model, + custom language model of type `CTCDecoderLM`, or `None` if not using a language model + lm_dict (str or None, optional): file consisting of the dictionary used for the LM, with a word + per line sorted by LM index. If decoding with a lexicon, entries in lm_dict must also occur + in the lexicon file. If `None`, dictionary for LM is constructed using the lexicon file. + (Default: None) + nbest (int, optional): number of best decodings to return (Default: 1) + beam_size (int, optional): max number of hypos to hold after each decode step (Default: 50) + beam_size_token (int, optional): max number of tokens to consider at each decode step. + If `None`, it is set to the total number of tokens (Default: None) + beam_threshold (float, optional): threshold for pruning hypothesis (Default: 50) + lm_weight (float, optional): weight of language model (Default: 2) + word_score (float, optional): word insertion score (Default: 0) + unk_score (float, optional): unknown word insertion score (Default: -inf) + sil_score (float, optional): silence insertion score (Default: 0) + log_add (bool, optional): whether or not to use logadd when merging hypotheses (Default: False) + blank_token (str, optional): token corresponding to blank (Default: "-") + sil_token (str, optional): token corresponding to silence (Default: "|") + unk_word (str, optional): word corresponding to unknown (Default: "") + + Returns: + CTCDecoder: decoder + + Example + >>> decoder = ctc_decoder( + >>> lexicon="lexicon.txt", + >>> tokens="tokens.txt", + >>> lm="kenlm.bin", + >>> ) + >>> results = decoder(emissions) # List of shape (B, nbest) of Hypotheses + """ + if lm_dict is not None and type(lm_dict) is not str: + raise ValueError("lm_dict must be None or str type.") + + tokens_dict = _Dictionary(tokens) + + # decoder options + if lexicon: + lexicon = _load_words(lexicon) + decoder_options = _LexiconDecoderOptions( + beam_size=beam_size, + beam_size_token=beam_size_token or tokens_dict.index_size(), + beam_threshold=beam_threshold, + lm_weight=lm_weight, + word_score=word_score, + unk_score=unk_score, + sil_score=sil_score, + log_add=log_add, + criterion_type=_CriterionType.CTC, + ) + else: + decoder_options = _LexiconFreeDecoderOptions( + beam_size=beam_size, + beam_size_token=beam_size_token or tokens_dict.index_size(), + beam_threshold=beam_threshold, + lm_weight=lm_weight, + sil_score=sil_score, + log_add=log_add, + criterion_type=_CriterionType.CTC, + ) + + # construct word dict and language model + word_dict = _get_word_dict(lexicon, lm, lm_dict, tokens_dict, unk_word) + + if type(lm) == str: + if _KenLM is None: + raise RuntimeError( + "flashlight-text is installed, but KenLM is not installed. " + "Please refer to https://github.com/kpu/kenlm#python-module for how to install it." + ) + lm = _KenLM(lm, word_dict) + elif lm is None: + lm = _ZeroLM() + + return CTCDecoder( + nbest=nbest, + lexicon=lexicon, + word_dict=word_dict, + tokens_dict=tokens_dict, + lm=lm, + decoder_options=decoder_options, + blank_token=blank_token, + sil_token=sil_token, + unk_word=unk_word, + ) + + +def _get_filenames(model: str) -> _PretrainedFiles: + if model not in ["librispeech", "librispeech-3-gram", "librispeech-4-gram"]: + raise ValueError( + f"{model} not supported. Must be one of ['librispeech-3-gram', 'librispeech-4-gram', 'librispeech']" + ) + + prefix = f"decoder-assets/{model}" + return _PretrainedFiles( + lexicon=f"{prefix}/lexicon.txt", + tokens=f"{prefix}/tokens.txt", + lm=f"{prefix}/lm.bin" if model != "librispeech" else None, + ) + + +def download_pretrained_files(model: str) -> _PretrainedFiles: + """ + Retrieves pretrained data files used for :func:`ctc_decoder`. + + Args: + model (str): pretrained language model to download. + Valid values are: ``"librispeech-3-gram"``, ``"librispeech-4-gram"`` and ``"librispeech"``. + + Returns: + Object with the following attributes + + * ``lm``: path corresponding to downloaded language model, + or ``None`` if the model is not associated with an lm + * ``lexicon``: path corresponding to downloaded lexicon file + * ``tokens``: path corresponding to downloaded tokens file + """ + + files = _get_filenames(model) + lexicon_file = download_asset(files.lexicon) + tokens_file = download_asset(files.tokens) + if files.lm is not None: + lm_file = download_asset(files.lm) + else: + lm_file = None + + return _PretrainedFiles( + lexicon=lexicon_file, + tokens=tokens_file, + lm=lm_file, + ) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/decoder/_cuda_ctc_decoder.py b/.venv/lib/python3.11/site-packages/torchaudio/models/decoder/_cuda_ctc_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..9b1f509644091e04ea3bdc4301a74c546044f31d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/models/decoder/_cuda_ctc_decoder.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +import math + +from typing import List, NamedTuple, Union + +import torch +import torchaudio + +torchaudio._extension._load_lib("libctc_prefix_decoder") +import torchaudio.lib.pybind11_prefixctc as cuctc + + +__all__ = ["CUCTCHypothesis", "CUCTCDecoder", "cuda_ctc_decoder"] + + +def _get_vocab_list(vocab_file): + vocab = [] + with open(vocab_file, "r", encoding="utf-8") as f: + for line in f: + line = line.strip().split() + vocab.append(line[0]) + return vocab + + +class CUCTCHypothesis(NamedTuple): + r"""Represents hypothesis generated by CUCTC beam search decoder :class:`CUCTCDecoder`.""" + tokens: List[int] + """Predicted sequence of token IDs. Shape `(L, )`, where `L` is the length of the output sequence""" + + words: List[str] + """List of predicted tokens. Algin with modeling unit. + """ + + score: float + """Score corresponding to hypothesis""" + + +_DEFAULT_BLANK_SKIP_THREASHOLD = 0.95 + + +class CUCTCDecoder: + """CUDA CTC beam search decoder. + + .. devices:: CUDA + + Note: + To build the decoder, please use the factory function :func:`cuda_ctc_decoder`. + """ + + def __init__( + self, + vocab_list: List[str], + blank_id: int = 0, + beam_size: int = 10, + nbest: int = 1, + blank_skip_threshold: float = _DEFAULT_BLANK_SKIP_THREASHOLD, + cuda_stream: torch.cuda.streams.Stream = None, + ): + """ + Args: + blank_id (int): token id corresopnding to blank, only support 0 for now. (Default: 0) + vocab_list (List[str]): list of vocabulary tokens + beam_size (int, optional): max number of hypos to hold after each decode step (Default: 10) + nbest (int): number of best decodings to return + blank_skip_threshold (float): + skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding. + (Default: 0.95). + cuda_stream (torch.cuda.streams.Stream): using assigned cuda stream (Default: using default stream) + + """ + if cuda_stream: + if not isinstance(cuda_stream, torch.cuda.streams.Stream): + raise AssertionError("cuda_stream must be torch.cuda.streams.Stream") + cuda_stream_ = cuda_stream.cuda_stream if cuda_stream else torch.cuda.current_stream().cuda_stream + self.internal_data = cuctc.prefixCTC_alloc(cuda_stream_) + self.memory = torch.empty(0, dtype=torch.int8, device=torch.device("cuda")) + if blank_id != 0: + raise AssertionError("blank_id must be 0") + self.blank_id = blank_id + self.vocab_list = vocab_list + self.space_id = 0 + self.nbest = nbest + if not (blank_skip_threshold >= 0 and blank_skip_threshold <= 1): + raise AssertionError("blank_skip_threshold must be between 0 and 1") + self.blank_skip_threshold = math.log(blank_skip_threshold) + self.beam_size = min(beam_size, len(vocab_list)) # beam size must be smaller than vocab size + + def __del__(self): + if cuctc is not None: + cuctc.prefixCTC_free(self.internal_data) + + def __call__(self, log_prob: torch.Tensor, encoder_out_lens: torch.Tensor): + """ + Args: + log_prob (torch.FloatTensor): GPU tensor of shape `(batch, frame, num_tokens)` storing sequences of + probability distribution over labels; log_softmax(output of acoustic model). + lengths (dtype torch.int32): GPU tensor of shape `(batch, )` storing the valid length of + in time axis of the output Tensor in each batch. + + Returns: + List[List[CUCTCHypothesis]]: + List of sorted best hypotheses for each audio sequence in the batch. + """ + if not encoder_out_lens.dtype == torch.int32: + raise AssertionError("encoder_out_lens must be torch.int32") + if not log_prob.dtype == torch.float32: + raise AssertionError("log_prob must be torch.float32") + if not (log_prob.is_cuda and encoder_out_lens.is_cuda): + raise AssertionError("inputs must be cuda tensors") + if not (log_prob.is_contiguous() and encoder_out_lens.is_contiguous()): + raise AssertionError("input tensors must be contiguous") + required_size, score_hyps = cuctc.ctc_beam_search_decoder_batch_gpu_v2( + self.internal_data, + self.memory.data_ptr(), + self.memory.size(0), + log_prob.data_ptr(), + encoder_out_lens.data_ptr(), + log_prob.size(), + log_prob.stride(), + self.beam_size, + self.blank_id, + self.space_id, + self.blank_skip_threshold, + ) + if required_size > 0: + self.memory = torch.empty(required_size, dtype=torch.int8, device=log_prob.device).contiguous() + _, score_hyps = cuctc.ctc_beam_search_decoder_batch_gpu_v2( + self.internal_data, + self.memory.data_ptr(), + self.memory.size(0), + log_prob.data_ptr(), + encoder_out_lens.data_ptr(), + log_prob.size(), + log_prob.stride(), + self.beam_size, + self.blank_id, + self.space_id, + self.blank_skip_threshold, + ) + batch_size = len(score_hyps) + hypos = [] + for i in range(batch_size): + hypos.append( + [ + CUCTCHypothesis( + tokens=score_hyps[i][j][1], + words=[self.vocab_list[word_id] for word_id in score_hyps[i][j][1]], + score=score_hyps[i][j][0], + ) + for j in range(self.nbest) + ] + ) + return hypos + + +def cuda_ctc_decoder( + tokens: Union[str, List[str]], + nbest: int = 1, + beam_size: int = 10, + blank_skip_threshold: float = _DEFAULT_BLANK_SKIP_THREASHOLD, +) -> CUCTCDecoder: + """Builds an instance of :class:`CUCTCDecoder`. + + Args: + tokens (str or List[str]): File or list containing valid tokens. + If using a file, the expected format is for tokens mapping to the same index to be on the same line + beam_size (int, optional): The maximum number of hypos to hold after each decode step (Default: 10) + nbest (int): The number of best decodings to return + blank_id (int): The token ID corresopnding to the blank symbol. + blank_skip_threshold (float): skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding + (Default: 0.95). + + Returns: + CUCTCDecoder: decoder + + Example + >>> decoder = cuda_ctc_decoder( + >>> vocab_file="tokens.txt", + >>> blank_skip_threshold=0.95, + >>> ) + >>> results = decoder(log_probs, encoder_out_lens) # List of shape (B, nbest) of Hypotheses + """ + if type(tokens) == str: + tokens = _get_vocab_list(tokens) + + return CUCTCDecoder(vocab_list=tokens, beam_size=beam_size, nbest=nbest, blank_skip_threshold=blank_skip_threshold) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/deepspeech.py b/.venv/lib/python3.11/site-packages/torchaudio/models/deepspeech.py new file mode 100644 index 0000000000000000000000000000000000000000..ef23d1d351bde615cb2b1b38ffdd7782fbb5b627 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/models/deepspeech.py @@ -0,0 +1,84 @@ +import torch + +__all__ = ["DeepSpeech"] + + +class FullyConnected(torch.nn.Module): + """ + Args: + n_feature: Number of input features + n_hidden: Internal hidden unit size. + """ + + def __init__(self, n_feature: int, n_hidden: int, dropout: float, relu_max_clip: int = 20) -> None: + super(FullyConnected, self).__init__() + self.fc = torch.nn.Linear(n_feature, n_hidden, bias=True) + self.relu_max_clip = relu_max_clip + self.dropout = dropout + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc(x) + x = torch.nn.functional.relu(x) + x = torch.nn.functional.hardtanh(x, 0, self.relu_max_clip) + if self.dropout: + x = torch.nn.functional.dropout(x, self.dropout, self.training) + return x + + +class DeepSpeech(torch.nn.Module): + """DeepSpeech architecture introduced in + *Deep Speech: Scaling up end-to-end speech recognition* :cite:`hannun2014deep`. + + Args: + n_feature: Number of input features + n_hidden: Internal hidden unit size. + n_class: Number of output classes + """ + + def __init__( + self, + n_feature: int, + n_hidden: int = 2048, + n_class: int = 40, + dropout: float = 0.0, + ) -> None: + super(DeepSpeech, self).__init__() + self.n_hidden = n_hidden + self.fc1 = FullyConnected(n_feature, n_hidden, dropout) + self.fc2 = FullyConnected(n_hidden, n_hidden, dropout) + self.fc3 = FullyConnected(n_hidden, n_hidden, dropout) + self.bi_rnn = torch.nn.RNN(n_hidden, n_hidden, num_layers=1, nonlinearity="relu", bidirectional=True) + self.fc4 = FullyConnected(n_hidden, n_hidden, dropout) + self.out = torch.nn.Linear(n_hidden, n_class) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Tensor of dimension (batch, channel, time, feature). + Returns: + Tensor: Predictor tensor of dimension (batch, time, class). + """ + # N x C x T x F + x = self.fc1(x) + # N x C x T x H + x = self.fc2(x) + # N x C x T x H + x = self.fc3(x) + # N x C x T x H + x = x.squeeze(1) + # N x T x H + x = x.transpose(0, 1) + # T x N x H + x, _ = self.bi_rnn(x) + # The fifth (non-recurrent) layer takes both the forward and backward units as inputs + x = x[:, :, : self.n_hidden] + x[:, :, self.n_hidden :] + # T x N x H + x = self.fc4(x) + # T x N x H + x = self.out(x) + # T x N x n_class + x = x.permute(1, 0, 2) + # N x T x n_class + x = torch.nn.functional.log_softmax(x, dim=2) + # N x T x n_class + return x diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/emformer.py b/.venv/lib/python3.11/site-packages/torchaudio/models/emformer.py new file mode 100644 index 0000000000000000000000000000000000000000..9ddd257552ecda94cb55bbc1eed1dae8a5382380 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/models/emformer.py @@ -0,0 +1,884 @@ +import math +from typing import List, Optional, Tuple + +import torch + + +__all__ = ["Emformer"] + + +def _lengths_to_padding_mask(lengths: torch.Tensor) -> torch.Tensor: + batch_size = lengths.shape[0] + max_length = int(torch.max(lengths).item()) + padding_mask = torch.arange(max_length, device=lengths.device, dtype=lengths.dtype).expand( + batch_size, max_length + ) >= lengths.unsqueeze(1) + return padding_mask + + +def _gen_padding_mask( + utterance: torch.Tensor, + right_context: torch.Tensor, + summary: torch.Tensor, + lengths: torch.Tensor, + mems: torch.Tensor, + left_context_key: Optional[torch.Tensor] = None, +) -> Optional[torch.Tensor]: + T = right_context.size(0) + utterance.size(0) + summary.size(0) + B = right_context.size(1) + if B == 1: + padding_mask = None + else: + right_context_blocks_length = T - torch.max(lengths).int() - summary.size(0) + left_context_blocks_length = left_context_key.size(0) if left_context_key is not None else 0 + klengths = lengths + mems.size(0) + right_context_blocks_length + left_context_blocks_length + padding_mask = _lengths_to_padding_mask(lengths=klengths) + return padding_mask + + +def _get_activation_module(activation: str) -> torch.nn.Module: + if activation == "relu": + return torch.nn.ReLU() + elif activation == "gelu": + return torch.nn.GELU() + elif activation == "silu": + return torch.nn.SiLU() + else: + raise ValueError(f"Unsupported activation {activation}") + + +def _get_weight_init_gains(weight_init_scale_strategy: Optional[str], num_layers: int) -> List[Optional[float]]: + if weight_init_scale_strategy is None: + return [None for _ in range(num_layers)] + elif weight_init_scale_strategy == "depthwise": + return [1.0 / math.sqrt(layer_idx + 1) for layer_idx in range(num_layers)] + elif weight_init_scale_strategy == "constant": + return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)] + else: + raise ValueError(f"Unsupported weight_init_scale_strategy value {weight_init_scale_strategy}") + + +def _gen_attention_mask_block( + col_widths: List[int], col_mask: List[bool], num_rows: int, device: torch.device +) -> torch.Tensor: + if len(col_widths) != len(col_mask): + raise ValueError("Length of col_widths must match that of col_mask") + + mask_block = [ + torch.ones(num_rows, col_width, device=device) + if is_ones_col + else torch.zeros(num_rows, col_width, device=device) + for col_width, is_ones_col in zip(col_widths, col_mask) + ] + return torch.cat(mask_block, dim=1) + + +class _EmformerAttention(torch.nn.Module): + r"""Emformer layer attention module. + + Args: + input_dim (int): input dimension. + num_heads (int): number of attention heads in each Emformer layer. + dropout (float, optional): dropout probability. (Default: 0.0) + weight_init_gain (float or None, optional): scale factor to apply when initializing + attention module parameters. (Default: ``None``) + tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``) + negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8) + """ + + def __init__( + self, + input_dim: int, + num_heads: int, + dropout: float = 0.0, + weight_init_gain: Optional[float] = None, + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + ): + super().__init__() + + if input_dim % num_heads != 0: + raise ValueError(f"input_dim ({input_dim}) is not a multiple of num_heads ({num_heads}).") + + self.input_dim = input_dim + self.num_heads = num_heads + self.dropout = dropout + self.tanh_on_mem = tanh_on_mem + self.negative_inf = negative_inf + + self.scaling = (self.input_dim // self.num_heads) ** -0.5 + + self.emb_to_key_value = torch.nn.Linear(input_dim, 2 * input_dim, bias=True) + self.emb_to_query = torch.nn.Linear(input_dim, input_dim, bias=True) + self.out_proj = torch.nn.Linear(input_dim, input_dim, bias=True) + + if weight_init_gain: + torch.nn.init.xavier_uniform_(self.emb_to_key_value.weight, gain=weight_init_gain) + torch.nn.init.xavier_uniform_(self.emb_to_query.weight, gain=weight_init_gain) + + def _gen_key_value(self, input: torch.Tensor, mems: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + T, _, _ = input.shape + summary_length = mems.size(0) + 1 + right_ctx_utterance_block = input[: T - summary_length] + mems_right_ctx_utterance_block = torch.cat([mems, right_ctx_utterance_block]) + key, value = self.emb_to_key_value(mems_right_ctx_utterance_block).chunk(chunks=2, dim=2) + return key, value + + def _gen_attention_probs( + self, + attention_weights: torch.Tensor, + attention_mask: torch.Tensor, + padding_mask: Optional[torch.Tensor], + ) -> torch.Tensor: + attention_weights_float = attention_weights.float() + attention_weights_float = attention_weights_float.masked_fill(attention_mask.unsqueeze(0), self.negative_inf) + T = attention_weights.size(1) + B = attention_weights.size(0) // self.num_heads + if padding_mask is not None: + attention_weights_float = attention_weights_float.view(B, self.num_heads, T, -1) + attention_weights_float = attention_weights_float.masked_fill( + padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf + ) + attention_weights_float = attention_weights_float.view(B * self.num_heads, T, -1) + attention_probs = torch.nn.functional.softmax(attention_weights_float, dim=-1).type_as(attention_weights) + return torch.nn.functional.dropout(attention_probs, p=float(self.dropout), training=self.training) + + def _forward_impl( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + summary: torch.Tensor, + mems: torch.Tensor, + attention_mask: torch.Tensor, + left_context_key: Optional[torch.Tensor] = None, + left_context_val: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + B = utterance.size(1) + T = right_context.size(0) + utterance.size(0) + summary.size(0) + + # Compute query with [right context, utterance, summary]. + query = self.emb_to_query(torch.cat([right_context, utterance, summary])) + + # Compute key and value with [mems, right context, utterance]. + key, value = self.emb_to_key_value(torch.cat([mems, right_context, utterance])).chunk(chunks=2, dim=2) + + if left_context_key is not None and left_context_val is not None: + right_context_blocks_length = T - torch.max(lengths).int() - summary.size(0) + key = torch.cat( + [ + key[: mems.size(0) + right_context_blocks_length], + left_context_key, + key[mems.size(0) + right_context_blocks_length :], + ], + ) + value = torch.cat( + [ + value[: mems.size(0) + right_context_blocks_length], + left_context_val, + value[mems.size(0) + right_context_blocks_length :], + ], + ) + + # Compute attention weights from query, key, and value. + reshaped_query, reshaped_key, reshaped_value = [ + tensor.contiguous().view(-1, B * self.num_heads, self.input_dim // self.num_heads).transpose(0, 1) + for tensor in [query, key, value] + ] + attention_weights = torch.bmm(reshaped_query * self.scaling, reshaped_key.transpose(1, 2)) + + # Compute padding mask. + padding_mask = _gen_padding_mask(utterance, right_context, summary, lengths, mems, left_context_key) + + # Compute attention probabilities. + attention_probs = self._gen_attention_probs(attention_weights, attention_mask, padding_mask) + + # Compute attention. + attention = torch.bmm(attention_probs, reshaped_value) + if attention.shape != ( + B * self.num_heads, + T, + self.input_dim // self.num_heads, + ): + raise AssertionError("Computed attention has incorrect dimensions") + attention = attention.transpose(0, 1).contiguous().view(T, B, self.input_dim) + + # Apply output projection. + output_right_context_mems = self.out_proj(attention) + + summary_length = summary.size(0) + output_right_context = output_right_context_mems[: T - summary_length] + output_mems = output_right_context_mems[T - summary_length :] + if self.tanh_on_mem: + output_mems = torch.tanh(output_mems) + else: + output_mems = torch.clamp(output_mems, min=-10, max=10) + + return output_right_context, output_mems, key, value + + def forward( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + summary: torch.Tensor, + mems: torch.Tensor, + attention_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Forward pass for training. + + B: batch size; + D: feature dimension of each frame; + T: number of utterance frames; + R: number of right context frames; + S: number of summary elements; + M: number of memory elements. + + Args: + utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`. + lengths (torch.Tensor): with shape `(B,)` and i-th element representing + number of valid frames for i-th batch element in ``utterance``. + right_context (torch.Tensor): right context frames, with shape `(R, B, D)`. + summary (torch.Tensor): summary elements, with shape `(S, B, D)`. + mems (torch.Tensor): memory elements, with shape `(M, B, D)`. + attention_mask (torch.Tensor): attention mask for underlying attention module. + + Returns: + (Tensor, Tensor): + Tensor + output frames corresponding to utterance and right_context, with shape `(T + R, B, D)`. + Tensor + updated memory elements, with shape `(M, B, D)`. + """ + output, output_mems, _, _ = self._forward_impl(utterance, lengths, right_context, summary, mems, attention_mask) + return output, output_mems[:-1] + + @torch.jit.export + def infer( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + summary: torch.Tensor, + mems: torch.Tensor, + left_context_key: torch.Tensor, + left_context_val: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + r"""Forward pass for inference. + + B: batch size; + D: feature dimension of each frame; + T: number of utterance frames; + R: number of right context frames; + S: number of summary elements; + M: number of memory elements. + + Args: + utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`. + lengths (torch.Tensor): with shape `(B,)` and i-th element representing + number of valid frames for i-th batch element in ``utterance``. + right_context (torch.Tensor): right context frames, with shape `(R, B, D)`. + summary (torch.Tensor): summary elements, with shape `(S, B, D)`. + mems (torch.Tensor): memory elements, with shape `(M, B, D)`. + left_context_key (torch.Tensor): left context attention key computed from preceding invocation. + left_context_val (torch.Tensor): left context attention value computed from preceding invocation. + + Returns: + (Tensor, Tensor, Tensor, and Tensor): + Tensor + output frames corresponding to utterance and right_context, with shape `(T + R, B, D)`. + Tensor + updated memory elements, with shape `(M, B, D)`. + Tensor + attention key computed for left context and utterance. + Tensor + attention value computed for left context and utterance. + """ + query_dim = right_context.size(0) + utterance.size(0) + summary.size(0) + key_dim = right_context.size(0) + utterance.size(0) + mems.size(0) + left_context_key.size(0) + attention_mask = torch.zeros(query_dim, key_dim).to(dtype=torch.bool, device=utterance.device) + attention_mask[-1, : mems.size(0)] = True + output, output_mems, key, value = self._forward_impl( + utterance, + lengths, + right_context, + summary, + mems, + attention_mask, + left_context_key=left_context_key, + left_context_val=left_context_val, + ) + return ( + output, + output_mems, + key[mems.size(0) + right_context.size(0) :], + value[mems.size(0) + right_context.size(0) :], + ) + + +class _EmformerLayer(torch.nn.Module): + r"""Emformer layer that constitutes Emformer. + + Args: + input_dim (int): input dimension. + num_heads (int): number of attention heads. + ffn_dim: (int): hidden layer dimension of feedforward network. + segment_length (int): length of each input segment. + dropout (float, optional): dropout probability. (Default: 0.0) + activation (str, optional): activation function to use in feedforward network. + Must be one of ("relu", "gelu", "silu"). (Default: "relu") + left_context_length (int, optional): length of left context. (Default: 0) + max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0) + weight_init_gain (float or None, optional): scale factor to apply when initializing + attention module parameters. (Default: ``None``) + tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``) + negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8) + """ + + def __init__( + self, + input_dim: int, + num_heads: int, + ffn_dim: int, + segment_length: int, + dropout: float = 0.0, + activation: str = "relu", + left_context_length: int = 0, + max_memory_size: int = 0, + weight_init_gain: Optional[float] = None, + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + ): + super().__init__() + + self.attention = _EmformerAttention( + input_dim=input_dim, + num_heads=num_heads, + dropout=dropout, + weight_init_gain=weight_init_gain, + tanh_on_mem=tanh_on_mem, + negative_inf=negative_inf, + ) + self.dropout = torch.nn.Dropout(dropout) + self.memory_op = torch.nn.AvgPool1d(kernel_size=segment_length, stride=segment_length, ceil_mode=True) + + activation_module = _get_activation_module(activation) + self.pos_ff = torch.nn.Sequential( + torch.nn.LayerNorm(input_dim), + torch.nn.Linear(input_dim, ffn_dim), + activation_module, + torch.nn.Dropout(dropout), + torch.nn.Linear(ffn_dim, input_dim), + torch.nn.Dropout(dropout), + ) + self.layer_norm_input = torch.nn.LayerNorm(input_dim) + self.layer_norm_output = torch.nn.LayerNorm(input_dim) + + self.left_context_length = left_context_length + self.segment_length = segment_length + self.max_memory_size = max_memory_size + self.input_dim = input_dim + + self.use_mem = max_memory_size > 0 + + def _init_state(self, batch_size: int, device: Optional[torch.device]) -> List[torch.Tensor]: + empty_memory = torch.zeros(self.max_memory_size, batch_size, self.input_dim, device=device) + left_context_key = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device) + left_context_val = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device) + past_length = torch.zeros(1, batch_size, dtype=torch.int32, device=device) + return [empty_memory, left_context_key, left_context_val, past_length] + + def _unpack_state(self, state: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + past_length = state[3][0][0].item() + past_left_context_length = min(self.left_context_length, past_length) + past_mem_length = min(self.max_memory_size, math.ceil(past_length / self.segment_length)) + pre_mems = state[0][self.max_memory_size - past_mem_length :] + lc_key = state[1][self.left_context_length - past_left_context_length :] + lc_val = state[2][self.left_context_length - past_left_context_length :] + return pre_mems, lc_key, lc_val + + def _pack_state( + self, + next_k: torch.Tensor, + next_v: torch.Tensor, + update_length: int, + mems: torch.Tensor, + state: List[torch.Tensor], + ) -> List[torch.Tensor]: + new_k = torch.cat([state[1], next_k]) + new_v = torch.cat([state[2], next_v]) + state[0] = torch.cat([state[0], mems])[-self.max_memory_size :] + state[1] = new_k[new_k.shape[0] - self.left_context_length :] + state[2] = new_v[new_v.shape[0] - self.left_context_length :] + state[3] = state[3] + update_length + return state + + def _process_attention_output( + self, + rc_output: torch.Tensor, + utterance: torch.Tensor, + right_context: torch.Tensor, + ) -> torch.Tensor: + result = self.dropout(rc_output) + torch.cat([right_context, utterance]) + result = self.pos_ff(result) + result + result = self.layer_norm_output(result) + return result + + def _apply_pre_attention_layer_norm( + self, utterance: torch.Tensor, right_context: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + layer_norm_input = self.layer_norm_input(torch.cat([right_context, utterance])) + return ( + layer_norm_input[right_context.size(0) :], + layer_norm_input[: right_context.size(0)], + ) + + def _apply_post_attention_ffn( + self, rc_output: torch.Tensor, utterance: torch.Tensor, right_context: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + rc_output = self._process_attention_output(rc_output, utterance, right_context) + return rc_output[right_context.size(0) :], rc_output[: right_context.size(0)] + + def _apply_attention_forward( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + mems: torch.Tensor, + attention_mask: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + if attention_mask is None: + raise ValueError("attention_mask must be not None when for_inference is False") + + if self.use_mem: + summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) + else: + summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + rc_output, next_m = self.attention( + utterance=utterance, + lengths=lengths, + right_context=right_context, + summary=summary, + mems=mems, + attention_mask=attention_mask, + ) + return rc_output, next_m + + def _apply_attention_infer( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + mems: torch.Tensor, + state: Optional[List[torch.Tensor]], + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + if state is None: + state = self._init_state(utterance.size(1), device=utterance.device) + pre_mems, lc_key, lc_val = self._unpack_state(state) + if self.use_mem: + summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) + summary = summary[:1] + else: + summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + rc_output, next_m, next_k, next_v = self.attention.infer( + utterance=utterance, + lengths=lengths, + right_context=right_context, + summary=summary, + mems=pre_mems, + left_context_key=lc_key, + left_context_val=lc_val, + ) + state = self._pack_state(next_k, next_v, utterance.size(0), mems, state) + return rc_output, next_m, state + + def forward( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + mems: torch.Tensor, + attention_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r"""Forward pass for training. + + B: batch size; + D: feature dimension of each frame; + T: number of utterance frames; + R: number of right context frames; + M: number of memory elements. + + Args: + utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`. + lengths (torch.Tensor): with shape `(B,)` and i-th element representing + number of valid frames for i-th batch element in ``utterance``. + right_context (torch.Tensor): right context frames, with shape `(R, B, D)`. + mems (torch.Tensor): memory elements, with shape `(M, B, D)`. + attention_mask (torch.Tensor): attention mask for underlying attention module. + + Returns: + (Tensor, Tensor, Tensor): + Tensor + encoded utterance frames, with shape `(T, B, D)`. + Tensor + updated right context frames, with shape `(R, B, D)`. + Tensor + updated memory elements, with shape `(M, B, D)`. + """ + ( + layer_norm_utterance, + layer_norm_right_context, + ) = self._apply_pre_attention_layer_norm(utterance, right_context) + rc_output, output_mems = self._apply_attention_forward( + layer_norm_utterance, + lengths, + layer_norm_right_context, + mems, + attention_mask, + ) + output_utterance, output_right_context = self._apply_post_attention_ffn(rc_output, utterance, right_context) + return output_utterance, output_right_context, output_mems + + @torch.jit.export + def infer( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + state: Optional[List[torch.Tensor]], + mems: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]: + r"""Forward pass for inference. + + B: batch size; + D: feature dimension of each frame; + T: number of utterance frames; + R: number of right context frames; + M: number of memory elements. + + Args: + utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`. + lengths (torch.Tensor): with shape `(B,)` and i-th element representing + number of valid frames for i-th batch element in ``utterance``. + right_context (torch.Tensor): right context frames, with shape `(R, B, D)`. + state (List[torch.Tensor] or None): list of tensors representing layer internal state + generated in preceding invocation of ``infer``. + mems (torch.Tensor): memory elements, with shape `(M, B, D)`. + + Returns: + (Tensor, Tensor, List[torch.Tensor], Tensor): + Tensor + encoded utterance frames, with shape `(T, B, D)`. + Tensor + updated right context frames, with shape `(R, B, D)`. + List[Tensor] + list of tensors representing layer internal state + generated in current invocation of ``infer``. + Tensor + updated memory elements, with shape `(M, B, D)`. + """ + ( + layer_norm_utterance, + layer_norm_right_context, + ) = self._apply_pre_attention_layer_norm(utterance, right_context) + rc_output, output_mems, output_state = self._apply_attention_infer( + layer_norm_utterance, lengths, layer_norm_right_context, mems, state + ) + output_utterance, output_right_context = self._apply_post_attention_ffn(rc_output, utterance, right_context) + return output_utterance, output_right_context, output_state, output_mems + + +class _EmformerImpl(torch.nn.Module): + def __init__( + self, + emformer_layers: torch.nn.ModuleList, + segment_length: int, + left_context_length: int = 0, + right_context_length: int = 0, + max_memory_size: int = 0, + ): + super().__init__() + + self.use_mem = max_memory_size > 0 + self.memory_op = torch.nn.AvgPool1d( + kernel_size=segment_length, + stride=segment_length, + ceil_mode=True, + ) + self.emformer_layers = emformer_layers + self.left_context_length = left_context_length + self.right_context_length = right_context_length + self.segment_length = segment_length + self.max_memory_size = max_memory_size + + def _gen_right_context(self, input: torch.Tensor) -> torch.Tensor: + T = input.shape[0] + num_segs = math.ceil((T - self.right_context_length) / self.segment_length) + right_context_blocks = [] + for seg_idx in range(num_segs - 1): + start = (seg_idx + 1) * self.segment_length + end = start + self.right_context_length + right_context_blocks.append(input[start:end]) + right_context_blocks.append(input[T - self.right_context_length :]) + return torch.cat(right_context_blocks) + + def _gen_attention_mask_col_widths(self, seg_idx: int, utterance_length: int) -> List[int]: + num_segs = math.ceil(utterance_length / self.segment_length) + rc = self.right_context_length + lc = self.left_context_length + rc_start = seg_idx * rc + rc_end = rc_start + rc + seg_start = max(seg_idx * self.segment_length - lc, 0) + seg_end = min((seg_idx + 1) * self.segment_length, utterance_length) + rc_length = self.right_context_length * num_segs + + if self.use_mem: + m_start = max(seg_idx - self.max_memory_size, 0) + mem_length = num_segs - 1 + col_widths = [ + m_start, # before memory + seg_idx - m_start, # memory + mem_length - seg_idx, # after memory + rc_start, # before right context + rc, # right context + rc_length - rc_end, # after right context + seg_start, # before query segment + seg_end - seg_start, # query segment + utterance_length - seg_end, # after query segment + ] + else: + col_widths = [ + rc_start, # before right context + rc, # right context + rc_length - rc_end, # after right context + seg_start, # before query segment + seg_end - seg_start, # query segment + utterance_length - seg_end, # after query segment + ] + + return col_widths + + def _gen_attention_mask(self, input: torch.Tensor) -> torch.Tensor: + utterance_length = input.size(0) + num_segs = math.ceil(utterance_length / self.segment_length) + + rc_mask = [] + query_mask = [] + summary_mask = [] + + if self.use_mem: + num_cols = 9 + # memory, right context, query segment + rc_q_cols_mask = [idx in [1, 4, 7] for idx in range(num_cols)] + # right context, query segment + s_cols_mask = [idx in [4, 7] for idx in range(num_cols)] + masks_to_concat = [rc_mask, query_mask, summary_mask] + else: + num_cols = 6 + # right context, query segment + rc_q_cols_mask = [idx in [1, 4] for idx in range(num_cols)] + s_cols_mask = None + masks_to_concat = [rc_mask, query_mask] + + for seg_idx in range(num_segs): + col_widths = self._gen_attention_mask_col_widths(seg_idx, utterance_length) + + rc_mask_block = _gen_attention_mask_block( + col_widths, rc_q_cols_mask, self.right_context_length, input.device + ) + rc_mask.append(rc_mask_block) + + query_mask_block = _gen_attention_mask_block( + col_widths, + rc_q_cols_mask, + min( + self.segment_length, + utterance_length - seg_idx * self.segment_length, + ), + input.device, + ) + query_mask.append(query_mask_block) + + if s_cols_mask is not None: + summary_mask_block = _gen_attention_mask_block(col_widths, s_cols_mask, 1, input.device) + summary_mask.append(summary_mask_block) + + attention_mask = (1 - torch.cat([torch.cat(mask) for mask in masks_to_concat])).to(torch.bool) + return attention_mask + + def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Forward pass for training and non-streaming inference. + + B: batch size; + T: max number of input frames in batch; + D: feature dimension of each frame. + + Args: + input (torch.Tensor): utterance frames right-padded with right context frames, with + shape `(B, T + right_context_length, D)`. + lengths (torch.Tensor): with shape `(B,)` and i-th element representing + number of valid utterance frames for i-th batch element in ``input``. + + Returns: + (Tensor, Tensor): + Tensor + output frames, with shape `(B, T, D)`. + Tensor + output lengths, with shape `(B,)` and i-th element representing + number of valid frames for i-th batch element in output frames. + """ + input = input.permute(1, 0, 2) + right_context = self._gen_right_context(input) + utterance = input[: input.size(0) - self.right_context_length] + attention_mask = self._gen_attention_mask(utterance) + mems = ( + self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1] + if self.use_mem + else torch.empty(0).to(dtype=input.dtype, device=input.device) + ) + output = utterance + for layer in self.emformer_layers: + output, right_context, mems = layer(output, lengths, right_context, mems, attention_mask) + return output.permute(1, 0, 2), lengths + + @torch.jit.export + def infer( + self, + input: torch.Tensor, + lengths: torch.Tensor, + states: Optional[List[List[torch.Tensor]]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: + r"""Forward pass for streaming inference. + + B: batch size; + D: feature dimension of each frame. + + Args: + input (torch.Tensor): utterance frames right-padded with right context frames, with + shape `(B, segment_length + right_context_length, D)`. + lengths (torch.Tensor): with shape `(B,)` and i-th element representing + number of valid frames for i-th batch element in ``input``. + states (List[List[torch.Tensor]] or None, optional): list of lists of tensors + representing internal state generated in preceding invocation of ``infer``. (Default: ``None``) + + Returns: + (Tensor, Tensor, List[List[Tensor]]): + Tensor + output frames, with shape `(B, segment_length, D)`. + Tensor + output lengths, with shape `(B,)` and i-th element representing + number of valid frames for i-th batch element in output frames. + List[List[Tensor]] + output states; list of lists of tensors representing internal state + generated in current invocation of ``infer``. + """ + if input.size(1) != self.segment_length + self.right_context_length: + raise ValueError( + "Per configured segment_length and right_context_length" + f", expected size of {self.segment_length + self.right_context_length} for dimension 1 of input" + f", but got {input.size(1)}." + ) + input = input.permute(1, 0, 2) + right_context_start_idx = input.size(0) - self.right_context_length + right_context = input[right_context_start_idx:] + utterance = input[:right_context_start_idx] + output_lengths = torch.clamp(lengths - self.right_context_length, min=0) + mems = ( + self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) + if self.use_mem + else torch.empty(0).to(dtype=input.dtype, device=input.device) + ) + output = utterance + output_states: List[List[torch.Tensor]] = [] + for layer_idx, layer in enumerate(self.emformer_layers): + output, right_context, output_state, mems = layer.infer( + output, + output_lengths, + right_context, + None if states is None else states[layer_idx], + mems, + ) + output_states.append(output_state) + + return output.permute(1, 0, 2), output_lengths, output_states + + +class Emformer(_EmformerImpl): + r"""Emformer architecture introduced in + *Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency Streaming Speech Recognition* + :cite:`shi2021emformer`. + + See Also: + * :func:`~torchaudio.models.emformer_rnnt_model`, + :func:`~torchaudio.models.emformer_rnnt_base`: factory functions. + * :class:`torchaudio.pipelines.RNNTBundle`: ASR pipelines with pretrained model. + + Args: + input_dim (int): input dimension. + num_heads (int): number of attention heads in each Emformer layer. + ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network. + num_layers (int): number of Emformer layers to instantiate. + segment_length (int): length of each input segment. + dropout (float, optional): dropout probability. (Default: 0.0) + activation (str, optional): activation function to use in each Emformer layer's + feedforward network. Must be one of ("relu", "gelu", "silu"). (Default: "relu") + left_context_length (int, optional): length of left context. (Default: 0) + right_context_length (int, optional): length of right context. (Default: 0) + max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0) + weight_init_scale_strategy (str or None, optional): per-layer weight initialization scaling + strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise") + tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``) + negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8) + + Examples: + >>> emformer = Emformer(512, 8, 2048, 20, 4, right_context_length=1) + >>> input = torch.rand(128, 400, 512) # batch, num_frames, feature_dim + >>> lengths = torch.randint(1, 200, (128,)) # batch + >>> output, lengths = emformer(input, lengths) + >>> input = torch.rand(128, 5, 512) + >>> lengths = torch.ones(128) * 5 + >>> output, lengths, states = emformer.infer(input, lengths, None) + """ + + def __init__( + self, + input_dim: int, + num_heads: int, + ffn_dim: int, + num_layers: int, + segment_length: int, + dropout: float = 0.0, + activation: str = "relu", + left_context_length: int = 0, + right_context_length: int = 0, + max_memory_size: int = 0, + weight_init_scale_strategy: Optional[str] = "depthwise", + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + ): + weight_init_gains = _get_weight_init_gains(weight_init_scale_strategy, num_layers) + emformer_layers = torch.nn.ModuleList( + [ + _EmformerLayer( + input_dim, + num_heads, + ffn_dim, + segment_length, + dropout=dropout, + activation=activation, + left_context_length=left_context_length, + max_memory_size=max_memory_size, + weight_init_gain=weight_init_gains[layer_idx], + tanh_on_mem=tanh_on_mem, + negative_inf=negative_inf, + ) + for layer_idx in range(num_layers) + ] + ) + super().__init__( + emformer_layers, + segment_length, + left_context_length=left_context_length, + right_context_length=right_context_length, + max_memory_size=max_memory_size, + ) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/rnnt.py b/.venv/lib/python3.11/site-packages/torchaudio/models/rnnt.py new file mode 100644 index 0000000000000000000000000000000000000000..f9dbe22c9fb4a97cf7f8779a953b5bd7b5bbffd9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/models/rnnt.py @@ -0,0 +1,816 @@ +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple + +import torch +from torchaudio.models import Emformer + + +__all__ = ["RNNT", "emformer_rnnt_base", "emformer_rnnt_model"] + + +class _TimeReduction(torch.nn.Module): + r"""Coalesces frames along time dimension into a + fewer number of frames with higher feature dimensionality. + + Args: + stride (int): number of frames to merge for each output frame. + """ + + def __init__(self, stride: int) -> None: + super().__init__() + self.stride = stride + + def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Forward pass. + + B: batch size; + T: maximum input sequence length in batch; + D: feature dimension of each input sequence frame. + + Args: + input (torch.Tensor): input sequences, with shape `(B, T, D)`. + lengths (torch.Tensor): with shape `(B,)` and i-th element representing + number of valid frames for i-th batch element in ``input``. + + Returns: + (torch.Tensor, torch.Tensor): + torch.Tensor + output sequences, with shape + `(B, T // stride, D * stride)` + torch.Tensor + output lengths, with shape `(B,)` and i-th element representing + number of valid frames for i-th batch element in output sequences. + """ + B, T, D = input.shape + num_frames = T - (T % self.stride) + input = input[:, :num_frames, :] + lengths = lengths.div(self.stride, rounding_mode="trunc") + T_max = num_frames // self.stride + + output = input.reshape(B, T_max, D * self.stride) + output = output.contiguous() + return output, lengths + + +class _CustomLSTM(torch.nn.Module): + r"""Custom long-short-term memory (LSTM) block that applies layer normalization + to internal nodes. + + Args: + input_dim (int): input dimension. + hidden_dim (int): hidden dimension. + layer_norm (bool, optional): if ``True``, enables layer normalization. (Default: ``False``) + layer_norm_epsilon (float, optional): value of epsilon to use in + layer normalization layers (Default: 1e-5) + """ + + def __init__( + self, + input_dim: int, + hidden_dim: int, + layer_norm: bool = False, + layer_norm_epsilon: float = 1e-5, + ) -> None: + super().__init__() + self.x2g = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=(not layer_norm)) + self.p2g = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=False) + if layer_norm: + self.c_norm = torch.nn.LayerNorm(hidden_dim, eps=layer_norm_epsilon) + self.g_norm = torch.nn.LayerNorm(4 * hidden_dim, eps=layer_norm_epsilon) + else: + self.c_norm = torch.nn.Identity() + self.g_norm = torch.nn.Identity() + + self.hidden_dim = hidden_dim + + def forward( + self, input: torch.Tensor, state: Optional[List[torch.Tensor]] + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + r"""Forward pass. + + B: batch size; + T: maximum sequence length in batch; + D: feature dimension of each input sequence element. + + Args: + input (torch.Tensor): with shape `(T, B, D)`. + state (List[torch.Tensor] or None): list of tensors + representing internal state generated in preceding invocation + of ``forward``. + + Returns: + (torch.Tensor, List[torch.Tensor]): + torch.Tensor + output, with shape `(T, B, hidden_dim)`. + List[torch.Tensor] + list of tensors representing internal state generated + in current invocation of ``forward``. + """ + if state is None: + B = input.size(1) + h = torch.zeros(B, self.hidden_dim, device=input.device, dtype=input.dtype) + c = torch.zeros(B, self.hidden_dim, device=input.device, dtype=input.dtype) + else: + h, c = state + + gated_input = self.x2g(input) + outputs = [] + for gates in gated_input.unbind(0): + gates = gates + self.p2g(h) + gates = self.g_norm(gates) + input_gate, forget_gate, cell_gate, output_gate = gates.chunk(4, 1) + input_gate = input_gate.sigmoid() + forget_gate = forget_gate.sigmoid() + cell_gate = cell_gate.tanh() + output_gate = output_gate.sigmoid() + c = forget_gate * c + input_gate * cell_gate + c = self.c_norm(c) + h = output_gate * c.tanh() + outputs.append(h) + + output = torch.stack(outputs, dim=0) + state = [h, c] + + return output, state + + +class _Transcriber(ABC): + @abstractmethod + def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + pass + + @abstractmethod + def infer( + self, + input: torch.Tensor, + lengths: torch.Tensor, + states: Optional[List[List[torch.Tensor]]], + ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: + pass + + +class _EmformerEncoder(torch.nn.Module, _Transcriber): + r"""Emformer-based recurrent neural network transducer (RNN-T) encoder (transcription network). + + Args: + input_dim (int): feature dimension of each input sequence element. + output_dim (int): feature dimension of each output sequence element. + segment_length (int): length of input segment expressed as number of frames. + right_context_length (int): length of right context expressed as number of frames. + time_reduction_input_dim (int): dimension to scale each element in input sequences to + prior to applying time reduction block. + time_reduction_stride (int): factor by which to reduce length of input sequence. + transformer_num_heads (int): number of attention heads in each Emformer layer. + transformer_ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network. + transformer_num_layers (int): number of Emformer layers to instantiate. + transformer_left_context_length (int): length of left context. + transformer_dropout (float, optional): transformer dropout probability. (Default: 0.0) + transformer_activation (str, optional): activation function to use in each Emformer layer's + feedforward network. Must be one of ("relu", "gelu", "silu"). (Default: "relu") + transformer_max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0) + transformer_weight_init_scale_strategy (str, optional): per-layer weight initialization scaling + strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise") + transformer_tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``) + """ + + def __init__( + self, + *, + input_dim: int, + output_dim: int, + segment_length: int, + right_context_length: int, + time_reduction_input_dim: int, + time_reduction_stride: int, + transformer_num_heads: int, + transformer_ffn_dim: int, + transformer_num_layers: int, + transformer_left_context_length: int, + transformer_dropout: float = 0.0, + transformer_activation: str = "relu", + transformer_max_memory_size: int = 0, + transformer_weight_init_scale_strategy: str = "depthwise", + transformer_tanh_on_mem: bool = False, + ) -> None: + super().__init__() + self.input_linear = torch.nn.Linear( + input_dim, + time_reduction_input_dim, + bias=False, + ) + self.time_reduction = _TimeReduction(time_reduction_stride) + transformer_input_dim = time_reduction_input_dim * time_reduction_stride + self.transformer = Emformer( + transformer_input_dim, + transformer_num_heads, + transformer_ffn_dim, + transformer_num_layers, + segment_length // time_reduction_stride, + dropout=transformer_dropout, + activation=transformer_activation, + left_context_length=transformer_left_context_length, + right_context_length=right_context_length // time_reduction_stride, + max_memory_size=transformer_max_memory_size, + weight_init_scale_strategy=transformer_weight_init_scale_strategy, + tanh_on_mem=transformer_tanh_on_mem, + ) + self.output_linear = torch.nn.Linear(transformer_input_dim, output_dim) + self.layer_norm = torch.nn.LayerNorm(output_dim) + + def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Forward pass for training. + + B: batch size; + T: maximum input sequence length in batch; + D: feature dimension of each input sequence frame (input_dim). + + Args: + input (torch.Tensor): input frame sequences right-padded with right context, with + shape `(B, T + right context length, D)`. + lengths (torch.Tensor): with shape `(B,)` and i-th element representing + number of valid frames for i-th batch element in ``input``. + + Returns: + (torch.Tensor, torch.Tensor): + torch.Tensor + output frame sequences, with + shape `(B, T // time_reduction_stride, output_dim)`. + torch.Tensor + output input lengths, with shape `(B,)` and i-th element representing + number of valid elements for i-th batch element in output frame sequences. + """ + input_linear_out = self.input_linear(input) + time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths) + transformer_out, transformer_lengths = self.transformer(time_reduction_out, time_reduction_lengths) + output_linear_out = self.output_linear(transformer_out) + layer_norm_out = self.layer_norm(output_linear_out) + return layer_norm_out, transformer_lengths + + @torch.jit.export + def infer( + self, + input: torch.Tensor, + lengths: torch.Tensor, + states: Optional[List[List[torch.Tensor]]], + ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: + r"""Forward pass for inference. + + B: batch size; + T: maximum input sequence segment length in batch; + D: feature dimension of each input sequence frame (input_dim). + + Args: + input (torch.Tensor): input frame sequence segments right-padded with right context, with + shape `(B, T + right context length, D)`. + lengths (torch.Tensor): with shape `(B,)` and i-th element representing + number of valid frames for i-th batch element in ``input``. + state (List[List[torch.Tensor]] or None): list of lists of tensors + representing internal state generated in preceding invocation + of ``infer``. + + Returns: + (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]): + torch.Tensor + output frame sequences, with + shape `(B, T // time_reduction_stride, output_dim)`. + torch.Tensor + output input lengths, with shape `(B,)` and i-th element representing + number of valid elements for i-th batch element in output. + List[List[torch.Tensor]] + output states; list of lists of tensors + representing internal state generated in current invocation + of ``infer``. + """ + input_linear_out = self.input_linear(input) + time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths) + ( + transformer_out, + transformer_lengths, + transformer_states, + ) = self.transformer.infer(time_reduction_out, time_reduction_lengths, states) + output_linear_out = self.output_linear(transformer_out) + layer_norm_out = self.layer_norm(output_linear_out) + return layer_norm_out, transformer_lengths, transformer_states + + +class _Predictor(torch.nn.Module): + r"""Recurrent neural network transducer (RNN-T) prediction network. + + Args: + num_symbols (int): size of target token lexicon. + output_dim (int): feature dimension of each output sequence element. + symbol_embedding_dim (int): dimension of each target token embedding. + num_lstm_layers (int): number of LSTM layers to instantiate. + lstm_hidden_dim (int): output dimension of each LSTM layer. + lstm_layer_norm (bool, optional): if ``True``, enables layer normalization + for LSTM layers. (Default: ``False``) + lstm_layer_norm_epsilon (float, optional): value of epsilon to use in + LSTM layer normalization layers. (Default: 1e-5) + lstm_dropout (float, optional): LSTM dropout probability. (Default: 0.0) + + """ + + def __init__( + self, + num_symbols: int, + output_dim: int, + symbol_embedding_dim: int, + num_lstm_layers: int, + lstm_hidden_dim: int, + lstm_layer_norm: bool = False, + lstm_layer_norm_epsilon: float = 1e-5, + lstm_dropout: float = 0.0, + ) -> None: + super().__init__() + self.embedding = torch.nn.Embedding(num_symbols, symbol_embedding_dim) + self.input_layer_norm = torch.nn.LayerNorm(symbol_embedding_dim) + self.lstm_layers = torch.nn.ModuleList( + [ + _CustomLSTM( + symbol_embedding_dim if idx == 0 else lstm_hidden_dim, + lstm_hidden_dim, + layer_norm=lstm_layer_norm, + layer_norm_epsilon=lstm_layer_norm_epsilon, + ) + for idx in range(num_lstm_layers) + ] + ) + self.dropout = torch.nn.Dropout(p=lstm_dropout) + self.linear = torch.nn.Linear(lstm_hidden_dim, output_dim) + self.output_layer_norm = torch.nn.LayerNorm(output_dim) + + self.lstm_dropout = lstm_dropout + + def forward( + self, + input: torch.Tensor, + lengths: torch.Tensor, + state: Optional[List[List[torch.Tensor]]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: + r"""Forward pass. + + B: batch size; + U: maximum sequence length in batch; + D: feature dimension of each input sequence element. + + Args: + input (torch.Tensor): target sequences, with shape `(B, U)` and each element + mapping to a target symbol, i.e. in range `[0, num_symbols)`. + lengths (torch.Tensor): with shape `(B,)` and i-th element representing + number of valid frames for i-th batch element in ``input``. + state (List[List[torch.Tensor]] or None, optional): list of lists of tensors + representing internal state generated in preceding invocation + of ``forward``. (Default: ``None``) + + Returns: + (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]): + torch.Tensor + output encoding sequences, with shape `(B, U, output_dim)` + torch.Tensor + output lengths, with shape `(B,)` and i-th element representing + number of valid elements for i-th batch element in output encoding sequences. + List[List[torch.Tensor]] + output states; list of lists of tensors + representing internal state generated in current invocation of ``forward``. + """ + input_tb = input.permute(1, 0) + embedding_out = self.embedding(input_tb) + input_layer_norm_out = self.input_layer_norm(embedding_out) + + lstm_out = input_layer_norm_out + state_out: List[List[torch.Tensor]] = [] + for layer_idx, lstm in enumerate(self.lstm_layers): + lstm_out, lstm_state_out = lstm(lstm_out, None if state is None else state[layer_idx]) + lstm_out = self.dropout(lstm_out) + state_out.append(lstm_state_out) + + linear_out = self.linear(lstm_out) + output_layer_norm_out = self.output_layer_norm(linear_out) + return output_layer_norm_out.permute(1, 0, 2), lengths, state_out + + +class _Joiner(torch.nn.Module): + r"""Recurrent neural network transducer (RNN-T) joint network. + + Args: + input_dim (int): source and target input dimension. + output_dim (int): output dimension. + activation (str, optional): activation function to use in the joiner. + Must be one of ("relu", "tanh"). (Default: "relu") + + """ + + def __init__(self, input_dim: int, output_dim: int, activation: str = "relu") -> None: + super().__init__() + self.linear = torch.nn.Linear(input_dim, output_dim, bias=True) + if activation == "relu": + self.activation = torch.nn.ReLU() + elif activation == "tanh": + self.activation = torch.nn.Tanh() + else: + raise ValueError(f"Unsupported activation {activation}") + + def forward( + self, + source_encodings: torch.Tensor, + source_lengths: torch.Tensor, + target_encodings: torch.Tensor, + target_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r"""Forward pass for training. + + B: batch size; + T: maximum source sequence length in batch; + U: maximum target sequence length in batch; + D: dimension of each source and target sequence encoding. + + Args: + source_encodings (torch.Tensor): source encoding sequences, with + shape `(B, T, D)`. + source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing + valid sequence length of i-th batch element in ``source_encodings``. + target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`. + target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing + valid sequence length of i-th batch element in ``target_encodings``. + + Returns: + (torch.Tensor, torch.Tensor, torch.Tensor): + torch.Tensor + joint network output, with shape `(B, T, U, output_dim)`. + torch.Tensor + output source lengths, with shape `(B,)` and i-th element representing + number of valid elements along dim 1 for i-th batch element in joint network output. + torch.Tensor + output target lengths, with shape `(B,)` and i-th element representing + number of valid elements along dim 2 for i-th batch element in joint network output. + """ + joint_encodings = source_encodings.unsqueeze(2).contiguous() + target_encodings.unsqueeze(1).contiguous() + activation_out = self.activation(joint_encodings) + output = self.linear(activation_out) + return output, source_lengths, target_lengths + + +class RNNT(torch.nn.Module): + r"""torchaudio.models.RNNT() + + Recurrent neural network transducer (RNN-T) model. + + Note: + To build the model, please use one of the factory functions. + + See Also: + :class:`torchaudio.pipelines.RNNTBundle`: ASR pipeline with pre-trained models. + + Args: + transcriber (torch.nn.Module): transcription network. + predictor (torch.nn.Module): prediction network. + joiner (torch.nn.Module): joint network. + """ + + def __init__(self, transcriber: _Transcriber, predictor: _Predictor, joiner: _Joiner) -> None: + super().__init__() + self.transcriber = transcriber + self.predictor = predictor + self.joiner = joiner + + def forward( + self, + sources: torch.Tensor, + source_lengths: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + predictor_state: Optional[List[List[torch.Tensor]]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: + r"""Forward pass for training. + + B: batch size; + T: maximum source sequence length in batch; + U: maximum target sequence length in batch; + D: feature dimension of each source sequence element. + + Args: + sources (torch.Tensor): source frame sequences right-padded with right context, with + shape `(B, T, D)`. + source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing + number of valid frames for i-th batch element in ``sources``. + targets (torch.Tensor): target sequences, with shape `(B, U)` and each element + mapping to a target symbol. + target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing + number of valid frames for i-th batch element in ``targets``. + predictor_state (List[List[torch.Tensor]] or None, optional): list of lists of tensors + representing prediction network internal state generated in preceding invocation + of ``forward``. (Default: ``None``) + + Returns: + (torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]): + torch.Tensor + joint network output, with shape + `(B, max output source length, max output target length, output_dim (number of target symbols))`. + torch.Tensor + output source lengths, with shape `(B,)` and i-th element representing + number of valid elements along dim 1 for i-th batch element in joint network output. + torch.Tensor + output target lengths, with shape `(B,)` and i-th element representing + number of valid elements along dim 2 for i-th batch element in joint network output. + List[List[torch.Tensor]] + output states; list of lists of tensors + representing prediction network internal state generated in current invocation + of ``forward``. + """ + source_encodings, source_lengths = self.transcriber( + input=sources, + lengths=source_lengths, + ) + target_encodings, target_lengths, predictor_state = self.predictor( + input=targets, + lengths=target_lengths, + state=predictor_state, + ) + output, source_lengths, target_lengths = self.joiner( + source_encodings=source_encodings, + source_lengths=source_lengths, + target_encodings=target_encodings, + target_lengths=target_lengths, + ) + + return ( + output, + source_lengths, + target_lengths, + predictor_state, + ) + + @torch.jit.export + def transcribe_streaming( + self, + sources: torch.Tensor, + source_lengths: torch.Tensor, + state: Optional[List[List[torch.Tensor]]], + ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: + r"""Applies transcription network to sources in streaming mode. + + B: batch size; + T: maximum source sequence segment length in batch; + D: feature dimension of each source sequence frame. + + Args: + sources (torch.Tensor): source frame sequence segments right-padded with right context, with + shape `(B, T + right context length, D)`. + source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing + number of valid frames for i-th batch element in ``sources``. + state (List[List[torch.Tensor]] or None): list of lists of tensors + representing transcription network internal state generated in preceding invocation + of ``transcribe_streaming``. + + Returns: + (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]): + torch.Tensor + output frame sequences, with + shape `(B, T // time_reduction_stride, output_dim)`. + torch.Tensor + output lengths, with shape `(B,)` and i-th element representing + number of valid elements for i-th batch element in output. + List[List[torch.Tensor]] + output states; list of lists of tensors + representing transcription network internal state generated in current invocation + of ``transcribe_streaming``. + """ + return self.transcriber.infer(sources, source_lengths, state) + + @torch.jit.export + def transcribe( + self, + sources: torch.Tensor, + source_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Applies transcription network to sources in non-streaming mode. + + B: batch size; + T: maximum source sequence length in batch; + D: feature dimension of each source sequence frame. + + Args: + sources (torch.Tensor): source frame sequences right-padded with right context, with + shape `(B, T + right context length, D)`. + source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing + number of valid frames for i-th batch element in ``sources``. + + Returns: + (torch.Tensor, torch.Tensor): + torch.Tensor + output frame sequences, with + shape `(B, T // time_reduction_stride, output_dim)`. + torch.Tensor + output lengths, with shape `(B,)` and i-th element representing + number of valid elements for i-th batch element in output frame sequences. + """ + return self.transcriber(sources, source_lengths) + + @torch.jit.export + def predict( + self, + targets: torch.Tensor, + target_lengths: torch.Tensor, + state: Optional[List[List[torch.Tensor]]], + ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: + r"""Applies prediction network to targets. + + B: batch size; + U: maximum target sequence length in batch; + D: feature dimension of each target sequence frame. + + Args: + targets (torch.Tensor): target sequences, with shape `(B, U)` and each element + mapping to a target symbol, i.e. in range `[0, num_symbols)`. + target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing + number of valid frames for i-th batch element in ``targets``. + state (List[List[torch.Tensor]] or None): list of lists of tensors + representing internal state generated in preceding invocation + of ``predict``. + + Returns: + (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]): + torch.Tensor + output frame sequences, with shape `(B, U, output_dim)`. + torch.Tensor + output lengths, with shape `(B,)` and i-th element representing + number of valid elements for i-th batch element in output. + List[List[torch.Tensor]] + output states; list of lists of tensors + representing internal state generated in current invocation of ``predict``. + """ + return self.predictor(input=targets, lengths=target_lengths, state=state) + + @torch.jit.export + def join( + self, + source_encodings: torch.Tensor, + source_lengths: torch.Tensor, + target_encodings: torch.Tensor, + target_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r"""Applies joint network to source and target encodings. + + B: batch size; + T: maximum source sequence length in batch; + U: maximum target sequence length in batch; + D: dimension of each source and target sequence encoding. + + Args: + source_encodings (torch.Tensor): source encoding sequences, with + shape `(B, T, D)`. + source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing + valid sequence length of i-th batch element in ``source_encodings``. + target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`. + target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing + valid sequence length of i-th batch element in ``target_encodings``. + + Returns: + (torch.Tensor, torch.Tensor, torch.Tensor): + torch.Tensor + joint network output, with shape `(B, T, U, output_dim)`. + torch.Tensor + output source lengths, with shape `(B,)` and i-th element representing + number of valid elements along dim 1 for i-th batch element in joint network output. + torch.Tensor + output target lengths, with shape `(B,)` and i-th element representing + number of valid elements along dim 2 for i-th batch element in joint network output. + """ + output, source_lengths, target_lengths = self.joiner( + source_encodings=source_encodings, + source_lengths=source_lengths, + target_encodings=target_encodings, + target_lengths=target_lengths, + ) + return output, source_lengths, target_lengths + + +def emformer_rnnt_model( + *, + input_dim: int, + encoding_dim: int, + num_symbols: int, + segment_length: int, + right_context_length: int, + time_reduction_input_dim: int, + time_reduction_stride: int, + transformer_num_heads: int, + transformer_ffn_dim: int, + transformer_num_layers: int, + transformer_dropout: float, + transformer_activation: str, + transformer_left_context_length: int, + transformer_max_memory_size: int, + transformer_weight_init_scale_strategy: str, + transformer_tanh_on_mem: bool, + symbol_embedding_dim: int, + num_lstm_layers: int, + lstm_layer_norm: bool, + lstm_layer_norm_epsilon: float, + lstm_dropout: float, +) -> RNNT: + r"""Builds Emformer-based :class:`~torchaudio.models.RNNT`. + + Note: + For non-streaming inference, the expectation is for `transcribe` to be called on input + sequences right-concatenated with `right_context_length` frames. + + For streaming inference, the expectation is for `transcribe_streaming` to be called + on input chunks comprising `segment_length` frames right-concatenated with `right_context_length` + frames. + + Args: + input_dim (int): dimension of input sequence frames passed to transcription network. + encoding_dim (int): dimension of transcription- and prediction-network-generated encodings + passed to joint network. + num_symbols (int): cardinality of set of target tokens. + segment_length (int): length of input segment expressed as number of frames. + right_context_length (int): length of right context expressed as number of frames. + time_reduction_input_dim (int): dimension to scale each element in input sequences to + prior to applying time reduction block. + time_reduction_stride (int): factor by which to reduce length of input sequence. + transformer_num_heads (int): number of attention heads in each Emformer layer. + transformer_ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network. + transformer_num_layers (int): number of Emformer layers to instantiate. + transformer_left_context_length (int): length of left context considered by Emformer. + transformer_dropout (float): Emformer dropout probability. + transformer_activation (str): activation function to use in each Emformer layer's + feedforward network. Must be one of ("relu", "gelu", "silu"). + transformer_max_memory_size (int): maximum number of memory elements to use. + transformer_weight_init_scale_strategy (str): per-layer weight initialization scaling + strategy. Must be one of ("depthwise", "constant", ``None``). + transformer_tanh_on_mem (bool): if ``True``, applies tanh to memory elements. + symbol_embedding_dim (int): dimension of each target token embedding. + num_lstm_layers (int): number of LSTM layers to instantiate. + lstm_layer_norm (bool): if ``True``, enables layer normalization for LSTM layers. + lstm_layer_norm_epsilon (float): value of epsilon to use in LSTM layer normalization layers. + lstm_dropout (float): LSTM dropout probability. + + Returns: + RNNT: + Emformer RNN-T model. + """ + encoder = _EmformerEncoder( + input_dim=input_dim, + output_dim=encoding_dim, + segment_length=segment_length, + right_context_length=right_context_length, + time_reduction_input_dim=time_reduction_input_dim, + time_reduction_stride=time_reduction_stride, + transformer_num_heads=transformer_num_heads, + transformer_ffn_dim=transformer_ffn_dim, + transformer_num_layers=transformer_num_layers, + transformer_dropout=transformer_dropout, + transformer_activation=transformer_activation, + transformer_left_context_length=transformer_left_context_length, + transformer_max_memory_size=transformer_max_memory_size, + transformer_weight_init_scale_strategy=transformer_weight_init_scale_strategy, + transformer_tanh_on_mem=transformer_tanh_on_mem, + ) + predictor = _Predictor( + num_symbols, + encoding_dim, + symbol_embedding_dim=symbol_embedding_dim, + num_lstm_layers=num_lstm_layers, + lstm_hidden_dim=symbol_embedding_dim, + lstm_layer_norm=lstm_layer_norm, + lstm_layer_norm_epsilon=lstm_layer_norm_epsilon, + lstm_dropout=lstm_dropout, + ) + joiner = _Joiner(encoding_dim, num_symbols) + return RNNT(encoder, predictor, joiner) + + +def emformer_rnnt_base(num_symbols: int) -> RNNT: + r"""Builds basic version of Emformer-based :class:`~torchaudio.models.RNNT`. + + Args: + num_symbols (int): The size of target token lexicon. + + Returns: + RNNT: + Emformer RNN-T model. + """ + return emformer_rnnt_model( + input_dim=80, + encoding_dim=1024, + num_symbols=num_symbols, + segment_length=16, + right_context_length=4, + time_reduction_input_dim=128, + time_reduction_stride=4, + transformer_num_heads=8, + transformer_ffn_dim=2048, + transformer_num_layers=20, + transformer_dropout=0.1, + transformer_activation="gelu", + transformer_left_context_length=30, + transformer_max_memory_size=0, + transformer_weight_init_scale_strategy="depthwise", + transformer_tanh_on_mem=True, + symbol_embedding_dim=512, + num_lstm_layers=3, + lstm_layer_norm=True, + lstm_layer_norm_epsilon=1e-3, + lstm_dropout=0.3, + ) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/rnnt_decoder.py b/.venv/lib/python3.11/site-packages/torchaudio/models/rnnt_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..5a02b2ca907733a8e1ab404d1107bb702e977748 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/models/rnnt_decoder.py @@ -0,0 +1,339 @@ +from typing import Callable, Dict, List, Optional, Tuple + +import torch +from torchaudio.models import RNNT + + +__all__ = ["Hypothesis", "RNNTBeamSearch"] + + +Hypothesis = Tuple[List[int], torch.Tensor, List[List[torch.Tensor]], float] +Hypothesis.__doc__ = """Hypothesis generated by RNN-T beam search decoder, + represented as tuple of (tokens, prediction network output, prediction network state, score). + """ + + +def _get_hypo_tokens(hypo: Hypothesis) -> List[int]: + return hypo[0] + + +def _get_hypo_predictor_out(hypo: Hypothesis) -> torch.Tensor: + return hypo[1] + + +def _get_hypo_state(hypo: Hypothesis) -> List[List[torch.Tensor]]: + return hypo[2] + + +def _get_hypo_score(hypo: Hypothesis) -> float: + return hypo[3] + + +def _get_hypo_key(hypo: Hypothesis) -> str: + return str(hypo[0]) + + +def _batch_state(hypos: List[Hypothesis]) -> List[List[torch.Tensor]]: + states: List[List[torch.Tensor]] = [] + for i in range(len(_get_hypo_state(hypos[0]))): + batched_state_components: List[torch.Tensor] = [] + for j in range(len(_get_hypo_state(hypos[0])[i])): + batched_state_components.append(torch.cat([_get_hypo_state(hypo)[i][j] for hypo in hypos])) + states.append(batched_state_components) + return states + + +def _slice_state(states: List[List[torch.Tensor]], idx: int, device: torch.device) -> List[List[torch.Tensor]]: + idx_tensor = torch.tensor([idx], device=device) + return [[state.index_select(0, idx_tensor) for state in state_tuple] for state_tuple in states] + + +def _default_hypo_sort_key(hypo: Hypothesis) -> float: + return _get_hypo_score(hypo) / (len(_get_hypo_tokens(hypo)) + 1) + + +def _compute_updated_scores( + hypos: List[Hypothesis], + next_token_probs: torch.Tensor, + beam_width: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hypo_scores = torch.tensor([_get_hypo_score(h) for h in hypos]).unsqueeze(1) + nonblank_scores = hypo_scores + next_token_probs[:, :-1] # [beam_width, num_tokens - 1] + nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(beam_width) + nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(nonblank_scores.shape[1], rounding_mode="trunc") + nonblank_nbest_token = nonblank_nbest_idx % nonblank_scores.shape[1] + return nonblank_nbest_scores, nonblank_nbest_hypo_idx, nonblank_nbest_token + + +def _remove_hypo(hypo: Hypothesis, hypo_list: List[Hypothesis]) -> None: + for i, elem in enumerate(hypo_list): + if _get_hypo_key(hypo) == _get_hypo_key(elem): + del hypo_list[i] + break + + +class RNNTBeamSearch(torch.nn.Module): + r"""Beam search decoder for RNN-T model. + + See Also: + * :class:`torchaudio.pipelines.RNNTBundle`: ASR pipeline with pretrained model. + + Args: + model (RNNT): RNN-T model to use. + blank (int): index of blank token in vocabulary. + temperature (float, optional): temperature to apply to joint network output. + Larger values yield more uniform samples. (Default: 1.0) + hypo_sort_key (Callable[[Hypothesis], float] or None, optional): callable that computes a score + for a given hypothesis to rank hypotheses by. If ``None``, defaults to callable that returns + hypothesis score normalized by token sequence length. (Default: None) + step_max_tokens (int, optional): maximum number of tokens to emit per input time step. (Default: 100) + """ + + def __init__( + self, + model: RNNT, + blank: int, + temperature: float = 1.0, + hypo_sort_key: Optional[Callable[[Hypothesis], float]] = None, + step_max_tokens: int = 100, + ) -> None: + super().__init__() + self.model = model + self.blank = blank + self.temperature = temperature + + if hypo_sort_key is None: + self.hypo_sort_key = _default_hypo_sort_key + else: + self.hypo_sort_key = hypo_sort_key + + self.step_max_tokens = step_max_tokens + + def _init_b_hypos(self, device: torch.device) -> List[Hypothesis]: + token = self.blank + state = None + + one_tensor = torch.tensor([1], device=device) + pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state) + init_hypo = ( + [token], + pred_out[0].detach(), + pred_state, + 0.0, + ) + return [init_hypo] + + def _gen_next_token_probs( + self, enc_out: torch.Tensor, hypos: List[Hypothesis], device: torch.device + ) -> torch.Tensor: + one_tensor = torch.tensor([1], device=device) + predictor_out = torch.stack([_get_hypo_predictor_out(h) for h in hypos], dim=0) + joined_out, _, _ = self.model.join( + enc_out, + one_tensor, + predictor_out, + torch.tensor([1] * len(hypos), device=device), + ) # [beam_width, 1, 1, num_tokens] + joined_out = torch.nn.functional.log_softmax(joined_out / self.temperature, dim=3) + return joined_out[:, 0, 0] + + def _gen_b_hypos( + self, + b_hypos: List[Hypothesis], + a_hypos: List[Hypothesis], + next_token_probs: torch.Tensor, + key_to_b_hypo: Dict[str, Hypothesis], + ) -> List[Hypothesis]: + for i in range(len(a_hypos)): + h_a = a_hypos[i] + append_blank_score = _get_hypo_score(h_a) + next_token_probs[i, -1] + if _get_hypo_key(h_a) in key_to_b_hypo: + h_b = key_to_b_hypo[_get_hypo_key(h_a)] + _remove_hypo(h_b, b_hypos) + score = float(torch.tensor(_get_hypo_score(h_b)).logaddexp(append_blank_score)) + else: + score = float(append_blank_score) + h_b = ( + _get_hypo_tokens(h_a), + _get_hypo_predictor_out(h_a), + _get_hypo_state(h_a), + score, + ) + b_hypos.append(h_b) + key_to_b_hypo[_get_hypo_key(h_b)] = h_b + _, sorted_idx = torch.tensor([_get_hypo_score(hypo) for hypo in b_hypos]).sort() + return [b_hypos[idx] for idx in sorted_idx] + + def _gen_a_hypos( + self, + a_hypos: List[Hypothesis], + b_hypos: List[Hypothesis], + next_token_probs: torch.Tensor, + t: int, + beam_width: int, + device: torch.device, + ) -> List[Hypothesis]: + ( + nonblank_nbest_scores, + nonblank_nbest_hypo_idx, + nonblank_nbest_token, + ) = _compute_updated_scores(a_hypos, next_token_probs, beam_width) + + if len(b_hypos) < beam_width: + b_nbest_score = -float("inf") + else: + b_nbest_score = _get_hypo_score(b_hypos[-beam_width]) + + base_hypos: List[Hypothesis] = [] + new_tokens: List[int] = [] + new_scores: List[float] = [] + for i in range(beam_width): + score = float(nonblank_nbest_scores[i]) + if score > b_nbest_score: + a_hypo_idx = int(nonblank_nbest_hypo_idx[i]) + base_hypos.append(a_hypos[a_hypo_idx]) + new_tokens.append(int(nonblank_nbest_token[i])) + new_scores.append(score) + + if base_hypos: + new_hypos = self._gen_new_hypos(base_hypos, new_tokens, new_scores, t, device) + else: + new_hypos: List[Hypothesis] = [] + + return new_hypos + + def _gen_new_hypos( + self, + base_hypos: List[Hypothesis], + tokens: List[int], + scores: List[float], + t: int, + device: torch.device, + ) -> List[Hypothesis]: + tgt_tokens = torch.tensor([[token] for token in tokens], device=device) + states = _batch_state(base_hypos) + pred_out, _, pred_states = self.model.predict( + tgt_tokens, + torch.tensor([1] * len(base_hypos), device=device), + states, + ) + new_hypos: List[Hypothesis] = [] + for i, h_a in enumerate(base_hypos): + new_tokens = _get_hypo_tokens(h_a) + [tokens[i]] + new_hypos.append((new_tokens, pred_out[i].detach(), _slice_state(pred_states, i, device), scores[i])) + return new_hypos + + def _search( + self, + enc_out: torch.Tensor, + hypo: Optional[List[Hypothesis]], + beam_width: int, + ) -> List[Hypothesis]: + n_time_steps = enc_out.shape[1] + device = enc_out.device + + a_hypos: List[Hypothesis] = [] + b_hypos = self._init_b_hypos(device) if hypo is None else hypo + for t in range(n_time_steps): + a_hypos = b_hypos + b_hypos = torch.jit.annotate(List[Hypothesis], []) + key_to_b_hypo: Dict[str, Hypothesis] = {} + symbols_current_t = 0 + + while a_hypos: + next_token_probs = self._gen_next_token_probs(enc_out[:, t : t + 1], a_hypos, device) + next_token_probs = next_token_probs.cpu() + b_hypos = self._gen_b_hypos(b_hypos, a_hypos, next_token_probs, key_to_b_hypo) + + if symbols_current_t == self.step_max_tokens: + break + + a_hypos = self._gen_a_hypos( + a_hypos, + b_hypos, + next_token_probs, + t, + beam_width, + device, + ) + if a_hypos: + symbols_current_t += 1 + + _, sorted_idx = torch.tensor([self.hypo_sort_key(hyp) for hyp in b_hypos]).topk(beam_width) + b_hypos = [b_hypos[idx] for idx in sorted_idx] + + return b_hypos + + def forward(self, input: torch.Tensor, length: torch.Tensor, beam_width: int) -> List[Hypothesis]: + r"""Performs beam search for the given input sequence. + + T: number of frames; + D: feature dimension of each frame. + + Args: + input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D). + length (torch.Tensor): number of valid frames in input + sequence, with shape () or (1,). + beam_width (int): beam size to use during search. + + Returns: + List[Hypothesis]: top-``beam_width`` hypotheses found by beam search. + """ + if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1): + raise ValueError("input must be of shape (T, D) or (1, T, D)") + if input.dim() == 2: + input = input.unsqueeze(0) + + if length.shape != () and length.shape != (1,): + raise ValueError("length must be of shape () or (1,)") + if length.dim() == 0: + length = length.unsqueeze(0) + + enc_out, _ = self.model.transcribe(input, length) + return self._search(enc_out, None, beam_width) + + @torch.jit.export + def infer( + self, + input: torch.Tensor, + length: torch.Tensor, + beam_width: int, + state: Optional[List[List[torch.Tensor]]] = None, + hypothesis: Optional[List[Hypothesis]] = None, + ) -> Tuple[List[Hypothesis], List[List[torch.Tensor]]]: + r"""Performs beam search for the given input sequence in streaming mode. + + T: number of frames; + D: feature dimension of each frame. + + Args: + input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D). + length (torch.Tensor): number of valid frames in input + sequence, with shape () or (1,). + beam_width (int): beam size to use during search. + state (List[List[torch.Tensor]] or None, optional): list of lists of tensors + representing transcription network internal state generated in preceding + invocation. (Default: ``None``) + hypothesis (List[Hypothesis] or None): hypotheses from preceding invocation to seed + search with. (Default: ``None``) + + Returns: + (List[Hypothesis], List[List[torch.Tensor]]): + List[Hypothesis] + top-``beam_width`` hypotheses found by beam search. + List[List[torch.Tensor]] + list of lists of tensors representing transcription network + internal state generated in current invocation. + """ + if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1): + raise ValueError("input must be of shape (T, D) or (1, T, D)") + if input.dim() == 2: + input = input.unsqueeze(0) + + if length.shape != () and length.shape != (1,): + raise ValueError("length must be of shape () or (1,)") + if length.dim() == 0: + length = length.unsqueeze(0) + + enc_out, _, state = self.model.transcribe_streaming(input, length, state) + return self._search(enc_out, hypothesis, beam_width), state diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/tacotron2.py b/.venv/lib/python3.11/site-packages/torchaudio/models/tacotron2.py new file mode 100644 index 0000000000000000000000000000000000000000..978fb97c88db9c64a9b216a340e63075e53e2295 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/models/tacotron2.py @@ -0,0 +1,1046 @@ +# ***************************************************************************** +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the +# names of its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# ***************************************************************************** + +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn, Tensor +from torch.nn import functional as F + + +__all__ = [ + "Tacotron2", +] + + +def _get_linear_layer(in_dim: int, out_dim: int, bias: bool = True, w_init_gain: str = "linear") -> torch.nn.Linear: + r"""Linear layer with xavier uniform initialization. + + Args: + in_dim (int): Size of each input sample. + out_dim (int): Size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias. (Default: ``True``) + w_init_gain (str, optional): Parameter passed to ``torch.nn.init.calculate_gain`` + for setting the gain parameter of ``xavier_uniform_``. (Default: ``linear``) + + Returns: + (torch.nn.Linear): The corresponding linear layer. + """ + linear = torch.nn.Linear(in_dim, out_dim, bias=bias) + torch.nn.init.xavier_uniform_(linear.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) + return linear + + +def _get_conv1d_layer( + in_channels: int, + out_channels: int, + kernel_size: int = 1, + stride: int = 1, + padding: Optional[Union[str, int, Tuple[int]]] = None, + dilation: int = 1, + bias: bool = True, + w_init_gain: str = "linear", +) -> torch.nn.Conv1d: + r"""1D convolution with xavier uniform initialization. + + Args: + in_channels (int): Number of channels in the input image. + out_channels (int): Number of channels produced by the convolution. + kernel_size (int, optional): Number of channels in the input image. (Default: ``1``) + stride (int, optional): Number of channels in the input image. (Default: ``1``) + padding (str, int or tuple, optional): Padding added to both sides of the input. + (Default: dilation * (kernel_size - 1) / 2) + dilation (int, optional): Number of channels in the input image. (Default: ``1``) + w_init_gain (str, optional): Parameter passed to ``torch.nn.init.calculate_gain`` + for setting the gain parameter of ``xavier_uniform_``. (Default: ``linear``) + + Returns: + (torch.nn.Conv1d): The corresponding Conv1D layer. + """ + if padding is None: + if kernel_size % 2 != 1: + raise ValueError("kernel_size must be odd") + padding = int(dilation * (kernel_size - 1) / 2) + + conv1d = torch.nn.Conv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + + torch.nn.init.xavier_uniform_(conv1d.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) + + return conv1d + + +def _get_mask_from_lengths(lengths: Tensor) -> Tensor: + r"""Returns a binary mask based on ``lengths``. The ``i``-th row and ``j``-th column of the mask + is ``1`` if ``j`` is smaller than ``i``-th element of ``lengths. + + Args: + lengths (Tensor): The length of each element in the batch, with shape (n_batch, ). + + Returns: + mask (Tensor): The binary mask, with shape (n_batch, max of ``lengths``). + """ + max_len = torch.max(lengths).item() + ids = torch.arange(0, max_len, device=lengths.device, dtype=lengths.dtype) + mask = (ids < lengths.unsqueeze(1)).byte() + mask = torch.le(mask, 0) + return mask + + +class _LocationLayer(nn.Module): + r"""Location layer used in the Attention model. + + Args: + attention_n_filter (int): Number of filters for attention model. + attention_kernel_size (int): Kernel size for attention model. + attention_hidden_dim (int): Dimension of attention hidden representation. + """ + + def __init__( + self, + attention_n_filter: int, + attention_kernel_size: int, + attention_hidden_dim: int, + ): + super().__init__() + padding = int((attention_kernel_size - 1) / 2) + self.location_conv = _get_conv1d_layer( + 2, + attention_n_filter, + kernel_size=attention_kernel_size, + padding=padding, + bias=False, + stride=1, + dilation=1, + ) + self.location_dense = _get_linear_layer( + attention_n_filter, attention_hidden_dim, bias=False, w_init_gain="tanh" + ) + + def forward(self, attention_weights_cat: Tensor) -> Tensor: + r"""Location layer used in the Attention model. + + Args: + attention_weights_cat (Tensor): Cumulative and previous attention weights + with shape (n_batch, 2, max of ``text_lengths``). + + Returns: + processed_attention (Tensor): Cumulative and previous attention weights + with shape (n_batch, ``attention_hidden_dim``). + """ + # (n_batch, attention_n_filter, text_lengths.max()) + processed_attention = self.location_conv(attention_weights_cat) + processed_attention = processed_attention.transpose(1, 2) + # (n_batch, text_lengths.max(), attention_hidden_dim) + processed_attention = self.location_dense(processed_attention) + return processed_attention + + +class _Attention(nn.Module): + r"""Locally sensitive attention model. + + Args: + attention_rnn_dim (int): Number of hidden units for RNN. + encoder_embedding_dim (int): Number of embedding dimensions in the Encoder. + attention_hidden_dim (int): Dimension of attention hidden representation. + attention_location_n_filter (int): Number of filters for Attention model. + attention_location_kernel_size (int): Kernel size for Attention model. + """ + + def __init__( + self, + attention_rnn_dim: int, + encoder_embedding_dim: int, + attention_hidden_dim: int, + attention_location_n_filter: int, + attention_location_kernel_size: int, + ) -> None: + super().__init__() + self.query_layer = _get_linear_layer(attention_rnn_dim, attention_hidden_dim, bias=False, w_init_gain="tanh") + self.memory_layer = _get_linear_layer( + encoder_embedding_dim, attention_hidden_dim, bias=False, w_init_gain="tanh" + ) + self.v = _get_linear_layer(attention_hidden_dim, 1, bias=False) + self.location_layer = _LocationLayer( + attention_location_n_filter, + attention_location_kernel_size, + attention_hidden_dim, + ) + self.score_mask_value = -float("inf") + + def _get_alignment_energies(self, query: Tensor, processed_memory: Tensor, attention_weights_cat: Tensor) -> Tensor: + r"""Get the alignment vector. + + Args: + query (Tensor): Decoder output with shape (n_batch, n_mels * n_frames_per_step). + processed_memory (Tensor): Processed Encoder outputs + with shape (n_batch, max of ``text_lengths``, attention_hidden_dim). + attention_weights_cat (Tensor): Cumulative and previous attention weights + with shape (n_batch, 2, max of ``text_lengths``). + + Returns: + alignment (Tensor): attention weights, it is a tensor with shape (batch, max of ``text_lengths``). + """ + + processed_query = self.query_layer(query.unsqueeze(1)) + processed_attention_weights = self.location_layer(attention_weights_cat) + energies = self.v(torch.tanh(processed_query + processed_attention_weights + processed_memory)) + + alignment = energies.squeeze(2) + return alignment + + def forward( + self, + attention_hidden_state: Tensor, + memory: Tensor, + processed_memory: Tensor, + attention_weights_cat: Tensor, + mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + r"""Pass the input through the Attention model. + + Args: + attention_hidden_state (Tensor): Attention rnn last output with shape (n_batch, ``attention_rnn_dim``). + memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``). + processed_memory (Tensor): Processed Encoder outputs + with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``). + attention_weights_cat (Tensor): Previous and cumulative attention weights + with shape (n_batch, current_num_frames * 2, max of ``text_lengths``). + mask (Tensor): Binary mask for padded data with shape (n_batch, current_num_frames). + + Returns: + attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``). + attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``). + """ + alignment = self._get_alignment_energies(attention_hidden_state, processed_memory, attention_weights_cat) + + alignment = alignment.masked_fill(mask, self.score_mask_value) + + attention_weights = F.softmax(alignment, dim=1) + attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) + attention_context = attention_context.squeeze(1) + + return attention_context, attention_weights + + +class _Prenet(nn.Module): + r"""Prenet Module. It is consists of ``len(output_size)`` linear layers. + + Args: + in_dim (int): The size of each input sample. + output_sizes (list): The output dimension of each linear layers. + """ + + def __init__(self, in_dim: int, out_sizes: List[int]) -> None: + super().__init__() + in_sizes = [in_dim] + out_sizes[:-1] + self.layers = nn.ModuleList( + [_get_linear_layer(in_size, out_size, bias=False) for (in_size, out_size) in zip(in_sizes, out_sizes)] + ) + + def forward(self, x: Tensor) -> Tensor: + r"""Pass the input through Prenet. + + Args: + x (Tensor): The input sequence to Prenet with shape (n_batch, in_dim). + + Return: + x (Tensor): Tensor with shape (n_batch, sizes[-1]) + """ + + for linear in self.layers: + x = F.dropout(F.relu(linear(x)), p=0.5, training=True) + return x + + +class _Postnet(nn.Module): + r"""Postnet Module. + + Args: + n_mels (int): Number of mel bins. + postnet_embedding_dim (int): Postnet embedding dimension. + postnet_kernel_size (int): Postnet kernel size. + postnet_n_convolution (int): Number of postnet convolutions. + """ + + def __init__( + self, + n_mels: int, + postnet_embedding_dim: int, + postnet_kernel_size: int, + postnet_n_convolution: int, + ): + super().__init__() + self.convolutions = nn.ModuleList() + + for i in range(postnet_n_convolution): + in_channels = n_mels if i == 0 else postnet_embedding_dim + out_channels = n_mels if i == (postnet_n_convolution - 1) else postnet_embedding_dim + init_gain = "linear" if i == (postnet_n_convolution - 1) else "tanh" + num_features = n_mels if i == (postnet_n_convolution - 1) else postnet_embedding_dim + self.convolutions.append( + nn.Sequential( + _get_conv1d_layer( + in_channels, + out_channels, + kernel_size=postnet_kernel_size, + stride=1, + padding=int((postnet_kernel_size - 1) / 2), + dilation=1, + w_init_gain=init_gain, + ), + nn.BatchNorm1d(num_features), + ) + ) + + self.n_convs = len(self.convolutions) + + def forward(self, x: Tensor) -> Tensor: + r"""Pass the input through Postnet. + + Args: + x (Tensor): The input sequence with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``). + + Return: + x (Tensor): Tensor with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``). + """ + + for i, conv in enumerate(self.convolutions): + if i < self.n_convs - 1: + x = F.dropout(torch.tanh(conv(x)), 0.5, training=self.training) + else: + x = F.dropout(conv(x), 0.5, training=self.training) + + return x + + +class _Encoder(nn.Module): + r"""Encoder Module. + + Args: + encoder_embedding_dim (int): Number of embedding dimensions in the encoder. + encoder_n_convolution (int): Number of convolution layers in the encoder. + encoder_kernel_size (int): The kernel size in the encoder. + + Examples + >>> encoder = _Encoder(3, 512, 5) + >>> input = torch.rand(10, 20, 30) + >>> output = encoder(input) # shape: (10, 30, 512) + """ + + def __init__( + self, + encoder_embedding_dim: int, + encoder_n_convolution: int, + encoder_kernel_size: int, + ) -> None: + super().__init__() + + self.convolutions = nn.ModuleList() + for _ in range(encoder_n_convolution): + conv_layer = nn.Sequential( + _get_conv1d_layer( + encoder_embedding_dim, + encoder_embedding_dim, + kernel_size=encoder_kernel_size, + stride=1, + padding=int((encoder_kernel_size - 1) / 2), + dilation=1, + w_init_gain="relu", + ), + nn.BatchNorm1d(encoder_embedding_dim), + ) + self.convolutions.append(conv_layer) + + self.lstm = nn.LSTM( + encoder_embedding_dim, + int(encoder_embedding_dim / 2), + 1, + batch_first=True, + bidirectional=True, + ) + self.lstm.flatten_parameters() + + def forward(self, x: Tensor, input_lengths: Tensor) -> Tensor: + r"""Pass the input through the Encoder. + + Args: + x (Tensor): The input sequences with shape (n_batch, encoder_embedding_dim, n_seq). + input_lengths (Tensor): The length of each input sequence with shape (n_batch, ). + + Return: + x (Tensor): A tensor with shape (n_batch, n_seq, encoder_embedding_dim). + """ + + for conv in self.convolutions: + x = F.dropout(F.relu(conv(x)), 0.5, self.training) + + x = x.transpose(1, 2) + + input_lengths = input_lengths.cpu() + x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True) + + outputs, _ = self.lstm(x) + outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) + + return outputs + + +class _Decoder(nn.Module): + r"""Decoder with Attention model. + + Args: + n_mels (int): number of mel bins + n_frames_per_step (int): number of frames processed per step, only 1 is supported + encoder_embedding_dim (int): the number of embedding dimensions in the encoder. + decoder_rnn_dim (int): number of units in decoder LSTM + decoder_max_step (int): maximum number of output mel spectrograms + decoder_dropout (float): dropout probability for decoder LSTM + decoder_early_stopping (bool): stop decoding when all samples are finished + attention_rnn_dim (int): number of units in attention LSTM + attention_hidden_dim (int): dimension of attention hidden representation + attention_location_n_filter (int): number of filters for attention model + attention_location_kernel_size (int): kernel size for attention model + attention_dropout (float): dropout probability for attention LSTM + prenet_dim (int): number of ReLU units in prenet layers + gate_threshold (float): probability threshold for stop token + """ + + def __init__( + self, + n_mels: int, + n_frames_per_step: int, + encoder_embedding_dim: int, + decoder_rnn_dim: int, + decoder_max_step: int, + decoder_dropout: float, + decoder_early_stopping: bool, + attention_rnn_dim: int, + attention_hidden_dim: int, + attention_location_n_filter: int, + attention_location_kernel_size: int, + attention_dropout: float, + prenet_dim: int, + gate_threshold: float, + ) -> None: + + super().__init__() + self.n_mels = n_mels + self.n_frames_per_step = n_frames_per_step + self.encoder_embedding_dim = encoder_embedding_dim + self.attention_rnn_dim = attention_rnn_dim + self.decoder_rnn_dim = decoder_rnn_dim + self.prenet_dim = prenet_dim + self.decoder_max_step = decoder_max_step + self.gate_threshold = gate_threshold + self.attention_dropout = attention_dropout + self.decoder_dropout = decoder_dropout + self.decoder_early_stopping = decoder_early_stopping + + self.prenet = _Prenet(n_mels * n_frames_per_step, [prenet_dim, prenet_dim]) + + self.attention_rnn = nn.LSTMCell(prenet_dim + encoder_embedding_dim, attention_rnn_dim) + + self.attention_layer = _Attention( + attention_rnn_dim, + encoder_embedding_dim, + attention_hidden_dim, + attention_location_n_filter, + attention_location_kernel_size, + ) + + self.decoder_rnn = nn.LSTMCell(attention_rnn_dim + encoder_embedding_dim, decoder_rnn_dim, True) + + self.linear_projection = _get_linear_layer(decoder_rnn_dim + encoder_embedding_dim, n_mels * n_frames_per_step) + + self.gate_layer = _get_linear_layer( + decoder_rnn_dim + encoder_embedding_dim, 1, bias=True, w_init_gain="sigmoid" + ) + + def _get_initial_frame(self, memory: Tensor) -> Tensor: + r"""Gets all zeros frames to use as the first decoder input. + + Args: + memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``). + + Returns: + decoder_input (Tensor): all zeros frames with shape + (n_batch, max of ``text_lengths``, ``n_mels * n_frames_per_step``). + """ + + n_batch = memory.size(0) + dtype = memory.dtype + device = memory.device + decoder_input = torch.zeros(n_batch, self.n_mels * self.n_frames_per_step, dtype=dtype, device=device) + return decoder_input + + def _initialize_decoder_states( + self, memory: Tensor + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + r"""Initializes attention rnn states, decoder rnn states, attention + weights, attention cumulative weights, attention context, stores memory + and stores processed memory. + + Args: + memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``). + + Returns: + attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``). + attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``). + decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``). + decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``). + attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``). + attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``). + attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``). + processed_memory (Tensor): Processed encoder outputs + with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``). + """ + n_batch = memory.size(0) + max_time = memory.size(1) + dtype = memory.dtype + device = memory.device + + attention_hidden = torch.zeros(n_batch, self.attention_rnn_dim, dtype=dtype, device=device) + attention_cell = torch.zeros(n_batch, self.attention_rnn_dim, dtype=dtype, device=device) + + decoder_hidden = torch.zeros(n_batch, self.decoder_rnn_dim, dtype=dtype, device=device) + decoder_cell = torch.zeros(n_batch, self.decoder_rnn_dim, dtype=dtype, device=device) + + attention_weights = torch.zeros(n_batch, max_time, dtype=dtype, device=device) + attention_weights_cum = torch.zeros(n_batch, max_time, dtype=dtype, device=device) + attention_context = torch.zeros(n_batch, self.encoder_embedding_dim, dtype=dtype, device=device) + + processed_memory = self.attention_layer.memory_layer(memory) + + return ( + attention_hidden, + attention_cell, + decoder_hidden, + decoder_cell, + attention_weights, + attention_weights_cum, + attention_context, + processed_memory, + ) + + def _parse_decoder_inputs(self, decoder_inputs: Tensor) -> Tensor: + r"""Prepares decoder inputs. + + Args: + decoder_inputs (Tensor): Inputs used for teacher-forced training, i.e. mel-specs, + with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``) + + Returns: + inputs (Tensor): Processed decoder inputs with shape (max of ``mel_specgram_lengths``, n_batch, ``n_mels``). + """ + # (n_batch, n_mels, mel_specgram_lengths.max()) -> (n_batch, mel_specgram_lengths.max(), n_mels) + decoder_inputs = decoder_inputs.transpose(1, 2) + decoder_inputs = decoder_inputs.view( + decoder_inputs.size(0), + int(decoder_inputs.size(1) / self.n_frames_per_step), + -1, + ) + # (n_batch, mel_specgram_lengths.max(), n_mels) -> (mel_specgram_lengths.max(), n_batch, n_mels) + decoder_inputs = decoder_inputs.transpose(0, 1) + return decoder_inputs + + def _parse_decoder_outputs( + self, mel_specgram: Tensor, gate_outputs: Tensor, alignments: Tensor + ) -> Tuple[Tensor, Tensor, Tensor]: + r"""Prepares decoder outputs for output + + Args: + mel_specgram (Tensor): mel spectrogram with shape (max of ``mel_specgram_lengths``, n_batch, ``n_mels``) + gate_outputs (Tensor): predicted stop token with shape (max of ``mel_specgram_lengths``, n_batch) + alignments (Tensor): sequence of attention weights from the decoder + with shape (max of ``mel_specgram_lengths``, n_batch, max of ``text_lengths``) + + Returns: + mel_specgram (Tensor): mel spectrogram with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``) + gate_outputs (Tensor): predicted stop token with shape (n_batch, max of ``mel_specgram_lengths``) + alignments (Tensor): sequence of attention weights from the decoder + with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``) + """ + # (mel_specgram_lengths.max(), n_batch, text_lengths.max()) + # -> (n_batch, mel_specgram_lengths.max(), text_lengths.max()) + alignments = alignments.transpose(0, 1).contiguous() + # (mel_specgram_lengths.max(), n_batch) -> (n_batch, mel_specgram_lengths.max()) + gate_outputs = gate_outputs.transpose(0, 1).contiguous() + # (mel_specgram_lengths.max(), n_batch, n_mels) -> (n_batch, mel_specgram_lengths.max(), n_mels) + mel_specgram = mel_specgram.transpose(0, 1).contiguous() + # decouple frames per step + shape = (mel_specgram.shape[0], -1, self.n_mels) + mel_specgram = mel_specgram.view(*shape) + # (n_batch, mel_specgram_lengths.max(), n_mels) -> (n_batch, n_mels, T_out) + mel_specgram = mel_specgram.transpose(1, 2) + + return mel_specgram, gate_outputs, alignments + + def decode( + self, + decoder_input: Tensor, + attention_hidden: Tensor, + attention_cell: Tensor, + decoder_hidden: Tensor, + decoder_cell: Tensor, + attention_weights: Tensor, + attention_weights_cum: Tensor, + attention_context: Tensor, + memory: Tensor, + processed_memory: Tensor, + mask: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + r"""Decoder step using stored states, attention and memory + + Args: + decoder_input (Tensor): Output of the Prenet with shape (n_batch, ``prenet_dim``). + attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``). + attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``). + decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``). + decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``). + attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``). + attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``). + attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``). + memory (Tensor): Encoder output with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``). + processed_memory (Tensor): Processed Encoder outputs + with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``). + mask (Tensor): Binary mask for padded data with shape (n_batch, current_num_frames). + + Returns: + decoder_output: Predicted mel spectrogram for the current frame with shape (n_batch, ``n_mels``). + gate_prediction (Tensor): Prediction of the stop token with shape (n_batch, ``1``). + attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``). + attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``). + decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``). + decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``). + attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``). + attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``). + attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``). + """ + cell_input = torch.cat((decoder_input, attention_context), -1) + + attention_hidden, attention_cell = self.attention_rnn(cell_input, (attention_hidden, attention_cell)) + attention_hidden = F.dropout(attention_hidden, self.attention_dropout, self.training) + + attention_weights_cat = torch.cat((attention_weights.unsqueeze(1), attention_weights_cum.unsqueeze(1)), dim=1) + attention_context, attention_weights = self.attention_layer( + attention_hidden, memory, processed_memory, attention_weights_cat, mask + ) + + attention_weights_cum += attention_weights + decoder_input = torch.cat((attention_hidden, attention_context), -1) + + decoder_hidden, decoder_cell = self.decoder_rnn(decoder_input, (decoder_hidden, decoder_cell)) + decoder_hidden = F.dropout(decoder_hidden, self.decoder_dropout, self.training) + + decoder_hidden_attention_context = torch.cat((decoder_hidden, attention_context), dim=1) + decoder_output = self.linear_projection(decoder_hidden_attention_context) + + gate_prediction = self.gate_layer(decoder_hidden_attention_context) + + return ( + decoder_output, + gate_prediction, + attention_hidden, + attention_cell, + decoder_hidden, + decoder_cell, + attention_weights, + attention_weights_cum, + attention_context, + ) + + def forward( + self, memory: Tensor, mel_specgram_truth: Tensor, memory_lengths: Tensor + ) -> Tuple[Tensor, Tensor, Tensor]: + r"""Decoder forward pass for training. + + Args: + memory (Tensor): Encoder outputs + with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``). + mel_specgram_truth (Tensor): Decoder ground-truth mel-specs for teacher forcing + with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``). + memory_lengths (Tensor): Encoder output lengths for attention masking + (the same as ``text_lengths``) with shape (n_batch, ). + + Returns: + mel_specgram (Tensor): Predicted mel spectrogram + with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``). + gate_outputs (Tensor): Predicted stop token for each timestep + with shape (n_batch, max of ``mel_specgram_lengths``). + alignments (Tensor): Sequence of attention weights from the decoder + with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``). + """ + + decoder_input = self._get_initial_frame(memory).unsqueeze(0) + decoder_inputs = self._parse_decoder_inputs(mel_specgram_truth) + decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) + decoder_inputs = self.prenet(decoder_inputs) + + mask = _get_mask_from_lengths(memory_lengths) + ( + attention_hidden, + attention_cell, + decoder_hidden, + decoder_cell, + attention_weights, + attention_weights_cum, + attention_context, + processed_memory, + ) = self._initialize_decoder_states(memory) + + mel_outputs, gate_outputs, alignments = [], [], [] + while len(mel_outputs) < decoder_inputs.size(0) - 1: + decoder_input = decoder_inputs[len(mel_outputs)] + ( + mel_output, + gate_output, + attention_hidden, + attention_cell, + decoder_hidden, + decoder_cell, + attention_weights, + attention_weights_cum, + attention_context, + ) = self.decode( + decoder_input, + attention_hidden, + attention_cell, + decoder_hidden, + decoder_cell, + attention_weights, + attention_weights_cum, + attention_context, + memory, + processed_memory, + mask, + ) + + mel_outputs += [mel_output.squeeze(1)] + gate_outputs += [gate_output.squeeze(1)] + alignments += [attention_weights] + + mel_specgram, gate_outputs, alignments = self._parse_decoder_outputs( + torch.stack(mel_outputs), torch.stack(gate_outputs), torch.stack(alignments) + ) + + return mel_specgram, gate_outputs, alignments + + def _get_go_frame(self, memory: Tensor) -> Tensor: + """Gets all zeros frames to use as the first decoder input + + args: + memory (Tensor): Encoder outputs + with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``). + + returns: + decoder_input (Tensor): All zeros frames with shape(n_batch, ``n_mels`` * ``n_frame_per_step``). + """ + + n_batch = memory.size(0) + dtype = memory.dtype + device = memory.device + decoder_input = torch.zeros(n_batch, self.n_mels * self.n_frames_per_step, dtype=dtype, device=device) + return decoder_input + + @torch.jit.export + def infer(self, memory: Tensor, memory_lengths: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Decoder inference + + Args: + memory (Tensor): Encoder outputs + with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``). + memory_lengths (Tensor): Encoder output lengths for attention masking + (the same as ``text_lengths``) with shape (n_batch, ). + + Returns: + mel_specgram (Tensor): Predicted mel spectrogram + with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``). + mel_specgram_lengths (Tensor): the length of the predicted mel spectrogram (n_batch, )) + gate_outputs (Tensor): Predicted stop token for each timestep + with shape (n_batch, max of ``mel_specgram_lengths``). + alignments (Tensor): Sequence of attention weights from the decoder + with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``). + """ + batch_size, device = memory.size(0), memory.device + + decoder_input = self._get_go_frame(memory) + + mask = _get_mask_from_lengths(memory_lengths) + ( + attention_hidden, + attention_cell, + decoder_hidden, + decoder_cell, + attention_weights, + attention_weights_cum, + attention_context, + processed_memory, + ) = self._initialize_decoder_states(memory) + + mel_specgram_lengths = torch.zeros([batch_size], dtype=torch.int32, device=device) + finished = torch.zeros([batch_size], dtype=torch.bool, device=device) + mel_specgrams: List[Tensor] = [] + gate_outputs: List[Tensor] = [] + alignments: List[Tensor] = [] + for _ in range(self.decoder_max_step): + decoder_input = self.prenet(decoder_input) + ( + mel_specgram, + gate_output, + attention_hidden, + attention_cell, + decoder_hidden, + decoder_cell, + attention_weights, + attention_weights_cum, + attention_context, + ) = self.decode( + decoder_input, + attention_hidden, + attention_cell, + decoder_hidden, + decoder_cell, + attention_weights, + attention_weights_cum, + attention_context, + memory, + processed_memory, + mask, + ) + + mel_specgrams.append(mel_specgram.unsqueeze(0)) + gate_outputs.append(gate_output.transpose(0, 1)) + alignments.append(attention_weights) + mel_specgram_lengths[~finished] += 1 + + finished |= torch.sigmoid(gate_output.squeeze(1)) > self.gate_threshold + if self.decoder_early_stopping and torch.all(finished): + break + + decoder_input = mel_specgram + + if len(mel_specgrams) == self.decoder_max_step: + warnings.warn( + "Reached max decoder steps. The generated spectrogram might not cover " "the whole transcript." + ) + + mel_specgrams = torch.cat(mel_specgrams, dim=0) + gate_outputs = torch.cat(gate_outputs, dim=0) + alignments = torch.cat(alignments, dim=0) + + mel_specgrams, gate_outputs, alignments = self._parse_decoder_outputs(mel_specgrams, gate_outputs, alignments) + + return mel_specgrams, mel_specgram_lengths, gate_outputs, alignments + + +class Tacotron2(nn.Module): + r"""Tacotron2 model from *Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions* + :cite:`shen2018natural` based on the implementation from + `Nvidia Deep Learning Examples `_. + + See Also: + * :class:`torchaudio.pipelines.Tacotron2TTSBundle`: TTS pipeline with pretrained model. + + Args: + mask_padding (bool, optional): Use mask padding (Default: ``False``). + n_mels (int, optional): Number of mel bins (Default: ``80``). + n_symbol (int, optional): Number of symbols for the input text (Default: ``148``). + n_frames_per_step (int, optional): Number of frames processed per step, only 1 is supported (Default: ``1``). + symbol_embedding_dim (int, optional): Input embedding dimension (Default: ``512``). + encoder_n_convolution (int, optional): Number of encoder convolutions (Default: ``3``). + encoder_kernel_size (int, optional): Encoder kernel size (Default: ``5``). + encoder_embedding_dim (int, optional): Encoder embedding dimension (Default: ``512``). + decoder_rnn_dim (int, optional): Number of units in decoder LSTM (Default: ``1024``). + decoder_max_step (int, optional): Maximum number of output mel spectrograms (Default: ``2000``). + decoder_dropout (float, optional): Dropout probability for decoder LSTM (Default: ``0.1``). + decoder_early_stopping (bool, optional): Continue decoding after all samples are finished (Default: ``True``). + attention_rnn_dim (int, optional): Number of units in attention LSTM (Default: ``1024``). + attention_hidden_dim (int, optional): Dimension of attention hidden representation (Default: ``128``). + attention_location_n_filter (int, optional): Number of filters for attention model (Default: ``32``). + attention_location_kernel_size (int, optional): Kernel size for attention model (Default: ``31``). + attention_dropout (float, optional): Dropout probability for attention LSTM (Default: ``0.1``). + prenet_dim (int, optional): Number of ReLU units in prenet layers (Default: ``256``). + postnet_n_convolution (int, optional): Number of postnet convolutions (Default: ``5``). + postnet_kernel_size (int, optional): Postnet kernel size (Default: ``5``). + postnet_embedding_dim (int, optional): Postnet embedding dimension (Default: ``512``). + gate_threshold (float, optional): Probability threshold for stop token (Default: ``0.5``). + """ + + def __init__( + self, + mask_padding: bool = False, + n_mels: int = 80, + n_symbol: int = 148, + n_frames_per_step: int = 1, + symbol_embedding_dim: int = 512, + encoder_embedding_dim: int = 512, + encoder_n_convolution: int = 3, + encoder_kernel_size: int = 5, + decoder_rnn_dim: int = 1024, + decoder_max_step: int = 2000, + decoder_dropout: float = 0.1, + decoder_early_stopping: bool = True, + attention_rnn_dim: int = 1024, + attention_hidden_dim: int = 128, + attention_location_n_filter: int = 32, + attention_location_kernel_size: int = 31, + attention_dropout: float = 0.1, + prenet_dim: int = 256, + postnet_n_convolution: int = 5, + postnet_kernel_size: int = 5, + postnet_embedding_dim: int = 512, + gate_threshold: float = 0.5, + ) -> None: + super().__init__() + + self.mask_padding = mask_padding + self.n_mels = n_mels + self.n_frames_per_step = n_frames_per_step + self.embedding = nn.Embedding(n_symbol, symbol_embedding_dim) + torch.nn.init.xavier_uniform_(self.embedding.weight) + self.encoder = _Encoder(encoder_embedding_dim, encoder_n_convolution, encoder_kernel_size) + self.decoder = _Decoder( + n_mels, + n_frames_per_step, + encoder_embedding_dim, + decoder_rnn_dim, + decoder_max_step, + decoder_dropout, + decoder_early_stopping, + attention_rnn_dim, + attention_hidden_dim, + attention_location_n_filter, + attention_location_kernel_size, + attention_dropout, + prenet_dim, + gate_threshold, + ) + self.postnet = _Postnet(n_mels, postnet_embedding_dim, postnet_kernel_size, postnet_n_convolution) + + def forward( + self, + tokens: Tensor, + token_lengths: Tensor, + mel_specgram: Tensor, + mel_specgram_lengths: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + r"""Pass the input through the Tacotron2 model. This is in teacher + forcing mode, which is generally used for training. + + The input ``tokens`` should be padded with zeros to length max of ``token_lengths``. + The input ``mel_specgram`` should be padded with zeros to length max of ``mel_specgram_lengths``. + + Args: + tokens (Tensor): The input tokens to Tacotron2 with shape `(n_batch, max of token_lengths)`. + token_lengths (Tensor): The valid length of each sample in ``tokens`` with shape `(n_batch, )`. + mel_specgram (Tensor): The target mel spectrogram + with shape `(n_batch, n_mels, max of mel_specgram_lengths)`. + mel_specgram_lengths (Tensor): The length of each mel spectrogram with shape `(n_batch, )`. + + Returns: + [Tensor, Tensor, Tensor, Tensor]: + Tensor + Mel spectrogram before Postnet with shape `(n_batch, n_mels, max of mel_specgram_lengths)`. + Tensor + Mel spectrogram after Postnet with shape `(n_batch, n_mels, max of mel_specgram_lengths)`. + Tensor + The output for stop token at each time step with shape `(n_batch, max of mel_specgram_lengths)`. + Tensor + Sequence of attention weights from the decoder with + shape `(n_batch, max of mel_specgram_lengths, max of token_lengths)`. + """ + + embedded_inputs = self.embedding(tokens).transpose(1, 2) + + encoder_outputs = self.encoder(embedded_inputs, token_lengths) + mel_specgram, gate_outputs, alignments = self.decoder( + encoder_outputs, mel_specgram, memory_lengths=token_lengths + ) + + mel_specgram_postnet = self.postnet(mel_specgram) + mel_specgram_postnet = mel_specgram + mel_specgram_postnet + + if self.mask_padding: + mask = _get_mask_from_lengths(mel_specgram_lengths) + mask = mask.expand(self.n_mels, mask.size(0), mask.size(1)) + mask = mask.permute(1, 0, 2) + + mel_specgram.masked_fill_(mask, 0.0) + mel_specgram_postnet.masked_fill_(mask, 0.0) + gate_outputs.masked_fill_(mask[:, 0, :], 1e3) + + return mel_specgram, mel_specgram_postnet, gate_outputs, alignments + + @torch.jit.export + def infer(self, tokens: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]: + r"""Using Tacotron2 for inference. The input is a batch of encoded + sentences (``tokens``) and its corresponding lengths (``lengths``). The + output is the generated mel spectrograms, its corresponding lengths, and + the attention weights from the decoder. + + The input `tokens` should be padded with zeros to length max of ``lengths``. + + Args: + tokens (Tensor): The input tokens to Tacotron2 with shape `(n_batch, max of lengths)`. + lengths (Tensor or None, optional): + The valid length of each sample in ``tokens`` with shape `(n_batch, )`. + If ``None``, it is assumed that the all the tokens are valid. Default: ``None`` + + Returns: + (Tensor, Tensor, Tensor): + Tensor + The predicted mel spectrogram with shape `(n_batch, n_mels, max of mel_specgram_lengths)`. + Tensor + The length of the predicted mel spectrogram with shape `(n_batch, )`. + Tensor + Sequence of attention weights from the decoder with shape + `(n_batch, max of mel_specgram_lengths, max of lengths)`. + """ + n_batch, max_length = tokens.shape + if lengths is None: + lengths = torch.tensor([max_length]).expand(n_batch).to(tokens.device, tokens.dtype) + + assert lengths is not None # For TorchScript compiler + embedded_inputs = self.embedding(tokens).transpose(1, 2) + encoder_outputs = self.encoder(embedded_inputs, lengths) + mel_specgram, mel_specgram_lengths, _, alignments = self.decoder.infer(encoder_outputs, lengths) + + mel_outputs_postnet = self.postnet(mel_specgram) + mel_outputs_postnet = mel_specgram + mel_outputs_postnet + + alignments = alignments.unfold(1, n_batch, n_batch).transpose(0, 2) + + return mel_outputs_postnet, mel_specgram_lengths, alignments diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/wav2letter.py b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2letter.py new file mode 100644 index 0000000000000000000000000000000000000000..d776131686d1f65982a565088e72e45e7b7c107f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2letter.py @@ -0,0 +1,72 @@ +from torch import nn, Tensor + +__all__ = [ + "Wav2Letter", +] + + +class Wav2Letter(nn.Module): + r"""Wav2Letter model architecture from *Wav2Letter: an End-to-End ConvNet-based Speech + Recognition System* :cite:`collobert2016wav2letter`. + + See Also: + * `Training example `__ + + Args: + num_classes (int, optional): Number of classes to be classified. (Default: ``40``) + input_type (str, optional): Wav2Letter can use as input: ``waveform``, ``power_spectrum`` + or ``mfcc`` (Default: ``waveform``). + num_features (int, optional): Number of input features that the network will receive (Default: ``1``). + """ + + def __init__(self, num_classes: int = 40, input_type: str = "waveform", num_features: int = 1) -> None: + super().__init__() + + acoustic_num_features = 250 if input_type == "waveform" else num_features + acoustic_model = nn.Sequential( + nn.Conv1d(in_channels=acoustic_num_features, out_channels=250, kernel_size=48, stride=2, padding=23), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=2000, kernel_size=32, stride=1, padding=16), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=2000, out_channels=2000, kernel_size=1, stride=1, padding=0), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=2000, out_channels=num_classes, kernel_size=1, stride=1, padding=0), + nn.ReLU(inplace=True), + ) + + if input_type == "waveform": + waveform_model = nn.Sequential( + nn.Conv1d(in_channels=num_features, out_channels=250, kernel_size=250, stride=160, padding=45), + nn.ReLU(inplace=True), + ) + self.acoustic_model = nn.Sequential(waveform_model, acoustic_model) + + if input_type in ["power_spectrum", "mfcc"]: + self.acoustic_model = acoustic_model + + def forward(self, x: Tensor) -> Tensor: + r""" + Args: + x (torch.Tensor): Tensor of dimension (batch_size, num_features, input_length). + + Returns: + Tensor: Predictor tensor of dimension (batch_size, number_of_classes, input_length). + """ + + x = self.acoustic_model(x) + x = nn.functional.log_softmax(x, dim=1) + return x diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/model.py b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/model.py new file mode 100644 index 0000000000000000000000000000000000000000..254122f0eee21906ec50f3d4238a5b3024e74a0a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/model.py @@ -0,0 +1,1579 @@ +import math +from typing import List, Optional, Tuple + +import torch +from torch import Tensor +from torch.nn import Module + +from . import components + + +class Wav2Vec2Model(Module): + """Acoustic model used in *wav2vec 2.0* :cite:`baevski2020wav2vec`. + + Note: + To build the model, please use one of the factory functions. + + See Also: + * :class:`torchaudio.pipelines.Wav2Vec2Bundle`: Pretrained models (without fine-tuning) + * :class:`torchaudio.pipelines.Wav2Vec2ASRBundle`: ASR pipelines with pretrained models. + + Args: + feature_extractor (torch.nn.Module): + Feature extractor that extracts feature vectors from raw audio Tensor. + + encoder (torch.nn.Module): + Encoder that converts the audio features into the sequence of probability + distribution (in negative log-likelihood) over labels. + + aux (torch.nn.Module or None, optional): + Auxiliary module. If provided, the output from encoder is passed to this module. + """ # noqa: E501 + + def __init__( + self, + feature_extractor: Module, + encoder: Module, + aux: Optional[Module] = None, + ): + super().__init__() + self.feature_extractor = feature_extractor + self.encoder = encoder + self.aux = aux + + @torch.jit.export + def extract_features( + self, + waveforms: Tensor, + lengths: Optional[Tensor] = None, + num_layers: Optional[int] = None, + ) -> Tuple[List[Tensor], Optional[Tensor]]: + """Extract feature vectors from raw waveforms + + This returns the list of outputs from the intermediate layers of + transformer block in encoder. + + Args: + waveforms (Tensor): Audio tensor of shape `(batch, frames)`. + lengths (Tensor or None, optional): + Indicates the valid length of each audio in the batch. + Shape: `(batch, )`. + When the ``waveforms`` contains audios with different durations, + by providing ``lengths`` argument, the model will compute + the corresponding valid output lengths and apply proper mask in + transformer attention layer. + If ``None``, it is assumed that the entire audio waveform + length is valid. + num_layers (int or None, optional): + If given, limit the number of intermediate layers to go through. + Providing `1` will stop the computation after going through one + intermediate layers. If not given, the outputs from all the + intermediate layers are returned. + + Returns: + (List[Tensor], Optional[Tensor]): + List of Tensors + Features from requested layers. + Each Tensor is of shape: `(batch, time frame, feature dimension)` + Tensor or None + If ``lengths`` argument was provided, a Tensor of shape `(batch, )` + is returned. + It indicates the valid length in time axis of each feature Tensor. + """ + x, lengths = self.feature_extractor(waveforms, lengths) + x = self.encoder.extract_features(x, lengths, num_layers) + return x, lengths + + def forward( + self, + waveforms: Tensor, + lengths: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + """Compute the sequence of probability distribution over labels. + + Args: + waveforms (Tensor): Audio tensor of shape `(batch, frames)`. + lengths (Tensor or None, optional): + Indicates the valid length of each audio in the batch. + Shape: `(batch, )`. + When the ``waveforms`` contains audios with different durations, + by providing ``lengths`` argument, the model will compute + the corresponding valid output lengths and apply proper mask in + transformer attention layer. + If ``None``, it is assumed that all the audio in ``waveforms`` + have valid length. Default: ``None``. + + Returns: + (Tensor, Optional[Tensor]): + Tensor + The sequences of probability distribution (in logit) over labels. + Shape: `(batch, frames, num labels)`. + Tensor or None + If ``lengths`` argument was provided, a Tensor of shape `(batch, )` + is returned. + It indicates the valid length in time axis of the output Tensor. + """ + x, lengths = self.feature_extractor(waveforms, lengths) + x = self.encoder(x, lengths) + if self.aux is not None: + x = self.aux(x) + return x, lengths + + +class HuBERTPretrainModel(Module): + """HuBERTPretrainModel() + + HuBERT model used for pretraining in *HuBERT* :cite:`hsu2021hubert`. + + Note: + To build the model, please use one of the factory functions. + + See Also: + `HuBERT Pre-training and Fine-tuning Recipes + `__ + + Args: + wav2vec2 (Wav2Vec2Model): + Wav2Vec2 encoder that generates the transformer outputs. + + mask_generator (torch.nn.Module): + Mask generator that generates the mask for masked prediction during the training. + + logit_generator (torch.nn.Module): + Logit generator that predicts the logits of the masked and unmasked inputs. + + feature_grad_mult (float or None): + The factor to scale the convolutional feature extraction layer gradients by. + If ``None``, the gradients of feature extraction layers are not affected. + The scale factor will not affect the forward pass. + """ + + def __init__( + self, + wav2vec2: Wav2Vec2Model, + mask_generator: Module, + logit_generator: Module, + feature_grad_mult: Optional[float], + ): + super().__init__() + self.wav2vec2 = wav2vec2 + self.mask_generator = mask_generator + self.logit_generator = logit_generator + if feature_grad_mult is not None and not 0.0 < feature_grad_mult < 1.0: + raise ValueError( + f"The value of `feature_grad_mult` must be ``None``or between (0, 1). Found {feature_grad_mult}" + ) + self.feature_grad_mult = feature_grad_mult + + def forward( + self, + waveforms: Tensor, + labels: Tensor, + audio_lengths: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + """Compute the sequence of probability distribution over labels. + + Args: + waveforms (Tensor): Audio tensor of dimension `[batch, frames]`. + labels (Tensor): Label for pre-training. A Tensor of dimension `[batch, frames]`. + audio_lengths (Tensor or None, optional): + Indicates the valid length of each audio in the batch. + Shape: `[batch, ]`. + When the ``waveforms`` contains audios with different durations, + by providing ``lengths`` argument, the model will compute + the corresponding valid output lengths and apply proper mask in + transformer attention layer. + If ``None``, it is assumed that all the audio in ``waveforms`` + have valid length. Default: ``None``. + + Returns: + (Tensor, Tensor, Tensor): + Tensor + The masked sequences of probability distribution (in logit). + Shape: `(masked_frames, num labels)`. + Tensor + The unmasked sequence of probability distribution (in logit). + Shape: `(unmasked_frames, num labels)`. + Tensor + The feature mean value for additional penalty loss. + Shape: `(1,)`. + """ + x, lengths = self.wav2vec2.feature_extractor(waveforms, audio_lengths) + if self.feature_grad_mult is not None and self.feature_grad_mult < 1.0: + x = components.GradMultiply.apply(x, self.feature_grad_mult) + features_pen = x.float().pow(2).mean() + if lengths is not None: + padding_mask = components._get_padding_mask(x, lengths) + else: + padding_mask = None + x, attention_mask = self.wav2vec2.encoder._preprocess(x, lengths) + x, mask = self.mask_generator(x, padding_mask) + x = self.wav2vec2.encoder.transformer(x, attention_mask=attention_mask) + if x.shape[1] != labels.shape[1]: + raise ValueError("The length of label must match that of HuBERT model output") + if padding_mask is not None: + mask_m = torch.logical_and(~padding_mask, mask) + mask_u = torch.logical_and(~padding_mask, ~mask_m) + else: + mask_m = mask + mask_u = ~mask_m + + logit_m, logit_u = self.logit_generator(x, labels, mask_m, mask_u) + + return logit_m, logit_u, features_pen + + +def wav2vec2_model( + extractor_mode: str, + extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]], + extractor_conv_bias: bool, + encoder_embed_dim: int, + encoder_projection_dropout: float, + encoder_pos_conv_kernel: int, + encoder_pos_conv_groups: int, + encoder_num_layers: int, + encoder_num_heads: int, + encoder_attention_dropout: float, + encoder_ff_interm_features: int, + encoder_ff_interm_dropout: float, + encoder_dropout: float, + encoder_layer_norm_first: bool, + encoder_layer_drop: float, + aux_num_out: Optional[int], +) -> Wav2Vec2Model: + """Builds custom :class:`~torchaudio.models.Wav2Vec2Model`. + + Note: + The "feature extractor" below corresponds to + `ConvFeatureExtractionModel `__ + in the original ``fairseq`` implementation. + This is referred as "(convolutional) feature encoder" in the *wav2vec 2.0* + :cite:`baevski2020wav2vec` paper. + + The "encoder" below corresponds to `TransformerEncoder `__, + and this is referred as "Transformer" in the paper. + + Args: + extractor_mode (str): Operation mode of feature extractor. + Valid values are ``"group_norm"`` or ``"layer_norm"``. + If ``"group_norm"``, then a single normalization is applied + in the first convolution block. Otherwise, all the convolution + blocks will have layer normalization. + + This option corresponds to ``extractor_mode`` from ``fairseq``. + extractor_conv_layer_config (list of integer tuples or None): + Configuration of convolution layers in feature extractor. + List of convolution configuration, + i.e. ``[(output_channel, kernel_size, stride), ...]`` + + If ``None`` is provided, then the following default value is used. + + .. code-block:: python + + [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ] + + This option corresponds to ``conv_feature_layers`` from ``fairseq``. + + extractor_conv_bias (bool): + Whether to include bias term to each convolution operation. + + This option corresponds to ``conv_bias`` from ``fairseq``. + + encoder_embed_dim (int): + The dimension of embedding in encoder. + + This option corresponds to ``encoder_embed_dim`` from ``fairseq``. + + encoder_projection_dropout (float): + The dropout probability applied after the input feature is projected + to ``encoder_embed_dim``. + + This option corresponds to ``dropout_input`` from ``fairseq``. + + encoder_pos_conv_kernel (int): + The kernel size of convolutional positional embeddings. + + This option corresponds to ``conv_pos`` from ``fairseq``. + + encoder_pos_conv_groups (int): + The number of groups of convolutional positional embeddings. + + This option corresponds to ``conv_pos_groups`` from ``fairseq``. + + encoder_num_layers (int): + The number of self attention layers in transformer block. + + This option corresponds to ``encoder_layers`` from ``fairseq``. + + encoder_num_heads (int): + The number of heads in self attention layers. + + This option corresponds to ``encoder_attention_heads`` from ``fairseq``. + + encoder_attention_dropout (float): + The dropout probability applied after softmax in self-attention layer. + + This option corresponds to ``attention_dropout`` from ``fairseq``. + + encoder_ff_interm_features (int): + The dimension of hidden features in feed forward layer. + + This option corresponds to ``encoder_ffn_embed_dim`` from ``fairseq``. + + encoder_ff_interm_dropout (float): + The dropout probability applied in feedforward layer. + + This option correspinds to ``activation_dropout`` from ``fairseq``. + + encoder_dropout (float): + The dropout probability applied at the end of feed forward layer. + + This option corresponds to ``dropout`` from ``fairseq``. + + encoder_layer_norm_first (bool): + Control the order of layer norm in transformer layer and each encoder layer. + If True, in transformer layer, layer norm is applied before features are fed + to encoder layers. In encoder layer, two layer norms are applied before and after + self attention. + If False, in transformer layer, layer norm is applied after features are fed + to encoder layers. In encoder layer, two layer norms are applied after self + attention, before and after feed forward. + + This option corresponds to ``layer_norm_first`` from ``fairseq``. + + encoder_layer_drop (float): + Probability to drop each encoder layer during training. + + This option corresponds to ``layerdrop`` from ``fairseq``. + + aux_num_out (int or None): + When provided, attach an extra linear layer on top of encoder, which can be + used for fine-tuning. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + if extractor_conv_layer_config is None: + extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2 + + feature_extractor = components._get_feature_extractor( + extractor_mode, extractor_conv_layer_config, extractor_conv_bias + ) + encoder = components._get_encoder( + in_features=extractor_conv_layer_config[-1][0], + embed_dim=encoder_embed_dim, + dropout_input=encoder_projection_dropout, + pos_conv_kernel=encoder_pos_conv_kernel, + pos_conv_groups=encoder_pos_conv_groups, + num_layers=encoder_num_layers, + num_heads=encoder_num_heads, + attention_dropout=encoder_attention_dropout, + ff_interm_features=encoder_ff_interm_features, + ff_interm_dropout=encoder_ff_interm_dropout, + dropout=encoder_dropout, + layer_norm_first=encoder_layer_norm_first, + layer_drop=encoder_layer_drop, + ) + aux = None + if aux_num_out is not None: + aux = torch.nn.Linear(in_features=encoder_embed_dim, out_features=aux_num_out) + return Wav2Vec2Model(feature_extractor, encoder, aux) + + +def wav2vec2_base( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.1, + encoder_ff_interm_dropout: float = 0.1, + encoder_dropout: float = 0.1, + encoder_layer_drop: float = 0.1, + aux_num_out: Optional[int] = None, +) -> Wav2Vec2Model: + """Builds "base" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec` + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="group_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=768, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=12, + encoder_num_heads=12, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=3072, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=False, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + ) + + +def wav2vec2_large( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.1, + encoder_ff_interm_dropout: float = 0.1, + encoder_dropout: float = 0.1, + encoder_layer_drop: float = 0.1, + aux_num_out: Optional[int] = None, +) -> Wav2Vec2Model: + """Builds "large" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec` + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="group_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=1024, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=24, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=4096, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=False, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + ) + + +def wav2vec2_large_lv60k( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.0, + encoder_ff_interm_dropout: float = 0.1, + encoder_dropout: float = 0.0, + encoder_layer_drop: float = 0.1, + aux_num_out: Optional[int] = None, +) -> Wav2Vec2Model: + """Builds "large lv-60k" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec` + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=True, + encoder_embed_dim=1024, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=24, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=4096, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=True, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + ) + + +def hubert_base( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.1, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.1, + encoder_layer_drop: float = 0.05, + aux_num_out: Optional[int] = None, +) -> Wav2Vec2Model: + """Builds "base" :class:`HuBERT ` from *HuBERT* :cite:`hsu2021hubert` + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="group_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=768, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=12, + encoder_num_heads=12, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=3072, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=False, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + ) + + +def hubert_large( + encoder_projection_dropout: float = 0.0, + encoder_attention_dropout: float = 0.0, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.0, + encoder_layer_drop: float = 0.0, + aux_num_out: Optional[int] = None, +) -> Wav2Vec2Model: + """Builds "large" :class:`HuBERT ` from *HuBERT* :cite:`hsu2021hubert` + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=1024, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=24, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=4096, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=True, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + ) + + +def hubert_xlarge( + encoder_projection_dropout: float = 0.0, + encoder_attention_dropout: float = 0.0, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.0, + encoder_layer_drop: float = 0.0, + aux_num_out: Optional[int] = None, +) -> Wav2Vec2Model: + """Builds "extra large" :class:`HuBERT ` from *HuBERT* :cite:`hsu2021hubert` + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=1280, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=48, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=5120, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=True, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + ) + + +def _init_hubert_pretrain_model(module): + if isinstance(module, components.ConvLayerBlock): + torch.nn.init.kaiming_normal_(module.conv.weight) + elif isinstance(module, components.ConvolutionalPositionalEmbedding): + # normalize the weight to normal distribution. + std = math.sqrt(4.0 / (module.embed_dim * module.kernel_size)) + torch.nn.init.normal_(module.conv.weight, mean=0.0, std=std) + torch.nn.init.constant_(module.conv.bias, 0.0) + elif isinstance(module, components.SelfAttention): + # normalize the query, key, value, and out_proj parameters in self attention module. + torch.nn.init.xavier_uniform_(module.k_proj.weight, gain=1 / math.sqrt(2)) + torch.nn.init.xavier_uniform_(module.v_proj.weight, gain=1 / math.sqrt(2)) + torch.nn.init.xavier_uniform_(module.q_proj.weight, gain=1 / math.sqrt(2)) + torch.nn.init.xavier_uniform_(module.out_proj.weight) + torch.nn.init.constant_(module.out_proj.bias, 0.0) + elif isinstance(module, components.Transformer): + module.apply(components._init_transformer_params) + else: + pass + + +def hubert_pretrain_model( + extractor_mode: str, + extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]], + extractor_conv_bias: bool, + encoder_embed_dim: int, + encoder_projection_dropout: float, + encoder_pos_conv_kernel: int, + encoder_pos_conv_groups: int, + encoder_num_layers: int, + encoder_num_heads: int, + encoder_attention_dropout: float, + encoder_ff_interm_features: int, + encoder_ff_interm_dropout: float, + encoder_dropout: float, + encoder_layer_norm_first: bool, + encoder_layer_drop: float, + mask_prob: float, + mask_selection: str, + mask_other: float, + mask_length: int, + no_mask_overlap: bool, + mask_min_space: int, + mask_channel_prob: float, + mask_channel_selection: str, + mask_channel_other: float, + mask_channel_length: int, + no_mask_channel_overlap: bool, + mask_channel_min_space: int, + skip_masked: bool, + skip_nomask: bool, + num_classes: int, + final_dim: int, + feature_grad_mult: Optional[float], +) -> HuBERTPretrainModel: + """Builds custom :class:`HuBERTPretrainModel` for training from scratch + + Note: + The "feature extractor" below corresponds to + `ConvFeatureExtractionModel `__ + in the original ``fairseq`` implementation. + This is referred as "(convolutional) feature encoder" in the *wav2vec 2.0* + :cite:`baevski2020wav2vec` paper. + + The "encoder" below corresponds to `TransformerEncoder `__, + and this is referred as "Transformer" in the paper. + + Args: + extractor_mode (str): Operation mode of feature extractor. + Valid values are ``"group_norm"`` or ``"layer_norm"``. + If ``"group_norm"``, then a single normalization is applied + in the first convolution block. Otherwise, all the convolution + blocks will have layer normalization. + + This option corresponds to ``extractor_mode`` from ``fairseq``. + + extractor_conv_layer_config (list of integer tuples or None): + Configuration of convolution layers in feature extractor. + List of convolution configuration, + i.e. ``[(output_channel, kernel_size, stride), ...]`` + + If ``None`` is provided, then the following default value is used. + + .. code-block:: python + + [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ] + + This option corresponds to ``conv_feature_layers`` from ``fairseq``. + + extractor_conv_bias (bool): + Whether to include bias term to each convolution operation. + + This option corresponds to ``conv_bias`` from ``fairseq``. + + encoder_embed_dim (int): + The dimension of embedding in encoder. + + This option corresponds to ``encoder_embed_dim`` from ``fairseq``. + + encoder_projection_dropout (float): + The dropout probability applied after the input feature is projected + to ``encoder_embed_dim``. + + This option corresponds to ``dropout_input`` from ``fairseq``. + + encoder_pos_conv_kernel (int): + The kernel size of convolutional positional embeddings. + + This option corresponds to ``conv_pos`` from ``fairseq``. + + encoder_pos_conv_groups (int): + The number of groups of convolutional positional embeddings. + + This option corresponds to ``conv_pos_groups`` from ``fairseq``. + + encoder_num_layers (int): + The number of self attention layers in transformer block. + + This option corresponds to ``encoder_layers`` from ``fairseq``. + + encoder_num_heads (int): + The number of heads in self attention layers. + + This option corresponds to ``encoder_attention_heads`` from ``fairseq``. + + encoder_attention_dropout (float): + The dropout probability applied after softmax in self-attention layer. + + This option corresponds to ``attention_dropout`` from ``fairseq``. + + encoder_ff_interm_features (int): + The dimension of hidden features in feed forward layer. + + This option corresponds to ``encoder_ffn_embed_dim`` from ``fairseq``. + + encoder_ff_interm_dropout (float): + The dropout probability applied in feedforward layer. + + This option correspinds to ``activation_dropout`` from ``fairseq``. + + encoder_dropout (float): + The dropout probability applied at the end of feed forward layer. + + This option corresponds to ``dropout`` from ``fairseq``. + + encoder_layer_norm_first (bool): + Control the order of layer norm in transformer layer and each encoder layer. + If True, in transformer layer, layer norm is applied before features are fed + to encoder layers. In encoder layer, two layer norms are applied before and after + self attention. + If False, in transformer layer, layer norm is applied after features are fed + to encoder layers. In encoder layer, two layer norms are applied after self + attention, before and after feed forward. + + This option corresponds to ``layer_norm_first`` from ``fairseq``. + + encoder_layer_drop (float): + Probability to drop each encoder layer during training. + + This option corresponds to ``layerdrop`` from ``fairseq``. + + mask_prob (float): + Probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + However due to overlaps, the actual number will be smaller (unless no_overlap is True). + + This option corresponds to ``mask_prob`` from ``fairseq``. + + mask_selection (str): + How to choose the mask length. Options: [``static``, ``uniform``, ``normal``, ``poisson``]. + + This option corresponds to ``mask_selection`` from ``fairseq``. + + mask_other (float): + Secondary mask argument (used for more complex distributions). + + This option corresponds to ``mask_other`` from ``fairseq``. + + mask_length (int): + The lengths of the mask. + + This option corresponds to ``mask_length`` from ``fairseq``. + + no_mask_overlap (bool): + Whether to allow masks to overlap. + + This option corresponds to ``no_mask_overlap`` from ``fairseq``. + + mask_min_space (int): + Minimum space between spans (if no overlap is enabled). + + This option corresponds to ``mask_min_space`` from ``fairseq``. + + mask_channel_prob: (float): + The probability of replacing a feature with 0. + + This option corresponds to ``mask_channel_prob`` from ``fairseq``. + + mask_channel_selection (str): + How to choose the mask length for channel masking. Options: [``static``, ``uniform``, ``normal``, ``poisson``]. + + This option corresponds to ``mask_channel_selection`` from ``fairseq``. + + mask_channel_other (float): + Secondary mask argument for channel masking(used for more complex distributions). + + This option corresponds to ``mask_channel_other`` from ``fairseq``. + + mask_channel_length (int): + Minimum space between spans (if no overlap is enabled) for channel masking. + + This option corresponds to ``mask_channel_length`` from ``fairseq``. + + no_mask_channel_overlap (bool): + Whether to allow channel masks to overlap. + + This option corresponds to ``no_mask_channel_overlap`` from ``fairseq``. + + mask_channel_min_space (int): + Minimum space between spans for channel masking(if no overlap is enabled). + + This option corresponds to ``mask_channel_min_space`` from ``fairseq``. + + skip_masked (bool): + If True, skip computing losses over masked frames. + + This option corresponds to ``skip_masked`` from ``fairseq``. + + skip_nomask (bool): + If True, skip computing losses over unmasked frames. + + This option corresponds to ``skip_nomask`` from ``fairseq``. + + num_classes (int): + The number of classes in the labels. + + final_dim (int): + Project final representations and targets to `final_dim`. + + This option corresponds to ``final_dim`` from ``fairseq``. + + feature_grad_mult (float or None): + The factor to scale the convolutional feature extraction layer gradients by. + The scale factor will not affect the forward pass. + + This option corresponds to ``feature_grad_mult`` from ``fairseq``. + + Returns: + HuBERTPretrainModel: + The resulting model. + """ # noqa: E501 + if extractor_conv_layer_config is None: + extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2 + + feature_extractor = components._get_feature_extractor( + extractor_mode, extractor_conv_layer_config, extractor_conv_bias + ) + encoder = components._get_encoder( + in_features=extractor_conv_layer_config[-1][0], + embed_dim=encoder_embed_dim, + dropout_input=encoder_projection_dropout, + pos_conv_kernel=encoder_pos_conv_kernel, + pos_conv_groups=encoder_pos_conv_groups, + num_layers=encoder_num_layers, + num_heads=encoder_num_heads, + attention_dropout=encoder_attention_dropout, + ff_interm_features=encoder_ff_interm_features, + ff_interm_dropout=encoder_ff_interm_dropout, + dropout=encoder_dropout, + layer_norm_first=encoder_layer_norm_first, + layer_drop=encoder_layer_drop, + ) + wav2vec2 = Wav2Vec2Model(feature_extractor, encoder) + mask_generator = components.MaskGenerator( + encoder_embed_dim, + mask_prob, + mask_selection, + mask_other, + mask_length, + no_mask_overlap, + mask_min_space, + mask_channel_prob, + mask_channel_selection, + mask_channel_other, + mask_channel_length, + no_mask_channel_overlap, + mask_channel_min_space, + ) + logit_generator = components.LogitGenerator( + encoder_embed_dim, + num_classes, + final_dim, + skip_masked, + skip_nomask, + ) + model = HuBERTPretrainModel( + wav2vec2=wav2vec2, + mask_generator=mask_generator, + logit_generator=logit_generator, + feature_grad_mult=feature_grad_mult, + ) + # initialize the model for pre-training + model.apply(_init_hubert_pretrain_model) + return model + + +def hubert_pretrain_base( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.1, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.1, + encoder_layer_drop: float = 0.05, + mask_prob: float = 0.8, + mask_channel_prob: float = 0.0, + mask_channel_length: int = 10, + feature_grad_mult: Optional[float] = 0.1, + num_classes: int = 100, +) -> HuBERTPretrainModel: + """Builds "base" :class:`HuBERTPretrainModel` from *HuBERT* :cite:`hsu2021hubert` for pretraining. + + Args: + encoder_projection_dropout (float): + See :py:func:`hubert_pretrain_model`. + encoder_attention_dropout (float): + See :py:func:`hubert_pretrain_model`. + encoder_ff_interm_dropout (float): + See :py:func:`hubert_pretrain_model`. + encoder_dropout (float): + See :py:func:`hubert_pretrain_model`. + encoder_layer_drop (float): + See :py:func:`hubert_pretrain_model`. + mask_prob (float): + See :py:func:`hubert_pretrain_model`. + mask_channel_prob (float): + See :py:func:`hubert_pretrain_model`. + mask_channel_length (int): + See :py:func:`hubert_pretrain_model`. + feature_grad_mult (float or None): + See :py:func:`hubert_pretrain_model`. + num_classes (int, optional): + See :py:func:`hubert_pretrain_model`. + + Returns: + HuBERTPretrainModel: + The resulting model. + """ # noqa: E501 + return hubert_pretrain_model( + extractor_mode="group_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=768, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=12, + encoder_num_heads=12, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=3072, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=False, + encoder_layer_drop=encoder_layer_drop, + mask_prob=mask_prob, + mask_selection="static", + mask_other=0.0, + mask_length=10, + no_mask_overlap=False, + mask_min_space=1, + mask_channel_prob=mask_channel_prob, + mask_channel_selection="static", + mask_channel_other=0.0, + mask_channel_length=mask_channel_length, + no_mask_channel_overlap=False, + mask_channel_min_space=1, + skip_masked=False, + skip_nomask=False, + num_classes=num_classes, + final_dim=256, + feature_grad_mult=feature_grad_mult, + ) + + +def hubert_pretrain_large( + encoder_projection_dropout: float = 0.0, + encoder_attention_dropout: float = 0.0, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.0, + encoder_layer_drop: float = 0.0, + mask_prob: float = 0.8, + mask_channel_prob: float = 0.0, + mask_channel_length: int = 10, + feature_grad_mult: Optional[float] = None, +) -> HuBERTPretrainModel: + """Builds "large" :class:`HuBERTPretrainModel` from *HuBERT* :cite:`hsu2021hubert` for pretraining. + + Args: + encoder_projection_dropout (float): + See :py:func:`hubert_pretrain_model`. + encoder_attention_dropout (float): + See :py:func:`hubert_pretrain_model`. + encoder_ff_interm_dropout (float): + See :py:func:`hubert_pretrain_model`. + encoder_dropout (float): + See :py:func:`hubert_pretrain_model`. + encoder_layer_drop (float): + See :py:func:`hubert_pretrain_model`. + mask_prob (float): + See :py:func:`hubert_pretrain_model`. + mask_channel_prob (float): + See :py:func:`hubert_pretrain_model`. + mask_channel_length (int): + See :py:func:`hubert_pretrain_model`. + feature_grad_mult (float or None): + See :py:func:`hubert_pretrain_model`. + + Returns: + HuBERTPretrainModel: + The resulting model. + """ # noqa: E501 + return hubert_pretrain_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=1024, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=24, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=4096, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=True, + encoder_layer_drop=encoder_layer_drop, + mask_prob=mask_prob, + mask_selection="static", + mask_other=0.0, + mask_length=10, + no_mask_overlap=False, + mask_min_space=1, + mask_channel_prob=mask_channel_prob, + mask_channel_selection="static", + mask_channel_other=0.0, + mask_channel_length=mask_channel_length, + no_mask_channel_overlap=False, + mask_channel_min_space=1, + skip_masked=False, + skip_nomask=False, + num_classes=500, + final_dim=768, + feature_grad_mult=feature_grad_mult, + ) + + +def hubert_pretrain_xlarge( + encoder_projection_dropout: float = 0.0, + encoder_attention_dropout: float = 0.0, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.0, + encoder_layer_drop: float = 0.0, + mask_prob: float = 0.8, + mask_channel_prob: float = 0.0, + mask_channel_length: int = 10, + feature_grad_mult: Optional[float] = None, +) -> HuBERTPretrainModel: + """Builds "extra large" :class:`HuBERTPretrainModel` from *HuBERT* :cite:`hsu2021hubert` for pretraining. + + Args: + encoder_projection_dropout (float): + See :py:func:`hubert_pretrain_model`. + encoder_attention_dropout (float): + See :py:func:`hubert_pretrain_model`. + encoder_ff_interm_dropout (float): + See :py:func:`hubert_pretrain_model`. + encoder_dropout (float): + See :py:func:`hubert_pretrain_model`. + encoder_layer_drop (float): + See :py:func:`hubert_pretrain_model`. + mask_prob (float): + See :py:func:`hubert_pretrain_model`. + mask_channel_prob (float): + See :py:func:`hubert_pretrain_model`. + mask_channel_length (int): + See :py:func:`hubert_pretrain_model`. + feature_grad_mult (float or None): + See :py:func:`hubert_pretrain_model`. + + Returns: + HuBERTPretrainModel: + The resulting model. + """ # noqa: E501 + return hubert_pretrain_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=1280, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=48, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=5120, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=True, + encoder_layer_drop=encoder_layer_drop, + mask_prob=mask_prob, + mask_selection="static", + mask_other=0.0, + mask_length=10, + no_mask_overlap=False, + mask_min_space=1, + mask_channel_prob=mask_channel_prob, + mask_channel_selection="static", + mask_channel_other=0.0, + mask_channel_length=mask_channel_length, + no_mask_channel_overlap=False, + mask_channel_min_space=1, + skip_masked=False, + skip_nomask=False, + num_classes=500, + final_dim=1024, + feature_grad_mult=feature_grad_mult, + ) + + +def wavlm_model( + extractor_mode: str, + extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]], + extractor_conv_bias: bool, + encoder_embed_dim: int, + encoder_projection_dropout: float, + encoder_pos_conv_kernel: int, + encoder_pos_conv_groups: int, + encoder_num_layers: int, + encoder_num_heads: int, + encoder_num_buckets: int, + encoder_max_distance: int, + encoder_attention_dropout: float, + encoder_ff_interm_features: int, + encoder_ff_interm_dropout: float, + encoder_dropout: float, + encoder_layer_norm_first: bool, + encoder_layer_drop: float, + aux_num_out: Optional[int], +) -> Wav2Vec2Model: + """Builds custom WaveLM model :cite:`chen2022wavlm`. The architecture is compatible + with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output object is + :class:`~torchaudio.models.Wav2Vec2Model`. Most of the arguments have the same meaning + as in :py:func:`~torchaudio.models.wav2vec2_model` so please refer there for documentation. + + Args: + extractor_mode (str): Operation mode of feature extractor. + See :py:func:`~torchaudio.models.wav2vec2_model`. + + extractor_conv_layer_config (list of integer tuples or None): + See :py:func:`~torchaudio.models.wav2vec2_model`. + + extractor_conv_bias (bool): + See :py:func:`~torchaudio.models.wav2vec2_model`. + + encoder_embed_dim (int): + See :py:func:`~torchaudio.models.wav2vec2_model`. + + encoder_projection_dropout (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + + encoder_pos_conv_kernel (int): + See :py:func:`~torchaudio.models.wav2vec2_model`. + + encoder_pos_conv_groups (int): + See :py:func:`~torchaudio.models.wav2vec2_model`. + + encoder_num_layers (int): + See :py:func:`~torchaudio.models.wav2vec2_model`. + + encoder_num_heads (int): + See :py:func:`~torchaudio.models.wav2vec2_model`. + + encoder_num_buckets (int): + Number of buckets for relative position embedding. + encoder_max_distance (int): + Maximum distance for relative position embedding. + + encoder_attention_dropout (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + + encoder_ff_interm_features (int): + See :py:func:`~torchaudio.models.wav2vec2_model`. + + encoder_ff_interm_dropout (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + + encoder_dropout (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + + encoder_layer_norm_first (bool): + See :py:func:`~torchaudio.models.wav2vec2_model`. + + encoder_layer_drop (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + + aux_num_out (int or None): + See :py:func:`~torchaudio.models.wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ + if extractor_conv_layer_config is None: + extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2 + + feature_extractor = components._get_feature_extractor( + extractor_mode, extractor_conv_layer_config, extractor_conv_bias + ) + encoder = components._get_wavlm_encoder( + in_features=extractor_conv_layer_config[-1][0], + embed_dim=encoder_embed_dim, + dropout_input=encoder_projection_dropout, + pos_conv_kernel=encoder_pos_conv_kernel, + pos_conv_groups=encoder_pos_conv_groups, + num_layers=encoder_num_layers, + num_heads=encoder_num_heads, + num_buckets=encoder_num_buckets, + max_distance=encoder_max_distance, + attention_dropout=encoder_attention_dropout, + ff_interm_features=encoder_ff_interm_features, + ff_interm_dropout=encoder_ff_interm_dropout, + dropout=encoder_dropout, + layer_norm_first=encoder_layer_norm_first, + layer_drop=encoder_layer_drop, + ) + aux = None + if aux_num_out is not None: + aux = torch.nn.Linear(in_features=encoder_embed_dim, out_features=aux_num_out) + return Wav2Vec2Model(feature_extractor, encoder, aux) + + +def wavlm_base( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.1, + encoder_ff_interm_dropout: float = 0.1, + encoder_dropout: float = 0.1, + encoder_layer_drop: float = 0.1, + aux_num_out: Optional[int] = None, +) -> Wav2Vec2Model: + """Builds "base" WaveLM model :cite:`chen2022wavlm`. The architecture is compatible + with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is + :class:`~torchaudio.models.Wav2Vec2Model`. + + Args: + encoder_projection_dropout (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + encoder_dropout (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + aux_num_out (int, optional): + See :py:func:`~torchaudio.models.wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ + return wavlm_model( + extractor_mode="group_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=768, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=12, + encoder_num_heads=12, + encoder_num_buckets=320, + encoder_max_distance=800, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=3072, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=False, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + ) + + +def wavlm_large( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.1, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.1, + encoder_layer_drop: float = 0.1, + aux_num_out: Optional[int] = None, +) -> Wav2Vec2Model: + """Builds "large" WaveLM model :cite:`chen2022wavlm`. The architecture is compatible + with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is + :class:`~torchaudio.models.Wav2Vec2Model`. + + Args: + encoder_projection_dropout (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + encoder_dropout (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + aux_num_out (int, optional): + See :py:func:`~torchaudio.models.wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ + return wavlm_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=1024, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=24, + encoder_num_heads=16, + encoder_num_buckets=320, + encoder_max_distance=800, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=4096, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=True, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + ) + + +def wav2vec2_xlsr_300m( + encoder_projection_dropout: float = 0.0, + encoder_attention_dropout: float = 0.0, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.0, + encoder_layer_drop: float = 0.0, + aux_num_out: Optional[int] = None, +) -> Wav2Vec2Model: + """Builds XLS-R model :cite:`babu2021xls` with 300 millions of parameters. The architecture is compatible + with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is + :class:`~torchaudio.models.Wav2Vec2Model`. + + Args: + encoder_projection_dropout (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + encoder_dropout (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + aux_num_out (int, optional): + See :py:func:`~torchaudio.models.wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ + return wav2vec2_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=True, + encoder_embed_dim=1024, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=24, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=4096, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=True, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + ) + + +def wav2vec2_xlsr_1b( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.0, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.0, + encoder_layer_drop: float = 0.0, + aux_num_out: Optional[int] = None, +) -> Wav2Vec2Model: + """Builds XLS-R model :cite:`babu2021xls` with 1 billion of parameters. The architecture is compatible + with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is + :class:`~torchaudio.models.Wav2Vec2Model`. + + Args: + encoder_projection_dropout (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + encoder_dropout (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + aux_num_out (int, optional): + See :py:func:`~torchaudio.models.wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ + return wav2vec2_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=True, + encoder_embed_dim=1280, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=48, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=5120, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=True, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + ) + + +def wav2vec2_xlsr_2b( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.0, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.0, + encoder_layer_drop: float = 0.0, + aux_num_out: Optional[int] = None, +) -> Wav2Vec2Model: + """Builds XLS-R model :cite:`babu2021xls` with 2 billions of parameters. The architecture is compatible + with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is + :class:`~torchaudio.models.Wav2Vec2Model`. + + Args: + encoder_projection_dropout (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + encoder_dropout (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`~torchaudio.models.wav2vec2_model`. + aux_num_out (int, optional): + See :py:func:`~torchaudio.models.wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ + return wav2vec2_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=True, + encoder_embed_dim=1920, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=48, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=7680, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=True, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + ) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0457b5dd707f7216adc3ea919ba8e257d86f4f71 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__init__.py @@ -0,0 +1,7 @@ +from .import_fairseq import import_fairseq_model +from .import_huggingface import import_huggingface_model + +__all__ = [ + "import_huggingface_model", + "import_fairseq_model", +] diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ddc540df7cf6c03d578531fb4b65c7050cfc694c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/import_fairseq.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/import_fairseq.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b1f274464d491a2e5bf5f88363bb34dcacff448 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/import_fairseq.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/import_huggingface.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/import_huggingface.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64a25adf8ac816e3492911a88410a9594fde81d0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/import_huggingface.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/wavlm_attention.py b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/wavlm_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..fafddfeb958cbcdfdc0a7781b49bc124fff78290 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/wavlm_attention.py @@ -0,0 +1,214 @@ +""" +The MIT License (MIT) + +Copyright (c) Microsoft Corporation + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import math +from typing import Optional, Tuple + +import torch +from torch import nn, Tensor + + +class WavLMSelfAttention(nn.Module): + """Multi-headed self-attention for WavLM model :cite:`chen2022wavlm`. + Wraps around ``torch.nn.MultiheadAttention``, creating relaive position embeddings and passing them to multi-headed + attention as a mask. + Source: https://github.com/microsoft/unilm/blob/2d8302f09c99bca2b82e6e868d81d4281cceebc8/wavlm/modules.py#L303-L763 + + Args: + embed_dim (int): Total dimension of the model. + num_heads (int): The number of heads. + dropout (float, optional): Dropout probability on attn_output_weights. (Default: to ``0.0``) + bias (bool, optional): If ``True``, add bias to input / output projection layers. (Default: ``True``) + has_relative_attention_bias (bool, optional): If ``True``, apply relative position embedding. + Necessary in the first encoder layer, but not in the subsequent ones. (Default: ``False``) + num_buckets (int, optional): Number of buckets for relative position embedding. (Default: ``32``) + max_distance (int, optional): Naximum distance for relative position embedding. (Default: ``128``) + gru_rel_pos (bool, optional): If ``True``, apply gated relative position embedding. (Default: ``False``) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + bias: bool = True, + has_relative_attention_bias: bool = False, + num_buckets: int = 32, + max_distance: int = 128, + gru_rel_pos: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + + if has_relative_attention_bias: + self.rel_attn_embed = nn.Embedding(num_buckets, num_heads) + else: + self.rel_attn_embed = None + + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + self.dropout = dropout + self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True) + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8) + self.gru_rel_pos_const = nn.Parameter(torch.ones(1, num_heads, 1, 1)) + self.has_position_bias = True + + def compute_bias(self, query_length: int, key_length: int) -> Tensor: + """Compute relative position embeddings for WavLM model. + Args: + query_length (int): Query position can take values between 0 and ``query_length - 1``. + key_length (int): Key position can take values between 0 and ``key_length - 1``. + Returns: + Tensor of shape `(num_heads, query_length, key_length)`, relative positions embeddings + """ + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position # Shape (query_length, key_length) + relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True) + relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device) + values = self.rel_attn_embed(relative_position_bucket) # Shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]) + return values + + def _relative_positions_bucket(self, relative_positions: Tensor, bidirectional: bool = True): + """Compute relative position buckets for WavLM model. Computation similar to formula (5) in WavLM + paper :cite:`chen2022wavlm`. + Args: + relative_positions (Tensor): Relative offsets between query and key positions, + of shape ``(query_length, key_length)``. + bidirectional (bool): If ``True``, values will be filled both above and below the diagonal in the resulting + matrix. If ``False``, the elements above the diagonal (i.e. with negative relative offsets) will be set + to zero. (Default ``True``) + Returns: + Tensor of shape ``(query_length, key_length)`` filled bucketed values of with relative positions. + """ + num_buckets = self.num_buckets + max_distance = self.max_distance + # Shape (query_length, key_length) + relative_buckets = torch.zeros_like(relative_positions, dtype=torch.long) + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def forward( + self, + query: Tensor, + key_padding_mask: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + position_bias: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + query (Tensor): Input of shape ``(batch_size, src_len, embed_dim)``. + key_padding_mask (Tensor or None, optional): Mask to exclude keys that are pads, of shape + `(batch, src_len)`, where padding elements are indicated by 1s. (Default: ``None``) + attn_mask: Needs to be ``None``. The argument exists for compatibility with + ``EncoderLayer``. (Default: ``None``) + position_bias (Tensor or None, optional): Position bias of shape + ``(batch_size * num_heads, src_len, src_len)``. When used inside WavLM model encoder, will be + generated in the first layer and then passed from each encoder layer to the next one. + (Default: ``None``) + Returns: + attn_output (Tensor): Attention output of shape ``(batch_size, src_len, embed_dim)``. + position_bias (Tensor or None): Position bias of shape ``(batch_size * num_heads, src_len, src_len)``. + """ + bsz, seq_len, embed_dim = query.size() + assert embed_dim == self.embed_dim + assert attention_mask is None + + if self.rel_attn_embed is not None and position_bias is None: + position_bias = self.compute_bias(seq_len, seq_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1) + + attn_mask_rel_pos: Optional[Tensor] = None + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos: # Apply gating on relative position bias + query_layer = query.view(bsz, seq_len, self.num_heads, -1) + query_layer = query_layer.permute(0, 2, 1, 3) + + gate_a, gate_b = torch.sigmoid( + self.gru_rel_pos_linear(query_layer).view(bsz, self.num_heads, seq_len, 2, 4).sum(-1, keepdim=False) + ).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz, self.num_heads, -1, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view((bsz, self.num_heads, seq_len, seq_len)) + + if attn_mask_rel_pos is not None and key_padding_mask is not None: + key_padding_mask = key_padding_mask.view(bsz, 1, 1, seq_len).expand(-1, self.num_heads, -1, -1) + key_padding_mask = torch.nn.functional._canonical_mask( + mask=key_padding_mask, + mask_name="key_padding_mask", + other_type=torch.nn.functional._none_or_dtype(attn_mask_rel_pos), + other_name="", + target_type=query.dtype, + ) + if attn_mask_rel_pos is not None and key_padding_mask is not None: + attn_mask_rel_pos = attn_mask_rel_pos + key_padding_mask + query_projected = torch.nn.functional.linear(query, self.attention.in_proj_weight, self.attention.in_proj_bias) + query, key, value = query_projected.chunk(3, -1) + shape = (bsz, seq_len, self.num_heads, self.head_dim) + query = query.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim) + key = key.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim) + value = value.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim) + dropout = self.dropout if self.training else 0.0 + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask_rel_pos, + dropout_p=dropout, + is_causal=False, + ) + attn_output = attn_output.transpose(1, 2).reshape(bsz, -1, self.num_heads * self.head_dim) + attn_output = self.attention.out_proj(attn_output) + return attn_output, position_bias diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/wavernn.py b/.venv/lib/python3.11/site-packages/torchaudio/models/wavernn.py new file mode 100644 index 0000000000000000000000000000000000000000..8ae5a3e91675cd9ef7d4614f0daaec50f80dcdee --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/models/wavernn.py @@ -0,0 +1,409 @@ +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +__all__ = [ + "ResBlock", + "MelResNet", + "Stretch2d", + "UpsampleNetwork", + "WaveRNN", +] + + +class ResBlock(nn.Module): + r"""ResNet block based on *Efficient Neural Audio Synthesis* :cite:`kalchbrenner2018efficient`. + + Args: + n_freq: the number of bins in a spectrogram. (Default: ``128``) + + Examples + >>> resblock = ResBlock() + >>> input = torch.rand(10, 128, 512) # a random spectrogram + >>> output = resblock(input) # shape: (10, 128, 512) + """ + + def __init__(self, n_freq: int = 128) -> None: + super().__init__() + + self.resblock_model = nn.Sequential( + nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False), + nn.BatchNorm1d(n_freq), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False), + nn.BatchNorm1d(n_freq), + ) + + def forward(self, specgram: Tensor) -> Tensor: + r"""Pass the input through the ResBlock layer. + Args: + specgram (Tensor): the input sequence to the ResBlock layer (n_batch, n_freq, n_time). + + Return: + Tensor shape: (n_batch, n_freq, n_time) + """ + + return self.resblock_model(specgram) + specgram + + +class MelResNet(nn.Module): + r"""MelResNet layer uses a stack of ResBlocks on spectrogram. + + Args: + n_res_block: the number of ResBlock in stack. (Default: ``10``) + n_freq: the number of bins in a spectrogram. (Default: ``128``) + n_hidden: the number of hidden dimensions of resblock. (Default: ``128``) + n_output: the number of output dimensions of melresnet. (Default: ``128``) + kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``) + + Examples + >>> melresnet = MelResNet() + >>> input = torch.rand(10, 128, 512) # a random spectrogram + >>> output = melresnet(input) # shape: (10, 128, 508) + """ + + def __init__( + self, n_res_block: int = 10, n_freq: int = 128, n_hidden: int = 128, n_output: int = 128, kernel_size: int = 5 + ) -> None: + super().__init__() + + ResBlocks = [ResBlock(n_hidden) for _ in range(n_res_block)] + + self.melresnet_model = nn.Sequential( + nn.Conv1d(in_channels=n_freq, out_channels=n_hidden, kernel_size=kernel_size, bias=False), + nn.BatchNorm1d(n_hidden), + nn.ReLU(inplace=True), + *ResBlocks, + nn.Conv1d(in_channels=n_hidden, out_channels=n_output, kernel_size=1), + ) + + def forward(self, specgram: Tensor) -> Tensor: + r"""Pass the input through the MelResNet layer. + Args: + specgram (Tensor): the input sequence to the MelResNet layer (n_batch, n_freq, n_time). + + Return: + Tensor shape: (n_batch, n_output, n_time - kernel_size + 1) + """ + + return self.melresnet_model(specgram) + + +class Stretch2d(nn.Module): + r"""Upscale the frequency and time dimensions of a spectrogram. + + Args: + time_scale: the scale factor in time dimension + freq_scale: the scale factor in frequency dimension + + Examples + >>> stretch2d = Stretch2d(time_scale=10, freq_scale=5) + + >>> input = torch.rand(10, 100, 512) # a random spectrogram + >>> output = stretch2d(input) # shape: (10, 500, 5120) + """ + + def __init__(self, time_scale: int, freq_scale: int) -> None: + super().__init__() + + self.freq_scale = freq_scale + self.time_scale = time_scale + + def forward(self, specgram: Tensor) -> Tensor: + r"""Pass the input through the Stretch2d layer. + + Args: + specgram (Tensor): the input sequence to the Stretch2d layer (..., n_freq, n_time). + + Return: + Tensor shape: (..., n_freq * freq_scale, n_time * time_scale) + """ + + return specgram.repeat_interleave(self.freq_scale, -2).repeat_interleave(self.time_scale, -1) + + +class UpsampleNetwork(nn.Module): + r"""Upscale the dimensions of a spectrogram. + + Args: + upsample_scales: the list of upsample scales. + n_res_block: the number of ResBlock in stack. (Default: ``10``) + n_freq: the number of bins in a spectrogram. (Default: ``128``) + n_hidden: the number of hidden dimensions of resblock. (Default: ``128``) + n_output: the number of output dimensions of melresnet. (Default: ``128``) + kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``) + + Examples + >>> upsamplenetwork = UpsampleNetwork(upsample_scales=[4, 4, 16]) + >>> input = torch.rand(10, 128, 10) # a random spectrogram + >>> output = upsamplenetwork(input) # shape: (10, 128, 1536), (10, 128, 1536) + """ + + def __init__( + self, + upsample_scales: List[int], + n_res_block: int = 10, + n_freq: int = 128, + n_hidden: int = 128, + n_output: int = 128, + kernel_size: int = 5, + ) -> None: + super().__init__() + + total_scale = 1 + for upsample_scale in upsample_scales: + total_scale *= upsample_scale + self.total_scale: int = total_scale + + self.indent = (kernel_size - 1) // 2 * total_scale + self.resnet = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size) + self.resnet_stretch = Stretch2d(total_scale, 1) + + up_layers = [] + for scale in upsample_scales: + stretch = Stretch2d(scale, 1) + conv = nn.Conv2d( + in_channels=1, out_channels=1, kernel_size=(1, scale * 2 + 1), padding=(0, scale), bias=False + ) + torch.nn.init.constant_(conv.weight, 1.0 / (scale * 2 + 1)) + up_layers.append(stretch) + up_layers.append(conv) + self.upsample_layers = nn.Sequential(*up_layers) + + def forward(self, specgram: Tensor) -> Tuple[Tensor, Tensor]: + r"""Pass the input through the UpsampleNetwork layer. + + Args: + specgram (Tensor): the input sequence to the UpsampleNetwork layer (n_batch, n_freq, n_time) + + Return: + Tensor shape: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale), + (n_batch, n_output, (n_time - kernel_size + 1) * total_scale) + where total_scale is the product of all elements in upsample_scales. + """ + + resnet_output = self.resnet(specgram).unsqueeze(1) + resnet_output = self.resnet_stretch(resnet_output) + resnet_output = resnet_output.squeeze(1) + + specgram = specgram.unsqueeze(1) + upsampling_output = self.upsample_layers(specgram) + upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent : -self.indent] + + return upsampling_output, resnet_output + + +class WaveRNN(nn.Module): + r"""WaveRNN model from *Efficient Neural Audio Synthesis* :cite:`wavernn` + based on the implementation from `fatchord/WaveRNN `_. + + The original implementation was introduced in *Efficient Neural Audio Synthesis* + :cite:`kalchbrenner2018efficient`. The input channels of waveform and spectrogram have to be 1. + The product of `upsample_scales` must equal `hop_length`. + + See Also: + * `Training example `__ + * :class:`torchaudio.pipelines.Tacotron2TTSBundle`: TTS pipeline with pretrained model. + + Args: + upsample_scales: the list of upsample scales. + n_classes: the number of output classes. + hop_length: the number of samples between the starts of consecutive frames. + n_res_block: the number of ResBlock in stack. (Default: ``10``) + n_rnn: the dimension of RNN layer. (Default: ``512``) + n_fc: the dimension of fully connected layer. (Default: ``512``) + kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``) + n_freq: the number of bins in a spectrogram. (Default: ``128``) + n_hidden: the number of hidden dimensions of resblock. (Default: ``128``) + n_output: the number of output dimensions of melresnet. (Default: ``128``) + + Example + >>> wavernn = WaveRNN(upsample_scales=[5,5,8], n_classes=512, hop_length=200) + >>> waveform, sample_rate = torchaudio.load(file) + >>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length) + >>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time) + >>> output = wavernn(waveform, specgram) + >>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, n_classes) + """ + + def __init__( + self, + upsample_scales: List[int], + n_classes: int, + hop_length: int, + n_res_block: int = 10, + n_rnn: int = 512, + n_fc: int = 512, + kernel_size: int = 5, + n_freq: int = 128, + n_hidden: int = 128, + n_output: int = 128, + ) -> None: + super().__init__() + + self.kernel_size = kernel_size + self._pad = (kernel_size - 1 if kernel_size % 2 else kernel_size) // 2 + self.n_rnn = n_rnn + self.n_aux = n_output // 4 + self.hop_length = hop_length + self.n_classes = n_classes + self.n_bits: int = int(math.log2(self.n_classes)) + + total_scale = 1 + for upsample_scale in upsample_scales: + total_scale *= upsample_scale + if total_scale != self.hop_length: + raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}") + + self.upsample = UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size) + self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn) + + self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True) + self.rnn2 = nn.GRU(n_rnn + self.n_aux, n_rnn, batch_first=True) + + self.relu1 = nn.ReLU(inplace=True) + self.relu2 = nn.ReLU(inplace=True) + + self.fc1 = nn.Linear(n_rnn + self.n_aux, n_fc) + self.fc2 = nn.Linear(n_fc + self.n_aux, n_fc) + self.fc3 = nn.Linear(n_fc, self.n_classes) + + def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor: + r"""Pass the input through the WaveRNN model. + + Args: + waveform: the input waveform to the WaveRNN layer (n_batch, 1, (n_time - kernel_size + 1) * hop_length) + specgram: the input spectrogram to the WaveRNN layer (n_batch, 1, n_freq, n_time) + + Return: + Tensor: shape (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes) + """ + + if waveform.size(1) != 1: + raise ValueError("Require the input channel of waveform is 1") + if specgram.size(1) != 1: + raise ValueError("Require the input channel of specgram is 1") + # remove channel dimension until the end + waveform, specgram = waveform.squeeze(1), specgram.squeeze(1) + + batch_size = waveform.size(0) + h1 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device) + h2 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device) + # output of upsample: + # specgram: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale) + # aux: (n_batch, n_output, (n_time - kernel_size + 1) * total_scale) + specgram, aux = self.upsample(specgram) + specgram = specgram.transpose(1, 2) + aux = aux.transpose(1, 2) + + aux_idx = [self.n_aux * i for i in range(5)] + a1 = aux[:, :, aux_idx[0] : aux_idx[1]] + a2 = aux[:, :, aux_idx[1] : aux_idx[2]] + a3 = aux[:, :, aux_idx[2] : aux_idx[3]] + a4 = aux[:, :, aux_idx[3] : aux_idx[4]] + + x = torch.cat([waveform.unsqueeze(-1), specgram, a1], dim=-1) + x = self.fc(x) + res = x + x, _ = self.rnn1(x, h1) + + x = x + res + res = x + x = torch.cat([x, a2], dim=-1) + x, _ = self.rnn2(x, h2) + + x = x + res + x = torch.cat([x, a3], dim=-1) + x = self.fc1(x) + x = self.relu1(x) + + x = torch.cat([x, a4], dim=-1) + x = self.fc2(x) + x = self.relu2(x) + x = self.fc3(x) + + # bring back channel dimension + return x.unsqueeze(1) + + @torch.jit.export + def infer(self, specgram: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: + r"""Inference method of WaveRNN. + + This function currently only supports multinomial sampling, which assumes the + network is trained on cross entropy loss. + + Args: + specgram (Tensor): + Batch of spectrograms. Shape: `(n_batch, n_freq, n_time)`. + lengths (Tensor or None, optional): + Indicates the valid length of each audio in the batch. + Shape: `(batch, )`. + When the ``specgram`` contains spectrograms with different durations, + by providing ``lengths`` argument, the model will compute + the corresponding valid output lengths. + If ``None``, it is assumed that all the audio in ``waveforms`` + have valid length. Default: ``None``. + + Returns: + (Tensor, Optional[Tensor]): + Tensor + The inferred waveform of size `(n_batch, 1, n_time)`. + 1 stands for a single channel. + Tensor or None + If ``lengths`` argument was provided, a Tensor of shape `(batch, )` + is returned. + It indicates the valid length in time axis of the output Tensor. + """ + + device = specgram.device + dtype = specgram.dtype + + specgram = torch.nn.functional.pad(specgram, (self._pad, self._pad)) + specgram, aux = self.upsample(specgram) + if lengths is not None: + lengths = lengths * self.upsample.total_scale + + output: List[Tensor] = [] + b_size, _, seq_len = specgram.size() + + h1 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype) + h2 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype) + x = torch.zeros((b_size, 1), device=device, dtype=dtype) + + aux_split = [aux[:, self.n_aux * i : self.n_aux * (i + 1), :] for i in range(4)] + + for i in range(seq_len): + + m_t = specgram[:, :, i] + + a1_t, a2_t, a3_t, a4_t = [a[:, :, i] for a in aux_split] + + x = torch.cat([x, m_t, a1_t], dim=1) + x = self.fc(x) + _, h1 = self.rnn1(x.unsqueeze(1), h1) + + x = x + h1[0] + inp = torch.cat([x, a2_t], dim=1) + _, h2 = self.rnn2(inp.unsqueeze(1), h2) + + x = x + h2[0] + x = torch.cat([x, a3_t], dim=1) + x = F.relu(self.fc1(x)) + + x = torch.cat([x, a4_t], dim=1) + x = F.relu(self.fc2(x)) + + logits = self.fc3(x) + + posterior = F.softmax(logits, dim=1) + + x = torch.multinomial(posterior, 1).float() + # Transform label [0, 2 ** n_bits - 1] to waveform [-1, 1] + x = 2 * x / (2**self.n_bits - 1.0) - 1.0 + + output.append(x) + + return torch.stack(output).permute(1, 2, 0), lengths diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/prototype/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/prototype/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38d09a3023000968b7c001bc5a223cd1382678e5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/prototype/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0e4a6194f48027caf10f3dcbbada53719a14d4a2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/__init__.py @@ -0,0 +1,4 @@ +from .musan import Musan + + +__all__ = ["Musan"] diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7d7df85469ddc957c37aebc9b183f174eaa77c2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/__pycache__/musan.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/__pycache__/musan.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0b2a03d01c8b22e238f0e340c1c5f624abb190e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/__pycache__/musan.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/musan.py b/.venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/musan.py new file mode 100644 index 0000000000000000000000000000000000000000..c4592bb3e4097f51064bfac01467873ba7263ec8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/musan.py @@ -0,0 +1,67 @@ +from pathlib import Path +from typing import Tuple, Union + +import torch +from torch.utils.data import Dataset +from torchaudio.datasets.utils import _load_waveform + + +_SUBSETS = ["music", "noise", "speech"] +_SAMPLE_RATE = 16_000 + + +class Musan(Dataset): + r"""*MUSAN* :cite:`musan2015` dataset. + + Args: + root (str or Path): Root directory where the dataset's top-level directory exists. + subset (str): Subset of the dataset to use. Options: [``"music"``, ``"noise"``, ``"speech"``]. + """ + + def __init__(self, root: Union[str, Path], subset: str): + if subset not in _SUBSETS: + raise ValueError(f"Invalid subset '{subset}' given. Please provide one of {_SUBSETS}") + + subset_path = Path(root) / subset + self._walker = [str(p) for p in subset_path.glob("*/*.*")] + + def get_metadata(self, n: int) -> Tuple[str, int, str]: + r"""Get metadata for the n-th sample in the dataset. Returns filepath instead of waveform, + but otherwise returns the same fields as :py:func:`__getitem__`. + + Args: + n (int): Index of sample to be loaded. + + Returns: + (str, int, str): + str + Path to audio. + int + Sample rate. + str + File name. + """ + audio_path = self._walker[n] + return audio_path, _SAMPLE_RATE, Path(audio_path).name + + def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str]: + r"""Return the n-th sample in the dataset. + + Args: + n (int): Index of sample to be loaded. + + Returns: + (torch.Tensor, int, str): + torch.Tensor + Waveform. + int + Sample rate. + str + File name. + """ + audio_path, sample_rate, filename = self.get_metadata(n) + path = Path(audio_path) + return _load_waveform(path.parent, path.name, sample_rate), sample_rate, filename + + def __len__(self) -> int: + return len(self._walker) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..20bc181731eba87faeb77a36e7e1cdce4101f496 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__init__.py @@ -0,0 +1,26 @@ +from ._dsp import ( + adsr_envelope, + exp_sigmoid, + extend_pitch, + filter_waveform, + frequency_impulse_response, + oscillator_bank, + sinc_impulse_response, +) +from ._rir import ray_tracing, simulate_rir_ism +from .functional import barkscale_fbanks, chroma_filterbank + + +__all__ = [ + "adsr_envelope", + "exp_sigmoid", + "barkscale_fbanks", + "chroma_filterbank", + "extend_pitch", + "filter_waveform", + "frequency_impulse_response", + "oscillator_bank", + "ray_tracing", + "sinc_impulse_response", + "simulate_rir_ism", +] diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65300f8d861b52dd2d39187ced0c771e768356ef Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/_dsp.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/_dsp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..125d658dbe82ea5c7b26c70126da3b60dff8e8d6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/_dsp.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/_rir.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/_rir.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13f89ce786f3b06a880d2d3819f462e412e768fa Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/_rir.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/functional.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/functional.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b4b503d08c1db58d763db4a17fba592242ab499 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/functional.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/_dsp.py b/.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/_dsp.py new file mode 100644 index 0000000000000000000000000000000000000000..72b1a153f57eaec1b464ad42199cf6f6e331ae26 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/_dsp.py @@ -0,0 +1,433 @@ +import warnings +from typing import List, Optional, Union + +import torch + +from torchaudio.functional import fftconvolve + + +def oscillator_bank( + frequencies: torch.Tensor, + amplitudes: torch.Tensor, + sample_rate: float, + reduction: str = "sum", + dtype: Optional[torch.dtype] = torch.float64, +) -> torch.Tensor: + """Synthesize waveform from the given instantaneous frequencies and amplitudes. + + .. devices:: CPU CUDA + + .. properties:: Autograd TorchScript + + Note: + The phase information of the output waveform is found by taking the cumulative sum + of the given instantaneous frequencies (``frequencies``). + This incurs roundoff error when the data type does not have enough precision. + Using ``torch.float64`` can work around this. + + The following figure shows the difference between ``torch.float32`` and + ``torch.float64`` when generating a sin wave of constant frequency and amplitude + with sample rate 8000 [Hz]. + Notice that ``torch.float32`` version shows artifacts that are not seen in + ``torch.float64`` version. + + .. image:: https://download.pytorch.org/torchaudio/doc-assets/oscillator_precision.png + + Args: + frequencies (Tensor): Sample-wise oscillator frequencies (Hz). Shape `(..., time, N)`. + amplitudes (Tensor): Sample-wise oscillator amplitude. Shape: `(..., time, N)`. + sample_rate (float): Sample rate + reduction (str): Reduction to perform. + Valid values are ``"sum"``, ``"mean"`` or ``"none"``. Default: ``"sum"`` + dtype (torch.dtype or None, optional): The data type on which cumulative sum operation is performed. + Default: ``torch.float64``. Pass ``None`` to disable the casting. + + Returns: + Tensor: + The resulting waveform. + + If ``reduction`` is ``"none"``, then the shape is + `(..., time, N)`, otherwise the shape is `(..., time)`. + """ + if frequencies.shape != amplitudes.shape: + raise ValueError( + "The shapes of `frequencies` and `amplitudes` must match. " + f"Found: {frequencies.shape} and {amplitudes.shape} respectively." + ) + reductions = ["sum", "mean", "none"] + if reduction not in reductions: + raise ValueError(f"The value of reduction must be either {reductions}. Found: {reduction}") + + invalid = torch.abs(frequencies) >= sample_rate / 2 + if torch.any(invalid): + warnings.warn( + "Some frequencies are above nyquist frequency. " + "Setting the corresponding amplitude to zero. " + "This might cause numerically unstable gradient." + ) + amplitudes = torch.where(invalid, 0.0, amplitudes) + + pi2 = 2.0 * torch.pi + freqs = frequencies * pi2 / sample_rate % pi2 + phases = torch.cumsum(freqs, dim=-2, dtype=dtype) + if dtype is not None and freqs.dtype != dtype: + phases = phases.to(freqs.dtype) + + waveform = amplitudes * torch.sin(phases) + if reduction == "sum": + return waveform.sum(-1) + if reduction == "mean": + return waveform.mean(-1) + return waveform + + +def adsr_envelope( + num_frames: int, + *, + attack: float = 0.0, + hold: float = 0.0, + decay: float = 0.0, + sustain: float = 1.0, + release: float = 0.0, + n_decay: int = 2, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, +): + """Generate ADSR Envelope + + .. devices:: CPU CUDA + + Args: + num_frames (int): The number of output frames. + attack (float, optional): + The relative *time* it takes to reach the maximum level from + the start. (Default: ``0.0``) + hold (float, optional): + The relative *time* the maximum level is held before + it starts to decay. (Default: ``0.0``) + decay (float, optional): + The relative *time* it takes to sustain from + the maximum level. (Default: ``0.0``) + sustain (float, optional): The relative *level* at which + the sound should sustain. (Default: ``1.0``) + + .. Note:: + The duration of sustain is derived as `1.0 - (The sum of attack, hold, decay and release)`. + + release (float, optional): The relative *time* it takes for the sound level to + reach zero after the sustain. (Default: ``0.0``) + n_decay (int, optional): The degree of polynomial decay. Default: ``2``. + dtype (torch.dtype, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default + (see :py:func:`torch.set_default_tensor_type`). + device (torch.device, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :py:func:`torch.set_default_tensor_type`). + device will be the CPU for CPU tensor types and the current CUDA + device for CUDA tensor types. + + Returns: + Tensor: ADSR Envelope. Shape: `(num_frames, )` + + Example + .. image:: https://download.pytorch.org/torchaudio/doc-assets/adsr_examples.png + + """ + if not 0 <= attack <= 1: + raise ValueError(f"The value of `attack` must be within [0, 1]. Found: {attack}") + if not 0 <= decay <= 1: + raise ValueError(f"The value of `decay` must be within [0, 1]. Found: {decay}") + if not 0 <= sustain <= 1: + raise ValueError(f"The value of `sustain` must be within [0, 1]. Found: {sustain}") + if not 0 <= hold <= 1: + raise ValueError(f"The value of `hold` must be within [0, 1]. Found: {hold}") + if not 0 <= release <= 1: + raise ValueError(f"The value of `release` must be within [0, 1]. Found: {release}") + if attack + decay + release + hold > 1: + raise ValueError("The sum of `attack`, `hold`, `decay` and `release` must not exceed 1.") + + nframes = num_frames - 1 + num_a = int(nframes * attack) + num_h = int(nframes * hold) + num_d = int(nframes * decay) + num_r = int(nframes * release) + + # Initialize with sustain + out = torch.full((num_frames,), float(sustain), device=device, dtype=dtype) + + # attack + if num_a > 0: + torch.linspace(0.0, 1.0, num_a + 1, out=out[: num_a + 1]) + + # hold + if num_h > 0: + out[num_a : num_a + num_h + 1] = 1.0 + + # decay + if num_d > 0: + # Compute: sustain + (1.0 - sustain) * (linspace[1, 0] ** n_decay) + i = num_a + num_h + decay = out[i : i + num_d + 1] + torch.linspace(1.0, 0.0, num_d + 1, out=decay) + decay **= n_decay + decay *= 1.0 - sustain + decay += sustain + + # sustain is handled by initialization + + # release + if num_r > 0: + torch.linspace(sustain, 0, num_r + 1, out=out[-num_r - 1 :]) + + return out + + +def extend_pitch( + base: torch.Tensor, + pattern: Union[int, List[float], torch.Tensor], +): + """Extend the given time series values with multipliers of them. + + .. devices:: CPU CUDA + + .. properties:: Autograd TorchScript + + Given a series of fundamental frequencies (pitch), this function appends + its harmonic overtones or inharmonic partials. + + Args: + base (torch.Tensor): + Base time series, like fundamental frequencies (Hz). Shape: `(..., time, 1)`. + pattern (int, list of floats or torch.Tensor): + If ``int``, the number of pitch series after the operation. + `pattern - 1` tones are added, so that the resulting Tensor contains + up to `pattern`-th overtones of the given series. + + If list of float or ``torch.Tensor``, it must be one dimensional, + representing the custom multiplier of the fundamental frequency. + + Returns: + Tensor: Oscillator frequencies (Hz). Shape: `(..., time, num_tones)`. + + Example + >>> # fundamental frequency + >>> f0 = torch.linspace(1, 5, 5).unsqueeze(-1) + >>> f0 + tensor([[1.], + [2.], + [3.], + [4.], + [5.]]) + >>> # Add harmonic overtones, up to 3rd. + >>> f = extend_pitch(f0, 3) + >>> f.shape + torch.Size([5, 3]) + >>> f + tensor([[ 1., 2., 3.], + [ 2., 4., 6.], + [ 3., 6., 9.], + [ 4., 8., 12.], + [ 5., 10., 15.]]) + >>> # Add custom (inharmonic) partials. + >>> f = extend_pitch(f0, torch.tensor([1, 2.1, 3.3, 4.5])) + >>> f.shape + torch.Size([5, 4]) + >>> f + tensor([[ 1.0000, 2.1000, 3.3000, 4.5000], + [ 2.0000, 4.2000, 6.6000, 9.0000], + [ 3.0000, 6.3000, 9.9000, 13.5000], + [ 4.0000, 8.4000, 13.2000, 18.0000], + [ 5.0000, 10.5000, 16.5000, 22.5000]]) + """ + if isinstance(pattern, torch.Tensor): + mult = pattern + elif isinstance(pattern, int): + mult = torch.linspace(1.0, float(pattern), pattern, device=base.device, dtype=base.dtype) + else: + mult = torch.tensor(pattern, dtype=base.dtype, device=base.device) + h_freq = base @ mult.unsqueeze(0) + return h_freq + + +def sinc_impulse_response(cutoff: torch.Tensor, window_size: int = 513, high_pass: bool = False): + """Create windowed-sinc impulse response for given cutoff frequencies. + + .. devices:: CPU CUDA + + .. properties:: Autograd TorchScript + + Args: + cutoff (Tensor): Cutoff frequencies for low-pass sinc filter. + + window_size (int, optional): Size of the Hamming window to apply. Must be odd. + (Default: 513) + + high_pass (bool, optional): + If ``True``, convert the resulting filter to high-pass. + Otherwise low-pass filter is returned. Default: ``False``. + + Returns: + Tensor: A series of impulse responses. Shape: `(..., window_size)`. + """ + if window_size % 2 == 0: + raise ValueError(f"`window_size` must be odd. Given: {window_size}") + + half = window_size // 2 + device, dtype = cutoff.device, cutoff.dtype + idx = torch.linspace(-half, half, window_size, device=device, dtype=dtype) + + filt = torch.special.sinc(cutoff.unsqueeze(-1) * idx.unsqueeze(0)) + filt = filt * torch.hamming_window(window_size, device=device, dtype=dtype, periodic=False).unsqueeze(0) + filt = filt / filt.sum(dim=-1, keepdim=True).abs() + + # High pass IR is obtained by subtracting low_pass IR from delta function. + # https://courses.engr.illinois.edu/ece401/fa2020/slides/lec10.pdf + if high_pass: + filt = -filt + filt[..., half] = 1.0 + filt[..., half] + return filt + + +def frequency_impulse_response(magnitudes): + """Create filter from desired frequency response + + Args: + magnitudes: The desired frequency responses. Shape: `(..., num_fft_bins)` + + Returns: + Tensor: Impulse response. Shape `(..., 2 * (num_fft_bins - 1))` + """ + if magnitudes.min() < 0.0: + # Negative magnitude does not make sense but allowing so that autograd works + # around 0. + # Should we raise error? + warnings.warn("The input frequency response should not contain negative values.") + ir = torch.fft.fftshift(torch.fft.irfft(magnitudes), dim=-1) + device, dtype = magnitudes.device, magnitudes.dtype + window = torch.hann_window(ir.size(-1), periodic=False, device=device, dtype=dtype).expand_as(ir) + return ir * window + + +def _overlap_and_add(waveform, stride): + num_frames, frame_size = waveform.shape[-2:] + numel = (num_frames - 1) * stride + frame_size + buffer = torch.zeros(waveform.shape[:-2] + (numel,), device=waveform.device, dtype=waveform.dtype) + for i in range(num_frames): + start = i * stride + end = start + frame_size + buffer[..., start:end] += waveform[..., i, :] + return buffer + + +def filter_waveform(waveform: torch.Tensor, kernels: torch.Tensor, delay_compensation: int = -1): + """Applies filters along time axis of the given waveform. + + This function applies the given filters along time axis in the following manner: + + 1. Split the given waveform into chunks. The number of chunks is equal to the number of given filters. + 2. Filter each chunk with corresponding filter. + 3. Place the filtered chunks at the original indices while adding up the overlapping parts. + 4. Crop the resulting waveform so that delay introduced by the filter is removed and its length + matches that of the input waveform. + + The following figure illustrates this. + + .. image:: https://download.pytorch.org/torchaudio/doc-assets/filter_waveform.png + + .. note:: + + If the number of filters is one, then the operation becomes stationary. + i.e. the same filtering is applied across the time axis. + + Args: + waveform (Tensor): Shape `(..., time)`. + kernels (Tensor): Impulse responses. + Valid inputs are 2D tensor with shape `(num_filters, filter_length)` or + `(N+1)`-D tensor with shape `(..., num_filters, filter_length)`, where `N` is + the dimension of waveform. + + In case of 2D input, the same set of filters is used across channels and batches. + Otherwise, different sets of filters are applied. In this case, the shape of + the first `N-1` dimensions of filters must match (or be broadcastable to) that of waveform. + + delay_compensation (int): Control how the waveform is cropped after full convolution. + If the value is zero or positive, it is interpreted as the length of crop at the + beginning of the waveform. The value cannot be larger than the size of filter kernel. + Otherwise the initial crop is ``filter_size // 2``. + When cropping happens, the waveform is also cropped from the end so that the + length of the resulting waveform matches the input waveform. + + Returns: + Tensor: `(..., time)`. + """ + if kernels.ndim not in [2, waveform.ndim + 1]: + raise ValueError( + "`kernels` must be 2 or N+1 dimension where " + f"N is the dimension of waveform. Found: {kernels.ndim} (N={waveform.ndim})" + ) + + num_filters, filter_size = kernels.shape[-2:] + num_frames = waveform.size(-1) + + if delay_compensation > filter_size: + raise ValueError( + "When `delay_compenstation` is provided, it cannot be larger than the size of filters." + f"Found: delay_compensation={delay_compensation}, filter_size={filter_size}" + ) + + # Transform waveform's time axis into (num_filters x chunk_length) with optional padding + chunk_length = num_frames // num_filters + if num_frames % num_filters > 0: + chunk_length += 1 + num_pad = chunk_length * num_filters - num_frames + waveform = torch.nn.functional.pad(waveform, [0, num_pad], "constant", 0) + chunked = waveform.unfold(-1, chunk_length, chunk_length) + assert chunked.numel() >= waveform.numel() + + # Broadcast kernels + if waveform.ndim + 1 > kernels.ndim: + expand_shape = waveform.shape[:-1] + kernels.shape + kernels = kernels.expand(expand_shape) + + convolved = fftconvolve(chunked, kernels) + restored = _overlap_and_add(convolved, chunk_length) + + # Trim in a way that the number of samples are same as input, + # and the filter delay is compensated + if delay_compensation >= 0: + start = delay_compensation + else: + start = filter_size // 2 + num_crops = restored.size(-1) - num_frames + end = num_crops - start + result = restored[..., start:-end] + return result + + +def exp_sigmoid( + input: torch.Tensor, exponent: float = 10.0, max_value: float = 2.0, threshold: float = 1e-7 +) -> torch.Tensor: + """Exponential Sigmoid pointwise nonlinearity. + Implements the equation: + ``max_value`` * sigmoid(``input``) ** (log(``exponent``)) + ``threshold`` + + The output has a range of [``threshold``, ``max_value``]. + ``exponent`` controls the slope of the output. + + .. devices:: CPU CUDA + + Args: + input (Tensor): Input Tensor + exponent (float, optional): Exponent. Controls the slope of the output + max_value (float, optional): Maximum value of the output + threshold (float, optional): Minimum value of the output + + Returns: + Tensor: Exponential Sigmoid output. Shape: same as input + + """ + + return max_value * torch.pow( + torch.nn.functional.sigmoid(input), + torch.log(torch.tensor(exponent, device=input.device, dtype=input.dtype)), + ) + torch.tensor(threshold, device=input.device, dtype=input.dtype) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/_rir.py b/.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/_rir.py new file mode 100644 index 0000000000000000000000000000000000000000..0e67a5494d204182d83cc09166064ea9d4355176 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/_rir.py @@ -0,0 +1,379 @@ +import math +from typing import Optional, Tuple, Union + +import torch +import torchaudio +from torch import Tensor + + +def _compute_image_sources( + room: torch.Tensor, + source: torch.Tensor, + max_order: int, + absorption: torch.Tensor, + scatter: Optional[torch.Tensor] = None, +) -> Tuple[Tensor, Tensor]: + """Compute image sources in a shoebox-like room. + + Args: + room (torch.Tensor): The 1D Tensor to determine the room size. The shape is + `(D,)`, where ``D`` is 2 if room is a 2D room, or 3 if room is a 3D room. + source (torch.Tensor): The coordinate of the sound source. Tensor with dimensions + `(D)`. + max_order (int): The maximum number of reflections of the source. + absorption (torch.Tensor): The absorption coefficients of wall materials. + ``absorption`` is a Tensor with dimensions `(num_band, num_wall)`. + The shape options are ``[(1, 4), (1, 6), (7, 4), (7, 6)]``. + ``num_band`` is `1` if the coefficients is the same for all frequencies, or is `7` + if the coefficients are different to different frequencies. `7` refers to the default number + of octave bands. (See note in `simulate_rir_ism` method). + ``num_wall`` is `4` if the room is a 2D room, representing absorption coefficients + of ``"west"``, ``"east"``, ``"south"``, and ``"north"`` walls, respectively. + Or it is `6` if the room is a 3D room, representing absorption coefficients + of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``, and ``"ceiling"``, respectively. + scatter (torch.Tensor): The scattering coefficients of wall materials. + The shape of ``scatter`` must match that of ``absorption``. If ``None``, it is not + used in image source computation. (Default: ``None``) + + Returns: + (torch.Tensor): The coordinates of all image sources within ``max_order`` number of reflections. + Tensor with dimensions `(num_image_source, D)`. + (torch.Tensor): The attenuation of corresponding image sources. Tensor with dimensions + `(num_band, num_image_source)`. + """ + if scatter is None: + tr = torch.sqrt(1 - absorption) + else: + tr = torch.sqrt(1 - absorption) * torch.sqrt(1 - scatter) + + ind = torch.arange(-max_order, max_order + 1, device=source.device) + if room.shape[0] == 2: + XYZ = torch.meshgrid(ind, ind, indexing="ij") + else: + XYZ = torch.meshgrid(ind, ind, ind, indexing="ij") + XYZ = torch.stack([c.reshape((-1,)) for c in XYZ], dim=-1) + XYZ = XYZ[XYZ.abs().sum(dim=-1) <= max_order] + + # compute locations of image sources + d = room[None, :] + s = source[None, :] + img_loc = torch.where(XYZ % 2 == 1, d * (XYZ + 1) - s, d * XYZ + s) + + # attenuation + exp_lo = abs(torch.floor((XYZ / 2))) + exp_hi = abs(torch.floor((XYZ + 1) / 2)) + t_lo = tr[:, ::2].unsqueeze(1).repeat(1, XYZ.shape[0], 1) # (num_band, left walls) + t_hi = tr[:, 1::2].unsqueeze(1).repeat(1, XYZ.shape[0], 1) # (num_band, right walls) + att = torch.prod((t_lo**exp_lo) * (t_hi**exp_hi), dim=-1) # (num_band, num_image_source) + return img_loc, att + + +def _hann(x: torch.Tensor, T: int): + """Compute the Hann window where the values are truncated based on window length. + torch.hann_window can only sample window function at integer points, the method is to sample + continuous window function at non-integer points. + + Args: + x (torch.Tensor): The fractional component of time delay Tensor. + T (torch.Tensor): The window length of sinc function. + + Returns: + (torch.Tensor): The hann window Tensor where values outside + the sinc window (`T`) is set to zero. + """ + y = torch.where( + torch.abs(x) <= T / 2, + 0.5 * (1 + torch.cos(2 * math.pi * x / T)), + x.new_zeros(1), + ) + return y + + +def _frac_delay(delay: torch.Tensor, delay_i: torch.Tensor, delay_filter_length: int): + """Compute fractional delay of impulse response signal. + + Args: + delay (torch.Tensor): The time delay Tensor in samples. + delay_i (torch.Tensor): The integer part of delay. + delay_filter_length (int): The window length for sinc function. + + Returns: + (torch.Tensor): The impulse response Tensor for all image sources. + """ + if delay_filter_length % 2 != 1: + raise ValueError("The filter length must be odd") + + pad = delay_filter_length // 2 + n = torch.arange(-pad, pad + 1, device=delay.device) + delay_i[..., None] + delay = delay[..., None] + + return torch.special.sinc(n - delay) * _hann(n - delay, 2 * pad) + + +def _adjust_coeff(coeffs: Union[float, torch.Tensor], name: str) -> torch.Tensor: + """Validates and converts absorption or scattering parameters to a tensor with appropriate shape + + Args: + coeff (float or torch.Tensor): The absorption coefficients of wall materials. + + If the dtype is ``float``, the absorption coefficient is identical for all walls and + all frequencies. + + If ``absorption`` is a 1D Tensor, the shape must be `(2*dim,)`, + where the values represent absorption coefficients of ``"west"``, ``"east"``, + ``"south"``, ``"north"``, ``"floor"``, and ``"ceiling"``, respectively. + + If ``absorption`` is a 2D Tensor, the shape must be `(7, 2*dim)`, + where 7 represents the number of octave bands. + + Returns: + (torch.Tensor): The expanded coefficient. + The shape is `(1, 6)` for single octave band case, and + `(7, 6)` for multi octave band case. + """ + num_walls = 6 + if isinstance(coeffs, float): + if coeffs < 0: + raise ValueError(f"`{name}` must be non-negative. Found: {coeffs}") + return torch.full((1, num_walls), coeffs) + if isinstance(coeffs, Tensor): + if torch.any(coeffs < 0): + raise ValueError(f"`{name}` must be non-negative. Found: {coeffs}") + if coeffs.ndim == 1: + if coeffs.numel() != num_walls: + raise ValueError( + f"The shape of `{name}` must be ({num_walls},) when it is a 1D Tensor. " + f"Found the shape {coeffs.shape}." + ) + return coeffs.unsqueeze(0) + if coeffs.ndim == 2: + if coeffs.shape[1] != num_walls: + raise ValueError( + f"The shape of `{name}` must be (NUM_BANDS, {num_walls}) when it " + f"is a 2D Tensor. Found: {coeffs.shape}." + ) + return coeffs + raise TypeError(f"`{name}` must be float or Tensor.") + + +def _validate_inputs( + room: torch.Tensor, + source: torch.Tensor, + mic_array: torch.Tensor, +): + """Validate dimensions of input arguments, and normalize different kinds of absorption into the same dimension. + + Args: + room (torch.Tensor): The size of the room. width, length (and height) + source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(dim,)`. + mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, dim)`. + """ + if not (room.ndim == 1 and room.numel() == 3): + raise ValueError(f"`room` must be a 1D Tensor with 3 elements. Found {room.shape}.") + if not (source.ndim == 1 and source.numel() == 3): + raise ValueError(f"`source` must be 1D Tensor with 3 elements. Found {source.shape}.") + if not (mic_array.ndim == 2 and mic_array.shape[1] == 3): + raise ValueError(f"`mic_array` must be a 2D Tensor with shape (num_channels, 3). Found {mic_array.shape}.") + + +def simulate_rir_ism( + room: torch.Tensor, + source: torch.Tensor, + mic_array: torch.Tensor, + max_order: int, + absorption: Union[float, torch.Tensor], + output_length: Optional[int] = None, + delay_filter_length: int = 81, + center_frequency: Optional[torch.Tensor] = None, + sound_speed: float = 343.0, + sample_rate: float = 16000.0, +) -> Tensor: + r"""Compute Room Impulse Response (RIR) based on the *image source method* :cite:`allen1979image`. + The implementation is based on *pyroomacoustics* :cite:`scheibler2018pyroomacoustics`. + + .. devices:: CPU + + .. properties:: TorchScript + + Args: + room (torch.Tensor): Room coordinates. The shape of `room` must be `(3,)` which represents + three dimensions of the room. + source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(3,)`. + mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, 3)`. + max_order (int): The maximum number of reflections of the source. + absorption (float or torch.Tensor): The *absorption* :cite:`wiki:Absorption_(acoustics)` + coefficients of wall materials for sound energy. + If the dtype is ``float``, the absorption coefficient is identical for all walls and + all frequencies. + If ``absorption`` is a 1D Tensor, the shape must be `(6,)`, where the values represent + absorption coefficients of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``, + and ``"ceiling"``, respectively. + If ``absorption`` is a 2D Tensor, the shape must be `(7, 6)`, where 7 represents the number of octave bands. + output_length (int or None, optional): The output length of simulated RIR signal. If ``None``, + the length is defined as + + .. math:: + \frac{\text{max\_d} \cdot \text{sample\_rate}}{\text{sound\_speed}} + \text{delay\_filter\_length} + + where ``max_d`` is the maximum distance between image sources and microphones. + delay_filter_length (int, optional): The filter length for computing sinc function. (Default: ``81``) + center_frequency (torch.Tensor, optional): The center frequencies of octave bands for multi-band walls. + Only used when ``absorption`` is a 2D Tensor. + sound_speed (float, optional): The speed of sound. (Default: ``343.0``) + sample_rate (float, optional): The sample rate of the generated room impulse response signal. + (Default: ``16000.0``) + + Returns: + (torch.Tensor): The simulated room impulse response waveform. Tensor with dimensions + `(channel, rir_length)`. + + Note: + If ``absorption`` is a 2D Tensor and ``center_frequency`` is set to ``None``, the center frequencies + of octave bands are fixed to ``[125.0, 250.0, 500.0, 1000.0, 2000.0, 4000.0, 8000.0]``. + Users need to tune the values of ``absorption`` to the corresponding frequencies. + """ + _validate_inputs(room, source, mic_array) + absorption = _adjust_coeff(absorption, "absorption") + img_location, att = _compute_image_sources(room, source, max_order, absorption) + + # compute distances between image sources and microphones + vec = img_location[:, None, :] - mic_array[None, :, :] + dist = torch.linalg.norm(vec, dim=-1) # (image_source, channel) + + img_src_att = att[..., None] / dist[None, ...] # (band, image_source, channel) + + # separate delays in integer / frac part + delay = dist * sample_rate / sound_speed # distance to delay in samples + delay_i = torch.ceil(delay) # integer part + + # compute the shorts IRs corresponding to each image source + irs = img_src_att[..., None] * _frac_delay(delay, delay_i, delay_filter_length)[None, ...] + + rir_length = int(delay_i.max() + irs.shape[-1]) + rir = torch.ops.torchaudio._simulate_rir(irs, delay_i.type(torch.int32), rir_length) + + # multi-band processing + if absorption.shape[0] > 1: + if center_frequency is None: + center = torch.tensor( + [125.0, 250.0, 500.0, 1000.0, 2000.0, 4000.0, 8000.0], dtype=room.dtype, device=room.device + ) + else: + center = center_frequency + # n_fft is set to 512 by default. + filters = torch.ops.torchaudio._make_rir_filter(center, sample_rate, n_fft=512) + rir = torchaudio.functional.fftconvolve(rir, filters.unsqueeze(1).repeat(1, rir.shape[1], 1), mode="same") + + # sum up rir signals of all image sources into one waveform. + rir = rir.sum(0) + + if output_length is not None: + if output_length > rir.shape[-1]: + rir = torch.nn.functional.pad(rir, (0, output_length - rir.shape[-1]), "constant", 0.0) + else: + rir = rir[..., :output_length] + + return rir + + +def ray_tracing( + room: torch.Tensor, + source: torch.Tensor, + mic_array: torch.Tensor, + num_rays: int, + absorption: Union[float, torch.Tensor] = 0.0, + scattering: Union[float, torch.Tensor] = 0.0, + mic_radius: float = 0.5, + sound_speed: float = 343.0, + energy_thres: float = 1e-7, + time_thres: float = 10.0, + hist_bin_size: float = 0.004, +) -> torch.Tensor: + r"""Compute energy histogram via ray tracing. + + The implementation is based on *pyroomacoustics* :cite:`scheibler2018pyroomacoustics`. + + ``num_rays`` rays are casted uniformly in all directions from the source; + when a ray intersects a wall, it is reflected and part of its energy is absorbed. + It is also scattered (sent directly to the microphone(s)) according to the ``scattering`` + coefficient. + When a ray is close to the microphone, its current energy is recorded in the output + histogram for that given time slot. + + .. devices:: CPU + + .. properties:: TorchScript + + Args: + room (torch.Tensor): Room coordinates. The shape of `room` must be `(3,)` which represents + three dimensions of the room. + source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(3,)`. + mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, 3)`. + absorption (float or torch.Tensor, optional): The absorption coefficients of wall materials. + (Default: ``0.0``). + If the type is ``float``, the absorption coefficient is identical to all walls and + all frequencies. + If ``absorption`` is a 1D Tensor, the shape must be `(6,)`, representing absorption + coefficients of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``, and + ``"ceiling"``, respectively. + If ``absorption`` is a 2D Tensor, the shape must be `(num_bands, 6)`. + ``num_bands`` is the number of frequency bands (usually 7). + scattering(float or torch.Tensor, optional): The scattering coefficients of wall materials. (Default: ``0.0``) + The shape and type of this parameter is the same as for ``absorption``. + mic_radius(float, optional): The radius of the microphone in meters. (Default: 0.5) + sound_speed (float, optional): The speed of sound in meters per second. (Default: ``343.0``) + energy_thres (float, optional): The energy level below which we stop tracing a ray. (Default: ``1e-7``) + The initial energy of each ray is ``2 / num_rays``. + time_thres (float, optional): The maximal duration for which rays are traced. (Unit: seconds) (Default: 10.0) + hist_bin_size (float, optional): The size of each bin in the output histogram. (Unit: seconds) (Default: 0.004) + + Returns: + (torch.Tensor): The 3D histogram(s) where the energy of the traced ray is recorded. + Each bin corresponds to a given time slot. + The shape is `(channel, num_bands, num_bins)`, where + ``num_bins = ceil(time_thres / hist_bin_size)``. + If both ``absorption`` and ``scattering`` are floats, then ``num_bands == 1``. + """ + if time_thres < hist_bin_size: + raise ValueError( + "`time_thres` must be greater than `hist_bin_size`. " + f"Found: hist_bin_size={hist_bin_size}, time_thres={time_thres}." + ) + + if room.dtype != source.dtype or source.dtype != mic_array.dtype: + raise ValueError( + "dtype of `room`, `source` and `mic_array` must match. " + f"Found: `room` ({room.dtype}), `source` ({source.dtype}) and " + f"`mic_array` ({mic_array.dtype})" + ) + + _validate_inputs(room, source, mic_array) + absorption = _adjust_coeff(absorption, "absorption").to(room.dtype) + scattering = _adjust_coeff(scattering, "scattering").to(room.dtype) + + # Bring absorption and scattering to the same shape + if absorption.shape[0] == 1 and scattering.shape[0] > 1: + absorption = absorption.expand(scattering.shape) + if scattering.shape[0] == 1 and absorption.shape[0] > 1: + scattering = scattering.expand(absorption.shape) + if absorption.shape != scattering.shape: + raise ValueError( + "`absorption` and `scattering` must be broadcastable to the same number of bands and walls. " + f"Inferred shapes absorption={absorption.shape} and scattering={scattering.shape}" + ) + + histograms = torch.ops.torchaudio.ray_tracing( + room, + source, + mic_array, + num_rays, + absorption, + scattering, + mic_radius, + sound_speed, + energy_thres, + time_thres, + hist_bin_size, + ) + + return histograms diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/functional.py b/.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..0805a252af4ef6946606a32a532188cd937321b8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/functional.py @@ -0,0 +1,190 @@ +import math +import warnings +from typing import Optional + +import torch +from torchaudio.functional.functional import _create_triangular_filterbank + + +def _hz_to_bark(freqs: float, bark_scale: str = "traunmuller") -> float: + r"""Convert Hz to Barks. + + Args: + freqs (float): Frequencies in Hz + bark_scale (str, optional): Scale to use: ``traunmuller``, ``schroeder`` or ``wang``. (Default: ``traunmuller``) + + Returns: + barks (float): Frequency in Barks + """ + + if bark_scale not in ["schroeder", "traunmuller", "wang"]: + raise ValueError('bark_scale should be one of "schroeder", "traunmuller" or "wang".') + + if bark_scale == "wang": + return 6.0 * math.asinh(freqs / 600.0) + elif bark_scale == "schroeder": + return 7.0 * math.asinh(freqs / 650.0) + # Traunmuller Bark scale + barks = ((26.81 * freqs) / (1960.0 + freqs)) - 0.53 + # Bark value correction + if barks < 2: + barks += 0.15 * (2 - barks) + elif barks > 20.1: + barks += 0.22 * (barks - 20.1) + + return barks + + +def _bark_to_hz(barks: torch.Tensor, bark_scale: str = "traunmuller") -> torch.Tensor: + """Convert bark bin numbers to frequencies. + + Args: + barks (torch.Tensor): Bark frequencies + bark_scale (str, optional): Scale to use: ``traunmuller``,``schroeder`` or ``wang``. (Default: ``traunmuller``) + + Returns: + freqs (torch.Tensor): Barks converted in Hz + """ + + if bark_scale not in ["schroeder", "traunmuller", "wang"]: + raise ValueError('bark_scale should be one of "traunmuller", "schroeder" or "wang".') + + if bark_scale == "wang": + return 600.0 * torch.sinh(barks / 6.0) + elif bark_scale == "schroeder": + return 650.0 * torch.sinh(barks / 7.0) + # Bark value correction + if any(barks < 2): + idx = barks < 2 + barks[idx] = (barks[idx] - 0.3) / 0.85 + elif any(barks > 20.1): + idx = barks > 20.1 + barks[idx] = (barks[idx] + 4.422) / 1.22 + + # Traunmuller Bark scale + freqs = 1960 * ((barks + 0.53) / (26.28 - barks)) + + return freqs + + +def _hz_to_octs(freqs, tuning=0.0, bins_per_octave=12): + a440 = 440.0 * 2.0 ** (tuning / bins_per_octave) + return torch.log2(freqs / (a440 / 16)) + + +def barkscale_fbanks( + n_freqs: int, + f_min: float, + f_max: float, + n_barks: int, + sample_rate: int, + bark_scale: str = "traunmuller", +) -> torch.Tensor: + r"""Create a frequency bin conversion matrix. + + .. devices:: CPU + + .. properties:: TorchScript + + .. image:: https://download.pytorch.org/torchaudio/doc-assets/bark_fbanks.png + :alt: Visualization of generated filter bank + + Args: + n_freqs (int): Number of frequencies to highlight/apply + f_min (float): Minimum frequency (Hz) + f_max (float): Maximum frequency (Hz) + n_barks (int): Number of mel filterbanks + sample_rate (int): Sample rate of the audio waveform + bark_scale (str, optional): Scale to use: ``traunmuller``,``schroeder`` or ``wang``. (Default: ``traunmuller``) + + Returns: + torch.Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_barks``) + meaning number of frequencies to highlight/apply to x the number of filterbanks. + Each column is a filterbank so that assuming there is a matrix A of + size (..., ``n_freqs``), the applied result would be + ``A * barkscale_fbanks(A.size(-1), ...)``. + + """ + + # freq bins + all_freqs = torch.linspace(0, sample_rate // 2, n_freqs) + + # calculate bark freq bins + m_min = _hz_to_bark(f_min, bark_scale=bark_scale) + m_max = _hz_to_bark(f_max, bark_scale=bark_scale) + + m_pts = torch.linspace(m_min, m_max, n_barks + 2) + f_pts = _bark_to_hz(m_pts, bark_scale=bark_scale) + + # create filterbank + fb = _create_triangular_filterbank(all_freqs, f_pts) + + if (fb.max(dim=0).values == 0.0).any(): + warnings.warn( + "At least one bark filterbank has all zero values. " + f"The value for `n_barks` ({n_barks}) may be set too high. " + f"Or, the value for `n_freqs` ({n_freqs}) may be set too low." + ) + + return fb + + +def chroma_filterbank( + sample_rate: int, + n_freqs: int, + n_chroma: int, + *, + tuning: float = 0.0, + ctroct: float = 5.0, + octwidth: Optional[float] = 2.0, + norm: int = 2, + base_c: bool = True, +): + """Create a frequency-to-chroma conversion matrix. Implementation adapted from librosa. + + Args: + sample_rate (int): Sample rate. + n_freqs (int): Number of input frequencies. + n_chroma (int): Number of output chroma. + tuning (float, optional): Tuning deviation from A440 in fractions of a chroma bin. (Default: 0.0) + ctroct (float, optional): Center of Gaussian dominance window to weight filters by, in octaves. (Default: 5.0) + octwidth (float or None, optional): Width of Gaussian dominance window to weight filters by, in octaves. + If ``None``, then disable weighting altogether. (Default: 2.0) + norm (int, optional): order of norm to normalize filter bank by. (Default: 2) + base_c (bool, optional): If True, then start filter bank at C. Otherwise, start at A. (Default: True) + + Returns: + torch.Tensor: Chroma filter bank, with shape `(n_freqs, n_chroma)`. + """ + # Skip redundant upper half of frequency range. + freqs = torch.linspace(0, sample_rate // 2, n_freqs)[1:] + freq_bins = n_chroma * _hz_to_octs(freqs, bins_per_octave=n_chroma, tuning=tuning) + freq_bins = torch.cat((torch.tensor([freq_bins[0] - 1.5 * n_chroma]), freq_bins)) + freq_bin_widths = torch.cat( + ( + torch.maximum(freq_bins[1:] - freq_bins[:-1], torch.tensor(1.0)), + torch.tensor([1]), + ) + ) + + # (n_freqs, n_chroma) + D = freq_bins.unsqueeze(1) - torch.arange(0, n_chroma) + + n_chroma2 = round(n_chroma / 2) + + # Project to range [-n_chroma/2, n_chroma/2 - 1] + D = torch.remainder(D + n_chroma2, n_chroma) - n_chroma2 + + fb = torch.exp(-0.5 * (2 * D / torch.tile(freq_bin_widths.unsqueeze(1), (1, n_chroma))) ** 2) + fb = torch.nn.functional.normalize(fb, p=norm, dim=1) + + if octwidth is not None: + fb *= torch.tile( + torch.exp(-0.5 * (((freq_bins.unsqueeze(1) / n_chroma - ctroct) / octwidth) ** 2)), + (1, n_chroma), + ) + + if base_c: + fb = torch.roll(fb, -3 * (n_chroma // 12), dims=1) + + return fb diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cc1b62974644a672bce2916c0f6d04e80e55a2e8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__init__.py @@ -0,0 +1,36 @@ +from ._conformer_wav2vec2 import ( + conformer_wav2vec2_base, + conformer_wav2vec2_model, + conformer_wav2vec2_pretrain_base, + conformer_wav2vec2_pretrain_large, + conformer_wav2vec2_pretrain_model, + ConformerWav2Vec2PretrainModel, +) +from ._emformer_hubert import emformer_hubert_base, emformer_hubert_model +from .conv_emformer import ConvEmformer +from .hifi_gan import hifigan_vocoder, hifigan_vocoder_v1, hifigan_vocoder_v2, hifigan_vocoder_v3, HiFiGANVocoder +from .rnnt import conformer_rnnt_base, conformer_rnnt_biasing, conformer_rnnt_biasing_base, conformer_rnnt_model +from .rnnt_decoder import Hypothesis, RNNTBeamSearchBiasing + +__all__ = [ + "conformer_rnnt_base", + "conformer_rnnt_model", + "conformer_rnnt_biasing", + "conformer_rnnt_biasing_base", + "ConvEmformer", + "conformer_wav2vec2_model", + "conformer_wav2vec2_base", + "conformer_wav2vec2_pretrain_model", + "conformer_wav2vec2_pretrain_base", + "conformer_wav2vec2_pretrain_large", + "ConformerWav2Vec2PretrainModel", + "emformer_hubert_base", + "emformer_hubert_model", + "Hypothesis", + "RNNTBeamSearchBiasing", + "HiFiGANVocoder", + "hifigan_vocoder_v1", + "hifigan_vocoder_v2", + "hifigan_vocoder_v3", + "hifigan_vocoder", +] diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7acdda6625fe88b22efe52575a973b4bbfd74af1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/_conformer_wav2vec2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/_conformer_wav2vec2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0808e7bf7c6ad95b477f64eea4aba81e37127e4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/_conformer_wav2vec2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/_emformer_hubert.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/_emformer_hubert.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20339f14ce258e0dd34594c66991a18e82f3596b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/_emformer_hubert.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/conv_emformer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/conv_emformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba3525403b941b60d350d1282c4c3912551de4cd Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/conv_emformer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/hifi_gan.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/hifi_gan.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cd85079e1646c793961f654cf56efabd0c52ab0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/hifi_gan.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/rnnt.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/rnnt.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cbe31486f849924b2363a17f1ac27d3a6c20c6f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/rnnt.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/rnnt_decoder.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/rnnt_decoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7431e4c97952bd199a22e459ffa1445830a80a0c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/rnnt_decoder.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/_conformer_wav2vec2.py b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/_conformer_wav2vec2.py new file mode 100644 index 0000000000000000000000000000000000000000..b1ea86a81c831a8f346dd1290e221ece67be4734 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/_conformer_wav2vec2.py @@ -0,0 +1,794 @@ +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn, Tensor +from torch.nn import Module, ModuleList +from torchaudio.models import Wav2Vec2Model +from torchaudio.models.conformer import ConformerLayer +from torchaudio.models.rnnt import _TimeReduction +from torchaudio.models.wav2vec2 import components + + +def _buffered_arange(max) -> Tensor: + """Compute arange using a buffered tensor across function calls. + Produces same result as torch.arange(end=max). + + Args: + max (int): Ending value for arange. + """ + if not hasattr(_buffered_arange, "buf"): + _buffered_arange.buf = torch.LongTensor() + if max > _buffered_arange.buf.numel(): + _buffered_arange.buf.resize_(max) + torch.arange(max, out=_buffered_arange.buf) + return _buffered_arange.buf[:max] + + +def _sample_negatives(input: Tensor, num_negatives: int, cross_sample_negatives: int) -> Tuple[Tensor, Tensor]: + """Sample negative examples from masked input. + + Args: + input (Tensor): Tensor of dimension `(batch, frame, dim)`. + num_negatives (int): Number of negative examples to sample. + cross_sample_negatives (int): Number of negative examples to cross sample. + + Returns: + (Tensor, Tensor): + Tensor + The negative samples. + Tensor + The indices of the negative samples. + """ + if num_negatives == 0 and cross_sample_negatives == 0: + return ( + torch.zeros(0).to(input.device, input.dtype), + torch.zeros(0).to(input.device, input.dtype), + ) + + B, T, D = input.shape + input = input.view(-1, D) + + cross_high = T * B + high = T + + assert high > 1 + + if num_negatives > 0: + tszs = _buffered_arange(T).unsqueeze(-1).expand(-1, num_negatives).flatten() + + neg_idxs = torch.randint(low=0, high=high - 1, size=(B, num_negatives * T)) + neg_idxs[neg_idxs >= tszs] += 1 + + if cross_sample_negatives > 0: + tszs = _buffered_arange(T).unsqueeze(-1).expand(-1, cross_sample_negatives).flatten() + + cross_neg_idxs = torch.randint(low=0, high=cross_high - 1, size=(B, cross_sample_negatives * T)) + cross_neg_idxs[cross_neg_idxs >= tszs] += 1 + + if num_negatives > 0: + neg_idxs = neg_idxs + (torch.arange(B).unsqueeze(1) * high) + else: + neg_idxs = cross_neg_idxs + + if cross_sample_negatives > 0 and num_negatives > 0: + neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) + + negs = input[neg_idxs.view(-1)] + negs = negs.view(B, T, num_negatives + cross_sample_negatives, D).permute(2, 0, 1, 3) # NxBxCxT + + return negs, neg_idxs + + +class NegativeSampler(Module): + r"""Applies preprocessing to input and then computes negative sampling. + + Args: + preprocessor (nn.Module): Transforms input tensor prior to negative sampling. + num_negatives (int): Number of negative examples to sample. + cross_sample_negatives (int): Number of negative examples to cross sample. + """ + + def __init__( + self, + preprocessor: Module, + num_negatives: int, + cross_sample_negatives: int, + ): + super().__init__() + self.preprocessor = preprocessor + self.num_negatives = num_negatives + self.cross_sample_negatives = cross_sample_negatives + + def forward(self, input: Tensor) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + """ + Args: + input (Tensor): Tensor of dimension `(B, T, D)`. + + Returns: + (Tensor, Tensor, Optional[Tensor]): + Tensor + The input tensor after preprocessing, prior to being sampled. + Tensor + The negative samples. + Tensor + The indices of the negative samples. + """ + preprocessed = self.preprocessor(input) + negs, neg_idxs = _sample_negatives(preprocessed, self.num_negatives, self.cross_sample_negatives) + return preprocessed, negs, neg_idxs + + +class FeatureEncoder(Module): + """Feature Encoder class, consisting of time reduction and linear layer. + + Args: + stride (int): Number of frames to merge for the output frame. + input_dim (int): Input dimension of the tensor. + output_dim (int): Output dimension of the tensor. + """ + + def __init__(self, input_dim: int, output_dim: int, stride: int): + super().__init__() + self.time_reduction_layer = _TimeReduction(stride=stride) + self.linear_layer = nn.Linear(input_dim * stride, output_dim) + + def forward( + self, + x: Tensor, + lengths: Optional[Tensor], + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x (Tensor): Feature Tensor representing log Mel Spectrogram output. shape ``(B, T, D)``. + lengths (Tensor or None): + Valid length of each input sample. shape: ``(B, )``. + + Returns: + (Tensor, Optional[Tensor]): + Tensor: output sequence after undergoing time reduction and linear projection. + Shape ``(B, T // stride, D * stride). + Optional[Tensor]: output lengths of shape ``(B,)`` if lengths parameter is provided, + otherwise `None`. + """ + if lengths is None: + B, T, D = x.shape + dummy_lengths = torch.full((B,), T) + x, _ = self.time_reduction_layer(x, dummy_lengths) + x = self.linear_layer(x) + return x, None + + x, lengths = self.time_reduction_layer(x, lengths) + x = self.linear_layer(x) + return x, lengths + + +class ConformerEncoder(Module): + """Conformer Encoder class, consisting of feature projection and conformer modules. + + Args: + feature_projection (nn.Module): + Projects feature to encoder dimension. + conformer (nn.ModuleList) + List of Conformer layers. + """ + + def __init__( + self, + feature_projection: Module, + conformer: ModuleList, + ): + super().__init__() + self.feature_projection = feature_projection + self.conformer = conformer + + def _preprocess( + self, + features: Tensor, + lengths: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + x = self.feature_projection(features) + if lengths is not None: + mask = components._get_padding_mask(x, lengths) + else: + mask = None + return x, mask + + def _get_intermediate_outputs( + self, + x: Tensor, + mask: Optional[Tensor] = None, + num_layers: Optional[int] = None, + ) -> List[Tensor]: + if num_layers is not None: + if not 0 < num_layers <= len(self.conformer): + raise ValueError(f"`num_layers` must be between [1, {len(self.conformer)}]") + + ret: List[Tensor] = [] + + x = x.transpose(0, 1) + for layer in self.conformer: + x = layer(x, mask) + ret.append(x.transpose(0, 1)) + if num_layers is not None and len(ret) >= num_layers: + return ret + return ret + + def forward( + self, + features: Tensor, + lengths: Optional[Tensor] = None, + ) -> Tensor: + """ + Args: + features (Tensor): Tensor of features of shape ``(B, T, D)``. + lengths (Tensor or None, optional): Valid length of each input sample. shape: ``(B, )``. + + Returns: + Tensor: result after applying conformer encoder to features. + """ + x, mask = self._preprocess(features, lengths) + x = x.transpose(0, 1) + for layer in self.conformer: + x = layer(x, mask) + return x.transpose(0, 1) + + def extract_features( + self, + features: Tensor, + lengths: Optional[Tensor] = None, + num_layers: Optional[int] = None, + ) -> List[Tensor]: + """Returns the list of outputs from the intermediate layers of conformer block in the encoder. + + Args: + features (Tensor): Tensor of features of shape ``(B, T, D)``. + lengths (Tensor or None, optional): Valid length of each input sample. shape: ``(B, )``. + + Returns: + List[Tensor]: + Features from requested layers. Each Tensor is of shape: `(batch, time frame, feature dimension)`. + """ + x, masks = self._preprocess(features, lengths) + return self._get_intermediate_outputs(x, mask=masks, num_layers=num_layers) + + +class ConformerWav2Vec2PretrainModel(Module): + """Conformer Wav2Vec2 pre-train model for training from scratch. + + Note: + To build the model, please use one of the factory functions, + :py:func:`conformer_wav2vec2_base` or :py:func:`conformer_wav2vec2_large` + + Args: + wav2vec2 (nn.Module): + Conformer based Wav2Vec2 model, including feature extractor and conformer encoder components. + mask_generator (nn.Module): + Mask generator that generates the mask for masked prediction during training. + negative_sampler (nn.Module): + Negative sampler to apply after masking. + + """ + + def __init__( + self, + wav2vec2: Wav2Vec2Model, + mask_generator: Module, + negative_sampler: Module, + ): + super().__init__() + self.wav2vec2 = wav2vec2 + self.mask_generator = mask_generator + self.negative_sampler = negative_sampler + + def forward( + self, + features: Tensor, + audio_lengths: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor]: + """ + Args: + features (Tensor): + Tensor of audio features of shape `(batch, frame, dim)`. + audio_lengths (Tensor or None, optional): + Tensor of valid length of each valid auidio in the batch. + shape: `(batch, )` (Default: ``None``) + + Returns: + (Tensor, Optional[Tensor], Tensor, Tensor, Tensor, Tensor): + Tensor + The masked sequences of probability distribution of shape `(batch, frame dim)`. + Tensor or None + If ``lengths`` argument was provided, a Tensor of shape `(batch, )` representing + valid length in time axis is returns. + Tensor + The mask indices. + Tensor + The targets, prior to negative sampling. + Tensor + The negative samples. + Tensor + The indices of the negative samples. + """ + x, lengths = self.wav2vec2.feature_extractor(features, audio_lengths) + + if lengths is not None: + padding_mask = components._get_padding_mask(x, lengths) + else: + padding_mask = None + + x = self.wav2vec2.encoder.feature_projection.layer_norm(x) + x = self.wav2vec2.encoder.feature_projection.dropout(x) + + # Unmasked feature is used to generate positive and negative samples. + unmasked_x = x.clone() + # Apply masking to x before passing it to Conformer layers. + x, mask_idxs = self.mask_generator(x, padding_mask) + # Select the frames from masked indices for negative sampling. + unmasked_x = unmasked_x[mask_idxs].view(x.shape[0], -1, x.shape[-1]) + targets, negs, neg_idxs = self.negative_sampler(unmasked_x) + + x = self.wav2vec2.encoder.feature_projection.projection(x) + x = x.transpose(0, 1) + for conformer_layer in self.wav2vec2.encoder.conformer: + x = conformer_layer(x, padding_mask) + x = x.transpose(0, 1) + + return x, lengths, mask_idxs, targets, negs, neg_idxs + + +################################################################################ +def _get_conformer_feature_extractor( + input_dim: int, + output_dim: int, + stride: int, +) -> FeatureEncoder: + """Construct Feature Extractor + + Args: + input_dim (int): Input dimension of features. + output_dim (int): Output dimension after feature extraction. + stride (int): Stride used in Time Reduction layer of feature extractor. + + Returns: + FeatureEncoder: The resulting feature extraction. + """ + return FeatureEncoder(input_dim, output_dim, stride) + + +def _get_conformer_encoder( + in_features: int, + embed_dim: int, + dropout_input: float, + num_layers: int, + num_heads: int, + ff_interm_features: int, + dropout: float, + depthwise_conv_kernel_size: Union[int, List[int]], + convolution_first: bool, + use_group_norm: bool, +) -> ConformerEncoder: + """Construct Conformer Encoder + + Args: + in_features (int): The number of input features. + embed_dim (int): The dimension of the embedding in the feature projection. + dropout_input (float): The dropout probability applied after the input feature + is projected to ``embed_dim``. + num_layers (int): Number of Conformer layers in the encoder. + num_heads (int): Number of heads in each Conformer layer. + ff_interm_features (int): Hidden layer dimension of the feedforward network in + each Conformer layer. + dropout (float): Dropout probability in each Conformer layer. + depthwise_conv_kernel_size (int or List[int]): List of kernel sizes corresponding + to each of the Conformer layers.If int is provided, all layers will have the + same kernel size. + convolution_first (bool): Whether to apply the convolution module ahead of the + attention module in each Conformer layer. + use_group_norm (bool): Whether to use ``GroupNorm`` rather than ``BatchNorm1d`` in + the convolution module in each Conformer layer. + + Returns: + ConformerEncoder: + The resulting conformer encoder module. + """ + feature_projection = components.FeatureProjection(in_features, embed_dim, dropout_input) + + if type(depthwise_conv_kernel_size) == int: + depthwise_conv_kernel_size = [depthwise_conv_kernel_size] * num_layers + + assert len(depthwise_conv_kernel_size) == num_layers + + conformer_layers = [] + for l in range(num_layers): + layer = ConformerLayer( + input_dim=embed_dim, + ffn_dim=ff_interm_features, + num_attention_heads=num_heads, + depthwise_conv_kernel_size=depthwise_conv_kernel_size[l], + dropout=dropout, + use_group_norm=use_group_norm, + convolution_first=convolution_first, + ) + conformer_layers.append(layer) + + return ConformerEncoder(feature_projection, ModuleList(conformer_layers)) + + +def _get_conformer_negativer_sampler( + input_dim: int, + output_dim: int, + num_negatives: int, + cross_sample_negatives: int, +) -> NegativeSampler: + """Build custom NegativeSampler module, including linear layer and negative sampling. + + Args: + input_dim (int): Dimension of input after feature extraction. + output_dim (int): Dimension of embedding for use in negative sampling. Same as the + embedding in the feature projection. + num_negatives (int): Number of negatives to sample. + cross_sample_negatives (int): Number of cross sampled negatives. + + Returns: + NegativeSampler: + The resulting negative sampler module. + """ + preprocessor = nn.Linear(input_dim, output_dim) + return NegativeSampler(preprocessor, num_negatives, cross_sample_negatives) + + +def conformer_wav2vec2_model( + extractor_input_dim: int, + extractor_output_dim: int, + extractor_stride: int, + encoder_embed_dim: int, + encoder_projection_dropout: float, + encoder_num_layers: int, + encoder_num_heads: int, + encoder_ff_interm_features: int, + encoder_depthwise_conv_kernel_size: Union[int, List[int]], + encoder_dropout: float, + encoder_convolution_first: bool, + encoder_use_group_norm: bool, +) -> Wav2Vec2Model: + """Build a custom Conformer Wav2Vec2Model + + Args: + extractor_input_dim (int): Input dimension of the features. + extractor_output_dim (int): Output dimension after feature extraction. + extractor_stride (int): Stride used in time reduction layer of feature extraction. + encoder_embed_dim (int): The dimension of the embedding in the feature projection. + encoder_projection_dropout (float): + The dropout probability applied after the input feature is projected to ``embed_dim`` + encoder_num_layers (int): Number of Conformer layers in the encoder. + encoder_num_heads (int): Number of heads in each Conformer layer. + encoder_ff_interm_features (int): + Hidden layer dimension of the feedforward network in each Conformer layer. + encoder_depthwise_conv_kernel_size (int or List[int]): + List of kernel sizes corresponding to each of the Conformer layers. + If int is provided, all layers will have the same kernel size. + encoder_dropout (float): Dropout probability in each Conformer layer. + encoder_convolution_first (bool): + Whether to apply the convolution module ahead of the attention module + in each Conformer layer. + encoder_use_group_norm (bool): + Whether to use ``GroupNorm`` rather than ``BatchNorm1d`` in the convolution + module in each Conformer layer. + + Returns: + Wav2Vec2Model: + The resulting wav2vec2 model with a conformer encoder. + """ + feature_extractor = _get_conformer_feature_extractor( + extractor_input_dim, + extractor_output_dim, + extractor_stride, + ) + + encoder = _get_conformer_encoder( + in_features=extractor_output_dim, + embed_dim=encoder_embed_dim, + dropout_input=encoder_projection_dropout, + num_layers=encoder_num_layers, + num_heads=encoder_num_heads, + ff_interm_features=encoder_ff_interm_features, + depthwise_conv_kernel_size=encoder_depthwise_conv_kernel_size, + dropout=encoder_dropout, + convolution_first=encoder_convolution_first, + use_group_norm=encoder_use_group_norm, + ) + + return Wav2Vec2Model(feature_extractor, encoder) + + +def conformer_wav2vec2_base( + extractor_input_dim: int = 64, + extractor_output_dim: int = 256, + encoder_projection_dropout: float = 0.0, +) -> Wav2Vec2Model: + """ + Build Conformer Wav2Vec2 Model with "small" architecture from + *Conformer-Based Slef-Supervised Learning for Non-Speech Audio Tasks* :cite:`9746490` + + Args: + extractor_input_dim (int, optional): Input dimension of feature extractor. (Default: 64) + extractor_output_dim (int, optional): Output dimension of feature extractor. (Default: 256) + encoder_projection_dropout (float, optional): + Dropout probability applied after feature projection. (Default: 0.0) + + Returns: + Wav2Vec2Model: + The resulting wav2vec2 model with a conformer encoder and ``base`` configuration. + """ + return conformer_wav2vec2_model( + extractor_input_dim=extractor_input_dim, + extractor_output_dim=extractor_output_dim, + extractor_stride=4, + encoder_embed_dim=256, + encoder_projection_dropout=encoder_projection_dropout, + encoder_num_layers=12, + encoder_num_heads=8, + encoder_ff_interm_features=1024, + encoder_depthwise_conv_kernel_size=[31] + [15] * 11, + encoder_dropout=0.1, + encoder_convolution_first=True, + encoder_use_group_norm=True, + ) + + +def conformer_wav2vec2_pretrain_model( + extractor_input_dim: int, + extractor_output_dim: int, + extractor_stride: int, + encoder_embed_dim: int, + encoder_projection_dropout: float, + encoder_num_layers: int, + encoder_num_heads: int, + encoder_ff_interm_features: int, + encoder_depthwise_conv_kernel_size: int, + encoder_dropout: float, + encoder_convolution_first: bool, + encoder_use_group_norm: bool, + mask_prob: float, + mask_selection: str, + mask_other: float, + mask_length: int, + no_mask_overlap: bool, + mask_min_space: int, + mask_channel_prob: float, + mask_channel_selection: str, + mask_channel_other: float, + mask_channel_length: int, + no_mask_channel_overlap: bool, + mask_channel_min_space: int, + num_negatives: int, + cross_sample_negatives: int, +) -> ConformerWav2Vec2PretrainModel: + """Build a custom Conformer Wav2Vec2 Model for pre-training + + Args: + extractor_input_dim (int): Input dimension of the features. + extractor_output_dim (int): Output dimension after feature extraction. + extractor_stride (int): + Stride used in time reduction layer of feature extraction. + encoder_embed_dim (int): + The dimension of the embedding in the feature projection. + encoder_projection_dropout (float): + The dropout probability applied after the input feature is projected to + ``embed_dim`` + encoder_num_layers (int): + Number of Conformer layers in the encoder. + encoder_num_heads (int): + Number of heads in each Conformer layer. + encoder_ff_interm_features (int): + Hidden layer dimension of the feedforward network in each Conformer layer. + encoder_depthwise_conv_kernel_size (int or List[int]): + List of kernel sizes corresponding to each of the Conformer layers. + If int is provided, all layers will have the same kernel size. + encoder_dropout (float): + Dropout probability in each Conformer layer. + encoder_convolution_first (bool): + Whether to apply the convolution module ahead of the attention module + in each Conformer layer. + encoder_use_group_norm (bool): + Whether to use ``GroupNorm`` rather than ``BatchNorm1d`` in the convolution + module in each Conformer layer. + mask_prob (float): + Probability for each token to be chosen as start of the span to be masked. + mask_selection (str) + How to choose the mask length. Options: [``static``, ``uniform``, ``normal``, ``poisson``]. + mask_other (float): + Secondary mask argument (used for more complex distributions). + mask_length (int): + The lengths of the mask. + no_mask_overlap (bool): + Whether to allow masks to overlap. + mask_min_space (int): + Minimum space between spans (if no overlap is enabled). + mask_channel_prob: (float): + The probability of replacing a feature with 0. + mask_channel_selection (str): + How to choose the mask length for channel masking. + Options: [``static``, ``uniform``, ``normal``, ``poisson``]. + mask_channel_other (float): + Secondary mask argument for channel masking (used for more complex distributions). + mask_channel_length (int): + Minimum space between spans (if no overlap is enabled) for channel masking. + no_mask_channel_overlap (bool): + Whether to allow channel masks to overlap. + mask_channel_min_space (int): + Minimum space between spans for channel masking (if no overlap is enabled). + num_negatives (int): + Number of negatives to sample. + cross_sample_negatives (int): + Number of cross sampled negatives. + + Returns: + ConformerWav2Vec2PretrainModel: + The resulting model. + """ + wav2vec2 = conformer_wav2vec2_model( + extractor_input_dim, + extractor_output_dim, + extractor_stride, + encoder_embed_dim, + encoder_projection_dropout, + encoder_num_layers, + encoder_num_heads, + encoder_ff_interm_features, + encoder_depthwise_conv_kernel_size, + encoder_dropout, + encoder_convolution_first, + encoder_use_group_norm, + ) + + mask_generator = components.MaskGenerator( + extractor_output_dim, + mask_prob, + mask_selection, + mask_other, + mask_length, + no_mask_overlap, + mask_min_space, + mask_channel_prob, + mask_channel_selection, + mask_channel_other, + mask_channel_length, + no_mask_channel_overlap, + mask_channel_min_space, + ) + + negative_sampler = _get_conformer_negativer_sampler( + extractor_output_dim, + encoder_embed_dim, + num_negatives, + cross_sample_negatives, + ) + + return ConformerWav2Vec2PretrainModel( + wav2vec2=wav2vec2, + mask_generator=mask_generator, + negative_sampler=negative_sampler, + ) + + +def conformer_wav2vec2_pretrain_base( + extractor_input_dim: int = 64, + extractor_output_dim: int = 256, + encoder_projection_dropout: float = 0.0, + mask_prob: float = 0.3, + mask_length: int = 3, + num_negatives: int = 100, + cross_sample_negatives: int = 0, +) -> ConformerWav2Vec2PretrainModel: + """Build Conformer Wav2Vec2 Model for pre-training with "small" architecture from + *Conformer-Based Self-Supervised Learning for Non-Speech Audio Tasks* :cite:`9746490` + + Args: + extractor_input_dim (int, optional): Input dimension of the features. (Default: 64) + extractor_output_dim (int, optional): Output dimension after feature extraction. (Default: 256) + encoder_projection_dropout (float, optional): + The dropout probability applied after the input feature is projected to + ``embed_dim``. (Default: 0.0) + mask_prob (float, optional): + Probability for each token to be chosen as start of the span to be masked. (Default: 0.3) + mask_length (int, optional): + The lengths of the mask. (Default: 3) + num_negatives (int, optional): + Number of sampled negatives. (Default: 0) + cross_sample_negatives (int, optional): + Number of cross sampled negatives. (Default: 0) + + Returns: + ConformerWav2Vec2PretrainModel: + The resulting model. + """ + return conformer_wav2vec2_pretrain_model( + extractor_input_dim=extractor_input_dim, + extractor_output_dim=extractor_output_dim, + extractor_stride=4, + encoder_embed_dim=256, + encoder_projection_dropout=encoder_projection_dropout, + encoder_num_layers=12, + encoder_num_heads=8, + encoder_ff_interm_features=1024, + encoder_depthwise_conv_kernel_size=[31] + [15] * 11, + encoder_dropout=0.1, + encoder_convolution_first=True, + encoder_use_group_norm=True, + mask_prob=mask_prob, + mask_selection="static", + mask_other=0.0, + mask_length=mask_length, + no_mask_overlap=False, + mask_min_space=0, + mask_channel_prob=0, + mask_channel_selection="static", + mask_channel_other=0, + mask_channel_length=10, + no_mask_channel_overlap=False, + mask_channel_min_space=1, + num_negatives=num_negatives, + cross_sample_negatives=cross_sample_negatives, + ) + + +def conformer_wav2vec2_pretrain_large( + extractor_input_dim: int = 64, + extractor_output_dim: int = 256, + encoder_projection_dropout: float = 0.0, + mask_prob: float = 0.3, + mask_length: int = 3, + num_negatives: int = 100, + cross_sample_negatives: int = 0, +) -> ConformerWav2Vec2PretrainModel: + """Build Conformer Wav2Vec2 Model for pre-training with "large" architecture from + *Conformer-Based Slef-Supervised Learning for Non-Speech Audio Tasks* :cite:`9746490` + + Args: + extractor_input_dim (int, optional): Input dimension of the features. (Default: 64) + extractor_output_dim (int, optional): Output dimension after feature extraction. (Default: 256) + encoder_projection_dropout (float, optional): + The dropout probability applied after the input feature is projected to + ``embed_dim``. (Default: 0.0) + mask_prob (float, optional): + Probability for each token to be chosen as start of the span to be masked. (Default: 0.3) + mask_length (int, optional): + The lengths of the mask. (Default: 3) + num_negatives (int, optional): + Number of sampled negatives. (Default: 0) + cross_sample_negatives (int, optional): + Number of cross sampled negatives. (Default: 0) + + Returns: + ConformerWav2Vec2PretrainModel: + The resulting model. + """ + return conformer_wav2vec2_pretrain_model( + extractor_input_dim=extractor_input_dim, + extractor_output_dim=extractor_output_dim, + extractor_stride=4, + encoder_embed_dim=768, + encoder_projection_dropout=encoder_projection_dropout, + encoder_num_layers=12, + encoder_num_heads=12, + encoder_ff_interm_features=1024, + encoder_depthwise_conv_kernel_size=[31] + [15] * 11, + encoder_dropout=0.1, + encoder_convolution_first=True, + encoder_use_group_norm=True, + mask_prob=mask_prob, + mask_selection="static", + mask_other=0.0, + mask_length=mask_length, + no_mask_overlap=False, + mask_min_space=0, + mask_channel_prob=0, + mask_channel_selection="static", + mask_channel_other=0, + mask_channel_length=10, + no_mask_channel_overlap=False, + mask_channel_min_space=1, + num_negatives=num_negatives, + cross_sample_negatives=cross_sample_negatives, + ) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/_emformer_hubert.py b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/_emformer_hubert.py new file mode 100644 index 0000000000000000000000000000000000000000..872c6ce90191a841d7a1387bf17a1803b1689b83 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/_emformer_hubert.py @@ -0,0 +1,333 @@ +from typing import List, Optional, Tuple + +import torch +from torchaudio.models import Wav2Vec2Model +from torchaudio.models.emformer import Emformer +from torchaudio.models.rnnt import _TimeReduction + + +class FeatureEncoder(torch.nn.Module): + """Extract features from log-mel spectrogram input. Consists of linear layer and time reduction layer. + + Args: + input_dim (int): The feature dimension of log-mel spectrogram feature. + output_dim (int): The feature dimension after linear layer. + use_bias (bool): If ``True``, enable bias parameter in the linear layer. + stride (int): Number of frames to merge for the output frame. + """ + + def __init__(self, input_dim: int, output_dim: int, use_bias: bool, stride: int): + super().__init__() + self.linear = torch.nn.Linear(input_dim, output_dim, bias=use_bias) + self.time_reduction = _TimeReduction(stride) + + def forward( + self, input: torch.Tensor, lengths: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + input (torch.Tensor): The log-mel spectrogram input. + Tensor with dimensions `(batch, time, input_dim)`. + lengths (torch.Tensor or None): Valid length of each input sample. + Tensor with dimension `(batch, )`. + + Returns: + (torch.Tensor, torch.Tensor or None): + torch.Tensor + Returned feature Tensor after linear layer and time reduction layer. + Tensor with dimensions `(batch, time // stride, output_dim)`. + torch.Tensor or None + The reduced lengths Tensor. + """ + output = self.linear(input) + if lengths is None: + B, T, _ = input.shape + dummy_lengths = torch.full((B,), T) + output, _ = self.time_reduction(output, dummy_lengths) + else: + output, lengths = self.time_reduction(output, lengths) + return output, lengths + + +class EmformerEncoder(torch.nn.Module): + """Emformer Encoder class for HuBERT pre-training. Consists of emformer module, + linear layer and layer normalization layer. + + Args: + emformer (torch.nn.Module): + :py:class:`torchaudio.models.Emformer` module that consists of a list of emformer layers. + output_linear (torch.nn.Module): + Linear layer after emformer module. + layer_norm (torch.nn.Module): + Apply layer normalization to the output. + """ + + def __init__( + self, + emformer: torch.nn.Module, + output_linear: torch.nn.Module, + layer_norm: torch.nn.Module, + ): + super().__init__() + self.emformer = emformer + self.output_linear = output_linear + self.layer_norm = layer_norm + + def forward( + self, + input: torch.Tensor, + lengths: Optional[torch.Tensor], + ) -> torch.Tensor: + """ + Args: + input (torch.Tensor): The input feature for emformer encoder. + Tensor with dimensions `(batch, time, feature_dim)`. + lengths (torch.Tensor or None): Valid length of each input sample. + Tensor with dimension `(batch, )`. + + Returns: + torch.Tensor: The feature Tensor after emformer encoder. + """ + if lengths is None: + B, T, _ = input.shape + dummy_lengths = torch.full((B,), T) + output, _ = self.emformer(input, dummy_lengths) + else: + output, lengths = self.emformer(input, lengths) + output = self.output_linear(output) + output = self.layer_norm(output) + return output + + def extract_features( + self, + input: torch.Tensor, + lengths: Optional[torch.Tensor], + num_layers: Optional[int] = None, + ) -> List[torch.Tensor]: + """Extract output Tensors of the emformer layers. + + Args: + input (torch.Tensor): The input feature for emformer encoder. + Tensor with dimensions `(batch, time, feature_dim)`. + lengths (torch.Tensor or None): Valid length of each input sample. + Tensor with dimension `(batch, )`. + num_layers (int or None, optional): If not ``None``, returns the first + `num_layers` layers of Tensors as the output, otherwise returns the + Tensors from all emformer layers. + + Returns: + List[torch.Tensor]: + Output Tensors of selected emformer layers. + """ + if num_layers is not None: + if not 0 < num_layers <= len(self.emformer.emformer_layers): + raise ValueError(f"`num_layers` must be between [1, {len(self.emformer.emformer_layers)}]") + + ret: List[torch.Tensor] = [] + + input = input.permute(1, 0, 2) + right_context = self.emformer._gen_right_context(input) + utterance = input[: input.size(0) - self.emformer.right_context_length] + attention_mask = self.emformer._gen_attention_mask(utterance) + mems = ( + self.emformer.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1] + if self.emformer.use_mem + else torch.empty(0).to(dtype=input.dtype, device=input.device) + ) + output = utterance + if lengths is None: + B, T, _ = input.shape + lengths = torch.full((B,), T) + for layer in self.emformer.emformer_layers: + output, right_context, mems = layer(output, lengths, right_context, mems, attention_mask) + ret.append(output.permute(1, 0, 2)) + if num_layers is not None and len(ret) >= num_layers: + return ret + return ret + + +def _get_emformer_feature_extractor(input_dim: int, output_dim: int, use_bias: bool, stride: int) -> FeatureEncoder: + """Construct FeatureEncoder for emformer model. + + Args: + input_dim (int): The feature dimension of log-mel spectrogram feature. + output_dim (int): The feature dimension after linear layer. + use_bias (bool): If ``True``, enable bias parameter in the linear layer. + stride (int): Number of frames to merge for the output frame. + + Returns: + FeatureEncoder: The resulting FeatureEncoder module. + """ + return FeatureEncoder(input_dim, output_dim, use_bias, stride) + + +def _get_emformer_encoder( + input_dim: int, + output_dim: int, + num_heads: int, + ffn_dim: int, + num_layers: int, + segment_length: int, + left_context_length: int, + right_context_length: int, + dropout: float, + activation: str, + max_memory_size: int, + weight_init_scale_strategy: Optional[str], + tanh_on_mem: bool, +) -> EmformerEncoder: + """Construct EmformerEncoder for emformer model. + + Args: + input_dim (int): The feature dimension of input Tensor. + output_dim (int): The feature dimension after EmformerEncoder. + num_heads (int): Number of attention heads in each Emformer layer. + ffn_dim: (int): Hidden layer dimension of feedforward network. + num_layers (int): Number of Emformer layers to instantiate. + segment_length (int): Length of each input segment. + left_context_length (int): Length of left context. + right_context_length (int): Length of right context. + dropout (float): Dropout probability. + activation (str): Activation function to use in each Emformer layer's + feedforward network. Must be one of ("relu", "gelu", "silu"). + max_memory_size (int): Maximum number of memory elements to use. + weight_init_scale_strategy (str or None): Per-layer weight initialization scaling + strategy. Must be one of ("depthwise", "constant", ``None``). + tanh_on_mem (bool): If ``True``, applies tanh to memory elements. + + Returns: + EmformerEncoder: The resulting EmformerEncoder module. + """ + emformer = Emformer( + input_dim=input_dim, + num_heads=num_heads, + ffn_dim=ffn_dim, + num_layers=num_layers, + segment_length=segment_length, + left_context_length=left_context_length, + right_context_length=right_context_length, + dropout=dropout, + activation=activation, + max_memory_size=max_memory_size, + weight_init_scale_strategy=weight_init_scale_strategy, + tanh_on_mem=tanh_on_mem, + ) + output_linear = torch.nn.Linear(input_dim, output_dim) + layer_norm = torch.nn.LayerNorm(output_dim) + return EmformerEncoder(emformer, output_linear, layer_norm) + + +def emformer_hubert_model( + extractor_input_dim: int, + extractor_output_dim: int, + extractor_use_bias: bool, + extractor_stride: int, + encoder_input_dim: int, + encoder_output_dim: int, + encoder_num_heads: int, + encoder_ffn_dim: int, + encoder_num_layers: int, + encoder_segment_length: int, + encoder_left_context_length: int, + encoder_right_context_length: int, + encoder_dropout: float, + encoder_activation: str, + encoder_max_memory_size: int, + encoder_weight_init_scale_strategy: Optional[str], + encoder_tanh_on_mem: bool, + aux_num_out: Optional[int], +) -> Wav2Vec2Model: + """Build a custom Emformer HuBERT model. + + Args: + extractor_input_dim (int): The input dimension for feature extractor. + extractor_output_dim (int): The output dimension after feature extractor. + extractor_use_bias (bool): If ``True``, enable bias parameter in the linear layer of feature extractor. + extractor_stride (int): Number of frames to merge for the output frame in feature extractor. + encoder_input_dim (int): The input dimension for Emformer layer. + encoder_output_dim (int): The output dimension after EmformerEncoder. + encoder_num_heads (int): Number of attention heads in each Emformer layer. + encoder_ffn_dim (int): Hidden layer dimension of feedforward network in Emformer. + encoder_num_layers (int): Number of Emformer layers to instantiate. + encoder_segment_length (int): Length of each input segment. + encoder_left_context_length (int): Length of left context. + encoder_right_context_length (int): Length of right context. + encoder_dropout (float): Dropout probability. + encoder_activation (str): Activation function to use in each Emformer layer's + feedforward network. Must be one of ("relu", "gelu", "silu"). + encoder_max_memory_size (int): Maximum number of memory elements to use. + encoder_weight_init_scale_strategy (str or None): Per-layer weight initialization scaling + strategy. Must be one of ("depthwise", "constant", ``None``). + encoder_tanh_on_mem (bool): If ``True``, applies tanh to memory elements. + aux_num_out (int or None): + When provided, attach an extra linear layer on top of encoder, which can be + used for fine-tuning. + + Returns: + Wav2Vec2Model: + The resulting :py:class:`torchaudio.models.Wav2Vec2Model` model + with a :py:class:`torchaudio.models.Emformer` encoder. + """ + feature_extractor = _get_emformer_feature_extractor( + extractor_input_dim, extractor_output_dim, extractor_use_bias, extractor_stride + ) + emformer = _get_emformer_encoder( + encoder_input_dim, + encoder_output_dim, + encoder_num_heads, + encoder_ffn_dim, + encoder_num_layers, + encoder_segment_length, + encoder_left_context_length, + encoder_right_context_length, + encoder_dropout, + encoder_activation, + encoder_max_memory_size, + encoder_weight_init_scale_strategy, + encoder_tanh_on_mem, + ) + aux = None + if aux_num_out is not None: + aux = torch.nn.Linear(in_features=encoder_output_dim, out_features=aux_num_out) + return Wav2Vec2Model(feature_extractor, emformer, aux) + + +def emformer_hubert_base( + extractor_input_dim: int = 80, + extractor_output_dim: int = 128, + encoder_dropout: float = 0.1, + aux_num_out: Optional[int] = None, +) -> Wav2Vec2Model: + """Build Emformer HuBERT Model with 20 Emformer layers. + + Args: + extractor_input_dim (int, optional): The input dimension for feature extractor. (Default: 80) + extractor_output_dim (int, optional): The output dimension after feature extractor. (Default: 128) + encoder_dropout (float, optional): Dropout probability in Emformer. (Default: 0.1) + aux_num_out (int or None, optional): Output dimension of aux layer for fine-tuning. (Default: ``None``) + + Returns: + Wav2Vec2Model: + The resulting :py:class:`torchaudio.models.Wav2Vec2Model` model + with a :py:class:`torchaudio.models.Emformer` encoder. + """ + return emformer_hubert_model( + extractor_input_dim=extractor_input_dim, + extractor_output_dim=extractor_output_dim, + extractor_use_bias=False, + extractor_stride=4, + encoder_input_dim=512, + encoder_output_dim=1024, + encoder_num_heads=8, + encoder_ffn_dim=2048, + encoder_num_layers=20, + encoder_segment_length=4, + encoder_left_context_length=30, + encoder_right_context_length=1, + encoder_dropout=encoder_dropout, + encoder_activation="gelu", + encoder_max_memory_size=0, + encoder_weight_init_scale_strategy="depthwise", + encoder_tanh_on_mem=True, + aux_num_out=aux_num_out, + ) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/conv_emformer.py b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/conv_emformer.py new file mode 100644 index 0000000000000000000000000000000000000000..b5495cfcd3a4753cdf17d426a1b2bcc8479ee3fa --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/conv_emformer.py @@ -0,0 +1,525 @@ +import math +from typing import List, Optional, Tuple + +import torch +from torchaudio.models.emformer import _EmformerAttention, _EmformerImpl, _get_weight_init_gains + + +def _get_activation_module(activation: str) -> torch.nn.Module: + if activation == "relu": + return torch.nn.ReLU() + elif activation == "gelu": + return torch.nn.GELU() + elif activation == "silu": + return torch.nn.SiLU() + else: + raise ValueError(f"Unsupported activation {activation}") + + +class _ResidualContainer(torch.nn.Module): + def __init__(self, module: torch.nn.Module, output_weight: int): + super().__init__() + self.module = module + self.output_weight = output_weight + + def forward(self, input: torch.Tensor): + output = self.module(input) + return output * self.output_weight + input + + +class _ConvolutionModule(torch.nn.Module): + def __init__( + self, + input_dim: int, + segment_length: int, + right_context_length: int, + kernel_size: int, + activation: str = "silu", + dropout: float = 0.0, + ): + super().__init__() + self.input_dim = input_dim + self.segment_length = segment_length + self.right_context_length = right_context_length + self.state_size = kernel_size - 1 + + self.pre_conv = torch.nn.Sequential( + torch.nn.LayerNorm(input_dim), torch.nn.Linear(input_dim, 2 * input_dim, bias=True), torch.nn.GLU() + ) + self.conv = torch.nn.Conv1d( + in_channels=input_dim, + out_channels=input_dim, + kernel_size=kernel_size, + stride=1, + padding=0, + groups=input_dim, + ) + self.post_conv = torch.nn.Sequential( + torch.nn.LayerNorm(input_dim), + _get_activation_module(activation), + torch.nn.Linear(input_dim, input_dim, bias=True), + torch.nn.Dropout(p=dropout), + ) + + def _split_right_context(self, utterance: torch.Tensor, right_context: torch.Tensor) -> torch.Tensor: + T, B, D = right_context.size() + if T % self.right_context_length != 0: + raise ValueError("Tensor length should be divisible by its right context length") + num_segments = T // self.right_context_length + # (num_segments, right context length, B, D) + right_context_segments = right_context.reshape(num_segments, self.right_context_length, B, D) + right_context_segments = right_context_segments.permute(0, 2, 1, 3).reshape( + num_segments * B, self.right_context_length, D + ) + + pad_segments = [] # [(kernel_size - 1, B, D), ...] + for seg_idx in range(num_segments): + end_idx = min(self.state_size + (seg_idx + 1) * self.segment_length, utterance.size(0)) + start_idx = end_idx - self.state_size + pad_segments.append(utterance[start_idx:end_idx, :, :]) + + pad_segments = torch.cat(pad_segments, dim=1).permute(1, 0, 2) # (num_segments * B, kernel_size - 1, D) + return torch.cat([pad_segments, right_context_segments], dim=1).permute(0, 2, 1) + + def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: + # (num_segments * B, D, right_context_length) + right_context = right_context.reshape(-1, B, self.input_dim, self.right_context_length) + right_context = right_context.permute(0, 3, 1, 2) + return right_context.reshape(-1, B, self.input_dim) # (right_context_length * num_segments, B, D) + + def forward( + self, utterance: torch.Tensor, right_context: torch.Tensor, state: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + input = torch.cat((right_context, utterance)) # input: (T, B, D) + x = self.pre_conv(input) + x_right_context, x_utterance = x[: right_context.size(0), :, :], x[right_context.size(0) :, :, :] + x_utterance = x_utterance.permute(1, 2, 0) # (B, D, T_utterance) + + if state is None: + state = torch.zeros( + input.size(1), + input.size(2), + self.state_size, + device=input.device, + dtype=input.dtype, + ) # (B, D, T) + state_x_utterance = torch.cat([state, x_utterance], dim=2) + + conv_utterance = self.conv(state_x_utterance) # (B, D, T_utterance) + conv_utterance = conv_utterance.permute(2, 0, 1) + + if self.right_context_length > 0: + # (B * num_segments, D, right_context_length + kernel_size - 1) + right_context_block = self._split_right_context(state_x_utterance.permute(2, 0, 1), x_right_context) + conv_right_context_block = self.conv(right_context_block) # (B * num_segments, D, right_context_length) + # (T_right_context, B, D) + conv_right_context = self._merge_right_context(conv_right_context_block, input.size(1)) + y = torch.cat([conv_right_context, conv_utterance], dim=0) + else: + y = conv_utterance + + output = self.post_conv(y) + input + new_state = state_x_utterance[:, :, -self.state_size :] + return output[right_context.size(0) :], output[: right_context.size(0)], new_state + + def infer( + self, utterance: torch.Tensor, right_context: torch.Tensor, state: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + input = torch.cat((utterance, right_context)) + x = self.pre_conv(input) # (T, B, D) + x = x.permute(1, 2, 0) # (B, D, T) + + if state is None: + state = torch.zeros( + input.size(1), + input.size(2), + self.state_size, + device=input.device, + dtype=input.dtype, + ) # (B, D, T) + state_x = torch.cat([state, x], dim=2) + conv_out = self.conv(state_x) + conv_out = conv_out.permute(2, 0, 1) # T, B, D + output = self.post_conv(conv_out) + input + new_state = state_x[:, :, -self.state_size - right_context.size(0) : -right_context.size(0)] + return output[: utterance.size(0)], output[utterance.size(0) :], new_state + + +class _ConvEmformerLayer(torch.nn.Module): + r"""Convolution-augmented Emformer layer that constitutes ConvEmformer. + + Args: + input_dim (int): input dimension. + num_heads (int): number of attention heads. + ffn_dim: (int): hidden layer dimension of feedforward network. + segment_length (int): length of each input segment. + kernel_size (int): size of kernel to use in convolution module. + dropout (float, optional): dropout probability. (Default: 0.0) + ffn_activation (str, optional): activation function to use in feedforward network. + Must be one of ("relu", "gelu", "silu"). (Default: "relu") + left_context_length (int, optional): length of left context. (Default: 0) + right_context_length (int, optional): length of right context. (Default: 0) + max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0) + weight_init_gain (float or None, optional): scale factor to apply when initializing + attention module parameters. (Default: ``None``) + tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``) + negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8) + conv_activation (str, optional): activation function to use in convolution module. + Must be one of ("relu", "gelu", "silu"). (Default: "silu") + """ + + def __init__( + self, + input_dim: int, + num_heads: int, + ffn_dim: int, + segment_length: int, + kernel_size: int, + dropout: float = 0.0, + ffn_activation: str = "relu", + left_context_length: int = 0, + right_context_length: int = 0, + max_memory_size: int = 0, + weight_init_gain: Optional[float] = None, + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + conv_activation: str = "silu", + ): + super().__init__() + # TODO: implement talking heads attention. + self.attention = _EmformerAttention( + input_dim=input_dim, + num_heads=num_heads, + dropout=dropout, + weight_init_gain=weight_init_gain, + tanh_on_mem=tanh_on_mem, + negative_inf=negative_inf, + ) + self.dropout = torch.nn.Dropout(dropout) + self.memory_op = torch.nn.AvgPool1d(kernel_size=segment_length, stride=segment_length, ceil_mode=True) + + activation_module = _get_activation_module(ffn_activation) + self.ffn0 = _ResidualContainer( + torch.nn.Sequential( + torch.nn.LayerNorm(input_dim), + torch.nn.Linear(input_dim, ffn_dim), + activation_module, + torch.nn.Dropout(dropout), + torch.nn.Linear(ffn_dim, input_dim), + torch.nn.Dropout(dropout), + ), + 0.5, + ) + self.ffn1 = _ResidualContainer( + torch.nn.Sequential( + torch.nn.LayerNorm(input_dim), + torch.nn.Linear(input_dim, ffn_dim), + activation_module, + torch.nn.Dropout(dropout), + torch.nn.Linear(ffn_dim, input_dim), + torch.nn.Dropout(dropout), + ), + 0.5, + ) + self.layer_norm_input = torch.nn.LayerNorm(input_dim) + self.layer_norm_output = torch.nn.LayerNorm(input_dim) + + self.conv = _ConvolutionModule( + input_dim=input_dim, + kernel_size=kernel_size, + activation=conv_activation, + dropout=dropout, + segment_length=segment_length, + right_context_length=right_context_length, + ) + + self.left_context_length = left_context_length + self.segment_length = segment_length + self.max_memory_size = max_memory_size + self.input_dim = input_dim + self.kernel_size = kernel_size + self.use_mem = max_memory_size > 0 + + def _init_state(self, batch_size: int, device: Optional[torch.device]) -> List[torch.Tensor]: + empty_memory = torch.zeros(self.max_memory_size, batch_size, self.input_dim, device=device) + left_context_key = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device) + left_context_val = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device) + past_length = torch.zeros(1, batch_size, dtype=torch.int32, device=device) + conv_cache = torch.zeros( + batch_size, + self.input_dim, + self.kernel_size - 1, + device=device, + ) + return [empty_memory, left_context_key, left_context_val, past_length, conv_cache] + + def _unpack_state(self, state: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + past_length = state[3][0][0].item() + past_left_context_length = min(self.left_context_length, past_length) + past_mem_length = min(self.max_memory_size, math.ceil(past_length / self.segment_length)) + pre_mems = state[0][self.max_memory_size - past_mem_length :] + lc_key = state[1][self.left_context_length - past_left_context_length :] + lc_val = state[2][self.left_context_length - past_left_context_length :] + conv_cache = state[4] + return pre_mems, lc_key, lc_val, conv_cache + + def _pack_state( + self, + next_k: torch.Tensor, + next_v: torch.Tensor, + update_length: int, + mems: torch.Tensor, + conv_cache: torch.Tensor, + state: List[torch.Tensor], + ) -> List[torch.Tensor]: + new_k = torch.cat([state[1], next_k]) + new_v = torch.cat([state[2], next_v]) + state[0] = torch.cat([state[0], mems])[-self.max_memory_size :] + state[1] = new_k[new_k.shape[0] - self.left_context_length :] + state[2] = new_v[new_v.shape[0] - self.left_context_length :] + state[3] = state[3] + update_length + state[4] = conv_cache + return state + + def _apply_pre_attention( + self, utterance: torch.Tensor, right_context: torch.Tensor, summary: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + x = torch.cat([right_context, utterance, summary]) + ffn0_out = self.ffn0(x) + layer_norm_input_out = self.layer_norm_input(ffn0_out) + layer_norm_input_right_context, layer_norm_input_utterance, layer_norm_input_summary = ( + layer_norm_input_out[: right_context.size(0)], + layer_norm_input_out[right_context.size(0) : right_context.size(0) + utterance.size(0)], + layer_norm_input_out[right_context.size(0) + utterance.size(0) :], + ) + return ffn0_out, layer_norm_input_right_context, layer_norm_input_utterance, layer_norm_input_summary + + def _apply_post_attention( + self, + rc_output: torch.Tensor, + ffn0_out: torch.Tensor, + conv_cache: Optional[torch.Tensor], + rc_length: int, + utterance_length: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + result = self.dropout(rc_output) + ffn0_out[: rc_length + utterance_length] + conv_utterance, conv_right_context, conv_cache = self.conv(result[rc_length:], result[:rc_length], conv_cache) + result = torch.cat([conv_right_context, conv_utterance]) + result = self.ffn1(result) + result = self.layer_norm_output(result) + output_utterance, output_right_context = result[rc_length:], result[:rc_length] + return output_utterance, output_right_context, conv_cache + + def forward( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + mems: torch.Tensor, + attention_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r"""Forward pass for training. + + B: batch size; + D: feature dimension of each frame; + T: number of utterance frames; + R: number of right context frames; + M: number of memory elements. + + Args: + utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`. + lengths (torch.Tensor): with shape `(B,)` and i-th element representing + number of valid frames for i-th batch element in ``utterance``. + right_context (torch.Tensor): right context frames, with shape `(R, B, D)`. + mems (torch.Tensor): memory elements, with shape `(M, B, D)`. + attention_mask (torch.Tensor): attention mask for underlying attention module. + + Returns: + (Tensor, Tensor, Tensor): + Tensor + encoded utterance frames, with shape `(T, B, D)`. + Tensor + updated right context frames, with shape `(R, B, D)`. + Tensor + updated memory elements, with shape `(M, B, D)`. + """ + if self.use_mem: + summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) + else: + summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + + ( + ffn0_out, + layer_norm_input_right_context, + layer_norm_input_utterance, + layer_norm_input_summary, + ) = self._apply_pre_attention(utterance, right_context, summary) + + rc_output, output_mems = self.attention( + utterance=layer_norm_input_utterance, + lengths=lengths, + right_context=layer_norm_input_right_context, + summary=layer_norm_input_summary, + mems=mems, + attention_mask=attention_mask, + ) + + output_utterance, output_right_context, _ = self._apply_post_attention( + rc_output, ffn0_out, None, right_context.size(0), utterance.size(0) + ) + + return output_utterance, output_right_context, output_mems + + @torch.jit.export + def infer( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + state: Optional[List[torch.Tensor]], + mems: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]: + r"""Forward pass for inference. + + B: batch size; + D: feature dimension of each frame; + T: number of utterance frames; + R: number of right context frames; + M: number of memory elements. + + Args: + utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`. + lengths (torch.Tensor): with shape `(B,)` and i-th element representing + number of valid frames for i-th batch element in ``utterance``. + right_context (torch.Tensor): right context frames, with shape `(R, B, D)`. + state (List[torch.Tensor] or None): list of tensors representing layer internal state + generated in preceding invocation of ``infer``. + mems (torch.Tensor): memory elements, with shape `(M, B, D)`. + + Returns: + (Tensor, Tensor, List[torch.Tensor], Tensor): + Tensor + encoded utterance frames, with shape `(T, B, D)`. + Tensor + updated right context frames, with shape `(R, B, D)`. + List[Tensor] + list of tensors representing layer internal state + generated in current invocation of ``infer``. + Tensor + updated memory elements, with shape `(M, B, D)`. + """ + if self.use_mem: + summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:1] + else: + summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + + ( + ffn0_out, + layer_norm_input_right_context, + layer_norm_input_utterance, + layer_norm_input_summary, + ) = self._apply_pre_attention(utterance, right_context, summary) + + if state is None: + state = self._init_state(layer_norm_input_utterance.size(1), device=layer_norm_input_utterance.device) + pre_mems, lc_key, lc_val, conv_cache = self._unpack_state(state) + + rc_output, next_m, next_k, next_v = self.attention.infer( + utterance=layer_norm_input_utterance, + lengths=lengths, + right_context=layer_norm_input_right_context, + summary=layer_norm_input_summary, + mems=pre_mems, + left_context_key=lc_key, + left_context_val=lc_val, + ) + + output_utterance, output_right_context, conv_cache = self._apply_post_attention( + rc_output, ffn0_out, conv_cache, right_context.size(0), utterance.size(0) + ) + output_state = self._pack_state(next_k, next_v, utterance.size(0), mems, conv_cache, state) + return output_utterance, output_right_context, output_state, next_m + + +class ConvEmformer(_EmformerImpl): + r"""Implements the convolution-augmented streaming transformer architecture introduced in + *Streaming Transformer Transducer based Speech Recognition Using Non-Causal Convolution* + :cite:`9747706`. + + Args: + input_dim (int): input dimension. + num_heads (int): number of attention heads in each ConvEmformer layer. + ffn_dim (int): hidden layer dimension of each ConvEmformer layer's feedforward network. + num_layers (int): number of ConvEmformer layers to instantiate. + segment_length (int): length of each input segment. + kernel_size (int): size of kernel to use in convolution modules. + dropout (float, optional): dropout probability. (Default: 0.0) + ffn_activation (str, optional): activation function to use in feedforward networks. + Must be one of ("relu", "gelu", "silu"). (Default: "relu") + left_context_length (int, optional): length of left context. (Default: 0) + right_context_length (int, optional): length of right context. (Default: 0) + max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0) + weight_init_scale_strategy (str or None, optional): per-layer weight initialization scaling + strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise") + tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``) + negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8) + conv_activation (str, optional): activation function to use in convolution modules. + Must be one of ("relu", "gelu", "silu"). (Default: "silu") + + Examples: + >>> conv_emformer = ConvEmformer(80, 4, 1024, 12, 16, 8, right_context_length=4) + >>> input = torch.rand(10, 200, 80) + >>> lengths = torch.randint(1, 200, (10,)) + >>> output, lengths = conv_emformer(input, lengths) + >>> input = torch.rand(4, 20, 80) + >>> lengths = torch.ones(4) * 20 + >>> output, lengths, states = conv_emformer.infer(input, lengths, None) + """ + + def __init__( + self, + input_dim: int, + num_heads: int, + ffn_dim: int, + num_layers: int, + segment_length: int, + kernel_size: int, + dropout: float = 0.0, + ffn_activation: str = "relu", + left_context_length: int = 0, + right_context_length: int = 0, + max_memory_size: int = 0, + weight_init_scale_strategy: Optional[str] = "depthwise", + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + conv_activation: str = "silu", + ): + weight_init_gains = _get_weight_init_gains(weight_init_scale_strategy, num_layers) + emformer_layers = torch.nn.ModuleList( + [ + _ConvEmformerLayer( + input_dim, + num_heads, + ffn_dim, + segment_length, + kernel_size, + dropout=dropout, + ffn_activation=ffn_activation, + left_context_length=left_context_length, + right_context_length=right_context_length, + max_memory_size=max_memory_size, + weight_init_gain=weight_init_gains[layer_idx], + tanh_on_mem=tanh_on_mem, + negative_inf=negative_inf, + conv_activation=conv_activation, + ) + for layer_idx in range(num_layers) + ] + ) + super().__init__( + emformer_layers, + segment_length, + left_context_length=left_context_length, + right_context_length=right_context_length, + max_memory_size=max_memory_size, + ) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/hifi_gan.py b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/hifi_gan.py new file mode 100644 index 0000000000000000000000000000000000000000..93d92e1854651358367c6388fa3d916941961367 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/hifi_gan.py @@ -0,0 +1,336 @@ +""" +MIT License + +Copyright (c) 2020 Jungil Kong + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Conv1d, ConvTranspose1d + + +class HiFiGANVocoder(torch.nn.Module): + """Generator part of *HiFi GAN* :cite:`NEURIPS2020_c5d73680`. + Source: https://github.com/jik876/hifi-gan/blob/4769534d45265d52a904b850da5a622601885777/models.py#L75 + + Note: + To build the model, please use one of the factory functions: :py:func:`hifigan_vocoder`, + :py:func:`hifigan_vocoder_v1`, :py:func:`hifigan_vocoder_v2`, :py:func:`hifigan_vocoder_v3`. + + Args: + in_channels (int): Number of channels in the input features. + upsample_rates (tuple of ``int``): Factors by which each upsampling layer increases the time dimension. + upsample_initial_channel (int): Number of channels in the input feature tensor. + upsample_kernel_sizes (tuple of ``int``): Kernel size for each upsampling layer. + resblock_kernel_sizes (tuple of ``int``): Kernel size for each residual block. + resblock_dilation_sizes (tuple of tuples of ``int``): Dilation sizes for each 1D convolutional layer in each + residual block. For resblock type 1 inner tuples should have length 3, because there are 3 + convolutions in each layer. For resblock type 2 they should have length 2. + resblock_type (int, 1 or 2): Determines whether ``ResBlock1`` or ``ResBlock2`` will be used. + lrelu_slope (float): Slope of leaky ReLUs in activations. + """ + + def __init__( + self, + in_channels: int, + upsample_rates: Tuple[int, ...], + upsample_initial_channel: int, + upsample_kernel_sizes: Tuple[int, ...], + resblock_kernel_sizes: Tuple[int, ...], + resblock_dilation_sizes: Tuple[Tuple[int, ...], ...], + resblock_type: int, + lrelu_slope: float, + ): + super(HiFiGANVocoder, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3) + resblock = ResBlock1 if resblock_type == 1 else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for (k, d) in zip(resblock_kernel_sizes, resblock_dilation_sizes): + self.resblocks.append(resblock(ch, k, d, lrelu_slope)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3) + self.lrelu_slope = lrelu_slope + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (Tensor): Feature input tensor of shape `(batch_size, num_channels, time_length)`. + + Returns: + Tensor of shape `(batch_size, 1, time_length * upsample_rate)`, where `upsample_rate` is the product + of upsample rates for all layers. + """ + x = self.conv_pre(x) + for i, upsampling_layer in enumerate(self.ups): + x = F.leaky_relu(x, self.lrelu_slope) + x = upsampling_layer(x) + xs = torch.zeros_like(x) + for j in range(self.num_kernels): + res_block: ResBlockInterface = self.resblocks[i * self.num_kernels + j] + xs += res_block.forward(x) + x = xs / self.num_kernels + + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + +@torch.jit.interface +class ResBlockInterface(torch.nn.Module): + """Interface for ResBlock - necessary to make type annotations in ``HiFiGANVocoder.forward`` compatible + with TorchScript + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pass + + +class ResBlock1(torch.nn.Module): + """Residual block of type 1 for HiFiGAN Vocoder :cite:`NEURIPS2020_c5d73680`. + Args: + channels (int): Number of channels in the input features. + kernel_size (int, optional): Kernel size for 1D convolutions. (Default: ``3``) + dilation (tuple of 3 ``int``, optional): Dilations for each 1D convolution. (Default: ``(1, 3, 5)``) + lrelu_slope (float): Slope of leaky ReLUs in activations. + """ + + def __init__( + self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int, int] = (1, 3, 5), lrelu_slope: float = 0.1 + ): + super(ResBlock1, self).__init__() + self.convs1 = nn.ModuleList( + [ + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ), + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ), + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)), + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)), + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)), + ] + ) + self.lrelu_slope = lrelu_slope + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (Tensor): input of shape ``(batch_size, channels, time_length)``. + Returns: + Tensor of the same shape as input. + """ + for conv1, conv2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, self.lrelu_slope) + xt = conv1(xt) + xt = F.leaky_relu(xt, self.lrelu_slope) + xt = conv2(xt) + x = xt + x + return x + + +class ResBlock2(torch.nn.Module): + """Residual block of type 2 for HiFiGAN Vocoder :cite:`NEURIPS2020_c5d73680`. + Args: + channels (int): Number of channels in the input features. + kernel_size (int, optional): Kernel size for 1D convolutions. (Default: ``3``) + dilation (tuple of 2 ``int``, optional): Dilations for each 1D convolution. (Default: ``(1, 3)``) + lrelu_slope (float): Slope of leaky ReLUs in activations. + """ + + def __init__( + self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int] = (1, 3), lrelu_slope: float = 0.1 + ): + super(ResBlock2, self).__init__() + self.convs = nn.ModuleList( + [ + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ), + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ), + ] + ) + self.lrelu_slope = lrelu_slope + + def forward(self, x: torch.Tensor): + """ + Args: + x (Tensor): input of shape ``(batch_size, channels, time_length)``. + Returns: + Tensor of the same shape as input. + """ + for c in self.convs: + xt = F.leaky_relu(x, self.lrelu_slope) + xt = c(xt) + x = xt + x + return x + + +def get_padding(kernel_size, dilation=1): + """Find padding for which 1D convolution preserves the input shape.""" + return int((kernel_size * dilation - dilation) / 2) + + +def hifigan_vocoder( + in_channels: int, + upsample_rates: Tuple[int, ...], + upsample_initial_channel: int, + upsample_kernel_sizes: Tuple[int, ...], + resblock_kernel_sizes: Tuple[int, ...], + resblock_dilation_sizes: Tuple[Tuple[int, ...], ...], + resblock_type: int, + lrelu_slope: float, +) -> HiFiGANVocoder: + r"""Builds HiFi GAN Vocoder :cite:`NEURIPS2020_c5d73680`. + + Args: + in_channels (int): See :py:class:`HiFiGANVocoder`. + upsample_rates (tuple of ``int``): See :py:class:`HiFiGANVocoder`. + upsample_initial_channel (int): See :py:class:`HiFiGANVocoder`. + upsample_kernel_sizes (tuple of ``int``): See :py:class:`HiFiGANVocoder`. + resblock_kernel_sizes (tuple of ``int``): See :py:class:`HiFiGANVocoder`. + resblock_dilation_sizes (tuple of tuples of ``int``): See :py:class:`HiFiGANVocoder`. + resblock_type (int, 1 or 2): See :py:class:`HiFiGANVocoder`. + Returns: + HiFiGANVocoder: generated model. + """ + + return HiFiGANVocoder( + upsample_rates=upsample_rates, + resblock_kernel_sizes=resblock_kernel_sizes, + resblock_dilation_sizes=resblock_dilation_sizes, + resblock_type=resblock_type, + upsample_initial_channel=upsample_initial_channel, + upsample_kernel_sizes=upsample_kernel_sizes, + in_channels=in_channels, + lrelu_slope=lrelu_slope, + ) + + +def hifigan_vocoder_v1() -> HiFiGANVocoder: + r"""Builds HiFiGAN Vocoder with V1 architecture :cite:`NEURIPS2020_c5d73680`. + + Returns: + HiFiGANVocoder: generated model. + """ + return hifigan_vocoder( + upsample_rates=(8, 8, 2, 2), + upsample_kernel_sizes=(16, 16, 4, 4), + upsample_initial_channel=512, + resblock_kernel_sizes=(3, 7, 11), + resblock_dilation_sizes=((1, 3, 5), (1, 3, 5), (1, 3, 5)), + resblock_type=1, + in_channels=80, + lrelu_slope=0.1, + ) + + +def hifigan_vocoder_v2() -> HiFiGANVocoder: + r"""Builds HiFiGAN Vocoder with V2 architecture :cite:`NEURIPS2020_c5d73680`. + + Returns: + HiFiGANVocoder: generated model. + """ + return hifigan_vocoder( + upsample_rates=(8, 8, 2, 2), + upsample_kernel_sizes=(16, 16, 4, 4), + upsample_initial_channel=128, + resblock_kernel_sizes=(3, 7, 11), + resblock_dilation_sizes=((1, 3, 5), (1, 3, 5), (1, 3, 5)), + resblock_type=1, + in_channels=80, + lrelu_slope=0.1, + ) + + +def hifigan_vocoder_v3() -> HiFiGANVocoder: + r"""Builds HiFiGAN Vocoder with V3 architecture :cite:`NEURIPS2020_c5d73680`. + + Returns: + HiFiGANVocoder: generated model. + """ + return hifigan_vocoder( + upsample_rates=(8, 8, 4), + upsample_kernel_sizes=(16, 16, 8), + upsample_initial_channel=256, + resblock_kernel_sizes=(3, 5, 7), + resblock_dilation_sizes=((1, 2), (2, 6), (3, 12)), + resblock_type=2, + in_channels=80, + lrelu_slope=0.1, + ) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/rnnt.py b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/rnnt.py new file mode 100644 index 0000000000000000000000000000000000000000..aa7e32d5b961a80b2637995a3e3be07eaa5d7165 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/rnnt.py @@ -0,0 +1,711 @@ +import math +from typing import Dict, List, Optional, Tuple + +import torch +from torchaudio.models import Conformer, RNNT +from torchaudio.models.rnnt import _Joiner, _Predictor, _TimeReduction, _Transcriber + + +TrieNode = Tuple[Dict[int, "TrieNode"], int, Optional[Tuple[int, int]]] + + +class _ConformerEncoder(torch.nn.Module, _Transcriber): + def __init__( + self, + *, + input_dim: int, + output_dim: int, + time_reduction_stride: int, + conformer_input_dim: int, + conformer_ffn_dim: int, + conformer_num_layers: int, + conformer_num_heads: int, + conformer_depthwise_conv_kernel_size: int, + conformer_dropout: float, + ) -> None: + super().__init__() + self.time_reduction = _TimeReduction(time_reduction_stride) + self.input_linear = torch.nn.Linear(input_dim * time_reduction_stride, conformer_input_dim) + self.conformer = Conformer( + num_layers=conformer_num_layers, + input_dim=conformer_input_dim, + ffn_dim=conformer_ffn_dim, + num_heads=conformer_num_heads, + depthwise_conv_kernel_size=conformer_depthwise_conv_kernel_size, + dropout=conformer_dropout, + use_group_norm=True, + convolution_first=True, + ) + self.output_linear = torch.nn.Linear(conformer_input_dim, output_dim) + self.layer_norm = torch.nn.LayerNorm(output_dim) + + def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + time_reduction_out, time_reduction_lengths = self.time_reduction(input, lengths) + input_linear_out = self.input_linear(time_reduction_out) + x, lengths = self.conformer(input_linear_out, time_reduction_lengths) + output_linear_out = self.output_linear(x) + layer_norm_out = self.layer_norm(output_linear_out) + return layer_norm_out, lengths + + def infer( + self, + input: torch.Tensor, + lengths: torch.Tensor, + states: Optional[List[List[torch.Tensor]]], + ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: + raise RuntimeError("Conformer does not support streaming inference.") + + +class _JoinerBiasing(torch.nn.Module): + r"""Recurrent neural network transducer (RNN-T) joint network. + + Args: + input_dim (int): source and target input dimension. + output_dim (int): output dimension. + activation (str, optional): activation function to use in the joiner. + Must be one of ("relu", "tanh"). (Default: "relu") + biasing (bool): perform biasing + deepbiasing (bool): perform deep biasing + attndim (int): dimension of the biasing vector hptr + + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + activation: str = "relu", + biasing: bool = False, + deepbiasing: bool = False, + attndim: int = 1, + ) -> None: + super().__init__() + self.linear = torch.nn.Linear(input_dim, output_dim, bias=True) + self.biasing = biasing + self.deepbiasing = deepbiasing + if self.biasing and self.deepbiasing: + self.biasinglinear = torch.nn.Linear(attndim, input_dim, bias=True) + self.attndim = attndim + if activation == "relu": + self.activation = torch.nn.ReLU() + elif activation == "tanh": + self.activation = torch.nn.Tanh() + else: + raise ValueError(f"Unsupported activation {activation}") + + def forward( + self, + source_encodings: torch.Tensor, + source_lengths: torch.Tensor, + target_encodings: torch.Tensor, + target_lengths: torch.Tensor, + hptr: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + r"""Forward pass for training. + + B: batch size; + T: maximum source sequence length in batch; + U: maximum target sequence length in batch; + D: dimension of each source and target sequence encoding. + + Args: + source_encodings (torch.Tensor): source encoding sequences, with + shape `(B, T, D)`. + source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing + valid sequence length of i-th batch element in ``source_encodings``. + target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`. + target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing + valid sequence length of i-th batch element in ``target_encodings``. + hptr (torch.Tensor): deep biasing vector with shape `(B, T, U, A)`. + + Returns: + (torch.Tensor, torch.Tensor, torch.Tensor): + torch.Tensor + joint network output, with shape `(B, T, U, output_dim)`. + torch.Tensor + output source lengths, with shape `(B,)` and i-th element representing + number of valid elements along dim 1 for i-th batch element in joint network output. + torch.Tensor + output target lengths, with shape `(B,)` and i-th element representing + number of valid elements along dim 2 for i-th batch element in joint network output. + torch.Tensor + joint network second last layer output (i.e. before self.linear), with shape `(B, T, U, D)`. + """ + joint_encodings = source_encodings.unsqueeze(2).contiguous() + target_encodings.unsqueeze(1).contiguous() + if self.biasing and self.deepbiasing and hptr is not None: + hptr = self.biasinglinear(hptr) + joint_encodings += hptr + elif self.biasing and self.deepbiasing: + # Hack here for unused parameters + joint_encodings += self.biasinglinear(joint_encodings.new_zeros(1, self.attndim)).mean() * 0 + activation_out = self.activation(joint_encodings) + output = self.linear(activation_out) + return output, source_lengths, target_lengths, activation_out + + +class RNNTBiasing(RNNT): + r"""torchaudio.models.RNNT() + + Recurrent neural network transducer (RNN-T) model. + + Note: + To build the model, please use one of the factory functions. + + Args: + transcriber (torch.nn.Module): transcription network. + predictor (torch.nn.Module): prediction network. + joiner (torch.nn.Module): joint network. + attndim (int): TCPGen attention dimension + biasing (bool): If true, use biasing, otherwise use standard RNN-T + deepbiasing (bool): If true, use deep biasing by extracting the biasing vector + embdim (int): dimension of symbol embeddings + jointdim (int): dimension of the joint network joint dimension + charlist (list): The list of word piece tokens in the same order as the output layer + encoutdim (int): dimension of the encoder output vectors + dropout_tcpgen (float): dropout rate for TCPGen + tcpsche (int): The epoch at which TCPGen starts to train + DBaverage (bool): If true, instead of TCPGen, use DBRNNT for biasing + """ + + def __init__( + self, + transcriber: _Transcriber, + predictor: _Predictor, + joiner: _Joiner, + attndim: int, + biasing: bool, + deepbiasing: bool, + embdim: int, + jointdim: int, + charlist: List[str], + encoutdim: int, + dropout_tcpgen: float, + tcpsche: int, + DBaverage: bool, + ) -> None: + super().__init__(transcriber, predictor, joiner) + self.attndim = attndim + self.deepbiasing = deepbiasing + self.jointdim = jointdim + self.embdim = embdim + self.encoutdim = encoutdim + self.char_list = charlist or [] + self.blank_idx = self.char_list.index("") + self.nchars = len(self.char_list) + self.DBaverage = DBaverage + self.biasing = biasing + if self.biasing: + if self.deepbiasing and self.DBaverage: + # Deep biasing without TCPGen + self.biasingemb = torch.nn.Linear(self.nchars, self.attndim, bias=False) + else: + # TCPGen parameters + self.ooKBemb = torch.nn.Embedding(1, self.embdim) + self.Qproj_char = torch.nn.Linear(self.embdim, self.attndim) + self.Qproj_acoustic = torch.nn.Linear(self.encoutdim, self.attndim) + self.Kproj = torch.nn.Linear(self.embdim, self.attndim) + self.pointer_gate = torch.nn.Linear(self.attndim + self.jointdim, 1) + self.dropout_tcpgen = torch.nn.Dropout(dropout_tcpgen) + self.tcpsche = tcpsche + + def forward( + self, + sources: torch.Tensor, + source_lengths: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + tries: TrieNode, + current_epoch: int, + predictor_state: Optional[List[List[torch.Tensor]]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]], torch.Tensor, torch.Tensor]: + r"""Forward pass for training. + + B: batch size; + T: maximum source sequence length in batch; + U: maximum target sequence length in batch; + D: feature dimension of each source sequence element. + + Args: + sources (torch.Tensor): source frame sequences right-padded with right context, with + shape `(B, T, D)`. + source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing + number of valid frames for i-th batch element in ``sources``. + targets (torch.Tensor): target sequences, with shape `(B, U)` and each element + mapping to a target symbol. + target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing + number of valid frames for i-th batch element in ``targets``. + tries (TrieNode): wordpiece prefix trees representing the biasing list to be searched + current_epoch (Int): the current epoch number to determine if TCPGen should be trained + at this epoch + predictor_state (List[List[torch.Tensor]] or None, optional): list of lists of tensors + representing prediction network internal state generated in preceding invocation + of ``forward``. (Default: ``None``) + + Returns: + (torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]): + torch.Tensor + joint network output, with shape + `(B, max output source length, max output target length, output_dim (number of target symbols))`. + torch.Tensor + output source lengths, with shape `(B,)` and i-th element representing + number of valid elements along dim 1 for i-th batch element in joint network output. + torch.Tensor + output target lengths, with shape `(B,)` and i-th element representing + number of valid elements along dim 2 for i-th batch element in joint network output. + List[List[torch.Tensor]] + output states; list of lists of tensors + representing prediction network internal state generated in current invocation + of ``forward``. + torch.Tensor + TCPGen distribution, with shape + `(B, max output source length, max output target length, output_dim (number of target symbols))`. + torch.Tensor + Generation probability (or copy probability), with shape + `(B, max output source length, max output target length, 1)`. + """ + source_encodings, source_lengths = self.transcriber( + input=sources, + lengths=source_lengths, + ) + target_encodings, target_lengths, predictor_state = self.predictor( + input=targets, + lengths=target_lengths, + state=predictor_state, + ) + # Forward TCPGen + hptr = None + tcpgen_dist, p_gen = None, None + if self.biasing and current_epoch >= self.tcpsche and tries != []: + ptrdist_mask, p_gen_mask = self.get_tcpgen_step_masks(targets, tries) + hptr, tcpgen_dist = self.forward_tcpgen(targets, ptrdist_mask, source_encodings) + hptr = self.dropout_tcpgen(hptr) + elif self.biasing: + # Hack here to bypass unused parameters + if self.DBaverage and self.deepbiasing: + dummy = self.biasingemb(source_encodings.new_zeros(1, len(self.char_list))).mean() + else: + dummy = source_encodings.new_zeros(1, self.embdim) + dummy = self.Qproj_char(dummy).mean() + dummy += self.Qproj_acoustic(source_encodings.new_zeros(1, source_encodings.size(-1))).mean() + dummy += self.Kproj(source_encodings.new_zeros(1, self.embdim)).mean() + dummy += self.pointer_gate(source_encodings.new_zeros(1, self.attndim + self.jointdim)).mean() + dummy += self.ooKBemb.weight.mean() + dummy = dummy * 0 + source_encodings += dummy + + output, source_lengths, target_lengths, jointer_activation = self.joiner( + source_encodings=source_encodings, + source_lengths=source_lengths, + target_encodings=target_encodings, + target_lengths=target_lengths, + hptr=hptr, + ) + + # Calculate Generation Probability + if self.biasing and hptr is not None and tcpgen_dist is not None: + p_gen = torch.sigmoid(self.pointer_gate(torch.cat((jointer_activation, hptr), dim=-1))) + # avoid collapsing to ooKB token in the first few updates + # if current_epoch == self.tcpsche: + # p_gen = p_gen * 0.1 + p_gen = p_gen.masked_fill(p_gen_mask.bool().unsqueeze(1).unsqueeze(-1), 0) + + return (output, source_lengths, target_lengths, predictor_state, tcpgen_dist, p_gen) + + def get_tcpgen_distribution(self, query, ptrdist_mask): + # Make use of the predictor embedding matrix + keyvalues = torch.cat([self.predictor.embedding.weight.data, self.ooKBemb.weight], dim=0) + keyvalues = self.dropout_tcpgen(self.Kproj(keyvalues)) + # B * T * U * attndim, nbpe * attndim -> B * T * U * nbpe + tcpgendist = torch.einsum("ntuj,ij->ntui", query, keyvalues) + tcpgendist = tcpgendist / math.sqrt(query.size(-1)) + ptrdist_mask = ptrdist_mask.unsqueeze(1).repeat(1, tcpgendist.size(1), 1, 1) + tcpgendist.masked_fill_(ptrdist_mask.bool(), -1e9) + tcpgendist = torch.nn.functional.softmax(tcpgendist, dim=-1) + # B * T * U * nbpe, nbpe * attndim -> B * T * U * attndim + hptr = torch.einsum("ntui,ij->ntuj", tcpgendist[:, :, :, :-1], keyvalues[:-1, :]) + return hptr, tcpgendist + + def forward_tcpgen(self, targets, ptrdist_mask, source_encodings): + tcpgen_dist = None + if self.DBaverage and self.deepbiasing: + hptr = self.biasingemb(1 - ptrdist_mask[:, :, :-1].float()).unsqueeze(1) + else: + query_char = self.predictor.embedding(targets) + query_char = self.Qproj_char(query_char).unsqueeze(1) # B * 1 * U * attndim + query_acoustic = self.Qproj_acoustic(source_encodings).unsqueeze(2) # B * T * 1 * attndim + query = query_char + query_acoustic # B * T * U * attndim + hptr, tcpgen_dist = self.get_tcpgen_distribution(query, ptrdist_mask) + return hptr, tcpgen_dist + + def get_tcpgen_step_masks(self, yseqs, resettrie): + seqlen = len(yseqs[0]) + batch_masks = yseqs.new_ones(len(yseqs), seqlen, len(self.char_list) + 1) + p_gen_masks = [] + for i, yseq in enumerate(yseqs): + new_tree = resettrie + p_gen_mask = [] + for j, vy in enumerate(yseq): + vy = vy.item() + new_tree = new_tree[0] + if vy in [self.blank_idx]: + new_tree = resettrie + p_gen_mask.append(0) + elif self.char_list[vy].endswith("▁"): + if vy in new_tree and new_tree[vy][0] != {}: + new_tree = new_tree[vy] + else: + new_tree = resettrie + p_gen_mask.append(0) + elif vy not in new_tree: + new_tree = [{}] + p_gen_mask.append(1) + else: + new_tree = new_tree[vy] + p_gen_mask.append(0) + batch_masks[i, j, list(new_tree[0].keys())] = 0 + # In the original paper, ooKB node was not masked + # In this implementation, if not masking ooKB, ooKB probability + # would quickly collapse to 1.0 in the first few updates. + # Haven't found out why this happened. + # batch_masks[i, j, -1] = 0 + p_gen_masks.append(p_gen_mask + [1] * (seqlen - len(p_gen_mask))) + p_gen_masks = torch.Tensor(p_gen_masks).to(yseqs.device).byte() + return batch_masks, p_gen_masks + + def get_tcpgen_step_masks_prefix(self, yseqs, resettrie): + # Implemented for prefix-based wordpieces, not tested yet + seqlen = len(yseqs[0]) + batch_masks = yseqs.new_ones(len(yseqs), seqlen, len(self.char_list) + 1) + p_gen_masks = [] + for i, yseq in enumerate(yseqs): + p_gen_mask = [] + new_tree = resettrie + for j, vy in enumerate(yseq): + vy = vy.item() + new_tree = new_tree[0] + if vy in [self.blank_idx]: + new_tree = resettrie + batch_masks[i, j, list(new_tree[0].keys())] = 0 + elif self.char_list[vy].startswith("▁"): + new_tree = resettrie + if vy not in new_tree[0]: + batch_masks[i, j, list(new_tree[0].keys())] = 0 + else: + new_tree = new_tree[0][vy] + batch_masks[i, j, list(new_tree[0].keys())] = 0 + if new_tree[1] != -1: + batch_masks[i, j, list(resettrie[0].keys())] = 0 + else: + if vy not in new_tree: + new_tree = resettrie + batch_masks[i, j, list(new_tree[0].keys())] = 0 + else: + new_tree = new_tree[vy] + batch_masks[i, j, list(new_tree[0].keys())] = 0 + if new_tree[1] != -1: + batch_masks[i, j, list(resettrie[0].keys())] = 0 + p_gen_mask.append(0) + # batch_masks[i, j, -1] = 0 + p_gen_masks.append(p_gen_mask + [1] * (seqlen - len(p_gen_mask))) + p_gen_masks = torch.Tensor(p_gen_masks).to(yseqs.device).byte() + + return batch_masks, p_gen_masks + + def get_tcpgen_step(self, vy, trie, resettrie): + new_tree = trie[0] + if vy in [self.blank_idx]: + new_tree = resettrie + elif self.char_list[vy].endswith("▁"): + if vy in new_tree and new_tree[vy][0] != {}: + new_tree = new_tree[vy] + else: + new_tree = resettrie + elif vy not in new_tree: + new_tree = [{}] + else: + new_tree = new_tree[vy] + return new_tree + + def join( + self, + source_encodings: torch.Tensor, + source_lengths: torch.Tensor, + target_encodings: torch.Tensor, + target_lengths: torch.Tensor, + hptr: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r"""Applies joint network to source and target encodings. + + B: batch size; + T: maximum source sequence length in batch; + U: maximum target sequence length in batch; + D: dimension of each source and target sequence encoding. + A: TCPGen attention dimension + + Args: + source_encodings (torch.Tensor): source encoding sequences, with + shape `(B, T, D)`. + source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing + valid sequence length of i-th batch element in ``source_encodings``. + target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`. + target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing + valid sequence length of i-th batch element in ``target_encodings``. + hptr (torch.Tensor): deep biasing vector with shape `(B, T, U, A)`. + + Returns: + (torch.Tensor, torch.Tensor, torch.Tensor): + torch.Tensor + joint network output, with shape `(B, T, U, output_dim)`. + torch.Tensor + output source lengths, with shape `(B,)` and i-th element representing + number of valid elements along dim 1 for i-th batch element in joint network output. + torch.Tensor + joint network second last layer output, with shape `(B, T, U, D)`. + """ + output, source_lengths, target_lengths, jointer_activation = self.joiner( + source_encodings=source_encodings, + source_lengths=source_lengths, + target_encodings=target_encodings, + target_lengths=target_lengths, + hptr=hptr, + ) + return output, source_lengths, jointer_activation + + +def conformer_rnnt_model( + *, + input_dim: int, + encoding_dim: int, + time_reduction_stride: int, + conformer_input_dim: int, + conformer_ffn_dim: int, + conformer_num_layers: int, + conformer_num_heads: int, + conformer_depthwise_conv_kernel_size: int, + conformer_dropout: float, + num_symbols: int, + symbol_embedding_dim: int, + num_lstm_layers: int, + lstm_hidden_dim: int, + lstm_layer_norm: int, + lstm_layer_norm_epsilon: int, + lstm_dropout: int, + joiner_activation: str, +) -> RNNT: + r"""Builds Conformer-based recurrent neural network transducer (RNN-T) model. + + Args: + input_dim (int): dimension of input sequence frames passed to transcription network. + encoding_dim (int): dimension of transcription- and prediction-network-generated encodings + passed to joint network. + time_reduction_stride (int): factor by which to reduce length of input sequence. + conformer_input_dim (int): dimension of Conformer input. + conformer_ffn_dim (int): hidden layer dimension of each Conformer layer's feedforward network. + conformer_num_layers (int): number of Conformer layers to instantiate. + conformer_num_heads (int): number of attention heads in each Conformer layer. + conformer_depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer. + conformer_dropout (float): Conformer dropout probability. + num_symbols (int): cardinality of set of target tokens. + symbol_embedding_dim (int): dimension of each target token embedding. + num_lstm_layers (int): number of LSTM layers to instantiate. + lstm_hidden_dim (int): output dimension of each LSTM layer. + lstm_layer_norm (bool): if ``True``, enables layer normalization for LSTM layers. + lstm_layer_norm_epsilon (float): value of epsilon to use in LSTM layer normalization layers. + lstm_dropout (float): LSTM dropout probability. + joiner_activation (str): activation function to use in the joiner. + Must be one of ("relu", "tanh"). (Default: "relu") + + Returns: + RNNT: + Conformer RNN-T model. + """ + encoder = _ConformerEncoder( + input_dim=input_dim, + output_dim=encoding_dim, + time_reduction_stride=time_reduction_stride, + conformer_input_dim=conformer_input_dim, + conformer_ffn_dim=conformer_ffn_dim, + conformer_num_layers=conformer_num_layers, + conformer_num_heads=conformer_num_heads, + conformer_depthwise_conv_kernel_size=conformer_depthwise_conv_kernel_size, + conformer_dropout=conformer_dropout, + ) + predictor = _Predictor( + num_symbols=num_symbols, + output_dim=encoding_dim, + symbol_embedding_dim=symbol_embedding_dim, + num_lstm_layers=num_lstm_layers, + lstm_hidden_dim=lstm_hidden_dim, + lstm_layer_norm=lstm_layer_norm, + lstm_layer_norm_epsilon=lstm_layer_norm_epsilon, + lstm_dropout=lstm_dropout, + ) + joiner = _Joiner(encoding_dim, num_symbols, activation=joiner_activation) + return RNNT(encoder, predictor, joiner) + + +def conformer_rnnt_base() -> RNNT: + r"""Builds basic version of Conformer RNN-T model. + + Returns: + RNNT: + Conformer RNN-T model. + """ + return conformer_rnnt_model( + input_dim=80, + encoding_dim=1024, + time_reduction_stride=4, + conformer_input_dim=256, + conformer_ffn_dim=1024, + conformer_num_layers=16, + conformer_num_heads=4, + conformer_depthwise_conv_kernel_size=31, + conformer_dropout=0.1, + num_symbols=1024, + symbol_embedding_dim=256, + num_lstm_layers=2, + lstm_hidden_dim=512, + lstm_layer_norm=True, + lstm_layer_norm_epsilon=1e-5, + lstm_dropout=0.3, + joiner_activation="tanh", + ) + + +def conformer_rnnt_biasing( + *, + input_dim: int, + encoding_dim: int, + time_reduction_stride: int, + conformer_input_dim: int, + conformer_ffn_dim: int, + conformer_num_layers: int, + conformer_num_heads: int, + conformer_depthwise_conv_kernel_size: int, + conformer_dropout: float, + num_symbols: int, + symbol_embedding_dim: int, + num_lstm_layers: int, + lstm_hidden_dim: int, + lstm_layer_norm: int, + lstm_layer_norm_epsilon: int, + lstm_dropout: int, + joiner_activation: str, + attndim: int, + biasing: bool, + charlist: List[str], + deepbiasing: bool, + tcpsche: int, + DBaverage: bool, +) -> RNNTBiasing: + r"""Builds Conformer-based recurrent neural network transducer (RNN-T) model. + + Args: + input_dim (int): dimension of input sequence frames passed to transcription network. + encoding_dim (int): dimension of transcription- and prediction-network-generated encodings + passed to joint network. + time_reduction_stride (int): factor by which to reduce length of input sequence. + conformer_input_dim (int): dimension of Conformer input. + conformer_ffn_dim (int): hidden layer dimension of each Conformer layer's feedforward network. + conformer_num_layers (int): number of Conformer layers to instantiate. + conformer_num_heads (int): number of attention heads in each Conformer layer. + conformer_depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer. + conformer_dropout (float): Conformer dropout probability. + num_symbols (int): cardinality of set of target tokens. + symbol_embedding_dim (int): dimension of each target token embedding. + num_lstm_layers (int): number of LSTM layers to instantiate. + lstm_hidden_dim (int): output dimension of each LSTM layer. + lstm_layer_norm (bool): if ``True``, enables layer normalization for LSTM layers. + lstm_layer_norm_epsilon (float): value of epsilon to use in LSTM layer normalization layers. + lstm_dropout (float): LSTM dropout probability. + joiner_activation (str): activation function to use in the joiner. + Must be one of ("relu", "tanh"). (Default: "relu") + attndim (int): TCPGen attention dimension + biasing (bool): If true, use biasing, otherwise use standard RNN-T + charlist (list): The list of word piece tokens in the same order as the output layer + deepbiasing (bool): If true, use deep biasing by extracting the biasing vector + tcpsche (int): The epoch at which TCPGen starts to train + DBaverage (bool): If true, instead of TCPGen, use DBRNNT for biasing + + Returns: + RNNT: + Conformer RNN-T model with TCPGen-based biasing support. + """ + encoder = _ConformerEncoder( + input_dim=input_dim, + output_dim=encoding_dim, + time_reduction_stride=time_reduction_stride, + conformer_input_dim=conformer_input_dim, + conformer_ffn_dim=conformer_ffn_dim, + conformer_num_layers=conformer_num_layers, + conformer_num_heads=conformer_num_heads, + conformer_depthwise_conv_kernel_size=conformer_depthwise_conv_kernel_size, + conformer_dropout=conformer_dropout, + ) + predictor = _Predictor( + num_symbols=num_symbols, + output_dim=encoding_dim, + symbol_embedding_dim=symbol_embedding_dim, + num_lstm_layers=num_lstm_layers, + lstm_hidden_dim=lstm_hidden_dim, + lstm_layer_norm=lstm_layer_norm, + lstm_layer_norm_epsilon=lstm_layer_norm_epsilon, + lstm_dropout=lstm_dropout, + ) + joiner = _JoinerBiasing( + encoding_dim, + num_symbols, + activation=joiner_activation, + deepbiasing=deepbiasing, + attndim=attndim, + biasing=biasing, + ) + return RNNTBiasing( + encoder, + predictor, + joiner, + attndim, + biasing, + deepbiasing, + symbol_embedding_dim, + encoding_dim, + charlist, + encoding_dim, + conformer_dropout, + tcpsche, + DBaverage, + ) + + +def conformer_rnnt_biasing_base(charlist=None, biasing=True) -> RNNT: + r"""Builds basic version of Conformer RNN-T model with TCPGen. + + Returns: + RNNT: + Conformer RNN-T model with TCPGen-based biasing support. + """ + return conformer_rnnt_biasing( + input_dim=80, + encoding_dim=576, + time_reduction_stride=4, + conformer_input_dim=144, + conformer_ffn_dim=576, + conformer_num_layers=16, + conformer_num_heads=4, + conformer_depthwise_conv_kernel_size=31, + conformer_dropout=0.1, + num_symbols=601, + symbol_embedding_dim=256, + num_lstm_layers=1, + lstm_hidden_dim=320, + lstm_layer_norm=True, + lstm_layer_norm_epsilon=1e-5, + lstm_dropout=0.3, + joiner_activation="tanh", + attndim=256, + biasing=biasing, + charlist=charlist, + deepbiasing=True, + tcpsche=30, + DBaverage=False, + ) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/rnnt_decoder.py b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/rnnt_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4e7c4ac661e29da7d104f80541dc0c9919b98ea0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/prototype/models/rnnt_decoder.py @@ -0,0 +1,399 @@ +from typing import Callable, Dict, List, Optional, Tuple + +import torch +from torchaudio.models import RNNT +from torchaudio.prototype.models.rnnt import TrieNode + +__all__ = ["Hypothesis", "RNNTBeamSearchBiasing"] + + +Hypothesis = Tuple[List[int], torch.Tensor, List[List[torch.Tensor]], float, list] +Hypothesis.__doc__ = """Hypothesis generated by RNN-T beam search decoder, + represented as tuple of (tokens, prediction network output, prediction network state, score). + """ + + +def _get_hypo_tokens(hypo: Hypothesis) -> List[int]: + return hypo[0] + + +def _get_hypo_predictor_out(hypo: Hypothesis) -> torch.Tensor: + return hypo[1] + + +def _get_hypo_state(hypo: Hypothesis) -> List[List[torch.Tensor]]: + return hypo[2] + + +def _get_hypo_score(hypo: Hypothesis) -> float: + return hypo[3] + + +def _get_hypo_trie(hypo: Hypothesis) -> TrieNode: + return hypo[4] + + +def _set_hypo_trie(hypo: Hypothesis, trie: TrieNode) -> None: + hypo[4] = trie + + +def _get_hypo_key(hypo: Hypothesis) -> str: + return str(hypo[0]) + + +def _batch_state(hypos: List[Hypothesis]) -> List[List[torch.Tensor]]: + states: List[List[torch.Tensor]] = [] + for i in range(len(_get_hypo_state(hypos[0]))): + batched_state_components: List[torch.Tensor] = [] + for j in range(len(_get_hypo_state(hypos[0])[i])): + batched_state_components.append(torch.cat([_get_hypo_state(hypo)[i][j] for hypo in hypos])) + states.append(batched_state_components) + return states + + +def _slice_state(states: List[List[torch.Tensor]], idx: int, device: torch.device) -> List[List[torch.Tensor]]: + idx_tensor = torch.tensor([idx], device=device) + return [[state.index_select(0, idx_tensor) for state in state_tuple] for state_tuple in states] + + +def _default_hypo_sort_key(hypo: Hypothesis) -> float: + return _get_hypo_score(hypo) / (len(_get_hypo_tokens(hypo)) + 1) + + +def _compute_updated_scores( + hypos: List[Hypothesis], + next_token_probs: torch.Tensor, + beam_width: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hypo_scores = torch.tensor([_get_hypo_score(h) for h in hypos]).unsqueeze(1) + nonblank_scores = hypo_scores + next_token_probs[:, :-1] # [beam_width, num_tokens - 1] + nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(beam_width) + nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(nonblank_scores.shape[1], rounding_mode="trunc") + nonblank_nbest_token = nonblank_nbest_idx % nonblank_scores.shape[1] + return nonblank_nbest_scores, nonblank_nbest_hypo_idx, nonblank_nbest_token + + +def _remove_hypo(hypo: Hypothesis, hypo_list: List[Hypothesis]) -> None: + for i, elem in enumerate(hypo_list): + if _get_hypo_key(hypo) == _get_hypo_key(elem): + del hypo_list[i] + break + + +class RNNTBeamSearchBiasing(torch.nn.Module): + r"""Beam search decoder for RNN-T model with biasing support. + + Args: + model (RNNT): RNN-T model to use. + blank (int): index of blank token in vocabulary. + temperature (float, optional): temperature to apply to joint network output. + Larger values yield more uniform samples. (Default: 1.0) + hypo_sort_key (Callable[[Hypothesis], float] or None, optional): callable that computes a score + for a given hypothesis to rank hypotheses by. If ``None``, defaults to callable that returns + hypothesis score normalized by token sequence length. (Default: None) + step_max_tokens (int, optional): maximum number of tokens to emit per input time step. (Default: 100) + trie (list, optional): the prefix tree for TCPGen biasing + biasing (bool, optional): If true, do biasing, otherwise use standard RNN-T support + """ + + def __init__( + self, + model: RNNT, + blank: int, + temperature: float = 1.0, + hypo_sort_key: Optional[Callable[[Hypothesis], float]] = None, + step_max_tokens: int = 100, + trie: TrieNode = None, + biasing: bool = False, + ) -> None: + super().__init__() + self.model = model + self.blank = blank + self.temperature = temperature + self.resettrie = trie or [] + self.dobiasing = biasing + + if hypo_sort_key is None: + self.hypo_sort_key = _default_hypo_sort_key + else: + self.hypo_sort_key = hypo_sort_key + + self.step_max_tokens = step_max_tokens + + def _init_b_hypos(self, hypo: Optional[Hypothesis], device: torch.device) -> List[Hypothesis]: + if hypo is not None: + token = _get_hypo_tokens(hypo)[-1] + state = _get_hypo_state(hypo) + else: + token = self.blank + state = None + + one_tensor = torch.tensor([1], device=device) + pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state) + init_hypo = ([token], pred_out[0].detach(), pred_state, 0.0, self.resettrie) + return [init_hypo] + + def _get_trie_mask(self, trie): + step_mask = torch.ones(len(self.model.char_list) + 1) + step_mask[list(trie[0].keys())] = 0 + # step_mask[-1] = 0 + return step_mask + + def _get_generation_prob(self, trie): + if len(trie[0].keys()) == 0: + return True + else: + return False + + def _gen_next_token_probs( + self, enc_out: torch.Tensor, hypos: List[Hypothesis], device: torch.device + ) -> torch.Tensor: + one_tensor = torch.tensor([1], device=device) + predictor_out = torch.stack([_get_hypo_predictor_out(h) for h in hypos], dim=0) + if self.dobiasing: + # Get valid subset of wordpieces + trie_masks = torch.stack([self._get_trie_mask(_get_hypo_trie(h)) for h in hypos], dim=0) + trie_masks = trie_masks.to(enc_out.device).unsqueeze(1) # beam_width, 1, nchars + # Determine if there is any paths on the trie + genprob_masks = torch.tensor([self._get_generation_prob(_get_hypo_trie(h)) for h in hypos]) # beam_width + genprob_masks = genprob_masks.to(enc_out.device) + # Forward TCPGen component + last_tokens = torch.tensor([_get_hypo_tokens(h)[-1] for h in hypos]).unsqueeze(-1).to(enc_out.device) + hptr, tcpgen_dist = self.model.forward_tcpgen(last_tokens, trie_masks, enc_out) + else: + hptr = None + # hptr sent to joiner, if deepbiasing is True joiner will use it + joined_out, _, joined_activation = self.model.join( + enc_out, + one_tensor, + predictor_out, + torch.tensor([1] * len(hypos), device=device), + hptr=hptr, + ) # [beam_width, 1, 1, num_tokens] + if self.dobiasing: + p_gen = torch.sigmoid(self.model.pointer_gate(torch.cat((joined_activation, hptr), dim=-1))) + p_gen = p_gen.masked_fill(genprob_masks.view(p_gen.size(0), 1, 1, 1), 0) + model_tu = torch.softmax(joined_out / self.temperature, dim=3) + # assuming last token is blank + p_not_null = 1.0 - model_tu[:, :, :, -1:] + ptr_dist_fact = torch.cat([tcpgen_dist[:, :, :, :-2], tcpgen_dist[:, :, :, -1:]], dim=-1) * p_not_null + ptr_gen_complement = tcpgen_dist[:, :, :, -1:] * p_gen + p_partial = ptr_dist_fact[:, :, :, :-1] * p_gen + model_tu[:, :, :, :-1] * (1 - p_gen + ptr_gen_complement) + p_final = torch.cat([p_partial, model_tu[:, :, :, -1:]], dim=-1) + joined_out = torch.log(p_final) + else: + joined_out = torch.nn.functional.log_softmax(joined_out / self.temperature, dim=3) + return joined_out[:, 0, 0] + + def _gen_b_hypos( + self, + b_hypos: List[Hypothesis], + a_hypos: List[Hypothesis], + next_token_probs: torch.Tensor, + key_to_b_hypo: Dict[str, Hypothesis], + ) -> List[Hypothesis]: + for i in range(len(a_hypos)): + h_a = a_hypos[i] + append_blank_score = _get_hypo_score(h_a) + next_token_probs[i, -1] + if _get_hypo_key(h_a) in key_to_b_hypo: + h_b = key_to_b_hypo[_get_hypo_key(h_a)] + _remove_hypo(h_b, b_hypos) + score = float(torch.tensor(_get_hypo_score(h_b)).logaddexp(append_blank_score)) + else: + score = float(append_blank_score) + h_b = ( + _get_hypo_tokens(h_a), + _get_hypo_predictor_out(h_a), + _get_hypo_state(h_a), + score, + _get_hypo_trie(h_a), + ) + b_hypos.append(h_b) + key_to_b_hypo[_get_hypo_key(h_b)] = h_b + _, sorted_idx = torch.tensor([_get_hypo_score(hypo) for hypo in b_hypos]).sort() + return [b_hypos[idx] for idx in sorted_idx] + + def _gen_a_hypos( + self, + a_hypos: List[Hypothesis], + b_hypos: List[Hypothesis], + next_token_probs: torch.Tensor, + t: int, + beam_width: int, + device: torch.device, + ) -> List[Hypothesis]: + ( + nonblank_nbest_scores, + nonblank_nbest_hypo_idx, + nonblank_nbest_token, + ) = _compute_updated_scores(a_hypos, next_token_probs, beam_width) + + if len(b_hypos) < beam_width: + b_nbest_score = -float("inf") + else: + b_nbest_score = _get_hypo_score(b_hypos[-beam_width]) + + base_hypos: List[Hypothesis] = [] + new_tokens: List[int] = [] + new_scores: List[float] = [] + for i in range(beam_width): + score = float(nonblank_nbest_scores[i]) + if score > b_nbest_score: + a_hypo_idx = int(nonblank_nbest_hypo_idx[i]) + base_hypos.append(a_hypos[a_hypo_idx]) + new_tokens.append(int(nonblank_nbest_token[i])) + new_scores.append(score) + + if base_hypos: + new_hypos = self._gen_new_hypos(base_hypos, new_tokens, new_scores, t, device) + else: + new_hypos: List[Hypothesis] = [] + + return new_hypos + + def _gen_new_hypos( + self, + base_hypos: List[Hypothesis], + tokens: List[int], + scores: List[float], + t: int, + device: torch.device, + ) -> List[Hypothesis]: + tgt_tokens = torch.tensor([[token] for token in tokens], device=device) + states = _batch_state(base_hypos) + pred_out, _, pred_states = self.model.predict( + tgt_tokens, + torch.tensor([1] * len(base_hypos), device=device), + states, + ) + new_hypos: List[Hypothesis] = [] + for i, h_a in enumerate(base_hypos): + new_tokens = _get_hypo_tokens(h_a) + [tokens[i]] + if self.dobiasing: + new_trie = self.model.get_tcpgen_step(tokens[i], _get_hypo_trie(h_a), self.resettrie) + else: + new_trie = self.resettrie + new_hypos.append( + (new_tokens, pred_out[i].detach(), _slice_state(pred_states, i, device), scores[i], new_trie) + ) + return new_hypos + + def _search( + self, + enc_out: torch.Tensor, + hypo: Optional[Hypothesis], + beam_width: int, + ) -> List[Hypothesis]: + n_time_steps = enc_out.shape[1] + device = enc_out.device + + a_hypos: List[Hypothesis] = [] + b_hypos = self._init_b_hypos(hypo, device) + for t in range(n_time_steps): + a_hypos = b_hypos + b_hypos = torch.jit.annotate(List[Hypothesis], []) + key_to_b_hypo: Dict[str, Hypothesis] = {} + symbols_current_t = 0 + + while a_hypos: + next_token_probs = self._gen_next_token_probs(enc_out[:, t : t + 1], a_hypos, device) + next_token_probs = next_token_probs.cpu() + b_hypos = self._gen_b_hypos(b_hypos, a_hypos, next_token_probs, key_to_b_hypo) + + if symbols_current_t == self.step_max_tokens: + break + + a_hypos = self._gen_a_hypos( + a_hypos, + b_hypos, + next_token_probs, + t, + beam_width, + device, + ) + if a_hypos: + symbols_current_t += 1 + + _, sorted_idx = torch.tensor([self.hypo_sort_key(hypo) for hypo in b_hypos]).topk(beam_width) + b_hypos = [b_hypos[idx] for idx in sorted_idx] + + return b_hypos + + def forward( + self, + input: torch.Tensor, + length: torch.Tensor, + beam_width: int, + ) -> List[Hypothesis]: + r"""Performs beam search for the given input sequence. + + T: number of frames; + D: feature dimension of each frame. + + Args: + input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D). + length (torch.Tensor): number of valid frames in input + sequence, with shape () or (1,). + beam_width (int): beam size to use during search. + + Returns: + List[Hypothesis]: top-``beam_width`` hypotheses found by beam search. + """ + if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1): + raise ValueError("input must be of shape (T, D) or (1, T, D)") + if input.dim() == 2: + input = input.unsqueeze(0) + + if length.shape != () and length.shape != (1,): + raise ValueError("length must be of shape () or (1,)") + if input.dim() == 0: + input = input.unsqueeze(0) + + enc_out, _ = self.model.transcribe(input, length) + return self._search(enc_out, None, beam_width) + + @torch.jit.export + def infer( + self, + input: torch.Tensor, + length: torch.Tensor, + beam_width: int, + state: Optional[List[List[torch.Tensor]]] = None, + hypothesis: Optional[Hypothesis] = None, + ) -> Tuple[List[Hypothesis], List[List[torch.Tensor]]]: + r"""Performs beam search for the given input sequence in streaming mode. + + T: number of frames; + D: feature dimension of each frame. + + Args: + input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D). + length (torch.Tensor): number of valid frames in input + sequence, with shape () or (1,). + beam_width (int): beam size to use during search. + state (List[List[torch.Tensor]] or None, optional): list of lists of tensors + representing transcription network internal state generated in preceding + invocation. (Default: ``None``) + hypothesis (Hypothesis or None): hypothesis from preceding invocation to seed + search with. (Default: ``None``) + + Returns: + (List[Hypothesis], List[List[torch.Tensor]]): + List[Hypothesis] + top-``beam_width`` hypotheses found by beam search. + List[List[torch.Tensor]] + list of lists of tensors representing transcription network + internal state generated in current invocation. + """ + if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1): + raise ValueError("input must be of shape (T, D) or (1, T, D)") + if input.dim() == 2: + input = input.unsqueeze(0) + + if length.shape != () and length.shape != (1,): + raise ValueError("length must be of shape () or (1,)") + if length.dim() == 0: + length = length.unsqueeze(0) + + enc_out, _, state = self.model.transcribe_streaming(input, length, state) + return self._search(enc_out, hypothesis, beam_width), state diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bf7bb4b86e58121241ca4ffd176cac92d4342e6a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/__init__.py @@ -0,0 +1,12 @@ +from ._vggish import VGGISH, VGGishBundle +from .hifigan_pipeline import HIFIGAN_VOCODER_V3_LJSPEECH, HiFiGANVocoderBundle +from .rnnt_pipeline import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3 + +__all__ = [ + "EMFORMER_RNNT_BASE_MUSTC", + "EMFORMER_RNNT_BASE_TEDLIUM3", + "HIFIGAN_VOCODER_V3_LJSPEECH", + "HiFiGANVocoderBundle", + "VGGISH", + "VGGishBundle", +] diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee437284cabb73112ae2cdae79e174fce69710cb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/__pycache__/hifigan_pipeline.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/__pycache__/hifigan_pipeline.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55700ea5504dfe79abddaab4ebe3790963404643 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/__pycache__/hifigan_pipeline.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/__pycache__/rnnt_pipeline.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/__pycache__/rnnt_pipeline.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e82ed4ee3e03effc59dbe4498f8d3608025d586 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/__pycache__/rnnt_pipeline.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/_vggish/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/_vggish/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8cd4774f56a300d099b24f3a9e905224967da522 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/_vggish/__init__.py @@ -0,0 +1,3 @@ +from ._vggish_pipeline import VGGISH, VGGishBundle + +__all__ = ["VGGISH", "VGGishBundle"] diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/_vggish/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/_vggish/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ab1bd073a4ea8c7a86431fe72f313b017eb4d51 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/_vggish/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/_vggish/__pycache__/_vggish_impl.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/_vggish/__pycache__/_vggish_impl.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fb933d130eb0b55cf9041ebce85cd9e31274016 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/_vggish/__pycache__/_vggish_impl.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/_vggish/__pycache__/_vggish_pipeline.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/_vggish/__pycache__/_vggish_pipeline.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3cade98cb21ec252009af7f1c137e11fc15eb27f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/_vggish/__pycache__/_vggish_pipeline.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/_vggish/_vggish_impl.py b/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/_vggish/_vggish_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..6eb6ea8f59490eab777f2ba699f128d7c7876adc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/_vggish/_vggish_impl.py @@ -0,0 +1,233 @@ +# Derived from torchvggish (https://github.com/harritaylor/torchvggish). +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import math + +import torch + + +_MEL_BREAK_FREQUENCY_HERTZ = 700.0 +_MEL_HIGH_FREQUENCY_Q = 1127.0 + + +_SAMPLE_RATE = 16000 +_STFT_WINDOW_LENGTH_SECONDS = 0.025 +_STFT_HOP_LENGTH_SECONDS = 0.010 +_MEL_MIN_HZ = 125 +_MEL_MAX_HZ = 7500 +_NUM_BANDS = 64 +_LOG_OFFSET = 0.01 +_EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames +_EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap. + + +def _build_features_network(): + layers = [] + + for input_dim, output_dim in [(1, 64), (64, 128)]: + layers += [ + torch.nn.Conv2d(input_dim, output_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), + torch.nn.ReLU(inplace=True), + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), + ] + + for input_dim, output_dim in [(128, 256), (256, 512)]: + layers += [ + torch.nn.Conv2d(input_dim, output_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), + torch.nn.ReLU(inplace=True), + torch.nn.Conv2d( + output_dim, + output_dim, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + ), + torch.nn.ReLU(inplace=True), + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), + ] + + return torch.nn.Sequential(*layers) + + +def _build_embedding_network(): + return torch.nn.Sequential( + torch.nn.Linear(512 * 4 * 6, 4096), + torch.nn.ReLU(True), + torch.nn.Linear(4096, 4096), + torch.nn.ReLU(True), + torch.nn.Linear(4096, 128), + torch.nn.ReLU(True), + ) + + +def _frame(data, window_length, hop_length): + num_samples = data.shape[0] + num_frames = 1 + int(math.floor((num_samples - window_length) / hop_length)) + shape = (num_frames, window_length) + data.shape[1:] + strides = (data.stride()[0] * hop_length,) + data.stride() + return torch.as_strided(data, shape, strides) + + +def _stft_magnitude(signal, fft_length, hop_length=None, window_length=None): + frames = _frame(signal, window_length, hop_length) + window = torch.hann_window(window_length, periodic=True).to(signal.device) + windowed_frames = frames * window + return torch.abs(torch.fft.rfft(windowed_frames, int(fft_length))) + + +def _hertz_to_mel(frequencies_hertz): + return _MEL_HIGH_FREQUENCY_Q * torch.log(1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)) + + +def _spectrogram_to_mel_matrix( + num_mel_bins=20, + num_spectrogram_bins=129, + audio_sample_rate=8000, + lower_edge_hertz=125.0, + upper_edge_hertz=3800.0, +): + nyquist_hertz = audio_sample_rate / 2.0 + if lower_edge_hertz < 0.0: + raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz) + if lower_edge_hertz >= upper_edge_hertz: + raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" % (lower_edge_hertz, upper_edge_hertz)) + + if upper_edge_hertz > nyquist_hertz: + raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" % (upper_edge_hertz, nyquist_hertz)) + spectrogram_bins_hertz = torch.linspace(0.0, nyquist_hertz, num_spectrogram_bins) + + spectrogram_bins_mel = _hertz_to_mel(spectrogram_bins_hertz) + # The i'th mel band (starting from i=1) has center frequency + # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge + # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in + # the band_edges_mel arrays. + band_edges_mel = torch.linspace( + _hertz_to_mel(torch.tensor(lower_edge_hertz)), + _hertz_to_mel(torch.tensor(upper_edge_hertz)), + num_mel_bins + 2, + ) + # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins + # of spectrogram values. + mel_weights_matrix = torch.empty((num_spectrogram_bins, num_mel_bins)) + for i in range(num_mel_bins): + lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i : i + 3] + # Calculate lower and upper slopes for every spectrogram bin. + # Line segments are linear in the *mel* domain, not hertz. + lower_slope = (spectrogram_bins_mel - lower_edge_mel) / (center_mel - lower_edge_mel) + upper_slope = (upper_edge_mel - spectrogram_bins_mel) / (upper_edge_mel - center_mel) + + # .. then intersect them with each other and zero. + mel_weights_matrix[:, i] = torch.maximum(torch.tensor(0.0), torch.minimum(lower_slope, upper_slope)) + + # HTK excludes the spectrogram DC bin; make sure it always gets a zero + # coefficient. + mel_weights_matrix[0, :] = 0.0 + return mel_weights_matrix + + +def _log_mel_spectrogram( + data, + audio_sample_rate=8000, + log_offset=0.0, + window_length_secs=0.025, + hop_length_secs=0.010, + **kwargs, +): + window_length_samples = int(round(audio_sample_rate * window_length_secs)) + hop_length_samples = int(round(audio_sample_rate * hop_length_secs)) + fft_length = 2 ** int(math.ceil(math.log(window_length_samples) / math.log(2.0))) + + spectrogram = _stft_magnitude( + data, + fft_length=fft_length, + hop_length=hop_length_samples, + window_length=window_length_samples, + ) + mel_spectrogram = torch.matmul( + spectrogram, + _spectrogram_to_mel_matrix( + num_spectrogram_bins=spectrogram.shape[1], + audio_sample_rate=audio_sample_rate, + **kwargs, + ).to(spectrogram), + ) + return torch.log(mel_spectrogram + log_offset) + + +def _waveform_to_examples(data): + # Compute log mel spectrogram features, with shape (n_frame, n_mel) + log_mel = _log_mel_spectrogram( + data, + audio_sample_rate=_SAMPLE_RATE, + log_offset=_LOG_OFFSET, + window_length_secs=_STFT_WINDOW_LENGTH_SECONDS, + hop_length_secs=_STFT_HOP_LENGTH_SECONDS, + num_mel_bins=_NUM_BANDS, + lower_edge_hertz=_MEL_MIN_HZ, + upper_edge_hertz=_MEL_MAX_HZ, + ) + + # Frame features into examples, with shape (n_example, n_frame, n_mel) + features_sample_rate = 1.0 / _STFT_HOP_LENGTH_SECONDS + example_window_length = int(round(_EXAMPLE_WINDOW_SECONDS * features_sample_rate)) + + example_hop_length = int(round(_EXAMPLE_HOP_SECONDS * features_sample_rate)) + log_mel_examples = _frame(log_mel, window_length=example_window_length, hop_length=example_hop_length) + + # (n_example, 1, n_frame, n_mel) + return log_mel_examples.unsqueeze(1) + + +class VGGish(torch.nn.Module): + """Implementation of VGGish model :cite:`45611`.""" + + def __init__(self): + super().__init__() + + self.features_network = _build_features_network() + self.embedding_network = _build_embedding_network() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + """ + Args: + input (torch.Tensor): batch of spectrograms, with shape `(n_example, 1, n_frame, 64)`. + + Returns: + torch.Tensor: model output, with shape `(n_example, 128)`. + """ + x = self.features_network(input) + + x = x.permute(0, 2, 3, 1) + x = x.reshape(x.size(0), -1) + + return self.embedding_network(x) + + +class VGGishInputProcessor: + """Converts raw waveforms to batches of examples to use as inputs to VGGish.""" + + def __call__(self, input: torch.Tensor) -> torch.Tensor: + """ + Args: + input (torch.Tensor): waveform, with shape `(T,)`. + sample_rate (int): sample rate of waveform in hertz. + + Returns: + torch.Tensor: batch of examples to pass to VGGish, with shape `(n_example, 1, n_frame, 64)`. + """ + if len(input.shape) != 1: + raise ValueError("input waveform must have dimension of 1.") + return _waveform_to_examples(input) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/_vggish/_vggish_pipeline.py b/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/_vggish/_vggish_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..7b0f527e73eb61fe4a9fa7d7d86ea467f9ae8a9e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/_vggish/_vggish_pipeline.py @@ -0,0 +1,82 @@ +from dataclasses import dataclass +from typing import Callable, Dict + +import torch +import torchaudio + +from ._vggish_impl import _SAMPLE_RATE, VGGish as _VGGish, VGGishInputProcessor as _VGGishInputProcessor + + +def _get_state_dict(): + path = torchaudio.utils.download_asset("models/vggish.pt") + return torch.load(path) + + +@dataclass +class VGGishBundle: + """VGGish :cite:`45611` inference pipeline ported from + `torchvggish `__ + and `tensorflow-models `__. + + Example: + >>> import torchaudio + >>> from torchaudio.prototype.pipelines import VGGISH + >>> + >>> input_sr = VGGISH.sample_rate + >>> input_proc = VGGISH.get_input_processor() + >>> model = VGGISH.get_model() + >>> + >>> waveform, sr = torchaudio.load( + >>> "Chopin_Ballade_-1_In_G_Minor,_Op._23.mp3", + >>> ) + >>> waveform = waveform.squeeze(0) + >>> waveform = torchaudio.functional.resample(waveform, sr, input_sr) + >>> mono_output = model(input_proc(waveform)) + """ + + class VGGish(_VGGish): + __doc__ = _VGGish.__doc__ + + class VGGishInputProcessor(_VGGishInputProcessor): + __doc__ = _VGGishInputProcessor.__doc__ + + _state_dict_func: Callable[[], Dict] + + @property + def sample_rate(self) -> int: + """Sample rate of input waveform expected by input processor and model. + + :type: int + """ + return _SAMPLE_RATE + + def get_model(self) -> VGGish: + """Constructs pre-trained VGGish model. Downloads and caches weights as necessary. + + Returns: + VGGish: VGGish model with pre-trained weights loaded. + """ + model = self.VGGish() + state_dict = self._state_dict_func() + model.load_state_dict(state_dict) + model.eval() + return model + + def get_input_processor(self) -> VGGishInputProcessor: + """Constructs input processor for VGGish. + + Returns: + VGGishInputProcessor: input processor for VGGish. + """ + return self.VGGishInputProcessor() + + +VGGISH = VGGishBundle(_get_state_dict) +VGGISH.__doc__ = """Pre-trained VGGish :cite:`45611` inference pipeline ported from + `torchvggish `__ + and `tensorflow-models `__. + + Per the `documentation `__ + for the original model, the model is "trained on a large YouTube dataset (a preliminary version of + what later became YouTube-8M)". + """ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/hifigan_pipeline.py b/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/hifigan_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..9aa383deb010c872ef5817962396cb122281b382 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/hifigan_pipeline.py @@ -0,0 +1,228 @@ +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch.nn import Module +from torchaudio._internal import load_state_dict_from_url + +from torchaudio.prototype.models.hifi_gan import hifigan_vocoder, HiFiGANVocoder +from torchaudio.transforms import MelSpectrogram + + +@dataclass +class HiFiGANVocoderBundle: + """Data class that bundles associated information to use pretrained + :py:class:`~torchaudio.prototype.models.HiFiGANVocoder`. + + This class provides interfaces for instantiating the pretrained model along with + the information necessary to retrieve pretrained weights and additional data + to be used with the model. + + Torchaudio library instantiates objects of this class, each of which represents + a different pretrained model. Client code should access pretrained models via these + instances. + + This bundle can convert mel spectrorgam to waveforms and vice versa. A typical use case would be a flow like + `text -> mel spectrogram -> waveform`, where one can use an external component, e.g. Tacotron2, + to generate mel spectrogram from text. Please see below for the code example. + + Example: Transform synthetic mel spectrogram to audio. + >>> import torch + >>> import torchaudio + >>> # Since HiFiGAN bundle is in prototypes, it needs to be exported explicitly + >>> from torchaudio.prototype.pipelines import HIFIGAN_VOCODER_V3_LJSPEECH as bundle + >>> + >>> # Load the HiFiGAN bundle + >>> vocoder = bundle.get_vocoder() + Downloading: "https://download.pytorch.org/torchaudio/models/hifigan_vocoder_v3_ljspeech.pth" + 100%|████████████| 5.59M/5.59M [00:00<00:00, 18.7MB/s] + >>> + >>> # Generate synthetic mel spectrogram + >>> specgram = torch.sin(0.5 * torch.arange(start=0, end=100)).expand(bundle._vocoder_params["in_channels"], 100) + >>> + >>> # Transform mel spectrogram into audio + >>> waveform = vocoder(specgram) + >>> torchaudio.save('sample.wav', waveform, bundle.sample_rate) + + Example: Usage together with Tacotron2, text to audio. + >>> import torch + >>> import torchaudio + >>> # Since HiFiGAN bundle is in prototypes, it needs to be exported explicitly + >>> from torchaudio.prototype.pipelines import HIFIGAN_VOCODER_V3_LJSPEECH as bundle_hifigan + >>> + >>> # Load Tacotron2 bundle + >>> bundle_tactron2 = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH + >>> processor = bundle_tactron2.get_text_processor() + >>> tacotron2 = bundle_tactron2.get_tacotron2() + >>> + >>> # Use Tacotron2 to convert text to mel spectrogram + >>> text = "A quick brown fox jumped over a lazy dog" + >>> input, lengths = processor(text) + >>> specgram, lengths, _ = tacotron2.infer(input, lengths) + >>> + >>> # Load HiFiGAN bundle + >>> vocoder = bundle_hifigan.get_vocoder() + Downloading: "https://download.pytorch.org/torchaudio/models/hifigan_vocoder_v3_ljspeech.pth" + 100%|████████████| 5.59M/5.59M [00:03<00:00, 1.55MB/s] + >>> + >>> # Use HiFiGAN to convert mel spectrogram to audio + >>> waveform = vocoder(specgram).squeeze(0) + >>> torchaudio.save('sample.wav', waveform, bundle_hifigan.sample_rate) + """ # noqa: E501 + + _path: str + _vocoder_params: Dict[str, Any] # Vocoder parameters + _mel_params: Dict[str, Any] # Mel transformation parameters + _sample_rate: float + + def _get_state_dict(self, dl_kwargs): + url = f"https://download.pytorch.org/torchaudio/models/{self._path}" + dl_kwargs = {} if dl_kwargs is None else dl_kwargs + state_dict = load_state_dict_from_url(url, **dl_kwargs) + return state_dict + + def get_vocoder(self, *, dl_kwargs=None) -> HiFiGANVocoder: + """Construct the HiFiGAN Generator model, which can be used a vocoder, and load the pretrained weight. + + The weight file is downloaded from the internet and cached with + :func:`torch.hub.load_state_dict_from_url` + + Args: + dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. + + Returns: + Variation of :py:class:`~torchaudio.prototype.models.HiFiGANVocoder`. + """ + model = hifigan_vocoder(**self._vocoder_params) + model.load_state_dict(self._get_state_dict(dl_kwargs)) + model.eval() + return model + + def get_mel_transform(self) -> Module: + """Construct an object which transforms waveforms into mel spectrograms.""" + return _HiFiGANMelSpectrogram( + n_mels=self._vocoder_params["in_channels"], + sample_rate=self._sample_rate, + **self._mel_params, + ) + + @property + def sample_rate(self): + """Sample rate of the audio that the model is trained on. + + :type: float + """ + return self._sample_rate + + +class _HiFiGANMelSpectrogram(torch.nn.Module): + """ + Generate mel spectrogram in a way equivalent to the original HiFiGAN implementation: + https://github.com/jik876/hifi-gan/blob/4769534d45265d52a904b850da5a622601885777/meldataset.py#L49-L72 + + This class wraps around :py:class:`torchaudio.transforms.MelSpectrogram`, but performs extra steps to achive + equivalence with the HiFiGAN implementation. + + Args: + hop_size (int): Length of hop between STFT windows. + n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins. + win_length (int): Window size. + f_min (float or None): Minimum frequency. + f_max (float or None): Maximum frequency. + sample_rate (int): Sample rate of audio signal. + n_mels (int): Number of mel filterbanks. + """ + + def __init__( + self, + hop_size: int, + n_fft: int, + win_length: int, + f_min: Optional[float], + f_max: Optional[float], + sample_rate: float, + n_mels: int, + ): + super(_HiFiGANMelSpectrogram, self).__init__() + self.mel_transform = MelSpectrogram( + sample_rate=sample_rate, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_size, + f_min=f_min, + f_max=f_max, + n_mels=n_mels, + normalized=False, + pad=0, + mel_scale="slaney", + norm="slaney", + center=False, + ) + self.sample_rate = sample_rate + self.hop_size = hop_size + self.n_fft = n_fft + self.win_length = win_length + self.f_min = f_min + self.f_max = f_max + self.n_mels = n_mels + self.pad_size = int((n_fft - hop_size) / 2) + + def forward(self, waveform: torch.Tensor) -> torch.Tensor: + """Generate mel spectrogram from a waveform. Should have same sample rate as ``self.sample_rate``. + + Args: + waveform (Tensor): waveform of shape ``(batch_size, time_length)``. + Returns: + Tensor of shape ``(batch_size, n_mel, time_length)`` + """ + ref_waveform = F.pad(waveform.unsqueeze(1), (self.pad_size, self.pad_size), mode="reflect") + ref_waveform = ref_waveform.squeeze(1) + + spectr = (self.mel_transform.spectrogram(ref_waveform) + 1e-9) ** 0.5 + mel_spectrogram = self.mel_transform.mel_scale(spectr) + mel_spectrogram = torch.log(torch.clamp(mel_spectrogram, min=1e-5)) + return mel_spectrogram + + +HIFIGAN_VOCODER_V3_LJSPEECH = HiFiGANVocoderBundle( + "hifigan_vocoder_v3_ljspeech.pth", + _vocoder_params={ + "upsample_rates": (8, 8, 4), + "upsample_kernel_sizes": (16, 16, 8), + "upsample_initial_channel": 256, + "resblock_kernel_sizes": (3, 5, 7), + "resblock_dilation_sizes": ((1, 2), (2, 6), (3, 12)), + "resblock_type": 2, + "in_channels": 80, + "lrelu_slope": 0.1, + }, + _mel_params={ + "hop_size": 256, + "n_fft": 1024, + "win_length": 1024, + "f_min": 0, + "f_max": 8000, + }, + _sample_rate=22050, +) +HIFIGAN_VOCODER_V3_LJSPEECH.__doc__ = """HiFiGAN Vocoder pipeline, trained on *The LJ Speech Dataset* + :cite:`ljspeech17`. + + This pipeine can be used with an external component which generates mel spectrograms from text, for example, + Tacotron2 - see examples in :py:class:`HiFiGANVocoderBundle`. + Although this works with the existing Tacotron2 bundles, for the best results one needs to retrain Tacotron2 + using the same data preprocessing pipeline which was used for training HiFiGAN. In particular, the original + HiFiGAN implementation uses a custom method of generating mel spectrograms from waveforms, different from + :py:class:`torchaudio.transforms.MelSpectrogram`. We reimplemented this transform as + :py:meth:`HiFiGANVocoderBundle.get_mel_transform`, making sure it is equivalent to the original HiFiGAN code `here + `_. + + The underlying vocoder is constructed by + :py:func:`torchaudio.prototype.models.hifigan_vocoder`. The weights are converted from the ones published + with the original paper :cite:`NEURIPS2020_c5d73680` under `MIT License + `__. See links to + pre-trained models on `GitHub `__. + + Please refer to :py:class:`HiFiGANVocoderBundle` for usage instructions. + """ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/rnnt_pipeline.py b/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/rnnt_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..c82e2f83a2a4a1dc241e5b1cf15fad0690446d72 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/prototype/pipelines/rnnt_pipeline.py @@ -0,0 +1,58 @@ +from functools import partial + +from torchaudio.models import emformer_rnnt_base +from torchaudio.pipelines import RNNTBundle + + +EMFORMER_RNNT_BASE_MUSTC = RNNTBundle( + _rnnt_path="models/emformer_rnnt_base_mustc.pt", + _rnnt_factory_func=partial(emformer_rnnt_base, num_symbols=501), + _global_stats_path="pipeline-assets/global_stats_rnnt_mustc.json", + _sp_model_path="pipeline-assets/spm_bpe_500_mustc.model", + _right_padding=4, + _blank=500, + _sample_rate=16000, + _n_fft=400, + _n_mels=80, + _hop_length=160, + _segment_length=16, + _right_context_length=4, +) +EMFORMER_RNNT_BASE_MUSTC.__doc__ = """Pre-trained Emformer-RNNT-based ASR pipeline capable of performing both +streaming and non-streaming inference. + +The underlying model is constructed by :py:func:`torchaudio.models.emformer_rnnt_base` +and utilizes weights trained on *MuST-C release v2.0* :cite:`CATTONI2021101155` dataset +using training script ``train.py`` +`here `__ +with ``num_symbols=501``. + +Please refer to :py:class:`torchaudio.pipelines.RNNTBundle` for usage instructions. +""" + + +EMFORMER_RNNT_BASE_TEDLIUM3 = RNNTBundle( + _rnnt_path="models/emformer_rnnt_base_tedlium3.pt", + _rnnt_factory_func=partial(emformer_rnnt_base, num_symbols=501), + _global_stats_path="pipeline-assets/global_stats_rnnt_tedlium3.json", + _sp_model_path="pipeline-assets/spm_bpe_500_tedlium3.model", + _right_padding=4, + _blank=500, + _sample_rate=16000, + _n_fft=400, + _n_mels=80, + _hop_length=160, + _segment_length=16, + _right_context_length=4, +) +EMFORMER_RNNT_BASE_TEDLIUM3.__doc__ = """Pre-trained Emformer-RNNT-based ASR pipeline capable of performing both +streaming and non-streaming inference. + +The underlying model is constructed by :py:func:`torchaudio.models.emformer_rnnt_base` +and utilizes weights trained on *TED-LIUM Release 3* :cite:`rousseau2012tedlium` dataset +using training script ``train.py`` +`here `__ +with ``num_symbols=501``. + +Please refer to :py:class:`torchaudio.pipelines.RNNTBundle` for usage instructions. +""" diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/transforms/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/prototype/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..457f20e119a0640336ff91eb92ff68dd42fd23f6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/prototype/transforms/__init__.py @@ -0,0 +1,9 @@ +from ._transforms import BarkScale, BarkSpectrogram, ChromaScale, ChromaSpectrogram, InverseBarkScale + +__all__ = [ + "BarkScale", + "BarkSpectrogram", + "ChromaScale", + "ChromaSpectrogram", + "InverseBarkScale", +] diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/transforms/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/prototype/transforms/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3edcb6c0d7ed7229151fcf0bbb871c05fc45e161 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/prototype/transforms/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/transforms/__pycache__/_transforms.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/prototype/transforms/__pycache__/_transforms.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e74307a15884654c083e03e55812c61ba2d924d8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/prototype/transforms/__pycache__/_transforms.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/prototype/transforms/_transforms.py b/.venv/lib/python3.11/site-packages/torchaudio/prototype/transforms/_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..9d89cc5339c84d927c5a4d91a014026a9242f675 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/prototype/transforms/_transforms.py @@ -0,0 +1,456 @@ +from typing import Callable, Optional + +import torch +from torchaudio.prototype.functional import barkscale_fbanks, chroma_filterbank +from torchaudio.transforms import Spectrogram + + +class BarkScale(torch.nn.Module): + r"""Turn a normal STFT into a bark frequency STFT with triangular filter banks. + + .. devices:: CPU CUDA + + .. properties:: Autograd TorchScript + + Args: + n_barks (int, optional): Number of bark filterbanks. (Default: ``128``) + sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) + f_min (float, optional): Minimum frequency. (Default: ``0.``) + f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``) + n_stft (int, optional): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. (Default: ``201``) + norm (str or None, optional): If ``"slaney"``, divide the triangular bark weights by the width of the bark band + (area normalization). (Default: ``None``) + bark_scale (str, optional): Scale to use: ``traunmuller``, ``schroeder`` or ``wang``. (Default: ``traunmuller``) + + Example + >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True) + >>> spectrogram_transform = transforms.Spectrogram(n_fft=1024) + >>> spectrogram = spectrogram_transform(waveform) + >>> barkscale_transform = transforms.BarkScale(sample_rate=sample_rate, n_stft=1024 // 2 + 1) + >>> barkscale_spectrogram = barkscale_transform(spectrogram) + + See also: + :py:func:`torchaudio.prototype.functional.barkscale_fbanks` - The function used to + generate the filter banks. + """ + __constants__ = ["n_barks", "sample_rate", "f_min", "f_max"] + + def __init__( + self, + n_barks: int = 128, + sample_rate: int = 16000, + f_min: float = 0.0, + f_max: Optional[float] = None, + n_stft: int = 201, + bark_scale: str = "traunmuller", + ) -> None: + super(BarkScale, self).__init__() + self.n_barks = n_barks + self.sample_rate = sample_rate + self.f_max = f_max if f_max is not None else float(sample_rate // 2) + self.f_min = f_min + self.bark_scale = bark_scale + + if f_min > self.f_max: + raise ValueError("Require f_min: {} <= f_max: {}".format(f_min, self.f_max)) + + fb = barkscale_fbanks(n_stft, self.f_min, self.f_max, self.n_barks, self.sample_rate, self.bark_scale) + self.register_buffer("fb", fb) + + def forward(self, specgram: torch.Tensor) -> torch.Tensor: + r""" + Args: + specgram (torch.Tensor): A spectrogram STFT of dimension (..., freq, time). + + Returns: + torch.Tensor: Bark frequency spectrogram of size (..., ``n_barks``, time). + """ + + # (..., time, freq) dot (freq, n_mels) -> (..., n_mels, time) + bark_specgram = torch.matmul(specgram.transpose(-1, -2), self.fb).transpose(-1, -2) + + return bark_specgram + + +class InverseBarkScale(torch.nn.Module): + r"""Estimate a STFT in normal frequency domain from bark frequency domain. + + .. devices:: CPU CUDA + + It minimizes the euclidian norm between the input bark-spectrogram and the product between + the estimated spectrogram and the filter banks using SGD. + + Args: + n_stft (int): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. + n_barks (int, optional): Number of bark filterbanks. (Default: ``128``) + sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) + f_min (float, optional): Minimum frequency. (Default: ``0.``) + f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``) + max_iter (int, optional): Maximum number of optimization iterations. (Default: ``100000``) + tolerance_loss (float, optional): Value of loss to stop optimization at. (Default: ``1e-5``) + tolerance_change (float, optional): Difference in losses to stop optimization at. (Default: ``1e-8``) + sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``) + bark_scale (str, optional): Scale to use: ``traunmuller``, ``schroeder`` or ``wang``. (Default: ``traunmuller``) + + Example + >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True) + >>> mel_spectrogram_transform = transforms.BarkSpectrogram(sample_rate, n_fft=1024) + >>> mel_spectrogram = bark_spectrogram_transform(waveform) + >>> inverse_barkscale_transform = transforms.InverseBarkScale(n_stft=1024 // 2 + 1) + >>> spectrogram = inverse_barkscale_transform(mel_spectrogram) + """ + __constants__ = [ + "n_stft", + "n_barks", + "sample_rate", + "f_min", + "f_max", + "max_iter", + "tolerance_loss", + "tolerance_change", + "sgdargs", + ] + + def __init__( + self, + n_stft: int, + n_barks: int = 128, + sample_rate: int = 16000, + f_min: float = 0.0, + f_max: Optional[float] = None, + max_iter: int = 100000, + tolerance_loss: float = 1e-5, + tolerance_change: float = 1e-8, + sgdargs: Optional[dict] = None, + bark_scale: str = "traunmuller", + ) -> None: + super(InverseBarkScale, self).__init__() + self.n_barks = n_barks + self.sample_rate = sample_rate + self.f_max = f_max or float(sample_rate // 2) + self.f_min = f_min + self.max_iter = max_iter + self.tolerance_loss = tolerance_loss + self.tolerance_change = tolerance_change + self.sgdargs = sgdargs or {"lr": 0.1, "momentum": 0.9} + + if f_min > self.f_max: + raise ValueError("Require f_min: {} <= f_max: {}".format(f_min, self.f_max)) + + fb = barkscale_fbanks(n_stft, self.f_min, self.f_max, self.n_barks, self.sample_rate, bark_scale) + self.register_buffer("fb", fb) + + def forward(self, barkspec: torch.Tensor) -> torch.Tensor: + r""" + Args: + barkspec (torch.Tensor): A Bark frequency spectrogram of dimension (..., ``n_barks``, time) + + Returns: + torch.Tensor: Linear scale spectrogram of size (..., freq, time) + """ + # pack batch + shape = barkspec.size() + barkspec = barkspec.view(-1, shape[-2], shape[-1]) + + n_barks, time = shape[-2], shape[-1] + freq, _ = self.fb.size() # (freq, n_mels) + barkspec = barkspec.transpose(-1, -2) + if self.n_barks != n_barks: + raise ValueError("Expected an input with {} bark bins. Found: {}".format(self.n_barks, n_barks)) + + specgram = torch.rand( + barkspec.size()[0], time, freq, requires_grad=True, dtype=barkspec.dtype, device=barkspec.device + ) + + optim = torch.optim.SGD([specgram], **self.sgdargs) + + loss = float("inf") + for _ in range(self.max_iter): + optim.zero_grad() + diff = barkspec - specgram.matmul(self.fb) + new_loss = diff.pow(2).sum(axis=-1).mean() + # take sum over bark-frequency then average over other dimensions + # so that loss threshold is applied par unit timeframe + new_loss.backward() + optim.step() + specgram.data = specgram.data.clamp(min=0) + + new_loss = new_loss.item() + if new_loss < self.tolerance_loss or abs(loss - new_loss) < self.tolerance_change: + break + loss = new_loss + + specgram.requires_grad_(False) + specgram = specgram.clamp(min=0).transpose(-1, -2) + + # unpack batch + specgram = specgram.view(shape[:-2] + (freq, time)) + return specgram + + +class BarkSpectrogram(torch.nn.Module): + r"""Create BarkSpectrogram for a raw audio signal. + + .. devices:: CPU CUDA + + .. properties:: Autograd TorchScript + + This is a composition of :py:func:`torchaudio.transforms.Spectrogram` and + and :py:func:`torchaudio.transforms.BarkScale`. + + Sources + * https://www.fon.hum.uva.nl/praat/manual/BarkSpectrogram.html + * Traunmüller, Hartmut. "Analytical Expressions for the Tonotopic Sensory Scale." Journal of the Acoustical + * Society of America. Vol. 88, Issue 1, 1990, pp. 97–100. + * https://ccrma.stanford.edu/courses/120-fall-2003/lecture-5.html + + Args: + sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) + n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``) + win_length (int or None, optional): Window size. (Default: ``n_fft``) + hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``) + f_min (float, optional): Minimum frequency. (Default: ``0.``) + f_max (float or None, optional): Maximum frequency. (Default: ``None``) + pad (int, optional): Two sided padding of signal. (Default: ``0``) + n_mels (int, optional): Number of mel filterbanks. (Default: ``128``) + window_fn (Callable[..., torch.Tensor], optional): A function to create a window tensor + that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) + power (float, optional): Exponent for the magnitude spectrogram, + (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``) + normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``) + wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``) + center (bool, optional): whether to pad :attr:`waveform` on both sides so + that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. + (Default: ``True``) + pad_mode (string, optional): controls the padding method used when + :attr:`center` is ``True``. (Default: ``"reflect"``) + bark_scale (str, optional): Scale to use: ``traunmuller``, ``schroeder`` or ``wang``. (Default: ``traunmuller``) + + Example + >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True) + >>> transform = transforms.BarkSpectrogram(sample_rate) + >>> bark_specgram = transform(waveform) # (channel, n_barks, time) + + See also: + :py:func:`torchaudio.functional.melscale_fbanks` - The function used to + generate the filter banks. + """ + __constants__ = ["sample_rate", "n_fft", "win_length", "hop_length", "pad", "n_barks", "f_min"] + + def __init__( + self, + sample_rate: int = 16000, + n_fft: int = 400, + win_length: Optional[int] = None, + hop_length: Optional[int] = None, + f_min: float = 0.0, + f_max: Optional[float] = None, + pad: int = 0, + n_barks: int = 128, + window_fn: Callable[..., torch.Tensor] = torch.hann_window, + power: float = 2.0, + normalized: bool = False, + wkwargs: Optional[dict] = None, + center: bool = True, + pad_mode: str = "reflect", + bark_scale: str = "traunmuller", + ) -> None: + super(BarkSpectrogram, self).__init__() + + self.sample_rate = sample_rate + self.n_fft = n_fft + self.win_length = win_length if win_length is not None else n_fft + self.hop_length = hop_length if hop_length is not None else self.win_length // 2 + self.pad = pad + self.power = power + self.normalized = normalized + self.n_barks = n_barks # number of bark frequency bins + self.f_max = f_max + self.f_min = f_min + self.spectrogram = Spectrogram( + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + pad=self.pad, + window_fn=window_fn, + power=self.power, + normalized=self.normalized, + wkwargs=wkwargs, + center=center, + pad_mode=pad_mode, + onesided=True, + ) + self.bark_scale = BarkScale( + self.n_barks, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1, bark_scale + ) + + def forward(self, waveform: torch.Tensor) -> torch.Tensor: + r""" + Args: + waveform (torch.Tensor): torch.Tensor of audio of dimension (..., time). + + Returns: + torch.Tensor: Bark frequency spectrogram of size (..., ``n_barks``, time). + """ + specgram = self.spectrogram(waveform) + bark_specgram = self.bark_scale(specgram) + return bark_specgram + + +class ChromaScale(torch.nn.Module): + r"""Converts spectrogram to chromagram. + + .. devices:: CPU CUDA + + .. properties:: Autograd + + Args: + sample_rate (int): Sample rate of audio signal. + n_freqs (int): Number of frequency bins in STFT. See ``n_fft`` in :class:`Spectrogram`. + n_chroma (int, optional): Number of chroma. (Default: ``12``) + tuning (float, optional): Tuning deviation from A440 in fractions of a chroma bin. (Default: 0.0) + ctroct (float, optional): Center of Gaussian dominance window to weight filters by, in octaves. (Default: 5.0) + octwidth (float or None, optional): Width of Gaussian dominance window to weight filters by, in octaves. + If ``None``, then disable weighting altogether. (Default: 2.0) + norm (int, optional): order of norm to normalize filter bank by. (Default: 2) + base_c (bool, optional): If True, then start filter bank at C. Otherwise, start at A. (Default: True) + + Example + >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True) + >>> spectrogram_transform = transforms.Spectrogram(n_fft=1024) + >>> spectrogram = spectrogram_transform(waveform) + >>> chroma_transform = transforms.ChromaScale(sample_rate=sample_rate, n_freqs=1024 // 2 + 1) + >>> chroma_spectrogram = chroma_transform(spectrogram) + + See also: + :py:func:`torchaudio.prototype.functional.chroma_filterbank` — function used to + generate the filter bank. + """ + + def __init__( + self, + sample_rate: int, + n_freqs: int, + *, + n_chroma: int = 12, + tuning: float = 0.0, + ctroct: float = 5.0, + octwidth: Optional[float] = 2.0, + norm: int = 2, + base_c: bool = True, + ): + super().__init__() + fb = chroma_filterbank( + sample_rate, n_freqs, n_chroma, tuning=tuning, ctroct=ctroct, octwidth=octwidth, norm=norm, base_c=base_c + ) + self.register_buffer("fb", fb) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r""" + Args: + specgram (torch.Tensor): Spectrogram of dimension (..., ``n_freqs``, time). + + Returns: + torch.Tensor: Chroma spectrogram of size (..., ``n_chroma``, time). + """ + return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2) + + +class ChromaSpectrogram(torch.nn.Module): + r"""Generates chromagram for audio signal. + + .. devices:: CPU CUDA + + .. properties:: Autograd + + Composes :py:func:`torchaudio.transforms.Spectrogram` and + and :py:func:`torchaudio.prototype.transforms.ChromaScale`. + + Args: + sample_rate (int): Sample rate of audio signal. + n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. + win_length (int or None, optional): Window size. (Default: ``n_fft``) + hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``) + pad (int, optional): Two sided padding of signal. (Default: ``0``) + window_fn (Callable[..., torch.Tensor], optional): A function to create a window tensor + that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) + power (float, optional): Exponent for the magnitude spectrogram, + (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``) + normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``) + wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``) + center (bool, optional): whether to pad :attr:`waveform` on both sides so + that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. + (Default: ``True``) + pad_mode (string, optional): controls the padding method used when + :attr:`center` is ``True``. (Default: ``"reflect"``) + n_chroma (int, optional): Number of chroma. (Default: ``12``) + tuning (float, optional): Tuning deviation from A440 in fractions of a chroma bin. (Default: 0.0) + ctroct (float, optional): Center of Gaussian dominance window to weight filters by, in octaves. (Default: 5.0) + octwidth (float or None, optional): Width of Gaussian dominance window to weight filters by, in octaves. + If ``None``, then disable weighting altogether. (Default: 2.0) + norm (int, optional): order of norm to normalize filter bank by. (Default: 2) + base_c (bool, optional): If True, then start filter bank at C. Otherwise, start at A. (Default: True) + + Example + >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True) + >>> transform = transforms.ChromaSpectrogram(sample_rate=sample_rate, n_fft=400) + >>> chromagram = transform(waveform) # (channel, n_chroma, time) + """ + + def __init__( + self, + sample_rate: int, + n_fft: int, + *, + win_length: Optional[int] = None, + hop_length: Optional[int] = None, + pad: int = 0, + window_fn: Callable[..., torch.Tensor] = torch.hann_window, + power: float = 2.0, + normalized: bool = False, + wkwargs: Optional[dict] = None, + center: bool = True, + pad_mode: str = "reflect", + n_chroma: int = 12, + tuning: float = 0.0, + ctroct: float = 5.0, + octwidth: Optional[float] = 2.0, + norm: int = 2, + base_c: bool = True, + ): + super().__init__() + self.spectrogram = Spectrogram( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + pad=pad, + window_fn=window_fn, + power=power, + normalized=normalized, + wkwargs=wkwargs, + center=center, + pad_mode=pad_mode, + onesided=True, + ) + self.chroma_scale = ChromaScale( + sample_rate, + n_fft // 2 + 1, + n_chroma=n_chroma, + tuning=tuning, + base_c=base_c, + ctroct=ctroct, + octwidth=octwidth, + norm=norm, + ) + + def forward(self, waveform: torch.Tensor) -> torch.Tensor: + r""" + Args: + waveform (Tensor): Tensor of audio of dimension (..., time). + + Returns: + Tensor: Chromagram of size (..., ``n_chroma``, time). + """ + spectrogram = self.spectrogram(waveform) + chroma_spectrogram = self.chroma_scale(spectrogram) + return chroma_spectrogram