""" audio_detector_inference.py ============================ Inference wrapper for AASISTDeepFake. Mirrors the structure of text_detector_inference.py for consistency. Usage ----- from audio_detector_inference import AudioDetectorInference detector = AudioDetectorInference(checkpoint="best_aasist.pt", threshold=0.5) result = detector.predict(waveform_np_array, sample_rate=16000) """ import os import numpy as np import torch from audio_model import AASISTDeepFake, SAMPLE_RATE, MAX_SAMPLES class AudioDetectorInference: """ Thin wrapper around AASISTDeepFake for single-clip audio prediction. Parameters ---------- checkpoint : str Path to the best_aasist.pt state-dict file produced by training. threshold : float sigmoid(logit) >= threshold → Real. Use 0.5 (default) or the optimal F1 threshold printed at the end of training (the value labelled "Optimal threshold" in Cell 14). device : torch.device | None Auto-detects CUDA if None. """ def __init__( self, checkpoint: str = "best_aasist.pt", threshold: float = 0.5, device: torch.device = None, ): self.threshold = threshold self.device = device or torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) self.model = None if os.path.exists(checkpoint): print(f"[AudioDetector] Loading checkpoint: {checkpoint}") self.model = AASISTDeepFake() self.model.load_state_dict( torch.load(checkpoint, map_location=self.device) ) self.model.eval().to(self.device) print(f"[AudioDetector] ✅ AASISTDeepFake ready " f"(threshold={self.threshold})") else: print( f"[AudioDetector] ⚠️ '{checkpoint}' not found.\n" f"[AudioDetector] Upload best_aasist.pt to the Space — " f"audio predictions will fail until then." ) # ────────────────────────────────────────────────────────────────────────── def _preprocess(self, x: np.ndarray, sr: int) -> torch.Tensor: """ Normalise, convert to mono, resample to 16 kHz, and pad/trim to exactly MAX_SAMPLES (80 000). Returns ------- torch.Tensor of shape (1, 80 000), dtype float32, on CPU. """ x = x.astype(np.float32) # ── Normalise int16 recordings ──────────────────────────────────────── if np.abs(x).max() > 1.0: x = x / 32768.0 # ── Stereo → mono ───────────────────────────────────────────────────── if x.ndim == 2: x = x.mean(axis=1) # ── Resample if needed ──────────────────────────────────────────────── if sr != SAMPLE_RATE: import librosa print(f"[AudioDetector] Resampling {sr} Hz → {SAMPLE_RATE} Hz …") x = librosa.resample(x, orig_sr=sr, target_sr=SAMPLE_RATE) # ── Pad or trim to exactly MAX_SAMPLES ──────────────────────────────── if len(x) < MAX_SAMPLES: x = np.pad(x, (0, MAX_SAMPLES - len(x))) else: x = x[:MAX_SAMPLES] return torch.tensor(x, dtype=torch.float32).unsqueeze(0) # (1, 80000) # ────────────────────────────────────────────────────────────────────────── def predict(self, x: np.ndarray, sr: int) -> dict: """ Classify a single raw audio clip. Parameters ---------- x : np.ndarray Raw waveform (any sample rate, mono or stereo). sr : int Sample rate of x. Returns ------- dict with keys: label : "Real" | "Fake" real_prob : P(real) in [0, 1] fake_prob : P(fake) in [0, 1] confidence : probability of the predicted class in [0, 1] """ if self.model is None: return { "error": ( "Model not loaded — upload best_aasist.pt to the Space." ) } wav = self._preprocess(x, sr).to(self.device) with torch.no_grad(): logit = self.model(wav) # (1, 1) real_prob = torch.sigmoid(logit).item() fake_prob = 1.0 - real_prob is_real = real_prob >= self.threshold label = "Real" if is_real else "Fake" confidence = real_prob if is_real else fake_prob print( f"[AudioDetector] real_prob={real_prob:.4f} " f"fake_prob={fake_prob:.4f} → {label}" ) return { "label": label, "real_prob": round(real_prob, 4), "fake_prob": round(fake_prob, 4), "confidence": round(confidence, 4), }