Spaces:
Sleeping
Sleeping
| import os | |
| import math | |
| from copy import deepcopy | |
| import librosa as li | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset | |
| from src.data.dataproperties import DataProperties | |
| from src.constants import ( | |
| SAMPLE_RATE, | |
| HOP_LENGTH | |
| ) | |
| from src.attacks.offline.perturbation.voicebox.pitch import PitchEncoder | |
| from src.attacks.offline.perturbation.voicebox.loudness import LoudnessEncoder | |
| from os import path | |
| from tqdm import tqdm | |
| from pathlib import Path | |
| from typing import Union, Iterable | |
| ################################################################################ | |
| # Cache and load datasets | |
| ################################################################################ | |
| def ensure_dir(directory: Union[str, Path]): | |
| """ | |
| Ensure all directories along given path exist, given directory name | |
| """ | |
| directory = str(directory) | |
| if len(directory) > 0 and not os.path.exists(directory): | |
| os.makedirs(directory) | |
| class VoiceBoxDataset(Dataset): | |
| """ | |
| A Dataset object for the LibriSpeech dataset subsets. The required data can | |
| be downloaded by running the script `download_librispeech.sh`. This class | |
| takes audio data from the specified directory and caches tensors to disk. | |
| """ | |
| def __init__(self, | |
| split: str, | |
| data_dir: str, | |
| cache_dir: str, | |
| audio_ext: str, | |
| signal_length: Union[float, int], | |
| scale: Union[float, int], | |
| target: str, | |
| features: Union[str, Iterable[str]] = None, | |
| sample_rate: int = SAMPLE_RATE, | |
| hop_length: int = HOP_LENGTH, | |
| batch_format: str = 'dict', | |
| *args, | |
| **kwargs): | |
| """ | |
| Load, organize, and cache LibriSpeech dataset. | |
| Parameters | |
| ---------- | |
| split (str): data subset name | |
| data_dir (str): dataset root directory | |
| cache_dir (str): root directory to which tensors will be saved | |
| sample_rate (int): sample rate in Hz | |
| audio_ext (str): extension for audio files within dataset | |
| signal_length (int): length of audio files in samples (if `int` given) | |
| or seconds (if `float` given) | |
| scale (float): range to which audio will be scaled | |
| hop_length (int): hop size for computing frame-wise features (e.g. | |
| pitch, loudness) | |
| target (str): string specifying target type. | |
| features (Iterable): strings specifying features to compute for each | |
| audio file in the dataset. Must be subset of | |
| `pitch`, `periodicity`, `loudness` | |
| batch_format (str): format for returning batches. Must be either `dict` | |
| or `tuple` | |
| """ | |
| if batch_format not in ['dict', 'tuple']: | |
| raise ValueError(f'Invalid batch format {batch_format}') | |
| self.batch_format = batch_format | |
| self.data_dir = os.fspath(data_dir) | |
| self.cache_dir = os.fspath(cache_dir) | |
| self.audio_ext = audio_ext | |
| self.sample_rate = sample_rate | |
| self.scale = scale | |
| self.hop_length = hop_length | |
| # if signal length is given as floating-point value, assume time in | |
| # seconds and convert to samples | |
| if isinstance(signal_length, float): | |
| self.signal_length = math.floor(signal_length * self.sample_rate) | |
| else: | |
| self.signal_length = signal_length | |
| # compute frame-equivalent signal length for targets/features, | |
| # accounting for center-padding in spectrogram implementations | |
| self.num_frames = math.ceil(self.signal_length / self.hop_length) | |
| if not self.signal_length % self.hop_length: | |
| self.num_frames += 1 | |
| # register data properties | |
| DataProperties.register_properties( | |
| sample_rate=self.sample_rate, | |
| signal_length=self.signal_length, | |
| scale=self.scale | |
| ) | |
| # check for valid subset | |
| self.split = self._check_split(split) | |
| # create directories if necessary | |
| ensure_dir(path.join(self.cache_dir, self.split)) | |
| ensure_dir(path.join(self.cache_dir, self.split)) | |
| # check for valid target types | |
| self.target = self._check_target(target) | |
| # check for valid feature types | |
| self.features = self._check_features(features) | |
| # scan all audio files in dataset | |
| self.audio_list = self._get_audio_list() | |
| # check for cached audio, targets, and features by name. If missing, | |
| # build required caches. Cache files are identified by sample rate and | |
| # hop size where necessary (e.g. for pitch features, but not class | |
| # targets) | |
| self._build_audio_cache() | |
| self._build_target_cache() | |
| for feature in self.features: | |
| self._build_feature_cache(feature) | |
| # load data and target tensors from caches | |
| self.tx = torch.load( | |
| Path(self.cache_dir) / | |
| self.split / | |
| f'{self._get_audio_id()}.pt') | |
| self.ty = torch.load( | |
| Path(self.cache_dir) / | |
| self.split / | |
| f'{self._get_target_id()}.pt') | |
| # load feature tensors from cache and store by name | |
| self.tf = dict() | |
| if self.features is not None and self.features: | |
| for feature in self.features: | |
| self.tf[feature] = torch.load( | |
| Path(self.cache_dir) / | |
| self.split / | |
| f'{self._get_feature_id(feature)}.pt') | |
| def _check_split(split: str): | |
| if split not in ['train', 'test']: | |
| raise ValueError(f'Invalid split {split}') | |
| return split | |
| def _check_target(target: str): | |
| if target not in ['class', 'transcript']: | |
| raise ValueError(f'Invalid target type {target}') | |
| return target | |
| def _check_features(features: Union[str, Iterable[str]]): | |
| if features is None or not features: | |
| features = [] | |
| else: | |
| if isinstance(features, str): | |
| features = [features] | |
| for f in features: | |
| if f not in ['pitch', 'periodicity', 'loudness']: | |
| raise ValueError(f'Invalid feature type {f}') | |
| return list(features) | |
| def _get_audio_list(self, *args, **kwargs): | |
| """Scan for all audio files with given extension""" | |
| return sorted( | |
| list((Path(self.data_dir) / self.split).rglob( | |
| f'*.{self.audio_ext}')) | |
| ) | |
| def _get_audio_id(self): | |
| """Identifier for cached audio""" | |
| return f'{self.sample_rate}-audio' | |
| def _get_target_id(self): | |
| """Identifier for cached targets""" | |
| if self.target in ['class', 'transcript']: | |
| return f'{self.target}' | |
| else: | |
| return f'{self.sample_rate}-{self.hop_length}-{self.target}' | |
| def _get_feature_id(self, feature: str): | |
| """Identifier for cached features""" | |
| return f'{self.sample_rate}-{self.hop_length}-{feature}' | |
| def _build_audio_cache(self): | |
| """Load audio data and cache to disk""" | |
| audio_id = self._get_audio_id() | |
| audio_cache = list( | |
| (Path(self.cache_dir) / self.split).rglob( | |
| f'{audio_id}.pt') | |
| ) | |
| if len(audio_cache) >= 1: | |
| return | |
| # prepare to store audio waveforms and lengths | |
| waveforms = torch.zeros(len(self.audio_list), 1, self.signal_length) | |
| pbar = tqdm(self.audio_list, total=len(self.audio_list)) | |
| for i, audio_fn in enumerate(pbar): | |
| pbar.set_description( | |
| f'Loading {self.split}: {path.basename(audio_fn)}') | |
| # load audio and resample, but leave original length | |
| waveform, _ = li.load(audio_fn, | |
| mono=True, | |
| sr=self.sample_rate) | |
| waveforms[ | |
| i, :, :min(self.signal_length, len(waveform)) | |
| ] = torch.from_numpy(waveform)[..., :self.signal_length] | |
| # cache padded tensors and lengths to disk | |
| torch.save(waveforms, | |
| path.join( | |
| self.cache_dir, | |
| self.split, | |
| f'{audio_id}.pt') | |
| ) | |
| def _build_target_cache(self): | |
| """Load targets and cache to disk""" | |
| raise NotImplementedError() | |
| def _build_feature_cache(self, feature: str): | |
| """Load features and cache to disk""" | |
| feature_id = self._get_feature_id(feature) | |
| feature_cache = list( | |
| (Path(self.cache_dir) / self.split).rglob( | |
| f'{feature_id}.pt') | |
| ) | |
| if len(feature_cache) >= 1: | |
| return | |
| # compute f0, periodicity using PyWorld 'dio' algorithm | |
| pitch_extractor = PitchEncoder(hop_length=self.hop_length) | |
| loudness_extractor = LoudnessEncoder(hop_length=self.hop_length) | |
| # determine 'zero' values for each feature | |
| zero_pitch, zero_per = pitch_extractor( | |
| torch.zeros(1, 1, self.signal_length)) | |
| zero_loud = loudness_extractor(torch.zeros(1, 1, self.signal_length)) | |
| pad_val_pitch = zero_pitch.mean().item() | |
| pad_val_per = zero_per.mean().item() | |
| pad_val_loud = zero_loud.mean().item() | |
| # store frame-wise features | |
| if feature == 'loudness': | |
| loudness = torch.full( | |
| (len(self.audio_list), self.num_frames, 1), | |
| pad_val_loud, | |
| dtype=torch.float32 | |
| ) | |
| elif feature in ['pitch', 'periodicity']: | |
| pitch = torch.full( | |
| (len(self.audio_list), self.num_frames, 1), | |
| pad_val_pitch, | |
| dtype=torch.float32 | |
| ) | |
| periodicity = torch.full( | |
| (len(self.audio_list), self.num_frames, 1), | |
| pad_val_per, | |
| dtype=torch.float32 | |
| ) | |
| # iterate over audio | |
| pbar = tqdm(self.audio_list, total=len(self.audio_list)) | |
| for i, audio_fn in enumerate(pbar): | |
| pbar.set_description( | |
| f'Computing {feature} ({self.split}): ' | |
| f'{path.basename(audio_fn)}') | |
| # load audio and resample, but leave original length | |
| waveform, _ = li.load(audio_fn, | |
| mono=True, | |
| sr=self.sample_rate, | |
| duration=self.signal_length / self.sample_rate) | |
| # convert to tensor, insert batch dimension | |
| waveform = torch.from_numpy(waveform).unsqueeze(0) | |
| # trim or pad waveform if necessary | |
| if waveform.shape[-1] >= self.signal_length: | |
| waveform = waveform[..., :self.signal_length] | |
| else: | |
| pad_len = self.signal_length - waveform.shape[-1] | |
| waveform = F.pad(waveform, (0, pad_len)) | |
| # compute and store pitch/periodicity in tandem | |
| if feature in ['pitch', 'periodicity']: | |
| f0, p = pitch_extractor(waveform) | |
| pitch[ | |
| i, :min(f0.shape[1], self.num_frames), : | |
| ] = f0[:, :self.num_frames, :] | |
| periodicity[ | |
| i, :min(p.shape[1], self.num_frames), : | |
| ] = p[:, :self.num_frames, :] | |
| elif feature == 'loudness': | |
| l = loudness_extractor(waveform) | |
| loudness[ | |
| i, :min(l.shape[1], self.num_frames), : | |
| ] = l[:, :self.num_frames, :] | |
| else: | |
| raise ValueError(f'Invalid feature type {feature}') | |
| if feature in ['pitch', 'periodicity']: | |
| # save to disk | |
| torch.save(pitch, | |
| path.join( | |
| self.cache_dir, | |
| self.split, | |
| f'{self._get_feature_id("pitch")}.pt' | |
| )) | |
| torch.save(periodicity, | |
| path.join( | |
| self.cache_dir, | |
| self.split, | |
| f'{self._get_feature_id("periodicity")}.pt' | |
| )) | |
| else: | |
| # save to disk | |
| torch.save(loudness, | |
| path.join( | |
| self.cache_dir, | |
| self.split, | |
| f'{feature_id}.pt' | |
| )) | |
| def __len__(self): | |
| return len(self.tx) | |
| def __getitem__(self, idx): | |
| """Return batch of audio, targets, and optional feature values""" | |
| if self.batch_format == 'dict': | |
| # return batch items by name | |
| batch = { | |
| 'x': self.tx[idx], | |
| 'y': self.ty[idx], | |
| **{k: self.tf[k][idx] for k in self.tf} | |
| } | |
| elif self.batch_format == 'tuple': | |
| # return batch items in order | |
| batch = (self.tx[idx], self.ty[idx]) + tuple( | |
| self.tf[k][idx] for k in self.tf) | |
| else: | |
| raise ValueError(f'Invalid batch format {self.batch_format}') | |
| return batch | |
| def index_reduce(self, idx): | |
| """Reduce to a subset by indexing into all stored tensors""" | |
| new_dataset = deepcopy(self) | |
| new_dataset.tx = new_dataset.tx[idx] | |
| new_dataset.ty = new_dataset.ty[idx] | |
| for feature in new_dataset.features: | |
| new_dataset.tf[feature] = new_dataset.tf[feature][idx] | |
| return new_dataset | |
| def overwrite_dataset(self, x: torch.Tensor, y: torch.Tensor, idx): | |
| """Overwrite inputs and targets, and select features correspondingly""" | |
| # support boolean or integer indices | |
| assert len(idx) <= self.__len__() | |
| assert len(idx) == self.__len__() or \ | |
| (len(idx) == len(x) and len(idx) == len(y)) | |
| new_dataset = self.index_reduce(idx) | |
| new_dataset.tx = x | |
| new_dataset.ty = y | |
| return new_dataset | |