MeysamSh commited on
Commit
9a5139a
·
verified ·
1 Parent(s): 80b2409

Upload folder using huggingface_hub

Browse files
evaluation/AASIST/.ipynb_checkpoints/AASIST_util-checkpoint.py ADDED
@@ -0,0 +1,1038 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AASIST
3
+ Copyright (c) 2021-present NAVER Corp.
4
+ MIT license
5
+ """
6
+
7
+ import random
8
+ from typing import Union
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch import Tensor
14
+ import sys
15
+ import os
16
+ import argparse
17
+ import torch.optim as optim
18
+ import torchaudio
19
+ from torch.utils.data import Dataset, DataLoader
20
+ from tqdm import tqdm
21
+ import torchaudio.transforms as T
22
+ from collections import defaultdict
23
+ import torch.multiprocessing
24
+
25
+ torch.multiprocessing.set_sharing_strategy('file_system')
26
+
27
+
28
+ def extract_system_id(wavname):
29
+ """Extrait l'identifiant du système à partir du nom du fichier."""
30
+ return wavname.split('-')[0]
31
+
32
+ def pad(x, max_len=64600):
33
+ """ Padding ou découpage d'un signal audio """
34
+ x_len = x.shape[0]
35
+ if x_len >= max_len:
36
+ return x[:max_len]
37
+ num_repeats = int(max_len / x_len) + 1
38
+ padded_x = np.tile(x, (num_repeats))[:max_len]
39
+ return padded_x
40
+
41
+
42
+ def pad_random(x: np.ndarray, max_len: int = 64600):
43
+ """ Découpe aléatoire si trop long, padding si trop court """
44
+ x_len = x.shape[0]
45
+ if x_len >= max_len:
46
+ stt = np.random.randint(x_len - max_len)
47
+ return x[stt:stt + max_len]
48
+ num_repeats = int(max_len / x_len) + 1
49
+ padded_x = np.tile(x, (num_repeats))[:max_len]
50
+ return padded_x
51
+ # ==========================================================
52
+ # Chargement des données (Dataset)
53
+ # ==========================================================
54
+
55
+
56
+ class MyDataset(Dataset):
57
+ def __init__(self, wavdir, mos_list="", target_sample_rate=16000):
58
+ self.mos_lookup = {}
59
+ if mos_list:
60
+ with open(mos_list, 'r') as f:
61
+ for line in f:
62
+ parts = line.strip().split(',')
63
+ wavname = parts[0]
64
+ mos = float(parts[1])
65
+ self.mos_lookup[wavname] = mos
66
+
67
+ self.wavdir = wavdir
68
+ wavnames=os.listdir(self.wavdir)
69
+ self.wavnames = [f_name for f_name in wavnames if f_name.endswith(".wav")]
70
+ self.target_sample_rate = target_sample_rate
71
+
72
+ def __getitem__(self, idx):
73
+ wavname = self.wavnames[idx]
74
+ wavpath = os.path.join(self.wavdir, wavname)
75
+ wav, sample_rate = torchaudio.load(wavpath)
76
+
77
+ if sample_rate != self.target_sample_rate:
78
+ resampler = T.Resample(orig_freq=sample_rate, new_freq=self.target_sample_rate)
79
+ wav = resampler(wav)
80
+ if wavname in self.mos_lookup:
81
+ score = self.mos_lookup[wavname]
82
+ else:
83
+ score = 0 #TODO: it should be manage more properly
84
+ return wav, score, wavname
85
+
86
+ def __len__(self):
87
+ return len(self.wavnames)
88
+
89
+ def collate_fn(self, batch):
90
+ """ Padding et tronquage des séquences audio pour normaliser à 64600 frames """
91
+ wavs, scores, wavnames = zip(*batch)
92
+ max_len = 64600
93
+ output_wavs = []
94
+ for wav in wavs:
95
+
96
+ wav_np = wav.squeeze(0).cpu().numpy() # Enlève la dimension channel (1,) et met sur CPU
97
+ padded_wav = pad_random(wav_np, max_len)
98
+
99
+ padded_wav = torch.tensor(padded_wav, dtype=torch.float32).unsqueeze(0) # Remettre la dimension (1, time)
100
+
101
+ output_wavs.append(padded_wav)
102
+
103
+ output_wavs = torch.stack(output_wavs, dim=0) # [batch_size, 1, 64600]
104
+
105
+ scores = torch.tensor(scores, dtype=torch.float32)
106
+
107
+ return output_wavs, scores, wavnames
108
+
109
+
110
+
111
+ class GraphAttentionLayer(nn.Module):
112
+ def __init__(self, in_dim, out_dim, **kwargs):
113
+ super().__init__()
114
+
115
+ # attention map
116
+ self.att_proj = nn.Linear(in_dim, out_dim)
117
+ self.att_weight = self._init_new_params(out_dim, 1)
118
+
119
+ # project
120
+ self.proj_with_att = nn.Linear(in_dim, out_dim)
121
+ self.proj_without_att = nn.Linear(in_dim, out_dim)
122
+
123
+ # batch norm
124
+ self.bn = nn.BatchNorm1d(out_dim)
125
+
126
+ # dropout for inputs
127
+ self.input_drop = nn.Dropout(p=0.2)
128
+
129
+ # activate
130
+ self.act = nn.SELU(inplace=True)
131
+
132
+ # temperature
133
+ self.temp = 1.
134
+ if "temperature" in kwargs:
135
+ self.temp = kwargs["temperature"]
136
+
137
+ def forward(self, x):
138
+ '''
139
+ x :(#bs, #node, #dim)
140
+ '''
141
+ # apply input dropout
142
+ x = self.input_drop(x)
143
+
144
+ # derive attention map
145
+ att_map = self._derive_att_map(x)
146
+
147
+ # projection
148
+ x = self._project(x, att_map)
149
+
150
+ # apply batch norm
151
+ x = self._apply_BN(x)
152
+ x = self.act(x)
153
+ return x
154
+
155
+ def _pairwise_mul_nodes(self, x):
156
+ '''
157
+ Calculates pairwise multiplication of nodes.
158
+ - for attention map
159
+ x :(#bs, #node, #dim)
160
+ out_shape :(#bs, #node, #node, #dim)
161
+ '''
162
+
163
+ nb_nodes = x.size(1)
164
+ x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
165
+ x_mirror = x.transpose(1, 2)
166
+
167
+ return x * x_mirror
168
+
169
+ def _derive_att_map(self, x):
170
+ '''
171
+ x :(#bs, #node, #dim)
172
+ out_shape :(#bs, #node, #node, 1)
173
+ '''
174
+ att_map = self._pairwise_mul_nodes(x)
175
+ # size: (#bs, #node, #node, #dim_out)
176
+ att_map = torch.tanh(self.att_proj(att_map))
177
+ # size: (#bs, #node, #node, 1)
178
+ att_map = torch.matmul(att_map, self.att_weight)
179
+
180
+ # apply temperature
181
+ att_map = att_map / self.temp
182
+
183
+ att_map = F.softmax(att_map, dim=-2)
184
+
185
+ return att_map
186
+
187
+ def _project(self, x, att_map):
188
+ x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
189
+ x2 = self.proj_without_att(x)
190
+
191
+ return x1 + x2
192
+
193
+ def _apply_BN(self, x):
194
+ org_size = x.size()
195
+ x = x.view(-1, org_size[-1])
196
+ x = self.bn(x)
197
+ x = x.view(org_size)
198
+
199
+ return x
200
+
201
+ def _init_new_params(self, *size):
202
+ out = nn.Parameter(torch.FloatTensor(*size))
203
+ nn.init.xavier_normal_(out)
204
+ return out
205
+
206
+
207
+ class HtrgGraphAttentionLayer(nn.Module):
208
+ def __init__(self, in_dim, out_dim, **kwargs):
209
+ super().__init__()
210
+
211
+ self.proj_type1 = nn.Linear(in_dim, in_dim)
212
+ self.proj_type2 = nn.Linear(in_dim, in_dim)
213
+
214
+ # attention map
215
+ self.att_proj = nn.Linear(in_dim, out_dim)
216
+ self.att_projM = nn.Linear(in_dim, out_dim)
217
+
218
+ self.att_weight11 = self._init_new_params(out_dim, 1)
219
+ self.att_weight22 = self._init_new_params(out_dim, 1)
220
+ self.att_weight12 = self._init_new_params(out_dim, 1)
221
+ self.att_weightM = self._init_new_params(out_dim, 1)
222
+
223
+ # project
224
+ self.proj_with_att = nn.Linear(in_dim, out_dim)
225
+ self.proj_without_att = nn.Linear(in_dim, out_dim)
226
+
227
+ self.proj_with_attM = nn.Linear(in_dim, out_dim)
228
+ self.proj_without_attM = nn.Linear(in_dim, out_dim)
229
+
230
+ # batch norm
231
+ self.bn = nn.BatchNorm1d(out_dim)
232
+
233
+ # dropout for inputs
234
+ self.input_drop = nn.Dropout(p=0.2)
235
+
236
+ # activate
237
+ self.act = nn.SELU(inplace=True)
238
+
239
+ # temperature
240
+ self.temp = 1.
241
+ if "temperature" in kwargs:
242
+ self.temp = kwargs["temperature"]
243
+
244
+ def forward(self, x1, x2, master=None):
245
+ '''
246
+ x1 :(#bs, #node, #dim)
247
+ x2 :(#bs, #node, #dim)
248
+ '''
249
+ num_type1 = x1.size(1)
250
+ num_type2 = x2.size(1)
251
+
252
+ x1 = self.proj_type1(x1)
253
+ x2 = self.proj_type2(x2)
254
+
255
+ x = torch.cat([x1, x2], dim=1)
256
+
257
+ if master is None:
258
+ master = torch.mean(x, dim=1, keepdim=True)
259
+
260
+ # apply input dropout
261
+ x = self.input_drop(x)
262
+
263
+ # derive attention map
264
+ att_map = self._derive_att_map(x, num_type1, num_type2)
265
+
266
+ # directional edge for master node
267
+ master = self._update_master(x, master)
268
+
269
+ # projection
270
+ x = self._project(x, att_map)
271
+
272
+ # apply batch norm
273
+ x = self._apply_BN(x)
274
+ x = self.act(x)
275
+
276
+ x1 = x.narrow(1, 0, num_type1)
277
+ x2 = x.narrow(1, num_type1, num_type2)
278
+
279
+ return x1, x2, master
280
+
281
+ def _update_master(self, x, master):
282
+
283
+ att_map = self._derive_att_map_master(x, master)
284
+ master = self._project_master(x, master, att_map)
285
+
286
+ return master
287
+
288
+ def _pairwise_mul_nodes(self, x):
289
+ '''
290
+ Calculates pairwise multiplication of nodes.
291
+ - for attention map
292
+ x :(#bs, #node, #dim)
293
+ out_shape :(#bs, #node, #node, #dim)
294
+ '''
295
+
296
+ nb_nodes = x.size(1)
297
+ x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
298
+ x_mirror = x.transpose(1, 2)
299
+
300
+ return x * x_mirror
301
+
302
+ def _derive_att_map_master(self, x, master):
303
+ '''
304
+ x :(#bs, #node, #dim)
305
+ out_shape :(#bs, #node, #node, 1)
306
+ '''
307
+ att_map = x * master
308
+ att_map = torch.tanh(self.att_projM(att_map))
309
+
310
+ att_map = torch.matmul(att_map, self.att_weightM)
311
+
312
+ # apply temperature
313
+ att_map = att_map / self.temp
314
+
315
+ att_map = F.softmax(att_map, dim=-2)
316
+
317
+ return att_map
318
+
319
+ def _derive_att_map(self, x, num_type1, num_type2):
320
+ '''
321
+ x :(#bs, #node, #dim)
322
+ out_shape :(#bs, #node, #node, 1)
323
+ '''
324
+ att_map = self._pairwise_mul_nodes(x)
325
+ # size: (#bs, #node, #node, #dim_out)
326
+ att_map = torch.tanh(self.att_proj(att_map))
327
+ # size: (#bs, #node, #node, 1)
328
+
329
+ att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1)
330
+
331
+ att_board[:, :num_type1, :num_type1, :] = torch.matmul(
332
+ att_map[:, :num_type1, :num_type1, :], self.att_weight11)
333
+ att_board[:, num_type1:, num_type1:, :] = torch.matmul(
334
+ att_map[:, num_type1:, num_type1:, :], self.att_weight22)
335
+ att_board[:, :num_type1, num_type1:, :] = torch.matmul(
336
+ att_map[:, :num_type1, num_type1:, :], self.att_weight12)
337
+ att_board[:, num_type1:, :num_type1, :] = torch.matmul(
338
+ att_map[:, num_type1:, :num_type1, :], self.att_weight12)
339
+
340
+ att_map = att_board
341
+
342
+ # att_map = torch.matmul(att_map, self.att_weight12)
343
+
344
+ # apply temperature
345
+ att_map = att_map / self.temp
346
+
347
+ att_map = F.softmax(att_map, dim=-2)
348
+
349
+ return att_map
350
+
351
+ def _project(self, x, att_map):
352
+ x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
353
+ x2 = self.proj_without_att(x)
354
+
355
+ return x1 + x2
356
+
357
+ def _project_master(self, x, master, att_map):
358
+
359
+ x1 = self.proj_with_attM(torch.matmul(
360
+ att_map.squeeze(-1).unsqueeze(1), x))
361
+ x2 = self.proj_without_attM(master)
362
+
363
+ return x1 + x2
364
+
365
+ def _apply_BN(self, x):
366
+ org_size = x.size()
367
+ x = x.view(-1, org_size[-1])
368
+ x = self.bn(x)
369
+ x = x.view(org_size)
370
+
371
+ return x
372
+
373
+ def _init_new_params(self, *size):
374
+ out = nn.Parameter(torch.FloatTensor(*size))
375
+ nn.init.xavier_normal_(out)
376
+ return out
377
+
378
+
379
+ class GraphPool(nn.Module):
380
+ def __init__(self, k: float, in_dim: int, p: Union[float, int]):
381
+ super().__init__()
382
+ self.k = k
383
+ self.sigmoid = nn.Sigmoid()
384
+ self.proj = nn.Linear(in_dim, 1)
385
+ self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
386
+ self.in_dim = in_dim
387
+
388
+ def forward(self, h):
389
+ Z = self.drop(h)
390
+ weights = self.proj(Z)
391
+ scores = self.sigmoid(weights)
392
+ new_h = self.top_k_graph(scores, h, self.k)
393
+
394
+ return new_h
395
+
396
+ def top_k_graph(self, scores, h, k):
397
+ """
398
+ args
399
+ =====
400
+ scores: attention-based weights (#bs, #node, 1)
401
+ h: graph data (#bs, #node, #dim)
402
+ k: ratio of remaining nodes, (float)
403
+
404
+ returns
405
+ =====
406
+ h: graph pool applied data (#bs, #node', #dim)
407
+ """
408
+ _, n_nodes, n_feat = h.size()
409
+ n_nodes = max(int(n_nodes * k), 1)
410
+ _, idx = torch.topk(scores, n_nodes, dim=1)
411
+ idx = idx.expand(-1, -1, n_feat)
412
+
413
+ h = h * scores
414
+ h = torch.gather(h, 1, idx)
415
+
416
+ return h
417
+
418
+
419
+ class CONV(nn.Module):
420
+ @staticmethod
421
+ def to_mel(hz):
422
+ return 2595 * np.log10(1 + hz / 700)
423
+
424
+ @staticmethod
425
+ def to_hz(mel):
426
+ return 700 * (10**(mel / 2595) - 1)
427
+
428
+ def __init__(self,
429
+ out_channels,
430
+ kernel_size,
431
+ sample_rate=16000,
432
+ in_channels=1,
433
+ stride=1,
434
+ padding=0,
435
+ dilation=1,
436
+ bias=False,
437
+ groups=1,
438
+ mask=False):
439
+ super().__init__()
440
+ if in_channels != 1:
441
+
442
+ msg = "SincConv only support one input channel (here, in_channels = {%i})" % (
443
+ in_channels)
444
+ raise ValueError(msg)
445
+ self.out_channels = out_channels
446
+ self.kernel_size = kernel_size
447
+ self.sample_rate = sample_rate
448
+
449
+ # Forcing the filters to be odd (i.e, perfectly symmetrics)
450
+ if kernel_size % 2 == 0:
451
+ self.kernel_size = self.kernel_size + 1
452
+ self.stride = stride
453
+ self.padding = padding
454
+ self.dilation = dilation
455
+ self.mask = mask
456
+ if bias:
457
+ raise ValueError('SincConv does not support bias.')
458
+ if groups > 1:
459
+ raise ValueError('SincConv does not support groups.')
460
+
461
+ NFFT = 512
462
+ f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
463
+ fmel = self.to_mel(f)
464
+ fmelmax = np.max(fmel)
465
+ fmelmin = np.min(fmel)
466
+ filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 1)
467
+ filbandwidthsf = self.to_hz(filbandwidthsmel)
468
+
469
+ self.mel = filbandwidthsf
470
+ self.hsupp = torch.arange(-(self.kernel_size - 1) / 2,
471
+ (self.kernel_size - 1) / 2 + 1)
472
+ self.band_pass = torch.zeros(self.out_channels, self.kernel_size)
473
+ for i in range(len(self.mel) - 1):
474
+ fmin = self.mel[i]
475
+ fmax = self.mel[i + 1]
476
+ hHigh = (2*fmax/self.sample_rate) * \
477
+ np.sinc(2*fmax*self.hsupp/self.sample_rate)
478
+ hLow = (2*fmin/self.sample_rate) * \
479
+ np.sinc(2*fmin*self.hsupp/self.sample_rate)
480
+ hideal = hHigh - hLow
481
+
482
+ self.band_pass[i, :] = Tensor(np.hamming(
483
+ self.kernel_size)) * Tensor(hideal)
484
+
485
+ def forward(self, x, mask=False):
486
+ band_pass_filter = self.band_pass.clone().to(x.device)
487
+ if mask:
488
+ A = np.random.uniform(0, 20)
489
+ A = int(A)
490
+ A0 = random.randint(0, band_pass_filter.shape[0] - A)
491
+ band_pass_filter[A0:A0 + A, :] = 0
492
+ else:
493
+ band_pass_filter = band_pass_filter
494
+
495
+ self.filters = (band_pass_filter).view(self.out_channels, 1,
496
+ self.kernel_size)
497
+
498
+ return F.conv1d(x,
499
+ self.filters,
500
+ stride=self.stride,
501
+ padding=self.padding,
502
+ dilation=self.dilation,
503
+ bias=None,
504
+ groups=1)
505
+
506
+
507
+ class Residual_block(nn.Module):
508
+ def __init__(self, nb_filts, first=False):
509
+ super().__init__()
510
+ self.first = first
511
+
512
+ if not self.first:
513
+ self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
514
+ self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
515
+ out_channels=nb_filts[1],
516
+ kernel_size=(2, 3),
517
+ padding=(1, 1),
518
+ stride=1)
519
+ self.selu = nn.SELU(inplace=True)
520
+
521
+ self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
522
+ self.conv2 = nn.Conv2d(in_channels=nb_filts[1],
523
+ out_channels=nb_filts[1],
524
+ kernel_size=(2, 3),
525
+ padding=(0, 1),
526
+ stride=1)
527
+
528
+ if nb_filts[0] != nb_filts[1]:
529
+ self.downsample = True
530
+ self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
531
+ out_channels=nb_filts[1],
532
+ padding=(0, 1),
533
+ kernel_size=(1, 3),
534
+ stride=1)
535
+
536
+ else:
537
+ self.downsample = False
538
+ self.mp = nn.MaxPool2d((1, 3)) # self.mp = nn.MaxPool2d((1,4))
539
+
540
+ def forward(self, x):
541
+ identity = x
542
+ if not self.first:
543
+ out = self.bn1(x)
544
+ out = self.selu(out)
545
+ else:
546
+ out = x
547
+ out = self.conv1(x)
548
+
549
+ # print('out',out.shape)
550
+ out = self.bn2(out)
551
+ out = self.selu(out)
552
+ # print('out',out.shape)
553
+ out = self.conv2(out)
554
+ #print('conv2 out',out.shape)
555
+ if self.downsample:
556
+ identity = self.conv_downsample(identity)
557
+
558
+ out += identity
559
+ out = self.mp(out)
560
+ return out
561
+
562
+
563
+ class Model(nn.Module):
564
+ def __init__(self, d_args):
565
+ super().__init__()
566
+
567
+ self.d_args = d_args
568
+ filts = d_args["filts"]
569
+ gat_dims = d_args["gat_dims"]
570
+ pool_ratios = d_args["pool_ratios"]
571
+ temperatures = d_args["temperatures"]
572
+
573
+ self.conv_time = CONV(out_channels=filts[0],
574
+ kernel_size=d_args["first_conv"],
575
+ in_channels=1)
576
+ self.first_bn = nn.BatchNorm2d(num_features=1)
577
+
578
+ self.drop = nn.Dropout(0.5, inplace=True)
579
+ self.drop_way = nn.Dropout(0.2, inplace=True)
580
+ self.selu = nn.SELU(inplace=True)
581
+
582
+ self.encoder = nn.Sequential(
583
+ nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
584
+ nn.Sequential(Residual_block(nb_filts=filts[2])),
585
+ nn.Sequential(Residual_block(nb_filts=filts[3])),
586
+ nn.Sequential(Residual_block(nb_filts=filts[4])),
587
+ nn.Sequential(Residual_block(nb_filts=filts[4])),
588
+ nn.Sequential(Residual_block(nb_filts=filts[4])))
589
+
590
+ self.pos_S = nn.Parameter(torch.randn(1, 23, filts[-1][-1]))
591
+ self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
592
+ self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
593
+
594
+ self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1],
595
+ gat_dims[0],
596
+ temperature=temperatures[0])
597
+ self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1],
598
+ gat_dims[0],
599
+ temperature=temperatures[1])
600
+
601
+ self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(
602
+ gat_dims[0], gat_dims[1], temperature=temperatures[2])
603
+ self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(
604
+ gat_dims[1], gat_dims[1], temperature=temperatures[2])
605
+
606
+ self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(
607
+ gat_dims[0], gat_dims[1], temperature=temperatures[2])
608
+
609
+ self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(
610
+ gat_dims[1], gat_dims[1], temperature=temperatures[2])
611
+
612
+ self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
613
+ self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
614
+ self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
615
+ self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
616
+
617
+ self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
618
+ self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
619
+
620
+ if "output_cls" in d_args:
621
+ self.out_layer = nn.Linear(5 * gat_dims[1], d_args["output_cls"])
622
+ else:
623
+ self.out_layer = nn.Linear(5 * gat_dims[1], 2)
624
+
625
+ def forward(self, x, Freq_aug=False):
626
+
627
+ x = x.unsqueeze(1)
628
+ x = self.conv_time(x, mask=Freq_aug)
629
+ x = x.unsqueeze(dim=1)
630
+ x = F.max_pool2d(torch.abs(x), (3, 3))
631
+ x = self.first_bn(x)
632
+ x = self.selu(x)
633
+
634
+ # get embeddings using encoder
635
+ # (#bs, #filt, #spec, #seq)
636
+ e = self.encoder(x)
637
+
638
+ # spectral GAT (GAT-S)
639
+ e_S, _ = torch.max(torch.abs(e), dim=3) # max along time
640
+ e_S = e_S.transpose(1, 2) + self.pos_S
641
+
642
+ gat_S = self.GAT_layer_S(e_S)
643
+ out_S = self.pool_S(gat_S) # (#bs, #node, #dim)
644
+
645
+ # temporal GAT (GAT-T)
646
+ e_T, _ = torch.max(torch.abs(e), dim=2) # max along freq
647
+ e_T = e_T.transpose(1, 2)
648
+
649
+ gat_T = self.GAT_layer_T(e_T)
650
+ out_T = self.pool_T(gat_T)
651
+
652
+ # learnable master node
653
+ master1 = self.master1.expand(x.size(0), -1, -1)
654
+ master2 = self.master2.expand(x.size(0), -1, -1)
655
+
656
+ # inference 1
657
+ out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(
658
+ out_T, out_S, master=self.master1)
659
+
660
+ out_S1 = self.pool_hS1(out_S1)
661
+ out_T1 = self.pool_hT1(out_T1)
662
+
663
+ out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(
664
+ out_T1, out_S1, master=master1)
665
+ out_T1 = out_T1 + out_T_aug
666
+ out_S1 = out_S1 + out_S_aug
667
+ master1 = master1 + master_aug
668
+
669
+ # inference 2
670
+ out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(
671
+ out_T, out_S, master=self.master2)
672
+ out_S2 = self.pool_hS2(out_S2)
673
+ out_T2 = self.pool_hT2(out_T2)
674
+
675
+ out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(
676
+ out_T2, out_S2, master=master2)
677
+ out_T2 = out_T2 + out_T_aug
678
+ out_S2 = out_S2 + out_S_aug
679
+ master2 = master2 + master_aug
680
+
681
+ out_T1 = self.drop_way(out_T1)
682
+ out_T2 = self.drop_way(out_T2)
683
+ out_S1 = self.drop_way(out_S1)
684
+ out_S2 = self.drop_way(out_S2)
685
+ master1 = self.drop_way(master1)
686
+ master2 = self.drop_way(master2)
687
+
688
+ out_T = torch.max(out_T1, out_T2)
689
+ out_S = torch.max(out_S1, out_S2)
690
+ master = torch.max(master1, master2)
691
+
692
+ T_max, _ = torch.max(torch.abs(out_T), dim=1)
693
+ T_avg = torch.mean(out_T, dim=1)
694
+
695
+ S_max, _ = torch.max(torch.abs(out_S), dim=1)
696
+ S_avg = torch.mean(out_S, dim=1)
697
+
698
+ last_hidden = torch.cat(
699
+ [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)
700
+
701
+ last_hidden = self.drop(last_hidden)
702
+ output = self.out_layer(last_hidden)
703
+
704
+ output=F.softmax(output,dim=1)
705
+
706
+ return last_hidden, output
707
+
708
+
709
+
710
+ def obtain_asv_error_rates(tar_asv, non_asv, spoof_asv, asv_threshold):
711
+
712
+ # False alarm and miss rates for ASV
713
+ Pfa_asv = sum(non_asv >= asv_threshold) / non_asv.size
714
+ Pmiss_asv = sum(tar_asv < asv_threshold) / tar_asv.size
715
+
716
+ # Rate of rejecting spoofs in ASV
717
+ if spoof_asv.size == 0:
718
+ Pmiss_spoof_asv = None
719
+ Pfa_spoof_asv = None
720
+ else:
721
+ Pmiss_spoof_asv = np.sum(spoof_asv < asv_threshold) / spoof_asv.size
722
+ Pfa_spoof_asv = np.sum(spoof_asv >= asv_threshold) / spoof_asv.size
723
+
724
+ return Pfa_asv, Pmiss_asv, Pmiss_spoof_asv, Pfa_spoof_asv
725
+
726
+
727
+ def obtain_asv_error_rates(tar_asv, non_asv, spoof_asv, asv_threshold):
728
+
729
+ # False alarm and miss rates for ASV
730
+ Pfa_asv = sum(non_asv >= asv_threshold) / non_asv.size
731
+ Pmiss_asv = sum(tar_asv < asv_threshold) / tar_asv.size
732
+
733
+ # Rate of rejecting spoofs in ASV
734
+ if spoof_asv.size == 0:
735
+ Pmiss_spoof_asv = None
736
+ Pfa_spoof_asv = None
737
+ else:
738
+ Pmiss_spoof_asv = np.sum(spoof_asv < asv_threshold) / spoof_asv.size
739
+ Pfa_spoof_asv = np.sum(spoof_asv >= asv_threshold) / spoof_asv.size
740
+
741
+ return Pfa_asv, Pmiss_asv, Pmiss_spoof_asv, Pfa_spoof_asv
742
+
743
+
744
+ def compute_det_curve(target_scores, nontarget_scores):
745
+
746
+ n_scores = target_scores.size + nontarget_scores.size
747
+ all_scores = np.concatenate((target_scores, nontarget_scores))
748
+ labels = np.concatenate(
749
+ (np.ones(target_scores.size), np.zeros(nontarget_scores.size)))
750
+
751
+ # Sort labels based on scores
752
+ indices = np.argsort(all_scores, kind='mergesort')
753
+ labels = labels[indices]
754
+
755
+ # Compute false rejection and false acceptance rates
756
+ tar_trial_sums = np.cumsum(labels)
757
+ nontarget_trial_sums = nontarget_scores.size - \
758
+ (np.arange(1, n_scores + 1) - tar_trial_sums)
759
+
760
+ # false rejection rates
761
+ frr = np.concatenate(
762
+ (np.atleast_1d(0), tar_trial_sums / target_scores.size))
763
+ far = np.concatenate((np.atleast_1d(1), nontarget_trial_sums /
764
+ nontarget_scores.size)) # false acceptance rates
765
+ # Thresholds are the sorted scores
766
+ thresholds = np.concatenate(
767
+ (np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices]))
768
+
769
+ return frr, far, thresholds
770
+
771
+
772
+ def compute_Pmiss_Pfa_Pspoof_curves(tar_scores, non_scores, spf_scores):
773
+
774
+ # Concatenate all scores and designate arbitrary labels 1=target, 0=nontarget, -1=spoof
775
+ all_scores = np.concatenate((tar_scores, non_scores, spf_scores))
776
+ labels = np.concatenate((np.ones(tar_scores.size), np.zeros(non_scores.size), -1*np.ones(spf_scores.size)))
777
+
778
+ # Sort labels based on scores
779
+ indices = np.argsort(all_scores, kind='mergesort')
780
+ labels = labels[indices]
781
+
782
+ # Cumulative sums
783
+ tar_sums = np.cumsum(labels==1)
784
+ non_sums = np.cumsum(labels==0)
785
+ spoof_sums = np.cumsum(labels==-1)
786
+
787
+ Pmiss = np.concatenate((np.atleast_1d(0), tar_sums / tar_scores.size))
788
+ Pfa_non = np.concatenate((np.atleast_1d(1), 1 - (non_sums / non_scores.size)))
789
+ Pfa_spoof = np.concatenate((np.atleast_1d(1), 1 - (spoof_sums / spf_scores.size)))
790
+ thresholds = np.concatenate((np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices])) # Thresholds are the sorted scores
791
+
792
+ return Pmiss, Pfa_non, Pfa_spoof, thresholds
793
+
794
+
795
+ def compute_eer(target_scores, nontarget_scores):
796
+ """ Returns equal error rate (EER) and the corresponding threshold. """
797
+ frr, far, thresholds = compute_det_curve(target_scores, nontarget_scores)
798
+ abs_diffs = np.abs(frr - far)
799
+ min_index = np.argmin(abs_diffs)
800
+ eer = np.mean((frr[min_index], far[min_index]))
801
+ return eer, frr, far, thresholds
802
+
803
+
804
+ def compute_mindcf(frr, far, thresholds, Pspoof, Cmiss, Cfa):
805
+ min_c_det = float("inf")
806
+ min_c_det_threshold = thresholds
807
+
808
+ p_target = 1- Pspoof
809
+ for i in range(0, len(frr)):
810
+ # Weighted sum of false negative and false positive errors.
811
+ c_det = Cmiss * frr[i] * p_target + Cfa * far[i] * (1 - p_target)
812
+ if c_det < min_c_det:
813
+ min_c_det = c_det
814
+ min_c_det_threshold = thresholds[i]
815
+ # See Equations (3) and (4). Now we normalize the cost.
816
+ c_def = min(Cmiss * p_target, Cfa * (1 - p_target))
817
+ min_dcf = min_c_det / c_def
818
+ return min_dcf, min_c_det_threshold
819
+
820
+
821
+ def compute_tDCF(bonafide_score_cm, spoof_score_cm, Pfa_asv, Pmiss_asv,
822
+ Pmiss_spoof_asv, cost_model, print_cost):
823
+
824
+ # Sanity check of cost parameters
825
+ if cost_model['Cfa_asv'] < 0 or cost_model['Cmiss_asv'] < 0 or \
826
+ cost_model['Cfa_cm'] < 0 or cost_model['Cmiss_cm'] < 0:
827
+ print('WARNING: Usually the cost values should be positive!')
828
+
829
+ if cost_model['Ptar'] < 0 or cost_model['Pnon'] < 0 or cost_model['Pspoof'] < 0 or \
830
+ np.abs(cost_model['Ptar'] + cost_model['Pnon'] + cost_model['Pspoof'] - 1) > 1e-10:
831
+ sys.exit(
832
+ 'ERROR: Your prior probabilities should be positive and sum up to one.'
833
+ )
834
+
835
+ # Unless we evaluate worst-case model, we need to have some spoof tests against asv
836
+ if Pmiss_spoof_asv is None:
837
+ sys.exit(
838
+ 'ERROR: you should provide miss rate of spoof tests against your ASV system.'
839
+ )
840
+
841
+ # Sanity check of scores
842
+ combined_scores = np.concatenate((bonafide_score_cm, spoof_score_cm))
843
+ if np.isnan(combined_scores).any() or np.isinf(combined_scores).any():
844
+ sys.exit('ERROR: Your scores contain nan or inf.')
845
+
846
+ # Sanity check that inputs are scores and not decisions
847
+ n_uniq = np.unique(combined_scores).size
848
+ if n_uniq < 3:
849
+ sys.exit(
850
+ 'ERROR: You should provide soft CM scores - not binary decisions')
851
+
852
+ # Obtain miss and false alarm rates of CM
853
+ Pmiss_cm, Pfa_cm, CM_thresholds = compute_det_curve(
854
+ bonafide_score_cm, spoof_score_cm)
855
+
856
+ # Constants - see ASVspoof 2019 evaluation plan
857
+ C1 = cost_model['Ptar'] * (cost_model['Cmiss_cm'] - cost_model['Cmiss_asv'] * Pmiss_asv) - \
858
+ cost_model['Pnon'] * cost_model['Cfa_asv'] * Pfa_asv
859
+ C2 = cost_model['Cfa_cm'] * cost_model['Pspoof'] * (1 - Pmiss_spoof_asv)
860
+
861
+ # Sanity check of the weights
862
+ if C1 < 0 or C2 < 0:
863
+ sys.exit(
864
+ 'You should never see this error but I cannot evalute tDCF with negative weights - please check whether your ASV error rates are correctly computed?'
865
+ )
866
+
867
+ # Obtain t-DCF curve for all thresholds
868
+ tDCF = C1 * Pmiss_cm + C2 * Pfa_cm
869
+
870
+ # Normalized t-DCF
871
+ tDCF_norm = tDCF / np.minimum(C1, C2)
872
+
873
+ # Everything should be fine if reaching here.
874
+ if print_cost:
875
+
876
+ print('t-DCF evaluation from [Nbona={}, Nspoof={}] trials\n'.format(
877
+ bonafide_score_cm.size, spoof_score_cm.size))
878
+ print('t-DCF MODEL')
879
+ print(' Ptar = {:8.5f} (Prior probability of target user)'.
880
+ format(cost_model['Ptar']))
881
+ print(
882
+ ' Pnon = {:8.5f} (Prior probability of nontarget user)'.
883
+ format(cost_model['Pnon']))
884
+ print(
885
+ ' Pspoof = {:8.5f} (Prior probability of spoofing attack)'.
886
+ format(cost_model['Pspoof']))
887
+ print(
888
+ ' Cfa_asv = {:8.5f} (Cost of ASV falsely accepting a nontarget)'
889
+ .format(cost_model['Cfa_asv']))
890
+ print(
891
+ ' Cmiss_asv = {:8.5f} (Cost of ASV falsely rejecting target speaker)'
892
+ .format(cost_model['Cmiss_asv']))
893
+ print(
894
+ ' Cfa_cm = {:8.5f} (Cost of CM falsely passing a spoof to ASV system)'
895
+ .format(cost_model['Cfa_cm']))
896
+ print(
897
+ ' Cmiss_cm = {:8.5f} (Cost of CM falsely blocking target utterance which never reaches ASV)'
898
+ .format(cost_model['Cmiss_cm']))
899
+ print(
900
+ '\n Implied normalized t-DCF function (depends on t-DCF parameters and ASV errors), s=CM threshold)'
901
+ )
902
+
903
+ if C2 == np.minimum(C1, C2):
904
+ print(
905
+ ' tDCF_norm(s) = {:8.5f} x Pmiss_cm(s) + Pfa_cm(s)\n'.format(
906
+ C1 / C2))
907
+ else:
908
+ print(
909
+ ' tDCF_norm(s) = Pmiss_cm(s) + {:8.5f} x Pfa_cm(s)\n'.format(
910
+ C2 / C1))
911
+
912
+ return tDCF_norm, CM_thresholds
913
+
914
+
915
+ def calculate_CLLR(target_llrs, nontarget_llrs):
916
+ """
917
+ Calculate the CLLR of the scores.
918
+
919
+ Parameters:
920
+ target_llrs (list or numpy array): Log-likelihood ratios for target trials.
921
+ nontarget_llrs (list or numpy array): Log-likelihood ratios for non-target trials.
922
+
923
+ Returns:
924
+ float: The calculated CLLR value.
925
+ """
926
+ def negative_log_sigmoid(lodds):
927
+ """
928
+ Calculate the negative log of the sigmoid function.
929
+
930
+ Parameters:
931
+ lodds (numpy array): Log-odds values.
932
+
933
+ Returns:
934
+ numpy array: The negative log of the sigmoid values.
935
+ """
936
+ return np.log1p(np.exp(-lodds))
937
+
938
+ # Convert the input lists to numpy arrays if they are not already
939
+ target_llrs = np.array(target_llrs)
940
+ nontarget_llrs = np.array(nontarget_llrs)
941
+
942
+ # Calculate the CLLR value
943
+ cllr = 0.5 * (np.mean(negative_log_sigmoid(target_llrs)) + np.mean(negative_log_sigmoid(-nontarget_llrs))) / np.log(2)
944
+
945
+ return cllr
946
+
947
+
948
+ def compute_Pmiss_Pfa_Pspoof_curves(tar_scores, non_scores, spf_scores):
949
+
950
+ # Concatenate all scores and designate arbitrary labels 1=target, 0=nontarget, -1=spoof
951
+ all_scores = np.concatenate((tar_scores, non_scores, spf_scores))
952
+ labels = np.concatenate((np.ones(tar_scores.size), np.zeros(non_scores.size), -1*np.ones(spf_scores.size)))
953
+
954
+ # Sort labels based on scores
955
+ indices = np.argsort(all_scores, kind='mergesort')
956
+ labels = labels[indices]
957
+
958
+ # Cumulative sums
959
+ tar_sums = np.cumsum(labels==1)
960
+ non_sums = np.cumsum(labels==0)
961
+ spoof_sums = np.cumsum(labels==-1)
962
+
963
+ Pmiss = np.concatenate((np.atleast_1d(0), tar_sums / tar_scores.size))
964
+ Pfa_non = np.concatenate((np.atleast_1d(1), 1 - (non_sums / non_scores.size)))
965
+ Pfa_spoof = np.concatenate((np.atleast_1d(1), 1 - (spoof_sums / spf_scores.size)))
966
+ thresholds = np.concatenate((np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices])) # Thresholds are the sorted scores
967
+
968
+ return Pmiss, Pfa_non, Pfa_spoof, thresholds
969
+
970
+
971
+ def compute_teer(Pmiss_CM, Pfa_CM, tau_CM, Pmiss_ASV, Pfa_non_ASV, Pfa_spf_ASV, tau_ASV):
972
+ # Different spoofing prevalence priors (rho) parameters values
973
+ rho_vals = [0,0.5,1]
974
+
975
+ tEER_val = np.empty([len(rho_vals),len(tau_ASV)], dtype=float)
976
+
977
+ for rho_idx, rho_spf in enumerate(rho_vals):
978
+
979
+ # Table to store the CM threshold index, per each of the ASV operating points
980
+ tEER_idx_CM = np.empty(len(tau_ASV), dtype=int)
981
+
982
+ tEER_path = np.empty([len(rho_vals),len(tau_ASV),2], dtype=float)
983
+
984
+ # Tables to store the t-EER, total Pfa and total miss valuees along the t-EER path
985
+ Pmiss_total = np.empty(len(tau_ASV), dtype=float)
986
+ Pfa_total = np.empty(len(tau_ASV), dtype=float)
987
+ min_tEER = np.inf
988
+ argmin_tEER = np.empty(2)
989
+
990
+ # best intersection point
991
+ xpoint_crit_best = np.inf
992
+ xpoint = np.empty(2)
993
+
994
+ # Loop over all possible ASV thresholds
995
+ for tau_ASV_idx, tau_ASV_val in enumerate(tau_ASV):
996
+
997
+ # Tandem miss and fa rates as defined in the manuscript
998
+ Pmiss_tdm = Pmiss_CM + (1 - Pmiss_CM) * Pmiss_ASV[tau_ASV_idx]
999
+ Pfa_tdm = (1 - rho_spf) * (1 - Pmiss_CM) * Pfa_non_ASV[tau_ASV_idx] + rho_spf * Pfa_CM * Pfa_spf_ASV[tau_ASV_idx]
1000
+
1001
+ # Store only the INDEX of the CM threshold (for the current ASV threshold)
1002
+ h = Pmiss_tdm - Pfa_tdm
1003
+ tmp = np.argmin(abs(h))
1004
+ tEER_idx_CM[tau_ASV_idx] = tmp
1005
+
1006
+ if Pmiss_ASV[tau_ASV_idx] < (1 - rho_spf) * Pfa_non_ASV[tau_ASV_idx] + rho_spf * Pfa_spf_ASV[tau_ASV_idx]:
1007
+ Pmiss_total[tau_ASV_idx] = Pmiss_tdm[tmp]
1008
+ Pfa_total[tau_ASV_idx] = Pfa_tdm[tmp]
1009
+
1010
+ tEER_val[rho_idx,tau_ASV_idx] = np.mean([Pfa_total[tau_ASV_idx], Pmiss_total[tau_ASV_idx]])
1011
+
1012
+ tEER_path[rho_idx,tau_ASV_idx, 0] = tau_ASV_val
1013
+ tEER_path[rho_idx,tau_ASV_idx, 1] = tau_CM[tmp]
1014
+
1015
+ if tEER_val[rho_idx,tau_ASV_idx] < min_tEER:
1016
+ min_tEER = tEER_val[rho_idx,tau_ASV_idx]
1017
+ argmin_tEER[0] = tau_ASV_val
1018
+ argmin_tEER[1] = tau_CM[tmp]
1019
+
1020
+ # Check how close we are to the INTERSECTION POINT for different prior (rho) values:
1021
+ LHS = Pfa_non_ASV[tau_ASV_idx]/Pfa_spf_ASV[tau_ASV_idx]
1022
+ RHS = Pfa_CM[tmp]/(1 - Pmiss_CM[tmp])
1023
+ crit = abs(LHS - RHS)
1024
+
1025
+ if crit < xpoint_crit_best:
1026
+ xpoint_crit_best = crit
1027
+ xpoint[0] = tau_ASV_val
1028
+ xpoint[1] = tau_CM[tmp]
1029
+ xpoint_tEER = Pfa_spf_ASV[tau_ASV_idx]*Pfa_CM[tmp]
1030
+ else:
1031
+ # Not in allowed region
1032
+ tEER_path[rho_idx,tau_ASV_idx, 0] = np.nan
1033
+ tEER_path[rho_idx,tau_ASV_idx, 1] = np.nan
1034
+ Pmiss_total[tau_ASV_idx] = np.nan
1035
+ Pfa_total[tau_ASV_idx] = np.nan
1036
+ tEER_val[rho_idx,tau_ASV_idx] = np.nan
1037
+
1038
+ return xpoint_tEER*100
evaluation/AASIST/AASIST_util.py ADDED
@@ -0,0 +1,1065 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AASIST
3
+ Copyright (c) 2021-present NAVER Corp.
4
+ MIT license
5
+ """
6
+
7
+ import random
8
+ from typing import Union
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch import Tensor
14
+ import sys
15
+ import os
16
+ import argparse
17
+ import torch.optim as optim
18
+ import torchaudio
19
+ from torch.utils.data import Dataset, DataLoader
20
+ from tqdm import tqdm
21
+ import torchaudio.transforms as T
22
+ from collections import defaultdict
23
+ import torch.multiprocessing
24
+
25
+ torch.multiprocessing.set_sharing_strategy('file_system')
26
+
27
+
28
+ def load_aasist_model(ckpt_path, device):
29
+ model_config = {
30
+ "architecture": "AASIST",
31
+ "nb_samp": 64600,
32
+ "first_conv": 128,
33
+ "filts": [70, [1, 32], [32, 32], [32, 64], [64, 64]],
34
+ "gat_dims": [64, 32],
35
+ "pool_ratios": [0.5, 0.7, 0.5, 0.5],
36
+ "temperatures": [2.0, 2.0, 100.0, 100.0],
37
+ "output_cls": 25
38
+ }
39
+
40
+ net = Model(model_config).to(device)
41
+ checkpoint = torch.load(ckpt_path, map_location=device)
42
+ net.load_state_dict(checkpoint)
43
+ net.eval()
44
+
45
+ return net
46
+
47
+ def aasist_evaluate(models, audio):
48
+ score = []
49
+ for model in models:
50
+ _, probb = model(audio)
51
+ score.append(probb[0, 0:1].item())
52
+ return np.mean(score)
53
+
54
+
55
+ def extract_system_id(wavname):
56
+ """Extrait l'identifiant du système à partir du nom du fichier."""
57
+ return wavname.split('-')[0]
58
+
59
+ def pad(x, max_len=64600):
60
+ """ Padding ou découpage d'un signal audio """
61
+ x_len = x.shape[0]
62
+ if x_len >= max_len:
63
+ return x[:max_len]
64
+ num_repeats = int(max_len / x_len) + 1
65
+ padded_x = np.tile(x, (num_repeats))[:max_len]
66
+ return padded_x
67
+
68
+
69
+ def pad_random(x: np.ndarray, max_len: int = 64600):
70
+ """ Découpe aléatoire si trop long, padding si trop court """
71
+ x_len = x.shape[0]
72
+ if x_len >= max_len:
73
+ stt = np.random.randint(x_len - max_len)
74
+ return x[stt:stt + max_len]
75
+ num_repeats = int(max_len / x_len) + 1
76
+ padded_x = np.tile(x, (num_repeats))[:max_len]
77
+ return padded_x
78
+ # ==========================================================
79
+ # Chargement des données (Dataset)
80
+ # ==========================================================
81
+
82
+
83
+ class MyDataset(Dataset):
84
+ def __init__(self, wavdir, mos_list="", target_sample_rate=16000):
85
+ self.mos_lookup = {}
86
+ if mos_list:
87
+ with open(mos_list, 'r') as f:
88
+ for line in f:
89
+ parts = line.strip().split(',')
90
+ wavname = parts[0]
91
+ mos = float(parts[1])
92
+ self.mos_lookup[wavname] = mos
93
+
94
+ self.wavdir = wavdir
95
+ wavnames=os.listdir(self.wavdir)
96
+ self.wavnames = [f_name for f_name in wavnames if f_name.endswith(".wav")]
97
+ self.target_sample_rate = target_sample_rate
98
+
99
+ def __getitem__(self, idx):
100
+ wavname = self.wavnames[idx]
101
+ wavpath = os.path.join(self.wavdir, wavname)
102
+ wav, sample_rate = torchaudio.load(wavpath)
103
+
104
+ if sample_rate != self.target_sample_rate:
105
+ resampler = T.Resample(orig_freq=sample_rate, new_freq=self.target_sample_rate)
106
+ wav = resampler(wav)
107
+ if wavname in self.mos_lookup:
108
+ score = self.mos_lookup[wavname]
109
+ else:
110
+ score = 0 #TODO: it should be manage more properly
111
+ return wav, score, wavname
112
+
113
+ def __len__(self):
114
+ return len(self.wavnames)
115
+
116
+ def collate_fn(self, batch):
117
+ """ Padding et tronquage des séquences audio pour normaliser à 64600 frames """
118
+ wavs, scores, wavnames = zip(*batch)
119
+ max_len = 64600
120
+ output_wavs = []
121
+ for wav in wavs:
122
+
123
+ wav_np = wav.squeeze(0).cpu().numpy() # Enlève la dimension channel (1,) et met sur CPU
124
+ padded_wav = pad_random(wav_np, max_len)
125
+
126
+ padded_wav = torch.tensor(padded_wav, dtype=torch.float32).unsqueeze(0) # Remettre la dimension (1, time)
127
+
128
+ output_wavs.append(padded_wav)
129
+
130
+ output_wavs = torch.stack(output_wavs, dim=0) # [batch_size, 1, 64600]
131
+
132
+ scores = torch.tensor(scores, dtype=torch.float32)
133
+
134
+ return output_wavs, scores, wavnames
135
+
136
+
137
+
138
+ class GraphAttentionLayer(nn.Module):
139
+ def __init__(self, in_dim, out_dim, **kwargs):
140
+ super().__init__()
141
+
142
+ # attention map
143
+ self.att_proj = nn.Linear(in_dim, out_dim)
144
+ self.att_weight = self._init_new_params(out_dim, 1)
145
+
146
+ # project
147
+ self.proj_with_att = nn.Linear(in_dim, out_dim)
148
+ self.proj_without_att = nn.Linear(in_dim, out_dim)
149
+
150
+ # batch norm
151
+ self.bn = nn.BatchNorm1d(out_dim)
152
+
153
+ # dropout for inputs
154
+ self.input_drop = nn.Dropout(p=0.2)
155
+
156
+ # activate
157
+ self.act = nn.SELU(inplace=True)
158
+
159
+ # temperature
160
+ self.temp = 1.
161
+ if "temperature" in kwargs:
162
+ self.temp = kwargs["temperature"]
163
+
164
+ def forward(self, x):
165
+ '''
166
+ x :(#bs, #node, #dim)
167
+ '''
168
+ # apply input dropout
169
+ x = self.input_drop(x)
170
+
171
+ # derive attention map
172
+ att_map = self._derive_att_map(x)
173
+
174
+ # projection
175
+ x = self._project(x, att_map)
176
+
177
+ # apply batch norm
178
+ x = self._apply_BN(x)
179
+ x = self.act(x)
180
+ return x
181
+
182
+ def _pairwise_mul_nodes(self, x):
183
+ '''
184
+ Calculates pairwise multiplication of nodes.
185
+ - for attention map
186
+ x :(#bs, #node, #dim)
187
+ out_shape :(#bs, #node, #node, #dim)
188
+ '''
189
+
190
+ nb_nodes = x.size(1)
191
+ x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
192
+ x_mirror = x.transpose(1, 2)
193
+
194
+ return x * x_mirror
195
+
196
+ def _derive_att_map(self, x):
197
+ '''
198
+ x :(#bs, #node, #dim)
199
+ out_shape :(#bs, #node, #node, 1)
200
+ '''
201
+ att_map = self._pairwise_mul_nodes(x)
202
+ # size: (#bs, #node, #node, #dim_out)
203
+ att_map = torch.tanh(self.att_proj(att_map))
204
+ # size: (#bs, #node, #node, 1)
205
+ att_map = torch.matmul(att_map, self.att_weight)
206
+
207
+ # apply temperature
208
+ att_map = att_map / self.temp
209
+
210
+ att_map = F.softmax(att_map, dim=-2)
211
+
212
+ return att_map
213
+
214
+ def _project(self, x, att_map):
215
+ x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
216
+ x2 = self.proj_without_att(x)
217
+
218
+ return x1 + x2
219
+
220
+ def _apply_BN(self, x):
221
+ org_size = x.size()
222
+ x = x.view(-1, org_size[-1])
223
+ x = self.bn(x)
224
+ x = x.view(org_size)
225
+
226
+ return x
227
+
228
+ def _init_new_params(self, *size):
229
+ out = nn.Parameter(torch.FloatTensor(*size))
230
+ nn.init.xavier_normal_(out)
231
+ return out
232
+
233
+
234
+ class HtrgGraphAttentionLayer(nn.Module):
235
+ def __init__(self, in_dim, out_dim, **kwargs):
236
+ super().__init__()
237
+
238
+ self.proj_type1 = nn.Linear(in_dim, in_dim)
239
+ self.proj_type2 = nn.Linear(in_dim, in_dim)
240
+
241
+ # attention map
242
+ self.att_proj = nn.Linear(in_dim, out_dim)
243
+ self.att_projM = nn.Linear(in_dim, out_dim)
244
+
245
+ self.att_weight11 = self._init_new_params(out_dim, 1)
246
+ self.att_weight22 = self._init_new_params(out_dim, 1)
247
+ self.att_weight12 = self._init_new_params(out_dim, 1)
248
+ self.att_weightM = self._init_new_params(out_dim, 1)
249
+
250
+ # project
251
+ self.proj_with_att = nn.Linear(in_dim, out_dim)
252
+ self.proj_without_att = nn.Linear(in_dim, out_dim)
253
+
254
+ self.proj_with_attM = nn.Linear(in_dim, out_dim)
255
+ self.proj_without_attM = nn.Linear(in_dim, out_dim)
256
+
257
+ # batch norm
258
+ self.bn = nn.BatchNorm1d(out_dim)
259
+
260
+ # dropout for inputs
261
+ self.input_drop = nn.Dropout(p=0.2)
262
+
263
+ # activate
264
+ self.act = nn.SELU(inplace=True)
265
+
266
+ # temperature
267
+ self.temp = 1.
268
+ if "temperature" in kwargs:
269
+ self.temp = kwargs["temperature"]
270
+
271
+ def forward(self, x1, x2, master=None):
272
+ '''
273
+ x1 :(#bs, #node, #dim)
274
+ x2 :(#bs, #node, #dim)
275
+ '''
276
+ num_type1 = x1.size(1)
277
+ num_type2 = x2.size(1)
278
+
279
+ x1 = self.proj_type1(x1)
280
+ x2 = self.proj_type2(x2)
281
+
282
+ x = torch.cat([x1, x2], dim=1)
283
+
284
+ if master is None:
285
+ master = torch.mean(x, dim=1, keepdim=True)
286
+
287
+ # apply input dropout
288
+ x = self.input_drop(x)
289
+
290
+ # derive attention map
291
+ att_map = self._derive_att_map(x, num_type1, num_type2)
292
+
293
+ # directional edge for master node
294
+ master = self._update_master(x, master)
295
+
296
+ # projection
297
+ x = self._project(x, att_map)
298
+
299
+ # apply batch norm
300
+ x = self._apply_BN(x)
301
+ x = self.act(x)
302
+
303
+ x1 = x.narrow(1, 0, num_type1)
304
+ x2 = x.narrow(1, num_type1, num_type2)
305
+
306
+ return x1, x2, master
307
+
308
+ def _update_master(self, x, master):
309
+
310
+ att_map = self._derive_att_map_master(x, master)
311
+ master = self._project_master(x, master, att_map)
312
+
313
+ return master
314
+
315
+ def _pairwise_mul_nodes(self, x):
316
+ '''
317
+ Calculates pairwise multiplication of nodes.
318
+ - for attention map
319
+ x :(#bs, #node, #dim)
320
+ out_shape :(#bs, #node, #node, #dim)
321
+ '''
322
+
323
+ nb_nodes = x.size(1)
324
+ x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
325
+ x_mirror = x.transpose(1, 2)
326
+
327
+ return x * x_mirror
328
+
329
+ def _derive_att_map_master(self, x, master):
330
+ '''
331
+ x :(#bs, #node, #dim)
332
+ out_shape :(#bs, #node, #node, 1)
333
+ '''
334
+ att_map = x * master
335
+ att_map = torch.tanh(self.att_projM(att_map))
336
+
337
+ att_map = torch.matmul(att_map, self.att_weightM)
338
+
339
+ # apply temperature
340
+ att_map = att_map / self.temp
341
+
342
+ att_map = F.softmax(att_map, dim=-2)
343
+
344
+ return att_map
345
+
346
+ def _derive_att_map(self, x, num_type1, num_type2):
347
+ '''
348
+ x :(#bs, #node, #dim)
349
+ out_shape :(#bs, #node, #node, 1)
350
+ '''
351
+ att_map = self._pairwise_mul_nodes(x)
352
+ # size: (#bs, #node, #node, #dim_out)
353
+ att_map = torch.tanh(self.att_proj(att_map))
354
+ # size: (#bs, #node, #node, 1)
355
+
356
+ att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1)
357
+
358
+ att_board[:, :num_type1, :num_type1, :] = torch.matmul(
359
+ att_map[:, :num_type1, :num_type1, :], self.att_weight11)
360
+ att_board[:, num_type1:, num_type1:, :] = torch.matmul(
361
+ att_map[:, num_type1:, num_type1:, :], self.att_weight22)
362
+ att_board[:, :num_type1, num_type1:, :] = torch.matmul(
363
+ att_map[:, :num_type1, num_type1:, :], self.att_weight12)
364
+ att_board[:, num_type1:, :num_type1, :] = torch.matmul(
365
+ att_map[:, num_type1:, :num_type1, :], self.att_weight12)
366
+
367
+ att_map = att_board
368
+
369
+ # att_map = torch.matmul(att_map, self.att_weight12)
370
+
371
+ # apply temperature
372
+ att_map = att_map / self.temp
373
+
374
+ att_map = F.softmax(att_map, dim=-2)
375
+
376
+ return att_map
377
+
378
+ def _project(self, x, att_map):
379
+ x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
380
+ x2 = self.proj_without_att(x)
381
+
382
+ return x1 + x2
383
+
384
+ def _project_master(self, x, master, att_map):
385
+
386
+ x1 = self.proj_with_attM(torch.matmul(
387
+ att_map.squeeze(-1).unsqueeze(1), x))
388
+ x2 = self.proj_without_attM(master)
389
+
390
+ return x1 + x2
391
+
392
+ def _apply_BN(self, x):
393
+ org_size = x.size()
394
+ x = x.view(-1, org_size[-1])
395
+ x = self.bn(x)
396
+ x = x.view(org_size)
397
+
398
+ return x
399
+
400
+ def _init_new_params(self, *size):
401
+ out = nn.Parameter(torch.FloatTensor(*size))
402
+ nn.init.xavier_normal_(out)
403
+ return out
404
+
405
+
406
+ class GraphPool(nn.Module):
407
+ def __init__(self, k: float, in_dim: int, p: Union[float, int]):
408
+ super().__init__()
409
+ self.k = k
410
+ self.sigmoid = nn.Sigmoid()
411
+ self.proj = nn.Linear(in_dim, 1)
412
+ self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
413
+ self.in_dim = in_dim
414
+
415
+ def forward(self, h):
416
+ Z = self.drop(h)
417
+ weights = self.proj(Z)
418
+ scores = self.sigmoid(weights)
419
+ new_h = self.top_k_graph(scores, h, self.k)
420
+
421
+ return new_h
422
+
423
+ def top_k_graph(self, scores, h, k):
424
+ """
425
+ args
426
+ =====
427
+ scores: attention-based weights (#bs, #node, 1)
428
+ h: graph data (#bs, #node, #dim)
429
+ k: ratio of remaining nodes, (float)
430
+
431
+ returns
432
+ =====
433
+ h: graph pool applied data (#bs, #node', #dim)
434
+ """
435
+ _, n_nodes, n_feat = h.size()
436
+ n_nodes = max(int(n_nodes * k), 1)
437
+ _, idx = torch.topk(scores, n_nodes, dim=1)
438
+ idx = idx.expand(-1, -1, n_feat)
439
+
440
+ h = h * scores
441
+ h = torch.gather(h, 1, idx)
442
+
443
+ return h
444
+
445
+
446
+ class CONV(nn.Module):
447
+ @staticmethod
448
+ def to_mel(hz):
449
+ return 2595 * np.log10(1 + hz / 700)
450
+
451
+ @staticmethod
452
+ def to_hz(mel):
453
+ return 700 * (10**(mel / 2595) - 1)
454
+
455
+ def __init__(self,
456
+ out_channels,
457
+ kernel_size,
458
+ sample_rate=16000,
459
+ in_channels=1,
460
+ stride=1,
461
+ padding=0,
462
+ dilation=1,
463
+ bias=False,
464
+ groups=1,
465
+ mask=False):
466
+ super().__init__()
467
+ if in_channels != 1:
468
+
469
+ msg = "SincConv only support one input channel (here, in_channels = {%i})" % (
470
+ in_channels)
471
+ raise ValueError(msg)
472
+ self.out_channels = out_channels
473
+ self.kernel_size = kernel_size
474
+ self.sample_rate = sample_rate
475
+
476
+ # Forcing the filters to be odd (i.e, perfectly symmetrics)
477
+ if kernel_size % 2 == 0:
478
+ self.kernel_size = self.kernel_size + 1
479
+ self.stride = stride
480
+ self.padding = padding
481
+ self.dilation = dilation
482
+ self.mask = mask
483
+ if bias:
484
+ raise ValueError('SincConv does not support bias.')
485
+ if groups > 1:
486
+ raise ValueError('SincConv does not support groups.')
487
+
488
+ NFFT = 512
489
+ f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
490
+ fmel = self.to_mel(f)
491
+ fmelmax = np.max(fmel)
492
+ fmelmin = np.min(fmel)
493
+ filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 1)
494
+ filbandwidthsf = self.to_hz(filbandwidthsmel)
495
+
496
+ self.mel = filbandwidthsf
497
+ self.hsupp = torch.arange(-(self.kernel_size - 1) / 2,
498
+ (self.kernel_size - 1) / 2 + 1)
499
+ self.band_pass = torch.zeros(self.out_channels, self.kernel_size)
500
+ for i in range(len(self.mel) - 1):
501
+ fmin = self.mel[i]
502
+ fmax = self.mel[i + 1]
503
+ hHigh = (2*fmax/self.sample_rate) * \
504
+ np.sinc(2*fmax*self.hsupp/self.sample_rate)
505
+ hLow = (2*fmin/self.sample_rate) * \
506
+ np.sinc(2*fmin*self.hsupp/self.sample_rate)
507
+ hideal = hHigh - hLow
508
+
509
+ self.band_pass[i, :] = Tensor(np.hamming(
510
+ self.kernel_size)) * Tensor(hideal)
511
+
512
+ def forward(self, x, mask=False):
513
+ band_pass_filter = self.band_pass.clone().to(x.device)
514
+ if mask:
515
+ A = np.random.uniform(0, 20)
516
+ A = int(A)
517
+ A0 = random.randint(0, band_pass_filter.shape[0] - A)
518
+ band_pass_filter[A0:A0 + A, :] = 0
519
+ else:
520
+ band_pass_filter = band_pass_filter
521
+
522
+ self.filters = (band_pass_filter).view(self.out_channels, 1,
523
+ self.kernel_size)
524
+
525
+ return F.conv1d(x,
526
+ self.filters,
527
+ stride=self.stride,
528
+ padding=self.padding,
529
+ dilation=self.dilation,
530
+ bias=None,
531
+ groups=1)
532
+
533
+
534
+ class Residual_block(nn.Module):
535
+ def __init__(self, nb_filts, first=False):
536
+ super().__init__()
537
+ self.first = first
538
+
539
+ if not self.first:
540
+ self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
541
+ self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
542
+ out_channels=nb_filts[1],
543
+ kernel_size=(2, 3),
544
+ padding=(1, 1),
545
+ stride=1)
546
+ self.selu = nn.SELU(inplace=True)
547
+
548
+ self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
549
+ self.conv2 = nn.Conv2d(in_channels=nb_filts[1],
550
+ out_channels=nb_filts[1],
551
+ kernel_size=(2, 3),
552
+ padding=(0, 1),
553
+ stride=1)
554
+
555
+ if nb_filts[0] != nb_filts[1]:
556
+ self.downsample = True
557
+ self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
558
+ out_channels=nb_filts[1],
559
+ padding=(0, 1),
560
+ kernel_size=(1, 3),
561
+ stride=1)
562
+
563
+ else:
564
+ self.downsample = False
565
+ self.mp = nn.MaxPool2d((1, 3)) # self.mp = nn.MaxPool2d((1,4))
566
+
567
+ def forward(self, x):
568
+ identity = x
569
+ if not self.first:
570
+ out = self.bn1(x)
571
+ out = self.selu(out)
572
+ else:
573
+ out = x
574
+ out = self.conv1(x)
575
+
576
+ # print('out',out.shape)
577
+ out = self.bn2(out)
578
+ out = self.selu(out)
579
+ # print('out',out.shape)
580
+ out = self.conv2(out)
581
+ #print('conv2 out',out.shape)
582
+ if self.downsample:
583
+ identity = self.conv_downsample(identity)
584
+
585
+ out += identity
586
+ out = self.mp(out)
587
+ return out
588
+
589
+
590
+ class Model(nn.Module):
591
+ def __init__(self, d_args):
592
+ super().__init__()
593
+
594
+ self.d_args = d_args
595
+ filts = d_args["filts"]
596
+ gat_dims = d_args["gat_dims"]
597
+ pool_ratios = d_args["pool_ratios"]
598
+ temperatures = d_args["temperatures"]
599
+
600
+ self.conv_time = CONV(out_channels=filts[0],
601
+ kernel_size=d_args["first_conv"],
602
+ in_channels=1)
603
+ self.first_bn = nn.BatchNorm2d(num_features=1)
604
+
605
+ self.drop = nn.Dropout(0.5, inplace=True)
606
+ self.drop_way = nn.Dropout(0.2, inplace=True)
607
+ self.selu = nn.SELU(inplace=True)
608
+
609
+ self.encoder = nn.Sequential(
610
+ nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
611
+ nn.Sequential(Residual_block(nb_filts=filts[2])),
612
+ nn.Sequential(Residual_block(nb_filts=filts[3])),
613
+ nn.Sequential(Residual_block(nb_filts=filts[4])),
614
+ nn.Sequential(Residual_block(nb_filts=filts[4])),
615
+ nn.Sequential(Residual_block(nb_filts=filts[4])))
616
+
617
+ self.pos_S = nn.Parameter(torch.randn(1, 23, filts[-1][-1]))
618
+ self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
619
+ self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
620
+
621
+ self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1],
622
+ gat_dims[0],
623
+ temperature=temperatures[0])
624
+ self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1],
625
+ gat_dims[0],
626
+ temperature=temperatures[1])
627
+
628
+ self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(
629
+ gat_dims[0], gat_dims[1], temperature=temperatures[2])
630
+ self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(
631
+ gat_dims[1], gat_dims[1], temperature=temperatures[2])
632
+
633
+ self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(
634
+ gat_dims[0], gat_dims[1], temperature=temperatures[2])
635
+
636
+ self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(
637
+ gat_dims[1], gat_dims[1], temperature=temperatures[2])
638
+
639
+ self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
640
+ self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
641
+ self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
642
+ self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
643
+
644
+ self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
645
+ self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
646
+
647
+ if "output_cls" in d_args:
648
+ self.out_layer = nn.Linear(5 * gat_dims[1], d_args["output_cls"])
649
+ else:
650
+ self.out_layer = nn.Linear(5 * gat_dims[1], 2)
651
+
652
+ def forward(self, x, Freq_aug=False):
653
+
654
+ x = x.unsqueeze(1)
655
+ x = self.conv_time(x, mask=Freq_aug)
656
+ x = x.unsqueeze(dim=1)
657
+ x = F.max_pool2d(torch.abs(x), (3, 3))
658
+ x = self.first_bn(x)
659
+ x = self.selu(x)
660
+
661
+ # get embeddings using encoder
662
+ # (#bs, #filt, #spec, #seq)
663
+ e = self.encoder(x)
664
+
665
+ # spectral GAT (GAT-S)
666
+ e_S, _ = torch.max(torch.abs(e), dim=3) # max along time
667
+ e_S = e_S.transpose(1, 2) + self.pos_S
668
+
669
+ gat_S = self.GAT_layer_S(e_S)
670
+ out_S = self.pool_S(gat_S) # (#bs, #node, #dim)
671
+
672
+ # temporal GAT (GAT-T)
673
+ e_T, _ = torch.max(torch.abs(e), dim=2) # max along freq
674
+ e_T = e_T.transpose(1, 2)
675
+
676
+ gat_T = self.GAT_layer_T(e_T)
677
+ out_T = self.pool_T(gat_T)
678
+
679
+ # learnable master node
680
+ master1 = self.master1.expand(x.size(0), -1, -1)
681
+ master2 = self.master2.expand(x.size(0), -1, -1)
682
+
683
+ # inference 1
684
+ out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(
685
+ out_T, out_S, master=self.master1)
686
+
687
+ out_S1 = self.pool_hS1(out_S1)
688
+ out_T1 = self.pool_hT1(out_T1)
689
+
690
+ out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(
691
+ out_T1, out_S1, master=master1)
692
+ out_T1 = out_T1 + out_T_aug
693
+ out_S1 = out_S1 + out_S_aug
694
+ master1 = master1 + master_aug
695
+
696
+ # inference 2
697
+ out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(
698
+ out_T, out_S, master=self.master2)
699
+ out_S2 = self.pool_hS2(out_S2)
700
+ out_T2 = self.pool_hT2(out_T2)
701
+
702
+ out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(
703
+ out_T2, out_S2, master=master2)
704
+ out_T2 = out_T2 + out_T_aug
705
+ out_S2 = out_S2 + out_S_aug
706
+ master2 = master2 + master_aug
707
+
708
+ out_T1 = self.drop_way(out_T1)
709
+ out_T2 = self.drop_way(out_T2)
710
+ out_S1 = self.drop_way(out_S1)
711
+ out_S2 = self.drop_way(out_S2)
712
+ master1 = self.drop_way(master1)
713
+ master2 = self.drop_way(master2)
714
+
715
+ out_T = torch.max(out_T1, out_T2)
716
+ out_S = torch.max(out_S1, out_S2)
717
+ master = torch.max(master1, master2)
718
+
719
+ T_max, _ = torch.max(torch.abs(out_T), dim=1)
720
+ T_avg = torch.mean(out_T, dim=1)
721
+
722
+ S_max, _ = torch.max(torch.abs(out_S), dim=1)
723
+ S_avg = torch.mean(out_S, dim=1)
724
+
725
+ last_hidden = torch.cat(
726
+ [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)
727
+
728
+ last_hidden = self.drop(last_hidden)
729
+ output = self.out_layer(last_hidden)
730
+
731
+ output=F.softmax(output,dim=1)
732
+
733
+ return last_hidden, output
734
+
735
+
736
+
737
+ def obtain_asv_error_rates(tar_asv, non_asv, spoof_asv, asv_threshold):
738
+
739
+ # False alarm and miss rates for ASV
740
+ Pfa_asv = sum(non_asv >= asv_threshold) / non_asv.size
741
+ Pmiss_asv = sum(tar_asv < asv_threshold) / tar_asv.size
742
+
743
+ # Rate of rejecting spoofs in ASV
744
+ if spoof_asv.size == 0:
745
+ Pmiss_spoof_asv = None
746
+ Pfa_spoof_asv = None
747
+ else:
748
+ Pmiss_spoof_asv = np.sum(spoof_asv < asv_threshold) / spoof_asv.size
749
+ Pfa_spoof_asv = np.sum(spoof_asv >= asv_threshold) / spoof_asv.size
750
+
751
+ return Pfa_asv, Pmiss_asv, Pmiss_spoof_asv, Pfa_spoof_asv
752
+
753
+
754
+ def obtain_asv_error_rates(tar_asv, non_asv, spoof_asv, asv_threshold):
755
+
756
+ # False alarm and miss rates for ASV
757
+ Pfa_asv = sum(non_asv >= asv_threshold) / non_asv.size
758
+ Pmiss_asv = sum(tar_asv < asv_threshold) / tar_asv.size
759
+
760
+ # Rate of rejecting spoofs in ASV
761
+ if spoof_asv.size == 0:
762
+ Pmiss_spoof_asv = None
763
+ Pfa_spoof_asv = None
764
+ else:
765
+ Pmiss_spoof_asv = np.sum(spoof_asv < asv_threshold) / spoof_asv.size
766
+ Pfa_spoof_asv = np.sum(spoof_asv >= asv_threshold) / spoof_asv.size
767
+
768
+ return Pfa_asv, Pmiss_asv, Pmiss_spoof_asv, Pfa_spoof_asv
769
+
770
+
771
+ def compute_det_curve(target_scores, nontarget_scores):
772
+
773
+ n_scores = target_scores.size + nontarget_scores.size
774
+ all_scores = np.concatenate((target_scores, nontarget_scores))
775
+ labels = np.concatenate(
776
+ (np.ones(target_scores.size), np.zeros(nontarget_scores.size)))
777
+
778
+ # Sort labels based on scores
779
+ indices = np.argsort(all_scores, kind='mergesort')
780
+ labels = labels[indices]
781
+
782
+ # Compute false rejection and false acceptance rates
783
+ tar_trial_sums = np.cumsum(labels)
784
+ nontarget_trial_sums = nontarget_scores.size - \
785
+ (np.arange(1, n_scores + 1) - tar_trial_sums)
786
+
787
+ # false rejection rates
788
+ frr = np.concatenate(
789
+ (np.atleast_1d(0), tar_trial_sums / target_scores.size))
790
+ far = np.concatenate((np.atleast_1d(1), nontarget_trial_sums /
791
+ nontarget_scores.size)) # false acceptance rates
792
+ # Thresholds are the sorted scores
793
+ thresholds = np.concatenate(
794
+ (np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices]))
795
+
796
+ return frr, far, thresholds
797
+
798
+
799
+ def compute_Pmiss_Pfa_Pspoof_curves(tar_scores, non_scores, spf_scores):
800
+
801
+ # Concatenate all scores and designate arbitrary labels 1=target, 0=nontarget, -1=spoof
802
+ all_scores = np.concatenate((tar_scores, non_scores, spf_scores))
803
+ labels = np.concatenate((np.ones(tar_scores.size), np.zeros(non_scores.size), -1*np.ones(spf_scores.size)))
804
+
805
+ # Sort labels based on scores
806
+ indices = np.argsort(all_scores, kind='mergesort')
807
+ labels = labels[indices]
808
+
809
+ # Cumulative sums
810
+ tar_sums = np.cumsum(labels==1)
811
+ non_sums = np.cumsum(labels==0)
812
+ spoof_sums = np.cumsum(labels==-1)
813
+
814
+ Pmiss = np.concatenate((np.atleast_1d(0), tar_sums / tar_scores.size))
815
+ Pfa_non = np.concatenate((np.atleast_1d(1), 1 - (non_sums / non_scores.size)))
816
+ Pfa_spoof = np.concatenate((np.atleast_1d(1), 1 - (spoof_sums / spf_scores.size)))
817
+ thresholds = np.concatenate((np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices])) # Thresholds are the sorted scores
818
+
819
+ return Pmiss, Pfa_non, Pfa_spoof, thresholds
820
+
821
+
822
+ def compute_eer(target_scores, nontarget_scores):
823
+ """ Returns equal error rate (EER) and the corresponding threshold. """
824
+ frr, far, thresholds = compute_det_curve(target_scores, nontarget_scores)
825
+ abs_diffs = np.abs(frr - far)
826
+ min_index = np.argmin(abs_diffs)
827
+ eer = np.mean((frr[min_index], far[min_index]))
828
+ return eer, frr, far, thresholds
829
+
830
+
831
+ def compute_mindcf(frr, far, thresholds, Pspoof, Cmiss, Cfa):
832
+ min_c_det = float("inf")
833
+ min_c_det_threshold = thresholds
834
+
835
+ p_target = 1- Pspoof
836
+ for i in range(0, len(frr)):
837
+ # Weighted sum of false negative and false positive errors.
838
+ c_det = Cmiss * frr[i] * p_target + Cfa * far[i] * (1 - p_target)
839
+ if c_det < min_c_det:
840
+ min_c_det = c_det
841
+ min_c_det_threshold = thresholds[i]
842
+ # See Equations (3) and (4). Now we normalize the cost.
843
+ c_def = min(Cmiss * p_target, Cfa * (1 - p_target))
844
+ min_dcf = min_c_det / c_def
845
+ return min_dcf, min_c_det_threshold
846
+
847
+
848
+ def compute_tDCF(bonafide_score_cm, spoof_score_cm, Pfa_asv, Pmiss_asv,
849
+ Pmiss_spoof_asv, cost_model, print_cost):
850
+
851
+ # Sanity check of cost parameters
852
+ if cost_model['Cfa_asv'] < 0 or cost_model['Cmiss_asv'] < 0 or \
853
+ cost_model['Cfa_cm'] < 0 or cost_model['Cmiss_cm'] < 0:
854
+ print('WARNING: Usually the cost values should be positive!')
855
+
856
+ if cost_model['Ptar'] < 0 or cost_model['Pnon'] < 0 or cost_model['Pspoof'] < 0 or \
857
+ np.abs(cost_model['Ptar'] + cost_model['Pnon'] + cost_model['Pspoof'] - 1) > 1e-10:
858
+ sys.exit(
859
+ 'ERROR: Your prior probabilities should be positive and sum up to one.'
860
+ )
861
+
862
+ # Unless we evaluate worst-case model, we need to have some spoof tests against asv
863
+ if Pmiss_spoof_asv is None:
864
+ sys.exit(
865
+ 'ERROR: you should provide miss rate of spoof tests against your ASV system.'
866
+ )
867
+
868
+ # Sanity check of scores
869
+ combined_scores = np.concatenate((bonafide_score_cm, spoof_score_cm))
870
+ if np.isnan(combined_scores).any() or np.isinf(combined_scores).any():
871
+ sys.exit('ERROR: Your scores contain nan or inf.')
872
+
873
+ # Sanity check that inputs are scores and not decisions
874
+ n_uniq = np.unique(combined_scores).size
875
+ if n_uniq < 3:
876
+ sys.exit(
877
+ 'ERROR: You should provide soft CM scores - not binary decisions')
878
+
879
+ # Obtain miss and false alarm rates of CM
880
+ Pmiss_cm, Pfa_cm, CM_thresholds = compute_det_curve(
881
+ bonafide_score_cm, spoof_score_cm)
882
+
883
+ # Constants - see ASVspoof 2019 evaluation plan
884
+ C1 = cost_model['Ptar'] * (cost_model['Cmiss_cm'] - cost_model['Cmiss_asv'] * Pmiss_asv) - \
885
+ cost_model['Pnon'] * cost_model['Cfa_asv'] * Pfa_asv
886
+ C2 = cost_model['Cfa_cm'] * cost_model['Pspoof'] * (1 - Pmiss_spoof_asv)
887
+
888
+ # Sanity check of the weights
889
+ if C1 < 0 or C2 < 0:
890
+ sys.exit(
891
+ 'You should never see this error but I cannot evalute tDCF with negative weights - please check whether your ASV error rates are correctly computed?'
892
+ )
893
+
894
+ # Obtain t-DCF curve for all thresholds
895
+ tDCF = C1 * Pmiss_cm + C2 * Pfa_cm
896
+
897
+ # Normalized t-DCF
898
+ tDCF_norm = tDCF / np.minimum(C1, C2)
899
+
900
+ # Everything should be fine if reaching here.
901
+ if print_cost:
902
+
903
+ print('t-DCF evaluation from [Nbona={}, Nspoof={}] trials\n'.format(
904
+ bonafide_score_cm.size, spoof_score_cm.size))
905
+ print('t-DCF MODEL')
906
+ print(' Ptar = {:8.5f} (Prior probability of target user)'.
907
+ format(cost_model['Ptar']))
908
+ print(
909
+ ' Pnon = {:8.5f} (Prior probability of nontarget user)'.
910
+ format(cost_model['Pnon']))
911
+ print(
912
+ ' Pspoof = {:8.5f} (Prior probability of spoofing attack)'.
913
+ format(cost_model['Pspoof']))
914
+ print(
915
+ ' Cfa_asv = {:8.5f} (Cost of ASV falsely accepting a nontarget)'
916
+ .format(cost_model['Cfa_asv']))
917
+ print(
918
+ ' Cmiss_asv = {:8.5f} (Cost of ASV falsely rejecting target speaker)'
919
+ .format(cost_model['Cmiss_asv']))
920
+ print(
921
+ ' Cfa_cm = {:8.5f} (Cost of CM falsely passing a spoof to ASV system)'
922
+ .format(cost_model['Cfa_cm']))
923
+ print(
924
+ ' Cmiss_cm = {:8.5f} (Cost of CM falsely blocking target utterance which never reaches ASV)'
925
+ .format(cost_model['Cmiss_cm']))
926
+ print(
927
+ '\n Implied normalized t-DCF function (depends on t-DCF parameters and ASV errors), s=CM threshold)'
928
+ )
929
+
930
+ if C2 == np.minimum(C1, C2):
931
+ print(
932
+ ' tDCF_norm(s) = {:8.5f} x Pmiss_cm(s) + Pfa_cm(s)\n'.format(
933
+ C1 / C2))
934
+ else:
935
+ print(
936
+ ' tDCF_norm(s) = Pmiss_cm(s) + {:8.5f} x Pfa_cm(s)\n'.format(
937
+ C2 / C1))
938
+
939
+ return tDCF_norm, CM_thresholds
940
+
941
+
942
+ def calculate_CLLR(target_llrs, nontarget_llrs):
943
+ """
944
+ Calculate the CLLR of the scores.
945
+
946
+ Parameters:
947
+ target_llrs (list or numpy array): Log-likelihood ratios for target trials.
948
+ nontarget_llrs (list or numpy array): Log-likelihood ratios for non-target trials.
949
+
950
+ Returns:
951
+ float: The calculated CLLR value.
952
+ """
953
+ def negative_log_sigmoid(lodds):
954
+ """
955
+ Calculate the negative log of the sigmoid function.
956
+
957
+ Parameters:
958
+ lodds (numpy array): Log-odds values.
959
+
960
+ Returns:
961
+ numpy array: The negative log of the sigmoid values.
962
+ """
963
+ return np.log1p(np.exp(-lodds))
964
+
965
+ # Convert the input lists to numpy arrays if they are not already
966
+ target_llrs = np.array(target_llrs)
967
+ nontarget_llrs = np.array(nontarget_llrs)
968
+
969
+ # Calculate the CLLR value
970
+ cllr = 0.5 * (np.mean(negative_log_sigmoid(target_llrs)) + np.mean(negative_log_sigmoid(-nontarget_llrs))) / np.log(2)
971
+
972
+ return cllr
973
+
974
+
975
+ def compute_Pmiss_Pfa_Pspoof_curves(tar_scores, non_scores, spf_scores):
976
+
977
+ # Concatenate all scores and designate arbitrary labels 1=target, 0=nontarget, -1=spoof
978
+ all_scores = np.concatenate((tar_scores, non_scores, spf_scores))
979
+ labels = np.concatenate((np.ones(tar_scores.size), np.zeros(non_scores.size), -1*np.ones(spf_scores.size)))
980
+
981
+ # Sort labels based on scores
982
+ indices = np.argsort(all_scores, kind='mergesort')
983
+ labels = labels[indices]
984
+
985
+ # Cumulative sums
986
+ tar_sums = np.cumsum(labels==1)
987
+ non_sums = np.cumsum(labels==0)
988
+ spoof_sums = np.cumsum(labels==-1)
989
+
990
+ Pmiss = np.concatenate((np.atleast_1d(0), tar_sums / tar_scores.size))
991
+ Pfa_non = np.concatenate((np.atleast_1d(1), 1 - (non_sums / non_scores.size)))
992
+ Pfa_spoof = np.concatenate((np.atleast_1d(1), 1 - (spoof_sums / spf_scores.size)))
993
+ thresholds = np.concatenate((np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices])) # Thresholds are the sorted scores
994
+
995
+ return Pmiss, Pfa_non, Pfa_spoof, thresholds
996
+
997
+
998
+ def compute_teer(Pmiss_CM, Pfa_CM, tau_CM, Pmiss_ASV, Pfa_non_ASV, Pfa_spf_ASV, tau_ASV):
999
+ # Different spoofing prevalence priors (rho) parameters values
1000
+ rho_vals = [0,0.5,1]
1001
+
1002
+ tEER_val = np.empty([len(rho_vals),len(tau_ASV)], dtype=float)
1003
+
1004
+ for rho_idx, rho_spf in enumerate(rho_vals):
1005
+
1006
+ # Table to store the CM threshold index, per each of the ASV operating points
1007
+ tEER_idx_CM = np.empty(len(tau_ASV), dtype=int)
1008
+
1009
+ tEER_path = np.empty([len(rho_vals),len(tau_ASV),2], dtype=float)
1010
+
1011
+ # Tables to store the t-EER, total Pfa and total miss valuees along the t-EER path
1012
+ Pmiss_total = np.empty(len(tau_ASV), dtype=float)
1013
+ Pfa_total = np.empty(len(tau_ASV), dtype=float)
1014
+ min_tEER = np.inf
1015
+ argmin_tEER = np.empty(2)
1016
+
1017
+ # best intersection point
1018
+ xpoint_crit_best = np.inf
1019
+ xpoint = np.empty(2)
1020
+
1021
+ # Loop over all possible ASV thresholds
1022
+ for tau_ASV_idx, tau_ASV_val in enumerate(tau_ASV):
1023
+
1024
+ # Tandem miss and fa rates as defined in the manuscript
1025
+ Pmiss_tdm = Pmiss_CM + (1 - Pmiss_CM) * Pmiss_ASV[tau_ASV_idx]
1026
+ Pfa_tdm = (1 - rho_spf) * (1 - Pmiss_CM) * Pfa_non_ASV[tau_ASV_idx] + rho_spf * Pfa_CM * Pfa_spf_ASV[tau_ASV_idx]
1027
+
1028
+ # Store only the INDEX of the CM threshold (for the current ASV threshold)
1029
+ h = Pmiss_tdm - Pfa_tdm
1030
+ tmp = np.argmin(abs(h))
1031
+ tEER_idx_CM[tau_ASV_idx] = tmp
1032
+
1033
+ if Pmiss_ASV[tau_ASV_idx] < (1 - rho_spf) * Pfa_non_ASV[tau_ASV_idx] + rho_spf * Pfa_spf_ASV[tau_ASV_idx]:
1034
+ Pmiss_total[tau_ASV_idx] = Pmiss_tdm[tmp]
1035
+ Pfa_total[tau_ASV_idx] = Pfa_tdm[tmp]
1036
+
1037
+ tEER_val[rho_idx,tau_ASV_idx] = np.mean([Pfa_total[tau_ASV_idx], Pmiss_total[tau_ASV_idx]])
1038
+
1039
+ tEER_path[rho_idx,tau_ASV_idx, 0] = tau_ASV_val
1040
+ tEER_path[rho_idx,tau_ASV_idx, 1] = tau_CM[tmp]
1041
+
1042
+ if tEER_val[rho_idx,tau_ASV_idx] < min_tEER:
1043
+ min_tEER = tEER_val[rho_idx,tau_ASV_idx]
1044
+ argmin_tEER[0] = tau_ASV_val
1045
+ argmin_tEER[1] = tau_CM[tmp]
1046
+
1047
+ # Check how close we are to the INTERSECTION POINT for different prior (rho) values:
1048
+ LHS = Pfa_non_ASV[tau_ASV_idx]/Pfa_spf_ASV[tau_ASV_idx]
1049
+ RHS = Pfa_CM[tmp]/(1 - Pmiss_CM[tmp])
1050
+ crit = abs(LHS - RHS)
1051
+
1052
+ if crit < xpoint_crit_best:
1053
+ xpoint_crit_best = crit
1054
+ xpoint[0] = tau_ASV_val
1055
+ xpoint[1] = tau_CM[tmp]
1056
+ xpoint_tEER = Pfa_spf_ASV[tau_ASV_idx]*Pfa_CM[tmp]
1057
+ else:
1058
+ # Not in allowed region
1059
+ tEER_path[rho_idx,tau_ASV_idx, 0] = np.nan
1060
+ tEER_path[rho_idx,tau_ASV_idx, 1] = np.nan
1061
+ Pmiss_total[tau_ASV_idx] = np.nan
1062
+ Pfa_total[tau_ASV_idx] = np.nan
1063
+ tEER_val[rho_idx,tau_ASV_idx] = np.nan
1064
+
1065
+ return xpoint_tEER*100
evaluation/AASIST/S1_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b36eddfdb4fa2c1dbdf00e57e34b83e841218872da6c6d6f97f9616182a9f876
3
+ size 1277933
evaluation/AASIST/S2_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a333d6c7d7a40cfdb25f69d4ac2dd2bc3731ba71ec3adf58e2dd837bbe1eef93
3
+ size 1277933
evaluation/AASIST/S3_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3eaf2873d0b367721d96ea2407539f19e52700eb0c3c8f6dcf16e9603b02739f
3
+ size 1277933
evaluation/AASIST/S4_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29cf1672b9bdde392de88aa875ca7ea915d750d4ac3d8ed5c93c5e691a3939dd
3
+ size 1277933
evaluation/AASIST/S5_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49c34c6bcadfee296cce9b3ff3ea9e0d7852f39c7cef3ad8b02c16ac213c2427
3
+ size 1277933
evaluation/AASIST/__pycache__/AASIST_util.cpython-310.pyc ADDED
Binary file (24.6 kB). View file
 
evaluation/AASIST/__pycache__/AASIST_util.cpython-39.pyc ADDED
Binary file (24.5 kB). View file