Voice Activity Detection
English
hypersunflower commited on
Commit
d1124fa
·
verified ·
1 Parent(s): fcdae06

Upload 3 files

Browse files
Files changed (3) hide show
  1. logMelSpectrogram.py +175 -0
  2. sadModel.py +32 -0
  3. speech_detection.py +76 -0
logMelSpectrogram.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import numpy as np
4
+ from matplotlib import pyplot as plt
5
+
6
+ from typing import Optional
7
+
8
+ class logMelSpectrogram:
9
+
10
+ def __init__(
11
+ self,
12
+ frame_rate_s: int = 30,
13
+ stride_s: int = 10,
14
+ n_fft: Optional[int] = None,
15
+ n_mels: Optional[int] = 40,
16
+ top_db: int = 80,
17
+ pre_emph_coef: float = 0.95,
18
+ device: Optional[str] = None
19
+ ):
20
+
21
+ self.frame_rate_s = frame_rate_s
22
+ self.stride_s = stride_s
23
+ self.n_fft = n_fft
24
+ self.n_mels = n_mels
25
+ self.log_mel_spec_is_computed = False
26
+ self.top_db = top_db
27
+ self.pre_emph_coef = pre_emph_coef
28
+
29
+ if not device:
30
+ self.device = "cuda" if torch.cuda.is_available() else (
31
+ "mps" if torch.mps.is_available() else "cpu"
32
+ )
33
+ self.device = device
34
+ torch.set_default_device(device)
35
+ torch.set_default_dtype(torch.float32)
36
+
37
+ def transform(
38
+ self,
39
+ samples: np.array,
40
+ sr: int,
41
+ ):
42
+
43
+ self.samples = torch.from_numpy(samples)
44
+ self.sr = sr
45
+
46
+ if self.samples.shape[0] < 2:
47
+ raise ValueError("Samples should be longer than two")
48
+
49
+
50
+ # pre emphasis
51
+ # it's necessary to compensate the audio roll off
52
+ # meaning it amplifies the difference between current signal
53
+ # and previous one
54
+
55
+ pre_emph_samples = torch.cat([
56
+ self.samples[0:1],
57
+ self.samples[1:] - self.pre_emph_coef * self.samples[:-1]
58
+ ], dim=0)
59
+
60
+ # framing
61
+ # it's needed to turn the audio into descrete overlapping chunks
62
+
63
+ stride = self.sr * self.stride_s // 1000
64
+ frame_rate = self.sr * self.frame_rate_s // 1000
65
+
66
+
67
+ chunks = pre_emph_samples.unfold(0, frame_rate, stride).contiguous()
68
+ num_of_frames = chunks.shape[0]
69
+
70
+ # hann window to smooth out the edges
71
+ # as i understand, it is necessary to
72
+ # smooth out the edges of chunks to avoid
73
+ # sudden drops and rises in volume
74
+
75
+ n = torch.arange(frame_rate)
76
+ hanning_weights = 0.5 - 0.5 * torch.cos(2 * torch.pi * n / (frame_rate - 1))
77
+
78
+ weighted_chunks = chunks * hanning_weights
79
+
80
+
81
+ # applying fast fourier transform
82
+ # to decompose "raw" audio into underlying frequencies
83
+ # only positive frequencies are taken, because negative freqs
84
+ # dont bring new information
85
+ # so there are about "half" (n_fft / 2 + 1) extracted
86
+ if not self.n_fft:
87
+ self.n_fft = 2 ** torch.ceil(torch.log2(torch.tensor(frame_rate, dtype=torch.float32))).to(torch.int32)
88
+
89
+ fft_chunks = torch.fft.rfft(weighted_chunks, n=self.n_fft)
90
+ power_spec = (2 / self.n_fft ** 2) * torch.abs(fft_chunks) ** 2
91
+
92
+
93
+ # herz to mels converter and vice versa
94
+
95
+ def hz_to_mel(hz):
96
+ return 2595 * torch.log10(1 + hz / 700)
97
+ def mel_to_hz(m):
98
+ return 700 * (10 ** (m / 2595) - 1)
99
+
100
+ fmax = self.sr / 2
101
+ fmin = 0
102
+
103
+ # here we create mels scale
104
+ mels = torch.linspace(
105
+ hz_to_mel(torch.tensor(fmin)),
106
+ hz_to_mel(torch.tensor(fmax)),
107
+ self.n_mels + 2
108
+ )
109
+
110
+ # converting linear mels to hz thus
111
+ # introducing non-linearity
112
+ hz_points = mel_to_hz(mels)
113
+ bins = torch.floor((self.n_fft + 1) * hz_points / self.sr).to(torch.int32)
114
+
115
+ # building triangular filters
116
+ # that are overlapping and gain "energy" with the increase of hz
117
+ # simulating human hearing that is better at distinguishing between lower
118
+ # freqs than higher ones
119
+ # so as the hz rises the filter becomes bigger
120
+ # and, if one might say, less sensitive
121
+ k = torch.arange(self.n_fft // 2 + 1).unsqueeze(0)
122
+
123
+ f_left = bins[:-2].unsqueeze(1)
124
+ f_center = bins[1:-1].unsqueeze(1)
125
+ f_right = bins[2:].unsqueeze(1)
126
+
127
+ up = (k - f_left) / torch.clamp(f_center - f_left, min=1e-8) # (n_mels, bins)
128
+ down = (f_right - k) / torch.clamp(f_right - f_center, min=1e-8) # (n_mels, bins)
129
+
130
+ filters = torch.clamp(torch.minimum(up, down), min=0.0)
131
+
132
+
133
+ mel_spec = torch.matmul(filters, power_spec.T)
134
+
135
+ # converting mel spectogram to log scale
136
+
137
+ mel_spec = torch.clamp(mel_spec, min=1e-10)
138
+ log_mel_spec = 10 * torch.log10(mel_spec)
139
+
140
+ # normalising
141
+
142
+ log_mel_spec = torch.clamp(
143
+ log_mel_spec,
144
+ min=torch.max(log_mel_spec) - self.top_db
145
+ )
146
+
147
+ self.log_mel_spec = log_mel_spec
148
+
149
+ self.log_mel_spec_is_computed = True
150
+
151
+ return log_mel_spec
152
+
153
+ def plot_waveform(self):
154
+
155
+ plt.figure(figsize=(10, 4))
156
+ cpu_samples = self.samples.cpu().numpy()
157
+ plt.plot(np.arange(cpu_samples.shape[0]) / self.sr, cpu_samples)
158
+ plt.title("Waveform")
159
+ plt.xlabel("Time (s)")
160
+ plt.ylabel("Amplitude")
161
+ plt.show()
162
+
163
+ def plot_log_mel_spec(self, cmap="magma_r"):
164
+
165
+ if not self.log_mel_spec_is_computed:
166
+ raise ValueError("run compute() before plotting log mel spectogram")
167
+
168
+ plt.figure(figsize=(10, 4))
169
+ spec_to_plot = self.log_mel_spec.cpu().numpy()
170
+ plt.imshow(spec_to_plot, origin="lower", aspect="auto", cmap=cmap)
171
+ plt.title("Log-Mel Spectrogram (dB)")
172
+ plt.xlabel("Time frames")
173
+ plt.ylabel("Mel bins")
174
+ plt.colorbar()
175
+ plt.show()
sadModel.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import nn
4
+
5
+ class sadModel(nn.Module):
6
+ def __init__(self, input_dim=40, hidden_dim=64, num_layers=1, output_dim=800):
7
+ super(sadModel, self).__init__()
8
+
9
+ # GRU expects input: (seq_len, batch, input_size)
10
+ self.gru = nn.GRU(
11
+ input_size=input_dim,
12
+ hidden_size=hidden_dim,
13
+ num_layers=num_layers,
14
+ batch_first=True,
15
+ bidirectional=True
16
+ )
17
+
18
+ self.fc = nn.Linear(hidden_dim * 2 * 400, output_dim) # 2 for bidirectional
19
+
20
+ def forward(self, x):
21
+ # x: (batch, 1, 40, 400) -> remove channel dim and permute
22
+ x = x.squeeze(1).permute(0, 2, 1) # (batch, 400, 40)
23
+
24
+ # pass through gru
25
+ out, _ = self.gru(x) # out: (batch, 400, hidden_dim*2)
26
+
27
+ # flatten time dimension
28
+ out = out.contiguous().view(out.size(0), -1) # (batch, 400*hidden_dim*2)
29
+
30
+ out = self.fc(out) # (batch, 800)
31
+
32
+ return out
speech_detection.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from pydub import AudioSegment
4
+ AudioSegment.converter = "/usr/bin/ffmpeg"
5
+
6
+ import numpy as np
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ import matplotlib.pyplot as plt
12
+
13
+ from typing import Optional
14
+
15
+ class detectSpeech:
16
+
17
+ def __init__(
18
+ self,
19
+ model_class,
20
+ logMelSpectrogram,
21
+ model_path: str,
22
+ stride_s: int = 25,
23
+ frame_rate_s: int = 25,
24
+ device: Optional[str] = None,
25
+ threshold: float = 0.5,
26
+ batch_size: int = 32,
27
+ sr: int = 16000
28
+ ):
29
+
30
+ if device is None:
31
+ self.device = "cuda" if torch.cuda.is_available() else (
32
+ "mps" if torch.mps.is_available() else "cpu"
33
+ )
34
+ else:
35
+ self.device = device
36
+
37
+ self.model_path = model_path
38
+
39
+ self.model = model_class.to(self.device)
40
+ self.model.load_state_dict(torch.load(self.model_path, weights_only=True))
41
+ self.model.eval()
42
+
43
+ self.log_mel_spec = logMelSpectrogram
44
+
45
+ self.sr = sr
46
+ self.stride = sr * stride_s // 1000
47
+ self.frame_rate = sr * frame_rate_s // 1000
48
+
49
+
50
+ def detect(
51
+ self,
52
+ audio_path: str
53
+ ):
54
+
55
+
56
+ audio = AudioSegment.from_file(audio_path)
57
+ audio = audio.set_channels(1).set_frame_rate(self.sr)
58
+ samples = np.array(audio.get_array_of_samples(), dtype=np.float32)
59
+
60
+ log_mel = self.log_mel_spec.transform(samples=samples, sr=self.sr).to(self.device)
61
+
62
+
63
+ chunks_mel = log_mel.unfold(dimension=1, size=self.frame_rate, step=self.stride)
64
+ chunks_mel = chunks_mel.permute(1, 0, 2)
65
+
66
+ chunks_mel = F.normalize(chunks_mel).unsqueeze(1)
67
+
68
+ with torch.no_grad():
69
+ outputs = self.model.forward(chunks_mel)
70
+ outputs = torch.sigmoid(outputs)
71
+ outputs = (outputs >= 0.5).int()
72
+
73
+ onset, offset = torch.split(outputs, 400, dim=1)
74
+
75
+
76
+ return torch.flatten(onset), torch.flatten(offset)