| """ |
| SADA Deepfake Detection Model |
| ββββββββββββββββββββββββββββββ |
| Wav2Vec2-Base backbone with a custom classification head. |
| β’ projector : Linear(768 β 256) |
| β’ classifier: Linear(256 β 2) index 0 = AI/fake, index 1 = human/real |
| |
| Weights are loaded from a state-dict file (best_deepfake_model_tensor.pt). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import io |
| import logging |
| import os |
| import glob |
| from pathlib import Path |
|
|
| |
| if os.name == 'nt': |
| local_app_data = os.environ.get('LOCALAPPDATA', '') |
| if local_app_data: |
| ffmpeg_pattern = os.path.join(local_app_data, "Microsoft", "WinGet", "Packages", "Gyan.FFmpeg*", "**", "bin") |
| for p in glob.glob(ffmpeg_pattern, recursive=True): |
| if os.path.isdir(p) and "ffmpeg.exe" in os.listdir(p): |
| if p not in os.environ.get("PATH", ""): |
| os.environ["PATH"] = p + os.pathsep + os.environ.get("PATH", "") |
| break |
|
|
|
|
| |
| import librosa |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| LABELS = {0: "human", 1: "ai"} |
| SAMPLE_RATE = 16_000 |
| MAX_DURATION_SEC = 30 |
|
|
|
|
| |
| class DeepfakeDetector(nn.Module): |
| """Wav2Vec2-Base + projection head + 2-class classifier.""" |
|
|
| def __init__(self, pretrained_backbone: str = "facebook/wav2vec2-base"): |
| super().__init__() |
| self.wav2vec2 = Wav2Vec2Model.from_pretrained(pretrained_backbone) |
| self.projector = nn.Linear(768, 256) |
| self.classifier = nn.Linear(256, 2) |
|
|
| def forward( |
| self, |
| input_values: torch.Tensor, |
| attention_mask: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| outputs = self.wav2vec2( |
| input_values=input_values, |
| attention_mask=attention_mask, |
| ) |
| |
| hidden = outputs.last_hidden_state |
| if attention_mask is not None: |
| mask = attention_mask.unsqueeze(-1).float() |
| pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9) |
| else: |
| pooled = hidden.mean(dim=1) |
|
|
| projected = self.projector(pooled) |
| logits = self.classifier(projected) |
| return logits |
|
|
|
|
| |
| def load_model( |
| weights_path: str | Path, |
| device: str = "cpu", |
| ) -> tuple[DeepfakeDetector, Wav2Vec2FeatureExtractor]: |
| """Instantiate model, load weights, and return (model, feature_extractor).""" |
| logger.info("Loading Wav2Vec2 backbone from HuggingFace β¦") |
| model = DeepfakeDetector(pretrained_backbone="facebook/wav2vec2-base") |
|
|
| logger.info("Loading fine-tuned weights from %s β¦", weights_path) |
| state_dict = torch.load(weights_path, map_location=device, weights_only=False) |
| model.load_state_dict(state_dict, strict=True) |
| model.to(device) |
| model.eval() |
| logger.info("Model loaded successfully on device=%s", device) |
|
|
| feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( |
| "facebook/wav2vec2-base" |
| ) |
| return model, feature_extractor |
|
|
|
|
| import tempfile |
|
|
| |
|
|
| def _guess_suffix(raw_bytes: bytes) -> str: |
| """Guess file extension from magic bytes so librosa/ffmpeg decodes correctly.""" |
| header = raw_bytes[:16] |
| if header[:4] == b'RIFF' and header[8:12] == b'WAVE': |
| return ".wav" |
| if header[:3] == b'ID3' or header[:2] == b'\xff\xfb': |
| return ".mp3" |
| if header[:4] == b'fLaC': |
| return ".flac" |
| if header[:4] == b'OggS': |
| return ".ogg" |
| if header[4:8] == b'ftyp': |
| return ".m4a" |
| if header[:4] == b'\x1aE\xdf\xa3': |
| return ".webm" |
| return ".wav" |
|
|
|
|
| def _load_audio(raw_bytes: bytes) -> np.ndarray: |
| """Decode arbitrary audio bytes to a 16 kHz mono float32 numpy array.""" |
| suffix = _guess_suffix(raw_bytes) |
| logger.info("Detected audio format suffix: %s (%d bytes)", suffix, len(raw_bytes)) |
|
|
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: |
| tmp.write(raw_bytes) |
| tmp_path = tmp.name |
|
|
| try: |
| audio, _ = librosa.load(tmp_path, sr=SAMPLE_RATE, mono=True) |
| finally: |
| os.remove(tmp_path) |
|
|
| |
| max_samples = SAMPLE_RATE * MAX_DURATION_SEC |
| if len(audio) > max_samples: |
| audio = audio[:max_samples] |
|
|
| |
| |
| peak = np.max(np.abs(audio)) |
| if peak > 1e-6: |
| audio = audio / peak |
|
|
| return audio |
|
|
|
|
| @torch.no_grad() |
| def predict( |
| audio_bytes: bytes, |
| model: DeepfakeDetector, |
| feature_extractor: Wav2Vec2FeatureExtractor, |
| device: str = "cpu", |
| ) -> dict: |
| """ |
| Run inference on raw audio bytes. |
| |
| Returns |
| ------- |
| dict {"label": "ai"|"human", "confidence": float, "breakdown": {...}} |
| """ |
| |
| waveform = _load_audio(audio_bytes) |
| duration_seconds = len(waveform) / SAMPLE_RATE |
|
|
| if len(waveform) < SAMPLE_RATE * 0.5: |
| raise ValueError( |
| f"Audio too short ({duration_seconds:.1f}s). " |
| "Please provide at least 0.5 seconds of audio." |
| ) |
|
|
| |
| inputs = feature_extractor( |
| waveform, |
| sampling_rate=SAMPLE_RATE, |
| return_tensors="pt", |
| padding=True, |
| ) |
| input_values = inputs.input_values.to(device) |
|
|
| |
| logits = model(input_values) |
| probs = F.softmax(logits, dim=-1).squeeze(0) |
|
|
| human_prob = round(probs[0].item() * 100, 2) |
| ai_prob = round(probs[1].item() * 100, 2) |
|
|
| label = LABELS[probs.argmax().item()] |
| confidence = ai_prob if label == "ai" else human_prob |
|
|
| return { |
| "label": label, |
| "confidence": confidence, |
| "breakdown": { |
| "ai": ai_prob, |
| "human": human_prob, |
| "noise": 0.0, |
| }, |
| "duration_seconds": round(duration_seconds, 2), |
| } |
|
|