aadhi97x's picture
clean commit
6c1314b
import torch
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
import torchaudio
import numpy as np
from app.config import settings
from app.utils import extract_heuristic_features
class ModelHandler:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(ModelHandler, cls).__new__(cls)
cls._instance.model = None
cls._instance.feature_extractor = None
cls._instance.device = "cuda" if torch.cuda.is_available() else "cpu"
return cls._instance
def load_model(self):
if self.model is None:
print(f"Loading model {settings.MODEL_NAME} on {self.device}...")
try:
# Using a generic audio classification pipeline structure
# For this specific task, we might fallback to a simpler model if this fails or is too heavy
# But typically we'd use something like 'facebook/wav2vec2-base-960h' finetuned for spoofing
# Or a specific deepfake detection model.
# For this demo, let's assume we are using a model that fits AutoModelForAudioClassification
self.feature_extractor = AutoFeatureExtractor.from_pretrained(settings.MODEL_NAME)
self.model = AutoModelForAudioClassification.from_pretrained(settings.MODEL_NAME)
self.model.to(self.device)
self.model.eval()
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
# Fallback or re-raise depending on requirements
# For now, we allow it to fail so we can debug or fix
raise e
def predict(self, waveform, sr):
if self.model is None:
self.load_model()
# Ensure proper input size/format for the model
# Most HF audio models expect array input via feature extractor
waveform_np = waveform.squeeze().numpy()
inputs = self.feature_extractor(
waveform_np,
sampling_rate=self.feature_extractor.sampling_rate,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.feature_extractor.sampling_rate * 5 # Limit to 5s for stability?
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
logits = self.model(**inputs).logits
probs = torch.nn.functional.softmax(logits, dim=-1)
# NOTE: Label mapping depends on the specific model used.
# usually 0: real, 1: fake or vice versa.
# We need to check the model config 'id2label'
id2label = self.model.config.id2label
predicted_class_id = torch.argmax(probs, dim=-1).item()
predicted_label = id2label[predicted_class_id]
confidence = probs[0][predicted_class_id].item()
return predicted_label, confidence
model_handler = ModelHandler()