| | |
| | |
| | |
| | |
| | |
| | """AudioDataset support. In order to handle a larger number of files |
| | without having to scan again the folders, we precompute some metadata |
| | (filename, sample rate, duration), and use that to efficiently sample audio segments. |
| | """ |
| | import argparse |
| | import copy |
| | from concurrent.futures import ThreadPoolExecutor, Future |
| | from dataclasses import dataclass, fields |
| | from contextlib import ExitStack |
| | from functools import lru_cache |
| | import gzip |
| | import json |
| | import logging |
| | import os |
| | from pathlib import Path |
| | import random |
| | import sys |
| | import typing as tp |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | from .audio import audio_read, audio_info |
| | from .audio_utils import convert_audio |
| | from .zip import PathInZip |
| |
|
| | try: |
| | import dora |
| | except ImportError: |
| | dora = None |
| |
|
| |
|
| | @dataclass(order=True) |
| | class BaseInfo: |
| |
|
| | @classmethod |
| | def _dict2fields(cls, dictionary: dict): |
| | return { |
| | field.name: dictionary[field.name] |
| | for field in fields(cls) if field.name in dictionary |
| | } |
| |
|
| | @classmethod |
| | def from_dict(cls, dictionary: dict): |
| | _dictionary = cls._dict2fields(dictionary) |
| | return cls(**_dictionary) |
| |
|
| | def to_dict(self): |
| | return { |
| | field.name: self.__getattribute__(field.name) |
| | for field in fields(self) |
| | } |
| |
|
| |
|
| | @dataclass(order=True) |
| | class AudioMeta(BaseInfo): |
| | path: str |
| | duration: float |
| | sample_rate: int |
| | amplitude: tp.Optional[float] = None |
| | weight: tp.Optional[float] = None |
| | |
| | info_path: tp.Optional[PathInZip] = None |
| |
|
| | @classmethod |
| | def from_dict(cls, dictionary: dict): |
| | base = cls._dict2fields(dictionary) |
| | if 'info_path' in base and base['info_path'] is not None: |
| | base['info_path'] = PathInZip(base['info_path']) |
| | return cls(**base) |
| |
|
| | def to_dict(self): |
| | d = super().to_dict() |
| | if d['info_path'] is not None: |
| | d['info_path'] = str(d['info_path']) |
| | return d |
| |
|
| |
|
| | @dataclass(order=True) |
| | class SegmentInfo(BaseInfo): |
| | meta: AudioMeta |
| | seek_time: float |
| | |
| | |
| | n_frames: int |
| | total_frames: int |
| | sample_rate: int |
| | channels: int |
| |
|
| |
|
| | DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a'] |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta: |
| | """AudioMeta from a path to an audio file. |
| | |
| | Args: |
| | file_path (str): Resolved path of valid audio file. |
| | minimal (bool): Whether to only load the minimal set of metadata (takes longer if not). |
| | Returns: |
| | AudioMeta: Audio file path and its metadata. |
| | """ |
| | info = audio_info(file_path) |
| | amplitude: tp.Optional[float] = None |
| | if not minimal: |
| | wav, sr = audio_read(file_path) |
| | amplitude = wav.abs().max().item() |
| | return AudioMeta(file_path, info.duration, info.sample_rate, amplitude) |
| |
|
| |
|
| | def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta: |
| | """If Dora is available as a dependency, try to resolve potential relative paths |
| | in list of AudioMeta. This method is expected to be used when loading meta from file. |
| | |
| | Args: |
| | m (AudioMeta): Audio meta to resolve. |
| | fast (bool): If True, uses a really fast check for determining if a file |
| | is already absolute or not. Only valid on Linux/Mac. |
| | Returns: |
| | AudioMeta: Audio meta with resolved path. |
| | """ |
| | def is_abs(m): |
| | if fast: |
| | return str(m)[0] == '/' |
| | else: |
| | os.path.isabs(str(m)) |
| |
|
| | if not dora: |
| | return m |
| |
|
| | if not is_abs(m.path): |
| | m.path = dora.git_save.to_absolute_path(m.path) |
| | if m.info_path is not None and not is_abs(m.info_path.zip_path): |
| | m.info_path.zip_path = dora.git_save.to_absolute_path(m.path) |
| | return m |
| |
|
| |
|
| | def find_audio_files(path: tp.Union[Path, str], |
| | exts: tp.List[str] = DEFAULT_EXTS, |
| | resolve: bool = True, |
| | minimal: bool = True, |
| | progress: bool = False, |
| | workers: int = 0) -> tp.List[AudioMeta]: |
| | """Build a list of AudioMeta from a given path, |
| | collecting relevant audio files and fetching meta info. |
| | |
| | Args: |
| | path (str or Path): Path to folder containing audio files. |
| | exts (list of str): List of file extensions to consider for audio files. |
| | minimal (bool): Whether to only load the minimal set of metadata (takes longer if not). |
| | progress (bool): Whether to log progress on audio files collection. |
| | workers (int): number of parallel workers, if 0, use only the current thread. |
| | Returns: |
| | list of AudioMeta: List of audio file path and its metadata. |
| | """ |
| | audio_files = [] |
| | futures: tp.List[Future] = [] |
| | pool: tp.Optional[ThreadPoolExecutor] = None |
| | with ExitStack() as stack: |
| | if workers > 0: |
| | pool = ThreadPoolExecutor(workers) |
| | stack.enter_context(pool) |
| |
|
| | if progress: |
| | print("Finding audio files...") |
| | for root, folders, files in os.walk(path, followlinks=True): |
| | for file in files: |
| | full_path = Path(root) / file |
| | if full_path.suffix.lower() in exts: |
| | audio_files.append(full_path) |
| | if pool is not None: |
| | futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal)) |
| | if progress: |
| | print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr) |
| |
|
| | if progress: |
| | print("Getting audio metadata...") |
| | meta: tp.List[AudioMeta] = [] |
| | for idx, file_path in enumerate(audio_files): |
| | try: |
| | if pool is None: |
| | m = _get_audio_meta(str(file_path), minimal) |
| | else: |
| | m = futures[idx].result() |
| | if resolve: |
| | m = _resolve_audio_meta(m) |
| | except Exception as err: |
| | print("Error with", str(file_path), err, file=sys.stderr) |
| | continue |
| | meta.append(m) |
| | if progress: |
| | print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr) |
| | meta.sort() |
| | return meta |
| |
|
| |
|
| | def load_audio_meta(path: tp.Union[str, Path], |
| | resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]: |
| | """Load list of AudioMeta from an optionally compressed json file. |
| | |
| | Args: |
| | path (str or Path): Path to JSON file. |
| | resolve (bool): Whether to resolve the path from AudioMeta (default=True). |
| | fast (bool): activates some tricks to make things faster. |
| | Returns: |
| | list of AudioMeta: List of audio file path and its total duration. |
| | """ |
| | open_fn = gzip.open if str(path).lower().endswith('.gz') else open |
| | with open_fn(path, 'rb') as fp: |
| | lines = fp.readlines() |
| | meta = [] |
| | for line in lines: |
| | d = json.loads(line) |
| | m = AudioMeta.from_dict(d) |
| | if resolve: |
| | m = _resolve_audio_meta(m, fast=fast) |
| | meta.append(m) |
| | return meta |
| |
|
| |
|
| | def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]): |
| | """Save the audio metadata to the file pointer as json. |
| | |
| | Args: |
| | path (str or Path): Path to JSON file. |
| | metadata (list of BaseAudioMeta): List of audio meta to save. |
| | """ |
| | Path(path).parent.mkdir(exist_ok=True, parents=True) |
| | open_fn = gzip.open if str(path).lower().endswith('.gz') else open |
| | with open_fn(path, 'wb') as fp: |
| | for m in meta: |
| | json_str = json.dumps(m.to_dict()) + '\n' |
| | json_bytes = json_str.encode('utf-8') |
| | fp.write(json_bytes) |
| |
|
| |
|
| | class AudioDataset: |
| | """Base audio dataset. |
| | |
| | The dataset takes a list of AudioMeta and create a dataset composed of segments of audio |
| | and potentially additional information, by creating random segments from the list of audio |
| | files referenced in the metadata and applying minimal data pre-processing such as resampling, |
| | mixing of channels, padding, etc. |
| | |
| | If no segment_duration value is provided, the AudioDataset will return the full wav for each |
| | audio file. Otherwise, it will randomly sample audio files and create a segment of the specified |
| | duration, applying padding if required. |
| | |
| | By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True |
| | allows to return a tuple containing the torch Tensor and additional metadata on the segment and the |
| | original audio meta. |
| | |
| | Note that you can call `start_epoch(epoch)` in order to get |
| | a deterministic "randomization" for `shuffle=True`. |
| | For a given epoch and dataset index, this will always return the same extract. |
| | You can get back some diversity by setting the `shuffle_seed` param. |
| | |
| | Args: |
| | meta (list of AudioMeta): List of audio files metadata. |
| | segment_duration (float, optional): Optional segment duration of audio to load. |
| | If not specified, the dataset will load the full audio segment from the file. |
| | shuffle (bool): Set to `True` to have the data reshuffled at every epoch. |
| | sample_rate (int): Target sample rate of the loaded audio samples. |
| | channels (int): Target number of channels of the loaded audio samples. |
| | sample_on_duration (bool): Set to `True` to sample segments with probability |
| | dependent on audio file duration. This is only used if `segment_duration` is provided. |
| | sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of |
| | `AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product |
| | of the file duration and file weight. This is only used if `segment_duration` is provided. |
| | min_segment_ratio (float): Minimum segment ratio to use when the audio file |
| | is shorter than the desired segment. |
| | max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset. |
| | return_info (bool): Whether to return the wav only or return wav along with segment info and metadata. |
| | min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided |
| | audio shorter than this will be filtered out. |
| | max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided |
| | audio longer than this will be filtered out. |
| | shuffle_seed (int): can be used to further randomize |
| | load_wav (bool): if False, skip loading the wav but returns a tensor of 0 |
| | with the expected segment_duration (which must be provided if load_wav is False). |
| | permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration` |
| | are False. Will ensure a permutation on files when going through the dataset. |
| | In that case the epoch number must be provided in order for the model |
| | to continue the permutation across epochs. In that case, it is assumed |
| | that `num_samples = total_batch_size * num_updates_per_epoch`, with |
| | `total_batch_size` the overall batch size accounting for all gpus. |
| | """ |
| | def __init__(self, |
| | meta: tp.List[AudioMeta], |
| | segment_duration: tp.Optional[float] = None, |
| | shuffle: bool = True, |
| | num_samples: int = 10_000, |
| | sample_rate: int = 48_000, |
| | channels: int = 2, |
| | pad: bool = True, |
| | sample_on_duration: bool = True, |
| | sample_on_weight: bool = True, |
| | min_segment_ratio: float = 0.5, |
| | max_read_retry: int = 10, |
| | return_info: bool = False, |
| | min_audio_duration: tp.Optional[float] = None, |
| | max_audio_duration: tp.Optional[float] = None, |
| | shuffle_seed: int = 0, |
| | load_wav: bool = True, |
| | permutation_on_files: bool = False, |
| | ): |
| | assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta." |
| | assert segment_duration is None or segment_duration > 0 |
| | assert segment_duration is None or min_segment_ratio >= 0 |
| | self.segment_duration = segment_duration |
| | self.min_segment_ratio = min_segment_ratio |
| | self.max_audio_duration = max_audio_duration |
| | self.min_audio_duration = min_audio_duration |
| | if self.min_audio_duration is not None and self.max_audio_duration is not None: |
| | assert self.min_audio_duration <= self.max_audio_duration |
| | self.meta: tp.List[AudioMeta] = self._filter_duration(meta) |
| | assert len(self.meta) |
| | self.total_duration = sum(d.duration for d in self.meta) |
| |
|
| | if segment_duration is None: |
| | num_samples = len(self.meta) |
| | self.num_samples = num_samples |
| | self.shuffle = shuffle |
| | self.sample_rate = sample_rate |
| | self.channels = channels |
| | self.pad = pad |
| | self.sample_on_weight = sample_on_weight |
| | self.sample_on_duration = sample_on_duration |
| | self.sampling_probabilities = self._get_sampling_probabilities() |
| | self.max_read_retry = max_read_retry |
| | self.return_info = return_info |
| | self.shuffle_seed = shuffle_seed |
| | self.current_epoch: tp.Optional[int] = None |
| | self.load_wav = load_wav |
| | if not load_wav: |
| | assert segment_duration is not None |
| | self.permutation_on_files = permutation_on_files |
| | if permutation_on_files: |
| | assert not self.sample_on_duration |
| | assert not self.sample_on_weight |
| | assert self.shuffle |
| |
|
| | def start_epoch(self, epoch: int): |
| | self.current_epoch = epoch |
| |
|
| | def __len__(self): |
| | return self.num_samples |
| |
|
| | def _get_sampling_probabilities(self, normalized: bool = True): |
| | """Return the sampling probabilities for each file inside `self.meta`.""" |
| | scores: tp.List[float] = [] |
| | for file_meta in self.meta: |
| | score = 1. |
| | if self.sample_on_weight and file_meta.weight is not None: |
| | score *= file_meta.weight |
| | if self.sample_on_duration: |
| | score *= file_meta.duration |
| | scores.append(score) |
| | probabilities = torch.tensor(scores) |
| | if normalized: |
| | probabilities /= probabilities.sum() |
| | return probabilities |
| |
|
| | @staticmethod |
| | @lru_cache(16) |
| | def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int): |
| | |
| | |
| | rng = torch.Generator() |
| | rng.manual_seed(base_seed + permutation_index) |
| | return torch.randperm(num_files, generator=rng) |
| |
|
| | def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta: |
| | """Sample a given file from `self.meta`. Can be overridden in subclasses. |
| | This is only called if `segment_duration` is not None. |
| | |
| | You must use the provided random number generator `rng` for reproducibility. |
| | You can further make use of the index accessed. |
| | """ |
| | if self.permutation_on_files: |
| | assert self.current_epoch is not None |
| | total_index = self.current_epoch * len(self) + index |
| | permutation_index = total_index // len(self.meta) |
| | relative_index = total_index % len(self.meta) |
| | permutation = AudioDataset._get_file_permutation( |
| | len(self.meta), permutation_index, self.shuffle_seed) |
| | file_index = permutation[relative_index] |
| | return self.meta[file_index] |
| |
|
| | if not self.sample_on_weight and not self.sample_on_duration: |
| | file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item()) |
| | else: |
| | file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item()) |
| |
|
| | return self.meta[file_index] |
| |
|
| | def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1): |
| | |
| | if self.load_wav: |
| | return audio_read(path, seek_time, duration, pad=False) |
| | else: |
| | assert self.segment_duration is not None |
| | n_frames = int(self.sample_rate * self.segment_duration) |
| | return torch.zeros(self.channels, n_frames), self.sample_rate |
| |
|
| | def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]: |
| | if self.segment_duration is None: |
| | file_meta = self.meta[index] |
| | out, sr = audio_read(file_meta.path) |
| | out = convert_audio(out, sr, self.sample_rate, self.channels) |
| | n_frames = out.shape[-1] |
| | segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames, |
| | sample_rate=self.sample_rate, channels=out.shape[0]) |
| | else: |
| | rng = torch.Generator() |
| | if self.shuffle: |
| | |
| | |
| | if self.current_epoch is None: |
| | rng.manual_seed(index + self.num_samples * random.randint(0, 2**24)) |
| | else: |
| | rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed)) |
| | else: |
| | |
| | rng.manual_seed(index) |
| |
|
| | for retry in range(self.max_read_retry): |
| | file_meta = self.sample_file(index, rng) |
| | |
| | |
| | max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio) |
| | seek_time = torch.rand(1, generator=rng).item() * max_seek |
| | try: |
| | out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False) |
| | out = convert_audio(out, sr, self.sample_rate, self.channels) |
| | n_frames = out.shape[-1] |
| | target_frames = int(self.segment_duration * self.sample_rate) |
| | if self.pad: |
| | out = F.pad(out, (0, target_frames - n_frames)) |
| | segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames, |
| | sample_rate=self.sample_rate, channels=out.shape[0]) |
| | except Exception as exc: |
| | logger.warning("Error opening file %s: %r", file_meta.path, exc) |
| | if retry == self.max_read_retry - 1: |
| | raise |
| | else: |
| | break |
| |
|
| | if self.return_info: |
| | |
| | return out, segment_info |
| | else: |
| | return out |
| |
|
| | def collater(self, samples): |
| | """The collater function has to be provided to the dataloader |
| | if AudioDataset has return_info=True in order to properly collate |
| | the samples of a batch. |
| | """ |
| | if self.segment_duration is None and len(samples) > 1: |
| | assert self.pad, "Must allow padding when batching examples of different durations." |
| |
|
| | |
| | to_pad = self.segment_duration is None and self.pad |
| | if to_pad: |
| | max_len = max([wav.shape[-1] for wav, _ in samples]) |
| |
|
| | def _pad_wav(wav): |
| | return F.pad(wav, (0, max_len - wav.shape[-1])) |
| |
|
| | if self.return_info: |
| | if len(samples) > 0: |
| | assert len(samples[0]) == 2 |
| | assert isinstance(samples[0][0], torch.Tensor) |
| | assert isinstance(samples[0][1], SegmentInfo) |
| |
|
| | wavs = [wav for wav, _ in samples] |
| | segment_infos = [copy.deepcopy(info) for _, info in samples] |
| |
|
| | if to_pad: |
| | |
| | for i in range(len(samples)): |
| | |
| | segment_infos[i].total_frames = max_len |
| | wavs[i] = _pad_wav(wavs[i]) |
| |
|
| | wav = torch.stack(wavs) |
| | return wav, segment_infos |
| | else: |
| | assert isinstance(samples[0], torch.Tensor) |
| | if to_pad: |
| | samples = [_pad_wav(s) for s in samples] |
| | return torch.stack(samples) |
| |
|
| | def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]: |
| | """Filters out audio files with audio durations that will not allow to sample examples from them.""" |
| | orig_len = len(meta) |
| |
|
| | |
| | if self.min_audio_duration is not None: |
| | meta = [m for m in meta if m.duration >= self.min_audio_duration] |
| |
|
| | |
| | if self.max_audio_duration is not None: |
| | meta = [m for m in meta if m.duration <= self.max_audio_duration] |
| |
|
| | filtered_len = len(meta) |
| | removed_percentage = 100*(1-float(filtered_len)/orig_len) |
| | msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage |
| | if removed_percentage < 10: |
| | logging.debug(msg) |
| | else: |
| | logging.warning(msg) |
| | return meta |
| |
|
| | @classmethod |
| | def from_meta(cls, root: tp.Union[str, Path], **kwargs): |
| | """Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file. |
| | |
| | Args: |
| | root (str or Path): Path to root folder containing audio files. |
| | kwargs: Additional keyword arguments for the AudioDataset. |
| | """ |
| | root = Path(root) |
| | if root.is_dir(): |
| | if (root / 'data.jsonl').exists(): |
| | root = root / 'data.jsonl' |
| | elif (root / 'data.jsonl.gz').exists(): |
| | root = root / 'data.jsonl.gz' |
| | else: |
| | raise ValueError("Don't know where to read metadata from in the dir. " |
| | "Expecting either a data.jsonl or data.jsonl.gz file but none found.") |
| | meta = load_audio_meta(root) |
| | return cls(meta, **kwargs) |
| |
|
| | @classmethod |
| | def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True, |
| | exts: tp.List[str] = DEFAULT_EXTS, **kwargs): |
| | """Instantiate AudioDataset from a path containing (possibly nested) audio files. |
| | |
| | Args: |
| | root (str or Path): Path to root folder containing audio files. |
| | minimal_meta (bool): Whether to only load minimal metadata or not. |
| | exts (list of str): Extensions for audio files. |
| | kwargs: Additional keyword arguments for the AudioDataset. |
| | """ |
| | root = Path(root) |
| | if root.is_file(): |
| | meta = load_audio_meta(root, resolve=True) |
| | else: |
| | meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True) |
| | return cls(meta, **kwargs) |
| |
|
| |
|
| | def main(): |
| | logging.basicConfig(stream=sys.stderr, level=logging.INFO) |
| | parser = argparse.ArgumentParser( |
| | prog='audio_dataset', |
| | description='Generate .jsonl files by scanning a folder.') |
| | parser.add_argument('root', help='Root folder with all the audio files') |
| | parser.add_argument('output_meta_file', |
| | help='Output file to store the metadata, ') |
| | parser.add_argument('--complete', |
| | action='store_false', dest='minimal', default=True, |
| | help='Retrieve all metadata, even the one that are expansive ' |
| | 'to compute (e.g. normalization).') |
| | parser.add_argument('--resolve', |
| | action='store_true', default=False, |
| | help='Resolve the paths to be absolute and with no symlinks.') |
| | parser.add_argument('--workers', |
| | default=10, type=int, |
| | help='Number of workers.') |
| | args = parser.parse_args() |
| | meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True, |
| | resolve=args.resolve, minimal=args.minimal, workers=args.workers) |
| | save_audio_meta(args.output_meta_file, meta) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|