| |
| |
| |
| |
| |
|
|
| import argparse |
| import copy |
| from concurrent.futures import ThreadPoolExecutor, Future |
| from dataclasses import dataclass, fields |
| from contextlib import ExitStack |
| import gzip |
| import json |
| import logging |
| import os |
| from pathlib import Path |
| import random |
| import sys |
| import typing as tp |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| from .audio import audio_read, audio_info |
| from .audio_utils import convert_audio |
| from .zip import PathInZip |
|
|
| try: |
| import dora |
| except ImportError: |
| dora = None |
|
|
|
|
| @dataclass(order=True) |
| class BaseInfo: |
|
|
| @classmethod |
| def _dict2fields(cls, dictionary: dict): |
| return { |
| field.name: dictionary[field.name] |
| for field in fields(cls) if field.name in dictionary |
| } |
|
|
| @classmethod |
| def from_dict(cls, dictionary: dict): |
| _dictionary = cls._dict2fields(dictionary) |
| return cls(**_dictionary) |
|
|
| def to_dict(self): |
| return { |
| field.name: self.__getattribute__(field.name) |
| for field in fields(self) |
| } |
|
|
|
|
| @dataclass(order=True) |
| class AudioMeta(BaseInfo): |
| path: str |
| duration: float |
| sample_rate: int |
| amplitude: tp.Optional[float] = None |
| weight: tp.Optional[float] = None |
| |
| info_path: tp.Optional[PathInZip] = None |
|
|
| @classmethod |
| def from_dict(cls, dictionary: dict): |
| base = cls._dict2fields(dictionary) |
| if 'info_path' in base and base['info_path'] is not None: |
| base['info_path'] = PathInZip(base['info_path']) |
| return cls(**base) |
|
|
| def to_dict(self): |
| d = super().to_dict() |
| if d['info_path'] is not None: |
| d['info_path'] = str(d['info_path']) |
| return d |
|
|
|
|
| @dataclass(order=True) |
| class SegmentInfo(BaseInfo): |
| meta: AudioMeta |
| seek_time: float |
| n_frames: int |
| total_frames: int |
| sample_rate: int |
|
|
|
|
| DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a'] |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta: |
| """AudioMeta from a path to an audio file. |
| |
| Args: |
| file_path (str): Resolved path of valid audio file. |
| minimal (bool): Whether to only load the minimal set of metadata (takes longer if not). |
| Returns: |
| AudioMeta: Audio file path and its metadata. |
| """ |
| info = audio_info(file_path) |
| amplitude: tp.Optional[float] = None |
| if not minimal: |
| wav, sr = audio_read(file_path) |
| amplitude = wav.abs().max().item() |
| return AudioMeta(file_path, info.duration, info.sample_rate, amplitude) |
|
|
|
|
| def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta: |
| """If Dora is available as a dependency, try to resolve potential relative paths |
| in list of AudioMeta. This method is expected to be used when loading meta from file. |
| |
| Args: |
| m (AudioMeta): Audio meta to resolve. |
| fast (bool): If True, uses a really fast check for determining if a file is already absolute or not. |
| Only valid on Linux/Mac. |
| Returns: |
| AudioMeta: Audio meta with resolved path. |
| """ |
| def is_abs(m): |
| if fast: |
| return str(m)[0] == '/' |
| else: |
| os.path.isabs(str(m)) |
|
|
| if not dora: |
| return m |
|
|
| if not is_abs(m.path): |
| m.path = dora.git_save.to_absolute_path(m.path) |
| if m.info_path is not None and not is_abs(m.info_path.zip_path): |
| m.info_path.zip_path = dora.git_save.to_absolute_path(m.path) |
| return m |
|
|
|
|
| def find_audio_files(path: tp.Union[Path, str], |
| exts: tp.List[str] = DEFAULT_EXTS, |
| resolve: bool = True, |
| minimal: bool = True, |
| progress: bool = False, |
| workers: int = 0) -> tp.List[AudioMeta]: |
| """Build a list of AudioMeta from a given path, |
| collecting relevant audio files and fetching meta info. |
| |
| Args: |
| path (str or Path): Path to folder containing audio files. |
| exts (list of str): List of file extensions to consider for audio files. |
| minimal (bool): Whether to only load the minimal set of metadata (takes longer if not). |
| progress (bool): Whether to log progress on audio files collection. |
| workers (int): number of parallel workers, if 0, use only the current thread. |
| Returns: |
| List[AudioMeta]: List of audio file path and its metadata. |
| """ |
| audio_files = [] |
| futures: tp.List[Future] = [] |
| pool: tp.Optional[ThreadPoolExecutor] = None |
| with ExitStack() as stack: |
| if workers > 0: |
| pool = ThreadPoolExecutor(workers) |
| stack.enter_context(pool) |
|
|
| if progress: |
| print("Finding audio files...") |
| for root, folders, files in os.walk(path, followlinks=True): |
| for file in files: |
| full_path = Path(root) / file |
| if full_path.suffix.lower() in exts: |
| audio_files.append(full_path) |
| if pool is not None: |
| futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal)) |
| if progress: |
| print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr) |
|
|
| if progress: |
| print("Getting audio metadata...") |
| meta: tp.List[AudioMeta] = [] |
| for idx, file_path in enumerate(audio_files): |
| try: |
| if pool is None: |
| m = _get_audio_meta(str(file_path), minimal) |
| else: |
| m = futures[idx].result() |
| if resolve: |
| m = _resolve_audio_meta(m) |
| except Exception as err: |
| print("Error with", str(file_path), err, file=sys.stderr) |
| continue |
| meta.append(m) |
| if progress: |
| print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr) |
| meta.sort() |
| return meta |
|
|
|
|
| def load_audio_meta(path: tp.Union[str, Path], |
| resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]: |
| """Load list of AudioMeta from an optionally compressed json file. |
| |
| Args: |
| path (str or Path): Path to JSON file. |
| resolve (bool): Whether to resolve the path from AudioMeta (default=True). |
| fast (bool): activates some tricks to make things faster. |
| Returns: |
| List[AudioMeta]: List of audio file path and its total duration. |
| """ |
| open_fn = gzip.open if str(path).lower().endswith('.gz') else open |
| with open_fn(path, 'rb') as fp: |
| lines = fp.readlines() |
| meta = [] |
| for line in lines: |
| d = json.loads(line) |
| m = AudioMeta.from_dict(d) |
| if resolve: |
| m = _resolve_audio_meta(m, fast=fast) |
| meta.append(m) |
| return meta |
|
|
|
|
| def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]): |
| """Save the audio metadata to the file pointer as json. |
| |
| Args: |
| path (str or Path): Path to JSON file. |
| metadata (list of BaseAudioMeta): List of audio meta to save. |
| """ |
| Path(path).parent.mkdir(exist_ok=True, parents=True) |
| open_fn = gzip.open if str(path).lower().endswith('.gz') else open |
| with open_fn(path, 'wb') as fp: |
| for m in meta: |
| json_str = json.dumps(m.to_dict()) + '\n' |
| json_bytes = json_str.encode('utf-8') |
| fp.write(json_bytes) |
|
|
|
|
| class AudioDataset: |
| """Base audio dataset. |
| |
| The dataset takes a list of AudioMeta and create a dataset composed of segments of audio |
| and potentially additional information, by creating random segments from the list of audio |
| files referenced in the metadata and applying minimal data pre-processing such as resampling, |
| mixing of channels, padding, etc. |
| |
| If no segment_duration value is provided, the AudioDataset will return the full wav for each |
| audio file. Otherwise, it will randomly sample audio files and create a segment of the specified |
| duration, applying padding if required. |
| |
| By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True |
| allows to return a tuple containing the torch Tensor and additional metadata on the segment and the |
| original audio meta. |
| |
| Args: |
| meta (tp.List[AudioMeta]): List of audio files metadata. |
| segment_duration (float): Optional segment duration of audio to load. |
| If not specified, the dataset will load the full audio segment from the file. |
| shuffle (bool): Set to `True` to have the data reshuffled at every epoch. |
| sample_rate (int): Target sample rate of the loaded audio samples. |
| channels (int): Target number of channels of the loaded audio samples. |
| sample_on_duration (bool): Set to `True` to sample segments with probability |
| dependent on audio file duration. This is only used if `segment_duration` is provided. |
| sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of |
| `AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product |
| of the file duration and file weight. This is only used if `segment_duration` is provided. |
| min_segment_ratio (float): Minimum segment ratio to use when the audio file |
| is shorter than the desired segment. |
| max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset. |
| return_info (bool): Whether to return the wav only or return wav along with segment info and metadata. |
| min_audio_duration (tp.Optional[float], optional): Minimum audio file duration, in seconds, if provided |
| audio shorter than this will be filtered out. |
| max_audio_duration (tp.Optional[float], optional): Maximal audio file duration in seconds, if provided |
| audio longer than this will be filtered out. |
| """ |
| def __init__(self, |
| meta: tp.List[AudioMeta], |
| segment_duration: tp.Optional[float] = None, |
| shuffle: bool = True, |
| num_samples: int = 10_000, |
| sample_rate: int = 48_000, |
| channels: int = 2, |
| pad: bool = True, |
| sample_on_duration: bool = True, |
| sample_on_weight: bool = True, |
| min_segment_ratio: float = 0.5, |
| max_read_retry: int = 10, |
| return_info: bool = False, |
| min_audio_duration: tp.Optional[float] = None, |
| max_audio_duration: tp.Optional[float] = None |
| ): |
| assert len(meta) > 0, 'No audio meta provided to AudioDataset. Please check loading of audio meta.' |
| assert segment_duration is None or segment_duration > 0 |
| assert segment_duration is None or min_segment_ratio >= 0 |
| logging.debug(f'sample_on_duration: {sample_on_duration}') |
| logging.debug(f'sample_on_weight: {sample_on_weight}') |
| logging.debug(f'pad: {pad}') |
| logging.debug(f'min_segment_ratio: {min_segment_ratio}') |
|
|
| self.segment_duration = segment_duration |
| self.min_segment_ratio = min_segment_ratio |
| self.max_audio_duration = max_audio_duration |
| self.min_audio_duration = min_audio_duration |
| if self.min_audio_duration is not None and self.max_audio_duration is not None: |
| assert self.min_audio_duration <= self.max_audio_duration |
| self.meta: tp.List[AudioMeta] = self._filter_duration(meta) |
| assert len(self.meta) |
| self.total_duration = sum(d.duration for d in self.meta) |
|
|
| if segment_duration is None: |
| num_samples = len(self.meta) |
| self.num_samples = num_samples |
| self.shuffle = shuffle |
| self.sample_rate = sample_rate |
| self.channels = channels |
| self.pad = pad |
| self.sample_on_weight = sample_on_weight |
| self.sample_on_duration = sample_on_duration |
| self.sampling_probabilities = self._get_sampling_probabilities() |
| self.max_read_retry = max_read_retry |
| self.return_info = return_info |
|
|
| def __len__(self): |
| return self.num_samples |
|
|
| def _get_sampling_probabilities(self, normalized: bool = True): |
| """Return the sampling probabilities for each file inside `self.meta`. |
| """ |
| scores: tp.List[float] = [] |
| for file_meta in self.meta: |
| score = 1. |
| if self.sample_on_weight and file_meta.weight is not None: |
| score *= file_meta.weight |
| if self.sample_on_duration: |
| score *= file_meta.duration |
| scores.append(score) |
| probabilities = torch.tensor(scores) |
| if normalized: |
| probabilities /= probabilities.sum() |
| return probabilities |
|
|
| def sample_file(self, rng: torch.Generator) -> AudioMeta: |
| """Sample a given file from `self.meta`. Can be overriden in subclasses. |
| This is only called if `segment_duration` is not None. |
| |
| You must use the provided random number generator `rng` for reproducibility. |
| """ |
| if not self.sample_on_weight and not self.sample_on_duration: |
| file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item()) |
| else: |
| file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item()) |
|
|
| return self.meta[file_index] |
|
|
| def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]: |
| if self.segment_duration is None: |
| file_meta = self.meta[index] |
| out, sr = audio_read(file_meta.path) |
| out = convert_audio(out, sr, self.sample_rate, self.channels) |
| n_frames = out.shape[-1] |
| segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames, |
| sample_rate=self.sample_rate) |
| else: |
| rng = torch.Generator() |
| if self.shuffle: |
| |
| rng.manual_seed(index + self.num_samples * random.randint(0, 2**24)) |
| else: |
| |
| rng.manual_seed(index) |
|
|
| for retry in range(self.max_read_retry): |
| file_meta = self.sample_file(rng) |
| |
| |
| max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio) |
| seek_time = torch.rand(1, generator=rng).item() * max_seek |
| try: |
| out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False) |
| out = convert_audio(out, sr, self.sample_rate, self.channels) |
| n_frames = out.shape[-1] |
| target_frames = int(self.segment_duration * self.sample_rate) |
| if self.pad: |
| out = F.pad(out, (0, target_frames - n_frames)) |
| segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames, |
| sample_rate=self.sample_rate) |
| except Exception as exc: |
| logger.warning("Error opening file %s: %r", file_meta.path, exc) |
| if retry == self.max_read_retry - 1: |
| raise |
| else: |
| break |
|
|
| if self.return_info: |
| |
| return out, segment_info |
| else: |
| return out |
|
|
| def collater(self, samples): |
| """The collater function has to be provided to the dataloader |
| if AudioDataset has return_info=True in order to properly collate |
| the samples of a batch. |
| """ |
| if self.segment_duration is None and len(samples) > 1: |
| assert self.pad, "Must allow padding when batching examples of different durations." |
|
|
| |
| to_pad = self.segment_duration is None and self.pad |
| if to_pad: |
| max_len = max([wav.shape[-1] for wav, _ in samples]) |
|
|
| def _pad_wav(wav): |
| return F.pad(wav, (0, max_len - wav.shape[-1])) |
|
|
| if self.return_info: |
| if len(samples) > 0: |
| assert len(samples[0]) == 2 |
| assert isinstance(samples[0][0], torch.Tensor) |
| assert isinstance(samples[0][1], SegmentInfo) |
|
|
| wavs = [wav for wav, _ in samples] |
| segment_infos = [copy.deepcopy(info) for _, info in samples] |
|
|
| if to_pad: |
| |
| for i in range(len(samples)): |
| |
| segment_infos[i].total_frames = max_len |
| wavs[i] = _pad_wav(wavs[i]) |
|
|
| wav = torch.stack(wavs) |
| return wav, segment_infos |
| else: |
| assert isinstance(samples[0], torch.Tensor) |
| if to_pad: |
| samples = [_pad_wav(s) for s in samples] |
| return torch.stack(samples) |
|
|
| def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]: |
| """Filters out audio files with short durations. |
| Removes from meta files that have durations that will not allow to samples examples from them. |
| """ |
| orig_len = len(meta) |
|
|
| |
| if self.min_audio_duration is not None: |
| meta = [m for m in meta if m.duration >= self.min_audio_duration] |
|
|
| |
| if self.max_audio_duration is not None: |
| meta = [m for m in meta if m.duration <= self.max_audio_duration] |
|
|
| filtered_len = len(meta) |
| removed_percentage = 100*(1-float(filtered_len)/orig_len) |
| msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage |
| if removed_percentage < 10: |
| logging.debug(msg) |
| else: |
| logging.warning(msg) |
| return meta |
|
|
| @classmethod |
| def from_meta(cls, root: tp.Union[str, Path], **kwargs): |
| """Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file. |
| |
| Args: |
| root (str or Path): Path to root folder containing audio files. |
| kwargs: Additional keyword arguments for the AudioDataset. |
| """ |
| root = Path(root) |
| if root.is_dir(): |
| if (root / 'data.jsonl').exists(): |
| root = root / 'data.jsonl' |
| elif (root / 'data.jsonl.gz').exists(): |
| root = root / 'data.jsonl.gz' |
| else: |
| raise ValueError("Don't know where to read metadata from in the dir. " |
| "Expecting either a data.jsonl or data.jsonl.gz file but none found.") |
| meta = load_audio_meta(root) |
| return cls(meta, **kwargs) |
|
|
| @classmethod |
| def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True, |
| exts: tp.List[str] = DEFAULT_EXTS, **kwargs): |
| """Instantiate AudioDataset from a path containing (possibly nested) audio files. |
| |
| Args: |
| root (str or Path): Path to root folder containing audio files. |
| minimal_meta (bool): Whether to only load minimal metadata or not. |
| exts (list of str): Extensions for audio files. |
| kwargs: Additional keyword arguments for the AudioDataset. |
| """ |
| root = Path(root) |
| if root.is_file(): |
| meta = load_audio_meta(root, resolve=True) |
| else: |
| meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True) |
| return cls(meta, **kwargs) |
|
|
|
|
| def main(): |
| logging.basicConfig(stream=sys.stderr, level=logging.INFO) |
| parser = argparse.ArgumentParser( |
| prog='audio_dataset', |
| description='Generate .jsonl files by scanning a folder.') |
| parser.add_argument('root', help='Root folder with all the audio files') |
| parser.add_argument('output_meta_file', |
| help='Output file to store the metadata, ') |
| parser.add_argument('--complete', |
| action='store_false', dest='minimal', default=True, |
| help='Retrieve all metadata, even the one that are expansive ' |
| 'to compute (e.g. normalization).') |
| parser.add_argument('--resolve', |
| action='store_true', default=False, |
| help='Resolve the paths to be absolute and with no symlinks.') |
| parser.add_argument('--workers', |
| default=10, type=int, |
| help='Number of workers.') |
| args = parser.parse_args() |
| meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True, |
| resolve=args.resolve, minimal=args.minimal, workers=args.workers) |
| save_audio_meta(args.output_meta_file, meta) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|