Spaces:
Running
Running
| import torch | |
| from transformers import AutoFeatureExtractor, AutoModelForAudioClassification | |
| import torchaudio | |
| import numpy as np | |
| from app.config import settings | |
| from app.utils import extract_heuristic_features | |
| class ModelHandler: | |
| _instance = None | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super(ModelHandler, cls).__new__(cls) | |
| cls._instance.model = None | |
| cls._instance.feature_extractor = None | |
| cls._instance.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| return cls._instance | |
| def load_model(self): | |
| if self.model is None: | |
| print(f"Loading model {settings.MODEL_NAME} on {self.device}...") | |
| try: | |
| # Using a generic audio classification pipeline structure | |
| # For this specific task, we might fallback to a simpler model if this fails or is too heavy | |
| # But typically we'd use something like 'facebook/wav2vec2-base-960h' finetuned for spoofing | |
| # Or a specific deepfake detection model. | |
| # For this demo, let's assume we are using a model that fits AutoModelForAudioClassification | |
| self.feature_extractor = AutoFeatureExtractor.from_pretrained(settings.MODEL_NAME) | |
| self.model = AutoModelForAudioClassification.from_pretrained(settings.MODEL_NAME) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| print("Model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| # Fallback or re-raise depending on requirements | |
| # For now, we allow it to fail so we can debug or fix | |
| raise e | |
| def predict(self, waveform, sr): | |
| if self.model is None: | |
| self.load_model() | |
| # Ensure proper input size/format for the model | |
| # Most HF audio models expect array input via feature extractor | |
| waveform_np = waveform.squeeze().numpy() | |
| inputs = self.feature_extractor( | |
| waveform_np, | |
| sampling_rate=self.feature_extractor.sampling_rate, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=self.feature_extractor.sampling_rate * 5 # Limit to 5s for stability? | |
| ) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| logits = self.model(**inputs).logits | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| # NOTE: Label mapping depends on the specific model used. | |
| # usually 0: real, 1: fake or vice versa. | |
| # We need to check the model config 'id2label' | |
| id2label = self.model.config.id2label | |
| predicted_class_id = torch.argmax(probs, dim=-1).item() | |
| predicted_label = id2label[predicted_class_id] | |
| confidence = probs[0][predicted_class_id].item() | |
| return predicted_label, confidence | |
| model_handler = ModelHandler() | |