TrueFrame / audio_detect.py
Gaurav-Mhatre's picture
Update audio_detect.py
9d26966 verified
import os
import torch
import librosa
import soundfile as sf
import numpy as np
import subprocess
import traceback
import imageio_ffmpeg # ⚑ NEW: Added FFmpeg engine
from transformers import AutoModelForAudioClassification, AutoFeatureExtractor
class AudioDeepfakeDetector:
def __init__(self):
self.model_name = "Hemgg/Deepfake-audio-detection"
self.model = None
self.extractor = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Limit CPU threads to prevent bottlenecking
if self.device == "cpu":
torch.set_num_threads(4)
print(f"⚑ Loading Audio AI Model: {self.model_name}...")
try:
self.extractor = AutoFeatureExtractor.from_pretrained(self.model_name)
self.model = AutoModelForAudioClassification.from_pretrained(self.model_name).to(self.device)
self.model.eval()
print(f" ℹ️ Labels: {self.model.config.id2label}")
print("βœ… Audio Model Loaded Successfully.")
except Exception as e:
print(f"❌ Failed to load Audio Model: {e}")
traceback.print_exc()
def predict(self, audio_path):
if not self.model:
return "ERROR: MODEL NOT LOADED", 0.0
temp_wav = "temp_fast_audio.wav"
try:
print(f"πŸ” Analyzing audio: {audio_path}")
# ⚑ ULTRA-FAST FFMPEG PRE-PROCESSING
# Instantly chops the file to 4 seconds, forces Mono, and sets 16kHz
ffmpeg_exe = imageio_ffmpeg.get_ffmpeg_exe()
command = [
ffmpeg_exe,
"-y", # Overwrite existing files
"-i", audio_path, # Input file (mp3, wav, etc.)
"-t", "4", # ⚑ Only grab the first 4 seconds
"-ac", "1", # ⚑ Force Mono (1 channel)
"-ar", "16000", # ⚑ Force 16000Hz sample rate
temp_wav # Output perfectly formatted temp file
]
# Run the command silently and instantly
subprocess.run(command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
# Now Soundfile can load it in a fraction of a millisecond
# because it requires ZERO math or resampling from Python!
if os.path.exists(temp_wav):
audio, sr = sf.read(temp_wav)
os.remove(temp_wav) # Clean up temp file
else:
raise Exception("FFmpeg failed to process audio.")
# Ensure data format matches PyTorch requirements
audio = audio.astype(np.float32)
inputs = self.extractor(audio, sampling_rate=sr, return_tensors="pt", padding=True)
inputs = {key: val.to(self.device) for key, val in inputs.items()}
# Fast AI Inference
with torch.inference_mode():
logits = self.model(**inputs).logits
probs = torch.nn.functional.softmax(logits, dim=-1)
confidence, predicted_class_id = torch.max(probs, dim=-1)
raw_label = self.model.config.id2label[predicted_class_id.item()]
is_fake = False
check_label = raw_label.lower()
if "ai" in check_label or "fake" in check_label or "spoof" in check_label:
is_fake = True
label = "DEEPFAKE DETECTED" if is_fake else "REAL"
score = confidence.item()
print(f"βœ… AI Verdict: {raw_label} -> {label} ({score*100:.1f}%)")
return label, score
except Exception as e:
print(f"❌ AUDIO ERROR: {e}")
traceback.print_exc()
return "ERROR", 0.0
if __name__ == "__main__":
detector = AudioDeepfakeDetector()