File size: 7,245 Bytes
ef5ede7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | """
SADA Deepfake Detection Model
ββββββββββββββββββββββββββββββ
Wav2Vec2-Base backbone with a custom classification head.
β’ projector : Linear(768 β 256)
β’ classifier: Linear(256 β 2) index 0 = AI/fake, index 1 = human/real
Weights are loaded from a state-dict file (best_deepfake_model_tensor.pt).
"""
from __future__ import annotations
import io
import logging
import os
import glob
from pathlib import Path
# --- Auto-inject FFmpeg to PATH for Windows (winget support) ---
if os.name == 'nt':
local_app_data = os.environ.get('LOCALAPPDATA', '')
if local_app_data:
ffmpeg_pattern = os.path.join(local_app_data, "Microsoft", "WinGet", "Packages", "Gyan.FFmpeg*", "**", "bin")
for p in glob.glob(ffmpeg_pattern, recursive=True):
if os.path.isdir(p) and "ffmpeg.exe" in os.listdir(p):
if p not in os.environ.get("PATH", ""):
os.environ["PATH"] = p + os.pathsep + os.environ.get("PATH", "")
break
# pyrefly: ignore [missing-import]
import librosa
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor
logger = logging.getLogger(__name__)
# ββ Label mapping ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
LABELS = {0: "human", 1: "ai"}
SAMPLE_RATE = 16_000 # Wav2Vec2 expects 16 kHz
MAX_DURATION_SEC = 30 # Truncate very long clips to save memory
# ββ Model architecture ββββββββββββββββββββββββββββββββββββββββββββββββββββ
class DeepfakeDetector(nn.Module):
"""Wav2Vec2-Base + projection head + 2-class classifier."""
def __init__(self, pretrained_backbone: str = "facebook/wav2vec2-base"):
super().__init__()
self.wav2vec2 = Wav2Vec2Model.from_pretrained(pretrained_backbone)
self.projector = nn.Linear(768, 256)
self.classifier = nn.Linear(256, 2)
def forward(
self,
input_values: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
outputs = self.wav2vec2(
input_values=input_values,
attention_mask=attention_mask,
)
# Mean-pool over time axis
hidden = outputs.last_hidden_state # (B, T, 768)
if attention_mask is not None:
mask = attention_mask.unsqueeze(-1).float() # (B, T, 1)
pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
else:
pooled = hidden.mean(dim=1) # (B, 768)
projected = self.projector(pooled) # (B, 256)
logits = self.classifier(projected) # (B, 2)
return logits
# ββ Loading ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def load_model(
weights_path: str | Path,
device: str = "cpu",
) -> tuple[DeepfakeDetector, Wav2Vec2FeatureExtractor]:
"""Instantiate model, load weights, and return (model, feature_extractor)."""
logger.info("Loading Wav2Vec2 backbone from HuggingFace β¦")
model = DeepfakeDetector(pretrained_backbone="facebook/wav2vec2-base")
logger.info("Loading fine-tuned weights from %s β¦", weights_path)
state_dict = torch.load(weights_path, map_location=device, weights_only=False)
model.load_state_dict(state_dict, strict=True)
model.to(device)
model.eval()
logger.info("Model loaded successfully on device=%s", device)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"facebook/wav2vec2-base"
)
return model, feature_extractor
import tempfile
# ββ Inference ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _guess_suffix(raw_bytes: bytes) -> str:
"""Guess file extension from magic bytes so librosa/ffmpeg decodes correctly."""
header = raw_bytes[:16]
if header[:4] == b'RIFF' and header[8:12] == b'WAVE':
return ".wav"
if header[:3] == b'ID3' or header[:2] == b'\xff\xfb':
return ".mp3"
if header[:4] == b'fLaC':
return ".flac"
if header[:4] == b'OggS':
return ".ogg"
if header[4:8] == b'ftyp': # MP4/M4A container
return ".m4a"
if header[:4] == b'\x1aE\xdf\xa3': # Matroska/WebM
return ".webm"
return ".wav" # fallback β most decoders handle raw PCM
def _load_audio(raw_bytes: bytes) -> np.ndarray:
"""Decode arbitrary audio bytes to a 16 kHz mono float32 numpy array."""
suffix = _guess_suffix(raw_bytes)
logger.info("Detected audio format suffix: %s (%d bytes)", suffix, len(raw_bytes))
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp.write(raw_bytes)
tmp_path = tmp.name
try:
audio, _ = librosa.load(tmp_path, sr=SAMPLE_RATE, mono=True)
finally:
os.remove(tmp_path)
# Truncate to MAX_DURATION_SEC to avoid OOM
max_samples = SAMPLE_RATE * MAX_DURATION_SEC
if len(audio) > max_samples:
audio = audio[:max_samples]
# Peak-normalise so quiet mic recordings match the amplitude of
# clean uploaded files the model was trained on.
peak = np.max(np.abs(audio))
if peak > 1e-6:
audio = audio / peak
return audio
@torch.no_grad()
def predict(
audio_bytes: bytes,
model: DeepfakeDetector,
feature_extractor: Wav2Vec2FeatureExtractor,
device: str = "cpu",
) -> dict:
"""
Run inference on raw audio bytes.
Returns
-------
dict {"label": "ai"|"human", "confidence": float, "breakdown": {...}}
"""
# 1. Decode audio
waveform = _load_audio(audio_bytes)
duration_seconds = len(waveform) / SAMPLE_RATE
if len(waveform) < SAMPLE_RATE * 0.5:
raise ValueError(
f"Audio too short ({duration_seconds:.1f}s). "
"Please provide at least 0.5 seconds of audio."
)
# 2. Feature extraction
inputs = feature_extractor(
waveform,
sampling_rate=SAMPLE_RATE,
return_tensors="pt",
padding=True,
)
input_values = inputs.input_values.to(device)
# 3. Forward pass
logits = model(input_values) # (1, 2)
probs = F.softmax(logits, dim=-1).squeeze(0) # (2,)
human_prob = round(probs[0].item() * 100, 2)
ai_prob = round(probs[1].item() * 100, 2)
label = LABELS[probs.argmax().item()]
confidence = ai_prob if label == "ai" else human_prob
return {
"label": label,
"confidence": confidence,
"breakdown": {
"ai": ai_prob,
"human": human_prob,
"noise": 0.0,
},
"duration_seconds": round(duration_seconds, 2),
}
|