backend / model.py
AryaAzhar's picture
Upload 7 files
ef5ede7 verified
Raw
History Blame Contribute Delete
7.25 kB
"""
SADA Deepfake Detection Model
──────────────────────────────
Wav2Vec2-Base backbone with a custom classification head.
β€’ projector : Linear(768 β†’ 256)
β€’ classifier: Linear(256 β†’ 2) index 0 = AI/fake, index 1 = human/real
Weights are loaded from a state-dict file (best_deepfake_model_tensor.pt).
"""
from __future__ import annotations
import io
import logging
import os
import glob
from pathlib import Path
# --- Auto-inject FFmpeg to PATH for Windows (winget support) ---
if os.name == 'nt':
local_app_data = os.environ.get('LOCALAPPDATA', '')
if local_app_data:
ffmpeg_pattern = os.path.join(local_app_data, "Microsoft", "WinGet", "Packages", "Gyan.FFmpeg*", "**", "bin")
for p in glob.glob(ffmpeg_pattern, recursive=True):
if os.path.isdir(p) and "ffmpeg.exe" in os.listdir(p):
if p not in os.environ.get("PATH", ""):
os.environ["PATH"] = p + os.pathsep + os.environ.get("PATH", "")
break
# pyrefly: ignore [missing-import]
import librosa
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor
logger = logging.getLogger(__name__)
# ── Label mapping ──────────────────────────────────────────────────────────
LABELS = {0: "human", 1: "ai"}
SAMPLE_RATE = 16_000 # Wav2Vec2 expects 16 kHz
MAX_DURATION_SEC = 30 # Truncate very long clips to save memory
# ── Model architecture ────────────────────────────────────────────────────
class DeepfakeDetector(nn.Module):
"""Wav2Vec2-Base + projection head + 2-class classifier."""
def __init__(self, pretrained_backbone: str = "facebook/wav2vec2-base"):
super().__init__()
self.wav2vec2 = Wav2Vec2Model.from_pretrained(pretrained_backbone)
self.projector = nn.Linear(768, 256)
self.classifier = nn.Linear(256, 2)
def forward(
self,
input_values: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
outputs = self.wav2vec2(
input_values=input_values,
attention_mask=attention_mask,
)
# Mean-pool over time axis
hidden = outputs.last_hidden_state # (B, T, 768)
if attention_mask is not None:
mask = attention_mask.unsqueeze(-1).float() # (B, T, 1)
pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
else:
pooled = hidden.mean(dim=1) # (B, 768)
projected = self.projector(pooled) # (B, 256)
logits = self.classifier(projected) # (B, 2)
return logits
# ── Loading ────────────────────────────────────────────────────────────────
def load_model(
weights_path: str | Path,
device: str = "cpu",
) -> tuple[DeepfakeDetector, Wav2Vec2FeatureExtractor]:
"""Instantiate model, load weights, and return (model, feature_extractor)."""
logger.info("Loading Wav2Vec2 backbone from HuggingFace …")
model = DeepfakeDetector(pretrained_backbone="facebook/wav2vec2-base")
logger.info("Loading fine-tuned weights from %s …", weights_path)
state_dict = torch.load(weights_path, map_location=device, weights_only=False)
model.load_state_dict(state_dict, strict=True)
model.to(device)
model.eval()
logger.info("Model loaded successfully on device=%s", device)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"facebook/wav2vec2-base"
)
return model, feature_extractor
import tempfile
# ── Inference ──────────────────────────────────────────────────────────────
def _guess_suffix(raw_bytes: bytes) -> str:
"""Guess file extension from magic bytes so librosa/ffmpeg decodes correctly."""
header = raw_bytes[:16]
if header[:4] == b'RIFF' and header[8:12] == b'WAVE':
return ".wav"
if header[:3] == b'ID3' or header[:2] == b'\xff\xfb':
return ".mp3"
if header[:4] == b'fLaC':
return ".flac"
if header[:4] == b'OggS':
return ".ogg"
if header[4:8] == b'ftyp': # MP4/M4A container
return ".m4a"
if header[:4] == b'\x1aE\xdf\xa3': # Matroska/WebM
return ".webm"
return ".wav" # fallback β€” most decoders handle raw PCM
def _load_audio(raw_bytes: bytes) -> np.ndarray:
"""Decode arbitrary audio bytes to a 16 kHz mono float32 numpy array."""
suffix = _guess_suffix(raw_bytes)
logger.info("Detected audio format suffix: %s (%d bytes)", suffix, len(raw_bytes))
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp.write(raw_bytes)
tmp_path = tmp.name
try:
audio, _ = librosa.load(tmp_path, sr=SAMPLE_RATE, mono=True)
finally:
os.remove(tmp_path)
# Truncate to MAX_DURATION_SEC to avoid OOM
max_samples = SAMPLE_RATE * MAX_DURATION_SEC
if len(audio) > max_samples:
audio = audio[:max_samples]
# Peak-normalise so quiet mic recordings match the amplitude of
# clean uploaded files the model was trained on.
peak = np.max(np.abs(audio))
if peak > 1e-6:
audio = audio / peak
return audio
@torch.no_grad()
def predict(
audio_bytes: bytes,
model: DeepfakeDetector,
feature_extractor: Wav2Vec2FeatureExtractor,
device: str = "cpu",
) -> dict:
"""
Run inference on raw audio bytes.
Returns
-------
dict {"label": "ai"|"human", "confidence": float, "breakdown": {...}}
"""
# 1. Decode audio
waveform = _load_audio(audio_bytes)
duration_seconds = len(waveform) / SAMPLE_RATE
if len(waveform) < SAMPLE_RATE * 0.5:
raise ValueError(
f"Audio too short ({duration_seconds:.1f}s). "
"Please provide at least 0.5 seconds of audio."
)
# 2. Feature extraction
inputs = feature_extractor(
waveform,
sampling_rate=SAMPLE_RATE,
return_tensors="pt",
padding=True,
)
input_values = inputs.input_values.to(device)
# 3. Forward pass
logits = model(input_values) # (1, 2)
probs = F.softmax(logits, dim=-1).squeeze(0) # (2,)
human_prob = round(probs[0].item() * 100, 2)
ai_prob = round(probs[1].item() * 100, 2)
label = LABELS[probs.argmax().item()]
confidence = ai_prob if label == "ai" else human_prob
return {
"label": label,
"confidence": confidence,
"breakdown": {
"ai": ai_prob,
"human": human_prob,
"noise": 0.0,
},
"duration_seconds": round(duration_seconds, 2),
}