voice-detection-api / ml /inference.py
Hariharan S
Upgrade to SOTA Wav2Vec2 deepfake detector
488006a
import base64
import os
import uuid
import uuid
import logging
from typing import Dict, Any
try:
import numpy as np
except ImportError:
np = None
from ml.feature_extraction import extract_audio_features
# Import model and the availability flag
from ml.model import VoiceAuthenticityClassifier, HAS_TORCH
from ml.explanation import generate_explanation
if HAS_TORCH:
import torch
logger = logging.getLogger(__name__)
# Load model (singleton pattern)
# Use absolute path relative to the running app or configurable
MODEL_PATH = os.path.join(os.path.dirname(__file__), "saved_models", "voice_classifier.pth")
model = None
device = None
def load_model():
"""Load pre-trained model weights or initialize random model"""
global model, device
if not HAS_TORCH:
logger.warning("PyTorch not installed. Running in LIGHTWEIGHT MODE (Heuristic only).")
return None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if model is None:
try:
# Initialize model architecture
unique_model = VoiceAuthenticityClassifier()
if os.path.exists(MODEL_PATH):
unique_model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
logger.info(f"Loaded model weights from {MODEL_PATH}")
else:
# If no weights, using random initialization is fine for structure
# but predictions will be random.
logger.warning(f"No weights found at {MODEL_PATH}. Using RANDOM initialization.")
unique_model.to(device)
unique_model.eval()
model = unique_model
except Exception as e:
logger.error(f"Failed to load model: {e}")
model = None
return model
def heuristic_fallback(features):
"""
Improved heuristic-based classification for AI vs Human voice detection.
Modern AI voices (Canva, ElevenLabs, etc.) typically have:
- Very LOW spectral flux std (smooth, no natural variation)
- Very LOW zero-crossing rate std (uniform texture)
- LOW pitch std (unnaturally consistent pitch)
- LOW MFCC std (over-smoothed formants)
Human voices have:
- HIGH spectral flux std (dynamic, natural variation)
- HIGH zero-crossing rate std (varied texture)
- HIGH pitch std (natural intonation)
- HIGH MFCC std (natural formant dynamics)
"""
is_list = isinstance(features, list)
if len(features) < 36:
return 0.5
# Extract key features
if is_list:
mfcc_std_avg = sum(features[13:26]) / 13
spectral_flux_mean = features[30]
spectral_flux_std = features[31]
pitch_std = features[33]
zcr_mean = features[34]
zcr_std = features[35]
else:
mfcc_std_avg = features[13:26].mean()
spectral_flux_mean = features[30]
spectral_flux_std = features[31]
pitch_std = features[33]
zcr_mean = features[34]
zcr_std = features[35]
# Debug logging
logger.info(f"Features - MFCC_std: {mfcc_std_avg:.2f}, Flux_std: {spectral_flux_std:.2f}, Pitch_std: {pitch_std:.2f}, ZCR_std: {zcr_std:.4f}")
# Start with neutral score
ai_score = 0.5
# Count AI and Human indicators
ai_indicators = 0
human_indicators = 0
# === Feature Analysis ===
# Spectral Flux: AI tends to be smoother, but short clips can also be smooth
if spectral_flux_std < 3.0:
ai_indicators += 1
elif spectral_flux_std > 8.0:
human_indicators += 2 # Strong human indicator
# MFCC std: Very low suggests over-processed audio
if mfcc_std_avg < 20:
ai_indicators += 1
elif mfcc_std_avg > 40:
human_indicators += 1
# ZCR std: High ZCR variance often indicates natural speech
if zcr_std < 0.035:
ai_indicators += 1
elif zcr_std > 0.08:
human_indicators += 2 # Strong human indicator
# Pitch std: Only very low pitch std indicates AI (modern AI can fake high variance)
if pitch_std < 50:
ai_indicators += 1 # Too uniform - strong AI indicator
# === Combined Decision ===
# If we have strong human indicators (high ZCR variance, high flux), trust them
if human_indicators >= 3:
ai_score = 0.2
elif human_indicators >= 2 and ai_indicators <= 1:
ai_score = 0.35
elif ai_indicators >= 3:
ai_score = 0.9
elif ai_indicators >= 2 and human_indicators <= 1:
ai_score = 0.75
elif ai_indicators > human_indicators:
ai_score = 0.6
else:
ai_score = 0.4
# Clamp to valid range
return max(0.01, min(0.99, ai_score))
# Import SOTA model
try:
from ml.sota_model import get_detector
HAS_SOTA = True
except ImportError as e:
logging.warning(f"Could not import SOTA model: {e}")
HAS_SOTA = False
async def predict_voice_authenticity(audio_base64: str, language: str) -> Dict:
"""
Main inference pipeline using SOTA Deep Learning model
"""
temp_path = f"/tmp/{uuid.uuid4()}.mp3"
try:
# 1. Decode audio
try:
audio_bytes = base64.b64decode(audio_base64)
with open(temp_path, "wb") as f:
f.write(audio_bytes)
except Exception as e:
logger.error(f"Base64 decode failed: {e}")
raise ValueError("Invalid Base64 audio string")
# 2. Extract features (still useful for explanation)
features = extract_audio_features(temp_path)
# 3. Predict using SOTA Model
ai_probability = None
used_method = "SOTA"
if HAS_SOTA:
detector = get_detector()
ai_probability = detector.predict(temp_path)
# 4. Fallback to heuristics if SOTA fails
if ai_probability is None:
logger.warning("SOTA model unavailable/failed, falling back to heuristics")
ai_probability = heuristic_fallback(features)
used_method = "HEURISTIC"
# 5. Clean up
if os.path.exists(temp_path):
os.remove(temp_path)
# 6. Interpret results
# Threshold can be tuned. SOTA models are usually very confident.
if ai_probability > 0.5:
classification = "AI_GENERATED"
confidence = ai_probability
else:
classification = "HUMAN"
confidence = 1.0 - ai_probability
logger.info(f"Method: {used_method}, Prob: {ai_probability:.4f}, Class: {classification}")
# 7. Generate explanation
explanation = generate_explanation(features, ai_probability)
return {
"status": "success",
"language": language,
"classification": classification,
"confidenceScore": round(confidence, 2),
"explanation": explanation
}
except Exception as e:
if os.path.exists(temp_path):
os.remove(temp_path)
logger.error(f"Prediction error: {e}")
raise ValueError(f"Audio processing error: {str(e)}")