Spaces:
Sleeping
Sleeping
| """ | |
| Audio and spectrogram augmentation for training data diversity. | |
| Three augmentation families are provided: | |
| 1. Speed perturbation β resample audio to simulate faster/slower speech. | |
| Changes both tempo and pitch (intentional for ASR aug). | |
| 2. Noise addition β add Gaussian noise at a controlled SNR (dB). | |
| 3. SpecAugment β mask random time-steps and frequency bins in the | |
| mel-spectrogram; applied inside the DataCollator so | |
| it is random on every training step, not cached. | |
| All functions operate on numpy float32 arrays (audio) or torch.Tensor | |
| (spectrogram). They are designed to be called from: | |
| - make_prepare_fn() in trainer.py β speed + noise on raw audio | |
| - DataCollatorSpeechSeq2SeqWithPadding.__call__() β SpecAugment on features | |
| """ | |
| from __future__ import annotations | |
| import random | |
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| import torchaudio.functional as F_audio | |
| import torchaudio.transforms as T | |
| # --------------------------------------------------------------------------- | |
| # Speed perturbation | |
| # --------------------------------------------------------------------------- | |
| def apply_speed_perturbation(audio: np.ndarray, sr: int, factor: float) -> np.ndarray: | |
| """ | |
| Change the playback speed of `audio` by `factor`. | |
| factor > 1.0 β faster speech (audio gets shorter) | |
| factor < 1.0 β slower speech (audio gets longer) | |
| Implemented via resampling: treating the signal as if it was recorded at | |
| sr * factor and then played back at sr. This shifts pitch proportionally | |
| to speed (tape-speed effect), which is the standard approach for ASR | |
| data augmentation and is well-supported by Whisper. | |
| Args: | |
| audio: float32 numpy array, shape [N] | |
| sr: original sample rate (e.g. 16000) | |
| factor: speed multiplier (e.g. 0.9, 1.1) | |
| Returns: | |
| float32 numpy array, resampled to sr Hz at the new speed. | |
| """ | |
| if factor == 1.0: | |
| return audio | |
| waveform = torch.from_numpy(audio).unsqueeze(0) # [1, N] | |
| orig_sr = int(sr * factor) # "virtual" sample rate | |
| resampled = F_audio.resample(waveform, orig_sr, sr) # back to target sr | |
| return resampled.squeeze(0).numpy().astype(np.float32) | |
| def maybe_apply_speed( | |
| audio: np.ndarray, | |
| sr: int, | |
| config: dict, | |
| ) -> np.ndarray: | |
| """ | |
| Randomly apply speed perturbation according to `config`. | |
| Config keys (all optional): | |
| enabled : bool β master switch (default True) | |
| probability : float β chance of applying per sample (default 0.3) | |
| factors : list β speed multipliers to choose from | |
| (default [0.9, 0.95, 1.05, 1.1]) | |
| """ | |
| if not config.get("enabled", True): | |
| return audio | |
| if random.random() >= config.get("probability", 0.3): | |
| return audio | |
| factor = random.choice(config.get("factors", [0.9, 0.95, 1.05, 1.1])) | |
| return apply_speed_perturbation(audio, sr, factor) | |
| # --------------------------------------------------------------------------- | |
| # Noise addition | |
| # --------------------------------------------------------------------------- | |
| def apply_noise(audio: np.ndarray, snr_db: float) -> np.ndarray: | |
| """ | |
| Add Gaussian white noise to `audio` at the given SNR (dB). | |
| Lower SNR β more noise (harder). Typical training range: 15β30 dB. | |
| The noisy signal is clipped to [-1, 1] to stay within valid PCM range. | |
| """ | |
| signal_power = np.mean(audio.astype(np.float64) ** 2) | |
| if signal_power < 1e-10: # near-silent segment β skip | |
| return audio | |
| noise_power = signal_power / (10.0 ** (snr_db / 10.0)) | |
| noise = np.random.normal(0.0, np.sqrt(noise_power), len(audio)).astype(np.float32) | |
| return np.clip(audio + noise, -1.0, 1.0) | |
| def maybe_apply_noise( | |
| audio: np.ndarray, | |
| config: dict, | |
| ) -> np.ndarray: | |
| """ | |
| Randomly add Gaussian noise according to `config`. | |
| Config keys (all optional): | |
| enabled : bool β master switch (default True) | |
| probability : float β chance of applying per sample (default 0.3) | |
| min_snr_db : float β minimum SNR in dB (default 15.0) | |
| max_snr_db : float β maximum SNR in dB (default 30.0) | |
| """ | |
| if not config.get("enabled", True): | |
| return audio | |
| if random.random() >= config.get("probability", 0.3): | |
| return audio | |
| min_snr = config.get("min_snr_db", 15.0) | |
| max_snr = config.get("max_snr_db", 30.0) | |
| snr = random.uniform(min_snr, max_snr) | |
| return apply_noise(audio, snr) | |
| # --------------------------------------------------------------------------- | |
| # SpecAugment | |
| # --------------------------------------------------------------------------- | |
| def apply_spec_augment( | |
| input_features: torch.Tensor, | |
| time_mask_param: int = 80, | |
| freq_mask_param: int = 27, | |
| num_time_masks: int = 2, | |
| num_freq_masks: int = 2, | |
| ) -> torch.Tensor: | |
| """ | |
| Apply SpecAugment (Park et al. 2019) to a batch of mel-spectrogram features. | |
| Alternately masks random contiguous time-steps and frequency bins with zeros. | |
| This is applied INSIDE the DataCollator so it is stochastically fresh on | |
| every training step β it is never cached to disk. | |
| Args: | |
| input_features : torch.Tensor shape [batch, n_mels, time] or [n_mels, time] | |
| time_mask_param: maximum number of consecutive time-steps to mask | |
| freq_mask_param: maximum number of consecutive frequency bins to mask | |
| num_time_masks : how many separate time masks to apply | |
| num_freq_masks : how many separate frequency masks to apply | |
| Returns: | |
| Tensor of the same shape with masked regions set to zero. | |
| """ | |
| is_batched = input_features.dim() == 3 | |
| features = input_features.unsqueeze(0) if not is_batched else input_features.clone() | |
| # torchaudio transforms expect [batch, freq, time] | |
| for _ in range(num_freq_masks): | |
| features = T.FrequencyMasking(freq_mask_param=freq_mask_param)(features) | |
| for _ in range(num_time_masks): | |
| features = T.TimeMasking(time_mask_param=time_mask_param)(features) | |
| return features.squeeze(0) if not is_batched else features | |