File size: 5,548 Bytes
e950836
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
audio_detector_inference.py
============================
Inference wrapper for AASISTDeepFake.
Mirrors the structure of text_detector_inference.py for consistency.

Usage
-----
from audio_detector_inference import AudioDetectorInference

detector = AudioDetectorInference(checkpoint="best_aasist.pt", threshold=0.5)
result   = detector.predict(waveform_np_array, sample_rate=16000)
"""

import os
import numpy as np
import torch
from audio_model import AASISTDeepFake, SAMPLE_RATE, MAX_SAMPLES


class AudioDetectorInference:
    """
    Thin wrapper around AASISTDeepFake for single-clip audio prediction.

    Parameters
    ----------
    checkpoint : str
        Path to the best_aasist.pt state-dict file produced by training.
    threshold  : float
        sigmoid(logit) >= threshold  β†’  Real.
        Use 0.5 (default) or the optimal F1 threshold printed at the end
        of training (the value labelled "Optimal threshold" in Cell 14).
    device     : torch.device | None
        Auto-detects CUDA if None.
    """

    def __init__(
        self,
        checkpoint: str   = "best_aasist.pt",
        threshold:  float = 0.5,
        device: torch.device = None,
    ):
        self.threshold = threshold
        self.device    = device or torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        self.model = None

        if os.path.exists(checkpoint):
            print(f"[AudioDetector] Loading checkpoint: {checkpoint}")
            self.model = AASISTDeepFake()
            self.model.load_state_dict(
                torch.load(checkpoint, map_location=self.device)
            )
            self.model.eval().to(self.device)
            print(f"[AudioDetector] βœ… AASISTDeepFake ready  "
                  f"(threshold={self.threshold})")
        else:
            print(
                f"[AudioDetector] ⚠️  '{checkpoint}' not found.\n"
                f"[AudioDetector]    Upload best_aasist.pt to the Space β€” "
                f"audio predictions will fail until then."
            )

    # ──────────────────────────────────────────────────────────────────────────
    def _preprocess(self, x: np.ndarray, sr: int) -> torch.Tensor:
        """
        Normalise, convert to mono, resample to 16 kHz, and
        pad/trim to exactly MAX_SAMPLES (80 000).

        Returns
        -------
        torch.Tensor of shape (1, 80 000), dtype float32, on CPU.
        """
        x = x.astype(np.float32)

        # ── Normalise int16 recordings ────────────────────────────────────────
        if np.abs(x).max() > 1.0:
            x = x / 32768.0

        # ── Stereo β†’ mono ─────────────────────────────────────────────────────
        if x.ndim == 2:
            x = x.mean(axis=1)

        # ── Resample if needed ────────────────────────────────────────────────
        if sr != SAMPLE_RATE:
            import librosa
            print(f"[AudioDetector] Resampling {sr} Hz β†’ {SAMPLE_RATE} Hz …")
            x = librosa.resample(x, orig_sr=sr, target_sr=SAMPLE_RATE)

        # ── Pad or trim to exactly MAX_SAMPLES ────────────────────────────────
        if len(x) < MAX_SAMPLES:
            x = np.pad(x, (0, MAX_SAMPLES - len(x)))
        else:
            x = x[:MAX_SAMPLES]

        return torch.tensor(x, dtype=torch.float32).unsqueeze(0)  # (1, 80000)

    # ──────────────────────────────────────────────────────────────────────────
    def predict(self, x: np.ndarray, sr: int) -> dict:
        """
        Classify a single raw audio clip.

        Parameters
        ----------
        x  : np.ndarray   Raw waveform (any sample rate, mono or stereo).
        sr : int          Sample rate of x.

        Returns
        -------
        dict with keys:
            label      : "Real" | "Fake"
            real_prob  : P(real)  in [0, 1]
            fake_prob  : P(fake)  in [0, 1]
            confidence : probability of the predicted class in [0, 1]
        """
        if self.model is None:
            return {
                "error": (
                    "Model not loaded β€” upload best_aasist.pt to the Space."
                )
            }

        wav = self._preprocess(x, sr).to(self.device)

        with torch.no_grad():
            logit     = self.model(wav)                    # (1, 1)
            real_prob = torch.sigmoid(logit).item()

        fake_prob  = 1.0 - real_prob
        is_real    = real_prob >= self.threshold
        label      = "Real" if is_real else "Fake"
        confidence = real_prob if is_real else fake_prob

        print(
            f"[AudioDetector] real_prob={real_prob:.4f}  "
            f"fake_prob={fake_prob:.4f}  β†’  {label}"
        )

        return {
            "label":      label,
            "real_prob":  round(real_prob, 4),
            "fake_prob":  round(fake_prob, 4),
            "confidence": round(confidence, 4),
        }