File size: 7,595 Bytes
957e2dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import torch
import math
import decimal

from typing import List

import torch.nn.functional as F

from torchaudio.transforms import MFCC

from src.simulation.component import Component

################################################################################
# Voice Activity Detection (VAD)
################################################################################


class KaldiStyleVAD(Component):
    """

    Kaldi-style Voice Activity Detection (VAD) module. Adapted from

    https://github.com/fsepteixeira/FoolHD/blob/main/code/utils/vad_cmvn.py

    """
    def __init__(self,

                 compute_grad: bool = True,

                 threshold: float = -15.0,

                 proportion_threshold: float = 0.12,

                 frame_len: float = 0.025,

                 hop_len: float = 0.010,

                 mean_scale: float = 0.5,

                 context: int = 2):
        super().__init__(compute_grad)

        self.threshold = threshold
        self.proportion_threshold = proportion_threshold
        self.mean_scale = mean_scale
        self.context = context
        self.diff_zero = mean_scale != 0
        self.unfold_size = 2 * context + 1
        self.frame_len = int(frame_len * self.sample_rate)
        self.hop_len = int(hop_len * self.sample_rate)

        # prepare to compute MFCC
        self.mfcc = MFCC(
            sample_rate=self.sample_rate,
            n_mfcc=30,
            dct_type=2,
            norm='ortho',
            log_mels=True,
            melkwargs={
                'n_fft': self.frame_len,
                'hop_length': self.hop_len,
                'n_mels': 30,
                'f_min': 20,
                'f_max': self.sample_rate // 2,
                'power': 2.0,
                'center': True
            }
        )

    def forward(self, x: torch.Tensor):

        if x.shape[-1] < self.frame_len + self.hop_len:
            return x

        # require batch dimension
        assert x.ndim >= 2

        # require mono audio, discard channel dimension
        n_batch, slen = x.shape[0], x.shape[-1]
        x = x.reshape(n_batch, slen)

        # compute MFCC
        x_mfcc = self.mfcc(x).permute(0, 2, 1)  # (n_batch, n_frames, n_mfcc)

        # set device for energy threshold
        energy_threshold = torch.tensor([self.threshold]).to(x_mfcc.device)

        # first MFCC coefficient represents log energy
        log_energy = x_mfcc[:, :, 0]

        if self.diff_zero:
            energy_threshold = energy_threshold + self.mean_scale * log_energy.mean(dim=1)

        # prepare frame-wise mask
        mask = torch.ones_like(log_energy)

        # pad borders with symmetric context before striding
        mask = F.pad(mask, pad=(self.context, self.context), value=1.0)

        # get all (overlapping) context "windows"
        mask = mask.unfold(dimension=1, size=self.unfold_size, step=1)

        # number of values included in each context window
        den_count = mask.sum(dim=-1)

        # pad borders with symmetric context
        log_energy = F.pad(log_energy, pad=(self.context, self.context))

        # get all (overlapping) context "windows"
        log_energy = log_energy.unfold(
            dimension=1,
            size=self.unfold_size,
            step=1
        )

        # number of values in each context window above threshold
        num_count = log_energy.gt(
            energy_threshold.unsqueeze(-1).unsqueeze(-1)
        ).sum(dim=-1)

        # frame-by-frame mask
        mask = num_count.ge(den_count*self.proportion_threshold)

        # "fold" to obtain waveform mask
        mask_wav = mask.unsqueeze(-1).repeat_interleave(
            repeats=self.frame_len, dim=-1
        )
        mask_wav = torch.cat(
            [
                mask_wav[:, 0],
                mask_wav[:, 1:][:, :, self.frame_len - self.hop_len:].reshape(
                    n_batch, -1
                )
            ], dim=-1
        )
        left_trim = self.frame_len // 2
        right_trim = mask_wav.shape[-1] - left_trim - x.shape[-1]
        mask_wav = mask_wav[..., left_trim: -right_trim]

        # compute number of accepted samples per input waveform
        samples_per_row: List[int] = []
        for e in torch.sum(mask_wav, dim=-1):
            samples_per_row.append(e.item())

        # split resulting tensor to keep trimmed inputs separate
        split = torch.split(x[mask_wav], samples_per_row)

        # placeholder for outputs: (n_batch, 1, padded_length)
        final = torch.zeros_like(x).unsqueeze(1)  # pad to preserve length

        # concatenate and pad split views
        for i, tensor in enumerate(split):
            length = tensor.shape[-1]
            final[i, :, :length] = tensor

        return final[..., :slen]


class VAD(Component):
    """

    Apply Voice Activity Detection (VAD) while allowing for straight-through

    gradient estimation. For now, only supports simple energy-based method,

    and should be placed after normalization to avoid scale-dependence.

    """
    def __init__(self,

                 compute_grad: bool = True,

                 frame_len: float = 0.05,

                 threshold: float = -72

                 ):

        super().__init__(compute_grad)

        self.threshold = threshold
        self.frame_len = int(
            decimal.Decimal(
                frame_len * self.sample_rate
            ).quantize(
                decimal.Decimal('1'), rounding=decimal.ROUND_HALF_UP
            )
        )  # convert seconds to samples, round up

    def forward(self, x: torch.Tensor):

        # require batch dimension
        assert x.ndim >= 2

        # require mono audio, discard channel dimension
        n_batch, slen = x.shape[0], x.shape[-1]
        audio = x.reshape(n_batch, slen)

        eps = 1e-12  # numerical stability

        # determine number of frames
        if slen <= self.frame_len:
            n_frames = 1
        else:
            n_frames = 1 + int(
                math.ceil(
                    (1.0 * slen - self.frame_len) / self.frame_len)
            )

        # pad to integer frame length
        padlen = int(n_frames * self.frame_len)
        zeros = torch.zeros((x.shape[0], padlen - slen,)).to(x)
        padded = torch.cat((audio, zeros), dim=-1)

        # obtain strided (frame-wise) view of audio
        shape = (padded.shape[0], n_frames, self.frame_len)
        frames = torch.as_strided(
            padded,
            size=shape,
            stride=(padded.shape[-1], self.frame_len, 1)
        )

        # create frame-by-frame mask based on energy threshold
        mask = 20 * torch.log10(
            ((frames * self.scale).norm(dim=-1) / self.frame_len) + eps
        ) > self.threshold

        # turn frame-by-frame mask into sample-by-sample mask
        mask_wav = torch.repeat_interleave(mask, self.frame_len, dim=-1)
        samples_per_row = torch.sum(mask, dim=-1) * self.frame_len

        split = torch.split(padded[mask_wav], tuple(samples_per_row))

        # placeholder for outputs: (n_batch, 1, padded_length)
        final = torch.zeros_like(padded).unsqueeze(1)  # pad to preserve length

        # concatenate and pad split views
        for i, tensor in enumerate(split):
            length = tensor.shape[-1]
            final[i, :, :length] = tensor

        return final[..., :slen]