ALeLacheur's picture
Voiceblock demo: Attempt 8
957e2dc
import os
import math
from copy import deepcopy
import librosa as li
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from src.data.dataproperties import DataProperties
from src.constants import (
SAMPLE_RATE,
HOP_LENGTH
)
from src.attacks.offline.perturbation.voicebox.pitch import PitchEncoder
from src.attacks.offline.perturbation.voicebox.loudness import LoudnessEncoder
from os import path
from tqdm import tqdm
from pathlib import Path
from typing import Union, Iterable
################################################################################
# Cache and load datasets
################################################################################
def ensure_dir(directory: Union[str, Path]):
"""
Ensure all directories along given path exist, given directory name
"""
directory = str(directory)
if len(directory) > 0 and not os.path.exists(directory):
os.makedirs(directory)
class VoiceBoxDataset(Dataset):
"""
A Dataset object for the LibriSpeech dataset subsets. The required data can
be downloaded by running the script `download_librispeech.sh`. This class
takes audio data from the specified directory and caches tensors to disk.
"""
def __init__(self,
split: str,
data_dir: str,
cache_dir: str,
audio_ext: str,
signal_length: Union[float, int],
scale: Union[float, int],
target: str,
features: Union[str, Iterable[str]] = None,
sample_rate: int = SAMPLE_RATE,
hop_length: int = HOP_LENGTH,
batch_format: str = 'dict',
*args,
**kwargs):
"""
Load, organize, and cache LibriSpeech dataset.
Parameters
----------
split (str): data subset name
data_dir (str): dataset root directory
cache_dir (str): root directory to which tensors will be saved
sample_rate (int): sample rate in Hz
audio_ext (str): extension for audio files within dataset
signal_length (int): length of audio files in samples (if `int` given)
or seconds (if `float` given)
scale (float): range to which audio will be scaled
hop_length (int): hop size for computing frame-wise features (e.g.
pitch, loudness)
target (str): string specifying target type.
features (Iterable): strings specifying features to compute for each
audio file in the dataset. Must be subset of
`pitch`, `periodicity`, `loudness`
batch_format (str): format for returning batches. Must be either `dict`
or `tuple`
"""
if batch_format not in ['dict', 'tuple']:
raise ValueError(f'Invalid batch format {batch_format}')
self.batch_format = batch_format
self.data_dir = os.fspath(data_dir)
self.cache_dir = os.fspath(cache_dir)
self.audio_ext = audio_ext
self.sample_rate = sample_rate
self.scale = scale
self.hop_length = hop_length
# if signal length is given as floating-point value, assume time in
# seconds and convert to samples
if isinstance(signal_length, float):
self.signal_length = math.floor(signal_length * self.sample_rate)
else:
self.signal_length = signal_length
# compute frame-equivalent signal length for targets/features,
# accounting for center-padding in spectrogram implementations
self.num_frames = math.ceil(self.signal_length / self.hop_length)
if not self.signal_length % self.hop_length:
self.num_frames += 1
# register data properties
DataProperties.register_properties(
sample_rate=self.sample_rate,
signal_length=self.signal_length,
scale=self.scale
)
# check for valid subset
self.split = self._check_split(split)
# create directories if necessary
ensure_dir(path.join(self.cache_dir, self.split))
ensure_dir(path.join(self.cache_dir, self.split))
# check for valid target types
self.target = self._check_target(target)
# check for valid feature types
self.features = self._check_features(features)
# scan all audio files in dataset
self.audio_list = self._get_audio_list()
# check for cached audio, targets, and features by name. If missing,
# build required caches. Cache files are identified by sample rate and
# hop size where necessary (e.g. for pitch features, but not class
# targets)
self._build_audio_cache()
self._build_target_cache()
for feature in self.features:
self._build_feature_cache(feature)
# load data and target tensors from caches
self.tx = torch.load(
Path(self.cache_dir) /
self.split /
f'{self._get_audio_id()}.pt')
self.ty = torch.load(
Path(self.cache_dir) /
self.split /
f'{self._get_target_id()}.pt')
# load feature tensors from cache and store by name
self.tf = dict()
if self.features is not None and self.features:
for feature in self.features:
self.tf[feature] = torch.load(
Path(self.cache_dir) /
self.split /
f'{self._get_feature_id(feature)}.pt')
@staticmethod
def _check_split(split: str):
if split not in ['train', 'test']:
raise ValueError(f'Invalid split {split}')
return split
@staticmethod
def _check_target(target: str):
if target not in ['class', 'transcript']:
raise ValueError(f'Invalid target type {target}')
return target
@staticmethod
def _check_features(features: Union[str, Iterable[str]]):
if features is None or not features:
features = []
else:
if isinstance(features, str):
features = [features]
for f in features:
if f not in ['pitch', 'periodicity', 'loudness']:
raise ValueError(f'Invalid feature type {f}')
return list(features)
def _get_audio_list(self, *args, **kwargs):
"""Scan for all audio files with given extension"""
return sorted(
list((Path(self.data_dir) / self.split).rglob(
f'*.{self.audio_ext}'))
)
def _get_audio_id(self):
"""Identifier for cached audio"""
return f'{self.sample_rate}-audio'
def _get_target_id(self):
"""Identifier for cached targets"""
if self.target in ['class', 'transcript']:
return f'{self.target}'
else:
return f'{self.sample_rate}-{self.hop_length}-{self.target}'
def _get_feature_id(self, feature: str):
"""Identifier for cached features"""
return f'{self.sample_rate}-{self.hop_length}-{feature}'
def _build_audio_cache(self):
"""Load audio data and cache to disk"""
audio_id = self._get_audio_id()
audio_cache = list(
(Path(self.cache_dir) / self.split).rglob(
f'{audio_id}.pt')
)
if len(audio_cache) >= 1:
return
# prepare to store audio waveforms and lengths
waveforms = torch.zeros(len(self.audio_list), 1, self.signal_length)
pbar = tqdm(self.audio_list, total=len(self.audio_list))
for i, audio_fn in enumerate(pbar):
pbar.set_description(
f'Loading {self.split}: {path.basename(audio_fn)}')
# load audio and resample, but leave original length
waveform, _ = li.load(audio_fn,
mono=True,
sr=self.sample_rate)
waveforms[
i, :, :min(self.signal_length, len(waveform))
] = torch.from_numpy(waveform)[..., :self.signal_length]
# cache padded tensors and lengths to disk
torch.save(waveforms,
path.join(
self.cache_dir,
self.split,
f'{audio_id}.pt')
)
def _build_target_cache(self):
"""Load targets and cache to disk"""
raise NotImplementedError()
def _build_feature_cache(self, feature: str):
"""Load features and cache to disk"""
feature_id = self._get_feature_id(feature)
feature_cache = list(
(Path(self.cache_dir) / self.split).rglob(
f'{feature_id}.pt')
)
if len(feature_cache) >= 1:
return
# compute f0, periodicity using PyWorld 'dio' algorithm
pitch_extractor = PitchEncoder(hop_length=self.hop_length)
loudness_extractor = LoudnessEncoder(hop_length=self.hop_length)
# determine 'zero' values for each feature
zero_pitch, zero_per = pitch_extractor(
torch.zeros(1, 1, self.signal_length))
zero_loud = loudness_extractor(torch.zeros(1, 1, self.signal_length))
pad_val_pitch = zero_pitch.mean().item()
pad_val_per = zero_per.mean().item()
pad_val_loud = zero_loud.mean().item()
# store frame-wise features
if feature == 'loudness':
loudness = torch.full(
(len(self.audio_list), self.num_frames, 1),
pad_val_loud,
dtype=torch.float32
)
elif feature in ['pitch', 'periodicity']:
pitch = torch.full(
(len(self.audio_list), self.num_frames, 1),
pad_val_pitch,
dtype=torch.float32
)
periodicity = torch.full(
(len(self.audio_list), self.num_frames, 1),
pad_val_per,
dtype=torch.float32
)
# iterate over audio
pbar = tqdm(self.audio_list, total=len(self.audio_list))
for i, audio_fn in enumerate(pbar):
pbar.set_description(
f'Computing {feature} ({self.split}): '
f'{path.basename(audio_fn)}')
# load audio and resample, but leave original length
waveform, _ = li.load(audio_fn,
mono=True,
sr=self.sample_rate,
duration=self.signal_length / self.sample_rate)
# convert to tensor, insert batch dimension
waveform = torch.from_numpy(waveform).unsqueeze(0)
# trim or pad waveform if necessary
if waveform.shape[-1] >= self.signal_length:
waveform = waveform[..., :self.signal_length]
else:
pad_len = self.signal_length - waveform.shape[-1]
waveform = F.pad(waveform, (0, pad_len))
# compute and store pitch/periodicity in tandem
if feature in ['pitch', 'periodicity']:
f0, p = pitch_extractor(waveform)
pitch[
i, :min(f0.shape[1], self.num_frames), :
] = f0[:, :self.num_frames, :]
periodicity[
i, :min(p.shape[1], self.num_frames), :
] = p[:, :self.num_frames, :]
elif feature == 'loudness':
l = loudness_extractor(waveform)
loudness[
i, :min(l.shape[1], self.num_frames), :
] = l[:, :self.num_frames, :]
else:
raise ValueError(f'Invalid feature type {feature}')
if feature in ['pitch', 'periodicity']:
# save to disk
torch.save(pitch,
path.join(
self.cache_dir,
self.split,
f'{self._get_feature_id("pitch")}.pt'
))
torch.save(periodicity,
path.join(
self.cache_dir,
self.split,
f'{self._get_feature_id("periodicity")}.pt'
))
else:
# save to disk
torch.save(loudness,
path.join(
self.cache_dir,
self.split,
f'{feature_id}.pt'
))
def __len__(self):
return len(self.tx)
def __getitem__(self, idx):
"""Return batch of audio, targets, and optional feature values"""
if self.batch_format == 'dict':
# return batch items by name
batch = {
'x': self.tx[idx],
'y': self.ty[idx],
**{k: self.tf[k][idx] for k in self.tf}
}
elif self.batch_format == 'tuple':
# return batch items in order
batch = (self.tx[idx], self.ty[idx]) + tuple(
self.tf[k][idx] for k in self.tf)
else:
raise ValueError(f'Invalid batch format {self.batch_format}')
return batch
def index_reduce(self, idx):
"""Reduce to a subset by indexing into all stored tensors"""
new_dataset = deepcopy(self)
new_dataset.tx = new_dataset.tx[idx]
new_dataset.ty = new_dataset.ty[idx]
for feature in new_dataset.features:
new_dataset.tf[feature] = new_dataset.tf[feature][idx]
return new_dataset
def overwrite_dataset(self, x: torch.Tensor, y: torch.Tensor, idx):
"""Overwrite inputs and targets, and select features correspondingly"""
# support boolean or integer indices
assert len(idx) <= self.__len__()
assert len(idx) == self.__len__() or \
(len(idx) == len(x) and len(idx) == len(y))
new_dataset = self.index_reduce(idx)
new_dataset.tx = x
new_dataset.ty = y
return new_dataset