File size: 3,121 Bytes
6c1314b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
import torchaudio
import numpy as np
from app.config import settings
from app.utils import extract_heuristic_features

class ModelHandler:
    _instance = None

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(ModelHandler, cls).__new__(cls)
            cls._instance.model = None
            cls._instance.feature_extractor = None
            cls._instance.device = "cuda" if torch.cuda.is_available() else "cpu"
        return cls._instance

    def load_model(self):
        if self.model is None:
            print(f"Loading model {settings.MODEL_NAME} on {self.device}...")
            try:
                # Using a generic audio classification pipeline structure
                # For this specific task, we might fallback to a simpler model if this fails or is too heavy
                # But typically we'd use something like 'facebook/wav2vec2-base-960h' finetuned for spoofing
                # Or a specific deepfake detection model.
                # For this demo, let's assume we are using a model that fits AutoModelForAudioClassification
                
                self.feature_extractor = AutoFeatureExtractor.from_pretrained(settings.MODEL_NAME)
                self.model = AutoModelForAudioClassification.from_pretrained(settings.MODEL_NAME)
                self.model.to(self.device)
                self.model.eval()
                print("Model loaded successfully.")
            except Exception as e:
                print(f"Error loading model: {e}")
                # Fallback or re-raise depending on requirements
                # For now, we allow it to fail so we can debug or fix
                raise e

    def predict(self, waveform, sr):
        if self.model is None:
            self.load_model()
            
        # Ensure proper input size/format for the model
        # Most HF audio models expect array input via feature extractor
        waveform_np = waveform.squeeze().numpy()
        
        inputs = self.feature_extractor(
            waveform_np, 
            sampling_rate=self.feature_extractor.sampling_rate, 
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.feature_extractor.sampling_rate * 5 # Limit to 5s for stability?
        )
        
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        with torch.no_grad():
            logits = self.model(**inputs).logits
            
        probs = torch.nn.functional.softmax(logits, dim=-1)
        
        # NOTE: Label mapping depends on the specific model used.
        # usually 0: real, 1: fake or vice versa.
        # We need to check the model config 'id2label'
        
        id2label = self.model.config.id2label
        predicted_class_id = torch.argmax(probs, dim=-1).item()
        predicted_label = id2label[predicted_class_id]
        confidence = probs[0][predicted_class_id].item()
        
        return predicted_label, confidence

model_handler = ModelHandler()