koichi12 commited on
Commit
91f1872
·
verified ·
1 Parent(s): c060ea1

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/__init__.cpython-311.pyc +0 -0
  2. .venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/_no_backend.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/common.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/no_backend.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/sox_io_backend.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/torchaudio/backend/_sox_io_backend.py +294 -0
  7. .venv/lib/python3.11/site-packages/torchaudio/backend/no_backend.py +14 -0
  8. .venv/lib/python3.11/site-packages/torchaudio/compliance/__init__.py +5 -0
  9. .venv/lib/python3.11/site-packages/torchaudio/compliance/__pycache__/__init__.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/torchaudio/compliance/__pycache__/kaldi.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/torchaudio/compliance/kaldi.py +813 -0
  12. .venv/lib/python3.11/site-packages/torchaudio/models/__init__.py +85 -0
  13. .venv/lib/python3.11/site-packages/torchaudio/models/_hdemucs.py +1008 -0
  14. .venv/lib/python3.11/site-packages/torchaudio/models/conformer.py +293 -0
  15. .venv/lib/python3.11/site-packages/torchaudio/models/conv_tasnet.py +330 -0
  16. .venv/lib/python3.11/site-packages/torchaudio/models/decoder/__init__.py +46 -0
  17. .venv/lib/python3.11/site-packages/torchaudio/models/decoder/_ctc_decoder.py +568 -0
  18. .venv/lib/python3.11/site-packages/torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
  19. .venv/lib/python3.11/site-packages/torchaudio/models/deepspeech.py +84 -0
  20. .venv/lib/python3.11/site-packages/torchaudio/models/emformer.py +884 -0
  21. .venv/lib/python3.11/site-packages/torchaudio/models/rnnt.py +816 -0
  22. .venv/lib/python3.11/site-packages/torchaudio/models/rnnt_decoder.py +339 -0
  23. .venv/lib/python3.11/site-packages/torchaudio/models/tacotron2.py +1046 -0
  24. .venv/lib/python3.11/site-packages/torchaudio/models/wav2letter.py +72 -0
  25. .venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/model.py +1579 -0
  26. .venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__init__.py +7 -0
  27. .venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/import_fairseq.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/import_huggingface.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/wavlm_attention.py +214 -0
  31. .venv/lib/python3.11/site-packages/torchaudio/models/wavernn.py +409 -0
  32. .venv/lib/python3.11/site-packages/torchaudio/prototype/__init__.py +0 -0
  33. .venv/lib/python3.11/site-packages/torchaudio/prototype/__pycache__/__init__.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/__init__.py +4 -0
  35. .venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/__pycache__/__init__.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/__pycache__/musan.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/musan.py +67 -0
  38. .venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__init__.py +26 -0
  39. .venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/__init__.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/_dsp.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/_rir.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/functional.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/torchaudio/prototype/functional/_dsp.py +433 -0
  44. .venv/lib/python3.11/site-packages/torchaudio/prototype/functional/_rir.py +379 -0
  45. .venv/lib/python3.11/site-packages/torchaudio/prototype/functional/functional.py +190 -0
  46. .venv/lib/python3.11/site-packages/torchaudio/prototype/models/__init__.py +36 -0
  47. .venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/__init__.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/_conformer_wav2vec2.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/_emformer_hubert.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/conv_emformer.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (347 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/_no_backend.cpython-311.pyc ADDED
Binary file (1.58 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/common.cpython-311.pyc ADDED
Binary file (846 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/no_backend.cpython-311.pyc ADDED
Binary file (861 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/sox_io_backend.cpython-311.pyc ADDED
Binary file (869 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchaudio/backend/_sox_io_backend.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ import torchaudio
6
+ from torchaudio import AudioMetaData
7
+
8
+ sox_ext = torchaudio._extension.lazy_import_sox_ext()
9
+
10
+
11
+ def info(
12
+ filepath: str,
13
+ format: Optional[str] = None,
14
+ ) -> AudioMetaData:
15
+ """Get signal information of an audio file.
16
+
17
+ Args:
18
+ filepath (str):
19
+ Source of audio data.
20
+
21
+ format (str or None, optional):
22
+ Override the format detection with the given format.
23
+ Providing the argument might help when libsox can not infer the format
24
+ from header or extension.
25
+
26
+ Returns:
27
+ AudioMetaData: Metadata of the given audio.
28
+ """
29
+ if not torch.jit.is_scripting():
30
+ if hasattr(filepath, "read"):
31
+ raise RuntimeError("sox_io backend does not support file-like object.")
32
+ filepath = os.fspath(filepath)
33
+ sinfo = sox_ext.get_info(filepath, format)
34
+ return AudioMetaData(*sinfo)
35
+
36
+
37
+ def load(
38
+ filepath: str,
39
+ frame_offset: int = 0,
40
+ num_frames: int = -1,
41
+ normalize: bool = True,
42
+ channels_first: bool = True,
43
+ format: Optional[str] = None,
44
+ ) -> Tuple[torch.Tensor, int]:
45
+ """Load audio data from file.
46
+
47
+ Note:
48
+ This function can handle all the codecs that underlying libsox can handle,
49
+ however it is tested on the following formats;
50
+
51
+ * WAV, AMB
52
+
53
+ * 32-bit floating-point
54
+ * 32-bit signed integer
55
+ * 24-bit signed integer
56
+ * 16-bit signed integer
57
+ * 8-bit unsigned integer (WAV only)
58
+
59
+ * MP3
60
+ * FLAC
61
+ * OGG/VORBIS
62
+ * OPUS
63
+ * SPHERE
64
+ * AMR-NB
65
+
66
+ To load ``MP3``, ``FLAC``, ``OGG/VORBIS``, ``OPUS`` and other codecs ``libsox`` does not
67
+ handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
68
+ and corresponding codec libraries such as ``libmad`` or ``libmp3lame`` etc.
69
+
70
+ By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with
71
+ ``float32`` dtype, and the shape of `[channel, time]`.
72
+
73
+ .. warning::
74
+
75
+ ``normalize`` argument does not perform volume normalization.
76
+ It only converts the sample type to `torch.float32` from the native sample
77
+ type.
78
+
79
+ When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit
80
+ signed integer, 24-bit signed integer, and 8-bit unsigned integer, by providing ``normalize=False``,
81
+ this function can return integer Tensor, where the samples are expressed within the whole range
82
+ of the corresponding dtype, that is, ``int32`` tensor for 32-bit signed PCM,
83
+ ``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM. Since torch does not
84
+ support ``int24`` dtype, 24-bit signed PCM are converted to ``int32`` tensors.
85
+
86
+ ``normalize`` argument has no effect on 32-bit floating-point WAV and other formats, such as
87
+ ``flac`` and ``mp3``.
88
+
89
+ For these formats, this function always returns ``float32`` Tensor with values.
90
+
91
+ Args:
92
+ filepath (path-like object): Source of audio data.
93
+ frame_offset (int):
94
+ Number of frames to skip before start reading data.
95
+ num_frames (int, optional):
96
+ Maximum number of frames to read. ``-1`` reads all the remaining samples,
97
+ starting from ``frame_offset``.
98
+ This function may return the less number of frames if there is not enough
99
+ frames in the given file.
100
+ normalize (bool, optional):
101
+ When ``True``, this function converts the native sample type to ``float32``.
102
+ Default: ``True``.
103
+
104
+ If input file is integer WAV, giving ``False`` will change the resulting Tensor type to
105
+ integer type.
106
+ This argument has no effect for formats other than integer WAV type.
107
+
108
+ channels_first (bool, optional):
109
+ When True, the returned Tensor has dimension `[channel, time]`.
110
+ Otherwise, the returned Tensor's dimension is `[time, channel]`.
111
+ format (str or None, optional):
112
+ Override the format detection with the given format.
113
+ Providing the argument might help when libsox can not infer the format
114
+ from header or extension.
115
+
116
+ Returns:
117
+ (torch.Tensor, int): Resulting Tensor and sample rate.
118
+ If the input file has integer wav format and ``normalize=False``, then it has
119
+ integer type, else ``float32`` type. If ``channels_first=True``, it has
120
+ `[channel, time]` else `[time, channel]`.
121
+ """
122
+ if not torch.jit.is_scripting():
123
+ if hasattr(filepath, "read"):
124
+ raise RuntimeError("sox_io backend does not support file-like object.")
125
+ filepath = os.fspath(filepath)
126
+ return sox_ext.load_audio_file(filepath, frame_offset, num_frames, normalize, channels_first, format)
127
+
128
+
129
+ def save(
130
+ filepath: str,
131
+ src: torch.Tensor,
132
+ sample_rate: int,
133
+ channels_first: bool = True,
134
+ compression: Optional[float] = None,
135
+ format: Optional[str] = None,
136
+ encoding: Optional[str] = None,
137
+ bits_per_sample: Optional[int] = None,
138
+ ):
139
+ """Save audio data to file.
140
+
141
+ Args:
142
+ filepath (path-like object): Path to save file.
143
+ src (torch.Tensor): Audio data to save. must be 2D tensor.
144
+ sample_rate (int): sampling rate
145
+ channels_first (bool, optional): If ``True``, the given tensor is interpreted as `[channel, time]`,
146
+ otherwise `[time, channel]`.
147
+ compression (float or None, optional): Used for formats other than WAV.
148
+ This corresponds to ``-C`` option of ``sox`` command.
149
+
150
+ ``"mp3"``
151
+ Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or
152
+ VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``.
153
+
154
+ ``"flac"``
155
+ Whole number from ``0`` to ``8``. ``8`` is default and highest compression.
156
+
157
+ ``"ogg"``, ``"vorbis"``
158
+ Number from ``-1`` to ``10``; ``-1`` is the highest compression
159
+ and lowest quality. Default: ``3``.
160
+
161
+ See the detail at http://sox.sourceforge.net/soxformat.html.
162
+ format (str or None, optional): Override the audio format.
163
+ When ``filepath`` argument is path-like object, audio format is infered from
164
+ file extension. If file extension is missing or different, you can specify the
165
+ correct format with this argument.
166
+
167
+ When ``filepath`` argument is file-like object, this argument is required.
168
+
169
+ Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``,
170
+ ``"amb"``, ``"flac"``, ``"sph"``, ``"gsm"``, and ``"htk"``.
171
+
172
+ encoding (str or None, optional): Changes the encoding for the supported formats.
173
+ This argument is effective only for supported formats, such as ``"wav"``, ``""amb"``
174
+ and ``"sph"``. Valid values are;
175
+
176
+ - ``"PCM_S"`` (signed integer Linear PCM)
177
+ - ``"PCM_U"`` (unsigned integer Linear PCM)
178
+ - ``"PCM_F"`` (floating point PCM)
179
+ - ``"ULAW"`` (mu-law)
180
+ - ``"ALAW"`` (a-law)
181
+
182
+ Default values
183
+ If not provided, the default value is picked based on ``format`` and ``bits_per_sample``.
184
+
185
+ ``"wav"``, ``"amb"``
186
+ - | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the
187
+ | Tensor is used to determine the default value.
188
+
189
+ - ``"PCM_U"`` if dtype is ``uint8``
190
+ - ``"PCM_S"`` if dtype is ``int16`` or ``int32``
191
+ - ``"PCM_F"`` if dtype is ``float32``
192
+
193
+ - ``"PCM_U"`` if ``bits_per_sample=8``
194
+ - ``"PCM_S"`` otherwise
195
+
196
+ ``"sph"`` format;
197
+ - the default value is ``"PCM_S"``
198
+
199
+ bits_per_sample (int or None, optional): Changes the bit depth for the supported formats.
200
+ When ``format`` is one of ``"wav"``, ``"flac"``, ``"sph"``, or ``"amb"``, you can change the
201
+ bit depth. Valid values are ``8``, ``16``, ``32`` and ``64``.
202
+
203
+ Default Value;
204
+ If not provided, the default values are picked based on ``format`` and ``"encoding"``;
205
+
206
+ ``"wav"``, ``"amb"``;
207
+ - | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the
208
+ | Tensor is used.
209
+
210
+ - ``8`` if dtype is ``uint8``
211
+ - ``16`` if dtype is ``int16``
212
+ - ``32`` if dtype is ``int32`` or ``float32``
213
+
214
+ - ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"``
215
+ - ``16`` if ``encoding`` is ``"PCM_S"``
216
+ - ``32`` if ``encoding`` is ``"PCM_F"``
217
+
218
+ ``"flac"`` format;
219
+ - the default value is ``24``
220
+
221
+ ``"sph"`` format;
222
+ - ``16`` if ``encoding`` is ``"PCM_U"``, ``"PCM_S"``, ``"PCM_F"`` or not provided.
223
+ - ``8`` if ``encoding`` is ``"ULAW"`` or ``"ALAW"``
224
+
225
+ ``"amb"`` format;
226
+ - ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"``
227
+ - ``16`` if ``encoding`` is ``"PCM_S"`` or not provided.
228
+ - ``32`` if ``encoding`` is ``"PCM_F"``
229
+
230
+ Supported formats/encodings/bit depth/compression are;
231
+
232
+ ``"wav"``, ``"amb"``
233
+ - 32-bit floating-point PCM
234
+ - 32-bit signed integer PCM
235
+ - 24-bit signed integer PCM
236
+ - 16-bit signed integer PCM
237
+ - 8-bit unsigned integer PCM
238
+ - 8-bit mu-law
239
+ - 8-bit a-law
240
+
241
+ Note: Default encoding/bit depth is determined by the dtype of the input Tensor.
242
+
243
+ ``"mp3"``
244
+ Fixed bit rate (such as 128kHz) and variable bit rate compression.
245
+ Default: VBR with high quality.
246
+
247
+ ``"flac"``
248
+ - 8-bit
249
+ - 16-bit
250
+ - 24-bit (default)
251
+
252
+ ``"ogg"``, ``"vorbis"``
253
+ - Different quality level. Default: approx. 112kbps
254
+
255
+ ``"sph"``
256
+ - 8-bit signed integer PCM
257
+ - 16-bit signed integer PCM
258
+ - 24-bit signed integer PCM
259
+ - 32-bit signed integer PCM (default)
260
+ - 8-bit mu-law
261
+ - 8-bit a-law
262
+ - 16-bit a-law
263
+ - 24-bit a-law
264
+ - 32-bit a-law
265
+
266
+ ``"amr-nb"``
267
+ Bitrate ranging from 4.75 kbit/s to 12.2 kbit/s. Default: 4.75 kbit/s
268
+
269
+ ``"gsm"``
270
+ Lossy Speech Compression, CPU intensive.
271
+
272
+ ``"htk"``
273
+ Uses a default single-channel 16-bit PCM format.
274
+
275
+ Note:
276
+ To save into formats that ``libsox`` does not handle natively, (such as ``"mp3"``,
277
+ ``"flac"``, ``"ogg"`` and ``"vorbis"``), your installation of ``torchaudio`` has
278
+ to be linked to ``libsox`` and corresponding codec libraries such as ``libmad``
279
+ or ``libmp3lame`` etc.
280
+ """
281
+ if not torch.jit.is_scripting():
282
+ if hasattr(filepath, "write"):
283
+ raise RuntimeError("sox_io backend does not handle file-like object.")
284
+ filepath = os.fspath(filepath)
285
+ sox_ext.save_audio_file(
286
+ filepath,
287
+ src,
288
+ sample_rate,
289
+ channels_first,
290
+ compression,
291
+ format,
292
+ encoding,
293
+ bits_per_sample,
294
+ )
.venv/lib/python3.11/site-packages/torchaudio/backend/no_backend.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def __getattr__(name: str):
2
+ import warnings
3
+
4
+ warnings.warn(
5
+ "Torchaudio's I/O functions now support par-call bakcend dispatch. "
6
+ "Importing backend implementation directly is no longer guaranteed to work. "
7
+ "Please use `backend` keyword with load/save/info function, instead of "
8
+ "calling the udnerlying implementation directly.",
9
+ stacklevel=2,
10
+ )
11
+
12
+ from . import _no_backend
13
+
14
+ return getattr(_no_backend, name)
.venv/lib/python3.11/site-packages/torchaudio/compliance/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from . import kaldi
2
+
3
+ __all__ = [
4
+ "kaldi",
5
+ ]
.venv/lib/python3.11/site-packages/torchaudio/compliance/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (268 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchaudio/compliance/__pycache__/kaldi.cpython-311.pyc ADDED
Binary file (37.6 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/compliance/kaldi.py ADDED
@@ -0,0 +1,813 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torchaudio
6
+ from torch import Tensor
7
+
8
+ __all__ = [
9
+ "get_mel_banks",
10
+ "inverse_mel_scale",
11
+ "inverse_mel_scale_scalar",
12
+ "mel_scale",
13
+ "mel_scale_scalar",
14
+ "spectrogram",
15
+ "fbank",
16
+ "mfcc",
17
+ "vtln_warp_freq",
18
+ "vtln_warp_mel_freq",
19
+ ]
20
+
21
+ # numeric_limits<float>::epsilon() 1.1920928955078125e-07
22
+ EPSILON = torch.tensor(torch.finfo(torch.float).eps)
23
+ # 1 milliseconds = 0.001 seconds
24
+ MILLISECONDS_TO_SECONDS = 0.001
25
+
26
+ # window types
27
+ HAMMING = "hamming"
28
+ HANNING = "hanning"
29
+ POVEY = "povey"
30
+ RECTANGULAR = "rectangular"
31
+ BLACKMAN = "blackman"
32
+ WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
33
+
34
+
35
+ def _get_epsilon(device, dtype):
36
+ return EPSILON.to(device=device, dtype=dtype)
37
+
38
+
39
+ def _next_power_of_2(x: int) -> int:
40
+ r"""Returns the smallest power of 2 that is greater than x"""
41
+ return 1 if x == 0 else 2 ** (x - 1).bit_length()
42
+
43
+
44
+ def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor:
45
+ r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``)
46
+ representing how the window is shifted along the waveform. Each row is a frame.
47
+
48
+ Args:
49
+ waveform (Tensor): Tensor of size ``num_samples``
50
+ window_size (int): Frame length
51
+ window_shift (int): Frame shift
52
+ snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
53
+ in the file, and the number of frames depends on the frame_length. If False, the number of frames
54
+ depends only on the frame_shift, and we reflect the data at the ends.
55
+
56
+ Returns:
57
+ Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame
58
+ """
59
+ assert waveform.dim() == 1
60
+ num_samples = waveform.size(0)
61
+ strides = (window_shift * waveform.stride(0), waveform.stride(0))
62
+
63
+ if snip_edges:
64
+ if num_samples < window_size:
65
+ return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
66
+ else:
67
+ m = 1 + (num_samples - window_size) // window_shift
68
+ else:
69
+ reversed_waveform = torch.flip(waveform, [0])
70
+ m = (num_samples + (window_shift // 2)) // window_shift
71
+ pad = window_size // 2 - window_shift // 2
72
+ pad_right = reversed_waveform
73
+ if pad > 0:
74
+ # torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
75
+ # but we want [2, 1, 0, 0, 1, 2]
76
+ pad_left = reversed_waveform[-pad:]
77
+ waveform = torch.cat((pad_left, waveform, pad_right), dim=0)
78
+ else:
79
+ # pad is negative so we want to trim the waveform at the front
80
+ waveform = torch.cat((waveform[-pad:], pad_right), dim=0)
81
+
82
+ sizes = (m, window_size)
83
+ return waveform.as_strided(sizes, strides)
84
+
85
+
86
+ def _feature_window_function(
87
+ window_type: str,
88
+ window_size: int,
89
+ blackman_coeff: float,
90
+ device: torch.device,
91
+ dtype: int,
92
+ ) -> Tensor:
93
+ r"""Returns a window function with the given type and size"""
94
+ if window_type == HANNING:
95
+ return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)
96
+ elif window_type == HAMMING:
97
+ return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype)
98
+ elif window_type == POVEY:
99
+ # like hanning but goes to zero at edges
100
+ return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85)
101
+ elif window_type == RECTANGULAR:
102
+ return torch.ones(window_size, device=device, dtype=dtype)
103
+ elif window_type == BLACKMAN:
104
+ a = 2 * math.pi / (window_size - 1)
105
+ window_function = torch.arange(window_size, device=device, dtype=dtype)
106
+ # can't use torch.blackman_window as they use different coefficients
107
+ return (
108
+ blackman_coeff
109
+ - 0.5 * torch.cos(a * window_function)
110
+ + (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)
111
+ ).to(device=device, dtype=dtype)
112
+ else:
113
+ raise Exception("Invalid window type " + window_type)
114
+
115
+
116
+ def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor:
117
+ r"""Returns the log energy of size (m) for a strided_input (m,*)"""
118
+ device, dtype = strided_input.device, strided_input.dtype
119
+ log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
120
+ if energy_floor == 0.0:
121
+ return log_energy
122
+ return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))
123
+
124
+
125
+ def _get_waveform_and_window_properties(
126
+ waveform: Tensor,
127
+ channel: int,
128
+ sample_frequency: float,
129
+ frame_shift: float,
130
+ frame_length: float,
131
+ round_to_power_of_two: bool,
132
+ preemphasis_coefficient: float,
133
+ ) -> Tuple[Tensor, int, int, int]:
134
+ r"""Gets the waveform and window properties"""
135
+ channel = max(channel, 0)
136
+ assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0))
137
+ waveform = waveform[channel, :] # size (n)
138
+ window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
139
+ window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
140
+ padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
141
+
142
+ assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
143
+ window_size, len(waveform)
144
+ )
145
+ assert 0 < window_shift, "`window_shift` must be greater than 0"
146
+ assert padded_window_size % 2 == 0, (
147
+ "the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`"
148
+ )
149
+ assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
150
+ assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
151
+ return waveform, window_shift, window_size, padded_window_size
152
+
153
+
154
+ def _get_window(
155
+ waveform: Tensor,
156
+ padded_window_size: int,
157
+ window_size: int,
158
+ window_shift: int,
159
+ window_type: str,
160
+ blackman_coeff: float,
161
+ snip_edges: bool,
162
+ raw_energy: bool,
163
+ energy_floor: float,
164
+ dither: float,
165
+ remove_dc_offset: bool,
166
+ preemphasis_coefficient: float,
167
+ ) -> Tuple[Tensor, Tensor]:
168
+ r"""Gets a window and its log energy
169
+
170
+ Returns:
171
+ (Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m)
172
+ """
173
+ device, dtype = waveform.device, waveform.dtype
174
+ epsilon = _get_epsilon(device, dtype)
175
+
176
+ # size (m, window_size)
177
+ strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
178
+
179
+ if dither != 0.0:
180
+ rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype)
181
+ strided_input = strided_input + rand_gauss * dither
182
+
183
+ if remove_dc_offset:
184
+ # Subtract each row/frame by its mean
185
+ row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
186
+ strided_input = strided_input - row_means
187
+
188
+ if raw_energy:
189
+ # Compute the log energy of each row/frame before applying preemphasis and
190
+ # window function
191
+ signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
192
+
193
+ if preemphasis_coefficient != 0.0:
194
+ # strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
195
+ offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(
196
+ 0
197
+ ) # size (m, window_size + 1)
198
+ strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]
199
+
200
+ # Apply window_function to each row/frame
201
+ window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze(
202
+ 0
203
+ ) # size (1, window_size)
204
+ strided_input = strided_input * window_function # size (m, window_size)
205
+
206
+ # Pad columns with zero until we reach size (m, padded_window_size)
207
+ if padded_window_size != window_size:
208
+ padding_right = padded_window_size - window_size
209
+ strided_input = torch.nn.functional.pad(
210
+ strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0
211
+ ).squeeze(0)
212
+
213
+ # Compute energy after window function (not the raw one)
214
+ if not raw_energy:
215
+ signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
216
+
217
+ return strided_input, signal_log_energy
218
+
219
+
220
+ def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
221
+ # subtracts the column mean of the tensor size (m, n) if subtract_mean=True
222
+ # it returns size (m, n)
223
+ if subtract_mean:
224
+ col_means = torch.mean(tensor, dim=0).unsqueeze(0)
225
+ tensor = tensor - col_means
226
+ return tensor
227
+
228
+
229
+ def spectrogram(
230
+ waveform: Tensor,
231
+ blackman_coeff: float = 0.42,
232
+ channel: int = -1,
233
+ dither: float = 0.0,
234
+ energy_floor: float = 1.0,
235
+ frame_length: float = 25.0,
236
+ frame_shift: float = 10.0,
237
+ min_duration: float = 0.0,
238
+ preemphasis_coefficient: float = 0.97,
239
+ raw_energy: bool = True,
240
+ remove_dc_offset: bool = True,
241
+ round_to_power_of_two: bool = True,
242
+ sample_frequency: float = 16000.0,
243
+ snip_edges: bool = True,
244
+ subtract_mean: bool = False,
245
+ window_type: str = POVEY,
246
+ ) -> Tensor:
247
+ r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's
248
+ compute-spectrogram-feats.
249
+
250
+ Args:
251
+ waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
252
+ blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
253
+ channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
254
+ dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
255
+ the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
256
+ energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
257
+ this floor is applied to the zeroth component, representing the total signal energy. The floor on the
258
+ individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
259
+ frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
260
+ frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
261
+ min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
262
+ preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
263
+ raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
264
+ remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
265
+ round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
266
+ to FFT. (Default: ``True``)
267
+ sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
268
+ specified there) (Default: ``16000.0``)
269
+ snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
270
+ in the file, and the number of frames depends on the frame_length. If False, the number of frames
271
+ depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
272
+ subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
273
+ it this way. (Default: ``False``)
274
+ window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
275
+ (Default: ``'povey'``)
276
+
277
+ Returns:
278
+ Tensor: A spectrogram identical to what Kaldi would output. The shape is
279
+ (m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided
280
+ """
281
+ device, dtype = waveform.device, waveform.dtype
282
+ epsilon = _get_epsilon(device, dtype)
283
+
284
+ waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
285
+ waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
286
+ )
287
+
288
+ if len(waveform) < min_duration * sample_frequency:
289
+ # signal is too short
290
+ return torch.empty(0)
291
+
292
+ strided_input, signal_log_energy = _get_window(
293
+ waveform,
294
+ padded_window_size,
295
+ window_size,
296
+ window_shift,
297
+ window_type,
298
+ blackman_coeff,
299
+ snip_edges,
300
+ raw_energy,
301
+ energy_floor,
302
+ dither,
303
+ remove_dc_offset,
304
+ preemphasis_coefficient,
305
+ )
306
+
307
+ # size (m, padded_window_size // 2 + 1, 2)
308
+ fft = torch.fft.rfft(strided_input)
309
+
310
+ # Convert the FFT into a power spectrum
311
+ power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1)
312
+ power_spectrum[:, 0] = signal_log_energy
313
+
314
+ power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
315
+ return power_spectrum
316
+
317
+
318
+ def inverse_mel_scale_scalar(mel_freq: float) -> float:
319
+ return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
320
+
321
+
322
+ def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
323
+ return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
324
+
325
+
326
+ def mel_scale_scalar(freq: float) -> float:
327
+ return 1127.0 * math.log(1.0 + freq / 700.0)
328
+
329
+
330
+ def mel_scale(freq: Tensor) -> Tensor:
331
+ return 1127.0 * (1.0 + freq / 700.0).log()
332
+
333
+
334
+ def vtln_warp_freq(
335
+ vtln_low_cutoff: float,
336
+ vtln_high_cutoff: float,
337
+ low_freq: float,
338
+ high_freq: float,
339
+ vtln_warp_factor: float,
340
+ freq: Tensor,
341
+ ) -> Tensor:
342
+ r"""This computes a VTLN warping function that is not the same as HTK's one,
343
+ but has similar inputs (this function has the advantage of never producing
344
+ empty bins).
345
+
346
+ This function computes a warp function F(freq), defined between low_freq
347
+ and high_freq inclusive, with the following properties:
348
+ F(low_freq) == low_freq
349
+ F(high_freq) == high_freq
350
+ The function is continuous and piecewise linear with two inflection
351
+ points.
352
+ The lower inflection point (measured in terms of the unwarped
353
+ frequency) is at frequency l, determined as described below.
354
+ The higher inflection point is at a frequency h, determined as
355
+ described below.
356
+ If l <= f <= h, then F(f) = f/vtln_warp_factor.
357
+ If the higher inflection point (measured in terms of the unwarped
358
+ frequency) is at h, then max(h, F(h)) == vtln_high_cutoff.
359
+ Since (by the last point) F(h) == h/vtln_warp_factor, then
360
+ max(h, h/vtln_warp_factor) == vtln_high_cutoff, so
361
+ h = vtln_high_cutoff / max(1, 1/vtln_warp_factor).
362
+ = vtln_high_cutoff * min(1, vtln_warp_factor).
363
+ If the lower inflection point (measured in terms of the unwarped
364
+ frequency) is at l, then min(l, F(l)) == vtln_low_cutoff
365
+ This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor)
366
+ = vtln_low_cutoff * max(1, vtln_warp_factor)
367
+ Args:
368
+ vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
369
+ vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
370
+ low_freq (float): Lower frequency cutoffs in mel computation
371
+ high_freq (float): Upper frequency cutoffs in mel computation
372
+ vtln_warp_factor (float): Vtln warp factor
373
+ freq (Tensor): given frequency in Hz
374
+
375
+ Returns:
376
+ Tensor: Freq after vtln warp
377
+ """
378
+ assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq"
379
+ assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]"
380
+ l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
381
+ h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
382
+ scale = 1.0 / vtln_warp_factor
383
+ Fl = scale * l # F(l)
384
+ Fh = scale * h # F(h)
385
+ assert l > low_freq and h < high_freq
386
+ # slope of left part of the 3-piece linear function
387
+ scale_left = (Fl - low_freq) / (l - low_freq)
388
+ # [slope of center part is just "scale"]
389
+
390
+ # slope of right part of the 3-piece linear function
391
+ scale_right = (high_freq - Fh) / (high_freq - h)
392
+
393
+ res = torch.empty_like(freq)
394
+
395
+ outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq
396
+ before_l = torch.lt(freq, l) # freq < l
397
+ before_h = torch.lt(freq, h) # freq < h
398
+ after_h = torch.ge(freq, h) # freq >= h
399
+
400
+ # order of operations matter here (since there is overlapping frequency regions)
401
+ res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
402
+ res[before_h] = scale * freq[before_h]
403
+ res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
404
+ res[outside_low_high_freq] = freq[outside_low_high_freq]
405
+
406
+ return res
407
+
408
+
409
+ def vtln_warp_mel_freq(
410
+ vtln_low_cutoff: float,
411
+ vtln_high_cutoff: float,
412
+ low_freq,
413
+ high_freq: float,
414
+ vtln_warp_factor: float,
415
+ mel_freq: Tensor,
416
+ ) -> Tensor:
417
+ r"""
418
+ Args:
419
+ vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
420
+ vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
421
+ low_freq (float): Lower frequency cutoffs in mel computation
422
+ high_freq (float): Upper frequency cutoffs in mel computation
423
+ vtln_warp_factor (float): Vtln warp factor
424
+ mel_freq (Tensor): Given frequency in Mel
425
+
426
+ Returns:
427
+ Tensor: ``mel_freq`` after vtln warp
428
+ """
429
+ return mel_scale(
430
+ vtln_warp_freq(
431
+ vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq)
432
+ )
433
+ )
434
+
435
+
436
+ def get_mel_banks(
437
+ num_bins: int,
438
+ window_length_padded: int,
439
+ sample_freq: float,
440
+ low_freq: float,
441
+ high_freq: float,
442
+ vtln_low: float,
443
+ vtln_high: float,
444
+ vtln_warp_factor: float,
445
+ ) -> Tuple[Tensor, Tensor]:
446
+ """
447
+ Returns:
448
+ (Tensor, Tensor): The tuple consists of ``bins`` (which is
449
+ melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is
450
+ center frequencies of bins of size (``num_bins``)).
451
+ """
452
+ assert num_bins > 3, "Must have at least 3 mel bins"
453
+ assert window_length_padded % 2 == 0
454
+ num_fft_bins = window_length_padded / 2
455
+ nyquist = 0.5 * sample_freq
456
+
457
+ if high_freq <= 0.0:
458
+ high_freq += nyquist
459
+
460
+ assert (
461
+ (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq)
462
+ ), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
463
+
464
+ # fft-bin width [think of it as Nyquist-freq / half-window-length]
465
+ fft_bin_width = sample_freq / window_length_padded
466
+ mel_low_freq = mel_scale_scalar(low_freq)
467
+ mel_high_freq = mel_scale_scalar(high_freq)
468
+
469
+ # divide by num_bins+1 in next line because of end-effects where the bins
470
+ # spread out to the sides.
471
+ mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
472
+
473
+ if vtln_high < 0.0:
474
+ vtln_high += nyquist
475
+
476
+ assert vtln_warp_factor == 1.0 or (
477
+ (low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)
478
+ ), "Bad values in options: vtln-low {} and vtln-high {}, versus " "low-freq {} and high-freq {}".format(
479
+ vtln_low, vtln_high, low_freq, high_freq
480
+ )
481
+
482
+ bin = torch.arange(num_bins).unsqueeze(1)
483
+ left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
484
+ center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1)
485
+ right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1)
486
+
487
+ if vtln_warp_factor != 1.0:
488
+ left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel)
489
+ center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel)
490
+ right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel)
491
+
492
+ center_freqs = inverse_mel_scale(center_mel) # size (num_bins)
493
+ # size(1, num_fft_bins)
494
+ mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0)
495
+
496
+ # size (num_bins, num_fft_bins)
497
+ up_slope = (mel - left_mel) / (center_mel - left_mel)
498
+ down_slope = (right_mel - mel) / (right_mel - center_mel)
499
+
500
+ if vtln_warp_factor == 1.0:
501
+ # left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
502
+ bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
503
+ else:
504
+ # warping can move the order of left_mel, center_mel, right_mel anywhere
505
+ bins = torch.zeros_like(up_slope)
506
+ up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel
507
+ down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel
508
+ bins[up_idx] = up_slope[up_idx]
509
+ bins[down_idx] = down_slope[down_idx]
510
+
511
+ return bins, center_freqs
512
+
513
+
514
+ def fbank(
515
+ waveform: Tensor,
516
+ blackman_coeff: float = 0.42,
517
+ channel: int = -1,
518
+ dither: float = 0.0,
519
+ energy_floor: float = 1.0,
520
+ frame_length: float = 25.0,
521
+ frame_shift: float = 10.0,
522
+ high_freq: float = 0.0,
523
+ htk_compat: bool = False,
524
+ low_freq: float = 20.0,
525
+ min_duration: float = 0.0,
526
+ num_mel_bins: int = 23,
527
+ preemphasis_coefficient: float = 0.97,
528
+ raw_energy: bool = True,
529
+ remove_dc_offset: bool = True,
530
+ round_to_power_of_two: bool = True,
531
+ sample_frequency: float = 16000.0,
532
+ snip_edges: bool = True,
533
+ subtract_mean: bool = False,
534
+ use_energy: bool = False,
535
+ use_log_fbank: bool = True,
536
+ use_power: bool = True,
537
+ vtln_high: float = -500.0,
538
+ vtln_low: float = 100.0,
539
+ vtln_warp: float = 1.0,
540
+ window_type: str = POVEY,
541
+ ) -> Tensor:
542
+ r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's
543
+ compute-fbank-feats.
544
+
545
+ Args:
546
+ waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
547
+ blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
548
+ channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
549
+ dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
550
+ the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
551
+ energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
552
+ this floor is applied to the zeroth component, representing the total signal energy. The floor on the
553
+ individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
554
+ frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
555
+ frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
556
+ high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
557
+ (Default: ``0.0``)
558
+ htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features
559
+ (need to change other parameters). (Default: ``False``)
560
+ low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
561
+ min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
562
+ num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
563
+ preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
564
+ raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
565
+ remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
566
+ round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
567
+ to FFT. (Default: ``True``)
568
+ sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
569
+ specified there) (Default: ``16000.0``)
570
+ snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
571
+ in the file, and the number of frames depends on the frame_length. If False, the number of frames
572
+ depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
573
+ subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
574
+ it this way. (Default: ``False``)
575
+ use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
576
+ use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``)
577
+ use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``)
578
+ vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
579
+ negative, offset from high-mel-freq (Default: ``-500.0``)
580
+ vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
581
+ vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
582
+ window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
583
+ (Default: ``'povey'``)
584
+
585
+ Returns:
586
+ Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``)
587
+ where m is calculated in _get_strided
588
+ """
589
+ device, dtype = waveform.device, waveform.dtype
590
+
591
+ waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
592
+ waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
593
+ )
594
+
595
+ if len(waveform) < min_duration * sample_frequency:
596
+ # signal is too short
597
+ return torch.empty(0, device=device, dtype=dtype)
598
+
599
+ # strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
600
+ strided_input, signal_log_energy = _get_window(
601
+ waveform,
602
+ padded_window_size,
603
+ window_size,
604
+ window_shift,
605
+ window_type,
606
+ blackman_coeff,
607
+ snip_edges,
608
+ raw_energy,
609
+ energy_floor,
610
+ dither,
611
+ remove_dc_offset,
612
+ preemphasis_coefficient,
613
+ )
614
+
615
+ # size (m, padded_window_size // 2 + 1)
616
+ spectrum = torch.fft.rfft(strided_input).abs()
617
+ if use_power:
618
+ spectrum = spectrum.pow(2.0)
619
+
620
+ # size (num_mel_bins, padded_window_size // 2)
621
+ mel_energies, _ = get_mel_banks(
622
+ num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp
623
+ )
624
+ mel_energies = mel_energies.to(device=device, dtype=dtype)
625
+
626
+ # pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
627
+ mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
628
+
629
+ # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
630
+ mel_energies = torch.mm(spectrum, mel_energies.T)
631
+ if use_log_fbank:
632
+ # avoid log of zero (which should be prevented anyway by dithering)
633
+ mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
634
+
635
+ # if use_energy then add it as the last column for htk_compat == true else first column
636
+ if use_energy:
637
+ signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1)
638
+ # returns size (m, num_mel_bins + 1)
639
+ if htk_compat:
640
+ mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1)
641
+ else:
642
+ mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1)
643
+
644
+ mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
645
+ return mel_energies
646
+
647
+
648
+ def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor:
649
+ # returns a dct matrix of size (num_mel_bins, num_ceps)
650
+ # size (num_mel_bins, num_mel_bins)
651
+ dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho")
652
+ # kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
653
+ # this would be the first column in the dct_matrix for torchaudio as it expects a
654
+ # right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
655
+ # expects a left multiply e.g. dct_matrix * vector).
656
+ dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins))
657
+ dct_matrix = dct_matrix[:, :num_ceps]
658
+ return dct_matrix
659
+
660
+
661
+ def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor:
662
+ # returns size (num_ceps)
663
+ # Compute liftering coefficients (scaling on cepstral coeffs)
664
+ # coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
665
+ i = torch.arange(num_ceps)
666
+ return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter)
667
+
668
+
669
+ def mfcc(
670
+ waveform: Tensor,
671
+ blackman_coeff: float = 0.42,
672
+ cepstral_lifter: float = 22.0,
673
+ channel: int = -1,
674
+ dither: float = 0.0,
675
+ energy_floor: float = 1.0,
676
+ frame_length: float = 25.0,
677
+ frame_shift: float = 10.0,
678
+ high_freq: float = 0.0,
679
+ htk_compat: bool = False,
680
+ low_freq: float = 20.0,
681
+ num_ceps: int = 13,
682
+ min_duration: float = 0.0,
683
+ num_mel_bins: int = 23,
684
+ preemphasis_coefficient: float = 0.97,
685
+ raw_energy: bool = True,
686
+ remove_dc_offset: bool = True,
687
+ round_to_power_of_two: bool = True,
688
+ sample_frequency: float = 16000.0,
689
+ snip_edges: bool = True,
690
+ subtract_mean: bool = False,
691
+ use_energy: bool = False,
692
+ vtln_high: float = -500.0,
693
+ vtln_low: float = 100.0,
694
+ vtln_warp: float = 1.0,
695
+ window_type: str = POVEY,
696
+ ) -> Tensor:
697
+ r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's
698
+ compute-mfcc-feats.
699
+
700
+ Args:
701
+ waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
702
+ blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
703
+ cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``)
704
+ channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
705
+ dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
706
+ the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
707
+ energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
708
+ this floor is applied to the zeroth component, representing the total signal energy. The floor on the
709
+ individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
710
+ frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
711
+ frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
712
+ high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
713
+ (Default: ``0.0``)
714
+ htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible
715
+ features (need to change other parameters). (Default: ``False``)
716
+ low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
717
+ num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``)
718
+ min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
719
+ num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
720
+ preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
721
+ raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
722
+ remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
723
+ round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
724
+ to FFT. (Default: ``True``)
725
+ sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
726
+ specified there) (Default: ``16000.0``)
727
+ snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
728
+ in the file, and the number of frames depends on the frame_length. If False, the number of frames
729
+ depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
730
+ subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
731
+ it this way. (Default: ``False``)
732
+ use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
733
+ vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
734
+ negative, offset from high-mel-freq (Default: ``-500.0``)
735
+ vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
736
+ vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
737
+ window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
738
+ (Default: ``"povey"``)
739
+
740
+ Returns:
741
+ Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``)
742
+ where m is calculated in _get_strided
743
+ """
744
+ assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins)
745
+
746
+ device, dtype = waveform.device, waveform.dtype
747
+
748
+ # The mel_energies should not be squared (use_power=True), not have mean subtracted
749
+ # (subtract_mean=False), and use log (use_log_fbank=True).
750
+ # size (m, num_mel_bins + use_energy)
751
+ feature = fbank(
752
+ waveform=waveform,
753
+ blackman_coeff=blackman_coeff,
754
+ channel=channel,
755
+ dither=dither,
756
+ energy_floor=energy_floor,
757
+ frame_length=frame_length,
758
+ frame_shift=frame_shift,
759
+ high_freq=high_freq,
760
+ htk_compat=htk_compat,
761
+ low_freq=low_freq,
762
+ min_duration=min_duration,
763
+ num_mel_bins=num_mel_bins,
764
+ preemphasis_coefficient=preemphasis_coefficient,
765
+ raw_energy=raw_energy,
766
+ remove_dc_offset=remove_dc_offset,
767
+ round_to_power_of_two=round_to_power_of_two,
768
+ sample_frequency=sample_frequency,
769
+ snip_edges=snip_edges,
770
+ subtract_mean=False,
771
+ use_energy=use_energy,
772
+ use_log_fbank=True,
773
+ use_power=True,
774
+ vtln_high=vtln_high,
775
+ vtln_low=vtln_low,
776
+ vtln_warp=vtln_warp,
777
+ window_type=window_type,
778
+ )
779
+
780
+ if use_energy:
781
+ # size (m)
782
+ signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
783
+ # offset is 0 if htk_compat==True else 1
784
+ mel_offset = int(not htk_compat)
785
+ feature = feature[:, mel_offset : (num_mel_bins + mel_offset)]
786
+
787
+ # size (num_mel_bins, num_ceps)
788
+ dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device)
789
+
790
+ # size (m, num_ceps)
791
+ feature = feature.matmul(dct_matrix)
792
+
793
+ if cepstral_lifter != 0.0:
794
+ # size (1, num_ceps)
795
+ lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0)
796
+ feature *= lifter_coeffs.to(device=device, dtype=dtype)
797
+
798
+ # if use_energy then replace the last column for htk_compat == true else first column
799
+ if use_energy:
800
+ feature[:, 0] = signal_log_energy
801
+
802
+ if htk_compat:
803
+ energy = feature[:, 0].unsqueeze(1) # size (m, 1)
804
+ feature = feature[:, 1:] # size (m, num_ceps - 1)
805
+ if not use_energy:
806
+ # scale on C0 (actually removing a scale we previously added that's
807
+ # part of one common definition of the cosine transform.)
808
+ energy *= math.sqrt(2)
809
+
810
+ feature = torch.cat((feature, energy), dim=1)
811
+
812
+ feature = _subtract_column_mean(feature, subtract_mean)
813
+ return feature
.venv/lib/python3.11/site-packages/torchaudio/models/__init__.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._hdemucs import HDemucs, hdemucs_high, hdemucs_low, hdemucs_medium
2
+ from .conformer import Conformer
3
+ from .conv_tasnet import conv_tasnet_base, ConvTasNet
4
+ from .deepspeech import DeepSpeech
5
+ from .emformer import Emformer
6
+ from .rnnt import emformer_rnnt_base, emformer_rnnt_model, RNNT
7
+ from .rnnt_decoder import Hypothesis, RNNTBeamSearch
8
+ from .squim import (
9
+ squim_objective_base,
10
+ squim_objective_model,
11
+ squim_subjective_base,
12
+ squim_subjective_model,
13
+ SquimObjective,
14
+ SquimSubjective,
15
+ )
16
+ from .tacotron2 import Tacotron2
17
+ from .wav2letter import Wav2Letter
18
+ from .wav2vec2 import (
19
+ hubert_base,
20
+ hubert_large,
21
+ hubert_pretrain_base,
22
+ hubert_pretrain_large,
23
+ hubert_pretrain_model,
24
+ hubert_pretrain_xlarge,
25
+ hubert_xlarge,
26
+ HuBERTPretrainModel,
27
+ wav2vec2_base,
28
+ wav2vec2_large,
29
+ wav2vec2_large_lv60k,
30
+ wav2vec2_model,
31
+ wav2vec2_xlsr_1b,
32
+ wav2vec2_xlsr_2b,
33
+ wav2vec2_xlsr_300m,
34
+ Wav2Vec2Model,
35
+ wavlm_base,
36
+ wavlm_large,
37
+ wavlm_model,
38
+ )
39
+ from .wavernn import WaveRNN
40
+
41
+
42
+ __all__ = [
43
+ "Wav2Letter",
44
+ "WaveRNN",
45
+ "ConvTasNet",
46
+ "conv_tasnet_base",
47
+ "DeepSpeech",
48
+ "Wav2Vec2Model",
49
+ "HuBERTPretrainModel",
50
+ "wavlm_model",
51
+ "wavlm_base",
52
+ "wavlm_large",
53
+ "wav2vec2_model",
54
+ "wav2vec2_base",
55
+ "wav2vec2_large",
56
+ "wav2vec2_large_lv60k",
57
+ "hubert_base",
58
+ "hubert_large",
59
+ "hubert_xlarge",
60
+ "hubert_pretrain_model",
61
+ "hubert_pretrain_base",
62
+ "hubert_pretrain_large",
63
+ "hubert_pretrain_xlarge",
64
+ "wav2vec2_xlsr_300m",
65
+ "wav2vec2_xlsr_1b",
66
+ "wav2vec2_xlsr_2b",
67
+ "Tacotron2",
68
+ "Conformer",
69
+ "Emformer",
70
+ "Hypothesis",
71
+ "RNNT",
72
+ "RNNTBeamSearch",
73
+ "emformer_rnnt_base",
74
+ "emformer_rnnt_model",
75
+ "HDemucs",
76
+ "hdemucs_low",
77
+ "hdemucs_medium",
78
+ "hdemucs_high",
79
+ "squim_objective_base",
80
+ "squim_objective_model",
81
+ "squim_subjective_base",
82
+ "squim_subjective_model",
83
+ "SquimObjective",
84
+ "SquimSubjective",
85
+ ]
.venv/lib/python3.11/site-packages/torchaudio/models/_hdemucs.py ADDED
@@ -0,0 +1,1008 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *****************************************************************************
2
+ # MIT License
3
+ #
4
+ # Copyright (c) Facebook, Inc. and its affiliates.
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ # *****************************************************************************
24
+
25
+
26
+ import math
27
+ import typing as tp
28
+ from typing import Any, Dict, List, Optional
29
+
30
+ import torch
31
+ from torch import nn
32
+ from torch.nn import functional as F
33
+
34
+
35
+ class _ScaledEmbedding(torch.nn.Module):
36
+ r"""Make continuous embeddings and boost learning rate
37
+
38
+ Args:
39
+ num_embeddings (int): number of embeddings
40
+ embedding_dim (int): embedding dimensions
41
+ scale (float, optional): amount to scale learning rate (Default: 10.0)
42
+ smooth (bool, optional): choose to apply smoothing (Default: ``False``)
43
+ """
44
+
45
+ def __init__(self, num_embeddings: int, embedding_dim: int, scale: float = 10.0, smooth: bool = False):
46
+ super().__init__()
47
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
48
+ if smooth:
49
+ weight = torch.cumsum(self.embedding.weight.data, dim=0)
50
+ # when summing gaussian, scale raises as sqrt(n), so we normalize by that.
51
+ weight = weight / torch.arange(1, num_embeddings + 1).sqrt()[:, None]
52
+ self.embedding.weight.data[:] = weight
53
+ self.embedding.weight.data /= scale
54
+ self.scale = scale
55
+
56
+ @property
57
+ def weight(self) -> torch.Tensor:
58
+ return self.embedding.weight * self.scale
59
+
60
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
61
+ r"""Forward pass for embedding with scale.
62
+ Args:
63
+ x (torch.Tensor): input tensor of shape `(num_embeddings)`
64
+
65
+ Returns:
66
+ (Tensor):
67
+ Embedding output of shape `(num_embeddings, embedding_dim)`
68
+ """
69
+ out = self.embedding(x) * self.scale
70
+ return out
71
+
72
+
73
+ class _HEncLayer(torch.nn.Module):
74
+
75
+ r"""Encoder layer. This used both by the time and the frequency branch.
76
+ Args:
77
+ chin (int): number of input channels.
78
+ chout (int): number of output channels.
79
+ kernel_size (int, optional): Kernel size for encoder (Default: 8)
80
+ stride (int, optional): Stride for encoder layer (Default: 4)
81
+ norm_groups (int, optional): number of groups for group norm. (Default: 4)
82
+ empty (bool, optional): used to make a layer with just the first conv. this is used
83
+ before merging the time and freq. branches. (Default: ``False``)
84
+ freq (bool, optional): boolean for whether conv layer is for frequency domain (Default: ``True``)
85
+ norm_type (string, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``)
86
+ context (int, optional): context size for the 1x1 conv. (Default: 0)
87
+ dconv_kw (Dict[str, Any] or None, optional): dictionary of kwargs for the DConv class. (Default: ``None``)
88
+ pad (bool, optional): true to pad the input. Padding is done so that the output size is
89
+ always the input size / stride. (Default: ``True``)
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ chin: int,
95
+ chout: int,
96
+ kernel_size: int = 8,
97
+ stride: int = 4,
98
+ norm_groups: int = 4,
99
+ empty: bool = False,
100
+ freq: bool = True,
101
+ norm_type: str = "group_norm",
102
+ context: int = 0,
103
+ dconv_kw: Optional[Dict[str, Any]] = None,
104
+ pad: bool = True,
105
+ ):
106
+ super().__init__()
107
+ if dconv_kw is None:
108
+ dconv_kw = {}
109
+ norm_fn = lambda d: nn.Identity() # noqa
110
+ if norm_type == "group_norm":
111
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
112
+ pad_val = kernel_size // 4 if pad else 0
113
+ klass = nn.Conv1d
114
+ self.freq = freq
115
+ self.kernel_size = kernel_size
116
+ self.stride = stride
117
+ self.empty = empty
118
+ self.pad = pad_val
119
+ if freq:
120
+ kernel_size = [kernel_size, 1]
121
+ stride = [stride, 1]
122
+ pad_val = [pad_val, 0]
123
+ klass = nn.Conv2d
124
+ self.conv = klass(chin, chout, kernel_size, stride, pad_val)
125
+ self.norm1 = norm_fn(chout)
126
+
127
+ if self.empty:
128
+ self.rewrite = nn.Identity()
129
+ self.norm2 = nn.Identity()
130
+ self.dconv = nn.Identity()
131
+ else:
132
+ self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
133
+ self.norm2 = norm_fn(2 * chout)
134
+ self.dconv = _DConv(chout, **dconv_kw)
135
+
136
+ def forward(self, x: torch.Tensor, inject: Optional[torch.Tensor] = None) -> torch.Tensor:
137
+ r"""Forward pass for encoding layer.
138
+
139
+ Size depends on whether frequency or time
140
+
141
+ Args:
142
+ x (torch.Tensor): tensor input of shape `(B, C, F, T)` for frequency and shape
143
+ `(B, C, T)` for time
144
+ inject (torch.Tensor, optional): on last layer, combine frequency and time branches through inject param,
145
+ same shape as x (default: ``None``)
146
+
147
+ Returns:
148
+ Tensor
149
+ output tensor after encoder layer of shape `(B, C, F / stride, T)` for frequency
150
+ and shape `(B, C, ceil(T / stride))` for time
151
+ """
152
+
153
+ if not self.freq and x.dim() == 4:
154
+ B, C, Fr, T = x.shape
155
+ x = x.view(B, -1, T)
156
+
157
+ if not self.freq:
158
+ le = x.shape[-1]
159
+ if not le % self.stride == 0:
160
+ x = F.pad(x, (0, self.stride - (le % self.stride)))
161
+ y = self.conv(x)
162
+ if self.empty:
163
+ return y
164
+ if inject is not None:
165
+ if inject.shape[-1] != y.shape[-1]:
166
+ raise ValueError("Injection shapes do not align")
167
+ if inject.dim() == 3 and y.dim() == 4:
168
+ inject = inject[:, :, None]
169
+ y = y + inject
170
+ y = F.gelu(self.norm1(y))
171
+ if self.freq:
172
+ B, C, Fr, T = y.shape
173
+ y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
174
+ y = self.dconv(y)
175
+ y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
176
+ else:
177
+ y = self.dconv(y)
178
+ z = self.norm2(self.rewrite(y))
179
+ z = F.glu(z, dim=1)
180
+ return z
181
+
182
+
183
+ class _HDecLayer(torch.nn.Module):
184
+ r"""Decoder layer. This used both by the time and the frequency branches.
185
+ Args:
186
+ chin (int): number of input channels.
187
+ chout (int): number of output channels.
188
+ last (bool, optional): whether current layer is final layer (Default: ``False``)
189
+ kernel_size (int, optional): Kernel size for encoder (Default: 8)
190
+ stride (int): Stride for encoder layer (Default: 4)
191
+ norm_groups (int, optional): number of groups for group norm. (Default: 1)
192
+ empty (bool, optional): used to make a layer with just the first conv. this is used
193
+ before merging the time and freq. branches. (Default: ``False``)
194
+ freq (bool, optional): boolean for whether conv layer is for frequency (Default: ``True``)
195
+ norm_type (str, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``)
196
+ context (int, optional): context size for the 1x1 conv. (Default: 1)
197
+ dconv_kw (Dict[str, Any] or None, optional): dictionary of kwargs for the DConv class. (Default: ``None``)
198
+ pad (bool, optional): true to pad the input. Padding is done so that the output size is
199
+ always the input size / stride. (Default: ``True``)
200
+ """
201
+
202
+ def __init__(
203
+ self,
204
+ chin: int,
205
+ chout: int,
206
+ last: bool = False,
207
+ kernel_size: int = 8,
208
+ stride: int = 4,
209
+ norm_groups: int = 1,
210
+ empty: bool = False,
211
+ freq: bool = True,
212
+ norm_type: str = "group_norm",
213
+ context: int = 1,
214
+ dconv_kw: Optional[Dict[str, Any]] = None,
215
+ pad: bool = True,
216
+ ):
217
+ super().__init__()
218
+ if dconv_kw is None:
219
+ dconv_kw = {}
220
+ norm_fn = lambda d: nn.Identity() # noqa
221
+ if norm_type == "group_norm":
222
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
223
+ if pad:
224
+ if (kernel_size - stride) % 2 != 0:
225
+ raise ValueError("Kernel size and stride do not align")
226
+ pad = (kernel_size - stride) // 2
227
+ else:
228
+ pad = 0
229
+ self.pad = pad
230
+ self.last = last
231
+ self.freq = freq
232
+ self.chin = chin
233
+ self.empty = empty
234
+ self.stride = stride
235
+ self.kernel_size = kernel_size
236
+ klass = nn.Conv1d
237
+ klass_tr = nn.ConvTranspose1d
238
+ if freq:
239
+ kernel_size = [kernel_size, 1]
240
+ stride = [stride, 1]
241
+ klass = nn.Conv2d
242
+ klass_tr = nn.ConvTranspose2d
243
+ self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
244
+ self.norm2 = norm_fn(chout)
245
+ if self.empty:
246
+ self.rewrite = nn.Identity()
247
+ self.norm1 = nn.Identity()
248
+ else:
249
+ self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
250
+ self.norm1 = norm_fn(2 * chin)
251
+
252
+ def forward(self, x: torch.Tensor, skip: Optional[torch.Tensor], length):
253
+ r"""Forward pass for decoding layer.
254
+
255
+ Size depends on whether frequency or time
256
+
257
+ Args:
258
+ x (torch.Tensor): tensor input of shape `(B, C, F, T)` for frequency and shape
259
+ `(B, C, T)` for time
260
+ skip (torch.Tensor, optional): on first layer, separate frequency and time branches using param
261
+ (default: ``None``)
262
+ length (int): Size of tensor for output
263
+
264
+ Returns:
265
+ (Tensor, Tensor):
266
+ Tensor
267
+ output tensor after decoder layer of shape `(B, C, F * stride, T)` for frequency domain except last
268
+ frequency layer shape is `(B, C, kernel_size, T)`. Shape is `(B, C, stride * T)`
269
+ for time domain.
270
+ Tensor
271
+ contains the output just before final transposed convolution, which is used when the
272
+ freq. and time branch separate. Otherwise, does not matter. Shape is
273
+ `(B, C, F, T)` for frequency and `(B, C, T)` for time.
274
+ """
275
+ if self.freq and x.dim() == 3:
276
+ B, C, T = x.shape
277
+ x = x.view(B, self.chin, -1, T)
278
+
279
+ if not self.empty:
280
+ x = x + skip
281
+ y = F.glu(self.norm1(self.rewrite(x)), dim=1)
282
+ else:
283
+ y = x
284
+ if skip is not None:
285
+ raise ValueError("Skip must be none when empty is true.")
286
+
287
+ z = self.norm2(self.conv_tr(y))
288
+ if self.freq:
289
+ if self.pad:
290
+ z = z[..., self.pad : -self.pad, :]
291
+ else:
292
+ z = z[..., self.pad : self.pad + length]
293
+ if z.shape[-1] != length:
294
+ raise ValueError("Last index of z must be equal to length")
295
+ if not self.last:
296
+ z = F.gelu(z)
297
+
298
+ return z, y
299
+
300
+
301
+ class HDemucs(torch.nn.Module):
302
+ r"""Hybrid Demucs model from
303
+ *Hybrid Spectrogram and Waveform Source Separation* :cite:`defossez2021hybrid`.
304
+
305
+ See Also:
306
+ * :class:`torchaudio.pipelines.SourceSeparationBundle`: Source separation pipeline with pre-trained models.
307
+
308
+ Args:
309
+ sources (List[str]): list of source names. List can contain the following source
310
+ options: [``"bass"``, ``"drums"``, ``"other"``, ``"mixture"``, ``"vocals"``].
311
+ audio_channels (int, optional): input/output audio channels. (Default: 2)
312
+ channels (int, optional): initial number of hidden channels. (Default: 48)
313
+ growth (int, optional): increase the number of hidden channels by this factor at each layer. (Default: 2)
314
+ nfft (int, optional): number of fft bins. Note that changing this requires careful computation of
315
+ various shape parameters and will not work out of the box for hybrid models. (Default: 4096)
316
+ depth (int, optional): number of layers in encoder and decoder (Default: 6)
317
+ freq_emb (float, optional): add frequency embedding after the first frequency layer if > 0,
318
+ the actual value controls the weight of the embedding. (Default: 0.2)
319
+ emb_scale (int, optional): equivalent to scaling the embedding learning rate (Default: 10)
320
+ emb_smooth (bool, optional): initialize the embedding with a smooth one (with respect to frequencies).
321
+ (Default: ``True``)
322
+ kernel_size (int, optional): kernel_size for encoder and decoder layers. (Default: 8)
323
+ time_stride (int, optional): stride for the final time layer, after the merge. (Default: 2)
324
+ stride (int, optional): stride for encoder and decoder layers. (Default: 4)
325
+ context (int, optional): context for 1x1 conv in the decoder. (Default: 4)
326
+ context_enc (int, optional): context for 1x1 conv in the encoder. (Default: 0)
327
+ norm_starts (int, optional): layer at which group norm starts being used.
328
+ decoder layers are numbered in reverse order. (Default: 4)
329
+ norm_groups (int, optional): number of groups for group norm. (Default: 4)
330
+ dconv_depth (int, optional): depth of residual DConv branch. (Default: 2)
331
+ dconv_comp (int, optional): compression of DConv branch. (Default: 4)
332
+ dconv_attn (int, optional): adds attention layers in DConv branch starting at this layer. (Default: 4)
333
+ dconv_lstm (int, optional): adds a LSTM layer in DConv branch starting at this layer. (Default: 4)
334
+ dconv_init (float, optional): initial scale for the DConv branch LayerScale. (Default: 1e-4)
335
+ """
336
+
337
+ def __init__(
338
+ self,
339
+ sources: List[str],
340
+ audio_channels: int = 2,
341
+ channels: int = 48,
342
+ growth: int = 2,
343
+ nfft: int = 4096,
344
+ depth: int = 6,
345
+ freq_emb: float = 0.2,
346
+ emb_scale: int = 10,
347
+ emb_smooth: bool = True,
348
+ kernel_size: int = 8,
349
+ time_stride: int = 2,
350
+ stride: int = 4,
351
+ context: int = 1,
352
+ context_enc: int = 0,
353
+ norm_starts: int = 4,
354
+ norm_groups: int = 4,
355
+ dconv_depth: int = 2,
356
+ dconv_comp: int = 4,
357
+ dconv_attn: int = 4,
358
+ dconv_lstm: int = 4,
359
+ dconv_init: float = 1e-4,
360
+ ):
361
+ super().__init__()
362
+ self.depth = depth
363
+ self.nfft = nfft
364
+ self.audio_channels = audio_channels
365
+ self.sources = sources
366
+ self.kernel_size = kernel_size
367
+ self.context = context
368
+ self.stride = stride
369
+ self.channels = channels
370
+
371
+ self.hop_length = self.nfft // 4
372
+ self.freq_emb = None
373
+
374
+ self.freq_encoder = nn.ModuleList()
375
+ self.freq_decoder = nn.ModuleList()
376
+
377
+ self.time_encoder = nn.ModuleList()
378
+ self.time_decoder = nn.ModuleList()
379
+
380
+ chin = audio_channels
381
+ chin_z = chin * 2 # number of channels for the freq branch
382
+ chout = channels
383
+ chout_z = channels
384
+ freqs = self.nfft // 2
385
+
386
+ for index in range(self.depth):
387
+ lstm = index >= dconv_lstm
388
+ attn = index >= dconv_attn
389
+ norm_type = "group_norm" if index >= norm_starts else "none"
390
+ freq = freqs > 1
391
+ stri = stride
392
+ ker = kernel_size
393
+ if not freq:
394
+ if freqs != 1:
395
+ raise ValueError("When freq is false, freqs must be 1.")
396
+ ker = time_stride * 2
397
+ stri = time_stride
398
+
399
+ pad = True
400
+ last_freq = False
401
+ if freq and freqs <= kernel_size:
402
+ ker = freqs
403
+ pad = False
404
+ last_freq = True
405
+
406
+ kw = {
407
+ "kernel_size": ker,
408
+ "stride": stri,
409
+ "freq": freq,
410
+ "pad": pad,
411
+ "norm_type": norm_type,
412
+ "norm_groups": norm_groups,
413
+ "dconv_kw": {
414
+ "lstm": lstm,
415
+ "attn": attn,
416
+ "depth": dconv_depth,
417
+ "compress": dconv_comp,
418
+ "init": dconv_init,
419
+ },
420
+ }
421
+ kwt = dict(kw)
422
+ kwt["freq"] = 0
423
+ kwt["kernel_size"] = kernel_size
424
+ kwt["stride"] = stride
425
+ kwt["pad"] = True
426
+ kw_dec = dict(kw)
427
+
428
+ if last_freq:
429
+ chout_z = max(chout, chout_z)
430
+ chout = chout_z
431
+
432
+ enc = _HEncLayer(chin_z, chout_z, context=context_enc, **kw)
433
+ if freq:
434
+ if last_freq is True and nfft == 2048:
435
+ kwt["stride"] = 2
436
+ kwt["kernel_size"] = 4
437
+ tenc = _HEncLayer(chin, chout, context=context_enc, empty=last_freq, **kwt)
438
+ self.time_encoder.append(tenc)
439
+
440
+ self.freq_encoder.append(enc)
441
+ if index == 0:
442
+ chin = self.audio_channels * len(self.sources)
443
+ chin_z = chin * 2
444
+ dec = _HDecLayer(chout_z, chin_z, last=index == 0, context=context, **kw_dec)
445
+ if freq:
446
+ tdec = _HDecLayer(chout, chin, empty=last_freq, last=index == 0, context=context, **kwt)
447
+ self.time_decoder.insert(0, tdec)
448
+ self.freq_decoder.insert(0, dec)
449
+
450
+ chin = chout
451
+ chin_z = chout_z
452
+ chout = int(growth * chout)
453
+ chout_z = int(growth * chout_z)
454
+ if freq:
455
+ if freqs <= kernel_size:
456
+ freqs = 1
457
+ else:
458
+ freqs //= stride
459
+ if index == 0 and freq_emb:
460
+ self.freq_emb = _ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
461
+ self.freq_emb_scale = freq_emb
462
+
463
+ _rescale_module(self)
464
+
465
+ def _spec(self, x):
466
+ hl = self.hop_length
467
+ nfft = self.nfft
468
+ x0 = x # noqa
469
+
470
+ # We re-pad the signal in order to keep the property
471
+ # that the size of the output is exactly the size of the input
472
+ # divided by the stride (here hop_length), when divisible.
473
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
474
+ # which is not supported by torch.stft.
475
+ # Having all convolution operations follow this convention allow to easily
476
+ # align the time and frequency branches later on.
477
+ if hl != nfft // 4:
478
+ raise ValueError("Hop length must be nfft // 4")
479
+ le = int(math.ceil(x.shape[-1] / hl))
480
+ pad = hl // 2 * 3
481
+ x = self._pad1d(x, pad, pad + le * hl - x.shape[-1], mode="reflect")
482
+
483
+ z = _spectro(x, nfft, hl)[..., :-1, :]
484
+ if z.shape[-1] != le + 4:
485
+ raise ValueError("Spectrogram's last dimension must be 4 + input size divided by stride")
486
+ z = z[..., 2 : 2 + le]
487
+ return z
488
+
489
+ def _ispec(self, z, length=None):
490
+ hl = self.hop_length
491
+ z = F.pad(z, [0, 0, 0, 1])
492
+ z = F.pad(z, [2, 2])
493
+ pad = hl // 2 * 3
494
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
495
+ x = _ispectro(z, hl, length=le)
496
+ x = x[..., pad : pad + length]
497
+ return x
498
+
499
+ def _pad1d(self, x: torch.Tensor, padding_left: int, padding_right: int, mode: str = "zero", value: float = 0.0):
500
+ """Wrapper around F.pad, in order for reflect padding when num_frames is shorter than max_pad.
501
+ Add extra zero padding around in order for padding to not break."""
502
+ length = x.shape[-1]
503
+ if mode == "reflect":
504
+ max_pad = max(padding_left, padding_right)
505
+ if length <= max_pad:
506
+ x = F.pad(x, (0, max_pad - length + 1))
507
+ return F.pad(x, (padding_left, padding_right), mode, value)
508
+
509
+ def _magnitude(self, z):
510
+ # move the complex dimension to the channel one.
511
+ B, C, Fr, T = z.shape
512
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
513
+ m = m.reshape(B, C * 2, Fr, T)
514
+ return m
515
+
516
+ def _mask(self, m):
517
+ # `m` is a full spectrogram and `z` is ignored.
518
+ B, S, C, Fr, T = m.shape
519
+ out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
520
+ out = torch.view_as_complex(out.contiguous())
521
+ return out
522
+
523
+ def forward(self, input: torch.Tensor):
524
+
525
+ r"""HDemucs forward call
526
+
527
+ Args:
528
+ input (torch.Tensor): input mixed tensor of shape `(batch_size, channel, num_frames)`
529
+
530
+ Returns:
531
+ Tensor
532
+ output tensor split into sources of shape `(batch_size, num_sources, channel, num_frames)`
533
+ """
534
+
535
+ if input.ndim != 3:
536
+ raise ValueError(f"Expected 3D tensor with dimensions (batch, channel, frames). Found: {input.shape}")
537
+
538
+ if input.shape[1] != self.audio_channels:
539
+ raise ValueError(
540
+ f"The channel dimension of input Tensor must match `audio_channels` of HDemucs model. "
541
+ f"Found:{input.shape[1]}."
542
+ )
543
+
544
+ x = input
545
+ length = x.shape[-1]
546
+
547
+ z = self._spec(input)
548
+ mag = self._magnitude(z)
549
+ x = mag
550
+
551
+ B, C, Fq, T = x.shape
552
+
553
+ # unlike previous Demucs, we always normalize because it is easier.
554
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
555
+ std = x.std(dim=(1, 2, 3), keepdim=True)
556
+ x = (x - mean) / (1e-5 + std)
557
+ # x will be the freq. branch input.
558
+
559
+ # Prepare the time branch input.
560
+ xt = input
561
+ meant = xt.mean(dim=(1, 2), keepdim=True)
562
+ stdt = xt.std(dim=(1, 2), keepdim=True)
563
+ xt = (xt - meant) / (1e-5 + stdt)
564
+
565
+ saved = [] # skip connections, freq.
566
+ saved_t = [] # skip connections, time.
567
+ lengths: List[int] = [] # saved lengths to properly remove padding, freq branch.
568
+ lengths_t: List[int] = [] # saved lengths for time branch.
569
+
570
+ for idx, encode in enumerate(self.freq_encoder):
571
+ lengths.append(x.shape[-1])
572
+ inject = None
573
+ if idx < len(self.time_encoder):
574
+ # we have not yet merged branches.
575
+ lengths_t.append(xt.shape[-1])
576
+ tenc = self.time_encoder[idx]
577
+ xt = tenc(xt)
578
+ if not tenc.empty:
579
+ # save for skip connection
580
+ saved_t.append(xt)
581
+ else:
582
+ # tenc contains just the first conv., so that now time and freq.
583
+ # branches have the same shape and can be merged.
584
+ inject = xt
585
+ x = encode(x, inject)
586
+ if idx == 0 and self.freq_emb is not None:
587
+ # add frequency embedding to allow for non equivariant convolutions
588
+ # over the frequency axis.
589
+ frs = torch.arange(x.shape[-2], device=x.device)
590
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
591
+ x = x + self.freq_emb_scale * emb
592
+
593
+ saved.append(x)
594
+
595
+ x = torch.zeros_like(x)
596
+ xt = torch.zeros_like(x)
597
+ # initialize everything to zero (signal will go through u-net skips).
598
+
599
+ for idx, decode in enumerate(self.freq_decoder):
600
+ skip = saved.pop(-1)
601
+ x, pre = decode(x, skip, lengths.pop(-1))
602
+ # `pre` contains the output just before final transposed convolution,
603
+ # which is used when the freq. and time branch separate.
604
+ offset = self.depth - len(self.time_decoder)
605
+ if idx >= offset:
606
+ tdec = self.time_decoder[idx - offset]
607
+ length_t = lengths_t.pop(-1)
608
+ if tdec.empty:
609
+ if pre.shape[2] != 1:
610
+ raise ValueError(f"If tdec empty is True, pre shape does not match {pre.shape}")
611
+ pre = pre[:, :, 0]
612
+ xt, _ = tdec(pre, None, length_t)
613
+ else:
614
+ skip = saved_t.pop(-1)
615
+ xt, _ = tdec(xt, skip, length_t)
616
+
617
+ if len(saved) != 0:
618
+ raise AssertionError("saved is not empty")
619
+ if len(lengths_t) != 0:
620
+ raise AssertionError("lengths_t is not empty")
621
+ if len(saved_t) != 0:
622
+ raise AssertionError("saved_t is not empty")
623
+
624
+ S = len(self.sources)
625
+ x = x.view(B, S, -1, Fq, T)
626
+ x = x * std[:, None] + mean[:, None]
627
+
628
+ zout = self._mask(x)
629
+ x = self._ispec(zout, length)
630
+
631
+ xt = xt.view(B, S, -1, length)
632
+ xt = xt * stdt[:, None] + meant[:, None]
633
+ x = xt + x
634
+ return x
635
+
636
+
637
+ class _DConv(torch.nn.Module):
638
+ r"""
639
+ New residual branches in each encoder layer.
640
+ This alternates dilated convolutions, potentially with LSTMs and attention.
641
+ Also before entering each residual branch, dimension is projected on a smaller subspace,
642
+ e.g. of dim `channels // compress`.
643
+
644
+ Args:
645
+ channels (int): input/output channels for residual branch.
646
+ compress (float, optional): amount of channel compression inside the branch. (default: 4)
647
+ depth (int, optional): number of layers in the residual branch. Each layer has its own
648
+ projection, and potentially LSTM and attention.(default: 2)
649
+ init (float, optional): initial scale for LayerNorm. (default: 1e-4)
650
+ norm_type (bool, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``)
651
+ attn (bool, optional): use LocalAttention. (Default: ``False``)
652
+ heads (int, optional): number of heads for the LocalAttention. (default: 4)
653
+ ndecay (int, optional): number of decay controls in the LocalAttention. (default: 4)
654
+ lstm (bool, optional): use LSTM. (Default: ``False``)
655
+ kernel_size (int, optional): kernel size for the (dilated) convolutions. (default: 3)
656
+ """
657
+
658
+ def __init__(
659
+ self,
660
+ channels: int,
661
+ compress: float = 4,
662
+ depth: int = 2,
663
+ init: float = 1e-4,
664
+ norm_type: str = "group_norm",
665
+ attn: bool = False,
666
+ heads: int = 4,
667
+ ndecay: int = 4,
668
+ lstm: bool = False,
669
+ kernel_size: int = 3,
670
+ ):
671
+
672
+ super().__init__()
673
+ if kernel_size % 2 == 0:
674
+ raise ValueError("Kernel size should not be divisible by 2")
675
+ self.channels = channels
676
+ self.compress = compress
677
+ self.depth = abs(depth)
678
+ dilate = depth > 0
679
+
680
+ norm_fn: tp.Callable[[int], nn.Module]
681
+ norm_fn = lambda d: nn.Identity() # noqa
682
+ if norm_type == "group_norm":
683
+ norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
684
+
685
+ hidden = int(channels / compress)
686
+
687
+ act = nn.GELU
688
+
689
+ self.layers = nn.ModuleList([])
690
+ for d in range(self.depth):
691
+ dilation = pow(2, d) if dilate else 1
692
+ padding = dilation * (kernel_size // 2)
693
+ mods = [
694
+ nn.Conv1d(channels, hidden, kernel_size, dilation=dilation, padding=padding),
695
+ norm_fn(hidden),
696
+ act(),
697
+ nn.Conv1d(hidden, 2 * channels, 1),
698
+ norm_fn(2 * channels),
699
+ nn.GLU(1),
700
+ _LayerScale(channels, init),
701
+ ]
702
+ if attn:
703
+ mods.insert(3, _LocalState(hidden, heads=heads, ndecay=ndecay))
704
+ if lstm:
705
+ mods.insert(3, _BLSTM(hidden, layers=2, skip=True))
706
+ layer = nn.Sequential(*mods)
707
+ self.layers.append(layer)
708
+
709
+ def forward(self, x):
710
+ r"""DConv forward call
711
+
712
+ Args:
713
+ x (torch.Tensor): input tensor for convolution
714
+
715
+ Returns:
716
+ Tensor
717
+ Output after being run through layers.
718
+ """
719
+ for layer in self.layers:
720
+ x = x + layer(x)
721
+ return x
722
+
723
+
724
+ class _BLSTM(torch.nn.Module):
725
+ r"""
726
+ BiLSTM with same hidden units as input dim.
727
+ If `max_steps` is not None, input will be splitting in overlapping
728
+ chunks and the LSTM applied separately on each chunk.
729
+ Args:
730
+ dim (int): dimensions at LSTM layer.
731
+ layers (int, optional): number of LSTM layers. (default: 1)
732
+ skip (bool, optional): (default: ``False``)
733
+ """
734
+
735
+ def __init__(self, dim, layers: int = 1, skip: bool = False):
736
+ super().__init__()
737
+ self.max_steps = 200
738
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
739
+ self.linear = nn.Linear(2 * dim, dim)
740
+ self.skip = skip
741
+
742
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
743
+ r"""BLSTM forward call
744
+
745
+ Args:
746
+ x (torch.Tensor): input tensor for BLSTM shape is `(batch_size, dim, time_steps)`
747
+
748
+ Returns:
749
+ Tensor
750
+ Output after being run through bidirectional LSTM. Shape is `(batch_size, dim, time_steps)`
751
+ """
752
+ B, C, T = x.shape
753
+ y = x
754
+ framed = False
755
+ width = 0
756
+ stride = 0
757
+ nframes = 0
758
+ if self.max_steps is not None and T > self.max_steps:
759
+ width = self.max_steps
760
+ stride = width // 2
761
+ frames = _unfold(x, width, stride)
762
+ nframes = frames.shape[2]
763
+ framed = True
764
+ x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
765
+
766
+ x = x.permute(2, 0, 1)
767
+
768
+ x = self.lstm(x)[0]
769
+ x = self.linear(x)
770
+ x = x.permute(1, 2, 0)
771
+ if framed:
772
+ out = []
773
+ frames = x.reshape(B, -1, C, width)
774
+ limit = stride // 2
775
+ for k in range(nframes):
776
+ if k == 0:
777
+ out.append(frames[:, k, :, :-limit])
778
+ elif k == nframes - 1:
779
+ out.append(frames[:, k, :, limit:])
780
+ else:
781
+ out.append(frames[:, k, :, limit:-limit])
782
+ out = torch.cat(out, -1)
783
+ out = out[..., :T]
784
+ x = out
785
+ if self.skip:
786
+ x = x + y
787
+
788
+ return x
789
+
790
+
791
+ class _LocalState(nn.Module):
792
+ """Local state allows to have attention based only on data (no positional embedding),
793
+ but while setting a constraint on the time window (e.g. decaying penalty term).
794
+ Also a failed experiments with trying to provide some frequency based attention.
795
+ """
796
+
797
+ def __init__(self, channels: int, heads: int = 4, ndecay: int = 4):
798
+ r"""
799
+ Args:
800
+ channels (int): Size of Conv1d layers.
801
+ heads (int, optional): (default: 4)
802
+ ndecay (int, optional): (default: 4)
803
+ """
804
+ super(_LocalState, self).__init__()
805
+ if channels % heads != 0:
806
+ raise ValueError("Channels must be divisible by heads.")
807
+ self.heads = heads
808
+ self.ndecay = ndecay
809
+ self.content = nn.Conv1d(channels, channels, 1)
810
+ self.query = nn.Conv1d(channels, channels, 1)
811
+ self.key = nn.Conv1d(channels, channels, 1)
812
+
813
+ self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
814
+ if ndecay:
815
+ # Initialize decay close to zero (there is a sigmoid), for maximum initial window.
816
+ self.query_decay.weight.data *= 0.01
817
+ if self.query_decay.bias is None:
818
+ raise ValueError("bias must not be None.")
819
+ self.query_decay.bias.data[:] = -2
820
+ self.proj = nn.Conv1d(channels + heads * 0, channels, 1)
821
+
822
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
823
+ r"""LocalState forward call
824
+
825
+ Args:
826
+ x (torch.Tensor): input tensor for LocalState
827
+
828
+ Returns:
829
+ Tensor
830
+ Output after being run through LocalState layer.
831
+ """
832
+ B, C, T = x.shape
833
+ heads = self.heads
834
+ indexes = torch.arange(T, device=x.device, dtype=x.dtype)
835
+ # left index are keys, right index are queries
836
+ delta = indexes[:, None] - indexes[None, :]
837
+
838
+ queries = self.query(x).view(B, heads, -1, T)
839
+ keys = self.key(x).view(B, heads, -1, T)
840
+ # t are keys, s are queries
841
+ dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
842
+ dots /= math.sqrt(keys.shape[2])
843
+ if self.ndecay:
844
+ decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
845
+ decay_q = self.query_decay(x).view(B, heads, -1, T)
846
+ decay_q = torch.sigmoid(decay_q) / 2
847
+ decay_kernel = -decays.view(-1, 1, 1) * delta.abs() / math.sqrt(self.ndecay)
848
+ dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
849
+
850
+ # Kill self reference.
851
+ dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
852
+ weights = torch.softmax(dots, dim=2)
853
+
854
+ content = self.content(x).view(B, heads, -1, T)
855
+ result = torch.einsum("bhts,bhct->bhcs", weights, content)
856
+ result = result.reshape(B, -1, T)
857
+ return x + self.proj(result)
858
+
859
+
860
+ class _LayerScale(nn.Module):
861
+ """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
862
+ This rescales diagonally residual outputs close to 0 initially, then learnt.
863
+ """
864
+
865
+ def __init__(self, channels: int, init: float = 0):
866
+ r"""
867
+ Args:
868
+ channels (int): Size of rescaling
869
+ init (float, optional): Scale to default to (default: 0)
870
+ """
871
+ super().__init__()
872
+ self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
873
+ self.scale.data[:] = init
874
+
875
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
876
+ r"""LayerScale forward call
877
+
878
+ Args:
879
+ x (torch.Tensor): input tensor for LayerScale
880
+
881
+ Returns:
882
+ Tensor
883
+ Output after rescaling tensor.
884
+ """
885
+ return self.scale[:, None] * x
886
+
887
+
888
+ def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor:
889
+ """Given input of size [*OT, T], output Tensor of size [*OT, F, K]
890
+ with K the kernel size, by extracting frames with the given stride.
891
+ This will pad the input so that `F = ceil(T / K)`.
892
+ see https://github.com/pytorch/pytorch/issues/60466
893
+ """
894
+ shape = list(a.shape[:-1])
895
+ length = int(a.shape[-1])
896
+ n_frames = math.ceil(length / stride)
897
+ tgt_length = (n_frames - 1) * stride + kernel_size
898
+ a = F.pad(input=a, pad=[0, tgt_length - length])
899
+ strides = [a.stride(dim) for dim in range(a.dim())]
900
+ if strides[-1] != 1:
901
+ raise ValueError("Data should be contiguous.")
902
+ strides = strides[:-1] + [stride, 1]
903
+ shape.append(n_frames)
904
+ shape.append(kernel_size)
905
+ return a.as_strided(shape, strides)
906
+
907
+
908
+ def _rescale_module(module):
909
+ r"""
910
+ Rescales initial weight scale for all models within the module.
911
+ """
912
+ for sub in module.modules():
913
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
914
+ std = sub.weight.std().detach()
915
+ scale = (std / 0.1) ** 0.5
916
+ sub.weight.data /= scale
917
+ if sub.bias is not None:
918
+ sub.bias.data /= scale
919
+
920
+
921
+ def _spectro(x: torch.Tensor, n_fft: int = 512, hop_length: int = 0, pad: int = 0) -> torch.Tensor:
922
+ other = list(x.shape[:-1])
923
+ length = int(x.shape[-1])
924
+ x = x.reshape(-1, length)
925
+ z = torch.stft(
926
+ x,
927
+ n_fft * (1 + pad),
928
+ hop_length,
929
+ window=torch.hann_window(n_fft).to(x),
930
+ win_length=n_fft,
931
+ normalized=True,
932
+ center=True,
933
+ return_complex=True,
934
+ pad_mode="reflect",
935
+ )
936
+ _, freqs, frame = z.shape
937
+ other.extend([freqs, frame])
938
+ return z.view(other)
939
+
940
+
941
+ def _ispectro(z: torch.Tensor, hop_length: int = 0, length: int = 0, pad: int = 0) -> torch.Tensor:
942
+ other = list(z.shape[:-2])
943
+ freqs = int(z.shape[-2])
944
+ frames = int(z.shape[-1])
945
+
946
+ n_fft = 2 * freqs - 2
947
+ z = z.view(-1, freqs, frames)
948
+ win_length = n_fft // (1 + pad)
949
+ x = torch.istft(
950
+ z,
951
+ n_fft,
952
+ hop_length,
953
+ window=torch.hann_window(win_length).to(z.real),
954
+ win_length=win_length,
955
+ normalized=True,
956
+ length=length,
957
+ center=True,
958
+ )
959
+ _, length = x.shape
960
+ other.append(length)
961
+ return x.view(other)
962
+
963
+
964
+ def hdemucs_low(sources: List[str]) -> HDemucs:
965
+ """Builds low nfft (1024) version of :class:`HDemucs`, suitable for sample rates around 8 kHz.
966
+
967
+ Args:
968
+ sources (List[str]): See :py:func:`HDemucs`.
969
+
970
+ Returns:
971
+ HDemucs:
972
+ HDemucs model.
973
+ """
974
+
975
+ return HDemucs(sources=sources, nfft=1024, depth=5)
976
+
977
+
978
+ def hdemucs_medium(sources: List[str]) -> HDemucs:
979
+ r"""Builds medium nfft (2048) version of :class:`HDemucs`, suitable for sample rates of 16-32 kHz.
980
+
981
+ .. note::
982
+
983
+ Medium HDemucs has not been tested against the original Hybrid Demucs as this nfft and depth configuration is
984
+ not compatible with the original implementation in https://github.com/facebookresearch/demucs
985
+
986
+ Args:
987
+ sources (List[str]): See :py:func:`HDemucs`.
988
+
989
+ Returns:
990
+ HDemucs:
991
+ HDemucs model.
992
+ """
993
+
994
+ return HDemucs(sources=sources, nfft=2048, depth=6)
995
+
996
+
997
+ def hdemucs_high(sources: List[str]) -> HDemucs:
998
+ r"""Builds medium nfft (4096) version of :class:`HDemucs`, suitable for sample rates of 44.1-48 kHz.
999
+
1000
+ Args:
1001
+ sources (List[str]): See :py:func:`HDemucs`.
1002
+
1003
+ Returns:
1004
+ HDemucs:
1005
+ HDemucs model.
1006
+ """
1007
+
1008
+ return HDemucs(sources=sources, nfft=4096, depth=6)
.venv/lib/python3.11/site-packages/torchaudio/models/conformer.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+
5
+
6
+ __all__ = ["Conformer"]
7
+
8
+
9
+ def _lengths_to_padding_mask(lengths: torch.Tensor) -> torch.Tensor:
10
+ batch_size = lengths.shape[0]
11
+ max_length = int(torch.max(lengths).item())
12
+ padding_mask = torch.arange(max_length, device=lengths.device, dtype=lengths.dtype).expand(
13
+ batch_size, max_length
14
+ ) >= lengths.unsqueeze(1)
15
+ return padding_mask
16
+
17
+
18
+ class _ConvolutionModule(torch.nn.Module):
19
+ r"""Conformer convolution module.
20
+
21
+ Args:
22
+ input_dim (int): input dimension.
23
+ num_channels (int): number of depthwise convolution layer input channels.
24
+ depthwise_kernel_size (int): kernel size of depthwise convolution layer.
25
+ dropout (float, optional): dropout probability. (Default: 0.0)
26
+ bias (bool, optional): indicates whether to add bias term to each convolution layer. (Default: ``False``)
27
+ use_group_norm (bool, optional): use GroupNorm rather than BatchNorm. (Default: ``False``)
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ input_dim: int,
33
+ num_channels: int,
34
+ depthwise_kernel_size: int,
35
+ dropout: float = 0.0,
36
+ bias: bool = False,
37
+ use_group_norm: bool = False,
38
+ ) -> None:
39
+ super().__init__()
40
+ if (depthwise_kernel_size - 1) % 2 != 0:
41
+ raise ValueError("depthwise_kernel_size must be odd to achieve 'SAME' padding.")
42
+ self.layer_norm = torch.nn.LayerNorm(input_dim)
43
+ self.sequential = torch.nn.Sequential(
44
+ torch.nn.Conv1d(
45
+ input_dim,
46
+ 2 * num_channels,
47
+ 1,
48
+ stride=1,
49
+ padding=0,
50
+ bias=bias,
51
+ ),
52
+ torch.nn.GLU(dim=1),
53
+ torch.nn.Conv1d(
54
+ num_channels,
55
+ num_channels,
56
+ depthwise_kernel_size,
57
+ stride=1,
58
+ padding=(depthwise_kernel_size - 1) // 2,
59
+ groups=num_channels,
60
+ bias=bias,
61
+ ),
62
+ torch.nn.GroupNorm(num_groups=1, num_channels=num_channels)
63
+ if use_group_norm
64
+ else torch.nn.BatchNorm1d(num_channels),
65
+ torch.nn.SiLU(),
66
+ torch.nn.Conv1d(
67
+ num_channels,
68
+ input_dim,
69
+ kernel_size=1,
70
+ stride=1,
71
+ padding=0,
72
+ bias=bias,
73
+ ),
74
+ torch.nn.Dropout(dropout),
75
+ )
76
+
77
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
78
+ r"""
79
+ Args:
80
+ input (torch.Tensor): with shape `(B, T, D)`.
81
+
82
+ Returns:
83
+ torch.Tensor: output, with shape `(B, T, D)`.
84
+ """
85
+ x = self.layer_norm(input)
86
+ x = x.transpose(1, 2)
87
+ x = self.sequential(x)
88
+ return x.transpose(1, 2)
89
+
90
+
91
+ class _FeedForwardModule(torch.nn.Module):
92
+ r"""Positionwise feed forward layer.
93
+
94
+ Args:
95
+ input_dim (int): input dimension.
96
+ hidden_dim (int): hidden dimension.
97
+ dropout (float, optional): dropout probability. (Default: 0.0)
98
+ """
99
+
100
+ def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.0) -> None:
101
+ super().__init__()
102
+ self.sequential = torch.nn.Sequential(
103
+ torch.nn.LayerNorm(input_dim),
104
+ torch.nn.Linear(input_dim, hidden_dim, bias=True),
105
+ torch.nn.SiLU(),
106
+ torch.nn.Dropout(dropout),
107
+ torch.nn.Linear(hidden_dim, input_dim, bias=True),
108
+ torch.nn.Dropout(dropout),
109
+ )
110
+
111
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
112
+ r"""
113
+ Args:
114
+ input (torch.Tensor): with shape `(*, D)`.
115
+
116
+ Returns:
117
+ torch.Tensor: output, with shape `(*, D)`.
118
+ """
119
+ return self.sequential(input)
120
+
121
+
122
+ class ConformerLayer(torch.nn.Module):
123
+ r"""Conformer layer that constitutes Conformer.
124
+
125
+ Args:
126
+ input_dim (int): input dimension.
127
+ ffn_dim (int): hidden layer dimension of feedforward network.
128
+ num_attention_heads (int): number of attention heads.
129
+ depthwise_conv_kernel_size (int): kernel size of depthwise convolution layer.
130
+ dropout (float, optional): dropout probability. (Default: 0.0)
131
+ use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d``
132
+ in the convolution module. (Default: ``False``)
133
+ convolution_first (bool, optional): apply the convolution module ahead of
134
+ the attention module. (Default: ``False``)
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ input_dim: int,
140
+ ffn_dim: int,
141
+ num_attention_heads: int,
142
+ depthwise_conv_kernel_size: int,
143
+ dropout: float = 0.0,
144
+ use_group_norm: bool = False,
145
+ convolution_first: bool = False,
146
+ ) -> None:
147
+ super().__init__()
148
+
149
+ self.ffn1 = _FeedForwardModule(input_dim, ffn_dim, dropout=dropout)
150
+
151
+ self.self_attn_layer_norm = torch.nn.LayerNorm(input_dim)
152
+ self.self_attn = torch.nn.MultiheadAttention(input_dim, num_attention_heads, dropout=dropout)
153
+ self.self_attn_dropout = torch.nn.Dropout(dropout)
154
+
155
+ self.conv_module = _ConvolutionModule(
156
+ input_dim=input_dim,
157
+ num_channels=input_dim,
158
+ depthwise_kernel_size=depthwise_conv_kernel_size,
159
+ dropout=dropout,
160
+ bias=True,
161
+ use_group_norm=use_group_norm,
162
+ )
163
+
164
+ self.ffn2 = _FeedForwardModule(input_dim, ffn_dim, dropout=dropout)
165
+ self.final_layer_norm = torch.nn.LayerNorm(input_dim)
166
+ self.convolution_first = convolution_first
167
+
168
+ def _apply_convolution(self, input: torch.Tensor) -> torch.Tensor:
169
+ residual = input
170
+ input = input.transpose(0, 1)
171
+ input = self.conv_module(input)
172
+ input = input.transpose(0, 1)
173
+ input = residual + input
174
+ return input
175
+
176
+ def forward(self, input: torch.Tensor, key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor:
177
+ r"""
178
+ Args:
179
+ input (torch.Tensor): input, with shape `(T, B, D)`.
180
+ key_padding_mask (torch.Tensor or None): key padding mask to use in self attention layer.
181
+
182
+ Returns:
183
+ torch.Tensor: output, with shape `(T, B, D)`.
184
+ """
185
+ residual = input
186
+ x = self.ffn1(input)
187
+ x = x * 0.5 + residual
188
+
189
+ if self.convolution_first:
190
+ x = self._apply_convolution(x)
191
+
192
+ residual = x
193
+ x = self.self_attn_layer_norm(x)
194
+ x, _ = self.self_attn(
195
+ query=x,
196
+ key=x,
197
+ value=x,
198
+ key_padding_mask=key_padding_mask,
199
+ need_weights=False,
200
+ )
201
+ x = self.self_attn_dropout(x)
202
+ x = x + residual
203
+
204
+ if not self.convolution_first:
205
+ x = self._apply_convolution(x)
206
+
207
+ residual = x
208
+ x = self.ffn2(x)
209
+ x = x * 0.5 + residual
210
+
211
+ x = self.final_layer_norm(x)
212
+ return x
213
+
214
+
215
+ class Conformer(torch.nn.Module):
216
+ r"""Conformer architecture introduced in
217
+ *Conformer: Convolution-augmented Transformer for Speech Recognition*
218
+ :cite:`gulati2020conformer`.
219
+
220
+ Args:
221
+ input_dim (int): input dimension.
222
+ num_heads (int): number of attention heads in each Conformer layer.
223
+ ffn_dim (int): hidden layer dimension of feedforward networks.
224
+ num_layers (int): number of Conformer layers to instantiate.
225
+ depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer.
226
+ dropout (float, optional): dropout probability. (Default: 0.0)
227
+ use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d``
228
+ in the convolution module. (Default: ``False``)
229
+ convolution_first (bool, optional): apply the convolution module ahead of
230
+ the attention module. (Default: ``False``)
231
+
232
+ Examples:
233
+ >>> conformer = Conformer(
234
+ >>> input_dim=80,
235
+ >>> num_heads=4,
236
+ >>> ffn_dim=128,
237
+ >>> num_layers=4,
238
+ >>> depthwise_conv_kernel_size=31,
239
+ >>> )
240
+ >>> lengths = torch.randint(1, 400, (10,)) # (batch,)
241
+ >>> input = torch.rand(10, int(lengths.max()), input_dim) # (batch, num_frames, input_dim)
242
+ >>> output = conformer(input, lengths)
243
+ """
244
+
245
+ def __init__(
246
+ self,
247
+ input_dim: int,
248
+ num_heads: int,
249
+ ffn_dim: int,
250
+ num_layers: int,
251
+ depthwise_conv_kernel_size: int,
252
+ dropout: float = 0.0,
253
+ use_group_norm: bool = False,
254
+ convolution_first: bool = False,
255
+ ):
256
+ super().__init__()
257
+
258
+ self.conformer_layers = torch.nn.ModuleList(
259
+ [
260
+ ConformerLayer(
261
+ input_dim,
262
+ ffn_dim,
263
+ num_heads,
264
+ depthwise_conv_kernel_size,
265
+ dropout=dropout,
266
+ use_group_norm=use_group_norm,
267
+ convolution_first=convolution_first,
268
+ )
269
+ for _ in range(num_layers)
270
+ ]
271
+ )
272
+
273
+ def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
274
+ r"""
275
+ Args:
276
+ input (torch.Tensor): with shape `(B, T, input_dim)`.
277
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
278
+ number of valid frames for i-th batch element in ``input``.
279
+
280
+ Returns:
281
+ (torch.Tensor, torch.Tensor)
282
+ torch.Tensor
283
+ output frames, with shape `(B, T, input_dim)`
284
+ torch.Tensor
285
+ output lengths, with shape `(B,)` and i-th element representing
286
+ number of valid frames for i-th batch element in output frames.
287
+ """
288
+ encoder_padding_mask = _lengths_to_padding_mask(lengths)
289
+
290
+ x = input.transpose(0, 1)
291
+ for layer in self.conformer_layers:
292
+ x = layer(x, encoder_padding_mask)
293
+ return x.transpose(0, 1), lengths
.venv/lib/python3.11/site-packages/torchaudio/models/conv_tasnet.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implements Conv-TasNet with building blocks of it.
2
+
3
+ Based on https://github.com/naplab/Conv-TasNet/tree/e66d82a8f956a69749ec8a4ae382217faa097c5c
4
+ """
5
+
6
+ from typing import Optional, Tuple
7
+
8
+ import torch
9
+
10
+
11
+ class ConvBlock(torch.nn.Module):
12
+ """1D Convolutional block.
13
+
14
+ Args:
15
+ io_channels (int): The number of input/output channels, <B, Sc>
16
+ hidden_channels (int): The number of channels in the internal layers, <H>.
17
+ kernel_size (int): The convolution kernel size of the middle layer, <P>.
18
+ padding (int): Padding value of the convolution in the middle layer.
19
+ dilation (int, optional): Dilation value of the convolution in the middle layer.
20
+ no_redisual (bool, optional): Disable residual block/output.
21
+
22
+ Note:
23
+ This implementation corresponds to the "non-causal" setting in the paper.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ io_channels: int,
29
+ hidden_channels: int,
30
+ kernel_size: int,
31
+ padding: int,
32
+ dilation: int = 1,
33
+ no_residual: bool = False,
34
+ ):
35
+ super().__init__()
36
+
37
+ self.conv_layers = torch.nn.Sequential(
38
+ torch.nn.Conv1d(in_channels=io_channels, out_channels=hidden_channels, kernel_size=1),
39
+ torch.nn.PReLU(),
40
+ torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08),
41
+ torch.nn.Conv1d(
42
+ in_channels=hidden_channels,
43
+ out_channels=hidden_channels,
44
+ kernel_size=kernel_size,
45
+ padding=padding,
46
+ dilation=dilation,
47
+ groups=hidden_channels,
48
+ ),
49
+ torch.nn.PReLU(),
50
+ torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08),
51
+ )
52
+
53
+ self.res_out = (
54
+ None
55
+ if no_residual
56
+ else torch.nn.Conv1d(in_channels=hidden_channels, out_channels=io_channels, kernel_size=1)
57
+ )
58
+ self.skip_out = torch.nn.Conv1d(in_channels=hidden_channels, out_channels=io_channels, kernel_size=1)
59
+
60
+ def forward(self, input: torch.Tensor) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
61
+ feature = self.conv_layers(input)
62
+ if self.res_out is None:
63
+ residual = None
64
+ else:
65
+ residual = self.res_out(feature)
66
+ skip_out = self.skip_out(feature)
67
+ return residual, skip_out
68
+
69
+
70
+ class MaskGenerator(torch.nn.Module):
71
+ """TCN (Temporal Convolution Network) Separation Module
72
+
73
+ Generates masks for separation.
74
+
75
+ Args:
76
+ input_dim (int): Input feature dimension, <N>.
77
+ num_sources (int): The number of sources to separate.
78
+ kernel_size (int): The convolution kernel size of conv blocks, <P>.
79
+ num_featrs (int): Input/output feature dimenstion of conv blocks, <B, Sc>.
80
+ num_hidden (int): Intermediate feature dimention of conv blocks, <H>
81
+ num_layers (int): The number of conv blocks in one stack, <X>.
82
+ num_stacks (int): The number of conv block stacks, <R>.
83
+ msk_activate (str): The activation function of the mask output.
84
+
85
+ Note:
86
+ This implementation corresponds to the "non-causal" setting in the paper.
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ input_dim: int,
92
+ num_sources: int,
93
+ kernel_size: int,
94
+ num_feats: int,
95
+ num_hidden: int,
96
+ num_layers: int,
97
+ num_stacks: int,
98
+ msk_activate: str,
99
+ ):
100
+ super().__init__()
101
+
102
+ self.input_dim = input_dim
103
+ self.num_sources = num_sources
104
+
105
+ self.input_norm = torch.nn.GroupNorm(num_groups=1, num_channels=input_dim, eps=1e-8)
106
+ self.input_conv = torch.nn.Conv1d(in_channels=input_dim, out_channels=num_feats, kernel_size=1)
107
+
108
+ self.receptive_field = 0
109
+ self.conv_layers = torch.nn.ModuleList([])
110
+ for s in range(num_stacks):
111
+ for l in range(num_layers):
112
+ multi = 2**l
113
+ self.conv_layers.append(
114
+ ConvBlock(
115
+ io_channels=num_feats,
116
+ hidden_channels=num_hidden,
117
+ kernel_size=kernel_size,
118
+ dilation=multi,
119
+ padding=multi,
120
+ # The last ConvBlock does not need residual
121
+ no_residual=(l == (num_layers - 1) and s == (num_stacks - 1)),
122
+ )
123
+ )
124
+ self.receptive_field += kernel_size if s == 0 and l == 0 else (kernel_size - 1) * multi
125
+ self.output_prelu = torch.nn.PReLU()
126
+ self.output_conv = torch.nn.Conv1d(
127
+ in_channels=num_feats,
128
+ out_channels=input_dim * num_sources,
129
+ kernel_size=1,
130
+ )
131
+ if msk_activate == "sigmoid":
132
+ self.mask_activate = torch.nn.Sigmoid()
133
+ elif msk_activate == "relu":
134
+ self.mask_activate = torch.nn.ReLU()
135
+ else:
136
+ raise ValueError(f"Unsupported activation {msk_activate}")
137
+
138
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
139
+ """Generate separation mask.
140
+
141
+ Args:
142
+ input (torch.Tensor): 3D Tensor with shape [batch, features, frames]
143
+
144
+ Returns:
145
+ Tensor: shape [batch, num_sources, features, frames]
146
+ """
147
+ batch_size = input.shape[0]
148
+ feats = self.input_norm(input)
149
+ feats = self.input_conv(feats)
150
+ output = 0.0
151
+ for layer in self.conv_layers:
152
+ residual, skip = layer(feats)
153
+ if residual is not None: # the last conv layer does not produce residual
154
+ feats = feats + residual
155
+ output = output + skip
156
+ output = self.output_prelu(output)
157
+ output = self.output_conv(output)
158
+ output = self.mask_activate(output)
159
+ return output.view(batch_size, self.num_sources, self.input_dim, -1)
160
+
161
+
162
+ class ConvTasNet(torch.nn.Module):
163
+ """Conv-TasNet architecture introduced in
164
+ *Conv-TasNet: Surpassing Ideal Time–Frequency Magnitude Masking for Speech Separation*
165
+ :cite:`Luo_2019`.
166
+
167
+ Note:
168
+ This implementation corresponds to the "non-causal" setting in the paper.
169
+
170
+ See Also:
171
+ * :class:`torchaudio.pipelines.SourceSeparationBundle`: Source separation pipeline with pre-trained models.
172
+
173
+ Args:
174
+ num_sources (int, optional): The number of sources to split.
175
+ enc_kernel_size (int, optional): The convolution kernel size of the encoder/decoder, <L>.
176
+ enc_num_feats (int, optional): The feature dimensions passed to mask generator, <N>.
177
+ msk_kernel_size (int, optional): The convolution kernel size of the mask generator, <P>.
178
+ msk_num_feats (int, optional): The input/output feature dimension of conv block in the mask generator, <B, Sc>.
179
+ msk_num_hidden_feats (int, optional): The internal feature dimension of conv block of the mask generator, <H>.
180
+ msk_num_layers (int, optional): The number of layers in one conv block of the mask generator, <X>.
181
+ msk_num_stacks (int, optional): The numbr of conv blocks of the mask generator, <R>.
182
+ msk_activate (str, optional): The activation function of the mask output (Default: ``sigmoid``).
183
+ """
184
+
185
+ def __init__(
186
+ self,
187
+ num_sources: int = 2,
188
+ # encoder/decoder parameters
189
+ enc_kernel_size: int = 16,
190
+ enc_num_feats: int = 512,
191
+ # mask generator parameters
192
+ msk_kernel_size: int = 3,
193
+ msk_num_feats: int = 128,
194
+ msk_num_hidden_feats: int = 512,
195
+ msk_num_layers: int = 8,
196
+ msk_num_stacks: int = 3,
197
+ msk_activate: str = "sigmoid",
198
+ ):
199
+ super().__init__()
200
+
201
+ self.num_sources = num_sources
202
+ self.enc_num_feats = enc_num_feats
203
+ self.enc_kernel_size = enc_kernel_size
204
+ self.enc_stride = enc_kernel_size // 2
205
+
206
+ self.encoder = torch.nn.Conv1d(
207
+ in_channels=1,
208
+ out_channels=enc_num_feats,
209
+ kernel_size=enc_kernel_size,
210
+ stride=self.enc_stride,
211
+ padding=self.enc_stride,
212
+ bias=False,
213
+ )
214
+ self.mask_generator = MaskGenerator(
215
+ input_dim=enc_num_feats,
216
+ num_sources=num_sources,
217
+ kernel_size=msk_kernel_size,
218
+ num_feats=msk_num_feats,
219
+ num_hidden=msk_num_hidden_feats,
220
+ num_layers=msk_num_layers,
221
+ num_stacks=msk_num_stacks,
222
+ msk_activate=msk_activate,
223
+ )
224
+ self.decoder = torch.nn.ConvTranspose1d(
225
+ in_channels=enc_num_feats,
226
+ out_channels=1,
227
+ kernel_size=enc_kernel_size,
228
+ stride=self.enc_stride,
229
+ padding=self.enc_stride,
230
+ bias=False,
231
+ )
232
+
233
+ def _align_num_frames_with_strides(self, input: torch.Tensor) -> Tuple[torch.Tensor, int]:
234
+ """Pad input Tensor so that the end of the input tensor corresponds with
235
+
236
+ 1. (if kernel size is odd) the center of the last convolution kernel
237
+ or 2. (if kernel size is even) the end of the first half of the last convolution kernel
238
+
239
+ Assumption:
240
+ The resulting Tensor will be padded with the size of stride (== kernel_width // 2)
241
+ on the both ends in Conv1D
242
+
243
+ |<--- k_1 --->|
244
+ | | |<-- k_n-1 -->|
245
+ | | | |<--- k_n --->|
246
+ | | | | |
247
+ | | | | |
248
+ | v v v |
249
+ |<---->|<--- input signal --->|<--->|<---->|
250
+ stride PAD stride
251
+
252
+ Args:
253
+ input (torch.Tensor): 3D Tensor with shape (batch_size, channels==1, frames)
254
+
255
+ Returns:
256
+ Tensor: Padded Tensor
257
+ int: Number of paddings performed
258
+ """
259
+ batch_size, num_channels, num_frames = input.shape
260
+ is_odd = self.enc_kernel_size % 2
261
+ num_strides = (num_frames - is_odd) // self.enc_stride
262
+ num_remainings = num_frames - (is_odd + num_strides * self.enc_stride)
263
+ if num_remainings == 0:
264
+ return input, 0
265
+
266
+ num_paddings = self.enc_stride - num_remainings
267
+ pad = torch.zeros(
268
+ batch_size,
269
+ num_channels,
270
+ num_paddings,
271
+ dtype=input.dtype,
272
+ device=input.device,
273
+ )
274
+ return torch.cat([input, pad], 2), num_paddings
275
+
276
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
277
+ """Perform source separation. Generate audio source waveforms.
278
+
279
+ Args:
280
+ input (torch.Tensor): 3D Tensor with shape [batch, channel==1, frames]
281
+
282
+ Returns:
283
+ Tensor: 3D Tensor with shape [batch, channel==num_sources, frames]
284
+ """
285
+ if input.ndim != 3 or input.shape[1] != 1:
286
+ raise ValueError(f"Expected 3D tensor (batch, channel==1, frames). Found: {input.shape}")
287
+
288
+ # B: batch size
289
+ # L: input frame length
290
+ # L': padded input frame length
291
+ # F: feature dimension
292
+ # M: feature frame length
293
+ # S: number of sources
294
+
295
+ padded, num_pads = self._align_num_frames_with_strides(input) # B, 1, L'
296
+ batch_size, num_padded_frames = padded.shape[0], padded.shape[2]
297
+ feats = self.encoder(padded) # B, F, M
298
+ masked = self.mask_generator(feats) * feats.unsqueeze(1) # B, S, F, M
299
+ masked = masked.view(batch_size * self.num_sources, self.enc_num_feats, -1) # B*S, F, M
300
+ decoded = self.decoder(masked) # B*S, 1, L'
301
+ output = decoded.view(batch_size, self.num_sources, num_padded_frames) # B, S, L'
302
+ if num_pads > 0:
303
+ output = output[..., :-num_pads] # B, S, L
304
+ return output
305
+
306
+
307
+ def conv_tasnet_base(num_sources: int = 2) -> ConvTasNet:
308
+ r"""Builds non-causal version of :class:`~torchaudio.models.ConvTasNet`.
309
+
310
+ The parameter settings follow the ones with the highest Si-SNR metirc score in the paper,
311
+ except the mask activation function is changed from "sigmoid" to "relu" for performance improvement.
312
+
313
+ Args:
314
+ num_sources (int, optional): Number of sources in the output.
315
+ (Default: 2)
316
+ Returns:
317
+ ConvTasNet:
318
+ ConvTasNet model.
319
+ """
320
+ return ConvTasNet(
321
+ num_sources=num_sources,
322
+ enc_kernel_size=16,
323
+ enc_num_feats=512,
324
+ msk_kernel_size=3,
325
+ msk_num_feats=128,
326
+ msk_num_hidden_feats=512,
327
+ msk_num_layers=8,
328
+ msk_num_stacks=3,
329
+ msk_activate="relu",
330
+ )
.venv/lib/python3.11/site-packages/torchaudio/models/decoder/__init__.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _CTC_DECODERS = [
2
+ "CTCHypothesis",
3
+ "CTCDecoder",
4
+ "CTCDecoderLM",
5
+ "CTCDecoderLMState",
6
+ "ctc_decoder",
7
+ "download_pretrained_files",
8
+ ]
9
+ _CUDA_CTC_DECODERS = [
10
+ "CUCTCDecoder",
11
+ "CUCTCHypothesis",
12
+ "cuda_ctc_decoder",
13
+ ]
14
+
15
+
16
+ def __getattr__(name: str):
17
+ if name in _CTC_DECODERS:
18
+ try:
19
+ from . import _ctc_decoder
20
+ except Exception as err:
21
+ raise RuntimeError(
22
+ "CTC Decoder suit requires flashlight-text package and optionally KenLM. Please install them."
23
+ ) from err
24
+
25
+ item = getattr(_ctc_decoder, name)
26
+ globals()[name] = item
27
+ return item
28
+ elif name in _CUDA_CTC_DECODERS:
29
+ try:
30
+ from . import _cuda_ctc_decoder
31
+ except AttributeError as err:
32
+ raise RuntimeError(
33
+ "To use CUCTC decoder, please set BUILD_CUDA_CTC_DECODER=1 when building from source."
34
+ ) from err
35
+
36
+ item = getattr(_cuda_ctc_decoder, name)
37
+ globals()[name] = item
38
+ return item
39
+ raise AttributeError(f"module {__name__} has no attribute {name}")
40
+
41
+
42
+ def __dir__():
43
+ return sorted(__all__)
44
+
45
+
46
+ __all__ = _CTC_DECODERS + _CUDA_CTC_DECODERS
.venv/lib/python3.11/site-packages/torchaudio/models/decoder/_ctc_decoder.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import itertools as it
4
+
5
+ from abc import abstractmethod
6
+ from collections import namedtuple
7
+ from typing import Dict, List, NamedTuple, Optional, Tuple, Union
8
+
9
+ import torch
10
+
11
+ from flashlight.lib.text.decoder import (
12
+ CriterionType as _CriterionType,
13
+ LexiconDecoder as _LexiconDecoder,
14
+ LexiconDecoderOptions as _LexiconDecoderOptions,
15
+ LexiconFreeDecoder as _LexiconFreeDecoder,
16
+ LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions,
17
+ LM as _LM,
18
+ LMState as _LMState,
19
+ SmearingMode as _SmearingMode,
20
+ Trie as _Trie,
21
+ ZeroLM as _ZeroLM,
22
+ )
23
+ from flashlight.lib.text.dictionary import (
24
+ create_word_dict as _create_word_dict,
25
+ Dictionary as _Dictionary,
26
+ load_words as _load_words,
27
+ )
28
+ from torchaudio.utils import download_asset
29
+
30
+ try:
31
+ from flashlight.lib.text.decoder.kenlm import KenLM as _KenLM
32
+ except Exception:
33
+ try:
34
+ from flashlight.lib.text.decoder import KenLM as _KenLM
35
+ except Exception:
36
+ _KenLM = None
37
+
38
+ __all__ = [
39
+ "CTCHypothesis",
40
+ "CTCDecoder",
41
+ "CTCDecoderLM",
42
+ "CTCDecoderLMState",
43
+ "ctc_decoder",
44
+ "download_pretrained_files",
45
+ ]
46
+
47
+ _PretrainedFiles = namedtuple("PretrainedFiles", ["lexicon", "tokens", "lm"])
48
+
49
+
50
+ def _construct_trie(tokens_dict, word_dict, lexicon, lm, silence):
51
+ vocab_size = tokens_dict.index_size()
52
+ trie = _Trie(vocab_size, silence)
53
+ start_state = lm.start(False)
54
+
55
+ for word, spellings in lexicon.items():
56
+ word_idx = word_dict.get_index(word)
57
+ _, score = lm.score(start_state, word_idx)
58
+ for spelling in spellings:
59
+ spelling_idx = [tokens_dict.get_index(token) for token in spelling]
60
+ trie.insert(spelling_idx, word_idx, score)
61
+ trie.smear(_SmearingMode.MAX)
62
+ return trie
63
+
64
+
65
+ def _get_word_dict(lexicon, lm, lm_dict, tokens_dict, unk_word):
66
+ word_dict = None
67
+ if lm_dict is not None:
68
+ word_dict = _Dictionary(lm_dict)
69
+
70
+ if lexicon and word_dict is None:
71
+ word_dict = _create_word_dict(lexicon)
72
+ elif not lexicon and word_dict is None and type(lm) == str:
73
+ d = {tokens_dict.get_entry(i): [[tokens_dict.get_entry(i)]] for i in range(tokens_dict.index_size())}
74
+ d[unk_word] = [[unk_word]]
75
+ word_dict = _create_word_dict(d)
76
+
77
+ return word_dict
78
+
79
+
80
+ class CTCHypothesis(NamedTuple):
81
+ r"""Represents hypothesis generated by CTC beam search decoder :class:`CTCDecoder`."""
82
+ tokens: torch.LongTensor
83
+ """Predicted sequence of token IDs. Shape `(L, )`, where `L` is the length of the output sequence"""
84
+
85
+ words: List[str]
86
+ """List of predicted words.
87
+
88
+ Note:
89
+ This attribute is only applicable if a lexicon is provided to the decoder. If
90
+ decoding without a lexicon, it will be blank. Please refer to :attr:`tokens` and
91
+ :func:`~torchaudio.models.decoder.CTCDecoder.idxs_to_tokens` instead.
92
+ """
93
+
94
+ score: float
95
+ """Score corresponding to hypothesis"""
96
+
97
+ timesteps: torch.IntTensor
98
+ """Timesteps corresponding to the tokens. Shape `(L, )`, where `L` is the length of the output sequence"""
99
+
100
+
101
+ class CTCDecoderLMState(_LMState):
102
+ """Language model state."""
103
+
104
+ @property
105
+ def children(self) -> Dict[int, CTCDecoderLMState]:
106
+ """Map of indices to LM states"""
107
+ return super().children
108
+
109
+ def child(self, usr_index: int) -> CTCDecoderLMState:
110
+ """Returns child corresponding to usr_index, or creates and returns a new state if input index
111
+ is not found.
112
+
113
+ Args:
114
+ usr_index (int): index corresponding to child state
115
+
116
+ Returns:
117
+ CTCDecoderLMState: child state corresponding to usr_index
118
+ """
119
+ return super().child(usr_index)
120
+
121
+ def compare(self, state: CTCDecoderLMState) -> CTCDecoderLMState:
122
+ """Compare two language model states.
123
+
124
+ Args:
125
+ state (CTCDecoderLMState): LM state to compare against
126
+
127
+ Returns:
128
+ int: 0 if the states are the same, -1 if self is less, +1 if self is greater.
129
+ """
130
+ pass
131
+
132
+
133
+ class CTCDecoderLM(_LM):
134
+ """Language model base class for creating custom language models to use with the decoder."""
135
+
136
+ @abstractmethod
137
+ def start(self, start_with_nothing: bool) -> CTCDecoderLMState:
138
+ """Initialize or reset the language model.
139
+
140
+ Args:
141
+ start_with_nothing (bool): whether or not to start sentence with sil token.
142
+
143
+ Returns:
144
+ CTCDecoderLMState: starting state
145
+ """
146
+ raise NotImplementedError
147
+
148
+ @abstractmethod
149
+ def score(self, state: CTCDecoderLMState, usr_token_idx: int) -> Tuple[CTCDecoderLMState, float]:
150
+ """Evaluate the language model based on the current LM state and new word.
151
+
152
+ Args:
153
+ state (CTCDecoderLMState): current LM state
154
+ usr_token_idx (int): index of the word
155
+
156
+ Returns:
157
+ (CTCDecoderLMState, float)
158
+ CTCDecoderLMState:
159
+ new LM state
160
+ float:
161
+ score
162
+ """
163
+ raise NotImplementedError
164
+
165
+ @abstractmethod
166
+ def finish(self, state: CTCDecoderLMState) -> Tuple[CTCDecoderLMState, float]:
167
+ """Evaluate end for language model based on current LM state.
168
+
169
+ Args:
170
+ state (CTCDecoderLMState): current LM state
171
+
172
+ Returns:
173
+ (CTCDecoderLMState, float)
174
+ CTCDecoderLMState:
175
+ new LM state
176
+ float:
177
+ score
178
+ """
179
+ raise NotImplementedError
180
+
181
+
182
+ class CTCDecoder:
183
+ """CTC beam search decoder from *Flashlight* :cite:`kahn2022flashlight`.
184
+
185
+ .. devices:: CPU
186
+
187
+ Note:
188
+ To build the decoder, please use the factory function :func:`ctc_decoder`.
189
+ """
190
+
191
+ def __init__(
192
+ self,
193
+ nbest: int,
194
+ lexicon: Optional[Dict],
195
+ word_dict: _Dictionary,
196
+ tokens_dict: _Dictionary,
197
+ lm: CTCDecoderLM,
198
+ decoder_options: Union[_LexiconDecoderOptions, _LexiconFreeDecoderOptions],
199
+ blank_token: str,
200
+ sil_token: str,
201
+ unk_word: str,
202
+ ) -> None:
203
+ """
204
+ Args:
205
+ nbest (int): number of best decodings to return
206
+ lexicon (Dict or None): lexicon mapping of words to spellings, or None for lexicon-free decoder
207
+ word_dict (_Dictionary): dictionary of words
208
+ tokens_dict (_Dictionary): dictionary of tokens
209
+ lm (CTCDecoderLM): language model. If using a lexicon, only word level LMs are currently supported
210
+ decoder_options (_LexiconDecoderOptions or _LexiconFreeDecoderOptions):
211
+ parameters used for beam search decoding
212
+ blank_token (str): token corresopnding to blank
213
+ sil_token (str): token corresponding to silence
214
+ unk_word (str): word corresponding to unknown
215
+ """
216
+
217
+ self.nbest = nbest
218
+ self.word_dict = word_dict
219
+ self.tokens_dict = tokens_dict
220
+ self.blank = self.tokens_dict.get_index(blank_token)
221
+ silence = self.tokens_dict.get_index(sil_token)
222
+ transitions = []
223
+
224
+ if lexicon:
225
+ trie = _construct_trie(tokens_dict, word_dict, lexicon, lm, silence)
226
+ unk_word = word_dict.get_index(unk_word)
227
+ token_lm = False # use word level LM
228
+
229
+ self.decoder = _LexiconDecoder(
230
+ decoder_options,
231
+ trie,
232
+ lm,
233
+ silence,
234
+ self.blank,
235
+ unk_word,
236
+ transitions,
237
+ token_lm,
238
+ )
239
+ else:
240
+ self.decoder = _LexiconFreeDecoder(decoder_options, lm, silence, self.blank, transitions)
241
+ # https://github.com/pytorch/audio/issues/3218
242
+ # If lm is passed like rvalue reference, the lm object gets garbage collected,
243
+ # and later call to the lm fails.
244
+ # This ensures that lm object is not deleted as long as the decoder is alive.
245
+ # https://github.com/pybind/pybind11/discussions/4013
246
+ self.lm = lm
247
+
248
+ def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor:
249
+ idxs = (g[0] for g in it.groupby(idxs))
250
+ idxs = filter(lambda x: x != self.blank, idxs)
251
+ return torch.LongTensor(list(idxs))
252
+
253
+ def _get_timesteps(self, idxs: torch.IntTensor) -> torch.IntTensor:
254
+ """Returns frame numbers corresponding to non-blank tokens."""
255
+
256
+ timesteps = []
257
+ for i, idx in enumerate(idxs):
258
+ if idx == self.blank:
259
+ continue
260
+ if i == 0 or idx != idxs[i - 1]:
261
+ timesteps.append(i)
262
+ return torch.IntTensor(timesteps)
263
+
264
+ def decode_begin(self):
265
+ """Initialize the internal state of the decoder.
266
+
267
+ See :py:meth:`decode_step` for the usage.
268
+
269
+ .. note::
270
+
271
+ This method is required only when performing online decoding.
272
+ It is not necessary when performing batch decoding with :py:meth:`__call__`.
273
+ """
274
+ self.decoder.decode_begin()
275
+
276
+ def decode_end(self):
277
+ """Finalize the internal state of the decoder.
278
+
279
+ See :py:meth:`decode_step` for the usage.
280
+
281
+ .. note::
282
+
283
+ This method is required only when performing online decoding.
284
+ It is not necessary when performing batch decoding with :py:meth:`__call__`.
285
+ """
286
+ self.decoder.decode_end()
287
+
288
+ def decode_step(self, emissions: torch.FloatTensor):
289
+ """Perform incremental decoding on top of the curent internal state.
290
+
291
+ .. note::
292
+
293
+ This method is required only when performing online decoding.
294
+ It is not necessary when performing batch decoding with :py:meth:`__call__`.
295
+
296
+ Args:
297
+ emissions (torch.FloatTensor): CPU tensor of shape `(frame, num_tokens)` storing sequences of
298
+ probability distribution over labels; output of acoustic model.
299
+
300
+ Example:
301
+ >>> decoder = torchaudio.models.decoder.ctc_decoder(...)
302
+ >>> decoder.decode_begin()
303
+ >>> decoder.decode_step(emission1)
304
+ >>> decoder.decode_step(emission2)
305
+ >>> decoder.decode_end()
306
+ >>> result = decoder.get_final_hypothesis()
307
+ """
308
+ if emissions.dtype != torch.float32:
309
+ raise ValueError("emissions must be float32.")
310
+
311
+ if not emissions.is_cpu:
312
+ raise RuntimeError("emissions must be a CPU tensor.")
313
+
314
+ if not emissions.is_contiguous():
315
+ raise RuntimeError("emissions must be contiguous.")
316
+
317
+ if emissions.ndim != 2:
318
+ raise RuntimeError(f"emissions must be 2D. Found {emissions.shape}")
319
+
320
+ T, N = emissions.size()
321
+ self.decoder.decode_step(emissions.data_ptr(), T, N)
322
+
323
+ def _to_hypo(self, results) -> List[CTCHypothesis]:
324
+ return [
325
+ CTCHypothesis(
326
+ tokens=self._get_tokens(result.tokens),
327
+ words=[self.word_dict.get_entry(x) for x in result.words if x >= 0],
328
+ score=result.score,
329
+ timesteps=self._get_timesteps(result.tokens),
330
+ )
331
+ for result in results
332
+ ]
333
+
334
+ def get_final_hypothesis(self) -> List[CTCHypothesis]:
335
+ """Get the final hypothesis
336
+
337
+ Returns:
338
+ List[CTCHypothesis]:
339
+ List of sorted best hypotheses.
340
+
341
+ .. note::
342
+
343
+ This method is required only when performing online decoding.
344
+ It is not necessary when performing batch decoding with :py:meth:`__call__`.
345
+ """
346
+ results = self.decoder.get_all_final_hypothesis()
347
+ return self._to_hypo(results[: self.nbest])
348
+
349
+ def __call__(
350
+ self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None
351
+ ) -> List[List[CTCHypothesis]]:
352
+ """
353
+ Performs batched offline decoding.
354
+
355
+ .. note::
356
+
357
+ This method performs offline decoding in one go. To perform incremental decoding,
358
+ please refer to :py:meth:`decode_step`.
359
+
360
+ Args:
361
+ emissions (torch.FloatTensor): CPU tensor of shape `(batch, frame, num_tokens)` storing sequences of
362
+ probability distribution over labels; output of acoustic model.
363
+ lengths (Tensor or None, optional): CPU tensor of shape `(batch, )` storing the valid length of
364
+ in time axis of the output Tensor in each batch.
365
+
366
+ Returns:
367
+ List[List[CTCHypothesis]]:
368
+ List of sorted best hypotheses for each audio sequence in the batch.
369
+ """
370
+
371
+ if emissions.dtype != torch.float32:
372
+ raise ValueError("emissions must be float32.")
373
+
374
+ if not emissions.is_cpu:
375
+ raise RuntimeError("emissions must be a CPU tensor.")
376
+
377
+ if not emissions.is_contiguous():
378
+ raise RuntimeError("emissions must be contiguous.")
379
+
380
+ if emissions.ndim != 3:
381
+ raise RuntimeError(f"emissions must be 3D. Found {emissions.shape}")
382
+
383
+ if lengths is not None and not lengths.is_cpu:
384
+ raise RuntimeError("lengths must be a CPU tensor.")
385
+
386
+ B, T, N = emissions.size()
387
+ if lengths is None:
388
+ lengths = torch.full((B,), T)
389
+
390
+ float_bytes = 4
391
+ hypos = []
392
+
393
+ for b in range(B):
394
+ emissions_ptr = emissions.data_ptr() + float_bytes * b * emissions.stride(0)
395
+ results = self.decoder.decode(emissions_ptr, lengths[b], N)
396
+ hypos.append(self._to_hypo(results[: self.nbest]))
397
+ return hypos
398
+
399
+ def idxs_to_tokens(self, idxs: torch.LongTensor) -> List:
400
+ """
401
+ Map raw token IDs into corresponding tokens
402
+
403
+ Args:
404
+ idxs (LongTensor): raw token IDs generated from decoder
405
+
406
+ Returns:
407
+ List: tokens corresponding to the input IDs
408
+ """
409
+ return [self.tokens_dict.get_entry(idx.item()) for idx in idxs]
410
+
411
+
412
+ def ctc_decoder(
413
+ lexicon: Optional[str],
414
+ tokens: Union[str, List[str]],
415
+ lm: Union[str, CTCDecoderLM] = None,
416
+ lm_dict: Optional[str] = None,
417
+ nbest: int = 1,
418
+ beam_size: int = 50,
419
+ beam_size_token: Optional[int] = None,
420
+ beam_threshold: float = 50,
421
+ lm_weight: float = 2,
422
+ word_score: float = 0,
423
+ unk_score: float = float("-inf"),
424
+ sil_score: float = 0,
425
+ log_add: bool = False,
426
+ blank_token: str = "-",
427
+ sil_token: str = "|",
428
+ unk_word: str = "<unk>",
429
+ ) -> CTCDecoder:
430
+ """Builds an instance of :class:`CTCDecoder`.
431
+
432
+ Args:
433
+ lexicon (str or None): lexicon file containing the possible words and corresponding spellings.
434
+ Each line consists of a word and its space separated spelling. If `None`, uses lexicon-free
435
+ decoding.
436
+ tokens (str or List[str]): file or list containing valid tokens. If using a file, the expected
437
+ format is for tokens mapping to the same index to be on the same line
438
+ lm (str, CTCDecoderLM, or None, optional): either a path containing KenLM language model,
439
+ custom language model of type `CTCDecoderLM`, or `None` if not using a language model
440
+ lm_dict (str or None, optional): file consisting of the dictionary used for the LM, with a word
441
+ per line sorted by LM index. If decoding with a lexicon, entries in lm_dict must also occur
442
+ in the lexicon file. If `None`, dictionary for LM is constructed using the lexicon file.
443
+ (Default: None)
444
+ nbest (int, optional): number of best decodings to return (Default: 1)
445
+ beam_size (int, optional): max number of hypos to hold after each decode step (Default: 50)
446
+ beam_size_token (int, optional): max number of tokens to consider at each decode step.
447
+ If `None`, it is set to the total number of tokens (Default: None)
448
+ beam_threshold (float, optional): threshold for pruning hypothesis (Default: 50)
449
+ lm_weight (float, optional): weight of language model (Default: 2)
450
+ word_score (float, optional): word insertion score (Default: 0)
451
+ unk_score (float, optional): unknown word insertion score (Default: -inf)
452
+ sil_score (float, optional): silence insertion score (Default: 0)
453
+ log_add (bool, optional): whether or not to use logadd when merging hypotheses (Default: False)
454
+ blank_token (str, optional): token corresponding to blank (Default: "-")
455
+ sil_token (str, optional): token corresponding to silence (Default: "|")
456
+ unk_word (str, optional): word corresponding to unknown (Default: "<unk>")
457
+
458
+ Returns:
459
+ CTCDecoder: decoder
460
+
461
+ Example
462
+ >>> decoder = ctc_decoder(
463
+ >>> lexicon="lexicon.txt",
464
+ >>> tokens="tokens.txt",
465
+ >>> lm="kenlm.bin",
466
+ >>> )
467
+ >>> results = decoder(emissions) # List of shape (B, nbest) of Hypotheses
468
+ """
469
+ if lm_dict is not None and type(lm_dict) is not str:
470
+ raise ValueError("lm_dict must be None or str type.")
471
+
472
+ tokens_dict = _Dictionary(tokens)
473
+
474
+ # decoder options
475
+ if lexicon:
476
+ lexicon = _load_words(lexicon)
477
+ decoder_options = _LexiconDecoderOptions(
478
+ beam_size=beam_size,
479
+ beam_size_token=beam_size_token or tokens_dict.index_size(),
480
+ beam_threshold=beam_threshold,
481
+ lm_weight=lm_weight,
482
+ word_score=word_score,
483
+ unk_score=unk_score,
484
+ sil_score=sil_score,
485
+ log_add=log_add,
486
+ criterion_type=_CriterionType.CTC,
487
+ )
488
+ else:
489
+ decoder_options = _LexiconFreeDecoderOptions(
490
+ beam_size=beam_size,
491
+ beam_size_token=beam_size_token or tokens_dict.index_size(),
492
+ beam_threshold=beam_threshold,
493
+ lm_weight=lm_weight,
494
+ sil_score=sil_score,
495
+ log_add=log_add,
496
+ criterion_type=_CriterionType.CTC,
497
+ )
498
+
499
+ # construct word dict and language model
500
+ word_dict = _get_word_dict(lexicon, lm, lm_dict, tokens_dict, unk_word)
501
+
502
+ if type(lm) == str:
503
+ if _KenLM is None:
504
+ raise RuntimeError(
505
+ "flashlight-text is installed, but KenLM is not installed. "
506
+ "Please refer to https://github.com/kpu/kenlm#python-module for how to install it."
507
+ )
508
+ lm = _KenLM(lm, word_dict)
509
+ elif lm is None:
510
+ lm = _ZeroLM()
511
+
512
+ return CTCDecoder(
513
+ nbest=nbest,
514
+ lexicon=lexicon,
515
+ word_dict=word_dict,
516
+ tokens_dict=tokens_dict,
517
+ lm=lm,
518
+ decoder_options=decoder_options,
519
+ blank_token=blank_token,
520
+ sil_token=sil_token,
521
+ unk_word=unk_word,
522
+ )
523
+
524
+
525
+ def _get_filenames(model: str) -> _PretrainedFiles:
526
+ if model not in ["librispeech", "librispeech-3-gram", "librispeech-4-gram"]:
527
+ raise ValueError(
528
+ f"{model} not supported. Must be one of ['librispeech-3-gram', 'librispeech-4-gram', 'librispeech']"
529
+ )
530
+
531
+ prefix = f"decoder-assets/{model}"
532
+ return _PretrainedFiles(
533
+ lexicon=f"{prefix}/lexicon.txt",
534
+ tokens=f"{prefix}/tokens.txt",
535
+ lm=f"{prefix}/lm.bin" if model != "librispeech" else None,
536
+ )
537
+
538
+
539
+ def download_pretrained_files(model: str) -> _PretrainedFiles:
540
+ """
541
+ Retrieves pretrained data files used for :func:`ctc_decoder`.
542
+
543
+ Args:
544
+ model (str): pretrained language model to download.
545
+ Valid values are: ``"librispeech-3-gram"``, ``"librispeech-4-gram"`` and ``"librispeech"``.
546
+
547
+ Returns:
548
+ Object with the following attributes
549
+
550
+ * ``lm``: path corresponding to downloaded language model,
551
+ or ``None`` if the model is not associated with an lm
552
+ * ``lexicon``: path corresponding to downloaded lexicon file
553
+ * ``tokens``: path corresponding to downloaded tokens file
554
+ """
555
+
556
+ files = _get_filenames(model)
557
+ lexicon_file = download_asset(files.lexicon)
558
+ tokens_file = download_asset(files.tokens)
559
+ if files.lm is not None:
560
+ lm_file = download_asset(files.lm)
561
+ else:
562
+ lm_file = None
563
+
564
+ return _PretrainedFiles(
565
+ lexicon=lexicon_file,
566
+ tokens=tokens_file,
567
+ lm=lm_file,
568
+ )
.venv/lib/python3.11/site-packages/torchaudio/models/decoder/_cuda_ctc_decoder.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+
5
+ from typing import List, NamedTuple, Union
6
+
7
+ import torch
8
+ import torchaudio
9
+
10
+ torchaudio._extension._load_lib("libctc_prefix_decoder")
11
+ import torchaudio.lib.pybind11_prefixctc as cuctc
12
+
13
+
14
+ __all__ = ["CUCTCHypothesis", "CUCTCDecoder", "cuda_ctc_decoder"]
15
+
16
+
17
+ def _get_vocab_list(vocab_file):
18
+ vocab = []
19
+ with open(vocab_file, "r", encoding="utf-8") as f:
20
+ for line in f:
21
+ line = line.strip().split()
22
+ vocab.append(line[0])
23
+ return vocab
24
+
25
+
26
+ class CUCTCHypothesis(NamedTuple):
27
+ r"""Represents hypothesis generated by CUCTC beam search decoder :class:`CUCTCDecoder`."""
28
+ tokens: List[int]
29
+ """Predicted sequence of token IDs. Shape `(L, )`, where `L` is the length of the output sequence"""
30
+
31
+ words: List[str]
32
+ """List of predicted tokens. Algin with modeling unit.
33
+ """
34
+
35
+ score: float
36
+ """Score corresponding to hypothesis"""
37
+
38
+
39
+ _DEFAULT_BLANK_SKIP_THREASHOLD = 0.95
40
+
41
+
42
+ class CUCTCDecoder:
43
+ """CUDA CTC beam search decoder.
44
+
45
+ .. devices:: CUDA
46
+
47
+ Note:
48
+ To build the decoder, please use the factory function :func:`cuda_ctc_decoder`.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ vocab_list: List[str],
54
+ blank_id: int = 0,
55
+ beam_size: int = 10,
56
+ nbest: int = 1,
57
+ blank_skip_threshold: float = _DEFAULT_BLANK_SKIP_THREASHOLD,
58
+ cuda_stream: torch.cuda.streams.Stream = None,
59
+ ):
60
+ """
61
+ Args:
62
+ blank_id (int): token id corresopnding to blank, only support 0 for now. (Default: 0)
63
+ vocab_list (List[str]): list of vocabulary tokens
64
+ beam_size (int, optional): max number of hypos to hold after each decode step (Default: 10)
65
+ nbest (int): number of best decodings to return
66
+ blank_skip_threshold (float):
67
+ skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding.
68
+ (Default: 0.95).
69
+ cuda_stream (torch.cuda.streams.Stream): using assigned cuda stream (Default: using default stream)
70
+
71
+ """
72
+ if cuda_stream:
73
+ if not isinstance(cuda_stream, torch.cuda.streams.Stream):
74
+ raise AssertionError("cuda_stream must be torch.cuda.streams.Stream")
75
+ cuda_stream_ = cuda_stream.cuda_stream if cuda_stream else torch.cuda.current_stream().cuda_stream
76
+ self.internal_data = cuctc.prefixCTC_alloc(cuda_stream_)
77
+ self.memory = torch.empty(0, dtype=torch.int8, device=torch.device("cuda"))
78
+ if blank_id != 0:
79
+ raise AssertionError("blank_id must be 0")
80
+ self.blank_id = blank_id
81
+ self.vocab_list = vocab_list
82
+ self.space_id = 0
83
+ self.nbest = nbest
84
+ if not (blank_skip_threshold >= 0 and blank_skip_threshold <= 1):
85
+ raise AssertionError("blank_skip_threshold must be between 0 and 1")
86
+ self.blank_skip_threshold = math.log(blank_skip_threshold)
87
+ self.beam_size = min(beam_size, len(vocab_list)) # beam size must be smaller than vocab size
88
+
89
+ def __del__(self):
90
+ if cuctc is not None:
91
+ cuctc.prefixCTC_free(self.internal_data)
92
+
93
+ def __call__(self, log_prob: torch.Tensor, encoder_out_lens: torch.Tensor):
94
+ """
95
+ Args:
96
+ log_prob (torch.FloatTensor): GPU tensor of shape `(batch, frame, num_tokens)` storing sequences of
97
+ probability distribution over labels; log_softmax(output of acoustic model).
98
+ lengths (dtype torch.int32): GPU tensor of shape `(batch, )` storing the valid length of
99
+ in time axis of the output Tensor in each batch.
100
+
101
+ Returns:
102
+ List[List[CUCTCHypothesis]]:
103
+ List of sorted best hypotheses for each audio sequence in the batch.
104
+ """
105
+ if not encoder_out_lens.dtype == torch.int32:
106
+ raise AssertionError("encoder_out_lens must be torch.int32")
107
+ if not log_prob.dtype == torch.float32:
108
+ raise AssertionError("log_prob must be torch.float32")
109
+ if not (log_prob.is_cuda and encoder_out_lens.is_cuda):
110
+ raise AssertionError("inputs must be cuda tensors")
111
+ if not (log_prob.is_contiguous() and encoder_out_lens.is_contiguous()):
112
+ raise AssertionError("input tensors must be contiguous")
113
+ required_size, score_hyps = cuctc.ctc_beam_search_decoder_batch_gpu_v2(
114
+ self.internal_data,
115
+ self.memory.data_ptr(),
116
+ self.memory.size(0),
117
+ log_prob.data_ptr(),
118
+ encoder_out_lens.data_ptr(),
119
+ log_prob.size(),
120
+ log_prob.stride(),
121
+ self.beam_size,
122
+ self.blank_id,
123
+ self.space_id,
124
+ self.blank_skip_threshold,
125
+ )
126
+ if required_size > 0:
127
+ self.memory = torch.empty(required_size, dtype=torch.int8, device=log_prob.device).contiguous()
128
+ _, score_hyps = cuctc.ctc_beam_search_decoder_batch_gpu_v2(
129
+ self.internal_data,
130
+ self.memory.data_ptr(),
131
+ self.memory.size(0),
132
+ log_prob.data_ptr(),
133
+ encoder_out_lens.data_ptr(),
134
+ log_prob.size(),
135
+ log_prob.stride(),
136
+ self.beam_size,
137
+ self.blank_id,
138
+ self.space_id,
139
+ self.blank_skip_threshold,
140
+ )
141
+ batch_size = len(score_hyps)
142
+ hypos = []
143
+ for i in range(batch_size):
144
+ hypos.append(
145
+ [
146
+ CUCTCHypothesis(
147
+ tokens=score_hyps[i][j][1],
148
+ words=[self.vocab_list[word_id] for word_id in score_hyps[i][j][1]],
149
+ score=score_hyps[i][j][0],
150
+ )
151
+ for j in range(self.nbest)
152
+ ]
153
+ )
154
+ return hypos
155
+
156
+
157
+ def cuda_ctc_decoder(
158
+ tokens: Union[str, List[str]],
159
+ nbest: int = 1,
160
+ beam_size: int = 10,
161
+ blank_skip_threshold: float = _DEFAULT_BLANK_SKIP_THREASHOLD,
162
+ ) -> CUCTCDecoder:
163
+ """Builds an instance of :class:`CUCTCDecoder`.
164
+
165
+ Args:
166
+ tokens (str or List[str]): File or list containing valid tokens.
167
+ If using a file, the expected format is for tokens mapping to the same index to be on the same line
168
+ beam_size (int, optional): The maximum number of hypos to hold after each decode step (Default: 10)
169
+ nbest (int): The number of best decodings to return
170
+ blank_id (int): The token ID corresopnding to the blank symbol.
171
+ blank_skip_threshold (float): skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding
172
+ (Default: 0.95).
173
+
174
+ Returns:
175
+ CUCTCDecoder: decoder
176
+
177
+ Example
178
+ >>> decoder = cuda_ctc_decoder(
179
+ >>> vocab_file="tokens.txt",
180
+ >>> blank_skip_threshold=0.95,
181
+ >>> )
182
+ >>> results = decoder(log_probs, encoder_out_lens) # List of shape (B, nbest) of Hypotheses
183
+ """
184
+ if type(tokens) == str:
185
+ tokens = _get_vocab_list(tokens)
186
+
187
+ return CUCTCDecoder(vocab_list=tokens, beam_size=beam_size, nbest=nbest, blank_skip_threshold=blank_skip_threshold)
.venv/lib/python3.11/site-packages/torchaudio/models/deepspeech.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ __all__ = ["DeepSpeech"]
4
+
5
+
6
+ class FullyConnected(torch.nn.Module):
7
+ """
8
+ Args:
9
+ n_feature: Number of input features
10
+ n_hidden: Internal hidden unit size.
11
+ """
12
+
13
+ def __init__(self, n_feature: int, n_hidden: int, dropout: float, relu_max_clip: int = 20) -> None:
14
+ super(FullyConnected, self).__init__()
15
+ self.fc = torch.nn.Linear(n_feature, n_hidden, bias=True)
16
+ self.relu_max_clip = relu_max_clip
17
+ self.dropout = dropout
18
+
19
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
20
+ x = self.fc(x)
21
+ x = torch.nn.functional.relu(x)
22
+ x = torch.nn.functional.hardtanh(x, 0, self.relu_max_clip)
23
+ if self.dropout:
24
+ x = torch.nn.functional.dropout(x, self.dropout, self.training)
25
+ return x
26
+
27
+
28
+ class DeepSpeech(torch.nn.Module):
29
+ """DeepSpeech architecture introduced in
30
+ *Deep Speech: Scaling up end-to-end speech recognition* :cite:`hannun2014deep`.
31
+
32
+ Args:
33
+ n_feature: Number of input features
34
+ n_hidden: Internal hidden unit size.
35
+ n_class: Number of output classes
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ n_feature: int,
41
+ n_hidden: int = 2048,
42
+ n_class: int = 40,
43
+ dropout: float = 0.0,
44
+ ) -> None:
45
+ super(DeepSpeech, self).__init__()
46
+ self.n_hidden = n_hidden
47
+ self.fc1 = FullyConnected(n_feature, n_hidden, dropout)
48
+ self.fc2 = FullyConnected(n_hidden, n_hidden, dropout)
49
+ self.fc3 = FullyConnected(n_hidden, n_hidden, dropout)
50
+ self.bi_rnn = torch.nn.RNN(n_hidden, n_hidden, num_layers=1, nonlinearity="relu", bidirectional=True)
51
+ self.fc4 = FullyConnected(n_hidden, n_hidden, dropout)
52
+ self.out = torch.nn.Linear(n_hidden, n_class)
53
+
54
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
55
+ """
56
+ Args:
57
+ x (torch.Tensor): Tensor of dimension (batch, channel, time, feature).
58
+ Returns:
59
+ Tensor: Predictor tensor of dimension (batch, time, class).
60
+ """
61
+ # N x C x T x F
62
+ x = self.fc1(x)
63
+ # N x C x T x H
64
+ x = self.fc2(x)
65
+ # N x C x T x H
66
+ x = self.fc3(x)
67
+ # N x C x T x H
68
+ x = x.squeeze(1)
69
+ # N x T x H
70
+ x = x.transpose(0, 1)
71
+ # T x N x H
72
+ x, _ = self.bi_rnn(x)
73
+ # The fifth (non-recurrent) layer takes both the forward and backward units as inputs
74
+ x = x[:, :, : self.n_hidden] + x[:, :, self.n_hidden :]
75
+ # T x N x H
76
+ x = self.fc4(x)
77
+ # T x N x H
78
+ x = self.out(x)
79
+ # T x N x n_class
80
+ x = x.permute(1, 0, 2)
81
+ # N x T x n_class
82
+ x = torch.nn.functional.log_softmax(x, dim=2)
83
+ # N x T x n_class
84
+ return x
.venv/lib/python3.11/site-packages/torchaudio/models/emformer.py ADDED
@@ -0,0 +1,884 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple
3
+
4
+ import torch
5
+
6
+
7
+ __all__ = ["Emformer"]
8
+
9
+
10
+ def _lengths_to_padding_mask(lengths: torch.Tensor) -> torch.Tensor:
11
+ batch_size = lengths.shape[0]
12
+ max_length = int(torch.max(lengths).item())
13
+ padding_mask = torch.arange(max_length, device=lengths.device, dtype=lengths.dtype).expand(
14
+ batch_size, max_length
15
+ ) >= lengths.unsqueeze(1)
16
+ return padding_mask
17
+
18
+
19
+ def _gen_padding_mask(
20
+ utterance: torch.Tensor,
21
+ right_context: torch.Tensor,
22
+ summary: torch.Tensor,
23
+ lengths: torch.Tensor,
24
+ mems: torch.Tensor,
25
+ left_context_key: Optional[torch.Tensor] = None,
26
+ ) -> Optional[torch.Tensor]:
27
+ T = right_context.size(0) + utterance.size(0) + summary.size(0)
28
+ B = right_context.size(1)
29
+ if B == 1:
30
+ padding_mask = None
31
+ else:
32
+ right_context_blocks_length = T - torch.max(lengths).int() - summary.size(0)
33
+ left_context_blocks_length = left_context_key.size(0) if left_context_key is not None else 0
34
+ klengths = lengths + mems.size(0) + right_context_blocks_length + left_context_blocks_length
35
+ padding_mask = _lengths_to_padding_mask(lengths=klengths)
36
+ return padding_mask
37
+
38
+
39
+ def _get_activation_module(activation: str) -> torch.nn.Module:
40
+ if activation == "relu":
41
+ return torch.nn.ReLU()
42
+ elif activation == "gelu":
43
+ return torch.nn.GELU()
44
+ elif activation == "silu":
45
+ return torch.nn.SiLU()
46
+ else:
47
+ raise ValueError(f"Unsupported activation {activation}")
48
+
49
+
50
+ def _get_weight_init_gains(weight_init_scale_strategy: Optional[str], num_layers: int) -> List[Optional[float]]:
51
+ if weight_init_scale_strategy is None:
52
+ return [None for _ in range(num_layers)]
53
+ elif weight_init_scale_strategy == "depthwise":
54
+ return [1.0 / math.sqrt(layer_idx + 1) for layer_idx in range(num_layers)]
55
+ elif weight_init_scale_strategy == "constant":
56
+ return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)]
57
+ else:
58
+ raise ValueError(f"Unsupported weight_init_scale_strategy value {weight_init_scale_strategy}")
59
+
60
+
61
+ def _gen_attention_mask_block(
62
+ col_widths: List[int], col_mask: List[bool], num_rows: int, device: torch.device
63
+ ) -> torch.Tensor:
64
+ if len(col_widths) != len(col_mask):
65
+ raise ValueError("Length of col_widths must match that of col_mask")
66
+
67
+ mask_block = [
68
+ torch.ones(num_rows, col_width, device=device)
69
+ if is_ones_col
70
+ else torch.zeros(num_rows, col_width, device=device)
71
+ for col_width, is_ones_col in zip(col_widths, col_mask)
72
+ ]
73
+ return torch.cat(mask_block, dim=1)
74
+
75
+
76
+ class _EmformerAttention(torch.nn.Module):
77
+ r"""Emformer layer attention module.
78
+
79
+ Args:
80
+ input_dim (int): input dimension.
81
+ num_heads (int): number of attention heads in each Emformer layer.
82
+ dropout (float, optional): dropout probability. (Default: 0.0)
83
+ weight_init_gain (float or None, optional): scale factor to apply when initializing
84
+ attention module parameters. (Default: ``None``)
85
+ tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
86
+ negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ input_dim: int,
92
+ num_heads: int,
93
+ dropout: float = 0.0,
94
+ weight_init_gain: Optional[float] = None,
95
+ tanh_on_mem: bool = False,
96
+ negative_inf: float = -1e8,
97
+ ):
98
+ super().__init__()
99
+
100
+ if input_dim % num_heads != 0:
101
+ raise ValueError(f"input_dim ({input_dim}) is not a multiple of num_heads ({num_heads}).")
102
+
103
+ self.input_dim = input_dim
104
+ self.num_heads = num_heads
105
+ self.dropout = dropout
106
+ self.tanh_on_mem = tanh_on_mem
107
+ self.negative_inf = negative_inf
108
+
109
+ self.scaling = (self.input_dim // self.num_heads) ** -0.5
110
+
111
+ self.emb_to_key_value = torch.nn.Linear(input_dim, 2 * input_dim, bias=True)
112
+ self.emb_to_query = torch.nn.Linear(input_dim, input_dim, bias=True)
113
+ self.out_proj = torch.nn.Linear(input_dim, input_dim, bias=True)
114
+
115
+ if weight_init_gain:
116
+ torch.nn.init.xavier_uniform_(self.emb_to_key_value.weight, gain=weight_init_gain)
117
+ torch.nn.init.xavier_uniform_(self.emb_to_query.weight, gain=weight_init_gain)
118
+
119
+ def _gen_key_value(self, input: torch.Tensor, mems: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
120
+ T, _, _ = input.shape
121
+ summary_length = mems.size(0) + 1
122
+ right_ctx_utterance_block = input[: T - summary_length]
123
+ mems_right_ctx_utterance_block = torch.cat([mems, right_ctx_utterance_block])
124
+ key, value = self.emb_to_key_value(mems_right_ctx_utterance_block).chunk(chunks=2, dim=2)
125
+ return key, value
126
+
127
+ def _gen_attention_probs(
128
+ self,
129
+ attention_weights: torch.Tensor,
130
+ attention_mask: torch.Tensor,
131
+ padding_mask: Optional[torch.Tensor],
132
+ ) -> torch.Tensor:
133
+ attention_weights_float = attention_weights.float()
134
+ attention_weights_float = attention_weights_float.masked_fill(attention_mask.unsqueeze(0), self.negative_inf)
135
+ T = attention_weights.size(1)
136
+ B = attention_weights.size(0) // self.num_heads
137
+ if padding_mask is not None:
138
+ attention_weights_float = attention_weights_float.view(B, self.num_heads, T, -1)
139
+ attention_weights_float = attention_weights_float.masked_fill(
140
+ padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf
141
+ )
142
+ attention_weights_float = attention_weights_float.view(B * self.num_heads, T, -1)
143
+ attention_probs = torch.nn.functional.softmax(attention_weights_float, dim=-1).type_as(attention_weights)
144
+ return torch.nn.functional.dropout(attention_probs, p=float(self.dropout), training=self.training)
145
+
146
+ def _forward_impl(
147
+ self,
148
+ utterance: torch.Tensor,
149
+ lengths: torch.Tensor,
150
+ right_context: torch.Tensor,
151
+ summary: torch.Tensor,
152
+ mems: torch.Tensor,
153
+ attention_mask: torch.Tensor,
154
+ left_context_key: Optional[torch.Tensor] = None,
155
+ left_context_val: Optional[torch.Tensor] = None,
156
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
157
+ B = utterance.size(1)
158
+ T = right_context.size(0) + utterance.size(0) + summary.size(0)
159
+
160
+ # Compute query with [right context, utterance, summary].
161
+ query = self.emb_to_query(torch.cat([right_context, utterance, summary]))
162
+
163
+ # Compute key and value with [mems, right context, utterance].
164
+ key, value = self.emb_to_key_value(torch.cat([mems, right_context, utterance])).chunk(chunks=2, dim=2)
165
+
166
+ if left_context_key is not None and left_context_val is not None:
167
+ right_context_blocks_length = T - torch.max(lengths).int() - summary.size(0)
168
+ key = torch.cat(
169
+ [
170
+ key[: mems.size(0) + right_context_blocks_length],
171
+ left_context_key,
172
+ key[mems.size(0) + right_context_blocks_length :],
173
+ ],
174
+ )
175
+ value = torch.cat(
176
+ [
177
+ value[: mems.size(0) + right_context_blocks_length],
178
+ left_context_val,
179
+ value[mems.size(0) + right_context_blocks_length :],
180
+ ],
181
+ )
182
+
183
+ # Compute attention weights from query, key, and value.
184
+ reshaped_query, reshaped_key, reshaped_value = [
185
+ tensor.contiguous().view(-1, B * self.num_heads, self.input_dim // self.num_heads).transpose(0, 1)
186
+ for tensor in [query, key, value]
187
+ ]
188
+ attention_weights = torch.bmm(reshaped_query * self.scaling, reshaped_key.transpose(1, 2))
189
+
190
+ # Compute padding mask.
191
+ padding_mask = _gen_padding_mask(utterance, right_context, summary, lengths, mems, left_context_key)
192
+
193
+ # Compute attention probabilities.
194
+ attention_probs = self._gen_attention_probs(attention_weights, attention_mask, padding_mask)
195
+
196
+ # Compute attention.
197
+ attention = torch.bmm(attention_probs, reshaped_value)
198
+ if attention.shape != (
199
+ B * self.num_heads,
200
+ T,
201
+ self.input_dim // self.num_heads,
202
+ ):
203
+ raise AssertionError("Computed attention has incorrect dimensions")
204
+ attention = attention.transpose(0, 1).contiguous().view(T, B, self.input_dim)
205
+
206
+ # Apply output projection.
207
+ output_right_context_mems = self.out_proj(attention)
208
+
209
+ summary_length = summary.size(0)
210
+ output_right_context = output_right_context_mems[: T - summary_length]
211
+ output_mems = output_right_context_mems[T - summary_length :]
212
+ if self.tanh_on_mem:
213
+ output_mems = torch.tanh(output_mems)
214
+ else:
215
+ output_mems = torch.clamp(output_mems, min=-10, max=10)
216
+
217
+ return output_right_context, output_mems, key, value
218
+
219
+ def forward(
220
+ self,
221
+ utterance: torch.Tensor,
222
+ lengths: torch.Tensor,
223
+ right_context: torch.Tensor,
224
+ summary: torch.Tensor,
225
+ mems: torch.Tensor,
226
+ attention_mask: torch.Tensor,
227
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
228
+ r"""Forward pass for training.
229
+
230
+ B: batch size;
231
+ D: feature dimension of each frame;
232
+ T: number of utterance frames;
233
+ R: number of right context frames;
234
+ S: number of summary elements;
235
+ M: number of memory elements.
236
+
237
+ Args:
238
+ utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
239
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
240
+ number of valid frames for i-th batch element in ``utterance``.
241
+ right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
242
+ summary (torch.Tensor): summary elements, with shape `(S, B, D)`.
243
+ mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
244
+ attention_mask (torch.Tensor): attention mask for underlying attention module.
245
+
246
+ Returns:
247
+ (Tensor, Tensor):
248
+ Tensor
249
+ output frames corresponding to utterance and right_context, with shape `(T + R, B, D)`.
250
+ Tensor
251
+ updated memory elements, with shape `(M, B, D)`.
252
+ """
253
+ output, output_mems, _, _ = self._forward_impl(utterance, lengths, right_context, summary, mems, attention_mask)
254
+ return output, output_mems[:-1]
255
+
256
+ @torch.jit.export
257
+ def infer(
258
+ self,
259
+ utterance: torch.Tensor,
260
+ lengths: torch.Tensor,
261
+ right_context: torch.Tensor,
262
+ summary: torch.Tensor,
263
+ mems: torch.Tensor,
264
+ left_context_key: torch.Tensor,
265
+ left_context_val: torch.Tensor,
266
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
267
+ r"""Forward pass for inference.
268
+
269
+ B: batch size;
270
+ D: feature dimension of each frame;
271
+ T: number of utterance frames;
272
+ R: number of right context frames;
273
+ S: number of summary elements;
274
+ M: number of memory elements.
275
+
276
+ Args:
277
+ utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
278
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
279
+ number of valid frames for i-th batch element in ``utterance``.
280
+ right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
281
+ summary (torch.Tensor): summary elements, with shape `(S, B, D)`.
282
+ mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
283
+ left_context_key (torch.Tensor): left context attention key computed from preceding invocation.
284
+ left_context_val (torch.Tensor): left context attention value computed from preceding invocation.
285
+
286
+ Returns:
287
+ (Tensor, Tensor, Tensor, and Tensor):
288
+ Tensor
289
+ output frames corresponding to utterance and right_context, with shape `(T + R, B, D)`.
290
+ Tensor
291
+ updated memory elements, with shape `(M, B, D)`.
292
+ Tensor
293
+ attention key computed for left context and utterance.
294
+ Tensor
295
+ attention value computed for left context and utterance.
296
+ """
297
+ query_dim = right_context.size(0) + utterance.size(0) + summary.size(0)
298
+ key_dim = right_context.size(0) + utterance.size(0) + mems.size(0) + left_context_key.size(0)
299
+ attention_mask = torch.zeros(query_dim, key_dim).to(dtype=torch.bool, device=utterance.device)
300
+ attention_mask[-1, : mems.size(0)] = True
301
+ output, output_mems, key, value = self._forward_impl(
302
+ utterance,
303
+ lengths,
304
+ right_context,
305
+ summary,
306
+ mems,
307
+ attention_mask,
308
+ left_context_key=left_context_key,
309
+ left_context_val=left_context_val,
310
+ )
311
+ return (
312
+ output,
313
+ output_mems,
314
+ key[mems.size(0) + right_context.size(0) :],
315
+ value[mems.size(0) + right_context.size(0) :],
316
+ )
317
+
318
+
319
+ class _EmformerLayer(torch.nn.Module):
320
+ r"""Emformer layer that constitutes Emformer.
321
+
322
+ Args:
323
+ input_dim (int): input dimension.
324
+ num_heads (int): number of attention heads.
325
+ ffn_dim: (int): hidden layer dimension of feedforward network.
326
+ segment_length (int): length of each input segment.
327
+ dropout (float, optional): dropout probability. (Default: 0.0)
328
+ activation (str, optional): activation function to use in feedforward network.
329
+ Must be one of ("relu", "gelu", "silu"). (Default: "relu")
330
+ left_context_length (int, optional): length of left context. (Default: 0)
331
+ max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
332
+ weight_init_gain (float or None, optional): scale factor to apply when initializing
333
+ attention module parameters. (Default: ``None``)
334
+ tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
335
+ negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
336
+ """
337
+
338
+ def __init__(
339
+ self,
340
+ input_dim: int,
341
+ num_heads: int,
342
+ ffn_dim: int,
343
+ segment_length: int,
344
+ dropout: float = 0.0,
345
+ activation: str = "relu",
346
+ left_context_length: int = 0,
347
+ max_memory_size: int = 0,
348
+ weight_init_gain: Optional[float] = None,
349
+ tanh_on_mem: bool = False,
350
+ negative_inf: float = -1e8,
351
+ ):
352
+ super().__init__()
353
+
354
+ self.attention = _EmformerAttention(
355
+ input_dim=input_dim,
356
+ num_heads=num_heads,
357
+ dropout=dropout,
358
+ weight_init_gain=weight_init_gain,
359
+ tanh_on_mem=tanh_on_mem,
360
+ negative_inf=negative_inf,
361
+ )
362
+ self.dropout = torch.nn.Dropout(dropout)
363
+ self.memory_op = torch.nn.AvgPool1d(kernel_size=segment_length, stride=segment_length, ceil_mode=True)
364
+
365
+ activation_module = _get_activation_module(activation)
366
+ self.pos_ff = torch.nn.Sequential(
367
+ torch.nn.LayerNorm(input_dim),
368
+ torch.nn.Linear(input_dim, ffn_dim),
369
+ activation_module,
370
+ torch.nn.Dropout(dropout),
371
+ torch.nn.Linear(ffn_dim, input_dim),
372
+ torch.nn.Dropout(dropout),
373
+ )
374
+ self.layer_norm_input = torch.nn.LayerNorm(input_dim)
375
+ self.layer_norm_output = torch.nn.LayerNorm(input_dim)
376
+
377
+ self.left_context_length = left_context_length
378
+ self.segment_length = segment_length
379
+ self.max_memory_size = max_memory_size
380
+ self.input_dim = input_dim
381
+
382
+ self.use_mem = max_memory_size > 0
383
+
384
+ def _init_state(self, batch_size: int, device: Optional[torch.device]) -> List[torch.Tensor]:
385
+ empty_memory = torch.zeros(self.max_memory_size, batch_size, self.input_dim, device=device)
386
+ left_context_key = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
387
+ left_context_val = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
388
+ past_length = torch.zeros(1, batch_size, dtype=torch.int32, device=device)
389
+ return [empty_memory, left_context_key, left_context_val, past_length]
390
+
391
+ def _unpack_state(self, state: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
392
+ past_length = state[3][0][0].item()
393
+ past_left_context_length = min(self.left_context_length, past_length)
394
+ past_mem_length = min(self.max_memory_size, math.ceil(past_length / self.segment_length))
395
+ pre_mems = state[0][self.max_memory_size - past_mem_length :]
396
+ lc_key = state[1][self.left_context_length - past_left_context_length :]
397
+ lc_val = state[2][self.left_context_length - past_left_context_length :]
398
+ return pre_mems, lc_key, lc_val
399
+
400
+ def _pack_state(
401
+ self,
402
+ next_k: torch.Tensor,
403
+ next_v: torch.Tensor,
404
+ update_length: int,
405
+ mems: torch.Tensor,
406
+ state: List[torch.Tensor],
407
+ ) -> List[torch.Tensor]:
408
+ new_k = torch.cat([state[1], next_k])
409
+ new_v = torch.cat([state[2], next_v])
410
+ state[0] = torch.cat([state[0], mems])[-self.max_memory_size :]
411
+ state[1] = new_k[new_k.shape[0] - self.left_context_length :]
412
+ state[2] = new_v[new_v.shape[0] - self.left_context_length :]
413
+ state[3] = state[3] + update_length
414
+ return state
415
+
416
+ def _process_attention_output(
417
+ self,
418
+ rc_output: torch.Tensor,
419
+ utterance: torch.Tensor,
420
+ right_context: torch.Tensor,
421
+ ) -> torch.Tensor:
422
+ result = self.dropout(rc_output) + torch.cat([right_context, utterance])
423
+ result = self.pos_ff(result) + result
424
+ result = self.layer_norm_output(result)
425
+ return result
426
+
427
+ def _apply_pre_attention_layer_norm(
428
+ self, utterance: torch.Tensor, right_context: torch.Tensor
429
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
430
+ layer_norm_input = self.layer_norm_input(torch.cat([right_context, utterance]))
431
+ return (
432
+ layer_norm_input[right_context.size(0) :],
433
+ layer_norm_input[: right_context.size(0)],
434
+ )
435
+
436
+ def _apply_post_attention_ffn(
437
+ self, rc_output: torch.Tensor, utterance: torch.Tensor, right_context: torch.Tensor
438
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
439
+ rc_output = self._process_attention_output(rc_output, utterance, right_context)
440
+ return rc_output[right_context.size(0) :], rc_output[: right_context.size(0)]
441
+
442
+ def _apply_attention_forward(
443
+ self,
444
+ utterance: torch.Tensor,
445
+ lengths: torch.Tensor,
446
+ right_context: torch.Tensor,
447
+ mems: torch.Tensor,
448
+ attention_mask: Optional[torch.Tensor],
449
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
450
+ if attention_mask is None:
451
+ raise ValueError("attention_mask must be not None when for_inference is False")
452
+
453
+ if self.use_mem:
454
+ summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
455
+ else:
456
+ summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
457
+ rc_output, next_m = self.attention(
458
+ utterance=utterance,
459
+ lengths=lengths,
460
+ right_context=right_context,
461
+ summary=summary,
462
+ mems=mems,
463
+ attention_mask=attention_mask,
464
+ )
465
+ return rc_output, next_m
466
+
467
+ def _apply_attention_infer(
468
+ self,
469
+ utterance: torch.Tensor,
470
+ lengths: torch.Tensor,
471
+ right_context: torch.Tensor,
472
+ mems: torch.Tensor,
473
+ state: Optional[List[torch.Tensor]],
474
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
475
+ if state is None:
476
+ state = self._init_state(utterance.size(1), device=utterance.device)
477
+ pre_mems, lc_key, lc_val = self._unpack_state(state)
478
+ if self.use_mem:
479
+ summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
480
+ summary = summary[:1]
481
+ else:
482
+ summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
483
+ rc_output, next_m, next_k, next_v = self.attention.infer(
484
+ utterance=utterance,
485
+ lengths=lengths,
486
+ right_context=right_context,
487
+ summary=summary,
488
+ mems=pre_mems,
489
+ left_context_key=lc_key,
490
+ left_context_val=lc_val,
491
+ )
492
+ state = self._pack_state(next_k, next_v, utterance.size(0), mems, state)
493
+ return rc_output, next_m, state
494
+
495
+ def forward(
496
+ self,
497
+ utterance: torch.Tensor,
498
+ lengths: torch.Tensor,
499
+ right_context: torch.Tensor,
500
+ mems: torch.Tensor,
501
+ attention_mask: torch.Tensor,
502
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
503
+ r"""Forward pass for training.
504
+
505
+ B: batch size;
506
+ D: feature dimension of each frame;
507
+ T: number of utterance frames;
508
+ R: number of right context frames;
509
+ M: number of memory elements.
510
+
511
+ Args:
512
+ utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
513
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
514
+ number of valid frames for i-th batch element in ``utterance``.
515
+ right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
516
+ mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
517
+ attention_mask (torch.Tensor): attention mask for underlying attention module.
518
+
519
+ Returns:
520
+ (Tensor, Tensor, Tensor):
521
+ Tensor
522
+ encoded utterance frames, with shape `(T, B, D)`.
523
+ Tensor
524
+ updated right context frames, with shape `(R, B, D)`.
525
+ Tensor
526
+ updated memory elements, with shape `(M, B, D)`.
527
+ """
528
+ (
529
+ layer_norm_utterance,
530
+ layer_norm_right_context,
531
+ ) = self._apply_pre_attention_layer_norm(utterance, right_context)
532
+ rc_output, output_mems = self._apply_attention_forward(
533
+ layer_norm_utterance,
534
+ lengths,
535
+ layer_norm_right_context,
536
+ mems,
537
+ attention_mask,
538
+ )
539
+ output_utterance, output_right_context = self._apply_post_attention_ffn(rc_output, utterance, right_context)
540
+ return output_utterance, output_right_context, output_mems
541
+
542
+ @torch.jit.export
543
+ def infer(
544
+ self,
545
+ utterance: torch.Tensor,
546
+ lengths: torch.Tensor,
547
+ right_context: torch.Tensor,
548
+ state: Optional[List[torch.Tensor]],
549
+ mems: torch.Tensor,
550
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
551
+ r"""Forward pass for inference.
552
+
553
+ B: batch size;
554
+ D: feature dimension of each frame;
555
+ T: number of utterance frames;
556
+ R: number of right context frames;
557
+ M: number of memory elements.
558
+
559
+ Args:
560
+ utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
561
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
562
+ number of valid frames for i-th batch element in ``utterance``.
563
+ right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
564
+ state (List[torch.Tensor] or None): list of tensors representing layer internal state
565
+ generated in preceding invocation of ``infer``.
566
+ mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
567
+
568
+ Returns:
569
+ (Tensor, Tensor, List[torch.Tensor], Tensor):
570
+ Tensor
571
+ encoded utterance frames, with shape `(T, B, D)`.
572
+ Tensor
573
+ updated right context frames, with shape `(R, B, D)`.
574
+ List[Tensor]
575
+ list of tensors representing layer internal state
576
+ generated in current invocation of ``infer``.
577
+ Tensor
578
+ updated memory elements, with shape `(M, B, D)`.
579
+ """
580
+ (
581
+ layer_norm_utterance,
582
+ layer_norm_right_context,
583
+ ) = self._apply_pre_attention_layer_norm(utterance, right_context)
584
+ rc_output, output_mems, output_state = self._apply_attention_infer(
585
+ layer_norm_utterance, lengths, layer_norm_right_context, mems, state
586
+ )
587
+ output_utterance, output_right_context = self._apply_post_attention_ffn(rc_output, utterance, right_context)
588
+ return output_utterance, output_right_context, output_state, output_mems
589
+
590
+
591
+ class _EmformerImpl(torch.nn.Module):
592
+ def __init__(
593
+ self,
594
+ emformer_layers: torch.nn.ModuleList,
595
+ segment_length: int,
596
+ left_context_length: int = 0,
597
+ right_context_length: int = 0,
598
+ max_memory_size: int = 0,
599
+ ):
600
+ super().__init__()
601
+
602
+ self.use_mem = max_memory_size > 0
603
+ self.memory_op = torch.nn.AvgPool1d(
604
+ kernel_size=segment_length,
605
+ stride=segment_length,
606
+ ceil_mode=True,
607
+ )
608
+ self.emformer_layers = emformer_layers
609
+ self.left_context_length = left_context_length
610
+ self.right_context_length = right_context_length
611
+ self.segment_length = segment_length
612
+ self.max_memory_size = max_memory_size
613
+
614
+ def _gen_right_context(self, input: torch.Tensor) -> torch.Tensor:
615
+ T = input.shape[0]
616
+ num_segs = math.ceil((T - self.right_context_length) / self.segment_length)
617
+ right_context_blocks = []
618
+ for seg_idx in range(num_segs - 1):
619
+ start = (seg_idx + 1) * self.segment_length
620
+ end = start + self.right_context_length
621
+ right_context_blocks.append(input[start:end])
622
+ right_context_blocks.append(input[T - self.right_context_length :])
623
+ return torch.cat(right_context_blocks)
624
+
625
+ def _gen_attention_mask_col_widths(self, seg_idx: int, utterance_length: int) -> List[int]:
626
+ num_segs = math.ceil(utterance_length / self.segment_length)
627
+ rc = self.right_context_length
628
+ lc = self.left_context_length
629
+ rc_start = seg_idx * rc
630
+ rc_end = rc_start + rc
631
+ seg_start = max(seg_idx * self.segment_length - lc, 0)
632
+ seg_end = min((seg_idx + 1) * self.segment_length, utterance_length)
633
+ rc_length = self.right_context_length * num_segs
634
+
635
+ if self.use_mem:
636
+ m_start = max(seg_idx - self.max_memory_size, 0)
637
+ mem_length = num_segs - 1
638
+ col_widths = [
639
+ m_start, # before memory
640
+ seg_idx - m_start, # memory
641
+ mem_length - seg_idx, # after memory
642
+ rc_start, # before right context
643
+ rc, # right context
644
+ rc_length - rc_end, # after right context
645
+ seg_start, # before query segment
646
+ seg_end - seg_start, # query segment
647
+ utterance_length - seg_end, # after query segment
648
+ ]
649
+ else:
650
+ col_widths = [
651
+ rc_start, # before right context
652
+ rc, # right context
653
+ rc_length - rc_end, # after right context
654
+ seg_start, # before query segment
655
+ seg_end - seg_start, # query segment
656
+ utterance_length - seg_end, # after query segment
657
+ ]
658
+
659
+ return col_widths
660
+
661
+ def _gen_attention_mask(self, input: torch.Tensor) -> torch.Tensor:
662
+ utterance_length = input.size(0)
663
+ num_segs = math.ceil(utterance_length / self.segment_length)
664
+
665
+ rc_mask = []
666
+ query_mask = []
667
+ summary_mask = []
668
+
669
+ if self.use_mem:
670
+ num_cols = 9
671
+ # memory, right context, query segment
672
+ rc_q_cols_mask = [idx in [1, 4, 7] for idx in range(num_cols)]
673
+ # right context, query segment
674
+ s_cols_mask = [idx in [4, 7] for idx in range(num_cols)]
675
+ masks_to_concat = [rc_mask, query_mask, summary_mask]
676
+ else:
677
+ num_cols = 6
678
+ # right context, query segment
679
+ rc_q_cols_mask = [idx in [1, 4] for idx in range(num_cols)]
680
+ s_cols_mask = None
681
+ masks_to_concat = [rc_mask, query_mask]
682
+
683
+ for seg_idx in range(num_segs):
684
+ col_widths = self._gen_attention_mask_col_widths(seg_idx, utterance_length)
685
+
686
+ rc_mask_block = _gen_attention_mask_block(
687
+ col_widths, rc_q_cols_mask, self.right_context_length, input.device
688
+ )
689
+ rc_mask.append(rc_mask_block)
690
+
691
+ query_mask_block = _gen_attention_mask_block(
692
+ col_widths,
693
+ rc_q_cols_mask,
694
+ min(
695
+ self.segment_length,
696
+ utterance_length - seg_idx * self.segment_length,
697
+ ),
698
+ input.device,
699
+ )
700
+ query_mask.append(query_mask_block)
701
+
702
+ if s_cols_mask is not None:
703
+ summary_mask_block = _gen_attention_mask_block(col_widths, s_cols_mask, 1, input.device)
704
+ summary_mask.append(summary_mask_block)
705
+
706
+ attention_mask = (1 - torch.cat([torch.cat(mask) for mask in masks_to_concat])).to(torch.bool)
707
+ return attention_mask
708
+
709
+ def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
710
+ r"""Forward pass for training and non-streaming inference.
711
+
712
+ B: batch size;
713
+ T: max number of input frames in batch;
714
+ D: feature dimension of each frame.
715
+
716
+ Args:
717
+ input (torch.Tensor): utterance frames right-padded with right context frames, with
718
+ shape `(B, T + right_context_length, D)`.
719
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
720
+ number of valid utterance frames for i-th batch element in ``input``.
721
+
722
+ Returns:
723
+ (Tensor, Tensor):
724
+ Tensor
725
+ output frames, with shape `(B, T, D)`.
726
+ Tensor
727
+ output lengths, with shape `(B,)` and i-th element representing
728
+ number of valid frames for i-th batch element in output frames.
729
+ """
730
+ input = input.permute(1, 0, 2)
731
+ right_context = self._gen_right_context(input)
732
+ utterance = input[: input.size(0) - self.right_context_length]
733
+ attention_mask = self._gen_attention_mask(utterance)
734
+ mems = (
735
+ self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1]
736
+ if self.use_mem
737
+ else torch.empty(0).to(dtype=input.dtype, device=input.device)
738
+ )
739
+ output = utterance
740
+ for layer in self.emformer_layers:
741
+ output, right_context, mems = layer(output, lengths, right_context, mems, attention_mask)
742
+ return output.permute(1, 0, 2), lengths
743
+
744
+ @torch.jit.export
745
+ def infer(
746
+ self,
747
+ input: torch.Tensor,
748
+ lengths: torch.Tensor,
749
+ states: Optional[List[List[torch.Tensor]]] = None,
750
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
751
+ r"""Forward pass for streaming inference.
752
+
753
+ B: batch size;
754
+ D: feature dimension of each frame.
755
+
756
+ Args:
757
+ input (torch.Tensor): utterance frames right-padded with right context frames, with
758
+ shape `(B, segment_length + right_context_length, D)`.
759
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
760
+ number of valid frames for i-th batch element in ``input``.
761
+ states (List[List[torch.Tensor]] or None, optional): list of lists of tensors
762
+ representing internal state generated in preceding invocation of ``infer``. (Default: ``None``)
763
+
764
+ Returns:
765
+ (Tensor, Tensor, List[List[Tensor]]):
766
+ Tensor
767
+ output frames, with shape `(B, segment_length, D)`.
768
+ Tensor
769
+ output lengths, with shape `(B,)` and i-th element representing
770
+ number of valid frames for i-th batch element in output frames.
771
+ List[List[Tensor]]
772
+ output states; list of lists of tensors representing internal state
773
+ generated in current invocation of ``infer``.
774
+ """
775
+ if input.size(1) != self.segment_length + self.right_context_length:
776
+ raise ValueError(
777
+ "Per configured segment_length and right_context_length"
778
+ f", expected size of {self.segment_length + self.right_context_length} for dimension 1 of input"
779
+ f", but got {input.size(1)}."
780
+ )
781
+ input = input.permute(1, 0, 2)
782
+ right_context_start_idx = input.size(0) - self.right_context_length
783
+ right_context = input[right_context_start_idx:]
784
+ utterance = input[:right_context_start_idx]
785
+ output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
786
+ mems = (
787
+ self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
788
+ if self.use_mem
789
+ else torch.empty(0).to(dtype=input.dtype, device=input.device)
790
+ )
791
+ output = utterance
792
+ output_states: List[List[torch.Tensor]] = []
793
+ for layer_idx, layer in enumerate(self.emformer_layers):
794
+ output, right_context, output_state, mems = layer.infer(
795
+ output,
796
+ output_lengths,
797
+ right_context,
798
+ None if states is None else states[layer_idx],
799
+ mems,
800
+ )
801
+ output_states.append(output_state)
802
+
803
+ return output.permute(1, 0, 2), output_lengths, output_states
804
+
805
+
806
+ class Emformer(_EmformerImpl):
807
+ r"""Emformer architecture introduced in
808
+ *Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency Streaming Speech Recognition*
809
+ :cite:`shi2021emformer`.
810
+
811
+ See Also:
812
+ * :func:`~torchaudio.models.emformer_rnnt_model`,
813
+ :func:`~torchaudio.models.emformer_rnnt_base`: factory functions.
814
+ * :class:`torchaudio.pipelines.RNNTBundle`: ASR pipelines with pretrained model.
815
+
816
+ Args:
817
+ input_dim (int): input dimension.
818
+ num_heads (int): number of attention heads in each Emformer layer.
819
+ ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
820
+ num_layers (int): number of Emformer layers to instantiate.
821
+ segment_length (int): length of each input segment.
822
+ dropout (float, optional): dropout probability. (Default: 0.0)
823
+ activation (str, optional): activation function to use in each Emformer layer's
824
+ feedforward network. Must be one of ("relu", "gelu", "silu"). (Default: "relu")
825
+ left_context_length (int, optional): length of left context. (Default: 0)
826
+ right_context_length (int, optional): length of right context. (Default: 0)
827
+ max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
828
+ weight_init_scale_strategy (str or None, optional): per-layer weight initialization scaling
829
+ strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise")
830
+ tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
831
+ negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
832
+
833
+ Examples:
834
+ >>> emformer = Emformer(512, 8, 2048, 20, 4, right_context_length=1)
835
+ >>> input = torch.rand(128, 400, 512) # batch, num_frames, feature_dim
836
+ >>> lengths = torch.randint(1, 200, (128,)) # batch
837
+ >>> output, lengths = emformer(input, lengths)
838
+ >>> input = torch.rand(128, 5, 512)
839
+ >>> lengths = torch.ones(128) * 5
840
+ >>> output, lengths, states = emformer.infer(input, lengths, None)
841
+ """
842
+
843
+ def __init__(
844
+ self,
845
+ input_dim: int,
846
+ num_heads: int,
847
+ ffn_dim: int,
848
+ num_layers: int,
849
+ segment_length: int,
850
+ dropout: float = 0.0,
851
+ activation: str = "relu",
852
+ left_context_length: int = 0,
853
+ right_context_length: int = 0,
854
+ max_memory_size: int = 0,
855
+ weight_init_scale_strategy: Optional[str] = "depthwise",
856
+ tanh_on_mem: bool = False,
857
+ negative_inf: float = -1e8,
858
+ ):
859
+ weight_init_gains = _get_weight_init_gains(weight_init_scale_strategy, num_layers)
860
+ emformer_layers = torch.nn.ModuleList(
861
+ [
862
+ _EmformerLayer(
863
+ input_dim,
864
+ num_heads,
865
+ ffn_dim,
866
+ segment_length,
867
+ dropout=dropout,
868
+ activation=activation,
869
+ left_context_length=left_context_length,
870
+ max_memory_size=max_memory_size,
871
+ weight_init_gain=weight_init_gains[layer_idx],
872
+ tanh_on_mem=tanh_on_mem,
873
+ negative_inf=negative_inf,
874
+ )
875
+ for layer_idx in range(num_layers)
876
+ ]
877
+ )
878
+ super().__init__(
879
+ emformer_layers,
880
+ segment_length,
881
+ left_context_length=left_context_length,
882
+ right_context_length=right_context_length,
883
+ max_memory_size=max_memory_size,
884
+ )
.venv/lib/python3.11/site-packages/torchaudio/models/rnnt.py ADDED
@@ -0,0 +1,816 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Optional, Tuple
3
+
4
+ import torch
5
+ from torchaudio.models import Emformer
6
+
7
+
8
+ __all__ = ["RNNT", "emformer_rnnt_base", "emformer_rnnt_model"]
9
+
10
+
11
+ class _TimeReduction(torch.nn.Module):
12
+ r"""Coalesces frames along time dimension into a
13
+ fewer number of frames with higher feature dimensionality.
14
+
15
+ Args:
16
+ stride (int): number of frames to merge for each output frame.
17
+ """
18
+
19
+ def __init__(self, stride: int) -> None:
20
+ super().__init__()
21
+ self.stride = stride
22
+
23
+ def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
24
+ r"""Forward pass.
25
+
26
+ B: batch size;
27
+ T: maximum input sequence length in batch;
28
+ D: feature dimension of each input sequence frame.
29
+
30
+ Args:
31
+ input (torch.Tensor): input sequences, with shape `(B, T, D)`.
32
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
33
+ number of valid frames for i-th batch element in ``input``.
34
+
35
+ Returns:
36
+ (torch.Tensor, torch.Tensor):
37
+ torch.Tensor
38
+ output sequences, with shape
39
+ `(B, T // stride, D * stride)`
40
+ torch.Tensor
41
+ output lengths, with shape `(B,)` and i-th element representing
42
+ number of valid frames for i-th batch element in output sequences.
43
+ """
44
+ B, T, D = input.shape
45
+ num_frames = T - (T % self.stride)
46
+ input = input[:, :num_frames, :]
47
+ lengths = lengths.div(self.stride, rounding_mode="trunc")
48
+ T_max = num_frames // self.stride
49
+
50
+ output = input.reshape(B, T_max, D * self.stride)
51
+ output = output.contiguous()
52
+ return output, lengths
53
+
54
+
55
+ class _CustomLSTM(torch.nn.Module):
56
+ r"""Custom long-short-term memory (LSTM) block that applies layer normalization
57
+ to internal nodes.
58
+
59
+ Args:
60
+ input_dim (int): input dimension.
61
+ hidden_dim (int): hidden dimension.
62
+ layer_norm (bool, optional): if ``True``, enables layer normalization. (Default: ``False``)
63
+ layer_norm_epsilon (float, optional): value of epsilon to use in
64
+ layer normalization layers (Default: 1e-5)
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ input_dim: int,
70
+ hidden_dim: int,
71
+ layer_norm: bool = False,
72
+ layer_norm_epsilon: float = 1e-5,
73
+ ) -> None:
74
+ super().__init__()
75
+ self.x2g = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=(not layer_norm))
76
+ self.p2g = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=False)
77
+ if layer_norm:
78
+ self.c_norm = torch.nn.LayerNorm(hidden_dim, eps=layer_norm_epsilon)
79
+ self.g_norm = torch.nn.LayerNorm(4 * hidden_dim, eps=layer_norm_epsilon)
80
+ else:
81
+ self.c_norm = torch.nn.Identity()
82
+ self.g_norm = torch.nn.Identity()
83
+
84
+ self.hidden_dim = hidden_dim
85
+
86
+ def forward(
87
+ self, input: torch.Tensor, state: Optional[List[torch.Tensor]]
88
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
89
+ r"""Forward pass.
90
+
91
+ B: batch size;
92
+ T: maximum sequence length in batch;
93
+ D: feature dimension of each input sequence element.
94
+
95
+ Args:
96
+ input (torch.Tensor): with shape `(T, B, D)`.
97
+ state (List[torch.Tensor] or None): list of tensors
98
+ representing internal state generated in preceding invocation
99
+ of ``forward``.
100
+
101
+ Returns:
102
+ (torch.Tensor, List[torch.Tensor]):
103
+ torch.Tensor
104
+ output, with shape `(T, B, hidden_dim)`.
105
+ List[torch.Tensor]
106
+ list of tensors representing internal state generated
107
+ in current invocation of ``forward``.
108
+ """
109
+ if state is None:
110
+ B = input.size(1)
111
+ h = torch.zeros(B, self.hidden_dim, device=input.device, dtype=input.dtype)
112
+ c = torch.zeros(B, self.hidden_dim, device=input.device, dtype=input.dtype)
113
+ else:
114
+ h, c = state
115
+
116
+ gated_input = self.x2g(input)
117
+ outputs = []
118
+ for gates in gated_input.unbind(0):
119
+ gates = gates + self.p2g(h)
120
+ gates = self.g_norm(gates)
121
+ input_gate, forget_gate, cell_gate, output_gate = gates.chunk(4, 1)
122
+ input_gate = input_gate.sigmoid()
123
+ forget_gate = forget_gate.sigmoid()
124
+ cell_gate = cell_gate.tanh()
125
+ output_gate = output_gate.sigmoid()
126
+ c = forget_gate * c + input_gate * cell_gate
127
+ c = self.c_norm(c)
128
+ h = output_gate * c.tanh()
129
+ outputs.append(h)
130
+
131
+ output = torch.stack(outputs, dim=0)
132
+ state = [h, c]
133
+
134
+ return output, state
135
+
136
+
137
+ class _Transcriber(ABC):
138
+ @abstractmethod
139
+ def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
140
+ pass
141
+
142
+ @abstractmethod
143
+ def infer(
144
+ self,
145
+ input: torch.Tensor,
146
+ lengths: torch.Tensor,
147
+ states: Optional[List[List[torch.Tensor]]],
148
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
149
+ pass
150
+
151
+
152
+ class _EmformerEncoder(torch.nn.Module, _Transcriber):
153
+ r"""Emformer-based recurrent neural network transducer (RNN-T) encoder (transcription network).
154
+
155
+ Args:
156
+ input_dim (int): feature dimension of each input sequence element.
157
+ output_dim (int): feature dimension of each output sequence element.
158
+ segment_length (int): length of input segment expressed as number of frames.
159
+ right_context_length (int): length of right context expressed as number of frames.
160
+ time_reduction_input_dim (int): dimension to scale each element in input sequences to
161
+ prior to applying time reduction block.
162
+ time_reduction_stride (int): factor by which to reduce length of input sequence.
163
+ transformer_num_heads (int): number of attention heads in each Emformer layer.
164
+ transformer_ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
165
+ transformer_num_layers (int): number of Emformer layers to instantiate.
166
+ transformer_left_context_length (int): length of left context.
167
+ transformer_dropout (float, optional): transformer dropout probability. (Default: 0.0)
168
+ transformer_activation (str, optional): activation function to use in each Emformer layer's
169
+ feedforward network. Must be one of ("relu", "gelu", "silu"). (Default: "relu")
170
+ transformer_max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
171
+ transformer_weight_init_scale_strategy (str, optional): per-layer weight initialization scaling
172
+ strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise")
173
+ transformer_tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
174
+ """
175
+
176
+ def __init__(
177
+ self,
178
+ *,
179
+ input_dim: int,
180
+ output_dim: int,
181
+ segment_length: int,
182
+ right_context_length: int,
183
+ time_reduction_input_dim: int,
184
+ time_reduction_stride: int,
185
+ transformer_num_heads: int,
186
+ transformer_ffn_dim: int,
187
+ transformer_num_layers: int,
188
+ transformer_left_context_length: int,
189
+ transformer_dropout: float = 0.0,
190
+ transformer_activation: str = "relu",
191
+ transformer_max_memory_size: int = 0,
192
+ transformer_weight_init_scale_strategy: str = "depthwise",
193
+ transformer_tanh_on_mem: bool = False,
194
+ ) -> None:
195
+ super().__init__()
196
+ self.input_linear = torch.nn.Linear(
197
+ input_dim,
198
+ time_reduction_input_dim,
199
+ bias=False,
200
+ )
201
+ self.time_reduction = _TimeReduction(time_reduction_stride)
202
+ transformer_input_dim = time_reduction_input_dim * time_reduction_stride
203
+ self.transformer = Emformer(
204
+ transformer_input_dim,
205
+ transformer_num_heads,
206
+ transformer_ffn_dim,
207
+ transformer_num_layers,
208
+ segment_length // time_reduction_stride,
209
+ dropout=transformer_dropout,
210
+ activation=transformer_activation,
211
+ left_context_length=transformer_left_context_length,
212
+ right_context_length=right_context_length // time_reduction_stride,
213
+ max_memory_size=transformer_max_memory_size,
214
+ weight_init_scale_strategy=transformer_weight_init_scale_strategy,
215
+ tanh_on_mem=transformer_tanh_on_mem,
216
+ )
217
+ self.output_linear = torch.nn.Linear(transformer_input_dim, output_dim)
218
+ self.layer_norm = torch.nn.LayerNorm(output_dim)
219
+
220
+ def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
221
+ r"""Forward pass for training.
222
+
223
+ B: batch size;
224
+ T: maximum input sequence length in batch;
225
+ D: feature dimension of each input sequence frame (input_dim).
226
+
227
+ Args:
228
+ input (torch.Tensor): input frame sequences right-padded with right context, with
229
+ shape `(B, T + right context length, D)`.
230
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
231
+ number of valid frames for i-th batch element in ``input``.
232
+
233
+ Returns:
234
+ (torch.Tensor, torch.Tensor):
235
+ torch.Tensor
236
+ output frame sequences, with
237
+ shape `(B, T // time_reduction_stride, output_dim)`.
238
+ torch.Tensor
239
+ output input lengths, with shape `(B,)` and i-th element representing
240
+ number of valid elements for i-th batch element in output frame sequences.
241
+ """
242
+ input_linear_out = self.input_linear(input)
243
+ time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths)
244
+ transformer_out, transformer_lengths = self.transformer(time_reduction_out, time_reduction_lengths)
245
+ output_linear_out = self.output_linear(transformer_out)
246
+ layer_norm_out = self.layer_norm(output_linear_out)
247
+ return layer_norm_out, transformer_lengths
248
+
249
+ @torch.jit.export
250
+ def infer(
251
+ self,
252
+ input: torch.Tensor,
253
+ lengths: torch.Tensor,
254
+ states: Optional[List[List[torch.Tensor]]],
255
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
256
+ r"""Forward pass for inference.
257
+
258
+ B: batch size;
259
+ T: maximum input sequence segment length in batch;
260
+ D: feature dimension of each input sequence frame (input_dim).
261
+
262
+ Args:
263
+ input (torch.Tensor): input frame sequence segments right-padded with right context, with
264
+ shape `(B, T + right context length, D)`.
265
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
266
+ number of valid frames for i-th batch element in ``input``.
267
+ state (List[List[torch.Tensor]] or None): list of lists of tensors
268
+ representing internal state generated in preceding invocation
269
+ of ``infer``.
270
+
271
+ Returns:
272
+ (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
273
+ torch.Tensor
274
+ output frame sequences, with
275
+ shape `(B, T // time_reduction_stride, output_dim)`.
276
+ torch.Tensor
277
+ output input lengths, with shape `(B,)` and i-th element representing
278
+ number of valid elements for i-th batch element in output.
279
+ List[List[torch.Tensor]]
280
+ output states; list of lists of tensors
281
+ representing internal state generated in current invocation
282
+ of ``infer``.
283
+ """
284
+ input_linear_out = self.input_linear(input)
285
+ time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths)
286
+ (
287
+ transformer_out,
288
+ transformer_lengths,
289
+ transformer_states,
290
+ ) = self.transformer.infer(time_reduction_out, time_reduction_lengths, states)
291
+ output_linear_out = self.output_linear(transformer_out)
292
+ layer_norm_out = self.layer_norm(output_linear_out)
293
+ return layer_norm_out, transformer_lengths, transformer_states
294
+
295
+
296
+ class _Predictor(torch.nn.Module):
297
+ r"""Recurrent neural network transducer (RNN-T) prediction network.
298
+
299
+ Args:
300
+ num_symbols (int): size of target token lexicon.
301
+ output_dim (int): feature dimension of each output sequence element.
302
+ symbol_embedding_dim (int): dimension of each target token embedding.
303
+ num_lstm_layers (int): number of LSTM layers to instantiate.
304
+ lstm_hidden_dim (int): output dimension of each LSTM layer.
305
+ lstm_layer_norm (bool, optional): if ``True``, enables layer normalization
306
+ for LSTM layers. (Default: ``False``)
307
+ lstm_layer_norm_epsilon (float, optional): value of epsilon to use in
308
+ LSTM layer normalization layers. (Default: 1e-5)
309
+ lstm_dropout (float, optional): LSTM dropout probability. (Default: 0.0)
310
+
311
+ """
312
+
313
+ def __init__(
314
+ self,
315
+ num_symbols: int,
316
+ output_dim: int,
317
+ symbol_embedding_dim: int,
318
+ num_lstm_layers: int,
319
+ lstm_hidden_dim: int,
320
+ lstm_layer_norm: bool = False,
321
+ lstm_layer_norm_epsilon: float = 1e-5,
322
+ lstm_dropout: float = 0.0,
323
+ ) -> None:
324
+ super().__init__()
325
+ self.embedding = torch.nn.Embedding(num_symbols, symbol_embedding_dim)
326
+ self.input_layer_norm = torch.nn.LayerNorm(symbol_embedding_dim)
327
+ self.lstm_layers = torch.nn.ModuleList(
328
+ [
329
+ _CustomLSTM(
330
+ symbol_embedding_dim if idx == 0 else lstm_hidden_dim,
331
+ lstm_hidden_dim,
332
+ layer_norm=lstm_layer_norm,
333
+ layer_norm_epsilon=lstm_layer_norm_epsilon,
334
+ )
335
+ for idx in range(num_lstm_layers)
336
+ ]
337
+ )
338
+ self.dropout = torch.nn.Dropout(p=lstm_dropout)
339
+ self.linear = torch.nn.Linear(lstm_hidden_dim, output_dim)
340
+ self.output_layer_norm = torch.nn.LayerNorm(output_dim)
341
+
342
+ self.lstm_dropout = lstm_dropout
343
+
344
+ def forward(
345
+ self,
346
+ input: torch.Tensor,
347
+ lengths: torch.Tensor,
348
+ state: Optional[List[List[torch.Tensor]]] = None,
349
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
350
+ r"""Forward pass.
351
+
352
+ B: batch size;
353
+ U: maximum sequence length in batch;
354
+ D: feature dimension of each input sequence element.
355
+
356
+ Args:
357
+ input (torch.Tensor): target sequences, with shape `(B, U)` and each element
358
+ mapping to a target symbol, i.e. in range `[0, num_symbols)`.
359
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
360
+ number of valid frames for i-th batch element in ``input``.
361
+ state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
362
+ representing internal state generated in preceding invocation
363
+ of ``forward``. (Default: ``None``)
364
+
365
+ Returns:
366
+ (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
367
+ torch.Tensor
368
+ output encoding sequences, with shape `(B, U, output_dim)`
369
+ torch.Tensor
370
+ output lengths, with shape `(B,)` and i-th element representing
371
+ number of valid elements for i-th batch element in output encoding sequences.
372
+ List[List[torch.Tensor]]
373
+ output states; list of lists of tensors
374
+ representing internal state generated in current invocation of ``forward``.
375
+ """
376
+ input_tb = input.permute(1, 0)
377
+ embedding_out = self.embedding(input_tb)
378
+ input_layer_norm_out = self.input_layer_norm(embedding_out)
379
+
380
+ lstm_out = input_layer_norm_out
381
+ state_out: List[List[torch.Tensor]] = []
382
+ for layer_idx, lstm in enumerate(self.lstm_layers):
383
+ lstm_out, lstm_state_out = lstm(lstm_out, None if state is None else state[layer_idx])
384
+ lstm_out = self.dropout(lstm_out)
385
+ state_out.append(lstm_state_out)
386
+
387
+ linear_out = self.linear(lstm_out)
388
+ output_layer_norm_out = self.output_layer_norm(linear_out)
389
+ return output_layer_norm_out.permute(1, 0, 2), lengths, state_out
390
+
391
+
392
+ class _Joiner(torch.nn.Module):
393
+ r"""Recurrent neural network transducer (RNN-T) joint network.
394
+
395
+ Args:
396
+ input_dim (int): source and target input dimension.
397
+ output_dim (int): output dimension.
398
+ activation (str, optional): activation function to use in the joiner.
399
+ Must be one of ("relu", "tanh"). (Default: "relu")
400
+
401
+ """
402
+
403
+ def __init__(self, input_dim: int, output_dim: int, activation: str = "relu") -> None:
404
+ super().__init__()
405
+ self.linear = torch.nn.Linear(input_dim, output_dim, bias=True)
406
+ if activation == "relu":
407
+ self.activation = torch.nn.ReLU()
408
+ elif activation == "tanh":
409
+ self.activation = torch.nn.Tanh()
410
+ else:
411
+ raise ValueError(f"Unsupported activation {activation}")
412
+
413
+ def forward(
414
+ self,
415
+ source_encodings: torch.Tensor,
416
+ source_lengths: torch.Tensor,
417
+ target_encodings: torch.Tensor,
418
+ target_lengths: torch.Tensor,
419
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
420
+ r"""Forward pass for training.
421
+
422
+ B: batch size;
423
+ T: maximum source sequence length in batch;
424
+ U: maximum target sequence length in batch;
425
+ D: dimension of each source and target sequence encoding.
426
+
427
+ Args:
428
+ source_encodings (torch.Tensor): source encoding sequences, with
429
+ shape `(B, T, D)`.
430
+ source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
431
+ valid sequence length of i-th batch element in ``source_encodings``.
432
+ target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
433
+ target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
434
+ valid sequence length of i-th batch element in ``target_encodings``.
435
+
436
+ Returns:
437
+ (torch.Tensor, torch.Tensor, torch.Tensor):
438
+ torch.Tensor
439
+ joint network output, with shape `(B, T, U, output_dim)`.
440
+ torch.Tensor
441
+ output source lengths, with shape `(B,)` and i-th element representing
442
+ number of valid elements along dim 1 for i-th batch element in joint network output.
443
+ torch.Tensor
444
+ output target lengths, with shape `(B,)` and i-th element representing
445
+ number of valid elements along dim 2 for i-th batch element in joint network output.
446
+ """
447
+ joint_encodings = source_encodings.unsqueeze(2).contiguous() + target_encodings.unsqueeze(1).contiguous()
448
+ activation_out = self.activation(joint_encodings)
449
+ output = self.linear(activation_out)
450
+ return output, source_lengths, target_lengths
451
+
452
+
453
+ class RNNT(torch.nn.Module):
454
+ r"""torchaudio.models.RNNT()
455
+
456
+ Recurrent neural network transducer (RNN-T) model.
457
+
458
+ Note:
459
+ To build the model, please use one of the factory functions.
460
+
461
+ See Also:
462
+ :class:`torchaudio.pipelines.RNNTBundle`: ASR pipeline with pre-trained models.
463
+
464
+ Args:
465
+ transcriber (torch.nn.Module): transcription network.
466
+ predictor (torch.nn.Module): prediction network.
467
+ joiner (torch.nn.Module): joint network.
468
+ """
469
+
470
+ def __init__(self, transcriber: _Transcriber, predictor: _Predictor, joiner: _Joiner) -> None:
471
+ super().__init__()
472
+ self.transcriber = transcriber
473
+ self.predictor = predictor
474
+ self.joiner = joiner
475
+
476
+ def forward(
477
+ self,
478
+ sources: torch.Tensor,
479
+ source_lengths: torch.Tensor,
480
+ targets: torch.Tensor,
481
+ target_lengths: torch.Tensor,
482
+ predictor_state: Optional[List[List[torch.Tensor]]] = None,
483
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
484
+ r"""Forward pass for training.
485
+
486
+ B: batch size;
487
+ T: maximum source sequence length in batch;
488
+ U: maximum target sequence length in batch;
489
+ D: feature dimension of each source sequence element.
490
+
491
+ Args:
492
+ sources (torch.Tensor): source frame sequences right-padded with right context, with
493
+ shape `(B, T, D)`.
494
+ source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
495
+ number of valid frames for i-th batch element in ``sources``.
496
+ targets (torch.Tensor): target sequences, with shape `(B, U)` and each element
497
+ mapping to a target symbol.
498
+ target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
499
+ number of valid frames for i-th batch element in ``targets``.
500
+ predictor_state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
501
+ representing prediction network internal state generated in preceding invocation
502
+ of ``forward``. (Default: ``None``)
503
+
504
+ Returns:
505
+ (torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
506
+ torch.Tensor
507
+ joint network output, with shape
508
+ `(B, max output source length, max output target length, output_dim (number of target symbols))`.
509
+ torch.Tensor
510
+ output source lengths, with shape `(B,)` and i-th element representing
511
+ number of valid elements along dim 1 for i-th batch element in joint network output.
512
+ torch.Tensor
513
+ output target lengths, with shape `(B,)` and i-th element representing
514
+ number of valid elements along dim 2 for i-th batch element in joint network output.
515
+ List[List[torch.Tensor]]
516
+ output states; list of lists of tensors
517
+ representing prediction network internal state generated in current invocation
518
+ of ``forward``.
519
+ """
520
+ source_encodings, source_lengths = self.transcriber(
521
+ input=sources,
522
+ lengths=source_lengths,
523
+ )
524
+ target_encodings, target_lengths, predictor_state = self.predictor(
525
+ input=targets,
526
+ lengths=target_lengths,
527
+ state=predictor_state,
528
+ )
529
+ output, source_lengths, target_lengths = self.joiner(
530
+ source_encodings=source_encodings,
531
+ source_lengths=source_lengths,
532
+ target_encodings=target_encodings,
533
+ target_lengths=target_lengths,
534
+ )
535
+
536
+ return (
537
+ output,
538
+ source_lengths,
539
+ target_lengths,
540
+ predictor_state,
541
+ )
542
+
543
+ @torch.jit.export
544
+ def transcribe_streaming(
545
+ self,
546
+ sources: torch.Tensor,
547
+ source_lengths: torch.Tensor,
548
+ state: Optional[List[List[torch.Tensor]]],
549
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
550
+ r"""Applies transcription network to sources in streaming mode.
551
+
552
+ B: batch size;
553
+ T: maximum source sequence segment length in batch;
554
+ D: feature dimension of each source sequence frame.
555
+
556
+ Args:
557
+ sources (torch.Tensor): source frame sequence segments right-padded with right context, with
558
+ shape `(B, T + right context length, D)`.
559
+ source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
560
+ number of valid frames for i-th batch element in ``sources``.
561
+ state (List[List[torch.Tensor]] or None): list of lists of tensors
562
+ representing transcription network internal state generated in preceding invocation
563
+ of ``transcribe_streaming``.
564
+
565
+ Returns:
566
+ (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
567
+ torch.Tensor
568
+ output frame sequences, with
569
+ shape `(B, T // time_reduction_stride, output_dim)`.
570
+ torch.Tensor
571
+ output lengths, with shape `(B,)` and i-th element representing
572
+ number of valid elements for i-th batch element in output.
573
+ List[List[torch.Tensor]]
574
+ output states; list of lists of tensors
575
+ representing transcription network internal state generated in current invocation
576
+ of ``transcribe_streaming``.
577
+ """
578
+ return self.transcriber.infer(sources, source_lengths, state)
579
+
580
+ @torch.jit.export
581
+ def transcribe(
582
+ self,
583
+ sources: torch.Tensor,
584
+ source_lengths: torch.Tensor,
585
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
586
+ r"""Applies transcription network to sources in non-streaming mode.
587
+
588
+ B: batch size;
589
+ T: maximum source sequence length in batch;
590
+ D: feature dimension of each source sequence frame.
591
+
592
+ Args:
593
+ sources (torch.Tensor): source frame sequences right-padded with right context, with
594
+ shape `(B, T + right context length, D)`.
595
+ source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
596
+ number of valid frames for i-th batch element in ``sources``.
597
+
598
+ Returns:
599
+ (torch.Tensor, torch.Tensor):
600
+ torch.Tensor
601
+ output frame sequences, with
602
+ shape `(B, T // time_reduction_stride, output_dim)`.
603
+ torch.Tensor
604
+ output lengths, with shape `(B,)` and i-th element representing
605
+ number of valid elements for i-th batch element in output frame sequences.
606
+ """
607
+ return self.transcriber(sources, source_lengths)
608
+
609
+ @torch.jit.export
610
+ def predict(
611
+ self,
612
+ targets: torch.Tensor,
613
+ target_lengths: torch.Tensor,
614
+ state: Optional[List[List[torch.Tensor]]],
615
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
616
+ r"""Applies prediction network to targets.
617
+
618
+ B: batch size;
619
+ U: maximum target sequence length in batch;
620
+ D: feature dimension of each target sequence frame.
621
+
622
+ Args:
623
+ targets (torch.Tensor): target sequences, with shape `(B, U)` and each element
624
+ mapping to a target symbol, i.e. in range `[0, num_symbols)`.
625
+ target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
626
+ number of valid frames for i-th batch element in ``targets``.
627
+ state (List[List[torch.Tensor]] or None): list of lists of tensors
628
+ representing internal state generated in preceding invocation
629
+ of ``predict``.
630
+
631
+ Returns:
632
+ (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
633
+ torch.Tensor
634
+ output frame sequences, with shape `(B, U, output_dim)`.
635
+ torch.Tensor
636
+ output lengths, with shape `(B,)` and i-th element representing
637
+ number of valid elements for i-th batch element in output.
638
+ List[List[torch.Tensor]]
639
+ output states; list of lists of tensors
640
+ representing internal state generated in current invocation of ``predict``.
641
+ """
642
+ return self.predictor(input=targets, lengths=target_lengths, state=state)
643
+
644
+ @torch.jit.export
645
+ def join(
646
+ self,
647
+ source_encodings: torch.Tensor,
648
+ source_lengths: torch.Tensor,
649
+ target_encodings: torch.Tensor,
650
+ target_lengths: torch.Tensor,
651
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
652
+ r"""Applies joint network to source and target encodings.
653
+
654
+ B: batch size;
655
+ T: maximum source sequence length in batch;
656
+ U: maximum target sequence length in batch;
657
+ D: dimension of each source and target sequence encoding.
658
+
659
+ Args:
660
+ source_encodings (torch.Tensor): source encoding sequences, with
661
+ shape `(B, T, D)`.
662
+ source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
663
+ valid sequence length of i-th batch element in ``source_encodings``.
664
+ target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
665
+ target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
666
+ valid sequence length of i-th batch element in ``target_encodings``.
667
+
668
+ Returns:
669
+ (torch.Tensor, torch.Tensor, torch.Tensor):
670
+ torch.Tensor
671
+ joint network output, with shape `(B, T, U, output_dim)`.
672
+ torch.Tensor
673
+ output source lengths, with shape `(B,)` and i-th element representing
674
+ number of valid elements along dim 1 for i-th batch element in joint network output.
675
+ torch.Tensor
676
+ output target lengths, with shape `(B,)` and i-th element representing
677
+ number of valid elements along dim 2 for i-th batch element in joint network output.
678
+ """
679
+ output, source_lengths, target_lengths = self.joiner(
680
+ source_encodings=source_encodings,
681
+ source_lengths=source_lengths,
682
+ target_encodings=target_encodings,
683
+ target_lengths=target_lengths,
684
+ )
685
+ return output, source_lengths, target_lengths
686
+
687
+
688
+ def emformer_rnnt_model(
689
+ *,
690
+ input_dim: int,
691
+ encoding_dim: int,
692
+ num_symbols: int,
693
+ segment_length: int,
694
+ right_context_length: int,
695
+ time_reduction_input_dim: int,
696
+ time_reduction_stride: int,
697
+ transformer_num_heads: int,
698
+ transformer_ffn_dim: int,
699
+ transformer_num_layers: int,
700
+ transformer_dropout: float,
701
+ transformer_activation: str,
702
+ transformer_left_context_length: int,
703
+ transformer_max_memory_size: int,
704
+ transformer_weight_init_scale_strategy: str,
705
+ transformer_tanh_on_mem: bool,
706
+ symbol_embedding_dim: int,
707
+ num_lstm_layers: int,
708
+ lstm_layer_norm: bool,
709
+ lstm_layer_norm_epsilon: float,
710
+ lstm_dropout: float,
711
+ ) -> RNNT:
712
+ r"""Builds Emformer-based :class:`~torchaudio.models.RNNT`.
713
+
714
+ Note:
715
+ For non-streaming inference, the expectation is for `transcribe` to be called on input
716
+ sequences right-concatenated with `right_context_length` frames.
717
+
718
+ For streaming inference, the expectation is for `transcribe_streaming` to be called
719
+ on input chunks comprising `segment_length` frames right-concatenated with `right_context_length`
720
+ frames.
721
+
722
+ Args:
723
+ input_dim (int): dimension of input sequence frames passed to transcription network.
724
+ encoding_dim (int): dimension of transcription- and prediction-network-generated encodings
725
+ passed to joint network.
726
+ num_symbols (int): cardinality of set of target tokens.
727
+ segment_length (int): length of input segment expressed as number of frames.
728
+ right_context_length (int): length of right context expressed as number of frames.
729
+ time_reduction_input_dim (int): dimension to scale each element in input sequences to
730
+ prior to applying time reduction block.
731
+ time_reduction_stride (int): factor by which to reduce length of input sequence.
732
+ transformer_num_heads (int): number of attention heads in each Emformer layer.
733
+ transformer_ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
734
+ transformer_num_layers (int): number of Emformer layers to instantiate.
735
+ transformer_left_context_length (int): length of left context considered by Emformer.
736
+ transformer_dropout (float): Emformer dropout probability.
737
+ transformer_activation (str): activation function to use in each Emformer layer's
738
+ feedforward network. Must be one of ("relu", "gelu", "silu").
739
+ transformer_max_memory_size (int): maximum number of memory elements to use.
740
+ transformer_weight_init_scale_strategy (str): per-layer weight initialization scaling
741
+ strategy. Must be one of ("depthwise", "constant", ``None``).
742
+ transformer_tanh_on_mem (bool): if ``True``, applies tanh to memory elements.
743
+ symbol_embedding_dim (int): dimension of each target token embedding.
744
+ num_lstm_layers (int): number of LSTM layers to instantiate.
745
+ lstm_layer_norm (bool): if ``True``, enables layer normalization for LSTM layers.
746
+ lstm_layer_norm_epsilon (float): value of epsilon to use in LSTM layer normalization layers.
747
+ lstm_dropout (float): LSTM dropout probability.
748
+
749
+ Returns:
750
+ RNNT:
751
+ Emformer RNN-T model.
752
+ """
753
+ encoder = _EmformerEncoder(
754
+ input_dim=input_dim,
755
+ output_dim=encoding_dim,
756
+ segment_length=segment_length,
757
+ right_context_length=right_context_length,
758
+ time_reduction_input_dim=time_reduction_input_dim,
759
+ time_reduction_stride=time_reduction_stride,
760
+ transformer_num_heads=transformer_num_heads,
761
+ transformer_ffn_dim=transformer_ffn_dim,
762
+ transformer_num_layers=transformer_num_layers,
763
+ transformer_dropout=transformer_dropout,
764
+ transformer_activation=transformer_activation,
765
+ transformer_left_context_length=transformer_left_context_length,
766
+ transformer_max_memory_size=transformer_max_memory_size,
767
+ transformer_weight_init_scale_strategy=transformer_weight_init_scale_strategy,
768
+ transformer_tanh_on_mem=transformer_tanh_on_mem,
769
+ )
770
+ predictor = _Predictor(
771
+ num_symbols,
772
+ encoding_dim,
773
+ symbol_embedding_dim=symbol_embedding_dim,
774
+ num_lstm_layers=num_lstm_layers,
775
+ lstm_hidden_dim=symbol_embedding_dim,
776
+ lstm_layer_norm=lstm_layer_norm,
777
+ lstm_layer_norm_epsilon=lstm_layer_norm_epsilon,
778
+ lstm_dropout=lstm_dropout,
779
+ )
780
+ joiner = _Joiner(encoding_dim, num_symbols)
781
+ return RNNT(encoder, predictor, joiner)
782
+
783
+
784
+ def emformer_rnnt_base(num_symbols: int) -> RNNT:
785
+ r"""Builds basic version of Emformer-based :class:`~torchaudio.models.RNNT`.
786
+
787
+ Args:
788
+ num_symbols (int): The size of target token lexicon.
789
+
790
+ Returns:
791
+ RNNT:
792
+ Emformer RNN-T model.
793
+ """
794
+ return emformer_rnnt_model(
795
+ input_dim=80,
796
+ encoding_dim=1024,
797
+ num_symbols=num_symbols,
798
+ segment_length=16,
799
+ right_context_length=4,
800
+ time_reduction_input_dim=128,
801
+ time_reduction_stride=4,
802
+ transformer_num_heads=8,
803
+ transformer_ffn_dim=2048,
804
+ transformer_num_layers=20,
805
+ transformer_dropout=0.1,
806
+ transformer_activation="gelu",
807
+ transformer_left_context_length=30,
808
+ transformer_max_memory_size=0,
809
+ transformer_weight_init_scale_strategy="depthwise",
810
+ transformer_tanh_on_mem=True,
811
+ symbol_embedding_dim=512,
812
+ num_lstm_layers=3,
813
+ lstm_layer_norm=True,
814
+ lstm_layer_norm_epsilon=1e-3,
815
+ lstm_dropout=0.3,
816
+ )
.venv/lib/python3.11/site-packages/torchaudio/models/rnnt_decoder.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Dict, List, Optional, Tuple
2
+
3
+ import torch
4
+ from torchaudio.models import RNNT
5
+
6
+
7
+ __all__ = ["Hypothesis", "RNNTBeamSearch"]
8
+
9
+
10
+ Hypothesis = Tuple[List[int], torch.Tensor, List[List[torch.Tensor]], float]
11
+ Hypothesis.__doc__ = """Hypothesis generated by RNN-T beam search decoder,
12
+ represented as tuple of (tokens, prediction network output, prediction network state, score).
13
+ """
14
+
15
+
16
+ def _get_hypo_tokens(hypo: Hypothesis) -> List[int]:
17
+ return hypo[0]
18
+
19
+
20
+ def _get_hypo_predictor_out(hypo: Hypothesis) -> torch.Tensor:
21
+ return hypo[1]
22
+
23
+
24
+ def _get_hypo_state(hypo: Hypothesis) -> List[List[torch.Tensor]]:
25
+ return hypo[2]
26
+
27
+
28
+ def _get_hypo_score(hypo: Hypothesis) -> float:
29
+ return hypo[3]
30
+
31
+
32
+ def _get_hypo_key(hypo: Hypothesis) -> str:
33
+ return str(hypo[0])
34
+
35
+
36
+ def _batch_state(hypos: List[Hypothesis]) -> List[List[torch.Tensor]]:
37
+ states: List[List[torch.Tensor]] = []
38
+ for i in range(len(_get_hypo_state(hypos[0]))):
39
+ batched_state_components: List[torch.Tensor] = []
40
+ for j in range(len(_get_hypo_state(hypos[0])[i])):
41
+ batched_state_components.append(torch.cat([_get_hypo_state(hypo)[i][j] for hypo in hypos]))
42
+ states.append(batched_state_components)
43
+ return states
44
+
45
+
46
+ def _slice_state(states: List[List[torch.Tensor]], idx: int, device: torch.device) -> List[List[torch.Tensor]]:
47
+ idx_tensor = torch.tensor([idx], device=device)
48
+ return [[state.index_select(0, idx_tensor) for state in state_tuple] for state_tuple in states]
49
+
50
+
51
+ def _default_hypo_sort_key(hypo: Hypothesis) -> float:
52
+ return _get_hypo_score(hypo) / (len(_get_hypo_tokens(hypo)) + 1)
53
+
54
+
55
+ def _compute_updated_scores(
56
+ hypos: List[Hypothesis],
57
+ next_token_probs: torch.Tensor,
58
+ beam_width: int,
59
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
60
+ hypo_scores = torch.tensor([_get_hypo_score(h) for h in hypos]).unsqueeze(1)
61
+ nonblank_scores = hypo_scores + next_token_probs[:, :-1] # [beam_width, num_tokens - 1]
62
+ nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(beam_width)
63
+ nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(nonblank_scores.shape[1], rounding_mode="trunc")
64
+ nonblank_nbest_token = nonblank_nbest_idx % nonblank_scores.shape[1]
65
+ return nonblank_nbest_scores, nonblank_nbest_hypo_idx, nonblank_nbest_token
66
+
67
+
68
+ def _remove_hypo(hypo: Hypothesis, hypo_list: List[Hypothesis]) -> None:
69
+ for i, elem in enumerate(hypo_list):
70
+ if _get_hypo_key(hypo) == _get_hypo_key(elem):
71
+ del hypo_list[i]
72
+ break
73
+
74
+
75
+ class RNNTBeamSearch(torch.nn.Module):
76
+ r"""Beam search decoder for RNN-T model.
77
+
78
+ See Also:
79
+ * :class:`torchaudio.pipelines.RNNTBundle`: ASR pipeline with pretrained model.
80
+
81
+ Args:
82
+ model (RNNT): RNN-T model to use.
83
+ blank (int): index of blank token in vocabulary.
84
+ temperature (float, optional): temperature to apply to joint network output.
85
+ Larger values yield more uniform samples. (Default: 1.0)
86
+ hypo_sort_key (Callable[[Hypothesis], float] or None, optional): callable that computes a score
87
+ for a given hypothesis to rank hypotheses by. If ``None``, defaults to callable that returns
88
+ hypothesis score normalized by token sequence length. (Default: None)
89
+ step_max_tokens (int, optional): maximum number of tokens to emit per input time step. (Default: 100)
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ model: RNNT,
95
+ blank: int,
96
+ temperature: float = 1.0,
97
+ hypo_sort_key: Optional[Callable[[Hypothesis], float]] = None,
98
+ step_max_tokens: int = 100,
99
+ ) -> None:
100
+ super().__init__()
101
+ self.model = model
102
+ self.blank = blank
103
+ self.temperature = temperature
104
+
105
+ if hypo_sort_key is None:
106
+ self.hypo_sort_key = _default_hypo_sort_key
107
+ else:
108
+ self.hypo_sort_key = hypo_sort_key
109
+
110
+ self.step_max_tokens = step_max_tokens
111
+
112
+ def _init_b_hypos(self, device: torch.device) -> List[Hypothesis]:
113
+ token = self.blank
114
+ state = None
115
+
116
+ one_tensor = torch.tensor([1], device=device)
117
+ pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state)
118
+ init_hypo = (
119
+ [token],
120
+ pred_out[0].detach(),
121
+ pred_state,
122
+ 0.0,
123
+ )
124
+ return [init_hypo]
125
+
126
+ def _gen_next_token_probs(
127
+ self, enc_out: torch.Tensor, hypos: List[Hypothesis], device: torch.device
128
+ ) -> torch.Tensor:
129
+ one_tensor = torch.tensor([1], device=device)
130
+ predictor_out = torch.stack([_get_hypo_predictor_out(h) for h in hypos], dim=0)
131
+ joined_out, _, _ = self.model.join(
132
+ enc_out,
133
+ one_tensor,
134
+ predictor_out,
135
+ torch.tensor([1] * len(hypos), device=device),
136
+ ) # [beam_width, 1, 1, num_tokens]
137
+ joined_out = torch.nn.functional.log_softmax(joined_out / self.temperature, dim=3)
138
+ return joined_out[:, 0, 0]
139
+
140
+ def _gen_b_hypos(
141
+ self,
142
+ b_hypos: List[Hypothesis],
143
+ a_hypos: List[Hypothesis],
144
+ next_token_probs: torch.Tensor,
145
+ key_to_b_hypo: Dict[str, Hypothesis],
146
+ ) -> List[Hypothesis]:
147
+ for i in range(len(a_hypos)):
148
+ h_a = a_hypos[i]
149
+ append_blank_score = _get_hypo_score(h_a) + next_token_probs[i, -1]
150
+ if _get_hypo_key(h_a) in key_to_b_hypo:
151
+ h_b = key_to_b_hypo[_get_hypo_key(h_a)]
152
+ _remove_hypo(h_b, b_hypos)
153
+ score = float(torch.tensor(_get_hypo_score(h_b)).logaddexp(append_blank_score))
154
+ else:
155
+ score = float(append_blank_score)
156
+ h_b = (
157
+ _get_hypo_tokens(h_a),
158
+ _get_hypo_predictor_out(h_a),
159
+ _get_hypo_state(h_a),
160
+ score,
161
+ )
162
+ b_hypos.append(h_b)
163
+ key_to_b_hypo[_get_hypo_key(h_b)] = h_b
164
+ _, sorted_idx = torch.tensor([_get_hypo_score(hypo) for hypo in b_hypos]).sort()
165
+ return [b_hypos[idx] for idx in sorted_idx]
166
+
167
+ def _gen_a_hypos(
168
+ self,
169
+ a_hypos: List[Hypothesis],
170
+ b_hypos: List[Hypothesis],
171
+ next_token_probs: torch.Tensor,
172
+ t: int,
173
+ beam_width: int,
174
+ device: torch.device,
175
+ ) -> List[Hypothesis]:
176
+ (
177
+ nonblank_nbest_scores,
178
+ nonblank_nbest_hypo_idx,
179
+ nonblank_nbest_token,
180
+ ) = _compute_updated_scores(a_hypos, next_token_probs, beam_width)
181
+
182
+ if len(b_hypos) < beam_width:
183
+ b_nbest_score = -float("inf")
184
+ else:
185
+ b_nbest_score = _get_hypo_score(b_hypos[-beam_width])
186
+
187
+ base_hypos: List[Hypothesis] = []
188
+ new_tokens: List[int] = []
189
+ new_scores: List[float] = []
190
+ for i in range(beam_width):
191
+ score = float(nonblank_nbest_scores[i])
192
+ if score > b_nbest_score:
193
+ a_hypo_idx = int(nonblank_nbest_hypo_idx[i])
194
+ base_hypos.append(a_hypos[a_hypo_idx])
195
+ new_tokens.append(int(nonblank_nbest_token[i]))
196
+ new_scores.append(score)
197
+
198
+ if base_hypos:
199
+ new_hypos = self._gen_new_hypos(base_hypos, new_tokens, new_scores, t, device)
200
+ else:
201
+ new_hypos: List[Hypothesis] = []
202
+
203
+ return new_hypos
204
+
205
+ def _gen_new_hypos(
206
+ self,
207
+ base_hypos: List[Hypothesis],
208
+ tokens: List[int],
209
+ scores: List[float],
210
+ t: int,
211
+ device: torch.device,
212
+ ) -> List[Hypothesis]:
213
+ tgt_tokens = torch.tensor([[token] for token in tokens], device=device)
214
+ states = _batch_state(base_hypos)
215
+ pred_out, _, pred_states = self.model.predict(
216
+ tgt_tokens,
217
+ torch.tensor([1] * len(base_hypos), device=device),
218
+ states,
219
+ )
220
+ new_hypos: List[Hypothesis] = []
221
+ for i, h_a in enumerate(base_hypos):
222
+ new_tokens = _get_hypo_tokens(h_a) + [tokens[i]]
223
+ new_hypos.append((new_tokens, pred_out[i].detach(), _slice_state(pred_states, i, device), scores[i]))
224
+ return new_hypos
225
+
226
+ def _search(
227
+ self,
228
+ enc_out: torch.Tensor,
229
+ hypo: Optional[List[Hypothesis]],
230
+ beam_width: int,
231
+ ) -> List[Hypothesis]:
232
+ n_time_steps = enc_out.shape[1]
233
+ device = enc_out.device
234
+
235
+ a_hypos: List[Hypothesis] = []
236
+ b_hypos = self._init_b_hypos(device) if hypo is None else hypo
237
+ for t in range(n_time_steps):
238
+ a_hypos = b_hypos
239
+ b_hypos = torch.jit.annotate(List[Hypothesis], [])
240
+ key_to_b_hypo: Dict[str, Hypothesis] = {}
241
+ symbols_current_t = 0
242
+
243
+ while a_hypos:
244
+ next_token_probs = self._gen_next_token_probs(enc_out[:, t : t + 1], a_hypos, device)
245
+ next_token_probs = next_token_probs.cpu()
246
+ b_hypos = self._gen_b_hypos(b_hypos, a_hypos, next_token_probs, key_to_b_hypo)
247
+
248
+ if symbols_current_t == self.step_max_tokens:
249
+ break
250
+
251
+ a_hypos = self._gen_a_hypos(
252
+ a_hypos,
253
+ b_hypos,
254
+ next_token_probs,
255
+ t,
256
+ beam_width,
257
+ device,
258
+ )
259
+ if a_hypos:
260
+ symbols_current_t += 1
261
+
262
+ _, sorted_idx = torch.tensor([self.hypo_sort_key(hyp) for hyp in b_hypos]).topk(beam_width)
263
+ b_hypos = [b_hypos[idx] for idx in sorted_idx]
264
+
265
+ return b_hypos
266
+
267
+ def forward(self, input: torch.Tensor, length: torch.Tensor, beam_width: int) -> List[Hypothesis]:
268
+ r"""Performs beam search for the given input sequence.
269
+
270
+ T: number of frames;
271
+ D: feature dimension of each frame.
272
+
273
+ Args:
274
+ input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
275
+ length (torch.Tensor): number of valid frames in input
276
+ sequence, with shape () or (1,).
277
+ beam_width (int): beam size to use during search.
278
+
279
+ Returns:
280
+ List[Hypothesis]: top-``beam_width`` hypotheses found by beam search.
281
+ """
282
+ if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
283
+ raise ValueError("input must be of shape (T, D) or (1, T, D)")
284
+ if input.dim() == 2:
285
+ input = input.unsqueeze(0)
286
+
287
+ if length.shape != () and length.shape != (1,):
288
+ raise ValueError("length must be of shape () or (1,)")
289
+ if length.dim() == 0:
290
+ length = length.unsqueeze(0)
291
+
292
+ enc_out, _ = self.model.transcribe(input, length)
293
+ return self._search(enc_out, None, beam_width)
294
+
295
+ @torch.jit.export
296
+ def infer(
297
+ self,
298
+ input: torch.Tensor,
299
+ length: torch.Tensor,
300
+ beam_width: int,
301
+ state: Optional[List[List[torch.Tensor]]] = None,
302
+ hypothesis: Optional[List[Hypothesis]] = None,
303
+ ) -> Tuple[List[Hypothesis], List[List[torch.Tensor]]]:
304
+ r"""Performs beam search for the given input sequence in streaming mode.
305
+
306
+ T: number of frames;
307
+ D: feature dimension of each frame.
308
+
309
+ Args:
310
+ input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
311
+ length (torch.Tensor): number of valid frames in input
312
+ sequence, with shape () or (1,).
313
+ beam_width (int): beam size to use during search.
314
+ state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
315
+ representing transcription network internal state generated in preceding
316
+ invocation. (Default: ``None``)
317
+ hypothesis (List[Hypothesis] or None): hypotheses from preceding invocation to seed
318
+ search with. (Default: ``None``)
319
+
320
+ Returns:
321
+ (List[Hypothesis], List[List[torch.Tensor]]):
322
+ List[Hypothesis]
323
+ top-``beam_width`` hypotheses found by beam search.
324
+ List[List[torch.Tensor]]
325
+ list of lists of tensors representing transcription network
326
+ internal state generated in current invocation.
327
+ """
328
+ if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
329
+ raise ValueError("input must be of shape (T, D) or (1, T, D)")
330
+ if input.dim() == 2:
331
+ input = input.unsqueeze(0)
332
+
333
+ if length.shape != () and length.shape != (1,):
334
+ raise ValueError("length must be of shape () or (1,)")
335
+ if length.dim() == 0:
336
+ length = length.unsqueeze(0)
337
+
338
+ enc_out, _, state = self.model.transcribe_streaming(input, length, state)
339
+ return self._search(enc_out, hypothesis, beam_width), state
.venv/lib/python3.11/site-packages/torchaudio/models/tacotron2.py ADDED
@@ -0,0 +1,1046 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *****************************************************************************
2
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Redistribution and use in source and binary forms, with or without
5
+ # modification, are permitted provided that the following conditions are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of the NVIDIA CORPORATION nor the
12
+ # names of its contributors may be used to endorse or promote products
13
+ # derived from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16
+ # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17
+ # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18
+ # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19
+ # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20
+ # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21
+ # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22
+ # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24
+ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
+ #
26
+ # *****************************************************************************
27
+
28
+ import warnings
29
+ from typing import List, Optional, Tuple, Union
30
+
31
+ import torch
32
+ from torch import nn, Tensor
33
+ from torch.nn import functional as F
34
+
35
+
36
+ __all__ = [
37
+ "Tacotron2",
38
+ ]
39
+
40
+
41
+ def _get_linear_layer(in_dim: int, out_dim: int, bias: bool = True, w_init_gain: str = "linear") -> torch.nn.Linear:
42
+ r"""Linear layer with xavier uniform initialization.
43
+
44
+ Args:
45
+ in_dim (int): Size of each input sample.
46
+ out_dim (int): Size of each output sample.
47
+ bias (bool, optional): If set to ``False``, the layer will not learn an additive bias. (Default: ``True``)
48
+ w_init_gain (str, optional): Parameter passed to ``torch.nn.init.calculate_gain``
49
+ for setting the gain parameter of ``xavier_uniform_``. (Default: ``linear``)
50
+
51
+ Returns:
52
+ (torch.nn.Linear): The corresponding linear layer.
53
+ """
54
+ linear = torch.nn.Linear(in_dim, out_dim, bias=bias)
55
+ torch.nn.init.xavier_uniform_(linear.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
56
+ return linear
57
+
58
+
59
+ def _get_conv1d_layer(
60
+ in_channels: int,
61
+ out_channels: int,
62
+ kernel_size: int = 1,
63
+ stride: int = 1,
64
+ padding: Optional[Union[str, int, Tuple[int]]] = None,
65
+ dilation: int = 1,
66
+ bias: bool = True,
67
+ w_init_gain: str = "linear",
68
+ ) -> torch.nn.Conv1d:
69
+ r"""1D convolution with xavier uniform initialization.
70
+
71
+ Args:
72
+ in_channels (int): Number of channels in the input image.
73
+ out_channels (int): Number of channels produced by the convolution.
74
+ kernel_size (int, optional): Number of channels in the input image. (Default: ``1``)
75
+ stride (int, optional): Number of channels in the input image. (Default: ``1``)
76
+ padding (str, int or tuple, optional): Padding added to both sides of the input.
77
+ (Default: dilation * (kernel_size - 1) / 2)
78
+ dilation (int, optional): Number of channels in the input image. (Default: ``1``)
79
+ w_init_gain (str, optional): Parameter passed to ``torch.nn.init.calculate_gain``
80
+ for setting the gain parameter of ``xavier_uniform_``. (Default: ``linear``)
81
+
82
+ Returns:
83
+ (torch.nn.Conv1d): The corresponding Conv1D layer.
84
+ """
85
+ if padding is None:
86
+ if kernel_size % 2 != 1:
87
+ raise ValueError("kernel_size must be odd")
88
+ padding = int(dilation * (kernel_size - 1) / 2)
89
+
90
+ conv1d = torch.nn.Conv1d(
91
+ in_channels,
92
+ out_channels,
93
+ kernel_size=kernel_size,
94
+ stride=stride,
95
+ padding=padding,
96
+ dilation=dilation,
97
+ bias=bias,
98
+ )
99
+
100
+ torch.nn.init.xavier_uniform_(conv1d.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
101
+
102
+ return conv1d
103
+
104
+
105
+ def _get_mask_from_lengths(lengths: Tensor) -> Tensor:
106
+ r"""Returns a binary mask based on ``lengths``. The ``i``-th row and ``j``-th column of the mask
107
+ is ``1`` if ``j`` is smaller than ``i``-th element of ``lengths.
108
+
109
+ Args:
110
+ lengths (Tensor): The length of each element in the batch, with shape (n_batch, ).
111
+
112
+ Returns:
113
+ mask (Tensor): The binary mask, with shape (n_batch, max of ``lengths``).
114
+ """
115
+ max_len = torch.max(lengths).item()
116
+ ids = torch.arange(0, max_len, device=lengths.device, dtype=lengths.dtype)
117
+ mask = (ids < lengths.unsqueeze(1)).byte()
118
+ mask = torch.le(mask, 0)
119
+ return mask
120
+
121
+
122
+ class _LocationLayer(nn.Module):
123
+ r"""Location layer used in the Attention model.
124
+
125
+ Args:
126
+ attention_n_filter (int): Number of filters for attention model.
127
+ attention_kernel_size (int): Kernel size for attention model.
128
+ attention_hidden_dim (int): Dimension of attention hidden representation.
129
+ """
130
+
131
+ def __init__(
132
+ self,
133
+ attention_n_filter: int,
134
+ attention_kernel_size: int,
135
+ attention_hidden_dim: int,
136
+ ):
137
+ super().__init__()
138
+ padding = int((attention_kernel_size - 1) / 2)
139
+ self.location_conv = _get_conv1d_layer(
140
+ 2,
141
+ attention_n_filter,
142
+ kernel_size=attention_kernel_size,
143
+ padding=padding,
144
+ bias=False,
145
+ stride=1,
146
+ dilation=1,
147
+ )
148
+ self.location_dense = _get_linear_layer(
149
+ attention_n_filter, attention_hidden_dim, bias=False, w_init_gain="tanh"
150
+ )
151
+
152
+ def forward(self, attention_weights_cat: Tensor) -> Tensor:
153
+ r"""Location layer used in the Attention model.
154
+
155
+ Args:
156
+ attention_weights_cat (Tensor): Cumulative and previous attention weights
157
+ with shape (n_batch, 2, max of ``text_lengths``).
158
+
159
+ Returns:
160
+ processed_attention (Tensor): Cumulative and previous attention weights
161
+ with shape (n_batch, ``attention_hidden_dim``).
162
+ """
163
+ # (n_batch, attention_n_filter, text_lengths.max())
164
+ processed_attention = self.location_conv(attention_weights_cat)
165
+ processed_attention = processed_attention.transpose(1, 2)
166
+ # (n_batch, text_lengths.max(), attention_hidden_dim)
167
+ processed_attention = self.location_dense(processed_attention)
168
+ return processed_attention
169
+
170
+
171
+ class _Attention(nn.Module):
172
+ r"""Locally sensitive attention model.
173
+
174
+ Args:
175
+ attention_rnn_dim (int): Number of hidden units for RNN.
176
+ encoder_embedding_dim (int): Number of embedding dimensions in the Encoder.
177
+ attention_hidden_dim (int): Dimension of attention hidden representation.
178
+ attention_location_n_filter (int): Number of filters for Attention model.
179
+ attention_location_kernel_size (int): Kernel size for Attention model.
180
+ """
181
+
182
+ def __init__(
183
+ self,
184
+ attention_rnn_dim: int,
185
+ encoder_embedding_dim: int,
186
+ attention_hidden_dim: int,
187
+ attention_location_n_filter: int,
188
+ attention_location_kernel_size: int,
189
+ ) -> None:
190
+ super().__init__()
191
+ self.query_layer = _get_linear_layer(attention_rnn_dim, attention_hidden_dim, bias=False, w_init_gain="tanh")
192
+ self.memory_layer = _get_linear_layer(
193
+ encoder_embedding_dim, attention_hidden_dim, bias=False, w_init_gain="tanh"
194
+ )
195
+ self.v = _get_linear_layer(attention_hidden_dim, 1, bias=False)
196
+ self.location_layer = _LocationLayer(
197
+ attention_location_n_filter,
198
+ attention_location_kernel_size,
199
+ attention_hidden_dim,
200
+ )
201
+ self.score_mask_value = -float("inf")
202
+
203
+ def _get_alignment_energies(self, query: Tensor, processed_memory: Tensor, attention_weights_cat: Tensor) -> Tensor:
204
+ r"""Get the alignment vector.
205
+
206
+ Args:
207
+ query (Tensor): Decoder output with shape (n_batch, n_mels * n_frames_per_step).
208
+ processed_memory (Tensor): Processed Encoder outputs
209
+ with shape (n_batch, max of ``text_lengths``, attention_hidden_dim).
210
+ attention_weights_cat (Tensor): Cumulative and previous attention weights
211
+ with shape (n_batch, 2, max of ``text_lengths``).
212
+
213
+ Returns:
214
+ alignment (Tensor): attention weights, it is a tensor with shape (batch, max of ``text_lengths``).
215
+ """
216
+
217
+ processed_query = self.query_layer(query.unsqueeze(1))
218
+ processed_attention_weights = self.location_layer(attention_weights_cat)
219
+ energies = self.v(torch.tanh(processed_query + processed_attention_weights + processed_memory))
220
+
221
+ alignment = energies.squeeze(2)
222
+ return alignment
223
+
224
+ def forward(
225
+ self,
226
+ attention_hidden_state: Tensor,
227
+ memory: Tensor,
228
+ processed_memory: Tensor,
229
+ attention_weights_cat: Tensor,
230
+ mask: Tensor,
231
+ ) -> Tuple[Tensor, Tensor]:
232
+ r"""Pass the input through the Attention model.
233
+
234
+ Args:
235
+ attention_hidden_state (Tensor): Attention rnn last output with shape (n_batch, ``attention_rnn_dim``).
236
+ memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
237
+ processed_memory (Tensor): Processed Encoder outputs
238
+ with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``).
239
+ attention_weights_cat (Tensor): Previous and cumulative attention weights
240
+ with shape (n_batch, current_num_frames * 2, max of ``text_lengths``).
241
+ mask (Tensor): Binary mask for padded data with shape (n_batch, current_num_frames).
242
+
243
+ Returns:
244
+ attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
245
+ attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
246
+ """
247
+ alignment = self._get_alignment_energies(attention_hidden_state, processed_memory, attention_weights_cat)
248
+
249
+ alignment = alignment.masked_fill(mask, self.score_mask_value)
250
+
251
+ attention_weights = F.softmax(alignment, dim=1)
252
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
253
+ attention_context = attention_context.squeeze(1)
254
+
255
+ return attention_context, attention_weights
256
+
257
+
258
+ class _Prenet(nn.Module):
259
+ r"""Prenet Module. It is consists of ``len(output_size)`` linear layers.
260
+
261
+ Args:
262
+ in_dim (int): The size of each input sample.
263
+ output_sizes (list): The output dimension of each linear layers.
264
+ """
265
+
266
+ def __init__(self, in_dim: int, out_sizes: List[int]) -> None:
267
+ super().__init__()
268
+ in_sizes = [in_dim] + out_sizes[:-1]
269
+ self.layers = nn.ModuleList(
270
+ [_get_linear_layer(in_size, out_size, bias=False) for (in_size, out_size) in zip(in_sizes, out_sizes)]
271
+ )
272
+
273
+ def forward(self, x: Tensor) -> Tensor:
274
+ r"""Pass the input through Prenet.
275
+
276
+ Args:
277
+ x (Tensor): The input sequence to Prenet with shape (n_batch, in_dim).
278
+
279
+ Return:
280
+ x (Tensor): Tensor with shape (n_batch, sizes[-1])
281
+ """
282
+
283
+ for linear in self.layers:
284
+ x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
285
+ return x
286
+
287
+
288
+ class _Postnet(nn.Module):
289
+ r"""Postnet Module.
290
+
291
+ Args:
292
+ n_mels (int): Number of mel bins.
293
+ postnet_embedding_dim (int): Postnet embedding dimension.
294
+ postnet_kernel_size (int): Postnet kernel size.
295
+ postnet_n_convolution (int): Number of postnet convolutions.
296
+ """
297
+
298
+ def __init__(
299
+ self,
300
+ n_mels: int,
301
+ postnet_embedding_dim: int,
302
+ postnet_kernel_size: int,
303
+ postnet_n_convolution: int,
304
+ ):
305
+ super().__init__()
306
+ self.convolutions = nn.ModuleList()
307
+
308
+ for i in range(postnet_n_convolution):
309
+ in_channels = n_mels if i == 0 else postnet_embedding_dim
310
+ out_channels = n_mels if i == (postnet_n_convolution - 1) else postnet_embedding_dim
311
+ init_gain = "linear" if i == (postnet_n_convolution - 1) else "tanh"
312
+ num_features = n_mels if i == (postnet_n_convolution - 1) else postnet_embedding_dim
313
+ self.convolutions.append(
314
+ nn.Sequential(
315
+ _get_conv1d_layer(
316
+ in_channels,
317
+ out_channels,
318
+ kernel_size=postnet_kernel_size,
319
+ stride=1,
320
+ padding=int((postnet_kernel_size - 1) / 2),
321
+ dilation=1,
322
+ w_init_gain=init_gain,
323
+ ),
324
+ nn.BatchNorm1d(num_features),
325
+ )
326
+ )
327
+
328
+ self.n_convs = len(self.convolutions)
329
+
330
+ def forward(self, x: Tensor) -> Tensor:
331
+ r"""Pass the input through Postnet.
332
+
333
+ Args:
334
+ x (Tensor): The input sequence with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
335
+
336
+ Return:
337
+ x (Tensor): Tensor with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
338
+ """
339
+
340
+ for i, conv in enumerate(self.convolutions):
341
+ if i < self.n_convs - 1:
342
+ x = F.dropout(torch.tanh(conv(x)), 0.5, training=self.training)
343
+ else:
344
+ x = F.dropout(conv(x), 0.5, training=self.training)
345
+
346
+ return x
347
+
348
+
349
+ class _Encoder(nn.Module):
350
+ r"""Encoder Module.
351
+
352
+ Args:
353
+ encoder_embedding_dim (int): Number of embedding dimensions in the encoder.
354
+ encoder_n_convolution (int): Number of convolution layers in the encoder.
355
+ encoder_kernel_size (int): The kernel size in the encoder.
356
+
357
+ Examples
358
+ >>> encoder = _Encoder(3, 512, 5)
359
+ >>> input = torch.rand(10, 20, 30)
360
+ >>> output = encoder(input) # shape: (10, 30, 512)
361
+ """
362
+
363
+ def __init__(
364
+ self,
365
+ encoder_embedding_dim: int,
366
+ encoder_n_convolution: int,
367
+ encoder_kernel_size: int,
368
+ ) -> None:
369
+ super().__init__()
370
+
371
+ self.convolutions = nn.ModuleList()
372
+ for _ in range(encoder_n_convolution):
373
+ conv_layer = nn.Sequential(
374
+ _get_conv1d_layer(
375
+ encoder_embedding_dim,
376
+ encoder_embedding_dim,
377
+ kernel_size=encoder_kernel_size,
378
+ stride=1,
379
+ padding=int((encoder_kernel_size - 1) / 2),
380
+ dilation=1,
381
+ w_init_gain="relu",
382
+ ),
383
+ nn.BatchNorm1d(encoder_embedding_dim),
384
+ )
385
+ self.convolutions.append(conv_layer)
386
+
387
+ self.lstm = nn.LSTM(
388
+ encoder_embedding_dim,
389
+ int(encoder_embedding_dim / 2),
390
+ 1,
391
+ batch_first=True,
392
+ bidirectional=True,
393
+ )
394
+ self.lstm.flatten_parameters()
395
+
396
+ def forward(self, x: Tensor, input_lengths: Tensor) -> Tensor:
397
+ r"""Pass the input through the Encoder.
398
+
399
+ Args:
400
+ x (Tensor): The input sequences with shape (n_batch, encoder_embedding_dim, n_seq).
401
+ input_lengths (Tensor): The length of each input sequence with shape (n_batch, ).
402
+
403
+ Return:
404
+ x (Tensor): A tensor with shape (n_batch, n_seq, encoder_embedding_dim).
405
+ """
406
+
407
+ for conv in self.convolutions:
408
+ x = F.dropout(F.relu(conv(x)), 0.5, self.training)
409
+
410
+ x = x.transpose(1, 2)
411
+
412
+ input_lengths = input_lengths.cpu()
413
+ x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True)
414
+
415
+ outputs, _ = self.lstm(x)
416
+ outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
417
+
418
+ return outputs
419
+
420
+
421
+ class _Decoder(nn.Module):
422
+ r"""Decoder with Attention model.
423
+
424
+ Args:
425
+ n_mels (int): number of mel bins
426
+ n_frames_per_step (int): number of frames processed per step, only 1 is supported
427
+ encoder_embedding_dim (int): the number of embedding dimensions in the encoder.
428
+ decoder_rnn_dim (int): number of units in decoder LSTM
429
+ decoder_max_step (int): maximum number of output mel spectrograms
430
+ decoder_dropout (float): dropout probability for decoder LSTM
431
+ decoder_early_stopping (bool): stop decoding when all samples are finished
432
+ attention_rnn_dim (int): number of units in attention LSTM
433
+ attention_hidden_dim (int): dimension of attention hidden representation
434
+ attention_location_n_filter (int): number of filters for attention model
435
+ attention_location_kernel_size (int): kernel size for attention model
436
+ attention_dropout (float): dropout probability for attention LSTM
437
+ prenet_dim (int): number of ReLU units in prenet layers
438
+ gate_threshold (float): probability threshold for stop token
439
+ """
440
+
441
+ def __init__(
442
+ self,
443
+ n_mels: int,
444
+ n_frames_per_step: int,
445
+ encoder_embedding_dim: int,
446
+ decoder_rnn_dim: int,
447
+ decoder_max_step: int,
448
+ decoder_dropout: float,
449
+ decoder_early_stopping: bool,
450
+ attention_rnn_dim: int,
451
+ attention_hidden_dim: int,
452
+ attention_location_n_filter: int,
453
+ attention_location_kernel_size: int,
454
+ attention_dropout: float,
455
+ prenet_dim: int,
456
+ gate_threshold: float,
457
+ ) -> None:
458
+
459
+ super().__init__()
460
+ self.n_mels = n_mels
461
+ self.n_frames_per_step = n_frames_per_step
462
+ self.encoder_embedding_dim = encoder_embedding_dim
463
+ self.attention_rnn_dim = attention_rnn_dim
464
+ self.decoder_rnn_dim = decoder_rnn_dim
465
+ self.prenet_dim = prenet_dim
466
+ self.decoder_max_step = decoder_max_step
467
+ self.gate_threshold = gate_threshold
468
+ self.attention_dropout = attention_dropout
469
+ self.decoder_dropout = decoder_dropout
470
+ self.decoder_early_stopping = decoder_early_stopping
471
+
472
+ self.prenet = _Prenet(n_mels * n_frames_per_step, [prenet_dim, prenet_dim])
473
+
474
+ self.attention_rnn = nn.LSTMCell(prenet_dim + encoder_embedding_dim, attention_rnn_dim)
475
+
476
+ self.attention_layer = _Attention(
477
+ attention_rnn_dim,
478
+ encoder_embedding_dim,
479
+ attention_hidden_dim,
480
+ attention_location_n_filter,
481
+ attention_location_kernel_size,
482
+ )
483
+
484
+ self.decoder_rnn = nn.LSTMCell(attention_rnn_dim + encoder_embedding_dim, decoder_rnn_dim, True)
485
+
486
+ self.linear_projection = _get_linear_layer(decoder_rnn_dim + encoder_embedding_dim, n_mels * n_frames_per_step)
487
+
488
+ self.gate_layer = _get_linear_layer(
489
+ decoder_rnn_dim + encoder_embedding_dim, 1, bias=True, w_init_gain="sigmoid"
490
+ )
491
+
492
+ def _get_initial_frame(self, memory: Tensor) -> Tensor:
493
+ r"""Gets all zeros frames to use as the first decoder input.
494
+
495
+ Args:
496
+ memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
497
+
498
+ Returns:
499
+ decoder_input (Tensor): all zeros frames with shape
500
+ (n_batch, max of ``text_lengths``, ``n_mels * n_frames_per_step``).
501
+ """
502
+
503
+ n_batch = memory.size(0)
504
+ dtype = memory.dtype
505
+ device = memory.device
506
+ decoder_input = torch.zeros(n_batch, self.n_mels * self.n_frames_per_step, dtype=dtype, device=device)
507
+ return decoder_input
508
+
509
+ def _initialize_decoder_states(
510
+ self, memory: Tensor
511
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
512
+ r"""Initializes attention rnn states, decoder rnn states, attention
513
+ weights, attention cumulative weights, attention context, stores memory
514
+ and stores processed memory.
515
+
516
+ Args:
517
+ memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
518
+
519
+ Returns:
520
+ attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
521
+ attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
522
+ decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
523
+ decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
524
+ attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
525
+ attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``).
526
+ attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
527
+ processed_memory (Tensor): Processed encoder outputs
528
+ with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``).
529
+ """
530
+ n_batch = memory.size(0)
531
+ max_time = memory.size(1)
532
+ dtype = memory.dtype
533
+ device = memory.device
534
+
535
+ attention_hidden = torch.zeros(n_batch, self.attention_rnn_dim, dtype=dtype, device=device)
536
+ attention_cell = torch.zeros(n_batch, self.attention_rnn_dim, dtype=dtype, device=device)
537
+
538
+ decoder_hidden = torch.zeros(n_batch, self.decoder_rnn_dim, dtype=dtype, device=device)
539
+ decoder_cell = torch.zeros(n_batch, self.decoder_rnn_dim, dtype=dtype, device=device)
540
+
541
+ attention_weights = torch.zeros(n_batch, max_time, dtype=dtype, device=device)
542
+ attention_weights_cum = torch.zeros(n_batch, max_time, dtype=dtype, device=device)
543
+ attention_context = torch.zeros(n_batch, self.encoder_embedding_dim, dtype=dtype, device=device)
544
+
545
+ processed_memory = self.attention_layer.memory_layer(memory)
546
+
547
+ return (
548
+ attention_hidden,
549
+ attention_cell,
550
+ decoder_hidden,
551
+ decoder_cell,
552
+ attention_weights,
553
+ attention_weights_cum,
554
+ attention_context,
555
+ processed_memory,
556
+ )
557
+
558
+ def _parse_decoder_inputs(self, decoder_inputs: Tensor) -> Tensor:
559
+ r"""Prepares decoder inputs.
560
+
561
+ Args:
562
+ decoder_inputs (Tensor): Inputs used for teacher-forced training, i.e. mel-specs,
563
+ with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``)
564
+
565
+ Returns:
566
+ inputs (Tensor): Processed decoder inputs with shape (max of ``mel_specgram_lengths``, n_batch, ``n_mels``).
567
+ """
568
+ # (n_batch, n_mels, mel_specgram_lengths.max()) -> (n_batch, mel_specgram_lengths.max(), n_mels)
569
+ decoder_inputs = decoder_inputs.transpose(1, 2)
570
+ decoder_inputs = decoder_inputs.view(
571
+ decoder_inputs.size(0),
572
+ int(decoder_inputs.size(1) / self.n_frames_per_step),
573
+ -1,
574
+ )
575
+ # (n_batch, mel_specgram_lengths.max(), n_mels) -> (mel_specgram_lengths.max(), n_batch, n_mels)
576
+ decoder_inputs = decoder_inputs.transpose(0, 1)
577
+ return decoder_inputs
578
+
579
+ def _parse_decoder_outputs(
580
+ self, mel_specgram: Tensor, gate_outputs: Tensor, alignments: Tensor
581
+ ) -> Tuple[Tensor, Tensor, Tensor]:
582
+ r"""Prepares decoder outputs for output
583
+
584
+ Args:
585
+ mel_specgram (Tensor): mel spectrogram with shape (max of ``mel_specgram_lengths``, n_batch, ``n_mels``)
586
+ gate_outputs (Tensor): predicted stop token with shape (max of ``mel_specgram_lengths``, n_batch)
587
+ alignments (Tensor): sequence of attention weights from the decoder
588
+ with shape (max of ``mel_specgram_lengths``, n_batch, max of ``text_lengths``)
589
+
590
+ Returns:
591
+ mel_specgram (Tensor): mel spectrogram with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``)
592
+ gate_outputs (Tensor): predicted stop token with shape (n_batch, max of ``mel_specgram_lengths``)
593
+ alignments (Tensor): sequence of attention weights from the decoder
594
+ with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``)
595
+ """
596
+ # (mel_specgram_lengths.max(), n_batch, text_lengths.max())
597
+ # -> (n_batch, mel_specgram_lengths.max(), text_lengths.max())
598
+ alignments = alignments.transpose(0, 1).contiguous()
599
+ # (mel_specgram_lengths.max(), n_batch) -> (n_batch, mel_specgram_lengths.max())
600
+ gate_outputs = gate_outputs.transpose(0, 1).contiguous()
601
+ # (mel_specgram_lengths.max(), n_batch, n_mels) -> (n_batch, mel_specgram_lengths.max(), n_mels)
602
+ mel_specgram = mel_specgram.transpose(0, 1).contiguous()
603
+ # decouple frames per step
604
+ shape = (mel_specgram.shape[0], -1, self.n_mels)
605
+ mel_specgram = mel_specgram.view(*shape)
606
+ # (n_batch, mel_specgram_lengths.max(), n_mels) -> (n_batch, n_mels, T_out)
607
+ mel_specgram = mel_specgram.transpose(1, 2)
608
+
609
+ return mel_specgram, gate_outputs, alignments
610
+
611
+ def decode(
612
+ self,
613
+ decoder_input: Tensor,
614
+ attention_hidden: Tensor,
615
+ attention_cell: Tensor,
616
+ decoder_hidden: Tensor,
617
+ decoder_cell: Tensor,
618
+ attention_weights: Tensor,
619
+ attention_weights_cum: Tensor,
620
+ attention_context: Tensor,
621
+ memory: Tensor,
622
+ processed_memory: Tensor,
623
+ mask: Tensor,
624
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
625
+ r"""Decoder step using stored states, attention and memory
626
+
627
+ Args:
628
+ decoder_input (Tensor): Output of the Prenet with shape (n_batch, ``prenet_dim``).
629
+ attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
630
+ attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
631
+ decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
632
+ decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
633
+ attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
634
+ attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``).
635
+ attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
636
+ memory (Tensor): Encoder output with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
637
+ processed_memory (Tensor): Processed Encoder outputs
638
+ with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``).
639
+ mask (Tensor): Binary mask for padded data with shape (n_batch, current_num_frames).
640
+
641
+ Returns:
642
+ decoder_output: Predicted mel spectrogram for the current frame with shape (n_batch, ``n_mels``).
643
+ gate_prediction (Tensor): Prediction of the stop token with shape (n_batch, ``1``).
644
+ attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
645
+ attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
646
+ decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
647
+ decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
648
+ attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
649
+ attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``).
650
+ attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
651
+ """
652
+ cell_input = torch.cat((decoder_input, attention_context), -1)
653
+
654
+ attention_hidden, attention_cell = self.attention_rnn(cell_input, (attention_hidden, attention_cell))
655
+ attention_hidden = F.dropout(attention_hidden, self.attention_dropout, self.training)
656
+
657
+ attention_weights_cat = torch.cat((attention_weights.unsqueeze(1), attention_weights_cum.unsqueeze(1)), dim=1)
658
+ attention_context, attention_weights = self.attention_layer(
659
+ attention_hidden, memory, processed_memory, attention_weights_cat, mask
660
+ )
661
+
662
+ attention_weights_cum += attention_weights
663
+ decoder_input = torch.cat((attention_hidden, attention_context), -1)
664
+
665
+ decoder_hidden, decoder_cell = self.decoder_rnn(decoder_input, (decoder_hidden, decoder_cell))
666
+ decoder_hidden = F.dropout(decoder_hidden, self.decoder_dropout, self.training)
667
+
668
+ decoder_hidden_attention_context = torch.cat((decoder_hidden, attention_context), dim=1)
669
+ decoder_output = self.linear_projection(decoder_hidden_attention_context)
670
+
671
+ gate_prediction = self.gate_layer(decoder_hidden_attention_context)
672
+
673
+ return (
674
+ decoder_output,
675
+ gate_prediction,
676
+ attention_hidden,
677
+ attention_cell,
678
+ decoder_hidden,
679
+ decoder_cell,
680
+ attention_weights,
681
+ attention_weights_cum,
682
+ attention_context,
683
+ )
684
+
685
+ def forward(
686
+ self, memory: Tensor, mel_specgram_truth: Tensor, memory_lengths: Tensor
687
+ ) -> Tuple[Tensor, Tensor, Tensor]:
688
+ r"""Decoder forward pass for training.
689
+
690
+ Args:
691
+ memory (Tensor): Encoder outputs
692
+ with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
693
+ mel_specgram_truth (Tensor): Decoder ground-truth mel-specs for teacher forcing
694
+ with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
695
+ memory_lengths (Tensor): Encoder output lengths for attention masking
696
+ (the same as ``text_lengths``) with shape (n_batch, ).
697
+
698
+ Returns:
699
+ mel_specgram (Tensor): Predicted mel spectrogram
700
+ with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
701
+ gate_outputs (Tensor): Predicted stop token for each timestep
702
+ with shape (n_batch, max of ``mel_specgram_lengths``).
703
+ alignments (Tensor): Sequence of attention weights from the decoder
704
+ with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``).
705
+ """
706
+
707
+ decoder_input = self._get_initial_frame(memory).unsqueeze(0)
708
+ decoder_inputs = self._parse_decoder_inputs(mel_specgram_truth)
709
+ decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
710
+ decoder_inputs = self.prenet(decoder_inputs)
711
+
712
+ mask = _get_mask_from_lengths(memory_lengths)
713
+ (
714
+ attention_hidden,
715
+ attention_cell,
716
+ decoder_hidden,
717
+ decoder_cell,
718
+ attention_weights,
719
+ attention_weights_cum,
720
+ attention_context,
721
+ processed_memory,
722
+ ) = self._initialize_decoder_states(memory)
723
+
724
+ mel_outputs, gate_outputs, alignments = [], [], []
725
+ while len(mel_outputs) < decoder_inputs.size(0) - 1:
726
+ decoder_input = decoder_inputs[len(mel_outputs)]
727
+ (
728
+ mel_output,
729
+ gate_output,
730
+ attention_hidden,
731
+ attention_cell,
732
+ decoder_hidden,
733
+ decoder_cell,
734
+ attention_weights,
735
+ attention_weights_cum,
736
+ attention_context,
737
+ ) = self.decode(
738
+ decoder_input,
739
+ attention_hidden,
740
+ attention_cell,
741
+ decoder_hidden,
742
+ decoder_cell,
743
+ attention_weights,
744
+ attention_weights_cum,
745
+ attention_context,
746
+ memory,
747
+ processed_memory,
748
+ mask,
749
+ )
750
+
751
+ mel_outputs += [mel_output.squeeze(1)]
752
+ gate_outputs += [gate_output.squeeze(1)]
753
+ alignments += [attention_weights]
754
+
755
+ mel_specgram, gate_outputs, alignments = self._parse_decoder_outputs(
756
+ torch.stack(mel_outputs), torch.stack(gate_outputs), torch.stack(alignments)
757
+ )
758
+
759
+ return mel_specgram, gate_outputs, alignments
760
+
761
+ def _get_go_frame(self, memory: Tensor) -> Tensor:
762
+ """Gets all zeros frames to use as the first decoder input
763
+
764
+ args:
765
+ memory (Tensor): Encoder outputs
766
+ with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
767
+
768
+ returns:
769
+ decoder_input (Tensor): All zeros frames with shape(n_batch, ``n_mels`` * ``n_frame_per_step``).
770
+ """
771
+
772
+ n_batch = memory.size(0)
773
+ dtype = memory.dtype
774
+ device = memory.device
775
+ decoder_input = torch.zeros(n_batch, self.n_mels * self.n_frames_per_step, dtype=dtype, device=device)
776
+ return decoder_input
777
+
778
+ @torch.jit.export
779
+ def infer(self, memory: Tensor, memory_lengths: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
780
+ """Decoder inference
781
+
782
+ Args:
783
+ memory (Tensor): Encoder outputs
784
+ with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
785
+ memory_lengths (Tensor): Encoder output lengths for attention masking
786
+ (the same as ``text_lengths``) with shape (n_batch, ).
787
+
788
+ Returns:
789
+ mel_specgram (Tensor): Predicted mel spectrogram
790
+ with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
791
+ mel_specgram_lengths (Tensor): the length of the predicted mel spectrogram (n_batch, ))
792
+ gate_outputs (Tensor): Predicted stop token for each timestep
793
+ with shape (n_batch, max of ``mel_specgram_lengths``).
794
+ alignments (Tensor): Sequence of attention weights from the decoder
795
+ with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``).
796
+ """
797
+ batch_size, device = memory.size(0), memory.device
798
+
799
+ decoder_input = self._get_go_frame(memory)
800
+
801
+ mask = _get_mask_from_lengths(memory_lengths)
802
+ (
803
+ attention_hidden,
804
+ attention_cell,
805
+ decoder_hidden,
806
+ decoder_cell,
807
+ attention_weights,
808
+ attention_weights_cum,
809
+ attention_context,
810
+ processed_memory,
811
+ ) = self._initialize_decoder_states(memory)
812
+
813
+ mel_specgram_lengths = torch.zeros([batch_size], dtype=torch.int32, device=device)
814
+ finished = torch.zeros([batch_size], dtype=torch.bool, device=device)
815
+ mel_specgrams: List[Tensor] = []
816
+ gate_outputs: List[Tensor] = []
817
+ alignments: List[Tensor] = []
818
+ for _ in range(self.decoder_max_step):
819
+ decoder_input = self.prenet(decoder_input)
820
+ (
821
+ mel_specgram,
822
+ gate_output,
823
+ attention_hidden,
824
+ attention_cell,
825
+ decoder_hidden,
826
+ decoder_cell,
827
+ attention_weights,
828
+ attention_weights_cum,
829
+ attention_context,
830
+ ) = self.decode(
831
+ decoder_input,
832
+ attention_hidden,
833
+ attention_cell,
834
+ decoder_hidden,
835
+ decoder_cell,
836
+ attention_weights,
837
+ attention_weights_cum,
838
+ attention_context,
839
+ memory,
840
+ processed_memory,
841
+ mask,
842
+ )
843
+
844
+ mel_specgrams.append(mel_specgram.unsqueeze(0))
845
+ gate_outputs.append(gate_output.transpose(0, 1))
846
+ alignments.append(attention_weights)
847
+ mel_specgram_lengths[~finished] += 1
848
+
849
+ finished |= torch.sigmoid(gate_output.squeeze(1)) > self.gate_threshold
850
+ if self.decoder_early_stopping and torch.all(finished):
851
+ break
852
+
853
+ decoder_input = mel_specgram
854
+
855
+ if len(mel_specgrams) == self.decoder_max_step:
856
+ warnings.warn(
857
+ "Reached max decoder steps. The generated spectrogram might not cover " "the whole transcript."
858
+ )
859
+
860
+ mel_specgrams = torch.cat(mel_specgrams, dim=0)
861
+ gate_outputs = torch.cat(gate_outputs, dim=0)
862
+ alignments = torch.cat(alignments, dim=0)
863
+
864
+ mel_specgrams, gate_outputs, alignments = self._parse_decoder_outputs(mel_specgrams, gate_outputs, alignments)
865
+
866
+ return mel_specgrams, mel_specgram_lengths, gate_outputs, alignments
867
+
868
+
869
+ class Tacotron2(nn.Module):
870
+ r"""Tacotron2 model from *Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions*
871
+ :cite:`shen2018natural` based on the implementation from
872
+ `Nvidia Deep Learning Examples <https://github.com/NVIDIA/DeepLearningExamples/>`_.
873
+
874
+ See Also:
875
+ * :class:`torchaudio.pipelines.Tacotron2TTSBundle`: TTS pipeline with pretrained model.
876
+
877
+ Args:
878
+ mask_padding (bool, optional): Use mask padding (Default: ``False``).
879
+ n_mels (int, optional): Number of mel bins (Default: ``80``).
880
+ n_symbol (int, optional): Number of symbols for the input text (Default: ``148``).
881
+ n_frames_per_step (int, optional): Number of frames processed per step, only 1 is supported (Default: ``1``).
882
+ symbol_embedding_dim (int, optional): Input embedding dimension (Default: ``512``).
883
+ encoder_n_convolution (int, optional): Number of encoder convolutions (Default: ``3``).
884
+ encoder_kernel_size (int, optional): Encoder kernel size (Default: ``5``).
885
+ encoder_embedding_dim (int, optional): Encoder embedding dimension (Default: ``512``).
886
+ decoder_rnn_dim (int, optional): Number of units in decoder LSTM (Default: ``1024``).
887
+ decoder_max_step (int, optional): Maximum number of output mel spectrograms (Default: ``2000``).
888
+ decoder_dropout (float, optional): Dropout probability for decoder LSTM (Default: ``0.1``).
889
+ decoder_early_stopping (bool, optional): Continue decoding after all samples are finished (Default: ``True``).
890
+ attention_rnn_dim (int, optional): Number of units in attention LSTM (Default: ``1024``).
891
+ attention_hidden_dim (int, optional): Dimension of attention hidden representation (Default: ``128``).
892
+ attention_location_n_filter (int, optional): Number of filters for attention model (Default: ``32``).
893
+ attention_location_kernel_size (int, optional): Kernel size for attention model (Default: ``31``).
894
+ attention_dropout (float, optional): Dropout probability for attention LSTM (Default: ``0.1``).
895
+ prenet_dim (int, optional): Number of ReLU units in prenet layers (Default: ``256``).
896
+ postnet_n_convolution (int, optional): Number of postnet convolutions (Default: ``5``).
897
+ postnet_kernel_size (int, optional): Postnet kernel size (Default: ``5``).
898
+ postnet_embedding_dim (int, optional): Postnet embedding dimension (Default: ``512``).
899
+ gate_threshold (float, optional): Probability threshold for stop token (Default: ``0.5``).
900
+ """
901
+
902
+ def __init__(
903
+ self,
904
+ mask_padding: bool = False,
905
+ n_mels: int = 80,
906
+ n_symbol: int = 148,
907
+ n_frames_per_step: int = 1,
908
+ symbol_embedding_dim: int = 512,
909
+ encoder_embedding_dim: int = 512,
910
+ encoder_n_convolution: int = 3,
911
+ encoder_kernel_size: int = 5,
912
+ decoder_rnn_dim: int = 1024,
913
+ decoder_max_step: int = 2000,
914
+ decoder_dropout: float = 0.1,
915
+ decoder_early_stopping: bool = True,
916
+ attention_rnn_dim: int = 1024,
917
+ attention_hidden_dim: int = 128,
918
+ attention_location_n_filter: int = 32,
919
+ attention_location_kernel_size: int = 31,
920
+ attention_dropout: float = 0.1,
921
+ prenet_dim: int = 256,
922
+ postnet_n_convolution: int = 5,
923
+ postnet_kernel_size: int = 5,
924
+ postnet_embedding_dim: int = 512,
925
+ gate_threshold: float = 0.5,
926
+ ) -> None:
927
+ super().__init__()
928
+
929
+ self.mask_padding = mask_padding
930
+ self.n_mels = n_mels
931
+ self.n_frames_per_step = n_frames_per_step
932
+ self.embedding = nn.Embedding(n_symbol, symbol_embedding_dim)
933
+ torch.nn.init.xavier_uniform_(self.embedding.weight)
934
+ self.encoder = _Encoder(encoder_embedding_dim, encoder_n_convolution, encoder_kernel_size)
935
+ self.decoder = _Decoder(
936
+ n_mels,
937
+ n_frames_per_step,
938
+ encoder_embedding_dim,
939
+ decoder_rnn_dim,
940
+ decoder_max_step,
941
+ decoder_dropout,
942
+ decoder_early_stopping,
943
+ attention_rnn_dim,
944
+ attention_hidden_dim,
945
+ attention_location_n_filter,
946
+ attention_location_kernel_size,
947
+ attention_dropout,
948
+ prenet_dim,
949
+ gate_threshold,
950
+ )
951
+ self.postnet = _Postnet(n_mels, postnet_embedding_dim, postnet_kernel_size, postnet_n_convolution)
952
+
953
+ def forward(
954
+ self,
955
+ tokens: Tensor,
956
+ token_lengths: Tensor,
957
+ mel_specgram: Tensor,
958
+ mel_specgram_lengths: Tensor,
959
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
960
+ r"""Pass the input through the Tacotron2 model. This is in teacher
961
+ forcing mode, which is generally used for training.
962
+
963
+ The input ``tokens`` should be padded with zeros to length max of ``token_lengths``.
964
+ The input ``mel_specgram`` should be padded with zeros to length max of ``mel_specgram_lengths``.
965
+
966
+ Args:
967
+ tokens (Tensor): The input tokens to Tacotron2 with shape `(n_batch, max of token_lengths)`.
968
+ token_lengths (Tensor): The valid length of each sample in ``tokens`` with shape `(n_batch, )`.
969
+ mel_specgram (Tensor): The target mel spectrogram
970
+ with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
971
+ mel_specgram_lengths (Tensor): The length of each mel spectrogram with shape `(n_batch, )`.
972
+
973
+ Returns:
974
+ [Tensor, Tensor, Tensor, Tensor]:
975
+ Tensor
976
+ Mel spectrogram before Postnet with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
977
+ Tensor
978
+ Mel spectrogram after Postnet with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
979
+ Tensor
980
+ The output for stop token at each time step with shape `(n_batch, max of mel_specgram_lengths)`.
981
+ Tensor
982
+ Sequence of attention weights from the decoder with
983
+ shape `(n_batch, max of mel_specgram_lengths, max of token_lengths)`.
984
+ """
985
+
986
+ embedded_inputs = self.embedding(tokens).transpose(1, 2)
987
+
988
+ encoder_outputs = self.encoder(embedded_inputs, token_lengths)
989
+ mel_specgram, gate_outputs, alignments = self.decoder(
990
+ encoder_outputs, mel_specgram, memory_lengths=token_lengths
991
+ )
992
+
993
+ mel_specgram_postnet = self.postnet(mel_specgram)
994
+ mel_specgram_postnet = mel_specgram + mel_specgram_postnet
995
+
996
+ if self.mask_padding:
997
+ mask = _get_mask_from_lengths(mel_specgram_lengths)
998
+ mask = mask.expand(self.n_mels, mask.size(0), mask.size(1))
999
+ mask = mask.permute(1, 0, 2)
1000
+
1001
+ mel_specgram.masked_fill_(mask, 0.0)
1002
+ mel_specgram_postnet.masked_fill_(mask, 0.0)
1003
+ gate_outputs.masked_fill_(mask[:, 0, :], 1e3)
1004
+
1005
+ return mel_specgram, mel_specgram_postnet, gate_outputs, alignments
1006
+
1007
+ @torch.jit.export
1008
+ def infer(self, tokens: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]:
1009
+ r"""Using Tacotron2 for inference. The input is a batch of encoded
1010
+ sentences (``tokens``) and its corresponding lengths (``lengths``). The
1011
+ output is the generated mel spectrograms, its corresponding lengths, and
1012
+ the attention weights from the decoder.
1013
+
1014
+ The input `tokens` should be padded with zeros to length max of ``lengths``.
1015
+
1016
+ Args:
1017
+ tokens (Tensor): The input tokens to Tacotron2 with shape `(n_batch, max of lengths)`.
1018
+ lengths (Tensor or None, optional):
1019
+ The valid length of each sample in ``tokens`` with shape `(n_batch, )`.
1020
+ If ``None``, it is assumed that the all the tokens are valid. Default: ``None``
1021
+
1022
+ Returns:
1023
+ (Tensor, Tensor, Tensor):
1024
+ Tensor
1025
+ The predicted mel spectrogram with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
1026
+ Tensor
1027
+ The length of the predicted mel spectrogram with shape `(n_batch, )`.
1028
+ Tensor
1029
+ Sequence of attention weights from the decoder with shape
1030
+ `(n_batch, max of mel_specgram_lengths, max of lengths)`.
1031
+ """
1032
+ n_batch, max_length = tokens.shape
1033
+ if lengths is None:
1034
+ lengths = torch.tensor([max_length]).expand(n_batch).to(tokens.device, tokens.dtype)
1035
+
1036
+ assert lengths is not None # For TorchScript compiler
1037
+ embedded_inputs = self.embedding(tokens).transpose(1, 2)
1038
+ encoder_outputs = self.encoder(embedded_inputs, lengths)
1039
+ mel_specgram, mel_specgram_lengths, _, alignments = self.decoder.infer(encoder_outputs, lengths)
1040
+
1041
+ mel_outputs_postnet = self.postnet(mel_specgram)
1042
+ mel_outputs_postnet = mel_specgram + mel_outputs_postnet
1043
+
1044
+ alignments = alignments.unfold(1, n_batch, n_batch).transpose(0, 2)
1045
+
1046
+ return mel_outputs_postnet, mel_specgram_lengths, alignments
.venv/lib/python3.11/site-packages/torchaudio/models/wav2letter.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn, Tensor
2
+
3
+ __all__ = [
4
+ "Wav2Letter",
5
+ ]
6
+
7
+
8
+ class Wav2Letter(nn.Module):
9
+ r"""Wav2Letter model architecture from *Wav2Letter: an End-to-End ConvNet-based Speech
10
+ Recognition System* :cite:`collobert2016wav2letter`.
11
+
12
+ See Also:
13
+ * `Training example <https://github.com/pytorch/audio/tree/release/0.12/examples/pipeline_wav2letter>`__
14
+
15
+ Args:
16
+ num_classes (int, optional): Number of classes to be classified. (Default: ``40``)
17
+ input_type (str, optional): Wav2Letter can use as input: ``waveform``, ``power_spectrum``
18
+ or ``mfcc`` (Default: ``waveform``).
19
+ num_features (int, optional): Number of input features that the network will receive (Default: ``1``).
20
+ """
21
+
22
+ def __init__(self, num_classes: int = 40, input_type: str = "waveform", num_features: int = 1) -> None:
23
+ super().__init__()
24
+
25
+ acoustic_num_features = 250 if input_type == "waveform" else num_features
26
+ acoustic_model = nn.Sequential(
27
+ nn.Conv1d(in_channels=acoustic_num_features, out_channels=250, kernel_size=48, stride=2, padding=23),
28
+ nn.ReLU(inplace=True),
29
+ nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
30
+ nn.ReLU(inplace=True),
31
+ nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
32
+ nn.ReLU(inplace=True),
33
+ nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
34
+ nn.ReLU(inplace=True),
35
+ nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
36
+ nn.ReLU(inplace=True),
37
+ nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
38
+ nn.ReLU(inplace=True),
39
+ nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
40
+ nn.ReLU(inplace=True),
41
+ nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
42
+ nn.ReLU(inplace=True),
43
+ nn.Conv1d(in_channels=250, out_channels=2000, kernel_size=32, stride=1, padding=16),
44
+ nn.ReLU(inplace=True),
45
+ nn.Conv1d(in_channels=2000, out_channels=2000, kernel_size=1, stride=1, padding=0),
46
+ nn.ReLU(inplace=True),
47
+ nn.Conv1d(in_channels=2000, out_channels=num_classes, kernel_size=1, stride=1, padding=0),
48
+ nn.ReLU(inplace=True),
49
+ )
50
+
51
+ if input_type == "waveform":
52
+ waveform_model = nn.Sequential(
53
+ nn.Conv1d(in_channels=num_features, out_channels=250, kernel_size=250, stride=160, padding=45),
54
+ nn.ReLU(inplace=True),
55
+ )
56
+ self.acoustic_model = nn.Sequential(waveform_model, acoustic_model)
57
+
58
+ if input_type in ["power_spectrum", "mfcc"]:
59
+ self.acoustic_model = acoustic_model
60
+
61
+ def forward(self, x: Tensor) -> Tensor:
62
+ r"""
63
+ Args:
64
+ x (torch.Tensor): Tensor of dimension (batch_size, num_features, input_length).
65
+
66
+ Returns:
67
+ Tensor: Predictor tensor of dimension (batch_size, number_of_classes, input_length).
68
+ """
69
+
70
+ x = self.acoustic_model(x)
71
+ x = nn.functional.log_softmax(x, dim=1)
72
+ return x
.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/model.py ADDED
@@ -0,0 +1,1579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn import Module
7
+
8
+ from . import components
9
+
10
+
11
+ class Wav2Vec2Model(Module):
12
+ """Acoustic model used in *wav2vec 2.0* :cite:`baevski2020wav2vec`.
13
+
14
+ Note:
15
+ To build the model, please use one of the factory functions.
16
+
17
+ See Also:
18
+ * :class:`torchaudio.pipelines.Wav2Vec2Bundle`: Pretrained models (without fine-tuning)
19
+ * :class:`torchaudio.pipelines.Wav2Vec2ASRBundle`: ASR pipelines with pretrained models.
20
+
21
+ Args:
22
+ feature_extractor (torch.nn.Module):
23
+ Feature extractor that extracts feature vectors from raw audio Tensor.
24
+
25
+ encoder (torch.nn.Module):
26
+ Encoder that converts the audio features into the sequence of probability
27
+ distribution (in negative log-likelihood) over labels.
28
+
29
+ aux (torch.nn.Module or None, optional):
30
+ Auxiliary module. If provided, the output from encoder is passed to this module.
31
+ """ # noqa: E501
32
+
33
+ def __init__(
34
+ self,
35
+ feature_extractor: Module,
36
+ encoder: Module,
37
+ aux: Optional[Module] = None,
38
+ ):
39
+ super().__init__()
40
+ self.feature_extractor = feature_extractor
41
+ self.encoder = encoder
42
+ self.aux = aux
43
+
44
+ @torch.jit.export
45
+ def extract_features(
46
+ self,
47
+ waveforms: Tensor,
48
+ lengths: Optional[Tensor] = None,
49
+ num_layers: Optional[int] = None,
50
+ ) -> Tuple[List[Tensor], Optional[Tensor]]:
51
+ """Extract feature vectors from raw waveforms
52
+
53
+ This returns the list of outputs from the intermediate layers of
54
+ transformer block in encoder.
55
+
56
+ Args:
57
+ waveforms (Tensor): Audio tensor of shape `(batch, frames)`.
58
+ lengths (Tensor or None, optional):
59
+ Indicates the valid length of each audio in the batch.
60
+ Shape: `(batch, )`.
61
+ When the ``waveforms`` contains audios with different durations,
62
+ by providing ``lengths`` argument, the model will compute
63
+ the corresponding valid output lengths and apply proper mask in
64
+ transformer attention layer.
65
+ If ``None``, it is assumed that the entire audio waveform
66
+ length is valid.
67
+ num_layers (int or None, optional):
68
+ If given, limit the number of intermediate layers to go through.
69
+ Providing `1` will stop the computation after going through one
70
+ intermediate layers. If not given, the outputs from all the
71
+ intermediate layers are returned.
72
+
73
+ Returns:
74
+ (List[Tensor], Optional[Tensor]):
75
+ List of Tensors
76
+ Features from requested layers.
77
+ Each Tensor is of shape: `(batch, time frame, feature dimension)`
78
+ Tensor or None
79
+ If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
80
+ is returned.
81
+ It indicates the valid length in time axis of each feature Tensor.
82
+ """
83
+ x, lengths = self.feature_extractor(waveforms, lengths)
84
+ x = self.encoder.extract_features(x, lengths, num_layers)
85
+ return x, lengths
86
+
87
+ def forward(
88
+ self,
89
+ waveforms: Tensor,
90
+ lengths: Optional[Tensor] = None,
91
+ ) -> Tuple[Tensor, Optional[Tensor]]:
92
+ """Compute the sequence of probability distribution over labels.
93
+
94
+ Args:
95
+ waveforms (Tensor): Audio tensor of shape `(batch, frames)`.
96
+ lengths (Tensor or None, optional):
97
+ Indicates the valid length of each audio in the batch.
98
+ Shape: `(batch, )`.
99
+ When the ``waveforms`` contains audios with different durations,
100
+ by providing ``lengths`` argument, the model will compute
101
+ the corresponding valid output lengths and apply proper mask in
102
+ transformer attention layer.
103
+ If ``None``, it is assumed that all the audio in ``waveforms``
104
+ have valid length. Default: ``None``.
105
+
106
+ Returns:
107
+ (Tensor, Optional[Tensor]):
108
+ Tensor
109
+ The sequences of probability distribution (in logit) over labels.
110
+ Shape: `(batch, frames, num labels)`.
111
+ Tensor or None
112
+ If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
113
+ is returned.
114
+ It indicates the valid length in time axis of the output Tensor.
115
+ """
116
+ x, lengths = self.feature_extractor(waveforms, lengths)
117
+ x = self.encoder(x, lengths)
118
+ if self.aux is not None:
119
+ x = self.aux(x)
120
+ return x, lengths
121
+
122
+
123
+ class HuBERTPretrainModel(Module):
124
+ """HuBERTPretrainModel()
125
+
126
+ HuBERT model used for pretraining in *HuBERT* :cite:`hsu2021hubert`.
127
+
128
+ Note:
129
+ To build the model, please use one of the factory functions.
130
+
131
+ See Also:
132
+ `HuBERT Pre-training and Fine-tuning Recipes
133
+ <https://github.com/pytorch/audio/tree/main/examples/hubert>`__
134
+
135
+ Args:
136
+ wav2vec2 (Wav2Vec2Model):
137
+ Wav2Vec2 encoder that generates the transformer outputs.
138
+
139
+ mask_generator (torch.nn.Module):
140
+ Mask generator that generates the mask for masked prediction during the training.
141
+
142
+ logit_generator (torch.nn.Module):
143
+ Logit generator that predicts the logits of the masked and unmasked inputs.
144
+
145
+ feature_grad_mult (float or None):
146
+ The factor to scale the convolutional feature extraction layer gradients by.
147
+ If ``None``, the gradients of feature extraction layers are not affected.
148
+ The scale factor will not affect the forward pass.
149
+ """
150
+
151
+ def __init__(
152
+ self,
153
+ wav2vec2: Wav2Vec2Model,
154
+ mask_generator: Module,
155
+ logit_generator: Module,
156
+ feature_grad_mult: Optional[float],
157
+ ):
158
+ super().__init__()
159
+ self.wav2vec2 = wav2vec2
160
+ self.mask_generator = mask_generator
161
+ self.logit_generator = logit_generator
162
+ if feature_grad_mult is not None and not 0.0 < feature_grad_mult < 1.0:
163
+ raise ValueError(
164
+ f"The value of `feature_grad_mult` must be ``None``or between (0, 1). Found {feature_grad_mult}"
165
+ )
166
+ self.feature_grad_mult = feature_grad_mult
167
+
168
+ def forward(
169
+ self,
170
+ waveforms: Tensor,
171
+ labels: Tensor,
172
+ audio_lengths: Optional[Tensor] = None,
173
+ ) -> Tuple[Tensor, Optional[Tensor]]:
174
+ """Compute the sequence of probability distribution over labels.
175
+
176
+ Args:
177
+ waveforms (Tensor): Audio tensor of dimension `[batch, frames]`.
178
+ labels (Tensor): Label for pre-training. A Tensor of dimension `[batch, frames]`.
179
+ audio_lengths (Tensor or None, optional):
180
+ Indicates the valid length of each audio in the batch.
181
+ Shape: `[batch, ]`.
182
+ When the ``waveforms`` contains audios with different durations,
183
+ by providing ``lengths`` argument, the model will compute
184
+ the corresponding valid output lengths and apply proper mask in
185
+ transformer attention layer.
186
+ If ``None``, it is assumed that all the audio in ``waveforms``
187
+ have valid length. Default: ``None``.
188
+
189
+ Returns:
190
+ (Tensor, Tensor, Tensor):
191
+ Tensor
192
+ The masked sequences of probability distribution (in logit).
193
+ Shape: `(masked_frames, num labels)`.
194
+ Tensor
195
+ The unmasked sequence of probability distribution (in logit).
196
+ Shape: `(unmasked_frames, num labels)`.
197
+ Tensor
198
+ The feature mean value for additional penalty loss.
199
+ Shape: `(1,)`.
200
+ """
201
+ x, lengths = self.wav2vec2.feature_extractor(waveforms, audio_lengths)
202
+ if self.feature_grad_mult is not None and self.feature_grad_mult < 1.0:
203
+ x = components.GradMultiply.apply(x, self.feature_grad_mult)
204
+ features_pen = x.float().pow(2).mean()
205
+ if lengths is not None:
206
+ padding_mask = components._get_padding_mask(x, lengths)
207
+ else:
208
+ padding_mask = None
209
+ x, attention_mask = self.wav2vec2.encoder._preprocess(x, lengths)
210
+ x, mask = self.mask_generator(x, padding_mask)
211
+ x = self.wav2vec2.encoder.transformer(x, attention_mask=attention_mask)
212
+ if x.shape[1] != labels.shape[1]:
213
+ raise ValueError("The length of label must match that of HuBERT model output")
214
+ if padding_mask is not None:
215
+ mask_m = torch.logical_and(~padding_mask, mask)
216
+ mask_u = torch.logical_and(~padding_mask, ~mask_m)
217
+ else:
218
+ mask_m = mask
219
+ mask_u = ~mask_m
220
+
221
+ logit_m, logit_u = self.logit_generator(x, labels, mask_m, mask_u)
222
+
223
+ return logit_m, logit_u, features_pen
224
+
225
+
226
+ def wav2vec2_model(
227
+ extractor_mode: str,
228
+ extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]],
229
+ extractor_conv_bias: bool,
230
+ encoder_embed_dim: int,
231
+ encoder_projection_dropout: float,
232
+ encoder_pos_conv_kernel: int,
233
+ encoder_pos_conv_groups: int,
234
+ encoder_num_layers: int,
235
+ encoder_num_heads: int,
236
+ encoder_attention_dropout: float,
237
+ encoder_ff_interm_features: int,
238
+ encoder_ff_interm_dropout: float,
239
+ encoder_dropout: float,
240
+ encoder_layer_norm_first: bool,
241
+ encoder_layer_drop: float,
242
+ aux_num_out: Optional[int],
243
+ ) -> Wav2Vec2Model:
244
+ """Builds custom :class:`~torchaudio.models.Wav2Vec2Model`.
245
+
246
+ Note:
247
+ The "feature extractor" below corresponds to
248
+ `ConvFeatureExtractionModel <https://github.com/pytorch/fairseq/blob/dd3bd3c0497ae9a7ae7364404a6b0a4c501780b3/fairseq/models/wav2vec/wav2vec2.py#L736>`__
249
+ in the original ``fairseq`` implementation.
250
+ This is referred as "(convolutional) feature encoder" in the *wav2vec 2.0*
251
+ :cite:`baevski2020wav2vec` paper.
252
+
253
+ The "encoder" below corresponds to `TransformerEncoder <https://github.com/pytorch/fairseq/blob/dd3bd3c0497ae9a7ae7364404a6b0a4c501780b3/fairseq/models/wav2vec/wav2vec2.py#L817>`__,
254
+ and this is referred as "Transformer" in the paper.
255
+
256
+ Args:
257
+ extractor_mode (str): Operation mode of feature extractor.
258
+ Valid values are ``"group_norm"`` or ``"layer_norm"``.
259
+ If ``"group_norm"``, then a single normalization is applied
260
+ in the first convolution block. Otherwise, all the convolution
261
+ blocks will have layer normalization.
262
+
263
+ This option corresponds to ``extractor_mode`` from ``fairseq``.
264
+ extractor_conv_layer_config (list of integer tuples or None):
265
+ Configuration of convolution layers in feature extractor.
266
+ List of convolution configuration,
267
+ i.e. ``[(output_channel, kernel_size, stride), ...]``
268
+
269
+ If ``None`` is provided, then the following default value is used.
270
+
271
+ .. code-block:: python
272
+
273
+ [
274
+ (512, 10, 5),
275
+ (512, 3, 2),
276
+ (512, 3, 2),
277
+ (512, 3, 2),
278
+ (512, 3, 2),
279
+ (512, 2, 2),
280
+ (512, 2, 2),
281
+ ]
282
+
283
+ This option corresponds to ``conv_feature_layers`` from ``fairseq``.
284
+
285
+ extractor_conv_bias (bool):
286
+ Whether to include bias term to each convolution operation.
287
+
288
+ This option corresponds to ``conv_bias`` from ``fairseq``.
289
+
290
+ encoder_embed_dim (int):
291
+ The dimension of embedding in encoder.
292
+
293
+ This option corresponds to ``encoder_embed_dim`` from ``fairseq``.
294
+
295
+ encoder_projection_dropout (float):
296
+ The dropout probability applied after the input feature is projected
297
+ to ``encoder_embed_dim``.
298
+
299
+ This option corresponds to ``dropout_input`` from ``fairseq``.
300
+
301
+ encoder_pos_conv_kernel (int):
302
+ The kernel size of convolutional positional embeddings.
303
+
304
+ This option corresponds to ``conv_pos`` from ``fairseq``.
305
+
306
+ encoder_pos_conv_groups (int):
307
+ The number of groups of convolutional positional embeddings.
308
+
309
+ This option corresponds to ``conv_pos_groups`` from ``fairseq``.
310
+
311
+ encoder_num_layers (int):
312
+ The number of self attention layers in transformer block.
313
+
314
+ This option corresponds to ``encoder_layers`` from ``fairseq``.
315
+
316
+ encoder_num_heads (int):
317
+ The number of heads in self attention layers.
318
+
319
+ This option corresponds to ``encoder_attention_heads`` from ``fairseq``.
320
+
321
+ encoder_attention_dropout (float):
322
+ The dropout probability applied after softmax in self-attention layer.
323
+
324
+ This option corresponds to ``attention_dropout`` from ``fairseq``.
325
+
326
+ encoder_ff_interm_features (int):
327
+ The dimension of hidden features in feed forward layer.
328
+
329
+ This option corresponds to ``encoder_ffn_embed_dim`` from ``fairseq``.
330
+
331
+ encoder_ff_interm_dropout (float):
332
+ The dropout probability applied in feedforward layer.
333
+
334
+ This option correspinds to ``activation_dropout`` from ``fairseq``.
335
+
336
+ encoder_dropout (float):
337
+ The dropout probability applied at the end of feed forward layer.
338
+
339
+ This option corresponds to ``dropout`` from ``fairseq``.
340
+
341
+ encoder_layer_norm_first (bool):
342
+ Control the order of layer norm in transformer layer and each encoder layer.
343
+ If True, in transformer layer, layer norm is applied before features are fed
344
+ to encoder layers. In encoder layer, two layer norms are applied before and after
345
+ self attention.
346
+ If False, in transformer layer, layer norm is applied after features are fed
347
+ to encoder layers. In encoder layer, two layer norms are applied after self
348
+ attention, before and after feed forward.
349
+
350
+ This option corresponds to ``layer_norm_first`` from ``fairseq``.
351
+
352
+ encoder_layer_drop (float):
353
+ Probability to drop each encoder layer during training.
354
+
355
+ This option corresponds to ``layerdrop`` from ``fairseq``.
356
+
357
+ aux_num_out (int or None):
358
+ When provided, attach an extra linear layer on top of encoder, which can be
359
+ used for fine-tuning.
360
+
361
+ Returns:
362
+ Wav2Vec2Model:
363
+ The resulting model.
364
+ """ # noqa: E501
365
+ if extractor_conv_layer_config is None:
366
+ extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
367
+
368
+ feature_extractor = components._get_feature_extractor(
369
+ extractor_mode, extractor_conv_layer_config, extractor_conv_bias
370
+ )
371
+ encoder = components._get_encoder(
372
+ in_features=extractor_conv_layer_config[-1][0],
373
+ embed_dim=encoder_embed_dim,
374
+ dropout_input=encoder_projection_dropout,
375
+ pos_conv_kernel=encoder_pos_conv_kernel,
376
+ pos_conv_groups=encoder_pos_conv_groups,
377
+ num_layers=encoder_num_layers,
378
+ num_heads=encoder_num_heads,
379
+ attention_dropout=encoder_attention_dropout,
380
+ ff_interm_features=encoder_ff_interm_features,
381
+ ff_interm_dropout=encoder_ff_interm_dropout,
382
+ dropout=encoder_dropout,
383
+ layer_norm_first=encoder_layer_norm_first,
384
+ layer_drop=encoder_layer_drop,
385
+ )
386
+ aux = None
387
+ if aux_num_out is not None:
388
+ aux = torch.nn.Linear(in_features=encoder_embed_dim, out_features=aux_num_out)
389
+ return Wav2Vec2Model(feature_extractor, encoder, aux)
390
+
391
+
392
+ def wav2vec2_base(
393
+ encoder_projection_dropout: float = 0.1,
394
+ encoder_attention_dropout: float = 0.1,
395
+ encoder_ff_interm_dropout: float = 0.1,
396
+ encoder_dropout: float = 0.1,
397
+ encoder_layer_drop: float = 0.1,
398
+ aux_num_out: Optional[int] = None,
399
+ ) -> Wav2Vec2Model:
400
+ """Builds "base" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec`
401
+
402
+ Args:
403
+ encoder_projection_dropout (float):
404
+ See :py:func:`wav2vec2_model`.
405
+ encoder_attention_dropout (float):
406
+ See :py:func:`wav2vec2_model`.
407
+ encoder_ff_interm_dropout (float):
408
+ See :py:func:`wav2vec2_model`.
409
+ encoder_dropout (float):
410
+ See :py:func:`wav2vec2_model`.
411
+ encoder_layer_drop (float):
412
+ See :py:func:`wav2vec2_model`.
413
+ aux_num_out (int or None, optional):
414
+ See :py:func:`wav2vec2_model`.
415
+
416
+ Returns:
417
+ Wav2Vec2Model:
418
+ The resulting model.
419
+ """ # noqa: E501
420
+ return wav2vec2_model(
421
+ extractor_mode="group_norm",
422
+ extractor_conv_layer_config=None,
423
+ extractor_conv_bias=False,
424
+ encoder_embed_dim=768,
425
+ encoder_projection_dropout=encoder_projection_dropout,
426
+ encoder_pos_conv_kernel=128,
427
+ encoder_pos_conv_groups=16,
428
+ encoder_num_layers=12,
429
+ encoder_num_heads=12,
430
+ encoder_attention_dropout=encoder_attention_dropout,
431
+ encoder_ff_interm_features=3072,
432
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
433
+ encoder_dropout=encoder_dropout,
434
+ encoder_layer_norm_first=False,
435
+ encoder_layer_drop=encoder_layer_drop,
436
+ aux_num_out=aux_num_out,
437
+ )
438
+
439
+
440
+ def wav2vec2_large(
441
+ encoder_projection_dropout: float = 0.1,
442
+ encoder_attention_dropout: float = 0.1,
443
+ encoder_ff_interm_dropout: float = 0.1,
444
+ encoder_dropout: float = 0.1,
445
+ encoder_layer_drop: float = 0.1,
446
+ aux_num_out: Optional[int] = None,
447
+ ) -> Wav2Vec2Model:
448
+ """Builds "large" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec`
449
+
450
+ Args:
451
+ encoder_projection_dropout (float):
452
+ See :py:func:`wav2vec2_model`.
453
+ encoder_attention_dropout (float):
454
+ See :py:func:`wav2vec2_model`.
455
+ encoder_ff_interm_dropout (float):
456
+ See :py:func:`wav2vec2_model`.
457
+ encoder_dropout (float):
458
+ See :py:func:`wav2vec2_model`.
459
+ encoder_layer_drop (float):
460
+ See :py:func:`wav2vec2_model`.
461
+ aux_num_out (int or None, optional):
462
+ See :py:func:`wav2vec2_model`.
463
+
464
+ Returns:
465
+ Wav2Vec2Model:
466
+ The resulting model.
467
+ """ # noqa: E501
468
+ return wav2vec2_model(
469
+ extractor_mode="group_norm",
470
+ extractor_conv_layer_config=None,
471
+ extractor_conv_bias=False,
472
+ encoder_embed_dim=1024,
473
+ encoder_projection_dropout=encoder_projection_dropout,
474
+ encoder_pos_conv_kernel=128,
475
+ encoder_pos_conv_groups=16,
476
+ encoder_num_layers=24,
477
+ encoder_num_heads=16,
478
+ encoder_attention_dropout=encoder_attention_dropout,
479
+ encoder_ff_interm_features=4096,
480
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
481
+ encoder_dropout=encoder_dropout,
482
+ encoder_layer_norm_first=False,
483
+ encoder_layer_drop=encoder_layer_drop,
484
+ aux_num_out=aux_num_out,
485
+ )
486
+
487
+
488
+ def wav2vec2_large_lv60k(
489
+ encoder_projection_dropout: float = 0.1,
490
+ encoder_attention_dropout: float = 0.0,
491
+ encoder_ff_interm_dropout: float = 0.1,
492
+ encoder_dropout: float = 0.0,
493
+ encoder_layer_drop: float = 0.1,
494
+ aux_num_out: Optional[int] = None,
495
+ ) -> Wav2Vec2Model:
496
+ """Builds "large lv-60k" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec`
497
+
498
+ Args:
499
+ encoder_projection_dropout (float):
500
+ See :py:func:`wav2vec2_model`.
501
+ encoder_attention_dropout (float):
502
+ See :py:func:`wav2vec2_model`.
503
+ encoder_ff_interm_dropout (float):
504
+ See :py:func:`wav2vec2_model`.
505
+ encoder_dropout (float):
506
+ See :py:func:`wav2vec2_model`.
507
+ encoder_layer_drop (float):
508
+ See :py:func:`wav2vec2_model`.
509
+ aux_num_out (int or None, optional):
510
+ See :py:func:`wav2vec2_model`.
511
+
512
+ Returns:
513
+ Wav2Vec2Model:
514
+ The resulting model.
515
+ """ # noqa: E501
516
+ return wav2vec2_model(
517
+ extractor_mode="layer_norm",
518
+ extractor_conv_layer_config=None,
519
+ extractor_conv_bias=True,
520
+ encoder_embed_dim=1024,
521
+ encoder_projection_dropout=encoder_projection_dropout,
522
+ encoder_pos_conv_kernel=128,
523
+ encoder_pos_conv_groups=16,
524
+ encoder_num_layers=24,
525
+ encoder_num_heads=16,
526
+ encoder_attention_dropout=encoder_attention_dropout,
527
+ encoder_ff_interm_features=4096,
528
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
529
+ encoder_dropout=encoder_dropout,
530
+ encoder_layer_norm_first=True,
531
+ encoder_layer_drop=encoder_layer_drop,
532
+ aux_num_out=aux_num_out,
533
+ )
534
+
535
+
536
+ def hubert_base(
537
+ encoder_projection_dropout: float = 0.1,
538
+ encoder_attention_dropout: float = 0.1,
539
+ encoder_ff_interm_dropout: float = 0.0,
540
+ encoder_dropout: float = 0.1,
541
+ encoder_layer_drop: float = 0.05,
542
+ aux_num_out: Optional[int] = None,
543
+ ) -> Wav2Vec2Model:
544
+ """Builds "base" :class:`HuBERT <torchaudio.models.Wav2Vec2Model>` from *HuBERT* :cite:`hsu2021hubert`
545
+
546
+ Args:
547
+ encoder_projection_dropout (float):
548
+ See :py:func:`wav2vec2_model`.
549
+ encoder_attention_dropout (float):
550
+ See :py:func:`wav2vec2_model`.
551
+ encoder_ff_interm_dropout (float):
552
+ See :py:func:`wav2vec2_model`.
553
+ encoder_dropout (float):
554
+ See :py:func:`wav2vec2_model`.
555
+ encoder_layer_drop (float):
556
+ See :py:func:`wav2vec2_model`.
557
+ aux_num_out (int or None, optional):
558
+ See :py:func:`wav2vec2_model`.
559
+
560
+ Returns:
561
+ Wav2Vec2Model:
562
+ The resulting model.
563
+ """ # noqa: E501
564
+ return wav2vec2_model(
565
+ extractor_mode="group_norm",
566
+ extractor_conv_layer_config=None,
567
+ extractor_conv_bias=False,
568
+ encoder_embed_dim=768,
569
+ encoder_projection_dropout=encoder_projection_dropout,
570
+ encoder_pos_conv_kernel=128,
571
+ encoder_pos_conv_groups=16,
572
+ encoder_num_layers=12,
573
+ encoder_num_heads=12,
574
+ encoder_attention_dropout=encoder_attention_dropout,
575
+ encoder_ff_interm_features=3072,
576
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
577
+ encoder_dropout=encoder_dropout,
578
+ encoder_layer_norm_first=False,
579
+ encoder_layer_drop=encoder_layer_drop,
580
+ aux_num_out=aux_num_out,
581
+ )
582
+
583
+
584
+ def hubert_large(
585
+ encoder_projection_dropout: float = 0.0,
586
+ encoder_attention_dropout: float = 0.0,
587
+ encoder_ff_interm_dropout: float = 0.0,
588
+ encoder_dropout: float = 0.0,
589
+ encoder_layer_drop: float = 0.0,
590
+ aux_num_out: Optional[int] = None,
591
+ ) -> Wav2Vec2Model:
592
+ """Builds "large" :class:`HuBERT <torchaudio.models.Wav2Vec2Model>` from *HuBERT* :cite:`hsu2021hubert`
593
+
594
+ Args:
595
+ encoder_projection_dropout (float):
596
+ See :py:func:`wav2vec2_model`.
597
+ encoder_attention_dropout (float):
598
+ See :py:func:`wav2vec2_model`.
599
+ encoder_ff_interm_dropout (float):
600
+ See :py:func:`wav2vec2_model`.
601
+ encoder_dropout (float):
602
+ See :py:func:`wav2vec2_model`.
603
+ encoder_layer_drop (float):
604
+ See :py:func:`wav2vec2_model`.
605
+ aux_num_out (int or None, optional):
606
+ See :py:func:`wav2vec2_model`.
607
+
608
+ Returns:
609
+ Wav2Vec2Model:
610
+ The resulting model.
611
+ """ # noqa: E501
612
+ return wav2vec2_model(
613
+ extractor_mode="layer_norm",
614
+ extractor_conv_layer_config=None,
615
+ extractor_conv_bias=False,
616
+ encoder_embed_dim=1024,
617
+ encoder_projection_dropout=encoder_projection_dropout,
618
+ encoder_pos_conv_kernel=128,
619
+ encoder_pos_conv_groups=16,
620
+ encoder_num_layers=24,
621
+ encoder_num_heads=16,
622
+ encoder_attention_dropout=encoder_attention_dropout,
623
+ encoder_ff_interm_features=4096,
624
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
625
+ encoder_dropout=encoder_dropout,
626
+ encoder_layer_norm_first=True,
627
+ encoder_layer_drop=encoder_layer_drop,
628
+ aux_num_out=aux_num_out,
629
+ )
630
+
631
+
632
+ def hubert_xlarge(
633
+ encoder_projection_dropout: float = 0.0,
634
+ encoder_attention_dropout: float = 0.0,
635
+ encoder_ff_interm_dropout: float = 0.0,
636
+ encoder_dropout: float = 0.0,
637
+ encoder_layer_drop: float = 0.0,
638
+ aux_num_out: Optional[int] = None,
639
+ ) -> Wav2Vec2Model:
640
+ """Builds "extra large" :class:`HuBERT <torchaudio.models.Wav2Vec2Model>` from *HuBERT* :cite:`hsu2021hubert`
641
+
642
+ Args:
643
+ encoder_projection_dropout (float):
644
+ See :py:func:`wav2vec2_model`.
645
+ encoder_attention_dropout (float):
646
+ See :py:func:`wav2vec2_model`.
647
+ encoder_ff_interm_dropout (float):
648
+ See :py:func:`wav2vec2_model`.
649
+ encoder_dropout (float):
650
+ See :py:func:`wav2vec2_model`.
651
+ encoder_layer_drop (float):
652
+ See :py:func:`wav2vec2_model`.
653
+ aux_num_out (int or None, optional):
654
+ See :py:func:`wav2vec2_model`.
655
+
656
+ Returns:
657
+ Wav2Vec2Model:
658
+ The resulting model.
659
+ """ # noqa: E501
660
+ return wav2vec2_model(
661
+ extractor_mode="layer_norm",
662
+ extractor_conv_layer_config=None,
663
+ extractor_conv_bias=False,
664
+ encoder_embed_dim=1280,
665
+ encoder_projection_dropout=encoder_projection_dropout,
666
+ encoder_pos_conv_kernel=128,
667
+ encoder_pos_conv_groups=16,
668
+ encoder_num_layers=48,
669
+ encoder_num_heads=16,
670
+ encoder_attention_dropout=encoder_attention_dropout,
671
+ encoder_ff_interm_features=5120,
672
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
673
+ encoder_dropout=encoder_dropout,
674
+ encoder_layer_norm_first=True,
675
+ encoder_layer_drop=encoder_layer_drop,
676
+ aux_num_out=aux_num_out,
677
+ )
678
+
679
+
680
+ def _init_hubert_pretrain_model(module):
681
+ if isinstance(module, components.ConvLayerBlock):
682
+ torch.nn.init.kaiming_normal_(module.conv.weight)
683
+ elif isinstance(module, components.ConvolutionalPositionalEmbedding):
684
+ # normalize the weight to normal distribution.
685
+ std = math.sqrt(4.0 / (module.embed_dim * module.kernel_size))
686
+ torch.nn.init.normal_(module.conv.weight, mean=0.0, std=std)
687
+ torch.nn.init.constant_(module.conv.bias, 0.0)
688
+ elif isinstance(module, components.SelfAttention):
689
+ # normalize the query, key, value, and out_proj parameters in self attention module.
690
+ torch.nn.init.xavier_uniform_(module.k_proj.weight, gain=1 / math.sqrt(2))
691
+ torch.nn.init.xavier_uniform_(module.v_proj.weight, gain=1 / math.sqrt(2))
692
+ torch.nn.init.xavier_uniform_(module.q_proj.weight, gain=1 / math.sqrt(2))
693
+ torch.nn.init.xavier_uniform_(module.out_proj.weight)
694
+ torch.nn.init.constant_(module.out_proj.bias, 0.0)
695
+ elif isinstance(module, components.Transformer):
696
+ module.apply(components._init_transformer_params)
697
+ else:
698
+ pass
699
+
700
+
701
+ def hubert_pretrain_model(
702
+ extractor_mode: str,
703
+ extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]],
704
+ extractor_conv_bias: bool,
705
+ encoder_embed_dim: int,
706
+ encoder_projection_dropout: float,
707
+ encoder_pos_conv_kernel: int,
708
+ encoder_pos_conv_groups: int,
709
+ encoder_num_layers: int,
710
+ encoder_num_heads: int,
711
+ encoder_attention_dropout: float,
712
+ encoder_ff_interm_features: int,
713
+ encoder_ff_interm_dropout: float,
714
+ encoder_dropout: float,
715
+ encoder_layer_norm_first: bool,
716
+ encoder_layer_drop: float,
717
+ mask_prob: float,
718
+ mask_selection: str,
719
+ mask_other: float,
720
+ mask_length: int,
721
+ no_mask_overlap: bool,
722
+ mask_min_space: int,
723
+ mask_channel_prob: float,
724
+ mask_channel_selection: str,
725
+ mask_channel_other: float,
726
+ mask_channel_length: int,
727
+ no_mask_channel_overlap: bool,
728
+ mask_channel_min_space: int,
729
+ skip_masked: bool,
730
+ skip_nomask: bool,
731
+ num_classes: int,
732
+ final_dim: int,
733
+ feature_grad_mult: Optional[float],
734
+ ) -> HuBERTPretrainModel:
735
+ """Builds custom :class:`HuBERTPretrainModel` for training from scratch
736
+
737
+ Note:
738
+ The "feature extractor" below corresponds to
739
+ `ConvFeatureExtractionModel <https://github.com/pytorch/fairseq/blob/dd3bd3c0497ae9a7ae7364404a6b0a4c501780b3/fairseq/models/wav2vec/wav2vec2.py#L736>`__
740
+ in the original ``fairseq`` implementation.
741
+ This is referred as "(convolutional) feature encoder" in the *wav2vec 2.0*
742
+ :cite:`baevski2020wav2vec` paper.
743
+
744
+ The "encoder" below corresponds to `TransformerEncoder <https://github.com/pytorch/fairseq/blob/dd3bd3c0497ae9a7ae7364404a6b0a4c501780b3/fairseq/models/wav2vec/wav2vec2.py#L817>`__,
745
+ and this is referred as "Transformer" in the paper.
746
+
747
+ Args:
748
+ extractor_mode (str): Operation mode of feature extractor.
749
+ Valid values are ``"group_norm"`` or ``"layer_norm"``.
750
+ If ``"group_norm"``, then a single normalization is applied
751
+ in the first convolution block. Otherwise, all the convolution
752
+ blocks will have layer normalization.
753
+
754
+ This option corresponds to ``extractor_mode`` from ``fairseq``.
755
+
756
+ extractor_conv_layer_config (list of integer tuples or None):
757
+ Configuration of convolution layers in feature extractor.
758
+ List of convolution configuration,
759
+ i.e. ``[(output_channel, kernel_size, stride), ...]``
760
+
761
+ If ``None`` is provided, then the following default value is used.
762
+
763
+ .. code-block:: python
764
+
765
+ [
766
+ (512, 10, 5),
767
+ (512, 3, 2),
768
+ (512, 3, 2),
769
+ (512, 3, 2),
770
+ (512, 3, 2),
771
+ (512, 2, 2),
772
+ (512, 2, 2),
773
+ ]
774
+
775
+ This option corresponds to ``conv_feature_layers`` from ``fairseq``.
776
+
777
+ extractor_conv_bias (bool):
778
+ Whether to include bias term to each convolution operation.
779
+
780
+ This option corresponds to ``conv_bias`` from ``fairseq``.
781
+
782
+ encoder_embed_dim (int):
783
+ The dimension of embedding in encoder.
784
+
785
+ This option corresponds to ``encoder_embed_dim`` from ``fairseq``.
786
+
787
+ encoder_projection_dropout (float):
788
+ The dropout probability applied after the input feature is projected
789
+ to ``encoder_embed_dim``.
790
+
791
+ This option corresponds to ``dropout_input`` from ``fairseq``.
792
+
793
+ encoder_pos_conv_kernel (int):
794
+ The kernel size of convolutional positional embeddings.
795
+
796
+ This option corresponds to ``conv_pos`` from ``fairseq``.
797
+
798
+ encoder_pos_conv_groups (int):
799
+ The number of groups of convolutional positional embeddings.
800
+
801
+ This option corresponds to ``conv_pos_groups`` from ``fairseq``.
802
+
803
+ encoder_num_layers (int):
804
+ The number of self attention layers in transformer block.
805
+
806
+ This option corresponds to ``encoder_layers`` from ``fairseq``.
807
+
808
+ encoder_num_heads (int):
809
+ The number of heads in self attention layers.
810
+
811
+ This option corresponds to ``encoder_attention_heads`` from ``fairseq``.
812
+
813
+ encoder_attention_dropout (float):
814
+ The dropout probability applied after softmax in self-attention layer.
815
+
816
+ This option corresponds to ``attention_dropout`` from ``fairseq``.
817
+
818
+ encoder_ff_interm_features (int):
819
+ The dimension of hidden features in feed forward layer.
820
+
821
+ This option corresponds to ``encoder_ffn_embed_dim`` from ``fairseq``.
822
+
823
+ encoder_ff_interm_dropout (float):
824
+ The dropout probability applied in feedforward layer.
825
+
826
+ This option correspinds to ``activation_dropout`` from ``fairseq``.
827
+
828
+ encoder_dropout (float):
829
+ The dropout probability applied at the end of feed forward layer.
830
+
831
+ This option corresponds to ``dropout`` from ``fairseq``.
832
+
833
+ encoder_layer_norm_first (bool):
834
+ Control the order of layer norm in transformer layer and each encoder layer.
835
+ If True, in transformer layer, layer norm is applied before features are fed
836
+ to encoder layers. In encoder layer, two layer norms are applied before and after
837
+ self attention.
838
+ If False, in transformer layer, layer norm is applied after features are fed
839
+ to encoder layers. In encoder layer, two layer norms are applied after self
840
+ attention, before and after feed forward.
841
+
842
+ This option corresponds to ``layer_norm_first`` from ``fairseq``.
843
+
844
+ encoder_layer_drop (float):
845
+ Probability to drop each encoder layer during training.
846
+
847
+ This option corresponds to ``layerdrop`` from ``fairseq``.
848
+
849
+ mask_prob (float):
850
+ Probability for each token to be chosen as start of the span to be masked. this will be multiplied by
851
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
852
+ However due to overlaps, the actual number will be smaller (unless no_overlap is True).
853
+
854
+ This option corresponds to ``mask_prob`` from ``fairseq``.
855
+
856
+ mask_selection (str):
857
+ How to choose the mask length. Options: [``static``, ``uniform``, ``normal``, ``poisson``].
858
+
859
+ This option corresponds to ``mask_selection`` from ``fairseq``.
860
+
861
+ mask_other (float):
862
+ Secondary mask argument (used for more complex distributions).
863
+
864
+ This option corresponds to ``mask_other`` from ``fairseq``.
865
+
866
+ mask_length (int):
867
+ The lengths of the mask.
868
+
869
+ This option corresponds to ``mask_length`` from ``fairseq``.
870
+
871
+ no_mask_overlap (bool):
872
+ Whether to allow masks to overlap.
873
+
874
+ This option corresponds to ``no_mask_overlap`` from ``fairseq``.
875
+
876
+ mask_min_space (int):
877
+ Minimum space between spans (if no overlap is enabled).
878
+
879
+ This option corresponds to ``mask_min_space`` from ``fairseq``.
880
+
881
+ mask_channel_prob: (float):
882
+ The probability of replacing a feature with 0.
883
+
884
+ This option corresponds to ``mask_channel_prob`` from ``fairseq``.
885
+
886
+ mask_channel_selection (str):
887
+ How to choose the mask length for channel masking. Options: [``static``, ``uniform``, ``normal``, ``poisson``].
888
+
889
+ This option corresponds to ``mask_channel_selection`` from ``fairseq``.
890
+
891
+ mask_channel_other (float):
892
+ Secondary mask argument for channel masking(used for more complex distributions).
893
+
894
+ This option corresponds to ``mask_channel_other`` from ``fairseq``.
895
+
896
+ mask_channel_length (int):
897
+ Minimum space between spans (if no overlap is enabled) for channel masking.
898
+
899
+ This option corresponds to ``mask_channel_length`` from ``fairseq``.
900
+
901
+ no_mask_channel_overlap (bool):
902
+ Whether to allow channel masks to overlap.
903
+
904
+ This option corresponds to ``no_mask_channel_overlap`` from ``fairseq``.
905
+
906
+ mask_channel_min_space (int):
907
+ Minimum space between spans for channel masking(if no overlap is enabled).
908
+
909
+ This option corresponds to ``mask_channel_min_space`` from ``fairseq``.
910
+
911
+ skip_masked (bool):
912
+ If True, skip computing losses over masked frames.
913
+
914
+ This option corresponds to ``skip_masked`` from ``fairseq``.
915
+
916
+ skip_nomask (bool):
917
+ If True, skip computing losses over unmasked frames.
918
+
919
+ This option corresponds to ``skip_nomask`` from ``fairseq``.
920
+
921
+ num_classes (int):
922
+ The number of classes in the labels.
923
+
924
+ final_dim (int):
925
+ Project final representations and targets to `final_dim`.
926
+
927
+ This option corresponds to ``final_dim`` from ``fairseq``.
928
+
929
+ feature_grad_mult (float or None):
930
+ The factor to scale the convolutional feature extraction layer gradients by.
931
+ The scale factor will not affect the forward pass.
932
+
933
+ This option corresponds to ``feature_grad_mult`` from ``fairseq``.
934
+
935
+ Returns:
936
+ HuBERTPretrainModel:
937
+ The resulting model.
938
+ """ # noqa: E501
939
+ if extractor_conv_layer_config is None:
940
+ extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
941
+
942
+ feature_extractor = components._get_feature_extractor(
943
+ extractor_mode, extractor_conv_layer_config, extractor_conv_bias
944
+ )
945
+ encoder = components._get_encoder(
946
+ in_features=extractor_conv_layer_config[-1][0],
947
+ embed_dim=encoder_embed_dim,
948
+ dropout_input=encoder_projection_dropout,
949
+ pos_conv_kernel=encoder_pos_conv_kernel,
950
+ pos_conv_groups=encoder_pos_conv_groups,
951
+ num_layers=encoder_num_layers,
952
+ num_heads=encoder_num_heads,
953
+ attention_dropout=encoder_attention_dropout,
954
+ ff_interm_features=encoder_ff_interm_features,
955
+ ff_interm_dropout=encoder_ff_interm_dropout,
956
+ dropout=encoder_dropout,
957
+ layer_norm_first=encoder_layer_norm_first,
958
+ layer_drop=encoder_layer_drop,
959
+ )
960
+ wav2vec2 = Wav2Vec2Model(feature_extractor, encoder)
961
+ mask_generator = components.MaskGenerator(
962
+ encoder_embed_dim,
963
+ mask_prob,
964
+ mask_selection,
965
+ mask_other,
966
+ mask_length,
967
+ no_mask_overlap,
968
+ mask_min_space,
969
+ mask_channel_prob,
970
+ mask_channel_selection,
971
+ mask_channel_other,
972
+ mask_channel_length,
973
+ no_mask_channel_overlap,
974
+ mask_channel_min_space,
975
+ )
976
+ logit_generator = components.LogitGenerator(
977
+ encoder_embed_dim,
978
+ num_classes,
979
+ final_dim,
980
+ skip_masked,
981
+ skip_nomask,
982
+ )
983
+ model = HuBERTPretrainModel(
984
+ wav2vec2=wav2vec2,
985
+ mask_generator=mask_generator,
986
+ logit_generator=logit_generator,
987
+ feature_grad_mult=feature_grad_mult,
988
+ )
989
+ # initialize the model for pre-training
990
+ model.apply(_init_hubert_pretrain_model)
991
+ return model
992
+
993
+
994
+ def hubert_pretrain_base(
995
+ encoder_projection_dropout: float = 0.1,
996
+ encoder_attention_dropout: float = 0.1,
997
+ encoder_ff_interm_dropout: float = 0.0,
998
+ encoder_dropout: float = 0.1,
999
+ encoder_layer_drop: float = 0.05,
1000
+ mask_prob: float = 0.8,
1001
+ mask_channel_prob: float = 0.0,
1002
+ mask_channel_length: int = 10,
1003
+ feature_grad_mult: Optional[float] = 0.1,
1004
+ num_classes: int = 100,
1005
+ ) -> HuBERTPretrainModel:
1006
+ """Builds "base" :class:`HuBERTPretrainModel` from *HuBERT* :cite:`hsu2021hubert` for pretraining.
1007
+
1008
+ Args:
1009
+ encoder_projection_dropout (float):
1010
+ See :py:func:`hubert_pretrain_model`.
1011
+ encoder_attention_dropout (float):
1012
+ See :py:func:`hubert_pretrain_model`.
1013
+ encoder_ff_interm_dropout (float):
1014
+ See :py:func:`hubert_pretrain_model`.
1015
+ encoder_dropout (float):
1016
+ See :py:func:`hubert_pretrain_model`.
1017
+ encoder_layer_drop (float):
1018
+ See :py:func:`hubert_pretrain_model`.
1019
+ mask_prob (float):
1020
+ See :py:func:`hubert_pretrain_model`.
1021
+ mask_channel_prob (float):
1022
+ See :py:func:`hubert_pretrain_model`.
1023
+ mask_channel_length (int):
1024
+ See :py:func:`hubert_pretrain_model`.
1025
+ feature_grad_mult (float or None):
1026
+ See :py:func:`hubert_pretrain_model`.
1027
+ num_classes (int, optional):
1028
+ See :py:func:`hubert_pretrain_model`.
1029
+
1030
+ Returns:
1031
+ HuBERTPretrainModel:
1032
+ The resulting model.
1033
+ """ # noqa: E501
1034
+ return hubert_pretrain_model(
1035
+ extractor_mode="group_norm",
1036
+ extractor_conv_layer_config=None,
1037
+ extractor_conv_bias=False,
1038
+ encoder_embed_dim=768,
1039
+ encoder_projection_dropout=encoder_projection_dropout,
1040
+ encoder_pos_conv_kernel=128,
1041
+ encoder_pos_conv_groups=16,
1042
+ encoder_num_layers=12,
1043
+ encoder_num_heads=12,
1044
+ encoder_attention_dropout=encoder_attention_dropout,
1045
+ encoder_ff_interm_features=3072,
1046
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
1047
+ encoder_dropout=encoder_dropout,
1048
+ encoder_layer_norm_first=False,
1049
+ encoder_layer_drop=encoder_layer_drop,
1050
+ mask_prob=mask_prob,
1051
+ mask_selection="static",
1052
+ mask_other=0.0,
1053
+ mask_length=10,
1054
+ no_mask_overlap=False,
1055
+ mask_min_space=1,
1056
+ mask_channel_prob=mask_channel_prob,
1057
+ mask_channel_selection="static",
1058
+ mask_channel_other=0.0,
1059
+ mask_channel_length=mask_channel_length,
1060
+ no_mask_channel_overlap=False,
1061
+ mask_channel_min_space=1,
1062
+ skip_masked=False,
1063
+ skip_nomask=False,
1064
+ num_classes=num_classes,
1065
+ final_dim=256,
1066
+ feature_grad_mult=feature_grad_mult,
1067
+ )
1068
+
1069
+
1070
+ def hubert_pretrain_large(
1071
+ encoder_projection_dropout: float = 0.0,
1072
+ encoder_attention_dropout: float = 0.0,
1073
+ encoder_ff_interm_dropout: float = 0.0,
1074
+ encoder_dropout: float = 0.0,
1075
+ encoder_layer_drop: float = 0.0,
1076
+ mask_prob: float = 0.8,
1077
+ mask_channel_prob: float = 0.0,
1078
+ mask_channel_length: int = 10,
1079
+ feature_grad_mult: Optional[float] = None,
1080
+ ) -> HuBERTPretrainModel:
1081
+ """Builds "large" :class:`HuBERTPretrainModel` from *HuBERT* :cite:`hsu2021hubert` for pretraining.
1082
+
1083
+ Args:
1084
+ encoder_projection_dropout (float):
1085
+ See :py:func:`hubert_pretrain_model`.
1086
+ encoder_attention_dropout (float):
1087
+ See :py:func:`hubert_pretrain_model`.
1088
+ encoder_ff_interm_dropout (float):
1089
+ See :py:func:`hubert_pretrain_model`.
1090
+ encoder_dropout (float):
1091
+ See :py:func:`hubert_pretrain_model`.
1092
+ encoder_layer_drop (float):
1093
+ See :py:func:`hubert_pretrain_model`.
1094
+ mask_prob (float):
1095
+ See :py:func:`hubert_pretrain_model`.
1096
+ mask_channel_prob (float):
1097
+ See :py:func:`hubert_pretrain_model`.
1098
+ mask_channel_length (int):
1099
+ See :py:func:`hubert_pretrain_model`.
1100
+ feature_grad_mult (float or None):
1101
+ See :py:func:`hubert_pretrain_model`.
1102
+
1103
+ Returns:
1104
+ HuBERTPretrainModel:
1105
+ The resulting model.
1106
+ """ # noqa: E501
1107
+ return hubert_pretrain_model(
1108
+ extractor_mode="layer_norm",
1109
+ extractor_conv_layer_config=None,
1110
+ extractor_conv_bias=False,
1111
+ encoder_embed_dim=1024,
1112
+ encoder_projection_dropout=encoder_projection_dropout,
1113
+ encoder_pos_conv_kernel=128,
1114
+ encoder_pos_conv_groups=16,
1115
+ encoder_num_layers=24,
1116
+ encoder_num_heads=16,
1117
+ encoder_attention_dropout=encoder_attention_dropout,
1118
+ encoder_ff_interm_features=4096,
1119
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
1120
+ encoder_dropout=encoder_dropout,
1121
+ encoder_layer_norm_first=True,
1122
+ encoder_layer_drop=encoder_layer_drop,
1123
+ mask_prob=mask_prob,
1124
+ mask_selection="static",
1125
+ mask_other=0.0,
1126
+ mask_length=10,
1127
+ no_mask_overlap=False,
1128
+ mask_min_space=1,
1129
+ mask_channel_prob=mask_channel_prob,
1130
+ mask_channel_selection="static",
1131
+ mask_channel_other=0.0,
1132
+ mask_channel_length=mask_channel_length,
1133
+ no_mask_channel_overlap=False,
1134
+ mask_channel_min_space=1,
1135
+ skip_masked=False,
1136
+ skip_nomask=False,
1137
+ num_classes=500,
1138
+ final_dim=768,
1139
+ feature_grad_mult=feature_grad_mult,
1140
+ )
1141
+
1142
+
1143
+ def hubert_pretrain_xlarge(
1144
+ encoder_projection_dropout: float = 0.0,
1145
+ encoder_attention_dropout: float = 0.0,
1146
+ encoder_ff_interm_dropout: float = 0.0,
1147
+ encoder_dropout: float = 0.0,
1148
+ encoder_layer_drop: float = 0.0,
1149
+ mask_prob: float = 0.8,
1150
+ mask_channel_prob: float = 0.0,
1151
+ mask_channel_length: int = 10,
1152
+ feature_grad_mult: Optional[float] = None,
1153
+ ) -> HuBERTPretrainModel:
1154
+ """Builds "extra large" :class:`HuBERTPretrainModel` from *HuBERT* :cite:`hsu2021hubert` for pretraining.
1155
+
1156
+ Args:
1157
+ encoder_projection_dropout (float):
1158
+ See :py:func:`hubert_pretrain_model`.
1159
+ encoder_attention_dropout (float):
1160
+ See :py:func:`hubert_pretrain_model`.
1161
+ encoder_ff_interm_dropout (float):
1162
+ See :py:func:`hubert_pretrain_model`.
1163
+ encoder_dropout (float):
1164
+ See :py:func:`hubert_pretrain_model`.
1165
+ encoder_layer_drop (float):
1166
+ See :py:func:`hubert_pretrain_model`.
1167
+ mask_prob (float):
1168
+ See :py:func:`hubert_pretrain_model`.
1169
+ mask_channel_prob (float):
1170
+ See :py:func:`hubert_pretrain_model`.
1171
+ mask_channel_length (int):
1172
+ See :py:func:`hubert_pretrain_model`.
1173
+ feature_grad_mult (float or None):
1174
+ See :py:func:`hubert_pretrain_model`.
1175
+
1176
+ Returns:
1177
+ HuBERTPretrainModel:
1178
+ The resulting model.
1179
+ """ # noqa: E501
1180
+ return hubert_pretrain_model(
1181
+ extractor_mode="layer_norm",
1182
+ extractor_conv_layer_config=None,
1183
+ extractor_conv_bias=False,
1184
+ encoder_embed_dim=1280,
1185
+ encoder_projection_dropout=encoder_projection_dropout,
1186
+ encoder_pos_conv_kernel=128,
1187
+ encoder_pos_conv_groups=16,
1188
+ encoder_num_layers=48,
1189
+ encoder_num_heads=16,
1190
+ encoder_attention_dropout=encoder_attention_dropout,
1191
+ encoder_ff_interm_features=5120,
1192
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
1193
+ encoder_dropout=encoder_dropout,
1194
+ encoder_layer_norm_first=True,
1195
+ encoder_layer_drop=encoder_layer_drop,
1196
+ mask_prob=mask_prob,
1197
+ mask_selection="static",
1198
+ mask_other=0.0,
1199
+ mask_length=10,
1200
+ no_mask_overlap=False,
1201
+ mask_min_space=1,
1202
+ mask_channel_prob=mask_channel_prob,
1203
+ mask_channel_selection="static",
1204
+ mask_channel_other=0.0,
1205
+ mask_channel_length=mask_channel_length,
1206
+ no_mask_channel_overlap=False,
1207
+ mask_channel_min_space=1,
1208
+ skip_masked=False,
1209
+ skip_nomask=False,
1210
+ num_classes=500,
1211
+ final_dim=1024,
1212
+ feature_grad_mult=feature_grad_mult,
1213
+ )
1214
+
1215
+
1216
+ def wavlm_model(
1217
+ extractor_mode: str,
1218
+ extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]],
1219
+ extractor_conv_bias: bool,
1220
+ encoder_embed_dim: int,
1221
+ encoder_projection_dropout: float,
1222
+ encoder_pos_conv_kernel: int,
1223
+ encoder_pos_conv_groups: int,
1224
+ encoder_num_layers: int,
1225
+ encoder_num_heads: int,
1226
+ encoder_num_buckets: int,
1227
+ encoder_max_distance: int,
1228
+ encoder_attention_dropout: float,
1229
+ encoder_ff_interm_features: int,
1230
+ encoder_ff_interm_dropout: float,
1231
+ encoder_dropout: float,
1232
+ encoder_layer_norm_first: bool,
1233
+ encoder_layer_drop: float,
1234
+ aux_num_out: Optional[int],
1235
+ ) -> Wav2Vec2Model:
1236
+ """Builds custom WaveLM model :cite:`chen2022wavlm`. The architecture is compatible
1237
+ with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output object is
1238
+ :class:`~torchaudio.models.Wav2Vec2Model`. Most of the arguments have the same meaning
1239
+ as in :py:func:`~torchaudio.models.wav2vec2_model` so please refer there for documentation.
1240
+
1241
+ Args:
1242
+ extractor_mode (str): Operation mode of feature extractor.
1243
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1244
+
1245
+ extractor_conv_layer_config (list of integer tuples or None):
1246
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1247
+
1248
+ extractor_conv_bias (bool):
1249
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1250
+
1251
+ encoder_embed_dim (int):
1252
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1253
+
1254
+ encoder_projection_dropout (float):
1255
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1256
+
1257
+ encoder_pos_conv_kernel (int):
1258
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1259
+
1260
+ encoder_pos_conv_groups (int):
1261
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1262
+
1263
+ encoder_num_layers (int):
1264
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1265
+
1266
+ encoder_num_heads (int):
1267
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1268
+
1269
+ encoder_num_buckets (int):
1270
+ Number of buckets for relative position embedding.
1271
+ encoder_max_distance (int):
1272
+ Maximum distance for relative position embedding.
1273
+
1274
+ encoder_attention_dropout (float):
1275
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1276
+
1277
+ encoder_ff_interm_features (int):
1278
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1279
+
1280
+ encoder_ff_interm_dropout (float):
1281
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1282
+
1283
+ encoder_dropout (float):
1284
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1285
+
1286
+ encoder_layer_norm_first (bool):
1287
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1288
+
1289
+ encoder_layer_drop (float):
1290
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1291
+
1292
+ aux_num_out (int or None):
1293
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1294
+
1295
+ Returns:
1296
+ Wav2Vec2Model:
1297
+ The resulting model.
1298
+ """
1299
+ if extractor_conv_layer_config is None:
1300
+ extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
1301
+
1302
+ feature_extractor = components._get_feature_extractor(
1303
+ extractor_mode, extractor_conv_layer_config, extractor_conv_bias
1304
+ )
1305
+ encoder = components._get_wavlm_encoder(
1306
+ in_features=extractor_conv_layer_config[-1][0],
1307
+ embed_dim=encoder_embed_dim,
1308
+ dropout_input=encoder_projection_dropout,
1309
+ pos_conv_kernel=encoder_pos_conv_kernel,
1310
+ pos_conv_groups=encoder_pos_conv_groups,
1311
+ num_layers=encoder_num_layers,
1312
+ num_heads=encoder_num_heads,
1313
+ num_buckets=encoder_num_buckets,
1314
+ max_distance=encoder_max_distance,
1315
+ attention_dropout=encoder_attention_dropout,
1316
+ ff_interm_features=encoder_ff_interm_features,
1317
+ ff_interm_dropout=encoder_ff_interm_dropout,
1318
+ dropout=encoder_dropout,
1319
+ layer_norm_first=encoder_layer_norm_first,
1320
+ layer_drop=encoder_layer_drop,
1321
+ )
1322
+ aux = None
1323
+ if aux_num_out is not None:
1324
+ aux = torch.nn.Linear(in_features=encoder_embed_dim, out_features=aux_num_out)
1325
+ return Wav2Vec2Model(feature_extractor, encoder, aux)
1326
+
1327
+
1328
+ def wavlm_base(
1329
+ encoder_projection_dropout: float = 0.1,
1330
+ encoder_attention_dropout: float = 0.1,
1331
+ encoder_ff_interm_dropout: float = 0.1,
1332
+ encoder_dropout: float = 0.1,
1333
+ encoder_layer_drop: float = 0.1,
1334
+ aux_num_out: Optional[int] = None,
1335
+ ) -> Wav2Vec2Model:
1336
+ """Builds "base" WaveLM model :cite:`chen2022wavlm`. The architecture is compatible
1337
+ with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is
1338
+ :class:`~torchaudio.models.Wav2Vec2Model`.
1339
+
1340
+ Args:
1341
+ encoder_projection_dropout (float):
1342
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1343
+ encoder_attention_dropout (float):
1344
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1345
+ encoder_ff_interm_dropout (float):
1346
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1347
+ encoder_dropout (float):
1348
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1349
+ encoder_layer_drop (float):
1350
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1351
+ aux_num_out (int, optional):
1352
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1353
+
1354
+ Returns:
1355
+ Wav2Vec2Model:
1356
+ The resulting model.
1357
+ """
1358
+ return wavlm_model(
1359
+ extractor_mode="group_norm",
1360
+ extractor_conv_layer_config=None,
1361
+ extractor_conv_bias=False,
1362
+ encoder_embed_dim=768,
1363
+ encoder_projection_dropout=encoder_projection_dropout,
1364
+ encoder_pos_conv_kernel=128,
1365
+ encoder_pos_conv_groups=16,
1366
+ encoder_num_layers=12,
1367
+ encoder_num_heads=12,
1368
+ encoder_num_buckets=320,
1369
+ encoder_max_distance=800,
1370
+ encoder_attention_dropout=encoder_attention_dropout,
1371
+ encoder_ff_interm_features=3072,
1372
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
1373
+ encoder_dropout=encoder_dropout,
1374
+ encoder_layer_norm_first=False,
1375
+ encoder_layer_drop=encoder_layer_drop,
1376
+ aux_num_out=aux_num_out,
1377
+ )
1378
+
1379
+
1380
+ def wavlm_large(
1381
+ encoder_projection_dropout: float = 0.1,
1382
+ encoder_attention_dropout: float = 0.1,
1383
+ encoder_ff_interm_dropout: float = 0.0,
1384
+ encoder_dropout: float = 0.1,
1385
+ encoder_layer_drop: float = 0.1,
1386
+ aux_num_out: Optional[int] = None,
1387
+ ) -> Wav2Vec2Model:
1388
+ """Builds "large" WaveLM model :cite:`chen2022wavlm`. The architecture is compatible
1389
+ with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is
1390
+ :class:`~torchaudio.models.Wav2Vec2Model`.
1391
+
1392
+ Args:
1393
+ encoder_projection_dropout (float):
1394
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1395
+ encoder_attention_dropout (float):
1396
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1397
+ encoder_ff_interm_dropout (float):
1398
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1399
+ encoder_dropout (float):
1400
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1401
+ encoder_layer_drop (float):
1402
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1403
+ aux_num_out (int, optional):
1404
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1405
+
1406
+ Returns:
1407
+ Wav2Vec2Model:
1408
+ The resulting model.
1409
+ """
1410
+ return wavlm_model(
1411
+ extractor_mode="layer_norm",
1412
+ extractor_conv_layer_config=None,
1413
+ extractor_conv_bias=False,
1414
+ encoder_embed_dim=1024,
1415
+ encoder_projection_dropout=encoder_projection_dropout,
1416
+ encoder_pos_conv_kernel=128,
1417
+ encoder_pos_conv_groups=16,
1418
+ encoder_num_layers=24,
1419
+ encoder_num_heads=16,
1420
+ encoder_num_buckets=320,
1421
+ encoder_max_distance=800,
1422
+ encoder_attention_dropout=encoder_attention_dropout,
1423
+ encoder_ff_interm_features=4096,
1424
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
1425
+ encoder_dropout=encoder_dropout,
1426
+ encoder_layer_norm_first=True,
1427
+ encoder_layer_drop=encoder_layer_drop,
1428
+ aux_num_out=aux_num_out,
1429
+ )
1430
+
1431
+
1432
+ def wav2vec2_xlsr_300m(
1433
+ encoder_projection_dropout: float = 0.0,
1434
+ encoder_attention_dropout: float = 0.0,
1435
+ encoder_ff_interm_dropout: float = 0.0,
1436
+ encoder_dropout: float = 0.0,
1437
+ encoder_layer_drop: float = 0.0,
1438
+ aux_num_out: Optional[int] = None,
1439
+ ) -> Wav2Vec2Model:
1440
+ """Builds XLS-R model :cite:`babu2021xls` with 300 millions of parameters. The architecture is compatible
1441
+ with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is
1442
+ :class:`~torchaudio.models.Wav2Vec2Model`.
1443
+
1444
+ Args:
1445
+ encoder_projection_dropout (float):
1446
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1447
+ encoder_attention_dropout (float):
1448
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1449
+ encoder_ff_interm_dropout (float):
1450
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1451
+ encoder_dropout (float):
1452
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1453
+ encoder_layer_drop (float):
1454
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1455
+ aux_num_out (int, optional):
1456
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1457
+
1458
+ Returns:
1459
+ Wav2Vec2Model:
1460
+ The resulting model.
1461
+ """
1462
+ return wav2vec2_model(
1463
+ extractor_mode="layer_norm",
1464
+ extractor_conv_layer_config=None,
1465
+ extractor_conv_bias=True,
1466
+ encoder_embed_dim=1024,
1467
+ encoder_projection_dropout=encoder_projection_dropout,
1468
+ encoder_pos_conv_kernel=128,
1469
+ encoder_pos_conv_groups=16,
1470
+ encoder_num_layers=24,
1471
+ encoder_num_heads=16,
1472
+ encoder_attention_dropout=encoder_attention_dropout,
1473
+ encoder_ff_interm_features=4096,
1474
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
1475
+ encoder_dropout=encoder_dropout,
1476
+ encoder_layer_norm_first=True,
1477
+ encoder_layer_drop=encoder_layer_drop,
1478
+ aux_num_out=aux_num_out,
1479
+ )
1480
+
1481
+
1482
+ def wav2vec2_xlsr_1b(
1483
+ encoder_projection_dropout: float = 0.1,
1484
+ encoder_attention_dropout: float = 0.0,
1485
+ encoder_ff_interm_dropout: float = 0.0,
1486
+ encoder_dropout: float = 0.0,
1487
+ encoder_layer_drop: float = 0.0,
1488
+ aux_num_out: Optional[int] = None,
1489
+ ) -> Wav2Vec2Model:
1490
+ """Builds XLS-R model :cite:`babu2021xls` with 1 billion of parameters. The architecture is compatible
1491
+ with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is
1492
+ :class:`~torchaudio.models.Wav2Vec2Model`.
1493
+
1494
+ Args:
1495
+ encoder_projection_dropout (float):
1496
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1497
+ encoder_attention_dropout (float):
1498
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1499
+ encoder_ff_interm_dropout (float):
1500
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1501
+ encoder_dropout (float):
1502
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1503
+ encoder_layer_drop (float):
1504
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1505
+ aux_num_out (int, optional):
1506
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1507
+
1508
+ Returns:
1509
+ Wav2Vec2Model:
1510
+ The resulting model.
1511
+ """
1512
+ return wav2vec2_model(
1513
+ extractor_mode="layer_norm",
1514
+ extractor_conv_layer_config=None,
1515
+ extractor_conv_bias=True,
1516
+ encoder_embed_dim=1280,
1517
+ encoder_projection_dropout=encoder_projection_dropout,
1518
+ encoder_pos_conv_kernel=128,
1519
+ encoder_pos_conv_groups=16,
1520
+ encoder_num_layers=48,
1521
+ encoder_num_heads=16,
1522
+ encoder_attention_dropout=encoder_attention_dropout,
1523
+ encoder_ff_interm_features=5120,
1524
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
1525
+ encoder_dropout=encoder_dropout,
1526
+ encoder_layer_norm_first=True,
1527
+ encoder_layer_drop=encoder_layer_drop,
1528
+ aux_num_out=aux_num_out,
1529
+ )
1530
+
1531
+
1532
+ def wav2vec2_xlsr_2b(
1533
+ encoder_projection_dropout: float = 0.1,
1534
+ encoder_attention_dropout: float = 0.0,
1535
+ encoder_ff_interm_dropout: float = 0.0,
1536
+ encoder_dropout: float = 0.0,
1537
+ encoder_layer_drop: float = 0.0,
1538
+ aux_num_out: Optional[int] = None,
1539
+ ) -> Wav2Vec2Model:
1540
+ """Builds XLS-R model :cite:`babu2021xls` with 2 billions of parameters. The architecture is compatible
1541
+ with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is
1542
+ :class:`~torchaudio.models.Wav2Vec2Model`.
1543
+
1544
+ Args:
1545
+ encoder_projection_dropout (float):
1546
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1547
+ encoder_attention_dropout (float):
1548
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1549
+ encoder_ff_interm_dropout (float):
1550
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1551
+ encoder_dropout (float):
1552
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1553
+ encoder_layer_drop (float):
1554
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1555
+ aux_num_out (int, optional):
1556
+ See :py:func:`~torchaudio.models.wav2vec2_model`.
1557
+
1558
+ Returns:
1559
+ Wav2Vec2Model:
1560
+ The resulting model.
1561
+ """
1562
+ return wav2vec2_model(
1563
+ extractor_mode="layer_norm",
1564
+ extractor_conv_layer_config=None,
1565
+ extractor_conv_bias=True,
1566
+ encoder_embed_dim=1920,
1567
+ encoder_projection_dropout=encoder_projection_dropout,
1568
+ encoder_pos_conv_kernel=128,
1569
+ encoder_pos_conv_groups=16,
1570
+ encoder_num_layers=48,
1571
+ encoder_num_heads=16,
1572
+ encoder_attention_dropout=encoder_attention_dropout,
1573
+ encoder_ff_interm_features=7680,
1574
+ encoder_ff_interm_dropout=encoder_ff_interm_dropout,
1575
+ encoder_dropout=encoder_dropout,
1576
+ encoder_layer_norm_first=True,
1577
+ encoder_layer_drop=encoder_layer_drop,
1578
+ aux_num_out=aux_num_out,
1579
+ )
.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .import_fairseq import import_fairseq_model
2
+ from .import_huggingface import import_huggingface_model
3
+
4
+ __all__ = [
5
+ "import_huggingface_model",
6
+ "import_fairseq_model",
7
+ ]
.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (401 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/import_fairseq.cpython-311.pyc ADDED
Binary file (12.1 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/__pycache__/import_huggingface.cpython-311.pyc ADDED
Binary file (7.91 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/wavlm_attention.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The MIT License (MIT)
3
+
4
+ Copyright (c) Microsoft Corporation
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
23
+ """
24
+
25
+ import math
26
+ from typing import Optional, Tuple
27
+
28
+ import torch
29
+ from torch import nn, Tensor
30
+
31
+
32
+ class WavLMSelfAttention(nn.Module):
33
+ """Multi-headed self-attention for WavLM model :cite:`chen2022wavlm`.
34
+ Wraps around ``torch.nn.MultiheadAttention``, creating relaive position embeddings and passing them to multi-headed
35
+ attention as a mask.
36
+ Source: https://github.com/microsoft/unilm/blob/2d8302f09c99bca2b82e6e868d81d4281cceebc8/wavlm/modules.py#L303-L763
37
+
38
+ Args:
39
+ embed_dim (int): Total dimension of the model.
40
+ num_heads (int): The number of heads.
41
+ dropout (float, optional): Dropout probability on attn_output_weights. (Default: to ``0.0``)
42
+ bias (bool, optional): If ``True``, add bias to input / output projection layers. (Default: ``True``)
43
+ has_relative_attention_bias (bool, optional): If ``True``, apply relative position embedding.
44
+ Necessary in the first encoder layer, but not in the subsequent ones. (Default: ``False``)
45
+ num_buckets (int, optional): Number of buckets for relative position embedding. (Default: ``32``)
46
+ max_distance (int, optional): Naximum distance for relative position embedding. (Default: ``128``)
47
+ gru_rel_pos (bool, optional): If ``True``, apply gated relative position embedding. (Default: ``False``)
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ embed_dim: int,
53
+ num_heads: int,
54
+ dropout: float = 0.0,
55
+ bias: bool = True,
56
+ has_relative_attention_bias: bool = False,
57
+ num_buckets: int = 32,
58
+ max_distance: int = 128,
59
+ gru_rel_pos: bool = True,
60
+ ):
61
+ super().__init__()
62
+ self.embed_dim = embed_dim
63
+ self.num_heads = num_heads
64
+ self.has_relative_attention_bias = has_relative_attention_bias
65
+ self.num_buckets = num_buckets
66
+ self.max_distance = max_distance
67
+
68
+ if has_relative_attention_bias:
69
+ self.rel_attn_embed = nn.Embedding(num_buckets, num_heads)
70
+ else:
71
+ self.rel_attn_embed = None
72
+
73
+ self.head_dim = embed_dim // num_heads
74
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
75
+
76
+ self.dropout = dropout
77
+ self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True)
78
+
79
+ self.gru_rel_pos = gru_rel_pos
80
+ if self.gru_rel_pos:
81
+ self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8)
82
+ self.gru_rel_pos_const = nn.Parameter(torch.ones(1, num_heads, 1, 1))
83
+ self.has_position_bias = True
84
+
85
+ def compute_bias(self, query_length: int, key_length: int) -> Tensor:
86
+ """Compute relative position embeddings for WavLM model.
87
+ Args:
88
+ query_length (int): Query position can take values between 0 and ``query_length - 1``.
89
+ key_length (int): Key position can take values between 0 and ``key_length - 1``.
90
+ Returns:
91
+ Tensor of shape `(num_heads, query_length, key_length)`, relative positions embeddings
92
+ """
93
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
94
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
95
+ relative_position = memory_position - context_position # Shape (query_length, key_length)
96
+ relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True)
97
+ relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device)
98
+ values = self.rel_attn_embed(relative_position_bucket) # Shape (query_length, key_length, num_heads)
99
+ values = values.permute([2, 0, 1])
100
+ return values
101
+
102
+ def _relative_positions_bucket(self, relative_positions: Tensor, bidirectional: bool = True):
103
+ """Compute relative position buckets for WavLM model. Computation similar to formula (5) in WavLM
104
+ paper :cite:`chen2022wavlm`.
105
+ Args:
106
+ relative_positions (Tensor): Relative offsets between query and key positions,
107
+ of shape ``(query_length, key_length)``.
108
+ bidirectional (bool): If ``True``, values will be filled both above and below the diagonal in the resulting
109
+ matrix. If ``False``, the elements above the diagonal (i.e. with negative relative offsets) will be set
110
+ to zero. (Default ``True``)
111
+ Returns:
112
+ Tensor of shape ``(query_length, key_length)`` filled bucketed values of with relative positions.
113
+ """
114
+ num_buckets = self.num_buckets
115
+ max_distance = self.max_distance
116
+ # Shape (query_length, key_length)
117
+ relative_buckets = torch.zeros_like(relative_positions, dtype=torch.long)
118
+
119
+ if bidirectional:
120
+ num_buckets = num_buckets // 2
121
+ relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
122
+ relative_positions = torch.abs(relative_positions)
123
+ else:
124
+ relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
125
+
126
+ max_exact = num_buckets // 2
127
+ is_small = relative_positions < max_exact
128
+
129
+ relative_postion_if_large = max_exact + (
130
+ torch.log(relative_positions.float() / max_exact)
131
+ / math.log(max_distance / max_exact)
132
+ * (num_buckets - max_exact)
133
+ ).to(torch.long)
134
+ relative_postion_if_large = torch.min(
135
+ relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
136
+ )
137
+
138
+ relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
139
+ return relative_buckets
140
+
141
+ def forward(
142
+ self,
143
+ query: Tensor,
144
+ key_padding_mask: Optional[Tensor] = None,
145
+ attention_mask: Optional[Tensor] = None,
146
+ position_bias: Optional[Tensor] = None,
147
+ ) -> Tuple[Tensor, Optional[Tensor]]:
148
+ """
149
+ Args:
150
+ query (Tensor): Input of shape ``(batch_size, src_len, embed_dim)``.
151
+ key_padding_mask (Tensor or None, optional): Mask to exclude keys that are pads, of shape
152
+ `(batch, src_len)`, where padding elements are indicated by 1s. (Default: ``None``)
153
+ attn_mask: Needs to be ``None``. The argument exists for compatibility with
154
+ ``EncoderLayer``. (Default: ``None``)
155
+ position_bias (Tensor or None, optional): Position bias of shape
156
+ ``(batch_size * num_heads, src_len, src_len)``. When used inside WavLM model encoder, will be
157
+ generated in the first layer and then passed from each encoder layer to the next one.
158
+ (Default: ``None``)
159
+ Returns:
160
+ attn_output (Tensor): Attention output of shape ``(batch_size, src_len, embed_dim)``.
161
+ position_bias (Tensor or None): Position bias of shape ``(batch_size * num_heads, src_len, src_len)``.
162
+ """
163
+ bsz, seq_len, embed_dim = query.size()
164
+ assert embed_dim == self.embed_dim
165
+ assert attention_mask is None
166
+
167
+ if self.rel_attn_embed is not None and position_bias is None:
168
+ position_bias = self.compute_bias(seq_len, seq_len)
169
+ position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1)
170
+
171
+ attn_mask_rel_pos: Optional[Tensor] = None
172
+ if position_bias is not None:
173
+ attn_mask_rel_pos = position_bias
174
+ if self.gru_rel_pos: # Apply gating on relative position bias
175
+ query_layer = query.view(bsz, seq_len, self.num_heads, -1)
176
+ query_layer = query_layer.permute(0, 2, 1, 3)
177
+
178
+ gate_a, gate_b = torch.sigmoid(
179
+ self.gru_rel_pos_linear(query_layer).view(bsz, self.num_heads, seq_len, 2, 4).sum(-1, keepdim=False)
180
+ ).chunk(2, dim=-1)
181
+ gate_a_1 = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0
182
+ attn_mask_rel_pos = gate_a_1.view(bsz, self.num_heads, -1, 1) * position_bias
183
+
184
+ attn_mask_rel_pos = attn_mask_rel_pos.view((bsz, self.num_heads, seq_len, seq_len))
185
+
186
+ if attn_mask_rel_pos is not None and key_padding_mask is not None:
187
+ key_padding_mask = key_padding_mask.view(bsz, 1, 1, seq_len).expand(-1, self.num_heads, -1, -1)
188
+ key_padding_mask = torch.nn.functional._canonical_mask(
189
+ mask=key_padding_mask,
190
+ mask_name="key_padding_mask",
191
+ other_type=torch.nn.functional._none_or_dtype(attn_mask_rel_pos),
192
+ other_name="",
193
+ target_type=query.dtype,
194
+ )
195
+ if attn_mask_rel_pos is not None and key_padding_mask is not None:
196
+ attn_mask_rel_pos = attn_mask_rel_pos + key_padding_mask
197
+ query_projected = torch.nn.functional.linear(query, self.attention.in_proj_weight, self.attention.in_proj_bias)
198
+ query, key, value = query_projected.chunk(3, -1)
199
+ shape = (bsz, seq_len, self.num_heads, self.head_dim)
200
+ query = query.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim)
201
+ key = key.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim)
202
+ value = value.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim)
203
+ dropout = self.dropout if self.training else 0.0
204
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
205
+ query,
206
+ key,
207
+ value,
208
+ attn_mask=attn_mask_rel_pos,
209
+ dropout_p=dropout,
210
+ is_causal=False,
211
+ )
212
+ attn_output = attn_output.transpose(1, 2).reshape(bsz, -1, self.num_heads * self.head_dim)
213
+ attn_output = self.attention.out_proj(attn_output)
214
+ return attn_output, position_bias
.venv/lib/python3.11/site-packages/torchaudio/models/wavernn.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn, Tensor
7
+
8
+ __all__ = [
9
+ "ResBlock",
10
+ "MelResNet",
11
+ "Stretch2d",
12
+ "UpsampleNetwork",
13
+ "WaveRNN",
14
+ ]
15
+
16
+
17
+ class ResBlock(nn.Module):
18
+ r"""ResNet block based on *Efficient Neural Audio Synthesis* :cite:`kalchbrenner2018efficient`.
19
+
20
+ Args:
21
+ n_freq: the number of bins in a spectrogram. (Default: ``128``)
22
+
23
+ Examples
24
+ >>> resblock = ResBlock()
25
+ >>> input = torch.rand(10, 128, 512) # a random spectrogram
26
+ >>> output = resblock(input) # shape: (10, 128, 512)
27
+ """
28
+
29
+ def __init__(self, n_freq: int = 128) -> None:
30
+ super().__init__()
31
+
32
+ self.resblock_model = nn.Sequential(
33
+ nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False),
34
+ nn.BatchNorm1d(n_freq),
35
+ nn.ReLU(inplace=True),
36
+ nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False),
37
+ nn.BatchNorm1d(n_freq),
38
+ )
39
+
40
+ def forward(self, specgram: Tensor) -> Tensor:
41
+ r"""Pass the input through the ResBlock layer.
42
+ Args:
43
+ specgram (Tensor): the input sequence to the ResBlock layer (n_batch, n_freq, n_time).
44
+
45
+ Return:
46
+ Tensor shape: (n_batch, n_freq, n_time)
47
+ """
48
+
49
+ return self.resblock_model(specgram) + specgram
50
+
51
+
52
+ class MelResNet(nn.Module):
53
+ r"""MelResNet layer uses a stack of ResBlocks on spectrogram.
54
+
55
+ Args:
56
+ n_res_block: the number of ResBlock in stack. (Default: ``10``)
57
+ n_freq: the number of bins in a spectrogram. (Default: ``128``)
58
+ n_hidden: the number of hidden dimensions of resblock. (Default: ``128``)
59
+ n_output: the number of output dimensions of melresnet. (Default: ``128``)
60
+ kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
61
+
62
+ Examples
63
+ >>> melresnet = MelResNet()
64
+ >>> input = torch.rand(10, 128, 512) # a random spectrogram
65
+ >>> output = melresnet(input) # shape: (10, 128, 508)
66
+ """
67
+
68
+ def __init__(
69
+ self, n_res_block: int = 10, n_freq: int = 128, n_hidden: int = 128, n_output: int = 128, kernel_size: int = 5
70
+ ) -> None:
71
+ super().__init__()
72
+
73
+ ResBlocks = [ResBlock(n_hidden) for _ in range(n_res_block)]
74
+
75
+ self.melresnet_model = nn.Sequential(
76
+ nn.Conv1d(in_channels=n_freq, out_channels=n_hidden, kernel_size=kernel_size, bias=False),
77
+ nn.BatchNorm1d(n_hidden),
78
+ nn.ReLU(inplace=True),
79
+ *ResBlocks,
80
+ nn.Conv1d(in_channels=n_hidden, out_channels=n_output, kernel_size=1),
81
+ )
82
+
83
+ def forward(self, specgram: Tensor) -> Tensor:
84
+ r"""Pass the input through the MelResNet layer.
85
+ Args:
86
+ specgram (Tensor): the input sequence to the MelResNet layer (n_batch, n_freq, n_time).
87
+
88
+ Return:
89
+ Tensor shape: (n_batch, n_output, n_time - kernel_size + 1)
90
+ """
91
+
92
+ return self.melresnet_model(specgram)
93
+
94
+
95
+ class Stretch2d(nn.Module):
96
+ r"""Upscale the frequency and time dimensions of a spectrogram.
97
+
98
+ Args:
99
+ time_scale: the scale factor in time dimension
100
+ freq_scale: the scale factor in frequency dimension
101
+
102
+ Examples
103
+ >>> stretch2d = Stretch2d(time_scale=10, freq_scale=5)
104
+
105
+ >>> input = torch.rand(10, 100, 512) # a random spectrogram
106
+ >>> output = stretch2d(input) # shape: (10, 500, 5120)
107
+ """
108
+
109
+ def __init__(self, time_scale: int, freq_scale: int) -> None:
110
+ super().__init__()
111
+
112
+ self.freq_scale = freq_scale
113
+ self.time_scale = time_scale
114
+
115
+ def forward(self, specgram: Tensor) -> Tensor:
116
+ r"""Pass the input through the Stretch2d layer.
117
+
118
+ Args:
119
+ specgram (Tensor): the input sequence to the Stretch2d layer (..., n_freq, n_time).
120
+
121
+ Return:
122
+ Tensor shape: (..., n_freq * freq_scale, n_time * time_scale)
123
+ """
124
+
125
+ return specgram.repeat_interleave(self.freq_scale, -2).repeat_interleave(self.time_scale, -1)
126
+
127
+
128
+ class UpsampleNetwork(nn.Module):
129
+ r"""Upscale the dimensions of a spectrogram.
130
+
131
+ Args:
132
+ upsample_scales: the list of upsample scales.
133
+ n_res_block: the number of ResBlock in stack. (Default: ``10``)
134
+ n_freq: the number of bins in a spectrogram. (Default: ``128``)
135
+ n_hidden: the number of hidden dimensions of resblock. (Default: ``128``)
136
+ n_output: the number of output dimensions of melresnet. (Default: ``128``)
137
+ kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
138
+
139
+ Examples
140
+ >>> upsamplenetwork = UpsampleNetwork(upsample_scales=[4, 4, 16])
141
+ >>> input = torch.rand(10, 128, 10) # a random spectrogram
142
+ >>> output = upsamplenetwork(input) # shape: (10, 128, 1536), (10, 128, 1536)
143
+ """
144
+
145
+ def __init__(
146
+ self,
147
+ upsample_scales: List[int],
148
+ n_res_block: int = 10,
149
+ n_freq: int = 128,
150
+ n_hidden: int = 128,
151
+ n_output: int = 128,
152
+ kernel_size: int = 5,
153
+ ) -> None:
154
+ super().__init__()
155
+
156
+ total_scale = 1
157
+ for upsample_scale in upsample_scales:
158
+ total_scale *= upsample_scale
159
+ self.total_scale: int = total_scale
160
+
161
+ self.indent = (kernel_size - 1) // 2 * total_scale
162
+ self.resnet = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
163
+ self.resnet_stretch = Stretch2d(total_scale, 1)
164
+
165
+ up_layers = []
166
+ for scale in upsample_scales:
167
+ stretch = Stretch2d(scale, 1)
168
+ conv = nn.Conv2d(
169
+ in_channels=1, out_channels=1, kernel_size=(1, scale * 2 + 1), padding=(0, scale), bias=False
170
+ )
171
+ torch.nn.init.constant_(conv.weight, 1.0 / (scale * 2 + 1))
172
+ up_layers.append(stretch)
173
+ up_layers.append(conv)
174
+ self.upsample_layers = nn.Sequential(*up_layers)
175
+
176
+ def forward(self, specgram: Tensor) -> Tuple[Tensor, Tensor]:
177
+ r"""Pass the input through the UpsampleNetwork layer.
178
+
179
+ Args:
180
+ specgram (Tensor): the input sequence to the UpsampleNetwork layer (n_batch, n_freq, n_time)
181
+
182
+ Return:
183
+ Tensor shape: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale),
184
+ (n_batch, n_output, (n_time - kernel_size + 1) * total_scale)
185
+ where total_scale is the product of all elements in upsample_scales.
186
+ """
187
+
188
+ resnet_output = self.resnet(specgram).unsqueeze(1)
189
+ resnet_output = self.resnet_stretch(resnet_output)
190
+ resnet_output = resnet_output.squeeze(1)
191
+
192
+ specgram = specgram.unsqueeze(1)
193
+ upsampling_output = self.upsample_layers(specgram)
194
+ upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent : -self.indent]
195
+
196
+ return upsampling_output, resnet_output
197
+
198
+
199
+ class WaveRNN(nn.Module):
200
+ r"""WaveRNN model from *Efficient Neural Audio Synthesis* :cite:`wavernn`
201
+ based on the implementation from `fatchord/WaveRNN <https://github.com/fatchord/WaveRNN>`_.
202
+
203
+ The original implementation was introduced in *Efficient Neural Audio Synthesis*
204
+ :cite:`kalchbrenner2018efficient`. The input channels of waveform and spectrogram have to be 1.
205
+ The product of `upsample_scales` must equal `hop_length`.
206
+
207
+ See Also:
208
+ * `Training example <https://github.com/pytorch/audio/tree/release/0.12/examples/pipeline_wavernn>`__
209
+ * :class:`torchaudio.pipelines.Tacotron2TTSBundle`: TTS pipeline with pretrained model.
210
+
211
+ Args:
212
+ upsample_scales: the list of upsample scales.
213
+ n_classes: the number of output classes.
214
+ hop_length: the number of samples between the starts of consecutive frames.
215
+ n_res_block: the number of ResBlock in stack. (Default: ``10``)
216
+ n_rnn: the dimension of RNN layer. (Default: ``512``)
217
+ n_fc: the dimension of fully connected layer. (Default: ``512``)
218
+ kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
219
+ n_freq: the number of bins in a spectrogram. (Default: ``128``)
220
+ n_hidden: the number of hidden dimensions of resblock. (Default: ``128``)
221
+ n_output: the number of output dimensions of melresnet. (Default: ``128``)
222
+
223
+ Example
224
+ >>> wavernn = WaveRNN(upsample_scales=[5,5,8], n_classes=512, hop_length=200)
225
+ >>> waveform, sample_rate = torchaudio.load(file)
226
+ >>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length)
227
+ >>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time)
228
+ >>> output = wavernn(waveform, specgram)
229
+ >>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, n_classes)
230
+ """
231
+
232
+ def __init__(
233
+ self,
234
+ upsample_scales: List[int],
235
+ n_classes: int,
236
+ hop_length: int,
237
+ n_res_block: int = 10,
238
+ n_rnn: int = 512,
239
+ n_fc: int = 512,
240
+ kernel_size: int = 5,
241
+ n_freq: int = 128,
242
+ n_hidden: int = 128,
243
+ n_output: int = 128,
244
+ ) -> None:
245
+ super().__init__()
246
+
247
+ self.kernel_size = kernel_size
248
+ self._pad = (kernel_size - 1 if kernel_size % 2 else kernel_size) // 2
249
+ self.n_rnn = n_rnn
250
+ self.n_aux = n_output // 4
251
+ self.hop_length = hop_length
252
+ self.n_classes = n_classes
253
+ self.n_bits: int = int(math.log2(self.n_classes))
254
+
255
+ total_scale = 1
256
+ for upsample_scale in upsample_scales:
257
+ total_scale *= upsample_scale
258
+ if total_scale != self.hop_length:
259
+ raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}")
260
+
261
+ self.upsample = UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size)
262
+ self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn)
263
+
264
+ self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True)
265
+ self.rnn2 = nn.GRU(n_rnn + self.n_aux, n_rnn, batch_first=True)
266
+
267
+ self.relu1 = nn.ReLU(inplace=True)
268
+ self.relu2 = nn.ReLU(inplace=True)
269
+
270
+ self.fc1 = nn.Linear(n_rnn + self.n_aux, n_fc)
271
+ self.fc2 = nn.Linear(n_fc + self.n_aux, n_fc)
272
+ self.fc3 = nn.Linear(n_fc, self.n_classes)
273
+
274
+ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor:
275
+ r"""Pass the input through the WaveRNN model.
276
+
277
+ Args:
278
+ waveform: the input waveform to the WaveRNN layer (n_batch, 1, (n_time - kernel_size + 1) * hop_length)
279
+ specgram: the input spectrogram to the WaveRNN layer (n_batch, 1, n_freq, n_time)
280
+
281
+ Return:
282
+ Tensor: shape (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes)
283
+ """
284
+
285
+ if waveform.size(1) != 1:
286
+ raise ValueError("Require the input channel of waveform is 1")
287
+ if specgram.size(1) != 1:
288
+ raise ValueError("Require the input channel of specgram is 1")
289
+ # remove channel dimension until the end
290
+ waveform, specgram = waveform.squeeze(1), specgram.squeeze(1)
291
+
292
+ batch_size = waveform.size(0)
293
+ h1 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device)
294
+ h2 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device)
295
+ # output of upsample:
296
+ # specgram: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale)
297
+ # aux: (n_batch, n_output, (n_time - kernel_size + 1) * total_scale)
298
+ specgram, aux = self.upsample(specgram)
299
+ specgram = specgram.transpose(1, 2)
300
+ aux = aux.transpose(1, 2)
301
+
302
+ aux_idx = [self.n_aux * i for i in range(5)]
303
+ a1 = aux[:, :, aux_idx[0] : aux_idx[1]]
304
+ a2 = aux[:, :, aux_idx[1] : aux_idx[2]]
305
+ a3 = aux[:, :, aux_idx[2] : aux_idx[3]]
306
+ a4 = aux[:, :, aux_idx[3] : aux_idx[4]]
307
+
308
+ x = torch.cat([waveform.unsqueeze(-1), specgram, a1], dim=-1)
309
+ x = self.fc(x)
310
+ res = x
311
+ x, _ = self.rnn1(x, h1)
312
+
313
+ x = x + res
314
+ res = x
315
+ x = torch.cat([x, a2], dim=-1)
316
+ x, _ = self.rnn2(x, h2)
317
+
318
+ x = x + res
319
+ x = torch.cat([x, a3], dim=-1)
320
+ x = self.fc1(x)
321
+ x = self.relu1(x)
322
+
323
+ x = torch.cat([x, a4], dim=-1)
324
+ x = self.fc2(x)
325
+ x = self.relu2(x)
326
+ x = self.fc3(x)
327
+
328
+ # bring back channel dimension
329
+ return x.unsqueeze(1)
330
+
331
+ @torch.jit.export
332
+ def infer(self, specgram: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
333
+ r"""Inference method of WaveRNN.
334
+
335
+ This function currently only supports multinomial sampling, which assumes the
336
+ network is trained on cross entropy loss.
337
+
338
+ Args:
339
+ specgram (Tensor):
340
+ Batch of spectrograms. Shape: `(n_batch, n_freq, n_time)`.
341
+ lengths (Tensor or None, optional):
342
+ Indicates the valid length of each audio in the batch.
343
+ Shape: `(batch, )`.
344
+ When the ``specgram`` contains spectrograms with different durations,
345
+ by providing ``lengths`` argument, the model will compute
346
+ the corresponding valid output lengths.
347
+ If ``None``, it is assumed that all the audio in ``waveforms``
348
+ have valid length. Default: ``None``.
349
+
350
+ Returns:
351
+ (Tensor, Optional[Tensor]):
352
+ Tensor
353
+ The inferred waveform of size `(n_batch, 1, n_time)`.
354
+ 1 stands for a single channel.
355
+ Tensor or None
356
+ If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
357
+ is returned.
358
+ It indicates the valid length in time axis of the output Tensor.
359
+ """
360
+
361
+ device = specgram.device
362
+ dtype = specgram.dtype
363
+
364
+ specgram = torch.nn.functional.pad(specgram, (self._pad, self._pad))
365
+ specgram, aux = self.upsample(specgram)
366
+ if lengths is not None:
367
+ lengths = lengths * self.upsample.total_scale
368
+
369
+ output: List[Tensor] = []
370
+ b_size, _, seq_len = specgram.size()
371
+
372
+ h1 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype)
373
+ h2 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype)
374
+ x = torch.zeros((b_size, 1), device=device, dtype=dtype)
375
+
376
+ aux_split = [aux[:, self.n_aux * i : self.n_aux * (i + 1), :] for i in range(4)]
377
+
378
+ for i in range(seq_len):
379
+
380
+ m_t = specgram[:, :, i]
381
+
382
+ a1_t, a2_t, a3_t, a4_t = [a[:, :, i] for a in aux_split]
383
+
384
+ x = torch.cat([x, m_t, a1_t], dim=1)
385
+ x = self.fc(x)
386
+ _, h1 = self.rnn1(x.unsqueeze(1), h1)
387
+
388
+ x = x + h1[0]
389
+ inp = torch.cat([x, a2_t], dim=1)
390
+ _, h2 = self.rnn2(inp.unsqueeze(1), h2)
391
+
392
+ x = x + h2[0]
393
+ x = torch.cat([x, a3_t], dim=1)
394
+ x = F.relu(self.fc1(x))
395
+
396
+ x = torch.cat([x, a4_t], dim=1)
397
+ x = F.relu(self.fc2(x))
398
+
399
+ logits = self.fc3(x)
400
+
401
+ posterior = F.softmax(logits, dim=1)
402
+
403
+ x = torch.multinomial(posterior, 1).float()
404
+ # Transform label [0, 2 ** n_bits - 1] to waveform [-1, 1]
405
+ x = 2 * x / (2**self.n_bits - 1.0) - 1.0
406
+
407
+ output.append(x)
408
+
409
+ return torch.stack(output).permute(1, 2, 0), lengths
.venv/lib/python3.11/site-packages/torchaudio/prototype/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/torchaudio/prototype/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (193 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .musan import Musan
2
+
3
+
4
+ __all__ = ["Musan"]
.venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (278 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/__pycache__/musan.cpython-311.pyc ADDED
Binary file (3.72 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/prototype/datasets/musan.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Tuple, Union
3
+
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from torchaudio.datasets.utils import _load_waveform
7
+
8
+
9
+ _SUBSETS = ["music", "noise", "speech"]
10
+ _SAMPLE_RATE = 16_000
11
+
12
+
13
+ class Musan(Dataset):
14
+ r"""*MUSAN* :cite:`musan2015` dataset.
15
+
16
+ Args:
17
+ root (str or Path): Root directory where the dataset's top-level directory exists.
18
+ subset (str): Subset of the dataset to use. Options: [``"music"``, ``"noise"``, ``"speech"``].
19
+ """
20
+
21
+ def __init__(self, root: Union[str, Path], subset: str):
22
+ if subset not in _SUBSETS:
23
+ raise ValueError(f"Invalid subset '{subset}' given. Please provide one of {_SUBSETS}")
24
+
25
+ subset_path = Path(root) / subset
26
+ self._walker = [str(p) for p in subset_path.glob("*/*.*")]
27
+
28
+ def get_metadata(self, n: int) -> Tuple[str, int, str]:
29
+ r"""Get metadata for the n-th sample in the dataset. Returns filepath instead of waveform,
30
+ but otherwise returns the same fields as :py:func:`__getitem__`.
31
+
32
+ Args:
33
+ n (int): Index of sample to be loaded.
34
+
35
+ Returns:
36
+ (str, int, str):
37
+ str
38
+ Path to audio.
39
+ int
40
+ Sample rate.
41
+ str
42
+ File name.
43
+ """
44
+ audio_path = self._walker[n]
45
+ return audio_path, _SAMPLE_RATE, Path(audio_path).name
46
+
47
+ def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str]:
48
+ r"""Return the n-th sample in the dataset.
49
+
50
+ Args:
51
+ n (int): Index of sample to be loaded.
52
+
53
+ Returns:
54
+ (torch.Tensor, int, str):
55
+ torch.Tensor
56
+ Waveform.
57
+ int
58
+ Sample rate.
59
+ str
60
+ File name.
61
+ """
62
+ audio_path, sample_rate, filename = self.get_metadata(n)
63
+ path = Path(audio_path)
64
+ return _load_waveform(path.parent, path.name, sample_rate), sample_rate, filename
65
+
66
+ def __len__(self) -> int:
67
+ return len(self._walker)
.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._dsp import (
2
+ adsr_envelope,
3
+ exp_sigmoid,
4
+ extend_pitch,
5
+ filter_waveform,
6
+ frequency_impulse_response,
7
+ oscillator_bank,
8
+ sinc_impulse_response,
9
+ )
10
+ from ._rir import ray_tracing, simulate_rir_ism
11
+ from .functional import barkscale_fbanks, chroma_filterbank
12
+
13
+
14
+ __all__ = [
15
+ "adsr_envelope",
16
+ "exp_sigmoid",
17
+ "barkscale_fbanks",
18
+ "chroma_filterbank",
19
+ "extend_pitch",
20
+ "filter_waveform",
21
+ "frequency_impulse_response",
22
+ "oscillator_bank",
23
+ "ray_tracing",
24
+ "sinc_impulse_response",
25
+ "simulate_rir_ism",
26
+ ]
.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (779 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/_dsp.cpython-311.pyc ADDED
Binary file (20.3 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/_rir.cpython-311.pyc ADDED
Binary file (21.5 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/__pycache__/functional.cpython-311.pyc ADDED
Binary file (8.67 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/_dsp.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import List, Optional, Union
3
+
4
+ import torch
5
+
6
+ from torchaudio.functional import fftconvolve
7
+
8
+
9
+ def oscillator_bank(
10
+ frequencies: torch.Tensor,
11
+ amplitudes: torch.Tensor,
12
+ sample_rate: float,
13
+ reduction: str = "sum",
14
+ dtype: Optional[torch.dtype] = torch.float64,
15
+ ) -> torch.Tensor:
16
+ """Synthesize waveform from the given instantaneous frequencies and amplitudes.
17
+
18
+ .. devices:: CPU CUDA
19
+
20
+ .. properties:: Autograd TorchScript
21
+
22
+ Note:
23
+ The phase information of the output waveform is found by taking the cumulative sum
24
+ of the given instantaneous frequencies (``frequencies``).
25
+ This incurs roundoff error when the data type does not have enough precision.
26
+ Using ``torch.float64`` can work around this.
27
+
28
+ The following figure shows the difference between ``torch.float32`` and
29
+ ``torch.float64`` when generating a sin wave of constant frequency and amplitude
30
+ with sample rate 8000 [Hz].
31
+ Notice that ``torch.float32`` version shows artifacts that are not seen in
32
+ ``torch.float64`` version.
33
+
34
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/oscillator_precision.png
35
+
36
+ Args:
37
+ frequencies (Tensor): Sample-wise oscillator frequencies (Hz). Shape `(..., time, N)`.
38
+ amplitudes (Tensor): Sample-wise oscillator amplitude. Shape: `(..., time, N)`.
39
+ sample_rate (float): Sample rate
40
+ reduction (str): Reduction to perform.
41
+ Valid values are ``"sum"``, ``"mean"`` or ``"none"``. Default: ``"sum"``
42
+ dtype (torch.dtype or None, optional): The data type on which cumulative sum operation is performed.
43
+ Default: ``torch.float64``. Pass ``None`` to disable the casting.
44
+
45
+ Returns:
46
+ Tensor:
47
+ The resulting waveform.
48
+
49
+ If ``reduction`` is ``"none"``, then the shape is
50
+ `(..., time, N)`, otherwise the shape is `(..., time)`.
51
+ """
52
+ if frequencies.shape != amplitudes.shape:
53
+ raise ValueError(
54
+ "The shapes of `frequencies` and `amplitudes` must match. "
55
+ f"Found: {frequencies.shape} and {amplitudes.shape} respectively."
56
+ )
57
+ reductions = ["sum", "mean", "none"]
58
+ if reduction not in reductions:
59
+ raise ValueError(f"The value of reduction must be either {reductions}. Found: {reduction}")
60
+
61
+ invalid = torch.abs(frequencies) >= sample_rate / 2
62
+ if torch.any(invalid):
63
+ warnings.warn(
64
+ "Some frequencies are above nyquist frequency. "
65
+ "Setting the corresponding amplitude to zero. "
66
+ "This might cause numerically unstable gradient."
67
+ )
68
+ amplitudes = torch.where(invalid, 0.0, amplitudes)
69
+
70
+ pi2 = 2.0 * torch.pi
71
+ freqs = frequencies * pi2 / sample_rate % pi2
72
+ phases = torch.cumsum(freqs, dim=-2, dtype=dtype)
73
+ if dtype is not None and freqs.dtype != dtype:
74
+ phases = phases.to(freqs.dtype)
75
+
76
+ waveform = amplitudes * torch.sin(phases)
77
+ if reduction == "sum":
78
+ return waveform.sum(-1)
79
+ if reduction == "mean":
80
+ return waveform.mean(-1)
81
+ return waveform
82
+
83
+
84
+ def adsr_envelope(
85
+ num_frames: int,
86
+ *,
87
+ attack: float = 0.0,
88
+ hold: float = 0.0,
89
+ decay: float = 0.0,
90
+ sustain: float = 1.0,
91
+ release: float = 0.0,
92
+ n_decay: int = 2,
93
+ dtype: Optional[torch.dtype] = None,
94
+ device: Optional[torch.device] = None,
95
+ ):
96
+ """Generate ADSR Envelope
97
+
98
+ .. devices:: CPU CUDA
99
+
100
+ Args:
101
+ num_frames (int): The number of output frames.
102
+ attack (float, optional):
103
+ The relative *time* it takes to reach the maximum level from
104
+ the start. (Default: ``0.0``)
105
+ hold (float, optional):
106
+ The relative *time* the maximum level is held before
107
+ it starts to decay. (Default: ``0.0``)
108
+ decay (float, optional):
109
+ The relative *time* it takes to sustain from
110
+ the maximum level. (Default: ``0.0``)
111
+ sustain (float, optional): The relative *level* at which
112
+ the sound should sustain. (Default: ``1.0``)
113
+
114
+ .. Note::
115
+ The duration of sustain is derived as `1.0 - (The sum of attack, hold, decay and release)`.
116
+
117
+ release (float, optional): The relative *time* it takes for the sound level to
118
+ reach zero after the sustain. (Default: ``0.0``)
119
+ n_decay (int, optional): The degree of polynomial decay. Default: ``2``.
120
+ dtype (torch.dtype, optional): the desired data type of returned tensor.
121
+ Default: if ``None``, uses a global default
122
+ (see :py:func:`torch.set_default_tensor_type`).
123
+ device (torch.device, optional): the desired device of returned tensor.
124
+ Default: if ``None``, uses the current device for the default tensor type
125
+ (see :py:func:`torch.set_default_tensor_type`).
126
+ device will be the CPU for CPU tensor types and the current CUDA
127
+ device for CUDA tensor types.
128
+
129
+ Returns:
130
+ Tensor: ADSR Envelope. Shape: `(num_frames, )`
131
+
132
+ Example
133
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/adsr_examples.png
134
+
135
+ """
136
+ if not 0 <= attack <= 1:
137
+ raise ValueError(f"The value of `attack` must be within [0, 1]. Found: {attack}")
138
+ if not 0 <= decay <= 1:
139
+ raise ValueError(f"The value of `decay` must be within [0, 1]. Found: {decay}")
140
+ if not 0 <= sustain <= 1:
141
+ raise ValueError(f"The value of `sustain` must be within [0, 1]. Found: {sustain}")
142
+ if not 0 <= hold <= 1:
143
+ raise ValueError(f"The value of `hold` must be within [0, 1]. Found: {hold}")
144
+ if not 0 <= release <= 1:
145
+ raise ValueError(f"The value of `release` must be within [0, 1]. Found: {release}")
146
+ if attack + decay + release + hold > 1:
147
+ raise ValueError("The sum of `attack`, `hold`, `decay` and `release` must not exceed 1.")
148
+
149
+ nframes = num_frames - 1
150
+ num_a = int(nframes * attack)
151
+ num_h = int(nframes * hold)
152
+ num_d = int(nframes * decay)
153
+ num_r = int(nframes * release)
154
+
155
+ # Initialize with sustain
156
+ out = torch.full((num_frames,), float(sustain), device=device, dtype=dtype)
157
+
158
+ # attack
159
+ if num_a > 0:
160
+ torch.linspace(0.0, 1.0, num_a + 1, out=out[: num_a + 1])
161
+
162
+ # hold
163
+ if num_h > 0:
164
+ out[num_a : num_a + num_h + 1] = 1.0
165
+
166
+ # decay
167
+ if num_d > 0:
168
+ # Compute: sustain + (1.0 - sustain) * (linspace[1, 0] ** n_decay)
169
+ i = num_a + num_h
170
+ decay = out[i : i + num_d + 1]
171
+ torch.linspace(1.0, 0.0, num_d + 1, out=decay)
172
+ decay **= n_decay
173
+ decay *= 1.0 - sustain
174
+ decay += sustain
175
+
176
+ # sustain is handled by initialization
177
+
178
+ # release
179
+ if num_r > 0:
180
+ torch.linspace(sustain, 0, num_r + 1, out=out[-num_r - 1 :])
181
+
182
+ return out
183
+
184
+
185
+ def extend_pitch(
186
+ base: torch.Tensor,
187
+ pattern: Union[int, List[float], torch.Tensor],
188
+ ):
189
+ """Extend the given time series values with multipliers of them.
190
+
191
+ .. devices:: CPU CUDA
192
+
193
+ .. properties:: Autograd TorchScript
194
+
195
+ Given a series of fundamental frequencies (pitch), this function appends
196
+ its harmonic overtones or inharmonic partials.
197
+
198
+ Args:
199
+ base (torch.Tensor):
200
+ Base time series, like fundamental frequencies (Hz). Shape: `(..., time, 1)`.
201
+ pattern (int, list of floats or torch.Tensor):
202
+ If ``int``, the number of pitch series after the operation.
203
+ `pattern - 1` tones are added, so that the resulting Tensor contains
204
+ up to `pattern`-th overtones of the given series.
205
+
206
+ If list of float or ``torch.Tensor``, it must be one dimensional,
207
+ representing the custom multiplier of the fundamental frequency.
208
+
209
+ Returns:
210
+ Tensor: Oscillator frequencies (Hz). Shape: `(..., time, num_tones)`.
211
+
212
+ Example
213
+ >>> # fundamental frequency
214
+ >>> f0 = torch.linspace(1, 5, 5).unsqueeze(-1)
215
+ >>> f0
216
+ tensor([[1.],
217
+ [2.],
218
+ [3.],
219
+ [4.],
220
+ [5.]])
221
+ >>> # Add harmonic overtones, up to 3rd.
222
+ >>> f = extend_pitch(f0, 3)
223
+ >>> f.shape
224
+ torch.Size([5, 3])
225
+ >>> f
226
+ tensor([[ 1., 2., 3.],
227
+ [ 2., 4., 6.],
228
+ [ 3., 6., 9.],
229
+ [ 4., 8., 12.],
230
+ [ 5., 10., 15.]])
231
+ >>> # Add custom (inharmonic) partials.
232
+ >>> f = extend_pitch(f0, torch.tensor([1, 2.1, 3.3, 4.5]))
233
+ >>> f.shape
234
+ torch.Size([5, 4])
235
+ >>> f
236
+ tensor([[ 1.0000, 2.1000, 3.3000, 4.5000],
237
+ [ 2.0000, 4.2000, 6.6000, 9.0000],
238
+ [ 3.0000, 6.3000, 9.9000, 13.5000],
239
+ [ 4.0000, 8.4000, 13.2000, 18.0000],
240
+ [ 5.0000, 10.5000, 16.5000, 22.5000]])
241
+ """
242
+ if isinstance(pattern, torch.Tensor):
243
+ mult = pattern
244
+ elif isinstance(pattern, int):
245
+ mult = torch.linspace(1.0, float(pattern), pattern, device=base.device, dtype=base.dtype)
246
+ else:
247
+ mult = torch.tensor(pattern, dtype=base.dtype, device=base.device)
248
+ h_freq = base @ mult.unsqueeze(0)
249
+ return h_freq
250
+
251
+
252
+ def sinc_impulse_response(cutoff: torch.Tensor, window_size: int = 513, high_pass: bool = False):
253
+ """Create windowed-sinc impulse response for given cutoff frequencies.
254
+
255
+ .. devices:: CPU CUDA
256
+
257
+ .. properties:: Autograd TorchScript
258
+
259
+ Args:
260
+ cutoff (Tensor): Cutoff frequencies for low-pass sinc filter.
261
+
262
+ window_size (int, optional): Size of the Hamming window to apply. Must be odd.
263
+ (Default: 513)
264
+
265
+ high_pass (bool, optional):
266
+ If ``True``, convert the resulting filter to high-pass.
267
+ Otherwise low-pass filter is returned. Default: ``False``.
268
+
269
+ Returns:
270
+ Tensor: A series of impulse responses. Shape: `(..., window_size)`.
271
+ """
272
+ if window_size % 2 == 0:
273
+ raise ValueError(f"`window_size` must be odd. Given: {window_size}")
274
+
275
+ half = window_size // 2
276
+ device, dtype = cutoff.device, cutoff.dtype
277
+ idx = torch.linspace(-half, half, window_size, device=device, dtype=dtype)
278
+
279
+ filt = torch.special.sinc(cutoff.unsqueeze(-1) * idx.unsqueeze(0))
280
+ filt = filt * torch.hamming_window(window_size, device=device, dtype=dtype, periodic=False).unsqueeze(0)
281
+ filt = filt / filt.sum(dim=-1, keepdim=True).abs()
282
+
283
+ # High pass IR is obtained by subtracting low_pass IR from delta function.
284
+ # https://courses.engr.illinois.edu/ece401/fa2020/slides/lec10.pdf
285
+ if high_pass:
286
+ filt = -filt
287
+ filt[..., half] = 1.0 + filt[..., half]
288
+ return filt
289
+
290
+
291
+ def frequency_impulse_response(magnitudes):
292
+ """Create filter from desired frequency response
293
+
294
+ Args:
295
+ magnitudes: The desired frequency responses. Shape: `(..., num_fft_bins)`
296
+
297
+ Returns:
298
+ Tensor: Impulse response. Shape `(..., 2 * (num_fft_bins - 1))`
299
+ """
300
+ if magnitudes.min() < 0.0:
301
+ # Negative magnitude does not make sense but allowing so that autograd works
302
+ # around 0.
303
+ # Should we raise error?
304
+ warnings.warn("The input frequency response should not contain negative values.")
305
+ ir = torch.fft.fftshift(torch.fft.irfft(magnitudes), dim=-1)
306
+ device, dtype = magnitudes.device, magnitudes.dtype
307
+ window = torch.hann_window(ir.size(-1), periodic=False, device=device, dtype=dtype).expand_as(ir)
308
+ return ir * window
309
+
310
+
311
+ def _overlap_and_add(waveform, stride):
312
+ num_frames, frame_size = waveform.shape[-2:]
313
+ numel = (num_frames - 1) * stride + frame_size
314
+ buffer = torch.zeros(waveform.shape[:-2] + (numel,), device=waveform.device, dtype=waveform.dtype)
315
+ for i in range(num_frames):
316
+ start = i * stride
317
+ end = start + frame_size
318
+ buffer[..., start:end] += waveform[..., i, :]
319
+ return buffer
320
+
321
+
322
+ def filter_waveform(waveform: torch.Tensor, kernels: torch.Tensor, delay_compensation: int = -1):
323
+ """Applies filters along time axis of the given waveform.
324
+
325
+ This function applies the given filters along time axis in the following manner:
326
+
327
+ 1. Split the given waveform into chunks. The number of chunks is equal to the number of given filters.
328
+ 2. Filter each chunk with corresponding filter.
329
+ 3. Place the filtered chunks at the original indices while adding up the overlapping parts.
330
+ 4. Crop the resulting waveform so that delay introduced by the filter is removed and its length
331
+ matches that of the input waveform.
332
+
333
+ The following figure illustrates this.
334
+
335
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/filter_waveform.png
336
+
337
+ .. note::
338
+
339
+ If the number of filters is one, then the operation becomes stationary.
340
+ i.e. the same filtering is applied across the time axis.
341
+
342
+ Args:
343
+ waveform (Tensor): Shape `(..., time)`.
344
+ kernels (Tensor): Impulse responses.
345
+ Valid inputs are 2D tensor with shape `(num_filters, filter_length)` or
346
+ `(N+1)`-D tensor with shape `(..., num_filters, filter_length)`, where `N` is
347
+ the dimension of waveform.
348
+
349
+ In case of 2D input, the same set of filters is used across channels and batches.
350
+ Otherwise, different sets of filters are applied. In this case, the shape of
351
+ the first `N-1` dimensions of filters must match (or be broadcastable to) that of waveform.
352
+
353
+ delay_compensation (int): Control how the waveform is cropped after full convolution.
354
+ If the value is zero or positive, it is interpreted as the length of crop at the
355
+ beginning of the waveform. The value cannot be larger than the size of filter kernel.
356
+ Otherwise the initial crop is ``filter_size // 2``.
357
+ When cropping happens, the waveform is also cropped from the end so that the
358
+ length of the resulting waveform matches the input waveform.
359
+
360
+ Returns:
361
+ Tensor: `(..., time)`.
362
+ """
363
+ if kernels.ndim not in [2, waveform.ndim + 1]:
364
+ raise ValueError(
365
+ "`kernels` must be 2 or N+1 dimension where "
366
+ f"N is the dimension of waveform. Found: {kernels.ndim} (N={waveform.ndim})"
367
+ )
368
+
369
+ num_filters, filter_size = kernels.shape[-2:]
370
+ num_frames = waveform.size(-1)
371
+
372
+ if delay_compensation > filter_size:
373
+ raise ValueError(
374
+ "When `delay_compenstation` is provided, it cannot be larger than the size of filters."
375
+ f"Found: delay_compensation={delay_compensation}, filter_size={filter_size}"
376
+ )
377
+
378
+ # Transform waveform's time axis into (num_filters x chunk_length) with optional padding
379
+ chunk_length = num_frames // num_filters
380
+ if num_frames % num_filters > 0:
381
+ chunk_length += 1
382
+ num_pad = chunk_length * num_filters - num_frames
383
+ waveform = torch.nn.functional.pad(waveform, [0, num_pad], "constant", 0)
384
+ chunked = waveform.unfold(-1, chunk_length, chunk_length)
385
+ assert chunked.numel() >= waveform.numel()
386
+
387
+ # Broadcast kernels
388
+ if waveform.ndim + 1 > kernels.ndim:
389
+ expand_shape = waveform.shape[:-1] + kernels.shape
390
+ kernels = kernels.expand(expand_shape)
391
+
392
+ convolved = fftconvolve(chunked, kernels)
393
+ restored = _overlap_and_add(convolved, chunk_length)
394
+
395
+ # Trim in a way that the number of samples are same as input,
396
+ # and the filter delay is compensated
397
+ if delay_compensation >= 0:
398
+ start = delay_compensation
399
+ else:
400
+ start = filter_size // 2
401
+ num_crops = restored.size(-1) - num_frames
402
+ end = num_crops - start
403
+ result = restored[..., start:-end]
404
+ return result
405
+
406
+
407
+ def exp_sigmoid(
408
+ input: torch.Tensor, exponent: float = 10.0, max_value: float = 2.0, threshold: float = 1e-7
409
+ ) -> torch.Tensor:
410
+ """Exponential Sigmoid pointwise nonlinearity.
411
+ Implements the equation:
412
+ ``max_value`` * sigmoid(``input``) ** (log(``exponent``)) + ``threshold``
413
+
414
+ The output has a range of [``threshold``, ``max_value``].
415
+ ``exponent`` controls the slope of the output.
416
+
417
+ .. devices:: CPU CUDA
418
+
419
+ Args:
420
+ input (Tensor): Input Tensor
421
+ exponent (float, optional): Exponent. Controls the slope of the output
422
+ max_value (float, optional): Maximum value of the output
423
+ threshold (float, optional): Minimum value of the output
424
+
425
+ Returns:
426
+ Tensor: Exponential Sigmoid output. Shape: same as input
427
+
428
+ """
429
+
430
+ return max_value * torch.pow(
431
+ torch.nn.functional.sigmoid(input),
432
+ torch.log(torch.tensor(exponent, device=input.device, dtype=input.dtype)),
433
+ ) + torch.tensor(threshold, device=input.device, dtype=input.dtype)
.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/_rir.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torchaudio
6
+ from torch import Tensor
7
+
8
+
9
+ def _compute_image_sources(
10
+ room: torch.Tensor,
11
+ source: torch.Tensor,
12
+ max_order: int,
13
+ absorption: torch.Tensor,
14
+ scatter: Optional[torch.Tensor] = None,
15
+ ) -> Tuple[Tensor, Tensor]:
16
+ """Compute image sources in a shoebox-like room.
17
+
18
+ Args:
19
+ room (torch.Tensor): The 1D Tensor to determine the room size. The shape is
20
+ `(D,)`, where ``D`` is 2 if room is a 2D room, or 3 if room is a 3D room.
21
+ source (torch.Tensor): The coordinate of the sound source. Tensor with dimensions
22
+ `(D)`.
23
+ max_order (int): The maximum number of reflections of the source.
24
+ absorption (torch.Tensor): The absorption coefficients of wall materials.
25
+ ``absorption`` is a Tensor with dimensions `(num_band, num_wall)`.
26
+ The shape options are ``[(1, 4), (1, 6), (7, 4), (7, 6)]``.
27
+ ``num_band`` is `1` if the coefficients is the same for all frequencies, or is `7`
28
+ if the coefficients are different to different frequencies. `7` refers to the default number
29
+ of octave bands. (See note in `simulate_rir_ism` method).
30
+ ``num_wall`` is `4` if the room is a 2D room, representing absorption coefficients
31
+ of ``"west"``, ``"east"``, ``"south"``, and ``"north"`` walls, respectively.
32
+ Or it is `6` if the room is a 3D room, representing absorption coefficients
33
+ of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``, and ``"ceiling"``, respectively.
34
+ scatter (torch.Tensor): The scattering coefficients of wall materials.
35
+ The shape of ``scatter`` must match that of ``absorption``. If ``None``, it is not
36
+ used in image source computation. (Default: ``None``)
37
+
38
+ Returns:
39
+ (torch.Tensor): The coordinates of all image sources within ``max_order`` number of reflections.
40
+ Tensor with dimensions `(num_image_source, D)`.
41
+ (torch.Tensor): The attenuation of corresponding image sources. Tensor with dimensions
42
+ `(num_band, num_image_source)`.
43
+ """
44
+ if scatter is None:
45
+ tr = torch.sqrt(1 - absorption)
46
+ else:
47
+ tr = torch.sqrt(1 - absorption) * torch.sqrt(1 - scatter)
48
+
49
+ ind = torch.arange(-max_order, max_order + 1, device=source.device)
50
+ if room.shape[0] == 2:
51
+ XYZ = torch.meshgrid(ind, ind, indexing="ij")
52
+ else:
53
+ XYZ = torch.meshgrid(ind, ind, ind, indexing="ij")
54
+ XYZ = torch.stack([c.reshape((-1,)) for c in XYZ], dim=-1)
55
+ XYZ = XYZ[XYZ.abs().sum(dim=-1) <= max_order]
56
+
57
+ # compute locations of image sources
58
+ d = room[None, :]
59
+ s = source[None, :]
60
+ img_loc = torch.where(XYZ % 2 == 1, d * (XYZ + 1) - s, d * XYZ + s)
61
+
62
+ # attenuation
63
+ exp_lo = abs(torch.floor((XYZ / 2)))
64
+ exp_hi = abs(torch.floor((XYZ + 1) / 2))
65
+ t_lo = tr[:, ::2].unsqueeze(1).repeat(1, XYZ.shape[0], 1) # (num_band, left walls)
66
+ t_hi = tr[:, 1::2].unsqueeze(1).repeat(1, XYZ.shape[0], 1) # (num_band, right walls)
67
+ att = torch.prod((t_lo**exp_lo) * (t_hi**exp_hi), dim=-1) # (num_band, num_image_source)
68
+ return img_loc, att
69
+
70
+
71
+ def _hann(x: torch.Tensor, T: int):
72
+ """Compute the Hann window where the values are truncated based on window length.
73
+ torch.hann_window can only sample window function at integer points, the method is to sample
74
+ continuous window function at non-integer points.
75
+
76
+ Args:
77
+ x (torch.Tensor): The fractional component of time delay Tensor.
78
+ T (torch.Tensor): The window length of sinc function.
79
+
80
+ Returns:
81
+ (torch.Tensor): The hann window Tensor where values outside
82
+ the sinc window (`T`) is set to zero.
83
+ """
84
+ y = torch.where(
85
+ torch.abs(x) <= T / 2,
86
+ 0.5 * (1 + torch.cos(2 * math.pi * x / T)),
87
+ x.new_zeros(1),
88
+ )
89
+ return y
90
+
91
+
92
+ def _frac_delay(delay: torch.Tensor, delay_i: torch.Tensor, delay_filter_length: int):
93
+ """Compute fractional delay of impulse response signal.
94
+
95
+ Args:
96
+ delay (torch.Tensor): The time delay Tensor in samples.
97
+ delay_i (torch.Tensor): The integer part of delay.
98
+ delay_filter_length (int): The window length for sinc function.
99
+
100
+ Returns:
101
+ (torch.Tensor): The impulse response Tensor for all image sources.
102
+ """
103
+ if delay_filter_length % 2 != 1:
104
+ raise ValueError("The filter length must be odd")
105
+
106
+ pad = delay_filter_length // 2
107
+ n = torch.arange(-pad, pad + 1, device=delay.device) + delay_i[..., None]
108
+ delay = delay[..., None]
109
+
110
+ return torch.special.sinc(n - delay) * _hann(n - delay, 2 * pad)
111
+
112
+
113
+ def _adjust_coeff(coeffs: Union[float, torch.Tensor], name: str) -> torch.Tensor:
114
+ """Validates and converts absorption or scattering parameters to a tensor with appropriate shape
115
+
116
+ Args:
117
+ coeff (float or torch.Tensor): The absorption coefficients of wall materials.
118
+
119
+ If the dtype is ``float``, the absorption coefficient is identical for all walls and
120
+ all frequencies.
121
+
122
+ If ``absorption`` is a 1D Tensor, the shape must be `(2*dim,)`,
123
+ where the values represent absorption coefficients of ``"west"``, ``"east"``,
124
+ ``"south"``, ``"north"``, ``"floor"``, and ``"ceiling"``, respectively.
125
+
126
+ If ``absorption`` is a 2D Tensor, the shape must be `(7, 2*dim)`,
127
+ where 7 represents the number of octave bands.
128
+
129
+ Returns:
130
+ (torch.Tensor): The expanded coefficient.
131
+ The shape is `(1, 6)` for single octave band case, and
132
+ `(7, 6)` for multi octave band case.
133
+ """
134
+ num_walls = 6
135
+ if isinstance(coeffs, float):
136
+ if coeffs < 0:
137
+ raise ValueError(f"`{name}` must be non-negative. Found: {coeffs}")
138
+ return torch.full((1, num_walls), coeffs)
139
+ if isinstance(coeffs, Tensor):
140
+ if torch.any(coeffs < 0):
141
+ raise ValueError(f"`{name}` must be non-negative. Found: {coeffs}")
142
+ if coeffs.ndim == 1:
143
+ if coeffs.numel() != num_walls:
144
+ raise ValueError(
145
+ f"The shape of `{name}` must be ({num_walls},) when it is a 1D Tensor. "
146
+ f"Found the shape {coeffs.shape}."
147
+ )
148
+ return coeffs.unsqueeze(0)
149
+ if coeffs.ndim == 2:
150
+ if coeffs.shape[1] != num_walls:
151
+ raise ValueError(
152
+ f"The shape of `{name}` must be (NUM_BANDS, {num_walls}) when it "
153
+ f"is a 2D Tensor. Found: {coeffs.shape}."
154
+ )
155
+ return coeffs
156
+ raise TypeError(f"`{name}` must be float or Tensor.")
157
+
158
+
159
+ def _validate_inputs(
160
+ room: torch.Tensor,
161
+ source: torch.Tensor,
162
+ mic_array: torch.Tensor,
163
+ ):
164
+ """Validate dimensions of input arguments, and normalize different kinds of absorption into the same dimension.
165
+
166
+ Args:
167
+ room (torch.Tensor): The size of the room. width, length (and height)
168
+ source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(dim,)`.
169
+ mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, dim)`.
170
+ """
171
+ if not (room.ndim == 1 and room.numel() == 3):
172
+ raise ValueError(f"`room` must be a 1D Tensor with 3 elements. Found {room.shape}.")
173
+ if not (source.ndim == 1 and source.numel() == 3):
174
+ raise ValueError(f"`source` must be 1D Tensor with 3 elements. Found {source.shape}.")
175
+ if not (mic_array.ndim == 2 and mic_array.shape[1] == 3):
176
+ raise ValueError(f"`mic_array` must be a 2D Tensor with shape (num_channels, 3). Found {mic_array.shape}.")
177
+
178
+
179
+ def simulate_rir_ism(
180
+ room: torch.Tensor,
181
+ source: torch.Tensor,
182
+ mic_array: torch.Tensor,
183
+ max_order: int,
184
+ absorption: Union[float, torch.Tensor],
185
+ output_length: Optional[int] = None,
186
+ delay_filter_length: int = 81,
187
+ center_frequency: Optional[torch.Tensor] = None,
188
+ sound_speed: float = 343.0,
189
+ sample_rate: float = 16000.0,
190
+ ) -> Tensor:
191
+ r"""Compute Room Impulse Response (RIR) based on the *image source method* :cite:`allen1979image`.
192
+ The implementation is based on *pyroomacoustics* :cite:`scheibler2018pyroomacoustics`.
193
+
194
+ .. devices:: CPU
195
+
196
+ .. properties:: TorchScript
197
+
198
+ Args:
199
+ room (torch.Tensor): Room coordinates. The shape of `room` must be `(3,)` which represents
200
+ three dimensions of the room.
201
+ source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(3,)`.
202
+ mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, 3)`.
203
+ max_order (int): The maximum number of reflections of the source.
204
+ absorption (float or torch.Tensor): The *absorption* :cite:`wiki:Absorption_(acoustics)`
205
+ coefficients of wall materials for sound energy.
206
+ If the dtype is ``float``, the absorption coefficient is identical for all walls and
207
+ all frequencies.
208
+ If ``absorption`` is a 1D Tensor, the shape must be `(6,)`, where the values represent
209
+ absorption coefficients of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``,
210
+ and ``"ceiling"``, respectively.
211
+ If ``absorption`` is a 2D Tensor, the shape must be `(7, 6)`, where 7 represents the number of octave bands.
212
+ output_length (int or None, optional): The output length of simulated RIR signal. If ``None``,
213
+ the length is defined as
214
+
215
+ .. math::
216
+ \frac{\text{max\_d} \cdot \text{sample\_rate}}{\text{sound\_speed}} + \text{delay\_filter\_length}
217
+
218
+ where ``max_d`` is the maximum distance between image sources and microphones.
219
+ delay_filter_length (int, optional): The filter length for computing sinc function. (Default: ``81``)
220
+ center_frequency (torch.Tensor, optional): The center frequencies of octave bands for multi-band walls.
221
+ Only used when ``absorption`` is a 2D Tensor.
222
+ sound_speed (float, optional): The speed of sound. (Default: ``343.0``)
223
+ sample_rate (float, optional): The sample rate of the generated room impulse response signal.
224
+ (Default: ``16000.0``)
225
+
226
+ Returns:
227
+ (torch.Tensor): The simulated room impulse response waveform. Tensor with dimensions
228
+ `(channel, rir_length)`.
229
+
230
+ Note:
231
+ If ``absorption`` is a 2D Tensor and ``center_frequency`` is set to ``None``, the center frequencies
232
+ of octave bands are fixed to ``[125.0, 250.0, 500.0, 1000.0, 2000.0, 4000.0, 8000.0]``.
233
+ Users need to tune the values of ``absorption`` to the corresponding frequencies.
234
+ """
235
+ _validate_inputs(room, source, mic_array)
236
+ absorption = _adjust_coeff(absorption, "absorption")
237
+ img_location, att = _compute_image_sources(room, source, max_order, absorption)
238
+
239
+ # compute distances between image sources and microphones
240
+ vec = img_location[:, None, :] - mic_array[None, :, :]
241
+ dist = torch.linalg.norm(vec, dim=-1) # (image_source, channel)
242
+
243
+ img_src_att = att[..., None] / dist[None, ...] # (band, image_source, channel)
244
+
245
+ # separate delays in integer / frac part
246
+ delay = dist * sample_rate / sound_speed # distance to delay in samples
247
+ delay_i = torch.ceil(delay) # integer part
248
+
249
+ # compute the shorts IRs corresponding to each image source
250
+ irs = img_src_att[..., None] * _frac_delay(delay, delay_i, delay_filter_length)[None, ...]
251
+
252
+ rir_length = int(delay_i.max() + irs.shape[-1])
253
+ rir = torch.ops.torchaudio._simulate_rir(irs, delay_i.type(torch.int32), rir_length)
254
+
255
+ # multi-band processing
256
+ if absorption.shape[0] > 1:
257
+ if center_frequency is None:
258
+ center = torch.tensor(
259
+ [125.0, 250.0, 500.0, 1000.0, 2000.0, 4000.0, 8000.0], dtype=room.dtype, device=room.device
260
+ )
261
+ else:
262
+ center = center_frequency
263
+ # n_fft is set to 512 by default.
264
+ filters = torch.ops.torchaudio._make_rir_filter(center, sample_rate, n_fft=512)
265
+ rir = torchaudio.functional.fftconvolve(rir, filters.unsqueeze(1).repeat(1, rir.shape[1], 1), mode="same")
266
+
267
+ # sum up rir signals of all image sources into one waveform.
268
+ rir = rir.sum(0)
269
+
270
+ if output_length is not None:
271
+ if output_length > rir.shape[-1]:
272
+ rir = torch.nn.functional.pad(rir, (0, output_length - rir.shape[-1]), "constant", 0.0)
273
+ else:
274
+ rir = rir[..., :output_length]
275
+
276
+ return rir
277
+
278
+
279
+ def ray_tracing(
280
+ room: torch.Tensor,
281
+ source: torch.Tensor,
282
+ mic_array: torch.Tensor,
283
+ num_rays: int,
284
+ absorption: Union[float, torch.Tensor] = 0.0,
285
+ scattering: Union[float, torch.Tensor] = 0.0,
286
+ mic_radius: float = 0.5,
287
+ sound_speed: float = 343.0,
288
+ energy_thres: float = 1e-7,
289
+ time_thres: float = 10.0,
290
+ hist_bin_size: float = 0.004,
291
+ ) -> torch.Tensor:
292
+ r"""Compute energy histogram via ray tracing.
293
+
294
+ The implementation is based on *pyroomacoustics* :cite:`scheibler2018pyroomacoustics`.
295
+
296
+ ``num_rays`` rays are casted uniformly in all directions from the source;
297
+ when a ray intersects a wall, it is reflected and part of its energy is absorbed.
298
+ It is also scattered (sent directly to the microphone(s)) according to the ``scattering``
299
+ coefficient.
300
+ When a ray is close to the microphone, its current energy is recorded in the output
301
+ histogram for that given time slot.
302
+
303
+ .. devices:: CPU
304
+
305
+ .. properties:: TorchScript
306
+
307
+ Args:
308
+ room (torch.Tensor): Room coordinates. The shape of `room` must be `(3,)` which represents
309
+ three dimensions of the room.
310
+ source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(3,)`.
311
+ mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, 3)`.
312
+ absorption (float or torch.Tensor, optional): The absorption coefficients of wall materials.
313
+ (Default: ``0.0``).
314
+ If the type is ``float``, the absorption coefficient is identical to all walls and
315
+ all frequencies.
316
+ If ``absorption`` is a 1D Tensor, the shape must be `(6,)`, representing absorption
317
+ coefficients of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``, and
318
+ ``"ceiling"``, respectively.
319
+ If ``absorption`` is a 2D Tensor, the shape must be `(num_bands, 6)`.
320
+ ``num_bands`` is the number of frequency bands (usually 7).
321
+ scattering(float or torch.Tensor, optional): The scattering coefficients of wall materials. (Default: ``0.0``)
322
+ The shape and type of this parameter is the same as for ``absorption``.
323
+ mic_radius(float, optional): The radius of the microphone in meters. (Default: 0.5)
324
+ sound_speed (float, optional): The speed of sound in meters per second. (Default: ``343.0``)
325
+ energy_thres (float, optional): The energy level below which we stop tracing a ray. (Default: ``1e-7``)
326
+ The initial energy of each ray is ``2 / num_rays``.
327
+ time_thres (float, optional): The maximal duration for which rays are traced. (Unit: seconds) (Default: 10.0)
328
+ hist_bin_size (float, optional): The size of each bin in the output histogram. (Unit: seconds) (Default: 0.004)
329
+
330
+ Returns:
331
+ (torch.Tensor): The 3D histogram(s) where the energy of the traced ray is recorded.
332
+ Each bin corresponds to a given time slot.
333
+ The shape is `(channel, num_bands, num_bins)`, where
334
+ ``num_bins = ceil(time_thres / hist_bin_size)``.
335
+ If both ``absorption`` and ``scattering`` are floats, then ``num_bands == 1``.
336
+ """
337
+ if time_thres < hist_bin_size:
338
+ raise ValueError(
339
+ "`time_thres` must be greater than `hist_bin_size`. "
340
+ f"Found: hist_bin_size={hist_bin_size}, time_thres={time_thres}."
341
+ )
342
+
343
+ if room.dtype != source.dtype or source.dtype != mic_array.dtype:
344
+ raise ValueError(
345
+ "dtype of `room`, `source` and `mic_array` must match. "
346
+ f"Found: `room` ({room.dtype}), `source` ({source.dtype}) and "
347
+ f"`mic_array` ({mic_array.dtype})"
348
+ )
349
+
350
+ _validate_inputs(room, source, mic_array)
351
+ absorption = _adjust_coeff(absorption, "absorption").to(room.dtype)
352
+ scattering = _adjust_coeff(scattering, "scattering").to(room.dtype)
353
+
354
+ # Bring absorption and scattering to the same shape
355
+ if absorption.shape[0] == 1 and scattering.shape[0] > 1:
356
+ absorption = absorption.expand(scattering.shape)
357
+ if scattering.shape[0] == 1 and absorption.shape[0] > 1:
358
+ scattering = scattering.expand(absorption.shape)
359
+ if absorption.shape != scattering.shape:
360
+ raise ValueError(
361
+ "`absorption` and `scattering` must be broadcastable to the same number of bands and walls. "
362
+ f"Inferred shapes absorption={absorption.shape} and scattering={scattering.shape}"
363
+ )
364
+
365
+ histograms = torch.ops.torchaudio.ray_tracing(
366
+ room,
367
+ source,
368
+ mic_array,
369
+ num_rays,
370
+ absorption,
371
+ scattering,
372
+ mic_radius,
373
+ sound_speed,
374
+ energy_thres,
375
+ time_thres,
376
+ hist_bin_size,
377
+ )
378
+
379
+ return histograms
.venv/lib/python3.11/site-packages/torchaudio/prototype/functional/functional.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from torchaudio.functional.functional import _create_triangular_filterbank
7
+
8
+
9
+ def _hz_to_bark(freqs: float, bark_scale: str = "traunmuller") -> float:
10
+ r"""Convert Hz to Barks.
11
+
12
+ Args:
13
+ freqs (float): Frequencies in Hz
14
+ bark_scale (str, optional): Scale to use: ``traunmuller``, ``schroeder`` or ``wang``. (Default: ``traunmuller``)
15
+
16
+ Returns:
17
+ barks (float): Frequency in Barks
18
+ """
19
+
20
+ if bark_scale not in ["schroeder", "traunmuller", "wang"]:
21
+ raise ValueError('bark_scale should be one of "schroeder", "traunmuller" or "wang".')
22
+
23
+ if bark_scale == "wang":
24
+ return 6.0 * math.asinh(freqs / 600.0)
25
+ elif bark_scale == "schroeder":
26
+ return 7.0 * math.asinh(freqs / 650.0)
27
+ # Traunmuller Bark scale
28
+ barks = ((26.81 * freqs) / (1960.0 + freqs)) - 0.53
29
+ # Bark value correction
30
+ if barks < 2:
31
+ barks += 0.15 * (2 - barks)
32
+ elif barks > 20.1:
33
+ barks += 0.22 * (barks - 20.1)
34
+
35
+ return barks
36
+
37
+
38
+ def _bark_to_hz(barks: torch.Tensor, bark_scale: str = "traunmuller") -> torch.Tensor:
39
+ """Convert bark bin numbers to frequencies.
40
+
41
+ Args:
42
+ barks (torch.Tensor): Bark frequencies
43
+ bark_scale (str, optional): Scale to use: ``traunmuller``,``schroeder`` or ``wang``. (Default: ``traunmuller``)
44
+
45
+ Returns:
46
+ freqs (torch.Tensor): Barks converted in Hz
47
+ """
48
+
49
+ if bark_scale not in ["schroeder", "traunmuller", "wang"]:
50
+ raise ValueError('bark_scale should be one of "traunmuller", "schroeder" or "wang".')
51
+
52
+ if bark_scale == "wang":
53
+ return 600.0 * torch.sinh(barks / 6.0)
54
+ elif bark_scale == "schroeder":
55
+ return 650.0 * torch.sinh(barks / 7.0)
56
+ # Bark value correction
57
+ if any(barks < 2):
58
+ idx = barks < 2
59
+ barks[idx] = (barks[idx] - 0.3) / 0.85
60
+ elif any(barks > 20.1):
61
+ idx = barks > 20.1
62
+ barks[idx] = (barks[idx] + 4.422) / 1.22
63
+
64
+ # Traunmuller Bark scale
65
+ freqs = 1960 * ((barks + 0.53) / (26.28 - barks))
66
+
67
+ return freqs
68
+
69
+
70
+ def _hz_to_octs(freqs, tuning=0.0, bins_per_octave=12):
71
+ a440 = 440.0 * 2.0 ** (tuning / bins_per_octave)
72
+ return torch.log2(freqs / (a440 / 16))
73
+
74
+
75
+ def barkscale_fbanks(
76
+ n_freqs: int,
77
+ f_min: float,
78
+ f_max: float,
79
+ n_barks: int,
80
+ sample_rate: int,
81
+ bark_scale: str = "traunmuller",
82
+ ) -> torch.Tensor:
83
+ r"""Create a frequency bin conversion matrix.
84
+
85
+ .. devices:: CPU
86
+
87
+ .. properties:: TorchScript
88
+
89
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/bark_fbanks.png
90
+ :alt: Visualization of generated filter bank
91
+
92
+ Args:
93
+ n_freqs (int): Number of frequencies to highlight/apply
94
+ f_min (float): Minimum frequency (Hz)
95
+ f_max (float): Maximum frequency (Hz)
96
+ n_barks (int): Number of mel filterbanks
97
+ sample_rate (int): Sample rate of the audio waveform
98
+ bark_scale (str, optional): Scale to use: ``traunmuller``,``schroeder`` or ``wang``. (Default: ``traunmuller``)
99
+
100
+ Returns:
101
+ torch.Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_barks``)
102
+ meaning number of frequencies to highlight/apply to x the number of filterbanks.
103
+ Each column is a filterbank so that assuming there is a matrix A of
104
+ size (..., ``n_freqs``), the applied result would be
105
+ ``A * barkscale_fbanks(A.size(-1), ...)``.
106
+
107
+ """
108
+
109
+ # freq bins
110
+ all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
111
+
112
+ # calculate bark freq bins
113
+ m_min = _hz_to_bark(f_min, bark_scale=bark_scale)
114
+ m_max = _hz_to_bark(f_max, bark_scale=bark_scale)
115
+
116
+ m_pts = torch.linspace(m_min, m_max, n_barks + 2)
117
+ f_pts = _bark_to_hz(m_pts, bark_scale=bark_scale)
118
+
119
+ # create filterbank
120
+ fb = _create_triangular_filterbank(all_freqs, f_pts)
121
+
122
+ if (fb.max(dim=0).values == 0.0).any():
123
+ warnings.warn(
124
+ "At least one bark filterbank has all zero values. "
125
+ f"The value for `n_barks` ({n_barks}) may be set too high. "
126
+ f"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
127
+ )
128
+
129
+ return fb
130
+
131
+
132
+ def chroma_filterbank(
133
+ sample_rate: int,
134
+ n_freqs: int,
135
+ n_chroma: int,
136
+ *,
137
+ tuning: float = 0.0,
138
+ ctroct: float = 5.0,
139
+ octwidth: Optional[float] = 2.0,
140
+ norm: int = 2,
141
+ base_c: bool = True,
142
+ ):
143
+ """Create a frequency-to-chroma conversion matrix. Implementation adapted from librosa.
144
+
145
+ Args:
146
+ sample_rate (int): Sample rate.
147
+ n_freqs (int): Number of input frequencies.
148
+ n_chroma (int): Number of output chroma.
149
+ tuning (float, optional): Tuning deviation from A440 in fractions of a chroma bin. (Default: 0.0)
150
+ ctroct (float, optional): Center of Gaussian dominance window to weight filters by, in octaves. (Default: 5.0)
151
+ octwidth (float or None, optional): Width of Gaussian dominance window to weight filters by, in octaves.
152
+ If ``None``, then disable weighting altogether. (Default: 2.0)
153
+ norm (int, optional): order of norm to normalize filter bank by. (Default: 2)
154
+ base_c (bool, optional): If True, then start filter bank at C. Otherwise, start at A. (Default: True)
155
+
156
+ Returns:
157
+ torch.Tensor: Chroma filter bank, with shape `(n_freqs, n_chroma)`.
158
+ """
159
+ # Skip redundant upper half of frequency range.
160
+ freqs = torch.linspace(0, sample_rate // 2, n_freqs)[1:]
161
+ freq_bins = n_chroma * _hz_to_octs(freqs, bins_per_octave=n_chroma, tuning=tuning)
162
+ freq_bins = torch.cat((torch.tensor([freq_bins[0] - 1.5 * n_chroma]), freq_bins))
163
+ freq_bin_widths = torch.cat(
164
+ (
165
+ torch.maximum(freq_bins[1:] - freq_bins[:-1], torch.tensor(1.0)),
166
+ torch.tensor([1]),
167
+ )
168
+ )
169
+
170
+ # (n_freqs, n_chroma)
171
+ D = freq_bins.unsqueeze(1) - torch.arange(0, n_chroma)
172
+
173
+ n_chroma2 = round(n_chroma / 2)
174
+
175
+ # Project to range [-n_chroma/2, n_chroma/2 - 1]
176
+ D = torch.remainder(D + n_chroma2, n_chroma) - n_chroma2
177
+
178
+ fb = torch.exp(-0.5 * (2 * D / torch.tile(freq_bin_widths.unsqueeze(1), (1, n_chroma))) ** 2)
179
+ fb = torch.nn.functional.normalize(fb, p=norm, dim=1)
180
+
181
+ if octwidth is not None:
182
+ fb *= torch.tile(
183
+ torch.exp(-0.5 * (((freq_bins.unsqueeze(1) / n_chroma - ctroct) / octwidth) ** 2)),
184
+ (1, n_chroma),
185
+ )
186
+
187
+ if base_c:
188
+ fb = torch.roll(fb, -3 * (n_chroma // 12), dims=1)
189
+
190
+ return fb
.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._conformer_wav2vec2 import (
2
+ conformer_wav2vec2_base,
3
+ conformer_wav2vec2_model,
4
+ conformer_wav2vec2_pretrain_base,
5
+ conformer_wav2vec2_pretrain_large,
6
+ conformer_wav2vec2_pretrain_model,
7
+ ConformerWav2Vec2PretrainModel,
8
+ )
9
+ from ._emformer_hubert import emformer_hubert_base, emformer_hubert_model
10
+ from .conv_emformer import ConvEmformer
11
+ from .hifi_gan import hifigan_vocoder, hifigan_vocoder_v1, hifigan_vocoder_v2, hifigan_vocoder_v3, HiFiGANVocoder
12
+ from .rnnt import conformer_rnnt_base, conformer_rnnt_biasing, conformer_rnnt_biasing_base, conformer_rnnt_model
13
+ from .rnnt_decoder import Hypothesis, RNNTBeamSearchBiasing
14
+
15
+ __all__ = [
16
+ "conformer_rnnt_base",
17
+ "conformer_rnnt_model",
18
+ "conformer_rnnt_biasing",
19
+ "conformer_rnnt_biasing_base",
20
+ "ConvEmformer",
21
+ "conformer_wav2vec2_model",
22
+ "conformer_wav2vec2_base",
23
+ "conformer_wav2vec2_pretrain_model",
24
+ "conformer_wav2vec2_pretrain_base",
25
+ "conformer_wav2vec2_pretrain_large",
26
+ "ConformerWav2Vec2PretrainModel",
27
+ "emformer_hubert_base",
28
+ "emformer_hubert_model",
29
+ "Hypothesis",
30
+ "RNNTBeamSearchBiasing",
31
+ "HiFiGANVocoder",
32
+ "hifigan_vocoder_v1",
33
+ "hifigan_vocoder_v2",
34
+ "hifigan_vocoder_v3",
35
+ "hifigan_vocoder",
36
+ ]
.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.35 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/_conformer_wav2vec2.cpython-311.pyc ADDED
Binary file (33.4 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/_emformer_hubert.cpython-311.pyc ADDED
Binary file (16.2 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/prototype/models/__pycache__/conv_emformer.cpython-311.pyc ADDED
Binary file (30.2 kB). View file