File size: 14,774 Bytes
957e2dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
import os
import math
from copy import deepcopy

import librosa as li
import numpy as np

import torch
import torch.nn.functional as F

from torch.utils.data import Dataset

from src.data.dataproperties import DataProperties
from src.constants import (
    SAMPLE_RATE,
    HOP_LENGTH
)
from src.attacks.offline.perturbation.voicebox.pitch import PitchEncoder
from src.attacks.offline.perturbation.voicebox.loudness import LoudnessEncoder

from os import path
from tqdm import tqdm
from pathlib import Path

from typing import Union, Iterable

################################################################################
# Cache and load datasets
################################################################################


def ensure_dir(directory: Union[str, Path]):
    """

    Ensure all directories along given path exist, given directory name

    """
    directory = str(directory)
    if len(directory) > 0 and not os.path.exists(directory):
        os.makedirs(directory)


class VoiceBoxDataset(Dataset):

    """

    A Dataset object for the LibriSpeech dataset subsets. The required data can

    be downloaded by running the script `download_librispeech.sh`. This class

    takes audio data from the specified directory and caches tensors to disk.

    """
    def __init__(self,

                 split: str,

                 data_dir: str,

                 cache_dir: str,

                 audio_ext: str,

                 signal_length: Union[float, int],

                 scale: Union[float, int],

                 target: str,

                 features: Union[str, Iterable[str]] = None,

                 sample_rate: int = SAMPLE_RATE,

                 hop_length: int = HOP_LENGTH,

                 batch_format: str = 'dict',

                 *args,

                 **kwargs):
        """

        Load, organize, and cache LibriSpeech dataset.



        Parameters

        ----------

        split (str):         data subset name



        data_dir (str):      dataset root directory



        cache_dir (str):     root directory to which tensors will be saved



        sample_rate (int):   sample rate in Hz



        audio_ext (str):     extension for audio files within dataset



        signal_length (int): length of audio files in samples (if `int` given)

                             or seconds (if `float` given)



        scale (float):       range to which audio will be scaled



        hop_length (int):    hop size for computing frame-wise features (e.g.

                             pitch, loudness)



        target (str):        string specifying target type.



        features (Iterable): strings specifying features to compute for each

                             audio file in the dataset. Must be subset of

                             `pitch`, `periodicity`, `loudness`



        batch_format (str):  format for returning batches. Must be either `dict`

                             or `tuple`

        """

        if batch_format not in ['dict', 'tuple']:
            raise ValueError(f'Invalid batch format {batch_format}')
        self.batch_format = batch_format

        self.data_dir = os.fspath(data_dir)
        self.cache_dir = os.fspath(cache_dir)

        self.audio_ext = audio_ext
        self.sample_rate = sample_rate
        self.scale = scale
        self.hop_length = hop_length

        # if signal length is given as floating-point value, assume time in
        # seconds and convert to samples
        if isinstance(signal_length, float):
            self.signal_length = math.floor(signal_length * self.sample_rate)
        else:
            self.signal_length = signal_length

        # compute frame-equivalent signal length for targets/features,
        # accounting for center-padding in spectrogram implementations
        self.num_frames = math.ceil(self.signal_length / self.hop_length)
        if not self.signal_length % self.hop_length:
            self.num_frames += 1

        # register data properties
        DataProperties.register_properties(
            sample_rate=self.sample_rate,
            signal_length=self.signal_length,
            scale=self.scale
        )

        # check for valid subset
        self.split = self._check_split(split)

        # create directories if necessary
        ensure_dir(path.join(self.cache_dir, self.split))
        ensure_dir(path.join(self.cache_dir, self.split))

        # check for valid target types
        self.target = self._check_target(target)

        # check for valid feature types
        self.features = self._check_features(features)

        # scan all audio files in dataset
        self.audio_list = self._get_audio_list()

        # check for cached audio, targets, and features by name. If missing,
        # build required caches. Cache files are identified by sample rate and
        # hop size where necessary (e.g. for pitch features, but not class
        # targets)
        self._build_audio_cache()
        self._build_target_cache()
        for feature in self.features:
            self._build_feature_cache(feature)

        # load data and target tensors from caches
        self.tx = torch.load(
            Path(self.cache_dir) /
            self.split /
            f'{self._get_audio_id()}.pt')
        self.ty = torch.load(
            Path(self.cache_dir) /
            self.split /
            f'{self._get_target_id()}.pt')

        # load feature tensors from cache and store by name
        self.tf = dict()
        if self.features is not None and self.features:
            for feature in self.features:
                self.tf[feature] = torch.load(
                    Path(self.cache_dir) /
                    self.split /
                    f'{self._get_feature_id(feature)}.pt')

    @staticmethod
    def _check_split(split: str):
        if split not in ['train', 'test']:
            raise ValueError(f'Invalid split {split}')
        return split

    @staticmethod
    def _check_target(target: str):
        if target not in ['class', 'transcript']:
            raise ValueError(f'Invalid target type {target}')
        return target

    @staticmethod
    def _check_features(features: Union[str, Iterable[str]]):
        if features is None or not features:
            features = []
        else:
            if isinstance(features, str):
                features = [features]

            for f in features:
                if f not in ['pitch', 'periodicity', 'loudness']:
                    raise ValueError(f'Invalid feature type {f}')
        return list(features)

    def _get_audio_list(self, *args, **kwargs):
        """Scan for all audio files with given extension"""
        return sorted(
            list((Path(self.data_dir) / self.split).rglob(
                f'*.{self.audio_ext}'))
        )

    def _get_audio_id(self):
        """Identifier for cached audio"""
        return f'{self.sample_rate}-audio'

    def _get_target_id(self):
        """Identifier for cached targets"""
        if self.target in ['class', 'transcript']:
            return f'{self.target}'
        else:
            return f'{self.sample_rate}-{self.hop_length}-{self.target}'

    def _get_feature_id(self, feature: str):
        """Identifier for cached features"""
        return f'{self.sample_rate}-{self.hop_length}-{feature}'

    def _build_audio_cache(self):
        """Load audio data and cache to disk"""

        audio_id = self._get_audio_id()
        audio_cache = list(
            (Path(self.cache_dir) / self.split).rglob(
                f'{audio_id}.pt')
        )
        if len(audio_cache) >= 1:
            return

        # prepare to store audio waveforms and lengths
        waveforms = torch.zeros(len(self.audio_list), 1, self.signal_length)

        pbar = tqdm(self.audio_list, total=len(self.audio_list))
        for i, audio_fn in enumerate(pbar):
            pbar.set_description(
                f'Loading {self.split}: {path.basename(audio_fn)}')

            # load audio and resample, but leave original length
            waveform, _ = li.load(audio_fn,
                                  mono=True,
                                  sr=self.sample_rate)
            waveforms[
                i, :, :min(self.signal_length, len(waveform))
            ] = torch.from_numpy(waveform)[..., :self.signal_length]

        # cache padded tensors and lengths to disk
        torch.save(waveforms,
                   path.join(
                       self.cache_dir,
                       self.split,
                       f'{audio_id}.pt')
                   )

    def _build_target_cache(self):
        """Load targets and cache to disk"""
        raise NotImplementedError()

    def _build_feature_cache(self, feature: str):
        """Load features and cache to disk"""

        feature_id = self._get_feature_id(feature)
        feature_cache = list(
            (Path(self.cache_dir) / self.split).rglob(
                f'{feature_id}.pt')
        )
        if len(feature_cache) >= 1:
            return

        # compute f0, periodicity using PyWorld 'dio' algorithm
        pitch_extractor = PitchEncoder(hop_length=self.hop_length)
        loudness_extractor = LoudnessEncoder(hop_length=self.hop_length)

        # determine 'zero' values for each feature
        zero_pitch, zero_per = pitch_extractor(
            torch.zeros(1, 1, self.signal_length))
        zero_loud = loudness_extractor(torch.zeros(1, 1, self.signal_length))
        pad_val_pitch = zero_pitch.mean().item()
        pad_val_per = zero_per.mean().item()
        pad_val_loud = zero_loud.mean().item()

        # store frame-wise features
        if feature == 'loudness':
            loudness = torch.full(
                (len(self.audio_list), self.num_frames, 1),
                pad_val_loud,
                dtype=torch.float32
            )
        elif feature in ['pitch', 'periodicity']:
            pitch = torch.full(
                (len(self.audio_list), self.num_frames, 1),
                pad_val_pitch,
                dtype=torch.float32
            )
            periodicity = torch.full(
                (len(self.audio_list), self.num_frames, 1),
                pad_val_per,
                dtype=torch.float32
            )

        # iterate over audio
        pbar = tqdm(self.audio_list, total=len(self.audio_list))
        for i, audio_fn in enumerate(pbar):
            pbar.set_description(
                f'Computing {feature} ({self.split}): '
                f'{path.basename(audio_fn)}')

            # load audio and resample, but leave original length
            waveform, _ = li.load(audio_fn,
                                  mono=True,
                                  sr=self.sample_rate,
                                  duration=self.signal_length / self.sample_rate)

            # convert to tensor, insert batch dimension
            waveform = torch.from_numpy(waveform).unsqueeze(0)

            # trim or pad waveform if necessary
            if waveform.shape[-1] >= self.signal_length:
                waveform = waveform[..., :self.signal_length]
            else:
                pad_len = self.signal_length - waveform.shape[-1]
                waveform = F.pad(waveform, (0, pad_len))

            # compute and store pitch/periodicity in tandem
            if feature in ['pitch', 'periodicity']:

                f0, p = pitch_extractor(waveform)
                pitch[
                    i, :min(f0.shape[1], self.num_frames), :
                ] = f0[:, :self.num_frames, :]
                periodicity[
                    i, :min(p.shape[1], self.num_frames), :
                ] = p[:, :self.num_frames, :]

            elif feature == 'loudness':

                l = loudness_extractor(waveform)
                loudness[
                    i, :min(l.shape[1], self.num_frames), :
                ] = l[:, :self.num_frames, :]

            else:
                raise ValueError(f'Invalid feature type {feature}')

        if feature in ['pitch', 'periodicity']:

            # save to disk
            torch.save(pitch,
                       path.join(
                           self.cache_dir,
                           self.split,
                           f'{self._get_feature_id("pitch")}.pt'
                       ))
            torch.save(periodicity,
                       path.join(
                           self.cache_dir,
                           self.split,
                           f'{self._get_feature_id("periodicity")}.pt'
                       ))
        else:
            # save to disk
            torch.save(loudness,
                       path.join(
                           self.cache_dir,
                           self.split,
                           f'{feature_id}.pt'
                       ))

    def __len__(self):
        return len(self.tx)

    def __getitem__(self, idx):
        """Return batch of audio, targets, and optional feature values"""

        if self.batch_format == 'dict':
            # return batch items by name
            batch = {
                'x': self.tx[idx],
                'y': self.ty[idx],
                **{k: self.tf[k][idx] for k in self.tf}
            }
        elif self.batch_format == 'tuple':
            # return batch items in order
            batch = (self.tx[idx], self.ty[idx]) + tuple(
                self.tf[k][idx] for k in self.tf)
        else:
            raise ValueError(f'Invalid batch format {self.batch_format}')

        return batch

    def index_reduce(self, idx):
        """Reduce to a subset by indexing into all stored tensors"""

        new_dataset = deepcopy(self)
        new_dataset.tx = new_dataset.tx[idx]
        new_dataset.ty = new_dataset.ty[idx]
        for feature in new_dataset.features:
            new_dataset.tf[feature] = new_dataset.tf[feature][idx]

        return new_dataset

    def overwrite_dataset(self, x: torch.Tensor, y: torch.Tensor, idx):
        """Overwrite inputs and targets, and select features correspondingly"""

        # support boolean or integer indices
        assert len(idx) <= self.__len__()
        assert len(idx) == self.__len__() or \
               (len(idx) == len(x) and len(idx) == len(y))

        new_dataset = self.index_reduce(idx)
        new_dataset.tx = x
        new_dataset.ty = y

        return new_dataset