Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| import torchaudio | |
| from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor | |
| import logging | |
| import os | |
| import shutil | |
| logger = logging.getLogger(__name__) | |
| class DeepfakeDetector: | |
| def __init__(self, model_name="hemgg/Deepfake-audio-detection"): | |
| """ | |
| Initialize the SOTA Deepfake Detector model. | |
| Uses a pre-trained Wav2Vec2 model fine-tuned for deepfake detection. | |
| """ | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Loading SOTA model: {model_name} on {self.device}...") | |
| try: | |
| self.model = AutoModelForAudioClassification.from_pretrained(model_name).to(self.device).eval() | |
| self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name) | |
| self.loaded = True | |
| logger.info("SOTA Model loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"Failed to load SOTA model: {e}") | |
| self.loaded = False | |
| def predict(self, audio_path): | |
| """ | |
| Predict if audio is AI-generated (Fake) or Human (Real). | |
| Returns: probability of being AI (0.0 to 1.0) | |
| """ | |
| if not self.loaded: | |
| logger.warning("SOTA model not loaded, returning None") | |
| return None | |
| try: | |
| # Load and resample audio using librosa (more robust backend) | |
| import librosa | |
| import numpy as np | |
| # Load directly at 16kHz | |
| waveform, sample_rate = librosa.load(audio_path, sr=16000) | |
| # Ensure proper shape for transformers (1, length) | |
| # librosa returns (length,) for mono | |
| waveform = torch.tensor(waveform).unsqueeze(0) | |
| # Input is now a tensor of shape (1, L) | |
| # feature_extractor expects numpy array or tensor | |
| input_values = self.feature_extractor( | |
| waveform.squeeze().numpy(), | |
| return_tensors="pt", | |
| sampling_rate=16000 | |
| ).input_values.to(self.device) | |
| with torch.no_grad(): | |
| logits = self.model(input_values).logits | |
| # The model outputs [Real_Logit, Fake_Logit] usually | |
| # Let's check the config label map if possible, but hemgg/Deepfake-audio-detection | |
| # typically maps 0: Real, 1: Fake or vice-versa. | |
| # hemgg/Deepfake-audio-detection labels: {0: 'real', 1: 'fake'} | |
| probs = F.softmax(logits, dim=-1) | |
| # labels: {0: 'AIVoice', 1: 'HumanVoice'} | |
| fake_prob = probs[0][0].item() # Index 0 is 'AIVoice' | |
| logger.info(f"SOTA Prediction - Fake Prob: {fake_prob:.4f}") | |
| return fake_prob | |
| except Exception as e: | |
| logger.error(f"SOTA prediction failed: {e}") | |
| return None | |
| # Singleton instance | |
| _detector = None | |
| def get_detector(): | |
| global _detector | |
| if _detector is None: | |
| _detector = DeepfakeDetector() | |
| return _detector | |