MeysamSh commited on
Commit
c49e455
·
1 Parent(s): 56053bc

Add application file

Browse files
AASIST_ASVspoof5_Exp4_CL.conf ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "database_path": "/lium/corpus/vrac/asini/deepfake_dataset/ASVspoof5_2024/",
3
+ "train_path": "ASVspoof5.train.metadata.txt",
4
+ "dev_path": "ASVspoof5.dev.metadata.txt",
5
+ "model_path": "./models/weights/AASIST/Exp4_CL/best.pth",
6
+ "score_file_dir":"exp_result/AASIST_ASVspoof5_Exp4_eval_train_ep50_bs64/eval_scores_using_best_dev_model_onTrain.txt",
7
+ "split_num":5,
8
+ "accumulating":"False",
9
+ "re_init_optim":"False",
10
+ "train_wav_path":"flac_T/",
11
+ "dev_wav_path":"flac_D/",
12
+ "debug_mode": "False",
13
+ "batch_size": 64,
14
+ "num_epochs": 20,
15
+ "loss": "CCE",
16
+ "track": "LA",
17
+ "eval_all_best": "True",
18
+ "eval_output": "eval_scores_using_best_dev_model.txt",
19
+ "cudnn_deterministic_toggle": "True",
20
+ "cudnn_benchmark_toggle": "False",
21
+ "model_config": {
22
+ "architecture": "AASIST",
23
+ "nb_samp": 64600,
24
+ "first_conv": 128,
25
+ "filts": [70, [1, 32], [32, 32], [32, 64], [64, 64]],
26
+ "gat_dims": [64, 32],
27
+ "pool_ratios": [0.5, 0.7, 0.5, 0.5],
28
+ "temperatures": [2.0, 2.0, 100.0, 100.0],
29
+ "output_cls": 9
30
+ },
31
+ "optim_config": {
32
+ "optimizer": "adam",
33
+ "amsgrad": "False",
34
+ "base_lr": 0.0001,
35
+ "lr_min": 0.000005,
36
+ "betas": [0.9, 0.999],
37
+ "weight_decay": 0.0001,
38
+ "scheduler": "cosine"
39
+ }
40
+ }
Web/index.html ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8">
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
+ <title>Audio Analysis API</title>
8
+ <link rel="stylesheet" href="styles.css">
9
+
10
+ <!-- Bootstrap CSS -->
11
+ <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css" rel="stylesheet">
12
+ <style>
13
+ body {
14
+ background-color: #f8f9fa;
15
+ padding: 20px;
16
+ }
17
+
18
+ .container {
19
+ max-width: 800px;
20
+ margin: 0 auto;
21
+ background: #fff;
22
+ padding: 30px;
23
+ border-radius: 10px;
24
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
25
+ }
26
+
27
+ h1 {
28
+ text-align: center;
29
+ margin-bottom: 20px;
30
+ color: #333;
31
+ font-weight: bold;
32
+ }
33
+
34
+ h2 {
35
+ color: #555;
36
+ margin-bottom: 20px;
37
+ font-size: 1.5rem;
38
+ }
39
+
40
+ .btn {
41
+ margin: 5px;
42
+ font-weight: 500;
43
+ }
44
+
45
+ #recordingsList {
46
+ margin-top: 20px;
47
+ }
48
+
49
+ .response {
50
+ margin-top: 20px;
51
+ padding: 15px;
52
+ background-color: #e9ecef;
53
+ border-radius: 5px;
54
+ color: #333;
55
+ font-size: 1.1rem;
56
+ }
57
+
58
+ .metadata {
59
+ margin-top: 20px;
60
+ padding: 15px;
61
+ background-color: #f1f3f4;
62
+ border-radius: 5px;
63
+ color: #333;
64
+ font-size: 1.1rem;
65
+ }
66
+
67
+ .list-group-item {
68
+ display: flex;
69
+ justify-content: space-between;
70
+ align-items: center;
71
+ }
72
+
73
+ .list-group-item a {
74
+ text-decoration: none;
75
+ color: #0d6efd;
76
+ }
77
+
78
+ .list-group-item a:hover {
79
+ text-decoration: underline;
80
+ }
81
+
82
+ #controls {
83
+ margin-bottom: 20px;
84
+ }
85
+
86
+ #formats {
87
+ font-size: 0.9rem;
88
+ color: #666;
89
+ margin-bottom: 10px;
90
+ }
91
+ </style>
92
+ </head>
93
+
94
+ <body>
95
+ <div class="container">
96
+ <h1>Audio Analysis API</h1>
97
+ <h2>Upload or Record Audio Files</h2>
98
+
99
+ <!-- Form for Uploading Files -->
100
+ <form id="upload-form" class="mb-4">
101
+ <div class="mb-3">
102
+ <input type="file" id="audio-file" class="form-control" accept="audio/*" multiple />
103
+ </div>
104
+ <button type="button" id="upload-button" class="btn btn-primary w-100">Upload & Analyze</button>
105
+ </form>
106
+
107
+ <hr>
108
+
109
+ <!-- Buttons for Recording Audio -->
110
+ <div id="controls" class="mb-4 text-center">
111
+ <button id="recordButton" class="btn btn-success">Record</button>
112
+ <button id="pauseButton" class="btn btn-warning" disabled>Pause</button>
113
+ <button id="stopButton" class="btn btn-danger" disabled>Stop</button>
114
+ </div>
115
+ <div id="formats" class="mb-3 text-center">Format: Start recording to see sample rate</div>
116
+ <p class="text-center"><strong>Recordings:</strong></p>
117
+ <ol id="recordingsList" class="list-group"></ol>
118
+
119
+ <!-- Metadata Display -->
120
+ <div class="metadata mt-4">
121
+ <h3>File Metadata</h3>
122
+
123
+ <!-- Dropdown Filters -->
124
+ <div class="mb-3 d-flex flex-wrap gap-3">
125
+ <i>Choisir un Label</i>
126
+ <select id="filter-label" class="form-select">
127
+ <option value="">All Labels</option>
128
+ </select>
129
+ <i>Choisir un System</i>
130
+ <select id="filter-system" class="form-select">
131
+ <option value="">All Systems</option>
132
+ </select>
133
+ <i>Choisir un Codec</i>
134
+ <select id="filter-codec" class="form-select">
135
+ <option value="">All Codecs</option>
136
+ </select>
137
+ <i>Choisir un Genre</i>
138
+ <select id="filter-genre" class="form-select">
139
+ <option value="">All Genres</option>
140
+ </select>
141
+ <i>Choisir une Année</i>
142
+ <select id="filter-year" class="form-select">
143
+ <option value="">All Years</option>
144
+ </select>
145
+ </div>
146
+
147
+ <div id="metadata-display"></div>
148
+ </div>
149
+
150
+ <!-- Response Display -->
151
+ <div class="response mt-4">
152
+ <h3>Analysis Results</h3>
153
+ <div id="response"></div>
154
+ </div>
155
+ </div>
156
+
157
+ <!-- Load Recorder.js and your script.js -->
158
+ <script src="recorder.js"></script>
159
+ <script src="script.js"></script>
160
+ </body>
161
+
162
+ </html>
Web/recorder.js ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ (function(f){if(typeof exports==="object"&&typeof module!=="undefined"){module.exports=f()}else if(typeof define==="function"&&define.amd){define([],f)}else{var g;if(typeof window!=="undefined"){g=window}else if(typeof global!=="undefined"){g=global}else if(typeof self!=="undefined"){g=self}else{g=this}g.Recorder = f()}})(function(){var define,module,exports;return (function e(t,n,r){function s(o,u){if(!n[o]){if(!t[o]){var a=typeof require=="function"&&require;if(!u&&a)return a(o,!0);if(i)return i(o,!0);var f=new Error("Cannot find module '"+o+"'");throw f.code="MODULE_NOT_FOUND",f}var l=n[o]={exports:{}};t[o][0].call(l.exports,function(e){var n=t[o][1][e];return s(n?n:e)},l,l.exports,e,t,n,r)}return n[o].exports}var i=typeof require=="function"&&require;for(var o=0;o<r.length;o++)s(r[o]);return s})({1:[function(require,module,exports){
2
+ "use strict";
3
+
4
+ module.exports = require("./recorder").Recorder;
5
+
6
+ },{"./recorder":2}],2:[function(require,module,exports){
7
+ 'use strict';
8
+
9
+ var _createClass = (function () {
10
+ function defineProperties(target, props) {
11
+ for (var i = 0; i < props.length; i++) {
12
+ var descriptor = props[i];descriptor.enumerable = descriptor.enumerable || false;descriptor.configurable = true;if ("value" in descriptor) descriptor.writable = true;Object.defineProperty(target, descriptor.key, descriptor);
13
+ }
14
+ }return function (Constructor, protoProps, staticProps) {
15
+ if (protoProps) defineProperties(Constructor.prototype, protoProps);if (staticProps) defineProperties(Constructor, staticProps);return Constructor;
16
+ };
17
+ })();
18
+
19
+ Object.defineProperty(exports, "__esModule", {
20
+ value: true
21
+ });
22
+ exports.Recorder = undefined;
23
+
24
+ var _inlineWorker = require('inline-worker');
25
+
26
+ var _inlineWorker2 = _interopRequireDefault(_inlineWorker);
27
+
28
+ function _interopRequireDefault(obj) {
29
+ return obj && obj.__esModule ? obj : { default: obj };
30
+ }
31
+
32
+ function _classCallCheck(instance, Constructor) {
33
+ if (!(instance instanceof Constructor)) {
34
+ throw new TypeError("Cannot call a class as a function");
35
+ }
36
+ }
37
+
38
+ var Recorder = exports.Recorder = (function () {
39
+ function Recorder(source, cfg) {
40
+ var _this = this;
41
+
42
+ _classCallCheck(this, Recorder);
43
+
44
+ this.config = {
45
+ bufferLen: 4096,
46
+ numChannels: 2,
47
+ mimeType: 'audio/wav'
48
+ };
49
+ this.recording = false;
50
+ this.callbacks = {
51
+ getBuffer: [],
52
+ exportWAV: []
53
+ };
54
+
55
+ Object.assign(this.config, cfg);
56
+ this.context = source.context;
57
+ this.node = (this.context.createScriptProcessor || this.context.createJavaScriptNode).call(this.context, this.config.bufferLen, this.config.numChannels, this.config.numChannels);
58
+
59
+ this.node.onaudioprocess = function (e) {
60
+ if (!_this.recording) return;
61
+
62
+ var buffer = [];
63
+ for (var channel = 0; channel < _this.config.numChannels; channel++) {
64
+ buffer.push(e.inputBuffer.getChannelData(channel));
65
+ }
66
+ _this.worker.postMessage({
67
+ command: 'record',
68
+ buffer: buffer
69
+ });
70
+ };
71
+
72
+ source.connect(this.node);
73
+ this.node.connect(this.context.destination); //this should not be necessary
74
+
75
+ var self = {};
76
+ this.worker = new _inlineWorker2.default(function () {
77
+ var recLength = 0,
78
+ recBuffers = [],
79
+ sampleRate = undefined,
80
+ numChannels = undefined;
81
+
82
+ self.onmessage = function (e) {
83
+ switch (e.data.command) {
84
+ case 'init':
85
+ init(e.data.config);
86
+ break;
87
+ case 'record':
88
+ record(e.data.buffer);
89
+ break;
90
+ case 'exportWAV':
91
+ exportWAV(e.data.type);
92
+ break;
93
+ case 'getBuffer':
94
+ getBuffer();
95
+ break;
96
+ case 'clear':
97
+ clear();
98
+ break;
99
+ }
100
+ };
101
+
102
+ function init(config) {
103
+ sampleRate = config.sampleRate;
104
+ numChannels = config.numChannels;
105
+ initBuffers();
106
+ }
107
+
108
+ function record(inputBuffer) {
109
+ for (var channel = 0; channel < numChannels; channel++) {
110
+ recBuffers[channel].push(inputBuffer[channel]);
111
+ }
112
+ recLength += inputBuffer[0].length;
113
+ }
114
+
115
+ function exportWAV(type) {
116
+ var buffers = [];
117
+ for (var channel = 0; channel < numChannels; channel++) {
118
+ buffers.push(mergeBuffers(recBuffers[channel], recLength));
119
+ }
120
+ var interleaved = undefined;
121
+ if (numChannels === 2) {
122
+ interleaved = interleave(buffers[0], buffers[1]);
123
+ } else {
124
+ interleaved = buffers[0];
125
+ }
126
+ var dataview = encodeWAV(interleaved);
127
+ var audioBlob = new Blob([dataview], { type: type });
128
+
129
+ self.postMessage({ command: 'exportWAV', data: audioBlob });
130
+ }
131
+
132
+ function getBuffer() {
133
+ var buffers = [];
134
+ for (var channel = 0; channel < numChannels; channel++) {
135
+ buffers.push(mergeBuffers(recBuffers[channel], recLength));
136
+ }
137
+ self.postMessage({ command: 'getBuffer', data: buffers });
138
+ }
139
+
140
+ function clear() {
141
+ recLength = 0;
142
+ recBuffers = [];
143
+ initBuffers();
144
+ }
145
+
146
+ function initBuffers() {
147
+ for (var channel = 0; channel < numChannels; channel++) {
148
+ recBuffers[channel] = [];
149
+ }
150
+ }
151
+
152
+ function mergeBuffers(recBuffers, recLength) {
153
+ var result = new Float32Array(recLength);
154
+ var offset = 0;
155
+ for (var i = 0; i < recBuffers.length; i++) {
156
+ result.set(recBuffers[i], offset);
157
+ offset += recBuffers[i].length;
158
+ }
159
+ return result;
160
+ }
161
+
162
+ function interleave(inputL, inputR) {
163
+ var length = inputL.length + inputR.length;
164
+ var result = new Float32Array(length);
165
+
166
+ var index = 0,
167
+ inputIndex = 0;
168
+
169
+ while (index < length) {
170
+ result[index++] = inputL[inputIndex];
171
+ result[index++] = inputR[inputIndex];
172
+ inputIndex++;
173
+ }
174
+ return result;
175
+ }
176
+
177
+ function floatTo16BitPCM(output, offset, input) {
178
+ for (var i = 0; i < input.length; i++, offset += 2) {
179
+ var s = Math.max(-1, Math.min(1, input[i]));
180
+ output.setInt16(offset, s < 0 ? s * 0x8000 : s * 0x7FFF, true);
181
+ }
182
+ }
183
+
184
+ function writeString(view, offset, string) {
185
+ for (var i = 0; i < string.length; i++) {
186
+ view.setUint8(offset + i, string.charCodeAt(i));
187
+ }
188
+ }
189
+
190
+ function encodeWAV(samples) {
191
+ var buffer = new ArrayBuffer(44 + samples.length * 2);
192
+ var view = new DataView(buffer);
193
+
194
+ /* RIFF identifier */
195
+ writeString(view, 0, 'RIFF');
196
+ /* RIFF chunk length */
197
+ view.setUint32(4, 36 + samples.length * 2, true);
198
+ /* RIFF type */
199
+ writeString(view, 8, 'WAVE');
200
+ /* format chunk identifier */
201
+ writeString(view, 12, 'fmt ');
202
+ /* format chunk length */
203
+ view.setUint32(16, 16, true);
204
+ /* sample format (raw) */
205
+ view.setUint16(20, 1, true);
206
+ /* channel count */
207
+ view.setUint16(22, numChannels, true);
208
+ /* sample rate */
209
+ view.setUint32(24, sampleRate, true);
210
+ /* byte rate (sample rate * block align) */
211
+ view.setUint32(28, sampleRate * 4, true);
212
+ /* block align (channel count * bytes per sample) */
213
+ view.setUint16(32, numChannels * 2, true);
214
+ /* bits per sample */
215
+ view.setUint16(34, 16, true);
216
+ /* data chunk identifier */
217
+ writeString(view, 36, 'data');
218
+ /* data chunk length */
219
+ view.setUint32(40, samples.length * 2, true);
220
+
221
+ floatTo16BitPCM(view, 44, samples);
222
+
223
+ return view;
224
+ }
225
+ }, self);
226
+
227
+ this.worker.postMessage({
228
+ command: 'init',
229
+ config: {
230
+ sampleRate: this.context.sampleRate,
231
+ numChannels: this.config.numChannels
232
+ }
233
+ });
234
+
235
+ this.worker.onmessage = function (e) {
236
+ var cb = _this.callbacks[e.data.command].pop();
237
+ if (typeof cb == 'function') {
238
+ cb(e.data.data);
239
+ }
240
+ };
241
+ }
242
+
243
+ _createClass(Recorder, [{
244
+ key: 'record',
245
+ value: function record() {
246
+ this.recording = true;
247
+ }
248
+ }, {
249
+ key: 'stop',
250
+ value: function stop() {
251
+ this.recording = false;
252
+ }
253
+ }, {
254
+ key: 'clear',
255
+ value: function clear() {
256
+ this.worker.postMessage({ command: 'clear' });
257
+ }
258
+ }, {
259
+ key: 'getBuffer',
260
+ value: function getBuffer(cb) {
261
+ cb = cb || this.config.callback;
262
+ if (!cb) throw new Error('Callback not set');
263
+
264
+ this.callbacks.getBuffer.push(cb);
265
+
266
+ this.worker.postMessage({ command: 'getBuffer' });
267
+ }
268
+ }, {
269
+ key: 'exportWAV',
270
+ value: function exportWAV(cb, mimeType) {
271
+ mimeType = mimeType || this.config.mimeType;
272
+ cb = cb || this.config.callback;
273
+ if (!cb) throw new Error('Callback not set');
274
+
275
+ this.callbacks.exportWAV.push(cb);
276
+
277
+ this.worker.postMessage({
278
+ command: 'exportWAV',
279
+ type: mimeType
280
+ });
281
+ }
282
+ }], [{
283
+ key: 'forceDownload',
284
+ value: function forceDownload(blob, filename) {
285
+ var url = (window.URL || window.webkitURL).createObjectURL(blob);
286
+ var link = window.document.createElement('a');
287
+ link.href = url;
288
+ link.download = filename || 'output.wav';
289
+ var click = document.createEvent("Event");
290
+ click.initEvent("click", true, true);
291
+ link.dispatchEvent(click);
292
+ }
293
+ }]);
294
+
295
+ return Recorder;
296
+ })();
297
+
298
+ exports.default = Recorder;
299
+
300
+ },{"inline-worker":3}],3:[function(require,module,exports){
301
+ "use strict";
302
+
303
+ module.exports = require("./inline-worker");
304
+ },{"./inline-worker":4}],4:[function(require,module,exports){
305
+ (function (global){
306
+ "use strict";
307
+
308
+ var _createClass = (function () { function defineProperties(target, props) { for (var key in props) { var prop = props[key]; prop.configurable = true; if (prop.value) prop.writable = true; } Object.defineProperties(target, props); } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; })();
309
+
310
+ var _classCallCheck = function (instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } };
311
+
312
+ var WORKER_ENABLED = !!(global === global.window && global.URL && global.Blob && global.Worker);
313
+
314
+ var InlineWorker = (function () {
315
+ function InlineWorker(func, self) {
316
+ var _this = this;
317
+
318
+ _classCallCheck(this, InlineWorker);
319
+
320
+ if (WORKER_ENABLED) {
321
+ var functionBody = func.toString().trim().match(/^function\s*\w*\s*\([\w\s,]*\)\s*{([\w\W]*?)}$/)[1];
322
+ var url = global.URL.createObjectURL(new global.Blob([functionBody], { type: "text/javascript" }));
323
+
324
+ return new global.Worker(url);
325
+ }
326
+
327
+ this.self = self;
328
+ this.self.postMessage = function (data) {
329
+ setTimeout(function () {
330
+ _this.onmessage({ data: data });
331
+ }, 0);
332
+ };
333
+
334
+ setTimeout(function () {
335
+ func.call(self);
336
+ }, 0);
337
+ }
338
+
339
+ _createClass(InlineWorker, {
340
+ postMessage: {
341
+ value: function postMessage(data) {
342
+ var _this = this;
343
+
344
+ setTimeout(function () {
345
+ _this.self.onmessage({ data: data });
346
+ }, 0);
347
+ }
348
+ }
349
+ });
350
+
351
+ return InlineWorker;
352
+ })();
353
+
354
+ module.exports = InlineWorker;
355
+ }).call(this,typeof global !== "undefined" ? global : typeof self !== "undefined" ? self : typeof window !== "undefined" ? window : {})
356
+ },{}]},{},[1])(1)
357
+ });
Web/script.js ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ const uploadButton = document.getElementById('upload-button');
2
+ const audioFileInput = document.getElementById('audio-file');
3
+ const recordButton = document.getElementById('recordButton');
4
+ const stopButton = document.getElementById('stopButton');
5
+ const pauseButton = document.getElementById('pauseButton');
6
+ const responseDiv = document.getElementById('response');
7
+ const metadataDisplay = document.getElementById('metadata-display');
8
+
9
+ let gumStream;
10
+ let rec;
11
+ let input;
12
+ let audioContext;
13
+
14
+ function startAudioContext() {
15
+ if (!audioContext) {
16
+ audioContext = new (window.AudioContext || window.webkitAudioContext)();
17
+ } else if (audioContext.state === 'suspended') {
18
+ audioContext.resume().then(() => {
19
+ console.log('AudioContext repris');
20
+ });
21
+ }
22
+ }
23
+
24
+ // Fonction pour rééchantillonner l'audio à 16 kHz
25
+ async function resampleAudio(blob, targetSampleRate = 16000) {
26
+ return new Promise((resolve, reject) => {
27
+ const reader = new FileReader();
28
+ reader.onload = async () => {
29
+ const audioContext = new (window.AudioContext || window.webkitAudioContext)();
30
+ const buffer = await audioContext.decodeAudioData(reader.result);
31
+
32
+ // Créer un nouvel AudioContext avec le taux d'échantillonnage cible
33
+ const offlineContext = new OfflineAudioContext(
34
+ buffer.numberOfChannels,
35
+ buffer.length * (targetSampleRate / buffer.sampleRate),
36
+ targetSampleRate
37
+ );
38
+
39
+ // Créer une source audio avec le buffer original
40
+ const source = offlineContext.createBufferSource();
41
+ source.buffer = buffer;
42
+
43
+ // Connecter la source au contexte offline
44
+ source.connect(offlineContext.destination);
45
+ source.start();
46
+
47
+ // Rendre l'audio
48
+ const resampledBuffer = await offlineContext.startRendering();
49
+
50
+ // Convertir le buffer rééchantillonné en WAV
51
+ const wavBlob = bufferToWav(resampledBuffer);
52
+ resolve(wavBlob);
53
+ };
54
+ reader.onerror = reject;
55
+ reader.readAsArrayBuffer(blob);
56
+ });
57
+ }
58
+
59
+ // Fonction pour convertir un AudioBuffer en WAV
60
+ function bufferToWav(buffer) {
61
+ const numChannels = buffer.numberOfChannels;
62
+ const sampleRate = buffer.sampleRate;
63
+ const length = buffer.length * numChannels * 2; // 2 bytes par échantillon
64
+ const data = new Float32Array(length);
65
+
66
+ // Interleave les canaux
67
+ for (let channel = 0; channel < numChannels; channel++) {
68
+ const channelData = buffer.getChannelData(channel);
69
+ for (let i = 0; i < channelData.length; i++) {
70
+ data[i * numChannels + channel] = channelData[i];
71
+ }
72
+ }
73
+
74
+ // Encoder en WAV
75
+ const wavBlob = encodeWAV(data, sampleRate, numChannels);
76
+ return wavBlob;
77
+ }
78
+
79
+ // Fonction pour encoder des données audio en WAV
80
+ function encodeWAV(samples, sampleRate, numChannels) {
81
+ const buffer = new ArrayBuffer(44 + samples.length * 2);
82
+ const view = new DataView(buffer);
83
+
84
+ // Écrire l'en-tête WAV
85
+ writeString(view, 0, 'RIFF');
86
+ view.setUint32(4, 36 + samples.length * 2, true);
87
+ writeString(view, 8, 'WAVE');
88
+ writeString(view, 12, 'fmt ');
89
+ view.setUint32(16, 16, true);
90
+ view.setUint16(20, 1, true); // Format PCM
91
+ view.setUint16(22, numChannels, true);
92
+ view.setUint32(24, sampleRate, true);
93
+ view.setUint32(28, sampleRate * numChannels * 2, true);
94
+ view.setUint16(32, numChannels * 2, true);
95
+ view.setUint16(34, 16, true); // Bits par échantillon
96
+ writeString(view, 36, 'data');
97
+ view.setUint32(40, samples.length * 2, true);
98
+
99
+ // Écrire les échantillons audio
100
+ floatTo16BitPCM(view, 44, samples);
101
+
102
+ return new Blob([view], { type: 'audio/wav' });
103
+ }
104
+
105
+ // Fonction utilitaire pour écrire une chaîne dans un DataView
106
+ function writeString(view, offset, string) {
107
+ for (let i = 0; i < string.length; i++) {
108
+ view.setUint8(offset + i, string.charCodeAt(i));
109
+ }
110
+ }
111
+
112
+ // Fonction utilitaire pour convertir des échantillons flottants en PCM 16 bits
113
+ function floatTo16BitPCM(view, offset, input) {
114
+ for (let i = 0; i < input.length; i++, offset += 2) {
115
+ const s = Math.max(-1, Math.min(1, input[i]));
116
+ view.setInt16(offset, s < 0 ? s * 0x8000 : s * 0x7FFF, true);
117
+ }
118
+ }
119
+
120
+ // Function to fetch metadata from the text file
121
+ async function fetchMetadata() {
122
+ try {
123
+ const response = await fetch('../metadata.txt'); // Assurez-vous que le fichier est accessible
124
+ if (!response.ok) {
125
+ throw new Error('Failed to fetch metadata');
126
+ }
127
+ const text = await response.text();
128
+ console.log('Metadata file content:', text); // Debugging
129
+
130
+ // Split text into lines
131
+ const lines = text.split('\n').map(line => line.trim()).filter(line => line !== '');
132
+
133
+ if (lines.length < 2) {
134
+ throw new Error('Metadata file is empty or malformed');
135
+ }
136
+
137
+ // Extract headers
138
+ const headers = lines[0].split(';').map(h => h.trim().toLowerCase());
139
+
140
+ // Extract data
141
+ const metadata = lines.slice(1).map(line => {
142
+ const values = line.split(';').map(value => value.trim());
143
+ let entry = {};
144
+ headers.forEach((header, index) => {
145
+ entry[header] = values[index] || 'N/A'; // Default to 'N/A' if missing data
146
+ });
147
+ return entry;
148
+ });
149
+
150
+ console.log('Parsed Metadata:', metadata); // Debugging
151
+ return metadata;
152
+ } catch (error) {
153
+ console.error('Error fetching metadata:', error);
154
+ return [];
155
+ }
156
+ }
157
+
158
+ function populateFilters() {
159
+ const predefinedValues = {
160
+ label: ["spoof", "genuine"],
161
+ system: ["bonafide"].concat(Array.from({ length: 19 }, (_, i) => `A${String(i + 1).padStart(2, '0')}`)),
162
+ codec: ["FLAC", "WAV", "MP3"],
163
+ genre: ["male", "female"],
164
+ year: ["2020", "2021", "2022", "2023", "2024", "2025"]
165
+ };
166
+
167
+ Object.keys(predefinedValues).forEach(key => {
168
+ populateDropdown(`filter-${key}`, predefinedValues[key]);
169
+ });
170
+ }
171
+
172
+ function populateDropdown(id, values) {
173
+ const select = document.getElementById(id);
174
+ select.innerHTML = '<option value="">All</option>'; // Ajouter l'option "All" par défaut
175
+
176
+ values.forEach(value => {
177
+ const option = document.createElement("option");
178
+ option.value = value;
179
+ option.textContent = value.charAt(0).toUpperCase() + value.slice(1); // Majuscule initiale
180
+ select.appendChild(option);
181
+ });
182
+
183
+ select.addEventListener("change", filterMetadata);
184
+ }
185
+
186
+ function filterMetadata() {
187
+ const selectedLabel = document.getElementById("filter-label").value.toLowerCase();
188
+ const selectedSystem = document.getElementById("filter-system").value.toLowerCase();
189
+ const selectedCodec = document.getElementById("filter-codec").value.toLowerCase();
190
+ const selectedGenre = document.getElementById("filter-genre").value.toLowerCase();
191
+ const selectedYear = document.getElementById("filter-year").value.toLowerCase();
192
+
193
+ fetchMetadata().then(metadata => {
194
+ const filteredMetadata = metadata.filter(entry =>
195
+ (selectedLabel === "" || entry.label.toLowerCase() === selectedLabel) &&
196
+ (selectedSystem === "" || entry.system.toLowerCase() === selectedSystem) &&
197
+ (selectedCodec === "" || entry.codec.toLowerCase() === selectedCodec) &&
198
+ (selectedGenre === "" || entry.genre.toLowerCase() === selectedGenre) &&
199
+ (selectedYear === "" || entry.year.toLowerCase() === selectedYear)
200
+ );
201
+
202
+ displayMetadata(null, metadata, true); // Mode filtrage
203
+ });
204
+ }
205
+
206
+
207
+ function displayMetadata(files, metadata, filteredOnly = false) {
208
+ metadataDisplay.innerHTML = ''; // Nettoyer l'affichage avant de remplir
209
+
210
+ // Si on ne filtre pas et qu'aucun fichier n'est sélectionné, afficher tout
211
+ if (!filteredOnly && (!files || files.length === 0)) {
212
+ metadataDisplay.innerHTML = '<p>No files selected.</p>';
213
+ return;
214
+ }
215
+
216
+ let filteredMetadata;
217
+
218
+ if (filteredOnly) {
219
+ // Appliquer les filtres des drop-downs
220
+ const selectedLabel = document.getElementById("filter-label").value.toLowerCase();
221
+ const selectedSystem = document.getElementById("filter-system").value.toLowerCase();
222
+ const selectedCodec = document.getElementById("filter-codec").value.toLowerCase();
223
+ const selectedGenre = document.getElementById("filter-genre").value.toLowerCase();
224
+ const selectedYear = document.getElementById("filter-year").value.toLowerCase();
225
+
226
+ filteredMetadata = metadata.filter(entry =>
227
+ (selectedLabel === "" || entry.label.toLowerCase() === selectedLabel) &&
228
+ (selectedSystem === "" || entry.system.toLowerCase() === selectedSystem) &&
229
+ (selectedCodec === "" || entry.codec.toLowerCase() === selectedCodec) &&
230
+ (selectedGenre === "" || entry.genre.toLowerCase() === selectedGenre) &&
231
+ (selectedYear === "" || entry.year.toLowerCase() === selectedYear)
232
+ );
233
+ } else {
234
+ // Obtenir la liste des fichiers sélectionnés
235
+ const selectedFiles = Array.from(files).map(file => file.name.trim().toLowerCase());
236
+
237
+ // Filtrer les métadonnées pour ne garder que celles des fichiers sélectionnés
238
+ filteredMetadata = metadata.filter(entry => selectedFiles.includes(entry.filedir.trim().toLowerCase()));
239
+ }
240
+
241
+ // Vérifier si aucun résultat après filtrage
242
+ if (filteredMetadata.length === 0) {
243
+ metadataDisplay.innerHTML = '<p>No metadata found.</p>';
244
+ return;
245
+ }
246
+
247
+ // Création du tableau Bootstrap
248
+ const table = document.createElement('table');
249
+ table.classList.add('table', 'table-striped', 'table-bordered');
250
+
251
+ // Création de l'en-tête du tableau
252
+ const headerRow = document.createElement('tr');
253
+ Object.keys(filteredMetadata[0]).forEach(headerText => {
254
+ const header = document.createElement('th');
255
+ header.textContent = headerText.charAt(0).toUpperCase() + headerText.slice(1);
256
+ headerRow.appendChild(header);
257
+ });
258
+ table.appendChild(headerRow);
259
+
260
+ // Remplir le tableau avec les métadonnées filtrées
261
+ filteredMetadata.forEach(entry => {
262
+ const row = document.createElement('tr');
263
+ Object.values(entry).forEach(value => {
264
+ const cell = document.createElement('td');
265
+ cell.textContent = value;
266
+ row.appendChild(cell);
267
+ });
268
+ table.appendChild(row);
269
+ });
270
+
271
+ // Ajouter le tableau à la section d'affichage des métadonnées
272
+ metadataDisplay.appendChild(table);
273
+ }
274
+
275
+
276
+ document.addEventListener('DOMContentLoaded', async () => {
277
+ populateFilters(); // Charger les valeurs fixes dans les drop-downs
278
+ const metadata = await fetchMetadata();
279
+ displayMetadata(metadata);
280
+ });
281
+
282
+ async function uploadAudio(files) {
283
+ if (!files || files.length === 0) {
284
+ alert('Please select or record files first!');
285
+ return;
286
+ }
287
+
288
+ const formData = new FormData();
289
+ const filesArray = Array.from(files);
290
+ for (let i = 0; i < filesArray.length; i++) {
291
+ formData.append('files', filesArray[i]);
292
+ }
293
+
294
+ responseDiv.textContent = 'Uploading and analyzing audio...';
295
+
296
+ try {
297
+ const metadataObj = await fetchMetadata();
298
+ displayMetadata(filesArray, metadataObj); // Afficher uniquement les métadonnées des fichiers sélectionnés
299
+
300
+ const response = await fetch('http://127.0.0.1:8000/predict/', {
301
+ method: 'POST',
302
+ body: formData,
303
+ });
304
+
305
+ if (!response.ok) {
306
+ const errorData = await response.json();
307
+ throw new Error(`Server error: ${errorData.message || response.statusText}`);
308
+ }
309
+
310
+ const data = await response.json();
311
+ responseDiv.innerHTML = '';
312
+
313
+ data.forEach((result, index) => {
314
+ const resultDiv = document.createElement('div');
315
+ resultDiv.innerHTML = `File: <b>${result.filename}</b>, Label: <b>${result.label}</b>, Confidence: <b>${result.confidence}</b>`;
316
+ responseDiv.appendChild(resultDiv);
317
+
318
+
319
+ });
320
+
321
+ } catch (error) {
322
+ console.error('Error:', error);
323
+ responseDiv.textContent = 'Error: ' + error.message;
324
+ }
325
+ }
326
+
327
+
328
+
329
+ uploadButton.addEventListener('click', () => {
330
+ const files = audioFileInput.files;
331
+ if (!files || files.length === 0) {
332
+ alert('Please select files first!');
333
+ return;
334
+ }
335
+ uploadAudio(files);
336
+ });
337
+
338
+ // Start Recording
339
+ recordButton.addEventListener('click', async () => {
340
+ startAudioContext(); // Initialiser ou reprendre l'AudioContext
341
+
342
+ console.log('Recording started');
343
+
344
+ const constraints = { audio: true, video: false };
345
+
346
+ try {
347
+ gumStream = await navigator.mediaDevices.getUserMedia(constraints);
348
+ console.log('Microphone access granted');
349
+ input = audioContext.createMediaStreamSource(gumStream);
350
+ console.log('Audio source created');
351
+
352
+ // Initialize Recorder.js
353
+ rec = new Recorder(input, { numChannels: 1 });
354
+ console.log('Recorder initialized');
355
+
356
+ // Start recording
357
+ rec.record();
358
+ console.log('Recording started');
359
+
360
+ // Update button states
361
+ recordButton.disabled = true;
362
+ stopButton.disabled = false;
363
+ pauseButton.disabled = false;
364
+ } catch (error) {
365
+ console.error('Error accessing microphone:', error);
366
+ alert('Error accessing microphone: ' + error.message);
367
+ }
368
+ });
369
+
370
+
371
+ function stopRecording() {
372
+ console.log('stopRecording called');
373
+
374
+ // Désactiver les boutons
375
+ stopButton.disabled = true;
376
+ recordButton.disabled = false;
377
+ pauseButton.disabled = true;
378
+ pauseButton.innerHTML = 'Pause';
379
+
380
+ // Arrêter l'enregistrement
381
+ rec.stop();
382
+ console.log('Recording stopped');
383
+
384
+ // Arrêter l'accès au microphone
385
+ gumStream.getAudioTracks()[0].stop();
386
+ console.log('Microphone access stopped');
387
+
388
+ // Exporter l'audio en WAV
389
+ rec.exportWAV(async (blob) => {
390
+ console.log('Audio exported as WAV');
391
+
392
+ // Vérifier la taille du fichier audio
393
+ if (blob.size === 0) {
394
+ console.error('Le fichier audio est vide.');
395
+ responseDiv.textContent = 'Erreur : Le fichier audio est vide.';
396
+ return;
397
+ }
398
+
399
+ // Rééchantillonner l'audio à 16 kHz
400
+ try {
401
+ const resampledBlob = await resampleAudio(blob, 16000);
402
+ console.log('Audio rééchantillonné à 16 kHz');
403
+
404
+
405
+ // Envoyer l'audio rééchantillonné à l'API pour analyse
406
+ await sendAudioToAPI(resampledBlob); // Ajouter await ici
407
+ } catch (error) {
408
+ console.error('Erreur lors du rééchantillonnage :', error);
409
+ responseDiv.textContent = 'Erreur : ' + error.message;
410
+ }
411
+ });
412
+ }
413
+
414
+ async function sendAudioToAPI(blob) {
415
+ console.log('Sending audio to API');
416
+
417
+ const formData = new FormData();
418
+ const filename = 'recorded-audio.wav'; // Nom du fichier
419
+ formData.append('files', blob, filename); // Utiliser 'files' comme nom de champ
420
+
421
+ try {
422
+ const response = await fetch('http://127.0.0.1:8000/predict/', {
423
+ method: 'POST',
424
+ body: formData,
425
+ });
426
+
427
+ console.log('API response status:', response.status);
428
+
429
+ if (!response.ok) {
430
+ throw new Error(`HTTP error! status: ${response.status}`);
431
+ }
432
+
433
+ const data = await response.json();
434
+ console.log('API response data:', data);
435
+
436
+ // Afficher le résultat de l'API
437
+ if (data.length > 0) {
438
+ responseDiv.innerHTML = `Label: <b>${data[0].label}</b>, Confidence: <b>${data[0].confidence}</b>`;
439
+ } else {
440
+ responseDiv.textContent = 'Error: No data returned from the API.';
441
+ }
442
+ } catch (error) {
443
+ console.error('Error sending audio to API:', error);
444
+ responseDiv.textContent = 'Error: ' + error.message;
445
+ }
446
+ }
447
+
448
+ // Pause Recording
449
+ pauseButton.addEventListener('click', () => {
450
+ if (rec.recording) {
451
+ // Pause recording
452
+ rec.stop();
453
+ pauseButton.textContent = 'Resume';
454
+ } else {
455
+ // Resume recording
456
+ rec.record();
457
+ pauseButton.textContent = 'Pause';
458
+ }
459
+ });
460
+
461
+
462
+ stopButton.addEventListener('click', () => {
463
+ stopRecording();
464
+ });
465
+
466
+ // Ajouter un écouteur d'événement pour un clic utilisateur sur le bouton d'enregistrement
467
+ recordButton.addEventListener('click', async () => {
468
+ startAudioContext(); // Initialiser ou reprendre l'AudioContext
469
+
470
+ console.log('Recording started');
471
+
472
+ const constraints = { audio: true, video: false };
473
+
474
+ try {
475
+ gumStream = await navigator.mediaDevices.getUserMedia(constraints);
476
+ console.log('Microphone access granted');
477
+ input = audioContext.createMediaStreamSource(gumStream);
478
+ console.log('Audio source created');
479
+
480
+ // Initialize Recorder.js
481
+ rec = new Recorder(input, { numChannels: 1 });
482
+ console.log('Recorder initialized');
483
+
484
+ // Start recording
485
+ rec.record();
486
+ console.log('Recording started');
487
+
488
+ // Update button states
489
+ recordButton.disabled = true;
490
+ stopButton.disabled = false;
491
+ pauseButton.disabled = false;
492
+ } catch (error) {
493
+ console.error('Error accessing microphone:', error);
494
+ alert('Error accessing microphone: ' + error.message);
495
+ }
496
+ });
497
+
498
+
Web/styles.css ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .metadata-table {
2
+ width: 100%;
3
+ border-collapse: collapse;
4
+ margin-top: 10px;
5
+ }
6
+
7
+ .metadata-table th,
8
+ .metadata-table td {
9
+ border: 1px solid #ddd;
10
+ padding: 8px;
11
+ text-align: left;
12
+ }
13
+
14
+ .metadata-table th {
15
+ background-color: #f2f2f2;
16
+ font-weight: bold;
17
+ }
18
+
19
+ .metadata-table tr:nth-child(even) {
20
+ background-color: #f9f9f9;
21
+ }
22
+
23
+ .metadata-table tr:hover {
24
+ background-color: #f1f1f1;
25
+ }
calculate_modules.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+
4
+
5
+ def obtain_asv_error_rates(tar_asv, non_asv, spoof_asv, asv_threshold):
6
+
7
+ # False alarm and miss rates for ASV
8
+ Pfa_asv = sum(non_asv >= asv_threshold) / non_asv.size
9
+ Pmiss_asv = sum(tar_asv < asv_threshold) / tar_asv.size
10
+
11
+ # Rate of rejecting spoofs in ASV
12
+ if spoof_asv.size == 0:
13
+ Pmiss_spoof_asv = None
14
+ Pfa_spoof_asv = None
15
+ else:
16
+ Pmiss_spoof_asv = np.sum(spoof_asv < asv_threshold) / spoof_asv.size
17
+ Pfa_spoof_asv = np.sum(spoof_asv >= asv_threshold) / spoof_asv.size
18
+
19
+ return Pfa_asv, Pmiss_asv, Pmiss_spoof_asv, Pfa_spoof_asv
20
+
21
+
22
+ def obtain_asv_error_rates(tar_asv, non_asv, spoof_asv, asv_threshold):
23
+
24
+ # False alarm and miss rates for ASV
25
+ Pfa_asv = sum(non_asv >= asv_threshold) / non_asv.size
26
+ Pmiss_asv = sum(tar_asv < asv_threshold) / tar_asv.size
27
+
28
+ # Rate of rejecting spoofs in ASV
29
+ if spoof_asv.size == 0:
30
+ Pmiss_spoof_asv = None
31
+ Pfa_spoof_asv = None
32
+ else:
33
+ Pmiss_spoof_asv = np.sum(spoof_asv < asv_threshold) / spoof_asv.size
34
+ Pfa_spoof_asv = np.sum(spoof_asv >= asv_threshold) / spoof_asv.size
35
+
36
+ return Pfa_asv, Pmiss_asv, Pmiss_spoof_asv, Pfa_spoof_asv
37
+
38
+
39
+ def compute_det_curve(target_scores, nontarget_scores):
40
+
41
+ n_scores = target_scores.size + nontarget_scores.size
42
+ all_scores = np.concatenate((target_scores, nontarget_scores))
43
+ labels = np.concatenate(
44
+ (np.ones(target_scores.size), np.zeros(nontarget_scores.size)))
45
+
46
+ # Sort labels based on scores
47
+ indices = np.argsort(all_scores, kind='mergesort')
48
+ labels = labels[indices]
49
+
50
+ # Compute false rejection and false acceptance rates
51
+ tar_trial_sums = np.cumsum(labels)
52
+ nontarget_trial_sums = nontarget_scores.size - \
53
+ (np.arange(1, n_scores + 1) - tar_trial_sums)
54
+
55
+ # false rejection rates
56
+ frr = np.concatenate(
57
+ (np.atleast_1d(0), tar_trial_sums / target_scores.size))
58
+ far = np.concatenate((np.atleast_1d(1), nontarget_trial_sums /
59
+ nontarget_scores.size)) # false acceptance rates
60
+ # Thresholds are the sorted scores
61
+ thresholds = np.concatenate(
62
+ (np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices]))
63
+
64
+ return frr, far, thresholds
65
+
66
+
67
+ def compute_Pmiss_Pfa_Pspoof_curves(tar_scores, non_scores, spf_scores):
68
+
69
+ # Concatenate all scores and designate arbitrary labels 1=target, 0=nontarget, -1=spoof
70
+ all_scores = np.concatenate((tar_scores, non_scores, spf_scores))
71
+ labels = np.concatenate((np.ones(tar_scores.size), np.zeros(non_scores.size), -1*np.ones(spf_scores.size)))
72
+
73
+ # Sort labels based on scores
74
+ indices = np.argsort(all_scores, kind='mergesort')
75
+ labels = labels[indices]
76
+
77
+ # Cumulative sums
78
+ tar_sums = np.cumsum(labels==1)
79
+ non_sums = np.cumsum(labels==0)
80
+ spoof_sums = np.cumsum(labels==-1)
81
+
82
+ Pmiss = np.concatenate((np.atleast_1d(0), tar_sums / tar_scores.size))
83
+ Pfa_non = np.concatenate((np.atleast_1d(1), 1 - (non_sums / non_scores.size)))
84
+ Pfa_spoof = np.concatenate((np.atleast_1d(1), 1 - (spoof_sums / spf_scores.size)))
85
+ thresholds = np.concatenate((np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices])) # Thresholds are the sorted scores
86
+
87
+ return Pmiss, Pfa_non, Pfa_spoof, thresholds
88
+
89
+
90
+ def compute_eer(target_scores, nontarget_scores):
91
+ """ Returns equal error rate (EER) and the corresponding threshold. """
92
+ frr, far, thresholds = compute_det_curve(target_scores, nontarget_scores)
93
+ abs_diffs = np.abs(frr - far)
94
+ min_index = np.argmin(abs_diffs)
95
+ eer = np.mean((frr[min_index], far[min_index]))
96
+ return eer, frr, far, thresholds
97
+
98
+
99
+ def compute_mindcf(frr, far, thresholds, Pspoof, Cmiss, Cfa):
100
+ min_c_det = float("inf")
101
+ min_c_det_threshold = thresholds
102
+
103
+ p_target = 1- Pspoof
104
+ for i in range(0, len(frr)):
105
+ # Weighted sum of false negative and false positive errors.
106
+ c_det = Cmiss * frr[i] * p_target + Cfa * far[i] * (1 - p_target)
107
+ if c_det < min_c_det:
108
+ min_c_det = c_det
109
+ min_c_det_threshold = thresholds[i]
110
+ # See Equations (3) and (4). Now we normalize the cost.
111
+ c_def = min(Cmiss * p_target, Cfa * (1 - p_target))
112
+ min_dcf = min_c_det / c_def
113
+ return min_dcf, min_c_det_threshold
114
+
115
+
116
+ def compute_tDCF(bonafide_score_cm, spoof_score_cm, Pfa_asv, Pmiss_asv,
117
+ Pmiss_spoof_asv, cost_model, print_cost):
118
+
119
+ # Sanity check of cost parameters
120
+ if cost_model['Cfa_asv'] < 0 or cost_model['Cmiss_asv'] < 0 or \
121
+ cost_model['Cfa_cm'] < 0 or cost_model['Cmiss_cm'] < 0:
122
+ print('WARNING: Usually the cost values should be positive!')
123
+
124
+ if cost_model['Ptar'] < 0 or cost_model['Pnon'] < 0 or cost_model['Pspoof'] < 0 or \
125
+ np.abs(cost_model['Ptar'] + cost_model['Pnon'] + cost_model['Pspoof'] - 1) > 1e-10:
126
+ sys.exit(
127
+ 'ERROR: Your prior probabilities should be positive and sum up to one.'
128
+ )
129
+
130
+ # Unless we evaluate worst-case model, we need to have some spoof tests against asv
131
+ if Pmiss_spoof_asv is None:
132
+ sys.exit(
133
+ 'ERROR: you should provide miss rate of spoof tests against your ASV system.'
134
+ )
135
+
136
+ # Sanity check of scores
137
+ combined_scores = np.concatenate((bonafide_score_cm, spoof_score_cm))
138
+ if np.isnan(combined_scores).any() or np.isinf(combined_scores).any():
139
+ sys.exit('ERROR: Your scores contain nan or inf.')
140
+
141
+ # Sanity check that inputs are scores and not decisions
142
+ n_uniq = np.unique(combined_scores).size
143
+ if n_uniq < 3:
144
+ sys.exit(
145
+ 'ERROR: You should provide soft CM scores - not binary decisions')
146
+
147
+ # Obtain miss and false alarm rates of CM
148
+ Pmiss_cm, Pfa_cm, CM_thresholds = compute_det_curve(
149
+ bonafide_score_cm, spoof_score_cm)
150
+
151
+ # Constants - see ASVspoof 2019 evaluation plan
152
+ C1 = cost_model['Ptar'] * (cost_model['Cmiss_cm'] - cost_model['Cmiss_asv'] * Pmiss_asv) - \
153
+ cost_model['Pnon'] * cost_model['Cfa_asv'] * Pfa_asv
154
+ C2 = cost_model['Cfa_cm'] * cost_model['Pspoof'] * (1 - Pmiss_spoof_asv)
155
+
156
+ # Sanity check of the weights
157
+ if C1 < 0 or C2 < 0:
158
+ sys.exit(
159
+ 'You should never see this error but I cannot evalute tDCF with negative weights - please check whether your ASV error rates are correctly computed?'
160
+ )
161
+
162
+ # Obtain t-DCF curve for all thresholds
163
+ tDCF = C1 * Pmiss_cm + C2 * Pfa_cm
164
+
165
+ # Normalized t-DCF
166
+ tDCF_norm = tDCF / np.minimum(C1, C2)
167
+
168
+ # Everything should be fine if reaching here.
169
+ if print_cost:
170
+
171
+ print('t-DCF evaluation from [Nbona={}, Nspoof={}] trials\n'.format(
172
+ bonafide_score_cm.size, spoof_score_cm.size))
173
+ print('t-DCF MODEL')
174
+ print(' Ptar = {:8.5f} (Prior probability of target user)'.
175
+ format(cost_model['Ptar']))
176
+ print(
177
+ ' Pnon = {:8.5f} (Prior probability of nontarget user)'.
178
+ format(cost_model['Pnon']))
179
+ print(
180
+ ' Pspoof = {:8.5f} (Prior probability of spoofing attack)'.
181
+ format(cost_model['Pspoof']))
182
+ print(
183
+ ' Cfa_asv = {:8.5f} (Cost of ASV falsely accepting a nontarget)'
184
+ .format(cost_model['Cfa_asv']))
185
+ print(
186
+ ' Cmiss_asv = {:8.5f} (Cost of ASV falsely rejecting target speaker)'
187
+ .format(cost_model['Cmiss_asv']))
188
+ print(
189
+ ' Cfa_cm = {:8.5f} (Cost of CM falsely passing a spoof to ASV system)'
190
+ .format(cost_model['Cfa_cm']))
191
+ print(
192
+ ' Cmiss_cm = {:8.5f} (Cost of CM falsely blocking target utterance which never reaches ASV)'
193
+ .format(cost_model['Cmiss_cm']))
194
+ print(
195
+ '\n Implied normalized t-DCF function (depends on t-DCF parameters and ASV errors), s=CM threshold)'
196
+ )
197
+
198
+ if C2 == np.minimum(C1, C2):
199
+ print(
200
+ ' tDCF_norm(s) = {:8.5f} x Pmiss_cm(s) + Pfa_cm(s)\n'.format(
201
+ C1 / C2))
202
+ else:
203
+ print(
204
+ ' tDCF_norm(s) = Pmiss_cm(s) + {:8.5f} x Pfa_cm(s)\n'.format(
205
+ C2 / C1))
206
+
207
+ return tDCF_norm, CM_thresholds
208
+
209
+
210
+ def calculate_CLLR(target_llrs, nontarget_llrs):
211
+ """
212
+ Calculate the CLLR of the scores.
213
+
214
+ Parameters:
215
+ target_llrs (list or numpy array): Log-likelihood ratios for target trials.
216
+ nontarget_llrs (list or numpy array): Log-likelihood ratios for non-target trials.
217
+
218
+ Returns:
219
+ float: The calculated CLLR value.
220
+ """
221
+ def negative_log_sigmoid(lodds):
222
+ """
223
+ Calculate the negative log of the sigmoid function.
224
+
225
+ Parameters:
226
+ lodds (numpy array): Log-odds values.
227
+
228
+ Returns:
229
+ numpy array: The negative log of the sigmoid values.
230
+ """
231
+ return np.log1p(np.exp(-lodds))
232
+
233
+ # Convert the input lists to numpy arrays if they are not already
234
+ target_llrs = np.array(target_llrs)
235
+ nontarget_llrs = np.array(nontarget_llrs)
236
+
237
+ # Calculate the CLLR value
238
+ cllr = 0.5 * (np.mean(negative_log_sigmoid(target_llrs)) + np.mean(negative_log_sigmoid(-nontarget_llrs))) / np.log(2)
239
+
240
+ return cllr
241
+
242
+
243
+ def compute_Pmiss_Pfa_Pspoof_curves(tar_scores, non_scores, spf_scores):
244
+
245
+ # Concatenate all scores and designate arbitrary labels 1=target, 0=nontarget, -1=spoof
246
+ all_scores = np.concatenate((tar_scores, non_scores, spf_scores))
247
+ labels = np.concatenate((np.ones(tar_scores.size), np.zeros(non_scores.size), -1*np.ones(spf_scores.size)))
248
+
249
+ # Sort labels based on scores
250
+ indices = np.argsort(all_scores, kind='mergesort')
251
+ labels = labels[indices]
252
+
253
+ # Cumulative sums
254
+ tar_sums = np.cumsum(labels==1)
255
+ non_sums = np.cumsum(labels==0)
256
+ spoof_sums = np.cumsum(labels==-1)
257
+
258
+ Pmiss = np.concatenate((np.atleast_1d(0), tar_sums / tar_scores.size))
259
+ Pfa_non = np.concatenate((np.atleast_1d(1), 1 - (non_sums / non_scores.size)))
260
+ Pfa_spoof = np.concatenate((np.atleast_1d(1), 1 - (spoof_sums / spf_scores.size)))
261
+ thresholds = np.concatenate((np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices])) # Thresholds are the sorted scores
262
+
263
+ return Pmiss, Pfa_non, Pfa_spoof, thresholds
264
+
265
+
266
+ def compute_teer(Pmiss_CM, Pfa_CM, tau_CM, Pmiss_ASV, Pfa_non_ASV, Pfa_spf_ASV, tau_ASV):
267
+ # Different spoofing prevalence priors (rho) parameters values
268
+ rho_vals = [0,0.5,1]
269
+
270
+ tEER_val = np.empty([len(rho_vals),len(tau_ASV)], dtype=float)
271
+
272
+ for rho_idx, rho_spf in enumerate(rho_vals):
273
+
274
+ # Table to store the CM threshold index, per each of the ASV operating points
275
+ tEER_idx_CM = np.empty(len(tau_ASV), dtype=int)
276
+
277
+ tEER_path = np.empty([len(rho_vals),len(tau_ASV),2], dtype=float)
278
+
279
+ # Tables to store the t-EER, total Pfa and total miss valuees along the t-EER path
280
+ Pmiss_total = np.empty(len(tau_ASV), dtype=float)
281
+ Pfa_total = np.empty(len(tau_ASV), dtype=float)
282
+ min_tEER = np.inf
283
+ argmin_tEER = np.empty(2)
284
+
285
+ # best intersection point
286
+ xpoint_crit_best = np.inf
287
+ xpoint = np.empty(2)
288
+
289
+ # Loop over all possible ASV thresholds
290
+ for tau_ASV_idx, tau_ASV_val in enumerate(tau_ASV):
291
+
292
+ # Tandem miss and fa rates as defined in the manuscript
293
+ Pmiss_tdm = Pmiss_CM + (1 - Pmiss_CM) * Pmiss_ASV[tau_ASV_idx]
294
+ Pfa_tdm = (1 - rho_spf) * (1 - Pmiss_CM) * Pfa_non_ASV[tau_ASV_idx] + rho_spf * Pfa_CM * Pfa_spf_ASV[tau_ASV_idx]
295
+
296
+ # Store only the INDEX of the CM threshold (for the current ASV threshold)
297
+ h = Pmiss_tdm - Pfa_tdm
298
+ tmp = np.argmin(abs(h))
299
+ tEER_idx_CM[tau_ASV_idx] = tmp
300
+
301
+ if Pmiss_ASV[tau_ASV_idx] < (1 - rho_spf) * Pfa_non_ASV[tau_ASV_idx] + rho_spf * Pfa_spf_ASV[tau_ASV_idx]:
302
+ Pmiss_total[tau_ASV_idx] = Pmiss_tdm[tmp]
303
+ Pfa_total[tau_ASV_idx] = Pfa_tdm[tmp]
304
+
305
+ tEER_val[rho_idx,tau_ASV_idx] = np.mean([Pfa_total[tau_ASV_idx], Pmiss_total[tau_ASV_idx]])
306
+
307
+ tEER_path[rho_idx,tau_ASV_idx, 0] = tau_ASV_val
308
+ tEER_path[rho_idx,tau_ASV_idx, 1] = tau_CM[tmp]
309
+
310
+ if tEER_val[rho_idx,tau_ASV_idx] < min_tEER:
311
+ min_tEER = tEER_val[rho_idx,tau_ASV_idx]
312
+ argmin_tEER[0] = tau_ASV_val
313
+ argmin_tEER[1] = tau_CM[tmp]
314
+
315
+ # Check how close we are to the INTERSECTION POINT for different prior (rho) values:
316
+ LHS = Pfa_non_ASV[tau_ASV_idx]/Pfa_spf_ASV[tau_ASV_idx]
317
+ RHS = Pfa_CM[tmp]/(1 - Pmiss_CM[tmp])
318
+ crit = abs(LHS - RHS)
319
+
320
+ if crit < xpoint_crit_best:
321
+ xpoint_crit_best = crit
322
+ xpoint[0] = tau_ASV_val
323
+ xpoint[1] = tau_CM[tmp]
324
+ xpoint_tEER = Pfa_spf_ASV[tau_ASV_idx]*Pfa_CM[tmp]
325
+ else:
326
+ # Not in allowed region
327
+ tEER_path[rho_idx,tau_ASV_idx, 0] = np.nan
328
+ tEER_path[rho_idx,tau_ASV_idx, 1] = np.nan
329
+ Pmiss_total[tau_ASV_idx] = np.nan
330
+ Pfa_total[tau_ASV_idx] = np.nan
331
+ tEER_val[rho_idx,tau_ASV_idx] = np.nan
332
+
333
+ return xpoint_tEER*100
docker-compose.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3.8'
2
+
3
+ services:
4
+ backend:
5
+ build: .
6
+ ports:
7
+ - "8000:8000"
8
+ volumes:
9
+ - .:/app
10
+ command: uvicorn main:app --host 0.0.0.0 --port 8000
11
+
12
+ frontend:
13
+ image: nginx:alpine
14
+ ports:
15
+ - "80:80"
16
+ volumes:
17
+ - ./index.html:/usr/share/nginx/html/index.html
18
+ - ./script.js:/usr/share/nginx/html/script.js
19
+ - ./recorder.js:/usr/share/nginx/html/recorder.js
20
+ depends_on:
21
+ - backend
model_utils.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AASIST
3
+ Copyright (c) 2021-present NAVER Corp.
4
+ MIT license
5
+ """
6
+
7
+ import random
8
+ from typing import Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torch import Tensor
15
+
16
+ import json
17
+ import torchaudio
18
+ import numpy as np
19
+
20
+ # Ensure that the Model class and all related components (GraphAttentionLayer, etc.) are defined here
21
+ # Placeholder for dependencies
22
+ # class Model(nn.Module):
23
+ # def __init__(self, d_args):
24
+ # # Your model implementation
25
+ # pass
26
+
27
+ # Function to load configuration
28
+ def load_config(config_path):
29
+ with open(config_path, 'r') as f:
30
+ return json.load(f)
31
+
32
+ # Function to load the model
33
+ def load_model(checkpoint_path, d_args):
34
+ model = Model(d_args)
35
+ try:
36
+ # Load checkpoint
37
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
38
+ model.load_state_dict(checkpoint)
39
+ print("Model loaded successfully.")
40
+ except Exception as e:
41
+ print(f"Error loading model: {e}")
42
+ raise
43
+ model.eval()
44
+ return model
45
+
46
+ # Preprocess audio
47
+ def preprocess_audio(audio_path, sample_rate=16000):
48
+ try:
49
+ waveform, sr = torchaudio.load(audio_path)
50
+ print(f"Loaded audio: {audio_path}, Sample Rate: {sr}")
51
+ if sr != sample_rate:
52
+ resample_transform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
53
+ waveform = resample_transform(waveform)
54
+ if waveform.size(0) > 1:
55
+ waveform = torch.mean(waveform, dim=0, keepdim=True) # Convert to mono if stereo
56
+ return waveform
57
+ except Exception as e:
58
+ print(f"Error in audio preprocessing: {e}")
59
+ raise
60
+
61
+ # Inference function
62
+ def infer(model, waveform, freq_aug=False):
63
+ try:
64
+ with torch.no_grad():
65
+ last_hidden, output = model(waveform, Freq_aug=freq_aug)
66
+ print("Model output:", output)
67
+ if output is None:
68
+ raise ValueError("Model output is None.")
69
+ predicted_label = torch.argmax(output, dim=1).item()
70
+ return predicted_label, output
71
+ except Exception as e:
72
+ print(f"Error during inference: {e}")
73
+ raise
74
+
75
+
76
+ class GraphAttentionLayer(nn.Module):
77
+ def __init__(self, in_dim, out_dim, **kwargs):
78
+ super().__init__()
79
+
80
+ # attention map
81
+ self.att_proj = nn.Linear(in_dim, out_dim)
82
+ self.att_weight = self._init_new_params(out_dim, 1)
83
+
84
+ # project
85
+ self.proj_with_att = nn.Linear(in_dim, out_dim)
86
+ self.proj_without_att = nn.Linear(in_dim, out_dim)
87
+
88
+ # batch norm
89
+ self.bn = nn.BatchNorm1d(out_dim)
90
+
91
+ # dropout for inputs
92
+ self.input_drop = nn.Dropout(p=0.2)
93
+
94
+ # activate
95
+ self.act = nn.SELU(inplace=True)
96
+
97
+ # temperature
98
+ self.temp = 1.
99
+ if "temperature" in kwargs:
100
+ self.temp = kwargs["temperature"]
101
+
102
+ def forward(self, x):
103
+ '''
104
+ x :(#bs, #node, #dim)
105
+ '''
106
+ # apply input dropout
107
+ x = self.input_drop(x)
108
+
109
+ # derive attention map
110
+ att_map = self._derive_att_map(x)
111
+
112
+ # projection
113
+ x = self._project(x, att_map)
114
+
115
+ # apply batch norm
116
+ x = self._apply_BN(x)
117
+ x = self.act(x)
118
+ return x
119
+
120
+ def _pairwise_mul_nodes(self, x):
121
+ '''
122
+ Calculates pairwise multiplication of nodes.
123
+ - for attention map
124
+ x :(#bs, #node, #dim)
125
+ out_shape :(#bs, #node, #node, #dim)
126
+ '''
127
+
128
+ nb_nodes = x.size(1)
129
+ x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
130
+ x_mirror = x.transpose(1, 2)
131
+
132
+ return x * x_mirror
133
+
134
+ def _derive_att_map(self, x):
135
+ '''
136
+ x :(#bs, #node, #dim)
137
+ out_shape :(#bs, #node, #node, 1)
138
+ '''
139
+ att_map = self._pairwise_mul_nodes(x)
140
+ # size: (#bs, #node, #node, #dim_out)
141
+ att_map = torch.tanh(self.att_proj(att_map))
142
+ # size: (#bs, #node, #node, 1)
143
+ att_map = torch.matmul(att_map, self.att_weight)
144
+
145
+ # apply temperature
146
+ att_map = att_map / self.temp
147
+
148
+ att_map = F.softmax(att_map, dim=-2)
149
+
150
+ return att_map
151
+
152
+ def _project(self, x, att_map):
153
+ x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
154
+ x2 = self.proj_without_att(x)
155
+
156
+ return x1 + x2
157
+
158
+ def _apply_BN(self, x):
159
+ org_size = x.size()
160
+ x = x.view(-1, org_size[-1])
161
+ x = self.bn(x)
162
+ x = x.view(org_size)
163
+
164
+ return x
165
+
166
+ def _init_new_params(self, *size):
167
+ out = nn.Parameter(torch.FloatTensor(*size))
168
+ nn.init.xavier_normal_(out)
169
+ return out
170
+
171
+
172
+ class HtrgGraphAttentionLayer(nn.Module):
173
+ def __init__(self, in_dim, out_dim, **kwargs):
174
+ super().__init__()
175
+
176
+ self.proj_type1 = nn.Linear(in_dim, in_dim)
177
+ self.proj_type2 = nn.Linear(in_dim, in_dim)
178
+
179
+ # attention map
180
+ self.att_proj = nn.Linear(in_dim, out_dim)
181
+ self.att_projM = nn.Linear(in_dim, out_dim)
182
+
183
+ self.att_weight11 = self._init_new_params(out_dim, 1)
184
+ self.att_weight22 = self._init_new_params(out_dim, 1)
185
+ self.att_weight12 = self._init_new_params(out_dim, 1)
186
+ self.att_weightM = self._init_new_params(out_dim, 1)
187
+
188
+ # project
189
+ self.proj_with_att = nn.Linear(in_dim, out_dim)
190
+ self.proj_without_att = nn.Linear(in_dim, out_dim)
191
+
192
+ self.proj_with_attM = nn.Linear(in_dim, out_dim)
193
+ self.proj_without_attM = nn.Linear(in_dim, out_dim)
194
+
195
+ # batch norm
196
+ self.bn = nn.BatchNorm1d(out_dim)
197
+
198
+ # dropout for inputs
199
+ self.input_drop = nn.Dropout(p=0.2)
200
+
201
+ # activate
202
+ self.act = nn.SELU(inplace=True)
203
+
204
+ # temperature
205
+ self.temp = 1.
206
+ if "temperature" in kwargs:
207
+ self.temp = kwargs["temperature"]
208
+
209
+ def forward(self, x1, x2, master=None):
210
+ '''
211
+ x1 :(#bs, #node, #dim)
212
+ x2 :(#bs, #node, #dim)
213
+ '''
214
+ num_type1 = x1.size(1)
215
+ num_type2 = x2.size(1)
216
+
217
+ x1 = self.proj_type1(x1)
218
+ x2 = self.proj_type2(x2)
219
+
220
+ x = torch.cat([x1, x2], dim=1)
221
+
222
+ if master is None:
223
+ master = torch.mean(x, dim=1, keepdim=True)
224
+
225
+ # apply input dropout
226
+ x = self.input_drop(x)
227
+
228
+ # derive attention map
229
+ att_map = self._derive_att_map(x, num_type1, num_type2)
230
+
231
+ # directional edge for master node
232
+ master = self._update_master(x, master)
233
+
234
+ # projection
235
+ x = self._project(x, att_map)
236
+
237
+ # apply batch norm
238
+ x = self._apply_BN(x)
239
+ x = self.act(x)
240
+
241
+ x1 = x.narrow(1, 0, num_type1)
242
+ x2 = x.narrow(1, num_type1, num_type2)
243
+
244
+ return x1, x2, master
245
+
246
+ def _update_master(self, x, master):
247
+
248
+ att_map = self._derive_att_map_master(x, master)
249
+ master = self._project_master(x, master, att_map)
250
+
251
+ return master
252
+
253
+ def _pairwise_mul_nodes(self, x):
254
+ '''
255
+ Calculates pairwise multiplication of nodes.
256
+ - for attention map
257
+ x :(#bs, #node, #dim)
258
+ out_shape :(#bs, #node, #node, #dim)
259
+ '''
260
+
261
+ nb_nodes = x.size(1)
262
+ x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
263
+ x_mirror = x.transpose(1, 2)
264
+
265
+ return x * x_mirror
266
+
267
+ def _derive_att_map_master(self, x, master):
268
+ '''
269
+ x :(#bs, #node, #dim)
270
+ out_shape :(#bs, #node, #node, 1)
271
+ '''
272
+ att_map = x * master
273
+ att_map = torch.tanh(self.att_projM(att_map))
274
+
275
+ att_map = torch.matmul(att_map, self.att_weightM)
276
+
277
+ # apply temperature
278
+ att_map = att_map / self.temp
279
+
280
+ att_map = F.softmax(att_map, dim=-2)
281
+
282
+ return att_map
283
+
284
+ def _derive_att_map(self, x, num_type1, num_type2):
285
+ '''
286
+ x :(#bs, #node, #dim)
287
+ out_shape :(#bs, #node, #node, 1)
288
+ '''
289
+ att_map = self._pairwise_mul_nodes(x)
290
+ # size: (#bs, #node, #node, #dim_out)
291
+ att_map = torch.tanh(self.att_proj(att_map))
292
+ # size: (#bs, #node, #node, 1)
293
+
294
+ att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1)
295
+
296
+ att_board[:, :num_type1, :num_type1, :] = torch.matmul(
297
+ att_map[:, :num_type1, :num_type1, :], self.att_weight11)
298
+ att_board[:, num_type1:, num_type1:, :] = torch.matmul(
299
+ att_map[:, num_type1:, num_type1:, :], self.att_weight22)
300
+ att_board[:, :num_type1, num_type1:, :] = torch.matmul(
301
+ att_map[:, :num_type1, num_type1:, :], self.att_weight12)
302
+ att_board[:, num_type1:, :num_type1, :] = torch.matmul(
303
+ att_map[:, num_type1:, :num_type1, :], self.att_weight12)
304
+
305
+ att_map = att_board
306
+
307
+ # att_map = torch.matmul(att_map, self.att_weight12)
308
+
309
+ # apply temperature
310
+ att_map = att_map / self.temp
311
+
312
+ att_map = F.softmax(att_map, dim=-2)
313
+
314
+ return att_map
315
+
316
+ def _project(self, x, att_map):
317
+ x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
318
+ x2 = self.proj_without_att(x)
319
+
320
+ return x1 + x2
321
+
322
+ def _project_master(self, x, master, att_map):
323
+
324
+ x1 = self.proj_with_attM(torch.matmul(
325
+ att_map.squeeze(-1).unsqueeze(1), x))
326
+ x2 = self.proj_without_attM(master)
327
+
328
+ return x1 + x2
329
+
330
+ def _apply_BN(self, x):
331
+ org_size = x.size()
332
+ x = x.view(-1, org_size[-1])
333
+ x = self.bn(x)
334
+ x = x.view(org_size)
335
+
336
+ return x
337
+
338
+ def _init_new_params(self, *size):
339
+ out = nn.Parameter(torch.FloatTensor(*size))
340
+ nn.init.xavier_normal_(out)
341
+ return out
342
+
343
+
344
+ class GraphPool(nn.Module):
345
+ def __init__(self, k: float, in_dim: int, p: Union[float, int]):
346
+ super().__init__()
347
+ self.k = k
348
+ self.sigmoid = nn.Sigmoid()
349
+ self.proj = nn.Linear(in_dim, 1)
350
+ self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
351
+ self.in_dim = in_dim
352
+
353
+ def forward(self, h):
354
+ Z = self.drop(h)
355
+ weights = self.proj(Z)
356
+ scores = self.sigmoid(weights)
357
+ new_h = self.top_k_graph(scores, h, self.k)
358
+
359
+ return new_h
360
+
361
+ def top_k_graph(self, scores, h, k):
362
+ """
363
+ args
364
+ =====
365
+ scores: attention-based weights (#bs, #node, 1)
366
+ h: graph data (#bs, #node, #dim)
367
+ k: ratio of remaining nodes, (float)
368
+
369
+ returns
370
+ =====
371
+ h: graph pool applied data (#bs, #node', #dim)
372
+ """
373
+ _, n_nodes, n_feat = h.size()
374
+ n_nodes = max(int(n_nodes * k), 1)
375
+ _, idx = torch.topk(scores, n_nodes, dim=1)
376
+ idx = idx.expand(-1, -1, n_feat)
377
+
378
+ h = h * scores
379
+ h = torch.gather(h, 1, idx)
380
+
381
+ return h
382
+
383
+
384
+ class CONV(nn.Module):
385
+ @staticmethod
386
+ def to_mel(hz):
387
+ return 2595 * np.log10(1 + hz / 700)
388
+
389
+ @staticmethod
390
+ def to_hz(mel):
391
+ return 700 * (10**(mel / 2595) - 1)
392
+
393
+ def __init__(self,
394
+ out_channels,
395
+ kernel_size,
396
+ sample_rate=16000,
397
+ in_channels=1,
398
+ stride=1,
399
+ padding=0,
400
+ dilation=1,
401
+ bias=False,
402
+ groups=1,
403
+ mask=False):
404
+ super().__init__()
405
+ if in_channels != 1:
406
+
407
+ msg = "SincConv only support one input channel (here, in_channels = {%i})" % (
408
+ in_channels)
409
+ raise ValueError(msg)
410
+ self.out_channels = out_channels
411
+ self.kernel_size = kernel_size
412
+ self.sample_rate = sample_rate
413
+
414
+ # Forcing the filters to be odd (i.e, perfectly symmetrics)
415
+ if kernel_size % 2 == 0:
416
+ self.kernel_size = self.kernel_size + 1
417
+ self.stride = stride
418
+ self.padding = padding
419
+ self.dilation = dilation
420
+ self.mask = mask
421
+ if bias:
422
+ raise ValueError('SincConv does not support bias.')
423
+ if groups > 1:
424
+ raise ValueError('SincConv does not support groups.')
425
+
426
+ NFFT = 512
427
+ f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
428
+ fmel = self.to_mel(f)
429
+ fmelmax = np.max(fmel)
430
+ fmelmin = np.min(fmel)
431
+ filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 1)
432
+ filbandwidthsf = self.to_hz(filbandwidthsmel)
433
+
434
+ self.mel = filbandwidthsf
435
+ self.hsupp = torch.arange(-(self.kernel_size - 1) / 2,
436
+ (self.kernel_size - 1) / 2 + 1)
437
+ self.band_pass = torch.zeros(self.out_channels, self.kernel_size)
438
+ for i in range(len(self.mel) - 1):
439
+ fmin = self.mel[i]
440
+ fmax = self.mel[i + 1]
441
+ hHigh = (2*fmax/self.sample_rate) * \
442
+ np.sinc(2*fmax*self.hsupp/self.sample_rate)
443
+ hLow = (2*fmin/self.sample_rate) * \
444
+ np.sinc(2*fmin*self.hsupp/self.sample_rate)
445
+ hideal = hHigh - hLow
446
+
447
+ self.band_pass[i, :] = Tensor(np.hamming(
448
+ self.kernel_size)) * Tensor(hideal)
449
+
450
+ def forward(self, x, mask=False):
451
+ band_pass_filter = self.band_pass.clone().to(x.device)
452
+ if mask:
453
+ A = np.random.uniform(0, 20)
454
+ A = int(A)
455
+ A0 = random.randint(0, band_pass_filter.shape[0] - A)
456
+ band_pass_filter[A0:A0 + A, :] = 0
457
+ else:
458
+ band_pass_filter = band_pass_filter
459
+
460
+ self.filters = (band_pass_filter).view(self.out_channels, 1,
461
+ self.kernel_size)
462
+
463
+ return F.conv1d(x,
464
+ self.filters,
465
+ stride=self.stride,
466
+ padding=self.padding,
467
+ dilation=self.dilation,
468
+ bias=None,
469
+ groups=1)
470
+
471
+
472
+ class Residual_block(nn.Module):
473
+ def __init__(self, nb_filts, first=False):
474
+ super().__init__()
475
+ self.first = first
476
+
477
+ if not self.first:
478
+ self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
479
+ self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
480
+ out_channels=nb_filts[1],
481
+ kernel_size=(2, 3),
482
+ padding=(1, 1),
483
+ stride=1)
484
+ self.selu = nn.SELU(inplace=True)
485
+
486
+ self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
487
+ self.conv2 = nn.Conv2d(in_channels=nb_filts[1],
488
+ out_channels=nb_filts[1],
489
+ kernel_size=(2, 3),
490
+ padding=(0, 1),
491
+ stride=1)
492
+
493
+ if nb_filts[0] != nb_filts[1]:
494
+ self.downsample = True
495
+ self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
496
+ out_channels=nb_filts[1],
497
+ padding=(0, 1),
498
+ kernel_size=(1, 3),
499
+ stride=1)
500
+
501
+ else:
502
+ self.downsample = False
503
+ self.mp = nn.MaxPool2d((1, 3)) # self.mp = nn.MaxPool2d((1,4))
504
+
505
+ def forward(self, x):
506
+ identity = x
507
+ if not self.first:
508
+ out = self.bn1(x)
509
+ out = self.selu(out)
510
+ else:
511
+ out = x
512
+ out = self.conv1(x)
513
+
514
+ # print('out',out.shape)
515
+ out = self.bn2(out)
516
+ out = self.selu(out)
517
+ # print('out',out.shape)
518
+ out = self.conv2(out)
519
+ #print('conv2 out',out.shape)
520
+ if self.downsample:
521
+ identity = self.conv_downsample(identity)
522
+
523
+ out += identity
524
+ out = self.mp(out)
525
+ return out
526
+
527
+
528
+ class Model(nn.Module):
529
+ def __init__(self, d_args):
530
+ super().__init__()
531
+
532
+ self.d_args = d_args
533
+ filts = d_args["filts"]
534
+ gat_dims = d_args["gat_dims"]
535
+ pool_ratios = d_args["pool_ratios"]
536
+ temperatures = d_args["temperatures"]
537
+
538
+ self.conv_time = CONV(out_channels=filts[0],
539
+ kernel_size=d_args["first_conv"],
540
+ in_channels=1)
541
+ self.first_bn = nn.BatchNorm2d(num_features=1)
542
+
543
+ self.drop = nn.Dropout(0.5, inplace=True)
544
+ self.drop_way = nn.Dropout(0.2, inplace=True)
545
+ self.selu = nn.SELU(inplace=True)
546
+
547
+ self.encoder = nn.Sequential(
548
+ nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
549
+ nn.Sequential(Residual_block(nb_filts=filts[2])),
550
+ nn.Sequential(Residual_block(nb_filts=filts[3])),
551
+ nn.Sequential(Residual_block(nb_filts=filts[4])),
552
+ nn.Sequential(Residual_block(nb_filts=filts[4])),
553
+ nn.Sequential(Residual_block(nb_filts=filts[4])))
554
+
555
+ self.pos_S = nn.Parameter(torch.randn(1, 23, filts[-1][-1]))
556
+ self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
557
+ self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
558
+
559
+ self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1],
560
+ gat_dims[0],
561
+ temperature=temperatures[0])
562
+ self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1],
563
+ gat_dims[0],
564
+ temperature=temperatures[1])
565
+
566
+ self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(
567
+ gat_dims[0], gat_dims[1], temperature=temperatures[2])
568
+ self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(
569
+ gat_dims[1], gat_dims[1], temperature=temperatures[2])
570
+
571
+ self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(
572
+ gat_dims[0], gat_dims[1], temperature=temperatures[2])
573
+
574
+ self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(
575
+ gat_dims[1], gat_dims[1], temperature=temperatures[2])
576
+
577
+ self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
578
+ self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
579
+ self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
580
+ self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
581
+
582
+ self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
583
+ self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
584
+
585
+ if "output_cls" in d_args:
586
+ self.out_layer = nn.Linear(5 * gat_dims[1], d_args["output_cls"])
587
+ else:
588
+ self.out_layer = nn.Linear(5 * gat_dims[1], 2)
589
+
590
+ def forward(self, x, Freq_aug=False):
591
+
592
+ x = x.unsqueeze(1)
593
+ x = self.conv_time(x, mask=Freq_aug)
594
+ x = x.unsqueeze(dim=1)
595
+ x = F.max_pool2d(torch.abs(x), (3, 3))
596
+ x = self.first_bn(x)
597
+ x = self.selu(x)
598
+
599
+ # get embeddings using encoder
600
+ # (#bs, #filt, #spec, #seq)
601
+ e = self.encoder(x)
602
+
603
+ # spectral GAT (GAT-S)
604
+ e_S, _ = torch.max(torch.abs(e), dim=3) # max along time
605
+ e_S = e_S.transpose(1, 2) + self.pos_S
606
+
607
+ gat_S = self.GAT_layer_S(e_S)
608
+ out_S = self.pool_S(gat_S) # (#bs, #node, #dim)
609
+
610
+ # temporal GAT (GAT-T)
611
+ e_T, _ = torch.max(torch.abs(e), dim=2) # max along freq
612
+ e_T = e_T.transpose(1, 2)
613
+
614
+ gat_T = self.GAT_layer_T(e_T)
615
+ out_T = self.pool_T(gat_T)
616
+
617
+ # learnable master node
618
+ master1 = self.master1.expand(x.size(0), -1, -1)
619
+ master2 = self.master2.expand(x.size(0), -1, -1)
620
+
621
+ # inference 1
622
+ out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(
623
+ out_T, out_S, master=self.master1)
624
+
625
+ out_S1 = self.pool_hS1(out_S1)
626
+ out_T1 = self.pool_hT1(out_T1)
627
+
628
+ out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(
629
+ out_T1, out_S1, master=master1)
630
+ out_T1 = out_T1 + out_T_aug
631
+ out_S1 = out_S1 + out_S_aug
632
+ master1 = master1 + master_aug
633
+
634
+ # inference 2
635
+ out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(
636
+ out_T, out_S, master=self.master2)
637
+ out_S2 = self.pool_hS2(out_S2)
638
+ out_T2 = self.pool_hT2(out_T2)
639
+
640
+ out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(
641
+ out_T2, out_S2, master=master2)
642
+ out_T2 = out_T2 + out_T_aug
643
+ out_S2 = out_S2 + out_S_aug
644
+ master2 = master2 + master_aug
645
+
646
+ out_T1 = self.drop_way(out_T1)
647
+ out_T2 = self.drop_way(out_T2)
648
+ out_S1 = self.drop_way(out_S1)
649
+ out_S2 = self.drop_way(out_S2)
650
+ master1 = self.drop_way(master1)
651
+ master2 = self.drop_way(master2)
652
+
653
+ out_T = torch.max(out_T1, out_T2)
654
+ out_S = torch.max(out_S1, out_S2)
655
+ master = torch.max(master1, master2)
656
+
657
+ T_max, _ = torch.max(torch.abs(out_T), dim=1)
658
+ T_avg = torch.mean(out_T, dim=1)
659
+
660
+ S_max, _ = torch.max(torch.abs(out_S), dim=1)
661
+ S_avg = torch.mean(out_S, dim=1)
662
+
663
+ last_hidden = torch.cat(
664
+ [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)
665
+
666
+ last_hidden = self.drop(last_hidden)
667
+ output = self.out_layer(last_hidden)
668
+
669
+ output=F.softmax(output,dim=1)
670
+
671
+ return last_hidden, output