Spaces:
Running
Running
| import torch | |
| import librosa | |
| import numpy as np | |
| import io | |
| import os | |
| import tempfile | |
| import torch.nn.functional as F | |
| from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor | |
| MODEL_NAME = "Hemgg/Deepfake-audio-detection" | |
| class ModelService: | |
| def __init__(self): | |
| print("Loading AI Model...") | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME) | |
| self.model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME).to(self.device) | |
| def preprocess_audio(self, audio_bytes): | |
| # Temp file is the safest way to read MP3/WAV/OGG on cloud servers | |
| fd, tmp_path = tempfile.mkstemp(suffix=".audio") | |
| try: | |
| with os.fdopen(fd, 'wb') as tmp: | |
| tmp.write(audio_bytes) | |
| # Load and resample to 16kHz (Standard for Wav2Vec2) | |
| speech, _ = librosa.load(tmp_path, sr=16000) | |
| return speech | |
| except Exception as e: | |
| raise ValueError(f"Audio processing failed: {str(e)}") | |
| finally: | |
| if os.path.exists(tmp_path): | |
| os.remove(tmp_path) | |
| def predict(self, audio_bytes): | |
| speech = self.preprocess_audio(audio_bytes) | |
| inputs = self.feature_extractor(speech, sampling_rate=16000, return_tensors="pt", padding=True) | |
| inputs = {key: val.to(self.device) for key, val in inputs.items()} | |
| with torch.no_grad(): | |
| logits = self.model(**inputs).logits | |
| probs = F.softmax(logits, dim=-1) | |
| id2label = self.model.config.id2label | |
| predicted_id = torch.argmax(probs, dim=-1).item() | |
| # Mapping to Portal Labels | |
| lbl = id2label[predicted_id].lower() | |
| if "real" in lbl or "human" in lbl or "bonafide" in lbl: | |
| return "HUMAN", probs[0][predicted_id].item() | |
| return "AI_GENERATED", probs[0][predicted_id].item() | |
| model_service = None | |
| def get_model_service(): | |
| global model_service | |
| if model_service is None: | |
| model_service = ModelService() | |
| return model_service |