Spaces:
Running
Running
Commit ·
9f2b6db
0
Parent(s):
Pure production deploy
Browse files- .gitattributes +4 -0
- .gitignore +18 -0
- .vscode/settings.json +4 -0
- AI Project 2026.pdf +0 -0
- Dockerfile +25 -0
- Model_Training_(Odio).ipynb +0 -0
- README.md +108 -0
- backend/app.py +233 -0
- backend/dataset.py +211 -0
- backend/models.py +1019 -0
- backend/preprocess.py +58 -0
- backend/train.py +514 -0
- frontend/index.html +75 -0
- frontend/script.js +243 -0
- frontend/style.css +34 -0
- requirements.txt +17 -0
.gitattributes
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
backend/models/*.pth filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
**/*.pth filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
backend/models/**/*.pth filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
backend/models/Ablation[[:space:]]models/*.pth filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv
|
| 2 |
+
MLAAD-tiny/
|
| 3 |
+
data/
|
| 4 |
+
*.pth
|
| 5 |
+
Download_sample (Ignore)/
|
| 6 |
+
__pycache__/
|
| 7 |
+
project_requirements.txt
|
| 8 |
+
generate_notebook.py
|
| 9 |
+
backend/precomputed_features/
|
| 10 |
+
|
| 11 |
+
# Models
|
| 12 |
+
*.pth
|
| 13 |
+
backend/models/*.pth
|
| 14 |
+
|
| 15 |
+
# Environment
|
| 16 |
+
.env
|
| 17 |
+
.DS_Store
|
| 18 |
+
node_modules/
|
.vscode/settings.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"python-envs.defaultEnvManager": "ms-python.python:conda",
|
| 3 |
+
"python-envs.defaultPackageManager": "ms-python.python:conda"
|
| 4 |
+
}
|
AI Project 2026.pdf
ADDED
|
Binary file (73.1 kB). View file
|
|
|
Dockerfile
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use a slim Python image
|
| 2 |
+
FROM python:3.10-slim
|
| 3 |
+
|
| 4 |
+
# Install system dependencies for audio processing
|
| 5 |
+
RUN apt-get update && apt-get install -y \
|
| 6 |
+
libsndfile1 \
|
| 7 |
+
ffmpeg \
|
| 8 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 9 |
+
|
| 10 |
+
# Set working directory
|
| 11 |
+
WORKDIR /app
|
| 12 |
+
|
| 13 |
+
# Copy requirements and install
|
| 14 |
+
COPY requirements.txt .
|
| 15 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 16 |
+
|
| 17 |
+
# Copy the backend code and models
|
| 18 |
+
COPY backend/ ./backend/
|
| 19 |
+
|
| 20 |
+
# Expose the port FastAPI will run on
|
| 21 |
+
EXPOSE 7860
|
| 22 |
+
|
| 23 |
+
# Command to run the application
|
| 24 |
+
# Note: Hugging Face uses port 7860 by default
|
| 25 |
+
CMD ["uvicorn", "backend.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
Model_Training_(Odio).ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
README.md
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<img width="1080" height="324" alt="odiocheck" src="https://github.com/user-attachments/assets/4d7b573e-5b0b-4fc7-85de-da60bbb701c2" />
|
| 2 |
+
|
| 3 |
+
# OdioCheck - Deepfake Voice Detection AI
|
| 4 |
+
*50.021 Artificial Intelligence Project*
|
| 5 |
+
|
| 6 |
+
## Theme
|
| 7 |
+
**AI for Security & Social Good** (UN SDG #16: Peace, Justice, and Strong Institutions)
|
| 8 |
+
OdioCheck tackles the rising threat of audio deepfakes used in scams and misdirection.
|
| 9 |
+
|
| 10 |
+
## Requirements Checklist
|
| 11 |
+
- [x] **Fully functioning code:** Complete end-to-end PyTorch implementation from dataset loading to real-time inference via a web UI.
|
| 12 |
+
- [x] **Baseline models (×3):**
|
| 13 |
+
- **Wav2Vec2** — self-supervised transformer feature extractor (frozen) + attentive pooling classifier (`backend/models.py`)
|
| 14 |
+
- **AASIST** — graph-based SOTA baseline using sinc-filter frontend + spectro-temporal heterogeneous graph attention (`backend/models.py`)
|
| 15 |
+
- **CQCC Baseline** — standard CNN processing Constant-Q Cepstral Coefficients (`backend/models.py`)
|
| 16 |
+
- [x] **SOTA Custom Model:** `ImprovedWav2Vec2CQCCDetector` — a novel fusion architecture combining Wav2Vec 2.0 and CQCC features via **bidirectional cross-attention**, followed by a **Graph Attention** backend (`backend/models.py`).
|
| 17 |
+
- [x] **Ablation Study (×4):** Four ablation variants systematically isolate each architectural component to validate the custom model design:
|
| 18 |
+
- **Ablation 1** — Wav2Vec2 + Graph (no CQCC, no cross-attention)
|
| 19 |
+
- **Ablation 2** — CQCC + Graph (no Wav2Vec2, no cross-attention)
|
| 20 |
+
- **Ablation 3** — Wav2Vec2 + CQCC + Simple Concat + Graph (no cross-attention)
|
| 21 |
+
- **Ablation 4** — Wav2Vec2 + CQCC + Cross-Attention + Linear (no Graph Attention)
|
| 22 |
+
- [x] **Fully Working Frontend:** Glassmorphic UI (Tailwind + Vanilla JS) served via FastAPI. Supports OGG/MP3/M4A/FLAC/WAV. Shows **side-by-side** predictions from all four primary models with real-time animated confidence bars and a per-window **temporal analysis timeline chart**.
|
| 23 |
+
- [x] **Cross-lingual Dataset Split:** Trained on English audio (`MLAAD-tiny/en`), tested on unseen German audio (`MLAAD-tiny/de`) for out-of-distribution generalisation evaluation.
|
| 24 |
+
- [x] **CQCC Feature Caching:** Pre-computed CQCC tensors are cached to disk to avoid redundant computation across training runs.
|
| 25 |
+
|
| 26 |
+
---
|
| 27 |
+
|
| 28 |
+
## Installation
|
| 29 |
+
|
| 30 |
+
Ensure you have Python 3.9+ installed. Install all dependencies:
|
| 31 |
+
```bash
|
| 32 |
+
pip install -r requirements.txt
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
### Dataset Download
|
| 36 |
+
We use the `MLAAD-tiny` dataset (multi-language audio deepfakes). Download it from Hugging Face before training:
|
| 37 |
+
```bash
|
| 38 |
+
pip install -U "huggingface_hub[cli]"
|
| 39 |
+
huggingface-cli download mueller91/MLAAD-tiny --repo-type dataset --local-dir MLAAD-tiny
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
---
|
| 43 |
+
|
| 44 |
+
## Running the Project
|
| 45 |
+
|
| 46 |
+
### Step 1 — (Optional) Pre-compute CQCC Cache
|
| 47 |
+
Pre-computing CQCC features once dramatically speeds up all subsequent training runs:
|
| 48 |
+
```bash
|
| 49 |
+
python backend/train.py --precompute-cqcc-only
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
### Step 2 — Train All Models
|
| 53 |
+
Trains all 4 primary models + 4 ablation variants, evaluates on the German test set, and saves `.pth` weights to `backend/models/`:
|
| 54 |
+
```bash
|
| 55 |
+
python backend/train.py
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
#### Available Training Flags
|
| 59 |
+
| Flag | Default | Description |
|
| 60 |
+
|---|---|---|
|
| 61 |
+
| `--val-split F` | `0.2` | Fraction of English data reserved for validation (0–0.5). |
|
| 62 |
+
| `--data-dir PATH` | auto | Override dataset root (must contain `original/` and `fake/` folders). |
|
| 63 |
+
| `--cqcc-cache-dir PATH` | `backend/precomputed_features/cqcc` | Where to read/write cached CQCC tensors. |
|
| 64 |
+
| `--precompute-cqcc-only` | `False` | Build CQCC cache and exit without training. |
|
| 65 |
+
| `--force-rebuild-cqcc` | `False` | Recompute CQCC cache even if files already exist. |
|
| 66 |
+
| `--smoke-test` | `False` | Run one forward pass through every model and exit — useful for verifying setup. |
|
| 67 |
+
|
| 68 |
+
#### Quick Smoke Test
|
| 69 |
+
Verify all models initialise and run a forward pass correctly without full training:
|
| 70 |
+
```bash
|
| 71 |
+
python backend/train.py --smoke-test
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
### Step 3 — Start the Web Interface
|
| 75 |
+
```bash
|
| 76 |
+
uvicorn backend.app:app --reload
|
| 77 |
+
```
|
| 78 |
+
Open **http://127.0.0.1:8000** in your browser. Upload any audio file (WAV, MP3, OGG, FLAC, M4A) to see simultaneous predictions from all four primary models plus an animated temporal confidence chart.
|
| 79 |
+
|
| 80 |
+
---
|
| 81 |
+
|
| 82 |
+
## Project Architecture
|
| 83 |
+
|
| 84 |
+
```
|
| 85 |
+
AI Project/
|
| 86 |
+
├── backend/
|
| 87 |
+
│ ├── models.py # All model architectures (3 baselines + custom + 4 ablations)
|
| 88 |
+
│ ├── dataset.py # AudioDataset with CQCC caching + data augmentation
|
| 89 |
+
│ ├── train.py # Full training + evaluation pipeline (CLI-driven)
|
| 90 |
+
│ ├── app.py # FastAPI inference server (windowed temporal analysis)
|
| 91 |
+
│ ├── preprocess.py # Standalone preprocessing utilities
|
| 92 |
+
│ └── models/ # Saved .pth weight files (generated after training)
|
| 93 |
+
├── frontend/
|
| 94 |
+
│ ├── index.html # Glassmorphic UI shell
|
| 95 |
+
│ ├── script.js # File upload, Chart.js timeline, model panel rendering
|
| 96 |
+
│ └── style.css # Custom glassmorphism styles
|
| 97 |
+
├── MLAAD-tiny/ # Dataset (downloaded separately)
|
| 98 |
+
├── requirements.txt # Python dependencies
|
| 99 |
+
└── colab_training_notebook.ipynb # Google Colab training notebook
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
---
|
| 103 |
+
|
| 104 |
+
## Working with Other Datasets
|
| 105 |
+
To replace MLAAD-tiny with another dataset (e.g., ASVspoof):
|
| 106 |
+
1. Place your `fake/` and `original/` (or `real/`) audio folders into a `data/` directory at the project root.
|
| 107 |
+
2. The `AudioDataset` in `dataset.py` auto-detects and falls back to the `data/` directory if `MLAAD-tiny` is absent.
|
| 108 |
+
3. Re-run `python backend/train.py`. The full pipeline runs identically.
|
backend/app.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from fastapi import FastAPI, UploadFile, File
|
| 6 |
+
from fastapi.responses import JSONResponse
|
| 7 |
+
from fastapi.staticfiles import StaticFiles
|
| 8 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
+
from dataset import compute_cqcc
|
| 10 |
+
import sys
|
| 11 |
+
import librosa
|
| 12 |
+
|
| 13 |
+
sys.path.append(os.path.dirname(__file__))
|
| 14 |
+
|
| 15 |
+
from models import (
|
| 16 |
+
Wav2Vec2SpoofDetector,
|
| 17 |
+
AASISTDetector,
|
| 18 |
+
CQCCBaselineDetector,
|
| 19 |
+
ImprovedWav2Vec2CQCCDetector
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
app = FastAPI(title="Deepfake Voice Detection")
|
| 23 |
+
|
| 24 |
+
app.add_middleware(
|
| 25 |
+
CORSMiddleware,
|
| 26 |
+
allow_origins=["*"],
|
| 27 |
+
allow_credentials=True,
|
| 28 |
+
allow_methods=["*"],
|
| 29 |
+
allow_headers=["*"],
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 33 |
+
|
| 34 |
+
# -------------------------------------------------------
|
| 35 |
+
# Load Models
|
| 36 |
+
# -------------------------------------------------------
|
| 37 |
+
|
| 38 |
+
models_dir = os.path.join(os.path.dirname(__file__), "models")
|
| 39 |
+
|
| 40 |
+
def load_model(model, filename):
|
| 41 |
+
path = os.path.join(models_dir, filename)
|
| 42 |
+
if os.path.exists(path):
|
| 43 |
+
state_dict = torch.load(path, map_location=device)
|
| 44 |
+
# Handle weight_norm parametrization mismatch (common in Wav2Vec2 between versions)
|
| 45 |
+
# This converts the 'parametrizations' keys back to 'weight_g' and 'weight_v'
|
| 46 |
+
new_state_dict = {}
|
| 47 |
+
for k, v in state_dict.items():
|
| 48 |
+
if "pos_conv_embed.conv.parametrizations.weight.original0" in k:
|
| 49 |
+
new_key = k.replace("parametrizations.weight.original0", "weight_g")
|
| 50 |
+
new_state_dict[new_key] = v
|
| 51 |
+
elif "pos_conv_embed.conv.parametrizations.weight.original1" in k:
|
| 52 |
+
new_key = k.replace("parametrizations.weight.original1", "weight_v")
|
| 53 |
+
new_state_dict[new_key] = v
|
| 54 |
+
else:
|
| 55 |
+
new_state_dict[k] = v
|
| 56 |
+
model.load_state_dict(new_state_dict)
|
| 57 |
+
print(f"Loaded {filename}")
|
| 58 |
+
else:
|
| 59 |
+
print(f"WARNING: {filename} not found. Run train.py first!")
|
| 60 |
+
model.eval()
|
| 61 |
+
|
| 62 |
+
return model
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
wav2vec_model = load_model(
|
| 66 |
+
Wav2Vec2SpoofDetector(num_classes=2).to(device),
|
| 67 |
+
"wav2vec2.pth"
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
aasist_model = load_model(
|
| 71 |
+
AASISTDetector(num_classes=2).to(device),
|
| 72 |
+
"aasist.pth"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
cqcc_baseline_model = load_model(
|
| 76 |
+
CQCCBaselineDetector(num_classes=2).to(device),
|
| 77 |
+
"cqcc_baseline.pth"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
custom_hybrid_model = load_model(
|
| 81 |
+
ImprovedWav2Vec2CQCCDetector(num_classes=2).to(device),
|
| 82 |
+
"custom_hybrid.pth"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# -------------------------------------------------------
|
| 87 |
+
# Audio Preprocessing (mirrors dataset.py __getitem__)
|
| 88 |
+
# -------------------------------------------------------
|
| 89 |
+
|
| 90 |
+
TARGET_LEN = 64600 # AASIST standard: 4.025s at 16kHz
|
| 91 |
+
CQCC_N_BINS = 60 # Matches AudioDataset default
|
| 92 |
+
|
| 93 |
+
# 50% overlap: each step is half a window (~2s), giving smooth temporal curves
|
| 94 |
+
# without running 4x Wav2Vec2 passes per second.
|
| 95 |
+
WINDOW_STEP = TARGET_LEN // 2
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def preprocess_window(wav_np: np.ndarray) -> tuple[torch.Tensor, torch.Tensor]:
|
| 99 |
+
"""
|
| 100 |
+
Crop or pad a single audio window to TARGET_LEN, then compute waveform
|
| 101 |
+
and CQCC tensors — identical to AudioDataset.__getitem__ (non-augmented).
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
wav : (1, TARGET_LEN) float32 tensor
|
| 105 |
+
cqcc : (1, 20, T) float32 tensor
|
| 106 |
+
"""
|
| 107 |
+
# Center-crop or zero-pad to exactly TARGET_LEN (matches eval path in dataset.py)
|
| 108 |
+
if len(wav_np) > TARGET_LEN:
|
| 109 |
+
start = (len(wav_np) - TARGET_LEN) // 2
|
| 110 |
+
wav_np = wav_np[start : start + TARGET_LEN]
|
| 111 |
+
elif len(wav_np) < TARGET_LEN:
|
| 112 |
+
wav_np = np.pad(wav_np, (0, TARGET_LEN - len(wav_np)), mode='constant')
|
| 113 |
+
|
| 114 |
+
wav = torch.from_numpy(wav_np).unsqueeze(0).float()
|
| 115 |
+
cqcc = compute_cqcc(wav_np, n_bins=CQCC_N_BINS) # → (1, 20, T)
|
| 116 |
+
return wav, cqcc
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def run_window(wav: torch.Tensor, cqcc: torch.Tensor) -> dict:
|
| 120 |
+
"""
|
| 121 |
+
Run all four models on a single window and return fake probabilities (0–100).
|
| 122 |
+
"""
|
| 123 |
+
wav = wav.unsqueeze(0).to(device) # (1, 1, TARGET_LEN)
|
| 124 |
+
cqcc = cqcc.unsqueeze(0).to(device) # (1, 1, 20, T)
|
| 125 |
+
|
| 126 |
+
with torch.no_grad():
|
| 127 |
+
w2v_prob = torch.softmax(wav2vec_model(wav), dim=1)[0][1].item()
|
| 128 |
+
aasist_prob = torch.softmax(aasist_model(wav), dim=1)[0][1].item()
|
| 129 |
+
cqcc_prob = torch.softmax(cqcc_baseline_model(cqcc), dim=1)[0][1].item()
|
| 130 |
+
custom_prob = torch.softmax(custom_hybrid_model(wav, cqcc), dim=1)[0][1].item()
|
| 131 |
+
|
| 132 |
+
return {
|
| 133 |
+
"wav2vec2": round(w2v_prob * 100, 2),
|
| 134 |
+
"aasist": round(aasist_prob * 100, 2),
|
| 135 |
+
"cqcc_baseline": round(cqcc_prob * 100, 2),
|
| 136 |
+
"custom_hybrid": round(custom_prob * 100, 2),
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def aggregate_prediction(fake_prob_pct: float) -> dict:
|
| 141 |
+
"""Convert a mean fake probability into the standard prediction dict."""
|
| 142 |
+
return {
|
| 143 |
+
"prediction": "FAKE" if fake_prob_pct > 50 else "REAL",
|
| 144 |
+
"fake_probability": fake_prob_pct,
|
| 145 |
+
"real_probability": round(100 - fake_prob_pct, 2),
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# -------------------------------------------------------
|
| 150 |
+
# Prediction Endpoint
|
| 151 |
+
# -------------------------------------------------------
|
| 152 |
+
@app.post("/api/predict")
|
| 153 |
+
async def predict(file: UploadFile = File(...)):
|
| 154 |
+
temp_path = f"temp_{file.filename}"
|
| 155 |
+
try:
|
| 156 |
+
with open(temp_path, "wb") as f:
|
| 157 |
+
f.write(await file.read())
|
| 158 |
+
|
| 159 |
+
# Load at 16 kHz mono — identical to librosa.load call in dataset.py
|
| 160 |
+
wav_np, sr = librosa.load(temp_path, sr=16000, mono=True)
|
| 161 |
+
|
| 162 |
+
# -------------------------------------------------------
|
| 163 |
+
# Slice into overlapping windows of TARGET_LEN samples.
|
| 164 |
+
# Step = 50% overlap. Very short clips produce a single window.
|
| 165 |
+
# -------------------------------------------------------
|
| 166 |
+
total_samples = len(wav_np)
|
| 167 |
+
starts = list(range(0, total_samples, WINDOW_STEP))
|
| 168 |
+
|
| 169 |
+
window_probs = [] # per-window fake-probability dicts
|
| 170 |
+
window_labels = [] # x-axis: start of each window in seconds
|
| 171 |
+
|
| 172 |
+
for start in starts:
|
| 173 |
+
chunk = wav_np[start : start + TARGET_LEN]
|
| 174 |
+
wav_t, cqcc_t = preprocess_window(chunk)
|
| 175 |
+
probs = run_window(wav_t, cqcc_t)
|
| 176 |
+
window_probs.append(probs)
|
| 177 |
+
|
| 178 |
+
start_sec = round(start / sr, 2)
|
| 179 |
+
window_labels.append(start_sec)
|
| 180 |
+
|
| 181 |
+
# -------------------------------------------------------
|
| 182 |
+
# Overall prediction = mean fake probability across all windows
|
| 183 |
+
# -------------------------------------------------------
|
| 184 |
+
model_keys = ["wav2vec2", "aasist", "cqcc_baseline", "custom_hybrid"]
|
| 185 |
+
overall = {}
|
| 186 |
+
for key in model_keys:
|
| 187 |
+
mean_fake = round(
|
| 188 |
+
sum(w[key] for w in window_probs) / len(window_probs), 2
|
| 189 |
+
)
|
| 190 |
+
overall[key] = aggregate_prediction(mean_fake)
|
| 191 |
+
|
| 192 |
+
# -------------------------------------------------------
|
| 193 |
+
# Time-series data for the frontend chart
|
| 194 |
+
# -------------------------------------------------------
|
| 195 |
+
timeline = {
|
| 196 |
+
key: [w[key] for w in window_probs]
|
| 197 |
+
for key in model_keys
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
return JSONResponse({
|
| 201 |
+
"overall": overall, # {model: {prediction, fake_probability, real_probability}}
|
| 202 |
+
"timeline": timeline, # {model: [fake_prob_pct, ...]} — one value per window
|
| 203 |
+
"window_labels": window_labels, # [start_sec, ...] — x-axis in seconds (starts at 0.0)
|
| 204 |
+
})
|
| 205 |
+
|
| 206 |
+
except Exception as e:
|
| 207 |
+
return JSONResponse({"error": str(e)}, status_code=500)
|
| 208 |
+
|
| 209 |
+
finally:
|
| 210 |
+
if os.path.exists(temp_path):
|
| 211 |
+
os.remove(temp_path)
|
| 212 |
+
|
| 213 |
+
# -------------------------------------------------------
|
| 214 |
+
# Serve frontend
|
| 215 |
+
# -------------------------------------------------------
|
| 216 |
+
|
| 217 |
+
frontend_dir = os.path.join(os.path.dirname(__file__), "..", "frontend")
|
| 218 |
+
|
| 219 |
+
if os.path.exists(frontend_dir):
|
| 220 |
+
app.mount("/", StaticFiles(directory=frontend_dir, html=True), name="frontend")
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# -------------------------------------------------------
|
| 224 |
+
# Run Server
|
| 225 |
+
# -------------------------------------------------------
|
| 226 |
+
|
| 227 |
+
if __name__ == "__main__":
|
| 228 |
+
|
| 229 |
+
import uvicorn
|
| 230 |
+
|
| 231 |
+
print("Starting server at http://127.0.0.1:8000")
|
| 232 |
+
|
| 233 |
+
uvicorn.run(app, host="127.0.0.1", port=8000)
|
backend/dataset.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import hashlib
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio
|
| 5 |
+
import numpy as np
|
| 6 |
+
from torch.utils.data import Dataset
|
| 7 |
+
import librosa
|
| 8 |
+
from scipy.fftpack import dct
|
| 9 |
+
|
| 10 |
+
def compute_cqcc(wav_np, n_bins, sample_rate=16000, hop_length=160, num_coeffs=20):
|
| 11 |
+
"""Compute CQCC features from a mono waveform numpy array."""
|
| 12 |
+
try:
|
| 13 |
+
cqt = np.abs(
|
| 14 |
+
librosa.cqt(
|
| 15 |
+
wav_np,
|
| 16 |
+
sr=sample_rate,
|
| 17 |
+
n_bins=n_bins,
|
| 18 |
+
hop_length=hop_length,
|
| 19 |
+
fmin=librosa.note_to_hz('C1')
|
| 20 |
+
)
|
| 21 |
+
)
|
| 22 |
+
log_power = librosa.amplitude_to_db(cqt, ref=np.max)
|
| 23 |
+
cqcc = dct(log_power, type=2, axis=0, norm='ortho')[:num_coeffs]
|
| 24 |
+
return torch.from_numpy(cqcc).unsqueeze(0).float()
|
| 25 |
+
except Exception:
|
| 26 |
+
# Fallback for very short or invalid audio.
|
| 27 |
+
return torch.zeros((1, num_coeffs, 10), dtype=torch.float32)
|
| 28 |
+
|
| 29 |
+
class AudioDataset(Dataset):
|
| 30 |
+
def __init__(self, data_dir=None, n_bins=60, augment=False, cqcc_cache_dir=None, target_lang=None):
|
| 31 |
+
if data_dir is None:
|
| 32 |
+
# Check if MLAAD-tiny exists, else fallback to 'data'
|
| 33 |
+
mlaad_dir = os.path.join(os.path.dirname(__file__), "..", "MLAAD-tiny")
|
| 34 |
+
if os.path.exists(mlaad_dir):
|
| 35 |
+
data_dir = mlaad_dir
|
| 36 |
+
else:
|
| 37 |
+
data_dir = os.path.join(os.path.dirname(__file__), "..", "data")
|
| 38 |
+
|
| 39 |
+
self.data_dir = data_dir
|
| 40 |
+
self.files = []
|
| 41 |
+
self.labels = []
|
| 42 |
+
self.n_bins = n_bins
|
| 43 |
+
self.augment = augment
|
| 44 |
+
self.cqcc_cache_dir = cqcc_cache_dir
|
| 45 |
+
self.target_lang = target_lang
|
| 46 |
+
|
| 47 |
+
real_path = os.path.join(data_dir, "original")
|
| 48 |
+
if not os.path.exists(real_path):
|
| 49 |
+
real_path = os.path.join(data_dir, "real")
|
| 50 |
+
|
| 51 |
+
fake_path = os.path.join(data_dir, "fake")
|
| 52 |
+
|
| 53 |
+
for root, dirs, files in os.walk(real_path):
|
| 54 |
+
dirs.sort()
|
| 55 |
+
files.sort()
|
| 56 |
+
for f in files:
|
| 57 |
+
if f.endswith('.wav') or f.endswith('.flac'):
|
| 58 |
+
if self.target_lang:
|
| 59 |
+
rel_root = os.path.relpath(root, real_path).replace('\\', '/')
|
| 60 |
+
if not rel_root.startswith(self.target_lang):
|
| 61 |
+
continue
|
| 62 |
+
self.files.append(os.path.join(root, f))
|
| 63 |
+
self.labels.append(0) # 0 = Real
|
| 64 |
+
|
| 65 |
+
for root, dirs, files in os.walk(fake_path):
|
| 66 |
+
dirs.sort()
|
| 67 |
+
files.sort()
|
| 68 |
+
for f in files:
|
| 69 |
+
if f.endswith('.wav') or f.endswith('.flac'):
|
| 70 |
+
if self.target_lang:
|
| 71 |
+
rel_root = os.path.relpath(root, fake_path).replace('\\', '/')
|
| 72 |
+
if not rel_root.startswith(self.target_lang):
|
| 73 |
+
continue
|
| 74 |
+
self.files.append(os.path.join(root, f))
|
| 75 |
+
self.labels.append(1) # 1 = Fake
|
| 76 |
+
|
| 77 |
+
if self.cqcc_cache_dir is not None:
|
| 78 |
+
os.makedirs(self.cqcc_cache_dir, exist_ok=True)
|
| 79 |
+
|
| 80 |
+
def __len__(self):
|
| 81 |
+
return len(self.files)
|
| 82 |
+
|
| 83 |
+
def _cqcc_cache_path(self, audio_path):
|
| 84 |
+
rel_path = os.path.relpath(audio_path, start=self.data_dir)
|
| 85 |
+
cache_key = hashlib.md5(audio_path.encode("utf-8")).hexdigest()
|
| 86 |
+
rel_stem = os.path.splitext(rel_path)[0]
|
| 87 |
+
safe_name = rel_stem.replace(os.sep, "__")
|
| 88 |
+
return os.path.join(self.cqcc_cache_dir, f"{safe_name}_{cache_key}.pt")
|
| 89 |
+
|
| 90 |
+
def _load_or_compute_cqcc(self, audio_path, wav_np, is_augmented=False):
|
| 91 |
+
if self.cqcc_cache_dir is None or is_augmented:
|
| 92 |
+
return compute_cqcc(wav_np, n_bins=self.n_bins)
|
| 93 |
+
|
| 94 |
+
cache_path = self._cqcc_cache_path(audio_path)
|
| 95 |
+
if os.path.exists(cache_path):
|
| 96 |
+
return torch.load(cache_path, map_location="cpu")
|
| 97 |
+
|
| 98 |
+
cqcc = compute_cqcc(wav_np, n_bins=self.n_bins)
|
| 99 |
+
torch.save(cqcc, cache_path)
|
| 100 |
+
return cqcc
|
| 101 |
+
|
| 102 |
+
def precompute_cqcc_cache(self, force=False):
|
| 103 |
+
"""Materialize CQCC features to disk so training can reuse them."""
|
| 104 |
+
import tqdm
|
| 105 |
+
if self.cqcc_cache_dir is None:
|
| 106 |
+
raise ValueError("cqcc_cache_dir must be set to precompute CQCC features.")
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
from tqdm.notebook import tqdm
|
| 110 |
+
iterable_files = tqdm(self.files, desc="Precomputing CQCC Cache")
|
| 111 |
+
except ImportError:
|
| 112 |
+
iterable_files = self.files
|
| 113 |
+
|
| 114 |
+
total = len(self.files)
|
| 115 |
+
for idx, audio_path in enumerate(iterable_files):
|
| 116 |
+
cache_path = self._cqcc_cache_path(audio_path)
|
| 117 |
+
if not force and os.path.exists(cache_path):
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
try:
|
| 121 |
+
wav_np, _ = librosa.load(audio_path, sr=16000, mono=True)
|
| 122 |
+
cqcc = compute_cqcc(wav_np, n_bins=self.n_bins)
|
| 123 |
+
torch.save(cqcc, cache_path)
|
| 124 |
+
except Exception as e:
|
| 125 |
+
print(f"Error precomputing CQCC for {audio_path}: {e}")
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if (idx + 1) % 100 == 0 or idx + 1 == total:
|
| 129 |
+
print(f"Precomputed CQCC {idx + 1}/{total}")
|
| 130 |
+
|
| 131 |
+
def __getitem__(self, idx):
|
| 132 |
+
audio_path = self.files[idx]
|
| 133 |
+
wav_np, sr = librosa.load(audio_path, sr=16000, mono=True)
|
| 134 |
+
|
| 135 |
+
is_augmented = False
|
| 136 |
+
# Augmentation on raw audio (Data Augmentation for generalizability)
|
| 137 |
+
if self.augment and np.random.rand() < 0.3:
|
| 138 |
+
# Apply only ONE augmentation type per sample to avoid over-modification
|
| 139 |
+
aug_type = np.random.choice(['noise', 'speed', 'pitch'], p=[0.33, 0.33, 0.34])
|
| 140 |
+
|
| 141 |
+
if aug_type == 'noise':
|
| 142 |
+
# SNR-based noise addition (reverted to original robust method)
|
| 143 |
+
signal_power = np.mean(wav_np**2)
|
| 144 |
+
if signal_power > 1e-10:
|
| 145 |
+
snr_db = np.random.uniform(10, 30)
|
| 146 |
+
snr_linear = 10**(snr_db / 10)
|
| 147 |
+
noise_power = signal_power / snr_linear
|
| 148 |
+
noise = np.random.randn(len(wav_np)) * np.sqrt(noise_power)
|
| 149 |
+
wav_np = wav_np + noise
|
| 150 |
+
is_augmented = True
|
| 151 |
+
elif aug_type == 'speed':
|
| 152 |
+
# Mild speed perturbation
|
| 153 |
+
speed_factor = np.random.uniform(0.95, 1.05)
|
| 154 |
+
wav_np = librosa.effects.time_stretch(wav_np, rate=speed_factor)
|
| 155 |
+
is_augmented = True
|
| 156 |
+
elif aug_type == 'pitch':
|
| 157 |
+
# Subtle pitch shift
|
| 158 |
+
n_steps = np.random.uniform(-1, 1)
|
| 159 |
+
wav_np = librosa.effects.pitch_shift(wav_np, sr=sr, n_steps=n_steps)
|
| 160 |
+
is_augmented = True
|
| 161 |
+
|
| 162 |
+
# Crop or pad to exactly 64600 samples (AASIST standard)
|
| 163 |
+
target_len = 64600
|
| 164 |
+
if len(wav_np) > target_len:
|
| 165 |
+
# Center crop or random crop for augment instead of taking just the start.
|
| 166 |
+
if self.augment:
|
| 167 |
+
start = np.random.randint(0, len(wav_np) - target_len)
|
| 168 |
+
else:
|
| 169 |
+
start = (len(wav_np) - target_len) // 2
|
| 170 |
+
wav_np = wav_np[start:start+target_len]
|
| 171 |
+
elif len(wav_np) < target_len:
|
| 172 |
+
pad = target_len - len(wav_np)
|
| 173 |
+
wav_np = np.pad(wav_np, (0, pad), 'constant')
|
| 174 |
+
|
| 175 |
+
wav = torch.from_numpy(wav_np).unsqueeze(0).float()
|
| 176 |
+
|
| 177 |
+
cqcc = self._load_or_compute_cqcc(audio_path, wav_np, is_augmented=is_augmented)
|
| 178 |
+
|
| 179 |
+
return wav, cqcc, self.labels[idx]
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def collate_variable_length(batch):
|
| 183 |
+
|
| 184 |
+
wavs, cqccs, labels = zip(*batch)
|
| 185 |
+
labels = torch.tensor(labels)
|
| 186 |
+
|
| 187 |
+
# ---------- WAVE ----------
|
| 188 |
+
max_wav_len = max(w.shape[-1] for w in wavs)
|
| 189 |
+
|
| 190 |
+
wavs_padded = []
|
| 191 |
+
for w in wavs:
|
| 192 |
+
if w.shape[-1] < max_wav_len:
|
| 193 |
+
pad = max_wav_len - w.shape[-1]
|
| 194 |
+
w = torch.nn.functional.pad(w, (0, pad))
|
| 195 |
+
wavs_padded.append(w)
|
| 196 |
+
|
| 197 |
+
wavs = torch.stack(wavs_padded, dim=0)
|
| 198 |
+
|
| 199 |
+
# ---------- CQCC ----------
|
| 200 |
+
max_cqcc_len = max(c.shape[-1] for c in cqccs)
|
| 201 |
+
|
| 202 |
+
cqccs_padded = []
|
| 203 |
+
for c in cqccs:
|
| 204 |
+
if c.shape[-1] < max_cqcc_len:
|
| 205 |
+
pad = max_cqcc_len - c.shape[-1]
|
| 206 |
+
c = torch.nn.functional.pad(c, (0, pad))
|
| 207 |
+
cqccs_padded.append(c)
|
| 208 |
+
|
| 209 |
+
cqccs = torch.stack(cqccs_padded, dim=0)
|
| 210 |
+
|
| 211 |
+
return wavs, cqccs, labels
|
backend/models.py
ADDED
|
@@ -0,0 +1,1019 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from transformers import Wav2Vec2Model
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# ============================================================
|
| 8 |
+
# 1. Wav2Vec2 Detector (Self-supervised Transformer Baseline)
|
| 9 |
+
# ============================================================
|
| 10 |
+
class AttentivePooling(nn.Module):
|
| 11 |
+
def __init__(self, dim):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.attn = nn.Sequential(
|
| 14 |
+
nn.Linear(dim, dim),
|
| 15 |
+
nn.Tanh(),
|
| 16 |
+
nn.Linear(dim, 1)
|
| 17 |
+
)
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
w = torch.softmax(self.attn(x), dim=1)
|
| 20 |
+
return torch.sum(w * x, dim=1)
|
| 21 |
+
|
| 22 |
+
class Wav2Vec2SpoofDetector(nn.Module):
|
| 23 |
+
def __init__(self, num_classes=2, model_name="facebook/wav2vec2-base"):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.wav2vec = Wav2Vec2Model.from_pretrained(model_name)
|
| 26 |
+
|
| 27 |
+
#freeze model
|
| 28 |
+
for param in self.wav2vec.parameters():
|
| 29 |
+
param.requires_grad = False
|
| 30 |
+
|
| 31 |
+
hidden = self.wav2vec.config.hidden_size
|
| 32 |
+
self.pool = AttentivePooling(hidden)
|
| 33 |
+
self.classifier = nn.Sequential(
|
| 34 |
+
nn.LayerNorm(hidden),
|
| 35 |
+
nn.Dropout(0.2),
|
| 36 |
+
nn.Linear(hidden, num_classes)
|
| 37 |
+
)
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
if x.dim() == 3:
|
| 40 |
+
x = x.squeeze(1)
|
| 41 |
+
out = self.wav2vec(x).last_hidden_state
|
| 42 |
+
pooled = self.pool(out)
|
| 43 |
+
return self.classifier(pooled)
|
| 44 |
+
|
| 45 |
+
# ============================================================
|
| 46 |
+
# 2. AASIST (SOTA Graph-based Baseline)
|
| 47 |
+
# ============================================================
|
| 48 |
+
|
| 49 |
+
import random
|
| 50 |
+
from typing import Union
|
| 51 |
+
import numpy as np
|
| 52 |
+
from torch import Tensor
|
| 53 |
+
|
| 54 |
+
# Original simplistic Graph Attention/Block kept for the Custom model dependent on it
|
| 55 |
+
class GraphAttention(nn.Module):
|
| 56 |
+
def __init__(self, in_dim, out_dim):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.fc = nn.Linear(in_dim, out_dim)
|
| 59 |
+
self.attn = nn.Linear(out_dim * 2, 1)
|
| 60 |
+
|
| 61 |
+
def forward(self, x):
|
| 62 |
+
h = self.fc(x)
|
| 63 |
+
# Instead of allocating O(N^2 * D) tensor arrays for pairwise combinations,
|
| 64 |
+
# we can decompose the linear attention matrix and use broadcasting!
|
| 65 |
+
# Memory consumption goes from ~10GB on N=400 to ~2MB.
|
| 66 |
+
W = self.attn.weight.squeeze()
|
| 67 |
+
D = h.shape[-1]
|
| 68 |
+
|
| 69 |
+
W_1 = W[:D]
|
| 70 |
+
W_2 = W[D:]
|
| 71 |
+
|
| 72 |
+
# Compute individual node scores: shape (B, N, 1)
|
| 73 |
+
score_i = torch.matmul(h, W_1).unsqueeze(-1)
|
| 74 |
+
score_j = torch.matmul(h, W_2).unsqueeze(-1)
|
| 75 |
+
|
| 76 |
+
# Broadcast (B, N, 1) + (B, 1, N) -> (B, N, N)
|
| 77 |
+
e = score_i + score_j.transpose(1, 2)
|
| 78 |
+
|
| 79 |
+
if self.attn.bias is not None:
|
| 80 |
+
e = e + self.attn.bias
|
| 81 |
+
|
| 82 |
+
alpha = F.softmax(e, dim=-1)
|
| 83 |
+
out = torch.matmul(alpha, h)
|
| 84 |
+
return out
|
| 85 |
+
|
| 86 |
+
class GraphBlock(nn.Module):
|
| 87 |
+
def __init__(self, dim):
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.gat = GraphAttention(dim, dim)
|
| 90 |
+
self.norm = nn.LayerNorm(dim)
|
| 91 |
+
self.dropout = nn.Dropout(0.2)
|
| 92 |
+
|
| 93 |
+
def forward(self, x):
|
| 94 |
+
res = x
|
| 95 |
+
x = self.gat(x)
|
| 96 |
+
x = self.dropout(x)
|
| 97 |
+
x = self.norm(x + res)
|
| 98 |
+
return x
|
| 99 |
+
|
| 100 |
+
class GraphAttentionLayer(nn.Module):
|
| 101 |
+
def __init__(self, in_dim, out_dim, **kwargs):
|
| 102 |
+
super().__init__()
|
| 103 |
+
|
| 104 |
+
# attention map
|
| 105 |
+
self.att_proj = nn.Linear(in_dim, out_dim)
|
| 106 |
+
self.att_weight = self._init_new_params(out_dim, 1)
|
| 107 |
+
|
| 108 |
+
# project
|
| 109 |
+
self.proj_with_att = nn.Linear(in_dim, out_dim)
|
| 110 |
+
self.proj_without_att = nn.Linear(in_dim, out_dim)
|
| 111 |
+
|
| 112 |
+
# batch norm
|
| 113 |
+
self.bn = nn.BatchNorm1d(out_dim)
|
| 114 |
+
|
| 115 |
+
# dropout for inputs
|
| 116 |
+
self.input_drop = nn.Dropout(p=0.2)
|
| 117 |
+
|
| 118 |
+
# activate
|
| 119 |
+
self.act = nn.SELU(inplace=True)
|
| 120 |
+
|
| 121 |
+
# temperature
|
| 122 |
+
self.temp = 1.
|
| 123 |
+
if "temperature" in kwargs:
|
| 124 |
+
self.temp = kwargs["temperature"]
|
| 125 |
+
|
| 126 |
+
def forward(self, x):
|
| 127 |
+
'''
|
| 128 |
+
x :(#bs, #node, #dim)
|
| 129 |
+
'''
|
| 130 |
+
# apply input dropout
|
| 131 |
+
x = self.input_drop(x)
|
| 132 |
+
|
| 133 |
+
# derive attention map
|
| 134 |
+
att_map = self._derive_att_map(x)
|
| 135 |
+
|
| 136 |
+
# projection
|
| 137 |
+
x = self._project(x, att_map)
|
| 138 |
+
|
| 139 |
+
# apply batch norm
|
| 140 |
+
x = self._apply_BN(x)
|
| 141 |
+
x = self.act(x)
|
| 142 |
+
return x
|
| 143 |
+
|
| 144 |
+
def _pairwise_mul_nodes(self, x):
|
| 145 |
+
'''
|
| 146 |
+
Calculates pairwise multiplication of nodes.
|
| 147 |
+
- for attention map
|
| 148 |
+
x :(#bs, #node, #dim)
|
| 149 |
+
out_shape :(#bs, #node, #node, #dim)
|
| 150 |
+
'''
|
| 151 |
+
|
| 152 |
+
nb_nodes = x.size(1)
|
| 153 |
+
x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
|
| 154 |
+
x_mirror = x.transpose(1, 2)
|
| 155 |
+
|
| 156 |
+
return x * x_mirror
|
| 157 |
+
|
| 158 |
+
def _derive_att_map(self, x):
|
| 159 |
+
'''
|
| 160 |
+
x :(#bs, #node, #dim)
|
| 161 |
+
out_shape :(#bs, #node, #node, 1)
|
| 162 |
+
'''
|
| 163 |
+
att_map = self._pairwise_mul_nodes(x)
|
| 164 |
+
# size: (#bs, #node, #node, #dim_out)
|
| 165 |
+
att_map = torch.tanh(self.att_proj(att_map))
|
| 166 |
+
# size: (#bs, #node, #node, 1)
|
| 167 |
+
att_map = torch.matmul(att_map, self.att_weight)
|
| 168 |
+
|
| 169 |
+
# apply temperature
|
| 170 |
+
att_map = att_map / self.temp
|
| 171 |
+
|
| 172 |
+
att_map = F.softmax(att_map, dim=-2)
|
| 173 |
+
|
| 174 |
+
return att_map
|
| 175 |
+
|
| 176 |
+
def _project(self, x, att_map):
|
| 177 |
+
x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
|
| 178 |
+
x2 = self.proj_without_att(x)
|
| 179 |
+
|
| 180 |
+
return x1 + x2
|
| 181 |
+
|
| 182 |
+
def _apply_BN(self, x):
|
| 183 |
+
org_size = x.size()
|
| 184 |
+
x = x.view(-1, org_size[-1])
|
| 185 |
+
x = self.bn(x)
|
| 186 |
+
x = x.view(org_size)
|
| 187 |
+
|
| 188 |
+
return x
|
| 189 |
+
|
| 190 |
+
def _init_new_params(self, *size):
|
| 191 |
+
out = nn.Parameter(torch.FloatTensor(*size))
|
| 192 |
+
nn.init.xavier_normal_(out)
|
| 193 |
+
return out
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class HtrgGraphAttentionLayer(nn.Module):
|
| 197 |
+
def __init__(self, in_dim, out_dim, **kwargs):
|
| 198 |
+
super().__init__()
|
| 199 |
+
|
| 200 |
+
self.proj_type1 = nn.Linear(in_dim, in_dim)
|
| 201 |
+
self.proj_type2 = nn.Linear(in_dim, in_dim)
|
| 202 |
+
|
| 203 |
+
# attention map
|
| 204 |
+
self.att_proj = nn.Linear(in_dim, out_dim)
|
| 205 |
+
self.att_projM = nn.Linear(in_dim, out_dim)
|
| 206 |
+
|
| 207 |
+
self.att_weight11 = self._init_new_params(out_dim, 1)
|
| 208 |
+
self.att_weight22 = self._init_new_params(out_dim, 1)
|
| 209 |
+
self.att_weight12 = self._init_new_params(out_dim, 1)
|
| 210 |
+
self.att_weightM = self._init_new_params(out_dim, 1)
|
| 211 |
+
|
| 212 |
+
# project
|
| 213 |
+
self.proj_with_att = nn.Linear(in_dim, out_dim)
|
| 214 |
+
self.proj_without_att = nn.Linear(in_dim, out_dim)
|
| 215 |
+
|
| 216 |
+
self.proj_with_attM = nn.Linear(in_dim, out_dim)
|
| 217 |
+
self.proj_without_attM = nn.Linear(in_dim, out_dim)
|
| 218 |
+
|
| 219 |
+
# batch norm
|
| 220 |
+
self.bn = nn.BatchNorm1d(out_dim)
|
| 221 |
+
|
| 222 |
+
# dropout for inputs
|
| 223 |
+
self.input_drop = nn.Dropout(p=0.2)
|
| 224 |
+
|
| 225 |
+
# activate
|
| 226 |
+
self.act = nn.SELU(inplace=True)
|
| 227 |
+
|
| 228 |
+
# temperature
|
| 229 |
+
self.temp = 1.
|
| 230 |
+
if "temperature" in kwargs:
|
| 231 |
+
self.temp = kwargs["temperature"]
|
| 232 |
+
|
| 233 |
+
def forward(self, x1, x2, master=None):
|
| 234 |
+
'''
|
| 235 |
+
x1 :(#bs, #node, #dim)
|
| 236 |
+
x2 :(#bs, #node, #dim)
|
| 237 |
+
'''
|
| 238 |
+
num_type1 = x1.size(1)
|
| 239 |
+
num_type2 = x2.size(1)
|
| 240 |
+
|
| 241 |
+
x1 = self.proj_type1(x1)
|
| 242 |
+
x2 = self.proj_type2(x2)
|
| 243 |
+
|
| 244 |
+
x = torch.cat([x1, x2], dim=1)
|
| 245 |
+
|
| 246 |
+
if master is None:
|
| 247 |
+
master = torch.mean(x, dim=1, keepdim=True)
|
| 248 |
+
|
| 249 |
+
# apply input dropout
|
| 250 |
+
x = self.input_drop(x)
|
| 251 |
+
|
| 252 |
+
# derive attention map
|
| 253 |
+
att_map = self._derive_att_map(x, num_type1, num_type2)
|
| 254 |
+
|
| 255 |
+
# directional edge for master node
|
| 256 |
+
master = self._update_master(x, master)
|
| 257 |
+
|
| 258 |
+
# projection
|
| 259 |
+
x = self._project(x, att_map)
|
| 260 |
+
|
| 261 |
+
# apply batch norm
|
| 262 |
+
x = self._apply_BN(x)
|
| 263 |
+
x = self.act(x)
|
| 264 |
+
|
| 265 |
+
x1 = x.narrow(1, 0, num_type1)
|
| 266 |
+
x2 = x.narrow(1, num_type1, num_type2)
|
| 267 |
+
|
| 268 |
+
return x1, x2, master
|
| 269 |
+
|
| 270 |
+
def _update_master(self, x, master):
|
| 271 |
+
|
| 272 |
+
att_map = self._derive_att_map_master(x, master)
|
| 273 |
+
master = self._project_master(x, master, att_map)
|
| 274 |
+
|
| 275 |
+
return master
|
| 276 |
+
|
| 277 |
+
def _pairwise_mul_nodes(self, x):
|
| 278 |
+
'''
|
| 279 |
+
Calculates pairwise multiplication of nodes.
|
| 280 |
+
- for attention map
|
| 281 |
+
x :(#bs, #node, #dim)
|
| 282 |
+
out_shape :(#bs, #node, #node, #dim)
|
| 283 |
+
'''
|
| 284 |
+
|
| 285 |
+
nb_nodes = x.size(1)
|
| 286 |
+
x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
|
| 287 |
+
x_mirror = x.transpose(1, 2)
|
| 288 |
+
|
| 289 |
+
return x * x_mirror
|
| 290 |
+
|
| 291 |
+
def _derive_att_map_master(self, x, master):
|
| 292 |
+
'''
|
| 293 |
+
x :(#bs, #node, #dim)
|
| 294 |
+
out_shape :(#bs, #node, #node, 1)
|
| 295 |
+
'''
|
| 296 |
+
att_map = x * master
|
| 297 |
+
att_map = torch.tanh(self.att_projM(att_map))
|
| 298 |
+
|
| 299 |
+
att_map = torch.matmul(att_map, self.att_weightM)
|
| 300 |
+
|
| 301 |
+
# apply temperature
|
| 302 |
+
att_map = att_map / self.temp
|
| 303 |
+
|
| 304 |
+
att_map = F.softmax(att_map, dim=-2)
|
| 305 |
+
|
| 306 |
+
return att_map
|
| 307 |
+
|
| 308 |
+
def _derive_att_map(self, x, num_type1, num_type2):
|
| 309 |
+
'''
|
| 310 |
+
x :(#bs, #node, #dim)
|
| 311 |
+
out_shape :(#bs, #node, #node, 1)
|
| 312 |
+
'''
|
| 313 |
+
att_map = self._pairwise_mul_nodes(x)
|
| 314 |
+
# size: (#bs, #node, #node, #dim_out)
|
| 315 |
+
att_map = torch.tanh(self.att_proj(att_map))
|
| 316 |
+
# size: (#bs, #node, #node, 1)
|
| 317 |
+
|
| 318 |
+
att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1)
|
| 319 |
+
|
| 320 |
+
att_board[:, :num_type1, :num_type1, :] = torch.matmul(
|
| 321 |
+
att_map[:, :num_type1, :num_type1, :], self.att_weight11)
|
| 322 |
+
att_board[:, num_type1:, num_type1:, :] = torch.matmul(
|
| 323 |
+
att_map[:, num_type1:, num_type1:, :], self.att_weight22)
|
| 324 |
+
att_board[:, :num_type1, num_type1:, :] = torch.matmul(
|
| 325 |
+
att_map[:, :num_type1, num_type1:, :], self.att_weight12)
|
| 326 |
+
att_board[:, num_type1:, :num_type1, :] = torch.matmul(
|
| 327 |
+
att_map[:, num_type1:, :num_type1, :], self.att_weight12)
|
| 328 |
+
|
| 329 |
+
att_map = att_board
|
| 330 |
+
|
| 331 |
+
# apply temperature
|
| 332 |
+
att_map = att_map / self.temp
|
| 333 |
+
|
| 334 |
+
att_map = F.softmax(att_map, dim=-2)
|
| 335 |
+
|
| 336 |
+
return att_map
|
| 337 |
+
|
| 338 |
+
def _project(self, x, att_map):
|
| 339 |
+
x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
|
| 340 |
+
x2 = self.proj_without_att(x)
|
| 341 |
+
|
| 342 |
+
return x1 + x2
|
| 343 |
+
|
| 344 |
+
def _project_master(self, x, master, att_map):
|
| 345 |
+
|
| 346 |
+
x1 = self.proj_with_attM(torch.matmul(
|
| 347 |
+
att_map.squeeze(-1).unsqueeze(1), x))
|
| 348 |
+
x2 = self.proj_without_attM(master)
|
| 349 |
+
|
| 350 |
+
return x1 + x2
|
| 351 |
+
|
| 352 |
+
def _apply_BN(self, x):
|
| 353 |
+
org_size = x.size()
|
| 354 |
+
x = x.view(-1, org_size[-1])
|
| 355 |
+
x = self.bn(x)
|
| 356 |
+
x = x.view(org_size)
|
| 357 |
+
|
| 358 |
+
return x
|
| 359 |
+
|
| 360 |
+
def _init_new_params(self, *size):
|
| 361 |
+
out = nn.Parameter(torch.FloatTensor(*size))
|
| 362 |
+
nn.init.xavier_normal_(out)
|
| 363 |
+
return out
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
class GraphPool(nn.Module):
|
| 367 |
+
def __init__(self, k: float, in_dim: int, p: Union[float, int]):
|
| 368 |
+
super().__init__()
|
| 369 |
+
self.k = k
|
| 370 |
+
self.sigmoid = nn.Sigmoid()
|
| 371 |
+
self.proj = nn.Linear(in_dim, 1)
|
| 372 |
+
self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
|
| 373 |
+
self.in_dim = in_dim
|
| 374 |
+
|
| 375 |
+
def forward(self, h):
|
| 376 |
+
Z = self.drop(h)
|
| 377 |
+
weights = self.proj(Z)
|
| 378 |
+
scores = self.sigmoid(weights)
|
| 379 |
+
new_h = self.top_k_graph(scores, h, self.k)
|
| 380 |
+
|
| 381 |
+
return new_h
|
| 382 |
+
|
| 383 |
+
def top_k_graph(self, scores, h, k):
|
| 384 |
+
_, n_nodes, n_feat = h.size()
|
| 385 |
+
n_nodes = max(int(n_nodes * k), 1)
|
| 386 |
+
_, idx = torch.topk(scores, n_nodes, dim=1)
|
| 387 |
+
idx = idx.expand(-1, -1, n_feat)
|
| 388 |
+
|
| 389 |
+
h = h * scores
|
| 390 |
+
h = torch.gather(h, 1, idx)
|
| 391 |
+
|
| 392 |
+
return h
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
class CONV(nn.Module):
|
| 396 |
+
@staticmethod
|
| 397 |
+
def to_mel(hz):
|
| 398 |
+
return 2595 * np.log10(1 + hz / 700)
|
| 399 |
+
|
| 400 |
+
@staticmethod
|
| 401 |
+
def to_hz(mel):
|
| 402 |
+
return 700 * (10**(mel / 2595) - 1)
|
| 403 |
+
|
| 404 |
+
def __init__(self,
|
| 405 |
+
out_channels,
|
| 406 |
+
kernel_size,
|
| 407 |
+
sample_rate=16000,
|
| 408 |
+
in_channels=1,
|
| 409 |
+
stride=1,
|
| 410 |
+
padding=0,
|
| 411 |
+
dilation=1,
|
| 412 |
+
bias=False,
|
| 413 |
+
groups=1,
|
| 414 |
+
mask=False):
|
| 415 |
+
super().__init__()
|
| 416 |
+
if in_channels != 1:
|
| 417 |
+
msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels)
|
| 418 |
+
raise ValueError(msg)
|
| 419 |
+
self.out_channels = out_channels
|
| 420 |
+
self.kernel_size = kernel_size
|
| 421 |
+
self.sample_rate = sample_rate
|
| 422 |
+
|
| 423 |
+
# Forcing the filters to be odd (i.e, perfectly symmetrics)
|
| 424 |
+
if kernel_size % 2 == 0:
|
| 425 |
+
self.kernel_size = self.kernel_size + 1
|
| 426 |
+
self.stride = stride
|
| 427 |
+
self.padding = padding
|
| 428 |
+
self.dilation = dilation
|
| 429 |
+
self.mask = mask
|
| 430 |
+
if bias:
|
| 431 |
+
raise ValueError('SincConv does not support bias.')
|
| 432 |
+
if groups > 1:
|
| 433 |
+
raise ValueError('SincConv does not support groups.')
|
| 434 |
+
|
| 435 |
+
NFFT = 512
|
| 436 |
+
f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
|
| 437 |
+
fmel = self.to_mel(f)
|
| 438 |
+
fmelmax = np.max(fmel)
|
| 439 |
+
fmelmin = np.min(fmel)
|
| 440 |
+
filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 1)
|
| 441 |
+
filbandwidthsf = self.to_hz(filbandwidthsmel)
|
| 442 |
+
|
| 443 |
+
self.mel = filbandwidthsf
|
| 444 |
+
self.hsupp = torch.arange(-(self.kernel_size - 1) / 2,
|
| 445 |
+
(self.kernel_size - 1) / 2 + 1)
|
| 446 |
+
self.band_pass = torch.zeros(self.out_channels, self.kernel_size)
|
| 447 |
+
for i in range(len(self.mel) - 1):
|
| 448 |
+
fmin = self.mel[i]
|
| 449 |
+
fmax = self.mel[i + 1]
|
| 450 |
+
hHigh = (2*fmax/self.sample_rate) * \
|
| 451 |
+
np.sinc(2*fmax*self.hsupp/self.sample_rate)
|
| 452 |
+
hLow = (2*fmin/self.sample_rate) * \
|
| 453 |
+
np.sinc(2*fmin*self.hsupp/self.sample_rate)
|
| 454 |
+
hideal = hHigh - hLow
|
| 455 |
+
|
| 456 |
+
self.band_pass[i, :] = Tensor(np.hamming(
|
| 457 |
+
self.kernel_size)) * Tensor(hideal)
|
| 458 |
+
|
| 459 |
+
def forward(self, x, mask=False):
|
| 460 |
+
band_pass_filter = self.band_pass.clone().to(x.device)
|
| 461 |
+
if mask:
|
| 462 |
+
A = np.random.uniform(0, 20)
|
| 463 |
+
A = int(A)
|
| 464 |
+
A0 = random.randint(0, band_pass_filter.shape[0] - A)
|
| 465 |
+
band_pass_filter[A0:A0 + A, :] = 0
|
| 466 |
+
else:
|
| 467 |
+
band_pass_filter = band_pass_filter
|
| 468 |
+
|
| 469 |
+
self.filters = (band_pass_filter).view(self.out_channels, 1,
|
| 470 |
+
self.kernel_size)
|
| 471 |
+
|
| 472 |
+
return F.conv1d(x,
|
| 473 |
+
self.filters,
|
| 474 |
+
stride=self.stride,
|
| 475 |
+
padding=self.padding,
|
| 476 |
+
dilation=self.dilation,
|
| 477 |
+
bias=None,
|
| 478 |
+
groups=1)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
class Residual_block(nn.Module):
|
| 482 |
+
def __init__(self, nb_filts, first=False):
|
| 483 |
+
super().__init__()
|
| 484 |
+
self.first = first
|
| 485 |
+
|
| 486 |
+
if not self.first:
|
| 487 |
+
self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
|
| 488 |
+
self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
|
| 489 |
+
out_channels=nb_filts[1],
|
| 490 |
+
kernel_size=(2, 3),
|
| 491 |
+
padding=(1, 1),
|
| 492 |
+
stride=1)
|
| 493 |
+
self.selu = nn.SELU(inplace=True)
|
| 494 |
+
|
| 495 |
+
self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
|
| 496 |
+
self.conv2 = nn.Conv2d(in_channels=nb_filts[1],
|
| 497 |
+
out_channels=nb_filts[1],
|
| 498 |
+
kernel_size=(2, 3),
|
| 499 |
+
padding=(0, 1),
|
| 500 |
+
stride=1)
|
| 501 |
+
|
| 502 |
+
if nb_filts[0] != nb_filts[1]:
|
| 503 |
+
self.downsample = True
|
| 504 |
+
self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
|
| 505 |
+
out_channels=nb_filts[1],
|
| 506 |
+
padding=(0, 1),
|
| 507 |
+
kernel_size=(1, 3),
|
| 508 |
+
stride=1)
|
| 509 |
+
|
| 510 |
+
else:
|
| 511 |
+
self.downsample = False
|
| 512 |
+
self.mp = nn.MaxPool2d((1, 3))
|
| 513 |
+
|
| 514 |
+
def forward(self, x):
|
| 515 |
+
identity = x
|
| 516 |
+
if not self.first:
|
| 517 |
+
out = self.bn1(x)
|
| 518 |
+
out = self.selu(out)
|
| 519 |
+
else:
|
| 520 |
+
out = x
|
| 521 |
+
out = self.conv1(x)
|
| 522 |
+
|
| 523 |
+
out = self.bn2(out)
|
| 524 |
+
out = self.selu(out)
|
| 525 |
+
out = self.conv2(out)
|
| 526 |
+
if self.downsample:
|
| 527 |
+
identity = self.conv_downsample(identity)
|
| 528 |
+
|
| 529 |
+
out += identity
|
| 530 |
+
out = self.mp(out)
|
| 531 |
+
return out
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
class AASISTModel(nn.Module):
|
| 535 |
+
def __init__(self, d_args):
|
| 536 |
+
super().__init__()
|
| 537 |
+
|
| 538 |
+
self.d_args = d_args
|
| 539 |
+
filts = d_args["filts"]
|
| 540 |
+
gat_dims = d_args["gat_dims"]
|
| 541 |
+
pool_ratios = d_args["pool_ratios"]
|
| 542 |
+
temperatures = d_args["temperatures"]
|
| 543 |
+
|
| 544 |
+
self.conv_time = CONV(out_channels=filts[0],
|
| 545 |
+
kernel_size=d_args["first_conv"],
|
| 546 |
+
in_channels=1)
|
| 547 |
+
self.first_bn = nn.BatchNorm2d(num_features=1)
|
| 548 |
+
|
| 549 |
+
self.drop = nn.Dropout(0.5, inplace=True)
|
| 550 |
+
self.drop_way = nn.Dropout(0.2, inplace=True)
|
| 551 |
+
self.selu = nn.SELU(inplace=True)
|
| 552 |
+
|
| 553 |
+
self.encoder = nn.Sequential(
|
| 554 |
+
nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
|
| 555 |
+
nn.Sequential(Residual_block(nb_filts=filts[2])),
|
| 556 |
+
nn.Sequential(Residual_block(nb_filts=filts[3])),
|
| 557 |
+
nn.Sequential(Residual_block(nb_filts=filts[4])),
|
| 558 |
+
nn.Sequential(Residual_block(nb_filts=filts[4])),
|
| 559 |
+
nn.Sequential(Residual_block(nb_filts=filts[4])))
|
| 560 |
+
|
| 561 |
+
self.pos_S = nn.Parameter(torch.randn(1, 23, filts[-1][-1]))
|
| 562 |
+
self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
|
| 563 |
+
self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
|
| 564 |
+
|
| 565 |
+
self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1],
|
| 566 |
+
gat_dims[0],
|
| 567 |
+
temperature=temperatures[0])
|
| 568 |
+
self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1],
|
| 569 |
+
gat_dims[0],
|
| 570 |
+
temperature=temperatures[1])
|
| 571 |
+
|
| 572 |
+
self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(
|
| 573 |
+
gat_dims[0], gat_dims[1], temperature=temperatures[2])
|
| 574 |
+
self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(
|
| 575 |
+
gat_dims[1], gat_dims[1], temperature=temperatures[2])
|
| 576 |
+
|
| 577 |
+
self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(
|
| 578 |
+
gat_dims[0], gat_dims[1], temperature=temperatures[2])
|
| 579 |
+
|
| 580 |
+
self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(
|
| 581 |
+
gat_dims[1], gat_dims[1], temperature=temperatures[2])
|
| 582 |
+
|
| 583 |
+
self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
|
| 584 |
+
self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
|
| 585 |
+
self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
|
| 586 |
+
self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
|
| 587 |
+
|
| 588 |
+
self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
|
| 589 |
+
self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
|
| 590 |
+
|
| 591 |
+
self.out_layer = nn.Linear(5 * gat_dims[1], 2)
|
| 592 |
+
|
| 593 |
+
def forward(self, x, Freq_aug=False):
|
| 594 |
+
|
| 595 |
+
x = x.unsqueeze(1)
|
| 596 |
+
x = self.conv_time(x, mask=Freq_aug)
|
| 597 |
+
x = x.unsqueeze(dim=1)
|
| 598 |
+
x = F.max_pool2d(torch.abs(x), (3, 3))
|
| 599 |
+
x = self.first_bn(x)
|
| 600 |
+
x = self.selu(x)
|
| 601 |
+
|
| 602 |
+
e = self.encoder(x)
|
| 603 |
+
|
| 604 |
+
e_S, _ = torch.max(torch.abs(e), dim=3)
|
| 605 |
+
e_S = e_S.transpose(1, 2) + self.pos_S
|
| 606 |
+
|
| 607 |
+
gat_S = self.GAT_layer_S(e_S)
|
| 608 |
+
out_S = self.pool_S(gat_S)
|
| 609 |
+
|
| 610 |
+
e_T, _ = torch.max(torch.abs(e), dim=2)
|
| 611 |
+
e_T = e_T.transpose(1, 2)
|
| 612 |
+
|
| 613 |
+
gat_T = self.GAT_layer_T(e_T)
|
| 614 |
+
out_T = self.pool_T(gat_T)
|
| 615 |
+
|
| 616 |
+
master1 = self.master1.expand(x.size(0), -1, -1)
|
| 617 |
+
master2 = self.master2.expand(x.size(0), -1, -1)
|
| 618 |
+
|
| 619 |
+
out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(
|
| 620 |
+
out_T, out_S, master=self.master1)
|
| 621 |
+
|
| 622 |
+
out_S1 = self.pool_hS1(out_S1)
|
| 623 |
+
out_T1 = self.pool_hT1(out_T1)
|
| 624 |
+
|
| 625 |
+
out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(
|
| 626 |
+
out_T1, out_S1, master=master1)
|
| 627 |
+
out_T1 = out_T1 + out_T_aug
|
| 628 |
+
out_S1 = out_S1 + out_S_aug
|
| 629 |
+
master1 = master1 + master_aug
|
| 630 |
+
|
| 631 |
+
out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(
|
| 632 |
+
out_T, out_S, master=self.master2)
|
| 633 |
+
out_S2 = self.pool_hS2(out_S2)
|
| 634 |
+
out_T2 = self.pool_hT2(out_T2)
|
| 635 |
+
|
| 636 |
+
out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(
|
| 637 |
+
out_T2, out_S2, master=master2)
|
| 638 |
+
out_T2 = out_T2 + out_T_aug
|
| 639 |
+
out_S2 = out_S2 + out_S_aug
|
| 640 |
+
master2 = master2 + master_aug
|
| 641 |
+
|
| 642 |
+
out_T1 = self.drop_way(out_T1)
|
| 643 |
+
out_T2 = self.drop_way(out_T2)
|
| 644 |
+
out_S1 = self.drop_way(out_S1)
|
| 645 |
+
out_S2 = self.drop_way(out_S2)
|
| 646 |
+
master1 = self.drop_way(master1)
|
| 647 |
+
master2 = self.drop_way(master2)
|
| 648 |
+
|
| 649 |
+
out_T = torch.max(out_T1, out_T2)
|
| 650 |
+
out_S = torch.max(out_S1, out_S2)
|
| 651 |
+
master = torch.max(master1, master2)
|
| 652 |
+
|
| 653 |
+
T_max, _ = torch.max(torch.abs(out_T), dim=1)
|
| 654 |
+
T_avg = torch.mean(out_T, dim=1)
|
| 655 |
+
|
| 656 |
+
S_max, _ = torch.max(torch.abs(out_S), dim=1)
|
| 657 |
+
S_avg = torch.mean(out_S, dim=1)
|
| 658 |
+
|
| 659 |
+
last_hidden = torch.cat(
|
| 660 |
+
[T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)
|
| 661 |
+
|
| 662 |
+
last_hidden = self.drop(last_hidden)
|
| 663 |
+
output = self.out_layer(last_hidden)
|
| 664 |
+
|
| 665 |
+
return last_hidden, output
|
| 666 |
+
|
| 667 |
+
class AASISTDetector(nn.Module):
|
| 668 |
+
def __init__(self, num_classes=2):
|
| 669 |
+
super().__init__()
|
| 670 |
+
d_args = {
|
| 671 |
+
"nb_samp": 64600,
|
| 672 |
+
"first_conv": 128,
|
| 673 |
+
"in_channels": 1,
|
| 674 |
+
"filts": [70, [1, 32], [32, 32], [32, 64], [64, 64]],
|
| 675 |
+
"gat_dims": [64, 32],
|
| 676 |
+
"pool_ratios": [0.5, 0.7, 0.5, 0.5],
|
| 677 |
+
"temperatures": [2.0, 2.0, 100.0]
|
| 678 |
+
}
|
| 679 |
+
self.model = AASISTModel(d_args)
|
| 680 |
+
|
| 681 |
+
# Override out_layer if not strictly 2 classes.
|
| 682 |
+
if num_classes != 2:
|
| 683 |
+
self.model.out_layer = nn.Linear(5 * d_args["gat_dims"][1], num_classes)
|
| 684 |
+
|
| 685 |
+
def forward(self, x):
|
| 686 |
+
# x is (B, 1, T) or (B, T)
|
| 687 |
+
if x.dim() == 3:
|
| 688 |
+
x = x.squeeze(1) # Convert to (B, T)
|
| 689 |
+
_, out = self.model(x)
|
| 690 |
+
return out
|
| 691 |
+
|
| 692 |
+
# ============================================================
|
| 693 |
+
# 3. CQCC Baseline Detector (Acoustic Feature Baseline)
|
| 694 |
+
# ============================================================
|
| 695 |
+
|
| 696 |
+
class CQCCBaselineDetector(nn.Module):
|
| 697 |
+
def __init__(self, num_classes=2):
|
| 698 |
+
super().__init__()
|
| 699 |
+
# Input shape expected: (B, 1, 20, T)
|
| 700 |
+
self.features = nn.Sequential(
|
| 701 |
+
nn.Conv2d(1, 16, 3, padding=1),
|
| 702 |
+
nn.BatchNorm2d(16),
|
| 703 |
+
nn.ReLU(),
|
| 704 |
+
nn.MaxPool2d(2),
|
| 705 |
+
nn.Conv2d(16, 32, 3, padding=1),
|
| 706 |
+
nn.BatchNorm2d(32),
|
| 707 |
+
nn.ReLU(),
|
| 708 |
+
nn.MaxPool2d(2),
|
| 709 |
+
nn.Conv2d(32, 64, 3, padding=1),
|
| 710 |
+
nn.BatchNorm2d(64),
|
| 711 |
+
nn.ReLU(),
|
| 712 |
+
nn.AdaptiveAvgPool2d(1)
|
| 713 |
+
)
|
| 714 |
+
self.classifier = nn.Sequential(
|
| 715 |
+
nn.Dropout(0.3),
|
| 716 |
+
nn.Linear(64, num_classes)
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
+
def forward(self, x):
|
| 720 |
+
x = self.features(x)
|
| 721 |
+
x = x.flatten(1)
|
| 722 |
+
return self.classifier(x)
|
| 723 |
+
|
| 724 |
+
# ============================================================
|
| 725 |
+
# 4. Custom Fusional Wav2Vec2 + CQCC with Cross-Attention + Graph
|
| 726 |
+
# ============================================================
|
| 727 |
+
|
| 728 |
+
class PositionalEncoding(nn.Module):
|
| 729 |
+
def __init__(self, dim, max_len=6000):
|
| 730 |
+
super().__init__()
|
| 731 |
+
self.pos_embed = nn.Parameter(torch.randn(1, max_len, dim))
|
| 732 |
+
|
| 733 |
+
def forward(self, x):
|
| 734 |
+
return x + self.pos_embed[:, :x.size(1)]
|
| 735 |
+
|
| 736 |
+
class BidirectionalCrossAttention(nn.Module):
|
| 737 |
+
def __init__(self, dim, num_heads=4):
|
| 738 |
+
super().__init__()
|
| 739 |
+
self.attn1 = nn.MultiheadAttention(dim, num_heads, batch_first=True, dropout=0.2)
|
| 740 |
+
self.attn2 = nn.MultiheadAttention(dim, num_heads, batch_first=True, dropout=0.2)
|
| 741 |
+
self.norm_q = nn.LayerNorm(dim)
|
| 742 |
+
self.norm_kv = nn.LayerNorm(dim)
|
| 743 |
+
|
| 744 |
+
def forward(self, x1, x2):
|
| 745 |
+
# x1 attends to x2
|
| 746 |
+
q1 = self.norm_q(x1)
|
| 747 |
+
k2 = self.norm_kv(x2)
|
| 748 |
+
v2 = k2
|
| 749 |
+
out1, _ = self.attn1(q1, k2, v2)
|
| 750 |
+
|
| 751 |
+
# x2 attends to x1
|
| 752 |
+
q2 = self.norm_q(x2)
|
| 753 |
+
k1 = self.norm_kv(x1)
|
| 754 |
+
v1 = k1
|
| 755 |
+
out2, _ = self.attn2(q2, k1, v1)
|
| 756 |
+
return out1, out2
|
| 757 |
+
|
| 758 |
+
def align_sequences(x, target_len):
|
| 759 |
+
"""Linear interpolation to match sequence lengths"""
|
| 760 |
+
x = x.transpose(1, 2)
|
| 761 |
+
x = F.interpolate(x, size=target_len, mode='linear', align_corners=False)
|
| 762 |
+
return x.transpose(1, 2)
|
| 763 |
+
|
| 764 |
+
class ImprovedWav2Vec2CQCCDetector(nn.Module):
|
| 765 |
+
def __init__(self, num_classes=2):
|
| 766 |
+
super().__init__()
|
| 767 |
+
|
| 768 |
+
# Wav2Vec2
|
| 769 |
+
self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
|
| 770 |
+
|
| 771 |
+
# Freeze the Wav2Vec2 layer so it acts purely as a feature extractor
|
| 772 |
+
for param in self.wav2vec.parameters():
|
| 773 |
+
param.requires_grad = False
|
| 774 |
+
|
| 775 |
+
dim = self.wav2vec.config.hidden_size
|
| 776 |
+
|
| 777 |
+
# CQCC encoder
|
| 778 |
+
self.cqcc_conv = nn.Sequential(
|
| 779 |
+
nn.Conv1d(20, 128, kernel_size=3, padding=1),
|
| 780 |
+
nn.BatchNorm1d(128),
|
| 781 |
+
nn.GELU(),
|
| 782 |
+
nn.Dropout(0.2),
|
| 783 |
+
nn.Conv1d(128, dim, kernel_size=3, padding=1),
|
| 784 |
+
nn.BatchNorm1d(dim),
|
| 785 |
+
nn.GELU()
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
# Positional Encoding
|
| 789 |
+
self.pos_enc = PositionalEncoding(dim)
|
| 790 |
+
|
| 791 |
+
# Bidirectional Cross Attention
|
| 792 |
+
self.cross_attn = BidirectionalCrossAttention(dim)
|
| 793 |
+
|
| 794 |
+
# True Graph Transformer Backend (using GAT blocks from AASIST)
|
| 795 |
+
self.graph_layers = nn.ModuleList([
|
| 796 |
+
GraphBlock(dim) for _ in range(3)
|
| 797 |
+
])
|
| 798 |
+
|
| 799 |
+
# Classifier
|
| 800 |
+
self.classifier = nn.Sequential(
|
| 801 |
+
nn.Linear(dim, 128),
|
| 802 |
+
nn.GELU(),
|
| 803 |
+
nn.Dropout(0.2),
|
| 804 |
+
nn.Linear(128, num_classes)
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
def forward(self, wav, cqcc):
|
| 808 |
+
if wav.dim() == 3:
|
| 809 |
+
wav = wav.squeeze(1)
|
| 810 |
+
|
| 811 |
+
# Wav2Vec2 features
|
| 812 |
+
w2v = self.wav2vec(wav).last_hidden_state # (B, T_w, D)
|
| 813 |
+
|
| 814 |
+
# CQCC features
|
| 815 |
+
if cqcc.dim() == 4:
|
| 816 |
+
cqcc = cqcc.squeeze(1)
|
| 817 |
+
cqcc_feat = self.cqcc_conv(cqcc).transpose(1, 2) # (B, T_c, D)
|
| 818 |
+
|
| 819 |
+
# Align lengths
|
| 820 |
+
cqcc_feat = align_sequences(cqcc_feat, w2v.size(1))
|
| 821 |
+
|
| 822 |
+
# Add positional encoding
|
| 823 |
+
w2v = self.pos_enc(w2v)
|
| 824 |
+
cqcc_feat = self.pos_enc(cqcc_feat)
|
| 825 |
+
|
| 826 |
+
# Cross attention (bidirectional)
|
| 827 |
+
f1, f2 = self.cross_attn(cqcc_feat, w2v)
|
| 828 |
+
fused = f1 + f2
|
| 829 |
+
|
| 830 |
+
# Graph Transformer processing on node sequences
|
| 831 |
+
x = fused
|
| 832 |
+
for layer in self.graph_layers:
|
| 833 |
+
x = layer(x)
|
| 834 |
+
|
| 835 |
+
# Global average pooling on the nodes
|
| 836 |
+
pooled = x.mean(dim=1)
|
| 837 |
+
|
| 838 |
+
return self.classifier(pooled)
|
| 839 |
+
|
| 840 |
+
# ============================================================
|
| 841 |
+
# 5. Ablation Models
|
| 842 |
+
# ============================================================
|
| 843 |
+
|
| 844 |
+
class AblationWav2Vec2GraphDetector(nn.Module):
|
| 845 |
+
"""Ablation 1: Wav2Vec2 only + Graph Backend (No CQCC, No Cross-Attention)"""
|
| 846 |
+
def __init__(self, num_classes=2):
|
| 847 |
+
super().__init__()
|
| 848 |
+
self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
|
| 849 |
+
for param in self.wav2vec.parameters():
|
| 850 |
+
param.requires_grad = False
|
| 851 |
+
|
| 852 |
+
dim = self.wav2vec.config.hidden_size
|
| 853 |
+
self.pos_enc = PositionalEncoding(dim)
|
| 854 |
+
|
| 855 |
+
self.graph_layers = nn.ModuleList([GraphBlock(dim) for _ in range(3)])
|
| 856 |
+
self.classifier = nn.Sequential(
|
| 857 |
+
nn.Linear(dim, 128), nn.GELU(), nn.Dropout(0.2), nn.Linear(128, num_classes)
|
| 858 |
+
)
|
| 859 |
+
|
| 860 |
+
def forward(self, wav, cqcc=None): # Accept both but ignore CQCC
|
| 861 |
+
if wav.dim() == 3:
|
| 862 |
+
wav = wav.squeeze(1)
|
| 863 |
+
|
| 864 |
+
w2v = self.wav2vec(wav).last_hidden_state
|
| 865 |
+
w2v = self.pos_enc(w2v)
|
| 866 |
+
|
| 867 |
+
x = w2v
|
| 868 |
+
for layer in self.graph_layers:
|
| 869 |
+
x = layer(x)
|
| 870 |
+
|
| 871 |
+
pooled = x.mean(dim=1)
|
| 872 |
+
return self.classifier(pooled)
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
class AblationCQCCGraphDetector(nn.Module):
|
| 876 |
+
"""Ablation 2: CQCC only + Graph Backend (No Wav2Vec2, No Cross-Attention)"""
|
| 877 |
+
def __init__(self, num_classes=2):
|
| 878 |
+
super().__init__()
|
| 879 |
+
dim = 768 # Match Wav2Vec2 hidden size for fair comparison
|
| 880 |
+
|
| 881 |
+
self.cqcc_conv = nn.Sequential(
|
| 882 |
+
nn.Conv1d(20, 128, kernel_size=3, padding=1),
|
| 883 |
+
nn.BatchNorm1d(128),
|
| 884 |
+
nn.GELU(),
|
| 885 |
+
nn.Dropout(0.2),
|
| 886 |
+
nn.Conv1d(128, dim, kernel_size=3, padding=1),
|
| 887 |
+
nn.BatchNorm1d(dim),
|
| 888 |
+
nn.GELU()
|
| 889 |
+
)
|
| 890 |
+
self.pos_enc = PositionalEncoding(dim)
|
| 891 |
+
|
| 892 |
+
self.graph_layers = nn.ModuleList([GraphBlock(dim) for _ in range(3)])
|
| 893 |
+
self.classifier = nn.Sequential(
|
| 894 |
+
nn.Linear(dim, 128), nn.GELU(), nn.Dropout(0.2), nn.Linear(128, num_classes)
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
def forward(self, cqcc):
|
| 898 |
+
if cqcc.dim() == 4:
|
| 899 |
+
cqcc = cqcc.squeeze(1)
|
| 900 |
+
|
| 901 |
+
cqcc_feat = self.cqcc_conv(cqcc).transpose(1, 2)
|
| 902 |
+
cqcc_feat = self.pos_enc(cqcc_feat)
|
| 903 |
+
|
| 904 |
+
x = cqcc_feat
|
| 905 |
+
for layer in self.graph_layers:
|
| 906 |
+
x = layer(x)
|
| 907 |
+
|
| 908 |
+
pooled = x.mean(dim=1)
|
| 909 |
+
return self.classifier(pooled)
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
class AblationConcatGraphDetector(nn.Module):
|
| 913 |
+
"""Ablation 3: Wav2Vec2 + CQCC + Simple Concat Fusion + Graph Backend (No Cross-Attention)"""
|
| 914 |
+
def __init__(self, num_classes=2):
|
| 915 |
+
super().__init__()
|
| 916 |
+
self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
|
| 917 |
+
for param in self.wav2vec.parameters():
|
| 918 |
+
param.requires_grad = False
|
| 919 |
+
|
| 920 |
+
dim = self.wav2vec.config.hidden_size
|
| 921 |
+
|
| 922 |
+
self.cqcc_conv = nn.Sequential(
|
| 923 |
+
nn.Conv1d(20, 128, kernel_size=3, padding=1),
|
| 924 |
+
nn.BatchNorm1d(128),
|
| 925 |
+
nn.GELU(),
|
| 926 |
+
nn.Dropout(0.2),
|
| 927 |
+
nn.Conv1d(128, dim, kernel_size=3, padding=1),
|
| 928 |
+
nn.BatchNorm1d(dim),
|
| 929 |
+
nn.GELU()
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
self.fusion_proj = nn.Linear(dim * 2, dim) # Project concatenated features back to dim
|
| 933 |
+
self.pos_enc = PositionalEncoding(dim)
|
| 934 |
+
|
| 935 |
+
self.graph_layers = nn.ModuleList([GraphBlock(dim) for _ in range(3)])
|
| 936 |
+
self.classifier = nn.Sequential(
|
| 937 |
+
nn.Linear(dim, 128), nn.GELU(), nn.Dropout(0.2), nn.Linear(128, num_classes)
|
| 938 |
+
)
|
| 939 |
+
|
| 940 |
+
def forward(self, wav, cqcc):
|
| 941 |
+
if wav.dim() == 3:
|
| 942 |
+
wav = wav.squeeze(1)
|
| 943 |
+
w2v = self.wav2vec(wav).last_hidden_state
|
| 944 |
+
|
| 945 |
+
if cqcc.dim() == 4:
|
| 946 |
+
cqcc = cqcc.squeeze(1)
|
| 947 |
+
cqcc_feat = self.cqcc_conv(cqcc).transpose(1, 2)
|
| 948 |
+
|
| 949 |
+
cqcc_feat = align_sequences(cqcc_feat, w2v.size(1))
|
| 950 |
+
|
| 951 |
+
# Simple concat over feature dimension instead of cross-attention
|
| 952 |
+
fused = torch.cat([w2v, cqcc_feat], dim=-1)
|
| 953 |
+
fused = self.fusion_proj(fused)
|
| 954 |
+
|
| 955 |
+
fused = self.pos_enc(fused)
|
| 956 |
+
|
| 957 |
+
x = fused
|
| 958 |
+
for layer in self.graph_layers:
|
| 959 |
+
x = layer(x)
|
| 960 |
+
|
| 961 |
+
pooled = x.mean(dim=1)
|
| 962 |
+
return self.classifier(pooled)
|
| 963 |
+
|
| 964 |
+
|
| 965 |
+
class AblationCrossAttnLinearDetector(nn.Module):
|
| 966 |
+
"""Ablation 4: Wav2Vec2 + CQCC + Cross-Attention + Linear Backend (No Graph Transformer)"""
|
| 967 |
+
def __init__(self, num_classes=2):
|
| 968 |
+
super().__init__()
|
| 969 |
+
self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
|
| 970 |
+
for param in self.wav2vec.parameters():
|
| 971 |
+
param.requires_grad = False
|
| 972 |
+
|
| 973 |
+
dim = self.wav2vec.config.hidden_size
|
| 974 |
+
|
| 975 |
+
self.cqcc_conv = nn.Sequential(
|
| 976 |
+
nn.Conv1d(20, 128, kernel_size=3, padding=1),
|
| 977 |
+
nn.BatchNorm1d(128),
|
| 978 |
+
nn.GELU(),
|
| 979 |
+
nn.Dropout(0.2),
|
| 980 |
+
nn.Conv1d(128, dim, kernel_size=3, padding=1),
|
| 981 |
+
nn.BatchNorm1d(dim),
|
| 982 |
+
nn.GELU()
|
| 983 |
+
)
|
| 984 |
+
|
| 985 |
+
self.pos_enc = PositionalEncoding(dim)
|
| 986 |
+
self.cross_attn = BidirectionalCrossAttention(dim)
|
| 987 |
+
|
| 988 |
+
# Richer MLP classifier since graph is missing
|
| 989 |
+
self.classifier = nn.Sequential(
|
| 990 |
+
nn.Linear(dim, 256),
|
| 991 |
+
nn.GELU(),
|
| 992 |
+
nn.Dropout(0.3),
|
| 993 |
+
nn.Linear(256, 128),
|
| 994 |
+
nn.GELU(),
|
| 995 |
+
nn.Dropout(0.2),
|
| 996 |
+
nn.Linear(128, num_classes)
|
| 997 |
+
)
|
| 998 |
+
|
| 999 |
+
def forward(self, wav, cqcc):
|
| 1000 |
+
if wav.dim() == 3:
|
| 1001 |
+
wav = wav.squeeze(1)
|
| 1002 |
+
w2v = self.wav2vec(wav).last_hidden_state
|
| 1003 |
+
|
| 1004 |
+
if cqcc.dim() == 4:
|
| 1005 |
+
cqcc = cqcc.squeeze(1)
|
| 1006 |
+
cqcc_feat = self.cqcc_conv(cqcc).transpose(1, 2)
|
| 1007 |
+
|
| 1008 |
+
cqcc_feat = align_sequences(cqcc_feat, w2v.size(1))
|
| 1009 |
+
|
| 1010 |
+
w2v = self.pos_enc(w2v)
|
| 1011 |
+
cqcc_feat = self.pos_enc(cqcc_feat)
|
| 1012 |
+
|
| 1013 |
+
f1, f2 = self.cross_attn(cqcc_feat, w2v)
|
| 1014 |
+
fused = f1 + f2
|
| 1015 |
+
|
| 1016 |
+
# No graph layer, straight to global average pooling
|
| 1017 |
+
pooled = fused.mean(dim=1)
|
| 1018 |
+
return self.classifier(pooled)
|
| 1019 |
+
|
backend/preprocess.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
import argparse
|
| 5 |
+
from dataset import AudioDataset
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def run_command(cmd):
|
| 9 |
+
try:
|
| 10 |
+
subprocess.run(cmd, check=True, text=True, capture_output=True)
|
| 11 |
+
except subprocess.CalledProcessError:
|
| 12 |
+
sys.exit(1)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def download_dataset():
|
| 16 |
+
run_command(["git", "lfs", "install"])
|
| 17 |
+
dataset_dir = "MLAAD-tiny"
|
| 18 |
+
if not os.path.exists(dataset_dir):
|
| 19 |
+
print("=== Cloning MLAAD-tiny dataset ===")
|
| 20 |
+
run_command(["git", "clone", "https://huggingface.co/datasets/mueller91/MLAAD-tiny"])
|
| 21 |
+
else:
|
| 22 |
+
print(f"Dataset directory '{dataset_dir}' already exists. Skipping clone.")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def precompute_cqcc(data_dir, cqcc_cache_dir, force=False):
|
| 26 |
+
for lang in ["en", "de"]:
|
| 27 |
+
print(f"\n--- Precomputing CQCC for language: {lang} ---")
|
| 28 |
+
dataset = AudioDataset(
|
| 29 |
+
data_dir=data_dir,
|
| 30 |
+
augment=False,
|
| 31 |
+
cqcc_cache_dir=cqcc_cache_dir,
|
| 32 |
+
target_lang=lang
|
| 33 |
+
)
|
| 34 |
+
dataset.precompute_cqcc_cache(force=force)
|
| 35 |
+
print("\nFinished all CQCC preprocessing.")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def parse_args():
|
| 39 |
+
parser = argparse.ArgumentParser(description="Download dataset and precompute CQCC features.")
|
| 40 |
+
parser.add_argument("--data-dir", default="MLAAD-tiny")
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--cqcc-cache-dir",
|
| 43 |
+
default=os.path.join(os.path.dirname(__file__), "precomputed_features", "cqcc")
|
| 44 |
+
)
|
| 45 |
+
parser.add_argument("--force", action="store_true")
|
| 46 |
+
parser.add_argument("--skip-download", action="store_true")
|
| 47 |
+
parser.add_argument("--skip-cqcc", action="store_true")
|
| 48 |
+
return parser.parse_args()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if __name__ == "__main__":
|
| 52 |
+
args = parse_args()
|
| 53 |
+
|
| 54 |
+
if not args.skip_download:
|
| 55 |
+
download_dataset()
|
| 56 |
+
|
| 57 |
+
if not args.skip_cqcc:
|
| 58 |
+
precompute_cqcc(args.data_dir, args.cqcc_cache_dir, args.force)
|
backend/train.py
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from dataset import AudioDataset, collate_variable_length
|
| 7 |
+
from models import (
|
| 8 |
+
AASISTDetector,
|
| 9 |
+
Wav2Vec2SpoofDetector,
|
| 10 |
+
CQCCBaselineDetector,
|
| 11 |
+
ImprovedWav2Vec2CQCCDetector,
|
| 12 |
+
AblationWav2Vec2GraphDetector,
|
| 13 |
+
AblationCQCCGraphDetector,
|
| 14 |
+
AblationConcatGraphDetector,
|
| 15 |
+
AblationCrossAttnLinearDetector
|
| 16 |
+
)
|
| 17 |
+
from sklearn.metrics import roc_curve, auc
|
| 18 |
+
import numpy as np
|
| 19 |
+
import random
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def train_model(model, train_dataloader, criterion, optimizer, epochs=5, input_type='wav', device=None, val_dataloader=None, eval_interval=1, patience=2, model_save_path=None):
|
| 24 |
+
|
| 25 |
+
if device is None:
|
| 26 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 27 |
+
|
| 28 |
+
model.to(device)
|
| 29 |
+
|
| 30 |
+
loss_history = []
|
| 31 |
+
best_val_metric = float('inf') # For min_dcf, lower is better
|
| 32 |
+
patience_counter = 0
|
| 33 |
+
best_epoch = 0
|
| 34 |
+
|
| 35 |
+
for epoch in range(epochs):
|
| 36 |
+
model.train()
|
| 37 |
+
epoch_loss = 0
|
| 38 |
+
correct = 0
|
| 39 |
+
total = 0
|
| 40 |
+
# Wrap the dataloader with tqdm for a progress bar
|
| 41 |
+
for batch_idx, batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs} - Training")):
|
| 42 |
+
|
| 43 |
+
wavs, cqccs, labels = batch
|
| 44 |
+
wavs = wavs.to(device)
|
| 45 |
+
cqccs = cqccs.to(device)
|
| 46 |
+
labels = labels.to(device)
|
| 47 |
+
|
| 48 |
+
optimizer.zero_grad()
|
| 49 |
+
|
| 50 |
+
if input_type == 'wav':
|
| 51 |
+
outputs = model(wavs)
|
| 52 |
+
elif input_type == 'cqcc':
|
| 53 |
+
outputs = model(cqccs)
|
| 54 |
+
elif input_type == 'wav_and_cqcc':
|
| 55 |
+
outputs = model(wavs, cqccs)
|
| 56 |
+
else:
|
| 57 |
+
raise ValueError("invalid input_type")
|
| 58 |
+
|
| 59 |
+
loss = criterion(outputs, labels)
|
| 60 |
+
|
| 61 |
+
loss.backward()
|
| 62 |
+
optimizer.step()
|
| 63 |
+
|
| 64 |
+
epoch_loss += loss.item()
|
| 65 |
+
|
| 66 |
+
_, predicted = torch.max(outputs.data, 1)
|
| 67 |
+
|
| 68 |
+
total += labels.size(0)
|
| 69 |
+
correct += (predicted == labels).sum().item()
|
| 70 |
+
|
| 71 |
+
# Print intermediate progress within the epoch
|
| 72 |
+
if batch_idx % 500 == 0 and batch_idx > 0: # Report every 500 batches
|
| 73 |
+
current_acc = 100 * correct / total
|
| 74 |
+
current_loss = epoch_loss / (batch_idx + 1)
|
| 75 |
+
print(f" Batch {batch_idx}/{len(train_dataloader)} | Loss: {current_loss:.4f} | Acc: {current_acc:.2f}%")
|
| 76 |
+
|
| 77 |
+
acc = 100 * correct / total if total > 0 else 0
|
| 78 |
+
avg_loss = epoch_loss / len(train_dataloader)
|
| 79 |
+
loss_history.append(avg_loss)
|
| 80 |
+
print(f"Epoch {epoch+1}/{epochs} | Training Loss: {avg_loss:.4f} | Training Acc: {acc:.2f}%")
|
| 81 |
+
|
| 82 |
+
# Validation and Early Stopping
|
| 83 |
+
if val_dataloader is not None and (epoch + 1) % eval_interval == 0:
|
| 84 |
+
print(f"Epoch {epoch+1}/{epochs} - Evaluating on Validation Set...")
|
| 85 |
+
_, _, _, val_eer, val_min_dcf, val_accuracy = evaluate_model(
|
| 86 |
+
model, val_dataloader, input_type=input_type, device=device
|
| 87 |
+
)
|
| 88 |
+
print(f" Validation | EER={val_eer*100:.2f}% | minDCF={val_min_dcf:.4f} | Accuracy={val_accuracy:.2f}")
|
| 89 |
+
|
| 90 |
+
if val_min_dcf < best_val_metric:
|
| 91 |
+
best_val_metric = val_min_dcf
|
| 92 |
+
patience_counter = 0
|
| 93 |
+
best_epoch = epoch + 1
|
| 94 |
+
if model_save_path:
|
| 95 |
+
torch.save(model.state_dict(), model_save_path)
|
| 96 |
+
print(f" Saved best model to {model_save_path} (minDCF: {best_val_metric:.4f})")
|
| 97 |
+
else:
|
| 98 |
+
patience_counter += 1
|
| 99 |
+
print(f" Validation minDCF did not improve. Patience: {patience_counter}/{patience}")
|
| 100 |
+
|
| 101 |
+
if patience_counter >= patience:
|
| 102 |
+
print(f"Early stopping triggered after {epoch+1} epochs. Best minDCF: {best_val_metric:.4f} at epoch {best_epoch}")
|
| 103 |
+
if model_save_path:
|
| 104 |
+
print(f"Loading best model from {model_save_path}")
|
| 105 |
+
model.load_state_dict(torch.load(model_save_path))
|
| 106 |
+
return loss_history # Stop training
|
| 107 |
+
|
| 108 |
+
# ensure save path logic is intact even when loop ends naturally
|
| 109 |
+
if val_dataloader is None and model_save_path is not None:
|
| 110 |
+
torch.save(model.state_dict(), model_save_path)
|
| 111 |
+
print(f" Saved final model to {model_save_path}")
|
| 112 |
+
|
| 113 |
+
return loss_history
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def evaluate_model(model, dataloader, input_type='wav', device=None):
|
| 117 |
+
|
| 118 |
+
if device is None:
|
| 119 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 120 |
+
|
| 121 |
+
model.eval()
|
| 122 |
+
|
| 123 |
+
all_labels = []
|
| 124 |
+
all_probs = []
|
| 125 |
+
|
| 126 |
+
with torch.no_grad():
|
| 127 |
+
for batch in tqdm(dataloader, desc="Evaluating"):
|
| 128 |
+
|
| 129 |
+
wavs, cqccs, labels = batch
|
| 130 |
+
wavs = wavs.to(device)
|
| 131 |
+
cqccs = cqccs.to(device)
|
| 132 |
+
labels = labels.to(device)
|
| 133 |
+
|
| 134 |
+
if input_type == 'wav':
|
| 135 |
+
outputs = model(wavs)
|
| 136 |
+
elif input_type == 'cqcc':
|
| 137 |
+
outputs = model(cqccs)
|
| 138 |
+
elif input_type == 'wav_and_cqcc':
|
| 139 |
+
outputs = model(wavs, cqccs)
|
| 140 |
+
else:
|
| 141 |
+
raise ValueError("invalid input_type")
|
| 142 |
+
|
| 143 |
+
probs = torch.softmax(outputs, dim=1)[:, 1]
|
| 144 |
+
|
| 145 |
+
all_labels.extend(labels.tolist())
|
| 146 |
+
all_probs.extend(probs.tolist())
|
| 147 |
+
|
| 148 |
+
fpr, tpr, thresholds = roc_curve(all_labels, all_probs)
|
| 149 |
+
roc_auc = auc(fpr, tpr)
|
| 150 |
+
|
| 151 |
+
# ------------------
|
| 152 |
+
# EER (Equal Error Rate)
|
| 153 |
+
# ------------------
|
| 154 |
+
fnr = 1 - tpr
|
| 155 |
+
eer_index = np.nanargmin(np.absolute(fnr - fpr))
|
| 156 |
+
eer = fpr[eer_index]
|
| 157 |
+
|
| 158 |
+
# ------------------
|
| 159 |
+
# minDCF (Minimum Detection Cost Function)
|
| 160 |
+
# Parameters according to ASVspoof 5 Evaluation Plan (Track 1)
|
| 161 |
+
# ------------------
|
| 162 |
+
P_spoof = 0.05 # Prior probability of a spoofing attack (\pi_{spf})
|
| 163 |
+
P_bonafide = 0.95 # Prior probability of a real/bonafide utterance (1 - \pi_{spf})
|
| 164 |
+
C_miss = 1 # Cost of falsely rejecting a real voice (Miss)
|
| 165 |
+
C_fa = 10 # Cost of falsely accepting a spoof (False Alarm)
|
| 166 |
+
|
| 167 |
+
# In the dataset, 0 = real (bonafide), 1 = fake (spoof)
|
| 168 |
+
# fpr (False Positive Rate) = predicted fake (1) when true is real (0). This is a "miss" in ASVspoof.
|
| 169 |
+
# fnr (False Negative Rate) = predicted real (0) when true is fake (1). This is a "false alarm" in ASVspoof.
|
| 170 |
+
P_miss = fpr
|
| 171 |
+
P_fa = fnr
|
| 172 |
+
|
| 173 |
+
# Raw DCF = C_miss * P_bonafide * P_miss + C_fa * P_spoof * P_fa
|
| 174 |
+
# Normalized by the default DCF (min cost of predicting all bonafide vs all spoof)
|
| 175 |
+
dcf_default = min(C_miss * P_bonafide, C_fa * P_spoof)
|
| 176 |
+
dcf_array = (C_miss * P_bonafide * P_miss + C_fa * P_spoof * P_fa) / dcf_default
|
| 177 |
+
min_dcf = np.min(dcf_array)
|
| 178 |
+
|
| 179 |
+
# Overall Accuracy (using 0.5 threshold)
|
| 180 |
+
preds = [1 if p > 0.5 else 0 for p in all_probs]
|
| 181 |
+
correct = sum(1 for p, l in zip(preds, all_labels) if p == l)
|
| 182 |
+
accuracy = correct / len(all_labels) if len(all_labels) > 0 else 0
|
| 183 |
+
|
| 184 |
+
return fpr, tpr, roc_auc, eer, min_dcf, accuracy
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def parse_args():
|
| 188 |
+
parser = argparse.ArgumentParser(description="Train spoof-detection models with optional CQCC caching.")
|
| 189 |
+
parser.add_argument(
|
| 190 |
+
"--data-dir",
|
| 191 |
+
default=None,
|
| 192 |
+
help="Path to dataset root containing original/ and fake/ folders."
|
| 193 |
+
)
|
| 194 |
+
parser.add_argument(
|
| 195 |
+
"--cqcc-cache-dir", # this is where cqcc is stored
|
| 196 |
+
default=os.path.join(os.path.dirname(__file__), "precomputed_features", "cqcc"),
|
| 197 |
+
help="Directory used to store and reuse precomputed CQCC tensors."
|
| 198 |
+
)
|
| 199 |
+
parser.add_argument(
|
| 200 |
+
"--precompute-cqcc-only",
|
| 201 |
+
action="store_true",
|
| 202 |
+
help="Only build the CQCC cache and exit without training."
|
| 203 |
+
)
|
| 204 |
+
parser.add_argument(
|
| 205 |
+
"--val-split",
|
| 206 |
+
type=float,
|
| 207 |
+
default=0.2,
|
| 208 |
+
help="Fraction of English training data to reserve for validation."
|
| 209 |
+
)
|
| 210 |
+
parser.add_argument(
|
| 211 |
+
"--force-rebuild-cqcc",
|
| 212 |
+
action="store_true",
|
| 213 |
+
help="Recompute cached CQCC files even if they already exist."
|
| 214 |
+
)
|
| 215 |
+
parser.add_argument(
|
| 216 |
+
"--smoke-test",
|
| 217 |
+
action="store_true",
|
| 218 |
+
help="Load one batch, run a forward pass through each model, and exit without training."
|
| 219 |
+
)
|
| 220 |
+
return parser.parse_args()
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def run_smoke_test(dataloader, device):
|
| 224 |
+
print("\n--- Running Smoke Test ---")
|
| 225 |
+
batch = next(iter(dataloader))
|
| 226 |
+
wavs, cqccs, labels = batch
|
| 227 |
+
|
| 228 |
+
models_to_test = [
|
| 229 |
+
("Wav2Vec2 Baseline", Wav2Vec2SpoofDetector(num_classes=2).to(device), "wav"),
|
| 230 |
+
("AASIST Baseline", AASISTDetector(num_classes=2).to(device), "wav"),
|
| 231 |
+
("CQCC Baseline", CQCCBaselineDetector(num_classes=2).to(device), "cqcc"),
|
| 232 |
+
("Custom Fusion Model", ImprovedWav2Vec2CQCCDetector(num_classes=2).to(device), "wav_and_cqcc"),
|
| 233 |
+
("Ablation W2V2+Graph", AblationWav2Vec2GraphDetector(num_classes=2).to(device), "wav"),
|
| 234 |
+
("Ablation CQCC+Graph", AblationCQCCGraphDetector(num_classes=2).to(device), "cqcc"),
|
| 235 |
+
("Ablation Concat+Graph", AblationConcatGraphDetector(num_classes=2).to(device), "wav_and_cqcc"),
|
| 236 |
+
("Ablation CrossAttn+Linear", AblationCrossAttnLinearDetector(num_classes=2).to(device), "wav_and_cqcc"),
|
| 237 |
+
]
|
| 238 |
+
|
| 239 |
+
with torch.no_grad():
|
| 240 |
+
for name, model, input_type in models_to_test:
|
| 241 |
+
model.eval()
|
| 242 |
+
if input_type == "wav":
|
| 243 |
+
outputs = model(wavs.to(device))
|
| 244 |
+
elif input_type == "cqcc":
|
| 245 |
+
outputs = model(cqccs.to(device))
|
| 246 |
+
elif input_type == "wav_and_cqcc":
|
| 247 |
+
outputs = model(wavs.to(device), cqccs.to(device))
|
| 248 |
+
else:
|
| 249 |
+
raise ValueError("invalid input_type")
|
| 250 |
+
|
| 251 |
+
print(f"{name}: input OK, output shape = {tuple(outputs.shape)}")
|
| 252 |
+
|
| 253 |
+
print(f"Labels shape = {tuple(labels.shape)}")
|
| 254 |
+
print("Smoke test complete. Cached CQCC loading and model forward passes succeeded.")
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def main():
|
| 258 |
+
args = parse_args()
|
| 259 |
+
print(args)
|
| 260 |
+
SEED = 42
|
| 261 |
+
random.seed(SEED)
|
| 262 |
+
np.random.seed(SEED)
|
| 263 |
+
torch.manual_seed(SEED)
|
| 264 |
+
if torch.cuda.is_available():
|
| 265 |
+
torch.cuda.manual_seed_all(SEED)
|
| 266 |
+
|
| 267 |
+
g = torch.Generator()
|
| 268 |
+
g.manual_seed(SEED)
|
| 269 |
+
|
| 270 |
+
torch.backends.cudnn.deterministic = True
|
| 271 |
+
torch.backends.cudnn.benchmark = False
|
| 272 |
+
|
| 273 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 274 |
+
|
| 275 |
+
print(f"Using device: {device}")
|
| 276 |
+
|
| 277 |
+
print("Loading English Dataset for training/validation...")
|
| 278 |
+
full_en_dataset = AudioDataset(data_dir=args.data_dir, augment=False, cqcc_cache_dir=args.cqcc_cache_dir, target_lang="en")
|
| 279 |
+
total_en = len(full_en_dataset)
|
| 280 |
+
if total_en == 0:
|
| 281 |
+
raise ValueError("No English data found for target_lang='en'. Check data_dir and directory layout.")
|
| 282 |
+
|
| 283 |
+
val_split = min(max(args.val_split, 0.0), 0.5)
|
| 284 |
+
train_size = int((1.0 - val_split) * total_en)
|
| 285 |
+
val_size = total_en - train_size
|
| 286 |
+
indices = torch.randperm(total_en, generator=g).tolist()
|
| 287 |
+
train_indices = indices[:train_size]
|
| 288 |
+
val_indices = indices[train_size:]
|
| 289 |
+
|
| 290 |
+
train_dataset = torch.utils.data.Subset(
|
| 291 |
+
AudioDataset(data_dir=args.data_dir, augment=True, cqcc_cache_dir=args.cqcc_cache_dir, target_lang="en"),
|
| 292 |
+
train_indices
|
| 293 |
+
)
|
| 294 |
+
val_dataset = torch.utils.data.Subset(
|
| 295 |
+
AudioDataset(data_dir=args.data_dir, augment=False, cqcc_cache_dir=args.cqcc_cache_dir, target_lang="en"),
|
| 296 |
+
val_indices
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
print("Loading German Dataset for Testing...")
|
| 300 |
+
test_dataset = AudioDataset(data_dir=args.data_dir, augment=False, cqcc_cache_dir=args.cqcc_cache_dir, target_lang="de")
|
| 301 |
+
|
| 302 |
+
if args.precompute_cqcc_only:
|
| 303 |
+
print("\n--- Starting CQCC Precomputation ---")
|
| 304 |
+
print(f"Dataset: {full_en_dataset.data_dir}")
|
| 305 |
+
print("Precomputing CQCC cache for English data...")
|
| 306 |
+
full_en_dataset.precompute_cqcc_cache(force=args.force_rebuild_cqcc)
|
| 307 |
+
test_dataset.precompute_cqcc_cache(force=args.force_rebuild_cqcc)
|
| 308 |
+
print("CQCC preprocessing complete. Exiting.")
|
| 309 |
+
return
|
| 310 |
+
|
| 311 |
+
train_loader = DataLoader(
|
| 312 |
+
train_dataset,
|
| 313 |
+
batch_size=8,
|
| 314 |
+
shuffle=True,
|
| 315 |
+
collate_fn=collate_variable_length,
|
| 316 |
+
num_workers=2,
|
| 317 |
+
pin_memory=True,
|
| 318 |
+
generator=g, # ensure reproducible shuffling
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
val_loader = DataLoader(
|
| 322 |
+
val_dataset,
|
| 323 |
+
batch_size=8,
|
| 324 |
+
shuffle=False,
|
| 325 |
+
collate_fn=collate_variable_length,
|
| 326 |
+
num_workers=2,
|
| 327 |
+
pin_memory=True
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
test_loader = DataLoader(
|
| 331 |
+
test_dataset,
|
| 332 |
+
batch_size=8,
|
| 333 |
+
shuffle=False,
|
| 334 |
+
collate_fn=collate_variable_length,
|
| 335 |
+
num_workers=2,
|
| 336 |
+
pin_memory=True
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
if args.smoke_test:
|
| 340 |
+
run_smoke_test(train_loader, device)
|
| 341 |
+
return
|
| 342 |
+
|
| 343 |
+
models_dir = os.path.join(os.path.dirname(__file__), "models")
|
| 344 |
+
os.makedirs(models_dir, exist_ok=True)
|
| 345 |
+
|
| 346 |
+
criterion = nn.CrossEntropyLoss()
|
| 347 |
+
|
| 348 |
+
# ============================================================
|
| 349 |
+
# 1 Wav2Vec2 Baseline
|
| 350 |
+
# ============================================================
|
| 351 |
+
|
| 352 |
+
print("\n--- Training Wav2Vec2 Baseline ---")
|
| 353 |
+
|
| 354 |
+
wav2vec_model = Wav2Vec2SpoofDetector(num_classes=2).to(device)
|
| 355 |
+
|
| 356 |
+
optimizer_wav2vec = torch.optim.Adam(wav2vec_model.parameters(), lr=1e-4)
|
| 357 |
+
|
| 358 |
+
wav2vec_loss = train_model(
|
| 359 |
+
wav2vec_model,
|
| 360 |
+
train_loader,
|
| 361 |
+
criterion,
|
| 362 |
+
optimizer_wav2vec,
|
| 363 |
+
input_type='wav',
|
| 364 |
+
device=device,
|
| 365 |
+
val_dataloader=val_loader,
|
| 366 |
+
model_save_path=os.path.join(models_dir, "wav2vec2.pth")
|
| 367 |
+
)
|
| 368 |
+
del wav2vec_model, optimizer_wav2vec
|
| 369 |
+
torch.cuda.empty_cache()
|
| 370 |
+
# ============================================================
|
| 371 |
+
# 2 AASIST Baseline
|
| 372 |
+
# ============================================================
|
| 373 |
+
|
| 374 |
+
print("\n--- Training AASIST Baseline ---")
|
| 375 |
+
|
| 376 |
+
aasist_model = AASISTDetector(num_classes=2).to(device)
|
| 377 |
+
|
| 378 |
+
optimizer_aasist = torch.optim.Adam(aasist_model.parameters(), lr=5e-4)
|
| 379 |
+
|
| 380 |
+
aasist_loss = train_model(
|
| 381 |
+
aasist_model,
|
| 382 |
+
train_loader,
|
| 383 |
+
criterion,
|
| 384 |
+
optimizer_aasist,
|
| 385 |
+
input_type='wav',
|
| 386 |
+
device=device,
|
| 387 |
+
val_dataloader=val_loader,
|
| 388 |
+
model_save_path=os.path.join(models_dir, "aasist.pth")
|
| 389 |
+
)
|
| 390 |
+
del aasist_model, optimizer_aasist
|
| 391 |
+
torch.cuda.empty_cache()
|
| 392 |
+
# ============================================================
|
| 393 |
+
# 3 CQCC Baseline
|
| 394 |
+
# ============================================================
|
| 395 |
+
|
| 396 |
+
print("\n--- Training CQCC Baseline ---")
|
| 397 |
+
|
| 398 |
+
cqcc_baseline = CQCCBaselineDetector(num_classes=2).to(device)
|
| 399 |
+
|
| 400 |
+
optimizer_cqcc = torch.optim.Adam(cqcc_baseline.parameters(), lr=1e-4)
|
| 401 |
+
|
| 402 |
+
cqcc_loss = train_model(
|
| 403 |
+
cqcc_baseline,
|
| 404 |
+
train_loader,
|
| 405 |
+
criterion,
|
| 406 |
+
optimizer_cqcc,
|
| 407 |
+
input_type='cqcc',
|
| 408 |
+
device=device,
|
| 409 |
+
val_dataloader=val_loader,
|
| 410 |
+
model_save_path=os.path.join(models_dir, "cqcc_baseline.pth")
|
| 411 |
+
)
|
| 412 |
+
del cqcc_baseline, optimizer_cqcc
|
| 413 |
+
torch.cuda.empty_cache()
|
| 414 |
+
# ============================================================
|
| 415 |
+
# 4 Custom Fusional Wav2Vec2 + CQCC with Cross-Attention + Graph
|
| 416 |
+
# ============================================================
|
| 417 |
+
|
| 418 |
+
print("\n--- Training Custom Fusion Detector ---")
|
| 419 |
+
|
| 420 |
+
custom_model = ImprovedWav2Vec2CQCCDetector(num_classes=2).to(device)
|
| 421 |
+
|
| 422 |
+
optimizer_custom = torch.optim.Adam(custom_model.parameters(), lr=1e-4)
|
| 423 |
+
|
| 424 |
+
custom_loss = train_model(
|
| 425 |
+
custom_model,
|
| 426 |
+
train_loader,
|
| 427 |
+
criterion,
|
| 428 |
+
optimizer_custom,
|
| 429 |
+
input_type='wav_and_cqcc',
|
| 430 |
+
device=device,
|
| 431 |
+
val_dataloader=val_loader,
|
| 432 |
+
model_save_path=os.path.join(models_dir, "custom_hybrid.pth")
|
| 433 |
+
)
|
| 434 |
+
del custom_model, optimizer_custom
|
| 435 |
+
torch.cuda.empty_cache()
|
| 436 |
+
|
| 437 |
+
# ============================================================
|
| 438 |
+
# 5 Ablation Models
|
| 439 |
+
# ============================================================
|
| 440 |
+
|
| 441 |
+
print("\n--- Training Ablation 1 (Wav2Vec2 + Graph) ---")
|
| 442 |
+
ab1_model = AblationWav2Vec2GraphDetector(num_classes=2).to(device)
|
| 443 |
+
optimizer_ab1 = torch.optim.Adam(ab1_model.parameters(), lr=1e-4) # learning rate for wav2vec2-based
|
| 444 |
+
ab1_loss = train_model(ab1_model, train_loader, criterion, optimizer_ab1, input_type='wav', device=device, val_dataloader=val_loader, model_save_path=os.path.join(models_dir, "ablation_w2v2_graph.pth"))
|
| 445 |
+
del ab1_model, optimizer_ab1
|
| 446 |
+
torch.cuda.empty_cache()
|
| 447 |
+
|
| 448 |
+
print("\n--- Training Ablation 2 (CQCC + Graph) ---")
|
| 449 |
+
ab2_model = AblationCQCCGraphDetector(num_classes=2).to(device)
|
| 450 |
+
optimizer_ab2 = torch.optim.Adam(ab2_model.parameters(), lr=1e-4) # learning rate for CQCC-based
|
| 451 |
+
ab2_loss = train_model(ab2_model, train_loader, criterion, optimizer_ab2, input_type='cqcc', device=device, val_dataloader=val_loader, model_save_path=os.path.join(models_dir, "ablation_cqcc_graph.pth"))
|
| 452 |
+
del ab2_model, optimizer_ab2
|
| 453 |
+
torch.cuda.empty_cache()
|
| 454 |
+
|
| 455 |
+
print("\n--- Training Ablation 3 (Wav2Vec2 + CQCC + Simple Concat) ---")
|
| 456 |
+
ab3_model = AblationConcatGraphDetector(num_classes=2).to(device)
|
| 457 |
+
optimizer_ab3 = torch.optim.Adam(ab3_model.parameters(), lr=1e-4)
|
| 458 |
+
ab3_loss = train_model(ab3_model, train_loader, criterion, optimizer_ab3, input_type='wav_and_cqcc', device=device, val_dataloader=val_loader, model_save_path=os.path.join(models_dir, "ablation_concat_graph.pth"))
|
| 459 |
+
del ab3_model, optimizer_ab3
|
| 460 |
+
torch.cuda.empty_cache()
|
| 461 |
+
|
| 462 |
+
print("\n--- Training Ablation 4 (Wav2Vec2 + CQCC + Cross-Attn + Linear) ---")
|
| 463 |
+
ab4_model = AblationCrossAttnLinearDetector(num_classes=2).to(device)
|
| 464 |
+
optimizer_ab4 = torch.optim.Adam(ab4_model.parameters(), lr=1e-4)
|
| 465 |
+
ab4_loss = train_model(ab4_model, train_loader, criterion, optimizer_ab4, input_type='wav_and_cqcc', device=device, val_dataloader=val_loader, model_save_path=os.path.join(models_dir, "ablation_crossattn_linear.pth"))
|
| 466 |
+
del ab4_model, optimizer_ab4
|
| 467 |
+
torch.cuda.empty_cache()
|
| 468 |
+
|
| 469 |
+
# ============================================================
|
| 470 |
+
# Evaluation — reload one at a time
|
| 471 |
+
# ============================================================
|
| 472 |
+
print("\n--- Evaluating Models ---")
|
| 473 |
+
evals = []
|
| 474 |
+
|
| 475 |
+
models_to_eval = [
|
| 476 |
+
("Wav2Vec2 Baseline", Wav2Vec2SpoofDetector, "wav2vec2.pth", 'wav'),
|
| 477 |
+
("AASIST Baseline", AASISTDetector, "aasist.pth", 'wav'),
|
| 478 |
+
("CQCC Baseline", CQCCBaselineDetector, "cqcc_baseline.pth", 'cqcc'),
|
| 479 |
+
("Custom Fusion Model", ImprovedWav2Vec2CQCCDetector, "custom_hybrid.pth", 'wav_and_cqcc'),
|
| 480 |
+
("Ablation 1 (W2V2+Graph)", AblationWav2Vec2GraphDetector, "ablation_w2v2_graph.pth", 'wav'),
|
| 481 |
+
("Ablation 2 (CQCC+Graph)", AblationCQCCGraphDetector, "ablation_cqcc_graph.pth", 'cqcc'),
|
| 482 |
+
("Ablation 3 (Concat+Graph)", AblationConcatGraphDetector, "ablation_concat_graph.pth", 'wav_and_cqcc'),
|
| 483 |
+
("Ablation 4 (CrossAttn+Linear)", AblationCrossAttnLinearDetector, "ablation_crossattn_linear.pth", 'wav_and_cqcc'),
|
| 484 |
+
]
|
| 485 |
+
|
| 486 |
+
for name, model_class, filename, inp in models_to_eval:
|
| 487 |
+
model_path = os.path.join(models_dir, filename)
|
| 488 |
+
if not os.path.exists(model_path):
|
| 489 |
+
print(f"Skipping evaluation for {name} (Model weights not found at {model_path})")
|
| 490 |
+
continue
|
| 491 |
+
|
| 492 |
+
model_obj = model_class(num_classes=2).to(device)
|
| 493 |
+
model_obj.load_state_dict(torch.load(model_path, map_location=device))
|
| 494 |
+
model_obj.eval()
|
| 495 |
+
|
| 496 |
+
print(f"\n--- Metrics for {name} ---")
|
| 497 |
+
|
| 498 |
+
# 1. EVAL ON TRAIN SET
|
| 499 |
+
train_fpr, train_tpr, train_auc, train_eer, train_min_dcf, train_acc = evaluate_model(
|
| 500 |
+
model_obj, train_loader, input_type=inp, device=device
|
| 501 |
+
)
|
| 502 |
+
print(f"[Train] Acc={train_acc*100:.2f}% | EER={train_eer*100:.2f}% | minDCF={train_min_dcf:.4f}")
|
| 503 |
+
|
| 504 |
+
# 2. EVAL ON TEST SET
|
| 505 |
+
test_fpr, test_tpr, test_auc, test_eer, test_min_dcf, test_acc = evaluate_model(
|
| 506 |
+
model_obj, test_loader, input_type=inp, device=device
|
| 507 |
+
)
|
| 508 |
+
print(f"[Test ] Acc={test_acc*100:.2f}% | EER={test_eer*100:.2f}% | minDCF={test_min_dcf:.4f}")
|
| 509 |
+
|
| 510 |
+
del model_obj
|
| 511 |
+
torch.cuda.empty_cache()
|
| 512 |
+
|
| 513 |
+
if __name__ == "__main__":
|
| 514 |
+
main()
|
frontend/index.html
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
|
| 4 |
+
<head>
|
| 5 |
+
<meta charset="UTF-8">
|
| 6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 7 |
+
<title>OdioCheck | Deepfake Voice Detection</title>
|
| 8 |
+
<!-- Tailwind CSS -->
|
| 9 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
| 10 |
+
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600;700&display=swap" rel="stylesheet">
|
| 11 |
+
<link rel="stylesheet" href="style.css">
|
| 12 |
+
<!-- Chart.js -->
|
| 13 |
+
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
|
| 14 |
+
</head>
|
| 15 |
+
|
| 16 |
+
<body
|
| 17 |
+
class="bg-slate-900 text-slate-100 font-sans min-h-screen flex flex-col items-center justify-center p-6 subtle-bg">
|
| 18 |
+
|
| 19 |
+
<div class="glass-card max-w-2xl w-full rounded-3xl p-8 relative overflow-hidden transition-all duration-300">
|
| 20 |
+
<!-- Glowing Orb Background -->
|
| 21 |
+
<div
|
| 22 |
+
class="absolute -top-32 -left-32 w-64 h-64 bg-indigo-600 rounded-full mix-blend-multiply filter blur-3xl opacity-30 animate-pulse">
|
| 23 |
+
</div>
|
| 24 |
+
<div class="absolute -bottom-32 -right-32 w-64 h-64 bg-fuchsia-600 rounded-full mix-blend-multiply filter blur-3xl opacity-30 animate-pulse"
|
| 25 |
+
style="animation-delay: 2s;"></div>
|
| 26 |
+
|
| 27 |
+
<div class="relative z-10">
|
| 28 |
+
<h1
|
| 29 |
+
class="text-4xl font-bold mb-2 text-transparent bg-clip-text bg-gradient-to-r from-indigo-400 to-cyan-300">
|
| 30 |
+
OdioCheck
|
| 31 |
+
</h1>
|
| 32 |
+
<p class="text-slate-400 mb-8 font-light">Advanced Deepfake Voice Detection powered by SOTA Graph
|
| 33 |
+
architecture.</p>
|
| 34 |
+
|
| 35 |
+
<div id="drop-zone"
|
| 36 |
+
class="border-2 border-dashed border-slate-600 rounded-2xl p-10 flex flex-col items-center justify-center cursor-pointer hover:border-indigo-400 hover:bg-slate-800/50 transition-all duration-300 group">
|
| 37 |
+
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5"
|
| 38 |
+
stroke="currentColor"
|
| 39 |
+
class="w-12 h-12 text-slate-500 group-hover:text-indigo-400 mb-4 transition-colors">
|
| 40 |
+
<path stroke-linecap="round" stroke-linejoin="round"
|
| 41 |
+
d="M12 18.75a6 6 0 006-6v-1.5m-6 7.5a6 6 0 01-6-6v-1.5m6 7.5v3.75m-3.75 0h7.5M12 15.75a3 3 0 01-3-3V4.5a3 3 0 116 0v8.25a3 3 0 01-3 3z" />
|
| 42 |
+
</svg>
|
| 43 |
+
<p class="text-lg text-slate-300 font-medium">Click to upload or drag & drop</p>
|
| 44 |
+
<p class="text-sm text-slate-500 mt-1">Supports WAV, OGG, MP3, FLAC, M4A & more</p>
|
| 45 |
+
<input type="file" id="file-input" class="hidden" accept="audio/*">
|
| 46 |
+
</div>
|
| 47 |
+
|
| 48 |
+
<!-- Analysis Section -->
|
| 49 |
+
<div id="analysis-section" class="mt-8 hidden opacity-0 transition-opacity duration-500">
|
| 50 |
+
<div class="flex items-center space-x-4 mb-6">
|
| 51 |
+
<div id="loading-spinner" class="hidden">
|
| 52 |
+
<div class="animate-spin rounded-full h-8 w-8 border-b-2 border-indigo-400"></div>
|
| 53 |
+
</div>
|
| 54 |
+
<h2 id="status-text" class="text-xl font-semibold text-slate-300">Analyzing Spectrogram...</h2>
|
| 55 |
+
</div>
|
| 56 |
+
|
| 57 |
+
<!-- Results: panels will be inserted via JavaScript based on response keys -->
|
| 58 |
+
<div id="results" class="hidden">
|
| 59 |
+
<div id="model-panels" class="grid grid-cols-2 gap-6"></div>
|
| 60 |
+
</div>
|
| 61 |
+
</div>
|
| 62 |
+
</div>
|
| 63 |
+
</div>
|
| 64 |
+
|
| 65 |
+
<!-- Additional Graph Section for wow factor -->
|
| 66 |
+
<div id="chart-card"
|
| 67 |
+
class="glass-card max-w-2xl w-full rounded-3xl p-8 mt-6 hidden opacity-0 transition-opacity duration-500">
|
| 68 |
+
<h3 class="text-lg font-semibold mb-4 text-slate-300">Timeline Analysis</h3>
|
| 69 |
+
<canvas id="audioChart" height="100"></canvas>
|
| 70 |
+
</div>
|
| 71 |
+
|
| 72 |
+
<script src="script.js"></script>
|
| 73 |
+
</body>
|
| 74 |
+
|
| 75 |
+
</html>
|
frontend/script.js
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
const dropZone = document.getElementById('drop-zone');
|
| 2 |
+
const fileInput = document.getElementById('file-input');
|
| 3 |
+
const analysisSection = document.getElementById('analysis-section');
|
| 4 |
+
const statusText = document.getElementById('status-text');
|
| 5 |
+
const results = document.getElementById('results');
|
| 6 |
+
const loadingSpinner = document.getElementById('loading-spinner');
|
| 7 |
+
const chartCard = document.getElementById('chart-card');
|
| 8 |
+
|
| 9 |
+
// -------------------------------------------------------
|
| 10 |
+
// Chart setup
|
| 11 |
+
// -------------------------------------------------------
|
| 12 |
+
const ctx = document.getElementById('audioChart').getContext('2d');
|
| 13 |
+
let audioChart = new Chart(ctx, {
|
| 14 |
+
type: 'line',
|
| 15 |
+
data: { labels: [], datasets: [] },
|
| 16 |
+
options: {
|
| 17 |
+
responsive: true,
|
| 18 |
+
animation: { duration: 600, easing: 'easeInOutQuart' },
|
| 19 |
+
plugins: {
|
| 20 |
+
legend: { display: true, labels: { color: '#94a3b8', font: { size: 12 } } },
|
| 21 |
+
tooltip: {
|
| 22 |
+
callbacks: {
|
| 23 |
+
label: ctx => ` ${ctx.dataset.label}: ${ctx.parsed.y.toFixed(1)}% fake`,
|
| 24 |
+
title: items => `Segment @ ${items[0].label}s`
|
| 25 |
+
}
|
| 26 |
+
}
|
| 27 |
+
},
|
| 28 |
+
scales: {
|
| 29 |
+
y: {
|
| 30 |
+
beginAtZero: true,
|
| 31 |
+
max: 100,
|
| 32 |
+
ticks: { color: '#94a3b8', callback: v => v + '%' },
|
| 33 |
+
grid: { color: 'rgba(148,163,184,0.1)' },
|
| 34 |
+
title: { display: true, text: 'Fake Probability (%)', color: '#64748b' }
|
| 35 |
+
},
|
| 36 |
+
x: {
|
| 37 |
+
ticks: {
|
| 38 |
+
color: '#94a3b8', callback: (_, i, ticks) => {
|
| 39 |
+
// Show fewer labels when there are many windows
|
| 40 |
+
const step = Math.max(1, Math.floor(ticks.length / 8));
|
| 41 |
+
return i % step === 0 ? audioChart.data.labels[i] + 's' : '';
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
grid: { color: 'rgba(148,163,184,0.05)' },
|
| 45 |
+
title: { display: true, text: 'Time (seconds)', color: '#64748b' }
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
});
|
| 50 |
+
|
| 51 |
+
// Palette and display names for the four models
|
| 52 |
+
const MODEL_META = {
|
| 53 |
+
wav2vec2: { label: 'Wav2Vec2', color: '#3b82f6' },
|
| 54 |
+
aasist: { label: 'AASIST', color: '#f43f5e' },
|
| 55 |
+
cqcc_baseline: { label: 'CQCC Baseline', color: '#fbbf24' },
|
| 56 |
+
custom_hybrid: { label: 'Proposed Custom Hybrid', color: '#10b981' },
|
| 57 |
+
};
|
| 58 |
+
|
| 59 |
+
// -------------------------------------------------------
|
| 60 |
+
// File handling
|
| 61 |
+
// -------------------------------------------------------
|
| 62 |
+
function handleFile(file) {
|
| 63 |
+
if (!file) return;
|
| 64 |
+
|
| 65 |
+
// Show sections
|
| 66 |
+
analysisSection.classList.remove('hidden');
|
| 67 |
+
chartCard.classList.remove('hidden');
|
| 68 |
+
setTimeout(() => {
|
| 69 |
+
analysisSection.classList.remove('opacity-0');
|
| 70 |
+
chartCard.classList.remove('opacity-0');
|
| 71 |
+
}, 50);
|
| 72 |
+
|
| 73 |
+
results.classList.add('hidden');
|
| 74 |
+
loadingSpinner.classList.remove('hidden');
|
| 75 |
+
statusText.innerText = `Analyzing "${file.name}"…`;
|
| 76 |
+
|
| 77 |
+
// Clear previous state
|
| 78 |
+
document.getElementById('model-panels').innerHTML = '';
|
| 79 |
+
audioChart.data.labels = [];
|
| 80 |
+
audioChart.data.datasets = [];
|
| 81 |
+
audioChart.update();
|
| 82 |
+
|
| 83 |
+
// Animated placeholder while waiting: a single pulsing dataset
|
| 84 |
+
const placeholder = {
|
| 85 |
+
label: 'Analyzing…',
|
| 86 |
+
data: Array.from({ length: 20 }, (_, i) => 45 + Math.sin(i / 2) * 10),
|
| 87 |
+
borderColor: 'rgba(99,102,241,0.5)',
|
| 88 |
+
backgroundColor: 'rgba(99,102,241,0.05)',
|
| 89 |
+
borderDash: [4, 4],
|
| 90 |
+
fill: true,
|
| 91 |
+
tension: 0.4,
|
| 92 |
+
pointRadius: 0,
|
| 93 |
+
};
|
| 94 |
+
audioChart.data.labels = Array.from({ length: 20 }, (_, i) => i);
|
| 95 |
+
audioChart.data.datasets = [placeholder];
|
| 96 |
+
audioChart.update();
|
| 97 |
+
|
| 98 |
+
let tick = 0;
|
| 99 |
+
const loadingAnim = setInterval(() => {
|
| 100 |
+
tick++;
|
| 101 |
+
placeholder.data = Array.from({ length: 20 }, (_, i) =>
|
| 102 |
+
45 + Math.sin((i + tick) / 2) * 10
|
| 103 |
+
);
|
| 104 |
+
audioChart.update('none'); // skip animation for perf
|
| 105 |
+
}, 80);
|
| 106 |
+
|
| 107 |
+
const formData = new FormData();
|
| 108 |
+
formData.append('file', file);
|
| 109 |
+
|
| 110 |
+
const HF_API_URL = window.location.hostname === '127.0.0.1' || window.location.hostname === 'localhost'
|
| 111 |
+
? '/api/predict'
|
| 112 |
+
: 'https://junsiang26-odiocheck-backend.hf.space/api/predict';
|
| 113 |
+
|
| 114 |
+
fetch(HF_API_URL, { method: 'POST', body: formData })
|
| 115 |
+
.then(r => r.json())
|
| 116 |
+
.then(data => {
|
| 117 |
+
clearInterval(loadingAnim);
|
| 118 |
+
loadingSpinner.classList.add('hidden');
|
| 119 |
+
|
| 120 |
+
if (data.error) {
|
| 121 |
+
statusText.innerText = 'Error analyzing file.';
|
| 122 |
+
console.error(data.error);
|
| 123 |
+
return;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
renderResults(data);
|
| 127 |
+
})
|
| 128 |
+
.catch(() => {
|
| 129 |
+
clearInterval(loadingAnim);
|
| 130 |
+
loadingSpinner.classList.add('hidden');
|
| 131 |
+
statusText.innerText = 'Connection error. Is the backend running?';
|
| 132 |
+
});
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
// -------------------------------------------------------
|
| 136 |
+
// Render results from the new response shape:
|
| 137 |
+
// data.overall → { model: { prediction, fake_probability, real_probability } }
|
| 138 |
+
// data.timeline → { model: [fake_prob_pct, ...] }
|
| 139 |
+
// data.window_labels → [centre_sec, ...]
|
| 140 |
+
// -------------------------------------------------------
|
| 141 |
+
function renderResults(data) {
|
| 142 |
+
const { overall, timeline, window_labels } = data;
|
| 143 |
+
|
| 144 |
+
statusText.innerText = 'Analysis Complete';
|
| 145 |
+
results.classList.remove('hidden');
|
| 146 |
+
|
| 147 |
+
// --- Model panels (overall verdict) ---
|
| 148 |
+
const panelsEl = document.getElementById('model-panels');
|
| 149 |
+
panelsEl.innerHTML = '';
|
| 150 |
+
|
| 151 |
+
for (const [key, info] of Object.entries(overall)) {
|
| 152 |
+
const meta = MODEL_META[key] || { label: key, color: '#94a3b8' };
|
| 153 |
+
const isFake = info.prediction === 'FAKE';
|
| 154 |
+
const barColor = isFake ? 'from-rose-500 to-rose-400' : 'from-emerald-400 to-emerald-500';
|
| 155 |
+
const displayPct = isFake ? info.fake_probability : info.real_probability;
|
| 156 |
+
|
| 157 |
+
panelsEl.insertAdjacentHTML('beforeend', `
|
| 158 |
+
<div>
|
| 159 |
+
<div class="flex justify-between items-end mb-2">
|
| 160 |
+
<span class="text-sm text-slate-400 uppercase tracking-widest font-semibold"
|
| 161 |
+
style="color:${meta.color}">${meta.label}</span>
|
| 162 |
+
<span class="text-3xl font-bold tracking-wider ${isFake ? 'text-rose-500' : 'text-emerald-500'}">
|
| 163 |
+
${info.prediction}
|
| 164 |
+
</span>
|
| 165 |
+
</div>
|
| 166 |
+
<div class="text-xs text-slate-500 mb-2">
|
| 167 |
+
Fake: <span class="text-slate-300">${info.fake_probability}%</span>
|
| 168 |
+
·
|
| 169 |
+
Real: <span class="text-slate-300">${info.real_probability}%</span>
|
| 170 |
+
</div>
|
| 171 |
+
<div class="w-full bg-slate-700 h-4 rounded-full overflow-hidden mb-6 mt-1">
|
| 172 |
+
<div class="prob-bar h-full bg-gradient-to-r transition-all duration-1000 ease-out ${barColor}"
|
| 173 |
+
style="width:0%"
|
| 174 |
+
data-width="${displayPct}">
|
| 175 |
+
</div>
|
| 176 |
+
</div>
|
| 177 |
+
</div>`);
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
// Animate bars
|
| 181 |
+
requestAnimationFrame(() => {
|
| 182 |
+
document.querySelectorAll('.prob-bar').forEach(bar => {
|
| 183 |
+
bar.style.width = bar.dataset.width + '%';
|
| 184 |
+
});
|
| 185 |
+
});
|
| 186 |
+
|
| 187 |
+
// --- Timeline chart (real data) ---
|
| 188 |
+
// window_labels are now start-of-segment times (0, 2, 4 ...)
|
| 189 |
+
// For short audio with a single window, we pad with the audio-end label
|
| 190 |
+
// so the chart shows a line rather than a lonely dot.
|
| 191 |
+
let labels = [...window_labels];
|
| 192 |
+
let timelineValues = {};
|
| 193 |
+
Object.entries(timeline).forEach(([k, v]) => { timelineValues[k] = [...v]; });
|
| 194 |
+
|
| 195 |
+
if (labels.length === 1) {
|
| 196 |
+
// Estimate audio duration: single window = TARGET_LEN / 16000 ≈ 4.025s
|
| 197 |
+
const audioEnd = parseFloat((labels[0] + 4.025).toFixed(2));
|
| 198 |
+
labels.push(audioEnd);
|
| 199 |
+
Object.keys(timelineValues).forEach(k => timelineValues[k].push(timelineValues[k][0]));
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
audioChart.data.labels = labels;
|
| 203 |
+
audioChart.data.datasets = Object.entries(timelineValues).map(([key, values]) => {
|
| 204 |
+
const meta = MODEL_META[key] || { label: key, color: '#94a3b8' };
|
| 205 |
+
const hex = meta.color;
|
| 206 |
+
const rgb = hex.match(/[0-9a-fA-F]{2}/g).map(h => parseInt(h, 16)).join(',');
|
| 207 |
+
return {
|
| 208 |
+
label: meta.label,
|
| 209 |
+
data: values,
|
| 210 |
+
borderColor: hex,
|
| 211 |
+
backgroundColor: `rgba(${rgb},0.08)`,
|
| 212 |
+
fill: true,
|
| 213 |
+
tension: 0.4,
|
| 214 |
+
pointRadius: values.length <= 20 ? 4 : 2,
|
| 215 |
+
pointHoverRadius: 6,
|
| 216 |
+
};
|
| 217 |
+
});
|
| 218 |
+
|
| 219 |
+
// Add a 50% threshold reference line
|
| 220 |
+
audioChart.data.datasets.push({
|
| 221 |
+
label: 'Decision threshold (50%)',
|
| 222 |
+
data: Array(labels.length).fill(50),
|
| 223 |
+
borderColor: 'rgba(255,255,255,0.2)',
|
| 224 |
+
borderDash: [6, 4],
|
| 225 |
+
borderWidth: 1,
|
| 226 |
+
pointRadius: 0,
|
| 227 |
+
fill: false,
|
| 228 |
+
tension: 0,
|
| 229 |
+
});
|
| 230 |
+
|
| 231 |
+
audioChart.update();
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
// -------------------------------------------------------
|
| 235 |
+
// Drop zone wiring
|
| 236 |
+
// -------------------------------------------------------
|
| 237 |
+
dropZone.addEventListener('click', () => fileInput.click());
|
| 238 |
+
fileInput.addEventListener('change', e => handleFile(e.target.files[0]));
|
| 239 |
+
|
| 240 |
+
['dragenter', 'dragover', 'dragleave', 'drop'].forEach(name => {
|
| 241 |
+
dropZone.addEventListener(name, e => { e.preventDefault(); e.stopPropagation(); });
|
| 242 |
+
});
|
| 243 |
+
dropZone.addEventListener('drop', e => handleFile(e.dataTransfer.files[0]));
|
frontend/style.css
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* Glassmorphism utility classes */
|
| 2 |
+
.glass-card {
|
| 3 |
+
background: rgba(30, 41, 59, 0.7);
|
| 4 |
+
backdrop-filter: blur(12px);
|
| 5 |
+
-webkit-backdrop-filter: blur(12px);
|
| 6 |
+
border: 1px solid rgba(255, 255, 255, 0.1);
|
| 7 |
+
box-shadow: 0 4px 30px rgba(0, 0, 0, 0.1);
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
.subtle-bg {
|
| 11 |
+
background-color: #0f172a;
|
| 12 |
+
background-image:
|
| 13 |
+
radial-gradient(at 0% 0%, hsla(253, 16%, 7%, 1) 0, transparent 50%),
|
| 14 |
+
radial-gradient(at 50% 0%, hsla(225, 39%, 30%, 1) 0, transparent 50%),
|
| 15 |
+
radial-gradient(at 100% 0%, hsla(339, 49%, 30%, 1) 0, transparent 50%);
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
.animate-pulse {
|
| 19 |
+
animation: pulse 4s cubic-bezier(0.4, 0, 0.6, 1) infinite;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
@keyframes pulse {
|
| 23 |
+
|
| 24 |
+
0%,
|
| 25 |
+
100% {
|
| 26 |
+
opacity: 0.3;
|
| 27 |
+
transform: scale(1);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
50% {
|
| 31 |
+
opacity: 0.5;
|
| 32 |
+
transform: scale(1.05);
|
| 33 |
+
}
|
| 34 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets == 2.21.0
|
| 2 |
+
fastapi
|
| 3 |
+
librosa
|
| 4 |
+
matplotlib
|
| 5 |
+
numpy
|
| 6 |
+
python-multipart
|
| 7 |
+
python-pptx
|
| 8 |
+
scikit-learn
|
| 9 |
+
scipy
|
| 10 |
+
seaborn
|
| 11 |
+
soundfile
|
| 12 |
+
torch>=2.6.0
|
| 13 |
+
torchaudio
|
| 14 |
+
torchvision
|
| 15 |
+
tqdm
|
| 16 |
+
transformers
|
| 17 |
+
uvicorn
|