davesalvi commited on
Commit
36c50ca
·
1 Parent(s): acf7345

add rawnet2 code

Browse files
Files changed (4) hide show
  1. config/rawnet_config.yaml +50 -0
  2. src/audio_utils.py +41 -0
  3. src/rawnet_model.py +558 -0
  4. src/utils.py +247 -0
config/rawnet_config.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+
3
+ seed: 1234
4
+ num_thread: 6
5
+ prefetch_factor: 2
6
+
7
+ num_epochs: 150
8
+ early_stopping: 10
9
+ lr: 0.0001
10
+ weight_decay: 0.0001
11
+ batch_size: 128
12
+ #batch_size: 32
13
+
14
+ T_max: 100
15
+ eta_min: 0.00001
16
+
17
+ save_model_folder: 'checkpoints/rawnet2_model/'
18
+ save_results_folder: 'results/'
19
+ model_pretrained: 'RAWNET_ASVSPOOF.pth'
20
+
21
+ #amsgrad: 1
22
+ win_len: 3.0
23
+
24
+ training_asvspoof: True
25
+ training_FoR: True
26
+ training_InTheWild: True
27
+
28
+ train_model: True
29
+ eval_model: True
30
+
31
+ #model-related
32
+ model:
33
+ first_conv: 1024 # no. of filter coefficients
34
+ in_channels: 1
35
+ filts: [20, [20, 20], [20, 128], [128, 128]] # no. of filters channel in residual blocks
36
+ blocks: [2, 4]
37
+ nb_fc_node: 1024
38
+ gru_node: 1024
39
+ nb_gru_layer: 3
40
+ nb_classes: 2
41
+
42
+ old_model:
43
+ first_conv: 1024 # no. of filter coefficients
44
+ in_channels: 1
45
+ filts: [20, [20, 20], [20, 128], [128, 128]] # no. of filters channel in residual blocks
46
+ blocks: [2, 4]
47
+ nb_fc_node: 1024
48
+ gru_node: 1024
49
+ nb_gru_layer: 3
50
+ nb_classes: 2
src/audio_utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import numpy as np
3
+ import warnings
4
+ import soundfile as sf
5
+
6
+ warnings.filterwarnings("ignore")
7
+
8
+ def read_audio(audio_path, dur=180, fs=16000, trim=False, int_type=False, windowing=False, freq_min=None, freq_max=6000):
9
+
10
+ if audio_path.endswith('.wav'):
11
+ X, fs_orig = sf.read(audio_path)
12
+ # X, fs_orig = librosa.load(audio_path, sr=None, duration=dur)
13
+ if fs_orig != fs:
14
+ X = librosa.resample(X, orig_sr=fs_orig, target_sr=fs)
15
+ else:
16
+ X = np.load(audio_path)
17
+ fs = 16000
18
+
19
+ if trim:
20
+ X = librosa.effects.trim(X, top_db=20)[0]
21
+ # from float to int
22
+ if int_type:
23
+ X = (X * 32768).astype(np.int32)
24
+ if windowing:
25
+ win_len = 3 # in seconds
26
+ mask = np.zeros(dur*fs).astype(bool)
27
+ for ii in range(mask.shape[0]//(win_len*fs)):
28
+ mask[ii*win_len*fs:ii*win_len*fs+fs] = True
29
+ mask = mask[:X.shape[0]]
30
+ X = X[mask]
31
+
32
+ sf.write(audio_path, X, fs)
33
+
34
+ return X, fs
35
+
36
+
37
+ def mix_tracks(audio1, audio2):
38
+ mix_len = np.min([len(audio1), len(audio2)])
39
+ mix = (audio1[:mix_len] + audio2[:mix_len]) / 2
40
+
41
+ return mix
src/rawnet_model.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ from collections import OrderedDict
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch import Tensor
6
+ import sys
7
+ from src.audio_utils import *
8
+ import random
9
+ import pandas as pd
10
+ import pdb
11
+
12
+ class SincConv(nn.Module):
13
+ @staticmethod
14
+ def to_mel(hz):
15
+ return 2595 * np.log10(1 + hz / 700)
16
+
17
+ @staticmethod
18
+ def to_hz(mel):
19
+ return 700 * (10 ** (mel / 2595) - 1)
20
+
21
+ def __init__(self, device, out_channels, kernel_size, in_channels=1, sample_rate=16000,
22
+ stride=1, padding=0, dilation=1, bias=False, groups=1):
23
+
24
+ super(SincConv, self).__init__()
25
+
26
+ if in_channels != 1:
27
+ msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels)
28
+ raise ValueError(msg)
29
+
30
+ self.out_channels = out_channels
31
+ self.kernel_size = kernel_size
32
+ self.sample_rate = sample_rate
33
+
34
+ # Forcing the filters to be odd (i.e, perfectly symmetrics)
35
+ if kernel_size % 2 == 0:
36
+ self.kernel_size = self.kernel_size + 1
37
+
38
+ self.device = device
39
+ self.stride = stride
40
+ self.padding = padding
41
+ self.dilation = dilation
42
+
43
+ if bias:
44
+ raise ValueError('SincConv does not support bias.')
45
+ if groups > 1:
46
+ raise ValueError('SincConv does not support groups.')
47
+
48
+ # initialize filterbanks using Mel scale
49
+ NFFT = 512
50
+ f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
51
+ fmel = self.to_mel(f) # Hz to mel conversion
52
+ fmelmax = np.max(fmel)
53
+ fmelmin = np.min(fmel)
54
+ filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 1)
55
+ filbandwidthsf = self.to_hz(filbandwidthsmel) # Mel to Hz conversion
56
+ self.mel = filbandwidthsf
57
+ self.hsupp = torch.arange(-(self.kernel_size - 1) / 2, (self.kernel_size - 1) / 2 + 1)
58
+ self.band_pass = torch.zeros(self.out_channels, self.kernel_size)
59
+
60
+ def forward(self, x):
61
+ for i in range(len(self.mel) - 1):
62
+ fmin = self.mel[i]
63
+ fmax = self.mel[i + 1]
64
+ hHigh = (2 * fmax / self.sample_rate) * np.sinc(2 * fmax * self.hsupp / self.sample_rate)
65
+ hLow = (2 * fmin / self.sample_rate) * np.sinc(2 * fmin * self.hsupp / self.sample_rate)
66
+ hideal = hHigh - hLow
67
+
68
+ self.band_pass[i, :] = Tensor(np.hamming(self.kernel_size)) * Tensor(hideal)
69
+
70
+ band_pass_filter = self.band_pass.to(self.device)
71
+
72
+ self.filters = (band_pass_filter).view(self.out_channels, 1, self.kernel_size)
73
+
74
+ return F.conv1d(x, self.filters, stride=self.stride,
75
+ padding=self.padding, dilation=self.dilation,
76
+ bias=None, groups=1)
77
+
78
+
79
+ class Residual_block(nn.Module):
80
+ def __init__(self, nb_filts, first=False):
81
+ super(Residual_block, self).__init__()
82
+ self.first = first
83
+
84
+ if not self.first:
85
+ self.bn1 = nn.BatchNorm1d(num_features=nb_filts[0])
86
+
87
+ self.lrelu = nn.LeakyReLU(negative_slope=0.3)
88
+
89
+ self.conv1 = nn.Conv1d(in_channels=nb_filts[0],
90
+ out_channels=nb_filts[1],
91
+ kernel_size=3,
92
+ padding=1,
93
+ stride=1)
94
+
95
+ self.bn2 = nn.BatchNorm1d(num_features=nb_filts[1])
96
+ self.conv2 = nn.Conv1d(in_channels=nb_filts[1],
97
+ out_channels=nb_filts[1],
98
+ padding=1,
99
+ kernel_size=3,
100
+ stride=1)
101
+
102
+ if nb_filts[0] != nb_filts[1]:
103
+ self.downsample = True
104
+ self.conv_downsample = nn.Conv1d(in_channels=nb_filts[0],
105
+ out_channels=nb_filts[1],
106
+ padding=0,
107
+ kernel_size=1,
108
+ stride=1)
109
+
110
+ else:
111
+ self.downsample = False
112
+ self.mp = nn.MaxPool1d(3)
113
+
114
+ def forward(self, x):
115
+ identity = x
116
+ if not self.first:
117
+ out = self.bn1(x)
118
+ out = self.lrelu(out)
119
+ else:
120
+ out = x
121
+
122
+ out = self.conv1(x)
123
+ out = self.bn2(out)
124
+ out = self.lrelu(out)
125
+ out = self.conv2(out)
126
+
127
+ if self.downsample:
128
+ identity = self.conv_downsample(identity)
129
+
130
+ out += identity
131
+ out = self.mp(out)
132
+ return out
133
+
134
+
135
+ class RawNet(nn.Module):
136
+ def __init__(self, d_args, device):
137
+ super(RawNet, self).__init__()
138
+
139
+ self.device = device
140
+
141
+ self.Sinc_conv = SincConv(device=self.device,
142
+ out_channels=d_args['filts'][0],
143
+ kernel_size=d_args['first_conv'],
144
+ in_channels=d_args['in_channels']
145
+ )
146
+
147
+ # self.Sinc_conv = SincConv(out_channels=d_args['filts'][0],
148
+ # kernel_size=d_args['first_conv'])
149
+
150
+ self.first_bn = nn.BatchNorm1d(num_features=d_args['filts'][0])
151
+ self.selu = nn.SELU(inplace=True)
152
+ self.block0 = nn.Sequential(Residual_block(nb_filts=d_args['filts'][1], first=True))
153
+ self.block1 = nn.Sequential(Residual_block(nb_filts=d_args['filts'][1]))
154
+ self.block2 = nn.Sequential(Residual_block(nb_filts=d_args['filts'][2]))
155
+ d_args['filts'][2][0] = d_args['filts'][2][1]
156
+ self.block3 = nn.Sequential(Residual_block(nb_filts=d_args['filts'][2]))
157
+ self.block4 = nn.Sequential(Residual_block(nb_filts=d_args['filts'][2]))
158
+ self.block5 = nn.Sequential(Residual_block(nb_filts=d_args['filts'][2]))
159
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
160
+
161
+ self.fc_attention0 = self._make_attention_fc(in_features=d_args['filts'][1][-1],
162
+ l_out_features=d_args['filts'][1][-1])
163
+ self.fc_attention1 = self._make_attention_fc(in_features=d_args['filts'][1][-1],
164
+ l_out_features=d_args['filts'][1][-1])
165
+ self.fc_attention2 = self._make_attention_fc(in_features=d_args['filts'][2][-1],
166
+ l_out_features=d_args['filts'][2][-1])
167
+ self.fc_attention3 = self._make_attention_fc(in_features=d_args['filts'][2][-1],
168
+ l_out_features=d_args['filts'][2][-1])
169
+ self.fc_attention4 = self._make_attention_fc(in_features=d_args['filts'][2][-1],
170
+ l_out_features=d_args['filts'][2][-1])
171
+ self.fc_attention5 = self._make_attention_fc(in_features=d_args['filts'][2][-1],
172
+ l_out_features=d_args['filts'][2][-1])
173
+
174
+ self.bn_before_gru = nn.BatchNorm1d(num_features=d_args['filts'][2][-1])
175
+ self.gru = nn.GRU(input_size=d_args['filts'][2][-1],
176
+ hidden_size=d_args['gru_node'],
177
+ num_layers=d_args['nb_gru_layer'],
178
+ batch_first=True)
179
+
180
+ self.fc1_gru = nn.Linear(in_features=d_args['gru_node'],
181
+ out_features=d_args['nb_fc_node'])
182
+
183
+ self.fc2_gru = nn.Linear(in_features=d_args['nb_fc_node'],
184
+ out_features=d_args['nb_classes'], bias=True)
185
+
186
+ self.sig = nn.Sigmoid()
187
+ self.logsoftmax = nn.LogSoftmax(dim=1)
188
+ #
189
+ # def forward(self, x, y=None):
190
+ #
191
+ # nb_samp = x.shape[0]
192
+ # len_seq = x.shape[1]
193
+ # x = x.view(nb_samp, 1, len_seq)
194
+ #
195
+ # x = self.Sinc_conv(x)
196
+ # x = F.max_pool1d(torch.abs(x), 3)
197
+ # x = self.first_bn(x)
198
+ # x = self.selu(x)
199
+ #
200
+ # x0 = self.block0(x)
201
+ # y0 = self.avgpool(x0).view(x0.size(0), -1) # torch.Size([batch, filter])
202
+ # y0 = self.fc_attention0(y0)
203
+ # y0 = self.sig(y0).view(y0.size(0), y0.size(1), -1) # torch.Size([batch, filter, 1])
204
+ # x = x0 * y0 + y0 # (batch, filter, time) x (batch, filter, 1)
205
+ #
206
+ # x1 = self.block1(x)
207
+ # y1 = self.avgpool(x1).view(x1.size(0), -1) # torch.Size([batch, filter])
208
+ # y1 = self.fc_attention1(y1)
209
+ # y1 = self.sig(y1).view(y1.size(0), y1.size(1), -1) # torch.Size([batch, filter, 1])
210
+ # x = x1 * y1 + y1 # (batch, filter, time) x (batch, filter, 1)
211
+ #
212
+ # x2 = self.block2(x)
213
+ # y2 = self.avgpool(x2).view(x2.size(0), -1) # torch.Size([batch, filter])
214
+ # y2 = self.fc_attention2(y2)
215
+ # y2 = self.sig(y2).view(y2.size(0), y2.size(1), -1) # torch.Size([batch, filter, 1])
216
+ # x = x2 * y2 + y2 # (batch, filter, time) x (batch, filter, 1)
217
+ #
218
+ # x3 = self.block3(x)
219
+ # y3 = self.avgpool(x3).view(x3.size(0), -1) # torch.Size([batch, filter])
220
+ # y3 = self.fc_attention3(y3)
221
+ # y3 = self.sig(y3).view(y3.size(0), y3.size(1), -1) # torch.Size([batch, filter, 1])
222
+ # x = x3 * y3 + y3 # (batch, filter, time) x (batch, filter, 1)
223
+ #
224
+ # x4 = self.block4(x)
225
+ # y4 = self.avgpool(x4).view(x4.size(0), -1) # torch.Size([batch, filter])
226
+ # y4 = self.fc_attention4(y4)
227
+ # y4 = self.sig(y4).view(y4.size(0), y4.size(1), -1) # torch.Size([batch, filter, 1])
228
+ # x = x4 * y4 + y4 # (batch, filter, time) x (batch, filter, 1)
229
+ #
230
+ # x5 = self.block5(x)
231
+ # y5 = self.avgpool(x5).view(x5.size(0), -1) # torch.Size([batch, filter])
232
+ # y5 = self.fc_attention5(y5)
233
+ # y5 = self.sig(y5).view(y5.size(0), y5.size(1), -1) # torch.Size([batch, filter, 1])
234
+ # x = x5 * y5 + y5 # (batch, filter, time) x (batch, filter, 1)
235
+ #
236
+ # x = self.bn_before_gru(x)
237
+ # x = self.selu(x)
238
+ # x = x.permute(0, 2, 1) # (batch, filt, time) >> (batch, time, filt)
239
+ # self.gru.flatten_parameters()
240
+ # x, _ = self.gru(x)
241
+ # x = x[:, -1, :]
242
+ # x = self.fc1_gru(x)
243
+ # x = self.fc2_gru(x)
244
+ # output = self.logsoftmax(x)
245
+ #
246
+ # return output
247
+
248
+ def forward(self, x):
249
+ # Pass through Residual Part
250
+ x = self._forward_residual_part(x)
251
+
252
+ # pdb.set_trace()
253
+
254
+ # Pass through Processing Part
255
+ x = self._forward_processing_part(x)
256
+
257
+ output = self.logsoftmax(x)
258
+ return output
259
+
260
+
261
+
262
+ def _forward_residual_part(self, x):
263
+ nb_samp = x.shape[0]
264
+ len_seq = x.shape[1]
265
+ x = x.view(nb_samp, 1, len_seq)
266
+
267
+ # pdb.set_trace()
268
+
269
+ x = self.Sinc_conv(x)
270
+ x = F.max_pool1d(torch.abs(x), 3)
271
+ x = self.first_bn(x)
272
+ x = self.selu(x)
273
+
274
+ x0 = self.block0(x)
275
+ y0 = self.avgpool(x0).view(x0.size(0), -1)
276
+ y0 = self.fc_attention0(y0)
277
+ y0 = self.sig(y0).view(y0.size(0), y0.size(1), -1)
278
+ x = x0 * y0 + y0
279
+
280
+ x1 = self.block1(x)
281
+ y1 = self.avgpool(x1).view(x1.size(0), -1)
282
+ y1 = self.fc_attention1(y1)
283
+ y1 = self.sig(y1).view(y1.size(0), y1.size(1), -1)
284
+ x = x1 * y1 + y1
285
+
286
+ x2 = self.block2(x)
287
+ y2 = self.avgpool(x2).view(x2.size(0), -1)
288
+ y2 = self.fc_attention2(y2)
289
+ y2 = self.sig(y2).view(y2.size(0), y2.size(1), -1)
290
+ x = x2 * y2 + y2
291
+
292
+ x3 = self.block3(x)
293
+ y3 = self.avgpool(x3).view(x3.size(0), -1)
294
+ y3 = self.fc_attention3(y3)
295
+ y3 = self.sig(y3).view(y3.size(0), y3.size(1), -1)
296
+ x = x3 * y3 + y3
297
+
298
+ x4 = self.block4(x)
299
+ y4 = self.avgpool(x4).view(x4.size(0), -1)
300
+ y4 = self.fc_attention4(y4)
301
+ y4 = self.sig(y4).view(y4.size(0), y4.size(1), -1)
302
+ x = x4 * y4 + y4
303
+
304
+ x5 = self.block5(x)
305
+ y5 = self.avgpool(x5).view(x5.size(0), -1)
306
+ y5 = self.fc_attention5(y5)
307
+ y5 = self.sig(y5).view(y5.size(0), y5.size(1), -1)
308
+ x = x5 * y5 + y5
309
+
310
+ return x
311
+
312
+ def _forward_processing_part(self, x):
313
+ x = self.bn_before_gru(x)
314
+ x = self.selu(x)
315
+ x = x.permute(0, 2, 1)
316
+ self.gru.flatten_parameters()
317
+ x, _ = self.gru(x)
318
+ x = x[:, -1, :]
319
+ x = self.fc1_gru(x)
320
+ x = self.fc2_gru(x)
321
+ return x
322
+
323
+
324
+ def freeze_processing_part(self):
325
+ for param in self.bn_before_gru.parameters():
326
+ param.requires_grad = False
327
+ for param in self.gru.parameters():
328
+ param.requires_grad = False
329
+ for param in self.fc1_gru.parameters():
330
+ param.requires_grad = False
331
+ for param in self.fc2_gru.parameters():
332
+ param.requires_grad = False
333
+
334
+ def unfreeze_processing_part(self):
335
+ for param in self.bn_before_gru.parameters():
336
+ param.requires_grad = True
337
+ for param in self.gru.parameters():
338
+ param.requires_grad = True
339
+ for param in self.fc1_gru.parameters():
340
+ param.requires_grad = True
341
+ for param in self.fc2_gru.parameters():
342
+ param.requires_grad = True
343
+
344
+ def freeze_residual_part(self):
345
+ for param in self.Sinc_conv.parameters():
346
+ param.requires_grad = False
347
+ for param in self.first_bn.parameters():
348
+ param.requires_grad = False
349
+ for param in self.block0.parameters():
350
+ param.requires_grad = False
351
+ for param in self.block1.parameters():
352
+ param.requires_grad = False
353
+ for param in self.block2.parameters():
354
+ param.requires_grad = False
355
+ for param in self.block3.parameters():
356
+ param.requires_grad = False
357
+ for param in self.block4.parameters():
358
+ param.requires_grad = False
359
+ for param in self.block5.parameters():
360
+ param.requires_grad = False
361
+ for param in self.fc_attention0.parameters():
362
+ param.requires_grad = False
363
+ for param in self.fc_attention1.parameters():
364
+ param.requires_grad = False
365
+ for param in self.fc_attention2.parameters():
366
+ param.requires_grad = False
367
+ for param in self.fc_attention3.parameters():
368
+ param.requires_grad = False
369
+ for param in self.fc_attention4.parameters():
370
+ param.requires_grad = False
371
+ for param in self.fc_attention5.parameters():
372
+ param.requires_grad = False
373
+
374
+ def unfreeze_residual_part(self):
375
+ for param in self.Sinc_conv.parameters():
376
+ param.requires_grad = True
377
+ for param in self.first_bn.parameters():
378
+ param.requires_grad = True
379
+ for param in self.block0.parameters():
380
+ param.requires_grad = True
381
+ for param in self.block1.parameters():
382
+ param.requires_grad = True
383
+ for param in self.block2.parameters():
384
+ param.requires_grad = True
385
+ for param in self.block3.parameters():
386
+ param.requires_grad = True
387
+ for param in self.block4.parameters():
388
+ param.requires_grad = True
389
+ for param in self.block5.parameters():
390
+ param.requires_grad = True
391
+ for param in self.fc_attention0.parameters():
392
+ param.requires_grad = True
393
+ for param in self.fc_attention1.parameters():
394
+ param.requires_grad = True
395
+ for param in self.fc_attention2.parameters():
396
+ param.requires_grad = True
397
+ for param in self.fc_attention3.parameters():
398
+ param.requires_grad = True
399
+ for param in self.fc_attention4.parameters():
400
+ param.requires_grad = True
401
+ for param in self.fc_attention5.parameters():
402
+ param.requires_grad = True
403
+
404
+
405
+ def get_embeddings(self, x):
406
+ nb_samp = x.shape[0]
407
+ len_seq = x.shape[1]
408
+ x = x.view(nb_samp, 1, len_seq)
409
+
410
+ x = self.Sinc_conv(x)
411
+ x = F.max_pool1d(torch.abs(x), 3)
412
+ x = self.first_bn(x)
413
+ x = self.selu(x)
414
+
415
+ x0 = self.block0(x)
416
+ y0 = self.avgpool(x0).view(x0.size(0), -1)
417
+ y0 = self.fc_attention0(y0)
418
+ y0 = self.sig(y0).view(y0.size(0), y0.size(1), -1)
419
+ x = x0 * y0 + y0
420
+
421
+ x1 = self.block1(x)
422
+ y1 = self.avgpool(x1).view(x1.size(0), -1)
423
+ y1 = self.fc_attention1(y1)
424
+ y1 = self.sig(y1).view(y1.size(0), y1.size(1), -1)
425
+ x = x1 * y1 + y1
426
+
427
+ x2 = self.block2(x)
428
+ y2 = self.avgpool(x2).view(x2.size(0), -1)
429
+ y2 = self.fc_attention2(y2)
430
+ y2 = self.sig(y2).view(y2.size(0), y2.size(1), -1)
431
+ x = x2 * y2 + y2
432
+
433
+ x3 = self.block3(x)
434
+ y3 = self.avgpool(x3).view(x3.size(0), -1)
435
+ y3 = self.fc_attention3(y3)
436
+ y3 = self.sig(y3).view(y3.size(0), y3.size(1), -1)
437
+ x = x3 * y3 + y3
438
+
439
+ x4 = self.block4(x)
440
+ y4 = self.avgpool(x4).view(x4.size(0), -1)
441
+ y4 = self.fc_attention4(y4)
442
+ y4 = self.sig(y4).view(y4.size(0), y4.size(1), -1)
443
+ x = x4 * y4 + y4
444
+
445
+ x5 = self.block5(x)
446
+ y5 = self.avgpool(x5).view(x5.size(0), -1)
447
+ y5 = self.fc_attention5(y5)
448
+ y5 = self.sig(y5).view(y5.size(0), y5.size(1), -1)
449
+ x = x5 * y5 + y5
450
+
451
+ x = self.bn_before_gru(x)
452
+ x = self.selu(x)
453
+ x = x.permute(0, 2, 1) # (batch, filt, time) >> (batch, time, filt)
454
+ self.gru.flatten_parameters()
455
+ x, _ = self.gru(x)
456
+
457
+ embeddings = x[:, -1, :] # Extract the embeddings from the GRU output
458
+
459
+ return embeddings
460
+
461
+
462
+ def _make_attention_fc(self, in_features, l_out_features):
463
+
464
+ l_fc = []
465
+
466
+ l_fc.append(nn.Linear(in_features=in_features,
467
+ out_features=l_out_features))
468
+
469
+ return nn.Sequential(*l_fc)
470
+
471
+ def _make_layer(self, nb_blocks, nb_filts, first=False):
472
+ layers = []
473
+ # def __init__(self, nb_filts, first = False):
474
+ for i in range(nb_blocks):
475
+ first = first if i == 0 else False
476
+ layers.append(Residual_block(nb_filts=nb_filts,
477
+ first=first))
478
+ if i == 0: nb_filts[0] = nb_filts[1]
479
+
480
+ return nn.Sequential(*layers)
481
+
482
+ def summary(self, input_size, batch_size=-1, device="cuda", print_fn=None):
483
+ if print_fn == None: printfn = print
484
+ model = self
485
+
486
+ def register_hook(module):
487
+ def hook(module, input, output):
488
+ class_name = str(module.__class__).split(".")[-1].split("'")[0]
489
+ module_idx = len(summary)
490
+
491
+ m_key = "%s-%i" % (class_name, module_idx + 1)
492
+ summary[m_key] = OrderedDict()
493
+ summary[m_key]["input_shape"] = list(input[0].size())
494
+ summary[m_key]["input_shape"][0] = batch_size
495
+ if isinstance(output, (list, tuple)):
496
+ summary[m_key]["output_shape"] = [
497
+ [-1] + list(o.size())[1:] for o in output
498
+ ]
499
+ else:
500
+ summary[m_key]["output_shape"] = list(output.size())
501
+ if len(summary[m_key]["output_shape"]) != 0:
502
+ summary[m_key]["output_shape"][0] = batch_size
503
+
504
+ params = 0
505
+ if hasattr(module, "weight") and hasattr(module.weight, "size"):
506
+ params += torch.prod(torch.LongTensor(list(module.weight.size())))
507
+ summary[m_key]["trainable"] = module.weight.requires_grad
508
+ if hasattr(module, "bias") and hasattr(module.bias, "size"):
509
+ params += torch.prod(torch.LongTensor(list(module.bias.size())))
510
+ summary[m_key]["nb_params"] = params
511
+
512
+ if (
513
+ not isinstance(module, nn.Sequential)
514
+ and not isinstance(module, nn.ModuleList)
515
+ and not (module == model)
516
+ ):
517
+ hooks.append(module.register_forward_hook(hook))
518
+
519
+ device = device.lower()
520
+ assert device in [
521
+ "cuda",
522
+ "cpu",
523
+ ], "Input device is not valid, please specify 'cuda' or 'cpu'"
524
+
525
+ if device == "cuda" and torch.cuda.is_available():
526
+ dtype = torch.cuda.FloatTensor
527
+ else:
528
+ dtype = torch.FloatTensor
529
+ if isinstance(input_size, tuple):
530
+ input_size = [input_size]
531
+ x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
532
+ summary = OrderedDict()
533
+ hooks = []
534
+ model.apply(register_hook)
535
+ model(*x)
536
+ for h in hooks:
537
+ h.remove()
538
+
539
+ print_fn("----------------------------------------------------------------")
540
+ line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
541
+ print_fn(line_new)
542
+ print_fn("================================================================")
543
+ total_params = 0
544
+ total_output = 0
545
+ trainable_params = 0
546
+ for layer in summary:
547
+ # input_shape, output_shape, trainable, nb_params
548
+ line_new = "{:>20} {:>25} {:>15}".format(
549
+ layer,
550
+ str(summary[layer]["output_shape"]),
551
+ "{0:,}".format(summary[layer]["nb_params"]),
552
+ )
553
+ total_params += summary[layer]["nb_params"]
554
+ total_output += np.prod(summary[layer]["output_shape"])
555
+ if "trainable" in summary[layer]:
556
+ if summary[layer]["trainable"] == True:
557
+ trainable_params += summary[layer]["nb_params"]
558
+ print_fn(line_new)
src/utils.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import GPUtil
5
+ import yaml
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ from sklearn.metrics import roc_curve, auc, confusion_matrix
9
+ import pandas as pd
10
+ import torch.nn as nn
11
+
12
+
13
+ def set_gpu(id=-1):
14
+ """
15
+ Set GPU device or select the one with the lowest memory usage (None for CPU-only)
16
+
17
+ :param id: if specified, corresponds to the GPU index desired.
18
+ """
19
+ if id is None:
20
+ # CPU only
21
+ print('GPU not selected')
22
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(-1)
23
+ else:
24
+ # -1 for automatic choice
25
+ device = id if id != -1 else GPUtil.getFirstAvailable(order='memory')[0]
26
+ try:
27
+ name = GPUtil.getGPUs()[device].name
28
+ except IndexError:
29
+ print('The selected GPU does not exist. Switching to the most available one.')
30
+ device = GPUtil.getFirstAvailable(order='memory')[0]
31
+ name = GPUtil.getGPUs()[device].name
32
+ print('GPU selected: %d - %s' % (device, name))
33
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(device)
34
+ return device
35
+
36
+
37
+ def prepare_asvspoof_data(config):
38
+
39
+ data_dir_2019 = '/nas/public/dataset/asvspoof2019/LA/ASVspoof2019_LA_cm_protocols'
40
+ data_eval_2021 = '/nas/public/dataset/asvspoof2021/DF_cm_eval_labels.txt'
41
+ files = [os.path.join(data_dir_2019, 'ASVspoof2019.LA.cm.train.trn.txt'),
42
+ os.path.join(data_dir_2019, 'ASVspoof2019.LA.cm.dev.trl.txt'), data_eval_2021]
43
+
44
+ audio_dir_2019 = '/nas/public/dataset/asvspoof2019/LA'
45
+ audio_dir_2021 = '/nas/public/dataset/asvspoof2021/ASVspoof2021_DF_eval/flac/'
46
+ set_dirs = [os.path.join(audio_dir_2019, 'ASVspoof2019_LA_train/flac/'),
47
+ os.path.join(audio_dir_2019, 'ASVspoof2019_LA_dev/flac/'), audio_dir_2021]
48
+
49
+ save_paths = [config['df_train_path'], config['df_dev_path'], config['df_eval_path']]
50
+
51
+ for file_path, set_dir, save_path in zip(files, set_dirs, save_paths):
52
+
53
+ txt_file = pd.read_csv(file_path, sep=' ', header=None)
54
+ txt_file = txt_file.replace({'bonafide': 0, 'spoof': 1})
55
+
56
+ txt_file.iloc[:,1] = set_dir + txt_file.iloc[:,1].astype(str) + '.flac'
57
+
58
+ if not file_path == data_eval_2021:
59
+ df = txt_file[[1, 4]]
60
+ df = df.rename({1: 'path', 4: 'label'}, axis='columns')
61
+ else:
62
+ df = txt_file[[1, 5]]
63
+ df = df.rename({1: 'path', 5: 'label'}, axis='columns')
64
+
65
+ df.to_csv(save_path)
66
+
67
+
68
+ def init_weights(module):
69
+ if isinstance(module, nn.Linear):
70
+ torch.nn.init.xavier_uniform_(module.weight)
71
+ module.bias.data.fill_(0.01)
72
+
73
+
74
+ def read_yaml(config_path):
75
+ """
76
+ Read YAML file.
77
+
78
+ :param config_path: path to the YAML config file.
79
+ :type config_path: str
80
+ :return: dictionary correspondent to YAML content
81
+ :rtype dict
82
+ """
83
+ with open(config_path, 'r') as f:
84
+ config = yaml.safe_load(f)
85
+ return config
86
+
87
+
88
+ def sigmoid(x, factor=1):
89
+ """
90
+ Compute sigmoid function.
91
+
92
+ :param x: input signal
93
+ :param factor: sigmoid parameter
94
+ :return: sigmoid(x)
95
+ :rtype np.array
96
+ """
97
+ z = 1 / (1 + np.exp(-factor*x))
98
+ return z
99
+
100
+
101
+ def plot_roc_curve(labels, pred, legend=None):
102
+ """
103
+ Plot ROC curve.
104
+
105
+ :param labels: groundtruth labels
106
+ :type labels: list
107
+ :param pred: predicted score
108
+ :type pred: list
109
+ :param legend: if True, add legend to the plot
110
+ :type legend: bool
111
+ :return:
112
+ """
113
+ # labels and pred bust be given in (N, ) shape
114
+
115
+ def tpr5(y_true, y_pred):
116
+ fpr, tpr, thr = roc_curve(y_true, y_pred)
117
+ fp_sort = sorted(fpr)
118
+ tp_sort = sorted(tpr)
119
+ tpr_ind = [i for (i, val) in enumerate(fp_sort) if val >= 0.1][0]
120
+ tpr01 = tp_sort[tpr_ind]
121
+ return tpr01
122
+
123
+ lw = 3
124
+
125
+
126
+ fpr, tpr, thres = roc_curve(labels, pred)
127
+ rocauc = auc(fpr, tpr)
128
+ fnr = 1 - tpr
129
+ eer = fpr[np.nanargmin(np.absolute(fnr - fpr))]
130
+ optimal_index = np.argmax(tpr - fpr)
131
+ optimal_threshold = thres[optimal_index]
132
+
133
+ print('TPR5 = {:.3f}'.format(tpr5(labels, pred)))
134
+ print('AUC = {:.3f}'.format(rocauc))
135
+ print('EER = {:.3f}'.format(eer))
136
+ print('Best Thres. = {:.3f}'.format(optimal_threshold))
137
+ print()
138
+ if legend:
139
+ plt.plot(fpr, tpr, lw=lw, label='$\mathrm{' + legend + ' - AUC = %0.2f}$' % rocauc)
140
+ else:
141
+ plt.plot(fpr, tpr, lw=lw, label='$\mathrm{AUC = %0.2f}$' % rocauc)
142
+ plt.plot([0, 1], [0, 1], color='black', lw=lw, linestyle='--')
143
+ plt.xlim([-0.02, 1.0])
144
+ plt.ylim([0.0, 1.03])
145
+ plt.xlabel(r'$\mathrm{False\;Positive\;Rate}$', fontsize=18)
146
+ plt.ylabel(r'$\mathrm{True\;Positive\;Rate}$', fontsize=18)
147
+ plt.legend(loc="lower right", fontsize=15)
148
+ plt.xticks(fontsize=15)
149
+ plt.yticks(fontsize=15)
150
+ plt.grid(True)
151
+ # plt.show()
152
+
153
+ return optimal_threshold
154
+
155
+ def plot_confusion_matrix(y_true, y_pred, normalize=False, cmap=plt.cm.Blues):
156
+ """
157
+ Plot confusion matrix.
158
+
159
+ :param y_true: ground-truth labels
160
+ :type y_true: list
161
+ :param y_pred: predicted labels
162
+ :type y_pred: list
163
+ :param normalize: if set to True, normalise the confusion matrix.
164
+ :type normalize: bool
165
+ :param cmap: matplotlib cmap to be used for plot
166
+ :type cmap:
167
+ :return:
168
+ """
169
+ cm = confusion_matrix(y_true, y_pred)
170
+ # Only use the labels that appear in the data
171
+ # classes = classes[unique_labels(y_true, y_pred)]
172
+ classes = ['$\it{Real}$','$\it{Fake}$']
173
+ if normalize:
174
+ cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
175
+ print(cm)
176
+
177
+ fsize = 25 # fontsize
178
+ fig, ax = plt.subplots()
179
+ im = ax.imshow(cm, interpolation='nearest', cmap=cmap, clim=(0,1))
180
+ cbar = ax.figure.colorbar(im, ax=ax)
181
+ cbar.ax.tick_params(labelsize=fsize)
182
+ ax.set(xticks=np.arange(cm.shape[1]),
183
+ yticks=np.arange(cm.shape[0]),
184
+ )
185
+ ax.set_xlabel('$\mathrm{True\;label}$', fontsize=fsize)
186
+ ax.set_ylabel('$\mathrm{Predicted\;label}$', fontsize=fsize)
187
+ ax.set_xticklabels(classes, fontsize=fsize)
188
+ ax.set_yticklabels(classes, fontsize=fsize)
189
+ # Rotate the tick labels and set their alignment.
190
+ # plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
191
+ # rotation_mode="anchor")
192
+ # Loop over data dimensions and create text annotations.
193
+ fmt = '.3f' if normalize else 'd'
194
+ thresh = cm.max() / 2.
195
+ for i in range(cm.shape[0]):
196
+ for j in range(cm.shape[1]):
197
+ ax.text(j, i, format('$\mathrm{' + str(format(cm[i, j], fmt)) + '}$'),
198
+ ha="center", va="center",
199
+ fontsize=fsize,
200
+ color="white" if np.array(cm[i, j]) > thresh else "black")
201
+ fig.tight_layout()
202
+ # plt.show()
203
+
204
+ return ax
205
+
206
+
207
+ def reconstruct_from_pred(pred_array, win_len, hop_size, fs=16000):
208
+ """
209
+ Create a score array with length equal to the original signal length starting from predictions aggregated on
210
+ rectangular windows.
211
+
212
+ :param pred_array: aggregated prediction array
213
+ :type pred_array: list
214
+ :param win_len: length of the window used for aggregation
215
+ :type win_len: int
216
+ :param hop_size: length of the hop used for aggregation
217
+ :type hop_size: int
218
+ :param fs: sampling frequency
219
+ :type fs: int
220
+ :return: reconstructed array
221
+ """
222
+
223
+ pred_array = np.array(pred_array)
224
+ audio_shape = (len(pred_array)-1) * hop_size * fs + win_len * fs
225
+
226
+ window_pred = np.zeros((len(pred_array), int(audio_shape)))
227
+ for idx, pred in enumerate(pred_array):
228
+ window_pred[idx, int(idx*hop_size*fs):int((idx*hop_size+win_len)*fs)] = pred
229
+
230
+ window_pred = np.nanmean(np.where(window_pred != 0, window_pred, np.nan), 0)
231
+
232
+ return window_pred
233
+
234
+
235
+ def seed_everything(seed: int):
236
+ """
237
+ Set seed for everything.
238
+ :param seed: seed value
239
+ :type seed: int
240
+ """
241
+ random.seed(seed)
242
+ os.environ['PYTHONHASHSEED'] = str(seed)
243
+ np.random.seed(seed)
244
+ torch.manual_seed(seed)
245
+ torch.cuda.manual_seed(seed)
246
+ torch.backends.cudnn.deterministic = True
247
+ torch.backends.cudnn.benchmark = True