prasb commited on
Commit
610facb
·
verified ·
1 Parent(s): b5147e7

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/ATN.cpython-38.pyc +0 -0
  2. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/ATNConfig.cpython-38.pyc +0 -0
  3. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/ATNConfigSet.cpython-38.pyc +0 -0
  4. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/ATNDeserializationOptions.cpython-38.pyc +0 -0
  5. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/ATNDeserializer.cpython-38.pyc +0 -0
  6. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/ATNSimulator.cpython-38.pyc +0 -0
  7. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/ATNState.cpython-38.pyc +0 -0
  8. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/ATNType.cpython-38.pyc +0 -0
  9. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/LexerATNSimulator.cpython-38.pyc +0 -0
  10. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/LexerAction.cpython-38.pyc +0 -0
  11. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/LexerActionExecutor.cpython-38.pyc +0 -0
  12. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/ParserATNSimulator.cpython-38.pyc +0 -0
  13. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/PredictionMode.cpython-38.pyc +0 -0
  14. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/SemanticContext.cpython-38.pyc +0 -0
  15. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/Transition.cpython-38.pyc +0 -0
  16. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/__init__.cpython-38.pyc +0 -0
  17. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/dfa/__pycache__/DFA.cpython-38.pyc +0 -0
  18. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/dfa/__pycache__/DFASerializer.cpython-38.pyc +0 -0
  19. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/dfa/__pycache__/DFAState.cpython-38.pyc +0 -0
  20. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/dfa/__pycache__/__init__.cpython-38.pyc +0 -0
  21. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/error/__pycache__/DiagnosticErrorListener.cpython-38.pyc +0 -0
  22. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/error/__pycache__/ErrorListener.cpython-38.pyc +0 -0
  23. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/error/__pycache__/ErrorStrategy.cpython-38.pyc +0 -0
  24. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/error/__pycache__/Errors.cpython-38.pyc +0 -0
  25. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/error/__pycache__/__init__.cpython-38.pyc +0 -0
  26. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/tree/__pycache__/ParseTreeMatch.cpython-38.pyc +0 -0
  27. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/tree/__pycache__/ParseTreePattern.cpython-38.pyc +0 -0
  28. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/tree/__pycache__/ParseTreePatternMatcher.cpython-38.pyc +0 -0
  29. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/tree/__pycache__/TokenTagToken.cpython-38.pyc +0 -0
  30. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/tree/__pycache__/Trees.cpython-38.pyc +0 -0
  31. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/tree/__pycache__/__init__.cpython-38.pyc +0 -0
  32. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/xpath/__pycache__/XPath.cpython-38.pyc +0 -0
  33. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/xpath/__pycache__/__init__.cpython-38.pyc +0 -0
  34. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/matplotlib/mpl-data/sample_data/goog.npz +3 -0
  35. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/matplotlib/mpl-data/sample_data/topobathy.npz +3 -0
  36. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/__pycache__/__init__.cpython-38.pyc +0 -0
  37. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/__pycache__/_extension.cpython-38.pyc +0 -0
  38. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/__pycache__/kaldi_io.cpython-38.pyc +0 -0
  39. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/__pycache__/version.cpython-38.pyc +0 -0
  40. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/compliance/__init__.py +5 -0
  41. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/compliance/kaldi.py +815 -0
  42. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/datasets/__init__.py +34 -0
  43. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/datasets/cmuarctic.py +148 -0
  44. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/datasets/cmudict.py +183 -0
  45. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/datasets/commonvoice.py +71 -0
  46. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/datasets/dr_vctk.py +106 -0
  47. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/datasets/gtzan.py +1108 -0
  48. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/datasets/librilight_limited.py +91 -0
  49. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/datasets/librimix.py +85 -0
  50. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/datasets/librispeech.py +135 -0
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/ATN.cpython-38.pyc ADDED
Binary file (3.13 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/ATNConfig.cpython-38.pyc ADDED
Binary file (4.14 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/ATNConfigSet.cpython-38.pyc ADDED
Binary file (6.22 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/ATNDeserializationOptions.cpython-38.pyc ADDED
Binary file (1 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/ATNDeserializer.cpython-38.pyc ADDED
Binary file (15.8 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/ATNSimulator.cpython-38.pyc ADDED
Binary file (1.15 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/ATNState.cpython-38.pyc ADDED
Binary file (6.68 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/ATNType.cpython-38.pyc ADDED
Binary file (575 Bytes). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/LexerATNSimulator.cpython-38.pyc ADDED
Binary file (11.7 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/LexerAction.cpython-38.pyc ADDED
Binary file (8.59 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/LexerActionExecutor.cpython-38.pyc ADDED
Binary file (2.53 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/ParserATNSimulator.cpython-38.pyc ADDED
Binary file (24.5 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/PredictionMode.cpython-38.pyc ADDED
Binary file (5.07 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/SemanticContext.cpython-38.pyc ADDED
Binary file (7.36 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/Transition.cpython-38.pyc ADDED
Binary file (9.62 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/atn/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (189 Bytes). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/dfa/__pycache__/DFA.cpython-38.pyc ADDED
Binary file (3.13 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/dfa/__pycache__/DFASerializer.cpython-38.pyc ADDED
Binary file (2.48 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/dfa/__pycache__/DFAState.cpython-38.pyc ADDED
Binary file (2.33 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/dfa/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (189 Bytes). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/error/__pycache__/DiagnosticErrorListener.cpython-38.pyc ADDED
Binary file (2.9 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/error/__pycache__/ErrorListener.cpython-38.pyc ADDED
Binary file (2.78 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/error/__pycache__/ErrorStrategy.cpython-38.pyc ADDED
Binary file (9.89 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/error/__pycache__/Errors.cpython-38.pyc ADDED
Binary file (5.03 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/error/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (191 Bytes). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/tree/__pycache__/ParseTreeMatch.cpython-38.pyc ADDED
Binary file (1.8 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/tree/__pycache__/ParseTreePattern.cpython-38.pyc ADDED
Binary file (1.45 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/tree/__pycache__/ParseTreePatternMatcher.cpython-38.pyc ADDED
Binary file (7.75 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/tree/__pycache__/TokenTagToken.cpython-38.pyc ADDED
Binary file (1.05 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/tree/__pycache__/Trees.cpython-38.pyc ADDED
Binary file (3.41 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/tree/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (163 Bytes). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/xpath/__pycache__/XPath.cpython-38.pyc ADDED
Binary file (10.8 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/antlr4/xpath/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (191 Bytes). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/matplotlib/mpl-data/sample_data/goog.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:400917cf30e6b664f7b0da93d7c745860d3aa9008da8b7f160d2dd12e6a318b1
3
+ size 22845
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/matplotlib/mpl-data/sample_data/topobathy.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0244e03291702df45024dcb5cacbc4f3d4cb30d72dfa7fd371c4ac61c42b4fbf
3
+ size 45224
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (761 Bytes). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/__pycache__/_extension.cpython-38.pyc ADDED
Binary file (3.43 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/__pycache__/kaldi_io.cpython-38.pyc ADDED
Binary file (4.47 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/__pycache__/version.cpython-38.pyc ADDED
Binary file (250 Bytes). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/compliance/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from . import kaldi
2
+
3
+ __all__ = [
4
+ "kaldi",
5
+ ]
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/compliance/kaldi.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torchaudio
6
+ from torch import Tensor
7
+
8
+ __all__ = [
9
+ "get_mel_banks",
10
+ "inverse_mel_scale",
11
+ "inverse_mel_scale_scalar",
12
+ "mel_scale",
13
+ "mel_scale_scalar",
14
+ "spectrogram",
15
+ "fbank",
16
+ "mfcc",
17
+ "vtln_warp_freq",
18
+ "vtln_warp_mel_freq",
19
+ ]
20
+
21
+ # numeric_limits<float>::epsilon() 1.1920928955078125e-07
22
+ EPSILON = torch.tensor(torch.finfo(torch.float).eps)
23
+ # 1 milliseconds = 0.001 seconds
24
+ MILLISECONDS_TO_SECONDS = 0.001
25
+
26
+ # window types
27
+ HAMMING = "hamming"
28
+ HANNING = "hanning"
29
+ POVEY = "povey"
30
+ RECTANGULAR = "rectangular"
31
+ BLACKMAN = "blackman"
32
+ WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
33
+
34
+
35
+ def _get_epsilon(device, dtype):
36
+ return EPSILON.to(device=device, dtype=dtype)
37
+
38
+
39
+ def _next_power_of_2(x: int) -> int:
40
+ r"""Returns the smallest power of 2 that is greater than x"""
41
+ return 1 if x == 0 else 2 ** (x - 1).bit_length()
42
+
43
+
44
+ def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor:
45
+ r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``)
46
+ representing how the window is shifted along the waveform. Each row is a frame.
47
+
48
+ Args:
49
+ waveform (Tensor): Tensor of size ``num_samples``
50
+ window_size (int): Frame length
51
+ window_shift (int): Frame shift
52
+ snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
53
+ in the file, and the number of frames depends on the frame_length. If False, the number of frames
54
+ depends only on the frame_shift, and we reflect the data at the ends.
55
+
56
+ Returns:
57
+ Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame
58
+ """
59
+ assert waveform.dim() == 1
60
+ num_samples = waveform.size(0)
61
+ strides = (window_shift * waveform.stride(0), waveform.stride(0))
62
+
63
+ if snip_edges:
64
+ if num_samples < window_size:
65
+ return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
66
+ else:
67
+ m = 1 + (num_samples - window_size) // window_shift
68
+ else:
69
+ reversed_waveform = torch.flip(waveform, [0])
70
+ m = (num_samples + (window_shift // 2)) // window_shift
71
+ pad = window_size // 2 - window_shift // 2
72
+ pad_right = reversed_waveform
73
+ if pad > 0:
74
+ # torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
75
+ # but we want [2, 1, 0, 0, 1, 2]
76
+ pad_left = reversed_waveform[-pad:]
77
+ waveform = torch.cat((pad_left, waveform, pad_right), dim=0)
78
+ else:
79
+ # pad is negative so we want to trim the waveform at the front
80
+ waveform = torch.cat((waveform[-pad:], pad_right), dim=0)
81
+
82
+ sizes = (m, window_size)
83
+ return waveform.as_strided(sizes, strides)
84
+
85
+
86
+ def _feature_window_function(
87
+ window_type: str,
88
+ window_size: int,
89
+ blackman_coeff: float,
90
+ device: torch.device,
91
+ dtype: int,
92
+ ) -> Tensor:
93
+ r"""Returns a window function with the given type and size"""
94
+ if window_type == HANNING:
95
+ return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)
96
+ elif window_type == HAMMING:
97
+ return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype)
98
+ elif window_type == POVEY:
99
+ # like hanning but goes to zero at edges
100
+ return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85)
101
+ elif window_type == RECTANGULAR:
102
+ return torch.ones(window_size, device=device, dtype=dtype)
103
+ elif window_type == BLACKMAN:
104
+ a = 2 * math.pi / (window_size - 1)
105
+ window_function = torch.arange(window_size, device=device, dtype=dtype)
106
+ # can't use torch.blackman_window as they use different coefficients
107
+ return (
108
+ blackman_coeff
109
+ - 0.5 * torch.cos(a * window_function)
110
+ + (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)
111
+ ).to(device=device, dtype=dtype)
112
+ else:
113
+ raise Exception("Invalid window type " + window_type)
114
+
115
+
116
+ def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor:
117
+ r"""Returns the log energy of size (m) for a strided_input (m,*)"""
118
+ device, dtype = strided_input.device, strided_input.dtype
119
+ log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
120
+ if energy_floor == 0.0:
121
+ return log_energy
122
+ return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))
123
+
124
+
125
+ def _get_waveform_and_window_properties(
126
+ waveform: Tensor,
127
+ channel: int,
128
+ sample_frequency: float,
129
+ frame_shift: float,
130
+ frame_length: float,
131
+ round_to_power_of_two: bool,
132
+ preemphasis_coefficient: float,
133
+ ) -> Tuple[Tensor, int, int, int]:
134
+ r"""Gets the waveform and window properties"""
135
+ channel = max(channel, 0)
136
+ assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0))
137
+ waveform = waveform[channel, :] # size (n)
138
+ window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
139
+ window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
140
+ padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
141
+
142
+ assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
143
+ window_size, len(waveform)
144
+ )
145
+ assert 0 < window_shift, "`window_shift` must be greater than 0"
146
+ assert padded_window_size % 2 == 0, (
147
+ "the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`"
148
+ )
149
+ assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
150
+ assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
151
+ return waveform, window_shift, window_size, padded_window_size
152
+
153
+
154
+ def _get_window(
155
+ waveform: Tensor,
156
+ padded_window_size: int,
157
+ window_size: int,
158
+ window_shift: int,
159
+ window_type: str,
160
+ blackman_coeff: float,
161
+ snip_edges: bool,
162
+ raw_energy: bool,
163
+ energy_floor: float,
164
+ dither: float,
165
+ remove_dc_offset: bool,
166
+ preemphasis_coefficient: float,
167
+ ) -> Tuple[Tensor, Tensor]:
168
+ r"""Gets a window and its log energy
169
+
170
+ Returns:
171
+ (Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m)
172
+ """
173
+ device, dtype = waveform.device, waveform.dtype
174
+ epsilon = _get_epsilon(device, dtype)
175
+
176
+ # size (m, window_size)
177
+ strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
178
+
179
+ if dither != 0.0:
180
+ # Returns a random number strictly between 0 and 1
181
+ x = torch.max(epsilon, torch.rand(strided_input.shape, device=device, dtype=dtype))
182
+ rand_gauss = torch.sqrt(-2 * x.log()) * torch.cos(2 * math.pi * x)
183
+ strided_input = strided_input + rand_gauss * dither
184
+
185
+ if remove_dc_offset:
186
+ # Subtract each row/frame by its mean
187
+ row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
188
+ strided_input = strided_input - row_means
189
+
190
+ if raw_energy:
191
+ # Compute the log energy of each row/frame before applying preemphasis and
192
+ # window function
193
+ signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
194
+
195
+ if preemphasis_coefficient != 0.0:
196
+ # strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
197
+ offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(
198
+ 0
199
+ ) # size (m, window_size + 1)
200
+ strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]
201
+
202
+ # Apply window_function to each row/frame
203
+ window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze(
204
+ 0
205
+ ) # size (1, window_size)
206
+ strided_input = strided_input * window_function # size (m, window_size)
207
+
208
+ # Pad columns with zero until we reach size (m, padded_window_size)
209
+ if padded_window_size != window_size:
210
+ padding_right = padded_window_size - window_size
211
+ strided_input = torch.nn.functional.pad(
212
+ strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0
213
+ ).squeeze(0)
214
+
215
+ # Compute energy after window function (not the raw one)
216
+ if not raw_energy:
217
+ signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
218
+
219
+ return strided_input, signal_log_energy
220
+
221
+
222
+ def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
223
+ # subtracts the column mean of the tensor size (m, n) if subtract_mean=True
224
+ # it returns size (m, n)
225
+ if subtract_mean:
226
+ col_means = torch.mean(tensor, dim=0).unsqueeze(0)
227
+ tensor = tensor - col_means
228
+ return tensor
229
+
230
+
231
+ def spectrogram(
232
+ waveform: Tensor,
233
+ blackman_coeff: float = 0.42,
234
+ channel: int = -1,
235
+ dither: float = 0.0,
236
+ energy_floor: float = 1.0,
237
+ frame_length: float = 25.0,
238
+ frame_shift: float = 10.0,
239
+ min_duration: float = 0.0,
240
+ preemphasis_coefficient: float = 0.97,
241
+ raw_energy: bool = True,
242
+ remove_dc_offset: bool = True,
243
+ round_to_power_of_two: bool = True,
244
+ sample_frequency: float = 16000.0,
245
+ snip_edges: bool = True,
246
+ subtract_mean: bool = False,
247
+ window_type: str = POVEY,
248
+ ) -> Tensor:
249
+ r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's
250
+ compute-spectrogram-feats.
251
+
252
+ Args:
253
+ waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
254
+ blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
255
+ channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
256
+ dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
257
+ the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
258
+ energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
259
+ this floor is applied to the zeroth component, representing the total signal energy. The floor on the
260
+ individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
261
+ frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
262
+ frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
263
+ min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
264
+ preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
265
+ raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
266
+ remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
267
+ round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
268
+ to FFT. (Default: ``True``)
269
+ sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
270
+ specified there) (Default: ``16000.0``)
271
+ snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
272
+ in the file, and the number of frames depends on the frame_length. If False, the number of frames
273
+ depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
274
+ subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
275
+ it this way. (Default: ``False``)
276
+ window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
277
+ (Default: ``'povey'``)
278
+
279
+ Returns:
280
+ Tensor: A spectrogram identical to what Kaldi would output. The shape is
281
+ (m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided
282
+ """
283
+ device, dtype = waveform.device, waveform.dtype
284
+ epsilon = _get_epsilon(device, dtype)
285
+
286
+ waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
287
+ waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
288
+ )
289
+
290
+ if len(waveform) < min_duration * sample_frequency:
291
+ # signal is too short
292
+ return torch.empty(0)
293
+
294
+ strided_input, signal_log_energy = _get_window(
295
+ waveform,
296
+ padded_window_size,
297
+ window_size,
298
+ window_shift,
299
+ window_type,
300
+ blackman_coeff,
301
+ snip_edges,
302
+ raw_energy,
303
+ energy_floor,
304
+ dither,
305
+ remove_dc_offset,
306
+ preemphasis_coefficient,
307
+ )
308
+
309
+ # size (m, padded_window_size // 2 + 1, 2)
310
+ fft = torch.fft.rfft(strided_input)
311
+
312
+ # Convert the FFT into a power spectrum
313
+ power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1)
314
+ power_spectrum[:, 0] = signal_log_energy
315
+
316
+ power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
317
+ return power_spectrum
318
+
319
+
320
+ def inverse_mel_scale_scalar(mel_freq: float) -> float:
321
+ return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
322
+
323
+
324
+ def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
325
+ return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
326
+
327
+
328
+ def mel_scale_scalar(freq: float) -> float:
329
+ return 1127.0 * math.log(1.0 + freq / 700.0)
330
+
331
+
332
+ def mel_scale(freq: Tensor) -> Tensor:
333
+ return 1127.0 * (1.0 + freq / 700.0).log()
334
+
335
+
336
+ def vtln_warp_freq(
337
+ vtln_low_cutoff: float,
338
+ vtln_high_cutoff: float,
339
+ low_freq: float,
340
+ high_freq: float,
341
+ vtln_warp_factor: float,
342
+ freq: Tensor,
343
+ ) -> Tensor:
344
+ r"""This computes a VTLN warping function that is not the same as HTK's one,
345
+ but has similar inputs (this function has the advantage of never producing
346
+ empty bins).
347
+
348
+ This function computes a warp function F(freq), defined between low_freq
349
+ and high_freq inclusive, with the following properties:
350
+ F(low_freq) == low_freq
351
+ F(high_freq) == high_freq
352
+ The function is continuous and piecewise linear with two inflection
353
+ points.
354
+ The lower inflection point (measured in terms of the unwarped
355
+ frequency) is at frequency l, determined as described below.
356
+ The higher inflection point is at a frequency h, determined as
357
+ described below.
358
+ If l <= f <= h, then F(f) = f/vtln_warp_factor.
359
+ If the higher inflection point (measured in terms of the unwarped
360
+ frequency) is at h, then max(h, F(h)) == vtln_high_cutoff.
361
+ Since (by the last point) F(h) == h/vtln_warp_factor, then
362
+ max(h, h/vtln_warp_factor) == vtln_high_cutoff, so
363
+ h = vtln_high_cutoff / max(1, 1/vtln_warp_factor).
364
+ = vtln_high_cutoff * min(1, vtln_warp_factor).
365
+ If the lower inflection point (measured in terms of the unwarped
366
+ frequency) is at l, then min(l, F(l)) == vtln_low_cutoff
367
+ This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor)
368
+ = vtln_low_cutoff * max(1, vtln_warp_factor)
369
+ Args:
370
+ vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
371
+ vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
372
+ low_freq (float): Lower frequency cutoffs in mel computation
373
+ high_freq (float): Upper frequency cutoffs in mel computation
374
+ vtln_warp_factor (float): Vtln warp factor
375
+ freq (Tensor): given frequency in Hz
376
+
377
+ Returns:
378
+ Tensor: Freq after vtln warp
379
+ """
380
+ assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq"
381
+ assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]"
382
+ l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
383
+ h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
384
+ scale = 1.0 / vtln_warp_factor
385
+ Fl = scale * l # F(l)
386
+ Fh = scale * h # F(h)
387
+ assert l > low_freq and h < high_freq
388
+ # slope of left part of the 3-piece linear function
389
+ scale_left = (Fl - low_freq) / (l - low_freq)
390
+ # [slope of center part is just "scale"]
391
+
392
+ # slope of right part of the 3-piece linear function
393
+ scale_right = (high_freq - Fh) / (high_freq - h)
394
+
395
+ res = torch.empty_like(freq)
396
+
397
+ outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq
398
+ before_l = torch.lt(freq, l) # freq < l
399
+ before_h = torch.lt(freq, h) # freq < h
400
+ after_h = torch.ge(freq, h) # freq >= h
401
+
402
+ # order of operations matter here (since there is overlapping frequency regions)
403
+ res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
404
+ res[before_h] = scale * freq[before_h]
405
+ res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
406
+ res[outside_low_high_freq] = freq[outside_low_high_freq]
407
+
408
+ return res
409
+
410
+
411
+ def vtln_warp_mel_freq(
412
+ vtln_low_cutoff: float,
413
+ vtln_high_cutoff: float,
414
+ low_freq,
415
+ high_freq: float,
416
+ vtln_warp_factor: float,
417
+ mel_freq: Tensor,
418
+ ) -> Tensor:
419
+ r"""
420
+ Args:
421
+ vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
422
+ vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
423
+ low_freq (float): Lower frequency cutoffs in mel computation
424
+ high_freq (float): Upper frequency cutoffs in mel computation
425
+ vtln_warp_factor (float): Vtln warp factor
426
+ mel_freq (Tensor): Given frequency in Mel
427
+
428
+ Returns:
429
+ Tensor: ``mel_freq`` after vtln warp
430
+ """
431
+ return mel_scale(
432
+ vtln_warp_freq(
433
+ vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq)
434
+ )
435
+ )
436
+
437
+
438
+ def get_mel_banks(
439
+ num_bins: int,
440
+ window_length_padded: int,
441
+ sample_freq: float,
442
+ low_freq: float,
443
+ high_freq: float,
444
+ vtln_low: float,
445
+ vtln_high: float,
446
+ vtln_warp_factor: float,
447
+ ) -> Tuple[Tensor, Tensor]:
448
+ """
449
+ Returns:
450
+ (Tensor, Tensor): The tuple consists of ``bins`` (which is
451
+ melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is
452
+ center frequencies of bins of size (``num_bins``)).
453
+ """
454
+ assert num_bins > 3, "Must have at least 3 mel bins"
455
+ assert window_length_padded % 2 == 0
456
+ num_fft_bins = window_length_padded / 2
457
+ nyquist = 0.5 * sample_freq
458
+
459
+ if high_freq <= 0.0:
460
+ high_freq += nyquist
461
+
462
+ assert (
463
+ (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq)
464
+ ), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
465
+
466
+ # fft-bin width [think of it as Nyquist-freq / half-window-length]
467
+ fft_bin_width = sample_freq / window_length_padded
468
+ mel_low_freq = mel_scale_scalar(low_freq)
469
+ mel_high_freq = mel_scale_scalar(high_freq)
470
+
471
+ # divide by num_bins+1 in next line because of end-effects where the bins
472
+ # spread out to the sides.
473
+ mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
474
+
475
+ if vtln_high < 0.0:
476
+ vtln_high += nyquist
477
+
478
+ assert vtln_warp_factor == 1.0 or (
479
+ (low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)
480
+ ), "Bad values in options: vtln-low {} and vtln-high {}, versus " "low-freq {} and high-freq {}".format(
481
+ vtln_low, vtln_high, low_freq, high_freq
482
+ )
483
+
484
+ bin = torch.arange(num_bins).unsqueeze(1)
485
+ left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
486
+ center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1)
487
+ right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1)
488
+
489
+ if vtln_warp_factor != 1.0:
490
+ left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel)
491
+ center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel)
492
+ right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel)
493
+
494
+ center_freqs = inverse_mel_scale(center_mel) # size (num_bins)
495
+ # size(1, num_fft_bins)
496
+ mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0)
497
+
498
+ # size (num_bins, num_fft_bins)
499
+ up_slope = (mel - left_mel) / (center_mel - left_mel)
500
+ down_slope = (right_mel - mel) / (right_mel - center_mel)
501
+
502
+ if vtln_warp_factor == 1.0:
503
+ # left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
504
+ bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
505
+ else:
506
+ # warping can move the order of left_mel, center_mel, right_mel anywhere
507
+ bins = torch.zeros_like(up_slope)
508
+ up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel
509
+ down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel
510
+ bins[up_idx] = up_slope[up_idx]
511
+ bins[down_idx] = down_slope[down_idx]
512
+
513
+ return bins, center_freqs
514
+
515
+
516
+ def fbank(
517
+ waveform: Tensor,
518
+ blackman_coeff: float = 0.42,
519
+ channel: int = -1,
520
+ dither: float = 0.0,
521
+ energy_floor: float = 1.0,
522
+ frame_length: float = 25.0,
523
+ frame_shift: float = 10.0,
524
+ high_freq: float = 0.0,
525
+ htk_compat: bool = False,
526
+ low_freq: float = 20.0,
527
+ min_duration: float = 0.0,
528
+ num_mel_bins: int = 23,
529
+ preemphasis_coefficient: float = 0.97,
530
+ raw_energy: bool = True,
531
+ remove_dc_offset: bool = True,
532
+ round_to_power_of_two: bool = True,
533
+ sample_frequency: float = 16000.0,
534
+ snip_edges: bool = True,
535
+ subtract_mean: bool = False,
536
+ use_energy: bool = False,
537
+ use_log_fbank: bool = True,
538
+ use_power: bool = True,
539
+ vtln_high: float = -500.0,
540
+ vtln_low: float = 100.0,
541
+ vtln_warp: float = 1.0,
542
+ window_type: str = POVEY,
543
+ ) -> Tensor:
544
+ r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's
545
+ compute-fbank-feats.
546
+
547
+ Args:
548
+ waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
549
+ blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
550
+ channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
551
+ dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
552
+ the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
553
+ energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
554
+ this floor is applied to the zeroth component, representing the total signal energy. The floor on the
555
+ individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
556
+ frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
557
+ frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
558
+ high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
559
+ (Default: ``0.0``)
560
+ htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features
561
+ (need to change other parameters). (Default: ``False``)
562
+ low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
563
+ min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
564
+ num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
565
+ preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
566
+ raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
567
+ remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
568
+ round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
569
+ to FFT. (Default: ``True``)
570
+ sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
571
+ specified there) (Default: ``16000.0``)
572
+ snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
573
+ in the file, and the number of frames depends on the frame_length. If False, the number of frames
574
+ depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
575
+ subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
576
+ it this way. (Default: ``False``)
577
+ use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
578
+ use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``)
579
+ use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``)
580
+ vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
581
+ negative, offset from high-mel-freq (Default: ``-500.0``)
582
+ vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
583
+ vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
584
+ window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
585
+ (Default: ``'povey'``)
586
+
587
+ Returns:
588
+ Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``)
589
+ where m is calculated in _get_strided
590
+ """
591
+ device, dtype = waveform.device, waveform.dtype
592
+
593
+ waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
594
+ waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
595
+ )
596
+
597
+ if len(waveform) < min_duration * sample_frequency:
598
+ # signal is too short
599
+ return torch.empty(0, device=device, dtype=dtype)
600
+
601
+ # strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
602
+ strided_input, signal_log_energy = _get_window(
603
+ waveform,
604
+ padded_window_size,
605
+ window_size,
606
+ window_shift,
607
+ window_type,
608
+ blackman_coeff,
609
+ snip_edges,
610
+ raw_energy,
611
+ energy_floor,
612
+ dither,
613
+ remove_dc_offset,
614
+ preemphasis_coefficient,
615
+ )
616
+
617
+ # size (m, padded_window_size // 2 + 1)
618
+ spectrum = torch.fft.rfft(strided_input).abs()
619
+ if use_power:
620
+ spectrum = spectrum.pow(2.0)
621
+
622
+ # size (num_mel_bins, padded_window_size // 2)
623
+ mel_energies, _ = get_mel_banks(
624
+ num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp
625
+ )
626
+ mel_energies = mel_energies.to(device=device, dtype=dtype)
627
+
628
+ # pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
629
+ mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
630
+
631
+ # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
632
+ mel_energies = torch.mm(spectrum, mel_energies.T)
633
+ if use_log_fbank:
634
+ # avoid log of zero (which should be prevented anyway by dithering)
635
+ mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
636
+
637
+ # if use_energy then add it as the last column for htk_compat == true else first column
638
+ if use_energy:
639
+ signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1)
640
+ # returns size (m, num_mel_bins + 1)
641
+ if htk_compat:
642
+ mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1)
643
+ else:
644
+ mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1)
645
+
646
+ mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
647
+ return mel_energies
648
+
649
+
650
+ def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor:
651
+ # returns a dct matrix of size (num_mel_bins, num_ceps)
652
+ # size (num_mel_bins, num_mel_bins)
653
+ dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho")
654
+ # kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
655
+ # this would be the first column in the dct_matrix for torchaudio as it expects a
656
+ # right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
657
+ # expects a left multiply e.g. dct_matrix * vector).
658
+ dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins))
659
+ dct_matrix = dct_matrix[:, :num_ceps]
660
+ return dct_matrix
661
+
662
+
663
+ def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor:
664
+ # returns size (num_ceps)
665
+ # Compute liftering coefficients (scaling on cepstral coeffs)
666
+ # coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
667
+ i = torch.arange(num_ceps)
668
+ return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter)
669
+
670
+
671
+ def mfcc(
672
+ waveform: Tensor,
673
+ blackman_coeff: float = 0.42,
674
+ cepstral_lifter: float = 22.0,
675
+ channel: int = -1,
676
+ dither: float = 0.0,
677
+ energy_floor: float = 1.0,
678
+ frame_length: float = 25.0,
679
+ frame_shift: float = 10.0,
680
+ high_freq: float = 0.0,
681
+ htk_compat: bool = False,
682
+ low_freq: float = 20.0,
683
+ num_ceps: int = 13,
684
+ min_duration: float = 0.0,
685
+ num_mel_bins: int = 23,
686
+ preemphasis_coefficient: float = 0.97,
687
+ raw_energy: bool = True,
688
+ remove_dc_offset: bool = True,
689
+ round_to_power_of_two: bool = True,
690
+ sample_frequency: float = 16000.0,
691
+ snip_edges: bool = True,
692
+ subtract_mean: bool = False,
693
+ use_energy: bool = False,
694
+ vtln_high: float = -500.0,
695
+ vtln_low: float = 100.0,
696
+ vtln_warp: float = 1.0,
697
+ window_type: str = POVEY,
698
+ ) -> Tensor:
699
+ r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's
700
+ compute-mfcc-feats.
701
+
702
+ Args:
703
+ waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
704
+ blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
705
+ cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``)
706
+ channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
707
+ dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
708
+ the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
709
+ energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
710
+ this floor is applied to the zeroth component, representing the total signal energy. The floor on the
711
+ individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
712
+ frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
713
+ frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
714
+ high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
715
+ (Default: ``0.0``)
716
+ htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible
717
+ features (need to change other parameters). (Default: ``False``)
718
+ low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
719
+ num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``)
720
+ min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
721
+ num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
722
+ preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
723
+ raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
724
+ remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
725
+ round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
726
+ to FFT. (Default: ``True``)
727
+ sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
728
+ specified there) (Default: ``16000.0``)
729
+ snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
730
+ in the file, and the number of frames depends on the frame_length. If False, the number of frames
731
+ depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
732
+ subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
733
+ it this way. (Default: ``False``)
734
+ use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
735
+ vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
736
+ negative, offset from high-mel-freq (Default: ``-500.0``)
737
+ vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
738
+ vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
739
+ window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
740
+ (Default: ``"povey"``)
741
+
742
+ Returns:
743
+ Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``)
744
+ where m is calculated in _get_strided
745
+ """
746
+ assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins)
747
+
748
+ device, dtype = waveform.device, waveform.dtype
749
+
750
+ # The mel_energies should not be squared (use_power=True), not have mean subtracted
751
+ # (subtract_mean=False), and use log (use_log_fbank=True).
752
+ # size (m, num_mel_bins + use_energy)
753
+ feature = fbank(
754
+ waveform=waveform,
755
+ blackman_coeff=blackman_coeff,
756
+ channel=channel,
757
+ dither=dither,
758
+ energy_floor=energy_floor,
759
+ frame_length=frame_length,
760
+ frame_shift=frame_shift,
761
+ high_freq=high_freq,
762
+ htk_compat=htk_compat,
763
+ low_freq=low_freq,
764
+ min_duration=min_duration,
765
+ num_mel_bins=num_mel_bins,
766
+ preemphasis_coefficient=preemphasis_coefficient,
767
+ raw_energy=raw_energy,
768
+ remove_dc_offset=remove_dc_offset,
769
+ round_to_power_of_two=round_to_power_of_two,
770
+ sample_frequency=sample_frequency,
771
+ snip_edges=snip_edges,
772
+ subtract_mean=False,
773
+ use_energy=use_energy,
774
+ use_log_fbank=True,
775
+ use_power=True,
776
+ vtln_high=vtln_high,
777
+ vtln_low=vtln_low,
778
+ vtln_warp=vtln_warp,
779
+ window_type=window_type,
780
+ )
781
+
782
+ if use_energy:
783
+ # size (m)
784
+ signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
785
+ # offset is 0 if htk_compat==True else 1
786
+ mel_offset = int(not htk_compat)
787
+ feature = feature[:, mel_offset : (num_mel_bins + mel_offset)]
788
+
789
+ # size (num_mel_bins, num_ceps)
790
+ dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device)
791
+
792
+ # size (m, num_ceps)
793
+ feature = feature.matmul(dct_matrix)
794
+
795
+ if cepstral_lifter != 0.0:
796
+ # size (1, num_ceps)
797
+ lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0)
798
+ feature *= lifter_coeffs.to(device=device, dtype=dtype)
799
+
800
+ # if use_energy then replace the last column for htk_compat == true else first column
801
+ if use_energy:
802
+ feature[:, 0] = signal_log_energy
803
+
804
+ if htk_compat:
805
+ energy = feature[:, 0].unsqueeze(1) # size (m, 1)
806
+ feature = feature[:, 1:] # size (m, num_ceps - 1)
807
+ if not use_energy:
808
+ # scale on C0 (actually removing a scale we previously added that's
809
+ # part of one common definition of the cosine transform.)
810
+ energy *= math.sqrt(2)
811
+
812
+ feature = torch.cat((feature, energy), dim=1)
813
+
814
+ feature = _subtract_column_mean(feature, subtract_mean)
815
+ return feature
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/datasets/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .cmuarctic import CMUARCTIC
2
+ from .cmudict import CMUDict
3
+ from .commonvoice import COMMONVOICE
4
+ from .dr_vctk import DR_VCTK
5
+ from .gtzan import GTZAN
6
+ from .librilight_limited import LibriLightLimited
7
+ from .librimix import LibriMix
8
+ from .librispeech import LIBRISPEECH
9
+ from .libritts import LIBRITTS
10
+ from .ljspeech import LJSPEECH
11
+ from .quesst14 import QUESST14
12
+ from .speechcommands import SPEECHCOMMANDS
13
+ from .tedlium import TEDLIUM
14
+ from .vctk import VCTK_092
15
+ from .yesno import YESNO
16
+
17
+
18
+ __all__ = [
19
+ "COMMONVOICE",
20
+ "LIBRISPEECH",
21
+ "LibriLightLimited",
22
+ "SPEECHCOMMANDS",
23
+ "VCTK_092",
24
+ "DR_VCTK",
25
+ "YESNO",
26
+ "LJSPEECH",
27
+ "GTZAN",
28
+ "CMUARCTIC",
29
+ "CMUDict",
30
+ "LibriMix",
31
+ "LIBRITTS",
32
+ "TEDLIUM",
33
+ "QUESST14",
34
+ ]
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/datasets/cmuarctic.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Tuple, Union
5
+
6
+ import torchaudio
7
+ from torch import Tensor
8
+ from torch.hub import download_url_to_file
9
+ from torch.utils.data import Dataset
10
+ from torchaudio.datasets.utils import extract_archive
11
+
12
+ URL = "aew"
13
+ FOLDER_IN_ARCHIVE = "ARCTIC"
14
+ _CHECKSUMS = {
15
+ "http://festvox.org/cmu_arctic/packed/cmu_us_aew_arctic.tar.bz2": "645cb33c0f0b2ce41384fdd8d3db2c3f5fc15c1e688baeb74d2e08cab18ab406", # noqa: E501
16
+ "http://festvox.org/cmu_arctic/packed/cmu_us_ahw_arctic.tar.bz2": "024664adeb892809d646a3efd043625b46b5bfa3e6189b3500b2d0d59dfab06c", # noqa: E501
17
+ "http://festvox.org/cmu_arctic/packed/cmu_us_aup_arctic.tar.bz2": "2c55bc3050caa996758869126ad10cf42e1441212111db034b3a45189c18b6fc", # noqa: E501
18
+ "http://festvox.org/cmu_arctic/packed/cmu_us_awb_arctic.tar.bz2": "d74a950c9739a65f7bfc4dfa6187f2730fa03de5b8eb3f2da97a51b74df64d3c", # noqa: E501
19
+ "http://festvox.org/cmu_arctic/packed/cmu_us_axb_arctic.tar.bz2": "dd65c3d2907d1ee52f86e44f578319159e60f4bf722a9142be01161d84e330ff", # noqa: E501
20
+ "http://festvox.org/cmu_arctic/packed/cmu_us_bdl_arctic.tar.bz2": "26b91aaf48b2799b2956792b4632c2f926cd0542f402b5452d5adecb60942904", # noqa: E501
21
+ "http://festvox.org/cmu_arctic/packed/cmu_us_clb_arctic.tar.bz2": "3f16dc3f3b97955ea22623efb33b444341013fc660677b2e170efdcc959fa7c6", # noqa: E501
22
+ "http://festvox.org/cmu_arctic/packed/cmu_us_eey_arctic.tar.bz2": "8a0ee4e5acbd4b2f61a4fb947c1730ab3adcc9dc50b195981d99391d29928e8a", # noqa: E501
23
+ "http://festvox.org/cmu_arctic/packed/cmu_us_fem_arctic.tar.bz2": "3fcff629412b57233589cdb058f730594a62c4f3a75c20de14afe06621ef45e2", # noqa: E501
24
+ "http://festvox.org/cmu_arctic/packed/cmu_us_gka_arctic.tar.bz2": "dc82e7967cbd5eddbed33074b0699128dbd4482b41711916d58103707e38c67f", # noqa: E501
25
+ "http://festvox.org/cmu_arctic/packed/cmu_us_jmk_arctic.tar.bz2": "3a37c0e1dfc91e734fdbc88b562d9e2ebca621772402cdc693bbc9b09b211d73", # noqa: E501
26
+ "http://festvox.org/cmu_arctic/packed/cmu_us_ksp_arctic.tar.bz2": "8029cafce8296f9bed3022c44ef1e7953332b6bf6943c14b929f468122532717", # noqa: E501
27
+ "http://festvox.org/cmu_arctic/packed/cmu_us_ljm_arctic.tar.bz2": "b23993765cbf2b9e7bbc3c85b6c56eaf292ac81ee4bb887b638a24d104f921a0", # noqa: E501
28
+ "http://festvox.org/cmu_arctic/packed/cmu_us_lnh_arctic.tar.bz2": "4faf34d71aa7112813252fb20c5433e2fdd9a9de55a00701ffcbf05f24a5991a", # noqa: E501
29
+ "http://festvox.org/cmu_arctic/packed/cmu_us_rms_arctic.tar.bz2": "c6dc11235629c58441c071a7ba8a2d067903dfefbaabc4056d87da35b72ecda4", # noqa: E501
30
+ "http://festvox.org/cmu_arctic/packed/cmu_us_rxr_arctic.tar.bz2": "1fa4271c393e5998d200e56c102ff46fcfea169aaa2148ad9e9469616fbfdd9b", # noqa: E501
31
+ "http://festvox.org/cmu_arctic/packed/cmu_us_slp_arctic.tar.bz2": "54345ed55e45c23d419e9a823eef427f1cc93c83a710735ec667d068c916abf1", # noqa: E501
32
+ "http://festvox.org/cmu_arctic/packed/cmu_us_slt_arctic.tar.bz2": "7c173297916acf3cc7fcab2713be4c60b27312316765a90934651d367226b4ea", # noqa: E501
33
+ }
34
+
35
+
36
+ def load_cmuarctic_item(line: str, path: str, folder_audio: str, ext_audio: str) -> Tuple[Tensor, int, str, str]:
37
+
38
+ utterance_id, transcript = line[0].strip().split(" ", 2)[1:]
39
+
40
+ # Remove space, double quote, and single parenthesis from transcript
41
+ transcript = transcript[1:-3]
42
+
43
+ file_audio = os.path.join(path, folder_audio, utterance_id + ext_audio)
44
+
45
+ # Load audio
46
+ waveform, sample_rate = torchaudio.load(file_audio)
47
+
48
+ return (waveform, sample_rate, transcript, utterance_id.split("_")[1])
49
+
50
+
51
+ class CMUARCTIC(Dataset):
52
+ """Create a Dataset for *CMU ARCTIC* [:footcite:`Kominek03cmuarctic`].
53
+
54
+ Args:
55
+ root (str or Path): Path to the directory where the dataset is found or downloaded.
56
+ url (str, optional):
57
+ The URL to download the dataset from or the type of the dataset to download.
58
+ (default: ``"aew"``)
59
+ Allowed type values are ``"aew"``, ``"ahw"``, ``"aup"``, ``"awb"``, ``"axb"``, ``"bdl"``,
60
+ ``"clb"``, ``"eey"``, ``"fem"``, ``"gka"``, ``"jmk"``, ``"ksp"``, ``"ljm"``, ``"lnh"``,
61
+ ``"rms"``, ``"rxr"``, ``"slp"`` or ``"slt"``.
62
+ folder_in_archive (str, optional):
63
+ The top-level directory of the dataset. (default: ``"ARCTIC"``)
64
+ download (bool, optional):
65
+ Whether to download the dataset if it is not found at root path. (default: ``False``).
66
+ """
67
+
68
+ _file_text = "txt.done.data"
69
+ _folder_text = "etc"
70
+ _ext_audio = ".wav"
71
+ _folder_audio = "wav"
72
+
73
+ def __init__(
74
+ self, root: Union[str, Path], url: str = URL, folder_in_archive: str = FOLDER_IN_ARCHIVE, download: bool = False
75
+ ) -> None:
76
+
77
+ if url in [
78
+ "aew",
79
+ "ahw",
80
+ "aup",
81
+ "awb",
82
+ "axb",
83
+ "bdl",
84
+ "clb",
85
+ "eey",
86
+ "fem",
87
+ "gka",
88
+ "jmk",
89
+ "ksp",
90
+ "ljm",
91
+ "lnh",
92
+ "rms",
93
+ "rxr",
94
+ "slp",
95
+ "slt",
96
+ ]:
97
+
98
+ url = "cmu_us_" + url + "_arctic"
99
+ ext_archive = ".tar.bz2"
100
+ base_url = "http://www.festvox.org/cmu_arctic/packed/"
101
+
102
+ url = os.path.join(base_url, url + ext_archive)
103
+
104
+ # Get string representation of 'root' in case Path object is passed
105
+ root = os.fspath(root)
106
+
107
+ basename = os.path.basename(url)
108
+ root = os.path.join(root, folder_in_archive)
109
+ if not os.path.isdir(root):
110
+ os.mkdir(root)
111
+ archive = os.path.join(root, basename)
112
+
113
+ basename = basename.split(".")[0]
114
+
115
+ self._path = os.path.join(root, basename)
116
+
117
+ if download:
118
+ if not os.path.isdir(self._path):
119
+ if not os.path.isfile(archive):
120
+ checksum = _CHECKSUMS.get(url, None)
121
+ download_url_to_file(url, archive, hash_prefix=checksum)
122
+ extract_archive(archive)
123
+ else:
124
+ if not os.path.exists(self._path):
125
+ raise RuntimeError(
126
+ f"The path {self._path} doesn't exist. "
127
+ "Please check the ``root`` path or set `download=True` to download it"
128
+ )
129
+ self._text = os.path.join(self._path, self._folder_text, self._file_text)
130
+
131
+ with open(self._text, "r") as text:
132
+ walker = csv.reader(text, delimiter="\n")
133
+ self._walker = list(walker)
134
+
135
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str]:
136
+ """Load the n-th sample from the dataset.
137
+
138
+ Args:
139
+ n (int): The index of the sample to be loaded
140
+
141
+ Returns:
142
+ (Tensor, int, str, str): ``(waveform, sample_rate, transcript, utterance_id)``
143
+ """
144
+ line = self._walker[n]
145
+ return load_cmuarctic_item(line, self._path, self._folder_audio, self._ext_audio)
146
+
147
+ def __len__(self) -> int:
148
+ return len(self._walker)
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/datasets/cmudict.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from pathlib import Path
4
+ from typing import Iterable, List, Tuple, Union
5
+
6
+ from torch.hub import download_url_to_file
7
+ from torch.utils.data import Dataset
8
+
9
+ _CHECKSUMS = {
10
+ "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b": "209a8b4cd265013e96f4658632a9878103b0c5abf62b50d4ef3ae1be226b29e4", # noqa: E501
11
+ "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols": "408ccaae803641c6d7b626b6299949320c2dbca96b2220fd3fb17887b023b027", # noqa: E501
12
+ }
13
+ _PUNCTUATIONS = set(
14
+ [
15
+ "!EXCLAMATION-POINT",
16
+ '"CLOSE-QUOTE',
17
+ '"DOUBLE-QUOTE',
18
+ '"END-OF-QUOTE',
19
+ '"END-QUOTE',
20
+ '"IN-QUOTES',
21
+ '"QUOTE',
22
+ '"UNQUOTE',
23
+ "#HASH-MARK",
24
+ "#POUND-SIGN",
25
+ "#SHARP-SIGN",
26
+ "%PERCENT",
27
+ "&AMPERSAND",
28
+ "'END-INNER-QUOTE",
29
+ "'END-QUOTE",
30
+ "'INNER-QUOTE",
31
+ "'QUOTE",
32
+ "'SINGLE-QUOTE",
33
+ "(BEGIN-PARENS",
34
+ "(IN-PARENTHESES",
35
+ "(LEFT-PAREN",
36
+ "(OPEN-PARENTHESES",
37
+ "(PAREN",
38
+ "(PARENS",
39
+ "(PARENTHESES",
40
+ ")CLOSE-PAREN",
41
+ ")CLOSE-PARENTHESES",
42
+ ")END-PAREN",
43
+ ")END-PARENS",
44
+ ")END-PARENTHESES",
45
+ ")END-THE-PAREN",
46
+ ")PAREN",
47
+ ")PARENS",
48
+ ")RIGHT-PAREN",
49
+ ")UN-PARENTHESES",
50
+ "+PLUS",
51
+ ",COMMA",
52
+ "--DASH",
53
+ "-DASH",
54
+ "-HYPHEN",
55
+ "...ELLIPSIS",
56
+ ".DECIMAL",
57
+ ".DOT",
58
+ ".FULL-STOP",
59
+ ".PERIOD",
60
+ ".POINT",
61
+ "/SLASH",
62
+ ":COLON",
63
+ ";SEMI-COLON",
64
+ ";SEMI-COLON(1)",
65
+ "?QUESTION-MARK",
66
+ "{BRACE",
67
+ "{LEFT-BRACE",
68
+ "{OPEN-BRACE",
69
+ "}CLOSE-BRACE",
70
+ "}RIGHT-BRACE",
71
+ ]
72
+ )
73
+
74
+
75
+ def _parse_dictionary(lines: Iterable[str], exclude_punctuations: bool) -> List[str]:
76
+ _alt_re = re.compile(r"\([0-9]+\)")
77
+ cmudict: List[Tuple[str, List[str]]] = list()
78
+ for line in lines:
79
+ if not line or line.startswith(";;;"): # ignore comments
80
+ continue
81
+
82
+ word, phones = line.strip().split(" ")
83
+ if word in _PUNCTUATIONS:
84
+ if exclude_punctuations:
85
+ continue
86
+ # !EXCLAMATION-POINT -> !
87
+ # --DASH -> --
88
+ # ...ELLIPSIS -> ...
89
+ if word.startswith("..."):
90
+ word = "..."
91
+ elif word.startswith("--"):
92
+ word = "--"
93
+ else:
94
+ word = word[0]
95
+
96
+ # if a word have multiple pronunciations, there will be (number) appended to it
97
+ # for example, DATAPOINTS and DATAPOINTS(1),
98
+ # the regular expression `_alt_re` removes the '(1)' and change the word DATAPOINTS(1) to DATAPOINTS
99
+ word = re.sub(_alt_re, "", word)
100
+ phones = phones.split(" ")
101
+ cmudict.append((word, phones))
102
+
103
+ return cmudict
104
+
105
+
106
+ class CMUDict(Dataset):
107
+ """Create a Dataset for *CMU Pronouncing Dictionary* [:footcite:`cmudict`] (CMUDict).
108
+
109
+ Args:
110
+ root (str or Path): Path to the directory where the dataset is found or downloaded.
111
+ exclude_punctuations (bool, optional):
112
+ When enabled, exclude the pronounciation of punctuations, such as
113
+ `!EXCLAMATION-POINT` and `#HASH-MARK`.
114
+ download (bool, optional):
115
+ Whether to download the dataset if it is not found at root path. (default: ``False``).
116
+ url (str, optional):
117
+ The URL to download the dictionary from.
118
+ (default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b"``)
119
+ url_symbols (str, optional):
120
+ The URL to download the list of symbols from.
121
+ (default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols"``)
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ root: Union[str, Path],
127
+ exclude_punctuations: bool = True,
128
+ *,
129
+ download: bool = False,
130
+ url: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b",
131
+ url_symbols: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols",
132
+ ) -> None:
133
+
134
+ self.exclude_punctuations = exclude_punctuations
135
+
136
+ self._root_path = Path(root)
137
+ if not os.path.isdir(self._root_path):
138
+ raise RuntimeError(f"The root directory does not exist; {root}")
139
+
140
+ dict_file = self._root_path / os.path.basename(url)
141
+ symbol_file = self._root_path / os.path.basename(url_symbols)
142
+ if not os.path.exists(dict_file):
143
+ if not download:
144
+ raise RuntimeError(
145
+ "The dictionary file is not found in the following location. "
146
+ f"Set `download=True` to download it. {dict_file}"
147
+ )
148
+ checksum = _CHECKSUMS.get(url, None)
149
+ download_url_to_file(url, dict_file, checksum)
150
+ if not os.path.exists(symbol_file):
151
+ if not download:
152
+ raise RuntimeError(
153
+ "The symbol file is not found in the following location. "
154
+ f"Set `download=True` to download it. {symbol_file}"
155
+ )
156
+ checksum = _CHECKSUMS.get(url_symbols, None)
157
+ download_url_to_file(url_symbols, symbol_file, checksum)
158
+
159
+ with open(symbol_file, "r") as text:
160
+ self._symbols = [line.strip() for line in text.readlines()]
161
+
162
+ with open(dict_file, "r", encoding="latin-1") as text:
163
+ self._dictionary = _parse_dictionary(text.readlines(), exclude_punctuations=self.exclude_punctuations)
164
+
165
+ def __getitem__(self, n: int) -> Tuple[str, List[str]]:
166
+ """Load the n-th sample from the dataset.
167
+
168
+ Args:
169
+ n (int): The index of the sample to be loaded.
170
+
171
+ Returns:
172
+ (str, List[str]): The corresponding word and phonemes ``(word, [phonemes])``.
173
+
174
+ """
175
+ return self._dictionary[n]
176
+
177
+ def __len__(self) -> int:
178
+ return len(self._dictionary)
179
+
180
+ @property
181
+ def symbols(self) -> List[str]:
182
+ """list[str]: A list of phonemes symbols, such as `AA`, `AE`, `AH`."""
183
+ return self._symbols.copy()
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/datasets/commonvoice.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Dict, List, Tuple, Union
5
+
6
+ import torchaudio
7
+ from torch import Tensor
8
+ from torch.utils.data import Dataset
9
+
10
+
11
+ def load_commonvoice_item(
12
+ line: List[str], header: List[str], path: str, folder_audio: str, ext_audio: str
13
+ ) -> Tuple[Tensor, int, Dict[str, str]]:
14
+ # Each line as the following data:
15
+ # client_id, path, sentence, up_votes, down_votes, age, gender, accent
16
+
17
+ assert header[1] == "path"
18
+ fileid = line[1]
19
+ filename = os.path.join(path, folder_audio, fileid)
20
+ if not filename.endswith(ext_audio):
21
+ filename += ext_audio
22
+ waveform, sample_rate = torchaudio.load(filename)
23
+
24
+ dic = dict(zip(header, line))
25
+
26
+ return waveform, sample_rate, dic
27
+
28
+
29
+ class COMMONVOICE(Dataset):
30
+ """Create a Dataset for *CommonVoice* [:footcite:`ardila2020common`].
31
+
32
+ Args:
33
+ root (str or Path): Path to the directory where the dataset is located.
34
+ (Where the ``tsv`` file is present.)
35
+ tsv (str, optional):
36
+ The name of the tsv file used to construct the metadata, such as
37
+ ``"train.tsv"``, ``"test.tsv"``, ``"dev.tsv"``, ``"invalidated.tsv"``,
38
+ ``"validated.tsv"`` and ``"other.tsv"``. (default: ``"train.tsv"``)
39
+ """
40
+
41
+ _ext_txt = ".txt"
42
+ _ext_audio = ".mp3"
43
+ _folder_audio = "clips"
44
+
45
+ def __init__(self, root: Union[str, Path], tsv: str = "train.tsv") -> None:
46
+
47
+ # Get string representation of 'root' in case Path object is passed
48
+ self._path = os.fspath(root)
49
+ self._tsv = os.path.join(self._path, tsv)
50
+
51
+ with open(self._tsv, "r") as tsv_:
52
+ walker = csv.reader(tsv_, delimiter="\t")
53
+ self._header = next(walker)
54
+ self._walker = list(walker)
55
+
56
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, Dict[str, str]]:
57
+ """Load the n-th sample from the dataset.
58
+
59
+ Args:
60
+ n (int): The index of the sample to be loaded
61
+
62
+ Returns:
63
+ (Tensor, int, Dict[str, str]): ``(waveform, sample_rate, dictionary)``, where dictionary
64
+ is built from the TSV file with the following keys: ``client_id``, ``path``, ``sentence``,
65
+ ``up_votes``, ``down_votes``, ``age``, ``gender`` and ``accent``.
66
+ """
67
+ line = self._walker[n]
68
+ return load_commonvoice_item(line, self._header, self._path, self._folder_audio, self._ext_audio)
69
+
70
+ def __len__(self) -> int:
71
+ return len(self._walker)
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/datasets/dr_vctk.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Dict, Tuple, Union
3
+
4
+ import torchaudio
5
+ from torch import Tensor
6
+ from torch.hub import download_url_to_file
7
+ from torch.utils.data import Dataset
8
+ from torchaudio.datasets.utils import extract_archive
9
+
10
+
11
+ _URL = "https://datashare.ed.ac.uk/bitstream/handle/10283/3038/DR-VCTK.zip"
12
+ _CHECKSUM = "781f12f4406ed36ed27ae3bce55da47ba176e2d8bae67319e389e07b2c9bd769"
13
+ _SUPPORTED_SUBSETS = {"train", "test"}
14
+
15
+
16
+ class DR_VCTK(Dataset):
17
+ """Create a dataset for *Device Recorded VCTK (Small subset version)* [:footcite:`Sarfjoo2018DeviceRV`].
18
+
19
+ Args:
20
+ root (str or Path): Root directory where the dataset's top level directory is found.
21
+ subset (str): The subset to use. Can be one of ``"train"`` and ``"test"``. (default: ``"train"``).
22
+ download (bool):
23
+ Whether to download the dataset if it is not found at root path. (default: ``False``).
24
+ url (str): The URL to download the dataset from.
25
+ (default: ``"https://datashare.ed.ac.uk/bitstream/handle/10283/3038/DR-VCTK.zip"``)
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ root: Union[str, Path],
31
+ subset: str = "train",
32
+ *,
33
+ download: bool = False,
34
+ url: str = _URL,
35
+ ) -> None:
36
+ if subset not in _SUPPORTED_SUBSETS:
37
+ raise RuntimeError(
38
+ f"The subset '{subset}' does not match any of the supported subsets: {_SUPPORTED_SUBSETS}"
39
+ )
40
+
41
+ root = Path(root).expanduser()
42
+ archive = root / "DR-VCTK.zip"
43
+
44
+ self._subset = subset
45
+ self._path = root / "DR-VCTK" / "DR-VCTK"
46
+ self._clean_audio_dir = self._path / f"clean_{self._subset}set_wav_16k"
47
+ self._noisy_audio_dir = self._path / f"device-recorded_{self._subset}set_wav_16k"
48
+ self._config_filepath = self._path / "configurations" / f"{self._subset}_ch_log.txt"
49
+
50
+ if not self._path.is_dir():
51
+ if not archive.is_file():
52
+ if not download:
53
+ raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
54
+ download_url_to_file(url, archive, hash_prefix=_CHECKSUM)
55
+ extract_archive(archive, root)
56
+
57
+ self._config = self._load_config(self._config_filepath)
58
+ self._filename_list = sorted(self._config)
59
+
60
+ def _load_config(self, filepath: str) -> Dict[str, Tuple[str, int]]:
61
+ # Skip header
62
+ skip_rows = 2 if self._subset == "train" else 1
63
+
64
+ config = {}
65
+ with open(filepath) as f:
66
+ for i, line in enumerate(f):
67
+ if i < skip_rows or not line:
68
+ continue
69
+ filename, source, channel_id = line.strip().split("\t")
70
+ config[filename] = (source, int(channel_id))
71
+ return config
72
+
73
+ def _load_dr_vctk_item(self, filename: str) -> Tuple[Tensor, int, Tensor, int, str, str, str, int]:
74
+ speaker_id, utterance_id = filename.split(".")[0].split("_")
75
+ source, channel_id = self._config[filename]
76
+ file_clean_audio = self._clean_audio_dir / filename
77
+ file_noisy_audio = self._noisy_audio_dir / filename
78
+ waveform_clean, sample_rate_clean = torchaudio.load(file_clean_audio)
79
+ waveform_noisy, sample_rate_noisy = torchaudio.load(file_noisy_audio)
80
+ return (
81
+ waveform_clean,
82
+ sample_rate_clean,
83
+ waveform_noisy,
84
+ sample_rate_noisy,
85
+ speaker_id,
86
+ utterance_id,
87
+ source,
88
+ channel_id,
89
+ )
90
+
91
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, Tensor, int, str, str, str, int]:
92
+ """Load the n-th sample from the dataset.
93
+
94
+ Args:
95
+ n (int): The index of the sample to be loaded
96
+
97
+ Returns:
98
+ (Tensor, int, Tensor, int, str, str, str, int):
99
+ ``(waveform_clean, sample_rate_clean, waveform_noisy, sample_rate_noisy, speaker_id,\
100
+ utterance_id, source, channel_id)``
101
+ """
102
+ filename = self._filename_list[n]
103
+ return self._load_dr_vctk_item(filename)
104
+
105
+ def __len__(self) -> int:
106
+ return len(self._filename_list)
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/datasets/gtzan.py ADDED
@@ -0,0 +1,1108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torchaudio
6
+ from torch import Tensor
7
+ from torch.hub import download_url_to_file
8
+ from torch.utils.data import Dataset
9
+ from torchaudio.datasets.utils import extract_archive
10
+
11
+ # The following lists prefixed with `filtered_` provide a filtered split
12
+ # that:
13
+ #
14
+ # a. Mitigate a known issue with GTZAN (duplication)
15
+ #
16
+ # b. Provide a standard split for testing it against other
17
+ # methods (e.g. the one in jordipons/sklearn-audio-transfer-learning).
18
+ #
19
+ # Those are used when GTZAN is initialised with the `filtered` keyword.
20
+ # The split was taken from (github) jordipons/sklearn-audio-transfer-learning.
21
+
22
+ gtzan_genres = [
23
+ "blues",
24
+ "classical",
25
+ "country",
26
+ "disco",
27
+ "hiphop",
28
+ "jazz",
29
+ "metal",
30
+ "pop",
31
+ "reggae",
32
+ "rock",
33
+ ]
34
+
35
+ filtered_test = [
36
+ "blues.00012",
37
+ "blues.00013",
38
+ "blues.00014",
39
+ "blues.00015",
40
+ "blues.00016",
41
+ "blues.00017",
42
+ "blues.00018",
43
+ "blues.00019",
44
+ "blues.00020",
45
+ "blues.00021",
46
+ "blues.00022",
47
+ "blues.00023",
48
+ "blues.00024",
49
+ "blues.00025",
50
+ "blues.00026",
51
+ "blues.00027",
52
+ "blues.00028",
53
+ "blues.00061",
54
+ "blues.00062",
55
+ "blues.00063",
56
+ "blues.00064",
57
+ "blues.00065",
58
+ "blues.00066",
59
+ "blues.00067",
60
+ "blues.00068",
61
+ "blues.00069",
62
+ "blues.00070",
63
+ "blues.00071",
64
+ "blues.00072",
65
+ "blues.00098",
66
+ "blues.00099",
67
+ "classical.00011",
68
+ "classical.00012",
69
+ "classical.00013",
70
+ "classical.00014",
71
+ "classical.00015",
72
+ "classical.00016",
73
+ "classical.00017",
74
+ "classical.00018",
75
+ "classical.00019",
76
+ "classical.00020",
77
+ "classical.00021",
78
+ "classical.00022",
79
+ "classical.00023",
80
+ "classical.00024",
81
+ "classical.00025",
82
+ "classical.00026",
83
+ "classical.00027",
84
+ "classical.00028",
85
+ "classical.00029",
86
+ "classical.00034",
87
+ "classical.00035",
88
+ "classical.00036",
89
+ "classical.00037",
90
+ "classical.00038",
91
+ "classical.00039",
92
+ "classical.00040",
93
+ "classical.00041",
94
+ "classical.00049",
95
+ "classical.00077",
96
+ "classical.00078",
97
+ "classical.00079",
98
+ "country.00030",
99
+ "country.00031",
100
+ "country.00032",
101
+ "country.00033",
102
+ "country.00034",
103
+ "country.00035",
104
+ "country.00036",
105
+ "country.00037",
106
+ "country.00038",
107
+ "country.00039",
108
+ "country.00040",
109
+ "country.00043",
110
+ "country.00044",
111
+ "country.00046",
112
+ "country.00047",
113
+ "country.00048",
114
+ "country.00050",
115
+ "country.00051",
116
+ "country.00053",
117
+ "country.00054",
118
+ "country.00055",
119
+ "country.00056",
120
+ "country.00057",
121
+ "country.00058",
122
+ "country.00059",
123
+ "country.00060",
124
+ "country.00061",
125
+ "country.00062",
126
+ "country.00063",
127
+ "country.00064",
128
+ "disco.00001",
129
+ "disco.00021",
130
+ "disco.00058",
131
+ "disco.00062",
132
+ "disco.00063",
133
+ "disco.00064",
134
+ "disco.00065",
135
+ "disco.00066",
136
+ "disco.00069",
137
+ "disco.00076",
138
+ "disco.00077",
139
+ "disco.00078",
140
+ "disco.00079",
141
+ "disco.00080",
142
+ "disco.00081",
143
+ "disco.00082",
144
+ "disco.00083",
145
+ "disco.00084",
146
+ "disco.00085",
147
+ "disco.00086",
148
+ "disco.00087",
149
+ "disco.00088",
150
+ "disco.00091",
151
+ "disco.00092",
152
+ "disco.00093",
153
+ "disco.00094",
154
+ "disco.00096",
155
+ "disco.00097",
156
+ "disco.00099",
157
+ "hiphop.00000",
158
+ "hiphop.00026",
159
+ "hiphop.00027",
160
+ "hiphop.00030",
161
+ "hiphop.00040",
162
+ "hiphop.00043",
163
+ "hiphop.00044",
164
+ "hiphop.00045",
165
+ "hiphop.00051",
166
+ "hiphop.00052",
167
+ "hiphop.00053",
168
+ "hiphop.00054",
169
+ "hiphop.00062",
170
+ "hiphop.00063",
171
+ "hiphop.00064",
172
+ "hiphop.00065",
173
+ "hiphop.00066",
174
+ "hiphop.00067",
175
+ "hiphop.00068",
176
+ "hiphop.00069",
177
+ "hiphop.00070",
178
+ "hiphop.00071",
179
+ "hiphop.00072",
180
+ "hiphop.00073",
181
+ "hiphop.00074",
182
+ "hiphop.00075",
183
+ "hiphop.00099",
184
+ "jazz.00073",
185
+ "jazz.00074",
186
+ "jazz.00075",
187
+ "jazz.00076",
188
+ "jazz.00077",
189
+ "jazz.00078",
190
+ "jazz.00079",
191
+ "jazz.00080",
192
+ "jazz.00081",
193
+ "jazz.00082",
194
+ "jazz.00083",
195
+ "jazz.00084",
196
+ "jazz.00085",
197
+ "jazz.00086",
198
+ "jazz.00087",
199
+ "jazz.00088",
200
+ "jazz.00089",
201
+ "jazz.00090",
202
+ "jazz.00091",
203
+ "jazz.00092",
204
+ "jazz.00093",
205
+ "jazz.00094",
206
+ "jazz.00095",
207
+ "jazz.00096",
208
+ "jazz.00097",
209
+ "jazz.00098",
210
+ "jazz.00099",
211
+ "metal.00012",
212
+ "metal.00013",
213
+ "metal.00014",
214
+ "metal.00015",
215
+ "metal.00022",
216
+ "metal.00023",
217
+ "metal.00025",
218
+ "metal.00026",
219
+ "metal.00027",
220
+ "metal.00028",
221
+ "metal.00029",
222
+ "metal.00030",
223
+ "metal.00031",
224
+ "metal.00032",
225
+ "metal.00033",
226
+ "metal.00038",
227
+ "metal.00039",
228
+ "metal.00067",
229
+ "metal.00070",
230
+ "metal.00073",
231
+ "metal.00074",
232
+ "metal.00075",
233
+ "metal.00078",
234
+ "metal.00083",
235
+ "metal.00085",
236
+ "metal.00087",
237
+ "metal.00088",
238
+ "pop.00000",
239
+ "pop.00001",
240
+ "pop.00013",
241
+ "pop.00014",
242
+ "pop.00043",
243
+ "pop.00063",
244
+ "pop.00064",
245
+ "pop.00065",
246
+ "pop.00066",
247
+ "pop.00069",
248
+ "pop.00070",
249
+ "pop.00071",
250
+ "pop.00072",
251
+ "pop.00073",
252
+ "pop.00074",
253
+ "pop.00075",
254
+ "pop.00076",
255
+ "pop.00077",
256
+ "pop.00078",
257
+ "pop.00079",
258
+ "pop.00082",
259
+ "pop.00088",
260
+ "pop.00089",
261
+ "pop.00090",
262
+ "pop.00091",
263
+ "pop.00092",
264
+ "pop.00093",
265
+ "pop.00094",
266
+ "pop.00095",
267
+ "pop.00096",
268
+ "reggae.00034",
269
+ "reggae.00035",
270
+ "reggae.00036",
271
+ "reggae.00037",
272
+ "reggae.00038",
273
+ "reggae.00039",
274
+ "reggae.00040",
275
+ "reggae.00046",
276
+ "reggae.00047",
277
+ "reggae.00048",
278
+ "reggae.00052",
279
+ "reggae.00053",
280
+ "reggae.00064",
281
+ "reggae.00065",
282
+ "reggae.00066",
283
+ "reggae.00067",
284
+ "reggae.00068",
285
+ "reggae.00071",
286
+ "reggae.00079",
287
+ "reggae.00082",
288
+ "reggae.00083",
289
+ "reggae.00084",
290
+ "reggae.00087",
291
+ "reggae.00088",
292
+ "reggae.00089",
293
+ "reggae.00090",
294
+ "rock.00010",
295
+ "rock.00011",
296
+ "rock.00012",
297
+ "rock.00013",
298
+ "rock.00014",
299
+ "rock.00015",
300
+ "rock.00027",
301
+ "rock.00028",
302
+ "rock.00029",
303
+ "rock.00030",
304
+ "rock.00031",
305
+ "rock.00032",
306
+ "rock.00033",
307
+ "rock.00034",
308
+ "rock.00035",
309
+ "rock.00036",
310
+ "rock.00037",
311
+ "rock.00039",
312
+ "rock.00040",
313
+ "rock.00041",
314
+ "rock.00042",
315
+ "rock.00043",
316
+ "rock.00044",
317
+ "rock.00045",
318
+ "rock.00046",
319
+ "rock.00047",
320
+ "rock.00048",
321
+ "rock.00086",
322
+ "rock.00087",
323
+ "rock.00088",
324
+ "rock.00089",
325
+ "rock.00090",
326
+ ]
327
+
328
+ filtered_train = [
329
+ "blues.00029",
330
+ "blues.00030",
331
+ "blues.00031",
332
+ "blues.00032",
333
+ "blues.00033",
334
+ "blues.00034",
335
+ "blues.00035",
336
+ "blues.00036",
337
+ "blues.00037",
338
+ "blues.00038",
339
+ "blues.00039",
340
+ "blues.00040",
341
+ "blues.00041",
342
+ "blues.00042",
343
+ "blues.00043",
344
+ "blues.00044",
345
+ "blues.00045",
346
+ "blues.00046",
347
+ "blues.00047",
348
+ "blues.00048",
349
+ "blues.00049",
350
+ "blues.00073",
351
+ "blues.00074",
352
+ "blues.00075",
353
+ "blues.00076",
354
+ "blues.00077",
355
+ "blues.00078",
356
+ "blues.00079",
357
+ "blues.00080",
358
+ "blues.00081",
359
+ "blues.00082",
360
+ "blues.00083",
361
+ "blues.00084",
362
+ "blues.00085",
363
+ "blues.00086",
364
+ "blues.00087",
365
+ "blues.00088",
366
+ "blues.00089",
367
+ "blues.00090",
368
+ "blues.00091",
369
+ "blues.00092",
370
+ "blues.00093",
371
+ "blues.00094",
372
+ "blues.00095",
373
+ "blues.00096",
374
+ "blues.00097",
375
+ "classical.00030",
376
+ "classical.00031",
377
+ "classical.00032",
378
+ "classical.00033",
379
+ "classical.00043",
380
+ "classical.00044",
381
+ "classical.00045",
382
+ "classical.00046",
383
+ "classical.00047",
384
+ "classical.00048",
385
+ "classical.00050",
386
+ "classical.00051",
387
+ "classical.00052",
388
+ "classical.00053",
389
+ "classical.00054",
390
+ "classical.00055",
391
+ "classical.00056",
392
+ "classical.00057",
393
+ "classical.00058",
394
+ "classical.00059",
395
+ "classical.00060",
396
+ "classical.00061",
397
+ "classical.00062",
398
+ "classical.00063",
399
+ "classical.00064",
400
+ "classical.00065",
401
+ "classical.00066",
402
+ "classical.00067",
403
+ "classical.00080",
404
+ "classical.00081",
405
+ "classical.00082",
406
+ "classical.00083",
407
+ "classical.00084",
408
+ "classical.00085",
409
+ "classical.00086",
410
+ "classical.00087",
411
+ "classical.00088",
412
+ "classical.00089",
413
+ "classical.00090",
414
+ "classical.00091",
415
+ "classical.00092",
416
+ "classical.00093",
417
+ "classical.00094",
418
+ "classical.00095",
419
+ "classical.00096",
420
+ "classical.00097",
421
+ "classical.00098",
422
+ "classical.00099",
423
+ "country.00019",
424
+ "country.00020",
425
+ "country.00021",
426
+ "country.00022",
427
+ "country.00023",
428
+ "country.00024",
429
+ "country.00025",
430
+ "country.00026",
431
+ "country.00028",
432
+ "country.00029",
433
+ "country.00065",
434
+ "country.00066",
435
+ "country.00067",
436
+ "country.00068",
437
+ "country.00069",
438
+ "country.00070",
439
+ "country.00071",
440
+ "country.00072",
441
+ "country.00073",
442
+ "country.00074",
443
+ "country.00075",
444
+ "country.00076",
445
+ "country.00077",
446
+ "country.00078",
447
+ "country.00079",
448
+ "country.00080",
449
+ "country.00081",
450
+ "country.00082",
451
+ "country.00083",
452
+ "country.00084",
453
+ "country.00085",
454
+ "country.00086",
455
+ "country.00087",
456
+ "country.00088",
457
+ "country.00089",
458
+ "country.00090",
459
+ "country.00091",
460
+ "country.00092",
461
+ "country.00093",
462
+ "country.00094",
463
+ "country.00095",
464
+ "country.00096",
465
+ "country.00097",
466
+ "country.00098",
467
+ "country.00099",
468
+ "disco.00005",
469
+ "disco.00015",
470
+ "disco.00016",
471
+ "disco.00017",
472
+ "disco.00018",
473
+ "disco.00019",
474
+ "disco.00020",
475
+ "disco.00022",
476
+ "disco.00023",
477
+ "disco.00024",
478
+ "disco.00025",
479
+ "disco.00026",
480
+ "disco.00027",
481
+ "disco.00028",
482
+ "disco.00029",
483
+ "disco.00030",
484
+ "disco.00031",
485
+ "disco.00032",
486
+ "disco.00033",
487
+ "disco.00034",
488
+ "disco.00035",
489
+ "disco.00036",
490
+ "disco.00037",
491
+ "disco.00039",
492
+ "disco.00040",
493
+ "disco.00041",
494
+ "disco.00042",
495
+ "disco.00043",
496
+ "disco.00044",
497
+ "disco.00045",
498
+ "disco.00047",
499
+ "disco.00049",
500
+ "disco.00053",
501
+ "disco.00054",
502
+ "disco.00056",
503
+ "disco.00057",
504
+ "disco.00059",
505
+ "disco.00061",
506
+ "disco.00070",
507
+ "disco.00073",
508
+ "disco.00074",
509
+ "disco.00089",
510
+ "hiphop.00002",
511
+ "hiphop.00003",
512
+ "hiphop.00004",
513
+ "hiphop.00005",
514
+ "hiphop.00006",
515
+ "hiphop.00007",
516
+ "hiphop.00008",
517
+ "hiphop.00009",
518
+ "hiphop.00010",
519
+ "hiphop.00011",
520
+ "hiphop.00012",
521
+ "hiphop.00013",
522
+ "hiphop.00014",
523
+ "hiphop.00015",
524
+ "hiphop.00016",
525
+ "hiphop.00017",
526
+ "hiphop.00018",
527
+ "hiphop.00019",
528
+ "hiphop.00020",
529
+ "hiphop.00021",
530
+ "hiphop.00022",
531
+ "hiphop.00023",
532
+ "hiphop.00024",
533
+ "hiphop.00025",
534
+ "hiphop.00028",
535
+ "hiphop.00029",
536
+ "hiphop.00031",
537
+ "hiphop.00032",
538
+ "hiphop.00033",
539
+ "hiphop.00034",
540
+ "hiphop.00035",
541
+ "hiphop.00036",
542
+ "hiphop.00037",
543
+ "hiphop.00038",
544
+ "hiphop.00041",
545
+ "hiphop.00042",
546
+ "hiphop.00055",
547
+ "hiphop.00056",
548
+ "hiphop.00057",
549
+ "hiphop.00058",
550
+ "hiphop.00059",
551
+ "hiphop.00060",
552
+ "hiphop.00061",
553
+ "hiphop.00077",
554
+ "hiphop.00078",
555
+ "hiphop.00079",
556
+ "hiphop.00080",
557
+ "jazz.00000",
558
+ "jazz.00001",
559
+ "jazz.00011",
560
+ "jazz.00012",
561
+ "jazz.00013",
562
+ "jazz.00014",
563
+ "jazz.00015",
564
+ "jazz.00016",
565
+ "jazz.00017",
566
+ "jazz.00018",
567
+ "jazz.00019",
568
+ "jazz.00020",
569
+ "jazz.00021",
570
+ "jazz.00022",
571
+ "jazz.00023",
572
+ "jazz.00024",
573
+ "jazz.00041",
574
+ "jazz.00047",
575
+ "jazz.00048",
576
+ "jazz.00049",
577
+ "jazz.00050",
578
+ "jazz.00051",
579
+ "jazz.00052",
580
+ "jazz.00053",
581
+ "jazz.00054",
582
+ "jazz.00055",
583
+ "jazz.00056",
584
+ "jazz.00057",
585
+ "jazz.00058",
586
+ "jazz.00059",
587
+ "jazz.00060",
588
+ "jazz.00061",
589
+ "jazz.00062",
590
+ "jazz.00063",
591
+ "jazz.00064",
592
+ "jazz.00065",
593
+ "jazz.00066",
594
+ "jazz.00067",
595
+ "jazz.00068",
596
+ "jazz.00069",
597
+ "jazz.00070",
598
+ "jazz.00071",
599
+ "jazz.00072",
600
+ "metal.00002",
601
+ "metal.00003",
602
+ "metal.00005",
603
+ "metal.00021",
604
+ "metal.00024",
605
+ "metal.00035",
606
+ "metal.00046",
607
+ "metal.00047",
608
+ "metal.00048",
609
+ "metal.00049",
610
+ "metal.00050",
611
+ "metal.00051",
612
+ "metal.00052",
613
+ "metal.00053",
614
+ "metal.00054",
615
+ "metal.00055",
616
+ "metal.00056",
617
+ "metal.00057",
618
+ "metal.00059",
619
+ "metal.00060",
620
+ "metal.00061",
621
+ "metal.00062",
622
+ "metal.00063",
623
+ "metal.00064",
624
+ "metal.00065",
625
+ "metal.00066",
626
+ "metal.00069",
627
+ "metal.00071",
628
+ "metal.00072",
629
+ "metal.00079",
630
+ "metal.00080",
631
+ "metal.00084",
632
+ "metal.00086",
633
+ "metal.00089",
634
+ "metal.00090",
635
+ "metal.00091",
636
+ "metal.00092",
637
+ "metal.00093",
638
+ "metal.00094",
639
+ "metal.00095",
640
+ "metal.00096",
641
+ "metal.00097",
642
+ "metal.00098",
643
+ "metal.00099",
644
+ "pop.00002",
645
+ "pop.00003",
646
+ "pop.00004",
647
+ "pop.00005",
648
+ "pop.00006",
649
+ "pop.00007",
650
+ "pop.00008",
651
+ "pop.00009",
652
+ "pop.00011",
653
+ "pop.00012",
654
+ "pop.00016",
655
+ "pop.00017",
656
+ "pop.00018",
657
+ "pop.00019",
658
+ "pop.00020",
659
+ "pop.00023",
660
+ "pop.00024",
661
+ "pop.00025",
662
+ "pop.00026",
663
+ "pop.00027",
664
+ "pop.00028",
665
+ "pop.00029",
666
+ "pop.00031",
667
+ "pop.00032",
668
+ "pop.00033",
669
+ "pop.00034",
670
+ "pop.00035",
671
+ "pop.00036",
672
+ "pop.00038",
673
+ "pop.00039",
674
+ "pop.00040",
675
+ "pop.00041",
676
+ "pop.00042",
677
+ "pop.00044",
678
+ "pop.00046",
679
+ "pop.00049",
680
+ "pop.00050",
681
+ "pop.00080",
682
+ "pop.00097",
683
+ "pop.00098",
684
+ "pop.00099",
685
+ "reggae.00000",
686
+ "reggae.00001",
687
+ "reggae.00002",
688
+ "reggae.00004",
689
+ "reggae.00006",
690
+ "reggae.00009",
691
+ "reggae.00011",
692
+ "reggae.00012",
693
+ "reggae.00014",
694
+ "reggae.00015",
695
+ "reggae.00016",
696
+ "reggae.00017",
697
+ "reggae.00018",
698
+ "reggae.00019",
699
+ "reggae.00020",
700
+ "reggae.00021",
701
+ "reggae.00022",
702
+ "reggae.00023",
703
+ "reggae.00024",
704
+ "reggae.00025",
705
+ "reggae.00026",
706
+ "reggae.00027",
707
+ "reggae.00028",
708
+ "reggae.00029",
709
+ "reggae.00030",
710
+ "reggae.00031",
711
+ "reggae.00032",
712
+ "reggae.00042",
713
+ "reggae.00043",
714
+ "reggae.00044",
715
+ "reggae.00045",
716
+ "reggae.00049",
717
+ "reggae.00050",
718
+ "reggae.00051",
719
+ "reggae.00054",
720
+ "reggae.00055",
721
+ "reggae.00056",
722
+ "reggae.00057",
723
+ "reggae.00058",
724
+ "reggae.00059",
725
+ "reggae.00060",
726
+ "reggae.00063",
727
+ "reggae.00069",
728
+ "rock.00000",
729
+ "rock.00001",
730
+ "rock.00002",
731
+ "rock.00003",
732
+ "rock.00004",
733
+ "rock.00005",
734
+ "rock.00006",
735
+ "rock.00007",
736
+ "rock.00008",
737
+ "rock.00009",
738
+ "rock.00016",
739
+ "rock.00017",
740
+ "rock.00018",
741
+ "rock.00019",
742
+ "rock.00020",
743
+ "rock.00021",
744
+ "rock.00022",
745
+ "rock.00023",
746
+ "rock.00024",
747
+ "rock.00025",
748
+ "rock.00026",
749
+ "rock.00057",
750
+ "rock.00058",
751
+ "rock.00059",
752
+ "rock.00060",
753
+ "rock.00061",
754
+ "rock.00062",
755
+ "rock.00063",
756
+ "rock.00064",
757
+ "rock.00065",
758
+ "rock.00066",
759
+ "rock.00067",
760
+ "rock.00068",
761
+ "rock.00069",
762
+ "rock.00070",
763
+ "rock.00091",
764
+ "rock.00092",
765
+ "rock.00093",
766
+ "rock.00094",
767
+ "rock.00095",
768
+ "rock.00096",
769
+ "rock.00097",
770
+ "rock.00098",
771
+ "rock.00099",
772
+ ]
773
+
774
+ filtered_valid = [
775
+ "blues.00000",
776
+ "blues.00001",
777
+ "blues.00002",
778
+ "blues.00003",
779
+ "blues.00004",
780
+ "blues.00005",
781
+ "blues.00006",
782
+ "blues.00007",
783
+ "blues.00008",
784
+ "blues.00009",
785
+ "blues.00010",
786
+ "blues.00011",
787
+ "blues.00050",
788
+ "blues.00051",
789
+ "blues.00052",
790
+ "blues.00053",
791
+ "blues.00054",
792
+ "blues.00055",
793
+ "blues.00056",
794
+ "blues.00057",
795
+ "blues.00058",
796
+ "blues.00059",
797
+ "blues.00060",
798
+ "classical.00000",
799
+ "classical.00001",
800
+ "classical.00002",
801
+ "classical.00003",
802
+ "classical.00004",
803
+ "classical.00005",
804
+ "classical.00006",
805
+ "classical.00007",
806
+ "classical.00008",
807
+ "classical.00009",
808
+ "classical.00010",
809
+ "classical.00068",
810
+ "classical.00069",
811
+ "classical.00070",
812
+ "classical.00071",
813
+ "classical.00072",
814
+ "classical.00073",
815
+ "classical.00074",
816
+ "classical.00075",
817
+ "classical.00076",
818
+ "country.00000",
819
+ "country.00001",
820
+ "country.00002",
821
+ "country.00003",
822
+ "country.00004",
823
+ "country.00005",
824
+ "country.00006",
825
+ "country.00007",
826
+ "country.00009",
827
+ "country.00010",
828
+ "country.00011",
829
+ "country.00012",
830
+ "country.00013",
831
+ "country.00014",
832
+ "country.00015",
833
+ "country.00016",
834
+ "country.00017",
835
+ "country.00018",
836
+ "country.00027",
837
+ "country.00041",
838
+ "country.00042",
839
+ "country.00045",
840
+ "country.00049",
841
+ "disco.00000",
842
+ "disco.00002",
843
+ "disco.00003",
844
+ "disco.00004",
845
+ "disco.00006",
846
+ "disco.00007",
847
+ "disco.00008",
848
+ "disco.00009",
849
+ "disco.00010",
850
+ "disco.00011",
851
+ "disco.00012",
852
+ "disco.00013",
853
+ "disco.00014",
854
+ "disco.00046",
855
+ "disco.00048",
856
+ "disco.00052",
857
+ "disco.00067",
858
+ "disco.00068",
859
+ "disco.00072",
860
+ "disco.00075",
861
+ "disco.00090",
862
+ "disco.00095",
863
+ "hiphop.00081",
864
+ "hiphop.00082",
865
+ "hiphop.00083",
866
+ "hiphop.00084",
867
+ "hiphop.00085",
868
+ "hiphop.00086",
869
+ "hiphop.00087",
870
+ "hiphop.00088",
871
+ "hiphop.00089",
872
+ "hiphop.00090",
873
+ "hiphop.00091",
874
+ "hiphop.00092",
875
+ "hiphop.00093",
876
+ "hiphop.00094",
877
+ "hiphop.00095",
878
+ "hiphop.00096",
879
+ "hiphop.00097",
880
+ "hiphop.00098",
881
+ "jazz.00002",
882
+ "jazz.00003",
883
+ "jazz.00004",
884
+ "jazz.00005",
885
+ "jazz.00006",
886
+ "jazz.00007",
887
+ "jazz.00008",
888
+ "jazz.00009",
889
+ "jazz.00010",
890
+ "jazz.00025",
891
+ "jazz.00026",
892
+ "jazz.00027",
893
+ "jazz.00028",
894
+ "jazz.00029",
895
+ "jazz.00030",
896
+ "jazz.00031",
897
+ "jazz.00032",
898
+ "metal.00000",
899
+ "metal.00001",
900
+ "metal.00006",
901
+ "metal.00007",
902
+ "metal.00008",
903
+ "metal.00009",
904
+ "metal.00010",
905
+ "metal.00011",
906
+ "metal.00016",
907
+ "metal.00017",
908
+ "metal.00018",
909
+ "metal.00019",
910
+ "metal.00020",
911
+ "metal.00036",
912
+ "metal.00037",
913
+ "metal.00068",
914
+ "metal.00076",
915
+ "metal.00077",
916
+ "metal.00081",
917
+ "metal.00082",
918
+ "pop.00010",
919
+ "pop.00053",
920
+ "pop.00055",
921
+ "pop.00058",
922
+ "pop.00059",
923
+ "pop.00060",
924
+ "pop.00061",
925
+ "pop.00062",
926
+ "pop.00081",
927
+ "pop.00083",
928
+ "pop.00084",
929
+ "pop.00085",
930
+ "pop.00086",
931
+ "reggae.00061",
932
+ "reggae.00062",
933
+ "reggae.00070",
934
+ "reggae.00072",
935
+ "reggae.00074",
936
+ "reggae.00076",
937
+ "reggae.00077",
938
+ "reggae.00078",
939
+ "reggae.00085",
940
+ "reggae.00092",
941
+ "reggae.00093",
942
+ "reggae.00094",
943
+ "reggae.00095",
944
+ "reggae.00096",
945
+ "reggae.00097",
946
+ "reggae.00098",
947
+ "reggae.00099",
948
+ "rock.00038",
949
+ "rock.00049",
950
+ "rock.00050",
951
+ "rock.00051",
952
+ "rock.00052",
953
+ "rock.00053",
954
+ "rock.00054",
955
+ "rock.00055",
956
+ "rock.00056",
957
+ "rock.00071",
958
+ "rock.00072",
959
+ "rock.00073",
960
+ "rock.00074",
961
+ "rock.00075",
962
+ "rock.00076",
963
+ "rock.00077",
964
+ "rock.00078",
965
+ "rock.00079",
966
+ "rock.00080",
967
+ "rock.00081",
968
+ "rock.00082",
969
+ "rock.00083",
970
+ "rock.00084",
971
+ "rock.00085",
972
+ ]
973
+
974
+
975
+ URL = "http://opihi.cs.uvic.ca/sound/genres.tar.gz"
976
+ FOLDER_IN_ARCHIVE = "genres"
977
+ _CHECKSUMS = {
978
+ "http://opihi.cs.uvic.ca/sound/genres.tar.gz": "24347e0223d2ba798e0a558c4c172d9d4a19c00bb7963fe055d183dadb4ef2c6"
979
+ }
980
+
981
+
982
+ def load_gtzan_item(fileid: str, path: str, ext_audio: str) -> Tuple[Tensor, str]:
983
+ """
984
+ Loads a file from the dataset and returns the raw waveform
985
+ as a Torch Tensor, its sample rate as an integer, and its
986
+ genre as a string.
987
+ """
988
+ # Filenames are of the form label.id, e.g. blues.00078
989
+ label, _ = fileid.split(".")
990
+
991
+ # Read wav
992
+ file_audio = os.path.join(path, label, fileid + ext_audio)
993
+ waveform, sample_rate = torchaudio.load(file_audio)
994
+
995
+ return waveform, sample_rate, label
996
+
997
+
998
+ class GTZAN(Dataset):
999
+ """Create a Dataset for *GTZAN* [:footcite:`tzanetakis_essl_cook_2001`].
1000
+
1001
+ Note:
1002
+ Please see http://marsyas.info/downloads/datasets.html if you are planning to use
1003
+ this dataset to publish results.
1004
+
1005
+ Args:
1006
+ root (str or Path): Path to the directory where the dataset is found or downloaded.
1007
+ url (str, optional): The URL to download the dataset from.
1008
+ (default: ``"http://opihi.cs.uvic.ca/sound/genres.tar.gz"``)
1009
+ folder_in_archive (str, optional): The top-level directory of the dataset.
1010
+ download (bool, optional):
1011
+ Whether to download the dataset if it is not found at root path. (default: ``False``).
1012
+ subset (str or None, optional): Which subset of the dataset to use.
1013
+ One of ``"training"``, ``"validation"``, ``"testing"`` or ``None``.
1014
+ If ``None``, the entire dataset is used. (default: ``None``).
1015
+ """
1016
+
1017
+ _ext_audio = ".wav"
1018
+
1019
+ def __init__(
1020
+ self,
1021
+ root: Union[str, Path],
1022
+ url: str = URL,
1023
+ folder_in_archive: str = FOLDER_IN_ARCHIVE,
1024
+ download: bool = False,
1025
+ subset: Optional[str] = None,
1026
+ ) -> None:
1027
+
1028
+ # super(GTZAN, self).__init__()
1029
+
1030
+ # Get string representation of 'root' in case Path object is passed
1031
+ root = os.fspath(root)
1032
+
1033
+ self.root = root
1034
+ self.url = url
1035
+ self.folder_in_archive = folder_in_archive
1036
+ self.download = download
1037
+ self.subset = subset
1038
+
1039
+ assert subset is None or subset in ["training", "validation", "testing"], (
1040
+ "When `subset` not None, it must take a value from " + "{'training', 'validation', 'testing'}."
1041
+ )
1042
+
1043
+ archive = os.path.basename(url)
1044
+ archive = os.path.join(root, archive)
1045
+ self._path = os.path.join(root, folder_in_archive)
1046
+
1047
+ if download:
1048
+ if not os.path.isdir(self._path):
1049
+ if not os.path.isfile(archive):
1050
+ checksum = _CHECKSUMS.get(url, None)
1051
+ download_url_to_file(url, archive, hash_prefix=checksum)
1052
+ extract_archive(archive)
1053
+
1054
+ if not os.path.isdir(self._path):
1055
+ raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
1056
+
1057
+ if self.subset is None:
1058
+ # Check every subdirectory under dataset root
1059
+ # which has the same name as the genres in
1060
+ # GTZAN (e.g. `root_dir'/blues/, `root_dir'/rock, etc.)
1061
+ # This lets users remove or move around song files,
1062
+ # useful when e.g. they want to use only some of the files
1063
+ # in a genre or want to label other files with a different
1064
+ # genre.
1065
+ self._walker = []
1066
+
1067
+ root = os.path.expanduser(self._path)
1068
+
1069
+ for directory in gtzan_genres:
1070
+ fulldir = os.path.join(root, directory)
1071
+
1072
+ if not os.path.exists(fulldir):
1073
+ continue
1074
+
1075
+ songs_in_genre = os.listdir(fulldir)
1076
+ songs_in_genre.sort()
1077
+ for fname in songs_in_genre:
1078
+ name, ext = os.path.splitext(fname)
1079
+ if ext.lower() == ".wav" and "." in name:
1080
+ # Check whether the file is of the form
1081
+ # `gtzan_genre`.`5 digit number`.wav
1082
+ genre, num = name.split(".")
1083
+ if genre in gtzan_genres and len(num) == 5 and num.isdigit():
1084
+ self._walker.append(name)
1085
+ else:
1086
+ if self.subset == "training":
1087
+ self._walker = filtered_train
1088
+ elif self.subset == "validation":
1089
+ self._walker = filtered_valid
1090
+ elif self.subset == "testing":
1091
+ self._walker = filtered_test
1092
+
1093
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, str]:
1094
+ """Load the n-th sample from the dataset.
1095
+
1096
+ Args:
1097
+ n (int): The index of the sample to be loaded
1098
+
1099
+ Returns:
1100
+ (Tensor, int, str): ``(waveform, sample_rate, label)``
1101
+ """
1102
+ fileid = self._walker[n]
1103
+ item = load_gtzan_item(fileid, self._path, self._ext_audio)
1104
+ waveform, sample_rate, label = item
1105
+ return waveform, sample_rate, label
1106
+
1107
+ def __len__(self) -> int:
1108
+ return len(self._walker)
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/datasets/librilight_limited.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import List, Tuple, Union
4
+
5
+ from torch import Tensor
6
+ from torch.hub import download_url_to_file
7
+ from torch.utils.data import Dataset
8
+ from torchaudio.datasets.librispeech import load_librispeech_item
9
+ from torchaudio.datasets.utils import extract_archive
10
+
11
+
12
+ _ARCHIVE_NAME = "librispeech_finetuning"
13
+ _URL = "https://dl.fbaipublicfiles.com/librilight/data/librispeech_finetuning.tgz"
14
+ _CHECKSUM = "5d1efdc777b548194d7e09ba89126e2188026df9fd57aa57eb14408d2b2342af"
15
+
16
+
17
+ def _get_fileids_paths(path, subset, _ext_audio) -> List[Tuple[str, str]]:
18
+ """Get the file names and the corresponding file paths without `speaker_id`
19
+ and `chapter_id` directories.
20
+ The format of path is like:
21
+ {root}/{_ARCHIVE_NAME}/1h/[0-5]/[clean, other] or
22
+ {root}/{_ARCHIVE_NAME}/9h/[clean, other]
23
+ """
24
+ if subset == "10min":
25
+ files_paths = [
26
+ (os.path.join(os.path.dirname(p), "..", ".."), str(p.stem))
27
+ for p in Path(path).glob("1h/0/*/*/*/*" + _ext_audio)
28
+ ]
29
+ elif subset in ["1h", "10h"]:
30
+ files_paths = [
31
+ (os.path.join(os.path.dirname(p), "..", ".."), str(p.stem))
32
+ for p in Path(path).glob("1h/*/*/*/*/*" + _ext_audio)
33
+ ]
34
+ if subset == "10h":
35
+ files_paths += [
36
+ (os.path.join(os.path.dirname(p), "..", ".."), str(p.stem))
37
+ for p in Path(path).glob("9h/*/*/*/*" + _ext_audio)
38
+ ]
39
+ else:
40
+ raise ValueError(f"Unsupported subset value. Found {subset}.")
41
+ files_paths = sorted(files_paths, key=lambda x: x[0] + x[1])
42
+ return files_paths
43
+
44
+
45
+ class LibriLightLimited(Dataset):
46
+ """Create a Dataset for LibriLightLimited, which is the supervised subset of
47
+ LibriLight dataset.
48
+
49
+ Args:
50
+ root (str or Path): Path to the directory where the dataset is found or downloaded.
51
+ subset (str, optional): The subset to use. Options: [``10min``, ``1h``, ``10h``]
52
+ (Default: ``10min``).
53
+ download (bool, optional):
54
+ Whether to download the dataset if it is not found at root path. (default: ``False``).
55
+ """
56
+
57
+ _ext_txt = ".trans.txt"
58
+ _ext_audio = ".flac"
59
+
60
+ def __init__(
61
+ self,
62
+ root: Union[str, Path],
63
+ subset: str = "10min",
64
+ download: bool = False,
65
+ ) -> None:
66
+ assert subset in ["10min", "1h", "10h"], "`subset` must be one of ['10min', '1h', '10h']"
67
+
68
+ root = os.fspath(root)
69
+ self._path = os.path.join(root, _ARCHIVE_NAME)
70
+ archive = os.path.join(root, f"{_ARCHIVE_NAME}.tgz")
71
+ if not os.path.isdir(self._path):
72
+ if not download:
73
+ raise RuntimeError("Dataset not found. Please use `download=True` to download")
74
+ if not os.path.isfile(archive):
75
+ download_url_to_file(_URL, archive, hash_prefix=_CHECKSUM)
76
+ extract_archive(archive)
77
+ self._fileids_paths = _get_fileids_paths(self._path, subset, self._ext_audio)
78
+
79
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
80
+ """Load the n-th sample from the dataset.
81
+ Args:
82
+ n (int): The index of the sample to be loaded
83
+ Returns:
84
+ (Tensor, int, str, int, int, int):
85
+ ``(waveform, sample_rate, transcript, speaker_id, chapter_id, utterance_id)``
86
+ """
87
+ file_path, fileid = self._fileids_paths[n]
88
+ return load_librispeech_item(fileid, file_path, self._ext_audio, self._ext_txt)
89
+
90
+ def __len__(self) -> int:
91
+ return len(self._fileids_paths)
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/datasets/librimix.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import List, Tuple, Union
3
+
4
+ import torch
5
+ import torchaudio
6
+ from torch.utils.data import Dataset
7
+
8
+ SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]]
9
+
10
+
11
+ class LibriMix(Dataset):
12
+ r"""Create the *LibriMix* [:footcite:`cosentino2020librimix`] dataset.
13
+
14
+ Args:
15
+ root (str or Path): The path to the directory where the directory ``Libri2Mix`` or
16
+ ``Libri3Mix`` is stored.
17
+ subset (str, optional): The subset to use. Options: [``train-360``, ``train-100``,
18
+ ``dev``, and ``test``] (Default: ``train-360``).
19
+ num_speakers (int, optional): The number of speakers, which determines the directories
20
+ to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect
21
+ N source audios. (Default: 2)
22
+ sample_rate (int, optional): sample rate of audio files. The ``sample_rate`` determines
23
+ which subdirectory the audio are fetched. If any of the audio has a different sample
24
+ rate, raises ``ValueError``. Options: [8000, 16000] (Default: 8000)
25
+ task (str, optional): the task of LibriMix.
26
+ Options: [``enh_single``, ``enh_both``, ``sep_clean``, ``sep_noisy``]
27
+ (Default: ``sep_clean``)
28
+
29
+ Note:
30
+ The LibriMix dataset needs to be manually generated. Please check https://github.com/JorisCos/LibriMix
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ root: Union[str, Path],
36
+ subset: str = "train-360",
37
+ num_speakers: int = 2,
38
+ sample_rate: int = 8000,
39
+ task: str = "sep_clean",
40
+ ):
41
+ self.root = Path(root) / f"Libri{num_speakers}Mix"
42
+ if sample_rate == 8000:
43
+ self.root = self.root / "wav8k/min" / subset
44
+ elif sample_rate == 16000:
45
+ self.root = self.root / "wav16k/min" / subset
46
+ else:
47
+ raise ValueError(f"Unsupported sample rate. Found {sample_rate}.")
48
+ self.sample_rate = sample_rate
49
+ self.task = task
50
+ self.mix_dir = (self.root / f"mix_{task.split('_')[1]}").resolve()
51
+ self.src_dirs = [(self.root / f"s{i+1}").resolve() for i in range(num_speakers)]
52
+
53
+ self.files = [p.name for p in self.mix_dir.glob("*wav")]
54
+ self.files.sort()
55
+
56
+ def _load_audio(self, path) -> torch.Tensor:
57
+ waveform, sample_rate = torchaudio.load(path)
58
+ if sample_rate != self.sample_rate:
59
+ raise ValueError(
60
+ f"The dataset contains audio file of sample rate {sample_rate}, "
61
+ f"but the requested sample rate is {self.sample_rate}."
62
+ )
63
+ return waveform
64
+
65
+ def _load_sample(self, filename) -> SampleType:
66
+ mixed = self._load_audio(str(self.mix_dir / filename))
67
+ srcs = []
68
+ for i, dir_ in enumerate(self.src_dirs):
69
+ src = self._load_audio(str(dir_ / filename))
70
+ if mixed.shape != src.shape:
71
+ raise ValueError(f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}")
72
+ srcs.append(src)
73
+ return self.sample_rate, mixed, srcs
74
+
75
+ def __len__(self) -> int:
76
+ return len(self.files)
77
+
78
+ def __getitem__(self, key: int) -> SampleType:
79
+ """Load the n-th sample from the dataset.
80
+ Args:
81
+ key (int): The index of the sample to be loaded
82
+ Returns:
83
+ (int, Tensor, List[Tensor]): ``(sample_rate, mix_waveform, list_of_source_waveforms)``
84
+ """
85
+ return self._load_sample(self.files[key])
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchaudio/datasets/librispeech.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Tuple, Union
4
+
5
+ import torchaudio
6
+ from torch import Tensor
7
+ from torch.hub import download_url_to_file
8
+ from torch.utils.data import Dataset
9
+ from torchaudio.datasets.utils import extract_archive
10
+
11
+ URL = "train-clean-100"
12
+ FOLDER_IN_ARCHIVE = "LibriSpeech"
13
+ _DATA_SUBSETS = [
14
+ "dev-clean",
15
+ "dev-other",
16
+ "test-clean",
17
+ "test-other",
18
+ "train-clean-100",
19
+ "train-clean-360",
20
+ "train-other-500",
21
+ ]
22
+ _CHECKSUMS = {
23
+ "http://www.openslr.org/resources/12/dev-clean.tar.gz": "76f87d090650617fca0cac8f88b9416e0ebf80350acb97b343a85fa903728ab3", # noqa: E501
24
+ "http://www.openslr.org/resources/12/dev-other.tar.gz": "12661c48e8c3fe1de2c1caa4c3e135193bfb1811584f11f569dd12645aa84365", # noqa: E501
25
+ "http://www.openslr.org/resources/12/test-clean.tar.gz": "39fde525e59672dc6d1551919b1478f724438a95aa55f874b576be21967e6c23", # noqa: E501
26
+ "http://www.openslr.org/resources/12/test-other.tar.gz": "d09c181bba5cf717b3dee7d4d592af11a3ee3a09e08ae025c5506f6ebe961c29", # noqa: E501
27
+ "http://www.openslr.org/resources/12/train-clean-100.tar.gz": "d4ddd1d5a6ab303066f14971d768ee43278a5f2a0aa43dc716b0e64ecbbbf6e2", # noqa: E501
28
+ "http://www.openslr.org/resources/12/train-clean-360.tar.gz": "146a56496217e96c14334a160df97fffedd6e0a04e66b9c5af0d40be3c792ecf", # noqa: E501
29
+ "http://www.openslr.org/resources/12/train-other-500.tar.gz": "ddb22f27f96ec163645d53215559df6aa36515f26e01dd70798188350adcb6d2", # noqa: E501
30
+ }
31
+
32
+
33
+ def download_librispeech(root, url):
34
+ base_url = "http://www.openslr.org/resources/12/"
35
+ ext_archive = ".tar.gz"
36
+
37
+ filename = url + ext_archive
38
+ archive = os.path.join(root, filename)
39
+ download_url = os.path.join(base_url, filename)
40
+ if not os.path.isfile(archive):
41
+ checksum = _CHECKSUMS.get(download_url, None)
42
+ download_url_to_file(download_url, archive, hash_prefix=checksum)
43
+ extract_archive(archive)
44
+
45
+
46
+ def load_librispeech_item(
47
+ fileid: str, path: str, ext_audio: str, ext_txt: str
48
+ ) -> Tuple[Tensor, int, str, int, int, int]:
49
+ speaker_id, chapter_id, utterance_id = fileid.split("-")
50
+
51
+ # Load audio
52
+ fileid_audio = f"{speaker_id}-{chapter_id}-{utterance_id}"
53
+ file_audio = fileid_audio + ext_audio
54
+ file_audio = os.path.join(path, speaker_id, chapter_id, file_audio)
55
+ waveform, sample_rate = torchaudio.load(file_audio)
56
+
57
+ # Load text
58
+ file_text = f"{speaker_id}-{chapter_id}{ext_txt}"
59
+ file_text = os.path.join(path, speaker_id, chapter_id, file_text)
60
+ with open(file_text) as ft:
61
+ for line in ft:
62
+ fileid_text, transcript = line.strip().split(" ", 1)
63
+ if fileid_audio == fileid_text:
64
+ break
65
+ else:
66
+ # Translation not found
67
+ raise FileNotFoundError(f"Translation not found for {fileid_audio}")
68
+
69
+ return (
70
+ waveform,
71
+ sample_rate,
72
+ transcript,
73
+ int(speaker_id),
74
+ int(chapter_id),
75
+ int(utterance_id),
76
+ )
77
+
78
+
79
+ class LIBRISPEECH(Dataset):
80
+ """Create a Dataset for *LibriSpeech* [:footcite:`7178964`].
81
+
82
+ Args:
83
+ root (str or Path): Path to the directory where the dataset is found or downloaded.
84
+ url (str, optional): The URL to download the dataset from,
85
+ or the type of the dataset to dowload.
86
+ Allowed type values are ``"dev-clean"``, ``"dev-other"``, ``"test-clean"``,
87
+ ``"test-other"``, ``"train-clean-100"``, ``"train-clean-360"`` and
88
+ ``"train-other-500"``. (default: ``"train-clean-100"``)
89
+ folder_in_archive (str, optional):
90
+ The top-level directory of the dataset. (default: ``"LibriSpeech"``)
91
+ download (bool, optional):
92
+ Whether to download the dataset if it is not found at root path. (default: ``False``).
93
+ """
94
+
95
+ _ext_txt = ".trans.txt"
96
+ _ext_audio = ".flac"
97
+
98
+ def __init__(
99
+ self,
100
+ root: Union[str, Path],
101
+ url: str = URL,
102
+ folder_in_archive: str = FOLDER_IN_ARCHIVE,
103
+ download: bool = False,
104
+ ) -> None:
105
+ if url not in _DATA_SUBSETS:
106
+ raise ValueError(f"Invalid url '{url}' given; please provide one of {_DATA_SUBSETS}.")
107
+
108
+ root = os.fspath(root)
109
+ self._path = os.path.join(root, folder_in_archive, url)
110
+
111
+ if not os.path.isdir(self._path):
112
+ if download:
113
+ download_librispeech(root, url)
114
+ else:
115
+ raise RuntimeError(
116
+ f"Dataset not found at {self._path}. Please set `download=True` to download the dataset."
117
+ )
118
+
119
+ self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*/*/*" + self._ext_audio))
120
+
121
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
122
+ """Load the n-th sample from the dataset.
123
+
124
+ Args:
125
+ n (int): The index of the sample to be loaded
126
+
127
+ Returns:
128
+ (Tensor, int, str, int, int, int):
129
+ ``(waveform, sample_rate, transcript, speaker_id, chapter_id, utterance_id)``
130
+ """
131
+ fileid = self._walker[n]
132
+ return load_librispeech_item(fileid, self._path, self._ext_audio, self._ext_txt)
133
+
134
+ def __len__(self) -> int:
135
+ return len(self._walker)