primepake commited on
Commit
11eafe6
·
1 Parent(s): 72229da

add dac continous latent space

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. dac-vae/audiotools/__init__.py +10 -0
  2. dac-vae/audiotools/__pycache__/__init__.cpython-310.pyc +0 -0
  3. dac-vae/audiotools/core/__init__.py +4 -0
  4. dac-vae/audiotools/core/__pycache__/__init__.cpython-310.pyc +0 -0
  5. dac-vae/audiotools/core/__pycache__/audio_signal.cpython-310.pyc +0 -0
  6. dac-vae/audiotools/core/__pycache__/display.cpython-310.pyc +0 -0
  7. dac-vae/audiotools/core/__pycache__/dsp.cpython-310.pyc +0 -0
  8. dac-vae/audiotools/core/__pycache__/effects.cpython-310.pyc +0 -0
  9. dac-vae/audiotools/core/__pycache__/ffmpeg.cpython-310.pyc +0 -0
  10. dac-vae/audiotools/core/__pycache__/loudness.cpython-310.pyc +0 -0
  11. dac-vae/audiotools/core/__pycache__/playback.cpython-310.pyc +0 -0
  12. dac-vae/audiotools/core/__pycache__/util.cpython-310.pyc +0 -0
  13. dac-vae/audiotools/core/__pycache__/whisper.cpython-310.pyc +0 -0
  14. dac-vae/audiotools/core/audio_signal.py +1682 -0
  15. dac-vae/audiotools/core/display.py +194 -0
  16. dac-vae/audiotools/core/dsp.py +390 -0
  17. dac-vae/audiotools/core/effects.py +647 -0
  18. dac-vae/audiotools/core/ffmpeg.py +211 -0
  19. dac-vae/audiotools/core/loudness.py +320 -0
  20. dac-vae/audiotools/core/playback.py +252 -0
  21. dac-vae/audiotools/core/templates/__init__.py +0 -0
  22. dac-vae/audiotools/core/templates/__pycache__/__init__.cpython-310.pyc +0 -0
  23. dac-vae/audiotools/core/templates/headers.html +322 -0
  24. dac-vae/audiotools/core/templates/pandoc.css +407 -0
  25. dac-vae/audiotools/core/templates/widget.html +52 -0
  26. dac-vae/audiotools/core/util.py +671 -0
  27. dac-vae/audiotools/core/whisper.py +97 -0
  28. dac-vae/audiotools/data/__init__.py +3 -0
  29. dac-vae/audiotools/data/__pycache__/__init__.cpython-310.pyc +0 -0
  30. dac-vae/audiotools/data/__pycache__/datasets.cpython-310.pyc +0 -0
  31. dac-vae/audiotools/data/__pycache__/preprocess.cpython-310.pyc +0 -0
  32. dac-vae/audiotools/data/__pycache__/transforms.cpython-310.pyc +0 -0
  33. dac-vae/audiotools/data/datasets.py +517 -0
  34. dac-vae/audiotools/data/preprocess.py +81 -0
  35. dac-vae/audiotools/data/transforms.py +1592 -0
  36. dac-vae/audiotools/metrics/__init__.py +6 -0
  37. dac-vae/audiotools/metrics/__pycache__/__init__.cpython-310.pyc +0 -0
  38. dac-vae/audiotools/metrics/__pycache__/distance.cpython-310.pyc +0 -0
  39. dac-vae/audiotools/metrics/__pycache__/quality.cpython-310.pyc +0 -0
  40. dac-vae/audiotools/metrics/__pycache__/spectral.cpython-310.pyc +0 -0
  41. dac-vae/audiotools/metrics/distance.py +131 -0
  42. dac-vae/audiotools/metrics/quality.py +159 -0
  43. dac-vae/audiotools/metrics/spectral.py +247 -0
  44. dac-vae/audiotools/ml/__init__.py +5 -0
  45. dac-vae/audiotools/ml/__pycache__/__init__.cpython-310.pyc +0 -0
  46. dac-vae/audiotools/ml/__pycache__/accelerator.cpython-310.pyc +0 -0
  47. dac-vae/audiotools/ml/__pycache__/decorators.cpython-310.pyc +0 -0
  48. dac-vae/audiotools/ml/__pycache__/experiment.cpython-310.pyc +0 -0
  49. dac-vae/audiotools/ml/accelerator.py +184 -0
  50. dac-vae/audiotools/ml/decorators.py +441 -0
dac-vae/audiotools/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "0.7.4"
2
+ from .core import AudioSignal
3
+ from .core import STFTParams
4
+ from .core import Meter
5
+ from .core import util
6
+ from . import metrics
7
+ from . import data
8
+ from . import ml
9
+ from .data import datasets
10
+ from .data import transforms
dac-vae/audiotools/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (453 Bytes). View file
 
dac-vae/audiotools/core/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from . import util
2
+ from .audio_signal import AudioSignal
3
+ from .audio_signal import STFTParams
4
+ from .loudness import Meter
dac-vae/audiotools/core/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (304 Bytes). View file
 
dac-vae/audiotools/core/__pycache__/audio_signal.cpython-310.pyc ADDED
Binary file (45.5 kB). View file
 
dac-vae/audiotools/core/__pycache__/display.cpython-310.pyc ADDED
Binary file (6.39 kB). View file
 
dac-vae/audiotools/core/__pycache__/dsp.cpython-310.pyc ADDED
Binary file (11.6 kB). View file
 
dac-vae/audiotools/core/__pycache__/effects.cpython-310.pyc ADDED
Binary file (17.5 kB). View file
 
dac-vae/audiotools/core/__pycache__/ffmpeg.cpython-310.pyc ADDED
Binary file (5.63 kB). View file
 
dac-vae/audiotools/core/__pycache__/loudness.cpython-310.pyc ADDED
Binary file (8.47 kB). View file
 
dac-vae/audiotools/core/__pycache__/playback.cpython-310.pyc ADDED
Binary file (6.89 kB). View file
 
dac-vae/audiotools/core/__pycache__/util.cpython-310.pyc ADDED
Binary file (18.6 kB). View file
 
dac-vae/audiotools/core/__pycache__/whisper.cpython-310.pyc ADDED
Binary file (2.95 kB). View file
 
dac-vae/audiotools/core/audio_signal.py ADDED
@@ -0,0 +1,1682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import functools
3
+ import hashlib
4
+ import math
5
+ import pathlib
6
+ import tempfile
7
+ import typing
8
+ import warnings
9
+ from collections import namedtuple
10
+ from pathlib import Path
11
+
12
+ import julius
13
+ import numpy as np
14
+ import soundfile
15
+ import torch
16
+
17
+ from . import util
18
+ from .display import DisplayMixin
19
+ from .dsp import DSPMixin
20
+ from .effects import EffectMixin
21
+ from .effects import ImpulseResponseMixin
22
+ from .ffmpeg import FFMPEGMixin
23
+ from .loudness import LoudnessMixin
24
+ from .playback import PlayMixin
25
+ from .whisper import WhisperMixin
26
+
27
+
28
+ STFTParams = namedtuple(
29
+ "STFTParams",
30
+ ["window_length", "hop_length", "window_type", "match_stride", "padding_type"],
31
+ )
32
+ """
33
+ STFTParams object is a container that holds STFT parameters - window_length,
34
+ hop_length, and window_type. Not all parameters need to be specified. Ones that
35
+ are not specified will be inferred by the AudioSignal parameters.
36
+
37
+ Parameters
38
+ ----------
39
+ window_length : int, optional
40
+ Window length of STFT, by default ``0.032 * self.sample_rate``.
41
+ hop_length : int, optional
42
+ Hop length of STFT, by default ``window_length // 4``.
43
+ window_type : str, optional
44
+ Type of window to use, by default ``sqrt\_hann``.
45
+ match_stride : bool, optional
46
+ Whether to match the stride of convolutional layers, by default False
47
+ padding_type : str, optional
48
+ Type of padding to use, by default 'reflect'
49
+ """
50
+ STFTParams.__new__.__defaults__ = (None, None, None, None, None)
51
+
52
+
53
+ class AudioSignal(
54
+ EffectMixin,
55
+ LoudnessMixin,
56
+ PlayMixin,
57
+ ImpulseResponseMixin,
58
+ DSPMixin,
59
+ DisplayMixin,
60
+ FFMPEGMixin,
61
+ WhisperMixin,
62
+ ):
63
+ """This is the core object of this library. Audio is always
64
+ loaded into an AudioSignal, which then enables all the features
65
+ of this library, including audio augmentations, I/O, playback,
66
+ and more.
67
+
68
+ The structure of this object is that the base functionality
69
+ is defined in ``core/audio_signal.py``, while extensions to
70
+ that functionality are defined in the other ``core/*.py``
71
+ files. For example, all the display-based functionality
72
+ (e.g. plot spectrograms, waveforms, write to tensorboard)
73
+ are in ``core/display.py``.
74
+
75
+ Parameters
76
+ ----------
77
+ audio_path_or_array : typing.Union[torch.Tensor, str, Path, np.ndarray]
78
+ Object to create AudioSignal from. Can be a tensor, numpy array,
79
+ or a path to a file. The file is always reshaped to
80
+ sample_rate : int, optional
81
+ Sample rate of the audio. If different from underlying file, resampling is
82
+ performed. If passing in an array or tensor, this must be defined,
83
+ by default None
84
+ stft_params : STFTParams, optional
85
+ Parameters of STFT to use. , by default None
86
+ offset : float, optional
87
+ Offset in seconds to read from file, by default 0
88
+ duration : float, optional
89
+ Duration in seconds to read from file, by default None
90
+ device : str, optional
91
+ Device to load audio onto, by default None
92
+
93
+ Examples
94
+ --------
95
+ Loading an AudioSignal from an array, at a sample rate of
96
+ 44100.
97
+
98
+ >>> signal = AudioSignal(torch.randn(5*44100), 44100)
99
+
100
+ Note, the signal is reshaped to have a batch size, and one
101
+ audio channel:
102
+
103
+ >>> print(signal.shape)
104
+ (1, 1, 44100)
105
+
106
+ You can treat AudioSignals like tensors, and many of the same
107
+ functions you might use on tensors are defined for AudioSignals
108
+ as well:
109
+
110
+ >>> signal.to("cuda")
111
+ >>> signal.cuda()
112
+ >>> signal.clone()
113
+ >>> signal.detach()
114
+
115
+ Indexing AudioSignals returns an AudioSignal:
116
+
117
+ >>> signal[..., 3*44100:4*44100]
118
+
119
+ The above signal is 1 second long, and is also an AudioSignal.
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ audio_path_or_array: typing.Union[torch.Tensor, str, Path, np.ndarray],
125
+ sample_rate: int = None,
126
+ stft_params: STFTParams = None,
127
+ offset: float = 0,
128
+ duration: float = None,
129
+ device: str = None,
130
+ ):
131
+ audio_path = None
132
+ audio_array = None
133
+
134
+ if isinstance(audio_path_or_array, str):
135
+ audio_path = audio_path_or_array
136
+ elif isinstance(audio_path_or_array, pathlib.Path):
137
+ audio_path = audio_path_or_array
138
+ elif isinstance(audio_path_or_array, np.ndarray):
139
+ audio_array = audio_path_or_array
140
+ elif torch.is_tensor(audio_path_or_array):
141
+ audio_array = audio_path_or_array
142
+ else:
143
+ raise ValueError(
144
+ "audio_path_or_array must be either a Path, "
145
+ "string, numpy array, or torch Tensor!"
146
+ )
147
+
148
+ self.path_to_file = None
149
+
150
+ self.audio_data = None
151
+ self.sources = None # List of AudioSignal objects.
152
+ self.stft_data = None
153
+ if audio_path is not None:
154
+ self.load_from_file(
155
+ audio_path, offset=offset, duration=duration, device=device
156
+ )
157
+ elif audio_array is not None:
158
+ assert sample_rate is not None, "Must set sample rate!"
159
+ self.load_from_array(audio_array, sample_rate, device=device)
160
+
161
+ self.window = None
162
+ self.stft_params = stft_params
163
+
164
+ self.metadata = {
165
+ "offset": offset,
166
+ "duration": duration,
167
+ }
168
+
169
+ @property
170
+ def path_to_input_file(
171
+ self,
172
+ ):
173
+ """
174
+ Path to input file, if it exists.
175
+ Alias to ``path_to_file`` for backwards compatibility
176
+ """
177
+ return self.path_to_file
178
+
179
+ @classmethod
180
+ def excerpt(
181
+ cls,
182
+ audio_path: typing.Union[str, Path],
183
+ offset: float = None,
184
+ duration: float = None,
185
+ state: typing.Union[np.random.RandomState, int] = None,
186
+ **kwargs,
187
+ ):
188
+ """Randomly draw an excerpt of ``duration`` seconds from an
189
+ audio file specified at ``audio_path``, between ``offset`` seconds
190
+ and end of file. ``state`` can be used to seed the random draw.
191
+
192
+ Parameters
193
+ ----------
194
+ audio_path : typing.Union[str, Path]
195
+ Path to audio file to grab excerpt from.
196
+ offset : float, optional
197
+ Lower bound for the start time, in seconds drawn from
198
+ the file, by default None.
199
+ duration : float, optional
200
+ Duration of excerpt, in seconds, by default None
201
+ state : typing.Union[np.random.RandomState, int], optional
202
+ RandomState or seed of random state, by default None
203
+
204
+ Returns
205
+ -------
206
+ AudioSignal
207
+ AudioSignal containing excerpt.
208
+
209
+ Examples
210
+ --------
211
+ >>> signal = AudioSignal.excerpt("path/to/audio", duration=5)
212
+ """
213
+ info = util.info(audio_path)
214
+ total_duration = info.duration
215
+
216
+ state = util.random_state(state)
217
+ lower_bound = 0 if offset is None else offset
218
+ upper_bound = max(total_duration - duration, 0)
219
+ offset = state.uniform(lower_bound, upper_bound)
220
+
221
+ signal = cls(audio_path, offset=offset, duration=duration, **kwargs)
222
+ signal.metadata["offset"] = offset
223
+ signal.metadata["duration"] = duration
224
+
225
+ return signal
226
+
227
+ @classmethod
228
+ def salient_excerpt(
229
+ cls,
230
+ audio_path: typing.Union[str, Path],
231
+ loudness_cutoff: float = None,
232
+ num_tries: int = 8,
233
+ state: typing.Union[np.random.RandomState, int] = None,
234
+ **kwargs,
235
+ ):
236
+ """Similar to AudioSignal.excerpt, except it extracts excerpts only
237
+ if they are above a specified loudness threshold, which is computed via
238
+ a fast LUFS routine.
239
+
240
+ Parameters
241
+ ----------
242
+ audio_path : typing.Union[str, Path]
243
+ Path to audio file to grab excerpt from.
244
+ loudness_cutoff : float, optional
245
+ Loudness threshold in dB. Typical values are ``-40, -60``,
246
+ etc, by default None
247
+ num_tries : int, optional
248
+ Number of tries to grab an excerpt above the threshold
249
+ before giving up, by default 8.
250
+ state : typing.Union[np.random.RandomState, int], optional
251
+ RandomState or seed of random state, by default None
252
+ kwargs : dict
253
+ Keyword arguments to AudioSignal.excerpt
254
+
255
+ Returns
256
+ -------
257
+ AudioSignal
258
+ AudioSignal containing excerpt.
259
+
260
+
261
+ .. warning::
262
+ if ``num_tries`` is set to None, ``salient_excerpt`` may try forever, which can
263
+ result in an infinite loop if ``audio_path`` does not have
264
+ any loud enough excerpts.
265
+
266
+ Examples
267
+ --------
268
+ >>> signal = AudioSignal.salient_excerpt(
269
+ "path/to/audio",
270
+ loudness_cutoff=-40,
271
+ duration=5
272
+ )
273
+ """
274
+ state = util.random_state(state)
275
+ if loudness_cutoff is None:
276
+ excerpt = cls.excerpt(audio_path, state=state, **kwargs)
277
+ else:
278
+ loudness = -np.inf
279
+ num_try = 0
280
+ while loudness <= loudness_cutoff:
281
+ excerpt = cls.excerpt(audio_path, state=state, **kwargs)
282
+ loudness = excerpt.loudness()
283
+ num_try += 1
284
+ if num_tries is not None and num_try >= num_tries:
285
+ break
286
+ return excerpt
287
+
288
+ @classmethod
289
+ def zeros(
290
+ cls,
291
+ duration: float,
292
+ sample_rate: int,
293
+ num_channels: int = 1,
294
+ batch_size: int = 1,
295
+ **kwargs,
296
+ ):
297
+ """Helper function create an AudioSignal of all zeros.
298
+
299
+ Parameters
300
+ ----------
301
+ duration : float
302
+ Duration of AudioSignal
303
+ sample_rate : int
304
+ Sample rate of AudioSignal
305
+ num_channels : int, optional
306
+ Number of channels, by default 1
307
+ batch_size : int, optional
308
+ Batch size, by default 1
309
+
310
+ Returns
311
+ -------
312
+ AudioSignal
313
+ AudioSignal containing all zeros.
314
+
315
+ Examples
316
+ --------
317
+ Generate 5 seconds of all zeros at a sample rate of 44100.
318
+
319
+ >>> signal = AudioSignal.zeros(5.0, 44100)
320
+ """
321
+ n_samples = int(duration * sample_rate)
322
+ return cls(
323
+ torch.zeros(batch_size, num_channels, n_samples), sample_rate, **kwargs
324
+ )
325
+
326
+ @classmethod
327
+ def wave(
328
+ cls,
329
+ frequency: float,
330
+ duration: float,
331
+ sample_rate: int,
332
+ num_channels: int = 1,
333
+ shape: str = "sine",
334
+ **kwargs,
335
+ ):
336
+ """
337
+ Generate a waveform of a given frequency and shape.
338
+
339
+ Parameters
340
+ ----------
341
+ frequency : float
342
+ Frequency of the waveform
343
+ duration : float
344
+ Duration of the waveform
345
+ sample_rate : int
346
+ Sample rate of the waveform
347
+ num_channels : int, optional
348
+ Number of channels, by default 1
349
+ shape : str, optional
350
+ Shape of the waveform, by default "saw"
351
+ One of "sawtooth", "square", "sine", "triangle"
352
+ kwargs : dict
353
+ Keyword arguments to AudioSignal
354
+ """
355
+ n_samples = int(duration * sample_rate)
356
+ t = torch.linspace(0, duration, n_samples)
357
+ if shape == "sawtooth":
358
+ from scipy.signal import sawtooth
359
+
360
+ wave_data = sawtooth(2 * np.pi * frequency * t, 0.5)
361
+ elif shape == "square":
362
+ from scipy.signal import square
363
+
364
+ wave_data = square(2 * np.pi * frequency * t)
365
+ elif shape == "sine":
366
+ wave_data = np.sin(2 * np.pi * frequency * t)
367
+ elif shape == "triangle":
368
+ from scipy.signal import sawtooth
369
+
370
+ # frequency is doubled by the abs call, so omit the 2 in 2pi
371
+ wave_data = sawtooth(np.pi * frequency * t, 0.5)
372
+ wave_data = -np.abs(wave_data) * 2 + 1
373
+ else:
374
+ raise ValueError(f"Invalid shape {shape}")
375
+
376
+ wave_data = torch.tensor(wave_data, dtype=torch.float32)
377
+ wave_data = wave_data.unsqueeze(0).unsqueeze(0).repeat(1, num_channels, 1)
378
+ return cls(wave_data, sample_rate, **kwargs)
379
+
380
+ @classmethod
381
+ def batch(
382
+ cls,
383
+ audio_signals: list,
384
+ pad_signals: bool = False,
385
+ truncate_signals: bool = False,
386
+ resample: bool = False,
387
+ dim: int = 0,
388
+ ):
389
+ """Creates a batched AudioSignal from a list of AudioSignals.
390
+
391
+ Parameters
392
+ ----------
393
+ audio_signals : list[AudioSignal]
394
+ List of AudioSignal objects
395
+ pad_signals : bool, optional
396
+ Whether to pad signals to length of the maximum length
397
+ AudioSignal in the list, by default False
398
+ truncate_signals : bool, optional
399
+ Whether to truncate signals to length of shortest length
400
+ AudioSignal in the list, by default False
401
+ resample : bool, optional
402
+ Whether to resample AudioSignal to the sample rate of
403
+ the first AudioSignal in the list, by default False
404
+ dim : int, optional
405
+ Dimension along which to batch the signals.
406
+
407
+ Returns
408
+ -------
409
+ AudioSignal
410
+ Batched AudioSignal.
411
+
412
+ Raises
413
+ ------
414
+ RuntimeError
415
+ If not all AudioSignals are the same sample rate, and
416
+ ``resample=False``, an error is raised.
417
+ RuntimeError
418
+ If not all AudioSignals are the same the length, and
419
+ both ``pad_signals=False`` and ``truncate_signals=False``,
420
+ an error is raised.
421
+
422
+ Examples
423
+ --------
424
+ Batching a bunch of random signals:
425
+
426
+ >>> signal_list = [AudioSignal(torch.randn(44100), 44100) for _ in range(10)]
427
+ >>> signal = AudioSignal.batch(signal_list)
428
+ >>> print(signal.shape)
429
+ (10, 1, 44100)
430
+
431
+ """
432
+ signal_lengths = [x.signal_length for x in audio_signals]
433
+ sample_rates = [x.sample_rate for x in audio_signals]
434
+
435
+ if len(set(sample_rates)) != 1:
436
+ if resample:
437
+ for x in audio_signals:
438
+ x.resample(sample_rates[0])
439
+ else:
440
+ raise RuntimeError(
441
+ f"Not all signals had the same sample rate! Got {sample_rates}. "
442
+ f"All signals must have the same sample rate, or resample must be True. "
443
+ )
444
+
445
+ if len(set(signal_lengths)) != 1:
446
+ if pad_signals:
447
+ max_length = max(signal_lengths)
448
+ for x in audio_signals:
449
+ pad_len = max_length - x.signal_length
450
+ x.zero_pad(0, pad_len)
451
+ elif truncate_signals:
452
+ min_length = min(signal_lengths)
453
+ for x in audio_signals:
454
+ x.truncate_samples(min_length)
455
+ else:
456
+ raise RuntimeError(
457
+ f"Not all signals had the same length! Got {signal_lengths}. "
458
+ f"All signals must be the same length, or pad_signals/truncate_signals "
459
+ f"must be True. "
460
+ )
461
+ # Concatenate along the specified dimension (default 0)
462
+ audio_data = torch.cat([x.audio_data for x in audio_signals], dim=dim)
463
+ audio_paths = [x.path_to_file for x in audio_signals]
464
+
465
+ batched_signal = cls(
466
+ audio_data,
467
+ sample_rate=audio_signals[0].sample_rate,
468
+ )
469
+ batched_signal.path_to_file = audio_paths
470
+ return batched_signal
471
+
472
+ # I/O
473
+ def load_from_file(
474
+ self,
475
+ audio_path: typing.Union[str, Path],
476
+ offset: float,
477
+ duration: float,
478
+ device: str = "cpu",
479
+ ):
480
+ """Loads data from file. Used internally when AudioSignal
481
+ is instantiated with a path to a file.
482
+
483
+ Parameters
484
+ ----------
485
+ audio_path : typing.Union[str, Path]
486
+ Path to file
487
+ offset : float
488
+ Offset in seconds
489
+ duration : float
490
+ Duration in seconds
491
+ device : str, optional
492
+ Device to put AudioSignal on, by default "cpu"
493
+
494
+ Returns
495
+ -------
496
+ AudioSignal
497
+ AudioSignal loaded from file
498
+ """
499
+ import librosa
500
+
501
+ data, sample_rate = librosa.load(
502
+ audio_path,
503
+ offset=offset,
504
+ duration=duration,
505
+ sr=None,
506
+ mono=False,
507
+ )
508
+ data = util.ensure_tensor(data)
509
+ if data.shape[-1] == 0:
510
+ raise RuntimeError(
511
+ f"Audio file {audio_path} with offset {offset} and duration {duration} is empty!"
512
+ )
513
+
514
+ if data.ndim < 2:
515
+ data = data.unsqueeze(0)
516
+ if data.ndim < 3:
517
+ data = data.unsqueeze(0)
518
+ self.audio_data = data
519
+
520
+ self.original_signal_length = self.signal_length
521
+
522
+ self.sample_rate = sample_rate
523
+ self.path_to_file = audio_path
524
+ return self.to(device)
525
+
526
+ def load_from_array(
527
+ self,
528
+ audio_array: typing.Union[torch.Tensor, np.ndarray],
529
+ sample_rate: int,
530
+ device: str = "cpu",
531
+ ):
532
+ """Loads data from array, reshaping it to be exactly 3
533
+ dimensions. Used internally when AudioSignal is called
534
+ with a tensor or an array.
535
+
536
+ Parameters
537
+ ----------
538
+ audio_array : typing.Union[torch.Tensor, np.ndarray]
539
+ Array/tensor of audio of samples.
540
+ sample_rate : int
541
+ Sample rate of audio
542
+ device : str, optional
543
+ Device to move audio onto, by default "cpu"
544
+
545
+ Returns
546
+ -------
547
+ AudioSignal
548
+ AudioSignal loaded from array
549
+ """
550
+ audio_data = util.ensure_tensor(audio_array)
551
+
552
+ if audio_data.dtype == torch.double:
553
+ audio_data = audio_data.float()
554
+
555
+ if audio_data.ndim < 2:
556
+ audio_data = audio_data.unsqueeze(0)
557
+ if audio_data.ndim < 3:
558
+ audio_data = audio_data.unsqueeze(0)
559
+ self.audio_data = audio_data
560
+
561
+ self.original_signal_length = self.signal_length
562
+
563
+ self.sample_rate = sample_rate
564
+ return self.to(device)
565
+
566
+ def write(self, audio_path: typing.Union[str, Path]):
567
+ """Writes audio to a file. Only writes the audio
568
+ that is in the very first item of the batch. To write other items
569
+ in the batch, index the signal along the batch dimension
570
+ before writing. After writing, the signal's ``path_to_file``
571
+ attribute is updated to the new path.
572
+
573
+ Parameters
574
+ ----------
575
+ audio_path : typing.Union[str, Path]
576
+ Path to write audio to.
577
+
578
+ Returns
579
+ -------
580
+ AudioSignal
581
+ Returns original AudioSignal, so you can use this in a fluent
582
+ interface.
583
+
584
+ Examples
585
+ --------
586
+ Creating and writing a signal to disk:
587
+
588
+ >>> signal = AudioSignal(torch.randn(10, 1, 44100), 44100)
589
+ >>> signal.write("/tmp/out.wav")
590
+
591
+ Writing a different element of the batch:
592
+
593
+ >>> signal[5].write("/tmp/out.wav")
594
+
595
+ Using this in a fluent interface:
596
+
597
+ >>> signal.write("/tmp/original.wav").low_pass(4000).write("/tmp/lowpass.wav")
598
+
599
+ """
600
+ if self.audio_data[0].abs().max() > 1:
601
+ warnings.warn("Audio amplitude > 1 clipped when saving")
602
+ soundfile.write(str(audio_path), self.audio_data[0].numpy().T, self.sample_rate)
603
+
604
+ self.path_to_file = audio_path
605
+ return self
606
+
607
+ def deepcopy(self):
608
+ """Copies the signal and all of its attributes.
609
+
610
+ Returns
611
+ -------
612
+ AudioSignal
613
+ Deep copy of the audio signal.
614
+ """
615
+ return copy.deepcopy(self)
616
+
617
+ def copy(self):
618
+ """Shallow copy of signal.
619
+
620
+ Returns
621
+ -------
622
+ AudioSignal
623
+ Shallow copy of the audio signal.
624
+ """
625
+ return copy.copy(self)
626
+
627
+ def clone(self):
628
+ """Clones all tensors contained in the AudioSignal,
629
+ and returns a copy of the signal with everything
630
+ cloned. Useful when using AudioSignal within autograd
631
+ computation graphs.
632
+
633
+ Relevant attributes are the stft data, the audio data,
634
+ and the loudness of the file.
635
+
636
+ Returns
637
+ -------
638
+ AudioSignal
639
+ Clone of AudioSignal.
640
+ """
641
+ clone = type(self)(
642
+ self.audio_data.clone(),
643
+ self.sample_rate,
644
+ stft_params=self.stft_params,
645
+ )
646
+ if self.stft_data is not None:
647
+ clone.stft_data = self.stft_data.clone()
648
+ if self._loudness is not None:
649
+ clone._loudness = self._loudness.clone()
650
+ clone.path_to_file = copy.deepcopy(self.path_to_file)
651
+ clone.metadata = copy.deepcopy(self.metadata)
652
+ return clone
653
+
654
+ def detach(self):
655
+ """Detaches tensors contained in AudioSignal.
656
+
657
+ Relevant attributes are the stft data, the audio data,
658
+ and the loudness of the file.
659
+
660
+ Returns
661
+ -------
662
+ AudioSignal
663
+ Same signal, but with all tensors detached.
664
+ """
665
+ if self._loudness is not None:
666
+ self._loudness = self._loudness.detach()
667
+ if self.stft_data is not None:
668
+ self.stft_data = self.stft_data.detach()
669
+
670
+ self.audio_data = self.audio_data.detach()
671
+ return self
672
+
673
+ def hash(self):
674
+ """Writes the audio data to a temporary file, and then
675
+ hashes it using hashlib. Useful for creating a file
676
+ name based on the audio content.
677
+
678
+ Returns
679
+ -------
680
+ str
681
+ Hash of audio data.
682
+
683
+ Examples
684
+ --------
685
+ Creating a signal, and writing it to a unique file name:
686
+
687
+ >>> signal = AudioSignal(torch.randn(44100), 44100)
688
+ >>> hash = signal.hash()
689
+ >>> signal.write(f"{hash}.wav")
690
+
691
+ """
692
+ with tempfile.NamedTemporaryFile(suffix=".wav") as f:
693
+ self.write(f.name)
694
+ h = hashlib.sha256()
695
+ b = bytearray(128 * 1024)
696
+ mv = memoryview(b)
697
+ with open(f.name, "rb", buffering=0) as f:
698
+ for n in iter(lambda: f.readinto(mv), 0):
699
+ h.update(mv[:n])
700
+ file_hash = h.hexdigest()
701
+ return file_hash
702
+
703
+ # Signal operations
704
+ def to_mono(self):
705
+ """Converts audio data to mono audio, by taking the mean
706
+ along the channels dimension.
707
+
708
+ Returns
709
+ -------
710
+ AudioSignal
711
+ AudioSignal with mean of channels.
712
+ """
713
+ self.audio_data = self.audio_data.mean(1, keepdim=True)
714
+ return self
715
+
716
+ def resample(self, sample_rate: int):
717
+ """Resamples the audio, using sinc interpolation. This works on both
718
+ cpu and gpu, and is much faster on gpu.
719
+
720
+ Parameters
721
+ ----------
722
+ sample_rate : int
723
+ Sample rate to resample to.
724
+
725
+ Returns
726
+ -------
727
+ AudioSignal
728
+ Resampled AudioSignal
729
+ """
730
+ if sample_rate == self.sample_rate:
731
+ return self
732
+ self.audio_data = julius.resample_frac(
733
+ self.audio_data, self.sample_rate, sample_rate
734
+ )
735
+ self.sample_rate = sample_rate
736
+ return self
737
+
738
+ # Tensor operations
739
+ def to(self, device: str):
740
+ """Moves all tensors contained in signal to the specified device.
741
+
742
+ Parameters
743
+ ----------
744
+ device : str
745
+ Device to move AudioSignal onto. Typical values are
746
+ "cuda", "cpu", or "cuda:n" to specify the nth gpu.
747
+
748
+ Returns
749
+ -------
750
+ AudioSignal
751
+ AudioSignal with all tensors moved to specified device.
752
+ """
753
+ if self._loudness is not None:
754
+ self._loudness = self._loudness.to(device)
755
+ if self.stft_data is not None:
756
+ self.stft_data = self.stft_data.to(device)
757
+ if self.audio_data is not None:
758
+ self.audio_data = self.audio_data.to(device)
759
+ return self
760
+
761
+ def float(self):
762
+ """Calls ``.float()`` on ``self.audio_data``.
763
+
764
+ Returns
765
+ -------
766
+ AudioSignal
767
+ """
768
+ self.audio_data = self.audio_data.float()
769
+ return self
770
+
771
+ def cpu(self):
772
+ """Moves AudioSignal to cpu.
773
+
774
+ Returns
775
+ -------
776
+ AudioSignal
777
+ """
778
+ return self.to("cpu")
779
+
780
+ def cuda(self): # pragma: no cover
781
+ """Moves AudioSignal to cuda.
782
+
783
+ Returns
784
+ -------
785
+ AudioSignal
786
+ """
787
+ return self.to("cuda")
788
+
789
+ def numpy(self):
790
+ """Detaches ``self.audio_data``, moves to cpu, and converts to numpy.
791
+
792
+ Returns
793
+ -------
794
+ np.ndarray
795
+ Audio data as a numpy array.
796
+ """
797
+ return self.audio_data.detach().cpu().numpy()
798
+
799
+ def zero_pad(self, before: int, after: int):
800
+ """Zero pads the audio_data tensor before and after.
801
+
802
+ Parameters
803
+ ----------
804
+ before : int
805
+ How many zeros to prepend to audio.
806
+ after : int
807
+ How many zeros to append to audio.
808
+
809
+ Returns
810
+ -------
811
+ AudioSignal
812
+ AudioSignal with padding applied.
813
+ """
814
+ self.audio_data = torch.nn.functional.pad(self.audio_data, (before, after))
815
+ return self
816
+
817
+ def zero_pad_to(self, length: int, mode: str = "after"):
818
+ """Pad with zeros to a specified length, either before or after
819
+ the audio data.
820
+
821
+ Parameters
822
+ ----------
823
+ length : int
824
+ Length to pad to
825
+ mode : str, optional
826
+ Whether to prepend or append zeros to signal, by default "after"
827
+
828
+ Returns
829
+ -------
830
+ AudioSignal
831
+ AudioSignal with padding applied.
832
+ """
833
+ if mode == "before":
834
+ self.zero_pad(max(length - self.signal_length, 0), 0)
835
+ elif mode == "after":
836
+ self.zero_pad(0, max(length - self.signal_length, 0))
837
+ return self
838
+
839
+ def trim(self, before: int, after: int):
840
+ """Trims the audio_data tensor before and after.
841
+
842
+ Parameters
843
+ ----------
844
+ before : int
845
+ How many samples to trim from beginning.
846
+ after : int
847
+ How many samples to trim from end.
848
+
849
+ Returns
850
+ -------
851
+ AudioSignal
852
+ AudioSignal with trimming applied.
853
+ """
854
+ if after == 0:
855
+ self.audio_data = self.audio_data[..., before:]
856
+ else:
857
+ self.audio_data = self.audio_data[..., before:-after]
858
+ return self
859
+
860
+ def truncate_samples(self, length_in_samples: int):
861
+ """Truncate signal to specified length.
862
+
863
+ Parameters
864
+ ----------
865
+ length_in_samples : int
866
+ Truncate to this many samples.
867
+
868
+ Returns
869
+ -------
870
+ AudioSignal
871
+ AudioSignal with truncation applied.
872
+ """
873
+ self.audio_data = self.audio_data[..., :length_in_samples]
874
+ return self
875
+
876
+ @property
877
+ def device(self):
878
+ """Get device that AudioSignal is on.
879
+
880
+ Returns
881
+ -------
882
+ torch.device
883
+ Device that AudioSignal is on.
884
+ """
885
+ if self.audio_data is not None:
886
+ device = self.audio_data.device
887
+ elif self.stft_data is not None:
888
+ device = self.stft_data.device
889
+ return device
890
+
891
+ # Properties
892
+ @property
893
+ def audio_data(self):
894
+ """Returns the audio data tensor in the object.
895
+
896
+ Audio data is always of the shape
897
+ (batch_size, num_channels, num_samples). If value has less
898
+ than 3 dims (e.g. is (num_channels, num_samples)), then it will
899
+ be reshaped to (1, num_channels, num_samples) - a batch size of 1.
900
+
901
+ Parameters
902
+ ----------
903
+ data : typing.Union[torch.Tensor, np.ndarray]
904
+ Audio data to set.
905
+
906
+ Returns
907
+ -------
908
+ torch.Tensor
909
+ Audio samples.
910
+ """
911
+ return self._audio_data
912
+
913
+ @audio_data.setter
914
+ def audio_data(self, data: typing.Union[torch.Tensor, np.ndarray]):
915
+ if data is not None:
916
+ assert torch.is_tensor(data), "audio_data should be torch.Tensor"
917
+ assert data.ndim == 3, "audio_data should be 3-dim (B, C, T)"
918
+ self._audio_data = data
919
+ # Old loudness value not guaranteed to be right, reset it.
920
+ self._loudness = None
921
+ return
922
+
923
+ # alias for audio_data
924
+ samples = audio_data
925
+
926
+ @property
927
+ def stft_data(self):
928
+ """Returns the STFT data inside the signal. Shape is
929
+ (batch, channels, frequencies, time).
930
+
931
+ Returns
932
+ -------
933
+ torch.Tensor
934
+ Complex spectrogram data.
935
+ """
936
+ return self._stft_data
937
+
938
+ @stft_data.setter
939
+ def stft_data(self, data: typing.Union[torch.Tensor, np.ndarray]):
940
+ if data is not None:
941
+ assert torch.is_tensor(data) and torch.is_complex(data)
942
+ if self.stft_data is not None and self.stft_data.shape != data.shape:
943
+ warnings.warn("stft_data changed shape")
944
+ self._stft_data = data
945
+ return
946
+
947
+ @property
948
+ def batch_size(self):
949
+ """Batch size of audio signal.
950
+
951
+ Returns
952
+ -------
953
+ int
954
+ Batch size of signal.
955
+ """
956
+ return self.audio_data.shape[0]
957
+
958
+ @property
959
+ def signal_length(self):
960
+ """Length of audio signal.
961
+
962
+ Returns
963
+ -------
964
+ int
965
+ Length of signal in samples.
966
+ """
967
+ return self.audio_data.shape[-1]
968
+
969
+ # alias for signal_length
970
+ length = signal_length
971
+
972
+ @property
973
+ def shape(self):
974
+ """Shape of audio data.
975
+
976
+ Returns
977
+ -------
978
+ tuple
979
+ Shape of audio data.
980
+ """
981
+ return self.audio_data.shape
982
+
983
+ @property
984
+ def signal_duration(self):
985
+ """Length of audio signal in seconds.
986
+
987
+ Returns
988
+ -------
989
+ float
990
+ Length of signal in seconds.
991
+ """
992
+ return self.signal_length / self.sample_rate
993
+
994
+ # alias for signal_duration
995
+ duration = signal_duration
996
+
997
+ @property
998
+ def num_channels(self):
999
+ """Number of audio channels.
1000
+
1001
+ Returns
1002
+ -------
1003
+ int
1004
+ Number of audio channels.
1005
+ """
1006
+ return self.audio_data.shape[1]
1007
+
1008
+ # STFT
1009
+ @staticmethod
1010
+ @functools.lru_cache(None)
1011
+ def get_window(window_type: str, window_length: int, device: str):
1012
+ """Wrapper around scipy.signal.get_window so one can also get the
1013
+ popular sqrt-hann window. This function caches for efficiency
1014
+ using functools.lru\_cache.
1015
+
1016
+ Parameters
1017
+ ----------
1018
+ window_type : str
1019
+ Type of window to get
1020
+ window_length : int
1021
+ Length of the window
1022
+ device : str
1023
+ Device to put window onto.
1024
+
1025
+ Returns
1026
+ -------
1027
+ torch.Tensor
1028
+ Window returned by scipy.signal.get_window, as a tensor.
1029
+ """
1030
+ from scipy import signal
1031
+
1032
+ if window_type == "average":
1033
+ window = np.ones(window_length) / window_length
1034
+ elif window_type == "sqrt_hann":
1035
+ window = np.sqrt(signal.get_window("hann", window_length))
1036
+ else:
1037
+ window = signal.get_window(window_type, window_length)
1038
+ window = torch.from_numpy(window).to(device).float()
1039
+ return window
1040
+
1041
+ @property
1042
+ def stft_params(self):
1043
+ """Returns STFTParams object, which can be re-used to other
1044
+ AudioSignals.
1045
+
1046
+ This property can be set as well. If values are not defined in STFTParams,
1047
+ they are inferred automatically from the signal properties. The default is to use
1048
+ 32ms windows, with 8ms hop length, and the square root of the hann window.
1049
+
1050
+ Returns
1051
+ -------
1052
+ STFTParams
1053
+ STFT parameters for the AudioSignal.
1054
+
1055
+ Examples
1056
+ --------
1057
+ >>> stft_params = STFTParams(128, 32)
1058
+ >>> signal1 = AudioSignal(torch.randn(44100), 44100, stft_params=stft_params)
1059
+ >>> signal2 = AudioSignal(torch.randn(44100), 44100, stft_params=signal1.stft_params)
1060
+ >>> signal1.stft_params = STFTParams() # Defaults
1061
+ """
1062
+ return self._stft_params
1063
+
1064
+ @stft_params.setter
1065
+ def stft_params(self, value: STFTParams):
1066
+ default_win_len = int(2 ** (np.ceil(np.log2(0.032 * self.sample_rate))))
1067
+ default_hop_len = default_win_len // 4
1068
+ default_win_type = "hann"
1069
+ default_match_stride = False
1070
+ default_padding_type = "reflect"
1071
+
1072
+ default_stft_params = STFTParams(
1073
+ window_length=default_win_len,
1074
+ hop_length=default_hop_len,
1075
+ window_type=default_win_type,
1076
+ match_stride=default_match_stride,
1077
+ padding_type=default_padding_type,
1078
+ )._asdict()
1079
+
1080
+ value = value._asdict() if value else default_stft_params
1081
+
1082
+ for key in default_stft_params:
1083
+ if value[key] is None:
1084
+ value[key] = default_stft_params[key]
1085
+
1086
+ self._stft_params = STFTParams(**value)
1087
+ self.stft_data = None
1088
+
1089
+ def compute_stft_padding(
1090
+ self, window_length: int, hop_length: int, match_stride: bool
1091
+ ):
1092
+ """Compute how the STFT should be padded, based on match\_stride.
1093
+
1094
+ Parameters
1095
+ ----------
1096
+ window_length : int
1097
+ Window length of STFT.
1098
+ hop_length : int
1099
+ Hop length of STFT.
1100
+ match_stride : bool
1101
+ Whether or not to match stride, making the STFT have the same alignment as
1102
+ convolutional layers.
1103
+
1104
+ Returns
1105
+ -------
1106
+ tuple
1107
+ Amount to pad on either side of audio.
1108
+ """
1109
+ length = self.signal_length
1110
+
1111
+ if match_stride:
1112
+ assert (
1113
+ hop_length == window_length // 4
1114
+ ), "For match_stride, hop must equal n_fft // 4"
1115
+ right_pad = math.ceil(length / hop_length) * hop_length - length
1116
+ pad = (window_length - hop_length) // 2
1117
+ else:
1118
+ right_pad = 0
1119
+ pad = 0
1120
+
1121
+ return right_pad, pad
1122
+
1123
+ def stft(
1124
+ self,
1125
+ window_length: int = None,
1126
+ hop_length: int = None,
1127
+ window_type: str = None,
1128
+ match_stride: bool = None,
1129
+ padding_type: str = None,
1130
+ ):
1131
+ """Computes the short-time Fourier transform of the audio data,
1132
+ with specified STFT parameters.
1133
+
1134
+ Parameters
1135
+ ----------
1136
+ window_length : int, optional
1137
+ Window length of STFT, by default ``0.032 * self.sample_rate``.
1138
+ hop_length : int, optional
1139
+ Hop length of STFT, by default ``window_length // 4``.
1140
+ window_type : str, optional
1141
+ Type of window to use, by default ``sqrt\_hann``.
1142
+ match_stride : bool, optional
1143
+ Whether to match the stride of convolutional layers, by default False
1144
+ padding_type : str, optional
1145
+ Type of padding to use, by default 'reflect'
1146
+
1147
+ Returns
1148
+ -------
1149
+ torch.Tensor
1150
+ STFT of audio data.
1151
+
1152
+ Examples
1153
+ --------
1154
+ Compute the STFT of an AudioSignal:
1155
+
1156
+ >>> signal = AudioSignal(torch.randn(44100), 44100)
1157
+ >>> signal.stft()
1158
+
1159
+ Vary the window and hop length:
1160
+
1161
+ >>> stft_params = [STFTParams(128, 32), STFTParams(512, 128)]
1162
+ >>> for stft_param in stft_params:
1163
+ >>> signal.stft_params = stft_params
1164
+ >>> signal.stft()
1165
+
1166
+ """
1167
+ window_length = (
1168
+ self.stft_params.window_length
1169
+ if window_length is None
1170
+ else int(window_length)
1171
+ )
1172
+ hop_length = (
1173
+ self.stft_params.hop_length if hop_length is None else int(hop_length)
1174
+ )
1175
+ window_type = (
1176
+ self.stft_params.window_type if window_type is None else window_type
1177
+ )
1178
+ match_stride = (
1179
+ self.stft_params.match_stride if match_stride is None else match_stride
1180
+ )
1181
+ padding_type = (
1182
+ self.stft_params.padding_type if padding_type is None else padding_type
1183
+ )
1184
+
1185
+ window = self.get_window(window_type, window_length, self.audio_data.device)
1186
+ window = window.to(self.audio_data.device)
1187
+
1188
+ audio_data = self.audio_data
1189
+ right_pad, pad = self.compute_stft_padding(
1190
+ window_length, hop_length, match_stride
1191
+ )
1192
+ audio_data = torch.nn.functional.pad(
1193
+ audio_data, (pad, pad + right_pad), padding_type
1194
+ )
1195
+ stft_data = torch.stft(
1196
+ audio_data.reshape(-1, audio_data.shape[-1]),
1197
+ n_fft=window_length,
1198
+ hop_length=hop_length,
1199
+ window=window,
1200
+ return_complex=True,
1201
+ center=True,
1202
+ )
1203
+ _, nf, nt = stft_data.shape
1204
+ stft_data = stft_data.reshape(self.batch_size, self.num_channels, nf, nt)
1205
+
1206
+ if match_stride:
1207
+ # Drop first two and last two frames, which are added
1208
+ # because of padding. Now num_frames * hop_length = num_samples.
1209
+ stft_data = stft_data[..., 2:-2]
1210
+ self.stft_data = stft_data
1211
+
1212
+ return stft_data
1213
+
1214
+ def istft(
1215
+ self,
1216
+ window_length: int = None,
1217
+ hop_length: int = None,
1218
+ window_type: str = None,
1219
+ match_stride: bool = None,
1220
+ length: int = None,
1221
+ ):
1222
+ """Computes inverse STFT and sets it to audio\_data.
1223
+
1224
+ Parameters
1225
+ ----------
1226
+ window_length : int, optional
1227
+ Window length of STFT, by default ``0.032 * self.sample_rate``.
1228
+ hop_length : int, optional
1229
+ Hop length of STFT, by default ``window_length // 4``.
1230
+ window_type : str, optional
1231
+ Type of window to use, by default ``sqrt\_hann``.
1232
+ match_stride : bool, optional
1233
+ Whether to match the stride of convolutional layers, by default False
1234
+ length : int, optional
1235
+ Original length of signal, by default None
1236
+
1237
+ Returns
1238
+ -------
1239
+ AudioSignal
1240
+ AudioSignal with istft applied.
1241
+
1242
+ Raises
1243
+ ------
1244
+ RuntimeError
1245
+ Raises an error if stft was not called prior to istft on the signal,
1246
+ or if stft_data is not set.
1247
+ """
1248
+ if self.stft_data is None:
1249
+ raise RuntimeError("Cannot do inverse STFT without self.stft_data!")
1250
+
1251
+ window_length = (
1252
+ self.stft_params.window_length
1253
+ if window_length is None
1254
+ else int(window_length)
1255
+ )
1256
+ hop_length = (
1257
+ self.stft_params.hop_length if hop_length is None else int(hop_length)
1258
+ )
1259
+ window_type = (
1260
+ self.stft_params.window_type if window_type is None else window_type
1261
+ )
1262
+ match_stride = (
1263
+ self.stft_params.match_stride if match_stride is None else match_stride
1264
+ )
1265
+
1266
+ window = self.get_window(window_type, window_length, self.stft_data.device)
1267
+
1268
+ nb, nch, nf, nt = self.stft_data.shape
1269
+ stft_data = self.stft_data.reshape(nb * nch, nf, nt)
1270
+ right_pad, pad = self.compute_stft_padding(
1271
+ window_length, hop_length, match_stride
1272
+ )
1273
+
1274
+ if length is None:
1275
+ length = self.original_signal_length
1276
+ length = length + 2 * pad + right_pad
1277
+
1278
+ if match_stride:
1279
+ # Zero-pad the STFT on either side, putting back the frames that were
1280
+ # dropped in stft().
1281
+ stft_data = torch.nn.functional.pad(stft_data, (2, 2))
1282
+
1283
+ audio_data = torch.istft(
1284
+ stft_data,
1285
+ n_fft=window_length,
1286
+ hop_length=hop_length,
1287
+ window=window,
1288
+ length=length,
1289
+ center=True,
1290
+ )
1291
+ audio_data = audio_data.reshape(nb, nch, -1)
1292
+ if match_stride:
1293
+ audio_data = audio_data[..., pad : -(pad + right_pad)]
1294
+ self.audio_data = audio_data
1295
+
1296
+ return self
1297
+
1298
+ @staticmethod
1299
+ @functools.lru_cache(None)
1300
+ def get_mel_filters(
1301
+ sr: int, n_fft: int, n_mels: int, fmin: float = 0.0, fmax: float = None
1302
+ ):
1303
+ """Create a Filterbank matrix to combine FFT bins into Mel-frequency bins.
1304
+
1305
+ Parameters
1306
+ ----------
1307
+ sr : int
1308
+ Sample rate of audio
1309
+ n_fft : int
1310
+ Number of FFT bins
1311
+ n_mels : int
1312
+ Number of mels
1313
+ fmin : float, optional
1314
+ Lowest frequency, in Hz, by default 0.0
1315
+ fmax : float, optional
1316
+ Highest frequency, by default None
1317
+
1318
+ Returns
1319
+ -------
1320
+ np.ndarray [shape=(n_mels, 1 + n_fft/2)]
1321
+ Mel transform matrix
1322
+ """
1323
+ from librosa.filters import mel as librosa_mel_fn
1324
+
1325
+ return librosa_mel_fn(
1326
+ sr=sr,
1327
+ n_fft=n_fft,
1328
+ n_mels=n_mels,
1329
+ fmin=fmin,
1330
+ fmax=fmax,
1331
+ )
1332
+
1333
+ def mel_spectrogram(
1334
+ self, n_mels: int = 80, mel_fmin: float = 0.0, mel_fmax: float = None, **kwargs
1335
+ ):
1336
+ """Computes a Mel spectrogram.
1337
+
1338
+ Parameters
1339
+ ----------
1340
+ n_mels : int, optional
1341
+ Number of mels, by default 80
1342
+ mel_fmin : float, optional
1343
+ Lowest frequency, in Hz, by default 0.0
1344
+ mel_fmax : float, optional
1345
+ Highest frequency, by default None
1346
+ kwargs : dict, optional
1347
+ Keyword arguments to self.stft().
1348
+
1349
+ Returns
1350
+ -------
1351
+ torch.Tensor [shape=(batch, channels, mels, time)]
1352
+ Mel spectrogram.
1353
+ """
1354
+ stft = self.stft(**kwargs)
1355
+ magnitude = torch.abs(stft)
1356
+
1357
+ nf = magnitude.shape[2]
1358
+ mel_basis = self.get_mel_filters(
1359
+ sr=self.sample_rate,
1360
+ n_fft=2 * (nf - 1),
1361
+ n_mels=n_mels,
1362
+ fmin=mel_fmin,
1363
+ fmax=mel_fmax,
1364
+ )
1365
+ mel_basis = torch.from_numpy(mel_basis).to(self.device)
1366
+
1367
+ mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
1368
+ mel_spectrogram = mel_spectrogram.transpose(-1, 2)
1369
+ return mel_spectrogram
1370
+
1371
+ @staticmethod
1372
+ @functools.lru_cache(None)
1373
+ def get_dct(n_mfcc: int, n_mels: int, norm: str = "ortho", device: str = None):
1374
+ """Create a discrete cosine transform (DCT) transformation matrix with shape (``n_mels``, ``n_mfcc``),
1375
+ it can be normalized depending on norm. For more information about dct:
1376
+ http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
1377
+
1378
+ Parameters
1379
+ ----------
1380
+ n_mfcc : int
1381
+ Number of mfccs
1382
+ n_mels : int
1383
+ Number of mels
1384
+ norm : str
1385
+ Use "ortho" to get a orthogonal matrix or None, by default "ortho"
1386
+ device : str, optional
1387
+ Device to load the transformation matrix on, by default None
1388
+
1389
+ Returns
1390
+ -------
1391
+ torch.Tensor [shape=(n_mels, n_mfcc)] T
1392
+ The dct transformation matrix.
1393
+ """
1394
+ from torchaudio.functional import create_dct
1395
+
1396
+ return create_dct(n_mfcc, n_mels, norm).to(device)
1397
+
1398
+ def mfcc(
1399
+ self, n_mfcc: int = 40, n_mels: int = 80, log_offset: float = 1e-6, **kwargs
1400
+ ):
1401
+ """Computes mel-frequency cepstral coefficients (MFCCs).
1402
+
1403
+ Parameters
1404
+ ----------
1405
+ n_mfcc : int, optional
1406
+ Number of mels, by default 40
1407
+ n_mels : int, optional
1408
+ Number of mels, by default 80
1409
+ log_offset: float, optional
1410
+ Small value to prevent numerical issues when trying to compute log(0), by default 1e-6
1411
+ kwargs : dict, optional
1412
+ Keyword arguments to self.mel_spectrogram(), note that some of them will be used for self.stft()
1413
+
1414
+ Returns
1415
+ -------
1416
+ torch.Tensor [shape=(batch, channels, mfccs, time)]
1417
+ MFCCs.
1418
+ """
1419
+
1420
+ mel_spectrogram = self.mel_spectrogram(n_mels, **kwargs)
1421
+ mel_spectrogram = torch.log(mel_spectrogram + log_offset)
1422
+ dct_mat = self.get_dct(n_mfcc, n_mels, "ortho", self.device)
1423
+
1424
+ mfcc = mel_spectrogram.transpose(-1, -2) @ dct_mat
1425
+ mfcc = mfcc.transpose(-1, -2)
1426
+ return mfcc
1427
+
1428
+ @property
1429
+ def magnitude(self):
1430
+ """Computes and returns the absolute value of the STFT, which
1431
+ is the magnitude. This value can also be set to some tensor.
1432
+ When set, ``self.stft_data`` is manipulated so that its magnitude
1433
+ matches what this is set to, and modulated by the phase.
1434
+
1435
+ Returns
1436
+ -------
1437
+ torch.Tensor
1438
+ Magnitude of STFT.
1439
+
1440
+ Examples
1441
+ --------
1442
+ >>> signal = AudioSignal(torch.randn(44100), 44100)
1443
+ >>> magnitude = signal.magnitude # Computes stft if not computed
1444
+ >>> magnitude[magnitude < magnitude.mean()] = 0
1445
+ >>> signal.magnitude = magnitude
1446
+ >>> signal.istft()
1447
+ """
1448
+ if self.stft_data is None:
1449
+ self.stft()
1450
+ return torch.abs(self.stft_data)
1451
+
1452
+ @magnitude.setter
1453
+ def magnitude(self, value):
1454
+ self.stft_data = value * torch.exp(1j * self.phase)
1455
+ return
1456
+
1457
+ def log_magnitude(
1458
+ self, ref_value: float = 1.0, amin: float = 1e-5, top_db: float = 80.0
1459
+ ):
1460
+ """Computes the log-magnitude of the spectrogram.
1461
+
1462
+ Parameters
1463
+ ----------
1464
+ ref_value : float, optional
1465
+ The magnitude is scaled relative to ``ref``: ``20 * log10(S / ref)``.
1466
+ Zeros in the output correspond to positions where ``S == ref``,
1467
+ by default 1.0
1468
+ amin : float, optional
1469
+ Minimum threshold for ``S`` and ``ref``, by default 1e-5
1470
+ top_db : float, optional
1471
+ Threshold the output at ``top_db`` below the peak:
1472
+ ``max(10 * log10(S/ref)) - top_db``, by default -80.0
1473
+
1474
+ Returns
1475
+ -------
1476
+ torch.Tensor
1477
+ Log-magnitude spectrogram
1478
+ """
1479
+ magnitude = self.magnitude
1480
+
1481
+ amin = amin**2
1482
+ log_spec = 10.0 * torch.log10(magnitude.pow(2).clamp(min=amin))
1483
+ log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value))
1484
+
1485
+ if top_db is not None:
1486
+ log_spec = torch.maximum(log_spec, log_spec.max() - top_db)
1487
+ return log_spec
1488
+
1489
+ @property
1490
+ def phase(self):
1491
+ """Computes and returns the phase of the STFT.
1492
+ This value can also be set to some tensor.
1493
+ When set, ``self.stft_data`` is manipulated so that its phase
1494
+ matches what this is set to, we original magnitudeith th.
1495
+
1496
+ Returns
1497
+ -------
1498
+ torch.Tensor
1499
+ Phase of STFT.
1500
+
1501
+ Examples
1502
+ --------
1503
+ >>> signal = AudioSignal(torch.randn(44100), 44100)
1504
+ >>> phase = signal.phase # Computes stft if not computed
1505
+ >>> phase[phase < phase.mean()] = 0
1506
+ >>> signal.phase = phase
1507
+ >>> signal.istft()
1508
+ """
1509
+ if self.stft_data is None:
1510
+ self.stft()
1511
+ return torch.angle(self.stft_data)
1512
+
1513
+ @phase.setter
1514
+ def phase(self, value):
1515
+ self.stft_data = self.magnitude * torch.exp(1j * value)
1516
+ return
1517
+
1518
+ # Operator overloading
1519
+ def __add__(self, other):
1520
+ new_signal = self.clone()
1521
+ new_signal.audio_data += util._get_value(other)
1522
+ return new_signal
1523
+
1524
+ def __iadd__(self, other):
1525
+ self.audio_data += util._get_value(other)
1526
+ return self
1527
+
1528
+ def __radd__(self, other):
1529
+ return self + other
1530
+
1531
+ def __sub__(self, other):
1532
+ new_signal = self.clone()
1533
+ new_signal.audio_data -= util._get_value(other)
1534
+ return new_signal
1535
+
1536
+ def __isub__(self, other):
1537
+ self.audio_data -= util._get_value(other)
1538
+ return self
1539
+
1540
+ def __mul__(self, other):
1541
+ new_signal = self.clone()
1542
+ new_signal.audio_data *= util._get_value(other)
1543
+ return new_signal
1544
+
1545
+ def __imul__(self, other):
1546
+ self.audio_data *= util._get_value(other)
1547
+ return self
1548
+
1549
+ def __rmul__(self, other):
1550
+ return self * other
1551
+
1552
+ # Representation
1553
+ def _info(self):
1554
+ dur = f"{self.signal_duration:0.3f}" if self.signal_duration else "[unknown]"
1555
+ info = {
1556
+ "duration": f"{dur} seconds",
1557
+ "batch_size": self.batch_size,
1558
+ "path": self.path_to_file if self.path_to_file else "path unknown",
1559
+ "sample_rate": self.sample_rate,
1560
+ "num_channels": self.num_channels if self.num_channels else "[unknown]",
1561
+ "audio_data.shape": self.audio_data.shape,
1562
+ "stft_params": self.stft_params,
1563
+ "device": self.device,
1564
+ }
1565
+
1566
+ return info
1567
+
1568
+ def markdown(self):
1569
+ """Produces a markdown representation of AudioSignal, in a markdown table.
1570
+
1571
+ Returns
1572
+ -------
1573
+ str
1574
+ Markdown representation of AudioSignal.
1575
+
1576
+ Examples
1577
+ --------
1578
+ >>> signal = AudioSignal(torch.randn(44100), 44100)
1579
+ >>> print(signal.markdown())
1580
+ | Key | Value
1581
+ |---|---
1582
+ | duration | 1.000 seconds |
1583
+ | batch_size | 1 |
1584
+ | path | path unknown |
1585
+ | sample_rate | 44100 |
1586
+ | num_channels | 1 |
1587
+ | audio_data.shape | torch.Size([1, 1, 44100]) |
1588
+ | stft_params | STFTParams(window_length=2048, hop_length=512, window_type='sqrt_hann', match_stride=False) |
1589
+ | device | cpu |
1590
+ """
1591
+ info = self._info()
1592
+
1593
+ FORMAT = "| Key | Value \n" "|---|--- \n"
1594
+ for k, v in info.items():
1595
+ row = f"| {k} | {v} |\n"
1596
+ FORMAT += row
1597
+ return FORMAT
1598
+
1599
+ def __str__(self):
1600
+ info = self._info()
1601
+
1602
+ desc = ""
1603
+ for k, v in info.items():
1604
+ desc += f"{k}: {v}\n"
1605
+ return desc
1606
+
1607
+ def __rich__(self):
1608
+ from rich.table import Table
1609
+
1610
+ info = self._info()
1611
+
1612
+ table = Table(title=f"{self.__class__.__name__}")
1613
+ table.add_column("Key", style="green")
1614
+ table.add_column("Value", style="cyan")
1615
+
1616
+ for k, v in info.items():
1617
+ table.add_row(k, str(v))
1618
+ return table
1619
+
1620
+ # Comparison
1621
+ def __eq__(self, other):
1622
+ for k, v in list(self.__dict__.items()):
1623
+ if torch.is_tensor(v):
1624
+ if not torch.allclose(v, other.__dict__[k], atol=1e-6):
1625
+ max_error = (v - other.__dict__[k]).abs().max()
1626
+ print(f"Max abs error for {k}: {max_error}")
1627
+ return False
1628
+ return True
1629
+
1630
+ # Indexing
1631
+ def __getitem__(self, key):
1632
+ if torch.is_tensor(key) and key.ndim == 0 and key.item() is True:
1633
+ assert self.batch_size == 1
1634
+ audio_data = self.audio_data
1635
+ _loudness = self._loudness
1636
+ stft_data = self.stft_data
1637
+
1638
+ elif isinstance(key, (bool, int, list, slice, tuple)) or (
1639
+ torch.is_tensor(key) and key.ndim <= 1
1640
+ ):
1641
+ # Indexing only on the batch dimension.
1642
+ # Then let's copy over relevant stuff.
1643
+ # Future work: make this work for time-indexing
1644
+ # as well, using the hop length.
1645
+ audio_data = self.audio_data[key]
1646
+ _loudness = self._loudness[key] if self._loudness is not None else None
1647
+ stft_data = self.stft_data[key] if self.stft_data is not None else None
1648
+
1649
+ sources = None
1650
+
1651
+ copy = type(self)(audio_data, self.sample_rate, stft_params=self.stft_params)
1652
+ copy._loudness = _loudness
1653
+ copy._stft_data = stft_data
1654
+ copy.sources = sources
1655
+
1656
+ return copy
1657
+
1658
+ def __setitem__(self, key, value):
1659
+ if not isinstance(value, type(self)):
1660
+ self.audio_data[key] = value
1661
+ return
1662
+
1663
+ if torch.is_tensor(key) and key.ndim == 0 and key.item() is True:
1664
+ assert self.batch_size == 1
1665
+ self.audio_data = value.audio_data
1666
+ self._loudness = value._loudness
1667
+ self.stft_data = value.stft_data
1668
+ return
1669
+
1670
+ elif isinstance(key, (bool, int, list, slice, tuple)) or (
1671
+ torch.is_tensor(key) and key.ndim <= 1
1672
+ ):
1673
+ if self.audio_data is not None and value.audio_data is not None:
1674
+ self.audio_data[key] = value.audio_data
1675
+ if self._loudness is not None and value._loudness is not None:
1676
+ self._loudness[key] = value._loudness
1677
+ if self.stft_data is not None and value.stft_data is not None:
1678
+ self.stft_data[key] = value.stft_data
1679
+ return
1680
+
1681
+ def __ne__(self, other):
1682
+ return not self == other
dac-vae/audiotools/core/display.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import typing
3
+ from functools import wraps
4
+
5
+ from . import util
6
+
7
+
8
+ def format_figure(func):
9
+ """Decorator for formatting figures produced by the code below.
10
+ See :py:func:`audiotools.core.util.format_figure` for more.
11
+
12
+ Parameters
13
+ ----------
14
+ func : Callable
15
+ Plotting function that is decorated by this function.
16
+
17
+ """
18
+
19
+ @wraps(func)
20
+ def wrapper(*args, **kwargs):
21
+ f_keys = inspect.signature(util.format_figure).parameters.keys()
22
+ f_kwargs = {}
23
+ for k, v in list(kwargs.items()):
24
+ if k in f_keys:
25
+ kwargs.pop(k)
26
+ f_kwargs[k] = v
27
+ func(*args, **kwargs)
28
+ util.format_figure(**f_kwargs)
29
+
30
+ return wrapper
31
+
32
+
33
+ class DisplayMixin:
34
+ @format_figure
35
+ def specshow(
36
+ self,
37
+ preemphasis: bool = False,
38
+ x_axis: str = "time",
39
+ y_axis: str = "linear",
40
+ n_mels: int = 128,
41
+ **kwargs,
42
+ ):
43
+ """Displays a spectrogram, using ``librosa.display.specshow``.
44
+
45
+ Parameters
46
+ ----------
47
+ preemphasis : bool, optional
48
+ Whether or not to apply preemphasis, which makes high
49
+ frequency detail easier to see, by default False
50
+ x_axis : str, optional
51
+ How to label the x axis, by default "time"
52
+ y_axis : str, optional
53
+ How to label the y axis, by default "linear"
54
+ n_mels : int, optional
55
+ If displaying a mel spectrogram with ``y_axis = "mel"``,
56
+ this controls the number of mels, by default 128.
57
+ kwargs : dict, optional
58
+ Keyword arguments to :py:func:`audiotools.core.util.format_figure`.
59
+ """
60
+ import librosa
61
+ import librosa.display
62
+
63
+ # Always re-compute the STFT data before showing it, in case
64
+ # it changed.
65
+ signal = self.clone()
66
+ signal.stft_data = None
67
+
68
+ if preemphasis:
69
+ signal.preemphasis()
70
+
71
+ ref = signal.magnitude.max()
72
+ log_mag = signal.log_magnitude(ref_value=ref)
73
+
74
+ if y_axis == "mel":
75
+ log_mag = 20 * signal.mel_spectrogram(n_mels).clamp(1e-5).log10()
76
+ log_mag -= log_mag.max()
77
+
78
+ librosa.display.specshow(
79
+ log_mag.numpy()[0].mean(axis=0),
80
+ x_axis=x_axis,
81
+ y_axis=y_axis,
82
+ sr=signal.sample_rate,
83
+ **kwargs,
84
+ )
85
+
86
+ @format_figure
87
+ def waveplot(self, x_axis: str = "time", **kwargs):
88
+ """Displays a waveform plot, using ``librosa.display.waveshow``.
89
+
90
+ Parameters
91
+ ----------
92
+ x_axis : str, optional
93
+ How to label the x axis, by default "time"
94
+ kwargs : dict, optional
95
+ Keyword arguments to :py:func:`audiotools.core.util.format_figure`.
96
+ """
97
+ import librosa
98
+ import librosa.display
99
+
100
+ audio_data = self.audio_data[0].mean(dim=0)
101
+ audio_data = audio_data.cpu().numpy()
102
+
103
+ plot_fn = "waveshow" if hasattr(librosa.display, "waveshow") else "waveplot"
104
+ wave_plot_fn = getattr(librosa.display, plot_fn)
105
+ wave_plot_fn(audio_data, x_axis=x_axis, sr=self.sample_rate, **kwargs)
106
+
107
+ @format_figure
108
+ def wavespec(self, x_axis: str = "time", **kwargs):
109
+ """Displays a waveform plot, using ``librosa.display.waveshow``.
110
+
111
+ Parameters
112
+ ----------
113
+ x_axis : str, optional
114
+ How to label the x axis, by default "time"
115
+ kwargs : dict, optional
116
+ Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow`.
117
+ """
118
+ import matplotlib.pyplot as plt
119
+ from matplotlib.gridspec import GridSpec
120
+
121
+ gs = GridSpec(6, 1)
122
+ plt.subplot(gs[0, :])
123
+ self.waveplot(x_axis=x_axis)
124
+ plt.subplot(gs[1:, :])
125
+ self.specshow(x_axis=x_axis, **kwargs)
126
+
127
+ def write_audio_to_tb(
128
+ self,
129
+ tag: str,
130
+ writer,
131
+ step: int = None,
132
+ plot_fn: typing.Union[typing.Callable, str] = "specshow",
133
+ **kwargs,
134
+ ):
135
+ """Writes a signal and its spectrogram to Tensorboard. Will show up
136
+ under the Audio and Images tab in Tensorboard.
137
+
138
+ Parameters
139
+ ----------
140
+ tag : str
141
+ Tag to write signal to (e.g. ``clean/sample_0.wav``). The image will be
142
+ written to the corresponding ``.png`` file (e.g. ``clean/sample_0.png``).
143
+ writer : SummaryWriter
144
+ A SummaryWriter object from PyTorch library.
145
+ step : int, optional
146
+ The step to write the signal to, by default None
147
+ plot_fn : typing.Union[typing.Callable, str], optional
148
+ How to create the image. Set to ``None`` to avoid plotting, by default "specshow"
149
+ kwargs : dict, optional
150
+ Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or
151
+ whatever ``plot_fn`` is set to.
152
+ """
153
+ import matplotlib.pyplot as plt
154
+
155
+ audio_data = self.audio_data[0, 0].detach().cpu()
156
+ sample_rate = self.sample_rate
157
+ writer.add_audio(tag, audio_data, step, sample_rate)
158
+
159
+ if plot_fn is not None:
160
+ if isinstance(plot_fn, str):
161
+ plot_fn = getattr(self, plot_fn)
162
+ fig = plt.figure()
163
+ plt.clf()
164
+ plot_fn(**kwargs)
165
+ writer.add_figure(tag.replace("wav", "png"), fig, step)
166
+
167
+ def save_image(
168
+ self,
169
+ image_path: str,
170
+ plot_fn: typing.Union[typing.Callable, str] = "specshow",
171
+ **kwargs,
172
+ ):
173
+ """Save AudioSignal spectrogram (or whatever ``plot_fn`` is set to) to
174
+ a specified file.
175
+
176
+ Parameters
177
+ ----------
178
+ image_path : str
179
+ Where to save the file to.
180
+ plot_fn : typing.Union[typing.Callable, str], optional
181
+ How to create the image. Set to ``None`` to avoid plotting, by default "specshow"
182
+ kwargs : dict, optional
183
+ Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or
184
+ whatever ``plot_fn`` is set to.
185
+ """
186
+ import matplotlib.pyplot as plt
187
+
188
+ if isinstance(plot_fn, str):
189
+ plot_fn = getattr(self, plot_fn)
190
+
191
+ plt.clf()
192
+ plot_fn(**kwargs)
193
+ plt.savefig(image_path, bbox_inches="tight", pad_inches=0)
194
+ plt.close()
dac-vae/audiotools/core/dsp.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+
3
+ import julius
4
+ import numpy as np
5
+ import torch
6
+
7
+ from . import util
8
+
9
+
10
+ class DSPMixin:
11
+ _original_batch_size = None
12
+ _original_num_channels = None
13
+ _padded_signal_length = None
14
+
15
+ def _preprocess_signal_for_windowing(self, window_duration, hop_duration):
16
+ self._original_batch_size = self.batch_size
17
+ self._original_num_channels = self.num_channels
18
+
19
+ window_length = int(window_duration * self.sample_rate)
20
+ hop_length = int(hop_duration * self.sample_rate)
21
+
22
+ if window_length % hop_length != 0:
23
+ factor = window_length // hop_length
24
+ window_length = factor * hop_length
25
+
26
+ self.zero_pad(hop_length, hop_length)
27
+ self._padded_signal_length = self.signal_length
28
+
29
+ return window_length, hop_length
30
+
31
+ def windows(
32
+ self, window_duration: float, hop_duration: float, preprocess: bool = True
33
+ ):
34
+ """Generator which yields windows of specified duration from signal with a specified
35
+ hop length.
36
+
37
+ Parameters
38
+ ----------
39
+ window_duration : float
40
+ Duration of every window in seconds.
41
+ hop_duration : float
42
+ Hop between windows in seconds.
43
+ preprocess : bool, optional
44
+ Whether to preprocess the signal, so that the first sample is in
45
+ the middle of the first window, by default True
46
+
47
+ Yields
48
+ ------
49
+ AudioSignal
50
+ Each window is returned as an AudioSignal.
51
+ """
52
+ if preprocess:
53
+ window_length, hop_length = self._preprocess_signal_for_windowing(
54
+ window_duration, hop_duration
55
+ )
56
+
57
+ self.audio_data = self.audio_data.reshape(-1, 1, self.signal_length)
58
+
59
+ for b in range(self.batch_size):
60
+ i = 0
61
+ start_idx = i * hop_length
62
+ while True:
63
+ start_idx = i * hop_length
64
+ i += 1
65
+ end_idx = start_idx + window_length
66
+ if end_idx > self.signal_length:
67
+ break
68
+ yield self[b, ..., start_idx:end_idx]
69
+
70
+ def collect_windows(
71
+ self, window_duration: float, hop_duration: float, preprocess: bool = True
72
+ ):
73
+ """Reshapes signal into windows of specified duration from signal with a specified
74
+ hop length. Window are placed along the batch dimension. Use with
75
+ :py:func:`audiotools.core.dsp.DSPMixin.overlap_and_add` to reconstruct the
76
+ original signal.
77
+
78
+ Parameters
79
+ ----------
80
+ window_duration : float
81
+ Duration of every window in seconds.
82
+ hop_duration : float
83
+ Hop between windows in seconds.
84
+ preprocess : bool, optional
85
+ Whether to preprocess the signal, so that the first sample is in
86
+ the middle of the first window, by default True
87
+
88
+ Returns
89
+ -------
90
+ AudioSignal
91
+ AudioSignal unfolded with shape ``(nb * nch * num_windows, 1, window_length)``
92
+ """
93
+ if preprocess:
94
+ window_length, hop_length = self._preprocess_signal_for_windowing(
95
+ window_duration, hop_duration
96
+ )
97
+
98
+ # self.audio_data: (nb, nch, nt).
99
+ unfolded = torch.nn.functional.unfold(
100
+ self.audio_data.reshape(-1, 1, 1, self.signal_length),
101
+ kernel_size=(1, window_length),
102
+ stride=(1, hop_length),
103
+ )
104
+ # unfolded: (nb * nch, window_length, num_windows).
105
+ # -> (nb * nch * num_windows, 1, window_length)
106
+ unfolded = unfolded.permute(0, 2, 1).reshape(-1, 1, window_length)
107
+ self.audio_data = unfolded
108
+ return self
109
+
110
+ def overlap_and_add(self, hop_duration: float):
111
+ """Function which takes a list of windows and overlap adds them into a
112
+ signal the same length as ``audio_signal``.
113
+
114
+ Parameters
115
+ ----------
116
+ hop_duration : float
117
+ How much to shift for each window
118
+ (overlap is window_duration - hop_duration) in seconds.
119
+
120
+ Returns
121
+ -------
122
+ AudioSignal
123
+ overlap-and-added signal.
124
+ """
125
+ hop_length = int(hop_duration * self.sample_rate)
126
+ window_length = self.signal_length
127
+
128
+ nb, nch = self._original_batch_size, self._original_num_channels
129
+
130
+ unfolded = self.audio_data.reshape(nb * nch, -1, window_length).permute(0, 2, 1)
131
+ folded = torch.nn.functional.fold(
132
+ unfolded,
133
+ output_size=(1, self._padded_signal_length),
134
+ kernel_size=(1, window_length),
135
+ stride=(1, hop_length),
136
+ )
137
+
138
+ norm = torch.ones_like(unfolded, device=unfolded.device)
139
+ norm = torch.nn.functional.fold(
140
+ norm,
141
+ output_size=(1, self._padded_signal_length),
142
+ kernel_size=(1, window_length),
143
+ stride=(1, hop_length),
144
+ )
145
+
146
+ folded = folded / norm
147
+
148
+ folded = folded.reshape(nb, nch, -1)
149
+ self.audio_data = folded
150
+ self.trim(hop_length, hop_length)
151
+ return self
152
+
153
+ def low_pass(
154
+ self, cutoffs: typing.Union[torch.Tensor, np.ndarray, float], zeros: int = 51
155
+ ):
156
+ """Low-passes the signal in-place. Each item in the batch
157
+ can have a different low-pass cutoff, if the input
158
+ to this signal is an array or tensor. If a float, all
159
+ items are given the same low-pass filter.
160
+
161
+ Parameters
162
+ ----------
163
+ cutoffs : typing.Union[torch.Tensor, np.ndarray, float]
164
+ Cutoff in Hz of low-pass filter.
165
+ zeros : int, optional
166
+ Number of taps to use in low-pass filter, by default 51
167
+
168
+ Returns
169
+ -------
170
+ AudioSignal
171
+ Low-passed AudioSignal.
172
+ """
173
+ cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size)
174
+ cutoffs = cutoffs / self.sample_rate
175
+ filtered = torch.empty_like(self.audio_data)
176
+
177
+ for i, cutoff in enumerate(cutoffs):
178
+ lp_filter = julius.LowPassFilter(cutoff.cpu(), zeros=zeros).to(self.device)
179
+ filtered[i] = lp_filter(self.audio_data[i])
180
+
181
+ self.audio_data = filtered
182
+ self.stft_data = None
183
+ return self
184
+
185
+ def high_pass(
186
+ self, cutoffs: typing.Union[torch.Tensor, np.ndarray, float], zeros: int = 51
187
+ ):
188
+ """High-passes the signal in-place. Each item in the batch
189
+ can have a different high-pass cutoff, if the input
190
+ to this signal is an array or tensor. If a float, all
191
+ items are given the same high-pass filter.
192
+
193
+ Parameters
194
+ ----------
195
+ cutoffs : typing.Union[torch.Tensor, np.ndarray, float]
196
+ Cutoff in Hz of high-pass filter.
197
+ zeros : int, optional
198
+ Number of taps to use in high-pass filter, by default 51
199
+
200
+ Returns
201
+ -------
202
+ AudioSignal
203
+ High-passed AudioSignal.
204
+ """
205
+ cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size)
206
+ cutoffs = cutoffs / self.sample_rate
207
+ filtered = torch.empty_like(self.audio_data)
208
+
209
+ for i, cutoff in enumerate(cutoffs):
210
+ hp_filter = julius.HighPassFilter(cutoff.cpu(), zeros=zeros).to(self.device)
211
+ filtered[i] = hp_filter(self.audio_data[i])
212
+
213
+ self.audio_data = filtered
214
+ self.stft_data = None
215
+ return self
216
+
217
+ def mask_frequencies(
218
+ self,
219
+ fmin_hz: typing.Union[torch.Tensor, np.ndarray, float],
220
+ fmax_hz: typing.Union[torch.Tensor, np.ndarray, float],
221
+ val: float = 0.0,
222
+ ):
223
+ """Masks frequencies between ``fmin_hz`` and ``fmax_hz``, and fills them
224
+ with the value specified by ``val``. Useful for implementing SpecAug.
225
+ The min and max can be different for every item in the batch.
226
+
227
+ Parameters
228
+ ----------
229
+ fmin_hz : typing.Union[torch.Tensor, np.ndarray, float]
230
+ Lower end of band to mask out.
231
+ fmax_hz : typing.Union[torch.Tensor, np.ndarray, float]
232
+ Upper end of band to mask out.
233
+ val : float, optional
234
+ Value to fill in, by default 0.0
235
+
236
+ Returns
237
+ -------
238
+ AudioSignal
239
+ Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
240
+ masked audio data.
241
+ """
242
+ # SpecAug
243
+ mag, phase = self.magnitude, self.phase
244
+ fmin_hz = util.ensure_tensor(fmin_hz, ndim=mag.ndim)
245
+ fmax_hz = util.ensure_tensor(fmax_hz, ndim=mag.ndim)
246
+ assert torch.all(fmin_hz < fmax_hz)
247
+
248
+ # build mask
249
+ nbins = mag.shape[-2]
250
+ bins_hz = torch.linspace(0, self.sample_rate / 2, nbins, device=self.device)
251
+ bins_hz = bins_hz[None, None, :, None].repeat(
252
+ self.batch_size, 1, 1, mag.shape[-1]
253
+ )
254
+ mask = (fmin_hz <= bins_hz) & (bins_hz < fmax_hz)
255
+ mask = mask.to(self.device)
256
+
257
+ mag = mag.masked_fill(mask, val)
258
+ phase = phase.masked_fill(mask, val)
259
+ self.stft_data = mag * torch.exp(1j * phase)
260
+ return self
261
+
262
+ def mask_timesteps(
263
+ self,
264
+ tmin_s: typing.Union[torch.Tensor, np.ndarray, float],
265
+ tmax_s: typing.Union[torch.Tensor, np.ndarray, float],
266
+ val: float = 0.0,
267
+ ):
268
+ """Masks timesteps between ``tmin_s`` and ``tmax_s``, and fills them
269
+ with the value specified by ``val``. Useful for implementing SpecAug.
270
+ The min and max can be different for every item in the batch.
271
+
272
+ Parameters
273
+ ----------
274
+ tmin_s : typing.Union[torch.Tensor, np.ndarray, float]
275
+ Lower end of timesteps to mask out.
276
+ tmax_s : typing.Union[torch.Tensor, np.ndarray, float]
277
+ Upper end of timesteps to mask out.
278
+ val : float, optional
279
+ Value to fill in, by default 0.0
280
+
281
+ Returns
282
+ -------
283
+ AudioSignal
284
+ Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
285
+ masked audio data.
286
+ """
287
+ # SpecAug
288
+ mag, phase = self.magnitude, self.phase
289
+ tmin_s = util.ensure_tensor(tmin_s, ndim=mag.ndim)
290
+ tmax_s = util.ensure_tensor(tmax_s, ndim=mag.ndim)
291
+
292
+ assert torch.all(tmin_s < tmax_s)
293
+
294
+ # build mask
295
+ nt = mag.shape[-1]
296
+ bins_t = torch.linspace(0, self.signal_duration, nt, device=self.device)
297
+ bins_t = bins_t[None, None, None, :].repeat(
298
+ self.batch_size, 1, mag.shape[-2], 1
299
+ )
300
+ mask = (tmin_s <= bins_t) & (bins_t < tmax_s)
301
+
302
+ mag = mag.masked_fill(mask, val)
303
+ phase = phase.masked_fill(mask, val)
304
+ self.stft_data = mag * torch.exp(1j * phase)
305
+ return self
306
+
307
+ def mask_low_magnitudes(
308
+ self, db_cutoff: typing.Union[torch.Tensor, np.ndarray, float], val: float = 0.0
309
+ ):
310
+ """Mask away magnitudes below a specified threshold, which
311
+ can be different for every item in the batch.
312
+
313
+ Parameters
314
+ ----------
315
+ db_cutoff : typing.Union[torch.Tensor, np.ndarray, float]
316
+ Decibel value for which things below it will be masked away.
317
+ val : float, optional
318
+ Value to fill in for masked portions, by default 0.0
319
+
320
+ Returns
321
+ -------
322
+ AudioSignal
323
+ Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
324
+ masked audio data.
325
+ """
326
+ mag = self.magnitude
327
+ log_mag = self.log_magnitude()
328
+
329
+ db_cutoff = util.ensure_tensor(db_cutoff, ndim=mag.ndim)
330
+ mask = log_mag < db_cutoff
331
+ mag = mag.masked_fill(mask, val)
332
+
333
+ self.magnitude = mag
334
+ return self
335
+
336
+ def shift_phase(self, shift: typing.Union[torch.Tensor, np.ndarray, float]):
337
+ """Shifts the phase by a constant value.
338
+
339
+ Parameters
340
+ ----------
341
+ shift : typing.Union[torch.Tensor, np.ndarray, float]
342
+ What to shift the phase by.
343
+
344
+ Returns
345
+ -------
346
+ AudioSignal
347
+ Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
348
+ masked audio data.
349
+ """
350
+ shift = util.ensure_tensor(shift, ndim=self.phase.ndim)
351
+ self.phase = self.phase + shift
352
+ return self
353
+
354
+ def corrupt_phase(self, scale: typing.Union[torch.Tensor, np.ndarray, float]):
355
+ """Corrupts the phase randomly by some scaled value.
356
+
357
+ Parameters
358
+ ----------
359
+ scale : typing.Union[torch.Tensor, np.ndarray, float]
360
+ Standard deviation of noise to add to the phase.
361
+
362
+ Returns
363
+ -------
364
+ AudioSignal
365
+ Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
366
+ masked audio data.
367
+ """
368
+ scale = util.ensure_tensor(scale, ndim=self.phase.ndim)
369
+ self.phase = self.phase + scale * torch.randn_like(self.phase)
370
+ return self
371
+
372
+ def preemphasis(self, coef: float = 0.85):
373
+ """Applies pre-emphasis to audio signal.
374
+
375
+ Parameters
376
+ ----------
377
+ coef : float, optional
378
+ How much pre-emphasis to apply, lower values do less. 0 does nothing.
379
+ by default 0.85
380
+
381
+ Returns
382
+ -------
383
+ AudioSignal
384
+ Pre-emphasized signal.
385
+ """
386
+ kernel = torch.tensor([1, -coef, 0]).view(1, 1, -1).to(self.device)
387
+ x = self.audio_data.reshape(-1, 1, self.signal_length)
388
+ x = torch.nn.functional.conv1d(x, kernel, padding=1)
389
+ self.audio_data = x.reshape(*self.audio_data.shape)
390
+ return self
dac-vae/audiotools/core/effects.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+
3
+ import julius
4
+ import numpy as np
5
+ import torch
6
+ import torchaudio
7
+
8
+ from . import util
9
+
10
+
11
+ class EffectMixin:
12
+ GAIN_FACTOR = np.log(10) / 20
13
+ """Gain factor for converting between amplitude and decibels."""
14
+ CODEC_PRESETS = {
15
+ "8-bit": {"format": "wav", "encoding": "ULAW", "bits_per_sample": 8},
16
+ "GSM-FR": {"format": "gsm"},
17
+ "MP3": {"format": "mp3", "compression": -9},
18
+ "Vorbis": {"format": "vorbis", "compression": -1},
19
+ "Ogg": {
20
+ "format": "ogg",
21
+ "compression": -1,
22
+ },
23
+ "Amr-nb": {"format": "amr-nb"},
24
+ }
25
+ """Presets for applying codecs via torchaudio."""
26
+
27
+ def mix(
28
+ self,
29
+ other,
30
+ snr: typing.Union[torch.Tensor, np.ndarray, float] = 10,
31
+ other_eq: typing.Union[torch.Tensor, np.ndarray] = None,
32
+ ):
33
+ """Mixes noise with signal at specified
34
+ signal-to-noise ratio. Optionally, the
35
+ other signal can be equalized in-place.
36
+
37
+
38
+ Parameters
39
+ ----------
40
+ other : AudioSignal
41
+ AudioSignal object to mix with.
42
+ snr : typing.Union[torch.Tensor, np.ndarray, float], optional
43
+ Signal to noise ratio, by default 10
44
+ other_eq : typing.Union[torch.Tensor, np.ndarray], optional
45
+ EQ curve to apply to other signal, if any, by default None
46
+
47
+ Returns
48
+ -------
49
+ AudioSignal
50
+ In-place modification of AudioSignal.
51
+ """
52
+ snr = util.ensure_tensor(snr).to(self.device)
53
+
54
+ pad_len = max(0, self.signal_length - other.signal_length)
55
+ other.zero_pad(0, pad_len)
56
+ other.truncate_samples(self.signal_length)
57
+ if other_eq is not None:
58
+ other = other.equalizer(other_eq)
59
+
60
+ tgt_loudness = self.loudness() - snr
61
+ other = other.normalize(tgt_loudness)
62
+
63
+ self.audio_data = self.audio_data + other.audio_data
64
+ return self
65
+
66
+ def convolve(self, other, start_at_max: bool = True):
67
+ """Convolves self with other.
68
+ This function uses FFTs to do the convolution.
69
+
70
+ Parameters
71
+ ----------
72
+ other : AudioSignal
73
+ Signal to convolve with.
74
+ start_at_max : bool, optional
75
+ Whether to start at the max value of other signal, to
76
+ avoid inducing delays, by default True
77
+
78
+ Returns
79
+ -------
80
+ AudioSignal
81
+ Convolved signal, in-place.
82
+ """
83
+ from . import AudioSignal
84
+
85
+ pad_len = self.signal_length - other.signal_length
86
+
87
+ if pad_len > 0:
88
+ other.zero_pad(0, pad_len)
89
+ else:
90
+ other.truncate_samples(self.signal_length)
91
+
92
+ if start_at_max:
93
+ # Use roll to rotate over the max for every item
94
+ # so that the impulse responses don't induce any
95
+ # delay.
96
+ idx = other.audio_data.abs().argmax(axis=-1)
97
+ irs = torch.zeros_like(other.audio_data)
98
+ for i in range(other.batch_size):
99
+ irs[i] = torch.roll(other.audio_data[i], -idx[i].item(), -1)
100
+ other = AudioSignal(irs, other.sample_rate)
101
+
102
+ delta = torch.zeros_like(other.audio_data)
103
+ delta[..., 0] = 1
104
+
105
+ length = self.signal_length
106
+ delta_fft = torch.fft.rfft(delta, length)
107
+ other_fft = torch.fft.rfft(other.audio_data, length)
108
+ self_fft = torch.fft.rfft(self.audio_data, length)
109
+
110
+ convolved_fft = other_fft * self_fft
111
+ convolved_audio = torch.fft.irfft(convolved_fft, length)
112
+
113
+ delta_convolved_fft = other_fft * delta_fft
114
+ delta_audio = torch.fft.irfft(delta_convolved_fft, length)
115
+
116
+ # Use the delta to rescale the audio exactly as needed.
117
+ delta_max = delta_audio.abs().max(dim=-1, keepdims=True)[0]
118
+ scale = 1 / delta_max.clamp(1e-5)
119
+ convolved_audio = convolved_audio * scale
120
+
121
+ self.audio_data = convolved_audio
122
+
123
+ return self
124
+
125
+ def apply_ir(
126
+ self,
127
+ ir,
128
+ drr: typing.Union[torch.Tensor, np.ndarray, float] = None,
129
+ ir_eq: typing.Union[torch.Tensor, np.ndarray] = None,
130
+ use_original_phase: bool = False,
131
+ ):
132
+ """Applies an impulse response to the signal. If ` is`ir_eq``
133
+ is specified, the impulse response is equalized before
134
+ it is applied, using the given curve.
135
+
136
+ Parameters
137
+ ----------
138
+ ir : AudioSignal
139
+ Impulse response to convolve with.
140
+ drr : typing.Union[torch.Tensor, np.ndarray, float], optional
141
+ Direct-to-reverberant ratio that impulse response will be
142
+ altered to, if specified, by default None
143
+ ir_eq : typing.Union[torch.Tensor, np.ndarray], optional
144
+ Equalization that will be applied to impulse response
145
+ if specified, by default None
146
+ use_original_phase : bool, optional
147
+ Whether to use the original phase, instead of the convolved
148
+ phase, by default False
149
+
150
+ Returns
151
+ -------
152
+ AudioSignal
153
+ Signal with impulse response applied to it
154
+ """
155
+ if ir_eq is not None:
156
+ ir = ir.equalizer(ir_eq)
157
+ if drr is not None:
158
+ ir = ir.alter_drr(drr)
159
+
160
+ # Save the peak before
161
+ max_spk = self.audio_data.abs().max(dim=-1, keepdims=True).values
162
+
163
+ # Augment the impulse response to simulate microphone effects
164
+ # and with varying direct-to-reverberant ratio.
165
+ phase = self.phase
166
+ self.convolve(ir)
167
+
168
+ # Use the input phase
169
+ if use_original_phase:
170
+ self.stft()
171
+ self.stft_data = self.magnitude * torch.exp(1j * phase)
172
+ self.istft()
173
+
174
+ # Rescale to the input's amplitude
175
+ max_transformed = self.audio_data.abs().max(dim=-1, keepdims=True).values
176
+ scale_factor = max_spk.clamp(1e-8) / max_transformed.clamp(1e-8)
177
+ self = self * scale_factor
178
+
179
+ return self
180
+
181
+ def ensure_max_of_audio(self, max: float = 1.0):
182
+ """Ensures that ``abs(audio_data) <= max``.
183
+
184
+ Parameters
185
+ ----------
186
+ max : float, optional
187
+ Max absolute value of signal, by default 1.0
188
+
189
+ Returns
190
+ -------
191
+ AudioSignal
192
+ Signal with values scaled between -max and max.
193
+ """
194
+ peak = self.audio_data.abs().max(dim=-1, keepdims=True)[0]
195
+ peak_gain = torch.ones_like(peak)
196
+ peak_gain[peak > max] = max / peak[peak > max]
197
+ self.audio_data = self.audio_data * peak_gain
198
+ return self
199
+
200
+ def normalize(self, db: typing.Union[torch.Tensor, np.ndarray, float] = -24.0):
201
+ """Normalizes the signal's volume to the specified db, in LUFS.
202
+ This is GPU-compatible, making for very fast loudness normalization.
203
+
204
+ Parameters
205
+ ----------
206
+ db : typing.Union[torch.Tensor, np.ndarray, float], optional
207
+ Loudness to normalize to, by default -24.0
208
+
209
+ Returns
210
+ -------
211
+ AudioSignal
212
+ Normalized audio signal.
213
+ """
214
+ db = util.ensure_tensor(db).to(self.device)
215
+ ref_db = self.loudness()
216
+ gain = db - ref_db
217
+ gain = torch.exp(gain * self.GAIN_FACTOR)
218
+
219
+ self.audio_data = self.audio_data * gain[:, None, None]
220
+ return self
221
+
222
+ def volume_change(self, db: typing.Union[torch.Tensor, np.ndarray, float]):
223
+ """Change volume of signal by some amount, in dB.
224
+
225
+ Parameters
226
+ ----------
227
+ db : typing.Union[torch.Tensor, np.ndarray, float]
228
+ Amount to change volume by.
229
+
230
+ Returns
231
+ -------
232
+ AudioSignal
233
+ Signal at new volume.
234
+ """
235
+ db = util.ensure_tensor(db, ndim=1).to(self.device)
236
+ gain = torch.exp(db * self.GAIN_FACTOR)
237
+ self.audio_data = self.audio_data * gain[:, None, None]
238
+ return self
239
+
240
+ def _to_2d(self):
241
+ waveform = self.audio_data.reshape(-1, self.signal_length)
242
+ return waveform
243
+
244
+ def _to_3d(self, waveform):
245
+ return waveform.reshape(self.batch_size, self.num_channels, -1)
246
+
247
+ def pitch_shift(self, n_semitones: int, quick: bool = True):
248
+ """Pitch shift the signal. All items in the batch
249
+ get the same pitch shift.
250
+
251
+ Parameters
252
+ ----------
253
+ n_semitones : int
254
+ How many semitones to shift the signal by.
255
+ quick : bool, optional
256
+ Using quick pitch shifting, by default True
257
+
258
+ Returns
259
+ -------
260
+ AudioSignal
261
+ Pitch shifted audio signal.
262
+ """
263
+ device = self.device
264
+ effects = [
265
+ ["pitch", str(n_semitones * 100)],
266
+ ["rate", str(self.sample_rate)],
267
+ ]
268
+ if quick:
269
+ effects[0].insert(1, "-q")
270
+
271
+ waveform = self._to_2d().cpu()
272
+ waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
273
+ waveform, self.sample_rate, effects, channels_first=True
274
+ )
275
+ self.sample_rate = sample_rate
276
+ self.audio_data = self._to_3d(waveform)
277
+ return self.to(device)
278
+
279
+ def time_stretch(self, factor: float, quick: bool = True):
280
+ """Time stretch the audio signal.
281
+
282
+ Parameters
283
+ ----------
284
+ factor : float
285
+ Factor by which to stretch the AudioSignal. Typically
286
+ between 0.8 and 1.2.
287
+ quick : bool, optional
288
+ Whether to use quick time stretching, by default True
289
+
290
+ Returns
291
+ -------
292
+ AudioSignal
293
+ Time-stretched AudioSignal.
294
+ """
295
+ device = self.device
296
+ effects = [
297
+ ["tempo", str(factor)],
298
+ ["rate", str(self.sample_rate)],
299
+ ]
300
+ if quick:
301
+ effects[0].insert(1, "-q")
302
+
303
+ waveform = self._to_2d().cpu()
304
+ waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
305
+ waveform, self.sample_rate, effects, channels_first=True
306
+ )
307
+ self.sample_rate = sample_rate
308
+ self.audio_data = self._to_3d(waveform)
309
+ return self.to(device)
310
+
311
+ def apply_codec(
312
+ self,
313
+ preset: str = None,
314
+ format: str = "wav",
315
+ encoding: str = None,
316
+ bits_per_sample: int = None,
317
+ compression: int = None,
318
+ ): # pragma: no cover
319
+ """Applies an audio codec to the signal.
320
+
321
+ Parameters
322
+ ----------
323
+ preset : str, optional
324
+ One of the keys in ``self.CODEC_PRESETS``, by default None
325
+ format : str, optional
326
+ Format for audio codec, by default "wav"
327
+ encoding : str, optional
328
+ Encoding to use, by default None
329
+ bits_per_sample : int, optional
330
+ How many bits per sample, by default None
331
+ compression : int, optional
332
+ Compression amount of codec, by default None
333
+
334
+ Returns
335
+ -------
336
+ AudioSignal
337
+ AudioSignal with codec applied.
338
+
339
+ Raises
340
+ ------
341
+ ValueError
342
+ If preset is not in ``self.CODEC_PRESETS``, an error
343
+ is thrown.
344
+ """
345
+ torchaudio_version_070 = "0.7" in torchaudio.__version__
346
+ if torchaudio_version_070:
347
+ return self
348
+
349
+ kwargs = {
350
+ "format": format,
351
+ "encoding": encoding,
352
+ "bits_per_sample": bits_per_sample,
353
+ "compression": compression,
354
+ }
355
+
356
+ if preset is not None:
357
+ if preset in self.CODEC_PRESETS:
358
+ kwargs = self.CODEC_PRESETS[preset]
359
+ else:
360
+ raise ValueError(
361
+ f"Unknown preset: {preset}. "
362
+ f"Known presets: {list(self.CODEC_PRESETS.keys())}"
363
+ )
364
+
365
+ waveform = self._to_2d()
366
+ if kwargs["format"] in ["vorbis", "mp3", "ogg", "amr-nb"]:
367
+ # Apply it in a for loop
368
+ augmented = torch.cat(
369
+ [
370
+ torchaudio.functional.apply_codec(
371
+ waveform[i][None, :], self.sample_rate, **kwargs
372
+ )
373
+ for i in range(waveform.shape[0])
374
+ ],
375
+ dim=0,
376
+ )
377
+ else:
378
+ augmented = torchaudio.functional.apply_codec(
379
+ waveform, self.sample_rate, **kwargs
380
+ )
381
+ augmented = self._to_3d(augmented)
382
+
383
+ self.audio_data = augmented
384
+ return self
385
+
386
+ def mel_filterbank(self, n_bands: int):
387
+ """Breaks signal into mel bands.
388
+
389
+ Parameters
390
+ ----------
391
+ n_bands : int
392
+ Number of mel bands to use.
393
+
394
+ Returns
395
+ -------
396
+ torch.Tensor
397
+ Mel-filtered bands, with last axis being the band index.
398
+ """
399
+ filterbank = (
400
+ julius.SplitBands(self.sample_rate, n_bands).float().to(self.device)
401
+ )
402
+ filtered = filterbank(self.audio_data)
403
+ return filtered.permute(1, 2, 3, 0)
404
+
405
+ def equalizer(self, db: typing.Union[torch.Tensor, np.ndarray]):
406
+ """Applies a mel-spaced equalizer to the audio signal.
407
+
408
+ Parameters
409
+ ----------
410
+ db : typing.Union[torch.Tensor, np.ndarray]
411
+ EQ curve to apply.
412
+
413
+ Returns
414
+ -------
415
+ AudioSignal
416
+ AudioSignal with equalization applied.
417
+ """
418
+ db = util.ensure_tensor(db)
419
+ n_bands = db.shape[-1]
420
+ fbank = self.mel_filterbank(n_bands)
421
+
422
+ # If there's a batch dimension, make sure it's the same.
423
+ if db.ndim == 2:
424
+ if db.shape[0] != 1:
425
+ assert db.shape[0] == fbank.shape[0]
426
+ else:
427
+ db = db.unsqueeze(0)
428
+
429
+ weights = (10**db).to(self.device).float()
430
+ fbank = fbank * weights[:, None, None, :]
431
+ eq_audio_data = fbank.sum(-1)
432
+ self.audio_data = eq_audio_data
433
+ return self
434
+
435
+ def clip_distortion(
436
+ self, clip_percentile: typing.Union[torch.Tensor, np.ndarray, float]
437
+ ):
438
+ """Clips the signal at a given percentile. The higher it is,
439
+ the lower the threshold for clipping.
440
+
441
+ Parameters
442
+ ----------
443
+ clip_percentile : typing.Union[torch.Tensor, np.ndarray, float]
444
+ Values are between 0.0 to 1.0. Typical values are 0.1 or below.
445
+
446
+ Returns
447
+ -------
448
+ AudioSignal
449
+ Audio signal with clipped audio data.
450
+ """
451
+ clip_percentile = util.ensure_tensor(clip_percentile, ndim=1)
452
+ min_thresh = torch.quantile(self.audio_data, clip_percentile / 2, dim=-1)
453
+ max_thresh = torch.quantile(self.audio_data, 1 - (clip_percentile / 2), dim=-1)
454
+
455
+ nc = self.audio_data.shape[1]
456
+ min_thresh = min_thresh[:, :nc, :]
457
+ max_thresh = max_thresh[:, :nc, :]
458
+
459
+ self.audio_data = self.audio_data.clamp(min_thresh, max_thresh)
460
+
461
+ return self
462
+
463
+ def quantization(
464
+ self, quantization_channels: typing.Union[torch.Tensor, np.ndarray, int]
465
+ ):
466
+ """Applies quantization to the input waveform.
467
+
468
+ Parameters
469
+ ----------
470
+ quantization_channels : typing.Union[torch.Tensor, np.ndarray, int]
471
+ Number of evenly spaced quantization channels to quantize
472
+ to.
473
+
474
+ Returns
475
+ -------
476
+ AudioSignal
477
+ Quantized AudioSignal.
478
+ """
479
+ quantization_channels = util.ensure_tensor(quantization_channels, ndim=3)
480
+
481
+ x = self.audio_data
482
+ x = (x + 1) / 2
483
+ x = x * quantization_channels
484
+ x = x.floor()
485
+ x = x / quantization_channels
486
+ x = 2 * x - 1
487
+
488
+ residual = (self.audio_data - x).detach()
489
+ self.audio_data = self.audio_data - residual
490
+ return self
491
+
492
+ def mulaw_quantization(
493
+ self, quantization_channels: typing.Union[torch.Tensor, np.ndarray, int]
494
+ ):
495
+ """Applies mu-law quantization to the input waveform.
496
+
497
+ Parameters
498
+ ----------
499
+ quantization_channels : typing.Union[torch.Tensor, np.ndarray, int]
500
+ Number of mu-law spaced quantization channels to quantize
501
+ to.
502
+
503
+ Returns
504
+ -------
505
+ AudioSignal
506
+ Quantized AudioSignal.
507
+ """
508
+ mu = quantization_channels - 1.0
509
+ mu = util.ensure_tensor(mu, ndim=3)
510
+
511
+ x = self.audio_data
512
+
513
+ # quantize
514
+ x = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu)
515
+ x = ((x + 1) / 2 * mu + 0.5).to(torch.int64)
516
+
517
+ # unquantize
518
+ x = (x / mu) * 2 - 1.0
519
+ x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu
520
+
521
+ residual = (self.audio_data - x).detach()
522
+ self.audio_data = self.audio_data - residual
523
+ return self
524
+
525
+ def __matmul__(self, other):
526
+ return self.convolve(other)
527
+
528
+
529
+ class ImpulseResponseMixin:
530
+ """These functions are generally only used with AudioSignals that are derived
531
+ from impulse responses, not other sources like music or speech. These methods
532
+ are used to replicate the data augmentation described in [1].
533
+
534
+ 1. Bryan, Nicholas J. "Impulse response data augmentation and deep
535
+ neural networks for blind room acoustic parameter estimation."
536
+ ICASSP 2020-2020 IEEE International Conference on Acoustics,
537
+ Speech and Signal Processing (ICASSP). IEEE, 2020.
538
+ """
539
+
540
+ def decompose_ir(self):
541
+ """Decomposes an impulse response into early and late
542
+ field responses.
543
+ """
544
+ # Equations 1 and 2
545
+ # -----------------
546
+ # Breaking up into early
547
+ # response + late field response.
548
+
549
+ td = torch.argmax(self.audio_data, dim=-1, keepdim=True)
550
+ t0 = int(self.sample_rate * 0.0025)
551
+
552
+ idx = torch.arange(self.audio_data.shape[-1], device=self.device)[None, None, :]
553
+ idx = idx.expand(self.batch_size, -1, -1)
554
+ early_idx = (idx >= td - t0) * (idx <= td + t0)
555
+
556
+ early_response = torch.zeros_like(self.audio_data, device=self.device)
557
+ early_response[early_idx] = self.audio_data[early_idx]
558
+
559
+ late_idx = ~early_idx
560
+ late_field = torch.zeros_like(self.audio_data, device=self.device)
561
+ late_field[late_idx] = self.audio_data[late_idx]
562
+
563
+ # Equation 4
564
+ # ----------
565
+ # Decompose early response into windowed
566
+ # direct path and windowed residual.
567
+
568
+ window = torch.zeros_like(self.audio_data, device=self.device)
569
+ for idx in range(self.batch_size):
570
+ window_idx = early_idx[idx, 0].nonzero()
571
+ window[idx, ..., window_idx] = self.get_window(
572
+ "hann", window_idx.shape[-1], self.device
573
+ )
574
+ return early_response, late_field, window
575
+
576
+ def measure_drr(self):
577
+ """Measures the direct-to-reverberant ratio of the impulse
578
+ response.
579
+
580
+ Returns
581
+ -------
582
+ float
583
+ Direct-to-reverberant ratio
584
+ """
585
+ early_response, late_field, _ = self.decompose_ir()
586
+ num = (early_response**2).sum(dim=-1)
587
+ den = (late_field**2).sum(dim=-1)
588
+ drr = 10 * torch.log10(num / den)
589
+ return drr
590
+
591
+ @staticmethod
592
+ def solve_alpha(early_response, late_field, wd, target_drr):
593
+ """Used to solve for the alpha value, which is used
594
+ to alter the drr.
595
+ """
596
+ # Equation 5
597
+ # ----------
598
+ # Apply the good ol' quadratic formula.
599
+
600
+ wd_sq = wd**2
601
+ wd_sq_1 = (1 - wd) ** 2
602
+ e_sq = early_response**2
603
+ l_sq = late_field**2
604
+ a = (wd_sq * e_sq).sum(dim=-1)
605
+ b = (2 * (1 - wd) * wd * e_sq).sum(dim=-1)
606
+ c = (wd_sq_1 * e_sq).sum(dim=-1) - torch.pow(10, target_drr / 10) * l_sq.sum(
607
+ dim=-1
608
+ )
609
+
610
+ expr = ((b**2) - 4 * a * c).sqrt()
611
+ alpha = torch.maximum(
612
+ (-b - expr) / (2 * a),
613
+ (-b + expr) / (2 * a),
614
+ )
615
+ return alpha
616
+
617
+ def alter_drr(self, drr: typing.Union[torch.Tensor, np.ndarray, float]):
618
+ """Alters the direct-to-reverberant ratio of the impulse response.
619
+
620
+ Parameters
621
+ ----------
622
+ drr : typing.Union[torch.Tensor, np.ndarray, float]
623
+ Direct-to-reverberant ratio that impulse response will be
624
+ altered to, if specified, by default None
625
+
626
+ Returns
627
+ -------
628
+ AudioSignal
629
+ Altered impulse response.
630
+ """
631
+ drr = util.ensure_tensor(drr, 2, self.batch_size).to(self.device)
632
+
633
+ early_response, late_field, window = self.decompose_ir()
634
+ alpha = self.solve_alpha(early_response, late_field, window, drr)
635
+ min_alpha = (
636
+ late_field.abs().max(dim=-1)[0] / early_response.abs().max(dim=-1)[0]
637
+ )
638
+ alpha = torch.maximum(alpha, min_alpha)[..., None]
639
+
640
+ aug_ir_data = (
641
+ alpha * window * early_response
642
+ + ((1 - window) * early_response)
643
+ + late_field
644
+ )
645
+ self.audio_data = aug_ir_data
646
+ self.ensure_max_of_audio()
647
+ return self
dac-vae/audiotools/core/ffmpeg.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import shlex
3
+ import subprocess
4
+ import tempfile
5
+ from pathlib import Path
6
+ from typing import Tuple
7
+
8
+ import ffmpy
9
+ import numpy as np
10
+ import torch
11
+
12
+
13
+ def r128stats(filepath: str, quiet: bool):
14
+ """Takes a path to an audio file, returns a dict with the loudness
15
+ stats computed by the ffmpeg ebur128 filter.
16
+
17
+ Parameters
18
+ ----------
19
+ filepath : str
20
+ Path to compute loudness stats on.
21
+ quiet : bool
22
+ Whether to show FFMPEG output during computation.
23
+
24
+ Returns
25
+ -------
26
+ dict
27
+ Dictionary containing loudness stats.
28
+ """
29
+ ffargs = [
30
+ "ffmpeg",
31
+ "-nostats",
32
+ "-i",
33
+ filepath,
34
+ "-filter_complex",
35
+ "ebur128",
36
+ "-f",
37
+ "null",
38
+ "-",
39
+ ]
40
+ if quiet:
41
+ ffargs += ["-hide_banner"]
42
+ proc = subprocess.Popen(ffargs, stderr=subprocess.PIPE, universal_newlines=True)
43
+ stats = proc.communicate()[1]
44
+ summary_index = stats.rfind("Summary:")
45
+
46
+ summary_list = stats[summary_index:].split()
47
+ i_lufs = float(summary_list[summary_list.index("I:") + 1])
48
+ i_thresh = float(summary_list[summary_list.index("I:") + 4])
49
+ lra = float(summary_list[summary_list.index("LRA:") + 1])
50
+ lra_thresh = float(summary_list[summary_list.index("LRA:") + 4])
51
+ lra_low = float(summary_list[summary_list.index("low:") + 1])
52
+ lra_high = float(summary_list[summary_list.index("high:") + 1])
53
+ stats_dict = {
54
+ "I": i_lufs,
55
+ "I Threshold": i_thresh,
56
+ "LRA": lra,
57
+ "LRA Threshold": lra_thresh,
58
+ "LRA Low": lra_low,
59
+ "LRA High": lra_high,
60
+ }
61
+
62
+ return stats_dict
63
+
64
+
65
+ def ffprobe_offset_and_codec(path: str) -> Tuple[float, str]:
66
+ """Given a path to a file, returns the start time offset and codec of
67
+ the first audio stream.
68
+ """
69
+ ff = ffmpy.FFprobe(
70
+ inputs={path: None},
71
+ global_options="-show_entries format=start_time:stream=duration,start_time,codec_type,codec_name,start_pts,time_base -of json -v quiet",
72
+ )
73
+ streams = json.loads(ff.run(stdout=subprocess.PIPE)[0])["streams"]
74
+ seconds_offset = 0.0
75
+ codec = None
76
+
77
+ # Get the offset and codec of the first audio stream we find
78
+ # and return its start time, if it has one.
79
+ for stream in streams:
80
+ if stream["codec_type"] == "audio":
81
+ seconds_offset = stream.get("start_time", 0.0)
82
+ codec = stream.get("codec_name")
83
+ break
84
+ return float(seconds_offset), codec
85
+
86
+
87
+ class FFMPEGMixin:
88
+ _loudness = None
89
+
90
+ def ffmpeg_loudness(self, quiet: bool = True):
91
+ """Computes loudness of audio file using FFMPEG.
92
+
93
+ Parameters
94
+ ----------
95
+ quiet : bool, optional
96
+ Whether to show FFMPEG output during computation,
97
+ by default True
98
+
99
+ Returns
100
+ -------
101
+ torch.Tensor
102
+ Loudness of every item in the batch, computed via
103
+ FFMPEG.
104
+ """
105
+ loudness = []
106
+
107
+ with tempfile.NamedTemporaryFile(suffix=".wav") as f:
108
+ for i in range(self.batch_size):
109
+ self[i].write(f.name)
110
+ loudness_stats = r128stats(f.name, quiet=quiet)
111
+ loudness.append(loudness_stats["I"])
112
+
113
+ self._loudness = torch.from_numpy(np.array(loudness)).float()
114
+ return self.loudness()
115
+
116
+ def ffmpeg_resample(self, sample_rate: int, quiet: bool = True):
117
+ """Resamples AudioSignal using FFMPEG. More memory-efficient
118
+ than using julius.resample for long audio files.
119
+
120
+ Parameters
121
+ ----------
122
+ sample_rate : int
123
+ Sample rate to resample to.
124
+ quiet : bool, optional
125
+ Whether to show FFMPEG output during computation,
126
+ by default True
127
+
128
+ Returns
129
+ -------
130
+ AudioSignal
131
+ Resampled AudioSignal.
132
+ """
133
+ from audiotools import AudioSignal
134
+
135
+ if sample_rate == self.sample_rate:
136
+ return self
137
+
138
+ with tempfile.NamedTemporaryFile(suffix=".wav") as f:
139
+ self.write(f.name)
140
+ f_out = f.name.replace("wav", "rs.wav")
141
+ command = f"ffmpeg -i {f.name} -ar {sample_rate} {f_out}"
142
+ if quiet:
143
+ command += " -hide_banner -loglevel error"
144
+ subprocess.check_call(shlex.split(command))
145
+ resampled = AudioSignal(f_out)
146
+ Path.unlink(Path(f_out))
147
+ return resampled
148
+
149
+ @classmethod
150
+ def load_from_file_with_ffmpeg(cls, audio_path: str, quiet: bool = True, **kwargs):
151
+ """Loads AudioSignal object after decoding it to a wav file using FFMPEG.
152
+ Useful for loading audio that isn't covered by librosa's loading mechanism. Also
153
+ useful for loading mp3 files, without any offset.
154
+
155
+ Parameters
156
+ ----------
157
+ audio_path : str
158
+ Path to load AudioSignal from.
159
+ quiet : bool, optional
160
+ Whether to show FFMPEG output during computation,
161
+ by default True
162
+
163
+ Returns
164
+ -------
165
+ AudioSignal
166
+ AudioSignal loaded from file with FFMPEG.
167
+ """
168
+ audio_path = str(audio_path)
169
+ with tempfile.TemporaryDirectory() as d:
170
+ wav_file = str(Path(d) / "extracted.wav")
171
+ padded_wav = str(Path(d) / "padded.wav")
172
+
173
+ global_options = "-y"
174
+ if quiet:
175
+ global_options += " -loglevel error"
176
+
177
+ ff = ffmpy.FFmpeg(
178
+ inputs={audio_path: None},
179
+ # For inputs that are m4a (and others?), the input audio can
180
+ # have samples that don't match the sample rate. This aresample
181
+ # option forces ffmpeg to read timing information in the source
182
+ # file instead of assuming constant sample rate.
183
+ #
184
+ # This fixes an issue where an input m4a file might be a
185
+ # different length than the output wav file
186
+ outputs={wav_file: "-af aresample=async=1000"},
187
+ global_options=global_options,
188
+ )
189
+ ff.run()
190
+
191
+ # We pad the file using the start time offset in case it's an audio
192
+ # stream starting at some offset in a video container.
193
+ pad, codec = ffprobe_offset_and_codec(audio_path)
194
+
195
+ # For mp3s, don't pad files with discrepancies less than 0.027s -
196
+ # it's likely due to codec latency. The amount of latency introduced
197
+ # by mp3 is 1152, which is 0.0261 44khz. So we set the threshold
198
+ # here slightly above that.
199
+ # Source: https://lame.sourceforge.io/tech-FAQ.txt.
200
+ if codec == "mp3" and pad < 0.027:
201
+ pad = 0.0
202
+ ff = ffmpy.FFmpeg(
203
+ inputs={wav_file: None},
204
+ outputs={padded_wav: f"-af 'adelay={pad*1000}:all=true'"},
205
+ global_options=global_options,
206
+ )
207
+ ff.run()
208
+
209
+ signal = cls(padded_wav, **kwargs)
210
+
211
+ return signal
dac-vae/audiotools/core/loudness.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ import julius
4
+ import numpy as np
5
+ import scipy
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torchaudio
9
+
10
+
11
+ class Meter(torch.nn.Module):
12
+ """Tensorized version of pyloudnorm.Meter. Works with batched audio tensors.
13
+
14
+ Parameters
15
+ ----------
16
+ rate : int
17
+ Sample rate of audio.
18
+ filter_class : str, optional
19
+ Class of weighting filter used.
20
+ K-weighting' (default), 'Fenton/Lee 1'
21
+ 'Fenton/Lee 2', 'Dash et al.'
22
+ by default "K-weighting"
23
+ block_size : float, optional
24
+ Gating block size in seconds, by default 0.400
25
+ zeros : int, optional
26
+ Number of zeros to use in FIR approximation of
27
+ IIR filters, by default 512
28
+ use_fir : bool, optional
29
+ Whether to use FIR approximation or exact IIR formulation.
30
+ If computing on GPU, ``use_fir=True`` will be used, as its
31
+ much faster, by default False
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ rate: int,
37
+ filter_class: str = "K-weighting",
38
+ block_size: float = 0.400,
39
+ zeros: int = 512,
40
+ use_fir: bool = False,
41
+ ):
42
+ super().__init__()
43
+
44
+ self.rate = rate
45
+ self.filter_class = filter_class
46
+ self.block_size = block_size
47
+ self.use_fir = use_fir
48
+
49
+ G = torch.from_numpy(np.array([1.0, 1.0, 1.0, 1.41, 1.41]))
50
+ self.register_buffer("G", G)
51
+
52
+ # Compute impulse responses so that filtering is fast via
53
+ # a convolution at runtime, on GPU, unlike lfilter.
54
+ impulse = np.zeros((zeros,))
55
+ impulse[..., 0] = 1.0
56
+
57
+ firs = np.zeros((len(self._filters), 1, zeros))
58
+ passband_gain = torch.zeros(len(self._filters))
59
+
60
+ for i, (_, filter_stage) in enumerate(self._filters.items()):
61
+ firs[i] = scipy.signal.lfilter(filter_stage.b, filter_stage.a, impulse)
62
+ passband_gain[i] = filter_stage.passband_gain
63
+
64
+ firs = torch.from_numpy(firs[..., ::-1].copy()).float()
65
+
66
+ self.register_buffer("firs", firs)
67
+ self.register_buffer("passband_gain", passband_gain)
68
+
69
+ def apply_filter_gpu(self, data: torch.Tensor):
70
+ """Performs FIR approximation of loudness computation.
71
+
72
+ Parameters
73
+ ----------
74
+ data : torch.Tensor
75
+ Audio data of shape (nb, nch, nt).
76
+
77
+ Returns
78
+ -------
79
+ torch.Tensor
80
+ Filtered audio data.
81
+ """
82
+ # Data is of shape (nb, nch, nt)
83
+ # Reshape to (nb*nch, 1, nt)
84
+ nb, nt, nch = data.shape
85
+ data = data.permute(0, 2, 1)
86
+ data = data.reshape(nb * nch, 1, nt)
87
+
88
+ # Apply padding
89
+ pad_length = self.firs.shape[-1]
90
+
91
+ # Apply filtering in sequence
92
+ for i in range(self.firs.shape[0]):
93
+ data = F.pad(data, (pad_length, pad_length))
94
+ data = julius.fftconv.fft_conv1d(data, self.firs[i, None, ...])
95
+ data = self.passband_gain[i] * data
96
+ data = data[..., 1 : nt + 1]
97
+
98
+ data = data.permute(0, 2, 1)
99
+ data = data[:, :nt, :]
100
+ return data
101
+
102
+ def apply_filter_cpu(self, data: torch.Tensor):
103
+ """Performs IIR formulation of loudness computation.
104
+
105
+ Parameters
106
+ ----------
107
+ data : torch.Tensor
108
+ Audio data of shape (nb, nch, nt).
109
+
110
+ Returns
111
+ -------
112
+ torch.Tensor
113
+ Filtered audio data.
114
+ """
115
+ for _, filter_stage in self._filters.items():
116
+ passband_gain = filter_stage.passband_gain
117
+
118
+ a_coeffs = torch.from_numpy(filter_stage.a).float().to(data.device)
119
+ b_coeffs = torch.from_numpy(filter_stage.b).float().to(data.device)
120
+
121
+ _data = data.permute(0, 2, 1)
122
+ filtered = torchaudio.functional.lfilter(
123
+ _data, a_coeffs, b_coeffs, clamp=False
124
+ )
125
+ data = passband_gain * filtered.permute(0, 2, 1)
126
+ return data
127
+
128
+ def apply_filter(self, data: torch.Tensor):
129
+ """Applies filter on either CPU or GPU, depending
130
+ on if the audio is on GPU or is on CPU, or if
131
+ ``self.use_fir`` is True.
132
+
133
+ Parameters
134
+ ----------
135
+ data : torch.Tensor
136
+ Audio data of shape (nb, nch, nt).
137
+
138
+ Returns
139
+ -------
140
+ torch.Tensor
141
+ Filtered audio data.
142
+ """
143
+ if data.is_cuda or self.use_fir:
144
+ data = self.apply_filter_gpu(data)
145
+ else:
146
+ data = self.apply_filter_cpu(data)
147
+ return data
148
+
149
+ def forward(self, data: torch.Tensor):
150
+ """Computes integrated loudness of data.
151
+
152
+ Parameters
153
+ ----------
154
+ data : torch.Tensor
155
+ Audio data of shape (nb, nch, nt).
156
+
157
+ Returns
158
+ -------
159
+ torch.Tensor
160
+ Filtered audio data.
161
+ """
162
+ return self.integrated_loudness(data)
163
+
164
+ def _unfold(self, input_data):
165
+ T_g = self.block_size
166
+ overlap = 0.75 # overlap of 75% of the block duration
167
+ step = 1.0 - overlap # step size by percentage
168
+
169
+ kernel_size = int(T_g * self.rate)
170
+ stride = int(T_g * self.rate * step)
171
+ unfolded = julius.core.unfold(input_data.permute(0, 2, 1), kernel_size, stride)
172
+ unfolded = unfolded.transpose(-1, -2)
173
+
174
+ return unfolded
175
+
176
+ def integrated_loudness(self, data: torch.Tensor):
177
+ """Computes integrated loudness of data.
178
+
179
+ Parameters
180
+ ----------
181
+ data : torch.Tensor
182
+ Audio data of shape (nb, nch, nt).
183
+
184
+ Returns
185
+ -------
186
+ torch.Tensor
187
+ Filtered audio data.
188
+ """
189
+ if not torch.is_tensor(data):
190
+ data = torch.from_numpy(data).float()
191
+ else:
192
+ data = data.float()
193
+
194
+ input_data = copy.copy(data)
195
+ # Data always has a batch and channel dimension.
196
+ # Is of shape (nb, nt, nch)
197
+ if input_data.ndim < 2:
198
+ input_data = input_data.unsqueeze(-1)
199
+ if input_data.ndim < 3:
200
+ input_data = input_data.unsqueeze(0)
201
+
202
+ nb, nt, nch = input_data.shape
203
+
204
+ # Apply frequency weighting filters - account
205
+ # for the acoustic respose of the head and auditory system
206
+ input_data = self.apply_filter(input_data)
207
+
208
+ G = self.G # channel gains
209
+ T_g = self.block_size # 400 ms gating block standard
210
+ Gamma_a = -70.0 # -70 LKFS = absolute loudness threshold
211
+
212
+ unfolded = self._unfold(input_data)
213
+
214
+ z = (1.0 / (T_g * self.rate)) * unfolded.square().sum(2)
215
+ l = -0.691 + 10.0 * torch.log10((G[None, :nch, None] * z).sum(1, keepdim=True))
216
+ l = l.expand_as(z)
217
+
218
+ # find gating block indices above absolute threshold
219
+ z_avg_gated = z
220
+ z_avg_gated[l <= Gamma_a] = 0
221
+ masked = l > Gamma_a
222
+ z_avg_gated = z_avg_gated.sum(2) / masked.sum(2)
223
+
224
+ # calculate the relative threshold value (see eq. 6)
225
+ Gamma_r = (
226
+ -0.691 + 10.0 * torch.log10((z_avg_gated * G[None, :nch]).sum(-1)) - 10.0
227
+ )
228
+ Gamma_r = Gamma_r[:, None, None]
229
+ Gamma_r = Gamma_r.expand(nb, nch, l.shape[-1])
230
+
231
+ # find gating block indices above relative and absolute thresholds (end of eq. 7)
232
+ z_avg_gated = z
233
+ z_avg_gated[l <= Gamma_a] = 0
234
+ z_avg_gated[l <= Gamma_r] = 0
235
+ masked = (l > Gamma_a) * (l > Gamma_r)
236
+ z_avg_gated = z_avg_gated.sum(2) / masked.sum(2)
237
+
238
+ # # Cannot use nan_to_num (pytorch 1.8 does not come with GCP-supported cuda version)
239
+ # z_avg_gated = torch.nan_to_num(z_avg_gated)
240
+ z_avg_gated = torch.where(
241
+ z_avg_gated.isnan(), torch.zeros_like(z_avg_gated), z_avg_gated
242
+ )
243
+ z_avg_gated[z_avg_gated == float("inf")] = float(np.finfo(np.float32).max)
244
+ z_avg_gated[z_avg_gated == -float("inf")] = float(np.finfo(np.float32).min)
245
+
246
+ LUFS = -0.691 + 10.0 * torch.log10((G[None, :nch] * z_avg_gated).sum(1))
247
+ return LUFS.float()
248
+
249
+ @property
250
+ def filter_class(self):
251
+ return self._filter_class
252
+
253
+ @filter_class.setter
254
+ def filter_class(self, value):
255
+ from pyloudnorm import Meter
256
+
257
+ meter = Meter(self.rate)
258
+ meter.filter_class = value
259
+ self._filter_class = value
260
+ self._filters = meter._filters
261
+
262
+
263
+ class LoudnessMixin:
264
+ _loudness = None
265
+ MIN_LOUDNESS = -70
266
+ """Minimum loudness possible."""
267
+
268
+ def loudness(
269
+ self, filter_class: str = "K-weighting", block_size: float = 0.400, **kwargs
270
+ ):
271
+ """Calculates loudness using an implementation of ITU-R BS.1770-4.
272
+ Allows control over gating block size and frequency weighting filters for
273
+ additional control. Measure the integrated gated loudness of a signal.
274
+
275
+ API is derived from PyLoudnorm, but this implementation is ported to PyTorch
276
+ and is tensorized across batches. When on GPU, an FIR approximation of the IIR
277
+ filters is used to compute loudness for speed.
278
+
279
+ Uses the weighting filters and block size defined by the meter
280
+ the integrated loudness is measured based upon the gating algorithm
281
+ defined in the ITU-R BS.1770-4 specification.
282
+
283
+ Parameters
284
+ ----------
285
+ filter_class : str, optional
286
+ Class of weighting filter used.
287
+ K-weighting' (default), 'Fenton/Lee 1'
288
+ 'Fenton/Lee 2', 'Dash et al.'
289
+ by default "K-weighting"
290
+ block_size : float, optional
291
+ Gating block size in seconds, by default 0.400
292
+ kwargs : dict, optional
293
+ Keyword arguments to :py:func:`audiotools.core.loudness.Meter`.
294
+
295
+ Returns
296
+ -------
297
+ torch.Tensor
298
+ Loudness of audio data.
299
+ """
300
+ if self._loudness is not None:
301
+ return self._loudness.to(self.device)
302
+ original_length = self.signal_length
303
+ if self.signal_duration < 0.5:
304
+ pad_len = int((0.5 - self.signal_duration) * self.sample_rate)
305
+ self.zero_pad(0, pad_len)
306
+
307
+ # create BS.1770 meter
308
+ meter = Meter(
309
+ self.sample_rate, filter_class=filter_class, block_size=block_size, **kwargs
310
+ )
311
+ meter = meter.to(self.device)
312
+ # measure loudness
313
+ loudness = meter.integrated_loudness(self.audio_data.permute(0, 2, 1))
314
+ self.truncate_samples(original_length)
315
+ min_loudness = (
316
+ torch.ones_like(loudness, device=loudness.device) * self.MIN_LOUDNESS
317
+ )
318
+ self._loudness = torch.maximum(loudness, min_loudness)
319
+
320
+ return self._loudness.to(self.device)
dac-vae/audiotools/core/playback.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ These are utilities that allow one to embed an AudioSignal
3
+ as a playable object in a Jupyter notebook, or to play audio from
4
+ the terminal, etc.
5
+ """ # fmt: skip
6
+ import base64
7
+ import io
8
+ import random
9
+ import string
10
+ import subprocess
11
+ from tempfile import NamedTemporaryFile
12
+
13
+ import importlib_resources as pkg_resources
14
+
15
+ from . import templates
16
+ from .util import _close_temp_files
17
+ from .util import format_figure
18
+
19
+ headers = pkg_resources.files(templates).joinpath("headers.html").read_text()
20
+ widget = pkg_resources.files(templates).joinpath("widget.html").read_text()
21
+
22
+ DEFAULT_EXTENSION = ".wav"
23
+
24
+
25
+ def _check_imports(): # pragma: no cover
26
+ try:
27
+ import ffmpy
28
+ except:
29
+ ffmpy = False
30
+
31
+ try:
32
+ import IPython
33
+ except:
34
+ raise ImportError("IPython must be installed in order to use this function!")
35
+ return ffmpy, IPython
36
+
37
+
38
+ class PlayMixin:
39
+ def embed(self, ext: str = None, display: bool = True, return_html: bool = False):
40
+ """Embeds audio as a playable audio embed in a notebook, or HTML
41
+ document, etc.
42
+
43
+ Parameters
44
+ ----------
45
+ ext : str, optional
46
+ Extension to use when saving the audio, by default ".wav"
47
+ display : bool, optional
48
+ This controls whether or not to display the audio when called. This
49
+ is used when the embed is the last line in a Jupyter cell, to prevent
50
+ the audio from being embedded twice, by default True
51
+ return_html : bool, optional
52
+ Whether to return the data wrapped in an HTML audio element, by default False
53
+
54
+ Returns
55
+ -------
56
+ str
57
+ Either the element for display, or the HTML string of it.
58
+ """
59
+ if ext is None:
60
+ ext = DEFAULT_EXTENSION
61
+ ext = f".{ext}" if not ext.startswith(".") else ext
62
+ ffmpy, IPython = _check_imports()
63
+ sr = self.sample_rate
64
+ tmpfiles = []
65
+
66
+ with _close_temp_files(tmpfiles):
67
+ tmp_wav = NamedTemporaryFile(mode="w+", suffix=".wav", delete=False)
68
+ tmpfiles.append(tmp_wav)
69
+ self.write(tmp_wav.name)
70
+ if ext != ".wav" and ffmpy:
71
+ tmp_converted = NamedTemporaryFile(mode="w+", suffix=ext, delete=False)
72
+ tmpfiles.append(tmp_wav)
73
+ ff = ffmpy.FFmpeg(
74
+ inputs={tmp_wav.name: None},
75
+ outputs={
76
+ tmp_converted.name: "-write_xing 0 -codec:a libmp3lame -b:a 128k -y -hide_banner -loglevel error"
77
+ },
78
+ )
79
+ ff.run()
80
+ else:
81
+ tmp_converted = tmp_wav
82
+
83
+ audio_element = IPython.display.Audio(data=tmp_converted.name, rate=sr)
84
+ if display:
85
+ IPython.display.display(audio_element)
86
+
87
+ if return_html:
88
+ audio_element = (
89
+ f"<audio "
90
+ f" controls "
91
+ f" src='{audio_element.src_attr()}'> "
92
+ f"</audio> "
93
+ )
94
+ return audio_element
95
+
96
+ def widget(
97
+ self,
98
+ title: str = None,
99
+ ext: str = ".wav",
100
+ add_headers: bool = True,
101
+ player_width: str = "100%",
102
+ margin: str = "10px",
103
+ plot_fn: str = "specshow",
104
+ return_html: bool = False,
105
+ **kwargs,
106
+ ):
107
+ """Creates a playable widget with spectrogram. Inspired (heavily) by
108
+ https://sjvasquez.github.io/blog/melnet/.
109
+
110
+ Parameters
111
+ ----------
112
+ title : str, optional
113
+ Title of plot, placed in upper right of top-most axis.
114
+ ext : str, optional
115
+ Extension for embedding, by default ".mp3"
116
+ add_headers : bool, optional
117
+ Whether or not to add headers (use for first embed, False for later embeds), by default True
118
+ player_width : str, optional
119
+ Width of the player, as a string in a CSS rule, by default "100%"
120
+ margin : str, optional
121
+ Margin on all sides of player, by default "10px"
122
+ plot_fn : function, optional
123
+ Plotting function to use (by default self.specshow).
124
+ return_html : bool, optional
125
+ Whether to return the data wrapped in an HTML audio element, by default False
126
+ kwargs : dict, optional
127
+ Keyword arguments to plot_fn (by default self.specshow).
128
+
129
+ Returns
130
+ -------
131
+ HTML
132
+ HTML object.
133
+ """
134
+ import matplotlib.pyplot as plt
135
+
136
+ def _save_fig_to_tag():
137
+ buffer = io.BytesIO()
138
+
139
+ plt.savefig(buffer, bbox_inches="tight", pad_inches=0)
140
+ plt.close()
141
+
142
+ buffer.seek(0)
143
+ data_uri = base64.b64encode(buffer.read()).decode("ascii")
144
+ tag = "data:image/png;base64,{0}".format(data_uri)
145
+
146
+ return tag
147
+
148
+ _, IPython = _check_imports()
149
+
150
+ header_html = ""
151
+
152
+ if add_headers:
153
+ header_html = headers.replace("PLAYER_WIDTH", str(player_width))
154
+ header_html = header_html.replace("MARGIN", str(margin))
155
+ IPython.display.display(IPython.display.HTML(header_html))
156
+
157
+ widget_html = widget
158
+ if isinstance(plot_fn, str):
159
+ plot_fn = getattr(self, plot_fn)
160
+ kwargs["title"] = title
161
+ plot_fn(**kwargs)
162
+
163
+ fig = plt.gcf()
164
+ pixels = fig.get_size_inches() * fig.dpi
165
+
166
+ tag = _save_fig_to_tag()
167
+
168
+ # Make the source image for the levels
169
+ self.specshow()
170
+ format_figure((12, 1.5))
171
+ levels_tag = _save_fig_to_tag()
172
+
173
+ player_id = "".join(random.choice(string.ascii_uppercase) for _ in range(10))
174
+
175
+ audio_elem = self.embed(ext=ext, display=False)
176
+ widget_html = widget_html.replace("AUDIO_SRC", audio_elem.src_attr())
177
+ widget_html = widget_html.replace("IMAGE_SRC", tag)
178
+ widget_html = widget_html.replace("LEVELS_SRC", levels_tag)
179
+ widget_html = widget_html.replace("PLAYER_ID", player_id)
180
+
181
+ # Calculate width/height of figure based on figure size.
182
+ widget_html = widget_html.replace("PADDING_AMOUNT", f"{int(pixels[1])}px")
183
+ widget_html = widget_html.replace("MAX_WIDTH", f"{int(pixels[0])}px")
184
+
185
+ IPython.display.display(IPython.display.HTML(widget_html))
186
+
187
+ if return_html:
188
+ html = header_html if add_headers else ""
189
+ html += widget_html
190
+ return html
191
+
192
+ def play(self):
193
+ """
194
+ Plays an audio signal if ffplay from the ffmpeg suite of tools is installed.
195
+ Otherwise, will fail. The audio signal is written to a temporary file
196
+ and then played with ffplay.
197
+ """
198
+ tmpfiles = []
199
+ with _close_temp_files(tmpfiles):
200
+ tmp_wav = NamedTemporaryFile(suffix=".wav", delete=False)
201
+ tmpfiles.append(tmp_wav)
202
+ self.write(tmp_wav.name)
203
+ print(self)
204
+ subprocess.call(
205
+ [
206
+ "ffplay",
207
+ "-nodisp",
208
+ "-autoexit",
209
+ "-hide_banner",
210
+ "-loglevel",
211
+ "error",
212
+ tmp_wav.name,
213
+ ]
214
+ )
215
+ return self
216
+
217
+
218
+ if __name__ == "__main__": # pragma: no cover
219
+ from audiotools import AudioSignal
220
+
221
+ signal = AudioSignal(
222
+ "tests/audio/spk/f10_script4_produced.mp3", offset=5, duration=5
223
+ )
224
+
225
+ wave_html = signal.widget(
226
+ "Waveform",
227
+ plot_fn="waveplot",
228
+ return_html=True,
229
+ )
230
+
231
+ spec_html = signal.widget("Spectrogram", return_html=True, add_headers=False)
232
+
233
+ combined_html = signal.widget(
234
+ "Waveform + spectrogram",
235
+ plot_fn="wavespec",
236
+ return_html=True,
237
+ add_headers=False,
238
+ )
239
+
240
+ signal.low_pass(8000)
241
+ lowpass_html = signal.widget(
242
+ "Lowpassed audio",
243
+ plot_fn="wavespec",
244
+ return_html=True,
245
+ add_headers=False,
246
+ )
247
+
248
+ with open("/tmp/index.html", "w") as f:
249
+ f.write(wave_html)
250
+ f.write(spec_html)
251
+ f.write(combined_html)
252
+ f.write(lowpass_html)
dac-vae/audiotools/core/templates/__init__.py ADDED
File without changes
dac-vae/audiotools/core/templates/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (166 Bytes). View file
 
dac-vae/audiotools/core/templates/headers.html ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <style>
2
+ .player {
3
+ width: 100%;
4
+ /*border: 1px solid black;*/
5
+ margin: 10px;
6
+ }
7
+
8
+ .underlay img {
9
+ width: 100%;
10
+ height: 100%;
11
+ }
12
+
13
+ .spectrogram {
14
+ height: 0;
15
+ width: 100%;
16
+ position: relative;
17
+ }
18
+
19
+ .audio-controls {
20
+ width: 100%;
21
+ height: 54px;
22
+ display: flex;
23
+ /*border-top: 1px solid black;*/
24
+ /*background-color: rgb(241, 243, 244);*/
25
+ background-color: rgb(248, 248, 248);
26
+ background-color: rgb(253, 253, 254);
27
+ border: 1px solid rgb(205, 208, 211);
28
+ margin-top: 20px;
29
+ /*border: 1px solid black;*/
30
+ border-radius: 30px;
31
+
32
+ }
33
+
34
+ .play-img {
35
+ margin: auto;
36
+ height: 45%;
37
+ width: 45%;
38
+ display: block;
39
+ }
40
+
41
+ .download-img {
42
+ margin: auto;
43
+ height: 100%;
44
+ width: 100%;
45
+ display: block;
46
+ }
47
+
48
+ .pause-img {
49
+ margin: auto;
50
+ height: 45%;
51
+ width: 45%;
52
+ display: none
53
+ }
54
+
55
+ .playpause {
56
+ margin:11px 11px 11px 11px;
57
+ width: 32px;
58
+ min-width: 32px;
59
+ height: 32px;
60
+ /*background-color: rgb(241, 243, 244);*/
61
+ background-color: rgba(0, 0, 0, 0.0);
62
+ /*border-right: 1px solid black;*/
63
+ /*border: 1px solid red;*/
64
+ border-radius: 16px;
65
+ color: black;
66
+ transition: 0.25s;
67
+ box-sizing: border-box !important;
68
+ }
69
+
70
+ .download {
71
+ margin:11px 11px 11px 11px;
72
+ width: 32px;
73
+ min-width: 32px;
74
+ height: 32px;
75
+ /*background-color: rgb(241, 243, 244);*/
76
+ background-color: rgba(0, 0, 0, 0.0);
77
+ /*border-right: 1px solid black;*/
78
+ /*border: 1px solid red;*/
79
+ border-radius: 16px;
80
+ color: black;
81
+ transition: 0.25s;
82
+ box-sizing: border-box !important;
83
+ }
84
+
85
+ /*.playpause:disabled {
86
+ background-color: red;
87
+ }*/
88
+
89
+ .playpause:hover {
90
+ background-color: rgba(10, 20, 30, 0.03);
91
+ }
92
+
93
+ .playpause:focus {
94
+ outline:none;
95
+ }
96
+
97
+ .response {
98
+ padding:0px 20px 0px 0px;
99
+ width: calc(100% - 132px);
100
+ height: 100%;
101
+
102
+ /*border: 1px solid red;*/
103
+ /*border-bottom: 1px solid rgb(89, 89, 89);*/
104
+ }
105
+
106
+ .response-canvas {
107
+ height: 100%;
108
+ width: 100%;
109
+ }
110
+
111
+
112
+ .underlay {
113
+ height: 100%;
114
+ width: 100%;
115
+ position: absolute;
116
+ top: 0;
117
+ left: 0;
118
+ }
119
+
120
+ .overlay{
121
+ width: 0%;
122
+ height:100%;
123
+ top: 0;
124
+ right: 0px;
125
+
126
+ background:rgba(255, 255, 255, 0.15);
127
+ overflow:hidden;
128
+ position: absolute;
129
+ z-index: 10;
130
+ border-left: solid 1px rgba(0, 0, 0, 0.664);
131
+
132
+ position: absolute;
133
+ pointer-events: none;
134
+ }
135
+ </style>
136
+
137
+ <script>
138
+ !function(t){if("object"==typeof exports&&"undefined"!=typeof module)module.exports=t();else if("function"==typeof define&&define.amd)define([],t);else{("undefined"!=typeof window?window:"undefined"!=typeof global?global:"undefined"!=typeof self?self:this).pako=t()}}(function(){return function(){return function t(e,a,i){function n(s,o){if(!a[s]){if(!e[s]){var l="function"==typeof require&&require;if(!o&&l)return l(s,!0);if(r)return r(s,!0);var h=new Error("Cannot find module '"+s+"'");throw h.code="MODULE_NOT_FOUND",h}var d=a[s]={exports:{}};e[s][0].call(d.exports,function(t){return n(e[s][1][t]||t)},d,d.exports,t,e,a,i)}return a[s].exports}for(var r="function"==typeof require&&require,s=0;s<i.length;s++)n(i[s]);return n}}()({1:[function(t,e,a){"use strict";var i=t("./zlib/deflate"),n=t("./utils/common"),r=t("./utils/strings"),s=t("./zlib/messages"),o=t("./zlib/zstream"),l=Object.prototype.toString,h=0,d=-1,f=0,_=8;function u(t){if(!(this instanceof u))return new u(t);this.options=n.assign({level:d,method:_,chunkSize:16384,windowBits:15,memLevel:8,strategy:f,to:""},t||{});var e=this.options;e.raw&&e.windowBits>0?e.windowBits=-e.windowBits:e.gzip&&e.windowBits>0&&e.windowBits<16&&(e.windowBits+=16),this.err=0,this.msg="",this.ended=!1,this.chunks=[],this.strm=new o,this.strm.avail_out=0;var a=i.deflateInit2(this.strm,e.level,e.method,e.windowBits,e.memLevel,e.strategy);if(a!==h)throw new Error(s[a]);if(e.header&&i.deflateSetHeader(this.strm,e.header),e.dictionary){var c;if(c="string"==typeof e.dictionary?r.string2buf(e.dictionary):"[object ArrayBuffer]"===l.call(e.dictionary)?new Uint8Array(e.dictionary):e.dictionary,(a=i.deflateSetDictionary(this.strm,c))!==h)throw new Error(s[a]);this._dict_set=!0}}function c(t,e){var a=new u(e);if(a.push(t,!0),a.err)throw a.msg||s[a.err];return a.result}u.prototype.push=function(t,e){var a,s,o=this.strm,d=this.options.chunkSize;if(this.ended)return!1;s=e===~~e?e:!0===e?4:0,"string"==typeof t?o.input=r.string2buf(t):"[object ArrayBuffer]"===l.call(t)?o.input=new Uint8Array(t):o.input=t,o.next_in=0,o.avail_in=o.input.length;do{if(0===o.avail_out&&(o.output=new n.Buf8(d),o.next_out=0,o.avail_out=d),1!==(a=i.deflate(o,s))&&a!==h)return this.onEnd(a),this.ended=!0,!1;0!==o.avail_out&&(0!==o.avail_in||4!==s&&2!==s)||("string"===this.options.to?this.onData(r.buf2binstring(n.shrinkBuf(o.output,o.next_out))):this.onData(n.shrinkBuf(o.output,o.next_out)))}while((o.avail_in>0||0===o.avail_out)&&1!==a);return 4===s?(a=i.deflateEnd(this.strm),this.onEnd(a),this.ended=!0,a===h):2!==s||(this.onEnd(h),o.avail_out=0,!0)},u.prototype.onData=function(t){this.chunks.push(t)},u.prototype.onEnd=function(t){t===h&&("string"===this.options.to?this.result=this.chunks.join(""):this.result=n.flattenChunks(this.chunks)),this.chunks=[],this.err=t,this.msg=this.strm.msg},a.Deflate=u,a.deflate=c,a.deflateRaw=function(t,e){return(e=e||{}).raw=!0,c(t,e)},a.gzip=function(t,e){return(e=e||{}).gzip=!0,c(t,e)}},{"./utils/common":3,"./utils/strings":4,"./zlib/deflate":8,"./zlib/messages":13,"./zlib/zstream":15}],2:[function(t,e,a){"use strict";var i=t("./zlib/inflate"),n=t("./utils/common"),r=t("./utils/strings"),s=t("./zlib/constants"),o=t("./zlib/messages"),l=t("./zlib/zstream"),h=t("./zlib/gzheader"),d=Object.prototype.toString;function f(t){if(!(this instanceof f))return new f(t);this.options=n.assign({chunkSize:16384,windowBits:0,to:""},t||{});var e=this.options;e.raw&&e.windowBits>=0&&e.windowBits<16&&(e.windowBits=-e.windowBits,0===e.windowBits&&(e.windowBits=-15)),!(e.windowBits>=0&&e.windowBits<16)||t&&t.windowBits||(e.windowBits+=32),e.windowBits>15&&e.windowBits<48&&0==(15&e.windowBits)&&(e.windowBits|=15),this.err=0,this.msg="",this.ended=!1,this.chunks=[],this.strm=new l,this.strm.avail_out=0;var a=i.inflateInit2(this.strm,e.windowBits);if(a!==s.Z_OK)throw new Error(o[a]);if(this.header=new h,i.inflateGetHeader(this.strm,this.header),e.dictionary&&("string"==typeof e.dictionary?e.dictionary=r.string2buf(e.dictionary):"[object ArrayBuffer]"===d.call(e.dictionary)&&(e.dictionary=new Uint8Array(e.dictionary)),e.raw&&(a=i.inflateSetDictionary(this.strm,e.dictionary))!==s.Z_OK))throw new Error(o[a])}function _(t,e){var a=new f(e);if(a.push(t,!0),a.err)throw a.msg||o[a.err];return a.result}f.prototype.push=function(t,e){var a,o,l,h,f,_=this.strm,u=this.options.chunkSize,c=this.options.dictionary,b=!1;if(this.ended)return!1;o=e===~~e?e:!0===e?s.Z_FINISH:s.Z_NO_FLUSH,"string"==typeof t?_.input=r.binstring2buf(t):"[object ArrayBuffer]"===d.call(t)?_.input=new Uint8Array(t):_.input=t,_.next_in=0,_.avail_in=_.input.length;do{if(0===_.avail_out&&(_.output=new n.Buf8(u),_.next_out=0,_.avail_out=u),(a=i.inflate(_,s.Z_NO_FLUSH))===s.Z_NEED_DICT&&c&&(a=i.inflateSetDictionary(this.strm,c)),a===s.Z_BUF_ERROR&&!0===b&&(a=s.Z_OK,b=!1),a!==s.Z_STREAM_END&&a!==s.Z_OK)return this.onEnd(a),this.ended=!0,!1;_.next_out&&(0!==_.avail_out&&a!==s.Z_STREAM_END&&(0!==_.avail_in||o!==s.Z_FINISH&&o!==s.Z_SYNC_FLUSH)||("string"===this.options.to?(l=r.utf8border(_.output,_.next_out),h=_.next_out-l,f=r.buf2string(_.output,l),_.next_out=h,_.avail_out=u-h,h&&n.arraySet(_.output,_.output,l,h,0),this.onData(f)):this.onData(n.shrinkBuf(_.output,_.next_out)))),0===_.avail_in&&0===_.avail_out&&(b=!0)}while((_.avail_in>0||0===_.avail_out)&&a!==s.Z_STREAM_END);return a===s.Z_STREAM_END&&(o=s.Z_FINISH),o===s.Z_FINISH?(a=i.inflateEnd(this.strm),this.onEnd(a),this.ended=!0,a===s.Z_OK):o!==s.Z_SYNC_FLUSH||(this.onEnd(s.Z_OK),_.avail_out=0,!0)},f.prototype.onData=function(t){this.chunks.push(t)},f.prototype.onEnd=function(t){t===s.Z_OK&&("string"===this.options.to?this.result=this.chunks.join(""):this.result=n.flattenChunks(this.chunks)),this.chunks=[],this.err=t,this.msg=this.strm.msg},a.Inflate=f,a.inflate=_,a.inflateRaw=function(t,e){return(e=e||{}).raw=!0,_(t,e)},a.ungzip=_},{"./utils/common":3,"./utils/strings":4,"./zlib/constants":6,"./zlib/gzheader":9,"./zlib/inflate":11,"./zlib/messages":13,"./zlib/zstream":15}],3:[function(t,e,a){"use strict";var i="undefined"!=typeof Uint8Array&&"undefined"!=typeof Uint16Array&&"undefined"!=typeof Int32Array;function n(t,e){return Object.prototype.hasOwnProperty.call(t,e)}a.assign=function(t){for(var e=Array.prototype.slice.call(arguments,1);e.length;){var a=e.shift();if(a){if("object"!=typeof a)throw new TypeError(a+"must be non-object");for(var i in a)n(a,i)&&(t[i]=a[i])}}return t},a.shrinkBuf=function(t,e){return t.length===e?t:t.subarray?t.subarray(0,e):(t.length=e,t)};var r={arraySet:function(t,e,a,i,n){if(e.subarray&&t.subarray)t.set(e.subarray(a,a+i),n);else for(var r=0;r<i;r++)t[n+r]=e[a+r]},flattenChunks:function(t){var e,a,i,n,r,s;for(i=0,e=0,a=t.length;e<a;e++)i+=t[e].length;for(s=new Uint8Array(i),n=0,e=0,a=t.length;e<a;e++)r=t[e],s.set(r,n),n+=r.length;return s}},s={arraySet:function(t,e,a,i,n){for(var r=0;r<i;r++)t[n+r]=e[a+r]},flattenChunks:function(t){return[].concat.apply([],t)}};a.setTyped=function(t){t?(a.Buf8=Uint8Array,a.Buf16=Uint16Array,a.Buf32=Int32Array,a.assign(a,r)):(a.Buf8=Array,a.Buf16=Array,a.Buf32=Array,a.assign(a,s))},a.setTyped(i)},{}],4:[function(t,e,a){"use strict";var i=t("./common"),n=!0,r=!0;try{String.fromCharCode.apply(null,[0])}catch(t){n=!1}try{String.fromCharCode.apply(null,new Uint8Array(1))}catch(t){r=!1}for(var s=new i.Buf8(256),o=0;o<256;o++)s[o]=o>=252?6:o>=248?5:o>=240?4:o>=224?3:o>=192?2:1;function l(t,e){if(e<65534&&(t.subarray&&r||!t.subarray&&n))return String.fromCharCode.apply(null,i.shrinkBuf(t,e));for(var a="",s=0;s<e;s++)a+=String.fromCharCode(t[s]);return a}s[254]=s[254]=1,a.string2buf=function(t){var e,a,n,r,s,o=t.length,l=0;for(r=0;r<o;r++)55296==(64512&(a=t.charCodeAt(r)))&&r+1<o&&56320==(64512&(n=t.charCodeAt(r+1)))&&(a=65536+(a-55296<<10)+(n-56320),r++),l+=a<128?1:a<2048?2:a<65536?3:4;for(e=new i.Buf8(l),s=0,r=0;s<l;r++)55296==(64512&(a=t.charCodeAt(r)))&&r+1<o&&56320==(64512&(n=t.charCodeAt(r+1)))&&(a=65536+(a-55296<<10)+(n-56320),r++),a<128?e[s++]=a:a<2048?(e[s++]=192|a>>>6,e[s++]=128|63&a):a<65536?(e[s++]=224|a>>>12,e[s++]=128|a>>>6&63,e[s++]=128|63&a):(e[s++]=240|a>>>18,e[s++]=128|a>>>12&63,e[s++]=128|a>>>6&63,e[s++]=128|63&a);return e},a.buf2binstring=function(t){return l(t,t.length)},a.binstring2buf=function(t){for(var e=new i.Buf8(t.length),a=0,n=e.length;a<n;a++)e[a]=t.charCodeAt(a);return e},a.buf2string=function(t,e){var a,i,n,r,o=e||t.length,h=new Array(2*o);for(i=0,a=0;a<o;)if((n=t[a++])<128)h[i++]=n;else if((r=s[n])>4)h[i++]=65533,a+=r-1;else{for(n&=2===r?31:3===r?15:7;r>1&&a<o;)n=n<<6|63&t[a++],r--;r>1?h[i++]=65533:n<65536?h[i++]=n:(n-=65536,h[i++]=55296|n>>10&1023,h[i++]=56320|1023&n)}return l(h,i)},a.utf8border=function(t,e){var a;for((e=e||t.length)>t.length&&(e=t.length),a=e-1;a>=0&&128==(192&t[a]);)a--;return a<0?e:0===a?e:a+s[t[a]]>e?a:e}},{"./common":3}],5:[function(t,e,a){"use strict";e.exports=function(t,e,a,i){for(var n=65535&t|0,r=t>>>16&65535|0,s=0;0!==a;){a-=s=a>2e3?2e3:a;do{r=r+(n=n+e[i++]|0)|0}while(--s);n%=65521,r%=65521}return n|r<<16|0}},{}],6:[function(t,e,a){"use strict";e.exports={Z_NO_FLUSH:0,Z_PARTIAL_FLUSH:1,Z_SYNC_FLUSH:2,Z_FULL_FLUSH:3,Z_FINISH:4,Z_BLOCK:5,Z_TREES:6,Z_OK:0,Z_STREAM_END:1,Z_NEED_DICT:2,Z_ERRNO:-1,Z_STREAM_ERROR:-2,Z_DATA_ERROR:-3,Z_BUF_ERROR:-5,Z_NO_COMPRESSION:0,Z_BEST_SPEED:1,Z_BEST_COMPRESSION:9,Z_DEFAULT_COMPRESSION:-1,Z_FILTERED:1,Z_HUFFMAN_ONLY:2,Z_RLE:3,Z_FIXED:4,Z_DEFAULT_STRATEGY:0,Z_BINARY:0,Z_TEXT:1,Z_UNKNOWN:2,Z_DEFLATED:8}},{}],7:[function(t,e,a){"use strict";var i=function(){for(var t,e=[],a=0;a<256;a++){t=a;for(var i=0;i<8;i++)t=1&t?3988292384^t>>>1:t>>>1;e[a]=t}return e}();e.exports=function(t,e,a,n){var r=i,s=n+a;t^=-1;for(var o=n;o<s;o++)t=t>>>8^r[255&(t^e[o])];return-1^t}},{}],8:[function(t,e,a){"use strict";var i,n=t("../utils/common"),r=t("./trees"),s=t("./adler32"),o=t("./crc32"),l=t("./messages"),h=0,d=1,f=3,_=4,u=5,c=0,b=1,g=-2,m=-3,w=-5,p=-1,v=1,k=2,y=3,x=4,z=0,B=2,S=8,E=9,A=15,Z=8,R=286,C=30,N=19,O=2*R+1,D=15,I=3,U=258,T=U+I+1,F=32,L=42,H=69,j=73,K=91,M=103,P=113,Y=666,q=1,G=2,X=3,W=4,J=3;function Q(t,e){return t.msg=l[e],e}function V(t){return(t<<1)-(t>4?9:0)}function $(t){for(var e=t.length;--e>=0;)t[e]=0}function tt(t){var e=t.state,a=e.pending;a>t.avail_out&&(a=t.avail_out),0!==a&&(n.arraySet(t.output,e.pending_buf,e.pending_out,a,t.next_out),t.next_out+=a,e.pending_out+=a,t.total_out+=a,t.avail_out-=a,e.pending-=a,0===e.pending&&(e.pending_out=0))}function et(t,e){r._tr_flush_block(t,t.block_start>=0?t.block_start:-1,t.strstart-t.block_start,e),t.block_start=t.strstart,tt(t.strm)}function at(t,e){t.pending_buf[t.pending++]=e}function it(t,e){t.pending_buf[t.pending++]=e>>>8&255,t.pending_buf[t.pending++]=255&e}function nt(t,e){var a,i,n=t.max_chain_length,r=t.strstart,s=t.prev_length,o=t.nice_match,l=t.strstart>t.w_size-T?t.strstart-(t.w_size-T):0,h=t.window,d=t.w_mask,f=t.prev,_=t.strstart+U,u=h[r+s-1],c=h[r+s];t.prev_length>=t.good_match&&(n>>=2),o>t.lookahead&&(o=t.lookahead);do{if(h[(a=e)+s]===c&&h[a+s-1]===u&&h[a]===h[r]&&h[++a]===h[r+1]){r+=2,a++;do{}while(h[++r]===h[++a]&&h[++r]===h[++a]&&h[++r]===h[++a]&&h[++r]===h[++a]&&h[++r]===h[++a]&&h[++r]===h[++a]&&h[++r]===h[++a]&&h[++r]===h[++a]&&r<_);if(i=U-(_-r),r=_-U,i>s){if(t.match_start=e,s=i,i>=o)break;u=h[r+s-1],c=h[r+s]}}}while((e=f[e&d])>l&&0!=--n);return s<=t.lookahead?s:t.lookahead}function rt(t){var e,a,i,r,l,h,d,f,_,u,c=t.w_size;do{if(r=t.window_size-t.lookahead-t.strstart,t.strstart>=c+(c-T)){n.arraySet(t.window,t.window,c,c,0),t.match_start-=c,t.strstart-=c,t.block_start-=c,e=a=t.hash_size;do{i=t.head[--e],t.head[e]=i>=c?i-c:0}while(--a);e=a=c;do{i=t.prev[--e],t.prev[e]=i>=c?i-c:0}while(--a);r+=c}if(0===t.strm.avail_in)break;if(h=t.strm,d=t.window,f=t.strstart+t.lookahead,_=r,u=void 0,(u=h.avail_in)>_&&(u=_),a=0===u?0:(h.avail_in-=u,n.arraySet(d,h.input,h.next_in,u,f),1===h.state.wrap?h.adler=s(h.adler,d,u,f):2===h.state.wrap&&(h.adler=o(h.adler,d,u,f)),h.next_in+=u,h.total_in+=u,u),t.lookahead+=a,t.lookahead+t.insert>=I)for(l=t.strstart-t.insert,t.ins_h=t.window[l],t.ins_h=(t.ins_h<<t.hash_shift^t.window[l+1])&t.hash_mask;t.insert&&(t.ins_h=(t.ins_h<<t.hash_shift^t.window[l+I-1])&t.hash_mask,t.prev[l&t.w_mask]=t.head[t.ins_h],t.head[t.ins_h]=l,l++,t.insert--,!(t.lookahead+t.insert<I)););}while(t.lookahead<T&&0!==t.strm.avail_in)}function st(t,e){for(var a,i;;){if(t.lookahead<T){if(rt(t),t.lookahead<T&&e===h)return q;if(0===t.lookahead)break}if(a=0,t.lookahead>=I&&(t.ins_h=(t.ins_h<<t.hash_shift^t.window[t.strstart+I-1])&t.hash_mask,a=t.prev[t.strstart&t.w_mask]=t.head[t.ins_h],t.head[t.ins_h]=t.strstart),0!==a&&t.strstart-a<=t.w_size-T&&(t.match_length=nt(t,a)),t.match_length>=I)if(i=r._tr_tally(t,t.strstart-t.match_start,t.match_length-I),t.lookahead-=t.match_length,t.match_length<=t.max_lazy_match&&t.lookahead>=I){t.match_length--;do{t.strstart++,t.ins_h=(t.ins_h<<t.hash_shift^t.window[t.strstart+I-1])&t.hash_mask,a=t.prev[t.strstart&t.w_mask]=t.head[t.ins_h],t.head[t.ins_h]=t.strstart}while(0!=--t.match_length);t.strstart++}else t.strstart+=t.match_length,t.match_length=0,t.ins_h=t.window[t.strstart],t.ins_h=(t.ins_h<<t.hash_shift^t.window[t.strstart+1])&t.hash_mask;else i=r._tr_tally(t,0,t.window[t.strstart]),t.lookahead--,t.strstart++;if(i&&(et(t,!1),0===t.strm.avail_out))return q}return t.insert=t.strstart<I-1?t.strstart:I-1,e===_?(et(t,!0),0===t.strm.avail_out?X:W):t.last_lit&&(et(t,!1),0===t.strm.avail_out)?q:G}function ot(t,e){for(var a,i,n;;){if(t.lookahead<T){if(rt(t),t.lookahead<T&&e===h)return q;if(0===t.lookahead)break}if(a=0,t.lookahead>=I&&(t.ins_h=(t.ins_h<<t.hash_shift^t.window[t.strstart+I-1])&t.hash_mask,a=t.prev[t.strstart&t.w_mask]=t.head[t.ins_h],t.head[t.ins_h]=t.strstart),t.prev_length=t.match_length,t.prev_match=t.match_start,t.match_length=I-1,0!==a&&t.prev_length<t.max_lazy_match&&t.strstart-a<=t.w_size-T&&(t.match_length=nt(t,a),t.match_length<=5&&(t.strategy===v||t.match_length===I&&t.strstart-t.match_start>4096)&&(t.match_length=I-1)),t.prev_length>=I&&t.match_length<=t.prev_length){n=t.strstart+t.lookahead-I,i=r._tr_tally(t,t.strstart-1-t.prev_match,t.prev_length-I),t.lookahead-=t.prev_length-1,t.prev_length-=2;do{++t.strstart<=n&&(t.ins_h=(t.ins_h<<t.hash_shift^t.window[t.strstart+I-1])&t.hash_mask,a=t.prev[t.strstart&t.w_mask]=t.head[t.ins_h],t.head[t.ins_h]=t.strstart)}while(0!=--t.prev_length);if(t.match_available=0,t.match_length=I-1,t.strstart++,i&&(et(t,!1),0===t.strm.avail_out))return q}else if(t.match_available){if((i=r._tr_tally(t,0,t.window[t.strstart-1]))&&et(t,!1),t.strstart++,t.lookahead--,0===t.strm.avail_out)return q}else t.match_available=1,t.strstart++,t.lookahead--}return t.match_available&&(i=r._tr_tally(t,0,t.window[t.strstart-1]),t.match_available=0),t.insert=t.strstart<I-1?t.strstart:I-1,e===_?(et(t,!0),0===t.strm.avail_out?X:W):t.last_lit&&(et(t,!1),0===t.strm.avail_out)?q:G}function lt(t,e,a,i,n){this.good_length=t,this.max_lazy=e,this.nice_length=a,this.max_chain=i,this.func=n}function ht(){this.strm=null,this.status=0,this.pending_buf=null,this.pending_buf_size=0,this.pending_out=0,this.pending=0,this.wrap=0,this.gzhead=null,this.gzindex=0,this.method=S,this.last_flush=-1,this.w_size=0,this.w_bits=0,this.w_mask=0,this.window=null,this.window_size=0,this.prev=null,this.head=null,this.ins_h=0,this.hash_size=0,this.hash_bits=0,this.hash_mask=0,this.hash_shift=0,this.block_start=0,this.match_length=0,this.prev_match=0,this.match_available=0,this.strstart=0,this.match_start=0,this.lookahead=0,this.prev_length=0,this.max_chain_length=0,this.max_lazy_match=0,this.level=0,this.strategy=0,this.good_match=0,this.nice_match=0,this.dyn_ltree=new n.Buf16(2*O),this.dyn_dtree=new n.Buf16(2*(2*C+1)),this.bl_tree=new n.Buf16(2*(2*N+1)),$(this.dyn_ltree),$(this.dyn_dtree),$(this.bl_tree),this.l_desc=null,this.d_desc=null,this.bl_desc=null,this.bl_count=new n.Buf16(D+1),this.heap=new n.Buf16(2*R+1),$(this.heap),this.heap_len=0,this.heap_max=0,this.depth=new n.Buf16(2*R+1),$(this.depth),this.l_buf=0,this.lit_bufsize=0,this.last_lit=0,this.d_buf=0,this.opt_len=0,this.static_len=0,this.matches=0,this.insert=0,this.bi_buf=0,this.bi_valid=0}function dt(t){var e;return t&&t.state?(t.total_in=t.total_out=0,t.data_type=B,(e=t.state).pending=0,e.pending_out=0,e.wrap<0&&(e.wrap=-e.wrap),e.status=e.wrap?L:P,t.adler=2===e.wrap?0:1,e.last_flush=h,r._tr_init(e),c):Q(t,g)}function ft(t){var e,a=dt(t);return a===c&&((e=t.state).window_size=2*e.w_size,$(e.head),e.max_lazy_match=i[e.level].max_lazy,e.good_match=i[e.level].good_length,e.nice_match=i[e.level].nice_length,e.max_chain_length=i[e.level].max_chain,e.strstart=0,e.block_start=0,e.lookahead=0,e.insert=0,e.match_length=e.prev_length=I-1,e.match_available=0,e.ins_h=0),a}function _t(t,e,a,i,r,s){if(!t)return g;var o=1;if(e===p&&(e=6),i<0?(o=0,i=-i):i>15&&(o=2,i-=16),r<1||r>E||a!==S||i<8||i>15||e<0||e>9||s<0||s>x)return Q(t,g);8===i&&(i=9);var l=new ht;return t.state=l,l.strm=t,l.wrap=o,l.gzhead=null,l.w_bits=i,l.w_size=1<<l.w_bits,l.w_mask=l.w_size-1,l.hash_bits=r+7,l.hash_size=1<<l.hash_bits,l.hash_mask=l.hash_size-1,l.hash_shift=~~((l.hash_bits+I-1)/I),l.window=new n.Buf8(2*l.w_size),l.head=new n.Buf16(l.hash_size),l.prev=new n.Buf16(l.w_size),l.lit_bufsize=1<<r+6,l.pending_buf_size=4*l.lit_bufsize,l.pending_buf=new n.Buf8(l.pending_buf_size),l.d_buf=1*l.lit_bufsize,l.l_buf=3*l.lit_bufsize,l.level=e,l.strategy=s,l.method=a,ft(t)}i=[new lt(0,0,0,0,function(t,e){var a=65535;for(a>t.pending_buf_size-5&&(a=t.pending_buf_size-5);;){if(t.lookahead<=1){if(rt(t),0===t.lookahead&&e===h)return q;if(0===t.lookahead)break}t.strstart+=t.lookahead,t.lookahead=0;var i=t.block_start+a;if((0===t.strstart||t.strstart>=i)&&(t.lookahead=t.strstart-i,t.strstart=i,et(t,!1),0===t.strm.avail_out))return q;if(t.strstart-t.block_start>=t.w_size-T&&(et(t,!1),0===t.strm.avail_out))return q}return t.insert=0,e===_?(et(t,!0),0===t.strm.avail_out?X:W):(t.strstart>t.block_start&&(et(t,!1),t.strm.avail_out),q)}),new lt(4,4,8,4,st),new lt(4,5,16,8,st),new lt(4,6,32,32,st),new lt(4,4,16,16,ot),new lt(8,16,32,32,ot),new lt(8,16,128,128,ot),new lt(8,32,128,256,ot),new lt(32,128,258,1024,ot),new lt(32,258,258,4096,ot)],a.deflateInit=function(t,e){return _t(t,e,S,A,Z,z)},a.deflateInit2=_t,a.deflateReset=ft,a.deflateResetKeep=dt,a.deflateSetHeader=function(t,e){return t&&t.state?2!==t.state.wrap?g:(t.state.gzhead=e,c):g},a.deflate=function(t,e){var a,n,s,l;if(!t||!t.state||e>u||e<0)return t?Q(t,g):g;if(n=t.state,!t.output||!t.input&&0!==t.avail_in||n.status===Y&&e!==_)return Q(t,0===t.avail_out?w:g);if(n.strm=t,a=n.last_flush,n.last_flush=e,n.status===L)if(2===n.wrap)t.adler=0,at(n,31),at(n,139),at(n,8),n.gzhead?(at(n,(n.gzhead.text?1:0)+(n.gzhead.hcrc?2:0)+(n.gzhead.extra?4:0)+(n.gzhead.name?8:0)+(n.gzhead.comment?16:0)),at(n,255&n.gzhead.time),at(n,n.gzhead.time>>8&255),at(n,n.gzhead.time>>16&255),at(n,n.gzhead.time>>24&255),at(n,9===n.level?2:n.strategy>=k||n.level<2?4:0),at(n,255&n.gzhead.os),n.gzhead.extra&&n.gzhead.extra.length&&(at(n,255&n.gzhead.extra.length),at(n,n.gzhead.extra.length>>8&255)),n.gzhead.hcrc&&(t.adler=o(t.adler,n.pending_buf,n.pending,0)),n.gzindex=0,n.status=H):(at(n,0),at(n,0),at(n,0),at(n,0),at(n,0),at(n,9===n.level?2:n.strategy>=k||n.level<2?4:0),at(n,J),n.status=P);else{var m=S+(n.w_bits-8<<4)<<8;m|=(n.strategy>=k||n.level<2?0:n.level<6?1:6===n.level?2:3)<<6,0!==n.strstart&&(m|=F),m+=31-m%31,n.status=P,it(n,m),0!==n.strstart&&(it(n,t.adler>>>16),it(n,65535&t.adler)),t.adler=1}if(n.status===H)if(n.gzhead.extra){for(s=n.pending;n.gzindex<(65535&n.gzhead.extra.length)&&(n.pending!==n.pending_buf_size||(n.gzhead.hcrc&&n.pending>s&&(t.adler=o(t.adler,n.pending_buf,n.pending-s,s)),tt(t),s=n.pending,n.pending!==n.pending_buf_size));)at(n,255&n.gzhead.extra[n.gzindex]),n.gzindex++;n.gzhead.hcrc&&n.pending>s&&(t.adler=o(t.adler,n.pending_buf,n.pending-s,s)),n.gzindex===n.gzhead.extra.length&&(n.gzindex=0,n.status=j)}else n.status=j;if(n.status===j)if(n.gzhead.name){s=n.pending;do{if(n.pending===n.pending_buf_size&&(n.gzhead.hcrc&&n.pending>s&&(t.adler=o(t.adler,n.pending_buf,n.pending-s,s)),tt(t),s=n.pending,n.pending===n.pending_buf_size)){l=1;break}l=n.gzindex<n.gzhead.name.length?255&n.gzhead.name.charCodeAt(n.gzindex++):0,at(n,l)}while(0!==l);n.gzhead.hcrc&&n.pending>s&&(t.adler=o(t.adler,n.pending_buf,n.pending-s,s)),0===l&&(n.gzindex=0,n.status=K)}else n.status=K;if(n.status===K)if(n.gzhead.comment){s=n.pending;do{if(n.pending===n.pending_buf_size&&(n.gzhead.hcrc&&n.pending>s&&(t.adler=o(t.adler,n.pending_buf,n.pending-s,s)),tt(t),s=n.pending,n.pending===n.pending_buf_size)){l=1;break}l=n.gzindex<n.gzhead.comment.length?255&n.gzhead.comment.charCodeAt(n.gzindex++):0,at(n,l)}while(0!==l);n.gzhead.hcrc&&n.pending>s&&(t.adler=o(t.adler,n.pending_buf,n.pending-s,s)),0===l&&(n.status=M)}else n.status=M;if(n.status===M&&(n.gzhead.hcrc?(n.pending+2>n.pending_buf_size&&tt(t),n.pending+2<=n.pending_buf_size&&(at(n,255&t.adler),at(n,t.adler>>8&255),t.adler=0,n.status=P)):n.status=P),0!==n.pending){if(tt(t),0===t.avail_out)return n.last_flush=-1,c}else if(0===t.avail_in&&V(e)<=V(a)&&e!==_)return Q(t,w);if(n.status===Y&&0!==t.avail_in)return Q(t,w);if(0!==t.avail_in||0!==n.lookahead||e!==h&&n.status!==Y){var p=n.strategy===k?function(t,e){for(var a;;){if(0===t.lookahead&&(rt(t),0===t.lookahead)){if(e===h)return q;break}if(t.match_length=0,a=r._tr_tally(t,0,t.window[t.strstart]),t.lookahead--,t.strstart++,a&&(et(t,!1),0===t.strm.avail_out))return q}return t.insert=0,e===_?(et(t,!0),0===t.strm.avail_out?X:W):t.last_lit&&(et(t,!1),0===t.strm.avail_out)?q:G}(n,e):n.strategy===y?function(t,e){for(var a,i,n,s,o=t.window;;){if(t.lookahead<=U){if(rt(t),t.lookahead<=U&&e===h)return q;if(0===t.lookahead)break}if(t.match_length=0,t.lookahead>=I&&t.strstart>0&&(i=o[n=t.strstart-1])===o[++n]&&i===o[++n]&&i===o[++n]){s=t.strstart+U;do{}while(i===o[++n]&&i===o[++n]&&i===o[++n]&&i===o[++n]&&i===o[++n]&&i===o[++n]&&i===o[++n]&&i===o[++n]&&n<s);t.match_length=U-(s-n),t.match_length>t.lookahead&&(t.match_length=t.lookahead)}if(t.match_length>=I?(a=r._tr_tally(t,1,t.match_length-I),t.lookahead-=t.match_length,t.strstart+=t.match_length,t.match_length=0):(a=r._tr_tally(t,0,t.window[t.strstart]),t.lookahead--,t.strstart++),a&&(et(t,!1),0===t.strm.avail_out))return q}return t.insert=0,e===_?(et(t,!0),0===t.strm.avail_out?X:W):t.last_lit&&(et(t,!1),0===t.strm.avail_out)?q:G}(n,e):i[n.level].func(n,e);if(p!==X&&p!==W||(n.status=Y),p===q||p===X)return 0===t.avail_out&&(n.last_flush=-1),c;if(p===G&&(e===d?r._tr_align(n):e!==u&&(r._tr_stored_block(n,0,0,!1),e===f&&($(n.head),0===n.lookahead&&(n.strstart=0,n.block_start=0,n.insert=0))),tt(t),0===t.avail_out))return n.last_flush=-1,c}return e!==_?c:n.wrap<=0?b:(2===n.wrap?(at(n,255&t.adler),at(n,t.adler>>8&255),at(n,t.adler>>16&255),at(n,t.adler>>24&255),at(n,255&t.total_in),at(n,t.total_in>>8&255),at(n,t.total_in>>16&255),at(n,t.total_in>>24&255)):(it(n,t.adler>>>16),it(n,65535&t.adler)),tt(t),n.wrap>0&&(n.wrap=-n.wrap),0!==n.pending?c:b)},a.deflateEnd=function(t){var e;return t&&t.state?(e=t.state.status)!==L&&e!==H&&e!==j&&e!==K&&e!==M&&e!==P&&e!==Y?Q(t,g):(t.state=null,e===P?Q(t,m):c):g},a.deflateSetDictionary=function(t,e){var a,i,r,o,l,h,d,f,_=e.length;if(!t||!t.state)return g;if(2===(o=(a=t.state).wrap)||1===o&&a.status!==L||a.lookahead)return g;for(1===o&&(t.adler=s(t.adler,e,_,0)),a.wrap=0,_>=a.w_size&&(0===o&&($(a.head),a.strstart=0,a.block_start=0,a.insert=0),f=new n.Buf8(a.w_size),n.arraySet(f,e,_-a.w_size,a.w_size,0),e=f,_=a.w_size),l=t.avail_in,h=t.next_in,d=t.input,t.avail_in=_,t.next_in=0,t.input=e,rt(a);a.lookahead>=I;){i=a.strstart,r=a.lookahead-(I-1);do{a.ins_h=(a.ins_h<<a.hash_shift^a.window[i+I-1])&a.hash_mask,a.prev[i&a.w_mask]=a.head[a.ins_h],a.head[a.ins_h]=i,i++}while(--r);a.strstart=i,a.lookahead=I-1,rt(a)}return a.strstart+=a.lookahead,a.block_start=a.strstart,a.insert=a.lookahead,a.lookahead=0,a.match_length=a.prev_length=I-1,a.match_available=0,t.next_in=h,t.input=d,t.avail_in=l,a.wrap=o,c},a.deflateInfo="pako deflate (from Nodeca project)"},{"../utils/common":3,"./adler32":5,"./crc32":7,"./messages":13,"./trees":14}],9:[function(t,e,a){"use strict";e.exports=function(){this.text=0,this.time=0,this.xflags=0,this.os=0,this.extra=null,this.extra_len=0,this.name="",this.comment="",this.hcrc=0,this.done=!1}},{}],10:[function(t,e,a){"use strict";e.exports=function(t,e){var a,i,n,r,s,o,l,h,d,f,_,u,c,b,g,m,w,p,v,k,y,x,z,B,S;a=t.state,i=t.next_in,B=t.input,n=i+(t.avail_in-5),r=t.next_out,S=t.output,s=r-(e-t.avail_out),o=r+(t.avail_out-257),l=a.dmax,h=a.wsize,d=a.whave,f=a.wnext,_=a.window,u=a.hold,c=a.bits,b=a.lencode,g=a.distcode,m=(1<<a.lenbits)-1,w=(1<<a.distbits)-1;t:do{c<15&&(u+=B[i++]<<c,c+=8,u+=B[i++]<<c,c+=8),p=b[u&m];e:for(;;){if(u>>>=v=p>>>24,c-=v,0===(v=p>>>16&255))S[r++]=65535&p;else{if(!(16&v)){if(0==(64&v)){p=b[(65535&p)+(u&(1<<v)-1)];continue e}if(32&v){a.mode=12;break t}t.msg="invalid literal/length code",a.mode=30;break t}k=65535&p,(v&=15)&&(c<v&&(u+=B[i++]<<c,c+=8),k+=u&(1<<v)-1,u>>>=v,c-=v),c<15&&(u+=B[i++]<<c,c+=8,u+=B[i++]<<c,c+=8),p=g[u&w];a:for(;;){if(u>>>=v=p>>>24,c-=v,!(16&(v=p>>>16&255))){if(0==(64&v)){p=g[(65535&p)+(u&(1<<v)-1)];continue a}t.msg="invalid distance code",a.mode=30;break t}if(y=65535&p,c<(v&=15)&&(u+=B[i++]<<c,(c+=8)<v&&(u+=B[i++]<<c,c+=8)),(y+=u&(1<<v)-1)>l){t.msg="invalid distance too far back",a.mode=30;break t}if(u>>>=v,c-=v,y>(v=r-s)){if((v=y-v)>d&&a.sane){t.msg="invalid distance too far back",a.mode=30;break t}if(x=0,z=_,0===f){if(x+=h-v,v<k){k-=v;do{S[r++]=_[x++]}while(--v);x=r-y,z=S}}else if(f<v){if(x+=h+f-v,(v-=f)<k){k-=v;do{S[r++]=_[x++]}while(--v);if(x=0,f<k){k-=v=f;do{S[r++]=_[x++]}while(--v);x=r-y,z=S}}}else if(x+=f-v,v<k){k-=v;do{S[r++]=_[x++]}while(--v);x=r-y,z=S}for(;k>2;)S[r++]=z[x++],S[r++]=z[x++],S[r++]=z[x++],k-=3;k&&(S[r++]=z[x++],k>1&&(S[r++]=z[x++]))}else{x=r-y;do{S[r++]=S[x++],S[r++]=S[x++],S[r++]=S[x++],k-=3}while(k>2);k&&(S[r++]=S[x++],k>1&&(S[r++]=S[x++]))}break}}break}}while(i<n&&r<o);i-=k=c>>3,u&=(1<<(c-=k<<3))-1,t.next_in=i,t.next_out=r,t.avail_in=i<n?n-i+5:5-(i-n),t.avail_out=r<o?o-r+257:257-(r-o),a.hold=u,a.bits=c}},{}],11:[function(t,e,a){"use strict";var i=t("../utils/common"),n=t("./adler32"),r=t("./crc32"),s=t("./inffast"),o=t("./inftrees"),l=0,h=1,d=2,f=4,_=5,u=6,c=0,b=1,g=2,m=-2,w=-3,p=-4,v=-5,k=8,y=1,x=2,z=3,B=4,S=5,E=6,A=7,Z=8,R=9,C=10,N=11,O=12,D=13,I=14,U=15,T=16,F=17,L=18,H=19,j=20,K=21,M=22,P=23,Y=24,q=25,G=26,X=27,W=28,J=29,Q=30,V=31,$=32,tt=852,et=592,at=15;function it(t){return(t>>>24&255)+(t>>>8&65280)+((65280&t)<<8)+((255&t)<<24)}function nt(){this.mode=0,this.last=!1,this.wrap=0,this.havedict=!1,this.flags=0,this.dmax=0,this.check=0,this.total=0,this.head=null,this.wbits=0,this.wsize=0,this.whave=0,this.wnext=0,this.window=null,this.hold=0,this.bits=0,this.length=0,this.offset=0,this.extra=0,this.lencode=null,this.distcode=null,this.lenbits=0,this.distbits=0,this.ncode=0,this.nlen=0,this.ndist=0,this.have=0,this.next=null,this.lens=new i.Buf16(320),this.work=new i.Buf16(288),this.lendyn=null,this.distdyn=null,this.sane=0,this.back=0,this.was=0}function rt(t){var e;return t&&t.state?(e=t.state,t.total_in=t.total_out=e.total=0,t.msg="",e.wrap&&(t.adler=1&e.wrap),e.mode=y,e.last=0,e.havedict=0,e.dmax=32768,e.head=null,e.hold=0,e.bits=0,e.lencode=e.lendyn=new i.Buf32(tt),e.distcode=e.distdyn=new i.Buf32(et),e.sane=1,e.back=-1,c):m}function st(t){var e;return t&&t.state?((e=t.state).wsize=0,e.whave=0,e.wnext=0,rt(t)):m}function ot(t,e){var a,i;return t&&t.state?(i=t.state,e<0?(a=0,e=-e):(a=1+(e>>4),e<48&&(e&=15)),e&&(e<8||e>15)?m:(null!==i.window&&i.wbits!==e&&(i.window=null),i.wrap=a,i.wbits=e,st(t))):m}function lt(t,e){var a,i;return t?(i=new nt,t.state=i,i.window=null,(a=ot(t,e))!==c&&(t.state=null),a):m}var ht,dt,ft=!0;function _t(t){if(ft){var e;for(ht=new i.Buf32(512),dt=new i.Buf32(32),e=0;e<144;)t.lens[e++]=8;for(;e<256;)t.lens[e++]=9;for(;e<280;)t.lens[e++]=7;for(;e<288;)t.lens[e++]=8;for(o(h,t.lens,0,288,ht,0,t.work,{bits:9}),e=0;e<32;)t.lens[e++]=5;o(d,t.lens,0,32,dt,0,t.work,{bits:5}),ft=!1}t.lencode=ht,t.lenbits=9,t.distcode=dt,t.distbits=5}function ut(t,e,a,n){var r,s=t.state;return null===s.window&&(s.wsize=1<<s.wbits,s.wnext=0,s.whave=0,s.window=new i.Buf8(s.wsize)),n>=s.wsize?(i.arraySet(s.window,e,a-s.wsize,s.wsize,0),s.wnext=0,s.whave=s.wsize):((r=s.wsize-s.wnext)>n&&(r=n),i.arraySet(s.window,e,a-n,r,s.wnext),(n-=r)?(i.arraySet(s.window,e,a-n,n,0),s.wnext=n,s.whave=s.wsize):(s.wnext+=r,s.wnext===s.wsize&&(s.wnext=0),s.whave<s.wsize&&(s.whave+=r))),0}a.inflateReset=st,a.inflateReset2=ot,a.inflateResetKeep=rt,a.inflateInit=function(t){return lt(t,at)},a.inflateInit2=lt,a.inflate=function(t,e){var a,tt,et,at,nt,rt,st,ot,lt,ht,dt,ft,ct,bt,gt,mt,wt,pt,vt,kt,yt,xt,zt,Bt,St=0,Et=new i.Buf8(4),At=[16,17,18,0,8,7,9,6,10,5,11,4,12,3,13,2,14,1,15];if(!t||!t.state||!t.output||!t.input&&0!==t.avail_in)return m;(a=t.state).mode===O&&(a.mode=D),nt=t.next_out,et=t.output,st=t.avail_out,at=t.next_in,tt=t.input,rt=t.avail_in,ot=a.hold,lt=a.bits,ht=rt,dt=st,xt=c;t:for(;;)switch(a.mode){case y:if(0===a.wrap){a.mode=D;break}for(;lt<16;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(2&a.wrap&&35615===ot){a.check=0,Et[0]=255&ot,Et[1]=ot>>>8&255,a.check=r(a.check,Et,2,0),ot=0,lt=0,a.mode=x;break}if(a.flags=0,a.head&&(a.head.done=!1),!(1&a.wrap)||(((255&ot)<<8)+(ot>>8))%31){t.msg="incorrect header check",a.mode=Q;break}if((15&ot)!==k){t.msg="unknown compression method",a.mode=Q;break}if(lt-=4,yt=8+(15&(ot>>>=4)),0===a.wbits)a.wbits=yt;else if(yt>a.wbits){t.msg="invalid window size",a.mode=Q;break}a.dmax=1<<yt,t.adler=a.check=1,a.mode=512&ot?C:O,ot=0,lt=0;break;case x:for(;lt<16;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(a.flags=ot,(255&a.flags)!==k){t.msg="unknown compression method",a.mode=Q;break}if(57344&a.flags){t.msg="unknown header flags set",a.mode=Q;break}a.head&&(a.head.text=ot>>8&1),512&a.flags&&(Et[0]=255&ot,Et[1]=ot>>>8&255,a.check=r(a.check,Et,2,0)),ot=0,lt=0,a.mode=z;case z:for(;lt<32;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}a.head&&(a.head.time=ot),512&a.flags&&(Et[0]=255&ot,Et[1]=ot>>>8&255,Et[2]=ot>>>16&255,Et[3]=ot>>>24&255,a.check=r(a.check,Et,4,0)),ot=0,lt=0,a.mode=B;case B:for(;lt<16;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}a.head&&(a.head.xflags=255&ot,a.head.os=ot>>8),512&a.flags&&(Et[0]=255&ot,Et[1]=ot>>>8&255,a.check=r(a.check,Et,2,0)),ot=0,lt=0,a.mode=S;case S:if(1024&a.flags){for(;lt<16;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}a.length=ot,a.head&&(a.head.extra_len=ot),512&a.flags&&(Et[0]=255&ot,Et[1]=ot>>>8&255,a.check=r(a.check,Et,2,0)),ot=0,lt=0}else a.head&&(a.head.extra=null);a.mode=E;case E:if(1024&a.flags&&((ft=a.length)>rt&&(ft=rt),ft&&(a.head&&(yt=a.head.extra_len-a.length,a.head.extra||(a.head.extra=new Array(a.head.extra_len)),i.arraySet(a.head.extra,tt,at,ft,yt)),512&a.flags&&(a.check=r(a.check,tt,ft,at)),rt-=ft,at+=ft,a.length-=ft),a.length))break t;a.length=0,a.mode=A;case A:if(2048&a.flags){if(0===rt)break t;ft=0;do{yt=tt[at+ft++],a.head&&yt&&a.length<65536&&(a.head.name+=String.fromCharCode(yt))}while(yt&&ft<rt);if(512&a.flags&&(a.check=r(a.check,tt,ft,at)),rt-=ft,at+=ft,yt)break t}else a.head&&(a.head.name=null);a.length=0,a.mode=Z;case Z:if(4096&a.flags){if(0===rt)break t;ft=0;do{yt=tt[at+ft++],a.head&&yt&&a.length<65536&&(a.head.comment+=String.fromCharCode(yt))}while(yt&&ft<rt);if(512&a.flags&&(a.check=r(a.check,tt,ft,at)),rt-=ft,at+=ft,yt)break t}else a.head&&(a.head.comment=null);a.mode=R;case R:if(512&a.flags){for(;lt<16;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(ot!==(65535&a.check)){t.msg="header crc mismatch",a.mode=Q;break}ot=0,lt=0}a.head&&(a.head.hcrc=a.flags>>9&1,a.head.done=!0),t.adler=a.check=0,a.mode=O;break;case C:for(;lt<32;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}t.adler=a.check=it(ot),ot=0,lt=0,a.mode=N;case N:if(0===a.havedict)return t.next_out=nt,t.avail_out=st,t.next_in=at,t.avail_in=rt,a.hold=ot,a.bits=lt,g;t.adler=a.check=1,a.mode=O;case O:if(e===_||e===u)break t;case D:if(a.last){ot>>>=7&lt,lt-=7&lt,a.mode=X;break}for(;lt<3;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}switch(a.last=1&ot,lt-=1,3&(ot>>>=1)){case 0:a.mode=I;break;case 1:if(_t(a),a.mode=j,e===u){ot>>>=2,lt-=2;break t}break;case 2:a.mode=F;break;case 3:t.msg="invalid block type",a.mode=Q}ot>>>=2,lt-=2;break;case I:for(ot>>>=7&lt,lt-=7&lt;lt<32;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if((65535&ot)!=(ot>>>16^65535)){t.msg="invalid stored block lengths",a.mode=Q;break}if(a.length=65535&ot,ot=0,lt=0,a.mode=U,e===u)break t;case U:a.mode=T;case T:if(ft=a.length){if(ft>rt&&(ft=rt),ft>st&&(ft=st),0===ft)break t;i.arraySet(et,tt,at,ft,nt),rt-=ft,at+=ft,st-=ft,nt+=ft,a.length-=ft;break}a.mode=O;break;case F:for(;lt<14;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(a.nlen=257+(31&ot),ot>>>=5,lt-=5,a.ndist=1+(31&ot),ot>>>=5,lt-=5,a.ncode=4+(15&ot),ot>>>=4,lt-=4,a.nlen>286||a.ndist>30){t.msg="too many length or distance symbols",a.mode=Q;break}a.have=0,a.mode=L;case L:for(;a.have<a.ncode;){for(;lt<3;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}a.lens[At[a.have++]]=7&ot,ot>>>=3,lt-=3}for(;a.have<19;)a.lens[At[a.have++]]=0;if(a.lencode=a.lendyn,a.lenbits=7,zt={bits:a.lenbits},xt=o(l,a.lens,0,19,a.lencode,0,a.work,zt),a.lenbits=zt.bits,xt){t.msg="invalid code lengths set",a.mode=Q;break}a.have=0,a.mode=H;case H:for(;a.have<a.nlen+a.ndist;){for(;mt=(St=a.lencode[ot&(1<<a.lenbits)-1])>>>16&255,wt=65535&St,!((gt=St>>>24)<=lt);){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(wt<16)ot>>>=gt,lt-=gt,a.lens[a.have++]=wt;else{if(16===wt){for(Bt=gt+2;lt<Bt;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(ot>>>=gt,lt-=gt,0===a.have){t.msg="invalid bit length repeat",a.mode=Q;break}yt=a.lens[a.have-1],ft=3+(3&ot),ot>>>=2,lt-=2}else if(17===wt){for(Bt=gt+3;lt<Bt;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}lt-=gt,yt=0,ft=3+(7&(ot>>>=gt)),ot>>>=3,lt-=3}else{for(Bt=gt+7;lt<Bt;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}lt-=gt,yt=0,ft=11+(127&(ot>>>=gt)),ot>>>=7,lt-=7}if(a.have+ft>a.nlen+a.ndist){t.msg="invalid bit length repeat",a.mode=Q;break}for(;ft--;)a.lens[a.have++]=yt}}if(a.mode===Q)break;if(0===a.lens[256]){t.msg="invalid code -- missing end-of-block",a.mode=Q;break}if(a.lenbits=9,zt={bits:a.lenbits},xt=o(h,a.lens,0,a.nlen,a.lencode,0,a.work,zt),a.lenbits=zt.bits,xt){t.msg="invalid literal/lengths set",a.mode=Q;break}if(a.distbits=6,a.distcode=a.distdyn,zt={bits:a.distbits},xt=o(d,a.lens,a.nlen,a.ndist,a.distcode,0,a.work,zt),a.distbits=zt.bits,xt){t.msg="invalid distances set",a.mode=Q;break}if(a.mode=j,e===u)break t;case j:a.mode=K;case K:if(rt>=6&&st>=258){t.next_out=nt,t.avail_out=st,t.next_in=at,t.avail_in=rt,a.hold=ot,a.bits=lt,s(t,dt),nt=t.next_out,et=t.output,st=t.avail_out,at=t.next_in,tt=t.input,rt=t.avail_in,ot=a.hold,lt=a.bits,a.mode===O&&(a.back=-1);break}for(a.back=0;mt=(St=a.lencode[ot&(1<<a.lenbits)-1])>>>16&255,wt=65535&St,!((gt=St>>>24)<=lt);){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(mt&&0==(240&mt)){for(pt=gt,vt=mt,kt=wt;mt=(St=a.lencode[kt+((ot&(1<<pt+vt)-1)>>pt)])>>>16&255,wt=65535&St,!(pt+(gt=St>>>24)<=lt);){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}ot>>>=pt,lt-=pt,a.back+=pt}if(ot>>>=gt,lt-=gt,a.back+=gt,a.length=wt,0===mt){a.mode=G;break}if(32&mt){a.back=-1,a.mode=O;break}if(64&mt){t.msg="invalid literal/length code",a.mode=Q;break}a.extra=15&mt,a.mode=M;case M:if(a.extra){for(Bt=a.extra;lt<Bt;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}a.length+=ot&(1<<a.extra)-1,ot>>>=a.extra,lt-=a.extra,a.back+=a.extra}a.was=a.length,a.mode=P;case P:for(;mt=(St=a.distcode[ot&(1<<a.distbits)-1])>>>16&255,wt=65535&St,!((gt=St>>>24)<=lt);){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(0==(240&mt)){for(pt=gt,vt=mt,kt=wt;mt=(St=a.distcode[kt+((ot&(1<<pt+vt)-1)>>pt)])>>>16&255,wt=65535&St,!(pt+(gt=St>>>24)<=lt);){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}ot>>>=pt,lt-=pt,a.back+=pt}if(ot>>>=gt,lt-=gt,a.back+=gt,64&mt){t.msg="invalid distance code",a.mode=Q;break}a.offset=wt,a.extra=15&mt,a.mode=Y;case Y:if(a.extra){for(Bt=a.extra;lt<Bt;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}a.offset+=ot&(1<<a.extra)-1,ot>>>=a.extra,lt-=a.extra,a.back+=a.extra}if(a.offset>a.dmax){t.msg="invalid distance too far back",a.mode=Q;break}a.mode=q;case q:if(0===st)break t;if(ft=dt-st,a.offset>ft){if((ft=a.offset-ft)>a.whave&&a.sane){t.msg="invalid distance too far back",a.mode=Q;break}ft>a.wnext?(ft-=a.wnext,ct=a.wsize-ft):ct=a.wnext-ft,ft>a.length&&(ft=a.length),bt=a.window}else bt=et,ct=nt-a.offset,ft=a.length;ft>st&&(ft=st),st-=ft,a.length-=ft;do{et[nt++]=bt[ct++]}while(--ft);0===a.length&&(a.mode=K);break;case G:if(0===st)break t;et[nt++]=a.length,st--,a.mode=K;break;case X:if(a.wrap){for(;lt<32;){if(0===rt)break t;rt--,ot|=tt[at++]<<lt,lt+=8}if(dt-=st,t.total_out+=dt,a.total+=dt,dt&&(t.adler=a.check=a.flags?r(a.check,et,dt,nt-dt):n(a.check,et,dt,nt-dt)),dt=st,(a.flags?ot:it(ot))!==a.check){t.msg="incorrect data check",a.mode=Q;break}ot=0,lt=0}a.mode=W;case W:if(a.wrap&&a.flags){for(;lt<32;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(ot!==(4294967295&a.total)){t.msg="incorrect length check",a.mode=Q;break}ot=0,lt=0}a.mode=J;case J:xt=b;break t;case Q:xt=w;break t;case V:return p;case $:default:return m}return t.next_out=nt,t.avail_out=st,t.next_in=at,t.avail_in=rt,a.hold=ot,a.bits=lt,(a.wsize||dt!==t.avail_out&&a.mode<Q&&(a.mode<X||e!==f))&&ut(t,t.output,t.next_out,dt-t.avail_out)?(a.mode=V,p):(ht-=t.avail_in,dt-=t.avail_out,t.total_in+=ht,t.total_out+=dt,a.total+=dt,a.wrap&&dt&&(t.adler=a.check=a.flags?r(a.check,et,dt,t.next_out-dt):n(a.check,et,dt,t.next_out-dt)),t.data_type=a.bits+(a.last?64:0)+(a.mode===O?128:0)+(a.mode===j||a.mode===U?256:0),(0===ht&&0===dt||e===f)&&xt===c&&(xt=v),xt)},a.inflateEnd=function(t){if(!t||!t.state)return m;var e=t.state;return e.window&&(e.window=null),t.state=null,c},a.inflateGetHeader=function(t,e){var a;return t&&t.state?0==(2&(a=t.state).wrap)?m:(a.head=e,e.done=!1,c):m},a.inflateSetDictionary=function(t,e){var a,i=e.length;return t&&t.state?0!==(a=t.state).wrap&&a.mode!==N?m:a.mode===N&&n(1,e,i,0)!==a.check?w:ut(t,e,i,i)?(a.mode=V,p):(a.havedict=1,c):m},a.inflateInfo="pako inflate (from Nodeca project)"},{"../utils/common":3,"./adler32":5,"./crc32":7,"./inffast":10,"./inftrees":12}],12:[function(t,e,a){"use strict";var i=t("../utils/common"),n=[3,4,5,6,7,8,9,10,11,13,15,17,19,23,27,31,35,43,51,59,67,83,99,115,131,163,195,227,258,0,0],r=[16,16,16,16,16,16,16,16,17,17,17,17,18,18,18,18,19,19,19,19,20,20,20,20,21,21,21,21,16,72,78],s=[1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193,257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577,0,0],o=[16,16,16,16,17,17,18,18,19,19,20,20,21,21,22,22,23,23,24,24,25,25,26,26,27,27,28,28,29,29,64,64];e.exports=function(t,e,a,l,h,d,f,_){var u,c,b,g,m,w,p,v,k,y=_.bits,x=0,z=0,B=0,S=0,E=0,A=0,Z=0,R=0,C=0,N=0,O=null,D=0,I=new i.Buf16(16),U=new i.Buf16(16),T=null,F=0;for(x=0;x<=15;x++)I[x]=0;for(z=0;z<l;z++)I[e[a+z]]++;for(E=y,S=15;S>=1&&0===I[S];S--);if(E>S&&(E=S),0===S)return h[d++]=20971520,h[d++]=20971520,_.bits=1,0;for(B=1;B<S&&0===I[B];B++);for(E<B&&(E=B),R=1,x=1;x<=15;x++)if(R<<=1,(R-=I[x])<0)return-1;if(R>0&&(0===t||1!==S))return-1;for(U[1]=0,x=1;x<15;x++)U[x+1]=U[x]+I[x];for(z=0;z<l;z++)0!==e[a+z]&&(f[U[e[a+z]]++]=z);if(0===t?(O=T=f,w=19):1===t?(O=n,D-=257,T=r,F-=257,w=256):(O=s,T=o,w=-1),N=0,z=0,x=B,m=d,A=E,Z=0,b=-1,g=(C=1<<E)-1,1===t&&C>852||2===t&&C>592)return 1;for(;;){p=x-Z,f[z]<w?(v=0,k=f[z]):f[z]>w?(v=T[F+f[z]],k=O[D+f[z]]):(v=96,k=0),u=1<<x-Z,B=c=1<<A;do{h[m+(N>>Z)+(c-=u)]=p<<24|v<<16|k|0}while(0!==c);for(u=1<<x-1;N&u;)u>>=1;if(0!==u?(N&=u-1,N+=u):N=0,z++,0==--I[x]){if(x===S)break;x=e[a+f[z]]}if(x>E&&(N&g)!==b){for(0===Z&&(Z=E),m+=B,R=1<<(A=x-Z);A+Z<S&&!((R-=I[A+Z])<=0);)A++,R<<=1;if(C+=1<<A,1===t&&C>852||2===t&&C>592)return 1;h[b=N&g]=E<<24|A<<16|m-d|0}}return 0!==N&&(h[m+N]=x-Z<<24|64<<16|0),_.bits=E,0}},{"../utils/common":3}],13:[function(t,e,a){"use strict";e.exports={2:"need dictionary",1:"stream end",0:"","-1":"file error","-2":"stream error","-3":"data error","-4":"insufficient memory","-5":"buffer error","-6":"incompatible version"}},{}],14:[function(t,e,a){"use strict";var i=t("../utils/common"),n=4,r=0,s=1,o=2;function l(t){for(var e=t.length;--e>=0;)t[e]=0}var h=0,d=1,f=2,_=29,u=256,c=u+1+_,b=30,g=19,m=2*c+1,w=15,p=16,v=7,k=256,y=16,x=17,z=18,B=[0,0,0,0,0,0,0,0,1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,0],S=[0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13],E=[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,3,7],A=[16,17,18,0,8,7,9,6,10,5,11,4,12,3,13,2,14,1,15],Z=new Array(2*(c+2));l(Z);var R=new Array(2*b);l(R);var C=new Array(512);l(C);var N=new Array(256);l(N);var O=new Array(_);l(O);var D,I,U,T=new Array(b);function F(t,e,a,i,n){this.static_tree=t,this.extra_bits=e,this.extra_base=a,this.elems=i,this.max_length=n,this.has_stree=t&&t.length}function L(t,e){this.dyn_tree=t,this.max_code=0,this.stat_desc=e}function H(t){return t<256?C[t]:C[256+(t>>>7)]}function j(t,e){t.pending_buf[t.pending++]=255&e,t.pending_buf[t.pending++]=e>>>8&255}function K(t,e,a){t.bi_valid>p-a?(t.bi_buf|=e<<t.bi_valid&65535,j(t,t.bi_buf),t.bi_buf=e>>p-t.bi_valid,t.bi_valid+=a-p):(t.bi_buf|=e<<t.bi_valid&65535,t.bi_valid+=a)}function M(t,e,a){K(t,a[2*e],a[2*e+1])}function P(t,e){var a=0;do{a|=1&t,t>>>=1,a<<=1}while(--e>0);return a>>>1}function Y(t,e,a){var i,n,r=new Array(w+1),s=0;for(i=1;i<=w;i++)r[i]=s=s+a[i-1]<<1;for(n=0;n<=e;n++){var o=t[2*n+1];0!==o&&(t[2*n]=P(r[o]++,o))}}function q(t){var e;for(e=0;e<c;e++)t.dyn_ltree[2*e]=0;for(e=0;e<b;e++)t.dyn_dtree[2*e]=0;for(e=0;e<g;e++)t.bl_tree[2*e]=0;t.dyn_ltree[2*k]=1,t.opt_len=t.static_len=0,t.last_lit=t.matches=0}function G(t){t.bi_valid>8?j(t,t.bi_buf):t.bi_valid>0&&(t.pending_buf[t.pending++]=t.bi_buf),t.bi_buf=0,t.bi_valid=0}function X(t,e,a,i){var n=2*e,r=2*a;return t[n]<t[r]||t[n]===t[r]&&i[e]<=i[a]}function W(t,e,a){for(var i=t.heap[a],n=a<<1;n<=t.heap_len&&(n<t.heap_len&&X(e,t.heap[n+1],t.heap[n],t.depth)&&n++,!X(e,i,t.heap[n],t.depth));)t.heap[a]=t.heap[n],a=n,n<<=1;t.heap[a]=i}function J(t,e,a){var i,n,r,s,o=0;if(0!==t.last_lit)do{i=t.pending_buf[t.d_buf+2*o]<<8|t.pending_buf[t.d_buf+2*o+1],n=t.pending_buf[t.l_buf+o],o++,0===i?M(t,n,e):(M(t,(r=N[n])+u+1,e),0!==(s=B[r])&&K(t,n-=O[r],s),M(t,r=H(--i),a),0!==(s=S[r])&&K(t,i-=T[r],s))}while(o<t.last_lit);M(t,k,e)}function Q(t,e){var a,i,n,r=e.dyn_tree,s=e.stat_desc.static_tree,o=e.stat_desc.has_stree,l=e.stat_desc.elems,h=-1;for(t.heap_len=0,t.heap_max=m,a=0;a<l;a++)0!==r[2*a]?(t.heap[++t.heap_len]=h=a,t.depth[a]=0):r[2*a+1]=0;for(;t.heap_len<2;)r[2*(n=t.heap[++t.heap_len]=h<2?++h:0)]=1,t.depth[n]=0,t.opt_len--,o&&(t.static_len-=s[2*n+1]);for(e.max_code=h,a=t.heap_len>>1;a>=1;a--)W(t,r,a);n=l;do{a=t.heap[1],t.heap[1]=t.heap[t.heap_len--],W(t,r,1),i=t.heap[1],t.heap[--t.heap_max]=a,t.heap[--t.heap_max]=i,r[2*n]=r[2*a]+r[2*i],t.depth[n]=(t.depth[a]>=t.depth[i]?t.depth[a]:t.depth[i])+1,r[2*a+1]=r[2*i+1]=n,t.heap[1]=n++,W(t,r,1)}while(t.heap_len>=2);t.heap[--t.heap_max]=t.heap[1],function(t,e){var a,i,n,r,s,o,l=e.dyn_tree,h=e.max_code,d=e.stat_desc.static_tree,f=e.stat_desc.has_stree,_=e.stat_desc.extra_bits,u=e.stat_desc.extra_base,c=e.stat_desc.max_length,b=0;for(r=0;r<=w;r++)t.bl_count[r]=0;for(l[2*t.heap[t.heap_max]+1]=0,a=t.heap_max+1;a<m;a++)(r=l[2*l[2*(i=t.heap[a])+1]+1]+1)>c&&(r=c,b++),l[2*i+1]=r,i>h||(t.bl_count[r]++,s=0,i>=u&&(s=_[i-u]),o=l[2*i],t.opt_len+=o*(r+s),f&&(t.static_len+=o*(d[2*i+1]+s)));if(0!==b){do{for(r=c-1;0===t.bl_count[r];)r--;t.bl_count[r]--,t.bl_count[r+1]+=2,t.bl_count[c]--,b-=2}while(b>0);for(r=c;0!==r;r--)for(i=t.bl_count[r];0!==i;)(n=t.heap[--a])>h||(l[2*n+1]!==r&&(t.opt_len+=(r-l[2*n+1])*l[2*n],l[2*n+1]=r),i--)}}(t,e),Y(r,h,t.bl_count)}function V(t,e,a){var i,n,r=-1,s=e[1],o=0,l=7,h=4;for(0===s&&(l=138,h=3),e[2*(a+1)+1]=65535,i=0;i<=a;i++)n=s,s=e[2*(i+1)+1],++o<l&&n===s||(o<h?t.bl_tree[2*n]+=o:0!==n?(n!==r&&t.bl_tree[2*n]++,t.bl_tree[2*y]++):o<=10?t.bl_tree[2*x]++:t.bl_tree[2*z]++,o=0,r=n,0===s?(l=138,h=3):n===s?(l=6,h=3):(l=7,h=4))}function $(t,e,a){var i,n,r=-1,s=e[1],o=0,l=7,h=4;for(0===s&&(l=138,h=3),i=0;i<=a;i++)if(n=s,s=e[2*(i+1)+1],!(++o<l&&n===s)){if(o<h)do{M(t,n,t.bl_tree)}while(0!=--o);else 0!==n?(n!==r&&(M(t,n,t.bl_tree),o--),M(t,y,t.bl_tree),K(t,o-3,2)):o<=10?(M(t,x,t.bl_tree),K(t,o-3,3)):(M(t,z,t.bl_tree),K(t,o-11,7));o=0,r=n,0===s?(l=138,h=3):n===s?(l=6,h=3):(l=7,h=4)}}l(T);var tt=!1;function et(t,e,a,n){K(t,(h<<1)+(n?1:0),3),function(t,e,a,n){G(t),n&&(j(t,a),j(t,~a)),i.arraySet(t.pending_buf,t.window,e,a,t.pending),t.pending+=a}(t,e,a,!0)}a._tr_init=function(t){tt||(function(){var t,e,a,i,n,r=new Array(w+1);for(a=0,i=0;i<_-1;i++)for(O[i]=a,t=0;t<1<<B[i];t++)N[a++]=i;for(N[a-1]=i,n=0,i=0;i<16;i++)for(T[i]=n,t=0;t<1<<S[i];t++)C[n++]=i;for(n>>=7;i<b;i++)for(T[i]=n<<7,t=0;t<1<<S[i]-7;t++)C[256+n++]=i;for(e=0;e<=w;e++)r[e]=0;for(t=0;t<=143;)Z[2*t+1]=8,t++,r[8]++;for(;t<=255;)Z[2*t+1]=9,t++,r[9]++;for(;t<=279;)Z[2*t+1]=7,t++,r[7]++;for(;t<=287;)Z[2*t+1]=8,t++,r[8]++;for(Y(Z,c+1,r),t=0;t<b;t++)R[2*t+1]=5,R[2*t]=P(t,5);D=new F(Z,B,u+1,c,w),I=new F(R,S,0,b,w),U=new F(new Array(0),E,0,g,v)}(),tt=!0),t.l_desc=new L(t.dyn_ltree,D),t.d_desc=new L(t.dyn_dtree,I),t.bl_desc=new L(t.bl_tree,U),t.bi_buf=0,t.bi_valid=0,q(t)},a._tr_stored_block=et,a._tr_flush_block=function(t,e,a,i){var l,h,_=0;t.level>0?(t.strm.data_type===o&&(t.strm.data_type=function(t){var e,a=4093624447;for(e=0;e<=31;e++,a>>>=1)if(1&a&&0!==t.dyn_ltree[2*e])return r;if(0!==t.dyn_ltree[18]||0!==t.dyn_ltree[20]||0!==t.dyn_ltree[26])return s;for(e=32;e<u;e++)if(0!==t.dyn_ltree[2*e])return s;return r}(t)),Q(t,t.l_desc),Q(t,t.d_desc),_=function(t){var e;for(V(t,t.dyn_ltree,t.l_desc.max_code),V(t,t.dyn_dtree,t.d_desc.max_code),Q(t,t.bl_desc),e=g-1;e>=3&&0===t.bl_tree[2*A[e]+1];e--);return t.opt_len+=3*(e+1)+5+5+4,e}(t),l=t.opt_len+3+7>>>3,(h=t.static_len+3+7>>>3)<=l&&(l=h)):l=h=a+5,a+4<=l&&-1!==e?et(t,e,a,i):t.strategy===n||h===l?(K(t,(d<<1)+(i?1:0),3),J(t,Z,R)):(K(t,(f<<1)+(i?1:0),3),function(t,e,a,i){var n;for(K(t,e-257,5),K(t,a-1,5),K(t,i-4,4),n=0;n<i;n++)K(t,t.bl_tree[2*A[n]+1],3);$(t,t.dyn_ltree,e-1),$(t,t.dyn_dtree,a-1)}(t,t.l_desc.max_code+1,t.d_desc.max_code+1,_+1),J(t,t.dyn_ltree,t.dyn_dtree)),q(t),i&&G(t)},a._tr_tally=function(t,e,a){return t.pending_buf[t.d_buf+2*t.last_lit]=e>>>8&255,t.pending_buf[t.d_buf+2*t.last_lit+1]=255&e,t.pending_buf[t.l_buf+t.last_lit]=255&a,t.last_lit++,0===e?t.dyn_ltree[2*a]++:(t.matches++,e--,t.dyn_ltree[2*(N[a]+u+1)]++,t.dyn_dtree[2*H(e)]++),t.last_lit===t.lit_bufsize-1},a._tr_align=function(t){K(t,d<<1,3),M(t,k,Z),function(t){16===t.bi_valid?(j(t,t.bi_buf),t.bi_buf=0,t.bi_valid=0):t.bi_valid>=8&&(t.pending_buf[t.pending++]=255&t.bi_buf,t.bi_buf>>=8,t.bi_valid-=8)}(t)}},{"../utils/common":3}],15:[function(t,e,a){"use strict";e.exports=function(){this.input=null,this.next_in=0,this.avail_in=0,this.total_in=0,this.output=null,this.next_out=0,this.avail_out=0,this.total_out=0,this.msg="",this.state=null,this.data_type=2,this.adler=0}},{}],"/":[function(t,e,a){"use strict";var i={};(0,t("./lib/utils/common").assign)(i,t("./lib/deflate"),t("./lib/inflate"),t("./lib/zlib/constants")),e.exports=i},{"./lib/deflate":1,"./lib/inflate":2,"./lib/utils/common":3,"./lib/zlib/constants":6}]},{},[])("/")});
139
+ </script>
140
+ <script>
141
+ !function(){var e={};"object"==typeof module?module.exports=e:window.UPNG=e,function(e,r){e.toRGBA8=function(r){var t=r.width,n=r.height;if(null==r.tabs.acTL)return[e.toRGBA8.decodeImage(r.data,t,n,r).buffer];var i=[];null==r.frames[0].data&&(r.frames[0].data=r.data);for(var a,f=new Uint8Array(t*n*4),o=0;o<r.frames.length;o++){var s=r.frames[o],l=s.rect.x,c=s.rect.y,u=s.rect.width,d=s.rect.height,h=e.toRGBA8.decodeImage(s.data,u,d,r);if(0==o?a=h:0==s.blend?e._copyTile(h,u,d,a,t,n,l,c,0):1==s.blend&&e._copyTile(h,u,d,a,t,n,l,c,1),i.push(a.buffer),a=a.slice(0),0==s.dispose);else if(1==s.dispose)e._copyTile(f,u,d,a,t,n,l,c,0);else if(2==s.dispose){for(var v=o-1;2==r.frames[v].dispose;)v--;a=new Uint8Array(i[v]).slice(0)}}return i},e.toRGBA8.decodeImage=function(r,t,n,i){var a=t*n,f=e.decode._getBPP(i),o=Math.ceil(t*f/8),s=new Uint8Array(4*a),l=new Uint32Array(s.buffer),c=i.ctype,u=i.depth,d=e._bin.readUshort;if(6==c){var h=a<<2;if(8==u)for(var v=0;v<h;v++)s[v]=r[v];if(16==u)for(v=0;v<h;v++)s[v]=r[v<<1]}else if(2==c){var p=i.tabs.tRNS,b=-1,g=-1,m=-1;if(p&&(b=p[0],g=p[1],m=p[2]),8==u)for(v=0;v<a;v++){var y=3*v;s[M=v<<2]=r[y],s[M+1]=r[y+1],s[M+2]=r[y+2],s[M+3]=255,-1!=b&&r[y]==b&&r[y+1]==g&&r[y+2]==m&&(s[M+3]=0)}if(16==u)for(v=0;v<a;v++){y=6*v;s[M=v<<2]=r[y],s[M+1]=r[y+2],s[M+2]=r[y+4],s[M+3]=255,-1!=b&&d(r,y)==b&&d(r,y+2)==g&&d(r,y+4)==m&&(s[M+3]=0)}}else if(3==c){var w=i.tabs.PLTE,A=i.tabs.tRNS,U=A?A.length:0;if(1==u)for(var _=0;_<n;_++){var q=_*o,I=_*t;for(v=0;v<t;v++){var M=I+v<<2,T=3*(z=r[q+(v>>3)]>>7-((7&v)<<0)&1);s[M]=w[T],s[M+1]=w[T+1],s[M+2]=w[T+2],s[M+3]=z<U?A[z]:255}}if(2==u)for(_=0;_<n;_++)for(q=_*o,I=_*t,v=0;v<t;v++){M=I+v<<2,T=3*(z=r[q+(v>>2)]>>6-((3&v)<<1)&3);s[M]=w[T],s[M+1]=w[T+1],s[M+2]=w[T+2],s[M+3]=z<U?A[z]:255}if(4==u)for(_=0;_<n;_++)for(q=_*o,I=_*t,v=0;v<t;v++){M=I+v<<2,T=3*(z=r[q+(v>>1)]>>4-((1&v)<<2)&15);s[M]=w[T],s[M+1]=w[T+1],s[M+2]=w[T+2],s[M+3]=z<U?A[z]:255}if(8==u)for(v=0;v<a;v++){var z;M=v<<2,T=3*(z=r[v]);s[M]=w[T],s[M+1]=w[T+1],s[M+2]=w[T+2],s[M+3]=z<U?A[z]:255}}else if(4==c){if(8==u)for(v=0;v<a;v++){M=v<<2;var R=r[N=v<<1];s[M]=R,s[M+1]=R,s[M+2]=R,s[M+3]=r[N+1]}if(16==u)for(v=0;v<a;v++){var N;M=v<<2,R=r[N=v<<2];s[M]=R,s[M+1]=R,s[M+2]=R,s[M+3]=r[N+2]}}else if(0==c){b=i.tabs.tRNS?i.tabs.tRNS:-1;if(1==u)for(v=0;v<a;v++){var L=(R=255*(r[v>>3]>>7-(7&v)&1))==255*b?0:255;l[v]=L<<24|R<<16|R<<8|R}if(2==u)for(v=0;v<a;v++){L=(R=85*(r[v>>2]>>6-((3&v)<<1)&3))==85*b?0:255;l[v]=L<<24|R<<16|R<<8|R}if(4==u)for(v=0;v<a;v++){L=(R=17*(r[v>>1]>>4-((1&v)<<2)&15))==17*b?0:255;l[v]=L<<24|R<<16|R<<8|R}if(8==u)for(v=0;v<a;v++){L=(R=r[v])==b?0:255;l[v]=L<<24|R<<16|R<<8|R}if(16==u)for(v=0;v<a;v++){R=r[v<<1],L=d(r,v<<1)==b?0:255;l[v]=L<<24|R<<16|R<<8|R}}return s},e.decode=function(r){for(var t,n=new Uint8Array(r),i=8,a=e._bin,f=a.readUshort,o=a.readUint,s={tabs:{},frames:[]},l=new Uint8Array(n.length),c=0,u=0,d=[137,80,78,71,13,10,26,10],h=0;h<8;h++)if(n[h]!=d[h])throw"The input is not a PNG file!";for(;i<n.length;){var v=a.readUint(n,i);i+=4;var p=a.readASCII(n,i,4);if(i+=4,"IHDR"==p)e.decode._IHDR(n,i,s);else if("IDAT"==p){for(h=0;h<v;h++)l[c+h]=n[i+h];c+=v}else if("acTL"==p)s.tabs[p]={num_frames:o(n,i),num_plays:o(n,i+4)},t=new Uint8Array(n.length);else if("fcTL"==p){var b;if(0!=u)(b=s.frames[s.frames.length-1]).data=e.decode._decompress(s,t.slice(0,u),b.rect.width,b.rect.height),u=0;var g={x:o(n,i+12),y:o(n,i+16),width:o(n,i+4),height:o(n,i+8)},m=f(n,i+22);m=f(n,i+20)/(0==m?100:m);var y={rect:g,delay:Math.round(1e3*m),dispose:n[i+24],blend:n[i+25]};s.frames.push(y)}else if("fdAT"==p){for(h=0;h<v-4;h++)t[u+h]=n[i+h+4];u+=v-4}else if("pHYs"==p)s.tabs[p]=[a.readUint(n,i),a.readUint(n,i+4),n[i+8]];else if("cHRM"==p){s.tabs[p]=[];for(h=0;h<8;h++)s.tabs[p].push(a.readUint(n,i+4*h))}else if("tEXt"==p){null==s.tabs[p]&&(s.tabs[p]={});var w=a.nextZero(n,i),A=a.readASCII(n,i,w-i),U=a.readASCII(n,w+1,i+v-w-1);s.tabs[p][A]=U}else if("iTXt"==p){null==s.tabs[p]&&(s.tabs[p]={});w=0;var _=i;w=a.nextZero(n,_);A=a.readASCII(n,_,w-_),n[_=w+1],n[_+1];_+=2,w=a.nextZero(n,_);a.readASCII(n,_,w-_);_=w+1,w=a.nextZero(n,_);a.readUTF8(n,_,w-_);_=w+1;U=a.readUTF8(n,_,v-(_-i));s.tabs[p][A]=U}else if("PLTE"==p)s.tabs[p]=a.readBytes(n,i,v);else if("hIST"==p){var q=s.tabs.PLTE.length/3;s.tabs[p]=[];for(h=0;h<q;h++)s.tabs[p].push(f(n,i+2*h))}else if("tRNS"==p)3==s.ctype?s.tabs[p]=a.readBytes(n,i,v):0==s.ctype?s.tabs[p]=f(n,i):2==s.ctype&&(s.tabs[p]=[f(n,i),f(n,i+2),f(n,i+4)]);else if("gAMA"==p)s.tabs[p]=a.readUint(n,i)/1e5;else if("sRGB"==p)s.tabs[p]=n[i];else if("bKGD"==p)0==s.ctype||4==s.ctype?s.tabs[p]=[f(n,i)]:2==s.ctype||6==s.ctype?s.tabs[p]=[f(n,i),f(n,i+2),f(n,i+4)]:3==s.ctype&&(s.tabs[p]=n[i]);else if("IEND"==p)break;i+=v;a.readUint(n,i);i+=4}0!=u&&((b=s.frames[s.frames.length-1]).data=e.decode._decompress(s,t.slice(0,u),b.rect.width,b.rect.height),u=0);return s.data=e.decode._decompress(s,l,s.width,s.height),delete s.compress,delete s.interlace,delete s.filter,s},e.decode._decompress=function(r,t,n,i){return 0==r.compress&&(t=e.decode._inflate(t)),0==r.interlace?t=e.decode._filterZero(t,r,0,n,i):1==r.interlace&&(t=e.decode._readInterlace(t,r)),t},e.decode._inflate=function(e){return r.inflate(e)},e.decode._readInterlace=function(r,t){for(var n=t.width,i=t.height,a=e.decode._getBPP(t),f=a>>3,o=Math.ceil(n*a/8),s=new Uint8Array(i*o),l=0,c=[0,0,4,0,2,0,1],u=[0,4,0,2,0,1,0],d=[8,8,8,4,4,2,2],h=[8,8,4,4,2,2,1],v=0;v<7;){for(var p=d[v],b=h[v],g=0,m=0,y=c[v];y<i;)y+=p,m++;for(var w=u[v];w<n;)w+=b,g++;var A=Math.ceil(g*a/8);e.decode._filterZero(r,t,l,g,m);for(var U=0,_=c[v];_<i;){for(var q=u[v],I=l+U*A<<3;q<n;){var M;if(1==a)M=(M=r[I>>3])>>7-(7&I)&1,s[_*o+(q>>3)]|=M<<7-((3&q)<<0);if(2==a)M=(M=r[I>>3])>>6-(7&I)&3,s[_*o+(q>>2)]|=M<<6-((3&q)<<1);if(4==a)M=(M=r[I>>3])>>4-(7&I)&15,s[_*o+(q>>1)]|=M<<4-((1&q)<<2);if(a>=8)for(var T=_*o+q*f,z=0;z<f;z++)s[T+z]=r[(I>>3)+z];I+=a,q+=b}U++,_+=p}g*m!=0&&(l+=m*(1+A)),v+=1}return s},e.decode._getBPP=function(e){return[1,null,3,1,2,null,4][e.ctype]*e.depth},e.decode._filterZero=function(r,t,n,i,a){var f=e.decode._getBPP(t),o=Math.ceil(i*f/8),s=e.decode._paeth;f=Math.ceil(f/8);for(var l=0;l<a;l++){var c=n+l*o,u=c+l+1,d=r[u-1];if(0==d)for(var h=0;h<o;h++)r[c+h]=r[u+h];else if(1==d){for(h=0;h<f;h++)r[c+h]=r[u+h];for(h=f;h<o;h++)r[c+h]=r[u+h]+r[c+h-f]&255}else if(0==l){for(h=0;h<f;h++)r[c+h]=r[u+h];if(2==d)for(h=f;h<o;h++)r[c+h]=255&r[u+h];if(3==d)for(h=f;h<o;h++)r[c+h]=r[u+h]+(r[c+h-f]>>1)&255;if(4==d)for(h=f;h<o;h++)r[c+h]=r[u+h]+s(r[c+h-f],0,0)&255}else{if(2==d)for(h=0;h<o;h++)r[c+h]=r[u+h]+r[c+h-o]&255;if(3==d){for(h=0;h<f;h++)r[c+h]=r[u+h]+(r[c+h-o]>>1)&255;for(h=f;h<o;h++)r[c+h]=r[u+h]+(r[c+h-o]+r[c+h-f]>>1)&255}if(4==d){for(h=0;h<f;h++)r[c+h]=r[u+h]+s(0,r[c+h-o],0)&255;for(h=f;h<o;h++)r[c+h]=r[u+h]+s(r[c+h-f],r[c+h-o],r[c+h-f-o])&255}}}return r},e.decode._paeth=function(e,r,t){var n=e+r-t,i=Math.abs(n-e),a=Math.abs(n-r),f=Math.abs(n-t);return i<=a&&i<=f?e:a<=f?r:t},e.decode._IHDR=function(r,t,n){var i=e._bin;n.width=i.readUint(r,t),t+=4,n.height=i.readUint(r,t),t+=4,n.depth=r[t],t++,n.ctype=r[t],t++,n.compress=r[t],t++,n.filter=r[t],t++,n.interlace=r[t],t++},e._bin={nextZero:function(e,r){for(;0!=e[r];)r++;return r},readUshort:function(e,r){return e[r]<<8|e[r+1]},writeUshort:function(e,r,t){e[r]=t>>8&255,e[r+1]=255&t},readUint:function(e,r){return 16777216*e[r]+(e[r+1]<<16|e[r+2]<<8|e[r+3])},writeUint:function(e,r,t){e[r]=t>>24&255,e[r+1]=t>>16&255,e[r+2]=t>>8&255,e[r+3]=255&t},readASCII:function(e,r,t){for(var n="",i=0;i<t;i++)n+=String.fromCharCode(e[r+i]);return n},writeASCII:function(e,r,t){for(var n=0;n<t.length;n++)e[r+n]=t.charCodeAt(n)},readBytes:function(e,r,t){for(var n=[],i=0;i<t;i++)n.push(e[r+i]);return n},pad:function(e){return e.length<2?"0"+e:e},readUTF8:function(r,t,n){for(var i,a="",f=0;f<n;f++)a+="%"+e._bin.pad(r[t+f].toString(16));try{i=decodeURIComponent(a)}catch(i){return e._bin.readASCII(r,t,n)}return i}},e._copyTile=function(e,r,t,n,i,a,f,o,s){for(var l=Math.min(r,i),c=Math.min(t,a),u=0,d=0,h=0;h<c;h++)for(var v=0;v<l;v++)if(f>=0&&o>=0?(u=h*r+v<<2,d=(o+h)*i+f+v<<2):(u=(-o+h)*r-f+v<<2,d=h*i+v<<2),0==s)n[d]=e[u],n[d+1]=e[u+1],n[d+2]=e[u+2],n[d+3]=e[u+3];else if(1==s){var p=e[u+3]*(1/255),b=e[u]*p,g=e[u+1]*p,m=e[u+2]*p,y=n[d+3]*(1/255),w=n[d]*y,A=n[d+1]*y,U=n[d+2]*y,_=1-p,q=p+y*_,I=0==q?0:1/q;n[d+3]=255*q,n[d+0]=(b+w*_)*I,n[d+1]=(g+A*_)*I,n[d+2]=(m+U*_)*I}else if(2==s){p=e[u+3],b=e[u],g=e[u+1],m=e[u+2],y=n[d+3],w=n[d],A=n[d+1],U=n[d+2];p==y&&b==w&&g==A&&m==U?(n[d]=0,n[d+1]=0,n[d+2]=0,n[d+3]=0):(n[d]=b,n[d+1]=g,n[d+2]=m,n[d+3]=p)}else if(3==s){p=e[u+3],b=e[u],g=e[u+1],m=e[u+2],y=n[d+3],w=n[d],A=n[d+1],U=n[d+2];if(p==y&&b==w&&g==A&&m==U)continue;if(p<220&&y>20)return!1}return!0},e.encode=function(r,t,n,i,a,f){null==i&&(i=0),null==f&&(f=!1);var o=e.encode.compress(r,t,n,i,!1,f);return e.encode.compressPNG(o,-1),e.encode._main(o,t,n,a)},e.encodeLL=function(r,t,n,i,a,f,o){for(var s={ctype:0+(1==i?0:2)+(0==a?0:4),depth:f,frames:[]},l=(i+a)*f,c=l*t,u=0;u<r.length;u++)s.frames.push({rect:{x:0,y:0,width:t,height:n},img:new Uint8Array(r[u]),blend:0,dispose:1,bpp:Math.ceil(l/8),bpl:Math.ceil(c/8)});return e.encode.compressPNG(s,4),e.encode._main(s,t,n,o)},e.encode._main=function(r,t,n,i){var a=e.crc.crc,f=e._bin.writeUint,o=e._bin.writeUshort,s=e._bin.writeASCII,l=8,c=r.frames.length>1,u=!1,d=46+(c?20:0);if(3==r.ctype){for(var h=r.plte.length,v=0;v<h;v++)r.plte[v]>>>24!=255&&(u=!0);d+=8+3*h+4+(u?8+1*h+4:0)}for(var p=0;p<r.frames.length;p++){c&&(d+=38),d+=(q=r.frames[p]).cimg.length+12,0!=p&&(d+=4)}d+=12;var b=new Uint8Array(d),g=[137,80,78,71,13,10,26,10];for(v=0;v<8;v++)b[v]=g[v];if(f(b,l,13),s(b,l+=4,"IHDR"),f(b,l+=4,t),f(b,l+=4,n),b[l+=4]=r.depth,b[++l]=r.ctype,b[++l]=0,b[++l]=0,b[++l]=0,f(b,++l,a(b,l-17,17)),f(b,l+=4,1),s(b,l+=4,"sRGB"),b[l+=4]=1,f(b,++l,a(b,l-5,5)),l+=4,c&&(f(b,l,8),s(b,l+=4,"acTL"),f(b,l+=4,r.frames.length),f(b,l+=4,0),f(b,l+=4,a(b,l-12,12)),l+=4),3==r.ctype){f(b,l,3*(h=r.plte.length)),s(b,l+=4,"PLTE"),l+=4;for(v=0;v<h;v++){var m=3*v,y=r.plte[v],w=255&y,A=y>>>8&255,U=y>>>16&255;b[l+m+0]=w,b[l+m+1]=A,b[l+m+2]=U}if(f(b,l+=3*h,a(b,l-3*h-4,3*h+4)),l+=4,u){f(b,l,h),s(b,l+=4,"tRNS"),l+=4;for(v=0;v<h;v++)b[l+v]=r.plte[v]>>>24&255;f(b,l+=h,a(b,l-h-4,h+4)),l+=4}}var _=0;for(p=0;p<r.frames.length;p++){var q=r.frames[p];c&&(f(b,l,26),s(b,l+=4,"fcTL"),f(b,l+=4,_++),f(b,l+=4,q.rect.width),f(b,l+=4,q.rect.height),f(b,l+=4,q.rect.x),f(b,l+=4,q.rect.y),o(b,l+=4,i[p]),o(b,l+=2,1e3),b[l+=2]=q.dispose,b[++l]=q.blend,f(b,++l,a(b,l-30,30)),l+=4);var I=q.cimg;f(b,l,(h=I.length)+(0==p?0:4));var M=l+=4;s(b,l,0==p?"IDAT":"fdAT"),l+=4,0!=p&&(f(b,l,_++),l+=4);for(v=0;v<h;v++)b[l+v]=I[v];f(b,l+=h,a(b,M,l-M)),l+=4}return f(b,l,0),s(b,l+=4,"IEND"),f(b,l+=4,a(b,l-4,4)),l+=4,b.buffer},e.encode.compressPNG=function(r,t){for(var n=0;n<r.frames.length;n++){var i=r.frames[n],a=(i.rect.width,i.rect.height),f=new Uint8Array(a*i.bpl+a);i.cimg=e.encode._filterZero(i.img,a,i.bpp,i.bpl,f,t)}},e.encode.compress=function(r,t,n,i,a,f){null==f&&(f=!1);for(var o=6,s=8,l=255,c=0;c<r.length;c++)for(var u=new Uint8Array(r[c]),d=u.length,h=0;h<d;h+=4)l&=u[h+3];var v=255!=l,p=v&&a,b=e.encode.framize(r,t,n,a,p),g={},m=[],y=[];if(0!=i){var w=[];for(h=0;h<b.length;h++)w.push(b[h].img.buffer);var A=e.encode.concatRGBA(w,a),U=e.quantize(A,i),_=0,q=new Uint8Array(U.abuf);for(h=0;h<b.length;h++){var I=(F=b[h].img).length;y.push(new Uint8Array(U.inds.buffer,_>>2,I>>2));for(c=0;c<I;c+=4)F[c]=q[_+c],F[c+1]=q[_+c+1],F[c+2]=q[_+c+2],F[c+3]=q[_+c+3];_+=I}for(h=0;h<U.plte.length;h++)m.push(U.plte[h].est.rgba)}else for(c=0;c<b.length;c++){var M=b[c],T=new Uint32Array(M.img.buffer),z=M.rect.width,R=(d=T.length,new Uint8Array(d));y.push(R);for(h=0;h<d;h++){var N=T[h];if(0!=h&&N==T[h-1])R[h]=R[h-1];else if(h>z&&N==T[h-z])R[h]=R[h-z];else{var L=g[N];if(null==L&&(g[N]=L=m.length,m.push(N),m.length>=300))break;R[h]=L}}}var P=m.length;P<=256&&0==f&&(s=P<=2?1:P<=4?2:P<=16?4:8,a&&(s=8));for(c=0;c<b.length;c++){(M=b[c]).rect.x,M.rect.y,z=M.rect.width;var S=M.rect.height,D=M.img,B=(new Uint32Array(D.buffer),4*z),x=4;if(P<=256&&0==f){B=Math.ceil(s*z/8);for(var C=new Uint8Array(B*S),G=y[c],Z=0;Z<S;Z++){h=Z*B;var k=Z*z;if(8==s)for(var E=0;E<z;E++)C[h+E]=G[k+E];else if(4==s)for(E=0;E<z;E++)C[h+(E>>1)]|=G[k+E]<<4-4*(1&E);else if(2==s)for(E=0;E<z;E++)C[h+(E>>2)]|=G[k+E]<<6-2*(3&E);else if(1==s)for(E=0;E<z;E++)C[h+(E>>3)]|=G[k+E]<<7-1*(7&E)}D=C,o=3,x=1}else if(0==v&&1==b.length){C=new Uint8Array(z*S*3);var H=z*S;for(h=0;h<H;h++){var F,K=4*h;C[F=3*h]=D[K],C[F+1]=D[K+1],C[F+2]=D[K+2]}D=C,o=2,x=3,B=3*z}M.img=D,M.bpl=B,M.bpp=x}return{ctype:o,depth:s,plte:m,frames:b}},e.encode.framize=function(r,t,n,i,a){for(var f=[],o=0;o<r.length;o++){var s=new Uint8Array(r[o]),l=new Uint32Array(s.buffer),c=0,u=0,d=t,h=n,v=0;if(0==o||a)s=s.slice(0);else{for(var p=i||1==o||2==f[f.length-2].dispose?1:2,b=0,g=1e9,m=0;m<p;m++){for(var y=new Uint8Array(r[o-1-m]),w=new Uint32Array(r[o-1-m]),A=t,U=n,_=-1,q=-1,I=0;I<n;I++)for(var M=0;M<t;M++){var T=I*t+M;l[T]!=w[T]&&(M<A&&(A=M),M>_&&(_=M),I<U&&(U=I),I>q&&(q=I))}var z=-1==_?1:(_-A+1)*(q-U+1);z<g&&(g=z,b=m,-1==_?(c=u=0,d=h=1):(c=A,u=U,d=_-A+1,h=q-U+1))}y=new Uint8Array(r[o-1-b]);1==b&&(f[f.length-1].dispose=2);var R=new Uint8Array(d*h*4);new Uint32Array(R.buffer);e._copyTile(y,t,n,R,d,h,-c,-u,0),e._copyTile(s,t,n,R,d,h,-c,-u,3)?(e._copyTile(s,t,n,R,d,h,-c,-u,2),v=1):(e._copyTile(s,t,n,R,d,h,-c,-u,0),v=0),s=R}f.push({rect:{x:c,y:u,width:d,height:h},img:s,blend:v,dispose:a?1:0})}return f},e.encode._filterZero=function(t,n,i,a,f,o){if(-1!=o){for(var s=0;s<n;s++)e.encode._filterLine(f,t,s,a,i,o);return r.deflate(f)}for(var l=[],c=0;c<5;c++)if(!(n*a>5e5)||2!=c&&3!=c&&4!=c){for(s=0;s<n;s++)e.encode._filterLine(f,t,s,a,i,c);if(l.push(r.deflate(f)),1==i)break}for(var u,d=1e9,h=0;h<l.length;h++)l[h].length<d&&(u=h,d=l[h].length);return l[u]},e.encode._filterLine=function(r,t,n,i,a,f){var o=n*i,s=o+n,l=e.decode._paeth;if(r[s]=f,s++,0==f)for(var c=0;c<i;c++)r[s+c]=t[o+c];else if(1==f){for(c=0;c<a;c++)r[s+c]=t[o+c];for(c=a;c<i;c++)r[s+c]=t[o+c]-t[o+c-a]+256&255}else if(0==n){for(c=0;c<a;c++)r[s+c]=t[o+c];if(2==f)for(c=a;c<i;c++)r[s+c]=t[o+c];if(3==f)for(c=a;c<i;c++)r[s+c]=t[o+c]-(t[o+c-a]>>1)+256&255;if(4==f)for(c=a;c<i;c++)r[s+c]=t[o+c]-l(t[o+c-a],0,0)+256&255}else{if(2==f)for(c=0;c<i;c++)r[s+c]=t[o+c]+256-t[o+c-i]&255;if(3==f){for(c=0;c<a;c++)r[s+c]=t[o+c]+256-(t[o+c-i]>>1)&255;for(c=a;c<i;c++)r[s+c]=t[o+c]+256-(t[o+c-i]+t[o+c-a]>>1)&255}if(4==f){for(c=0;c<a;c++)r[s+c]=t[o+c]+256-l(0,t[o+c-i],0)&255;for(c=a;c<i;c++)r[s+c]=t[o+c]+256-l(t[o+c-a],t[o+c-i],t[o+c-a-i])&255}}},e.crc={table:function(){for(var e=new Uint32Array(256),r=0;r<256;r++){for(var t=r,n=0;n<8;n++)1&t?t=3988292384^t>>>1:t>>>=1;e[r]=t}return e}(),update:function(r,t,n,i){for(var a=0;a<i;a++)r=e.crc.table[255&(r^t[n+a])]^r>>>8;return r},crc:function(r,t,n){return 4294967295^e.crc.update(4294967295,r,t,n)}},e.quantize=function(r,t){for(var n=new Uint8Array(r),i=n.slice(0),a=new Uint32Array(i.buffer),f=e.quantize.getKDtree(i,t),o=f[0],s=f[1],l=(e.quantize.planeDst,n),c=a,u=l.length,d=new Uint8Array(n.length>>2),h=0;h<u;h+=4){var v=l[h]*(1/255),p=l[h+1]*(1/255),b=l[h+2]*(1/255),g=l[h+3]*(1/255),m=e.quantize.getNearest(o,v,p,b,g);d[h>>2]=m.ind,c[h>>2]=m.est.rgba}return{abuf:i.buffer,inds:d,plte:s}},e.quantize.getKDtree=function(r,t,n){null==n&&(n=1e-4);var i=new Uint32Array(r.buffer),a={i0:0,i1:r.length,bst:null,est:null,tdst:0,left:null,right:null};a.bst=e.quantize.stats(r,a.i0,a.i1),a.est=e.quantize.estats(a.bst);for(var f=[a];f.length<t;){for(var o=0,s=0,l=0;l<f.length;l++)f[l].est.L>o&&(o=f[l].est.L,s=l);if(o<n)break;var c=f[s],u=e.quantize.splitPixels(r,i,c.i0,c.i1,c.est.e,c.est.eMq255);if(c.i0>=u||c.i1<=u)c.est.L=0;else{var d={i0:c.i0,i1:u,bst:null,est:null,tdst:0,left:null,right:null};d.bst=e.quantize.stats(r,d.i0,d.i1),d.est=e.quantize.estats(d.bst);var h={i0:u,i1:c.i1,bst:null,est:null,tdst:0,left:null,right:null};h.bst={R:[],m:[],N:c.bst.N-d.bst.N};for(l=0;l<16;l++)h.bst.R[l]=c.bst.R[l]-d.bst.R[l];for(l=0;l<4;l++)h.bst.m[l]=c.bst.m[l]-d.bst.m[l];h.est=e.quantize.estats(h.bst),c.left=d,c.right=h,f[s]=d,f.push(h)}}f.sort(function(e,r){return r.bst.N-e.bst.N});for(l=0;l<f.length;l++)f[l].ind=l;return[a,f]},e.quantize.getNearest=function(r,t,n,i,a){if(null==r.left)return r.tdst=e.quantize.dist(r.est.q,t,n,i,a),r;var f=e.quantize.planeDst(r.est,t,n,i,a),o=r.left,s=r.right;f>0&&(o=r.right,s=r.left);var l=e.quantize.getNearest(o,t,n,i,a);if(l.tdst<=f*f)return l;var c=e.quantize.getNearest(s,t,n,i,a);return c.tdst<l.tdst?c:l},e.quantize.planeDst=function(e,r,t,n,i){var a=e.e;return a[0]*r+a[1]*t+a[2]*n+a[3]*i-e.eMq},e.quantize.dist=function(e,r,t,n,i){var a=r-e[0],f=t-e[1],o=n-e[2],s=i-e[3];return a*a+f*f+o*o+s*s},e.quantize.splitPixels=function(r,t,n,i,a,f){var o=e.quantize.vecDot;i-=4;for(;n<i;){for(;o(r,n,a)<=f;)n+=4;for(;o(r,i,a)>f;)i-=4;if(n>=i)break;var s=t[n>>2];t[n>>2]=t[i>>2],t[i>>2]=s,n+=4,i-=4}for(;o(r,n,a)>f;)n-=4;return n+4},e.quantize.vecDot=function(e,r,t){return e[r]*t[0]+e[r+1]*t[1]+e[r+2]*t[2]+e[r+3]*t[3]},e.quantize.stats=function(e,r,t){for(var n=[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],i=[0,0,0,0],a=t-r>>2,f=r;f<t;f+=4){var o=e[f]*(1/255),s=e[f+1]*(1/255),l=e[f+2]*(1/255),c=e[f+3]*(1/255);i[0]+=o,i[1]+=s,i[2]+=l,i[3]+=c,n[0]+=o*o,n[1]+=o*s,n[2]+=o*l,n[3]+=o*c,n[5]+=s*s,n[6]+=s*l,n[7]+=s*c,n[10]+=l*l,n[11]+=l*c,n[15]+=c*c}return n[4]=n[1],n[8]=n[2],n[9]=n[6],n[12]=n[3],n[13]=n[7],n[14]=n[11],{R:n,m:i,N:a}},e.quantize.estats=function(r){var t=r.R,n=r.m,i=r.N,a=n[0],f=n[1],o=n[2],s=n[3],l=0==i?0:1/i,c=[t[0]-a*a*l,t[1]-a*f*l,t[2]-a*o*l,t[3]-a*s*l,t[4]-f*a*l,t[5]-f*f*l,t[6]-f*o*l,t[7]-f*s*l,t[8]-o*a*l,t[9]-o*f*l,t[10]-o*o*l,t[11]-o*s*l,t[12]-s*a*l,t[13]-s*f*l,t[14]-s*o*l,t[15]-s*s*l],u=c,d=e.M4,h=[.5,.5,.5,.5],v=0,p=0;if(0!=i)for(var b=0;b<10&&(h=d.multVec(u,h),p=Math.sqrt(d.dot(h,h)),h=d.sml(1/p,h),!(Math.abs(p-v)<1e-9));b++)v=p;var g=[a*l,f*l,o*l,s*l];return{Cov:c,q:g,e:h,L:v,eMq255:d.dot(d.sml(255,g),h),eMq:d.dot(h,g),rgba:(Math.round(255*g[3])<<24|Math.round(255*g[2])<<16|Math.round(255*g[1])<<8|Math.round(255*g[0])<<0)>>>0}},e.M4={multVec:function(e,r){return[e[0]*r[0]+e[1]*r[1]+e[2]*r[2]+e[3]*r[3],e[4]*r[0]+e[5]*r[1]+e[6]*r[2]+e[7]*r[3],e[8]*r[0]+e[9]*r[1]+e[10]*r[2]+e[11]*r[3],e[12]*r[0]+e[13]*r[1]+e[14]*r[2]+e[15]*r[3]]},dot:function(e,r){return e[0]*r[0]+e[1]*r[1]+e[2]*r[2]+e[3]*r[3]},sml:function(e,r){return[e*r[0],e*r[1],e*r[2],e*r[3]]}},e.encode.concatRGBA=function(e,r){for(var t=0,n=0;n<e.length;n++)t+=e[n].byteLength;var i=new Uint8Array(t),a=0;for(n=0;n<e.length;n++){for(var f=new Uint8Array(e[n]),o=f.length,s=0;s<o;s+=4){var l=f[s],c=f[s+1],u=f[s+2],d=f[s+3];r&&(d=0==(128&d)?0:255),0==d&&(l=c=u=0),i[a+s]=l,i[a+s+1]=c,i[a+s+2]=u,i[a+s+3]=d}a+=o}return i.buffer}}(e,"function"==typeof require?require("pako"):window.pako)}();
142
+ </script>
143
+
144
+ <script>
145
+ class Player {
146
+
147
+ constructor(container) {
148
+ this.container = container
149
+ this.global_frac = 0.0
150
+ this.container = document.getElementById(container)
151
+ this.progress = null;
152
+ this.mat = [[]]
153
+
154
+ this.player = this.container.querySelector('audio')
155
+ this.demo_img = this.container.querySelector('.underlay > img')
156
+ this.overlay = this.container.querySelector('.overlay')
157
+ this.playpause = this.container.querySelector(".playpause");
158
+ this.download = this.container.querySelector(".download");
159
+ this.play_img = this.container.querySelector('.play-img')
160
+ this.pause_img = this.container.querySelector('.pause-img')
161
+ this.canvas = this.container.querySelector('.response-canvas')
162
+ this.response_container = this.container.querySelector('.response')
163
+ this.context = this.canvas.getContext('2d');
164
+
165
+ // console.log(this.player.duration)
166
+ var togglePlayPause = () => {
167
+ if (this.player.networkState !== 1) {
168
+ return
169
+ }
170
+ if (this.player.paused || this.player.ended) {
171
+ this.play()
172
+ } else {
173
+ this.pause()
174
+ }
175
+ }
176
+
177
+ this.update = () => {
178
+ this.global_frac = this.player.currentTime / this.player.duration
179
+ // this.global_frac = frac
180
+ // console.log(this.player.currentTime, this.player.duration, this.global_frac)
181
+ this.overlay.style.width = (100*(1.0 - this.global_frac)).toString() + '%'
182
+ this.redraw()
183
+ }
184
+
185
+ // var start = null;
186
+ this.updateLoop = (timestamp) => {
187
+ // if (!start) start = timestamp;
188
+ // var progress = timestamp - start;
189
+ this.update()
190
+ // this.progress = setTimeout(this.updateLoop, 10)
191
+ this.progress = window.requestAnimationFrame(this.updateLoop)
192
+ }
193
+
194
+ this.seek = (e) => {
195
+ this.global_frac = e.offsetX / this.demo_img.width
196
+ this.player.currentTime = this.global_frac * this.player.duration
197
+ // console.log(this.global_frac)
198
+ this.overlay.style.width = (100*(1.0 - this.global_frac)).toString() + '%'
199
+ this.redraw()
200
+ }
201
+
202
+ var download_audio = () => {
203
+ var url = this.player.querySelector('#src').src
204
+ const a = document.createElement('a')
205
+ a.href = url
206
+ a.download = "download"
207
+ document.body.appendChild(a)
208
+ a.click()
209
+ document.body.removeChild(a)
210
+ }
211
+
212
+ this.demo_img.onclick = this.seek;
213
+ this.playpause.disabled = true
214
+ this.player.onplay = this.updateLoop
215
+ this.player.onpause = () => {
216
+ window.cancelAnimationFrame(this.progress)
217
+ this.update();
218
+ }
219
+ this.player.onended = () => {this.pause()}
220
+ this.playpause.onclick = togglePlayPause;
221
+ this.download.onclick = download_audio;
222
+ }
223
+
224
+ load(audio_fname, img_fname, levels_fname) {
225
+ this.pause()
226
+ window.cancelAnimationFrame(this.progress)
227
+ this.playpause.disabled = true
228
+
229
+ this.player.querySelector('#src').setAttribute("src", audio_fname)
230
+ this.player.load()
231
+ this.demo_img.setAttribute("src", img_fname)
232
+ this.overlay.style.width = '0%'
233
+
234
+ fetch(levels_fname)
235
+ .then(response => response.arrayBuffer())
236
+ .then(text => {
237
+ this.mat = this.parse(text);
238
+ this.playpause.disabled = false;
239
+ this.redraw();
240
+ })
241
+ }
242
+
243
+ parse(buffer) {
244
+ var img = UPNG.decode(buffer)
245
+ var dat = UPNG.toRGBA8(img)[0]
246
+ var view = new DataView(dat)
247
+ var data = new Array(img.width).fill(0).map(() => new Array(img.height).fill(0));
248
+
249
+ var min =100
250
+ var max = -100
251
+ var idx = 0
252
+ for (let i=0; i < img.height*img.width*4; i+=4) {
253
+ var rgba = [view.getUint8(i, 1) / 255, view.getUint8(i + 1, 1) / 255, view.getUint8(i + 2, 1) / 255, view.getUint8(i + 3, 1) / 255]
254
+ var norm = Math.pow(Math.pow(rgba[0], 2) + Math.pow(rgba[1], 2) + Math.pow(rgba[2], 2), 0.5)
255
+ data[idx % img.width][img.height - Math.floor(idx / img.width) - 1] = norm
256
+
257
+ idx += 1
258
+ min = Math.min(min, norm)
259
+ max = Math.max(max, norm)
260
+ }
261
+ for (let i = 0; i < data.length; i++) {
262
+ for (let j = 0; j < data[i].length; j++) {
263
+ data[i][j] = Math.pow((data[i][j] - min) / (max - min), 1.5)
264
+ }
265
+ }
266
+ var data3 = new Array(img.width).fill(0).map(() => new Array(img.height).fill(0));
267
+ for (let i = 0; i < data.length; i++) {
268
+ for (let j = 0; j < data[i].length; j++) {
269
+ if (i == 0 || i == (data.length - 1)) {
270
+ data3[i][j] = data[i][j]
271
+ } else{
272
+ data3[i][j] = 0.33*(data[i - 1][j]) + 0.33*(data[i][j]) + 0.33*(data[i + 1][j])
273
+ // data3[i][j] = 0.00*(data[i - 1][j]) + 1.00*(data[i][j]) + 0.00*(data[i + 1][j])
274
+ }
275
+ }
276
+ }
277
+
278
+ var scale = 5
279
+ var data2 = new Array(scale*img.width).fill(0).map(() => new Array(img.height).fill(0));
280
+ for (let j = 0; j < data[0].length; j++) {
281
+ for (let i = 0; i < data.length - 1; i++) {
282
+ for (let k = 0; k < scale; k++) {
283
+ data2[scale*i + k][j] = (1.0 - (k/scale))*data3[i][j] + (k / scale)*data3[i + 1][j]
284
+ }
285
+ }
286
+ }
287
+ return data2
288
+ }
289
+
290
+ play() {
291
+ this.player.play();
292
+ this.play_img.style.display = 'none'
293
+ this.pause_img.style.display = 'block'
294
+ }
295
+
296
+ pause() {
297
+ this.player.pause();
298
+ this.pause_img.style.display = 'none'
299
+ this.play_img.style.display = 'block'
300
+ }
301
+
302
+ redraw() {
303
+ this.canvas.width = window.devicePixelRatio*this.response_container.offsetWidth;
304
+ this.canvas.height = window.devicePixelRatio*this.response_container.offsetHeight;
305
+
306
+ this.context.clearRect(0, 0, this.canvas.width, this.canvas.height)
307
+ this.canvas.style.width = (this.canvas.width / window.devicePixelRatio).toString() + "px";
308
+ this.canvas.style.height = (this.canvas.height / window.devicePixelRatio).toString() + "px";
309
+
310
+ var f = this.global_frac*this.mat.length
311
+ var tstep = Math.min(Math.floor(f), this.mat.length - 2)
312
+ var heights = this.mat[tstep]
313
+ var bar_width = (this.canvas.width / heights.length) - 1
314
+
315
+ for (let k = 0; k < heights.length - 1; k++) {
316
+ var height = Math.max(Math.round((heights[k])*this.canvas.height), 3)
317
+ this.context.fillStyle = '#696f7b';
318
+ this.context.fillRect(k*(bar_width + 1), (this.canvas.height - height) / 2, bar_width, height);
319
+ }
320
+ }
321
+ }
322
+ </script>
dac-vae/audiotools/core/templates/pandoc.css ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ Copyright (c) 2017 Chris Patuzzo
3
+ https://twitter.com/chrispatuzzo
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+ */
23
+
24
+ body {
25
+ font-family: Helvetica, arial, sans-serif;
26
+ font-size: 14px;
27
+ line-height: 1.6;
28
+ padding-top: 10px;
29
+ padding-bottom: 10px;
30
+ background-color: white;
31
+ padding: 30px;
32
+ color: #333;
33
+ }
34
+
35
+ body > *:first-child {
36
+ margin-top: 0 !important;
37
+ }
38
+
39
+ body > *:last-child {
40
+ margin-bottom: 0 !important;
41
+ }
42
+
43
+ a {
44
+ color: #4183C4;
45
+ text-decoration: none;
46
+ }
47
+
48
+ a.absent {
49
+ color: #cc0000;
50
+ }
51
+
52
+ a.anchor {
53
+ display: block;
54
+ padding-left: 30px;
55
+ margin-left: -30px;
56
+ cursor: pointer;
57
+ position: absolute;
58
+ top: 0;
59
+ left: 0;
60
+ bottom: 0;
61
+ }
62
+
63
+ h1, h2, h3, h4, h5, h6 {
64
+ margin: 20px 0 10px;
65
+ padding: 0;
66
+ font-weight: bold;
67
+ -webkit-font-smoothing: antialiased;
68
+ cursor: text;
69
+ position: relative;
70
+ }
71
+
72
+ h2:first-child, h1:first-child, h1:first-child + h2, h3:first-child, h4:first-child, h5:first-child, h6:first-child {
73
+ margin-top: 0;
74
+ padding-top: 0;
75
+ }
76
+
77
+ h1:hover a.anchor, h2:hover a.anchor, h3:hover a.anchor, h4:hover a.anchor, h5:hover a.anchor, h6:hover a.anchor {
78
+ text-decoration: none;
79
+ }
80
+
81
+ h1 tt, h1 code {
82
+ font-size: inherit;
83
+ }
84
+
85
+ h2 tt, h2 code {
86
+ font-size: inherit;
87
+ }
88
+
89
+ h3 tt, h3 code {
90
+ font-size: inherit;
91
+ }
92
+
93
+ h4 tt, h4 code {
94
+ font-size: inherit;
95
+ }
96
+
97
+ h5 tt, h5 code {
98
+ font-size: inherit;
99
+ }
100
+
101
+ h6 tt, h6 code {
102
+ font-size: inherit;
103
+ }
104
+
105
+ h1 {
106
+ font-size: 28px;
107
+ color: black;
108
+ }
109
+
110
+ h2 {
111
+ font-size: 24px;
112
+ border-bottom: 1px solid #cccccc;
113
+ color: black;
114
+ }
115
+
116
+ h3 {
117
+ font-size: 18px;
118
+ }
119
+
120
+ h4 {
121
+ font-size: 16px;
122
+ }
123
+
124
+ h5 {
125
+ font-size: 14px;
126
+ }
127
+
128
+ h6 {
129
+ color: #777777;
130
+ font-size: 14px;
131
+ }
132
+
133
+ p, blockquote, ul, ol, dl, li, table, pre {
134
+ margin: 15px 0;
135
+ }
136
+
137
+ hr {
138
+ border: 0 none;
139
+ color: #cccccc;
140
+ height: 4px;
141
+ padding: 0;
142
+ }
143
+
144
+ body > h2:first-child {
145
+ margin-top: 0;
146
+ padding-top: 0;
147
+ }
148
+
149
+ body > h1:first-child {
150
+ margin-top: 0;
151
+ padding-top: 0;
152
+ }
153
+
154
+ body > h1:first-child + h2 {
155
+ margin-top: 0;
156
+ padding-top: 0;
157
+ }
158
+
159
+ body > h3:first-child, body > h4:first-child, body > h5:first-child, body > h6:first-child {
160
+ margin-top: 0;
161
+ padding-top: 0;
162
+ }
163
+
164
+ a:first-child h1, a:first-child h2, a:first-child h3, a:first-child h4, a:first-child h5, a:first-child h6 {
165
+ margin-top: 0;
166
+ padding-top: 0;
167
+ }
168
+
169
+ h1 p, h2 p, h3 p, h4 p, h5 p, h6 p {
170
+ margin-top: 0;
171
+ }
172
+
173
+ li p.first {
174
+ display: inline-block;
175
+ }
176
+
177
+ ul, ol {
178
+ padding-left: 30px;
179
+ }
180
+
181
+ ul :first-child, ol :first-child {
182
+ margin-top: 0;
183
+ }
184
+
185
+ ul :last-child, ol :last-child {
186
+ margin-bottom: 0;
187
+ }
188
+
189
+ dl {
190
+ padding: 0;
191
+ }
192
+
193
+ dl dt {
194
+ font-size: 14px;
195
+ font-weight: bold;
196
+ font-style: italic;
197
+ padding: 0;
198
+ margin: 15px 0 5px;
199
+ }
200
+
201
+ dl dt:first-child {
202
+ padding: 0;
203
+ }
204
+
205
+ dl dt > :first-child {
206
+ margin-top: 0;
207
+ }
208
+
209
+ dl dt > :last-child {
210
+ margin-bottom: 0;
211
+ }
212
+
213
+ dl dd {
214
+ margin: 0 0 15px;
215
+ padding: 0 15px;
216
+ }
217
+
218
+ dl dd > :first-child {
219
+ margin-top: 0;
220
+ }
221
+
222
+ dl dd > :last-child {
223
+ margin-bottom: 0;
224
+ }
225
+
226
+ blockquote {
227
+ border-left: 4px solid #dddddd;
228
+ padding: 0 15px;
229
+ color: #777777;
230
+ }
231
+
232
+ blockquote > :first-child {
233
+ margin-top: 0;
234
+ }
235
+
236
+ blockquote > :last-child {
237
+ margin-bottom: 0;
238
+ }
239
+
240
+ table {
241
+ padding: 0;
242
+ }
243
+ table tr {
244
+ border-top: 1px solid #cccccc;
245
+ background-color: white;
246
+ margin: 0;
247
+ padding: 0;
248
+ }
249
+
250
+ table tr:nth-child(2n) {
251
+ background-color: #f8f8f8;
252
+ }
253
+
254
+ table tr th {
255
+ font-weight: bold;
256
+ border: 1px solid #cccccc;
257
+ text-align: left;
258
+ margin: 0;
259
+ padding: 6px 13px;
260
+ }
261
+
262
+ table tr td {
263
+ border: 1px solid #cccccc;
264
+ text-align: left;
265
+ margin: 0;
266
+ padding: 6px 13px;
267
+ }
268
+
269
+ table tr th :first-child, table tr td :first-child {
270
+ margin-top: 0;
271
+ }
272
+
273
+ table tr th :last-child, table tr td :last-child {
274
+ margin-bottom: 0;
275
+ }
276
+
277
+ img {
278
+ max-width: 100%;
279
+ }
280
+
281
+ span.frame {
282
+ display: block;
283
+ overflow: hidden;
284
+ }
285
+
286
+ span.frame > span {
287
+ border: 1px solid #dddddd;
288
+ display: block;
289
+ float: left;
290
+ overflow: hidden;
291
+ margin: 13px 0 0;
292
+ padding: 7px;
293
+ width: auto;
294
+ }
295
+
296
+ span.frame span img {
297
+ display: block;
298
+ float: left;
299
+ }
300
+
301
+ span.frame span span {
302
+ clear: both;
303
+ color: #333333;
304
+ display: block;
305
+ padding: 5px 0 0;
306
+ }
307
+
308
+ span.align-center {
309
+ display: block;
310
+ overflow: hidden;
311
+ clear: both;
312
+ }
313
+
314
+ span.align-center > span {
315
+ display: block;
316
+ overflow: hidden;
317
+ margin: 13px auto 0;
318
+ text-align: center;
319
+ }
320
+
321
+ span.align-center span img {
322
+ margin: 0 auto;
323
+ text-align: center;
324
+ }
325
+
326
+ span.align-right {
327
+ display: block;
328
+ overflow: hidden;
329
+ clear: both;
330
+ }
331
+
332
+ span.align-right > span {
333
+ display: block;
334
+ overflow: hidden;
335
+ margin: 13px 0 0;
336
+ text-align: right;
337
+ }
338
+
339
+ span.align-right span img {
340
+ margin: 0;
341
+ text-align: right;
342
+ }
343
+
344
+ span.float-left {
345
+ display: block;
346
+ margin-right: 13px;
347
+ overflow: hidden;
348
+ float: left;
349
+ }
350
+
351
+ span.float-left span {
352
+ margin: 13px 0 0;
353
+ }
354
+
355
+ span.float-right {
356
+ display: block;
357
+ margin-left: 13px;
358
+ overflow: hidden;
359
+ float: right;
360
+ }
361
+
362
+ span.float-right > span {
363
+ display: block;
364
+ overflow: hidden;
365
+ margin: 13px auto 0;
366
+ text-align: right;
367
+ }
368
+
369
+ code, tt {
370
+ margin: 0 2px;
371
+ padding: 0 5px;
372
+ white-space: nowrap;
373
+ border-radius: 3px;
374
+ }
375
+
376
+ pre code {
377
+ margin: 0;
378
+ padding: 0;
379
+ white-space: pre;
380
+ border: none;
381
+ background: transparent;
382
+ }
383
+
384
+ .highlight pre {
385
+ font-size: 13px;
386
+ line-height: 19px;
387
+ overflow: auto;
388
+ padding: 6px 10px;
389
+ border-radius: 3px;
390
+ }
391
+
392
+ pre {
393
+ font-size: 13px;
394
+ line-height: 19px;
395
+ overflow: auto;
396
+ padding: 6px 10px;
397
+ border-radius: 3px;
398
+ }
399
+
400
+ pre code, pre tt {
401
+ background-color: transparent;
402
+ border: none;
403
+ }
404
+
405
+ body {
406
+ max-width: 600px;
407
+ }
dac-vae/audiotools/core/templates/widget.html ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div id='PLAYER_ID' class='player' style="max-width: MAX_WIDTH;">
2
+ <div class='spectrogram' style="padding-top: PADDING_AMOUNT;">
3
+ <div class='overlay'></div>
4
+ <div class='underlay'>
5
+ <img>
6
+ </div>
7
+ </div>
8
+
9
+ <div class='audio-controls'>
10
+ <button id="playpause" disabled class='playpause' title="play">
11
+ <svg class='play-img' width="14px" height="19px" viewBox="0 0 14 19">
12
+ <polygon id="Triangle" fill="#000000" transform="translate(9, 9.5) rotate(90) translate(-7, -9.5) " points="7 2.5 16.5 16.5 -2.5 16.5"></polygon>
13
+ </svg>
14
+ <svg class='pause-img' width="16px" height="19px" viewBox="0 0 16 19">
15
+ <g fill="#000000" stroke="#000000">
16
+ <rect id="Rectangle" x="0.5" y="0.5" width="4" height="18"></rect>
17
+ <rect id="Rectangle" x="11.5" y="0.5" width="4" height="18"></rect>
18
+ </g>
19
+ </svg>
20
+ </button>
21
+
22
+ <audio class='play'>
23
+ <source id='src'>
24
+ </audio>
25
+ <div class='response'>
26
+ <canvas class='response-canvas'></canvas>
27
+ </div>
28
+
29
+ <button id="download" class='download' title="download">
30
+ <svg class='download-img' x="0px" y="0px" viewBox="0 0 29.978 29.978" style="enable-background:new 0 0 29.978 29.978;" xml:space="preserve">
31
+ <g>
32
+ <path d="M25.462,19.105v6.848H4.515v-6.848H0.489v8.861c0,1.111,0.9,2.012,2.016,2.012h24.967c1.115,0,2.016-0.9,2.016-2.012
33
+ v-8.861H25.462z"/>
34
+ <path d="M14.62,18.426l-5.764-6.965c0,0-0.877-0.828,0.074-0.828s3.248,0,3.248,0s0-0.557,0-1.416c0-2.449,0-6.906,0-8.723
35
+ c0,0-0.129-0.494,0.615-0.494c0.75,0,4.035,0,4.572,0c0.536,0,0.524,0.416,0.524,0.416c0,1.762,0,6.373,0,8.742
36
+ c0,0.768,0,1.266,0,1.266s1.842,0,2.998,0c1.154,0,0.285,0.867,0.285,0.867s-4.904,6.51-5.588,7.193
37
+ C15.092,18.979,14.62,18.426,14.62,18.426z"/>
38
+ </g>
39
+ </svg>
40
+ </button>
41
+ </div>
42
+ </div>
43
+
44
+ <script>
45
+ var PLAYER_ID = new Player('PLAYER_ID')
46
+ PLAYER_ID.load(
47
+ "AUDIO_SRC",
48
+ "IMAGE_SRC",
49
+ "LEVELS_SRC"
50
+ )
51
+ window.addEventListener("resize", function() {PLAYER_ID.redraw()})
52
+ </script>
dac-vae/audiotools/core/util.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import glob
3
+ import math
4
+ import numbers
5
+ import os
6
+ import random
7
+ import typing
8
+ from contextlib import contextmanager
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+ from typing import Dict
12
+ from typing import List
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torchaudio
17
+ from flatten_dict import flatten
18
+ from flatten_dict import unflatten
19
+
20
+
21
+ @dataclass
22
+ class Info:
23
+ """Shim for torchaudio.info API changes."""
24
+
25
+ sample_rate: float
26
+ num_frames: int
27
+
28
+ @property
29
+ def duration(self) -> float:
30
+ return self.num_frames / self.sample_rate
31
+
32
+
33
+ def info(audio_path: str):
34
+ """Shim for torchaudio.info to make 0.7.2 API match 0.8.0.
35
+
36
+ Parameters
37
+ ----------
38
+ audio_path : str
39
+ Path to audio file.
40
+ """
41
+ # try default backend first, then fallback to soundfile
42
+ try:
43
+ info = torchaudio.info(str(audio_path))
44
+ except: # pragma: no cover
45
+ info = torchaudio.backend.soundfile_backend.info(str(audio_path))
46
+
47
+ if isinstance(info, tuple): # pragma: no cover
48
+ signal_info = info[0]
49
+ info = Info(sample_rate=signal_info.rate, num_frames=signal_info.length)
50
+ else:
51
+ info = Info(sample_rate=info.sample_rate, num_frames=info.num_frames)
52
+
53
+ return info
54
+
55
+
56
+ def ensure_tensor(
57
+ x: typing.Union[np.ndarray, torch.Tensor, float, int],
58
+ ndim: int = None,
59
+ batch_size: int = None,
60
+ ):
61
+ """Ensures that the input ``x`` is a tensor of specified
62
+ dimensions and batch size.
63
+
64
+ Parameters
65
+ ----------
66
+ x : typing.Union[np.ndarray, torch.Tensor, float, int]
67
+ Data that will become a tensor on its way out.
68
+ ndim : int, optional
69
+ How many dimensions should be in the output, by default None
70
+ batch_size : int, optional
71
+ The batch size of the output, by default None
72
+
73
+ Returns
74
+ -------
75
+ torch.Tensor
76
+ Modified version of ``x`` as a tensor.
77
+ """
78
+ if not torch.is_tensor(x):
79
+ x = torch.as_tensor(x)
80
+ if ndim is not None:
81
+ assert x.ndim <= ndim
82
+ while x.ndim < ndim:
83
+ x = x.unsqueeze(-1)
84
+ if batch_size is not None:
85
+ if x.shape[0] != batch_size:
86
+ shape = list(x.shape)
87
+ shape[0] = batch_size
88
+ x = x.expand(*shape)
89
+ return x
90
+
91
+
92
+ def _get_value(other):
93
+ from . import AudioSignal
94
+
95
+ if isinstance(other, AudioSignal):
96
+ return other.audio_data
97
+ return other
98
+
99
+
100
+ def hz_to_bin(hz: torch.Tensor, n_fft: int, sample_rate: int):
101
+ """Closest frequency bin given a frequency, number
102
+ of bins, and a sampling rate.
103
+
104
+ Parameters
105
+ ----------
106
+ hz : torch.Tensor
107
+ Tensor of frequencies in Hz.
108
+ n_fft : int
109
+ Number of FFT bins.
110
+ sample_rate : int
111
+ Sample rate of audio.
112
+
113
+ Returns
114
+ -------
115
+ torch.Tensor
116
+ Closest bins to the data.
117
+ """
118
+ shape = hz.shape
119
+ hz = hz.flatten()
120
+ freqs = torch.linspace(0, sample_rate / 2, 2 + n_fft // 2)
121
+ hz[hz > sample_rate / 2] = sample_rate / 2
122
+
123
+ closest = (hz[None, :] - freqs[:, None]).abs()
124
+ closest_bins = closest.min(dim=0).indices
125
+
126
+ return closest_bins.reshape(*shape)
127
+
128
+
129
+ def random_state(seed: typing.Union[int, np.random.RandomState]):
130
+ """
131
+ Turn seed into a np.random.RandomState instance.
132
+
133
+ Parameters
134
+ ----------
135
+ seed : typing.Union[int, np.random.RandomState] or None
136
+ If seed is None, return the RandomState singleton used by np.random.
137
+ If seed is an int, return a new RandomState instance seeded with seed.
138
+ If seed is already a RandomState instance, return it.
139
+ Otherwise raise ValueError.
140
+
141
+ Returns
142
+ -------
143
+ np.random.RandomState
144
+ Random state object.
145
+
146
+ Raises
147
+ ------
148
+ ValueError
149
+ If seed is not valid, an error is thrown.
150
+ """
151
+ if seed is None or seed is np.random:
152
+ return np.random.mtrand._rand
153
+ elif isinstance(seed, (numbers.Integral, np.integer, int)):
154
+ return np.random.RandomState(seed)
155
+ elif isinstance(seed, np.random.RandomState):
156
+ return seed
157
+ else:
158
+ raise ValueError(
159
+ "%r cannot be used to seed a numpy.random.RandomState" " instance" % seed
160
+ )
161
+
162
+
163
+ def seed(random_seed, set_cudnn=False):
164
+ """
165
+ Seeds all random states with the same random seed
166
+ for reproducibility. Seeds ``numpy``, ``random`` and ``torch``
167
+ random generators.
168
+ For full reproducibility, two further options must be set
169
+ according to the torch documentation:
170
+ https://pytorch.org/docs/stable/notes/randomness.html
171
+ To do this, ``set_cudnn`` must be True. It defaults to
172
+ False, since setting it to True results in a performance
173
+ hit.
174
+
175
+ Args:
176
+ random_seed (int): integer corresponding to random seed to
177
+ use.
178
+ set_cudnn (bool): Whether or not to set cudnn into determinstic
179
+ mode and off of benchmark mode. Defaults to False.
180
+ """
181
+
182
+ torch.manual_seed(random_seed)
183
+ np.random.seed(random_seed)
184
+ random.seed(random_seed)
185
+
186
+ if set_cudnn:
187
+ torch.backends.cudnn.deterministic = True
188
+ torch.backends.cudnn.benchmark = False
189
+
190
+
191
+ @contextmanager
192
+ def _close_temp_files(tmpfiles: list):
193
+ """Utility function for creating a context and closing all temporary files
194
+ once the context is exited. For correct functionality, all temporary file
195
+ handles created inside the context must be appended to the ```tmpfiles```
196
+ list.
197
+
198
+ This function is taken wholesale from Scaper.
199
+
200
+ Parameters
201
+ ----------
202
+ tmpfiles : list
203
+ List of temporary file handles
204
+ """
205
+
206
+ def _close():
207
+ for t in tmpfiles:
208
+ try:
209
+ t.close()
210
+ os.unlink(t.name)
211
+ except:
212
+ pass
213
+
214
+ try:
215
+ yield
216
+ except: # pragma: no cover
217
+ _close()
218
+ raise
219
+ _close()
220
+
221
+
222
+ AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3", ".mp4"]
223
+
224
+
225
+ def find_audio(folder: str, ext: List[str] = AUDIO_EXTENSIONS):
226
+ """Finds all audio files in a directory recursively.
227
+ Returns a list.
228
+
229
+ Parameters
230
+ ----------
231
+ folder : str
232
+ Folder to look for audio files in, recursively.
233
+ ext : List[str], optional
234
+ Extensions to look for without the ., by default
235
+ ``['.wav', '.flac', '.mp3', '.mp4']``.
236
+ """
237
+ folder = Path(folder)
238
+ # Take care of case where user has passed in an audio file directly
239
+ # into one of the calling functions.
240
+ if str(folder).endswith(tuple(ext)):
241
+ # if, however, there's a glob in the path, we need to
242
+ # return the glob, not the file.
243
+ if "*" in str(folder):
244
+ return glob.glob(str(folder), recursive=("**" in str(folder)))
245
+ else:
246
+ return [folder]
247
+
248
+ files = []
249
+ for x in ext:
250
+ files += folder.glob(f"**/*{x}")
251
+ return files
252
+
253
+
254
+ def read_sources(
255
+ sources: List[str],
256
+ remove_empty: bool = True,
257
+ relative_path: str = "",
258
+ ext: List[str] = AUDIO_EXTENSIONS,
259
+ ):
260
+ """Reads audio sources that can either be folders
261
+ full of audio files, or CSV files that contain paths
262
+ to audio files. CSV files that adhere to the expected
263
+ format can be generated by
264
+ :py:func:`audiotools.data.preprocess.create_csv`.
265
+
266
+ Parameters
267
+ ----------
268
+ sources : List[str]
269
+ List of audio sources to be converted into a
270
+ list of lists of audio files.
271
+ remove_empty : bool, optional
272
+ Whether or not to remove rows with an empty "path"
273
+ from each CSV file, by default True.
274
+
275
+ Returns
276
+ -------
277
+ list
278
+ List of lists of rows of CSV files.
279
+ """
280
+ files = []
281
+ relative_path = Path(relative_path)
282
+ for source in sources:
283
+ source = str(source)
284
+ _files = []
285
+ if source.endswith(".csv"):
286
+ with open(source, "r") as f:
287
+ reader = csv.DictReader(f)
288
+ for x in reader:
289
+ if remove_empty and x["path"] == "":
290
+ continue
291
+ if x["path"] != "":
292
+ x["path"] = str(relative_path / x["path"])
293
+ _files.append(x)
294
+ else:
295
+ for x in find_audio(source, ext=ext):
296
+ x = str(relative_path / x)
297
+ _files.append({"path": x})
298
+ files.append(sorted(_files, key=lambda x: x["path"]))
299
+ return files
300
+
301
+
302
+ def choose_from_list_of_lists(
303
+ state: np.random.RandomState, list_of_lists: list, p: float = None
304
+ ):
305
+ """Choose a single item from a list of lists.
306
+
307
+ Parameters
308
+ ----------
309
+ state : np.random.RandomState
310
+ Random state to use when choosing an item.
311
+ list_of_lists : list
312
+ A list of lists from which items will be drawn.
313
+ p : float, optional
314
+ Probabilities of each list, by default None
315
+
316
+ Returns
317
+ -------
318
+ typing.Any
319
+ An item from the list of lists.
320
+ """
321
+ source_idx = state.choice(list(range(len(list_of_lists))), p=p)
322
+ item_idx = state.randint(len(list_of_lists[source_idx]))
323
+ return list_of_lists[source_idx][item_idx], source_idx, item_idx
324
+
325
+
326
+ @contextmanager
327
+ def chdir(newdir: typing.Union[Path, str]):
328
+ """
329
+ Context manager for switching directories to run a
330
+ function. Useful for when you want to use relative
331
+ paths to different runs.
332
+
333
+ Parameters
334
+ ----------
335
+ newdir : typing.Union[Path, str]
336
+ Directory to switch to.
337
+ """
338
+ curdir = os.getcwd()
339
+ try:
340
+ os.chdir(newdir)
341
+ yield
342
+ finally:
343
+ os.chdir(curdir)
344
+
345
+
346
+ def prepare_batch(batch: typing.Union[dict, list, torch.Tensor], device: str = "cpu"):
347
+ """Moves items in a batch (typically generated by a DataLoader as a list
348
+ or a dict) to the specified device. This works even if dictionaries
349
+ are nested.
350
+
351
+ Parameters
352
+ ----------
353
+ batch : typing.Union[dict, list, torch.Tensor]
354
+ Batch, typically generated by a dataloader, that will be moved to
355
+ the device.
356
+ device : str, optional
357
+ Device to move batch to, by default "cpu"
358
+
359
+ Returns
360
+ -------
361
+ typing.Union[dict, list, torch.Tensor]
362
+ Batch with all values moved to the specified device.
363
+ """
364
+ if isinstance(batch, dict):
365
+ batch = flatten(batch)
366
+ for key, val in batch.items():
367
+ try:
368
+ batch[key] = val.to(device)
369
+ except:
370
+ pass
371
+ batch = unflatten(batch)
372
+ elif torch.is_tensor(batch):
373
+ batch = batch.to(device)
374
+ elif isinstance(batch, list):
375
+ for i in range(len(batch)):
376
+ try:
377
+ batch[i] = batch[i].to(device)
378
+ except:
379
+ pass
380
+ return batch
381
+
382
+
383
+ def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState = None):
384
+ """Samples from a distribution defined by a tuple. The first
385
+ item in the tuple is the distribution type, and the rest of the
386
+ items are arguments to that distribution. The distribution function
387
+ is gotten from the ``np.random.RandomState`` object.
388
+
389
+ Parameters
390
+ ----------
391
+ dist_tuple : tuple
392
+ Distribution tuple
393
+ state : np.random.RandomState, optional
394
+ Random state, or seed to use, by default None
395
+
396
+ Returns
397
+ -------
398
+ typing.Union[float, int, str]
399
+ Draw from the distribution.
400
+
401
+ Examples
402
+ --------
403
+ Sample from a uniform distribution:
404
+
405
+ >>> dist_tuple = ("uniform", 0, 1)
406
+ >>> sample_from_dist(dist_tuple)
407
+
408
+ Sample from a constant distribution:
409
+
410
+ >>> dist_tuple = ("const", 0)
411
+ >>> sample_from_dist(dist_tuple)
412
+
413
+ Sample from a normal distribution:
414
+
415
+ >>> dist_tuple = ("normal", 0, 0.5)
416
+ >>> sample_from_dist(dist_tuple)
417
+
418
+ """
419
+ if dist_tuple[0] == "const":
420
+ return dist_tuple[1]
421
+ state = random_state(state)
422
+ dist_fn = getattr(state, dist_tuple[0])
423
+ return dist_fn(*dist_tuple[1:])
424
+
425
+
426
+ def collate(list_of_dicts: list, n_splits: int = None):
427
+ """Collates a list of dictionaries (e.g. as returned by a
428
+ dataloader) into a dictionary with batched values. This routine
429
+ uses the default torch collate function for everything
430
+ except AudioSignal objects, which are handled by the
431
+ :py:func:`audiotools.core.audio_signal.AudioSignal.batch`
432
+ function.
433
+
434
+ This function takes n_splits to enable splitting a batch
435
+ into multiple sub-batches for the purposes of gradient accumulation,
436
+ etc.
437
+
438
+ Parameters
439
+ ----------
440
+ list_of_dicts : list
441
+ List of dictionaries to be collated.
442
+ n_splits : int
443
+ Number of splits to make when creating the batches (split into
444
+ sub-batches). Useful for things like gradient accumulation.
445
+
446
+ Returns
447
+ -------
448
+ dict
449
+ Dictionary containing batched data.
450
+ """
451
+
452
+ from . import AudioSignal
453
+
454
+ batches = []
455
+ list_len = len(list_of_dicts)
456
+
457
+ return_list = False if n_splits is None else True
458
+ n_splits = 1 if n_splits is None else n_splits
459
+ n_items = int(math.ceil(list_len / n_splits))
460
+
461
+ for i in range(0, list_len, n_items):
462
+ # Flatten the dictionaries to avoid recursion.
463
+ list_of_dicts_ = [flatten(d) for d in list_of_dicts[i : i + n_items]]
464
+ dict_of_lists = {
465
+ k: [dic[k] for dic in list_of_dicts_] for k in list_of_dicts_[0]
466
+ }
467
+
468
+ batch = {}
469
+ for k, v in dict_of_lists.items():
470
+ if isinstance(v, list):
471
+ if all(isinstance(s, AudioSignal) for s in v):
472
+ batch[k] = AudioSignal.batch(v, pad_signals=True)
473
+ else:
474
+ # Borrow the default collate fn from torch.
475
+ batch[k] = torch.utils.data._utils.collate.default_collate(v)
476
+ batches.append(unflatten(batch))
477
+
478
+ batches = batches[0] if not return_list else batches
479
+ return batches
480
+
481
+
482
+ BASE_SIZE = 864
483
+ DEFAULT_FIG_SIZE = (9, 3)
484
+
485
+
486
+ def format_figure(
487
+ fig_size: tuple = None,
488
+ title: str = None,
489
+ fig=None,
490
+ format_axes: bool = True,
491
+ format: bool = True,
492
+ font_color: str = "white",
493
+ ):
494
+ """Prettifies the spectrogram and waveform plots. A title
495
+ can be inset into the top right corner, and the axes can be
496
+ inset into the figure, allowing the data to take up the entire
497
+ image. Used in
498
+
499
+ - :py:func:`audiotools.core.display.DisplayMixin.specshow`
500
+ - :py:func:`audiotools.core.display.DisplayMixin.waveplot`
501
+ - :py:func:`audiotools.core.display.DisplayMixin.wavespec`
502
+
503
+ Parameters
504
+ ----------
505
+ fig_size : tuple, optional
506
+ Size of figure, by default (9, 3)
507
+ title : str, optional
508
+ Title to inset in top right, by default None
509
+ fig : matplotlib.figure.Figure, optional
510
+ Figure object, if None ``plt.gcf()`` will be used, by default None
511
+ format_axes : bool, optional
512
+ Format the axes to be inside the figure, by default True
513
+ format : bool, optional
514
+ This formatting can be skipped entirely by passing ``format=False``
515
+ to any of the plotting functions that use this formater, by default True
516
+ font_color : str, optional
517
+ Color of font of axes, by default "white"
518
+ """
519
+ import matplotlib
520
+ import matplotlib.pyplot as plt
521
+
522
+ if fig_size is None:
523
+ fig_size = DEFAULT_FIG_SIZE
524
+ if not format:
525
+ return
526
+ if fig is None:
527
+ fig = plt.gcf()
528
+ fig.set_size_inches(*fig_size)
529
+ axs = fig.axes
530
+
531
+ pixels = (fig.get_size_inches() * fig.dpi)[0]
532
+ font_scale = pixels / BASE_SIZE
533
+
534
+ if format_axes:
535
+ axs = fig.axes
536
+
537
+ for ax in axs:
538
+ ymin, _ = ax.get_ylim()
539
+ xmin, _ = ax.get_xlim()
540
+
541
+ ticks = ax.get_yticks()
542
+ for t in ticks[2:-1]:
543
+ t = axs[0].annotate(
544
+ f"{(t / 1000):2.1f}k",
545
+ xy=(xmin, t),
546
+ xycoords="data",
547
+ xytext=(5, -5),
548
+ textcoords="offset points",
549
+ ha="left",
550
+ va="top",
551
+ color=font_color,
552
+ fontsize=12 * font_scale,
553
+ alpha=0.75,
554
+ )
555
+
556
+ ticks = ax.get_xticks()[2:]
557
+ for t in ticks[:-1]:
558
+ t = axs[0].annotate(
559
+ f"{t:2.1f}s",
560
+ xy=(t, ymin),
561
+ xycoords="data",
562
+ xytext=(5, 5),
563
+ textcoords="offset points",
564
+ ha="center",
565
+ va="bottom",
566
+ color=font_color,
567
+ fontsize=12 * font_scale,
568
+ alpha=0.75,
569
+ )
570
+
571
+ ax.margins(0, 0)
572
+ ax.set_axis_off()
573
+ ax.xaxis.set_major_locator(plt.NullLocator())
574
+ ax.yaxis.set_major_locator(plt.NullLocator())
575
+
576
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
577
+
578
+ if title is not None:
579
+ t = axs[0].annotate(
580
+ title,
581
+ xy=(1, 1),
582
+ xycoords="axes fraction",
583
+ fontsize=20 * font_scale,
584
+ xytext=(-5, -5),
585
+ textcoords="offset points",
586
+ ha="right",
587
+ va="top",
588
+ color="white",
589
+ )
590
+ t.set_bbox(dict(facecolor="black", alpha=0.5, edgecolor="black"))
591
+
592
+
593
+ def generate_chord_dataset(
594
+ max_voices: int = 8,
595
+ sample_rate: int = 44100,
596
+ num_items: int = 5,
597
+ duration: float = 1.0,
598
+ min_note: str = "C2",
599
+ max_note: str = "C6",
600
+ output_dir: Path = "chords",
601
+ ):
602
+ """
603
+ Generates a toy multitrack dataset of chords, synthesized from sine waves.
604
+
605
+
606
+ Parameters
607
+ ----------
608
+ max_voices : int, optional
609
+ Maximum number of voices in a chord, by default 8
610
+ sample_rate : int, optional
611
+ Sample rate of audio, by default 44100
612
+ num_items : int, optional
613
+ Number of items to generate, by default 5
614
+ duration : float, optional
615
+ Duration of each item, by default 1.0
616
+ min_note : str, optional
617
+ Minimum note in the dataset, by default "C2"
618
+ max_note : str, optional
619
+ Maximum note in the dataset, by default "C6"
620
+ output_dir : Path, optional
621
+ Directory to save the dataset, by default "chords"
622
+
623
+ """
624
+ import librosa
625
+ from . import AudioSignal
626
+ from ..data.preprocess import create_csv
627
+
628
+ min_midi = librosa.note_to_midi(min_note)
629
+ max_midi = librosa.note_to_midi(max_note)
630
+
631
+ tracks = []
632
+ for idx in range(num_items):
633
+ track = {}
634
+ # figure out how many voices to put in this track
635
+ num_voices = random.randint(1, max_voices)
636
+ for voice_idx in range(num_voices):
637
+ # choose some random params
638
+ midinote = random.randint(min_midi, max_midi)
639
+ dur = random.uniform(0.85 * duration, duration)
640
+
641
+ sig = AudioSignal.wave(
642
+ frequency=librosa.midi_to_hz(midinote),
643
+ duration=dur,
644
+ sample_rate=sample_rate,
645
+ shape="sine",
646
+ )
647
+ track[f"voice_{voice_idx}"] = sig
648
+ tracks.append(track)
649
+
650
+ # save the tracks to disk
651
+ output_dir = Path(output_dir)
652
+ output_dir.mkdir(exist_ok=True)
653
+ for idx, track in enumerate(tracks):
654
+ track_dir = output_dir / f"track_{idx}"
655
+ track_dir.mkdir(exist_ok=True)
656
+ for voice_name, sig in track.items():
657
+ sig.write(track_dir / f"{voice_name}.wav")
658
+
659
+ all_voices = list(set([k for track in tracks for k in track.keys()]))
660
+ voice_lists = {voice: [] for voice in all_voices}
661
+ for track in tracks:
662
+ for voice_name in all_voices:
663
+ if voice_name in track:
664
+ voice_lists[voice_name].append(track[voice_name].path_to_file)
665
+ else:
666
+ voice_lists[voice_name].append("")
667
+
668
+ for voice_name, paths in voice_lists.items():
669
+ create_csv(paths, output_dir / f"{voice_name}.csv", loudness=True)
670
+
671
+ return output_dir
dac-vae/audiotools/core/whisper.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class WhisperMixin:
5
+ is_initialized = False
6
+
7
+ def setup_whisper(
8
+ self,
9
+ pretrained_model_name_or_path: str = "openai/whisper-base.en",
10
+ device: str = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
11
+ ):
12
+ from transformers import WhisperForConditionalGeneration
13
+ from transformers import WhisperProcessor
14
+
15
+ self.whisper_device = device
16
+ self.whisper_processor = WhisperProcessor.from_pretrained(
17
+ pretrained_model_name_or_path
18
+ )
19
+ self.whisper_model = WhisperForConditionalGeneration.from_pretrained(
20
+ pretrained_model_name_or_path
21
+ ).to(self.whisper_device)
22
+ self.is_initialized = True
23
+
24
+ def get_whisper_features(self) -> torch.Tensor:
25
+ """Preprocess audio signal as per the whisper model's training config.
26
+
27
+ Returns
28
+ -------
29
+ torch.Tensor
30
+ The prepinput features of the audio signal. Shape: (1, channels, seq_len)
31
+ """
32
+ import torch
33
+
34
+ if not self.is_initialized:
35
+ self.setup_whisper()
36
+
37
+ signal = self.to(self.device)
38
+ raw_speech = list(
39
+ (
40
+ signal.clone()
41
+ .resample(self.whisper_processor.feature_extractor.sampling_rate)
42
+ .audio_data[:, 0, :]
43
+ .numpy()
44
+ )
45
+ )
46
+
47
+ with torch.inference_mode():
48
+ input_features = self.whisper_processor(
49
+ raw_speech,
50
+ sampling_rate=self.whisper_processor.feature_extractor.sampling_rate,
51
+ return_tensors="pt",
52
+ ).input_features
53
+
54
+ return input_features
55
+
56
+ def get_whisper_transcript(self) -> str:
57
+ """Get the transcript of the audio signal using the whisper model.
58
+
59
+ Returns
60
+ -------
61
+ str
62
+ The transcript of the audio signal, including special tokens such as <|startoftranscript|> and <|endoftext|>.
63
+ """
64
+
65
+ if not self.is_initialized:
66
+ self.setup_whisper()
67
+
68
+ input_features = self.get_whisper_features()
69
+
70
+ with torch.inference_mode():
71
+ input_features = input_features.to(self.whisper_device)
72
+ generated_ids = self.whisper_model.generate(inputs=input_features)
73
+
74
+ transcription = self.whisper_processor.batch_decode(generated_ids)
75
+ return transcription[0]
76
+
77
+ def get_whisper_embeddings(self) -> torch.Tensor:
78
+ """Get the last hidden state embeddings of the audio signal using the whisper model.
79
+
80
+ Returns
81
+ -------
82
+ torch.Tensor
83
+ The Whisper embeddings of the audio signal. Shape: (1, seq_len, hidden_size)
84
+ """
85
+ import torch
86
+
87
+ if not self.is_initialized:
88
+ self.setup_whisper()
89
+
90
+ input_features = self.get_whisper_features()
91
+ encoder = self.whisper_model.get_encoder()
92
+
93
+ with torch.inference_mode():
94
+ input_features = input_features.to(self.whisper_device)
95
+ embeddings = encoder(input_features)
96
+
97
+ return embeddings.last_hidden_state
dac-vae/audiotools/data/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import datasets
2
+ from . import preprocess
3
+ from . import transforms
dac-vae/audiotools/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (255 Bytes). View file
 
dac-vae/audiotools/data/__pycache__/datasets.cpython-310.pyc ADDED
Binary file (17 kB). View file
 
dac-vae/audiotools/data/__pycache__/preprocess.cpython-310.pyc ADDED
Binary file (2.85 kB). View file
 
dac-vae/audiotools/data/__pycache__/transforms.cpython-310.pyc ADDED
Binary file (55.5 kB). View file
 
dac-vae/audiotools/data/datasets.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Callable
3
+ from typing import Dict
4
+ from typing import List
5
+ from typing import Union
6
+
7
+ import numpy as np
8
+ from torch.utils.data import SequentialSampler
9
+ from torch.utils.data.distributed import DistributedSampler
10
+
11
+ from ..core import AudioSignal
12
+ from ..core import util
13
+
14
+
15
+ class AudioLoader:
16
+ """Loads audio endlessly from a list of audio sources
17
+ containing paths to audio files. Audio sources can be
18
+ folders full of audio files (which are found via file
19
+ extension) or by providing a CSV file which contains paths
20
+ to audio files.
21
+
22
+ Parameters
23
+ ----------
24
+ sources : List[str], optional
25
+ Sources containing folders, or CSVs with
26
+ paths to audio files, by default None
27
+ weights : List[float], optional
28
+ Weights to sample audio files from each source, by default None
29
+ relative_path : str, optional
30
+ Path audio should be loaded relative to, by default ""
31
+ transform : Callable, optional
32
+ Transform to instantiate alongside audio sample,
33
+ by default None
34
+ ext : List[str]
35
+ List of extensions to find audio within each source by. Can
36
+ also be a file name (e.g. "vocals.wav"). by default
37
+ ``['.wav', '.flac', '.mp3', '.mp4']``.
38
+ shuffle: bool
39
+ Whether to shuffle the files within the dataloader. Defaults to True.
40
+ shuffle_state: int
41
+ State to use to seed the shuffle of the files.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ sources: List[str] = None,
47
+ weights: List[float] = None,
48
+ transform: Callable = None,
49
+ relative_path: str = "",
50
+ ext: List[str] = util.AUDIO_EXTENSIONS,
51
+ shuffle: bool = True,
52
+ shuffle_state: int = 0,
53
+ ):
54
+ self.audio_lists = util.read_sources(
55
+ sources, relative_path=relative_path, ext=ext
56
+ )
57
+
58
+ self.audio_indices = [
59
+ (src_idx, item_idx)
60
+ for src_idx, src in enumerate(self.audio_lists)
61
+ for item_idx in range(len(src))
62
+ ]
63
+ if shuffle:
64
+ state = util.random_state(shuffle_state)
65
+ state.shuffle(self.audio_indices)
66
+
67
+ self.sources = sources
68
+ self.weights = weights
69
+ self.transform = transform
70
+
71
+ def __call__(
72
+ self,
73
+ state,
74
+ sample_rate: int,
75
+ duration: float,
76
+ loudness_cutoff: float = -40,
77
+ num_channels: int = 1,
78
+ offset: float = None,
79
+ source_idx: int = None,
80
+ item_idx: int = None,
81
+ global_idx: int = None,
82
+ ):
83
+ if source_idx is not None and item_idx is not None:
84
+ try:
85
+ audio_info = self.audio_lists[source_idx][item_idx]
86
+ except:
87
+ audio_info = {"path": "none"}
88
+ elif global_idx is not None:
89
+ source_idx, item_idx = self.audio_indices[
90
+ global_idx % len(self.audio_indices)
91
+ ]
92
+ audio_info = self.audio_lists[source_idx][item_idx]
93
+ else:
94
+ audio_info, source_idx, item_idx = util.choose_from_list_of_lists(
95
+ state, self.audio_lists, p=self.weights
96
+ )
97
+
98
+ path = audio_info["path"]
99
+ signal = AudioSignal.zeros(duration, sample_rate, num_channels)
100
+
101
+ if path != "none":
102
+ if offset is None:
103
+ signal = AudioSignal.salient_excerpt(
104
+ path,
105
+ duration=duration,
106
+ state=state,
107
+ loudness_cutoff=loudness_cutoff,
108
+ )
109
+ else:
110
+ signal = AudioSignal(
111
+ path,
112
+ offset=offset,
113
+ duration=duration,
114
+ )
115
+
116
+ if num_channels == 1:
117
+ signal = signal.to_mono()
118
+ signal = signal.resample(sample_rate)
119
+
120
+ if signal.duration < duration:
121
+ signal = signal.zero_pad_to(int(duration * sample_rate))
122
+
123
+ for k, v in audio_info.items():
124
+ signal.metadata[k] = v
125
+
126
+ item = {
127
+ "signal": signal,
128
+ "source_idx": source_idx,
129
+ "item_idx": item_idx,
130
+ "source": str(self.sources[source_idx]),
131
+ "path": str(path),
132
+ }
133
+ if self.transform is not None:
134
+ item["transform_args"] = self.transform.instantiate(state, signal=signal)
135
+ return item
136
+
137
+
138
+ def default_matcher(x, y):
139
+ return Path(x).parent == Path(y).parent
140
+
141
+
142
+ def align_lists(lists, matcher: Callable = default_matcher):
143
+ longest_list = lists[np.argmax([len(l) for l in lists])]
144
+ for i, x in enumerate(longest_list):
145
+ for l in lists:
146
+ if i >= len(l):
147
+ l.append({"path": "none"})
148
+ elif not matcher(l[i]["path"], x["path"]):
149
+ l.insert(i, {"path": "none"})
150
+ return lists
151
+
152
+
153
+ class AudioDataset:
154
+ """Loads audio from multiple loaders (with associated transforms)
155
+ for a specified number of samples. Excerpts are drawn randomly
156
+ of the specified duration, above a specified loudness threshold
157
+ and are resampled on the fly to the desired sample rate
158
+ (if it is different from the audio source sample rate).
159
+
160
+ This takes either a single AudioLoader object,
161
+ a dictionary of AudioLoader objects, or a dictionary of AudioLoader
162
+ objects. Each AudioLoader is called by the dataset, and the
163
+ result is placed in the output dictionary. A transform can also be
164
+ specified for the entire dataset, rather than for each specific
165
+ loader. This transform can be applied to the output of all the
166
+ loaders if desired.
167
+
168
+ AudioLoader objects can be specified as aligned, which means the
169
+ loaders correspond to multitrack audio (e.g. a vocals, bass,
170
+ drums, and other loader for multitrack music mixtures).
171
+
172
+
173
+ Parameters
174
+ ----------
175
+ loaders : Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]]
176
+ AudioLoaders to sample audio from.
177
+ sample_rate : int
178
+ Desired sample rate.
179
+ n_examples : int, optional
180
+ Number of examples (length of dataset), by default 1000
181
+ duration : float, optional
182
+ Duration of audio samples, by default 0.5
183
+ loudness_cutoff : float, optional
184
+ Loudness cutoff threshold for audio samples, by default -40
185
+ num_channels : int, optional
186
+ Number of channels in output audio, by default 1
187
+ transform : Callable, optional
188
+ Transform to instantiate alongside each dataset item, by default None
189
+ aligned : bool, optional
190
+ Whether the loaders should be sampled in an aligned manner (e.g. same
191
+ offset, duration, and matched file name), by default False
192
+ shuffle_loaders : bool, optional
193
+ Whether to shuffle the loaders before sampling from them, by default False
194
+ matcher : Callable
195
+ How to match files from adjacent audio lists (e.g. for a multitrack audio loader),
196
+ by default uses the parent directory of each file.
197
+ without_replacement : bool
198
+ Whether to choose files with or without replacement, by default True.
199
+
200
+
201
+ Examples
202
+ --------
203
+ >>> from audiotools.data.datasets import AudioLoader
204
+ >>> from audiotools.data.datasets import AudioDataset
205
+ >>> from audiotools import transforms as tfm
206
+ >>> import numpy as np
207
+ >>>
208
+ >>> loaders = [
209
+ >>> AudioLoader(
210
+ >>> sources=[f"tests/audio/spk"],
211
+ >>> transform=tfm.Equalizer(),
212
+ >>> ext=["wav"],
213
+ >>> )
214
+ >>> for i in range(5)
215
+ >>> ]
216
+ >>>
217
+ >>> dataset = AudioDataset(
218
+ >>> loaders = loaders,
219
+ >>> sample_rate = 44100,
220
+ >>> duration = 1.0,
221
+ >>> transform = tfm.RescaleAudio(),
222
+ >>> )
223
+ >>>
224
+ >>> item = dataset[np.random.randint(len(dataset))]
225
+ >>>
226
+ >>> for i in range(len(loaders)):
227
+ >>> item[i]["signal"] = loaders[i].transform(
228
+ >>> item[i]["signal"], **item[i]["transform_args"]
229
+ >>> )
230
+ >>> item[i]["signal"].widget(i)
231
+ >>>
232
+ >>> mix = sum([item[i]["signal"] for i in range(len(loaders))])
233
+ >>> mix = dataset.transform(mix, **item["transform_args"])
234
+ >>> mix.widget("mix")
235
+
236
+ Below is an example of how one could load MUSDB multitrack data:
237
+
238
+ >>> import audiotools as at
239
+ >>> from pathlib import Path
240
+ >>> from audiotools import transforms as tfm
241
+ >>> import numpy as np
242
+ >>> import torch
243
+ >>>
244
+ >>> def build_dataset(
245
+ >>> sample_rate: int = 44100,
246
+ >>> duration: float = 5.0,
247
+ >>> musdb_path: str = "~/.data/musdb/",
248
+ >>> ):
249
+ >>> musdb_path = Path(musdb_path).expanduser()
250
+ >>> loaders = {
251
+ >>> src: at.datasets.AudioLoader(
252
+ >>> sources=[musdb_path],
253
+ >>> transform=tfm.Compose(
254
+ >>> tfm.VolumeNorm(("uniform", -20, -10)),
255
+ >>> tfm.Silence(prob=0.1),
256
+ >>> ),
257
+ >>> ext=[f"{src}.wav"],
258
+ >>> )
259
+ >>> for src in ["vocals", "bass", "drums", "other"]
260
+ >>> }
261
+ >>>
262
+ >>> dataset = at.datasets.AudioDataset(
263
+ >>> loaders=loaders,
264
+ >>> sample_rate=sample_rate,
265
+ >>> duration=duration,
266
+ >>> num_channels=1,
267
+ >>> aligned=True,
268
+ >>> transform=tfm.RescaleAudio(),
269
+ >>> shuffle_loaders=True,
270
+ >>> )
271
+ >>> return dataset, list(loaders.keys())
272
+ >>>
273
+ >>> train_data, sources = build_dataset()
274
+ >>> dataloader = torch.utils.data.DataLoader(
275
+ >>> train_data,
276
+ >>> batch_size=16,
277
+ >>> num_workers=0,
278
+ >>> collate_fn=train_data.collate,
279
+ >>> )
280
+ >>> batch = next(iter(dataloader))
281
+ >>>
282
+ >>> for k in sources:
283
+ >>> src = batch[k]
284
+ >>> src["transformed"] = train_data.loaders[k].transform(
285
+ >>> src["signal"].clone(), **src["transform_args"]
286
+ >>> )
287
+ >>>
288
+ >>> mixture = sum(batch[k]["transformed"] for k in sources)
289
+ >>> mixture = train_data.transform(mixture, **batch["transform_args"])
290
+ >>>
291
+ >>> # Say a model takes the mix and gives back (n_batch, n_src, n_time).
292
+ >>> # Construct the targets:
293
+ >>> targets = at.AudioSignal.batch([batch[k]["transformed"] for k in sources], dim=1)
294
+
295
+ Similarly, here's example code for loading Slakh data:
296
+
297
+ >>> import audiotools as at
298
+ >>> from pathlib import Path
299
+ >>> from audiotools import transforms as tfm
300
+ >>> import numpy as np
301
+ >>> import torch
302
+ >>> import glob
303
+ >>>
304
+ >>> def build_dataset(
305
+ >>> sample_rate: int = 16000,
306
+ >>> duration: float = 10.0,
307
+ >>> slakh_path: str = "~/.data/slakh/",
308
+ >>> ):
309
+ >>> slakh_path = Path(slakh_path).expanduser()
310
+ >>>
311
+ >>> # Find the max number of sources in Slakh
312
+ >>> src_names = [x.name for x in list(slakh_path.glob("**/*.wav")) if "S" in str(x.name)]
313
+ >>> n_sources = len(list(set(src_names)))
314
+ >>>
315
+ >>> loaders = {
316
+ >>> f"S{i:02d}": at.datasets.AudioLoader(
317
+ >>> sources=[slakh_path],
318
+ >>> transform=tfm.Compose(
319
+ >>> tfm.VolumeNorm(("uniform", -20, -10)),
320
+ >>> tfm.Silence(prob=0.1),
321
+ >>> ),
322
+ >>> ext=[f"S{i:02d}.wav"],
323
+ >>> )
324
+ >>> for i in range(n_sources)
325
+ >>> }
326
+ >>> dataset = at.datasets.AudioDataset(
327
+ >>> loaders=loaders,
328
+ >>> sample_rate=sample_rate,
329
+ >>> duration=duration,
330
+ >>> num_channels=1,
331
+ >>> aligned=True,
332
+ >>> transform=tfm.RescaleAudio(),
333
+ >>> shuffle_loaders=False,
334
+ >>> )
335
+ >>>
336
+ >>> return dataset, list(loaders.keys())
337
+ >>>
338
+ >>> train_data, sources = build_dataset()
339
+ >>> dataloader = torch.utils.data.DataLoader(
340
+ >>> train_data,
341
+ >>> batch_size=16,
342
+ >>> num_workers=0,
343
+ >>> collate_fn=train_data.collate,
344
+ >>> )
345
+ >>> batch = next(iter(dataloader))
346
+ >>>
347
+ >>> for k in sources:
348
+ >>> src = batch[k]
349
+ >>> src["transformed"] = train_data.loaders[k].transform(
350
+ >>> src["signal"].clone(), **src["transform_args"]
351
+ >>> )
352
+ >>>
353
+ >>> mixture = sum(batch[k]["transformed"] for k in sources)
354
+ >>> mixture = train_data.transform(mixture, **batch["transform_args"])
355
+
356
+ """
357
+
358
+ def __init__(
359
+ self,
360
+ loaders: Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]],
361
+ sample_rate: int,
362
+ n_examples: int = 1000,
363
+ duration: float = 0.5,
364
+ offset: float = None,
365
+ loudness_cutoff: float = -40,
366
+ num_channels: int = 1,
367
+ transform: Callable = None,
368
+ aligned: bool = False,
369
+ shuffle_loaders: bool = False,
370
+ matcher: Callable = default_matcher,
371
+ without_replacement: bool = True,
372
+ ):
373
+ # Internally we convert loaders to a dictionary
374
+ if isinstance(loaders, list):
375
+ loaders = {i: l for i, l in enumerate(loaders)}
376
+ elif isinstance(loaders, AudioLoader):
377
+ loaders = {0: loaders}
378
+
379
+ self.loaders = loaders
380
+ self.loudness_cutoff = loudness_cutoff
381
+ self.num_channels = num_channels
382
+
383
+ self.length = n_examples
384
+ self.transform = transform
385
+ self.sample_rate = sample_rate
386
+ self.duration = duration
387
+ self.offset = offset
388
+ self.aligned = aligned
389
+ self.shuffle_loaders = shuffle_loaders
390
+ self.without_replacement = without_replacement
391
+
392
+ if aligned:
393
+ loaders_list = list(loaders.values())
394
+ for i in range(len(loaders_list[0].audio_lists)):
395
+ input_lists = [l.audio_lists[i] for l in loaders_list]
396
+ # Alignment happens in-place
397
+ align_lists(input_lists, matcher)
398
+
399
+ def __getitem__(self, idx):
400
+ state = util.random_state(idx)
401
+ offset = None if self.offset is None else self.offset
402
+ item = {}
403
+
404
+ keys = list(self.loaders.keys())
405
+ if self.shuffle_loaders:
406
+ state.shuffle(keys)
407
+
408
+ loader_kwargs = {
409
+ "state": state,
410
+ "sample_rate": self.sample_rate,
411
+ "duration": self.duration,
412
+ "loudness_cutoff": self.loudness_cutoff,
413
+ "num_channels": self.num_channels,
414
+ "global_idx": idx if self.without_replacement else None,
415
+ }
416
+
417
+ # Draw item from first loader
418
+ loader = self.loaders[keys[0]]
419
+ item[keys[0]] = loader(**loader_kwargs)
420
+
421
+ for key in keys[1:]:
422
+ loader = self.loaders[key]
423
+ if self.aligned:
424
+ # Path mapper takes the current loader + everything
425
+ # returned by the first loader.
426
+ offset = item[keys[0]]["signal"].metadata["offset"]
427
+ loader_kwargs.update(
428
+ {
429
+ "offset": offset,
430
+ "source_idx": item[keys[0]]["source_idx"],
431
+ "item_idx": item[keys[0]]["item_idx"],
432
+ }
433
+ )
434
+ item[key] = loader(**loader_kwargs)
435
+
436
+ # Sort dictionary back into original order
437
+ keys = list(self.loaders.keys())
438
+ item = {k: item[k] for k in keys}
439
+
440
+ item["idx"] = idx
441
+ if self.transform is not None:
442
+ item["transform_args"] = self.transform.instantiate(
443
+ state=state, signal=item[keys[0]]["signal"]
444
+ )
445
+
446
+ # If there's only one loader, pop it up
447
+ # to the main dictionary, instead of keeping it
448
+ # nested.
449
+ if len(keys) == 1:
450
+ item.update(item.pop(keys[0]))
451
+
452
+ return item
453
+
454
+ def __len__(self):
455
+ return self.length
456
+
457
+ @staticmethod
458
+ def collate(list_of_dicts: Union[list, dict], n_splits: int = None):
459
+ """Collates items drawn from this dataset. Uses
460
+ :py:func:`audiotools.core.util.collate`.
461
+
462
+ Parameters
463
+ ----------
464
+ list_of_dicts : typing.Union[list, dict]
465
+ Data drawn from each item.
466
+ n_splits : int
467
+ Number of splits to make when creating the batches (split into
468
+ sub-batches). Useful for things like gradient accumulation.
469
+
470
+ Returns
471
+ -------
472
+ dict
473
+ Dictionary of batched data.
474
+ """
475
+ return util.collate(list_of_dicts, n_splits=n_splits)
476
+
477
+
478
+ class ConcatDataset(AudioDataset):
479
+ def __init__(self, datasets: list):
480
+ self.datasets = datasets
481
+
482
+ def __len__(self):
483
+ return sum([len(d) for d in self.datasets])
484
+
485
+ def __getitem__(self, idx):
486
+ dataset = self.datasets[idx % len(self.datasets)]
487
+ return dataset[idx // len(self.datasets)]
488
+
489
+
490
+ class ResumableDistributedSampler(DistributedSampler): # pragma: no cover
491
+ """Distributed sampler that can be resumed from a given start index."""
492
+
493
+ def __init__(self, dataset, start_idx: int = None, **kwargs):
494
+ super().__init__(dataset, **kwargs)
495
+ # Start index, allows to resume an experiment at the index it was
496
+ self.start_idx = start_idx // self.num_replicas if start_idx is not None else 0
497
+
498
+ def __iter__(self):
499
+ for i, idx in enumerate(super().__iter__()):
500
+ if i >= self.start_idx:
501
+ yield idx
502
+ self.start_idx = 0 # set the index back to 0 so for the next epoch
503
+
504
+
505
+ class ResumableSequentialSampler(SequentialSampler): # pragma: no cover
506
+ """Sequential sampler that can be resumed from a given start index."""
507
+
508
+ def __init__(self, dataset, start_idx: int = None, **kwargs):
509
+ super().__init__(dataset, **kwargs)
510
+ # Start index, allows to resume an experiment at the index it was
511
+ self.start_idx = start_idx if start_idx is not None else 0
512
+
513
+ def __iter__(self):
514
+ for i, idx in enumerate(super().__iter__()):
515
+ if i >= self.start_idx:
516
+ yield idx
517
+ self.start_idx = 0 # set the index back to 0 so for the next epoch
dac-vae/audiotools/data/preprocess.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+ from pathlib import Path
4
+
5
+ from tqdm import tqdm
6
+
7
+ from ..core import AudioSignal
8
+
9
+
10
+ def create_csv(
11
+ audio_files: list, output_csv: Path, loudness: bool = False, data_path: str = None
12
+ ):
13
+ """Converts a folder of audio files to a CSV file. If ``loudness = True``,
14
+ the output of this function will create a CSV file that looks something
15
+ like:
16
+
17
+ .. csv-table::
18
+ :header: path,loudness
19
+
20
+ daps/produced/f1_script1_produced.wav,-16.299999237060547
21
+ daps/produced/f1_script2_produced.wav,-16.600000381469727
22
+ daps/produced/f1_script3_produced.wav,-17.299999237060547
23
+ daps/produced/f1_script4_produced.wav,-16.100000381469727
24
+ daps/produced/f1_script5_produced.wav,-16.700000762939453
25
+ daps/produced/f3_script1_produced.wav,-16.5
26
+
27
+ .. note::
28
+ The paths above are written relative to the ``data_path`` argument
29
+ which defaults to the environment variable ``PATH_TO_DATA`` if
30
+ it isn't passed to this function, and defaults to the empty string
31
+ if that environment variable is not set.
32
+
33
+ You can produce a CSV file from a directory of audio files via:
34
+
35
+ >>> import audiotools
36
+ >>> directory = ...
37
+ >>> audio_files = audiotools.util.find_audio(directory)
38
+ >>> output_path = "train.csv"
39
+ >>> audiotools.data.preprocess.create_csv(
40
+ >>> audio_files, output_csv, loudness=True
41
+ >>> )
42
+
43
+ Note that you can create empty rows in the CSV file by passing an empty
44
+ string or None in the ``audio_files`` list. This is useful if you want to
45
+ sync multiple CSV files in a multitrack setting. The loudness of these
46
+ empty rows will be set to -inf.
47
+
48
+ Parameters
49
+ ----------
50
+ audio_files : list
51
+ List of audio files.
52
+ output_csv : Path
53
+ Output CSV, with each row containing the relative path of every file
54
+ to ``data_path``, if specified (defaults to None).
55
+ loudness : bool
56
+ Compute loudness of entire file and store alongside path.
57
+ """
58
+
59
+ info = []
60
+ pbar = tqdm(audio_files)
61
+ for af in pbar:
62
+ af = Path(af)
63
+ pbar.set_description(f"Processing {af.name}")
64
+ _info = {}
65
+ if af.name == "":
66
+ _info["path"] = ""
67
+ if loudness:
68
+ _info["loudness"] = -float("inf")
69
+ else:
70
+ _info["path"] = af.relative_to(data_path) if data_path is not None else af
71
+ if loudness:
72
+ _info["loudness"] = AudioSignal(af).ffmpeg_loudness().item()
73
+
74
+ info.append(_info)
75
+
76
+ with open(output_csv, "w") as f:
77
+ writer = csv.DictWriter(f, fieldnames=list(info[0].keys()))
78
+ writer.writeheader()
79
+
80
+ for item in info:
81
+ writer.writerow(item)
dac-vae/audiotools/data/transforms.py ADDED
@@ -0,0 +1,1592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from contextlib import contextmanager
3
+ from inspect import signature
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ import torch
8
+ from flatten_dict import flatten
9
+ from flatten_dict import unflatten
10
+ from numpy.random import RandomState
11
+
12
+ from .. import ml
13
+ from ..core import AudioSignal
14
+ from ..core import util
15
+ from .datasets import AudioLoader
16
+
17
+ tt = torch.tensor
18
+ """Shorthand for converting things to torch.tensor."""
19
+
20
+
21
+ class BaseTransform:
22
+ """This is the base class for all transforms that are implemented
23
+ in this library. Transforms have two main operations: ``transform``
24
+ and ``instantiate``.
25
+
26
+ ``instantiate`` sets the parameters randomly
27
+ from distribution tuples for each parameter. For example, for the
28
+ ``BackgroundNoise`` transform, the signal-to-noise ratio (``snr``)
29
+ is chosen randomly by instantiate. By default, it chosen uniformly
30
+ between 10.0 and 30.0 (the tuple is set to ``("uniform", 10.0, 30.0)``).
31
+
32
+ ``transform`` applies the transform using the instantiated parameters.
33
+ A simple example is as follows:
34
+
35
+ >>> seed = 0
36
+ >>> signal = ...
37
+ >>> transform = transforms.NoiseFloor(db = ("uniform", -50.0, -30.0))
38
+ >>> kwargs = transform.instantiate()
39
+ >>> output = transform(signal.clone(), **kwargs)
40
+
41
+ By breaking apart the instantiation of parameters from the actual audio
42
+ processing of the transform, we can make things more reproducible, while
43
+ also applying the transform on batches of data efficiently on GPU,
44
+ rather than on individual audio samples.
45
+
46
+ .. note::
47
+ We call ``signal.clone()`` for the input to the ``transform`` function
48
+ because signals are modified in-place! If you don't clone the signal,
49
+ you will lose the original data.
50
+
51
+ Parameters
52
+ ----------
53
+ keys : list, optional
54
+ Keys that the transform looks for when
55
+ calling ``self.transform``, by default []. In general this is
56
+ set automatically, and you won't need to manipulate this argument.
57
+ name : str, optional
58
+ Name of this transform, used to identify it in the dictionary
59
+ produced by ``self.instantiate``, by default None
60
+ prob : float, optional
61
+ Probability of applying this transform, by default 1.0
62
+
63
+ Examples
64
+ --------
65
+
66
+ >>> seed = 0
67
+ >>>
68
+ >>> audio_path = "tests/audio/spk/f10_script4_produced.wav"
69
+ >>> signal = AudioSignal(audio_path, offset=10, duration=2)
70
+ >>> transform = tfm.Compose(
71
+ >>> [
72
+ >>> tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]),
73
+ >>> tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]),
74
+ >>> ],
75
+ >>> )
76
+ >>>
77
+ >>> kwargs = transform.instantiate(seed, signal)
78
+ >>> output = transform(signal, **kwargs)
79
+
80
+ """
81
+
82
+ def __init__(self, keys: list = [], name: str = None, prob: float = 1.0):
83
+ # Get keys from the _transform signature.
84
+ tfm_keys = list(signature(self._transform).parameters.keys())
85
+
86
+ # Filter out signal and kwargs keys.
87
+ ignore_keys = ["signal", "kwargs"]
88
+ tfm_keys = [k for k in tfm_keys if k not in ignore_keys]
89
+
90
+ # Combine keys specified by the child class, the keys found in
91
+ # _transform signature, and the mask key.
92
+ self.keys = keys + tfm_keys + ["mask"]
93
+
94
+ self.prob = prob
95
+
96
+ if name is None:
97
+ name = self.__class__.__name__
98
+ self.name = name
99
+
100
+ def _prepare(self, batch: dict):
101
+ sub_batch = batch[self.name]
102
+
103
+ for k in self.keys:
104
+ assert k in sub_batch.keys(), f"{k} not in batch"
105
+
106
+ return sub_batch
107
+
108
+ def _transform(self, signal):
109
+ return signal
110
+
111
+ def _instantiate(self, state: RandomState, signal: AudioSignal = None):
112
+ return {}
113
+
114
+ @staticmethod
115
+ def apply_mask(batch: dict, mask: torch.Tensor):
116
+ """Applies a mask to the batch.
117
+
118
+ Parameters
119
+ ----------
120
+ batch : dict
121
+ Batch whose values will be masked in the ``transform`` pass.
122
+ mask : torch.Tensor
123
+ Mask to apply to batch.
124
+
125
+ Returns
126
+ -------
127
+ dict
128
+ A dictionary that contains values only where ``mask = True``.
129
+ """
130
+ masked_batch = {k: v[mask] for k, v in flatten(batch).items()}
131
+ return unflatten(masked_batch)
132
+
133
+ def transform(self, signal: AudioSignal, **kwargs):
134
+ """Apply the transform to the audio signal,
135
+ with given keyword arguments.
136
+
137
+ Parameters
138
+ ----------
139
+ signal : AudioSignal
140
+ Signal that will be modified by the transforms in-place.
141
+ kwargs: dict
142
+ Keyword arguments to the specific transforms ``self._transform``
143
+ function.
144
+
145
+ Returns
146
+ -------
147
+ AudioSignal
148
+ Transformed AudioSignal.
149
+
150
+ Examples
151
+ --------
152
+
153
+ >>> for seed in range(10):
154
+ >>> kwargs = transform.instantiate(seed, signal)
155
+ >>> output = transform(signal.clone(), **kwargs)
156
+
157
+ """
158
+ tfm_kwargs = self._prepare(kwargs)
159
+ mask = tfm_kwargs["mask"]
160
+
161
+ if torch.any(mask):
162
+ tfm_kwargs = self.apply_mask(tfm_kwargs, mask)
163
+ tfm_kwargs = {k: v for k, v in tfm_kwargs.items() if k != "mask"}
164
+ signal[mask] = self._transform(signal[mask], **tfm_kwargs)
165
+
166
+ return signal
167
+
168
+ def __call__(self, *args, **kwargs):
169
+ return self.transform(*args, **kwargs)
170
+
171
+ def instantiate(
172
+ self,
173
+ state: RandomState = None,
174
+ signal: AudioSignal = None,
175
+ ):
176
+ """Instantiates parameters for the transform.
177
+
178
+ Parameters
179
+ ----------
180
+ state : RandomState, optional
181
+ _description_, by default None
182
+ signal : AudioSignal, optional
183
+ _description_, by default None
184
+
185
+ Returns
186
+ -------
187
+ dict
188
+ Dictionary containing instantiated arguments for every keyword
189
+ argument to ``self._transform``.
190
+
191
+ Examples
192
+ --------
193
+
194
+ >>> for seed in range(10):
195
+ >>> kwargs = transform.instantiate(seed, signal)
196
+ >>> output = transform(signal.clone(), **kwargs)
197
+
198
+ """
199
+ state = util.random_state(state)
200
+
201
+ # Not all instantiates need the signal. Check if signal
202
+ # is needed before passing it in, so that the end-user
203
+ # doesn't need to have variables they're not using flowing
204
+ # into their function.
205
+ needs_signal = "signal" in set(signature(self._instantiate).parameters.keys())
206
+ kwargs = {}
207
+ if needs_signal:
208
+ kwargs = {"signal": signal}
209
+
210
+ # Instantiate the parameters for the transform.
211
+ params = self._instantiate(state, **kwargs)
212
+ for k in list(params.keys()):
213
+ v = params[k]
214
+ if isinstance(v, (AudioSignal, torch.Tensor, dict)):
215
+ params[k] = v
216
+ else:
217
+ params[k] = tt(v)
218
+ mask = state.rand() <= self.prob
219
+ params[f"mask"] = tt(mask)
220
+
221
+ # Put the params into a nested dictionary that will be
222
+ # used later when calling the transform. This is to avoid
223
+ # collisions in the dictionary.
224
+ params = {self.name: params}
225
+
226
+ return params
227
+
228
+ def batch_instantiate(
229
+ self,
230
+ states: list = None,
231
+ signal: AudioSignal = None,
232
+ ):
233
+ """Instantiates arguments for every item in a batch,
234
+ given a list of states. Each state in the list
235
+ corresponds to one item in the batch.
236
+
237
+ Parameters
238
+ ----------
239
+ states : list, optional
240
+ List of states, by default None
241
+ signal : AudioSignal, optional
242
+ AudioSignal to pass to the ``self.instantiate`` section
243
+ if it is needed for this transform, by default None
244
+
245
+ Returns
246
+ -------
247
+ dict
248
+ Collated dictionary of arguments.
249
+
250
+ Examples
251
+ --------
252
+
253
+ >>> batch_size = 4
254
+ >>> signal = AudioSignal(audio_path, offset=10, duration=2)
255
+ >>> signal_batch = AudioSignal.batch([signal.clone() for _ in range(batch_size)])
256
+ >>>
257
+ >>> states = [seed + idx for idx in list(range(batch_size))]
258
+ >>> kwargs = transform.batch_instantiate(states, signal_batch)
259
+ >>> batch_output = transform(signal_batch, **kwargs)
260
+ """
261
+ kwargs = []
262
+ for state in states:
263
+ kwargs.append(self.instantiate(state, signal))
264
+ kwargs = util.collate(kwargs)
265
+ return kwargs
266
+
267
+
268
+ class Identity(BaseTransform):
269
+ """This transform just returns the original signal."""
270
+
271
+ pass
272
+
273
+
274
+ class SpectralTransform(BaseTransform):
275
+ """Spectral transforms require STFT data to exist, since manipulations
276
+ of the STFT require the spectrogram. This just calls ``stft`` before
277
+ the transform is called, and calls ``istft`` after the transform is
278
+ called so that the audio data is written to after the spectral
279
+ manipulation.
280
+ """
281
+
282
+ def transform(self, signal, **kwargs):
283
+ signal.stft()
284
+ super().transform(signal, **kwargs)
285
+ signal.istft()
286
+ return signal
287
+
288
+
289
+ class Compose(BaseTransform):
290
+ """Compose applies transforms in sequence, one after the other. The
291
+ transforms are passed in as positional arguments or as a list like so:
292
+
293
+ >>> transform = tfm.Compose(
294
+ >>> [
295
+ >>> tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]),
296
+ >>> tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]),
297
+ >>> ],
298
+ >>> )
299
+
300
+ This will convolve the signal with a room impulse response, and then
301
+ add background noise to the signal. Instantiate instantiates
302
+ all the parameters for every transform in the transform list so the
303
+ interface for using the Compose transform is the same as everything
304
+ else:
305
+
306
+ >>> kwargs = transform.instantiate()
307
+ >>> output = transform(signal.clone(), **kwargs)
308
+
309
+ Under the hood, the transform maps each transform to a unique name
310
+ under the hood of the form ``{position}.{name}``, where ``position``
311
+ is the index of the transform in the list. ``Compose`` can nest
312
+ within other ``Compose`` transforms, like so:
313
+
314
+ >>> preprocess = transforms.Compose(
315
+ >>> tfm.GlobalVolumeNorm(),
316
+ >>> tfm.CrossTalk(),
317
+ >>> name="preprocess",
318
+ >>> )
319
+ >>> augment = transforms.Compose(
320
+ >>> tfm.RoomImpulseResponse(),
321
+ >>> tfm.BackgroundNoise(),
322
+ >>> name="augment",
323
+ >>> )
324
+ >>> postprocess = transforms.Compose(
325
+ >>> tfm.VolumeChange(),
326
+ >>> tfm.RescaleAudio(),
327
+ >>> tfm.ShiftPhase(),
328
+ >>> name="postprocess",
329
+ >>> )
330
+ >>> transform = transforms.Compose(preprocess, augment, postprocess),
331
+
332
+ This defines 3 composed transforms, and then composes them in sequence
333
+ with one another.
334
+
335
+ Parameters
336
+ ----------
337
+ *transforms : list
338
+ List of transforms to apply
339
+ name : str, optional
340
+ Name of this transform, used to identify it in the dictionary
341
+ produced by ``self.instantiate``, by default None
342
+ prob : float, optional
343
+ Probability of applying this transform, by default 1.0
344
+ """
345
+
346
+ def __init__(self, *transforms: list, name: str = None, prob: float = 1.0):
347
+ if isinstance(transforms[0], list):
348
+ transforms = transforms[0]
349
+
350
+ for i, tfm in enumerate(transforms):
351
+ tfm.name = f"{i}.{tfm.name}"
352
+
353
+ keys = [tfm.name for tfm in transforms]
354
+ super().__init__(keys=keys, name=name, prob=prob)
355
+
356
+ self.transforms = transforms
357
+ self.transforms_to_apply = keys
358
+
359
+ @contextmanager
360
+ def filter(self, *names: list):
361
+ """This can be used to skip transforms entirely when applying
362
+ the sequence of transforms to a signal. For example, take
363
+ the following transforms with the names ``preprocess, augment, postprocess``.
364
+
365
+ >>> preprocess = transforms.Compose(
366
+ >>> tfm.GlobalVolumeNorm(),
367
+ >>> tfm.CrossTalk(),
368
+ >>> name="preprocess",
369
+ >>> )
370
+ >>> augment = transforms.Compose(
371
+ >>> tfm.RoomImpulseResponse(),
372
+ >>> tfm.BackgroundNoise(),
373
+ >>> name="augment",
374
+ >>> )
375
+ >>> postprocess = transforms.Compose(
376
+ >>> tfm.VolumeChange(),
377
+ >>> tfm.RescaleAudio(),
378
+ >>> tfm.ShiftPhase(),
379
+ >>> name="postprocess",
380
+ >>> )
381
+ >>> transform = transforms.Compose(preprocess, augment, postprocess)
382
+
383
+ If we wanted to apply all 3 to a signal, we do:
384
+
385
+ >>> kwargs = transform.instantiate()
386
+ >>> output = transform(signal.clone(), **kwargs)
387
+
388
+ But if we only wanted to apply the ``preprocess`` and ``postprocess``
389
+ transforms to the signal, we do:
390
+
391
+ >>> with transform_fn.filter("preprocess", "postprocess"):
392
+ >>> output = transform(signal.clone(), **kwargs)
393
+
394
+ Parameters
395
+ ----------
396
+ *names : list
397
+ List of transforms, identified by name, to apply to signal.
398
+ """
399
+ old_transforms = self.transforms_to_apply
400
+ self.transforms_to_apply = names
401
+ yield
402
+ self.transforms_to_apply = old_transforms
403
+
404
+ def _transform(self, signal, **kwargs):
405
+ for transform in self.transforms:
406
+ if any([x in transform.name for x in self.transforms_to_apply]):
407
+ signal = transform(signal, **kwargs)
408
+ return signal
409
+
410
+ def _instantiate(self, state: RandomState, signal: AudioSignal = None):
411
+ parameters = {}
412
+ for transform in self.transforms:
413
+ parameters.update(transform.instantiate(state, signal=signal))
414
+ return parameters
415
+
416
+ def __getitem__(self, idx):
417
+ return self.transforms[idx]
418
+
419
+ def __len__(self):
420
+ return len(self.transforms)
421
+
422
+ def __iter__(self):
423
+ for transform in self.transforms:
424
+ yield transform
425
+
426
+
427
+ class Choose(Compose):
428
+ """Choose logic is the same as :py:func:`audiotools.data.transforms.Compose`,
429
+ but instead of applying all the transforms in sequence, it applies just a single transform,
430
+ which is chosen for each item in the batch.
431
+
432
+ Parameters
433
+ ----------
434
+ *transforms : list
435
+ List of transforms to apply
436
+ weights : list
437
+ Probability of choosing any specific transform.
438
+ name : str, optional
439
+ Name of this transform, used to identify it in the dictionary
440
+ produced by ``self.instantiate``, by default None
441
+ prob : float, optional
442
+ Probability of applying this transform, by default 1.0
443
+
444
+ Examples
445
+ --------
446
+
447
+ >>> transforms.Choose(tfm.LowPass(), tfm.HighPass())
448
+ """
449
+
450
+ def __init__(
451
+ self,
452
+ *transforms: list,
453
+ weights: list = None,
454
+ name: str = None,
455
+ prob: float = 1.0,
456
+ ):
457
+ super().__init__(*transforms, name=name, prob=prob)
458
+
459
+ if weights is None:
460
+ _len = len(self.transforms)
461
+ weights = [1 / _len for _ in range(_len)]
462
+ self.weights = np.array(weights)
463
+
464
+ def _instantiate(self, state: RandomState, signal: AudioSignal = None):
465
+ kwargs = super()._instantiate(state, signal)
466
+ tfm_idx = list(range(len(self.transforms)))
467
+ tfm_idx = state.choice(tfm_idx, p=self.weights)
468
+ one_hot = []
469
+ for i, t in enumerate(self.transforms):
470
+ mask = kwargs[t.name]["mask"]
471
+ if mask.item():
472
+ kwargs[t.name]["mask"] = tt(i == tfm_idx)
473
+ one_hot.append(kwargs[t.name]["mask"])
474
+ kwargs["one_hot"] = one_hot
475
+ return kwargs
476
+
477
+
478
+ class Repeat(Compose):
479
+ """Repeatedly applies a given transform ``n_repeat`` times."
480
+
481
+ Parameters
482
+ ----------
483
+ transform : BaseTransform
484
+ Transform to repeat.
485
+ n_repeat : int, optional
486
+ Number of times to repeat transform, by default 1
487
+ """
488
+
489
+ def __init__(
490
+ self,
491
+ transform,
492
+ n_repeat: int = 1,
493
+ name: str = None,
494
+ prob: float = 1.0,
495
+ ):
496
+ transforms = [copy.copy(transform) for _ in range(n_repeat)]
497
+ super().__init__(transforms, name=name, prob=prob)
498
+
499
+ self.n_repeat = n_repeat
500
+
501
+
502
+ class RepeatUpTo(Choose):
503
+ """Repeatedly applies a given transform up to ``max_repeat`` times."
504
+
505
+ Parameters
506
+ ----------
507
+ transform : BaseTransform
508
+ Transform to repeat.
509
+ max_repeat : int, optional
510
+ Max number of times to repeat transform, by default 1
511
+ weights : list
512
+ Probability of choosing any specific number up to ``max_repeat``.
513
+ """
514
+
515
+ def __init__(
516
+ self,
517
+ transform,
518
+ max_repeat: int = 5,
519
+ weights: list = None,
520
+ name: str = None,
521
+ prob: float = 1.0,
522
+ ):
523
+ transforms = []
524
+ for n in range(1, max_repeat):
525
+ transforms.append(Repeat(transform, n_repeat=n))
526
+ super().__init__(transforms, name=name, prob=prob, weights=weights)
527
+
528
+ self.max_repeat = max_repeat
529
+
530
+
531
+ class ClippingDistortion(BaseTransform):
532
+ """Adds clipping distortion to signal. Corresponds
533
+ to :py:func:`audiotools.core.effects.EffectMixin.clip_distortion`.
534
+
535
+ Parameters
536
+ ----------
537
+ perc : tuple, optional
538
+ Clipping percentile. Values are between 0.0 to 1.0.
539
+ Typical values are 0.1 or below, by default ("uniform", 0.0, 0.1)
540
+ name : str, optional
541
+ Name of this transform, used to identify it in the dictionary
542
+ produced by ``self.instantiate``, by default None
543
+ prob : float, optional
544
+ Probability of applying this transform, by default 1.0
545
+ """
546
+
547
+ def __init__(
548
+ self,
549
+ perc: tuple = ("uniform", 0.0, 0.1),
550
+ name: str = None,
551
+ prob: float = 1.0,
552
+ ):
553
+ super().__init__(name=name, prob=prob)
554
+
555
+ self.perc = perc
556
+
557
+ def _instantiate(self, state: RandomState):
558
+ return {"perc": util.sample_from_dist(self.perc, state)}
559
+
560
+ def _transform(self, signal, perc):
561
+ return signal.clip_distortion(perc)
562
+
563
+
564
+ class Equalizer(BaseTransform):
565
+ """Applies an equalization curve to the audio signal. Corresponds
566
+ to :py:func:`audiotools.core.effects.EffectMixin.equalizer`.
567
+
568
+ Parameters
569
+ ----------
570
+ eq_amount : tuple, optional
571
+ The maximum dB cut to apply to the audio in any band,
572
+ by default ("const", 1.0 dB)
573
+ n_bands : int, optional
574
+ Number of bands in EQ, by default 6
575
+ name : str, optional
576
+ Name of this transform, used to identify it in the dictionary
577
+ produced by ``self.instantiate``, by default None
578
+ prob : float, optional
579
+ Probability of applying this transform, by default 1.0
580
+ """
581
+
582
+ def __init__(
583
+ self,
584
+ eq_amount: tuple = ("const", 1.0),
585
+ n_bands: int = 6,
586
+ name: str = None,
587
+ prob: float = 1.0,
588
+ ):
589
+ super().__init__(name=name, prob=prob)
590
+
591
+ self.eq_amount = eq_amount
592
+ self.n_bands = n_bands
593
+
594
+ def _instantiate(self, state: RandomState):
595
+ eq_amount = util.sample_from_dist(self.eq_amount, state)
596
+ eq = -eq_amount * state.rand(self.n_bands)
597
+ return {"eq": eq}
598
+
599
+ def _transform(self, signal, eq):
600
+ return signal.equalizer(eq)
601
+
602
+
603
+ class Quantization(BaseTransform):
604
+ """Applies quantization to the input waveform. Corresponds
605
+ to :py:func:`audiotools.core.effects.EffectMixin.quantization`.
606
+
607
+ Parameters
608
+ ----------
609
+ channels : tuple, optional
610
+ Number of evenly spaced quantization channels to quantize
611
+ to, by default ("choice", [8, 32, 128, 256, 1024])
612
+ name : str, optional
613
+ Name of this transform, used to identify it in the dictionary
614
+ produced by ``self.instantiate``, by default None
615
+ prob : float, optional
616
+ Probability of applying this transform, by default 1.0
617
+ """
618
+
619
+ def __init__(
620
+ self,
621
+ channels: tuple = ("choice", [8, 32, 128, 256, 1024]),
622
+ name: str = None,
623
+ prob: float = 1.0,
624
+ ):
625
+ super().__init__(name=name, prob=prob)
626
+
627
+ self.channels = channels
628
+
629
+ def _instantiate(self, state: RandomState):
630
+ return {"channels": util.sample_from_dist(self.channels, state)}
631
+
632
+ def _transform(self, signal, channels):
633
+ return signal.quantization(channels)
634
+
635
+
636
+ class MuLawQuantization(BaseTransform):
637
+ """Applies mu-law quantization to the input waveform. Corresponds
638
+ to :py:func:`audiotools.core.effects.EffectMixin.mulaw_quantization`.
639
+
640
+ Parameters
641
+ ----------
642
+ channels : tuple, optional
643
+ Number of mu-law spaced quantization channels to quantize
644
+ to, by default ("choice", [8, 32, 128, 256, 1024])
645
+ name : str, optional
646
+ Name of this transform, used to identify it in the dictionary
647
+ produced by ``self.instantiate``, by default None
648
+ prob : float, optional
649
+ Probability of applying this transform, by default 1.0
650
+ """
651
+
652
+ def __init__(
653
+ self,
654
+ channels: tuple = ("choice", [8, 32, 128, 256, 1024]),
655
+ name: str = None,
656
+ prob: float = 1.0,
657
+ ):
658
+ super().__init__(name=name, prob=prob)
659
+
660
+ self.channels = channels
661
+
662
+ def _instantiate(self, state: RandomState):
663
+ return {"channels": util.sample_from_dist(self.channels, state)}
664
+
665
+ def _transform(self, signal, channels):
666
+ return signal.mulaw_quantization(channels)
667
+
668
+
669
+ class NoiseFloor(BaseTransform):
670
+ """Adds a noise floor of Gaussian noise to the signal at a specified
671
+ dB.
672
+
673
+ Parameters
674
+ ----------
675
+ db : tuple, optional
676
+ Level of noise to add to signal, by default ("const", -50.0)
677
+ name : str, optional
678
+ Name of this transform, used to identify it in the dictionary
679
+ produced by ``self.instantiate``, by default None
680
+ prob : float, optional
681
+ Probability of applying this transform, by default 1.0
682
+ """
683
+
684
+ def __init__(
685
+ self,
686
+ db: tuple = ("const", -50.0),
687
+ name: str = None,
688
+ prob: float = 1.0,
689
+ ):
690
+ super().__init__(name=name, prob=prob)
691
+
692
+ self.db = db
693
+
694
+ def _instantiate(self, state: RandomState, signal: AudioSignal):
695
+ db = util.sample_from_dist(self.db, state)
696
+ audio_data = state.randn(signal.num_channels, signal.signal_length)
697
+ nz_signal = AudioSignal(audio_data, signal.sample_rate)
698
+ nz_signal.normalize(db)
699
+ return {"nz_signal": nz_signal}
700
+
701
+ def _transform(self, signal, nz_signal):
702
+ # Clone bg_signal so that transform can be repeatedly applied
703
+ # to different signals with the same effect.
704
+ return signal + nz_signal
705
+
706
+
707
+ class BackgroundNoise(BaseTransform):
708
+ """Adds background noise from audio specified by a set of CSV files.
709
+ A valid CSV file looks like, and is typically generated by
710
+ :py:func:`audiotools.data.preprocess.create_csv`:
711
+
712
+ .. csv-table::
713
+ :header: path
714
+
715
+ room_tone/m6_script2_clean.wav
716
+ room_tone/m6_script2_cleanraw.wav
717
+ room_tone/m6_script2_ipad_balcony1.wav
718
+ room_tone/m6_script2_ipad_bedroom1.wav
719
+ room_tone/m6_script2_ipad_confroom1.wav
720
+ room_tone/m6_script2_ipad_confroom2.wav
721
+ room_tone/m6_script2_ipad_livingroom1.wav
722
+ room_tone/m6_script2_ipad_office1.wav
723
+
724
+ .. note::
725
+ All paths are relative to an environment variable called ``PATH_TO_DATA``,
726
+ so that CSV files are portable across machines where data may be
727
+ located in different places.
728
+
729
+ This transform calls :py:func:`audiotools.core.effects.EffectMixin.mix`
730
+ and :py:func:`audiotools.core.effects.EffectMixin.equalizer` under the
731
+ hood.
732
+
733
+ Parameters
734
+ ----------
735
+ snr : tuple, optional
736
+ Signal-to-noise ratio, by default ("uniform", 10.0, 30.0)
737
+ sources : List[str], optional
738
+ Sources containing folders, or CSVs with paths to audio files,
739
+ by default None
740
+ weights : List[float], optional
741
+ Weights to sample audio files from each source, by default None
742
+ eq_amount : tuple, optional
743
+ Amount of equalization to apply, by default ("const", 1.0)
744
+ n_bands : int, optional
745
+ Number of bands in equalizer, by default 3
746
+ name : str, optional
747
+ Name of this transform, used to identify it in the dictionary
748
+ produced by ``self.instantiate``, by default None
749
+ prob : float, optional
750
+ Probability of applying this transform, by default 1.0
751
+ loudness_cutoff : float, optional
752
+ Loudness cutoff when loading from audio files, by default None
753
+ """
754
+
755
+ def __init__(
756
+ self,
757
+ snr: tuple = ("uniform", 10.0, 30.0),
758
+ sources: List[str] = None,
759
+ weights: List[float] = None,
760
+ eq_amount: tuple = ("const", 1.0),
761
+ n_bands: int = 3,
762
+ name: str = None,
763
+ prob: float = 1.0,
764
+ loudness_cutoff: float = None,
765
+ ):
766
+ super().__init__(name=name, prob=prob)
767
+
768
+ self.snr = snr
769
+ self.eq_amount = eq_amount
770
+ self.n_bands = n_bands
771
+ self.loader = AudioLoader(sources, weights)
772
+ self.loudness_cutoff = loudness_cutoff
773
+
774
+ def _instantiate(self, state: RandomState, signal: AudioSignal):
775
+ eq_amount = util.sample_from_dist(self.eq_amount, state)
776
+ eq = -eq_amount * state.rand(self.n_bands)
777
+ snr = util.sample_from_dist(self.snr, state)
778
+
779
+ bg_signal = self.loader(
780
+ state,
781
+ signal.sample_rate,
782
+ duration=signal.signal_duration,
783
+ loudness_cutoff=self.loudness_cutoff,
784
+ num_channels=signal.num_channels,
785
+ )["signal"]
786
+
787
+ return {"eq": eq, "bg_signal": bg_signal, "snr": snr}
788
+
789
+ def _transform(self, signal, bg_signal, snr, eq):
790
+ # Clone bg_signal so that transform can be repeatedly applied
791
+ # to different signals with the same effect.
792
+ return signal.mix(bg_signal.clone(), snr, eq)
793
+
794
+
795
+ class CrossTalk(BaseTransform):
796
+ """Adds crosstalk between speakers, whose audio is drawn from a CSV file
797
+ that was produced via :py:func:`audiotools.data.preprocess.create_csv`.
798
+
799
+ This transform calls :py:func:`audiotools.core.effects.EffectMixin.mix`
800
+ under the hood.
801
+
802
+ Parameters
803
+ ----------
804
+ snr : tuple, optional
805
+ How loud cross-talk speaker is relative to original signal in dB,
806
+ by default ("uniform", 0.0, 10.0)
807
+ sources : List[str], optional
808
+ Sources containing folders, or CSVs with paths to audio files,
809
+ by default None
810
+ weights : List[float], optional
811
+ Weights to sample audio files from each source, by default None
812
+ name : str, optional
813
+ Name of this transform, used to identify it in the dictionary
814
+ produced by ``self.instantiate``, by default None
815
+ prob : float, optional
816
+ Probability of applying this transform, by default 1.0
817
+ loudness_cutoff : float, optional
818
+ Loudness cutoff when loading from audio files, by default -40
819
+ """
820
+
821
+ def __init__(
822
+ self,
823
+ snr: tuple = ("uniform", 0.0, 10.0),
824
+ sources: List[str] = None,
825
+ weights: List[float] = None,
826
+ name: str = None,
827
+ prob: float = 1.0,
828
+ loudness_cutoff: float = -40,
829
+ ):
830
+ super().__init__(name=name, prob=prob)
831
+
832
+ self.snr = snr
833
+ self.loader = AudioLoader(sources, weights)
834
+ self.loudness_cutoff = loudness_cutoff
835
+
836
+ def _instantiate(self, state: RandomState, signal: AudioSignal):
837
+ snr = util.sample_from_dist(self.snr, state)
838
+ crosstalk_signal = self.loader(
839
+ state,
840
+ signal.sample_rate,
841
+ duration=signal.signal_duration,
842
+ loudness_cutoff=self.loudness_cutoff,
843
+ num_channels=signal.num_channels,
844
+ )["signal"]
845
+
846
+ return {"crosstalk_signal": crosstalk_signal, "snr": snr}
847
+
848
+ def _transform(self, signal, crosstalk_signal, snr):
849
+ # Clone bg_signal so that transform can be repeatedly applied
850
+ # to different signals with the same effect.
851
+ loudness = signal.loudness()
852
+ mix = signal.mix(crosstalk_signal.clone(), snr)
853
+ mix.normalize(loudness)
854
+ return mix
855
+
856
+
857
+ class RoomImpulseResponse(BaseTransform):
858
+ """Convolves signal with a room impulse response, at a specified
859
+ direct-to-reverberant ratio, with equalization applied. Room impulse
860
+ response data is drawn from a CSV file that was produced via
861
+ :py:func:`audiotools.data.preprocess.create_csv`.
862
+
863
+ This transform calls :py:func:`audiotools.core.effects.EffectMixin.apply_ir`
864
+ under the hood.
865
+
866
+ Parameters
867
+ ----------
868
+ drr : tuple, optional
869
+ _description_, by default ("uniform", 0.0, 30.0)
870
+ sources : List[str], optional
871
+ Sources containing folders, or CSVs with paths to audio files,
872
+ by default None
873
+ weights : List[float], optional
874
+ Weights to sample audio files from each source, by default None
875
+ eq_amount : tuple, optional
876
+ Amount of equalization to apply, by default ("const", 1.0)
877
+ n_bands : int, optional
878
+ Number of bands in equalizer, by default 6
879
+ name : str, optional
880
+ Name of this transform, used to identify it in the dictionary
881
+ produced by ``self.instantiate``, by default None
882
+ prob : float, optional
883
+ Probability of applying this transform, by default 1.0
884
+ use_original_phase : bool, optional
885
+ Whether or not to use the original phase, by default False
886
+ offset : float, optional
887
+ Offset from each impulse response file to use, by default 0.0
888
+ duration : float, optional
889
+ Duration of each impulse response, by default 1.0
890
+ """
891
+
892
+ def __init__(
893
+ self,
894
+ drr: tuple = ("uniform", 0.0, 30.0),
895
+ sources: List[str] = None,
896
+ weights: List[float] = None,
897
+ eq_amount: tuple = ("const", 1.0),
898
+ n_bands: int = 6,
899
+ name: str = None,
900
+ prob: float = 1.0,
901
+ use_original_phase: bool = False,
902
+ offset: float = 0.0,
903
+ duration: float = 1.0,
904
+ ):
905
+ super().__init__(name=name, prob=prob)
906
+
907
+ self.drr = drr
908
+ self.eq_amount = eq_amount
909
+ self.n_bands = n_bands
910
+ self.use_original_phase = use_original_phase
911
+
912
+ self.loader = AudioLoader(sources, weights)
913
+ self.offset = offset
914
+ self.duration = duration
915
+
916
+ def _instantiate(self, state: RandomState, signal: AudioSignal = None):
917
+ eq_amount = util.sample_from_dist(self.eq_amount, state)
918
+ eq = -eq_amount * state.rand(self.n_bands)
919
+ drr = util.sample_from_dist(self.drr, state)
920
+
921
+ ir_signal = self.loader(
922
+ state,
923
+ signal.sample_rate,
924
+ offset=self.offset,
925
+ duration=self.duration,
926
+ loudness_cutoff=None,
927
+ num_channels=signal.num_channels,
928
+ )["signal"]
929
+ ir_signal.zero_pad_to(signal.sample_rate)
930
+
931
+ return {"eq": eq, "ir_signal": ir_signal, "drr": drr}
932
+
933
+ def _transform(self, signal, ir_signal, drr, eq):
934
+ # Clone ir_signal so that transform can be repeatedly applied
935
+ # to different signals with the same effect.
936
+ return signal.apply_ir(
937
+ ir_signal.clone(), drr, eq, use_original_phase=self.use_original_phase
938
+ )
939
+
940
+
941
+ class VolumeChange(BaseTransform):
942
+ """Changes the volume of the input signal.
943
+
944
+ Uses :py:func:`audiotools.core.effects.EffectMixin.volume_change`.
945
+
946
+ Parameters
947
+ ----------
948
+ db : tuple, optional
949
+ Change in volume in decibels, by default ("uniform", -12.0, 0.0)
950
+ name : str, optional
951
+ Name of this transform, used to identify it in the dictionary
952
+ produced by ``self.instantiate``, by default None
953
+ prob : float, optional
954
+ Probability of applying this transform, by default 1.0
955
+ """
956
+
957
+ def __init__(
958
+ self,
959
+ db: tuple = ("uniform", -12.0, 0.0),
960
+ name: str = None,
961
+ prob: float = 1.0,
962
+ ):
963
+ super().__init__(name=name, prob=prob)
964
+ self.db = db
965
+
966
+ def _instantiate(self, state: RandomState):
967
+ return {"db": util.sample_from_dist(self.db, state)}
968
+
969
+ def _transform(self, signal, db):
970
+ return signal.volume_change(db)
971
+
972
+
973
+ class VolumeNorm(BaseTransform):
974
+ """Normalizes the volume of the excerpt to a specified decibel.
975
+
976
+ Uses :py:func:`audiotools.core.effects.EffectMixin.normalize`.
977
+
978
+ Parameters
979
+ ----------
980
+ db : tuple, optional
981
+ dB to normalize signal to, by default ("const", -24)
982
+ name : str, optional
983
+ Name of this transform, used to identify it in the dictionary
984
+ produced by ``self.instantiate``, by default None
985
+ prob : float, optional
986
+ Probability of applying this transform, by default 1.0
987
+ """
988
+
989
+ def __init__(
990
+ self,
991
+ db: tuple = ("const", -24),
992
+ name: str = None,
993
+ prob: float = 1.0,
994
+ ):
995
+ super().__init__(name=name, prob=prob)
996
+
997
+ self.db = db
998
+
999
+ def _instantiate(self, state: RandomState):
1000
+ return {"db": util.sample_from_dist(self.db, state)}
1001
+
1002
+ def _transform(self, signal, db):
1003
+ return signal.normalize(db)
1004
+
1005
+
1006
+ class GlobalVolumeNorm(BaseTransform):
1007
+ """Similar to :py:func:`audiotools.data.transforms.VolumeNorm`, this
1008
+ transform also normalizes the volume of a signal, but it uses
1009
+ the volume of the entire audio file the loaded excerpt comes from,
1010
+ rather than the volume of just the excerpt. The volume of the
1011
+ entire audio file is expected in ``signal.metadata["loudness"]``.
1012
+ If loading audio from a CSV generated by :py:func:`audiotools.data.preprocess.create_csv`
1013
+ with ``loudness = True``, like the following:
1014
+
1015
+ .. csv-table::
1016
+ :header: path,loudness
1017
+
1018
+ daps/produced/f1_script1_produced.wav,-16.299999237060547
1019
+ daps/produced/f1_script2_produced.wav,-16.600000381469727
1020
+ daps/produced/f1_script3_produced.wav,-17.299999237060547
1021
+ daps/produced/f1_script4_produced.wav,-16.100000381469727
1022
+ daps/produced/f1_script5_produced.wav,-16.700000762939453
1023
+ daps/produced/f3_script1_produced.wav,-16.5
1024
+
1025
+ The ``AudioLoader`` will automatically load the loudness column into
1026
+ the metadata of the signal.
1027
+
1028
+ Uses :py:func:`audiotools.core.effects.EffectMixin.volume_change`.
1029
+
1030
+ Parameters
1031
+ ----------
1032
+ db : tuple, optional
1033
+ dB to normalize signal to, by default ("const", -24)
1034
+ name : str, optional
1035
+ Name of this transform, used to identify it in the dictionary
1036
+ produced by ``self.instantiate``, by default None
1037
+ prob : float, optional
1038
+ Probability of applying this transform, by default 1.0
1039
+ """
1040
+
1041
+ def __init__(
1042
+ self,
1043
+ db: tuple = ("const", -24),
1044
+ name: str = None,
1045
+ prob: float = 1.0,
1046
+ ):
1047
+ super().__init__(name=name, prob=prob)
1048
+
1049
+ self.db = db
1050
+
1051
+ def _instantiate(self, state: RandomState, signal: AudioSignal):
1052
+ if "loudness" not in signal.metadata:
1053
+ db_change = 0.0
1054
+ elif float(signal.metadata["loudness"]) == float("-inf"):
1055
+ db_change = 0.0
1056
+ else:
1057
+ db = util.sample_from_dist(self.db, state)
1058
+ db_change = db - float(signal.metadata["loudness"])
1059
+
1060
+ return {"db": db_change}
1061
+
1062
+ def _transform(self, signal, db):
1063
+ return signal.volume_change(db)
1064
+
1065
+
1066
+ class Silence(BaseTransform):
1067
+ """Zeros out the signal with some probability.
1068
+
1069
+ Parameters
1070
+ ----------
1071
+ name : str, optional
1072
+ Name of this transform, used to identify it in the dictionary
1073
+ produced by ``self.instantiate``, by default None
1074
+ prob : float, optional
1075
+ Probability of applying this transform, by default 0.1
1076
+ """
1077
+
1078
+ def __init__(self, name: str = None, prob: float = 0.1):
1079
+ super().__init__(name=name, prob=prob)
1080
+
1081
+ def _transform(self, signal):
1082
+ _loudness = signal._loudness
1083
+ signal = AudioSignal(
1084
+ torch.zeros_like(signal.audio_data),
1085
+ sample_rate=signal.sample_rate,
1086
+ stft_params=signal.stft_params,
1087
+ )
1088
+ # So that the amound of noise added is as if it wasn't silenced.
1089
+ # TODO: improve this hack
1090
+ signal._loudness = _loudness
1091
+
1092
+ return signal
1093
+
1094
+
1095
+ class LowPass(BaseTransform):
1096
+ """Applies a LowPass filter.
1097
+
1098
+ Uses :py:func:`audiotools.core.dsp.DSPMixin.low_pass`.
1099
+
1100
+ Parameters
1101
+ ----------
1102
+ cutoff : tuple, optional
1103
+ Cutoff frequency distribution,
1104
+ by default ``("choice", [4000, 8000, 16000])``
1105
+ zeros : int, optional
1106
+ Number of zero-crossings in filter, argument to
1107
+ ``julius.LowPassFilters``, by default 51
1108
+ name : str, optional
1109
+ Name of this transform, used to identify it in the dictionary
1110
+ produced by ``self.instantiate``, by default None
1111
+ prob : float, optional
1112
+ Probability of applying this transform, by default 1.0
1113
+ """
1114
+
1115
+ def __init__(
1116
+ self,
1117
+ cutoff: tuple = ("choice", [4000, 8000, 16000]),
1118
+ zeros: int = 51,
1119
+ name: str = None,
1120
+ prob: float = 1,
1121
+ ):
1122
+ super().__init__(name=name, prob=prob)
1123
+
1124
+ self.cutoff = cutoff
1125
+ self.zeros = zeros
1126
+
1127
+ def _instantiate(self, state: RandomState):
1128
+ return {"cutoff": util.sample_from_dist(self.cutoff, state)}
1129
+
1130
+ def _transform(self, signal, cutoff):
1131
+ return signal.low_pass(cutoff, zeros=self.zeros)
1132
+
1133
+
1134
+ class HighPass(BaseTransform):
1135
+ """Applies a HighPass filter.
1136
+
1137
+ Uses :py:func:`audiotools.core.dsp.DSPMixin.high_pass`.
1138
+
1139
+ Parameters
1140
+ ----------
1141
+ cutoff : tuple, optional
1142
+ Cutoff frequency distribution,
1143
+ by default ``("choice", [50, 100, 250, 500, 1000])``
1144
+ zeros : int, optional
1145
+ Number of zero-crossings in filter, argument to
1146
+ ``julius.LowPassFilters``, by default 51
1147
+ name : str, optional
1148
+ Name of this transform, used to identify it in the dictionary
1149
+ produced by ``self.instantiate``, by default None
1150
+ prob : float, optional
1151
+ Probability of applying this transform, by default 1.0
1152
+ """
1153
+
1154
+ def __init__(
1155
+ self,
1156
+ cutoff: tuple = ("choice", [50, 100, 250, 500, 1000]),
1157
+ zeros: int = 51,
1158
+ name: str = None,
1159
+ prob: float = 1,
1160
+ ):
1161
+ super().__init__(name=name, prob=prob)
1162
+
1163
+ self.cutoff = cutoff
1164
+ self.zeros = zeros
1165
+
1166
+ def _instantiate(self, state: RandomState):
1167
+ return {"cutoff": util.sample_from_dist(self.cutoff, state)}
1168
+
1169
+ def _transform(self, signal, cutoff):
1170
+ return signal.high_pass(cutoff, zeros=self.zeros)
1171
+
1172
+
1173
+ class RescaleAudio(BaseTransform):
1174
+ """Rescales the audio so it is in between ``-val`` and ``val``
1175
+ only if the original audio exceeds those bounds. Useful if
1176
+ transforms have caused the audio to clip.
1177
+
1178
+ Uses :py:func:`audiotools.core.effects.EffectMixin.ensure_max_of_audio`.
1179
+
1180
+ Parameters
1181
+ ----------
1182
+ val : float, optional
1183
+ Max absolute value of signal, by default 1.0
1184
+ name : str, optional
1185
+ Name of this transform, used to identify it in the dictionary
1186
+ produced by ``self.instantiate``, by default None
1187
+ prob : float, optional
1188
+ Probability of applying this transform, by default 1.0
1189
+ """
1190
+
1191
+ def __init__(self, val: float = 1.0, name: str = None, prob: float = 1):
1192
+ super().__init__(name=name, prob=prob)
1193
+
1194
+ self.val = val
1195
+
1196
+ def _transform(self, signal):
1197
+ return signal.ensure_max_of_audio(self.val)
1198
+
1199
+
1200
+ class ShiftPhase(SpectralTransform):
1201
+ """Shifts the phase of the audio.
1202
+
1203
+ Uses :py:func:`audiotools.core.dsp.DSPMixin.shift)phase`.
1204
+
1205
+ Parameters
1206
+ ----------
1207
+ shift : tuple, optional
1208
+ How much to shift phase by, by default ("uniform", -np.pi, np.pi)
1209
+ name : str, optional
1210
+ Name of this transform, used to identify it in the dictionary
1211
+ produced by ``self.instantiate``, by default None
1212
+ prob : float, optional
1213
+ Probability of applying this transform, by default 1.0
1214
+ """
1215
+
1216
+ def __init__(
1217
+ self,
1218
+ shift: tuple = ("uniform", -np.pi, np.pi),
1219
+ name: str = None,
1220
+ prob: float = 1,
1221
+ ):
1222
+ super().__init__(name=name, prob=prob)
1223
+ self.shift = shift
1224
+
1225
+ def _instantiate(self, state: RandomState):
1226
+ return {"shift": util.sample_from_dist(self.shift, state)}
1227
+
1228
+ def _transform(self, signal, shift):
1229
+ return signal.shift_phase(shift)
1230
+
1231
+
1232
+ class InvertPhase(ShiftPhase):
1233
+ """Inverts the phase of the audio.
1234
+
1235
+ Uses :py:func:`audiotools.core.dsp.DSPMixin.shift_phase`.
1236
+
1237
+ Parameters
1238
+ ----------
1239
+ name : str, optional
1240
+ Name of this transform, used to identify it in the dictionary
1241
+ produced by ``self.instantiate``, by default None
1242
+ prob : float, optional
1243
+ Probability of applying this transform, by default 1.0
1244
+ """
1245
+
1246
+ def __init__(self, name: str = None, prob: float = 1):
1247
+ super().__init__(shift=("const", np.pi), name=name, prob=prob)
1248
+
1249
+
1250
+ class CorruptPhase(SpectralTransform):
1251
+ """Corrupts the phase of the audio.
1252
+
1253
+ Uses :py:func:`audiotools.core.dsp.DSPMixin.corrupt_phase`.
1254
+
1255
+ Parameters
1256
+ ----------
1257
+ scale : tuple, optional
1258
+ How much to corrupt phase by, by default ("uniform", 0, np.pi)
1259
+ name : str, optional
1260
+ Name of this transform, used to identify it in the dictionary
1261
+ produced by ``self.instantiate``, by default None
1262
+ prob : float, optional
1263
+ Probability of applying this transform, by default 1.0
1264
+ """
1265
+
1266
+ def __init__(
1267
+ self, scale: tuple = ("uniform", 0, np.pi), name: str = None, prob: float = 1
1268
+ ):
1269
+ super().__init__(name=name, prob=prob)
1270
+ self.scale = scale
1271
+
1272
+ def _instantiate(self, state: RandomState, signal: AudioSignal = None):
1273
+ scale = util.sample_from_dist(self.scale, state)
1274
+ corruption = state.normal(scale=scale, size=signal.phase.shape[1:])
1275
+ return {"corruption": corruption.astype("float32")}
1276
+
1277
+ def _transform(self, signal, corruption):
1278
+ return signal.shift_phase(shift=corruption)
1279
+
1280
+
1281
+ class FrequencyMask(SpectralTransform):
1282
+ """Masks a band of frequencies at a center frequency
1283
+ from the audio.
1284
+
1285
+ Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_frequencies`.
1286
+
1287
+ Parameters
1288
+ ----------
1289
+ f_center : tuple, optional
1290
+ Center frequency between 0.0 and 1.0 (Nyquist), by default ("uniform", 0.0, 1.0)
1291
+ f_width : tuple, optional
1292
+ Width of zero'd out band, by default ("const", 0.1)
1293
+ name : str, optional
1294
+ Name of this transform, used to identify it in the dictionary
1295
+ produced by ``self.instantiate``, by default None
1296
+ prob : float, optional
1297
+ Probability of applying this transform, by default 1.0
1298
+ """
1299
+
1300
+ def __init__(
1301
+ self,
1302
+ f_center: tuple = ("uniform", 0.0, 1.0),
1303
+ f_width: tuple = ("const", 0.1),
1304
+ name: str = None,
1305
+ prob: float = 1,
1306
+ ):
1307
+ super().__init__(name=name, prob=prob)
1308
+ self.f_center = f_center
1309
+ self.f_width = f_width
1310
+
1311
+ def _instantiate(self, state: RandomState, signal: AudioSignal):
1312
+ f_center = util.sample_from_dist(self.f_center, state)
1313
+ f_width = util.sample_from_dist(self.f_width, state)
1314
+
1315
+ fmin = max(f_center - (f_width / 2), 0.0)
1316
+ fmax = min(f_center + (f_width / 2), 1.0)
1317
+
1318
+ fmin_hz = (signal.sample_rate / 2) * fmin
1319
+ fmax_hz = (signal.sample_rate / 2) * fmax
1320
+
1321
+ return {"fmin_hz": fmin_hz, "fmax_hz": fmax_hz}
1322
+
1323
+ def _transform(self, signal, fmin_hz: float, fmax_hz: float):
1324
+ return signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz)
1325
+
1326
+
1327
+ class TimeMask(SpectralTransform):
1328
+ """Masks out contiguous time-steps from signal.
1329
+
1330
+ Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_timesteps`.
1331
+
1332
+ Parameters
1333
+ ----------
1334
+ t_center : tuple, optional
1335
+ Center time in terms of 0.0 and 1.0 (duration of signal),
1336
+ by default ("uniform", 0.0, 1.0)
1337
+ t_width : tuple, optional
1338
+ Width of dropped out portion, by default ("const", 0.025)
1339
+ name : str, optional
1340
+ Name of this transform, used to identify it in the dictionary
1341
+ produced by ``self.instantiate``, by default None
1342
+ prob : float, optional
1343
+ Probability of applying this transform, by default 1.0
1344
+ """
1345
+
1346
+ def __init__(
1347
+ self,
1348
+ t_center: tuple = ("uniform", 0.0, 1.0),
1349
+ t_width: tuple = ("const", 0.025),
1350
+ name: str = None,
1351
+ prob: float = 1,
1352
+ ):
1353
+ super().__init__(name=name, prob=prob)
1354
+ self.t_center = t_center
1355
+ self.t_width = t_width
1356
+
1357
+ def _instantiate(self, state: RandomState, signal: AudioSignal):
1358
+ t_center = util.sample_from_dist(self.t_center, state)
1359
+ t_width = util.sample_from_dist(self.t_width, state)
1360
+
1361
+ tmin = max(t_center - (t_width / 2), 0.0)
1362
+ tmax = min(t_center + (t_width / 2), 1.0)
1363
+
1364
+ tmin_s = signal.signal_duration * tmin
1365
+ tmax_s = signal.signal_duration * tmax
1366
+ return {"tmin_s": tmin_s, "tmax_s": tmax_s}
1367
+
1368
+ def _transform(self, signal, tmin_s: float, tmax_s: float):
1369
+ return signal.mask_timesteps(tmin_s=tmin_s, tmax_s=tmax_s)
1370
+
1371
+
1372
+ class MaskLowMagnitudes(SpectralTransform):
1373
+ """Masks low magnitude regions out of signal.
1374
+
1375
+ Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_low_magnitudes`.
1376
+
1377
+ Parameters
1378
+ ----------
1379
+ db_cutoff : tuple, optional
1380
+ Decibel value for which things below it will be masked away,
1381
+ by default ("uniform", -10, 10)
1382
+ name : str, optional
1383
+ Name of this transform, used to identify it in the dictionary
1384
+ produced by ``self.instantiate``, by default None
1385
+ prob : float, optional
1386
+ Probability of applying this transform, by default 1.0
1387
+ """
1388
+
1389
+ def __init__(
1390
+ self,
1391
+ db_cutoff: tuple = ("uniform", -10, 10),
1392
+ name: str = None,
1393
+ prob: float = 1,
1394
+ ):
1395
+ super().__init__(name=name, prob=prob)
1396
+ self.db_cutoff = db_cutoff
1397
+
1398
+ def _instantiate(self, state: RandomState, signal: AudioSignal = None):
1399
+ return {"db_cutoff": util.sample_from_dist(self.db_cutoff, state)}
1400
+
1401
+ def _transform(self, signal, db_cutoff: float):
1402
+ return signal.mask_low_magnitudes(db_cutoff)
1403
+
1404
+
1405
+ class Smoothing(BaseTransform):
1406
+ """Convolves the signal with a smoothing window.
1407
+
1408
+ Uses :py:func:`audiotools.core.effects.EffectMixin.convolve`.
1409
+
1410
+ Parameters
1411
+ ----------
1412
+ window_type : tuple, optional
1413
+ Type of window to use, by default ("const", "average")
1414
+ window_length : tuple, optional
1415
+ Length of smoothing window, by
1416
+ default ("choice", [8, 16, 32, 64, 128, 256, 512])
1417
+ name : str, optional
1418
+ Name of this transform, used to identify it in the dictionary
1419
+ produced by ``self.instantiate``, by default None
1420
+ prob : float, optional
1421
+ Probability of applying this transform, by default 1.0
1422
+ """
1423
+
1424
+ def __init__(
1425
+ self,
1426
+ window_type: tuple = ("const", "average"),
1427
+ window_length: tuple = ("choice", [8, 16, 32, 64, 128, 256, 512]),
1428
+ name: str = None,
1429
+ prob: float = 1,
1430
+ ):
1431
+ super().__init__(name=name, prob=prob)
1432
+ self.window_type = window_type
1433
+ self.window_length = window_length
1434
+
1435
+ def _instantiate(self, state: RandomState, signal: AudioSignal = None):
1436
+ window_type = util.sample_from_dist(self.window_type, state)
1437
+ window_length = util.sample_from_dist(self.window_length, state)
1438
+ window = signal.get_window(
1439
+ window_type=window_type, window_length=window_length, device="cpu"
1440
+ )
1441
+ return {"window": AudioSignal(window, signal.sample_rate)}
1442
+
1443
+ def _transform(self, signal, window):
1444
+ sscale = signal.audio_data.abs().max(dim=-1, keepdim=True).values
1445
+ sscale[sscale == 0.0] = 1.0
1446
+
1447
+ out = signal.convolve(window)
1448
+
1449
+ oscale = out.audio_data.abs().max(dim=-1, keepdim=True).values
1450
+ oscale[oscale == 0.0] = 1.0
1451
+
1452
+ out = out * (sscale / oscale)
1453
+ return out
1454
+
1455
+
1456
+ class TimeNoise(TimeMask):
1457
+ """Similar to :py:func:`audiotools.data.transforms.TimeMask`, but
1458
+ replaces with noise instead of zeros.
1459
+
1460
+ Parameters
1461
+ ----------
1462
+ t_center : tuple, optional
1463
+ Center time in terms of 0.0 and 1.0 (duration of signal),
1464
+ by default ("uniform", 0.0, 1.0)
1465
+ t_width : tuple, optional
1466
+ Width of dropped out portion, by default ("const", 0.025)
1467
+ name : str, optional
1468
+ Name of this transform, used to identify it in the dictionary
1469
+ produced by ``self.instantiate``, by default None
1470
+ prob : float, optional
1471
+ Probability of applying this transform, by default 1.0
1472
+ """
1473
+
1474
+ def __init__(
1475
+ self,
1476
+ t_center: tuple = ("uniform", 0.0, 1.0),
1477
+ t_width: tuple = ("const", 0.025),
1478
+ name: str = None,
1479
+ prob: float = 1,
1480
+ ):
1481
+ super().__init__(t_center=t_center, t_width=t_width, name=name, prob=prob)
1482
+
1483
+ def _transform(self, signal, tmin_s: float, tmax_s: float):
1484
+ signal = signal.mask_timesteps(tmin_s=tmin_s, tmax_s=tmax_s, val=0.0)
1485
+ mag, phase = signal.magnitude, signal.phase
1486
+
1487
+ mag_r, phase_r = torch.randn_like(mag), torch.randn_like(phase)
1488
+ mask = (mag == 0.0) * (phase == 0.0)
1489
+
1490
+ mag[mask] = mag_r[mask]
1491
+ phase[mask] = phase_r[mask]
1492
+
1493
+ signal.magnitude = mag
1494
+ signal.phase = phase
1495
+ return signal
1496
+
1497
+
1498
+ class FrequencyNoise(FrequencyMask):
1499
+ """Similar to :py:func:`audiotools.data.transforms.FrequencyMask`, but
1500
+ replaces with noise instead of zeros.
1501
+
1502
+ Parameters
1503
+ ----------
1504
+ f_center : tuple, optional
1505
+ Center frequency between 0.0 and 1.0 (Nyquist), by default ("uniform", 0.0, 1.0)
1506
+ f_width : tuple, optional
1507
+ Width of zero'd out band, by default ("const", 0.1)
1508
+ name : str, optional
1509
+ Name of this transform, used to identify it in the dictionary
1510
+ produced by ``self.instantiate``, by default None
1511
+ prob : float, optional
1512
+ Probability of applying this transform, by default 1.0
1513
+ """
1514
+
1515
+ def __init__(
1516
+ self,
1517
+ f_center: tuple = ("uniform", 0.0, 1.0),
1518
+ f_width: tuple = ("const", 0.1),
1519
+ name: str = None,
1520
+ prob: float = 1,
1521
+ ):
1522
+ super().__init__(f_center=f_center, f_width=f_width, name=name, prob=prob)
1523
+
1524
+ def _transform(self, signal, fmin_hz: float, fmax_hz: float):
1525
+ signal = signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz)
1526
+ mag, phase = signal.magnitude, signal.phase
1527
+
1528
+ mag_r, phase_r = torch.randn_like(mag), torch.randn_like(phase)
1529
+ mask = (mag == 0.0) * (phase == 0.0)
1530
+
1531
+ mag[mask] = mag_r[mask]
1532
+ phase[mask] = phase_r[mask]
1533
+
1534
+ signal.magnitude = mag
1535
+ signal.phase = phase
1536
+ return signal
1537
+
1538
+
1539
+ class SpectralDenoising(Equalizer):
1540
+ """Applies denoising algorithm detailed in
1541
+ :py:func:`audiotools.ml.layers.spectral_gate.SpectralGate`,
1542
+ using a randomly generated noise signal for denoising.
1543
+
1544
+ Parameters
1545
+ ----------
1546
+ eq_amount : tuple, optional
1547
+ Amount of eq to apply to noise signal, by default ("const", 1.0)
1548
+ denoise_amount : tuple, optional
1549
+ Amount to denoise by, by default ("uniform", 0.8, 1.0)
1550
+ nz_volume : float, optional
1551
+ Volume of noise to denoise with, by default -40
1552
+ n_bands : int, optional
1553
+ Number of bands in equalizer, by default 6
1554
+ n_freq : int, optional
1555
+ Number of frequency bins to smooth by, by default 3
1556
+ n_time : int, optional
1557
+ Number of time bins to smooth by, by default 5
1558
+ name : str, optional
1559
+ Name of this transform, used to identify it in the dictionary
1560
+ produced by ``self.instantiate``, by default None
1561
+ prob : float, optional
1562
+ Probability of applying this transform, by default 1.0
1563
+ """
1564
+
1565
+ def __init__(
1566
+ self,
1567
+ eq_amount: tuple = ("const", 1.0),
1568
+ denoise_amount: tuple = ("uniform", 0.8, 1.0),
1569
+ nz_volume: float = -40,
1570
+ n_bands: int = 6,
1571
+ n_freq: int = 3,
1572
+ n_time: int = 5,
1573
+ name: str = None,
1574
+ prob: float = 1,
1575
+ ):
1576
+ super().__init__(eq_amount=eq_amount, n_bands=n_bands, name=name, prob=prob)
1577
+
1578
+ self.nz_volume = nz_volume
1579
+ self.denoise_amount = denoise_amount
1580
+ self.spectral_gate = ml.layers.SpectralGate(n_freq, n_time)
1581
+
1582
+ def _transform(self, signal, nz, eq, denoise_amount):
1583
+ nz = nz.normalize(self.nz_volume).equalizer(eq)
1584
+ self.spectral_gate = self.spectral_gate.to(signal.device)
1585
+ signal = self.spectral_gate(signal, nz, denoise_amount)
1586
+ return signal
1587
+
1588
+ def _instantiate(self, state: RandomState):
1589
+ kwargs = super()._instantiate(state)
1590
+ kwargs["denoise_amount"] = util.sample_from_dist(self.denoise_amount, state)
1591
+ kwargs["nz"] = AudioSignal(state.randn(22050), 44100)
1592
+ return kwargs
dac-vae/audiotools/metrics/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """
2
+ Functions for comparing AudioSignal objects to one another.
3
+ """ # fmt: skip
4
+ from . import distance
5
+ from . import quality
6
+ from . import spectral
dac-vae/audiotools/metrics/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (331 Bytes). View file
 
dac-vae/audiotools/metrics/__pycache__/distance.cpython-310.pyc ADDED
Binary file (3.84 kB). View file
 
dac-vae/audiotools/metrics/__pycache__/quality.cpython-310.pyc ADDED
Binary file (4.47 kB). View file
 
dac-vae/audiotools/metrics/__pycache__/spectral.cpython-310.pyc ADDED
Binary file (7.45 kB). View file
 
dac-vae/audiotools/metrics/distance.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from .. import AudioSignal
5
+
6
+
7
+ class L1Loss(nn.L1Loss):
8
+ """L1 Loss between AudioSignals. Defaults
9
+ to comparing ``audio_data``, but any
10
+ attribute of an AudioSignal can be used.
11
+
12
+ Parameters
13
+ ----------
14
+ attribute : str, optional
15
+ Attribute of signal to compare, defaults to ``audio_data``.
16
+ weight : float, optional
17
+ Weight of this loss, defaults to 1.0.
18
+ """
19
+
20
+ def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
21
+ self.attribute = attribute
22
+ self.weight = weight
23
+ super().__init__(**kwargs)
24
+
25
+ def forward(self, x: AudioSignal, y: AudioSignal):
26
+ """
27
+ Parameters
28
+ ----------
29
+ x : AudioSignal
30
+ Estimate AudioSignal
31
+ y : AudioSignal
32
+ Reference AudioSignal
33
+
34
+ Returns
35
+ -------
36
+ torch.Tensor
37
+ L1 loss between AudioSignal attributes.
38
+ """
39
+ if isinstance(x, AudioSignal):
40
+ x = getattr(x, self.attribute)
41
+ y = getattr(y, self.attribute)
42
+ return super().forward(x, y)
43
+
44
+
45
+ class SISDRLoss(nn.Module):
46
+ """
47
+ Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
48
+ of estimated and reference audio signals or aligned features.
49
+
50
+ Parameters
51
+ ----------
52
+ scaling : int, optional
53
+ Whether to use scale-invariant (True) or
54
+ signal-to-noise ratio (False), by default True
55
+ reduction : str, optional
56
+ How to reduce across the batch (either 'mean',
57
+ 'sum', or none).], by default ' mean'
58
+ zero_mean : int, optional
59
+ Zero mean the references and estimates before
60
+ computing the loss, by default True
61
+ clip_min : int, optional
62
+ The minimum possible loss value. Helps network
63
+ to not focus on making already good examples better, by default None
64
+ weight : float, optional
65
+ Weight of this loss, defaults to 1.0.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ scaling: int = True,
71
+ reduction: str = "mean",
72
+ zero_mean: int = True,
73
+ clip_min: int = None,
74
+ weight: float = 1.0,
75
+ ):
76
+ self.scaling = scaling
77
+ self.reduction = reduction
78
+ self.zero_mean = zero_mean
79
+ self.clip_min = clip_min
80
+ self.weight = weight
81
+ super().__init__()
82
+
83
+ def forward(self, x: AudioSignal, y: AudioSignal):
84
+ eps = 1e-8
85
+ # nb, nc, nt
86
+ if isinstance(x, AudioSignal):
87
+ references = x.audio_data
88
+ estimates = y.audio_data
89
+ else:
90
+ references = x
91
+ estimates = y
92
+
93
+ nb = references.shape[0]
94
+ references = references.reshape(nb, 1, -1).permute(0, 2, 1)
95
+ estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
96
+
97
+ # samples now on axis 1
98
+ if self.zero_mean:
99
+ mean_reference = references.mean(dim=1, keepdim=True)
100
+ mean_estimate = estimates.mean(dim=1, keepdim=True)
101
+ else:
102
+ mean_reference = 0
103
+ mean_estimate = 0
104
+
105
+ _references = references - mean_reference
106
+ _estimates = estimates - mean_estimate
107
+
108
+ references_projection = (_references**2).sum(dim=-2) + eps
109
+ references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
110
+
111
+ scale = (
112
+ (references_on_estimates / references_projection).unsqueeze(1)
113
+ if self.scaling
114
+ else 1
115
+ )
116
+
117
+ e_true = scale * _references
118
+ e_res = _estimates - e_true
119
+
120
+ signal = (e_true**2).sum(dim=1)
121
+ noise = (e_res**2).sum(dim=1)
122
+ sdr = -10 * torch.log10(signal / noise + eps)
123
+
124
+ if self.clip_min is not None:
125
+ sdr = torch.clamp(sdr, min=self.clip_min)
126
+
127
+ if self.reduction == "mean":
128
+ sdr = sdr.mean()
129
+ elif self.reduction == "sum":
130
+ sdr = sdr.sum()
131
+ return sdr
dac-vae/audiotools/metrics/quality.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from .. import AudioSignal
7
+
8
+
9
+ def stoi(
10
+ estimates: AudioSignal,
11
+ references: AudioSignal,
12
+ extended: int = False,
13
+ ):
14
+ """Short term objective intelligibility
15
+ Computes the STOI (See [1][2]) of a denoised signal compared to a clean
16
+ signal, The output is expected to have a monotonic relation with the
17
+ subjective speech-intelligibility, where a higher score denotes better
18
+ speech intelligibility. Uses pystoi under the hood.
19
+
20
+ Parameters
21
+ ----------
22
+ estimates : AudioSignal
23
+ Denoised speech
24
+ references : AudioSignal
25
+ Clean original speech
26
+ extended : int, optional
27
+ Boolean, whether to use the extended STOI described in [3], by default False
28
+
29
+ Returns
30
+ -------
31
+ Tensor[float]
32
+ Short time objective intelligibility measure between clean and
33
+ denoised speech
34
+
35
+ References
36
+ ----------
37
+ 1. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time
38
+ Objective Intelligibility Measure for Time-Frequency Weighted Noisy
39
+ Speech', ICASSP 2010, Texas, Dallas.
40
+ 2. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for
41
+ Intelligibility Prediction of Time-Frequency Weighted Noisy Speech',
42
+ IEEE Transactions on Audio, Speech, and Language Processing, 2011.
43
+ 3. Jesper Jensen and Cees H. Taal, 'An Algorithm for Predicting the
44
+ Intelligibility of Speech Masked by Modulated Noise Maskers',
45
+ IEEE Transactions on Audio, Speech and Language Processing, 2016.
46
+ """
47
+ import pystoi
48
+
49
+ estimates = estimates.clone().to_mono()
50
+ references = references.clone().to_mono()
51
+
52
+ stois = []
53
+ for i in range(estimates.batch_size):
54
+ _stoi = pystoi.stoi(
55
+ references.audio_data[i, 0].detach().cpu().numpy(),
56
+ estimates.audio_data[i, 0].detach().cpu().numpy(),
57
+ references.sample_rate,
58
+ extended=extended,
59
+ )
60
+ stois.append(_stoi)
61
+ return torch.from_numpy(np.array(stois))
62
+
63
+
64
+ def pesq(
65
+ estimates: AudioSignal,
66
+ references: AudioSignal,
67
+ mode: str = "wb",
68
+ target_sr: float = 16000,
69
+ ):
70
+ """_summary_
71
+
72
+ Parameters
73
+ ----------
74
+ estimates : AudioSignal
75
+ Degraded AudioSignal
76
+ references : AudioSignal
77
+ Reference AudioSignal
78
+ mode : str, optional
79
+ 'wb' (wide-band) or 'nb' (narrow-band), by default "wb"
80
+ target_sr : float, optional
81
+ Target sample rate, by default 16000
82
+
83
+ Returns
84
+ -------
85
+ Tensor[float]
86
+ PESQ score: P.862.2 Prediction (MOS-LQO)
87
+ """
88
+ from pesq import pesq as pesq_fn
89
+
90
+ estimates = estimates.clone().to_mono().resample(target_sr)
91
+ references = references.clone().to_mono().resample(target_sr)
92
+
93
+ pesqs = []
94
+ for i in range(estimates.batch_size):
95
+ _pesq = pesq_fn(
96
+ estimates.sample_rate,
97
+ references.audio_data[i, 0].detach().cpu().numpy(),
98
+ estimates.audio_data[i, 0].detach().cpu().numpy(),
99
+ mode,
100
+ )
101
+ pesqs.append(_pesq)
102
+ return torch.from_numpy(np.array(pesqs))
103
+
104
+
105
+ def visqol(
106
+ estimates: AudioSignal,
107
+ references: AudioSignal,
108
+ mode: str = "audio",
109
+ ): # pragma: no cover
110
+ """ViSQOL score.
111
+
112
+ Parameters
113
+ ----------
114
+ estimates : AudioSignal
115
+ Degraded AudioSignal
116
+ references : AudioSignal
117
+ Reference AudioSignal
118
+ mode : str, optional
119
+ 'audio' or 'speech', by default 'audio'
120
+
121
+ Returns
122
+ -------
123
+ Tensor[float]
124
+ ViSQOL score (MOS-LQO)
125
+ """
126
+ from visqol import visqol_lib_py
127
+ from visqol.pb2 import visqol_config_pb2
128
+ from visqol.pb2 import similarity_result_pb2
129
+
130
+ config = visqol_config_pb2.VisqolConfig()
131
+ if mode == "audio":
132
+ target_sr = 48000
133
+ config.options.use_speech_scoring = False
134
+ svr_model_path = "libsvm_nu_svr_model.txt"
135
+ elif mode == "speech":
136
+ target_sr = 16000
137
+ config.options.use_speech_scoring = True
138
+ svr_model_path = "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite"
139
+ else:
140
+ raise ValueError(f"Unrecognized mode: {mode}")
141
+ config.audio.sample_rate = target_sr
142
+ config.options.svr_model_path = os.path.join(
143
+ os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path
144
+ )
145
+
146
+ api = visqol_lib_py.VisqolApi()
147
+ api.Create(config)
148
+
149
+ estimates = estimates.clone().to_mono().resample(target_sr)
150
+ references = references.clone().to_mono().resample(target_sr)
151
+
152
+ visqols = []
153
+ for i in range(estimates.batch_size):
154
+ _visqol = api.Measure(
155
+ references.audio_data[i, 0].detach().cpu().numpy().astype(float),
156
+ estimates.audio_data[i, 0].detach().cpu().numpy().astype(float),
157
+ )
158
+ visqols.append(_visqol.moslqo)
159
+ return torch.from_numpy(np.array(visqols))
dac-vae/audiotools/metrics/spectral.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ from torch import nn
6
+
7
+ from .. import AudioSignal
8
+ from .. import STFTParams
9
+
10
+
11
+ class MultiScaleSTFTLoss(nn.Module):
12
+ """Computes the multi-scale STFT loss from [1].
13
+
14
+ Parameters
15
+ ----------
16
+ window_lengths : List[int], optional
17
+ Length of each window of each STFT, by default [2048, 512]
18
+ loss_fn : typing.Callable, optional
19
+ How to compare each loss, by default nn.L1Loss()
20
+ clamp_eps : float, optional
21
+ Clamp on the log magnitude, below, by default 1e-5
22
+ mag_weight : float, optional
23
+ Weight of raw magnitude portion of loss, by default 1.0
24
+ log_weight : float, optional
25
+ Weight of log magnitude portion of loss, by default 1.0
26
+ pow : float, optional
27
+ Power to raise magnitude to before taking log, by default 2.0
28
+ weight : float, optional
29
+ Weight of this loss, by default 1.0
30
+ match_stride : bool, optional
31
+ Whether to match the stride of convolutional layers, by default False
32
+
33
+ References
34
+ ----------
35
+
36
+ 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
37
+ "DDSP: Differentiable Digital Signal Processing."
38
+ International Conference on Learning Representations. 2019.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ window_lengths: List[int] = [2048, 512],
44
+ loss_fn: typing.Callable = nn.L1Loss(),
45
+ clamp_eps: float = 1e-5,
46
+ mag_weight: float = 1.0,
47
+ log_weight: float = 1.0,
48
+ pow: float = 2.0,
49
+ weight: float = 1.0,
50
+ match_stride: bool = False,
51
+ window_type: str = None,
52
+ ):
53
+ super().__init__()
54
+ self.stft_params = [
55
+ STFTParams(
56
+ window_length=w,
57
+ hop_length=w // 4,
58
+ match_stride=match_stride,
59
+ window_type=window_type,
60
+ )
61
+ for w in window_lengths
62
+ ]
63
+ self.loss_fn = loss_fn
64
+ self.log_weight = log_weight
65
+ self.mag_weight = mag_weight
66
+ self.clamp_eps = clamp_eps
67
+ self.weight = weight
68
+ self.pow = pow
69
+
70
+ def forward(self, x: AudioSignal, y: AudioSignal):
71
+ """Computes multi-scale STFT between an estimate and a reference
72
+ signal.
73
+
74
+ Parameters
75
+ ----------
76
+ x : AudioSignal
77
+ Estimate signal
78
+ y : AudioSignal
79
+ Reference signal
80
+
81
+ Returns
82
+ -------
83
+ torch.Tensor
84
+ Multi-scale STFT loss.
85
+ """
86
+ loss = 0.0
87
+ for s in self.stft_params:
88
+ x.stft(s.window_length, s.hop_length, s.window_type)
89
+ y.stft(s.window_length, s.hop_length, s.window_type)
90
+ loss += self.log_weight * self.loss_fn(
91
+ x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
92
+ y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
93
+ )
94
+ loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
95
+ return loss
96
+
97
+
98
+ class MelSpectrogramLoss(nn.Module):
99
+ """Compute distance between mel spectrograms. Can be used
100
+ in a multi-scale way.
101
+
102
+ Parameters
103
+ ----------
104
+ n_mels : List[int]
105
+ Number of mels per STFT, by default [150, 80],
106
+ window_lengths : List[int], optional
107
+ Length of each window of each STFT, by default [2048, 512]
108
+ loss_fn : typing.Callable, optional
109
+ How to compare each loss, by default nn.L1Loss()
110
+ clamp_eps : float, optional
111
+ Clamp on the log magnitude, below, by default 1e-5
112
+ mag_weight : float, optional
113
+ Weight of raw magnitude portion of loss, by default 1.0
114
+ log_weight : float, optional
115
+ Weight of log magnitude portion of loss, by default 1.0
116
+ pow : float, optional
117
+ Power to raise magnitude to before taking log, by default 2.0
118
+ weight : float, optional
119
+ Weight of this loss, by default 1.0
120
+ match_stride : bool, optional
121
+ Whether to match the stride of convolutional layers, by default False
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ n_mels: List[int] = [150, 80],
127
+ window_lengths: List[int] = [2048, 512],
128
+ loss_fn: typing.Callable = nn.L1Loss(),
129
+ clamp_eps: float = 1e-5,
130
+ mag_weight: float = 1.0,
131
+ log_weight: float = 1.0,
132
+ pow: float = 2.0,
133
+ weight: float = 1.0,
134
+ match_stride: bool = False,
135
+ mel_fmin: List[float] = [0.0, 0.0],
136
+ mel_fmax: List[float] = [None, None],
137
+ window_type: str = None,
138
+ ):
139
+ super().__init__()
140
+ self.stft_params = [
141
+ STFTParams(
142
+ window_length=w,
143
+ hop_length=w // 4,
144
+ match_stride=match_stride,
145
+ window_type=window_type,
146
+ )
147
+ for w in window_lengths
148
+ ]
149
+ self.n_mels = n_mels
150
+ self.loss_fn = loss_fn
151
+ self.clamp_eps = clamp_eps
152
+ self.log_weight = log_weight
153
+ self.mag_weight = mag_weight
154
+ self.weight = weight
155
+ self.mel_fmin = mel_fmin
156
+ self.mel_fmax = mel_fmax
157
+ self.pow = pow
158
+
159
+ def forward(self, x: AudioSignal, y: AudioSignal):
160
+ """Computes mel loss between an estimate and a reference
161
+ signal.
162
+
163
+ Parameters
164
+ ----------
165
+ x : AudioSignal
166
+ Estimate signal
167
+ y : AudioSignal
168
+ Reference signal
169
+
170
+ Returns
171
+ -------
172
+ torch.Tensor
173
+ Mel loss.
174
+ """
175
+ loss = 0.0
176
+ for n_mels, fmin, fmax, s in zip(
177
+ self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
178
+ ):
179
+ kwargs = {
180
+ "window_length": s.window_length,
181
+ "hop_length": s.hop_length,
182
+ "window_type": s.window_type,
183
+ }
184
+ x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
185
+ y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
186
+
187
+ loss += self.log_weight * self.loss_fn(
188
+ x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
189
+ y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
190
+ )
191
+ loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
192
+ return loss
193
+
194
+
195
+ class PhaseLoss(nn.Module):
196
+ """Difference between phase spectrograms.
197
+
198
+ Parameters
199
+ ----------
200
+ window_length : int, optional
201
+ Length of STFT window, by default 2048
202
+ hop_length : int, optional
203
+ Hop length of STFT window, by default 512
204
+ weight : float, optional
205
+ Weight of loss, by default 1.0
206
+ """
207
+
208
+ def __init__(
209
+ self, window_length: int = 2048, hop_length: int = 512, weight: float = 1.0
210
+ ):
211
+ super().__init__()
212
+
213
+ self.weight = weight
214
+ self.stft_params = STFTParams(window_length, hop_length)
215
+
216
+ def forward(self, x: AudioSignal, y: AudioSignal):
217
+ """Computes phase loss between an estimate and a reference
218
+ signal.
219
+
220
+ Parameters
221
+ ----------
222
+ x : AudioSignal
223
+ Estimate signal
224
+ y : AudioSignal
225
+ Reference signal
226
+
227
+ Returns
228
+ -------
229
+ torch.Tensor
230
+ Phase loss.
231
+ """
232
+ s = self.stft_params
233
+ x.stft(s.window_length, s.hop_length, s.window_type)
234
+ y.stft(s.window_length, s.hop_length, s.window_type)
235
+
236
+ # Take circular difference
237
+ diff = x.phase - y.phase
238
+ diff[diff < -np.pi] += 2 * np.pi
239
+ diff[diff > np.pi] -= -2 * np.pi
240
+
241
+ # Scale true magnitude to weights in [0, 1]
242
+ x_min, x_max = x.magnitude.min(), x.magnitude.max()
243
+ weights = (x.magnitude - x_min) / (x_max - x_min)
244
+
245
+ # Take weighted mean of all phase errors
246
+ loss = ((weights * diff) ** 2).mean()
247
+ return loss
dac-vae/audiotools/ml/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from . import decorators
2
+ from . import layers
3
+ from .accelerator import Accelerator
4
+ from .experiment import Experiment
5
+ from .layers import BaseModel
dac-vae/audiotools/ml/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (342 Bytes). View file
 
dac-vae/audiotools/ml/__pycache__/accelerator.cpython-310.pyc ADDED
Binary file (6.67 kB). View file
 
dac-vae/audiotools/ml/__pycache__/decorators.cpython-310.pyc ADDED
Binary file (14.2 kB). View file
 
dac-vae/audiotools/ml/__pycache__/experiment.cpython-310.pyc ADDED
Binary file (3.34 kB). View file
 
dac-vae/audiotools/ml/accelerator.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import typing
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.nn.parallel import DataParallel
7
+ from torch.nn.parallel import DistributedDataParallel
8
+
9
+ from ..data.datasets import ResumableDistributedSampler as DistributedSampler
10
+ from ..data.datasets import ResumableSequentialSampler as SequentialSampler
11
+
12
+
13
+ class Accelerator: # pragma: no cover
14
+ """This class is used to prepare models and dataloaders for
15
+ usage with DDP or DP. Use the functions prepare_model, prepare_dataloader to
16
+ prepare the respective objects. In the case of models, they are moved to
17
+ the appropriate GPU and SyncBatchNorm is applied to them. In the case of
18
+ dataloaders, a sampler is created and the dataloader is initialized with
19
+ that sampler.
20
+
21
+ If the world size is 1, prepare_model and prepare_dataloader are
22
+ no-ops. If the environment variable ``LOCAL_RANK`` is not set, then the
23
+ script was launched without ``torchrun``, and ``DataParallel``
24
+ will be used instead of ``DistributedDataParallel`` (not recommended), if
25
+ the world size (number of GPUs) is greater than 1.
26
+
27
+ Parameters
28
+ ----------
29
+ amp : bool, optional
30
+ Whether or not to enable automatic mixed precision, by default False
31
+ """
32
+
33
+ def __init__(self, amp: bool = False):
34
+ local_rank = os.getenv("LOCAL_RANK", None)
35
+ self.world_size = torch.cuda.device_count()
36
+
37
+ self.use_ddp = self.world_size > 1 and local_rank is not None
38
+ self.use_dp = self.world_size > 1 and local_rank is None
39
+ self.device = "cpu" if self.world_size == 0 else "cuda"
40
+
41
+ if self.use_ddp:
42
+ local_rank = int(local_rank)
43
+ dist.init_process_group(
44
+ "nccl",
45
+ init_method="env://",
46
+ world_size=self.world_size,
47
+ rank=local_rank,
48
+ )
49
+
50
+ self.local_rank = 0 if local_rank is None else local_rank
51
+ self.amp = amp
52
+
53
+ class DummyScaler:
54
+ def __init__(self):
55
+ pass
56
+
57
+ def step(self, optimizer):
58
+ optimizer.step()
59
+
60
+ def scale(self, loss):
61
+ return loss
62
+
63
+ def unscale_(self, optimizer):
64
+ return optimizer
65
+
66
+ def update(self):
67
+ pass
68
+
69
+ self.scaler = torch.cuda.amp.GradScaler() if amp else DummyScaler()
70
+ self.device_ctx = (
71
+ torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
72
+ )
73
+
74
+ def __enter__(self):
75
+ if self.device_ctx is not None:
76
+ self.device_ctx.__enter__()
77
+ return self
78
+
79
+ def __exit__(self, exc_type, exc_value, traceback):
80
+ if self.device_ctx is not None:
81
+ self.device_ctx.__exit__(exc_type, exc_value, traceback)
82
+
83
+ def prepare_model(self, model: torch.nn.Module, **kwargs):
84
+ """Prepares model for DDP or DP. The model is moved to
85
+ the device of the correct rank.
86
+
87
+ Parameters
88
+ ----------
89
+ model : torch.nn.Module
90
+ Model that is converted for DDP or DP.
91
+
92
+ Returns
93
+ -------
94
+ torch.nn.Module
95
+ Wrapped model, or original model if DDP and DP are turned off.
96
+ """
97
+ model = model.to(self.device)
98
+ if self.use_ddp:
99
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
100
+ model = DistributedDataParallel(
101
+ model, device_ids=[self.local_rank], **kwargs
102
+ )
103
+ elif self.use_dp:
104
+ model = DataParallel(model, **kwargs)
105
+ return model
106
+
107
+ # Automatic mixed-precision utilities
108
+ def autocast(self, *args, **kwargs):
109
+ """Context manager for autocasting. Arguments
110
+ go to ``torch.cuda.amp.autocast``.
111
+ """
112
+ return torch.cuda.amp.autocast(self.amp, *args, **kwargs)
113
+
114
+ def backward(self, loss: torch.Tensor):
115
+ """Backwards pass, after scaling the loss if ``amp`` is
116
+ enabled.
117
+
118
+ Parameters
119
+ ----------
120
+ loss : torch.Tensor
121
+ Loss value.
122
+ """
123
+ self.scaler.scale(loss).backward()
124
+
125
+ def step(self, optimizer: torch.optim.Optimizer):
126
+ """Steps the optimizer, using a ``scaler`` if ``amp`` is
127
+ enabled.
128
+
129
+ Parameters
130
+ ----------
131
+ optimizer : torch.optim.Optimizer
132
+ Optimizer to step forward.
133
+ """
134
+ self.scaler.step(optimizer)
135
+
136
+ def update(self):
137
+ """Updates the scale factor."""
138
+ self.scaler.update()
139
+
140
+ def prepare_dataloader(
141
+ self, dataset: typing.Iterable, start_idx: int = None, **kwargs
142
+ ):
143
+ """Wraps a dataset with a DataLoader, using the correct sampler if DDP is
144
+ enabled.
145
+
146
+ Parameters
147
+ ----------
148
+ dataset : typing.Iterable
149
+ Dataset to build Dataloader around.
150
+ start_idx : int, optional
151
+ Start index of sampler, useful if resuming from some epoch,
152
+ by default None
153
+
154
+ Returns
155
+ -------
156
+ _type_
157
+ _description_
158
+ """
159
+
160
+ if self.use_ddp:
161
+ sampler = DistributedSampler(
162
+ dataset,
163
+ start_idx,
164
+ num_replicas=self.world_size,
165
+ rank=self.local_rank,
166
+ )
167
+ if "num_workers" in kwargs:
168
+ kwargs["num_workers"] = max(kwargs["num_workers"] // self.world_size, 1)
169
+ kwargs["batch_size"] = max(kwargs["batch_size"] // self.world_size, 1)
170
+ else:
171
+ sampler = SequentialSampler(dataset, start_idx)
172
+
173
+ dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler, **kwargs)
174
+ return dataloader
175
+
176
+ @staticmethod
177
+ def unwrap(model):
178
+ """Unwraps the model if it was wrapped in DDP or DP, otherwise
179
+ just returns the model. Use this to unwrap the model returned by
180
+ :py:func:`audiotools.ml.accelerator.Accelerator.prepare_model`.
181
+ """
182
+ if hasattr(model, "module"):
183
+ return model.module
184
+ return model
dac-vae/audiotools/ml/decorators.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import time
4
+ from collections import defaultdict
5
+ from functools import wraps
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from rich import box
10
+ from rich.console import Console
11
+ from rich.console import Group
12
+ from rich.live import Live
13
+ from rich.markdown import Markdown
14
+ from rich.padding import Padding
15
+ from rich.panel import Panel
16
+ from rich.progress import BarColumn
17
+ from rich.progress import Progress
18
+ from rich.progress import SpinnerColumn
19
+ from rich.progress import TimeElapsedColumn
20
+ from rich.progress import TimeRemainingColumn
21
+ from rich.rule import Rule
22
+ from rich.table import Table
23
+ from torch.utils.tensorboard import SummaryWriter
24
+
25
+
26
+ # This is here so that the history can be pickled.
27
+ def default_list():
28
+ return []
29
+
30
+
31
+ class Mean:
32
+ """Keeps track of the running mean, along with the latest
33
+ value.
34
+ """
35
+
36
+ def __init__(self):
37
+ self.reset()
38
+
39
+ def __call__(self):
40
+ mean = self.total / max(self.count, 1)
41
+ return mean
42
+
43
+ def reset(self):
44
+ self.count = 0
45
+ self.total = 0
46
+
47
+ def update(self, val):
48
+ if math.isfinite(val):
49
+ self.count += 1
50
+ self.total += val
51
+
52
+
53
+ def when(condition):
54
+ """Runs a function only when the condition is met. The condition is
55
+ a function that is run.
56
+
57
+ Parameters
58
+ ----------
59
+ condition : Callable
60
+ Function to run to check whether or not to run the decorated
61
+ function.
62
+
63
+ Example
64
+ -------
65
+ Checkpoint only runs every 100 iterations, and only if the
66
+ local rank is 0.
67
+
68
+ >>> i = 0
69
+ >>> rank = 0
70
+ >>>
71
+ >>> @when(lambda: i % 100 == 0 and rank == 0)
72
+ >>> def checkpoint():
73
+ >>> print("Saving to /runs/exp1")
74
+ >>>
75
+ >>> for i in range(1000):
76
+ >>> checkpoint()
77
+
78
+ """
79
+
80
+ def decorator(fn):
81
+ @wraps(fn)
82
+ def decorated(*args, **kwargs):
83
+ if condition():
84
+ return fn(*args, **kwargs)
85
+
86
+ return decorated
87
+
88
+ return decorator
89
+
90
+
91
+ def timer(prefix: str = "time"):
92
+ """Adds execution time to the output dictionary of the decorated
93
+ function. The function decorated by this must output a dictionary.
94
+ The key added will follow the form "[prefix]/[name_of_function]"
95
+
96
+ Parameters
97
+ ----------
98
+ prefix : str, optional
99
+ The key added will follow the form "[prefix]/[name_of_function]",
100
+ by default "time".
101
+ """
102
+
103
+ def decorator(fn):
104
+ @wraps(fn)
105
+ def decorated(*args, **kwargs):
106
+ s = time.perf_counter()
107
+ output = fn(*args, **kwargs)
108
+ assert isinstance(output, dict)
109
+ e = time.perf_counter()
110
+ output[f"{prefix}/{fn.__name__}"] = e - s
111
+ return output
112
+
113
+ return decorated
114
+
115
+ return decorator
116
+
117
+
118
+ class Tracker:
119
+ """
120
+ A tracker class that helps to monitor the progress of training and logging the metrics.
121
+
122
+ Attributes
123
+ ----------
124
+ metrics : dict
125
+ A dictionary containing the metrics for each label.
126
+ history : dict
127
+ A dictionary containing the history of metrics for each label.
128
+ writer : SummaryWriter
129
+ A SummaryWriter object for logging the metrics.
130
+ rank : int
131
+ The rank of the current process.
132
+ step : int
133
+ The current step of the training.
134
+ tasks : dict
135
+ A dictionary containing the progress bars and tables for each label.
136
+ pbar : Progress
137
+ A progress bar object for displaying the progress.
138
+ consoles : list
139
+ A list of console objects for logging.
140
+ live : Live
141
+ A Live object for updating the display live.
142
+
143
+ Methods
144
+ -------
145
+ print(msg: str)
146
+ Prints the given message to all consoles.
147
+ update(label: str, fn_name: str)
148
+ Updates the progress bar and table for the given label.
149
+ done(label: str, title: str)
150
+ Resets the progress bar and table for the given label and prints the final result.
151
+ track(label: str, length: int, completed: int = 0, op: dist.ReduceOp = dist.ReduceOp.AVG, ddp_active: bool = "LOCAL_RANK" in os.environ)
152
+ A decorator for tracking the progress and metrics of a function.
153
+ log(label: str, value_type: str = "value", history: bool = True)
154
+ A decorator for logging the metrics of a function.
155
+ is_best(label: str, key: str) -> bool
156
+ Checks if the latest value of the given key in the label is the best so far.
157
+ state_dict() -> dict
158
+ Returns a dictionary containing the state of the tracker.
159
+ load_state_dict(state_dict: dict) -> Tracker
160
+ Loads the state of the tracker from the given state dictionary.
161
+ """
162
+
163
+ def __init__(
164
+ self,
165
+ writer: SummaryWriter = None,
166
+ log_file: str = None,
167
+ rank: int = 0,
168
+ console_width: int = 100,
169
+ step: int = 0,
170
+ ):
171
+ """
172
+ Initializes the Tracker object.
173
+
174
+ Parameters
175
+ ----------
176
+ writer : SummaryWriter, optional
177
+ A SummaryWriter object for logging the metrics, by default None.
178
+ log_file : str, optional
179
+ The path to the log file, by default None.
180
+ rank : int, optional
181
+ The rank of the current process, by default 0.
182
+ console_width : int, optional
183
+ The width of the console, by default 100.
184
+ step : int, optional
185
+ The current step of the training, by default 0.
186
+ """
187
+ self.metrics = {}
188
+ self.history = {}
189
+ self.writer = writer
190
+ self.rank = rank
191
+ self.step = step
192
+
193
+ # Create progress bars etc.
194
+ self.tasks = {}
195
+ self.pbar = Progress(
196
+ SpinnerColumn(),
197
+ "[progress.description]{task.description}",
198
+ "{task.completed}/{task.total}",
199
+ BarColumn(),
200
+ TimeElapsedColumn(),
201
+ "/",
202
+ TimeRemainingColumn(),
203
+ )
204
+ self.consoles = [Console(width=console_width)]
205
+ self.live = Live(console=self.consoles[0], refresh_per_second=10)
206
+ if log_file is not None:
207
+ self.consoles.append(Console(width=console_width, file=open(log_file, "a")))
208
+
209
+ def print(self, msg):
210
+ """
211
+ Prints the given message to all consoles.
212
+
213
+ Parameters
214
+ ----------
215
+ msg : str
216
+ The message to be printed.
217
+ """
218
+ if self.rank == 0:
219
+ for c in self.consoles:
220
+ c.log(msg)
221
+
222
+ def update(self, label, fn_name):
223
+ """
224
+ Updates the progress bar and table for the given label.
225
+
226
+ Parameters
227
+ ----------
228
+ label : str
229
+ The label of the progress bar and table to be updated.
230
+ fn_name : str
231
+ The name of the function associated with the label.
232
+ """
233
+ if self.rank == 0:
234
+ self.pbar.advance(self.tasks[label]["pbar"])
235
+
236
+ # Create table
237
+ table = Table(title=label, expand=True, box=box.MINIMAL)
238
+ table.add_column("key", style="cyan")
239
+ table.add_column("value", style="bright_blue")
240
+ table.add_column("mean", style="bright_green")
241
+
242
+ keys = self.metrics[label]["value"].keys()
243
+ for k in keys:
244
+ value = self.metrics[label]["value"][k]
245
+ mean = self.metrics[label]["mean"][k]()
246
+ table.add_row(k, f"{value:10.6f}", f"{mean:10.6f}")
247
+
248
+ self.tasks[label]["table"] = table
249
+ tables = [t["table"] for t in self.tasks.values()]
250
+ group = Group(*tables, self.pbar)
251
+ self.live.update(
252
+ Group(
253
+ Padding("", (0, 0)),
254
+ Rule(f"[italic]{fn_name}()", style="white"),
255
+ Padding("", (0, 0)),
256
+ Panel.fit(
257
+ group, padding=(0, 5), title="[b]Progress", border_style="blue"
258
+ ),
259
+ )
260
+ )
261
+
262
+ def done(self, label: str, title: str):
263
+ """
264
+ Resets the progress bar and table for the given label and prints the final result.
265
+
266
+ Parameters
267
+ ----------
268
+ label : str
269
+ The label of the progress bar and table to be reset.
270
+ title : str
271
+ The title to be displayed when printing the final result.
272
+ """
273
+ for label in self.metrics:
274
+ for v in self.metrics[label]["mean"].values():
275
+ v.reset()
276
+
277
+ if self.rank == 0:
278
+ self.pbar.reset(self.tasks[label]["pbar"])
279
+ tables = [t["table"] for t in self.tasks.values()]
280
+ group = Group(Markdown(f"# {title}"), *tables, self.pbar)
281
+ self.print(group)
282
+
283
+ def track(
284
+ self,
285
+ label: str,
286
+ length: int,
287
+ completed: int = 0,
288
+ op: dist.ReduceOp = dist.ReduceOp.AVG,
289
+ ddp_active: bool = "LOCAL_RANK" in os.environ,
290
+ ):
291
+ """
292
+ A decorator for tracking the progress and metrics of a function.
293
+
294
+ Parameters
295
+ ----------
296
+ label : str
297
+ The label to be associated with the progress and metrics.
298
+ length : int
299
+ The total number of iterations to be completed.
300
+ completed : int, optional
301
+ The number of iterations already completed, by default 0.
302
+ op : dist.ReduceOp, optional
303
+ The reduce operation to be used, by default dist.ReduceOp.AVG.
304
+ ddp_active : bool, optional
305
+ Whether the DistributedDataParallel is active, by default "LOCAL_RANK" in os.environ.
306
+ """
307
+ self.tasks[label] = {
308
+ "pbar": self.pbar.add_task(
309
+ f"[white]Iteration ({label})", total=length, completed=completed
310
+ ),
311
+ "table": Table(),
312
+ }
313
+ self.metrics[label] = {
314
+ "value": defaultdict(),
315
+ "mean": defaultdict(lambda: Mean()),
316
+ }
317
+
318
+ def decorator(fn):
319
+ @wraps(fn)
320
+ def decorated(*args, **kwargs):
321
+ output = fn(*args, **kwargs)
322
+ if not isinstance(output, dict):
323
+ self.update(label, fn.__name__)
324
+ return output
325
+ # Collect across all DDP processes
326
+ scalar_keys = []
327
+ for k, v in output.items():
328
+ if isinstance(v, (int, float)):
329
+ v = torch.tensor([v])
330
+ if not torch.is_tensor(v):
331
+ continue
332
+ if ddp_active and v.is_cuda: # pragma: no cover
333
+ dist.all_reduce(v, op=op)
334
+ output[k] = v.detach()
335
+ if torch.numel(v) == 1:
336
+ scalar_keys.append(k)
337
+ output[k] = v.item()
338
+
339
+ # Save the outputs to tracker
340
+ for k, v in output.items():
341
+ if k not in scalar_keys:
342
+ continue
343
+ self.metrics[label]["value"][k] = v
344
+ # Update the running mean
345
+ self.metrics[label]["mean"][k].update(v)
346
+
347
+ self.update(label, fn.__name__)
348
+ return output
349
+
350
+ return decorated
351
+
352
+ return decorator
353
+
354
+ def log(self, label: str, value_type: str = "value", history: bool = True):
355
+ """
356
+ A decorator for logging the metrics of a function.
357
+
358
+ Parameters
359
+ ----------
360
+ label : str
361
+ The label to be associated with the logging.
362
+ value_type : str, optional
363
+ The type of value to be logged, by default "value".
364
+ history : bool, optional
365
+ Whether to save the history of the metrics, by default True.
366
+ """
367
+ assert value_type in ["mean", "value"]
368
+ if history:
369
+ if label not in self.history:
370
+ self.history[label] = defaultdict(default_list)
371
+
372
+ def decorator(fn):
373
+ @wraps(fn)
374
+ def decorated(*args, **kwargs):
375
+ output = fn(*args, **kwargs)
376
+ if self.rank == 0:
377
+ nonlocal value_type, label
378
+ metrics = self.metrics[label][value_type]
379
+ for k, v in metrics.items():
380
+ v = v() if isinstance(v, Mean) else v
381
+ if self.writer is not None:
382
+ # self.writer.add_scalar(f"{k}/{label}", v, self.step)
383
+ self.writer.log_metric(f"{k}_{label}", v, step=self.step)
384
+ if label in self.history:
385
+ self.history[label][k].append(v)
386
+
387
+ if label in self.history:
388
+ self.history[label]["step"].append(self.step)
389
+
390
+ return output
391
+
392
+ return decorated
393
+
394
+ return decorator
395
+
396
+ def is_best(self, label, key):
397
+ """
398
+ Checks if the latest value of the given key in the label is the best so far.
399
+
400
+ Parameters
401
+ ----------
402
+ label : str
403
+ The label of the metrics to be checked.
404
+ key : str
405
+ The key of the metric to be checked.
406
+
407
+ Returns
408
+ -------
409
+ bool
410
+ True if the latest value is the best so far, otherwise False.
411
+ """
412
+ return self.history[label][key][-1] == min(self.history[label][key])
413
+
414
+ def state_dict(self):
415
+ """
416
+ Returns a dictionary containing the state of the tracker.
417
+
418
+ Returns
419
+ -------
420
+ dict
421
+ A dictionary containing the history and step of the tracker.
422
+ """
423
+ return {"history": self.history, "step": self.step}
424
+
425
+ def load_state_dict(self, state_dict):
426
+ """
427
+ Loads the state of the tracker from the given state dictionary.
428
+
429
+ Parameters
430
+ ----------
431
+ state_dict : dict
432
+ A dictionary containing the history and step of the tracker.
433
+
434
+ Returns
435
+ -------
436
+ Tracker
437
+ The tracker object with the loaded state.
438
+ """
439
+ self.history = state_dict["history"]
440
+ self.step = state_dict["step"]
441
+ return self