Hariharan S commited on
Commit
488006a
·
1 Parent(s): 0cd5695

Upgrade to SOTA Wav2Vec2 deepfake detector

Browse files
Files changed (4) hide show
  1. app/main.py +13 -0
  2. ml/inference.py +36 -13
  3. ml/sota_model.py +86 -0
  4. requirements.txt +3 -1
app/main.py CHANGED
@@ -26,6 +26,19 @@ app = FastAPI(
26
  version="1.0.0"
27
  )
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # CORS configuration
30
  app.add_middleware(
31
  CORSMiddleware,
 
26
  version="1.0.0"
27
  )
28
 
29
+ # Startup Event to Preload Model
30
+ @app.on_event("startup")
31
+ async def startup_event():
32
+ """Preload SOTA model on startup to avoid first-request latency"""
33
+ try:
34
+ logger.info("Initializing SOTA Deepfake Detector...")
35
+ # Import inside function to avoid top-level overhead if imports fail
36
+ from ml.sota_model import get_detector
37
+ get_detector() # Triggers model loading
38
+ logger.info("SOTA Model preloaded successfully!")
39
+ except Exception as e:
40
+ logger.warning(f"Could not preload SOTA model: {e}")
41
+
42
  # CORS configuration
43
  app.add_middleware(
44
  CORSMiddleware,
ml/inference.py CHANGED
@@ -149,9 +149,18 @@ def heuristic_fallback(features):
149
  # Clamp to valid range
150
  return max(0.01, min(0.99, ai_score))
151
 
 
 
 
 
 
 
 
 
 
152
  async def predict_voice_authenticity(audio_base64: str, language: str) -> Dict:
153
  """
154
- Main inference pipeline
155
  """
156
  temp_path = f"/tmp/{uuid.uuid4()}.mp3"
157
 
@@ -165,23 +174,37 @@ async def predict_voice_authenticity(audio_base64: str, language: str) -> Dict:
165
  logger.error(f"Base64 decode failed: {e}")
166
  raise ValueError("Invalid Base64 audio string")
167
 
168
- # 2. Extract features
169
  features = extract_audio_features(temp_path)
170
 
171
- # 3. Clean up
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  if os.path.exists(temp_path):
173
  os.remove(temp_path)
174
-
175
- # 4. Load model
176
- classifier = load_model()
177
-
178
- # 5. Run inference - Use heuristics for better modern AI voice detection
179
- # The heuristics are calibrated for Canva, ElevenLabs, etc.
180
- ai_probability = heuristic_fallback(features)
181
 
182
  # 6. Interpret results
183
- classification = "AI_GENERATED" if ai_probability > 0.5 else "HUMAN"
184
- confidence = ai_probability if ai_probability > 0.5 else (1 - ai_probability)
 
 
 
 
 
 
 
185
 
186
  # 7. Generate explanation
187
  explanation = generate_explanation(features, ai_probability)
@@ -198,4 +221,4 @@ async def predict_voice_authenticity(audio_base64: str, language: str) -> Dict:
198
  if os.path.exists(temp_path):
199
  os.remove(temp_path)
200
  logger.error(f"Prediction error: {e}")
201
- raise ValueError(f"Audio processing/feature extraction error: {str(e)}")
 
149
  # Clamp to valid range
150
  return max(0.01, min(0.99, ai_score))
151
 
152
+
153
+ # Import SOTA model
154
+ try:
155
+ from ml.sota_model import get_detector
156
+ HAS_SOTA = True
157
+ except ImportError as e:
158
+ logging.warning(f"Could not import SOTA model: {e}")
159
+ HAS_SOTA = False
160
+
161
  async def predict_voice_authenticity(audio_base64: str, language: str) -> Dict:
162
  """
163
+ Main inference pipeline using SOTA Deep Learning model
164
  """
165
  temp_path = f"/tmp/{uuid.uuid4()}.mp3"
166
 
 
174
  logger.error(f"Base64 decode failed: {e}")
175
  raise ValueError("Invalid Base64 audio string")
176
 
177
+ # 2. Extract features (still useful for explanation)
178
  features = extract_audio_features(temp_path)
179
 
180
+ # 3. Predict using SOTA Model
181
+ ai_probability = None
182
+ used_method = "SOTA"
183
+
184
+ if HAS_SOTA:
185
+ detector = get_detector()
186
+ ai_probability = detector.predict(temp_path)
187
+
188
+ # 4. Fallback to heuristics if SOTA fails
189
+ if ai_probability is None:
190
+ logger.warning("SOTA model unavailable/failed, falling back to heuristics")
191
+ ai_probability = heuristic_fallback(features)
192
+ used_method = "HEURISTIC"
193
+
194
+ # 5. Clean up
195
  if os.path.exists(temp_path):
196
  os.remove(temp_path)
 
 
 
 
 
 
 
197
 
198
  # 6. Interpret results
199
+ # Threshold can be tuned. SOTA models are usually very confident.
200
+ if ai_probability > 0.5:
201
+ classification = "AI_GENERATED"
202
+ confidence = ai_probability
203
+ else:
204
+ classification = "HUMAN"
205
+ confidence = 1.0 - ai_probability
206
+
207
+ logger.info(f"Method: {used_method}, Prob: {ai_probability:.4f}, Class: {classification}")
208
 
209
  # 7. Generate explanation
210
  explanation = generate_explanation(features, ai_probability)
 
221
  if os.path.exists(temp_path):
222
  os.remove(temp_path)
223
  logger.error(f"Prediction error: {e}")
224
+ raise ValueError(f"Audio processing error: {str(e)}")
ml/sota_model.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torchaudio
5
+ from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor
6
+ import logging
7
+ import os
8
+ import shutil
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class DeepfakeDetector:
13
+ def __init__(self, model_name="hemgg/Deepfake-audio-detection"):
14
+ """
15
+ Initialize the SOTA Deepfake Detector model.
16
+ Uses a pre-trained Wav2Vec2 model fine-tuned for deepfake detection.
17
+ """
18
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ logger.info(f"Loading SOTA model: {model_name} on {self.device}...")
20
+
21
+ try:
22
+ self.model = AutoModelForAudioClassification.from_pretrained(model_name).to(self.device).eval()
23
+ self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
24
+ self.loaded = True
25
+ logger.info("SOTA Model loaded successfully!")
26
+ except Exception as e:
27
+ logger.error(f"Failed to load SOTA model: {e}")
28
+ self.loaded = False
29
+
30
+ def predict(self, audio_path):
31
+ """
32
+ Predict if audio is AI-generated (Fake) or Human (Real).
33
+ Returns: probability of being AI (0.0 to 1.0)
34
+ """
35
+ if not self.loaded:
36
+ logger.warning("SOTA model not loaded, returning None")
37
+ return None
38
+
39
+ try:
40
+ # Load and resample audio using librosa (more robust backend)
41
+ import librosa
42
+ import numpy as np
43
+
44
+ # Load directly at 16kHz
45
+ waveform, sample_rate = librosa.load(audio_path, sr=16000)
46
+
47
+ # Ensure proper shape for transformers (1, length)
48
+ # librosa returns (length,) for mono
49
+ waveform = torch.tensor(waveform).unsqueeze(0)
50
+
51
+ # Input is now a tensor of shape (1, L)
52
+ # feature_extractor expects numpy array or tensor
53
+
54
+ input_values = self.feature_extractor(
55
+ waveform.squeeze().numpy(),
56
+ return_tensors="pt",
57
+ sampling_rate=16000
58
+ ).input_values.to(self.device)
59
+
60
+ with torch.no_grad():
61
+ logits = self.model(input_values).logits
62
+
63
+ # The model outputs [Real_Logit, Fake_Logit] usually
64
+ # Let's check the config label map if possible, but hemgg/Deepfake-audio-detection
65
+ # typically maps 0: Real, 1: Fake or vice-versa.
66
+ # hemgg/Deepfake-audio-detection labels: {0: 'real', 1: 'fake'}
67
+
68
+ probs = F.softmax(logits, dim=-1)
69
+ # labels: {0: 'AIVoice', 1: 'HumanVoice'}
70
+ fake_prob = probs[0][0].item() # Index 0 is 'AIVoice'
71
+
72
+ logger.info(f"SOTA Prediction - Fake Prob: {fake_prob:.4f}")
73
+ return fake_prob
74
+
75
+ except Exception as e:
76
+ logger.error(f"SOTA prediction failed: {e}")
77
+ return None
78
+
79
+ # Singleton instance
80
+ _detector = None
81
+
82
+ def get_detector():
83
+ global _detector
84
+ if _detector is None:
85
+ _detector = DeepfakeDetector()
86
+ return _detector
requirements.txt CHANGED
@@ -5,11 +5,13 @@ pydantic==2.5.3
5
  python-multipart==0.0.6
6
 
7
  # ML & Audio Processing
8
- torch==2.1.2
 
9
  librosa==0.10.1
10
  soundfile==0.12.1
11
  numpy==1.26.3
12
  scipy>=1.10.0
 
13
  scikit-learn==1.4.0
14
 
15
  # Utilities
 
5
  python-multipart==0.0.6
6
 
7
  # ML & Audio Processing
8
+ torch>=2.2.0
9
+ torchaudio>=2.2.0
10
  librosa==0.10.1
11
  soundfile==0.12.1
12
  numpy==1.26.3
13
  scipy>=1.10.0
14
+ transformers>=4.35.0 # For pre-trained deepfake models
15
  scikit-learn==1.4.0
16
 
17
  # Utilities