diff --git a/.venv/lib/python3.11/site-packages/torchaudio/_internal/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/_internal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..363e94f13bb5059ab6888af2fb60314699f1ab1e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/_internal/__init__.py @@ -0,0 +1,10 @@ +try: + from .fb import download_url_to_file, load_state_dict_from_url +except ImportError: + from torch.hub import download_url_to_file, load_state_dict_from_url + + +__all__ = [ + "load_state_dict_from_url", + "download_url_to_file", +] diff --git a/.venv/lib/python3.11/site-packages/torchaudio/_internal/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/_internal/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d892c50e0d8acd29ccf3099f7e505722aec2c71 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/_internal/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/_internal/__pycache__/module_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/_internal/__pycache__/module_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d35e0c41871d36d977cfb1b7ade9f665428dbfd9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/_internal/__pycache__/module_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/_internal/module_utils.py b/.venv/lib/python3.11/site-packages/torchaudio/_internal/module_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bc484104bfd86460c57b2f773146ea1a14e47984 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/_internal/module_utils.py @@ -0,0 +1,113 @@ +import importlib.util +import os +import warnings +from functools import wraps +from typing import Optional + + +def eval_env(var, default): + """Check if environment varable has True-y value""" + if var not in os.environ: + return default + + val = os.environ.get(var, "0") + trues = ["1", "true", "TRUE", "on", "ON", "yes", "YES"] + falses = ["0", "false", "FALSE", "off", "OFF", "no", "NO"] + if val in trues: + return True + if val not in falses: + # fmt: off + raise RuntimeError( + f"Unexpected environment variable value `{var}={val}`. " + f"Expected one of {trues + falses}") + # fmt: on + return False + + +def is_module_available(*modules: str) -> bool: + r"""Returns if a top-level module with :attr:`name` exists *without** + importing it. This is generally safer than try-catch block around a + `import X`. It avoids third party libraries breaking assumptions of some of + our tests, e.g., setting multiprocessing start method when imported + (see librosa/#747, torchvision/#544). + """ + return all(importlib.util.find_spec(m) is not None for m in modules) + + +def requires_module(*modules: str): + """Decorate function to give error message if invoked without required optional modules. + + This decorator is to give better error message to users rather + than raising ``NameError: name 'module' is not defined`` at random places. + """ + missing = [m for m in modules if not is_module_available(m)] + + if not missing: + # fall through. If all the modules are available, no need to decorate + def decorator(func): + return func + + else: + req = f"module: {missing[0]}" if len(missing) == 1 else f"modules: {missing}" + + def decorator(func): + @wraps(func) + def wrapped(*args, **kwargs): + raise RuntimeError(f"{func.__module__}.{func.__name__} requires {req}") + + return wrapped + + return decorator + + +def deprecated(direction: str, version: Optional[str] = None, remove: bool = False): + """Decorator to add deprecation message + + Args: + direction (str): Migration steps to be given to users. + version (str or int): The version when the object will be removed + remove (bool): If enabled, append future removal message. + """ + + def decorator(func): + @wraps(func) + def wrapped(*args, **kwargs): + message = f"{func.__module__}.{func.__name__} has been deprecated. {direction}" + if remove: + message += f' It will be removed from {"future" if version is None else version} release. ' + warnings.warn(message, stacklevel=2) + return func(*args, **kwargs) + + message = "This function has been deprecated. " + if remove: + message += f'It will be removed from {"future" if version is None else version} release. ' + + wrapped.__doc__ = f"""DEPRECATED: {func.__doc__} + + .. warning:: + + {message} + {direction} + """ + + return wrapped + + return decorator + + +def fail_with_message(message): + """Generate decorator to give users message about missing TorchAudio extension.""" + + def decorator(func): + @wraps(func) + def wrapped(*args, **kwargs): + raise RuntimeError(f"{func.__module__}.{func.__name__} {message}") + + return wrapped + + return decorator + + +def no_op(func): + """Op-op decorator. Used in place of fail_with_message when a functionality that requires extension works fine.""" + return func diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..609cb14fdcc38c48270acd5803f4bfe603c39e71 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/datasets/__init__.py @@ -0,0 +1,47 @@ +from .cmuarctic import CMUARCTIC +from .cmudict import CMUDict +from .commonvoice import COMMONVOICE +from .dr_vctk import DR_VCTK +from .fluentcommands import FluentSpeechCommands +from .gtzan import GTZAN +from .iemocap import IEMOCAP +from .librilight_limited import LibriLightLimited +from .librimix import LibriMix +from .librispeech import LIBRISPEECH +from .librispeech_biasing import LibriSpeechBiasing +from .libritts import LIBRITTS +from .ljspeech import LJSPEECH +from .musdb_hq import MUSDB_HQ +from .quesst14 import QUESST14 +from .snips import Snips +from .speechcommands import SPEECHCOMMANDS +from .tedlium import TEDLIUM +from .vctk import VCTK_092 +from .voxceleb1 import VoxCeleb1Identification, VoxCeleb1Verification +from .yesno import YESNO + + +__all__ = [ + "COMMONVOICE", + "LIBRISPEECH", + "LibriSpeechBiasing", + "LibriLightLimited", + "SPEECHCOMMANDS", + "VCTK_092", + "DR_VCTK", + "YESNO", + "LJSPEECH", + "GTZAN", + "CMUARCTIC", + "CMUDict", + "LibriMix", + "LIBRITTS", + "TEDLIUM", + "QUESST14", + "MUSDB_HQ", + "FluentSpeechCommands", + "VoxCeleb1Identification", + "VoxCeleb1Verification", + "IEMOCAP", + "Snips", +] diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/__pycache__/librilight_limited.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/datasets/__pycache__/librilight_limited.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2eeca3b7fd4f608ae24ca1148a9c5240a53a4a6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/datasets/__pycache__/librilight_limited.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/__pycache__/librimix.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/datasets/__pycache__/librimix.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48a5894e8a59ae0ad652b42c57ee20195263d549 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/datasets/__pycache__/librimix.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/__pycache__/librispeech_biasing.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/datasets/__pycache__/librispeech_biasing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27430a6dbb46f021fa8210f0e3278b35fd97d46d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/datasets/__pycache__/librispeech_biasing.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/__pycache__/libritts.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/datasets/__pycache__/libritts.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75584f9adb475c51e1211c8e60902e68355b8162 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/datasets/__pycache__/libritts.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/cmuarctic.py b/.venv/lib/python3.11/site-packages/torchaudio/datasets/cmuarctic.py new file mode 100644 index 0000000000000000000000000000000000000000..96f498f00f04a2f1b6d7d3e33510eafcb9ffe6bc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/datasets/cmuarctic.py @@ -0,0 +1,157 @@ +import csv +import os +from pathlib import Path +from typing import Tuple, Union + +import torchaudio +from torch import Tensor +from torch.utils.data import Dataset +from torchaudio._internal import download_url_to_file +from torchaudio.datasets.utils import _extract_tar + +URL = "aew" +FOLDER_IN_ARCHIVE = "ARCTIC" +_CHECKSUMS = { + "http://festvox.org/cmu_arctic/packed/cmu_us_aew_arctic.tar.bz2": "645cb33c0f0b2ce41384fdd8d3db2c3f5fc15c1e688baeb74d2e08cab18ab406", # noqa: E501 + "http://festvox.org/cmu_arctic/packed/cmu_us_ahw_arctic.tar.bz2": "024664adeb892809d646a3efd043625b46b5bfa3e6189b3500b2d0d59dfab06c", # noqa: E501 + "http://festvox.org/cmu_arctic/packed/cmu_us_aup_arctic.tar.bz2": "2c55bc3050caa996758869126ad10cf42e1441212111db034b3a45189c18b6fc", # noqa: E501 + "http://festvox.org/cmu_arctic/packed/cmu_us_awb_arctic.tar.bz2": "d74a950c9739a65f7bfc4dfa6187f2730fa03de5b8eb3f2da97a51b74df64d3c", # noqa: E501 + "http://festvox.org/cmu_arctic/packed/cmu_us_axb_arctic.tar.bz2": "dd65c3d2907d1ee52f86e44f578319159e60f4bf722a9142be01161d84e330ff", # noqa: E501 + "http://festvox.org/cmu_arctic/packed/cmu_us_bdl_arctic.tar.bz2": "26b91aaf48b2799b2956792b4632c2f926cd0542f402b5452d5adecb60942904", # noqa: E501 + "http://festvox.org/cmu_arctic/packed/cmu_us_clb_arctic.tar.bz2": "3f16dc3f3b97955ea22623efb33b444341013fc660677b2e170efdcc959fa7c6", # noqa: E501 + "http://festvox.org/cmu_arctic/packed/cmu_us_eey_arctic.tar.bz2": "8a0ee4e5acbd4b2f61a4fb947c1730ab3adcc9dc50b195981d99391d29928e8a", # noqa: E501 + "http://festvox.org/cmu_arctic/packed/cmu_us_fem_arctic.tar.bz2": "3fcff629412b57233589cdb058f730594a62c4f3a75c20de14afe06621ef45e2", # noqa: E501 + "http://festvox.org/cmu_arctic/packed/cmu_us_gka_arctic.tar.bz2": "dc82e7967cbd5eddbed33074b0699128dbd4482b41711916d58103707e38c67f", # noqa: E501 + "http://festvox.org/cmu_arctic/packed/cmu_us_jmk_arctic.tar.bz2": "3a37c0e1dfc91e734fdbc88b562d9e2ebca621772402cdc693bbc9b09b211d73", # noqa: E501 + "http://festvox.org/cmu_arctic/packed/cmu_us_ksp_arctic.tar.bz2": "8029cafce8296f9bed3022c44ef1e7953332b6bf6943c14b929f468122532717", # noqa: E501 + "http://festvox.org/cmu_arctic/packed/cmu_us_ljm_arctic.tar.bz2": "b23993765cbf2b9e7bbc3c85b6c56eaf292ac81ee4bb887b638a24d104f921a0", # noqa: E501 + "http://festvox.org/cmu_arctic/packed/cmu_us_lnh_arctic.tar.bz2": "4faf34d71aa7112813252fb20c5433e2fdd9a9de55a00701ffcbf05f24a5991a", # noqa: E501 + "http://festvox.org/cmu_arctic/packed/cmu_us_rms_arctic.tar.bz2": "c6dc11235629c58441c071a7ba8a2d067903dfefbaabc4056d87da35b72ecda4", # noqa: E501 + "http://festvox.org/cmu_arctic/packed/cmu_us_rxr_arctic.tar.bz2": "1fa4271c393e5998d200e56c102ff46fcfea169aaa2148ad9e9469616fbfdd9b", # noqa: E501 + "http://festvox.org/cmu_arctic/packed/cmu_us_slp_arctic.tar.bz2": "54345ed55e45c23d419e9a823eef427f1cc93c83a710735ec667d068c916abf1", # noqa: E501 + "http://festvox.org/cmu_arctic/packed/cmu_us_slt_arctic.tar.bz2": "7c173297916acf3cc7fcab2713be4c60b27312316765a90934651d367226b4ea", # noqa: E501 +} + + +def load_cmuarctic_item(line: str, path: str, folder_audio: str, ext_audio: str) -> Tuple[Tensor, int, str, str]: + + utterance_id, transcript = line[0].strip().split(" ", 2)[1:] + + # Remove space, double quote, and single parenthesis from transcript + transcript = transcript[1:-3] + + file_audio = os.path.join(path, folder_audio, utterance_id + ext_audio) + + # Load audio + waveform, sample_rate = torchaudio.load(file_audio) + + return (waveform, sample_rate, transcript, utterance_id.split("_")[1]) + + +class CMUARCTIC(Dataset): + """*CMU ARCTIC* :cite:`Kominek03cmuarctic` dataset. + + Args: + root (str or Path): Path to the directory where the dataset is found or downloaded. + url (str, optional): + The URL to download the dataset from or the type of the dataset to download. + (default: ``"aew"``) + Allowed type values are ``"aew"``, ``"ahw"``, ``"aup"``, ``"awb"``, ``"axb"``, ``"bdl"``, + ``"clb"``, ``"eey"``, ``"fem"``, ``"gka"``, ``"jmk"``, ``"ksp"``, ``"ljm"``, ``"lnh"``, + ``"rms"``, ``"rxr"``, ``"slp"`` or ``"slt"``. + folder_in_archive (str, optional): + The top-level directory of the dataset. (default: ``"ARCTIC"``) + download (bool, optional): + Whether to download the dataset if it is not found at root path. (default: ``False``). + """ + + _file_text = "txt.done.data" + _folder_text = "etc" + _ext_audio = ".wav" + _folder_audio = "wav" + + def __init__( + self, root: Union[str, Path], url: str = URL, folder_in_archive: str = FOLDER_IN_ARCHIVE, download: bool = False + ) -> None: + + if url in [ + "aew", + "ahw", + "aup", + "awb", + "axb", + "bdl", + "clb", + "eey", + "fem", + "gka", + "jmk", + "ksp", + "ljm", + "lnh", + "rms", + "rxr", + "slp", + "slt", + ]: + + url = "cmu_us_" + url + "_arctic" + ext_archive = ".tar.bz2" + base_url = "http://www.festvox.org/cmu_arctic/packed/" + + url = os.path.join(base_url, url + ext_archive) + + # Get string representation of 'root' in case Path object is passed + root = os.fspath(root) + + basename = os.path.basename(url) + root = os.path.join(root, folder_in_archive) + if not os.path.isdir(root): + os.mkdir(root) + archive = os.path.join(root, basename) + + basename = basename.split(".")[0] + + self._path = os.path.join(root, basename) + + if download: + if not os.path.isdir(self._path): + if not os.path.isfile(archive): + checksum = _CHECKSUMS.get(url, None) + download_url_to_file(url, archive, hash_prefix=checksum) + _extract_tar(archive) + else: + if not os.path.exists(self._path): + raise RuntimeError( + f"The path {self._path} doesn't exist. " + "Please check the ``root`` path or set `download=True` to download it" + ) + self._text = os.path.join(self._path, self._folder_text, self._file_text) + + with open(self._text, "r") as text: + walker = csv.reader(text, delimiter="\n") + self._walker = list(walker) + + def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + Tuple of the following items; + + Tensor: + Waveform + int: + Sample rate + str: + Transcript + str: + Utterance ID + """ + line = self._walker[n] + return load_cmuarctic_item(line, self._path, self._folder_audio, self._ext_audio) + + def __len__(self) -> int: + return len(self._walker) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/cmudict.py b/.venv/lib/python3.11/site-packages/torchaudio/datasets/cmudict.py new file mode 100644 index 0000000000000000000000000000000000000000..d1038f48badde6f5db589691c5aee2ddf1f1d5de --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/datasets/cmudict.py @@ -0,0 +1,186 @@ +import os +import re +from pathlib import Path +from typing import Iterable, List, Tuple, Union + +from torch.utils.data import Dataset +from torchaudio._internal import download_url_to_file + + +_CHECKSUMS = { + "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b": "209a8b4cd265013e96f4658632a9878103b0c5abf62b50d4ef3ae1be226b29e4", # noqa: E501 + "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols": "408ccaae803641c6d7b626b6299949320c2dbca96b2220fd3fb17887b023b027", # noqa: E501 +} +_PUNCTUATIONS = { + "!EXCLAMATION-POINT", + '"CLOSE-QUOTE', + '"DOUBLE-QUOTE', + '"END-OF-QUOTE', + '"END-QUOTE', + '"IN-QUOTES', + '"QUOTE', + '"UNQUOTE', + "#HASH-MARK", + "#POUND-SIGN", + "#SHARP-SIGN", + "%PERCENT", + "&ERSAND", + "'END-INNER-QUOTE", + "'END-QUOTE", + "'INNER-QUOTE", + "'QUOTE", + "'SINGLE-QUOTE", + "(BEGIN-PARENS", + "(IN-PARENTHESES", + "(LEFT-PAREN", + "(OPEN-PARENTHESES", + "(PAREN", + "(PARENS", + "(PARENTHESES", + ")CLOSE-PAREN", + ")CLOSE-PARENTHESES", + ")END-PAREN", + ")END-PARENS", + ")END-PARENTHESES", + ")END-THE-PAREN", + ")PAREN", + ")PARENS", + ")RIGHT-PAREN", + ")UN-PARENTHESES", + "+PLUS", + ",COMMA", + "--DASH", + "-DASH", + "-HYPHEN", + "...ELLIPSIS", + ".DECIMAL", + ".DOT", + ".FULL-STOP", + ".PERIOD", + ".POINT", + "/SLASH", + ":COLON", + ";SEMI-COLON", + ";SEMI-COLON(1)", + "?QUESTION-MARK", + "{BRACE", + "{LEFT-BRACE", + "{OPEN-BRACE", + "}CLOSE-BRACE", + "}RIGHT-BRACE", +} + + +def _parse_dictionary(lines: Iterable[str], exclude_punctuations: bool) -> List[str]: + _alt_re = re.compile(r"\([0-9]+\)") + cmudict: List[Tuple[str, List[str]]] = [] + for line in lines: + if not line or line.startswith(";;;"): # ignore comments + continue + + word, phones = line.strip().split(" ") + if word in _PUNCTUATIONS: + if exclude_punctuations: + continue + # !EXCLAMATION-POINT -> ! + # --DASH -> -- + # ...ELLIPSIS -> ... + if word.startswith("..."): + word = "..." + elif word.startswith("--"): + word = "--" + else: + word = word[0] + + # if a word have multiple pronunciations, there will be (number) appended to it + # for example, DATAPOINTS and DATAPOINTS(1), + # the regular expression `_alt_re` removes the '(1)' and change the word DATAPOINTS(1) to DATAPOINTS + word = re.sub(_alt_re, "", word) + phones = phones.split(" ") + cmudict.append((word, phones)) + + return cmudict + + +class CMUDict(Dataset): + """*CMU Pronouncing Dictionary* :cite:`cmudict` (CMUDict) dataset. + + Args: + root (str or Path): Path to the directory where the dataset is found or downloaded. + exclude_punctuations (bool, optional): + When enabled, exclude the pronounciation of punctuations, such as + `!EXCLAMATION-POINT` and `#HASH-MARK`. + download (bool, optional): + Whether to download the dataset if it is not found at root path. (default: ``False``). + url (str, optional): + The URL to download the dictionary from. + (default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b"``) + url_symbols (str, optional): + The URL to download the list of symbols from. + (default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols"``) + """ + + def __init__( + self, + root: Union[str, Path], + exclude_punctuations: bool = True, + *, + download: bool = False, + url: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b", + url_symbols: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols", + ) -> None: + + self.exclude_punctuations = exclude_punctuations + + self._root_path = Path(root) + if not os.path.isdir(self._root_path): + raise RuntimeError(f"The root directory does not exist; {root}") + + dict_file = self._root_path / os.path.basename(url) + symbol_file = self._root_path / os.path.basename(url_symbols) + if not os.path.exists(dict_file): + if not download: + raise RuntimeError( + "The dictionary file is not found in the following location. " + f"Set `download=True` to download it. {dict_file}" + ) + checksum = _CHECKSUMS.get(url, None) + download_url_to_file(url, dict_file, checksum) + if not os.path.exists(symbol_file): + if not download: + raise RuntimeError( + "The symbol file is not found in the following location. " + f"Set `download=True` to download it. {symbol_file}" + ) + checksum = _CHECKSUMS.get(url_symbols, None) + download_url_to_file(url_symbols, symbol_file, checksum) + + with open(symbol_file, "r") as text: + self._symbols = [line.strip() for line in text.readlines()] + + with open(dict_file, "r", encoding="latin-1") as text: + self._dictionary = _parse_dictionary(text.readlines(), exclude_punctuations=self.exclude_punctuations) + + def __getitem__(self, n: int) -> Tuple[str, List[str]]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded. + + Returns: + Tuple of a word and its phonemes + + str: + Word + List[str]: + Phonemes + """ + return self._dictionary[n] + + def __len__(self) -> int: + return len(self._dictionary) + + @property + def symbols(self) -> List[str]: + """list[str]: A list of phonemes symbols, such as ``"AA"``, ``"AE"``, ``"AH"``.""" + return self._symbols.copy() diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/commonvoice.py b/.venv/lib/python3.11/site-packages/torchaudio/datasets/commonvoice.py new file mode 100644 index 0000000000000000000000000000000000000000..db0e035c6116487a87efcffaeea31a19212be458 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/datasets/commonvoice.py @@ -0,0 +1,86 @@ +import csv +import os +from pathlib import Path +from typing import Dict, List, Tuple, Union + +import torchaudio +from torch import Tensor +from torch.utils.data import Dataset + + +def load_commonvoice_item( + line: List[str], header: List[str], path: str, folder_audio: str, ext_audio: str +) -> Tuple[Tensor, int, Dict[str, str]]: + # Each line as the following data: + # client_id, path, sentence, up_votes, down_votes, age, gender, accent + + if header[1] != "path": + raise ValueError(f"expect `header[1]` to be 'path', but got {header[1]}") + fileid = line[1] + filename = os.path.join(path, folder_audio, fileid) + if not filename.endswith(ext_audio): + filename += ext_audio + waveform, sample_rate = torchaudio.load(filename) + + dic = dict(zip(header, line)) + + return waveform, sample_rate, dic + + +class COMMONVOICE(Dataset): + """*CommonVoice* :cite:`ardila2020common` dataset. + + Args: + root (str or Path): Path to the directory where the dataset is located. + (Where the ``tsv`` file is present.) + tsv (str, optional): + The name of the tsv file used to construct the metadata, such as + ``"train.tsv"``, ``"test.tsv"``, ``"dev.tsv"``, ``"invalidated.tsv"``, + ``"validated.tsv"`` and ``"other.tsv"``. (default: ``"train.tsv"``) + """ + + _ext_txt = ".txt" + _ext_audio = ".mp3" + _folder_audio = "clips" + + def __init__(self, root: Union[str, Path], tsv: str = "train.tsv") -> None: + + # Get string representation of 'root' in case Path object is passed + self._path = os.fspath(root) + self._tsv = os.path.join(self._path, tsv) + + with open(self._tsv, "r") as tsv_: + walker = csv.reader(tsv_, delimiter="\t") + self._header = next(walker) + self._walker = list(walker) + + def __getitem__(self, n: int) -> Tuple[Tensor, int, Dict[str, str]]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + Tuple of the following items; + + Tensor: + Waveform + int: + Sample rate + Dict[str, str]: + Dictionary containing the following items from the corresponding TSV file; + + * ``"client_id"`` + * ``"path"`` + * ``"sentence"`` + * ``"up_votes"`` + * ``"down_votes"`` + * ``"age"`` + * ``"gender"`` + * ``"accent"`` + """ + line = self._walker[n] + return load_commonvoice_item(line, self._header, self._path, self._folder_audio, self._ext_audio) + + def __len__(self) -> int: + return len(self._walker) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/dr_vctk.py b/.venv/lib/python3.11/site-packages/torchaudio/datasets/dr_vctk.py new file mode 100644 index 0000000000000000000000000000000000000000..a634b968949480738eefef926d25b05846f0b67d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/datasets/dr_vctk.py @@ -0,0 +1,121 @@ +from pathlib import Path +from typing import Dict, Tuple, Union + +import torchaudio +from torch import Tensor +from torch.utils.data import Dataset +from torchaudio._internal import download_url_to_file +from torchaudio.datasets.utils import _extract_zip + + +_URL = "https://datashare.ed.ac.uk/bitstream/handle/10283/3038/DR-VCTK.zip" +_CHECKSUM = "781f12f4406ed36ed27ae3bce55da47ba176e2d8bae67319e389e07b2c9bd769" +_SUPPORTED_SUBSETS = {"train", "test"} + + +class DR_VCTK(Dataset): + """*Device Recorded VCTK (Small subset version)* :cite:`Sarfjoo2018DeviceRV` dataset. + + Args: + root (str or Path): Root directory where the dataset's top level directory is found. + subset (str): The subset to use. Can be one of ``"train"`` and ``"test"``. (default: ``"train"``). + download (bool): + Whether to download the dataset if it is not found at root path. (default: ``False``). + url (str): The URL to download the dataset from. + (default: ``"https://datashare.ed.ac.uk/bitstream/handle/10283/3038/DR-VCTK.zip"``) + """ + + def __init__( + self, + root: Union[str, Path], + subset: str = "train", + *, + download: bool = False, + url: str = _URL, + ) -> None: + if subset not in _SUPPORTED_SUBSETS: + raise RuntimeError( + f"The subset '{subset}' does not match any of the supported subsets: {_SUPPORTED_SUBSETS}" + ) + + root = Path(root).expanduser() + archive = root / "DR-VCTK.zip" + + self._subset = subset + self._path = root / "DR-VCTK" / "DR-VCTK" + self._clean_audio_dir = self._path / f"clean_{self._subset}set_wav_16k" + self._noisy_audio_dir = self._path / f"device-recorded_{self._subset}set_wav_16k" + self._config_filepath = self._path / "configurations" / f"{self._subset}_ch_log.txt" + + if not self._path.is_dir(): + if not archive.is_file(): + if not download: + raise RuntimeError("Dataset not found. Please use `download=True` to download it.") + download_url_to_file(url, archive, hash_prefix=_CHECKSUM) + _extract_zip(archive, root) + + self._config = self._load_config(self._config_filepath) + self._filename_list = sorted(self._config) + + def _load_config(self, filepath: str) -> Dict[str, Tuple[str, int]]: + # Skip header + skip_rows = 2 if self._subset == "train" else 1 + + config = {} + with open(filepath) as f: + for i, line in enumerate(f): + if i < skip_rows or not line: + continue + filename, source, channel_id = line.strip().split("\t") + config[filename] = (source, int(channel_id)) + return config + + def _load_dr_vctk_item(self, filename: str) -> Tuple[Tensor, int, Tensor, int, str, str, str, int]: + speaker_id, utterance_id = filename.split(".")[0].split("_") + source, channel_id = self._config[filename] + file_clean_audio = self._clean_audio_dir / filename + file_noisy_audio = self._noisy_audio_dir / filename + waveform_clean, sample_rate_clean = torchaudio.load(file_clean_audio) + waveform_noisy, sample_rate_noisy = torchaudio.load(file_noisy_audio) + return ( + waveform_clean, + sample_rate_clean, + waveform_noisy, + sample_rate_noisy, + speaker_id, + utterance_id, + source, + channel_id, + ) + + def __getitem__(self, n: int) -> Tuple[Tensor, int, Tensor, int, str, str, str, int]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + Tuple of the following items; + + Tensor: + Clean waveform + int: + Sample rate of the clean waveform + Tensor: + Noisy waveform + int: + Sample rate of the noisy waveform + str: + Speaker ID + str: + Utterance ID + str: + Source + int: + Channel ID + """ + filename = self._filename_list[n] + return self._load_dr_vctk_item(filename) + + def __len__(self) -> int: + return len(self._filename_list) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/fluentcommands.py b/.venv/lib/python3.11/site-packages/torchaudio/datasets/fluentcommands.py new file mode 100644 index 0000000000000000000000000000000000000000..5cdee398d6e31a5e622321d1f73177606d9c8640 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/datasets/fluentcommands.py @@ -0,0 +1,108 @@ +import csv +import os +from pathlib import Path +from typing import Tuple, Union + +from torch import Tensor +from torch.utils.data import Dataset +from torchaudio.datasets.utils import _load_waveform + +SAMPLE_RATE = 16000 + + +class FluentSpeechCommands(Dataset): + """*Fluent Speech Commands* :cite:`fluent` dataset + + Args: + root (str of Path): Path to the directory where the dataset is found. + subset (str, optional): subset of the dataset to use. + Options: [``"train"``, ``"valid"``, ``"test"``]. + (Default: ``"train"``) + """ + + def __init__(self, root: Union[str, Path], subset: str = "train"): + if subset not in ["train", "valid", "test"]: + raise ValueError("`subset` must be one of ['train', 'valid', 'test']") + + root = os.fspath(root) + self._path = os.path.join(root, "fluent_speech_commands_dataset") + + if not os.path.isdir(self._path): + raise RuntimeError("Dataset not found.") + + subset_path = os.path.join(self._path, "data", f"{subset}_data.csv") + with open(subset_path) as subset_csv: + subset_reader = csv.reader(subset_csv) + data = list(subset_reader) + + self.header = data[0] + self.data = data[1:] + + def get_metadata(self, n: int) -> Tuple[str, int, str, int, str, str, str, str]: + """Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform, + but otherwise returns the same fields as :py:func:`__getitem__`. + + Args: + n (int): The index of the sample to be loaded + + Returns: + Tuple of the following items; + + str: + Path to audio + int: + Sample rate + str: + File name + int: + Speaker ID + str: + Transcription + str: + Action + str: + Object + str: + Location + """ + sample = self.data[n] + + file_name = sample[self.header.index("path")].split("/")[-1] + file_name = file_name.split(".")[0] + speaker_id, transcription, action, obj, location = sample[2:] + file_path = os.path.join("wavs", "speakers", speaker_id, f"{file_name}.wav") + + return file_path, SAMPLE_RATE, file_name, speaker_id, transcription, action, obj, location + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, str, str, str, str]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + Tuple of the following items; + + Tensor: + Waveform + int: + Sample rate + str: + File name + int: + Speaker ID + str: + Transcription + str: + Action + str: + Object + str: + Location + """ + metadata = self.get_metadata(n) + waveform = _load_waveform(self._path, metadata[0], metadata[1]) + return (waveform,) + metadata[1:] diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/gtzan.py b/.venv/lib/python3.11/site-packages/torchaudio/datasets/gtzan.py new file mode 100644 index 0000000000000000000000000000000000000000..347e7e71831770f42d7fdaf0b3c63a09409f659d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/datasets/gtzan.py @@ -0,0 +1,1118 @@ +import os +from pathlib import Path +from typing import Optional, Tuple, Union + +import torchaudio +from torch import Tensor +from torch.utils.data import Dataset +from torchaudio._internal import download_url_to_file +from torchaudio.datasets.utils import _extract_tar + +# The following lists prefixed with `filtered_` provide a filtered split +# that: +# +# a. Mitigate a known issue with GTZAN (duplication) +# +# b. Provide a standard split for testing it against other +# methods (e.g. the one in jordipons/sklearn-audio-transfer-learning). +# +# Those are used when GTZAN is initialised with the `filtered` keyword. +# The split was taken from (github) jordipons/sklearn-audio-transfer-learning. + +gtzan_genres = [ + "blues", + "classical", + "country", + "disco", + "hiphop", + "jazz", + "metal", + "pop", + "reggae", + "rock", +] + +filtered_test = [ + "blues.00012", + "blues.00013", + "blues.00014", + "blues.00015", + "blues.00016", + "blues.00017", + "blues.00018", + "blues.00019", + "blues.00020", + "blues.00021", + "blues.00022", + "blues.00023", + "blues.00024", + "blues.00025", + "blues.00026", + "blues.00027", + "blues.00028", + "blues.00061", + "blues.00062", + "blues.00063", + "blues.00064", + "blues.00065", + "blues.00066", + "blues.00067", + "blues.00068", + "blues.00069", + "blues.00070", + "blues.00071", + "blues.00072", + "blues.00098", + "blues.00099", + "classical.00011", + "classical.00012", + "classical.00013", + "classical.00014", + "classical.00015", + "classical.00016", + "classical.00017", + "classical.00018", + "classical.00019", + "classical.00020", + "classical.00021", + "classical.00022", + "classical.00023", + "classical.00024", + "classical.00025", + "classical.00026", + "classical.00027", + "classical.00028", + "classical.00029", + "classical.00034", + "classical.00035", + "classical.00036", + "classical.00037", + "classical.00038", + "classical.00039", + "classical.00040", + "classical.00041", + "classical.00049", + "classical.00077", + "classical.00078", + "classical.00079", + "country.00030", + "country.00031", + "country.00032", + "country.00033", + "country.00034", + "country.00035", + "country.00036", + "country.00037", + "country.00038", + "country.00039", + "country.00040", + "country.00043", + "country.00044", + "country.00046", + "country.00047", + "country.00048", + "country.00050", + "country.00051", + "country.00053", + "country.00054", + "country.00055", + "country.00056", + "country.00057", + "country.00058", + "country.00059", + "country.00060", + "country.00061", + "country.00062", + "country.00063", + "country.00064", + "disco.00001", + "disco.00021", + "disco.00058", + "disco.00062", + "disco.00063", + "disco.00064", + "disco.00065", + "disco.00066", + "disco.00069", + "disco.00076", + "disco.00077", + "disco.00078", + "disco.00079", + "disco.00080", + "disco.00081", + "disco.00082", + "disco.00083", + "disco.00084", + "disco.00085", + "disco.00086", + "disco.00087", + "disco.00088", + "disco.00091", + "disco.00092", + "disco.00093", + "disco.00094", + "disco.00096", + "disco.00097", + "disco.00099", + "hiphop.00000", + "hiphop.00026", + "hiphop.00027", + "hiphop.00030", + "hiphop.00040", + "hiphop.00043", + "hiphop.00044", + "hiphop.00045", + "hiphop.00051", + "hiphop.00052", + "hiphop.00053", + "hiphop.00054", + "hiphop.00062", + "hiphop.00063", + "hiphop.00064", + "hiphop.00065", + "hiphop.00066", + "hiphop.00067", + "hiphop.00068", + "hiphop.00069", + "hiphop.00070", + "hiphop.00071", + "hiphop.00072", + "hiphop.00073", + "hiphop.00074", + "hiphop.00075", + "hiphop.00099", + "jazz.00073", + "jazz.00074", + "jazz.00075", + "jazz.00076", + "jazz.00077", + "jazz.00078", + "jazz.00079", + "jazz.00080", + "jazz.00081", + "jazz.00082", + "jazz.00083", + "jazz.00084", + "jazz.00085", + "jazz.00086", + "jazz.00087", + "jazz.00088", + "jazz.00089", + "jazz.00090", + "jazz.00091", + "jazz.00092", + "jazz.00093", + "jazz.00094", + "jazz.00095", + "jazz.00096", + "jazz.00097", + "jazz.00098", + "jazz.00099", + "metal.00012", + "metal.00013", + "metal.00014", + "metal.00015", + "metal.00022", + "metal.00023", + "metal.00025", + "metal.00026", + "metal.00027", + "metal.00028", + "metal.00029", + "metal.00030", + "metal.00031", + "metal.00032", + "metal.00033", + "metal.00038", + "metal.00039", + "metal.00067", + "metal.00070", + "metal.00073", + "metal.00074", + "metal.00075", + "metal.00078", + "metal.00083", + "metal.00085", + "metal.00087", + "metal.00088", + "pop.00000", + "pop.00001", + "pop.00013", + "pop.00014", + "pop.00043", + "pop.00063", + "pop.00064", + "pop.00065", + "pop.00066", + "pop.00069", + "pop.00070", + "pop.00071", + "pop.00072", + "pop.00073", + "pop.00074", + "pop.00075", + "pop.00076", + "pop.00077", + "pop.00078", + "pop.00079", + "pop.00082", + "pop.00088", + "pop.00089", + "pop.00090", + "pop.00091", + "pop.00092", + "pop.00093", + "pop.00094", + "pop.00095", + "pop.00096", + "reggae.00034", + "reggae.00035", + "reggae.00036", + "reggae.00037", + "reggae.00038", + "reggae.00039", + "reggae.00040", + "reggae.00046", + "reggae.00047", + "reggae.00048", + "reggae.00052", + "reggae.00053", + "reggae.00064", + "reggae.00065", + "reggae.00066", + "reggae.00067", + "reggae.00068", + "reggae.00071", + "reggae.00079", + "reggae.00082", + "reggae.00083", + "reggae.00084", + "reggae.00087", + "reggae.00088", + "reggae.00089", + "reggae.00090", + "rock.00010", + "rock.00011", + "rock.00012", + "rock.00013", + "rock.00014", + "rock.00015", + "rock.00027", + "rock.00028", + "rock.00029", + "rock.00030", + "rock.00031", + "rock.00032", + "rock.00033", + "rock.00034", + "rock.00035", + "rock.00036", + "rock.00037", + "rock.00039", + "rock.00040", + "rock.00041", + "rock.00042", + "rock.00043", + "rock.00044", + "rock.00045", + "rock.00046", + "rock.00047", + "rock.00048", + "rock.00086", + "rock.00087", + "rock.00088", + "rock.00089", + "rock.00090", +] + +filtered_train = [ + "blues.00029", + "blues.00030", + "blues.00031", + "blues.00032", + "blues.00033", + "blues.00034", + "blues.00035", + "blues.00036", + "blues.00037", + "blues.00038", + "blues.00039", + "blues.00040", + "blues.00041", + "blues.00042", + "blues.00043", + "blues.00044", + "blues.00045", + "blues.00046", + "blues.00047", + "blues.00048", + "blues.00049", + "blues.00073", + "blues.00074", + "blues.00075", + "blues.00076", + "blues.00077", + "blues.00078", + "blues.00079", + "blues.00080", + "blues.00081", + "blues.00082", + "blues.00083", + "blues.00084", + "blues.00085", + "blues.00086", + "blues.00087", + "blues.00088", + "blues.00089", + "blues.00090", + "blues.00091", + "blues.00092", + "blues.00093", + "blues.00094", + "blues.00095", + "blues.00096", + "blues.00097", + "classical.00030", + "classical.00031", + "classical.00032", + "classical.00033", + "classical.00043", + "classical.00044", + "classical.00045", + "classical.00046", + "classical.00047", + "classical.00048", + "classical.00050", + "classical.00051", + "classical.00052", + "classical.00053", + "classical.00054", + "classical.00055", + "classical.00056", + "classical.00057", + "classical.00058", + "classical.00059", + "classical.00060", + "classical.00061", + "classical.00062", + "classical.00063", + "classical.00064", + "classical.00065", + "classical.00066", + "classical.00067", + "classical.00080", + "classical.00081", + "classical.00082", + "classical.00083", + "classical.00084", + "classical.00085", + "classical.00086", + "classical.00087", + "classical.00088", + "classical.00089", + "classical.00090", + "classical.00091", + "classical.00092", + "classical.00093", + "classical.00094", + "classical.00095", + "classical.00096", + "classical.00097", + "classical.00098", + "classical.00099", + "country.00019", + "country.00020", + "country.00021", + "country.00022", + "country.00023", + "country.00024", + "country.00025", + "country.00026", + "country.00028", + "country.00029", + "country.00065", + "country.00066", + "country.00067", + "country.00068", + "country.00069", + "country.00070", + "country.00071", + "country.00072", + "country.00073", + "country.00074", + "country.00075", + "country.00076", + "country.00077", + "country.00078", + "country.00079", + "country.00080", + "country.00081", + "country.00082", + "country.00083", + "country.00084", + "country.00085", + "country.00086", + "country.00087", + "country.00088", + "country.00089", + "country.00090", + "country.00091", + "country.00092", + "country.00093", + "country.00094", + "country.00095", + "country.00096", + "country.00097", + "country.00098", + "country.00099", + "disco.00005", + "disco.00015", + "disco.00016", + "disco.00017", + "disco.00018", + "disco.00019", + "disco.00020", + "disco.00022", + "disco.00023", + "disco.00024", + "disco.00025", + "disco.00026", + "disco.00027", + "disco.00028", + "disco.00029", + "disco.00030", + "disco.00031", + "disco.00032", + "disco.00033", + "disco.00034", + "disco.00035", + "disco.00036", + "disco.00037", + "disco.00039", + "disco.00040", + "disco.00041", + "disco.00042", + "disco.00043", + "disco.00044", + "disco.00045", + "disco.00047", + "disco.00049", + "disco.00053", + "disco.00054", + "disco.00056", + "disco.00057", + "disco.00059", + "disco.00061", + "disco.00070", + "disco.00073", + "disco.00074", + "disco.00089", + "hiphop.00002", + "hiphop.00003", + "hiphop.00004", + "hiphop.00005", + "hiphop.00006", + "hiphop.00007", + "hiphop.00008", + "hiphop.00009", + "hiphop.00010", + "hiphop.00011", + "hiphop.00012", + "hiphop.00013", + "hiphop.00014", + "hiphop.00015", + "hiphop.00016", + "hiphop.00017", + "hiphop.00018", + "hiphop.00019", + "hiphop.00020", + "hiphop.00021", + "hiphop.00022", + "hiphop.00023", + "hiphop.00024", + "hiphop.00025", + "hiphop.00028", + "hiphop.00029", + "hiphop.00031", + "hiphop.00032", + "hiphop.00033", + "hiphop.00034", + "hiphop.00035", + "hiphop.00036", + "hiphop.00037", + "hiphop.00038", + "hiphop.00041", + "hiphop.00042", + "hiphop.00055", + "hiphop.00056", + "hiphop.00057", + "hiphop.00058", + "hiphop.00059", + "hiphop.00060", + "hiphop.00061", + "hiphop.00077", + "hiphop.00078", + "hiphop.00079", + "hiphop.00080", + "jazz.00000", + "jazz.00001", + "jazz.00011", + "jazz.00012", + "jazz.00013", + "jazz.00014", + "jazz.00015", + "jazz.00016", + "jazz.00017", + "jazz.00018", + "jazz.00019", + "jazz.00020", + "jazz.00021", + "jazz.00022", + "jazz.00023", + "jazz.00024", + "jazz.00041", + "jazz.00047", + "jazz.00048", + "jazz.00049", + "jazz.00050", + "jazz.00051", + "jazz.00052", + "jazz.00053", + "jazz.00054", + "jazz.00055", + "jazz.00056", + "jazz.00057", + "jazz.00058", + "jazz.00059", + "jazz.00060", + "jazz.00061", + "jazz.00062", + "jazz.00063", + "jazz.00064", + "jazz.00065", + "jazz.00066", + "jazz.00067", + "jazz.00068", + "jazz.00069", + "jazz.00070", + "jazz.00071", + "jazz.00072", + "metal.00002", + "metal.00003", + "metal.00005", + "metal.00021", + "metal.00024", + "metal.00035", + "metal.00046", + "metal.00047", + "metal.00048", + "metal.00049", + "metal.00050", + "metal.00051", + "metal.00052", + "metal.00053", + "metal.00054", + "metal.00055", + "metal.00056", + "metal.00057", + "metal.00059", + "metal.00060", + "metal.00061", + "metal.00062", + "metal.00063", + "metal.00064", + "metal.00065", + "metal.00066", + "metal.00069", + "metal.00071", + "metal.00072", + "metal.00079", + "metal.00080", + "metal.00084", + "metal.00086", + "metal.00089", + "metal.00090", + "metal.00091", + "metal.00092", + "metal.00093", + "metal.00094", + "metal.00095", + "metal.00096", + "metal.00097", + "metal.00098", + "metal.00099", + "pop.00002", + "pop.00003", + "pop.00004", + "pop.00005", + "pop.00006", + "pop.00007", + "pop.00008", + "pop.00009", + "pop.00011", + "pop.00012", + "pop.00016", + "pop.00017", + "pop.00018", + "pop.00019", + "pop.00020", + "pop.00023", + "pop.00024", + "pop.00025", + "pop.00026", + "pop.00027", + "pop.00028", + "pop.00029", + "pop.00031", + "pop.00032", + "pop.00033", + "pop.00034", + "pop.00035", + "pop.00036", + "pop.00038", + "pop.00039", + "pop.00040", + "pop.00041", + "pop.00042", + "pop.00044", + "pop.00046", + "pop.00049", + "pop.00050", + "pop.00080", + "pop.00097", + "pop.00098", + "pop.00099", + "reggae.00000", + "reggae.00001", + "reggae.00002", + "reggae.00004", + "reggae.00006", + "reggae.00009", + "reggae.00011", + "reggae.00012", + "reggae.00014", + "reggae.00015", + "reggae.00016", + "reggae.00017", + "reggae.00018", + "reggae.00019", + "reggae.00020", + "reggae.00021", + "reggae.00022", + "reggae.00023", + "reggae.00024", + "reggae.00025", + "reggae.00026", + "reggae.00027", + "reggae.00028", + "reggae.00029", + "reggae.00030", + "reggae.00031", + "reggae.00032", + "reggae.00042", + "reggae.00043", + "reggae.00044", + "reggae.00045", + "reggae.00049", + "reggae.00050", + "reggae.00051", + "reggae.00054", + "reggae.00055", + "reggae.00056", + "reggae.00057", + "reggae.00058", + "reggae.00059", + "reggae.00060", + "reggae.00063", + "reggae.00069", + "rock.00000", + "rock.00001", + "rock.00002", + "rock.00003", + "rock.00004", + "rock.00005", + "rock.00006", + "rock.00007", + "rock.00008", + "rock.00009", + "rock.00016", + "rock.00017", + "rock.00018", + "rock.00019", + "rock.00020", + "rock.00021", + "rock.00022", + "rock.00023", + "rock.00024", + "rock.00025", + "rock.00026", + "rock.00057", + "rock.00058", + "rock.00059", + "rock.00060", + "rock.00061", + "rock.00062", + "rock.00063", + "rock.00064", + "rock.00065", + "rock.00066", + "rock.00067", + "rock.00068", + "rock.00069", + "rock.00070", + "rock.00091", + "rock.00092", + "rock.00093", + "rock.00094", + "rock.00095", + "rock.00096", + "rock.00097", + "rock.00098", + "rock.00099", +] + +filtered_valid = [ + "blues.00000", + "blues.00001", + "blues.00002", + "blues.00003", + "blues.00004", + "blues.00005", + "blues.00006", + "blues.00007", + "blues.00008", + "blues.00009", + "blues.00010", + "blues.00011", + "blues.00050", + "blues.00051", + "blues.00052", + "blues.00053", + "blues.00054", + "blues.00055", + "blues.00056", + "blues.00057", + "blues.00058", + "blues.00059", + "blues.00060", + "classical.00000", + "classical.00001", + "classical.00002", + "classical.00003", + "classical.00004", + "classical.00005", + "classical.00006", + "classical.00007", + "classical.00008", + "classical.00009", + "classical.00010", + "classical.00068", + "classical.00069", + "classical.00070", + "classical.00071", + "classical.00072", + "classical.00073", + "classical.00074", + "classical.00075", + "classical.00076", + "country.00000", + "country.00001", + "country.00002", + "country.00003", + "country.00004", + "country.00005", + "country.00006", + "country.00007", + "country.00009", + "country.00010", + "country.00011", + "country.00012", + "country.00013", + "country.00014", + "country.00015", + "country.00016", + "country.00017", + "country.00018", + "country.00027", + "country.00041", + "country.00042", + "country.00045", + "country.00049", + "disco.00000", + "disco.00002", + "disco.00003", + "disco.00004", + "disco.00006", + "disco.00007", + "disco.00008", + "disco.00009", + "disco.00010", + "disco.00011", + "disco.00012", + "disco.00013", + "disco.00014", + "disco.00046", + "disco.00048", + "disco.00052", + "disco.00067", + "disco.00068", + "disco.00072", + "disco.00075", + "disco.00090", + "disco.00095", + "hiphop.00081", + "hiphop.00082", + "hiphop.00083", + "hiphop.00084", + "hiphop.00085", + "hiphop.00086", + "hiphop.00087", + "hiphop.00088", + "hiphop.00089", + "hiphop.00090", + "hiphop.00091", + "hiphop.00092", + "hiphop.00093", + "hiphop.00094", + "hiphop.00095", + "hiphop.00096", + "hiphop.00097", + "hiphop.00098", + "jazz.00002", + "jazz.00003", + "jazz.00004", + "jazz.00005", + "jazz.00006", + "jazz.00007", + "jazz.00008", + "jazz.00009", + "jazz.00010", + "jazz.00025", + "jazz.00026", + "jazz.00027", + "jazz.00028", + "jazz.00029", + "jazz.00030", + "jazz.00031", + "jazz.00032", + "metal.00000", + "metal.00001", + "metal.00006", + "metal.00007", + "metal.00008", + "metal.00009", + "metal.00010", + "metal.00011", + "metal.00016", + "metal.00017", + "metal.00018", + "metal.00019", + "metal.00020", + "metal.00036", + "metal.00037", + "metal.00068", + "metal.00076", + "metal.00077", + "metal.00081", + "metal.00082", + "pop.00010", + "pop.00053", + "pop.00055", + "pop.00058", + "pop.00059", + "pop.00060", + "pop.00061", + "pop.00062", + "pop.00081", + "pop.00083", + "pop.00084", + "pop.00085", + "pop.00086", + "reggae.00061", + "reggae.00062", + "reggae.00070", + "reggae.00072", + "reggae.00074", + "reggae.00076", + "reggae.00077", + "reggae.00078", + "reggae.00085", + "reggae.00092", + "reggae.00093", + "reggae.00094", + "reggae.00095", + "reggae.00096", + "reggae.00097", + "reggae.00098", + "reggae.00099", + "rock.00038", + "rock.00049", + "rock.00050", + "rock.00051", + "rock.00052", + "rock.00053", + "rock.00054", + "rock.00055", + "rock.00056", + "rock.00071", + "rock.00072", + "rock.00073", + "rock.00074", + "rock.00075", + "rock.00076", + "rock.00077", + "rock.00078", + "rock.00079", + "rock.00080", + "rock.00081", + "rock.00082", + "rock.00083", + "rock.00084", + "rock.00085", +] + + +URL = "http://opihi.cs.uvic.ca/sound/genres.tar.gz" +FOLDER_IN_ARCHIVE = "genres" +_CHECKSUMS = { + "http://opihi.cs.uvic.ca/sound/genres.tar.gz": "24347e0223d2ba798e0a558c4c172d9d4a19c00bb7963fe055d183dadb4ef2c6" +} + + +def load_gtzan_item(fileid: str, path: str, ext_audio: str) -> Tuple[Tensor, str]: + """ + Loads a file from the dataset and returns the raw waveform + as a Torch Tensor, its sample rate as an integer, and its + genre as a string. + """ + # Filenames are of the form label.id, e.g. blues.00078 + label, _ = fileid.split(".") + + # Read wav + file_audio = os.path.join(path, label, fileid + ext_audio) + waveform, sample_rate = torchaudio.load(file_audio) + + return waveform, sample_rate, label + + +class GTZAN(Dataset): + """*GTZAN* :cite:`tzanetakis_essl_cook_2001` dataset. + + Note: + Please see http://marsyas.info/downloads/datasets.html if you are planning to use + this dataset to publish results. + + Note: + As of October 2022, the download link is not currently working. Setting ``download=True`` + in GTZAN dataset will result in a URL connection error. + + Args: + root (str or Path): Path to the directory where the dataset is found or downloaded. + url (str, optional): The URL to download the dataset from. + (default: ``"http://opihi.cs.uvic.ca/sound/genres.tar.gz"``) + folder_in_archive (str, optional): The top-level directory of the dataset. + download (bool, optional): + Whether to download the dataset if it is not found at root path. (default: ``False``). + subset (str or None, optional): Which subset of the dataset to use. + One of ``"training"``, ``"validation"``, ``"testing"`` or ``None``. + If ``None``, the entire dataset is used. (default: ``None``). + """ + + _ext_audio = ".wav" + + def __init__( + self, + root: Union[str, Path], + url: str = URL, + folder_in_archive: str = FOLDER_IN_ARCHIVE, + download: bool = False, + subset: Optional[str] = None, + ) -> None: + + # super(GTZAN, self).__init__() + + # Get string representation of 'root' in case Path object is passed + root = os.fspath(root) + + self.root = root + self.url = url + self.folder_in_archive = folder_in_archive + self.download = download + self.subset = subset + + if subset is not None and subset not in ["training", "validation", "testing"]: + raise ValueError("When `subset` is not None, it must be one of ['training', 'validation', 'testing'].") + + archive = os.path.basename(url) + archive = os.path.join(root, archive) + self._path = os.path.join(root, folder_in_archive) + + if download: + if not os.path.isdir(self._path): + if not os.path.isfile(archive): + checksum = _CHECKSUMS.get(url, None) + download_url_to_file(url, archive, hash_prefix=checksum) + _extract_tar(archive) + + if not os.path.isdir(self._path): + raise RuntimeError("Dataset not found. Please use `download=True` to download it.") + + if self.subset is None: + # Check every subdirectory under dataset root + # which has the same name as the genres in + # GTZAN (e.g. `root_dir'/blues/, `root_dir'/rock, etc.) + # This lets users remove or move around song files, + # useful when e.g. they want to use only some of the files + # in a genre or want to label other files with a different + # genre. + self._walker = [] + + root = os.path.expanduser(self._path) + + for directory in gtzan_genres: + fulldir = os.path.join(root, directory) + + if not os.path.exists(fulldir): + continue + + songs_in_genre = os.listdir(fulldir) + songs_in_genre.sort() + for fname in songs_in_genre: + name, ext = os.path.splitext(fname) + if ext.lower() == ".wav" and "." in name: + # Check whether the file is of the form + # `gtzan_genre`.`5 digit number`.wav + genre, num = name.split(".") + if genre in gtzan_genres and len(num) == 5 and num.isdigit(): + self._walker.append(name) + else: + if self.subset == "training": + self._walker = filtered_train + elif self.subset == "validation": + self._walker = filtered_valid + elif self.subset == "testing": + self._walker = filtered_test + + def __getitem__(self, n: int) -> Tuple[Tensor, int, str]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + Tuple of the following items; + + Tensor: + Waveform + int: + Sample rate + str: + Label + """ + fileid = self._walker[n] + item = load_gtzan_item(fileid, self._path, self._ext_audio) + waveform, sample_rate, label = item + return waveform, sample_rate, label + + def __len__(self) -> int: + return len(self._walker) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/iemocap.py b/.venv/lib/python3.11/site-packages/torchaudio/datasets/iemocap.py new file mode 100644 index 0000000000000000000000000000000000000000..224300a84f5ec3ae217f030783c825fc3db56c8a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/datasets/iemocap.py @@ -0,0 +1,147 @@ +import os +import re +from pathlib import Path +from typing import Optional, Tuple, Union + +from torch import Tensor +from torch.utils.data import Dataset +from torchaudio.datasets.utils import _load_waveform + + +_SAMPLE_RATE = 16000 + + +def _get_wavs_paths(data_dir): + wav_dir = data_dir / "sentences" / "wav" + wav_paths = sorted(str(p) for p in wav_dir.glob("*/*.wav")) + relative_paths = [] + for wav_path in wav_paths: + start = wav_path.find("Session") + wav_path = wav_path[start:] + relative_paths.append(wav_path) + return relative_paths + + +class IEMOCAP(Dataset): + """*IEMOCAP* :cite:`iemocap` dataset. + + Args: + root (str or Path): Root directory where the dataset's top level directory is found + sessions (Tuple[int]): Tuple of sessions (1-5) to use. (Default: ``(1, 2, 3, 4, 5)``) + utterance_type (str or None, optional): Which type(s) of utterances to include in the dataset. + Options: ("scripted", "improvised", ``None``). If ``None``, both scripted and improvised + data are used. + """ + + def __init__( + self, + root: Union[str, Path], + sessions: Tuple[str] = (1, 2, 3, 4, 5), + utterance_type: Optional[str] = None, + ): + root = Path(root) + self._path = root / "IEMOCAP" + + if not os.path.isdir(self._path): + raise RuntimeError("Dataset not found.") + + if utterance_type not in ["scripted", "improvised", None]: + raise ValueError("utterance_type must be one of ['scripted', 'improvised', or None]") + + all_data = [] + self.data = [] + self.mapping = {} + + for session in sessions: + session_name = f"Session{session}" + session_dir = self._path / session_name + + # get wav paths + wav_paths = _get_wavs_paths(session_dir) + for wav_path in wav_paths: + wav_stem = str(Path(wav_path).stem) + all_data.append(wav_stem) + + # add labels + label_dir = session_dir / "dialog" / "EmoEvaluation" + query = "*.txt" + if utterance_type == "scripted": + query = "*script*.txt" + elif utterance_type == "improvised": + query = "*impro*.txt" + label_paths = label_dir.glob(query) + + for label_path in label_paths: + with open(label_path, "r") as f: + for line in f: + if not line.startswith("["): + continue + line = re.split("[\t\n]", line) + wav_stem = line[1] + label = line[2] + if wav_stem not in all_data: + continue + if label not in ["neu", "hap", "ang", "sad", "exc", "fru"]: + continue + self.mapping[wav_stem] = {} + self.mapping[wav_stem]["label"] = label + + for wav_path in wav_paths: + wav_stem = str(Path(wav_path).stem) + if wav_stem in self.mapping: + self.data.append(wav_stem) + self.mapping[wav_stem]["path"] = wav_path + + def get_metadata(self, n: int) -> Tuple[str, int, str, str, str]: + """Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform, + but otherwise returns the same fields as :py:meth:`__getitem__`. + + Args: + n (int): The index of the sample to be loaded + + Returns: + Tuple of the following items; + + str: + Path to audio + int: + Sample rate + str: + File name + str: + Label (one of ``"neu"``, ``"hap"``, ``"ang"``, ``"sad"``, ``"exc"``, ``"fru"``) + str: + Speaker + """ + wav_stem = self.data[n] + wav_path = self.mapping[wav_stem]["path"] + label = self.mapping[wav_stem]["label"] + speaker = wav_stem.split("_")[0] + return (wav_path, _SAMPLE_RATE, wav_stem, label, speaker) + + def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + Tuple of the following items; + + Tensor: + Waveform + int: + Sample rate + str: + File name + str: + Label (one of ``"neu"``, ``"hap"``, ``"ang"``, ``"sad"``, ``"exc"``, ``"fru"``) + str: + Speaker + """ + metadata = self.get_metadata(n) + waveform = _load_waveform(self._path, metadata[0], metadata[1]) + return (waveform,) + metadata[1:] + + def __len__(self): + return len(self.data) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/librilight_limited.py b/.venv/lib/python3.11/site-packages/torchaudio/datasets/librilight_limited.py new file mode 100644 index 0000000000000000000000000000000000000000..f0cb3100f7c4ad2e488c20bdfaac3833e0a136dd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/datasets/librilight_limited.py @@ -0,0 +1,111 @@ +import os +from pathlib import Path +from typing import List, Tuple, Union + +import torchaudio +from torch import Tensor +from torch.utils.data import Dataset +from torchaudio._internal import download_url_to_file +from torchaudio.datasets.librispeech import _get_librispeech_metadata +from torchaudio.datasets.utils import _extract_tar + + +_ARCHIVE_NAME = "librispeech_finetuning" +_URL = "https://dl.fbaipublicfiles.com/librilight/data/librispeech_finetuning.tgz" +_CHECKSUM = "5d1efdc777b548194d7e09ba89126e2188026df9fd57aa57eb14408d2b2342af" +_SUBSET_MAP = {"10min": ["1h/0"], "1h": ["1h/*"], "10h": ["1h/*", "9h"]} + + +def _get_fileids_paths(path: Path, folders: List[str], _ext_audio: str) -> List[Tuple[str, str]]: + """Get the file names and the corresponding file paths without `speaker_id` + and `chapter_id` directories. + The format of path is like: + {root}/{_ARCHIVE_NAME}/1h/[0-5]/[clean, other] or + {root}/{_ARCHIVE_NAME}/9h/[clean, other] + + Args: + path (Path): Root path to the dataset. + folders (List[str]): Folders that contain the desired audio files. + _ext_audio (str): Extension of audio files. + + Returns: + List[Tuple[str, str]]: + List of tuples where the first element is the relative path to the audio file. + The format of relative path is like: + 1h/[0-5]/[clean, other] or 9h/[clean, other] + The second element is the file name without audio extension. + """ + + path = Path(path) + files_paths = [] + for folder in folders: + paths = [p.relative_to(path) for p in path.glob(f"{folder}/*/*/*/*{_ext_audio}")] + files_paths += [(str(p.parent.parent.parent), str(p.stem)) for p in paths] # get subset folder and file name + files_paths.sort(key=lambda x: x[0] + x[1]) + return files_paths + + +class LibriLightLimited(Dataset): + """Subset of Libri-light :cite:`librilight` dataset, + which was used in HuBERT :cite:`hsu2021hubert` for supervised fine-tuning. + + Args: + root (str or Path): Path to the directory where the dataset is found or downloaded. + subset (str, optional): The subset to use. Options: [``"10min"``, ``"1h"``, ``"10h"``] + (Default: ``"10min"``). + download (bool, optional): + Whether to download the dataset if it is not found at root path. (default: ``False``). + """ + + _ext_txt = ".trans.txt" + _ext_audio = ".flac" + + def __init__( + self, + root: Union[str, Path], + subset: str = "10min", + download: bool = False, + ) -> None: + if subset not in _SUBSET_MAP: + raise ValueError(f"`subset` must be one of {_SUBSET_MAP.keys()}. Found: {subset}") + folders = _SUBSET_MAP[subset] + + root = os.fspath(root) + self._path = os.path.join(root, _ARCHIVE_NAME) + archive = os.path.join(root, f"{_ARCHIVE_NAME}.tgz") + if not os.path.isdir(self._path): + if not download: + raise RuntimeError("Dataset not found. Please use `download=True` to download") + if not os.path.isfile(archive): + download_url_to_file(_URL, archive, hash_prefix=_CHECKSUM) + _extract_tar(archive) + self._fileids_paths = _get_fileids_paths(self._path, folders, self._ext_audio) + + def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + Returns: + Tuple of the following items; + + Tensor: + Waveform + int: + Sample rate + str: + Transcript + int: + Speaker ID + int: + Chapter ID + int: + Utterance ID + """ + file_path, fileid = self._fileids_paths[n] + metadata = _get_librispeech_metadata(fileid, self._path, file_path, self._ext_audio, self._ext_txt) + waveform, _ = torchaudio.load(os.path.join(self._path, metadata[0])) + return (waveform,) + metadata[1:] + + def __len__(self) -> int: + return len(self._fileids_paths) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/librimix.py b/.venv/lib/python3.11/site-packages/torchaudio/datasets/librimix.py new file mode 100644 index 0000000000000000000000000000000000000000..2c6c6f18600ab35f037dda11f9f5bc32c8a5cbf5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/datasets/librimix.py @@ -0,0 +1,133 @@ +import os +from pathlib import Path +from typing import List, Tuple, Union + +import torch +from torch.utils.data import Dataset +from torchaudio.datasets.utils import _load_waveform + +_TASKS_TO_MIXTURE = { + "sep_clean": "mix_clean", + "enh_single": "mix_single", + "enh_both": "mix_both", + "sep_noisy": "mix_both", +} + + +class LibriMix(Dataset): + r"""*LibriMix* :cite:`cosentino2020librimix` dataset. + + Args: + root (str or Path): The path where the directory ``Libri2Mix`` or + ``Libri3Mix`` is stored. Not the path of those directories. + subset (str, optional): The subset to use. Options: [``"train-360"``, ``"train-100"``, + ``"dev"``, and ``"test"``] (Default: ``"train-360"``). + num_speakers (int, optional): The number of speakers, which determines the directories + to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect + N source audios. (Default: 2) + sample_rate (int, optional): Sample rate of audio files. The ``sample_rate`` determines + which subdirectory the audio are fetched. If any of the audio has a different sample + rate, raises ``ValueError``. Options: [8000, 16000] (Default: 8000) + task (str, optional): The task of LibriMix. + Options: [``"enh_single"``, ``"enh_both"``, ``"sep_clean"``, ``"sep_noisy"``] + (Default: ``"sep_clean"``) + mode (str, optional): The mode when creating the mixture. If set to ``"min"``, the lengths of mixture + and sources are the minimum length of all sources. If set to ``"max"``, the lengths of mixture and + sources are zero padded to the maximum length of all sources. + Options: [``"min"``, ``"max"``] + (Default: ``"min"``) + + Note: + The LibriMix dataset needs to be manually generated. Please check https://github.com/JorisCos/LibriMix + """ + + def __init__( + self, + root: Union[str, Path], + subset: str = "train-360", + num_speakers: int = 2, + sample_rate: int = 8000, + task: str = "sep_clean", + mode: str = "min", + ): + self.root = Path(root) / f"Libri{num_speakers}Mix" + if not os.path.exists(self.root): + raise RuntimeError( + f"The path {self.root} doesn't exist. " + "Please check the ``root`` path and ``num_speakers`` or download the dataset manually." + ) + if mode not in ["max", "min"]: + raise ValueError(f'Expect ``mode`` to be one in ["min", "max"]. Found {mode}.') + if sample_rate == 8000: + mix_dir = self.root / "wav8k" / mode / subset + elif sample_rate == 16000: + mix_dir = self.root / "wav16k" / mode / subset + else: + raise ValueError(f"Unsupported sample rate. Found {sample_rate}.") + self.sample_rate = sample_rate + self.task = task + + self.mix_dir = mix_dir / _TASKS_TO_MIXTURE[task] + if task == "enh_both": + self.src_dirs = [(mix_dir / "mix_clean")] + else: + self.src_dirs = [(mix_dir / f"s{i+1}") for i in range(num_speakers)] + + self.files = [p.name for p in self.mix_dir.glob("*.wav")] + self.files.sort() + + def _load_sample(self, key) -> Tuple[int, torch.Tensor, List[torch.Tensor]]: + metadata = self.get_metadata(key) + mixed = _load_waveform(self.root, metadata[1], metadata[0]) + srcs = [] + for i, path_ in enumerate(metadata[2]): + src = _load_waveform(self.root, path_, metadata[0]) + if mixed.shape != src.shape: + raise ValueError(f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}") + srcs.append(src) + return self.sample_rate, mixed, srcs + + def get_metadata(self, key: int) -> Tuple[int, str, List[str]]: + """Get metadata for the n-th sample from the dataset. + + Args: + key (int): The index of the sample to be loaded + + Returns: + Tuple of the following items; + + int: + Sample rate + str: + Path to mixed audio + List of str: + List of paths to source audios + """ + filename = self.files[key] + mixed_path = os.path.relpath(self.mix_dir / filename, self.root) + srcs_paths = [] + for dir_ in self.src_dirs: + src = os.path.relpath(dir_ / filename, self.root) + srcs_paths.append(src) + return self.sample_rate, mixed_path, srcs_paths + + def __len__(self) -> int: + return len(self.files) + + def __getitem__(self, key: int) -> Tuple[int, torch.Tensor, List[torch.Tensor]]: + """Load the n-th sample from the dataset. + + Args: + key (int): The index of the sample to be loaded + + Returns: + Tuple of the following items; + + int: + Sample rate + Tensor: + Mixture waveform + List of Tensors: + List of source waveforms + """ + return self._load_sample(key) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/librispeech.py b/.venv/lib/python3.11/site-packages/torchaudio/datasets/librispeech.py new file mode 100644 index 0000000000000000000000000000000000000000..7cf05dbecb5cce24c91e3bbcf232935e1f6d8cd9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/datasets/librispeech.py @@ -0,0 +1,174 @@ +import os +from pathlib import Path +from typing import Tuple, Union + +from torch import Tensor +from torch.utils.data import Dataset +from torchaudio._internal import download_url_to_file +from torchaudio.datasets.utils import _extract_tar, _load_waveform + +URL = "train-clean-100" +FOLDER_IN_ARCHIVE = "LibriSpeech" +SAMPLE_RATE = 16000 +_DATA_SUBSETS = [ + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + "train-clean-360", + "train-other-500", +] +_CHECKSUMS = { + "http://www.openslr.org/resources/12/dev-clean.tar.gz": "76f87d090650617fca0cac8f88b9416e0ebf80350acb97b343a85fa903728ab3", # noqa: E501 + "http://www.openslr.org/resources/12/dev-other.tar.gz": "12661c48e8c3fe1de2c1caa4c3e135193bfb1811584f11f569dd12645aa84365", # noqa: E501 + "http://www.openslr.org/resources/12/test-clean.tar.gz": "39fde525e59672dc6d1551919b1478f724438a95aa55f874b576be21967e6c23", # noqa: E501 + "http://www.openslr.org/resources/12/test-other.tar.gz": "d09c181bba5cf717b3dee7d4d592af11a3ee3a09e08ae025c5506f6ebe961c29", # noqa: E501 + "http://www.openslr.org/resources/12/train-clean-100.tar.gz": "d4ddd1d5a6ab303066f14971d768ee43278a5f2a0aa43dc716b0e64ecbbbf6e2", # noqa: E501 + "http://www.openslr.org/resources/12/train-clean-360.tar.gz": "146a56496217e96c14334a160df97fffedd6e0a04e66b9c5af0d40be3c792ecf", # noqa: E501 + "http://www.openslr.org/resources/12/train-other-500.tar.gz": "ddb22f27f96ec163645d53215559df6aa36515f26e01dd70798188350adcb6d2", # noqa: E501 +} + + +def _download_librispeech(root, url): + base_url = "http://www.openslr.org/resources/12/" + ext_archive = ".tar.gz" + + filename = url + ext_archive + archive = os.path.join(root, filename) + download_url = os.path.join(base_url, filename) + if not os.path.isfile(archive): + checksum = _CHECKSUMS.get(download_url, None) + download_url_to_file(download_url, archive, hash_prefix=checksum) + _extract_tar(archive) + + +def _get_librispeech_metadata( + fileid: str, root: str, folder: str, ext_audio: str, ext_txt: str +) -> Tuple[str, int, str, int, int, int]: + speaker_id, chapter_id, utterance_id = fileid.split("-") + + # Get audio path and sample rate + fileid_audio = f"{speaker_id}-{chapter_id}-{utterance_id}" + filepath = os.path.join(folder, speaker_id, chapter_id, f"{fileid_audio}{ext_audio}") + + # Load text + file_text = f"{speaker_id}-{chapter_id}{ext_txt}" + file_text = os.path.join(root, folder, speaker_id, chapter_id, file_text) + with open(file_text) as ft: + for line in ft: + fileid_text, transcript = line.strip().split(" ", 1) + if fileid_audio == fileid_text: + break + else: + # Translation not found + raise FileNotFoundError(f"Translation not found for {fileid_audio}") + + return ( + filepath, + SAMPLE_RATE, + transcript, + int(speaker_id), + int(chapter_id), + int(utterance_id), + ) + + +class LIBRISPEECH(Dataset): + """*LibriSpeech* :cite:`7178964` dataset. + + Args: + root (str or Path): Path to the directory where the dataset is found or downloaded. + url (str, optional): The URL to download the dataset from, + or the type of the dataset to dowload. + Allowed type values are ``"dev-clean"``, ``"dev-other"``, ``"test-clean"``, + ``"test-other"``, ``"train-clean-100"``, ``"train-clean-360"`` and + ``"train-other-500"``. (default: ``"train-clean-100"``) + folder_in_archive (str, optional): + The top-level directory of the dataset. (default: ``"LibriSpeech"``) + download (bool, optional): + Whether to download the dataset if it is not found at root path. (default: ``False``). + """ + + _ext_txt = ".trans.txt" + _ext_audio = ".flac" + + def __init__( + self, + root: Union[str, Path], + url: str = URL, + folder_in_archive: str = FOLDER_IN_ARCHIVE, + download: bool = False, + ) -> None: + self._url = url + if url not in _DATA_SUBSETS: + raise ValueError(f"Invalid url '{url}' given; please provide one of {_DATA_SUBSETS}.") + + root = os.fspath(root) + self._archive = os.path.join(root, folder_in_archive) + self._path = os.path.join(root, folder_in_archive, url) + + if not os.path.isdir(self._path): + if download: + _download_librispeech(root, url) + else: + raise RuntimeError( + f"Dataset not found at {self._path}. Please set `download=True` to download the dataset." + ) + + self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*/*/*" + self._ext_audio)) + + def get_metadata(self, n: int) -> Tuple[str, int, str, int, int, int]: + """Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform, + but otherwise returns the same fields as :py:func:`__getitem__`. + + Args: + n (int): The index of the sample to be loaded + + Returns: + Tuple of the following items; + + str: + Path to audio + int: + Sample rate + str: + Transcript + int: + Speaker ID + int: + Chapter ID + int: + Utterance ID + """ + fileid = self._walker[n] + return _get_librispeech_metadata(fileid, self._archive, self._url, self._ext_audio, self._ext_txt) + + def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + Tuple of the following items; + + Tensor: + Waveform + int: + Sample rate + str: + Transcript + int: + Speaker ID + int: + Chapter ID + int: + Utterance ID + """ + metadata = self.get_metadata(n) + waveform = _load_waveform(self._archive, metadata[0], metadata[1]) + return (waveform,) + metadata[1:] + + def __len__(self) -> int: + return len(self._walker) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/librispeech_biasing.py b/.venv/lib/python3.11/site-packages/torchaudio/datasets/librispeech_biasing.py new file mode 100644 index 0000000000000000000000000000000000000000..bd518cf2b69094728f8693fe2cb8a2a535bd7d3c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/datasets/librispeech_biasing.py @@ -0,0 +1,189 @@ +import os +from pathlib import Path +from typing import List, Tuple, Union + +from torch import Tensor +from torch.utils.data import Dataset +from torchaudio._internal import download_url_to_file +from torchaudio.datasets.utils import _extract_tar, _load_waveform + +URL = "train-clean-100" +FOLDER_IN_ARCHIVE = "LibriSpeech" +SAMPLE_RATE = 16000 +_DATA_SUBSETS = [ + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + "train-clean-360", + "train-other-500", +] +_CHECKSUMS = { + "http://www.openslr.org/resources/12/dev-clean.tar.gz": "76f87d090650617fca0cac8f88b9416e0ebf80350acb97b343a85fa903728ab3", # noqa: E501 + "http://www.openslr.org/resources/12/dev-other.tar.gz": "12661c48e8c3fe1de2c1caa4c3e135193bfb1811584f11f569dd12645aa84365", # noqa: E501 + "http://www.openslr.org/resources/12/test-clean.tar.gz": "39fde525e59672dc6d1551919b1478f724438a95aa55f874b576be21967e6c23", # noqa: E501 + "http://www.openslr.org/resources/12/test-other.tar.gz": "d09c181bba5cf717b3dee7d4d592af11a3ee3a09e08ae025c5506f6ebe961c29", # noqa: E501 + "http://www.openslr.org/resources/12/train-clean-100.tar.gz": "d4ddd1d5a6ab303066f14971d768ee43278a5f2a0aa43dc716b0e64ecbbbf6e2", # noqa: E501 + "http://www.openslr.org/resources/12/train-clean-360.tar.gz": "146a56496217e96c14334a160df97fffedd6e0a04e66b9c5af0d40be3c792ecf", # noqa: E501 + "http://www.openslr.org/resources/12/train-other-500.tar.gz": "ddb22f27f96ec163645d53215559df6aa36515f26e01dd70798188350adcb6d2", # noqa: E501 +} + + +def _download_librispeech(root, url): + base_url = "http://www.openslr.org/resources/12/" + ext_archive = ".tar.gz" + + filename = url + ext_archive + archive = os.path.join(root, filename) + download_url = os.path.join(base_url, filename) + if not os.path.isfile(archive): + checksum = _CHECKSUMS.get(download_url, None) + download_url_to_file(download_url, archive, hash_prefix=checksum) + _extract_tar(archive) + + +def _get_librispeech_metadata( + fileid: str, root: str, folder: str, ext_audio: str, ext_txt: str, blist: List[str] +) -> Tuple[str, int, str, int, int, int]: + blist = blist or [] + speaker_id, chapter_id, utterance_id = fileid.split("-") + + # Get audio path and sample rate + fileid_audio = f"{speaker_id}-{chapter_id}-{utterance_id}" + filepath = os.path.join(folder, speaker_id, chapter_id, f"{fileid_audio}{ext_audio}") + + # Load text + file_text = f"{speaker_id}-{chapter_id}{ext_txt}" + file_text = os.path.join(root, folder, speaker_id, chapter_id, file_text) + uttblist = [] + with open(file_text) as ft: + for line in ft: + fileid_text, transcript = line.strip().split(" ", 1) + if fileid_audio == fileid_text: + # get utterance biasing list + for word in transcript.split(): + if word in blist and word not in uttblist: + uttblist.append(word) + break + else: + # Translation not found + raise FileNotFoundError(f"Translation not found for {fileid_audio}") + + return ( + filepath, + SAMPLE_RATE, + transcript, + int(speaker_id), + int(chapter_id), + int(utterance_id), + uttblist, + ) + + +class LibriSpeechBiasing(Dataset): + """*LibriSpeech* :cite:`7178964` dataset with prefix-tree construction and biasing support. + + Args: + root (str or Path): Path to the directory where the dataset is found or downloaded. + url (str, optional): The URL to download the dataset from, + or the type of the dataset to dowload. + Allowed type values are ``"dev-clean"``, ``"dev-other"``, ``"test-clean"``, + ``"test-other"``, ``"train-clean-100"``, ``"train-clean-360"`` and + ``"train-other-500"``. (default: ``"train-clean-100"``) + folder_in_archive (str, optional): + The top-level directory of the dataset. (default: ``"LibriSpeech"``) + download (bool, optional): + Whether to download the dataset if it is not found at root path. (default: ``False``). + blist (list, optional): + The list of biasing words (default: ``[]``). + """ + + _ext_txt = ".trans.txt" + _ext_audio = ".flac" + + def __init__( + self, + root: Union[str, Path], + url: str = URL, + folder_in_archive: str = FOLDER_IN_ARCHIVE, + download: bool = False, + blist: List[str] = None, + ) -> None: + self._url = url + if url not in _DATA_SUBSETS: + raise ValueError(f"Invalid url '{url}' given; please provide one of {_DATA_SUBSETS}.") + + root = os.fspath(root) + self._archive = os.path.join(root, folder_in_archive) + self._path = os.path.join(root, folder_in_archive, url) + + if not os.path.isdir(self._path): + if download: + _download_librispeech(root, url) + else: + raise RuntimeError( + f"Dataset not found at {self._path}. Please set `download=True` to download the dataset." + ) + + self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*/*/*" + self._ext_audio)) + self.blist = blist + + def get_metadata(self, n: int) -> Tuple[str, int, str, int, int, int]: + """Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform, + but otherwise returns the same fields as :py:func:`__getitem__`. + + Args: + n (int): The index of the sample to be loaded + + Returns: + Tuple of the following items; + + str: + Path to audio + int: + Sample rate + str: + Transcript + int: + Speaker ID + int: + Chapter ID + int: + Utterance ID + list: + List of biasing words in the utterance + """ + fileid = self._walker[n] + return _get_librispeech_metadata(fileid, self._archive, self._url, self._ext_audio, self._ext_txt, self.blist) + + def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + Tuple of the following items; + + Tensor: + Waveform + int: + Sample rate + str: + Transcript + int: + Speaker ID + int: + Chapter ID + int: + Utterance ID + list: + List of biasing words in the utterance + """ + metadata = self.get_metadata(n) + waveform = _load_waveform(self._archive, metadata[0], metadata[1]) + return (waveform,) + metadata[1:] + + def __len__(self) -> int: + return len(self._walker) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/libritts.py b/.venv/lib/python3.11/site-packages/torchaudio/datasets/libritts.py new file mode 100644 index 0000000000000000000000000000000000000000..829ce9572920c31ec7a4b393379f779a7df14ea9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/datasets/libritts.py @@ -0,0 +1,168 @@ +import os +from pathlib import Path +from typing import Tuple, Union + +import torchaudio +from torch import Tensor +from torch.utils.data import Dataset +from torchaudio._internal import download_url_to_file +from torchaudio.datasets.utils import _extract_tar + +URL = "train-clean-100" +FOLDER_IN_ARCHIVE = "LibriTTS" +_CHECKSUMS = { + "http://www.openslr.org/resources/60/dev-clean.tar.gz": "da0864e1bd26debed35da8a869dd5c04dfc27682921936de7cff9c8a254dbe1a", # noqa: E501 + "http://www.openslr.org/resources/60/dev-other.tar.gz": "d413eda26f3a152ac7c9cf3658ef85504dfb1b625296e5fa83727f5186cca79c", # noqa: E501 + "http://www.openslr.org/resources/60/test-clean.tar.gz": "234ea5b25859102a87024a4b9b86641f5b5aaaf1197335c95090cde04fe9a4f5", # noqa: E501 + "http://www.openslr.org/resources/60/test-other.tar.gz": "33a5342094f3bba7ccc2e0500b9e72d558f72eb99328ac8debe1d9080402f10d", # noqa: E501 + "http://www.openslr.org/resources/60/train-clean-100.tar.gz": "c5608bf1ef74bb621935382b8399c5cdd51cd3ee47cec51f00f885a64c6c7f6b", # noqa: E501 + "http://www.openslr.org/resources/60/train-clean-360.tar.gz": "ce7cff44dcac46009d18379f37ef36551123a1dc4e5c8e4eb73ae57260de4886", # noqa: E501 + "http://www.openslr.org/resources/60/train-other-500.tar.gz": "e35f7e34deeb2e2bdfe4403d88c8fdd5fbf64865cae41f027a185a6965f0a5df", # noqa: E501 +} + + +def load_libritts_item( + fileid: str, + path: str, + ext_audio: str, + ext_original_txt: str, + ext_normalized_txt: str, +) -> Tuple[Tensor, int, str, str, int, int, str]: + speaker_id, chapter_id, segment_id, utterance_id = fileid.split("_") + utterance_id = fileid + + normalized_text = utterance_id + ext_normalized_txt + normalized_text = os.path.join(path, speaker_id, chapter_id, normalized_text) + + original_text = utterance_id + ext_original_txt + original_text = os.path.join(path, speaker_id, chapter_id, original_text) + + file_audio = utterance_id + ext_audio + file_audio = os.path.join(path, speaker_id, chapter_id, file_audio) + + # Load audio + waveform, sample_rate = torchaudio.load(file_audio) + + # Load original text + with open(original_text) as ft: + original_text = ft.readline() + + # Load normalized text + with open(normalized_text, "r") as ft: + normalized_text = ft.readline() + + return ( + waveform, + sample_rate, + original_text, + normalized_text, + int(speaker_id), + int(chapter_id), + utterance_id, + ) + + +class LIBRITTS(Dataset): + """*LibriTTS* :cite:`Zen2019LibriTTSAC` dataset. + + Args: + root (str or Path): Path to the directory where the dataset is found or downloaded. + url (str, optional): The URL to download the dataset from, + or the type of the dataset to dowload. + Allowed type values are ``"dev-clean"``, ``"dev-other"``, ``"test-clean"``, + ``"test-other"``, ``"train-clean-100"``, ``"train-clean-360"`` and + ``"train-other-500"``. (default: ``"train-clean-100"``) + folder_in_archive (str, optional): + The top-level directory of the dataset. (default: ``"LibriTTS"``) + download (bool, optional): + Whether to download the dataset if it is not found at root path. (default: ``False``). + """ + + _ext_original_txt = ".original.txt" + _ext_normalized_txt = ".normalized.txt" + _ext_audio = ".wav" + + def __init__( + self, + root: Union[str, Path], + url: str = URL, + folder_in_archive: str = FOLDER_IN_ARCHIVE, + download: bool = False, + ) -> None: + + if url in [ + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + "train-clean-360", + "train-other-500", + ]: + + ext_archive = ".tar.gz" + base_url = "http://www.openslr.org/resources/60/" + + url = os.path.join(base_url, url + ext_archive) + + # Get string representation of 'root' in case Path object is passed + root = os.fspath(root) + + basename = os.path.basename(url) + archive = os.path.join(root, basename) + + basename = basename.split(".")[0] + folder_in_archive = os.path.join(folder_in_archive, basename) + + self._path = os.path.join(root, folder_in_archive) + + if download: + if not os.path.isdir(self._path): + if not os.path.isfile(archive): + checksum = _CHECKSUMS.get(url, None) + download_url_to_file(url, archive, hash_prefix=checksum) + _extract_tar(archive) + else: + if not os.path.exists(self._path): + raise RuntimeError( + f"The path {self._path} doesn't exist. " + "Please check the ``root`` path or set `download=True` to download it" + ) + + self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*/*/*" + self._ext_audio)) + + def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int, int, str]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + Tuple of the following items; + + Tensor: + Waveform + int: + Sample rate + str: + Original text + str: + Normalized text + int: + Speaker ID + int: + Chapter ID + str: + Utterance ID + """ + fileid = self._walker[n] + return load_libritts_item( + fileid, + self._path, + self._ext_audio, + self._ext_original_txt, + self._ext_normalized_txt, + ) + + def __len__(self) -> int: + return len(self._walker) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/ljspeech.py b/.venv/lib/python3.11/site-packages/torchaudio/datasets/ljspeech.py new file mode 100644 index 0000000000000000000000000000000000000000..9cdaeeb0f3e67a29fc57e9d0e9ed3056d98c24df --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/datasets/ljspeech.py @@ -0,0 +1,107 @@ +import csv +import os +from pathlib import Path +from typing import Tuple, Union + +import torchaudio +from torch import Tensor +from torch.utils.data import Dataset +from torchaudio._internal import download_url_to_file +from torchaudio.datasets.utils import _extract_tar + + +_RELEASE_CONFIGS = { + "release1": { + "folder_in_archive": "wavs", + "url": "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2", + "checksum": "be1a30453f28eb8dd26af4101ae40cbf2c50413b1bb21936cbcdc6fae3de8aa5", + } +} + + +class LJSPEECH(Dataset): + """*LJSpeech-1.1* :cite:`ljspeech17` dataset. + + Args: + root (str or Path): Path to the directory where the dataset is found or downloaded. + url (str, optional): The URL to download the dataset from. + (default: ``"https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"``) + folder_in_archive (str, optional): + The top-level directory of the dataset. (default: ``"wavs"``) + download (bool, optional): + Whether to download the dataset if it is not found at root path. (default: ``False``). + """ + + def __init__( + self, + root: Union[str, Path], + url: str = _RELEASE_CONFIGS["release1"]["url"], + folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"], + download: bool = False, + ) -> None: + + self._parse_filesystem(root, url, folder_in_archive, download) + + def _parse_filesystem(self, root: str, url: str, folder_in_archive: str, download: bool) -> None: + root = Path(root) + + basename = os.path.basename(url) + archive = root / basename + + basename = Path(basename.split(".tar.bz2")[0]) + folder_in_archive = basename / folder_in_archive + + self._path = root / folder_in_archive + self._metadata_path = root / basename / "metadata.csv" + + if download: + if not os.path.isdir(self._path): + if not os.path.isfile(archive): + checksum = _RELEASE_CONFIGS["release1"]["checksum"] + download_url_to_file(url, archive, hash_prefix=checksum) + _extract_tar(archive) + else: + if not os.path.exists(self._path): + raise RuntimeError( + f"The path {self._path} doesn't exist. " + "Please check the ``root`` path or set `download=True` to download it" + ) + + with open(self._metadata_path, "r", newline="") as metadata: + flist = csv.reader(metadata, delimiter="|", quoting=csv.QUOTE_NONE) + self._flist = list(flist) + + def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + Tuple of the following items; + + Tensor: + Waveform + int: + Sample rate + str: + Transcript + str: + Normalized Transcript + """ + line = self._flist[n] + fileid, transcript, normalized_transcript = line + fileid_audio = self._path / (fileid + ".wav") + + # Load audio + waveform, sample_rate = torchaudio.load(fileid_audio) + + return ( + waveform, + sample_rate, + transcript, + normalized_transcript, + ) + + def __len__(self) -> int: + return len(self._flist) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/musdb_hq.py b/.venv/lib/python3.11/site-packages/torchaudio/datasets/musdb_hq.py new file mode 100644 index 0000000000000000000000000000000000000000..dd4bc9f340f3fde076ea31a683a7b41b7b3741d7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/datasets/musdb_hq.py @@ -0,0 +1,139 @@ +import os +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import torch +import torchaudio +from torch.utils.data import Dataset +from torchaudio._internal import download_url_to_file +from torchaudio.datasets.utils import _extract_zip + +_URL = "https://zenodo.org/record/3338373/files/musdb18hq.zip" +_CHECKSUM = "baac80d0483c61d74b2e5f3be75fa557eec52898339e6aa45c1fa48833c5d21d" +_EXT = ".wav" +_SAMPLE_RATE = 44100 +_VALIDATION_SET = [ + "Actions - One Minute Smile", + "Clara Berry And Wooldog - Waltz For My Victims", + "Johnny Lokke - Promises & Lies", + "Patrick Talbot - A Reason To Leave", + "Triviul - Angelsaint", + "Alexander Ross - Goodbye Bolero", + "Fergessen - Nos Palpitants", + "Leaf - Summerghost", + "Skelpolu - Human Mistakes", + "Young Griffo - Pennies", + "ANiMAL - Rockshow", + "James May - On The Line", + "Meaxic - Take A Step", + "Traffic Experiment - Sirens", +] + + +class MUSDB_HQ(Dataset): + """*MUSDB_HQ* :cite:`MUSDB18HQ` dataset. + + Args: + root (str or Path): Root directory where the dataset's top level directory is found + subset (str): Subset of the dataset to use. Options: [``"train"``, ``"test"``]. + sources (List[str] or None, optional): Sources extract data from. + List can contain the following options: [``"bass"``, ``"drums"``, ``"other"``, ``"mixture"``, ``"vocals"``]. + If ``None``, dataset consists of tracks except mixture. + (default: ``None``) + split (str or None, optional): Whether to split training set into train and validation set. + If ``None``, no splitting occurs. If ``train`` or ``validation``, returns respective set. + (default: ``None``) + download (bool, optional): Whether to download the dataset if it is not found at root path. + (default: ``False``) + """ + + def __init__( + self, + root: Union[str, Path], + subset: str, + sources: Optional[List[str]] = None, + split: Optional[str] = None, + download: bool = False, + ) -> None: + self.sources = ["bass", "drums", "other", "vocals"] if not sources else sources + self.split = split + + basename = os.path.basename(_URL) + archive = os.path.join(root, basename) + basename = basename.rsplit(".", 2)[0] + + if subset not in ["test", "train"]: + raise ValueError("`subset` must be one of ['test', 'train']") + if self.split is not None and self.split not in ["train", "validation"]: + raise ValueError("`split` must be one of ['train', 'validation']") + base_path = os.path.join(root, basename) + self._path = os.path.join(base_path, subset) + if not os.path.isdir(self._path): + if not os.path.isfile(archive): + if not download: + raise RuntimeError("Dataset not found. Please use `download=True` to download") + download_url_to_file(_URL, archive, hash_prefix=_CHECKSUM) + os.makedirs(base_path, exist_ok=True) + _extract_zip(archive, base_path) + + self.names = self._collect_songs() + + def _get_track(self, name, source): + return Path(self._path) / name / f"{source}{_EXT}" + + def _load_sample(self, n: int) -> Tuple[torch.Tensor, int, int, str]: + name = self.names[n] + wavs = [] + num_frames = None + for source in self.sources: + track = self._get_track(name, source) + wav, sr = torchaudio.load(str(track)) + if sr != _SAMPLE_RATE: + raise ValueError(f"expected sample rate {_SAMPLE_RATE}, but got {sr}") + if num_frames is None: + num_frames = wav.shape[-1] + else: + if wav.shape[-1] != num_frames: + raise ValueError("num_frames do not match across sources") + wavs.append(wav) + + stacked = torch.stack(wavs) + + return stacked, _SAMPLE_RATE, num_frames, name + + def _collect_songs(self): + if self.split == "validation": + return _VALIDATION_SET + path = Path(self._path) + names = [] + for root, folders, _ in os.walk(path, followlinks=True): + root = Path(root) + if root.name.startswith(".") or folders or root == path: + continue + name = str(root.relative_to(path)) + if self.split and name in _VALIDATION_SET: + continue + names.append(name) + return sorted(names) + + def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, int, str]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + Returns: + Tuple of the following items; + + Tensor: + Waveform + int: + Sample rate + int: + Num frames + str: + Track name + """ + return self._load_sample(n) + + def __len__(self) -> int: + return len(self.names) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/quesst14.py b/.venv/lib/python3.11/site-packages/torchaudio/datasets/quesst14.py new file mode 100644 index 0000000000000000000000000000000000000000..064423c4494850f2ad8f43fb00a956be21fcb95e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/datasets/quesst14.py @@ -0,0 +1,136 @@ +import os +import re +from pathlib import Path +from typing import Optional, Tuple, Union + +import torch +from torch.utils.data import Dataset +from torchaudio._internal import download_url_to_file +from torchaudio.datasets.utils import _extract_tar, _load_waveform + + +URL = "https://speech.fit.vutbr.cz/files/quesst14Database.tgz" +SAMPLE_RATE = 8000 +_CHECKSUM = "4f869e06bc066bbe9c5dde31dbd3909a0870d70291110ebbb38878dcbc2fc5e4" +_LANGUAGES = [ + "albanian", + "basque", + "czech", + "nnenglish", + "romanian", + "slovak", +] + + +class QUESST14(Dataset): + """*QUESST14* :cite:`Mir2015QUESST2014EQ` dataset. + + Args: + root (str or Path): Root directory where the dataset's top level directory is found + subset (str): Subset of the dataset to use. Options: [``"docs"``, ``"dev"``, ``"eval"``]. + language (str or None, optional): Language to get dataset for. + Options: [``None``, ``albanian``, ``basque``, ``czech``, ``nnenglish``, ``romanian``, ``slovak``]. + If ``None``, dataset consists of all languages. (default: ``"nnenglish"``) + download (bool, optional): Whether to download the dataset if it is not found at root path. + (default: ``False``) + """ + + def __init__( + self, + root: Union[str, Path], + subset: str, + language: Optional[str] = "nnenglish", + download: bool = False, + ) -> None: + if subset not in ["docs", "dev", "eval"]: + raise ValueError("`subset` must be one of ['docs', 'dev', 'eval']") + + if language is not None and language not in _LANGUAGES: + raise ValueError(f"`language` must be None or one of {str(_LANGUAGES)}") + + # Get string representation of 'root' + root = os.fspath(root) + + basename = os.path.basename(URL) + archive = os.path.join(root, basename) + + basename = basename.rsplit(".", 2)[0] + self._path = os.path.join(root, basename) + + if not os.path.isdir(self._path): + if not os.path.isfile(archive): + if not download: + raise RuntimeError("Dataset not found. Please use `download=True` to download") + download_url_to_file(URL, archive, hash_prefix=_CHECKSUM) + _extract_tar(archive, root) + + if subset == "docs": + self.data = filter_audio_paths(self._path, language, "language_key_utterances.lst") + elif subset == "dev": + self.data = filter_audio_paths(self._path, language, "language_key_dev.lst") + elif subset == "eval": + self.data = filter_audio_paths(self._path, language, "language_key_eval.lst") + + def get_metadata(self, n: int) -> Tuple[str, int, str]: + """Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform, + but otherwise returns the same fields as :py:func:`__getitem__`. + + Args: + n (int): The index of the sample to be loaded + + Returns: + Tuple of the following items; + + str: + Path to audio + int: + Sample rate + str: + File name + """ + audio_path = self.data[n] + relpath = os.path.relpath(audio_path, self._path) + return relpath, SAMPLE_RATE, audio_path.with_suffix("").name + + def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + Tuple of the following items; + + Tensor: + Waveform + int: + Sample rate + str: + File name + """ + metadata = self.get_metadata(n) + waveform = _load_waveform(self._path, metadata[0], metadata[1]) + return (waveform,) + metadata[1:] + + def __len__(self) -> int: + return len(self.data) + + +def filter_audio_paths( + path: str, + language: str, + lst_name: str, +): + """Extract audio paths for the given language.""" + audio_paths = [] + + path = Path(path) + with open(path / "scoring" / lst_name) as f: + for line in f: + audio_path, lang = line.strip().split() + if language is not None and lang != language: + continue + audio_path = re.sub(r"^.*?\/", "", audio_path) + audio_paths.append(path / audio_path) + + return audio_paths diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/snips.py b/.venv/lib/python3.11/site-packages/torchaudio/datasets/snips.py new file mode 100644 index 0000000000000000000000000000000000000000..6b15d677f7fa1f9c1baccad7625a6fa14c73d70f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/datasets/snips.py @@ -0,0 +1,157 @@ +import os +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import torch +from torch.utils.data import Dataset +from torchaudio.datasets.utils import _load_waveform + + +_SAMPLE_RATE = 16000 +_SPEAKERS = [ + "Aditi", + "Amy", + "Brian", + "Emma", + "Geraint", + "Ivy", + "Joanna", + "Joey", + "Justin", + "Kendra", + "Kimberly", + "Matthew", + "Nicole", + "Raveena", + "Russell", + "Salli", +] + + +def _load_labels(file: Path, subset: str): + """Load transcirpt, iob, and intent labels for all utterances. + + Args: + file (Path): The path to the label file. + subset (str): Subset of the dataset to use. Options: [``"train"``, ``"valid"``, ``"test"``]. + + Returns: + Dictionary of labels, where the key is the filename of the audio, + and the label is a Tuple of transcript, Inside–outside–beginning (IOB) label, and intention label. + """ + labels = {} + with open(file, "r") as f: + for line in f: + line = line.strip().split(" ") + index = line[0] + trans, iob_intent = " ".join(line[1:]).split("\t") + trans = " ".join(trans.split(" ")[1:-1]) + iob = " ".join(iob_intent.split(" ")[1:-1]) + intent = iob_intent.split(" ")[-1] + if subset in index: + labels[index] = (trans, iob, intent) + return labels + + +class Snips(Dataset): + """*Snips* :cite:`coucke2018snips` dataset. + + Args: + root (str or Path): Root directory where the dataset's top level directory is found. + subset (str): Subset of the dataset to use. Options: [``"train"``, ``"valid"``, ``"test"``]. + speakers (List[str] or None, optional): The speaker list to include in the dataset. If ``None``, + include all speakers in the subset. (Default: ``None``) + audio_format (str, optional): The extension of the audios. Options: [``"mp3"``, ``"wav"``]. + (Default: ``"mp3"``) + """ + + _trans_file = "all.iob.snips.txt" + + def __init__( + self, + root: Union[str, Path], + subset: str, + speakers: Optional[List[str]] = None, + audio_format: str = "mp3", + ) -> None: + if subset not in ["train", "valid", "test"]: + raise ValueError('`subset` must be one of ["train", "valid", "test"].') + if audio_format not in ["mp3", "wav"]: + raise ValueError('`audio_format` must be one of ["mp3", "wav].') + + root = Path(root) + self._path = root / "SNIPS" + self.audio_path = self._path / subset + if speakers is None: + speakers = _SPEAKERS + + if not os.path.isdir(self._path): + raise RuntimeError("Dataset not found.") + + self.audio_paths = self.audio_path.glob(f"*.{audio_format}") + self.data = [] + for audio_path in sorted(self.audio_paths): + audio_name = str(audio_path.name) + speaker = audio_name.split("-")[0] + if speaker in speakers: + self.data.append(audio_path) + transcript_path = self._path / self._trans_file + self.labels = _load_labels(transcript_path, subset) + + def get_metadata(self, n: int) -> Tuple[str, int, str, str, str]: + """Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform, + but otherwise returns the same fields as :py:func:`__getitem__`. + + Args: + n (int): The index of the sample to be loaded. + + Returns: + Tuple of the following items: + + str: + Path to audio + int: + Sample rate + str: + File name + str: + Transcription of audio + str: + Inside–outside–beginning (IOB) label of transcription + str: + Intention label of the audio. + """ + audio_path = self.data[n] + relpath = os.path.relpath(audio_path, self._path) + file_name = audio_path.with_suffix("").name + transcript, iob, intent = self.labels[file_name] + return relpath, _SAMPLE_RATE, file_name, transcript, iob, intent + + def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str, str, str]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + Tuple of the following items: + + Tensor: + Waveform + int: + Sample rate + str: + File name + str: + Transcription of audio + str: + Inside–outside–beginning (IOB) label of transcription + str: + Intention label of the audio. + """ + metadata = self.get_metadata(n) + waveform = _load_waveform(self._path, metadata[0], metadata[1]) + return (waveform,) + metadata[1:] + + def __len__(self) -> int: + return len(self.data) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/speechcommands.py b/.venv/lib/python3.11/site-packages/torchaudio/datasets/speechcommands.py new file mode 100644 index 0000000000000000000000000000000000000000..1945fc75c18b474404b733e43d50156f3c3d6652 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/datasets/speechcommands.py @@ -0,0 +1,183 @@ +import os +from pathlib import Path +from typing import Optional, Tuple, Union + +from torch import Tensor +from torch.utils.data import Dataset +from torchaudio._internal import download_url_to_file +from torchaudio.datasets.utils import _extract_tar, _load_waveform + +FOLDER_IN_ARCHIVE = "SpeechCommands" +URL = "speech_commands_v0.02" +HASH_DIVIDER = "_nohash_" +EXCEPT_FOLDER = "_background_noise_" +SAMPLE_RATE = 16000 +_CHECKSUMS = { + "http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz": "743935421bb51cccdb6bdd152e04c5c70274e935c82119ad7faeec31780d811d", # noqa: E501 + "http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz": "af14739ee7dc311471de98f5f9d2c9191b18aedfe957f4a6ff791c709868ff58", # noqa: E501 +} + + +def _load_list(root, *filenames): + output = [] + for filename in filenames: + filepath = os.path.join(root, filename) + with open(filepath) as fileobj: + output += [os.path.normpath(os.path.join(root, line.strip())) for line in fileobj] + return output + + +def _get_speechcommands_metadata(filepath: str, path: str) -> Tuple[str, int, str, str, int]: + relpath = os.path.relpath(filepath, path) + reldir, filename = os.path.split(relpath) + _, label = os.path.split(reldir) + # Besides the officially supported split method for datasets defined by "validation_list.txt" + # and "testing_list.txt" over "speech_commands_v0.0x.tar.gz" archives, an alternative split + # method referred to in paragraph 2-3 of Section 7.1, references 13 and 14 of the original + # paper, and the checksums file from the tensorflow_datasets package [1] is also supported. + # Some filenames in those "speech_commands_test_set_v0.0x.tar.gz" archives have the form + # "xxx.wav.wav", so file extensions twice needs to be stripped twice. + # [1] https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/url_checksums/speech_commands.txt + speaker, _ = os.path.splitext(filename) + speaker, _ = os.path.splitext(speaker) + + speaker_id, utterance_number = speaker.split(HASH_DIVIDER) + utterance_number = int(utterance_number) + + return relpath, SAMPLE_RATE, label, speaker_id, utterance_number + + +class SPEECHCOMMANDS(Dataset): + """*Speech Commands* :cite:`speechcommandsv2` dataset. + + Args: + root (str or Path): Path to the directory where the dataset is found or downloaded. + url (str, optional): The URL to download the dataset from, + or the type of the dataset to dowload. + Allowed type values are ``"speech_commands_v0.01"`` and ``"speech_commands_v0.02"`` + (default: ``"speech_commands_v0.02"``) + folder_in_archive (str, optional): + The top-level directory of the dataset. (default: ``"SpeechCommands"``) + download (bool, optional): + Whether to download the dataset if it is not found at root path. (default: ``False``). + subset (str or None, optional): + Select a subset of the dataset [None, "training", "validation", "testing"]. None means + the whole dataset. "validation" and "testing" are defined in "validation_list.txt" and + "testing_list.txt", respectively, and "training" is the rest. Details for the files + "validation_list.txt" and "testing_list.txt" are explained in the README of the dataset + and in the introduction of Section 7 of the original paper and its reference 12. The + original paper can be found `here `_. (Default: ``None``) + """ + + def __init__( + self, + root: Union[str, Path], + url: str = URL, + folder_in_archive: str = FOLDER_IN_ARCHIVE, + download: bool = False, + subset: Optional[str] = None, + ) -> None: + + if subset is not None and subset not in ["training", "validation", "testing"]: + raise ValueError("When `subset` is not None, it must be one of ['training', 'validation', 'testing'].") + + if url in [ + "speech_commands_v0.01", + "speech_commands_v0.02", + ]: + base_url = "http://download.tensorflow.org/data/" + ext_archive = ".tar.gz" + + url = os.path.join(base_url, url + ext_archive) + + # Get string representation of 'root' in case Path object is passed + root = os.fspath(root) + self._archive = os.path.join(root, folder_in_archive) + + basename = os.path.basename(url) + archive = os.path.join(root, basename) + + basename = basename.rsplit(".", 2)[0] + folder_in_archive = os.path.join(folder_in_archive, basename) + + self._path = os.path.join(root, folder_in_archive) + + if download: + if not os.path.isdir(self._path): + if not os.path.isfile(archive): + checksum = _CHECKSUMS.get(url, None) + download_url_to_file(url, archive, hash_prefix=checksum) + _extract_tar(archive, self._path) + else: + if not os.path.exists(self._path): + raise RuntimeError( + f"The path {self._path} doesn't exist. " + "Please check the ``root`` path or set `download=True` to download it" + ) + + if subset == "validation": + self._walker = _load_list(self._path, "validation_list.txt") + elif subset == "testing": + self._walker = _load_list(self._path, "testing_list.txt") + elif subset == "training": + excludes = set(_load_list(self._path, "validation_list.txt", "testing_list.txt")) + walker = sorted(str(p) for p in Path(self._path).glob("*/*.wav")) + self._walker = [ + w + for w in walker + if HASH_DIVIDER in w and EXCEPT_FOLDER not in w and os.path.normpath(w) not in excludes + ] + else: + walker = sorted(str(p) for p in Path(self._path).glob("*/*.wav")) + self._walker = [w for w in walker if HASH_DIVIDER in w and EXCEPT_FOLDER not in w] + + def get_metadata(self, n: int) -> Tuple[str, int, str, str, int]: + """Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform, + but otherwise returns the same fields as :py:func:`__getitem__`. + + Args: + n (int): The index of the sample to be loaded + + Returns: + Tuple of the following items; + + str: + Path to the audio + int: + Sample rate + str: + Label + str: + Speaker ID + int: + Utterance number + """ + fileid = self._walker[n] + return _get_speechcommands_metadata(fileid, self._archive) + + def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + Tuple of the following items; + + Tensor: + Waveform + int: + Sample rate + str: + Label + str: + Speaker ID + int: + Utterance number + """ + metadata = self.get_metadata(n) + waveform = _load_waveform(self._archive, metadata[0], metadata[1]) + return (waveform,) + metadata[1:] + + def __len__(self) -> int: + return len(self._walker) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/tedlium.py b/.venv/lib/python3.11/site-packages/torchaudio/datasets/tedlium.py new file mode 100644 index 0000000000000000000000000000000000000000..7e7d22195a772d18770f6db3253d83672743c81c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/datasets/tedlium.py @@ -0,0 +1,218 @@ +import os +from pathlib import Path +from typing import Tuple, Union + +import torchaudio +from torch import Tensor +from torch.utils.data import Dataset +from torchaudio._internal import download_url_to_file +from torchaudio.datasets.utils import _extract_tar + + +_RELEASE_CONFIGS = { + "release1": { + "folder_in_archive": "TEDLIUM_release1", + "url": "http://www.openslr.org/resources/7/TEDLIUM_release1.tar.gz", + "checksum": "30301975fd8c5cac4040c261c0852f57cfa8adbbad2ce78e77e4986957445f27", + "data_path": "", + "subset": "train", + "supported_subsets": ["train", "test", "dev"], + "dict": "TEDLIUM.150K.dic", + }, + "release2": { + "folder_in_archive": "TEDLIUM_release2", + "url": "http://www.openslr.org/resources/19/TEDLIUM_release2.tar.gz", + "checksum": "93281b5fcaaae5c88671c9d000b443cb3c7ea3499ad12010b3934ca41a7b9c58", + "data_path": "", + "subset": "train", + "supported_subsets": ["train", "test", "dev"], + "dict": "TEDLIUM.152k.dic", + }, + "release3": { + "folder_in_archive": "TEDLIUM_release-3", + "url": "http://www.openslr.org/resources/51/TEDLIUM_release-3.tgz", + "checksum": "ad1e454d14d1ad550bc2564c462d87c7a7ec83d4dc2b9210f22ab4973b9eccdb", + "data_path": "data/", + "subset": "train", + "supported_subsets": ["train", "test", "dev"], + "dict": "TEDLIUM.152k.dic", + }, +} + + +class TEDLIUM(Dataset): + """*Tedlium* :cite:`rousseau2012tedlium` dataset (releases 1,2 and 3). + + Args: + root (str or Path): Path to the directory where the dataset is found or downloaded. + release (str, optional): Release version. + Allowed values are ``"release1"``, ``"release2"`` or ``"release3"``. + (default: ``"release1"``). + subset (str, optional): The subset of dataset to use. Valid options are ``"train"``, ``"dev"``, + and ``"test"``. Defaults to ``"train"``. + download (bool, optional): + Whether to download the dataset if it is not found at root path. (default: ``False``). + audio_ext (str, optional): extension for audio file (default: ``".sph"``) + """ + + def __init__( + self, + root: Union[str, Path], + release: str = "release1", + subset: str = "train", + download: bool = False, + audio_ext: str = ".sph", + ) -> None: + self._ext_audio = audio_ext + if release in _RELEASE_CONFIGS.keys(): + folder_in_archive = _RELEASE_CONFIGS[release]["folder_in_archive"] + url = _RELEASE_CONFIGS[release]["url"] + subset = subset if subset else _RELEASE_CONFIGS[release]["subset"] + else: + # Raise warning + raise RuntimeError( + "The release {} does not match any of the supported tedlium releases{} ".format( + release, + _RELEASE_CONFIGS.keys(), + ) + ) + if subset not in _RELEASE_CONFIGS[release]["supported_subsets"]: + # Raise warning + raise RuntimeError( + "The subset {} does not match any of the supported tedlium subsets{} ".format( + subset, + _RELEASE_CONFIGS[release]["supported_subsets"], + ) + ) + + # Get string representation of 'root' in case Path object is passed + root = os.fspath(root) + + basename = os.path.basename(url) + archive = os.path.join(root, basename) + + basename = basename.split(".")[0] + + if release == "release3": + if subset == "train": + self._path = os.path.join(root, folder_in_archive, _RELEASE_CONFIGS[release]["data_path"]) + else: + self._path = os.path.join(root, folder_in_archive, "legacy", subset) + else: + self._path = os.path.join(root, folder_in_archive, _RELEASE_CONFIGS[release]["data_path"], subset) + + if download: + if not os.path.isdir(self._path): + if not os.path.isfile(archive): + checksum = _RELEASE_CONFIGS[release]["checksum"] + download_url_to_file(url, archive, hash_prefix=checksum) + _extract_tar(archive) + else: + if not os.path.exists(self._path): + raise RuntimeError( + f"The path {self._path} doesn't exist. " + "Please check the ``root`` path or set `download=True` to download it" + ) + + # Create list for all samples + self._filelist = [] + stm_path = os.path.join(self._path, "stm") + for file in sorted(os.listdir(stm_path)): + if file.endswith(".stm"): + stm_path = os.path.join(self._path, "stm", file) + with open(stm_path) as f: + l = len(f.readlines()) + file = file.replace(".stm", "") + self._filelist.extend((file, line) for line in range(l)) + # Create dict path for later read + self._dict_path = os.path.join(root, folder_in_archive, _RELEASE_CONFIGS[release]["dict"]) + self._phoneme_dict = None + + def _load_tedlium_item(self, fileid: str, line: int, path: str) -> Tuple[Tensor, int, str, int, int, int]: + """Loads a TEDLIUM dataset sample given a file name and corresponding sentence name. + + Args: + fileid (str): File id to identify both text and audio files corresponding to the sample + line (int): Line identifier for the sample inside the text file + path (str): Dataset root path + + Returns: + (Tensor, int, str, int, int, int): + ``(waveform, sample_rate, transcript, talk_id, speaker_id, identifier)`` + """ + transcript_path = os.path.join(path, "stm", fileid) + with open(transcript_path + ".stm") as f: + transcript = f.readlines()[line] + talk_id, _, speaker_id, start_time, end_time, identifier, transcript = transcript.split(" ", 6) + + wave_path = os.path.join(path, "sph", fileid) + waveform, sample_rate = self._load_audio(wave_path + self._ext_audio, start_time=start_time, end_time=end_time) + + return (waveform, sample_rate, transcript, talk_id, speaker_id, identifier) + + def _load_audio(self, path: str, start_time: float, end_time: float, sample_rate: int = 16000) -> [Tensor, int]: + """Default load function used in TEDLIUM dataset, you can overwrite this function to customize functionality + and load individual sentences from a full ted audio talk file. + + Args: + path (str): Path to audio file + start_time (int): Time in seconds where the sample sentence stars + end_time (int): Time in seconds where the sample sentence finishes + sample_rate (float, optional): Sampling rate + + Returns: + [Tensor, int]: Audio tensor representation and sample rate + """ + start_time = int(float(start_time) * sample_rate) + end_time = int(float(end_time) * sample_rate) + + kwargs = {"frame_offset": start_time, "num_frames": end_time - start_time} + + return torchaudio.load(path, **kwargs) + + def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + Tuple of the following items; + + Tensor: + Waveform + int: + Sample rate + str: + Transcript + int: + Talk ID + int: + Speaker ID + int: + Identifier + """ + fileid, line = self._filelist[n] + return self._load_tedlium_item(fileid, line, self._path) + + def __len__(self) -> int: + """TEDLIUM dataset custom function overwritting len default behaviour. + + Returns: + int: TEDLIUM dataset length + """ + return len(self._filelist) + + @property + def phoneme_dict(self): + """dict[str, tuple[str]]: Phonemes. Mapping from word to tuple of phonemes. + Note that some words have empty phonemes. + """ + # Read phoneme dictionary + if not self._phoneme_dict: + self._phoneme_dict = {} + with open(self._dict_path, "r", encoding="utf-8") as f: + for line in f.readlines(): + content = line.strip().split() + self._phoneme_dict[content[0]] = tuple(content[1:]) # content[1:] can be empty list + return self._phoneme_dict.copy() diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/utils.py b/.venv/lib/python3.11/site-packages/torchaudio/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b4599f83aae535d5c4126b5d0bab4ed325f494f3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/datasets/utils.py @@ -0,0 +1,54 @@ +import logging +import os +import tarfile +import zipfile +from typing import Any, List, Optional + +import torchaudio + +_LG = logging.getLogger(__name__) + + +def _extract_tar(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]: + if to_path is None: + to_path = os.path.dirname(from_path) + with tarfile.open(from_path, "r") as tar: + files = [] + for file_ in tar: # type: Any + file_path = os.path.join(to_path, file_.name) + if file_.isfile(): + files.append(file_path) + if os.path.exists(file_path): + _LG.info("%s already extracted.", file_path) + if not overwrite: + continue + tar.extract(file_, to_path) + return files + + +def _extract_zip(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]: + if to_path is None: + to_path = os.path.dirname(from_path) + + with zipfile.ZipFile(from_path, "r") as zfile: + files = zfile.namelist() + for file_ in files: + file_path = os.path.join(to_path, file_) + if os.path.exists(file_path): + _LG.info("%s already extracted.", file_path) + if not overwrite: + continue + zfile.extract(file_, to_path) + return files + + +def _load_waveform( + root: str, + filename: str, + exp_sample_rate: int, +): + path = os.path.join(root, filename) + waveform, sample_rate = torchaudio.load(path) + if exp_sample_rate != sample_rate: + raise ValueError(f"sample rate should be {exp_sample_rate}, but got {sample_rate}") + return waveform diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/vctk.py b/.venv/lib/python3.11/site-packages/torchaudio/datasets/vctk.py new file mode 100644 index 0000000000000000000000000000000000000000..3195b9b4276b643e934baadc26c872fc690383df --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/datasets/vctk.py @@ -0,0 +1,143 @@ +import os +from typing import Tuple + +import torchaudio +from torch import Tensor +from torch.utils.data import Dataset +from torchaudio._internal import download_url_to_file +from torchaudio.datasets.utils import _extract_zip + +URL = "https://datashare.is.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip" +_CHECKSUMS = { + "https://datashare.is.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip": "f96258be9fdc2cbff6559541aae7ea4f59df3fcaf5cf963aae5ca647357e359c" # noqa: E501 +} + + +SampleType = Tuple[Tensor, int, str, str, str] + + +class VCTK_092(Dataset): + """*VCTK 0.92* :cite:`yamagishi2019vctk` dataset + + Args: + root (str): Root directory where the dataset's top level directory is found. + mic_id (str, optional): Microphone ID. Either ``"mic1"`` or ``"mic2"``. (default: ``"mic2"``) + download (bool, optional): + Whether to download the dataset if it is not found at root path. (default: ``False``). + url (str, optional): The URL to download the dataset from. + (default: ``"https://datashare.is.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip"``) + audio_ext (str, optional): Custom audio extension if dataset is converted to non-default audio format. + + Note: + * All the speeches from speaker ``p315`` will be skipped due to the lack of the corresponding text files. + * All the speeches from ``p280`` will be skipped for ``mic_id="mic2"`` due to the lack of the audio files. + * Some of the speeches from speaker ``p362`` will be skipped due to the lack of the audio files. + * See Also: https://datashare.is.ed.ac.uk/handle/10283/3443 + """ + + def __init__( + self, + root: str, + mic_id: str = "mic2", + download: bool = False, + url: str = URL, + audio_ext=".flac", + ): + if mic_id not in ["mic1", "mic2"]: + raise RuntimeError(f'`mic_id` has to be either "mic1" or "mic2". Found: {mic_id}') + + archive = os.path.join(root, "VCTK-Corpus-0.92.zip") + + self._path = os.path.join(root, "VCTK-Corpus-0.92") + self._txt_dir = os.path.join(self._path, "txt") + self._audio_dir = os.path.join(self._path, "wav48_silence_trimmed") + self._mic_id = mic_id + self._audio_ext = audio_ext + + if download: + if not os.path.isdir(self._path): + if not os.path.isfile(archive): + checksum = _CHECKSUMS.get(url, None) + download_url_to_file(url, archive, hash_prefix=checksum) + _extract_zip(archive, self._path) + + if not os.path.isdir(self._path): + raise RuntimeError("Dataset not found. Please use `download=True` to download it.") + + # Extracting speaker IDs from the folder structure + self._speaker_ids = sorted(os.listdir(self._txt_dir)) + self._sample_ids = [] + + """ + Due to some insufficient data complexity in the 0.92 version of this dataset, + we start traversing the audio folder structure in accordance with the text folder. + As some of the audio files are missing of either ``mic_1`` or ``mic_2`` but the + text is present for the same, we first check for the existence of the audio file + before adding it to the ``sample_ids`` list. + + Once the ``audio_ids`` are loaded into memory we can quickly access the list for + different parameters required by the user. + """ + for speaker_id in self._speaker_ids: + if speaker_id == "p280" and mic_id == "mic2": + continue + utterance_dir = os.path.join(self._txt_dir, speaker_id) + for utterance_file in sorted(f for f in os.listdir(utterance_dir) if f.endswith(".txt")): + utterance_id = os.path.splitext(utterance_file)[0] + audio_path_mic = os.path.join( + self._audio_dir, + speaker_id, + f"{utterance_id}_{mic_id}{self._audio_ext}", + ) + if speaker_id == "p362" and not os.path.isfile(audio_path_mic): + continue + self._sample_ids.append(utterance_id.split("_")) + + def _load_text(self, file_path) -> str: + with open(file_path) as file_path: + return file_path.readlines()[0] + + def _load_audio(self, file_path) -> Tuple[Tensor, int]: + return torchaudio.load(file_path) + + def _load_sample(self, speaker_id: str, utterance_id: str, mic_id: str) -> SampleType: + transcript_path = os.path.join(self._txt_dir, speaker_id, f"{speaker_id}_{utterance_id}.txt") + audio_path = os.path.join( + self._audio_dir, + speaker_id, + f"{speaker_id}_{utterance_id}_{mic_id}{self._audio_ext}", + ) + + # Reading text + transcript = self._load_text(transcript_path) + + # Reading FLAC + waveform, sample_rate = self._load_audio(audio_path) + + return (waveform, sample_rate, transcript, speaker_id, utterance_id) + + def __getitem__(self, n: int) -> SampleType: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + Tuple of the following items; + + Tensor: + Waveform + int: + Sample rate + str: + Transcript + str: + Speaker ID + std: + Utterance ID + """ + speaker_id, utterance_id = self._sample_ids[n] + return self._load_sample(speaker_id, utterance_id, self._mic_id) + + def __len__(self) -> int: + return len(self._sample_ids) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/voxceleb1.py b/.venv/lib/python3.11/site-packages/torchaudio/datasets/voxceleb1.py new file mode 100644 index 0000000000000000000000000000000000000000..5112fff0898a88adb1d2c33acf9bdd905ca883f3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/datasets/voxceleb1.py @@ -0,0 +1,309 @@ +import os +from pathlib import Path +from typing import List, Tuple, Union + +from torch import Tensor +from torch.utils.data import Dataset +from torchaudio._internal import download_url_to_file +from torchaudio.datasets.utils import _extract_zip, _load_waveform + + +SAMPLE_RATE = 16000 +_ARCHIVE_CONFIGS = { + "dev": { + "archive_name": "vox1_dev_wav.zip", + "urls": [ + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad", + ], + "checksums": [ + "21ec6ca843659ebc2fdbe04b530baa4f191ad4b0971912672d92c158f32226a0", + "311d21e0c8cbf33573a4fce6c80e5a279d80736274b381c394319fc557159a04", + "92b64465f2b2a3dc0e4196ae8dd6828cbe9ddd1f089419a11e4cbfe2e1750df0", + "00e6190c770b27f27d2a3dd26ee15596b17066b715ac111906861a7d09a211a5", + ], + }, + "test": { + "archive_name": "vox1_test_wav.zip", + "url": "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip", + "checksum": "8de57f347fe22b2c24526e9f444f689ecf5096fc2a92018cf420ff6b5b15eaea", + }, +} +_IDEN_SPLIT_URL = "https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/iden_split.txt" +_VERI_TEST_URL = "https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test.txt" + + +def _download_extract_wavs(root: str): + for archive in ["dev", "test"]: + archive_name = _ARCHIVE_CONFIGS[archive]["archive_name"] + archive_path = os.path.join(root, archive_name) + # The zip file of dev data is splited to 4 chunks. + # Download and combine them into one file before extraction. + if archive == "dev": + urls = _ARCHIVE_CONFIGS[archive]["urls"] + checksums = _ARCHIVE_CONFIGS[archive]["checksums"] + with open(archive_path, "wb") as f: + for url, checksum in zip(urls, checksums): + file_path = os.path.join(root, os.path.basename(url)) + download_url_to_file(url, file_path, hash_prefix=checksum) + with open(file_path, "rb") as f_split: + f.write(f_split.read()) + else: + url = _ARCHIVE_CONFIGS[archive]["url"] + checksum = _ARCHIVE_CONFIGS[archive]["checksum"] + download_url_to_file(url, archive_path, hash_prefix=checksum) + _extract_zip(archive_path) + + +def _get_flist(root: str, file_path: str, subset: str) -> List[str]: + f_list = [] + if subset == "train": + index = 1 + elif subset == "dev": + index = 2 + else: + index = 3 + with open(file_path, "r") as f: + for line in f: + id, path = line.split() + if int(id) == index: + f_list.append(path) + return sorted(f_list) + + +def _get_paired_flist(root: str, veri_test_path: str): + f_list = [] + with open(veri_test_path, "r") as f: + for line in f: + label, path1, path2 = line.split() + f_list.append((label, path1, path2)) + return f_list + + +def _get_file_id(file_path: str, _ext_audio: str): + speaker_id, youtube_id, utterance_id = file_path.split("/")[-3:] + utterance_id = utterance_id.replace(_ext_audio, "") + file_id = "-".join([speaker_id, youtube_id, utterance_id]) + return file_id + + +class VoxCeleb1(Dataset): + """*VoxCeleb1* :cite:`nagrani2017voxceleb` dataset. + + Args: + root (str or Path): Path to the directory where the dataset is found or downloaded. + download (bool, optional): + Whether to download the dataset if it is not found at root path. (Default: ``False``). + """ + + _ext_audio = ".wav" + + def __init__(self, root: Union[str, Path], download: bool = False) -> None: + # Get string representation of 'root' in case Path object is passed + root = os.fspath(root) + self._path = os.path.join(root, "wav") + if not os.path.isdir(self._path): + if not download: + raise RuntimeError( + f"Dataset not found at {self._path}. Please set `download=True` to download the dataset." + ) + _download_extract_wavs(root) + + def get_metadata(self, n: int): + raise NotImplementedError + + def __getitem__(self, n: int): + raise NotImplementedError + + def __len__(self) -> int: + raise NotImplementedError + + +class VoxCeleb1Identification(VoxCeleb1): + """*VoxCeleb1* :cite:`nagrani2017voxceleb` dataset for speaker identification task. + + Each data sample contains the waveform, sample rate, speaker id, and the file id. + + Args: + root (str or Path): Path to the directory where the dataset is found or downloaded. + subset (str, optional): Subset of the dataset to use. Options: ["train", "dev", "test"]. (Default: ``"train"``) + meta_url (str, optional): The url of meta file that contains the list of subset labels and file paths. + The format of each row is ``subset file_path". For example: ``1 id10006/nLEBBc9oIFs/00003.wav``. + ``1``, ``2``, ``3`` mean ``train``, ``dev``, and ``test`` subest, respectively. + (Default: ``"https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/iden_split.txt"``) + download (bool, optional): + Whether to download the dataset if it is not found at root path. (Default: ``False``). + + Note: + The file structure of `VoxCeleb1Identification` dataset is as follows: + + └─ root/ + + └─ wav/ + + └─ speaker_id folders + + Users who pre-downloaded the ``"vox1_dev_wav.zip"`` and ``"vox1_test_wav.zip"`` files need to move + the extracted files into the same ``root`` directory. + """ + + def __init__( + self, root: Union[str, Path], subset: str = "train", meta_url: str = _IDEN_SPLIT_URL, download: bool = False + ) -> None: + super().__init__(root, download) + if subset not in ["train", "dev", "test"]: + raise ValueError("`subset` must be one of ['train', 'dev', 'test']") + # download the iden_split.txt to get the train, dev, test lists. + meta_list_path = os.path.join(root, os.path.basename(meta_url)) + if not os.path.exists(meta_list_path): + download_url_to_file(meta_url, meta_list_path) + self._flist = _get_flist(self._path, meta_list_path, subset) + + def get_metadata(self, n: int) -> Tuple[str, int, int, str]: + """Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform, + but otherwise returns the same fields as :py:func:`__getitem__`. + + Args: + n (int): The index of the sample + + Returns: + Tuple of the following items; + + str: + Path to audio + int: + Sample rate + int: + Speaker ID + str: + File ID + """ + file_path = self._flist[n] + file_id = _get_file_id(file_path, self._ext_audio) + speaker_id = file_id.split("-")[0] + speaker_id = int(speaker_id[3:]) + return file_path, SAMPLE_RATE, speaker_id, file_id + + def __getitem__(self, n: int) -> Tuple[Tensor, int, int, str]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + Tuple of the following items; + + Tensor: + Waveform + int: + Sample rate + int: + Speaker ID + str: + File ID + """ + metadata = self.get_metadata(n) + waveform = _load_waveform(self._path, metadata[0], metadata[1]) + return (waveform,) + metadata[1:] + + def __len__(self) -> int: + return len(self._flist) + + +class VoxCeleb1Verification(VoxCeleb1): + """*VoxCeleb1* :cite:`nagrani2017voxceleb` dataset for speaker verification task. + + Each data sample contains a pair of waveforms, sample rate, the label indicating if they are + from the same speaker, and the file ids. + + Args: + root (str or Path): Path to the directory where the dataset is found or downloaded. + meta_url (str, optional): The url of meta file that contains a list of utterance pairs + and the corresponding labels. The format of each row is ``label file_path1 file_path2". + For example: ``1 id10270/x6uYqmx31kE/00001.wav id10270/8jEAjG6SegY/00008.wav``. + ``1`` means the two utterances are from the same speaker, ``0`` means not. + (Default: ``"https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test.txt"``) + download (bool, optional): + Whether to download the dataset if it is not found at root path. (Default: ``False``). + + Note: + The file structure of `VoxCeleb1Verification` dataset is as follows: + + └─ root/ + + └─ wav/ + + └─ speaker_id folders + + Users who pre-downloaded the ``"vox1_dev_wav.zip"`` and ``"vox1_test_wav.zip"`` files need to move + the extracted files into the same ``root`` directory. + """ + + def __init__(self, root: Union[str, Path], meta_url: str = _VERI_TEST_URL, download: bool = False) -> None: + super().__init__(root, download) + # download the veri_test.txt to get the list of training pairs and labels. + meta_list_path = os.path.join(root, os.path.basename(meta_url)) + if not os.path.exists(meta_list_path): + download_url_to_file(meta_url, meta_list_path) + self._flist = _get_paired_flist(self._path, meta_list_path) + + def get_metadata(self, n: int) -> Tuple[str, str, int, int, str, str]: + """Get metadata for the n-th sample from the dataset. Returns filepaths instead of waveforms, + but otherwise returns the same fields as :py:func:`__getitem__`. + + Args: + n (int): The index of the sample + + Returns: + Tuple of the following items; + + str: + Path to audio file of speaker 1 + str: + Path to audio file of speaker 2 + int: + Sample rate + int: + Label + str: + File ID of speaker 1 + str: + File ID of speaker 2 + """ + label, file_path_spk1, file_path_spk2 = self._flist[n] + label = int(label) + file_id_spk1 = _get_file_id(file_path_spk1, self._ext_audio) + file_id_spk2 = _get_file_id(file_path_spk2, self._ext_audio) + return file_path_spk1, file_path_spk2, SAMPLE_RATE, label, file_id_spk1, file_id_spk2 + + def __getitem__(self, n: int) -> Tuple[Tensor, Tensor, int, int, str, str]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded. + + Returns: + Tuple of the following items; + + Tensor: + Waveform of speaker 1 + Tensor: + Waveform of speaker 2 + int: + Sample rate + int: + Label + str: + File ID of speaker 1 + str: + File ID of speaker 2 + """ + metadata = self.get_metadata(n) + waveform_spk1 = _load_waveform(self._path, metadata[0], metadata[2]) + waveform_spk2 = _load_waveform(self._path, metadata[1], metadata[2]) + return (waveform_spk1, waveform_spk2) + metadata[2:] + + def __len__(self) -> int: + return len(self._flist) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/datasets/yesno.py b/.venv/lib/python3.11/site-packages/torchaudio/datasets/yesno.py new file mode 100644 index 0000000000000000000000000000000000000000..baad08f1593a49af5f95658e8d4b67be6d3deeb9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/datasets/yesno.py @@ -0,0 +1,89 @@ +import os +from pathlib import Path +from typing import List, Tuple, Union + +import torchaudio +from torch import Tensor +from torch.utils.data import Dataset +from torchaudio._internal import download_url_to_file +from torchaudio.datasets.utils import _extract_tar + + +_RELEASE_CONFIGS = { + "release1": { + "folder_in_archive": "waves_yesno", + "url": "http://www.openslr.org/resources/1/waves_yesno.tar.gz", + "checksum": "c3f49e0cca421f96b75b41640749167b52118f232498667ca7a5f9416aef8e73", + } +} + + +class YESNO(Dataset): + """*YesNo* :cite:`YesNo` dataset. + + Args: + root (str or Path): Path to the directory where the dataset is found or downloaded. + url (str, optional): The URL to download the dataset from. + (default: ``"http://www.openslr.org/resources/1/waves_yesno.tar.gz"``) + folder_in_archive (str, optional): + The top-level directory of the dataset. (default: ``"waves_yesno"``) + download (bool, optional): + Whether to download the dataset if it is not found at root path. (default: ``False``). + """ + + def __init__( + self, + root: Union[str, Path], + url: str = _RELEASE_CONFIGS["release1"]["url"], + folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"], + download: bool = False, + ) -> None: + + self._parse_filesystem(root, url, folder_in_archive, download) + + def _parse_filesystem(self, root: str, url: str, folder_in_archive: str, download: bool) -> None: + root = Path(root) + archive = os.path.basename(url) + archive = root / archive + + self._path = root / folder_in_archive + if download: + if not os.path.isdir(self._path): + if not os.path.isfile(archive): + checksum = _RELEASE_CONFIGS["release1"]["checksum"] + download_url_to_file(url, archive, hash_prefix=checksum) + _extract_tar(archive) + + if not os.path.isdir(self._path): + raise RuntimeError("Dataset not found. Please use `download=True` to download it.") + + self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*.wav")) + + def _load_item(self, fileid: str, path: str): + labels = [int(c) for c in fileid.split("_")] + file_audio = os.path.join(path, fileid + ".wav") + waveform, sample_rate = torchaudio.load(file_audio) + return waveform, sample_rate, labels + + def __getitem__(self, n: int) -> Tuple[Tensor, int, List[int]]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + Tuple of the following items; + + Tensor: + Waveform + int: + Sample rate + List[int]: + labels + """ + fileid = self._walker[n] + item = self._load_item(fileid, self._path) + return item + + def __len__(self) -> int: + return len(self._walker) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/io/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/io/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5486424dd21a0f1aeb8ea2cead737309fb7f9a7a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/io/__init__.py @@ -0,0 +1,13 @@ +from torio.io import CodecConfig, StreamingMediaDecoder as StreamReader, StreamingMediaEncoder as StreamWriter + +from ._effector import AudioEffector +from ._playback import play_audio + + +__all__ = [ + "AudioEffector", + "StreamReader", + "StreamWriter", + "CodecConfig", + "play_audio", +] diff --git a/.venv/lib/python3.11/site-packages/torchaudio/io/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/io/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31dd8fcdf3fc65eada72a8e6e5fb73c706f1c0f2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/io/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/io/__pycache__/_effector.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/io/__pycache__/_effector.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3946c04e3c456f2a7029706301a6f3ae842f24db Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/io/__pycache__/_effector.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/io/__pycache__/_playback.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/io/__pycache__/_playback.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..915a85ed3089cc2f603f07ddf7df294b54ac1759 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/io/__pycache__/_playback.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/io/_effector.py b/.venv/lib/python3.11/site-packages/torchaudio/io/_effector.py new file mode 100644 index 0000000000000000000000000000000000000000..74255684c8fa75789e88fc224bcdac12aa1b29cf --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/io/_effector.py @@ -0,0 +1,347 @@ +import io +from typing import Iterator, List, Optional + +import torch +from torch import Tensor + +from torio.io._streaming_media_decoder import _get_afilter_desc, StreamingMediaDecoder as StreamReader +from torio.io._streaming_media_encoder import CodecConfig, StreamingMediaEncoder as StreamWriter + + +class _StreamingIOBuffer: + """Streaming Bytes IO buffer. Data are dropped when read.""" + + def __init__(self): + self._buffer: List(bytes) = [] + + def write(self, b: bytes): + if b: + self._buffer.append(b) + return len(b) + + def pop(self, n): + """Pop the oldest byte string. It does not necessary return the requested amount""" + if not self._buffer: + return b"" + if len(self._buffer[0]) <= n: + return self._buffer.pop(0) + ret = self._buffer[0][:n] + self._buffer[0] = self._buffer[0][n:] + return ret + + +def _get_sample_fmt(dtype: torch.dtype): + types = { + torch.uint8: "u8", + torch.int16: "s16", + torch.int32: "s32", + torch.float32: "flt", + torch.float64: "dbl", + } + if dtype not in types: + raise ValueError(f"Unsupported dtype is provided {dtype}. Supported dtypes are: {types.keys()}") + return types[dtype] + + +class _AudioStreamingEncoder: + """Given a waveform, encode on-demand and return bytes""" + + def __init__( + self, + src: Tensor, + sample_rate: int, + effect: str, + muxer: str, + encoder: Optional[str], + codec_config: Optional[CodecConfig], + frames_per_chunk: int, + ): + self.src = src + self.buffer = _StreamingIOBuffer() + self.writer = StreamWriter(self.buffer, format=muxer) + self.writer.add_audio_stream( + num_channels=src.size(1), + sample_rate=sample_rate, + format=_get_sample_fmt(src.dtype), + encoder=encoder, + filter_desc=effect, + codec_config=codec_config, + ) + self.writer.open() + self.fpc = frames_per_chunk + + # index on the input tensor (along time-axis) + # we use -1 to indicate that we finished iterating the tensor and + # the writer is closed. + self.i_iter = 0 + + def read(self, n): + while not self.buffer._buffer and self.i_iter >= 0: + self.writer.write_audio_chunk(0, self.src[self.i_iter : self.i_iter + self.fpc]) + self.i_iter += self.fpc + if self.i_iter >= self.src.size(0): + self.writer.flush() + self.writer.close() + self.i_iter = -1 + return self.buffer.pop(n) + + +def _encode( + src: Tensor, + sample_rate: int, + effect: str, + muxer: str, + encoder: Optional[str], + codec_config: Optional[CodecConfig], +): + buffer = io.BytesIO() + writer = StreamWriter(buffer, format=muxer) + writer.add_audio_stream( + num_channels=src.size(1), + sample_rate=sample_rate, + format=_get_sample_fmt(src.dtype), + encoder=encoder, + filter_desc=effect, + codec_config=codec_config, + ) + with writer.open(): + writer.write_audio_chunk(0, src) + buffer.seek(0) + return buffer + + +def _get_muxer(dtype: torch.dtype): + # TODO: check if this works in Windows. + types = { + torch.uint8: "u8", + torch.int16: "s16le", + torch.int32: "s32le", + torch.float32: "f32le", + torch.float64: "f64le", + } + if dtype not in types: + raise ValueError(f"Unsupported dtype is provided {dtype}. Supported dtypes are: {types.keys()}") + return types[dtype] + + +class AudioEffector: + """Apply various filters and/or codecs to waveforms. + + .. versionadded:: 2.1 + + Args: + effect (str or None, optional): Filter expressions or ``None`` to apply no filter. + See https://ffmpeg.org/ffmpeg-filters.html#Audio-Filters for the + details of filter syntax. + + format (str or None, optional): When provided, encode the audio into the + corresponding format. Default: ``None``. + + encoder (str or None, optional): When provided, override the encoder used + by the ``format``. Default: ``None``. + + codec_config (CodecConfig or None, optional): When provided, configure the encoding codec. + Should be provided in conjunction with ``format`` option. + + pad_end (bool, optional): When enabled, and if the waveform becomes shorter after applying + effects/codec, then pad the end with silence. + + Example - Basic usage + To use ``AudioEffector``, first instantiate it with a set of + ``effect`` and ``format``. + + >>> # instantiate the effector + >>> effector = AudioEffector(effect=..., format=...) + + Then, use :py:meth:`~AudioEffector.apply` or :py:meth:`~AudioEffector.stream` + method to apply them. + + >>> # Apply the effect to the whole waveform + >>> applied = effector.apply(waveform, sample_rate) + + >>> # Apply the effect chunk-by-chunk + >>> for chunk in effector.stream(waveform, sample_rate): + >>> ... + + Example - Applying effects + Please refer to + https://ffmpeg.org/ffmpeg-filters.html#Filtergraph-description + for the overview of filter description, and + https://ffmpeg.org/ffmpeg-filters.html#toc-Audio-Filters + for the list of available filters. + + Tempo - https://ffmpeg.org/ffmpeg-filters.html#atempo + + >>> AudioEffector(effect="atempo=1.5") + + Echo - https://ffmpeg.org/ffmpeg-filters.html#aecho + + >>> AudioEffector(effect="aecho=0.8:0.88:60:0.4") + + Flanger - https://ffmpeg.org/ffmpeg-filters.html#flanger + + >>> AudioEffector(effect="aflanger") + + Vibrato - https://ffmpeg.org/ffmpeg-filters.html#vibrato + + >>> AudioEffector(effect="vibrato") + + Tremolo - https://ffmpeg.org/ffmpeg-filters.html#tremolo + + >>> AudioEffector(effect="vibrato") + + You can also apply multiple effects at once. + + >>> AudioEffector(effect="") + + Example - Applying codec + One can apply codec using ``format`` argument. ``format`` can be + audio format or container format. If the container format supports + multiple encoders, you can specify it with ``encoder`` argument. + + Wav format + (no compression is applied but samples are converted to + 16-bit signed integer) + + >>> AudioEffector(format="wav") + + Ogg format with default encoder + + >>> AudioEffector(format="ogg") + + Ogg format with vorbis + + >>> AudioEffector(format="ogg", encoder="vorbis") + + Ogg format with opus + + >>> AudioEffector(format="ogg", encoder="opus") + + Webm format with opus + + >>> AudioEffector(format="webm", encoder="opus") + + Example - Applying codec with configuration + Reference: https://trac.ffmpeg.org/wiki/Encode/MP3 + + MP3 with default config + + >>> AudioEffector(format="mp3") + + MP3 with variable bitrate + + >>> AudioEffector(format="mp3", codec_config=CodecConfig(qscale=5)) + + MP3 with constant bitrate + + >>> AudioEffector(format="mp3", codec_config=CodecConfig(bit_rate=32_000)) + """ + + def __init__( + self, + effect: Optional[str] = None, + format: Optional[str] = None, + *, + encoder: Optional[str] = None, + codec_config: Optional[CodecConfig] = None, + pad_end: bool = True, + ): + if format is None: + if encoder is not None or codec_config is not None: + raise ValueError("`encoder` and/or `condec_config` opions are provided without `format` option.") + self.effect = effect + self.format = format + self.encoder = encoder + self.codec_config = codec_config + self.pad_end = pad_end + + def _get_reader(self, waveform, sample_rate, output_sample_rate, frames_per_chunk=None): + num_frames, num_channels = waveform.shape + + if self.format is not None: + muxer = self.format + encoder = self.encoder + option = {} + # Some formats are headerless, so need to provide these infomation. + if self.format == "mulaw": + option = {"sample_rate": f"{sample_rate}", "channels": f"{num_channels}"} + + else: # PCM + muxer = _get_muxer(waveform.dtype) + encoder = None + option = {"sample_rate": f"{sample_rate}", "channels": f"{num_channels}"} + + if frames_per_chunk is None: + src = _encode(waveform, sample_rate, self.effect, muxer, encoder, self.codec_config) + else: + src = _AudioStreamingEncoder( + waveform, sample_rate, self.effect, muxer, encoder, self.codec_config, frames_per_chunk + ) + + output_sr = sample_rate if output_sample_rate is None else output_sample_rate + filter_desc = _get_afilter_desc(output_sr, _get_sample_fmt(waveform.dtype), num_channels) + if self.pad_end: + filter_desc = f"{filter_desc},apad=whole_len={num_frames}" + + reader = StreamReader(src, format=muxer, option=option) + reader.add_audio_stream(frames_per_chunk or -1, -1, filter_desc=filter_desc) + return reader + + def apply(self, waveform: Tensor, sample_rate: int, output_sample_rate: Optional[int] = None) -> Tensor: + """Apply the effect and/or codecs to the whole tensor. + + Args: + waveform (Tensor): The input waveform. Shape: ``(time, channel)`` + sample_rate (int): Sample rate of the input waveform. + output_sample_rate (int or None, optional): Output sample rate. + If provided, override the output sample rate. + Otherwise, the resulting tensor is resampled to have + the same sample rate as the input. + Default: ``None``. + + Returns: + Tensor: + Resulting Tensor. Shape: ``(time, channel)``. The number of frames + could be different from that of the input. + """ + if waveform.ndim != 2: + raise ValueError(f"Expected the input waveform to be 2D. Found: {waveform.ndim}") + + if waveform.numel() == 0: + return waveform + + reader = self._get_reader(waveform, sample_rate, output_sample_rate) + reader.process_all_packets() + (applied,) = reader.pop_chunks() + return Tensor(applied) + + def stream( + self, waveform: Tensor, sample_rate: int, frames_per_chunk: int, output_sample_rate: Optional[int] = None + ) -> Iterator[Tensor]: + """Apply the effect and/or codecs to the given tensor chunk by chunk. + + Args: + waveform (Tensor): The input waveform. Shape: ``(time, channel)`` + sample_rate (int): Sample rate of the waveform. + frames_per_chunk (int): The number of frames to return at a time. + output_sample_rate (int or None, optional): Output sample rate. + If provided, override the output sample rate. + Otherwise, the resulting tensor is resampled to have + the same sample rate as the input. + Default: ``None``. + + Returns: + Iterator[Tensor]: + Series of processed chunks. Shape: ``(time, channel)``, where the + the number of frames matches ``frames_per_chunk`` except the + last chunk, which could be shorter. + """ + if waveform.ndim != 2: + raise ValueError(f"Expected the input waveform to be 2D. Found: {waveform.ndim}") + + if waveform.numel() == 0: + return waveform + + reader = self._get_reader(waveform, sample_rate, output_sample_rate, frames_per_chunk) + for (applied,) in reader.stream(): + yield Tensor(applied) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/io/_playback.py b/.venv/lib/python3.11/site-packages/torchaudio/io/_playback.py new file mode 100644 index 0000000000000000000000000000000000000000..7183ee3ba8cad8e842d066f6aaa9067687b9476b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/io/_playback.py @@ -0,0 +1,72 @@ +import warnings +from sys import platform +from typing import Optional + +import torch +import torchaudio + +dict_format = { + torch.uint8: "u8", + torch.int16: "s16", + torch.int32: "s32", + torch.int64: "s64", + torch.float32: "flt", + torch.float64: "dbl", +} + + +def play_audio( + waveform: torch.Tensor, + sample_rate: Optional[float], + device: Optional[str] = None, +) -> None: + """Plays audio through specified or available output device. + + .. warning:: + This function is currently only supported on MacOS, and requires + libavdevice (FFmpeg) with ``audiotoolbox`` output device. + + .. note:: + This function can play up to two audio channels. + + Args: + waveform: Tensor containing the audio to play. + Expected shape: `(time, num_channels)`. + sample_rate: Sample rate of the audio to play. + device: Output device to use. If None, the default device is used. + """ + + if platform == "darwin": + device = device or "audiotoolbox" + path = "-" + else: + raise ValueError(f"This function only supports MacOS, but current OS is {platform}") + + available_devices = list(torchaudio.utils.ffmpeg_utils.get_output_devices().keys()) + if device not in available_devices: + raise ValueError(f"Device {device} is not available. Available devices are: {available_devices}") + + if waveform.dtype not in dict_format: + raise ValueError(f"Unsupported type {waveform.dtype}. The list of supported types is: {dict_format.keys()}") + format = dict_format[waveform.dtype] + + if waveform.ndim != 2: + raise ValueError(f"Expected 2D tensor with shape `(time, num_channels)`, got {waveform.ndim}D tensor instead") + + time, num_channels = waveform.size() + if num_channels > 2: + warnings.warn( + f"Expected up to 2 channels, got {num_channels} channels instead. " + "Only the first 2 channels will be played.", + stacklevel=2, + ) + + # Write to speaker device + s = torchaudio.io.StreamWriter(dst=path, format=device) + s.add_audio_stream(sample_rate, num_channels, format=format) + + # write audio to the device + block_size = 256 + with s.open(): + for i in range(0, time, block_size): + s.write_audio_chunk(0, waveform[i : i + block_size, :]) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/lib/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/lib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torchaudio/lib/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/lib/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8a4ee91fa22facf609dffeb60501989f836da04 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/lib/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..810ab85bfc64ff0444483161deac455703e773e0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/_hdemucs.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/_hdemucs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a20e340904e6fd3885a7746d52fed98965abbd75 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/_hdemucs.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/conformer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/conformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce68bddc2bbb70a90cd09d7b854c6eb7b68f193f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/conformer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/conv_tasnet.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/conv_tasnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d2fb0a37828d0554669badc7cc186fa531fab95 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/conv_tasnet.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/deepspeech.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/deepspeech.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28ea5b1e562ec9e0af1aabfa18e125eaec0d284d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/deepspeech.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/emformer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/emformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13e8cf28bfc61fbd0f0d364e7ddc11c2707255c6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/emformer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/rnnt.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/rnnt.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0df7e2926011c239f6d200810c2fd42e0ec62582 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/rnnt.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/rnnt_decoder.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/rnnt_decoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5fa5073b94125f226b43dbfea7ffbcb465d09ca Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/rnnt_decoder.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/tacotron2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/tacotron2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63a070f42bf6d3d984e5ff867e93195f06a25ca3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/tacotron2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/wav2letter.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/wav2letter.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e72322ea2620b884864cedc01e84427c7124c956 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/wav2letter.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/wavernn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/wavernn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32a9180a03ba1142dd8fcb75e294ae88ceda736a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/wavernn.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/decoder/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/models/decoder/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e2dcbb35aad9379924deb2467bbf68ac7eaac8c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/models/decoder/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/decoder/__pycache__/_ctc_decoder.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/models/decoder/__pycache__/_ctc_decoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f677881dcfcda959ef253ec37295ef29c941306a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/models/decoder/__pycache__/_ctc_decoder.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/decoder/__pycache__/_cuda_ctc_decoder.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/models/decoder/__pycache__/_cuda_ctc_decoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e2a0deb048494aead4dc319ee944a8c493ed145 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/models/decoder/__pycache__/_cuda_ctc_decoder.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/squim/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/models/squim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..092d6eb8e36e2329c78d21bf609a8458818995e6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/models/squim/__init__.py @@ -0,0 +1,11 @@ +from .objective import squim_objective_base, squim_objective_model, SquimObjective +from .subjective import squim_subjective_base, squim_subjective_model, SquimSubjective + +__all__ = [ + "squim_objective_base", + "squim_objective_model", + "squim_subjective_base", + "squim_subjective_model", + "SquimObjective", + "SquimSubjective", +] diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/squim/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/models/squim/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a049af0d390f0db219ac34e7b97d3a76c684d0a3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/models/squim/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/squim/__pycache__/objective.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/models/squim/__pycache__/objective.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..227a6dd7e182df2a7eaad987f8eee54c44fd284c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/models/squim/__pycache__/objective.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/squim/__pycache__/subjective.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/models/squim/__pycache__/subjective.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cd5541ebee4865c8cd2a49302d09fb257c82305 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/models/squim/__pycache__/subjective.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/squim/objective.py b/.venv/lib/python3.11/site-packages/torchaudio/models/squim/objective.py new file mode 100644 index 0000000000000000000000000000000000000000..d0a0671a4ec3828fa5dbe1613c7378fbc55d34dd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/models/squim/objective.py @@ -0,0 +1,326 @@ +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def transform_wb_pesq_range(x: float) -> float: + """The metric defined by ITU-T P.862 is often called 'PESQ score', which is defined + for narrow-band signals and has a value range of [-0.5, 4.5] exactly. Here, we use the metric + defined by ITU-T P.862.2, commonly known as 'wide-band PESQ' and will be referred to as "PESQ score". + + Args: + x (float): Narrow-band PESQ score. + + Returns: + (float): Wide-band PESQ score. + """ + return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224)) + + +PESQRange: Tuple[float, float] = ( + 1.0, # P.862.2 uses a different input filter than P.862, and the lower bound of + # the raw score is not -0.5 anymore. It's hard to figure out the true lower bound. + # We are using 1.0 as a reasonable approximation. + transform_wb_pesq_range(4.5), +) + + +class RangeSigmoid(nn.Module): + def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None: + super(RangeSigmoid, self).__init__() + assert isinstance(val_range, tuple) and len(val_range) == 2 + self.val_range: Tuple[float, float] = val_range + self.sigmoid: nn.modules.Module = nn.Sigmoid() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.sigmoid(x) * (self.val_range[1] - self.val_range[0]) + self.val_range[0] + return out + + +class Encoder(nn.Module): + """Encoder module that transform 1D waveform to 2D representations. + + Args: + feat_dim (int, optional): The feature dimension after Encoder module. (Default: 512) + win_len (int, optional): kernel size in the Conv1D layer. (Default: 32) + """ + + def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None: + super(Encoder, self).__init__() + + self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply waveforms to convolutional layer and ReLU layer. + + Args: + x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`. + + Returns: + (torch,Tensor): Feature Tensor with dimensions `(batch, channel, frame)`. + """ + out = x.unsqueeze(dim=1) + out = F.relu(self.conv1d(out)) + return out + + +class SingleRNN(nn.Module): + def __init__(self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0) -> None: + super(SingleRNN, self).__init__() + + self.rnn_type = rnn_type + self.input_size = input_size + self.hidden_size = hidden_size + + self.rnn: nn.modules.Module = getattr(nn, rnn_type)( + input_size, + hidden_size, + 1, + dropout=dropout, + batch_first=True, + bidirectional=True, + ) + + self.proj = nn.Linear(hidden_size * 2, input_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # input shape: batch, seq, dim + out, _ = self.rnn(x) + out = self.proj(out) + return out + + +class DPRNN(nn.Module): + """*Dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual`. + + Args: + feat_dim (int, optional): The feature dimension after Encoder module. (Default: 64) + hidden_dim (int, optional): Hidden dimension in the RNN layer of DPRNN. (Default: 128) + num_blocks (int, optional): Number of DPRNN layers. (Default: 6) + rnn_type (str, optional): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. (Default: "LSTM") + d_model (int, optional): The number of expected features in the input. (Default: 256) + chunk_size (int, optional): Chunk size of input for DPRNN. (Default: 100) + chunk_stride (int, optional): Stride of chunk input for DPRNN. (Default: 50) + """ + + def __init__( + self, + feat_dim: int = 64, + hidden_dim: int = 128, + num_blocks: int = 6, + rnn_type: str = "LSTM", + d_model: int = 256, + chunk_size: int = 100, + chunk_stride: int = 50, + ) -> None: + super(DPRNN, self).__init__() + + self.num_blocks = num_blocks + + self.row_rnn = nn.ModuleList([]) + self.col_rnn = nn.ModuleList([]) + self.row_norm = nn.ModuleList([]) + self.col_norm = nn.ModuleList([]) + for _ in range(num_blocks): + self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim)) + self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim)) + self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8)) + self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8)) + self.conv = nn.Sequential( + nn.Conv2d(feat_dim, d_model, 1), + nn.PReLU(), + ) + self.chunk_size = chunk_size + self.chunk_stride = chunk_stride + + def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]: + # input shape: (B, N, T) + seq_len = x.shape[-1] + + rest = self.chunk_size - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size + out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride]) + + return out, rest + + def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]: + out, rest = self.pad_chunk(x) + batch_size, feat_dim, seq_len = out.shape + + segments1 = out[:, :, : -self.chunk_stride].contiguous().view(batch_size, feat_dim, -1, self.chunk_size) + segments2 = out[:, :, self.chunk_stride :].contiguous().view(batch_size, feat_dim, -1, self.chunk_size) + out = torch.cat([segments1, segments2], dim=3) + out = out.view(batch_size, feat_dim, -1, self.chunk_size).transpose(2, 3).contiguous() + + return out, rest + + def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor: + batch_size, dim, _, _ = x.shape + out = x.transpose(2, 3).contiguous().view(batch_size, dim, -1, self.chunk_size * 2) + out1 = out[:, :, :, : self.chunk_size].contiguous().view(batch_size, dim, -1)[:, :, self.chunk_stride :] + out2 = out[:, :, :, self.chunk_size :].contiguous().view(batch_size, dim, -1)[:, :, : -self.chunk_stride] + out = out1 + out2 + if rest > 0: + out = out[:, :, :-rest] + out = out.contiguous() + return out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, rest = self.chunking(x) + batch_size, _, dim1, dim2 = x.shape + out = x + for row_rnn, row_norm, col_rnn, col_norm in zip(self.row_rnn, self.row_norm, self.col_rnn, self.col_norm): + row_in = out.permute(0, 3, 2, 1).contiguous().view(batch_size * dim2, dim1, -1).contiguous() + row_out = row_rnn(row_in) + row_out = row_out.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous() + row_out = row_norm(row_out) + out = out + row_out + + col_in = out.permute(0, 2, 3, 1).contiguous().view(batch_size * dim1, dim2, -1).contiguous() + col_out = col_rnn(col_in) + col_out = col_out.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous() + col_out = col_norm(col_out) + out = out + col_out + out = self.conv(out) + out = self.merging(out, rest) + out = out.transpose(1, 2).contiguous() + return out + + +class AutoPool(nn.Module): + def __init__(self, pool_dim: int = 1) -> None: + super(AutoPool, self).__init__() + self.pool_dim: int = pool_dim + self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim) + self.register_parameter("alpha", nn.Parameter(torch.ones(1))) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + weight = self.softmax(torch.mul(x, self.alpha)) + out = torch.sum(torch.mul(x, weight), dim=self.pool_dim) + return out + + +class SquimObjective(nn.Module): + """Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores + for speech enhancement (e.g., STOI, PESQ, and SI-SDR). + + Args: + encoder (torch.nn.Module): Encoder module to transform 1D waveform to 2D feature representation. + dprnn (torch.nn.Module): DPRNN module to model sequential feature. + branches (torch.nn.ModuleList): Transformer branches in which each branch estimate one objective metirc score. + """ + + def __init__( + self, + encoder: nn.Module, + dprnn: nn.Module, + branches: nn.ModuleList, + ): + super(SquimObjective, self).__init__() + self.encoder = encoder + self.dprnn = dprnn + self.branches = branches + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + """ + Args: + x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`. + + Returns: + List(torch.Tensor): List of score Tenosrs. Each Tensor is with dimension `(batch,)`. + """ + if x.ndim != 2: + raise ValueError(f"The input must be a 2D Tensor. Found dimension {x.ndim}.") + x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20) + out = self.encoder(x) + out = self.dprnn(out) + scores = [] + for branch in self.branches: + scores.append(branch(out).squeeze(dim=1)) + return scores + + +def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module: + """Create branch module after DPRNN model for predicting metric score. + + Args: + d_model (int): The number of expected features in the input. + nhead (int): Number of heads in the multi-head attention model. + metric (str): The metric name to predict. + + Returns: + (nn.Module): Returned module to predict corresponding metric score. + """ + layer1 = nn.TransformerEncoderLayer(d_model, nhead, d_model * 4, dropout=0.0, batch_first=True) + layer2 = AutoPool() + if metric == "stoi": + layer3 = nn.Sequential( + nn.Linear(d_model, d_model), + nn.PReLU(), + nn.Linear(d_model, 1), + RangeSigmoid(), + ) + elif metric == "pesq": + layer3 = nn.Sequential( + nn.Linear(d_model, d_model), + nn.PReLU(), + nn.Linear(d_model, 1), + RangeSigmoid(val_range=PESQRange), + ) + else: + layer3: nn.modules.Module = nn.Sequential(nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1)) + return nn.Sequential(layer1, layer2, layer3) + + +def squim_objective_model( + feat_dim: int, + win_len: int, + d_model: int, + nhead: int, + hidden_dim: int, + num_blocks: int, + rnn_type: str, + chunk_size: int, + chunk_stride: Optional[int] = None, +) -> SquimObjective: + """Build a custome :class:`torchaudio.prototype.models.SquimObjective` model. + + Args: + feat_dim (int, optional): The feature dimension after Encoder module. + win_len (int): Kernel size in the Encoder module. + d_model (int): The number of expected features in the input. + nhead (int): Number of heads in the multi-head attention model. + hidden_dim (int): Hidden dimension in the RNN layer of DPRNN. + num_blocks (int): Number of DPRNN layers. + rnn_type (str): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. + chunk_size (int): Chunk size of input for DPRNN. + chunk_stride (int or None, optional): Stride of chunk input for DPRNN. + """ + if chunk_stride is None: + chunk_stride = chunk_size // 2 + encoder = Encoder(feat_dim, win_len) + dprnn = DPRNN(feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride) + branches = nn.ModuleList( + [ + _create_branch(d_model, nhead, "stoi"), + _create_branch(d_model, nhead, "pesq"), + _create_branch(d_model, nhead, "sisdr"), + ] + ) + return SquimObjective(encoder, dprnn, branches) + + +def squim_objective_base() -> SquimObjective: + """Build :class:`torchaudio.prototype.models.SquimObjective` model with default arguments.""" + return squim_objective_model( + feat_dim=256, + win_len=64, + d_model=256, + nhead=4, + hidden_dim=256, + num_blocks=2, + rnn_type="LSTM", + chunk_size=71, + ) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/squim/subjective.py b/.venv/lib/python3.11/site-packages/torchaudio/models/squim/subjective.py new file mode 100644 index 0000000000000000000000000000000000000000..4be681c91c5f67a2b888b49ec8269b74762360ab --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/models/squim/subjective.py @@ -0,0 +1,150 @@ +from typing import Tuple + +import torch +import torch.nn as nn +import torchaudio + + +class AttPool(nn.Module): + """Attention-Pooling module that estimates the attention score. + + Args: + input_dim (int): Input feature dimension. + att_dim (int): Attention Tensor dimension. + """ + + def __init__(self, input_dim: int, att_dim: int): + super(AttPool, self).__init__() + + self.linear1 = nn.Linear(input_dim, 1) + self.linear2 = nn.Linear(input_dim, att_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply attention and pooling. + + Args: + x (torch.Tensor): Input Tensor with dimensions `(batch, time, feature_dim)`. + + Returns: + (torch.Tensor): Attention score with dimensions `(batch, att_dim)`. + """ + + att = self.linear1(x) # (batch, time, 1) + att = att.transpose(2, 1) # (batch, 1, time) + att = nn.functional.softmax(att, dim=2) + x = torch.matmul(att, x).squeeze(1) # (batch, input_dim) + x = self.linear2(x) # (batch, att_dim) + return x + + +class Predictor(nn.Module): + """Prediction module that apply pooling and attention, then predict subjective metric scores. + + Args: + input_dim (int): Input feature dimension. + att_dim (int): Attention Tensor dimension. + """ + + def __init__(self, input_dim: int, att_dim: int): + super(Predictor, self).__init__() + self.att_pool_layer = AttPool(input_dim, att_dim) + self.att_dim = att_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Predict subjective evaluation metric score. + + Args: + x (torch.Tensor): Input Tensor with dimensions `(batch, time, feature_dim)`. + + Returns: + (torch.Tensor): Subjective metric score. Tensor with dimensions `(batch,)`. + """ + x = self.att_pool_layer(x) + x = nn.functional.softmax(x, dim=1) + B = torch.linspace(0, 4, steps=self.att_dim, device=x.device) + x = (x * B).sum(dim=1) + return x + + +class SquimSubjective(nn.Module): + """Speech Quality and Intelligibility Measures (SQUIM) model that predicts **subjective** metric scores + for speech enhancement (e.g., Mean Opinion Score (MOS)). The model is adopted from *NORESQA-MOS* + :cite:`manocha2022speech` which predicts MOS scores given the input speech and a non-matching reference. + + Args: + ssl_model (torch.nn.Module): The self-supervised learning model for feature extraction. + projector (torch.nn.Module): Projection layer that projects SSL feature to a lower dimension. + predictor (torch.nn.Module): Predict the subjective scores. + """ + + def __init__(self, ssl_model: nn.Module, projector: nn.Module, predictor: nn.Module): + super(SquimSubjective, self).__init__() + self.ssl_model = ssl_model + self.projector = projector + self.predictor = predictor + + def _align_shapes(self, waveform: torch.Tensor, reference: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Cut or pad the reference Tensor to make it aligned with waveform Tensor. + + Args: + waveform (torch.Tensor): Input waveform for evaluation. Tensor with dimensions `(batch, time)`. + reference (torch.Tensor): Non-matching clean reference. Tensor with dimensions `(batch, time_ref)`. + + Returns: + (torch.Tensor, torch.Tensor): The aligned waveform and reference Tensors + with same dimensions `(batch, time)`. + """ + T_waveform = waveform.shape[-1] + T_reference = reference.shape[-1] + if T_reference < T_waveform: + num_padding = T_waveform // T_reference + 1 + reference = torch.cat([reference for _ in range(num_padding)], dim=1) + return waveform, reference[:, :T_waveform] + + def forward(self, waveform: torch.Tensor, reference: torch.Tensor): + """Predict subjective evaluation metric score. + + Args: + waveform (torch.Tensor): Input waveform for evaluation. Tensor with dimensions `(batch, time)`. + reference (torch.Tensor): Non-matching clean reference. Tensor with dimensions `(batch, time_ref)`. + + Returns: + (torch.Tensor): Subjective metric score. Tensor with dimensions `(batch,)`. + """ + waveform, reference = self._align_shapes(waveform, reference) + waveform = self.projector(self.ssl_model.extract_features(waveform)[0][-1]) + reference = self.projector(self.ssl_model.extract_features(reference)[0][-1]) + concat = torch.cat((reference, waveform), dim=2) + score_diff = self.predictor(concat) # Score difference compared to the reference + return 5 - score_diff + + +def squim_subjective_model( + ssl_type: str, + feat_dim: int, + proj_dim: int, + att_dim: int, +) -> SquimSubjective: + """Build a custome :class:`torchaudio.prototype.models.SquimSubjective` model. + + Args: + ssl_type (str): Type of self-supervised learning (SSL) models. + Must be one of ["wav2vec2_base", "wav2vec2_large"]. + feat_dim (int): Feature dimension of the SSL feature representation. + proj_dim (int): Output dimension of projection layer. + att_dim (int): Dimension of attention scores. + """ + ssl_model = getattr(torchaudio.models, ssl_type)() + projector = nn.Linear(feat_dim, proj_dim) + predictor = Predictor(proj_dim * 2, att_dim) + return SquimSubjective(ssl_model, projector, predictor) + + +def squim_subjective_base() -> SquimSubjective: + """Build :class:`torchaudio.prototype.models.SquimSubjective` model with default arguments.""" + return squim_subjective_model( + ssl_type="wav2vec2_base", + feat_dim=768, + proj_dim=32, + att_dim=5, + ) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb83403f5719b68c790d2f9f934f8c80acea3557 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/__init__.py @@ -0,0 +1,45 @@ +from . import utils +from .model import ( + hubert_base, + hubert_large, + hubert_pretrain_base, + hubert_pretrain_large, + hubert_pretrain_model, + hubert_pretrain_xlarge, + hubert_xlarge, + HuBERTPretrainModel, + wav2vec2_base, + wav2vec2_large, + wav2vec2_large_lv60k, + wav2vec2_model, + wav2vec2_xlsr_1b, + wav2vec2_xlsr_2b, + wav2vec2_xlsr_300m, + Wav2Vec2Model, + wavlm_base, + wavlm_large, + wavlm_model, +) + +__all__ = [ + "Wav2Vec2Model", + "HuBERTPretrainModel", + "wavlm_model", + "wavlm_base", + "wavlm_large", + "wav2vec2_model", + "wav2vec2_base", + "wav2vec2_large", + "wav2vec2_large_lv60k", + "hubert_base", + "hubert_large", + "hubert_xlarge", + "hubert_pretrain_model", + "hubert_pretrain_base", + "hubert_pretrain_large", + "hubert_pretrain_xlarge", + "utils", + "wav2vec2_xlsr_300m", + "wav2vec2_xlsr_1b", + "wav2vec2_xlsr_2b", +] diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d469efd351b4cf1cddb4a3e44fe3ff7055e9262 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/__pycache__/components.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/__pycache__/components.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0fd969d6c1d2442ffb309e09f82d6d671878ac2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/__pycache__/components.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/__pycache__/model.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ca88f249609a76ab0a03b3bfd78d22b5616492f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/__pycache__/model.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/__pycache__/wavlm_attention.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/__pycache__/wavlm_attention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2525484c582b7c9834c52a25c0c44eaf1fef4a8b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/__pycache__/wavlm_attention.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/components.py b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/components.py new file mode 100644 index 0000000000000000000000000000000000000000..480a6ae50921efebf5930dc21caaa3a1a44945dd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/components.py @@ -0,0 +1,1167 @@ +import logging +from typing import List, Optional, Tuple + +import torch +from torch import nn, Tensor +from torch.nn import Module, Parameter + +from .wavlm_attention import WavLMSelfAttention + +_LG = logging.getLogger(__name__) + + +def _init_transformer_params(module): + """ + Initialize the weights of Transformer module in Wav2Vec2/HuBERT. + + If the module is ``nn.Linear``, normalize the weight with mean 0 and standard deviation 0.02. + If ``bias`` is set to ``True`` in the module, set ``bias`` to 0. + + If the module is ``nn.Embedding``, normalize the weight with mean 0 and standard deviation 0.02. + If ``padding_idx`` is not None, set the weight of padding to 0. + + Note: + Ths method corresponds to + `init_bert_params + `__ + in the original ``fairseq`` implementation. + """ + + def normal_(data): + data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class LayerNorm(nn.LayerNorm): + """Layer norm with transpose""" + + def forward(self, input: Tensor) -> Tensor: + x = input.transpose(-2, -1) + x = nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.transpose(-2, -1) + return x + + +class ConvLayerBlock(Module): + """Convolution unit of FeatureExtractor""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + bias: bool, + layer_norm: Optional[Module], + ): + super().__init__() + self.kernel_size = kernel_size + self.stride = stride + self.layer_norm = layer_norm + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + bias=bias, + ) + + def forward( + self, + x: Tensor, + length: Optional[Tensor], + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x (Tensor): Shape: ``[batch, in_channels, in_frame]``. + length (Tensor or None, optional): Shape ``[batch, ]``. + Returns: + Tensor: Shape ``[batch, out_channels, out_frames]``. + Optional[Tensor]: Shape ``[batch, ]``. + """ + x = self.conv(x) + if self.layer_norm is not None: + x = self.layer_norm(x) + x = nn.functional.gelu(x) + + if length is not None: + length = torch.div(length - self.kernel_size, self.stride, rounding_mode="floor") + 1 + # When input length is 0, the resulting length can be negative. So fix it here. + length = torch.max(torch.zeros_like(length), length) + return x, length + + +class FeatureExtractor(Module): + """Extract features from audio + + Args: + conv_layers (nn.ModuleList): + convolution layers + """ + + def __init__( + self, + conv_layers: nn.ModuleList, + ): + super().__init__() + self.conv_layers = conv_layers + + def forward( + self, + x: Tensor, + length: Optional[Tensor], + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x (Tensor): + Input Tensor representing a batch of audio, + shape: ``[batch, time]``. + length (Tensor or None, optional): + Valid length of each input sample. shape: ``[batch, ]``. + + Returns: + Tensor: + The resulting feature, shape: ``[batch, frame, feature]`` + Optional[Tensor]: + Valid length of each output sample. shape: ``[batch, ]``. + """ + if x.ndim != 2: + raise ValueError(f"Expected the input Tensor to be 2D (batch, time). Found: {list(x.shape)}") + + x = x.unsqueeze(1) # (batch, channel==1, frame) + for layer in self.conv_layers: + x, length = layer(x, length) # (batch, feature, frame) + x = x.transpose(1, 2) # (batch, frame, feature) + return x, length + + +class FeatureProjection(Module): + """Layer that connects FeatureExtractor and Encoder + + Projects features to encoder dimension. + + Args: + in_features (int): Input feature dim. + out_features (int): Output feature dim. + dropout (float): Dropout probability. + """ + + def __init__( + self, + in_features: int, + out_features: int, + dropout: float, + ): + super().__init__() + self.layer_norm = nn.LayerNorm(in_features) + self.projection = nn.Linear( + in_features, + out_features, + ) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + """ + Args: + x (Tensor): + Feature Tensor. shape: ``[batch, frame, in_feature]`` + Returns: + Tensor: Projected features. ``[batch, frame, out_feature]``. + """ + x = self.layer_norm(x) + x = self.projection(x) + x = self.dropout(x) + return x + + +class ConvolutionalPositionalEmbedding(Module): + """Positional embedding which is placed at the beginning of Transformer. + + Args: + embed_dim (int): Feature dimension of the input Tensor. + kernel_size (int): The number of frames to be use. + groups (int): The number of groups in feature dimensions. + """ + + def __init__( + self, + embed_dim: int, + kernel_size: int, + groups: int, + ): + super().__init__() + self.embed_dim = embed_dim + self.kernel_size = kernel_size + self.conv = nn.Conv1d( + in_channels=embed_dim, + out_channels=embed_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + groups=groups, + ) + + self.conv = nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2) + self.num_remove: int = 1 if kernel_size % 2 == 0 else 0 + + def __prepare_scriptable__(self): + if self.conv.__class__.__name__ == "ParametrizedConv1d": + _LG.warning("Removing weight_norm from %s", self.__class__.__name__) + torch.nn.utils.parametrize.remove_parametrizations(self.conv, "weight") + return self + + def forward(self, x): + """ + Args: + x (Tensor): shape ``[batch, frame, feature]``. + + Returns: + Tensor: The resulting feature. Shape ``[batch, frame, feature]``. + """ + x = x.transpose(-2, -1) + x = self.conv(x) + if self.num_remove > 0: + x = x[..., : -self.num_remove] + x = torch.nn.functional.gelu(x) + x = x.transpose(-2, -1) + return x + + +class SelfAttention(Module): + """Multihead Self Attention module + + Args: + embed_dim (int): Total dimension of the model. + num_heads (int): The number of heads. + dropout (float, optional): + Dropout probability on attn_output_weights. Default: ``0.0`` + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + ): + super().__init__() + head_dim = embed_dim // num_heads + if head_dim * num_heads != embed_dim: + raise ValueError(f"`embed_dim ({embed_dim})` is not divisible by `num_heads ({num_heads})`") + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = head_dim + + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + + def forward( + self, + x: Tensor, + attention_mask: Optional[Tensor] = None, + position_bias: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x (Tensor): shape: ``[batch_size, sequence_length, embed_dim]``. + attention_mask (Tensor or ``None``, optional): + shape: ``[batch_size, 1, sequence_length, sequence_length]`` + position_bias: Not used. Only for the compatibility with :py:class:`WavLMSelfAttention`. + key_padding_mask (Tensor or ``None``): Not used. Only for the compatibility with + :py:class:`WavLMSelfAttention`. + Returns: + (Tensor, ``None``): The resulting attention output and ``None`` (necessary for compatibility + with :py:class:`WavLMSelAttention`). + Attention output shape: ``[batch, sequence_length, embed_dim]``. + """ + if x.ndim != 3 or x.shape[2] != self.embed_dim: + raise ValueError( + f"The expected input shape is (batch, sequence, embed_dim=={self.embed_dim}). " f"Found {x.shape}." + ) + batch_size, length, embed_dim = x.size() + if attention_mask is not None: + shape_ = (batch_size, 1, length, length) + if attention_mask.size() != shape_: + raise ValueError(f"The expected attention mask shape is {shape_}. " f"Found {attention_mask.size()}.") + + shape = (batch_size, length, self.num_heads, self.head_dim) + q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd + k = self.k_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd + v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd + dropout = self.dropout if self.training else 0.0 + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attention_mask, dropout_p=dropout, is_causal=False + ) + attn_output = attn_output.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + output = self.out_proj(attn_output) + return output, None # Necessary for compatibility with WavLMSelAttention + + +class FeedForward(Module): + """Layer that follows attention layer in encoder layer.""" + + def __init__( + self, + io_features: int, + intermediate_features: int, + intermediate_dropout: float, + output_dropout: float, + ): + super().__init__() + self.intermediate_dense = nn.Linear(io_features, intermediate_features) + self.intermediate_dropout = nn.Dropout(intermediate_dropout) + self.output_dense = nn.Linear(intermediate_features, io_features) + self.output_dropout = nn.Dropout(output_dropout) + + def forward(self, x): + """ + Args: + x (Tensor): shape: `(batch, sequence_length, io_features)` + Returns: + x (Tensor): shape: `(batch, sequence_length, io_features)` + """ + x = self.intermediate_dense(x) + x = torch.nn.functional.gelu(x) + x = self.intermediate_dropout(x) + + x = self.output_dense(x) + x = self.output_dropout(x) + return x + + +class EncoderLayer(Module): + """A layer unit in encoder. Combines multihead self attention and feed forward.""" + + def __init__( + self, + attention: Module, + dropout: float, + layer_norm_first: bool, + feed_forward: Module, + ): + super().__init__() + self.attention = attention + self.dropout = nn.Dropout(dropout) + self.layer_norm = nn.LayerNorm(attention.embed_dim) + self.layer_norm_first = layer_norm_first + self.feed_forward = feed_forward + self.final_layer_norm = nn.LayerNorm(attention.embed_dim) + + def forward( + self, + x: Tensor, + attention_mask: Optional[Tensor] = None, + position_bias: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x (Tensor): Input of shape ``(batch, sequence_length, embed_dim)``. + attention_mask (Tensor or ``None``, optional): attention mask + of shape ``(batch, 1, sequence_length, sequence_length)``. (Default: ``None``) + position_bias (Tensor or ``None``, optional): position bias of shape + ``(batch_size * num_heads, src_len, src_len)``. + Only necessary for WavLM model, ``None`` otherwise. (Default: ``None``) + key_padding_mask (Tensor or ``None``, optional): key padding mask of shape ``(batch_size, src_len)``. + Only used for WavLM model, ignored otherwise. (Default: ``None``) + Returns: + (x, position_bias): Shapes are the same as in the input. Position bias is only relevant for WaLM model, + ``None`` otherwise. + """ + residual = x + + if self.layer_norm_first: + x = self.layer_norm(x) + + x, position_bias = self.attention( + x, attention_mask=attention_mask, position_bias=position_bias, key_padding_mask=key_padding_mask + ) + + x = self.dropout(x) + x = residual + x + + if self.layer_norm_first: + x = x + self.feed_forward(self.final_layer_norm(x)) + else: + x = self.layer_norm(x) + x = self.final_layer_norm(x + self.feed_forward(x)) + return x, position_bias + + +class Transformer(Module): + def __init__( + self, + pos_conv_embed: Module, + dropout: float, + layers: Module, + layer_norm_first: bool, + layer_drop: float, + ): + super().__init__() + self.pos_conv_embed = pos_conv_embed + self.layer_norm = nn.LayerNorm(pos_conv_embed.embed_dim) + self.layer_norm_first = layer_norm_first + self.layer_drop = layer_drop + self.dropout = nn.Dropout(dropout) + self.layers = layers + + def _preprocess(self, x: Tensor): + x = x + self.pos_conv_embed(x) + + if self.layer_norm_first: + x = self.layer_norm(x) + + x = self.dropout(x) + return x + + def forward( + self, + x: Tensor, + attention_mask: Optional[Tensor] = None, + position_bias: Optional[Tensor] = None, + ) -> Tensor: + x = self._preprocess(x) + for layer in self.layers: + if not (self.training and torch.rand(1).item() <= self.layer_drop): + x, position_bias = layer(x, attention_mask, position_bias=position_bias) + + if not self.layer_norm_first: + x = self.layer_norm(x) + return x + + def get_intermediate_outputs( + self, + x: Tensor, + attention_mask: Optional[Tensor] = None, + num_layers: Optional[int] = None, + ) -> List[Tensor]: + if num_layers is not None: + if not 0 < num_layers <= len(self.layers): + raise ValueError(f"`num_layers` must be between [1, {len(self.layers)}]") + + ret: List[Tensor] = [] + position_bias = None + x = self._preprocess(x) + for layer in self.layers: + x, position_bias = layer(x, attention_mask, position_bias=position_bias) + ret.append(x) + if num_layers is not None and len(ret) >= num_layers: + return ret + return ret + + +class Encoder(Module): + def __init__( + self, + feature_projection: Module, + transformer: Module, + ): + super().__init__() + self.feature_projection = feature_projection + self.transformer = transformer + + def _preprocess( + self, + features: Tensor, + lengths: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + x = self.feature_projection(features) + + mask: Optional[Tensor] = None + if lengths is not None: + batch_size, max_len, _ = x.shape + # create mask for padded elements and zero-out them + mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None] + x[mask] = 0.0 + # extend the mask to attention shape and set weight + mask = -10000.0 * mask[:, None, None, :].to(dtype=features.dtype) + mask = mask.expand(batch_size, 1, max_len, max_len) + return x, mask + + def forward( + self, + features: Tensor, + lengths: Optional[Tensor] = None, + ) -> Tensor: + x, mask = self._preprocess(features, lengths) + x = self.transformer(x, attention_mask=mask) + return x + + def extract_features( + self, + features: Tensor, + lengths: Optional[Tensor] = None, + num_layers: Optional[int] = None, + ) -> List[Tensor]: + x, masks = self._preprocess(features, lengths) + return self.transformer.get_intermediate_outputs(x, attention_mask=masks, num_layers=num_layers) + + +################################################################################ +def _get_feature_extractor( + norm_mode: str, + shapes: List[Tuple[int, int, int]], + bias: bool, +) -> FeatureExtractor: + """ + Args: + norm_mode (str): + Either "group_norm" or "layer_norm". + If "group_norm", then a single normalization is applied + in the first convolution block. Otherwise, all the convolution + blocks will have layer normalization. + This option corresponds to "extractor_mode" from fairseq. + Expected values are "group_norm" for Base arch, and + "layer_norm" for Large arch. + shapes (list of tuple of int): + Configuration of convolution layers. List of convolution configuration, + i.e. ``[(output_channel, kernel_size, stride), ...]`` + This option corresponds to "conv_feature_layers" from fairseq. + Expected values are + ``[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2`` + for all the architectures. + bias (bool): + Whether to include bias term to each convolution operation. + This option corresponds to "conv_bias" from fairseq. + Expected values are False for Base arch, and True for Large arch. + + See Also: + * Original implementation + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L666-L733 + * "extractor_mode" + - Def and base: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L38-L45 + - Large: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L52 + * "conv_feature_layers" + - Def, base and large: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L94-L100 + * "conv_bias" + - Def and base: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L101-L103 + - Large: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L61 + """ + if norm_mode not in ["group_norm", "layer_norm"]: + raise ValueError("Invalid norm mode") + blocks = [] + in_channels = 1 + for i, (out_channels, kernel_size, stride) in enumerate(shapes): + normalization = None + if norm_mode == "group_norm" and i == 0: + normalization = nn.GroupNorm( + num_groups=out_channels, + num_channels=out_channels, + affine=True, + ) + elif norm_mode == "layer_norm": + normalization = LayerNorm( + normalized_shape=out_channels, + elementwise_affine=True, + ) + blocks.append( + ConvLayerBlock( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + bias=bias, + layer_norm=normalization, + ) + ) + in_channels = out_channels + return FeatureExtractor(nn.ModuleList(blocks)) + + +def _get_encoder( + in_features: int, + embed_dim: int, + dropout_input: float, + pos_conv_kernel: int, + pos_conv_groups: int, + num_layers: int, + num_heads: int, + attention_dropout: float, + ff_interm_features: int, + ff_interm_dropout: float, + dropout: float, + layer_norm_first: bool, + layer_drop: float, +) -> Encoder: + """ + Args: + in_features (int): The number of input features. + embed_dim (int): + The dimension of embedding. + This option corresponds to "encoder_embed_dim" from fairseq. + Expected values are 768 for Base arch, and 1024 for Large arch. + dropout_input (float): + The dropout probability applied after the input feature is projected + to ``embed_dim``. + This option corresponds to "dropout_input" from fairseq. + Expected values are 0.1 for both Base and Large arch. + pos_conv_kernel (int): + The kernel size of convolutional positional embeddings. + This option corresponds to "conv_pos" from fairseq. + Expected values are 128 for both Base and Large arch. + pos_conv_groups (int): + The number of groups of convolutional positional embeddings. + This option corresponds to "conv_pos_groups" from fairseq. + Expected values are 16 for both Base and Large arch. + num_layers (int): + The number of self attention layers in transformer block. + This option corresponds to "encoder_layers" from fairseq. + Expected values are 12 for Base and 24 for Large arch. + num_heads (int): + The number of heads in self attention layers. + This option corresponds to "encoder_attention_heads" from fairseq. + Expected values are 12 for Base and 16 for Large arch. + attention_dropout (float): + The dropout probability applied after softmax in self-attention layer. + This option corresponds to "attention_dropout" from fairseq. + Expected values are 0.1 for Base and 0.0 for Large arch. + ff_interm_features (int): + The dimension of hidden features in feed forward layer. + This option corresponds to "encoder_ffn_embed_dim" from fairseq. + Expected values are 3072 for Base and 4096 for Large arch. + ff_interm_dropout (float): + The dropout probability applied in feedforward layer. + This option correspinds to "activation_dropout" from fairseq. + Expected values are 0.1 for both Base and Large arch. + dropout (float): + The dropout probability applied at the end of feed forward layer. + This option corresponds to "dropout" from fairseq. + Expected values are 0.1 for Base and 0.0 for Large arch. + layer_norm_first (bool): + Control the order of layer norm in transformer layer and each encoder layer. + If True, in transformer layer, layer norm is applied before features are fed + to encoder layers. In encoder layer, two layer norms are applied before and after + self attention. + If False, in transformer layer, layer norm is applied after features are fed + to encoder layers. In encoder layer, two layer norms are applied after self + attention, before and after feed forward. + This option corresponds to "layer_norm_first" from fairseq. + Expected values are False for Base and True for Large arch. + layer_drop (float): + Probability to drop each encoder layer during training. + This option corresponds to "layerdrop" from fairseq. + Expected values are 0.1 for both Base and Large arch. + + See Also: + * "encoder_embed_dim" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L49-L51 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L64 + * "dropout_input" + - Def, base and large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L75-L78 + * "conv_pos" + - Def, base and large + NOTE: The description is wrong. + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L204-L207 + - Usage + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L756 + * "conv_pos_groups" + - Def, base and large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L208-L211 + * "encoder_layers" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L46-L48 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L63 + * "encoder_attention_heads" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L55-L57 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L66 + * "attention_dropout" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L66-L68 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L60 + * "encoder_ffn_embed_dim" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L52-L54 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L65 + * "activation_dropout" + - Def + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L69-L71 + - Base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L55 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L55 + * "dropout" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L63-L65 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L59 + * "layer_norm_first" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L91-L93 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L53 + * "layerdrop" + - Def + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L72-L74 + - Base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L54 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L54 + """ + feature_projection = FeatureProjection(in_features, embed_dim, dropout_input) + pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups) + + # Original impl + # https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782 + encoder_layers = nn.ModuleList() + for _ in range(num_layers): + attention = SelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + dropout=attention_dropout, + ) + feed_forward = FeedForward( + io_features=embed_dim, + intermediate_features=ff_interm_features, + intermediate_dropout=ff_interm_dropout, + output_dropout=dropout, + ) + encoder_layers.append( + EncoderLayer( + attention=attention, + dropout=dropout, + layer_norm_first=layer_norm_first, + feed_forward=feed_forward, + ) + ) + transformer = Transformer( + pos_conv_embed=pos_conv, + dropout=dropout, + layers=encoder_layers, + layer_norm_first=not layer_norm_first, + layer_drop=layer_drop, + ) + return Encoder(feature_projection, transformer) + + +def _get_wavlm_encoder( + in_features: int, + embed_dim: int, + dropout_input: float, + pos_conv_kernel: int, + pos_conv_groups: int, + num_layers: int, + num_heads: int, + num_buckets: int, + max_distance: int, + attention_dropout: float, + ff_interm_features: int, + ff_interm_dropout: float, + dropout: float, + layer_norm_first: bool, + layer_drop: float, +) -> Encoder: + """ + Construct encoder for WavLM model :cite:`chen2022wavlm`. The structure of the encoder and most of the argments are + the same as in :py:func:`_get_encoder` so refer there for documentation. The only difference from Wav2Vec2 encoder + is usage of `WavLMSelfAttention` instead of `SelfAttention` and two additional parameters: `num_buckets` and + `max_distance`. + Args: + in_features (int): See :py:func:`_get_encoder`. + embed_dim (int): See :py:func:`_get_encoder`. + dropout_input (float): See :py:func:`_get_encoder`. + pos_conv_kernel (int): See :py:func:`_get_encoder`. + pos_conv_groups (int): See :py:func:`_get_encoder`. + num_layers (int): See :py:func:`_get_encoder`. + num_heads (int): See :py:func:`_get_encoder`. + num_buckets (int): Number of buckets for relative position embedding. + max_distance (int): Maximum distance for relative position embedding. + attention_dropout (float): See :py:func:`_get_encoder`. + ff_interm_features (int): See :py:func:`_get_encoder`. + ff_interm_dropout (float): See :py:func:`_get_encoder`. + dropout (float): See :py:func:`_get_encoder`. + layer_norm_first (bool): See :py:func:`_get_encoder`. + layer_drop (float): See :py:func:`_get_encoder`. + + """ + feature_projection = FeatureProjection(in_features, embed_dim, dropout_input) + pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups) + + # Original impl + # https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782 + encoder_layers = nn.ModuleList() + for i in range(num_layers): + attention = WavLMSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_buckets=num_buckets, + max_distance=max_distance, + dropout=attention_dropout, + has_relative_attention_bias=(i == 0), # Position embedding is only necessary in the first layer. + ) + feed_forward = FeedForward( + io_features=embed_dim, + intermediate_features=ff_interm_features, + intermediate_dropout=ff_interm_dropout, + output_dropout=dropout, + ) + encoder_layers.append( + EncoderLayer( + attention=attention, + dropout=dropout, + layer_norm_first=layer_norm_first, + feed_forward=feed_forward, + ) + ) + transformer = Transformer( + pos_conv_embed=pos_conv, + dropout=dropout, + layers=encoder_layers, + layer_norm_first=not layer_norm_first, + layer_drop=layer_drop, + ) + return Encoder(feature_projection, transformer) + + +def _compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> Tensor: + """Computes random mask spans for a given shape. + Args: + shape (int, int): The shape for which to compute masks. + The first element is batch size and second is the number of frames. + padding_mask (Tensor or None): The padding mask of the same dimension as shape, + which will prevent masking padded elements. + mask_prob (float): Probability for each token to be chosen as start of the span to be masked. + This will be multiplied by number of timesteps divided by length of mask span to mask + approximately this percentage of all elements. However due to overlaps, the actual number + will be smaller (unless no_overlap is True). + mask_type (str): How to compute mask lengths. Options: [``static``, ``uniform``, ``normal``, ``poisson``]. + ``static``: Fixed size + ``uniform``: Sample from uniform distribution [mask_other, mask_length*2] + ``normal``: Sample from normal distribution with mean ``mask_length`` and stdev ``mask_other``. + ``poisson``: Sample from possion distribution with lambda = ``mask_length``. + min_masks (int): Minimum number of masked spans. + no_overlap (bool): If false, will switch to an alternative recursive algorithm + that prevents spans from overlapping. + min_space (int): How many frames to keep unmasked between spans (Only used if no_overlap is True). + + Returns: + (Tensor): The mask indices of dimension `[batch, frame]`. + """ + + batch_size, frame = shape + mask = torch.full((batch_size, frame), False) + # add a random number for probabilistic rounding + all_num_mask = int(mask_prob * frame / float(mask_length) + torch.rand(1)) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(batch_size): + if padding_mask is not None: + sz = frame - padding_mask[i].long().sum().item() + # add a random number for probabilistic rounding + num_mask = int(mask_prob * sz / float(mask_length) + torch.rand(1)) + num_mask = max(min_masks, num_mask) + else: + sz = frame + num_mask = all_num_mask + + if mask_type == "static": + lengths = torch.full((num_mask,), mask_length) + elif mask_type == "uniform": + lengths = torch.randint(int(mask_other), mask_length * 2 + 1, size=(num_mask,)) + elif mask_type == "normal": + lengths = torch.normal(mask_length, mask_other, size=(num_mask,)) + lengths = torch.maximum(torch.ones(1), torch.round(lengths)).int() + elif mask_type == "poisson": + lengths = torch.poisson(mask_length, size=(num_mask,)) + lengths = torch.round(lengths).int() + else: + raise Exception(f"unknown mask selection: {mask_type}") + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = torch.randint(s, e - length, size=(1,)) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = torch.tensor([e - s for s, e in parts], dtype=torch.int) + lens[lens < length + min_space] = 0 + l_sum = lens.sum() + if l_sum == 0: + break + probs = lens / l_sum + c = torch.distributions.categorical.Categorical(probs).sample() + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = torch.tensor(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = torch.randperm(sz - min_len)[:num_mask] + mask_idc = torch.tensor( + [mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])] + ) + + mask_idcs.append(torch.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = mask_idc[torch.randperm(len(mask_idc))[:min_len].long()] + mask[i, mask_idc] = True + + return mask + + +def _get_padding_mask(input: Tensor, lengths: Tensor) -> Tensor: + """Generate the padding mask given the padded input and the lengths Tensors. + Args: + input (Tensor): The padded Tensor of dimension `[batch, max_len, frequency]`. + lengths (Tensor): The lengths Tensor of dimension `[batch,]`. + + Returns: + (Tensor): The padding mask. + """ + batch_size, max_len, _ = input.shape + mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None] + return mask + + +class MaskGenerator(Module): + """Generate the masks for masked prediction. + Args: + encoder_embed_dim (int): The dimension of the transformer embedding output. + mask_prob (float): Probability for each token to be chosen as start of the span to be masked. + This will be multiplied by number of timesteps divided by length of mask span to mask + approximately this percentage of all elements. However due to overlaps, the actual number + will be smaller (unless no_overlap is True). + mask_selection (str): How to choose the mask length. + Options: [``static``, ``uniform``, ``normal``, ``poisson``]. + mask_other (float): Secondary mask argument (used for more complex distributions). + mask_length (int): The lengths of the mask. + no_mask_overlap (bool): Whether to allow masks to overlap. + mask_min_space (int): Minimum space between spans (if no overlap is enabled). + mask_channel_prob (float): The probability of replacing a feature with 0. + mask_channel_selection (str): How to choose the mask length for channel masking. + Options: [``static``, ``uniform``, ``normal``, ``poisson``]. + mask_channel_other (float): Secondary mask argument for channel masking(used for more complex distributions). + mask_channel_length (int): Minimum space between spans (if no overlap is enabled) for channel masking. + no_mask_channel_overlap (bool): Whether to allow channel masks to overlap. + mask_channel_min_space (int): Minimum space between spans for channel masking(if no overlap is enabled). + """ + + def __init__( + self, + encoder_embed_dim: int, + mask_prob: float, + mask_selection: str, + mask_other: float, + mask_length: int, + no_mask_overlap: bool, + mask_min_space: int, + mask_channel_prob: float, + mask_channel_selection: str, + mask_channel_other: float, + mask_channel_length: int, + no_mask_channel_overlap: bool, + mask_channel_min_space: int, + ): + super().__init__() + self.mask_prob = mask_prob + self.mask_selection = mask_selection + self.mask_other = mask_other + self.mask_length = mask_length + self.no_mask_overlap = no_mask_overlap + self.mask_min_space = mask_min_space + self.mask_channel_prob = mask_channel_prob + self.mask_channel_selection = mask_channel_selection + self.mask_channel_other = mask_channel_other + self.mask_channel_length = mask_channel_length + self.no_mask_channel_overlap = no_mask_channel_overlap + self.mask_channel_min_space = mask_channel_min_space + self.mask_embedding = Parameter(torch.FloatTensor(encoder_embed_dim)) + torch.nn.init.uniform_(self.mask_embedding) + + def forward(self, x: Tensor, padding_mask: Optional[Tensor]) -> Tensor: + """ + Args: + x (Tensor): The encoded representations after feature extraction module. + padding_mask (Tensor or None): The padding mask of the same dimension as shape, + which will prevent masking padded elements. + + Returns: + Tensor: The feature representations after masking. + Tensor: The generated mask indices. + """ + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = _compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = mask_indices.to(x.device) + # change dtype of mask_embedding to x for mixed-precision training. + # see https://github.com/pytorch/audio/issues/2847 for details. + x[mask_indices] = self.mask_embedding.to(x.dtype) + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = _compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = mask_channel_indices.to(x.device).unsqueeze(1).expand(-1, T, -1) + x[mask_channel_indices] = 0 + + return x, mask_indices + + +def _compute_logits( + proj_x: Tensor, + target: Tensor, + label_embeddings: Parameter, +) -> Tensor: + """Compute the logits of the embeddings. + Args: + proj_x (Tensor): The projected masked representations of dimension `[batch, frame, final_dim]`. + target (Tensor): The target Tensor of dimension `[batch, frame, final_dim]`. + label_embeddings (Parameter): The trainable embeddings of target of dimension `[num_class, final_dim]`. + + Returns: + (Tensor): The logits of the inputs. + """ + logit_temp = 0.1 + pos = torch.index_select(label_embeddings, 0, target.long()) + negs = label_embeddings.unsqueeze(1).expand(-1, proj_x.size(0), -1) + neg_is_pos = (pos == negs).all(-1) + pos = pos.unsqueeze(0) + targets = torch.cat([pos, negs], dim=0) + + logits = torch.cosine_similarity(proj_x.float(), targets.float(), dim=-1).type_as(proj_x) + logits /= logit_temp + if neg_is_pos.any(): + logits[1:][neg_is_pos] = float("-inf") + logits = logits.transpose(0, 1) # (num_x, num_cls+1) + return logits + + +class LogitGenerator(Module): + """Generate the logits of masked and unmasked inputs. + Args: + encoder_embed_dim (int): The dimension of the transformer embedding output. + num_classes (int): The number of classes in the labels. + final_dim (int): Project final representations and targets to `final_dim`. + skip_masked (bool): If True, skip computing losses over masked frames. + skip_nomask (bool): If True, skip computing losses over unmasked frames. + """ + + def __init__( + self, + encoder_embed_dim: int, + num_classes: int, + final_dim: int, + skip_masked: bool, + skip_nomask: bool, + ): + super().__init__() + self.label_embeddings = Parameter(torch.FloatTensor(num_classes, final_dim)) + torch.nn.init.uniform_(self.label_embeddings) + self.final_proj = torch.nn.Linear(encoder_embed_dim, final_dim) + self.skip_masked = skip_masked + self.skip_nomask = skip_nomask + + def forward(self, x: Tensor, label: Tensor, mask_m: Tensor, mask_u: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + x (Tensor): The feature representation of the last transformer layer. + label (Tensor): The label Tensor of dimension `[batch, frame]`. + mask_m (Tensor): The masked indices of dimension `[batch, frame]`. + mask_u (Tensor): The unmasked indices of dimension `[batch, frame]`. + + Returns: + Tensor: The logits of masked frames. Tensor of dimension `[masked_frame, final_dim]`. + Tensor: The logits of unmasked frames. Tensor of dimension `[unmasked_frame, final_dim]`. + """ + proj_x = self.final_proj(x) + if self.skip_masked: + logit_m = None + else: + proj_x_m = proj_x[mask_m] + label_m = label[mask_m] + logit_m = _compute_logits(proj_x_m, label_m, self.label_embeddings) + + if self.skip_nomask: + logit_u = None + else: + proj_x_u = proj_x[mask_u] + label_u = label[mask_u] + logit_u = _compute_logits(proj_x_u, label_u, self.label_embeddings) + return logit_m, logit_u + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/import_fairseq.py b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/import_fairseq.py new file mode 100644 index 0000000000000000000000000000000000000000..39791e9b7d75ac3c2eb1fcf4f9c3517e7483048c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/import_fairseq.py @@ -0,0 +1,213 @@ +"""Import fariseq's wav2vec2.0 pretrained weights to torchaudios's format. + +For this module to work, you need `fairseq`. +""" +import re + +from torch.nn import Module + +from ..model import wav2vec2_model, Wav2Vec2Model + + +def _parse_config(w2v_model): + encoder = w2v_model.encoder + conv_layers = w2v_model.feature_extractor.conv_layers + + extractor_mode = "layer_norm" + if "GroupNorm" in conv_layers[0][2].__class__.__name__: + extractor_mode = "group_norm" + else: + extractor_mode = "layer_norm" + + conv_layer_config = [(l[0].out_channels, l[0].kernel_size[0], l[0].stride[0]) for l in conv_layers] + + if all(l[0].bias is None for l in conv_layers): + conv_bias = False + elif all(l[0].bias is not None for l in conv_layers): + conv_bias = True + else: + raise ValueError("Either all the convolutions layers have bias term or none of them should.") + + config = { + "extractor_mode": extractor_mode, + "extractor_conv_layer_config": conv_layer_config, + "extractor_conv_bias": conv_bias, + "encoder_embed_dim": w2v_model.post_extract_proj.out_features, + "encoder_projection_dropout": w2v_model.dropout_input.p, + "encoder_pos_conv_kernel": encoder.pos_conv[0].kernel_size[0], + "encoder_pos_conv_groups": encoder.pos_conv[0].groups, + "encoder_num_layers": len(encoder.layers), + "encoder_num_heads": encoder.layers[0].self_attn.num_heads, + "encoder_attention_dropout": encoder.layers[0].self_attn.dropout_module.p, + "encoder_ff_interm_features": encoder.layers[0].fc1.out_features, + "encoder_ff_interm_dropout": encoder.layers[0].dropout2.p, + "encoder_dropout": encoder.layers[0].dropout3.p, + "encoder_layer_norm_first": encoder.layer_norm_first, + "encoder_layer_drop": encoder.layerdrop, + } + return config + + +def _map_key(key): + key_ = key + if key.startswith("w2v_model."): + key = key.replace("w2v_model.", "") + if re.match(r"(mask_emb|quantizer|project_q|final_proj|mask_emb)", key): + return None + # Feature Extractor + # Group norm when "extractor_mode" is "default". + # (Only the first layer) + # "conv_layers.0.2.weight" -> "conv_layers.0.layer_norm.weight" + # "conv_layers.0.2.bias" -> "conv_layers.0.layer_norm.bias" + match = re.match(r"feature_extractor\.conv_layers\.0\.2\.(weight|bias)", key) + if match: + return f"feature_extractor.conv_layers.0.layer_norm.{match.group(1)}" + # Convolutions + # "conv_layers.X.0.weight" -> "conv_layers.X.conv.weight" + # "conv_layers.X.0.bias" -> "conv_layers.X.conv.bias" + match = re.match(r"feature_extractor\.conv_layers\.(\d+)\.0\.(weight|bias)", key) + if match: + return f"feature_extractor.conv_layers.{match.group(1)}.conv.{match.group(2)}" + # Layer norm when "extractor_mode" is "layer_norm". + # "conv_layers.X.2.1.weight" -> "conv_layers.X.layer_norm.weight" + # "conv_layers.X.2.1.bias" -> "conv_layers.X.layer_norm.bias" + match = re.match(r"feature_extractor\.conv_layers\.(\d+)\.2\.1\.(weight|bias)", key) + if match: + return f"feature_extractor.conv_layers.{match.group(1)}.layer_norm.{match.group(2)}" + match = re.match(r"post_extract_proj\.(weight|bias)", key) + # Encoder - Feature projection + if match: + return f"encoder.feature_projection.projection.{match.group(1)}" + match = re.match(r"layer_norm\.(weight|bias)", key) + if match: + return f"encoder.feature_projection.layer_norm.{match.group(1)}" + # Encoder - Transformer - Convolutional positional embedding + match = re.match(r"encoder\.pos_conv\.0\.(bias|weight_g|weight_v)", key) + if match: + return f"encoder.transformer.pos_conv_embed.conv.{match.group(1)}" + match = re.match(r"encoder\.layer_norm\.(weight|bias)", key) + if match: + return f"encoder.transformer.layer_norm.{match.group(1)}" + # Encoder - Transformer - Self attention layers + match = re.match(r"encoder\.layers\.(\d+)\.self_attn\.((k_|v_|q_|out_)proj\.(weight|bias))", key) + if match: + return f"encoder.transformer.layers.{match.group(1)}.attention.{match.group(2)}" + match = re.match(r"encoder\.layers\.(\d+)\.self_attn_layer_norm\.(weight|bias)", key) + if match: + return f"encoder.transformer.layers.{match.group(1)}.layer_norm.{match.group(2)}" + match = re.match(r"encoder\.layers\.(\d+)\.fc1\.(weight|bias)", key) + if match: + return f"encoder.transformer.layers.{match.group(1)}.feed_forward.intermediate_dense.{match.group(2)}" + match = re.match(r"encoder\.layers\.(\d+)\.fc2\.(weight|bias)", key) + if match: + return f"encoder.transformer.layers.{match.group(1)}.feed_forward.output_dense.{match.group(2)}" + match = re.match(r"encoder\.layers\.(\d+)\.final_layer_norm\.(weight|bias)", key) + if match: + return f"encoder.transformer.layers.{match.group(1)}.final_layer_norm.{match.group(2)}" + match = re.match(r"proj\.(weight|bias)", key) + # Auxiliary Module + # Only relevant when loading fine-tuned models + if match: + return f"aux.{match.group(1)}" + # HuBERT Extension + if key in ["label_embs_concat"]: + return key + raise ValueError(f"Unexpected key: {key_}") + + +def _convert_state_dict(state_dict): + converted = {} + for k, v in state_dict.items(): + k = _map_key(k) + if k is not None: + converted[k] = v + return converted + + +def import_fairseq_model(original: Module) -> Wav2Vec2Model: + """Builds :class:`Wav2Vec2Model` from the corresponding model object of + `fairseq `_. + + Args: + original (torch.nn.Module): + An instance of fairseq's Wav2Vec2.0 or HuBERT model. + One of ``fairseq.models.wav2vec.wav2vec2_asr.Wav2VecEncoder``, + ``fairseq.models.wav2vec.wav2vec2.Wav2Vec2Model`` or + ``fairseq.models.hubert.hubert_asr.HubertEncoder``. + + Returns: + Wav2Vec2Model: Imported model. + + Example - Loading pretrain-only model + >>> from torchaudio.models.wav2vec2.utils import import_fairseq_model + >>> + >>> # Load model using fairseq + >>> model_file = 'wav2vec_small.pt' + >>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file]) + >>> original = model[0] + >>> imported = import_fairseq_model(original) + >>> + >>> # Perform feature extraction + >>> waveform, _ = torchaudio.load('audio.wav') + >>> features, _ = imported.extract_features(waveform) + >>> + >>> # Compare result with the original model from fairseq + >>> reference = original.feature_extractor(waveform).transpose(1, 2) + >>> torch.testing.assert_allclose(features, reference) + + Example - Fine-tuned model + >>> from torchaudio.models.wav2vec2.utils import import_fairseq_model + >>> + >>> # Load model using fairseq + >>> model_file = 'wav2vec_small_960h.pt' + >>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file]) + >>> original = model[0] + >>> imported = import_fairseq_model(original.w2v_encoder) + >>> + >>> # Perform encoding + >>> waveform, _ = torchaudio.load('audio.wav') + >>> emission, _ = imported(waveform) + >>> + >>> # Compare result with the original model from fairseq + >>> mask = torch.zeros_like(waveform) + >>> reference = original(waveform, mask)['encoder_out'].transpose(0, 1) + >>> torch.testing.assert_allclose(emission, reference) + """ + class_ = original.__class__.__name__ + if class_ == "Wav2Vec2Model": + return _import_wav2vec2_pretraining(original) + if class_ == "Wav2VecEncoder": + return _import_wav2vec2_finetuning(original) + if class_ == "HubertModel": + return _import_hubert_pretraining(original) + if class_ == "HubertEncoder": + return _import_hubert_finetuning(original) + raise ValueError(f"Expected an instance of `Wav2Vec2Model` or `Wav2VecEncoder`. Found: {class_}") + + +def _import_wav2vec2_finetuning(original: Module) -> Wav2Vec2Model: + config = _parse_config(original.w2v_model) + model = wav2vec2_model(**config, aux_num_out=original.proj.out_features) + model.load_state_dict(_convert_state_dict(original.state_dict())) + return model + + +def _import_wav2vec2_pretraining(original: Module) -> Wav2Vec2Model: + config = _parse_config(original) + model = wav2vec2_model(**config, aux_num_out=None) + model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False) + return model + + +def _import_hubert_finetuning(original: Module) -> Wav2Vec2Model: + config = _parse_config(original.w2v_model) + model = wav2vec2_model(**config, aux_num_out=original.proj.out_features) + model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False) + return model + + +def _import_hubert_pretraining(original: Module) -> Wav2Vec2Model: + config = _parse_config(original) + model = wav2vec2_model(**config, aux_num_out=None) + model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False) + return model diff --git a/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/import_huggingface.py b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/import_huggingface.py new file mode 100644 index 0000000000000000000000000000000000000000..519d8c919f02be62b2f2e2aa0dd8db97222430d5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/models/wav2vec2/utils/import_huggingface.py @@ -0,0 +1,134 @@ +"""Import Hugging Face transformers's wav2vec2.0 pretrained weights to torchaudios's format. +""" +import logging +from typing import Any, Dict + +import torch +from torch.nn import Module + +from ..model import wav2vec2_model, Wav2Vec2Model, wavlm_model + +_LG = logging.getLogger(__name__) + + +def _get_config(cfg): + config = { + "extractor_mode": f"{cfg.feat_extract_norm}_norm", + "extractor_conv_layer_config": list(zip(cfg.conv_dim, cfg.conv_kernel, cfg.conv_stride)), + "extractor_conv_bias": cfg.conv_bias, + "encoder_embed_dim": cfg.hidden_size, + "encoder_projection_dropout": cfg.feat_proj_dropout, + "encoder_pos_conv_kernel": cfg.num_conv_pos_embeddings, + "encoder_pos_conv_groups": cfg.num_conv_pos_embedding_groups, + "encoder_num_layers": cfg.num_hidden_layers, + "encoder_num_heads": cfg.num_attention_heads, + "encoder_attention_dropout": cfg.attention_dropout, + "encoder_ff_interm_features": cfg.intermediate_size, + "encoder_ff_interm_dropout": cfg.activation_dropout, + "encoder_dropout": cfg.hidden_dropout, + "encoder_layer_norm_first": cfg.do_stable_layer_norm, + "encoder_layer_drop": cfg.layerdrop, + } + return config + + +def _get_config_wavlm(cfg): + config = { + "extractor_mode": f"{cfg.feat_extract_norm}_norm", + "extractor_conv_layer_config": list(zip(cfg.conv_dim, cfg.conv_kernel, cfg.conv_stride)), + "extractor_conv_bias": cfg.conv_bias, + "encoder_embed_dim": cfg.hidden_size, + "encoder_projection_dropout": cfg.feat_proj_dropout, + "encoder_pos_conv_kernel": cfg.num_conv_pos_embeddings, + "encoder_pos_conv_groups": cfg.num_conv_pos_embedding_groups, + "encoder_num_layers": cfg.num_hidden_layers, + "encoder_num_heads": cfg.num_attention_heads, + "encoder_num_buckets": cfg.num_buckets, + "encoder_max_distance": cfg.max_bucket_distance, + "encoder_attention_dropout": cfg.attention_dropout, + "encoder_ff_interm_features": cfg.intermediate_size, + "encoder_ff_interm_dropout": cfg.activation_dropout, + "encoder_dropout": cfg.hidden_dropout, + "encoder_layer_norm_first": cfg.do_stable_layer_norm, + "encoder_layer_drop": cfg.layerdrop, + } + return config + + +def _build(config, original): + is_for_ctc = original.__class__.__name__ in ["Wav2Vec2ForCTC", "WavLMForCTC"] + if is_for_ctc: + aux_num_out = original.config.vocab_size + wav2vec2 = original.wav2vec2 + else: + _LG.warning( + "The model is not an instance of Wav2Vec2ForCTC or WavLMForCTC. " '"lm_head" module is not imported.' + ) + aux_num_out = None + wav2vec2 = original + is_wavlm = original.__class__.__name__ in ["WavLMModel", "WavLMForCTC"] + if is_wavlm: + imported = wavlm_model(**config, aux_num_out=aux_num_out) + else: + imported = wav2vec2_model(**config, aux_num_out=aux_num_out) + imported.feature_extractor.load_state_dict(wav2vec2.feature_extractor.state_dict()) + imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict()) + encoder_state_dict = wav2vec2.encoder.state_dict() + if is_wavlm: # Rename paramaters of linear transformations for compatibility with the HF model + transform_wavlm_encoder_state(encoder_state_dict, config["encoder_num_layers"]) + imported.encoder.transformer.load_state_dict(encoder_state_dict) + if is_for_ctc: + imported.aux.load_state_dict(original.lm_head.state_dict()) + return imported + + +def transform_wavlm_encoder_state(state: Dict[str, Any], encoder_num_layers: int): + """Converts WavLM encoder state from HuggingFace format. In particular, concatenates linear projection weights and + biases to align with the structure of ``torch.nn.MultiheadAttention``. + """ + for i in range(encoder_num_layers): + q_proj_bias = state.pop(f"layers.{i}.attention.q_proj.bias") + k_proj_bias = state.pop(f"layers.{i}.attention.k_proj.bias") + v_proj_bias = state.pop(f"layers.{i}.attention.v_proj.bias") + q_proj_weight = state.pop(f"layers.{i}.attention.q_proj.weight") + k_proj_weight = state.pop(f"layers.{i}.attention.k_proj.weight") + v_proj_weight = state.pop(f"layers.{i}.attention.v_proj.weight") + state[f"layers.{i}.attention.attention.in_proj_bias"] = torch.cat((q_proj_bias, k_proj_bias, v_proj_bias)) + state[f"layers.{i}.attention.attention.in_proj_weight"] = torch.cat( + (q_proj_weight, k_proj_weight, v_proj_weight) + ) + + state[f"layers.{i}.attention.attention.out_proj.weight"] = state.pop(f"layers.{i}.attention.out_proj.weight") + state[f"layers.{i}.attention.attention.out_proj.bias"] = state.pop(f"layers.{i}.attention.out_proj.bias") + + +def import_huggingface_model(original: Module) -> Wav2Vec2Model: + """Builds :class:`Wav2Vec2Model` from the corresponding model object of + `Transformers `_. + + Args: + original (torch.nn.Module): An instance of ``Wav2Vec2ForCTC`` from ``transformers``. + + Returns: + Wav2Vec2Model: Imported model. + + Example + >>> from torchaudio.models.wav2vec2.utils import import_huggingface_model + >>> + >>> original = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") + >>> model = import_huggingface_model(original) + >>> + >>> waveforms, _ = torchaudio.load("audio.wav") + >>> logits, _ = model(waveforms) + """ + _LG.info("Importing model.") + _LG.info("Loading model configuration.") + is_wavlm = original.__class__.__name__ in ["WavLMModel", "WavLMForCTC"] + if is_wavlm: + config = _get_config_wavlm(original.config) + else: + config = _get_config(original.config) + _LG.debug(" - config: %s", config) + _LG.info("Building model.") + imported = _build(config, original) + return imported diff --git a/.venv/lib/python3.11/site-packages/torchaudio/utils/__init__.py b/.venv/lib/python3.11/site-packages/torchaudio/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..89bffaa34d61fbb12cfafbe7287af0b92139b19c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/utils/__init__.py @@ -0,0 +1,11 @@ +from torio.utils import ffmpeg_utils + +from . import sox_utils +from .download import download_asset + + +__all__ = [ + "download_asset", + "sox_utils", + "ffmpeg_utils", +] diff --git a/.venv/lib/python3.11/site-packages/torchaudio/utils/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12cc97ad4fb13d9cacf792c6812ff7f0b9c39e7a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/utils/__pycache__/download.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/utils/__pycache__/download.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..411a820c7ff3776be0179ac586220983adbeea24 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/utils/__pycache__/download.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/utils/__pycache__/ffmpeg_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/utils/__pycache__/ffmpeg_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe1bd507cf244f9d6fefcbe65f125dd7476b7b4d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/utils/__pycache__/ffmpeg_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/utils/__pycache__/sox_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchaudio/utils/__pycache__/sox_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f87f06cce9924c27d246998a2a6a43e8ecc36adb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchaudio/utils/__pycache__/sox_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchaudio/utils/download.py b/.venv/lib/python3.11/site-packages/torchaudio/utils/download.py new file mode 100644 index 0000000000000000000000000000000000000000..2081877d15a13e91a6fcb87905634addd23cc712 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/utils/download.py @@ -0,0 +1,89 @@ +import hashlib +import logging +from os import PathLike +from pathlib import Path +from typing import Union + +import torch +from torchaudio._internal import download_url_to_file + +_LG = logging.getLogger(__name__) + + +def _get_local_path(key): + path = Path(torch.hub.get_dir()) / "torchaudio" / Path(key) + path.parent.mkdir(parents=True, exist_ok=True) + return path + + +def _download(key, path, progress): + url = f"https://download.pytorch.org/torchaudio/{key}" + download_url_to_file(url, path, progress=progress) + + +def _get_hash(path, hash, chunk_size=1028): + m = hashlib.sha256() + with open(path, "rb") as file: + data = file.read(chunk_size) + while data: + m.update(data) + data = file.read(chunk_size) + return m.hexdigest() + + +def download_asset( + key: str, + hash: str = "", + path: Union[str, PathLike] = "", + *, + progress: bool = True, +) -> str: + """Download and store torchaudio assets to local file system. + + If a file exists at the download path, then that path is returned with or without + hash validation. + + Args: + key (str): The asset identifier. + hash (str, optional): + The value of SHA256 hash of the asset. If provided, it is used to verify + the downloaded / cached object. If not provided, then no hash validation + is performed. This means if a file exists at the download path, then the path + is returned as-is without verifying the identity of the file. + path (path-like object, optional): + By default, the downloaded asset is saved in a directory under + :py:func:`torch.hub.get_dir` and intermediate directories based on the given `key` + are created. + This argument can be used to overwrite the target location. + When this argument is provided, all the intermediate directories have to be + created beforehand. + progress (bool): Whether to show progress bar for downloading. Default: ``True``. + + Note: + Currently the valid key values are the route on ``download.pytorch.org/torchaudio``, + but this is an implementation detail. + + Returns: + str: The path to the asset on the local file system. + """ + path = path or _get_local_path(key) + + if path.exists(): + _LG.info("The local file (%s) exists. Skipping the download.", path) + else: + _LG.info("Downloading %s to %s", key, path) + _download(key, path, progress=progress) + + if hash: + _LG.info("Verifying the hash value.") + digest = _get_hash(path, hash) + + if digest != hash: + raise ValueError( + f"The hash value of the downloaded file ({path}), '{digest}' does not match " + f"the provided hash value, '{hash}'." + ) + + _LG.info("Hash validated.") + + return str(path) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/utils/ffmpeg_utils.py b/.venv/lib/python3.11/site-packages/torchaudio/utils/ffmpeg_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..385596edc1491e45ecd4aae14a07b2c0e64ecd22 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/utils/ffmpeg_utils.py @@ -0,0 +1,11 @@ +"""Module to change the configuration of FFmpeg libraries (such as libavformat). + +It affects functionalities in :py:mod:`torchaudio.io` (and indirectly :py:func:`torchaudio.load`). +""" + + +# This file is just for BC. +def __getattr__(item): + from torio.utils import ffmpeg_utils + + return getattr(ffmpeg_utils, item) diff --git a/.venv/lib/python3.11/site-packages/torchaudio/utils/sox_utils.py b/.venv/lib/python3.11/site-packages/torchaudio/utils/sox_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5212b77ea9d5ae0e58741322db7c9852a4ddafff --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchaudio/utils/sox_utils.py @@ -0,0 +1,99 @@ +"""Module to change the configuration of libsox, which is used by I/O functions like +:py:mod:`~torchaudio.backend.sox_io_backend` and :py:mod:`~torchaudio.sox_effects`. +""" + +from typing import Dict, List + +import torchaudio + +sox_ext = torchaudio._extension.lazy_import_sox_ext() + + +def set_seed(seed: int): + """Set libsox's PRNG + + Args: + seed (int): seed value. valid range is int32. + + See Also: + http://sox.sourceforge.net/sox.html + """ + sox_ext.set_seed(seed) + + +def set_verbosity(verbosity: int): + """Set libsox's verbosity + + Args: + verbosity (int): Set verbosity level of libsox. + + * ``1`` failure messages + * ``2`` warnings + * ``3`` details of processing + * ``4``-``6`` increasing levels of debug messages + + See Also: + http://sox.sourceforge.net/sox.html + """ + sox_ext.set_verbosity(verbosity) + + +def set_buffer_size(buffer_size: int): + """Set buffer size for sox effect chain + + Args: + buffer_size (int): Set the size in bytes of the buffers used for processing audio. + + See Also: + http://sox.sourceforge.net/sox.html + """ + sox_ext.set_buffer_size(buffer_size) + + +def set_use_threads(use_threads: bool): + """Set multithread option for sox effect chain + + Args: + use_threads (bool): When ``True``, enables ``libsox``'s parallel effects channels processing. + To use mutlithread, the underlying ``libsox`` has to be compiled with OpenMP support. + + See Also: + http://sox.sourceforge.net/sox.html + """ + sox_ext.set_use_threads(use_threads) + + +def list_effects() -> Dict[str, str]: + """List the available sox effect names + + Returns: + Dict[str, str]: Mapping from ``effect name`` to ``usage`` + """ + return dict(sox_ext.list_effects()) + + +def list_read_formats() -> List[str]: + """List the supported audio formats for read + + Returns: + List[str]: List of supported audio formats + """ + return sox_ext.list_read_formats() + + +def list_write_formats() -> List[str]: + """List the supported audio formats for write + + Returns: + List[str]: List of supported audio formats + """ + return sox_ext.list_write_formats() + + +def get_buffer_size() -> int: + """Get buffer size for sox effect chain + + Returns: + int: size in bytes of buffers used for processing audio. + """ + return sox_ext.get_buffer_size()