| |
|
|
| import os |
| import subprocess |
| from functools import lru_cache |
| from typing import Optional, Union |
| from scipy.io.wavfile import write |
| import tempfile |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
| def exact_div(x, y): |
| assert x % y == 0 |
| return x // y |
|
|
| |
| SAMPLE_RATE = 16000 |
| N_FFT = 400 |
| HOP_LENGTH = 160 |
| CHUNK_LENGTH = 30 |
| N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE |
| N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) |
|
|
| N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 |
| FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) |
| TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) |
|
|
|
|
| def load_audio(file: Union[str, np.ndarray], sr: int = SAMPLE_RATE) -> np.ndarray: |
| """ |
| Open an audio file or process a numpy array containing audio data as mono waveform, resampling as necessary. |
| |
| Parameters |
| ---------- |
| file: Union[str, np.ndarray] |
| The audio file to open or a numpy array containing the audio data. |
| |
| sr: int |
| The sample rate to resample the audio if necessary. |
| |
| Returns |
| ------- |
| A NumPy array containing the audio waveform, in float32 dtype. |
| """ |
| if isinstance(file, np.ndarray): |
| if file.dtype != np.float32: |
| file = file.astype(np.float32) |
| if file.ndim > 1: |
| file = np.mean(file, axis=1) |
|
|
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") |
| write(temp_file.name, SAMPLE_RATE, (file * 32768).astype(np.int16)) |
| temp_file_path = temp_file.name |
| temp_file.close() |
| else: |
| temp_file_path = file |
|
|
| try: |
| cmd = [ |
| "ffmpeg", |
| "-nostdin", |
| "-threads", |
| "0", |
| "-i", |
| temp_file_path, |
| "-f", |
| "s16le", |
| "-ac", |
| "1", |
| "-acodec", |
| "pcm_s16le", |
| "-ar", |
| str(sr), |
| "-", |
| ] |
| out = subprocess.run(cmd, capture_output=True, check=True).stdout |
| except subprocess.CalledProcessError as e: |
| raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e |
| finally: |
| if isinstance(file, np.ndarray): |
| os.remove(temp_file_path) |
|
|
| return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 |
|
|
|
|
| def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): |
| """ |
| Pad or trim the audio array to N_SAMPLES, as expected by the encoder. |
| """ |
| if torch.is_tensor(array): |
| if array.shape[axis] > length: |
| array = array.index_select( |
| dim=axis, index=torch.arange(length, device=array.device) |
| ) |
|
|
| if array.shape[axis] < length: |
| pad_widths = [(0, 0)] * array.ndim |
| pad_widths[axis] = (0, length - array.shape[axis]) |
| array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) |
| else: |
| if array.shape[axis] > length: |
| array = array.take(indices=range(length), axis=axis) |
|
|
| if array.shape[axis] < length: |
| pad_widths = [(0, 0)] * array.ndim |
| pad_widths[axis] = (0, length - array.shape[axis]) |
| array = np.pad(array, pad_widths) |
|
|
| return array |
|
|
|
|
| @lru_cache(maxsize=None) |
| def mel_filters(device, n_mels: int) -> torch.Tensor: |
| """ |
| load the mel filterbank matrix for projecting STFT into a Mel spectrogram. |
| Allows decoupling librosa dependency; saved using: |
| |
| np.savez_compressed( |
| "mel_filters.npz", |
| mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), |
| ) |
| """ |
| assert n_mels in [80, 128], f"Unsupported n_mels: {n_mels}" |
| with np.load( |
| os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") |
| ) as f: |
| return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) |
|
|
|
|
| def log_mel_spectrogram( |
| audio: Union[str, np.ndarray, torch.Tensor], |
| n_mels: int, |
| padding: int = 0, |
| device: Optional[Union[str, torch.device]] = None, |
| ): |
| """ |
| Compute the log-Mel spectrogram of |
| |
| Parameters |
| ---------- |
| audio: Union[str, np.ndarray, torch.Tensor], shape = (*) |
| The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz |
| |
| n_mels: int |
| The number of Mel-frequency filters, only 80 is supported |
| |
| padding: int |
| Number of zero samples to pad to the right |
| |
| device: Optional[Union[str, torch.device]] |
| If given, the audio tensor is moved to this device before STFT |
| |
| Returns |
| ------- |
| torch.Tensor, shape = (80, n_frames) |
| A Tensor that contains the Mel spectrogram |
| """ |
| if not torch.is_tensor(audio): |
| if isinstance(audio, str): |
| audio = load_audio(audio) |
| audio = torch.from_numpy(audio) |
|
|
| if device is not None: |
| audio = audio.to(device) |
| if padding > 0: |
| audio = F.pad(audio, (0, padding)) |
| window = torch.hann_window(N_FFT).to(audio.device) |
| stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) |
| magnitudes = stft[..., :-1].abs() ** 2 |
|
|
| filters = mel_filters(audio.device, n_mels) |
| mel_spec = filters @ magnitudes |
|
|
| log_spec = torch.clamp(mel_spec, min=1e-10).log10() |
| log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) |
| log_spec = (log_spec + 4.0) / 4.0 |
| return log_spec |