Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/_no_backend.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/common.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/no_backend.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/sox_io_backend.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchaudio/backend/_sox_io_backend.py +294 -0
- .venv/lib/python3.11/site-packages/torchaudio/backend/no_backend.py +14 -0
- .venv/lib/python3.11/site-packages/torchaudio/compliance/__init__.py +5 -0
- .venv/lib/python3.11/site-packages/torchaudio/compliance/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchaudio/compliance/__pycache__/kaldi.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchaudio/compliance/kaldi.py +813 -0
- .venv/lib/python3.11/site-packages/torchaudio/models/__init__.py +85 -0
- .venv/lib/python3.11/site-packages/torchaudio/models/_hdemucs.py +1008 -0
- .venv/lib/python3.11/site-packages/torchaudio/models/conformer.py +293 -0
- .venv/lib/python3.11/site-packages/torchaudio/models/conv_tasnet.py +330 -0
- .venv/lib/python3.11/site-packages/torchaudio/models/decoder/__init__.py +46 -0
- .venv/lib/python3.11/site-packages/torchaudio/models/decoder/_ctc_decoder.py +568 -0
- .venv/lib/python3.11/site-packages/torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
- .venv/lib/python3.11/site-packages/torchaudio/models/deepspeech.py +84 -0
- .venv/lib/python3.11/site-packages/torchaudio/models/emformer.py +884 -0
- .venv/lib/python3.11/site-packages/torchaudio/models/rnnt.py +816 -0
- .venv/lib/python3.11/site-packages/torchaudio/models/rnnt_decoder.py +339 -0
- .venv/lib/python3.11/site-packages/torchaudio/models/tacotron2.py +1046 -0
- .venv/lib/python3.11/site-packages/torchaudio/models/wav2letter.py +72 -0
- .venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/model.py +1579 -0
- .venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__init__.py +7 -0
- .venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/import_fairseq.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/import_huggingface.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/wavlm_attention.py +214 -0
- .venv/lib/python3.11/site-packages/torchaudio/models/wavernn.py +409 -0
- .venv/lib/python3.11/site-packages/torchaudio/prototype/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/torchaudio/prototype/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/__init__.py +4 -0
- .venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/__pycache__/musan.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/musan.py +67 -0
- .venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__init__.py +26 -0
- .venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/_dsp.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/_rir.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/functional.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchaudio/prototype/functional/_dsp.py +433 -0
- .venv/lib/python3.11/site-packages/torchaudio/prototype/functional/_rir.py +379 -0
- .venv/lib/python3.11/site-packages/torchaudio/prototype/functional/functional.py +190 -0
- .venv/lib/python3.11/site-packages/torchaudio/prototype/models/__init__.py +36 -0
- .venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/_conformer_wav2vec2.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/_emformer_hubert.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/conv_emformer.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (347 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/_no_backend.cpython-311.pyc
ADDED
|
Binary file (1.58 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/common.cpython-311.pyc
ADDED
|
Binary file (846 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/no_backend.cpython-311.pyc
ADDED
|
Binary file (861 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/sox_io_backend.cpython-311.pyc
ADDED
|
Binary file (869 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torchaudio/backend/_sox_io_backend.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torchaudio
|
| 6 |
+
from torchaudio import AudioMetaData
|
| 7 |
+
|
| 8 |
+
sox_ext = torchaudio._extension.lazy_import_sox_ext()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def info(
|
| 12 |
+
filepath: str,
|
| 13 |
+
format: Optional[str] = None,
|
| 14 |
+
) -> AudioMetaData:
|
| 15 |
+
"""Get signal information of an audio file.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
filepath (str):
|
| 19 |
+
Source of audio data.
|
| 20 |
+
|
| 21 |
+
format (str or None, optional):
|
| 22 |
+
Override the format detection with the given format.
|
| 23 |
+
Providing the argument might help when libsox can not infer the format
|
| 24 |
+
from header or extension.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
AudioMetaData: Metadata of the given audio.
|
| 28 |
+
"""
|
| 29 |
+
if not torch.jit.is_scripting():
|
| 30 |
+
if hasattr(filepath, "read"):
|
| 31 |
+
raise RuntimeError("sox_io backend does not support file-like object.")
|
| 32 |
+
filepath = os.fspath(filepath)
|
| 33 |
+
sinfo = sox_ext.get_info(filepath, format)
|
| 34 |
+
return AudioMetaData(*sinfo)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def load(
|
| 38 |
+
filepath: str,
|
| 39 |
+
frame_offset: int = 0,
|
| 40 |
+
num_frames: int = -1,
|
| 41 |
+
normalize: bool = True,
|
| 42 |
+
channels_first: bool = True,
|
| 43 |
+
format: Optional[str] = None,
|
| 44 |
+
) -> Tuple[torch.Tensor, int]:
|
| 45 |
+
"""Load audio data from file.
|
| 46 |
+
|
| 47 |
+
Note:
|
| 48 |
+
This function can handle all the codecs that underlying libsox can handle,
|
| 49 |
+
however it is tested on the following formats;
|
| 50 |
+
|
| 51 |
+
* WAV, AMB
|
| 52 |
+
|
| 53 |
+
* 32-bit floating-point
|
| 54 |
+
* 32-bit signed integer
|
| 55 |
+
* 24-bit signed integer
|
| 56 |
+
* 16-bit signed integer
|
| 57 |
+
* 8-bit unsigned integer (WAV only)
|
| 58 |
+
|
| 59 |
+
* MP3
|
| 60 |
+
* FLAC
|
| 61 |
+
* OGG/VORBIS
|
| 62 |
+
* OPUS
|
| 63 |
+
* SPHERE
|
| 64 |
+
* AMR-NB
|
| 65 |
+
|
| 66 |
+
To load ``MP3``, ``FLAC``, ``OGG/VORBIS``, ``OPUS`` and other codecs ``libsox`` does not
|
| 67 |
+
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
|
| 68 |
+
and corresponding codec libraries such as ``libmad`` or ``libmp3lame`` etc.
|
| 69 |
+
|
| 70 |
+
By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with
|
| 71 |
+
``float32`` dtype, and the shape of `[channel, time]`.
|
| 72 |
+
|
| 73 |
+
.. warning::
|
| 74 |
+
|
| 75 |
+
``normalize`` argument does not perform volume normalization.
|
| 76 |
+
It only converts the sample type to `torch.float32` from the native sample
|
| 77 |
+
type.
|
| 78 |
+
|
| 79 |
+
When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit
|
| 80 |
+
signed integer, 24-bit signed integer, and 8-bit unsigned integer, by providing ``normalize=False``,
|
| 81 |
+
this function can return integer Tensor, where the samples are expressed within the whole range
|
| 82 |
+
of the corresponding dtype, that is, ``int32`` tensor for 32-bit signed PCM,
|
| 83 |
+
``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM. Since torch does not
|
| 84 |
+
support ``int24`` dtype, 24-bit signed PCM are converted to ``int32`` tensors.
|
| 85 |
+
|
| 86 |
+
``normalize`` argument has no effect on 32-bit floating-point WAV and other formats, such as
|
| 87 |
+
``flac`` and ``mp3``.
|
| 88 |
+
|
| 89 |
+
For these formats, this function always returns ``float32`` Tensor with values.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
filepath (path-like object): Source of audio data.
|
| 93 |
+
frame_offset (int):
|
| 94 |
+
Number of frames to skip before start reading data.
|
| 95 |
+
num_frames (int, optional):
|
| 96 |
+
Maximum number of frames to read. ``-1`` reads all the remaining samples,
|
| 97 |
+
starting from ``frame_offset``.
|
| 98 |
+
This function may return the less number of frames if there is not enough
|
| 99 |
+
frames in the given file.
|
| 100 |
+
normalize (bool, optional):
|
| 101 |
+
When ``True``, this function converts the native sample type to ``float32``.
|
| 102 |
+
Default: ``True``.
|
| 103 |
+
|
| 104 |
+
If input file is integer WAV, giving ``False`` will change the resulting Tensor type to
|
| 105 |
+
integer type.
|
| 106 |
+
This argument has no effect for formats other than integer WAV type.
|
| 107 |
+
|
| 108 |
+
channels_first (bool, optional):
|
| 109 |
+
When True, the returned Tensor has dimension `[channel, time]`.
|
| 110 |
+
Otherwise, the returned Tensor's dimension is `[time, channel]`.
|
| 111 |
+
format (str or None, optional):
|
| 112 |
+
Override the format detection with the given format.
|
| 113 |
+
Providing the argument might help when libsox can not infer the format
|
| 114 |
+
from header or extension.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
(torch.Tensor, int): Resulting Tensor and sample rate.
|
| 118 |
+
If the input file has integer wav format and ``normalize=False``, then it has
|
| 119 |
+
integer type, else ``float32`` type. If ``channels_first=True``, it has
|
| 120 |
+
`[channel, time]` else `[time, channel]`.
|
| 121 |
+
"""
|
| 122 |
+
if not torch.jit.is_scripting():
|
| 123 |
+
if hasattr(filepath, "read"):
|
| 124 |
+
raise RuntimeError("sox_io backend does not support file-like object.")
|
| 125 |
+
filepath = os.fspath(filepath)
|
| 126 |
+
return sox_ext.load_audio_file(filepath, frame_offset, num_frames, normalize, channels_first, format)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def save(
|
| 130 |
+
filepath: str,
|
| 131 |
+
src: torch.Tensor,
|
| 132 |
+
sample_rate: int,
|
| 133 |
+
channels_first: bool = True,
|
| 134 |
+
compression: Optional[float] = None,
|
| 135 |
+
format: Optional[str] = None,
|
| 136 |
+
encoding: Optional[str] = None,
|
| 137 |
+
bits_per_sample: Optional[int] = None,
|
| 138 |
+
):
|
| 139 |
+
"""Save audio data to file.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
filepath (path-like object): Path to save file.
|
| 143 |
+
src (torch.Tensor): Audio data to save. must be 2D tensor.
|
| 144 |
+
sample_rate (int): sampling rate
|
| 145 |
+
channels_first (bool, optional): If ``True``, the given tensor is interpreted as `[channel, time]`,
|
| 146 |
+
otherwise `[time, channel]`.
|
| 147 |
+
compression (float or None, optional): Used for formats other than WAV.
|
| 148 |
+
This corresponds to ``-C`` option of ``sox`` command.
|
| 149 |
+
|
| 150 |
+
``"mp3"``
|
| 151 |
+
Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or
|
| 152 |
+
VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``.
|
| 153 |
+
|
| 154 |
+
``"flac"``
|
| 155 |
+
Whole number from ``0`` to ``8``. ``8`` is default and highest compression.
|
| 156 |
+
|
| 157 |
+
``"ogg"``, ``"vorbis"``
|
| 158 |
+
Number from ``-1`` to ``10``; ``-1`` is the highest compression
|
| 159 |
+
and lowest quality. Default: ``3``.
|
| 160 |
+
|
| 161 |
+
See the detail at http://sox.sourceforge.net/soxformat.html.
|
| 162 |
+
format (str or None, optional): Override the audio format.
|
| 163 |
+
When ``filepath`` argument is path-like object, audio format is infered from
|
| 164 |
+
file extension. If file extension is missing or different, you can specify the
|
| 165 |
+
correct format with this argument.
|
| 166 |
+
|
| 167 |
+
When ``filepath`` argument is file-like object, this argument is required.
|
| 168 |
+
|
| 169 |
+
Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``,
|
| 170 |
+
``"amb"``, ``"flac"``, ``"sph"``, ``"gsm"``, and ``"htk"``.
|
| 171 |
+
|
| 172 |
+
encoding (str or None, optional): Changes the encoding for the supported formats.
|
| 173 |
+
This argument is effective only for supported formats, such as ``"wav"``, ``""amb"``
|
| 174 |
+
and ``"sph"``. Valid values are;
|
| 175 |
+
|
| 176 |
+
- ``"PCM_S"`` (signed integer Linear PCM)
|
| 177 |
+
- ``"PCM_U"`` (unsigned integer Linear PCM)
|
| 178 |
+
- ``"PCM_F"`` (floating point PCM)
|
| 179 |
+
- ``"ULAW"`` (mu-law)
|
| 180 |
+
- ``"ALAW"`` (a-law)
|
| 181 |
+
|
| 182 |
+
Default values
|
| 183 |
+
If not provided, the default value is picked based on ``format`` and ``bits_per_sample``.
|
| 184 |
+
|
| 185 |
+
``"wav"``, ``"amb"``
|
| 186 |
+
- | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the
|
| 187 |
+
| Tensor is used to determine the default value.
|
| 188 |
+
|
| 189 |
+
- ``"PCM_U"`` if dtype is ``uint8``
|
| 190 |
+
- ``"PCM_S"`` if dtype is ``int16`` or ``int32``
|
| 191 |
+
- ``"PCM_F"`` if dtype is ``float32``
|
| 192 |
+
|
| 193 |
+
- ``"PCM_U"`` if ``bits_per_sample=8``
|
| 194 |
+
- ``"PCM_S"`` otherwise
|
| 195 |
+
|
| 196 |
+
``"sph"`` format;
|
| 197 |
+
- the default value is ``"PCM_S"``
|
| 198 |
+
|
| 199 |
+
bits_per_sample (int or None, optional): Changes the bit depth for the supported formats.
|
| 200 |
+
When ``format`` is one of ``"wav"``, ``"flac"``, ``"sph"``, or ``"amb"``, you can change the
|
| 201 |
+
bit depth. Valid values are ``8``, ``16``, ``32`` and ``64``.
|
| 202 |
+
|
| 203 |
+
Default Value;
|
| 204 |
+
If not provided, the default values are picked based on ``format`` and ``"encoding"``;
|
| 205 |
+
|
| 206 |
+
``"wav"``, ``"amb"``;
|
| 207 |
+
- | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the
|
| 208 |
+
| Tensor is used.
|
| 209 |
+
|
| 210 |
+
- ``8`` if dtype is ``uint8``
|
| 211 |
+
- ``16`` if dtype is ``int16``
|
| 212 |
+
- ``32`` if dtype is ``int32`` or ``float32``
|
| 213 |
+
|
| 214 |
+
- ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"``
|
| 215 |
+
- ``16`` if ``encoding`` is ``"PCM_S"``
|
| 216 |
+
- ``32`` if ``encoding`` is ``"PCM_F"``
|
| 217 |
+
|
| 218 |
+
``"flac"`` format;
|
| 219 |
+
- the default value is ``24``
|
| 220 |
+
|
| 221 |
+
``"sph"`` format;
|
| 222 |
+
- ``16`` if ``encoding`` is ``"PCM_U"``, ``"PCM_S"``, ``"PCM_F"`` or not provided.
|
| 223 |
+
- ``8`` if ``encoding`` is ``"ULAW"`` or ``"ALAW"``
|
| 224 |
+
|
| 225 |
+
``"amb"`` format;
|
| 226 |
+
- ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"``
|
| 227 |
+
- ``16`` if ``encoding`` is ``"PCM_S"`` or not provided.
|
| 228 |
+
- ``32`` if ``encoding`` is ``"PCM_F"``
|
| 229 |
+
|
| 230 |
+
Supported formats/encodings/bit depth/compression are;
|
| 231 |
+
|
| 232 |
+
``"wav"``, ``"amb"``
|
| 233 |
+
- 32-bit floating-point PCM
|
| 234 |
+
- 32-bit signed integer PCM
|
| 235 |
+
- 24-bit signed integer PCM
|
| 236 |
+
- 16-bit signed integer PCM
|
| 237 |
+
- 8-bit unsigned integer PCM
|
| 238 |
+
- 8-bit mu-law
|
| 239 |
+
- 8-bit a-law
|
| 240 |
+
|
| 241 |
+
Note: Default encoding/bit depth is determined by the dtype of the input Tensor.
|
| 242 |
+
|
| 243 |
+
``"mp3"``
|
| 244 |
+
Fixed bit rate (such as 128kHz) and variable bit rate compression.
|
| 245 |
+
Default: VBR with high quality.
|
| 246 |
+
|
| 247 |
+
``"flac"``
|
| 248 |
+
- 8-bit
|
| 249 |
+
- 16-bit
|
| 250 |
+
- 24-bit (default)
|
| 251 |
+
|
| 252 |
+
``"ogg"``, ``"vorbis"``
|
| 253 |
+
- Different quality level. Default: approx. 112kbps
|
| 254 |
+
|
| 255 |
+
``"sph"``
|
| 256 |
+
- 8-bit signed integer PCM
|
| 257 |
+
- 16-bit signed integer PCM
|
| 258 |
+
- 24-bit signed integer PCM
|
| 259 |
+
- 32-bit signed integer PCM (default)
|
| 260 |
+
- 8-bit mu-law
|
| 261 |
+
- 8-bit a-law
|
| 262 |
+
- 16-bit a-law
|
| 263 |
+
- 24-bit a-law
|
| 264 |
+
- 32-bit a-law
|
| 265 |
+
|
| 266 |
+
``"amr-nb"``
|
| 267 |
+
Bitrate ranging from 4.75 kbit/s to 12.2 kbit/s. Default: 4.75 kbit/s
|
| 268 |
+
|
| 269 |
+
``"gsm"``
|
| 270 |
+
Lossy Speech Compression, CPU intensive.
|
| 271 |
+
|
| 272 |
+
``"htk"``
|
| 273 |
+
Uses a default single-channel 16-bit PCM format.
|
| 274 |
+
|
| 275 |
+
Note:
|
| 276 |
+
To save into formats that ``libsox`` does not handle natively, (such as ``"mp3"``,
|
| 277 |
+
``"flac"``, ``"ogg"`` and ``"vorbis"``), your installation of ``torchaudio`` has
|
| 278 |
+
to be linked to ``libsox`` and corresponding codec libraries such as ``libmad``
|
| 279 |
+
or ``libmp3lame`` etc.
|
| 280 |
+
"""
|
| 281 |
+
if not torch.jit.is_scripting():
|
| 282 |
+
if hasattr(filepath, "write"):
|
| 283 |
+
raise RuntimeError("sox_io backend does not handle file-like object.")
|
| 284 |
+
filepath = os.fspath(filepath)
|
| 285 |
+
sox_ext.save_audio_file(
|
| 286 |
+
filepath,
|
| 287 |
+
src,
|
| 288 |
+
sample_rate,
|
| 289 |
+
channels_first,
|
| 290 |
+
compression,
|
| 291 |
+
format,
|
| 292 |
+
encoding,
|
| 293 |
+
bits_per_sample,
|
| 294 |
+
)
|
.venv/lib/python3.11/site-packages/torchaudio/backend/no_backend.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def __getattr__(name: str):
|
| 2 |
+
import warnings
|
| 3 |
+
|
| 4 |
+
warnings.warn(
|
| 5 |
+
"Torchaudio's I/O functions now support par-call bakcend dispatch. "
|
| 6 |
+
"Importing backend implementation directly is no longer guaranteed to work. "
|
| 7 |
+
"Please use `backend` keyword with load/save/info function, instead of "
|
| 8 |
+
"calling the udnerlying implementation directly.",
|
| 9 |
+
stacklevel=2,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
from . import _no_backend
|
| 13 |
+
|
| 14 |
+
return getattr(_no_backend, name)
|
.venv/lib/python3.11/site-packages/torchaudio/compliance/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import kaldi
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
"kaldi",
|
| 5 |
+
]
|
.venv/lib/python3.11/site-packages/torchaudio/compliance/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (268 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torchaudio/compliance/__pycache__/kaldi.cpython-311.pyc
ADDED
|
Binary file (37.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchaudio/compliance/kaldi.py
ADDED
|
@@ -0,0 +1,813 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torchaudio
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"get_mel_banks",
|
| 10 |
+
"inverse_mel_scale",
|
| 11 |
+
"inverse_mel_scale_scalar",
|
| 12 |
+
"mel_scale",
|
| 13 |
+
"mel_scale_scalar",
|
| 14 |
+
"spectrogram",
|
| 15 |
+
"fbank",
|
| 16 |
+
"mfcc",
|
| 17 |
+
"vtln_warp_freq",
|
| 18 |
+
"vtln_warp_mel_freq",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
# numeric_limits<float>::epsilon() 1.1920928955078125e-07
|
| 22 |
+
EPSILON = torch.tensor(torch.finfo(torch.float).eps)
|
| 23 |
+
# 1 milliseconds = 0.001 seconds
|
| 24 |
+
MILLISECONDS_TO_SECONDS = 0.001
|
| 25 |
+
|
| 26 |
+
# window types
|
| 27 |
+
HAMMING = "hamming"
|
| 28 |
+
HANNING = "hanning"
|
| 29 |
+
POVEY = "povey"
|
| 30 |
+
RECTANGULAR = "rectangular"
|
| 31 |
+
BLACKMAN = "blackman"
|
| 32 |
+
WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _get_epsilon(device, dtype):
|
| 36 |
+
return EPSILON.to(device=device, dtype=dtype)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _next_power_of_2(x: int) -> int:
|
| 40 |
+
r"""Returns the smallest power of 2 that is greater than x"""
|
| 41 |
+
return 1 if x == 0 else 2 ** (x - 1).bit_length()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor:
|
| 45 |
+
r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``)
|
| 46 |
+
representing how the window is shifted along the waveform. Each row is a frame.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
waveform (Tensor): Tensor of size ``num_samples``
|
| 50 |
+
window_size (int): Frame length
|
| 51 |
+
window_shift (int): Frame shift
|
| 52 |
+
snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
|
| 53 |
+
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
| 54 |
+
depends only on the frame_shift, and we reflect the data at the ends.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame
|
| 58 |
+
"""
|
| 59 |
+
assert waveform.dim() == 1
|
| 60 |
+
num_samples = waveform.size(0)
|
| 61 |
+
strides = (window_shift * waveform.stride(0), waveform.stride(0))
|
| 62 |
+
|
| 63 |
+
if snip_edges:
|
| 64 |
+
if num_samples < window_size:
|
| 65 |
+
return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
|
| 66 |
+
else:
|
| 67 |
+
m = 1 + (num_samples - window_size) // window_shift
|
| 68 |
+
else:
|
| 69 |
+
reversed_waveform = torch.flip(waveform, [0])
|
| 70 |
+
m = (num_samples + (window_shift // 2)) // window_shift
|
| 71 |
+
pad = window_size // 2 - window_shift // 2
|
| 72 |
+
pad_right = reversed_waveform
|
| 73 |
+
if pad > 0:
|
| 74 |
+
# torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
|
| 75 |
+
# but we want [2, 1, 0, 0, 1, 2]
|
| 76 |
+
pad_left = reversed_waveform[-pad:]
|
| 77 |
+
waveform = torch.cat((pad_left, waveform, pad_right), dim=0)
|
| 78 |
+
else:
|
| 79 |
+
# pad is negative so we want to trim the waveform at the front
|
| 80 |
+
waveform = torch.cat((waveform[-pad:], pad_right), dim=0)
|
| 81 |
+
|
| 82 |
+
sizes = (m, window_size)
|
| 83 |
+
return waveform.as_strided(sizes, strides)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _feature_window_function(
|
| 87 |
+
window_type: str,
|
| 88 |
+
window_size: int,
|
| 89 |
+
blackman_coeff: float,
|
| 90 |
+
device: torch.device,
|
| 91 |
+
dtype: int,
|
| 92 |
+
) -> Tensor:
|
| 93 |
+
r"""Returns a window function with the given type and size"""
|
| 94 |
+
if window_type == HANNING:
|
| 95 |
+
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)
|
| 96 |
+
elif window_type == HAMMING:
|
| 97 |
+
return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype)
|
| 98 |
+
elif window_type == POVEY:
|
| 99 |
+
# like hanning but goes to zero at edges
|
| 100 |
+
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85)
|
| 101 |
+
elif window_type == RECTANGULAR:
|
| 102 |
+
return torch.ones(window_size, device=device, dtype=dtype)
|
| 103 |
+
elif window_type == BLACKMAN:
|
| 104 |
+
a = 2 * math.pi / (window_size - 1)
|
| 105 |
+
window_function = torch.arange(window_size, device=device, dtype=dtype)
|
| 106 |
+
# can't use torch.blackman_window as they use different coefficients
|
| 107 |
+
return (
|
| 108 |
+
blackman_coeff
|
| 109 |
+
- 0.5 * torch.cos(a * window_function)
|
| 110 |
+
+ (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)
|
| 111 |
+
).to(device=device, dtype=dtype)
|
| 112 |
+
else:
|
| 113 |
+
raise Exception("Invalid window type " + window_type)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor:
|
| 117 |
+
r"""Returns the log energy of size (m) for a strided_input (m,*)"""
|
| 118 |
+
device, dtype = strided_input.device, strided_input.dtype
|
| 119 |
+
log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
|
| 120 |
+
if energy_floor == 0.0:
|
| 121 |
+
return log_energy
|
| 122 |
+
return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _get_waveform_and_window_properties(
|
| 126 |
+
waveform: Tensor,
|
| 127 |
+
channel: int,
|
| 128 |
+
sample_frequency: float,
|
| 129 |
+
frame_shift: float,
|
| 130 |
+
frame_length: float,
|
| 131 |
+
round_to_power_of_two: bool,
|
| 132 |
+
preemphasis_coefficient: float,
|
| 133 |
+
) -> Tuple[Tensor, int, int, int]:
|
| 134 |
+
r"""Gets the waveform and window properties"""
|
| 135 |
+
channel = max(channel, 0)
|
| 136 |
+
assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0))
|
| 137 |
+
waveform = waveform[channel, :] # size (n)
|
| 138 |
+
window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
|
| 139 |
+
window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
|
| 140 |
+
padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
|
| 141 |
+
|
| 142 |
+
assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
|
| 143 |
+
window_size, len(waveform)
|
| 144 |
+
)
|
| 145 |
+
assert 0 < window_shift, "`window_shift` must be greater than 0"
|
| 146 |
+
assert padded_window_size % 2 == 0, (
|
| 147 |
+
"the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`"
|
| 148 |
+
)
|
| 149 |
+
assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
|
| 150 |
+
assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
|
| 151 |
+
return waveform, window_shift, window_size, padded_window_size
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def _get_window(
|
| 155 |
+
waveform: Tensor,
|
| 156 |
+
padded_window_size: int,
|
| 157 |
+
window_size: int,
|
| 158 |
+
window_shift: int,
|
| 159 |
+
window_type: str,
|
| 160 |
+
blackman_coeff: float,
|
| 161 |
+
snip_edges: bool,
|
| 162 |
+
raw_energy: bool,
|
| 163 |
+
energy_floor: float,
|
| 164 |
+
dither: float,
|
| 165 |
+
remove_dc_offset: bool,
|
| 166 |
+
preemphasis_coefficient: float,
|
| 167 |
+
) -> Tuple[Tensor, Tensor]:
|
| 168 |
+
r"""Gets a window and its log energy
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
(Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m)
|
| 172 |
+
"""
|
| 173 |
+
device, dtype = waveform.device, waveform.dtype
|
| 174 |
+
epsilon = _get_epsilon(device, dtype)
|
| 175 |
+
|
| 176 |
+
# size (m, window_size)
|
| 177 |
+
strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
|
| 178 |
+
|
| 179 |
+
if dither != 0.0:
|
| 180 |
+
rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype)
|
| 181 |
+
strided_input = strided_input + rand_gauss * dither
|
| 182 |
+
|
| 183 |
+
if remove_dc_offset:
|
| 184 |
+
# Subtract each row/frame by its mean
|
| 185 |
+
row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
|
| 186 |
+
strided_input = strided_input - row_means
|
| 187 |
+
|
| 188 |
+
if raw_energy:
|
| 189 |
+
# Compute the log energy of each row/frame before applying preemphasis and
|
| 190 |
+
# window function
|
| 191 |
+
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
|
| 192 |
+
|
| 193 |
+
if preemphasis_coefficient != 0.0:
|
| 194 |
+
# strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
|
| 195 |
+
offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(
|
| 196 |
+
0
|
| 197 |
+
) # size (m, window_size + 1)
|
| 198 |
+
strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]
|
| 199 |
+
|
| 200 |
+
# Apply window_function to each row/frame
|
| 201 |
+
window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze(
|
| 202 |
+
0
|
| 203 |
+
) # size (1, window_size)
|
| 204 |
+
strided_input = strided_input * window_function # size (m, window_size)
|
| 205 |
+
|
| 206 |
+
# Pad columns with zero until we reach size (m, padded_window_size)
|
| 207 |
+
if padded_window_size != window_size:
|
| 208 |
+
padding_right = padded_window_size - window_size
|
| 209 |
+
strided_input = torch.nn.functional.pad(
|
| 210 |
+
strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0
|
| 211 |
+
).squeeze(0)
|
| 212 |
+
|
| 213 |
+
# Compute energy after window function (not the raw one)
|
| 214 |
+
if not raw_energy:
|
| 215 |
+
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
|
| 216 |
+
|
| 217 |
+
return strided_input, signal_log_energy
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
|
| 221 |
+
# subtracts the column mean of the tensor size (m, n) if subtract_mean=True
|
| 222 |
+
# it returns size (m, n)
|
| 223 |
+
if subtract_mean:
|
| 224 |
+
col_means = torch.mean(tensor, dim=0).unsqueeze(0)
|
| 225 |
+
tensor = tensor - col_means
|
| 226 |
+
return tensor
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def spectrogram(
|
| 230 |
+
waveform: Tensor,
|
| 231 |
+
blackman_coeff: float = 0.42,
|
| 232 |
+
channel: int = -1,
|
| 233 |
+
dither: float = 0.0,
|
| 234 |
+
energy_floor: float = 1.0,
|
| 235 |
+
frame_length: float = 25.0,
|
| 236 |
+
frame_shift: float = 10.0,
|
| 237 |
+
min_duration: float = 0.0,
|
| 238 |
+
preemphasis_coefficient: float = 0.97,
|
| 239 |
+
raw_energy: bool = True,
|
| 240 |
+
remove_dc_offset: bool = True,
|
| 241 |
+
round_to_power_of_two: bool = True,
|
| 242 |
+
sample_frequency: float = 16000.0,
|
| 243 |
+
snip_edges: bool = True,
|
| 244 |
+
subtract_mean: bool = False,
|
| 245 |
+
window_type: str = POVEY,
|
| 246 |
+
) -> Tensor:
|
| 247 |
+
r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's
|
| 248 |
+
compute-spectrogram-feats.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
| 252 |
+
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
|
| 253 |
+
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
|
| 254 |
+
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
|
| 255 |
+
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
|
| 256 |
+
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
|
| 257 |
+
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
|
| 258 |
+
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
|
| 259 |
+
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
|
| 260 |
+
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
|
| 261 |
+
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
|
| 262 |
+
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
|
| 263 |
+
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
|
| 264 |
+
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
|
| 265 |
+
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
| 266 |
+
to FFT. (Default: ``True``)
|
| 267 |
+
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
|
| 268 |
+
specified there) (Default: ``16000.0``)
|
| 269 |
+
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
|
| 270 |
+
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
| 271 |
+
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
|
| 272 |
+
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
|
| 273 |
+
it this way. (Default: ``False``)
|
| 274 |
+
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
|
| 275 |
+
(Default: ``'povey'``)
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
Tensor: A spectrogram identical to what Kaldi would output. The shape is
|
| 279 |
+
(m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided
|
| 280 |
+
"""
|
| 281 |
+
device, dtype = waveform.device, waveform.dtype
|
| 282 |
+
epsilon = _get_epsilon(device, dtype)
|
| 283 |
+
|
| 284 |
+
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
|
| 285 |
+
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if len(waveform) < min_duration * sample_frequency:
|
| 289 |
+
# signal is too short
|
| 290 |
+
return torch.empty(0)
|
| 291 |
+
|
| 292 |
+
strided_input, signal_log_energy = _get_window(
|
| 293 |
+
waveform,
|
| 294 |
+
padded_window_size,
|
| 295 |
+
window_size,
|
| 296 |
+
window_shift,
|
| 297 |
+
window_type,
|
| 298 |
+
blackman_coeff,
|
| 299 |
+
snip_edges,
|
| 300 |
+
raw_energy,
|
| 301 |
+
energy_floor,
|
| 302 |
+
dither,
|
| 303 |
+
remove_dc_offset,
|
| 304 |
+
preemphasis_coefficient,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# size (m, padded_window_size // 2 + 1, 2)
|
| 308 |
+
fft = torch.fft.rfft(strided_input)
|
| 309 |
+
|
| 310 |
+
# Convert the FFT into a power spectrum
|
| 311 |
+
power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1)
|
| 312 |
+
power_spectrum[:, 0] = signal_log_energy
|
| 313 |
+
|
| 314 |
+
power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
|
| 315 |
+
return power_spectrum
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def inverse_mel_scale_scalar(mel_freq: float) -> float:
|
| 319 |
+
return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
|
| 323 |
+
return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def mel_scale_scalar(freq: float) -> float:
|
| 327 |
+
return 1127.0 * math.log(1.0 + freq / 700.0)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def mel_scale(freq: Tensor) -> Tensor:
|
| 331 |
+
return 1127.0 * (1.0 + freq / 700.0).log()
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def vtln_warp_freq(
|
| 335 |
+
vtln_low_cutoff: float,
|
| 336 |
+
vtln_high_cutoff: float,
|
| 337 |
+
low_freq: float,
|
| 338 |
+
high_freq: float,
|
| 339 |
+
vtln_warp_factor: float,
|
| 340 |
+
freq: Tensor,
|
| 341 |
+
) -> Tensor:
|
| 342 |
+
r"""This computes a VTLN warping function that is not the same as HTK's one,
|
| 343 |
+
but has similar inputs (this function has the advantage of never producing
|
| 344 |
+
empty bins).
|
| 345 |
+
|
| 346 |
+
This function computes a warp function F(freq), defined between low_freq
|
| 347 |
+
and high_freq inclusive, with the following properties:
|
| 348 |
+
F(low_freq) == low_freq
|
| 349 |
+
F(high_freq) == high_freq
|
| 350 |
+
The function is continuous and piecewise linear with two inflection
|
| 351 |
+
points.
|
| 352 |
+
The lower inflection point (measured in terms of the unwarped
|
| 353 |
+
frequency) is at frequency l, determined as described below.
|
| 354 |
+
The higher inflection point is at a frequency h, determined as
|
| 355 |
+
described below.
|
| 356 |
+
If l <= f <= h, then F(f) = f/vtln_warp_factor.
|
| 357 |
+
If the higher inflection point (measured in terms of the unwarped
|
| 358 |
+
frequency) is at h, then max(h, F(h)) == vtln_high_cutoff.
|
| 359 |
+
Since (by the last point) F(h) == h/vtln_warp_factor, then
|
| 360 |
+
max(h, h/vtln_warp_factor) == vtln_high_cutoff, so
|
| 361 |
+
h = vtln_high_cutoff / max(1, 1/vtln_warp_factor).
|
| 362 |
+
= vtln_high_cutoff * min(1, vtln_warp_factor).
|
| 363 |
+
If the lower inflection point (measured in terms of the unwarped
|
| 364 |
+
frequency) is at l, then min(l, F(l)) == vtln_low_cutoff
|
| 365 |
+
This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor)
|
| 366 |
+
= vtln_low_cutoff * max(1, vtln_warp_factor)
|
| 367 |
+
Args:
|
| 368 |
+
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
|
| 369 |
+
vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
|
| 370 |
+
low_freq (float): Lower frequency cutoffs in mel computation
|
| 371 |
+
high_freq (float): Upper frequency cutoffs in mel computation
|
| 372 |
+
vtln_warp_factor (float): Vtln warp factor
|
| 373 |
+
freq (Tensor): given frequency in Hz
|
| 374 |
+
|
| 375 |
+
Returns:
|
| 376 |
+
Tensor: Freq after vtln warp
|
| 377 |
+
"""
|
| 378 |
+
assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq"
|
| 379 |
+
assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]"
|
| 380 |
+
l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
|
| 381 |
+
h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
|
| 382 |
+
scale = 1.0 / vtln_warp_factor
|
| 383 |
+
Fl = scale * l # F(l)
|
| 384 |
+
Fh = scale * h # F(h)
|
| 385 |
+
assert l > low_freq and h < high_freq
|
| 386 |
+
# slope of left part of the 3-piece linear function
|
| 387 |
+
scale_left = (Fl - low_freq) / (l - low_freq)
|
| 388 |
+
# [slope of center part is just "scale"]
|
| 389 |
+
|
| 390 |
+
# slope of right part of the 3-piece linear function
|
| 391 |
+
scale_right = (high_freq - Fh) / (high_freq - h)
|
| 392 |
+
|
| 393 |
+
res = torch.empty_like(freq)
|
| 394 |
+
|
| 395 |
+
outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq
|
| 396 |
+
before_l = torch.lt(freq, l) # freq < l
|
| 397 |
+
before_h = torch.lt(freq, h) # freq < h
|
| 398 |
+
after_h = torch.ge(freq, h) # freq >= h
|
| 399 |
+
|
| 400 |
+
# order of operations matter here (since there is overlapping frequency regions)
|
| 401 |
+
res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
|
| 402 |
+
res[before_h] = scale * freq[before_h]
|
| 403 |
+
res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
|
| 404 |
+
res[outside_low_high_freq] = freq[outside_low_high_freq]
|
| 405 |
+
|
| 406 |
+
return res
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def vtln_warp_mel_freq(
|
| 410 |
+
vtln_low_cutoff: float,
|
| 411 |
+
vtln_high_cutoff: float,
|
| 412 |
+
low_freq,
|
| 413 |
+
high_freq: float,
|
| 414 |
+
vtln_warp_factor: float,
|
| 415 |
+
mel_freq: Tensor,
|
| 416 |
+
) -> Tensor:
|
| 417 |
+
r"""
|
| 418 |
+
Args:
|
| 419 |
+
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
|
| 420 |
+
vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
|
| 421 |
+
low_freq (float): Lower frequency cutoffs in mel computation
|
| 422 |
+
high_freq (float): Upper frequency cutoffs in mel computation
|
| 423 |
+
vtln_warp_factor (float): Vtln warp factor
|
| 424 |
+
mel_freq (Tensor): Given frequency in Mel
|
| 425 |
+
|
| 426 |
+
Returns:
|
| 427 |
+
Tensor: ``mel_freq`` after vtln warp
|
| 428 |
+
"""
|
| 429 |
+
return mel_scale(
|
| 430 |
+
vtln_warp_freq(
|
| 431 |
+
vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq)
|
| 432 |
+
)
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def get_mel_banks(
|
| 437 |
+
num_bins: int,
|
| 438 |
+
window_length_padded: int,
|
| 439 |
+
sample_freq: float,
|
| 440 |
+
low_freq: float,
|
| 441 |
+
high_freq: float,
|
| 442 |
+
vtln_low: float,
|
| 443 |
+
vtln_high: float,
|
| 444 |
+
vtln_warp_factor: float,
|
| 445 |
+
) -> Tuple[Tensor, Tensor]:
|
| 446 |
+
"""
|
| 447 |
+
Returns:
|
| 448 |
+
(Tensor, Tensor): The tuple consists of ``bins`` (which is
|
| 449 |
+
melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is
|
| 450 |
+
center frequencies of bins of size (``num_bins``)).
|
| 451 |
+
"""
|
| 452 |
+
assert num_bins > 3, "Must have at least 3 mel bins"
|
| 453 |
+
assert window_length_padded % 2 == 0
|
| 454 |
+
num_fft_bins = window_length_padded / 2
|
| 455 |
+
nyquist = 0.5 * sample_freq
|
| 456 |
+
|
| 457 |
+
if high_freq <= 0.0:
|
| 458 |
+
high_freq += nyquist
|
| 459 |
+
|
| 460 |
+
assert (
|
| 461 |
+
(0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq)
|
| 462 |
+
), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
|
| 463 |
+
|
| 464 |
+
# fft-bin width [think of it as Nyquist-freq / half-window-length]
|
| 465 |
+
fft_bin_width = sample_freq / window_length_padded
|
| 466 |
+
mel_low_freq = mel_scale_scalar(low_freq)
|
| 467 |
+
mel_high_freq = mel_scale_scalar(high_freq)
|
| 468 |
+
|
| 469 |
+
# divide by num_bins+1 in next line because of end-effects where the bins
|
| 470 |
+
# spread out to the sides.
|
| 471 |
+
mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
|
| 472 |
+
|
| 473 |
+
if vtln_high < 0.0:
|
| 474 |
+
vtln_high += nyquist
|
| 475 |
+
|
| 476 |
+
assert vtln_warp_factor == 1.0 or (
|
| 477 |
+
(low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)
|
| 478 |
+
), "Bad values in options: vtln-low {} and vtln-high {}, versus " "low-freq {} and high-freq {}".format(
|
| 479 |
+
vtln_low, vtln_high, low_freq, high_freq
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
bin = torch.arange(num_bins).unsqueeze(1)
|
| 483 |
+
left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
|
| 484 |
+
center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1)
|
| 485 |
+
right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1)
|
| 486 |
+
|
| 487 |
+
if vtln_warp_factor != 1.0:
|
| 488 |
+
left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel)
|
| 489 |
+
center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel)
|
| 490 |
+
right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel)
|
| 491 |
+
|
| 492 |
+
center_freqs = inverse_mel_scale(center_mel) # size (num_bins)
|
| 493 |
+
# size(1, num_fft_bins)
|
| 494 |
+
mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0)
|
| 495 |
+
|
| 496 |
+
# size (num_bins, num_fft_bins)
|
| 497 |
+
up_slope = (mel - left_mel) / (center_mel - left_mel)
|
| 498 |
+
down_slope = (right_mel - mel) / (right_mel - center_mel)
|
| 499 |
+
|
| 500 |
+
if vtln_warp_factor == 1.0:
|
| 501 |
+
# left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
|
| 502 |
+
bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
|
| 503 |
+
else:
|
| 504 |
+
# warping can move the order of left_mel, center_mel, right_mel anywhere
|
| 505 |
+
bins = torch.zeros_like(up_slope)
|
| 506 |
+
up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel
|
| 507 |
+
down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel
|
| 508 |
+
bins[up_idx] = up_slope[up_idx]
|
| 509 |
+
bins[down_idx] = down_slope[down_idx]
|
| 510 |
+
|
| 511 |
+
return bins, center_freqs
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def fbank(
|
| 515 |
+
waveform: Tensor,
|
| 516 |
+
blackman_coeff: float = 0.42,
|
| 517 |
+
channel: int = -1,
|
| 518 |
+
dither: float = 0.0,
|
| 519 |
+
energy_floor: float = 1.0,
|
| 520 |
+
frame_length: float = 25.0,
|
| 521 |
+
frame_shift: float = 10.0,
|
| 522 |
+
high_freq: float = 0.0,
|
| 523 |
+
htk_compat: bool = False,
|
| 524 |
+
low_freq: float = 20.0,
|
| 525 |
+
min_duration: float = 0.0,
|
| 526 |
+
num_mel_bins: int = 23,
|
| 527 |
+
preemphasis_coefficient: float = 0.97,
|
| 528 |
+
raw_energy: bool = True,
|
| 529 |
+
remove_dc_offset: bool = True,
|
| 530 |
+
round_to_power_of_two: bool = True,
|
| 531 |
+
sample_frequency: float = 16000.0,
|
| 532 |
+
snip_edges: bool = True,
|
| 533 |
+
subtract_mean: bool = False,
|
| 534 |
+
use_energy: bool = False,
|
| 535 |
+
use_log_fbank: bool = True,
|
| 536 |
+
use_power: bool = True,
|
| 537 |
+
vtln_high: float = -500.0,
|
| 538 |
+
vtln_low: float = 100.0,
|
| 539 |
+
vtln_warp: float = 1.0,
|
| 540 |
+
window_type: str = POVEY,
|
| 541 |
+
) -> Tensor:
|
| 542 |
+
r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's
|
| 543 |
+
compute-fbank-feats.
|
| 544 |
+
|
| 545 |
+
Args:
|
| 546 |
+
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
| 547 |
+
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
|
| 548 |
+
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
|
| 549 |
+
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
|
| 550 |
+
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
|
| 551 |
+
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
|
| 552 |
+
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
|
| 553 |
+
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
|
| 554 |
+
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
|
| 555 |
+
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
|
| 556 |
+
high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
|
| 557 |
+
(Default: ``0.0``)
|
| 558 |
+
htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features
|
| 559 |
+
(need to change other parameters). (Default: ``False``)
|
| 560 |
+
low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
|
| 561 |
+
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
|
| 562 |
+
num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
|
| 563 |
+
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
|
| 564 |
+
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
|
| 565 |
+
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
|
| 566 |
+
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
| 567 |
+
to FFT. (Default: ``True``)
|
| 568 |
+
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
|
| 569 |
+
specified there) (Default: ``16000.0``)
|
| 570 |
+
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
|
| 571 |
+
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
| 572 |
+
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
|
| 573 |
+
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
|
| 574 |
+
it this way. (Default: ``False``)
|
| 575 |
+
use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
|
| 576 |
+
use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``)
|
| 577 |
+
use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``)
|
| 578 |
+
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
|
| 579 |
+
negative, offset from high-mel-freq (Default: ``-500.0``)
|
| 580 |
+
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
|
| 581 |
+
vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
|
| 582 |
+
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
|
| 583 |
+
(Default: ``'povey'``)
|
| 584 |
+
|
| 585 |
+
Returns:
|
| 586 |
+
Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``)
|
| 587 |
+
where m is calculated in _get_strided
|
| 588 |
+
"""
|
| 589 |
+
device, dtype = waveform.device, waveform.dtype
|
| 590 |
+
|
| 591 |
+
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
|
| 592 |
+
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
if len(waveform) < min_duration * sample_frequency:
|
| 596 |
+
# signal is too short
|
| 597 |
+
return torch.empty(0, device=device, dtype=dtype)
|
| 598 |
+
|
| 599 |
+
# strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
|
| 600 |
+
strided_input, signal_log_energy = _get_window(
|
| 601 |
+
waveform,
|
| 602 |
+
padded_window_size,
|
| 603 |
+
window_size,
|
| 604 |
+
window_shift,
|
| 605 |
+
window_type,
|
| 606 |
+
blackman_coeff,
|
| 607 |
+
snip_edges,
|
| 608 |
+
raw_energy,
|
| 609 |
+
energy_floor,
|
| 610 |
+
dither,
|
| 611 |
+
remove_dc_offset,
|
| 612 |
+
preemphasis_coefficient,
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
# size (m, padded_window_size // 2 + 1)
|
| 616 |
+
spectrum = torch.fft.rfft(strided_input).abs()
|
| 617 |
+
if use_power:
|
| 618 |
+
spectrum = spectrum.pow(2.0)
|
| 619 |
+
|
| 620 |
+
# size (num_mel_bins, padded_window_size // 2)
|
| 621 |
+
mel_energies, _ = get_mel_banks(
|
| 622 |
+
num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp
|
| 623 |
+
)
|
| 624 |
+
mel_energies = mel_energies.to(device=device, dtype=dtype)
|
| 625 |
+
|
| 626 |
+
# pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
|
| 627 |
+
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
|
| 628 |
+
|
| 629 |
+
# sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
|
| 630 |
+
mel_energies = torch.mm(spectrum, mel_energies.T)
|
| 631 |
+
if use_log_fbank:
|
| 632 |
+
# avoid log of zero (which should be prevented anyway by dithering)
|
| 633 |
+
mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
|
| 634 |
+
|
| 635 |
+
# if use_energy then add it as the last column for htk_compat == true else first column
|
| 636 |
+
if use_energy:
|
| 637 |
+
signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1)
|
| 638 |
+
# returns size (m, num_mel_bins + 1)
|
| 639 |
+
if htk_compat:
|
| 640 |
+
mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1)
|
| 641 |
+
else:
|
| 642 |
+
mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1)
|
| 643 |
+
|
| 644 |
+
mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
|
| 645 |
+
return mel_energies
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor:
|
| 649 |
+
# returns a dct matrix of size (num_mel_bins, num_ceps)
|
| 650 |
+
# size (num_mel_bins, num_mel_bins)
|
| 651 |
+
dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho")
|
| 652 |
+
# kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
|
| 653 |
+
# this would be the first column in the dct_matrix for torchaudio as it expects a
|
| 654 |
+
# right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
|
| 655 |
+
# expects a left multiply e.g. dct_matrix * vector).
|
| 656 |
+
dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins))
|
| 657 |
+
dct_matrix = dct_matrix[:, :num_ceps]
|
| 658 |
+
return dct_matrix
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor:
|
| 662 |
+
# returns size (num_ceps)
|
| 663 |
+
# Compute liftering coefficients (scaling on cepstral coeffs)
|
| 664 |
+
# coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
|
| 665 |
+
i = torch.arange(num_ceps)
|
| 666 |
+
return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter)
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
def mfcc(
|
| 670 |
+
waveform: Tensor,
|
| 671 |
+
blackman_coeff: float = 0.42,
|
| 672 |
+
cepstral_lifter: float = 22.0,
|
| 673 |
+
channel: int = -1,
|
| 674 |
+
dither: float = 0.0,
|
| 675 |
+
energy_floor: float = 1.0,
|
| 676 |
+
frame_length: float = 25.0,
|
| 677 |
+
frame_shift: float = 10.0,
|
| 678 |
+
high_freq: float = 0.0,
|
| 679 |
+
htk_compat: bool = False,
|
| 680 |
+
low_freq: float = 20.0,
|
| 681 |
+
num_ceps: int = 13,
|
| 682 |
+
min_duration: float = 0.0,
|
| 683 |
+
num_mel_bins: int = 23,
|
| 684 |
+
preemphasis_coefficient: float = 0.97,
|
| 685 |
+
raw_energy: bool = True,
|
| 686 |
+
remove_dc_offset: bool = True,
|
| 687 |
+
round_to_power_of_two: bool = True,
|
| 688 |
+
sample_frequency: float = 16000.0,
|
| 689 |
+
snip_edges: bool = True,
|
| 690 |
+
subtract_mean: bool = False,
|
| 691 |
+
use_energy: bool = False,
|
| 692 |
+
vtln_high: float = -500.0,
|
| 693 |
+
vtln_low: float = 100.0,
|
| 694 |
+
vtln_warp: float = 1.0,
|
| 695 |
+
window_type: str = POVEY,
|
| 696 |
+
) -> Tensor:
|
| 697 |
+
r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's
|
| 698 |
+
compute-mfcc-feats.
|
| 699 |
+
|
| 700 |
+
Args:
|
| 701 |
+
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
| 702 |
+
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
|
| 703 |
+
cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``)
|
| 704 |
+
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
|
| 705 |
+
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
|
| 706 |
+
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
|
| 707 |
+
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
|
| 708 |
+
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
|
| 709 |
+
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
|
| 710 |
+
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
|
| 711 |
+
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
|
| 712 |
+
high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
|
| 713 |
+
(Default: ``0.0``)
|
| 714 |
+
htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible
|
| 715 |
+
features (need to change other parameters). (Default: ``False``)
|
| 716 |
+
low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
|
| 717 |
+
num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``)
|
| 718 |
+
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
|
| 719 |
+
num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
|
| 720 |
+
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
|
| 721 |
+
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
|
| 722 |
+
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
|
| 723 |
+
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
| 724 |
+
to FFT. (Default: ``True``)
|
| 725 |
+
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
|
| 726 |
+
specified there) (Default: ``16000.0``)
|
| 727 |
+
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
|
| 728 |
+
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
| 729 |
+
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
|
| 730 |
+
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
|
| 731 |
+
it this way. (Default: ``False``)
|
| 732 |
+
use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
|
| 733 |
+
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
|
| 734 |
+
negative, offset from high-mel-freq (Default: ``-500.0``)
|
| 735 |
+
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
|
| 736 |
+
vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
|
| 737 |
+
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
|
| 738 |
+
(Default: ``"povey"``)
|
| 739 |
+
|
| 740 |
+
Returns:
|
| 741 |
+
Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``)
|
| 742 |
+
where m is calculated in _get_strided
|
| 743 |
+
"""
|
| 744 |
+
assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins)
|
| 745 |
+
|
| 746 |
+
device, dtype = waveform.device, waveform.dtype
|
| 747 |
+
|
| 748 |
+
# The mel_energies should not be squared (use_power=True), not have mean subtracted
|
| 749 |
+
# (subtract_mean=False), and use log (use_log_fbank=True).
|
| 750 |
+
# size (m, num_mel_bins + use_energy)
|
| 751 |
+
feature = fbank(
|
| 752 |
+
waveform=waveform,
|
| 753 |
+
blackman_coeff=blackman_coeff,
|
| 754 |
+
channel=channel,
|
| 755 |
+
dither=dither,
|
| 756 |
+
energy_floor=energy_floor,
|
| 757 |
+
frame_length=frame_length,
|
| 758 |
+
frame_shift=frame_shift,
|
| 759 |
+
high_freq=high_freq,
|
| 760 |
+
htk_compat=htk_compat,
|
| 761 |
+
low_freq=low_freq,
|
| 762 |
+
min_duration=min_duration,
|
| 763 |
+
num_mel_bins=num_mel_bins,
|
| 764 |
+
preemphasis_coefficient=preemphasis_coefficient,
|
| 765 |
+
raw_energy=raw_energy,
|
| 766 |
+
remove_dc_offset=remove_dc_offset,
|
| 767 |
+
round_to_power_of_two=round_to_power_of_two,
|
| 768 |
+
sample_frequency=sample_frequency,
|
| 769 |
+
snip_edges=snip_edges,
|
| 770 |
+
subtract_mean=False,
|
| 771 |
+
use_energy=use_energy,
|
| 772 |
+
use_log_fbank=True,
|
| 773 |
+
use_power=True,
|
| 774 |
+
vtln_high=vtln_high,
|
| 775 |
+
vtln_low=vtln_low,
|
| 776 |
+
vtln_warp=vtln_warp,
|
| 777 |
+
window_type=window_type,
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
if use_energy:
|
| 781 |
+
# size (m)
|
| 782 |
+
signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
|
| 783 |
+
# offset is 0 if htk_compat==True else 1
|
| 784 |
+
mel_offset = int(not htk_compat)
|
| 785 |
+
feature = feature[:, mel_offset : (num_mel_bins + mel_offset)]
|
| 786 |
+
|
| 787 |
+
# size (num_mel_bins, num_ceps)
|
| 788 |
+
dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device)
|
| 789 |
+
|
| 790 |
+
# size (m, num_ceps)
|
| 791 |
+
feature = feature.matmul(dct_matrix)
|
| 792 |
+
|
| 793 |
+
if cepstral_lifter != 0.0:
|
| 794 |
+
# size (1, num_ceps)
|
| 795 |
+
lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0)
|
| 796 |
+
feature *= lifter_coeffs.to(device=device, dtype=dtype)
|
| 797 |
+
|
| 798 |
+
# if use_energy then replace the last column for htk_compat == true else first column
|
| 799 |
+
if use_energy:
|
| 800 |
+
feature[:, 0] = signal_log_energy
|
| 801 |
+
|
| 802 |
+
if htk_compat:
|
| 803 |
+
energy = feature[:, 0].unsqueeze(1) # size (m, 1)
|
| 804 |
+
feature = feature[:, 1:] # size (m, num_ceps - 1)
|
| 805 |
+
if not use_energy:
|
| 806 |
+
# scale on C0 (actually removing a scale we previously added that's
|
| 807 |
+
# part of one common definition of the cosine transform.)
|
| 808 |
+
energy *= math.sqrt(2)
|
| 809 |
+
|
| 810 |
+
feature = torch.cat((feature, energy), dim=1)
|
| 811 |
+
|
| 812 |
+
feature = _subtract_column_mean(feature, subtract_mean)
|
| 813 |
+
return feature
|
.venv/lib/python3.11/site-packages/torchaudio/models/__init__.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._hdemucs import HDemucs, hdemucs_high, hdemucs_low, hdemucs_medium
|
| 2 |
+
from .conformer import Conformer
|
| 3 |
+
from .conv_tasnet import conv_tasnet_base, ConvTasNet
|
| 4 |
+
from .deepspeech import DeepSpeech
|
| 5 |
+
from .emformer import Emformer
|
| 6 |
+
from .rnnt import emformer_rnnt_base, emformer_rnnt_model, RNNT
|
| 7 |
+
from .rnnt_decoder import Hypothesis, RNNTBeamSearch
|
| 8 |
+
from .squim import (
|
| 9 |
+
squim_objective_base,
|
| 10 |
+
squim_objective_model,
|
| 11 |
+
squim_subjective_base,
|
| 12 |
+
squim_subjective_model,
|
| 13 |
+
SquimObjective,
|
| 14 |
+
SquimSubjective,
|
| 15 |
+
)
|
| 16 |
+
from .tacotron2 import Tacotron2
|
| 17 |
+
from .wav2letter import Wav2Letter
|
| 18 |
+
from .wav2vec2 import (
|
| 19 |
+
hubert_base,
|
| 20 |
+
hubert_large,
|
| 21 |
+
hubert_pretrain_base,
|
| 22 |
+
hubert_pretrain_large,
|
| 23 |
+
hubert_pretrain_model,
|
| 24 |
+
hubert_pretrain_xlarge,
|
| 25 |
+
hubert_xlarge,
|
| 26 |
+
HuBERTPretrainModel,
|
| 27 |
+
wav2vec2_base,
|
| 28 |
+
wav2vec2_large,
|
| 29 |
+
wav2vec2_large_lv60k,
|
| 30 |
+
wav2vec2_model,
|
| 31 |
+
wav2vec2_xlsr_1b,
|
| 32 |
+
wav2vec2_xlsr_2b,
|
| 33 |
+
wav2vec2_xlsr_300m,
|
| 34 |
+
Wav2Vec2Model,
|
| 35 |
+
wavlm_base,
|
| 36 |
+
wavlm_large,
|
| 37 |
+
wavlm_model,
|
| 38 |
+
)
|
| 39 |
+
from .wavernn import WaveRNN
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
__all__ = [
|
| 43 |
+
"Wav2Letter",
|
| 44 |
+
"WaveRNN",
|
| 45 |
+
"ConvTasNet",
|
| 46 |
+
"conv_tasnet_base",
|
| 47 |
+
"DeepSpeech",
|
| 48 |
+
"Wav2Vec2Model",
|
| 49 |
+
"HuBERTPretrainModel",
|
| 50 |
+
"wavlm_model",
|
| 51 |
+
"wavlm_base",
|
| 52 |
+
"wavlm_large",
|
| 53 |
+
"wav2vec2_model",
|
| 54 |
+
"wav2vec2_base",
|
| 55 |
+
"wav2vec2_large",
|
| 56 |
+
"wav2vec2_large_lv60k",
|
| 57 |
+
"hubert_base",
|
| 58 |
+
"hubert_large",
|
| 59 |
+
"hubert_xlarge",
|
| 60 |
+
"hubert_pretrain_model",
|
| 61 |
+
"hubert_pretrain_base",
|
| 62 |
+
"hubert_pretrain_large",
|
| 63 |
+
"hubert_pretrain_xlarge",
|
| 64 |
+
"wav2vec2_xlsr_300m",
|
| 65 |
+
"wav2vec2_xlsr_1b",
|
| 66 |
+
"wav2vec2_xlsr_2b",
|
| 67 |
+
"Tacotron2",
|
| 68 |
+
"Conformer",
|
| 69 |
+
"Emformer",
|
| 70 |
+
"Hypothesis",
|
| 71 |
+
"RNNT",
|
| 72 |
+
"RNNTBeamSearch",
|
| 73 |
+
"emformer_rnnt_base",
|
| 74 |
+
"emformer_rnnt_model",
|
| 75 |
+
"HDemucs",
|
| 76 |
+
"hdemucs_low",
|
| 77 |
+
"hdemucs_medium",
|
| 78 |
+
"hdemucs_high",
|
| 79 |
+
"squim_objective_base",
|
| 80 |
+
"squim_objective_model",
|
| 81 |
+
"squim_subjective_base",
|
| 82 |
+
"squim_subjective_model",
|
| 83 |
+
"SquimObjective",
|
| 84 |
+
"SquimSubjective",
|
| 85 |
+
]
|
.venv/lib/python3.11/site-packages/torchaudio/models/_hdemucs.py
ADDED
|
@@ -0,0 +1,1008 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# *****************************************************************************
|
| 2 |
+
# MIT License
|
| 3 |
+
#
|
| 4 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 5 |
+
#
|
| 6 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 7 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 8 |
+
# in the Software without restriction, including without limitation the rights
|
| 9 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 10 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 11 |
+
# furnished to do so, subject to the following conditions:
|
| 12 |
+
#
|
| 13 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 14 |
+
# copies or substantial portions of the Software.
|
| 15 |
+
#
|
| 16 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 17 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 18 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 19 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 20 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 21 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 22 |
+
# SOFTWARE.
|
| 23 |
+
# *****************************************************************************
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
import math
|
| 27 |
+
import typing as tp
|
| 28 |
+
from typing import Any, Dict, List, Optional
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
from torch import nn
|
| 32 |
+
from torch.nn import functional as F
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class _ScaledEmbedding(torch.nn.Module):
|
| 36 |
+
r"""Make continuous embeddings and boost learning rate
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
num_embeddings (int): number of embeddings
|
| 40 |
+
embedding_dim (int): embedding dimensions
|
| 41 |
+
scale (float, optional): amount to scale learning rate (Default: 10.0)
|
| 42 |
+
smooth (bool, optional): choose to apply smoothing (Default: ``False``)
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, scale: float = 10.0, smooth: bool = False):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
| 48 |
+
if smooth:
|
| 49 |
+
weight = torch.cumsum(self.embedding.weight.data, dim=0)
|
| 50 |
+
# when summing gaussian, scale raises as sqrt(n), so we normalize by that.
|
| 51 |
+
weight = weight / torch.arange(1, num_embeddings + 1).sqrt()[:, None]
|
| 52 |
+
self.embedding.weight.data[:] = weight
|
| 53 |
+
self.embedding.weight.data /= scale
|
| 54 |
+
self.scale = scale
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def weight(self) -> torch.Tensor:
|
| 58 |
+
return self.embedding.weight * self.scale
|
| 59 |
+
|
| 60 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 61 |
+
r"""Forward pass for embedding with scale.
|
| 62 |
+
Args:
|
| 63 |
+
x (torch.Tensor): input tensor of shape `(num_embeddings)`
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
(Tensor):
|
| 67 |
+
Embedding output of shape `(num_embeddings, embedding_dim)`
|
| 68 |
+
"""
|
| 69 |
+
out = self.embedding(x) * self.scale
|
| 70 |
+
return out
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class _HEncLayer(torch.nn.Module):
|
| 74 |
+
|
| 75 |
+
r"""Encoder layer. This used both by the time and the frequency branch.
|
| 76 |
+
Args:
|
| 77 |
+
chin (int): number of input channels.
|
| 78 |
+
chout (int): number of output channels.
|
| 79 |
+
kernel_size (int, optional): Kernel size for encoder (Default: 8)
|
| 80 |
+
stride (int, optional): Stride for encoder layer (Default: 4)
|
| 81 |
+
norm_groups (int, optional): number of groups for group norm. (Default: 4)
|
| 82 |
+
empty (bool, optional): used to make a layer with just the first conv. this is used
|
| 83 |
+
before merging the time and freq. branches. (Default: ``False``)
|
| 84 |
+
freq (bool, optional): boolean for whether conv layer is for frequency domain (Default: ``True``)
|
| 85 |
+
norm_type (string, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``)
|
| 86 |
+
context (int, optional): context size for the 1x1 conv. (Default: 0)
|
| 87 |
+
dconv_kw (Dict[str, Any] or None, optional): dictionary of kwargs for the DConv class. (Default: ``None``)
|
| 88 |
+
pad (bool, optional): true to pad the input. Padding is done so that the output size is
|
| 89 |
+
always the input size / stride. (Default: ``True``)
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def __init__(
|
| 93 |
+
self,
|
| 94 |
+
chin: int,
|
| 95 |
+
chout: int,
|
| 96 |
+
kernel_size: int = 8,
|
| 97 |
+
stride: int = 4,
|
| 98 |
+
norm_groups: int = 4,
|
| 99 |
+
empty: bool = False,
|
| 100 |
+
freq: bool = True,
|
| 101 |
+
norm_type: str = "group_norm",
|
| 102 |
+
context: int = 0,
|
| 103 |
+
dconv_kw: Optional[Dict[str, Any]] = None,
|
| 104 |
+
pad: bool = True,
|
| 105 |
+
):
|
| 106 |
+
super().__init__()
|
| 107 |
+
if dconv_kw is None:
|
| 108 |
+
dconv_kw = {}
|
| 109 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
| 110 |
+
if norm_type == "group_norm":
|
| 111 |
+
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
| 112 |
+
pad_val = kernel_size // 4 if pad else 0
|
| 113 |
+
klass = nn.Conv1d
|
| 114 |
+
self.freq = freq
|
| 115 |
+
self.kernel_size = kernel_size
|
| 116 |
+
self.stride = stride
|
| 117 |
+
self.empty = empty
|
| 118 |
+
self.pad = pad_val
|
| 119 |
+
if freq:
|
| 120 |
+
kernel_size = [kernel_size, 1]
|
| 121 |
+
stride = [stride, 1]
|
| 122 |
+
pad_val = [pad_val, 0]
|
| 123 |
+
klass = nn.Conv2d
|
| 124 |
+
self.conv = klass(chin, chout, kernel_size, stride, pad_val)
|
| 125 |
+
self.norm1 = norm_fn(chout)
|
| 126 |
+
|
| 127 |
+
if self.empty:
|
| 128 |
+
self.rewrite = nn.Identity()
|
| 129 |
+
self.norm2 = nn.Identity()
|
| 130 |
+
self.dconv = nn.Identity()
|
| 131 |
+
else:
|
| 132 |
+
self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
|
| 133 |
+
self.norm2 = norm_fn(2 * chout)
|
| 134 |
+
self.dconv = _DConv(chout, **dconv_kw)
|
| 135 |
+
|
| 136 |
+
def forward(self, x: torch.Tensor, inject: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 137 |
+
r"""Forward pass for encoding layer.
|
| 138 |
+
|
| 139 |
+
Size depends on whether frequency or time
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
x (torch.Tensor): tensor input of shape `(B, C, F, T)` for frequency and shape
|
| 143 |
+
`(B, C, T)` for time
|
| 144 |
+
inject (torch.Tensor, optional): on last layer, combine frequency and time branches through inject param,
|
| 145 |
+
same shape as x (default: ``None``)
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Tensor
|
| 149 |
+
output tensor after encoder layer of shape `(B, C, F / stride, T)` for frequency
|
| 150 |
+
and shape `(B, C, ceil(T / stride))` for time
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
if not self.freq and x.dim() == 4:
|
| 154 |
+
B, C, Fr, T = x.shape
|
| 155 |
+
x = x.view(B, -1, T)
|
| 156 |
+
|
| 157 |
+
if not self.freq:
|
| 158 |
+
le = x.shape[-1]
|
| 159 |
+
if not le % self.stride == 0:
|
| 160 |
+
x = F.pad(x, (0, self.stride - (le % self.stride)))
|
| 161 |
+
y = self.conv(x)
|
| 162 |
+
if self.empty:
|
| 163 |
+
return y
|
| 164 |
+
if inject is not None:
|
| 165 |
+
if inject.shape[-1] != y.shape[-1]:
|
| 166 |
+
raise ValueError("Injection shapes do not align")
|
| 167 |
+
if inject.dim() == 3 and y.dim() == 4:
|
| 168 |
+
inject = inject[:, :, None]
|
| 169 |
+
y = y + inject
|
| 170 |
+
y = F.gelu(self.norm1(y))
|
| 171 |
+
if self.freq:
|
| 172 |
+
B, C, Fr, T = y.shape
|
| 173 |
+
y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
|
| 174 |
+
y = self.dconv(y)
|
| 175 |
+
y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
|
| 176 |
+
else:
|
| 177 |
+
y = self.dconv(y)
|
| 178 |
+
z = self.norm2(self.rewrite(y))
|
| 179 |
+
z = F.glu(z, dim=1)
|
| 180 |
+
return z
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class _HDecLayer(torch.nn.Module):
|
| 184 |
+
r"""Decoder layer. This used both by the time and the frequency branches.
|
| 185 |
+
Args:
|
| 186 |
+
chin (int): number of input channels.
|
| 187 |
+
chout (int): number of output channels.
|
| 188 |
+
last (bool, optional): whether current layer is final layer (Default: ``False``)
|
| 189 |
+
kernel_size (int, optional): Kernel size for encoder (Default: 8)
|
| 190 |
+
stride (int): Stride for encoder layer (Default: 4)
|
| 191 |
+
norm_groups (int, optional): number of groups for group norm. (Default: 1)
|
| 192 |
+
empty (bool, optional): used to make a layer with just the first conv. this is used
|
| 193 |
+
before merging the time and freq. branches. (Default: ``False``)
|
| 194 |
+
freq (bool, optional): boolean for whether conv layer is for frequency (Default: ``True``)
|
| 195 |
+
norm_type (str, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``)
|
| 196 |
+
context (int, optional): context size for the 1x1 conv. (Default: 1)
|
| 197 |
+
dconv_kw (Dict[str, Any] or None, optional): dictionary of kwargs for the DConv class. (Default: ``None``)
|
| 198 |
+
pad (bool, optional): true to pad the input. Padding is done so that the output size is
|
| 199 |
+
always the input size / stride. (Default: ``True``)
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
def __init__(
|
| 203 |
+
self,
|
| 204 |
+
chin: int,
|
| 205 |
+
chout: int,
|
| 206 |
+
last: bool = False,
|
| 207 |
+
kernel_size: int = 8,
|
| 208 |
+
stride: int = 4,
|
| 209 |
+
norm_groups: int = 1,
|
| 210 |
+
empty: bool = False,
|
| 211 |
+
freq: bool = True,
|
| 212 |
+
norm_type: str = "group_norm",
|
| 213 |
+
context: int = 1,
|
| 214 |
+
dconv_kw: Optional[Dict[str, Any]] = None,
|
| 215 |
+
pad: bool = True,
|
| 216 |
+
):
|
| 217 |
+
super().__init__()
|
| 218 |
+
if dconv_kw is None:
|
| 219 |
+
dconv_kw = {}
|
| 220 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
| 221 |
+
if norm_type == "group_norm":
|
| 222 |
+
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
| 223 |
+
if pad:
|
| 224 |
+
if (kernel_size - stride) % 2 != 0:
|
| 225 |
+
raise ValueError("Kernel size and stride do not align")
|
| 226 |
+
pad = (kernel_size - stride) // 2
|
| 227 |
+
else:
|
| 228 |
+
pad = 0
|
| 229 |
+
self.pad = pad
|
| 230 |
+
self.last = last
|
| 231 |
+
self.freq = freq
|
| 232 |
+
self.chin = chin
|
| 233 |
+
self.empty = empty
|
| 234 |
+
self.stride = stride
|
| 235 |
+
self.kernel_size = kernel_size
|
| 236 |
+
klass = nn.Conv1d
|
| 237 |
+
klass_tr = nn.ConvTranspose1d
|
| 238 |
+
if freq:
|
| 239 |
+
kernel_size = [kernel_size, 1]
|
| 240 |
+
stride = [stride, 1]
|
| 241 |
+
klass = nn.Conv2d
|
| 242 |
+
klass_tr = nn.ConvTranspose2d
|
| 243 |
+
self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
|
| 244 |
+
self.norm2 = norm_fn(chout)
|
| 245 |
+
if self.empty:
|
| 246 |
+
self.rewrite = nn.Identity()
|
| 247 |
+
self.norm1 = nn.Identity()
|
| 248 |
+
else:
|
| 249 |
+
self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
|
| 250 |
+
self.norm1 = norm_fn(2 * chin)
|
| 251 |
+
|
| 252 |
+
def forward(self, x: torch.Tensor, skip: Optional[torch.Tensor], length):
|
| 253 |
+
r"""Forward pass for decoding layer.
|
| 254 |
+
|
| 255 |
+
Size depends on whether frequency or time
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
x (torch.Tensor): tensor input of shape `(B, C, F, T)` for frequency and shape
|
| 259 |
+
`(B, C, T)` for time
|
| 260 |
+
skip (torch.Tensor, optional): on first layer, separate frequency and time branches using param
|
| 261 |
+
(default: ``None``)
|
| 262 |
+
length (int): Size of tensor for output
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
(Tensor, Tensor):
|
| 266 |
+
Tensor
|
| 267 |
+
output tensor after decoder layer of shape `(B, C, F * stride, T)` for frequency domain except last
|
| 268 |
+
frequency layer shape is `(B, C, kernel_size, T)`. Shape is `(B, C, stride * T)`
|
| 269 |
+
for time domain.
|
| 270 |
+
Tensor
|
| 271 |
+
contains the output just before final transposed convolution, which is used when the
|
| 272 |
+
freq. and time branch separate. Otherwise, does not matter. Shape is
|
| 273 |
+
`(B, C, F, T)` for frequency and `(B, C, T)` for time.
|
| 274 |
+
"""
|
| 275 |
+
if self.freq and x.dim() == 3:
|
| 276 |
+
B, C, T = x.shape
|
| 277 |
+
x = x.view(B, self.chin, -1, T)
|
| 278 |
+
|
| 279 |
+
if not self.empty:
|
| 280 |
+
x = x + skip
|
| 281 |
+
y = F.glu(self.norm1(self.rewrite(x)), dim=1)
|
| 282 |
+
else:
|
| 283 |
+
y = x
|
| 284 |
+
if skip is not None:
|
| 285 |
+
raise ValueError("Skip must be none when empty is true.")
|
| 286 |
+
|
| 287 |
+
z = self.norm2(self.conv_tr(y))
|
| 288 |
+
if self.freq:
|
| 289 |
+
if self.pad:
|
| 290 |
+
z = z[..., self.pad : -self.pad, :]
|
| 291 |
+
else:
|
| 292 |
+
z = z[..., self.pad : self.pad + length]
|
| 293 |
+
if z.shape[-1] != length:
|
| 294 |
+
raise ValueError("Last index of z must be equal to length")
|
| 295 |
+
if not self.last:
|
| 296 |
+
z = F.gelu(z)
|
| 297 |
+
|
| 298 |
+
return z, y
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class HDemucs(torch.nn.Module):
|
| 302 |
+
r"""Hybrid Demucs model from
|
| 303 |
+
*Hybrid Spectrogram and Waveform Source Separation* :cite:`defossez2021hybrid`.
|
| 304 |
+
|
| 305 |
+
See Also:
|
| 306 |
+
* :class:`torchaudio.pipelines.SourceSeparationBundle`: Source separation pipeline with pre-trained models.
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
sources (List[str]): list of source names. List can contain the following source
|
| 310 |
+
options: [``"bass"``, ``"drums"``, ``"other"``, ``"mixture"``, ``"vocals"``].
|
| 311 |
+
audio_channels (int, optional): input/output audio channels. (Default: 2)
|
| 312 |
+
channels (int, optional): initial number of hidden channels. (Default: 48)
|
| 313 |
+
growth (int, optional): increase the number of hidden channels by this factor at each layer. (Default: 2)
|
| 314 |
+
nfft (int, optional): number of fft bins. Note that changing this requires careful computation of
|
| 315 |
+
various shape parameters and will not work out of the box for hybrid models. (Default: 4096)
|
| 316 |
+
depth (int, optional): number of layers in encoder and decoder (Default: 6)
|
| 317 |
+
freq_emb (float, optional): add frequency embedding after the first frequency layer if > 0,
|
| 318 |
+
the actual value controls the weight of the embedding. (Default: 0.2)
|
| 319 |
+
emb_scale (int, optional): equivalent to scaling the embedding learning rate (Default: 10)
|
| 320 |
+
emb_smooth (bool, optional): initialize the embedding with a smooth one (with respect to frequencies).
|
| 321 |
+
(Default: ``True``)
|
| 322 |
+
kernel_size (int, optional): kernel_size for encoder and decoder layers. (Default: 8)
|
| 323 |
+
time_stride (int, optional): stride for the final time layer, after the merge. (Default: 2)
|
| 324 |
+
stride (int, optional): stride for encoder and decoder layers. (Default: 4)
|
| 325 |
+
context (int, optional): context for 1x1 conv in the decoder. (Default: 4)
|
| 326 |
+
context_enc (int, optional): context for 1x1 conv in the encoder. (Default: 0)
|
| 327 |
+
norm_starts (int, optional): layer at which group norm starts being used.
|
| 328 |
+
decoder layers are numbered in reverse order. (Default: 4)
|
| 329 |
+
norm_groups (int, optional): number of groups for group norm. (Default: 4)
|
| 330 |
+
dconv_depth (int, optional): depth of residual DConv branch. (Default: 2)
|
| 331 |
+
dconv_comp (int, optional): compression of DConv branch. (Default: 4)
|
| 332 |
+
dconv_attn (int, optional): adds attention layers in DConv branch starting at this layer. (Default: 4)
|
| 333 |
+
dconv_lstm (int, optional): adds a LSTM layer in DConv branch starting at this layer. (Default: 4)
|
| 334 |
+
dconv_init (float, optional): initial scale for the DConv branch LayerScale. (Default: 1e-4)
|
| 335 |
+
"""
|
| 336 |
+
|
| 337 |
+
def __init__(
|
| 338 |
+
self,
|
| 339 |
+
sources: List[str],
|
| 340 |
+
audio_channels: int = 2,
|
| 341 |
+
channels: int = 48,
|
| 342 |
+
growth: int = 2,
|
| 343 |
+
nfft: int = 4096,
|
| 344 |
+
depth: int = 6,
|
| 345 |
+
freq_emb: float = 0.2,
|
| 346 |
+
emb_scale: int = 10,
|
| 347 |
+
emb_smooth: bool = True,
|
| 348 |
+
kernel_size: int = 8,
|
| 349 |
+
time_stride: int = 2,
|
| 350 |
+
stride: int = 4,
|
| 351 |
+
context: int = 1,
|
| 352 |
+
context_enc: int = 0,
|
| 353 |
+
norm_starts: int = 4,
|
| 354 |
+
norm_groups: int = 4,
|
| 355 |
+
dconv_depth: int = 2,
|
| 356 |
+
dconv_comp: int = 4,
|
| 357 |
+
dconv_attn: int = 4,
|
| 358 |
+
dconv_lstm: int = 4,
|
| 359 |
+
dconv_init: float = 1e-4,
|
| 360 |
+
):
|
| 361 |
+
super().__init__()
|
| 362 |
+
self.depth = depth
|
| 363 |
+
self.nfft = nfft
|
| 364 |
+
self.audio_channels = audio_channels
|
| 365 |
+
self.sources = sources
|
| 366 |
+
self.kernel_size = kernel_size
|
| 367 |
+
self.context = context
|
| 368 |
+
self.stride = stride
|
| 369 |
+
self.channels = channels
|
| 370 |
+
|
| 371 |
+
self.hop_length = self.nfft // 4
|
| 372 |
+
self.freq_emb = None
|
| 373 |
+
|
| 374 |
+
self.freq_encoder = nn.ModuleList()
|
| 375 |
+
self.freq_decoder = nn.ModuleList()
|
| 376 |
+
|
| 377 |
+
self.time_encoder = nn.ModuleList()
|
| 378 |
+
self.time_decoder = nn.ModuleList()
|
| 379 |
+
|
| 380 |
+
chin = audio_channels
|
| 381 |
+
chin_z = chin * 2 # number of channels for the freq branch
|
| 382 |
+
chout = channels
|
| 383 |
+
chout_z = channels
|
| 384 |
+
freqs = self.nfft // 2
|
| 385 |
+
|
| 386 |
+
for index in range(self.depth):
|
| 387 |
+
lstm = index >= dconv_lstm
|
| 388 |
+
attn = index >= dconv_attn
|
| 389 |
+
norm_type = "group_norm" if index >= norm_starts else "none"
|
| 390 |
+
freq = freqs > 1
|
| 391 |
+
stri = stride
|
| 392 |
+
ker = kernel_size
|
| 393 |
+
if not freq:
|
| 394 |
+
if freqs != 1:
|
| 395 |
+
raise ValueError("When freq is false, freqs must be 1.")
|
| 396 |
+
ker = time_stride * 2
|
| 397 |
+
stri = time_stride
|
| 398 |
+
|
| 399 |
+
pad = True
|
| 400 |
+
last_freq = False
|
| 401 |
+
if freq and freqs <= kernel_size:
|
| 402 |
+
ker = freqs
|
| 403 |
+
pad = False
|
| 404 |
+
last_freq = True
|
| 405 |
+
|
| 406 |
+
kw = {
|
| 407 |
+
"kernel_size": ker,
|
| 408 |
+
"stride": stri,
|
| 409 |
+
"freq": freq,
|
| 410 |
+
"pad": pad,
|
| 411 |
+
"norm_type": norm_type,
|
| 412 |
+
"norm_groups": norm_groups,
|
| 413 |
+
"dconv_kw": {
|
| 414 |
+
"lstm": lstm,
|
| 415 |
+
"attn": attn,
|
| 416 |
+
"depth": dconv_depth,
|
| 417 |
+
"compress": dconv_comp,
|
| 418 |
+
"init": dconv_init,
|
| 419 |
+
},
|
| 420 |
+
}
|
| 421 |
+
kwt = dict(kw)
|
| 422 |
+
kwt["freq"] = 0
|
| 423 |
+
kwt["kernel_size"] = kernel_size
|
| 424 |
+
kwt["stride"] = stride
|
| 425 |
+
kwt["pad"] = True
|
| 426 |
+
kw_dec = dict(kw)
|
| 427 |
+
|
| 428 |
+
if last_freq:
|
| 429 |
+
chout_z = max(chout, chout_z)
|
| 430 |
+
chout = chout_z
|
| 431 |
+
|
| 432 |
+
enc = _HEncLayer(chin_z, chout_z, context=context_enc, **kw)
|
| 433 |
+
if freq:
|
| 434 |
+
if last_freq is True and nfft == 2048:
|
| 435 |
+
kwt["stride"] = 2
|
| 436 |
+
kwt["kernel_size"] = 4
|
| 437 |
+
tenc = _HEncLayer(chin, chout, context=context_enc, empty=last_freq, **kwt)
|
| 438 |
+
self.time_encoder.append(tenc)
|
| 439 |
+
|
| 440 |
+
self.freq_encoder.append(enc)
|
| 441 |
+
if index == 0:
|
| 442 |
+
chin = self.audio_channels * len(self.sources)
|
| 443 |
+
chin_z = chin * 2
|
| 444 |
+
dec = _HDecLayer(chout_z, chin_z, last=index == 0, context=context, **kw_dec)
|
| 445 |
+
if freq:
|
| 446 |
+
tdec = _HDecLayer(chout, chin, empty=last_freq, last=index == 0, context=context, **kwt)
|
| 447 |
+
self.time_decoder.insert(0, tdec)
|
| 448 |
+
self.freq_decoder.insert(0, dec)
|
| 449 |
+
|
| 450 |
+
chin = chout
|
| 451 |
+
chin_z = chout_z
|
| 452 |
+
chout = int(growth * chout)
|
| 453 |
+
chout_z = int(growth * chout_z)
|
| 454 |
+
if freq:
|
| 455 |
+
if freqs <= kernel_size:
|
| 456 |
+
freqs = 1
|
| 457 |
+
else:
|
| 458 |
+
freqs //= stride
|
| 459 |
+
if index == 0 and freq_emb:
|
| 460 |
+
self.freq_emb = _ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
|
| 461 |
+
self.freq_emb_scale = freq_emb
|
| 462 |
+
|
| 463 |
+
_rescale_module(self)
|
| 464 |
+
|
| 465 |
+
def _spec(self, x):
|
| 466 |
+
hl = self.hop_length
|
| 467 |
+
nfft = self.nfft
|
| 468 |
+
x0 = x # noqa
|
| 469 |
+
|
| 470 |
+
# We re-pad the signal in order to keep the property
|
| 471 |
+
# that the size of the output is exactly the size of the input
|
| 472 |
+
# divided by the stride (here hop_length), when divisible.
|
| 473 |
+
# This is achieved by padding by 1/4th of the kernel size (here nfft).
|
| 474 |
+
# which is not supported by torch.stft.
|
| 475 |
+
# Having all convolution operations follow this convention allow to easily
|
| 476 |
+
# align the time and frequency branches later on.
|
| 477 |
+
if hl != nfft // 4:
|
| 478 |
+
raise ValueError("Hop length must be nfft // 4")
|
| 479 |
+
le = int(math.ceil(x.shape[-1] / hl))
|
| 480 |
+
pad = hl // 2 * 3
|
| 481 |
+
x = self._pad1d(x, pad, pad + le * hl - x.shape[-1], mode="reflect")
|
| 482 |
+
|
| 483 |
+
z = _spectro(x, nfft, hl)[..., :-1, :]
|
| 484 |
+
if z.shape[-1] != le + 4:
|
| 485 |
+
raise ValueError("Spectrogram's last dimension must be 4 + input size divided by stride")
|
| 486 |
+
z = z[..., 2 : 2 + le]
|
| 487 |
+
return z
|
| 488 |
+
|
| 489 |
+
def _ispec(self, z, length=None):
|
| 490 |
+
hl = self.hop_length
|
| 491 |
+
z = F.pad(z, [0, 0, 0, 1])
|
| 492 |
+
z = F.pad(z, [2, 2])
|
| 493 |
+
pad = hl // 2 * 3
|
| 494 |
+
le = hl * int(math.ceil(length / hl)) + 2 * pad
|
| 495 |
+
x = _ispectro(z, hl, length=le)
|
| 496 |
+
x = x[..., pad : pad + length]
|
| 497 |
+
return x
|
| 498 |
+
|
| 499 |
+
def _pad1d(self, x: torch.Tensor, padding_left: int, padding_right: int, mode: str = "zero", value: float = 0.0):
|
| 500 |
+
"""Wrapper around F.pad, in order for reflect padding when num_frames is shorter than max_pad.
|
| 501 |
+
Add extra zero padding around in order for padding to not break."""
|
| 502 |
+
length = x.shape[-1]
|
| 503 |
+
if mode == "reflect":
|
| 504 |
+
max_pad = max(padding_left, padding_right)
|
| 505 |
+
if length <= max_pad:
|
| 506 |
+
x = F.pad(x, (0, max_pad - length + 1))
|
| 507 |
+
return F.pad(x, (padding_left, padding_right), mode, value)
|
| 508 |
+
|
| 509 |
+
def _magnitude(self, z):
|
| 510 |
+
# move the complex dimension to the channel one.
|
| 511 |
+
B, C, Fr, T = z.shape
|
| 512 |
+
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
|
| 513 |
+
m = m.reshape(B, C * 2, Fr, T)
|
| 514 |
+
return m
|
| 515 |
+
|
| 516 |
+
def _mask(self, m):
|
| 517 |
+
# `m` is a full spectrogram and `z` is ignored.
|
| 518 |
+
B, S, C, Fr, T = m.shape
|
| 519 |
+
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
|
| 520 |
+
out = torch.view_as_complex(out.contiguous())
|
| 521 |
+
return out
|
| 522 |
+
|
| 523 |
+
def forward(self, input: torch.Tensor):
|
| 524 |
+
|
| 525 |
+
r"""HDemucs forward call
|
| 526 |
+
|
| 527 |
+
Args:
|
| 528 |
+
input (torch.Tensor): input mixed tensor of shape `(batch_size, channel, num_frames)`
|
| 529 |
+
|
| 530 |
+
Returns:
|
| 531 |
+
Tensor
|
| 532 |
+
output tensor split into sources of shape `(batch_size, num_sources, channel, num_frames)`
|
| 533 |
+
"""
|
| 534 |
+
|
| 535 |
+
if input.ndim != 3:
|
| 536 |
+
raise ValueError(f"Expected 3D tensor with dimensions (batch, channel, frames). Found: {input.shape}")
|
| 537 |
+
|
| 538 |
+
if input.shape[1] != self.audio_channels:
|
| 539 |
+
raise ValueError(
|
| 540 |
+
f"The channel dimension of input Tensor must match `audio_channels` of HDemucs model. "
|
| 541 |
+
f"Found:{input.shape[1]}."
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
x = input
|
| 545 |
+
length = x.shape[-1]
|
| 546 |
+
|
| 547 |
+
z = self._spec(input)
|
| 548 |
+
mag = self._magnitude(z)
|
| 549 |
+
x = mag
|
| 550 |
+
|
| 551 |
+
B, C, Fq, T = x.shape
|
| 552 |
+
|
| 553 |
+
# unlike previous Demucs, we always normalize because it is easier.
|
| 554 |
+
mean = x.mean(dim=(1, 2, 3), keepdim=True)
|
| 555 |
+
std = x.std(dim=(1, 2, 3), keepdim=True)
|
| 556 |
+
x = (x - mean) / (1e-5 + std)
|
| 557 |
+
# x will be the freq. branch input.
|
| 558 |
+
|
| 559 |
+
# Prepare the time branch input.
|
| 560 |
+
xt = input
|
| 561 |
+
meant = xt.mean(dim=(1, 2), keepdim=True)
|
| 562 |
+
stdt = xt.std(dim=(1, 2), keepdim=True)
|
| 563 |
+
xt = (xt - meant) / (1e-5 + stdt)
|
| 564 |
+
|
| 565 |
+
saved = [] # skip connections, freq.
|
| 566 |
+
saved_t = [] # skip connections, time.
|
| 567 |
+
lengths: List[int] = [] # saved lengths to properly remove padding, freq branch.
|
| 568 |
+
lengths_t: List[int] = [] # saved lengths for time branch.
|
| 569 |
+
|
| 570 |
+
for idx, encode in enumerate(self.freq_encoder):
|
| 571 |
+
lengths.append(x.shape[-1])
|
| 572 |
+
inject = None
|
| 573 |
+
if idx < len(self.time_encoder):
|
| 574 |
+
# we have not yet merged branches.
|
| 575 |
+
lengths_t.append(xt.shape[-1])
|
| 576 |
+
tenc = self.time_encoder[idx]
|
| 577 |
+
xt = tenc(xt)
|
| 578 |
+
if not tenc.empty:
|
| 579 |
+
# save for skip connection
|
| 580 |
+
saved_t.append(xt)
|
| 581 |
+
else:
|
| 582 |
+
# tenc contains just the first conv., so that now time and freq.
|
| 583 |
+
# branches have the same shape and can be merged.
|
| 584 |
+
inject = xt
|
| 585 |
+
x = encode(x, inject)
|
| 586 |
+
if idx == 0 and self.freq_emb is not None:
|
| 587 |
+
# add frequency embedding to allow for non equivariant convolutions
|
| 588 |
+
# over the frequency axis.
|
| 589 |
+
frs = torch.arange(x.shape[-2], device=x.device)
|
| 590 |
+
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
|
| 591 |
+
x = x + self.freq_emb_scale * emb
|
| 592 |
+
|
| 593 |
+
saved.append(x)
|
| 594 |
+
|
| 595 |
+
x = torch.zeros_like(x)
|
| 596 |
+
xt = torch.zeros_like(x)
|
| 597 |
+
# initialize everything to zero (signal will go through u-net skips).
|
| 598 |
+
|
| 599 |
+
for idx, decode in enumerate(self.freq_decoder):
|
| 600 |
+
skip = saved.pop(-1)
|
| 601 |
+
x, pre = decode(x, skip, lengths.pop(-1))
|
| 602 |
+
# `pre` contains the output just before final transposed convolution,
|
| 603 |
+
# which is used when the freq. and time branch separate.
|
| 604 |
+
offset = self.depth - len(self.time_decoder)
|
| 605 |
+
if idx >= offset:
|
| 606 |
+
tdec = self.time_decoder[idx - offset]
|
| 607 |
+
length_t = lengths_t.pop(-1)
|
| 608 |
+
if tdec.empty:
|
| 609 |
+
if pre.shape[2] != 1:
|
| 610 |
+
raise ValueError(f"If tdec empty is True, pre shape does not match {pre.shape}")
|
| 611 |
+
pre = pre[:, :, 0]
|
| 612 |
+
xt, _ = tdec(pre, None, length_t)
|
| 613 |
+
else:
|
| 614 |
+
skip = saved_t.pop(-1)
|
| 615 |
+
xt, _ = tdec(xt, skip, length_t)
|
| 616 |
+
|
| 617 |
+
if len(saved) != 0:
|
| 618 |
+
raise AssertionError("saved is not empty")
|
| 619 |
+
if len(lengths_t) != 0:
|
| 620 |
+
raise AssertionError("lengths_t is not empty")
|
| 621 |
+
if len(saved_t) != 0:
|
| 622 |
+
raise AssertionError("saved_t is not empty")
|
| 623 |
+
|
| 624 |
+
S = len(self.sources)
|
| 625 |
+
x = x.view(B, S, -1, Fq, T)
|
| 626 |
+
x = x * std[:, None] + mean[:, None]
|
| 627 |
+
|
| 628 |
+
zout = self._mask(x)
|
| 629 |
+
x = self._ispec(zout, length)
|
| 630 |
+
|
| 631 |
+
xt = xt.view(B, S, -1, length)
|
| 632 |
+
xt = xt * stdt[:, None] + meant[:, None]
|
| 633 |
+
x = xt + x
|
| 634 |
+
return x
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
class _DConv(torch.nn.Module):
|
| 638 |
+
r"""
|
| 639 |
+
New residual branches in each encoder layer.
|
| 640 |
+
This alternates dilated convolutions, potentially with LSTMs and attention.
|
| 641 |
+
Also before entering each residual branch, dimension is projected on a smaller subspace,
|
| 642 |
+
e.g. of dim `channels // compress`.
|
| 643 |
+
|
| 644 |
+
Args:
|
| 645 |
+
channels (int): input/output channels for residual branch.
|
| 646 |
+
compress (float, optional): amount of channel compression inside the branch. (default: 4)
|
| 647 |
+
depth (int, optional): number of layers in the residual branch. Each layer has its own
|
| 648 |
+
projection, and potentially LSTM and attention.(default: 2)
|
| 649 |
+
init (float, optional): initial scale for LayerNorm. (default: 1e-4)
|
| 650 |
+
norm_type (bool, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``)
|
| 651 |
+
attn (bool, optional): use LocalAttention. (Default: ``False``)
|
| 652 |
+
heads (int, optional): number of heads for the LocalAttention. (default: 4)
|
| 653 |
+
ndecay (int, optional): number of decay controls in the LocalAttention. (default: 4)
|
| 654 |
+
lstm (bool, optional): use LSTM. (Default: ``False``)
|
| 655 |
+
kernel_size (int, optional): kernel size for the (dilated) convolutions. (default: 3)
|
| 656 |
+
"""
|
| 657 |
+
|
| 658 |
+
def __init__(
|
| 659 |
+
self,
|
| 660 |
+
channels: int,
|
| 661 |
+
compress: float = 4,
|
| 662 |
+
depth: int = 2,
|
| 663 |
+
init: float = 1e-4,
|
| 664 |
+
norm_type: str = "group_norm",
|
| 665 |
+
attn: bool = False,
|
| 666 |
+
heads: int = 4,
|
| 667 |
+
ndecay: int = 4,
|
| 668 |
+
lstm: bool = False,
|
| 669 |
+
kernel_size: int = 3,
|
| 670 |
+
):
|
| 671 |
+
|
| 672 |
+
super().__init__()
|
| 673 |
+
if kernel_size % 2 == 0:
|
| 674 |
+
raise ValueError("Kernel size should not be divisible by 2")
|
| 675 |
+
self.channels = channels
|
| 676 |
+
self.compress = compress
|
| 677 |
+
self.depth = abs(depth)
|
| 678 |
+
dilate = depth > 0
|
| 679 |
+
|
| 680 |
+
norm_fn: tp.Callable[[int], nn.Module]
|
| 681 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
| 682 |
+
if norm_type == "group_norm":
|
| 683 |
+
norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
|
| 684 |
+
|
| 685 |
+
hidden = int(channels / compress)
|
| 686 |
+
|
| 687 |
+
act = nn.GELU
|
| 688 |
+
|
| 689 |
+
self.layers = nn.ModuleList([])
|
| 690 |
+
for d in range(self.depth):
|
| 691 |
+
dilation = pow(2, d) if dilate else 1
|
| 692 |
+
padding = dilation * (kernel_size // 2)
|
| 693 |
+
mods = [
|
| 694 |
+
nn.Conv1d(channels, hidden, kernel_size, dilation=dilation, padding=padding),
|
| 695 |
+
norm_fn(hidden),
|
| 696 |
+
act(),
|
| 697 |
+
nn.Conv1d(hidden, 2 * channels, 1),
|
| 698 |
+
norm_fn(2 * channels),
|
| 699 |
+
nn.GLU(1),
|
| 700 |
+
_LayerScale(channels, init),
|
| 701 |
+
]
|
| 702 |
+
if attn:
|
| 703 |
+
mods.insert(3, _LocalState(hidden, heads=heads, ndecay=ndecay))
|
| 704 |
+
if lstm:
|
| 705 |
+
mods.insert(3, _BLSTM(hidden, layers=2, skip=True))
|
| 706 |
+
layer = nn.Sequential(*mods)
|
| 707 |
+
self.layers.append(layer)
|
| 708 |
+
|
| 709 |
+
def forward(self, x):
|
| 710 |
+
r"""DConv forward call
|
| 711 |
+
|
| 712 |
+
Args:
|
| 713 |
+
x (torch.Tensor): input tensor for convolution
|
| 714 |
+
|
| 715 |
+
Returns:
|
| 716 |
+
Tensor
|
| 717 |
+
Output after being run through layers.
|
| 718 |
+
"""
|
| 719 |
+
for layer in self.layers:
|
| 720 |
+
x = x + layer(x)
|
| 721 |
+
return x
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
class _BLSTM(torch.nn.Module):
|
| 725 |
+
r"""
|
| 726 |
+
BiLSTM with same hidden units as input dim.
|
| 727 |
+
If `max_steps` is not None, input will be splitting in overlapping
|
| 728 |
+
chunks and the LSTM applied separately on each chunk.
|
| 729 |
+
Args:
|
| 730 |
+
dim (int): dimensions at LSTM layer.
|
| 731 |
+
layers (int, optional): number of LSTM layers. (default: 1)
|
| 732 |
+
skip (bool, optional): (default: ``False``)
|
| 733 |
+
"""
|
| 734 |
+
|
| 735 |
+
def __init__(self, dim, layers: int = 1, skip: bool = False):
|
| 736 |
+
super().__init__()
|
| 737 |
+
self.max_steps = 200
|
| 738 |
+
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
| 739 |
+
self.linear = nn.Linear(2 * dim, dim)
|
| 740 |
+
self.skip = skip
|
| 741 |
+
|
| 742 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 743 |
+
r"""BLSTM forward call
|
| 744 |
+
|
| 745 |
+
Args:
|
| 746 |
+
x (torch.Tensor): input tensor for BLSTM shape is `(batch_size, dim, time_steps)`
|
| 747 |
+
|
| 748 |
+
Returns:
|
| 749 |
+
Tensor
|
| 750 |
+
Output after being run through bidirectional LSTM. Shape is `(batch_size, dim, time_steps)`
|
| 751 |
+
"""
|
| 752 |
+
B, C, T = x.shape
|
| 753 |
+
y = x
|
| 754 |
+
framed = False
|
| 755 |
+
width = 0
|
| 756 |
+
stride = 0
|
| 757 |
+
nframes = 0
|
| 758 |
+
if self.max_steps is not None and T > self.max_steps:
|
| 759 |
+
width = self.max_steps
|
| 760 |
+
stride = width // 2
|
| 761 |
+
frames = _unfold(x, width, stride)
|
| 762 |
+
nframes = frames.shape[2]
|
| 763 |
+
framed = True
|
| 764 |
+
x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
|
| 765 |
+
|
| 766 |
+
x = x.permute(2, 0, 1)
|
| 767 |
+
|
| 768 |
+
x = self.lstm(x)[0]
|
| 769 |
+
x = self.linear(x)
|
| 770 |
+
x = x.permute(1, 2, 0)
|
| 771 |
+
if framed:
|
| 772 |
+
out = []
|
| 773 |
+
frames = x.reshape(B, -1, C, width)
|
| 774 |
+
limit = stride // 2
|
| 775 |
+
for k in range(nframes):
|
| 776 |
+
if k == 0:
|
| 777 |
+
out.append(frames[:, k, :, :-limit])
|
| 778 |
+
elif k == nframes - 1:
|
| 779 |
+
out.append(frames[:, k, :, limit:])
|
| 780 |
+
else:
|
| 781 |
+
out.append(frames[:, k, :, limit:-limit])
|
| 782 |
+
out = torch.cat(out, -1)
|
| 783 |
+
out = out[..., :T]
|
| 784 |
+
x = out
|
| 785 |
+
if self.skip:
|
| 786 |
+
x = x + y
|
| 787 |
+
|
| 788 |
+
return x
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
class _LocalState(nn.Module):
|
| 792 |
+
"""Local state allows to have attention based only on data (no positional embedding),
|
| 793 |
+
but while setting a constraint on the time window (e.g. decaying penalty term).
|
| 794 |
+
Also a failed experiments with trying to provide some frequency based attention.
|
| 795 |
+
"""
|
| 796 |
+
|
| 797 |
+
def __init__(self, channels: int, heads: int = 4, ndecay: int = 4):
|
| 798 |
+
r"""
|
| 799 |
+
Args:
|
| 800 |
+
channels (int): Size of Conv1d layers.
|
| 801 |
+
heads (int, optional): (default: 4)
|
| 802 |
+
ndecay (int, optional): (default: 4)
|
| 803 |
+
"""
|
| 804 |
+
super(_LocalState, self).__init__()
|
| 805 |
+
if channels % heads != 0:
|
| 806 |
+
raise ValueError("Channels must be divisible by heads.")
|
| 807 |
+
self.heads = heads
|
| 808 |
+
self.ndecay = ndecay
|
| 809 |
+
self.content = nn.Conv1d(channels, channels, 1)
|
| 810 |
+
self.query = nn.Conv1d(channels, channels, 1)
|
| 811 |
+
self.key = nn.Conv1d(channels, channels, 1)
|
| 812 |
+
|
| 813 |
+
self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
|
| 814 |
+
if ndecay:
|
| 815 |
+
# Initialize decay close to zero (there is a sigmoid), for maximum initial window.
|
| 816 |
+
self.query_decay.weight.data *= 0.01
|
| 817 |
+
if self.query_decay.bias is None:
|
| 818 |
+
raise ValueError("bias must not be None.")
|
| 819 |
+
self.query_decay.bias.data[:] = -2
|
| 820 |
+
self.proj = nn.Conv1d(channels + heads * 0, channels, 1)
|
| 821 |
+
|
| 822 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 823 |
+
r"""LocalState forward call
|
| 824 |
+
|
| 825 |
+
Args:
|
| 826 |
+
x (torch.Tensor): input tensor for LocalState
|
| 827 |
+
|
| 828 |
+
Returns:
|
| 829 |
+
Tensor
|
| 830 |
+
Output after being run through LocalState layer.
|
| 831 |
+
"""
|
| 832 |
+
B, C, T = x.shape
|
| 833 |
+
heads = self.heads
|
| 834 |
+
indexes = torch.arange(T, device=x.device, dtype=x.dtype)
|
| 835 |
+
# left index are keys, right index are queries
|
| 836 |
+
delta = indexes[:, None] - indexes[None, :]
|
| 837 |
+
|
| 838 |
+
queries = self.query(x).view(B, heads, -1, T)
|
| 839 |
+
keys = self.key(x).view(B, heads, -1, T)
|
| 840 |
+
# t are keys, s are queries
|
| 841 |
+
dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
|
| 842 |
+
dots /= math.sqrt(keys.shape[2])
|
| 843 |
+
if self.ndecay:
|
| 844 |
+
decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
|
| 845 |
+
decay_q = self.query_decay(x).view(B, heads, -1, T)
|
| 846 |
+
decay_q = torch.sigmoid(decay_q) / 2
|
| 847 |
+
decay_kernel = -decays.view(-1, 1, 1) * delta.abs() / math.sqrt(self.ndecay)
|
| 848 |
+
dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
|
| 849 |
+
|
| 850 |
+
# Kill self reference.
|
| 851 |
+
dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
|
| 852 |
+
weights = torch.softmax(dots, dim=2)
|
| 853 |
+
|
| 854 |
+
content = self.content(x).view(B, heads, -1, T)
|
| 855 |
+
result = torch.einsum("bhts,bhct->bhcs", weights, content)
|
| 856 |
+
result = result.reshape(B, -1, T)
|
| 857 |
+
return x + self.proj(result)
|
| 858 |
+
|
| 859 |
+
|
| 860 |
+
class _LayerScale(nn.Module):
|
| 861 |
+
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
|
| 862 |
+
This rescales diagonally residual outputs close to 0 initially, then learnt.
|
| 863 |
+
"""
|
| 864 |
+
|
| 865 |
+
def __init__(self, channels: int, init: float = 0):
|
| 866 |
+
r"""
|
| 867 |
+
Args:
|
| 868 |
+
channels (int): Size of rescaling
|
| 869 |
+
init (float, optional): Scale to default to (default: 0)
|
| 870 |
+
"""
|
| 871 |
+
super().__init__()
|
| 872 |
+
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
|
| 873 |
+
self.scale.data[:] = init
|
| 874 |
+
|
| 875 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 876 |
+
r"""LayerScale forward call
|
| 877 |
+
|
| 878 |
+
Args:
|
| 879 |
+
x (torch.Tensor): input tensor for LayerScale
|
| 880 |
+
|
| 881 |
+
Returns:
|
| 882 |
+
Tensor
|
| 883 |
+
Output after rescaling tensor.
|
| 884 |
+
"""
|
| 885 |
+
return self.scale[:, None] * x
|
| 886 |
+
|
| 887 |
+
|
| 888 |
+
def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor:
|
| 889 |
+
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
|
| 890 |
+
with K the kernel size, by extracting frames with the given stride.
|
| 891 |
+
This will pad the input so that `F = ceil(T / K)`.
|
| 892 |
+
see https://github.com/pytorch/pytorch/issues/60466
|
| 893 |
+
"""
|
| 894 |
+
shape = list(a.shape[:-1])
|
| 895 |
+
length = int(a.shape[-1])
|
| 896 |
+
n_frames = math.ceil(length / stride)
|
| 897 |
+
tgt_length = (n_frames - 1) * stride + kernel_size
|
| 898 |
+
a = F.pad(input=a, pad=[0, tgt_length - length])
|
| 899 |
+
strides = [a.stride(dim) for dim in range(a.dim())]
|
| 900 |
+
if strides[-1] != 1:
|
| 901 |
+
raise ValueError("Data should be contiguous.")
|
| 902 |
+
strides = strides[:-1] + [stride, 1]
|
| 903 |
+
shape.append(n_frames)
|
| 904 |
+
shape.append(kernel_size)
|
| 905 |
+
return a.as_strided(shape, strides)
|
| 906 |
+
|
| 907 |
+
|
| 908 |
+
def _rescale_module(module):
|
| 909 |
+
r"""
|
| 910 |
+
Rescales initial weight scale for all models within the module.
|
| 911 |
+
"""
|
| 912 |
+
for sub in module.modules():
|
| 913 |
+
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
|
| 914 |
+
std = sub.weight.std().detach()
|
| 915 |
+
scale = (std / 0.1) ** 0.5
|
| 916 |
+
sub.weight.data /= scale
|
| 917 |
+
if sub.bias is not None:
|
| 918 |
+
sub.bias.data /= scale
|
| 919 |
+
|
| 920 |
+
|
| 921 |
+
def _spectro(x: torch.Tensor, n_fft: int = 512, hop_length: int = 0, pad: int = 0) -> torch.Tensor:
|
| 922 |
+
other = list(x.shape[:-1])
|
| 923 |
+
length = int(x.shape[-1])
|
| 924 |
+
x = x.reshape(-1, length)
|
| 925 |
+
z = torch.stft(
|
| 926 |
+
x,
|
| 927 |
+
n_fft * (1 + pad),
|
| 928 |
+
hop_length,
|
| 929 |
+
window=torch.hann_window(n_fft).to(x),
|
| 930 |
+
win_length=n_fft,
|
| 931 |
+
normalized=True,
|
| 932 |
+
center=True,
|
| 933 |
+
return_complex=True,
|
| 934 |
+
pad_mode="reflect",
|
| 935 |
+
)
|
| 936 |
+
_, freqs, frame = z.shape
|
| 937 |
+
other.extend([freqs, frame])
|
| 938 |
+
return z.view(other)
|
| 939 |
+
|
| 940 |
+
|
| 941 |
+
def _ispectro(z: torch.Tensor, hop_length: int = 0, length: int = 0, pad: int = 0) -> torch.Tensor:
|
| 942 |
+
other = list(z.shape[:-2])
|
| 943 |
+
freqs = int(z.shape[-2])
|
| 944 |
+
frames = int(z.shape[-1])
|
| 945 |
+
|
| 946 |
+
n_fft = 2 * freqs - 2
|
| 947 |
+
z = z.view(-1, freqs, frames)
|
| 948 |
+
win_length = n_fft // (1 + pad)
|
| 949 |
+
x = torch.istft(
|
| 950 |
+
z,
|
| 951 |
+
n_fft,
|
| 952 |
+
hop_length,
|
| 953 |
+
window=torch.hann_window(win_length).to(z.real),
|
| 954 |
+
win_length=win_length,
|
| 955 |
+
normalized=True,
|
| 956 |
+
length=length,
|
| 957 |
+
center=True,
|
| 958 |
+
)
|
| 959 |
+
_, length = x.shape
|
| 960 |
+
other.append(length)
|
| 961 |
+
return x.view(other)
|
| 962 |
+
|
| 963 |
+
|
| 964 |
+
def hdemucs_low(sources: List[str]) -> HDemucs:
|
| 965 |
+
"""Builds low nfft (1024) version of :class:`HDemucs`, suitable for sample rates around 8 kHz.
|
| 966 |
+
|
| 967 |
+
Args:
|
| 968 |
+
sources (List[str]): See :py:func:`HDemucs`.
|
| 969 |
+
|
| 970 |
+
Returns:
|
| 971 |
+
HDemucs:
|
| 972 |
+
HDemucs model.
|
| 973 |
+
"""
|
| 974 |
+
|
| 975 |
+
return HDemucs(sources=sources, nfft=1024, depth=5)
|
| 976 |
+
|
| 977 |
+
|
| 978 |
+
def hdemucs_medium(sources: List[str]) -> HDemucs:
|
| 979 |
+
r"""Builds medium nfft (2048) version of :class:`HDemucs`, suitable for sample rates of 16-32 kHz.
|
| 980 |
+
|
| 981 |
+
.. note::
|
| 982 |
+
|
| 983 |
+
Medium HDemucs has not been tested against the original Hybrid Demucs as this nfft and depth configuration is
|
| 984 |
+
not compatible with the original implementation in https://github.com/facebookresearch/demucs
|
| 985 |
+
|
| 986 |
+
Args:
|
| 987 |
+
sources (List[str]): See :py:func:`HDemucs`.
|
| 988 |
+
|
| 989 |
+
Returns:
|
| 990 |
+
HDemucs:
|
| 991 |
+
HDemucs model.
|
| 992 |
+
"""
|
| 993 |
+
|
| 994 |
+
return HDemucs(sources=sources, nfft=2048, depth=6)
|
| 995 |
+
|
| 996 |
+
|
| 997 |
+
def hdemucs_high(sources: List[str]) -> HDemucs:
|
| 998 |
+
r"""Builds medium nfft (4096) version of :class:`HDemucs`, suitable for sample rates of 44.1-48 kHz.
|
| 999 |
+
|
| 1000 |
+
Args:
|
| 1001 |
+
sources (List[str]): See :py:func:`HDemucs`.
|
| 1002 |
+
|
| 1003 |
+
Returns:
|
| 1004 |
+
HDemucs:
|
| 1005 |
+
HDemucs model.
|
| 1006 |
+
"""
|
| 1007 |
+
|
| 1008 |
+
return HDemucs(sources=sources, nfft=4096, depth=6)
|
.venv/lib/python3.11/site-packages/torchaudio/models/conformer.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
__all__ = ["Conformer"]
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _lengths_to_padding_mask(lengths: torch.Tensor) -> torch.Tensor:
|
| 10 |
+
batch_size = lengths.shape[0]
|
| 11 |
+
max_length = int(torch.max(lengths).item())
|
| 12 |
+
padding_mask = torch.arange(max_length, device=lengths.device, dtype=lengths.dtype).expand(
|
| 13 |
+
batch_size, max_length
|
| 14 |
+
) >= lengths.unsqueeze(1)
|
| 15 |
+
return padding_mask
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class _ConvolutionModule(torch.nn.Module):
|
| 19 |
+
r"""Conformer convolution module.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
input_dim (int): input dimension.
|
| 23 |
+
num_channels (int): number of depthwise convolution layer input channels.
|
| 24 |
+
depthwise_kernel_size (int): kernel size of depthwise convolution layer.
|
| 25 |
+
dropout (float, optional): dropout probability. (Default: 0.0)
|
| 26 |
+
bias (bool, optional): indicates whether to add bias term to each convolution layer. (Default: ``False``)
|
| 27 |
+
use_group_norm (bool, optional): use GroupNorm rather than BatchNorm. (Default: ``False``)
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
input_dim: int,
|
| 33 |
+
num_channels: int,
|
| 34 |
+
depthwise_kernel_size: int,
|
| 35 |
+
dropout: float = 0.0,
|
| 36 |
+
bias: bool = False,
|
| 37 |
+
use_group_norm: bool = False,
|
| 38 |
+
) -> None:
|
| 39 |
+
super().__init__()
|
| 40 |
+
if (depthwise_kernel_size - 1) % 2 != 0:
|
| 41 |
+
raise ValueError("depthwise_kernel_size must be odd to achieve 'SAME' padding.")
|
| 42 |
+
self.layer_norm = torch.nn.LayerNorm(input_dim)
|
| 43 |
+
self.sequential = torch.nn.Sequential(
|
| 44 |
+
torch.nn.Conv1d(
|
| 45 |
+
input_dim,
|
| 46 |
+
2 * num_channels,
|
| 47 |
+
1,
|
| 48 |
+
stride=1,
|
| 49 |
+
padding=0,
|
| 50 |
+
bias=bias,
|
| 51 |
+
),
|
| 52 |
+
torch.nn.GLU(dim=1),
|
| 53 |
+
torch.nn.Conv1d(
|
| 54 |
+
num_channels,
|
| 55 |
+
num_channels,
|
| 56 |
+
depthwise_kernel_size,
|
| 57 |
+
stride=1,
|
| 58 |
+
padding=(depthwise_kernel_size - 1) // 2,
|
| 59 |
+
groups=num_channels,
|
| 60 |
+
bias=bias,
|
| 61 |
+
),
|
| 62 |
+
torch.nn.GroupNorm(num_groups=1, num_channels=num_channels)
|
| 63 |
+
if use_group_norm
|
| 64 |
+
else torch.nn.BatchNorm1d(num_channels),
|
| 65 |
+
torch.nn.SiLU(),
|
| 66 |
+
torch.nn.Conv1d(
|
| 67 |
+
num_channels,
|
| 68 |
+
input_dim,
|
| 69 |
+
kernel_size=1,
|
| 70 |
+
stride=1,
|
| 71 |
+
padding=0,
|
| 72 |
+
bias=bias,
|
| 73 |
+
),
|
| 74 |
+
torch.nn.Dropout(dropout),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 78 |
+
r"""
|
| 79 |
+
Args:
|
| 80 |
+
input (torch.Tensor): with shape `(B, T, D)`.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
torch.Tensor: output, with shape `(B, T, D)`.
|
| 84 |
+
"""
|
| 85 |
+
x = self.layer_norm(input)
|
| 86 |
+
x = x.transpose(1, 2)
|
| 87 |
+
x = self.sequential(x)
|
| 88 |
+
return x.transpose(1, 2)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class _FeedForwardModule(torch.nn.Module):
|
| 92 |
+
r"""Positionwise feed forward layer.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
input_dim (int): input dimension.
|
| 96 |
+
hidden_dim (int): hidden dimension.
|
| 97 |
+
dropout (float, optional): dropout probability. (Default: 0.0)
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.0) -> None:
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.sequential = torch.nn.Sequential(
|
| 103 |
+
torch.nn.LayerNorm(input_dim),
|
| 104 |
+
torch.nn.Linear(input_dim, hidden_dim, bias=True),
|
| 105 |
+
torch.nn.SiLU(),
|
| 106 |
+
torch.nn.Dropout(dropout),
|
| 107 |
+
torch.nn.Linear(hidden_dim, input_dim, bias=True),
|
| 108 |
+
torch.nn.Dropout(dropout),
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 112 |
+
r"""
|
| 113 |
+
Args:
|
| 114 |
+
input (torch.Tensor): with shape `(*, D)`.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
torch.Tensor: output, with shape `(*, D)`.
|
| 118 |
+
"""
|
| 119 |
+
return self.sequential(input)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class ConformerLayer(torch.nn.Module):
|
| 123 |
+
r"""Conformer layer that constitutes Conformer.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
input_dim (int): input dimension.
|
| 127 |
+
ffn_dim (int): hidden layer dimension of feedforward network.
|
| 128 |
+
num_attention_heads (int): number of attention heads.
|
| 129 |
+
depthwise_conv_kernel_size (int): kernel size of depthwise convolution layer.
|
| 130 |
+
dropout (float, optional): dropout probability. (Default: 0.0)
|
| 131 |
+
use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d``
|
| 132 |
+
in the convolution module. (Default: ``False``)
|
| 133 |
+
convolution_first (bool, optional): apply the convolution module ahead of
|
| 134 |
+
the attention module. (Default: ``False``)
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
input_dim: int,
|
| 140 |
+
ffn_dim: int,
|
| 141 |
+
num_attention_heads: int,
|
| 142 |
+
depthwise_conv_kernel_size: int,
|
| 143 |
+
dropout: float = 0.0,
|
| 144 |
+
use_group_norm: bool = False,
|
| 145 |
+
convolution_first: bool = False,
|
| 146 |
+
) -> None:
|
| 147 |
+
super().__init__()
|
| 148 |
+
|
| 149 |
+
self.ffn1 = _FeedForwardModule(input_dim, ffn_dim, dropout=dropout)
|
| 150 |
+
|
| 151 |
+
self.self_attn_layer_norm = torch.nn.LayerNorm(input_dim)
|
| 152 |
+
self.self_attn = torch.nn.MultiheadAttention(input_dim, num_attention_heads, dropout=dropout)
|
| 153 |
+
self.self_attn_dropout = torch.nn.Dropout(dropout)
|
| 154 |
+
|
| 155 |
+
self.conv_module = _ConvolutionModule(
|
| 156 |
+
input_dim=input_dim,
|
| 157 |
+
num_channels=input_dim,
|
| 158 |
+
depthwise_kernel_size=depthwise_conv_kernel_size,
|
| 159 |
+
dropout=dropout,
|
| 160 |
+
bias=True,
|
| 161 |
+
use_group_norm=use_group_norm,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
self.ffn2 = _FeedForwardModule(input_dim, ffn_dim, dropout=dropout)
|
| 165 |
+
self.final_layer_norm = torch.nn.LayerNorm(input_dim)
|
| 166 |
+
self.convolution_first = convolution_first
|
| 167 |
+
|
| 168 |
+
def _apply_convolution(self, input: torch.Tensor) -> torch.Tensor:
|
| 169 |
+
residual = input
|
| 170 |
+
input = input.transpose(0, 1)
|
| 171 |
+
input = self.conv_module(input)
|
| 172 |
+
input = input.transpose(0, 1)
|
| 173 |
+
input = residual + input
|
| 174 |
+
return input
|
| 175 |
+
|
| 176 |
+
def forward(self, input: torch.Tensor, key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor:
|
| 177 |
+
r"""
|
| 178 |
+
Args:
|
| 179 |
+
input (torch.Tensor): input, with shape `(T, B, D)`.
|
| 180 |
+
key_padding_mask (torch.Tensor or None): key padding mask to use in self attention layer.
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
torch.Tensor: output, with shape `(T, B, D)`.
|
| 184 |
+
"""
|
| 185 |
+
residual = input
|
| 186 |
+
x = self.ffn1(input)
|
| 187 |
+
x = x * 0.5 + residual
|
| 188 |
+
|
| 189 |
+
if self.convolution_first:
|
| 190 |
+
x = self._apply_convolution(x)
|
| 191 |
+
|
| 192 |
+
residual = x
|
| 193 |
+
x = self.self_attn_layer_norm(x)
|
| 194 |
+
x, _ = self.self_attn(
|
| 195 |
+
query=x,
|
| 196 |
+
key=x,
|
| 197 |
+
value=x,
|
| 198 |
+
key_padding_mask=key_padding_mask,
|
| 199 |
+
need_weights=False,
|
| 200 |
+
)
|
| 201 |
+
x = self.self_attn_dropout(x)
|
| 202 |
+
x = x + residual
|
| 203 |
+
|
| 204 |
+
if not self.convolution_first:
|
| 205 |
+
x = self._apply_convolution(x)
|
| 206 |
+
|
| 207 |
+
residual = x
|
| 208 |
+
x = self.ffn2(x)
|
| 209 |
+
x = x * 0.5 + residual
|
| 210 |
+
|
| 211 |
+
x = self.final_layer_norm(x)
|
| 212 |
+
return x
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class Conformer(torch.nn.Module):
|
| 216 |
+
r"""Conformer architecture introduced in
|
| 217 |
+
*Conformer: Convolution-augmented Transformer for Speech Recognition*
|
| 218 |
+
:cite:`gulati2020conformer`.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
input_dim (int): input dimension.
|
| 222 |
+
num_heads (int): number of attention heads in each Conformer layer.
|
| 223 |
+
ffn_dim (int): hidden layer dimension of feedforward networks.
|
| 224 |
+
num_layers (int): number of Conformer layers to instantiate.
|
| 225 |
+
depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer.
|
| 226 |
+
dropout (float, optional): dropout probability. (Default: 0.0)
|
| 227 |
+
use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d``
|
| 228 |
+
in the convolution module. (Default: ``False``)
|
| 229 |
+
convolution_first (bool, optional): apply the convolution module ahead of
|
| 230 |
+
the attention module. (Default: ``False``)
|
| 231 |
+
|
| 232 |
+
Examples:
|
| 233 |
+
>>> conformer = Conformer(
|
| 234 |
+
>>> input_dim=80,
|
| 235 |
+
>>> num_heads=4,
|
| 236 |
+
>>> ffn_dim=128,
|
| 237 |
+
>>> num_layers=4,
|
| 238 |
+
>>> depthwise_conv_kernel_size=31,
|
| 239 |
+
>>> )
|
| 240 |
+
>>> lengths = torch.randint(1, 400, (10,)) # (batch,)
|
| 241 |
+
>>> input = torch.rand(10, int(lengths.max()), input_dim) # (batch, num_frames, input_dim)
|
| 242 |
+
>>> output = conformer(input, lengths)
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
def __init__(
|
| 246 |
+
self,
|
| 247 |
+
input_dim: int,
|
| 248 |
+
num_heads: int,
|
| 249 |
+
ffn_dim: int,
|
| 250 |
+
num_layers: int,
|
| 251 |
+
depthwise_conv_kernel_size: int,
|
| 252 |
+
dropout: float = 0.0,
|
| 253 |
+
use_group_norm: bool = False,
|
| 254 |
+
convolution_first: bool = False,
|
| 255 |
+
):
|
| 256 |
+
super().__init__()
|
| 257 |
+
|
| 258 |
+
self.conformer_layers = torch.nn.ModuleList(
|
| 259 |
+
[
|
| 260 |
+
ConformerLayer(
|
| 261 |
+
input_dim,
|
| 262 |
+
ffn_dim,
|
| 263 |
+
num_heads,
|
| 264 |
+
depthwise_conv_kernel_size,
|
| 265 |
+
dropout=dropout,
|
| 266 |
+
use_group_norm=use_group_norm,
|
| 267 |
+
convolution_first=convolution_first,
|
| 268 |
+
)
|
| 269 |
+
for _ in range(num_layers)
|
| 270 |
+
]
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 274 |
+
r"""
|
| 275 |
+
Args:
|
| 276 |
+
input (torch.Tensor): with shape `(B, T, input_dim)`.
|
| 277 |
+
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
| 278 |
+
number of valid frames for i-th batch element in ``input``.
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
(torch.Tensor, torch.Tensor)
|
| 282 |
+
torch.Tensor
|
| 283 |
+
output frames, with shape `(B, T, input_dim)`
|
| 284 |
+
torch.Tensor
|
| 285 |
+
output lengths, with shape `(B,)` and i-th element representing
|
| 286 |
+
number of valid frames for i-th batch element in output frames.
|
| 287 |
+
"""
|
| 288 |
+
encoder_padding_mask = _lengths_to_padding_mask(lengths)
|
| 289 |
+
|
| 290 |
+
x = input.transpose(0, 1)
|
| 291 |
+
for layer in self.conformer_layers:
|
| 292 |
+
x = layer(x, encoder_padding_mask)
|
| 293 |
+
return x.transpose(0, 1), lengths
|
.venv/lib/python3.11/site-packages/torchaudio/models/conv_tasnet.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Implements Conv-TasNet with building blocks of it.
|
| 2 |
+
|
| 3 |
+
Based on https://github.com/naplab/Conv-TasNet/tree/e66d82a8f956a69749ec8a4ae382217faa097c5c
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ConvBlock(torch.nn.Module):
|
| 12 |
+
"""1D Convolutional block.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
io_channels (int): The number of input/output channels, <B, Sc>
|
| 16 |
+
hidden_channels (int): The number of channels in the internal layers, <H>.
|
| 17 |
+
kernel_size (int): The convolution kernel size of the middle layer, <P>.
|
| 18 |
+
padding (int): Padding value of the convolution in the middle layer.
|
| 19 |
+
dilation (int, optional): Dilation value of the convolution in the middle layer.
|
| 20 |
+
no_redisual (bool, optional): Disable residual block/output.
|
| 21 |
+
|
| 22 |
+
Note:
|
| 23 |
+
This implementation corresponds to the "non-causal" setting in the paper.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
io_channels: int,
|
| 29 |
+
hidden_channels: int,
|
| 30 |
+
kernel_size: int,
|
| 31 |
+
padding: int,
|
| 32 |
+
dilation: int = 1,
|
| 33 |
+
no_residual: bool = False,
|
| 34 |
+
):
|
| 35 |
+
super().__init__()
|
| 36 |
+
|
| 37 |
+
self.conv_layers = torch.nn.Sequential(
|
| 38 |
+
torch.nn.Conv1d(in_channels=io_channels, out_channels=hidden_channels, kernel_size=1),
|
| 39 |
+
torch.nn.PReLU(),
|
| 40 |
+
torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08),
|
| 41 |
+
torch.nn.Conv1d(
|
| 42 |
+
in_channels=hidden_channels,
|
| 43 |
+
out_channels=hidden_channels,
|
| 44 |
+
kernel_size=kernel_size,
|
| 45 |
+
padding=padding,
|
| 46 |
+
dilation=dilation,
|
| 47 |
+
groups=hidden_channels,
|
| 48 |
+
),
|
| 49 |
+
torch.nn.PReLU(),
|
| 50 |
+
torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08),
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
self.res_out = (
|
| 54 |
+
None
|
| 55 |
+
if no_residual
|
| 56 |
+
else torch.nn.Conv1d(in_channels=hidden_channels, out_channels=io_channels, kernel_size=1)
|
| 57 |
+
)
|
| 58 |
+
self.skip_out = torch.nn.Conv1d(in_channels=hidden_channels, out_channels=io_channels, kernel_size=1)
|
| 59 |
+
|
| 60 |
+
def forward(self, input: torch.Tensor) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
|
| 61 |
+
feature = self.conv_layers(input)
|
| 62 |
+
if self.res_out is None:
|
| 63 |
+
residual = None
|
| 64 |
+
else:
|
| 65 |
+
residual = self.res_out(feature)
|
| 66 |
+
skip_out = self.skip_out(feature)
|
| 67 |
+
return residual, skip_out
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class MaskGenerator(torch.nn.Module):
|
| 71 |
+
"""TCN (Temporal Convolution Network) Separation Module
|
| 72 |
+
|
| 73 |
+
Generates masks for separation.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
input_dim (int): Input feature dimension, <N>.
|
| 77 |
+
num_sources (int): The number of sources to separate.
|
| 78 |
+
kernel_size (int): The convolution kernel size of conv blocks, <P>.
|
| 79 |
+
num_featrs (int): Input/output feature dimenstion of conv blocks, <B, Sc>.
|
| 80 |
+
num_hidden (int): Intermediate feature dimention of conv blocks, <H>
|
| 81 |
+
num_layers (int): The number of conv blocks in one stack, <X>.
|
| 82 |
+
num_stacks (int): The number of conv block stacks, <R>.
|
| 83 |
+
msk_activate (str): The activation function of the mask output.
|
| 84 |
+
|
| 85 |
+
Note:
|
| 86 |
+
This implementation corresponds to the "non-causal" setting in the paper.
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
input_dim: int,
|
| 92 |
+
num_sources: int,
|
| 93 |
+
kernel_size: int,
|
| 94 |
+
num_feats: int,
|
| 95 |
+
num_hidden: int,
|
| 96 |
+
num_layers: int,
|
| 97 |
+
num_stacks: int,
|
| 98 |
+
msk_activate: str,
|
| 99 |
+
):
|
| 100 |
+
super().__init__()
|
| 101 |
+
|
| 102 |
+
self.input_dim = input_dim
|
| 103 |
+
self.num_sources = num_sources
|
| 104 |
+
|
| 105 |
+
self.input_norm = torch.nn.GroupNorm(num_groups=1, num_channels=input_dim, eps=1e-8)
|
| 106 |
+
self.input_conv = torch.nn.Conv1d(in_channels=input_dim, out_channels=num_feats, kernel_size=1)
|
| 107 |
+
|
| 108 |
+
self.receptive_field = 0
|
| 109 |
+
self.conv_layers = torch.nn.ModuleList([])
|
| 110 |
+
for s in range(num_stacks):
|
| 111 |
+
for l in range(num_layers):
|
| 112 |
+
multi = 2**l
|
| 113 |
+
self.conv_layers.append(
|
| 114 |
+
ConvBlock(
|
| 115 |
+
io_channels=num_feats,
|
| 116 |
+
hidden_channels=num_hidden,
|
| 117 |
+
kernel_size=kernel_size,
|
| 118 |
+
dilation=multi,
|
| 119 |
+
padding=multi,
|
| 120 |
+
# The last ConvBlock does not need residual
|
| 121 |
+
no_residual=(l == (num_layers - 1) and s == (num_stacks - 1)),
|
| 122 |
+
)
|
| 123 |
+
)
|
| 124 |
+
self.receptive_field += kernel_size if s == 0 and l == 0 else (kernel_size - 1) * multi
|
| 125 |
+
self.output_prelu = torch.nn.PReLU()
|
| 126 |
+
self.output_conv = torch.nn.Conv1d(
|
| 127 |
+
in_channels=num_feats,
|
| 128 |
+
out_channels=input_dim * num_sources,
|
| 129 |
+
kernel_size=1,
|
| 130 |
+
)
|
| 131 |
+
if msk_activate == "sigmoid":
|
| 132 |
+
self.mask_activate = torch.nn.Sigmoid()
|
| 133 |
+
elif msk_activate == "relu":
|
| 134 |
+
self.mask_activate = torch.nn.ReLU()
|
| 135 |
+
else:
|
| 136 |
+
raise ValueError(f"Unsupported activation {msk_activate}")
|
| 137 |
+
|
| 138 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 139 |
+
"""Generate separation mask.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
input (torch.Tensor): 3D Tensor with shape [batch, features, frames]
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
Tensor: shape [batch, num_sources, features, frames]
|
| 146 |
+
"""
|
| 147 |
+
batch_size = input.shape[0]
|
| 148 |
+
feats = self.input_norm(input)
|
| 149 |
+
feats = self.input_conv(feats)
|
| 150 |
+
output = 0.0
|
| 151 |
+
for layer in self.conv_layers:
|
| 152 |
+
residual, skip = layer(feats)
|
| 153 |
+
if residual is not None: # the last conv layer does not produce residual
|
| 154 |
+
feats = feats + residual
|
| 155 |
+
output = output + skip
|
| 156 |
+
output = self.output_prelu(output)
|
| 157 |
+
output = self.output_conv(output)
|
| 158 |
+
output = self.mask_activate(output)
|
| 159 |
+
return output.view(batch_size, self.num_sources, self.input_dim, -1)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class ConvTasNet(torch.nn.Module):
|
| 163 |
+
"""Conv-TasNet architecture introduced in
|
| 164 |
+
*Conv-TasNet: Surpassing Ideal Time–Frequency Magnitude Masking for Speech Separation*
|
| 165 |
+
:cite:`Luo_2019`.
|
| 166 |
+
|
| 167 |
+
Note:
|
| 168 |
+
This implementation corresponds to the "non-causal" setting in the paper.
|
| 169 |
+
|
| 170 |
+
See Also:
|
| 171 |
+
* :class:`torchaudio.pipelines.SourceSeparationBundle`: Source separation pipeline with pre-trained models.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
num_sources (int, optional): The number of sources to split.
|
| 175 |
+
enc_kernel_size (int, optional): The convolution kernel size of the encoder/decoder, <L>.
|
| 176 |
+
enc_num_feats (int, optional): The feature dimensions passed to mask generator, <N>.
|
| 177 |
+
msk_kernel_size (int, optional): The convolution kernel size of the mask generator, <P>.
|
| 178 |
+
msk_num_feats (int, optional): The input/output feature dimension of conv block in the mask generator, <B, Sc>.
|
| 179 |
+
msk_num_hidden_feats (int, optional): The internal feature dimension of conv block of the mask generator, <H>.
|
| 180 |
+
msk_num_layers (int, optional): The number of layers in one conv block of the mask generator, <X>.
|
| 181 |
+
msk_num_stacks (int, optional): The numbr of conv blocks of the mask generator, <R>.
|
| 182 |
+
msk_activate (str, optional): The activation function of the mask output (Default: ``sigmoid``).
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
def __init__(
|
| 186 |
+
self,
|
| 187 |
+
num_sources: int = 2,
|
| 188 |
+
# encoder/decoder parameters
|
| 189 |
+
enc_kernel_size: int = 16,
|
| 190 |
+
enc_num_feats: int = 512,
|
| 191 |
+
# mask generator parameters
|
| 192 |
+
msk_kernel_size: int = 3,
|
| 193 |
+
msk_num_feats: int = 128,
|
| 194 |
+
msk_num_hidden_feats: int = 512,
|
| 195 |
+
msk_num_layers: int = 8,
|
| 196 |
+
msk_num_stacks: int = 3,
|
| 197 |
+
msk_activate: str = "sigmoid",
|
| 198 |
+
):
|
| 199 |
+
super().__init__()
|
| 200 |
+
|
| 201 |
+
self.num_sources = num_sources
|
| 202 |
+
self.enc_num_feats = enc_num_feats
|
| 203 |
+
self.enc_kernel_size = enc_kernel_size
|
| 204 |
+
self.enc_stride = enc_kernel_size // 2
|
| 205 |
+
|
| 206 |
+
self.encoder = torch.nn.Conv1d(
|
| 207 |
+
in_channels=1,
|
| 208 |
+
out_channels=enc_num_feats,
|
| 209 |
+
kernel_size=enc_kernel_size,
|
| 210 |
+
stride=self.enc_stride,
|
| 211 |
+
padding=self.enc_stride,
|
| 212 |
+
bias=False,
|
| 213 |
+
)
|
| 214 |
+
self.mask_generator = MaskGenerator(
|
| 215 |
+
input_dim=enc_num_feats,
|
| 216 |
+
num_sources=num_sources,
|
| 217 |
+
kernel_size=msk_kernel_size,
|
| 218 |
+
num_feats=msk_num_feats,
|
| 219 |
+
num_hidden=msk_num_hidden_feats,
|
| 220 |
+
num_layers=msk_num_layers,
|
| 221 |
+
num_stacks=msk_num_stacks,
|
| 222 |
+
msk_activate=msk_activate,
|
| 223 |
+
)
|
| 224 |
+
self.decoder = torch.nn.ConvTranspose1d(
|
| 225 |
+
in_channels=enc_num_feats,
|
| 226 |
+
out_channels=1,
|
| 227 |
+
kernel_size=enc_kernel_size,
|
| 228 |
+
stride=self.enc_stride,
|
| 229 |
+
padding=self.enc_stride,
|
| 230 |
+
bias=False,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
def _align_num_frames_with_strides(self, input: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
| 234 |
+
"""Pad input Tensor so that the end of the input tensor corresponds with
|
| 235 |
+
|
| 236 |
+
1. (if kernel size is odd) the center of the last convolution kernel
|
| 237 |
+
or 2. (if kernel size is even) the end of the first half of the last convolution kernel
|
| 238 |
+
|
| 239 |
+
Assumption:
|
| 240 |
+
The resulting Tensor will be padded with the size of stride (== kernel_width // 2)
|
| 241 |
+
on the both ends in Conv1D
|
| 242 |
+
|
| 243 |
+
|<--- k_1 --->|
|
| 244 |
+
| | |<-- k_n-1 -->|
|
| 245 |
+
| | | |<--- k_n --->|
|
| 246 |
+
| | | | |
|
| 247 |
+
| | | | |
|
| 248 |
+
| v v v |
|
| 249 |
+
|<---->|<--- input signal --->|<--->|<---->|
|
| 250 |
+
stride PAD stride
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
input (torch.Tensor): 3D Tensor with shape (batch_size, channels==1, frames)
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
Tensor: Padded Tensor
|
| 257 |
+
int: Number of paddings performed
|
| 258 |
+
"""
|
| 259 |
+
batch_size, num_channels, num_frames = input.shape
|
| 260 |
+
is_odd = self.enc_kernel_size % 2
|
| 261 |
+
num_strides = (num_frames - is_odd) // self.enc_stride
|
| 262 |
+
num_remainings = num_frames - (is_odd + num_strides * self.enc_stride)
|
| 263 |
+
if num_remainings == 0:
|
| 264 |
+
return input, 0
|
| 265 |
+
|
| 266 |
+
num_paddings = self.enc_stride - num_remainings
|
| 267 |
+
pad = torch.zeros(
|
| 268 |
+
batch_size,
|
| 269 |
+
num_channels,
|
| 270 |
+
num_paddings,
|
| 271 |
+
dtype=input.dtype,
|
| 272 |
+
device=input.device,
|
| 273 |
+
)
|
| 274 |
+
return torch.cat([input, pad], 2), num_paddings
|
| 275 |
+
|
| 276 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 277 |
+
"""Perform source separation. Generate audio source waveforms.
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
input (torch.Tensor): 3D Tensor with shape [batch, channel==1, frames]
|
| 281 |
+
|
| 282 |
+
Returns:
|
| 283 |
+
Tensor: 3D Tensor with shape [batch, channel==num_sources, frames]
|
| 284 |
+
"""
|
| 285 |
+
if input.ndim != 3 or input.shape[1] != 1:
|
| 286 |
+
raise ValueError(f"Expected 3D tensor (batch, channel==1, frames). Found: {input.shape}")
|
| 287 |
+
|
| 288 |
+
# B: batch size
|
| 289 |
+
# L: input frame length
|
| 290 |
+
# L': padded input frame length
|
| 291 |
+
# F: feature dimension
|
| 292 |
+
# M: feature frame length
|
| 293 |
+
# S: number of sources
|
| 294 |
+
|
| 295 |
+
padded, num_pads = self._align_num_frames_with_strides(input) # B, 1, L'
|
| 296 |
+
batch_size, num_padded_frames = padded.shape[0], padded.shape[2]
|
| 297 |
+
feats = self.encoder(padded) # B, F, M
|
| 298 |
+
masked = self.mask_generator(feats) * feats.unsqueeze(1) # B, S, F, M
|
| 299 |
+
masked = masked.view(batch_size * self.num_sources, self.enc_num_feats, -1) # B*S, F, M
|
| 300 |
+
decoded = self.decoder(masked) # B*S, 1, L'
|
| 301 |
+
output = decoded.view(batch_size, self.num_sources, num_padded_frames) # B, S, L'
|
| 302 |
+
if num_pads > 0:
|
| 303 |
+
output = output[..., :-num_pads] # B, S, L
|
| 304 |
+
return output
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def conv_tasnet_base(num_sources: int = 2) -> ConvTasNet:
|
| 308 |
+
r"""Builds non-causal version of :class:`~torchaudio.models.ConvTasNet`.
|
| 309 |
+
|
| 310 |
+
The parameter settings follow the ones with the highest Si-SNR metirc score in the paper,
|
| 311 |
+
except the mask activation function is changed from "sigmoid" to "relu" for performance improvement.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
num_sources (int, optional): Number of sources in the output.
|
| 315 |
+
(Default: 2)
|
| 316 |
+
Returns:
|
| 317 |
+
ConvTasNet:
|
| 318 |
+
ConvTasNet model.
|
| 319 |
+
"""
|
| 320 |
+
return ConvTasNet(
|
| 321 |
+
num_sources=num_sources,
|
| 322 |
+
enc_kernel_size=16,
|
| 323 |
+
enc_num_feats=512,
|
| 324 |
+
msk_kernel_size=3,
|
| 325 |
+
msk_num_feats=128,
|
| 326 |
+
msk_num_hidden_feats=512,
|
| 327 |
+
msk_num_layers=8,
|
| 328 |
+
msk_num_stacks=3,
|
| 329 |
+
msk_activate="relu",
|
| 330 |
+
)
|
.venv/lib/python3.11/site-packages/torchaudio/models/decoder/__init__.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_CTC_DECODERS = [
|
| 2 |
+
"CTCHypothesis",
|
| 3 |
+
"CTCDecoder",
|
| 4 |
+
"CTCDecoderLM",
|
| 5 |
+
"CTCDecoderLMState",
|
| 6 |
+
"ctc_decoder",
|
| 7 |
+
"download_pretrained_files",
|
| 8 |
+
]
|
| 9 |
+
_CUDA_CTC_DECODERS = [
|
| 10 |
+
"CUCTCDecoder",
|
| 11 |
+
"CUCTCHypothesis",
|
| 12 |
+
"cuda_ctc_decoder",
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def __getattr__(name: str):
|
| 17 |
+
if name in _CTC_DECODERS:
|
| 18 |
+
try:
|
| 19 |
+
from . import _ctc_decoder
|
| 20 |
+
except Exception as err:
|
| 21 |
+
raise RuntimeError(
|
| 22 |
+
"CTC Decoder suit requires flashlight-text package and optionally KenLM. Please install them."
|
| 23 |
+
) from err
|
| 24 |
+
|
| 25 |
+
item = getattr(_ctc_decoder, name)
|
| 26 |
+
globals()[name] = item
|
| 27 |
+
return item
|
| 28 |
+
elif name in _CUDA_CTC_DECODERS:
|
| 29 |
+
try:
|
| 30 |
+
from . import _cuda_ctc_decoder
|
| 31 |
+
except AttributeError as err:
|
| 32 |
+
raise RuntimeError(
|
| 33 |
+
"To use CUCTC decoder, please set BUILD_CUDA_CTC_DECODER=1 when building from source."
|
| 34 |
+
) from err
|
| 35 |
+
|
| 36 |
+
item = getattr(_cuda_ctc_decoder, name)
|
| 37 |
+
globals()[name] = item
|
| 38 |
+
return item
|
| 39 |
+
raise AttributeError(f"module {__name__} has no attribute {name}")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def __dir__():
|
| 43 |
+
return sorted(__all__)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
__all__ = _CTC_DECODERS + _CUDA_CTC_DECODERS
|
.venv/lib/python3.11/site-packages/torchaudio/models/decoder/_ctc_decoder.py
ADDED
|
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import itertools as it
|
| 4 |
+
|
| 5 |
+
from abc import abstractmethod
|
| 6 |
+
from collections import namedtuple
|
| 7 |
+
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from flashlight.lib.text.decoder import (
|
| 12 |
+
CriterionType as _CriterionType,
|
| 13 |
+
LexiconDecoder as _LexiconDecoder,
|
| 14 |
+
LexiconDecoderOptions as _LexiconDecoderOptions,
|
| 15 |
+
LexiconFreeDecoder as _LexiconFreeDecoder,
|
| 16 |
+
LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions,
|
| 17 |
+
LM as _LM,
|
| 18 |
+
LMState as _LMState,
|
| 19 |
+
SmearingMode as _SmearingMode,
|
| 20 |
+
Trie as _Trie,
|
| 21 |
+
ZeroLM as _ZeroLM,
|
| 22 |
+
)
|
| 23 |
+
from flashlight.lib.text.dictionary import (
|
| 24 |
+
create_word_dict as _create_word_dict,
|
| 25 |
+
Dictionary as _Dictionary,
|
| 26 |
+
load_words as _load_words,
|
| 27 |
+
)
|
| 28 |
+
from torchaudio.utils import download_asset
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
from flashlight.lib.text.decoder.kenlm import KenLM as _KenLM
|
| 32 |
+
except Exception:
|
| 33 |
+
try:
|
| 34 |
+
from flashlight.lib.text.decoder import KenLM as _KenLM
|
| 35 |
+
except Exception:
|
| 36 |
+
_KenLM = None
|
| 37 |
+
|
| 38 |
+
__all__ = [
|
| 39 |
+
"CTCHypothesis",
|
| 40 |
+
"CTCDecoder",
|
| 41 |
+
"CTCDecoderLM",
|
| 42 |
+
"CTCDecoderLMState",
|
| 43 |
+
"ctc_decoder",
|
| 44 |
+
"download_pretrained_files",
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
_PretrainedFiles = namedtuple("PretrainedFiles", ["lexicon", "tokens", "lm"])
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _construct_trie(tokens_dict, word_dict, lexicon, lm, silence):
|
| 51 |
+
vocab_size = tokens_dict.index_size()
|
| 52 |
+
trie = _Trie(vocab_size, silence)
|
| 53 |
+
start_state = lm.start(False)
|
| 54 |
+
|
| 55 |
+
for word, spellings in lexicon.items():
|
| 56 |
+
word_idx = word_dict.get_index(word)
|
| 57 |
+
_, score = lm.score(start_state, word_idx)
|
| 58 |
+
for spelling in spellings:
|
| 59 |
+
spelling_idx = [tokens_dict.get_index(token) for token in spelling]
|
| 60 |
+
trie.insert(spelling_idx, word_idx, score)
|
| 61 |
+
trie.smear(_SmearingMode.MAX)
|
| 62 |
+
return trie
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _get_word_dict(lexicon, lm, lm_dict, tokens_dict, unk_word):
|
| 66 |
+
word_dict = None
|
| 67 |
+
if lm_dict is not None:
|
| 68 |
+
word_dict = _Dictionary(lm_dict)
|
| 69 |
+
|
| 70 |
+
if lexicon and word_dict is None:
|
| 71 |
+
word_dict = _create_word_dict(lexicon)
|
| 72 |
+
elif not lexicon and word_dict is None and type(lm) == str:
|
| 73 |
+
d = {tokens_dict.get_entry(i): [[tokens_dict.get_entry(i)]] for i in range(tokens_dict.index_size())}
|
| 74 |
+
d[unk_word] = [[unk_word]]
|
| 75 |
+
word_dict = _create_word_dict(d)
|
| 76 |
+
|
| 77 |
+
return word_dict
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class CTCHypothesis(NamedTuple):
|
| 81 |
+
r"""Represents hypothesis generated by CTC beam search decoder :class:`CTCDecoder`."""
|
| 82 |
+
tokens: torch.LongTensor
|
| 83 |
+
"""Predicted sequence of token IDs. Shape `(L, )`, where `L` is the length of the output sequence"""
|
| 84 |
+
|
| 85 |
+
words: List[str]
|
| 86 |
+
"""List of predicted words.
|
| 87 |
+
|
| 88 |
+
Note:
|
| 89 |
+
This attribute is only applicable if a lexicon is provided to the decoder. If
|
| 90 |
+
decoding without a lexicon, it will be blank. Please refer to :attr:`tokens` and
|
| 91 |
+
:func:`~torchaudio.models.decoder.CTCDecoder.idxs_to_tokens` instead.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
score: float
|
| 95 |
+
"""Score corresponding to hypothesis"""
|
| 96 |
+
|
| 97 |
+
timesteps: torch.IntTensor
|
| 98 |
+
"""Timesteps corresponding to the tokens. Shape `(L, )`, where `L` is the length of the output sequence"""
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class CTCDecoderLMState(_LMState):
|
| 102 |
+
"""Language model state."""
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def children(self) -> Dict[int, CTCDecoderLMState]:
|
| 106 |
+
"""Map of indices to LM states"""
|
| 107 |
+
return super().children
|
| 108 |
+
|
| 109 |
+
def child(self, usr_index: int) -> CTCDecoderLMState:
|
| 110 |
+
"""Returns child corresponding to usr_index, or creates and returns a new state if input index
|
| 111 |
+
is not found.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
usr_index (int): index corresponding to child state
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
CTCDecoderLMState: child state corresponding to usr_index
|
| 118 |
+
"""
|
| 119 |
+
return super().child(usr_index)
|
| 120 |
+
|
| 121 |
+
def compare(self, state: CTCDecoderLMState) -> CTCDecoderLMState:
|
| 122 |
+
"""Compare two language model states.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
state (CTCDecoderLMState): LM state to compare against
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
int: 0 if the states are the same, -1 if self is less, +1 if self is greater.
|
| 129 |
+
"""
|
| 130 |
+
pass
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class CTCDecoderLM(_LM):
|
| 134 |
+
"""Language model base class for creating custom language models to use with the decoder."""
|
| 135 |
+
|
| 136 |
+
@abstractmethod
|
| 137 |
+
def start(self, start_with_nothing: bool) -> CTCDecoderLMState:
|
| 138 |
+
"""Initialize or reset the language model.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
start_with_nothing (bool): whether or not to start sentence with sil token.
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
CTCDecoderLMState: starting state
|
| 145 |
+
"""
|
| 146 |
+
raise NotImplementedError
|
| 147 |
+
|
| 148 |
+
@abstractmethod
|
| 149 |
+
def score(self, state: CTCDecoderLMState, usr_token_idx: int) -> Tuple[CTCDecoderLMState, float]:
|
| 150 |
+
"""Evaluate the language model based on the current LM state and new word.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
state (CTCDecoderLMState): current LM state
|
| 154 |
+
usr_token_idx (int): index of the word
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
(CTCDecoderLMState, float)
|
| 158 |
+
CTCDecoderLMState:
|
| 159 |
+
new LM state
|
| 160 |
+
float:
|
| 161 |
+
score
|
| 162 |
+
"""
|
| 163 |
+
raise NotImplementedError
|
| 164 |
+
|
| 165 |
+
@abstractmethod
|
| 166 |
+
def finish(self, state: CTCDecoderLMState) -> Tuple[CTCDecoderLMState, float]:
|
| 167 |
+
"""Evaluate end for language model based on current LM state.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
state (CTCDecoderLMState): current LM state
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
(CTCDecoderLMState, float)
|
| 174 |
+
CTCDecoderLMState:
|
| 175 |
+
new LM state
|
| 176 |
+
float:
|
| 177 |
+
score
|
| 178 |
+
"""
|
| 179 |
+
raise NotImplementedError
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class CTCDecoder:
|
| 183 |
+
"""CTC beam search decoder from *Flashlight* :cite:`kahn2022flashlight`.
|
| 184 |
+
|
| 185 |
+
.. devices:: CPU
|
| 186 |
+
|
| 187 |
+
Note:
|
| 188 |
+
To build the decoder, please use the factory function :func:`ctc_decoder`.
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
def __init__(
|
| 192 |
+
self,
|
| 193 |
+
nbest: int,
|
| 194 |
+
lexicon: Optional[Dict],
|
| 195 |
+
word_dict: _Dictionary,
|
| 196 |
+
tokens_dict: _Dictionary,
|
| 197 |
+
lm: CTCDecoderLM,
|
| 198 |
+
decoder_options: Union[_LexiconDecoderOptions, _LexiconFreeDecoderOptions],
|
| 199 |
+
blank_token: str,
|
| 200 |
+
sil_token: str,
|
| 201 |
+
unk_word: str,
|
| 202 |
+
) -> None:
|
| 203 |
+
"""
|
| 204 |
+
Args:
|
| 205 |
+
nbest (int): number of best decodings to return
|
| 206 |
+
lexicon (Dict or None): lexicon mapping of words to spellings, or None for lexicon-free decoder
|
| 207 |
+
word_dict (_Dictionary): dictionary of words
|
| 208 |
+
tokens_dict (_Dictionary): dictionary of tokens
|
| 209 |
+
lm (CTCDecoderLM): language model. If using a lexicon, only word level LMs are currently supported
|
| 210 |
+
decoder_options (_LexiconDecoderOptions or _LexiconFreeDecoderOptions):
|
| 211 |
+
parameters used for beam search decoding
|
| 212 |
+
blank_token (str): token corresopnding to blank
|
| 213 |
+
sil_token (str): token corresponding to silence
|
| 214 |
+
unk_word (str): word corresponding to unknown
|
| 215 |
+
"""
|
| 216 |
+
|
| 217 |
+
self.nbest = nbest
|
| 218 |
+
self.word_dict = word_dict
|
| 219 |
+
self.tokens_dict = tokens_dict
|
| 220 |
+
self.blank = self.tokens_dict.get_index(blank_token)
|
| 221 |
+
silence = self.tokens_dict.get_index(sil_token)
|
| 222 |
+
transitions = []
|
| 223 |
+
|
| 224 |
+
if lexicon:
|
| 225 |
+
trie = _construct_trie(tokens_dict, word_dict, lexicon, lm, silence)
|
| 226 |
+
unk_word = word_dict.get_index(unk_word)
|
| 227 |
+
token_lm = False # use word level LM
|
| 228 |
+
|
| 229 |
+
self.decoder = _LexiconDecoder(
|
| 230 |
+
decoder_options,
|
| 231 |
+
trie,
|
| 232 |
+
lm,
|
| 233 |
+
silence,
|
| 234 |
+
self.blank,
|
| 235 |
+
unk_word,
|
| 236 |
+
transitions,
|
| 237 |
+
token_lm,
|
| 238 |
+
)
|
| 239 |
+
else:
|
| 240 |
+
self.decoder = _LexiconFreeDecoder(decoder_options, lm, silence, self.blank, transitions)
|
| 241 |
+
# https://github.com/pytorch/audio/issues/3218
|
| 242 |
+
# If lm is passed like rvalue reference, the lm object gets garbage collected,
|
| 243 |
+
# and later call to the lm fails.
|
| 244 |
+
# This ensures that lm object is not deleted as long as the decoder is alive.
|
| 245 |
+
# https://github.com/pybind/pybind11/discussions/4013
|
| 246 |
+
self.lm = lm
|
| 247 |
+
|
| 248 |
+
def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor:
|
| 249 |
+
idxs = (g[0] for g in it.groupby(idxs))
|
| 250 |
+
idxs = filter(lambda x: x != self.blank, idxs)
|
| 251 |
+
return torch.LongTensor(list(idxs))
|
| 252 |
+
|
| 253 |
+
def _get_timesteps(self, idxs: torch.IntTensor) -> torch.IntTensor:
|
| 254 |
+
"""Returns frame numbers corresponding to non-blank tokens."""
|
| 255 |
+
|
| 256 |
+
timesteps = []
|
| 257 |
+
for i, idx in enumerate(idxs):
|
| 258 |
+
if idx == self.blank:
|
| 259 |
+
continue
|
| 260 |
+
if i == 0 or idx != idxs[i - 1]:
|
| 261 |
+
timesteps.append(i)
|
| 262 |
+
return torch.IntTensor(timesteps)
|
| 263 |
+
|
| 264 |
+
def decode_begin(self):
|
| 265 |
+
"""Initialize the internal state of the decoder.
|
| 266 |
+
|
| 267 |
+
See :py:meth:`decode_step` for the usage.
|
| 268 |
+
|
| 269 |
+
.. note::
|
| 270 |
+
|
| 271 |
+
This method is required only when performing online decoding.
|
| 272 |
+
It is not necessary when performing batch decoding with :py:meth:`__call__`.
|
| 273 |
+
"""
|
| 274 |
+
self.decoder.decode_begin()
|
| 275 |
+
|
| 276 |
+
def decode_end(self):
|
| 277 |
+
"""Finalize the internal state of the decoder.
|
| 278 |
+
|
| 279 |
+
See :py:meth:`decode_step` for the usage.
|
| 280 |
+
|
| 281 |
+
.. note::
|
| 282 |
+
|
| 283 |
+
This method is required only when performing online decoding.
|
| 284 |
+
It is not necessary when performing batch decoding with :py:meth:`__call__`.
|
| 285 |
+
"""
|
| 286 |
+
self.decoder.decode_end()
|
| 287 |
+
|
| 288 |
+
def decode_step(self, emissions: torch.FloatTensor):
|
| 289 |
+
"""Perform incremental decoding on top of the curent internal state.
|
| 290 |
+
|
| 291 |
+
.. note::
|
| 292 |
+
|
| 293 |
+
This method is required only when performing online decoding.
|
| 294 |
+
It is not necessary when performing batch decoding with :py:meth:`__call__`.
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
emissions (torch.FloatTensor): CPU tensor of shape `(frame, num_tokens)` storing sequences of
|
| 298 |
+
probability distribution over labels; output of acoustic model.
|
| 299 |
+
|
| 300 |
+
Example:
|
| 301 |
+
>>> decoder = torchaudio.models.decoder.ctc_decoder(...)
|
| 302 |
+
>>> decoder.decode_begin()
|
| 303 |
+
>>> decoder.decode_step(emission1)
|
| 304 |
+
>>> decoder.decode_step(emission2)
|
| 305 |
+
>>> decoder.decode_end()
|
| 306 |
+
>>> result = decoder.get_final_hypothesis()
|
| 307 |
+
"""
|
| 308 |
+
if emissions.dtype != torch.float32:
|
| 309 |
+
raise ValueError("emissions must be float32.")
|
| 310 |
+
|
| 311 |
+
if not emissions.is_cpu:
|
| 312 |
+
raise RuntimeError("emissions must be a CPU tensor.")
|
| 313 |
+
|
| 314 |
+
if not emissions.is_contiguous():
|
| 315 |
+
raise RuntimeError("emissions must be contiguous.")
|
| 316 |
+
|
| 317 |
+
if emissions.ndim != 2:
|
| 318 |
+
raise RuntimeError(f"emissions must be 2D. Found {emissions.shape}")
|
| 319 |
+
|
| 320 |
+
T, N = emissions.size()
|
| 321 |
+
self.decoder.decode_step(emissions.data_ptr(), T, N)
|
| 322 |
+
|
| 323 |
+
def _to_hypo(self, results) -> List[CTCHypothesis]:
|
| 324 |
+
return [
|
| 325 |
+
CTCHypothesis(
|
| 326 |
+
tokens=self._get_tokens(result.tokens),
|
| 327 |
+
words=[self.word_dict.get_entry(x) for x in result.words if x >= 0],
|
| 328 |
+
score=result.score,
|
| 329 |
+
timesteps=self._get_timesteps(result.tokens),
|
| 330 |
+
)
|
| 331 |
+
for result in results
|
| 332 |
+
]
|
| 333 |
+
|
| 334 |
+
def get_final_hypothesis(self) -> List[CTCHypothesis]:
|
| 335 |
+
"""Get the final hypothesis
|
| 336 |
+
|
| 337 |
+
Returns:
|
| 338 |
+
List[CTCHypothesis]:
|
| 339 |
+
List of sorted best hypotheses.
|
| 340 |
+
|
| 341 |
+
.. note::
|
| 342 |
+
|
| 343 |
+
This method is required only when performing online decoding.
|
| 344 |
+
It is not necessary when performing batch decoding with :py:meth:`__call__`.
|
| 345 |
+
"""
|
| 346 |
+
results = self.decoder.get_all_final_hypothesis()
|
| 347 |
+
return self._to_hypo(results[: self.nbest])
|
| 348 |
+
|
| 349 |
+
def __call__(
|
| 350 |
+
self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None
|
| 351 |
+
) -> List[List[CTCHypothesis]]:
|
| 352 |
+
"""
|
| 353 |
+
Performs batched offline decoding.
|
| 354 |
+
|
| 355 |
+
.. note::
|
| 356 |
+
|
| 357 |
+
This method performs offline decoding in one go. To perform incremental decoding,
|
| 358 |
+
please refer to :py:meth:`decode_step`.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
emissions (torch.FloatTensor): CPU tensor of shape `(batch, frame, num_tokens)` storing sequences of
|
| 362 |
+
probability distribution over labels; output of acoustic model.
|
| 363 |
+
lengths (Tensor or None, optional): CPU tensor of shape `(batch, )` storing the valid length of
|
| 364 |
+
in time axis of the output Tensor in each batch.
|
| 365 |
+
|
| 366 |
+
Returns:
|
| 367 |
+
List[List[CTCHypothesis]]:
|
| 368 |
+
List of sorted best hypotheses for each audio sequence in the batch.
|
| 369 |
+
"""
|
| 370 |
+
|
| 371 |
+
if emissions.dtype != torch.float32:
|
| 372 |
+
raise ValueError("emissions must be float32.")
|
| 373 |
+
|
| 374 |
+
if not emissions.is_cpu:
|
| 375 |
+
raise RuntimeError("emissions must be a CPU tensor.")
|
| 376 |
+
|
| 377 |
+
if not emissions.is_contiguous():
|
| 378 |
+
raise RuntimeError("emissions must be contiguous.")
|
| 379 |
+
|
| 380 |
+
if emissions.ndim != 3:
|
| 381 |
+
raise RuntimeError(f"emissions must be 3D. Found {emissions.shape}")
|
| 382 |
+
|
| 383 |
+
if lengths is not None and not lengths.is_cpu:
|
| 384 |
+
raise RuntimeError("lengths must be a CPU tensor.")
|
| 385 |
+
|
| 386 |
+
B, T, N = emissions.size()
|
| 387 |
+
if lengths is None:
|
| 388 |
+
lengths = torch.full((B,), T)
|
| 389 |
+
|
| 390 |
+
float_bytes = 4
|
| 391 |
+
hypos = []
|
| 392 |
+
|
| 393 |
+
for b in range(B):
|
| 394 |
+
emissions_ptr = emissions.data_ptr() + float_bytes * b * emissions.stride(0)
|
| 395 |
+
results = self.decoder.decode(emissions_ptr, lengths[b], N)
|
| 396 |
+
hypos.append(self._to_hypo(results[: self.nbest]))
|
| 397 |
+
return hypos
|
| 398 |
+
|
| 399 |
+
def idxs_to_tokens(self, idxs: torch.LongTensor) -> List:
|
| 400 |
+
"""
|
| 401 |
+
Map raw token IDs into corresponding tokens
|
| 402 |
+
|
| 403 |
+
Args:
|
| 404 |
+
idxs (LongTensor): raw token IDs generated from decoder
|
| 405 |
+
|
| 406 |
+
Returns:
|
| 407 |
+
List: tokens corresponding to the input IDs
|
| 408 |
+
"""
|
| 409 |
+
return [self.tokens_dict.get_entry(idx.item()) for idx in idxs]
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def ctc_decoder(
|
| 413 |
+
lexicon: Optional[str],
|
| 414 |
+
tokens: Union[str, List[str]],
|
| 415 |
+
lm: Union[str, CTCDecoderLM] = None,
|
| 416 |
+
lm_dict: Optional[str] = None,
|
| 417 |
+
nbest: int = 1,
|
| 418 |
+
beam_size: int = 50,
|
| 419 |
+
beam_size_token: Optional[int] = None,
|
| 420 |
+
beam_threshold: float = 50,
|
| 421 |
+
lm_weight: float = 2,
|
| 422 |
+
word_score: float = 0,
|
| 423 |
+
unk_score: float = float("-inf"),
|
| 424 |
+
sil_score: float = 0,
|
| 425 |
+
log_add: bool = False,
|
| 426 |
+
blank_token: str = "-",
|
| 427 |
+
sil_token: str = "|",
|
| 428 |
+
unk_word: str = "<unk>",
|
| 429 |
+
) -> CTCDecoder:
|
| 430 |
+
"""Builds an instance of :class:`CTCDecoder`.
|
| 431 |
+
|
| 432 |
+
Args:
|
| 433 |
+
lexicon (str or None): lexicon file containing the possible words and corresponding spellings.
|
| 434 |
+
Each line consists of a word and its space separated spelling. If `None`, uses lexicon-free
|
| 435 |
+
decoding.
|
| 436 |
+
tokens (str or List[str]): file or list containing valid tokens. If using a file, the expected
|
| 437 |
+
format is for tokens mapping to the same index to be on the same line
|
| 438 |
+
lm (str, CTCDecoderLM, or None, optional): either a path containing KenLM language model,
|
| 439 |
+
custom language model of type `CTCDecoderLM`, or `None` if not using a language model
|
| 440 |
+
lm_dict (str or None, optional): file consisting of the dictionary used for the LM, with a word
|
| 441 |
+
per line sorted by LM index. If decoding with a lexicon, entries in lm_dict must also occur
|
| 442 |
+
in the lexicon file. If `None`, dictionary for LM is constructed using the lexicon file.
|
| 443 |
+
(Default: None)
|
| 444 |
+
nbest (int, optional): number of best decodings to return (Default: 1)
|
| 445 |
+
beam_size (int, optional): max number of hypos to hold after each decode step (Default: 50)
|
| 446 |
+
beam_size_token (int, optional): max number of tokens to consider at each decode step.
|
| 447 |
+
If `None`, it is set to the total number of tokens (Default: None)
|
| 448 |
+
beam_threshold (float, optional): threshold for pruning hypothesis (Default: 50)
|
| 449 |
+
lm_weight (float, optional): weight of language model (Default: 2)
|
| 450 |
+
word_score (float, optional): word insertion score (Default: 0)
|
| 451 |
+
unk_score (float, optional): unknown word insertion score (Default: -inf)
|
| 452 |
+
sil_score (float, optional): silence insertion score (Default: 0)
|
| 453 |
+
log_add (bool, optional): whether or not to use logadd when merging hypotheses (Default: False)
|
| 454 |
+
blank_token (str, optional): token corresponding to blank (Default: "-")
|
| 455 |
+
sil_token (str, optional): token corresponding to silence (Default: "|")
|
| 456 |
+
unk_word (str, optional): word corresponding to unknown (Default: "<unk>")
|
| 457 |
+
|
| 458 |
+
Returns:
|
| 459 |
+
CTCDecoder: decoder
|
| 460 |
+
|
| 461 |
+
Example
|
| 462 |
+
>>> decoder = ctc_decoder(
|
| 463 |
+
>>> lexicon="lexicon.txt",
|
| 464 |
+
>>> tokens="tokens.txt",
|
| 465 |
+
>>> lm="kenlm.bin",
|
| 466 |
+
>>> )
|
| 467 |
+
>>> results = decoder(emissions) # List of shape (B, nbest) of Hypotheses
|
| 468 |
+
"""
|
| 469 |
+
if lm_dict is not None and type(lm_dict) is not str:
|
| 470 |
+
raise ValueError("lm_dict must be None or str type.")
|
| 471 |
+
|
| 472 |
+
tokens_dict = _Dictionary(tokens)
|
| 473 |
+
|
| 474 |
+
# decoder options
|
| 475 |
+
if lexicon:
|
| 476 |
+
lexicon = _load_words(lexicon)
|
| 477 |
+
decoder_options = _LexiconDecoderOptions(
|
| 478 |
+
beam_size=beam_size,
|
| 479 |
+
beam_size_token=beam_size_token or tokens_dict.index_size(),
|
| 480 |
+
beam_threshold=beam_threshold,
|
| 481 |
+
lm_weight=lm_weight,
|
| 482 |
+
word_score=word_score,
|
| 483 |
+
unk_score=unk_score,
|
| 484 |
+
sil_score=sil_score,
|
| 485 |
+
log_add=log_add,
|
| 486 |
+
criterion_type=_CriterionType.CTC,
|
| 487 |
+
)
|
| 488 |
+
else:
|
| 489 |
+
decoder_options = _LexiconFreeDecoderOptions(
|
| 490 |
+
beam_size=beam_size,
|
| 491 |
+
beam_size_token=beam_size_token or tokens_dict.index_size(),
|
| 492 |
+
beam_threshold=beam_threshold,
|
| 493 |
+
lm_weight=lm_weight,
|
| 494 |
+
sil_score=sil_score,
|
| 495 |
+
log_add=log_add,
|
| 496 |
+
criterion_type=_CriterionType.CTC,
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
# construct word dict and language model
|
| 500 |
+
word_dict = _get_word_dict(lexicon, lm, lm_dict, tokens_dict, unk_word)
|
| 501 |
+
|
| 502 |
+
if type(lm) == str:
|
| 503 |
+
if _KenLM is None:
|
| 504 |
+
raise RuntimeError(
|
| 505 |
+
"flashlight-text is installed, but KenLM is not installed. "
|
| 506 |
+
"Please refer to https://github.com/kpu/kenlm#python-module for how to install it."
|
| 507 |
+
)
|
| 508 |
+
lm = _KenLM(lm, word_dict)
|
| 509 |
+
elif lm is None:
|
| 510 |
+
lm = _ZeroLM()
|
| 511 |
+
|
| 512 |
+
return CTCDecoder(
|
| 513 |
+
nbest=nbest,
|
| 514 |
+
lexicon=lexicon,
|
| 515 |
+
word_dict=word_dict,
|
| 516 |
+
tokens_dict=tokens_dict,
|
| 517 |
+
lm=lm,
|
| 518 |
+
decoder_options=decoder_options,
|
| 519 |
+
blank_token=blank_token,
|
| 520 |
+
sil_token=sil_token,
|
| 521 |
+
unk_word=unk_word,
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
def _get_filenames(model: str) -> _PretrainedFiles:
|
| 526 |
+
if model not in ["librispeech", "librispeech-3-gram", "librispeech-4-gram"]:
|
| 527 |
+
raise ValueError(
|
| 528 |
+
f"{model} not supported. Must be one of ['librispeech-3-gram', 'librispeech-4-gram', 'librispeech']"
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
prefix = f"decoder-assets/{model}"
|
| 532 |
+
return _PretrainedFiles(
|
| 533 |
+
lexicon=f"{prefix}/lexicon.txt",
|
| 534 |
+
tokens=f"{prefix}/tokens.txt",
|
| 535 |
+
lm=f"{prefix}/lm.bin" if model != "librispeech" else None,
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def download_pretrained_files(model: str) -> _PretrainedFiles:
|
| 540 |
+
"""
|
| 541 |
+
Retrieves pretrained data files used for :func:`ctc_decoder`.
|
| 542 |
+
|
| 543 |
+
Args:
|
| 544 |
+
model (str): pretrained language model to download.
|
| 545 |
+
Valid values are: ``"librispeech-3-gram"``, ``"librispeech-4-gram"`` and ``"librispeech"``.
|
| 546 |
+
|
| 547 |
+
Returns:
|
| 548 |
+
Object with the following attributes
|
| 549 |
+
|
| 550 |
+
* ``lm``: path corresponding to downloaded language model,
|
| 551 |
+
or ``None`` if the model is not associated with an lm
|
| 552 |
+
* ``lexicon``: path corresponding to downloaded lexicon file
|
| 553 |
+
* ``tokens``: path corresponding to downloaded tokens file
|
| 554 |
+
"""
|
| 555 |
+
|
| 556 |
+
files = _get_filenames(model)
|
| 557 |
+
lexicon_file = download_asset(files.lexicon)
|
| 558 |
+
tokens_file = download_asset(files.tokens)
|
| 559 |
+
if files.lm is not None:
|
| 560 |
+
lm_file = download_asset(files.lm)
|
| 561 |
+
else:
|
| 562 |
+
lm_file = None
|
| 563 |
+
|
| 564 |
+
return _PretrainedFiles(
|
| 565 |
+
lexicon=lexicon_file,
|
| 566 |
+
tokens=tokens_file,
|
| 567 |
+
lm=lm_file,
|
| 568 |
+
)
|
.venv/lib/python3.11/site-packages/torchaudio/models/decoder/_cuda_ctc_decoder.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
from typing import List, NamedTuple, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torchaudio
|
| 9 |
+
|
| 10 |
+
torchaudio._extension._load_lib("libctc_prefix_decoder")
|
| 11 |
+
import torchaudio.lib.pybind11_prefixctc as cuctc
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
__all__ = ["CUCTCHypothesis", "CUCTCDecoder", "cuda_ctc_decoder"]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _get_vocab_list(vocab_file):
|
| 18 |
+
vocab = []
|
| 19 |
+
with open(vocab_file, "r", encoding="utf-8") as f:
|
| 20 |
+
for line in f:
|
| 21 |
+
line = line.strip().split()
|
| 22 |
+
vocab.append(line[0])
|
| 23 |
+
return vocab
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class CUCTCHypothesis(NamedTuple):
|
| 27 |
+
r"""Represents hypothesis generated by CUCTC beam search decoder :class:`CUCTCDecoder`."""
|
| 28 |
+
tokens: List[int]
|
| 29 |
+
"""Predicted sequence of token IDs. Shape `(L, )`, where `L` is the length of the output sequence"""
|
| 30 |
+
|
| 31 |
+
words: List[str]
|
| 32 |
+
"""List of predicted tokens. Algin with modeling unit.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
score: float
|
| 36 |
+
"""Score corresponding to hypothesis"""
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
_DEFAULT_BLANK_SKIP_THREASHOLD = 0.95
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class CUCTCDecoder:
|
| 43 |
+
"""CUDA CTC beam search decoder.
|
| 44 |
+
|
| 45 |
+
.. devices:: CUDA
|
| 46 |
+
|
| 47 |
+
Note:
|
| 48 |
+
To build the decoder, please use the factory function :func:`cuda_ctc_decoder`.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
vocab_list: List[str],
|
| 54 |
+
blank_id: int = 0,
|
| 55 |
+
beam_size: int = 10,
|
| 56 |
+
nbest: int = 1,
|
| 57 |
+
blank_skip_threshold: float = _DEFAULT_BLANK_SKIP_THREASHOLD,
|
| 58 |
+
cuda_stream: torch.cuda.streams.Stream = None,
|
| 59 |
+
):
|
| 60 |
+
"""
|
| 61 |
+
Args:
|
| 62 |
+
blank_id (int): token id corresopnding to blank, only support 0 for now. (Default: 0)
|
| 63 |
+
vocab_list (List[str]): list of vocabulary tokens
|
| 64 |
+
beam_size (int, optional): max number of hypos to hold after each decode step (Default: 10)
|
| 65 |
+
nbest (int): number of best decodings to return
|
| 66 |
+
blank_skip_threshold (float):
|
| 67 |
+
skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding.
|
| 68 |
+
(Default: 0.95).
|
| 69 |
+
cuda_stream (torch.cuda.streams.Stream): using assigned cuda stream (Default: using default stream)
|
| 70 |
+
|
| 71 |
+
"""
|
| 72 |
+
if cuda_stream:
|
| 73 |
+
if not isinstance(cuda_stream, torch.cuda.streams.Stream):
|
| 74 |
+
raise AssertionError("cuda_stream must be torch.cuda.streams.Stream")
|
| 75 |
+
cuda_stream_ = cuda_stream.cuda_stream if cuda_stream else torch.cuda.current_stream().cuda_stream
|
| 76 |
+
self.internal_data = cuctc.prefixCTC_alloc(cuda_stream_)
|
| 77 |
+
self.memory = torch.empty(0, dtype=torch.int8, device=torch.device("cuda"))
|
| 78 |
+
if blank_id != 0:
|
| 79 |
+
raise AssertionError("blank_id must be 0")
|
| 80 |
+
self.blank_id = blank_id
|
| 81 |
+
self.vocab_list = vocab_list
|
| 82 |
+
self.space_id = 0
|
| 83 |
+
self.nbest = nbest
|
| 84 |
+
if not (blank_skip_threshold >= 0 and blank_skip_threshold <= 1):
|
| 85 |
+
raise AssertionError("blank_skip_threshold must be between 0 and 1")
|
| 86 |
+
self.blank_skip_threshold = math.log(blank_skip_threshold)
|
| 87 |
+
self.beam_size = min(beam_size, len(vocab_list)) # beam size must be smaller than vocab size
|
| 88 |
+
|
| 89 |
+
def __del__(self):
|
| 90 |
+
if cuctc is not None:
|
| 91 |
+
cuctc.prefixCTC_free(self.internal_data)
|
| 92 |
+
|
| 93 |
+
def __call__(self, log_prob: torch.Tensor, encoder_out_lens: torch.Tensor):
|
| 94 |
+
"""
|
| 95 |
+
Args:
|
| 96 |
+
log_prob (torch.FloatTensor): GPU tensor of shape `(batch, frame, num_tokens)` storing sequences of
|
| 97 |
+
probability distribution over labels; log_softmax(output of acoustic model).
|
| 98 |
+
lengths (dtype torch.int32): GPU tensor of shape `(batch, )` storing the valid length of
|
| 99 |
+
in time axis of the output Tensor in each batch.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
List[List[CUCTCHypothesis]]:
|
| 103 |
+
List of sorted best hypotheses for each audio sequence in the batch.
|
| 104 |
+
"""
|
| 105 |
+
if not encoder_out_lens.dtype == torch.int32:
|
| 106 |
+
raise AssertionError("encoder_out_lens must be torch.int32")
|
| 107 |
+
if not log_prob.dtype == torch.float32:
|
| 108 |
+
raise AssertionError("log_prob must be torch.float32")
|
| 109 |
+
if not (log_prob.is_cuda and encoder_out_lens.is_cuda):
|
| 110 |
+
raise AssertionError("inputs must be cuda tensors")
|
| 111 |
+
if not (log_prob.is_contiguous() and encoder_out_lens.is_contiguous()):
|
| 112 |
+
raise AssertionError("input tensors must be contiguous")
|
| 113 |
+
required_size, score_hyps = cuctc.ctc_beam_search_decoder_batch_gpu_v2(
|
| 114 |
+
self.internal_data,
|
| 115 |
+
self.memory.data_ptr(),
|
| 116 |
+
self.memory.size(0),
|
| 117 |
+
log_prob.data_ptr(),
|
| 118 |
+
encoder_out_lens.data_ptr(),
|
| 119 |
+
log_prob.size(),
|
| 120 |
+
log_prob.stride(),
|
| 121 |
+
self.beam_size,
|
| 122 |
+
self.blank_id,
|
| 123 |
+
self.space_id,
|
| 124 |
+
self.blank_skip_threshold,
|
| 125 |
+
)
|
| 126 |
+
if required_size > 0:
|
| 127 |
+
self.memory = torch.empty(required_size, dtype=torch.int8, device=log_prob.device).contiguous()
|
| 128 |
+
_, score_hyps = cuctc.ctc_beam_search_decoder_batch_gpu_v2(
|
| 129 |
+
self.internal_data,
|
| 130 |
+
self.memory.data_ptr(),
|
| 131 |
+
self.memory.size(0),
|
| 132 |
+
log_prob.data_ptr(),
|
| 133 |
+
encoder_out_lens.data_ptr(),
|
| 134 |
+
log_prob.size(),
|
| 135 |
+
log_prob.stride(),
|
| 136 |
+
self.beam_size,
|
| 137 |
+
self.blank_id,
|
| 138 |
+
self.space_id,
|
| 139 |
+
self.blank_skip_threshold,
|
| 140 |
+
)
|
| 141 |
+
batch_size = len(score_hyps)
|
| 142 |
+
hypos = []
|
| 143 |
+
for i in range(batch_size):
|
| 144 |
+
hypos.append(
|
| 145 |
+
[
|
| 146 |
+
CUCTCHypothesis(
|
| 147 |
+
tokens=score_hyps[i][j][1],
|
| 148 |
+
words=[self.vocab_list[word_id] for word_id in score_hyps[i][j][1]],
|
| 149 |
+
score=score_hyps[i][j][0],
|
| 150 |
+
)
|
| 151 |
+
for j in range(self.nbest)
|
| 152 |
+
]
|
| 153 |
+
)
|
| 154 |
+
return hypos
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def cuda_ctc_decoder(
|
| 158 |
+
tokens: Union[str, List[str]],
|
| 159 |
+
nbest: int = 1,
|
| 160 |
+
beam_size: int = 10,
|
| 161 |
+
blank_skip_threshold: float = _DEFAULT_BLANK_SKIP_THREASHOLD,
|
| 162 |
+
) -> CUCTCDecoder:
|
| 163 |
+
"""Builds an instance of :class:`CUCTCDecoder`.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
tokens (str or List[str]): File or list containing valid tokens.
|
| 167 |
+
If using a file, the expected format is for tokens mapping to the same index to be on the same line
|
| 168 |
+
beam_size (int, optional): The maximum number of hypos to hold after each decode step (Default: 10)
|
| 169 |
+
nbest (int): The number of best decodings to return
|
| 170 |
+
blank_id (int): The token ID corresopnding to the blank symbol.
|
| 171 |
+
blank_skip_threshold (float): skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding
|
| 172 |
+
(Default: 0.95).
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
CUCTCDecoder: decoder
|
| 176 |
+
|
| 177 |
+
Example
|
| 178 |
+
>>> decoder = cuda_ctc_decoder(
|
| 179 |
+
>>> vocab_file="tokens.txt",
|
| 180 |
+
>>> blank_skip_threshold=0.95,
|
| 181 |
+
>>> )
|
| 182 |
+
>>> results = decoder(log_probs, encoder_out_lens) # List of shape (B, nbest) of Hypotheses
|
| 183 |
+
"""
|
| 184 |
+
if type(tokens) == str:
|
| 185 |
+
tokens = _get_vocab_list(tokens)
|
| 186 |
+
|
| 187 |
+
return CUCTCDecoder(vocab_list=tokens, beam_size=beam_size, nbest=nbest, blank_skip_threshold=blank_skip_threshold)
|
.venv/lib/python3.11/site-packages/torchaudio/models/deepspeech.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
__all__ = ["DeepSpeech"]
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class FullyConnected(torch.nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
Args:
|
| 9 |
+
n_feature: Number of input features
|
| 10 |
+
n_hidden: Internal hidden unit size.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, n_feature: int, n_hidden: int, dropout: float, relu_max_clip: int = 20) -> None:
|
| 14 |
+
super(FullyConnected, self).__init__()
|
| 15 |
+
self.fc = torch.nn.Linear(n_feature, n_hidden, bias=True)
|
| 16 |
+
self.relu_max_clip = relu_max_clip
|
| 17 |
+
self.dropout = dropout
|
| 18 |
+
|
| 19 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 20 |
+
x = self.fc(x)
|
| 21 |
+
x = torch.nn.functional.relu(x)
|
| 22 |
+
x = torch.nn.functional.hardtanh(x, 0, self.relu_max_clip)
|
| 23 |
+
if self.dropout:
|
| 24 |
+
x = torch.nn.functional.dropout(x, self.dropout, self.training)
|
| 25 |
+
return x
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DeepSpeech(torch.nn.Module):
|
| 29 |
+
"""DeepSpeech architecture introduced in
|
| 30 |
+
*Deep Speech: Scaling up end-to-end speech recognition* :cite:`hannun2014deep`.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
n_feature: Number of input features
|
| 34 |
+
n_hidden: Internal hidden unit size.
|
| 35 |
+
n_class: Number of output classes
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
n_feature: int,
|
| 41 |
+
n_hidden: int = 2048,
|
| 42 |
+
n_class: int = 40,
|
| 43 |
+
dropout: float = 0.0,
|
| 44 |
+
) -> None:
|
| 45 |
+
super(DeepSpeech, self).__init__()
|
| 46 |
+
self.n_hidden = n_hidden
|
| 47 |
+
self.fc1 = FullyConnected(n_feature, n_hidden, dropout)
|
| 48 |
+
self.fc2 = FullyConnected(n_hidden, n_hidden, dropout)
|
| 49 |
+
self.fc3 = FullyConnected(n_hidden, n_hidden, dropout)
|
| 50 |
+
self.bi_rnn = torch.nn.RNN(n_hidden, n_hidden, num_layers=1, nonlinearity="relu", bidirectional=True)
|
| 51 |
+
self.fc4 = FullyConnected(n_hidden, n_hidden, dropout)
|
| 52 |
+
self.out = torch.nn.Linear(n_hidden, n_class)
|
| 53 |
+
|
| 54 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 55 |
+
"""
|
| 56 |
+
Args:
|
| 57 |
+
x (torch.Tensor): Tensor of dimension (batch, channel, time, feature).
|
| 58 |
+
Returns:
|
| 59 |
+
Tensor: Predictor tensor of dimension (batch, time, class).
|
| 60 |
+
"""
|
| 61 |
+
# N x C x T x F
|
| 62 |
+
x = self.fc1(x)
|
| 63 |
+
# N x C x T x H
|
| 64 |
+
x = self.fc2(x)
|
| 65 |
+
# N x C x T x H
|
| 66 |
+
x = self.fc3(x)
|
| 67 |
+
# N x C x T x H
|
| 68 |
+
x = x.squeeze(1)
|
| 69 |
+
# N x T x H
|
| 70 |
+
x = x.transpose(0, 1)
|
| 71 |
+
# T x N x H
|
| 72 |
+
x, _ = self.bi_rnn(x)
|
| 73 |
+
# The fifth (non-recurrent) layer takes both the forward and backward units as inputs
|
| 74 |
+
x = x[:, :, : self.n_hidden] + x[:, :, self.n_hidden :]
|
| 75 |
+
# T x N x H
|
| 76 |
+
x = self.fc4(x)
|
| 77 |
+
# T x N x H
|
| 78 |
+
x = self.out(x)
|
| 79 |
+
# T x N x n_class
|
| 80 |
+
x = x.permute(1, 0, 2)
|
| 81 |
+
# N x T x n_class
|
| 82 |
+
x = torch.nn.functional.log_softmax(x, dim=2)
|
| 83 |
+
# N x T x n_class
|
| 84 |
+
return x
|
.venv/lib/python3.11/site-packages/torchaudio/models/emformer.py
ADDED
|
@@ -0,0 +1,884 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
__all__ = ["Emformer"]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _lengths_to_padding_mask(lengths: torch.Tensor) -> torch.Tensor:
|
| 11 |
+
batch_size = lengths.shape[0]
|
| 12 |
+
max_length = int(torch.max(lengths).item())
|
| 13 |
+
padding_mask = torch.arange(max_length, device=lengths.device, dtype=lengths.dtype).expand(
|
| 14 |
+
batch_size, max_length
|
| 15 |
+
) >= lengths.unsqueeze(1)
|
| 16 |
+
return padding_mask
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _gen_padding_mask(
|
| 20 |
+
utterance: torch.Tensor,
|
| 21 |
+
right_context: torch.Tensor,
|
| 22 |
+
summary: torch.Tensor,
|
| 23 |
+
lengths: torch.Tensor,
|
| 24 |
+
mems: torch.Tensor,
|
| 25 |
+
left_context_key: Optional[torch.Tensor] = None,
|
| 26 |
+
) -> Optional[torch.Tensor]:
|
| 27 |
+
T = right_context.size(0) + utterance.size(0) + summary.size(0)
|
| 28 |
+
B = right_context.size(1)
|
| 29 |
+
if B == 1:
|
| 30 |
+
padding_mask = None
|
| 31 |
+
else:
|
| 32 |
+
right_context_blocks_length = T - torch.max(lengths).int() - summary.size(0)
|
| 33 |
+
left_context_blocks_length = left_context_key.size(0) if left_context_key is not None else 0
|
| 34 |
+
klengths = lengths + mems.size(0) + right_context_blocks_length + left_context_blocks_length
|
| 35 |
+
padding_mask = _lengths_to_padding_mask(lengths=klengths)
|
| 36 |
+
return padding_mask
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _get_activation_module(activation: str) -> torch.nn.Module:
|
| 40 |
+
if activation == "relu":
|
| 41 |
+
return torch.nn.ReLU()
|
| 42 |
+
elif activation == "gelu":
|
| 43 |
+
return torch.nn.GELU()
|
| 44 |
+
elif activation == "silu":
|
| 45 |
+
return torch.nn.SiLU()
|
| 46 |
+
else:
|
| 47 |
+
raise ValueError(f"Unsupported activation {activation}")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _get_weight_init_gains(weight_init_scale_strategy: Optional[str], num_layers: int) -> List[Optional[float]]:
|
| 51 |
+
if weight_init_scale_strategy is None:
|
| 52 |
+
return [None for _ in range(num_layers)]
|
| 53 |
+
elif weight_init_scale_strategy == "depthwise":
|
| 54 |
+
return [1.0 / math.sqrt(layer_idx + 1) for layer_idx in range(num_layers)]
|
| 55 |
+
elif weight_init_scale_strategy == "constant":
|
| 56 |
+
return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)]
|
| 57 |
+
else:
|
| 58 |
+
raise ValueError(f"Unsupported weight_init_scale_strategy value {weight_init_scale_strategy}")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _gen_attention_mask_block(
|
| 62 |
+
col_widths: List[int], col_mask: List[bool], num_rows: int, device: torch.device
|
| 63 |
+
) -> torch.Tensor:
|
| 64 |
+
if len(col_widths) != len(col_mask):
|
| 65 |
+
raise ValueError("Length of col_widths must match that of col_mask")
|
| 66 |
+
|
| 67 |
+
mask_block = [
|
| 68 |
+
torch.ones(num_rows, col_width, device=device)
|
| 69 |
+
if is_ones_col
|
| 70 |
+
else torch.zeros(num_rows, col_width, device=device)
|
| 71 |
+
for col_width, is_ones_col in zip(col_widths, col_mask)
|
| 72 |
+
]
|
| 73 |
+
return torch.cat(mask_block, dim=1)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class _EmformerAttention(torch.nn.Module):
|
| 77 |
+
r"""Emformer layer attention module.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
input_dim (int): input dimension.
|
| 81 |
+
num_heads (int): number of attention heads in each Emformer layer.
|
| 82 |
+
dropout (float, optional): dropout probability. (Default: 0.0)
|
| 83 |
+
weight_init_gain (float or None, optional): scale factor to apply when initializing
|
| 84 |
+
attention module parameters. (Default: ``None``)
|
| 85 |
+
tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
|
| 86 |
+
negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
input_dim: int,
|
| 92 |
+
num_heads: int,
|
| 93 |
+
dropout: float = 0.0,
|
| 94 |
+
weight_init_gain: Optional[float] = None,
|
| 95 |
+
tanh_on_mem: bool = False,
|
| 96 |
+
negative_inf: float = -1e8,
|
| 97 |
+
):
|
| 98 |
+
super().__init__()
|
| 99 |
+
|
| 100 |
+
if input_dim % num_heads != 0:
|
| 101 |
+
raise ValueError(f"input_dim ({input_dim}) is not a multiple of num_heads ({num_heads}).")
|
| 102 |
+
|
| 103 |
+
self.input_dim = input_dim
|
| 104 |
+
self.num_heads = num_heads
|
| 105 |
+
self.dropout = dropout
|
| 106 |
+
self.tanh_on_mem = tanh_on_mem
|
| 107 |
+
self.negative_inf = negative_inf
|
| 108 |
+
|
| 109 |
+
self.scaling = (self.input_dim // self.num_heads) ** -0.5
|
| 110 |
+
|
| 111 |
+
self.emb_to_key_value = torch.nn.Linear(input_dim, 2 * input_dim, bias=True)
|
| 112 |
+
self.emb_to_query = torch.nn.Linear(input_dim, input_dim, bias=True)
|
| 113 |
+
self.out_proj = torch.nn.Linear(input_dim, input_dim, bias=True)
|
| 114 |
+
|
| 115 |
+
if weight_init_gain:
|
| 116 |
+
torch.nn.init.xavier_uniform_(self.emb_to_key_value.weight, gain=weight_init_gain)
|
| 117 |
+
torch.nn.init.xavier_uniform_(self.emb_to_query.weight, gain=weight_init_gain)
|
| 118 |
+
|
| 119 |
+
def _gen_key_value(self, input: torch.Tensor, mems: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 120 |
+
T, _, _ = input.shape
|
| 121 |
+
summary_length = mems.size(0) + 1
|
| 122 |
+
right_ctx_utterance_block = input[: T - summary_length]
|
| 123 |
+
mems_right_ctx_utterance_block = torch.cat([mems, right_ctx_utterance_block])
|
| 124 |
+
key, value = self.emb_to_key_value(mems_right_ctx_utterance_block).chunk(chunks=2, dim=2)
|
| 125 |
+
return key, value
|
| 126 |
+
|
| 127 |
+
def _gen_attention_probs(
|
| 128 |
+
self,
|
| 129 |
+
attention_weights: torch.Tensor,
|
| 130 |
+
attention_mask: torch.Tensor,
|
| 131 |
+
padding_mask: Optional[torch.Tensor],
|
| 132 |
+
) -> torch.Tensor:
|
| 133 |
+
attention_weights_float = attention_weights.float()
|
| 134 |
+
attention_weights_float = attention_weights_float.masked_fill(attention_mask.unsqueeze(0), self.negative_inf)
|
| 135 |
+
T = attention_weights.size(1)
|
| 136 |
+
B = attention_weights.size(0) // self.num_heads
|
| 137 |
+
if padding_mask is not None:
|
| 138 |
+
attention_weights_float = attention_weights_float.view(B, self.num_heads, T, -1)
|
| 139 |
+
attention_weights_float = attention_weights_float.masked_fill(
|
| 140 |
+
padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf
|
| 141 |
+
)
|
| 142 |
+
attention_weights_float = attention_weights_float.view(B * self.num_heads, T, -1)
|
| 143 |
+
attention_probs = torch.nn.functional.softmax(attention_weights_float, dim=-1).type_as(attention_weights)
|
| 144 |
+
return torch.nn.functional.dropout(attention_probs, p=float(self.dropout), training=self.training)
|
| 145 |
+
|
| 146 |
+
def _forward_impl(
|
| 147 |
+
self,
|
| 148 |
+
utterance: torch.Tensor,
|
| 149 |
+
lengths: torch.Tensor,
|
| 150 |
+
right_context: torch.Tensor,
|
| 151 |
+
summary: torch.Tensor,
|
| 152 |
+
mems: torch.Tensor,
|
| 153 |
+
attention_mask: torch.Tensor,
|
| 154 |
+
left_context_key: Optional[torch.Tensor] = None,
|
| 155 |
+
left_context_val: Optional[torch.Tensor] = None,
|
| 156 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 157 |
+
B = utterance.size(1)
|
| 158 |
+
T = right_context.size(0) + utterance.size(0) + summary.size(0)
|
| 159 |
+
|
| 160 |
+
# Compute query with [right context, utterance, summary].
|
| 161 |
+
query = self.emb_to_query(torch.cat([right_context, utterance, summary]))
|
| 162 |
+
|
| 163 |
+
# Compute key and value with [mems, right context, utterance].
|
| 164 |
+
key, value = self.emb_to_key_value(torch.cat([mems, right_context, utterance])).chunk(chunks=2, dim=2)
|
| 165 |
+
|
| 166 |
+
if left_context_key is not None and left_context_val is not None:
|
| 167 |
+
right_context_blocks_length = T - torch.max(lengths).int() - summary.size(0)
|
| 168 |
+
key = torch.cat(
|
| 169 |
+
[
|
| 170 |
+
key[: mems.size(0) + right_context_blocks_length],
|
| 171 |
+
left_context_key,
|
| 172 |
+
key[mems.size(0) + right_context_blocks_length :],
|
| 173 |
+
],
|
| 174 |
+
)
|
| 175 |
+
value = torch.cat(
|
| 176 |
+
[
|
| 177 |
+
value[: mems.size(0) + right_context_blocks_length],
|
| 178 |
+
left_context_val,
|
| 179 |
+
value[mems.size(0) + right_context_blocks_length :],
|
| 180 |
+
],
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# Compute attention weights from query, key, and value.
|
| 184 |
+
reshaped_query, reshaped_key, reshaped_value = [
|
| 185 |
+
tensor.contiguous().view(-1, B * self.num_heads, self.input_dim // self.num_heads).transpose(0, 1)
|
| 186 |
+
for tensor in [query, key, value]
|
| 187 |
+
]
|
| 188 |
+
attention_weights = torch.bmm(reshaped_query * self.scaling, reshaped_key.transpose(1, 2))
|
| 189 |
+
|
| 190 |
+
# Compute padding mask.
|
| 191 |
+
padding_mask = _gen_padding_mask(utterance, right_context, summary, lengths, mems, left_context_key)
|
| 192 |
+
|
| 193 |
+
# Compute attention probabilities.
|
| 194 |
+
attention_probs = self._gen_attention_probs(attention_weights, attention_mask, padding_mask)
|
| 195 |
+
|
| 196 |
+
# Compute attention.
|
| 197 |
+
attention = torch.bmm(attention_probs, reshaped_value)
|
| 198 |
+
if attention.shape != (
|
| 199 |
+
B * self.num_heads,
|
| 200 |
+
T,
|
| 201 |
+
self.input_dim // self.num_heads,
|
| 202 |
+
):
|
| 203 |
+
raise AssertionError("Computed attention has incorrect dimensions")
|
| 204 |
+
attention = attention.transpose(0, 1).contiguous().view(T, B, self.input_dim)
|
| 205 |
+
|
| 206 |
+
# Apply output projection.
|
| 207 |
+
output_right_context_mems = self.out_proj(attention)
|
| 208 |
+
|
| 209 |
+
summary_length = summary.size(0)
|
| 210 |
+
output_right_context = output_right_context_mems[: T - summary_length]
|
| 211 |
+
output_mems = output_right_context_mems[T - summary_length :]
|
| 212 |
+
if self.tanh_on_mem:
|
| 213 |
+
output_mems = torch.tanh(output_mems)
|
| 214 |
+
else:
|
| 215 |
+
output_mems = torch.clamp(output_mems, min=-10, max=10)
|
| 216 |
+
|
| 217 |
+
return output_right_context, output_mems, key, value
|
| 218 |
+
|
| 219 |
+
def forward(
|
| 220 |
+
self,
|
| 221 |
+
utterance: torch.Tensor,
|
| 222 |
+
lengths: torch.Tensor,
|
| 223 |
+
right_context: torch.Tensor,
|
| 224 |
+
summary: torch.Tensor,
|
| 225 |
+
mems: torch.Tensor,
|
| 226 |
+
attention_mask: torch.Tensor,
|
| 227 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 228 |
+
r"""Forward pass for training.
|
| 229 |
+
|
| 230 |
+
B: batch size;
|
| 231 |
+
D: feature dimension of each frame;
|
| 232 |
+
T: number of utterance frames;
|
| 233 |
+
R: number of right context frames;
|
| 234 |
+
S: number of summary elements;
|
| 235 |
+
M: number of memory elements.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
|
| 239 |
+
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
| 240 |
+
number of valid frames for i-th batch element in ``utterance``.
|
| 241 |
+
right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
|
| 242 |
+
summary (torch.Tensor): summary elements, with shape `(S, B, D)`.
|
| 243 |
+
mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
|
| 244 |
+
attention_mask (torch.Tensor): attention mask for underlying attention module.
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
(Tensor, Tensor):
|
| 248 |
+
Tensor
|
| 249 |
+
output frames corresponding to utterance and right_context, with shape `(T + R, B, D)`.
|
| 250 |
+
Tensor
|
| 251 |
+
updated memory elements, with shape `(M, B, D)`.
|
| 252 |
+
"""
|
| 253 |
+
output, output_mems, _, _ = self._forward_impl(utterance, lengths, right_context, summary, mems, attention_mask)
|
| 254 |
+
return output, output_mems[:-1]
|
| 255 |
+
|
| 256 |
+
@torch.jit.export
|
| 257 |
+
def infer(
|
| 258 |
+
self,
|
| 259 |
+
utterance: torch.Tensor,
|
| 260 |
+
lengths: torch.Tensor,
|
| 261 |
+
right_context: torch.Tensor,
|
| 262 |
+
summary: torch.Tensor,
|
| 263 |
+
mems: torch.Tensor,
|
| 264 |
+
left_context_key: torch.Tensor,
|
| 265 |
+
left_context_val: torch.Tensor,
|
| 266 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 267 |
+
r"""Forward pass for inference.
|
| 268 |
+
|
| 269 |
+
B: batch size;
|
| 270 |
+
D: feature dimension of each frame;
|
| 271 |
+
T: number of utterance frames;
|
| 272 |
+
R: number of right context frames;
|
| 273 |
+
S: number of summary elements;
|
| 274 |
+
M: number of memory elements.
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
|
| 278 |
+
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
| 279 |
+
number of valid frames for i-th batch element in ``utterance``.
|
| 280 |
+
right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
|
| 281 |
+
summary (torch.Tensor): summary elements, with shape `(S, B, D)`.
|
| 282 |
+
mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
|
| 283 |
+
left_context_key (torch.Tensor): left context attention key computed from preceding invocation.
|
| 284 |
+
left_context_val (torch.Tensor): left context attention value computed from preceding invocation.
|
| 285 |
+
|
| 286 |
+
Returns:
|
| 287 |
+
(Tensor, Tensor, Tensor, and Tensor):
|
| 288 |
+
Tensor
|
| 289 |
+
output frames corresponding to utterance and right_context, with shape `(T + R, B, D)`.
|
| 290 |
+
Tensor
|
| 291 |
+
updated memory elements, with shape `(M, B, D)`.
|
| 292 |
+
Tensor
|
| 293 |
+
attention key computed for left context and utterance.
|
| 294 |
+
Tensor
|
| 295 |
+
attention value computed for left context and utterance.
|
| 296 |
+
"""
|
| 297 |
+
query_dim = right_context.size(0) + utterance.size(0) + summary.size(0)
|
| 298 |
+
key_dim = right_context.size(0) + utterance.size(0) + mems.size(0) + left_context_key.size(0)
|
| 299 |
+
attention_mask = torch.zeros(query_dim, key_dim).to(dtype=torch.bool, device=utterance.device)
|
| 300 |
+
attention_mask[-1, : mems.size(0)] = True
|
| 301 |
+
output, output_mems, key, value = self._forward_impl(
|
| 302 |
+
utterance,
|
| 303 |
+
lengths,
|
| 304 |
+
right_context,
|
| 305 |
+
summary,
|
| 306 |
+
mems,
|
| 307 |
+
attention_mask,
|
| 308 |
+
left_context_key=left_context_key,
|
| 309 |
+
left_context_val=left_context_val,
|
| 310 |
+
)
|
| 311 |
+
return (
|
| 312 |
+
output,
|
| 313 |
+
output_mems,
|
| 314 |
+
key[mems.size(0) + right_context.size(0) :],
|
| 315 |
+
value[mems.size(0) + right_context.size(0) :],
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
class _EmformerLayer(torch.nn.Module):
|
| 320 |
+
r"""Emformer layer that constitutes Emformer.
|
| 321 |
+
|
| 322 |
+
Args:
|
| 323 |
+
input_dim (int): input dimension.
|
| 324 |
+
num_heads (int): number of attention heads.
|
| 325 |
+
ffn_dim: (int): hidden layer dimension of feedforward network.
|
| 326 |
+
segment_length (int): length of each input segment.
|
| 327 |
+
dropout (float, optional): dropout probability. (Default: 0.0)
|
| 328 |
+
activation (str, optional): activation function to use in feedforward network.
|
| 329 |
+
Must be one of ("relu", "gelu", "silu"). (Default: "relu")
|
| 330 |
+
left_context_length (int, optional): length of left context. (Default: 0)
|
| 331 |
+
max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
|
| 332 |
+
weight_init_gain (float or None, optional): scale factor to apply when initializing
|
| 333 |
+
attention module parameters. (Default: ``None``)
|
| 334 |
+
tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
|
| 335 |
+
negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
|
| 336 |
+
"""
|
| 337 |
+
|
| 338 |
+
def __init__(
|
| 339 |
+
self,
|
| 340 |
+
input_dim: int,
|
| 341 |
+
num_heads: int,
|
| 342 |
+
ffn_dim: int,
|
| 343 |
+
segment_length: int,
|
| 344 |
+
dropout: float = 0.0,
|
| 345 |
+
activation: str = "relu",
|
| 346 |
+
left_context_length: int = 0,
|
| 347 |
+
max_memory_size: int = 0,
|
| 348 |
+
weight_init_gain: Optional[float] = None,
|
| 349 |
+
tanh_on_mem: bool = False,
|
| 350 |
+
negative_inf: float = -1e8,
|
| 351 |
+
):
|
| 352 |
+
super().__init__()
|
| 353 |
+
|
| 354 |
+
self.attention = _EmformerAttention(
|
| 355 |
+
input_dim=input_dim,
|
| 356 |
+
num_heads=num_heads,
|
| 357 |
+
dropout=dropout,
|
| 358 |
+
weight_init_gain=weight_init_gain,
|
| 359 |
+
tanh_on_mem=tanh_on_mem,
|
| 360 |
+
negative_inf=negative_inf,
|
| 361 |
+
)
|
| 362 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 363 |
+
self.memory_op = torch.nn.AvgPool1d(kernel_size=segment_length, stride=segment_length, ceil_mode=True)
|
| 364 |
+
|
| 365 |
+
activation_module = _get_activation_module(activation)
|
| 366 |
+
self.pos_ff = torch.nn.Sequential(
|
| 367 |
+
torch.nn.LayerNorm(input_dim),
|
| 368 |
+
torch.nn.Linear(input_dim, ffn_dim),
|
| 369 |
+
activation_module,
|
| 370 |
+
torch.nn.Dropout(dropout),
|
| 371 |
+
torch.nn.Linear(ffn_dim, input_dim),
|
| 372 |
+
torch.nn.Dropout(dropout),
|
| 373 |
+
)
|
| 374 |
+
self.layer_norm_input = torch.nn.LayerNorm(input_dim)
|
| 375 |
+
self.layer_norm_output = torch.nn.LayerNorm(input_dim)
|
| 376 |
+
|
| 377 |
+
self.left_context_length = left_context_length
|
| 378 |
+
self.segment_length = segment_length
|
| 379 |
+
self.max_memory_size = max_memory_size
|
| 380 |
+
self.input_dim = input_dim
|
| 381 |
+
|
| 382 |
+
self.use_mem = max_memory_size > 0
|
| 383 |
+
|
| 384 |
+
def _init_state(self, batch_size: int, device: Optional[torch.device]) -> List[torch.Tensor]:
|
| 385 |
+
empty_memory = torch.zeros(self.max_memory_size, batch_size, self.input_dim, device=device)
|
| 386 |
+
left_context_key = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
|
| 387 |
+
left_context_val = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
|
| 388 |
+
past_length = torch.zeros(1, batch_size, dtype=torch.int32, device=device)
|
| 389 |
+
return [empty_memory, left_context_key, left_context_val, past_length]
|
| 390 |
+
|
| 391 |
+
def _unpack_state(self, state: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 392 |
+
past_length = state[3][0][0].item()
|
| 393 |
+
past_left_context_length = min(self.left_context_length, past_length)
|
| 394 |
+
past_mem_length = min(self.max_memory_size, math.ceil(past_length / self.segment_length))
|
| 395 |
+
pre_mems = state[0][self.max_memory_size - past_mem_length :]
|
| 396 |
+
lc_key = state[1][self.left_context_length - past_left_context_length :]
|
| 397 |
+
lc_val = state[2][self.left_context_length - past_left_context_length :]
|
| 398 |
+
return pre_mems, lc_key, lc_val
|
| 399 |
+
|
| 400 |
+
def _pack_state(
|
| 401 |
+
self,
|
| 402 |
+
next_k: torch.Tensor,
|
| 403 |
+
next_v: torch.Tensor,
|
| 404 |
+
update_length: int,
|
| 405 |
+
mems: torch.Tensor,
|
| 406 |
+
state: List[torch.Tensor],
|
| 407 |
+
) -> List[torch.Tensor]:
|
| 408 |
+
new_k = torch.cat([state[1], next_k])
|
| 409 |
+
new_v = torch.cat([state[2], next_v])
|
| 410 |
+
state[0] = torch.cat([state[0], mems])[-self.max_memory_size :]
|
| 411 |
+
state[1] = new_k[new_k.shape[0] - self.left_context_length :]
|
| 412 |
+
state[2] = new_v[new_v.shape[0] - self.left_context_length :]
|
| 413 |
+
state[3] = state[3] + update_length
|
| 414 |
+
return state
|
| 415 |
+
|
| 416 |
+
def _process_attention_output(
|
| 417 |
+
self,
|
| 418 |
+
rc_output: torch.Tensor,
|
| 419 |
+
utterance: torch.Tensor,
|
| 420 |
+
right_context: torch.Tensor,
|
| 421 |
+
) -> torch.Tensor:
|
| 422 |
+
result = self.dropout(rc_output) + torch.cat([right_context, utterance])
|
| 423 |
+
result = self.pos_ff(result) + result
|
| 424 |
+
result = self.layer_norm_output(result)
|
| 425 |
+
return result
|
| 426 |
+
|
| 427 |
+
def _apply_pre_attention_layer_norm(
|
| 428 |
+
self, utterance: torch.Tensor, right_context: torch.Tensor
|
| 429 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 430 |
+
layer_norm_input = self.layer_norm_input(torch.cat([right_context, utterance]))
|
| 431 |
+
return (
|
| 432 |
+
layer_norm_input[right_context.size(0) :],
|
| 433 |
+
layer_norm_input[: right_context.size(0)],
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
def _apply_post_attention_ffn(
|
| 437 |
+
self, rc_output: torch.Tensor, utterance: torch.Tensor, right_context: torch.Tensor
|
| 438 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 439 |
+
rc_output = self._process_attention_output(rc_output, utterance, right_context)
|
| 440 |
+
return rc_output[right_context.size(0) :], rc_output[: right_context.size(0)]
|
| 441 |
+
|
| 442 |
+
def _apply_attention_forward(
|
| 443 |
+
self,
|
| 444 |
+
utterance: torch.Tensor,
|
| 445 |
+
lengths: torch.Tensor,
|
| 446 |
+
right_context: torch.Tensor,
|
| 447 |
+
mems: torch.Tensor,
|
| 448 |
+
attention_mask: Optional[torch.Tensor],
|
| 449 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 450 |
+
if attention_mask is None:
|
| 451 |
+
raise ValueError("attention_mask must be not None when for_inference is False")
|
| 452 |
+
|
| 453 |
+
if self.use_mem:
|
| 454 |
+
summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
|
| 455 |
+
else:
|
| 456 |
+
summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
|
| 457 |
+
rc_output, next_m = self.attention(
|
| 458 |
+
utterance=utterance,
|
| 459 |
+
lengths=lengths,
|
| 460 |
+
right_context=right_context,
|
| 461 |
+
summary=summary,
|
| 462 |
+
mems=mems,
|
| 463 |
+
attention_mask=attention_mask,
|
| 464 |
+
)
|
| 465 |
+
return rc_output, next_m
|
| 466 |
+
|
| 467 |
+
def _apply_attention_infer(
|
| 468 |
+
self,
|
| 469 |
+
utterance: torch.Tensor,
|
| 470 |
+
lengths: torch.Tensor,
|
| 471 |
+
right_context: torch.Tensor,
|
| 472 |
+
mems: torch.Tensor,
|
| 473 |
+
state: Optional[List[torch.Tensor]],
|
| 474 |
+
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
| 475 |
+
if state is None:
|
| 476 |
+
state = self._init_state(utterance.size(1), device=utterance.device)
|
| 477 |
+
pre_mems, lc_key, lc_val = self._unpack_state(state)
|
| 478 |
+
if self.use_mem:
|
| 479 |
+
summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
|
| 480 |
+
summary = summary[:1]
|
| 481 |
+
else:
|
| 482 |
+
summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
|
| 483 |
+
rc_output, next_m, next_k, next_v = self.attention.infer(
|
| 484 |
+
utterance=utterance,
|
| 485 |
+
lengths=lengths,
|
| 486 |
+
right_context=right_context,
|
| 487 |
+
summary=summary,
|
| 488 |
+
mems=pre_mems,
|
| 489 |
+
left_context_key=lc_key,
|
| 490 |
+
left_context_val=lc_val,
|
| 491 |
+
)
|
| 492 |
+
state = self._pack_state(next_k, next_v, utterance.size(0), mems, state)
|
| 493 |
+
return rc_output, next_m, state
|
| 494 |
+
|
| 495 |
+
def forward(
|
| 496 |
+
self,
|
| 497 |
+
utterance: torch.Tensor,
|
| 498 |
+
lengths: torch.Tensor,
|
| 499 |
+
right_context: torch.Tensor,
|
| 500 |
+
mems: torch.Tensor,
|
| 501 |
+
attention_mask: torch.Tensor,
|
| 502 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 503 |
+
r"""Forward pass for training.
|
| 504 |
+
|
| 505 |
+
B: batch size;
|
| 506 |
+
D: feature dimension of each frame;
|
| 507 |
+
T: number of utterance frames;
|
| 508 |
+
R: number of right context frames;
|
| 509 |
+
M: number of memory elements.
|
| 510 |
+
|
| 511 |
+
Args:
|
| 512 |
+
utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
|
| 513 |
+
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
| 514 |
+
number of valid frames for i-th batch element in ``utterance``.
|
| 515 |
+
right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
|
| 516 |
+
mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
|
| 517 |
+
attention_mask (torch.Tensor): attention mask for underlying attention module.
|
| 518 |
+
|
| 519 |
+
Returns:
|
| 520 |
+
(Tensor, Tensor, Tensor):
|
| 521 |
+
Tensor
|
| 522 |
+
encoded utterance frames, with shape `(T, B, D)`.
|
| 523 |
+
Tensor
|
| 524 |
+
updated right context frames, with shape `(R, B, D)`.
|
| 525 |
+
Tensor
|
| 526 |
+
updated memory elements, with shape `(M, B, D)`.
|
| 527 |
+
"""
|
| 528 |
+
(
|
| 529 |
+
layer_norm_utterance,
|
| 530 |
+
layer_norm_right_context,
|
| 531 |
+
) = self._apply_pre_attention_layer_norm(utterance, right_context)
|
| 532 |
+
rc_output, output_mems = self._apply_attention_forward(
|
| 533 |
+
layer_norm_utterance,
|
| 534 |
+
lengths,
|
| 535 |
+
layer_norm_right_context,
|
| 536 |
+
mems,
|
| 537 |
+
attention_mask,
|
| 538 |
+
)
|
| 539 |
+
output_utterance, output_right_context = self._apply_post_attention_ffn(rc_output, utterance, right_context)
|
| 540 |
+
return output_utterance, output_right_context, output_mems
|
| 541 |
+
|
| 542 |
+
@torch.jit.export
|
| 543 |
+
def infer(
|
| 544 |
+
self,
|
| 545 |
+
utterance: torch.Tensor,
|
| 546 |
+
lengths: torch.Tensor,
|
| 547 |
+
right_context: torch.Tensor,
|
| 548 |
+
state: Optional[List[torch.Tensor]],
|
| 549 |
+
mems: torch.Tensor,
|
| 550 |
+
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
|
| 551 |
+
r"""Forward pass for inference.
|
| 552 |
+
|
| 553 |
+
B: batch size;
|
| 554 |
+
D: feature dimension of each frame;
|
| 555 |
+
T: number of utterance frames;
|
| 556 |
+
R: number of right context frames;
|
| 557 |
+
M: number of memory elements.
|
| 558 |
+
|
| 559 |
+
Args:
|
| 560 |
+
utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
|
| 561 |
+
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
| 562 |
+
number of valid frames for i-th batch element in ``utterance``.
|
| 563 |
+
right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
|
| 564 |
+
state (List[torch.Tensor] or None): list of tensors representing layer internal state
|
| 565 |
+
generated in preceding invocation of ``infer``.
|
| 566 |
+
mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
|
| 567 |
+
|
| 568 |
+
Returns:
|
| 569 |
+
(Tensor, Tensor, List[torch.Tensor], Tensor):
|
| 570 |
+
Tensor
|
| 571 |
+
encoded utterance frames, with shape `(T, B, D)`.
|
| 572 |
+
Tensor
|
| 573 |
+
updated right context frames, with shape `(R, B, D)`.
|
| 574 |
+
List[Tensor]
|
| 575 |
+
list of tensors representing layer internal state
|
| 576 |
+
generated in current invocation of ``infer``.
|
| 577 |
+
Tensor
|
| 578 |
+
updated memory elements, with shape `(M, B, D)`.
|
| 579 |
+
"""
|
| 580 |
+
(
|
| 581 |
+
layer_norm_utterance,
|
| 582 |
+
layer_norm_right_context,
|
| 583 |
+
) = self._apply_pre_attention_layer_norm(utterance, right_context)
|
| 584 |
+
rc_output, output_mems, output_state = self._apply_attention_infer(
|
| 585 |
+
layer_norm_utterance, lengths, layer_norm_right_context, mems, state
|
| 586 |
+
)
|
| 587 |
+
output_utterance, output_right_context = self._apply_post_attention_ffn(rc_output, utterance, right_context)
|
| 588 |
+
return output_utterance, output_right_context, output_state, output_mems
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
class _EmformerImpl(torch.nn.Module):
|
| 592 |
+
def __init__(
|
| 593 |
+
self,
|
| 594 |
+
emformer_layers: torch.nn.ModuleList,
|
| 595 |
+
segment_length: int,
|
| 596 |
+
left_context_length: int = 0,
|
| 597 |
+
right_context_length: int = 0,
|
| 598 |
+
max_memory_size: int = 0,
|
| 599 |
+
):
|
| 600 |
+
super().__init__()
|
| 601 |
+
|
| 602 |
+
self.use_mem = max_memory_size > 0
|
| 603 |
+
self.memory_op = torch.nn.AvgPool1d(
|
| 604 |
+
kernel_size=segment_length,
|
| 605 |
+
stride=segment_length,
|
| 606 |
+
ceil_mode=True,
|
| 607 |
+
)
|
| 608 |
+
self.emformer_layers = emformer_layers
|
| 609 |
+
self.left_context_length = left_context_length
|
| 610 |
+
self.right_context_length = right_context_length
|
| 611 |
+
self.segment_length = segment_length
|
| 612 |
+
self.max_memory_size = max_memory_size
|
| 613 |
+
|
| 614 |
+
def _gen_right_context(self, input: torch.Tensor) -> torch.Tensor:
|
| 615 |
+
T = input.shape[0]
|
| 616 |
+
num_segs = math.ceil((T - self.right_context_length) / self.segment_length)
|
| 617 |
+
right_context_blocks = []
|
| 618 |
+
for seg_idx in range(num_segs - 1):
|
| 619 |
+
start = (seg_idx + 1) * self.segment_length
|
| 620 |
+
end = start + self.right_context_length
|
| 621 |
+
right_context_blocks.append(input[start:end])
|
| 622 |
+
right_context_blocks.append(input[T - self.right_context_length :])
|
| 623 |
+
return torch.cat(right_context_blocks)
|
| 624 |
+
|
| 625 |
+
def _gen_attention_mask_col_widths(self, seg_idx: int, utterance_length: int) -> List[int]:
|
| 626 |
+
num_segs = math.ceil(utterance_length / self.segment_length)
|
| 627 |
+
rc = self.right_context_length
|
| 628 |
+
lc = self.left_context_length
|
| 629 |
+
rc_start = seg_idx * rc
|
| 630 |
+
rc_end = rc_start + rc
|
| 631 |
+
seg_start = max(seg_idx * self.segment_length - lc, 0)
|
| 632 |
+
seg_end = min((seg_idx + 1) * self.segment_length, utterance_length)
|
| 633 |
+
rc_length = self.right_context_length * num_segs
|
| 634 |
+
|
| 635 |
+
if self.use_mem:
|
| 636 |
+
m_start = max(seg_idx - self.max_memory_size, 0)
|
| 637 |
+
mem_length = num_segs - 1
|
| 638 |
+
col_widths = [
|
| 639 |
+
m_start, # before memory
|
| 640 |
+
seg_idx - m_start, # memory
|
| 641 |
+
mem_length - seg_idx, # after memory
|
| 642 |
+
rc_start, # before right context
|
| 643 |
+
rc, # right context
|
| 644 |
+
rc_length - rc_end, # after right context
|
| 645 |
+
seg_start, # before query segment
|
| 646 |
+
seg_end - seg_start, # query segment
|
| 647 |
+
utterance_length - seg_end, # after query segment
|
| 648 |
+
]
|
| 649 |
+
else:
|
| 650 |
+
col_widths = [
|
| 651 |
+
rc_start, # before right context
|
| 652 |
+
rc, # right context
|
| 653 |
+
rc_length - rc_end, # after right context
|
| 654 |
+
seg_start, # before query segment
|
| 655 |
+
seg_end - seg_start, # query segment
|
| 656 |
+
utterance_length - seg_end, # after query segment
|
| 657 |
+
]
|
| 658 |
+
|
| 659 |
+
return col_widths
|
| 660 |
+
|
| 661 |
+
def _gen_attention_mask(self, input: torch.Tensor) -> torch.Tensor:
|
| 662 |
+
utterance_length = input.size(0)
|
| 663 |
+
num_segs = math.ceil(utterance_length / self.segment_length)
|
| 664 |
+
|
| 665 |
+
rc_mask = []
|
| 666 |
+
query_mask = []
|
| 667 |
+
summary_mask = []
|
| 668 |
+
|
| 669 |
+
if self.use_mem:
|
| 670 |
+
num_cols = 9
|
| 671 |
+
# memory, right context, query segment
|
| 672 |
+
rc_q_cols_mask = [idx in [1, 4, 7] for idx in range(num_cols)]
|
| 673 |
+
# right context, query segment
|
| 674 |
+
s_cols_mask = [idx in [4, 7] for idx in range(num_cols)]
|
| 675 |
+
masks_to_concat = [rc_mask, query_mask, summary_mask]
|
| 676 |
+
else:
|
| 677 |
+
num_cols = 6
|
| 678 |
+
# right context, query segment
|
| 679 |
+
rc_q_cols_mask = [idx in [1, 4] for idx in range(num_cols)]
|
| 680 |
+
s_cols_mask = None
|
| 681 |
+
masks_to_concat = [rc_mask, query_mask]
|
| 682 |
+
|
| 683 |
+
for seg_idx in range(num_segs):
|
| 684 |
+
col_widths = self._gen_attention_mask_col_widths(seg_idx, utterance_length)
|
| 685 |
+
|
| 686 |
+
rc_mask_block = _gen_attention_mask_block(
|
| 687 |
+
col_widths, rc_q_cols_mask, self.right_context_length, input.device
|
| 688 |
+
)
|
| 689 |
+
rc_mask.append(rc_mask_block)
|
| 690 |
+
|
| 691 |
+
query_mask_block = _gen_attention_mask_block(
|
| 692 |
+
col_widths,
|
| 693 |
+
rc_q_cols_mask,
|
| 694 |
+
min(
|
| 695 |
+
self.segment_length,
|
| 696 |
+
utterance_length - seg_idx * self.segment_length,
|
| 697 |
+
),
|
| 698 |
+
input.device,
|
| 699 |
+
)
|
| 700 |
+
query_mask.append(query_mask_block)
|
| 701 |
+
|
| 702 |
+
if s_cols_mask is not None:
|
| 703 |
+
summary_mask_block = _gen_attention_mask_block(col_widths, s_cols_mask, 1, input.device)
|
| 704 |
+
summary_mask.append(summary_mask_block)
|
| 705 |
+
|
| 706 |
+
attention_mask = (1 - torch.cat([torch.cat(mask) for mask in masks_to_concat])).to(torch.bool)
|
| 707 |
+
return attention_mask
|
| 708 |
+
|
| 709 |
+
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 710 |
+
r"""Forward pass for training and non-streaming inference.
|
| 711 |
+
|
| 712 |
+
B: batch size;
|
| 713 |
+
T: max number of input frames in batch;
|
| 714 |
+
D: feature dimension of each frame.
|
| 715 |
+
|
| 716 |
+
Args:
|
| 717 |
+
input (torch.Tensor): utterance frames right-padded with right context frames, with
|
| 718 |
+
shape `(B, T + right_context_length, D)`.
|
| 719 |
+
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
| 720 |
+
number of valid utterance frames for i-th batch element in ``input``.
|
| 721 |
+
|
| 722 |
+
Returns:
|
| 723 |
+
(Tensor, Tensor):
|
| 724 |
+
Tensor
|
| 725 |
+
output frames, with shape `(B, T, D)`.
|
| 726 |
+
Tensor
|
| 727 |
+
output lengths, with shape `(B,)` and i-th element representing
|
| 728 |
+
number of valid frames for i-th batch element in output frames.
|
| 729 |
+
"""
|
| 730 |
+
input = input.permute(1, 0, 2)
|
| 731 |
+
right_context = self._gen_right_context(input)
|
| 732 |
+
utterance = input[: input.size(0) - self.right_context_length]
|
| 733 |
+
attention_mask = self._gen_attention_mask(utterance)
|
| 734 |
+
mems = (
|
| 735 |
+
self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1]
|
| 736 |
+
if self.use_mem
|
| 737 |
+
else torch.empty(0).to(dtype=input.dtype, device=input.device)
|
| 738 |
+
)
|
| 739 |
+
output = utterance
|
| 740 |
+
for layer in self.emformer_layers:
|
| 741 |
+
output, right_context, mems = layer(output, lengths, right_context, mems, attention_mask)
|
| 742 |
+
return output.permute(1, 0, 2), lengths
|
| 743 |
+
|
| 744 |
+
@torch.jit.export
|
| 745 |
+
def infer(
|
| 746 |
+
self,
|
| 747 |
+
input: torch.Tensor,
|
| 748 |
+
lengths: torch.Tensor,
|
| 749 |
+
states: Optional[List[List[torch.Tensor]]] = None,
|
| 750 |
+
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
| 751 |
+
r"""Forward pass for streaming inference.
|
| 752 |
+
|
| 753 |
+
B: batch size;
|
| 754 |
+
D: feature dimension of each frame.
|
| 755 |
+
|
| 756 |
+
Args:
|
| 757 |
+
input (torch.Tensor): utterance frames right-padded with right context frames, with
|
| 758 |
+
shape `(B, segment_length + right_context_length, D)`.
|
| 759 |
+
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
| 760 |
+
number of valid frames for i-th batch element in ``input``.
|
| 761 |
+
states (List[List[torch.Tensor]] or None, optional): list of lists of tensors
|
| 762 |
+
representing internal state generated in preceding invocation of ``infer``. (Default: ``None``)
|
| 763 |
+
|
| 764 |
+
Returns:
|
| 765 |
+
(Tensor, Tensor, List[List[Tensor]]):
|
| 766 |
+
Tensor
|
| 767 |
+
output frames, with shape `(B, segment_length, D)`.
|
| 768 |
+
Tensor
|
| 769 |
+
output lengths, with shape `(B,)` and i-th element representing
|
| 770 |
+
number of valid frames for i-th batch element in output frames.
|
| 771 |
+
List[List[Tensor]]
|
| 772 |
+
output states; list of lists of tensors representing internal state
|
| 773 |
+
generated in current invocation of ``infer``.
|
| 774 |
+
"""
|
| 775 |
+
if input.size(1) != self.segment_length + self.right_context_length:
|
| 776 |
+
raise ValueError(
|
| 777 |
+
"Per configured segment_length and right_context_length"
|
| 778 |
+
f", expected size of {self.segment_length + self.right_context_length} for dimension 1 of input"
|
| 779 |
+
f", but got {input.size(1)}."
|
| 780 |
+
)
|
| 781 |
+
input = input.permute(1, 0, 2)
|
| 782 |
+
right_context_start_idx = input.size(0) - self.right_context_length
|
| 783 |
+
right_context = input[right_context_start_idx:]
|
| 784 |
+
utterance = input[:right_context_start_idx]
|
| 785 |
+
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
| 786 |
+
mems = (
|
| 787 |
+
self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
|
| 788 |
+
if self.use_mem
|
| 789 |
+
else torch.empty(0).to(dtype=input.dtype, device=input.device)
|
| 790 |
+
)
|
| 791 |
+
output = utterance
|
| 792 |
+
output_states: List[List[torch.Tensor]] = []
|
| 793 |
+
for layer_idx, layer in enumerate(self.emformer_layers):
|
| 794 |
+
output, right_context, output_state, mems = layer.infer(
|
| 795 |
+
output,
|
| 796 |
+
output_lengths,
|
| 797 |
+
right_context,
|
| 798 |
+
None if states is None else states[layer_idx],
|
| 799 |
+
mems,
|
| 800 |
+
)
|
| 801 |
+
output_states.append(output_state)
|
| 802 |
+
|
| 803 |
+
return output.permute(1, 0, 2), output_lengths, output_states
|
| 804 |
+
|
| 805 |
+
|
| 806 |
+
class Emformer(_EmformerImpl):
|
| 807 |
+
r"""Emformer architecture introduced in
|
| 808 |
+
*Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency Streaming Speech Recognition*
|
| 809 |
+
:cite:`shi2021emformer`.
|
| 810 |
+
|
| 811 |
+
See Also:
|
| 812 |
+
* :func:`~torchaudio.models.emformer_rnnt_model`,
|
| 813 |
+
:func:`~torchaudio.models.emformer_rnnt_base`: factory functions.
|
| 814 |
+
* :class:`torchaudio.pipelines.RNNTBundle`: ASR pipelines with pretrained model.
|
| 815 |
+
|
| 816 |
+
Args:
|
| 817 |
+
input_dim (int): input dimension.
|
| 818 |
+
num_heads (int): number of attention heads in each Emformer layer.
|
| 819 |
+
ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
|
| 820 |
+
num_layers (int): number of Emformer layers to instantiate.
|
| 821 |
+
segment_length (int): length of each input segment.
|
| 822 |
+
dropout (float, optional): dropout probability. (Default: 0.0)
|
| 823 |
+
activation (str, optional): activation function to use in each Emformer layer's
|
| 824 |
+
feedforward network. Must be one of ("relu", "gelu", "silu"). (Default: "relu")
|
| 825 |
+
left_context_length (int, optional): length of left context. (Default: 0)
|
| 826 |
+
right_context_length (int, optional): length of right context. (Default: 0)
|
| 827 |
+
max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
|
| 828 |
+
weight_init_scale_strategy (str or None, optional): per-layer weight initialization scaling
|
| 829 |
+
strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise")
|
| 830 |
+
tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
|
| 831 |
+
negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
|
| 832 |
+
|
| 833 |
+
Examples:
|
| 834 |
+
>>> emformer = Emformer(512, 8, 2048, 20, 4, right_context_length=1)
|
| 835 |
+
>>> input = torch.rand(128, 400, 512) # batch, num_frames, feature_dim
|
| 836 |
+
>>> lengths = torch.randint(1, 200, (128,)) # batch
|
| 837 |
+
>>> output, lengths = emformer(input, lengths)
|
| 838 |
+
>>> input = torch.rand(128, 5, 512)
|
| 839 |
+
>>> lengths = torch.ones(128) * 5
|
| 840 |
+
>>> output, lengths, states = emformer.infer(input, lengths, None)
|
| 841 |
+
"""
|
| 842 |
+
|
| 843 |
+
def __init__(
|
| 844 |
+
self,
|
| 845 |
+
input_dim: int,
|
| 846 |
+
num_heads: int,
|
| 847 |
+
ffn_dim: int,
|
| 848 |
+
num_layers: int,
|
| 849 |
+
segment_length: int,
|
| 850 |
+
dropout: float = 0.0,
|
| 851 |
+
activation: str = "relu",
|
| 852 |
+
left_context_length: int = 0,
|
| 853 |
+
right_context_length: int = 0,
|
| 854 |
+
max_memory_size: int = 0,
|
| 855 |
+
weight_init_scale_strategy: Optional[str] = "depthwise",
|
| 856 |
+
tanh_on_mem: bool = False,
|
| 857 |
+
negative_inf: float = -1e8,
|
| 858 |
+
):
|
| 859 |
+
weight_init_gains = _get_weight_init_gains(weight_init_scale_strategy, num_layers)
|
| 860 |
+
emformer_layers = torch.nn.ModuleList(
|
| 861 |
+
[
|
| 862 |
+
_EmformerLayer(
|
| 863 |
+
input_dim,
|
| 864 |
+
num_heads,
|
| 865 |
+
ffn_dim,
|
| 866 |
+
segment_length,
|
| 867 |
+
dropout=dropout,
|
| 868 |
+
activation=activation,
|
| 869 |
+
left_context_length=left_context_length,
|
| 870 |
+
max_memory_size=max_memory_size,
|
| 871 |
+
weight_init_gain=weight_init_gains[layer_idx],
|
| 872 |
+
tanh_on_mem=tanh_on_mem,
|
| 873 |
+
negative_inf=negative_inf,
|
| 874 |
+
)
|
| 875 |
+
for layer_idx in range(num_layers)
|
| 876 |
+
]
|
| 877 |
+
)
|
| 878 |
+
super().__init__(
|
| 879 |
+
emformer_layers,
|
| 880 |
+
segment_length,
|
| 881 |
+
left_context_length=left_context_length,
|
| 882 |
+
right_context_length=right_context_length,
|
| 883 |
+
max_memory_size=max_memory_size,
|
| 884 |
+
)
|
.venv/lib/python3.11/site-packages/torchaudio/models/rnnt.py
ADDED
|
@@ -0,0 +1,816 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torchaudio.models import Emformer
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
__all__ = ["RNNT", "emformer_rnnt_base", "emformer_rnnt_model"]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class _TimeReduction(torch.nn.Module):
|
| 12 |
+
r"""Coalesces frames along time dimension into a
|
| 13 |
+
fewer number of frames with higher feature dimensionality.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
stride (int): number of frames to merge for each output frame.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, stride: int) -> None:
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.stride = stride
|
| 22 |
+
|
| 23 |
+
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 24 |
+
r"""Forward pass.
|
| 25 |
+
|
| 26 |
+
B: batch size;
|
| 27 |
+
T: maximum input sequence length in batch;
|
| 28 |
+
D: feature dimension of each input sequence frame.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
input (torch.Tensor): input sequences, with shape `(B, T, D)`.
|
| 32 |
+
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
| 33 |
+
number of valid frames for i-th batch element in ``input``.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
(torch.Tensor, torch.Tensor):
|
| 37 |
+
torch.Tensor
|
| 38 |
+
output sequences, with shape
|
| 39 |
+
`(B, T // stride, D * stride)`
|
| 40 |
+
torch.Tensor
|
| 41 |
+
output lengths, with shape `(B,)` and i-th element representing
|
| 42 |
+
number of valid frames for i-th batch element in output sequences.
|
| 43 |
+
"""
|
| 44 |
+
B, T, D = input.shape
|
| 45 |
+
num_frames = T - (T % self.stride)
|
| 46 |
+
input = input[:, :num_frames, :]
|
| 47 |
+
lengths = lengths.div(self.stride, rounding_mode="trunc")
|
| 48 |
+
T_max = num_frames // self.stride
|
| 49 |
+
|
| 50 |
+
output = input.reshape(B, T_max, D * self.stride)
|
| 51 |
+
output = output.contiguous()
|
| 52 |
+
return output, lengths
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class _CustomLSTM(torch.nn.Module):
|
| 56 |
+
r"""Custom long-short-term memory (LSTM) block that applies layer normalization
|
| 57 |
+
to internal nodes.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
input_dim (int): input dimension.
|
| 61 |
+
hidden_dim (int): hidden dimension.
|
| 62 |
+
layer_norm (bool, optional): if ``True``, enables layer normalization. (Default: ``False``)
|
| 63 |
+
layer_norm_epsilon (float, optional): value of epsilon to use in
|
| 64 |
+
layer normalization layers (Default: 1e-5)
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
input_dim: int,
|
| 70 |
+
hidden_dim: int,
|
| 71 |
+
layer_norm: bool = False,
|
| 72 |
+
layer_norm_epsilon: float = 1e-5,
|
| 73 |
+
) -> None:
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.x2g = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=(not layer_norm))
|
| 76 |
+
self.p2g = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=False)
|
| 77 |
+
if layer_norm:
|
| 78 |
+
self.c_norm = torch.nn.LayerNorm(hidden_dim, eps=layer_norm_epsilon)
|
| 79 |
+
self.g_norm = torch.nn.LayerNorm(4 * hidden_dim, eps=layer_norm_epsilon)
|
| 80 |
+
else:
|
| 81 |
+
self.c_norm = torch.nn.Identity()
|
| 82 |
+
self.g_norm = torch.nn.Identity()
|
| 83 |
+
|
| 84 |
+
self.hidden_dim = hidden_dim
|
| 85 |
+
|
| 86 |
+
def forward(
|
| 87 |
+
self, input: torch.Tensor, state: Optional[List[torch.Tensor]]
|
| 88 |
+
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
| 89 |
+
r"""Forward pass.
|
| 90 |
+
|
| 91 |
+
B: batch size;
|
| 92 |
+
T: maximum sequence length in batch;
|
| 93 |
+
D: feature dimension of each input sequence element.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
input (torch.Tensor): with shape `(T, B, D)`.
|
| 97 |
+
state (List[torch.Tensor] or None): list of tensors
|
| 98 |
+
representing internal state generated in preceding invocation
|
| 99 |
+
of ``forward``.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
(torch.Tensor, List[torch.Tensor]):
|
| 103 |
+
torch.Tensor
|
| 104 |
+
output, with shape `(T, B, hidden_dim)`.
|
| 105 |
+
List[torch.Tensor]
|
| 106 |
+
list of tensors representing internal state generated
|
| 107 |
+
in current invocation of ``forward``.
|
| 108 |
+
"""
|
| 109 |
+
if state is None:
|
| 110 |
+
B = input.size(1)
|
| 111 |
+
h = torch.zeros(B, self.hidden_dim, device=input.device, dtype=input.dtype)
|
| 112 |
+
c = torch.zeros(B, self.hidden_dim, device=input.device, dtype=input.dtype)
|
| 113 |
+
else:
|
| 114 |
+
h, c = state
|
| 115 |
+
|
| 116 |
+
gated_input = self.x2g(input)
|
| 117 |
+
outputs = []
|
| 118 |
+
for gates in gated_input.unbind(0):
|
| 119 |
+
gates = gates + self.p2g(h)
|
| 120 |
+
gates = self.g_norm(gates)
|
| 121 |
+
input_gate, forget_gate, cell_gate, output_gate = gates.chunk(4, 1)
|
| 122 |
+
input_gate = input_gate.sigmoid()
|
| 123 |
+
forget_gate = forget_gate.sigmoid()
|
| 124 |
+
cell_gate = cell_gate.tanh()
|
| 125 |
+
output_gate = output_gate.sigmoid()
|
| 126 |
+
c = forget_gate * c + input_gate * cell_gate
|
| 127 |
+
c = self.c_norm(c)
|
| 128 |
+
h = output_gate * c.tanh()
|
| 129 |
+
outputs.append(h)
|
| 130 |
+
|
| 131 |
+
output = torch.stack(outputs, dim=0)
|
| 132 |
+
state = [h, c]
|
| 133 |
+
|
| 134 |
+
return output, state
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class _Transcriber(ABC):
|
| 138 |
+
@abstractmethod
|
| 139 |
+
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 140 |
+
pass
|
| 141 |
+
|
| 142 |
+
@abstractmethod
|
| 143 |
+
def infer(
|
| 144 |
+
self,
|
| 145 |
+
input: torch.Tensor,
|
| 146 |
+
lengths: torch.Tensor,
|
| 147 |
+
states: Optional[List[List[torch.Tensor]]],
|
| 148 |
+
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
| 149 |
+
pass
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class _EmformerEncoder(torch.nn.Module, _Transcriber):
|
| 153 |
+
r"""Emformer-based recurrent neural network transducer (RNN-T) encoder (transcription network).
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
input_dim (int): feature dimension of each input sequence element.
|
| 157 |
+
output_dim (int): feature dimension of each output sequence element.
|
| 158 |
+
segment_length (int): length of input segment expressed as number of frames.
|
| 159 |
+
right_context_length (int): length of right context expressed as number of frames.
|
| 160 |
+
time_reduction_input_dim (int): dimension to scale each element in input sequences to
|
| 161 |
+
prior to applying time reduction block.
|
| 162 |
+
time_reduction_stride (int): factor by which to reduce length of input sequence.
|
| 163 |
+
transformer_num_heads (int): number of attention heads in each Emformer layer.
|
| 164 |
+
transformer_ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
|
| 165 |
+
transformer_num_layers (int): number of Emformer layers to instantiate.
|
| 166 |
+
transformer_left_context_length (int): length of left context.
|
| 167 |
+
transformer_dropout (float, optional): transformer dropout probability. (Default: 0.0)
|
| 168 |
+
transformer_activation (str, optional): activation function to use in each Emformer layer's
|
| 169 |
+
feedforward network. Must be one of ("relu", "gelu", "silu"). (Default: "relu")
|
| 170 |
+
transformer_max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
|
| 171 |
+
transformer_weight_init_scale_strategy (str, optional): per-layer weight initialization scaling
|
| 172 |
+
strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise")
|
| 173 |
+
transformer_tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
def __init__(
|
| 177 |
+
self,
|
| 178 |
+
*,
|
| 179 |
+
input_dim: int,
|
| 180 |
+
output_dim: int,
|
| 181 |
+
segment_length: int,
|
| 182 |
+
right_context_length: int,
|
| 183 |
+
time_reduction_input_dim: int,
|
| 184 |
+
time_reduction_stride: int,
|
| 185 |
+
transformer_num_heads: int,
|
| 186 |
+
transformer_ffn_dim: int,
|
| 187 |
+
transformer_num_layers: int,
|
| 188 |
+
transformer_left_context_length: int,
|
| 189 |
+
transformer_dropout: float = 0.0,
|
| 190 |
+
transformer_activation: str = "relu",
|
| 191 |
+
transformer_max_memory_size: int = 0,
|
| 192 |
+
transformer_weight_init_scale_strategy: str = "depthwise",
|
| 193 |
+
transformer_tanh_on_mem: bool = False,
|
| 194 |
+
) -> None:
|
| 195 |
+
super().__init__()
|
| 196 |
+
self.input_linear = torch.nn.Linear(
|
| 197 |
+
input_dim,
|
| 198 |
+
time_reduction_input_dim,
|
| 199 |
+
bias=False,
|
| 200 |
+
)
|
| 201 |
+
self.time_reduction = _TimeReduction(time_reduction_stride)
|
| 202 |
+
transformer_input_dim = time_reduction_input_dim * time_reduction_stride
|
| 203 |
+
self.transformer = Emformer(
|
| 204 |
+
transformer_input_dim,
|
| 205 |
+
transformer_num_heads,
|
| 206 |
+
transformer_ffn_dim,
|
| 207 |
+
transformer_num_layers,
|
| 208 |
+
segment_length // time_reduction_stride,
|
| 209 |
+
dropout=transformer_dropout,
|
| 210 |
+
activation=transformer_activation,
|
| 211 |
+
left_context_length=transformer_left_context_length,
|
| 212 |
+
right_context_length=right_context_length // time_reduction_stride,
|
| 213 |
+
max_memory_size=transformer_max_memory_size,
|
| 214 |
+
weight_init_scale_strategy=transformer_weight_init_scale_strategy,
|
| 215 |
+
tanh_on_mem=transformer_tanh_on_mem,
|
| 216 |
+
)
|
| 217 |
+
self.output_linear = torch.nn.Linear(transformer_input_dim, output_dim)
|
| 218 |
+
self.layer_norm = torch.nn.LayerNorm(output_dim)
|
| 219 |
+
|
| 220 |
+
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 221 |
+
r"""Forward pass for training.
|
| 222 |
+
|
| 223 |
+
B: batch size;
|
| 224 |
+
T: maximum input sequence length in batch;
|
| 225 |
+
D: feature dimension of each input sequence frame (input_dim).
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
input (torch.Tensor): input frame sequences right-padded with right context, with
|
| 229 |
+
shape `(B, T + right context length, D)`.
|
| 230 |
+
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
| 231 |
+
number of valid frames for i-th batch element in ``input``.
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
(torch.Tensor, torch.Tensor):
|
| 235 |
+
torch.Tensor
|
| 236 |
+
output frame sequences, with
|
| 237 |
+
shape `(B, T // time_reduction_stride, output_dim)`.
|
| 238 |
+
torch.Tensor
|
| 239 |
+
output input lengths, with shape `(B,)` and i-th element representing
|
| 240 |
+
number of valid elements for i-th batch element in output frame sequences.
|
| 241 |
+
"""
|
| 242 |
+
input_linear_out = self.input_linear(input)
|
| 243 |
+
time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths)
|
| 244 |
+
transformer_out, transformer_lengths = self.transformer(time_reduction_out, time_reduction_lengths)
|
| 245 |
+
output_linear_out = self.output_linear(transformer_out)
|
| 246 |
+
layer_norm_out = self.layer_norm(output_linear_out)
|
| 247 |
+
return layer_norm_out, transformer_lengths
|
| 248 |
+
|
| 249 |
+
@torch.jit.export
|
| 250 |
+
def infer(
|
| 251 |
+
self,
|
| 252 |
+
input: torch.Tensor,
|
| 253 |
+
lengths: torch.Tensor,
|
| 254 |
+
states: Optional[List[List[torch.Tensor]]],
|
| 255 |
+
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
| 256 |
+
r"""Forward pass for inference.
|
| 257 |
+
|
| 258 |
+
B: batch size;
|
| 259 |
+
T: maximum input sequence segment length in batch;
|
| 260 |
+
D: feature dimension of each input sequence frame (input_dim).
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
input (torch.Tensor): input frame sequence segments right-padded with right context, with
|
| 264 |
+
shape `(B, T + right context length, D)`.
|
| 265 |
+
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
| 266 |
+
number of valid frames for i-th batch element in ``input``.
|
| 267 |
+
state (List[List[torch.Tensor]] or None): list of lists of tensors
|
| 268 |
+
representing internal state generated in preceding invocation
|
| 269 |
+
of ``infer``.
|
| 270 |
+
|
| 271 |
+
Returns:
|
| 272 |
+
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
|
| 273 |
+
torch.Tensor
|
| 274 |
+
output frame sequences, with
|
| 275 |
+
shape `(B, T // time_reduction_stride, output_dim)`.
|
| 276 |
+
torch.Tensor
|
| 277 |
+
output input lengths, with shape `(B,)` and i-th element representing
|
| 278 |
+
number of valid elements for i-th batch element in output.
|
| 279 |
+
List[List[torch.Tensor]]
|
| 280 |
+
output states; list of lists of tensors
|
| 281 |
+
representing internal state generated in current invocation
|
| 282 |
+
of ``infer``.
|
| 283 |
+
"""
|
| 284 |
+
input_linear_out = self.input_linear(input)
|
| 285 |
+
time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths)
|
| 286 |
+
(
|
| 287 |
+
transformer_out,
|
| 288 |
+
transformer_lengths,
|
| 289 |
+
transformer_states,
|
| 290 |
+
) = self.transformer.infer(time_reduction_out, time_reduction_lengths, states)
|
| 291 |
+
output_linear_out = self.output_linear(transformer_out)
|
| 292 |
+
layer_norm_out = self.layer_norm(output_linear_out)
|
| 293 |
+
return layer_norm_out, transformer_lengths, transformer_states
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class _Predictor(torch.nn.Module):
|
| 297 |
+
r"""Recurrent neural network transducer (RNN-T) prediction network.
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
num_symbols (int): size of target token lexicon.
|
| 301 |
+
output_dim (int): feature dimension of each output sequence element.
|
| 302 |
+
symbol_embedding_dim (int): dimension of each target token embedding.
|
| 303 |
+
num_lstm_layers (int): number of LSTM layers to instantiate.
|
| 304 |
+
lstm_hidden_dim (int): output dimension of each LSTM layer.
|
| 305 |
+
lstm_layer_norm (bool, optional): if ``True``, enables layer normalization
|
| 306 |
+
for LSTM layers. (Default: ``False``)
|
| 307 |
+
lstm_layer_norm_epsilon (float, optional): value of epsilon to use in
|
| 308 |
+
LSTM layer normalization layers. (Default: 1e-5)
|
| 309 |
+
lstm_dropout (float, optional): LSTM dropout probability. (Default: 0.0)
|
| 310 |
+
|
| 311 |
+
"""
|
| 312 |
+
|
| 313 |
+
def __init__(
|
| 314 |
+
self,
|
| 315 |
+
num_symbols: int,
|
| 316 |
+
output_dim: int,
|
| 317 |
+
symbol_embedding_dim: int,
|
| 318 |
+
num_lstm_layers: int,
|
| 319 |
+
lstm_hidden_dim: int,
|
| 320 |
+
lstm_layer_norm: bool = False,
|
| 321 |
+
lstm_layer_norm_epsilon: float = 1e-5,
|
| 322 |
+
lstm_dropout: float = 0.0,
|
| 323 |
+
) -> None:
|
| 324 |
+
super().__init__()
|
| 325 |
+
self.embedding = torch.nn.Embedding(num_symbols, symbol_embedding_dim)
|
| 326 |
+
self.input_layer_norm = torch.nn.LayerNorm(symbol_embedding_dim)
|
| 327 |
+
self.lstm_layers = torch.nn.ModuleList(
|
| 328 |
+
[
|
| 329 |
+
_CustomLSTM(
|
| 330 |
+
symbol_embedding_dim if idx == 0 else lstm_hidden_dim,
|
| 331 |
+
lstm_hidden_dim,
|
| 332 |
+
layer_norm=lstm_layer_norm,
|
| 333 |
+
layer_norm_epsilon=lstm_layer_norm_epsilon,
|
| 334 |
+
)
|
| 335 |
+
for idx in range(num_lstm_layers)
|
| 336 |
+
]
|
| 337 |
+
)
|
| 338 |
+
self.dropout = torch.nn.Dropout(p=lstm_dropout)
|
| 339 |
+
self.linear = torch.nn.Linear(lstm_hidden_dim, output_dim)
|
| 340 |
+
self.output_layer_norm = torch.nn.LayerNorm(output_dim)
|
| 341 |
+
|
| 342 |
+
self.lstm_dropout = lstm_dropout
|
| 343 |
+
|
| 344 |
+
def forward(
|
| 345 |
+
self,
|
| 346 |
+
input: torch.Tensor,
|
| 347 |
+
lengths: torch.Tensor,
|
| 348 |
+
state: Optional[List[List[torch.Tensor]]] = None,
|
| 349 |
+
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
| 350 |
+
r"""Forward pass.
|
| 351 |
+
|
| 352 |
+
B: batch size;
|
| 353 |
+
U: maximum sequence length in batch;
|
| 354 |
+
D: feature dimension of each input sequence element.
|
| 355 |
+
|
| 356 |
+
Args:
|
| 357 |
+
input (torch.Tensor): target sequences, with shape `(B, U)` and each element
|
| 358 |
+
mapping to a target symbol, i.e. in range `[0, num_symbols)`.
|
| 359 |
+
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
| 360 |
+
number of valid frames for i-th batch element in ``input``.
|
| 361 |
+
state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
|
| 362 |
+
representing internal state generated in preceding invocation
|
| 363 |
+
of ``forward``. (Default: ``None``)
|
| 364 |
+
|
| 365 |
+
Returns:
|
| 366 |
+
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
|
| 367 |
+
torch.Tensor
|
| 368 |
+
output encoding sequences, with shape `(B, U, output_dim)`
|
| 369 |
+
torch.Tensor
|
| 370 |
+
output lengths, with shape `(B,)` and i-th element representing
|
| 371 |
+
number of valid elements for i-th batch element in output encoding sequences.
|
| 372 |
+
List[List[torch.Tensor]]
|
| 373 |
+
output states; list of lists of tensors
|
| 374 |
+
representing internal state generated in current invocation of ``forward``.
|
| 375 |
+
"""
|
| 376 |
+
input_tb = input.permute(1, 0)
|
| 377 |
+
embedding_out = self.embedding(input_tb)
|
| 378 |
+
input_layer_norm_out = self.input_layer_norm(embedding_out)
|
| 379 |
+
|
| 380 |
+
lstm_out = input_layer_norm_out
|
| 381 |
+
state_out: List[List[torch.Tensor]] = []
|
| 382 |
+
for layer_idx, lstm in enumerate(self.lstm_layers):
|
| 383 |
+
lstm_out, lstm_state_out = lstm(lstm_out, None if state is None else state[layer_idx])
|
| 384 |
+
lstm_out = self.dropout(lstm_out)
|
| 385 |
+
state_out.append(lstm_state_out)
|
| 386 |
+
|
| 387 |
+
linear_out = self.linear(lstm_out)
|
| 388 |
+
output_layer_norm_out = self.output_layer_norm(linear_out)
|
| 389 |
+
return output_layer_norm_out.permute(1, 0, 2), lengths, state_out
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class _Joiner(torch.nn.Module):
|
| 393 |
+
r"""Recurrent neural network transducer (RNN-T) joint network.
|
| 394 |
+
|
| 395 |
+
Args:
|
| 396 |
+
input_dim (int): source and target input dimension.
|
| 397 |
+
output_dim (int): output dimension.
|
| 398 |
+
activation (str, optional): activation function to use in the joiner.
|
| 399 |
+
Must be one of ("relu", "tanh"). (Default: "relu")
|
| 400 |
+
|
| 401 |
+
"""
|
| 402 |
+
|
| 403 |
+
def __init__(self, input_dim: int, output_dim: int, activation: str = "relu") -> None:
|
| 404 |
+
super().__init__()
|
| 405 |
+
self.linear = torch.nn.Linear(input_dim, output_dim, bias=True)
|
| 406 |
+
if activation == "relu":
|
| 407 |
+
self.activation = torch.nn.ReLU()
|
| 408 |
+
elif activation == "tanh":
|
| 409 |
+
self.activation = torch.nn.Tanh()
|
| 410 |
+
else:
|
| 411 |
+
raise ValueError(f"Unsupported activation {activation}")
|
| 412 |
+
|
| 413 |
+
def forward(
|
| 414 |
+
self,
|
| 415 |
+
source_encodings: torch.Tensor,
|
| 416 |
+
source_lengths: torch.Tensor,
|
| 417 |
+
target_encodings: torch.Tensor,
|
| 418 |
+
target_lengths: torch.Tensor,
|
| 419 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 420 |
+
r"""Forward pass for training.
|
| 421 |
+
|
| 422 |
+
B: batch size;
|
| 423 |
+
T: maximum source sequence length in batch;
|
| 424 |
+
U: maximum target sequence length in batch;
|
| 425 |
+
D: dimension of each source and target sequence encoding.
|
| 426 |
+
|
| 427 |
+
Args:
|
| 428 |
+
source_encodings (torch.Tensor): source encoding sequences, with
|
| 429 |
+
shape `(B, T, D)`.
|
| 430 |
+
source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
| 431 |
+
valid sequence length of i-th batch element in ``source_encodings``.
|
| 432 |
+
target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
|
| 433 |
+
target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
| 434 |
+
valid sequence length of i-th batch element in ``target_encodings``.
|
| 435 |
+
|
| 436 |
+
Returns:
|
| 437 |
+
(torch.Tensor, torch.Tensor, torch.Tensor):
|
| 438 |
+
torch.Tensor
|
| 439 |
+
joint network output, with shape `(B, T, U, output_dim)`.
|
| 440 |
+
torch.Tensor
|
| 441 |
+
output source lengths, with shape `(B,)` and i-th element representing
|
| 442 |
+
number of valid elements along dim 1 for i-th batch element in joint network output.
|
| 443 |
+
torch.Tensor
|
| 444 |
+
output target lengths, with shape `(B,)` and i-th element representing
|
| 445 |
+
number of valid elements along dim 2 for i-th batch element in joint network output.
|
| 446 |
+
"""
|
| 447 |
+
joint_encodings = source_encodings.unsqueeze(2).contiguous() + target_encodings.unsqueeze(1).contiguous()
|
| 448 |
+
activation_out = self.activation(joint_encodings)
|
| 449 |
+
output = self.linear(activation_out)
|
| 450 |
+
return output, source_lengths, target_lengths
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
class RNNT(torch.nn.Module):
|
| 454 |
+
r"""torchaudio.models.RNNT()
|
| 455 |
+
|
| 456 |
+
Recurrent neural network transducer (RNN-T) model.
|
| 457 |
+
|
| 458 |
+
Note:
|
| 459 |
+
To build the model, please use one of the factory functions.
|
| 460 |
+
|
| 461 |
+
See Also:
|
| 462 |
+
:class:`torchaudio.pipelines.RNNTBundle`: ASR pipeline with pre-trained models.
|
| 463 |
+
|
| 464 |
+
Args:
|
| 465 |
+
transcriber (torch.nn.Module): transcription network.
|
| 466 |
+
predictor (torch.nn.Module): prediction network.
|
| 467 |
+
joiner (torch.nn.Module): joint network.
|
| 468 |
+
"""
|
| 469 |
+
|
| 470 |
+
def __init__(self, transcriber: _Transcriber, predictor: _Predictor, joiner: _Joiner) -> None:
|
| 471 |
+
super().__init__()
|
| 472 |
+
self.transcriber = transcriber
|
| 473 |
+
self.predictor = predictor
|
| 474 |
+
self.joiner = joiner
|
| 475 |
+
|
| 476 |
+
def forward(
|
| 477 |
+
self,
|
| 478 |
+
sources: torch.Tensor,
|
| 479 |
+
source_lengths: torch.Tensor,
|
| 480 |
+
targets: torch.Tensor,
|
| 481 |
+
target_lengths: torch.Tensor,
|
| 482 |
+
predictor_state: Optional[List[List[torch.Tensor]]] = None,
|
| 483 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
| 484 |
+
r"""Forward pass for training.
|
| 485 |
+
|
| 486 |
+
B: batch size;
|
| 487 |
+
T: maximum source sequence length in batch;
|
| 488 |
+
U: maximum target sequence length in batch;
|
| 489 |
+
D: feature dimension of each source sequence element.
|
| 490 |
+
|
| 491 |
+
Args:
|
| 492 |
+
sources (torch.Tensor): source frame sequences right-padded with right context, with
|
| 493 |
+
shape `(B, T, D)`.
|
| 494 |
+
source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
| 495 |
+
number of valid frames for i-th batch element in ``sources``.
|
| 496 |
+
targets (torch.Tensor): target sequences, with shape `(B, U)` and each element
|
| 497 |
+
mapping to a target symbol.
|
| 498 |
+
target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
| 499 |
+
number of valid frames for i-th batch element in ``targets``.
|
| 500 |
+
predictor_state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
|
| 501 |
+
representing prediction network internal state generated in preceding invocation
|
| 502 |
+
of ``forward``. (Default: ``None``)
|
| 503 |
+
|
| 504 |
+
Returns:
|
| 505 |
+
(torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
|
| 506 |
+
torch.Tensor
|
| 507 |
+
joint network output, with shape
|
| 508 |
+
`(B, max output source length, max output target length, output_dim (number of target symbols))`.
|
| 509 |
+
torch.Tensor
|
| 510 |
+
output source lengths, with shape `(B,)` and i-th element representing
|
| 511 |
+
number of valid elements along dim 1 for i-th batch element in joint network output.
|
| 512 |
+
torch.Tensor
|
| 513 |
+
output target lengths, with shape `(B,)` and i-th element representing
|
| 514 |
+
number of valid elements along dim 2 for i-th batch element in joint network output.
|
| 515 |
+
List[List[torch.Tensor]]
|
| 516 |
+
output states; list of lists of tensors
|
| 517 |
+
representing prediction network internal state generated in current invocation
|
| 518 |
+
of ``forward``.
|
| 519 |
+
"""
|
| 520 |
+
source_encodings, source_lengths = self.transcriber(
|
| 521 |
+
input=sources,
|
| 522 |
+
lengths=source_lengths,
|
| 523 |
+
)
|
| 524 |
+
target_encodings, target_lengths, predictor_state = self.predictor(
|
| 525 |
+
input=targets,
|
| 526 |
+
lengths=target_lengths,
|
| 527 |
+
state=predictor_state,
|
| 528 |
+
)
|
| 529 |
+
output, source_lengths, target_lengths = self.joiner(
|
| 530 |
+
source_encodings=source_encodings,
|
| 531 |
+
source_lengths=source_lengths,
|
| 532 |
+
target_encodings=target_encodings,
|
| 533 |
+
target_lengths=target_lengths,
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
return (
|
| 537 |
+
output,
|
| 538 |
+
source_lengths,
|
| 539 |
+
target_lengths,
|
| 540 |
+
predictor_state,
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
@torch.jit.export
|
| 544 |
+
def transcribe_streaming(
|
| 545 |
+
self,
|
| 546 |
+
sources: torch.Tensor,
|
| 547 |
+
source_lengths: torch.Tensor,
|
| 548 |
+
state: Optional[List[List[torch.Tensor]]],
|
| 549 |
+
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
| 550 |
+
r"""Applies transcription network to sources in streaming mode.
|
| 551 |
+
|
| 552 |
+
B: batch size;
|
| 553 |
+
T: maximum source sequence segment length in batch;
|
| 554 |
+
D: feature dimension of each source sequence frame.
|
| 555 |
+
|
| 556 |
+
Args:
|
| 557 |
+
sources (torch.Tensor): source frame sequence segments right-padded with right context, with
|
| 558 |
+
shape `(B, T + right context length, D)`.
|
| 559 |
+
source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
| 560 |
+
number of valid frames for i-th batch element in ``sources``.
|
| 561 |
+
state (List[List[torch.Tensor]] or None): list of lists of tensors
|
| 562 |
+
representing transcription network internal state generated in preceding invocation
|
| 563 |
+
of ``transcribe_streaming``.
|
| 564 |
+
|
| 565 |
+
Returns:
|
| 566 |
+
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
|
| 567 |
+
torch.Tensor
|
| 568 |
+
output frame sequences, with
|
| 569 |
+
shape `(B, T // time_reduction_stride, output_dim)`.
|
| 570 |
+
torch.Tensor
|
| 571 |
+
output lengths, with shape `(B,)` and i-th element representing
|
| 572 |
+
number of valid elements for i-th batch element in output.
|
| 573 |
+
List[List[torch.Tensor]]
|
| 574 |
+
output states; list of lists of tensors
|
| 575 |
+
representing transcription network internal state generated in current invocation
|
| 576 |
+
of ``transcribe_streaming``.
|
| 577 |
+
"""
|
| 578 |
+
return self.transcriber.infer(sources, source_lengths, state)
|
| 579 |
+
|
| 580 |
+
@torch.jit.export
|
| 581 |
+
def transcribe(
|
| 582 |
+
self,
|
| 583 |
+
sources: torch.Tensor,
|
| 584 |
+
source_lengths: torch.Tensor,
|
| 585 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 586 |
+
r"""Applies transcription network to sources in non-streaming mode.
|
| 587 |
+
|
| 588 |
+
B: batch size;
|
| 589 |
+
T: maximum source sequence length in batch;
|
| 590 |
+
D: feature dimension of each source sequence frame.
|
| 591 |
+
|
| 592 |
+
Args:
|
| 593 |
+
sources (torch.Tensor): source frame sequences right-padded with right context, with
|
| 594 |
+
shape `(B, T + right context length, D)`.
|
| 595 |
+
source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
| 596 |
+
number of valid frames for i-th batch element in ``sources``.
|
| 597 |
+
|
| 598 |
+
Returns:
|
| 599 |
+
(torch.Tensor, torch.Tensor):
|
| 600 |
+
torch.Tensor
|
| 601 |
+
output frame sequences, with
|
| 602 |
+
shape `(B, T // time_reduction_stride, output_dim)`.
|
| 603 |
+
torch.Tensor
|
| 604 |
+
output lengths, with shape `(B,)` and i-th element representing
|
| 605 |
+
number of valid elements for i-th batch element in output frame sequences.
|
| 606 |
+
"""
|
| 607 |
+
return self.transcriber(sources, source_lengths)
|
| 608 |
+
|
| 609 |
+
@torch.jit.export
|
| 610 |
+
def predict(
|
| 611 |
+
self,
|
| 612 |
+
targets: torch.Tensor,
|
| 613 |
+
target_lengths: torch.Tensor,
|
| 614 |
+
state: Optional[List[List[torch.Tensor]]],
|
| 615 |
+
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
| 616 |
+
r"""Applies prediction network to targets.
|
| 617 |
+
|
| 618 |
+
B: batch size;
|
| 619 |
+
U: maximum target sequence length in batch;
|
| 620 |
+
D: feature dimension of each target sequence frame.
|
| 621 |
+
|
| 622 |
+
Args:
|
| 623 |
+
targets (torch.Tensor): target sequences, with shape `(B, U)` and each element
|
| 624 |
+
mapping to a target symbol, i.e. in range `[0, num_symbols)`.
|
| 625 |
+
target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
| 626 |
+
number of valid frames for i-th batch element in ``targets``.
|
| 627 |
+
state (List[List[torch.Tensor]] or None): list of lists of tensors
|
| 628 |
+
representing internal state generated in preceding invocation
|
| 629 |
+
of ``predict``.
|
| 630 |
+
|
| 631 |
+
Returns:
|
| 632 |
+
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
|
| 633 |
+
torch.Tensor
|
| 634 |
+
output frame sequences, with shape `(B, U, output_dim)`.
|
| 635 |
+
torch.Tensor
|
| 636 |
+
output lengths, with shape `(B,)` and i-th element representing
|
| 637 |
+
number of valid elements for i-th batch element in output.
|
| 638 |
+
List[List[torch.Tensor]]
|
| 639 |
+
output states; list of lists of tensors
|
| 640 |
+
representing internal state generated in current invocation of ``predict``.
|
| 641 |
+
"""
|
| 642 |
+
return self.predictor(input=targets, lengths=target_lengths, state=state)
|
| 643 |
+
|
| 644 |
+
@torch.jit.export
|
| 645 |
+
def join(
|
| 646 |
+
self,
|
| 647 |
+
source_encodings: torch.Tensor,
|
| 648 |
+
source_lengths: torch.Tensor,
|
| 649 |
+
target_encodings: torch.Tensor,
|
| 650 |
+
target_lengths: torch.Tensor,
|
| 651 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 652 |
+
r"""Applies joint network to source and target encodings.
|
| 653 |
+
|
| 654 |
+
B: batch size;
|
| 655 |
+
T: maximum source sequence length in batch;
|
| 656 |
+
U: maximum target sequence length in batch;
|
| 657 |
+
D: dimension of each source and target sequence encoding.
|
| 658 |
+
|
| 659 |
+
Args:
|
| 660 |
+
source_encodings (torch.Tensor): source encoding sequences, with
|
| 661 |
+
shape `(B, T, D)`.
|
| 662 |
+
source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
| 663 |
+
valid sequence length of i-th batch element in ``source_encodings``.
|
| 664 |
+
target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
|
| 665 |
+
target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
| 666 |
+
valid sequence length of i-th batch element in ``target_encodings``.
|
| 667 |
+
|
| 668 |
+
Returns:
|
| 669 |
+
(torch.Tensor, torch.Tensor, torch.Tensor):
|
| 670 |
+
torch.Tensor
|
| 671 |
+
joint network output, with shape `(B, T, U, output_dim)`.
|
| 672 |
+
torch.Tensor
|
| 673 |
+
output source lengths, with shape `(B,)` and i-th element representing
|
| 674 |
+
number of valid elements along dim 1 for i-th batch element in joint network output.
|
| 675 |
+
torch.Tensor
|
| 676 |
+
output target lengths, with shape `(B,)` and i-th element representing
|
| 677 |
+
number of valid elements along dim 2 for i-th batch element in joint network output.
|
| 678 |
+
"""
|
| 679 |
+
output, source_lengths, target_lengths = self.joiner(
|
| 680 |
+
source_encodings=source_encodings,
|
| 681 |
+
source_lengths=source_lengths,
|
| 682 |
+
target_encodings=target_encodings,
|
| 683 |
+
target_lengths=target_lengths,
|
| 684 |
+
)
|
| 685 |
+
return output, source_lengths, target_lengths
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
def emformer_rnnt_model(
|
| 689 |
+
*,
|
| 690 |
+
input_dim: int,
|
| 691 |
+
encoding_dim: int,
|
| 692 |
+
num_symbols: int,
|
| 693 |
+
segment_length: int,
|
| 694 |
+
right_context_length: int,
|
| 695 |
+
time_reduction_input_dim: int,
|
| 696 |
+
time_reduction_stride: int,
|
| 697 |
+
transformer_num_heads: int,
|
| 698 |
+
transformer_ffn_dim: int,
|
| 699 |
+
transformer_num_layers: int,
|
| 700 |
+
transformer_dropout: float,
|
| 701 |
+
transformer_activation: str,
|
| 702 |
+
transformer_left_context_length: int,
|
| 703 |
+
transformer_max_memory_size: int,
|
| 704 |
+
transformer_weight_init_scale_strategy: str,
|
| 705 |
+
transformer_tanh_on_mem: bool,
|
| 706 |
+
symbol_embedding_dim: int,
|
| 707 |
+
num_lstm_layers: int,
|
| 708 |
+
lstm_layer_norm: bool,
|
| 709 |
+
lstm_layer_norm_epsilon: float,
|
| 710 |
+
lstm_dropout: float,
|
| 711 |
+
) -> RNNT:
|
| 712 |
+
r"""Builds Emformer-based :class:`~torchaudio.models.RNNT`.
|
| 713 |
+
|
| 714 |
+
Note:
|
| 715 |
+
For non-streaming inference, the expectation is for `transcribe` to be called on input
|
| 716 |
+
sequences right-concatenated with `right_context_length` frames.
|
| 717 |
+
|
| 718 |
+
For streaming inference, the expectation is for `transcribe_streaming` to be called
|
| 719 |
+
on input chunks comprising `segment_length` frames right-concatenated with `right_context_length`
|
| 720 |
+
frames.
|
| 721 |
+
|
| 722 |
+
Args:
|
| 723 |
+
input_dim (int): dimension of input sequence frames passed to transcription network.
|
| 724 |
+
encoding_dim (int): dimension of transcription- and prediction-network-generated encodings
|
| 725 |
+
passed to joint network.
|
| 726 |
+
num_symbols (int): cardinality of set of target tokens.
|
| 727 |
+
segment_length (int): length of input segment expressed as number of frames.
|
| 728 |
+
right_context_length (int): length of right context expressed as number of frames.
|
| 729 |
+
time_reduction_input_dim (int): dimension to scale each element in input sequences to
|
| 730 |
+
prior to applying time reduction block.
|
| 731 |
+
time_reduction_stride (int): factor by which to reduce length of input sequence.
|
| 732 |
+
transformer_num_heads (int): number of attention heads in each Emformer layer.
|
| 733 |
+
transformer_ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
|
| 734 |
+
transformer_num_layers (int): number of Emformer layers to instantiate.
|
| 735 |
+
transformer_left_context_length (int): length of left context considered by Emformer.
|
| 736 |
+
transformer_dropout (float): Emformer dropout probability.
|
| 737 |
+
transformer_activation (str): activation function to use in each Emformer layer's
|
| 738 |
+
feedforward network. Must be one of ("relu", "gelu", "silu").
|
| 739 |
+
transformer_max_memory_size (int): maximum number of memory elements to use.
|
| 740 |
+
transformer_weight_init_scale_strategy (str): per-layer weight initialization scaling
|
| 741 |
+
strategy. Must be one of ("depthwise", "constant", ``None``).
|
| 742 |
+
transformer_tanh_on_mem (bool): if ``True``, applies tanh to memory elements.
|
| 743 |
+
symbol_embedding_dim (int): dimension of each target token embedding.
|
| 744 |
+
num_lstm_layers (int): number of LSTM layers to instantiate.
|
| 745 |
+
lstm_layer_norm (bool): if ``True``, enables layer normalization for LSTM layers.
|
| 746 |
+
lstm_layer_norm_epsilon (float): value of epsilon to use in LSTM layer normalization layers.
|
| 747 |
+
lstm_dropout (float): LSTM dropout probability.
|
| 748 |
+
|
| 749 |
+
Returns:
|
| 750 |
+
RNNT:
|
| 751 |
+
Emformer RNN-T model.
|
| 752 |
+
"""
|
| 753 |
+
encoder = _EmformerEncoder(
|
| 754 |
+
input_dim=input_dim,
|
| 755 |
+
output_dim=encoding_dim,
|
| 756 |
+
segment_length=segment_length,
|
| 757 |
+
right_context_length=right_context_length,
|
| 758 |
+
time_reduction_input_dim=time_reduction_input_dim,
|
| 759 |
+
time_reduction_stride=time_reduction_stride,
|
| 760 |
+
transformer_num_heads=transformer_num_heads,
|
| 761 |
+
transformer_ffn_dim=transformer_ffn_dim,
|
| 762 |
+
transformer_num_layers=transformer_num_layers,
|
| 763 |
+
transformer_dropout=transformer_dropout,
|
| 764 |
+
transformer_activation=transformer_activation,
|
| 765 |
+
transformer_left_context_length=transformer_left_context_length,
|
| 766 |
+
transformer_max_memory_size=transformer_max_memory_size,
|
| 767 |
+
transformer_weight_init_scale_strategy=transformer_weight_init_scale_strategy,
|
| 768 |
+
transformer_tanh_on_mem=transformer_tanh_on_mem,
|
| 769 |
+
)
|
| 770 |
+
predictor = _Predictor(
|
| 771 |
+
num_symbols,
|
| 772 |
+
encoding_dim,
|
| 773 |
+
symbol_embedding_dim=symbol_embedding_dim,
|
| 774 |
+
num_lstm_layers=num_lstm_layers,
|
| 775 |
+
lstm_hidden_dim=symbol_embedding_dim,
|
| 776 |
+
lstm_layer_norm=lstm_layer_norm,
|
| 777 |
+
lstm_layer_norm_epsilon=lstm_layer_norm_epsilon,
|
| 778 |
+
lstm_dropout=lstm_dropout,
|
| 779 |
+
)
|
| 780 |
+
joiner = _Joiner(encoding_dim, num_symbols)
|
| 781 |
+
return RNNT(encoder, predictor, joiner)
|
| 782 |
+
|
| 783 |
+
|
| 784 |
+
def emformer_rnnt_base(num_symbols: int) -> RNNT:
|
| 785 |
+
r"""Builds basic version of Emformer-based :class:`~torchaudio.models.RNNT`.
|
| 786 |
+
|
| 787 |
+
Args:
|
| 788 |
+
num_symbols (int): The size of target token lexicon.
|
| 789 |
+
|
| 790 |
+
Returns:
|
| 791 |
+
RNNT:
|
| 792 |
+
Emformer RNN-T model.
|
| 793 |
+
"""
|
| 794 |
+
return emformer_rnnt_model(
|
| 795 |
+
input_dim=80,
|
| 796 |
+
encoding_dim=1024,
|
| 797 |
+
num_symbols=num_symbols,
|
| 798 |
+
segment_length=16,
|
| 799 |
+
right_context_length=4,
|
| 800 |
+
time_reduction_input_dim=128,
|
| 801 |
+
time_reduction_stride=4,
|
| 802 |
+
transformer_num_heads=8,
|
| 803 |
+
transformer_ffn_dim=2048,
|
| 804 |
+
transformer_num_layers=20,
|
| 805 |
+
transformer_dropout=0.1,
|
| 806 |
+
transformer_activation="gelu",
|
| 807 |
+
transformer_left_context_length=30,
|
| 808 |
+
transformer_max_memory_size=0,
|
| 809 |
+
transformer_weight_init_scale_strategy="depthwise",
|
| 810 |
+
transformer_tanh_on_mem=True,
|
| 811 |
+
symbol_embedding_dim=512,
|
| 812 |
+
num_lstm_layers=3,
|
| 813 |
+
lstm_layer_norm=True,
|
| 814 |
+
lstm_layer_norm_epsilon=1e-3,
|
| 815 |
+
lstm_dropout=0.3,
|
| 816 |
+
)
|
.venv/lib/python3.11/site-packages/torchaudio/models/rnnt_decoder.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Dict, List, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torchaudio.models import RNNT
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
__all__ = ["Hypothesis", "RNNTBeamSearch"]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
Hypothesis = Tuple[List[int], torch.Tensor, List[List[torch.Tensor]], float]
|
| 11 |
+
Hypothesis.__doc__ = """Hypothesis generated by RNN-T beam search decoder,
|
| 12 |
+
represented as tuple of (tokens, prediction network output, prediction network state, score).
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _get_hypo_tokens(hypo: Hypothesis) -> List[int]:
|
| 17 |
+
return hypo[0]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _get_hypo_predictor_out(hypo: Hypothesis) -> torch.Tensor:
|
| 21 |
+
return hypo[1]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _get_hypo_state(hypo: Hypothesis) -> List[List[torch.Tensor]]:
|
| 25 |
+
return hypo[2]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _get_hypo_score(hypo: Hypothesis) -> float:
|
| 29 |
+
return hypo[3]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _get_hypo_key(hypo: Hypothesis) -> str:
|
| 33 |
+
return str(hypo[0])
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _batch_state(hypos: List[Hypothesis]) -> List[List[torch.Tensor]]:
|
| 37 |
+
states: List[List[torch.Tensor]] = []
|
| 38 |
+
for i in range(len(_get_hypo_state(hypos[0]))):
|
| 39 |
+
batched_state_components: List[torch.Tensor] = []
|
| 40 |
+
for j in range(len(_get_hypo_state(hypos[0])[i])):
|
| 41 |
+
batched_state_components.append(torch.cat([_get_hypo_state(hypo)[i][j] for hypo in hypos]))
|
| 42 |
+
states.append(batched_state_components)
|
| 43 |
+
return states
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _slice_state(states: List[List[torch.Tensor]], idx: int, device: torch.device) -> List[List[torch.Tensor]]:
|
| 47 |
+
idx_tensor = torch.tensor([idx], device=device)
|
| 48 |
+
return [[state.index_select(0, idx_tensor) for state in state_tuple] for state_tuple in states]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _default_hypo_sort_key(hypo: Hypothesis) -> float:
|
| 52 |
+
return _get_hypo_score(hypo) / (len(_get_hypo_tokens(hypo)) + 1)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _compute_updated_scores(
|
| 56 |
+
hypos: List[Hypothesis],
|
| 57 |
+
next_token_probs: torch.Tensor,
|
| 58 |
+
beam_width: int,
|
| 59 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 60 |
+
hypo_scores = torch.tensor([_get_hypo_score(h) for h in hypos]).unsqueeze(1)
|
| 61 |
+
nonblank_scores = hypo_scores + next_token_probs[:, :-1] # [beam_width, num_tokens - 1]
|
| 62 |
+
nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(beam_width)
|
| 63 |
+
nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(nonblank_scores.shape[1], rounding_mode="trunc")
|
| 64 |
+
nonblank_nbest_token = nonblank_nbest_idx % nonblank_scores.shape[1]
|
| 65 |
+
return nonblank_nbest_scores, nonblank_nbest_hypo_idx, nonblank_nbest_token
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _remove_hypo(hypo: Hypothesis, hypo_list: List[Hypothesis]) -> None:
|
| 69 |
+
for i, elem in enumerate(hypo_list):
|
| 70 |
+
if _get_hypo_key(hypo) == _get_hypo_key(elem):
|
| 71 |
+
del hypo_list[i]
|
| 72 |
+
break
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class RNNTBeamSearch(torch.nn.Module):
|
| 76 |
+
r"""Beam search decoder for RNN-T model.
|
| 77 |
+
|
| 78 |
+
See Also:
|
| 79 |
+
* :class:`torchaudio.pipelines.RNNTBundle`: ASR pipeline with pretrained model.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
model (RNNT): RNN-T model to use.
|
| 83 |
+
blank (int): index of blank token in vocabulary.
|
| 84 |
+
temperature (float, optional): temperature to apply to joint network output.
|
| 85 |
+
Larger values yield more uniform samples. (Default: 1.0)
|
| 86 |
+
hypo_sort_key (Callable[[Hypothesis], float] or None, optional): callable that computes a score
|
| 87 |
+
for a given hypothesis to rank hypotheses by. If ``None``, defaults to callable that returns
|
| 88 |
+
hypothesis score normalized by token sequence length. (Default: None)
|
| 89 |
+
step_max_tokens (int, optional): maximum number of tokens to emit per input time step. (Default: 100)
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def __init__(
|
| 93 |
+
self,
|
| 94 |
+
model: RNNT,
|
| 95 |
+
blank: int,
|
| 96 |
+
temperature: float = 1.0,
|
| 97 |
+
hypo_sort_key: Optional[Callable[[Hypothesis], float]] = None,
|
| 98 |
+
step_max_tokens: int = 100,
|
| 99 |
+
) -> None:
|
| 100 |
+
super().__init__()
|
| 101 |
+
self.model = model
|
| 102 |
+
self.blank = blank
|
| 103 |
+
self.temperature = temperature
|
| 104 |
+
|
| 105 |
+
if hypo_sort_key is None:
|
| 106 |
+
self.hypo_sort_key = _default_hypo_sort_key
|
| 107 |
+
else:
|
| 108 |
+
self.hypo_sort_key = hypo_sort_key
|
| 109 |
+
|
| 110 |
+
self.step_max_tokens = step_max_tokens
|
| 111 |
+
|
| 112 |
+
def _init_b_hypos(self, device: torch.device) -> List[Hypothesis]:
|
| 113 |
+
token = self.blank
|
| 114 |
+
state = None
|
| 115 |
+
|
| 116 |
+
one_tensor = torch.tensor([1], device=device)
|
| 117 |
+
pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state)
|
| 118 |
+
init_hypo = (
|
| 119 |
+
[token],
|
| 120 |
+
pred_out[0].detach(),
|
| 121 |
+
pred_state,
|
| 122 |
+
0.0,
|
| 123 |
+
)
|
| 124 |
+
return [init_hypo]
|
| 125 |
+
|
| 126 |
+
def _gen_next_token_probs(
|
| 127 |
+
self, enc_out: torch.Tensor, hypos: List[Hypothesis], device: torch.device
|
| 128 |
+
) -> torch.Tensor:
|
| 129 |
+
one_tensor = torch.tensor([1], device=device)
|
| 130 |
+
predictor_out = torch.stack([_get_hypo_predictor_out(h) for h in hypos], dim=0)
|
| 131 |
+
joined_out, _, _ = self.model.join(
|
| 132 |
+
enc_out,
|
| 133 |
+
one_tensor,
|
| 134 |
+
predictor_out,
|
| 135 |
+
torch.tensor([1] * len(hypos), device=device),
|
| 136 |
+
) # [beam_width, 1, 1, num_tokens]
|
| 137 |
+
joined_out = torch.nn.functional.log_softmax(joined_out / self.temperature, dim=3)
|
| 138 |
+
return joined_out[:, 0, 0]
|
| 139 |
+
|
| 140 |
+
def _gen_b_hypos(
|
| 141 |
+
self,
|
| 142 |
+
b_hypos: List[Hypothesis],
|
| 143 |
+
a_hypos: List[Hypothesis],
|
| 144 |
+
next_token_probs: torch.Tensor,
|
| 145 |
+
key_to_b_hypo: Dict[str, Hypothesis],
|
| 146 |
+
) -> List[Hypothesis]:
|
| 147 |
+
for i in range(len(a_hypos)):
|
| 148 |
+
h_a = a_hypos[i]
|
| 149 |
+
append_blank_score = _get_hypo_score(h_a) + next_token_probs[i, -1]
|
| 150 |
+
if _get_hypo_key(h_a) in key_to_b_hypo:
|
| 151 |
+
h_b = key_to_b_hypo[_get_hypo_key(h_a)]
|
| 152 |
+
_remove_hypo(h_b, b_hypos)
|
| 153 |
+
score = float(torch.tensor(_get_hypo_score(h_b)).logaddexp(append_blank_score))
|
| 154 |
+
else:
|
| 155 |
+
score = float(append_blank_score)
|
| 156 |
+
h_b = (
|
| 157 |
+
_get_hypo_tokens(h_a),
|
| 158 |
+
_get_hypo_predictor_out(h_a),
|
| 159 |
+
_get_hypo_state(h_a),
|
| 160 |
+
score,
|
| 161 |
+
)
|
| 162 |
+
b_hypos.append(h_b)
|
| 163 |
+
key_to_b_hypo[_get_hypo_key(h_b)] = h_b
|
| 164 |
+
_, sorted_idx = torch.tensor([_get_hypo_score(hypo) for hypo in b_hypos]).sort()
|
| 165 |
+
return [b_hypos[idx] for idx in sorted_idx]
|
| 166 |
+
|
| 167 |
+
def _gen_a_hypos(
|
| 168 |
+
self,
|
| 169 |
+
a_hypos: List[Hypothesis],
|
| 170 |
+
b_hypos: List[Hypothesis],
|
| 171 |
+
next_token_probs: torch.Tensor,
|
| 172 |
+
t: int,
|
| 173 |
+
beam_width: int,
|
| 174 |
+
device: torch.device,
|
| 175 |
+
) -> List[Hypothesis]:
|
| 176 |
+
(
|
| 177 |
+
nonblank_nbest_scores,
|
| 178 |
+
nonblank_nbest_hypo_idx,
|
| 179 |
+
nonblank_nbest_token,
|
| 180 |
+
) = _compute_updated_scores(a_hypos, next_token_probs, beam_width)
|
| 181 |
+
|
| 182 |
+
if len(b_hypos) < beam_width:
|
| 183 |
+
b_nbest_score = -float("inf")
|
| 184 |
+
else:
|
| 185 |
+
b_nbest_score = _get_hypo_score(b_hypos[-beam_width])
|
| 186 |
+
|
| 187 |
+
base_hypos: List[Hypothesis] = []
|
| 188 |
+
new_tokens: List[int] = []
|
| 189 |
+
new_scores: List[float] = []
|
| 190 |
+
for i in range(beam_width):
|
| 191 |
+
score = float(nonblank_nbest_scores[i])
|
| 192 |
+
if score > b_nbest_score:
|
| 193 |
+
a_hypo_idx = int(nonblank_nbest_hypo_idx[i])
|
| 194 |
+
base_hypos.append(a_hypos[a_hypo_idx])
|
| 195 |
+
new_tokens.append(int(nonblank_nbest_token[i]))
|
| 196 |
+
new_scores.append(score)
|
| 197 |
+
|
| 198 |
+
if base_hypos:
|
| 199 |
+
new_hypos = self._gen_new_hypos(base_hypos, new_tokens, new_scores, t, device)
|
| 200 |
+
else:
|
| 201 |
+
new_hypos: List[Hypothesis] = []
|
| 202 |
+
|
| 203 |
+
return new_hypos
|
| 204 |
+
|
| 205 |
+
def _gen_new_hypos(
|
| 206 |
+
self,
|
| 207 |
+
base_hypos: List[Hypothesis],
|
| 208 |
+
tokens: List[int],
|
| 209 |
+
scores: List[float],
|
| 210 |
+
t: int,
|
| 211 |
+
device: torch.device,
|
| 212 |
+
) -> List[Hypothesis]:
|
| 213 |
+
tgt_tokens = torch.tensor([[token] for token in tokens], device=device)
|
| 214 |
+
states = _batch_state(base_hypos)
|
| 215 |
+
pred_out, _, pred_states = self.model.predict(
|
| 216 |
+
tgt_tokens,
|
| 217 |
+
torch.tensor([1] * len(base_hypos), device=device),
|
| 218 |
+
states,
|
| 219 |
+
)
|
| 220 |
+
new_hypos: List[Hypothesis] = []
|
| 221 |
+
for i, h_a in enumerate(base_hypos):
|
| 222 |
+
new_tokens = _get_hypo_tokens(h_a) + [tokens[i]]
|
| 223 |
+
new_hypos.append((new_tokens, pred_out[i].detach(), _slice_state(pred_states, i, device), scores[i]))
|
| 224 |
+
return new_hypos
|
| 225 |
+
|
| 226 |
+
def _search(
|
| 227 |
+
self,
|
| 228 |
+
enc_out: torch.Tensor,
|
| 229 |
+
hypo: Optional[List[Hypothesis]],
|
| 230 |
+
beam_width: int,
|
| 231 |
+
) -> List[Hypothesis]:
|
| 232 |
+
n_time_steps = enc_out.shape[1]
|
| 233 |
+
device = enc_out.device
|
| 234 |
+
|
| 235 |
+
a_hypos: List[Hypothesis] = []
|
| 236 |
+
b_hypos = self._init_b_hypos(device) if hypo is None else hypo
|
| 237 |
+
for t in range(n_time_steps):
|
| 238 |
+
a_hypos = b_hypos
|
| 239 |
+
b_hypos = torch.jit.annotate(List[Hypothesis], [])
|
| 240 |
+
key_to_b_hypo: Dict[str, Hypothesis] = {}
|
| 241 |
+
symbols_current_t = 0
|
| 242 |
+
|
| 243 |
+
while a_hypos:
|
| 244 |
+
next_token_probs = self._gen_next_token_probs(enc_out[:, t : t + 1], a_hypos, device)
|
| 245 |
+
next_token_probs = next_token_probs.cpu()
|
| 246 |
+
b_hypos = self._gen_b_hypos(b_hypos, a_hypos, next_token_probs, key_to_b_hypo)
|
| 247 |
+
|
| 248 |
+
if symbols_current_t == self.step_max_tokens:
|
| 249 |
+
break
|
| 250 |
+
|
| 251 |
+
a_hypos = self._gen_a_hypos(
|
| 252 |
+
a_hypos,
|
| 253 |
+
b_hypos,
|
| 254 |
+
next_token_probs,
|
| 255 |
+
t,
|
| 256 |
+
beam_width,
|
| 257 |
+
device,
|
| 258 |
+
)
|
| 259 |
+
if a_hypos:
|
| 260 |
+
symbols_current_t += 1
|
| 261 |
+
|
| 262 |
+
_, sorted_idx = torch.tensor([self.hypo_sort_key(hyp) for hyp in b_hypos]).topk(beam_width)
|
| 263 |
+
b_hypos = [b_hypos[idx] for idx in sorted_idx]
|
| 264 |
+
|
| 265 |
+
return b_hypos
|
| 266 |
+
|
| 267 |
+
def forward(self, input: torch.Tensor, length: torch.Tensor, beam_width: int) -> List[Hypothesis]:
|
| 268 |
+
r"""Performs beam search for the given input sequence.
|
| 269 |
+
|
| 270 |
+
T: number of frames;
|
| 271 |
+
D: feature dimension of each frame.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
|
| 275 |
+
length (torch.Tensor): number of valid frames in input
|
| 276 |
+
sequence, with shape () or (1,).
|
| 277 |
+
beam_width (int): beam size to use during search.
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
List[Hypothesis]: top-``beam_width`` hypotheses found by beam search.
|
| 281 |
+
"""
|
| 282 |
+
if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
|
| 283 |
+
raise ValueError("input must be of shape (T, D) or (1, T, D)")
|
| 284 |
+
if input.dim() == 2:
|
| 285 |
+
input = input.unsqueeze(0)
|
| 286 |
+
|
| 287 |
+
if length.shape != () and length.shape != (1,):
|
| 288 |
+
raise ValueError("length must be of shape () or (1,)")
|
| 289 |
+
if length.dim() == 0:
|
| 290 |
+
length = length.unsqueeze(0)
|
| 291 |
+
|
| 292 |
+
enc_out, _ = self.model.transcribe(input, length)
|
| 293 |
+
return self._search(enc_out, None, beam_width)
|
| 294 |
+
|
| 295 |
+
@torch.jit.export
|
| 296 |
+
def infer(
|
| 297 |
+
self,
|
| 298 |
+
input: torch.Tensor,
|
| 299 |
+
length: torch.Tensor,
|
| 300 |
+
beam_width: int,
|
| 301 |
+
state: Optional[List[List[torch.Tensor]]] = None,
|
| 302 |
+
hypothesis: Optional[List[Hypothesis]] = None,
|
| 303 |
+
) -> Tuple[List[Hypothesis], List[List[torch.Tensor]]]:
|
| 304 |
+
r"""Performs beam search for the given input sequence in streaming mode.
|
| 305 |
+
|
| 306 |
+
T: number of frames;
|
| 307 |
+
D: feature dimension of each frame.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
|
| 311 |
+
length (torch.Tensor): number of valid frames in input
|
| 312 |
+
sequence, with shape () or (1,).
|
| 313 |
+
beam_width (int): beam size to use during search.
|
| 314 |
+
state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
|
| 315 |
+
representing transcription network internal state generated in preceding
|
| 316 |
+
invocation. (Default: ``None``)
|
| 317 |
+
hypothesis (List[Hypothesis] or None): hypotheses from preceding invocation to seed
|
| 318 |
+
search with. (Default: ``None``)
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
(List[Hypothesis], List[List[torch.Tensor]]):
|
| 322 |
+
List[Hypothesis]
|
| 323 |
+
top-``beam_width`` hypotheses found by beam search.
|
| 324 |
+
List[List[torch.Tensor]]
|
| 325 |
+
list of lists of tensors representing transcription network
|
| 326 |
+
internal state generated in current invocation.
|
| 327 |
+
"""
|
| 328 |
+
if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
|
| 329 |
+
raise ValueError("input must be of shape (T, D) or (1, T, D)")
|
| 330 |
+
if input.dim() == 2:
|
| 331 |
+
input = input.unsqueeze(0)
|
| 332 |
+
|
| 333 |
+
if length.shape != () and length.shape != (1,):
|
| 334 |
+
raise ValueError("length must be of shape () or (1,)")
|
| 335 |
+
if length.dim() == 0:
|
| 336 |
+
length = length.unsqueeze(0)
|
| 337 |
+
|
| 338 |
+
enc_out, _, state = self.model.transcribe_streaming(input, length, state)
|
| 339 |
+
return self._search(enc_out, hypothesis, beam_width), state
|
.venv/lib/python3.11/site-packages/torchaudio/models/tacotron2.py
ADDED
|
@@ -0,0 +1,1046 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# *****************************************************************************
|
| 2 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Redistribution and use in source and binary forms, with or without
|
| 5 |
+
# modification, are permitted provided that the following conditions are met:
|
| 6 |
+
# * Redistributions of source code must retain the above copyright
|
| 7 |
+
# notice, this list of conditions and the following disclaimer.
|
| 8 |
+
# * Redistributions in binary form must reproduce the above copyright
|
| 9 |
+
# notice, this list of conditions and the following disclaimer in the
|
| 10 |
+
# documentation and/or other materials provided with the distribution.
|
| 11 |
+
# * Neither the name of the NVIDIA CORPORATION nor the
|
| 12 |
+
# names of its contributors may be used to endorse or promote products
|
| 13 |
+
# derived from this software without specific prior written permission.
|
| 14 |
+
#
|
| 15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
| 16 |
+
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
| 17 |
+
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 18 |
+
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
| 19 |
+
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
| 20 |
+
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
| 21 |
+
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
| 22 |
+
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
| 23 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
| 24 |
+
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 25 |
+
#
|
| 26 |
+
# *****************************************************************************
|
| 27 |
+
|
| 28 |
+
import warnings
|
| 29 |
+
from typing import List, Optional, Tuple, Union
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
from torch import nn, Tensor
|
| 33 |
+
from torch.nn import functional as F
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
__all__ = [
|
| 37 |
+
"Tacotron2",
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _get_linear_layer(in_dim: int, out_dim: int, bias: bool = True, w_init_gain: str = "linear") -> torch.nn.Linear:
|
| 42 |
+
r"""Linear layer with xavier uniform initialization.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
in_dim (int): Size of each input sample.
|
| 46 |
+
out_dim (int): Size of each output sample.
|
| 47 |
+
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias. (Default: ``True``)
|
| 48 |
+
w_init_gain (str, optional): Parameter passed to ``torch.nn.init.calculate_gain``
|
| 49 |
+
for setting the gain parameter of ``xavier_uniform_``. (Default: ``linear``)
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
(torch.nn.Linear): The corresponding linear layer.
|
| 53 |
+
"""
|
| 54 |
+
linear = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
| 55 |
+
torch.nn.init.xavier_uniform_(linear.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
|
| 56 |
+
return linear
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _get_conv1d_layer(
|
| 60 |
+
in_channels: int,
|
| 61 |
+
out_channels: int,
|
| 62 |
+
kernel_size: int = 1,
|
| 63 |
+
stride: int = 1,
|
| 64 |
+
padding: Optional[Union[str, int, Tuple[int]]] = None,
|
| 65 |
+
dilation: int = 1,
|
| 66 |
+
bias: bool = True,
|
| 67 |
+
w_init_gain: str = "linear",
|
| 68 |
+
) -> torch.nn.Conv1d:
|
| 69 |
+
r"""1D convolution with xavier uniform initialization.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
in_channels (int): Number of channels in the input image.
|
| 73 |
+
out_channels (int): Number of channels produced by the convolution.
|
| 74 |
+
kernel_size (int, optional): Number of channels in the input image. (Default: ``1``)
|
| 75 |
+
stride (int, optional): Number of channels in the input image. (Default: ``1``)
|
| 76 |
+
padding (str, int or tuple, optional): Padding added to both sides of the input.
|
| 77 |
+
(Default: dilation * (kernel_size - 1) / 2)
|
| 78 |
+
dilation (int, optional): Number of channels in the input image. (Default: ``1``)
|
| 79 |
+
w_init_gain (str, optional): Parameter passed to ``torch.nn.init.calculate_gain``
|
| 80 |
+
for setting the gain parameter of ``xavier_uniform_``. (Default: ``linear``)
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
(torch.nn.Conv1d): The corresponding Conv1D layer.
|
| 84 |
+
"""
|
| 85 |
+
if padding is None:
|
| 86 |
+
if kernel_size % 2 != 1:
|
| 87 |
+
raise ValueError("kernel_size must be odd")
|
| 88 |
+
padding = int(dilation * (kernel_size - 1) / 2)
|
| 89 |
+
|
| 90 |
+
conv1d = torch.nn.Conv1d(
|
| 91 |
+
in_channels,
|
| 92 |
+
out_channels,
|
| 93 |
+
kernel_size=kernel_size,
|
| 94 |
+
stride=stride,
|
| 95 |
+
padding=padding,
|
| 96 |
+
dilation=dilation,
|
| 97 |
+
bias=bias,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
torch.nn.init.xavier_uniform_(conv1d.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
|
| 101 |
+
|
| 102 |
+
return conv1d
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _get_mask_from_lengths(lengths: Tensor) -> Tensor:
|
| 106 |
+
r"""Returns a binary mask based on ``lengths``. The ``i``-th row and ``j``-th column of the mask
|
| 107 |
+
is ``1`` if ``j`` is smaller than ``i``-th element of ``lengths.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
lengths (Tensor): The length of each element in the batch, with shape (n_batch, ).
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
mask (Tensor): The binary mask, with shape (n_batch, max of ``lengths``).
|
| 114 |
+
"""
|
| 115 |
+
max_len = torch.max(lengths).item()
|
| 116 |
+
ids = torch.arange(0, max_len, device=lengths.device, dtype=lengths.dtype)
|
| 117 |
+
mask = (ids < lengths.unsqueeze(1)).byte()
|
| 118 |
+
mask = torch.le(mask, 0)
|
| 119 |
+
return mask
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class _LocationLayer(nn.Module):
|
| 123 |
+
r"""Location layer used in the Attention model.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
attention_n_filter (int): Number of filters for attention model.
|
| 127 |
+
attention_kernel_size (int): Kernel size for attention model.
|
| 128 |
+
attention_hidden_dim (int): Dimension of attention hidden representation.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
def __init__(
|
| 132 |
+
self,
|
| 133 |
+
attention_n_filter: int,
|
| 134 |
+
attention_kernel_size: int,
|
| 135 |
+
attention_hidden_dim: int,
|
| 136 |
+
):
|
| 137 |
+
super().__init__()
|
| 138 |
+
padding = int((attention_kernel_size - 1) / 2)
|
| 139 |
+
self.location_conv = _get_conv1d_layer(
|
| 140 |
+
2,
|
| 141 |
+
attention_n_filter,
|
| 142 |
+
kernel_size=attention_kernel_size,
|
| 143 |
+
padding=padding,
|
| 144 |
+
bias=False,
|
| 145 |
+
stride=1,
|
| 146 |
+
dilation=1,
|
| 147 |
+
)
|
| 148 |
+
self.location_dense = _get_linear_layer(
|
| 149 |
+
attention_n_filter, attention_hidden_dim, bias=False, w_init_gain="tanh"
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
def forward(self, attention_weights_cat: Tensor) -> Tensor:
|
| 153 |
+
r"""Location layer used in the Attention model.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
attention_weights_cat (Tensor): Cumulative and previous attention weights
|
| 157 |
+
with shape (n_batch, 2, max of ``text_lengths``).
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
processed_attention (Tensor): Cumulative and previous attention weights
|
| 161 |
+
with shape (n_batch, ``attention_hidden_dim``).
|
| 162 |
+
"""
|
| 163 |
+
# (n_batch, attention_n_filter, text_lengths.max())
|
| 164 |
+
processed_attention = self.location_conv(attention_weights_cat)
|
| 165 |
+
processed_attention = processed_attention.transpose(1, 2)
|
| 166 |
+
# (n_batch, text_lengths.max(), attention_hidden_dim)
|
| 167 |
+
processed_attention = self.location_dense(processed_attention)
|
| 168 |
+
return processed_attention
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class _Attention(nn.Module):
|
| 172 |
+
r"""Locally sensitive attention model.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
attention_rnn_dim (int): Number of hidden units for RNN.
|
| 176 |
+
encoder_embedding_dim (int): Number of embedding dimensions in the Encoder.
|
| 177 |
+
attention_hidden_dim (int): Dimension of attention hidden representation.
|
| 178 |
+
attention_location_n_filter (int): Number of filters for Attention model.
|
| 179 |
+
attention_location_kernel_size (int): Kernel size for Attention model.
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
def __init__(
|
| 183 |
+
self,
|
| 184 |
+
attention_rnn_dim: int,
|
| 185 |
+
encoder_embedding_dim: int,
|
| 186 |
+
attention_hidden_dim: int,
|
| 187 |
+
attention_location_n_filter: int,
|
| 188 |
+
attention_location_kernel_size: int,
|
| 189 |
+
) -> None:
|
| 190 |
+
super().__init__()
|
| 191 |
+
self.query_layer = _get_linear_layer(attention_rnn_dim, attention_hidden_dim, bias=False, w_init_gain="tanh")
|
| 192 |
+
self.memory_layer = _get_linear_layer(
|
| 193 |
+
encoder_embedding_dim, attention_hidden_dim, bias=False, w_init_gain="tanh"
|
| 194 |
+
)
|
| 195 |
+
self.v = _get_linear_layer(attention_hidden_dim, 1, bias=False)
|
| 196 |
+
self.location_layer = _LocationLayer(
|
| 197 |
+
attention_location_n_filter,
|
| 198 |
+
attention_location_kernel_size,
|
| 199 |
+
attention_hidden_dim,
|
| 200 |
+
)
|
| 201 |
+
self.score_mask_value = -float("inf")
|
| 202 |
+
|
| 203 |
+
def _get_alignment_energies(self, query: Tensor, processed_memory: Tensor, attention_weights_cat: Tensor) -> Tensor:
|
| 204 |
+
r"""Get the alignment vector.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
query (Tensor): Decoder output with shape (n_batch, n_mels * n_frames_per_step).
|
| 208 |
+
processed_memory (Tensor): Processed Encoder outputs
|
| 209 |
+
with shape (n_batch, max of ``text_lengths``, attention_hidden_dim).
|
| 210 |
+
attention_weights_cat (Tensor): Cumulative and previous attention weights
|
| 211 |
+
with shape (n_batch, 2, max of ``text_lengths``).
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
alignment (Tensor): attention weights, it is a tensor with shape (batch, max of ``text_lengths``).
|
| 215 |
+
"""
|
| 216 |
+
|
| 217 |
+
processed_query = self.query_layer(query.unsqueeze(1))
|
| 218 |
+
processed_attention_weights = self.location_layer(attention_weights_cat)
|
| 219 |
+
energies = self.v(torch.tanh(processed_query + processed_attention_weights + processed_memory))
|
| 220 |
+
|
| 221 |
+
alignment = energies.squeeze(2)
|
| 222 |
+
return alignment
|
| 223 |
+
|
| 224 |
+
def forward(
|
| 225 |
+
self,
|
| 226 |
+
attention_hidden_state: Tensor,
|
| 227 |
+
memory: Tensor,
|
| 228 |
+
processed_memory: Tensor,
|
| 229 |
+
attention_weights_cat: Tensor,
|
| 230 |
+
mask: Tensor,
|
| 231 |
+
) -> Tuple[Tensor, Tensor]:
|
| 232 |
+
r"""Pass the input through the Attention model.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
attention_hidden_state (Tensor): Attention rnn last output with shape (n_batch, ``attention_rnn_dim``).
|
| 236 |
+
memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
|
| 237 |
+
processed_memory (Tensor): Processed Encoder outputs
|
| 238 |
+
with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``).
|
| 239 |
+
attention_weights_cat (Tensor): Previous and cumulative attention weights
|
| 240 |
+
with shape (n_batch, current_num_frames * 2, max of ``text_lengths``).
|
| 241 |
+
mask (Tensor): Binary mask for padded data with shape (n_batch, current_num_frames).
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
|
| 245 |
+
attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
|
| 246 |
+
"""
|
| 247 |
+
alignment = self._get_alignment_energies(attention_hidden_state, processed_memory, attention_weights_cat)
|
| 248 |
+
|
| 249 |
+
alignment = alignment.masked_fill(mask, self.score_mask_value)
|
| 250 |
+
|
| 251 |
+
attention_weights = F.softmax(alignment, dim=1)
|
| 252 |
+
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
|
| 253 |
+
attention_context = attention_context.squeeze(1)
|
| 254 |
+
|
| 255 |
+
return attention_context, attention_weights
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class _Prenet(nn.Module):
|
| 259 |
+
r"""Prenet Module. It is consists of ``len(output_size)`` linear layers.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
in_dim (int): The size of each input sample.
|
| 263 |
+
output_sizes (list): The output dimension of each linear layers.
|
| 264 |
+
"""
|
| 265 |
+
|
| 266 |
+
def __init__(self, in_dim: int, out_sizes: List[int]) -> None:
|
| 267 |
+
super().__init__()
|
| 268 |
+
in_sizes = [in_dim] + out_sizes[:-1]
|
| 269 |
+
self.layers = nn.ModuleList(
|
| 270 |
+
[_get_linear_layer(in_size, out_size, bias=False) for (in_size, out_size) in zip(in_sizes, out_sizes)]
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 274 |
+
r"""Pass the input through Prenet.
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
x (Tensor): The input sequence to Prenet with shape (n_batch, in_dim).
|
| 278 |
+
|
| 279 |
+
Return:
|
| 280 |
+
x (Tensor): Tensor with shape (n_batch, sizes[-1])
|
| 281 |
+
"""
|
| 282 |
+
|
| 283 |
+
for linear in self.layers:
|
| 284 |
+
x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
|
| 285 |
+
return x
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class _Postnet(nn.Module):
|
| 289 |
+
r"""Postnet Module.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
n_mels (int): Number of mel bins.
|
| 293 |
+
postnet_embedding_dim (int): Postnet embedding dimension.
|
| 294 |
+
postnet_kernel_size (int): Postnet kernel size.
|
| 295 |
+
postnet_n_convolution (int): Number of postnet convolutions.
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
def __init__(
|
| 299 |
+
self,
|
| 300 |
+
n_mels: int,
|
| 301 |
+
postnet_embedding_dim: int,
|
| 302 |
+
postnet_kernel_size: int,
|
| 303 |
+
postnet_n_convolution: int,
|
| 304 |
+
):
|
| 305 |
+
super().__init__()
|
| 306 |
+
self.convolutions = nn.ModuleList()
|
| 307 |
+
|
| 308 |
+
for i in range(postnet_n_convolution):
|
| 309 |
+
in_channels = n_mels if i == 0 else postnet_embedding_dim
|
| 310 |
+
out_channels = n_mels if i == (postnet_n_convolution - 1) else postnet_embedding_dim
|
| 311 |
+
init_gain = "linear" if i == (postnet_n_convolution - 1) else "tanh"
|
| 312 |
+
num_features = n_mels if i == (postnet_n_convolution - 1) else postnet_embedding_dim
|
| 313 |
+
self.convolutions.append(
|
| 314 |
+
nn.Sequential(
|
| 315 |
+
_get_conv1d_layer(
|
| 316 |
+
in_channels,
|
| 317 |
+
out_channels,
|
| 318 |
+
kernel_size=postnet_kernel_size,
|
| 319 |
+
stride=1,
|
| 320 |
+
padding=int((postnet_kernel_size - 1) / 2),
|
| 321 |
+
dilation=1,
|
| 322 |
+
w_init_gain=init_gain,
|
| 323 |
+
),
|
| 324 |
+
nn.BatchNorm1d(num_features),
|
| 325 |
+
)
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
self.n_convs = len(self.convolutions)
|
| 329 |
+
|
| 330 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 331 |
+
r"""Pass the input through Postnet.
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
x (Tensor): The input sequence with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
|
| 335 |
+
|
| 336 |
+
Return:
|
| 337 |
+
x (Tensor): Tensor with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
|
| 338 |
+
"""
|
| 339 |
+
|
| 340 |
+
for i, conv in enumerate(self.convolutions):
|
| 341 |
+
if i < self.n_convs - 1:
|
| 342 |
+
x = F.dropout(torch.tanh(conv(x)), 0.5, training=self.training)
|
| 343 |
+
else:
|
| 344 |
+
x = F.dropout(conv(x), 0.5, training=self.training)
|
| 345 |
+
|
| 346 |
+
return x
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
class _Encoder(nn.Module):
|
| 350 |
+
r"""Encoder Module.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
encoder_embedding_dim (int): Number of embedding dimensions in the encoder.
|
| 354 |
+
encoder_n_convolution (int): Number of convolution layers in the encoder.
|
| 355 |
+
encoder_kernel_size (int): The kernel size in the encoder.
|
| 356 |
+
|
| 357 |
+
Examples
|
| 358 |
+
>>> encoder = _Encoder(3, 512, 5)
|
| 359 |
+
>>> input = torch.rand(10, 20, 30)
|
| 360 |
+
>>> output = encoder(input) # shape: (10, 30, 512)
|
| 361 |
+
"""
|
| 362 |
+
|
| 363 |
+
def __init__(
|
| 364 |
+
self,
|
| 365 |
+
encoder_embedding_dim: int,
|
| 366 |
+
encoder_n_convolution: int,
|
| 367 |
+
encoder_kernel_size: int,
|
| 368 |
+
) -> None:
|
| 369 |
+
super().__init__()
|
| 370 |
+
|
| 371 |
+
self.convolutions = nn.ModuleList()
|
| 372 |
+
for _ in range(encoder_n_convolution):
|
| 373 |
+
conv_layer = nn.Sequential(
|
| 374 |
+
_get_conv1d_layer(
|
| 375 |
+
encoder_embedding_dim,
|
| 376 |
+
encoder_embedding_dim,
|
| 377 |
+
kernel_size=encoder_kernel_size,
|
| 378 |
+
stride=1,
|
| 379 |
+
padding=int((encoder_kernel_size - 1) / 2),
|
| 380 |
+
dilation=1,
|
| 381 |
+
w_init_gain="relu",
|
| 382 |
+
),
|
| 383 |
+
nn.BatchNorm1d(encoder_embedding_dim),
|
| 384 |
+
)
|
| 385 |
+
self.convolutions.append(conv_layer)
|
| 386 |
+
|
| 387 |
+
self.lstm = nn.LSTM(
|
| 388 |
+
encoder_embedding_dim,
|
| 389 |
+
int(encoder_embedding_dim / 2),
|
| 390 |
+
1,
|
| 391 |
+
batch_first=True,
|
| 392 |
+
bidirectional=True,
|
| 393 |
+
)
|
| 394 |
+
self.lstm.flatten_parameters()
|
| 395 |
+
|
| 396 |
+
def forward(self, x: Tensor, input_lengths: Tensor) -> Tensor:
|
| 397 |
+
r"""Pass the input through the Encoder.
|
| 398 |
+
|
| 399 |
+
Args:
|
| 400 |
+
x (Tensor): The input sequences with shape (n_batch, encoder_embedding_dim, n_seq).
|
| 401 |
+
input_lengths (Tensor): The length of each input sequence with shape (n_batch, ).
|
| 402 |
+
|
| 403 |
+
Return:
|
| 404 |
+
x (Tensor): A tensor with shape (n_batch, n_seq, encoder_embedding_dim).
|
| 405 |
+
"""
|
| 406 |
+
|
| 407 |
+
for conv in self.convolutions:
|
| 408 |
+
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
|
| 409 |
+
|
| 410 |
+
x = x.transpose(1, 2)
|
| 411 |
+
|
| 412 |
+
input_lengths = input_lengths.cpu()
|
| 413 |
+
x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True)
|
| 414 |
+
|
| 415 |
+
outputs, _ = self.lstm(x)
|
| 416 |
+
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
|
| 417 |
+
|
| 418 |
+
return outputs
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
class _Decoder(nn.Module):
|
| 422 |
+
r"""Decoder with Attention model.
|
| 423 |
+
|
| 424 |
+
Args:
|
| 425 |
+
n_mels (int): number of mel bins
|
| 426 |
+
n_frames_per_step (int): number of frames processed per step, only 1 is supported
|
| 427 |
+
encoder_embedding_dim (int): the number of embedding dimensions in the encoder.
|
| 428 |
+
decoder_rnn_dim (int): number of units in decoder LSTM
|
| 429 |
+
decoder_max_step (int): maximum number of output mel spectrograms
|
| 430 |
+
decoder_dropout (float): dropout probability for decoder LSTM
|
| 431 |
+
decoder_early_stopping (bool): stop decoding when all samples are finished
|
| 432 |
+
attention_rnn_dim (int): number of units in attention LSTM
|
| 433 |
+
attention_hidden_dim (int): dimension of attention hidden representation
|
| 434 |
+
attention_location_n_filter (int): number of filters for attention model
|
| 435 |
+
attention_location_kernel_size (int): kernel size for attention model
|
| 436 |
+
attention_dropout (float): dropout probability for attention LSTM
|
| 437 |
+
prenet_dim (int): number of ReLU units in prenet layers
|
| 438 |
+
gate_threshold (float): probability threshold for stop token
|
| 439 |
+
"""
|
| 440 |
+
|
| 441 |
+
def __init__(
|
| 442 |
+
self,
|
| 443 |
+
n_mels: int,
|
| 444 |
+
n_frames_per_step: int,
|
| 445 |
+
encoder_embedding_dim: int,
|
| 446 |
+
decoder_rnn_dim: int,
|
| 447 |
+
decoder_max_step: int,
|
| 448 |
+
decoder_dropout: float,
|
| 449 |
+
decoder_early_stopping: bool,
|
| 450 |
+
attention_rnn_dim: int,
|
| 451 |
+
attention_hidden_dim: int,
|
| 452 |
+
attention_location_n_filter: int,
|
| 453 |
+
attention_location_kernel_size: int,
|
| 454 |
+
attention_dropout: float,
|
| 455 |
+
prenet_dim: int,
|
| 456 |
+
gate_threshold: float,
|
| 457 |
+
) -> None:
|
| 458 |
+
|
| 459 |
+
super().__init__()
|
| 460 |
+
self.n_mels = n_mels
|
| 461 |
+
self.n_frames_per_step = n_frames_per_step
|
| 462 |
+
self.encoder_embedding_dim = encoder_embedding_dim
|
| 463 |
+
self.attention_rnn_dim = attention_rnn_dim
|
| 464 |
+
self.decoder_rnn_dim = decoder_rnn_dim
|
| 465 |
+
self.prenet_dim = prenet_dim
|
| 466 |
+
self.decoder_max_step = decoder_max_step
|
| 467 |
+
self.gate_threshold = gate_threshold
|
| 468 |
+
self.attention_dropout = attention_dropout
|
| 469 |
+
self.decoder_dropout = decoder_dropout
|
| 470 |
+
self.decoder_early_stopping = decoder_early_stopping
|
| 471 |
+
|
| 472 |
+
self.prenet = _Prenet(n_mels * n_frames_per_step, [prenet_dim, prenet_dim])
|
| 473 |
+
|
| 474 |
+
self.attention_rnn = nn.LSTMCell(prenet_dim + encoder_embedding_dim, attention_rnn_dim)
|
| 475 |
+
|
| 476 |
+
self.attention_layer = _Attention(
|
| 477 |
+
attention_rnn_dim,
|
| 478 |
+
encoder_embedding_dim,
|
| 479 |
+
attention_hidden_dim,
|
| 480 |
+
attention_location_n_filter,
|
| 481 |
+
attention_location_kernel_size,
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
self.decoder_rnn = nn.LSTMCell(attention_rnn_dim + encoder_embedding_dim, decoder_rnn_dim, True)
|
| 485 |
+
|
| 486 |
+
self.linear_projection = _get_linear_layer(decoder_rnn_dim + encoder_embedding_dim, n_mels * n_frames_per_step)
|
| 487 |
+
|
| 488 |
+
self.gate_layer = _get_linear_layer(
|
| 489 |
+
decoder_rnn_dim + encoder_embedding_dim, 1, bias=True, w_init_gain="sigmoid"
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
def _get_initial_frame(self, memory: Tensor) -> Tensor:
|
| 493 |
+
r"""Gets all zeros frames to use as the first decoder input.
|
| 494 |
+
|
| 495 |
+
Args:
|
| 496 |
+
memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
|
| 497 |
+
|
| 498 |
+
Returns:
|
| 499 |
+
decoder_input (Tensor): all zeros frames with shape
|
| 500 |
+
(n_batch, max of ``text_lengths``, ``n_mels * n_frames_per_step``).
|
| 501 |
+
"""
|
| 502 |
+
|
| 503 |
+
n_batch = memory.size(0)
|
| 504 |
+
dtype = memory.dtype
|
| 505 |
+
device = memory.device
|
| 506 |
+
decoder_input = torch.zeros(n_batch, self.n_mels * self.n_frames_per_step, dtype=dtype, device=device)
|
| 507 |
+
return decoder_input
|
| 508 |
+
|
| 509 |
+
def _initialize_decoder_states(
|
| 510 |
+
self, memory: Tensor
|
| 511 |
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
|
| 512 |
+
r"""Initializes attention rnn states, decoder rnn states, attention
|
| 513 |
+
weights, attention cumulative weights, attention context, stores memory
|
| 514 |
+
and stores processed memory.
|
| 515 |
+
|
| 516 |
+
Args:
|
| 517 |
+
memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
|
| 518 |
+
|
| 519 |
+
Returns:
|
| 520 |
+
attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
|
| 521 |
+
attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
|
| 522 |
+
decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
|
| 523 |
+
decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
|
| 524 |
+
attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
|
| 525 |
+
attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``).
|
| 526 |
+
attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
|
| 527 |
+
processed_memory (Tensor): Processed encoder outputs
|
| 528 |
+
with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``).
|
| 529 |
+
"""
|
| 530 |
+
n_batch = memory.size(0)
|
| 531 |
+
max_time = memory.size(1)
|
| 532 |
+
dtype = memory.dtype
|
| 533 |
+
device = memory.device
|
| 534 |
+
|
| 535 |
+
attention_hidden = torch.zeros(n_batch, self.attention_rnn_dim, dtype=dtype, device=device)
|
| 536 |
+
attention_cell = torch.zeros(n_batch, self.attention_rnn_dim, dtype=dtype, device=device)
|
| 537 |
+
|
| 538 |
+
decoder_hidden = torch.zeros(n_batch, self.decoder_rnn_dim, dtype=dtype, device=device)
|
| 539 |
+
decoder_cell = torch.zeros(n_batch, self.decoder_rnn_dim, dtype=dtype, device=device)
|
| 540 |
+
|
| 541 |
+
attention_weights = torch.zeros(n_batch, max_time, dtype=dtype, device=device)
|
| 542 |
+
attention_weights_cum = torch.zeros(n_batch, max_time, dtype=dtype, device=device)
|
| 543 |
+
attention_context = torch.zeros(n_batch, self.encoder_embedding_dim, dtype=dtype, device=device)
|
| 544 |
+
|
| 545 |
+
processed_memory = self.attention_layer.memory_layer(memory)
|
| 546 |
+
|
| 547 |
+
return (
|
| 548 |
+
attention_hidden,
|
| 549 |
+
attention_cell,
|
| 550 |
+
decoder_hidden,
|
| 551 |
+
decoder_cell,
|
| 552 |
+
attention_weights,
|
| 553 |
+
attention_weights_cum,
|
| 554 |
+
attention_context,
|
| 555 |
+
processed_memory,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
def _parse_decoder_inputs(self, decoder_inputs: Tensor) -> Tensor:
|
| 559 |
+
r"""Prepares decoder inputs.
|
| 560 |
+
|
| 561 |
+
Args:
|
| 562 |
+
decoder_inputs (Tensor): Inputs used for teacher-forced training, i.e. mel-specs,
|
| 563 |
+
with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``)
|
| 564 |
+
|
| 565 |
+
Returns:
|
| 566 |
+
inputs (Tensor): Processed decoder inputs with shape (max of ``mel_specgram_lengths``, n_batch, ``n_mels``).
|
| 567 |
+
"""
|
| 568 |
+
# (n_batch, n_mels, mel_specgram_lengths.max()) -> (n_batch, mel_specgram_lengths.max(), n_mels)
|
| 569 |
+
decoder_inputs = decoder_inputs.transpose(1, 2)
|
| 570 |
+
decoder_inputs = decoder_inputs.view(
|
| 571 |
+
decoder_inputs.size(0),
|
| 572 |
+
int(decoder_inputs.size(1) / self.n_frames_per_step),
|
| 573 |
+
-1,
|
| 574 |
+
)
|
| 575 |
+
# (n_batch, mel_specgram_lengths.max(), n_mels) -> (mel_specgram_lengths.max(), n_batch, n_mels)
|
| 576 |
+
decoder_inputs = decoder_inputs.transpose(0, 1)
|
| 577 |
+
return decoder_inputs
|
| 578 |
+
|
| 579 |
+
def _parse_decoder_outputs(
|
| 580 |
+
self, mel_specgram: Tensor, gate_outputs: Tensor, alignments: Tensor
|
| 581 |
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
| 582 |
+
r"""Prepares decoder outputs for output
|
| 583 |
+
|
| 584 |
+
Args:
|
| 585 |
+
mel_specgram (Tensor): mel spectrogram with shape (max of ``mel_specgram_lengths``, n_batch, ``n_mels``)
|
| 586 |
+
gate_outputs (Tensor): predicted stop token with shape (max of ``mel_specgram_lengths``, n_batch)
|
| 587 |
+
alignments (Tensor): sequence of attention weights from the decoder
|
| 588 |
+
with shape (max of ``mel_specgram_lengths``, n_batch, max of ``text_lengths``)
|
| 589 |
+
|
| 590 |
+
Returns:
|
| 591 |
+
mel_specgram (Tensor): mel spectrogram with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``)
|
| 592 |
+
gate_outputs (Tensor): predicted stop token with shape (n_batch, max of ``mel_specgram_lengths``)
|
| 593 |
+
alignments (Tensor): sequence of attention weights from the decoder
|
| 594 |
+
with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``)
|
| 595 |
+
"""
|
| 596 |
+
# (mel_specgram_lengths.max(), n_batch, text_lengths.max())
|
| 597 |
+
# -> (n_batch, mel_specgram_lengths.max(), text_lengths.max())
|
| 598 |
+
alignments = alignments.transpose(0, 1).contiguous()
|
| 599 |
+
# (mel_specgram_lengths.max(), n_batch) -> (n_batch, mel_specgram_lengths.max())
|
| 600 |
+
gate_outputs = gate_outputs.transpose(0, 1).contiguous()
|
| 601 |
+
# (mel_specgram_lengths.max(), n_batch, n_mels) -> (n_batch, mel_specgram_lengths.max(), n_mels)
|
| 602 |
+
mel_specgram = mel_specgram.transpose(0, 1).contiguous()
|
| 603 |
+
# decouple frames per step
|
| 604 |
+
shape = (mel_specgram.shape[0], -1, self.n_mels)
|
| 605 |
+
mel_specgram = mel_specgram.view(*shape)
|
| 606 |
+
# (n_batch, mel_specgram_lengths.max(), n_mels) -> (n_batch, n_mels, T_out)
|
| 607 |
+
mel_specgram = mel_specgram.transpose(1, 2)
|
| 608 |
+
|
| 609 |
+
return mel_specgram, gate_outputs, alignments
|
| 610 |
+
|
| 611 |
+
def decode(
|
| 612 |
+
self,
|
| 613 |
+
decoder_input: Tensor,
|
| 614 |
+
attention_hidden: Tensor,
|
| 615 |
+
attention_cell: Tensor,
|
| 616 |
+
decoder_hidden: Tensor,
|
| 617 |
+
decoder_cell: Tensor,
|
| 618 |
+
attention_weights: Tensor,
|
| 619 |
+
attention_weights_cum: Tensor,
|
| 620 |
+
attention_context: Tensor,
|
| 621 |
+
memory: Tensor,
|
| 622 |
+
processed_memory: Tensor,
|
| 623 |
+
mask: Tensor,
|
| 624 |
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
|
| 625 |
+
r"""Decoder step using stored states, attention and memory
|
| 626 |
+
|
| 627 |
+
Args:
|
| 628 |
+
decoder_input (Tensor): Output of the Prenet with shape (n_batch, ``prenet_dim``).
|
| 629 |
+
attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
|
| 630 |
+
attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
|
| 631 |
+
decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
|
| 632 |
+
decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
|
| 633 |
+
attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
|
| 634 |
+
attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``).
|
| 635 |
+
attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
|
| 636 |
+
memory (Tensor): Encoder output with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
|
| 637 |
+
processed_memory (Tensor): Processed Encoder outputs
|
| 638 |
+
with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``).
|
| 639 |
+
mask (Tensor): Binary mask for padded data with shape (n_batch, current_num_frames).
|
| 640 |
+
|
| 641 |
+
Returns:
|
| 642 |
+
decoder_output: Predicted mel spectrogram for the current frame with shape (n_batch, ``n_mels``).
|
| 643 |
+
gate_prediction (Tensor): Prediction of the stop token with shape (n_batch, ``1``).
|
| 644 |
+
attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
|
| 645 |
+
attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
|
| 646 |
+
decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
|
| 647 |
+
decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
|
| 648 |
+
attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
|
| 649 |
+
attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``).
|
| 650 |
+
attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
|
| 651 |
+
"""
|
| 652 |
+
cell_input = torch.cat((decoder_input, attention_context), -1)
|
| 653 |
+
|
| 654 |
+
attention_hidden, attention_cell = self.attention_rnn(cell_input, (attention_hidden, attention_cell))
|
| 655 |
+
attention_hidden = F.dropout(attention_hidden, self.attention_dropout, self.training)
|
| 656 |
+
|
| 657 |
+
attention_weights_cat = torch.cat((attention_weights.unsqueeze(1), attention_weights_cum.unsqueeze(1)), dim=1)
|
| 658 |
+
attention_context, attention_weights = self.attention_layer(
|
| 659 |
+
attention_hidden, memory, processed_memory, attention_weights_cat, mask
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
attention_weights_cum += attention_weights
|
| 663 |
+
decoder_input = torch.cat((attention_hidden, attention_context), -1)
|
| 664 |
+
|
| 665 |
+
decoder_hidden, decoder_cell = self.decoder_rnn(decoder_input, (decoder_hidden, decoder_cell))
|
| 666 |
+
decoder_hidden = F.dropout(decoder_hidden, self.decoder_dropout, self.training)
|
| 667 |
+
|
| 668 |
+
decoder_hidden_attention_context = torch.cat((decoder_hidden, attention_context), dim=1)
|
| 669 |
+
decoder_output = self.linear_projection(decoder_hidden_attention_context)
|
| 670 |
+
|
| 671 |
+
gate_prediction = self.gate_layer(decoder_hidden_attention_context)
|
| 672 |
+
|
| 673 |
+
return (
|
| 674 |
+
decoder_output,
|
| 675 |
+
gate_prediction,
|
| 676 |
+
attention_hidden,
|
| 677 |
+
attention_cell,
|
| 678 |
+
decoder_hidden,
|
| 679 |
+
decoder_cell,
|
| 680 |
+
attention_weights,
|
| 681 |
+
attention_weights_cum,
|
| 682 |
+
attention_context,
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
def forward(
|
| 686 |
+
self, memory: Tensor, mel_specgram_truth: Tensor, memory_lengths: Tensor
|
| 687 |
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
| 688 |
+
r"""Decoder forward pass for training.
|
| 689 |
+
|
| 690 |
+
Args:
|
| 691 |
+
memory (Tensor): Encoder outputs
|
| 692 |
+
with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
|
| 693 |
+
mel_specgram_truth (Tensor): Decoder ground-truth mel-specs for teacher forcing
|
| 694 |
+
with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
|
| 695 |
+
memory_lengths (Tensor): Encoder output lengths for attention masking
|
| 696 |
+
(the same as ``text_lengths``) with shape (n_batch, ).
|
| 697 |
+
|
| 698 |
+
Returns:
|
| 699 |
+
mel_specgram (Tensor): Predicted mel spectrogram
|
| 700 |
+
with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
|
| 701 |
+
gate_outputs (Tensor): Predicted stop token for each timestep
|
| 702 |
+
with shape (n_batch, max of ``mel_specgram_lengths``).
|
| 703 |
+
alignments (Tensor): Sequence of attention weights from the decoder
|
| 704 |
+
with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``).
|
| 705 |
+
"""
|
| 706 |
+
|
| 707 |
+
decoder_input = self._get_initial_frame(memory).unsqueeze(0)
|
| 708 |
+
decoder_inputs = self._parse_decoder_inputs(mel_specgram_truth)
|
| 709 |
+
decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
|
| 710 |
+
decoder_inputs = self.prenet(decoder_inputs)
|
| 711 |
+
|
| 712 |
+
mask = _get_mask_from_lengths(memory_lengths)
|
| 713 |
+
(
|
| 714 |
+
attention_hidden,
|
| 715 |
+
attention_cell,
|
| 716 |
+
decoder_hidden,
|
| 717 |
+
decoder_cell,
|
| 718 |
+
attention_weights,
|
| 719 |
+
attention_weights_cum,
|
| 720 |
+
attention_context,
|
| 721 |
+
processed_memory,
|
| 722 |
+
) = self._initialize_decoder_states(memory)
|
| 723 |
+
|
| 724 |
+
mel_outputs, gate_outputs, alignments = [], [], []
|
| 725 |
+
while len(mel_outputs) < decoder_inputs.size(0) - 1:
|
| 726 |
+
decoder_input = decoder_inputs[len(mel_outputs)]
|
| 727 |
+
(
|
| 728 |
+
mel_output,
|
| 729 |
+
gate_output,
|
| 730 |
+
attention_hidden,
|
| 731 |
+
attention_cell,
|
| 732 |
+
decoder_hidden,
|
| 733 |
+
decoder_cell,
|
| 734 |
+
attention_weights,
|
| 735 |
+
attention_weights_cum,
|
| 736 |
+
attention_context,
|
| 737 |
+
) = self.decode(
|
| 738 |
+
decoder_input,
|
| 739 |
+
attention_hidden,
|
| 740 |
+
attention_cell,
|
| 741 |
+
decoder_hidden,
|
| 742 |
+
decoder_cell,
|
| 743 |
+
attention_weights,
|
| 744 |
+
attention_weights_cum,
|
| 745 |
+
attention_context,
|
| 746 |
+
memory,
|
| 747 |
+
processed_memory,
|
| 748 |
+
mask,
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
mel_outputs += [mel_output.squeeze(1)]
|
| 752 |
+
gate_outputs += [gate_output.squeeze(1)]
|
| 753 |
+
alignments += [attention_weights]
|
| 754 |
+
|
| 755 |
+
mel_specgram, gate_outputs, alignments = self._parse_decoder_outputs(
|
| 756 |
+
torch.stack(mel_outputs), torch.stack(gate_outputs), torch.stack(alignments)
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
return mel_specgram, gate_outputs, alignments
|
| 760 |
+
|
| 761 |
+
def _get_go_frame(self, memory: Tensor) -> Tensor:
|
| 762 |
+
"""Gets all zeros frames to use as the first decoder input
|
| 763 |
+
|
| 764 |
+
args:
|
| 765 |
+
memory (Tensor): Encoder outputs
|
| 766 |
+
with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
|
| 767 |
+
|
| 768 |
+
returns:
|
| 769 |
+
decoder_input (Tensor): All zeros frames with shape(n_batch, ``n_mels`` * ``n_frame_per_step``).
|
| 770 |
+
"""
|
| 771 |
+
|
| 772 |
+
n_batch = memory.size(0)
|
| 773 |
+
dtype = memory.dtype
|
| 774 |
+
device = memory.device
|
| 775 |
+
decoder_input = torch.zeros(n_batch, self.n_mels * self.n_frames_per_step, dtype=dtype, device=device)
|
| 776 |
+
return decoder_input
|
| 777 |
+
|
| 778 |
+
@torch.jit.export
|
| 779 |
+
def infer(self, memory: Tensor, memory_lengths: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
| 780 |
+
"""Decoder inference
|
| 781 |
+
|
| 782 |
+
Args:
|
| 783 |
+
memory (Tensor): Encoder outputs
|
| 784 |
+
with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
|
| 785 |
+
memory_lengths (Tensor): Encoder output lengths for attention masking
|
| 786 |
+
(the same as ``text_lengths``) with shape (n_batch, ).
|
| 787 |
+
|
| 788 |
+
Returns:
|
| 789 |
+
mel_specgram (Tensor): Predicted mel spectrogram
|
| 790 |
+
with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
|
| 791 |
+
mel_specgram_lengths (Tensor): the length of the predicted mel spectrogram (n_batch, ))
|
| 792 |
+
gate_outputs (Tensor): Predicted stop token for each timestep
|
| 793 |
+
with shape (n_batch, max of ``mel_specgram_lengths``).
|
| 794 |
+
alignments (Tensor): Sequence of attention weights from the decoder
|
| 795 |
+
with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``).
|
| 796 |
+
"""
|
| 797 |
+
batch_size, device = memory.size(0), memory.device
|
| 798 |
+
|
| 799 |
+
decoder_input = self._get_go_frame(memory)
|
| 800 |
+
|
| 801 |
+
mask = _get_mask_from_lengths(memory_lengths)
|
| 802 |
+
(
|
| 803 |
+
attention_hidden,
|
| 804 |
+
attention_cell,
|
| 805 |
+
decoder_hidden,
|
| 806 |
+
decoder_cell,
|
| 807 |
+
attention_weights,
|
| 808 |
+
attention_weights_cum,
|
| 809 |
+
attention_context,
|
| 810 |
+
processed_memory,
|
| 811 |
+
) = self._initialize_decoder_states(memory)
|
| 812 |
+
|
| 813 |
+
mel_specgram_lengths = torch.zeros([batch_size], dtype=torch.int32, device=device)
|
| 814 |
+
finished = torch.zeros([batch_size], dtype=torch.bool, device=device)
|
| 815 |
+
mel_specgrams: List[Tensor] = []
|
| 816 |
+
gate_outputs: List[Tensor] = []
|
| 817 |
+
alignments: List[Tensor] = []
|
| 818 |
+
for _ in range(self.decoder_max_step):
|
| 819 |
+
decoder_input = self.prenet(decoder_input)
|
| 820 |
+
(
|
| 821 |
+
mel_specgram,
|
| 822 |
+
gate_output,
|
| 823 |
+
attention_hidden,
|
| 824 |
+
attention_cell,
|
| 825 |
+
decoder_hidden,
|
| 826 |
+
decoder_cell,
|
| 827 |
+
attention_weights,
|
| 828 |
+
attention_weights_cum,
|
| 829 |
+
attention_context,
|
| 830 |
+
) = self.decode(
|
| 831 |
+
decoder_input,
|
| 832 |
+
attention_hidden,
|
| 833 |
+
attention_cell,
|
| 834 |
+
decoder_hidden,
|
| 835 |
+
decoder_cell,
|
| 836 |
+
attention_weights,
|
| 837 |
+
attention_weights_cum,
|
| 838 |
+
attention_context,
|
| 839 |
+
memory,
|
| 840 |
+
processed_memory,
|
| 841 |
+
mask,
|
| 842 |
+
)
|
| 843 |
+
|
| 844 |
+
mel_specgrams.append(mel_specgram.unsqueeze(0))
|
| 845 |
+
gate_outputs.append(gate_output.transpose(0, 1))
|
| 846 |
+
alignments.append(attention_weights)
|
| 847 |
+
mel_specgram_lengths[~finished] += 1
|
| 848 |
+
|
| 849 |
+
finished |= torch.sigmoid(gate_output.squeeze(1)) > self.gate_threshold
|
| 850 |
+
if self.decoder_early_stopping and torch.all(finished):
|
| 851 |
+
break
|
| 852 |
+
|
| 853 |
+
decoder_input = mel_specgram
|
| 854 |
+
|
| 855 |
+
if len(mel_specgrams) == self.decoder_max_step:
|
| 856 |
+
warnings.warn(
|
| 857 |
+
"Reached max decoder steps. The generated spectrogram might not cover " "the whole transcript."
|
| 858 |
+
)
|
| 859 |
+
|
| 860 |
+
mel_specgrams = torch.cat(mel_specgrams, dim=0)
|
| 861 |
+
gate_outputs = torch.cat(gate_outputs, dim=0)
|
| 862 |
+
alignments = torch.cat(alignments, dim=0)
|
| 863 |
+
|
| 864 |
+
mel_specgrams, gate_outputs, alignments = self._parse_decoder_outputs(mel_specgrams, gate_outputs, alignments)
|
| 865 |
+
|
| 866 |
+
return mel_specgrams, mel_specgram_lengths, gate_outputs, alignments
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
class Tacotron2(nn.Module):
|
| 870 |
+
r"""Tacotron2 model from *Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions*
|
| 871 |
+
:cite:`shen2018natural` based on the implementation from
|
| 872 |
+
`Nvidia Deep Learning Examples <https://github.com/NVIDIA/DeepLearningExamples/>`_.
|
| 873 |
+
|
| 874 |
+
See Also:
|
| 875 |
+
* :class:`torchaudio.pipelines.Tacotron2TTSBundle`: TTS pipeline with pretrained model.
|
| 876 |
+
|
| 877 |
+
Args:
|
| 878 |
+
mask_padding (bool, optional): Use mask padding (Default: ``False``).
|
| 879 |
+
n_mels (int, optional): Number of mel bins (Default: ``80``).
|
| 880 |
+
n_symbol (int, optional): Number of symbols for the input text (Default: ``148``).
|
| 881 |
+
n_frames_per_step (int, optional): Number of frames processed per step, only 1 is supported (Default: ``1``).
|
| 882 |
+
symbol_embedding_dim (int, optional): Input embedding dimension (Default: ``512``).
|
| 883 |
+
encoder_n_convolution (int, optional): Number of encoder convolutions (Default: ``3``).
|
| 884 |
+
encoder_kernel_size (int, optional): Encoder kernel size (Default: ``5``).
|
| 885 |
+
encoder_embedding_dim (int, optional): Encoder embedding dimension (Default: ``512``).
|
| 886 |
+
decoder_rnn_dim (int, optional): Number of units in decoder LSTM (Default: ``1024``).
|
| 887 |
+
decoder_max_step (int, optional): Maximum number of output mel spectrograms (Default: ``2000``).
|
| 888 |
+
decoder_dropout (float, optional): Dropout probability for decoder LSTM (Default: ``0.1``).
|
| 889 |
+
decoder_early_stopping (bool, optional): Continue decoding after all samples are finished (Default: ``True``).
|
| 890 |
+
attention_rnn_dim (int, optional): Number of units in attention LSTM (Default: ``1024``).
|
| 891 |
+
attention_hidden_dim (int, optional): Dimension of attention hidden representation (Default: ``128``).
|
| 892 |
+
attention_location_n_filter (int, optional): Number of filters for attention model (Default: ``32``).
|
| 893 |
+
attention_location_kernel_size (int, optional): Kernel size for attention model (Default: ``31``).
|
| 894 |
+
attention_dropout (float, optional): Dropout probability for attention LSTM (Default: ``0.1``).
|
| 895 |
+
prenet_dim (int, optional): Number of ReLU units in prenet layers (Default: ``256``).
|
| 896 |
+
postnet_n_convolution (int, optional): Number of postnet convolutions (Default: ``5``).
|
| 897 |
+
postnet_kernel_size (int, optional): Postnet kernel size (Default: ``5``).
|
| 898 |
+
postnet_embedding_dim (int, optional): Postnet embedding dimension (Default: ``512``).
|
| 899 |
+
gate_threshold (float, optional): Probability threshold for stop token (Default: ``0.5``).
|
| 900 |
+
"""
|
| 901 |
+
|
| 902 |
+
def __init__(
|
| 903 |
+
self,
|
| 904 |
+
mask_padding: bool = False,
|
| 905 |
+
n_mels: int = 80,
|
| 906 |
+
n_symbol: int = 148,
|
| 907 |
+
n_frames_per_step: int = 1,
|
| 908 |
+
symbol_embedding_dim: int = 512,
|
| 909 |
+
encoder_embedding_dim: int = 512,
|
| 910 |
+
encoder_n_convolution: int = 3,
|
| 911 |
+
encoder_kernel_size: int = 5,
|
| 912 |
+
decoder_rnn_dim: int = 1024,
|
| 913 |
+
decoder_max_step: int = 2000,
|
| 914 |
+
decoder_dropout: float = 0.1,
|
| 915 |
+
decoder_early_stopping: bool = True,
|
| 916 |
+
attention_rnn_dim: int = 1024,
|
| 917 |
+
attention_hidden_dim: int = 128,
|
| 918 |
+
attention_location_n_filter: int = 32,
|
| 919 |
+
attention_location_kernel_size: int = 31,
|
| 920 |
+
attention_dropout: float = 0.1,
|
| 921 |
+
prenet_dim: int = 256,
|
| 922 |
+
postnet_n_convolution: int = 5,
|
| 923 |
+
postnet_kernel_size: int = 5,
|
| 924 |
+
postnet_embedding_dim: int = 512,
|
| 925 |
+
gate_threshold: float = 0.5,
|
| 926 |
+
) -> None:
|
| 927 |
+
super().__init__()
|
| 928 |
+
|
| 929 |
+
self.mask_padding = mask_padding
|
| 930 |
+
self.n_mels = n_mels
|
| 931 |
+
self.n_frames_per_step = n_frames_per_step
|
| 932 |
+
self.embedding = nn.Embedding(n_symbol, symbol_embedding_dim)
|
| 933 |
+
torch.nn.init.xavier_uniform_(self.embedding.weight)
|
| 934 |
+
self.encoder = _Encoder(encoder_embedding_dim, encoder_n_convolution, encoder_kernel_size)
|
| 935 |
+
self.decoder = _Decoder(
|
| 936 |
+
n_mels,
|
| 937 |
+
n_frames_per_step,
|
| 938 |
+
encoder_embedding_dim,
|
| 939 |
+
decoder_rnn_dim,
|
| 940 |
+
decoder_max_step,
|
| 941 |
+
decoder_dropout,
|
| 942 |
+
decoder_early_stopping,
|
| 943 |
+
attention_rnn_dim,
|
| 944 |
+
attention_hidden_dim,
|
| 945 |
+
attention_location_n_filter,
|
| 946 |
+
attention_location_kernel_size,
|
| 947 |
+
attention_dropout,
|
| 948 |
+
prenet_dim,
|
| 949 |
+
gate_threshold,
|
| 950 |
+
)
|
| 951 |
+
self.postnet = _Postnet(n_mels, postnet_embedding_dim, postnet_kernel_size, postnet_n_convolution)
|
| 952 |
+
|
| 953 |
+
def forward(
|
| 954 |
+
self,
|
| 955 |
+
tokens: Tensor,
|
| 956 |
+
token_lengths: Tensor,
|
| 957 |
+
mel_specgram: Tensor,
|
| 958 |
+
mel_specgram_lengths: Tensor,
|
| 959 |
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
| 960 |
+
r"""Pass the input through the Tacotron2 model. This is in teacher
|
| 961 |
+
forcing mode, which is generally used for training.
|
| 962 |
+
|
| 963 |
+
The input ``tokens`` should be padded with zeros to length max of ``token_lengths``.
|
| 964 |
+
The input ``mel_specgram`` should be padded with zeros to length max of ``mel_specgram_lengths``.
|
| 965 |
+
|
| 966 |
+
Args:
|
| 967 |
+
tokens (Tensor): The input tokens to Tacotron2 with shape `(n_batch, max of token_lengths)`.
|
| 968 |
+
token_lengths (Tensor): The valid length of each sample in ``tokens`` with shape `(n_batch, )`.
|
| 969 |
+
mel_specgram (Tensor): The target mel spectrogram
|
| 970 |
+
with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
|
| 971 |
+
mel_specgram_lengths (Tensor): The length of each mel spectrogram with shape `(n_batch, )`.
|
| 972 |
+
|
| 973 |
+
Returns:
|
| 974 |
+
[Tensor, Tensor, Tensor, Tensor]:
|
| 975 |
+
Tensor
|
| 976 |
+
Mel spectrogram before Postnet with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
|
| 977 |
+
Tensor
|
| 978 |
+
Mel spectrogram after Postnet with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
|
| 979 |
+
Tensor
|
| 980 |
+
The output for stop token at each time step with shape `(n_batch, max of mel_specgram_lengths)`.
|
| 981 |
+
Tensor
|
| 982 |
+
Sequence of attention weights from the decoder with
|
| 983 |
+
shape `(n_batch, max of mel_specgram_lengths, max of token_lengths)`.
|
| 984 |
+
"""
|
| 985 |
+
|
| 986 |
+
embedded_inputs = self.embedding(tokens).transpose(1, 2)
|
| 987 |
+
|
| 988 |
+
encoder_outputs = self.encoder(embedded_inputs, token_lengths)
|
| 989 |
+
mel_specgram, gate_outputs, alignments = self.decoder(
|
| 990 |
+
encoder_outputs, mel_specgram, memory_lengths=token_lengths
|
| 991 |
+
)
|
| 992 |
+
|
| 993 |
+
mel_specgram_postnet = self.postnet(mel_specgram)
|
| 994 |
+
mel_specgram_postnet = mel_specgram + mel_specgram_postnet
|
| 995 |
+
|
| 996 |
+
if self.mask_padding:
|
| 997 |
+
mask = _get_mask_from_lengths(mel_specgram_lengths)
|
| 998 |
+
mask = mask.expand(self.n_mels, mask.size(0), mask.size(1))
|
| 999 |
+
mask = mask.permute(1, 0, 2)
|
| 1000 |
+
|
| 1001 |
+
mel_specgram.masked_fill_(mask, 0.0)
|
| 1002 |
+
mel_specgram_postnet.masked_fill_(mask, 0.0)
|
| 1003 |
+
gate_outputs.masked_fill_(mask[:, 0, :], 1e3)
|
| 1004 |
+
|
| 1005 |
+
return mel_specgram, mel_specgram_postnet, gate_outputs, alignments
|
| 1006 |
+
|
| 1007 |
+
@torch.jit.export
|
| 1008 |
+
def infer(self, tokens: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]:
|
| 1009 |
+
r"""Using Tacotron2 for inference. The input is a batch of encoded
|
| 1010 |
+
sentences (``tokens``) and its corresponding lengths (``lengths``). The
|
| 1011 |
+
output is the generated mel spectrograms, its corresponding lengths, and
|
| 1012 |
+
the attention weights from the decoder.
|
| 1013 |
+
|
| 1014 |
+
The input `tokens` should be padded with zeros to length max of ``lengths``.
|
| 1015 |
+
|
| 1016 |
+
Args:
|
| 1017 |
+
tokens (Tensor): The input tokens to Tacotron2 with shape `(n_batch, max of lengths)`.
|
| 1018 |
+
lengths (Tensor or None, optional):
|
| 1019 |
+
The valid length of each sample in ``tokens`` with shape `(n_batch, )`.
|
| 1020 |
+
If ``None``, it is assumed that the all the tokens are valid. Default: ``None``
|
| 1021 |
+
|
| 1022 |
+
Returns:
|
| 1023 |
+
(Tensor, Tensor, Tensor):
|
| 1024 |
+
Tensor
|
| 1025 |
+
The predicted mel spectrogram with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
|
| 1026 |
+
Tensor
|
| 1027 |
+
The length of the predicted mel spectrogram with shape `(n_batch, )`.
|
| 1028 |
+
Tensor
|
| 1029 |
+
Sequence of attention weights from the decoder with shape
|
| 1030 |
+
`(n_batch, max of mel_specgram_lengths, max of lengths)`.
|
| 1031 |
+
"""
|
| 1032 |
+
n_batch, max_length = tokens.shape
|
| 1033 |
+
if lengths is None:
|
| 1034 |
+
lengths = torch.tensor([max_length]).expand(n_batch).to(tokens.device, tokens.dtype)
|
| 1035 |
+
|
| 1036 |
+
assert lengths is not None # For TorchScript compiler
|
| 1037 |
+
embedded_inputs = self.embedding(tokens).transpose(1, 2)
|
| 1038 |
+
encoder_outputs = self.encoder(embedded_inputs, lengths)
|
| 1039 |
+
mel_specgram, mel_specgram_lengths, _, alignments = self.decoder.infer(encoder_outputs, lengths)
|
| 1040 |
+
|
| 1041 |
+
mel_outputs_postnet = self.postnet(mel_specgram)
|
| 1042 |
+
mel_outputs_postnet = mel_specgram + mel_outputs_postnet
|
| 1043 |
+
|
| 1044 |
+
alignments = alignments.unfold(1, n_batch, n_batch).transpose(0, 2)
|
| 1045 |
+
|
| 1046 |
+
return mel_outputs_postnet, mel_specgram_lengths, alignments
|
.venv/lib/python3.11/site-packages/torchaudio/models/wav2letter.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn, Tensor
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
"Wav2Letter",
|
| 5 |
+
]
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Wav2Letter(nn.Module):
|
| 9 |
+
r"""Wav2Letter model architecture from *Wav2Letter: an End-to-End ConvNet-based Speech
|
| 10 |
+
Recognition System* :cite:`collobert2016wav2letter`.
|
| 11 |
+
|
| 12 |
+
See Also:
|
| 13 |
+
* `Training example <https://github.com/pytorch/audio/tree/release/0.12/examples/pipeline_wav2letter>`__
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
num_classes (int, optional): Number of classes to be classified. (Default: ``40``)
|
| 17 |
+
input_type (str, optional): Wav2Letter can use as input: ``waveform``, ``power_spectrum``
|
| 18 |
+
or ``mfcc`` (Default: ``waveform``).
|
| 19 |
+
num_features (int, optional): Number of input features that the network will receive (Default: ``1``).
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, num_classes: int = 40, input_type: str = "waveform", num_features: int = 1) -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
|
| 25 |
+
acoustic_num_features = 250 if input_type == "waveform" else num_features
|
| 26 |
+
acoustic_model = nn.Sequential(
|
| 27 |
+
nn.Conv1d(in_channels=acoustic_num_features, out_channels=250, kernel_size=48, stride=2, padding=23),
|
| 28 |
+
nn.ReLU(inplace=True),
|
| 29 |
+
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
|
| 30 |
+
nn.ReLU(inplace=True),
|
| 31 |
+
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
|
| 32 |
+
nn.ReLU(inplace=True),
|
| 33 |
+
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
|
| 34 |
+
nn.ReLU(inplace=True),
|
| 35 |
+
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
|
| 36 |
+
nn.ReLU(inplace=True),
|
| 37 |
+
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
|
| 38 |
+
nn.ReLU(inplace=True),
|
| 39 |
+
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
|
| 40 |
+
nn.ReLU(inplace=True),
|
| 41 |
+
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
|
| 42 |
+
nn.ReLU(inplace=True),
|
| 43 |
+
nn.Conv1d(in_channels=250, out_channels=2000, kernel_size=32, stride=1, padding=16),
|
| 44 |
+
nn.ReLU(inplace=True),
|
| 45 |
+
nn.Conv1d(in_channels=2000, out_channels=2000, kernel_size=1, stride=1, padding=0),
|
| 46 |
+
nn.ReLU(inplace=True),
|
| 47 |
+
nn.Conv1d(in_channels=2000, out_channels=num_classes, kernel_size=1, stride=1, padding=0),
|
| 48 |
+
nn.ReLU(inplace=True),
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
if input_type == "waveform":
|
| 52 |
+
waveform_model = nn.Sequential(
|
| 53 |
+
nn.Conv1d(in_channels=num_features, out_channels=250, kernel_size=250, stride=160, padding=45),
|
| 54 |
+
nn.ReLU(inplace=True),
|
| 55 |
+
)
|
| 56 |
+
self.acoustic_model = nn.Sequential(waveform_model, acoustic_model)
|
| 57 |
+
|
| 58 |
+
if input_type in ["power_spectrum", "mfcc"]:
|
| 59 |
+
self.acoustic_model = acoustic_model
|
| 60 |
+
|
| 61 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 62 |
+
r"""
|
| 63 |
+
Args:
|
| 64 |
+
x (torch.Tensor): Tensor of dimension (batch_size, num_features, input_length).
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
Tensor: Predictor tensor of dimension (batch_size, number_of_classes, input_length).
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
x = self.acoustic_model(x)
|
| 71 |
+
x = nn.functional.log_softmax(x, dim=1)
|
| 72 |
+
return x
|
.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/model.py
ADDED
|
@@ -0,0 +1,1579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
from torch.nn import Module
|
| 7 |
+
|
| 8 |
+
from . import components
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Wav2Vec2Model(Module):
|
| 12 |
+
"""Acoustic model used in *wav2vec 2.0* :cite:`baevski2020wav2vec`.
|
| 13 |
+
|
| 14 |
+
Note:
|
| 15 |
+
To build the model, please use one of the factory functions.
|
| 16 |
+
|
| 17 |
+
See Also:
|
| 18 |
+
* :class:`torchaudio.pipelines.Wav2Vec2Bundle`: Pretrained models (without fine-tuning)
|
| 19 |
+
* :class:`torchaudio.pipelines.Wav2Vec2ASRBundle`: ASR pipelines with pretrained models.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
feature_extractor (torch.nn.Module):
|
| 23 |
+
Feature extractor that extracts feature vectors from raw audio Tensor.
|
| 24 |
+
|
| 25 |
+
encoder (torch.nn.Module):
|
| 26 |
+
Encoder that converts the audio features into the sequence of probability
|
| 27 |
+
distribution (in negative log-likelihood) over labels.
|
| 28 |
+
|
| 29 |
+
aux (torch.nn.Module or None, optional):
|
| 30 |
+
Auxiliary module. If provided, the output from encoder is passed to this module.
|
| 31 |
+
""" # noqa: E501
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
feature_extractor: Module,
|
| 36 |
+
encoder: Module,
|
| 37 |
+
aux: Optional[Module] = None,
|
| 38 |
+
):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.feature_extractor = feature_extractor
|
| 41 |
+
self.encoder = encoder
|
| 42 |
+
self.aux = aux
|
| 43 |
+
|
| 44 |
+
@torch.jit.export
|
| 45 |
+
def extract_features(
|
| 46 |
+
self,
|
| 47 |
+
waveforms: Tensor,
|
| 48 |
+
lengths: Optional[Tensor] = None,
|
| 49 |
+
num_layers: Optional[int] = None,
|
| 50 |
+
) -> Tuple[List[Tensor], Optional[Tensor]]:
|
| 51 |
+
"""Extract feature vectors from raw waveforms
|
| 52 |
+
|
| 53 |
+
This returns the list of outputs from the intermediate layers of
|
| 54 |
+
transformer block in encoder.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
waveforms (Tensor): Audio tensor of shape `(batch, frames)`.
|
| 58 |
+
lengths (Tensor or None, optional):
|
| 59 |
+
Indicates the valid length of each audio in the batch.
|
| 60 |
+
Shape: `(batch, )`.
|
| 61 |
+
When the ``waveforms`` contains audios with different durations,
|
| 62 |
+
by providing ``lengths`` argument, the model will compute
|
| 63 |
+
the corresponding valid output lengths and apply proper mask in
|
| 64 |
+
transformer attention layer.
|
| 65 |
+
If ``None``, it is assumed that the entire audio waveform
|
| 66 |
+
length is valid.
|
| 67 |
+
num_layers (int or None, optional):
|
| 68 |
+
If given, limit the number of intermediate layers to go through.
|
| 69 |
+
Providing `1` will stop the computation after going through one
|
| 70 |
+
intermediate layers. If not given, the outputs from all the
|
| 71 |
+
intermediate layers are returned.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
(List[Tensor], Optional[Tensor]):
|
| 75 |
+
List of Tensors
|
| 76 |
+
Features from requested layers.
|
| 77 |
+
Each Tensor is of shape: `(batch, time frame, feature dimension)`
|
| 78 |
+
Tensor or None
|
| 79 |
+
If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
|
| 80 |
+
is returned.
|
| 81 |
+
It indicates the valid length in time axis of each feature Tensor.
|
| 82 |
+
"""
|
| 83 |
+
x, lengths = self.feature_extractor(waveforms, lengths)
|
| 84 |
+
x = self.encoder.extract_features(x, lengths, num_layers)
|
| 85 |
+
return x, lengths
|
| 86 |
+
|
| 87 |
+
def forward(
|
| 88 |
+
self,
|
| 89 |
+
waveforms: Tensor,
|
| 90 |
+
lengths: Optional[Tensor] = None,
|
| 91 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
| 92 |
+
"""Compute the sequence of probability distribution over labels.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
waveforms (Tensor): Audio tensor of shape `(batch, frames)`.
|
| 96 |
+
lengths (Tensor or None, optional):
|
| 97 |
+
Indicates the valid length of each audio in the batch.
|
| 98 |
+
Shape: `(batch, )`.
|
| 99 |
+
When the ``waveforms`` contains audios with different durations,
|
| 100 |
+
by providing ``lengths`` argument, the model will compute
|
| 101 |
+
the corresponding valid output lengths and apply proper mask in
|
| 102 |
+
transformer attention layer.
|
| 103 |
+
If ``None``, it is assumed that all the audio in ``waveforms``
|
| 104 |
+
have valid length. Default: ``None``.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
(Tensor, Optional[Tensor]):
|
| 108 |
+
Tensor
|
| 109 |
+
The sequences of probability distribution (in logit) over labels.
|
| 110 |
+
Shape: `(batch, frames, num labels)`.
|
| 111 |
+
Tensor or None
|
| 112 |
+
If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
|
| 113 |
+
is returned.
|
| 114 |
+
It indicates the valid length in time axis of the output Tensor.
|
| 115 |
+
"""
|
| 116 |
+
x, lengths = self.feature_extractor(waveforms, lengths)
|
| 117 |
+
x = self.encoder(x, lengths)
|
| 118 |
+
if self.aux is not None:
|
| 119 |
+
x = self.aux(x)
|
| 120 |
+
return x, lengths
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class HuBERTPretrainModel(Module):
|
| 124 |
+
"""HuBERTPretrainModel()
|
| 125 |
+
|
| 126 |
+
HuBERT model used for pretraining in *HuBERT* :cite:`hsu2021hubert`.
|
| 127 |
+
|
| 128 |
+
Note:
|
| 129 |
+
To build the model, please use one of the factory functions.
|
| 130 |
+
|
| 131 |
+
See Also:
|
| 132 |
+
`HuBERT Pre-training and Fine-tuning Recipes
|
| 133 |
+
<https://github.com/pytorch/audio/tree/main/examples/hubert>`__
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
wav2vec2 (Wav2Vec2Model):
|
| 137 |
+
Wav2Vec2 encoder that generates the transformer outputs.
|
| 138 |
+
|
| 139 |
+
mask_generator (torch.nn.Module):
|
| 140 |
+
Mask generator that generates the mask for masked prediction during the training.
|
| 141 |
+
|
| 142 |
+
logit_generator (torch.nn.Module):
|
| 143 |
+
Logit generator that predicts the logits of the masked and unmasked inputs.
|
| 144 |
+
|
| 145 |
+
feature_grad_mult (float or None):
|
| 146 |
+
The factor to scale the convolutional feature extraction layer gradients by.
|
| 147 |
+
If ``None``, the gradients of feature extraction layers are not affected.
|
| 148 |
+
The scale factor will not affect the forward pass.
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
def __init__(
|
| 152 |
+
self,
|
| 153 |
+
wav2vec2: Wav2Vec2Model,
|
| 154 |
+
mask_generator: Module,
|
| 155 |
+
logit_generator: Module,
|
| 156 |
+
feature_grad_mult: Optional[float],
|
| 157 |
+
):
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.wav2vec2 = wav2vec2
|
| 160 |
+
self.mask_generator = mask_generator
|
| 161 |
+
self.logit_generator = logit_generator
|
| 162 |
+
if feature_grad_mult is not None and not 0.0 < feature_grad_mult < 1.0:
|
| 163 |
+
raise ValueError(
|
| 164 |
+
f"The value of `feature_grad_mult` must be ``None``or between (0, 1). Found {feature_grad_mult}"
|
| 165 |
+
)
|
| 166 |
+
self.feature_grad_mult = feature_grad_mult
|
| 167 |
+
|
| 168 |
+
def forward(
|
| 169 |
+
self,
|
| 170 |
+
waveforms: Tensor,
|
| 171 |
+
labels: Tensor,
|
| 172 |
+
audio_lengths: Optional[Tensor] = None,
|
| 173 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
| 174 |
+
"""Compute the sequence of probability distribution over labels.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
waveforms (Tensor): Audio tensor of dimension `[batch, frames]`.
|
| 178 |
+
labels (Tensor): Label for pre-training. A Tensor of dimension `[batch, frames]`.
|
| 179 |
+
audio_lengths (Tensor or None, optional):
|
| 180 |
+
Indicates the valid length of each audio in the batch.
|
| 181 |
+
Shape: `[batch, ]`.
|
| 182 |
+
When the ``waveforms`` contains audios with different durations,
|
| 183 |
+
by providing ``lengths`` argument, the model will compute
|
| 184 |
+
the corresponding valid output lengths and apply proper mask in
|
| 185 |
+
transformer attention layer.
|
| 186 |
+
If ``None``, it is assumed that all the audio in ``waveforms``
|
| 187 |
+
have valid length. Default: ``None``.
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
(Tensor, Tensor, Tensor):
|
| 191 |
+
Tensor
|
| 192 |
+
The masked sequences of probability distribution (in logit).
|
| 193 |
+
Shape: `(masked_frames, num labels)`.
|
| 194 |
+
Tensor
|
| 195 |
+
The unmasked sequence of probability distribution (in logit).
|
| 196 |
+
Shape: `(unmasked_frames, num labels)`.
|
| 197 |
+
Tensor
|
| 198 |
+
The feature mean value for additional penalty loss.
|
| 199 |
+
Shape: `(1,)`.
|
| 200 |
+
"""
|
| 201 |
+
x, lengths = self.wav2vec2.feature_extractor(waveforms, audio_lengths)
|
| 202 |
+
if self.feature_grad_mult is not None and self.feature_grad_mult < 1.0:
|
| 203 |
+
x = components.GradMultiply.apply(x, self.feature_grad_mult)
|
| 204 |
+
features_pen = x.float().pow(2).mean()
|
| 205 |
+
if lengths is not None:
|
| 206 |
+
padding_mask = components._get_padding_mask(x, lengths)
|
| 207 |
+
else:
|
| 208 |
+
padding_mask = None
|
| 209 |
+
x, attention_mask = self.wav2vec2.encoder._preprocess(x, lengths)
|
| 210 |
+
x, mask = self.mask_generator(x, padding_mask)
|
| 211 |
+
x = self.wav2vec2.encoder.transformer(x, attention_mask=attention_mask)
|
| 212 |
+
if x.shape[1] != labels.shape[1]:
|
| 213 |
+
raise ValueError("The length of label must match that of HuBERT model output")
|
| 214 |
+
if padding_mask is not None:
|
| 215 |
+
mask_m = torch.logical_and(~padding_mask, mask)
|
| 216 |
+
mask_u = torch.logical_and(~padding_mask, ~mask_m)
|
| 217 |
+
else:
|
| 218 |
+
mask_m = mask
|
| 219 |
+
mask_u = ~mask_m
|
| 220 |
+
|
| 221 |
+
logit_m, logit_u = self.logit_generator(x, labels, mask_m, mask_u)
|
| 222 |
+
|
| 223 |
+
return logit_m, logit_u, features_pen
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def wav2vec2_model(
|
| 227 |
+
extractor_mode: str,
|
| 228 |
+
extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]],
|
| 229 |
+
extractor_conv_bias: bool,
|
| 230 |
+
encoder_embed_dim: int,
|
| 231 |
+
encoder_projection_dropout: float,
|
| 232 |
+
encoder_pos_conv_kernel: int,
|
| 233 |
+
encoder_pos_conv_groups: int,
|
| 234 |
+
encoder_num_layers: int,
|
| 235 |
+
encoder_num_heads: int,
|
| 236 |
+
encoder_attention_dropout: float,
|
| 237 |
+
encoder_ff_interm_features: int,
|
| 238 |
+
encoder_ff_interm_dropout: float,
|
| 239 |
+
encoder_dropout: float,
|
| 240 |
+
encoder_layer_norm_first: bool,
|
| 241 |
+
encoder_layer_drop: float,
|
| 242 |
+
aux_num_out: Optional[int],
|
| 243 |
+
) -> Wav2Vec2Model:
|
| 244 |
+
"""Builds custom :class:`~torchaudio.models.Wav2Vec2Model`.
|
| 245 |
+
|
| 246 |
+
Note:
|
| 247 |
+
The "feature extractor" below corresponds to
|
| 248 |
+
`ConvFeatureExtractionModel <https://github.com/pytorch/fairseq/blob/dd3bd3c0497ae9a7ae7364404a6b0a4c501780b3/fairseq/models/wav2vec/wav2vec2.py#L736>`__
|
| 249 |
+
in the original ``fairseq`` implementation.
|
| 250 |
+
This is referred as "(convolutional) feature encoder" in the *wav2vec 2.0*
|
| 251 |
+
:cite:`baevski2020wav2vec` paper.
|
| 252 |
+
|
| 253 |
+
The "encoder" below corresponds to `TransformerEncoder <https://github.com/pytorch/fairseq/blob/dd3bd3c0497ae9a7ae7364404a6b0a4c501780b3/fairseq/models/wav2vec/wav2vec2.py#L817>`__,
|
| 254 |
+
and this is referred as "Transformer" in the paper.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
extractor_mode (str): Operation mode of feature extractor.
|
| 258 |
+
Valid values are ``"group_norm"`` or ``"layer_norm"``.
|
| 259 |
+
If ``"group_norm"``, then a single normalization is applied
|
| 260 |
+
in the first convolution block. Otherwise, all the convolution
|
| 261 |
+
blocks will have layer normalization.
|
| 262 |
+
|
| 263 |
+
This option corresponds to ``extractor_mode`` from ``fairseq``.
|
| 264 |
+
extractor_conv_layer_config (list of integer tuples or None):
|
| 265 |
+
Configuration of convolution layers in feature extractor.
|
| 266 |
+
List of convolution configuration,
|
| 267 |
+
i.e. ``[(output_channel, kernel_size, stride), ...]``
|
| 268 |
+
|
| 269 |
+
If ``None`` is provided, then the following default value is used.
|
| 270 |
+
|
| 271 |
+
.. code-block:: python
|
| 272 |
+
|
| 273 |
+
[
|
| 274 |
+
(512, 10, 5),
|
| 275 |
+
(512, 3, 2),
|
| 276 |
+
(512, 3, 2),
|
| 277 |
+
(512, 3, 2),
|
| 278 |
+
(512, 3, 2),
|
| 279 |
+
(512, 2, 2),
|
| 280 |
+
(512, 2, 2),
|
| 281 |
+
]
|
| 282 |
+
|
| 283 |
+
This option corresponds to ``conv_feature_layers`` from ``fairseq``.
|
| 284 |
+
|
| 285 |
+
extractor_conv_bias (bool):
|
| 286 |
+
Whether to include bias term to each convolution operation.
|
| 287 |
+
|
| 288 |
+
This option corresponds to ``conv_bias`` from ``fairseq``.
|
| 289 |
+
|
| 290 |
+
encoder_embed_dim (int):
|
| 291 |
+
The dimension of embedding in encoder.
|
| 292 |
+
|
| 293 |
+
This option corresponds to ``encoder_embed_dim`` from ``fairseq``.
|
| 294 |
+
|
| 295 |
+
encoder_projection_dropout (float):
|
| 296 |
+
The dropout probability applied after the input feature is projected
|
| 297 |
+
to ``encoder_embed_dim``.
|
| 298 |
+
|
| 299 |
+
This option corresponds to ``dropout_input`` from ``fairseq``.
|
| 300 |
+
|
| 301 |
+
encoder_pos_conv_kernel (int):
|
| 302 |
+
The kernel size of convolutional positional embeddings.
|
| 303 |
+
|
| 304 |
+
This option corresponds to ``conv_pos`` from ``fairseq``.
|
| 305 |
+
|
| 306 |
+
encoder_pos_conv_groups (int):
|
| 307 |
+
The number of groups of convolutional positional embeddings.
|
| 308 |
+
|
| 309 |
+
This option corresponds to ``conv_pos_groups`` from ``fairseq``.
|
| 310 |
+
|
| 311 |
+
encoder_num_layers (int):
|
| 312 |
+
The number of self attention layers in transformer block.
|
| 313 |
+
|
| 314 |
+
This option corresponds to ``encoder_layers`` from ``fairseq``.
|
| 315 |
+
|
| 316 |
+
encoder_num_heads (int):
|
| 317 |
+
The number of heads in self attention layers.
|
| 318 |
+
|
| 319 |
+
This option corresponds to ``encoder_attention_heads`` from ``fairseq``.
|
| 320 |
+
|
| 321 |
+
encoder_attention_dropout (float):
|
| 322 |
+
The dropout probability applied after softmax in self-attention layer.
|
| 323 |
+
|
| 324 |
+
This option corresponds to ``attention_dropout`` from ``fairseq``.
|
| 325 |
+
|
| 326 |
+
encoder_ff_interm_features (int):
|
| 327 |
+
The dimension of hidden features in feed forward layer.
|
| 328 |
+
|
| 329 |
+
This option corresponds to ``encoder_ffn_embed_dim`` from ``fairseq``.
|
| 330 |
+
|
| 331 |
+
encoder_ff_interm_dropout (float):
|
| 332 |
+
The dropout probability applied in feedforward layer.
|
| 333 |
+
|
| 334 |
+
This option correspinds to ``activation_dropout`` from ``fairseq``.
|
| 335 |
+
|
| 336 |
+
encoder_dropout (float):
|
| 337 |
+
The dropout probability applied at the end of feed forward layer.
|
| 338 |
+
|
| 339 |
+
This option corresponds to ``dropout`` from ``fairseq``.
|
| 340 |
+
|
| 341 |
+
encoder_layer_norm_first (bool):
|
| 342 |
+
Control the order of layer norm in transformer layer and each encoder layer.
|
| 343 |
+
If True, in transformer layer, layer norm is applied before features are fed
|
| 344 |
+
to encoder layers. In encoder layer, two layer norms are applied before and after
|
| 345 |
+
self attention.
|
| 346 |
+
If False, in transformer layer, layer norm is applied after features are fed
|
| 347 |
+
to encoder layers. In encoder layer, two layer norms are applied after self
|
| 348 |
+
attention, before and after feed forward.
|
| 349 |
+
|
| 350 |
+
This option corresponds to ``layer_norm_first`` from ``fairseq``.
|
| 351 |
+
|
| 352 |
+
encoder_layer_drop (float):
|
| 353 |
+
Probability to drop each encoder layer during training.
|
| 354 |
+
|
| 355 |
+
This option corresponds to ``layerdrop`` from ``fairseq``.
|
| 356 |
+
|
| 357 |
+
aux_num_out (int or None):
|
| 358 |
+
When provided, attach an extra linear layer on top of encoder, which can be
|
| 359 |
+
used for fine-tuning.
|
| 360 |
+
|
| 361 |
+
Returns:
|
| 362 |
+
Wav2Vec2Model:
|
| 363 |
+
The resulting model.
|
| 364 |
+
""" # noqa: E501
|
| 365 |
+
if extractor_conv_layer_config is None:
|
| 366 |
+
extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
|
| 367 |
+
|
| 368 |
+
feature_extractor = components._get_feature_extractor(
|
| 369 |
+
extractor_mode, extractor_conv_layer_config, extractor_conv_bias
|
| 370 |
+
)
|
| 371 |
+
encoder = components._get_encoder(
|
| 372 |
+
in_features=extractor_conv_layer_config[-1][0],
|
| 373 |
+
embed_dim=encoder_embed_dim,
|
| 374 |
+
dropout_input=encoder_projection_dropout,
|
| 375 |
+
pos_conv_kernel=encoder_pos_conv_kernel,
|
| 376 |
+
pos_conv_groups=encoder_pos_conv_groups,
|
| 377 |
+
num_layers=encoder_num_layers,
|
| 378 |
+
num_heads=encoder_num_heads,
|
| 379 |
+
attention_dropout=encoder_attention_dropout,
|
| 380 |
+
ff_interm_features=encoder_ff_interm_features,
|
| 381 |
+
ff_interm_dropout=encoder_ff_interm_dropout,
|
| 382 |
+
dropout=encoder_dropout,
|
| 383 |
+
layer_norm_first=encoder_layer_norm_first,
|
| 384 |
+
layer_drop=encoder_layer_drop,
|
| 385 |
+
)
|
| 386 |
+
aux = None
|
| 387 |
+
if aux_num_out is not None:
|
| 388 |
+
aux = torch.nn.Linear(in_features=encoder_embed_dim, out_features=aux_num_out)
|
| 389 |
+
return Wav2Vec2Model(feature_extractor, encoder, aux)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def wav2vec2_base(
|
| 393 |
+
encoder_projection_dropout: float = 0.1,
|
| 394 |
+
encoder_attention_dropout: float = 0.1,
|
| 395 |
+
encoder_ff_interm_dropout: float = 0.1,
|
| 396 |
+
encoder_dropout: float = 0.1,
|
| 397 |
+
encoder_layer_drop: float = 0.1,
|
| 398 |
+
aux_num_out: Optional[int] = None,
|
| 399 |
+
) -> Wav2Vec2Model:
|
| 400 |
+
"""Builds "base" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec`
|
| 401 |
+
|
| 402 |
+
Args:
|
| 403 |
+
encoder_projection_dropout (float):
|
| 404 |
+
See :py:func:`wav2vec2_model`.
|
| 405 |
+
encoder_attention_dropout (float):
|
| 406 |
+
See :py:func:`wav2vec2_model`.
|
| 407 |
+
encoder_ff_interm_dropout (float):
|
| 408 |
+
See :py:func:`wav2vec2_model`.
|
| 409 |
+
encoder_dropout (float):
|
| 410 |
+
See :py:func:`wav2vec2_model`.
|
| 411 |
+
encoder_layer_drop (float):
|
| 412 |
+
See :py:func:`wav2vec2_model`.
|
| 413 |
+
aux_num_out (int or None, optional):
|
| 414 |
+
See :py:func:`wav2vec2_model`.
|
| 415 |
+
|
| 416 |
+
Returns:
|
| 417 |
+
Wav2Vec2Model:
|
| 418 |
+
The resulting model.
|
| 419 |
+
""" # noqa: E501
|
| 420 |
+
return wav2vec2_model(
|
| 421 |
+
extractor_mode="group_norm",
|
| 422 |
+
extractor_conv_layer_config=None,
|
| 423 |
+
extractor_conv_bias=False,
|
| 424 |
+
encoder_embed_dim=768,
|
| 425 |
+
encoder_projection_dropout=encoder_projection_dropout,
|
| 426 |
+
encoder_pos_conv_kernel=128,
|
| 427 |
+
encoder_pos_conv_groups=16,
|
| 428 |
+
encoder_num_layers=12,
|
| 429 |
+
encoder_num_heads=12,
|
| 430 |
+
encoder_attention_dropout=encoder_attention_dropout,
|
| 431 |
+
encoder_ff_interm_features=3072,
|
| 432 |
+
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
|
| 433 |
+
encoder_dropout=encoder_dropout,
|
| 434 |
+
encoder_layer_norm_first=False,
|
| 435 |
+
encoder_layer_drop=encoder_layer_drop,
|
| 436 |
+
aux_num_out=aux_num_out,
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def wav2vec2_large(
|
| 441 |
+
encoder_projection_dropout: float = 0.1,
|
| 442 |
+
encoder_attention_dropout: float = 0.1,
|
| 443 |
+
encoder_ff_interm_dropout: float = 0.1,
|
| 444 |
+
encoder_dropout: float = 0.1,
|
| 445 |
+
encoder_layer_drop: float = 0.1,
|
| 446 |
+
aux_num_out: Optional[int] = None,
|
| 447 |
+
) -> Wav2Vec2Model:
|
| 448 |
+
"""Builds "large" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec`
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
encoder_projection_dropout (float):
|
| 452 |
+
See :py:func:`wav2vec2_model`.
|
| 453 |
+
encoder_attention_dropout (float):
|
| 454 |
+
See :py:func:`wav2vec2_model`.
|
| 455 |
+
encoder_ff_interm_dropout (float):
|
| 456 |
+
See :py:func:`wav2vec2_model`.
|
| 457 |
+
encoder_dropout (float):
|
| 458 |
+
See :py:func:`wav2vec2_model`.
|
| 459 |
+
encoder_layer_drop (float):
|
| 460 |
+
See :py:func:`wav2vec2_model`.
|
| 461 |
+
aux_num_out (int or None, optional):
|
| 462 |
+
See :py:func:`wav2vec2_model`.
|
| 463 |
+
|
| 464 |
+
Returns:
|
| 465 |
+
Wav2Vec2Model:
|
| 466 |
+
The resulting model.
|
| 467 |
+
""" # noqa: E501
|
| 468 |
+
return wav2vec2_model(
|
| 469 |
+
extractor_mode="group_norm",
|
| 470 |
+
extractor_conv_layer_config=None,
|
| 471 |
+
extractor_conv_bias=False,
|
| 472 |
+
encoder_embed_dim=1024,
|
| 473 |
+
encoder_projection_dropout=encoder_projection_dropout,
|
| 474 |
+
encoder_pos_conv_kernel=128,
|
| 475 |
+
encoder_pos_conv_groups=16,
|
| 476 |
+
encoder_num_layers=24,
|
| 477 |
+
encoder_num_heads=16,
|
| 478 |
+
encoder_attention_dropout=encoder_attention_dropout,
|
| 479 |
+
encoder_ff_interm_features=4096,
|
| 480 |
+
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
|
| 481 |
+
encoder_dropout=encoder_dropout,
|
| 482 |
+
encoder_layer_norm_first=False,
|
| 483 |
+
encoder_layer_drop=encoder_layer_drop,
|
| 484 |
+
aux_num_out=aux_num_out,
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def wav2vec2_large_lv60k(
|
| 489 |
+
encoder_projection_dropout: float = 0.1,
|
| 490 |
+
encoder_attention_dropout: float = 0.0,
|
| 491 |
+
encoder_ff_interm_dropout: float = 0.1,
|
| 492 |
+
encoder_dropout: float = 0.0,
|
| 493 |
+
encoder_layer_drop: float = 0.1,
|
| 494 |
+
aux_num_out: Optional[int] = None,
|
| 495 |
+
) -> Wav2Vec2Model:
|
| 496 |
+
"""Builds "large lv-60k" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec`
|
| 497 |
+
|
| 498 |
+
Args:
|
| 499 |
+
encoder_projection_dropout (float):
|
| 500 |
+
See :py:func:`wav2vec2_model`.
|
| 501 |
+
encoder_attention_dropout (float):
|
| 502 |
+
See :py:func:`wav2vec2_model`.
|
| 503 |
+
encoder_ff_interm_dropout (float):
|
| 504 |
+
See :py:func:`wav2vec2_model`.
|
| 505 |
+
encoder_dropout (float):
|
| 506 |
+
See :py:func:`wav2vec2_model`.
|
| 507 |
+
encoder_layer_drop (float):
|
| 508 |
+
See :py:func:`wav2vec2_model`.
|
| 509 |
+
aux_num_out (int or None, optional):
|
| 510 |
+
See :py:func:`wav2vec2_model`.
|
| 511 |
+
|
| 512 |
+
Returns:
|
| 513 |
+
Wav2Vec2Model:
|
| 514 |
+
The resulting model.
|
| 515 |
+
""" # noqa: E501
|
| 516 |
+
return wav2vec2_model(
|
| 517 |
+
extractor_mode="layer_norm",
|
| 518 |
+
extractor_conv_layer_config=None,
|
| 519 |
+
extractor_conv_bias=True,
|
| 520 |
+
encoder_embed_dim=1024,
|
| 521 |
+
encoder_projection_dropout=encoder_projection_dropout,
|
| 522 |
+
encoder_pos_conv_kernel=128,
|
| 523 |
+
encoder_pos_conv_groups=16,
|
| 524 |
+
encoder_num_layers=24,
|
| 525 |
+
encoder_num_heads=16,
|
| 526 |
+
encoder_attention_dropout=encoder_attention_dropout,
|
| 527 |
+
encoder_ff_interm_features=4096,
|
| 528 |
+
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
|
| 529 |
+
encoder_dropout=encoder_dropout,
|
| 530 |
+
encoder_layer_norm_first=True,
|
| 531 |
+
encoder_layer_drop=encoder_layer_drop,
|
| 532 |
+
aux_num_out=aux_num_out,
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def hubert_base(
|
| 537 |
+
encoder_projection_dropout: float = 0.1,
|
| 538 |
+
encoder_attention_dropout: float = 0.1,
|
| 539 |
+
encoder_ff_interm_dropout: float = 0.0,
|
| 540 |
+
encoder_dropout: float = 0.1,
|
| 541 |
+
encoder_layer_drop: float = 0.05,
|
| 542 |
+
aux_num_out: Optional[int] = None,
|
| 543 |
+
) -> Wav2Vec2Model:
|
| 544 |
+
"""Builds "base" :class:`HuBERT <torchaudio.models.Wav2Vec2Model>` from *HuBERT* :cite:`hsu2021hubert`
|
| 545 |
+
|
| 546 |
+
Args:
|
| 547 |
+
encoder_projection_dropout (float):
|
| 548 |
+
See :py:func:`wav2vec2_model`.
|
| 549 |
+
encoder_attention_dropout (float):
|
| 550 |
+
See :py:func:`wav2vec2_model`.
|
| 551 |
+
encoder_ff_interm_dropout (float):
|
| 552 |
+
See :py:func:`wav2vec2_model`.
|
| 553 |
+
encoder_dropout (float):
|
| 554 |
+
See :py:func:`wav2vec2_model`.
|
| 555 |
+
encoder_layer_drop (float):
|
| 556 |
+
See :py:func:`wav2vec2_model`.
|
| 557 |
+
aux_num_out (int or None, optional):
|
| 558 |
+
See :py:func:`wav2vec2_model`.
|
| 559 |
+
|
| 560 |
+
Returns:
|
| 561 |
+
Wav2Vec2Model:
|
| 562 |
+
The resulting model.
|
| 563 |
+
""" # noqa: E501
|
| 564 |
+
return wav2vec2_model(
|
| 565 |
+
extractor_mode="group_norm",
|
| 566 |
+
extractor_conv_layer_config=None,
|
| 567 |
+
extractor_conv_bias=False,
|
| 568 |
+
encoder_embed_dim=768,
|
| 569 |
+
encoder_projection_dropout=encoder_projection_dropout,
|
| 570 |
+
encoder_pos_conv_kernel=128,
|
| 571 |
+
encoder_pos_conv_groups=16,
|
| 572 |
+
encoder_num_layers=12,
|
| 573 |
+
encoder_num_heads=12,
|
| 574 |
+
encoder_attention_dropout=encoder_attention_dropout,
|
| 575 |
+
encoder_ff_interm_features=3072,
|
| 576 |
+
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
|
| 577 |
+
encoder_dropout=encoder_dropout,
|
| 578 |
+
encoder_layer_norm_first=False,
|
| 579 |
+
encoder_layer_drop=encoder_layer_drop,
|
| 580 |
+
aux_num_out=aux_num_out,
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
def hubert_large(
|
| 585 |
+
encoder_projection_dropout: float = 0.0,
|
| 586 |
+
encoder_attention_dropout: float = 0.0,
|
| 587 |
+
encoder_ff_interm_dropout: float = 0.0,
|
| 588 |
+
encoder_dropout: float = 0.0,
|
| 589 |
+
encoder_layer_drop: float = 0.0,
|
| 590 |
+
aux_num_out: Optional[int] = None,
|
| 591 |
+
) -> Wav2Vec2Model:
|
| 592 |
+
"""Builds "large" :class:`HuBERT <torchaudio.models.Wav2Vec2Model>` from *HuBERT* :cite:`hsu2021hubert`
|
| 593 |
+
|
| 594 |
+
Args:
|
| 595 |
+
encoder_projection_dropout (float):
|
| 596 |
+
See :py:func:`wav2vec2_model`.
|
| 597 |
+
encoder_attention_dropout (float):
|
| 598 |
+
See :py:func:`wav2vec2_model`.
|
| 599 |
+
encoder_ff_interm_dropout (float):
|
| 600 |
+
See :py:func:`wav2vec2_model`.
|
| 601 |
+
encoder_dropout (float):
|
| 602 |
+
See :py:func:`wav2vec2_model`.
|
| 603 |
+
encoder_layer_drop (float):
|
| 604 |
+
See :py:func:`wav2vec2_model`.
|
| 605 |
+
aux_num_out (int or None, optional):
|
| 606 |
+
See :py:func:`wav2vec2_model`.
|
| 607 |
+
|
| 608 |
+
Returns:
|
| 609 |
+
Wav2Vec2Model:
|
| 610 |
+
The resulting model.
|
| 611 |
+
""" # noqa: E501
|
| 612 |
+
return wav2vec2_model(
|
| 613 |
+
extractor_mode="layer_norm",
|
| 614 |
+
extractor_conv_layer_config=None,
|
| 615 |
+
extractor_conv_bias=False,
|
| 616 |
+
encoder_embed_dim=1024,
|
| 617 |
+
encoder_projection_dropout=encoder_projection_dropout,
|
| 618 |
+
encoder_pos_conv_kernel=128,
|
| 619 |
+
encoder_pos_conv_groups=16,
|
| 620 |
+
encoder_num_layers=24,
|
| 621 |
+
encoder_num_heads=16,
|
| 622 |
+
encoder_attention_dropout=encoder_attention_dropout,
|
| 623 |
+
encoder_ff_interm_features=4096,
|
| 624 |
+
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
|
| 625 |
+
encoder_dropout=encoder_dropout,
|
| 626 |
+
encoder_layer_norm_first=True,
|
| 627 |
+
encoder_layer_drop=encoder_layer_drop,
|
| 628 |
+
aux_num_out=aux_num_out,
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
def hubert_xlarge(
|
| 633 |
+
encoder_projection_dropout: float = 0.0,
|
| 634 |
+
encoder_attention_dropout: float = 0.0,
|
| 635 |
+
encoder_ff_interm_dropout: float = 0.0,
|
| 636 |
+
encoder_dropout: float = 0.0,
|
| 637 |
+
encoder_layer_drop: float = 0.0,
|
| 638 |
+
aux_num_out: Optional[int] = None,
|
| 639 |
+
) -> Wav2Vec2Model:
|
| 640 |
+
"""Builds "extra large" :class:`HuBERT <torchaudio.models.Wav2Vec2Model>` from *HuBERT* :cite:`hsu2021hubert`
|
| 641 |
+
|
| 642 |
+
Args:
|
| 643 |
+
encoder_projection_dropout (float):
|
| 644 |
+
See :py:func:`wav2vec2_model`.
|
| 645 |
+
encoder_attention_dropout (float):
|
| 646 |
+
See :py:func:`wav2vec2_model`.
|
| 647 |
+
encoder_ff_interm_dropout (float):
|
| 648 |
+
See :py:func:`wav2vec2_model`.
|
| 649 |
+
encoder_dropout (float):
|
| 650 |
+
See :py:func:`wav2vec2_model`.
|
| 651 |
+
encoder_layer_drop (float):
|
| 652 |
+
See :py:func:`wav2vec2_model`.
|
| 653 |
+
aux_num_out (int or None, optional):
|
| 654 |
+
See :py:func:`wav2vec2_model`.
|
| 655 |
+
|
| 656 |
+
Returns:
|
| 657 |
+
Wav2Vec2Model:
|
| 658 |
+
The resulting model.
|
| 659 |
+
""" # noqa: E501
|
| 660 |
+
return wav2vec2_model(
|
| 661 |
+
extractor_mode="layer_norm",
|
| 662 |
+
extractor_conv_layer_config=None,
|
| 663 |
+
extractor_conv_bias=False,
|
| 664 |
+
encoder_embed_dim=1280,
|
| 665 |
+
encoder_projection_dropout=encoder_projection_dropout,
|
| 666 |
+
encoder_pos_conv_kernel=128,
|
| 667 |
+
encoder_pos_conv_groups=16,
|
| 668 |
+
encoder_num_layers=48,
|
| 669 |
+
encoder_num_heads=16,
|
| 670 |
+
encoder_attention_dropout=encoder_attention_dropout,
|
| 671 |
+
encoder_ff_interm_features=5120,
|
| 672 |
+
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
|
| 673 |
+
encoder_dropout=encoder_dropout,
|
| 674 |
+
encoder_layer_norm_first=True,
|
| 675 |
+
encoder_layer_drop=encoder_layer_drop,
|
| 676 |
+
aux_num_out=aux_num_out,
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
def _init_hubert_pretrain_model(module):
|
| 681 |
+
if isinstance(module, components.ConvLayerBlock):
|
| 682 |
+
torch.nn.init.kaiming_normal_(module.conv.weight)
|
| 683 |
+
elif isinstance(module, components.ConvolutionalPositionalEmbedding):
|
| 684 |
+
# normalize the weight to normal distribution.
|
| 685 |
+
std = math.sqrt(4.0 / (module.embed_dim * module.kernel_size))
|
| 686 |
+
torch.nn.init.normal_(module.conv.weight, mean=0.0, std=std)
|
| 687 |
+
torch.nn.init.constant_(module.conv.bias, 0.0)
|
| 688 |
+
elif isinstance(module, components.SelfAttention):
|
| 689 |
+
# normalize the query, key, value, and out_proj parameters in self attention module.
|
| 690 |
+
torch.nn.init.xavier_uniform_(module.k_proj.weight, gain=1 / math.sqrt(2))
|
| 691 |
+
torch.nn.init.xavier_uniform_(module.v_proj.weight, gain=1 / math.sqrt(2))
|
| 692 |
+
torch.nn.init.xavier_uniform_(module.q_proj.weight, gain=1 / math.sqrt(2))
|
| 693 |
+
torch.nn.init.xavier_uniform_(module.out_proj.weight)
|
| 694 |
+
torch.nn.init.constant_(module.out_proj.bias, 0.0)
|
| 695 |
+
elif isinstance(module, components.Transformer):
|
| 696 |
+
module.apply(components._init_transformer_params)
|
| 697 |
+
else:
|
| 698 |
+
pass
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
def hubert_pretrain_model(
|
| 702 |
+
extractor_mode: str,
|
| 703 |
+
extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]],
|
| 704 |
+
extractor_conv_bias: bool,
|
| 705 |
+
encoder_embed_dim: int,
|
| 706 |
+
encoder_projection_dropout: float,
|
| 707 |
+
encoder_pos_conv_kernel: int,
|
| 708 |
+
encoder_pos_conv_groups: int,
|
| 709 |
+
encoder_num_layers: int,
|
| 710 |
+
encoder_num_heads: int,
|
| 711 |
+
encoder_attention_dropout: float,
|
| 712 |
+
encoder_ff_interm_features: int,
|
| 713 |
+
encoder_ff_interm_dropout: float,
|
| 714 |
+
encoder_dropout: float,
|
| 715 |
+
encoder_layer_norm_first: bool,
|
| 716 |
+
encoder_layer_drop: float,
|
| 717 |
+
mask_prob: float,
|
| 718 |
+
mask_selection: str,
|
| 719 |
+
mask_other: float,
|
| 720 |
+
mask_length: int,
|
| 721 |
+
no_mask_overlap: bool,
|
| 722 |
+
mask_min_space: int,
|
| 723 |
+
mask_channel_prob: float,
|
| 724 |
+
mask_channel_selection: str,
|
| 725 |
+
mask_channel_other: float,
|
| 726 |
+
mask_channel_length: int,
|
| 727 |
+
no_mask_channel_overlap: bool,
|
| 728 |
+
mask_channel_min_space: int,
|
| 729 |
+
skip_masked: bool,
|
| 730 |
+
skip_nomask: bool,
|
| 731 |
+
num_classes: int,
|
| 732 |
+
final_dim: int,
|
| 733 |
+
feature_grad_mult: Optional[float],
|
| 734 |
+
) -> HuBERTPretrainModel:
|
| 735 |
+
"""Builds custom :class:`HuBERTPretrainModel` for training from scratch
|
| 736 |
+
|
| 737 |
+
Note:
|
| 738 |
+
The "feature extractor" below corresponds to
|
| 739 |
+
`ConvFeatureExtractionModel <https://github.com/pytorch/fairseq/blob/dd3bd3c0497ae9a7ae7364404a6b0a4c501780b3/fairseq/models/wav2vec/wav2vec2.py#L736>`__
|
| 740 |
+
in the original ``fairseq`` implementation.
|
| 741 |
+
This is referred as "(convolutional) feature encoder" in the *wav2vec 2.0*
|
| 742 |
+
:cite:`baevski2020wav2vec` paper.
|
| 743 |
+
|
| 744 |
+
The "encoder" below corresponds to `TransformerEncoder <https://github.com/pytorch/fairseq/blob/dd3bd3c0497ae9a7ae7364404a6b0a4c501780b3/fairseq/models/wav2vec/wav2vec2.py#L817>`__,
|
| 745 |
+
and this is referred as "Transformer" in the paper.
|
| 746 |
+
|
| 747 |
+
Args:
|
| 748 |
+
extractor_mode (str): Operation mode of feature extractor.
|
| 749 |
+
Valid values are ``"group_norm"`` or ``"layer_norm"``.
|
| 750 |
+
If ``"group_norm"``, then a single normalization is applied
|
| 751 |
+
in the first convolution block. Otherwise, all the convolution
|
| 752 |
+
blocks will have layer normalization.
|
| 753 |
+
|
| 754 |
+
This option corresponds to ``extractor_mode`` from ``fairseq``.
|
| 755 |
+
|
| 756 |
+
extractor_conv_layer_config (list of integer tuples or None):
|
| 757 |
+
Configuration of convolution layers in feature extractor.
|
| 758 |
+
List of convolution configuration,
|
| 759 |
+
i.e. ``[(output_channel, kernel_size, stride), ...]``
|
| 760 |
+
|
| 761 |
+
If ``None`` is provided, then the following default value is used.
|
| 762 |
+
|
| 763 |
+
.. code-block:: python
|
| 764 |
+
|
| 765 |
+
[
|
| 766 |
+
(512, 10, 5),
|
| 767 |
+
(512, 3, 2),
|
| 768 |
+
(512, 3, 2),
|
| 769 |
+
(512, 3, 2),
|
| 770 |
+
(512, 3, 2),
|
| 771 |
+
(512, 2, 2),
|
| 772 |
+
(512, 2, 2),
|
| 773 |
+
]
|
| 774 |
+
|
| 775 |
+
This option corresponds to ``conv_feature_layers`` from ``fairseq``.
|
| 776 |
+
|
| 777 |
+
extractor_conv_bias (bool):
|
| 778 |
+
Whether to include bias term to each convolution operation.
|
| 779 |
+
|
| 780 |
+
This option corresponds to ``conv_bias`` from ``fairseq``.
|
| 781 |
+
|
| 782 |
+
encoder_embed_dim (int):
|
| 783 |
+
The dimension of embedding in encoder.
|
| 784 |
+
|
| 785 |
+
This option corresponds to ``encoder_embed_dim`` from ``fairseq``.
|
| 786 |
+
|
| 787 |
+
encoder_projection_dropout (float):
|
| 788 |
+
The dropout probability applied after the input feature is projected
|
| 789 |
+
to ``encoder_embed_dim``.
|
| 790 |
+
|
| 791 |
+
This option corresponds to ``dropout_input`` from ``fairseq``.
|
| 792 |
+
|
| 793 |
+
encoder_pos_conv_kernel (int):
|
| 794 |
+
The kernel size of convolutional positional embeddings.
|
| 795 |
+
|
| 796 |
+
This option corresponds to ``conv_pos`` from ``fairseq``.
|
| 797 |
+
|
| 798 |
+
encoder_pos_conv_groups (int):
|
| 799 |
+
The number of groups of convolutional positional embeddings.
|
| 800 |
+
|
| 801 |
+
This option corresponds to ``conv_pos_groups`` from ``fairseq``.
|
| 802 |
+
|
| 803 |
+
encoder_num_layers (int):
|
| 804 |
+
The number of self attention layers in transformer block.
|
| 805 |
+
|
| 806 |
+
This option corresponds to ``encoder_layers`` from ``fairseq``.
|
| 807 |
+
|
| 808 |
+
encoder_num_heads (int):
|
| 809 |
+
The number of heads in self attention layers.
|
| 810 |
+
|
| 811 |
+
This option corresponds to ``encoder_attention_heads`` from ``fairseq``.
|
| 812 |
+
|
| 813 |
+
encoder_attention_dropout (float):
|
| 814 |
+
The dropout probability applied after softmax in self-attention layer.
|
| 815 |
+
|
| 816 |
+
This option corresponds to ``attention_dropout`` from ``fairseq``.
|
| 817 |
+
|
| 818 |
+
encoder_ff_interm_features (int):
|
| 819 |
+
The dimension of hidden features in feed forward layer.
|
| 820 |
+
|
| 821 |
+
This option corresponds to ``encoder_ffn_embed_dim`` from ``fairseq``.
|
| 822 |
+
|
| 823 |
+
encoder_ff_interm_dropout (float):
|
| 824 |
+
The dropout probability applied in feedforward layer.
|
| 825 |
+
|
| 826 |
+
This option correspinds to ``activation_dropout`` from ``fairseq``.
|
| 827 |
+
|
| 828 |
+
encoder_dropout (float):
|
| 829 |
+
The dropout probability applied at the end of feed forward layer.
|
| 830 |
+
|
| 831 |
+
This option corresponds to ``dropout`` from ``fairseq``.
|
| 832 |
+
|
| 833 |
+
encoder_layer_norm_first (bool):
|
| 834 |
+
Control the order of layer norm in transformer layer and each encoder layer.
|
| 835 |
+
If True, in transformer layer, layer norm is applied before features are fed
|
| 836 |
+
to encoder layers. In encoder layer, two layer norms are applied before and after
|
| 837 |
+
self attention.
|
| 838 |
+
If False, in transformer layer, layer norm is applied after features are fed
|
| 839 |
+
to encoder layers. In encoder layer, two layer norms are applied after self
|
| 840 |
+
attention, before and after feed forward.
|
| 841 |
+
|
| 842 |
+
This option corresponds to ``layer_norm_first`` from ``fairseq``.
|
| 843 |
+
|
| 844 |
+
encoder_layer_drop (float):
|
| 845 |
+
Probability to drop each encoder layer during training.
|
| 846 |
+
|
| 847 |
+
This option corresponds to ``layerdrop`` from ``fairseq``.
|
| 848 |
+
|
| 849 |
+
mask_prob (float):
|
| 850 |
+
Probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
| 851 |
+
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
| 852 |
+
However due to overlaps, the actual number will be smaller (unless no_overlap is True).
|
| 853 |
+
|
| 854 |
+
This option corresponds to ``mask_prob`` from ``fairseq``.
|
| 855 |
+
|
| 856 |
+
mask_selection (str):
|
| 857 |
+
How to choose the mask length. Options: [``static``, ``uniform``, ``normal``, ``poisson``].
|
| 858 |
+
|
| 859 |
+
This option corresponds to ``mask_selection`` from ``fairseq``.
|
| 860 |
+
|
| 861 |
+
mask_other (float):
|
| 862 |
+
Secondary mask argument (used for more complex distributions).
|
| 863 |
+
|
| 864 |
+
This option corresponds to ``mask_other`` from ``fairseq``.
|
| 865 |
+
|
| 866 |
+
mask_length (int):
|
| 867 |
+
The lengths of the mask.
|
| 868 |
+
|
| 869 |
+
This option corresponds to ``mask_length`` from ``fairseq``.
|
| 870 |
+
|
| 871 |
+
no_mask_overlap (bool):
|
| 872 |
+
Whether to allow masks to overlap.
|
| 873 |
+
|
| 874 |
+
This option corresponds to ``no_mask_overlap`` from ``fairseq``.
|
| 875 |
+
|
| 876 |
+
mask_min_space (int):
|
| 877 |
+
Minimum space between spans (if no overlap is enabled).
|
| 878 |
+
|
| 879 |
+
This option corresponds to ``mask_min_space`` from ``fairseq``.
|
| 880 |
+
|
| 881 |
+
mask_channel_prob: (float):
|
| 882 |
+
The probability of replacing a feature with 0.
|
| 883 |
+
|
| 884 |
+
This option corresponds to ``mask_channel_prob`` from ``fairseq``.
|
| 885 |
+
|
| 886 |
+
mask_channel_selection (str):
|
| 887 |
+
How to choose the mask length for channel masking. Options: [``static``, ``uniform``, ``normal``, ``poisson``].
|
| 888 |
+
|
| 889 |
+
This option corresponds to ``mask_channel_selection`` from ``fairseq``.
|
| 890 |
+
|
| 891 |
+
mask_channel_other (float):
|
| 892 |
+
Secondary mask argument for channel masking(used for more complex distributions).
|
| 893 |
+
|
| 894 |
+
This option corresponds to ``mask_channel_other`` from ``fairseq``.
|
| 895 |
+
|
| 896 |
+
mask_channel_length (int):
|
| 897 |
+
Minimum space between spans (if no overlap is enabled) for channel masking.
|
| 898 |
+
|
| 899 |
+
This option corresponds to ``mask_channel_length`` from ``fairseq``.
|
| 900 |
+
|
| 901 |
+
no_mask_channel_overlap (bool):
|
| 902 |
+
Whether to allow channel masks to overlap.
|
| 903 |
+
|
| 904 |
+
This option corresponds to ``no_mask_channel_overlap`` from ``fairseq``.
|
| 905 |
+
|
| 906 |
+
mask_channel_min_space (int):
|
| 907 |
+
Minimum space between spans for channel masking(if no overlap is enabled).
|
| 908 |
+
|
| 909 |
+
This option corresponds to ``mask_channel_min_space`` from ``fairseq``.
|
| 910 |
+
|
| 911 |
+
skip_masked (bool):
|
| 912 |
+
If True, skip computing losses over masked frames.
|
| 913 |
+
|
| 914 |
+
This option corresponds to ``skip_masked`` from ``fairseq``.
|
| 915 |
+
|
| 916 |
+
skip_nomask (bool):
|
| 917 |
+
If True, skip computing losses over unmasked frames.
|
| 918 |
+
|
| 919 |
+
This option corresponds to ``skip_nomask`` from ``fairseq``.
|
| 920 |
+
|
| 921 |
+
num_classes (int):
|
| 922 |
+
The number of classes in the labels.
|
| 923 |
+
|
| 924 |
+
final_dim (int):
|
| 925 |
+
Project final representations and targets to `final_dim`.
|
| 926 |
+
|
| 927 |
+
This option corresponds to ``final_dim`` from ``fairseq``.
|
| 928 |
+
|
| 929 |
+
feature_grad_mult (float or None):
|
| 930 |
+
The factor to scale the convolutional feature extraction layer gradients by.
|
| 931 |
+
The scale factor will not affect the forward pass.
|
| 932 |
+
|
| 933 |
+
This option corresponds to ``feature_grad_mult`` from ``fairseq``.
|
| 934 |
+
|
| 935 |
+
Returns:
|
| 936 |
+
HuBERTPretrainModel:
|
| 937 |
+
The resulting model.
|
| 938 |
+
""" # noqa: E501
|
| 939 |
+
if extractor_conv_layer_config is None:
|
| 940 |
+
extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
|
| 941 |
+
|
| 942 |
+
feature_extractor = components._get_feature_extractor(
|
| 943 |
+
extractor_mode, extractor_conv_layer_config, extractor_conv_bias
|
| 944 |
+
)
|
| 945 |
+
encoder = components._get_encoder(
|
| 946 |
+
in_features=extractor_conv_layer_config[-1][0],
|
| 947 |
+
embed_dim=encoder_embed_dim,
|
| 948 |
+
dropout_input=encoder_projection_dropout,
|
| 949 |
+
pos_conv_kernel=encoder_pos_conv_kernel,
|
| 950 |
+
pos_conv_groups=encoder_pos_conv_groups,
|
| 951 |
+
num_layers=encoder_num_layers,
|
| 952 |
+
num_heads=encoder_num_heads,
|
| 953 |
+
attention_dropout=encoder_attention_dropout,
|
| 954 |
+
ff_interm_features=encoder_ff_interm_features,
|
| 955 |
+
ff_interm_dropout=encoder_ff_interm_dropout,
|
| 956 |
+
dropout=encoder_dropout,
|
| 957 |
+
layer_norm_first=encoder_layer_norm_first,
|
| 958 |
+
layer_drop=encoder_layer_drop,
|
| 959 |
+
)
|
| 960 |
+
wav2vec2 = Wav2Vec2Model(feature_extractor, encoder)
|
| 961 |
+
mask_generator = components.MaskGenerator(
|
| 962 |
+
encoder_embed_dim,
|
| 963 |
+
mask_prob,
|
| 964 |
+
mask_selection,
|
| 965 |
+
mask_other,
|
| 966 |
+
mask_length,
|
| 967 |
+
no_mask_overlap,
|
| 968 |
+
mask_min_space,
|
| 969 |
+
mask_channel_prob,
|
| 970 |
+
mask_channel_selection,
|
| 971 |
+
mask_channel_other,
|
| 972 |
+
mask_channel_length,
|
| 973 |
+
no_mask_channel_overlap,
|
| 974 |
+
mask_channel_min_space,
|
| 975 |
+
)
|
| 976 |
+
logit_generator = components.LogitGenerator(
|
| 977 |
+
encoder_embed_dim,
|
| 978 |
+
num_classes,
|
| 979 |
+
final_dim,
|
| 980 |
+
skip_masked,
|
| 981 |
+
skip_nomask,
|
| 982 |
+
)
|
| 983 |
+
model = HuBERTPretrainModel(
|
| 984 |
+
wav2vec2=wav2vec2,
|
| 985 |
+
mask_generator=mask_generator,
|
| 986 |
+
logit_generator=logit_generator,
|
| 987 |
+
feature_grad_mult=feature_grad_mult,
|
| 988 |
+
)
|
| 989 |
+
# initialize the model for pre-training
|
| 990 |
+
model.apply(_init_hubert_pretrain_model)
|
| 991 |
+
return model
|
| 992 |
+
|
| 993 |
+
|
| 994 |
+
def hubert_pretrain_base(
|
| 995 |
+
encoder_projection_dropout: float = 0.1,
|
| 996 |
+
encoder_attention_dropout: float = 0.1,
|
| 997 |
+
encoder_ff_interm_dropout: float = 0.0,
|
| 998 |
+
encoder_dropout: float = 0.1,
|
| 999 |
+
encoder_layer_drop: float = 0.05,
|
| 1000 |
+
mask_prob: float = 0.8,
|
| 1001 |
+
mask_channel_prob: float = 0.0,
|
| 1002 |
+
mask_channel_length: int = 10,
|
| 1003 |
+
feature_grad_mult: Optional[float] = 0.1,
|
| 1004 |
+
num_classes: int = 100,
|
| 1005 |
+
) -> HuBERTPretrainModel:
|
| 1006 |
+
"""Builds "base" :class:`HuBERTPretrainModel` from *HuBERT* :cite:`hsu2021hubert` for pretraining.
|
| 1007 |
+
|
| 1008 |
+
Args:
|
| 1009 |
+
encoder_projection_dropout (float):
|
| 1010 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1011 |
+
encoder_attention_dropout (float):
|
| 1012 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1013 |
+
encoder_ff_interm_dropout (float):
|
| 1014 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1015 |
+
encoder_dropout (float):
|
| 1016 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1017 |
+
encoder_layer_drop (float):
|
| 1018 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1019 |
+
mask_prob (float):
|
| 1020 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1021 |
+
mask_channel_prob (float):
|
| 1022 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1023 |
+
mask_channel_length (int):
|
| 1024 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1025 |
+
feature_grad_mult (float or None):
|
| 1026 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1027 |
+
num_classes (int, optional):
|
| 1028 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1029 |
+
|
| 1030 |
+
Returns:
|
| 1031 |
+
HuBERTPretrainModel:
|
| 1032 |
+
The resulting model.
|
| 1033 |
+
""" # noqa: E501
|
| 1034 |
+
return hubert_pretrain_model(
|
| 1035 |
+
extractor_mode="group_norm",
|
| 1036 |
+
extractor_conv_layer_config=None,
|
| 1037 |
+
extractor_conv_bias=False,
|
| 1038 |
+
encoder_embed_dim=768,
|
| 1039 |
+
encoder_projection_dropout=encoder_projection_dropout,
|
| 1040 |
+
encoder_pos_conv_kernel=128,
|
| 1041 |
+
encoder_pos_conv_groups=16,
|
| 1042 |
+
encoder_num_layers=12,
|
| 1043 |
+
encoder_num_heads=12,
|
| 1044 |
+
encoder_attention_dropout=encoder_attention_dropout,
|
| 1045 |
+
encoder_ff_interm_features=3072,
|
| 1046 |
+
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
|
| 1047 |
+
encoder_dropout=encoder_dropout,
|
| 1048 |
+
encoder_layer_norm_first=False,
|
| 1049 |
+
encoder_layer_drop=encoder_layer_drop,
|
| 1050 |
+
mask_prob=mask_prob,
|
| 1051 |
+
mask_selection="static",
|
| 1052 |
+
mask_other=0.0,
|
| 1053 |
+
mask_length=10,
|
| 1054 |
+
no_mask_overlap=False,
|
| 1055 |
+
mask_min_space=1,
|
| 1056 |
+
mask_channel_prob=mask_channel_prob,
|
| 1057 |
+
mask_channel_selection="static",
|
| 1058 |
+
mask_channel_other=0.0,
|
| 1059 |
+
mask_channel_length=mask_channel_length,
|
| 1060 |
+
no_mask_channel_overlap=False,
|
| 1061 |
+
mask_channel_min_space=1,
|
| 1062 |
+
skip_masked=False,
|
| 1063 |
+
skip_nomask=False,
|
| 1064 |
+
num_classes=num_classes,
|
| 1065 |
+
final_dim=256,
|
| 1066 |
+
feature_grad_mult=feature_grad_mult,
|
| 1067 |
+
)
|
| 1068 |
+
|
| 1069 |
+
|
| 1070 |
+
def hubert_pretrain_large(
|
| 1071 |
+
encoder_projection_dropout: float = 0.0,
|
| 1072 |
+
encoder_attention_dropout: float = 0.0,
|
| 1073 |
+
encoder_ff_interm_dropout: float = 0.0,
|
| 1074 |
+
encoder_dropout: float = 0.0,
|
| 1075 |
+
encoder_layer_drop: float = 0.0,
|
| 1076 |
+
mask_prob: float = 0.8,
|
| 1077 |
+
mask_channel_prob: float = 0.0,
|
| 1078 |
+
mask_channel_length: int = 10,
|
| 1079 |
+
feature_grad_mult: Optional[float] = None,
|
| 1080 |
+
) -> HuBERTPretrainModel:
|
| 1081 |
+
"""Builds "large" :class:`HuBERTPretrainModel` from *HuBERT* :cite:`hsu2021hubert` for pretraining.
|
| 1082 |
+
|
| 1083 |
+
Args:
|
| 1084 |
+
encoder_projection_dropout (float):
|
| 1085 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1086 |
+
encoder_attention_dropout (float):
|
| 1087 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1088 |
+
encoder_ff_interm_dropout (float):
|
| 1089 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1090 |
+
encoder_dropout (float):
|
| 1091 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1092 |
+
encoder_layer_drop (float):
|
| 1093 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1094 |
+
mask_prob (float):
|
| 1095 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1096 |
+
mask_channel_prob (float):
|
| 1097 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1098 |
+
mask_channel_length (int):
|
| 1099 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1100 |
+
feature_grad_mult (float or None):
|
| 1101 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1102 |
+
|
| 1103 |
+
Returns:
|
| 1104 |
+
HuBERTPretrainModel:
|
| 1105 |
+
The resulting model.
|
| 1106 |
+
""" # noqa: E501
|
| 1107 |
+
return hubert_pretrain_model(
|
| 1108 |
+
extractor_mode="layer_norm",
|
| 1109 |
+
extractor_conv_layer_config=None,
|
| 1110 |
+
extractor_conv_bias=False,
|
| 1111 |
+
encoder_embed_dim=1024,
|
| 1112 |
+
encoder_projection_dropout=encoder_projection_dropout,
|
| 1113 |
+
encoder_pos_conv_kernel=128,
|
| 1114 |
+
encoder_pos_conv_groups=16,
|
| 1115 |
+
encoder_num_layers=24,
|
| 1116 |
+
encoder_num_heads=16,
|
| 1117 |
+
encoder_attention_dropout=encoder_attention_dropout,
|
| 1118 |
+
encoder_ff_interm_features=4096,
|
| 1119 |
+
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
|
| 1120 |
+
encoder_dropout=encoder_dropout,
|
| 1121 |
+
encoder_layer_norm_first=True,
|
| 1122 |
+
encoder_layer_drop=encoder_layer_drop,
|
| 1123 |
+
mask_prob=mask_prob,
|
| 1124 |
+
mask_selection="static",
|
| 1125 |
+
mask_other=0.0,
|
| 1126 |
+
mask_length=10,
|
| 1127 |
+
no_mask_overlap=False,
|
| 1128 |
+
mask_min_space=1,
|
| 1129 |
+
mask_channel_prob=mask_channel_prob,
|
| 1130 |
+
mask_channel_selection="static",
|
| 1131 |
+
mask_channel_other=0.0,
|
| 1132 |
+
mask_channel_length=mask_channel_length,
|
| 1133 |
+
no_mask_channel_overlap=False,
|
| 1134 |
+
mask_channel_min_space=1,
|
| 1135 |
+
skip_masked=False,
|
| 1136 |
+
skip_nomask=False,
|
| 1137 |
+
num_classes=500,
|
| 1138 |
+
final_dim=768,
|
| 1139 |
+
feature_grad_mult=feature_grad_mult,
|
| 1140 |
+
)
|
| 1141 |
+
|
| 1142 |
+
|
| 1143 |
+
def hubert_pretrain_xlarge(
|
| 1144 |
+
encoder_projection_dropout: float = 0.0,
|
| 1145 |
+
encoder_attention_dropout: float = 0.0,
|
| 1146 |
+
encoder_ff_interm_dropout: float = 0.0,
|
| 1147 |
+
encoder_dropout: float = 0.0,
|
| 1148 |
+
encoder_layer_drop: float = 0.0,
|
| 1149 |
+
mask_prob: float = 0.8,
|
| 1150 |
+
mask_channel_prob: float = 0.0,
|
| 1151 |
+
mask_channel_length: int = 10,
|
| 1152 |
+
feature_grad_mult: Optional[float] = None,
|
| 1153 |
+
) -> HuBERTPretrainModel:
|
| 1154 |
+
"""Builds "extra large" :class:`HuBERTPretrainModel` from *HuBERT* :cite:`hsu2021hubert` for pretraining.
|
| 1155 |
+
|
| 1156 |
+
Args:
|
| 1157 |
+
encoder_projection_dropout (float):
|
| 1158 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1159 |
+
encoder_attention_dropout (float):
|
| 1160 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1161 |
+
encoder_ff_interm_dropout (float):
|
| 1162 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1163 |
+
encoder_dropout (float):
|
| 1164 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1165 |
+
encoder_layer_drop (float):
|
| 1166 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1167 |
+
mask_prob (float):
|
| 1168 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1169 |
+
mask_channel_prob (float):
|
| 1170 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1171 |
+
mask_channel_length (int):
|
| 1172 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1173 |
+
feature_grad_mult (float or None):
|
| 1174 |
+
See :py:func:`hubert_pretrain_model`.
|
| 1175 |
+
|
| 1176 |
+
Returns:
|
| 1177 |
+
HuBERTPretrainModel:
|
| 1178 |
+
The resulting model.
|
| 1179 |
+
""" # noqa: E501
|
| 1180 |
+
return hubert_pretrain_model(
|
| 1181 |
+
extractor_mode="layer_norm",
|
| 1182 |
+
extractor_conv_layer_config=None,
|
| 1183 |
+
extractor_conv_bias=False,
|
| 1184 |
+
encoder_embed_dim=1280,
|
| 1185 |
+
encoder_projection_dropout=encoder_projection_dropout,
|
| 1186 |
+
encoder_pos_conv_kernel=128,
|
| 1187 |
+
encoder_pos_conv_groups=16,
|
| 1188 |
+
encoder_num_layers=48,
|
| 1189 |
+
encoder_num_heads=16,
|
| 1190 |
+
encoder_attention_dropout=encoder_attention_dropout,
|
| 1191 |
+
encoder_ff_interm_features=5120,
|
| 1192 |
+
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
|
| 1193 |
+
encoder_dropout=encoder_dropout,
|
| 1194 |
+
encoder_layer_norm_first=True,
|
| 1195 |
+
encoder_layer_drop=encoder_layer_drop,
|
| 1196 |
+
mask_prob=mask_prob,
|
| 1197 |
+
mask_selection="static",
|
| 1198 |
+
mask_other=0.0,
|
| 1199 |
+
mask_length=10,
|
| 1200 |
+
no_mask_overlap=False,
|
| 1201 |
+
mask_min_space=1,
|
| 1202 |
+
mask_channel_prob=mask_channel_prob,
|
| 1203 |
+
mask_channel_selection="static",
|
| 1204 |
+
mask_channel_other=0.0,
|
| 1205 |
+
mask_channel_length=mask_channel_length,
|
| 1206 |
+
no_mask_channel_overlap=False,
|
| 1207 |
+
mask_channel_min_space=1,
|
| 1208 |
+
skip_masked=False,
|
| 1209 |
+
skip_nomask=False,
|
| 1210 |
+
num_classes=500,
|
| 1211 |
+
final_dim=1024,
|
| 1212 |
+
feature_grad_mult=feature_grad_mult,
|
| 1213 |
+
)
|
| 1214 |
+
|
| 1215 |
+
|
| 1216 |
+
def wavlm_model(
|
| 1217 |
+
extractor_mode: str,
|
| 1218 |
+
extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]],
|
| 1219 |
+
extractor_conv_bias: bool,
|
| 1220 |
+
encoder_embed_dim: int,
|
| 1221 |
+
encoder_projection_dropout: float,
|
| 1222 |
+
encoder_pos_conv_kernel: int,
|
| 1223 |
+
encoder_pos_conv_groups: int,
|
| 1224 |
+
encoder_num_layers: int,
|
| 1225 |
+
encoder_num_heads: int,
|
| 1226 |
+
encoder_num_buckets: int,
|
| 1227 |
+
encoder_max_distance: int,
|
| 1228 |
+
encoder_attention_dropout: float,
|
| 1229 |
+
encoder_ff_interm_features: int,
|
| 1230 |
+
encoder_ff_interm_dropout: float,
|
| 1231 |
+
encoder_dropout: float,
|
| 1232 |
+
encoder_layer_norm_first: bool,
|
| 1233 |
+
encoder_layer_drop: float,
|
| 1234 |
+
aux_num_out: Optional[int],
|
| 1235 |
+
) -> Wav2Vec2Model:
|
| 1236 |
+
"""Builds custom WaveLM model :cite:`chen2022wavlm`. The architecture is compatible
|
| 1237 |
+
with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output object is
|
| 1238 |
+
:class:`~torchaudio.models.Wav2Vec2Model`. Most of the arguments have the same meaning
|
| 1239 |
+
as in :py:func:`~torchaudio.models.wav2vec2_model` so please refer there for documentation.
|
| 1240 |
+
|
| 1241 |
+
Args:
|
| 1242 |
+
extractor_mode (str): Operation mode of feature extractor.
|
| 1243 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1244 |
+
|
| 1245 |
+
extractor_conv_layer_config (list of integer tuples or None):
|
| 1246 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1247 |
+
|
| 1248 |
+
extractor_conv_bias (bool):
|
| 1249 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1250 |
+
|
| 1251 |
+
encoder_embed_dim (int):
|
| 1252 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1253 |
+
|
| 1254 |
+
encoder_projection_dropout (float):
|
| 1255 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1256 |
+
|
| 1257 |
+
encoder_pos_conv_kernel (int):
|
| 1258 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1259 |
+
|
| 1260 |
+
encoder_pos_conv_groups (int):
|
| 1261 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1262 |
+
|
| 1263 |
+
encoder_num_layers (int):
|
| 1264 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1265 |
+
|
| 1266 |
+
encoder_num_heads (int):
|
| 1267 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1268 |
+
|
| 1269 |
+
encoder_num_buckets (int):
|
| 1270 |
+
Number of buckets for relative position embedding.
|
| 1271 |
+
encoder_max_distance (int):
|
| 1272 |
+
Maximum distance for relative position embedding.
|
| 1273 |
+
|
| 1274 |
+
encoder_attention_dropout (float):
|
| 1275 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1276 |
+
|
| 1277 |
+
encoder_ff_interm_features (int):
|
| 1278 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1279 |
+
|
| 1280 |
+
encoder_ff_interm_dropout (float):
|
| 1281 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1282 |
+
|
| 1283 |
+
encoder_dropout (float):
|
| 1284 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1285 |
+
|
| 1286 |
+
encoder_layer_norm_first (bool):
|
| 1287 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1288 |
+
|
| 1289 |
+
encoder_layer_drop (float):
|
| 1290 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1291 |
+
|
| 1292 |
+
aux_num_out (int or None):
|
| 1293 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1294 |
+
|
| 1295 |
+
Returns:
|
| 1296 |
+
Wav2Vec2Model:
|
| 1297 |
+
The resulting model.
|
| 1298 |
+
"""
|
| 1299 |
+
if extractor_conv_layer_config is None:
|
| 1300 |
+
extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
|
| 1301 |
+
|
| 1302 |
+
feature_extractor = components._get_feature_extractor(
|
| 1303 |
+
extractor_mode, extractor_conv_layer_config, extractor_conv_bias
|
| 1304 |
+
)
|
| 1305 |
+
encoder = components._get_wavlm_encoder(
|
| 1306 |
+
in_features=extractor_conv_layer_config[-1][0],
|
| 1307 |
+
embed_dim=encoder_embed_dim,
|
| 1308 |
+
dropout_input=encoder_projection_dropout,
|
| 1309 |
+
pos_conv_kernel=encoder_pos_conv_kernel,
|
| 1310 |
+
pos_conv_groups=encoder_pos_conv_groups,
|
| 1311 |
+
num_layers=encoder_num_layers,
|
| 1312 |
+
num_heads=encoder_num_heads,
|
| 1313 |
+
num_buckets=encoder_num_buckets,
|
| 1314 |
+
max_distance=encoder_max_distance,
|
| 1315 |
+
attention_dropout=encoder_attention_dropout,
|
| 1316 |
+
ff_interm_features=encoder_ff_interm_features,
|
| 1317 |
+
ff_interm_dropout=encoder_ff_interm_dropout,
|
| 1318 |
+
dropout=encoder_dropout,
|
| 1319 |
+
layer_norm_first=encoder_layer_norm_first,
|
| 1320 |
+
layer_drop=encoder_layer_drop,
|
| 1321 |
+
)
|
| 1322 |
+
aux = None
|
| 1323 |
+
if aux_num_out is not None:
|
| 1324 |
+
aux = torch.nn.Linear(in_features=encoder_embed_dim, out_features=aux_num_out)
|
| 1325 |
+
return Wav2Vec2Model(feature_extractor, encoder, aux)
|
| 1326 |
+
|
| 1327 |
+
|
| 1328 |
+
def wavlm_base(
|
| 1329 |
+
encoder_projection_dropout: float = 0.1,
|
| 1330 |
+
encoder_attention_dropout: float = 0.1,
|
| 1331 |
+
encoder_ff_interm_dropout: float = 0.1,
|
| 1332 |
+
encoder_dropout: float = 0.1,
|
| 1333 |
+
encoder_layer_drop: float = 0.1,
|
| 1334 |
+
aux_num_out: Optional[int] = None,
|
| 1335 |
+
) -> Wav2Vec2Model:
|
| 1336 |
+
"""Builds "base" WaveLM model :cite:`chen2022wavlm`. The architecture is compatible
|
| 1337 |
+
with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is
|
| 1338 |
+
:class:`~torchaudio.models.Wav2Vec2Model`.
|
| 1339 |
+
|
| 1340 |
+
Args:
|
| 1341 |
+
encoder_projection_dropout (float):
|
| 1342 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1343 |
+
encoder_attention_dropout (float):
|
| 1344 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1345 |
+
encoder_ff_interm_dropout (float):
|
| 1346 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1347 |
+
encoder_dropout (float):
|
| 1348 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1349 |
+
encoder_layer_drop (float):
|
| 1350 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1351 |
+
aux_num_out (int, optional):
|
| 1352 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1353 |
+
|
| 1354 |
+
Returns:
|
| 1355 |
+
Wav2Vec2Model:
|
| 1356 |
+
The resulting model.
|
| 1357 |
+
"""
|
| 1358 |
+
return wavlm_model(
|
| 1359 |
+
extractor_mode="group_norm",
|
| 1360 |
+
extractor_conv_layer_config=None,
|
| 1361 |
+
extractor_conv_bias=False,
|
| 1362 |
+
encoder_embed_dim=768,
|
| 1363 |
+
encoder_projection_dropout=encoder_projection_dropout,
|
| 1364 |
+
encoder_pos_conv_kernel=128,
|
| 1365 |
+
encoder_pos_conv_groups=16,
|
| 1366 |
+
encoder_num_layers=12,
|
| 1367 |
+
encoder_num_heads=12,
|
| 1368 |
+
encoder_num_buckets=320,
|
| 1369 |
+
encoder_max_distance=800,
|
| 1370 |
+
encoder_attention_dropout=encoder_attention_dropout,
|
| 1371 |
+
encoder_ff_interm_features=3072,
|
| 1372 |
+
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
|
| 1373 |
+
encoder_dropout=encoder_dropout,
|
| 1374 |
+
encoder_layer_norm_first=False,
|
| 1375 |
+
encoder_layer_drop=encoder_layer_drop,
|
| 1376 |
+
aux_num_out=aux_num_out,
|
| 1377 |
+
)
|
| 1378 |
+
|
| 1379 |
+
|
| 1380 |
+
def wavlm_large(
|
| 1381 |
+
encoder_projection_dropout: float = 0.1,
|
| 1382 |
+
encoder_attention_dropout: float = 0.1,
|
| 1383 |
+
encoder_ff_interm_dropout: float = 0.0,
|
| 1384 |
+
encoder_dropout: float = 0.1,
|
| 1385 |
+
encoder_layer_drop: float = 0.1,
|
| 1386 |
+
aux_num_out: Optional[int] = None,
|
| 1387 |
+
) -> Wav2Vec2Model:
|
| 1388 |
+
"""Builds "large" WaveLM model :cite:`chen2022wavlm`. The architecture is compatible
|
| 1389 |
+
with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is
|
| 1390 |
+
:class:`~torchaudio.models.Wav2Vec2Model`.
|
| 1391 |
+
|
| 1392 |
+
Args:
|
| 1393 |
+
encoder_projection_dropout (float):
|
| 1394 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1395 |
+
encoder_attention_dropout (float):
|
| 1396 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1397 |
+
encoder_ff_interm_dropout (float):
|
| 1398 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1399 |
+
encoder_dropout (float):
|
| 1400 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1401 |
+
encoder_layer_drop (float):
|
| 1402 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1403 |
+
aux_num_out (int, optional):
|
| 1404 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1405 |
+
|
| 1406 |
+
Returns:
|
| 1407 |
+
Wav2Vec2Model:
|
| 1408 |
+
The resulting model.
|
| 1409 |
+
"""
|
| 1410 |
+
return wavlm_model(
|
| 1411 |
+
extractor_mode="layer_norm",
|
| 1412 |
+
extractor_conv_layer_config=None,
|
| 1413 |
+
extractor_conv_bias=False,
|
| 1414 |
+
encoder_embed_dim=1024,
|
| 1415 |
+
encoder_projection_dropout=encoder_projection_dropout,
|
| 1416 |
+
encoder_pos_conv_kernel=128,
|
| 1417 |
+
encoder_pos_conv_groups=16,
|
| 1418 |
+
encoder_num_layers=24,
|
| 1419 |
+
encoder_num_heads=16,
|
| 1420 |
+
encoder_num_buckets=320,
|
| 1421 |
+
encoder_max_distance=800,
|
| 1422 |
+
encoder_attention_dropout=encoder_attention_dropout,
|
| 1423 |
+
encoder_ff_interm_features=4096,
|
| 1424 |
+
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
|
| 1425 |
+
encoder_dropout=encoder_dropout,
|
| 1426 |
+
encoder_layer_norm_first=True,
|
| 1427 |
+
encoder_layer_drop=encoder_layer_drop,
|
| 1428 |
+
aux_num_out=aux_num_out,
|
| 1429 |
+
)
|
| 1430 |
+
|
| 1431 |
+
|
| 1432 |
+
def wav2vec2_xlsr_300m(
|
| 1433 |
+
encoder_projection_dropout: float = 0.0,
|
| 1434 |
+
encoder_attention_dropout: float = 0.0,
|
| 1435 |
+
encoder_ff_interm_dropout: float = 0.0,
|
| 1436 |
+
encoder_dropout: float = 0.0,
|
| 1437 |
+
encoder_layer_drop: float = 0.0,
|
| 1438 |
+
aux_num_out: Optional[int] = None,
|
| 1439 |
+
) -> Wav2Vec2Model:
|
| 1440 |
+
"""Builds XLS-R model :cite:`babu2021xls` with 300 millions of parameters. The architecture is compatible
|
| 1441 |
+
with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is
|
| 1442 |
+
:class:`~torchaudio.models.Wav2Vec2Model`.
|
| 1443 |
+
|
| 1444 |
+
Args:
|
| 1445 |
+
encoder_projection_dropout (float):
|
| 1446 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1447 |
+
encoder_attention_dropout (float):
|
| 1448 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1449 |
+
encoder_ff_interm_dropout (float):
|
| 1450 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1451 |
+
encoder_dropout (float):
|
| 1452 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1453 |
+
encoder_layer_drop (float):
|
| 1454 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1455 |
+
aux_num_out (int, optional):
|
| 1456 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1457 |
+
|
| 1458 |
+
Returns:
|
| 1459 |
+
Wav2Vec2Model:
|
| 1460 |
+
The resulting model.
|
| 1461 |
+
"""
|
| 1462 |
+
return wav2vec2_model(
|
| 1463 |
+
extractor_mode="layer_norm",
|
| 1464 |
+
extractor_conv_layer_config=None,
|
| 1465 |
+
extractor_conv_bias=True,
|
| 1466 |
+
encoder_embed_dim=1024,
|
| 1467 |
+
encoder_projection_dropout=encoder_projection_dropout,
|
| 1468 |
+
encoder_pos_conv_kernel=128,
|
| 1469 |
+
encoder_pos_conv_groups=16,
|
| 1470 |
+
encoder_num_layers=24,
|
| 1471 |
+
encoder_num_heads=16,
|
| 1472 |
+
encoder_attention_dropout=encoder_attention_dropout,
|
| 1473 |
+
encoder_ff_interm_features=4096,
|
| 1474 |
+
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
|
| 1475 |
+
encoder_dropout=encoder_dropout,
|
| 1476 |
+
encoder_layer_norm_first=True,
|
| 1477 |
+
encoder_layer_drop=encoder_layer_drop,
|
| 1478 |
+
aux_num_out=aux_num_out,
|
| 1479 |
+
)
|
| 1480 |
+
|
| 1481 |
+
|
| 1482 |
+
def wav2vec2_xlsr_1b(
|
| 1483 |
+
encoder_projection_dropout: float = 0.1,
|
| 1484 |
+
encoder_attention_dropout: float = 0.0,
|
| 1485 |
+
encoder_ff_interm_dropout: float = 0.0,
|
| 1486 |
+
encoder_dropout: float = 0.0,
|
| 1487 |
+
encoder_layer_drop: float = 0.0,
|
| 1488 |
+
aux_num_out: Optional[int] = None,
|
| 1489 |
+
) -> Wav2Vec2Model:
|
| 1490 |
+
"""Builds XLS-R model :cite:`babu2021xls` with 1 billion of parameters. The architecture is compatible
|
| 1491 |
+
with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is
|
| 1492 |
+
:class:`~torchaudio.models.Wav2Vec2Model`.
|
| 1493 |
+
|
| 1494 |
+
Args:
|
| 1495 |
+
encoder_projection_dropout (float):
|
| 1496 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1497 |
+
encoder_attention_dropout (float):
|
| 1498 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1499 |
+
encoder_ff_interm_dropout (float):
|
| 1500 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1501 |
+
encoder_dropout (float):
|
| 1502 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1503 |
+
encoder_layer_drop (float):
|
| 1504 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1505 |
+
aux_num_out (int, optional):
|
| 1506 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1507 |
+
|
| 1508 |
+
Returns:
|
| 1509 |
+
Wav2Vec2Model:
|
| 1510 |
+
The resulting model.
|
| 1511 |
+
"""
|
| 1512 |
+
return wav2vec2_model(
|
| 1513 |
+
extractor_mode="layer_norm",
|
| 1514 |
+
extractor_conv_layer_config=None,
|
| 1515 |
+
extractor_conv_bias=True,
|
| 1516 |
+
encoder_embed_dim=1280,
|
| 1517 |
+
encoder_projection_dropout=encoder_projection_dropout,
|
| 1518 |
+
encoder_pos_conv_kernel=128,
|
| 1519 |
+
encoder_pos_conv_groups=16,
|
| 1520 |
+
encoder_num_layers=48,
|
| 1521 |
+
encoder_num_heads=16,
|
| 1522 |
+
encoder_attention_dropout=encoder_attention_dropout,
|
| 1523 |
+
encoder_ff_interm_features=5120,
|
| 1524 |
+
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
|
| 1525 |
+
encoder_dropout=encoder_dropout,
|
| 1526 |
+
encoder_layer_norm_first=True,
|
| 1527 |
+
encoder_layer_drop=encoder_layer_drop,
|
| 1528 |
+
aux_num_out=aux_num_out,
|
| 1529 |
+
)
|
| 1530 |
+
|
| 1531 |
+
|
| 1532 |
+
def wav2vec2_xlsr_2b(
|
| 1533 |
+
encoder_projection_dropout: float = 0.1,
|
| 1534 |
+
encoder_attention_dropout: float = 0.0,
|
| 1535 |
+
encoder_ff_interm_dropout: float = 0.0,
|
| 1536 |
+
encoder_dropout: float = 0.0,
|
| 1537 |
+
encoder_layer_drop: float = 0.0,
|
| 1538 |
+
aux_num_out: Optional[int] = None,
|
| 1539 |
+
) -> Wav2Vec2Model:
|
| 1540 |
+
"""Builds XLS-R model :cite:`babu2021xls` with 2 billions of parameters. The architecture is compatible
|
| 1541 |
+
with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is
|
| 1542 |
+
:class:`~torchaudio.models.Wav2Vec2Model`.
|
| 1543 |
+
|
| 1544 |
+
Args:
|
| 1545 |
+
encoder_projection_dropout (float):
|
| 1546 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1547 |
+
encoder_attention_dropout (float):
|
| 1548 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1549 |
+
encoder_ff_interm_dropout (float):
|
| 1550 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1551 |
+
encoder_dropout (float):
|
| 1552 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1553 |
+
encoder_layer_drop (float):
|
| 1554 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1555 |
+
aux_num_out (int, optional):
|
| 1556 |
+
See :py:func:`~torchaudio.models.wav2vec2_model`.
|
| 1557 |
+
|
| 1558 |
+
Returns:
|
| 1559 |
+
Wav2Vec2Model:
|
| 1560 |
+
The resulting model.
|
| 1561 |
+
"""
|
| 1562 |
+
return wav2vec2_model(
|
| 1563 |
+
extractor_mode="layer_norm",
|
| 1564 |
+
extractor_conv_layer_config=None,
|
| 1565 |
+
extractor_conv_bias=True,
|
| 1566 |
+
encoder_embed_dim=1920,
|
| 1567 |
+
encoder_projection_dropout=encoder_projection_dropout,
|
| 1568 |
+
encoder_pos_conv_kernel=128,
|
| 1569 |
+
encoder_pos_conv_groups=16,
|
| 1570 |
+
encoder_num_layers=48,
|
| 1571 |
+
encoder_num_heads=16,
|
| 1572 |
+
encoder_attention_dropout=encoder_attention_dropout,
|
| 1573 |
+
encoder_ff_interm_features=7680,
|
| 1574 |
+
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
|
| 1575 |
+
encoder_dropout=encoder_dropout,
|
| 1576 |
+
encoder_layer_norm_first=True,
|
| 1577 |
+
encoder_layer_drop=encoder_layer_drop,
|
| 1578 |
+
aux_num_out=aux_num_out,
|
| 1579 |
+
)
|
.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .import_fairseq import import_fairseq_model
|
| 2 |
+
from .import_huggingface import import_huggingface_model
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"import_huggingface_model",
|
| 6 |
+
"import_fairseq_model",
|
| 7 |
+
]
|
.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (401 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/import_fairseq.cpython-311.pyc
ADDED
|
Binary file (12.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/import_huggingface.cpython-311.pyc
ADDED
|
Binary file (7.91 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/wavlm_attention.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The MIT License (MIT)
|
| 3 |
+
|
| 4 |
+
Copyright (c) Microsoft Corporation
|
| 5 |
+
|
| 6 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 7 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 8 |
+
in the Software without restriction, including without limitation the rights
|
| 9 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 10 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 11 |
+
furnished to do so, subject to the following conditions:
|
| 12 |
+
|
| 13 |
+
The above copyright notice and this permission notice shall be included in all
|
| 14 |
+
copies or substantial portions of the Software.
|
| 15 |
+
|
| 16 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 17 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 18 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 19 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 20 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 21 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 22 |
+
SOFTWARE.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import math
|
| 26 |
+
from typing import Optional, Tuple
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
from torch import nn, Tensor
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class WavLMSelfAttention(nn.Module):
|
| 33 |
+
"""Multi-headed self-attention for WavLM model :cite:`chen2022wavlm`.
|
| 34 |
+
Wraps around ``torch.nn.MultiheadAttention``, creating relaive position embeddings and passing them to multi-headed
|
| 35 |
+
attention as a mask.
|
| 36 |
+
Source: https://github.com/microsoft/unilm/blob/2d8302f09c99bca2b82e6e868d81d4281cceebc8/wavlm/modules.py#L303-L763
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
embed_dim (int): Total dimension of the model.
|
| 40 |
+
num_heads (int): The number of heads.
|
| 41 |
+
dropout (float, optional): Dropout probability on attn_output_weights. (Default: to ``0.0``)
|
| 42 |
+
bias (bool, optional): If ``True``, add bias to input / output projection layers. (Default: ``True``)
|
| 43 |
+
has_relative_attention_bias (bool, optional): If ``True``, apply relative position embedding.
|
| 44 |
+
Necessary in the first encoder layer, but not in the subsequent ones. (Default: ``False``)
|
| 45 |
+
num_buckets (int, optional): Number of buckets for relative position embedding. (Default: ``32``)
|
| 46 |
+
max_distance (int, optional): Naximum distance for relative position embedding. (Default: ``128``)
|
| 47 |
+
gru_rel_pos (bool, optional): If ``True``, apply gated relative position embedding. (Default: ``False``)
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
embed_dim: int,
|
| 53 |
+
num_heads: int,
|
| 54 |
+
dropout: float = 0.0,
|
| 55 |
+
bias: bool = True,
|
| 56 |
+
has_relative_attention_bias: bool = False,
|
| 57 |
+
num_buckets: int = 32,
|
| 58 |
+
max_distance: int = 128,
|
| 59 |
+
gru_rel_pos: bool = True,
|
| 60 |
+
):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.embed_dim = embed_dim
|
| 63 |
+
self.num_heads = num_heads
|
| 64 |
+
self.has_relative_attention_bias = has_relative_attention_bias
|
| 65 |
+
self.num_buckets = num_buckets
|
| 66 |
+
self.max_distance = max_distance
|
| 67 |
+
|
| 68 |
+
if has_relative_attention_bias:
|
| 69 |
+
self.rel_attn_embed = nn.Embedding(num_buckets, num_heads)
|
| 70 |
+
else:
|
| 71 |
+
self.rel_attn_embed = None
|
| 72 |
+
|
| 73 |
+
self.head_dim = embed_dim // num_heads
|
| 74 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
| 75 |
+
|
| 76 |
+
self.dropout = dropout
|
| 77 |
+
self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True)
|
| 78 |
+
|
| 79 |
+
self.gru_rel_pos = gru_rel_pos
|
| 80 |
+
if self.gru_rel_pos:
|
| 81 |
+
self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8)
|
| 82 |
+
self.gru_rel_pos_const = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
| 83 |
+
self.has_position_bias = True
|
| 84 |
+
|
| 85 |
+
def compute_bias(self, query_length: int, key_length: int) -> Tensor:
|
| 86 |
+
"""Compute relative position embeddings for WavLM model.
|
| 87 |
+
Args:
|
| 88 |
+
query_length (int): Query position can take values between 0 and ``query_length - 1``.
|
| 89 |
+
key_length (int): Key position can take values between 0 and ``key_length - 1``.
|
| 90 |
+
Returns:
|
| 91 |
+
Tensor of shape `(num_heads, query_length, key_length)`, relative positions embeddings
|
| 92 |
+
"""
|
| 93 |
+
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
| 94 |
+
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
| 95 |
+
relative_position = memory_position - context_position # Shape (query_length, key_length)
|
| 96 |
+
relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True)
|
| 97 |
+
relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device)
|
| 98 |
+
values = self.rel_attn_embed(relative_position_bucket) # Shape (query_length, key_length, num_heads)
|
| 99 |
+
values = values.permute([2, 0, 1])
|
| 100 |
+
return values
|
| 101 |
+
|
| 102 |
+
def _relative_positions_bucket(self, relative_positions: Tensor, bidirectional: bool = True):
|
| 103 |
+
"""Compute relative position buckets for WavLM model. Computation similar to formula (5) in WavLM
|
| 104 |
+
paper :cite:`chen2022wavlm`.
|
| 105 |
+
Args:
|
| 106 |
+
relative_positions (Tensor): Relative offsets between query and key positions,
|
| 107 |
+
of shape ``(query_length, key_length)``.
|
| 108 |
+
bidirectional (bool): If ``True``, values will be filled both above and below the diagonal in the resulting
|
| 109 |
+
matrix. If ``False``, the elements above the diagonal (i.e. with negative relative offsets) will be set
|
| 110 |
+
to zero. (Default ``True``)
|
| 111 |
+
Returns:
|
| 112 |
+
Tensor of shape ``(query_length, key_length)`` filled bucketed values of with relative positions.
|
| 113 |
+
"""
|
| 114 |
+
num_buckets = self.num_buckets
|
| 115 |
+
max_distance = self.max_distance
|
| 116 |
+
# Shape (query_length, key_length)
|
| 117 |
+
relative_buckets = torch.zeros_like(relative_positions, dtype=torch.long)
|
| 118 |
+
|
| 119 |
+
if bidirectional:
|
| 120 |
+
num_buckets = num_buckets // 2
|
| 121 |
+
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
| 122 |
+
relative_positions = torch.abs(relative_positions)
|
| 123 |
+
else:
|
| 124 |
+
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
|
| 125 |
+
|
| 126 |
+
max_exact = num_buckets // 2
|
| 127 |
+
is_small = relative_positions < max_exact
|
| 128 |
+
|
| 129 |
+
relative_postion_if_large = max_exact + (
|
| 130 |
+
torch.log(relative_positions.float() / max_exact)
|
| 131 |
+
/ math.log(max_distance / max_exact)
|
| 132 |
+
* (num_buckets - max_exact)
|
| 133 |
+
).to(torch.long)
|
| 134 |
+
relative_postion_if_large = torch.min(
|
| 135 |
+
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
|
| 139 |
+
return relative_buckets
|
| 140 |
+
|
| 141 |
+
def forward(
|
| 142 |
+
self,
|
| 143 |
+
query: Tensor,
|
| 144 |
+
key_padding_mask: Optional[Tensor] = None,
|
| 145 |
+
attention_mask: Optional[Tensor] = None,
|
| 146 |
+
position_bias: Optional[Tensor] = None,
|
| 147 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
| 148 |
+
"""
|
| 149 |
+
Args:
|
| 150 |
+
query (Tensor): Input of shape ``(batch_size, src_len, embed_dim)``.
|
| 151 |
+
key_padding_mask (Tensor or None, optional): Mask to exclude keys that are pads, of shape
|
| 152 |
+
`(batch, src_len)`, where padding elements are indicated by 1s. (Default: ``None``)
|
| 153 |
+
attn_mask: Needs to be ``None``. The argument exists for compatibility with
|
| 154 |
+
``EncoderLayer``. (Default: ``None``)
|
| 155 |
+
position_bias (Tensor or None, optional): Position bias of shape
|
| 156 |
+
``(batch_size * num_heads, src_len, src_len)``. When used inside WavLM model encoder, will be
|
| 157 |
+
generated in the first layer and then passed from each encoder layer to the next one.
|
| 158 |
+
(Default: ``None``)
|
| 159 |
+
Returns:
|
| 160 |
+
attn_output (Tensor): Attention output of shape ``(batch_size, src_len, embed_dim)``.
|
| 161 |
+
position_bias (Tensor or None): Position bias of shape ``(batch_size * num_heads, src_len, src_len)``.
|
| 162 |
+
"""
|
| 163 |
+
bsz, seq_len, embed_dim = query.size()
|
| 164 |
+
assert embed_dim == self.embed_dim
|
| 165 |
+
assert attention_mask is None
|
| 166 |
+
|
| 167 |
+
if self.rel_attn_embed is not None and position_bias is None:
|
| 168 |
+
position_bias = self.compute_bias(seq_len, seq_len)
|
| 169 |
+
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1)
|
| 170 |
+
|
| 171 |
+
attn_mask_rel_pos: Optional[Tensor] = None
|
| 172 |
+
if position_bias is not None:
|
| 173 |
+
attn_mask_rel_pos = position_bias
|
| 174 |
+
if self.gru_rel_pos: # Apply gating on relative position bias
|
| 175 |
+
query_layer = query.view(bsz, seq_len, self.num_heads, -1)
|
| 176 |
+
query_layer = query_layer.permute(0, 2, 1, 3)
|
| 177 |
+
|
| 178 |
+
gate_a, gate_b = torch.sigmoid(
|
| 179 |
+
self.gru_rel_pos_linear(query_layer).view(bsz, self.num_heads, seq_len, 2, 4).sum(-1, keepdim=False)
|
| 180 |
+
).chunk(2, dim=-1)
|
| 181 |
+
gate_a_1 = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0
|
| 182 |
+
attn_mask_rel_pos = gate_a_1.view(bsz, self.num_heads, -1, 1) * position_bias
|
| 183 |
+
|
| 184 |
+
attn_mask_rel_pos = attn_mask_rel_pos.view((bsz, self.num_heads, seq_len, seq_len))
|
| 185 |
+
|
| 186 |
+
if attn_mask_rel_pos is not None and key_padding_mask is not None:
|
| 187 |
+
key_padding_mask = key_padding_mask.view(bsz, 1, 1, seq_len).expand(-1, self.num_heads, -1, -1)
|
| 188 |
+
key_padding_mask = torch.nn.functional._canonical_mask(
|
| 189 |
+
mask=key_padding_mask,
|
| 190 |
+
mask_name="key_padding_mask",
|
| 191 |
+
other_type=torch.nn.functional._none_or_dtype(attn_mask_rel_pos),
|
| 192 |
+
other_name="",
|
| 193 |
+
target_type=query.dtype,
|
| 194 |
+
)
|
| 195 |
+
if attn_mask_rel_pos is not None and key_padding_mask is not None:
|
| 196 |
+
attn_mask_rel_pos = attn_mask_rel_pos + key_padding_mask
|
| 197 |
+
query_projected = torch.nn.functional.linear(query, self.attention.in_proj_weight, self.attention.in_proj_bias)
|
| 198 |
+
query, key, value = query_projected.chunk(3, -1)
|
| 199 |
+
shape = (bsz, seq_len, self.num_heads, self.head_dim)
|
| 200 |
+
query = query.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim)
|
| 201 |
+
key = key.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim)
|
| 202 |
+
value = value.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim)
|
| 203 |
+
dropout = self.dropout if self.training else 0.0
|
| 204 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 205 |
+
query,
|
| 206 |
+
key,
|
| 207 |
+
value,
|
| 208 |
+
attn_mask=attn_mask_rel_pos,
|
| 209 |
+
dropout_p=dropout,
|
| 210 |
+
is_causal=False,
|
| 211 |
+
)
|
| 212 |
+
attn_output = attn_output.transpose(1, 2).reshape(bsz, -1, self.num_heads * self.head_dim)
|
| 213 |
+
attn_output = self.attention.out_proj(attn_output)
|
| 214 |
+
return attn_output, position_bias
|
.venv/lib/python3.11/site-packages/torchaudio/models/wavernn.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch import nn, Tensor
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"ResBlock",
|
| 10 |
+
"MelResNet",
|
| 11 |
+
"Stretch2d",
|
| 12 |
+
"UpsampleNetwork",
|
| 13 |
+
"WaveRNN",
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ResBlock(nn.Module):
|
| 18 |
+
r"""ResNet block based on *Efficient Neural Audio Synthesis* :cite:`kalchbrenner2018efficient`.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
n_freq: the number of bins in a spectrogram. (Default: ``128``)
|
| 22 |
+
|
| 23 |
+
Examples
|
| 24 |
+
>>> resblock = ResBlock()
|
| 25 |
+
>>> input = torch.rand(10, 128, 512) # a random spectrogram
|
| 26 |
+
>>> output = resblock(input) # shape: (10, 128, 512)
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, n_freq: int = 128) -> None:
|
| 30 |
+
super().__init__()
|
| 31 |
+
|
| 32 |
+
self.resblock_model = nn.Sequential(
|
| 33 |
+
nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False),
|
| 34 |
+
nn.BatchNorm1d(n_freq),
|
| 35 |
+
nn.ReLU(inplace=True),
|
| 36 |
+
nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False),
|
| 37 |
+
nn.BatchNorm1d(n_freq),
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def forward(self, specgram: Tensor) -> Tensor:
|
| 41 |
+
r"""Pass the input through the ResBlock layer.
|
| 42 |
+
Args:
|
| 43 |
+
specgram (Tensor): the input sequence to the ResBlock layer (n_batch, n_freq, n_time).
|
| 44 |
+
|
| 45 |
+
Return:
|
| 46 |
+
Tensor shape: (n_batch, n_freq, n_time)
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
return self.resblock_model(specgram) + specgram
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class MelResNet(nn.Module):
|
| 53 |
+
r"""MelResNet layer uses a stack of ResBlocks on spectrogram.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
n_res_block: the number of ResBlock in stack. (Default: ``10``)
|
| 57 |
+
n_freq: the number of bins in a spectrogram. (Default: ``128``)
|
| 58 |
+
n_hidden: the number of hidden dimensions of resblock. (Default: ``128``)
|
| 59 |
+
n_output: the number of output dimensions of melresnet. (Default: ``128``)
|
| 60 |
+
kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
|
| 61 |
+
|
| 62 |
+
Examples
|
| 63 |
+
>>> melresnet = MelResNet()
|
| 64 |
+
>>> input = torch.rand(10, 128, 512) # a random spectrogram
|
| 65 |
+
>>> output = melresnet(input) # shape: (10, 128, 508)
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self, n_res_block: int = 10, n_freq: int = 128, n_hidden: int = 128, n_output: int = 128, kernel_size: int = 5
|
| 70 |
+
) -> None:
|
| 71 |
+
super().__init__()
|
| 72 |
+
|
| 73 |
+
ResBlocks = [ResBlock(n_hidden) for _ in range(n_res_block)]
|
| 74 |
+
|
| 75 |
+
self.melresnet_model = nn.Sequential(
|
| 76 |
+
nn.Conv1d(in_channels=n_freq, out_channels=n_hidden, kernel_size=kernel_size, bias=False),
|
| 77 |
+
nn.BatchNorm1d(n_hidden),
|
| 78 |
+
nn.ReLU(inplace=True),
|
| 79 |
+
*ResBlocks,
|
| 80 |
+
nn.Conv1d(in_channels=n_hidden, out_channels=n_output, kernel_size=1),
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def forward(self, specgram: Tensor) -> Tensor:
|
| 84 |
+
r"""Pass the input through the MelResNet layer.
|
| 85 |
+
Args:
|
| 86 |
+
specgram (Tensor): the input sequence to the MelResNet layer (n_batch, n_freq, n_time).
|
| 87 |
+
|
| 88 |
+
Return:
|
| 89 |
+
Tensor shape: (n_batch, n_output, n_time - kernel_size + 1)
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
return self.melresnet_model(specgram)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class Stretch2d(nn.Module):
|
| 96 |
+
r"""Upscale the frequency and time dimensions of a spectrogram.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
time_scale: the scale factor in time dimension
|
| 100 |
+
freq_scale: the scale factor in frequency dimension
|
| 101 |
+
|
| 102 |
+
Examples
|
| 103 |
+
>>> stretch2d = Stretch2d(time_scale=10, freq_scale=5)
|
| 104 |
+
|
| 105 |
+
>>> input = torch.rand(10, 100, 512) # a random spectrogram
|
| 106 |
+
>>> output = stretch2d(input) # shape: (10, 500, 5120)
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
def __init__(self, time_scale: int, freq_scale: int) -> None:
|
| 110 |
+
super().__init__()
|
| 111 |
+
|
| 112 |
+
self.freq_scale = freq_scale
|
| 113 |
+
self.time_scale = time_scale
|
| 114 |
+
|
| 115 |
+
def forward(self, specgram: Tensor) -> Tensor:
|
| 116 |
+
r"""Pass the input through the Stretch2d layer.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
specgram (Tensor): the input sequence to the Stretch2d layer (..., n_freq, n_time).
|
| 120 |
+
|
| 121 |
+
Return:
|
| 122 |
+
Tensor shape: (..., n_freq * freq_scale, n_time * time_scale)
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
return specgram.repeat_interleave(self.freq_scale, -2).repeat_interleave(self.time_scale, -1)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class UpsampleNetwork(nn.Module):
|
| 129 |
+
r"""Upscale the dimensions of a spectrogram.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
upsample_scales: the list of upsample scales.
|
| 133 |
+
n_res_block: the number of ResBlock in stack. (Default: ``10``)
|
| 134 |
+
n_freq: the number of bins in a spectrogram. (Default: ``128``)
|
| 135 |
+
n_hidden: the number of hidden dimensions of resblock. (Default: ``128``)
|
| 136 |
+
n_output: the number of output dimensions of melresnet. (Default: ``128``)
|
| 137 |
+
kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
|
| 138 |
+
|
| 139 |
+
Examples
|
| 140 |
+
>>> upsamplenetwork = UpsampleNetwork(upsample_scales=[4, 4, 16])
|
| 141 |
+
>>> input = torch.rand(10, 128, 10) # a random spectrogram
|
| 142 |
+
>>> output = upsamplenetwork(input) # shape: (10, 128, 1536), (10, 128, 1536)
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
def __init__(
|
| 146 |
+
self,
|
| 147 |
+
upsample_scales: List[int],
|
| 148 |
+
n_res_block: int = 10,
|
| 149 |
+
n_freq: int = 128,
|
| 150 |
+
n_hidden: int = 128,
|
| 151 |
+
n_output: int = 128,
|
| 152 |
+
kernel_size: int = 5,
|
| 153 |
+
) -> None:
|
| 154 |
+
super().__init__()
|
| 155 |
+
|
| 156 |
+
total_scale = 1
|
| 157 |
+
for upsample_scale in upsample_scales:
|
| 158 |
+
total_scale *= upsample_scale
|
| 159 |
+
self.total_scale: int = total_scale
|
| 160 |
+
|
| 161 |
+
self.indent = (kernel_size - 1) // 2 * total_scale
|
| 162 |
+
self.resnet = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
|
| 163 |
+
self.resnet_stretch = Stretch2d(total_scale, 1)
|
| 164 |
+
|
| 165 |
+
up_layers = []
|
| 166 |
+
for scale in upsample_scales:
|
| 167 |
+
stretch = Stretch2d(scale, 1)
|
| 168 |
+
conv = nn.Conv2d(
|
| 169 |
+
in_channels=1, out_channels=1, kernel_size=(1, scale * 2 + 1), padding=(0, scale), bias=False
|
| 170 |
+
)
|
| 171 |
+
torch.nn.init.constant_(conv.weight, 1.0 / (scale * 2 + 1))
|
| 172 |
+
up_layers.append(stretch)
|
| 173 |
+
up_layers.append(conv)
|
| 174 |
+
self.upsample_layers = nn.Sequential(*up_layers)
|
| 175 |
+
|
| 176 |
+
def forward(self, specgram: Tensor) -> Tuple[Tensor, Tensor]:
|
| 177 |
+
r"""Pass the input through the UpsampleNetwork layer.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
specgram (Tensor): the input sequence to the UpsampleNetwork layer (n_batch, n_freq, n_time)
|
| 181 |
+
|
| 182 |
+
Return:
|
| 183 |
+
Tensor shape: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale),
|
| 184 |
+
(n_batch, n_output, (n_time - kernel_size + 1) * total_scale)
|
| 185 |
+
where total_scale is the product of all elements in upsample_scales.
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
resnet_output = self.resnet(specgram).unsqueeze(1)
|
| 189 |
+
resnet_output = self.resnet_stretch(resnet_output)
|
| 190 |
+
resnet_output = resnet_output.squeeze(1)
|
| 191 |
+
|
| 192 |
+
specgram = specgram.unsqueeze(1)
|
| 193 |
+
upsampling_output = self.upsample_layers(specgram)
|
| 194 |
+
upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent : -self.indent]
|
| 195 |
+
|
| 196 |
+
return upsampling_output, resnet_output
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class WaveRNN(nn.Module):
|
| 200 |
+
r"""WaveRNN model from *Efficient Neural Audio Synthesis* :cite:`wavernn`
|
| 201 |
+
based on the implementation from `fatchord/WaveRNN <https://github.com/fatchord/WaveRNN>`_.
|
| 202 |
+
|
| 203 |
+
The original implementation was introduced in *Efficient Neural Audio Synthesis*
|
| 204 |
+
:cite:`kalchbrenner2018efficient`. The input channels of waveform and spectrogram have to be 1.
|
| 205 |
+
The product of `upsample_scales` must equal `hop_length`.
|
| 206 |
+
|
| 207 |
+
See Also:
|
| 208 |
+
* `Training example <https://github.com/pytorch/audio/tree/release/0.12/examples/pipeline_wavernn>`__
|
| 209 |
+
* :class:`torchaudio.pipelines.Tacotron2TTSBundle`: TTS pipeline with pretrained model.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
upsample_scales: the list of upsample scales.
|
| 213 |
+
n_classes: the number of output classes.
|
| 214 |
+
hop_length: the number of samples between the starts of consecutive frames.
|
| 215 |
+
n_res_block: the number of ResBlock in stack. (Default: ``10``)
|
| 216 |
+
n_rnn: the dimension of RNN layer. (Default: ``512``)
|
| 217 |
+
n_fc: the dimension of fully connected layer. (Default: ``512``)
|
| 218 |
+
kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
|
| 219 |
+
n_freq: the number of bins in a spectrogram. (Default: ``128``)
|
| 220 |
+
n_hidden: the number of hidden dimensions of resblock. (Default: ``128``)
|
| 221 |
+
n_output: the number of output dimensions of melresnet. (Default: ``128``)
|
| 222 |
+
|
| 223 |
+
Example
|
| 224 |
+
>>> wavernn = WaveRNN(upsample_scales=[5,5,8], n_classes=512, hop_length=200)
|
| 225 |
+
>>> waveform, sample_rate = torchaudio.load(file)
|
| 226 |
+
>>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length)
|
| 227 |
+
>>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time)
|
| 228 |
+
>>> output = wavernn(waveform, specgram)
|
| 229 |
+
>>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, n_classes)
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
def __init__(
|
| 233 |
+
self,
|
| 234 |
+
upsample_scales: List[int],
|
| 235 |
+
n_classes: int,
|
| 236 |
+
hop_length: int,
|
| 237 |
+
n_res_block: int = 10,
|
| 238 |
+
n_rnn: int = 512,
|
| 239 |
+
n_fc: int = 512,
|
| 240 |
+
kernel_size: int = 5,
|
| 241 |
+
n_freq: int = 128,
|
| 242 |
+
n_hidden: int = 128,
|
| 243 |
+
n_output: int = 128,
|
| 244 |
+
) -> None:
|
| 245 |
+
super().__init__()
|
| 246 |
+
|
| 247 |
+
self.kernel_size = kernel_size
|
| 248 |
+
self._pad = (kernel_size - 1 if kernel_size % 2 else kernel_size) // 2
|
| 249 |
+
self.n_rnn = n_rnn
|
| 250 |
+
self.n_aux = n_output // 4
|
| 251 |
+
self.hop_length = hop_length
|
| 252 |
+
self.n_classes = n_classes
|
| 253 |
+
self.n_bits: int = int(math.log2(self.n_classes))
|
| 254 |
+
|
| 255 |
+
total_scale = 1
|
| 256 |
+
for upsample_scale in upsample_scales:
|
| 257 |
+
total_scale *= upsample_scale
|
| 258 |
+
if total_scale != self.hop_length:
|
| 259 |
+
raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}")
|
| 260 |
+
|
| 261 |
+
self.upsample = UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size)
|
| 262 |
+
self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn)
|
| 263 |
+
|
| 264 |
+
self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True)
|
| 265 |
+
self.rnn2 = nn.GRU(n_rnn + self.n_aux, n_rnn, batch_first=True)
|
| 266 |
+
|
| 267 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 268 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 269 |
+
|
| 270 |
+
self.fc1 = nn.Linear(n_rnn + self.n_aux, n_fc)
|
| 271 |
+
self.fc2 = nn.Linear(n_fc + self.n_aux, n_fc)
|
| 272 |
+
self.fc3 = nn.Linear(n_fc, self.n_classes)
|
| 273 |
+
|
| 274 |
+
def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor:
|
| 275 |
+
r"""Pass the input through the WaveRNN model.
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
waveform: the input waveform to the WaveRNN layer (n_batch, 1, (n_time - kernel_size + 1) * hop_length)
|
| 279 |
+
specgram: the input spectrogram to the WaveRNN layer (n_batch, 1, n_freq, n_time)
|
| 280 |
+
|
| 281 |
+
Return:
|
| 282 |
+
Tensor: shape (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes)
|
| 283 |
+
"""
|
| 284 |
+
|
| 285 |
+
if waveform.size(1) != 1:
|
| 286 |
+
raise ValueError("Require the input channel of waveform is 1")
|
| 287 |
+
if specgram.size(1) != 1:
|
| 288 |
+
raise ValueError("Require the input channel of specgram is 1")
|
| 289 |
+
# remove channel dimension until the end
|
| 290 |
+
waveform, specgram = waveform.squeeze(1), specgram.squeeze(1)
|
| 291 |
+
|
| 292 |
+
batch_size = waveform.size(0)
|
| 293 |
+
h1 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device)
|
| 294 |
+
h2 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device)
|
| 295 |
+
# output of upsample:
|
| 296 |
+
# specgram: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale)
|
| 297 |
+
# aux: (n_batch, n_output, (n_time - kernel_size + 1) * total_scale)
|
| 298 |
+
specgram, aux = self.upsample(specgram)
|
| 299 |
+
specgram = specgram.transpose(1, 2)
|
| 300 |
+
aux = aux.transpose(1, 2)
|
| 301 |
+
|
| 302 |
+
aux_idx = [self.n_aux * i for i in range(5)]
|
| 303 |
+
a1 = aux[:, :, aux_idx[0] : aux_idx[1]]
|
| 304 |
+
a2 = aux[:, :, aux_idx[1] : aux_idx[2]]
|
| 305 |
+
a3 = aux[:, :, aux_idx[2] : aux_idx[3]]
|
| 306 |
+
a4 = aux[:, :, aux_idx[3] : aux_idx[4]]
|
| 307 |
+
|
| 308 |
+
x = torch.cat([waveform.unsqueeze(-1), specgram, a1], dim=-1)
|
| 309 |
+
x = self.fc(x)
|
| 310 |
+
res = x
|
| 311 |
+
x, _ = self.rnn1(x, h1)
|
| 312 |
+
|
| 313 |
+
x = x + res
|
| 314 |
+
res = x
|
| 315 |
+
x = torch.cat([x, a2], dim=-1)
|
| 316 |
+
x, _ = self.rnn2(x, h2)
|
| 317 |
+
|
| 318 |
+
x = x + res
|
| 319 |
+
x = torch.cat([x, a3], dim=-1)
|
| 320 |
+
x = self.fc1(x)
|
| 321 |
+
x = self.relu1(x)
|
| 322 |
+
|
| 323 |
+
x = torch.cat([x, a4], dim=-1)
|
| 324 |
+
x = self.fc2(x)
|
| 325 |
+
x = self.relu2(x)
|
| 326 |
+
x = self.fc3(x)
|
| 327 |
+
|
| 328 |
+
# bring back channel dimension
|
| 329 |
+
return x.unsqueeze(1)
|
| 330 |
+
|
| 331 |
+
@torch.jit.export
|
| 332 |
+
def infer(self, specgram: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
|
| 333 |
+
r"""Inference method of WaveRNN.
|
| 334 |
+
|
| 335 |
+
This function currently only supports multinomial sampling, which assumes the
|
| 336 |
+
network is trained on cross entropy loss.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
specgram (Tensor):
|
| 340 |
+
Batch of spectrograms. Shape: `(n_batch, n_freq, n_time)`.
|
| 341 |
+
lengths (Tensor or None, optional):
|
| 342 |
+
Indicates the valid length of each audio in the batch.
|
| 343 |
+
Shape: `(batch, )`.
|
| 344 |
+
When the ``specgram`` contains spectrograms with different durations,
|
| 345 |
+
by providing ``lengths`` argument, the model will compute
|
| 346 |
+
the corresponding valid output lengths.
|
| 347 |
+
If ``None``, it is assumed that all the audio in ``waveforms``
|
| 348 |
+
have valid length. Default: ``None``.
|
| 349 |
+
|
| 350 |
+
Returns:
|
| 351 |
+
(Tensor, Optional[Tensor]):
|
| 352 |
+
Tensor
|
| 353 |
+
The inferred waveform of size `(n_batch, 1, n_time)`.
|
| 354 |
+
1 stands for a single channel.
|
| 355 |
+
Tensor or None
|
| 356 |
+
If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
|
| 357 |
+
is returned.
|
| 358 |
+
It indicates the valid length in time axis of the output Tensor.
|
| 359 |
+
"""
|
| 360 |
+
|
| 361 |
+
device = specgram.device
|
| 362 |
+
dtype = specgram.dtype
|
| 363 |
+
|
| 364 |
+
specgram = torch.nn.functional.pad(specgram, (self._pad, self._pad))
|
| 365 |
+
specgram, aux = self.upsample(specgram)
|
| 366 |
+
if lengths is not None:
|
| 367 |
+
lengths = lengths * self.upsample.total_scale
|
| 368 |
+
|
| 369 |
+
output: List[Tensor] = []
|
| 370 |
+
b_size, _, seq_len = specgram.size()
|
| 371 |
+
|
| 372 |
+
h1 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype)
|
| 373 |
+
h2 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype)
|
| 374 |
+
x = torch.zeros((b_size, 1), device=device, dtype=dtype)
|
| 375 |
+
|
| 376 |
+
aux_split = [aux[:, self.n_aux * i : self.n_aux * (i + 1), :] for i in range(4)]
|
| 377 |
+
|
| 378 |
+
for i in range(seq_len):
|
| 379 |
+
|
| 380 |
+
m_t = specgram[:, :, i]
|
| 381 |
+
|
| 382 |
+
a1_t, a2_t, a3_t, a4_t = [a[:, :, i] for a in aux_split]
|
| 383 |
+
|
| 384 |
+
x = torch.cat([x, m_t, a1_t], dim=1)
|
| 385 |
+
x = self.fc(x)
|
| 386 |
+
_, h1 = self.rnn1(x.unsqueeze(1), h1)
|
| 387 |
+
|
| 388 |
+
x = x + h1[0]
|
| 389 |
+
inp = torch.cat([x, a2_t], dim=1)
|
| 390 |
+
_, h2 = self.rnn2(inp.unsqueeze(1), h2)
|
| 391 |
+
|
| 392 |
+
x = x + h2[0]
|
| 393 |
+
x = torch.cat([x, a3_t], dim=1)
|
| 394 |
+
x = F.relu(self.fc1(x))
|
| 395 |
+
|
| 396 |
+
x = torch.cat([x, a4_t], dim=1)
|
| 397 |
+
x = F.relu(self.fc2(x))
|
| 398 |
+
|
| 399 |
+
logits = self.fc3(x)
|
| 400 |
+
|
| 401 |
+
posterior = F.softmax(logits, dim=1)
|
| 402 |
+
|
| 403 |
+
x = torch.multinomial(posterior, 1).float()
|
| 404 |
+
# Transform label [0, 2 ** n_bits - 1] to waveform [-1, 1]
|
| 405 |
+
x = 2 * x / (2**self.n_bits - 1.0) - 1.0
|
| 406 |
+
|
| 407 |
+
output.append(x)
|
| 408 |
+
|
| 409 |
+
return torch.stack(output).permute(1, 2, 0), lengths
|
.venv/lib/python3.11/site-packages/torchaudio/prototype/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/torchaudio/prototype/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (193 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .musan import Musan
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
__all__ = ["Musan"]
|
.venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (278 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/__pycache__/musan.cpython-311.pyc
ADDED
|
Binary file (3.72 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/musan.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
from torchaudio.datasets.utils import _load_waveform
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
_SUBSETS = ["music", "noise", "speech"]
|
| 10 |
+
_SAMPLE_RATE = 16_000
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Musan(Dataset):
|
| 14 |
+
r"""*MUSAN* :cite:`musan2015` dataset.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
root (str or Path): Root directory where the dataset's top-level directory exists.
|
| 18 |
+
subset (str): Subset of the dataset to use. Options: [``"music"``, ``"noise"``, ``"speech"``].
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, root: Union[str, Path], subset: str):
|
| 22 |
+
if subset not in _SUBSETS:
|
| 23 |
+
raise ValueError(f"Invalid subset '{subset}' given. Please provide one of {_SUBSETS}")
|
| 24 |
+
|
| 25 |
+
subset_path = Path(root) / subset
|
| 26 |
+
self._walker = [str(p) for p in subset_path.glob("*/*.*")]
|
| 27 |
+
|
| 28 |
+
def get_metadata(self, n: int) -> Tuple[str, int, str]:
|
| 29 |
+
r"""Get metadata for the n-th sample in the dataset. Returns filepath instead of waveform,
|
| 30 |
+
but otherwise returns the same fields as :py:func:`__getitem__`.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
n (int): Index of sample to be loaded.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
(str, int, str):
|
| 37 |
+
str
|
| 38 |
+
Path to audio.
|
| 39 |
+
int
|
| 40 |
+
Sample rate.
|
| 41 |
+
str
|
| 42 |
+
File name.
|
| 43 |
+
"""
|
| 44 |
+
audio_path = self._walker[n]
|
| 45 |
+
return audio_path, _SAMPLE_RATE, Path(audio_path).name
|
| 46 |
+
|
| 47 |
+
def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str]:
|
| 48 |
+
r"""Return the n-th sample in the dataset.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
n (int): Index of sample to be loaded.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
(torch.Tensor, int, str):
|
| 55 |
+
torch.Tensor
|
| 56 |
+
Waveform.
|
| 57 |
+
int
|
| 58 |
+
Sample rate.
|
| 59 |
+
str
|
| 60 |
+
File name.
|
| 61 |
+
"""
|
| 62 |
+
audio_path, sample_rate, filename = self.get_metadata(n)
|
| 63 |
+
path = Path(audio_path)
|
| 64 |
+
return _load_waveform(path.parent, path.name, sample_rate), sample_rate, filename
|
| 65 |
+
|
| 66 |
+
def __len__(self) -> int:
|
| 67 |
+
return len(self._walker)
|
.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._dsp import (
|
| 2 |
+
adsr_envelope,
|
| 3 |
+
exp_sigmoid,
|
| 4 |
+
extend_pitch,
|
| 5 |
+
filter_waveform,
|
| 6 |
+
frequency_impulse_response,
|
| 7 |
+
oscillator_bank,
|
| 8 |
+
sinc_impulse_response,
|
| 9 |
+
)
|
| 10 |
+
from ._rir import ray_tracing, simulate_rir_ism
|
| 11 |
+
from .functional import barkscale_fbanks, chroma_filterbank
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"adsr_envelope",
|
| 16 |
+
"exp_sigmoid",
|
| 17 |
+
"barkscale_fbanks",
|
| 18 |
+
"chroma_filterbank",
|
| 19 |
+
"extend_pitch",
|
| 20 |
+
"filter_waveform",
|
| 21 |
+
"frequency_impulse_response",
|
| 22 |
+
"oscillator_bank",
|
| 23 |
+
"ray_tracing",
|
| 24 |
+
"sinc_impulse_response",
|
| 25 |
+
"simulate_rir_ism",
|
| 26 |
+
]
|
.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (779 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/_dsp.cpython-311.pyc
ADDED
|
Binary file (20.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/_rir.cpython-311.pyc
ADDED
|
Binary file (21.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/functional.cpython-311.pyc
ADDED
|
Binary file (8.67 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/_dsp.py
ADDED
|
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from typing import List, Optional, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from torchaudio.functional import fftconvolve
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def oscillator_bank(
|
| 10 |
+
frequencies: torch.Tensor,
|
| 11 |
+
amplitudes: torch.Tensor,
|
| 12 |
+
sample_rate: float,
|
| 13 |
+
reduction: str = "sum",
|
| 14 |
+
dtype: Optional[torch.dtype] = torch.float64,
|
| 15 |
+
) -> torch.Tensor:
|
| 16 |
+
"""Synthesize waveform from the given instantaneous frequencies and amplitudes.
|
| 17 |
+
|
| 18 |
+
.. devices:: CPU CUDA
|
| 19 |
+
|
| 20 |
+
.. properties:: Autograd TorchScript
|
| 21 |
+
|
| 22 |
+
Note:
|
| 23 |
+
The phase information of the output waveform is found by taking the cumulative sum
|
| 24 |
+
of the given instantaneous frequencies (``frequencies``).
|
| 25 |
+
This incurs roundoff error when the data type does not have enough precision.
|
| 26 |
+
Using ``torch.float64`` can work around this.
|
| 27 |
+
|
| 28 |
+
The following figure shows the difference between ``torch.float32`` and
|
| 29 |
+
``torch.float64`` when generating a sin wave of constant frequency and amplitude
|
| 30 |
+
with sample rate 8000 [Hz].
|
| 31 |
+
Notice that ``torch.float32`` version shows artifacts that are not seen in
|
| 32 |
+
``torch.float64`` version.
|
| 33 |
+
|
| 34 |
+
.. image:: https://download.pytorch.org/torchaudio/doc-assets/oscillator_precision.png
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
frequencies (Tensor): Sample-wise oscillator frequencies (Hz). Shape `(..., time, N)`.
|
| 38 |
+
amplitudes (Tensor): Sample-wise oscillator amplitude. Shape: `(..., time, N)`.
|
| 39 |
+
sample_rate (float): Sample rate
|
| 40 |
+
reduction (str): Reduction to perform.
|
| 41 |
+
Valid values are ``"sum"``, ``"mean"`` or ``"none"``. Default: ``"sum"``
|
| 42 |
+
dtype (torch.dtype or None, optional): The data type on which cumulative sum operation is performed.
|
| 43 |
+
Default: ``torch.float64``. Pass ``None`` to disable the casting.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
Tensor:
|
| 47 |
+
The resulting waveform.
|
| 48 |
+
|
| 49 |
+
If ``reduction`` is ``"none"``, then the shape is
|
| 50 |
+
`(..., time, N)`, otherwise the shape is `(..., time)`.
|
| 51 |
+
"""
|
| 52 |
+
if frequencies.shape != amplitudes.shape:
|
| 53 |
+
raise ValueError(
|
| 54 |
+
"The shapes of `frequencies` and `amplitudes` must match. "
|
| 55 |
+
f"Found: {frequencies.shape} and {amplitudes.shape} respectively."
|
| 56 |
+
)
|
| 57 |
+
reductions = ["sum", "mean", "none"]
|
| 58 |
+
if reduction not in reductions:
|
| 59 |
+
raise ValueError(f"The value of reduction must be either {reductions}. Found: {reduction}")
|
| 60 |
+
|
| 61 |
+
invalid = torch.abs(frequencies) >= sample_rate / 2
|
| 62 |
+
if torch.any(invalid):
|
| 63 |
+
warnings.warn(
|
| 64 |
+
"Some frequencies are above nyquist frequency. "
|
| 65 |
+
"Setting the corresponding amplitude to zero. "
|
| 66 |
+
"This might cause numerically unstable gradient."
|
| 67 |
+
)
|
| 68 |
+
amplitudes = torch.where(invalid, 0.0, amplitudes)
|
| 69 |
+
|
| 70 |
+
pi2 = 2.0 * torch.pi
|
| 71 |
+
freqs = frequencies * pi2 / sample_rate % pi2
|
| 72 |
+
phases = torch.cumsum(freqs, dim=-2, dtype=dtype)
|
| 73 |
+
if dtype is not None and freqs.dtype != dtype:
|
| 74 |
+
phases = phases.to(freqs.dtype)
|
| 75 |
+
|
| 76 |
+
waveform = amplitudes * torch.sin(phases)
|
| 77 |
+
if reduction == "sum":
|
| 78 |
+
return waveform.sum(-1)
|
| 79 |
+
if reduction == "mean":
|
| 80 |
+
return waveform.mean(-1)
|
| 81 |
+
return waveform
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def adsr_envelope(
|
| 85 |
+
num_frames: int,
|
| 86 |
+
*,
|
| 87 |
+
attack: float = 0.0,
|
| 88 |
+
hold: float = 0.0,
|
| 89 |
+
decay: float = 0.0,
|
| 90 |
+
sustain: float = 1.0,
|
| 91 |
+
release: float = 0.0,
|
| 92 |
+
n_decay: int = 2,
|
| 93 |
+
dtype: Optional[torch.dtype] = None,
|
| 94 |
+
device: Optional[torch.device] = None,
|
| 95 |
+
):
|
| 96 |
+
"""Generate ADSR Envelope
|
| 97 |
+
|
| 98 |
+
.. devices:: CPU CUDA
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
num_frames (int): The number of output frames.
|
| 102 |
+
attack (float, optional):
|
| 103 |
+
The relative *time* it takes to reach the maximum level from
|
| 104 |
+
the start. (Default: ``0.0``)
|
| 105 |
+
hold (float, optional):
|
| 106 |
+
The relative *time* the maximum level is held before
|
| 107 |
+
it starts to decay. (Default: ``0.0``)
|
| 108 |
+
decay (float, optional):
|
| 109 |
+
The relative *time* it takes to sustain from
|
| 110 |
+
the maximum level. (Default: ``0.0``)
|
| 111 |
+
sustain (float, optional): The relative *level* at which
|
| 112 |
+
the sound should sustain. (Default: ``1.0``)
|
| 113 |
+
|
| 114 |
+
.. Note::
|
| 115 |
+
The duration of sustain is derived as `1.0 - (The sum of attack, hold, decay and release)`.
|
| 116 |
+
|
| 117 |
+
release (float, optional): The relative *time* it takes for the sound level to
|
| 118 |
+
reach zero after the sustain. (Default: ``0.0``)
|
| 119 |
+
n_decay (int, optional): The degree of polynomial decay. Default: ``2``.
|
| 120 |
+
dtype (torch.dtype, optional): the desired data type of returned tensor.
|
| 121 |
+
Default: if ``None``, uses a global default
|
| 122 |
+
(see :py:func:`torch.set_default_tensor_type`).
|
| 123 |
+
device (torch.device, optional): the desired device of returned tensor.
|
| 124 |
+
Default: if ``None``, uses the current device for the default tensor type
|
| 125 |
+
(see :py:func:`torch.set_default_tensor_type`).
|
| 126 |
+
device will be the CPU for CPU tensor types and the current CUDA
|
| 127 |
+
device for CUDA tensor types.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
Tensor: ADSR Envelope. Shape: `(num_frames, )`
|
| 131 |
+
|
| 132 |
+
Example
|
| 133 |
+
.. image:: https://download.pytorch.org/torchaudio/doc-assets/adsr_examples.png
|
| 134 |
+
|
| 135 |
+
"""
|
| 136 |
+
if not 0 <= attack <= 1:
|
| 137 |
+
raise ValueError(f"The value of `attack` must be within [0, 1]. Found: {attack}")
|
| 138 |
+
if not 0 <= decay <= 1:
|
| 139 |
+
raise ValueError(f"The value of `decay` must be within [0, 1]. Found: {decay}")
|
| 140 |
+
if not 0 <= sustain <= 1:
|
| 141 |
+
raise ValueError(f"The value of `sustain` must be within [0, 1]. Found: {sustain}")
|
| 142 |
+
if not 0 <= hold <= 1:
|
| 143 |
+
raise ValueError(f"The value of `hold` must be within [0, 1]. Found: {hold}")
|
| 144 |
+
if not 0 <= release <= 1:
|
| 145 |
+
raise ValueError(f"The value of `release` must be within [0, 1]. Found: {release}")
|
| 146 |
+
if attack + decay + release + hold > 1:
|
| 147 |
+
raise ValueError("The sum of `attack`, `hold`, `decay` and `release` must not exceed 1.")
|
| 148 |
+
|
| 149 |
+
nframes = num_frames - 1
|
| 150 |
+
num_a = int(nframes * attack)
|
| 151 |
+
num_h = int(nframes * hold)
|
| 152 |
+
num_d = int(nframes * decay)
|
| 153 |
+
num_r = int(nframes * release)
|
| 154 |
+
|
| 155 |
+
# Initialize with sustain
|
| 156 |
+
out = torch.full((num_frames,), float(sustain), device=device, dtype=dtype)
|
| 157 |
+
|
| 158 |
+
# attack
|
| 159 |
+
if num_a > 0:
|
| 160 |
+
torch.linspace(0.0, 1.0, num_a + 1, out=out[: num_a + 1])
|
| 161 |
+
|
| 162 |
+
# hold
|
| 163 |
+
if num_h > 0:
|
| 164 |
+
out[num_a : num_a + num_h + 1] = 1.0
|
| 165 |
+
|
| 166 |
+
# decay
|
| 167 |
+
if num_d > 0:
|
| 168 |
+
# Compute: sustain + (1.0 - sustain) * (linspace[1, 0] ** n_decay)
|
| 169 |
+
i = num_a + num_h
|
| 170 |
+
decay = out[i : i + num_d + 1]
|
| 171 |
+
torch.linspace(1.0, 0.0, num_d + 1, out=decay)
|
| 172 |
+
decay **= n_decay
|
| 173 |
+
decay *= 1.0 - sustain
|
| 174 |
+
decay += sustain
|
| 175 |
+
|
| 176 |
+
# sustain is handled by initialization
|
| 177 |
+
|
| 178 |
+
# release
|
| 179 |
+
if num_r > 0:
|
| 180 |
+
torch.linspace(sustain, 0, num_r + 1, out=out[-num_r - 1 :])
|
| 181 |
+
|
| 182 |
+
return out
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def extend_pitch(
|
| 186 |
+
base: torch.Tensor,
|
| 187 |
+
pattern: Union[int, List[float], torch.Tensor],
|
| 188 |
+
):
|
| 189 |
+
"""Extend the given time series values with multipliers of them.
|
| 190 |
+
|
| 191 |
+
.. devices:: CPU CUDA
|
| 192 |
+
|
| 193 |
+
.. properties:: Autograd TorchScript
|
| 194 |
+
|
| 195 |
+
Given a series of fundamental frequencies (pitch), this function appends
|
| 196 |
+
its harmonic overtones or inharmonic partials.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
base (torch.Tensor):
|
| 200 |
+
Base time series, like fundamental frequencies (Hz). Shape: `(..., time, 1)`.
|
| 201 |
+
pattern (int, list of floats or torch.Tensor):
|
| 202 |
+
If ``int``, the number of pitch series after the operation.
|
| 203 |
+
`pattern - 1` tones are added, so that the resulting Tensor contains
|
| 204 |
+
up to `pattern`-th overtones of the given series.
|
| 205 |
+
|
| 206 |
+
If list of float or ``torch.Tensor``, it must be one dimensional,
|
| 207 |
+
representing the custom multiplier of the fundamental frequency.
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
Tensor: Oscillator frequencies (Hz). Shape: `(..., time, num_tones)`.
|
| 211 |
+
|
| 212 |
+
Example
|
| 213 |
+
>>> # fundamental frequency
|
| 214 |
+
>>> f0 = torch.linspace(1, 5, 5).unsqueeze(-1)
|
| 215 |
+
>>> f0
|
| 216 |
+
tensor([[1.],
|
| 217 |
+
[2.],
|
| 218 |
+
[3.],
|
| 219 |
+
[4.],
|
| 220 |
+
[5.]])
|
| 221 |
+
>>> # Add harmonic overtones, up to 3rd.
|
| 222 |
+
>>> f = extend_pitch(f0, 3)
|
| 223 |
+
>>> f.shape
|
| 224 |
+
torch.Size([5, 3])
|
| 225 |
+
>>> f
|
| 226 |
+
tensor([[ 1., 2., 3.],
|
| 227 |
+
[ 2., 4., 6.],
|
| 228 |
+
[ 3., 6., 9.],
|
| 229 |
+
[ 4., 8., 12.],
|
| 230 |
+
[ 5., 10., 15.]])
|
| 231 |
+
>>> # Add custom (inharmonic) partials.
|
| 232 |
+
>>> f = extend_pitch(f0, torch.tensor([1, 2.1, 3.3, 4.5]))
|
| 233 |
+
>>> f.shape
|
| 234 |
+
torch.Size([5, 4])
|
| 235 |
+
>>> f
|
| 236 |
+
tensor([[ 1.0000, 2.1000, 3.3000, 4.5000],
|
| 237 |
+
[ 2.0000, 4.2000, 6.6000, 9.0000],
|
| 238 |
+
[ 3.0000, 6.3000, 9.9000, 13.5000],
|
| 239 |
+
[ 4.0000, 8.4000, 13.2000, 18.0000],
|
| 240 |
+
[ 5.0000, 10.5000, 16.5000, 22.5000]])
|
| 241 |
+
"""
|
| 242 |
+
if isinstance(pattern, torch.Tensor):
|
| 243 |
+
mult = pattern
|
| 244 |
+
elif isinstance(pattern, int):
|
| 245 |
+
mult = torch.linspace(1.0, float(pattern), pattern, device=base.device, dtype=base.dtype)
|
| 246 |
+
else:
|
| 247 |
+
mult = torch.tensor(pattern, dtype=base.dtype, device=base.device)
|
| 248 |
+
h_freq = base @ mult.unsqueeze(0)
|
| 249 |
+
return h_freq
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def sinc_impulse_response(cutoff: torch.Tensor, window_size: int = 513, high_pass: bool = False):
|
| 253 |
+
"""Create windowed-sinc impulse response for given cutoff frequencies.
|
| 254 |
+
|
| 255 |
+
.. devices:: CPU CUDA
|
| 256 |
+
|
| 257 |
+
.. properties:: Autograd TorchScript
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
cutoff (Tensor): Cutoff frequencies for low-pass sinc filter.
|
| 261 |
+
|
| 262 |
+
window_size (int, optional): Size of the Hamming window to apply. Must be odd.
|
| 263 |
+
(Default: 513)
|
| 264 |
+
|
| 265 |
+
high_pass (bool, optional):
|
| 266 |
+
If ``True``, convert the resulting filter to high-pass.
|
| 267 |
+
Otherwise low-pass filter is returned. Default: ``False``.
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
Tensor: A series of impulse responses. Shape: `(..., window_size)`.
|
| 271 |
+
"""
|
| 272 |
+
if window_size % 2 == 0:
|
| 273 |
+
raise ValueError(f"`window_size` must be odd. Given: {window_size}")
|
| 274 |
+
|
| 275 |
+
half = window_size // 2
|
| 276 |
+
device, dtype = cutoff.device, cutoff.dtype
|
| 277 |
+
idx = torch.linspace(-half, half, window_size, device=device, dtype=dtype)
|
| 278 |
+
|
| 279 |
+
filt = torch.special.sinc(cutoff.unsqueeze(-1) * idx.unsqueeze(0))
|
| 280 |
+
filt = filt * torch.hamming_window(window_size, device=device, dtype=dtype, periodic=False).unsqueeze(0)
|
| 281 |
+
filt = filt / filt.sum(dim=-1, keepdim=True).abs()
|
| 282 |
+
|
| 283 |
+
# High pass IR is obtained by subtracting low_pass IR from delta function.
|
| 284 |
+
# https://courses.engr.illinois.edu/ece401/fa2020/slides/lec10.pdf
|
| 285 |
+
if high_pass:
|
| 286 |
+
filt = -filt
|
| 287 |
+
filt[..., half] = 1.0 + filt[..., half]
|
| 288 |
+
return filt
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def frequency_impulse_response(magnitudes):
|
| 292 |
+
"""Create filter from desired frequency response
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
magnitudes: The desired frequency responses. Shape: `(..., num_fft_bins)`
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
Tensor: Impulse response. Shape `(..., 2 * (num_fft_bins - 1))`
|
| 299 |
+
"""
|
| 300 |
+
if magnitudes.min() < 0.0:
|
| 301 |
+
# Negative magnitude does not make sense but allowing so that autograd works
|
| 302 |
+
# around 0.
|
| 303 |
+
# Should we raise error?
|
| 304 |
+
warnings.warn("The input frequency response should not contain negative values.")
|
| 305 |
+
ir = torch.fft.fftshift(torch.fft.irfft(magnitudes), dim=-1)
|
| 306 |
+
device, dtype = magnitudes.device, magnitudes.dtype
|
| 307 |
+
window = torch.hann_window(ir.size(-1), periodic=False, device=device, dtype=dtype).expand_as(ir)
|
| 308 |
+
return ir * window
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def _overlap_and_add(waveform, stride):
|
| 312 |
+
num_frames, frame_size = waveform.shape[-2:]
|
| 313 |
+
numel = (num_frames - 1) * stride + frame_size
|
| 314 |
+
buffer = torch.zeros(waveform.shape[:-2] + (numel,), device=waveform.device, dtype=waveform.dtype)
|
| 315 |
+
for i in range(num_frames):
|
| 316 |
+
start = i * stride
|
| 317 |
+
end = start + frame_size
|
| 318 |
+
buffer[..., start:end] += waveform[..., i, :]
|
| 319 |
+
return buffer
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def filter_waveform(waveform: torch.Tensor, kernels: torch.Tensor, delay_compensation: int = -1):
|
| 323 |
+
"""Applies filters along time axis of the given waveform.
|
| 324 |
+
|
| 325 |
+
This function applies the given filters along time axis in the following manner:
|
| 326 |
+
|
| 327 |
+
1. Split the given waveform into chunks. The number of chunks is equal to the number of given filters.
|
| 328 |
+
2. Filter each chunk with corresponding filter.
|
| 329 |
+
3. Place the filtered chunks at the original indices while adding up the overlapping parts.
|
| 330 |
+
4. Crop the resulting waveform so that delay introduced by the filter is removed and its length
|
| 331 |
+
matches that of the input waveform.
|
| 332 |
+
|
| 333 |
+
The following figure illustrates this.
|
| 334 |
+
|
| 335 |
+
.. image:: https://download.pytorch.org/torchaudio/doc-assets/filter_waveform.png
|
| 336 |
+
|
| 337 |
+
.. note::
|
| 338 |
+
|
| 339 |
+
If the number of filters is one, then the operation becomes stationary.
|
| 340 |
+
i.e. the same filtering is applied across the time axis.
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
waveform (Tensor): Shape `(..., time)`.
|
| 344 |
+
kernels (Tensor): Impulse responses.
|
| 345 |
+
Valid inputs are 2D tensor with shape `(num_filters, filter_length)` or
|
| 346 |
+
`(N+1)`-D tensor with shape `(..., num_filters, filter_length)`, where `N` is
|
| 347 |
+
the dimension of waveform.
|
| 348 |
+
|
| 349 |
+
In case of 2D input, the same set of filters is used across channels and batches.
|
| 350 |
+
Otherwise, different sets of filters are applied. In this case, the shape of
|
| 351 |
+
the first `N-1` dimensions of filters must match (or be broadcastable to) that of waveform.
|
| 352 |
+
|
| 353 |
+
delay_compensation (int): Control how the waveform is cropped after full convolution.
|
| 354 |
+
If the value is zero or positive, it is interpreted as the length of crop at the
|
| 355 |
+
beginning of the waveform. The value cannot be larger than the size of filter kernel.
|
| 356 |
+
Otherwise the initial crop is ``filter_size // 2``.
|
| 357 |
+
When cropping happens, the waveform is also cropped from the end so that the
|
| 358 |
+
length of the resulting waveform matches the input waveform.
|
| 359 |
+
|
| 360 |
+
Returns:
|
| 361 |
+
Tensor: `(..., time)`.
|
| 362 |
+
"""
|
| 363 |
+
if kernels.ndim not in [2, waveform.ndim + 1]:
|
| 364 |
+
raise ValueError(
|
| 365 |
+
"`kernels` must be 2 or N+1 dimension where "
|
| 366 |
+
f"N is the dimension of waveform. Found: {kernels.ndim} (N={waveform.ndim})"
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
num_filters, filter_size = kernels.shape[-2:]
|
| 370 |
+
num_frames = waveform.size(-1)
|
| 371 |
+
|
| 372 |
+
if delay_compensation > filter_size:
|
| 373 |
+
raise ValueError(
|
| 374 |
+
"When `delay_compenstation` is provided, it cannot be larger than the size of filters."
|
| 375 |
+
f"Found: delay_compensation={delay_compensation}, filter_size={filter_size}"
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
# Transform waveform's time axis into (num_filters x chunk_length) with optional padding
|
| 379 |
+
chunk_length = num_frames // num_filters
|
| 380 |
+
if num_frames % num_filters > 0:
|
| 381 |
+
chunk_length += 1
|
| 382 |
+
num_pad = chunk_length * num_filters - num_frames
|
| 383 |
+
waveform = torch.nn.functional.pad(waveform, [0, num_pad], "constant", 0)
|
| 384 |
+
chunked = waveform.unfold(-1, chunk_length, chunk_length)
|
| 385 |
+
assert chunked.numel() >= waveform.numel()
|
| 386 |
+
|
| 387 |
+
# Broadcast kernels
|
| 388 |
+
if waveform.ndim + 1 > kernels.ndim:
|
| 389 |
+
expand_shape = waveform.shape[:-1] + kernels.shape
|
| 390 |
+
kernels = kernels.expand(expand_shape)
|
| 391 |
+
|
| 392 |
+
convolved = fftconvolve(chunked, kernels)
|
| 393 |
+
restored = _overlap_and_add(convolved, chunk_length)
|
| 394 |
+
|
| 395 |
+
# Trim in a way that the number of samples are same as input,
|
| 396 |
+
# and the filter delay is compensated
|
| 397 |
+
if delay_compensation >= 0:
|
| 398 |
+
start = delay_compensation
|
| 399 |
+
else:
|
| 400 |
+
start = filter_size // 2
|
| 401 |
+
num_crops = restored.size(-1) - num_frames
|
| 402 |
+
end = num_crops - start
|
| 403 |
+
result = restored[..., start:-end]
|
| 404 |
+
return result
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def exp_sigmoid(
|
| 408 |
+
input: torch.Tensor, exponent: float = 10.0, max_value: float = 2.0, threshold: float = 1e-7
|
| 409 |
+
) -> torch.Tensor:
|
| 410 |
+
"""Exponential Sigmoid pointwise nonlinearity.
|
| 411 |
+
Implements the equation:
|
| 412 |
+
``max_value`` * sigmoid(``input``) ** (log(``exponent``)) + ``threshold``
|
| 413 |
+
|
| 414 |
+
The output has a range of [``threshold``, ``max_value``].
|
| 415 |
+
``exponent`` controls the slope of the output.
|
| 416 |
+
|
| 417 |
+
.. devices:: CPU CUDA
|
| 418 |
+
|
| 419 |
+
Args:
|
| 420 |
+
input (Tensor): Input Tensor
|
| 421 |
+
exponent (float, optional): Exponent. Controls the slope of the output
|
| 422 |
+
max_value (float, optional): Maximum value of the output
|
| 423 |
+
threshold (float, optional): Minimum value of the output
|
| 424 |
+
|
| 425 |
+
Returns:
|
| 426 |
+
Tensor: Exponential Sigmoid output. Shape: same as input
|
| 427 |
+
|
| 428 |
+
"""
|
| 429 |
+
|
| 430 |
+
return max_value * torch.pow(
|
| 431 |
+
torch.nn.functional.sigmoid(input),
|
| 432 |
+
torch.log(torch.tensor(exponent, device=input.device, dtype=input.dtype)),
|
| 433 |
+
) + torch.tensor(threshold, device=input.device, dtype=input.dtype)
|
.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/_rir.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torchaudio
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _compute_image_sources(
|
| 10 |
+
room: torch.Tensor,
|
| 11 |
+
source: torch.Tensor,
|
| 12 |
+
max_order: int,
|
| 13 |
+
absorption: torch.Tensor,
|
| 14 |
+
scatter: Optional[torch.Tensor] = None,
|
| 15 |
+
) -> Tuple[Tensor, Tensor]:
|
| 16 |
+
"""Compute image sources in a shoebox-like room.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
room (torch.Tensor): The 1D Tensor to determine the room size. The shape is
|
| 20 |
+
`(D,)`, where ``D`` is 2 if room is a 2D room, or 3 if room is a 3D room.
|
| 21 |
+
source (torch.Tensor): The coordinate of the sound source. Tensor with dimensions
|
| 22 |
+
`(D)`.
|
| 23 |
+
max_order (int): The maximum number of reflections of the source.
|
| 24 |
+
absorption (torch.Tensor): The absorption coefficients of wall materials.
|
| 25 |
+
``absorption`` is a Tensor with dimensions `(num_band, num_wall)`.
|
| 26 |
+
The shape options are ``[(1, 4), (1, 6), (7, 4), (7, 6)]``.
|
| 27 |
+
``num_band`` is `1` if the coefficients is the same for all frequencies, or is `7`
|
| 28 |
+
if the coefficients are different to different frequencies. `7` refers to the default number
|
| 29 |
+
of octave bands. (See note in `simulate_rir_ism` method).
|
| 30 |
+
``num_wall`` is `4` if the room is a 2D room, representing absorption coefficients
|
| 31 |
+
of ``"west"``, ``"east"``, ``"south"``, and ``"north"`` walls, respectively.
|
| 32 |
+
Or it is `6` if the room is a 3D room, representing absorption coefficients
|
| 33 |
+
of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``, and ``"ceiling"``, respectively.
|
| 34 |
+
scatter (torch.Tensor): The scattering coefficients of wall materials.
|
| 35 |
+
The shape of ``scatter`` must match that of ``absorption``. If ``None``, it is not
|
| 36 |
+
used in image source computation. (Default: ``None``)
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
(torch.Tensor): The coordinates of all image sources within ``max_order`` number of reflections.
|
| 40 |
+
Tensor with dimensions `(num_image_source, D)`.
|
| 41 |
+
(torch.Tensor): The attenuation of corresponding image sources. Tensor with dimensions
|
| 42 |
+
`(num_band, num_image_source)`.
|
| 43 |
+
"""
|
| 44 |
+
if scatter is None:
|
| 45 |
+
tr = torch.sqrt(1 - absorption)
|
| 46 |
+
else:
|
| 47 |
+
tr = torch.sqrt(1 - absorption) * torch.sqrt(1 - scatter)
|
| 48 |
+
|
| 49 |
+
ind = torch.arange(-max_order, max_order + 1, device=source.device)
|
| 50 |
+
if room.shape[0] == 2:
|
| 51 |
+
XYZ = torch.meshgrid(ind, ind, indexing="ij")
|
| 52 |
+
else:
|
| 53 |
+
XYZ = torch.meshgrid(ind, ind, ind, indexing="ij")
|
| 54 |
+
XYZ = torch.stack([c.reshape((-1,)) for c in XYZ], dim=-1)
|
| 55 |
+
XYZ = XYZ[XYZ.abs().sum(dim=-1) <= max_order]
|
| 56 |
+
|
| 57 |
+
# compute locations of image sources
|
| 58 |
+
d = room[None, :]
|
| 59 |
+
s = source[None, :]
|
| 60 |
+
img_loc = torch.where(XYZ % 2 == 1, d * (XYZ + 1) - s, d * XYZ + s)
|
| 61 |
+
|
| 62 |
+
# attenuation
|
| 63 |
+
exp_lo = abs(torch.floor((XYZ / 2)))
|
| 64 |
+
exp_hi = abs(torch.floor((XYZ + 1) / 2))
|
| 65 |
+
t_lo = tr[:, ::2].unsqueeze(1).repeat(1, XYZ.shape[0], 1) # (num_band, left walls)
|
| 66 |
+
t_hi = tr[:, 1::2].unsqueeze(1).repeat(1, XYZ.shape[0], 1) # (num_band, right walls)
|
| 67 |
+
att = torch.prod((t_lo**exp_lo) * (t_hi**exp_hi), dim=-1) # (num_band, num_image_source)
|
| 68 |
+
return img_loc, att
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _hann(x: torch.Tensor, T: int):
|
| 72 |
+
"""Compute the Hann window where the values are truncated based on window length.
|
| 73 |
+
torch.hann_window can only sample window function at integer points, the method is to sample
|
| 74 |
+
continuous window function at non-integer points.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
x (torch.Tensor): The fractional component of time delay Tensor.
|
| 78 |
+
T (torch.Tensor): The window length of sinc function.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
(torch.Tensor): The hann window Tensor where values outside
|
| 82 |
+
the sinc window (`T`) is set to zero.
|
| 83 |
+
"""
|
| 84 |
+
y = torch.where(
|
| 85 |
+
torch.abs(x) <= T / 2,
|
| 86 |
+
0.5 * (1 + torch.cos(2 * math.pi * x / T)),
|
| 87 |
+
x.new_zeros(1),
|
| 88 |
+
)
|
| 89 |
+
return y
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _frac_delay(delay: torch.Tensor, delay_i: torch.Tensor, delay_filter_length: int):
|
| 93 |
+
"""Compute fractional delay of impulse response signal.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
delay (torch.Tensor): The time delay Tensor in samples.
|
| 97 |
+
delay_i (torch.Tensor): The integer part of delay.
|
| 98 |
+
delay_filter_length (int): The window length for sinc function.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
(torch.Tensor): The impulse response Tensor for all image sources.
|
| 102 |
+
"""
|
| 103 |
+
if delay_filter_length % 2 != 1:
|
| 104 |
+
raise ValueError("The filter length must be odd")
|
| 105 |
+
|
| 106 |
+
pad = delay_filter_length // 2
|
| 107 |
+
n = torch.arange(-pad, pad + 1, device=delay.device) + delay_i[..., None]
|
| 108 |
+
delay = delay[..., None]
|
| 109 |
+
|
| 110 |
+
return torch.special.sinc(n - delay) * _hann(n - delay, 2 * pad)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _adjust_coeff(coeffs: Union[float, torch.Tensor], name: str) -> torch.Tensor:
|
| 114 |
+
"""Validates and converts absorption or scattering parameters to a tensor with appropriate shape
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
coeff (float or torch.Tensor): The absorption coefficients of wall materials.
|
| 118 |
+
|
| 119 |
+
If the dtype is ``float``, the absorption coefficient is identical for all walls and
|
| 120 |
+
all frequencies.
|
| 121 |
+
|
| 122 |
+
If ``absorption`` is a 1D Tensor, the shape must be `(2*dim,)`,
|
| 123 |
+
where the values represent absorption coefficients of ``"west"``, ``"east"``,
|
| 124 |
+
``"south"``, ``"north"``, ``"floor"``, and ``"ceiling"``, respectively.
|
| 125 |
+
|
| 126 |
+
If ``absorption`` is a 2D Tensor, the shape must be `(7, 2*dim)`,
|
| 127 |
+
where 7 represents the number of octave bands.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
(torch.Tensor): The expanded coefficient.
|
| 131 |
+
The shape is `(1, 6)` for single octave band case, and
|
| 132 |
+
`(7, 6)` for multi octave band case.
|
| 133 |
+
"""
|
| 134 |
+
num_walls = 6
|
| 135 |
+
if isinstance(coeffs, float):
|
| 136 |
+
if coeffs < 0:
|
| 137 |
+
raise ValueError(f"`{name}` must be non-negative. Found: {coeffs}")
|
| 138 |
+
return torch.full((1, num_walls), coeffs)
|
| 139 |
+
if isinstance(coeffs, Tensor):
|
| 140 |
+
if torch.any(coeffs < 0):
|
| 141 |
+
raise ValueError(f"`{name}` must be non-negative. Found: {coeffs}")
|
| 142 |
+
if coeffs.ndim == 1:
|
| 143 |
+
if coeffs.numel() != num_walls:
|
| 144 |
+
raise ValueError(
|
| 145 |
+
f"The shape of `{name}` must be ({num_walls},) when it is a 1D Tensor. "
|
| 146 |
+
f"Found the shape {coeffs.shape}."
|
| 147 |
+
)
|
| 148 |
+
return coeffs.unsqueeze(0)
|
| 149 |
+
if coeffs.ndim == 2:
|
| 150 |
+
if coeffs.shape[1] != num_walls:
|
| 151 |
+
raise ValueError(
|
| 152 |
+
f"The shape of `{name}` must be (NUM_BANDS, {num_walls}) when it "
|
| 153 |
+
f"is a 2D Tensor. Found: {coeffs.shape}."
|
| 154 |
+
)
|
| 155 |
+
return coeffs
|
| 156 |
+
raise TypeError(f"`{name}` must be float or Tensor.")
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _validate_inputs(
|
| 160 |
+
room: torch.Tensor,
|
| 161 |
+
source: torch.Tensor,
|
| 162 |
+
mic_array: torch.Tensor,
|
| 163 |
+
):
|
| 164 |
+
"""Validate dimensions of input arguments, and normalize different kinds of absorption into the same dimension.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
room (torch.Tensor): The size of the room. width, length (and height)
|
| 168 |
+
source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(dim,)`.
|
| 169 |
+
mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, dim)`.
|
| 170 |
+
"""
|
| 171 |
+
if not (room.ndim == 1 and room.numel() == 3):
|
| 172 |
+
raise ValueError(f"`room` must be a 1D Tensor with 3 elements. Found {room.shape}.")
|
| 173 |
+
if not (source.ndim == 1 and source.numel() == 3):
|
| 174 |
+
raise ValueError(f"`source` must be 1D Tensor with 3 elements. Found {source.shape}.")
|
| 175 |
+
if not (mic_array.ndim == 2 and mic_array.shape[1] == 3):
|
| 176 |
+
raise ValueError(f"`mic_array` must be a 2D Tensor with shape (num_channels, 3). Found {mic_array.shape}.")
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def simulate_rir_ism(
|
| 180 |
+
room: torch.Tensor,
|
| 181 |
+
source: torch.Tensor,
|
| 182 |
+
mic_array: torch.Tensor,
|
| 183 |
+
max_order: int,
|
| 184 |
+
absorption: Union[float, torch.Tensor],
|
| 185 |
+
output_length: Optional[int] = None,
|
| 186 |
+
delay_filter_length: int = 81,
|
| 187 |
+
center_frequency: Optional[torch.Tensor] = None,
|
| 188 |
+
sound_speed: float = 343.0,
|
| 189 |
+
sample_rate: float = 16000.0,
|
| 190 |
+
) -> Tensor:
|
| 191 |
+
r"""Compute Room Impulse Response (RIR) based on the *image source method* :cite:`allen1979image`.
|
| 192 |
+
The implementation is based on *pyroomacoustics* :cite:`scheibler2018pyroomacoustics`.
|
| 193 |
+
|
| 194 |
+
.. devices:: CPU
|
| 195 |
+
|
| 196 |
+
.. properties:: TorchScript
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
room (torch.Tensor): Room coordinates. The shape of `room` must be `(3,)` which represents
|
| 200 |
+
three dimensions of the room.
|
| 201 |
+
source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(3,)`.
|
| 202 |
+
mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, 3)`.
|
| 203 |
+
max_order (int): The maximum number of reflections of the source.
|
| 204 |
+
absorption (float or torch.Tensor): The *absorption* :cite:`wiki:Absorption_(acoustics)`
|
| 205 |
+
coefficients of wall materials for sound energy.
|
| 206 |
+
If the dtype is ``float``, the absorption coefficient is identical for all walls and
|
| 207 |
+
all frequencies.
|
| 208 |
+
If ``absorption`` is a 1D Tensor, the shape must be `(6,)`, where the values represent
|
| 209 |
+
absorption coefficients of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``,
|
| 210 |
+
and ``"ceiling"``, respectively.
|
| 211 |
+
If ``absorption`` is a 2D Tensor, the shape must be `(7, 6)`, where 7 represents the number of octave bands.
|
| 212 |
+
output_length (int or None, optional): The output length of simulated RIR signal. If ``None``,
|
| 213 |
+
the length is defined as
|
| 214 |
+
|
| 215 |
+
.. math::
|
| 216 |
+
\frac{\text{max\_d} \cdot \text{sample\_rate}}{\text{sound\_speed}} + \text{delay\_filter\_length}
|
| 217 |
+
|
| 218 |
+
where ``max_d`` is the maximum distance between image sources and microphones.
|
| 219 |
+
delay_filter_length (int, optional): The filter length for computing sinc function. (Default: ``81``)
|
| 220 |
+
center_frequency (torch.Tensor, optional): The center frequencies of octave bands for multi-band walls.
|
| 221 |
+
Only used when ``absorption`` is a 2D Tensor.
|
| 222 |
+
sound_speed (float, optional): The speed of sound. (Default: ``343.0``)
|
| 223 |
+
sample_rate (float, optional): The sample rate of the generated room impulse response signal.
|
| 224 |
+
(Default: ``16000.0``)
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
(torch.Tensor): The simulated room impulse response waveform. Tensor with dimensions
|
| 228 |
+
`(channel, rir_length)`.
|
| 229 |
+
|
| 230 |
+
Note:
|
| 231 |
+
If ``absorption`` is a 2D Tensor and ``center_frequency`` is set to ``None``, the center frequencies
|
| 232 |
+
of octave bands are fixed to ``[125.0, 250.0, 500.0, 1000.0, 2000.0, 4000.0, 8000.0]``.
|
| 233 |
+
Users need to tune the values of ``absorption`` to the corresponding frequencies.
|
| 234 |
+
"""
|
| 235 |
+
_validate_inputs(room, source, mic_array)
|
| 236 |
+
absorption = _adjust_coeff(absorption, "absorption")
|
| 237 |
+
img_location, att = _compute_image_sources(room, source, max_order, absorption)
|
| 238 |
+
|
| 239 |
+
# compute distances between image sources and microphones
|
| 240 |
+
vec = img_location[:, None, :] - mic_array[None, :, :]
|
| 241 |
+
dist = torch.linalg.norm(vec, dim=-1) # (image_source, channel)
|
| 242 |
+
|
| 243 |
+
img_src_att = att[..., None] / dist[None, ...] # (band, image_source, channel)
|
| 244 |
+
|
| 245 |
+
# separate delays in integer / frac part
|
| 246 |
+
delay = dist * sample_rate / sound_speed # distance to delay in samples
|
| 247 |
+
delay_i = torch.ceil(delay) # integer part
|
| 248 |
+
|
| 249 |
+
# compute the shorts IRs corresponding to each image source
|
| 250 |
+
irs = img_src_att[..., None] * _frac_delay(delay, delay_i, delay_filter_length)[None, ...]
|
| 251 |
+
|
| 252 |
+
rir_length = int(delay_i.max() + irs.shape[-1])
|
| 253 |
+
rir = torch.ops.torchaudio._simulate_rir(irs, delay_i.type(torch.int32), rir_length)
|
| 254 |
+
|
| 255 |
+
# multi-band processing
|
| 256 |
+
if absorption.shape[0] > 1:
|
| 257 |
+
if center_frequency is None:
|
| 258 |
+
center = torch.tensor(
|
| 259 |
+
[125.0, 250.0, 500.0, 1000.0, 2000.0, 4000.0, 8000.0], dtype=room.dtype, device=room.device
|
| 260 |
+
)
|
| 261 |
+
else:
|
| 262 |
+
center = center_frequency
|
| 263 |
+
# n_fft is set to 512 by default.
|
| 264 |
+
filters = torch.ops.torchaudio._make_rir_filter(center, sample_rate, n_fft=512)
|
| 265 |
+
rir = torchaudio.functional.fftconvolve(rir, filters.unsqueeze(1).repeat(1, rir.shape[1], 1), mode="same")
|
| 266 |
+
|
| 267 |
+
# sum up rir signals of all image sources into one waveform.
|
| 268 |
+
rir = rir.sum(0)
|
| 269 |
+
|
| 270 |
+
if output_length is not None:
|
| 271 |
+
if output_length > rir.shape[-1]:
|
| 272 |
+
rir = torch.nn.functional.pad(rir, (0, output_length - rir.shape[-1]), "constant", 0.0)
|
| 273 |
+
else:
|
| 274 |
+
rir = rir[..., :output_length]
|
| 275 |
+
|
| 276 |
+
return rir
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def ray_tracing(
|
| 280 |
+
room: torch.Tensor,
|
| 281 |
+
source: torch.Tensor,
|
| 282 |
+
mic_array: torch.Tensor,
|
| 283 |
+
num_rays: int,
|
| 284 |
+
absorption: Union[float, torch.Tensor] = 0.0,
|
| 285 |
+
scattering: Union[float, torch.Tensor] = 0.0,
|
| 286 |
+
mic_radius: float = 0.5,
|
| 287 |
+
sound_speed: float = 343.0,
|
| 288 |
+
energy_thres: float = 1e-7,
|
| 289 |
+
time_thres: float = 10.0,
|
| 290 |
+
hist_bin_size: float = 0.004,
|
| 291 |
+
) -> torch.Tensor:
|
| 292 |
+
r"""Compute energy histogram via ray tracing.
|
| 293 |
+
|
| 294 |
+
The implementation is based on *pyroomacoustics* :cite:`scheibler2018pyroomacoustics`.
|
| 295 |
+
|
| 296 |
+
``num_rays`` rays are casted uniformly in all directions from the source;
|
| 297 |
+
when a ray intersects a wall, it is reflected and part of its energy is absorbed.
|
| 298 |
+
It is also scattered (sent directly to the microphone(s)) according to the ``scattering``
|
| 299 |
+
coefficient.
|
| 300 |
+
When a ray is close to the microphone, its current energy is recorded in the output
|
| 301 |
+
histogram for that given time slot.
|
| 302 |
+
|
| 303 |
+
.. devices:: CPU
|
| 304 |
+
|
| 305 |
+
.. properties:: TorchScript
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
room (torch.Tensor): Room coordinates. The shape of `room` must be `(3,)` which represents
|
| 309 |
+
three dimensions of the room.
|
| 310 |
+
source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(3,)`.
|
| 311 |
+
mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, 3)`.
|
| 312 |
+
absorption (float or torch.Tensor, optional): The absorption coefficients of wall materials.
|
| 313 |
+
(Default: ``0.0``).
|
| 314 |
+
If the type is ``float``, the absorption coefficient is identical to all walls and
|
| 315 |
+
all frequencies.
|
| 316 |
+
If ``absorption`` is a 1D Tensor, the shape must be `(6,)`, representing absorption
|
| 317 |
+
coefficients of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``, and
|
| 318 |
+
``"ceiling"``, respectively.
|
| 319 |
+
If ``absorption`` is a 2D Tensor, the shape must be `(num_bands, 6)`.
|
| 320 |
+
``num_bands`` is the number of frequency bands (usually 7).
|
| 321 |
+
scattering(float or torch.Tensor, optional): The scattering coefficients of wall materials. (Default: ``0.0``)
|
| 322 |
+
The shape and type of this parameter is the same as for ``absorption``.
|
| 323 |
+
mic_radius(float, optional): The radius of the microphone in meters. (Default: 0.5)
|
| 324 |
+
sound_speed (float, optional): The speed of sound in meters per second. (Default: ``343.0``)
|
| 325 |
+
energy_thres (float, optional): The energy level below which we stop tracing a ray. (Default: ``1e-7``)
|
| 326 |
+
The initial energy of each ray is ``2 / num_rays``.
|
| 327 |
+
time_thres (float, optional): The maximal duration for which rays are traced. (Unit: seconds) (Default: 10.0)
|
| 328 |
+
hist_bin_size (float, optional): The size of each bin in the output histogram. (Unit: seconds) (Default: 0.004)
|
| 329 |
+
|
| 330 |
+
Returns:
|
| 331 |
+
(torch.Tensor): The 3D histogram(s) where the energy of the traced ray is recorded.
|
| 332 |
+
Each bin corresponds to a given time slot.
|
| 333 |
+
The shape is `(channel, num_bands, num_bins)`, where
|
| 334 |
+
``num_bins = ceil(time_thres / hist_bin_size)``.
|
| 335 |
+
If both ``absorption`` and ``scattering`` are floats, then ``num_bands == 1``.
|
| 336 |
+
"""
|
| 337 |
+
if time_thres < hist_bin_size:
|
| 338 |
+
raise ValueError(
|
| 339 |
+
"`time_thres` must be greater than `hist_bin_size`. "
|
| 340 |
+
f"Found: hist_bin_size={hist_bin_size}, time_thres={time_thres}."
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
if room.dtype != source.dtype or source.dtype != mic_array.dtype:
|
| 344 |
+
raise ValueError(
|
| 345 |
+
"dtype of `room`, `source` and `mic_array` must match. "
|
| 346 |
+
f"Found: `room` ({room.dtype}), `source` ({source.dtype}) and "
|
| 347 |
+
f"`mic_array` ({mic_array.dtype})"
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
_validate_inputs(room, source, mic_array)
|
| 351 |
+
absorption = _adjust_coeff(absorption, "absorption").to(room.dtype)
|
| 352 |
+
scattering = _adjust_coeff(scattering, "scattering").to(room.dtype)
|
| 353 |
+
|
| 354 |
+
# Bring absorption and scattering to the same shape
|
| 355 |
+
if absorption.shape[0] == 1 and scattering.shape[0] > 1:
|
| 356 |
+
absorption = absorption.expand(scattering.shape)
|
| 357 |
+
if scattering.shape[0] == 1 and absorption.shape[0] > 1:
|
| 358 |
+
scattering = scattering.expand(absorption.shape)
|
| 359 |
+
if absorption.shape != scattering.shape:
|
| 360 |
+
raise ValueError(
|
| 361 |
+
"`absorption` and `scattering` must be broadcastable to the same number of bands and walls. "
|
| 362 |
+
f"Inferred shapes absorption={absorption.shape} and scattering={scattering.shape}"
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
histograms = torch.ops.torchaudio.ray_tracing(
|
| 366 |
+
room,
|
| 367 |
+
source,
|
| 368 |
+
mic_array,
|
| 369 |
+
num_rays,
|
| 370 |
+
absorption,
|
| 371 |
+
scattering,
|
| 372 |
+
mic_radius,
|
| 373 |
+
sound_speed,
|
| 374 |
+
energy_thres,
|
| 375 |
+
time_thres,
|
| 376 |
+
hist_bin_size,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
return histograms
|
.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/functional.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import warnings
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torchaudio.functional.functional import _create_triangular_filterbank
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _hz_to_bark(freqs: float, bark_scale: str = "traunmuller") -> float:
|
| 10 |
+
r"""Convert Hz to Barks.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
freqs (float): Frequencies in Hz
|
| 14 |
+
bark_scale (str, optional): Scale to use: ``traunmuller``, ``schroeder`` or ``wang``. (Default: ``traunmuller``)
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
barks (float): Frequency in Barks
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
if bark_scale not in ["schroeder", "traunmuller", "wang"]:
|
| 21 |
+
raise ValueError('bark_scale should be one of "schroeder", "traunmuller" or "wang".')
|
| 22 |
+
|
| 23 |
+
if bark_scale == "wang":
|
| 24 |
+
return 6.0 * math.asinh(freqs / 600.0)
|
| 25 |
+
elif bark_scale == "schroeder":
|
| 26 |
+
return 7.0 * math.asinh(freqs / 650.0)
|
| 27 |
+
# Traunmuller Bark scale
|
| 28 |
+
barks = ((26.81 * freqs) / (1960.0 + freqs)) - 0.53
|
| 29 |
+
# Bark value correction
|
| 30 |
+
if barks < 2:
|
| 31 |
+
barks += 0.15 * (2 - barks)
|
| 32 |
+
elif barks > 20.1:
|
| 33 |
+
barks += 0.22 * (barks - 20.1)
|
| 34 |
+
|
| 35 |
+
return barks
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _bark_to_hz(barks: torch.Tensor, bark_scale: str = "traunmuller") -> torch.Tensor:
|
| 39 |
+
"""Convert bark bin numbers to frequencies.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
barks (torch.Tensor): Bark frequencies
|
| 43 |
+
bark_scale (str, optional): Scale to use: ``traunmuller``,``schroeder`` or ``wang``. (Default: ``traunmuller``)
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
freqs (torch.Tensor): Barks converted in Hz
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
if bark_scale not in ["schroeder", "traunmuller", "wang"]:
|
| 50 |
+
raise ValueError('bark_scale should be one of "traunmuller", "schroeder" or "wang".')
|
| 51 |
+
|
| 52 |
+
if bark_scale == "wang":
|
| 53 |
+
return 600.0 * torch.sinh(barks / 6.0)
|
| 54 |
+
elif bark_scale == "schroeder":
|
| 55 |
+
return 650.0 * torch.sinh(barks / 7.0)
|
| 56 |
+
# Bark value correction
|
| 57 |
+
if any(barks < 2):
|
| 58 |
+
idx = barks < 2
|
| 59 |
+
barks[idx] = (barks[idx] - 0.3) / 0.85
|
| 60 |
+
elif any(barks > 20.1):
|
| 61 |
+
idx = barks > 20.1
|
| 62 |
+
barks[idx] = (barks[idx] + 4.422) / 1.22
|
| 63 |
+
|
| 64 |
+
# Traunmuller Bark scale
|
| 65 |
+
freqs = 1960 * ((barks + 0.53) / (26.28 - barks))
|
| 66 |
+
|
| 67 |
+
return freqs
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _hz_to_octs(freqs, tuning=0.0, bins_per_octave=12):
|
| 71 |
+
a440 = 440.0 * 2.0 ** (tuning / bins_per_octave)
|
| 72 |
+
return torch.log2(freqs / (a440 / 16))
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def barkscale_fbanks(
|
| 76 |
+
n_freqs: int,
|
| 77 |
+
f_min: float,
|
| 78 |
+
f_max: float,
|
| 79 |
+
n_barks: int,
|
| 80 |
+
sample_rate: int,
|
| 81 |
+
bark_scale: str = "traunmuller",
|
| 82 |
+
) -> torch.Tensor:
|
| 83 |
+
r"""Create a frequency bin conversion matrix.
|
| 84 |
+
|
| 85 |
+
.. devices:: CPU
|
| 86 |
+
|
| 87 |
+
.. properties:: TorchScript
|
| 88 |
+
|
| 89 |
+
.. image:: https://download.pytorch.org/torchaudio/doc-assets/bark_fbanks.png
|
| 90 |
+
:alt: Visualization of generated filter bank
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
n_freqs (int): Number of frequencies to highlight/apply
|
| 94 |
+
f_min (float): Minimum frequency (Hz)
|
| 95 |
+
f_max (float): Maximum frequency (Hz)
|
| 96 |
+
n_barks (int): Number of mel filterbanks
|
| 97 |
+
sample_rate (int): Sample rate of the audio waveform
|
| 98 |
+
bark_scale (str, optional): Scale to use: ``traunmuller``,``schroeder`` or ``wang``. (Default: ``traunmuller``)
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
torch.Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_barks``)
|
| 102 |
+
meaning number of frequencies to highlight/apply to x the number of filterbanks.
|
| 103 |
+
Each column is a filterbank so that assuming there is a matrix A of
|
| 104 |
+
size (..., ``n_freqs``), the applied result would be
|
| 105 |
+
``A * barkscale_fbanks(A.size(-1), ...)``.
|
| 106 |
+
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
# freq bins
|
| 110 |
+
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
|
| 111 |
+
|
| 112 |
+
# calculate bark freq bins
|
| 113 |
+
m_min = _hz_to_bark(f_min, bark_scale=bark_scale)
|
| 114 |
+
m_max = _hz_to_bark(f_max, bark_scale=bark_scale)
|
| 115 |
+
|
| 116 |
+
m_pts = torch.linspace(m_min, m_max, n_barks + 2)
|
| 117 |
+
f_pts = _bark_to_hz(m_pts, bark_scale=bark_scale)
|
| 118 |
+
|
| 119 |
+
# create filterbank
|
| 120 |
+
fb = _create_triangular_filterbank(all_freqs, f_pts)
|
| 121 |
+
|
| 122 |
+
if (fb.max(dim=0).values == 0.0).any():
|
| 123 |
+
warnings.warn(
|
| 124 |
+
"At least one bark filterbank has all zero values. "
|
| 125 |
+
f"The value for `n_barks` ({n_barks}) may be set too high. "
|
| 126 |
+
f"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
return fb
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def chroma_filterbank(
|
| 133 |
+
sample_rate: int,
|
| 134 |
+
n_freqs: int,
|
| 135 |
+
n_chroma: int,
|
| 136 |
+
*,
|
| 137 |
+
tuning: float = 0.0,
|
| 138 |
+
ctroct: float = 5.0,
|
| 139 |
+
octwidth: Optional[float] = 2.0,
|
| 140 |
+
norm: int = 2,
|
| 141 |
+
base_c: bool = True,
|
| 142 |
+
):
|
| 143 |
+
"""Create a frequency-to-chroma conversion matrix. Implementation adapted from librosa.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
sample_rate (int): Sample rate.
|
| 147 |
+
n_freqs (int): Number of input frequencies.
|
| 148 |
+
n_chroma (int): Number of output chroma.
|
| 149 |
+
tuning (float, optional): Tuning deviation from A440 in fractions of a chroma bin. (Default: 0.0)
|
| 150 |
+
ctroct (float, optional): Center of Gaussian dominance window to weight filters by, in octaves. (Default: 5.0)
|
| 151 |
+
octwidth (float or None, optional): Width of Gaussian dominance window to weight filters by, in octaves.
|
| 152 |
+
If ``None``, then disable weighting altogether. (Default: 2.0)
|
| 153 |
+
norm (int, optional): order of norm to normalize filter bank by. (Default: 2)
|
| 154 |
+
base_c (bool, optional): If True, then start filter bank at C. Otherwise, start at A. (Default: True)
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
torch.Tensor: Chroma filter bank, with shape `(n_freqs, n_chroma)`.
|
| 158 |
+
"""
|
| 159 |
+
# Skip redundant upper half of frequency range.
|
| 160 |
+
freqs = torch.linspace(0, sample_rate // 2, n_freqs)[1:]
|
| 161 |
+
freq_bins = n_chroma * _hz_to_octs(freqs, bins_per_octave=n_chroma, tuning=tuning)
|
| 162 |
+
freq_bins = torch.cat((torch.tensor([freq_bins[0] - 1.5 * n_chroma]), freq_bins))
|
| 163 |
+
freq_bin_widths = torch.cat(
|
| 164 |
+
(
|
| 165 |
+
torch.maximum(freq_bins[1:] - freq_bins[:-1], torch.tensor(1.0)),
|
| 166 |
+
torch.tensor([1]),
|
| 167 |
+
)
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# (n_freqs, n_chroma)
|
| 171 |
+
D = freq_bins.unsqueeze(1) - torch.arange(0, n_chroma)
|
| 172 |
+
|
| 173 |
+
n_chroma2 = round(n_chroma / 2)
|
| 174 |
+
|
| 175 |
+
# Project to range [-n_chroma/2, n_chroma/2 - 1]
|
| 176 |
+
D = torch.remainder(D + n_chroma2, n_chroma) - n_chroma2
|
| 177 |
+
|
| 178 |
+
fb = torch.exp(-0.5 * (2 * D / torch.tile(freq_bin_widths.unsqueeze(1), (1, n_chroma))) ** 2)
|
| 179 |
+
fb = torch.nn.functional.normalize(fb, p=norm, dim=1)
|
| 180 |
+
|
| 181 |
+
if octwidth is not None:
|
| 182 |
+
fb *= torch.tile(
|
| 183 |
+
torch.exp(-0.5 * (((freq_bins.unsqueeze(1) / n_chroma - ctroct) / octwidth) ** 2)),
|
| 184 |
+
(1, n_chroma),
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
if base_c:
|
| 188 |
+
fb = torch.roll(fb, -3 * (n_chroma // 12), dims=1)
|
| 189 |
+
|
| 190 |
+
return fb
|
.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._conformer_wav2vec2 import (
|
| 2 |
+
conformer_wav2vec2_base,
|
| 3 |
+
conformer_wav2vec2_model,
|
| 4 |
+
conformer_wav2vec2_pretrain_base,
|
| 5 |
+
conformer_wav2vec2_pretrain_large,
|
| 6 |
+
conformer_wav2vec2_pretrain_model,
|
| 7 |
+
ConformerWav2Vec2PretrainModel,
|
| 8 |
+
)
|
| 9 |
+
from ._emformer_hubert import emformer_hubert_base, emformer_hubert_model
|
| 10 |
+
from .conv_emformer import ConvEmformer
|
| 11 |
+
from .hifi_gan import hifigan_vocoder, hifigan_vocoder_v1, hifigan_vocoder_v2, hifigan_vocoder_v3, HiFiGANVocoder
|
| 12 |
+
from .rnnt import conformer_rnnt_base, conformer_rnnt_biasing, conformer_rnnt_biasing_base, conformer_rnnt_model
|
| 13 |
+
from .rnnt_decoder import Hypothesis, RNNTBeamSearchBiasing
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"conformer_rnnt_base",
|
| 17 |
+
"conformer_rnnt_model",
|
| 18 |
+
"conformer_rnnt_biasing",
|
| 19 |
+
"conformer_rnnt_biasing_base",
|
| 20 |
+
"ConvEmformer",
|
| 21 |
+
"conformer_wav2vec2_model",
|
| 22 |
+
"conformer_wav2vec2_base",
|
| 23 |
+
"conformer_wav2vec2_pretrain_model",
|
| 24 |
+
"conformer_wav2vec2_pretrain_base",
|
| 25 |
+
"conformer_wav2vec2_pretrain_large",
|
| 26 |
+
"ConformerWav2Vec2PretrainModel",
|
| 27 |
+
"emformer_hubert_base",
|
| 28 |
+
"emformer_hubert_model",
|
| 29 |
+
"Hypothesis",
|
| 30 |
+
"RNNTBeamSearchBiasing",
|
| 31 |
+
"HiFiGANVocoder",
|
| 32 |
+
"hifigan_vocoder_v1",
|
| 33 |
+
"hifigan_vocoder_v2",
|
| 34 |
+
"hifigan_vocoder_v3",
|
| 35 |
+
"hifigan_vocoder",
|
| 36 |
+
]
|
.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.35 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/_conformer_wav2vec2.cpython-311.pyc
ADDED
|
Binary file (33.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/_emformer_hubert.cpython-311.pyc
ADDED
|
Binary file (16.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/conv_emformer.cpython-311.pyc
ADDED
|
Binary file (30.2 kB). View file
|
|
|