Spaces:
Sleeping
Sleeping
Hariharan S commited on
Commit ·
488006a
1
Parent(s): 0cd5695
Upgrade to SOTA Wav2Vec2 deepfake detector
Browse files- app/main.py +13 -0
- ml/inference.py +36 -13
- ml/sota_model.py +86 -0
- requirements.txt +3 -1
app/main.py
CHANGED
|
@@ -26,6 +26,19 @@ app = FastAPI(
|
|
| 26 |
version="1.0.0"
|
| 27 |
)
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
# CORS configuration
|
| 30 |
app.add_middleware(
|
| 31 |
CORSMiddleware,
|
|
|
|
| 26 |
version="1.0.0"
|
| 27 |
)
|
| 28 |
|
| 29 |
+
# Startup Event to Preload Model
|
| 30 |
+
@app.on_event("startup")
|
| 31 |
+
async def startup_event():
|
| 32 |
+
"""Preload SOTA model on startup to avoid first-request latency"""
|
| 33 |
+
try:
|
| 34 |
+
logger.info("Initializing SOTA Deepfake Detector...")
|
| 35 |
+
# Import inside function to avoid top-level overhead if imports fail
|
| 36 |
+
from ml.sota_model import get_detector
|
| 37 |
+
get_detector() # Triggers model loading
|
| 38 |
+
logger.info("SOTA Model preloaded successfully!")
|
| 39 |
+
except Exception as e:
|
| 40 |
+
logger.warning(f"Could not preload SOTA model: {e}")
|
| 41 |
+
|
| 42 |
# CORS configuration
|
| 43 |
app.add_middleware(
|
| 44 |
CORSMiddleware,
|
ml/inference.py
CHANGED
|
@@ -149,9 +149,18 @@ def heuristic_fallback(features):
|
|
| 149 |
# Clamp to valid range
|
| 150 |
return max(0.01, min(0.99, ai_score))
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
async def predict_voice_authenticity(audio_base64: str, language: str) -> Dict:
|
| 153 |
"""
|
| 154 |
-
Main inference pipeline
|
| 155 |
"""
|
| 156 |
temp_path = f"/tmp/{uuid.uuid4()}.mp3"
|
| 157 |
|
|
@@ -165,23 +174,37 @@ async def predict_voice_authenticity(audio_base64: str, language: str) -> Dict:
|
|
| 165 |
logger.error(f"Base64 decode failed: {e}")
|
| 166 |
raise ValueError("Invalid Base64 audio string")
|
| 167 |
|
| 168 |
-
# 2. Extract features
|
| 169 |
features = extract_audio_features(temp_path)
|
| 170 |
|
| 171 |
-
# 3.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
if os.path.exists(temp_path):
|
| 173 |
os.remove(temp_path)
|
| 174 |
-
|
| 175 |
-
# 4. Load model
|
| 176 |
-
classifier = load_model()
|
| 177 |
-
|
| 178 |
-
# 5. Run inference - Use heuristics for better modern AI voice detection
|
| 179 |
-
# The heuristics are calibrated for Canva, ElevenLabs, etc.
|
| 180 |
-
ai_probability = heuristic_fallback(features)
|
| 181 |
|
| 182 |
# 6. Interpret results
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
# 7. Generate explanation
|
| 187 |
explanation = generate_explanation(features, ai_probability)
|
|
@@ -198,4 +221,4 @@ async def predict_voice_authenticity(audio_base64: str, language: str) -> Dict:
|
|
| 198 |
if os.path.exists(temp_path):
|
| 199 |
os.remove(temp_path)
|
| 200 |
logger.error(f"Prediction error: {e}")
|
| 201 |
-
raise ValueError(f"Audio processing
|
|
|
|
| 149 |
# Clamp to valid range
|
| 150 |
return max(0.01, min(0.99, ai_score))
|
| 151 |
|
| 152 |
+
|
| 153 |
+
# Import SOTA model
|
| 154 |
+
try:
|
| 155 |
+
from ml.sota_model import get_detector
|
| 156 |
+
HAS_SOTA = True
|
| 157 |
+
except ImportError as e:
|
| 158 |
+
logging.warning(f"Could not import SOTA model: {e}")
|
| 159 |
+
HAS_SOTA = False
|
| 160 |
+
|
| 161 |
async def predict_voice_authenticity(audio_base64: str, language: str) -> Dict:
|
| 162 |
"""
|
| 163 |
+
Main inference pipeline using SOTA Deep Learning model
|
| 164 |
"""
|
| 165 |
temp_path = f"/tmp/{uuid.uuid4()}.mp3"
|
| 166 |
|
|
|
|
| 174 |
logger.error(f"Base64 decode failed: {e}")
|
| 175 |
raise ValueError("Invalid Base64 audio string")
|
| 176 |
|
| 177 |
+
# 2. Extract features (still useful for explanation)
|
| 178 |
features = extract_audio_features(temp_path)
|
| 179 |
|
| 180 |
+
# 3. Predict using SOTA Model
|
| 181 |
+
ai_probability = None
|
| 182 |
+
used_method = "SOTA"
|
| 183 |
+
|
| 184 |
+
if HAS_SOTA:
|
| 185 |
+
detector = get_detector()
|
| 186 |
+
ai_probability = detector.predict(temp_path)
|
| 187 |
+
|
| 188 |
+
# 4. Fallback to heuristics if SOTA fails
|
| 189 |
+
if ai_probability is None:
|
| 190 |
+
logger.warning("SOTA model unavailable/failed, falling back to heuristics")
|
| 191 |
+
ai_probability = heuristic_fallback(features)
|
| 192 |
+
used_method = "HEURISTIC"
|
| 193 |
+
|
| 194 |
+
# 5. Clean up
|
| 195 |
if os.path.exists(temp_path):
|
| 196 |
os.remove(temp_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
# 6. Interpret results
|
| 199 |
+
# Threshold can be tuned. SOTA models are usually very confident.
|
| 200 |
+
if ai_probability > 0.5:
|
| 201 |
+
classification = "AI_GENERATED"
|
| 202 |
+
confidence = ai_probability
|
| 203 |
+
else:
|
| 204 |
+
classification = "HUMAN"
|
| 205 |
+
confidence = 1.0 - ai_probability
|
| 206 |
+
|
| 207 |
+
logger.info(f"Method: {used_method}, Prob: {ai_probability:.4f}, Class: {classification}")
|
| 208 |
|
| 209 |
# 7. Generate explanation
|
| 210 |
explanation = generate_explanation(features, ai_probability)
|
|
|
|
| 221 |
if os.path.exists(temp_path):
|
| 222 |
os.remove(temp_path)
|
| 223 |
logger.error(f"Prediction error: {e}")
|
| 224 |
+
raise ValueError(f"Audio processing error: {str(e)}")
|
ml/sota_model.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchaudio
|
| 5 |
+
from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import shutil
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
class DeepfakeDetector:
|
| 13 |
+
def __init__(self, model_name="hemgg/Deepfake-audio-detection"):
|
| 14 |
+
"""
|
| 15 |
+
Initialize the SOTA Deepfake Detector model.
|
| 16 |
+
Uses a pre-trained Wav2Vec2 model fine-tuned for deepfake detection.
|
| 17 |
+
"""
|
| 18 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 19 |
+
logger.info(f"Loading SOTA model: {model_name} on {self.device}...")
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
self.model = AutoModelForAudioClassification.from_pretrained(model_name).to(self.device).eval()
|
| 23 |
+
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
|
| 24 |
+
self.loaded = True
|
| 25 |
+
logger.info("SOTA Model loaded successfully!")
|
| 26 |
+
except Exception as e:
|
| 27 |
+
logger.error(f"Failed to load SOTA model: {e}")
|
| 28 |
+
self.loaded = False
|
| 29 |
+
|
| 30 |
+
def predict(self, audio_path):
|
| 31 |
+
"""
|
| 32 |
+
Predict if audio is AI-generated (Fake) or Human (Real).
|
| 33 |
+
Returns: probability of being AI (0.0 to 1.0)
|
| 34 |
+
"""
|
| 35 |
+
if not self.loaded:
|
| 36 |
+
logger.warning("SOTA model not loaded, returning None")
|
| 37 |
+
return None
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
# Load and resample audio using librosa (more robust backend)
|
| 41 |
+
import librosa
|
| 42 |
+
import numpy as np
|
| 43 |
+
|
| 44 |
+
# Load directly at 16kHz
|
| 45 |
+
waveform, sample_rate = librosa.load(audio_path, sr=16000)
|
| 46 |
+
|
| 47 |
+
# Ensure proper shape for transformers (1, length)
|
| 48 |
+
# librosa returns (length,) for mono
|
| 49 |
+
waveform = torch.tensor(waveform).unsqueeze(0)
|
| 50 |
+
|
| 51 |
+
# Input is now a tensor of shape (1, L)
|
| 52 |
+
# feature_extractor expects numpy array or tensor
|
| 53 |
+
|
| 54 |
+
input_values = self.feature_extractor(
|
| 55 |
+
waveform.squeeze().numpy(),
|
| 56 |
+
return_tensors="pt",
|
| 57 |
+
sampling_rate=16000
|
| 58 |
+
).input_values.to(self.device)
|
| 59 |
+
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
logits = self.model(input_values).logits
|
| 62 |
+
|
| 63 |
+
# The model outputs [Real_Logit, Fake_Logit] usually
|
| 64 |
+
# Let's check the config label map if possible, but hemgg/Deepfake-audio-detection
|
| 65 |
+
# typically maps 0: Real, 1: Fake or vice-versa.
|
| 66 |
+
# hemgg/Deepfake-audio-detection labels: {0: 'real', 1: 'fake'}
|
| 67 |
+
|
| 68 |
+
probs = F.softmax(logits, dim=-1)
|
| 69 |
+
# labels: {0: 'AIVoice', 1: 'HumanVoice'}
|
| 70 |
+
fake_prob = probs[0][0].item() # Index 0 is 'AIVoice'
|
| 71 |
+
|
| 72 |
+
logger.info(f"SOTA Prediction - Fake Prob: {fake_prob:.4f}")
|
| 73 |
+
return fake_prob
|
| 74 |
+
|
| 75 |
+
except Exception as e:
|
| 76 |
+
logger.error(f"SOTA prediction failed: {e}")
|
| 77 |
+
return None
|
| 78 |
+
|
| 79 |
+
# Singleton instance
|
| 80 |
+
_detector = None
|
| 81 |
+
|
| 82 |
+
def get_detector():
|
| 83 |
+
global _detector
|
| 84 |
+
if _detector is None:
|
| 85 |
+
_detector = DeepfakeDetector()
|
| 86 |
+
return _detector
|
requirements.txt
CHANGED
|
@@ -5,11 +5,13 @@ pydantic==2.5.3
|
|
| 5 |
python-multipart==0.0.6
|
| 6 |
|
| 7 |
# ML & Audio Processing
|
| 8 |
-
torch=
|
|
|
|
| 9 |
librosa==0.10.1
|
| 10 |
soundfile==0.12.1
|
| 11 |
numpy==1.26.3
|
| 12 |
scipy>=1.10.0
|
|
|
|
| 13 |
scikit-learn==1.4.0
|
| 14 |
|
| 15 |
# Utilities
|
|
|
|
| 5 |
python-multipart==0.0.6
|
| 6 |
|
| 7 |
# ML & Audio Processing
|
| 8 |
+
torch>=2.2.0
|
| 9 |
+
torchaudio>=2.2.0
|
| 10 |
librosa==0.10.1
|
| 11 |
soundfile==0.12.1
|
| 12 |
numpy==1.26.3
|
| 13 |
scipy>=1.10.0
|
| 14 |
+
transformers>=4.35.0 # For pre-trained deepfake models
|
| 15 |
scikit-learn==1.4.0
|
| 16 |
|
| 17 |
# Utilities
|