JunSiang26 commited on
Commit
9f2b6db
·
0 Parent(s):

Pure production deploy

Browse files
.gitattributes ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ backend/models/*.pth filter=lfs diff=lfs merge=lfs -text
2
+ **/*.pth filter=lfs diff=lfs merge=lfs -text
3
+ backend/models/**/*.pth filter=lfs diff=lfs merge=lfs -text
4
+ backend/models/Ablation[[:space:]]models/*.pth filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .venv
2
+ MLAAD-tiny/
3
+ data/
4
+ *.pth
5
+ Download_sample (Ignore)/
6
+ __pycache__/
7
+ project_requirements.txt
8
+ generate_notebook.py
9
+ backend/precomputed_features/
10
+
11
+ # Models
12
+ *.pth
13
+ backend/models/*.pth
14
+
15
+ # Environment
16
+ .env
17
+ .DS_Store
18
+ node_modules/
.vscode/settings.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "python-envs.defaultEnvManager": "ms-python.python:conda",
3
+ "python-envs.defaultPackageManager": "ms-python.python:conda"
4
+ }
AI Project 2026.pdf ADDED
Binary file (73.1 kB). View file
 
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use a slim Python image
2
+ FROM python:3.10-slim
3
+
4
+ # Install system dependencies for audio processing
5
+ RUN apt-get update && apt-get install -y \
6
+ libsndfile1 \
7
+ ffmpeg \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Set working directory
11
+ WORKDIR /app
12
+
13
+ # Copy requirements and install
14
+ COPY requirements.txt .
15
+ RUN pip install --no-cache-dir -r requirements.txt
16
+
17
+ # Copy the backend code and models
18
+ COPY backend/ ./backend/
19
+
20
+ # Expose the port FastAPI will run on
21
+ EXPOSE 7860
22
+
23
+ # Command to run the application
24
+ # Note: Hugging Face uses port 7860 by default
25
+ CMD ["uvicorn", "backend.app:app", "--host", "0.0.0.0", "--port", "7860"]
Model_Training_(Odio).ipynb ADDED
The diff for this file is too large to render. See raw diff
 
README.md ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <img width="1080" height="324" alt="odiocheck" src="https://github.com/user-attachments/assets/4d7b573e-5b0b-4fc7-85de-da60bbb701c2" />
2
+
3
+ # OdioCheck - Deepfake Voice Detection AI
4
+ *50.021 Artificial Intelligence Project*
5
+
6
+ ## Theme
7
+ **AI for Security & Social Good** (UN SDG #16: Peace, Justice, and Strong Institutions)
8
+ OdioCheck tackles the rising threat of audio deepfakes used in scams and misdirection.
9
+
10
+ ## Requirements Checklist
11
+ - [x] **Fully functioning code:** Complete end-to-end PyTorch implementation from dataset loading to real-time inference via a web UI.
12
+ - [x] **Baseline models (×3):**
13
+ - **Wav2Vec2** — self-supervised transformer feature extractor (frozen) + attentive pooling classifier (`backend/models.py`)
14
+ - **AASIST** — graph-based SOTA baseline using sinc-filter frontend + spectro-temporal heterogeneous graph attention (`backend/models.py`)
15
+ - **CQCC Baseline** — standard CNN processing Constant-Q Cepstral Coefficients (`backend/models.py`)
16
+ - [x] **SOTA Custom Model:** `ImprovedWav2Vec2CQCCDetector` — a novel fusion architecture combining Wav2Vec 2.0 and CQCC features via **bidirectional cross-attention**, followed by a **Graph Attention** backend (`backend/models.py`).
17
+ - [x] **Ablation Study (×4):** Four ablation variants systematically isolate each architectural component to validate the custom model design:
18
+ - **Ablation 1** — Wav2Vec2 + Graph (no CQCC, no cross-attention)
19
+ - **Ablation 2** — CQCC + Graph (no Wav2Vec2, no cross-attention)
20
+ - **Ablation 3** — Wav2Vec2 + CQCC + Simple Concat + Graph (no cross-attention)
21
+ - **Ablation 4** — Wav2Vec2 + CQCC + Cross-Attention + Linear (no Graph Attention)
22
+ - [x] **Fully Working Frontend:** Glassmorphic UI (Tailwind + Vanilla JS) served via FastAPI. Supports OGG/MP3/M4A/FLAC/WAV. Shows **side-by-side** predictions from all four primary models with real-time animated confidence bars and a per-window **temporal analysis timeline chart**.
23
+ - [x] **Cross-lingual Dataset Split:** Trained on English audio (`MLAAD-tiny/en`), tested on unseen German audio (`MLAAD-tiny/de`) for out-of-distribution generalisation evaluation.
24
+ - [x] **CQCC Feature Caching:** Pre-computed CQCC tensors are cached to disk to avoid redundant computation across training runs.
25
+
26
+ ---
27
+
28
+ ## Installation
29
+
30
+ Ensure you have Python 3.9+ installed. Install all dependencies:
31
+ ```bash
32
+ pip install -r requirements.txt
33
+ ```
34
+
35
+ ### Dataset Download
36
+ We use the `MLAAD-tiny` dataset (multi-language audio deepfakes). Download it from Hugging Face before training:
37
+ ```bash
38
+ pip install -U "huggingface_hub[cli]"
39
+ huggingface-cli download mueller91/MLAAD-tiny --repo-type dataset --local-dir MLAAD-tiny
40
+ ```
41
+
42
+ ---
43
+
44
+ ## Running the Project
45
+
46
+ ### Step 1 — (Optional) Pre-compute CQCC Cache
47
+ Pre-computing CQCC features once dramatically speeds up all subsequent training runs:
48
+ ```bash
49
+ python backend/train.py --precompute-cqcc-only
50
+ ```
51
+
52
+ ### Step 2 — Train All Models
53
+ Trains all 4 primary models + 4 ablation variants, evaluates on the German test set, and saves `.pth` weights to `backend/models/`:
54
+ ```bash
55
+ python backend/train.py
56
+ ```
57
+
58
+ #### Available Training Flags
59
+ | Flag | Default | Description |
60
+ |---|---|---|
61
+ | `--val-split F` | `0.2` | Fraction of English data reserved for validation (0–0.5). |
62
+ | `--data-dir PATH` | auto | Override dataset root (must contain `original/` and `fake/` folders). |
63
+ | `--cqcc-cache-dir PATH` | `backend/precomputed_features/cqcc` | Where to read/write cached CQCC tensors. |
64
+ | `--precompute-cqcc-only` | `False` | Build CQCC cache and exit without training. |
65
+ | `--force-rebuild-cqcc` | `False` | Recompute CQCC cache even if files already exist. |
66
+ | `--smoke-test` | `False` | Run one forward pass through every model and exit — useful for verifying setup. |
67
+
68
+ #### Quick Smoke Test
69
+ Verify all models initialise and run a forward pass correctly without full training:
70
+ ```bash
71
+ python backend/train.py --smoke-test
72
+ ```
73
+
74
+ ### Step 3 — Start the Web Interface
75
+ ```bash
76
+ uvicorn backend.app:app --reload
77
+ ```
78
+ Open **http://127.0.0.1:8000** in your browser. Upload any audio file (WAV, MP3, OGG, FLAC, M4A) to see simultaneous predictions from all four primary models plus an animated temporal confidence chart.
79
+
80
+ ---
81
+
82
+ ## Project Architecture
83
+
84
+ ```
85
+ AI Project/
86
+ ├── backend/
87
+ │ ├── models.py # All model architectures (3 baselines + custom + 4 ablations)
88
+ │ ├── dataset.py # AudioDataset with CQCC caching + data augmentation
89
+ │ ├── train.py # Full training + evaluation pipeline (CLI-driven)
90
+ │ ├── app.py # FastAPI inference server (windowed temporal analysis)
91
+ │ ├── preprocess.py # Standalone preprocessing utilities
92
+ │ └── models/ # Saved .pth weight files (generated after training)
93
+ ├── frontend/
94
+ │ ├── index.html # Glassmorphic UI shell
95
+ │ ├── script.js # File upload, Chart.js timeline, model panel rendering
96
+ │ └── style.css # Custom glassmorphism styles
97
+ ├── MLAAD-tiny/ # Dataset (downloaded separately)
98
+ ├── requirements.txt # Python dependencies
99
+ └── colab_training_notebook.ipynb # Google Colab training notebook
100
+ ```
101
+
102
+ ---
103
+
104
+ ## Working with Other Datasets
105
+ To replace MLAAD-tiny with another dataset (e.g., ASVspoof):
106
+ 1. Place your `fake/` and `original/` (or `real/`) audio folders into a `data/` directory at the project root.
107
+ 2. The `AudioDataset` in `dataset.py` auto-detects and falls back to the `data/` directory if `MLAAD-tiny` is absent.
108
+ 3. Re-run `python backend/train.py`. The full pipeline runs identically.
backend/app.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+ from fastapi import FastAPI, UploadFile, File
6
+ from fastapi.responses import JSONResponse
7
+ from fastapi.staticfiles import StaticFiles
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from dataset import compute_cqcc
10
+ import sys
11
+ import librosa
12
+
13
+ sys.path.append(os.path.dirname(__file__))
14
+
15
+ from models import (
16
+ Wav2Vec2SpoofDetector,
17
+ AASISTDetector,
18
+ CQCCBaselineDetector,
19
+ ImprovedWav2Vec2CQCCDetector
20
+ )
21
+
22
+ app = FastAPI(title="Deepfake Voice Detection")
23
+
24
+ app.add_middleware(
25
+ CORSMiddleware,
26
+ allow_origins=["*"],
27
+ allow_credentials=True,
28
+ allow_methods=["*"],
29
+ allow_headers=["*"],
30
+ )
31
+
32
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+
34
+ # -------------------------------------------------------
35
+ # Load Models
36
+ # -------------------------------------------------------
37
+
38
+ models_dir = os.path.join(os.path.dirname(__file__), "models")
39
+
40
+ def load_model(model, filename):
41
+ path = os.path.join(models_dir, filename)
42
+ if os.path.exists(path):
43
+ state_dict = torch.load(path, map_location=device)
44
+ # Handle weight_norm parametrization mismatch (common in Wav2Vec2 between versions)
45
+ # This converts the 'parametrizations' keys back to 'weight_g' and 'weight_v'
46
+ new_state_dict = {}
47
+ for k, v in state_dict.items():
48
+ if "pos_conv_embed.conv.parametrizations.weight.original0" in k:
49
+ new_key = k.replace("parametrizations.weight.original0", "weight_g")
50
+ new_state_dict[new_key] = v
51
+ elif "pos_conv_embed.conv.parametrizations.weight.original1" in k:
52
+ new_key = k.replace("parametrizations.weight.original1", "weight_v")
53
+ new_state_dict[new_key] = v
54
+ else:
55
+ new_state_dict[k] = v
56
+ model.load_state_dict(new_state_dict)
57
+ print(f"Loaded {filename}")
58
+ else:
59
+ print(f"WARNING: {filename} not found. Run train.py first!")
60
+ model.eval()
61
+
62
+ return model
63
+
64
+
65
+ wav2vec_model = load_model(
66
+ Wav2Vec2SpoofDetector(num_classes=2).to(device),
67
+ "wav2vec2.pth"
68
+ )
69
+
70
+ aasist_model = load_model(
71
+ AASISTDetector(num_classes=2).to(device),
72
+ "aasist.pth"
73
+ )
74
+
75
+ cqcc_baseline_model = load_model(
76
+ CQCCBaselineDetector(num_classes=2).to(device),
77
+ "cqcc_baseline.pth"
78
+ )
79
+
80
+ custom_hybrid_model = load_model(
81
+ ImprovedWav2Vec2CQCCDetector(num_classes=2).to(device),
82
+ "custom_hybrid.pth"
83
+ )
84
+
85
+
86
+ # -------------------------------------------------------
87
+ # Audio Preprocessing (mirrors dataset.py __getitem__)
88
+ # -------------------------------------------------------
89
+
90
+ TARGET_LEN = 64600 # AASIST standard: 4.025s at 16kHz
91
+ CQCC_N_BINS = 60 # Matches AudioDataset default
92
+
93
+ # 50% overlap: each step is half a window (~2s), giving smooth temporal curves
94
+ # without running 4x Wav2Vec2 passes per second.
95
+ WINDOW_STEP = TARGET_LEN // 2
96
+
97
+
98
+ def preprocess_window(wav_np: np.ndarray) -> tuple[torch.Tensor, torch.Tensor]:
99
+ """
100
+ Crop or pad a single audio window to TARGET_LEN, then compute waveform
101
+ and CQCC tensors — identical to AudioDataset.__getitem__ (non-augmented).
102
+
103
+ Returns:
104
+ wav : (1, TARGET_LEN) float32 tensor
105
+ cqcc : (1, 20, T) float32 tensor
106
+ """
107
+ # Center-crop or zero-pad to exactly TARGET_LEN (matches eval path in dataset.py)
108
+ if len(wav_np) > TARGET_LEN:
109
+ start = (len(wav_np) - TARGET_LEN) // 2
110
+ wav_np = wav_np[start : start + TARGET_LEN]
111
+ elif len(wav_np) < TARGET_LEN:
112
+ wav_np = np.pad(wav_np, (0, TARGET_LEN - len(wav_np)), mode='constant')
113
+
114
+ wav = torch.from_numpy(wav_np).unsqueeze(0).float()
115
+ cqcc = compute_cqcc(wav_np, n_bins=CQCC_N_BINS) # → (1, 20, T)
116
+ return wav, cqcc
117
+
118
+
119
+ def run_window(wav: torch.Tensor, cqcc: torch.Tensor) -> dict:
120
+ """
121
+ Run all four models on a single window and return fake probabilities (0–100).
122
+ """
123
+ wav = wav.unsqueeze(0).to(device) # (1, 1, TARGET_LEN)
124
+ cqcc = cqcc.unsqueeze(0).to(device) # (1, 1, 20, T)
125
+
126
+ with torch.no_grad():
127
+ w2v_prob = torch.softmax(wav2vec_model(wav), dim=1)[0][1].item()
128
+ aasist_prob = torch.softmax(aasist_model(wav), dim=1)[0][1].item()
129
+ cqcc_prob = torch.softmax(cqcc_baseline_model(cqcc), dim=1)[0][1].item()
130
+ custom_prob = torch.softmax(custom_hybrid_model(wav, cqcc), dim=1)[0][1].item()
131
+
132
+ return {
133
+ "wav2vec2": round(w2v_prob * 100, 2),
134
+ "aasist": round(aasist_prob * 100, 2),
135
+ "cqcc_baseline": round(cqcc_prob * 100, 2),
136
+ "custom_hybrid": round(custom_prob * 100, 2),
137
+ }
138
+
139
+
140
+ def aggregate_prediction(fake_prob_pct: float) -> dict:
141
+ """Convert a mean fake probability into the standard prediction dict."""
142
+ return {
143
+ "prediction": "FAKE" if fake_prob_pct > 50 else "REAL",
144
+ "fake_probability": fake_prob_pct,
145
+ "real_probability": round(100 - fake_prob_pct, 2),
146
+ }
147
+
148
+
149
+ # -------------------------------------------------------
150
+ # Prediction Endpoint
151
+ # -------------------------------------------------------
152
+ @app.post("/api/predict")
153
+ async def predict(file: UploadFile = File(...)):
154
+ temp_path = f"temp_{file.filename}"
155
+ try:
156
+ with open(temp_path, "wb") as f:
157
+ f.write(await file.read())
158
+
159
+ # Load at 16 kHz mono — identical to librosa.load call in dataset.py
160
+ wav_np, sr = librosa.load(temp_path, sr=16000, mono=True)
161
+
162
+ # -------------------------------------------------------
163
+ # Slice into overlapping windows of TARGET_LEN samples.
164
+ # Step = 50% overlap. Very short clips produce a single window.
165
+ # -------------------------------------------------------
166
+ total_samples = len(wav_np)
167
+ starts = list(range(0, total_samples, WINDOW_STEP))
168
+
169
+ window_probs = [] # per-window fake-probability dicts
170
+ window_labels = [] # x-axis: start of each window in seconds
171
+
172
+ for start in starts:
173
+ chunk = wav_np[start : start + TARGET_LEN]
174
+ wav_t, cqcc_t = preprocess_window(chunk)
175
+ probs = run_window(wav_t, cqcc_t)
176
+ window_probs.append(probs)
177
+
178
+ start_sec = round(start / sr, 2)
179
+ window_labels.append(start_sec)
180
+
181
+ # -------------------------------------------------------
182
+ # Overall prediction = mean fake probability across all windows
183
+ # -------------------------------------------------------
184
+ model_keys = ["wav2vec2", "aasist", "cqcc_baseline", "custom_hybrid"]
185
+ overall = {}
186
+ for key in model_keys:
187
+ mean_fake = round(
188
+ sum(w[key] for w in window_probs) / len(window_probs), 2
189
+ )
190
+ overall[key] = aggregate_prediction(mean_fake)
191
+
192
+ # -------------------------------------------------------
193
+ # Time-series data for the frontend chart
194
+ # -------------------------------------------------------
195
+ timeline = {
196
+ key: [w[key] for w in window_probs]
197
+ for key in model_keys
198
+ }
199
+
200
+ return JSONResponse({
201
+ "overall": overall, # {model: {prediction, fake_probability, real_probability}}
202
+ "timeline": timeline, # {model: [fake_prob_pct, ...]} — one value per window
203
+ "window_labels": window_labels, # [start_sec, ...] — x-axis in seconds (starts at 0.0)
204
+ })
205
+
206
+ except Exception as e:
207
+ return JSONResponse({"error": str(e)}, status_code=500)
208
+
209
+ finally:
210
+ if os.path.exists(temp_path):
211
+ os.remove(temp_path)
212
+
213
+ # -------------------------------------------------------
214
+ # Serve frontend
215
+ # -------------------------------------------------------
216
+
217
+ frontend_dir = os.path.join(os.path.dirname(__file__), "..", "frontend")
218
+
219
+ if os.path.exists(frontend_dir):
220
+ app.mount("/", StaticFiles(directory=frontend_dir, html=True), name="frontend")
221
+
222
+
223
+ # -------------------------------------------------------
224
+ # Run Server
225
+ # -------------------------------------------------------
226
+
227
+ if __name__ == "__main__":
228
+
229
+ import uvicorn
230
+
231
+ print("Starting server at http://127.0.0.1:8000")
232
+
233
+ uvicorn.run(app, host="127.0.0.1", port=8000)
backend/dataset.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import hashlib
3
+ import torch
4
+ import torchaudio
5
+ import numpy as np
6
+ from torch.utils.data import Dataset
7
+ import librosa
8
+ from scipy.fftpack import dct
9
+
10
+ def compute_cqcc(wav_np, n_bins, sample_rate=16000, hop_length=160, num_coeffs=20):
11
+ """Compute CQCC features from a mono waveform numpy array."""
12
+ try:
13
+ cqt = np.abs(
14
+ librosa.cqt(
15
+ wav_np,
16
+ sr=sample_rate,
17
+ n_bins=n_bins,
18
+ hop_length=hop_length,
19
+ fmin=librosa.note_to_hz('C1')
20
+ )
21
+ )
22
+ log_power = librosa.amplitude_to_db(cqt, ref=np.max)
23
+ cqcc = dct(log_power, type=2, axis=0, norm='ortho')[:num_coeffs]
24
+ return torch.from_numpy(cqcc).unsqueeze(0).float()
25
+ except Exception:
26
+ # Fallback for very short or invalid audio.
27
+ return torch.zeros((1, num_coeffs, 10), dtype=torch.float32)
28
+
29
+ class AudioDataset(Dataset):
30
+ def __init__(self, data_dir=None, n_bins=60, augment=False, cqcc_cache_dir=None, target_lang=None):
31
+ if data_dir is None:
32
+ # Check if MLAAD-tiny exists, else fallback to 'data'
33
+ mlaad_dir = os.path.join(os.path.dirname(__file__), "..", "MLAAD-tiny")
34
+ if os.path.exists(mlaad_dir):
35
+ data_dir = mlaad_dir
36
+ else:
37
+ data_dir = os.path.join(os.path.dirname(__file__), "..", "data")
38
+
39
+ self.data_dir = data_dir
40
+ self.files = []
41
+ self.labels = []
42
+ self.n_bins = n_bins
43
+ self.augment = augment
44
+ self.cqcc_cache_dir = cqcc_cache_dir
45
+ self.target_lang = target_lang
46
+
47
+ real_path = os.path.join(data_dir, "original")
48
+ if not os.path.exists(real_path):
49
+ real_path = os.path.join(data_dir, "real")
50
+
51
+ fake_path = os.path.join(data_dir, "fake")
52
+
53
+ for root, dirs, files in os.walk(real_path):
54
+ dirs.sort()
55
+ files.sort()
56
+ for f in files:
57
+ if f.endswith('.wav') or f.endswith('.flac'):
58
+ if self.target_lang:
59
+ rel_root = os.path.relpath(root, real_path).replace('\\', '/')
60
+ if not rel_root.startswith(self.target_lang):
61
+ continue
62
+ self.files.append(os.path.join(root, f))
63
+ self.labels.append(0) # 0 = Real
64
+
65
+ for root, dirs, files in os.walk(fake_path):
66
+ dirs.sort()
67
+ files.sort()
68
+ for f in files:
69
+ if f.endswith('.wav') or f.endswith('.flac'):
70
+ if self.target_lang:
71
+ rel_root = os.path.relpath(root, fake_path).replace('\\', '/')
72
+ if not rel_root.startswith(self.target_lang):
73
+ continue
74
+ self.files.append(os.path.join(root, f))
75
+ self.labels.append(1) # 1 = Fake
76
+
77
+ if self.cqcc_cache_dir is not None:
78
+ os.makedirs(self.cqcc_cache_dir, exist_ok=True)
79
+
80
+ def __len__(self):
81
+ return len(self.files)
82
+
83
+ def _cqcc_cache_path(self, audio_path):
84
+ rel_path = os.path.relpath(audio_path, start=self.data_dir)
85
+ cache_key = hashlib.md5(audio_path.encode("utf-8")).hexdigest()
86
+ rel_stem = os.path.splitext(rel_path)[0]
87
+ safe_name = rel_stem.replace(os.sep, "__")
88
+ return os.path.join(self.cqcc_cache_dir, f"{safe_name}_{cache_key}.pt")
89
+
90
+ def _load_or_compute_cqcc(self, audio_path, wav_np, is_augmented=False):
91
+ if self.cqcc_cache_dir is None or is_augmented:
92
+ return compute_cqcc(wav_np, n_bins=self.n_bins)
93
+
94
+ cache_path = self._cqcc_cache_path(audio_path)
95
+ if os.path.exists(cache_path):
96
+ return torch.load(cache_path, map_location="cpu")
97
+
98
+ cqcc = compute_cqcc(wav_np, n_bins=self.n_bins)
99
+ torch.save(cqcc, cache_path)
100
+ return cqcc
101
+
102
+ def precompute_cqcc_cache(self, force=False):
103
+ """Materialize CQCC features to disk so training can reuse them."""
104
+ import tqdm
105
+ if self.cqcc_cache_dir is None:
106
+ raise ValueError("cqcc_cache_dir must be set to precompute CQCC features.")
107
+
108
+ try:
109
+ from tqdm.notebook import tqdm
110
+ iterable_files = tqdm(self.files, desc="Precomputing CQCC Cache")
111
+ except ImportError:
112
+ iterable_files = self.files
113
+
114
+ total = len(self.files)
115
+ for idx, audio_path in enumerate(iterable_files):
116
+ cache_path = self._cqcc_cache_path(audio_path)
117
+ if not force and os.path.exists(cache_path):
118
+ continue
119
+
120
+ try:
121
+ wav_np, _ = librosa.load(audio_path, sr=16000, mono=True)
122
+ cqcc = compute_cqcc(wav_np, n_bins=self.n_bins)
123
+ torch.save(cqcc, cache_path)
124
+ except Exception as e:
125
+ print(f"Error precomputing CQCC for {audio_path}: {e}")
126
+
127
+
128
+ if (idx + 1) % 100 == 0 or idx + 1 == total:
129
+ print(f"Precomputed CQCC {idx + 1}/{total}")
130
+
131
+ def __getitem__(self, idx):
132
+ audio_path = self.files[idx]
133
+ wav_np, sr = librosa.load(audio_path, sr=16000, mono=True)
134
+
135
+ is_augmented = False
136
+ # Augmentation on raw audio (Data Augmentation for generalizability)
137
+ if self.augment and np.random.rand() < 0.3:
138
+ # Apply only ONE augmentation type per sample to avoid over-modification
139
+ aug_type = np.random.choice(['noise', 'speed', 'pitch'], p=[0.33, 0.33, 0.34])
140
+
141
+ if aug_type == 'noise':
142
+ # SNR-based noise addition (reverted to original robust method)
143
+ signal_power = np.mean(wav_np**2)
144
+ if signal_power > 1e-10:
145
+ snr_db = np.random.uniform(10, 30)
146
+ snr_linear = 10**(snr_db / 10)
147
+ noise_power = signal_power / snr_linear
148
+ noise = np.random.randn(len(wav_np)) * np.sqrt(noise_power)
149
+ wav_np = wav_np + noise
150
+ is_augmented = True
151
+ elif aug_type == 'speed':
152
+ # Mild speed perturbation
153
+ speed_factor = np.random.uniform(0.95, 1.05)
154
+ wav_np = librosa.effects.time_stretch(wav_np, rate=speed_factor)
155
+ is_augmented = True
156
+ elif aug_type == 'pitch':
157
+ # Subtle pitch shift
158
+ n_steps = np.random.uniform(-1, 1)
159
+ wav_np = librosa.effects.pitch_shift(wav_np, sr=sr, n_steps=n_steps)
160
+ is_augmented = True
161
+
162
+ # Crop or pad to exactly 64600 samples (AASIST standard)
163
+ target_len = 64600
164
+ if len(wav_np) > target_len:
165
+ # Center crop or random crop for augment instead of taking just the start.
166
+ if self.augment:
167
+ start = np.random.randint(0, len(wav_np) - target_len)
168
+ else:
169
+ start = (len(wav_np) - target_len) // 2
170
+ wav_np = wav_np[start:start+target_len]
171
+ elif len(wav_np) < target_len:
172
+ pad = target_len - len(wav_np)
173
+ wav_np = np.pad(wav_np, (0, pad), 'constant')
174
+
175
+ wav = torch.from_numpy(wav_np).unsqueeze(0).float()
176
+
177
+ cqcc = self._load_or_compute_cqcc(audio_path, wav_np, is_augmented=is_augmented)
178
+
179
+ return wav, cqcc, self.labels[idx]
180
+
181
+
182
+ def collate_variable_length(batch):
183
+
184
+ wavs, cqccs, labels = zip(*batch)
185
+ labels = torch.tensor(labels)
186
+
187
+ # ---------- WAVE ----------
188
+ max_wav_len = max(w.shape[-1] for w in wavs)
189
+
190
+ wavs_padded = []
191
+ for w in wavs:
192
+ if w.shape[-1] < max_wav_len:
193
+ pad = max_wav_len - w.shape[-1]
194
+ w = torch.nn.functional.pad(w, (0, pad))
195
+ wavs_padded.append(w)
196
+
197
+ wavs = torch.stack(wavs_padded, dim=0)
198
+
199
+ # ---------- CQCC ----------
200
+ max_cqcc_len = max(c.shape[-1] for c in cqccs)
201
+
202
+ cqccs_padded = []
203
+ for c in cqccs:
204
+ if c.shape[-1] < max_cqcc_len:
205
+ pad = max_cqcc_len - c.shape[-1]
206
+ c = torch.nn.functional.pad(c, (0, pad))
207
+ cqccs_padded.append(c)
208
+
209
+ cqccs = torch.stack(cqccs_padded, dim=0)
210
+
211
+ return wavs, cqccs, labels
backend/models.py ADDED
@@ -0,0 +1,1019 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import Wav2Vec2Model
5
+
6
+
7
+ # ============================================================
8
+ # 1. Wav2Vec2 Detector (Self-supervised Transformer Baseline)
9
+ # ============================================================
10
+ class AttentivePooling(nn.Module):
11
+ def __init__(self, dim):
12
+ super().__init__()
13
+ self.attn = nn.Sequential(
14
+ nn.Linear(dim, dim),
15
+ nn.Tanh(),
16
+ nn.Linear(dim, 1)
17
+ )
18
+ def forward(self, x):
19
+ w = torch.softmax(self.attn(x), dim=1)
20
+ return torch.sum(w * x, dim=1)
21
+
22
+ class Wav2Vec2SpoofDetector(nn.Module):
23
+ def __init__(self, num_classes=2, model_name="facebook/wav2vec2-base"):
24
+ super().__init__()
25
+ self.wav2vec = Wav2Vec2Model.from_pretrained(model_name)
26
+
27
+ #freeze model
28
+ for param in self.wav2vec.parameters():
29
+ param.requires_grad = False
30
+
31
+ hidden = self.wav2vec.config.hidden_size
32
+ self.pool = AttentivePooling(hidden)
33
+ self.classifier = nn.Sequential(
34
+ nn.LayerNorm(hidden),
35
+ nn.Dropout(0.2),
36
+ nn.Linear(hidden, num_classes)
37
+ )
38
+ def forward(self, x):
39
+ if x.dim() == 3:
40
+ x = x.squeeze(1)
41
+ out = self.wav2vec(x).last_hidden_state
42
+ pooled = self.pool(out)
43
+ return self.classifier(pooled)
44
+
45
+ # ============================================================
46
+ # 2. AASIST (SOTA Graph-based Baseline)
47
+ # ============================================================
48
+
49
+ import random
50
+ from typing import Union
51
+ import numpy as np
52
+ from torch import Tensor
53
+
54
+ # Original simplistic Graph Attention/Block kept for the Custom model dependent on it
55
+ class GraphAttention(nn.Module):
56
+ def __init__(self, in_dim, out_dim):
57
+ super().__init__()
58
+ self.fc = nn.Linear(in_dim, out_dim)
59
+ self.attn = nn.Linear(out_dim * 2, 1)
60
+
61
+ def forward(self, x):
62
+ h = self.fc(x)
63
+ # Instead of allocating O(N^2 * D) tensor arrays for pairwise combinations,
64
+ # we can decompose the linear attention matrix and use broadcasting!
65
+ # Memory consumption goes from ~10GB on N=400 to ~2MB.
66
+ W = self.attn.weight.squeeze()
67
+ D = h.shape[-1]
68
+
69
+ W_1 = W[:D]
70
+ W_2 = W[D:]
71
+
72
+ # Compute individual node scores: shape (B, N, 1)
73
+ score_i = torch.matmul(h, W_1).unsqueeze(-1)
74
+ score_j = torch.matmul(h, W_2).unsqueeze(-1)
75
+
76
+ # Broadcast (B, N, 1) + (B, 1, N) -> (B, N, N)
77
+ e = score_i + score_j.transpose(1, 2)
78
+
79
+ if self.attn.bias is not None:
80
+ e = e + self.attn.bias
81
+
82
+ alpha = F.softmax(e, dim=-1)
83
+ out = torch.matmul(alpha, h)
84
+ return out
85
+
86
+ class GraphBlock(nn.Module):
87
+ def __init__(self, dim):
88
+ super().__init__()
89
+ self.gat = GraphAttention(dim, dim)
90
+ self.norm = nn.LayerNorm(dim)
91
+ self.dropout = nn.Dropout(0.2)
92
+
93
+ def forward(self, x):
94
+ res = x
95
+ x = self.gat(x)
96
+ x = self.dropout(x)
97
+ x = self.norm(x + res)
98
+ return x
99
+
100
+ class GraphAttentionLayer(nn.Module):
101
+ def __init__(self, in_dim, out_dim, **kwargs):
102
+ super().__init__()
103
+
104
+ # attention map
105
+ self.att_proj = nn.Linear(in_dim, out_dim)
106
+ self.att_weight = self._init_new_params(out_dim, 1)
107
+
108
+ # project
109
+ self.proj_with_att = nn.Linear(in_dim, out_dim)
110
+ self.proj_without_att = nn.Linear(in_dim, out_dim)
111
+
112
+ # batch norm
113
+ self.bn = nn.BatchNorm1d(out_dim)
114
+
115
+ # dropout for inputs
116
+ self.input_drop = nn.Dropout(p=0.2)
117
+
118
+ # activate
119
+ self.act = nn.SELU(inplace=True)
120
+
121
+ # temperature
122
+ self.temp = 1.
123
+ if "temperature" in kwargs:
124
+ self.temp = kwargs["temperature"]
125
+
126
+ def forward(self, x):
127
+ '''
128
+ x :(#bs, #node, #dim)
129
+ '''
130
+ # apply input dropout
131
+ x = self.input_drop(x)
132
+
133
+ # derive attention map
134
+ att_map = self._derive_att_map(x)
135
+
136
+ # projection
137
+ x = self._project(x, att_map)
138
+
139
+ # apply batch norm
140
+ x = self._apply_BN(x)
141
+ x = self.act(x)
142
+ return x
143
+
144
+ def _pairwise_mul_nodes(self, x):
145
+ '''
146
+ Calculates pairwise multiplication of nodes.
147
+ - for attention map
148
+ x :(#bs, #node, #dim)
149
+ out_shape :(#bs, #node, #node, #dim)
150
+ '''
151
+
152
+ nb_nodes = x.size(1)
153
+ x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
154
+ x_mirror = x.transpose(1, 2)
155
+
156
+ return x * x_mirror
157
+
158
+ def _derive_att_map(self, x):
159
+ '''
160
+ x :(#bs, #node, #dim)
161
+ out_shape :(#bs, #node, #node, 1)
162
+ '''
163
+ att_map = self._pairwise_mul_nodes(x)
164
+ # size: (#bs, #node, #node, #dim_out)
165
+ att_map = torch.tanh(self.att_proj(att_map))
166
+ # size: (#bs, #node, #node, 1)
167
+ att_map = torch.matmul(att_map, self.att_weight)
168
+
169
+ # apply temperature
170
+ att_map = att_map / self.temp
171
+
172
+ att_map = F.softmax(att_map, dim=-2)
173
+
174
+ return att_map
175
+
176
+ def _project(self, x, att_map):
177
+ x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
178
+ x2 = self.proj_without_att(x)
179
+
180
+ return x1 + x2
181
+
182
+ def _apply_BN(self, x):
183
+ org_size = x.size()
184
+ x = x.view(-1, org_size[-1])
185
+ x = self.bn(x)
186
+ x = x.view(org_size)
187
+
188
+ return x
189
+
190
+ def _init_new_params(self, *size):
191
+ out = nn.Parameter(torch.FloatTensor(*size))
192
+ nn.init.xavier_normal_(out)
193
+ return out
194
+
195
+
196
+ class HtrgGraphAttentionLayer(nn.Module):
197
+ def __init__(self, in_dim, out_dim, **kwargs):
198
+ super().__init__()
199
+
200
+ self.proj_type1 = nn.Linear(in_dim, in_dim)
201
+ self.proj_type2 = nn.Linear(in_dim, in_dim)
202
+
203
+ # attention map
204
+ self.att_proj = nn.Linear(in_dim, out_dim)
205
+ self.att_projM = nn.Linear(in_dim, out_dim)
206
+
207
+ self.att_weight11 = self._init_new_params(out_dim, 1)
208
+ self.att_weight22 = self._init_new_params(out_dim, 1)
209
+ self.att_weight12 = self._init_new_params(out_dim, 1)
210
+ self.att_weightM = self._init_new_params(out_dim, 1)
211
+
212
+ # project
213
+ self.proj_with_att = nn.Linear(in_dim, out_dim)
214
+ self.proj_without_att = nn.Linear(in_dim, out_dim)
215
+
216
+ self.proj_with_attM = nn.Linear(in_dim, out_dim)
217
+ self.proj_without_attM = nn.Linear(in_dim, out_dim)
218
+
219
+ # batch norm
220
+ self.bn = nn.BatchNorm1d(out_dim)
221
+
222
+ # dropout for inputs
223
+ self.input_drop = nn.Dropout(p=0.2)
224
+
225
+ # activate
226
+ self.act = nn.SELU(inplace=True)
227
+
228
+ # temperature
229
+ self.temp = 1.
230
+ if "temperature" in kwargs:
231
+ self.temp = kwargs["temperature"]
232
+
233
+ def forward(self, x1, x2, master=None):
234
+ '''
235
+ x1 :(#bs, #node, #dim)
236
+ x2 :(#bs, #node, #dim)
237
+ '''
238
+ num_type1 = x1.size(1)
239
+ num_type2 = x2.size(1)
240
+
241
+ x1 = self.proj_type1(x1)
242
+ x2 = self.proj_type2(x2)
243
+
244
+ x = torch.cat([x1, x2], dim=1)
245
+
246
+ if master is None:
247
+ master = torch.mean(x, dim=1, keepdim=True)
248
+
249
+ # apply input dropout
250
+ x = self.input_drop(x)
251
+
252
+ # derive attention map
253
+ att_map = self._derive_att_map(x, num_type1, num_type2)
254
+
255
+ # directional edge for master node
256
+ master = self._update_master(x, master)
257
+
258
+ # projection
259
+ x = self._project(x, att_map)
260
+
261
+ # apply batch norm
262
+ x = self._apply_BN(x)
263
+ x = self.act(x)
264
+
265
+ x1 = x.narrow(1, 0, num_type1)
266
+ x2 = x.narrow(1, num_type1, num_type2)
267
+
268
+ return x1, x2, master
269
+
270
+ def _update_master(self, x, master):
271
+
272
+ att_map = self._derive_att_map_master(x, master)
273
+ master = self._project_master(x, master, att_map)
274
+
275
+ return master
276
+
277
+ def _pairwise_mul_nodes(self, x):
278
+ '''
279
+ Calculates pairwise multiplication of nodes.
280
+ - for attention map
281
+ x :(#bs, #node, #dim)
282
+ out_shape :(#bs, #node, #node, #dim)
283
+ '''
284
+
285
+ nb_nodes = x.size(1)
286
+ x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
287
+ x_mirror = x.transpose(1, 2)
288
+
289
+ return x * x_mirror
290
+
291
+ def _derive_att_map_master(self, x, master):
292
+ '''
293
+ x :(#bs, #node, #dim)
294
+ out_shape :(#bs, #node, #node, 1)
295
+ '''
296
+ att_map = x * master
297
+ att_map = torch.tanh(self.att_projM(att_map))
298
+
299
+ att_map = torch.matmul(att_map, self.att_weightM)
300
+
301
+ # apply temperature
302
+ att_map = att_map / self.temp
303
+
304
+ att_map = F.softmax(att_map, dim=-2)
305
+
306
+ return att_map
307
+
308
+ def _derive_att_map(self, x, num_type1, num_type2):
309
+ '''
310
+ x :(#bs, #node, #dim)
311
+ out_shape :(#bs, #node, #node, 1)
312
+ '''
313
+ att_map = self._pairwise_mul_nodes(x)
314
+ # size: (#bs, #node, #node, #dim_out)
315
+ att_map = torch.tanh(self.att_proj(att_map))
316
+ # size: (#bs, #node, #node, 1)
317
+
318
+ att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1)
319
+
320
+ att_board[:, :num_type1, :num_type1, :] = torch.matmul(
321
+ att_map[:, :num_type1, :num_type1, :], self.att_weight11)
322
+ att_board[:, num_type1:, num_type1:, :] = torch.matmul(
323
+ att_map[:, num_type1:, num_type1:, :], self.att_weight22)
324
+ att_board[:, :num_type1, num_type1:, :] = torch.matmul(
325
+ att_map[:, :num_type1, num_type1:, :], self.att_weight12)
326
+ att_board[:, num_type1:, :num_type1, :] = torch.matmul(
327
+ att_map[:, num_type1:, :num_type1, :], self.att_weight12)
328
+
329
+ att_map = att_board
330
+
331
+ # apply temperature
332
+ att_map = att_map / self.temp
333
+
334
+ att_map = F.softmax(att_map, dim=-2)
335
+
336
+ return att_map
337
+
338
+ def _project(self, x, att_map):
339
+ x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
340
+ x2 = self.proj_without_att(x)
341
+
342
+ return x1 + x2
343
+
344
+ def _project_master(self, x, master, att_map):
345
+
346
+ x1 = self.proj_with_attM(torch.matmul(
347
+ att_map.squeeze(-1).unsqueeze(1), x))
348
+ x2 = self.proj_without_attM(master)
349
+
350
+ return x1 + x2
351
+
352
+ def _apply_BN(self, x):
353
+ org_size = x.size()
354
+ x = x.view(-1, org_size[-1])
355
+ x = self.bn(x)
356
+ x = x.view(org_size)
357
+
358
+ return x
359
+
360
+ def _init_new_params(self, *size):
361
+ out = nn.Parameter(torch.FloatTensor(*size))
362
+ nn.init.xavier_normal_(out)
363
+ return out
364
+
365
+
366
+ class GraphPool(nn.Module):
367
+ def __init__(self, k: float, in_dim: int, p: Union[float, int]):
368
+ super().__init__()
369
+ self.k = k
370
+ self.sigmoid = nn.Sigmoid()
371
+ self.proj = nn.Linear(in_dim, 1)
372
+ self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
373
+ self.in_dim = in_dim
374
+
375
+ def forward(self, h):
376
+ Z = self.drop(h)
377
+ weights = self.proj(Z)
378
+ scores = self.sigmoid(weights)
379
+ new_h = self.top_k_graph(scores, h, self.k)
380
+
381
+ return new_h
382
+
383
+ def top_k_graph(self, scores, h, k):
384
+ _, n_nodes, n_feat = h.size()
385
+ n_nodes = max(int(n_nodes * k), 1)
386
+ _, idx = torch.topk(scores, n_nodes, dim=1)
387
+ idx = idx.expand(-1, -1, n_feat)
388
+
389
+ h = h * scores
390
+ h = torch.gather(h, 1, idx)
391
+
392
+ return h
393
+
394
+
395
+ class CONV(nn.Module):
396
+ @staticmethod
397
+ def to_mel(hz):
398
+ return 2595 * np.log10(1 + hz / 700)
399
+
400
+ @staticmethod
401
+ def to_hz(mel):
402
+ return 700 * (10**(mel / 2595) - 1)
403
+
404
+ def __init__(self,
405
+ out_channels,
406
+ kernel_size,
407
+ sample_rate=16000,
408
+ in_channels=1,
409
+ stride=1,
410
+ padding=0,
411
+ dilation=1,
412
+ bias=False,
413
+ groups=1,
414
+ mask=False):
415
+ super().__init__()
416
+ if in_channels != 1:
417
+ msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels)
418
+ raise ValueError(msg)
419
+ self.out_channels = out_channels
420
+ self.kernel_size = kernel_size
421
+ self.sample_rate = sample_rate
422
+
423
+ # Forcing the filters to be odd (i.e, perfectly symmetrics)
424
+ if kernel_size % 2 == 0:
425
+ self.kernel_size = self.kernel_size + 1
426
+ self.stride = stride
427
+ self.padding = padding
428
+ self.dilation = dilation
429
+ self.mask = mask
430
+ if bias:
431
+ raise ValueError('SincConv does not support bias.')
432
+ if groups > 1:
433
+ raise ValueError('SincConv does not support groups.')
434
+
435
+ NFFT = 512
436
+ f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
437
+ fmel = self.to_mel(f)
438
+ fmelmax = np.max(fmel)
439
+ fmelmin = np.min(fmel)
440
+ filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 1)
441
+ filbandwidthsf = self.to_hz(filbandwidthsmel)
442
+
443
+ self.mel = filbandwidthsf
444
+ self.hsupp = torch.arange(-(self.kernel_size - 1) / 2,
445
+ (self.kernel_size - 1) / 2 + 1)
446
+ self.band_pass = torch.zeros(self.out_channels, self.kernel_size)
447
+ for i in range(len(self.mel) - 1):
448
+ fmin = self.mel[i]
449
+ fmax = self.mel[i + 1]
450
+ hHigh = (2*fmax/self.sample_rate) * \
451
+ np.sinc(2*fmax*self.hsupp/self.sample_rate)
452
+ hLow = (2*fmin/self.sample_rate) * \
453
+ np.sinc(2*fmin*self.hsupp/self.sample_rate)
454
+ hideal = hHigh - hLow
455
+
456
+ self.band_pass[i, :] = Tensor(np.hamming(
457
+ self.kernel_size)) * Tensor(hideal)
458
+
459
+ def forward(self, x, mask=False):
460
+ band_pass_filter = self.band_pass.clone().to(x.device)
461
+ if mask:
462
+ A = np.random.uniform(0, 20)
463
+ A = int(A)
464
+ A0 = random.randint(0, band_pass_filter.shape[0] - A)
465
+ band_pass_filter[A0:A0 + A, :] = 0
466
+ else:
467
+ band_pass_filter = band_pass_filter
468
+
469
+ self.filters = (band_pass_filter).view(self.out_channels, 1,
470
+ self.kernel_size)
471
+
472
+ return F.conv1d(x,
473
+ self.filters,
474
+ stride=self.stride,
475
+ padding=self.padding,
476
+ dilation=self.dilation,
477
+ bias=None,
478
+ groups=1)
479
+
480
+
481
+ class Residual_block(nn.Module):
482
+ def __init__(self, nb_filts, first=False):
483
+ super().__init__()
484
+ self.first = first
485
+
486
+ if not self.first:
487
+ self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
488
+ self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
489
+ out_channels=nb_filts[1],
490
+ kernel_size=(2, 3),
491
+ padding=(1, 1),
492
+ stride=1)
493
+ self.selu = nn.SELU(inplace=True)
494
+
495
+ self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
496
+ self.conv2 = nn.Conv2d(in_channels=nb_filts[1],
497
+ out_channels=nb_filts[1],
498
+ kernel_size=(2, 3),
499
+ padding=(0, 1),
500
+ stride=1)
501
+
502
+ if nb_filts[0] != nb_filts[1]:
503
+ self.downsample = True
504
+ self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
505
+ out_channels=nb_filts[1],
506
+ padding=(0, 1),
507
+ kernel_size=(1, 3),
508
+ stride=1)
509
+
510
+ else:
511
+ self.downsample = False
512
+ self.mp = nn.MaxPool2d((1, 3))
513
+
514
+ def forward(self, x):
515
+ identity = x
516
+ if not self.first:
517
+ out = self.bn1(x)
518
+ out = self.selu(out)
519
+ else:
520
+ out = x
521
+ out = self.conv1(x)
522
+
523
+ out = self.bn2(out)
524
+ out = self.selu(out)
525
+ out = self.conv2(out)
526
+ if self.downsample:
527
+ identity = self.conv_downsample(identity)
528
+
529
+ out += identity
530
+ out = self.mp(out)
531
+ return out
532
+
533
+
534
+ class AASISTModel(nn.Module):
535
+ def __init__(self, d_args):
536
+ super().__init__()
537
+
538
+ self.d_args = d_args
539
+ filts = d_args["filts"]
540
+ gat_dims = d_args["gat_dims"]
541
+ pool_ratios = d_args["pool_ratios"]
542
+ temperatures = d_args["temperatures"]
543
+
544
+ self.conv_time = CONV(out_channels=filts[0],
545
+ kernel_size=d_args["first_conv"],
546
+ in_channels=1)
547
+ self.first_bn = nn.BatchNorm2d(num_features=1)
548
+
549
+ self.drop = nn.Dropout(0.5, inplace=True)
550
+ self.drop_way = nn.Dropout(0.2, inplace=True)
551
+ self.selu = nn.SELU(inplace=True)
552
+
553
+ self.encoder = nn.Sequential(
554
+ nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
555
+ nn.Sequential(Residual_block(nb_filts=filts[2])),
556
+ nn.Sequential(Residual_block(nb_filts=filts[3])),
557
+ nn.Sequential(Residual_block(nb_filts=filts[4])),
558
+ nn.Sequential(Residual_block(nb_filts=filts[4])),
559
+ nn.Sequential(Residual_block(nb_filts=filts[4])))
560
+
561
+ self.pos_S = nn.Parameter(torch.randn(1, 23, filts[-1][-1]))
562
+ self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
563
+ self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
564
+
565
+ self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1],
566
+ gat_dims[0],
567
+ temperature=temperatures[0])
568
+ self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1],
569
+ gat_dims[0],
570
+ temperature=temperatures[1])
571
+
572
+ self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(
573
+ gat_dims[0], gat_dims[1], temperature=temperatures[2])
574
+ self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(
575
+ gat_dims[1], gat_dims[1], temperature=temperatures[2])
576
+
577
+ self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(
578
+ gat_dims[0], gat_dims[1], temperature=temperatures[2])
579
+
580
+ self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(
581
+ gat_dims[1], gat_dims[1], temperature=temperatures[2])
582
+
583
+ self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
584
+ self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
585
+ self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
586
+ self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
587
+
588
+ self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
589
+ self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
590
+
591
+ self.out_layer = nn.Linear(5 * gat_dims[1], 2)
592
+
593
+ def forward(self, x, Freq_aug=False):
594
+
595
+ x = x.unsqueeze(1)
596
+ x = self.conv_time(x, mask=Freq_aug)
597
+ x = x.unsqueeze(dim=1)
598
+ x = F.max_pool2d(torch.abs(x), (3, 3))
599
+ x = self.first_bn(x)
600
+ x = self.selu(x)
601
+
602
+ e = self.encoder(x)
603
+
604
+ e_S, _ = torch.max(torch.abs(e), dim=3)
605
+ e_S = e_S.transpose(1, 2) + self.pos_S
606
+
607
+ gat_S = self.GAT_layer_S(e_S)
608
+ out_S = self.pool_S(gat_S)
609
+
610
+ e_T, _ = torch.max(torch.abs(e), dim=2)
611
+ e_T = e_T.transpose(1, 2)
612
+
613
+ gat_T = self.GAT_layer_T(e_T)
614
+ out_T = self.pool_T(gat_T)
615
+
616
+ master1 = self.master1.expand(x.size(0), -1, -1)
617
+ master2 = self.master2.expand(x.size(0), -1, -1)
618
+
619
+ out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(
620
+ out_T, out_S, master=self.master1)
621
+
622
+ out_S1 = self.pool_hS1(out_S1)
623
+ out_T1 = self.pool_hT1(out_T1)
624
+
625
+ out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(
626
+ out_T1, out_S1, master=master1)
627
+ out_T1 = out_T1 + out_T_aug
628
+ out_S1 = out_S1 + out_S_aug
629
+ master1 = master1 + master_aug
630
+
631
+ out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(
632
+ out_T, out_S, master=self.master2)
633
+ out_S2 = self.pool_hS2(out_S2)
634
+ out_T2 = self.pool_hT2(out_T2)
635
+
636
+ out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(
637
+ out_T2, out_S2, master=master2)
638
+ out_T2 = out_T2 + out_T_aug
639
+ out_S2 = out_S2 + out_S_aug
640
+ master2 = master2 + master_aug
641
+
642
+ out_T1 = self.drop_way(out_T1)
643
+ out_T2 = self.drop_way(out_T2)
644
+ out_S1 = self.drop_way(out_S1)
645
+ out_S2 = self.drop_way(out_S2)
646
+ master1 = self.drop_way(master1)
647
+ master2 = self.drop_way(master2)
648
+
649
+ out_T = torch.max(out_T1, out_T2)
650
+ out_S = torch.max(out_S1, out_S2)
651
+ master = torch.max(master1, master2)
652
+
653
+ T_max, _ = torch.max(torch.abs(out_T), dim=1)
654
+ T_avg = torch.mean(out_T, dim=1)
655
+
656
+ S_max, _ = torch.max(torch.abs(out_S), dim=1)
657
+ S_avg = torch.mean(out_S, dim=1)
658
+
659
+ last_hidden = torch.cat(
660
+ [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)
661
+
662
+ last_hidden = self.drop(last_hidden)
663
+ output = self.out_layer(last_hidden)
664
+
665
+ return last_hidden, output
666
+
667
+ class AASISTDetector(nn.Module):
668
+ def __init__(self, num_classes=2):
669
+ super().__init__()
670
+ d_args = {
671
+ "nb_samp": 64600,
672
+ "first_conv": 128,
673
+ "in_channels": 1,
674
+ "filts": [70, [1, 32], [32, 32], [32, 64], [64, 64]],
675
+ "gat_dims": [64, 32],
676
+ "pool_ratios": [0.5, 0.7, 0.5, 0.5],
677
+ "temperatures": [2.0, 2.0, 100.0]
678
+ }
679
+ self.model = AASISTModel(d_args)
680
+
681
+ # Override out_layer if not strictly 2 classes.
682
+ if num_classes != 2:
683
+ self.model.out_layer = nn.Linear(5 * d_args["gat_dims"][1], num_classes)
684
+
685
+ def forward(self, x):
686
+ # x is (B, 1, T) or (B, T)
687
+ if x.dim() == 3:
688
+ x = x.squeeze(1) # Convert to (B, T)
689
+ _, out = self.model(x)
690
+ return out
691
+
692
+ # ============================================================
693
+ # 3. CQCC Baseline Detector (Acoustic Feature Baseline)
694
+ # ============================================================
695
+
696
+ class CQCCBaselineDetector(nn.Module):
697
+ def __init__(self, num_classes=2):
698
+ super().__init__()
699
+ # Input shape expected: (B, 1, 20, T)
700
+ self.features = nn.Sequential(
701
+ nn.Conv2d(1, 16, 3, padding=1),
702
+ nn.BatchNorm2d(16),
703
+ nn.ReLU(),
704
+ nn.MaxPool2d(2),
705
+ nn.Conv2d(16, 32, 3, padding=1),
706
+ nn.BatchNorm2d(32),
707
+ nn.ReLU(),
708
+ nn.MaxPool2d(2),
709
+ nn.Conv2d(32, 64, 3, padding=1),
710
+ nn.BatchNorm2d(64),
711
+ nn.ReLU(),
712
+ nn.AdaptiveAvgPool2d(1)
713
+ )
714
+ self.classifier = nn.Sequential(
715
+ nn.Dropout(0.3),
716
+ nn.Linear(64, num_classes)
717
+ )
718
+
719
+ def forward(self, x):
720
+ x = self.features(x)
721
+ x = x.flatten(1)
722
+ return self.classifier(x)
723
+
724
+ # ============================================================
725
+ # 4. Custom Fusional Wav2Vec2 + CQCC with Cross-Attention + Graph
726
+ # ============================================================
727
+
728
+ class PositionalEncoding(nn.Module):
729
+ def __init__(self, dim, max_len=6000):
730
+ super().__init__()
731
+ self.pos_embed = nn.Parameter(torch.randn(1, max_len, dim))
732
+
733
+ def forward(self, x):
734
+ return x + self.pos_embed[:, :x.size(1)]
735
+
736
+ class BidirectionalCrossAttention(nn.Module):
737
+ def __init__(self, dim, num_heads=4):
738
+ super().__init__()
739
+ self.attn1 = nn.MultiheadAttention(dim, num_heads, batch_first=True, dropout=0.2)
740
+ self.attn2 = nn.MultiheadAttention(dim, num_heads, batch_first=True, dropout=0.2)
741
+ self.norm_q = nn.LayerNorm(dim)
742
+ self.norm_kv = nn.LayerNorm(dim)
743
+
744
+ def forward(self, x1, x2):
745
+ # x1 attends to x2
746
+ q1 = self.norm_q(x1)
747
+ k2 = self.norm_kv(x2)
748
+ v2 = k2
749
+ out1, _ = self.attn1(q1, k2, v2)
750
+
751
+ # x2 attends to x1
752
+ q2 = self.norm_q(x2)
753
+ k1 = self.norm_kv(x1)
754
+ v1 = k1
755
+ out2, _ = self.attn2(q2, k1, v1)
756
+ return out1, out2
757
+
758
+ def align_sequences(x, target_len):
759
+ """Linear interpolation to match sequence lengths"""
760
+ x = x.transpose(1, 2)
761
+ x = F.interpolate(x, size=target_len, mode='linear', align_corners=False)
762
+ return x.transpose(1, 2)
763
+
764
+ class ImprovedWav2Vec2CQCCDetector(nn.Module):
765
+ def __init__(self, num_classes=2):
766
+ super().__init__()
767
+
768
+ # Wav2Vec2
769
+ self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
770
+
771
+ # Freeze the Wav2Vec2 layer so it acts purely as a feature extractor
772
+ for param in self.wav2vec.parameters():
773
+ param.requires_grad = False
774
+
775
+ dim = self.wav2vec.config.hidden_size
776
+
777
+ # CQCC encoder
778
+ self.cqcc_conv = nn.Sequential(
779
+ nn.Conv1d(20, 128, kernel_size=3, padding=1),
780
+ nn.BatchNorm1d(128),
781
+ nn.GELU(),
782
+ nn.Dropout(0.2),
783
+ nn.Conv1d(128, dim, kernel_size=3, padding=1),
784
+ nn.BatchNorm1d(dim),
785
+ nn.GELU()
786
+ )
787
+
788
+ # Positional Encoding
789
+ self.pos_enc = PositionalEncoding(dim)
790
+
791
+ # Bidirectional Cross Attention
792
+ self.cross_attn = BidirectionalCrossAttention(dim)
793
+
794
+ # True Graph Transformer Backend (using GAT blocks from AASIST)
795
+ self.graph_layers = nn.ModuleList([
796
+ GraphBlock(dim) for _ in range(3)
797
+ ])
798
+
799
+ # Classifier
800
+ self.classifier = nn.Sequential(
801
+ nn.Linear(dim, 128),
802
+ nn.GELU(),
803
+ nn.Dropout(0.2),
804
+ nn.Linear(128, num_classes)
805
+ )
806
+
807
+ def forward(self, wav, cqcc):
808
+ if wav.dim() == 3:
809
+ wav = wav.squeeze(1)
810
+
811
+ # Wav2Vec2 features
812
+ w2v = self.wav2vec(wav).last_hidden_state # (B, T_w, D)
813
+
814
+ # CQCC features
815
+ if cqcc.dim() == 4:
816
+ cqcc = cqcc.squeeze(1)
817
+ cqcc_feat = self.cqcc_conv(cqcc).transpose(1, 2) # (B, T_c, D)
818
+
819
+ # Align lengths
820
+ cqcc_feat = align_sequences(cqcc_feat, w2v.size(1))
821
+
822
+ # Add positional encoding
823
+ w2v = self.pos_enc(w2v)
824
+ cqcc_feat = self.pos_enc(cqcc_feat)
825
+
826
+ # Cross attention (bidirectional)
827
+ f1, f2 = self.cross_attn(cqcc_feat, w2v)
828
+ fused = f1 + f2
829
+
830
+ # Graph Transformer processing on node sequences
831
+ x = fused
832
+ for layer in self.graph_layers:
833
+ x = layer(x)
834
+
835
+ # Global average pooling on the nodes
836
+ pooled = x.mean(dim=1)
837
+
838
+ return self.classifier(pooled)
839
+
840
+ # ============================================================
841
+ # 5. Ablation Models
842
+ # ============================================================
843
+
844
+ class AblationWav2Vec2GraphDetector(nn.Module):
845
+ """Ablation 1: Wav2Vec2 only + Graph Backend (No CQCC, No Cross-Attention)"""
846
+ def __init__(self, num_classes=2):
847
+ super().__init__()
848
+ self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
849
+ for param in self.wav2vec.parameters():
850
+ param.requires_grad = False
851
+
852
+ dim = self.wav2vec.config.hidden_size
853
+ self.pos_enc = PositionalEncoding(dim)
854
+
855
+ self.graph_layers = nn.ModuleList([GraphBlock(dim) for _ in range(3)])
856
+ self.classifier = nn.Sequential(
857
+ nn.Linear(dim, 128), nn.GELU(), nn.Dropout(0.2), nn.Linear(128, num_classes)
858
+ )
859
+
860
+ def forward(self, wav, cqcc=None): # Accept both but ignore CQCC
861
+ if wav.dim() == 3:
862
+ wav = wav.squeeze(1)
863
+
864
+ w2v = self.wav2vec(wav).last_hidden_state
865
+ w2v = self.pos_enc(w2v)
866
+
867
+ x = w2v
868
+ for layer in self.graph_layers:
869
+ x = layer(x)
870
+
871
+ pooled = x.mean(dim=1)
872
+ return self.classifier(pooled)
873
+
874
+
875
+ class AblationCQCCGraphDetector(nn.Module):
876
+ """Ablation 2: CQCC only + Graph Backend (No Wav2Vec2, No Cross-Attention)"""
877
+ def __init__(self, num_classes=2):
878
+ super().__init__()
879
+ dim = 768 # Match Wav2Vec2 hidden size for fair comparison
880
+
881
+ self.cqcc_conv = nn.Sequential(
882
+ nn.Conv1d(20, 128, kernel_size=3, padding=1),
883
+ nn.BatchNorm1d(128),
884
+ nn.GELU(),
885
+ nn.Dropout(0.2),
886
+ nn.Conv1d(128, dim, kernel_size=3, padding=1),
887
+ nn.BatchNorm1d(dim),
888
+ nn.GELU()
889
+ )
890
+ self.pos_enc = PositionalEncoding(dim)
891
+
892
+ self.graph_layers = nn.ModuleList([GraphBlock(dim) for _ in range(3)])
893
+ self.classifier = nn.Sequential(
894
+ nn.Linear(dim, 128), nn.GELU(), nn.Dropout(0.2), nn.Linear(128, num_classes)
895
+ )
896
+
897
+ def forward(self, cqcc):
898
+ if cqcc.dim() == 4:
899
+ cqcc = cqcc.squeeze(1)
900
+
901
+ cqcc_feat = self.cqcc_conv(cqcc).transpose(1, 2)
902
+ cqcc_feat = self.pos_enc(cqcc_feat)
903
+
904
+ x = cqcc_feat
905
+ for layer in self.graph_layers:
906
+ x = layer(x)
907
+
908
+ pooled = x.mean(dim=1)
909
+ return self.classifier(pooled)
910
+
911
+
912
+ class AblationConcatGraphDetector(nn.Module):
913
+ """Ablation 3: Wav2Vec2 + CQCC + Simple Concat Fusion + Graph Backend (No Cross-Attention)"""
914
+ def __init__(self, num_classes=2):
915
+ super().__init__()
916
+ self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
917
+ for param in self.wav2vec.parameters():
918
+ param.requires_grad = False
919
+
920
+ dim = self.wav2vec.config.hidden_size
921
+
922
+ self.cqcc_conv = nn.Sequential(
923
+ nn.Conv1d(20, 128, kernel_size=3, padding=1),
924
+ nn.BatchNorm1d(128),
925
+ nn.GELU(),
926
+ nn.Dropout(0.2),
927
+ nn.Conv1d(128, dim, kernel_size=3, padding=1),
928
+ nn.BatchNorm1d(dim),
929
+ nn.GELU()
930
+ )
931
+
932
+ self.fusion_proj = nn.Linear(dim * 2, dim) # Project concatenated features back to dim
933
+ self.pos_enc = PositionalEncoding(dim)
934
+
935
+ self.graph_layers = nn.ModuleList([GraphBlock(dim) for _ in range(3)])
936
+ self.classifier = nn.Sequential(
937
+ nn.Linear(dim, 128), nn.GELU(), nn.Dropout(0.2), nn.Linear(128, num_classes)
938
+ )
939
+
940
+ def forward(self, wav, cqcc):
941
+ if wav.dim() == 3:
942
+ wav = wav.squeeze(1)
943
+ w2v = self.wav2vec(wav).last_hidden_state
944
+
945
+ if cqcc.dim() == 4:
946
+ cqcc = cqcc.squeeze(1)
947
+ cqcc_feat = self.cqcc_conv(cqcc).transpose(1, 2)
948
+
949
+ cqcc_feat = align_sequences(cqcc_feat, w2v.size(1))
950
+
951
+ # Simple concat over feature dimension instead of cross-attention
952
+ fused = torch.cat([w2v, cqcc_feat], dim=-1)
953
+ fused = self.fusion_proj(fused)
954
+
955
+ fused = self.pos_enc(fused)
956
+
957
+ x = fused
958
+ for layer in self.graph_layers:
959
+ x = layer(x)
960
+
961
+ pooled = x.mean(dim=1)
962
+ return self.classifier(pooled)
963
+
964
+
965
+ class AblationCrossAttnLinearDetector(nn.Module):
966
+ """Ablation 4: Wav2Vec2 + CQCC + Cross-Attention + Linear Backend (No Graph Transformer)"""
967
+ def __init__(self, num_classes=2):
968
+ super().__init__()
969
+ self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
970
+ for param in self.wav2vec.parameters():
971
+ param.requires_grad = False
972
+
973
+ dim = self.wav2vec.config.hidden_size
974
+
975
+ self.cqcc_conv = nn.Sequential(
976
+ nn.Conv1d(20, 128, kernel_size=3, padding=1),
977
+ nn.BatchNorm1d(128),
978
+ nn.GELU(),
979
+ nn.Dropout(0.2),
980
+ nn.Conv1d(128, dim, kernel_size=3, padding=1),
981
+ nn.BatchNorm1d(dim),
982
+ nn.GELU()
983
+ )
984
+
985
+ self.pos_enc = PositionalEncoding(dim)
986
+ self.cross_attn = BidirectionalCrossAttention(dim)
987
+
988
+ # Richer MLP classifier since graph is missing
989
+ self.classifier = nn.Sequential(
990
+ nn.Linear(dim, 256),
991
+ nn.GELU(),
992
+ nn.Dropout(0.3),
993
+ nn.Linear(256, 128),
994
+ nn.GELU(),
995
+ nn.Dropout(0.2),
996
+ nn.Linear(128, num_classes)
997
+ )
998
+
999
+ def forward(self, wav, cqcc):
1000
+ if wav.dim() == 3:
1001
+ wav = wav.squeeze(1)
1002
+ w2v = self.wav2vec(wav).last_hidden_state
1003
+
1004
+ if cqcc.dim() == 4:
1005
+ cqcc = cqcc.squeeze(1)
1006
+ cqcc_feat = self.cqcc_conv(cqcc).transpose(1, 2)
1007
+
1008
+ cqcc_feat = align_sequences(cqcc_feat, w2v.size(1))
1009
+
1010
+ w2v = self.pos_enc(w2v)
1011
+ cqcc_feat = self.pos_enc(cqcc_feat)
1012
+
1013
+ f1, f2 = self.cross_attn(cqcc_feat, w2v)
1014
+ fused = f1 + f2
1015
+
1016
+ # No graph layer, straight to global average pooling
1017
+ pooled = fused.mean(dim=1)
1018
+ return self.classifier(pooled)
1019
+
backend/preprocess.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+ import os
4
+ import argparse
5
+ from dataset import AudioDataset
6
+
7
+
8
+ def run_command(cmd):
9
+ try:
10
+ subprocess.run(cmd, check=True, text=True, capture_output=True)
11
+ except subprocess.CalledProcessError:
12
+ sys.exit(1)
13
+
14
+
15
+ def download_dataset():
16
+ run_command(["git", "lfs", "install"])
17
+ dataset_dir = "MLAAD-tiny"
18
+ if not os.path.exists(dataset_dir):
19
+ print("=== Cloning MLAAD-tiny dataset ===")
20
+ run_command(["git", "clone", "https://huggingface.co/datasets/mueller91/MLAAD-tiny"])
21
+ else:
22
+ print(f"Dataset directory '{dataset_dir}' already exists. Skipping clone.")
23
+
24
+
25
+ def precompute_cqcc(data_dir, cqcc_cache_dir, force=False):
26
+ for lang in ["en", "de"]:
27
+ print(f"\n--- Precomputing CQCC for language: {lang} ---")
28
+ dataset = AudioDataset(
29
+ data_dir=data_dir,
30
+ augment=False,
31
+ cqcc_cache_dir=cqcc_cache_dir,
32
+ target_lang=lang
33
+ )
34
+ dataset.precompute_cqcc_cache(force=force)
35
+ print("\nFinished all CQCC preprocessing.")
36
+
37
+
38
+ def parse_args():
39
+ parser = argparse.ArgumentParser(description="Download dataset and precompute CQCC features.")
40
+ parser.add_argument("--data-dir", default="MLAAD-tiny")
41
+ parser.add_argument(
42
+ "--cqcc-cache-dir",
43
+ default=os.path.join(os.path.dirname(__file__), "precomputed_features", "cqcc")
44
+ )
45
+ parser.add_argument("--force", action="store_true")
46
+ parser.add_argument("--skip-download", action="store_true")
47
+ parser.add_argument("--skip-cqcc", action="store_true")
48
+ return parser.parse_args()
49
+
50
+
51
+ if __name__ == "__main__":
52
+ args = parse_args()
53
+
54
+ if not args.skip_download:
55
+ download_dataset()
56
+
57
+ if not args.skip_cqcc:
58
+ precompute_cqcc(args.data_dir, args.cqcc_cache_dir, args.force)
backend/train.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.utils.data import DataLoader
6
+ from dataset import AudioDataset, collate_variable_length
7
+ from models import (
8
+ AASISTDetector,
9
+ Wav2Vec2SpoofDetector,
10
+ CQCCBaselineDetector,
11
+ ImprovedWav2Vec2CQCCDetector,
12
+ AblationWav2Vec2GraphDetector,
13
+ AblationCQCCGraphDetector,
14
+ AblationConcatGraphDetector,
15
+ AblationCrossAttnLinearDetector
16
+ )
17
+ from sklearn.metrics import roc_curve, auc
18
+ import numpy as np
19
+ import random
20
+ from tqdm import tqdm
21
+
22
+
23
+ def train_model(model, train_dataloader, criterion, optimizer, epochs=5, input_type='wav', device=None, val_dataloader=None, eval_interval=1, patience=2, model_save_path=None):
24
+
25
+ if device is None:
26
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27
+
28
+ model.to(device)
29
+
30
+ loss_history = []
31
+ best_val_metric = float('inf') # For min_dcf, lower is better
32
+ patience_counter = 0
33
+ best_epoch = 0
34
+
35
+ for epoch in range(epochs):
36
+ model.train()
37
+ epoch_loss = 0
38
+ correct = 0
39
+ total = 0
40
+ # Wrap the dataloader with tqdm for a progress bar
41
+ for batch_idx, batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs} - Training")):
42
+
43
+ wavs, cqccs, labels = batch
44
+ wavs = wavs.to(device)
45
+ cqccs = cqccs.to(device)
46
+ labels = labels.to(device)
47
+
48
+ optimizer.zero_grad()
49
+
50
+ if input_type == 'wav':
51
+ outputs = model(wavs)
52
+ elif input_type == 'cqcc':
53
+ outputs = model(cqccs)
54
+ elif input_type == 'wav_and_cqcc':
55
+ outputs = model(wavs, cqccs)
56
+ else:
57
+ raise ValueError("invalid input_type")
58
+
59
+ loss = criterion(outputs, labels)
60
+
61
+ loss.backward()
62
+ optimizer.step()
63
+
64
+ epoch_loss += loss.item()
65
+
66
+ _, predicted = torch.max(outputs.data, 1)
67
+
68
+ total += labels.size(0)
69
+ correct += (predicted == labels).sum().item()
70
+
71
+ # Print intermediate progress within the epoch
72
+ if batch_idx % 500 == 0 and batch_idx > 0: # Report every 500 batches
73
+ current_acc = 100 * correct / total
74
+ current_loss = epoch_loss / (batch_idx + 1)
75
+ print(f" Batch {batch_idx}/{len(train_dataloader)} | Loss: {current_loss:.4f} | Acc: {current_acc:.2f}%")
76
+
77
+ acc = 100 * correct / total if total > 0 else 0
78
+ avg_loss = epoch_loss / len(train_dataloader)
79
+ loss_history.append(avg_loss)
80
+ print(f"Epoch {epoch+1}/{epochs} | Training Loss: {avg_loss:.4f} | Training Acc: {acc:.2f}%")
81
+
82
+ # Validation and Early Stopping
83
+ if val_dataloader is not None and (epoch + 1) % eval_interval == 0:
84
+ print(f"Epoch {epoch+1}/{epochs} - Evaluating on Validation Set...")
85
+ _, _, _, val_eer, val_min_dcf, val_accuracy = evaluate_model(
86
+ model, val_dataloader, input_type=input_type, device=device
87
+ )
88
+ print(f" Validation | EER={val_eer*100:.2f}% | minDCF={val_min_dcf:.4f} | Accuracy={val_accuracy:.2f}")
89
+
90
+ if val_min_dcf < best_val_metric:
91
+ best_val_metric = val_min_dcf
92
+ patience_counter = 0
93
+ best_epoch = epoch + 1
94
+ if model_save_path:
95
+ torch.save(model.state_dict(), model_save_path)
96
+ print(f" Saved best model to {model_save_path} (minDCF: {best_val_metric:.4f})")
97
+ else:
98
+ patience_counter += 1
99
+ print(f" Validation minDCF did not improve. Patience: {patience_counter}/{patience}")
100
+
101
+ if patience_counter >= patience:
102
+ print(f"Early stopping triggered after {epoch+1} epochs. Best minDCF: {best_val_metric:.4f} at epoch {best_epoch}")
103
+ if model_save_path:
104
+ print(f"Loading best model from {model_save_path}")
105
+ model.load_state_dict(torch.load(model_save_path))
106
+ return loss_history # Stop training
107
+
108
+ # ensure save path logic is intact even when loop ends naturally
109
+ if val_dataloader is None and model_save_path is not None:
110
+ torch.save(model.state_dict(), model_save_path)
111
+ print(f" Saved final model to {model_save_path}")
112
+
113
+ return loss_history
114
+
115
+
116
+ def evaluate_model(model, dataloader, input_type='wav', device=None):
117
+
118
+ if device is None:
119
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
120
+
121
+ model.eval()
122
+
123
+ all_labels = []
124
+ all_probs = []
125
+
126
+ with torch.no_grad():
127
+ for batch in tqdm(dataloader, desc="Evaluating"):
128
+
129
+ wavs, cqccs, labels = batch
130
+ wavs = wavs.to(device)
131
+ cqccs = cqccs.to(device)
132
+ labels = labels.to(device)
133
+
134
+ if input_type == 'wav':
135
+ outputs = model(wavs)
136
+ elif input_type == 'cqcc':
137
+ outputs = model(cqccs)
138
+ elif input_type == 'wav_and_cqcc':
139
+ outputs = model(wavs, cqccs)
140
+ else:
141
+ raise ValueError("invalid input_type")
142
+
143
+ probs = torch.softmax(outputs, dim=1)[:, 1]
144
+
145
+ all_labels.extend(labels.tolist())
146
+ all_probs.extend(probs.tolist())
147
+
148
+ fpr, tpr, thresholds = roc_curve(all_labels, all_probs)
149
+ roc_auc = auc(fpr, tpr)
150
+
151
+ # ------------------
152
+ # EER (Equal Error Rate)
153
+ # ------------------
154
+ fnr = 1 - tpr
155
+ eer_index = np.nanargmin(np.absolute(fnr - fpr))
156
+ eer = fpr[eer_index]
157
+
158
+ # ------------------
159
+ # minDCF (Minimum Detection Cost Function)
160
+ # Parameters according to ASVspoof 5 Evaluation Plan (Track 1)
161
+ # ------------------
162
+ P_spoof = 0.05 # Prior probability of a spoofing attack (\pi_{spf})
163
+ P_bonafide = 0.95 # Prior probability of a real/bonafide utterance (1 - \pi_{spf})
164
+ C_miss = 1 # Cost of falsely rejecting a real voice (Miss)
165
+ C_fa = 10 # Cost of falsely accepting a spoof (False Alarm)
166
+
167
+ # In the dataset, 0 = real (bonafide), 1 = fake (spoof)
168
+ # fpr (False Positive Rate) = predicted fake (1) when true is real (0). This is a "miss" in ASVspoof.
169
+ # fnr (False Negative Rate) = predicted real (0) when true is fake (1). This is a "false alarm" in ASVspoof.
170
+ P_miss = fpr
171
+ P_fa = fnr
172
+
173
+ # Raw DCF = C_miss * P_bonafide * P_miss + C_fa * P_spoof * P_fa
174
+ # Normalized by the default DCF (min cost of predicting all bonafide vs all spoof)
175
+ dcf_default = min(C_miss * P_bonafide, C_fa * P_spoof)
176
+ dcf_array = (C_miss * P_bonafide * P_miss + C_fa * P_spoof * P_fa) / dcf_default
177
+ min_dcf = np.min(dcf_array)
178
+
179
+ # Overall Accuracy (using 0.5 threshold)
180
+ preds = [1 if p > 0.5 else 0 for p in all_probs]
181
+ correct = sum(1 for p, l in zip(preds, all_labels) if p == l)
182
+ accuracy = correct / len(all_labels) if len(all_labels) > 0 else 0
183
+
184
+ return fpr, tpr, roc_auc, eer, min_dcf, accuracy
185
+
186
+
187
+ def parse_args():
188
+ parser = argparse.ArgumentParser(description="Train spoof-detection models with optional CQCC caching.")
189
+ parser.add_argument(
190
+ "--data-dir",
191
+ default=None,
192
+ help="Path to dataset root containing original/ and fake/ folders."
193
+ )
194
+ parser.add_argument(
195
+ "--cqcc-cache-dir", # this is where cqcc is stored
196
+ default=os.path.join(os.path.dirname(__file__), "precomputed_features", "cqcc"),
197
+ help="Directory used to store and reuse precomputed CQCC tensors."
198
+ )
199
+ parser.add_argument(
200
+ "--precompute-cqcc-only",
201
+ action="store_true",
202
+ help="Only build the CQCC cache and exit without training."
203
+ )
204
+ parser.add_argument(
205
+ "--val-split",
206
+ type=float,
207
+ default=0.2,
208
+ help="Fraction of English training data to reserve for validation."
209
+ )
210
+ parser.add_argument(
211
+ "--force-rebuild-cqcc",
212
+ action="store_true",
213
+ help="Recompute cached CQCC files even if they already exist."
214
+ )
215
+ parser.add_argument(
216
+ "--smoke-test",
217
+ action="store_true",
218
+ help="Load one batch, run a forward pass through each model, and exit without training."
219
+ )
220
+ return parser.parse_args()
221
+
222
+
223
+ def run_smoke_test(dataloader, device):
224
+ print("\n--- Running Smoke Test ---")
225
+ batch = next(iter(dataloader))
226
+ wavs, cqccs, labels = batch
227
+
228
+ models_to_test = [
229
+ ("Wav2Vec2 Baseline", Wav2Vec2SpoofDetector(num_classes=2).to(device), "wav"),
230
+ ("AASIST Baseline", AASISTDetector(num_classes=2).to(device), "wav"),
231
+ ("CQCC Baseline", CQCCBaselineDetector(num_classes=2).to(device), "cqcc"),
232
+ ("Custom Fusion Model", ImprovedWav2Vec2CQCCDetector(num_classes=2).to(device), "wav_and_cqcc"),
233
+ ("Ablation W2V2+Graph", AblationWav2Vec2GraphDetector(num_classes=2).to(device), "wav"),
234
+ ("Ablation CQCC+Graph", AblationCQCCGraphDetector(num_classes=2).to(device), "cqcc"),
235
+ ("Ablation Concat+Graph", AblationConcatGraphDetector(num_classes=2).to(device), "wav_and_cqcc"),
236
+ ("Ablation CrossAttn+Linear", AblationCrossAttnLinearDetector(num_classes=2).to(device), "wav_and_cqcc"),
237
+ ]
238
+
239
+ with torch.no_grad():
240
+ for name, model, input_type in models_to_test:
241
+ model.eval()
242
+ if input_type == "wav":
243
+ outputs = model(wavs.to(device))
244
+ elif input_type == "cqcc":
245
+ outputs = model(cqccs.to(device))
246
+ elif input_type == "wav_and_cqcc":
247
+ outputs = model(wavs.to(device), cqccs.to(device))
248
+ else:
249
+ raise ValueError("invalid input_type")
250
+
251
+ print(f"{name}: input OK, output shape = {tuple(outputs.shape)}")
252
+
253
+ print(f"Labels shape = {tuple(labels.shape)}")
254
+ print("Smoke test complete. Cached CQCC loading and model forward passes succeeded.")
255
+
256
+
257
+ def main():
258
+ args = parse_args()
259
+ print(args)
260
+ SEED = 42
261
+ random.seed(SEED)
262
+ np.random.seed(SEED)
263
+ torch.manual_seed(SEED)
264
+ if torch.cuda.is_available():
265
+ torch.cuda.manual_seed_all(SEED)
266
+
267
+ g = torch.Generator()
268
+ g.manual_seed(SEED)
269
+
270
+ torch.backends.cudnn.deterministic = True
271
+ torch.backends.cudnn.benchmark = False
272
+
273
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
274
+
275
+ print(f"Using device: {device}")
276
+
277
+ print("Loading English Dataset for training/validation...")
278
+ full_en_dataset = AudioDataset(data_dir=args.data_dir, augment=False, cqcc_cache_dir=args.cqcc_cache_dir, target_lang="en")
279
+ total_en = len(full_en_dataset)
280
+ if total_en == 0:
281
+ raise ValueError("No English data found for target_lang='en'. Check data_dir and directory layout.")
282
+
283
+ val_split = min(max(args.val_split, 0.0), 0.5)
284
+ train_size = int((1.0 - val_split) * total_en)
285
+ val_size = total_en - train_size
286
+ indices = torch.randperm(total_en, generator=g).tolist()
287
+ train_indices = indices[:train_size]
288
+ val_indices = indices[train_size:]
289
+
290
+ train_dataset = torch.utils.data.Subset(
291
+ AudioDataset(data_dir=args.data_dir, augment=True, cqcc_cache_dir=args.cqcc_cache_dir, target_lang="en"),
292
+ train_indices
293
+ )
294
+ val_dataset = torch.utils.data.Subset(
295
+ AudioDataset(data_dir=args.data_dir, augment=False, cqcc_cache_dir=args.cqcc_cache_dir, target_lang="en"),
296
+ val_indices
297
+ )
298
+
299
+ print("Loading German Dataset for Testing...")
300
+ test_dataset = AudioDataset(data_dir=args.data_dir, augment=False, cqcc_cache_dir=args.cqcc_cache_dir, target_lang="de")
301
+
302
+ if args.precompute_cqcc_only:
303
+ print("\n--- Starting CQCC Precomputation ---")
304
+ print(f"Dataset: {full_en_dataset.data_dir}")
305
+ print("Precomputing CQCC cache for English data...")
306
+ full_en_dataset.precompute_cqcc_cache(force=args.force_rebuild_cqcc)
307
+ test_dataset.precompute_cqcc_cache(force=args.force_rebuild_cqcc)
308
+ print("CQCC preprocessing complete. Exiting.")
309
+ return
310
+
311
+ train_loader = DataLoader(
312
+ train_dataset,
313
+ batch_size=8,
314
+ shuffle=True,
315
+ collate_fn=collate_variable_length,
316
+ num_workers=2,
317
+ pin_memory=True,
318
+ generator=g, # ensure reproducible shuffling
319
+ )
320
+
321
+ val_loader = DataLoader(
322
+ val_dataset,
323
+ batch_size=8,
324
+ shuffle=False,
325
+ collate_fn=collate_variable_length,
326
+ num_workers=2,
327
+ pin_memory=True
328
+ )
329
+
330
+ test_loader = DataLoader(
331
+ test_dataset,
332
+ batch_size=8,
333
+ shuffle=False,
334
+ collate_fn=collate_variable_length,
335
+ num_workers=2,
336
+ pin_memory=True
337
+ )
338
+
339
+ if args.smoke_test:
340
+ run_smoke_test(train_loader, device)
341
+ return
342
+
343
+ models_dir = os.path.join(os.path.dirname(__file__), "models")
344
+ os.makedirs(models_dir, exist_ok=True)
345
+
346
+ criterion = nn.CrossEntropyLoss()
347
+
348
+ # ============================================================
349
+ # 1 Wav2Vec2 Baseline
350
+ # ============================================================
351
+
352
+ print("\n--- Training Wav2Vec2 Baseline ---")
353
+
354
+ wav2vec_model = Wav2Vec2SpoofDetector(num_classes=2).to(device)
355
+
356
+ optimizer_wav2vec = torch.optim.Adam(wav2vec_model.parameters(), lr=1e-4)
357
+
358
+ wav2vec_loss = train_model(
359
+ wav2vec_model,
360
+ train_loader,
361
+ criterion,
362
+ optimizer_wav2vec,
363
+ input_type='wav',
364
+ device=device,
365
+ val_dataloader=val_loader,
366
+ model_save_path=os.path.join(models_dir, "wav2vec2.pth")
367
+ )
368
+ del wav2vec_model, optimizer_wav2vec
369
+ torch.cuda.empty_cache()
370
+ # ============================================================
371
+ # 2 AASIST Baseline
372
+ # ============================================================
373
+
374
+ print("\n--- Training AASIST Baseline ---")
375
+
376
+ aasist_model = AASISTDetector(num_classes=2).to(device)
377
+
378
+ optimizer_aasist = torch.optim.Adam(aasist_model.parameters(), lr=5e-4)
379
+
380
+ aasist_loss = train_model(
381
+ aasist_model,
382
+ train_loader,
383
+ criterion,
384
+ optimizer_aasist,
385
+ input_type='wav',
386
+ device=device,
387
+ val_dataloader=val_loader,
388
+ model_save_path=os.path.join(models_dir, "aasist.pth")
389
+ )
390
+ del aasist_model, optimizer_aasist
391
+ torch.cuda.empty_cache()
392
+ # ============================================================
393
+ # 3 CQCC Baseline
394
+ # ============================================================
395
+
396
+ print("\n--- Training CQCC Baseline ---")
397
+
398
+ cqcc_baseline = CQCCBaselineDetector(num_classes=2).to(device)
399
+
400
+ optimizer_cqcc = torch.optim.Adam(cqcc_baseline.parameters(), lr=1e-4)
401
+
402
+ cqcc_loss = train_model(
403
+ cqcc_baseline,
404
+ train_loader,
405
+ criterion,
406
+ optimizer_cqcc,
407
+ input_type='cqcc',
408
+ device=device,
409
+ val_dataloader=val_loader,
410
+ model_save_path=os.path.join(models_dir, "cqcc_baseline.pth")
411
+ )
412
+ del cqcc_baseline, optimizer_cqcc
413
+ torch.cuda.empty_cache()
414
+ # ============================================================
415
+ # 4 Custom Fusional Wav2Vec2 + CQCC with Cross-Attention + Graph
416
+ # ============================================================
417
+
418
+ print("\n--- Training Custom Fusion Detector ---")
419
+
420
+ custom_model = ImprovedWav2Vec2CQCCDetector(num_classes=2).to(device)
421
+
422
+ optimizer_custom = torch.optim.Adam(custom_model.parameters(), lr=1e-4)
423
+
424
+ custom_loss = train_model(
425
+ custom_model,
426
+ train_loader,
427
+ criterion,
428
+ optimizer_custom,
429
+ input_type='wav_and_cqcc',
430
+ device=device,
431
+ val_dataloader=val_loader,
432
+ model_save_path=os.path.join(models_dir, "custom_hybrid.pth")
433
+ )
434
+ del custom_model, optimizer_custom
435
+ torch.cuda.empty_cache()
436
+
437
+ # ============================================================
438
+ # 5 Ablation Models
439
+ # ============================================================
440
+
441
+ print("\n--- Training Ablation 1 (Wav2Vec2 + Graph) ---")
442
+ ab1_model = AblationWav2Vec2GraphDetector(num_classes=2).to(device)
443
+ optimizer_ab1 = torch.optim.Adam(ab1_model.parameters(), lr=1e-4) # learning rate for wav2vec2-based
444
+ ab1_loss = train_model(ab1_model, train_loader, criterion, optimizer_ab1, input_type='wav', device=device, val_dataloader=val_loader, model_save_path=os.path.join(models_dir, "ablation_w2v2_graph.pth"))
445
+ del ab1_model, optimizer_ab1
446
+ torch.cuda.empty_cache()
447
+
448
+ print("\n--- Training Ablation 2 (CQCC + Graph) ---")
449
+ ab2_model = AblationCQCCGraphDetector(num_classes=2).to(device)
450
+ optimizer_ab2 = torch.optim.Adam(ab2_model.parameters(), lr=1e-4) # learning rate for CQCC-based
451
+ ab2_loss = train_model(ab2_model, train_loader, criterion, optimizer_ab2, input_type='cqcc', device=device, val_dataloader=val_loader, model_save_path=os.path.join(models_dir, "ablation_cqcc_graph.pth"))
452
+ del ab2_model, optimizer_ab2
453
+ torch.cuda.empty_cache()
454
+
455
+ print("\n--- Training Ablation 3 (Wav2Vec2 + CQCC + Simple Concat) ---")
456
+ ab3_model = AblationConcatGraphDetector(num_classes=2).to(device)
457
+ optimizer_ab3 = torch.optim.Adam(ab3_model.parameters(), lr=1e-4)
458
+ ab3_loss = train_model(ab3_model, train_loader, criterion, optimizer_ab3, input_type='wav_and_cqcc', device=device, val_dataloader=val_loader, model_save_path=os.path.join(models_dir, "ablation_concat_graph.pth"))
459
+ del ab3_model, optimizer_ab3
460
+ torch.cuda.empty_cache()
461
+
462
+ print("\n--- Training Ablation 4 (Wav2Vec2 + CQCC + Cross-Attn + Linear) ---")
463
+ ab4_model = AblationCrossAttnLinearDetector(num_classes=2).to(device)
464
+ optimizer_ab4 = torch.optim.Adam(ab4_model.parameters(), lr=1e-4)
465
+ ab4_loss = train_model(ab4_model, train_loader, criterion, optimizer_ab4, input_type='wav_and_cqcc', device=device, val_dataloader=val_loader, model_save_path=os.path.join(models_dir, "ablation_crossattn_linear.pth"))
466
+ del ab4_model, optimizer_ab4
467
+ torch.cuda.empty_cache()
468
+
469
+ # ============================================================
470
+ # Evaluation — reload one at a time
471
+ # ============================================================
472
+ print("\n--- Evaluating Models ---")
473
+ evals = []
474
+
475
+ models_to_eval = [
476
+ ("Wav2Vec2 Baseline", Wav2Vec2SpoofDetector, "wav2vec2.pth", 'wav'),
477
+ ("AASIST Baseline", AASISTDetector, "aasist.pth", 'wav'),
478
+ ("CQCC Baseline", CQCCBaselineDetector, "cqcc_baseline.pth", 'cqcc'),
479
+ ("Custom Fusion Model", ImprovedWav2Vec2CQCCDetector, "custom_hybrid.pth", 'wav_and_cqcc'),
480
+ ("Ablation 1 (W2V2+Graph)", AblationWav2Vec2GraphDetector, "ablation_w2v2_graph.pth", 'wav'),
481
+ ("Ablation 2 (CQCC+Graph)", AblationCQCCGraphDetector, "ablation_cqcc_graph.pth", 'cqcc'),
482
+ ("Ablation 3 (Concat+Graph)", AblationConcatGraphDetector, "ablation_concat_graph.pth", 'wav_and_cqcc'),
483
+ ("Ablation 4 (CrossAttn+Linear)", AblationCrossAttnLinearDetector, "ablation_crossattn_linear.pth", 'wav_and_cqcc'),
484
+ ]
485
+
486
+ for name, model_class, filename, inp in models_to_eval:
487
+ model_path = os.path.join(models_dir, filename)
488
+ if not os.path.exists(model_path):
489
+ print(f"Skipping evaluation for {name} (Model weights not found at {model_path})")
490
+ continue
491
+
492
+ model_obj = model_class(num_classes=2).to(device)
493
+ model_obj.load_state_dict(torch.load(model_path, map_location=device))
494
+ model_obj.eval()
495
+
496
+ print(f"\n--- Metrics for {name} ---")
497
+
498
+ # 1. EVAL ON TRAIN SET
499
+ train_fpr, train_tpr, train_auc, train_eer, train_min_dcf, train_acc = evaluate_model(
500
+ model_obj, train_loader, input_type=inp, device=device
501
+ )
502
+ print(f"[Train] Acc={train_acc*100:.2f}% | EER={train_eer*100:.2f}% | minDCF={train_min_dcf:.4f}")
503
+
504
+ # 2. EVAL ON TEST SET
505
+ test_fpr, test_tpr, test_auc, test_eer, test_min_dcf, test_acc = evaluate_model(
506
+ model_obj, test_loader, input_type=inp, device=device
507
+ )
508
+ print(f"[Test ] Acc={test_acc*100:.2f}% | EER={test_eer*100:.2f}% | minDCF={test_min_dcf:.4f}")
509
+
510
+ del model_obj
511
+ torch.cuda.empty_cache()
512
+
513
+ if __name__ == "__main__":
514
+ main()
frontend/index.html ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8">
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
+ <title>OdioCheck | Deepfake Voice Detection</title>
8
+ <!-- Tailwind CSS -->
9
+ <script src="https://cdn.tailwindcss.com"></script>
10
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600;700&display=swap" rel="stylesheet">
11
+ <link rel="stylesheet" href="style.css">
12
+ <!-- Chart.js -->
13
+ <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
14
+ </head>
15
+
16
+ <body
17
+ class="bg-slate-900 text-slate-100 font-sans min-h-screen flex flex-col items-center justify-center p-6 subtle-bg">
18
+
19
+ <div class="glass-card max-w-2xl w-full rounded-3xl p-8 relative overflow-hidden transition-all duration-300">
20
+ <!-- Glowing Orb Background -->
21
+ <div
22
+ class="absolute -top-32 -left-32 w-64 h-64 bg-indigo-600 rounded-full mix-blend-multiply filter blur-3xl opacity-30 animate-pulse">
23
+ </div>
24
+ <div class="absolute -bottom-32 -right-32 w-64 h-64 bg-fuchsia-600 rounded-full mix-blend-multiply filter blur-3xl opacity-30 animate-pulse"
25
+ style="animation-delay: 2s;"></div>
26
+
27
+ <div class="relative z-10">
28
+ <h1
29
+ class="text-4xl font-bold mb-2 text-transparent bg-clip-text bg-gradient-to-r from-indigo-400 to-cyan-300">
30
+ OdioCheck
31
+ </h1>
32
+ <p class="text-slate-400 mb-8 font-light">Advanced Deepfake Voice Detection powered by SOTA Graph
33
+ architecture.</p>
34
+
35
+ <div id="drop-zone"
36
+ class="border-2 border-dashed border-slate-600 rounded-2xl p-10 flex flex-col items-center justify-center cursor-pointer hover:border-indigo-400 hover:bg-slate-800/50 transition-all duration-300 group">
37
+ <svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5"
38
+ stroke="currentColor"
39
+ class="w-12 h-12 text-slate-500 group-hover:text-indigo-400 mb-4 transition-colors">
40
+ <path stroke-linecap="round" stroke-linejoin="round"
41
+ d="M12 18.75a6 6 0 006-6v-1.5m-6 7.5a6 6 0 01-6-6v-1.5m6 7.5v3.75m-3.75 0h7.5M12 15.75a3 3 0 01-3-3V4.5a3 3 0 116 0v8.25a3 3 0 01-3 3z" />
42
+ </svg>
43
+ <p class="text-lg text-slate-300 font-medium">Click to upload or drag & drop</p>
44
+ <p class="text-sm text-slate-500 mt-1">Supports WAV, OGG, MP3, FLAC, M4A & more</p>
45
+ <input type="file" id="file-input" class="hidden" accept="audio/*">
46
+ </div>
47
+
48
+ <!-- Analysis Section -->
49
+ <div id="analysis-section" class="mt-8 hidden opacity-0 transition-opacity duration-500">
50
+ <div class="flex items-center space-x-4 mb-6">
51
+ <div id="loading-spinner" class="hidden">
52
+ <div class="animate-spin rounded-full h-8 w-8 border-b-2 border-indigo-400"></div>
53
+ </div>
54
+ <h2 id="status-text" class="text-xl font-semibold text-slate-300">Analyzing Spectrogram...</h2>
55
+ </div>
56
+
57
+ <!-- Results: panels will be inserted via JavaScript based on response keys -->
58
+ <div id="results" class="hidden">
59
+ <div id="model-panels" class="grid grid-cols-2 gap-6"></div>
60
+ </div>
61
+ </div>
62
+ </div>
63
+ </div>
64
+
65
+ <!-- Additional Graph Section for wow factor -->
66
+ <div id="chart-card"
67
+ class="glass-card max-w-2xl w-full rounded-3xl p-8 mt-6 hidden opacity-0 transition-opacity duration-500">
68
+ <h3 class="text-lg font-semibold mb-4 text-slate-300">Timeline Analysis</h3>
69
+ <canvas id="audioChart" height="100"></canvas>
70
+ </div>
71
+
72
+ <script src="script.js"></script>
73
+ </body>
74
+
75
+ </html>
frontend/script.js ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ const dropZone = document.getElementById('drop-zone');
2
+ const fileInput = document.getElementById('file-input');
3
+ const analysisSection = document.getElementById('analysis-section');
4
+ const statusText = document.getElementById('status-text');
5
+ const results = document.getElementById('results');
6
+ const loadingSpinner = document.getElementById('loading-spinner');
7
+ const chartCard = document.getElementById('chart-card');
8
+
9
+ // -------------------------------------------------------
10
+ // Chart setup
11
+ // -------------------------------------------------------
12
+ const ctx = document.getElementById('audioChart').getContext('2d');
13
+ let audioChart = new Chart(ctx, {
14
+ type: 'line',
15
+ data: { labels: [], datasets: [] },
16
+ options: {
17
+ responsive: true,
18
+ animation: { duration: 600, easing: 'easeInOutQuart' },
19
+ plugins: {
20
+ legend: { display: true, labels: { color: '#94a3b8', font: { size: 12 } } },
21
+ tooltip: {
22
+ callbacks: {
23
+ label: ctx => ` ${ctx.dataset.label}: ${ctx.parsed.y.toFixed(1)}% fake`,
24
+ title: items => `Segment @ ${items[0].label}s`
25
+ }
26
+ }
27
+ },
28
+ scales: {
29
+ y: {
30
+ beginAtZero: true,
31
+ max: 100,
32
+ ticks: { color: '#94a3b8', callback: v => v + '%' },
33
+ grid: { color: 'rgba(148,163,184,0.1)' },
34
+ title: { display: true, text: 'Fake Probability (%)', color: '#64748b' }
35
+ },
36
+ x: {
37
+ ticks: {
38
+ color: '#94a3b8', callback: (_, i, ticks) => {
39
+ // Show fewer labels when there are many windows
40
+ const step = Math.max(1, Math.floor(ticks.length / 8));
41
+ return i % step === 0 ? audioChart.data.labels[i] + 's' : '';
42
+ }
43
+ },
44
+ grid: { color: 'rgba(148,163,184,0.05)' },
45
+ title: { display: true, text: 'Time (seconds)', color: '#64748b' }
46
+ }
47
+ }
48
+ }
49
+ });
50
+
51
+ // Palette and display names for the four models
52
+ const MODEL_META = {
53
+ wav2vec2: { label: 'Wav2Vec2', color: '#3b82f6' },
54
+ aasist: { label: 'AASIST', color: '#f43f5e' },
55
+ cqcc_baseline: { label: 'CQCC Baseline', color: '#fbbf24' },
56
+ custom_hybrid: { label: 'Proposed Custom Hybrid', color: '#10b981' },
57
+ };
58
+
59
+ // -------------------------------------------------------
60
+ // File handling
61
+ // -------------------------------------------------------
62
+ function handleFile(file) {
63
+ if (!file) return;
64
+
65
+ // Show sections
66
+ analysisSection.classList.remove('hidden');
67
+ chartCard.classList.remove('hidden');
68
+ setTimeout(() => {
69
+ analysisSection.classList.remove('opacity-0');
70
+ chartCard.classList.remove('opacity-0');
71
+ }, 50);
72
+
73
+ results.classList.add('hidden');
74
+ loadingSpinner.classList.remove('hidden');
75
+ statusText.innerText = `Analyzing "${file.name}"…`;
76
+
77
+ // Clear previous state
78
+ document.getElementById('model-panels').innerHTML = '';
79
+ audioChart.data.labels = [];
80
+ audioChart.data.datasets = [];
81
+ audioChart.update();
82
+
83
+ // Animated placeholder while waiting: a single pulsing dataset
84
+ const placeholder = {
85
+ label: 'Analyzing…',
86
+ data: Array.from({ length: 20 }, (_, i) => 45 + Math.sin(i / 2) * 10),
87
+ borderColor: 'rgba(99,102,241,0.5)',
88
+ backgroundColor: 'rgba(99,102,241,0.05)',
89
+ borderDash: [4, 4],
90
+ fill: true,
91
+ tension: 0.4,
92
+ pointRadius: 0,
93
+ };
94
+ audioChart.data.labels = Array.from({ length: 20 }, (_, i) => i);
95
+ audioChart.data.datasets = [placeholder];
96
+ audioChart.update();
97
+
98
+ let tick = 0;
99
+ const loadingAnim = setInterval(() => {
100
+ tick++;
101
+ placeholder.data = Array.from({ length: 20 }, (_, i) =>
102
+ 45 + Math.sin((i + tick) / 2) * 10
103
+ );
104
+ audioChart.update('none'); // skip animation for perf
105
+ }, 80);
106
+
107
+ const formData = new FormData();
108
+ formData.append('file', file);
109
+
110
+ const HF_API_URL = window.location.hostname === '127.0.0.1' || window.location.hostname === 'localhost'
111
+ ? '/api/predict'
112
+ : 'https://junsiang26-odiocheck-backend.hf.space/api/predict';
113
+
114
+ fetch(HF_API_URL, { method: 'POST', body: formData })
115
+ .then(r => r.json())
116
+ .then(data => {
117
+ clearInterval(loadingAnim);
118
+ loadingSpinner.classList.add('hidden');
119
+
120
+ if (data.error) {
121
+ statusText.innerText = 'Error analyzing file.';
122
+ console.error(data.error);
123
+ return;
124
+ }
125
+
126
+ renderResults(data);
127
+ })
128
+ .catch(() => {
129
+ clearInterval(loadingAnim);
130
+ loadingSpinner.classList.add('hidden');
131
+ statusText.innerText = 'Connection error. Is the backend running?';
132
+ });
133
+ }
134
+
135
+ // -------------------------------------------------------
136
+ // Render results from the new response shape:
137
+ // data.overall → { model: { prediction, fake_probability, real_probability } }
138
+ // data.timeline → { model: [fake_prob_pct, ...] }
139
+ // data.window_labels → [centre_sec, ...]
140
+ // -------------------------------------------------------
141
+ function renderResults(data) {
142
+ const { overall, timeline, window_labels } = data;
143
+
144
+ statusText.innerText = 'Analysis Complete';
145
+ results.classList.remove('hidden');
146
+
147
+ // --- Model panels (overall verdict) ---
148
+ const panelsEl = document.getElementById('model-panels');
149
+ panelsEl.innerHTML = '';
150
+
151
+ for (const [key, info] of Object.entries(overall)) {
152
+ const meta = MODEL_META[key] || { label: key, color: '#94a3b8' };
153
+ const isFake = info.prediction === 'FAKE';
154
+ const barColor = isFake ? 'from-rose-500 to-rose-400' : 'from-emerald-400 to-emerald-500';
155
+ const displayPct = isFake ? info.fake_probability : info.real_probability;
156
+
157
+ panelsEl.insertAdjacentHTML('beforeend', `
158
+ <div>
159
+ <div class="flex justify-between items-end mb-2">
160
+ <span class="text-sm text-slate-400 uppercase tracking-widest font-semibold"
161
+ style="color:${meta.color}">${meta.label}</span>
162
+ <span class="text-3xl font-bold tracking-wider ${isFake ? 'text-rose-500' : 'text-emerald-500'}">
163
+ ${info.prediction}
164
+ </span>
165
+ </div>
166
+ <div class="text-xs text-slate-500 mb-2">
167
+ Fake: <span class="text-slate-300">${info.fake_probability}%</span>
168
+ &nbsp;·&nbsp;
169
+ Real: <span class="text-slate-300">${info.real_probability}%</span>
170
+ </div>
171
+ <div class="w-full bg-slate-700 h-4 rounded-full overflow-hidden mb-6 mt-1">
172
+ <div class="prob-bar h-full bg-gradient-to-r transition-all duration-1000 ease-out ${barColor}"
173
+ style="width:0%"
174
+ data-width="${displayPct}">
175
+ </div>
176
+ </div>
177
+ </div>`);
178
+ }
179
+
180
+ // Animate bars
181
+ requestAnimationFrame(() => {
182
+ document.querySelectorAll('.prob-bar').forEach(bar => {
183
+ bar.style.width = bar.dataset.width + '%';
184
+ });
185
+ });
186
+
187
+ // --- Timeline chart (real data) ---
188
+ // window_labels are now start-of-segment times (0, 2, 4 ...)
189
+ // For short audio with a single window, we pad with the audio-end label
190
+ // so the chart shows a line rather than a lonely dot.
191
+ let labels = [...window_labels];
192
+ let timelineValues = {};
193
+ Object.entries(timeline).forEach(([k, v]) => { timelineValues[k] = [...v]; });
194
+
195
+ if (labels.length === 1) {
196
+ // Estimate audio duration: single window = TARGET_LEN / 16000 ≈ 4.025s
197
+ const audioEnd = parseFloat((labels[0] + 4.025).toFixed(2));
198
+ labels.push(audioEnd);
199
+ Object.keys(timelineValues).forEach(k => timelineValues[k].push(timelineValues[k][0]));
200
+ }
201
+
202
+ audioChart.data.labels = labels;
203
+ audioChart.data.datasets = Object.entries(timelineValues).map(([key, values]) => {
204
+ const meta = MODEL_META[key] || { label: key, color: '#94a3b8' };
205
+ const hex = meta.color;
206
+ const rgb = hex.match(/[0-9a-fA-F]{2}/g).map(h => parseInt(h, 16)).join(',');
207
+ return {
208
+ label: meta.label,
209
+ data: values,
210
+ borderColor: hex,
211
+ backgroundColor: `rgba(${rgb},0.08)`,
212
+ fill: true,
213
+ tension: 0.4,
214
+ pointRadius: values.length <= 20 ? 4 : 2,
215
+ pointHoverRadius: 6,
216
+ };
217
+ });
218
+
219
+ // Add a 50% threshold reference line
220
+ audioChart.data.datasets.push({
221
+ label: 'Decision threshold (50%)',
222
+ data: Array(labels.length).fill(50),
223
+ borderColor: 'rgba(255,255,255,0.2)',
224
+ borderDash: [6, 4],
225
+ borderWidth: 1,
226
+ pointRadius: 0,
227
+ fill: false,
228
+ tension: 0,
229
+ });
230
+
231
+ audioChart.update();
232
+ }
233
+
234
+ // -------------------------------------------------------
235
+ // Drop zone wiring
236
+ // -------------------------------------------------------
237
+ dropZone.addEventListener('click', () => fileInput.click());
238
+ fileInput.addEventListener('change', e => handleFile(e.target.files[0]));
239
+
240
+ ['dragenter', 'dragover', 'dragleave', 'drop'].forEach(name => {
241
+ dropZone.addEventListener(name, e => { e.preventDefault(); e.stopPropagation(); });
242
+ });
243
+ dropZone.addEventListener('drop', e => handleFile(e.dataTransfer.files[0]));
frontend/style.css ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Glassmorphism utility classes */
2
+ .glass-card {
3
+ background: rgba(30, 41, 59, 0.7);
4
+ backdrop-filter: blur(12px);
5
+ -webkit-backdrop-filter: blur(12px);
6
+ border: 1px solid rgba(255, 255, 255, 0.1);
7
+ box-shadow: 0 4px 30px rgba(0, 0, 0, 0.1);
8
+ }
9
+
10
+ .subtle-bg {
11
+ background-color: #0f172a;
12
+ background-image:
13
+ radial-gradient(at 0% 0%, hsla(253, 16%, 7%, 1) 0, transparent 50%),
14
+ radial-gradient(at 50% 0%, hsla(225, 39%, 30%, 1) 0, transparent 50%),
15
+ radial-gradient(at 100% 0%, hsla(339, 49%, 30%, 1) 0, transparent 50%);
16
+ }
17
+
18
+ .animate-pulse {
19
+ animation: pulse 4s cubic-bezier(0.4, 0, 0.6, 1) infinite;
20
+ }
21
+
22
+ @keyframes pulse {
23
+
24
+ 0%,
25
+ 100% {
26
+ opacity: 0.3;
27
+ transform: scale(1);
28
+ }
29
+
30
+ 50% {
31
+ opacity: 0.5;
32
+ transform: scale(1.05);
33
+ }
34
+ }
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets == 2.21.0
2
+ fastapi
3
+ librosa
4
+ matplotlib
5
+ numpy
6
+ python-multipart
7
+ python-pptx
8
+ scikit-learn
9
+ scipy
10
+ seaborn
11
+ soundfile
12
+ torch>=2.6.0
13
+ torchaudio
14
+ torchvision
15
+ tqdm
16
+ transformers
17
+ uvicorn