voicedetectionapi / model_service.py
ROSHANNN123's picture
Update model_service.py
922c67e verified
import torch
import librosa
import numpy as np
import io
import os
import tempfile
import torch.nn.functional as F
from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor
MODEL_NAME = "Hemgg/Deepfake-audio-detection"
class ModelService:
def __init__(self):
print("Loading AI Model...")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
self.model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME).to(self.device)
def preprocess_audio(self, audio_bytes):
# Temp file is the safest way to read MP3/WAV/OGG on cloud servers
fd, tmp_path = tempfile.mkstemp(suffix=".audio")
try:
with os.fdopen(fd, 'wb') as tmp:
tmp.write(audio_bytes)
# Load and resample to 16kHz (Standard for Wav2Vec2)
speech, _ = librosa.load(tmp_path, sr=16000)
return speech
except Exception as e:
raise ValueError(f"Audio processing failed: {str(e)}")
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)
def predict(self, audio_bytes):
speech = self.preprocess_audio(audio_bytes)
inputs = self.feature_extractor(speech, sampling_rate=16000, return_tensors="pt", padding=True)
inputs = {key: val.to(self.device) for key, val in inputs.items()}
with torch.no_grad():
logits = self.model(**inputs).logits
probs = F.softmax(logits, dim=-1)
id2label = self.model.config.id2label
predicted_id = torch.argmax(probs, dim=-1).item()
# Mapping to Portal Labels
lbl = id2label[predicted_id].lower()
if "real" in lbl or "human" in lbl or "bonafide" in lbl:
return "HUMAN", probs[0][predicted_id].item()
return "AI_GENERATED", probs[0][predicted_id].item()
model_service = None
def get_model_service():
global model_service
if model_service is None:
model_service = ModelService()
return model_service