Speach-To-Text / src /data_preparation /augmentation.py
MIP-Tech's picture
Deploy to HF Spaces
0db822c
"""
Audio and spectrogram augmentation for training data diversity.
Three augmentation families are provided:
1. Speed perturbation β€” resample audio to simulate faster/slower speech.
Changes both tempo and pitch (intentional for ASR aug).
2. Noise addition β€” add Gaussian noise at a controlled SNR (dB).
3. SpecAugment β€” mask random time-steps and frequency bins in the
mel-spectrogram; applied inside the DataCollator so
it is random on every training step, not cached.
All functions operate on numpy float32 arrays (audio) or torch.Tensor
(spectrogram). They are designed to be called from:
- make_prepare_fn() in trainer.py β†’ speed + noise on raw audio
- DataCollatorSpeechSeq2SeqWithPadding.__call__() β†’ SpecAugment on features
"""
from __future__ import annotations
import random
from typing import Optional
import numpy as np
import torch
import torchaudio.functional as F_audio
import torchaudio.transforms as T
# ---------------------------------------------------------------------------
# Speed perturbation
# ---------------------------------------------------------------------------
def apply_speed_perturbation(audio: np.ndarray, sr: int, factor: float) -> np.ndarray:
"""
Change the playback speed of `audio` by `factor`.
factor > 1.0 β†’ faster speech (audio gets shorter)
factor < 1.0 β†’ slower speech (audio gets longer)
Implemented via resampling: treating the signal as if it was recorded at
sr * factor and then played back at sr. This shifts pitch proportionally
to speed (tape-speed effect), which is the standard approach for ASR
data augmentation and is well-supported by Whisper.
Args:
audio: float32 numpy array, shape [N]
sr: original sample rate (e.g. 16000)
factor: speed multiplier (e.g. 0.9, 1.1)
Returns:
float32 numpy array, resampled to sr Hz at the new speed.
"""
if factor == 1.0:
return audio
waveform = torch.from_numpy(audio).unsqueeze(0) # [1, N]
orig_sr = int(sr * factor) # "virtual" sample rate
resampled = F_audio.resample(waveform, orig_sr, sr) # back to target sr
return resampled.squeeze(0).numpy().astype(np.float32)
def maybe_apply_speed(
audio: np.ndarray,
sr: int,
config: dict,
) -> np.ndarray:
"""
Randomly apply speed perturbation according to `config`.
Config keys (all optional):
enabled : bool β€” master switch (default True)
probability : float β€” chance of applying per sample (default 0.3)
factors : list β€” speed multipliers to choose from
(default [0.9, 0.95, 1.05, 1.1])
"""
if not config.get("enabled", True):
return audio
if random.random() >= config.get("probability", 0.3):
return audio
factor = random.choice(config.get("factors", [0.9, 0.95, 1.05, 1.1]))
return apply_speed_perturbation(audio, sr, factor)
# ---------------------------------------------------------------------------
# Noise addition
# ---------------------------------------------------------------------------
def apply_noise(audio: np.ndarray, snr_db: float) -> np.ndarray:
"""
Add Gaussian white noise to `audio` at the given SNR (dB).
Lower SNR β†’ more noise (harder). Typical training range: 15–30 dB.
The noisy signal is clipped to [-1, 1] to stay within valid PCM range.
"""
signal_power = np.mean(audio.astype(np.float64) ** 2)
if signal_power < 1e-10: # near-silent segment β€” skip
return audio
noise_power = signal_power / (10.0 ** (snr_db / 10.0))
noise = np.random.normal(0.0, np.sqrt(noise_power), len(audio)).astype(np.float32)
return np.clip(audio + noise, -1.0, 1.0)
def maybe_apply_noise(
audio: np.ndarray,
config: dict,
) -> np.ndarray:
"""
Randomly add Gaussian noise according to `config`.
Config keys (all optional):
enabled : bool β€” master switch (default True)
probability : float β€” chance of applying per sample (default 0.3)
min_snr_db : float β€” minimum SNR in dB (default 15.0)
max_snr_db : float β€” maximum SNR in dB (default 30.0)
"""
if not config.get("enabled", True):
return audio
if random.random() >= config.get("probability", 0.3):
return audio
min_snr = config.get("min_snr_db", 15.0)
max_snr = config.get("max_snr_db", 30.0)
snr = random.uniform(min_snr, max_snr)
return apply_noise(audio, snr)
# ---------------------------------------------------------------------------
# SpecAugment
# ---------------------------------------------------------------------------
def apply_spec_augment(
input_features: torch.Tensor,
time_mask_param: int = 80,
freq_mask_param: int = 27,
num_time_masks: int = 2,
num_freq_masks: int = 2,
) -> torch.Tensor:
"""
Apply SpecAugment (Park et al. 2019) to a batch of mel-spectrogram features.
Alternately masks random contiguous time-steps and frequency bins with zeros.
This is applied INSIDE the DataCollator so it is stochastically fresh on
every training step β€” it is never cached to disk.
Args:
input_features : torch.Tensor shape [batch, n_mels, time] or [n_mels, time]
time_mask_param: maximum number of consecutive time-steps to mask
freq_mask_param: maximum number of consecutive frequency bins to mask
num_time_masks : how many separate time masks to apply
num_freq_masks : how many separate frequency masks to apply
Returns:
Tensor of the same shape with masked regions set to zero.
"""
is_batched = input_features.dim() == 3
features = input_features.unsqueeze(0) if not is_batched else input_features.clone()
# torchaudio transforms expect [batch, freq, time]
for _ in range(num_freq_masks):
features = T.FrequencyMasking(freq_mask_param=freq_mask_param)(features)
for _ in range(num_time_masks):
features = T.TimeMasking(time_mask_param=time_mask_param)(features)
return features.squeeze(0) if not is_batched else features