Spaces:
Running
Running
File size: 8,380 Bytes
9f2b6db 2c3f313 9f2b6db 409c94a 9f2b6db | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 | import os
import sys
# Add the current directory to sys.path so we can import local modules
sys.path.append(os.path.dirname(__file__))
import torch
import numpy as np
import torch.nn.functional as F
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from dataset import compute_cqcc
import librosa
from models import (
Wav2Vec2SpoofDetector,
AASISTDetector,
CQCCBaselineDetector,
ImprovedWav2Vec2CQCCDetector
)
app = FastAPI(title="Deepfake Voice Detection")
@app.get("/health")
def health():
return {"status": "online"}
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -------------------------------------------------------
# Load Models
# -------------------------------------------------------
models_dir = os.path.join(os.path.dirname(__file__), "models")
def load_model(model, filename):
path = os.path.join(models_dir, filename)
if os.path.exists(path):
state_dict = torch.load(path, map_location=device)
# Handle weight_norm parametrization mismatch (common in Wav2Vec2 between versions)
# This converts the 'parametrizations' keys back to 'weight_g' and 'weight_v'
new_state_dict = {}
for k, v in state_dict.items():
if "pos_conv_embed.conv.parametrizations.weight.original0" in k:
new_key = k.replace("parametrizations.weight.original0", "weight_g")
new_state_dict[new_key] = v
elif "pos_conv_embed.conv.parametrizations.weight.original1" in k:
new_key = k.replace("parametrizations.weight.original1", "weight_v")
new_state_dict[new_key] = v
else:
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
print(f"Loaded {filename}")
else:
print(f"WARNING: {filename} not found. Run train.py first!")
model.eval()
return model
wav2vec_model = load_model(
Wav2Vec2SpoofDetector(num_classes=2).to(device),
"wav2vec2.pth"
)
aasist_model = load_model(
AASISTDetector(num_classes=2).to(device),
"aasist.pth"
)
cqcc_baseline_model = load_model(
CQCCBaselineDetector(num_classes=2).to(device),
"cqcc_baseline.pth"
)
custom_hybrid_model = load_model(
ImprovedWav2Vec2CQCCDetector(num_classes=2).to(device),
"custom_hybrid.pth"
)
# -------------------------------------------------------
# Audio Preprocessing (mirrors dataset.py __getitem__)
# -------------------------------------------------------
TARGET_LEN = 64600 # AASIST standard: 4.025s at 16kHz
CQCC_N_BINS = 60 # Matches AudioDataset default
# 50% overlap: each step is half a window (~2s), giving smooth temporal curves
# without running 4x Wav2Vec2 passes per second.
WINDOW_STEP = TARGET_LEN // 2
def preprocess_window(wav_np: np.ndarray) -> tuple[torch.Tensor, torch.Tensor]:
"""
Crop or pad a single audio window to TARGET_LEN, then compute waveform
and CQCC tensors — identical to AudioDataset.__getitem__ (non-augmented).
Returns:
wav : (1, TARGET_LEN) float32 tensor
cqcc : (1, 20, T) float32 tensor
"""
# Center-crop or zero-pad to exactly TARGET_LEN (matches eval path in dataset.py)
if len(wav_np) > TARGET_LEN:
start = (len(wav_np) - TARGET_LEN) // 2
wav_np = wav_np[start : start + TARGET_LEN]
elif len(wav_np) < TARGET_LEN:
wav_np = np.pad(wav_np, (0, TARGET_LEN - len(wav_np)), mode='constant')
wav = torch.from_numpy(wav_np).unsqueeze(0).float()
cqcc = compute_cqcc(wav_np, n_bins=CQCC_N_BINS) # → (1, 20, T)
return wav, cqcc
def run_window(wav: torch.Tensor, cqcc: torch.Tensor) -> dict:
"""
Run all four models on a single window and return fake probabilities (0–100).
"""
wav = wav.unsqueeze(0).to(device) # (1, 1, TARGET_LEN)
cqcc = cqcc.unsqueeze(0).to(device) # (1, 1, 20, T)
with torch.no_grad():
w2v_prob = torch.softmax(wav2vec_model(wav), dim=1)[0][1].item()
aasist_prob = torch.softmax(aasist_model(wav), dim=1)[0][1].item()
cqcc_prob = torch.softmax(cqcc_baseline_model(cqcc), dim=1)[0][1].item()
custom_prob = torch.softmax(custom_hybrid_model(wav, cqcc), dim=1)[0][1].item()
return {
"wav2vec2": round(w2v_prob * 100, 2),
"aasist": round(aasist_prob * 100, 2),
"cqcc_baseline": round(cqcc_prob * 100, 2),
"custom_hybrid": round(custom_prob * 100, 2),
}
def aggregate_prediction(fake_prob_pct: float) -> dict:
"""Convert a mean fake probability into the standard prediction dict."""
return {
"prediction": "FAKE" if fake_prob_pct > 50 else "REAL",
"fake_probability": fake_prob_pct,
"real_probability": round(100 - fake_prob_pct, 2),
}
# -------------------------------------------------------
# Prediction Endpoint
# -------------------------------------------------------
@app.post("/api/predict")
async def predict(file: UploadFile = File(...)):
temp_path = f"temp_{file.filename}"
try:
with open(temp_path, "wb") as f:
f.write(await file.read())
# Load at 16 kHz mono — identical to librosa.load call in dataset.py
wav_np, sr = librosa.load(temp_path, sr=16000, mono=True)
# -------------------------------------------------------
# Slice into overlapping windows of TARGET_LEN samples.
# Step = 50% overlap. Very short clips produce a single window.
# -------------------------------------------------------
total_samples = len(wav_np)
starts = list(range(0, total_samples, WINDOW_STEP))
window_probs = [] # per-window fake-probability dicts
window_labels = [] # x-axis: start of each window in seconds
for start in starts:
chunk = wav_np[start : start + TARGET_LEN]
wav_t, cqcc_t = preprocess_window(chunk)
probs = run_window(wav_t, cqcc_t)
window_probs.append(probs)
start_sec = round(start / sr, 2)
window_labels.append(start_sec)
# -------------------------------------------------------
# Overall prediction = mean fake probability across all windows
# -------------------------------------------------------
model_keys = ["wav2vec2", "aasist", "cqcc_baseline", "custom_hybrid"]
overall = {}
for key in model_keys:
mean_fake = round(
sum(w[key] for w in window_probs) / len(window_probs), 2
)
overall[key] = aggregate_prediction(mean_fake)
# -------------------------------------------------------
# Time-series data for the frontend chart
# -------------------------------------------------------
timeline = {
key: [w[key] for w in window_probs]
for key in model_keys
}
return JSONResponse({
"overall": overall, # {model: {prediction, fake_probability, real_probability}}
"timeline": timeline, # {model: [fake_prob_pct, ...]} — one value per window
"window_labels": window_labels, # [start_sec, ...] — x-axis in seconds (starts at 0.0)
})
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
finally:
if os.path.exists(temp_path):
os.remove(temp_path)
# -------------------------------------------------------
# Serve frontend
# -------------------------------------------------------
frontend_dir = os.path.join(os.path.dirname(__file__), "..", "frontend")
if os.path.exists(frontend_dir):
app.mount("/", StaticFiles(directory=frontend_dir, html=True), name="frontend")
# -------------------------------------------------------
# Run Server
# -------------------------------------------------------
if __name__ == "__main__":
import uvicorn
print("Starting server at http://127.0.0.1:8000")
uvicorn.run(app, host="127.0.0.1", port=8000) |