File size: 5,548 Bytes
e950836 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | """
audio_detector_inference.py
============================
Inference wrapper for AASISTDeepFake.
Mirrors the structure of text_detector_inference.py for consistency.
Usage
-----
from audio_detector_inference import AudioDetectorInference
detector = AudioDetectorInference(checkpoint="best_aasist.pt", threshold=0.5)
result = detector.predict(waveform_np_array, sample_rate=16000)
"""
import os
import numpy as np
import torch
from audio_model import AASISTDeepFake, SAMPLE_RATE, MAX_SAMPLES
class AudioDetectorInference:
"""
Thin wrapper around AASISTDeepFake for single-clip audio prediction.
Parameters
----------
checkpoint : str
Path to the best_aasist.pt state-dict file produced by training.
threshold : float
sigmoid(logit) >= threshold β Real.
Use 0.5 (default) or the optimal F1 threshold printed at the end
of training (the value labelled "Optimal threshold" in Cell 14).
device : torch.device | None
Auto-detects CUDA if None.
"""
def __init__(
self,
checkpoint: str = "best_aasist.pt",
threshold: float = 0.5,
device: torch.device = None,
):
self.threshold = threshold
self.device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
self.model = None
if os.path.exists(checkpoint):
print(f"[AudioDetector] Loading checkpoint: {checkpoint}")
self.model = AASISTDeepFake()
self.model.load_state_dict(
torch.load(checkpoint, map_location=self.device)
)
self.model.eval().to(self.device)
print(f"[AudioDetector] β
AASISTDeepFake ready "
f"(threshold={self.threshold})")
else:
print(
f"[AudioDetector] β οΈ '{checkpoint}' not found.\n"
f"[AudioDetector] Upload best_aasist.pt to the Space β "
f"audio predictions will fail until then."
)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _preprocess(self, x: np.ndarray, sr: int) -> torch.Tensor:
"""
Normalise, convert to mono, resample to 16 kHz, and
pad/trim to exactly MAX_SAMPLES (80 000).
Returns
-------
torch.Tensor of shape (1, 80 000), dtype float32, on CPU.
"""
x = x.astype(np.float32)
# ββ Normalise int16 recordings ββββββββββββββββββββββββββββββββββββββββ
if np.abs(x).max() > 1.0:
x = x / 32768.0
# ββ Stereo β mono βββββββββββββββββββββββββββββββββββββββββββββββββββββ
if x.ndim == 2:
x = x.mean(axis=1)
# ββ Resample if needed ββββββββββββββββββββββββββββββββββββββββββββββββ
if sr != SAMPLE_RATE:
import librosa
print(f"[AudioDetector] Resampling {sr} Hz β {SAMPLE_RATE} Hz β¦")
x = librosa.resample(x, orig_sr=sr, target_sr=SAMPLE_RATE)
# ββ Pad or trim to exactly MAX_SAMPLES ββββββββββββββββββββββββββββββββ
if len(x) < MAX_SAMPLES:
x = np.pad(x, (0, MAX_SAMPLES - len(x)))
else:
x = x[:MAX_SAMPLES]
return torch.tensor(x, dtype=torch.float32).unsqueeze(0) # (1, 80000)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def predict(self, x: np.ndarray, sr: int) -> dict:
"""
Classify a single raw audio clip.
Parameters
----------
x : np.ndarray Raw waveform (any sample rate, mono or stereo).
sr : int Sample rate of x.
Returns
-------
dict with keys:
label : "Real" | "Fake"
real_prob : P(real) in [0, 1]
fake_prob : P(fake) in [0, 1]
confidence : probability of the predicted class in [0, 1]
"""
if self.model is None:
return {
"error": (
"Model not loaded β upload best_aasist.pt to the Space."
)
}
wav = self._preprocess(x, sr).to(self.device)
with torch.no_grad():
logit = self.model(wav) # (1, 1)
real_prob = torch.sigmoid(logit).item()
fake_prob = 1.0 - real_prob
is_real = real_prob >= self.threshold
label = "Real" if is_real else "Fake"
confidence = real_prob if is_real else fake_prob
print(
f"[AudioDetector] real_prob={real_prob:.4f} "
f"fake_prob={fake_prob:.4f} β {label}"
)
return {
"label": label,
"real_prob": round(real_prob, 4),
"fake_prob": round(fake_prob, 4),
"confidence": round(confidence, 4),
}
|