koichi12 commited on
Commit
7561da3
·
verified ·
1 Parent(s): 91f1872

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/torchaudio/_internal/__init__.py +10 -0
  2. .venv/lib/python3.11/site-packages/torchaudio/_internal/__pycache__/__init__.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/torchaudio/_internal/__pycache__/module_utils.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/torchaudio/_internal/module_utils.py +113 -0
  5. .venv/lib/python3.11/site-packages/torchaudio/datasets/__init__.py +47 -0
  6. .venv/lib/python3.11/site-packages/torchaudio/datasets/__pycache__/librilight_limited.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/torchaudio/datasets/__pycache__/librimix.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/torchaudio/datasets/__pycache__/librispeech_biasing.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/torchaudio/datasets/__pycache__/libritts.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/torchaudio/datasets/cmuarctic.py +157 -0
  11. .venv/lib/python3.11/site-packages/torchaudio/datasets/cmudict.py +186 -0
  12. .venv/lib/python3.11/site-packages/torchaudio/datasets/commonvoice.py +86 -0
  13. .venv/lib/python3.11/site-packages/torchaudio/datasets/dr_vctk.py +121 -0
  14. .venv/lib/python3.11/site-packages/torchaudio/datasets/fluentcommands.py +108 -0
  15. .venv/lib/python3.11/site-packages/torchaudio/datasets/gtzan.py +1118 -0
  16. .venv/lib/python3.11/site-packages/torchaudio/datasets/iemocap.py +147 -0
  17. .venv/lib/python3.11/site-packages/torchaudio/datasets/librilight_limited.py +111 -0
  18. .venv/lib/python3.11/site-packages/torchaudio/datasets/librimix.py +133 -0
  19. .venv/lib/python3.11/site-packages/torchaudio/datasets/librispeech.py +174 -0
  20. .venv/lib/python3.11/site-packages/torchaudio/datasets/librispeech_biasing.py +189 -0
  21. .venv/lib/python3.11/site-packages/torchaudio/datasets/libritts.py +168 -0
  22. .venv/lib/python3.11/site-packages/torchaudio/datasets/ljspeech.py +107 -0
  23. .venv/lib/python3.11/site-packages/torchaudio/datasets/musdb_hq.py +139 -0
  24. .venv/lib/python3.11/site-packages/torchaudio/datasets/quesst14.py +136 -0
  25. .venv/lib/python3.11/site-packages/torchaudio/datasets/snips.py +157 -0
  26. .venv/lib/python3.11/site-packages/torchaudio/datasets/speechcommands.py +183 -0
  27. .venv/lib/python3.11/site-packages/torchaudio/datasets/tedlium.py +218 -0
  28. .venv/lib/python3.11/site-packages/torchaudio/datasets/utils.py +54 -0
  29. .venv/lib/python3.11/site-packages/torchaudio/datasets/vctk.py +143 -0
  30. .venv/lib/python3.11/site-packages/torchaudio/datasets/voxceleb1.py +309 -0
  31. .venv/lib/python3.11/site-packages/torchaudio/datasets/yesno.py +89 -0
  32. .venv/lib/python3.11/site-packages/torchaudio/io/__init__.py +13 -0
  33. .venv/lib/python3.11/site-packages/torchaudio/io/__pycache__/__init__.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/torchaudio/io/__pycache__/_effector.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/torchaudio/io/__pycache__/_playback.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/torchaudio/io/_effector.py +347 -0
  37. .venv/lib/python3.11/site-packages/torchaudio/io/_playback.py +72 -0
  38. .venv/lib/python3.11/site-packages/torchaudio/lib/__init__.py +0 -0
  39. .venv/lib/python3.11/site-packages/torchaudio/lib/__pycache__/__init__.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/__init__.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/_hdemucs.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/conformer.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/conv_tasnet.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/deepspeech.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/emformer.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/rnnt.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/rnnt_decoder.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/tacotron2.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/wav2letter.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/wavernn.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/torchaudio/_internal/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from .fb import download_url_to_file, load_state_dict_from_url
3
+ except ImportError:
4
+ from torch.hub import download_url_to_file, load_state_dict_from_url
5
+
6
+
7
+ __all__ = [
8
+ "load_state_dict_from_url",
9
+ "download_url_to_file",
10
+ ]
.venv/lib/python3.11/site-packages/torchaudio/_internal/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (488 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchaudio/_internal/__pycache__/module_utils.cpython-311.pyc ADDED
Binary file (6.24 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/_internal/module_utils.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+ import os
3
+ import warnings
4
+ from functools import wraps
5
+ from typing import Optional
6
+
7
+
8
+ def eval_env(var, default):
9
+ """Check if environment varable has True-y value"""
10
+ if var not in os.environ:
11
+ return default
12
+
13
+ val = os.environ.get(var, "0")
14
+ trues = ["1", "true", "TRUE", "on", "ON", "yes", "YES"]
15
+ falses = ["0", "false", "FALSE", "off", "OFF", "no", "NO"]
16
+ if val in trues:
17
+ return True
18
+ if val not in falses:
19
+ # fmt: off
20
+ raise RuntimeError(
21
+ f"Unexpected environment variable value `{var}={val}`. "
22
+ f"Expected one of {trues + falses}")
23
+ # fmt: on
24
+ return False
25
+
26
+
27
+ def is_module_available(*modules: str) -> bool:
28
+ r"""Returns if a top-level module with :attr:`name` exists *without**
29
+ importing it. This is generally safer than try-catch block around a
30
+ `import X`. It avoids third party libraries breaking assumptions of some of
31
+ our tests, e.g., setting multiprocessing start method when imported
32
+ (see librosa/#747, torchvision/#544).
33
+ """
34
+ return all(importlib.util.find_spec(m) is not None for m in modules)
35
+
36
+
37
+ def requires_module(*modules: str):
38
+ """Decorate function to give error message if invoked without required optional modules.
39
+
40
+ This decorator is to give better error message to users rather
41
+ than raising ``NameError: name 'module' is not defined`` at random places.
42
+ """
43
+ missing = [m for m in modules if not is_module_available(m)]
44
+
45
+ if not missing:
46
+ # fall through. If all the modules are available, no need to decorate
47
+ def decorator(func):
48
+ return func
49
+
50
+ else:
51
+ req = f"module: {missing[0]}" if len(missing) == 1 else f"modules: {missing}"
52
+
53
+ def decorator(func):
54
+ @wraps(func)
55
+ def wrapped(*args, **kwargs):
56
+ raise RuntimeError(f"{func.__module__}.{func.__name__} requires {req}")
57
+
58
+ return wrapped
59
+
60
+ return decorator
61
+
62
+
63
+ def deprecated(direction: str, version: Optional[str] = None, remove: bool = False):
64
+ """Decorator to add deprecation message
65
+
66
+ Args:
67
+ direction (str): Migration steps to be given to users.
68
+ version (str or int): The version when the object will be removed
69
+ remove (bool): If enabled, append future removal message.
70
+ """
71
+
72
+ def decorator(func):
73
+ @wraps(func)
74
+ def wrapped(*args, **kwargs):
75
+ message = f"{func.__module__}.{func.__name__} has been deprecated. {direction}"
76
+ if remove:
77
+ message += f' It will be removed from {"future" if version is None else version} release. '
78
+ warnings.warn(message, stacklevel=2)
79
+ return func(*args, **kwargs)
80
+
81
+ message = "This function has been deprecated. "
82
+ if remove:
83
+ message += f'It will be removed from {"future" if version is None else version} release. '
84
+
85
+ wrapped.__doc__ = f"""DEPRECATED: {func.__doc__}
86
+
87
+ .. warning::
88
+
89
+ {message}
90
+ {direction}
91
+ """
92
+
93
+ return wrapped
94
+
95
+ return decorator
96
+
97
+
98
+ def fail_with_message(message):
99
+ """Generate decorator to give users message about missing TorchAudio extension."""
100
+
101
+ def decorator(func):
102
+ @wraps(func)
103
+ def wrapped(*args, **kwargs):
104
+ raise RuntimeError(f"{func.__module__}.{func.__name__} {message}")
105
+
106
+ return wrapped
107
+
108
+ return decorator
109
+
110
+
111
+ def no_op(func):
112
+ """Op-op decorator. Used in place of fail_with_message when a functionality that requires extension works fine."""
113
+ return func
.venv/lib/python3.11/site-packages/torchaudio/datasets/__init__.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .cmuarctic import CMUARCTIC
2
+ from .cmudict import CMUDict
3
+ from .commonvoice import COMMONVOICE
4
+ from .dr_vctk import DR_VCTK
5
+ from .fluentcommands import FluentSpeechCommands
6
+ from .gtzan import GTZAN
7
+ from .iemocap import IEMOCAP
8
+ from .librilight_limited import LibriLightLimited
9
+ from .librimix import LibriMix
10
+ from .librispeech import LIBRISPEECH
11
+ from .librispeech_biasing import LibriSpeechBiasing
12
+ from .libritts import LIBRITTS
13
+ from .ljspeech import LJSPEECH
14
+ from .musdb_hq import MUSDB_HQ
15
+ from .quesst14 import QUESST14
16
+ from .snips import Snips
17
+ from .speechcommands import SPEECHCOMMANDS
18
+ from .tedlium import TEDLIUM
19
+ from .vctk import VCTK_092
20
+ from .voxceleb1 import VoxCeleb1Identification, VoxCeleb1Verification
21
+ from .yesno import YESNO
22
+
23
+
24
+ __all__ = [
25
+ "COMMONVOICE",
26
+ "LIBRISPEECH",
27
+ "LibriSpeechBiasing",
28
+ "LibriLightLimited",
29
+ "SPEECHCOMMANDS",
30
+ "VCTK_092",
31
+ "DR_VCTK",
32
+ "YESNO",
33
+ "LJSPEECH",
34
+ "GTZAN",
35
+ "CMUARCTIC",
36
+ "CMUDict",
37
+ "LibriMix",
38
+ "LIBRITTS",
39
+ "TEDLIUM",
40
+ "QUESST14",
41
+ "MUSDB_HQ",
42
+ "FluentSpeechCommands",
43
+ "VoxCeleb1Identification",
44
+ "VoxCeleb1Verification",
45
+ "IEMOCAP",
46
+ "Snips",
47
+ ]
.venv/lib/python3.11/site-packages/torchaudio/datasets/__pycache__/librilight_limited.cpython-311.pyc ADDED
Binary file (6.97 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/datasets/__pycache__/librimix.cpython-311.pyc ADDED
Binary file (7.72 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/datasets/__pycache__/librispeech_biasing.cpython-311.pyc ADDED
Binary file (9.78 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/datasets/__pycache__/libritts.cpython-311.pyc ADDED
Binary file (8.05 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/datasets/cmuarctic.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Tuple, Union
5
+
6
+ import torchaudio
7
+ from torch import Tensor
8
+ from torch.utils.data import Dataset
9
+ from torchaudio._internal import download_url_to_file
10
+ from torchaudio.datasets.utils import _extract_tar
11
+
12
+ URL = "aew"
13
+ FOLDER_IN_ARCHIVE = "ARCTIC"
14
+ _CHECKSUMS = {
15
+ "http://festvox.org/cmu_arctic/packed/cmu_us_aew_arctic.tar.bz2": "645cb33c0f0b2ce41384fdd8d3db2c3f5fc15c1e688baeb74d2e08cab18ab406", # noqa: E501
16
+ "http://festvox.org/cmu_arctic/packed/cmu_us_ahw_arctic.tar.bz2": "024664adeb892809d646a3efd043625b46b5bfa3e6189b3500b2d0d59dfab06c", # noqa: E501
17
+ "http://festvox.org/cmu_arctic/packed/cmu_us_aup_arctic.tar.bz2": "2c55bc3050caa996758869126ad10cf42e1441212111db034b3a45189c18b6fc", # noqa: E501
18
+ "http://festvox.org/cmu_arctic/packed/cmu_us_awb_arctic.tar.bz2": "d74a950c9739a65f7bfc4dfa6187f2730fa03de5b8eb3f2da97a51b74df64d3c", # noqa: E501
19
+ "http://festvox.org/cmu_arctic/packed/cmu_us_axb_arctic.tar.bz2": "dd65c3d2907d1ee52f86e44f578319159e60f4bf722a9142be01161d84e330ff", # noqa: E501
20
+ "http://festvox.org/cmu_arctic/packed/cmu_us_bdl_arctic.tar.bz2": "26b91aaf48b2799b2956792b4632c2f926cd0542f402b5452d5adecb60942904", # noqa: E501
21
+ "http://festvox.org/cmu_arctic/packed/cmu_us_clb_arctic.tar.bz2": "3f16dc3f3b97955ea22623efb33b444341013fc660677b2e170efdcc959fa7c6", # noqa: E501
22
+ "http://festvox.org/cmu_arctic/packed/cmu_us_eey_arctic.tar.bz2": "8a0ee4e5acbd4b2f61a4fb947c1730ab3adcc9dc50b195981d99391d29928e8a", # noqa: E501
23
+ "http://festvox.org/cmu_arctic/packed/cmu_us_fem_arctic.tar.bz2": "3fcff629412b57233589cdb058f730594a62c4f3a75c20de14afe06621ef45e2", # noqa: E501
24
+ "http://festvox.org/cmu_arctic/packed/cmu_us_gka_arctic.tar.bz2": "dc82e7967cbd5eddbed33074b0699128dbd4482b41711916d58103707e38c67f", # noqa: E501
25
+ "http://festvox.org/cmu_arctic/packed/cmu_us_jmk_arctic.tar.bz2": "3a37c0e1dfc91e734fdbc88b562d9e2ebca621772402cdc693bbc9b09b211d73", # noqa: E501
26
+ "http://festvox.org/cmu_arctic/packed/cmu_us_ksp_arctic.tar.bz2": "8029cafce8296f9bed3022c44ef1e7953332b6bf6943c14b929f468122532717", # noqa: E501
27
+ "http://festvox.org/cmu_arctic/packed/cmu_us_ljm_arctic.tar.bz2": "b23993765cbf2b9e7bbc3c85b6c56eaf292ac81ee4bb887b638a24d104f921a0", # noqa: E501
28
+ "http://festvox.org/cmu_arctic/packed/cmu_us_lnh_arctic.tar.bz2": "4faf34d71aa7112813252fb20c5433e2fdd9a9de55a00701ffcbf05f24a5991a", # noqa: E501
29
+ "http://festvox.org/cmu_arctic/packed/cmu_us_rms_arctic.tar.bz2": "c6dc11235629c58441c071a7ba8a2d067903dfefbaabc4056d87da35b72ecda4", # noqa: E501
30
+ "http://festvox.org/cmu_arctic/packed/cmu_us_rxr_arctic.tar.bz2": "1fa4271c393e5998d200e56c102ff46fcfea169aaa2148ad9e9469616fbfdd9b", # noqa: E501
31
+ "http://festvox.org/cmu_arctic/packed/cmu_us_slp_arctic.tar.bz2": "54345ed55e45c23d419e9a823eef427f1cc93c83a710735ec667d068c916abf1", # noqa: E501
32
+ "http://festvox.org/cmu_arctic/packed/cmu_us_slt_arctic.tar.bz2": "7c173297916acf3cc7fcab2713be4c60b27312316765a90934651d367226b4ea", # noqa: E501
33
+ }
34
+
35
+
36
+ def load_cmuarctic_item(line: str, path: str, folder_audio: str, ext_audio: str) -> Tuple[Tensor, int, str, str]:
37
+
38
+ utterance_id, transcript = line[0].strip().split(" ", 2)[1:]
39
+
40
+ # Remove space, double quote, and single parenthesis from transcript
41
+ transcript = transcript[1:-3]
42
+
43
+ file_audio = os.path.join(path, folder_audio, utterance_id + ext_audio)
44
+
45
+ # Load audio
46
+ waveform, sample_rate = torchaudio.load(file_audio)
47
+
48
+ return (waveform, sample_rate, transcript, utterance_id.split("_")[1])
49
+
50
+
51
+ class CMUARCTIC(Dataset):
52
+ """*CMU ARCTIC* :cite:`Kominek03cmuarctic` dataset.
53
+
54
+ Args:
55
+ root (str or Path): Path to the directory where the dataset is found or downloaded.
56
+ url (str, optional):
57
+ The URL to download the dataset from or the type of the dataset to download.
58
+ (default: ``"aew"``)
59
+ Allowed type values are ``"aew"``, ``"ahw"``, ``"aup"``, ``"awb"``, ``"axb"``, ``"bdl"``,
60
+ ``"clb"``, ``"eey"``, ``"fem"``, ``"gka"``, ``"jmk"``, ``"ksp"``, ``"ljm"``, ``"lnh"``,
61
+ ``"rms"``, ``"rxr"``, ``"slp"`` or ``"slt"``.
62
+ folder_in_archive (str, optional):
63
+ The top-level directory of the dataset. (default: ``"ARCTIC"``)
64
+ download (bool, optional):
65
+ Whether to download the dataset if it is not found at root path. (default: ``False``).
66
+ """
67
+
68
+ _file_text = "txt.done.data"
69
+ _folder_text = "etc"
70
+ _ext_audio = ".wav"
71
+ _folder_audio = "wav"
72
+
73
+ def __init__(
74
+ self, root: Union[str, Path], url: str = URL, folder_in_archive: str = FOLDER_IN_ARCHIVE, download: bool = False
75
+ ) -> None:
76
+
77
+ if url in [
78
+ "aew",
79
+ "ahw",
80
+ "aup",
81
+ "awb",
82
+ "axb",
83
+ "bdl",
84
+ "clb",
85
+ "eey",
86
+ "fem",
87
+ "gka",
88
+ "jmk",
89
+ "ksp",
90
+ "ljm",
91
+ "lnh",
92
+ "rms",
93
+ "rxr",
94
+ "slp",
95
+ "slt",
96
+ ]:
97
+
98
+ url = "cmu_us_" + url + "_arctic"
99
+ ext_archive = ".tar.bz2"
100
+ base_url = "http://www.festvox.org/cmu_arctic/packed/"
101
+
102
+ url = os.path.join(base_url, url + ext_archive)
103
+
104
+ # Get string representation of 'root' in case Path object is passed
105
+ root = os.fspath(root)
106
+
107
+ basename = os.path.basename(url)
108
+ root = os.path.join(root, folder_in_archive)
109
+ if not os.path.isdir(root):
110
+ os.mkdir(root)
111
+ archive = os.path.join(root, basename)
112
+
113
+ basename = basename.split(".")[0]
114
+
115
+ self._path = os.path.join(root, basename)
116
+
117
+ if download:
118
+ if not os.path.isdir(self._path):
119
+ if not os.path.isfile(archive):
120
+ checksum = _CHECKSUMS.get(url, None)
121
+ download_url_to_file(url, archive, hash_prefix=checksum)
122
+ _extract_tar(archive)
123
+ else:
124
+ if not os.path.exists(self._path):
125
+ raise RuntimeError(
126
+ f"The path {self._path} doesn't exist. "
127
+ "Please check the ``root`` path or set `download=True` to download it"
128
+ )
129
+ self._text = os.path.join(self._path, self._folder_text, self._file_text)
130
+
131
+ with open(self._text, "r") as text:
132
+ walker = csv.reader(text, delimiter="\n")
133
+ self._walker = list(walker)
134
+
135
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str]:
136
+ """Load the n-th sample from the dataset.
137
+
138
+ Args:
139
+ n (int): The index of the sample to be loaded
140
+
141
+ Returns:
142
+ Tuple of the following items;
143
+
144
+ Tensor:
145
+ Waveform
146
+ int:
147
+ Sample rate
148
+ str:
149
+ Transcript
150
+ str:
151
+ Utterance ID
152
+ """
153
+ line = self._walker[n]
154
+ return load_cmuarctic_item(line, self._path, self._folder_audio, self._ext_audio)
155
+
156
+ def __len__(self) -> int:
157
+ return len(self._walker)
.venv/lib/python3.11/site-packages/torchaudio/datasets/cmudict.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from pathlib import Path
4
+ from typing import Iterable, List, Tuple, Union
5
+
6
+ from torch.utils.data import Dataset
7
+ from torchaudio._internal import download_url_to_file
8
+
9
+
10
+ _CHECKSUMS = {
11
+ "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b": "209a8b4cd265013e96f4658632a9878103b0c5abf62b50d4ef3ae1be226b29e4", # noqa: E501
12
+ "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols": "408ccaae803641c6d7b626b6299949320c2dbca96b2220fd3fb17887b023b027", # noqa: E501
13
+ }
14
+ _PUNCTUATIONS = {
15
+ "!EXCLAMATION-POINT",
16
+ '"CLOSE-QUOTE',
17
+ '"DOUBLE-QUOTE',
18
+ '"END-OF-QUOTE',
19
+ '"END-QUOTE',
20
+ '"IN-QUOTES',
21
+ '"QUOTE',
22
+ '"UNQUOTE',
23
+ "#HASH-MARK",
24
+ "#POUND-SIGN",
25
+ "#SHARP-SIGN",
26
+ "%PERCENT",
27
+ "&AMPERSAND",
28
+ "'END-INNER-QUOTE",
29
+ "'END-QUOTE",
30
+ "'INNER-QUOTE",
31
+ "'QUOTE",
32
+ "'SINGLE-QUOTE",
33
+ "(BEGIN-PARENS",
34
+ "(IN-PARENTHESES",
35
+ "(LEFT-PAREN",
36
+ "(OPEN-PARENTHESES",
37
+ "(PAREN",
38
+ "(PARENS",
39
+ "(PARENTHESES",
40
+ ")CLOSE-PAREN",
41
+ ")CLOSE-PARENTHESES",
42
+ ")END-PAREN",
43
+ ")END-PARENS",
44
+ ")END-PARENTHESES",
45
+ ")END-THE-PAREN",
46
+ ")PAREN",
47
+ ")PARENS",
48
+ ")RIGHT-PAREN",
49
+ ")UN-PARENTHESES",
50
+ "+PLUS",
51
+ ",COMMA",
52
+ "--DASH",
53
+ "-DASH",
54
+ "-HYPHEN",
55
+ "...ELLIPSIS",
56
+ ".DECIMAL",
57
+ ".DOT",
58
+ ".FULL-STOP",
59
+ ".PERIOD",
60
+ ".POINT",
61
+ "/SLASH",
62
+ ":COLON",
63
+ ";SEMI-COLON",
64
+ ";SEMI-COLON(1)",
65
+ "?QUESTION-MARK",
66
+ "{BRACE",
67
+ "{LEFT-BRACE",
68
+ "{OPEN-BRACE",
69
+ "}CLOSE-BRACE",
70
+ "}RIGHT-BRACE",
71
+ }
72
+
73
+
74
+ def _parse_dictionary(lines: Iterable[str], exclude_punctuations: bool) -> List[str]:
75
+ _alt_re = re.compile(r"\([0-9]+\)")
76
+ cmudict: List[Tuple[str, List[str]]] = []
77
+ for line in lines:
78
+ if not line or line.startswith(";;;"): # ignore comments
79
+ continue
80
+
81
+ word, phones = line.strip().split(" ")
82
+ if word in _PUNCTUATIONS:
83
+ if exclude_punctuations:
84
+ continue
85
+ # !EXCLAMATION-POINT -> !
86
+ # --DASH -> --
87
+ # ...ELLIPSIS -> ...
88
+ if word.startswith("..."):
89
+ word = "..."
90
+ elif word.startswith("--"):
91
+ word = "--"
92
+ else:
93
+ word = word[0]
94
+
95
+ # if a word have multiple pronunciations, there will be (number) appended to it
96
+ # for example, DATAPOINTS and DATAPOINTS(1),
97
+ # the regular expression `_alt_re` removes the '(1)' and change the word DATAPOINTS(1) to DATAPOINTS
98
+ word = re.sub(_alt_re, "", word)
99
+ phones = phones.split(" ")
100
+ cmudict.append((word, phones))
101
+
102
+ return cmudict
103
+
104
+
105
+ class CMUDict(Dataset):
106
+ """*CMU Pronouncing Dictionary* :cite:`cmudict` (CMUDict) dataset.
107
+
108
+ Args:
109
+ root (str or Path): Path to the directory where the dataset is found or downloaded.
110
+ exclude_punctuations (bool, optional):
111
+ When enabled, exclude the pronounciation of punctuations, such as
112
+ `!EXCLAMATION-POINT` and `#HASH-MARK`.
113
+ download (bool, optional):
114
+ Whether to download the dataset if it is not found at root path. (default: ``False``).
115
+ url (str, optional):
116
+ The URL to download the dictionary from.
117
+ (default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b"``)
118
+ url_symbols (str, optional):
119
+ The URL to download the list of symbols from.
120
+ (default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols"``)
121
+ """
122
+
123
+ def __init__(
124
+ self,
125
+ root: Union[str, Path],
126
+ exclude_punctuations: bool = True,
127
+ *,
128
+ download: bool = False,
129
+ url: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b",
130
+ url_symbols: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols",
131
+ ) -> None:
132
+
133
+ self.exclude_punctuations = exclude_punctuations
134
+
135
+ self._root_path = Path(root)
136
+ if not os.path.isdir(self._root_path):
137
+ raise RuntimeError(f"The root directory does not exist; {root}")
138
+
139
+ dict_file = self._root_path / os.path.basename(url)
140
+ symbol_file = self._root_path / os.path.basename(url_symbols)
141
+ if not os.path.exists(dict_file):
142
+ if not download:
143
+ raise RuntimeError(
144
+ "The dictionary file is not found in the following location. "
145
+ f"Set `download=True` to download it. {dict_file}"
146
+ )
147
+ checksum = _CHECKSUMS.get(url, None)
148
+ download_url_to_file(url, dict_file, checksum)
149
+ if not os.path.exists(symbol_file):
150
+ if not download:
151
+ raise RuntimeError(
152
+ "The symbol file is not found in the following location. "
153
+ f"Set `download=True` to download it. {symbol_file}"
154
+ )
155
+ checksum = _CHECKSUMS.get(url_symbols, None)
156
+ download_url_to_file(url_symbols, symbol_file, checksum)
157
+
158
+ with open(symbol_file, "r") as text:
159
+ self._symbols = [line.strip() for line in text.readlines()]
160
+
161
+ with open(dict_file, "r", encoding="latin-1") as text:
162
+ self._dictionary = _parse_dictionary(text.readlines(), exclude_punctuations=self.exclude_punctuations)
163
+
164
+ def __getitem__(self, n: int) -> Tuple[str, List[str]]:
165
+ """Load the n-th sample from the dataset.
166
+
167
+ Args:
168
+ n (int): The index of the sample to be loaded.
169
+
170
+ Returns:
171
+ Tuple of a word and its phonemes
172
+
173
+ str:
174
+ Word
175
+ List[str]:
176
+ Phonemes
177
+ """
178
+ return self._dictionary[n]
179
+
180
+ def __len__(self) -> int:
181
+ return len(self._dictionary)
182
+
183
+ @property
184
+ def symbols(self) -> List[str]:
185
+ """list[str]: A list of phonemes symbols, such as ``"AA"``, ``"AE"``, ``"AH"``."""
186
+ return self._symbols.copy()
.venv/lib/python3.11/site-packages/torchaudio/datasets/commonvoice.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Dict, List, Tuple, Union
5
+
6
+ import torchaudio
7
+ from torch import Tensor
8
+ from torch.utils.data import Dataset
9
+
10
+
11
+ def load_commonvoice_item(
12
+ line: List[str], header: List[str], path: str, folder_audio: str, ext_audio: str
13
+ ) -> Tuple[Tensor, int, Dict[str, str]]:
14
+ # Each line as the following data:
15
+ # client_id, path, sentence, up_votes, down_votes, age, gender, accent
16
+
17
+ if header[1] != "path":
18
+ raise ValueError(f"expect `header[1]` to be 'path', but got {header[1]}")
19
+ fileid = line[1]
20
+ filename = os.path.join(path, folder_audio, fileid)
21
+ if not filename.endswith(ext_audio):
22
+ filename += ext_audio
23
+ waveform, sample_rate = torchaudio.load(filename)
24
+
25
+ dic = dict(zip(header, line))
26
+
27
+ return waveform, sample_rate, dic
28
+
29
+
30
+ class COMMONVOICE(Dataset):
31
+ """*CommonVoice* :cite:`ardila2020common` dataset.
32
+
33
+ Args:
34
+ root (str or Path): Path to the directory where the dataset is located.
35
+ (Where the ``tsv`` file is present.)
36
+ tsv (str, optional):
37
+ The name of the tsv file used to construct the metadata, such as
38
+ ``"train.tsv"``, ``"test.tsv"``, ``"dev.tsv"``, ``"invalidated.tsv"``,
39
+ ``"validated.tsv"`` and ``"other.tsv"``. (default: ``"train.tsv"``)
40
+ """
41
+
42
+ _ext_txt = ".txt"
43
+ _ext_audio = ".mp3"
44
+ _folder_audio = "clips"
45
+
46
+ def __init__(self, root: Union[str, Path], tsv: str = "train.tsv") -> None:
47
+
48
+ # Get string representation of 'root' in case Path object is passed
49
+ self._path = os.fspath(root)
50
+ self._tsv = os.path.join(self._path, tsv)
51
+
52
+ with open(self._tsv, "r") as tsv_:
53
+ walker = csv.reader(tsv_, delimiter="\t")
54
+ self._header = next(walker)
55
+ self._walker = list(walker)
56
+
57
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, Dict[str, str]]:
58
+ """Load the n-th sample from the dataset.
59
+
60
+ Args:
61
+ n (int): The index of the sample to be loaded
62
+
63
+ Returns:
64
+ Tuple of the following items;
65
+
66
+ Tensor:
67
+ Waveform
68
+ int:
69
+ Sample rate
70
+ Dict[str, str]:
71
+ Dictionary containing the following items from the corresponding TSV file;
72
+
73
+ * ``"client_id"``
74
+ * ``"path"``
75
+ * ``"sentence"``
76
+ * ``"up_votes"``
77
+ * ``"down_votes"``
78
+ * ``"age"``
79
+ * ``"gender"``
80
+ * ``"accent"``
81
+ """
82
+ line = self._walker[n]
83
+ return load_commonvoice_item(line, self._header, self._path, self._folder_audio, self._ext_audio)
84
+
85
+ def __len__(self) -> int:
86
+ return len(self._walker)
.venv/lib/python3.11/site-packages/torchaudio/datasets/dr_vctk.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Dict, Tuple, Union
3
+
4
+ import torchaudio
5
+ from torch import Tensor
6
+ from torch.utils.data import Dataset
7
+ from torchaudio._internal import download_url_to_file
8
+ from torchaudio.datasets.utils import _extract_zip
9
+
10
+
11
+ _URL = "https://datashare.ed.ac.uk/bitstream/handle/10283/3038/DR-VCTK.zip"
12
+ _CHECKSUM = "781f12f4406ed36ed27ae3bce55da47ba176e2d8bae67319e389e07b2c9bd769"
13
+ _SUPPORTED_SUBSETS = {"train", "test"}
14
+
15
+
16
+ class DR_VCTK(Dataset):
17
+ """*Device Recorded VCTK (Small subset version)* :cite:`Sarfjoo2018DeviceRV` dataset.
18
+
19
+ Args:
20
+ root (str or Path): Root directory where the dataset's top level directory is found.
21
+ subset (str): The subset to use. Can be one of ``"train"`` and ``"test"``. (default: ``"train"``).
22
+ download (bool):
23
+ Whether to download the dataset if it is not found at root path. (default: ``False``).
24
+ url (str): The URL to download the dataset from.
25
+ (default: ``"https://datashare.ed.ac.uk/bitstream/handle/10283/3038/DR-VCTK.zip"``)
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ root: Union[str, Path],
31
+ subset: str = "train",
32
+ *,
33
+ download: bool = False,
34
+ url: str = _URL,
35
+ ) -> None:
36
+ if subset not in _SUPPORTED_SUBSETS:
37
+ raise RuntimeError(
38
+ f"The subset '{subset}' does not match any of the supported subsets: {_SUPPORTED_SUBSETS}"
39
+ )
40
+
41
+ root = Path(root).expanduser()
42
+ archive = root / "DR-VCTK.zip"
43
+
44
+ self._subset = subset
45
+ self._path = root / "DR-VCTK" / "DR-VCTK"
46
+ self._clean_audio_dir = self._path / f"clean_{self._subset}set_wav_16k"
47
+ self._noisy_audio_dir = self._path / f"device-recorded_{self._subset}set_wav_16k"
48
+ self._config_filepath = self._path / "configurations" / f"{self._subset}_ch_log.txt"
49
+
50
+ if not self._path.is_dir():
51
+ if not archive.is_file():
52
+ if not download:
53
+ raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
54
+ download_url_to_file(url, archive, hash_prefix=_CHECKSUM)
55
+ _extract_zip(archive, root)
56
+
57
+ self._config = self._load_config(self._config_filepath)
58
+ self._filename_list = sorted(self._config)
59
+
60
+ def _load_config(self, filepath: str) -> Dict[str, Tuple[str, int]]:
61
+ # Skip header
62
+ skip_rows = 2 if self._subset == "train" else 1
63
+
64
+ config = {}
65
+ with open(filepath) as f:
66
+ for i, line in enumerate(f):
67
+ if i < skip_rows or not line:
68
+ continue
69
+ filename, source, channel_id = line.strip().split("\t")
70
+ config[filename] = (source, int(channel_id))
71
+ return config
72
+
73
+ def _load_dr_vctk_item(self, filename: str) -> Tuple[Tensor, int, Tensor, int, str, str, str, int]:
74
+ speaker_id, utterance_id = filename.split(".")[0].split("_")
75
+ source, channel_id = self._config[filename]
76
+ file_clean_audio = self._clean_audio_dir / filename
77
+ file_noisy_audio = self._noisy_audio_dir / filename
78
+ waveform_clean, sample_rate_clean = torchaudio.load(file_clean_audio)
79
+ waveform_noisy, sample_rate_noisy = torchaudio.load(file_noisy_audio)
80
+ return (
81
+ waveform_clean,
82
+ sample_rate_clean,
83
+ waveform_noisy,
84
+ sample_rate_noisy,
85
+ speaker_id,
86
+ utterance_id,
87
+ source,
88
+ channel_id,
89
+ )
90
+
91
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, Tensor, int, str, str, str, int]:
92
+ """Load the n-th sample from the dataset.
93
+
94
+ Args:
95
+ n (int): The index of the sample to be loaded
96
+
97
+ Returns:
98
+ Tuple of the following items;
99
+
100
+ Tensor:
101
+ Clean waveform
102
+ int:
103
+ Sample rate of the clean waveform
104
+ Tensor:
105
+ Noisy waveform
106
+ int:
107
+ Sample rate of the noisy waveform
108
+ str:
109
+ Speaker ID
110
+ str:
111
+ Utterance ID
112
+ str:
113
+ Source
114
+ int:
115
+ Channel ID
116
+ """
117
+ filename = self._filename_list[n]
118
+ return self._load_dr_vctk_item(filename)
119
+
120
+ def __len__(self) -> int:
121
+ return len(self._filename_list)
.venv/lib/python3.11/site-packages/torchaudio/datasets/fluentcommands.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Tuple, Union
5
+
6
+ from torch import Tensor
7
+ from torch.utils.data import Dataset
8
+ from torchaudio.datasets.utils import _load_waveform
9
+
10
+ SAMPLE_RATE = 16000
11
+
12
+
13
+ class FluentSpeechCommands(Dataset):
14
+ """*Fluent Speech Commands* :cite:`fluent` dataset
15
+
16
+ Args:
17
+ root (str of Path): Path to the directory where the dataset is found.
18
+ subset (str, optional): subset of the dataset to use.
19
+ Options: [``"train"``, ``"valid"``, ``"test"``].
20
+ (Default: ``"train"``)
21
+ """
22
+
23
+ def __init__(self, root: Union[str, Path], subset: str = "train"):
24
+ if subset not in ["train", "valid", "test"]:
25
+ raise ValueError("`subset` must be one of ['train', 'valid', 'test']")
26
+
27
+ root = os.fspath(root)
28
+ self._path = os.path.join(root, "fluent_speech_commands_dataset")
29
+
30
+ if not os.path.isdir(self._path):
31
+ raise RuntimeError("Dataset not found.")
32
+
33
+ subset_path = os.path.join(self._path, "data", f"{subset}_data.csv")
34
+ with open(subset_path) as subset_csv:
35
+ subset_reader = csv.reader(subset_csv)
36
+ data = list(subset_reader)
37
+
38
+ self.header = data[0]
39
+ self.data = data[1:]
40
+
41
+ def get_metadata(self, n: int) -> Tuple[str, int, str, int, str, str, str, str]:
42
+ """Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
43
+ but otherwise returns the same fields as :py:func:`__getitem__`.
44
+
45
+ Args:
46
+ n (int): The index of the sample to be loaded
47
+
48
+ Returns:
49
+ Tuple of the following items;
50
+
51
+ str:
52
+ Path to audio
53
+ int:
54
+ Sample rate
55
+ str:
56
+ File name
57
+ int:
58
+ Speaker ID
59
+ str:
60
+ Transcription
61
+ str:
62
+ Action
63
+ str:
64
+ Object
65
+ str:
66
+ Location
67
+ """
68
+ sample = self.data[n]
69
+
70
+ file_name = sample[self.header.index("path")].split("/")[-1]
71
+ file_name = file_name.split(".")[0]
72
+ speaker_id, transcription, action, obj, location = sample[2:]
73
+ file_path = os.path.join("wavs", "speakers", speaker_id, f"{file_name}.wav")
74
+
75
+ return file_path, SAMPLE_RATE, file_name, speaker_id, transcription, action, obj, location
76
+
77
+ def __len__(self) -> int:
78
+ return len(self.data)
79
+
80
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, str, str, str, str]:
81
+ """Load the n-th sample from the dataset.
82
+
83
+ Args:
84
+ n (int): The index of the sample to be loaded
85
+
86
+ Returns:
87
+ Tuple of the following items;
88
+
89
+ Tensor:
90
+ Waveform
91
+ int:
92
+ Sample rate
93
+ str:
94
+ File name
95
+ int:
96
+ Speaker ID
97
+ str:
98
+ Transcription
99
+ str:
100
+ Action
101
+ str:
102
+ Object
103
+ str:
104
+ Location
105
+ """
106
+ metadata = self.get_metadata(n)
107
+ waveform = _load_waveform(self._path, metadata[0], metadata[1])
108
+ return (waveform,) + metadata[1:]
.venv/lib/python3.11/site-packages/torchaudio/datasets/gtzan.py ADDED
@@ -0,0 +1,1118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torchaudio
6
+ from torch import Tensor
7
+ from torch.utils.data import Dataset
8
+ from torchaudio._internal import download_url_to_file
9
+ from torchaudio.datasets.utils import _extract_tar
10
+
11
+ # The following lists prefixed with `filtered_` provide a filtered split
12
+ # that:
13
+ #
14
+ # a. Mitigate a known issue with GTZAN (duplication)
15
+ #
16
+ # b. Provide a standard split for testing it against other
17
+ # methods (e.g. the one in jordipons/sklearn-audio-transfer-learning).
18
+ #
19
+ # Those are used when GTZAN is initialised with the `filtered` keyword.
20
+ # The split was taken from (github) jordipons/sklearn-audio-transfer-learning.
21
+
22
+ gtzan_genres = [
23
+ "blues",
24
+ "classical",
25
+ "country",
26
+ "disco",
27
+ "hiphop",
28
+ "jazz",
29
+ "metal",
30
+ "pop",
31
+ "reggae",
32
+ "rock",
33
+ ]
34
+
35
+ filtered_test = [
36
+ "blues.00012",
37
+ "blues.00013",
38
+ "blues.00014",
39
+ "blues.00015",
40
+ "blues.00016",
41
+ "blues.00017",
42
+ "blues.00018",
43
+ "blues.00019",
44
+ "blues.00020",
45
+ "blues.00021",
46
+ "blues.00022",
47
+ "blues.00023",
48
+ "blues.00024",
49
+ "blues.00025",
50
+ "blues.00026",
51
+ "blues.00027",
52
+ "blues.00028",
53
+ "blues.00061",
54
+ "blues.00062",
55
+ "blues.00063",
56
+ "blues.00064",
57
+ "blues.00065",
58
+ "blues.00066",
59
+ "blues.00067",
60
+ "blues.00068",
61
+ "blues.00069",
62
+ "blues.00070",
63
+ "blues.00071",
64
+ "blues.00072",
65
+ "blues.00098",
66
+ "blues.00099",
67
+ "classical.00011",
68
+ "classical.00012",
69
+ "classical.00013",
70
+ "classical.00014",
71
+ "classical.00015",
72
+ "classical.00016",
73
+ "classical.00017",
74
+ "classical.00018",
75
+ "classical.00019",
76
+ "classical.00020",
77
+ "classical.00021",
78
+ "classical.00022",
79
+ "classical.00023",
80
+ "classical.00024",
81
+ "classical.00025",
82
+ "classical.00026",
83
+ "classical.00027",
84
+ "classical.00028",
85
+ "classical.00029",
86
+ "classical.00034",
87
+ "classical.00035",
88
+ "classical.00036",
89
+ "classical.00037",
90
+ "classical.00038",
91
+ "classical.00039",
92
+ "classical.00040",
93
+ "classical.00041",
94
+ "classical.00049",
95
+ "classical.00077",
96
+ "classical.00078",
97
+ "classical.00079",
98
+ "country.00030",
99
+ "country.00031",
100
+ "country.00032",
101
+ "country.00033",
102
+ "country.00034",
103
+ "country.00035",
104
+ "country.00036",
105
+ "country.00037",
106
+ "country.00038",
107
+ "country.00039",
108
+ "country.00040",
109
+ "country.00043",
110
+ "country.00044",
111
+ "country.00046",
112
+ "country.00047",
113
+ "country.00048",
114
+ "country.00050",
115
+ "country.00051",
116
+ "country.00053",
117
+ "country.00054",
118
+ "country.00055",
119
+ "country.00056",
120
+ "country.00057",
121
+ "country.00058",
122
+ "country.00059",
123
+ "country.00060",
124
+ "country.00061",
125
+ "country.00062",
126
+ "country.00063",
127
+ "country.00064",
128
+ "disco.00001",
129
+ "disco.00021",
130
+ "disco.00058",
131
+ "disco.00062",
132
+ "disco.00063",
133
+ "disco.00064",
134
+ "disco.00065",
135
+ "disco.00066",
136
+ "disco.00069",
137
+ "disco.00076",
138
+ "disco.00077",
139
+ "disco.00078",
140
+ "disco.00079",
141
+ "disco.00080",
142
+ "disco.00081",
143
+ "disco.00082",
144
+ "disco.00083",
145
+ "disco.00084",
146
+ "disco.00085",
147
+ "disco.00086",
148
+ "disco.00087",
149
+ "disco.00088",
150
+ "disco.00091",
151
+ "disco.00092",
152
+ "disco.00093",
153
+ "disco.00094",
154
+ "disco.00096",
155
+ "disco.00097",
156
+ "disco.00099",
157
+ "hiphop.00000",
158
+ "hiphop.00026",
159
+ "hiphop.00027",
160
+ "hiphop.00030",
161
+ "hiphop.00040",
162
+ "hiphop.00043",
163
+ "hiphop.00044",
164
+ "hiphop.00045",
165
+ "hiphop.00051",
166
+ "hiphop.00052",
167
+ "hiphop.00053",
168
+ "hiphop.00054",
169
+ "hiphop.00062",
170
+ "hiphop.00063",
171
+ "hiphop.00064",
172
+ "hiphop.00065",
173
+ "hiphop.00066",
174
+ "hiphop.00067",
175
+ "hiphop.00068",
176
+ "hiphop.00069",
177
+ "hiphop.00070",
178
+ "hiphop.00071",
179
+ "hiphop.00072",
180
+ "hiphop.00073",
181
+ "hiphop.00074",
182
+ "hiphop.00075",
183
+ "hiphop.00099",
184
+ "jazz.00073",
185
+ "jazz.00074",
186
+ "jazz.00075",
187
+ "jazz.00076",
188
+ "jazz.00077",
189
+ "jazz.00078",
190
+ "jazz.00079",
191
+ "jazz.00080",
192
+ "jazz.00081",
193
+ "jazz.00082",
194
+ "jazz.00083",
195
+ "jazz.00084",
196
+ "jazz.00085",
197
+ "jazz.00086",
198
+ "jazz.00087",
199
+ "jazz.00088",
200
+ "jazz.00089",
201
+ "jazz.00090",
202
+ "jazz.00091",
203
+ "jazz.00092",
204
+ "jazz.00093",
205
+ "jazz.00094",
206
+ "jazz.00095",
207
+ "jazz.00096",
208
+ "jazz.00097",
209
+ "jazz.00098",
210
+ "jazz.00099",
211
+ "metal.00012",
212
+ "metal.00013",
213
+ "metal.00014",
214
+ "metal.00015",
215
+ "metal.00022",
216
+ "metal.00023",
217
+ "metal.00025",
218
+ "metal.00026",
219
+ "metal.00027",
220
+ "metal.00028",
221
+ "metal.00029",
222
+ "metal.00030",
223
+ "metal.00031",
224
+ "metal.00032",
225
+ "metal.00033",
226
+ "metal.00038",
227
+ "metal.00039",
228
+ "metal.00067",
229
+ "metal.00070",
230
+ "metal.00073",
231
+ "metal.00074",
232
+ "metal.00075",
233
+ "metal.00078",
234
+ "metal.00083",
235
+ "metal.00085",
236
+ "metal.00087",
237
+ "metal.00088",
238
+ "pop.00000",
239
+ "pop.00001",
240
+ "pop.00013",
241
+ "pop.00014",
242
+ "pop.00043",
243
+ "pop.00063",
244
+ "pop.00064",
245
+ "pop.00065",
246
+ "pop.00066",
247
+ "pop.00069",
248
+ "pop.00070",
249
+ "pop.00071",
250
+ "pop.00072",
251
+ "pop.00073",
252
+ "pop.00074",
253
+ "pop.00075",
254
+ "pop.00076",
255
+ "pop.00077",
256
+ "pop.00078",
257
+ "pop.00079",
258
+ "pop.00082",
259
+ "pop.00088",
260
+ "pop.00089",
261
+ "pop.00090",
262
+ "pop.00091",
263
+ "pop.00092",
264
+ "pop.00093",
265
+ "pop.00094",
266
+ "pop.00095",
267
+ "pop.00096",
268
+ "reggae.00034",
269
+ "reggae.00035",
270
+ "reggae.00036",
271
+ "reggae.00037",
272
+ "reggae.00038",
273
+ "reggae.00039",
274
+ "reggae.00040",
275
+ "reggae.00046",
276
+ "reggae.00047",
277
+ "reggae.00048",
278
+ "reggae.00052",
279
+ "reggae.00053",
280
+ "reggae.00064",
281
+ "reggae.00065",
282
+ "reggae.00066",
283
+ "reggae.00067",
284
+ "reggae.00068",
285
+ "reggae.00071",
286
+ "reggae.00079",
287
+ "reggae.00082",
288
+ "reggae.00083",
289
+ "reggae.00084",
290
+ "reggae.00087",
291
+ "reggae.00088",
292
+ "reggae.00089",
293
+ "reggae.00090",
294
+ "rock.00010",
295
+ "rock.00011",
296
+ "rock.00012",
297
+ "rock.00013",
298
+ "rock.00014",
299
+ "rock.00015",
300
+ "rock.00027",
301
+ "rock.00028",
302
+ "rock.00029",
303
+ "rock.00030",
304
+ "rock.00031",
305
+ "rock.00032",
306
+ "rock.00033",
307
+ "rock.00034",
308
+ "rock.00035",
309
+ "rock.00036",
310
+ "rock.00037",
311
+ "rock.00039",
312
+ "rock.00040",
313
+ "rock.00041",
314
+ "rock.00042",
315
+ "rock.00043",
316
+ "rock.00044",
317
+ "rock.00045",
318
+ "rock.00046",
319
+ "rock.00047",
320
+ "rock.00048",
321
+ "rock.00086",
322
+ "rock.00087",
323
+ "rock.00088",
324
+ "rock.00089",
325
+ "rock.00090",
326
+ ]
327
+
328
+ filtered_train = [
329
+ "blues.00029",
330
+ "blues.00030",
331
+ "blues.00031",
332
+ "blues.00032",
333
+ "blues.00033",
334
+ "blues.00034",
335
+ "blues.00035",
336
+ "blues.00036",
337
+ "blues.00037",
338
+ "blues.00038",
339
+ "blues.00039",
340
+ "blues.00040",
341
+ "blues.00041",
342
+ "blues.00042",
343
+ "blues.00043",
344
+ "blues.00044",
345
+ "blues.00045",
346
+ "blues.00046",
347
+ "blues.00047",
348
+ "blues.00048",
349
+ "blues.00049",
350
+ "blues.00073",
351
+ "blues.00074",
352
+ "blues.00075",
353
+ "blues.00076",
354
+ "blues.00077",
355
+ "blues.00078",
356
+ "blues.00079",
357
+ "blues.00080",
358
+ "blues.00081",
359
+ "blues.00082",
360
+ "blues.00083",
361
+ "blues.00084",
362
+ "blues.00085",
363
+ "blues.00086",
364
+ "blues.00087",
365
+ "blues.00088",
366
+ "blues.00089",
367
+ "blues.00090",
368
+ "blues.00091",
369
+ "blues.00092",
370
+ "blues.00093",
371
+ "blues.00094",
372
+ "blues.00095",
373
+ "blues.00096",
374
+ "blues.00097",
375
+ "classical.00030",
376
+ "classical.00031",
377
+ "classical.00032",
378
+ "classical.00033",
379
+ "classical.00043",
380
+ "classical.00044",
381
+ "classical.00045",
382
+ "classical.00046",
383
+ "classical.00047",
384
+ "classical.00048",
385
+ "classical.00050",
386
+ "classical.00051",
387
+ "classical.00052",
388
+ "classical.00053",
389
+ "classical.00054",
390
+ "classical.00055",
391
+ "classical.00056",
392
+ "classical.00057",
393
+ "classical.00058",
394
+ "classical.00059",
395
+ "classical.00060",
396
+ "classical.00061",
397
+ "classical.00062",
398
+ "classical.00063",
399
+ "classical.00064",
400
+ "classical.00065",
401
+ "classical.00066",
402
+ "classical.00067",
403
+ "classical.00080",
404
+ "classical.00081",
405
+ "classical.00082",
406
+ "classical.00083",
407
+ "classical.00084",
408
+ "classical.00085",
409
+ "classical.00086",
410
+ "classical.00087",
411
+ "classical.00088",
412
+ "classical.00089",
413
+ "classical.00090",
414
+ "classical.00091",
415
+ "classical.00092",
416
+ "classical.00093",
417
+ "classical.00094",
418
+ "classical.00095",
419
+ "classical.00096",
420
+ "classical.00097",
421
+ "classical.00098",
422
+ "classical.00099",
423
+ "country.00019",
424
+ "country.00020",
425
+ "country.00021",
426
+ "country.00022",
427
+ "country.00023",
428
+ "country.00024",
429
+ "country.00025",
430
+ "country.00026",
431
+ "country.00028",
432
+ "country.00029",
433
+ "country.00065",
434
+ "country.00066",
435
+ "country.00067",
436
+ "country.00068",
437
+ "country.00069",
438
+ "country.00070",
439
+ "country.00071",
440
+ "country.00072",
441
+ "country.00073",
442
+ "country.00074",
443
+ "country.00075",
444
+ "country.00076",
445
+ "country.00077",
446
+ "country.00078",
447
+ "country.00079",
448
+ "country.00080",
449
+ "country.00081",
450
+ "country.00082",
451
+ "country.00083",
452
+ "country.00084",
453
+ "country.00085",
454
+ "country.00086",
455
+ "country.00087",
456
+ "country.00088",
457
+ "country.00089",
458
+ "country.00090",
459
+ "country.00091",
460
+ "country.00092",
461
+ "country.00093",
462
+ "country.00094",
463
+ "country.00095",
464
+ "country.00096",
465
+ "country.00097",
466
+ "country.00098",
467
+ "country.00099",
468
+ "disco.00005",
469
+ "disco.00015",
470
+ "disco.00016",
471
+ "disco.00017",
472
+ "disco.00018",
473
+ "disco.00019",
474
+ "disco.00020",
475
+ "disco.00022",
476
+ "disco.00023",
477
+ "disco.00024",
478
+ "disco.00025",
479
+ "disco.00026",
480
+ "disco.00027",
481
+ "disco.00028",
482
+ "disco.00029",
483
+ "disco.00030",
484
+ "disco.00031",
485
+ "disco.00032",
486
+ "disco.00033",
487
+ "disco.00034",
488
+ "disco.00035",
489
+ "disco.00036",
490
+ "disco.00037",
491
+ "disco.00039",
492
+ "disco.00040",
493
+ "disco.00041",
494
+ "disco.00042",
495
+ "disco.00043",
496
+ "disco.00044",
497
+ "disco.00045",
498
+ "disco.00047",
499
+ "disco.00049",
500
+ "disco.00053",
501
+ "disco.00054",
502
+ "disco.00056",
503
+ "disco.00057",
504
+ "disco.00059",
505
+ "disco.00061",
506
+ "disco.00070",
507
+ "disco.00073",
508
+ "disco.00074",
509
+ "disco.00089",
510
+ "hiphop.00002",
511
+ "hiphop.00003",
512
+ "hiphop.00004",
513
+ "hiphop.00005",
514
+ "hiphop.00006",
515
+ "hiphop.00007",
516
+ "hiphop.00008",
517
+ "hiphop.00009",
518
+ "hiphop.00010",
519
+ "hiphop.00011",
520
+ "hiphop.00012",
521
+ "hiphop.00013",
522
+ "hiphop.00014",
523
+ "hiphop.00015",
524
+ "hiphop.00016",
525
+ "hiphop.00017",
526
+ "hiphop.00018",
527
+ "hiphop.00019",
528
+ "hiphop.00020",
529
+ "hiphop.00021",
530
+ "hiphop.00022",
531
+ "hiphop.00023",
532
+ "hiphop.00024",
533
+ "hiphop.00025",
534
+ "hiphop.00028",
535
+ "hiphop.00029",
536
+ "hiphop.00031",
537
+ "hiphop.00032",
538
+ "hiphop.00033",
539
+ "hiphop.00034",
540
+ "hiphop.00035",
541
+ "hiphop.00036",
542
+ "hiphop.00037",
543
+ "hiphop.00038",
544
+ "hiphop.00041",
545
+ "hiphop.00042",
546
+ "hiphop.00055",
547
+ "hiphop.00056",
548
+ "hiphop.00057",
549
+ "hiphop.00058",
550
+ "hiphop.00059",
551
+ "hiphop.00060",
552
+ "hiphop.00061",
553
+ "hiphop.00077",
554
+ "hiphop.00078",
555
+ "hiphop.00079",
556
+ "hiphop.00080",
557
+ "jazz.00000",
558
+ "jazz.00001",
559
+ "jazz.00011",
560
+ "jazz.00012",
561
+ "jazz.00013",
562
+ "jazz.00014",
563
+ "jazz.00015",
564
+ "jazz.00016",
565
+ "jazz.00017",
566
+ "jazz.00018",
567
+ "jazz.00019",
568
+ "jazz.00020",
569
+ "jazz.00021",
570
+ "jazz.00022",
571
+ "jazz.00023",
572
+ "jazz.00024",
573
+ "jazz.00041",
574
+ "jazz.00047",
575
+ "jazz.00048",
576
+ "jazz.00049",
577
+ "jazz.00050",
578
+ "jazz.00051",
579
+ "jazz.00052",
580
+ "jazz.00053",
581
+ "jazz.00054",
582
+ "jazz.00055",
583
+ "jazz.00056",
584
+ "jazz.00057",
585
+ "jazz.00058",
586
+ "jazz.00059",
587
+ "jazz.00060",
588
+ "jazz.00061",
589
+ "jazz.00062",
590
+ "jazz.00063",
591
+ "jazz.00064",
592
+ "jazz.00065",
593
+ "jazz.00066",
594
+ "jazz.00067",
595
+ "jazz.00068",
596
+ "jazz.00069",
597
+ "jazz.00070",
598
+ "jazz.00071",
599
+ "jazz.00072",
600
+ "metal.00002",
601
+ "metal.00003",
602
+ "metal.00005",
603
+ "metal.00021",
604
+ "metal.00024",
605
+ "metal.00035",
606
+ "metal.00046",
607
+ "metal.00047",
608
+ "metal.00048",
609
+ "metal.00049",
610
+ "metal.00050",
611
+ "metal.00051",
612
+ "metal.00052",
613
+ "metal.00053",
614
+ "metal.00054",
615
+ "metal.00055",
616
+ "metal.00056",
617
+ "metal.00057",
618
+ "metal.00059",
619
+ "metal.00060",
620
+ "metal.00061",
621
+ "metal.00062",
622
+ "metal.00063",
623
+ "metal.00064",
624
+ "metal.00065",
625
+ "metal.00066",
626
+ "metal.00069",
627
+ "metal.00071",
628
+ "metal.00072",
629
+ "metal.00079",
630
+ "metal.00080",
631
+ "metal.00084",
632
+ "metal.00086",
633
+ "metal.00089",
634
+ "metal.00090",
635
+ "metal.00091",
636
+ "metal.00092",
637
+ "metal.00093",
638
+ "metal.00094",
639
+ "metal.00095",
640
+ "metal.00096",
641
+ "metal.00097",
642
+ "metal.00098",
643
+ "metal.00099",
644
+ "pop.00002",
645
+ "pop.00003",
646
+ "pop.00004",
647
+ "pop.00005",
648
+ "pop.00006",
649
+ "pop.00007",
650
+ "pop.00008",
651
+ "pop.00009",
652
+ "pop.00011",
653
+ "pop.00012",
654
+ "pop.00016",
655
+ "pop.00017",
656
+ "pop.00018",
657
+ "pop.00019",
658
+ "pop.00020",
659
+ "pop.00023",
660
+ "pop.00024",
661
+ "pop.00025",
662
+ "pop.00026",
663
+ "pop.00027",
664
+ "pop.00028",
665
+ "pop.00029",
666
+ "pop.00031",
667
+ "pop.00032",
668
+ "pop.00033",
669
+ "pop.00034",
670
+ "pop.00035",
671
+ "pop.00036",
672
+ "pop.00038",
673
+ "pop.00039",
674
+ "pop.00040",
675
+ "pop.00041",
676
+ "pop.00042",
677
+ "pop.00044",
678
+ "pop.00046",
679
+ "pop.00049",
680
+ "pop.00050",
681
+ "pop.00080",
682
+ "pop.00097",
683
+ "pop.00098",
684
+ "pop.00099",
685
+ "reggae.00000",
686
+ "reggae.00001",
687
+ "reggae.00002",
688
+ "reggae.00004",
689
+ "reggae.00006",
690
+ "reggae.00009",
691
+ "reggae.00011",
692
+ "reggae.00012",
693
+ "reggae.00014",
694
+ "reggae.00015",
695
+ "reggae.00016",
696
+ "reggae.00017",
697
+ "reggae.00018",
698
+ "reggae.00019",
699
+ "reggae.00020",
700
+ "reggae.00021",
701
+ "reggae.00022",
702
+ "reggae.00023",
703
+ "reggae.00024",
704
+ "reggae.00025",
705
+ "reggae.00026",
706
+ "reggae.00027",
707
+ "reggae.00028",
708
+ "reggae.00029",
709
+ "reggae.00030",
710
+ "reggae.00031",
711
+ "reggae.00032",
712
+ "reggae.00042",
713
+ "reggae.00043",
714
+ "reggae.00044",
715
+ "reggae.00045",
716
+ "reggae.00049",
717
+ "reggae.00050",
718
+ "reggae.00051",
719
+ "reggae.00054",
720
+ "reggae.00055",
721
+ "reggae.00056",
722
+ "reggae.00057",
723
+ "reggae.00058",
724
+ "reggae.00059",
725
+ "reggae.00060",
726
+ "reggae.00063",
727
+ "reggae.00069",
728
+ "rock.00000",
729
+ "rock.00001",
730
+ "rock.00002",
731
+ "rock.00003",
732
+ "rock.00004",
733
+ "rock.00005",
734
+ "rock.00006",
735
+ "rock.00007",
736
+ "rock.00008",
737
+ "rock.00009",
738
+ "rock.00016",
739
+ "rock.00017",
740
+ "rock.00018",
741
+ "rock.00019",
742
+ "rock.00020",
743
+ "rock.00021",
744
+ "rock.00022",
745
+ "rock.00023",
746
+ "rock.00024",
747
+ "rock.00025",
748
+ "rock.00026",
749
+ "rock.00057",
750
+ "rock.00058",
751
+ "rock.00059",
752
+ "rock.00060",
753
+ "rock.00061",
754
+ "rock.00062",
755
+ "rock.00063",
756
+ "rock.00064",
757
+ "rock.00065",
758
+ "rock.00066",
759
+ "rock.00067",
760
+ "rock.00068",
761
+ "rock.00069",
762
+ "rock.00070",
763
+ "rock.00091",
764
+ "rock.00092",
765
+ "rock.00093",
766
+ "rock.00094",
767
+ "rock.00095",
768
+ "rock.00096",
769
+ "rock.00097",
770
+ "rock.00098",
771
+ "rock.00099",
772
+ ]
773
+
774
+ filtered_valid = [
775
+ "blues.00000",
776
+ "blues.00001",
777
+ "blues.00002",
778
+ "blues.00003",
779
+ "blues.00004",
780
+ "blues.00005",
781
+ "blues.00006",
782
+ "blues.00007",
783
+ "blues.00008",
784
+ "blues.00009",
785
+ "blues.00010",
786
+ "blues.00011",
787
+ "blues.00050",
788
+ "blues.00051",
789
+ "blues.00052",
790
+ "blues.00053",
791
+ "blues.00054",
792
+ "blues.00055",
793
+ "blues.00056",
794
+ "blues.00057",
795
+ "blues.00058",
796
+ "blues.00059",
797
+ "blues.00060",
798
+ "classical.00000",
799
+ "classical.00001",
800
+ "classical.00002",
801
+ "classical.00003",
802
+ "classical.00004",
803
+ "classical.00005",
804
+ "classical.00006",
805
+ "classical.00007",
806
+ "classical.00008",
807
+ "classical.00009",
808
+ "classical.00010",
809
+ "classical.00068",
810
+ "classical.00069",
811
+ "classical.00070",
812
+ "classical.00071",
813
+ "classical.00072",
814
+ "classical.00073",
815
+ "classical.00074",
816
+ "classical.00075",
817
+ "classical.00076",
818
+ "country.00000",
819
+ "country.00001",
820
+ "country.00002",
821
+ "country.00003",
822
+ "country.00004",
823
+ "country.00005",
824
+ "country.00006",
825
+ "country.00007",
826
+ "country.00009",
827
+ "country.00010",
828
+ "country.00011",
829
+ "country.00012",
830
+ "country.00013",
831
+ "country.00014",
832
+ "country.00015",
833
+ "country.00016",
834
+ "country.00017",
835
+ "country.00018",
836
+ "country.00027",
837
+ "country.00041",
838
+ "country.00042",
839
+ "country.00045",
840
+ "country.00049",
841
+ "disco.00000",
842
+ "disco.00002",
843
+ "disco.00003",
844
+ "disco.00004",
845
+ "disco.00006",
846
+ "disco.00007",
847
+ "disco.00008",
848
+ "disco.00009",
849
+ "disco.00010",
850
+ "disco.00011",
851
+ "disco.00012",
852
+ "disco.00013",
853
+ "disco.00014",
854
+ "disco.00046",
855
+ "disco.00048",
856
+ "disco.00052",
857
+ "disco.00067",
858
+ "disco.00068",
859
+ "disco.00072",
860
+ "disco.00075",
861
+ "disco.00090",
862
+ "disco.00095",
863
+ "hiphop.00081",
864
+ "hiphop.00082",
865
+ "hiphop.00083",
866
+ "hiphop.00084",
867
+ "hiphop.00085",
868
+ "hiphop.00086",
869
+ "hiphop.00087",
870
+ "hiphop.00088",
871
+ "hiphop.00089",
872
+ "hiphop.00090",
873
+ "hiphop.00091",
874
+ "hiphop.00092",
875
+ "hiphop.00093",
876
+ "hiphop.00094",
877
+ "hiphop.00095",
878
+ "hiphop.00096",
879
+ "hiphop.00097",
880
+ "hiphop.00098",
881
+ "jazz.00002",
882
+ "jazz.00003",
883
+ "jazz.00004",
884
+ "jazz.00005",
885
+ "jazz.00006",
886
+ "jazz.00007",
887
+ "jazz.00008",
888
+ "jazz.00009",
889
+ "jazz.00010",
890
+ "jazz.00025",
891
+ "jazz.00026",
892
+ "jazz.00027",
893
+ "jazz.00028",
894
+ "jazz.00029",
895
+ "jazz.00030",
896
+ "jazz.00031",
897
+ "jazz.00032",
898
+ "metal.00000",
899
+ "metal.00001",
900
+ "metal.00006",
901
+ "metal.00007",
902
+ "metal.00008",
903
+ "metal.00009",
904
+ "metal.00010",
905
+ "metal.00011",
906
+ "metal.00016",
907
+ "metal.00017",
908
+ "metal.00018",
909
+ "metal.00019",
910
+ "metal.00020",
911
+ "metal.00036",
912
+ "metal.00037",
913
+ "metal.00068",
914
+ "metal.00076",
915
+ "metal.00077",
916
+ "metal.00081",
917
+ "metal.00082",
918
+ "pop.00010",
919
+ "pop.00053",
920
+ "pop.00055",
921
+ "pop.00058",
922
+ "pop.00059",
923
+ "pop.00060",
924
+ "pop.00061",
925
+ "pop.00062",
926
+ "pop.00081",
927
+ "pop.00083",
928
+ "pop.00084",
929
+ "pop.00085",
930
+ "pop.00086",
931
+ "reggae.00061",
932
+ "reggae.00062",
933
+ "reggae.00070",
934
+ "reggae.00072",
935
+ "reggae.00074",
936
+ "reggae.00076",
937
+ "reggae.00077",
938
+ "reggae.00078",
939
+ "reggae.00085",
940
+ "reggae.00092",
941
+ "reggae.00093",
942
+ "reggae.00094",
943
+ "reggae.00095",
944
+ "reggae.00096",
945
+ "reggae.00097",
946
+ "reggae.00098",
947
+ "reggae.00099",
948
+ "rock.00038",
949
+ "rock.00049",
950
+ "rock.00050",
951
+ "rock.00051",
952
+ "rock.00052",
953
+ "rock.00053",
954
+ "rock.00054",
955
+ "rock.00055",
956
+ "rock.00056",
957
+ "rock.00071",
958
+ "rock.00072",
959
+ "rock.00073",
960
+ "rock.00074",
961
+ "rock.00075",
962
+ "rock.00076",
963
+ "rock.00077",
964
+ "rock.00078",
965
+ "rock.00079",
966
+ "rock.00080",
967
+ "rock.00081",
968
+ "rock.00082",
969
+ "rock.00083",
970
+ "rock.00084",
971
+ "rock.00085",
972
+ ]
973
+
974
+
975
+ URL = "http://opihi.cs.uvic.ca/sound/genres.tar.gz"
976
+ FOLDER_IN_ARCHIVE = "genres"
977
+ _CHECKSUMS = {
978
+ "http://opihi.cs.uvic.ca/sound/genres.tar.gz": "24347e0223d2ba798e0a558c4c172d9d4a19c00bb7963fe055d183dadb4ef2c6"
979
+ }
980
+
981
+
982
+ def load_gtzan_item(fileid: str, path: str, ext_audio: str) -> Tuple[Tensor, str]:
983
+ """
984
+ Loads a file from the dataset and returns the raw waveform
985
+ as a Torch Tensor, its sample rate as an integer, and its
986
+ genre as a string.
987
+ """
988
+ # Filenames are of the form label.id, e.g. blues.00078
989
+ label, _ = fileid.split(".")
990
+
991
+ # Read wav
992
+ file_audio = os.path.join(path, label, fileid + ext_audio)
993
+ waveform, sample_rate = torchaudio.load(file_audio)
994
+
995
+ return waveform, sample_rate, label
996
+
997
+
998
+ class GTZAN(Dataset):
999
+ """*GTZAN* :cite:`tzanetakis_essl_cook_2001` dataset.
1000
+
1001
+ Note:
1002
+ Please see http://marsyas.info/downloads/datasets.html if you are planning to use
1003
+ this dataset to publish results.
1004
+
1005
+ Note:
1006
+ As of October 2022, the download link is not currently working. Setting ``download=True``
1007
+ in GTZAN dataset will result in a URL connection error.
1008
+
1009
+ Args:
1010
+ root (str or Path): Path to the directory where the dataset is found or downloaded.
1011
+ url (str, optional): The URL to download the dataset from.
1012
+ (default: ``"http://opihi.cs.uvic.ca/sound/genres.tar.gz"``)
1013
+ folder_in_archive (str, optional): The top-level directory of the dataset.
1014
+ download (bool, optional):
1015
+ Whether to download the dataset if it is not found at root path. (default: ``False``).
1016
+ subset (str or None, optional): Which subset of the dataset to use.
1017
+ One of ``"training"``, ``"validation"``, ``"testing"`` or ``None``.
1018
+ If ``None``, the entire dataset is used. (default: ``None``).
1019
+ """
1020
+
1021
+ _ext_audio = ".wav"
1022
+
1023
+ def __init__(
1024
+ self,
1025
+ root: Union[str, Path],
1026
+ url: str = URL,
1027
+ folder_in_archive: str = FOLDER_IN_ARCHIVE,
1028
+ download: bool = False,
1029
+ subset: Optional[str] = None,
1030
+ ) -> None:
1031
+
1032
+ # super(GTZAN, self).__init__()
1033
+
1034
+ # Get string representation of 'root' in case Path object is passed
1035
+ root = os.fspath(root)
1036
+
1037
+ self.root = root
1038
+ self.url = url
1039
+ self.folder_in_archive = folder_in_archive
1040
+ self.download = download
1041
+ self.subset = subset
1042
+
1043
+ if subset is not None and subset not in ["training", "validation", "testing"]:
1044
+ raise ValueError("When `subset` is not None, it must be one of ['training', 'validation', 'testing'].")
1045
+
1046
+ archive = os.path.basename(url)
1047
+ archive = os.path.join(root, archive)
1048
+ self._path = os.path.join(root, folder_in_archive)
1049
+
1050
+ if download:
1051
+ if not os.path.isdir(self._path):
1052
+ if not os.path.isfile(archive):
1053
+ checksum = _CHECKSUMS.get(url, None)
1054
+ download_url_to_file(url, archive, hash_prefix=checksum)
1055
+ _extract_tar(archive)
1056
+
1057
+ if not os.path.isdir(self._path):
1058
+ raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
1059
+
1060
+ if self.subset is None:
1061
+ # Check every subdirectory under dataset root
1062
+ # which has the same name as the genres in
1063
+ # GTZAN (e.g. `root_dir'/blues/, `root_dir'/rock, etc.)
1064
+ # This lets users remove or move around song files,
1065
+ # useful when e.g. they want to use only some of the files
1066
+ # in a genre or want to label other files with a different
1067
+ # genre.
1068
+ self._walker = []
1069
+
1070
+ root = os.path.expanduser(self._path)
1071
+
1072
+ for directory in gtzan_genres:
1073
+ fulldir = os.path.join(root, directory)
1074
+
1075
+ if not os.path.exists(fulldir):
1076
+ continue
1077
+
1078
+ songs_in_genre = os.listdir(fulldir)
1079
+ songs_in_genre.sort()
1080
+ for fname in songs_in_genre:
1081
+ name, ext = os.path.splitext(fname)
1082
+ if ext.lower() == ".wav" and "." in name:
1083
+ # Check whether the file is of the form
1084
+ # `gtzan_genre`.`5 digit number`.wav
1085
+ genre, num = name.split(".")
1086
+ if genre in gtzan_genres and len(num) == 5 and num.isdigit():
1087
+ self._walker.append(name)
1088
+ else:
1089
+ if self.subset == "training":
1090
+ self._walker = filtered_train
1091
+ elif self.subset == "validation":
1092
+ self._walker = filtered_valid
1093
+ elif self.subset == "testing":
1094
+ self._walker = filtered_test
1095
+
1096
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, str]:
1097
+ """Load the n-th sample from the dataset.
1098
+
1099
+ Args:
1100
+ n (int): The index of the sample to be loaded
1101
+
1102
+ Returns:
1103
+ Tuple of the following items;
1104
+
1105
+ Tensor:
1106
+ Waveform
1107
+ int:
1108
+ Sample rate
1109
+ str:
1110
+ Label
1111
+ """
1112
+ fileid = self._walker[n]
1113
+ item = load_gtzan_item(fileid, self._path, self._ext_audio)
1114
+ waveform, sample_rate, label = item
1115
+ return waveform, sample_rate, label
1116
+
1117
+ def __len__(self) -> int:
1118
+ return len(self._walker)
.venv/lib/python3.11/site-packages/torchaudio/datasets/iemocap.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from pathlib import Path
4
+ from typing import Optional, Tuple, Union
5
+
6
+ from torch import Tensor
7
+ from torch.utils.data import Dataset
8
+ from torchaudio.datasets.utils import _load_waveform
9
+
10
+
11
+ _SAMPLE_RATE = 16000
12
+
13
+
14
+ def _get_wavs_paths(data_dir):
15
+ wav_dir = data_dir / "sentences" / "wav"
16
+ wav_paths = sorted(str(p) for p in wav_dir.glob("*/*.wav"))
17
+ relative_paths = []
18
+ for wav_path in wav_paths:
19
+ start = wav_path.find("Session")
20
+ wav_path = wav_path[start:]
21
+ relative_paths.append(wav_path)
22
+ return relative_paths
23
+
24
+
25
+ class IEMOCAP(Dataset):
26
+ """*IEMOCAP* :cite:`iemocap` dataset.
27
+
28
+ Args:
29
+ root (str or Path): Root directory where the dataset's top level directory is found
30
+ sessions (Tuple[int]): Tuple of sessions (1-5) to use. (Default: ``(1, 2, 3, 4, 5)``)
31
+ utterance_type (str or None, optional): Which type(s) of utterances to include in the dataset.
32
+ Options: ("scripted", "improvised", ``None``). If ``None``, both scripted and improvised
33
+ data are used.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ root: Union[str, Path],
39
+ sessions: Tuple[str] = (1, 2, 3, 4, 5),
40
+ utterance_type: Optional[str] = None,
41
+ ):
42
+ root = Path(root)
43
+ self._path = root / "IEMOCAP"
44
+
45
+ if not os.path.isdir(self._path):
46
+ raise RuntimeError("Dataset not found.")
47
+
48
+ if utterance_type not in ["scripted", "improvised", None]:
49
+ raise ValueError("utterance_type must be one of ['scripted', 'improvised', or None]")
50
+
51
+ all_data = []
52
+ self.data = []
53
+ self.mapping = {}
54
+
55
+ for session in sessions:
56
+ session_name = f"Session{session}"
57
+ session_dir = self._path / session_name
58
+
59
+ # get wav paths
60
+ wav_paths = _get_wavs_paths(session_dir)
61
+ for wav_path in wav_paths:
62
+ wav_stem = str(Path(wav_path).stem)
63
+ all_data.append(wav_stem)
64
+
65
+ # add labels
66
+ label_dir = session_dir / "dialog" / "EmoEvaluation"
67
+ query = "*.txt"
68
+ if utterance_type == "scripted":
69
+ query = "*script*.txt"
70
+ elif utterance_type == "improvised":
71
+ query = "*impro*.txt"
72
+ label_paths = label_dir.glob(query)
73
+
74
+ for label_path in label_paths:
75
+ with open(label_path, "r") as f:
76
+ for line in f:
77
+ if not line.startswith("["):
78
+ continue
79
+ line = re.split("[\t\n]", line)
80
+ wav_stem = line[1]
81
+ label = line[2]
82
+ if wav_stem not in all_data:
83
+ continue
84
+ if label not in ["neu", "hap", "ang", "sad", "exc", "fru"]:
85
+ continue
86
+ self.mapping[wav_stem] = {}
87
+ self.mapping[wav_stem]["label"] = label
88
+
89
+ for wav_path in wav_paths:
90
+ wav_stem = str(Path(wav_path).stem)
91
+ if wav_stem in self.mapping:
92
+ self.data.append(wav_stem)
93
+ self.mapping[wav_stem]["path"] = wav_path
94
+
95
+ def get_metadata(self, n: int) -> Tuple[str, int, str, str, str]:
96
+ """Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
97
+ but otherwise returns the same fields as :py:meth:`__getitem__`.
98
+
99
+ Args:
100
+ n (int): The index of the sample to be loaded
101
+
102
+ Returns:
103
+ Tuple of the following items;
104
+
105
+ str:
106
+ Path to audio
107
+ int:
108
+ Sample rate
109
+ str:
110
+ File name
111
+ str:
112
+ Label (one of ``"neu"``, ``"hap"``, ``"ang"``, ``"sad"``, ``"exc"``, ``"fru"``)
113
+ str:
114
+ Speaker
115
+ """
116
+ wav_stem = self.data[n]
117
+ wav_path = self.mapping[wav_stem]["path"]
118
+ label = self.mapping[wav_stem]["label"]
119
+ speaker = wav_stem.split("_")[0]
120
+ return (wav_path, _SAMPLE_RATE, wav_stem, label, speaker)
121
+
122
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str]:
123
+ """Load the n-th sample from the dataset.
124
+
125
+ Args:
126
+ n (int): The index of the sample to be loaded
127
+
128
+ Returns:
129
+ Tuple of the following items;
130
+
131
+ Tensor:
132
+ Waveform
133
+ int:
134
+ Sample rate
135
+ str:
136
+ File name
137
+ str:
138
+ Label (one of ``"neu"``, ``"hap"``, ``"ang"``, ``"sad"``, ``"exc"``, ``"fru"``)
139
+ str:
140
+ Speaker
141
+ """
142
+ metadata = self.get_metadata(n)
143
+ waveform = _load_waveform(self._path, metadata[0], metadata[1])
144
+ return (waveform,) + metadata[1:]
145
+
146
+ def __len__(self):
147
+ return len(self.data)
.venv/lib/python3.11/site-packages/torchaudio/datasets/librilight_limited.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import List, Tuple, Union
4
+
5
+ import torchaudio
6
+ from torch import Tensor
7
+ from torch.utils.data import Dataset
8
+ from torchaudio._internal import download_url_to_file
9
+ from torchaudio.datasets.librispeech import _get_librispeech_metadata
10
+ from torchaudio.datasets.utils import _extract_tar
11
+
12
+
13
+ _ARCHIVE_NAME = "librispeech_finetuning"
14
+ _URL = "https://dl.fbaipublicfiles.com/librilight/data/librispeech_finetuning.tgz"
15
+ _CHECKSUM = "5d1efdc777b548194d7e09ba89126e2188026df9fd57aa57eb14408d2b2342af"
16
+ _SUBSET_MAP = {"10min": ["1h/0"], "1h": ["1h/*"], "10h": ["1h/*", "9h"]}
17
+
18
+
19
+ def _get_fileids_paths(path: Path, folders: List[str], _ext_audio: str) -> List[Tuple[str, str]]:
20
+ """Get the file names and the corresponding file paths without `speaker_id`
21
+ and `chapter_id` directories.
22
+ The format of path is like:
23
+ {root}/{_ARCHIVE_NAME}/1h/[0-5]/[clean, other] or
24
+ {root}/{_ARCHIVE_NAME}/9h/[clean, other]
25
+
26
+ Args:
27
+ path (Path): Root path to the dataset.
28
+ folders (List[str]): Folders that contain the desired audio files.
29
+ _ext_audio (str): Extension of audio files.
30
+
31
+ Returns:
32
+ List[Tuple[str, str]]:
33
+ List of tuples where the first element is the relative path to the audio file.
34
+ The format of relative path is like:
35
+ 1h/[0-5]/[clean, other] or 9h/[clean, other]
36
+ The second element is the file name without audio extension.
37
+ """
38
+
39
+ path = Path(path)
40
+ files_paths = []
41
+ for folder in folders:
42
+ paths = [p.relative_to(path) for p in path.glob(f"{folder}/*/*/*/*{_ext_audio}")]
43
+ files_paths += [(str(p.parent.parent.parent), str(p.stem)) for p in paths] # get subset folder and file name
44
+ files_paths.sort(key=lambda x: x[0] + x[1])
45
+ return files_paths
46
+
47
+
48
+ class LibriLightLimited(Dataset):
49
+ """Subset of Libri-light :cite:`librilight` dataset,
50
+ which was used in HuBERT :cite:`hsu2021hubert` for supervised fine-tuning.
51
+
52
+ Args:
53
+ root (str or Path): Path to the directory where the dataset is found or downloaded.
54
+ subset (str, optional): The subset to use. Options: [``"10min"``, ``"1h"``, ``"10h"``]
55
+ (Default: ``"10min"``).
56
+ download (bool, optional):
57
+ Whether to download the dataset if it is not found at root path. (default: ``False``).
58
+ """
59
+
60
+ _ext_txt = ".trans.txt"
61
+ _ext_audio = ".flac"
62
+
63
+ def __init__(
64
+ self,
65
+ root: Union[str, Path],
66
+ subset: str = "10min",
67
+ download: bool = False,
68
+ ) -> None:
69
+ if subset not in _SUBSET_MAP:
70
+ raise ValueError(f"`subset` must be one of {_SUBSET_MAP.keys()}. Found: {subset}")
71
+ folders = _SUBSET_MAP[subset]
72
+
73
+ root = os.fspath(root)
74
+ self._path = os.path.join(root, _ARCHIVE_NAME)
75
+ archive = os.path.join(root, f"{_ARCHIVE_NAME}.tgz")
76
+ if not os.path.isdir(self._path):
77
+ if not download:
78
+ raise RuntimeError("Dataset not found. Please use `download=True` to download")
79
+ if not os.path.isfile(archive):
80
+ download_url_to_file(_URL, archive, hash_prefix=_CHECKSUM)
81
+ _extract_tar(archive)
82
+ self._fileids_paths = _get_fileids_paths(self._path, folders, self._ext_audio)
83
+
84
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
85
+ """Load the n-th sample from the dataset.
86
+
87
+ Args:
88
+ n (int): The index of the sample to be loaded
89
+ Returns:
90
+ Tuple of the following items;
91
+
92
+ Tensor:
93
+ Waveform
94
+ int:
95
+ Sample rate
96
+ str:
97
+ Transcript
98
+ int:
99
+ Speaker ID
100
+ int:
101
+ Chapter ID
102
+ int:
103
+ Utterance ID
104
+ """
105
+ file_path, fileid = self._fileids_paths[n]
106
+ metadata = _get_librispeech_metadata(fileid, self._path, file_path, self._ext_audio, self._ext_txt)
107
+ waveform, _ = torchaudio.load(os.path.join(self._path, metadata[0]))
108
+ return (waveform,) + metadata[1:]
109
+
110
+ def __len__(self) -> int:
111
+ return len(self._fileids_paths)
.venv/lib/python3.11/site-packages/torchaudio/datasets/librimix.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import List, Tuple, Union
4
+
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ from torchaudio.datasets.utils import _load_waveform
8
+
9
+ _TASKS_TO_MIXTURE = {
10
+ "sep_clean": "mix_clean",
11
+ "enh_single": "mix_single",
12
+ "enh_both": "mix_both",
13
+ "sep_noisy": "mix_both",
14
+ }
15
+
16
+
17
+ class LibriMix(Dataset):
18
+ r"""*LibriMix* :cite:`cosentino2020librimix` dataset.
19
+
20
+ Args:
21
+ root (str or Path): The path where the directory ``Libri2Mix`` or
22
+ ``Libri3Mix`` is stored. Not the path of those directories.
23
+ subset (str, optional): The subset to use. Options: [``"train-360"``, ``"train-100"``,
24
+ ``"dev"``, and ``"test"``] (Default: ``"train-360"``).
25
+ num_speakers (int, optional): The number of speakers, which determines the directories
26
+ to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect
27
+ N source audios. (Default: 2)
28
+ sample_rate (int, optional): Sample rate of audio files. The ``sample_rate`` determines
29
+ which subdirectory the audio are fetched. If any of the audio has a different sample
30
+ rate, raises ``ValueError``. Options: [8000, 16000] (Default: 8000)
31
+ task (str, optional): The task of LibriMix.
32
+ Options: [``"enh_single"``, ``"enh_both"``, ``"sep_clean"``, ``"sep_noisy"``]
33
+ (Default: ``"sep_clean"``)
34
+ mode (str, optional): The mode when creating the mixture. If set to ``"min"``, the lengths of mixture
35
+ and sources are the minimum length of all sources. If set to ``"max"``, the lengths of mixture and
36
+ sources are zero padded to the maximum length of all sources.
37
+ Options: [``"min"``, ``"max"``]
38
+ (Default: ``"min"``)
39
+
40
+ Note:
41
+ The LibriMix dataset needs to be manually generated. Please check https://github.com/JorisCos/LibriMix
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ root: Union[str, Path],
47
+ subset: str = "train-360",
48
+ num_speakers: int = 2,
49
+ sample_rate: int = 8000,
50
+ task: str = "sep_clean",
51
+ mode: str = "min",
52
+ ):
53
+ self.root = Path(root) / f"Libri{num_speakers}Mix"
54
+ if not os.path.exists(self.root):
55
+ raise RuntimeError(
56
+ f"The path {self.root} doesn't exist. "
57
+ "Please check the ``root`` path and ``num_speakers`` or download the dataset manually."
58
+ )
59
+ if mode not in ["max", "min"]:
60
+ raise ValueError(f'Expect ``mode`` to be one in ["min", "max"]. Found {mode}.')
61
+ if sample_rate == 8000:
62
+ mix_dir = self.root / "wav8k" / mode / subset
63
+ elif sample_rate == 16000:
64
+ mix_dir = self.root / "wav16k" / mode / subset
65
+ else:
66
+ raise ValueError(f"Unsupported sample rate. Found {sample_rate}.")
67
+ self.sample_rate = sample_rate
68
+ self.task = task
69
+
70
+ self.mix_dir = mix_dir / _TASKS_TO_MIXTURE[task]
71
+ if task == "enh_both":
72
+ self.src_dirs = [(mix_dir / "mix_clean")]
73
+ else:
74
+ self.src_dirs = [(mix_dir / f"s{i+1}") for i in range(num_speakers)]
75
+
76
+ self.files = [p.name for p in self.mix_dir.glob("*.wav")]
77
+ self.files.sort()
78
+
79
+ def _load_sample(self, key) -> Tuple[int, torch.Tensor, List[torch.Tensor]]:
80
+ metadata = self.get_metadata(key)
81
+ mixed = _load_waveform(self.root, metadata[1], metadata[0])
82
+ srcs = []
83
+ for i, path_ in enumerate(metadata[2]):
84
+ src = _load_waveform(self.root, path_, metadata[0])
85
+ if mixed.shape != src.shape:
86
+ raise ValueError(f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}")
87
+ srcs.append(src)
88
+ return self.sample_rate, mixed, srcs
89
+
90
+ def get_metadata(self, key: int) -> Tuple[int, str, List[str]]:
91
+ """Get metadata for the n-th sample from the dataset.
92
+
93
+ Args:
94
+ key (int): The index of the sample to be loaded
95
+
96
+ Returns:
97
+ Tuple of the following items;
98
+
99
+ int:
100
+ Sample rate
101
+ str:
102
+ Path to mixed audio
103
+ List of str:
104
+ List of paths to source audios
105
+ """
106
+ filename = self.files[key]
107
+ mixed_path = os.path.relpath(self.mix_dir / filename, self.root)
108
+ srcs_paths = []
109
+ for dir_ in self.src_dirs:
110
+ src = os.path.relpath(dir_ / filename, self.root)
111
+ srcs_paths.append(src)
112
+ return self.sample_rate, mixed_path, srcs_paths
113
+
114
+ def __len__(self) -> int:
115
+ return len(self.files)
116
+
117
+ def __getitem__(self, key: int) -> Tuple[int, torch.Tensor, List[torch.Tensor]]:
118
+ """Load the n-th sample from the dataset.
119
+
120
+ Args:
121
+ key (int): The index of the sample to be loaded
122
+
123
+ Returns:
124
+ Tuple of the following items;
125
+
126
+ int:
127
+ Sample rate
128
+ Tensor:
129
+ Mixture waveform
130
+ List of Tensors:
131
+ List of source waveforms
132
+ """
133
+ return self._load_sample(key)
.venv/lib/python3.11/site-packages/torchaudio/datasets/librispeech.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Tuple, Union
4
+
5
+ from torch import Tensor
6
+ from torch.utils.data import Dataset
7
+ from torchaudio._internal import download_url_to_file
8
+ from torchaudio.datasets.utils import _extract_tar, _load_waveform
9
+
10
+ URL = "train-clean-100"
11
+ FOLDER_IN_ARCHIVE = "LibriSpeech"
12
+ SAMPLE_RATE = 16000
13
+ _DATA_SUBSETS = [
14
+ "dev-clean",
15
+ "dev-other",
16
+ "test-clean",
17
+ "test-other",
18
+ "train-clean-100",
19
+ "train-clean-360",
20
+ "train-other-500",
21
+ ]
22
+ _CHECKSUMS = {
23
+ "http://www.openslr.org/resources/12/dev-clean.tar.gz": "76f87d090650617fca0cac8f88b9416e0ebf80350acb97b343a85fa903728ab3", # noqa: E501
24
+ "http://www.openslr.org/resources/12/dev-other.tar.gz": "12661c48e8c3fe1de2c1caa4c3e135193bfb1811584f11f569dd12645aa84365", # noqa: E501
25
+ "http://www.openslr.org/resources/12/test-clean.tar.gz": "39fde525e59672dc6d1551919b1478f724438a95aa55f874b576be21967e6c23", # noqa: E501
26
+ "http://www.openslr.org/resources/12/test-other.tar.gz": "d09c181bba5cf717b3dee7d4d592af11a3ee3a09e08ae025c5506f6ebe961c29", # noqa: E501
27
+ "http://www.openslr.org/resources/12/train-clean-100.tar.gz": "d4ddd1d5a6ab303066f14971d768ee43278a5f2a0aa43dc716b0e64ecbbbf6e2", # noqa: E501
28
+ "http://www.openslr.org/resources/12/train-clean-360.tar.gz": "146a56496217e96c14334a160df97fffedd6e0a04e66b9c5af0d40be3c792ecf", # noqa: E501
29
+ "http://www.openslr.org/resources/12/train-other-500.tar.gz": "ddb22f27f96ec163645d53215559df6aa36515f26e01dd70798188350adcb6d2", # noqa: E501
30
+ }
31
+
32
+
33
+ def _download_librispeech(root, url):
34
+ base_url = "http://www.openslr.org/resources/12/"
35
+ ext_archive = ".tar.gz"
36
+
37
+ filename = url + ext_archive
38
+ archive = os.path.join(root, filename)
39
+ download_url = os.path.join(base_url, filename)
40
+ if not os.path.isfile(archive):
41
+ checksum = _CHECKSUMS.get(download_url, None)
42
+ download_url_to_file(download_url, archive, hash_prefix=checksum)
43
+ _extract_tar(archive)
44
+
45
+
46
+ def _get_librispeech_metadata(
47
+ fileid: str, root: str, folder: str, ext_audio: str, ext_txt: str
48
+ ) -> Tuple[str, int, str, int, int, int]:
49
+ speaker_id, chapter_id, utterance_id = fileid.split("-")
50
+
51
+ # Get audio path and sample rate
52
+ fileid_audio = f"{speaker_id}-{chapter_id}-{utterance_id}"
53
+ filepath = os.path.join(folder, speaker_id, chapter_id, f"{fileid_audio}{ext_audio}")
54
+
55
+ # Load text
56
+ file_text = f"{speaker_id}-{chapter_id}{ext_txt}"
57
+ file_text = os.path.join(root, folder, speaker_id, chapter_id, file_text)
58
+ with open(file_text) as ft:
59
+ for line in ft:
60
+ fileid_text, transcript = line.strip().split(" ", 1)
61
+ if fileid_audio == fileid_text:
62
+ break
63
+ else:
64
+ # Translation not found
65
+ raise FileNotFoundError(f"Translation not found for {fileid_audio}")
66
+
67
+ return (
68
+ filepath,
69
+ SAMPLE_RATE,
70
+ transcript,
71
+ int(speaker_id),
72
+ int(chapter_id),
73
+ int(utterance_id),
74
+ )
75
+
76
+
77
+ class LIBRISPEECH(Dataset):
78
+ """*LibriSpeech* :cite:`7178964` dataset.
79
+
80
+ Args:
81
+ root (str or Path): Path to the directory where the dataset is found or downloaded.
82
+ url (str, optional): The URL to download the dataset from,
83
+ or the type of the dataset to dowload.
84
+ Allowed type values are ``"dev-clean"``, ``"dev-other"``, ``"test-clean"``,
85
+ ``"test-other"``, ``"train-clean-100"``, ``"train-clean-360"`` and
86
+ ``"train-other-500"``. (default: ``"train-clean-100"``)
87
+ folder_in_archive (str, optional):
88
+ The top-level directory of the dataset. (default: ``"LibriSpeech"``)
89
+ download (bool, optional):
90
+ Whether to download the dataset if it is not found at root path. (default: ``False``).
91
+ """
92
+
93
+ _ext_txt = ".trans.txt"
94
+ _ext_audio = ".flac"
95
+
96
+ def __init__(
97
+ self,
98
+ root: Union[str, Path],
99
+ url: str = URL,
100
+ folder_in_archive: str = FOLDER_IN_ARCHIVE,
101
+ download: bool = False,
102
+ ) -> None:
103
+ self._url = url
104
+ if url not in _DATA_SUBSETS:
105
+ raise ValueError(f"Invalid url '{url}' given; please provide one of {_DATA_SUBSETS}.")
106
+
107
+ root = os.fspath(root)
108
+ self._archive = os.path.join(root, folder_in_archive)
109
+ self._path = os.path.join(root, folder_in_archive, url)
110
+
111
+ if not os.path.isdir(self._path):
112
+ if download:
113
+ _download_librispeech(root, url)
114
+ else:
115
+ raise RuntimeError(
116
+ f"Dataset not found at {self._path}. Please set `download=True` to download the dataset."
117
+ )
118
+
119
+ self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*/*/*" + self._ext_audio))
120
+
121
+ def get_metadata(self, n: int) -> Tuple[str, int, str, int, int, int]:
122
+ """Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
123
+ but otherwise returns the same fields as :py:func:`__getitem__`.
124
+
125
+ Args:
126
+ n (int): The index of the sample to be loaded
127
+
128
+ Returns:
129
+ Tuple of the following items;
130
+
131
+ str:
132
+ Path to audio
133
+ int:
134
+ Sample rate
135
+ str:
136
+ Transcript
137
+ int:
138
+ Speaker ID
139
+ int:
140
+ Chapter ID
141
+ int:
142
+ Utterance ID
143
+ """
144
+ fileid = self._walker[n]
145
+ return _get_librispeech_metadata(fileid, self._archive, self._url, self._ext_audio, self._ext_txt)
146
+
147
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
148
+ """Load the n-th sample from the dataset.
149
+
150
+ Args:
151
+ n (int): The index of the sample to be loaded
152
+
153
+ Returns:
154
+ Tuple of the following items;
155
+
156
+ Tensor:
157
+ Waveform
158
+ int:
159
+ Sample rate
160
+ str:
161
+ Transcript
162
+ int:
163
+ Speaker ID
164
+ int:
165
+ Chapter ID
166
+ int:
167
+ Utterance ID
168
+ """
169
+ metadata = self.get_metadata(n)
170
+ waveform = _load_waveform(self._archive, metadata[0], metadata[1])
171
+ return (waveform,) + metadata[1:]
172
+
173
+ def __len__(self) -> int:
174
+ return len(self._walker)
.venv/lib/python3.11/site-packages/torchaudio/datasets/librispeech_biasing.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import List, Tuple, Union
4
+
5
+ from torch import Tensor
6
+ from torch.utils.data import Dataset
7
+ from torchaudio._internal import download_url_to_file
8
+ from torchaudio.datasets.utils import _extract_tar, _load_waveform
9
+
10
+ URL = "train-clean-100"
11
+ FOLDER_IN_ARCHIVE = "LibriSpeech"
12
+ SAMPLE_RATE = 16000
13
+ _DATA_SUBSETS = [
14
+ "dev-clean",
15
+ "dev-other",
16
+ "test-clean",
17
+ "test-other",
18
+ "train-clean-100",
19
+ "train-clean-360",
20
+ "train-other-500",
21
+ ]
22
+ _CHECKSUMS = {
23
+ "http://www.openslr.org/resources/12/dev-clean.tar.gz": "76f87d090650617fca0cac8f88b9416e0ebf80350acb97b343a85fa903728ab3", # noqa: E501
24
+ "http://www.openslr.org/resources/12/dev-other.tar.gz": "12661c48e8c3fe1de2c1caa4c3e135193bfb1811584f11f569dd12645aa84365", # noqa: E501
25
+ "http://www.openslr.org/resources/12/test-clean.tar.gz": "39fde525e59672dc6d1551919b1478f724438a95aa55f874b576be21967e6c23", # noqa: E501
26
+ "http://www.openslr.org/resources/12/test-other.tar.gz": "d09c181bba5cf717b3dee7d4d592af11a3ee3a09e08ae025c5506f6ebe961c29", # noqa: E501
27
+ "http://www.openslr.org/resources/12/train-clean-100.tar.gz": "d4ddd1d5a6ab303066f14971d768ee43278a5f2a0aa43dc716b0e64ecbbbf6e2", # noqa: E501
28
+ "http://www.openslr.org/resources/12/train-clean-360.tar.gz": "146a56496217e96c14334a160df97fffedd6e0a04e66b9c5af0d40be3c792ecf", # noqa: E501
29
+ "http://www.openslr.org/resources/12/train-other-500.tar.gz": "ddb22f27f96ec163645d53215559df6aa36515f26e01dd70798188350adcb6d2", # noqa: E501
30
+ }
31
+
32
+
33
+ def _download_librispeech(root, url):
34
+ base_url = "http://www.openslr.org/resources/12/"
35
+ ext_archive = ".tar.gz"
36
+
37
+ filename = url + ext_archive
38
+ archive = os.path.join(root, filename)
39
+ download_url = os.path.join(base_url, filename)
40
+ if not os.path.isfile(archive):
41
+ checksum = _CHECKSUMS.get(download_url, None)
42
+ download_url_to_file(download_url, archive, hash_prefix=checksum)
43
+ _extract_tar(archive)
44
+
45
+
46
+ def _get_librispeech_metadata(
47
+ fileid: str, root: str, folder: str, ext_audio: str, ext_txt: str, blist: List[str]
48
+ ) -> Tuple[str, int, str, int, int, int]:
49
+ blist = blist or []
50
+ speaker_id, chapter_id, utterance_id = fileid.split("-")
51
+
52
+ # Get audio path and sample rate
53
+ fileid_audio = f"{speaker_id}-{chapter_id}-{utterance_id}"
54
+ filepath = os.path.join(folder, speaker_id, chapter_id, f"{fileid_audio}{ext_audio}")
55
+
56
+ # Load text
57
+ file_text = f"{speaker_id}-{chapter_id}{ext_txt}"
58
+ file_text = os.path.join(root, folder, speaker_id, chapter_id, file_text)
59
+ uttblist = []
60
+ with open(file_text) as ft:
61
+ for line in ft:
62
+ fileid_text, transcript = line.strip().split(" ", 1)
63
+ if fileid_audio == fileid_text:
64
+ # get utterance biasing list
65
+ for word in transcript.split():
66
+ if word in blist and word not in uttblist:
67
+ uttblist.append(word)
68
+ break
69
+ else:
70
+ # Translation not found
71
+ raise FileNotFoundError(f"Translation not found for {fileid_audio}")
72
+
73
+ return (
74
+ filepath,
75
+ SAMPLE_RATE,
76
+ transcript,
77
+ int(speaker_id),
78
+ int(chapter_id),
79
+ int(utterance_id),
80
+ uttblist,
81
+ )
82
+
83
+
84
+ class LibriSpeechBiasing(Dataset):
85
+ """*LibriSpeech* :cite:`7178964` dataset with prefix-tree construction and biasing support.
86
+
87
+ Args:
88
+ root (str or Path): Path to the directory where the dataset is found or downloaded.
89
+ url (str, optional): The URL to download the dataset from,
90
+ or the type of the dataset to dowload.
91
+ Allowed type values are ``"dev-clean"``, ``"dev-other"``, ``"test-clean"``,
92
+ ``"test-other"``, ``"train-clean-100"``, ``"train-clean-360"`` and
93
+ ``"train-other-500"``. (default: ``"train-clean-100"``)
94
+ folder_in_archive (str, optional):
95
+ The top-level directory of the dataset. (default: ``"LibriSpeech"``)
96
+ download (bool, optional):
97
+ Whether to download the dataset if it is not found at root path. (default: ``False``).
98
+ blist (list, optional):
99
+ The list of biasing words (default: ``[]``).
100
+ """
101
+
102
+ _ext_txt = ".trans.txt"
103
+ _ext_audio = ".flac"
104
+
105
+ def __init__(
106
+ self,
107
+ root: Union[str, Path],
108
+ url: str = URL,
109
+ folder_in_archive: str = FOLDER_IN_ARCHIVE,
110
+ download: bool = False,
111
+ blist: List[str] = None,
112
+ ) -> None:
113
+ self._url = url
114
+ if url not in _DATA_SUBSETS:
115
+ raise ValueError(f"Invalid url '{url}' given; please provide one of {_DATA_SUBSETS}.")
116
+
117
+ root = os.fspath(root)
118
+ self._archive = os.path.join(root, folder_in_archive)
119
+ self._path = os.path.join(root, folder_in_archive, url)
120
+
121
+ if not os.path.isdir(self._path):
122
+ if download:
123
+ _download_librispeech(root, url)
124
+ else:
125
+ raise RuntimeError(
126
+ f"Dataset not found at {self._path}. Please set `download=True` to download the dataset."
127
+ )
128
+
129
+ self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*/*/*" + self._ext_audio))
130
+ self.blist = blist
131
+
132
+ def get_metadata(self, n: int) -> Tuple[str, int, str, int, int, int]:
133
+ """Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
134
+ but otherwise returns the same fields as :py:func:`__getitem__`.
135
+
136
+ Args:
137
+ n (int): The index of the sample to be loaded
138
+
139
+ Returns:
140
+ Tuple of the following items;
141
+
142
+ str:
143
+ Path to audio
144
+ int:
145
+ Sample rate
146
+ str:
147
+ Transcript
148
+ int:
149
+ Speaker ID
150
+ int:
151
+ Chapter ID
152
+ int:
153
+ Utterance ID
154
+ list:
155
+ List of biasing words in the utterance
156
+ """
157
+ fileid = self._walker[n]
158
+ return _get_librispeech_metadata(fileid, self._archive, self._url, self._ext_audio, self._ext_txt, self.blist)
159
+
160
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
161
+ """Load the n-th sample from the dataset.
162
+
163
+ Args:
164
+ n (int): The index of the sample to be loaded
165
+
166
+ Returns:
167
+ Tuple of the following items;
168
+
169
+ Tensor:
170
+ Waveform
171
+ int:
172
+ Sample rate
173
+ str:
174
+ Transcript
175
+ int:
176
+ Speaker ID
177
+ int:
178
+ Chapter ID
179
+ int:
180
+ Utterance ID
181
+ list:
182
+ List of biasing words in the utterance
183
+ """
184
+ metadata = self.get_metadata(n)
185
+ waveform = _load_waveform(self._archive, metadata[0], metadata[1])
186
+ return (waveform,) + metadata[1:]
187
+
188
+ def __len__(self) -> int:
189
+ return len(self._walker)
.venv/lib/python3.11/site-packages/torchaudio/datasets/libritts.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Tuple, Union
4
+
5
+ import torchaudio
6
+ from torch import Tensor
7
+ from torch.utils.data import Dataset
8
+ from torchaudio._internal import download_url_to_file
9
+ from torchaudio.datasets.utils import _extract_tar
10
+
11
+ URL = "train-clean-100"
12
+ FOLDER_IN_ARCHIVE = "LibriTTS"
13
+ _CHECKSUMS = {
14
+ "http://www.openslr.org/resources/60/dev-clean.tar.gz": "da0864e1bd26debed35da8a869dd5c04dfc27682921936de7cff9c8a254dbe1a", # noqa: E501
15
+ "http://www.openslr.org/resources/60/dev-other.tar.gz": "d413eda26f3a152ac7c9cf3658ef85504dfb1b625296e5fa83727f5186cca79c", # noqa: E501
16
+ "http://www.openslr.org/resources/60/test-clean.tar.gz": "234ea5b25859102a87024a4b9b86641f5b5aaaf1197335c95090cde04fe9a4f5", # noqa: E501
17
+ "http://www.openslr.org/resources/60/test-other.tar.gz": "33a5342094f3bba7ccc2e0500b9e72d558f72eb99328ac8debe1d9080402f10d", # noqa: E501
18
+ "http://www.openslr.org/resources/60/train-clean-100.tar.gz": "c5608bf1ef74bb621935382b8399c5cdd51cd3ee47cec51f00f885a64c6c7f6b", # noqa: E501
19
+ "http://www.openslr.org/resources/60/train-clean-360.tar.gz": "ce7cff44dcac46009d18379f37ef36551123a1dc4e5c8e4eb73ae57260de4886", # noqa: E501
20
+ "http://www.openslr.org/resources/60/train-other-500.tar.gz": "e35f7e34deeb2e2bdfe4403d88c8fdd5fbf64865cae41f027a185a6965f0a5df", # noqa: E501
21
+ }
22
+
23
+
24
+ def load_libritts_item(
25
+ fileid: str,
26
+ path: str,
27
+ ext_audio: str,
28
+ ext_original_txt: str,
29
+ ext_normalized_txt: str,
30
+ ) -> Tuple[Tensor, int, str, str, int, int, str]:
31
+ speaker_id, chapter_id, segment_id, utterance_id = fileid.split("_")
32
+ utterance_id = fileid
33
+
34
+ normalized_text = utterance_id + ext_normalized_txt
35
+ normalized_text = os.path.join(path, speaker_id, chapter_id, normalized_text)
36
+
37
+ original_text = utterance_id + ext_original_txt
38
+ original_text = os.path.join(path, speaker_id, chapter_id, original_text)
39
+
40
+ file_audio = utterance_id + ext_audio
41
+ file_audio = os.path.join(path, speaker_id, chapter_id, file_audio)
42
+
43
+ # Load audio
44
+ waveform, sample_rate = torchaudio.load(file_audio)
45
+
46
+ # Load original text
47
+ with open(original_text) as ft:
48
+ original_text = ft.readline()
49
+
50
+ # Load normalized text
51
+ with open(normalized_text, "r") as ft:
52
+ normalized_text = ft.readline()
53
+
54
+ return (
55
+ waveform,
56
+ sample_rate,
57
+ original_text,
58
+ normalized_text,
59
+ int(speaker_id),
60
+ int(chapter_id),
61
+ utterance_id,
62
+ )
63
+
64
+
65
+ class LIBRITTS(Dataset):
66
+ """*LibriTTS* :cite:`Zen2019LibriTTSAC` dataset.
67
+
68
+ Args:
69
+ root (str or Path): Path to the directory where the dataset is found or downloaded.
70
+ url (str, optional): The URL to download the dataset from,
71
+ or the type of the dataset to dowload.
72
+ Allowed type values are ``"dev-clean"``, ``"dev-other"``, ``"test-clean"``,
73
+ ``"test-other"``, ``"train-clean-100"``, ``"train-clean-360"`` and
74
+ ``"train-other-500"``. (default: ``"train-clean-100"``)
75
+ folder_in_archive (str, optional):
76
+ The top-level directory of the dataset. (default: ``"LibriTTS"``)
77
+ download (bool, optional):
78
+ Whether to download the dataset if it is not found at root path. (default: ``False``).
79
+ """
80
+
81
+ _ext_original_txt = ".original.txt"
82
+ _ext_normalized_txt = ".normalized.txt"
83
+ _ext_audio = ".wav"
84
+
85
+ def __init__(
86
+ self,
87
+ root: Union[str, Path],
88
+ url: str = URL,
89
+ folder_in_archive: str = FOLDER_IN_ARCHIVE,
90
+ download: bool = False,
91
+ ) -> None:
92
+
93
+ if url in [
94
+ "dev-clean",
95
+ "dev-other",
96
+ "test-clean",
97
+ "test-other",
98
+ "train-clean-100",
99
+ "train-clean-360",
100
+ "train-other-500",
101
+ ]:
102
+
103
+ ext_archive = ".tar.gz"
104
+ base_url = "http://www.openslr.org/resources/60/"
105
+
106
+ url = os.path.join(base_url, url + ext_archive)
107
+
108
+ # Get string representation of 'root' in case Path object is passed
109
+ root = os.fspath(root)
110
+
111
+ basename = os.path.basename(url)
112
+ archive = os.path.join(root, basename)
113
+
114
+ basename = basename.split(".")[0]
115
+ folder_in_archive = os.path.join(folder_in_archive, basename)
116
+
117
+ self._path = os.path.join(root, folder_in_archive)
118
+
119
+ if download:
120
+ if not os.path.isdir(self._path):
121
+ if not os.path.isfile(archive):
122
+ checksum = _CHECKSUMS.get(url, None)
123
+ download_url_to_file(url, archive, hash_prefix=checksum)
124
+ _extract_tar(archive)
125
+ else:
126
+ if not os.path.exists(self._path):
127
+ raise RuntimeError(
128
+ f"The path {self._path} doesn't exist. "
129
+ "Please check the ``root`` path or set `download=True` to download it"
130
+ )
131
+
132
+ self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*/*/*" + self._ext_audio))
133
+
134
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int, int, str]:
135
+ """Load the n-th sample from the dataset.
136
+
137
+ Args:
138
+ n (int): The index of the sample to be loaded
139
+
140
+ Returns:
141
+ Tuple of the following items;
142
+
143
+ Tensor:
144
+ Waveform
145
+ int:
146
+ Sample rate
147
+ str:
148
+ Original text
149
+ str:
150
+ Normalized text
151
+ int:
152
+ Speaker ID
153
+ int:
154
+ Chapter ID
155
+ str:
156
+ Utterance ID
157
+ """
158
+ fileid = self._walker[n]
159
+ return load_libritts_item(
160
+ fileid,
161
+ self._path,
162
+ self._ext_audio,
163
+ self._ext_original_txt,
164
+ self._ext_normalized_txt,
165
+ )
166
+
167
+ def __len__(self) -> int:
168
+ return len(self._walker)
.venv/lib/python3.11/site-packages/torchaudio/datasets/ljspeech.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Tuple, Union
5
+
6
+ import torchaudio
7
+ from torch import Tensor
8
+ from torch.utils.data import Dataset
9
+ from torchaudio._internal import download_url_to_file
10
+ from torchaudio.datasets.utils import _extract_tar
11
+
12
+
13
+ _RELEASE_CONFIGS = {
14
+ "release1": {
15
+ "folder_in_archive": "wavs",
16
+ "url": "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2",
17
+ "checksum": "be1a30453f28eb8dd26af4101ae40cbf2c50413b1bb21936cbcdc6fae3de8aa5",
18
+ }
19
+ }
20
+
21
+
22
+ class LJSPEECH(Dataset):
23
+ """*LJSpeech-1.1* :cite:`ljspeech17` dataset.
24
+
25
+ Args:
26
+ root (str or Path): Path to the directory where the dataset is found or downloaded.
27
+ url (str, optional): The URL to download the dataset from.
28
+ (default: ``"https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"``)
29
+ folder_in_archive (str, optional):
30
+ The top-level directory of the dataset. (default: ``"wavs"``)
31
+ download (bool, optional):
32
+ Whether to download the dataset if it is not found at root path. (default: ``False``).
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ root: Union[str, Path],
38
+ url: str = _RELEASE_CONFIGS["release1"]["url"],
39
+ folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"],
40
+ download: bool = False,
41
+ ) -> None:
42
+
43
+ self._parse_filesystem(root, url, folder_in_archive, download)
44
+
45
+ def _parse_filesystem(self, root: str, url: str, folder_in_archive: str, download: bool) -> None:
46
+ root = Path(root)
47
+
48
+ basename = os.path.basename(url)
49
+ archive = root / basename
50
+
51
+ basename = Path(basename.split(".tar.bz2")[0])
52
+ folder_in_archive = basename / folder_in_archive
53
+
54
+ self._path = root / folder_in_archive
55
+ self._metadata_path = root / basename / "metadata.csv"
56
+
57
+ if download:
58
+ if not os.path.isdir(self._path):
59
+ if not os.path.isfile(archive):
60
+ checksum = _RELEASE_CONFIGS["release1"]["checksum"]
61
+ download_url_to_file(url, archive, hash_prefix=checksum)
62
+ _extract_tar(archive)
63
+ else:
64
+ if not os.path.exists(self._path):
65
+ raise RuntimeError(
66
+ f"The path {self._path} doesn't exist. "
67
+ "Please check the ``root`` path or set `download=True` to download it"
68
+ )
69
+
70
+ with open(self._metadata_path, "r", newline="") as metadata:
71
+ flist = csv.reader(metadata, delimiter="|", quoting=csv.QUOTE_NONE)
72
+ self._flist = list(flist)
73
+
74
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str]:
75
+ """Load the n-th sample from the dataset.
76
+
77
+ Args:
78
+ n (int): The index of the sample to be loaded
79
+
80
+ Returns:
81
+ Tuple of the following items;
82
+
83
+ Tensor:
84
+ Waveform
85
+ int:
86
+ Sample rate
87
+ str:
88
+ Transcript
89
+ str:
90
+ Normalized Transcript
91
+ """
92
+ line = self._flist[n]
93
+ fileid, transcript, normalized_transcript = line
94
+ fileid_audio = self._path / (fileid + ".wav")
95
+
96
+ # Load audio
97
+ waveform, sample_rate = torchaudio.load(fileid_audio)
98
+
99
+ return (
100
+ waveform,
101
+ sample_rate,
102
+ transcript,
103
+ normalized_transcript,
104
+ )
105
+
106
+ def __len__(self) -> int:
107
+ return len(self._flist)
.venv/lib/python3.11/site-packages/torchaudio/datasets/musdb_hq.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torchaudio
7
+ from torch.utils.data import Dataset
8
+ from torchaudio._internal import download_url_to_file
9
+ from torchaudio.datasets.utils import _extract_zip
10
+
11
+ _URL = "https://zenodo.org/record/3338373/files/musdb18hq.zip"
12
+ _CHECKSUM = "baac80d0483c61d74b2e5f3be75fa557eec52898339e6aa45c1fa48833c5d21d"
13
+ _EXT = ".wav"
14
+ _SAMPLE_RATE = 44100
15
+ _VALIDATION_SET = [
16
+ "Actions - One Minute Smile",
17
+ "Clara Berry And Wooldog - Waltz For My Victims",
18
+ "Johnny Lokke - Promises & Lies",
19
+ "Patrick Talbot - A Reason To Leave",
20
+ "Triviul - Angelsaint",
21
+ "Alexander Ross - Goodbye Bolero",
22
+ "Fergessen - Nos Palpitants",
23
+ "Leaf - Summerghost",
24
+ "Skelpolu - Human Mistakes",
25
+ "Young Griffo - Pennies",
26
+ "ANiMAL - Rockshow",
27
+ "James May - On The Line",
28
+ "Meaxic - Take A Step",
29
+ "Traffic Experiment - Sirens",
30
+ ]
31
+
32
+
33
+ class MUSDB_HQ(Dataset):
34
+ """*MUSDB_HQ* :cite:`MUSDB18HQ` dataset.
35
+
36
+ Args:
37
+ root (str or Path): Root directory where the dataset's top level directory is found
38
+ subset (str): Subset of the dataset to use. Options: [``"train"``, ``"test"``].
39
+ sources (List[str] or None, optional): Sources extract data from.
40
+ List can contain the following options: [``"bass"``, ``"drums"``, ``"other"``, ``"mixture"``, ``"vocals"``].
41
+ If ``None``, dataset consists of tracks except mixture.
42
+ (default: ``None``)
43
+ split (str or None, optional): Whether to split training set into train and validation set.
44
+ If ``None``, no splitting occurs. If ``train`` or ``validation``, returns respective set.
45
+ (default: ``None``)
46
+ download (bool, optional): Whether to download the dataset if it is not found at root path.
47
+ (default: ``False``)
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ root: Union[str, Path],
53
+ subset: str,
54
+ sources: Optional[List[str]] = None,
55
+ split: Optional[str] = None,
56
+ download: bool = False,
57
+ ) -> None:
58
+ self.sources = ["bass", "drums", "other", "vocals"] if not sources else sources
59
+ self.split = split
60
+
61
+ basename = os.path.basename(_URL)
62
+ archive = os.path.join(root, basename)
63
+ basename = basename.rsplit(".", 2)[0]
64
+
65
+ if subset not in ["test", "train"]:
66
+ raise ValueError("`subset` must be one of ['test', 'train']")
67
+ if self.split is not None and self.split not in ["train", "validation"]:
68
+ raise ValueError("`split` must be one of ['train', 'validation']")
69
+ base_path = os.path.join(root, basename)
70
+ self._path = os.path.join(base_path, subset)
71
+ if not os.path.isdir(self._path):
72
+ if not os.path.isfile(archive):
73
+ if not download:
74
+ raise RuntimeError("Dataset not found. Please use `download=True` to download")
75
+ download_url_to_file(_URL, archive, hash_prefix=_CHECKSUM)
76
+ os.makedirs(base_path, exist_ok=True)
77
+ _extract_zip(archive, base_path)
78
+
79
+ self.names = self._collect_songs()
80
+
81
+ def _get_track(self, name, source):
82
+ return Path(self._path) / name / f"{source}{_EXT}"
83
+
84
+ def _load_sample(self, n: int) -> Tuple[torch.Tensor, int, int, str]:
85
+ name = self.names[n]
86
+ wavs = []
87
+ num_frames = None
88
+ for source in self.sources:
89
+ track = self._get_track(name, source)
90
+ wav, sr = torchaudio.load(str(track))
91
+ if sr != _SAMPLE_RATE:
92
+ raise ValueError(f"expected sample rate {_SAMPLE_RATE}, but got {sr}")
93
+ if num_frames is None:
94
+ num_frames = wav.shape[-1]
95
+ else:
96
+ if wav.shape[-1] != num_frames:
97
+ raise ValueError("num_frames do not match across sources")
98
+ wavs.append(wav)
99
+
100
+ stacked = torch.stack(wavs)
101
+
102
+ return stacked, _SAMPLE_RATE, num_frames, name
103
+
104
+ def _collect_songs(self):
105
+ if self.split == "validation":
106
+ return _VALIDATION_SET
107
+ path = Path(self._path)
108
+ names = []
109
+ for root, folders, _ in os.walk(path, followlinks=True):
110
+ root = Path(root)
111
+ if root.name.startswith(".") or folders or root == path:
112
+ continue
113
+ name = str(root.relative_to(path))
114
+ if self.split and name in _VALIDATION_SET:
115
+ continue
116
+ names.append(name)
117
+ return sorted(names)
118
+
119
+ def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, int, str]:
120
+ """Load the n-th sample from the dataset.
121
+
122
+ Args:
123
+ n (int): The index of the sample to be loaded
124
+ Returns:
125
+ Tuple of the following items;
126
+
127
+ Tensor:
128
+ Waveform
129
+ int:
130
+ Sample rate
131
+ int:
132
+ Num frames
133
+ str:
134
+ Track name
135
+ """
136
+ return self._load_sample(n)
137
+
138
+ def __len__(self) -> int:
139
+ return len(self.names)
.venv/lib/python3.11/site-packages/torchaudio/datasets/quesst14.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from pathlib import Path
4
+ from typing import Optional, Tuple, Union
5
+
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ from torchaudio._internal import download_url_to_file
9
+ from torchaudio.datasets.utils import _extract_tar, _load_waveform
10
+
11
+
12
+ URL = "https://speech.fit.vutbr.cz/files/quesst14Database.tgz"
13
+ SAMPLE_RATE = 8000
14
+ _CHECKSUM = "4f869e06bc066bbe9c5dde31dbd3909a0870d70291110ebbb38878dcbc2fc5e4"
15
+ _LANGUAGES = [
16
+ "albanian",
17
+ "basque",
18
+ "czech",
19
+ "nnenglish",
20
+ "romanian",
21
+ "slovak",
22
+ ]
23
+
24
+
25
+ class QUESST14(Dataset):
26
+ """*QUESST14* :cite:`Mir2015QUESST2014EQ` dataset.
27
+
28
+ Args:
29
+ root (str or Path): Root directory where the dataset's top level directory is found
30
+ subset (str): Subset of the dataset to use. Options: [``"docs"``, ``"dev"``, ``"eval"``].
31
+ language (str or None, optional): Language to get dataset for.
32
+ Options: [``None``, ``albanian``, ``basque``, ``czech``, ``nnenglish``, ``romanian``, ``slovak``].
33
+ If ``None``, dataset consists of all languages. (default: ``"nnenglish"``)
34
+ download (bool, optional): Whether to download the dataset if it is not found at root path.
35
+ (default: ``False``)
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ root: Union[str, Path],
41
+ subset: str,
42
+ language: Optional[str] = "nnenglish",
43
+ download: bool = False,
44
+ ) -> None:
45
+ if subset not in ["docs", "dev", "eval"]:
46
+ raise ValueError("`subset` must be one of ['docs', 'dev', 'eval']")
47
+
48
+ if language is not None and language not in _LANGUAGES:
49
+ raise ValueError(f"`language` must be None or one of {str(_LANGUAGES)}")
50
+
51
+ # Get string representation of 'root'
52
+ root = os.fspath(root)
53
+
54
+ basename = os.path.basename(URL)
55
+ archive = os.path.join(root, basename)
56
+
57
+ basename = basename.rsplit(".", 2)[0]
58
+ self._path = os.path.join(root, basename)
59
+
60
+ if not os.path.isdir(self._path):
61
+ if not os.path.isfile(archive):
62
+ if not download:
63
+ raise RuntimeError("Dataset not found. Please use `download=True` to download")
64
+ download_url_to_file(URL, archive, hash_prefix=_CHECKSUM)
65
+ _extract_tar(archive, root)
66
+
67
+ if subset == "docs":
68
+ self.data = filter_audio_paths(self._path, language, "language_key_utterances.lst")
69
+ elif subset == "dev":
70
+ self.data = filter_audio_paths(self._path, language, "language_key_dev.lst")
71
+ elif subset == "eval":
72
+ self.data = filter_audio_paths(self._path, language, "language_key_eval.lst")
73
+
74
+ def get_metadata(self, n: int) -> Tuple[str, int, str]:
75
+ """Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
76
+ but otherwise returns the same fields as :py:func:`__getitem__`.
77
+
78
+ Args:
79
+ n (int): The index of the sample to be loaded
80
+
81
+ Returns:
82
+ Tuple of the following items;
83
+
84
+ str:
85
+ Path to audio
86
+ int:
87
+ Sample rate
88
+ str:
89
+ File name
90
+ """
91
+ audio_path = self.data[n]
92
+ relpath = os.path.relpath(audio_path, self._path)
93
+ return relpath, SAMPLE_RATE, audio_path.with_suffix("").name
94
+
95
+ def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str]:
96
+ """Load the n-th sample from the dataset.
97
+
98
+ Args:
99
+ n (int): The index of the sample to be loaded
100
+
101
+ Returns:
102
+ Tuple of the following items;
103
+
104
+ Tensor:
105
+ Waveform
106
+ int:
107
+ Sample rate
108
+ str:
109
+ File name
110
+ """
111
+ metadata = self.get_metadata(n)
112
+ waveform = _load_waveform(self._path, metadata[0], metadata[1])
113
+ return (waveform,) + metadata[1:]
114
+
115
+ def __len__(self) -> int:
116
+ return len(self.data)
117
+
118
+
119
+ def filter_audio_paths(
120
+ path: str,
121
+ language: str,
122
+ lst_name: str,
123
+ ):
124
+ """Extract audio paths for the given language."""
125
+ audio_paths = []
126
+
127
+ path = Path(path)
128
+ with open(path / "scoring" / lst_name) as f:
129
+ for line in f:
130
+ audio_path, lang = line.strip().split()
131
+ if language is not None and lang != language:
132
+ continue
133
+ audio_path = re.sub(r"^.*?\/", "", audio_path)
134
+ audio_paths.append(path / audio_path)
135
+
136
+ return audio_paths
.venv/lib/python3.11/site-packages/torchaudio/datasets/snips.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ from torchaudio.datasets.utils import _load_waveform
8
+
9
+
10
+ _SAMPLE_RATE = 16000
11
+ _SPEAKERS = [
12
+ "Aditi",
13
+ "Amy",
14
+ "Brian",
15
+ "Emma",
16
+ "Geraint",
17
+ "Ivy",
18
+ "Joanna",
19
+ "Joey",
20
+ "Justin",
21
+ "Kendra",
22
+ "Kimberly",
23
+ "Matthew",
24
+ "Nicole",
25
+ "Raveena",
26
+ "Russell",
27
+ "Salli",
28
+ ]
29
+
30
+
31
+ def _load_labels(file: Path, subset: str):
32
+ """Load transcirpt, iob, and intent labels for all utterances.
33
+
34
+ Args:
35
+ file (Path): The path to the label file.
36
+ subset (str): Subset of the dataset to use. Options: [``"train"``, ``"valid"``, ``"test"``].
37
+
38
+ Returns:
39
+ Dictionary of labels, where the key is the filename of the audio,
40
+ and the label is a Tuple of transcript, Inside–outside–beginning (IOB) label, and intention label.
41
+ """
42
+ labels = {}
43
+ with open(file, "r") as f:
44
+ for line in f:
45
+ line = line.strip().split(" ")
46
+ index = line[0]
47
+ trans, iob_intent = " ".join(line[1:]).split("\t")
48
+ trans = " ".join(trans.split(" ")[1:-1])
49
+ iob = " ".join(iob_intent.split(" ")[1:-1])
50
+ intent = iob_intent.split(" ")[-1]
51
+ if subset in index:
52
+ labels[index] = (trans, iob, intent)
53
+ return labels
54
+
55
+
56
+ class Snips(Dataset):
57
+ """*Snips* :cite:`coucke2018snips` dataset.
58
+
59
+ Args:
60
+ root (str or Path): Root directory where the dataset's top level directory is found.
61
+ subset (str): Subset of the dataset to use. Options: [``"train"``, ``"valid"``, ``"test"``].
62
+ speakers (List[str] or None, optional): The speaker list to include in the dataset. If ``None``,
63
+ include all speakers in the subset. (Default: ``None``)
64
+ audio_format (str, optional): The extension of the audios. Options: [``"mp3"``, ``"wav"``].
65
+ (Default: ``"mp3"``)
66
+ """
67
+
68
+ _trans_file = "all.iob.snips.txt"
69
+
70
+ def __init__(
71
+ self,
72
+ root: Union[str, Path],
73
+ subset: str,
74
+ speakers: Optional[List[str]] = None,
75
+ audio_format: str = "mp3",
76
+ ) -> None:
77
+ if subset not in ["train", "valid", "test"]:
78
+ raise ValueError('`subset` must be one of ["train", "valid", "test"].')
79
+ if audio_format not in ["mp3", "wav"]:
80
+ raise ValueError('`audio_format` must be one of ["mp3", "wav].')
81
+
82
+ root = Path(root)
83
+ self._path = root / "SNIPS"
84
+ self.audio_path = self._path / subset
85
+ if speakers is None:
86
+ speakers = _SPEAKERS
87
+
88
+ if not os.path.isdir(self._path):
89
+ raise RuntimeError("Dataset not found.")
90
+
91
+ self.audio_paths = self.audio_path.glob(f"*.{audio_format}")
92
+ self.data = []
93
+ for audio_path in sorted(self.audio_paths):
94
+ audio_name = str(audio_path.name)
95
+ speaker = audio_name.split("-")[0]
96
+ if speaker in speakers:
97
+ self.data.append(audio_path)
98
+ transcript_path = self._path / self._trans_file
99
+ self.labels = _load_labels(transcript_path, subset)
100
+
101
+ def get_metadata(self, n: int) -> Tuple[str, int, str, str, str]:
102
+ """Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
103
+ but otherwise returns the same fields as :py:func:`__getitem__`.
104
+
105
+ Args:
106
+ n (int): The index of the sample to be loaded.
107
+
108
+ Returns:
109
+ Tuple of the following items:
110
+
111
+ str:
112
+ Path to audio
113
+ int:
114
+ Sample rate
115
+ str:
116
+ File name
117
+ str:
118
+ Transcription of audio
119
+ str:
120
+ Inside–outside–beginning (IOB) label of transcription
121
+ str:
122
+ Intention label of the audio.
123
+ """
124
+ audio_path = self.data[n]
125
+ relpath = os.path.relpath(audio_path, self._path)
126
+ file_name = audio_path.with_suffix("").name
127
+ transcript, iob, intent = self.labels[file_name]
128
+ return relpath, _SAMPLE_RATE, file_name, transcript, iob, intent
129
+
130
+ def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str, str, str]:
131
+ """Load the n-th sample from the dataset.
132
+
133
+ Args:
134
+ n (int): The index of the sample to be loaded
135
+
136
+ Returns:
137
+ Tuple of the following items:
138
+
139
+ Tensor:
140
+ Waveform
141
+ int:
142
+ Sample rate
143
+ str:
144
+ File name
145
+ str:
146
+ Transcription of audio
147
+ str:
148
+ Inside–outside–beginning (IOB) label of transcription
149
+ str:
150
+ Intention label of the audio.
151
+ """
152
+ metadata = self.get_metadata(n)
153
+ waveform = _load_waveform(self._path, metadata[0], metadata[1])
154
+ return (waveform,) + metadata[1:]
155
+
156
+ def __len__(self) -> int:
157
+ return len(self.data)
.venv/lib/python3.11/site-packages/torchaudio/datasets/speechcommands.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Optional, Tuple, Union
4
+
5
+ from torch import Tensor
6
+ from torch.utils.data import Dataset
7
+ from torchaudio._internal import download_url_to_file
8
+ from torchaudio.datasets.utils import _extract_tar, _load_waveform
9
+
10
+ FOLDER_IN_ARCHIVE = "SpeechCommands"
11
+ URL = "speech_commands_v0.02"
12
+ HASH_DIVIDER = "_nohash_"
13
+ EXCEPT_FOLDER = "_background_noise_"
14
+ SAMPLE_RATE = 16000
15
+ _CHECKSUMS = {
16
+ "http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz": "743935421bb51cccdb6bdd152e04c5c70274e935c82119ad7faeec31780d811d", # noqa: E501
17
+ "http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz": "af14739ee7dc311471de98f5f9d2c9191b18aedfe957f4a6ff791c709868ff58", # noqa: E501
18
+ }
19
+
20
+
21
+ def _load_list(root, *filenames):
22
+ output = []
23
+ for filename in filenames:
24
+ filepath = os.path.join(root, filename)
25
+ with open(filepath) as fileobj:
26
+ output += [os.path.normpath(os.path.join(root, line.strip())) for line in fileobj]
27
+ return output
28
+
29
+
30
+ def _get_speechcommands_metadata(filepath: str, path: str) -> Tuple[str, int, str, str, int]:
31
+ relpath = os.path.relpath(filepath, path)
32
+ reldir, filename = os.path.split(relpath)
33
+ _, label = os.path.split(reldir)
34
+ # Besides the officially supported split method for datasets defined by "validation_list.txt"
35
+ # and "testing_list.txt" over "speech_commands_v0.0x.tar.gz" archives, an alternative split
36
+ # method referred to in paragraph 2-3 of Section 7.1, references 13 and 14 of the original
37
+ # paper, and the checksums file from the tensorflow_datasets package [1] is also supported.
38
+ # Some filenames in those "speech_commands_test_set_v0.0x.tar.gz" archives have the form
39
+ # "xxx.wav.wav", so file extensions twice needs to be stripped twice.
40
+ # [1] https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/url_checksums/speech_commands.txt
41
+ speaker, _ = os.path.splitext(filename)
42
+ speaker, _ = os.path.splitext(speaker)
43
+
44
+ speaker_id, utterance_number = speaker.split(HASH_DIVIDER)
45
+ utterance_number = int(utterance_number)
46
+
47
+ return relpath, SAMPLE_RATE, label, speaker_id, utterance_number
48
+
49
+
50
+ class SPEECHCOMMANDS(Dataset):
51
+ """*Speech Commands* :cite:`speechcommandsv2` dataset.
52
+
53
+ Args:
54
+ root (str or Path): Path to the directory where the dataset is found or downloaded.
55
+ url (str, optional): The URL to download the dataset from,
56
+ or the type of the dataset to dowload.
57
+ Allowed type values are ``"speech_commands_v0.01"`` and ``"speech_commands_v0.02"``
58
+ (default: ``"speech_commands_v0.02"``)
59
+ folder_in_archive (str, optional):
60
+ The top-level directory of the dataset. (default: ``"SpeechCommands"``)
61
+ download (bool, optional):
62
+ Whether to download the dataset if it is not found at root path. (default: ``False``).
63
+ subset (str or None, optional):
64
+ Select a subset of the dataset [None, "training", "validation", "testing"]. None means
65
+ the whole dataset. "validation" and "testing" are defined in "validation_list.txt" and
66
+ "testing_list.txt", respectively, and "training" is the rest. Details for the files
67
+ "validation_list.txt" and "testing_list.txt" are explained in the README of the dataset
68
+ and in the introduction of Section 7 of the original paper and its reference 12. The
69
+ original paper can be found `here <https://arxiv.org/pdf/1804.03209.pdf>`_. (Default: ``None``)
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ root: Union[str, Path],
75
+ url: str = URL,
76
+ folder_in_archive: str = FOLDER_IN_ARCHIVE,
77
+ download: bool = False,
78
+ subset: Optional[str] = None,
79
+ ) -> None:
80
+
81
+ if subset is not None and subset not in ["training", "validation", "testing"]:
82
+ raise ValueError("When `subset` is not None, it must be one of ['training', 'validation', 'testing'].")
83
+
84
+ if url in [
85
+ "speech_commands_v0.01",
86
+ "speech_commands_v0.02",
87
+ ]:
88
+ base_url = "http://download.tensorflow.org/data/"
89
+ ext_archive = ".tar.gz"
90
+
91
+ url = os.path.join(base_url, url + ext_archive)
92
+
93
+ # Get string representation of 'root' in case Path object is passed
94
+ root = os.fspath(root)
95
+ self._archive = os.path.join(root, folder_in_archive)
96
+
97
+ basename = os.path.basename(url)
98
+ archive = os.path.join(root, basename)
99
+
100
+ basename = basename.rsplit(".", 2)[0]
101
+ folder_in_archive = os.path.join(folder_in_archive, basename)
102
+
103
+ self._path = os.path.join(root, folder_in_archive)
104
+
105
+ if download:
106
+ if not os.path.isdir(self._path):
107
+ if not os.path.isfile(archive):
108
+ checksum = _CHECKSUMS.get(url, None)
109
+ download_url_to_file(url, archive, hash_prefix=checksum)
110
+ _extract_tar(archive, self._path)
111
+ else:
112
+ if not os.path.exists(self._path):
113
+ raise RuntimeError(
114
+ f"The path {self._path} doesn't exist. "
115
+ "Please check the ``root`` path or set `download=True` to download it"
116
+ )
117
+
118
+ if subset == "validation":
119
+ self._walker = _load_list(self._path, "validation_list.txt")
120
+ elif subset == "testing":
121
+ self._walker = _load_list(self._path, "testing_list.txt")
122
+ elif subset == "training":
123
+ excludes = set(_load_list(self._path, "validation_list.txt", "testing_list.txt"))
124
+ walker = sorted(str(p) for p in Path(self._path).glob("*/*.wav"))
125
+ self._walker = [
126
+ w
127
+ for w in walker
128
+ if HASH_DIVIDER in w and EXCEPT_FOLDER not in w and os.path.normpath(w) not in excludes
129
+ ]
130
+ else:
131
+ walker = sorted(str(p) for p in Path(self._path).glob("*/*.wav"))
132
+ self._walker = [w for w in walker if HASH_DIVIDER in w and EXCEPT_FOLDER not in w]
133
+
134
+ def get_metadata(self, n: int) -> Tuple[str, int, str, str, int]:
135
+ """Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
136
+ but otherwise returns the same fields as :py:func:`__getitem__`.
137
+
138
+ Args:
139
+ n (int): The index of the sample to be loaded
140
+
141
+ Returns:
142
+ Tuple of the following items;
143
+
144
+ str:
145
+ Path to the audio
146
+ int:
147
+ Sample rate
148
+ str:
149
+ Label
150
+ str:
151
+ Speaker ID
152
+ int:
153
+ Utterance number
154
+ """
155
+ fileid = self._walker[n]
156
+ return _get_speechcommands_metadata(fileid, self._archive)
157
+
158
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int]:
159
+ """Load the n-th sample from the dataset.
160
+
161
+ Args:
162
+ n (int): The index of the sample to be loaded
163
+
164
+ Returns:
165
+ Tuple of the following items;
166
+
167
+ Tensor:
168
+ Waveform
169
+ int:
170
+ Sample rate
171
+ str:
172
+ Label
173
+ str:
174
+ Speaker ID
175
+ int:
176
+ Utterance number
177
+ """
178
+ metadata = self.get_metadata(n)
179
+ waveform = _load_waveform(self._archive, metadata[0], metadata[1])
180
+ return (waveform,) + metadata[1:]
181
+
182
+ def __len__(self) -> int:
183
+ return len(self._walker)
.venv/lib/python3.11/site-packages/torchaudio/datasets/tedlium.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Tuple, Union
4
+
5
+ import torchaudio
6
+ from torch import Tensor
7
+ from torch.utils.data import Dataset
8
+ from torchaudio._internal import download_url_to_file
9
+ from torchaudio.datasets.utils import _extract_tar
10
+
11
+
12
+ _RELEASE_CONFIGS = {
13
+ "release1": {
14
+ "folder_in_archive": "TEDLIUM_release1",
15
+ "url": "http://www.openslr.org/resources/7/TEDLIUM_release1.tar.gz",
16
+ "checksum": "30301975fd8c5cac4040c261c0852f57cfa8adbbad2ce78e77e4986957445f27",
17
+ "data_path": "",
18
+ "subset": "train",
19
+ "supported_subsets": ["train", "test", "dev"],
20
+ "dict": "TEDLIUM.150K.dic",
21
+ },
22
+ "release2": {
23
+ "folder_in_archive": "TEDLIUM_release2",
24
+ "url": "http://www.openslr.org/resources/19/TEDLIUM_release2.tar.gz",
25
+ "checksum": "93281b5fcaaae5c88671c9d000b443cb3c7ea3499ad12010b3934ca41a7b9c58",
26
+ "data_path": "",
27
+ "subset": "train",
28
+ "supported_subsets": ["train", "test", "dev"],
29
+ "dict": "TEDLIUM.152k.dic",
30
+ },
31
+ "release3": {
32
+ "folder_in_archive": "TEDLIUM_release-3",
33
+ "url": "http://www.openslr.org/resources/51/TEDLIUM_release-3.tgz",
34
+ "checksum": "ad1e454d14d1ad550bc2564c462d87c7a7ec83d4dc2b9210f22ab4973b9eccdb",
35
+ "data_path": "data/",
36
+ "subset": "train",
37
+ "supported_subsets": ["train", "test", "dev"],
38
+ "dict": "TEDLIUM.152k.dic",
39
+ },
40
+ }
41
+
42
+
43
+ class TEDLIUM(Dataset):
44
+ """*Tedlium* :cite:`rousseau2012tedlium` dataset (releases 1,2 and 3).
45
+
46
+ Args:
47
+ root (str or Path): Path to the directory where the dataset is found or downloaded.
48
+ release (str, optional): Release version.
49
+ Allowed values are ``"release1"``, ``"release2"`` or ``"release3"``.
50
+ (default: ``"release1"``).
51
+ subset (str, optional): The subset of dataset to use. Valid options are ``"train"``, ``"dev"``,
52
+ and ``"test"``. Defaults to ``"train"``.
53
+ download (bool, optional):
54
+ Whether to download the dataset if it is not found at root path. (default: ``False``).
55
+ audio_ext (str, optional): extension for audio file (default: ``".sph"``)
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ root: Union[str, Path],
61
+ release: str = "release1",
62
+ subset: str = "train",
63
+ download: bool = False,
64
+ audio_ext: str = ".sph",
65
+ ) -> None:
66
+ self._ext_audio = audio_ext
67
+ if release in _RELEASE_CONFIGS.keys():
68
+ folder_in_archive = _RELEASE_CONFIGS[release]["folder_in_archive"]
69
+ url = _RELEASE_CONFIGS[release]["url"]
70
+ subset = subset if subset else _RELEASE_CONFIGS[release]["subset"]
71
+ else:
72
+ # Raise warning
73
+ raise RuntimeError(
74
+ "The release {} does not match any of the supported tedlium releases{} ".format(
75
+ release,
76
+ _RELEASE_CONFIGS.keys(),
77
+ )
78
+ )
79
+ if subset not in _RELEASE_CONFIGS[release]["supported_subsets"]:
80
+ # Raise warning
81
+ raise RuntimeError(
82
+ "The subset {} does not match any of the supported tedlium subsets{} ".format(
83
+ subset,
84
+ _RELEASE_CONFIGS[release]["supported_subsets"],
85
+ )
86
+ )
87
+
88
+ # Get string representation of 'root' in case Path object is passed
89
+ root = os.fspath(root)
90
+
91
+ basename = os.path.basename(url)
92
+ archive = os.path.join(root, basename)
93
+
94
+ basename = basename.split(".")[0]
95
+
96
+ if release == "release3":
97
+ if subset == "train":
98
+ self._path = os.path.join(root, folder_in_archive, _RELEASE_CONFIGS[release]["data_path"])
99
+ else:
100
+ self._path = os.path.join(root, folder_in_archive, "legacy", subset)
101
+ else:
102
+ self._path = os.path.join(root, folder_in_archive, _RELEASE_CONFIGS[release]["data_path"], subset)
103
+
104
+ if download:
105
+ if not os.path.isdir(self._path):
106
+ if not os.path.isfile(archive):
107
+ checksum = _RELEASE_CONFIGS[release]["checksum"]
108
+ download_url_to_file(url, archive, hash_prefix=checksum)
109
+ _extract_tar(archive)
110
+ else:
111
+ if not os.path.exists(self._path):
112
+ raise RuntimeError(
113
+ f"The path {self._path} doesn't exist. "
114
+ "Please check the ``root`` path or set `download=True` to download it"
115
+ )
116
+
117
+ # Create list for all samples
118
+ self._filelist = []
119
+ stm_path = os.path.join(self._path, "stm")
120
+ for file in sorted(os.listdir(stm_path)):
121
+ if file.endswith(".stm"):
122
+ stm_path = os.path.join(self._path, "stm", file)
123
+ with open(stm_path) as f:
124
+ l = len(f.readlines())
125
+ file = file.replace(".stm", "")
126
+ self._filelist.extend((file, line) for line in range(l))
127
+ # Create dict path for later read
128
+ self._dict_path = os.path.join(root, folder_in_archive, _RELEASE_CONFIGS[release]["dict"])
129
+ self._phoneme_dict = None
130
+
131
+ def _load_tedlium_item(self, fileid: str, line: int, path: str) -> Tuple[Tensor, int, str, int, int, int]:
132
+ """Loads a TEDLIUM dataset sample given a file name and corresponding sentence name.
133
+
134
+ Args:
135
+ fileid (str): File id to identify both text and audio files corresponding to the sample
136
+ line (int): Line identifier for the sample inside the text file
137
+ path (str): Dataset root path
138
+
139
+ Returns:
140
+ (Tensor, int, str, int, int, int):
141
+ ``(waveform, sample_rate, transcript, talk_id, speaker_id, identifier)``
142
+ """
143
+ transcript_path = os.path.join(path, "stm", fileid)
144
+ with open(transcript_path + ".stm") as f:
145
+ transcript = f.readlines()[line]
146
+ talk_id, _, speaker_id, start_time, end_time, identifier, transcript = transcript.split(" ", 6)
147
+
148
+ wave_path = os.path.join(path, "sph", fileid)
149
+ waveform, sample_rate = self._load_audio(wave_path + self._ext_audio, start_time=start_time, end_time=end_time)
150
+
151
+ return (waveform, sample_rate, transcript, talk_id, speaker_id, identifier)
152
+
153
+ def _load_audio(self, path: str, start_time: float, end_time: float, sample_rate: int = 16000) -> [Tensor, int]:
154
+ """Default load function used in TEDLIUM dataset, you can overwrite this function to customize functionality
155
+ and load individual sentences from a full ted audio talk file.
156
+
157
+ Args:
158
+ path (str): Path to audio file
159
+ start_time (int): Time in seconds where the sample sentence stars
160
+ end_time (int): Time in seconds where the sample sentence finishes
161
+ sample_rate (float, optional): Sampling rate
162
+
163
+ Returns:
164
+ [Tensor, int]: Audio tensor representation and sample rate
165
+ """
166
+ start_time = int(float(start_time) * sample_rate)
167
+ end_time = int(float(end_time) * sample_rate)
168
+
169
+ kwargs = {"frame_offset": start_time, "num_frames": end_time - start_time}
170
+
171
+ return torchaudio.load(path, **kwargs)
172
+
173
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
174
+ """Load the n-th sample from the dataset.
175
+
176
+ Args:
177
+ n (int): The index of the sample to be loaded
178
+
179
+ Returns:
180
+ Tuple of the following items;
181
+
182
+ Tensor:
183
+ Waveform
184
+ int:
185
+ Sample rate
186
+ str:
187
+ Transcript
188
+ int:
189
+ Talk ID
190
+ int:
191
+ Speaker ID
192
+ int:
193
+ Identifier
194
+ """
195
+ fileid, line = self._filelist[n]
196
+ return self._load_tedlium_item(fileid, line, self._path)
197
+
198
+ def __len__(self) -> int:
199
+ """TEDLIUM dataset custom function overwritting len default behaviour.
200
+
201
+ Returns:
202
+ int: TEDLIUM dataset length
203
+ """
204
+ return len(self._filelist)
205
+
206
+ @property
207
+ def phoneme_dict(self):
208
+ """dict[str, tuple[str]]: Phonemes. Mapping from word to tuple of phonemes.
209
+ Note that some words have empty phonemes.
210
+ """
211
+ # Read phoneme dictionary
212
+ if not self._phoneme_dict:
213
+ self._phoneme_dict = {}
214
+ with open(self._dict_path, "r", encoding="utf-8") as f:
215
+ for line in f.readlines():
216
+ content = line.strip().split()
217
+ self._phoneme_dict[content[0]] = tuple(content[1:]) # content[1:] can be empty list
218
+ return self._phoneme_dict.copy()
.venv/lib/python3.11/site-packages/torchaudio/datasets/utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import tarfile
4
+ import zipfile
5
+ from typing import Any, List, Optional
6
+
7
+ import torchaudio
8
+
9
+ _LG = logging.getLogger(__name__)
10
+
11
+
12
+ def _extract_tar(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
13
+ if to_path is None:
14
+ to_path = os.path.dirname(from_path)
15
+ with tarfile.open(from_path, "r") as tar:
16
+ files = []
17
+ for file_ in tar: # type: Any
18
+ file_path = os.path.join(to_path, file_.name)
19
+ if file_.isfile():
20
+ files.append(file_path)
21
+ if os.path.exists(file_path):
22
+ _LG.info("%s already extracted.", file_path)
23
+ if not overwrite:
24
+ continue
25
+ tar.extract(file_, to_path)
26
+ return files
27
+
28
+
29
+ def _extract_zip(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
30
+ if to_path is None:
31
+ to_path = os.path.dirname(from_path)
32
+
33
+ with zipfile.ZipFile(from_path, "r") as zfile:
34
+ files = zfile.namelist()
35
+ for file_ in files:
36
+ file_path = os.path.join(to_path, file_)
37
+ if os.path.exists(file_path):
38
+ _LG.info("%s already extracted.", file_path)
39
+ if not overwrite:
40
+ continue
41
+ zfile.extract(file_, to_path)
42
+ return files
43
+
44
+
45
+ def _load_waveform(
46
+ root: str,
47
+ filename: str,
48
+ exp_sample_rate: int,
49
+ ):
50
+ path = os.path.join(root, filename)
51
+ waveform, sample_rate = torchaudio.load(path)
52
+ if exp_sample_rate != sample_rate:
53
+ raise ValueError(f"sample rate should be {exp_sample_rate}, but got {sample_rate}")
54
+ return waveform
.venv/lib/python3.11/site-packages/torchaudio/datasets/vctk.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Tuple
3
+
4
+ import torchaudio
5
+ from torch import Tensor
6
+ from torch.utils.data import Dataset
7
+ from torchaudio._internal import download_url_to_file
8
+ from torchaudio.datasets.utils import _extract_zip
9
+
10
+ URL = "https://datashare.is.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip"
11
+ _CHECKSUMS = {
12
+ "https://datashare.is.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip": "f96258be9fdc2cbff6559541aae7ea4f59df3fcaf5cf963aae5ca647357e359c" # noqa: E501
13
+ }
14
+
15
+
16
+ SampleType = Tuple[Tensor, int, str, str, str]
17
+
18
+
19
+ class VCTK_092(Dataset):
20
+ """*VCTK 0.92* :cite:`yamagishi2019vctk` dataset
21
+
22
+ Args:
23
+ root (str): Root directory where the dataset's top level directory is found.
24
+ mic_id (str, optional): Microphone ID. Either ``"mic1"`` or ``"mic2"``. (default: ``"mic2"``)
25
+ download (bool, optional):
26
+ Whether to download the dataset if it is not found at root path. (default: ``False``).
27
+ url (str, optional): The URL to download the dataset from.
28
+ (default: ``"https://datashare.is.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip"``)
29
+ audio_ext (str, optional): Custom audio extension if dataset is converted to non-default audio format.
30
+
31
+ Note:
32
+ * All the speeches from speaker ``p315`` will be skipped due to the lack of the corresponding text files.
33
+ * All the speeches from ``p280`` will be skipped for ``mic_id="mic2"`` due to the lack of the audio files.
34
+ * Some of the speeches from speaker ``p362`` will be skipped due to the lack of the audio files.
35
+ * See Also: https://datashare.is.ed.ac.uk/handle/10283/3443
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ root: str,
41
+ mic_id: str = "mic2",
42
+ download: bool = False,
43
+ url: str = URL,
44
+ audio_ext=".flac",
45
+ ):
46
+ if mic_id not in ["mic1", "mic2"]:
47
+ raise RuntimeError(f'`mic_id` has to be either "mic1" or "mic2". Found: {mic_id}')
48
+
49
+ archive = os.path.join(root, "VCTK-Corpus-0.92.zip")
50
+
51
+ self._path = os.path.join(root, "VCTK-Corpus-0.92")
52
+ self._txt_dir = os.path.join(self._path, "txt")
53
+ self._audio_dir = os.path.join(self._path, "wav48_silence_trimmed")
54
+ self._mic_id = mic_id
55
+ self._audio_ext = audio_ext
56
+
57
+ if download:
58
+ if not os.path.isdir(self._path):
59
+ if not os.path.isfile(archive):
60
+ checksum = _CHECKSUMS.get(url, None)
61
+ download_url_to_file(url, archive, hash_prefix=checksum)
62
+ _extract_zip(archive, self._path)
63
+
64
+ if not os.path.isdir(self._path):
65
+ raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
66
+
67
+ # Extracting speaker IDs from the folder structure
68
+ self._speaker_ids = sorted(os.listdir(self._txt_dir))
69
+ self._sample_ids = []
70
+
71
+ """
72
+ Due to some insufficient data complexity in the 0.92 version of this dataset,
73
+ we start traversing the audio folder structure in accordance with the text folder.
74
+ As some of the audio files are missing of either ``mic_1`` or ``mic_2`` but the
75
+ text is present for the same, we first check for the existence of the audio file
76
+ before adding it to the ``sample_ids`` list.
77
+
78
+ Once the ``audio_ids`` are loaded into memory we can quickly access the list for
79
+ different parameters required by the user.
80
+ """
81
+ for speaker_id in self._speaker_ids:
82
+ if speaker_id == "p280" and mic_id == "mic2":
83
+ continue
84
+ utterance_dir = os.path.join(self._txt_dir, speaker_id)
85
+ for utterance_file in sorted(f for f in os.listdir(utterance_dir) if f.endswith(".txt")):
86
+ utterance_id = os.path.splitext(utterance_file)[0]
87
+ audio_path_mic = os.path.join(
88
+ self._audio_dir,
89
+ speaker_id,
90
+ f"{utterance_id}_{mic_id}{self._audio_ext}",
91
+ )
92
+ if speaker_id == "p362" and not os.path.isfile(audio_path_mic):
93
+ continue
94
+ self._sample_ids.append(utterance_id.split("_"))
95
+
96
+ def _load_text(self, file_path) -> str:
97
+ with open(file_path) as file_path:
98
+ return file_path.readlines()[0]
99
+
100
+ def _load_audio(self, file_path) -> Tuple[Tensor, int]:
101
+ return torchaudio.load(file_path)
102
+
103
+ def _load_sample(self, speaker_id: str, utterance_id: str, mic_id: str) -> SampleType:
104
+ transcript_path = os.path.join(self._txt_dir, speaker_id, f"{speaker_id}_{utterance_id}.txt")
105
+ audio_path = os.path.join(
106
+ self._audio_dir,
107
+ speaker_id,
108
+ f"{speaker_id}_{utterance_id}_{mic_id}{self._audio_ext}",
109
+ )
110
+
111
+ # Reading text
112
+ transcript = self._load_text(transcript_path)
113
+
114
+ # Reading FLAC
115
+ waveform, sample_rate = self._load_audio(audio_path)
116
+
117
+ return (waveform, sample_rate, transcript, speaker_id, utterance_id)
118
+
119
+ def __getitem__(self, n: int) -> SampleType:
120
+ """Load the n-th sample from the dataset.
121
+
122
+ Args:
123
+ n (int): The index of the sample to be loaded
124
+
125
+ Returns:
126
+ Tuple of the following items;
127
+
128
+ Tensor:
129
+ Waveform
130
+ int:
131
+ Sample rate
132
+ str:
133
+ Transcript
134
+ str:
135
+ Speaker ID
136
+ std:
137
+ Utterance ID
138
+ """
139
+ speaker_id, utterance_id = self._sample_ids[n]
140
+ return self._load_sample(speaker_id, utterance_id, self._mic_id)
141
+
142
+ def __len__(self) -> int:
143
+ return len(self._sample_ids)
.venv/lib/python3.11/site-packages/torchaudio/datasets/voxceleb1.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import List, Tuple, Union
4
+
5
+ from torch import Tensor
6
+ from torch.utils.data import Dataset
7
+ from torchaudio._internal import download_url_to_file
8
+ from torchaudio.datasets.utils import _extract_zip, _load_waveform
9
+
10
+
11
+ SAMPLE_RATE = 16000
12
+ _ARCHIVE_CONFIGS = {
13
+ "dev": {
14
+ "archive_name": "vox1_dev_wav.zip",
15
+ "urls": [
16
+ "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa",
17
+ "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab",
18
+ "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac",
19
+ "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad",
20
+ ],
21
+ "checksums": [
22
+ "21ec6ca843659ebc2fdbe04b530baa4f191ad4b0971912672d92c158f32226a0",
23
+ "311d21e0c8cbf33573a4fce6c80e5a279d80736274b381c394319fc557159a04",
24
+ "92b64465f2b2a3dc0e4196ae8dd6828cbe9ddd1f089419a11e4cbfe2e1750df0",
25
+ "00e6190c770b27f27d2a3dd26ee15596b17066b715ac111906861a7d09a211a5",
26
+ ],
27
+ },
28
+ "test": {
29
+ "archive_name": "vox1_test_wav.zip",
30
+ "url": "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip",
31
+ "checksum": "8de57f347fe22b2c24526e9f444f689ecf5096fc2a92018cf420ff6b5b15eaea",
32
+ },
33
+ }
34
+ _IDEN_SPLIT_URL = "https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/iden_split.txt"
35
+ _VERI_TEST_URL = "https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test.txt"
36
+
37
+
38
+ def _download_extract_wavs(root: str):
39
+ for archive in ["dev", "test"]:
40
+ archive_name = _ARCHIVE_CONFIGS[archive]["archive_name"]
41
+ archive_path = os.path.join(root, archive_name)
42
+ # The zip file of dev data is splited to 4 chunks.
43
+ # Download and combine them into one file before extraction.
44
+ if archive == "dev":
45
+ urls = _ARCHIVE_CONFIGS[archive]["urls"]
46
+ checksums = _ARCHIVE_CONFIGS[archive]["checksums"]
47
+ with open(archive_path, "wb") as f:
48
+ for url, checksum in zip(urls, checksums):
49
+ file_path = os.path.join(root, os.path.basename(url))
50
+ download_url_to_file(url, file_path, hash_prefix=checksum)
51
+ with open(file_path, "rb") as f_split:
52
+ f.write(f_split.read())
53
+ else:
54
+ url = _ARCHIVE_CONFIGS[archive]["url"]
55
+ checksum = _ARCHIVE_CONFIGS[archive]["checksum"]
56
+ download_url_to_file(url, archive_path, hash_prefix=checksum)
57
+ _extract_zip(archive_path)
58
+
59
+
60
+ def _get_flist(root: str, file_path: str, subset: str) -> List[str]:
61
+ f_list = []
62
+ if subset == "train":
63
+ index = 1
64
+ elif subset == "dev":
65
+ index = 2
66
+ else:
67
+ index = 3
68
+ with open(file_path, "r") as f:
69
+ for line in f:
70
+ id, path = line.split()
71
+ if int(id) == index:
72
+ f_list.append(path)
73
+ return sorted(f_list)
74
+
75
+
76
+ def _get_paired_flist(root: str, veri_test_path: str):
77
+ f_list = []
78
+ with open(veri_test_path, "r") as f:
79
+ for line in f:
80
+ label, path1, path2 = line.split()
81
+ f_list.append((label, path1, path2))
82
+ return f_list
83
+
84
+
85
+ def _get_file_id(file_path: str, _ext_audio: str):
86
+ speaker_id, youtube_id, utterance_id = file_path.split("/")[-3:]
87
+ utterance_id = utterance_id.replace(_ext_audio, "")
88
+ file_id = "-".join([speaker_id, youtube_id, utterance_id])
89
+ return file_id
90
+
91
+
92
+ class VoxCeleb1(Dataset):
93
+ """*VoxCeleb1* :cite:`nagrani2017voxceleb` dataset.
94
+
95
+ Args:
96
+ root (str or Path): Path to the directory where the dataset is found or downloaded.
97
+ download (bool, optional):
98
+ Whether to download the dataset if it is not found at root path. (Default: ``False``).
99
+ """
100
+
101
+ _ext_audio = ".wav"
102
+
103
+ def __init__(self, root: Union[str, Path], download: bool = False) -> None:
104
+ # Get string representation of 'root' in case Path object is passed
105
+ root = os.fspath(root)
106
+ self._path = os.path.join(root, "wav")
107
+ if not os.path.isdir(self._path):
108
+ if not download:
109
+ raise RuntimeError(
110
+ f"Dataset not found at {self._path}. Please set `download=True` to download the dataset."
111
+ )
112
+ _download_extract_wavs(root)
113
+
114
+ def get_metadata(self, n: int):
115
+ raise NotImplementedError
116
+
117
+ def __getitem__(self, n: int):
118
+ raise NotImplementedError
119
+
120
+ def __len__(self) -> int:
121
+ raise NotImplementedError
122
+
123
+
124
+ class VoxCeleb1Identification(VoxCeleb1):
125
+ """*VoxCeleb1* :cite:`nagrani2017voxceleb` dataset for speaker identification task.
126
+
127
+ Each data sample contains the waveform, sample rate, speaker id, and the file id.
128
+
129
+ Args:
130
+ root (str or Path): Path to the directory where the dataset is found or downloaded.
131
+ subset (str, optional): Subset of the dataset to use. Options: ["train", "dev", "test"]. (Default: ``"train"``)
132
+ meta_url (str, optional): The url of meta file that contains the list of subset labels and file paths.
133
+ The format of each row is ``subset file_path". For example: ``1 id10006/nLEBBc9oIFs/00003.wav``.
134
+ ``1``, ``2``, ``3`` mean ``train``, ``dev``, and ``test`` subest, respectively.
135
+ (Default: ``"https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/iden_split.txt"``)
136
+ download (bool, optional):
137
+ Whether to download the dataset if it is not found at root path. (Default: ``False``).
138
+
139
+ Note:
140
+ The file structure of `VoxCeleb1Identification` dataset is as follows:
141
+
142
+ └─ root/
143
+
144
+ └─ wav/
145
+
146
+ └─ speaker_id folders
147
+
148
+ Users who pre-downloaded the ``"vox1_dev_wav.zip"`` and ``"vox1_test_wav.zip"`` files need to move
149
+ the extracted files into the same ``root`` directory.
150
+ """
151
+
152
+ def __init__(
153
+ self, root: Union[str, Path], subset: str = "train", meta_url: str = _IDEN_SPLIT_URL, download: bool = False
154
+ ) -> None:
155
+ super().__init__(root, download)
156
+ if subset not in ["train", "dev", "test"]:
157
+ raise ValueError("`subset` must be one of ['train', 'dev', 'test']")
158
+ # download the iden_split.txt to get the train, dev, test lists.
159
+ meta_list_path = os.path.join(root, os.path.basename(meta_url))
160
+ if not os.path.exists(meta_list_path):
161
+ download_url_to_file(meta_url, meta_list_path)
162
+ self._flist = _get_flist(self._path, meta_list_path, subset)
163
+
164
+ def get_metadata(self, n: int) -> Tuple[str, int, int, str]:
165
+ """Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
166
+ but otherwise returns the same fields as :py:func:`__getitem__`.
167
+
168
+ Args:
169
+ n (int): The index of the sample
170
+
171
+ Returns:
172
+ Tuple of the following items;
173
+
174
+ str:
175
+ Path to audio
176
+ int:
177
+ Sample rate
178
+ int:
179
+ Speaker ID
180
+ str:
181
+ File ID
182
+ """
183
+ file_path = self._flist[n]
184
+ file_id = _get_file_id(file_path, self._ext_audio)
185
+ speaker_id = file_id.split("-")[0]
186
+ speaker_id = int(speaker_id[3:])
187
+ return file_path, SAMPLE_RATE, speaker_id, file_id
188
+
189
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, int, str]:
190
+ """Load the n-th sample from the dataset.
191
+
192
+ Args:
193
+ n (int): The index of the sample to be loaded
194
+
195
+ Returns:
196
+ Tuple of the following items;
197
+
198
+ Tensor:
199
+ Waveform
200
+ int:
201
+ Sample rate
202
+ int:
203
+ Speaker ID
204
+ str:
205
+ File ID
206
+ """
207
+ metadata = self.get_metadata(n)
208
+ waveform = _load_waveform(self._path, metadata[0], metadata[1])
209
+ return (waveform,) + metadata[1:]
210
+
211
+ def __len__(self) -> int:
212
+ return len(self._flist)
213
+
214
+
215
+ class VoxCeleb1Verification(VoxCeleb1):
216
+ """*VoxCeleb1* :cite:`nagrani2017voxceleb` dataset for speaker verification task.
217
+
218
+ Each data sample contains a pair of waveforms, sample rate, the label indicating if they are
219
+ from the same speaker, and the file ids.
220
+
221
+ Args:
222
+ root (str or Path): Path to the directory where the dataset is found or downloaded.
223
+ meta_url (str, optional): The url of meta file that contains a list of utterance pairs
224
+ and the corresponding labels. The format of each row is ``label file_path1 file_path2".
225
+ For example: ``1 id10270/x6uYqmx31kE/00001.wav id10270/8jEAjG6SegY/00008.wav``.
226
+ ``1`` means the two utterances are from the same speaker, ``0`` means not.
227
+ (Default: ``"https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test.txt"``)
228
+ download (bool, optional):
229
+ Whether to download the dataset if it is not found at root path. (Default: ``False``).
230
+
231
+ Note:
232
+ The file structure of `VoxCeleb1Verification` dataset is as follows:
233
+
234
+ └─ root/
235
+
236
+ └─ wav/
237
+
238
+ └─ speaker_id folders
239
+
240
+ Users who pre-downloaded the ``"vox1_dev_wav.zip"`` and ``"vox1_test_wav.zip"`` files need to move
241
+ the extracted files into the same ``root`` directory.
242
+ """
243
+
244
+ def __init__(self, root: Union[str, Path], meta_url: str = _VERI_TEST_URL, download: bool = False) -> None:
245
+ super().__init__(root, download)
246
+ # download the veri_test.txt to get the list of training pairs and labels.
247
+ meta_list_path = os.path.join(root, os.path.basename(meta_url))
248
+ if not os.path.exists(meta_list_path):
249
+ download_url_to_file(meta_url, meta_list_path)
250
+ self._flist = _get_paired_flist(self._path, meta_list_path)
251
+
252
+ def get_metadata(self, n: int) -> Tuple[str, str, int, int, str, str]:
253
+ """Get metadata for the n-th sample from the dataset. Returns filepaths instead of waveforms,
254
+ but otherwise returns the same fields as :py:func:`__getitem__`.
255
+
256
+ Args:
257
+ n (int): The index of the sample
258
+
259
+ Returns:
260
+ Tuple of the following items;
261
+
262
+ str:
263
+ Path to audio file of speaker 1
264
+ str:
265
+ Path to audio file of speaker 2
266
+ int:
267
+ Sample rate
268
+ int:
269
+ Label
270
+ str:
271
+ File ID of speaker 1
272
+ str:
273
+ File ID of speaker 2
274
+ """
275
+ label, file_path_spk1, file_path_spk2 = self._flist[n]
276
+ label = int(label)
277
+ file_id_spk1 = _get_file_id(file_path_spk1, self._ext_audio)
278
+ file_id_spk2 = _get_file_id(file_path_spk2, self._ext_audio)
279
+ return file_path_spk1, file_path_spk2, SAMPLE_RATE, label, file_id_spk1, file_id_spk2
280
+
281
+ def __getitem__(self, n: int) -> Tuple[Tensor, Tensor, int, int, str, str]:
282
+ """Load the n-th sample from the dataset.
283
+
284
+ Args:
285
+ n (int): The index of the sample to be loaded.
286
+
287
+ Returns:
288
+ Tuple of the following items;
289
+
290
+ Tensor:
291
+ Waveform of speaker 1
292
+ Tensor:
293
+ Waveform of speaker 2
294
+ int:
295
+ Sample rate
296
+ int:
297
+ Label
298
+ str:
299
+ File ID of speaker 1
300
+ str:
301
+ File ID of speaker 2
302
+ """
303
+ metadata = self.get_metadata(n)
304
+ waveform_spk1 = _load_waveform(self._path, metadata[0], metadata[2])
305
+ waveform_spk2 = _load_waveform(self._path, metadata[1], metadata[2])
306
+ return (waveform_spk1, waveform_spk2) + metadata[2:]
307
+
308
+ def __len__(self) -> int:
309
+ return len(self._flist)
.venv/lib/python3.11/site-packages/torchaudio/datasets/yesno.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import List, Tuple, Union
4
+
5
+ import torchaudio
6
+ from torch import Tensor
7
+ from torch.utils.data import Dataset
8
+ from torchaudio._internal import download_url_to_file
9
+ from torchaudio.datasets.utils import _extract_tar
10
+
11
+
12
+ _RELEASE_CONFIGS = {
13
+ "release1": {
14
+ "folder_in_archive": "waves_yesno",
15
+ "url": "http://www.openslr.org/resources/1/waves_yesno.tar.gz",
16
+ "checksum": "c3f49e0cca421f96b75b41640749167b52118f232498667ca7a5f9416aef8e73",
17
+ }
18
+ }
19
+
20
+
21
+ class YESNO(Dataset):
22
+ """*YesNo* :cite:`YesNo` dataset.
23
+
24
+ Args:
25
+ root (str or Path): Path to the directory where the dataset is found or downloaded.
26
+ url (str, optional): The URL to download the dataset from.
27
+ (default: ``"http://www.openslr.org/resources/1/waves_yesno.tar.gz"``)
28
+ folder_in_archive (str, optional):
29
+ The top-level directory of the dataset. (default: ``"waves_yesno"``)
30
+ download (bool, optional):
31
+ Whether to download the dataset if it is not found at root path. (default: ``False``).
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ root: Union[str, Path],
37
+ url: str = _RELEASE_CONFIGS["release1"]["url"],
38
+ folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"],
39
+ download: bool = False,
40
+ ) -> None:
41
+
42
+ self._parse_filesystem(root, url, folder_in_archive, download)
43
+
44
+ def _parse_filesystem(self, root: str, url: str, folder_in_archive: str, download: bool) -> None:
45
+ root = Path(root)
46
+ archive = os.path.basename(url)
47
+ archive = root / archive
48
+
49
+ self._path = root / folder_in_archive
50
+ if download:
51
+ if not os.path.isdir(self._path):
52
+ if not os.path.isfile(archive):
53
+ checksum = _RELEASE_CONFIGS["release1"]["checksum"]
54
+ download_url_to_file(url, archive, hash_prefix=checksum)
55
+ _extract_tar(archive)
56
+
57
+ if not os.path.isdir(self._path):
58
+ raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
59
+
60
+ self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*.wav"))
61
+
62
+ def _load_item(self, fileid: str, path: str):
63
+ labels = [int(c) for c in fileid.split("_")]
64
+ file_audio = os.path.join(path, fileid + ".wav")
65
+ waveform, sample_rate = torchaudio.load(file_audio)
66
+ return waveform, sample_rate, labels
67
+
68
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, List[int]]:
69
+ """Load the n-th sample from the dataset.
70
+
71
+ Args:
72
+ n (int): The index of the sample to be loaded
73
+
74
+ Returns:
75
+ Tuple of the following items;
76
+
77
+ Tensor:
78
+ Waveform
79
+ int:
80
+ Sample rate
81
+ List[int]:
82
+ labels
83
+ """
84
+ fileid = self._walker[n]
85
+ item = self._load_item(fileid, self._path)
86
+ return item
87
+
88
+ def __len__(self) -> int:
89
+ return len(self._walker)
.venv/lib/python3.11/site-packages/torchaudio/io/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torio.io import CodecConfig, StreamingMediaDecoder as StreamReader, StreamingMediaEncoder as StreamWriter
2
+
3
+ from ._effector import AudioEffector
4
+ from ._playback import play_audio
5
+
6
+
7
+ __all__ = [
8
+ "AudioEffector",
9
+ "StreamReader",
10
+ "StreamWriter",
11
+ "CodecConfig",
12
+ "play_audio",
13
+ ]
.venv/lib/python3.11/site-packages/torchaudio/io/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (535 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchaudio/io/__pycache__/_effector.cpython-311.pyc ADDED
Binary file (15.3 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/io/__pycache__/_playback.cpython-311.pyc ADDED
Binary file (3.69 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/io/_effector.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from typing import Iterator, List, Optional
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ from torio.io._streaming_media_decoder import _get_afilter_desc, StreamingMediaDecoder as StreamReader
8
+ from torio.io._streaming_media_encoder import CodecConfig, StreamingMediaEncoder as StreamWriter
9
+
10
+
11
+ class _StreamingIOBuffer:
12
+ """Streaming Bytes IO buffer. Data are dropped when read."""
13
+
14
+ def __init__(self):
15
+ self._buffer: List(bytes) = []
16
+
17
+ def write(self, b: bytes):
18
+ if b:
19
+ self._buffer.append(b)
20
+ return len(b)
21
+
22
+ def pop(self, n):
23
+ """Pop the oldest byte string. It does not necessary return the requested amount"""
24
+ if not self._buffer:
25
+ return b""
26
+ if len(self._buffer[0]) <= n:
27
+ return self._buffer.pop(0)
28
+ ret = self._buffer[0][:n]
29
+ self._buffer[0] = self._buffer[0][n:]
30
+ return ret
31
+
32
+
33
+ def _get_sample_fmt(dtype: torch.dtype):
34
+ types = {
35
+ torch.uint8: "u8",
36
+ torch.int16: "s16",
37
+ torch.int32: "s32",
38
+ torch.float32: "flt",
39
+ torch.float64: "dbl",
40
+ }
41
+ if dtype not in types:
42
+ raise ValueError(f"Unsupported dtype is provided {dtype}. Supported dtypes are: {types.keys()}")
43
+ return types[dtype]
44
+
45
+
46
+ class _AudioStreamingEncoder:
47
+ """Given a waveform, encode on-demand and return bytes"""
48
+
49
+ def __init__(
50
+ self,
51
+ src: Tensor,
52
+ sample_rate: int,
53
+ effect: str,
54
+ muxer: str,
55
+ encoder: Optional[str],
56
+ codec_config: Optional[CodecConfig],
57
+ frames_per_chunk: int,
58
+ ):
59
+ self.src = src
60
+ self.buffer = _StreamingIOBuffer()
61
+ self.writer = StreamWriter(self.buffer, format=muxer)
62
+ self.writer.add_audio_stream(
63
+ num_channels=src.size(1),
64
+ sample_rate=sample_rate,
65
+ format=_get_sample_fmt(src.dtype),
66
+ encoder=encoder,
67
+ filter_desc=effect,
68
+ codec_config=codec_config,
69
+ )
70
+ self.writer.open()
71
+ self.fpc = frames_per_chunk
72
+
73
+ # index on the input tensor (along time-axis)
74
+ # we use -1 to indicate that we finished iterating the tensor and
75
+ # the writer is closed.
76
+ self.i_iter = 0
77
+
78
+ def read(self, n):
79
+ while not self.buffer._buffer and self.i_iter >= 0:
80
+ self.writer.write_audio_chunk(0, self.src[self.i_iter : self.i_iter + self.fpc])
81
+ self.i_iter += self.fpc
82
+ if self.i_iter >= self.src.size(0):
83
+ self.writer.flush()
84
+ self.writer.close()
85
+ self.i_iter = -1
86
+ return self.buffer.pop(n)
87
+
88
+
89
+ def _encode(
90
+ src: Tensor,
91
+ sample_rate: int,
92
+ effect: str,
93
+ muxer: str,
94
+ encoder: Optional[str],
95
+ codec_config: Optional[CodecConfig],
96
+ ):
97
+ buffer = io.BytesIO()
98
+ writer = StreamWriter(buffer, format=muxer)
99
+ writer.add_audio_stream(
100
+ num_channels=src.size(1),
101
+ sample_rate=sample_rate,
102
+ format=_get_sample_fmt(src.dtype),
103
+ encoder=encoder,
104
+ filter_desc=effect,
105
+ codec_config=codec_config,
106
+ )
107
+ with writer.open():
108
+ writer.write_audio_chunk(0, src)
109
+ buffer.seek(0)
110
+ return buffer
111
+
112
+
113
+ def _get_muxer(dtype: torch.dtype):
114
+ # TODO: check if this works in Windows.
115
+ types = {
116
+ torch.uint8: "u8",
117
+ torch.int16: "s16le",
118
+ torch.int32: "s32le",
119
+ torch.float32: "f32le",
120
+ torch.float64: "f64le",
121
+ }
122
+ if dtype not in types:
123
+ raise ValueError(f"Unsupported dtype is provided {dtype}. Supported dtypes are: {types.keys()}")
124
+ return types[dtype]
125
+
126
+
127
+ class AudioEffector:
128
+ """Apply various filters and/or codecs to waveforms.
129
+
130
+ .. versionadded:: 2.1
131
+
132
+ Args:
133
+ effect (str or None, optional): Filter expressions or ``None`` to apply no filter.
134
+ See https://ffmpeg.org/ffmpeg-filters.html#Audio-Filters for the
135
+ details of filter syntax.
136
+
137
+ format (str or None, optional): When provided, encode the audio into the
138
+ corresponding format. Default: ``None``.
139
+
140
+ encoder (str or None, optional): When provided, override the encoder used
141
+ by the ``format``. Default: ``None``.
142
+
143
+ codec_config (CodecConfig or None, optional): When provided, configure the encoding codec.
144
+ Should be provided in conjunction with ``format`` option.
145
+
146
+ pad_end (bool, optional): When enabled, and if the waveform becomes shorter after applying
147
+ effects/codec, then pad the end with silence.
148
+
149
+ Example - Basic usage
150
+ To use ``AudioEffector``, first instantiate it with a set of
151
+ ``effect`` and ``format``.
152
+
153
+ >>> # instantiate the effector
154
+ >>> effector = AudioEffector(effect=..., format=...)
155
+
156
+ Then, use :py:meth:`~AudioEffector.apply` or :py:meth:`~AudioEffector.stream`
157
+ method to apply them.
158
+
159
+ >>> # Apply the effect to the whole waveform
160
+ >>> applied = effector.apply(waveform, sample_rate)
161
+
162
+ >>> # Apply the effect chunk-by-chunk
163
+ >>> for chunk in effector.stream(waveform, sample_rate):
164
+ >>> ...
165
+
166
+ Example - Applying effects
167
+ Please refer to
168
+ https://ffmpeg.org/ffmpeg-filters.html#Filtergraph-description
169
+ for the overview of filter description, and
170
+ https://ffmpeg.org/ffmpeg-filters.html#toc-Audio-Filters
171
+ for the list of available filters.
172
+
173
+ Tempo - https://ffmpeg.org/ffmpeg-filters.html#atempo
174
+
175
+ >>> AudioEffector(effect="atempo=1.5")
176
+
177
+ Echo - https://ffmpeg.org/ffmpeg-filters.html#aecho
178
+
179
+ >>> AudioEffector(effect="aecho=0.8:0.88:60:0.4")
180
+
181
+ Flanger - https://ffmpeg.org/ffmpeg-filters.html#flanger
182
+
183
+ >>> AudioEffector(effect="aflanger")
184
+
185
+ Vibrato - https://ffmpeg.org/ffmpeg-filters.html#vibrato
186
+
187
+ >>> AudioEffector(effect="vibrato")
188
+
189
+ Tremolo - https://ffmpeg.org/ffmpeg-filters.html#tremolo
190
+
191
+ >>> AudioEffector(effect="vibrato")
192
+
193
+ You can also apply multiple effects at once.
194
+
195
+ >>> AudioEffector(effect="")
196
+
197
+ Example - Applying codec
198
+ One can apply codec using ``format`` argument. ``format`` can be
199
+ audio format or container format. If the container format supports
200
+ multiple encoders, you can specify it with ``encoder`` argument.
201
+
202
+ Wav format
203
+ (no compression is applied but samples are converted to
204
+ 16-bit signed integer)
205
+
206
+ >>> AudioEffector(format="wav")
207
+
208
+ Ogg format with default encoder
209
+
210
+ >>> AudioEffector(format="ogg")
211
+
212
+ Ogg format with vorbis
213
+
214
+ >>> AudioEffector(format="ogg", encoder="vorbis")
215
+
216
+ Ogg format with opus
217
+
218
+ >>> AudioEffector(format="ogg", encoder="opus")
219
+
220
+ Webm format with opus
221
+
222
+ >>> AudioEffector(format="webm", encoder="opus")
223
+
224
+ Example - Applying codec with configuration
225
+ Reference: https://trac.ffmpeg.org/wiki/Encode/MP3
226
+
227
+ MP3 with default config
228
+
229
+ >>> AudioEffector(format="mp3")
230
+
231
+ MP3 with variable bitrate
232
+
233
+ >>> AudioEffector(format="mp3", codec_config=CodecConfig(qscale=5))
234
+
235
+ MP3 with constant bitrate
236
+
237
+ >>> AudioEffector(format="mp3", codec_config=CodecConfig(bit_rate=32_000))
238
+ """
239
+
240
+ def __init__(
241
+ self,
242
+ effect: Optional[str] = None,
243
+ format: Optional[str] = None,
244
+ *,
245
+ encoder: Optional[str] = None,
246
+ codec_config: Optional[CodecConfig] = None,
247
+ pad_end: bool = True,
248
+ ):
249
+ if format is None:
250
+ if encoder is not None or codec_config is not None:
251
+ raise ValueError("`encoder` and/or `condec_config` opions are provided without `format` option.")
252
+ self.effect = effect
253
+ self.format = format
254
+ self.encoder = encoder
255
+ self.codec_config = codec_config
256
+ self.pad_end = pad_end
257
+
258
+ def _get_reader(self, waveform, sample_rate, output_sample_rate, frames_per_chunk=None):
259
+ num_frames, num_channels = waveform.shape
260
+
261
+ if self.format is not None:
262
+ muxer = self.format
263
+ encoder = self.encoder
264
+ option = {}
265
+ # Some formats are headerless, so need to provide these infomation.
266
+ if self.format == "mulaw":
267
+ option = {"sample_rate": f"{sample_rate}", "channels": f"{num_channels}"}
268
+
269
+ else: # PCM
270
+ muxer = _get_muxer(waveform.dtype)
271
+ encoder = None
272
+ option = {"sample_rate": f"{sample_rate}", "channels": f"{num_channels}"}
273
+
274
+ if frames_per_chunk is None:
275
+ src = _encode(waveform, sample_rate, self.effect, muxer, encoder, self.codec_config)
276
+ else:
277
+ src = _AudioStreamingEncoder(
278
+ waveform, sample_rate, self.effect, muxer, encoder, self.codec_config, frames_per_chunk
279
+ )
280
+
281
+ output_sr = sample_rate if output_sample_rate is None else output_sample_rate
282
+ filter_desc = _get_afilter_desc(output_sr, _get_sample_fmt(waveform.dtype), num_channels)
283
+ if self.pad_end:
284
+ filter_desc = f"{filter_desc},apad=whole_len={num_frames}"
285
+
286
+ reader = StreamReader(src, format=muxer, option=option)
287
+ reader.add_audio_stream(frames_per_chunk or -1, -1, filter_desc=filter_desc)
288
+ return reader
289
+
290
+ def apply(self, waveform: Tensor, sample_rate: int, output_sample_rate: Optional[int] = None) -> Tensor:
291
+ """Apply the effect and/or codecs to the whole tensor.
292
+
293
+ Args:
294
+ waveform (Tensor): The input waveform. Shape: ``(time, channel)``
295
+ sample_rate (int): Sample rate of the input waveform.
296
+ output_sample_rate (int or None, optional): Output sample rate.
297
+ If provided, override the output sample rate.
298
+ Otherwise, the resulting tensor is resampled to have
299
+ the same sample rate as the input.
300
+ Default: ``None``.
301
+
302
+ Returns:
303
+ Tensor:
304
+ Resulting Tensor. Shape: ``(time, channel)``. The number of frames
305
+ could be different from that of the input.
306
+ """
307
+ if waveform.ndim != 2:
308
+ raise ValueError(f"Expected the input waveform to be 2D. Found: {waveform.ndim}")
309
+
310
+ if waveform.numel() == 0:
311
+ return waveform
312
+
313
+ reader = self._get_reader(waveform, sample_rate, output_sample_rate)
314
+ reader.process_all_packets()
315
+ (applied,) = reader.pop_chunks()
316
+ return Tensor(applied)
317
+
318
+ def stream(
319
+ self, waveform: Tensor, sample_rate: int, frames_per_chunk: int, output_sample_rate: Optional[int] = None
320
+ ) -> Iterator[Tensor]:
321
+ """Apply the effect and/or codecs to the given tensor chunk by chunk.
322
+
323
+ Args:
324
+ waveform (Tensor): The input waveform. Shape: ``(time, channel)``
325
+ sample_rate (int): Sample rate of the waveform.
326
+ frames_per_chunk (int): The number of frames to return at a time.
327
+ output_sample_rate (int or None, optional): Output sample rate.
328
+ If provided, override the output sample rate.
329
+ Otherwise, the resulting tensor is resampled to have
330
+ the same sample rate as the input.
331
+ Default: ``None``.
332
+
333
+ Returns:
334
+ Iterator[Tensor]:
335
+ Series of processed chunks. Shape: ``(time, channel)``, where the
336
+ the number of frames matches ``frames_per_chunk`` except the
337
+ last chunk, which could be shorter.
338
+ """
339
+ if waveform.ndim != 2:
340
+ raise ValueError(f"Expected the input waveform to be 2D. Found: {waveform.ndim}")
341
+
342
+ if waveform.numel() == 0:
343
+ return waveform
344
+
345
+ reader = self._get_reader(waveform, sample_rate, output_sample_rate, frames_per_chunk)
346
+ for (applied,) in reader.stream():
347
+ yield Tensor(applied)
.venv/lib/python3.11/site-packages/torchaudio/io/_playback.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from sys import platform
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torchaudio
7
+
8
+ dict_format = {
9
+ torch.uint8: "u8",
10
+ torch.int16: "s16",
11
+ torch.int32: "s32",
12
+ torch.int64: "s64",
13
+ torch.float32: "flt",
14
+ torch.float64: "dbl",
15
+ }
16
+
17
+
18
+ def play_audio(
19
+ waveform: torch.Tensor,
20
+ sample_rate: Optional[float],
21
+ device: Optional[str] = None,
22
+ ) -> None:
23
+ """Plays audio through specified or available output device.
24
+
25
+ .. warning::
26
+ This function is currently only supported on MacOS, and requires
27
+ libavdevice (FFmpeg) with ``audiotoolbox`` output device.
28
+
29
+ .. note::
30
+ This function can play up to two audio channels.
31
+
32
+ Args:
33
+ waveform: Tensor containing the audio to play.
34
+ Expected shape: `(time, num_channels)`.
35
+ sample_rate: Sample rate of the audio to play.
36
+ device: Output device to use. If None, the default device is used.
37
+ """
38
+
39
+ if platform == "darwin":
40
+ device = device or "audiotoolbox"
41
+ path = "-"
42
+ else:
43
+ raise ValueError(f"This function only supports MacOS, but current OS is {platform}")
44
+
45
+ available_devices = list(torchaudio.utils.ffmpeg_utils.get_output_devices().keys())
46
+ if device not in available_devices:
47
+ raise ValueError(f"Device {device} is not available. Available devices are: {available_devices}")
48
+
49
+ if waveform.dtype not in dict_format:
50
+ raise ValueError(f"Unsupported type {waveform.dtype}. The list of supported types is: {dict_format.keys()}")
51
+ format = dict_format[waveform.dtype]
52
+
53
+ if waveform.ndim != 2:
54
+ raise ValueError(f"Expected 2D tensor with shape `(time, num_channels)`, got {waveform.ndim}D tensor instead")
55
+
56
+ time, num_channels = waveform.size()
57
+ if num_channels > 2:
58
+ warnings.warn(
59
+ f"Expected up to 2 channels, got {num_channels} channels instead. "
60
+ "Only the first 2 channels will be played.",
61
+ stacklevel=2,
62
+ )
63
+
64
+ # Write to speaker device
65
+ s = torchaudio.io.StreamWriter(dst=path, format=device)
66
+ s.add_audio_stream(sample_rate, num_channels, format=format)
67
+
68
+ # write audio to the device
69
+ block_size = 256
70
+ with s.open():
71
+ for i in range(0, time, block_size):
72
+ s.write_audio_chunk(0, waveform[i : i + block_size, :])
.venv/lib/python3.11/site-packages/torchaudio/lib/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/torchaudio/lib/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (187 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.25 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/_hdemucs.cpython-311.pyc ADDED
Binary file (51.3 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/conformer.cpython-311.pyc ADDED
Binary file (14.6 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/conv_tasnet.cpython-311.pyc ADDED
Binary file (15.1 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/deepspeech.cpython-311.pyc ADDED
Binary file (4.91 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/emformer.cpython-311.pyc ADDED
Binary file (49.5 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/rnnt.cpython-311.pyc ADDED
Binary file (41.9 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/rnnt_decoder.cpython-311.pyc ADDED
Binary file (20.6 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/tacotron2.cpython-311.pyc ADDED
Binary file (49.9 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/wav2letter.cpython-311.pyc ADDED
Binary file (4.41 kB). View file
 
.venv/lib/python3.11/site-packages/torchaudio/models/__pycache__/wavernn.cpython-311.pyc ADDED
Binary file (22.9 kB). View file