File size: 6,286 Bytes
0db822c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""
Audio and spectrogram augmentation for training data diversity.

Three augmentation families are provided:

  1. Speed perturbation   β€” resample audio to simulate faster/slower speech.
                            Changes both tempo and pitch (intentional for ASR aug).
  2. Noise addition       β€” add Gaussian noise at a controlled SNR (dB).
  3. SpecAugment          β€” mask random time-steps and frequency bins in the
                            mel-spectrogram; applied inside the DataCollator so
                            it is random on every training step, not cached.

All functions operate on numpy float32 arrays (audio) or torch.Tensor
(spectrogram).  They are designed to be called from:
  - make_prepare_fn()  in trainer.py  β†’ speed + noise on raw audio
  - DataCollatorSpeechSeq2SeqWithPadding.__call__()  β†’ SpecAugment on features
"""

from __future__ import annotations

import random
from typing import Optional

import numpy as np
import torch
import torchaudio.functional as F_audio
import torchaudio.transforms as T


# ---------------------------------------------------------------------------
# Speed perturbation
# ---------------------------------------------------------------------------

def apply_speed_perturbation(audio: np.ndarray, sr: int, factor: float) -> np.ndarray:
    """
    Change the playback speed of `audio` by `factor`.

    factor > 1.0  β†’  faster speech (audio gets shorter)
    factor < 1.0  β†’  slower speech (audio gets longer)

    Implemented via resampling: treating the signal as if it was recorded at
    sr * factor and then played back at sr.  This shifts pitch proportionally
    to speed (tape-speed effect), which is the standard approach for ASR
    data augmentation and is well-supported by Whisper.

    Args:
        audio:  float32 numpy array, shape [N]
        sr:     original sample rate (e.g. 16000)
        factor: speed multiplier (e.g. 0.9, 1.1)

    Returns:
        float32 numpy array, resampled to sr Hz at the new speed.
    """
    if factor == 1.0:
        return audio
    waveform = torch.from_numpy(audio).unsqueeze(0)        # [1, N]
    orig_sr  = int(sr * factor)                            # "virtual" sample rate
    resampled = F_audio.resample(waveform, orig_sr, sr)    # back to target sr
    return resampled.squeeze(0).numpy().astype(np.float32)


def maybe_apply_speed(
    audio: np.ndarray,
    sr: int,
    config: dict,
) -> np.ndarray:
    """
    Randomly apply speed perturbation according to `config`.

    Config keys (all optional):
      enabled     : bool  β€” master switch (default True)
      probability : float β€” chance of applying per sample (default 0.3)
      factors     : list  β€” speed multipliers to choose from
                            (default [0.9, 0.95, 1.05, 1.1])
    """
    if not config.get("enabled", True):
        return audio
    if random.random() >= config.get("probability", 0.3):
        return audio
    factor = random.choice(config.get("factors", [0.9, 0.95, 1.05, 1.1]))
    return apply_speed_perturbation(audio, sr, factor)


# ---------------------------------------------------------------------------
# Noise addition
# ---------------------------------------------------------------------------

def apply_noise(audio: np.ndarray, snr_db: float) -> np.ndarray:
    """
    Add Gaussian white noise to `audio` at the given SNR (dB).

    Lower SNR β†’ more noise (harder).  Typical training range: 15–30 dB.

    The noisy signal is clipped to [-1, 1] to stay within valid PCM range.
    """
    signal_power = np.mean(audio.astype(np.float64) ** 2)
    if signal_power < 1e-10:          # near-silent segment β€” skip
        return audio
    noise_power  = signal_power / (10.0 ** (snr_db / 10.0))
    noise = np.random.normal(0.0, np.sqrt(noise_power), len(audio)).astype(np.float32)
    return np.clip(audio + noise, -1.0, 1.0)


def maybe_apply_noise(
    audio: np.ndarray,
    config: dict,
) -> np.ndarray:
    """
    Randomly add Gaussian noise according to `config`.

    Config keys (all optional):
      enabled     : bool  β€” master switch (default True)
      probability : float β€” chance of applying per sample (default 0.3)
      min_snr_db  : float β€” minimum SNR in dB (default 15.0)
      max_snr_db  : float β€” maximum SNR in dB (default 30.0)
    """
    if not config.get("enabled", True):
        return audio
    if random.random() >= config.get("probability", 0.3):
        return audio
    min_snr = config.get("min_snr_db", 15.0)
    max_snr = config.get("max_snr_db", 30.0)
    snr     = random.uniform(min_snr, max_snr)
    return apply_noise(audio, snr)


# ---------------------------------------------------------------------------
# SpecAugment
# ---------------------------------------------------------------------------

def apply_spec_augment(
    input_features: torch.Tensor,
    time_mask_param: int = 80,
    freq_mask_param: int = 27,
    num_time_masks: int = 2,
    num_freq_masks: int = 2,
) -> torch.Tensor:
    """
    Apply SpecAugment (Park et al. 2019) to a batch of mel-spectrogram features.

    Alternately masks random contiguous time-steps and frequency bins with zeros.
    This is applied INSIDE the DataCollator so it is stochastically fresh on
    every training step β€” it is never cached to disk.

    Args:
        input_features : torch.Tensor  shape [batch, n_mels, time] or [n_mels, time]
        time_mask_param: maximum number of consecutive time-steps to mask
        freq_mask_param: maximum number of consecutive frequency bins to mask
        num_time_masks : how many separate time masks to apply
        num_freq_masks : how many separate frequency masks to apply

    Returns:
        Tensor of the same shape with masked regions set to zero.
    """
    is_batched = input_features.dim() == 3
    features   = input_features.unsqueeze(0) if not is_batched else input_features.clone()

    # torchaudio transforms expect [batch, freq, time]
    for _ in range(num_freq_masks):
        features = T.FrequencyMasking(freq_mask_param=freq_mask_param)(features)

    for _ in range(num_time_masks):
        features = T.TimeMasking(time_mask_param=time_mask_param)(features)

    return features.squeeze(0) if not is_batched else features