rayso Claude Opus 4.5 commited on
Commit
5f16dc0
·
1 Parent(s): f251bec

Add AASIST model with multi-segment analysis

Browse files

- Full AASIST architecture for deepfake detection
- Multi-segment analysis with majority voting
- Improved accuracy for ElevenLabs V3 detection
- Git LFS for model weights

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Files changed (5) hide show
  1. .gitattributes +1 -0
  2. .gitignore +1 -0
  3. AASIST.pth +3 -0
  4. aasist_model.py +607 -0
  5. app.py +92 -378
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pth filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
AASIST.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51d2d9cf0738172f61e2a384ec50a54a55363240f67c971ed55a92435bc1a1c0
3
+ size 1281532
aasist_model.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AASIST
3
+ Copyright (c) 2021-present NAVER Corp.
4
+ MIT license
5
+ """
6
+
7
+ import random
8
+ from typing import Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torch import Tensor
15
+
16
+
17
+ class GraphAttentionLayer(nn.Module):
18
+ def __init__(self, in_dim, out_dim, **kwargs):
19
+ super().__init__()
20
+
21
+ # attention map
22
+ self.att_proj = nn.Linear(in_dim, out_dim)
23
+ self.att_weight = self._init_new_params(out_dim, 1)
24
+
25
+ # project
26
+ self.proj_with_att = nn.Linear(in_dim, out_dim)
27
+ self.proj_without_att = nn.Linear(in_dim, out_dim)
28
+
29
+ # batch norm
30
+ self.bn = nn.BatchNorm1d(out_dim)
31
+
32
+ # dropout for inputs
33
+ self.input_drop = nn.Dropout(p=0.2)
34
+
35
+ # activate
36
+ self.act = nn.SELU(inplace=True)
37
+
38
+ # temperature
39
+ self.temp = 1.
40
+ if "temperature" in kwargs:
41
+ self.temp = kwargs["temperature"]
42
+
43
+ def forward(self, x):
44
+ '''
45
+ x :(#bs, #node, #dim)
46
+ '''
47
+ # apply input dropout
48
+ x = self.input_drop(x)
49
+
50
+ # derive attention map
51
+ att_map = self._derive_att_map(x)
52
+
53
+ # projection
54
+ x = self._project(x, att_map)
55
+
56
+ # apply batch norm
57
+ x = self._apply_BN(x)
58
+ x = self.act(x)
59
+ return x
60
+
61
+ def _pairwise_mul_nodes(self, x):
62
+ '''
63
+ Calculates pairwise multiplication of nodes.
64
+ - for attention map
65
+ x :(#bs, #node, #dim)
66
+ out_shape :(#bs, #node, #node, #dim)
67
+ '''
68
+
69
+ nb_nodes = x.size(1)
70
+ x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
71
+ x_mirror = x.transpose(1, 2)
72
+
73
+ return x * x_mirror
74
+
75
+ def _derive_att_map(self, x):
76
+ '''
77
+ x :(#bs, #node, #dim)
78
+ out_shape :(#bs, #node, #node, 1)
79
+ '''
80
+ att_map = self._pairwise_mul_nodes(x)
81
+ # size: (#bs, #node, #node, #dim_out)
82
+ att_map = torch.tanh(self.att_proj(att_map))
83
+ # size: (#bs, #node, #node, 1)
84
+ att_map = torch.matmul(att_map, self.att_weight)
85
+
86
+ # apply temperature
87
+ att_map = att_map / self.temp
88
+
89
+ att_map = F.softmax(att_map, dim=-2)
90
+
91
+ return att_map
92
+
93
+ def _project(self, x, att_map):
94
+ x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
95
+ x2 = self.proj_without_att(x)
96
+
97
+ return x1 + x2
98
+
99
+ def _apply_BN(self, x):
100
+ org_size = x.size()
101
+ x = x.view(-1, org_size[-1])
102
+ x = self.bn(x)
103
+ x = x.view(org_size)
104
+
105
+ return x
106
+
107
+ def _init_new_params(self, *size):
108
+ out = nn.Parameter(torch.FloatTensor(*size))
109
+ nn.init.xavier_normal_(out)
110
+ return out
111
+
112
+
113
+ class HtrgGraphAttentionLayer(nn.Module):
114
+ def __init__(self, in_dim, out_dim, **kwargs):
115
+ super().__init__()
116
+
117
+ self.proj_type1 = nn.Linear(in_dim, in_dim)
118
+ self.proj_type2 = nn.Linear(in_dim, in_dim)
119
+
120
+ # attention map
121
+ self.att_proj = nn.Linear(in_dim, out_dim)
122
+ self.att_projM = nn.Linear(in_dim, out_dim)
123
+
124
+ self.att_weight11 = self._init_new_params(out_dim, 1)
125
+ self.att_weight22 = self._init_new_params(out_dim, 1)
126
+ self.att_weight12 = self._init_new_params(out_dim, 1)
127
+ self.att_weightM = self._init_new_params(out_dim, 1)
128
+
129
+ # project
130
+ self.proj_with_att = nn.Linear(in_dim, out_dim)
131
+ self.proj_without_att = nn.Linear(in_dim, out_dim)
132
+
133
+ self.proj_with_attM = nn.Linear(in_dim, out_dim)
134
+ self.proj_without_attM = nn.Linear(in_dim, out_dim)
135
+
136
+ # batch norm
137
+ self.bn = nn.BatchNorm1d(out_dim)
138
+
139
+ # dropout for inputs
140
+ self.input_drop = nn.Dropout(p=0.2)
141
+
142
+ # activate
143
+ self.act = nn.SELU(inplace=True)
144
+
145
+ # temperature
146
+ self.temp = 1.
147
+ if "temperature" in kwargs:
148
+ self.temp = kwargs["temperature"]
149
+
150
+ def forward(self, x1, x2, master=None):
151
+ '''
152
+ x1 :(#bs, #node, #dim)
153
+ x2 :(#bs, #node, #dim)
154
+ '''
155
+ num_type1 = x1.size(1)
156
+ num_type2 = x2.size(1)
157
+
158
+ x1 = self.proj_type1(x1)
159
+ x2 = self.proj_type2(x2)
160
+
161
+ x = torch.cat([x1, x2], dim=1)
162
+
163
+ if master is None:
164
+ master = torch.mean(x, dim=1, keepdim=True)
165
+
166
+ # apply input dropout
167
+ x = self.input_drop(x)
168
+
169
+ # derive attention map
170
+ att_map = self._derive_att_map(x, num_type1, num_type2)
171
+
172
+ # directional edge for master node
173
+ master = self._update_master(x, master)
174
+
175
+ # projection
176
+ x = self._project(x, att_map)
177
+
178
+ # apply batch norm
179
+ x = self._apply_BN(x)
180
+ x = self.act(x)
181
+
182
+ x1 = x.narrow(1, 0, num_type1)
183
+ x2 = x.narrow(1, num_type1, num_type2)
184
+
185
+ return x1, x2, master
186
+
187
+ def _update_master(self, x, master):
188
+
189
+ att_map = self._derive_att_map_master(x, master)
190
+ master = self._project_master(x, master, att_map)
191
+
192
+ return master
193
+
194
+ def _pairwise_mul_nodes(self, x):
195
+ '''
196
+ Calculates pairwise multiplication of nodes.
197
+ - for attention map
198
+ x :(#bs, #node, #dim)
199
+ out_shape :(#bs, #node, #node, #dim)
200
+ '''
201
+
202
+ nb_nodes = x.size(1)
203
+ x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
204
+ x_mirror = x.transpose(1, 2)
205
+
206
+ return x * x_mirror
207
+
208
+ def _derive_att_map_master(self, x, master):
209
+ '''
210
+ x :(#bs, #node, #dim)
211
+ out_shape :(#bs, #node, #node, 1)
212
+ '''
213
+ att_map = x * master
214
+ att_map = torch.tanh(self.att_projM(att_map))
215
+
216
+ att_map = torch.matmul(att_map, self.att_weightM)
217
+
218
+ # apply temperature
219
+ att_map = att_map / self.temp
220
+
221
+ att_map = F.softmax(att_map, dim=-2)
222
+
223
+ return att_map
224
+
225
+ def _derive_att_map(self, x, num_type1, num_type2):
226
+ '''
227
+ x :(#bs, #node, #dim)
228
+ out_shape :(#bs, #node, #node, 1)
229
+ '''
230
+ att_map = self._pairwise_mul_nodes(x)
231
+ # size: (#bs, #node, #node, #dim_out)
232
+ att_map = torch.tanh(self.att_proj(att_map))
233
+ # size: (#bs, #node, #node, 1)
234
+
235
+ att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1)
236
+
237
+ att_board[:, :num_type1, :num_type1, :] = torch.matmul(
238
+ att_map[:, :num_type1, :num_type1, :], self.att_weight11)
239
+ att_board[:, num_type1:, num_type1:, :] = torch.matmul(
240
+ att_map[:, num_type1:, num_type1:, :], self.att_weight22)
241
+ att_board[:, :num_type1, num_type1:, :] = torch.matmul(
242
+ att_map[:, :num_type1, num_type1:, :], self.att_weight12)
243
+ att_board[:, num_type1:, :num_type1, :] = torch.matmul(
244
+ att_map[:, num_type1:, :num_type1, :], self.att_weight12)
245
+
246
+ att_map = att_board
247
+
248
+ # att_map = torch.matmul(att_map, self.att_weight12)
249
+
250
+ # apply temperature
251
+ att_map = att_map / self.temp
252
+
253
+ att_map = F.softmax(att_map, dim=-2)
254
+
255
+ return att_map
256
+
257
+ def _project(self, x, att_map):
258
+ x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
259
+ x2 = self.proj_without_att(x)
260
+
261
+ return x1 + x2
262
+
263
+ def _project_master(self, x, master, att_map):
264
+
265
+ x1 = self.proj_with_attM(torch.matmul(
266
+ att_map.squeeze(-1).unsqueeze(1), x))
267
+ x2 = self.proj_without_attM(master)
268
+
269
+ return x1 + x2
270
+
271
+ def _apply_BN(self, x):
272
+ org_size = x.size()
273
+ x = x.view(-1, org_size[-1])
274
+ x = self.bn(x)
275
+ x = x.view(org_size)
276
+
277
+ return x
278
+
279
+ def _init_new_params(self, *size):
280
+ out = nn.Parameter(torch.FloatTensor(*size))
281
+ nn.init.xavier_normal_(out)
282
+ return out
283
+
284
+
285
+ class GraphPool(nn.Module):
286
+ def __init__(self, k: float, in_dim: int, p: Union[float, int]):
287
+ super().__init__()
288
+ self.k = k
289
+ self.sigmoid = nn.Sigmoid()
290
+ self.proj = nn.Linear(in_dim, 1)
291
+ self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
292
+ self.in_dim = in_dim
293
+
294
+ def forward(self, h):
295
+ Z = self.drop(h)
296
+ weights = self.proj(Z)
297
+ scores = self.sigmoid(weights)
298
+ new_h = self.top_k_graph(scores, h, self.k)
299
+
300
+ return new_h
301
+
302
+ def top_k_graph(self, scores, h, k):
303
+ """
304
+ args
305
+ =====
306
+ scores: attention-based weights (#bs, #node, 1)
307
+ h: graph data (#bs, #node, #dim)
308
+ k: ratio of remaining nodes, (float)
309
+
310
+ returns
311
+ =====
312
+ h: graph pool applied data (#bs, #node', #dim)
313
+ """
314
+ _, n_nodes, n_feat = h.size()
315
+ n_nodes = max(int(n_nodes * k), 1)
316
+ _, idx = torch.topk(scores, n_nodes, dim=1)
317
+ idx = idx.expand(-1, -1, n_feat)
318
+
319
+ h = h * scores
320
+ h = torch.gather(h, 1, idx)
321
+
322
+ return h
323
+
324
+
325
+ class CONV(nn.Module):
326
+ @staticmethod
327
+ def to_mel(hz):
328
+ return 2595 * np.log10(1 + hz / 700)
329
+
330
+ @staticmethod
331
+ def to_hz(mel):
332
+ return 700 * (10**(mel / 2595) - 1)
333
+
334
+ def __init__(self,
335
+ out_channels,
336
+ kernel_size,
337
+ sample_rate=16000,
338
+ in_channels=1,
339
+ stride=1,
340
+ padding=0,
341
+ dilation=1,
342
+ bias=False,
343
+ groups=1,
344
+ mask=False):
345
+ super().__init__()
346
+ if in_channels != 1:
347
+
348
+ msg = "SincConv only support one input channel (here, in_channels = {%i})" % (
349
+ in_channels)
350
+ raise ValueError(msg)
351
+ self.out_channels = out_channels
352
+ self.kernel_size = kernel_size
353
+ self.sample_rate = sample_rate
354
+
355
+ # Forcing the filters to be odd (i.e, perfectly symmetrics)
356
+ if kernel_size % 2 == 0:
357
+ self.kernel_size = self.kernel_size + 1
358
+ self.stride = stride
359
+ self.padding = padding
360
+ self.dilation = dilation
361
+ self.mask = mask
362
+ if bias:
363
+ raise ValueError('SincConv does not support bias.')
364
+ if groups > 1:
365
+ raise ValueError('SincConv does not support groups.')
366
+
367
+ NFFT = 512
368
+ f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
369
+ fmel = self.to_mel(f)
370
+ fmelmax = np.max(fmel)
371
+ fmelmin = np.min(fmel)
372
+ filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 1)
373
+ filbandwidthsf = self.to_hz(filbandwidthsmel)
374
+
375
+ self.mel = filbandwidthsf
376
+ self.hsupp = torch.arange(-(self.kernel_size - 1) / 2,
377
+ (self.kernel_size - 1) / 2 + 1)
378
+ self.band_pass = torch.zeros(self.out_channels, self.kernel_size)
379
+ for i in range(len(self.mel) - 1):
380
+ fmin = self.mel[i]
381
+ fmax = self.mel[i + 1]
382
+ hHigh = (2*fmax/self.sample_rate) * \
383
+ np.sinc(2*fmax*self.hsupp/self.sample_rate)
384
+ hLow = (2*fmin/self.sample_rate) * \
385
+ np.sinc(2*fmin*self.hsupp/self.sample_rate)
386
+ hideal = hHigh - hLow
387
+
388
+ self.band_pass[i, :] = Tensor(np.hamming(
389
+ self.kernel_size)) * Tensor(hideal)
390
+
391
+ def forward(self, x, mask=False):
392
+ band_pass_filter = self.band_pass.clone().to(x.device)
393
+ if mask:
394
+ A = np.random.uniform(0, 20)
395
+ A = int(A)
396
+ A0 = random.randint(0, band_pass_filter.shape[0] - A)
397
+ band_pass_filter[A0:A0 + A, :] = 0
398
+ else:
399
+ band_pass_filter = band_pass_filter
400
+
401
+ self.filters = (band_pass_filter).view(self.out_channels, 1,
402
+ self.kernel_size)
403
+
404
+ return F.conv1d(x,
405
+ self.filters,
406
+ stride=self.stride,
407
+ padding=self.padding,
408
+ dilation=self.dilation,
409
+ bias=None,
410
+ groups=1)
411
+
412
+
413
+ class Residual_block(nn.Module):
414
+ def __init__(self, nb_filts, first=False):
415
+ super().__init__()
416
+ self.first = first
417
+
418
+ if not self.first:
419
+ self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
420
+ self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
421
+ out_channels=nb_filts[1],
422
+ kernel_size=(2, 3),
423
+ padding=(1, 1),
424
+ stride=1)
425
+ self.selu = nn.SELU(inplace=True)
426
+
427
+ self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
428
+ self.conv2 = nn.Conv2d(in_channels=nb_filts[1],
429
+ out_channels=nb_filts[1],
430
+ kernel_size=(2, 3),
431
+ padding=(0, 1),
432
+ stride=1)
433
+
434
+ if nb_filts[0] != nb_filts[1]:
435
+ self.downsample = True
436
+ self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
437
+ out_channels=nb_filts[1],
438
+ padding=(0, 1),
439
+ kernel_size=(1, 3),
440
+ stride=1)
441
+
442
+ else:
443
+ self.downsample = False
444
+ self.mp = nn.MaxPool2d((1, 3)) # self.mp = nn.MaxPool2d((1,4))
445
+
446
+ def forward(self, x):
447
+ identity = x
448
+ if not self.first:
449
+ out = self.bn1(x)
450
+ out = self.selu(out)
451
+ else:
452
+ out = x
453
+ out = self.conv1(x)
454
+
455
+ # print('out',out.shape)
456
+ out = self.bn2(out)
457
+ out = self.selu(out)
458
+ # print('out',out.shape)
459
+ out = self.conv2(out)
460
+ #print('conv2 out',out.shape)
461
+ if self.downsample:
462
+ identity = self.conv_downsample(identity)
463
+
464
+ out += identity
465
+ out = self.mp(out)
466
+ return out
467
+
468
+
469
+ class Model(nn.Module):
470
+ def __init__(self, d_args):
471
+ super().__init__()
472
+
473
+ self.d_args = d_args
474
+ filts = d_args["filts"]
475
+ gat_dims = d_args["gat_dims"]
476
+ pool_ratios = d_args["pool_ratios"]
477
+ temperatures = d_args["temperatures"]
478
+
479
+ self.conv_time = CONV(out_channels=filts[0],
480
+ kernel_size=d_args["first_conv"],
481
+ in_channels=1)
482
+ self.first_bn = nn.BatchNorm2d(num_features=1)
483
+
484
+ self.drop = nn.Dropout(0.5, inplace=True)
485
+ self.drop_way = nn.Dropout(0.2, inplace=True)
486
+ self.selu = nn.SELU(inplace=True)
487
+
488
+ self.encoder = nn.Sequential(
489
+ nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
490
+ nn.Sequential(Residual_block(nb_filts=filts[2])),
491
+ nn.Sequential(Residual_block(nb_filts=filts[3])),
492
+ nn.Sequential(Residual_block(nb_filts=filts[4])),
493
+ nn.Sequential(Residual_block(nb_filts=filts[4])),
494
+ nn.Sequential(Residual_block(nb_filts=filts[4])))
495
+
496
+ self.pos_S = nn.Parameter(torch.randn(1, 23, filts[-1][-1]))
497
+ self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
498
+ self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
499
+
500
+ self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1],
501
+ gat_dims[0],
502
+ temperature=temperatures[0])
503
+ self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1],
504
+ gat_dims[0],
505
+ temperature=temperatures[1])
506
+
507
+ self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(
508
+ gat_dims[0], gat_dims[1], temperature=temperatures[2])
509
+ self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(
510
+ gat_dims[1], gat_dims[1], temperature=temperatures[2])
511
+
512
+ self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(
513
+ gat_dims[0], gat_dims[1], temperature=temperatures[2])
514
+
515
+ self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(
516
+ gat_dims[1], gat_dims[1], temperature=temperatures[2])
517
+
518
+ self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
519
+ self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
520
+ self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
521
+ self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
522
+
523
+ self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
524
+ self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
525
+
526
+ self.out_layer = nn.Linear(5 * gat_dims[1], 2)
527
+
528
+ def forward(self, x, Freq_aug=False):
529
+
530
+ x = x.unsqueeze(1)
531
+ x = self.conv_time(x, mask=Freq_aug)
532
+ x = x.unsqueeze(dim=1)
533
+ x = F.max_pool2d(torch.abs(x), (3, 3))
534
+ x = self.first_bn(x)
535
+ x = self.selu(x)
536
+
537
+ # get embeddings using encoder
538
+ # (#bs, #filt, #spec, #seq)
539
+ e = self.encoder(x)
540
+
541
+ # spectral GAT (GAT-S)
542
+ e_S, _ = torch.max(torch.abs(e), dim=3) # max along time
543
+ e_S = e_S.transpose(1, 2) + self.pos_S
544
+
545
+ gat_S = self.GAT_layer_S(e_S)
546
+ out_S = self.pool_S(gat_S) # (#bs, #node, #dim)
547
+
548
+ # temporal GAT (GAT-T)
549
+ e_T, _ = torch.max(torch.abs(e), dim=2) # max along freq
550
+ e_T = e_T.transpose(1, 2)
551
+
552
+ gat_T = self.GAT_layer_T(e_T)
553
+ out_T = self.pool_T(gat_T)
554
+
555
+ # learnable master node
556
+ master1 = self.master1.expand(x.size(0), -1, -1)
557
+ master2 = self.master2.expand(x.size(0), -1, -1)
558
+
559
+ # inference 1
560
+ out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(
561
+ out_T, out_S, master=self.master1)
562
+
563
+ out_S1 = self.pool_hS1(out_S1)
564
+ out_T1 = self.pool_hT1(out_T1)
565
+
566
+ out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(
567
+ out_T1, out_S1, master=master1)
568
+ out_T1 = out_T1 + out_T_aug
569
+ out_S1 = out_S1 + out_S_aug
570
+ master1 = master1 + master_aug
571
+
572
+ # inference 2
573
+ out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(
574
+ out_T, out_S, master=self.master2)
575
+ out_S2 = self.pool_hS2(out_S2)
576
+ out_T2 = self.pool_hT2(out_T2)
577
+
578
+ out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(
579
+ out_T2, out_S2, master=master2)
580
+ out_T2 = out_T2 + out_T_aug
581
+ out_S2 = out_S2 + out_S_aug
582
+ master2 = master2 + master_aug
583
+
584
+ out_T1 = self.drop_way(out_T1)
585
+ out_T2 = self.drop_way(out_T2)
586
+ out_S1 = self.drop_way(out_S1)
587
+ out_S2 = self.drop_way(out_S2)
588
+ master1 = self.drop_way(master1)
589
+ master2 = self.drop_way(master2)
590
+
591
+ out_T = torch.max(out_T1, out_T2)
592
+ out_S = torch.max(out_S1, out_S2)
593
+ master = torch.max(master1, master2)
594
+
595
+ T_max, _ = torch.max(torch.abs(out_T), dim=1)
596
+ T_avg = torch.mean(out_T, dim=1)
597
+
598
+ S_max, _ = torch.max(torch.abs(out_S), dim=1)
599
+ S_avg = torch.mean(out_S, dim=1)
600
+
601
+ last_hidden = torch.cat(
602
+ [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)
603
+
604
+ last_hidden = self.drop(last_hidden)
605
+ output = self.out_layer(last_hidden)
606
+
607
+ return last_hidden, output
app.py CHANGED
@@ -1,20 +1,15 @@
1
  """
2
  VoiceDetector - Forensic Deepfake Audio Detection
3
- Hugging Face Spaces Version
4
-
5
- Powered by AASIST (EER: 0.83% on ASVspoof 2019 LA)
6
  """
7
 
8
  import os
9
  import sys
10
- import json
11
  import time
12
- from datetime import datetime
13
 
14
  import gradio as gr
15
  import numpy as np
16
  import torch
17
- import torch.nn as nn
18
  import librosa
19
  import librosa.display
20
  import matplotlib
@@ -23,342 +18,8 @@ import matplotlib.pyplot as plt
23
  from PIL import Image
24
  import io
25
 
26
- # ============================================
27
- # AASIST Model Definition
28
- # ============================================
29
-
30
- class GraphAttentionLayer(nn.Module):
31
- def __init__(self, in_dim, out_dim, **kwargs):
32
- super().__init__()
33
- self.att_proj = nn.Linear(in_dim, out_dim)
34
- self.att_weight = nn.Parameter(torch.Tensor(out_dim, 1))
35
- nn.init.xavier_uniform_(self.att_weight)
36
- self.proj_with_att = nn.Linear(in_dim, out_dim)
37
- self.proj_without_att = nn.Linear(in_dim, out_dim)
38
- self.bn = nn.BatchNorm1d(out_dim)
39
- self.input_drop = nn.Dropout(p=0.2)
40
- self.act = nn.SELU(inplace=True)
41
- self.temp = kwargs.get("temperature", 1.0)
42
-
43
- def forward(self, x):
44
- x = self.input_drop(x)
45
- att_map = self._derive_att_map(x)
46
- x = self._project(x, att_map)
47
- x = self._apply_BN(x)
48
- x = self.act(x)
49
- return x
50
-
51
- def _pairwise_mul_nodes(self, x):
52
- nb_nodes = x.size(1)
53
- x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
54
- x_mirror = x.transpose(1, 2)
55
- return x * x_mirror
56
-
57
- def _derive_att_map(self, x):
58
- att_map = self._pairwise_mul_nodes(x)
59
- att_map = torch.tanh(self.att_proj(att_map))
60
- att_map = torch.matmul(att_map, self.att_weight)
61
- att_map = att_map / self.temp
62
- att_map = torch.softmax(att_map, dim=-2)
63
- return att_map
64
-
65
- def _project(self, x, att_map):
66
- x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
67
- x2 = self.proj_without_att(x)
68
- return x1 + x2
69
-
70
- def _apply_BN(self, x):
71
- org_size = x.size()
72
- x = x.view(-1, org_size[-1])
73
- x = self.bn(x)
74
- x = x.view(org_size)
75
- return x
76
-
77
-
78
- class HtrgGraphAttentionLayer(nn.Module):
79
- def __init__(self, in_dim, out_dim, **kwargs):
80
- super().__init__()
81
- self.proj_type1 = nn.Linear(in_dim, in_dim)
82
- self.proj_type2 = nn.Linear(in_dim, in_dim)
83
- self.att_proj = nn.Linear(in_dim, out_dim)
84
- self.att_weight = nn.Parameter(torch.Tensor(out_dim, 1))
85
- nn.init.xavier_uniform_(self.att_weight)
86
- self.proj_with_att = nn.Linear(in_dim, out_dim)
87
- self.proj_without_att = nn.Linear(in_dim, out_dim)
88
- self.bn = nn.BatchNorm1d(out_dim)
89
- self.input_drop = nn.Dropout(p=0.2)
90
- self.act = nn.SELU(inplace=True)
91
- self.temp = kwargs.get("temperature", 1.0)
92
-
93
- def forward(self, x1, x2, master=None):
94
- num_type1 = x1.size(1)
95
- if master is None:
96
- x = torch.cat([x1, x2], dim=1)
97
- else:
98
- x = torch.cat([x1, x2, master], dim=1)
99
- x = self.input_drop(x)
100
- x_type1 = self.proj_type1(x)
101
- x_type2 = self.proj_type2(x)
102
- att_map = self._derive_att_map(x_type1, x_type2)
103
- x = self._project(x, att_map)
104
- x = self._apply_BN(x)
105
- x = self.act(x)
106
- x1 = x[:, :num_type1, :]
107
- x2 = x[:, num_type1:, :]
108
- return x1, x2
109
-
110
- def _pairwise_mul_nodes(self, x1, x2):
111
- nb_nodes = x1.size(1) + x2.size(1)
112
- x = torch.cat([x1, x2], dim=1)
113
- x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
114
- x_mirror = x.transpose(1, 2)
115
- return x * x_mirror
116
-
117
- def _derive_att_map(self, x1, x2):
118
- att_map = self._pairwise_mul_nodes(x1, x2)
119
- att_map = torch.tanh(self.att_proj(att_map))
120
- att_map = torch.matmul(att_map, self.att_weight)
121
- att_map = att_map / self.temp
122
- att_map = torch.softmax(att_map, dim=-2)
123
- return att_map
124
-
125
- def _project(self, x, att_map):
126
- x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
127
- x2 = self.proj_without_att(x)
128
- return x1 + x2
129
-
130
- def _apply_BN(self, x):
131
- org_size = x.size()
132
- x = x.view(-1, org_size[-1])
133
- x = self.bn(x)
134
- x = x.view(org_size)
135
- return x
136
-
137
-
138
- class GraphPool(nn.Module):
139
- def __init__(self, k, in_dim, p):
140
- super().__init__()
141
- self.k = k
142
- self.sigmoid = nn.Sigmoid()
143
- self.proj = nn.Linear(in_dim, 1)
144
- self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
145
-
146
- def forward(self, h):
147
- Z = self.drop(h)
148
- weights = self.proj(Z).squeeze(-1)
149
- scores = self.sigmoid(weights)
150
- _, idx = torch.topk(scores, max(2, int(self.k * h.size(1))))
151
- new_h = h[:, idx, :]
152
- return new_h
153
-
154
-
155
- class CONV(nn.Module):
156
- @staticmethod
157
- def to_mel(hz):
158
- return 2595 * np.log10(1 + hz / 700)
159
-
160
- @staticmethod
161
- def to_hz(mel):
162
- return 700 * (10 ** (mel / 2595) - 1)
163
-
164
- def __init__(self, out_channels, kernel_size, sample_rate=16000, in_channels=1,
165
- stride=1, padding=0, dilation=1, bias=False, groups=1, min_low_hz=50, min_band_hz=50):
166
- super().__init__()
167
- self.out_channels = out_channels
168
- self.kernel_size = kernel_size
169
- self.sample_rate = sample_rate
170
- self.min_low_hz = min_low_hz
171
- self.min_band_hz = min_band_hz
172
-
173
- low_hz = 30
174
- high_hz = sample_rate / 2 - (min_low_hz + min_band_hz)
175
- mel = np.linspace(self.to_mel(low_hz), self.to_mel(high_hz), out_channels + 1)
176
- hz = self.to_hz(mel)
177
-
178
- self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1))
179
- self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1))
180
-
181
- n_lin = torch.linspace(0, (kernel_size / 2) - 1, steps=kernel_size // 2)
182
- self.window_ = 0.54 - 0.46 * torch.cos(2 * np.pi * n_lin / kernel_size)
183
- n = (kernel_size - 1) / 2.0
184
- self.n_ = 2 * np.pi * torch.arange(-n, 0).view(1, -1) / sample_rate
185
-
186
- self.stride = stride
187
- self.padding = padding
188
- self.dilation = dilation
189
-
190
- def forward(self, x):
191
- self.n_ = self.n_.to(x.device)
192
- self.window_ = self.window_.to(x.device)
193
-
194
- low = self.min_low_hz + torch.abs(self.low_hz_)
195
- high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_), self.min_low_hz, self.sample_rate / 2)
196
- band = (high - low)[:, 0]
197
-
198
- f_times_t_low = torch.matmul(low, self.n_)
199
- f_times_t_high = torch.matmul(high, self.n_)
200
-
201
- band_pass_left = ((torch.sin(f_times_t_high) - torch.sin(f_times_t_low)) / (self.n_ / 2)) * self.window_
202
- band_pass_center = 2 * band.view(-1, 1)
203
- band_pass_right = torch.flip(band_pass_left, dims=[1])
204
- band_pass = torch.cat([band_pass_left, band_pass_center, band_pass_right], dim=1)
205
- band_pass = band_pass / (2 * band[:, None])
206
- self.filters = band_pass.view(self.out_channels, 1, self.kernel_size)
207
-
208
- return torch.nn.functional.conv1d(x, self.filters, stride=self.stride,
209
- padding=self.padding, dilation=self.dilation, bias=None, groups=1)
210
-
211
-
212
- class Residual_block(nn.Module):
213
- def __init__(self, nb_filts, first=False):
214
- super().__init__()
215
- self.first = first
216
-
217
- if not first:
218
- self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
219
- self.conv1 = nn.Conv2d(in_channels=nb_filts[0], out_channels=nb_filts[1],
220
- kernel_size=(2, 3), padding=(1, 1), stride=1)
221
- self.selu = nn.SELU(inplace=True)
222
- self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
223
- self.conv2 = nn.Conv2d(in_channels=nb_filts[1], out_channels=nb_filts[1],
224
- kernel_size=(2, 3), padding=(0, 1), stride=1)
225
-
226
- if nb_filts[0] != nb_filts[1]:
227
- self.downsample = True
228
- self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0], out_channels=nb_filts[1],
229
- padding=(0, 1), kernel_size=(1, 3), stride=1)
230
- else:
231
- self.downsample = False
232
- self.mp = nn.MaxPool2d((1, 3))
233
-
234
- def forward(self, x):
235
- identity = x
236
- if not self.first:
237
- out = self.bn1(x)
238
- out = self.selu(out)
239
- else:
240
- out = x
241
- out = self.conv1(x)
242
- out = self.bn2(out)
243
- out = self.selu(out)
244
- out = self.conv2(out)
245
-
246
- if self.downsample:
247
- identity = self.conv_downsample(identity)
248
- out += identity
249
- out = self.mp(out)
250
- return out
251
-
252
-
253
- class AASISTModel(nn.Module):
254
- def __init__(self, d_args):
255
- super().__init__()
256
-
257
- filts = d_args.get("filts", [70, [1, 32], [32, 32], [32, 64], [64, 64]])
258
- gat_dims = d_args.get("gat_dims", [64, 32])
259
- pool_ratios = d_args.get("pool_ratios", [0.5, 0.7, 0.5, 0.5])
260
- temperatures = d_args.get("temperatures", [2.0, 2.0, 100.0, 100.0])
261
-
262
- self.conv_time = CONV(out_channels=filts[0], kernel_size=128, in_channels=1)
263
- self.first_bn = nn.BatchNorm2d(num_features=1)
264
- self.selu = nn.SELU(inplace=True)
265
-
266
- self.encoder = nn.Sequential(
267
- nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
268
- nn.Sequential(Residual_block(nb_filts=filts[2])),
269
- nn.Sequential(Residual_block(nb_filts=filts[3])),
270
- nn.Sequential(Residual_block(nb_filts=filts[4])),
271
- nn.Sequential(Residual_block(nb_filts=filts[4])),
272
- nn.Sequential(Residual_block(nb_filts=filts[4]))
273
- )
274
-
275
- self.pos_S = nn.Parameter(torch.randn(1, 23, filts[-1][-1]))
276
- self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
277
- self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
278
-
279
- self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1], gat_dims[0], temperature=temperatures[0])
280
- self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1], gat_dims[0], temperature=temperatures[1])
281
-
282
- self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(gat_dims[0], gat_dims[1], temperature=temperatures[2])
283
- self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(gat_dims[1], gat_dims[1], temperature=temperatures[2])
284
- self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(gat_dims[0], gat_dims[1], temperature=temperatures[3])
285
- self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(gat_dims[1], gat_dims[1], temperature=temperatures[3])
286
-
287
- self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
288
- self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
289
- self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
290
- self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
291
- self.pool_hS2 = GraphPool(pool_ratios[3], gat_dims[1], 0.3)
292
- self.pool_hT2 = GraphPool(pool_ratios[3], gat_dims[1], 0.3)
293
-
294
- self.out_layer = nn.Linear(5 * gat_dims[1], 2)
295
- self.drop = nn.Dropout(0.5)
296
- self.drop_way = nn.Dropout(0.2)
297
-
298
- def forward(self, x):
299
- x = x.unsqueeze(1)
300
- x = self.conv_time(x)
301
- x = x.unsqueeze(1)
302
- x = torch.abs(x)
303
- x = self.first_bn(x)
304
- x = self.selu(x)
305
-
306
- e = self.encoder(x)
307
-
308
- e_S = e.mean(dim=3).transpose(1, 2) + self.pos_S
309
- e_T = e.mean(dim=2).transpose(1, 2)
310
-
311
- gat_S = self.GAT_layer_S(e_S)
312
- gat_T = self.GAT_layer_T(e_T)
313
-
314
- out_S = self.pool_S(gat_S)
315
- out_T = self.pool_T(gat_T)
316
-
317
- master1 = self.master1.expand(x.size(0), -1, -1)
318
- master2 = self.master2.expand(x.size(0), -1, -1)
319
-
320
- out_T1, out_S1 = self.HtrgGAT_layer_ST11(out_T, out_S, master=master1)
321
- out_S1 = self.pool_hS1(out_S1)
322
- out_T1 = self.pool_hT1(out_T1)
323
- out_T_branch, out_S_branch = self.HtrgGAT_layer_ST12(out_T1, out_S1, master=None)
324
- out_S_branch = self.pool_hS2(out_S_branch)
325
- out_T_branch = self.pool_hT2(out_T_branch)
326
-
327
- out_T2, out_S2 = self.HtrgGAT_layer_ST21(out_T, out_S, master=master2)
328
- out_S2 = self.pool_hS1(out_S2)
329
- out_T2 = self.pool_hT1(out_T2)
330
- out_T_branch2, out_S_branch2 = self.HtrgGAT_layer_ST22(out_T2, out_S2, master=None)
331
- out_S_branch2 = self.pool_hS2(out_S_branch2)
332
- out_T_branch2 = self.pool_hT2(out_T_branch2)
333
-
334
- out_T_branch = self.drop_way(out_T_branch)
335
- out_S_branch = self.drop_way(out_S_branch)
336
- out_T_branch2 = self.drop_way(out_T_branch2)
337
- out_S_branch2 = self.drop_way(out_S_branch2)
338
- master1 = self.drop_way(master1)
339
- master2 = self.drop_way(master2)
340
-
341
- T_max, _ = out_T_branch.max(dim=1)
342
- T_avg = out_T_branch.mean(dim=1)
343
- S_max, _ = out_S_branch.max(dim=1)
344
- S_avg = out_S_branch.mean(dim=1)
345
- T_max2, _ = out_T_branch2.max(dim=1)
346
- T_avg2 = out_T_branch2.mean(dim=1)
347
- S_max2, _ = out_S_branch2.max(dim=1)
348
- S_avg2 = out_S_branch2.mean(dim=1)
349
- master1_max, _ = master1.max(dim=1)
350
- master2_max, _ = master2.max(dim=1)
351
-
352
- out = torch.cat([T_max, T_avg, S_max, S_avg, T_max2 + master1_max + S_avg2,
353
- T_avg2 + master2_max + S_max2, (T_max + T_avg + S_max + S_avg) / 4,
354
- (T_max2 + T_avg2 + S_max2 + S_avg2 + master1_max + master2_max) / 6,
355
- T_max - T_max2, S_max - S_max2], dim=1)
356
-
357
- out = out[:, :5 * 32]
358
- out = self.drop(out)
359
- out = self.out_layer(out)
360
- return out
361
-
362
 
363
  # ============================================
364
  # Detector Class
@@ -368,9 +29,13 @@ class AASISTDetector:
368
  def __init__(self):
369
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
370
  self.sample_rate = 16000
371
- self.max_length = 64600
372
 
 
373
  self.model_config = {
 
 
 
374
  "filts": [70, [1, 32], [32, 32], [32, 64], [64, 64]],
375
  "gat_dims": [64, 32],
376
  "pool_ratios": [0.5, 0.7, 0.5, 0.5],
@@ -381,52 +46,81 @@ class AASISTDetector:
381
  self._load_weights()
382
  self.model.eval()
383
  print(f"[AASIST] Loaded on {self.device}")
 
384
 
385
  def _load_weights(self):
386
- import urllib.request
387
-
388
- weights_path = "AASIST.pth"
389
 
390
  if not os.path.exists(weights_path):
391
- print("[AASIST] Downloading weights from GitHub...")
392
- try:
393
- url = "https://github.com/clovaai/aasist/releases/download/v1.0/AASIST.pth"
394
- urllib.request.urlretrieve(url, weights_path)
395
- print(f"[AASIST] Downloaded successfully")
396
- except Exception as e:
397
- print(f"[AASIST] Download failed: {e}")
398
- return
399
-
400
- if os.path.exists(weights_path):
401
- checkpoint = torch.load(weights_path, map_location=self.device, weights_only=False)
402
- if 'model' in checkpoint:
403
- self.model.load_state_dict(checkpoint['model'], strict=False)
404
- else:
405
- self.model.load_state_dict(checkpoint, strict=False)
406
- print(f"[AASIST] Weights loaded")
407
 
408
  def analyze(self, audio_path):
409
  start_time = time.time()
410
 
 
411
  audio, sr = librosa.load(audio_path, sr=self.sample_rate, mono=True)
 
412
 
 
413
  if np.max(np.abs(audio)) > 0:
414
  audio = audio / np.max(np.abs(audio))
415
 
416
- if len(audio) > self.max_length:
417
- start = (len(audio) - self.max_length) // 2
418
- audio = audio[start:start + self.max_length]
419
- else:
420
- audio = np.pad(audio, (0, self.max_length - len(audio)), mode='constant')
421
-
422
- audio_tensor = torch.FloatTensor(audio).unsqueeze(0).to(self.device)
423
 
424
- with torch.no_grad():
425
- output = self.model(audio_tensor)
426
- probs = torch.softmax(output, dim=1)
427
- prob_genuine = probs[0, 0].item()
428
- prob_deepfake = probs[0, 1].item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
 
 
430
  if prob_deepfake >= 0.60:
431
  prediction = "DEEPFAKE"
432
  confidence = prob_deepfake
@@ -443,9 +137,24 @@ class AASISTDetector:
443
  'prob_genuine': prob_genuine * 100,
444
  'prob_deepfake': prob_deepfake * 100,
445
  'processing_time_ms': (time.time() - start_time) * 1000,
446
- 'duration': len(audio) / self.sample_rate
 
 
 
447
  }
448
 
 
 
 
 
 
 
 
 
 
 
 
 
449
 
450
  # ============================================
451
  # Visualization
@@ -487,7 +196,7 @@ def create_spectrogram(audio_path):
487
  plt.close(fig)
488
  return img
489
  except Exception as e:
490
- print(f"Error: {e}")
491
  return None
492
 
493
 
@@ -556,10 +265,12 @@ def analyze_audio(audio_file):
556
  | **Confianza** | {confidence:.1f}% |
557
  | **Prob. Genuino** | {result['prob_genuine']:.1f}% |
558
  | **Prob. Deepfake** | {result['prob_deepfake']:.1f}% |
 
 
559
  | **Tiempo** | {result['processing_time_ms']:.0f}ms |
560
  | **Duracion** | {result['duration']:.1f}s |
561
 
562
- **Modelo:** AASIST (EER: 0.83%)
563
  """
564
 
565
  spectrogram = create_spectrogram(audio_path)
@@ -568,6 +279,9 @@ def analyze_audio(audio_file):
568
  return pred_display, summary, spectrogram, confidence_chart
569
 
570
  except Exception as e:
 
 
 
571
  return f"Error: {str(e)}", "", None, None
572
 
573
 
@@ -613,4 +327,4 @@ with gr.Blocks(title="VoiceDetector", theme=gr.themes.Soft(primary_hue="blue"))
613
  outputs=[prediction_output, summary_output, spectrogram_output, confidence_output])
614
 
615
  if __name__ == "__main__":
616
- app.launch()
 
1
  """
2
  VoiceDetector - Forensic Deepfake Audio Detection
3
+ Using original AASIST model (EER: 0.83% on ASVspoof 2019 LA)
 
 
4
  """
5
 
6
  import os
7
  import sys
 
8
  import time
 
9
 
10
  import gradio as gr
11
  import numpy as np
12
  import torch
 
13
  import librosa
14
  import librosa.display
15
  import matplotlib
 
18
  from PIL import Image
19
  import io
20
 
21
+ # Import original AASIST model
22
+ from aasist_model import Model as AASISTModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # ============================================
25
  # Detector Class
 
29
  def __init__(self):
30
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
31
  self.sample_rate = 16000
32
+ self.max_length = 64600 # ~4 seconds
33
 
34
+ # Original AASIST config
35
  self.model_config = {
36
+ "architecture": "AASIST",
37
+ "nb_samp": 64600,
38
+ "first_conv": 128,
39
  "filts": [70, [1, 32], [32, 32], [32, 64], [64, 64]],
40
  "gat_dims": [64, 32],
41
  "pool_ratios": [0.5, 0.7, 0.5, 0.5],
 
46
  self._load_weights()
47
  self.model.eval()
48
  print(f"[AASIST] Loaded on {self.device}")
49
+ print(f"[AASIST] Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
50
 
51
  def _load_weights(self):
52
+ weights_path = os.path.join(os.path.dirname(__file__), "AASIST.pth")
 
 
53
 
54
  if not os.path.exists(weights_path):
55
+ print(f"[AASIST] ERROR: Weights not found at {weights_path}")
56
+ return
57
+
58
+ checkpoint = torch.load(weights_path, map_location=self.device, weights_only=False)
59
+ self.model.load_state_dict(checkpoint, strict=False)
60
+ print(f"[AASIST] Weights loaded from {weights_path}")
 
 
 
 
 
 
 
 
 
 
61
 
62
  def analyze(self, audio_path):
63
  start_time = time.time()
64
 
65
+ # Load audio
66
  audio, sr = librosa.load(audio_path, sr=self.sample_rate, mono=True)
67
+ original_duration = len(audio) / self.sample_rate
68
 
69
+ # Normalize
70
  if np.max(np.abs(audio)) > 0:
71
  audio = audio / np.max(np.abs(audio))
72
 
73
+ # Multi-segment analysis for better detection
74
+ # Analyze multiple segments and use weighted voting
75
+ segment_results = []
 
 
 
 
76
 
77
+ if len(audio) <= self.max_length:
78
+ # Short audio: analyze as single segment
79
+ padded = np.pad(audio, (0, self.max_length - len(audio)), mode='constant')
80
+ segment_results.append(self._analyze_segment(padded))
81
+ else:
82
+ # Long audio: analyze multiple overlapping segments
83
+ # Sample from beginning, middle, and end for comprehensive coverage
84
+ step = self.max_length // 2 # 50% overlap
85
+
86
+ for i in range(0, len(audio) - self.max_length + 1, step):
87
+ segment = audio[i:i + self.max_length]
88
+ segment_results.append(self._analyze_segment(segment))
89
+
90
+ # Also analyze the last segment if we haven't covered the end
91
+ if len(audio) - self.max_length > (len(segment_results) - 1) * step:
92
+ segment = audio[-self.max_length:]
93
+ segment_results.append(self._analyze_segment(segment))
94
+
95
+ # Aggregate results with balanced approach
96
+ all_genuine = [r[0] for r in segment_results]
97
+ all_deepfake = [r[1] for r in segment_results]
98
+
99
+ max_deepfake = max(all_deepfake)
100
+ avg_deepfake = np.mean(all_deepfake)
101
+ avg_genuine = np.mean(all_genuine)
102
+
103
+ # Count how many segments are deepfake vs genuine
104
+ n_deepfake_segs = sum(1 for d in all_deepfake if d > 0.6)
105
+ n_genuine_segs = sum(1 for g in all_genuine if g > 0.6)
106
+ total_segs = len(segment_results)
107
+
108
+ # Majority voting with average as tiebreaker
109
+ # If majority of segments agree, use that
110
+ if n_deepfake_segs > total_segs * 0.5:
111
+ # More than half segments are deepfake
112
+ prob_deepfake = 0.6 * max_deepfake + 0.4 * avg_deepfake
113
+ prob_genuine = 1.0 - prob_deepfake
114
+ elif n_genuine_segs > total_segs * 0.5:
115
+ # More than half segments are genuine
116
+ prob_genuine = avg_genuine
117
+ prob_deepfake = avg_deepfake
118
+ else:
119
+ # Mixed results - use weighted average
120
+ prob_deepfake = 0.5 * max_deepfake + 0.5 * avg_deepfake
121
+ prob_genuine = 1.0 - prob_deepfake
122
 
123
+ # Prediction thresholds
124
  if prob_deepfake >= 0.60:
125
  prediction = "DEEPFAKE"
126
  confidence = prob_deepfake
 
137
  'prob_genuine': prob_genuine * 100,
138
  'prob_deepfake': prob_deepfake * 100,
139
  'processing_time_ms': (time.time() - start_time) * 1000,
140
+ 'duration': original_duration,
141
+ 'segments_analyzed': len(segment_results),
142
+ 'max_deepfake_segment': max_deepfake * 100,
143
+ 'avg_deepfake': avg_deepfake * 100
144
  }
145
 
146
+ def _analyze_segment(self, audio_segment):
147
+ """Analyze a single audio segment and return (prob_genuine, prob_deepfake)"""
148
+ audio_tensor = torch.FloatTensor(audio_segment).unsqueeze(0).to(self.device)
149
+
150
+ with torch.no_grad():
151
+ _, output = self.model(audio_tensor)
152
+ probs = torch.softmax(output, dim=1)
153
+ prob_genuine = probs[0, 0].item()
154
+ prob_deepfake = probs[0, 1].item()
155
+
156
+ return (prob_genuine, prob_deepfake)
157
+
158
 
159
  # ============================================
160
  # Visualization
 
196
  plt.close(fig)
197
  return img
198
  except Exception as e:
199
+ print(f"Error creating spectrogram: {e}")
200
  return None
201
 
202
 
 
265
  | **Confianza** | {confidence:.1f}% |
266
  | **Prob. Genuino** | {result['prob_genuine']:.1f}% |
267
  | **Prob. Deepfake** | {result['prob_deepfake']:.1f}% |
268
+ | **Segmentos analizados** | {result.get('segments_analyzed', 1)} |
269
+ | **Max Deepfake (segmento)** | {result.get('max_deepfake_segment', result['prob_deepfake']):.1f}% |
270
  | **Tiempo** | {result['processing_time_ms']:.0f}ms |
271
  | **Duracion** | {result['duration']:.1f}s |
272
 
273
+ **Modelo:** AASIST (Multi-segment analysis)
274
  """
275
 
276
  spectrogram = create_spectrogram(audio_path)
 
279
  return pred_display, summary, spectrogram, confidence_chart
280
 
281
  except Exception as e:
282
+ import traceback
283
+ error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
284
+ print(error_msg)
285
  return f"Error: {str(e)}", "", None, None
286
 
287
 
 
327
  outputs=[prediction_output, summary_output, spectrogram_output, confidence_output])
328
 
329
  if __name__ == "__main__":
330
+ app.launch(server_name="0.0.0.0", server_port=7860)