| | |
| | |
| | |
| | |
| | |
| |
|
| | """ |
| | Audio IO methods are defined in this module (info, read, write), |
| | We rely on av library for faster read when possible, otherwise on torchaudio. |
| | """ |
| |
|
| | from dataclasses import dataclass |
| | from pathlib import Path |
| | import logging |
| | import typing as tp |
| |
|
| | import numpy as np |
| | import soundfile |
| | import torch |
| | from torch.nn import functional as F |
| | import torchaudio as ta |
| |
|
| | import av |
| |
|
| | from .audio_utils import f32_pcm, i16_pcm, normalize_audio, convert_audio |
| |
|
| |
|
| | _av_initialized = False |
| |
|
| |
|
| | def _init_av(): |
| | global _av_initialized |
| | if _av_initialized: |
| | return |
| | logger = logging.getLogger('libav.mp3') |
| | logger.setLevel(logging.ERROR) |
| | _av_initialized = True |
| |
|
| |
|
| | @dataclass(frozen=True) |
| | class AudioFileInfo: |
| | sample_rate: int |
| | duration: float |
| | channels: int |
| |
|
| |
|
| | def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo: |
| | _init_av() |
| | with av.open(str(filepath)) as af: |
| | stream = af.streams.audio[0] |
| | sample_rate = stream.codec_context.sample_rate |
| | duration = float(stream.duration * stream.time_base) |
| | channels = stream.channels |
| | return AudioFileInfo(sample_rate, duration, channels) |
| |
|
| |
|
| | def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo: |
| | info = soundfile.info(filepath) |
| | return AudioFileInfo(info.samplerate, info.duration, info.channels) |
| |
|
| |
|
| | def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo: |
| | |
| | filepath = Path(filepath) |
| | if filepath.suffix in ['.flac', '.ogg']: |
| | |
| | return _soundfile_info(filepath) |
| | else: |
| | return _av_info(filepath) |
| |
|
| |
|
| | def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]: |
| | """FFMPEG-based audio file reading using PyAV bindings. |
| | Soundfile cannot read mp3 and av_read is more efficient than torchaudio. |
| | |
| | Args: |
| | filepath (str or Path): Path to audio file to read. |
| | seek_time (float): Time at which to start reading in the file. |
| | duration (float): Duration to read from the file. If set to -1, the whole file is read. |
| | Returns: |
| | Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate |
| | """ |
| | _init_av() |
| | with av.open(str(filepath)) as af: |
| | stream = af.streams.audio[0] |
| | sr = stream.codec_context.sample_rate |
| | num_frames = int(sr * duration) if duration >= 0 else -1 |
| | frame_offset = int(sr * seek_time) |
| | |
| | |
| | af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream) |
| | frames = [] |
| | length = 0 |
| | for frame in af.decode(streams=stream.index): |
| | current_offset = int(frame.rate * frame.pts * frame.time_base) |
| | strip = max(0, frame_offset - current_offset) |
| | buf = torch.from_numpy(frame.to_ndarray()) |
| | if buf.shape[0] != stream.channels: |
| | buf = buf.view(-1, stream.channels).t() |
| | buf = buf[:, strip:] |
| | frames.append(buf) |
| | length += buf.shape[1] |
| | if num_frames > 0 and length >= num_frames: |
| | break |
| | assert frames |
| | |
| | |
| | |
| | wav = torch.cat(frames, dim=1) |
| | assert wav.shape[0] == stream.channels |
| | if num_frames > 0: |
| | wav = wav[:, :num_frames] |
| | return f32_pcm(wav), sr |
| |
|
| |
|
| | def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0., |
| | duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]: |
| | """Read audio by picking the most appropriate backend tool based on the audio format. |
| | |
| | Args: |
| | filepath (str or Path): Path to audio file to read. |
| | seek_time (float): Time at which to start reading in the file. |
| | duration (float): Duration to read from the file. If set to -1, the whole file is read. |
| | pad (bool): Pad output audio if not reaching expected duration. |
| | Returns: |
| | Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate. |
| | """ |
| | fp = Path(filepath) |
| | if fp.suffix in ['.flac', '.ogg']: |
| | |
| | info = _soundfile_info(filepath) |
| | frames = -1 if duration <= 0 else int(duration * info.sample_rate) |
| | frame_offset = int(seek_time * info.sample_rate) |
| | wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32) |
| | assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}" |
| | wav = torch.from_numpy(wav).t().contiguous() |
| | if len(wav.shape) == 1: |
| | wav = torch.unsqueeze(wav, 0) |
| | elif ( |
| | fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats() |
| | and duration <= 0 and seek_time == 0 |
| | ): |
| | |
| | wav, sr = ta.load(fp) |
| | else: |
| | wav, sr = _av_read(filepath, seek_time, duration) |
| | if pad and duration > 0: |
| | expected_frames = int(duration * sr) |
| | wav = F.pad(wav, (0, expected_frames - wav.shape[-1])) |
| | return wav, sr |
| |
|
| |
|
| | def audio_write(stem_name: tp.Union[str, Path], |
| | wav: torch.Tensor, sample_rate: int, |
| | format: str = 'wav', mp3_rate: int = 320, normalize: bool = True, |
| | strategy: str = 'peak', peak_clip_headroom_db: float = 1, |
| | rms_headroom_db: float = 18, loudness_headroom_db: float = 14, |
| | loudness_compressor: bool = False, |
| | log_clipping: bool = True, make_parent_dir: bool = True, |
| | add_suffix: bool = True, channels:int = 1) -> Path: |
| | """Convenience function for saving audio to disk. Returns the filename the audio was written to. |
| | |
| | Args: |
| | stem_name (str or Path): Filename without extension which will be added automatically. |
| | format (str): Either "wav" or "mp3". |
| | mp3_rate (int): kbps when using mp3s. |
| | normalize (bool): if `True` (default), normalizes according to the prescribed |
| | strategy (see after). If `False`, the strategy is only used in case clipping |
| | would happen. |
| | strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak', |
| | i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square |
| | with extra headroom to avoid clipping. 'clip' just clips. |
| | peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy. |
| | rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger |
| | than the `peak_clip` one to avoid further clipping. |
| | loudness_headroom_db (float): Target loudness for loudness normalization. |
| | loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'. |
| | when strategy is 'loudness'log_clipping (bool): If True, basic logging on stderr when clipping still |
| | occurs despite strategy (only for 'rms'). |
| | make_parent_dir (bool): Make parent directory if it doesn't exist. |
| | Returns: |
| | Path: Path of the saved audio. |
| | """ |
| | assert wav.dtype.is_floating_point, "wav is not floating point" |
| | if wav.dim() == 1: |
| | wav = wav[None] |
| | elif wav.dim() > 2: |
| | raise ValueError("Input wav should be at most 2 dimension.") |
| | assert wav.isfinite().all() |
| | wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db, |
| | rms_headroom_db, loudness_headroom_db, log_clipping=log_clipping, |
| | sample_rate=sample_rate, stem_name=str(stem_name)) |
| | if channels > 1: |
| | wav = convert_audio(wav,sample_rate, sample_rate, channels) |
| | kwargs: dict = {} |
| | if format == 'mp3': |
| | suffix = '.mp3' |
| | kwargs.update({"compression": mp3_rate}) |
| | elif format == 'wav': |
| | wav = i16_pcm(wav) |
| | suffix = '.wav' |
| | kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16}) |
| | else: |
| | raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.") |
| | if not add_suffix: |
| | suffix = '' |
| | path = Path(str(stem_name) + suffix) |
| | if make_parent_dir: |
| | path.parent.mkdir(exist_ok=True, parents=True) |
| | try: |
| | ta.save(path, wav, sample_rate, **kwargs) |
| | except Exception: |
| | if path.exists(): |
| | |
| | path.unlink() |
| | raise |
| | return path |
| |
|