File size: 8,380 Bytes
9f2b6db
2c3f313
 
 
 
 
9f2b6db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409c94a
 
 
 
9f2b6db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import os
import sys

# Add the current directory to sys.path so we can import local modules
sys.path.append(os.path.dirname(__file__))

import torch
import numpy as np
import torch.nn.functional as F
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from dataset import compute_cqcc
import librosa

from models import (
    Wav2Vec2SpoofDetector,
    AASISTDetector,
    CQCCBaselineDetector,
    ImprovedWav2Vec2CQCCDetector
)

app = FastAPI(title="Deepfake Voice Detection")

@app.get("/health")
def health():
    return {"status": "online"}

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------------------------------------------
# Load Models
# -------------------------------------------------------

models_dir = os.path.join(os.path.dirname(__file__), "models")

def load_model(model, filename):
    path = os.path.join(models_dir, filename)
    if os.path.exists(path):
        state_dict = torch.load(path, map_location=device)
        # Handle weight_norm parametrization mismatch (common in Wav2Vec2 between versions)
        # This converts the 'parametrizations' keys back to 'weight_g' and 'weight_v'
        new_state_dict = {}
        for k, v in state_dict.items():
            if "pos_conv_embed.conv.parametrizations.weight.original0" in k:
                new_key = k.replace("parametrizations.weight.original0", "weight_g")
                new_state_dict[new_key] = v
            elif "pos_conv_embed.conv.parametrizations.weight.original1" in k:
                new_key = k.replace("parametrizations.weight.original1", "weight_v")
                new_state_dict[new_key] = v
            else:
                new_state_dict[k] = v
        model.load_state_dict(new_state_dict)
        print(f"Loaded {filename}")
    else:
        print(f"WARNING: {filename} not found. Run train.py first!")
    model.eval()

    return model


wav2vec_model = load_model(
    Wav2Vec2SpoofDetector(num_classes=2).to(device),
    "wav2vec2.pth"
)

aasist_model = load_model(
    AASISTDetector(num_classes=2).to(device),
    "aasist.pth"
)

cqcc_baseline_model = load_model(
    CQCCBaselineDetector(num_classes=2).to(device),
    "cqcc_baseline.pth"
)

custom_hybrid_model = load_model(
    ImprovedWav2Vec2CQCCDetector(num_classes=2).to(device),
    "custom_hybrid.pth"
)


# -------------------------------------------------------
# Audio Preprocessing (mirrors dataset.py __getitem__)
# -------------------------------------------------------

TARGET_LEN = 64600  # AASIST standard: 4.025s at 16kHz
CQCC_N_BINS = 60    # Matches AudioDataset default

# 50% overlap: each step is half a window (~2s), giving smooth temporal curves
# without running 4x Wav2Vec2 passes per second.
WINDOW_STEP = TARGET_LEN // 2


def preprocess_window(wav_np: np.ndarray) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Crop or pad a single audio window to TARGET_LEN, then compute waveform
    and CQCC tensors — identical to AudioDataset.__getitem__ (non-augmented).

    Returns:
        wav  : (1, TARGET_LEN) float32 tensor
        cqcc : (1, 20, T)      float32 tensor
    """
    # Center-crop or zero-pad to exactly TARGET_LEN (matches eval path in dataset.py)
    if len(wav_np) > TARGET_LEN:
        start = (len(wav_np) - TARGET_LEN) // 2
        wav_np = wav_np[start : start + TARGET_LEN]
    elif len(wav_np) < TARGET_LEN:
        wav_np = np.pad(wav_np, (0, TARGET_LEN - len(wav_np)), mode='constant')

    wav = torch.from_numpy(wav_np).unsqueeze(0).float()
    cqcc = compute_cqcc(wav_np, n_bins=CQCC_N_BINS)   # → (1, 20, T)
    return wav, cqcc


def run_window(wav: torch.Tensor, cqcc: torch.Tensor) -> dict:
    """
    Run all four models on a single window and return fake probabilities (0–100).
    """
    wav  = wav.unsqueeze(0).to(device)    # (1, 1, TARGET_LEN)
    cqcc = cqcc.unsqueeze(0).to(device)   # (1, 1, 20, T)

    with torch.no_grad():
        w2v_prob    = torch.softmax(wav2vec_model(wav),             dim=1)[0][1].item()
        aasist_prob = torch.softmax(aasist_model(wav),              dim=1)[0][1].item()
        cqcc_prob   = torch.softmax(cqcc_baseline_model(cqcc),      dim=1)[0][1].item()
        custom_prob = torch.softmax(custom_hybrid_model(wav, cqcc), dim=1)[0][1].item()

    return {
        "wav2vec2":      round(w2v_prob    * 100, 2),
        "aasist":        round(aasist_prob * 100, 2),
        "cqcc_baseline": round(cqcc_prob   * 100, 2),
        "custom_hybrid": round(custom_prob * 100, 2),
    }


def aggregate_prediction(fake_prob_pct: float) -> dict:
    """Convert a mean fake probability into the standard prediction dict."""
    return {
        "prediction":       "FAKE" if fake_prob_pct > 50 else "REAL",
        "fake_probability": fake_prob_pct,
        "real_probability": round(100 - fake_prob_pct, 2),
    }


# -------------------------------------------------------
# Prediction Endpoint
# -------------------------------------------------------
@app.post("/api/predict")
async def predict(file: UploadFile = File(...)):
    temp_path = f"temp_{file.filename}"
    try:
        with open(temp_path, "wb") as f:
            f.write(await file.read())

        # Load at 16 kHz mono — identical to librosa.load call in dataset.py
        wav_np, sr = librosa.load(temp_path, sr=16000, mono=True)

        # -------------------------------------------------------
        # Slice into overlapping windows of TARGET_LEN samples.
        # Step = 50% overlap.  Very short clips produce a single window.
        # -------------------------------------------------------
        total_samples = len(wav_np)
        starts = list(range(0, total_samples, WINDOW_STEP))

        window_probs  = []   # per-window fake-probability dicts
        window_labels = []   # x-axis: start of each window in seconds

        for start in starts:
            chunk = wav_np[start : start + TARGET_LEN]
            wav_t, cqcc_t = preprocess_window(chunk)
            probs = run_window(wav_t, cqcc_t)
            window_probs.append(probs)

            start_sec = round(start / sr, 2)
            window_labels.append(start_sec)

        # -------------------------------------------------------
        # Overall prediction = mean fake probability across all windows
        # -------------------------------------------------------
        model_keys = ["wav2vec2", "aasist", "cqcc_baseline", "custom_hybrid"]
        overall = {}
        for key in model_keys:
            mean_fake = round(
                sum(w[key] for w in window_probs) / len(window_probs), 2
            )
            overall[key] = aggregate_prediction(mean_fake)

        # -------------------------------------------------------
        # Time-series data for the frontend chart
        # -------------------------------------------------------
        timeline = {
            key: [w[key] for w in window_probs]
            for key in model_keys
        }

        return JSONResponse({
            "overall":       overall,         # {model: {prediction, fake_probability, real_probability}}
            "timeline":      timeline,        # {model: [fake_prob_pct, ...]}  — one value per window
            "window_labels": window_labels,   # [start_sec, ...]               — x-axis in seconds (starts at 0.0)
        })

    except Exception as e:
        return JSONResponse({"error": str(e)}, status_code=500)

    finally:
        if os.path.exists(temp_path):
            os.remove(temp_path)

# -------------------------------------------------------
# Serve frontend
# -------------------------------------------------------

frontend_dir = os.path.join(os.path.dirname(__file__), "..", "frontend")

if os.path.exists(frontend_dir):
    app.mount("/", StaticFiles(directory=frontend_dir, html=True), name="frontend")


# -------------------------------------------------------
# Run Server
# -------------------------------------------------------

if __name__ == "__main__":

    import uvicorn

    print("Starting server at http://127.0.0.1:8000")

    uvicorn.run(app, host="127.0.0.1", port=8000)