Multi_Modal_Deepfake_Detection / audio_detector_inference.py
pavankumarvk's picture
Upload 2 files
e950836 verified
"""
audio_detector_inference.py
============================
Inference wrapper for AASISTDeepFake.
Mirrors the structure of text_detector_inference.py for consistency.
Usage
-----
from audio_detector_inference import AudioDetectorInference
detector = AudioDetectorInference(checkpoint="best_aasist.pt", threshold=0.5)
result = detector.predict(waveform_np_array, sample_rate=16000)
"""
import os
import numpy as np
import torch
from audio_model import AASISTDeepFake, SAMPLE_RATE, MAX_SAMPLES
class AudioDetectorInference:
"""
Thin wrapper around AASISTDeepFake for single-clip audio prediction.
Parameters
----------
checkpoint : str
Path to the best_aasist.pt state-dict file produced by training.
threshold : float
sigmoid(logit) >= threshold β†’ Real.
Use 0.5 (default) or the optimal F1 threshold printed at the end
of training (the value labelled "Optimal threshold" in Cell 14).
device : torch.device | None
Auto-detects CUDA if None.
"""
def __init__(
self,
checkpoint: str = "best_aasist.pt",
threshold: float = 0.5,
device: torch.device = None,
):
self.threshold = threshold
self.device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
self.model = None
if os.path.exists(checkpoint):
print(f"[AudioDetector] Loading checkpoint: {checkpoint}")
self.model = AASISTDeepFake()
self.model.load_state_dict(
torch.load(checkpoint, map_location=self.device)
)
self.model.eval().to(self.device)
print(f"[AudioDetector] βœ… AASISTDeepFake ready "
f"(threshold={self.threshold})")
else:
print(
f"[AudioDetector] ⚠️ '{checkpoint}' not found.\n"
f"[AudioDetector] Upload best_aasist.pt to the Space β€” "
f"audio predictions will fail until then."
)
# ──────────────────────────────────────────────────────────────────────────
def _preprocess(self, x: np.ndarray, sr: int) -> torch.Tensor:
"""
Normalise, convert to mono, resample to 16 kHz, and
pad/trim to exactly MAX_SAMPLES (80 000).
Returns
-------
torch.Tensor of shape (1, 80 000), dtype float32, on CPU.
"""
x = x.astype(np.float32)
# ── Normalise int16 recordings ────────────────────────────────────────
if np.abs(x).max() > 1.0:
x = x / 32768.0
# ── Stereo β†’ mono ─────────────────────────────────────────────────────
if x.ndim == 2:
x = x.mean(axis=1)
# ── Resample if needed ────────────────────────────────────────────────
if sr != SAMPLE_RATE:
import librosa
print(f"[AudioDetector] Resampling {sr} Hz β†’ {SAMPLE_RATE} Hz …")
x = librosa.resample(x, orig_sr=sr, target_sr=SAMPLE_RATE)
# ── Pad or trim to exactly MAX_SAMPLES ────────────────────────────────
if len(x) < MAX_SAMPLES:
x = np.pad(x, (0, MAX_SAMPLES - len(x)))
else:
x = x[:MAX_SAMPLES]
return torch.tensor(x, dtype=torch.float32).unsqueeze(0) # (1, 80000)
# ──────────────────────────────────────────────────────────────────────────
def predict(self, x: np.ndarray, sr: int) -> dict:
"""
Classify a single raw audio clip.
Parameters
----------
x : np.ndarray Raw waveform (any sample rate, mono or stereo).
sr : int Sample rate of x.
Returns
-------
dict with keys:
label : "Real" | "Fake"
real_prob : P(real) in [0, 1]
fake_prob : P(fake) in [0, 1]
confidence : probability of the predicted class in [0, 1]
"""
if self.model is None:
return {
"error": (
"Model not loaded β€” upload best_aasist.pt to the Space."
)
}
wav = self._preprocess(x, sr).to(self.device)
with torch.no_grad():
logit = self.model(wav) # (1, 1)
real_prob = torch.sigmoid(logit).item()
fake_prob = 1.0 - real_prob
is_real = real_prob >= self.threshold
label = "Real" if is_real else "Fake"
confidence = real_prob if is_real else fake_prob
print(
f"[AudioDetector] real_prob={real_prob:.4f} "
f"fake_prob={fake_prob:.4f} β†’ {label}"
)
return {
"label": label,
"real_prob": round(real_prob, 4),
"fake_prob": round(fake_prob, 4),
"confidence": round(confidence, 4),
}