File size: 4,142 Bytes
b3f89f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f229a41
 
 
 
 
 
 
 
b3f89f5
 
 
 
 
 
f229a41
b3f89f5
 
 
 
 
f229a41
 
 
 
 
b3f89f5
 
 
 
 
f229a41
b3f89f5
 
f229a41
b3f89f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70820b4
 
 
 
 
 
b3f89f5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import traceback
from transformers import AutoModelForAudioClassification, AutoFeatureExtractor
from speechbrain.inference.VAD import VAD
import os

class ModelWrapper:
    def __init__(self, config: dict):
        self.config = config
        self.model_name = config.get("name", "nii-yamagishilab/mms-300m-anti-deepfake")
        self.device = config.get("device", "cpu")
        self.model = None
        self.feature_extractor = None
        self.vad = None
        
        # Log library versions for debugging
        try:
            import transformers
            import safetensors
            print(f"Library versions - transformers: {transformers.__version__}, safetensors: {safetensors.__version__}")
        except Exception as e:
            print(f"Warning: Could not log library versions: {e}")
        
        self.load_model()
        self.load_vad()

    def load_model(self):
        try:
            print(f"Loading Deepfake Detection model {self.model_name} on {self.device}...")
            model = AutoModelForAudioClassification.from_pretrained(
                self.model_name, 
                trust_remote_code=True
            ).to(self.device)
            
            fe_name = self.config.get("feature_extractor", self.model_name)
            feature_extractor = AutoFeatureExtractor.from_pretrained(fe_name)
            
            # Only set if both loaded successfully
            self.model = model
            self.feature_extractor = feature_extractor
            self.model.eval()
            print("Model loaded successfully.")
            
        except Exception as e:
            print(f"Error loading model: {e}")
            print(f"Model name: {self.model_name}, Device: {self.device}")
            traceback.print_exc()
            self.model = None
            self.feature_extractor = None

    def load_vad(self):
        try:
            vad_repo = self.config.get("vad", {}).get("repo", "speechbrain/vad-crdnn-libriparty")
            print(f"Loading SpeechBrain VAD from {vad_repo}...")
            # VAD loads internal models, ensure we catch errors here too
            self.vad = VAD.from_hparams(
                source=vad_repo, 
                savedir=self.config.get("vad", {}).get("save_path", "model_checkpoints")
            )
            print("SpeechBrain VAD loaded.")
        except Exception as e:
            print(f"Error loading VAD: {e}")
            traceback.print_exc()
            # We can tolerate VAD failure slightly by processing whole audio, or fail hard.
            # For now, let's keep it robust.
            self.vad = None

    def predict(self, audio: torch.Tensor, sr: int) -> float:
        """
        Predict probability of AI generation.
        Returns float (0.0 to 1.0), where 1.0 is AI.
        """
        if self.model is None or self.feature_extractor is None:
            raise RuntimeError("Model not loaded")

        with torch.no_grad():
            # Preprocess
            inputs = self.feature_extractor(
                audio.numpy(), 
                sampling_rate=sr, 
                return_tensors="pt"
            ).to(self.device)

            # Inference
            outputs = self.model(**inputs)
            logits = outputs.logits
            probs = torch.nn.functional.softmax(logits, dim=-1)
            
            # Label mapping: 
            # id2label usually {0: 'bonafide', 1: 'spoof'} OR {0: 'real', 1: 'fake'}
            # For mms-300m-anti-deepfake: 0 is 'bonafide' (human), 1 is 'spoof' (AI)
            # Verify this assumption via config or logs.  
            # (Logs from repro script said: Labels: {0: 'LABEL_0', 1: 'LABEL_1'})
            # Typically, LABEL_1 is the positive class (spoof).
            
            ai_prob = probs[0][1].item()
            
            # Safety check: handle NaN/Inf (can occur if model weights are improperly loaded)
            if not torch.isfinite(torch.tensor(ai_prob)):
                print(f"WARNING: Model returned non-finite value: {ai_prob}. Returning 0.5 as fallback.")
                return 0.5
            
            return ai_prob