voice-detection-api / ml /sota_model.py
Hariharan S
Upgrade to SOTA Wav2Vec2 deepfake detector
488006a
import torch
import torch.nn.functional as F
import torchaudio
from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor
import logging
import os
import shutil
logger = logging.getLogger(__name__)
class DeepfakeDetector:
def __init__(self, model_name="hemgg/Deepfake-audio-detection"):
"""
Initialize the SOTA Deepfake Detector model.
Uses a pre-trained Wav2Vec2 model fine-tuned for deepfake detection.
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Loading SOTA model: {model_name} on {self.device}...")
try:
self.model = AutoModelForAudioClassification.from_pretrained(model_name).to(self.device).eval()
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
self.loaded = True
logger.info("SOTA Model loaded successfully!")
except Exception as e:
logger.error(f"Failed to load SOTA model: {e}")
self.loaded = False
def predict(self, audio_path):
"""
Predict if audio is AI-generated (Fake) or Human (Real).
Returns: probability of being AI (0.0 to 1.0)
"""
if not self.loaded:
logger.warning("SOTA model not loaded, returning None")
return None
try:
# Load and resample audio using librosa (more robust backend)
import librosa
import numpy as np
# Load directly at 16kHz
waveform, sample_rate = librosa.load(audio_path, sr=16000)
# Ensure proper shape for transformers (1, length)
# librosa returns (length,) for mono
waveform = torch.tensor(waveform).unsqueeze(0)
# Input is now a tensor of shape (1, L)
# feature_extractor expects numpy array or tensor
input_values = self.feature_extractor(
waveform.squeeze().numpy(),
return_tensors="pt",
sampling_rate=16000
).input_values.to(self.device)
with torch.no_grad():
logits = self.model(input_values).logits
# The model outputs [Real_Logit, Fake_Logit] usually
# Let's check the config label map if possible, but hemgg/Deepfake-audio-detection
# typically maps 0: Real, 1: Fake or vice-versa.
# hemgg/Deepfake-audio-detection labels: {0: 'real', 1: 'fake'}
probs = F.softmax(logits, dim=-1)
# labels: {0: 'AIVoice', 1: 'HumanVoice'}
fake_prob = probs[0][0].item() # Index 0 is 'AIVoice'
logger.info(f"SOTA Prediction - Fake Prob: {fake_prob:.4f}")
return fake_prob
except Exception as e:
logger.error(f"SOTA prediction failed: {e}")
return None
# Singleton instance
_detector = None
def get_detector():
global _detector
if _detector is None:
_detector = DeepfakeDetector()
return _detector