diff --git a/.gitattributes b/.gitattributes
index 98ff606500992d94c2f3bc5d0001f221f00df3a0..d004eddbe7e056f34c44cea3f7ddae3736ab8049 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -295,3 +295,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/
.venv/bin/py-spy filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/_cffi_backend.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/jsonschema/tests/__pycache__/test_validators.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
+.venv/lib/python3.11/site-packages/torchaudio/functional/__pycache__/functional.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
diff --git a/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d89adc41cd9f6d07eb634aafa30a46e320cd4e4
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/const_vs_enum.cpython-311.pyc b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/const_vs_enum.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2149c4becc6f404e8c4da54fe283de4ce388046e
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/const_vs_enum.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/contains.cpython-311.pyc b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/contains.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..968c85585f6284fa3177ec171f9ec2e9718495a4
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/contains.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/issue232.cpython-311.pyc b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/issue232.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7698b2933a7155de1d2fd8224071a89506e9c520
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/issue232.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/json_schema_test_suite.cpython-311.pyc b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/json_schema_test_suite.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3cdaceefc2eb3936cf1af562d263f87c9c59dd0b
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/json_schema_test_suite.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/nested_schemas.cpython-311.pyc b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/nested_schemas.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cd0218e16d6b82d9496618f9b0ce1c59c0ab2f07
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/nested_schemas.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/subcomponents.cpython-311.pyc b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/subcomponents.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..204eeacc0577b0d7a27b1540e048f22f68861498
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/subcomponents.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/unused_registry.cpython-311.pyc b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/unused_registry.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8140a856205afe10b951ad9b560e622827764035
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/unused_registry.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/useless_applicator_schemas.cpython-311.pyc b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/useless_applicator_schemas.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3ac47c100dd7006c6797455f9a5bf5582c7b89be
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/useless_applicator_schemas.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/useless_keywords.cpython-311.pyc b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/useless_keywords.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a759251276176df00c06bc575ba6b4a95fec259
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/useless_keywords.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/validator_creation.cpython-311.pyc b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/validator_creation.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..70fadd4046cd1f5cec7ab21c6f680f37dd746943
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/__pycache__/validator_creation.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/const_vs_enum.py b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/const_vs_enum.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6fecd10f6d8b845c675be9c19e9b504b08d30b9
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/const_vs_enum.py
@@ -0,0 +1,30 @@
+"""
+A benchmark for comparing equivalent validation of `const` and `enum`.
+"""
+
+from pyperf import Runner
+
+from jsonschema import Draft202012Validator
+
+value = [37] * 100
+const_schema = {"const": list(value)}
+enum_schema = {"enum": [list(value)]}
+
+valid = list(value)
+invalid = [*valid, 73]
+
+const = Draft202012Validator(const_schema)
+enum = Draft202012Validator(enum_schema)
+
+assert const.is_valid(valid)
+assert enum.is_valid(valid)
+assert not const.is_valid(invalid)
+assert not enum.is_valid(invalid)
+
+
+if __name__ == "__main__":
+ runner = Runner()
+ runner.bench_func("const valid", lambda: const.is_valid(valid))
+ runner.bench_func("const invalid", lambda: const.is_valid(invalid))
+ runner.bench_func("enum valid", lambda: enum.is_valid(valid))
+ runner.bench_func("enum invalid", lambda: enum.is_valid(invalid))
diff --git a/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/contains.py b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/contains.py
new file mode 100644
index 0000000000000000000000000000000000000000..739cd044cceb807b4029dca9447e954214a24809
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/contains.py
@@ -0,0 +1,28 @@
+"""
+A benchmark for validation of the `contains` keyword.
+"""
+
+from pyperf import Runner
+
+from jsonschema import Draft202012Validator
+
+schema = {
+ "type": "array",
+ "contains": {"const": 37},
+}
+validator = Draft202012Validator(schema)
+
+size = 1000
+beginning = [37] + [0] * (size - 1)
+middle = [0] * (size // 2) + [37] + [0] * (size // 2)
+end = [0] * (size - 1) + [37]
+invalid = [0] * size
+
+
+if __name__ == "__main__":
+ runner = Runner()
+ runner.bench_func("baseline", lambda: validator.is_valid([]))
+ runner.bench_func("beginning", lambda: validator.is_valid(beginning))
+ runner.bench_func("middle", lambda: validator.is_valid(middle))
+ runner.bench_func("end", lambda: validator.is_valid(end))
+ runner.bench_func("invalid", lambda: validator.is_valid(invalid))
diff --git a/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/json_schema_test_suite.py b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/json_schema_test_suite.py
new file mode 100644
index 0000000000000000000000000000000000000000..905fb6a3b88faf56e3288f7eb5053172f97abe8b
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/json_schema_test_suite.py
@@ -0,0 +1,12 @@
+"""
+A performance benchmark using the official test suite.
+
+This benchmarks jsonschema using every valid example in the
+JSON-Schema-Test-Suite. It will take some time to complete.
+"""
+from pyperf import Runner
+
+from jsonschema.tests._suite import Suite
+
+if __name__ == "__main__":
+ Suite().benchmark(runner=Runner())
diff --git a/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/validator_creation.py b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/validator_creation.py
new file mode 100644
index 0000000000000000000000000000000000000000..4baeb3a31641a027496732a6f10e200346551209
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/jsonschema/benchmarks/validator_creation.py
@@ -0,0 +1,14 @@
+from pyperf import Runner
+
+from jsonschema import Draft202012Validator
+
+schema = {
+ "type": "array",
+ "minLength": 1,
+ "maxLength": 1,
+ "items": {"type": "integer"},
+}
+
+
+if __name__ == "__main__":
+ Runner().bench_func("validator creation", Draft202012Validator, schema)
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..90b411eae1f27ff5dc8adf4917c1ea8725cd2c9a
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/__init__.py
@@ -0,0 +1,53 @@
+# Initialize extension and backend first
+from . import _extension # noqa # usort: skip
+from ._backend import ( # noqa # usort: skip
+ AudioMetaData,
+ get_audio_backend,
+ info,
+ list_audio_backends,
+ load,
+ save,
+ set_audio_backend,
+)
+
+from . import ( # noqa: F401
+ compliance,
+ datasets,
+ functional,
+ io,
+ kaldi_io,
+ models,
+ pipelines,
+ sox_effects,
+ transforms,
+ utils,
+)
+
+# For BC
+from . import backend # noqa # usort: skip
+
+try:
+ from .version import __version__, git_version # noqa: F401
+except ImportError:
+ pass
+
+
+__all__ = [
+ "AudioMetaData",
+ "load",
+ "info",
+ "save",
+ "io",
+ "compliance",
+ "datasets",
+ "functional",
+ "models",
+ "pipelines",
+ "kaldi_io",
+ "utils",
+ "sox_effects",
+ "transforms",
+ "list_audio_backends",
+ "get_audio_backend",
+ "set_audio_backend",
+]
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f8d0511f85e8cdbd3ef955e9d038916c559608a2
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/__pycache__/kaldi_io.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/__pycache__/kaldi_io.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0c90b5b872a02eff71552363c8392f5a46e20dff
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/__pycache__/kaldi_io.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/__pycache__/version.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/__pycache__/version.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a8a0a9cad78fe8634d1ae59988b5ea268af5b47d
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/__pycache__/version.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/_backend/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/_backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..27337013ff12edbb9a6b18608a555e5c33031499
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/_backend/__init__.py
@@ -0,0 +1,61 @@
+from typing import List, Optional
+
+from torchaudio._internal.module_utils import deprecated
+
+from . import utils
+from .common import AudioMetaData
+
+__all__ = [
+ "AudioMetaData",
+ "load",
+ "info",
+ "save",
+ "list_audio_backends",
+ "get_audio_backend",
+ "set_audio_backend",
+]
+
+
+info = utils.get_info_func()
+load = utils.get_load_func()
+save = utils.get_save_func()
+
+
+def list_audio_backends() -> List[str]:
+ """List available backends
+
+ Returns:
+ list of str: The list of available backends.
+
+ The possible values are; ``"ffmpeg"``, ``"sox"`` and ``"soundfile"``.
+ """
+
+ return list(utils.get_available_backends().keys())
+
+
+# Temporary until global backend is removed
+@deprecated("With dispatcher enabled, this function is no-op. You can remove the function call.")
+def get_audio_backend() -> Optional[str]:
+ """Get the name of the current global backend
+
+ Returns:
+ str or None:
+ If dispatcher mode is enabled, returns ``None`` otherwise,
+ the name of current backend or ``None`` (no backend is set).
+ """
+ return None
+
+
+# Temporary until global backend is removed
+@deprecated("With dispatcher enabled, this function is no-op. You can remove the function call.")
+def set_audio_backend(backend: Optional[str]): # noqa
+ """Set the global backend.
+
+ This is a no-op when dispatcher mode is enabled.
+
+ Args:
+ backend (str or None): Name of the backend.
+ One of ``"sox_io"`` or ``"soundfile"`` based on availability
+ of the system. If ``None`` is provided the current backend is unassigned.
+ """
+ pass
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/_backend/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/_backend/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7495341e51e05574c8de6e3cfa4b6f0c4807452a
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/_backend/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/_backend/__pycache__/backend.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/_backend/__pycache__/backend.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d035c9a9900a38948b83d9f0d4ab6b6592845092
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/_backend/__pycache__/backend.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/_backend/__pycache__/common.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/_backend/__pycache__/common.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2f1be477a83c09fb96987fa3986446d767b6dcc4
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/_backend/__pycache__/common.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/_backend/__pycache__/ffmpeg.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/_backend/__pycache__/ffmpeg.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..41e56993c72fe45d86a1612cac61cf95af4a3f3f
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/_backend/__pycache__/ffmpeg.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/_backend/__pycache__/soundfile.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/_backend/__pycache__/soundfile.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..98591841949d441fd4e25bbcf7375a086fd4f21e
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/_backend/__pycache__/soundfile.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/_backend/__pycache__/soundfile_backend.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/_backend/__pycache__/soundfile_backend.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ba65c4c313f4b5dc6bf12315fcc0ef43d2057d0a
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/_backend/__pycache__/soundfile_backend.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/_backend/__pycache__/sox.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/_backend/__pycache__/sox.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..894f3b9033ebedfd016b03fbab0f0befce01cdc9
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/_backend/__pycache__/sox.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/_backend/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/_backend/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..95ce0d6a7228e5a176cd5bd0152852f1bec8d8c0
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/_backend/__pycache__/utils.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/_backend/backend.py b/.venv/lib/python3.11/site-packages/torchaudio/_backend/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..579340962c42a2210d5d7a5a41a1886b5fb62045
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/_backend/backend.py
@@ -0,0 +1,53 @@
+import os
+from abc import ABC, abstractmethod
+from typing import BinaryIO, Optional, Tuple, Union
+
+from torch import Tensor
+from torchaudio.io import CodecConfig
+
+from .common import AudioMetaData
+
+
+class Backend(ABC):
+ @staticmethod
+ @abstractmethod
+ def info(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str], buffer_size: int = 4096) -> AudioMetaData:
+ raise NotImplementedError
+
+ @staticmethod
+ @abstractmethod
+ def load(
+ uri: Union[BinaryIO, str, os.PathLike],
+ frame_offset: int = 0,
+ num_frames: int = -1,
+ normalize: bool = True,
+ channels_first: bool = True,
+ format: Optional[str] = None,
+ buffer_size: int = 4096,
+ ) -> Tuple[Tensor, int]:
+ raise NotImplementedError
+
+ @staticmethod
+ @abstractmethod
+ def save(
+ uri: Union[BinaryIO, str, os.PathLike],
+ src: Tensor,
+ sample_rate: int,
+ channels_first: bool = True,
+ format: Optional[str] = None,
+ encoding: Optional[str] = None,
+ bits_per_sample: Optional[int] = None,
+ buffer_size: int = 4096,
+ compression: Optional[Union[CodecConfig, float, int]] = None,
+ ) -> None:
+ raise NotImplementedError
+
+ @staticmethod
+ @abstractmethod
+ def can_decode(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str]) -> bool:
+ raise NotImplementedError
+
+ @staticmethod
+ @abstractmethod
+ def can_encode(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str]) -> bool:
+ raise NotImplementedError
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/_backend/common.py b/.venv/lib/python3.11/site-packages/torchaudio/_backend/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..804b18d461b93d4a371e02d9cde902b59aba3111
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/_backend/common.py
@@ -0,0 +1,52 @@
+class AudioMetaData:
+ """AudioMetaData()
+
+ Return type of ``torchaudio.info`` function.
+
+ :ivar int sample_rate: Sample rate
+ :ivar int num_frames: The number of frames
+ :ivar int num_channels: The number of channels
+ :ivar int bits_per_sample: The number of bits per sample. This is 0 for lossy formats,
+ or when it cannot be accurately inferred.
+ :ivar str encoding: Audio encoding
+ The values encoding can take are one of the following:
+
+ * ``PCM_S``: Signed integer linear PCM
+ * ``PCM_U``: Unsigned integer linear PCM
+ * ``PCM_F``: Floating point linear PCM
+ * ``FLAC``: Flac, Free Lossless Audio Codec
+ * ``ULAW``: Mu-law
+ * ``ALAW``: A-law
+ * ``MP3`` : MP3, MPEG-1 Audio Layer III
+ * ``VORBIS``: OGG Vorbis
+ * ``AMR_WB``: Adaptive Multi-Rate Wideband
+ * ``AMR_NB``: Adaptive Multi-Rate Narrowband
+ * ``OPUS``: Opus
+ * ``HTK``: Single channel 16-bit PCM
+ * ``UNKNOWN`` : None of above
+ """
+
+ def __init__(
+ self,
+ sample_rate: int,
+ num_frames: int,
+ num_channels: int,
+ bits_per_sample: int,
+ encoding: str,
+ ):
+ self.sample_rate = sample_rate
+ self.num_frames = num_frames
+ self.num_channels = num_channels
+ self.bits_per_sample = bits_per_sample
+ self.encoding = encoding
+
+ def __str__(self):
+ return (
+ f"AudioMetaData("
+ f"sample_rate={self.sample_rate}, "
+ f"num_frames={self.num_frames}, "
+ f"num_channels={self.num_channels}, "
+ f"bits_per_sample={self.bits_per_sample}, "
+ f"encoding={self.encoding}"
+ f")"
+ )
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/_backend/ffmpeg.py b/.venv/lib/python3.11/site-packages/torchaudio/_backend/ffmpeg.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca8374ea07de9c9c06615f773d62b1ca910efb95
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/_backend/ffmpeg.py
@@ -0,0 +1,334 @@
+import os
+import re
+import sys
+from typing import BinaryIO, Optional, Tuple, Union
+
+import torch
+import torchaudio
+
+from .backend import Backend
+from .common import AudioMetaData
+
+InputType = Union[BinaryIO, str, os.PathLike]
+
+
+def info_audio(
+ src: InputType,
+ format: Optional[str],
+ buffer_size: int = 4096,
+) -> AudioMetaData:
+ s = torchaudio.io.StreamReader(src, format, None, buffer_size)
+ sinfo = s.get_src_stream_info(s.default_audio_stream)
+ if sinfo.num_frames == 0:
+ waveform = _load_audio(s)
+ num_frames = waveform.size(1)
+ else:
+ num_frames = sinfo.num_frames
+ return AudioMetaData(
+ int(sinfo.sample_rate),
+ num_frames,
+ sinfo.num_channels,
+ sinfo.bits_per_sample,
+ sinfo.codec.upper(),
+ )
+
+
+def _get_load_filter(
+ frame_offset: int = 0,
+ num_frames: int = -1,
+ convert: bool = True,
+) -> Optional[str]:
+ if frame_offset < 0:
+ raise RuntimeError("Invalid argument: frame_offset must be non-negative. Found: {}".format(frame_offset))
+ if num_frames == 0 or num_frames < -1:
+ raise RuntimeError("Invalid argument: num_frames must be -1 or greater than 0. Found: {}".format(num_frames))
+
+ # All default values -> no filter
+ if frame_offset == 0 and num_frames == -1 and not convert:
+ return None
+ # Only convert
+ aformat = "aformat=sample_fmts=fltp"
+ if frame_offset == 0 and num_frames == -1 and convert:
+ return aformat
+ # At least one of frame_offset or num_frames has non-default value
+ if num_frames > 0:
+ atrim = "atrim=start_sample={}:end_sample={}".format(frame_offset, frame_offset + num_frames)
+ else:
+ atrim = "atrim=start_sample={}".format(frame_offset)
+ if not convert:
+ return atrim
+ return "{},{}".format(atrim, aformat)
+
+
+def _load_audio(
+ s: "torchaudio.io.StreamReader",
+ filter: Optional[str] = None,
+ channels_first: bool = True,
+) -> torch.Tensor:
+ s.add_audio_stream(-1, -1, filter_desc=filter)
+ s.process_all_packets()
+ chunk = s.pop_chunks()[0]
+ if chunk is None:
+ raise RuntimeError("Failed to decode audio.")
+ waveform = chunk._elem
+ return waveform.T if channels_first else waveform
+
+
+def load_audio(
+ src: InputType,
+ frame_offset: int = 0,
+ num_frames: int = -1,
+ convert: bool = True,
+ channels_first: bool = True,
+ format: Optional[str] = None,
+ buffer_size: int = 4096,
+) -> Tuple[torch.Tensor, int]:
+ if hasattr(src, "read") and format == "vorbis":
+ format = "ogg"
+ s = torchaudio.io.StreamReader(src, format, None, buffer_size)
+ sample_rate = int(s.get_src_stream_info(s.default_audio_stream).sample_rate)
+ filter = _get_load_filter(frame_offset, num_frames, convert)
+ waveform = _load_audio(s, filter, channels_first)
+ return waveform, sample_rate
+
+
+def _get_sample_format(dtype: torch.dtype) -> str:
+ dtype_to_format = {
+ torch.uint8: "u8",
+ torch.int16: "s16",
+ torch.int32: "s32",
+ torch.int64: "s64",
+ torch.float32: "flt",
+ torch.float64: "dbl",
+ }
+ format = dtype_to_format.get(dtype)
+ if format is None:
+ raise ValueError(f"No format found for dtype {dtype}; dtype must be one of {list(dtype_to_format.keys())}.")
+ return format
+
+
+def _native_endianness() -> str:
+ if sys.byteorder == "little":
+ return "le"
+ else:
+ return "be"
+
+
+def _get_encoder_for_wav(encoding: str, bits_per_sample: int) -> str:
+ if bits_per_sample not in {None, 8, 16, 24, 32, 64}:
+ raise ValueError(f"Invalid bits_per_sample {bits_per_sample} for WAV encoding.")
+ endianness = _native_endianness()
+ if not encoding:
+ if not bits_per_sample:
+ # default to PCM S16
+ return f"pcm_s16{endianness}"
+ if bits_per_sample == 8:
+ return "pcm_u8"
+ return f"pcm_s{bits_per_sample}{endianness}"
+ if encoding == "PCM_S":
+ if not bits_per_sample:
+ bits_per_sample = 16
+ if bits_per_sample == 8:
+ raise ValueError("For WAV signed PCM, 8-bit encoding is not supported.")
+ return f"pcm_s{bits_per_sample}{endianness}"
+ if encoding == "PCM_U":
+ if bits_per_sample in (None, 8):
+ return "pcm_u8"
+ raise ValueError("For WAV unsigned PCM, only 8-bit encoding is supported.")
+ if encoding == "PCM_F":
+ if not bits_per_sample:
+ bits_per_sample = 32
+ if bits_per_sample in (32, 64):
+ return f"pcm_f{bits_per_sample}{endianness}"
+ raise ValueError("For WAV float PCM, only 32- and 64-bit encodings are supported.")
+ if encoding == "ULAW":
+ if bits_per_sample in (None, 8):
+ return "pcm_mulaw"
+ raise ValueError("For WAV PCM mu-law, only 8-bit encoding is supported.")
+ if encoding == "ALAW":
+ if bits_per_sample in (None, 8):
+ return "pcm_alaw"
+ raise ValueError("For WAV PCM A-law, only 8-bit encoding is supported.")
+ raise ValueError(f"WAV encoding {encoding} is not supported.")
+
+
+def _get_flac_sample_fmt(bps):
+ if bps is None or bps == 16:
+ return "s16"
+ if bps == 24:
+ return "s32"
+ raise ValueError(f"FLAC only supports bits_per_sample values of 16 and 24 ({bps} specified).")
+
+
+def _parse_save_args(
+ ext: Optional[str],
+ format: Optional[str],
+ encoding: Optional[str],
+ bps: Optional[int],
+):
+ # torchaudio's save function accepts the followings, which do not 1to1 map
+ # to FFmpeg.
+ #
+ # - format: audio format
+ # - bits_per_sample: encoder sample format
+ # - encoding: such as PCM_U8.
+ #
+ # In FFmpeg, format is specified with the following three (and more)
+ #
+ # - muxer: could be audio format or container format.
+ # the one we passed to the constructor of StreamWriter
+ # - encoder: the audio encoder used to encode audio
+ # - encoder sample format: the format used by encoder to encode audio.
+ #
+ # If encoder sample format is different from source sample format, StreamWriter
+ # will insert a filter automatically.
+ #
+ def _type(spec):
+ # either format is exactly the specified one
+ # or extension matches to the spec AND there is no format override.
+ return format == spec or (format is None and ext == spec)
+
+ if _type("wav") or _type("amb"):
+ # wav is special because it supports different encoding through encoders
+ # each encoder only supports one encoder format
+ #
+ # amb format is a special case originated from libsox.
+ # It is basically a WAV format, with slight modification.
+ # https://github.com/chirlu/sox/commit/4a4ea33edbca5972a1ed8933cc3512c7302fa67a#diff-39171191a858add9df87f5f210a34a776ac2c026842ae6db6ce97f5e68836795
+ # It is a format so that decoders will recognize it as ambisonic.
+ # https://www.ambisonia.com/Members/mleese/file-format-for-b-format/
+ # FFmpeg does not recognize amb because it is basically a WAV format.
+ muxer = "wav"
+ encoder = _get_encoder_for_wav(encoding, bps)
+ sample_fmt = None
+ elif _type("vorbis"):
+ # FFpmeg does not recognize vorbis extension, while libsox used to do.
+ # For the sake of bakward compatibility, (and the simplicity),
+ # we support the case where users want to do save("foo.vorbis")
+ muxer = "ogg"
+ encoder = "vorbis"
+ sample_fmt = None
+ else:
+ muxer = format
+ encoder = None
+ sample_fmt = None
+ if _type("flac"):
+ sample_fmt = _get_flac_sample_fmt(bps)
+ if _type("ogg"):
+ sample_fmt = _get_flac_sample_fmt(bps)
+ return muxer, encoder, sample_fmt
+
+
+def save_audio(
+ uri: InputType,
+ src: torch.Tensor,
+ sample_rate: int,
+ channels_first: bool = True,
+ format: Optional[str] = None,
+ encoding: Optional[str] = None,
+ bits_per_sample: Optional[int] = None,
+ buffer_size: int = 4096,
+ compression: Optional[torchaudio.io.CodecConfig] = None,
+) -> None:
+ ext = None
+ if hasattr(uri, "write"):
+ if format is None:
+ raise RuntimeError("'format' is required when saving to file object.")
+ else:
+ uri = os.path.normpath(uri)
+ if tokens := str(uri).split(".")[1:]:
+ ext = tokens[-1].lower()
+
+ muxer, encoder, enc_fmt = _parse_save_args(ext, format, encoding, bits_per_sample)
+
+ if channels_first:
+ src = src.T
+
+ s = torchaudio.io.StreamWriter(uri, format=muxer, buffer_size=buffer_size)
+ s.add_audio_stream(
+ sample_rate,
+ num_channels=src.size(-1),
+ format=_get_sample_format(src.dtype),
+ encoder=encoder,
+ encoder_format=enc_fmt,
+ codec_config=compression,
+ )
+ with s.open():
+ s.write_audio_chunk(0, src)
+
+
+def _map_encoding(encoding: str) -> str:
+ for dst in ["PCM_S", "PCM_U", "PCM_F"]:
+ if dst in encoding:
+ return dst
+ if encoding == "PCM_MULAW":
+ return "ULAW"
+ elif encoding == "PCM_ALAW":
+ return "ALAW"
+ return encoding
+
+
+def _get_bits_per_sample(encoding: str, bits_per_sample: int) -> str:
+ if m := re.search(r"PCM_\w(\d+)\w*", encoding):
+ return int(m.group(1))
+ elif encoding in ["PCM_ALAW", "PCM_MULAW"]:
+ return 8
+ return bits_per_sample
+
+
+class FFmpegBackend(Backend):
+ @staticmethod
+ def info(uri: InputType, format: Optional[str], buffer_size: int = 4096) -> AudioMetaData:
+ metadata = info_audio(uri, format, buffer_size)
+ metadata.bits_per_sample = _get_bits_per_sample(metadata.encoding, metadata.bits_per_sample)
+ metadata.encoding = _map_encoding(metadata.encoding)
+ return metadata
+
+ @staticmethod
+ def load(
+ uri: InputType,
+ frame_offset: int = 0,
+ num_frames: int = -1,
+ normalize: bool = True,
+ channels_first: bool = True,
+ format: Optional[str] = None,
+ buffer_size: int = 4096,
+ ) -> Tuple[torch.Tensor, int]:
+ return load_audio(uri, frame_offset, num_frames, normalize, channels_first, format)
+
+ @staticmethod
+ def save(
+ uri: InputType,
+ src: torch.Tensor,
+ sample_rate: int,
+ channels_first: bool = True,
+ format: Optional[str] = None,
+ encoding: Optional[str] = None,
+ bits_per_sample: Optional[int] = None,
+ buffer_size: int = 4096,
+ compression: Optional[Union[torchaudio.io.CodecConfig, float, int]] = None,
+ ) -> None:
+ if not isinstance(compression, (torchaudio.io.CodecConfig, type(None))):
+ raise ValueError(
+ "FFmpeg backend expects non-`None` value for argument `compression` to be of ",
+ f"type `torchaudio.io.CodecConfig`, but received value of type {type(compression)}",
+ )
+ save_audio(
+ uri,
+ src,
+ sample_rate,
+ channels_first,
+ format,
+ encoding,
+ bits_per_sample,
+ buffer_size,
+ compression,
+ )
+
+ @staticmethod
+ def can_decode(uri: InputType, format: Optional[str]) -> bool:
+ return True
+
+ @staticmethod
+ def can_encode(uri: InputType, format: Optional[str]) -> bool:
+ return True
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/_backend/soundfile.py b/.venv/lib/python3.11/site-packages/torchaudio/_backend/soundfile.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4be1f70999db86e4ff70b6b703d5784a891a84c
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/_backend/soundfile.py
@@ -0,0 +1,54 @@
+import os
+from typing import BinaryIO, Optional, Tuple, Union
+
+import torch
+from torchaudio.io import CodecConfig
+
+from . import soundfile_backend
+from .backend import Backend
+from .common import AudioMetaData
+
+
+class SoundfileBackend(Backend):
+ @staticmethod
+ def info(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str], buffer_size: int = 4096) -> AudioMetaData:
+ return soundfile_backend.info(uri, format)
+
+ @staticmethod
+ def load(
+ uri: Union[BinaryIO, str, os.PathLike],
+ frame_offset: int = 0,
+ num_frames: int = -1,
+ normalize: bool = True,
+ channels_first: bool = True,
+ format: Optional[str] = None,
+ buffer_size: int = 4096,
+ ) -> Tuple[torch.Tensor, int]:
+ return soundfile_backend.load(uri, frame_offset, num_frames, normalize, channels_first, format)
+
+ @staticmethod
+ def save(
+ uri: Union[BinaryIO, str, os.PathLike],
+ src: torch.Tensor,
+ sample_rate: int,
+ channels_first: bool = True,
+ format: Optional[str] = None,
+ encoding: Optional[str] = None,
+ bits_per_sample: Optional[int] = None,
+ buffer_size: int = 4096,
+ compression: Optional[Union[CodecConfig, float, int]] = None,
+ ) -> None:
+ if compression:
+ raise ValueError("soundfile backend does not support argument `compression`.")
+
+ soundfile_backend.save(
+ uri, src, sample_rate, channels_first, format=format, encoding=encoding, bits_per_sample=bits_per_sample
+ )
+
+ @staticmethod
+ def can_decode(uri, format) -> bool:
+ return True
+
+ @staticmethod
+ def can_encode(uri, format) -> bool:
+ return True
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/_backend/soundfile_backend.py b/.venv/lib/python3.11/site-packages/torchaudio/_backend/soundfile_backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e7b0b13cd9adf2106ef4d6885f89341822add13
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/_backend/soundfile_backend.py
@@ -0,0 +1,457 @@
+"""The new soundfile backend which will become default in 0.8.0 onward"""
+import warnings
+from typing import Optional, Tuple
+
+import torch
+from torchaudio._internal import module_utils as _mod_utils
+
+from .common import AudioMetaData
+
+
+_IS_SOUNDFILE_AVAILABLE = False
+
+# TODO: import soundfile only when it is used.
+if _mod_utils.is_module_available("soundfile"):
+ try:
+ import soundfile
+
+ _requires_soundfile = _mod_utils.no_op
+ _IS_SOUNDFILE_AVAILABLE = True
+ except Exception:
+ _requires_soundfile = _mod_utils.fail_with_message(
+ "requires soundfile, but we failed to import it. Please check the installation of soundfile."
+ )
+else:
+ _requires_soundfile = _mod_utils.fail_with_message(
+ "requires soundfile, but it is not installed. Please install soundfile."
+ )
+
+
+# Mapping from soundfile subtype to number of bits per sample.
+# This is mostly heuristical and the value is set to 0 when it is irrelevant
+# (lossy formats) or when it can't be inferred.
+# For ADPCM (and G72X) subtypes, it's hard to infer the bit depth because it's not part of the standard:
+# According to https://en.wikipedia.org/wiki/Adaptive_differential_pulse-code_modulation#In_telephony,
+# the default seems to be 8 bits but it can be compressed further to 4 bits.
+# The dict is inspired from
+# https://github.com/bastibe/python-soundfile/blob/744efb4b01abc72498a96b09115b42a4cabd85e4/soundfile.py#L66-L94
+_SUBTYPE_TO_BITS_PER_SAMPLE = {
+ "PCM_S8": 8, # Signed 8 bit data
+ "PCM_16": 16, # Signed 16 bit data
+ "PCM_24": 24, # Signed 24 bit data
+ "PCM_32": 32, # Signed 32 bit data
+ "PCM_U8": 8, # Unsigned 8 bit data (WAV and RAW only)
+ "FLOAT": 32, # 32 bit float data
+ "DOUBLE": 64, # 64 bit float data
+ "ULAW": 8, # U-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
+ "ALAW": 8, # A-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
+ "IMA_ADPCM": 0, # IMA ADPCM.
+ "MS_ADPCM": 0, # Microsoft ADPCM.
+ "GSM610": 0, # GSM 6.10 encoding. (Wikipedia says 1.625 bit depth?? https://en.wikipedia.org/wiki/Full_Rate)
+ "VOX_ADPCM": 0, # OKI / Dialogix ADPCM
+ "G721_32": 0, # 32kbs G721 ADPCM encoding.
+ "G723_24": 0, # 24kbs G723 ADPCM encoding.
+ "G723_40": 0, # 40kbs G723 ADPCM encoding.
+ "DWVW_12": 12, # 12 bit Delta Width Variable Word encoding.
+ "DWVW_16": 16, # 16 bit Delta Width Variable Word encoding.
+ "DWVW_24": 24, # 24 bit Delta Width Variable Word encoding.
+ "DWVW_N": 0, # N bit Delta Width Variable Word encoding.
+ "DPCM_8": 8, # 8 bit differential PCM (XI only)
+ "DPCM_16": 16, # 16 bit differential PCM (XI only)
+ "VORBIS": 0, # Xiph Vorbis encoding. (lossy)
+ "ALAC_16": 16, # Apple Lossless Audio Codec (16 bit).
+ "ALAC_20": 20, # Apple Lossless Audio Codec (20 bit).
+ "ALAC_24": 24, # Apple Lossless Audio Codec (24 bit).
+ "ALAC_32": 32, # Apple Lossless Audio Codec (32 bit).
+}
+
+
+def _get_bit_depth(subtype):
+ if subtype not in _SUBTYPE_TO_BITS_PER_SAMPLE:
+ warnings.warn(
+ f"The {subtype} subtype is unknown to TorchAudio. As a result, the bits_per_sample "
+ "attribute will be set to 0. If you are seeing this warning, please "
+ "report by opening an issue on github (after checking for existing/closed ones). "
+ "You may otherwise ignore this warning."
+ )
+ return _SUBTYPE_TO_BITS_PER_SAMPLE.get(subtype, 0)
+
+
+_SUBTYPE_TO_ENCODING = {
+ "PCM_S8": "PCM_S",
+ "PCM_16": "PCM_S",
+ "PCM_24": "PCM_S",
+ "PCM_32": "PCM_S",
+ "PCM_U8": "PCM_U",
+ "FLOAT": "PCM_F",
+ "DOUBLE": "PCM_F",
+ "ULAW": "ULAW",
+ "ALAW": "ALAW",
+ "VORBIS": "VORBIS",
+}
+
+
+def _get_encoding(format: str, subtype: str):
+ if format == "FLAC":
+ return "FLAC"
+ return _SUBTYPE_TO_ENCODING.get(subtype, "UNKNOWN")
+
+
+@_requires_soundfile
+def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
+ """Get signal information of an audio file.
+
+ Note:
+ ``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
+ ``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,
+ which has a restriction on type annotation due to TorchScript compiler compatiblity.
+
+ Args:
+ filepath (path-like object or file-like object):
+ Source of audio data.
+ format (str or None, optional):
+ Not used. PySoundFile does not accept format hint.
+
+ Returns:
+ AudioMetaData: meta data of the given audio.
+
+ """
+ sinfo = soundfile.info(filepath)
+ return AudioMetaData(
+ sinfo.samplerate,
+ sinfo.frames,
+ sinfo.channels,
+ bits_per_sample=_get_bit_depth(sinfo.subtype),
+ encoding=_get_encoding(sinfo.format, sinfo.subtype),
+ )
+
+
+_SUBTYPE2DTYPE = {
+ "PCM_S8": "int8",
+ "PCM_U8": "uint8",
+ "PCM_16": "int16",
+ "PCM_32": "int32",
+ "FLOAT": "float32",
+ "DOUBLE": "float64",
+}
+
+
+@_requires_soundfile
+def load(
+ filepath: str,
+ frame_offset: int = 0,
+ num_frames: int = -1,
+ normalize: bool = True,
+ channels_first: bool = True,
+ format: Optional[str] = None,
+) -> Tuple[torch.Tensor, int]:
+ """Load audio data from file.
+
+ Note:
+ The formats this function can handle depend on the soundfile installation.
+ This function is tested on the following formats;
+
+ * WAV
+
+ * 32-bit floating-point
+ * 32-bit signed integer
+ * 16-bit signed integer
+ * 8-bit unsigned integer
+
+ * FLAC
+ * OGG/VORBIS
+ * SPHERE
+
+ By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with
+ ``float32`` dtype, and the shape of `[channel, time]`.
+
+ .. warning::
+
+ ``normalize`` argument does not perform volume normalization.
+ It only converts the sample type to `torch.float32` from the native sample
+ type.
+
+ When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit
+ signed integer, 24-bit signed integer, and 8-bit unsigned integer, by providing ``normalize=False``,
+ this function can return integer Tensor, where the samples are expressed within the whole range
+ of the corresponding dtype, that is, ``int32`` tensor for 32-bit signed PCM,
+ ``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM. Since torch does not
+ support ``int24`` dtype, 24-bit signed PCM are converted to ``int32`` tensors.
+
+ ``normalize`` argument has no effect on 32-bit floating-point WAV and other formats, such as
+ ``flac`` and ``mp3``.
+
+ For these formats, this function always returns ``float32`` Tensor with values.
+
+ Note:
+ ``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
+ ``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,
+ which has a restriction on type annotation due to TorchScript compiler compatiblity.
+
+ Args:
+ filepath (path-like object or file-like object):
+ Source of audio data.
+ frame_offset (int, optional):
+ Number of frames to skip before start reading data.
+ num_frames (int, optional):
+ Maximum number of frames to read. ``-1`` reads all the remaining samples,
+ starting from ``frame_offset``.
+ This function may return the less number of frames if there is not enough
+ frames in the given file.
+ normalize (bool, optional):
+ When ``True``, this function converts the native sample type to ``float32``.
+ Default: ``True``.
+
+ If input file is integer WAV, giving ``False`` will change the resulting Tensor type to
+ integer type.
+ This argument has no effect for formats other than integer WAV type.
+
+ channels_first (bool, optional):
+ When True, the returned Tensor has dimension `[channel, time]`.
+ Otherwise, the returned Tensor's dimension is `[time, channel]`.
+ format (str or None, optional):
+ Not used. PySoundFile does not accept format hint.
+
+ Returns:
+ (torch.Tensor, int): Resulting Tensor and sample rate.
+ If the input file has integer wav format and normalization is off, then it has
+ integer type, else ``float32`` type. If ``channels_first=True``, it has
+ `[channel, time]` else `[time, channel]`.
+ """
+ with soundfile.SoundFile(filepath, "r") as file_:
+ if file_.format != "WAV" or normalize:
+ dtype = "float32"
+ elif file_.subtype not in _SUBTYPE2DTYPE:
+ raise ValueError(f"Unsupported subtype: {file_.subtype}")
+ else:
+ dtype = _SUBTYPE2DTYPE[file_.subtype]
+
+ frames = file_._prepare_read(frame_offset, None, num_frames)
+ waveform = file_.read(frames, dtype, always_2d=True)
+ sample_rate = file_.samplerate
+
+ waveform = torch.from_numpy(waveform)
+ if channels_first:
+ waveform = waveform.t()
+ return waveform, sample_rate
+
+
+def _get_subtype_for_wav(dtype: torch.dtype, encoding: str, bits_per_sample: int):
+ if not encoding:
+ if not bits_per_sample:
+ subtype = {
+ torch.uint8: "PCM_U8",
+ torch.int16: "PCM_16",
+ torch.int32: "PCM_32",
+ torch.float32: "FLOAT",
+ torch.float64: "DOUBLE",
+ }.get(dtype)
+ if not subtype:
+ raise ValueError(f"Unsupported dtype for wav: {dtype}")
+ return subtype
+ if bits_per_sample == 8:
+ return "PCM_U8"
+ return f"PCM_{bits_per_sample}"
+ if encoding == "PCM_S":
+ if not bits_per_sample:
+ return "PCM_32"
+ if bits_per_sample == 8:
+ raise ValueError("wav does not support 8-bit signed PCM encoding.")
+ return f"PCM_{bits_per_sample}"
+ if encoding == "PCM_U":
+ if bits_per_sample in (None, 8):
+ return "PCM_U8"
+ raise ValueError("wav only supports 8-bit unsigned PCM encoding.")
+ if encoding == "PCM_F":
+ if bits_per_sample in (None, 32):
+ return "FLOAT"
+ if bits_per_sample == 64:
+ return "DOUBLE"
+ raise ValueError("wav only supports 32/64-bit float PCM encoding.")
+ if encoding == "ULAW":
+ if bits_per_sample in (None, 8):
+ return "ULAW"
+ raise ValueError("wav only supports 8-bit mu-law encoding.")
+ if encoding == "ALAW":
+ if bits_per_sample in (None, 8):
+ return "ALAW"
+ raise ValueError("wav only supports 8-bit a-law encoding.")
+ raise ValueError(f"wav does not support {encoding}.")
+
+
+def _get_subtype_for_sphere(encoding: str, bits_per_sample: int):
+ if encoding in (None, "PCM_S"):
+ return f"PCM_{bits_per_sample}" if bits_per_sample else "PCM_32"
+ if encoding in ("PCM_U", "PCM_F"):
+ raise ValueError(f"sph does not support {encoding} encoding.")
+ if encoding == "ULAW":
+ if bits_per_sample in (None, 8):
+ return "ULAW"
+ raise ValueError("sph only supports 8-bit for mu-law encoding.")
+ if encoding == "ALAW":
+ return "ALAW"
+ raise ValueError(f"sph does not support {encoding}.")
+
+
+def _get_subtype(dtype: torch.dtype, format: str, encoding: str, bits_per_sample: int):
+ if format == "wav":
+ return _get_subtype_for_wav(dtype, encoding, bits_per_sample)
+ if format == "flac":
+ if encoding:
+ raise ValueError("flac does not support encoding.")
+ if not bits_per_sample:
+ return "PCM_16"
+ if bits_per_sample > 24:
+ raise ValueError("flac does not support bits_per_sample > 24.")
+ return "PCM_S8" if bits_per_sample == 8 else f"PCM_{bits_per_sample}"
+ if format in ("ogg", "vorbis"):
+ if bits_per_sample:
+ raise ValueError("ogg/vorbis does not support bits_per_sample.")
+ if encoding is None or encoding == "vorbis":
+ return "VORBIS"
+ if encoding == "opus":
+ return "OPUS"
+ raise ValueError(f"Unexpected encoding: {encoding}")
+ if format == "mp3":
+ return "MPEG_LAYER_III"
+ if format == "sph":
+ return _get_subtype_for_sphere(encoding, bits_per_sample)
+ if format in ("nis", "nist"):
+ return "PCM_16"
+ raise ValueError(f"Unsupported format: {format}")
+
+
+@_requires_soundfile
+def save(
+ filepath: str,
+ src: torch.Tensor,
+ sample_rate: int,
+ channels_first: bool = True,
+ compression: Optional[float] = None,
+ format: Optional[str] = None,
+ encoding: Optional[str] = None,
+ bits_per_sample: Optional[int] = None,
+):
+ """Save audio data to file.
+
+ Note:
+ The formats this function can handle depend on the soundfile installation.
+ This function is tested on the following formats;
+
+ * WAV
+
+ * 32-bit floating-point
+ * 32-bit signed integer
+ * 16-bit signed integer
+ * 8-bit unsigned integer
+
+ * FLAC
+ * OGG/VORBIS
+ * SPHERE
+
+ Note:
+ ``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
+ ``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,
+ which has a restriction on type annotation due to TorchScript compiler compatiblity.
+
+ Args:
+ filepath (str or pathlib.Path): Path to audio file.
+ src (torch.Tensor): Audio data to save. must be 2D tensor.
+ sample_rate (int): sampling rate
+ channels_first (bool, optional): If ``True``, the given tensor is interpreted as `[channel, time]`,
+ otherwise `[time, channel]`.
+ compression (float of None, optional): Not used.
+ It is here only for interface compatibility reson with "sox_io" backend.
+ format (str or None, optional): Override the audio format.
+ When ``filepath`` argument is path-like object, audio format is
+ inferred from file extension. If the file extension is missing or
+ different, you can specify the correct format with this argument.
+
+ When ``filepath`` argument is file-like object,
+ this argument is required.
+
+ Valid values are ``"wav"``, ``"ogg"``, ``"vorbis"``,
+ ``"flac"`` and ``"sph"``.
+ encoding (str or None, optional): Changes the encoding for supported formats.
+ This argument is effective only for supported formats, sush as
+ ``"wav"``, ``""flac"`` and ``"sph"``. Valid values are;
+
+ - ``"PCM_S"`` (signed integer Linear PCM)
+ - ``"PCM_U"`` (unsigned integer Linear PCM)
+ - ``"PCM_F"`` (floating point PCM)
+ - ``"ULAW"`` (mu-law)
+ - ``"ALAW"`` (a-law)
+
+ bits_per_sample (int or None, optional): Changes the bit depth for the
+ supported formats.
+ When ``format`` is one of ``"wav"``, ``"flac"`` or ``"sph"``,
+ you can change the bit depth.
+ Valid values are ``8``, ``16``, ``24``, ``32`` and ``64``.
+
+ Supported formats/encodings/bit depth/compression are:
+
+ ``"wav"``
+ - 32-bit floating-point PCM
+ - 32-bit signed integer PCM
+ - 24-bit signed integer PCM
+ - 16-bit signed integer PCM
+ - 8-bit unsigned integer PCM
+ - 8-bit mu-law
+ - 8-bit a-law
+
+ Note:
+ Default encoding/bit depth is determined by the dtype of
+ the input Tensor.
+
+ ``"flac"``
+ - 8-bit
+ - 16-bit (default)
+ - 24-bit
+
+ ``"ogg"``, ``"vorbis"``
+ - Doesn't accept changing configuration.
+
+ ``"sph"``
+ - 8-bit signed integer PCM
+ - 16-bit signed integer PCM
+ - 24-bit signed integer PCM
+ - 32-bit signed integer PCM (default)
+ - 8-bit mu-law
+ - 8-bit a-law
+ - 16-bit a-law
+ - 24-bit a-law
+ - 32-bit a-law
+
+ """
+ if src.ndim != 2:
+ raise ValueError(f"Expected 2D Tensor, got {src.ndim}D.")
+ if compression is not None:
+ warnings.warn(
+ '`save` function of "soundfile" backend does not support "compression" parameter. '
+ "The argument is silently ignored."
+ )
+ if hasattr(filepath, "write"):
+ if format is None:
+ raise RuntimeError("`format` is required when saving to file object.")
+ ext = format.lower()
+ else:
+ ext = str(filepath).split(".")[-1].lower()
+
+ if bits_per_sample not in (None, 8, 16, 24, 32, 64):
+ raise ValueError("Invalid bits_per_sample.")
+ if bits_per_sample == 24:
+ warnings.warn(
+ "Saving audio with 24 bits per sample might warp samples near -1. "
+ "Using 16 bits per sample might be able to avoid this."
+ )
+ subtype = _get_subtype(src.dtype, ext, encoding, bits_per_sample)
+
+ # sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
+ # so we extend the extensions manually here
+ if ext in ["nis", "nist", "sph"] and format is None:
+ format = "NIST"
+
+ if channels_first:
+ src = src.t()
+
+ soundfile.write(file=filepath, data=src, samplerate=sample_rate, subtype=subtype, format=format)
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/_backend/sox.py b/.venv/lib/python3.11/site-packages/torchaudio/_backend/sox.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfcd8a4f8beadcd80787a354d47219f9edfb98e8
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/_backend/sox.py
@@ -0,0 +1,91 @@
+import os
+from typing import BinaryIO, Optional, Tuple, Union
+
+import torch
+import torchaudio
+
+from .backend import Backend
+from .common import AudioMetaData
+
+sox_ext = torchaudio._extension.lazy_import_sox_ext()
+
+
+class SoXBackend(Backend):
+ @staticmethod
+ def info(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str], buffer_size: int = 4096) -> AudioMetaData:
+ if hasattr(uri, "read"):
+ raise ValueError(
+ "SoX backend does not support reading from file-like objects. ",
+ "Please use an alternative backend that does support reading from file-like objects, e.g. FFmpeg.",
+ )
+ else:
+ sinfo = sox_ext.get_info(uri, format)
+ if sinfo:
+ return AudioMetaData(*sinfo)
+ else:
+ raise RuntimeError(f"Failed to fetch metadata for {uri}.")
+
+ @staticmethod
+ def load(
+ uri: Union[BinaryIO, str, os.PathLike],
+ frame_offset: int = 0,
+ num_frames: int = -1,
+ normalize: bool = True,
+ channels_first: bool = True,
+ format: Optional[str] = None,
+ buffer_size: int = 4096,
+ ) -> Tuple[torch.Tensor, int]:
+ if hasattr(uri, "read"):
+ raise ValueError(
+ "SoX backend does not support loading from file-like objects. ",
+ "Please use an alternative backend that does support loading from file-like objects, e.g. FFmpeg.",
+ )
+ else:
+ ret = sox_ext.load_audio_file(uri, frame_offset, num_frames, normalize, channels_first, format)
+ if not ret:
+ raise RuntimeError(f"Failed to load audio from {uri}.")
+ return ret
+
+ @staticmethod
+ def save(
+ uri: Union[BinaryIO, str, os.PathLike],
+ src: torch.Tensor,
+ sample_rate: int,
+ channels_first: bool = True,
+ format: Optional[str] = None,
+ encoding: Optional[str] = None,
+ bits_per_sample: Optional[int] = None,
+ buffer_size: int = 4096,
+ compression: Optional[Union[torchaudio.io.CodecConfig, float, int]] = None,
+ ) -> None:
+ if not isinstance(compression, (float, int, type(None))):
+ raise ValueError(
+ "SoX backend expects non-`None` value for argument `compression` to be of ",
+ f"type `float` or `int`, but received value of type {type(compression)}",
+ )
+ if hasattr(uri, "write"):
+ raise ValueError(
+ "SoX backend does not support writing to file-like objects. ",
+ "Please use an alternative backend that does support writing to file-like objects, e.g. FFmpeg.",
+ )
+ else:
+ sox_ext.save_audio_file(
+ uri,
+ src,
+ sample_rate,
+ channels_first,
+ compression,
+ format,
+ encoding,
+ bits_per_sample,
+ )
+
+ @staticmethod
+ def can_decode(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str]) -> bool:
+ # i.e. not a file-like object.
+ return not hasattr(uri, "read")
+
+ @staticmethod
+ def can_encode(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str]) -> bool:
+ # i.e. not a file-like object.
+ return not hasattr(uri, "write")
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/_backend/utils.py b/.venv/lib/python3.11/site-packages/torchaudio/_backend/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cde6b1927d2b35e34d4ab0e99124fde146e724d
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/_backend/utils.py
@@ -0,0 +1,317 @@
+import os
+from functools import lru_cache
+from typing import BinaryIO, Dict, Optional, Tuple, Type, Union
+
+import torch
+
+from torchaudio._extension import lazy_import_sox_ext
+from torchaudio.io import CodecConfig
+from torio._extension import lazy_import_ffmpeg_ext
+
+from . import soundfile_backend
+
+from .backend import Backend
+from .common import AudioMetaData
+from .ffmpeg import FFmpegBackend
+from .soundfile import SoundfileBackend
+from .sox import SoXBackend
+
+
+@lru_cache(None)
+def get_available_backends() -> Dict[str, Type[Backend]]:
+ backend_specs: Dict[str, Type[Backend]] = {}
+ if lazy_import_ffmpeg_ext().is_available():
+ backend_specs["ffmpeg"] = FFmpegBackend
+ if lazy_import_sox_ext().is_available():
+ backend_specs["sox"] = SoXBackend
+ if soundfile_backend._IS_SOUNDFILE_AVAILABLE:
+ backend_specs["soundfile"] = SoundfileBackend
+ return backend_specs
+
+
+def get_backend(backend_name, backends) -> Backend:
+ if backend := backends.get(backend_name):
+ return backend
+ else:
+ raise ValueError(
+ f"Unsupported backend '{backend_name}' specified; ",
+ f"please select one of {list(backends.keys())} instead.",
+ )
+
+
+def get_info_func():
+ backends = get_available_backends()
+
+ def dispatcher(
+ uri: Union[BinaryIO, str, os.PathLike], format: Optional[str], backend_name: Optional[str]
+ ) -> Backend:
+ if backend_name is not None:
+ return get_backend(backend_name, backends)
+
+ for backend in backends.values():
+ if backend.can_decode(uri, format):
+ return backend
+ raise RuntimeError(f"Couldn't find appropriate backend to handle uri {uri} and format {format}.")
+
+ def info(
+ uri: Union[BinaryIO, str, os.PathLike],
+ format: Optional[str] = None,
+ buffer_size: int = 4096,
+ backend: Optional[str] = None,
+ ) -> AudioMetaData:
+ """Get signal information of an audio file.
+
+ Note:
+ When the input type is file-like object, this function cannot
+ get the correct length (``num_samples``) for certain formats,
+ such as ``vorbis``.
+ In this case, the value of ``num_samples`` is ``0``.
+
+ Args:
+ uri (path-like object or file-like object):
+ Source of audio data. The following types are accepted:
+
+ * ``path-like``: File path or URL.
+ * ``file-like``: Object with ``read(size: int) -> bytes`` method,
+ which returns byte string of at most ``size`` length.
+
+ format (str or None, optional):
+ If not ``None``, interpreted as hint that may allow backend to override the detected format.
+ (Default: ``None``)
+
+ buffer_size (int, optional):
+ Size of buffer to use when processing file-like objects, in bytes. (Default: ``4096``)
+
+ backend (str or None, optional):
+ I/O backend to use.
+ If ``None``, function selects backend given input and available backends.
+ Otherwise, must be one of [``"ffmpeg"``, ``"sox"``, ``"soundfile"``],
+ with the corresponding backend available.
+ (Default: ``None``)
+
+ .. seealso::
+ :ref:`backend`
+
+ Returns:
+ AudioMetaData
+ """
+ backend = dispatcher(uri, format, backend)
+ return backend.info(uri, format, buffer_size)
+
+ return info
+
+
+def get_load_func():
+ backends = get_available_backends()
+
+ def dispatcher(
+ uri: Union[BinaryIO, str, os.PathLike], format: Optional[str], backend_name: Optional[str]
+ ) -> Backend:
+ if backend_name is not None:
+ return get_backend(backend_name, backends)
+
+ for backend in backends.values():
+ if backend.can_decode(uri, format):
+ return backend
+ raise RuntimeError(f"Couldn't find appropriate backend to handle uri {uri} and format {format}.")
+
+ def load(
+ uri: Union[BinaryIO, str, os.PathLike],
+ frame_offset: int = 0,
+ num_frames: int = -1,
+ normalize: bool = True,
+ channels_first: bool = True,
+ format: Optional[str] = None,
+ buffer_size: int = 4096,
+ backend: Optional[str] = None,
+ ) -> Tuple[torch.Tensor, int]:
+ """Load audio data from source.
+
+ By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with
+ ``float32`` dtype, and the shape of `[channel, time]`.
+
+ Note:
+ The formats this function can handle depend on the availability of backends.
+ Please use the following functions to fetch the supported formats.
+
+ - FFmpeg: :py:func:`torchaudio.utils.ffmpeg_utils.get_audio_decoders`
+ - Sox: :py:func:`torchaudio.utils.sox_utils.list_read_formats`
+ - SoundFile: Refer to `the official document `__.
+
+ .. warning::
+
+ ``normalize`` argument does not perform volume normalization.
+ It only converts the sample type to `torch.float32` from the native sample
+ type.
+
+ When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit
+ signed integer, 24-bit signed integer, and 8-bit unsigned integer, by providing ``normalize=False``,
+ this function can return integer Tensor, where the samples are expressed within the whole range
+ of the corresponding dtype, that is, ``int32`` tensor for 32-bit signed PCM,
+ ``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM. Since torch does not
+ support ``int24`` dtype, 24-bit signed PCM are converted to ``int32`` tensors.
+
+ ``normalize`` argument has no effect on 32-bit floating-point WAV and other formats, such as
+ ``flac`` and ``mp3``.
+
+ For these formats, this function always returns ``float32`` Tensor with values.
+
+
+ Args:
+ uri (path-like object or file-like object):
+ Source of audio data.
+ frame_offset (int, optional):
+ Number of frames to skip before start reading data.
+ num_frames (int, optional):
+ Maximum number of frames to read. ``-1`` reads all the remaining samples,
+ starting from ``frame_offset``.
+ This function may return the less number of frames if there is not enough
+ frames in the given file.
+ normalize (bool, optional):
+ When ``True``, this function converts the native sample type to ``float32``.
+ Default: ``True``.
+
+ If input file is integer WAV, giving ``False`` will change the resulting Tensor type to
+ integer type.
+ This argument has no effect for formats other than integer WAV type.
+
+ channels_first (bool, optional):
+ When True, the returned Tensor has dimension `[channel, time]`.
+ Otherwise, the returned Tensor's dimension is `[time, channel]`.
+
+ format (str or None, optional):
+ If not ``None``, interpreted as hint that may allow backend to override the detected format.
+ (Default: ``None``)
+
+ buffer_size (int, optional):
+ Size of buffer to use when processing file-like objects, in bytes. (Default: ``4096``)
+
+ backend (str or None, optional):
+ I/O backend to use.
+ If ``None``, function selects backend given input and available backends.
+ Otherwise, must be one of [``"ffmpeg"``, ``"sox"``, ``"soundfile"``],
+ with the corresponding backend being available. (Default: ``None``)
+
+ .. seealso::
+ :ref:`backend`
+
+ Returns:
+ (torch.Tensor, int): Resulting Tensor and sample rate.
+ If the input file has integer wav format and normalization is off, then it has
+ integer type, else ``float32`` type. If ``channels_first=True``, it has
+ `[channel, time]` else `[time, channel]`.
+ """
+ backend = dispatcher(uri, format, backend)
+ return backend.load(uri, frame_offset, num_frames, normalize, channels_first, format, buffer_size)
+
+ return load
+
+
+def get_save_func():
+ backends = get_available_backends()
+
+ def dispatcher(
+ uri: Union[BinaryIO, str, os.PathLike], format: Optional[str], backend_name: Optional[str]
+ ) -> Backend:
+ if backend_name is not None:
+ return get_backend(backend_name, backends)
+
+ for backend in backends.values():
+ if backend.can_encode(uri, format):
+ return backend
+ raise RuntimeError(f"Couldn't find appropriate backend to handle uri {uri} and format {format}.")
+
+ def save(
+ uri: Union[BinaryIO, str, os.PathLike],
+ src: torch.Tensor,
+ sample_rate: int,
+ channels_first: bool = True,
+ format: Optional[str] = None,
+ encoding: Optional[str] = None,
+ bits_per_sample: Optional[int] = None,
+ buffer_size: int = 4096,
+ backend: Optional[str] = None,
+ compression: Optional[Union[CodecConfig, float, int]] = None,
+ ):
+ """Save audio data to file.
+
+ Note:
+ The formats this function can handle depend on the availability of backends.
+ Please use the following functions to fetch the supported formats.
+
+ - FFmpeg: :py:func:`torchaudio.utils.ffmpeg_utils.get_audio_encoders`
+ - Sox: :py:func:`torchaudio.utils.sox_utils.list_write_formats`
+ - SoundFile: Refer to `the official document `__.
+
+ Args:
+ uri (str or pathlib.Path): Path to audio file.
+ src (torch.Tensor): Audio data to save. must be 2D tensor.
+ sample_rate (int): sampling rate
+ channels_first (bool, optional): If ``True``, the given tensor is interpreted as `[channel, time]`,
+ otherwise `[time, channel]`.
+ format (str or None, optional): Override the audio format.
+ When ``uri`` argument is path-like object, audio format is
+ inferred from file extension. If the file extension is missing or
+ different, you can specify the correct format with this argument.
+
+ When ``uri`` argument is file-like object,
+ this argument is required.
+
+ Valid values are ``"wav"``, ``"ogg"``, and ``"flac"``.
+ encoding (str or None, optional): Changes the encoding for supported formats.
+ This argument is effective only for supported formats, i.e.
+ ``"wav"`` and ``""flac"```. Valid values are
+
+ - ``"PCM_S"`` (signed integer Linear PCM)
+ - ``"PCM_U"`` (unsigned integer Linear PCM)
+ - ``"PCM_F"`` (floating point PCM)
+ - ``"ULAW"`` (mu-law)
+ - ``"ALAW"`` (a-law)
+
+ bits_per_sample (int or None, optional): Changes the bit depth for the
+ supported formats.
+ When ``format`` is one of ``"wav"`` and ``"flac"``,
+ you can change the bit depth.
+ Valid values are ``8``, ``16``, ``24``, ``32`` and ``64``.
+
+ buffer_size (int, optional):
+ Size of buffer to use when processing file-like objects, in bytes. (Default: ``4096``)
+
+ backend (str or None, optional):
+ I/O backend to use.
+ If ``None``, function selects backend given input and available backends.
+ Otherwise, must be one of [``"ffmpeg"``, ``"sox"``, ``"soundfile"``],
+ with the corresponding backend being available.
+ (Default: ``None``)
+
+ .. seealso::
+ :ref:`backend`
+
+ compression (CodecConfig, float, int, or None, optional):
+ Compression configuration to apply.
+
+ If the selected backend is FFmpeg, an instance of :py:class:`CodecConfig` must be provided.
+
+ Otherwise, if the selected backend is SoX, a float or int value corresponding to option ``-C`` of the
+ ``sox`` command line interface must be provided. For instance:
+
+ ``"mp3"``
+ Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or
+ VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``.
+
+ ``"flac"``
+ Whole number from ``0`` to ``8``. ``8`` is default and highest compression.
+
+ ``"ogg"``, ``"vorbis"``
+ Number from ``-1`` to ``10``; ``-1`` is the highest compression
+ and lowest quality. Default: ``3``.
+
+ Refer to http://sox.sourceforge.net/soxformat.html for more details.
+
+ """
+ backend = dispatcher(uri, format, backend)
+ return backend.save(
+ uri, src, sample_rate, channels_first, format, encoding, bits_per_sample, buffer_size, compression
+ )
+
+ return save
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/backend/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..84df7e7d697616076d549dc4163b55cd34335a25
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/backend/__init__.py
@@ -0,0 +1,8 @@
+# NOTE:
+# The entire `torchaudio.backend` module is deprecated.
+# New things should be added to `torchaudio._backend`.
+# Only things related to backward compatibility should be placed here.
+
+from . import common, no_backend, soundfile_backend, sox_io_backend # noqa
+
+__all__ = []
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/_sox_io_backend.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/_sox_io_backend.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7628c492841d3e11873ace25115a6d579fa7b991
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/_sox_io_backend.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/soundfile_backend.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/soundfile_backend.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..682db31829a6e4b4cb6e06deac44394960d24753
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/backend/__pycache__/soundfile_backend.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/backend/_no_backend.py b/.venv/lib/python3.11/site-packages/torchaudio/backend/_no_backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcbb2ad84aefcf33b181b686ee1105e532a8661d
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/backend/_no_backend.py
@@ -0,0 +1,25 @@
+from pathlib import Path
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+from torchaudio import AudioMetaData
+
+
+def load(
+ filepath: Union[str, Path],
+ out: Optional[Tensor] = None,
+ normalization: Union[bool, float, Callable] = True,
+ channels_first: bool = True,
+ num_frames: int = 0,
+ offset: int = 0,
+ filetype: Optional[str] = None,
+) -> Tuple[Tensor, int]:
+ raise RuntimeError("No audio I/O backend is available.")
+
+
+def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
+ raise RuntimeError("No audio I/O backend is available.")
+
+
+def info(filepath: str) -> AudioMetaData:
+ raise RuntimeError("No audio I/O backend is available.")
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/backend/common.py b/.venv/lib/python3.11/site-packages/torchaudio/backend/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f736bf4017c952c850dcbb3cc5fe1fe14f2715f
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/backend/common.py
@@ -0,0 +1,13 @@
+def __getattr__(name: str):
+ if name == "AudioMetaData":
+ import warnings
+
+ warnings.warn(
+ "`torchaudio.backend.common.AudioMetaData` has been moved to "
+ "`torchaudio.AudioMetaData`. Please update the import path.",
+ stacklevel=2,
+ )
+ from torchaudio import AudioMetaData
+
+ return AudioMetaData
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/backend/soundfile_backend.py b/.venv/lib/python3.11/site-packages/torchaudio/backend/soundfile_backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e81db372a12800a869f4a48291a77739c4f07e6
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/backend/soundfile_backend.py
@@ -0,0 +1,14 @@
+def __getattr__(name: str):
+ import warnings
+
+ warnings.warn(
+ "Torchaudio's I/O functions now support par-call bakcend dispatch. "
+ "Importing backend implementation directly is no longer guaranteed to work. "
+ "Please use `backend` keyword with load/save/info function, instead of "
+ "calling the udnerlying implementation directly.",
+ stacklevel=2,
+ )
+
+ from torchaudio._backend import soundfile_backend
+
+ return getattr(soundfile_backend, name)
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/backend/sox_io_backend.py b/.venv/lib/python3.11/site-packages/torchaudio/backend/sox_io_backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..a361ab87a5dba694e247f7f2205c0f2a25b19686
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/backend/sox_io_backend.py
@@ -0,0 +1,14 @@
+def __getattr__(name: str):
+ import warnings
+
+ warnings.warn(
+ "Torchaudio's I/O functions now support par-call bakcend dispatch. "
+ "Importing backend implementation directly is no longer guaranteed to work. "
+ "Please use `backend` keyword with load/save/info function, instead of "
+ "calling the udnerlying implementation directly.",
+ stacklevel=2,
+ )
+
+ from . import _sox_io_backend
+
+ return getattr(_sox_io_backend, name)
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/functional/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/functional/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b866977c67c9cbcb6098c39ea5c26ed7b19d5979
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/functional/__init__.py
@@ -0,0 +1,127 @@
+from ._alignment import forced_align, merge_tokens, TokenSpan
+from .filtering import (
+ allpass_biquad,
+ band_biquad,
+ bandpass_biquad,
+ bandreject_biquad,
+ bass_biquad,
+ biquad,
+ contrast,
+ dcshift,
+ deemph_biquad,
+ dither,
+ equalizer_biquad,
+ filtfilt,
+ flanger,
+ gain,
+ highpass_biquad,
+ lfilter,
+ lowpass_biquad,
+ overdrive,
+ phaser,
+ riaa_biquad,
+ treble_biquad,
+ vad,
+)
+from .functional import (
+ add_noise,
+ amplitude_to_DB,
+ apply_beamforming,
+ apply_codec,
+ compute_deltas,
+ convolve,
+ create_dct,
+ DB_to_amplitude,
+ deemphasis,
+ detect_pitch_frequency,
+ edit_distance,
+ fftconvolve,
+ frechet_distance,
+ griffinlim,
+ inverse_spectrogram,
+ linear_fbanks,
+ loudness,
+ mask_along_axis,
+ mask_along_axis_iid,
+ melscale_fbanks,
+ mu_law_decoding,
+ mu_law_encoding,
+ mvdr_weights_rtf,
+ mvdr_weights_souden,
+ phase_vocoder,
+ pitch_shift,
+ preemphasis,
+ psd,
+ resample,
+ rnnt_loss,
+ rtf_evd,
+ rtf_power,
+ sliding_window_cmn,
+ spectral_centroid,
+ spectrogram,
+ speed,
+)
+
+__all__ = [
+ "amplitude_to_DB",
+ "compute_deltas",
+ "create_dct",
+ "melscale_fbanks",
+ "linear_fbanks",
+ "DB_to_amplitude",
+ "loudness",
+ "detect_pitch_frequency",
+ "griffinlim",
+ "mask_along_axis",
+ "mask_along_axis_iid",
+ "mu_law_encoding",
+ "mu_law_decoding",
+ "phase_vocoder",
+ "sliding_window_cmn",
+ "spectrogram",
+ "inverse_spectrogram",
+ "spectral_centroid",
+ "allpass_biquad",
+ "band_biquad",
+ "bandpass_biquad",
+ "bandreject_biquad",
+ "bass_biquad",
+ "biquad",
+ "contrast",
+ "dither",
+ "dcshift",
+ "deemph_biquad",
+ "equalizer_biquad",
+ "filtfilt",
+ "flanger",
+ "forced_align",
+ "merge_tokens",
+ "TokenSpan",
+ "gain",
+ "highpass_biquad",
+ "lfilter",
+ "lowpass_biquad",
+ "overdrive",
+ "phaser",
+ "riaa_biquad",
+ "treble_biquad",
+ "vad",
+ "apply_codec",
+ "resample",
+ "edit_distance",
+ "pitch_shift",
+ "rnnt_loss",
+ "psd",
+ "mvdr_weights_souden",
+ "mvdr_weights_rtf",
+ "rtf_evd",
+ "rtf_power",
+ "apply_beamforming",
+ "fftconvolve",
+ "convolve",
+ "add_noise",
+ "speed",
+ "preemphasis",
+ "deemphasis",
+ "frechet_distance",
+]
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/functional/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/functional/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..34141f910dfcf3862e376398d8af7f60dd789c3b
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/functional/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/functional/__pycache__/_alignment.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/functional/__pycache__/_alignment.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5eb6e606901eb0d6ae3ca6ca935b6715cf14fc61
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/functional/__pycache__/_alignment.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/functional/__pycache__/filtering.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/functional/__pycache__/filtering.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a3ff6ad915a8918eab299f98e9a45a63b2beb454
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/functional/__pycache__/filtering.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/functional/__pycache__/functional.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/functional/__pycache__/functional.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..875bb4734f86c4db1a534e24ce77f28b9da0ee1c
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/functional/__pycache__/functional.cpython-311.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:71e5719c3daaa09433b5ece2431df353ef399f7678bc6bee1f1ebff9b16f9c13
+size 115834
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/functional/_alignment.py b/.venv/lib/python3.11/site-packages/torchaudio/functional/_alignment.py
new file mode 100644
index 0000000000000000000000000000000000000000..70d1e995e41b8e6817b2dcc25f6a2d0ec3d83752
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/functional/_alignment.py
@@ -0,0 +1,128 @@
+from dataclasses import dataclass
+from typing import List, Optional, Tuple
+
+import torch
+from torch import Tensor
+from torchaudio._extension import fail_if_no_align
+
+__all__ = []
+
+
+@fail_if_no_align
+def forced_align(
+ log_probs: Tensor,
+ targets: Tensor,
+ input_lengths: Optional[Tensor] = None,
+ target_lengths: Optional[Tensor] = None,
+ blank: int = 0,
+) -> Tuple[Tensor, Tensor]:
+ r"""Align a CTC label sequence to an emission.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: TorchScript
+
+ Args:
+ log_probs (Tensor): log probability of CTC emission output.
+ Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length,
+ `C` is the number of characters in alphabet including blank.
+ targets (Tensor): Target sequence. Tensor of shape `(B, L)`,
+ where `L` is the target length.
+ input_lengths (Tensor or None, optional):
+ Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`.
+ target_lengths (Tensor or None, optional):
+ Lengths of the targets. 1-D Tensor of shape `(B,)`.
+ blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0)
+
+ Returns:
+ Tuple(Tensor, Tensor):
+ Tensor: Label for each time step in the alignment path computed using forced alignment.
+
+ Tensor: Log probability scores of the labels for each time step.
+
+ Note:
+ The sequence length of `log_probs` must satisfy:
+
+
+ .. math::
+ L_{\text{log\_probs}} \ge L_{\text{label}} + N_{\text{repeat}}
+
+ where :math:`N_{\text{repeat}}` is the number of consecutively repeated tokens.
+ For example, in str `"aabbc"`, the number of repeats are `2`.
+
+ Note:
+ The current version only supports ``batch_size==1``.
+ """
+ if blank in targets:
+ raise ValueError(f"targets Tensor shouldn't contain blank index. Found {targets}.")
+ if torch.max(targets) >= log_probs.shape[-1]:
+ raise ValueError("targets values must be less than the CTC dimension")
+
+ if input_lengths is None:
+ batch_size, length = log_probs.size(0), log_probs.size(1)
+ input_lengths = torch.full((batch_size,), length, dtype=torch.int64, device=log_probs.device)
+ if target_lengths is None:
+ batch_size, length = targets.size(0), targets.size(1)
+ target_lengths = torch.full((batch_size,), length, dtype=torch.int64, device=targets.device)
+
+ # For TorchScript compatibility
+ assert input_lengths is not None
+ assert target_lengths is not None
+
+ paths, scores = torch.ops.torchaudio.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
+ return paths, scores
+
+
+@dataclass
+class TokenSpan:
+ """TokenSpan()
+ Token with time stamps and score. Returned by :py:func:`merge_tokens`.
+ """
+
+ token: int
+ """The token"""
+ start: int
+ """The start time (inclusive) in emission time axis."""
+ end: int
+ """The end time (exclusive) in emission time axis."""
+ score: float
+ """The score of the this token."""
+
+ def __len__(self) -> int:
+ """Returns the time span"""
+ return self.end - self.start
+
+
+def merge_tokens(tokens: Tensor, scores: Tensor, blank: int = 0) -> List[TokenSpan]:
+ """Removes repeated tokens and blank tokens from the given CTC token sequence.
+
+ Args:
+ tokens (Tensor): Alignment tokens (unbatched) returned from :py:func:`forced_align`.
+ Shape: `(time, )`.
+ scores (Tensor): Alignment scores (unbatched) returned from :py:func:`forced_align`.
+ Shape: `(time, )`. When computing the token-size score, the given score is averaged
+ across the corresponding time span.
+
+ Returns:
+ list of TokenSpan
+
+ Example:
+ >>> aligned_tokens, scores = forced_align(emission, targets, input_lengths, target_lengths)
+ >>> token_spans = merge_tokens(aligned_tokens[0], scores[0])
+ """
+ if tokens.ndim != 1 or scores.ndim != 1:
+ raise ValueError("`tokens` and `scores` must be 1D Tensor.")
+ if len(tokens) != len(scores):
+ raise ValueError("`tokens` and `scores` must be the same length.")
+
+ diff = torch.diff(
+ tokens, prepend=torch.tensor([-1], device=tokens.device), append=torch.tensor([-1], device=tokens.device)
+ )
+ changes_wo_blank = torch.nonzero((diff != 0)).squeeze().tolist()
+ tokens = tokens.tolist()
+ spans = [
+ TokenSpan(token=token, start=start, end=end, score=scores[start:end].mean().item())
+ for start, end in zip(changes_wo_blank[:-1], changes_wo_blank[1:])
+ if (token := tokens[start]) != blank
+ ]
+ return spans
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/functional/filtering.py b/.venv/lib/python3.11/site-packages/torchaudio/functional/filtering.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4140447efaf13edac5dce72a71d52a52686d286
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/functional/filtering.py
@@ -0,0 +1,1669 @@
+import math
+import warnings
+from typing import Optional
+
+import torch
+from torch import Tensor
+
+from torchaudio._extension import _IS_TORCHAUDIO_EXT_AVAILABLE
+
+
+def _dB2Linear(x: float) -> float:
+ return math.exp(x * math.log(10) / 20.0)
+
+
+def _generate_wave_table(
+ wave_type: str,
+ data_type: str,
+ table_size: int,
+ min: float,
+ max: float,
+ phase: float,
+ device: torch.device,
+) -> Tensor:
+ r"""A helper function for phaser. Generates a table with given parameters.
+
+ Args:
+ wave_type (str): SINE or TRIANGULAR
+ data_type (str): desired data_type ( `INT` or `FLOAT` )
+ table_size (int): desired table size
+ min (float): desired min value
+ max (float): desired max value
+ phase (float): desired phase
+ device (torch.device): Torch device on which table must be generated
+ Returns:
+ Tensor: A 1D tensor with wave table values
+ """
+
+ phase_offset = int(phase / math.pi / 2 * table_size + 0.5)
+
+ t = torch.arange(table_size, device=device, dtype=torch.int32)
+
+ point = (t + phase_offset) % table_size
+
+ d = torch.zeros_like(point, device=device, dtype=torch.float64)
+
+ if wave_type == "SINE":
+ d = (torch.sin(point.to(torch.float64) / table_size * 2 * math.pi) + 1) / 2
+ elif wave_type == "TRIANGLE":
+ d = point.to(torch.float64) * 2 / table_size
+ value = torch.div(4 * point, table_size, rounding_mode="floor")
+ d[value == 0] = d[value == 0] + 0.5
+ d[value == 1] = 1.5 - d[value == 1]
+ d[value == 2] = 1.5 - d[value == 2]
+ d[value == 3] = d[value == 3] - 1.5
+
+ d = d * (max - min) + min
+
+ if data_type == "INT":
+ mask = d < 0
+ d[mask] = d[mask] - 0.5
+ d[~mask] = d[~mask] + 0.5
+ d = d.to(torch.int32)
+ elif data_type == "FLOAT":
+ d = d.to(torch.float32)
+
+ return d
+
+
+def allpass_biquad(waveform: Tensor, sample_rate: int, central_freq: float, Q: float = 0.707) -> Tensor:
+ r"""Design two-pole all-pass filter. Similar to SoX implementation.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ waveform(torch.Tensor): audio waveform of dimension of `(..., time)`
+ sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
+ central_freq (float or torch.Tensor): central frequency (in Hz)
+ Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
+
+ Returns:
+ Tensor: Waveform of dimension of `(..., time)`
+
+ Reference:
+ - http://sox.sourceforge.net/sox.html
+ - https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF
+ """
+ dtype = waveform.dtype
+ device = waveform.device
+ central_freq = torch.as_tensor(central_freq, dtype=dtype, device=device)
+ Q = torch.as_tensor(Q, dtype=dtype, device=device)
+
+ w0 = 2 * math.pi * central_freq / sample_rate
+
+ alpha = torch.sin(w0) / 2 / Q
+
+ b0 = 1 - alpha
+ b1 = -2 * torch.cos(w0)
+ b2 = 1 + alpha
+ a0 = 1 + alpha
+ a1 = -2 * torch.cos(w0)
+ a2 = 1 - alpha
+ return biquad(waveform, b0, b1, b2, a0, a1, a2)
+
+
+def band_biquad(
+ waveform: Tensor,
+ sample_rate: int,
+ central_freq: float,
+ Q: float = 0.707,
+ noise: bool = False,
+) -> Tensor:
+ r"""Design two-pole band filter. Similar to SoX implementation.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ waveform (Tensor): audio waveform of dimension of `(..., time)`
+ sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
+ central_freq (float or torch.Tensor): central frequency (in Hz)
+ Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``).
+ noise (bool, optional) : If ``True``, uses the alternate mode for un-pitched audio (e.g. percussion).
+ If ``False``, uses mode oriented to pitched audio, i.e. voice, singing,
+ or instrumental music (Default: ``False``).
+
+ Returns:
+ Tensor: Waveform of dimension of `(..., time)`
+
+ Reference:
+ - http://sox.sourceforge.net/sox.html
+ - https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF
+ """
+ dtype = waveform.dtype
+ device = waveform.device
+ central_freq = torch.as_tensor(central_freq, dtype=dtype, device=device)
+ Q = torch.as_tensor(Q, dtype=dtype, device=device)
+
+ w0 = 2 * math.pi * central_freq / sample_rate
+ bw_Hz = central_freq / Q
+
+ a0 = 1.0
+ a2 = torch.exp(-2 * math.pi * bw_Hz / sample_rate)
+ a1 = -4 * a2 / (1 + a2) * torch.cos(w0)
+
+ b0 = torch.sqrt(1 - a1 * a1 / (4 * a2)) * (1 - a2)
+
+ if noise:
+ mult = torch.sqrt(((1 + a2) * (1 + a2) - a1 * a1) * (1 - a2) / (1 + a2)) / b0
+ b0 = mult * b0
+
+ b1 = 0.0
+ b2 = 0.0
+
+ return biquad(waveform, b0, b1, b2, a0, a1, a2)
+
+
+def bandpass_biquad(
+ waveform: Tensor,
+ sample_rate: int,
+ central_freq: float,
+ Q: float = 0.707,
+ const_skirt_gain: bool = False,
+) -> Tensor:
+ r"""Design two-pole band-pass filter. Similar to SoX implementation.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ waveform (Tensor): audio waveform of dimension of `(..., time)`
+ sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
+ central_freq (float or torch.Tensor): central frequency (in Hz)
+ Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
+ const_skirt_gain (bool, optional) : If ``True``, uses a constant skirt gain (peak gain = Q).
+ If ``False``, uses a constant 0dB peak gain. (Default: ``False``)
+
+ Returns:
+ Tensor: Waveform of dimension of `(..., time)`
+
+ Reference:
+ - http://sox.sourceforge.net/sox.html
+ - https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF
+ """
+ dtype = waveform.dtype
+ device = waveform.device
+ central_freq = torch.as_tensor(central_freq, dtype=dtype, device=device)
+ Q = torch.as_tensor(Q, dtype=dtype, device=device)
+
+ w0 = 2 * math.pi * central_freq / sample_rate
+ alpha = torch.sin(w0) / 2 / Q
+
+ temp = torch.sin(w0) / 2 if const_skirt_gain else alpha
+ b0 = temp
+ b1 = 0.0
+ b2 = -temp
+ a0 = 1 + alpha
+ a1 = -2 * torch.cos(w0)
+ a2 = 1 - alpha
+ return biquad(waveform, b0, b1, b2, a0, a1, a2)
+
+
+def bandreject_biquad(waveform: Tensor, sample_rate: int, central_freq: float, Q: float = 0.707) -> Tensor:
+ r"""Design two-pole band-reject filter. Similar to SoX implementation.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ waveform (Tensor): audio waveform of dimension of `(..., time)`
+ sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
+ central_freq (float or torch.Tensor): central frequency (in Hz)
+ Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
+
+ Returns:
+ Tensor: Waveform of dimension of `(..., time)`
+
+ Reference:
+ - http://sox.sourceforge.net/sox.html
+ - https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF
+ """
+ dtype = waveform.dtype
+ device = waveform.device
+ central_freq = torch.as_tensor(central_freq, dtype=dtype, device=device)
+ Q = torch.as_tensor(Q, dtype=dtype, device=device)
+
+ w0 = 2 * math.pi * central_freq / sample_rate
+ alpha = torch.sin(w0) / 2 / Q
+
+ b0 = 1.0
+ b1 = -2 * torch.cos(w0)
+ b2 = 1.0
+ a0 = 1 + alpha
+ a1 = -2 * torch.cos(w0)
+ a2 = 1 - alpha
+ return biquad(waveform, b0, b1, b2, a0, a1, a2)
+
+
+def bass_biquad(
+ waveform: Tensor,
+ sample_rate: int,
+ gain: float,
+ central_freq: float = 100,
+ Q: float = 0.707,
+) -> Tensor:
+ r"""Design a bass tone-control effect. Similar to SoX implementation.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ waveform (Tensor): audio waveform of dimension of `(..., time)`
+ sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
+ gain (float or torch.Tensor): desired gain at the boost (or attenuation) in dB.
+ central_freq (float or torch.Tensor, optional): central frequency (in Hz). (Default: ``100``)
+ Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``).
+
+ Returns:
+ Tensor: Waveform of dimension of `(..., time)`
+
+ Reference:
+ - http://sox.sourceforge.net/sox.html
+ - https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF
+ """
+ dtype = waveform.dtype
+ device = waveform.device
+ central_freq = torch.as_tensor(central_freq, dtype=dtype, device=device)
+ Q = torch.as_tensor(Q, dtype=dtype, device=device)
+ gain = torch.as_tensor(gain, dtype=dtype, device=device)
+
+ w0 = 2 * math.pi * central_freq / sample_rate
+ alpha = torch.sin(w0) / 2 / Q
+ A = torch.exp(gain / 40 * math.log(10))
+
+ temp1 = 2 * torch.sqrt(A) * alpha
+ temp2 = (A - 1) * torch.cos(w0)
+ temp3 = (A + 1) * torch.cos(w0)
+
+ b0 = A * ((A + 1) - temp2 + temp1)
+ b1 = 2 * A * ((A - 1) - temp3)
+ b2 = A * ((A + 1) - temp2 - temp1)
+ a0 = (A + 1) + temp2 + temp1
+ a1 = -2 * ((A - 1) + temp3)
+ a2 = (A + 1) + temp2 - temp1
+
+ return biquad(waveform, b0 / a0, b1 / a0, b2 / a0, a0 / a0, a1 / a0, a2 / a0)
+
+
+def biquad(waveform: Tensor, b0: float, b1: float, b2: float, a0: float, a1: float, a2: float) -> Tensor:
+ r"""Perform a biquad filter of input tensor. Initial conditions set to 0.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ waveform (Tensor): audio waveform of dimension of `(..., time)`
+ b0 (float or torch.Tensor): numerator coefficient of current input, x[n]
+ b1 (float or torch.Tensor): numerator coefficient of input one time step ago x[n-1]
+ b2 (float or torch.Tensor): numerator coefficient of input two time steps ago x[n-2]
+ a0 (float or torch.Tensor): denominator coefficient of current output y[n], typically 1
+ a1 (float or torch.Tensor): denominator coefficient of current output y[n-1]
+ a2 (float or torch.Tensor): denominator coefficient of current output y[n-2]
+
+ Returns:
+ Tensor: Waveform with dimension of `(..., time)`
+
+ Reference:
+ - https://en.wikipedia.org/wiki/Digital_biquad_filter
+ """
+
+ device = waveform.device
+ dtype = waveform.dtype
+
+ b0 = torch.as_tensor(b0, dtype=dtype, device=device).view(1)
+ b1 = torch.as_tensor(b1, dtype=dtype, device=device).view(1)
+ b2 = torch.as_tensor(b2, dtype=dtype, device=device).view(1)
+ a0 = torch.as_tensor(a0, dtype=dtype, device=device).view(1)
+ a1 = torch.as_tensor(a1, dtype=dtype, device=device).view(1)
+ a2 = torch.as_tensor(a2, dtype=dtype, device=device).view(1)
+
+ output_waveform = lfilter(
+ waveform,
+ torch.cat([a0, a1, a2]),
+ torch.cat([b0, b1, b2]),
+ )
+ return output_waveform
+
+
+def contrast(waveform: Tensor, enhancement_amount: float = 75.0) -> Tensor:
+ r"""Apply contrast effect. Similar to SoX implementation.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Comparable with compression, this effect modifies an audio signal to make it sound louder
+
+ Args:
+ waveform (Tensor): audio waveform of dimension of `(..., time)`
+ enhancement_amount (float, optional): controls the amount of the enhancement
+ Allowed range of values for enhancement_amount : 0-100
+ Note that enhancement_amount = 0 still gives a significant contrast enhancement
+
+ Returns:
+ Tensor: Waveform of dimension of `(..., time)`
+
+ Reference:
+ - http://sox.sourceforge.net/sox.html
+ """
+
+ if not 0 <= enhancement_amount <= 100:
+ raise ValueError("Allowed range of values for enhancement_amount : 0-100")
+
+ contrast = enhancement_amount / 750.0
+
+ temp1 = waveform * (math.pi / 2)
+ temp2 = contrast * torch.sin(temp1 * 4)
+ output_waveform = torch.sin(temp1 + temp2)
+
+ return output_waveform
+
+
+def dcshift(waveform: Tensor, shift: float, limiter_gain: Optional[float] = None) -> Tensor:
+ r"""Apply a DC shift to the audio. Similar to SoX implementation.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: TorchScript
+
+ This can be useful to remove a DC offset
+ (caused perhaps by a hardware problem in the recording chain) from the audio
+
+ Args:
+ waveform (Tensor): audio waveform of dimension of `(..., time)`
+ shift (float): indicates the amount to shift the audio
+ Allowed range of values for shift : -2.0 to +2.0
+ limiter_gain (float of None, optional): It is used only on peaks to prevent clipping
+ It should have a value much less than 1 (e.g. 0.05 or 0.02)
+
+ Returns:
+ Tensor: Waveform of dimension of `(..., time)`
+
+ Reference:
+ - http://sox.sourceforge.net/sox.html
+ """
+ output_waveform = waveform
+ limiter_threshold = 0.0
+
+ if limiter_gain is not None:
+ limiter_threshold = 1.0 - (abs(shift) - limiter_gain)
+
+ # Note:
+ # the following index-based update breaks auto-grad support
+ if limiter_gain is not None and shift > 0:
+ mask = waveform > limiter_threshold
+ temp = (waveform[mask] - limiter_threshold) * limiter_gain / (1 - limiter_threshold)
+ output_waveform[mask] = (temp + limiter_threshold + shift).clamp(max=limiter_threshold)
+ output_waveform[~mask] = (waveform[~mask] + shift).clamp(min=-1, max=1)
+ elif limiter_gain is not None and shift < 0:
+ mask = waveform < -limiter_threshold
+ temp = (waveform[mask] + limiter_threshold) * limiter_gain / (1 - limiter_threshold)
+ output_waveform[mask] = (temp - limiter_threshold + shift).clamp(min=-limiter_threshold)
+ output_waveform[~mask] = (waveform[~mask] + shift).clamp(min=-1, max=1)
+ else:
+ output_waveform = (waveform + shift).clamp(min=-1, max=1)
+
+ return output_waveform
+
+
+def deemph_biquad(waveform: Tensor, sample_rate: int) -> Tensor:
+ r"""Apply ISO 908 CD de-emphasis (shelving) IIR filter. Similar to SoX implementation.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ waveform (Tensor): audio waveform of dimension of `(..., time)`
+ sample_rate (int): sampling rate of the waveform, Allowed sample rate ``44100`` or ``48000``
+
+ Returns:
+ Tensor: Waveform of dimension of `(..., time)`
+
+ Reference:
+ - http://sox.sourceforge.net/sox.html
+ - https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF
+ """
+
+ if sample_rate == 44100:
+ central_freq = 5283
+ width_slope = 0.4845
+ gain = -9.477
+ elif sample_rate == 48000:
+ central_freq = 5356
+ width_slope = 0.479
+ gain = -9.62
+ else:
+ raise ValueError("Sample rate must be 44100 (audio-CD) or 48000 (DAT)")
+
+ w0 = 2 * math.pi * central_freq / sample_rate
+ A = math.exp(gain / 40.0 * math.log(10))
+ alpha = math.sin(w0) / 2 * math.sqrt((A + 1 / A) * (1 / width_slope - 1) + 2)
+
+ temp1 = 2 * math.sqrt(A) * alpha
+ temp2 = (A - 1) * math.cos(w0)
+ temp3 = (A + 1) * math.cos(w0)
+
+ b0 = A * ((A + 1) + temp2 + temp1)
+ b1 = -2 * A * ((A - 1) + temp3)
+ b2 = A * ((A + 1) + temp2 - temp1)
+ a0 = (A + 1) - temp2 + temp1
+ a1 = 2 * ((A - 1) - temp3)
+ a2 = (A + 1) - temp2 - temp1
+
+ return biquad(waveform, b0, b1, b2, a0, a1, a2)
+
+
+def _add_noise_shaping(dithered_waveform: Tensor, waveform: Tensor) -> Tensor:
+ r"""Noise shaping is calculated by error:
+ error[n] = dithered[n] - original[n]
+ noise_shaped_waveform[n] = dithered[n] + error[n-1]
+ """
+ wf_shape = waveform.size()
+ waveform = waveform.reshape(-1, wf_shape[-1])
+
+ dithered_shape = dithered_waveform.size()
+ dithered_waveform = dithered_waveform.reshape(-1, dithered_shape[-1])
+
+ error = dithered_waveform - waveform
+
+ # add error[n-1] to dithered_waveform[n], so offset the error by 1 index
+ zeros = torch.zeros(1, dtype=error.dtype, device=error.device)
+ for index in range(error.size()[0]):
+ err = error[index]
+ error_offset = torch.cat((zeros, err))
+ error[index] = error_offset[: waveform.size()[1]]
+
+ noise_shaped = dithered_waveform + error
+ return noise_shaped.reshape(dithered_shape[:-1] + noise_shaped.shape[-1:])
+
+
+def _apply_probability_distribution(waveform: Tensor, density_function: str = "TPDF") -> Tensor:
+ r"""Apply a probability distribution function on a waveform.
+
+ Triangular probability density function (TPDF) dither noise has a
+ triangular distribution; values in the center of the range have a higher
+ probability of occurring.
+
+ Rectangular probability density function (RPDF) dither noise has a
+ uniform distribution; any value in the specified range has the same
+ probability of occurring.
+
+ Gaussian probability density function (GPDF) has a normal distribution.
+ The relationship of probabilities of results follows a bell-shaped,
+ or Gaussian curve, typical of dither generated by analog sources.
+ Args:
+ waveform (Tensor): Tensor of audio of dimension (..., time)
+ density_function (str, optional): The density function of a
+ continuous random variable (Default: ``"TPDF"``)
+ Options: Triangular Probability Density Function - `TPDF`
+ Rectangular Probability Density Function - `RPDF`
+ Gaussian Probability Density Function - `GPDF`
+ Returns:
+ Tensor: waveform dithered with TPDF
+ """
+
+ # pack batch
+ shape = waveform.size()
+ waveform = waveform.reshape(-1, shape[-1])
+
+ channel_size = waveform.size()[0] - 1
+ time_size = waveform.size()[-1] - 1
+
+ random_channel = (
+ int(
+ torch.randint(
+ channel_size,
+ [
+ 1,
+ ],
+ ).item()
+ )
+ if channel_size > 0
+ else 0
+ )
+ random_time = (
+ int(
+ torch.randint(
+ time_size,
+ [
+ 1,
+ ],
+ ).item()
+ )
+ if time_size > 0
+ else 0
+ )
+
+ number_of_bits = 16
+ up_scaling = 2 ** (number_of_bits - 1) - 2
+ signal_scaled = waveform * up_scaling
+ down_scaling = 2 ** (number_of_bits - 1)
+
+ signal_scaled_dis = waveform
+ if density_function == "RPDF":
+ RPDF = waveform[random_channel][random_time] - 0.5
+
+ signal_scaled_dis = signal_scaled + RPDF
+ elif density_function == "GPDF":
+ # TODO Replace by distribution code once
+ # https://github.com/pytorch/pytorch/issues/29843 is resolved
+ # gaussian = torch.distributions.normal.Normal(torch.mean(waveform, -1), 1).sample()
+
+ num_rand_variables = 6
+
+ gaussian = waveform[random_channel][random_time]
+ for ws in num_rand_variables * [time_size]:
+ rand_chan = int(
+ torch.randint(
+ channel_size,
+ [
+ 1,
+ ],
+ ).item()
+ )
+ gaussian += waveform[rand_chan][
+ int(
+ torch.randint(
+ ws,
+ [
+ 1,
+ ],
+ ).item()
+ )
+ ]
+
+ signal_scaled_dis = signal_scaled + gaussian
+ else:
+ # dtype needed for https://github.com/pytorch/pytorch/issues/32358
+ TPDF = torch.bartlett_window(time_size + 1, dtype=signal_scaled.dtype, device=signal_scaled.device)
+ TPDF = TPDF.repeat((channel_size + 1), 1)
+ signal_scaled_dis = signal_scaled + TPDF
+
+ quantised_signal_scaled = torch.round(signal_scaled_dis)
+ quantised_signal = quantised_signal_scaled / down_scaling
+
+ # unpack batch
+ return quantised_signal.reshape(shape[:-1] + quantised_signal.shape[-1:])
+
+
+def dither(waveform: Tensor, density_function: str = "TPDF", noise_shaping: bool = False) -> Tensor:
+ r"""Apply dither
+
+ .. devices:: CPU CUDA
+
+ .. properties:: TorchScript
+
+ Dither increases the perceived dynamic range of audio stored at a
+ particular bit-depth by eliminating nonlinear truncation distortion
+ (i.e. adding minimally perceived noise to mask distortion caused by quantization).
+
+ Args:
+ waveform (Tensor): Tensor of audio of dimension (..., time)
+ density_function (str, optional):
+ The density function of a continuous random variable. One of
+ ``"TPDF"`` (Triangular Probability Density Function),
+ ``"RPDF"`` (Rectangular Probability Density Function) or
+ ``"GPDF"`` (Gaussian Probability Density Function) (Default: ``"TPDF"``).
+ noise_shaping (bool, optional): a filtering process that shapes the spectral
+ energy of quantisation error (Default: ``False``)
+
+ Returns:
+ Tensor: waveform dithered
+ """
+ dithered = _apply_probability_distribution(waveform, density_function=density_function)
+
+ if noise_shaping:
+ return _add_noise_shaping(dithered, waveform)
+ else:
+ return dithered
+
+
+def equalizer_biquad(
+ waveform: Tensor,
+ sample_rate: int,
+ center_freq: float,
+ gain: float,
+ Q: float = 0.707,
+) -> Tensor:
+ r"""Design biquad peaking equalizer filter and perform filtering. Similar to SoX implementation.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ waveform (Tensor): audio waveform of dimension of `(..., time)`
+ sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
+ center_freq (float): filter's central frequency
+ gain (float or torch.Tensor): desired gain at the boost (or attenuation) in dB
+ Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
+
+ Returns:
+ Tensor: Waveform of dimension of `(..., time)`
+ """
+ dtype = waveform.dtype
+ device = waveform.device
+ center_freq = torch.as_tensor(center_freq, dtype=dtype, device=device)
+ Q = torch.as_tensor(Q, dtype=dtype, device=device)
+ gain = torch.as_tensor(gain, dtype=dtype, device=device)
+
+ w0 = 2 * math.pi * center_freq / sample_rate
+ A = torch.exp(gain / 40.0 * math.log(10))
+ alpha = torch.sin(w0) / 2 / Q
+
+ b0 = 1 + alpha * A
+ b1 = -2 * torch.cos(w0)
+ b2 = 1 - alpha * A
+ a0 = 1 + alpha / A
+ a1 = -2 * torch.cos(w0)
+ a2 = 1 - alpha / A
+ return biquad(waveform, b0, b1, b2, a0, a1, a2)
+
+
+def filtfilt(
+ waveform: Tensor,
+ a_coeffs: Tensor,
+ b_coeffs: Tensor,
+ clamp: bool = True,
+) -> Tensor:
+ r"""Apply an IIR filter forward and backward to a waveform.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Inspired by https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.filtfilt.html
+
+ Args:
+ waveform (Tensor): audio waveform of dimension of `(..., time)`. Must be normalized to -1 to 1.
+ a_coeffs (Tensor): denominator coefficients of difference equation of dimension of either
+ 1D with shape `(num_order + 1)` or 2D with shape `(num_filters, num_order + 1)`.
+ Lower delay coefficients are first, e.g. ``[a0, a1, a2, ...]``.
+ Must be same size as b_coeffs (pad with 0's as necessary).
+ b_coeffs (Tensor): numerator coefficients of difference equation of dimension of either
+ 1D with shape `(num_order + 1)` or 2D with shape `(num_filters, num_order + 1)`.
+ Lower delay coefficients are first, e.g. ``[b0, b1, b2, ...]``.
+ Must be same size as a_coeffs (pad with 0's as necessary).
+ clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
+
+ Returns:
+ Tensor: Waveform with dimension of either `(..., num_filters, time)` if ``a_coeffs`` and ``b_coeffs``
+ are 2D Tensors, or `(..., time)` otherwise.
+ """
+ forward_filtered = lfilter(waveform, a_coeffs, b_coeffs, clamp=False, batching=True)
+ backward_filtered = lfilter(
+ forward_filtered.flip(-1),
+ a_coeffs,
+ b_coeffs,
+ clamp=clamp,
+ batching=True,
+ ).flip(-1)
+ return backward_filtered
+
+
+def flanger(
+ waveform: Tensor,
+ sample_rate: int,
+ delay: float = 0.0,
+ depth: float = 2.0,
+ regen: float = 0.0,
+ width: float = 71.0,
+ speed: float = 0.5,
+ phase: float = 25.0,
+ modulation: str = "sinusoidal",
+ interpolation: str = "linear",
+) -> Tensor:
+ r"""Apply a flanger effect to the audio. Similar to SoX implementation.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ waveform (Tensor): audio waveform of dimension of `(..., channel, time)` .
+ Max 4 channels allowed
+ sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
+ delay (float, optional): desired delay in milliseconds(ms)
+ Allowed range of values are 0 to 30
+ depth (float, optional): desired delay depth in milliseconds(ms)
+ Allowed range of values are 0 to 10
+ regen (float, optional): desired regen(feedback gain) in dB
+ Allowed range of values are -95 to 95
+ width (float, optional): desired width(delay gain) in dB
+ Allowed range of values are 0 to 100
+ speed (float, optional): modulation speed in Hz
+ Allowed range of values are 0.1 to 10
+ phase (float, optional): percentage phase-shift for multi-channel
+ Allowed range of values are 0 to 100
+ modulation (str, optional): Use either "sinusoidal" or "triangular" modulation. (Default: ``sinusoidal``)
+ interpolation (str, optional): Use either "linear" or "quadratic" for delay-line interpolation.
+ (Default: ``linear``)
+
+ Returns:
+ Tensor: Waveform of dimension of `(..., channel, time)`
+
+ Reference:
+ - http://sox.sourceforge.net/sox.html
+
+ - Scott Lehman, `Effects Explained`_,
+
+ .. _Effects Explained:
+ https://web.archive.org/web/20051125072557/http://www.harmony-central.com/Effects/effects-explained.html
+ """
+
+ if modulation not in ("sinusoidal", "triangular"):
+ raise ValueError('Only "sinusoidal" or "triangular" modulation allowed')
+
+ if interpolation not in ("linear", "quadratic"):
+ raise ValueError('Only "linear" or "quadratic" interpolation allowed')
+
+ actual_shape = waveform.shape
+ device, dtype = waveform.device, waveform.dtype
+
+ if actual_shape[-2] > 4:
+ raise ValueError("Max 4 channels allowed")
+
+ # convert to 3D (batch, channels, time)
+ waveform = waveform.view(-1, actual_shape[-2], actual_shape[-1])
+
+ # Scaling
+ feedback_gain = regen / 100
+ delay_gain = width / 100
+ channel_phase = phase / 100
+ delay_min = delay / 1000
+ delay_depth = depth / 1000
+
+ n_channels = waveform.shape[-2]
+
+ if modulation == "sinusoidal":
+ wave_type = "SINE"
+ else:
+ wave_type = "TRIANGLE"
+
+ # Balance output:
+ in_gain = 1.0 / (1 + delay_gain)
+ delay_gain = delay_gain / (1 + delay_gain)
+
+ # Balance feedback loop:
+ delay_gain = delay_gain * (1 - abs(feedback_gain))
+
+ delay_buf_length = int((delay_min + delay_depth) * sample_rate + 0.5)
+ delay_buf_length = delay_buf_length + 2
+
+ delay_bufs = torch.zeros(waveform.shape[0], n_channels, delay_buf_length, dtype=dtype, device=device)
+ delay_last = torch.zeros(waveform.shape[0], n_channels, dtype=dtype, device=device)
+
+ lfo_length = int(sample_rate / speed)
+
+ table_min = math.floor(delay_min * sample_rate + 0.5)
+ table_max = delay_buf_length - 2.0
+
+ lfo = _generate_wave_table(
+ wave_type=wave_type,
+ data_type="FLOAT",
+ table_size=lfo_length,
+ min=float(table_min),
+ max=float(table_max),
+ phase=3 * math.pi / 2,
+ device=device,
+ )
+
+ output_waveform = torch.zeros_like(waveform, dtype=dtype, device=device)
+
+ delay_buf_pos = 0
+ lfo_pos = 0
+ channel_idxs = torch.arange(0, n_channels, device=device)
+
+ for i in range(waveform.shape[-1]):
+
+ delay_buf_pos = (delay_buf_pos + delay_buf_length - 1) % delay_buf_length
+
+ cur_channel_phase = (channel_idxs * lfo_length * channel_phase + 0.5).to(torch.int64)
+ delay_tensor = lfo[(lfo_pos + cur_channel_phase) % lfo_length]
+ frac_delay = torch.frac(delay_tensor)
+ delay_tensor = torch.floor(delay_tensor)
+
+ int_delay = delay_tensor.to(torch.int64)
+
+ temp = waveform[:, :, i]
+
+ delay_bufs[:, :, delay_buf_pos] = temp + delay_last * feedback_gain
+
+ delayed_0 = delay_bufs[:, channel_idxs, (delay_buf_pos + int_delay) % delay_buf_length]
+
+ int_delay = int_delay + 1
+
+ delayed_1 = delay_bufs[:, channel_idxs, (delay_buf_pos + int_delay) % delay_buf_length]
+
+ int_delay = int_delay + 1
+
+ if interpolation == "linear":
+ delayed = delayed_0 + (delayed_1 - delayed_0) * frac_delay
+ else:
+ delayed_2 = delay_bufs[:, channel_idxs, (delay_buf_pos + int_delay) % delay_buf_length]
+
+ int_delay = int_delay + 1
+
+ delayed_2 = delayed_2 - delayed_0
+ delayed_1 = delayed_1 - delayed_0
+ a = delayed_2 * 0.5 - delayed_1
+ b = delayed_1 * 2 - delayed_2 * 0.5
+
+ delayed = delayed_0 + (a * frac_delay + b) * frac_delay
+
+ delay_last = delayed
+ output_waveform[:, :, i] = waveform[:, :, i] * in_gain + delayed * delay_gain
+
+ lfo_pos = (lfo_pos + 1) % lfo_length
+
+ return output_waveform.clamp(min=-1, max=1).view(actual_shape)
+
+
+def gain(waveform: Tensor, gain_db: float = 1.0) -> Tensor:
+ r"""Apply amplification or attenuation to the whole waveform.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ waveform (Tensor): Tensor of audio of dimension (..., time).
+ gain_db (float, optional) Gain adjustment in decibels (dB) (Default: ``1.0``).
+
+ Returns:
+ Tensor: the whole waveform amplified by gain_db.
+ """
+ if gain_db == 0:
+ return waveform
+
+ ratio = 10 ** (gain_db / 20)
+
+ return waveform * ratio
+
+
+def highpass_biquad(waveform: Tensor, sample_rate: int, cutoff_freq: float, Q: float = 0.707) -> Tensor:
+ r"""Design biquad highpass filter and perform filtering. Similar to SoX implementation.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ waveform (Tensor): audio waveform of dimension of `(..., time)`
+ sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
+ cutoff_freq (float or torch.Tensor): filter cutoff frequency
+ Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
+
+ Returns:
+ Tensor: Waveform dimension of `(..., time)`
+ """
+ dtype = waveform.dtype
+ device = waveform.device
+ cutoff_freq = torch.as_tensor(cutoff_freq, dtype=dtype, device=device)
+ Q = torch.as_tensor(Q, dtype=dtype, device=device)
+
+ w0 = 2 * math.pi * cutoff_freq / sample_rate
+ alpha = torch.sin(w0) / 2.0 / Q
+
+ b0 = (1 + torch.cos(w0)) / 2
+ b1 = -1 - torch.cos(w0)
+ b2 = b0
+ a0 = 1 + alpha
+ a1 = -2 * torch.cos(w0)
+ a2 = 1 - alpha
+ return biquad(waveform, b0, b1, b2, a0, a1, a2)
+
+
+def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: Tensor, padded_output_waveform: Tensor):
+ n_order = a_coeffs_flipped.size(1)
+ a_coeffs_flipped = a_coeffs_flipped.unsqueeze(2)
+ for i_sample, o0 in enumerate(input_signal_windows.permute(2, 0, 1)):
+ windowed_output_signal = padded_output_waveform[:, :, i_sample : i_sample + n_order]
+ o0 -= (windowed_output_signal.transpose(0, 1) @ a_coeffs_flipped)[..., 0].t()
+ padded_output_waveform[:, :, i_sample + n_order - 1] = o0
+
+
+if _IS_TORCHAUDIO_EXT_AVAILABLE:
+ _lfilter_core_cpu_loop = torch.ops.torchaudio._lfilter_core_loop
+else:
+ _lfilter_core_cpu_loop = _lfilter_core_generic_loop
+
+
+def _lfilter_core(
+ waveform: Tensor,
+ a_coeffs: Tensor,
+ b_coeffs: Tensor,
+) -> Tensor:
+
+ if a_coeffs.size() != b_coeffs.size():
+ raise ValueError(
+ "Expected coeffs to be the same size."
+ f"Found a_coeffs size: {a_coeffs.size()}, b_coeffs size: {b_coeffs.size()}"
+ )
+ if waveform.ndim != 3:
+ raise ValueError(f"Expected waveform to be 3 dimensional. Found: {waveform.ndim}")
+ if not (waveform.device == a_coeffs.device == b_coeffs.device):
+ raise ValueError(
+ "Expected waveform and coeffs to be on the same device."
+ f"Found: waveform device:{waveform.device}, a_coeffs device: {a_coeffs.device}, "
+ f"b_coeffs device: {b_coeffs.device}"
+ )
+
+ n_batch, n_channel, n_sample = waveform.size()
+ n_order = a_coeffs.size(1)
+ if n_order <= 0:
+ raise ValueError(f"Expected n_order to be positive. Found: {n_order}")
+
+ # Pad the input and create output
+
+ padded_waveform = torch.nn.functional.pad(waveform, [n_order - 1, 0])
+ padded_output_waveform = torch.zeros_like(padded_waveform)
+
+ # Set up the coefficients matrix
+ # Flip coefficients' order
+ a_coeffs_flipped = a_coeffs.flip(1)
+ b_coeffs_flipped = b_coeffs.flip(1)
+
+ # calculate windowed_input_signal in parallel using convolution
+ input_signal_windows = torch.nn.functional.conv1d(padded_waveform, b_coeffs_flipped.unsqueeze(1), groups=n_channel)
+
+ input_signal_windows.div_(a_coeffs[:, :1])
+ a_coeffs_flipped.div_(a_coeffs[:, :1])
+
+ if (
+ input_signal_windows.device == torch.device("cpu")
+ and a_coeffs_flipped.device == torch.device("cpu")
+ and padded_output_waveform.device == torch.device("cpu")
+ ):
+ _lfilter_core_cpu_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform)
+ else:
+ _lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform)
+
+ output = padded_output_waveform[:, :, n_order - 1 :]
+ return output
+
+
+if _IS_TORCHAUDIO_EXT_AVAILABLE:
+ _lfilter = torch.ops.torchaudio._lfilter
+else:
+ _lfilter = _lfilter_core
+
+
+def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool = True, batching: bool = True) -> Tensor:
+ r"""Perform an IIR filter by evaluating difference equation, using differentiable implementation
+ developed independently by *Yu et al.* :cite:`ismir_YuF23` and *Forgione et al.* :cite:`forgione2021dynonet`.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Note:
+ To avoid numerical problems, small filter order is preferred.
+ Using double precision could also minimize numerical precision errors.
+
+ Args:
+ waveform (Tensor): audio waveform of dimension of `(..., time)`. Must be normalized to -1 to 1.
+ a_coeffs (Tensor): denominator coefficients of difference equation of dimension of either
+ 1D with shape `(num_order + 1)` or 2D with shape `(num_filters, num_order + 1)`.
+ Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``.
+ Must be same size as b_coeffs (pad with 0's as necessary).
+ b_coeffs (Tensor): numerator coefficients of difference equation of dimension of either
+ 1D with shape `(num_order + 1)` or 2D with shape `(num_filters, num_order + 1)`.
+ Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``.
+ Must be same size as a_coeffs (pad with 0's as necessary).
+ clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
+ batching (bool, optional): Effective only when coefficients are 2D. If ``True``, then waveform should be at
+ least 2D, and the size of second axis from last should equals to ``num_filters``.
+ The output can be expressed as ``output[..., i, :] = lfilter(waveform[..., i, :],
+ a_coeffs[i], b_coeffs[i], clamp=clamp, batching=False)``. (Default: ``True``)
+
+ Returns:
+ Tensor: Waveform with dimension of either `(..., num_filters, time)` if ``a_coeffs`` and ``b_coeffs``
+ are 2D Tensors, or `(..., time)` otherwise.
+ """
+ if a_coeffs.size() != b_coeffs.size():
+ raise ValueError(
+ "Expected coeffs to be the same size."
+ f"Found: a_coeffs size: {a_coeffs.size()}, b_coeffs size: {b_coeffs.size()}"
+ )
+ if a_coeffs.ndim > 2:
+ raise ValueError(f"Expected coeffs to have greater than 1 dimension. Found: {a_coeffs.ndim}")
+
+ if a_coeffs.ndim > 1:
+ if batching:
+ if waveform.ndim <= 0:
+ raise ValueError("Expected waveform to have a positive number of dimensions." f"Found: {waveform.ndim}")
+ if waveform.shape[-2] != a_coeffs.shape[0]:
+ raise ValueError(
+ "Expected number of batches in waveform and coeffs to be the same."
+ f"Found: coeffs batches: {a_coeffs.shape[0]}, waveform batches: {waveform.shape[-2]}"
+ )
+ else:
+ waveform = torch.stack([waveform] * a_coeffs.shape[0], -2)
+ else:
+ a_coeffs = a_coeffs.unsqueeze(0)
+ b_coeffs = b_coeffs.unsqueeze(0)
+
+ # pack batch
+ shape = waveform.size()
+ waveform = waveform.reshape(-1, a_coeffs.shape[0], shape[-1])
+ output = _lfilter(waveform, a_coeffs, b_coeffs)
+
+ if clamp:
+ output = torch.clamp(output, min=-1.0, max=1.0)
+
+ # unpack batch
+ output = output.reshape(shape[:-1] + output.shape[-1:])
+
+ return output
+
+
+def lowpass_biquad(waveform: Tensor, sample_rate: int, cutoff_freq: float, Q: float = 0.707) -> Tensor:
+ r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
+ sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
+ cutoff_freq (float or torch.Tensor): filter cutoff frequency
+ Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
+
+ Returns:
+ Tensor: Waveform of dimension of `(..., time)`
+ """
+ dtype = waveform.dtype
+ device = waveform.device
+ cutoff_freq = torch.as_tensor(cutoff_freq, dtype=dtype, device=device)
+ Q = torch.as_tensor(Q, dtype=dtype, device=device)
+
+ w0 = 2 * math.pi * cutoff_freq / sample_rate
+ alpha = torch.sin(w0) / 2 / Q
+
+ b0 = (1 - torch.cos(w0)) / 2
+ b1 = 1 - torch.cos(w0)
+ b2 = b0
+ a0 = 1 + alpha
+ a1 = -2 * torch.cos(w0)
+ a2 = 1 - alpha
+ return biquad(waveform, b0, b1, b2, a0, a1, a2)
+
+
+def _overdrive_core_loop_generic(
+ waveform: Tensor, temp: Tensor, last_in: Tensor, last_out: Tensor, output_waveform: Tensor
+):
+ for i in range(waveform.shape[-1]):
+ last_out = temp[:, i] - last_in + 0.995 * last_out
+ last_in = temp[:, i]
+ output_waveform[:, i] = waveform[:, i] * 0.5 + last_out * 0.75
+
+
+if _IS_TORCHAUDIO_EXT_AVAILABLE:
+ _overdrive_core_loop_cpu = torch.ops.torchaudio._overdrive_core_loop
+else:
+ _overdrive_core_loop_cpu = _overdrive_core_loop_generic
+
+
+def overdrive(waveform: Tensor, gain: float = 20, colour: float = 20) -> Tensor:
+ r"""Apply a overdrive effect to the audio. Similar to SoX implementation.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ This effect applies a non linear distortion to the audio signal.
+
+ Args:
+ waveform (Tensor): audio waveform of dimension of `(..., time)`
+ gain (float, optional): desired gain at the boost (or attenuation) in dB
+ Allowed range of values are 0 to 100
+ colour (float, optional): controls the amount of even harmonic content in the over-driven output
+ Allowed range of values are 0 to 100
+
+ Returns:
+ Tensor: Waveform of dimension of `(..., time)`
+
+ Reference:
+ - http://sox.sourceforge.net/sox.html
+ """
+ actual_shape = waveform.shape
+ device, dtype = waveform.device, waveform.dtype
+
+ # convert to 2D (..,time)
+ waveform = waveform.view(-1, actual_shape[-1])
+
+ gain = _dB2Linear(gain)
+ colour = colour / 200
+ last_in = torch.zeros(waveform.shape[:-1], dtype=dtype, device=device)
+ last_out = torch.zeros(waveform.shape[:-1], dtype=dtype, device=device)
+
+ temp = waveform * gain + colour
+
+ mask1 = temp < -1
+ temp[mask1] = torch.tensor(-2.0 / 3.0, dtype=dtype, device=device)
+ # Wrapping the constant with Tensor is required for Torchscript
+
+ mask2 = temp > 1
+ temp[mask2] = torch.tensor(2.0 / 3.0, dtype=dtype, device=device)
+
+ mask3 = ~mask1 & ~mask2
+ temp[mask3] = temp[mask3] - (temp[mask3] ** 3) * (1.0 / 3)
+
+ output_waveform = torch.zeros_like(waveform, dtype=dtype, device=device)
+
+ # Uses CPU optimized loop function if available for CPU device
+ if device == torch.device("cpu"):
+ _overdrive_core_loop_cpu(waveform, temp, last_in, last_out, output_waveform)
+ else:
+ _overdrive_core_loop_generic(waveform, temp, last_in, last_out, output_waveform)
+
+ return output_waveform.clamp(min=-1, max=1).view(actual_shape)
+
+
+def phaser(
+ waveform: Tensor,
+ sample_rate: int,
+ gain_in: float = 0.4,
+ gain_out: float = 0.74,
+ delay_ms: float = 3.0,
+ decay: float = 0.4,
+ mod_speed: float = 0.5,
+ sinusoidal: bool = True,
+) -> Tensor:
+ r"""Apply a phasing effect to the audio. Similar to SoX implementation.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ waveform (Tensor): audio waveform of dimension of `(..., time)`
+ sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
+ gain_in (float, optional): desired input gain at the boost (or attenuation) in dB
+ Allowed range of values are 0 to 1
+ gain_out (float, optional): desired output gain at the boost (or attenuation) in dB
+ Allowed range of values are 0 to 1e9
+ delay_ms (float, optional): desired delay in milliseconds
+ Allowed range of values are 0 to 5.0
+ decay (float, optional): desired decay relative to gain-in
+ Allowed range of values are 0 to 0.99
+ mod_speed (float, optional): modulation speed in Hz
+ Allowed range of values are 0.1 to 2
+ sinusoidal (bool, optional): If ``True``, uses sinusoidal modulation (preferable for multiple instruments)
+ If ``False``, uses triangular modulation (gives single instruments a sharper phasing effect)
+ (Default: ``True``)
+
+ Returns:
+ Tensor: Waveform of dimension of `(..., time)`
+
+ Reference:
+ - http://sox.sourceforge.net/sox.html
+ - Scott Lehman, `Effects Explained`_.
+
+ .. _Effects Explained:
+ https://web.archive.org/web/20051125072557/http://www.harmony-central.com/Effects/effects-explained.html
+ """
+ actual_shape = waveform.shape
+ device, dtype = waveform.device, waveform.dtype
+
+ # convert to 2D (channels,time)
+ waveform = waveform.view(-1, actual_shape[-1])
+
+ delay_buf_len = int((delay_ms * 0.001 * sample_rate) + 0.5)
+ delay_buf = torch.zeros(waveform.shape[0], delay_buf_len, dtype=dtype, device=device)
+
+ mod_buf_len = int(sample_rate / mod_speed + 0.5)
+
+ if sinusoidal:
+ wave_type = "SINE"
+ else:
+ wave_type = "TRIANGLE"
+
+ mod_buf = _generate_wave_table(
+ wave_type=wave_type,
+ data_type="INT",
+ table_size=mod_buf_len,
+ min=1.0,
+ max=float(delay_buf_len),
+ phase=math.pi / 2,
+ device=device,
+ )
+
+ delay_pos = 0
+ mod_pos = 0
+
+ output_waveform_pre_gain_list = []
+ waveform = waveform * gain_in
+ delay_buf = delay_buf * decay
+ waveform_list = [waveform[:, i] for i in range(waveform.size(1))]
+ delay_buf_list = [delay_buf[:, i] for i in range(delay_buf.size(1))]
+ mod_buf_list = [mod_buf[i] for i in range(mod_buf.size(0))]
+
+ for i in range(waveform.shape[-1]):
+ idx = int((delay_pos + mod_buf_list[mod_pos]) % delay_buf_len)
+ mod_pos = (mod_pos + 1) % mod_buf_len
+ delay_pos = (delay_pos + 1) % delay_buf_len
+ temp = (waveform_list[i]) + (delay_buf_list[idx])
+ delay_buf_list[delay_pos] = temp * decay
+ output_waveform_pre_gain_list.append(temp)
+
+ output_waveform = torch.stack(output_waveform_pre_gain_list, dim=1).to(dtype=dtype, device=device)
+ output_waveform.mul_(gain_out)
+
+ return output_waveform.clamp(min=-1, max=1).view(actual_shape)
+
+
+def riaa_biquad(waveform: Tensor, sample_rate: int) -> Tensor:
+ r"""Apply RIAA vinyl playback equalization. Similar to SoX implementation.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ waveform (Tensor): audio waveform of dimension of `(..., time)`
+ sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz).
+ Allowed sample rates in Hz : ``44100``,``48000``,``88200``,``96000``
+
+ Returns:
+ Tensor: Waveform of dimension of `(..., time)`
+
+ Reference:
+ - http://sox.sourceforge.net/sox.html
+ - https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF
+ """
+
+ if sample_rate == 44100:
+ zeros = [-0.2014898, 0.9233820]
+ poles = [0.7083149, 0.9924091]
+
+ elif sample_rate == 48000:
+ zeros = [-0.1766069, 0.9321590]
+ poles = [0.7396325, 0.9931330]
+
+ elif sample_rate == 88200:
+ zeros = [-0.1168735, 0.9648312]
+ poles = [0.8590646, 0.9964002]
+
+ elif sample_rate == 96000:
+ zeros = [-0.1141486, 0.9676817]
+ poles = [0.8699137, 0.9966946]
+
+ else:
+ raise ValueError("Sample rate must be 44.1k, 48k, 88.2k, or 96k")
+
+ # polynomial coefficients with roots zeros[0] and zeros[1]
+ b0 = 1.0
+ b1 = -(zeros[0] + zeros[1])
+ b2 = zeros[0] * zeros[1]
+
+ # polynomial coefficients with roots poles[0] and poles[1]
+ a0 = 1.0
+ a1 = -(poles[0] + poles[1])
+ a2 = poles[0] * poles[1]
+
+ # Normalize to 0dB at 1kHz
+ y = 2 * math.pi * 1000 / sample_rate
+ b_re = b0 + b1 * math.cos(-y) + b2 * math.cos(-2 * y)
+ a_re = a0 + a1 * math.cos(-y) + a2 * math.cos(-2 * y)
+ b_im = b1 * math.sin(-y) + b2 * math.sin(-2 * y)
+ a_im = a1 * math.sin(-y) + a2 * math.sin(-2 * y)
+ g = 1 / math.sqrt((b_re**2 + b_im**2) / (a_re**2 + a_im**2))
+
+ b0 *= g
+ b1 *= g
+ b2 *= g
+
+ return biquad(waveform, b0, b1, b2, a0, a1, a2)
+
+
+def treble_biquad(
+ waveform: Tensor,
+ sample_rate: int,
+ gain: float,
+ central_freq: float = 3000,
+ Q: float = 0.707,
+) -> Tensor:
+ r"""Design a treble tone-control effect. Similar to SoX implementation.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ waveform (Tensor): audio waveform of dimension of `(..., time)`
+ sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
+ gain (float or torch.Tensor): desired gain at the boost (or attenuation) in dB.
+ central_freq (float or torch.Tensor, optional): central frequency (in Hz). (Default: ``3000``)
+ Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``).
+
+ Returns:
+ Tensor: Waveform of dimension of `(..., time)`
+
+ Reference:
+ - http://sox.sourceforge.net/sox.html
+ - https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF
+ """
+ dtype = waveform.dtype
+ device = waveform.device
+ central_freq = torch.as_tensor(central_freq, dtype=dtype, device=device)
+ Q = torch.as_tensor(Q, dtype=dtype, device=device)
+ gain = torch.as_tensor(gain, dtype=dtype, device=device)
+
+ w0 = 2 * math.pi * central_freq / sample_rate
+ alpha = torch.sin(w0) / 2 / Q
+ A = torch.exp(gain / 40 * math.log(10))
+
+ temp1 = 2 * torch.sqrt(A) * alpha
+ temp2 = (A - 1) * torch.cos(w0)
+ temp3 = (A + 1) * torch.cos(w0)
+
+ b0 = A * ((A + 1) + temp2 + temp1)
+ b1 = -2 * A * ((A - 1) + temp3)
+ b2 = A * ((A + 1) + temp2 - temp1)
+ a0 = (A + 1) - temp2 + temp1
+ a1 = 2 * ((A - 1) - temp3)
+ a2 = (A + 1) - temp2 - temp1
+
+ return biquad(waveform, b0, b1, b2, a0, a1, a2)
+
+
+def _measure(
+ measure_len_ws: int,
+ samples: Tensor,
+ spectrum: Tensor,
+ noise_spectrum: Tensor,
+ spectrum_window: Tensor,
+ spectrum_start: int,
+ spectrum_end: int,
+ cepstrum_window: Tensor,
+ cepstrum_start: int,
+ cepstrum_end: int,
+ noise_reduction_amount: float,
+ measure_smooth_time_mult: float,
+ noise_up_time_mult: Tensor,
+ noise_down_time_mult: Tensor,
+ boot_count: int,
+) -> float:
+ device = samples.device
+
+ if spectrum.size(-1) != noise_spectrum.size(-1):
+ raise ValueError(
+ "Expected spectrum size to match noise spectrum size in final dimension."
+ f"Found: spectrum size: {spectrum.size()}, noise_spectrum size: {noise_spectrum.size()}"
+ )
+
+ dft_len_ws = spectrum.size()[-1]
+
+ dftBuf = torch.zeros(dft_len_ws, device=device)
+
+ dftBuf[:measure_len_ws] = samples * spectrum_window[:measure_len_ws]
+
+ # lsx_safe_rdft((int)p->dft_len_ws, 1, c->dftBuf);
+ _dftBuf = torch.fft.rfft(dftBuf)
+
+ mult: float = boot_count / (1.0 + boot_count) if boot_count >= 0 else measure_smooth_time_mult
+
+ _d = _dftBuf[spectrum_start:spectrum_end].abs()
+ spectrum[spectrum_start:spectrum_end].mul_(mult).add_(_d * (1 - mult))
+ _d = spectrum[spectrum_start:spectrum_end] ** 2
+
+ _zeros = torch.zeros(spectrum_end - spectrum_start, device=device)
+ _mult = (
+ _zeros
+ if boot_count >= 0
+ else torch.where(
+ _d > noise_spectrum[spectrum_start:spectrum_end],
+ noise_up_time_mult, # if
+ noise_down_time_mult, # else,
+ )
+ )
+
+ noise_spectrum[spectrum_start:spectrum_end].mul_(_mult).add_(_d * (1 - _mult))
+ _d = torch.sqrt(
+ torch.max(
+ _zeros,
+ _d - noise_reduction_amount * noise_spectrum[spectrum_start:spectrum_end],
+ ),
+ )
+
+ _cepstrum_Buf: Tensor = torch.zeros(dft_len_ws >> 1, device=device)
+ _cepstrum_Buf[spectrum_start:spectrum_end] = _d * cepstrum_window
+ _cepstrum_Buf[spectrum_end : dft_len_ws >> 1].zero_()
+
+ # lsx_safe_rdft((int)p->dft_len_ws >> 1, 1, c->dftBuf);
+ _cepstrum_Buf = torch.fft.rfft(_cepstrum_Buf)
+
+ result: float = float(torch.sum(_cepstrum_Buf[cepstrum_start:cepstrum_end].abs().pow(2)))
+ result = math.log(result / (cepstrum_end - cepstrum_start)) if result > 0 else -math.inf
+ return max(0, 21 + result)
+
+
+def vad(
+ waveform: Tensor,
+ sample_rate: int,
+ trigger_level: float = 7.0,
+ trigger_time: float = 0.25,
+ search_time: float = 1.0,
+ allowed_gap: float = 0.25,
+ pre_trigger_time: float = 0.0,
+ # Fine-tuning parameters
+ boot_time: float = 0.35,
+ noise_up_time: float = 0.1,
+ noise_down_time: float = 0.01,
+ noise_reduction_amount: float = 1.35,
+ measure_freq: float = 20.0,
+ measure_duration: Optional[float] = None,
+ measure_smooth_time: float = 0.4,
+ hp_filter_freq: float = 50.0,
+ lp_filter_freq: float = 6000.0,
+ hp_lifter_freq: float = 150.0,
+ lp_lifter_freq: float = 2000.0,
+) -> Tensor:
+ r"""Voice Activity Detector. Similar to SoX implementation.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: TorchScript
+
+ Attempts to trim silence and quiet background sounds from the ends of recordings of speech.
+ The algorithm currently uses a simple cepstral power measurement to detect voice,
+ so may be fooled by other things, especially music.
+
+ The effect can trim only from the front of the audio,
+ so in order to trim from the back, the reverse effect must also be used.
+
+ Args:
+ waveform (Tensor): Tensor of audio of dimension `(channels, time)` or `(time)`
+ Tensor of shape `(channels, time)` is treated as a multi-channel recording
+ of the same event and the resulting output will be trimmed to the earliest
+ voice activity in any channel.
+ sample_rate (int): Sample rate of audio signal.
+ trigger_level (float, optional): The measurement level used to trigger activity detection.
+ This may need to be cahnged depending on the noise level, signal level,
+ and other characteristics of the input audio. (Default: 7.0)
+ trigger_time (float, optional): The time constant (in seconds)
+ used to help ignore short bursts of sound. (Default: 0.25)
+ search_time (float, optional): The amount of audio (in seconds)
+ to search for quieter/shorter bursts of audio to include prior
+ to the detected trigger point. (Default: 1.0)
+ allowed_gap (float, optional): The allowed gap (in seconds) between
+ quieter/shorter bursts of audio to include prior
+ to the detected trigger point. (Default: 0.25)
+ pre_trigger_time (float, optional): The amount of audio (in seconds) to preserve
+ before the trigger point and any found quieter/shorter bursts. (Default: 0.0)
+ boot_time (float, optional) The algorithm (internally) uses adaptive noise
+ estimation/reduction in order to detect the start of the wanted audio.
+ This option sets the time for the initial noise estimate. (Default: 0.35)
+ noise_up_time (float, optional) Time constant used by the adaptive noise estimator
+ for when the noise level is increasing. (Default: 0.1)
+ noise_down_time (float, optional) Time constant used by the adaptive noise estimator
+ for when the noise level is decreasing. (Default: 0.01)
+ noise_reduction_amount (float, optional) Amount of noise reduction to use in
+ the detection algorithm (e.g. 0, 0.5, ...). (Default: 1.35)
+ measure_freq (float, optional) Frequency of the algorithm's
+ processing/measurements. (Default: 20.0)
+ measure_duration: (float, optional) Measurement duration.
+ (Default: Twice the measurement period; i.e. with overlap.)
+ measure_smooth_time (float, optional) Time constant used to smooth
+ spectral measurements. (Default: 0.4)
+ hp_filter_freq (float, optional) "Brick-wall" frequency of high-pass filter applied
+ at the input to the detector algorithm. (Default: 50.0)
+ lp_filter_freq (float, optional) "Brick-wall" frequency of low-pass filter applied
+ at the input to the detector algorithm. (Default: 6000.0)
+ hp_lifter_freq (float, optional) "Brick-wall" frequency of high-pass lifter used
+ in the detector algorithm. (Default: 150.0)
+ lp_lifter_freq (float, optional) "Brick-wall" frequency of low-pass lifter used
+ in the detector algorithm. (Default: 2000.0)
+
+ Returns:
+ Tensor: Tensor of audio of dimension `(..., time)`.
+
+ Reference:
+ - http://sox.sourceforge.net/sox.html
+ """
+ device = waveform.device
+
+ if waveform.ndim > 2:
+ warnings.warn(
+ "Expected input tensor dimension of 1 for single channel"
+ f" or 2 for multi-channel. Got {waveform.ndim} instead. "
+ "Batch semantics is not supported. "
+ "Please refer to https://github.com/pytorch/audio/issues/1348"
+ " and https://github.com/pytorch/audio/issues/1468."
+ )
+
+ measure_duration: float = 2.0 / measure_freq if measure_duration is None else measure_duration
+
+ measure_len_ws = int(sample_rate * measure_duration + 0.5)
+ measure_len_ns = measure_len_ws
+ # for (dft_len_ws = 16; dft_len_ws < measure_len_ws; dft_len_ws <<= 1);
+ dft_len_ws = 16
+ while dft_len_ws < measure_len_ws:
+ dft_len_ws *= 2
+
+ measure_period_ns = int(sample_rate / measure_freq + 0.5)
+ measures_len = math.ceil(search_time * measure_freq)
+ search_pre_trigger_len_ns = measures_len * measure_period_ns
+ gap_len = int(allowed_gap * measure_freq + 0.5)
+
+ fixed_pre_trigger_len_ns = int(pre_trigger_time * sample_rate + 0.5)
+ samplesLen_ns = fixed_pre_trigger_len_ns + search_pre_trigger_len_ns + measure_len_ns
+
+ spectrum_window = torch.zeros(measure_len_ws, device=device)
+ for i in range(measure_len_ws):
+ # sox.h:741 define SOX_SAMPLE_MIN (sox_sample_t)SOX_INT_MIN(32)
+ spectrum_window[i] = 2.0 / math.sqrt(float(measure_len_ws))
+ # lsx_apply_hann(spectrum_window, (int)measure_len_ws);
+ spectrum_window *= torch.hann_window(measure_len_ws, device=device, dtype=torch.float)
+
+ spectrum_start: int = int(hp_filter_freq / sample_rate * dft_len_ws + 0.5)
+ spectrum_start: int = max(spectrum_start, 1)
+ spectrum_end: int = int(lp_filter_freq / sample_rate * dft_len_ws + 0.5)
+ spectrum_end: int = min(spectrum_end, dft_len_ws // 2)
+
+ cepstrum_window = torch.zeros(spectrum_end - spectrum_start, device=device)
+ for i in range(spectrum_end - spectrum_start):
+ cepstrum_window[i] = 2.0 / math.sqrt(float(spectrum_end) - spectrum_start)
+ # lsx_apply_hann(cepstrum_window,(int)(spectrum_end - spectrum_start));
+ cepstrum_window *= torch.hann_window(spectrum_end - spectrum_start, device=device, dtype=torch.float)
+
+ cepstrum_start = math.ceil(sample_rate * 0.5 / lp_lifter_freq)
+ cepstrum_end = math.floor(sample_rate * 0.5 / hp_lifter_freq)
+ cepstrum_end = min(cepstrum_end, dft_len_ws // 4)
+
+ if cepstrum_end <= cepstrum_start:
+ raise ValueError(
+ "Expected cepstrum_start to be smaller than cepstrum_end."
+ f"Found: cepstrum_start: {cepstrum_start}, cepstrum_end: {cepstrum_end}."
+ )
+
+ noise_up_time_mult = torch.tensor(math.exp(-1.0 / (noise_up_time * measure_freq)), device=device)
+ noise_down_time_mult = torch.tensor(math.exp(-1.0 / (noise_down_time * measure_freq)), device=device)
+ measure_smooth_time_mult = math.exp(-1.0 / (measure_smooth_time * measure_freq))
+ trigger_meas_time_mult = math.exp(-1.0 / (trigger_time * measure_freq))
+
+ boot_count_max = int(boot_time * measure_freq - 0.5)
+ boot_count = measures_index = flushedLen_ns = 0
+
+ # pack batch
+ shape = waveform.size()
+ waveform = waveform.view(-1, shape[-1])
+
+ n_channels, ilen = waveform.size()
+
+ mean_meas = torch.zeros(n_channels, device=device)
+ spectrum = torch.zeros(n_channels, dft_len_ws, device=device)
+ noise_spectrum = torch.zeros(n_channels, dft_len_ws, device=device)
+ measures = torch.zeros(n_channels, measures_len, device=device)
+
+ has_triggered: bool = False
+ num_measures_to_flush: int = 0
+
+ pos = 0
+ for pos in range(measure_len_ns, ilen, measure_period_ns):
+ for i in range(n_channels):
+ meas: float = _measure(
+ measure_len_ws=measure_len_ws,
+ samples=waveform[i, pos - measure_len_ws : pos],
+ spectrum=spectrum[i],
+ noise_spectrum=noise_spectrum[i],
+ spectrum_window=spectrum_window,
+ spectrum_start=spectrum_start,
+ spectrum_end=spectrum_end,
+ cepstrum_window=cepstrum_window,
+ cepstrum_start=cepstrum_start,
+ cepstrum_end=cepstrum_end,
+ noise_reduction_amount=noise_reduction_amount,
+ measure_smooth_time_mult=measure_smooth_time_mult,
+ noise_up_time_mult=noise_up_time_mult,
+ noise_down_time_mult=noise_down_time_mult,
+ boot_count=boot_count,
+ )
+ measures[i, measures_index] = meas
+ mean_meas[i] = mean_meas[i] * trigger_meas_time_mult + meas * (1.0 - trigger_meas_time_mult)
+
+ has_triggered = has_triggered or (mean_meas[i] >= trigger_level)
+ if has_triggered:
+ n: int = measures_len
+ k: int = measures_index
+ jTrigger: int = n
+ jZero: int = n
+ j: int = 0
+
+ for j in range(n):
+ if (measures[i, k] >= trigger_level) and (j <= jTrigger + gap_len):
+ jZero = jTrigger = j
+ elif (measures[i, k] == 0) and (jTrigger >= jZero):
+ jZero = j
+ k = (k + n - 1) % n
+ j = min(j, jZero)
+ # num_measures_to_flush = range_limit(j, num_measures_to_flush, n);
+ num_measures_to_flush = min(max(num_measures_to_flush, j), n)
+ # end if has_triggered
+ # end for channel
+ measures_index += 1
+ measures_index = measures_index % measures_len
+ if boot_count >= 0:
+ boot_count = -1 if boot_count == boot_count_max else boot_count + 1
+
+ if has_triggered:
+ flushedLen_ns = (measures_len - num_measures_to_flush) * measure_period_ns
+ break
+ # end for window
+ if not has_triggered:
+ return waveform[..., :0].view(shape[:-1] + torch.Size([0]))
+
+ res = waveform[:, pos - samplesLen_ns + flushedLen_ns :]
+ # unpack batch
+ return res.view(shape[:-1] + res.shape[-1:])
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/functional/functional.py b/.venv/lib/python3.11/site-packages/torchaudio/functional/functional.py
new file mode 100644
index 0000000000000000000000000000000000000000..af34e707e552d6eae403278743b406988e16e9db
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/functional/functional.py
@@ -0,0 +1,2535 @@
+# -*- coding: utf-8 -*-
+
+import math
+import tempfile
+import warnings
+from collections.abc import Sequence
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torchaudio
+from torch import Tensor
+from torchaudio._internal.module_utils import deprecated
+
+from .filtering import highpass_biquad, treble_biquad
+
+__all__ = [
+ "spectrogram",
+ "inverse_spectrogram",
+ "griffinlim",
+ "amplitude_to_DB",
+ "DB_to_amplitude",
+ "compute_deltas",
+ "melscale_fbanks",
+ "linear_fbanks",
+ "create_dct",
+ "compute_deltas",
+ "detect_pitch_frequency",
+ "DB_to_amplitude",
+ "mu_law_encoding",
+ "mu_law_decoding",
+ "phase_vocoder",
+ "mask_along_axis",
+ "mask_along_axis_iid",
+ "sliding_window_cmn",
+ "spectral_centroid",
+ "apply_codec",
+ "resample",
+ "edit_distance",
+ "loudness",
+ "pitch_shift",
+ "rnnt_loss",
+ "psd",
+ "mvdr_weights_souden",
+ "mvdr_weights_rtf",
+ "rtf_evd",
+ "rtf_power",
+ "apply_beamforming",
+ "fftconvolve",
+ "convolve",
+ "add_noise",
+ "speed",
+ "preemphasis",
+ "deemphasis",
+]
+
+
+def spectrogram(
+ waveform: Tensor,
+ pad: int,
+ window: Tensor,
+ n_fft: int,
+ hop_length: int,
+ win_length: int,
+ power: Optional[float],
+ normalized: Union[bool, str],
+ center: bool = True,
+ pad_mode: str = "reflect",
+ onesided: bool = True,
+ return_complex: Optional[bool] = None,
+) -> Tensor:
+ r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
+ The spectrogram can be either magnitude-only or complex.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ waveform (Tensor): Tensor of audio of dimension `(..., time)`
+ pad (int): Two sided padding of signal
+ window (Tensor): Window tensor that is applied/multiplied to each frame/window
+ n_fft (int): Size of FFT
+ hop_length (int): Length of hop between STFT windows
+ win_length (int): Window size
+ power (float or None): Exponent for the magnitude spectrogram,
+ (must be > 0) e.g., 1 for magnitude, 2 for power, etc.
+ If None, then the complex spectrum is returned instead.
+ normalized (bool or str): Whether to normalize by magnitude after stft. If input is str, choices are
+ ``"window"`` and ``"frame_length"``, if specific normalization type is desirable. ``True`` maps to
+ ``"window"``. When normalized on ``"window"``, waveform is normalized upon the window's L2 energy. If
+ normalized on ``"frame_length"``, waveform is normalized by dividing by
+ :math:`(\text{frame\_length})^{0.5}`.
+ center (bool, optional): whether to pad :attr:`waveform` on both sides so
+ that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
+ Default: ``True``
+ pad_mode (string, optional): controls the padding method used when
+ :attr:`center` is ``True``. Default: ``"reflect"``
+ onesided (bool, optional): controls whether to return half of results to
+ avoid redundancy. Default: ``True``
+ return_complex (bool, optional):
+ Deprecated and not used.
+
+ Returns:
+ Tensor: Dimension `(..., freq, time)`, freq is
+ ``n_fft // 2 + 1`` and ``n_fft`` is the number of
+ Fourier bins, and time is the number of window hops (n_frame).
+ """
+ if return_complex is not None:
+ warnings.warn(
+ "`return_complex` argument is now deprecated and is not effective."
+ "`torchaudio.functional.spectrogram(power=None)` always returns a tensor with "
+ "complex dtype. Please remove the argument in the function call."
+ )
+
+ if pad > 0:
+ # TODO add "with torch.no_grad():" back when JIT supports it
+ waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant")
+
+ frame_length_norm, window_norm = _get_spec_norms(normalized)
+
+ # pack batch
+ shape = waveform.size()
+ waveform = waveform.reshape(-1, shape[-1])
+
+ # default values are consistent with librosa.core.spectrum._spectrogram
+ spec_f = torch.stft(
+ input=waveform,
+ n_fft=n_fft,
+ hop_length=hop_length,
+ win_length=win_length,
+ window=window,
+ center=center,
+ pad_mode=pad_mode,
+ normalized=frame_length_norm,
+ onesided=onesided,
+ return_complex=True,
+ )
+
+ # unpack batch
+ spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-2:])
+
+ if window_norm:
+ spec_f /= window.pow(2.0).sum().sqrt()
+ if power is not None:
+ if power == 1.0:
+ return spec_f.abs()
+ return spec_f.abs().pow(power)
+ return spec_f
+
+
+def inverse_spectrogram(
+ spectrogram: Tensor,
+ length: Optional[int],
+ pad: int,
+ window: Tensor,
+ n_fft: int,
+ hop_length: int,
+ win_length: int,
+ normalized: Union[bool, str],
+ center: bool = True,
+ pad_mode: str = "reflect",
+ onesided: bool = True,
+) -> Tensor:
+ r"""Create an inverse spectrogram or a batch of inverse spectrograms from the provided
+ complex-valued spectrogram.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ spectrogram (Tensor): Complex tensor of audio of dimension (..., freq, time).
+ length (int or None): The output length of the waveform.
+ pad (int): Two sided padding of signal. It is only effective when ``length`` is provided.
+ window (Tensor): Window tensor that is applied/multiplied to each frame/window
+ n_fft (int): Size of FFT
+ hop_length (int): Length of hop between STFT windows
+ win_length (int): Window size
+ normalized (bool or str): Whether the stft output was normalized by magnitude. If input is str, choices are
+ ``"window"`` and ``"frame_length"``, dependent on normalization mode. ``True`` maps to
+ ``"window"``.
+ center (bool, optional): whether the waveform was padded on both sides so
+ that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
+ Default: ``True``
+ pad_mode (string, optional): controls the padding method used when
+ :attr:`center` is ``True``. This parameter is provided for compatibility with the
+ spectrogram function and is not used. Default: ``"reflect"``
+ onesided (bool, optional): controls whether spectrogram was done in onesided mode.
+ Default: ``True``
+
+ Returns:
+ Tensor: Dimension `(..., time)`. Least squares estimation of the original signal.
+ """
+
+ frame_length_norm, window_norm = _get_spec_norms(normalized)
+
+ if not spectrogram.is_complex():
+ raise ValueError("Expected `spectrogram` to be complex dtype.")
+
+ if window_norm:
+ spectrogram = spectrogram * window.pow(2.0).sum().sqrt()
+
+ # pack batch
+ shape = spectrogram.size()
+ spectrogram = spectrogram.reshape(-1, shape[-2], shape[-1])
+
+ # default values are consistent with librosa.core.spectrum._spectrogram
+ waveform = torch.istft(
+ input=spectrogram,
+ n_fft=n_fft,
+ hop_length=hop_length,
+ win_length=win_length,
+ window=window,
+ center=center,
+ normalized=frame_length_norm,
+ onesided=onesided,
+ length=length + 2 * pad if length is not None else None,
+ return_complex=False,
+ )
+
+ if length is not None and pad > 0:
+ # remove padding from front and back
+ waveform = waveform[:, pad:-pad]
+
+ # unpack batch
+ waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:])
+
+ return waveform
+
+
+def _get_spec_norms(normalized: Union[str, bool]):
+ frame_length_norm, window_norm = False, False
+ if torch.jit.isinstance(normalized, str):
+ if normalized not in ["frame_length", "window"]:
+ raise ValueError("Invalid normalized parameter: {}".format(normalized))
+ if normalized == "frame_length":
+ frame_length_norm = True
+ elif normalized == "window":
+ window_norm = True
+ elif torch.jit.isinstance(normalized, bool):
+ if normalized:
+ window_norm = True
+ else:
+ raise TypeError("Input type not supported")
+ return frame_length_norm, window_norm
+
+
+def _get_complex_dtype(real_dtype: torch.dtype):
+ if real_dtype == torch.double:
+ return torch.cdouble
+ if real_dtype == torch.float:
+ return torch.cfloat
+ if real_dtype == torch.half:
+ return torch.complex32
+ raise ValueError(f"Unexpected dtype {real_dtype}")
+
+
+def griffinlim(
+ specgram: Tensor,
+ window: Tensor,
+ n_fft: int,
+ hop_length: int,
+ win_length: int,
+ power: float,
+ n_iter: int,
+ momentum: float,
+ length: Optional[int],
+ rand_init: bool,
+) -> Tensor:
+ r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Implementation ported from
+ *librosa* :cite:`brian_mcfee-proc-scipy-2015`, *A fast Griffin-Lim algorithm* :cite:`6701851`
+ and *Signal estimation from modified short-time Fourier transform* :cite:`1172092`.
+
+ Args:
+ specgram (Tensor): A magnitude-only STFT spectrogram of dimension `(..., freq, frames)`
+ where freq is ``n_fft // 2 + 1``.
+ window (Tensor): Window tensor that is applied/multiplied to each frame/window
+ n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins
+ hop_length (int): Length of hop between STFT windows. (
+ Default: ``win_length // 2``)
+ win_length (int): Window size. (Default: ``n_fft``)
+ power (float): Exponent for the magnitude spectrogram,
+ (must be > 0) e.g., 1 for magnitude, 2 for power, etc.
+ n_iter (int): Number of iteration for phase recovery process.
+ momentum (float): The momentum parameter for fast Griffin-Lim.
+ Setting this to 0 recovers the original Griffin-Lim method.
+ Values near 1 can lead to faster convergence, but above 1 may not converge.
+ length (int or None): Array length of the expected output.
+ rand_init (bool): Initializes phase randomly if True, to zero otherwise.
+
+ Returns:
+ Tensor: waveform of `(..., time)`, where time equals the ``length`` parameter if given.
+ """
+ if not 0 <= momentum < 1:
+ raise ValueError("momentum must be in range [0, 1). Found: {}".format(momentum))
+
+ momentum = momentum / (1 + momentum)
+
+ # pack batch
+ shape = specgram.size()
+ specgram = specgram.reshape([-1] + list(shape[-2:]))
+
+ specgram = specgram.pow(1 / power)
+
+ # initialize the phase
+ if rand_init:
+ angles = torch.rand(specgram.size(), dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)
+ else:
+ angles = torch.full(specgram.size(), 1, dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)
+
+ # And initialize the previous iterate to 0
+ tprev = torch.tensor(0.0, dtype=specgram.dtype, device=specgram.device)
+ for _ in range(n_iter):
+ # Invert with our current estimate of the phases
+ inverse = torch.istft(
+ specgram * angles, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=length
+ )
+
+ # Rebuild the spectrogram
+ rebuilt = torch.stft(
+ input=inverse,
+ n_fft=n_fft,
+ hop_length=hop_length,
+ win_length=win_length,
+ window=window,
+ center=True,
+ pad_mode="reflect",
+ normalized=False,
+ onesided=True,
+ return_complex=True,
+ )
+
+ # Update our phase estimates
+ angles = rebuilt
+ if momentum:
+ angles = angles - tprev.mul_(momentum)
+ angles = angles.div(angles.abs().add(1e-16))
+
+ # Store the previous iterate
+ tprev = rebuilt
+
+ # Return the final phase estimates
+ waveform = torch.istft(
+ specgram * angles, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=length
+ )
+
+ # unpack batch
+ waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:])
+
+ return waveform
+
+
+def amplitude_to_DB(
+ x: Tensor, multiplier: float, amin: float, db_multiplier: float, top_db: Optional[float] = None
+) -> Tensor:
+ r"""Turn a spectrogram from the power/amplitude scale to the decibel scale.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ The output of each tensor in a batch depends on the maximum value of that tensor,
+ and so may return different values for an audio clip split into snippets vs. a full clip.
+
+ Args:
+
+ x (Tensor): Input spectrogram(s) before being converted to decibel scale.
+ The expected shapes are ``(freq, time)``, ``(channel, freq, time)`` or
+ ``(..., batch, channel, freq, time)``.
+
+ .. note::
+
+ When ``top_db`` is specified, cut-off values are computed for each audio
+ in the batch. Therefore if the input shape is 4D (or larger), different
+ cut-off values are used for audio data in the batch.
+ If the input shape is 2D or 3D, a single cutoff value is used.
+
+ multiplier (float): Use 10. for power and 20. for amplitude
+ amin (float): Number to clamp ``x``
+ db_multiplier (float): Log10(max(reference value and amin))
+ top_db (float or None, optional): Minimum negative cut-off in decibels. A reasonable number
+ is 80. (Default: ``None``)
+
+ Returns:
+ Tensor: Output tensor in decibel scale
+ """
+ x_db = multiplier * torch.log10(torch.clamp(x, min=amin))
+ x_db -= multiplier * db_multiplier
+
+ if top_db is not None:
+ # Expand batch
+ shape = x_db.size()
+ packed_channels = shape[-3] if x_db.dim() > 2 else 1
+ x_db = x_db.reshape(-1, packed_channels, shape[-2], shape[-1])
+
+ x_db = torch.max(x_db, (x_db.amax(dim=(-3, -2, -1)) - top_db).view(-1, 1, 1, 1))
+
+ # Repack batch
+ x_db = x_db.reshape(shape)
+
+ return x_db
+
+
+def DB_to_amplitude(x: Tensor, ref: float, power: float) -> Tensor:
+ r"""Turn a tensor from the decibel scale to the power/amplitude scale.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: TorchScript
+
+ Args:
+ x (Tensor): Input tensor before being converted to power/amplitude scale.
+ ref (float): Reference which the output will be scaled by.
+ power (float): If power equals 1, will compute DB to power. If 0.5, will compute DB to amplitude.
+
+ Returns:
+ Tensor: Output tensor in power/amplitude scale.
+ """
+ return ref * torch.pow(torch.pow(10.0, 0.1 * x), power)
+
+
+def _hz_to_mel(freq: float, mel_scale: str = "htk") -> float:
+ r"""Convert Hz to Mels.
+
+ Args:
+ freqs (float): Frequencies in Hz
+ mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
+
+ Returns:
+ mels (float): Frequency in Mels
+ """
+
+ if mel_scale not in ["slaney", "htk"]:
+ raise ValueError('mel_scale should be one of "htk" or "slaney".')
+
+ if mel_scale == "htk":
+ return 2595.0 * math.log10(1.0 + (freq / 700.0))
+
+ # Fill in the linear part
+ f_min = 0.0
+ f_sp = 200.0 / 3
+
+ mels = (freq - f_min) / f_sp
+
+ # Fill in the log-scale part
+ min_log_hz = 1000.0
+ min_log_mel = (min_log_hz - f_min) / f_sp
+ logstep = math.log(6.4) / 27.0
+
+ if freq >= min_log_hz:
+ mels = min_log_mel + math.log(freq / min_log_hz) / logstep
+
+ return mels
+
+
+def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor:
+ """Convert mel bin numbers to frequencies.
+
+ Args:
+ mels (Tensor): Mel frequencies
+ mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
+
+ Returns:
+ freqs (Tensor): Mels converted in Hz
+ """
+
+ if mel_scale not in ["slaney", "htk"]:
+ raise ValueError('mel_scale should be one of "htk" or "slaney".')
+
+ if mel_scale == "htk":
+ return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
+
+ # Fill in the linear scale
+ f_min = 0.0
+ f_sp = 200.0 / 3
+ freqs = f_min + f_sp * mels
+
+ # And now the nonlinear scale
+ min_log_hz = 1000.0
+ min_log_mel = (min_log_hz - f_min) / f_sp
+ logstep = math.log(6.4) / 27.0
+
+ log_t = mels >= min_log_mel
+ freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
+
+ return freqs
+
+
+def _create_triangular_filterbank(
+ all_freqs: Tensor,
+ f_pts: Tensor,
+) -> Tensor:
+ """Create a triangular filter bank.
+
+ Args:
+ all_freqs (Tensor): STFT freq points of size (`n_freqs`).
+ f_pts (Tensor): Filter mid points of size (`n_filter`).
+
+ Returns:
+ fb (Tensor): The filter bank of size (`n_freqs`, `n_filter`).
+ """
+ # Adopted from Librosa
+ # calculate the difference between each filter mid point and each stft freq point in hertz
+ f_diff = f_pts[1:] - f_pts[:-1] # (n_filter + 1)
+ slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_filter + 2)
+ # create overlapping triangles
+ zero = torch.zeros(1)
+ down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_filter)
+ up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_filter)
+ fb = torch.max(zero, torch.min(down_slopes, up_slopes))
+
+ return fb
+
+
+def melscale_fbanks(
+ n_freqs: int,
+ f_min: float,
+ f_max: float,
+ n_mels: int,
+ sample_rate: int,
+ norm: Optional[str] = None,
+ mel_scale: str = "htk",
+) -> Tensor:
+ r"""Create a frequency bin conversion matrix.
+
+ .. devices:: CPU
+
+ .. properties:: TorchScript
+
+ Note:
+ For the sake of the numerical compatibility with librosa, not all the coefficients
+ in the resulting filter bank has magnitude of 1.
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/mel_fbanks.png
+ :alt: Visualization of generated filter bank
+
+ Args:
+ n_freqs (int): Number of frequencies to highlight/apply
+ f_min (float): Minimum frequency (Hz)
+ f_max (float): Maximum frequency (Hz)
+ n_mels (int): Number of mel filterbanks
+ sample_rate (int): Sample rate of the audio waveform
+ norm (str or None, optional): If "slaney", divide the triangular mel weights by the width of the mel band
+ (area normalization). (Default: ``None``)
+ mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
+
+ Returns:
+ Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
+ meaning number of frequencies to highlight/apply to x the number of filterbanks.
+ Each column is a filterbank so that assuming there is a matrix A of
+ size (..., ``n_freqs``), the applied result would be
+ ``A @ melscale_fbanks(A.size(-1), ...)``.
+
+ """
+
+ if norm is not None and norm != "slaney":
+ raise ValueError('norm must be one of None or "slaney"')
+
+ # freq bins
+ all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
+
+ # calculate mel freq bins
+ m_min = _hz_to_mel(f_min, mel_scale=mel_scale)
+ m_max = _hz_to_mel(f_max, mel_scale=mel_scale)
+
+ m_pts = torch.linspace(m_min, m_max, n_mels + 2)
+ f_pts = _mel_to_hz(m_pts, mel_scale=mel_scale)
+
+ # create filterbank
+ fb = _create_triangular_filterbank(all_freqs, f_pts)
+
+ if norm is not None and norm == "slaney":
+ # Slaney-style mel is scaled to be approx constant energy per channel
+ enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels])
+ fb *= enorm.unsqueeze(0)
+
+ if (fb.max(dim=0).values == 0.0).any():
+ warnings.warn(
+ "At least one mel filterbank has all zero values. "
+ f"The value for `n_mels` ({n_mels}) may be set too high. "
+ f"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
+ )
+
+ return fb
+
+
+def linear_fbanks(
+ n_freqs: int,
+ f_min: float,
+ f_max: float,
+ n_filter: int,
+ sample_rate: int,
+) -> Tensor:
+ r"""Creates a linear triangular filterbank.
+
+ .. devices:: CPU
+
+ .. properties:: TorchScript
+
+ Note:
+ For the sake of the numerical compatibility with librosa, not all the coefficients
+ in the resulting filter bank has magnitude of 1.
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/lin_fbanks.png
+ :alt: Visualization of generated filter bank
+
+ Args:
+ n_freqs (int): Number of frequencies to highlight/apply
+ f_min (float): Minimum frequency (Hz)
+ f_max (float): Maximum frequency (Hz)
+ n_filter (int): Number of (linear) triangular filter
+ sample_rate (int): Sample rate of the audio waveform
+
+ Returns:
+ Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_filter``)
+ meaning number of frequencies to highlight/apply to x the number of filterbanks.
+ Each column is a filterbank so that assuming there is a matrix A of
+ size (..., ``n_freqs``), the applied result would be
+ ``A * linear_fbanks(A.size(-1), ...)``.
+ """
+ # freq bins
+ all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
+
+ # filter mid-points
+ f_pts = torch.linspace(f_min, f_max, n_filter + 2)
+
+ # create filterbank
+ fb = _create_triangular_filterbank(all_freqs, f_pts)
+
+ return fb
+
+
+def create_dct(n_mfcc: int, n_mels: int, norm: Optional[str]) -> Tensor:
+ r"""Create a DCT transformation matrix with shape (``n_mels``, ``n_mfcc``),
+ normalized depending on norm.
+
+ .. devices:: CPU
+
+ .. properties:: TorchScript
+
+ Args:
+ n_mfcc (int): Number of mfc coefficients to retain
+ n_mels (int): Number of mel filterbanks
+ norm (str or None): Norm to use (either "ortho" or None)
+
+ Returns:
+ Tensor: The transformation matrix, to be right-multiplied to
+ row-wise data of size (``n_mels``, ``n_mfcc``).
+ """
+
+ if norm is not None and norm != "ortho":
+ raise ValueError('norm must be either "ortho" or None')
+
+ # http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
+ n = torch.arange(float(n_mels))
+ k = torch.arange(float(n_mfcc)).unsqueeze(1)
+ dct = torch.cos(math.pi / float(n_mels) * (n + 0.5) * k) # size (n_mfcc, n_mels)
+
+ if norm is None:
+ dct *= 2.0
+ else:
+ dct[0] *= 1.0 / math.sqrt(2.0)
+ dct *= math.sqrt(2.0 / float(n_mels))
+ return dct.t()
+
+
+def mu_law_encoding(x: Tensor, quantization_channels: int) -> Tensor:
+ r"""Encode signal based on mu-law companding.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: TorchScript
+
+ For more info see the
+ `Wikipedia Entry `_
+
+ This algorithm expects the signal has been scaled to between -1 and 1 and
+ returns a signal encoded with values from 0 to quantization_channels - 1.
+
+ Args:
+ x (Tensor): Input tensor
+ quantization_channels (int): Number of channels
+
+ Returns:
+ Tensor: Input after mu-law encoding
+ """
+ mu = quantization_channels - 1.0
+ if not x.is_floating_point():
+ warnings.warn(
+ "The input Tensor must be of floating type. \
+ This will be an error in the v0.12 release."
+ )
+ x = x.to(torch.float)
+ mu = torch.tensor(mu, dtype=x.dtype)
+ x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu)
+ x_mu = ((x_mu + 1) / 2 * mu + 0.5).to(torch.int64)
+ return x_mu
+
+
+def mu_law_decoding(x_mu: Tensor, quantization_channels: int) -> Tensor:
+ r"""Decode mu-law encoded signal.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: TorchScript
+
+ For more info see the
+ `Wikipedia Entry `_
+
+ This expects an input with values between 0 and quantization_channels - 1
+ and returns a signal scaled between -1 and 1.
+
+ Args:
+ x_mu (Tensor): Input tensor
+ quantization_channels (int): Number of channels
+
+ Returns:
+ Tensor: Input after mu-law decoding
+ """
+ mu = quantization_channels - 1.0
+ if not x_mu.is_floating_point():
+ x_mu = x_mu.to(torch.float)
+ mu = torch.tensor(mu, dtype=x_mu.dtype)
+ x = ((x_mu) / mu) * 2 - 1.0
+ x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu
+ return x
+
+
+def phase_vocoder(complex_specgrams: Tensor, rate: float, phase_advance: Tensor) -> Tensor:
+ r"""Given a STFT tensor, speed up in time without modifying pitch by a factor of ``rate``.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ complex_specgrams (Tensor):
+ A tensor of dimension `(..., freq, num_frame)` with complex dtype.
+ rate (float): Speed-up factor
+ phase_advance (Tensor): Expected phase advance in each bin. Dimension of `(freq, 1)`
+
+ Returns:
+ Tensor:
+ Stretched spectrogram. The resulting tensor is of the same dtype as the input
+ spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``.
+
+ Example
+ >>> freq, hop_length = 1025, 512
+ >>> # (channel, freq, time)
+ >>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat)
+ >>> rate = 1.3 # Speed up by 30%
+ >>> phase_advance = torch.linspace(
+ >>> 0, math.pi * hop_length, freq)[..., None]
+ >>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
+ >>> x.shape # with 231 == ceil(300 / 1.3)
+ torch.Size([2, 1025, 231])
+ """
+ if rate == 1.0:
+ return complex_specgrams
+
+ # pack batch
+ shape = complex_specgrams.size()
+ complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-2:]))
+
+ # Figures out the corresponding real dtype, i.e. complex128 -> float64, complex64 -> float32
+ # Note torch.real is a view so it does not incur any memory copy.
+ real_dtype = torch.real(complex_specgrams).dtype
+ time_steps = torch.arange(0, complex_specgrams.size(-1), rate, device=complex_specgrams.device, dtype=real_dtype)
+
+ alphas = time_steps % 1.0
+ phase_0 = complex_specgrams[..., :1].angle()
+
+ # Time Padding
+ complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 2])
+
+ # (new_bins, freq, 2)
+ complex_specgrams_0 = complex_specgrams.index_select(-1, time_steps.long())
+ complex_specgrams_1 = complex_specgrams.index_select(-1, (time_steps + 1).long())
+
+ angle_0 = complex_specgrams_0.angle()
+ angle_1 = complex_specgrams_1.angle()
+
+ norm_0 = complex_specgrams_0.abs()
+ norm_1 = complex_specgrams_1.abs()
+
+ phase = angle_1 - angle_0 - phase_advance
+ phase = phase - 2 * math.pi * torch.round(phase / (2 * math.pi))
+
+ # Compute Phase Accum
+ phase = phase + phase_advance
+ phase = torch.cat([phase_0, phase[..., :-1]], dim=-1)
+ phase_acc = torch.cumsum(phase, -1)
+
+ mag = alphas * norm_1 + (1 - alphas) * norm_0
+
+ complex_specgrams_stretch = torch.polar(mag, phase_acc)
+
+ # unpack batch
+ complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-2] + complex_specgrams_stretch.shape[1:])
+ return complex_specgrams_stretch
+
+
+def _get_mask_param(mask_param: int, p: float, axis_length: int) -> int:
+ if p == 1.0:
+ return mask_param
+ else:
+ return min(mask_param, int(axis_length * p))
+
+
+def mask_along_axis_iid(
+ specgrams: Tensor,
+ mask_param: int,
+ mask_value: float,
+ axis: int,
+ p: float = 1.0,
+) -> Tensor:
+ r"""Apply a mask along ``axis``.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Mask will be applied from indices ``[v_0, v_0 + v)``,
+ where ``v`` is sampled from ``uniform(0, max_v)`` and
+ ``v_0`` from ``uniform(0, specgrams.size(axis) - v)``,
+ with ``max_v = mask_param`` when ``p = 1.0`` and
+ ``max_v = min(mask_param, floor(specgrams.size(axis) * p))`` otherwise.
+
+ Args:
+ specgrams (Tensor): Real spectrograms `(..., freq, time)`, with at least 3 dimensions.
+ mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
+ mask_value (float): Value to assign to the masked columns
+ axis (int): Axis to apply masking on, which should be the one of the last two dimensions.
+ p (float, optional): maximum proportion of columns that can be masked. (Default: 1.0)
+
+ Returns:
+ Tensor: Masked spectrograms with the same dimensions as input specgrams Tensor`
+ """
+
+ dim = specgrams.dim()
+
+ if dim < 3:
+ raise ValueError(f"Spectrogram must have at least three dimensions ({dim} given).")
+
+ if axis not in [dim - 2, dim - 1]:
+ raise ValueError(
+ f"Only Frequency and Time masking are supported (axis {dim-2} and axis {dim-1} supported; {axis} given)."
+ )
+
+ if not 0.0 <= p <= 1.0:
+ raise ValueError(f"The value of p must be between 0.0 and 1.0 ({p} given).")
+
+ mask_param = _get_mask_param(mask_param, p, specgrams.shape[axis])
+ if mask_param < 1:
+ return specgrams
+
+ device = specgrams.device
+ dtype = specgrams.dtype
+
+ value = torch.rand(specgrams.shape[: (dim - 2)], device=device, dtype=dtype) * mask_param
+ min_value = torch.rand(specgrams.shape[: (dim - 2)], device=device, dtype=dtype) * (specgrams.size(axis) - value)
+
+ # Create broadcastable mask
+ mask_start = min_value.long()[..., None, None]
+ mask_end = (min_value.long() + value.long())[..., None, None]
+ mask = torch.arange(0, specgrams.size(axis), device=device, dtype=dtype)
+
+ # Per batch example masking
+ specgrams = specgrams.transpose(axis, -1)
+ specgrams = specgrams.masked_fill((mask >= mask_start) & (mask < mask_end), mask_value)
+ specgrams = specgrams.transpose(axis, -1)
+
+ return specgrams
+
+
+def mask_along_axis(
+ specgram: Tensor,
+ mask_param: int,
+ mask_value: float,
+ axis: int,
+ p: float = 1.0,
+) -> Tensor:
+ r"""Apply a mask along ``axis``.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Mask will be applied from indices ``[v_0, v_0 + v)``,
+ where ``v`` is sampled from ``uniform(0, max_v)`` and
+ ``v_0`` from ``uniform(0, specgram.size(axis) - v)``, with
+ ``max_v = mask_param`` when ``p = 1.0`` and
+ ``max_v = min(mask_param, floor(specgram.size(axis) * p))``
+ otherwise.
+ All examples will have the same mask interval.
+
+ Args:
+ specgram (Tensor): Real spectrograms `(..., freq, time)`, with at least 2 dimensions.
+ mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
+ mask_value (float): Value to assign to the masked columns
+ axis (int): Axis to apply masking on, which should be the one of the last two dimensions.
+ p (float, optional): maximum proportion of columns that can be masked. (Default: 1.0)
+
+ Returns:
+ Tensor: Masked spectrograms with the same dimensions as input specgram Tensor
+ """
+ dim = specgram.dim()
+
+ if dim < 2:
+ raise ValueError(f"Spectrogram must have at least two dimensions (time and frequency) ({dim} given).")
+
+ if axis not in [dim - 2, dim - 1]:
+ raise ValueError(
+ f"Only Frequency and Time masking are supported (axis {dim-2} and axis {dim-1} supported; {axis} given)."
+ )
+
+ if not 0.0 <= p <= 1.0:
+ raise ValueError(f"The value of p must be between 0.0 and 1.0 ({p} given).")
+
+ mask_param = _get_mask_param(mask_param, p, specgram.shape[axis])
+ if mask_param < 1:
+ return specgram
+
+ # pack batch
+ shape = specgram.size()
+ specgram = specgram.reshape([-1] + list(shape[-2:]))
+ # After packing, specgram is a 3D tensor, and the axis corresponding to the to-be-masked dimension
+ # is now (axis - dim + 3), e.g. a tensor of shape (10, 2, 50, 10, 2) becomes a tensor of shape (1000, 10, 2).
+ value = torch.rand(1) * mask_param
+ min_value = torch.rand(1) * (specgram.size(axis - dim + 3) - value)
+
+ mask_start = (min_value.long()).squeeze()
+ mask_end = (min_value.long() + value.long()).squeeze()
+ mask = torch.arange(0, specgram.shape[axis - dim + 3], device=specgram.device, dtype=specgram.dtype)
+ mask = (mask >= mask_start) & (mask < mask_end)
+ # unsqueeze the mask if the axis is frequency
+ if axis == dim - 2:
+ mask = mask.unsqueeze(-1)
+
+ if mask_end - mask_start >= mask_param:
+ raise ValueError("Number of columns to be masked should be less than mask_param")
+
+ specgram = specgram.masked_fill(mask, mask_value)
+
+ # unpack batch
+ specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:])
+
+ return specgram
+
+
+def compute_deltas(specgram: Tensor, win_length: int = 5, mode: str = "replicate") -> Tensor:
+ r"""Compute delta coefficients of a tensor, usually a spectrogram:
+
+ .. devices:: CPU CUDA
+
+ .. properties:: TorchScript
+
+ .. math::
+ d_t = \frac{\sum_{n=1}^{\text{N}} n (c_{t+n} - c_{t-n})}{2 \sum_{n=1}^{\text{N}} n^2}
+
+ where :math:`d_t` is the deltas at time :math:`t`,
+ :math:`c_t` is the spectrogram coeffcients at time :math:`t`,
+ :math:`N` is ``(win_length-1)//2``.
+
+ Args:
+ specgram (Tensor): Tensor of audio of dimension `(..., freq, time)`
+ win_length (int, optional): The window length used for computing delta (Default: ``5``)
+ mode (str, optional): Mode parameter passed to padding (Default: ``"replicate"``)
+
+ Returns:
+ Tensor: Tensor of deltas of dimension `(..., freq, time)`
+
+ Example
+ >>> specgram = torch.randn(1, 40, 1000)
+ >>> delta = compute_deltas(specgram)
+ >>> delta2 = compute_deltas(delta)
+ """
+ device = specgram.device
+ dtype = specgram.dtype
+
+ # pack batch
+ shape = specgram.size()
+ specgram = specgram.reshape(1, -1, shape[-1])
+
+ if win_length < 3:
+ raise ValueError(f"Window length should be greater than or equal to 3. Found win_length {win_length}")
+
+ n = (win_length - 1) // 2
+
+ # twice sum of integer squared
+ denom = n * (n + 1) * (2 * n + 1) / 3
+
+ specgram = torch.nn.functional.pad(specgram, (n, n), mode=mode)
+
+ kernel = torch.arange(-n, n + 1, 1, device=device, dtype=dtype).repeat(specgram.shape[1], 1, 1)
+
+ output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom
+
+ # unpack batch
+ output = output.reshape(shape)
+
+ return output
+
+
+def _compute_nccf(waveform: Tensor, sample_rate: int, frame_time: float, freq_low: int) -> Tensor:
+ r"""
+ Compute Normalized Cross-Correlation Function (NCCF).
+
+ .. math::
+ \phi_i(m) = \frac{\sum_{n=b_i}^{b_i + N-1} w(n) w(m+n)}{\sqrt{E(b_i) E(m+b_i)}},
+
+ where
+ :math:`\phi_i(m)` is the NCCF at frame :math:`i` with lag :math:`m`,
+ :math:`w` is the waveform,
+ :math:`N` is the length of a frame,
+ :math:`b_i` is the beginning of frame :math:`i`,
+ :math:`E(j)` is the energy :math:`\sum_{n=j}^{j+N-1} w^2(n)`.
+ """
+
+ EPSILON = 10 ** (-9)
+
+ # Number of lags to check
+ lags = int(math.ceil(sample_rate / freq_low))
+
+ frame_size = int(math.ceil(sample_rate * frame_time))
+
+ waveform_length = waveform.size()[-1]
+ num_of_frames = int(math.ceil(waveform_length / frame_size))
+
+ p = lags + num_of_frames * frame_size - waveform_length
+ waveform = torch.nn.functional.pad(waveform, (0, p))
+
+ # Compute lags
+ output_lag = []
+ for lag in range(1, lags + 1):
+ s1 = waveform[..., :-lag].unfold(-1, frame_size, frame_size)[..., :num_of_frames, :]
+ s2 = waveform[..., lag:].unfold(-1, frame_size, frame_size)[..., :num_of_frames, :]
+
+ output_frames = (
+ (s1 * s2).sum(-1)
+ / (EPSILON + torch.linalg.vector_norm(s1, ord=2, dim=-1)).pow(2)
+ / (EPSILON + torch.linalg.vector_norm(s2, ord=2, dim=-1)).pow(2)
+ )
+
+ output_lag.append(output_frames.unsqueeze(-1))
+
+ nccf = torch.cat(output_lag, -1)
+
+ return nccf
+
+
+def _combine_max(a: Tuple[Tensor, Tensor], b: Tuple[Tensor, Tensor], thresh: float = 0.99) -> Tuple[Tensor, Tensor]:
+ """
+ Take value from first if bigger than a multiplicative factor of the second, elementwise.
+ """
+ mask = a[0] > thresh * b[0]
+ values = mask * a[0] + ~mask * b[0]
+ indices = mask * a[1] + ~mask * b[1]
+ return values, indices
+
+
+def _find_max_per_frame(nccf: Tensor, sample_rate: int, freq_high: int) -> Tensor:
+ r"""
+ For each frame, take the highest value of NCCF,
+ apply centered median smoothing, and convert to frequency.
+
+ Note: If the max among all the lags is very close
+ to the first half of lags, then the latter is taken.
+ """
+
+ lag_min = int(math.ceil(sample_rate / freq_high))
+
+ # Find near enough max that is smallest
+
+ best = torch.max(nccf[..., lag_min:], -1)
+
+ half_size = nccf.shape[-1] // 2
+ half = torch.max(nccf[..., lag_min:half_size], -1)
+
+ best = _combine_max(half, best)
+ indices = best[1]
+
+ # Add back minimal lag
+ indices += lag_min
+ # Add 1 empirical calibration offset
+ indices += 1
+
+ return indices
+
+
+def _median_smoothing(indices: Tensor, win_length: int) -> Tensor:
+ r"""
+ Apply median smoothing to the 1D tensor over the given window.
+ """
+
+ # Centered windowed
+ pad_length = (win_length - 1) // 2
+
+ # "replicate" padding in any dimension
+ indices = torch.nn.functional.pad(indices, (pad_length, 0), mode="constant", value=0.0)
+
+ indices[..., :pad_length] = torch.cat(pad_length * [indices[..., pad_length].unsqueeze(-1)], dim=-1)
+ roll = indices.unfold(-1, win_length, 1)
+
+ values, _ = torch.median(roll, -1)
+ return values
+
+
+def detect_pitch_frequency(
+ waveform: Tensor,
+ sample_rate: int,
+ frame_time: float = 10 ** (-2),
+ win_length: int = 30,
+ freq_low: int = 85,
+ freq_high: int = 3400,
+) -> Tensor:
+ r"""Detect pitch frequency.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: TorchScript
+
+ It is implemented using normalized cross-correlation function and median smoothing.
+
+ Args:
+ waveform (Tensor): Tensor of audio of dimension `(..., freq, time)`
+ sample_rate (int): The sample rate of the waveform (Hz)
+ frame_time (float, optional): Duration of a frame (Default: ``10 ** (-2)``).
+ win_length (int, optional): The window length for median smoothing (in number of frames) (Default: ``30``).
+ freq_low (int, optional): Lowest frequency that can be detected (Hz) (Default: ``85``).
+ freq_high (int, optional): Highest frequency that can be detected (Hz) (Default: ``3400``).
+
+ Returns:
+ Tensor: Tensor of freq of dimension `(..., frame)`
+ """
+ # pack batch
+ shape = list(waveform.size())
+ waveform = waveform.reshape([-1] + shape[-1:])
+
+ nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low)
+ indices = _find_max_per_frame(nccf, sample_rate, freq_high)
+ indices = _median_smoothing(indices, win_length)
+
+ # Convert indices to frequency
+ EPSILON = 10 ** (-9)
+ freq = sample_rate / (EPSILON + indices.to(torch.float))
+
+ # unpack batch
+ freq = freq.reshape(shape[:-1] + list(freq.shape[-1:]))
+
+ return freq
+
+
+def sliding_window_cmn(
+ specgram: Tensor,
+ cmn_window: int = 600,
+ min_cmn_window: int = 100,
+ center: bool = False,
+ norm_vars: bool = False,
+) -> Tensor:
+ r"""
+ Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: TorchScript
+
+ Args:
+ specgram (Tensor): Tensor of spectrogram of dimension `(..., time, freq)`
+ cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
+ min_cmn_window (int, optional): Minimum CMN window used at start of decoding (adds latency only at start).
+ Only applicable if center == false, ignored if center==true (int, default = 100)
+ center (bool, optional): If true, use a window centered on the current frame
+ (to the extent possible, modulo end effects). If false, window is to the left. (bool, default = false)
+ norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)
+
+ Returns:
+ Tensor: Tensor matching input shape `(..., freq, time)`
+ """
+ input_shape = specgram.shape
+ num_frames, num_feats = input_shape[-2:]
+ specgram = specgram.view(-1, num_frames, num_feats)
+ num_channels = specgram.shape[0]
+
+ dtype = specgram.dtype
+ device = specgram.device
+ last_window_start = last_window_end = -1
+ cur_sum = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
+ cur_sumsq = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
+ cmn_specgram = torch.zeros(num_channels, num_frames, num_feats, dtype=dtype, device=device)
+ for t in range(num_frames):
+ window_start = 0
+ window_end = 0
+ if center:
+ window_start = t - cmn_window // 2
+ window_end = window_start + cmn_window
+ else:
+ window_start = t - cmn_window
+ window_end = t + 1
+ if window_start < 0:
+ window_end -= window_start
+ window_start = 0
+ if not center:
+ if window_end > t:
+ window_end = max(t + 1, min_cmn_window)
+ if window_end > num_frames:
+ window_start -= window_end - num_frames
+ window_end = num_frames
+ if window_start < 0:
+ window_start = 0
+ if last_window_start == -1:
+ input_part = specgram[:, window_start : window_end - window_start, :]
+ cur_sum += torch.sum(input_part, 1)
+ if norm_vars:
+ cur_sumsq += torch.cumsum(input_part**2, 1)[:, -1, :]
+ else:
+ if window_start > last_window_start:
+ frame_to_remove = specgram[:, last_window_start, :]
+ cur_sum -= frame_to_remove
+ if norm_vars:
+ cur_sumsq -= frame_to_remove**2
+ if window_end > last_window_end:
+ frame_to_add = specgram[:, last_window_end, :]
+ cur_sum += frame_to_add
+ if norm_vars:
+ cur_sumsq += frame_to_add**2
+ window_frames = window_end - window_start
+ last_window_start = window_start
+ last_window_end = window_end
+ cmn_specgram[:, t, :] = specgram[:, t, :] - cur_sum / window_frames
+ if norm_vars:
+ if window_frames == 1:
+ cmn_specgram[:, t, :] = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
+ else:
+ variance = cur_sumsq
+ variance = variance / window_frames
+ variance -= (cur_sum**2) / (window_frames**2)
+ variance = torch.pow(variance, -0.5)
+ cmn_specgram[:, t, :] *= variance
+
+ cmn_specgram = cmn_specgram.view(input_shape[:-2] + (num_frames, num_feats))
+ if len(input_shape) == 2:
+ cmn_specgram = cmn_specgram.squeeze(0)
+ return cmn_specgram
+
+
+def spectral_centroid(
+ waveform: Tensor,
+ sample_rate: int,
+ pad: int,
+ window: Tensor,
+ n_fft: int,
+ hop_length: int,
+ win_length: int,
+) -> Tensor:
+ r"""Compute the spectral centroid for each channel along the time axis.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ The spectral centroid is defined as the weighted average of the
+ frequency values, weighted by their magnitude.
+
+ Args:
+ waveform (Tensor): Tensor of audio of dimension `(..., time)`
+ sample_rate (int): Sample rate of the audio waveform
+ pad (int): Two sided padding of signal
+ window (Tensor): Window tensor that is applied/multiplied to each frame/window
+ n_fft (int): Size of FFT
+ hop_length (int): Length of hop between STFT windows
+ win_length (int): Window size
+
+ Returns:
+ Tensor: Dimension `(..., time)`
+ """
+ specgram = spectrogram(
+ waveform,
+ pad=pad,
+ window=window,
+ n_fft=n_fft,
+ hop_length=hop_length,
+ win_length=win_length,
+ power=1.0,
+ normalized=False,
+ )
+ freqs = torch.linspace(0, sample_rate // 2, steps=1 + n_fft // 2, device=specgram.device).reshape((-1, 1))
+ freq_dim = -2
+ return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim)
+
+
+@deprecated("Please migrate to :py:class:`torchaudio.io.AudioEffector`.", remove=False)
+def apply_codec(
+ waveform: Tensor,
+ sample_rate: int,
+ format: str,
+ channels_first: bool = True,
+ compression: Optional[float] = None,
+ encoding: Optional[str] = None,
+ bits_per_sample: Optional[int] = None,
+) -> Tensor:
+ r"""
+ Apply codecs as a form of augmentation.
+
+ .. devices:: CPU
+
+ Args:
+ waveform (Tensor): Audio data. Must be 2 dimensional. See also ```channels_first```.
+ sample_rate (int): Sample rate of the audio waveform.
+ format (str): File format.
+ channels_first (bool, optional):
+ When True, both the input and output Tensor have dimension `(channel, time)`.
+ Otherwise, they have dimension `(time, channel)`.
+ compression (float or None, optional): Used for formats other than WAV.
+ For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
+ encoding (str or None, optional): Changes the encoding for the supported formats.
+ For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
+ bits_per_sample (int or None, optional): Changes the bit depth for the supported formats.
+ For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
+
+ Returns:
+ Tensor: Resulting Tensor.
+ If ``channels_first=True``, it has `(channel, time)` else `(time, channel)`.
+ """
+ from torchaudio.backend import _sox_io_backend
+
+ with tempfile.NamedTemporaryFile() as f:
+ torchaudio.backend._sox_io_backend.save(
+ f.name, waveform, sample_rate, channels_first, compression, format, encoding, bits_per_sample
+ )
+ augmented, sr = _sox_io_backend.load(f.name, channels_first=channels_first, format=format)
+ if sr != sample_rate:
+ augmented = resample(augmented, sr, sample_rate)
+ return augmented
+
+
+_CPU = torch.device("cpu")
+
+
+def _get_sinc_resample_kernel(
+ orig_freq: int,
+ new_freq: int,
+ gcd: int,
+ lowpass_filter_width: int = 6,
+ rolloff: float = 0.99,
+ resampling_method: str = "sinc_interp_hann",
+ beta: Optional[float] = None,
+ device: torch.device = _CPU,
+ dtype: Optional[torch.dtype] = None,
+):
+ if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
+ raise Exception(
+ "Frequencies must be of integer type to ensure quality resampling computation. "
+ "To work around this, manually convert both frequencies to integer values "
+ "that maintain their resampling rate ratio before passing them into the function. "
+ "Example: To downsample a 44100 hz waveform by a factor of 8, use "
+ "`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5`. "
+ "For more information, please refer to https://github.com/pytorch/audio/issues/1487."
+ )
+
+ if resampling_method in ["sinc_interpolation", "kaiser_window"]:
+ method_map = {
+ "sinc_interpolation": "sinc_interp_hann",
+ "kaiser_window": "sinc_interp_kaiser",
+ }
+ warnings.warn(
+ f'"{resampling_method}" resampling method name is being deprecated and replaced by '
+ f'"{method_map[resampling_method]}" in the next release. '
+ "The default behavior remains unchanged.",
+ stacklevel=3,
+ )
+ elif resampling_method not in ["sinc_interp_hann", "sinc_interp_kaiser"]:
+ raise ValueError("Invalid resampling method: {}".format(resampling_method))
+
+ orig_freq = int(orig_freq) // gcd
+ new_freq = int(new_freq) // gcd
+
+ if lowpass_filter_width <= 0:
+ raise ValueError("Low pass filter width should be positive.")
+ base_freq = min(orig_freq, new_freq)
+ # This will perform antialiasing filtering by removing the highest frequencies.
+ # At first I thought I only needed this when downsampling, but when upsampling
+ # you will get edge artifacts without this, as the edge is equivalent to zero padding,
+ # which will add high freq artifacts.
+ base_freq *= rolloff
+
+ # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
+ # using the sinc interpolation formula:
+ # x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t))
+ # We can then sample the function x(t) with a different sample rate:
+ # y[j] = x(j / new_freq)
+ # or,
+ # y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
+
+ # We see here that y[j] is the convolution of x[i] with a specific filter, for which
+ # we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing.
+ # But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq].
+ # Indeed:
+ # y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq))
+ # = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq))
+ # = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
+ # so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`.
+ # This will explain the F.conv1d after, with a stride of orig_freq.
+ width = math.ceil(lowpass_filter_width * orig_freq / base_freq)
+ # If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e.,
+ # they will have a lot of almost zero values to the left or to the right...
+ # There is probably a way to evaluate those filters more efficiently, but this is kept for
+ # future work.
+ idx_dtype = dtype if dtype is not None else torch.float64
+
+ idx = torch.arange(-width, width + orig_freq, dtype=idx_dtype, device=device)[None, None] / orig_freq
+
+ t = torch.arange(0, -new_freq, -1, dtype=dtype, device=device)[:, None, None] / new_freq + idx
+ t *= base_freq
+ t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
+
+ # we do not use built in torch windows here as we need to evaluate the window
+ # at specific positions, not over a regular grid.
+ if resampling_method == "sinc_interp_hann":
+ window = torch.cos(t * math.pi / lowpass_filter_width / 2) ** 2
+ else:
+ # sinc_interp_kaiser
+ if beta is None:
+ beta = 14.769656459379492
+ beta_tensor = torch.tensor(float(beta))
+ window = torch.i0(beta_tensor * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta_tensor)
+
+ t *= math.pi
+
+ scale = base_freq / orig_freq
+ kernels = torch.where(t == 0, torch.tensor(1.0).to(t), t.sin() / t)
+ kernels *= window * scale
+
+ if dtype is None:
+ kernels = kernels.to(dtype=torch.float32)
+
+ return kernels, width
+
+
+def _apply_sinc_resample_kernel(
+ waveform: Tensor,
+ orig_freq: int,
+ new_freq: int,
+ gcd: int,
+ kernel: Tensor,
+ width: int,
+):
+ if not waveform.is_floating_point():
+ raise TypeError(f"Expected floating point type for waveform tensor, but received {waveform.dtype}.")
+
+ orig_freq = int(orig_freq) // gcd
+ new_freq = int(new_freq) // gcd
+
+ # pack batch
+ shape = waveform.size()
+ waveform = waveform.view(-1, shape[-1])
+
+ num_wavs, length = waveform.shape
+ waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
+ resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
+ resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
+ target_length = torch.ceil(torch.as_tensor(new_freq * length / orig_freq)).long()
+ resampled = resampled[..., :target_length]
+
+ # unpack batch
+ resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
+ return resampled
+
+
+def resample(
+ waveform: Tensor,
+ orig_freq: int,
+ new_freq: int,
+ lowpass_filter_width: int = 6,
+ rolloff: float = 0.99,
+ resampling_method: str = "sinc_interp_hann",
+ beta: Optional[float] = None,
+) -> Tensor:
+ r"""Resamples the waveform at the new frequency using bandlimited interpolation. :cite:`RESAMPLE`.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Note:
+ ``transforms.Resample`` precomputes and reuses the resampling kernel, so using it will result in
+ more efficient computation if resampling multiple waveforms with the same resampling parameters.
+
+ Args:
+ waveform (Tensor): The input signal of dimension `(..., time)`
+ orig_freq (int): The original frequency of the signal
+ new_freq (int): The desired frequency
+ lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
+ but less efficient. (Default: ``6``)
+ rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
+ Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
+ resampling_method (str, optional): The resampling method to use.
+ Options: [``"sinc_interp_hann"``, ``"sinc_interp_kaiser"``] (Default: ``"sinc_interp_hann"``)
+ beta (float or None, optional): The shape parameter used for kaiser window.
+
+ Returns:
+ Tensor: The waveform at the new frequency of dimension `(..., time).`
+ """
+
+ if orig_freq <= 0.0 or new_freq <= 0.0:
+ raise ValueError("Original frequency and desired frequecy should be positive")
+
+ if orig_freq == new_freq:
+ return waveform
+
+ gcd = math.gcd(int(orig_freq), int(new_freq))
+
+ kernel, width = _get_sinc_resample_kernel(
+ orig_freq,
+ new_freq,
+ gcd,
+ lowpass_filter_width,
+ rolloff,
+ resampling_method,
+ beta,
+ waveform.device,
+ waveform.dtype,
+ )
+ resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width)
+ return resampled
+
+
+@torch.jit.unused
+def edit_distance(seq1: Sequence, seq2: Sequence) -> int:
+ """
+ Calculate the word level edit (Levenshtein) distance between two sequences.
+
+ .. devices:: CPU
+
+ The function computes an edit distance allowing deletion, insertion and
+ substitution. The result is an integer.
+
+ For most applications, the two input sequences should be the same type. If
+ two strings are given, the output is the edit distance between the two
+ strings (character edit distance). If two lists of strings are given, the
+ output is the edit distance between sentences (word edit distance). Users
+ may want to normalize the output by the length of the reference sequence.
+
+ Args:
+ seq1 (Sequence): the first sequence to compare.
+ seq2 (Sequence): the second sequence to compare.
+ Returns:
+ int: The distance between the first and second sequences.
+ """
+ len_sent2 = len(seq2)
+ dold = list(range(len_sent2 + 1))
+ dnew = [0 for _ in range(len_sent2 + 1)]
+
+ for i in range(1, len(seq1) + 1):
+ dnew[0] = i
+ for j in range(1, len_sent2 + 1):
+ if seq1[i - 1] == seq2[j - 1]:
+ dnew[j] = dold[j - 1]
+ else:
+ substitution = dold[j - 1] + 1
+ insertion = dnew[j - 1] + 1
+ deletion = dold[j] + 1
+ dnew[j] = min(substitution, insertion, deletion)
+
+ dnew, dold = dold, dnew
+
+ return int(dold[-1])
+
+
+def loudness(waveform: Tensor, sample_rate: int):
+ r"""Measure audio loudness according to the ITU-R BS.1770-4 recommendation.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: TorchScript
+
+ Args:
+ waveform(torch.Tensor): audio waveform of dimension `(..., channels, time)`
+ sample_rate (int): sampling rate of the waveform
+
+ Returns:
+ Tensor: loudness estimates (LKFS)
+
+ Reference:
+ - https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en
+ """
+
+ if waveform.size(-2) > 5:
+ raise ValueError("Only up to 5 channels are supported.")
+
+ gate_duration = 0.4
+ overlap = 0.75
+ gamma_abs = -70.0
+ kweight_bias = -0.691
+ gate_samples = int(round(gate_duration * sample_rate))
+ step = int(round(gate_samples * (1 - overlap)))
+
+ # Apply K-weighting
+ waveform = treble_biquad(waveform, sample_rate, 4.0, 1500.0, 1 / math.sqrt(2))
+ waveform = highpass_biquad(waveform, sample_rate, 38.0, 0.5)
+
+ # Compute the energy for each block
+ energy = torch.square(waveform).unfold(-1, gate_samples, step)
+ energy = torch.mean(energy, dim=-1)
+
+ # Compute channel-weighted summation
+ g = torch.tensor([1.0, 1.0, 1.0, 1.41, 1.41], dtype=waveform.dtype, device=waveform.device)
+ g = g[: energy.size(-2)]
+
+ energy_weighted = torch.sum(g.unsqueeze(-1) * energy, dim=-2)
+ loudness = -0.691 + 10 * torch.log10(energy_weighted)
+
+ # Apply absolute gating of the blocks
+ gated_blocks = loudness > gamma_abs
+ gated_blocks = gated_blocks.unsqueeze(-2)
+
+ energy_filtered = torch.sum(gated_blocks * energy, dim=-1) / torch.count_nonzero(gated_blocks, dim=-1)
+ energy_weighted = torch.sum(g * energy_filtered, dim=-1)
+ gamma_rel = kweight_bias + 10 * torch.log10(energy_weighted) - 10
+
+ # Apply relative gating of the blocks
+ gated_blocks = torch.logical_and(gated_blocks.squeeze(-2), loudness > gamma_rel.unsqueeze(-1))
+ gated_blocks = gated_blocks.unsqueeze(-2)
+
+ energy_filtered = torch.sum(gated_blocks * energy, dim=-1) / torch.count_nonzero(gated_blocks, dim=-1)
+ energy_weighted = torch.sum(g * energy_filtered, dim=-1)
+ LKFS = kweight_bias + 10 * torch.log10(energy_weighted)
+ return LKFS
+
+
+def pitch_shift(
+ waveform: Tensor,
+ sample_rate: int,
+ n_steps: int,
+ bins_per_octave: int = 12,
+ n_fft: int = 512,
+ win_length: Optional[int] = None,
+ hop_length: Optional[int] = None,
+ window: Optional[Tensor] = None,
+) -> Tensor:
+ """
+ Shift the pitch of a waveform by ``n_steps`` steps.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: TorchScript
+
+ Args:
+ waveform (Tensor): The input waveform of shape `(..., time)`.
+ sample_rate (int): Sample rate of `waveform`.
+ n_steps (int): The (fractional) steps to shift `waveform`.
+ bins_per_octave (int, optional): The number of steps per octave (Default: ``12``).
+ n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins (Default: ``512``).
+ win_length (int or None, optional): Window size. If None, then ``n_fft`` is used. (Default: ``None``).
+ hop_length (int or None, optional): Length of hop between STFT windows. If None, then
+ ``win_length // 4`` is used (Default: ``None``).
+ window (Tensor or None, optional): Window tensor that is applied/multiplied to each frame/window.
+ If None, then ``torch.hann_window(win_length)`` is used (Default: ``None``).
+
+
+ Returns:
+ Tensor: The pitch-shifted audio waveform of shape `(..., time)`.
+ """
+ waveform_stretch = _stretch_waveform(
+ waveform,
+ n_steps,
+ bins_per_octave,
+ n_fft,
+ win_length,
+ hop_length,
+ window,
+ )
+ rate = 2.0 ** (-float(n_steps) / bins_per_octave)
+ waveform_shift = resample(waveform_stretch, int(sample_rate / rate), sample_rate)
+
+ return _fix_waveform_shape(waveform_shift, waveform.size())
+
+
+def _stretch_waveform(
+ waveform: Tensor,
+ n_steps: int,
+ bins_per_octave: int = 12,
+ n_fft: int = 512,
+ win_length: Optional[int] = None,
+ hop_length: Optional[int] = None,
+ window: Optional[Tensor] = None,
+) -> Tensor:
+ """
+ Pitch shift helper function to preprocess and stretch waveform before resampling step.
+
+ Args:
+ See pitch_shift arg descriptions.
+
+ Returns:
+ Tensor: The preprocessed waveform stretched prior to resampling.
+ """
+ if hop_length is None:
+ hop_length = n_fft // 4
+ if win_length is None:
+ win_length = n_fft
+ if window is None:
+ window = torch.hann_window(window_length=win_length, device=waveform.device)
+
+ # pack batch
+ shape = waveform.size()
+ waveform = waveform.reshape(-1, shape[-1])
+
+ ori_len = shape[-1]
+ rate = 2.0 ** (-float(n_steps) / bins_per_octave)
+ spec_f = torch.stft(
+ input=waveform,
+ n_fft=n_fft,
+ hop_length=hop_length,
+ win_length=win_length,
+ window=window,
+ center=True,
+ pad_mode="reflect",
+ normalized=False,
+ onesided=True,
+ return_complex=True,
+ )
+ phase_advance = torch.linspace(0, math.pi * hop_length, spec_f.shape[-2], device=spec_f.device)[..., None]
+ spec_stretch = phase_vocoder(spec_f, rate, phase_advance)
+ len_stretch = int(round(ori_len / rate))
+ waveform_stretch = torch.istft(
+ spec_stretch, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=len_stretch
+ )
+ return waveform_stretch
+
+
+def _fix_waveform_shape(
+ waveform_shift: Tensor,
+ shape: List[int],
+) -> Tensor:
+ """
+ PitchShift helper function to process after resampling step to fix the shape back.
+
+ Args:
+ waveform_shift(Tensor): The waveform after stretch and resample
+ shape (List[int]): The shape of initial waveform
+
+ Returns:
+ Tensor: The pitch-shifted audio waveform of shape `(..., time)`.
+ """
+ ori_len = shape[-1]
+ shift_len = waveform_shift.size()[-1]
+ if shift_len > ori_len:
+ waveform_shift = waveform_shift[..., :ori_len]
+ else:
+ waveform_shift = torch.nn.functional.pad(waveform_shift, [0, ori_len - shift_len])
+
+ # unpack batch
+ waveform_shift = waveform_shift.view(shape[:-1] + waveform_shift.shape[-1:])
+ return waveform_shift
+
+
+def rnnt_loss(
+ logits: Tensor,
+ targets: Tensor,
+ logit_lengths: Tensor,
+ target_lengths: Tensor,
+ blank: int = -1,
+ clamp: float = -1,
+ reduction: str = "mean",
+ fused_log_softmax: bool = True,
+):
+ """Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
+ :cite:`graves2012sequence`.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ The RNN Transducer loss extends the CTC loss by defining a distribution over output
+ sequences of all lengths, and by jointly modelling both input-output and output-output
+ dependencies.
+
+ Args:
+ logits (Tensor): Tensor of dimension `(batch, max seq length, max target length + 1, class)`
+ containing output from joiner
+ targets (Tensor): Tensor of dimension `(batch, max target length)` containing targets with zero padded
+ logit_lengths (Tensor): Tensor of dimension `(batch)` containing lengths of each sequence from encoder
+ target_lengths (Tensor): Tensor of dimension `(batch)` containing lengths of targets for each sequence
+ blank (int, optional): blank label (Default: ``-1``)
+ clamp (float, optional): clamp for gradients (Default: ``-1``)
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``"none"`` | ``"mean"`` | ``"sum"``. (Default: ``"mean"``)
+ fused_log_softmax (bool): set to False if calling log_softmax outside of loss (Default: ``True``)
+ Returns:
+ Tensor: Loss with the reduction option applied. If ``reduction`` is ``"none"``, then size `(batch)`,
+ otherwise scalar.
+ """
+ if reduction not in ["none", "mean", "sum"]:
+ raise ValueError('reduction should be one of "none", "mean", or "sum"')
+
+ if blank < 0: # reinterpret blank index if blank < 0.
+ blank = logits.shape[-1] + blank
+
+ costs, _ = torch.ops.torchaudio.rnnt_loss(
+ logits=logits,
+ targets=targets,
+ logit_lengths=logit_lengths,
+ target_lengths=target_lengths,
+ blank=blank,
+ clamp=clamp,
+ fused_log_softmax=fused_log_softmax,
+ )
+
+ if reduction == "mean":
+ return costs.mean()
+ elif reduction == "sum":
+ return costs.sum()
+
+ return costs
+
+
+def psd(
+ specgram: Tensor,
+ mask: Optional[Tensor] = None,
+ normalize: bool = True,
+ eps: float = 1e-10,
+) -> Tensor:
+ """Compute cross-channel power spectral density (PSD) matrix.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ specgram (torch.Tensor): Multi-channel complex-valued spectrum.
+ Tensor with dimensions `(..., channel, freq, time)`.
+ mask (torch.Tensor or None, optional): Time-Frequency mask for normalization.
+ Tensor with dimensions `(..., freq, time)`. (Default: ``None``)
+ normalize (bool, optional): If ``True``, normalize the mask along the time dimension. (Default: ``True``)
+ eps (float, optional): Value to add to the denominator in mask normalization. (Default: ``1e-15``)
+
+ Returns:
+ torch.Tensor: The complex-valued PSD matrix of the input spectrum.
+ Tensor with dimensions `(..., freq, channel, channel)`
+ """
+ specgram = specgram.transpose(-3, -2) # shape (freq, channel, time)
+ # outer product:
+ # (..., ch_1, time) x (..., ch_2, time) -> (..., time, ch_1, ch_2)
+ psd = torch.einsum("...ct,...et->...tce", [specgram, specgram.conj()])
+
+ if mask is not None:
+ if mask.shape[:-1] != specgram.shape[:-2] or mask.shape[-1] != specgram.shape[-1]:
+ raise ValueError(
+ "The dimensions of mask except the channel dimension should be the same as specgram."
+ f"Found {mask.shape} for mask and {specgram.shape} for specgram."
+ )
+ # Normalized mask along time dimension:
+ if normalize:
+ mask = mask / (mask.sum(dim=-1, keepdim=True) + eps)
+
+ psd = psd * mask[..., None, None]
+
+ psd = psd.sum(dim=-3)
+ return psd
+
+
+def _compute_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch.Tensor:
+ r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.
+
+ Args:
+ input (torch.Tensor): Tensor with dimensions `(..., channel, channel)`.
+ dim1 (int, optional): The first dimension of the diagonal matrix.
+ (Default: ``-1``)
+ dim2 (int, optional): The second dimension of the diagonal matrix.
+ (Default: ``-2``)
+
+ Returns:
+ Tensor: The trace of the input Tensor.
+ """
+ if input.ndim < 2:
+ raise ValueError("The dimension of the tensor must be at least 2.")
+ if input.shape[dim1] != input.shape[dim2]:
+ raise ValueError("The size of ``dim1`` and ``dim2`` must be the same.")
+ input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2)
+ return input.sum(dim=-1)
+
+
+def _tik_reg(mat: torch.Tensor, reg: float = 1e-7, eps: float = 1e-8) -> torch.Tensor:
+ """Perform Tikhonov regularization (only modifying real part).
+
+ Args:
+ mat (torch.Tensor): Input matrix with dimensions `(..., channel, channel)`.
+ reg (float, optional): Regularization factor. (Default: 1e-8)
+ eps (float, optional): Value to avoid the correlation matrix is all-zero. (Default: ``1e-8``)
+
+ Returns:
+ Tensor: Regularized matrix with dimensions `(..., channel, channel)`.
+ """
+ # Add eps
+ C = mat.size(-1)
+ eye = torch.eye(C, dtype=mat.dtype, device=mat.device)
+ epsilon = _compute_mat_trace(mat).real[..., None, None] * reg
+ # in case that correlation_matrix is all-zero
+ epsilon = epsilon + eps
+ mat = mat + epsilon * eye[..., :, :]
+ return mat
+
+
+def _assert_psd_matrices(psd_s: torch.Tensor, psd_n: torch.Tensor) -> None:
+ """Assertion checks of the PSD matrices of target speech and noise.
+
+ Args:
+ psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
+ Tensor with dimensions `(..., freq, channel, channel)`.
+ psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
+ Tensor with dimensions `(..., freq, channel, channel)`.
+ """
+ if psd_s.ndim < 3 or psd_n.ndim < 3:
+ raise ValueError(
+ "Expected at least 3D Tensor (..., freq, channel, channel) for psd_s and psd_n. "
+ f"Found {psd_s.shape} for psd_s and {psd_n.shape} for psd_n."
+ )
+ if not (psd_s.is_complex() and psd_n.is_complex()):
+ raise TypeError(
+ "The type of psd_s and psd_n must be ``torch.cfloat`` or ``torch.cdouble``. "
+ f"Found {psd_s.dtype} for psd_s and {psd_n.dtype} for psd_n."
+ )
+ if psd_s.shape != psd_n.shape:
+ raise ValueError(
+ f"The dimensions of psd_s and psd_n should be the same. Found {psd_s.shape} and {psd_n.shape}."
+ )
+ if psd_s.shape[-1] != psd_s.shape[-2]:
+ raise ValueError(f"The last two dimensions of psd_s should be the same. Found {psd_s.shape}.")
+
+
+def mvdr_weights_souden(
+ psd_s: Tensor,
+ psd_n: Tensor,
+ reference_channel: Union[int, Tensor],
+ diagonal_loading: bool = True,
+ diag_eps: float = 1e-7,
+ eps: float = 1e-8,
+) -> Tensor:
+ r"""Compute the Minimum Variance Distortionless Response (*MVDR* :cite:`capon1969high`) beamforming weights
+ by the method proposed by *Souden et, al.* :cite:`souden2009optimal`.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Given the power spectral density (PSD) matrix of target speech :math:`\bf{\Phi}_{\textbf{SS}}`,
+ the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and a one-hot vector that represents the
+ reference channel :math:`\bf{u}`, the method computes the MVDR beamforming weight martrix
+ :math:`\textbf{w}_{\text{MVDR}}`. The formula is defined as:
+
+ .. math::
+ \textbf{w}_{\text{MVDR}}(f) =
+ \frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bf{\Phi}_{\textbf{SS}}}}(f)}
+ {\text{Trace}({{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f) \bf{\Phi}_{\textbf{SS}}}(f))}}\bm{u}
+
+ Args:
+ psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
+ Tensor with dimensions `(..., freq, channel, channel)`.
+ psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
+ Tensor with dimensions `(..., freq, channel, channel)`.
+ reference_channel (int or torch.Tensor): Specifies the reference channel.
+ If the dtype is ``int``, it represents the reference channel index.
+ If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
+ is one-hot.
+ diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
+ (Default: ``True``)
+ diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
+ It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
+ eps (float, optional): Value to add to the denominator in the beamforming weight formula.
+ (Default: ``1e-8``)
+
+ Returns:
+ torch.Tensor: The complex-valued MVDR beamforming weight matrix with dimensions `(..., freq, channel)`.
+ """
+ _assert_psd_matrices(psd_s, psd_n)
+
+ if diagonal_loading:
+ psd_n = _tik_reg(psd_n, reg=diag_eps)
+ numerator = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
+ # ws: (..., C, C) / (...,) -> (..., C, C)
+ ws = numerator / (_compute_mat_trace(numerator)[..., None, None] + eps)
+ if torch.jit.isinstance(reference_channel, int):
+ beamform_weights = ws[..., :, reference_channel]
+ elif torch.jit.isinstance(reference_channel, Tensor):
+ reference_channel = reference_channel.to(psd_n.dtype)
+ # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
+ beamform_weights = torch.einsum("...c,...c->...", [ws, reference_channel[..., None, None, :]])
+ else:
+ raise TypeError(f'Expected "int" or "Tensor" for reference_channel. Found: {type(reference_channel)}.')
+
+ return beamform_weights
+
+
+def mvdr_weights_rtf(
+ rtf: Tensor,
+ psd_n: Tensor,
+ reference_channel: Optional[Union[int, Tensor]] = None,
+ diagonal_loading: bool = True,
+ diag_eps: float = 1e-7,
+ eps: float = 1e-8,
+) -> Tensor:
+ r"""Compute the Minimum Variance Distortionless Response (*MVDR* :cite:`capon1969high`) beamforming weights
+ based on the relative transfer function (RTF) and power spectral density (PSD) matrix of noise.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Given the relative transfer function (RTF) matrix or the steering vector of target speech :math:`\bm{v}`,
+ the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and a one-hot vector that represents the
+ reference channel :math:`\bf{u}`, the method computes the MVDR beamforming weight martrix
+ :math:`\textbf{w}_{\text{MVDR}}`. The formula is defined as:
+
+ .. math::
+ \textbf{w}_{\text{MVDR}}(f) =
+ \frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}}
+ {{\bm{v}^{\mathsf{H}}}(f){\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}
+
+ where :math:`(.)^{\mathsf{H}}` denotes the Hermitian Conjugate operation.
+
+ Args:
+ rtf (torch.Tensor): The complex-valued RTF vector of target speech.
+ Tensor with dimensions `(..., freq, channel)`.
+ psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
+ Tensor with dimensions `(..., freq, channel, channel)`.
+ reference_channel (int or torch.Tensor): Specifies the reference channel.
+ If the dtype is ``int``, it represents the reference channel index.
+ If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
+ is one-hot.
+ diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
+ (Default: ``True``)
+ diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
+ It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
+ eps (float, optional): Value to add to the denominator in the beamforming weight formula.
+ (Default: ``1e-8``)
+
+ Returns:
+ torch.Tensor: The complex-valued MVDR beamforming weight matrix with dimensions `(..., freq, channel)`.
+ """
+ if rtf.ndim < 2:
+ raise ValueError(f"Expected at least 2D Tensor (..., freq, channel) for rtf. Found {rtf.shape}.")
+ if psd_n.ndim < 3:
+ raise ValueError(f"Expected at least 3D Tensor (..., freq, channel, channel) for psd_n. Found {psd_n.shape}.")
+ if not (rtf.is_complex() and psd_n.is_complex()):
+ raise TypeError(
+ "The type of rtf and psd_n must be ``torch.cfloat`` or ``torch.cdouble``. "
+ f"Found {rtf.dtype} for rtf and {psd_n.dtype} for psd_n."
+ )
+ if rtf.shape != psd_n.shape[:-1]:
+ raise ValueError(
+ "The dimensions of rtf and the dimensions withou the last dimension of psd_n should be the same. "
+ f"Found {rtf.shape} for rtf and {psd_n.shape} for psd_n."
+ )
+ if psd_n.shape[-1] != psd_n.shape[-2]:
+ raise ValueError(f"The last two dimensions of psd_n should be the same. Found {psd_n.shape}.")
+
+ if diagonal_loading:
+ psd_n = _tik_reg(psd_n, reg=diag_eps)
+ # numerator = psd_n.inv() @ stv
+ numerator = torch.linalg.solve(psd_n, rtf.unsqueeze(-1)).squeeze(-1) # (..., freq, channel)
+ # denominator = stv^H @ psd_n.inv() @ stv
+ denominator = torch.einsum("...d,...d->...", [rtf.conj(), numerator])
+ beamform_weights = numerator / (denominator.real.unsqueeze(-1) + eps)
+ # normalize the numerator
+ if reference_channel is not None:
+ if torch.jit.isinstance(reference_channel, int):
+ scale = rtf[..., reference_channel].conj()
+ elif torch.jit.isinstance(reference_channel, Tensor):
+ reference_channel = reference_channel.to(psd_n.dtype)
+ scale = torch.einsum("...c,...c->...", [rtf.conj(), reference_channel[..., None, :]])
+ else:
+ raise TypeError(f'Expected "int" or "Tensor" for reference_channel. Found: {type(reference_channel)}.')
+
+ beamform_weights = beamform_weights * scale[..., None]
+
+ return beamform_weights
+
+
+def rtf_evd(psd_s: Tensor) -> Tensor:
+ r"""Estimate the relative transfer function (RTF) or the steering vector by eigenvalue decomposition.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: TorchScript
+
+ Args:
+ psd_s (Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
+ Tensor of dimension `(..., freq, channel, channel)`
+
+ Returns:
+ Tensor: The estimated complex-valued RTF of target speech.
+ Tensor of dimension `(..., freq, channel)`
+ """
+ if not psd_s.is_complex():
+ raise TypeError(f"The type of psd_s must be ``torch.cfloat`` or ``torch.cdouble``. Found {psd_s.dtype}.")
+ if psd_s.shape[-1] != psd_s.shape[-2]:
+ raise ValueError(f"The last two dimensions of psd_s should be the same. Found {psd_s.shape}.")
+ _, v = torch.linalg.eigh(psd_s) # v is sorted along with eigenvalues in ascending order
+ rtf = v[..., -1] # choose the eigenvector with max eigenvalue
+ return rtf
+
+
+def rtf_power(
+ psd_s: Tensor,
+ psd_n: Tensor,
+ reference_channel: Union[int, Tensor],
+ n_iter: int = 3,
+ diagonal_loading: bool = True,
+ diag_eps: float = 1e-7,
+) -> Tensor:
+ r"""Estimate the relative transfer function (RTF) or the steering vector by the power method.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
+ Tensor with dimensions `(..., freq, channel, channel)`.
+ psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
+ Tensor with dimensions `(..., freq, channel, channel)`.
+ reference_channel (int or torch.Tensor): Specifies the reference channel.
+ If the dtype is ``int``, it represents the reference channel index.
+ If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
+ is one-hot.
+ diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
+ (Default: ``True``)
+ diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
+ It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
+
+ Returns:
+ torch.Tensor: The estimated complex-valued RTF of target speech.
+ Tensor of dimension `(..., freq, channel)`.
+ """
+ _assert_psd_matrices(psd_s, psd_n)
+ if n_iter <= 0:
+ raise ValueError("The number of iteration must be greater than 0.")
+
+ # Apply diagonal loading to psd_n to improve robustness.
+ if diagonal_loading:
+ psd_n = _tik_reg(psd_n, reg=diag_eps)
+ # phi is regarded as the first iteration
+ phi = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
+ if torch.jit.isinstance(reference_channel, int):
+ rtf = phi[..., reference_channel]
+ elif torch.jit.isinstance(reference_channel, Tensor):
+ reference_channel = reference_channel.to(psd_n.dtype)
+ rtf = torch.einsum("...c,...c->...", [phi, reference_channel[..., None, None, :]])
+ else:
+ raise TypeError(f'Expected "int" or "Tensor" for reference_channel. Found: {type(reference_channel)}.')
+ rtf = rtf.unsqueeze(-1) # (..., freq, channel, 1)
+ if n_iter >= 2:
+ # The number of iterations in the for loop is `n_iter - 2`
+ # because the `phi` above and `torch.matmul(psd_s, rtf)` are regarded as
+ # two iterations.
+ for _ in range(n_iter - 2):
+ rtf = torch.matmul(phi, rtf)
+ rtf = torch.matmul(psd_s, rtf)
+ else:
+ # if there is only one iteration, the rtf is the psd_s[..., referenc_channel]
+ # which is psd_n @ phi @ ref_channel
+ rtf = torch.matmul(psd_n, rtf)
+ return rtf.squeeze(-1)
+
+
+def apply_beamforming(beamform_weights: Tensor, specgram: Tensor) -> Tensor:
+ r"""Apply the beamforming weight to the multi-channel noisy spectrum to obtain the single-channel enhanced spectrum.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ .. math::
+ \hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f)
+
+ where :math:`\textbf{w}_{\text{bf}}(f)` is the beamforming weight for the :math:`f`-th frequency bin,
+ :math:`\textbf{Y}` is the multi-channel spectrum for the :math:`f`-th frequency bin.
+
+ Args:
+ beamform_weights (Tensor): The complex-valued beamforming weight matrix.
+ Tensor of dimension `(..., freq, channel)`
+ specgram (Tensor): The multi-channel complex-valued noisy spectrum.
+ Tensor of dimension `(..., channel, freq, time)`
+
+ Returns:
+ Tensor: The single-channel complex-valued enhanced spectrum.
+ Tensor of dimension `(..., freq, time)`
+ """
+ if beamform_weights.shape[:-2] != specgram.shape[:-3]:
+ raise ValueError(
+ "The dimensions except the last two dimensions of beamform_weights should be the same "
+ "as the dimensions except the last three dimensions of specgram. "
+ f"Found {beamform_weights.shape} for beamform_weights and {specgram.shape} for specgram."
+ )
+
+ if not (beamform_weights.is_complex() and specgram.is_complex()):
+ raise TypeError(
+ "The type of beamform_weights and specgram must be ``torch.cfloat`` or ``torch.cdouble``. "
+ f"Found {beamform_weights.dtype} for beamform_weights and {specgram.dtype} for specgram."
+ )
+
+ # (..., freq, channel) x (..., channel, freq, time) -> (..., freq, time)
+ specgram_enhanced = torch.einsum("...fc,...cft->...ft", [beamform_weights.conj(), specgram])
+ return specgram_enhanced
+
+
+def _check_shape_compatible(x: torch.Tensor, y: torch.Tensor) -> None:
+ if x.ndim != y.ndim:
+ raise ValueError(f"The operands must be the same dimension (got {x.ndim} and {y.ndim}).")
+
+ for i in range(x.ndim - 1):
+ xi = x.size(i)
+ yi = y.size(i)
+ if xi == yi or xi == 1 or yi == 1:
+ continue
+ raise ValueError(f"Leading dimensions of x and y are not broadcastable (got {x.shape} and {y.shape}).")
+
+
+def _check_convolve_mode(mode: str) -> None:
+ valid_convolve_modes = ["full", "valid", "same"]
+ if mode not in valid_convolve_modes:
+ raise ValueError(f"Unrecognized mode value '{mode}'. Please specify one of {valid_convolve_modes}.")
+
+
+def _apply_convolve_mode(conv_result: torch.Tensor, x_length: int, y_length: int, mode: str) -> torch.Tensor:
+ valid_convolve_modes = ["full", "valid", "same"]
+ if mode == "full":
+ return conv_result
+ elif mode == "valid":
+ target_length = max(x_length, y_length) - min(x_length, y_length) + 1
+ start_idx = (conv_result.size(-1) - target_length) // 2
+ return conv_result[..., start_idx : start_idx + target_length]
+ elif mode == "same":
+ start_idx = (conv_result.size(-1) - x_length) // 2
+ return conv_result[..., start_idx : start_idx + x_length]
+ else:
+ raise ValueError(f"Unrecognized mode value '{mode}'. Please specify one of {valid_convolve_modes}.")
+
+
+def fftconvolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.Tensor:
+ r"""
+ Convolves inputs along their last dimension using FFT. For inputs with large last dimensions, this function
+ is generally much faster than :meth:`convolve`.
+ Note that, in contrast to :meth:`torch.nn.functional.conv1d`, which actually applies the valid cross-correlation
+ operator, this function applies the true `convolution`_ operator.
+ Also note that this function can only output float tensors (int tensor inputs will be cast to float).
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ x (torch.Tensor): First convolution operand, with shape `(..., N)`.
+ y (torch.Tensor): Second convolution operand, with shape `(..., M)`
+ (leading dimensions must be broadcast-able with those of ``x``).
+ mode (str, optional): Must be one of ("full", "valid", "same").
+
+ * "full": Returns the full convolution result, with shape `(..., N + M - 1)`. (Default)
+ * "valid": Returns the segment of the full convolution result corresponding to where
+ the two inputs overlap completely, with shape `(..., max(N, M) - min(N, M) + 1)`.
+ * "same": Returns the center segment of the full convolution result, with shape `(..., N)`.
+
+ Returns:
+ torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(..., L)`, where
+ the leading dimensions match those of ``x`` and `L` is dictated by ``mode``.
+
+ .. _convolution:
+ https://en.wikipedia.org/wiki/Convolution
+ """
+ _check_shape_compatible(x, y)
+ _check_convolve_mode(mode)
+
+ n = x.size(-1) + y.size(-1) - 1
+ fresult = torch.fft.rfft(x, n=n) * torch.fft.rfft(y, n=n)
+ result = torch.fft.irfft(fresult, n=n)
+ return _apply_convolve_mode(result, x.size(-1), y.size(-1), mode)
+
+
+def convolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.Tensor:
+ r"""
+ Convolves inputs along their last dimension using the direct method.
+ Note that, in contrast to :meth:`torch.nn.functional.conv1d`, which actually applies the valid cross-correlation
+ operator, this function applies the true `convolution`_ operator.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ x (torch.Tensor): First convolution operand, with shape `(..., N)`.
+ y (torch.Tensor): Second convolution operand, with shape `(..., M)`
+ (leading dimensions must be broadcast-able with those of ``x``).
+ mode (str, optional): Must be one of ("full", "valid", "same").
+
+ * "full": Returns the full convolution result, with shape `(..., N + M - 1)`. (Default)
+ * "valid": Returns the segment of the full convolution result corresponding to where
+ the two inputs overlap completely, with shape `(..., max(N, M) - min(N, M) + 1)`.
+ * "same": Returns the center segment of the full convolution result, with shape `(..., N)`.
+
+ Returns:
+ torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(..., L)`, where
+ the leading dimensions match those of ``x`` and `L` is dictated by ``mode``.
+
+ .. _convolution:
+ https://en.wikipedia.org/wiki/Convolution
+ """
+ _check_shape_compatible(x, y)
+ _check_convolve_mode(mode)
+
+ x_size, y_size = x.size(-1), y.size(-1)
+
+ if x.size(-1) < y.size(-1):
+ x, y = y, x
+
+ if x.shape[:-1] != y.shape[:-1]:
+ new_shape = [max(i, j) for i, j in zip(x.shape[:-1], y.shape[:-1])]
+ x = x.broadcast_to(new_shape + [x.shape[-1]])
+ y = y.broadcast_to(new_shape + [y.shape[-1]])
+
+ num_signals = torch.tensor(x.shape[:-1]).prod()
+ reshaped_x = x.reshape((int(num_signals), x.size(-1)))
+ reshaped_y = y.reshape((int(num_signals), y.size(-1)))
+ output = torch.nn.functional.conv1d(
+ input=reshaped_x,
+ weight=reshaped_y.flip(-1).unsqueeze(1),
+ stride=1,
+ groups=reshaped_x.size(0),
+ padding=reshaped_y.size(-1) - 1,
+ )
+ output_shape = x.shape[:-1] + (-1,)
+ result = output.reshape(output_shape)
+ return _apply_convolve_mode(result, x_size, y_size, mode)
+
+
+def add_noise(
+ waveform: torch.Tensor, noise: torch.Tensor, snr: torch.Tensor, lengths: Optional[torch.Tensor] = None
+) -> torch.Tensor:
+ r"""Scales and adds noise to waveform per signal-to-noise ratio.
+
+ Specifically, for each pair of waveform vector :math:`x \in \mathbb{R}^L` and noise vector
+ :math:`n \in \mathbb{R}^L`, the function computes output :math:`y` as
+
+ .. math::
+ y = x + a n \, \text{,}
+
+ where
+
+ .. math::
+ a = \sqrt{ \frac{ ||x||_{2}^{2} }{ ||n||_{2}^{2} } \cdot 10^{-\frac{\text{SNR}}{10}} } \, \text{,}
+
+ with :math:`\text{SNR}` being the desired signal-to-noise ratio between :math:`x` and :math:`n`, in dB.
+
+ Note that this function broadcasts singleton leading dimensions in its inputs in a manner that is
+ consistent with the above formulae and PyTorch's broadcasting semantics.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ waveform (torch.Tensor): Input waveform, with shape `(..., L)`.
+ noise (torch.Tensor): Noise, with shape `(..., L)` (same shape as ``waveform``).
+ snr (torch.Tensor): Signal-to-noise ratios in dB, with shape `(...,)`.
+ lengths (torch.Tensor or None, optional): Valid lengths of signals in ``waveform`` and ``noise``, with shape
+ `(...,)` (leading dimensions must match those of ``waveform``). If ``None``, all elements in ``waveform``
+ and ``noise`` are treated as valid. (Default: ``None``)
+
+ Returns:
+ torch.Tensor: Result of scaling and adding ``noise`` to ``waveform``, with shape `(..., L)`
+ (same shape as ``waveform``).
+ """
+
+ if not (waveform.ndim - 1 == noise.ndim - 1 == snr.ndim and (lengths is None or lengths.ndim == snr.ndim)):
+ raise ValueError("Input leading dimensions don't match.")
+
+ L = waveform.size(-1)
+
+ if L != noise.size(-1):
+ raise ValueError(f"Length dimensions of waveform and noise don't match (got {L} and {noise.size(-1)}).")
+
+ # compute scale
+ if lengths is not None:
+ mask = torch.arange(0, L, device=lengths.device).expand(waveform.shape) < lengths.unsqueeze(
+ -1
+ ) # (*, L) < (*, 1) = (*, L)
+ masked_waveform = waveform * mask
+ masked_noise = noise * mask
+ else:
+ masked_waveform = waveform
+ masked_noise = noise
+
+ energy_signal = torch.linalg.vector_norm(masked_waveform, ord=2, dim=-1) ** 2 # (*,)
+ energy_noise = torch.linalg.vector_norm(masked_noise, ord=2, dim=-1) ** 2 # (*,)
+ original_snr_db = 10 * (torch.log10(energy_signal) - torch.log10(energy_noise))
+ scale = 10 ** ((original_snr_db - snr) / 20.0) # (*,)
+
+ # scale noise
+ scaled_noise = scale.unsqueeze(-1) * noise # (*, 1) * (*, L) = (*, L)
+
+ return waveform + scaled_noise # (*, L)
+
+
+def speed(
+ waveform: torch.Tensor, orig_freq: int, factor: float, lengths: Optional[torch.Tensor] = None
+) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ r"""Adjusts waveform speed.
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ waveform (torch.Tensor): Input signals, with shape `(..., time)`.
+ orig_freq (int): Original frequency of the signals in ``waveform``.
+ factor (float): Factor by which to adjust speed of input. Values greater than 1.0
+ compress ``waveform`` in time, whereas values less than 1.0 stretch ``waveform`` in time.
+ lengths (torch.Tensor or None, optional): Valid lengths of signals in ``waveform``, with shape `(...)`.
+ If ``None``, all elements in ``waveform`` are treated as valid. (Default: ``None``)
+
+ Returns:
+ (torch.Tensor, torch.Tensor or None):
+ torch.Tensor
+ Speed-adjusted waveform, with shape `(..., new_time).`
+ torch.Tensor or None
+ If ``lengths`` is not ``None``, valid lengths of signals in speed-adjusted waveform,
+ with shape `(...)`; otherwise, ``None``.
+ """
+
+ source_sample_rate = int(factor * orig_freq)
+ target_sample_rate = int(orig_freq)
+
+ gcd = math.gcd(source_sample_rate, target_sample_rate)
+ source_sample_rate = source_sample_rate // gcd
+ target_sample_rate = target_sample_rate // gcd
+
+ if lengths is None:
+ out_lengths = None
+ else:
+ out_lengths = torch.ceil(lengths * target_sample_rate / source_sample_rate).to(lengths.dtype)
+
+ return resample(waveform, source_sample_rate, target_sample_rate), out_lengths
+
+
+def preemphasis(waveform, coeff: float = 0.97) -> torch.Tensor:
+ r"""Pre-emphasizes a waveform along its last dimension, i.e.
+ for each signal :math:`x` in ``waveform``, computes
+ output :math:`y` as
+
+ .. math::
+ y[i] = x[i] - \text{coeff} \cdot x[i - 1]
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ waveform (torch.Tensor): Waveform, with shape `(..., N)`.
+ coeff (float, optional): Pre-emphasis coefficient. Typically between 0.0 and 1.0.
+ (Default: 0.97)
+
+ Returns:
+ torch.Tensor: Pre-emphasized waveform, with shape `(..., N)`.
+ """
+ waveform = waveform.clone()
+ waveform[..., 1:] -= coeff * waveform[..., :-1]
+ return waveform
+
+
+def deemphasis(waveform, coeff: float = 0.97) -> torch.Tensor:
+ r"""De-emphasizes a waveform along its last dimension.
+ Inverse of :meth:`preemphasis`. Concretely, for each signal
+ :math:`x` in ``waveform``, computes output :math:`y` as
+
+ .. math::
+ y[i] = x[i] + \text{coeff} \cdot y[i - 1]
+
+ .. devices:: CPU CUDA
+
+ .. properties:: Autograd TorchScript
+
+ Args:
+ waveform (torch.Tensor): Waveform, with shape `(..., N)`.
+ coeff (float, optional): De-emphasis coefficient. Typically between 0.0 and 1.0.
+ (Default: 0.97)
+
+ Returns:
+ torch.Tensor: De-emphasized waveform, with shape `(..., N)`.
+ """
+ a_coeffs = torch.tensor([1.0, -coeff], dtype=waveform.dtype, device=waveform.device)
+ b_coeffs = torch.tensor([1.0, 0.0], dtype=waveform.dtype, device=waveform.device)
+ return torchaudio.functional.lfilter(waveform, a_coeffs=a_coeffs, b_coeffs=b_coeffs)
+
+
+def frechet_distance(mu_x, sigma_x, mu_y, sigma_y):
+ r"""Computes the Fréchet distance between two multivariate normal distributions :cite:`dowson1982frechet`.
+
+ Concretely, for multivariate Gaussians :math:`X(\mu_X, \Sigma_X)`
+ and :math:`Y(\mu_Y, \Sigma_Y)`, the function computes and returns :math:`F` as
+
+ .. math::
+ F(X, Y) = || \mu_X - \mu_Y ||_2^2
+ + \text{Tr}\left( \Sigma_X + \Sigma_Y - 2 \sqrt{\Sigma_X \Sigma_Y} \right)
+
+ Args:
+ mu_x (torch.Tensor): mean :math:`\mu_X` of multivariate Gaussian :math:`X`, with shape `(N,)`.
+ sigma_x (torch.Tensor): covariance matrix :math:`\Sigma_X` of :math:`X`, with shape `(N, N)`.
+ mu_y (torch.Tensor): mean :math:`\mu_Y` of multivariate Gaussian :math:`Y`, with shape `(N,)`.
+ sigma_y (torch.Tensor): covariance matrix :math:`\Sigma_Y` of :math:`Y`, with shape `(N, N)`.
+
+ Returns:
+ torch.Tensor: the Fréchet distance between :math:`X` and :math:`Y`.
+ """
+ if len(mu_x.size()) != 1:
+ raise ValueError(f"Input mu_x must be one-dimensional; got dimension {len(mu_x.size())}.")
+ if len(sigma_x.size()) != 2:
+ raise ValueError(f"Input sigma_x must be two-dimensional; got dimension {len(sigma_x.size())}.")
+ if sigma_x.size(0) != sigma_x.size(1) != mu_x.size(0):
+ raise ValueError("Each of sigma_x's dimensions must match mu_x's size.")
+ if mu_x.size() != mu_y.size():
+ raise ValueError(f"Inputs mu_x and mu_y must have the same shape; got {mu_x.size()} and {mu_y.size()}.")
+ if sigma_x.size() != sigma_y.size():
+ raise ValueError(
+ f"Inputs sigma_x and sigma_y must have the same shape; got {sigma_x.size()} and {sigma_y.size()}."
+ )
+
+ a = (mu_x - mu_y).square().sum()
+ b = sigma_x.trace() + sigma_y.trace()
+ c = torch.linalg.eigvals(sigma_x @ sigma_y).sqrt().real.sum()
+ return a + b - 2 * c
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/kaldi_io.py b/.venv/lib/python3.11/site-packages/torchaudio/kaldi_io.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d372429dcfd23fdcbe8cd0f38abef1086d8a5eb
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/kaldi_io.py
@@ -0,0 +1,144 @@
+# To use this file, the dependency (https://github.com/vesis84/kaldi-io-for-python)
+# needs to be installed. This is a light wrapper around kaldi_io that returns
+# torch.Tensors.
+from typing import Any, Callable, Iterable, Tuple
+
+import torch
+from torch import Tensor
+from torchaudio._internal import module_utils as _mod_utils
+
+if _mod_utils.is_module_available("numpy"):
+ import numpy as np
+
+
+__all__ = [
+ "read_vec_int_ark",
+ "read_vec_flt_scp",
+ "read_vec_flt_ark",
+ "read_mat_scp",
+ "read_mat_ark",
+]
+
+
+def _convert_method_output_to_tensor(
+ file_or_fd: Any, fn: Callable, convert_contiguous: bool = False
+) -> Iterable[Tuple[str, Tensor]]:
+ r"""Takes a method invokes it. The output is converted to a tensor.
+
+ Args:
+ file_or_fd (str/FileDescriptor): File name or file descriptor
+ fn (Callable): Function that has the signature (file name/descriptor) and converts it to
+ Iterable[Tuple[str, Tensor]].
+ convert_contiguous (bool, optional): Determines whether the array should be converted into a
+ contiguous layout. (Default: ``False``)
+
+ Returns:
+ Iterable[Tuple[str, Tensor]]: The string is the key and the tensor is vec/mat
+ """
+ for key, np_arr in fn(file_or_fd):
+ if convert_contiguous:
+ np_arr = np.ascontiguousarray(np_arr)
+ yield key, torch.from_numpy(np_arr)
+
+
+@_mod_utils.requires_module("kaldi_io", "numpy")
+def read_vec_int_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
+ r"""Create generator of (key,vector) tuples, which reads from the ark file/stream.
+
+ Args:
+ file_or_fd (str/FileDescriptor): ark, gzipped ark, pipe or opened file descriptor
+
+ Returns:
+ Iterable[Tuple[str, Tensor]]: The string is the key and the tensor is the vector read from file
+
+ Example
+ >>> # read ark to a 'dictionary'
+ >>> d = { u:d for u,d in torchaudio.kaldi_io.read_vec_int_ark(file) }
+ """
+
+ import kaldi_io
+
+ # Requires convert_contiguous to be True because elements from int32 vector are
+ # sorted in tuples: (sizeof(int32), value) so strides are (5,) instead of (4,) which will throw an error
+ # in from_numpy as it expects strides to be a multiple of 4 (int32).
+ return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_vec_int_ark, convert_contiguous=True)
+
+
+@_mod_utils.requires_module("kaldi_io", "numpy")
+def read_vec_flt_scp(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
+ r"""Create generator of (key,vector) tuples, read according to Kaldi scp.
+
+ Args:
+ file_or_fd (str/FileDescriptor): scp, gzipped scp, pipe or opened file descriptor
+
+ Returns:
+ Iterable[Tuple[str, Tensor]]: The string is the key and the tensor is the vector read from file
+
+ Example
+ >>> # read scp to a 'dictionary'
+ >>> # d = { u:d for u,d in torchaudio.kaldi_io.read_vec_flt_scp(file) }
+ """
+
+ import kaldi_io
+
+ return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_vec_flt_scp)
+
+
+@_mod_utils.requires_module("kaldi_io", "numpy")
+def read_vec_flt_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
+ r"""Create generator of (key,vector) tuples, which reads from the ark file/stream.
+
+ Args:
+ file_or_fd (str/FileDescriptor): ark, gzipped ark, pipe or opened file descriptor
+
+ Returns:
+ Iterable[Tuple[str, Tensor]]: The string is the key and the tensor is the vector read from file
+
+ Example
+ >>> # read ark to a 'dictionary'
+ >>> d = { u:d for u,d in torchaudio.kaldi_io.read_vec_flt_ark(file) }
+ """
+
+ import kaldi_io
+
+ return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_vec_flt_ark)
+
+
+@_mod_utils.requires_module("kaldi_io", "numpy")
+def read_mat_scp(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
+ r"""Create generator of (key,matrix) tuples, read according to Kaldi scp.
+
+ Args:
+ file_or_fd (str/FileDescriptor): scp, gzipped scp, pipe or opened file descriptor
+
+ Returns:
+ Iterable[Tuple[str, Tensor]]: The string is the key and the tensor is the matrix read from file
+
+ Example
+ >>> # read scp to a 'dictionary'
+ >>> d = { u:d for u,d in torchaudio.kaldi_io.read_mat_scp(file) }
+ """
+
+ import kaldi_io
+
+ return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_mat_scp)
+
+
+@_mod_utils.requires_module("kaldi_io", "numpy")
+def read_mat_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
+ r"""Create generator of (key,matrix) tuples, which reads from the ark file/stream.
+
+ Args:
+ file_or_fd (str/FileDescriptor): ark, gzipped ark, pipe or opened file descriptor
+
+ Returns:
+ Iterable[Tuple[str, Tensor]]: The string is the key and the tensor is the matrix read from file
+
+ Example
+ >>> # read ark to a 'dictionary'
+ >>> d = { u:d for u,d in torchaudio.kaldi_io.read_mat_ark(file) }
+ """
+
+ import kaldi_io
+
+ return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_mat_ark)
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/pipelines/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..efec1f3521e760803e095efb71f164ed268896f1
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/__init__.py
@@ -0,0 +1,102 @@
+from ._source_separation_pipeline import (
+ CONVTASNET_BASE_LIBRI2MIX,
+ HDEMUCS_HIGH_MUSDB,
+ HDEMUCS_HIGH_MUSDB_PLUS,
+ SourceSeparationBundle,
+)
+from ._squim_pipeline import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE, SquimObjectiveBundle, SquimSubjectiveBundle
+from ._tts import (
+ TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
+ TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
+ TACOTRON2_WAVERNN_CHAR_LJSPEECH,
+ TACOTRON2_WAVERNN_PHONE_LJSPEECH,
+ Tacotron2TTSBundle,
+)
+from ._wav2vec2.impl import (
+ HUBERT_ASR_LARGE,
+ HUBERT_ASR_XLARGE,
+ HUBERT_BASE,
+ HUBERT_LARGE,
+ HUBERT_XLARGE,
+ MMS_FA,
+ VOXPOPULI_ASR_BASE_10K_DE,
+ VOXPOPULI_ASR_BASE_10K_EN,
+ VOXPOPULI_ASR_BASE_10K_ES,
+ VOXPOPULI_ASR_BASE_10K_FR,
+ VOXPOPULI_ASR_BASE_10K_IT,
+ WAV2VEC2_ASR_BASE_100H,
+ WAV2VEC2_ASR_BASE_10M,
+ WAV2VEC2_ASR_BASE_960H,
+ WAV2VEC2_ASR_LARGE_100H,
+ WAV2VEC2_ASR_LARGE_10M,
+ WAV2VEC2_ASR_LARGE_960H,
+ WAV2VEC2_ASR_LARGE_LV60K_100H,
+ WAV2VEC2_ASR_LARGE_LV60K_10M,
+ WAV2VEC2_ASR_LARGE_LV60K_960H,
+ WAV2VEC2_BASE,
+ WAV2VEC2_LARGE,
+ WAV2VEC2_LARGE_LV60K,
+ WAV2VEC2_XLSR53,
+ WAV2VEC2_XLSR_1B,
+ WAV2VEC2_XLSR_2B,
+ WAV2VEC2_XLSR_300M,
+ Wav2Vec2ASRBundle,
+ Wav2Vec2Bundle,
+ Wav2Vec2FABundle,
+ WAVLM_BASE,
+ WAVLM_BASE_PLUS,
+ WAVLM_LARGE,
+)
+from .rnnt_pipeline import EMFORMER_RNNT_BASE_LIBRISPEECH, RNNTBundle
+
+
+__all__ = [
+ "Wav2Vec2Bundle",
+ "Wav2Vec2ASRBundle",
+ "Wav2Vec2FABundle",
+ "WAV2VEC2_BASE",
+ "WAV2VEC2_LARGE",
+ "WAV2VEC2_LARGE_LV60K",
+ "WAV2VEC2_ASR_BASE_10M",
+ "WAV2VEC2_ASR_BASE_100H",
+ "WAV2VEC2_ASR_BASE_960H",
+ "WAV2VEC2_ASR_LARGE_10M",
+ "WAV2VEC2_ASR_LARGE_100H",
+ "WAV2VEC2_ASR_LARGE_960H",
+ "WAV2VEC2_ASR_LARGE_LV60K_10M",
+ "WAV2VEC2_ASR_LARGE_LV60K_100H",
+ "WAV2VEC2_ASR_LARGE_LV60K_960H",
+ "WAV2VEC2_XLSR53",
+ "WAV2VEC2_XLSR_300M",
+ "WAV2VEC2_XLSR_1B",
+ "WAV2VEC2_XLSR_2B",
+ "VOXPOPULI_ASR_BASE_10K_EN",
+ "VOXPOPULI_ASR_BASE_10K_ES",
+ "VOXPOPULI_ASR_BASE_10K_DE",
+ "VOXPOPULI_ASR_BASE_10K_FR",
+ "VOXPOPULI_ASR_BASE_10K_IT",
+ "HUBERT_BASE",
+ "HUBERT_LARGE",
+ "HUBERT_XLARGE",
+ "HUBERT_ASR_LARGE",
+ "HUBERT_ASR_XLARGE",
+ "MMS_FA",
+ "WAVLM_BASE",
+ "WAVLM_BASE_PLUS",
+ "WAVLM_LARGE",
+ "Tacotron2TTSBundle",
+ "TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH",
+ "TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH",
+ "TACOTRON2_WAVERNN_CHAR_LJSPEECH",
+ "TACOTRON2_WAVERNN_PHONE_LJSPEECH",
+ "RNNTBundle",
+ "EMFORMER_RNNT_BASE_LIBRISPEECH",
+ "SourceSeparationBundle",
+ "CONVTASNET_BASE_LIBRI2MIX",
+ "HDEMUCS_HIGH_MUSDB_PLUS",
+ "HDEMUCS_HIGH_MUSDB",
+ "SQUIM_OBJECTIVE",
+ "SQUIM_SUBJECTIVE",
+ "SquimObjectiveBundle",
+ "SquimSubjectiveBundle",
+]
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/pipelines/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ab43fa5f988fb96d82cec03edae9e6bd789bdce
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/pipelines/__pycache__/_source_separation_pipeline.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/__pycache__/_source_separation_pipeline.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d5a2c6b245dbd2bfe446d3e77ae9ce597235dc94
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/__pycache__/_source_separation_pipeline.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/pipelines/__pycache__/_squim_pipeline.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/__pycache__/_squim_pipeline.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f5db23f266ed2ed6055eb35830f7c89ec2908eb4
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/__pycache__/_squim_pipeline.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/pipelines/__pycache__/rnnt_pipeline.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/__pycache__/rnnt_pipeline.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d6727d90e450636686d5df9788d1ebce90c6b25d
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/__pycache__/rnnt_pipeline.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_source_separation_pipeline.py b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_source_separation_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae92e21831307f91450b32a73563c0011e455753
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_source_separation_pipeline.py
@@ -0,0 +1,109 @@
+from dataclasses import dataclass
+from functools import partial
+from typing import Callable
+
+import torch
+import torchaudio
+
+from torchaudio.models import conv_tasnet_base, hdemucs_high
+
+
+@dataclass
+class SourceSeparationBundle:
+ """Dataclass that bundles components for performing source separation.
+
+ Example
+ >>> import torchaudio
+ >>> from torchaudio.pipelines import CONVTASNET_BASE_LIBRI2MIX
+ >>> import torch
+ >>>
+ >>> # Build the separation model.
+ >>> model = CONVTASNET_BASE_LIBRI2MIX.get_model()
+ >>> 100%|███████████████████████████████|19.1M/19.1M [00:04<00:00, 4.93MB/s]
+ >>>
+ >>> # Instantiate the test set of Libri2Mix dataset.
+ >>> dataset = torchaudio.datasets.LibriMix("/home/datasets/", subset="test")
+ >>>
+ >>> # Apply source separation on mixture audio.
+ >>> for i, data in enumerate(dataset):
+ >>> sample_rate, mixture, clean_sources = data
+ >>> # Make sure the shape of input suits the model requirement.
+ >>> mixture = mixture.reshape(1, 1, -1)
+ >>> estimated_sources = model(mixture)
+ >>> score = si_snr_pit(estimated_sources, clean_sources) # for demonstration
+ >>> print(f"Si-SNR score is : {score}.)
+ >>> break
+ >>> Si-SNR score is : 16.24.
+ >>>
+ """
+
+ _model_path: str
+ _model_factory_func: Callable[[], torch.nn.Module]
+ _sample_rate: int
+
+ @property
+ def sample_rate(self) -> int:
+ """Sample rate of the audio that the model is trained on.
+
+ :type: int
+ """
+ return self._sample_rate
+
+ def get_model(self) -> torch.nn.Module:
+ """Construct the model and load the pretrained weight."""
+ model = self._model_factory_func()
+ path = torchaudio.utils.download_asset(self._model_path)
+ state_dict = torch.load(path)
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+
+CONVTASNET_BASE_LIBRI2MIX = SourceSeparationBundle(
+ _model_path="models/conv_tasnet_base_libri2mix.pt",
+ _model_factory_func=partial(conv_tasnet_base, num_sources=2),
+ _sample_rate=8000,
+)
+CONVTASNET_BASE_LIBRI2MIX.__doc__ = """Pre-trained Source Separation pipeline with *ConvTasNet*
+:cite:`Luo_2019` trained on *Libri2Mix dataset* :cite:`cosentino2020librimix`.
+
+The source separation model is constructed by :func:`~torchaudio.models.conv_tasnet_base`
+and is trained using the training script ``lightning_train.py``
+`here `__
+with default arguments.
+
+Please refer to :class:`SourceSeparationBundle` for usage instructions.
+"""
+
+
+HDEMUCS_HIGH_MUSDB_PLUS = SourceSeparationBundle(
+ _model_path="models/hdemucs_high_trained.pt",
+ _model_factory_func=partial(hdemucs_high, sources=["drums", "bass", "other", "vocals"]),
+ _sample_rate=44100,
+)
+HDEMUCS_HIGH_MUSDB_PLUS.__doc__ = """Pre-trained music source separation pipeline with
+*Hybrid Demucs* :cite:`defossez2021hybrid` trained on both training and test sets of
+MUSDB-HQ :cite:`MUSDB18HQ` and an additional 150 extra songs from an internal database
+that was specifically produced for Meta.
+
+The model is constructed by :func:`~torchaudio.models.hdemucs_high`.
+
+Training was performed in the original HDemucs repository `here `__.
+
+Please refer to :class:`SourceSeparationBundle` for usage instructions.
+"""
+
+
+HDEMUCS_HIGH_MUSDB = SourceSeparationBundle(
+ _model_path="models/hdemucs_high_musdbhq_only.pt",
+ _model_factory_func=partial(hdemucs_high, sources=["drums", "bass", "other", "vocals"]),
+ _sample_rate=44100,
+)
+HDEMUCS_HIGH_MUSDB.__doc__ = """Pre-trained music source separation pipeline with
+*Hybrid Demucs* :cite:`defossez2021hybrid` trained on the training set of MUSDB-HQ :cite:`MUSDB18HQ`.
+
+The model is constructed by :func:`~torchaudio.models.hdemucs_high`.
+Training was performed in the original HDemucs repository `here `__.
+
+Please refer to :class:`SourceSeparationBundle` for usage instructions.
+"""
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_squim_pipeline.py b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_squim_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c70db4aef70397d33dcb3d3b28131221cef52c8
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_squim_pipeline.py
@@ -0,0 +1,156 @@
+from dataclasses import dataclass
+
+import torch
+import torchaudio
+
+from torchaudio.models import squim_objective_base, squim_subjective_base, SquimObjective, SquimSubjective
+
+
+@dataclass
+class SquimObjectiveBundle:
+ """Data class that bundles associated information to use pretrained
+ :py:class:`~torchaudio.models.SquimObjective` model.
+
+ This class provides interfaces for instantiating the pretrained model along with
+ the information necessary to retrieve pretrained weights and additional data
+ to be used with the model.
+
+ Torchaudio library instantiates objects of this class, each of which represents
+ a different pretrained model. Client code should access pretrained models via these
+ instances.
+
+ This bundle can estimate objective metric scores for speech enhancement, such as STOI, PESQ, Si-SDR.
+ A typical use case would be a flow like `waveform -> list of scores`. Please see below for the code example.
+
+ Example: Estimate the objective metric scores for the input waveform.
+ >>> import torch
+ >>> import torchaudio
+ >>> from torchaudio.pipelines import SQUIM_OBJECTIVE as bundle
+ >>>
+ >>> # Load the SquimObjective bundle
+ >>> model = bundle.get_model()
+ Downloading: "https://download.pytorch.org/torchaudio/models/squim_objective_dns2020.pth"
+ 100%|████████████| 28.2M/28.2M [00:03<00:00, 9.24MB/s]
+ >>>
+ >>> # Resample audio to the expected sampling rate
+ >>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
+ >>>
+ >>> # Estimate objective metric scores
+ >>> scores = model(waveform)
+ >>> print(f"STOI: {scores[0].item()}, PESQ: {scores[1].item()}, SI-SDR: {scores[2].item()}.")
+ """ # noqa: E501
+
+ _path: str
+ _sample_rate: float
+
+ def get_model(self) -> SquimObjective:
+ """Construct the SquimObjective model, and load the pretrained weight.
+
+ Returns:
+ Variation of :py:class:`~torchaudio.models.SquimObjective`.
+ """
+ model = squim_objective_base()
+ path = torchaudio.utils.download_asset(f"models/{self._path}")
+ state_dict = torch.load(path, weights_only=True)
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+ @property
+ def sample_rate(self):
+ """Sample rate of the audio that the model is trained on.
+
+ :type: float
+ """
+ return self._sample_rate
+
+
+SQUIM_OBJECTIVE = SquimObjectiveBundle(
+ "squim_objective_dns2020.pth",
+ _sample_rate=16000,
+)
+SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
+ :cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.
+
+ The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
+ The weights are under `Creative Commons Attribution 4.0 International License
+ `__.
+
+ Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
+ """
+
+
+@dataclass
+class SquimSubjectiveBundle:
+ """Data class that bundles associated information to use pretrained
+ :py:class:`~torchaudio.models.SquimSubjective` model.
+
+ This class provides interfaces for instantiating the pretrained model along with
+ the information necessary to retrieve pretrained weights and additional data
+ to be used with the model.
+
+ Torchaudio library instantiates objects of this class, each of which represents
+ a different pretrained model. Client code should access pretrained models via these
+ instances.
+
+ This bundle can estimate subjective metric scores for speech enhancement, such as MOS.
+ A typical use case would be a flow like `waveform -> score`. Please see below for the code example.
+
+ Example: Estimate the subjective metric scores for the input waveform.
+ >>> import torch
+ >>> import torchaudio
+ >>> from torchaudio.pipelines import SQUIM_SUBJECTIVE as bundle
+ >>>
+ >>> # Load the SquimSubjective bundle
+ >>> model = bundle.get_model()
+ Downloading: "https://download.pytorch.org/torchaudio/models/squim_subjective_bvcc_daps.pth"
+ 100%|████████████| 360M/360M [00:09<00:00, 41.1MB/s]
+ >>>
+ >>> # Resample audio to the expected sampling rate
+ >>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
+ >>> # Use a clean reference (doesn't need to be the reference for the waveform) as the second input
+ >>> reference = torchaudio.functional.resample(reference, sample_rate, bundle.sample_rate)
+ >>>
+ >>> # Estimate subjective metric scores
+ >>> score = model(waveform, reference)
+ >>> print(f"MOS: {score}.")
+ """ # noqa: E501
+
+ _path: str
+ _sample_rate: float
+
+ def get_model(self) -> SquimSubjective:
+ """Construct the SquimSubjective model, and load the pretrained weight.
+ Returns:
+ Variation of :py:class:`~torchaudio.models.SquimObjective`.
+ """
+ model = squim_subjective_base()
+ path = torchaudio.utils.download_asset(f"models/{self._path}")
+ state_dict = torch.load(path, weights_only=True)
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+ @property
+ def sample_rate(self):
+ """Sample rate of the audio that the model is trained on.
+
+ :type: float
+ """
+ return self._sample_rate
+
+
+SQUIM_SUBJECTIVE = SquimSubjectiveBundle(
+ "squim_subjective_bvcc_daps.pth",
+ _sample_rate=16000,
+)
+SQUIM_SUBJECTIVE.__doc__ = """SquimSubjective pipeline trained
+ as described in :cite:`manocha2022speech` and :cite:`kumar2023torchaudio`
+ on the *BVCC* :cite:`cooper2021voices` and *DAPS* :cite:`mysore2014can` datasets.
+
+ The underlying model is constructed by :py:func:`torchaudio.models.squim_subjective_base`.
+ The weights are under `Creative Commons Attribution Non Commercial 4.0 International
+ `__.
+
+ Please refer to :py:class:`SquimSubjectiveBundle` for usage instructions.
+ """
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_tts/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_tts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..02851f596ceb281acc75c4d6a1aaf17eeee4a809
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_tts/__init__.py
@@ -0,0 +1,16 @@
+from .impl import (
+ TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
+ TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
+ TACOTRON2_WAVERNN_CHAR_LJSPEECH,
+ TACOTRON2_WAVERNN_PHONE_LJSPEECH,
+)
+from .interface import Tacotron2TTSBundle
+
+
+__all__ = [
+ "Tacotron2TTSBundle",
+ "TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH",
+ "TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH",
+ "TACOTRON2_WAVERNN_CHAR_LJSPEECH",
+ "TACOTRON2_WAVERNN_PHONE_LJSPEECH",
+]
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_tts/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_tts/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fd3e99e61e6aced84dcbfbce4a71eda7dfcdd9b5
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_tts/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_tts/__pycache__/impl.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_tts/__pycache__/impl.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..74ef3b5ed86d8ccb640a30a1ecbe2ff6448163b3
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_tts/__pycache__/impl.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_tts/__pycache__/interface.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_tts/__pycache__/interface.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db50caa54f407e4938b17ab899db597859527133
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_tts/__pycache__/interface.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_tts/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_tts/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96af402fe30646e1415151bf37d208e5bae5bec7
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_tts/__pycache__/utils.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_tts/impl.py b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_tts/impl.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8542286242dcbb2036fff49c1d0e11fbbf9258b
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_tts/impl.py
@@ -0,0 +1,385 @@
+import re
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import Tensor
+from torchaudio._internal import load_state_dict_from_url
+from torchaudio.functional import mu_law_decoding
+from torchaudio.models import Tacotron2, WaveRNN
+from torchaudio.transforms import GriffinLim, InverseMelScale
+
+from . import utils
+from .interface import Tacotron2TTSBundle
+
+__all__ = []
+
+_BASE_URL = "https://download.pytorch.org/torchaudio/models"
+
+
+################################################################################
+# Pipeline implementation - Text Processor
+################################################################################
+
+
+class _EnglishCharProcessor(Tacotron2TTSBundle.TextProcessor):
+ def __init__(self):
+ super().__init__()
+ self._tokens = utils._get_chars()
+ self._mapping = {s: i for i, s in enumerate(self._tokens)}
+
+ @property
+ def tokens(self):
+ return self._tokens
+
+ def __call__(self, texts: Union[str, List[str]]) -> Tuple[Tensor, Tensor]:
+ if isinstance(texts, str):
+ texts = [texts]
+ indices = [[self._mapping[c] for c in t.lower() if c in self._mapping] for t in texts]
+ return utils._to_tensor(indices)
+
+
+class _EnglishPhoneProcessor(Tacotron2TTSBundle.TextProcessor):
+ def __init__(self, *, dl_kwargs=None):
+ super().__init__()
+ self._tokens = utils._get_phones()
+ self._mapping = {p: i for i, p in enumerate(self._tokens)}
+ self._phonemizer = utils._load_phonemizer("en_us_cmudict_forward.pt", dl_kwargs=dl_kwargs)
+ self._pattern = r"(\[[A-Z]+?\]|[_!'(),.:;? -])"
+
+ @property
+ def tokens(self):
+ return self._tokens
+
+ def __call__(self, texts: Union[str, List[str]]) -> Tuple[Tensor, Tensor]:
+ if isinstance(texts, str):
+ texts = [texts]
+
+ indices = []
+ for phones in self._phonemizer(texts, lang="en_us"):
+ # '[F][UW][B][AA][R]!' -> ['F', 'UW', 'B', 'AA', 'R', '!']
+ ret = [re.sub(r"[\[\]]", "", r) for r in re.findall(self._pattern, phones)]
+ indices.append([self._mapping[p] for p in ret])
+ return utils._to_tensor(indices)
+
+
+################################################################################
+# Pipeline implementation - Vocoder
+################################################################################
+
+
+class _WaveRNNVocoder(torch.nn.Module, Tacotron2TTSBundle.Vocoder):
+ def __init__(self, model: WaveRNN, min_level_db: Optional[float] = -100):
+ super().__init__()
+ self._sample_rate = 22050
+ self._model = model
+ self._min_level_db = min_level_db
+
+ @property
+ def sample_rate(self):
+ return self._sample_rate
+
+ def forward(self, mel_spec, lengths=None):
+ mel_spec = torch.exp(mel_spec)
+ mel_spec = 20 * torch.log10(torch.clamp(mel_spec, min=1e-5))
+ if self._min_level_db is not None:
+ mel_spec = (self._min_level_db - mel_spec) / self._min_level_db
+ mel_spec = torch.clamp(mel_spec, min=0, max=1)
+ waveform, lengths = self._model.infer(mel_spec, lengths)
+ waveform = utils._unnormalize_waveform(waveform, self._model.n_bits)
+ waveform = mu_law_decoding(waveform, self._model.n_classes)
+ waveform = waveform.squeeze(1)
+ return waveform, lengths
+
+
+class _GriffinLimVocoder(torch.nn.Module, Tacotron2TTSBundle.Vocoder):
+ def __init__(self):
+ super().__init__()
+ self._sample_rate = 22050
+ self._inv_mel = InverseMelScale(
+ n_stft=(1024 // 2 + 1),
+ n_mels=80,
+ sample_rate=self.sample_rate,
+ f_min=0.0,
+ f_max=8000.0,
+ mel_scale="slaney",
+ norm="slaney",
+ )
+ self._griffin_lim = GriffinLim(
+ n_fft=1024,
+ power=1,
+ hop_length=256,
+ win_length=1024,
+ )
+
+ @property
+ def sample_rate(self):
+ return self._sample_rate
+
+ def forward(self, mel_spec, lengths=None):
+ mel_spec = torch.exp(mel_spec)
+ mel_spec = mel_spec.clone().detach().requires_grad_(True)
+ spec = self._inv_mel(mel_spec)
+ spec = spec.detach().requires_grad_(False)
+ waveforms = self._griffin_lim(spec)
+ return waveforms, lengths
+
+
+################################################################################
+# Bundle classes mixins
+################################################################################
+
+
+class _CharMixin:
+ def get_text_processor(self) -> Tacotron2TTSBundle.TextProcessor:
+ return _EnglishCharProcessor()
+
+
+class _PhoneMixin:
+ def get_text_processor(self, *, dl_kwargs=None) -> Tacotron2TTSBundle.TextProcessor:
+ return _EnglishPhoneProcessor(dl_kwargs=dl_kwargs)
+
+
+@dataclass
+class _Tacotron2Mixin:
+ _tacotron2_path: str
+ _tacotron2_params: Dict[str, Any]
+
+ def get_tacotron2(self, *, dl_kwargs=None) -> Tacotron2:
+ model = Tacotron2(**self._tacotron2_params)
+ url = f"{_BASE_URL}/{self._tacotron2_path}"
+ dl_kwargs = {} if dl_kwargs is None else dl_kwargs
+ state_dict = load_state_dict_from_url(url, **dl_kwargs)
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+
+@dataclass
+class _WaveRNNMixin:
+ _wavernn_path: Optional[str]
+ _wavernn_params: Optional[Dict[str, Any]]
+
+ def get_vocoder(self, *, dl_kwargs=None):
+ wavernn = self._get_wavernn(dl_kwargs=dl_kwargs)
+ return _WaveRNNVocoder(wavernn)
+
+ def _get_wavernn(self, *, dl_kwargs=None):
+ model = WaveRNN(**self._wavernn_params)
+ url = f"{_BASE_URL}/{self._wavernn_path}"
+ dl_kwargs = {} if dl_kwargs is None else dl_kwargs
+ state_dict = load_state_dict_from_url(url, **dl_kwargs)
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+
+class _GriffinLimMixin:
+ def get_vocoder(self, **_):
+ return _GriffinLimVocoder()
+
+
+################################################################################
+# Bundle classes
+################################################################################
+
+
+@dataclass
+class _Tacotron2WaveRNNCharBundle(_WaveRNNMixin, _Tacotron2Mixin, _CharMixin, Tacotron2TTSBundle):
+ pass
+
+
+@dataclass
+class _Tacotron2WaveRNNPhoneBundle(_WaveRNNMixin, _Tacotron2Mixin, _PhoneMixin, Tacotron2TTSBundle):
+ pass
+
+
+@dataclass
+class _Tacotron2GriffinLimCharBundle(_GriffinLimMixin, _Tacotron2Mixin, _CharMixin, Tacotron2TTSBundle):
+ pass
+
+
+@dataclass
+class _Tacotron2GriffinLimPhoneBundle(_GriffinLimMixin, _Tacotron2Mixin, _PhoneMixin, Tacotron2TTSBundle):
+ pass
+
+
+################################################################################
+# Instantiate bundle objects
+################################################################################
+
+
+TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH = _Tacotron2GriffinLimCharBundle(
+ _tacotron2_path="tacotron2_english_characters_1500_epochs_ljspeech.pth",
+ _tacotron2_params=utils._get_taco_params(n_symbols=38),
+)
+TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH.__doc__ = """Character-based TTS pipeline with :py:class:`~torchaudio.models.Tacotron2` trained on *LJSpeech* :cite:`ljspeech17` for 1,500 epochs, and
+:py:class:`~torchaudio.transforms.GriffinLim` as vocoder.
+
+The text processor encodes the input texts character-by-character.
+
+You can find the training script `here `__.
+The default parameters were used.
+
+Please refer to :func:`torchaudio.pipelines.Tacotron2TTSBundle` for the usage.
+
+Example - "Hello world! T T S stands for Text to Speech!"
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH.png
+ :alt: Spectrogram generated by Tacotron2
+
+ .. raw:: html
+
+
+
+Example - "The examination and testimony of the experts enabled the Commission to conclude that five shots may have been fired,"
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH_v2.png
+ :alt: Spectrogram generated by Tacotron2
+
+ .. raw:: html
+
+
+""" # noqa: E501
+
+TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH = _Tacotron2GriffinLimPhoneBundle(
+ _tacotron2_path="tacotron2_english_phonemes_1500_epochs_ljspeech.pth",
+ _tacotron2_params=utils._get_taco_params(n_symbols=96),
+)
+TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH.__doc__ = """Phoneme-based TTS pipeline with :py:class:`~torchaudio.models.Tacotron2` trained on *LJSpeech* :cite:`ljspeech17` for 1,500 epochs and
+:py:class:`~torchaudio.transforms.GriffinLim` as vocoder.
+
+The text processor encodes the input texts based on phoneme.
+It uses `DeepPhonemizer `__ to convert
+graphemes to phonemes.
+The model (*en_us_cmudict_forward*) was trained on
+`CMUDict `__.
+
+You can find the training script `here `__.
+The text processor is set to the *"english_phonemes"*.
+
+Please refer to :func:`torchaudio.pipelines.Tacotron2TTSBundle` for the usage.
+
+Example - "Hello world! T T S stands for Text to Speech!"
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH.png
+ :alt: Spectrogram generated by Tacotron2
+
+ .. raw:: html
+
+
+
+Example - "The examination and testimony of the experts enabled the Commission to conclude that five shots may have been fired,"
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH_v2.png
+ :alt: Spectrogram generated by Tacotron2
+
+ .. raw:: html
+
+
+
+""" # noqa: E501
+
+TACOTRON2_WAVERNN_CHAR_LJSPEECH = _Tacotron2WaveRNNCharBundle(
+ _tacotron2_path="tacotron2_english_characters_1500_epochs_wavernn_ljspeech.pth",
+ _tacotron2_params=utils._get_taco_params(n_symbols=38),
+ _wavernn_path="wavernn_10k_epochs_8bits_ljspeech.pth",
+ _wavernn_params=utils._get_wrnn_params(),
+)
+TACOTRON2_WAVERNN_CHAR_LJSPEECH.__doc__ = """Character-based TTS pipeline with :py:class:`~torchaudio.models.Tacotron2` trained on *LJSpeech* :cite:`ljspeech17` for 1,500 epochs and :py:class:`~torchaudio.models.WaveRNN` vocoder trained on 8 bits depth waveform of *LJSpeech* :cite:`ljspeech17` for 10,000 epochs.
+
+The text processor encodes the input texts character-by-character.
+
+You can find the training script `here `__.
+The following parameters were used; ``win_length=1100``, ``hop_length=275``, ``n_fft=2048``,
+``mel_fmin=40``, and ``mel_fmax=11025``.
+
+You can find the training script `here `__.
+
+Please refer to :func:`torchaudio.pipelines.Tacotron2TTSBundle` for the usage.
+
+Example - "Hello world! T T S stands for Text to Speech!"
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_CHAR_LJSPEECH.png
+ :alt: Spectrogram generated by Tacotron2
+
+ .. raw:: html
+
+
+
+Example - "The examination and testimony of the experts enabled the Commission to conclude that five shots may have been fired,"
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_CHAR_LJSPEECH_v2.png
+ :alt: Spectrogram generated by Tacotron2
+
+ .. raw:: html
+
+
+""" # noqa: E501
+
+TACOTRON2_WAVERNN_PHONE_LJSPEECH = _Tacotron2WaveRNNPhoneBundle(
+ _tacotron2_path="tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech.pth",
+ _tacotron2_params=utils._get_taco_params(n_symbols=96),
+ _wavernn_path="wavernn_10k_epochs_8bits_ljspeech.pth",
+ _wavernn_params=utils._get_wrnn_params(),
+)
+TACOTRON2_WAVERNN_PHONE_LJSPEECH.__doc__ = """Phoneme-based TTS pipeline with :py:class:`~torchaudio.models.Tacotron2` trained on *LJSpeech* :cite:`ljspeech17` for 1,500 epochs, and
+:py:class:`~torchaudio.models.WaveRNN` vocoder trained on 8 bits depth waveform of *LJSpeech* :cite:`ljspeech17` for 10,000 epochs.
+
+The text processor encodes the input texts based on phoneme.
+It uses `DeepPhonemizer `__ to convert
+graphemes to phonemes.
+The model (*en_us_cmudict_forward*) was trained on
+`CMUDict `__.
+
+You can find the training script for Tacotron2 `here `__.
+The following parameters were used; ``win_length=1100``, ``hop_length=275``, ``n_fft=2048``,
+``mel_fmin=40``, and ``mel_fmax=11025``.
+
+You can find the training script for WaveRNN `here `__.
+
+Please refer to :func:`torchaudio.pipelines.Tacotron2TTSBundle` for the usage.
+
+Example - "Hello world! T T S stands for Text to Speech!"
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_PHONE_LJSPEECH.png
+ :alt: Spectrogram generated by Tacotron2
+
+ .. raw:: html
+
+
+
+
+Example - "The examination and testimony of the experts enabled the Commission to conclude that five shots may have been fired,"
+
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_PHONE_LJSPEECH_v2.png
+ :alt: Spectrogram generated by Tacotron2
+
+ .. raw:: html
+
+
+""" # noqa: E501
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_tts/interface.py b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_tts/interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..564f236bc7c239d17dc82db04c350a9ccc618841
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_tts/interface.py
@@ -0,0 +1,255 @@
+from abc import ABC, abstractmethod
+from typing import List, Optional, Tuple, Union
+
+from torch import Tensor
+from torchaudio.models import Tacotron2
+
+
+class _TextProcessor(ABC):
+ @property
+ @abstractmethod
+ def tokens(self):
+ """The tokens that the each value in the processed tensor represent.
+
+ :type: List[str]
+ """
+
+ @abstractmethod
+ def __call__(self, texts: Union[str, List[str]]) -> Tuple[Tensor, Tensor]:
+ """Encode the given (batch of) texts into numerical tensors
+
+ Args:
+ text (str or list of str): The input texts.
+
+ Returns:
+ (Tensor, Tensor):
+ Tensor:
+ The encoded texts. Shape: `(batch, max length)`
+ Tensor:
+ The valid length of each sample in the batch. Shape: `(batch, )`.
+ """
+
+
+class _Vocoder(ABC):
+ @property
+ @abstractmethod
+ def sample_rate(self):
+ """The sample rate of the resulting waveform
+
+ :type: float
+ """
+
+ @abstractmethod
+ def __call__(self, specgrams: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
+ """Generate waveform from the given input, such as spectrogram
+
+ Args:
+ specgrams (Tensor):
+ The input spectrogram. Shape: `(batch, frequency bins, time)`.
+ The expected shape depends on the implementation.
+ lengths (Tensor, or None, optional):
+ The valid length of each sample in the batch. Shape: `(batch, )`.
+ (Default: `None`)
+
+ Returns:
+ (Tensor, Optional[Tensor]):
+ Tensor:
+ The generated waveform. Shape: `(batch, max length)`
+ Tensor or None:
+ The valid length of each sample in the batch. Shape: `(batch, )`.
+ """
+
+
+class Tacotron2TTSBundle(ABC):
+ """Data class that bundles associated information to use pretrained Tacotron2 and vocoder.
+
+ This class provides interfaces for instantiating the pretrained model along with
+ the information necessary to retrieve pretrained weights and additional data
+ to be used with the model.
+
+ Torchaudio library instantiates objects of this class, each of which represents
+ a different pretrained model. Client code should access pretrained models via these
+ instances.
+
+ Please see below for the usage and the available values.
+
+ Example - Character-based TTS pipeline with Tacotron2 and WaveRNN
+ >>> import torchaudio
+ >>>
+ >>> text = "Hello, T T S !"
+ >>> bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH
+ >>>
+ >>> # Build processor, Tacotron2 and WaveRNN model
+ >>> processor = bundle.get_text_processor()
+ >>> tacotron2 = bundle.get_tacotron2()
+ Downloading:
+ 100%|███████████████████████████████| 107M/107M [00:01<00:00, 87.9MB/s]
+ >>> vocoder = bundle.get_vocoder()
+ Downloading:
+ 100%|███████████████████████████████| 16.7M/16.7M [00:00<00:00, 78.1MB/s]
+ >>>
+ >>> # Encode text
+ >>> input, lengths = processor(text)
+ >>>
+ >>> # Generate (mel-scale) spectrogram
+ >>> specgram, lengths, _ = tacotron2.infer(input, lengths)
+ >>>
+ >>> # Convert spectrogram to waveform
+ >>> waveforms, lengths = vocoder(specgram, lengths)
+ >>>
+ >>> torchaudio.save('hello-tts.wav', waveforms, vocoder.sample_rate)
+
+ Example - Phoneme-based TTS pipeline with Tacotron2 and WaveRNN
+ >>>
+ >>> # Note:
+ >>> # This bundle uses pre-trained DeepPhonemizer as
+ >>> # the text pre-processor.
+ >>> # Please install deep-phonemizer.
+ >>> # See https://github.com/as-ideas/DeepPhonemizer
+ >>> # The pretrained weight is automatically downloaded.
+ >>>
+ >>> import torchaudio
+ >>>
+ >>> text = "Hello, TTS!"
+ >>> bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH
+ >>>
+ >>> # Build processor, Tacotron2 and WaveRNN model
+ >>> processor = bundle.get_text_processor()
+ Downloading:
+ 100%|███████████████████████████████| 63.6M/63.6M [00:04<00:00, 15.3MB/s]
+ >>> tacotron2 = bundle.get_tacotron2()
+ Downloading:
+ 100%|███████████████████████████████| 107M/107M [00:01<00:00, 87.9MB/s]
+ >>> vocoder = bundle.get_vocoder()
+ Downloading:
+ 100%|███████████████████████████████| 16.7M/16.7M [00:00<00:00, 78.1MB/s]
+ >>>
+ >>> # Encode text
+ >>> input, lengths = processor(text)
+ >>>
+ >>> # Generate (mel-scale) spectrogram
+ >>> specgram, lengths, _ = tacotron2.infer(input, lengths)
+ >>>
+ >>> # Convert spectrogram to waveform
+ >>> waveforms, lengths = vocoder(specgram, lengths)
+ >>>
+ >>> torchaudio.save('hello-tts.wav', waveforms, vocoder.sample_rate)
+ """
+
+ # Using the inner class so that these interfaces are not directly exposed on
+ # `torchaudio.pipelines`, but still listed in documentation.
+ # The thing is, text processing and vocoder are generic and we do not know what kind of
+ # new text processing and vocoder will be added in the future, so we want to make these
+ # interfaces specific to this Tacotron2TTS pipeline.
+
+ class TextProcessor(_TextProcessor):
+ """Interface of the text processing part of Tacotron2TTS pipeline
+
+ See :func:`torchaudio.pipelines.Tacotron2TTSBundle.get_text_processor` for the usage.
+ """
+
+ class Vocoder(_Vocoder):
+ """Interface of the vocoder part of Tacotron2TTS pipeline
+
+ See :func:`torchaudio.pipelines.Tacotron2TTSBundle.get_vocoder` for the usage.
+ """
+
+ @abstractmethod
+ def get_text_processor(self, *, dl_kwargs=None) -> TextProcessor:
+ """Create a text processor
+
+ For character-based pipeline, this processor splits the input text by character.
+ For phoneme-based pipeline, this processor converts the input text (grapheme) to
+ phonemes.
+
+ If a pre-trained weight file is necessary,
+ :func:`torch.hub.download_url_to_file` is used to downloaded it.
+
+ Args:
+ dl_kwargs (dictionary of keyword arguments,):
+ Passed to :func:`torch.hub.download_url_to_file`.
+
+ Returns:
+ TextProcessor:
+ A callable which takes a string or a list of strings as input and
+ returns Tensor of encoded texts and Tensor of valid lengths.
+ The object also has ``tokens`` property, which allows to recover the
+ tokenized form.
+
+ Example - Character-based
+ >>> text = [
+ >>> "Hello World!",
+ >>> "Text-to-speech!",
+ >>> ]
+ >>> bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH
+ >>> processor = bundle.get_text_processor()
+ >>> input, lengths = processor(text)
+ >>>
+ >>> print(input)
+ tensor([[19, 16, 23, 23, 26, 11, 34, 26, 29, 23, 15, 2, 0, 0, 0],
+ [31, 16, 35, 31, 1, 31, 26, 1, 30, 27, 16, 16, 14, 19, 2]],
+ dtype=torch.int32)
+ >>>
+ >>> print(lengths)
+ tensor([12, 15], dtype=torch.int32)
+ >>>
+ >>> print([processor.tokens[i] for i in input[0, :lengths[0]]])
+ ['h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', '!']
+ >>> print([processor.tokens[i] for i in input[1, :lengths[1]]])
+ ['t', 'e', 'x', 't', '-', 't', 'o', '-', 's', 'p', 'e', 'e', 'c', 'h', '!']
+
+ Example - Phoneme-based
+ >>> text = [
+ >>> "Hello, T T S !",
+ >>> "Text-to-speech!",
+ >>> ]
+ >>> bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH
+ >>> processor = bundle.get_text_processor()
+ Downloading:
+ 100%|███████████████████████████████| 63.6M/63.6M [00:04<00:00, 15.3MB/s]
+ >>> input, lengths = processor(text)
+ >>>
+ >>> print(input)
+ tensor([[54, 20, 65, 69, 11, 92, 44, 65, 38, 2, 0, 0, 0, 0],
+ [81, 40, 64, 79, 81, 1, 81, 20, 1, 79, 77, 59, 37, 2]],
+ dtype=torch.int32)
+ >>>
+ >>> print(lengths)
+ tensor([10, 14], dtype=torch.int32)
+ >>>
+ >>> print([processor.tokens[i] for i in input[0]])
+ ['HH', 'AH', 'L', 'OW', ' ', 'W', 'ER', 'L', 'D', '!', '_', '_', '_', '_']
+ >>> print([processor.tokens[i] for i in input[1]])
+ ['T', 'EH', 'K', 'S', 'T', '-', 'T', 'AH', '-', 'S', 'P', 'IY', 'CH', '!']
+ """
+
+ @abstractmethod
+ def get_vocoder(self, *, dl_kwargs=None) -> Vocoder:
+ """Create a vocoder module, based off of either WaveRNN or GriffinLim.
+
+ If a pre-trained weight file is necessary,
+ :func:`torch.hub.load_state_dict_from_url` is used to downloaded it.
+
+ Args:
+ dl_kwargs (dictionary of keyword arguments):
+ Passed to :func:`torch.hub.load_state_dict_from_url`.
+
+ Returns:
+ Vocoder:
+ A vocoder module, which takes spectrogram Tensor and an optional
+ length Tensor, then returns resulting waveform Tensor and an optional
+ length Tensor.
+ """
+
+ @abstractmethod
+ def get_tacotron2(self, *, dl_kwargs=None) -> Tacotron2:
+ """Create a Tacotron2 model with pre-trained weight.
+
+ Args:
+ dl_kwargs (dictionary of keyword arguments):
+ Passed to :func:`torch.hub.load_state_dict_from_url`.
+
+ Returns:
+ Tacotron2:
+ The resulting model.
+ """
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_tts/utils.py b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_tts/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef3ecb31335ae0cf9950d040fa11c21fb3403b25
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_tts/utils.py
@@ -0,0 +1,228 @@
+import logging
+import os
+
+import torch
+from torchaudio._internal import download_url_to_file, module_utils as _mod_utils
+
+
+def _get_chars():
+ return (
+ "_",
+ "-",
+ "!",
+ "'",
+ "(",
+ ")",
+ ",",
+ ".",
+ ":",
+ ";",
+ "?",
+ " ",
+ "a",
+ "b",
+ "c",
+ "d",
+ "e",
+ "f",
+ "g",
+ "h",
+ "i",
+ "j",
+ "k",
+ "l",
+ "m",
+ "n",
+ "o",
+ "p",
+ "q",
+ "r",
+ "s",
+ "t",
+ "u",
+ "v",
+ "w",
+ "x",
+ "y",
+ "z",
+ )
+
+
+def _get_phones():
+ return (
+ "_",
+ "-",
+ "!",
+ "'",
+ "(",
+ ")",
+ ",",
+ ".",
+ ":",
+ ";",
+ "?",
+ " ",
+ "AA",
+ "AA0",
+ "AA1",
+ "AA2",
+ "AE",
+ "AE0",
+ "AE1",
+ "AE2",
+ "AH",
+ "AH0",
+ "AH1",
+ "AH2",
+ "AO",
+ "AO0",
+ "AO1",
+ "AO2",
+ "AW",
+ "AW0",
+ "AW1",
+ "AW2",
+ "AY",
+ "AY0",
+ "AY1",
+ "AY2",
+ "B",
+ "CH",
+ "D",
+ "DH",
+ "EH",
+ "EH0",
+ "EH1",
+ "EH2",
+ "ER",
+ "ER0",
+ "ER1",
+ "ER2",
+ "EY",
+ "EY0",
+ "EY1",
+ "EY2",
+ "F",
+ "G",
+ "HH",
+ "IH",
+ "IH0",
+ "IH1",
+ "IH2",
+ "IY",
+ "IY0",
+ "IY1",
+ "IY2",
+ "JH",
+ "K",
+ "L",
+ "M",
+ "N",
+ "NG",
+ "OW",
+ "OW0",
+ "OW1",
+ "OW2",
+ "OY",
+ "OY0",
+ "OY1",
+ "OY2",
+ "P",
+ "R",
+ "S",
+ "SH",
+ "T",
+ "TH",
+ "UH",
+ "UH0",
+ "UH1",
+ "UH2",
+ "UW",
+ "UW0",
+ "UW1",
+ "UW2",
+ "V",
+ "W",
+ "Y",
+ "Z",
+ "ZH",
+ )
+
+
+def _to_tensor(indices):
+ lengths = torch.tensor([len(i) for i in indices], dtype=torch.int32)
+ values = [torch.tensor(i) for i in indices]
+ values = torch.nn.utils.rnn.pad_sequence(values, batch_first=True)
+ return values, lengths
+
+
+def _load_phonemizer(file, dl_kwargs):
+ if not _mod_utils.is_module_available("dp"):
+ raise RuntimeError("DeepPhonemizer is not installed. Please install it.")
+
+ from dp.phonemizer import Phonemizer
+
+ # By default, dp issues DEBUG level log.
+ logger = logging.getLogger("dp")
+ orig_level = logger.level
+ logger.setLevel(logging.INFO)
+ try:
+ url = f"https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/{file}"
+ directory = os.path.join(torch.hub.get_dir(), "checkpoints")
+ os.makedirs(directory, exist_ok=True)
+ path = os.path.join(directory, file)
+ if not os.path.exists(path):
+ dl_kwargs = {} if dl_kwargs is None else dl_kwargs
+ download_url_to_file(url, path, **dl_kwargs)
+ return Phonemizer.from_checkpoint(path)
+ finally:
+ logger.setLevel(orig_level)
+
+
+def _unnormalize_waveform(waveform: torch.Tensor, bits: int) -> torch.Tensor:
+ r"""Transform waveform [-1, 1] to label [0, 2 ** bits - 1]"""
+ waveform = torch.clamp(waveform, -1, 1)
+ waveform = (waveform + 1.0) * (2**bits - 1) / 2
+ return torch.clamp(waveform, 0, 2**bits - 1).int()
+
+
+def _get_taco_params(n_symbols):
+ return {
+ "mask_padding": False,
+ "n_mels": 80,
+ "n_frames_per_step": 1,
+ "symbol_embedding_dim": 512,
+ "encoder_embedding_dim": 512,
+ "encoder_n_convolution": 3,
+ "encoder_kernel_size": 5,
+ "decoder_rnn_dim": 1024,
+ "decoder_max_step": 2000,
+ "decoder_dropout": 0.1,
+ "decoder_early_stopping": True,
+ "attention_rnn_dim": 1024,
+ "attention_hidden_dim": 128,
+ "attention_location_n_filter": 32,
+ "attention_location_kernel_size": 31,
+ "attention_dropout": 0.1,
+ "prenet_dim": 256,
+ "postnet_n_convolution": 5,
+ "postnet_kernel_size": 5,
+ "postnet_embedding_dim": 512,
+ "gate_threshold": 0.5,
+ "n_symbol": n_symbols,
+ }
+
+
+def _get_wrnn_params():
+ return {
+ "upsample_scales": [5, 5, 11],
+ "n_classes": 2**8, # n_bits = 8
+ "hop_length": 275,
+ "n_res_block": 10,
+ "n_rnn": 512,
+ "n_fc": 512,
+ "kernel_size": 5,
+ "n_freq": 80,
+ "n_hidden": 128,
+ "n_output": 128,
+ }
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_wav2vec2/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_wav2vec2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_wav2vec2/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_wav2vec2/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ebe5b3f5477de1c5296ddc954e06316304dbb894
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_wav2vec2/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_wav2vec2/__pycache__/aligner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_wav2vec2/__pycache__/aligner.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e9505e42214ad07b12549538a651119a1c252a9
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_wav2vec2/__pycache__/aligner.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_wav2vec2/__pycache__/impl.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_wav2vec2/__pycache__/impl.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0242ac5dba9a1905c6467198fbed62b0e271ac7e
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_wav2vec2/__pycache__/impl.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_wav2vec2/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_wav2vec2/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8535b0ede89659dbf991dfd970e219864f2ba806
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_wav2vec2/__pycache__/utils.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_wav2vec2/aligner.py b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_wav2vec2/aligner.py
new file mode 100644
index 0000000000000000000000000000000000000000..3655d5bae88181796d6d889013b4438d0ea014b3
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_wav2vec2/aligner.py
@@ -0,0 +1,87 @@
+from abc import ABC, abstractmethod
+from typing import Dict, List
+
+import torch
+import torchaudio.functional as F
+from torch import Tensor
+from torchaudio.functional import TokenSpan
+
+
+class ITokenizer(ABC):
+ @abstractmethod
+ def __call__(self, transcript: List[str]) -> List[List[str]]:
+ """Tokenize the given transcript (list of word)
+
+ .. note::
+
+ The toranscript must be normalized.
+
+ Args:
+ transcript (list of str): Transcript (list of word).
+
+ Returns:
+ (list of int): List of token sequences
+ """
+
+
+class Tokenizer(ITokenizer):
+ def __init__(self, dictionary: Dict[str, int]):
+ self.dictionary = dictionary
+
+ def __call__(self, transcript: List[str]) -> List[List[int]]:
+ return [[self.dictionary[c] for c in word] for word in transcript]
+
+
+def _align_emission_and_tokens(emission: Tensor, tokens: List[int], blank: int = 0):
+ device = emission.device
+ emission = emission.unsqueeze(0)
+ targets = torch.tensor([tokens], dtype=torch.int32, device=device)
+
+ aligned_tokens, scores = F.forced_align(emission, targets, blank=blank)
+
+ scores = scores.exp() # convert back to probability
+ aligned_tokens, scores = aligned_tokens[0], scores[0] # remove batch dimension
+ return aligned_tokens, scores
+
+
+class IAligner(ABC):
+ @abstractmethod
+ def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]:
+ """Generate list of time-stamped token sequences
+
+ Args:
+ emission (Tensor): Sequence of token probability distributions in log-domain.
+ Shape: `(time, tokens)`.
+ tokens (list of integer sequence): Tokenized transcript.
+ Output from :py:class:`torchaudio.pipelines.Wav2Vec2FABundle.Tokenizer`.
+
+ Returns:
+ (list of TokenSpan sequence): Tokens with time stamps and scores.
+ """
+
+
+def _unflatten(list_, lengths):
+ assert len(list_) == sum(lengths)
+ i = 0
+ ret = []
+ for l in lengths:
+ ret.append(list_[i : i + l])
+ i += l
+ return ret
+
+
+def _flatten(nested_list):
+ return [item for list_ in nested_list for item in list_]
+
+
+class Aligner(IAligner):
+ def __init__(self, blank):
+ self.blank = blank
+
+ def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]:
+ if emission.ndim != 2:
+ raise ValueError(f"The input emission must be 2D. Found: {emission.shape}")
+
+ aligned_tokens, scores = _align_emission_and_tokens(emission, _flatten(tokens), self.blank)
+ spans = F.merge_tokens(aligned_tokens, scores)
+ return _unflatten(spans, [len(ts) for ts in tokens])
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_wav2vec2/impl.py b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_wav2vec2/impl.py
new file mode 100644
index 0000000000000000000000000000000000000000..d60fa8adb94e92e1a479fe94e09f521d0fe50056
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_wav2vec2/impl.py
@@ -0,0 +1,1699 @@
+from dataclasses import dataclass
+from typing import Any, Dict, Optional, Tuple
+
+from torch.nn import Module
+
+from . import aligner, utils
+
+
+__all__ = [] # type: ignore
+
+
+@dataclass
+class Wav2Vec2Bundle:
+ """Data class that bundles associated information to use pretrained :py:class:`~torchaudio.models.Wav2Vec2Model`.
+
+ This class provides interfaces for instantiating the pretrained model along with
+ the information necessary to retrieve pretrained weights and additional data
+ to be used with the model.
+
+ Torchaudio library instantiates objects of this class, each of which represents
+ a different pretrained model. Client code should access pretrained models via these
+ instances.
+
+ Please see below for the usage and the available values.
+
+ Example - Feature Extraction
+ >>> import torchaudio
+ >>>
+ >>> bundle = torchaudio.pipelines.HUBERT_BASE
+ >>>
+ >>> # Build the model and load pretrained weight.
+ >>> model = bundle.get_model()
+ Downloading:
+ 100%|███████████████████████████████| 360M/360M [00:06<00:00, 60.6MB/s]
+ >>>
+ >>> # Resample audio to the expected sampling rate
+ >>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
+ >>>
+ >>> # Extract acoustic features
+ >>> features, _ = model.extract_features(waveform)
+ """ # noqa: E501
+
+ _path: str
+ _params: Dict[str, Any]
+ _sample_rate: float
+ _normalize_waveform: bool
+ _model_type: str
+
+ @property
+ def sample_rate(self) -> float:
+ """Sample rate of the audio that the model is trained on.
+
+ :type: float
+ """
+ return self._sample_rate
+
+ def _get_state_dict(self, dl_kwargs):
+ # Note: This method is overridden in ASR bundle
+ return utils._get_state_dict(self._path, dl_kwargs)
+
+ def get_model(self, *, dl_kwargs=None) -> Module:
+ """Construct the model and load the pretrained weight.
+
+ The weight file is downloaded from the internet and cached with
+ :func:`torch.hub.load_state_dict_from_url`
+
+ Args:
+ dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
+
+ Returns:
+ Variation of :py:class:`~torchaudio.models.Wav2Vec2Model`.
+
+ For the models listed below, an additional layer normalization is performed on the input.
+
+ For all other models, a :py:class:`~torchaudio.models.Wav2Vec2Model` instance is returned.
+
+ - WAV2VEC2_LARGE_LV60K
+ - WAV2VEC2_ASR_LARGE_LV60K_10M
+ - WAV2VEC2_ASR_LARGE_LV60K_100H
+ - WAV2VEC2_ASR_LARGE_LV60K_960H
+ - WAV2VEC2_XLSR53
+ - WAV2VEC2_XLSR_300M
+ - WAV2VEC2_XLSR_1B
+ - WAV2VEC2_XLSR_2B
+ - HUBERT_LARGE
+ - HUBERT_XLARGE
+ - HUBERT_ASR_LARGE
+ - HUBERT_ASR_XLARGE
+ - WAVLM_LARGE
+ """
+ model = utils._get_model(self._model_type, self._params)
+ state_dict = self._get_state_dict(dl_kwargs)
+ model.load_state_dict(state_dict)
+ if self._normalize_waveform:
+ model = utils._extend_model(model, normalize_waveform=True)
+ model.eval()
+ return model
+
+
+@dataclass
+class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
+ """Data class that bundles associated information to use pretrained
+ :py:class:`~torchaudio.models.Wav2Vec2Model`.
+
+ This class provides interfaces for instantiating the pretrained model along with
+ the information necessary to retrieve pretrained weights and additional data
+ to be used with the model.
+
+ Torchaudio library instantiates objects of this class, each of which represents
+ a different pretrained model. Client code should access pretrained models via these
+ instances.
+
+ Please see below for the usage and the available values.
+
+ Example - ASR
+ >>> import torchaudio
+ >>>
+ >>> bundle = torchaudio.pipelines.HUBERT_ASR_LARGE
+ >>>
+ >>> # Build the model and load pretrained weight.
+ >>> model = bundle.get_model()
+ Downloading:
+ 100%|███████████████████████████████| 1.18G/1.18G [00:17<00:00, 73.8MB/s]
+ >>>
+ >>> # Check the corresponding labels of the output.
+ >>> labels = bundle.get_labels()
+ >>> print(labels)
+ ('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
+ >>>
+ >>> # Resample audio to the expected sampling rate
+ >>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
+ >>>
+ >>> # Infer the label probability distribution
+ >>> emissions, _ = model(waveform)
+ >>>
+ >>> # Pass emission to decoder
+ >>> # `ctc_decode` is for illustration purpose only
+ >>> transcripts = ctc_decode(emissions, labels)
+ """ # noqa: E501
+
+ _labels: Tuple[str, ...]
+ _remove_aux_axis: Tuple[int, ...] = (1, 2, 3)
+
+ def get_labels(
+ self,
+ *,
+ blank: str = "-",
+ ) -> Tuple[str, ...]:
+ """The output class labels.
+
+ The first is blank token, and it is customizable.
+
+ Args:
+ blank (str, optional): Blank token. (default: ``'-'``)
+
+ Returns:
+ Tuple[str, ...]:
+ For models fine-tuned on ASR, returns the tuple of strings representing
+ the output class labels.
+
+ Example
+ >>> from torchaudio.pipelines import HUBERT_ASR_LARGE as bundle
+ >>> bundle.get_labels()
+ ('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
+ """ # noqa: E501
+ return (blank, *self._labels)
+
+ def _get_state_dict(self, dl_kwargs):
+ return utils._get_state_dict(self._path, dl_kwargs, self._remove_aux_axis)
+
+
+WAV2VEC2_BASE = Wav2Vec2Bundle(
+ _path="wav2vec2_fairseq_base_ls960.pth",
+ _params={
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 768,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 12,
+ "encoder_num_heads": 12,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 3072,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.1,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.05,
+ "aux_num_out": None,
+ },
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _model_type="Wav2Vec2",
+)
+WAV2VEC2_BASE.__doc__ = """Wav2vec 2.0 model ("base" architecture),
+pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
+(the combination of "train-clean-100", "train-clean-360", and "train-other-500"), not fine-tuned.
+
+Originally published by the authors of *wav2vec 2.0* :cite:`baevski2020wav2vec` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
+""" # noqa: E501
+
+WAV2VEC2_ASR_BASE_10M = Wav2Vec2ASRBundle(
+ _path="wav2vec2_fairseq_base_ls960_asr_ll10m.pth",
+ _params={
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 768,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 12,
+ "encoder_num_heads": 12,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 3072,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.1,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.05,
+ "aux_num_out": 29,
+ },
+ _labels=utils._get_en_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _model_type="Wav2Vec2",
+)
+WAV2VEC2_ASR_BASE_10M.__doc__ = """Wav2vec 2.0 model ("base" architecture with an extra linear module),
+pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
+(the combination of "train-clean-100", "train-clean-360", and "train-other-500"), and
+fine-tuned for ASR on 10 minutes of transcribed audio from *Libri-Light* dataset
+:cite:`librilight` ("train-10min" subset).
+
+Originally published by the authors of *wav2vec 2.0* :cite:`baevski2020wav2vec` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+WAV2VEC2_ASR_BASE_100H = Wav2Vec2ASRBundle(
+ "wav2vec2_fairseq_base_ls960_asr_ls100.pth",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 768,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 12,
+ "encoder_num_heads": 12,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 3072,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.1,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.05,
+ "aux_num_out": 29,
+ },
+ _labels=utils._get_en_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _model_type="Wav2Vec2",
+)
+
+WAV2VEC2_ASR_BASE_100H.__doc__ = """Wav2vec 2.0 model ("base" architecture with an extra linear module),
+pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
+(the combination of "train-clean-100", "train-clean-360", and "train-other-500"), and
+fine-tuned for ASR on 100 hours of transcribed audio from "train-clean-100" subset.
+
+Originally published by the authors of *wav2vec 2.0* :cite:`baevski2020wav2vec` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+WAV2VEC2_ASR_BASE_960H = Wav2Vec2ASRBundle(
+ "wav2vec2_fairseq_base_ls960_asr_ls960.pth",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 768,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 12,
+ "encoder_num_heads": 12,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 3072,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.1,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.05,
+ "aux_num_out": 29,
+ },
+ _labels=utils._get_en_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _model_type="Wav2Vec2",
+)
+WAV2VEC2_ASR_BASE_960H.__doc__ = """Wav2vec 2.0 model ("base" architecture with an extra linear module),
+pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
+(the combination of "train-clean-100", "train-clean-360", and "train-other-500"), and
+fine-tuned for ASR on the same audio with the corresponding transcripts.
+
+Originally published by the authors of *wav2vec 2.0* :cite:`baevski2020wav2vec` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+WAV2VEC2_LARGE = Wav2Vec2Bundle(
+ "wav2vec2_fairseq_large_ls960.pth",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.2,
+ "aux_num_out": None,
+ },
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _model_type="Wav2Vec2",
+)
+WAV2VEC2_LARGE.__doc__ = """Wav2vec 2.0 model ("large" architecture),
+pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
+(the combination of "train-clean-100", "train-clean-360", and "train-other-500"), not fine-tuned.
+
+Originally published by the authors of *wav2vec 2.0* :cite:`baevski2020wav2vec` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
+""" # noqa: E501
+
+WAV2VEC2_ASR_LARGE_10M = Wav2Vec2ASRBundle(
+ "wav2vec2_fairseq_large_ls960_asr_ll10m.pth",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.2,
+ "aux_num_out": 29,
+ },
+ _labels=utils._get_en_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _model_type="Wav2Vec2",
+)
+WAV2VEC2_ASR_LARGE_10M.__doc__ = """Wav2vec 2.0 model ("large" architecture with an extra linear module),
+pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
+(the combination of "train-clean-100", "train-clean-360", and "train-other-500"), and
+fine-tuned for ASR on 10 minutes of transcribed audio from *Libri-Light* dataset
+:cite:`librilight` ("train-10min" subset).
+
+Originally published by the authors of *wav2vec 2.0* :cite:`baevski2020wav2vec` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+WAV2VEC2_ASR_LARGE_100H = Wav2Vec2ASRBundle(
+ "wav2vec2_fairseq_large_ls960_asr_ls100.pth",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.2,
+ "aux_num_out": 29,
+ },
+ _labels=utils._get_en_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _model_type="Wav2Vec2",
+)
+WAV2VEC2_ASR_LARGE_100H.__doc__ = """Wav2vec 2.0 model ("large" architecture with an extra linear module),
+pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
+(the combination of "train-clean-100", "train-clean-360", and "train-other-500"), and
+fine-tuned for ASR on 100 hours of transcribed audio from
+the same dataset ("train-clean-100" subset).
+
+Originally published by the authors of *wav2vec 2.0* :cite:`baevski2020wav2vec` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+WAV2VEC2_ASR_LARGE_960H = Wav2Vec2ASRBundle(
+ "wav2vec2_fairseq_large_ls960_asr_ls960.pth",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.2,
+ "aux_num_out": 29,
+ },
+ _labels=utils._get_en_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _model_type="Wav2Vec2",
+)
+WAV2VEC2_ASR_LARGE_960H.__doc__ = """Wav2vec 2.0 model ("large" architecture with an extra linear module),
+pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
+(the combination of "train-clean-100", "train-clean-360", and "train-other-500"), and
+fine-tuned for ASR on the same audio with the corresponding transcripts.
+
+Originally published by the authors of *wav2vec 2.0* :cite:`baevski2020wav2vec` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+WAV2VEC2_LARGE_LV60K = Wav2Vec2Bundle(
+ "wav2vec2_fairseq_large_lv60k.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": True,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.0,
+ "aux_num_out": None,
+ },
+ _sample_rate=16000,
+ _normalize_waveform=True,
+ _model_type="Wav2Vec2",
+)
+WAV2VEC2_LARGE_LV60K.__doc__ = """Wav2vec 2.0 model ("large-lv60k" architecture),
+pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* dataset :cite:`librilight`,
+not fine-tuned.
+
+Originally published by the authors of *wav2vec 2.0* :cite:`baevski2020wav2vec` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
+""" # noqa: E501
+
+WAV2VEC2_ASR_LARGE_LV60K_10M = Wav2Vec2ASRBundle(
+ "wav2vec2_fairseq_large_lv60k_asr_ll10m.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": True,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.0,
+ "aux_num_out": 29,
+ },
+ _labels=utils._get_en_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=True,
+ _model_type="Wav2Vec2",
+)
+WAV2VEC2_ASR_LARGE_LV60K_10M.__doc__ = """Wav2vec 2.0 model ("large-lv60k" architecture with an extra linear module),
+pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* dataset :cite:`librilight`, and
+fine-tuned for ASR on 10 minutes of transcribed audio from the same dataset ("train-10min" subset).
+
+Originally published by the authors of *wav2vec 2.0* :cite:`baevski2020wav2vec` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+WAV2VEC2_ASR_LARGE_LV60K_100H = Wav2Vec2ASRBundle(
+ "wav2vec2_fairseq_large_lv60k_asr_ls100.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": True,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.0,
+ "aux_num_out": 29,
+ },
+ _labels=utils._get_en_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=True,
+ _model_type="Wav2Vec2",
+)
+WAV2VEC2_ASR_LARGE_LV60K_100H.__doc__ = """Wav2vec 2.0 model ("large-lv60k" architecture with an extra linear module),
+pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* dataset :cite:`librilight`, and
+fine-tuned for ASR on 100 hours of transcribed audio from
+*LibriSpeech* dataset :cite:`7178964` ("train-clean-100" subset).
+
+Originally published by the authors of *wav2vec 2.0* :cite:`baevski2020wav2vec` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+WAV2VEC2_ASR_LARGE_LV60K_960H = Wav2Vec2ASRBundle(
+ "wav2vec2_fairseq_large_lv60k_asr_ls960.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": True,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.0,
+ "aux_num_out": 29,
+ },
+ _labels=utils._get_en_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=True,
+ _model_type="Wav2Vec2",
+)
+WAV2VEC2_ASR_LARGE_LV60K_960H.__doc__ = """Wav2vec 2.0 model ("large-lv60k" architecture with an extra linear module),
+pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* :cite:`librilight` dataset, and
+fine-tuned for ASR on 960 hours of transcribed audio from *LibriSpeech* dataset :cite:`7178964`
+(the combination of "train-clean-100", "train-clean-360", and "train-other-500").
+
+Originally published by the authors of *wav2vec 2.0* :cite:`baevski2020wav2vec` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+WAV2VEC2_XLSR53 = Wav2Vec2Bundle(
+ "wav2vec2_fairseq_large_xlsr53.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": True,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.0,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.0,
+ "aux_num_out": None,
+ },
+ _sample_rate=16000,
+ _normalize_waveform=True,
+ _model_type="Wav2Vec2",
+)
+WAV2VEC2_XLSR53.__doc__ = """Wav2vec 2.0 model ("base" architecture),
+pre-trained on 56,000 hours of unlabeled audio from multiple datasets (
+*Multilingual LibriSpeech* :cite:`Pratap_2020`,
+*CommonVoice* :cite:`ardila2020common` and
+*BABEL* :cite:`Gales2014SpeechRA`),
+not fine-tuned.
+
+Originally published by the authors of
+*Unsupervised Cross-lingual Representation Learning for Speech Recognition*
+:cite:`conneau2020unsupervised` under MIT License and redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
+""" # noqa: E501
+
+HUBERT_BASE = Wav2Vec2Bundle(
+ "hubert_fairseq_base_ls960.pth",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 768,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 12,
+ "encoder_num_heads": 12,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 3072,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.1,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.05,
+ "aux_num_out": None,
+ },
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _model_type="Wav2Vec2",
+)
+HUBERT_BASE.__doc__ = """HuBERT model ("base" architecture),
+pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
+(the combination of "train-clean-100", "train-clean-360", and "train-other-500"), not fine-tuned.
+
+Originally published by the authors of *HuBERT* :cite:`hsu2021hubert` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
+""" # noqa: E501
+
+HUBERT_LARGE = Wav2Vec2Bundle(
+ "hubert_fairseq_large_ll60k.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.0,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.0,
+ "aux_num_out": None,
+ },
+ _sample_rate=16000,
+ _normalize_waveform=True,
+ _model_type="Wav2Vec2",
+)
+HUBERT_LARGE.__doc__ = """HuBERT model ("large" architecture),
+pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* dataset :cite:`librilight`,
+not fine-tuned.
+
+Originally published by the authors of *HuBERT* :cite:`hsu2021hubert` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
+""" # noqa: E501
+
+HUBERT_XLARGE = Wav2Vec2Bundle(
+ "hubert_fairseq_xlarge_ll60k.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 1280,
+ "encoder_projection_dropout": 0.0,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 48,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 5120,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.0,
+ "aux_num_out": None,
+ },
+ _sample_rate=16000,
+ _normalize_waveform=True,
+ _model_type="Wav2Vec2",
+)
+HUBERT_XLARGE.__doc__ = """HuBERT model ("extra large" architecture),
+pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* dataset :cite:`librilight`,
+not fine-tuned.
+
+Originally published by the authors of *HuBERT* :cite:`hsu2021hubert` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
+""" # noqa: E501
+
+HUBERT_ASR_LARGE = Wav2Vec2ASRBundle(
+ "hubert_fairseq_large_ll60k_asr_ls960.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.0,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.1,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.1,
+ "aux_num_out": 29,
+ },
+ _labels=utils._get_en_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=True,
+ _model_type="Wav2Vec2",
+)
+HUBERT_ASR_LARGE.__doc__ = """HuBERT model ("large" architecture),
+pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* dataset :cite:`librilight`, and
+fine-tuned for ASR on 960 hours of transcribed audio from *LibriSpeech* dataset :cite:`7178964`
+(the combination of "train-clean-100", "train-clean-360", and "train-other-500").
+
+Originally published by the authors of *HuBERT* :cite:`hsu2021hubert` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+HUBERT_ASR_XLARGE = Wav2Vec2ASRBundle(
+ "hubert_fairseq_xlarge_ll60k_asr_ls960.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 1280,
+ "encoder_projection_dropout": 0.0,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 48,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 5120,
+ "encoder_ff_interm_dropout": 0.1,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.1,
+ "aux_num_out": 29,
+ },
+ _labels=utils._get_en_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=True,
+ _model_type="Wav2Vec2",
+)
+HUBERT_ASR_XLARGE.__doc__ = """HuBERT model ("extra large" architecture),
+pre-trained on 60,000 hours of unlabeled audio from
+*Libri-Light* dataset :cite:`librilight`, and
+fine-tuned for ASR on 960 hours of transcribed audio from
+*LibriSpeech* dataset :cite:`7178964`
+(the combination of "train-clean-100", "train-clean-360", and "train-other-500").
+
+Originally published by the authors of *HuBERT* :cite:`hsu2021hubert` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+
+VOXPOPULI_ASR_BASE_10K_DE = Wav2Vec2ASRBundle(
+ "wav2vec2_voxpopuli_base_10k_asr_de.pt",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 768,
+ "encoder_projection_dropout": 0.0,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 12,
+ "encoder_num_heads": 12,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 3072,
+ "encoder_ff_interm_dropout": 0.1,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.1,
+ "aux_num_out": 32,
+ },
+ _labels=utils._get_de_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _remove_aux_axis=(1, 2, 3, 35),
+ _model_type="Wav2Vec2",
+)
+VOXPOPULI_ASR_BASE_10K_DE.__doc__ = """wav2vec 2.0 model ("base" architecture),
+pre-trained on 10k hours of unlabeled audio from *VoxPopuli* dataset :cite:`voxpopuli`
+("10k" subset, consisting of 23 languages), and
+fine-tuned for ASR on 282 hours of transcribed audio from "de" subset.
+
+Originally published by the authors of *VoxPopuli* :cite:`voxpopuli` under CC BY-NC 4.0 and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+
+VOXPOPULI_ASR_BASE_10K_EN = Wav2Vec2ASRBundle(
+ "wav2vec2_voxpopuli_base_10k_asr_en.pt",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 768,
+ "encoder_projection_dropout": 0.0,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 12,
+ "encoder_num_heads": 12,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 3072,
+ "encoder_ff_interm_dropout": 0.1,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.1,
+ "aux_num_out": 28,
+ },
+ _labels=utils._get_vp_en_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _remove_aux_axis=(1, 2, 3, 31),
+ _model_type="Wav2Vec2",
+)
+VOXPOPULI_ASR_BASE_10K_EN.__doc__ = """wav2vec 2.0 model ("base" architecture),
+pre-trained on 10k hours of unlabeled audio from *VoxPopuli* dataset :cite:`voxpopuli`
+("10k" subset, consisting of 23 languages), and
+fine-tuned for ASR on 543 hours of transcribed audio from "en" subset.
+
+Originally published by the authors of *VoxPopuli* :cite:`voxpopuli` under CC BY-NC 4.0 and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+
+VOXPOPULI_ASR_BASE_10K_ES = Wav2Vec2ASRBundle(
+ "wav2vec2_voxpopuli_base_10k_asr_es.pt",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 768,
+ "encoder_projection_dropout": 0.0,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 12,
+ "encoder_num_heads": 12,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 3072,
+ "encoder_ff_interm_dropout": 0.1,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.1,
+ "aux_num_out": 35,
+ },
+ _labels=utils._get_es_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _remove_aux_axis=(1, 2, 3, 35),
+ _model_type="Wav2Vec2",
+)
+VOXPOPULI_ASR_BASE_10K_ES.__doc__ = """wav2vec 2.0 model ("base" architecture),
+pre-trained on 10k hours of unlabeled audio from *VoxPopuli* dataset :cite:`voxpopuli`
+("10k" subset, consisting of 23 languages), and
+fine-tuned for ASR on 166 hours of transcribed audio from "es" subset.
+
+Originally published by the authors of *VoxPopuli* :cite:`voxpopuli` under CC BY-NC 4.0 and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+VOXPOPULI_ASR_BASE_10K_FR = Wav2Vec2ASRBundle(
+ "wav2vec2_voxpopuli_base_10k_asr_fr.pt",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 768,
+ "encoder_projection_dropout": 0.0,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 12,
+ "encoder_num_heads": 12,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 3072,
+ "encoder_ff_interm_dropout": 0.1,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.1,
+ "aux_num_out": 43,
+ },
+ _labels=utils._get_fr_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _model_type="Wav2Vec2",
+)
+VOXPOPULI_ASR_BASE_10K_FR.__doc__ = """wav2vec 2.0 model ("base" architecture),
+pre-trained on 10k hours of unlabeled audio from *VoxPopuli* dataset :cite:`voxpopuli`
+("10k" subset, consisting of 23 languages), and
+fine-tuned for ASR on 211 hours of transcribed audio from "fr" subset.
+
+Originally published by the authors of *VoxPopuli* :cite:`voxpopuli` under CC BY-NC 4.0 and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+
+VOXPOPULI_ASR_BASE_10K_IT = Wav2Vec2ASRBundle(
+ "wav2vec2_voxpopuli_base_10k_asr_it.pt",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 768,
+ "encoder_projection_dropout": 0.0,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 12,
+ "encoder_num_heads": 12,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 3072,
+ "encoder_ff_interm_dropout": 0.1,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.1,
+ "aux_num_out": 37,
+ },
+ _labels=utils._get_it_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=False,
+ _remove_aux_axis=(1, 2, 3),
+ _model_type="Wav2Vec2",
+)
+VOXPOPULI_ASR_BASE_10K_IT.__doc__ = """wav2vec 2.0 model ("base" architecture),
+pre-trained on 10k hours of unlabeled audio from *VoxPopuli* dataset :cite:`voxpopuli`
+("10k" subset, consisting of 23 languages), and
+fine-tuned for ASR on 91 hours of transcribed audio from "it" subset.
+
+Originally published by the authors of *VoxPopuli* :cite:`voxpopuli` under CC BY-NC 4.0 and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage.
+""" # noqa: E501
+
+
+WAVLM_BASE = Wav2Vec2Bundle(
+ "wavlm_base.pth",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 768,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 12,
+ "encoder_num_heads": 12,
+ "encoder_max_distance": 800,
+ "encoder_num_buckets": 320,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 3072,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.1,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.05,
+ "aux_num_out": None,
+ },
+ _model_type="WavLM",
+ _sample_rate=16000,
+ _normalize_waveform=False,
+)
+WAVLM_BASE.__doc__ = """WavLM Base model ("base" architecture),
+pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`, not fine-tuned.
+
+Originally published by the authors of *WavLM* :cite:`chen2022wavlm` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
+""" # noqa: E501
+
+
+WAVLM_BASE_PLUS = Wav2Vec2Bundle(
+ "wavlm_base_plus.pth",
+ {
+ "extractor_mode": "group_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 768,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 12,
+ "encoder_num_heads": 12,
+ "encoder_max_distance": 800,
+ "encoder_num_buckets": 320,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 3072,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.1,
+ "encoder_layer_norm_first": False,
+ "encoder_layer_drop": 0.05,
+ "aux_num_out": None,
+ },
+ _model_type="WavLM",
+ _sample_rate=16000,
+ _normalize_waveform=False,
+)
+WAVLM_BASE_PLUS.__doc__ = """WavLM Base+ model ("base" architecture),
+pre-trained on 60,000 hours of Libri-Light dataset :cite:`librilight`, 10,000 hours of GigaSpeech :cite:`GigaSpeech2021`,
+and 24,000 hours of *VoxPopuli* :cite:`voxpopuli`, not fine-tuned.
+
+Originally published by the authors of *WavLM* :cite:`chen2022wavlm` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
+""" # noqa: E501
+
+
+WAVLM_LARGE = Wav2Vec2Bundle(
+ "wavlm_large.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": False,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_max_distance": 800,
+ "encoder_num_buckets": 320,
+ "encoder_attention_dropout": 0.1,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.1,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.05,
+ "aux_num_out": None,
+ },
+ _model_type="WavLM",
+ _sample_rate=16000,
+ _normalize_waveform=True,
+)
+WAVLM_LARGE.__doc__ = """WavLM Large model ("large" architecture),
+pre-trained on 60,000 hours of Libri-Light dataset :cite:`librilight`, 10,000 hours of GigaSpeech :cite:`GigaSpeech2021`,
+and 24,000 hours of *VoxPopuli* :cite:`voxpopuli`, not fine-tuned.
+
+Originally published by the authors of *WavLM* :cite:`chen2022wavlm` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage.
+""" # noqa: E501
+
+
+WAV2VEC2_XLSR_300M = Wav2Vec2Bundle(
+ "wav2vec2_xlsr_300m.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": True,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.0,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.0,
+ "aux_num_out": None,
+ },
+ _model_type="Wav2Vec2",
+ _sample_rate=16000,
+ _normalize_waveform=True,
+)
+WAV2VEC2_XLSR_300M.__doc__ = """XLS-R model with 300 million parameters,
+pre-trained on 436,000 hours of unlabeled audio from multiple datasets (
+*Multilingual LibriSpeech* :cite:`Pratap_2020`,
+*CommonVoice* :cite:`ardila2020common`,
+*VoxLingua107* :cite:`valk2021voxlingua107`,
+*BABEL* :cite:`Gales2014SpeechRA`, and
+*VoxPopuli* :cite:`voxpopuli`) in 128 languages,
+not fine-tuned.
+
+Originally published by the authors of *XLS-R* :cite:`babu2021xls` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for usage details.
+""" # noqa: E501
+
+
+WAV2VEC2_XLSR_1B = Wav2Vec2Bundle(
+ "wav2vec2_xlsr_1b.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": True,
+ "encoder_embed_dim": 1280,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 48,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 5120,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.0,
+ "aux_num_out": None,
+ },
+ _model_type="Wav2Vec2",
+ _sample_rate=16000,
+ _normalize_waveform=True,
+)
+WAV2VEC2_XLSR_1B.__doc__ = """XLS-R model with 1 billion parameters,
+pre-trained on 436,000 hours of unlabeled audio from multiple datasets (
+*Multilingual LibriSpeech* :cite:`Pratap_2020`,
+*CommonVoice* :cite:`ardila2020common`,
+*VoxLingua107* :cite:`valk2021voxlingua107`,
+*BABEL* :cite:`Gales2014SpeechRA`, and
+*VoxPopuli* :cite:`voxpopuli`) in 128 languages,
+not fine-tuned.
+
+Originally published by the authors of *XLS-R* :cite:`babu2021xls` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for usage details.
+""" # noqa: E501
+
+WAV2VEC2_XLSR_2B = Wav2Vec2Bundle(
+ "wav2vec2_xlsr_2b.pth",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": True,
+ "encoder_embed_dim": 1920,
+ "encoder_projection_dropout": 0.1,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 48,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 7680,
+ "encoder_ff_interm_dropout": 0.0,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.0,
+ "aux_num_out": None,
+ },
+ _model_type="Wav2Vec2",
+ _sample_rate=16000,
+ _normalize_waveform=True,
+)
+WAV2VEC2_XLSR_2B.__doc__ = """XLS-R model with 2 billion parameters,
+pre-trained on 436,000 hours of unlabeled audio from multiple datasets (
+*Multilingual LibriSpeech* :cite:`Pratap_2020`,
+*CommonVoice* :cite:`ardila2020common`,
+*VoxLingua107* :cite:`valk2021voxlingua107`,
+*BABEL* :cite:`Gales2014SpeechRA`, and
+*VoxPopuli* :cite:`voxpopuli`) in 128 languages,
+not fine-tuned.
+
+Originally published by the authors of *XLS-R* :cite:`babu2021xls` under MIT License and
+redistributed with the same license.
+[`License `__,
+`Source `__]
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for usage details.
+""" # noqa: E501
+
+
+@dataclass
+class Wav2Vec2FABundle(Wav2Vec2ASRBundle):
+ """Data class that bundles associated information to use pretrained :py:class:`~torchaudio.models.Wav2Vec2Model` for forced alignment.
+
+ This class provides interfaces for instantiating the pretrained model along with
+ the information necessary to retrieve pretrained weights and additional data
+ to be used with the model.
+
+ Torchaudio library instantiates objects of this class, each of which represents
+ a different pretrained model. Client code should access pretrained models via these
+ instances.
+
+ Please see below for the usage and the available values.
+
+ Example - Feature Extraction
+ >>> import torchaudio
+ >>>
+ >>> bundle = torchaudio.pipelines.MMS_FA
+ >>>
+ >>> # Build the model and load pretrained weight.
+ >>> model = bundle.get_model()
+ Downloading:
+ 100%|███████████████████████████████| 1.18G/1.18G [00:05<00:00, 216MB/s]
+ >>>
+ >>> # Resample audio to the expected sampling rate
+ >>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
+ >>>
+ >>> # Estimate the probability of token distribution
+ >>> emission, _ = model(waveform)
+ >>>
+ >>> # Generate frame-wise alignment
+ >>> alignment, scores = torchaudio.functional.forced_align(
+ >>> emission, targets, input_lengths, target_lengths, blank=0)
+ >>>
+ """ # noqa: E501
+
+ class Tokenizer(aligner.ITokenizer):
+ """Interface of the tokenizer"""
+
+ class Aligner(aligner.IAligner):
+ """Interface of the aligner"""
+
+ def get_labels(self, star: Optional[str] = "*", blank: str = "-") -> Tuple[str, ...]:
+ """Get the labels corresponding to the feature dimension of emission.
+
+ The first is blank token, and it is customizable.
+
+ Args:
+ star (str or None, optional): Change or disable star token. (default: ``"*"``)
+ blank (str, optional): Change the blank token. (default: ``'-'``)
+
+ Returns:
+ Tuple[str, ...]:
+ For models fine-tuned on ASR, returns the tuple of strings representing
+ the output class labels.
+
+ Example
+ >>> from torchaudio.pipelines import MMS_FA as bundle
+ >>> bundle.get_labels()
+ ('-', 'a', 'i', 'e', 'n', 'o', 'u', 't', 's', 'r', 'm', 'k', 'l', 'd', 'g', 'h', 'y', 'b', 'p', 'w', 'c', 'v', 'j', 'z', 'f', "'", 'q', 'x', '*')
+ >>> bundle.get_labels(star=None)
+ ('-', 'a', 'i', 'e', 'n', 'o', 'u', 't', 's', 'r', 'm', 'k', 'l', 'd', 'g', 'h', 'y', 'b', 'p', 'w', 'c', 'v', 'j', 'z', 'f', "'", 'q', 'x')
+ """ # noqa: E501
+ labels = super().get_labels(blank=blank)
+ return labels if star is None else (*labels, star)
+
+ def get_model(self, with_star: bool = True, *, dl_kwargs=None) -> Module:
+ """Construct the model and load the pretrained weight.
+
+ The weight file is downloaded from the internet and cached with
+ :func:`torch.hub.load_state_dict_from_url`
+
+ Args:
+ with_star (bool, optional): If enabled, the last dimension of output layer is
+ extended by one, which corresponds to `star` token.
+ dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
+
+ Returns:
+ Variation of :py:class:`~torchaudio.models.Wav2Vec2Model`.
+
+ .. note::
+
+ The model created with this method returns probability in log-domain,
+ (i.e. :py:func:`torch.nn.functional.log_softmax` is applied), whereas
+ the other Wav2Vec2 models returns logit.
+ """
+ model = utils._get_model(self._model_type, self._params)
+ state_dict = utils._get_state_dict(self._path, dl_kwargs, self._remove_aux_axis)
+ model.load_state_dict(state_dict)
+ model = utils._extend_model(
+ model, normalize_waveform=self._normalize_waveform, apply_log_softmax=True, append_star=with_star
+ )
+ model.eval()
+ return model
+
+ def get_dict(self, star: Optional[str] = "*", blank: str = "-") -> Dict[str, int]:
+ """Get the mapping from token to index (in emission feature dim)
+
+ Args:
+ star (str or None, optional): Change or disable star token. (default: ``"*"``)
+ blank (str, optional): Change the blank token. (default: ``'-'``)
+
+ Returns:
+ Tuple[str, ...]:
+ For models fine-tuned on ASR, returns the tuple of strings representing
+ the output class labels.
+
+ Example
+ >>> from torchaudio.pipelines import MMS_FA as bundle
+ >>> bundle.get_dict()
+ {'-': 0, 'a': 1, 'i': 2, 'e': 3, 'n': 4, 'o': 5, 'u': 6, 't': 7, 's': 8, 'r': 9, 'm': 10, 'k': 11, 'l': 12, 'd': 13, 'g': 14, 'h': 15, 'y': 16, 'b': 17, 'p': 18, 'w': 19, 'c': 20, 'v': 21, 'j': 22, 'z': 23, 'f': 24, "'": 25, 'q': 26, 'x': 27, '*': 28}
+ >>> bundle.get_dict(star=None)
+ {'-': 0, 'a': 1, 'i': 2, 'e': 3, 'n': 4, 'o': 5, 'u': 6, 't': 7, 's': 8, 'r': 9, 'm': 10, 'k': 11, 'l': 12, 'd': 13, 'g': 14, 'h': 15, 'y': 16, 'b': 17, 'p': 18, 'w': 19, 'c': 20, 'v': 21, 'j': 22, 'z': 23, 'f': 24, "'": 25, 'q': 26, 'x': 27}
+ """ # noqa: E501
+ return {k: i for i, k in enumerate(self.get_labels(star=star, blank=blank))}
+
+ def get_tokenizer(self) -> Tokenizer:
+ """Instantiate a Tokenizer.
+
+ Returns:
+ Tokenizer
+ """
+ return aligner.Tokenizer(self.get_dict())
+
+ def get_aligner(self) -> Aligner:
+ """Instantiate an Aligner.
+
+ Returns:
+ Aligner
+ """
+ return aligner.Aligner(blank=0)
+
+
+MMS_FA = Wav2Vec2FABundle(
+ "https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt",
+ {
+ "extractor_mode": "layer_norm",
+ "extractor_conv_layer_config": [
+ (512, 10, 5),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 3, 2),
+ (512, 2, 2),
+ (512, 2, 2),
+ ],
+ "extractor_conv_bias": True,
+ "encoder_embed_dim": 1024,
+ "encoder_projection_dropout": 0.0,
+ "encoder_pos_conv_kernel": 128,
+ "encoder_pos_conv_groups": 16,
+ "encoder_num_layers": 24,
+ "encoder_num_heads": 16,
+ "encoder_attention_dropout": 0.0,
+ "encoder_ff_interm_features": 4096,
+ "encoder_ff_interm_dropout": 0.1,
+ "encoder_dropout": 0.0,
+ "encoder_layer_norm_first": True,
+ "encoder_layer_drop": 0.1,
+ "aux_num_out": 28,
+ },
+ _labels=utils._get_mms_labels(),
+ _sample_rate=16000,
+ _normalize_waveform=True,
+ _model_type="Wav2Vec2",
+)
+MMS_FA.__doc__ = """
+Trained on 31K hours of data in 1,130 languages from *Scaling Speech Technology to 1,000+ Languages* :cite:`pratap2023scaling`.
+
+Published by the authors of *Scaling Speech Technology to 1,000+ Languages* :cite:`pratap2023scaling` under [`CC-BY-NC 4.0 License `__].
+
+Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2FABundle` for usage details.
+
+.. note::
+
+ Unlike other Wav2Vec2 bundles, this model does not have a token for word boundary (like `|`). This makes the post-processing of alignments slightly different.
+""" # noqa: E501
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_wav2vec2/utils.py b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_wav2vec2/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e690e8103c7a47a01d719e746e6c98a9c7f6c8db
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/_wav2vec2/utils.py
@@ -0,0 +1,346 @@
+from typing import List, Optional, Tuple
+
+import torch
+from torch import nn, Tensor
+
+from torchaudio._internal import load_state_dict_from_url
+from torchaudio.models import wav2vec2_model, Wav2Vec2Model, wavlm_model
+
+
+def _get_model(type_, params):
+ factories = {
+ "Wav2Vec2": wav2vec2_model,
+ "WavLM": wavlm_model,
+ }
+ if type_ not in factories:
+ raise ValueError(f"Supported model types are {tuple(factories.keys())}. Found: {type_}")
+ factory = factories[type_]
+ return factory(**params)
+
+
+class _Wav2Vec2Model(nn.Module):
+ """Wrapper class for :py:class:`~torchaudio.models.Wav2Vec2Model`.
+
+ This is used for layer normalization at the input
+ """
+
+ def __init__(self, model: Wav2Vec2Model, normalize_waveform: bool, apply_log_softmax: bool, append_star: bool):
+ super().__init__()
+ self.model = model
+ self.normalize_waveform = normalize_waveform
+ self.apply_log_softmax = apply_log_softmax
+ self.append_star = append_star
+
+ def forward(self, waveforms: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
+ if self.normalize_waveform:
+ waveforms = nn.functional.layer_norm(waveforms, waveforms.shape)
+ output, output_lengths = self.model(waveforms, lengths)
+ if self.apply_log_softmax:
+ output = torch.nn.functional.log_softmax(output, dim=-1)
+ if self.append_star:
+ star_dim = torch.zeros((1, output.size(1), 1), dtype=output.dtype, device=output.device)
+ output = torch.cat((output, star_dim), dim=-1)
+ return output, output_lengths
+
+ @torch.jit.export
+ def extract_features(
+ self,
+ waveforms: Tensor,
+ lengths: Optional[Tensor] = None,
+ num_layers: Optional[int] = None,
+ ) -> Tuple[List[Tensor], Optional[Tensor]]:
+ if self.normalize_waveform:
+ waveforms = nn.functional.layer_norm(waveforms, waveforms.shape)
+ return self.model.extract_features(waveforms, lengths, num_layers)
+
+
+def _extend_model(module, normalize_waveform, apply_log_softmax=False, append_star=False):
+ """Add extra transformations to the model"""
+ return _Wav2Vec2Model(module, normalize_waveform, apply_log_softmax, append_star)
+
+
+def _remove_aux_axes(state_dict, axes):
+ # Remove the seemingly unnecessary axis
+ # For ASR task, the pretrained weights originated from fairseq has unrelated dimensions at index 1, 2, 3
+ # It's originated from the Dictionary implementation of fairseq, which was intended for NLP tasks,
+ # but not used during the ASR training.
+ # https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L21-L37
+ # https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/criterions/ctc.py#L126-L129
+ #
+ # Also, some pretrained weights originated from voxpopuli has an extra dimensions that almost never used and
+ # that resembles mistake.
+ # The label `1` shows up in the training dataset of German (1 out of 16M),
+ # English (1 / 28M), Spanish (1 / 9.4M), Romanian (1 / 4.7M) and Polish (6 / 5.8M)
+ for key in ["aux.weight", "aux.bias"]:
+ mat = state_dict[key]
+ state_dict[key] = torch.stack([mat[i] for i in range(mat.size(0)) if i not in axes])
+
+
+def _get_state_dict(url, dl_kwargs, remove_axes=None):
+ if not url.startswith("https"):
+ url = f"https://download.pytorch.org/torchaudio/models/{url}"
+ dl_kwargs = {} if dl_kwargs is None else dl_kwargs
+ state_dict = load_state_dict_from_url(url, **dl_kwargs)
+ if remove_axes:
+ _remove_aux_axes(state_dict, remove_axes)
+ return state_dict
+
+
+def _get_en_labels():
+ return (
+ "|",
+ "E",
+ "T",
+ "A",
+ "O",
+ "N",
+ "I",
+ "H",
+ "S",
+ "R",
+ "D",
+ "L",
+ "U",
+ "M",
+ "W",
+ "C",
+ "F",
+ "G",
+ "Y",
+ "P",
+ "B",
+ "V",
+ "K",
+ "'",
+ "X",
+ "J",
+ "Q",
+ "Z",
+ )
+
+
+def _get_de_labels():
+ return (
+ "|",
+ "e",
+ "n",
+ "i",
+ "r",
+ "s",
+ "t",
+ "a",
+ "d",
+ "h",
+ "u",
+ "l",
+ "g",
+ "c",
+ "m",
+ "o",
+ "b",
+ "w",
+ "f",
+ "k",
+ "z",
+ "p",
+ "v",
+ "ü",
+ "ä",
+ "ö",
+ "j",
+ "ß",
+ "y",
+ "x",
+ "q",
+ )
+
+
+def _get_vp_en_labels():
+ return (
+ "|",
+ "e",
+ "t",
+ "o",
+ "i",
+ "a",
+ "n",
+ "s",
+ "r",
+ "h",
+ "l",
+ "d",
+ "c",
+ "u",
+ "m",
+ "p",
+ "f",
+ "g",
+ "w",
+ "y",
+ "b",
+ "v",
+ "k",
+ "x",
+ "j",
+ "q",
+ "z",
+ )
+
+
+def _get_es_labels():
+ return (
+ "|",
+ "e",
+ "a",
+ "o",
+ "s",
+ "n",
+ "r",
+ "i",
+ "l",
+ "d",
+ "c",
+ "t",
+ "u",
+ "p",
+ "m",
+ "b",
+ "q",
+ "y",
+ "g",
+ "v",
+ "h",
+ "ó",
+ "f",
+ "í",
+ "á",
+ "j",
+ "z",
+ "ñ",
+ "é",
+ "x",
+ "ú",
+ "k",
+ "w",
+ "ü",
+ )
+
+
+def _get_fr_labels():
+ return (
+ "|",
+ "e",
+ "s",
+ "n",
+ "i",
+ "t",
+ "r",
+ "a",
+ "o",
+ "u",
+ "l",
+ "d",
+ "c",
+ "p",
+ "m",
+ "é",
+ "v",
+ "q",
+ "f",
+ "g",
+ "b",
+ "h",
+ "x",
+ "à",
+ "j",
+ "è",
+ "y",
+ "ê",
+ "z",
+ "ô",
+ "k",
+ "ç",
+ "œ",
+ "û",
+ "ù",
+ "î",
+ "â",
+ "w",
+ "ï",
+ "ë",
+ "ü",
+ "æ",
+ )
+
+
+def _get_it_labels():
+ return (
+ "|",
+ "e",
+ "i",
+ "a",
+ "o",
+ "n",
+ "t",
+ "r",
+ "l",
+ "s",
+ "c",
+ "d",
+ "u",
+ "p",
+ "m",
+ "g",
+ "v",
+ "h",
+ "z",
+ "f",
+ "b",
+ "q",
+ "à",
+ "è",
+ "ù",
+ "é",
+ "ò",
+ "ì",
+ "k",
+ "y",
+ "x",
+ "w",
+ "j",
+ "ó",
+ "í",
+ "ï",
+ )
+
+
+def _get_mms_labels():
+ return (
+ "a",
+ "i",
+ "e",
+ "n",
+ "o",
+ "u",
+ "t",
+ "s",
+ "r",
+ "m",
+ "k",
+ "l",
+ "d",
+ "g",
+ "h",
+ "y",
+ "b",
+ "p",
+ "w",
+ "c",
+ "v",
+ "j",
+ "z",
+ "f",
+ "'",
+ "q",
+ "x",
+ )
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/pipelines/rnnt_pipeline.py b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/rnnt_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..11b5a479f3785241a00313a85ead1405b2f673cd
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/pipelines/rnnt_pipeline.py
@@ -0,0 +1,380 @@
+import json
+import math
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from functools import partial
+from typing import Callable, List, Tuple
+
+import torch
+import torchaudio
+from torchaudio._internal import module_utils
+from torchaudio.models import emformer_rnnt_base, RNNT, RNNTBeamSearch
+
+
+__all__ = []
+
+_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
+_gain = pow(10, 0.05 * _decibel)
+
+
+def _piecewise_linear_log(x):
+ x[x > math.e] = torch.log(x[x > math.e])
+ x[x <= math.e] = x[x <= math.e] / math.e
+ return x
+
+
+class _FunctionalModule(torch.nn.Module):
+ def __init__(self, functional):
+ super().__init__()
+ self.functional = functional
+
+ def forward(self, input):
+ return self.functional(input)
+
+
+class _GlobalStatsNormalization(torch.nn.Module):
+ def __init__(self, global_stats_path):
+ super().__init__()
+
+ with open(global_stats_path) as f:
+ blob = json.loads(f.read())
+
+ self.register_buffer("mean", torch.tensor(blob["mean"]))
+ self.register_buffer("invstddev", torch.tensor(blob["invstddev"]))
+
+ def forward(self, input):
+ return (input - self.mean) * self.invstddev
+
+
+class _FeatureExtractor(ABC):
+ @abstractmethod
+ def __call__(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Generates features and length output from the given input tensor.
+
+ Args:
+ input (torch.Tensor): input tensor.
+
+ Returns:
+ (torch.Tensor, torch.Tensor):
+ torch.Tensor:
+ Features, with shape `(length, *)`.
+ torch.Tensor:
+ Length, with shape `(1,)`.
+ """
+
+
+class _TokenProcessor(ABC):
+ @abstractmethod
+ def __call__(self, tokens: List[int], **kwargs) -> str:
+ """Decodes given list of tokens to text sequence.
+
+ Args:
+ tokens (List[int]): list of tokens to decode.
+
+ Returns:
+ str:
+ Decoded text sequence.
+ """
+
+
+class _ModuleFeatureExtractor(torch.nn.Module, _FeatureExtractor):
+ """``torch.nn.Module``-based feature extraction pipeline.
+
+ Args:
+ pipeline (torch.nn.Module): module that implements feature extraction logic.
+ """
+
+ def __init__(self, pipeline: torch.nn.Module) -> None:
+ super().__init__()
+ self.pipeline = pipeline
+
+ def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Generates features and length output from the given input tensor.
+
+ Args:
+ input (torch.Tensor): input tensor.
+
+ Returns:
+ (torch.Tensor, torch.Tensor):
+ torch.Tensor:
+ Features, with shape `(length, *)`.
+ torch.Tensor:
+ Length, with shape `(1,)`.
+ """
+ features = self.pipeline(input)
+ length = torch.tensor([features.shape[0]])
+ return features, length
+
+
+class _SentencePieceTokenProcessor(_TokenProcessor):
+ """SentencePiece-model-based token processor.
+
+ Args:
+ sp_model_path (str): path to SentencePiece model.
+ """
+
+ def __init__(self, sp_model_path: str) -> None:
+ if not module_utils.is_module_available("sentencepiece"):
+ raise RuntimeError("SentencePiece is not available. Please install it.")
+
+ import sentencepiece as spm
+
+ self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
+ self.post_process_remove_list = {
+ self.sp_model.unk_id(),
+ self.sp_model.eos_id(),
+ self.sp_model.pad_id(),
+ }
+
+ def __call__(self, tokens: List[int], lstrip: bool = True) -> str:
+ """Decodes given list of tokens to text sequence.
+
+ Args:
+ tokens (List[int]): list of tokens to decode.
+ lstrip (bool, optional): if ``True``, returns text sequence with leading whitespace
+ removed. (Default: ``True``).
+
+ Returns:
+ str:
+ Decoded text sequence.
+ """
+ filtered_hypo_tokens = [
+ token_index for token_index in tokens[1:] if token_index not in self.post_process_remove_list
+ ]
+ output_string = "".join(self.sp_model.id_to_piece(filtered_hypo_tokens)).replace("\u2581", " ")
+
+ if lstrip:
+ return output_string.lstrip()
+ else:
+ return output_string
+
+
+@dataclass
+class RNNTBundle:
+ """Dataclass that bundles components for performing automatic speech recognition (ASR, speech-to-text)
+ inference with an RNN-T model.
+
+ More specifically, the class provides methods that produce the featurization pipeline,
+ decoder wrapping the specified RNN-T model, and output token post-processor that together
+ constitute a complete end-to-end ASR inference pipeline that produces a text sequence
+ given a raw waveform.
+
+ It can support non-streaming (full-context) inference as well as streaming inference.
+
+ Users should not directly instantiate objects of this class; rather, users should use the
+ instances (representing pre-trained models) that exist within the module,
+ e.g. :data:`torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH`.
+
+ Example
+ >>> import torchaudio
+ >>> from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH
+ >>> import torch
+ >>>
+ >>> # Non-streaming inference.
+ >>> # Build feature extractor, decoder with RNN-T model, and token processor.
+ >>> feature_extractor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_feature_extractor()
+ 100%|███████████████████████████████| 3.81k/3.81k [00:00<00:00, 4.22MB/s]
+ >>> decoder = EMFORMER_RNNT_BASE_LIBRISPEECH.get_decoder()
+ Downloading: "https://download.pytorch.org/torchaudio/models/emformer_rnnt_base_librispeech.pt"
+ 100%|███████████████████████████████| 293M/293M [00:07<00:00, 42.1MB/s]
+ >>> token_processor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_token_processor()
+ 100%|███████████████████████████████| 295k/295k [00:00<00:00, 25.4MB/s]
+ >>>
+ >>> # Instantiate LibriSpeech dataset; retrieve waveform for first sample.
+ >>> dataset = torchaudio.datasets.LIBRISPEECH("/home/librispeech", url="test-clean")
+ >>> waveform = next(iter(dataset))[0].squeeze()
+ >>>
+ >>> with torch.no_grad():
+ >>> # Produce mel-scale spectrogram features.
+ >>> features, length = feature_extractor(waveform)
+ >>>
+ >>> # Generate top-10 hypotheses.
+ >>> hypotheses = decoder(features, length, 10)
+ >>>
+ >>> # For top hypothesis, convert predicted tokens to text.
+ >>> text = token_processor(hypotheses[0][0])
+ >>> print(text)
+ he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to [...]
+ >>>
+ >>>
+ >>> # Streaming inference.
+ >>> hop_length = EMFORMER_RNNT_BASE_LIBRISPEECH.hop_length
+ >>> num_samples_segment = EMFORMER_RNNT_BASE_LIBRISPEECH.segment_length * hop_length
+ >>> num_samples_segment_right_context = (
+ >>> num_samples_segment + EMFORMER_RNNT_BASE_LIBRISPEECH.right_context_length * hop_length
+ >>> )
+ >>>
+ >>> # Build streaming inference feature extractor.
+ >>> streaming_feature_extractor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_streaming_feature_extractor()
+ >>>
+ >>> # Process same waveform as before, this time sequentially across overlapping segments
+ >>> # to simulate streaming inference. Note the usage of ``streaming_feature_extractor`` and ``decoder.infer``.
+ >>> state, hypothesis = None, None
+ >>> for idx in range(0, len(waveform), num_samples_segment):
+ >>> segment = waveform[idx: idx + num_samples_segment_right_context]
+ >>> segment = torch.nn.functional.pad(segment, (0, num_samples_segment_right_context - len(segment)))
+ >>> with torch.no_grad():
+ >>> features, length = streaming_feature_extractor(segment)
+ >>> hypotheses, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
+ >>> hypothesis = hypotheses[0]
+ >>> transcript = token_processor(hypothesis[0])
+ >>> if transcript:
+ >>> print(transcript, end=" ", flush=True)
+ he hoped there would be stew for dinner turn ips and car rots and bru 'd oes and fat mut ton pieces to [...]
+ """
+
+ class FeatureExtractor(_FeatureExtractor):
+ """Interface of the feature extraction part of RNN-T pipeline"""
+
+ class TokenProcessor(_TokenProcessor):
+ """Interface of the token processor part of RNN-T pipeline"""
+
+ _rnnt_path: str
+ _rnnt_factory_func: Callable[[], RNNT]
+ _global_stats_path: str
+ _sp_model_path: str
+ _right_padding: int
+ _blank: int
+ _sample_rate: int
+ _n_fft: int
+ _n_mels: int
+ _hop_length: int
+ _segment_length: int
+ _right_context_length: int
+
+ def _get_model(self) -> RNNT:
+ model = self._rnnt_factory_func()
+ path = torchaudio.utils.download_asset(self._rnnt_path)
+ state_dict = torch.load(path)
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+ @property
+ def sample_rate(self) -> int:
+ """Sample rate (in cycles per second) of input waveforms.
+
+ :type: int
+ """
+ return self._sample_rate
+
+ @property
+ def n_fft(self) -> int:
+ """Size of FFT window to use.
+
+ :type: int
+ """
+ return self._n_fft
+
+ @property
+ def n_mels(self) -> int:
+ """Number of mel spectrogram features to extract from input waveforms.
+
+ :type: int
+ """
+ return self._n_mels
+
+ @property
+ def hop_length(self) -> int:
+ """Number of samples between successive frames in input expected by model.
+
+ :type: int
+ """
+ return self._hop_length
+
+ @property
+ def segment_length(self) -> int:
+ """Number of frames in segment in input expected by model.
+
+ :type: int
+ """
+ return self._segment_length
+
+ @property
+ def right_context_length(self) -> int:
+ """Number of frames in right contextual block in input expected by model.
+
+ :type: int
+ """
+ return self._right_context_length
+
+ def get_decoder(self) -> RNNTBeamSearch:
+ """Constructs RNN-T decoder.
+
+ Returns:
+ RNNTBeamSearch
+ """
+ model = self._get_model()
+ return RNNTBeamSearch(model, self._blank)
+
+ def get_feature_extractor(self) -> FeatureExtractor:
+ """Constructs feature extractor for non-streaming (full-context) ASR.
+
+ Returns:
+ FeatureExtractor
+ """
+ local_path = torchaudio.utils.download_asset(self._global_stats_path)
+ return _ModuleFeatureExtractor(
+ torch.nn.Sequential(
+ torchaudio.transforms.MelSpectrogram(
+ sample_rate=self.sample_rate, n_fft=self.n_fft, n_mels=self.n_mels, hop_length=self.hop_length
+ ),
+ _FunctionalModule(lambda x: x.transpose(1, 0)),
+ _FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)),
+ _GlobalStatsNormalization(local_path),
+ _FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 0, 0, self._right_padding))),
+ )
+ )
+
+ def get_streaming_feature_extractor(self) -> FeatureExtractor:
+ """Constructs feature extractor for streaming (simultaneous) ASR.
+
+ Returns:
+ FeatureExtractor
+ """
+ local_path = torchaudio.utils.download_asset(self._global_stats_path)
+ return _ModuleFeatureExtractor(
+ torch.nn.Sequential(
+ torchaudio.transforms.MelSpectrogram(
+ sample_rate=self.sample_rate, n_fft=self.n_fft, n_mels=self.n_mels, hop_length=self.hop_length
+ ),
+ _FunctionalModule(lambda x: x.transpose(1, 0)),
+ _FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)),
+ _GlobalStatsNormalization(local_path),
+ )
+ )
+
+ def get_token_processor(self) -> TokenProcessor:
+ """Constructs token processor.
+
+ Returns:
+ TokenProcessor
+ """
+ local_path = torchaudio.utils.download_asset(self._sp_model_path)
+ return _SentencePieceTokenProcessor(local_path)
+
+
+EMFORMER_RNNT_BASE_LIBRISPEECH = RNNTBundle(
+ _rnnt_path="models/emformer_rnnt_base_librispeech.pt",
+ _rnnt_factory_func=partial(emformer_rnnt_base, num_symbols=4097),
+ _global_stats_path="pipeline-assets/global_stats_rnnt_librispeech.json",
+ _sp_model_path="pipeline-assets/spm_bpe_4096_librispeech.model",
+ _right_padding=4,
+ _blank=4096,
+ _sample_rate=16000,
+ _n_fft=400,
+ _n_mels=80,
+ _hop_length=160,
+ _segment_length=16,
+ _right_context_length=4,
+)
+EMFORMER_RNNT_BASE_LIBRISPEECH.__doc__ = """ASR pipeline based on Emformer-RNNT,
+pretrained on *LibriSpeech* dataset :cite:`7178964`,
+capable of performing both streaming and non-streaming inference.
+
+The underlying model is constructed by :py:func:`torchaudio.models.emformer_rnnt_base`
+and utilizes weights trained on LibriSpeech using training script ``train.py``
+`here `__ with default arguments.
+
+Please refer to :py:class:`RNNTBundle` for usage instructions.
+"""
diff --git a/.venv/lib/python3.11/site-packages/torchaudio/version.py b/.venv/lib/python3.11/site-packages/torchaudio/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..8dd05cf81ab3c1b8e8b20f9b9ccdd335acebf71a
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchaudio/version.py
@@ -0,0 +1,2 @@
+__version__ = '2.5.1+cu124'
+git_version = '1661daf10599ca8889f092ec37814fabbe202bb0'