import torch from transformers import AutoFeatureExtractor, AutoModelForAudioClassification import torchaudio import numpy as np from app.config import settings from app.utils import extract_heuristic_features class ModelHandler: _instance = None def __new__(cls): if cls._instance is None: cls._instance = super(ModelHandler, cls).__new__(cls) cls._instance.model = None cls._instance.feature_extractor = None cls._instance.device = "cuda" if torch.cuda.is_available() else "cpu" return cls._instance def load_model(self): if self.model is None: print(f"Loading model {settings.MODEL_NAME} on {self.device}...") try: # Using a generic audio classification pipeline structure # For this specific task, we might fallback to a simpler model if this fails or is too heavy # But typically we'd use something like 'facebook/wav2vec2-base-960h' finetuned for spoofing # Or a specific deepfake detection model. # For this demo, let's assume we are using a model that fits AutoModelForAudioClassification self.feature_extractor = AutoFeatureExtractor.from_pretrained(settings.MODEL_NAME) self.model = AutoModelForAudioClassification.from_pretrained(settings.MODEL_NAME) self.model.to(self.device) self.model.eval() print("Model loaded successfully.") except Exception as e: print(f"Error loading model: {e}") # Fallback or re-raise depending on requirements # For now, we allow it to fail so we can debug or fix raise e def predict(self, waveform, sr): if self.model is None: self.load_model() # Ensure proper input size/format for the model # Most HF audio models expect array input via feature extractor waveform_np = waveform.squeeze().numpy() inputs = self.feature_extractor( waveform_np, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt", padding=True, truncation=True, max_length=self.feature_extractor.sampling_rate * 5 # Limit to 5s for stability? ) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): logits = self.model(**inputs).logits probs = torch.nn.functional.softmax(logits, dim=-1) # NOTE: Label mapping depends on the specific model used. # usually 0: real, 1: fake or vice versa. # We need to check the model config 'id2label' id2label = self.model.config.id2label predicted_class_id = torch.argmax(probs, dim=-1).item() predicted_label = id2label[predicted_class_id] confidence = probs[0][predicted_class_id].item() return predicted_label, confidence model_handler = ModelHandler()