| """ |
| audio_detector_inference.py |
| ============================ |
| Inference wrapper for AASISTDeepFake. |
| Mirrors the structure of text_detector_inference.py for consistency. |
| |
| Usage |
| ----- |
| from audio_detector_inference import AudioDetectorInference |
| |
| detector = AudioDetectorInference(checkpoint="best_aasist.pt", threshold=0.5) |
| result = detector.predict(waveform_np_array, sample_rate=16000) |
| """ |
|
|
| import os |
| import numpy as np |
| import torch |
| from audio_model import AASISTDeepFake, SAMPLE_RATE, MAX_SAMPLES |
|
|
|
|
| class AudioDetectorInference: |
| """ |
| Thin wrapper around AASISTDeepFake for single-clip audio prediction. |
| |
| Parameters |
| ---------- |
| checkpoint : str |
| Path to the best_aasist.pt state-dict file produced by training. |
| threshold : float |
| sigmoid(logit) >= threshold β Real. |
| Use 0.5 (default) or the optimal F1 threshold printed at the end |
| of training (the value labelled "Optimal threshold" in Cell 14). |
| device : torch.device | None |
| Auto-detects CUDA if None. |
| """ |
|
|
| def __init__( |
| self, |
| checkpoint: str = "best_aasist.pt", |
| threshold: float = 0.5, |
| device: torch.device = None, |
| ): |
| self.threshold = threshold |
| self.device = device or torch.device( |
| "cuda" if torch.cuda.is_available() else "cpu" |
| ) |
| self.model = None |
|
|
| if os.path.exists(checkpoint): |
| print(f"[AudioDetector] Loading checkpoint: {checkpoint}") |
| self.model = AASISTDeepFake() |
| self.model.load_state_dict( |
| torch.load(checkpoint, map_location=self.device) |
| ) |
| self.model.eval().to(self.device) |
| print(f"[AudioDetector] β
AASISTDeepFake ready " |
| f"(threshold={self.threshold})") |
| else: |
| print( |
| f"[AudioDetector] β οΈ '{checkpoint}' not found.\n" |
| f"[AudioDetector] Upload best_aasist.pt to the Space β " |
| f"audio predictions will fail until then." |
| ) |
|
|
| |
| def _preprocess(self, x: np.ndarray, sr: int) -> torch.Tensor: |
| """ |
| Normalise, convert to mono, resample to 16 kHz, and |
| pad/trim to exactly MAX_SAMPLES (80 000). |
| |
| Returns |
| ------- |
| torch.Tensor of shape (1, 80 000), dtype float32, on CPU. |
| """ |
| x = x.astype(np.float32) |
|
|
| |
| if np.abs(x).max() > 1.0: |
| x = x / 32768.0 |
|
|
| |
| if x.ndim == 2: |
| x = x.mean(axis=1) |
|
|
| |
| if sr != SAMPLE_RATE: |
| import librosa |
| print(f"[AudioDetector] Resampling {sr} Hz β {SAMPLE_RATE} Hz β¦") |
| x = librosa.resample(x, orig_sr=sr, target_sr=SAMPLE_RATE) |
|
|
| |
| if len(x) < MAX_SAMPLES: |
| x = np.pad(x, (0, MAX_SAMPLES - len(x))) |
| else: |
| x = x[:MAX_SAMPLES] |
|
|
| return torch.tensor(x, dtype=torch.float32).unsqueeze(0) |
|
|
| |
| def predict(self, x: np.ndarray, sr: int) -> dict: |
| """ |
| Classify a single raw audio clip. |
| |
| Parameters |
| ---------- |
| x : np.ndarray Raw waveform (any sample rate, mono or stereo). |
| sr : int Sample rate of x. |
| |
| Returns |
| ------- |
| dict with keys: |
| label : "Real" | "Fake" |
| real_prob : P(real) in [0, 1] |
| fake_prob : P(fake) in [0, 1] |
| confidence : probability of the predicted class in [0, 1] |
| """ |
| if self.model is None: |
| return { |
| "error": ( |
| "Model not loaded β upload best_aasist.pt to the Space." |
| ) |
| } |
|
|
| wav = self._preprocess(x, sr).to(self.device) |
|
|
| with torch.no_grad(): |
| logit = self.model(wav) |
| real_prob = torch.sigmoid(logit).item() |
|
|
| fake_prob = 1.0 - real_prob |
| is_real = real_prob >= self.threshold |
| label = "Real" if is_real else "Fake" |
| confidence = real_prob if is_real else fake_prob |
|
|
| print( |
| f"[AudioDetector] real_prob={real_prob:.4f} " |
| f"fake_prob={fake_prob:.4f} β {label}" |
| ) |
|
|
| return { |
| "label": label, |
| "real_prob": round(real_prob, 4), |
| "fake_prob": round(fake_prob, 4), |
| "confidence": round(confidence, 4), |
| } |
|
|