Spaces:
Sleeping
Sleeping
| """ | |
| Mel Spectrogram CNN Classifier for Voice Deepfake Detection | |
| Uses CNN to analyze mel spectrograms for visual artifacts | |
| indicative of AI-generated speech (vocoder patterns, unnatural harmonics). | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import torchaudio | |
| import torchaudio.transforms as T | |
| from typing import Dict, Tuple, Optional | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| # Import audio utilities for MP3 support | |
| from ..audio_utils import load_audio_torch, extract_advanced_features | |
| class SpectrogramCNN(nn.Module): | |
| """ | |
| CNN architecture for analyzing mel spectrograms. | |
| Inspired by ResNet but optimized for audio deepfake detection. | |
| """ | |
| def __init__(self, num_classes: int = 2): | |
| super().__init__() | |
| # Initial convolution | |
| self.conv1 = nn.Sequential( | |
| nn.Conv2d(1, 32, kernel_size=7, stride=2, padding=3), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |
| ) | |
| # Residual-style blocks | |
| self.block1 = self._make_block(32, 64) | |
| self.block2 = self._make_block(64, 128) | |
| self.block3 = self._make_block(128, 256) | |
| # Attention mechanism for focusing on relevant frequency bands | |
| self.attention = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Flatten(), | |
| nn.Linear(256, 64), | |
| nn.ReLU(), | |
| nn.Linear(64, 256), | |
| nn.Sigmoid() | |
| ) | |
| # Global pooling and classifier | |
| self.global_pool = nn.AdaptiveAvgPool2d(1) | |
| self.classifier = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Dropout(0.3), | |
| nn.Linear(256, 128), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(128, num_classes) | |
| ) | |
| self._initialize_weights() | |
| def _make_block(self, in_channels: int, out_channels: int) -> nn.Module: | |
| """Create a convolutional block with skip connection""" | |
| return nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(), | |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU() | |
| ) | |
| def _initialize_weights(self): | |
| """Initialize weights with Xavier uniform""" | |
| for module in self.modules(): | |
| if isinstance(module, nn.Conv2d): | |
| nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') | |
| elif isinstance(module, nn.BatchNorm2d): | |
| nn.init.constant_(module.weight, 1) | |
| nn.init.constant_(module.bias, 0) | |
| elif isinstance(module, nn.Linear): | |
| nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Forward pass returning logits and attention weights. | |
| Args: | |
| x: Input mel spectrogram [B, 1, H, W] | |
| Returns: | |
| logits: Classification logits [B, 2] | |
| attention_weights: Frequency attention weights [B, 256] | |
| """ | |
| # Feature extraction | |
| x = self.conv1(x) | |
| x = self.block1(x) | |
| x = self.block2(x) | |
| x = self.block3(x) | |
| # Attention | |
| attn = self.attention(x) | |
| x = x * attn.unsqueeze(-1).unsqueeze(-1) | |
| # Classification | |
| x = self.global_pool(x) | |
| logits = self.classifier(x) | |
| return logits, attn | |
| class SpectrogramDetector: | |
| """ | |
| Mel Spectrogram-based detector for AI-generated voice detection. | |
| Converts audio to mel spectrograms and uses CNN to detect | |
| visual patterns indicative of neural vocoders. | |
| """ | |
| def __init__(self, device: str = None): | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"[SpectrogramDetector] Using device: {self.device}") | |
| # Initialize CNN model | |
| self.model = SpectrogramCNN(num_classes=2) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| # Mel spectrogram parameters | |
| self.sample_rate = 16000 | |
| self.n_mels = 128 | |
| self.n_fft = 1024 | |
| self.hop_length = 256 | |
| self.target_length = 128 # Fixed width for CNN input | |
| # Mel transform | |
| self.mel_transform = T.MelSpectrogram( | |
| sample_rate=self.sample_rate, | |
| n_fft=self.n_fft, | |
| hop_length=self.hop_length, | |
| n_mels=self.n_mels | |
| ) | |
| self.resampler_cache = {} | |
| def _load_audio(self, audio_path: str) -> torch.Tensor: | |
| """Load and resample audio with MP3 support""" | |
| # Use audio_utils which supports MP3 via soundfile | |
| waveform = load_audio_torch(audio_path, target_sr=self.sample_rate) | |
| return waveform.unsqueeze(0) # Add channel dim | |
| def _create_mel_spectrogram(self, waveform: torch.Tensor) -> torch.Tensor: | |
| """Convert waveform to normalized mel spectrogram""" | |
| # Compute mel spectrogram | |
| mel_spec = self.mel_transform(waveform) | |
| # Convert to log scale (dB) | |
| mel_spec = torch.log(mel_spec + 1e-9) | |
| # Normalize | |
| mel_spec = (mel_spec - mel_spec.mean()) / (mel_spec.std() + 1e-8) | |
| # Resize to fixed width | |
| if mel_spec.shape[-1] != self.target_length: | |
| mel_spec = F.interpolate( | |
| mel_spec.unsqueeze(0), | |
| size=(self.n_mels, self.target_length), | |
| mode='bilinear', | |
| align_corners=False | |
| ).squeeze(0) | |
| return mel_spec | |
| def _analyze_spectrogram(self, mel_spec: torch.Tensor) -> Dict: | |
| """Analyze spectrogram for AI-typical patterns""" | |
| spec = mel_spec.squeeze().numpy() | |
| analysis = {} | |
| # Check for unnaturally smooth regions | |
| gradient = np.gradient(spec, axis=1) | |
| analysis["temporal_smoothness"] = float(1.0 / (np.std(gradient) + 1e-8)) | |
| # Check frequency band energy distribution | |
| low_band = spec[:32, :].mean() | |
| mid_band = spec[32:96, :].mean() | |
| high_band = spec[96:, :].mean() | |
| analysis["low_band_energy"] = float(low_band) | |
| analysis["mid_band_energy"] = float(mid_band) | |
| analysis["high_band_energy"] = float(high_band) | |
| # Check for vocoder grid patterns (common in neural TTS) | |
| fft_spec = np.abs(np.fft.fft2(spec)) | |
| analysis["periodicity_score"] = float(fft_spec[1:10, 1:10].mean() / fft_spec.mean()) | |
| # Harmonic-to-noise ratio approximation | |
| sorted_spec = np.sort(spec.flatten())[::-1] | |
| top_10_pct = sorted_spec[:int(len(sorted_spec) * 0.1)].mean() | |
| bottom_50_pct = sorted_spec[int(len(sorted_spec) * 0.5):].mean() | |
| analysis["hnr_approx"] = float(top_10_pct / (bottom_50_pct + 1e-8)) | |
| return analysis | |
| def detect(self, audio_path: str) -> Dict: | |
| """ | |
| Detect if audio is AI-generated using spectrogram analysis. | |
| Args: | |
| audio_path: Path to audio file | |
| Returns: | |
| Dictionary with detection results | |
| """ | |
| waveform = self._load_audio(audio_path) | |
| mel_spec = self._create_mel_spectrogram(waveform) | |
| spec_analysis = self._analyze_spectrogram(mel_spec) | |
| spec_analysis['energy_cv'] = self._compute_energy_cv(waveform) | |
| adv_features = extract_advanced_features(audio_path, self.sample_rate) | |
| spec_analysis.update(adv_features) | |
| ai_score = self._compute_ai_score_from_spectrogram(spec_analysis) | |
| ai_score = max(0.0, min(1.0, ai_score)) | |
| is_ai = ai_score >= 0.5 | |
| confidence = abs(ai_score - 0.5) * 2 | |
| result = { | |
| "classification": "ai_generated" if is_ai else "human", | |
| "confidence": confidence, | |
| "model_scores": { | |
| "ai_probability": ai_score, | |
| "human_probability": 1 - ai_score | |
| }, | |
| "spectrogram_analysis": spec_analysis, | |
| "frequency_attention": [], | |
| "indicators": self._generate_indicators(1 if is_ai else 0, spec_analysis) | |
| } | |
| return result | |
| def _compute_energy_cv(self, waveform: torch.Tensor) -> float: | |
| """Compute Coefficient of Variation of energy""" | |
| if waveform.dim() > 1: | |
| waveform = waveform.squeeze() | |
| chunk_size = len(waveform) // 10 | |
| if chunk_size == 0: return 0.0 | |
| energies = [] | |
| for i in range(10): | |
| chunk = waveform[i*chunk_size:(i+1)*chunk_size] | |
| energies.append(float(torch.sqrt(torch.mean(chunk ** 2)))) | |
| energy_std = np.std(energies) if energies else 0 | |
| energy_mean = np.mean(energies) if energies else 1 | |
| return float(energy_std / (energy_mean + 1e-8)) | |
| def _compute_ai_score_from_spectrogram(self, analysis: Dict) -> float: | |
| """ | |
| Compute AI probability from spectrogram analysis. | |
| AI-generated voices typically have: | |
| - Higher temporal smoothness (consistent spectrum) | |
| - Low energy variation (consistent volume, less natural pausing) | |
| - Specific band energy distributions | |
| """ | |
| score = 0.5 # Start neutral | |
| # 1. Energy CV - STRONGEST INDICATOR (prioritize this) | |
| # High variation (>0.5) is very typical of human speech (pauses/breathing) | |
| # Low variation (<0.2) is typical of AI | |
| energy_cv = analysis.get("energy_cv", 0.25) | |
| if energy_cv > 0.7: | |
| score -= 0.30 # Very strong human signal | |
| elif energy_cv > 0.5: | |
| score -= 0.20 # Strong human signal | |
| elif energy_cv > 0.35: | |
| score -= 0.10 # Moderate human signal | |
| elif energy_cv < 0.2: | |
| score += 0.10 # Consistent energy = AI like | |
| # 2. Advanced Features | |
| flux = analysis.get("spectral_flux", 0) | |
| mfcc_var = analysis.get("mfcc_variance", 0) | |
| # Flux Heuristic - widened threshold | |
| if flux > 2.4: | |
| score += 0.20 # Strong AI signal | |
| elif flux > 2.0: | |
| score += 0.10 # Moderate AI signal | |
| elif flux < 1.8 and flux > 0.1: | |
| score -= 0.10 # Natural human transitions | |
| # MFCC Var Heuristic | |
| if mfcc_var > 1900: | |
| score -= 0.20 # High complexity = human | |
| # 3. Temporal smoothness (de-prioritized - often misleading) | |
| # Only use extreme values | |
| smoothness = analysis.get("temporal_smoothness", 1.0) | |
| if smoothness > 5.0: | |
| score += 0.10 # Very unnaturally smooth | |
| # Additional features removed as unreliable for edge cases | |
| # (periodicity, high_band, hnr can cause false positives on clean recordings) | |
| return score | |
| def _generate_indicators(self, pred_class: int, analysis: Dict) -> list: | |
| """Generate human-readable indicators""" | |
| indicators = [] | |
| if pred_class == 1: # AI detected | |
| if analysis["temporal_smoothness"] > 5: | |
| indicators.append("Unnaturally smooth spectrogram transitions") | |
| if analysis["periodicity_score"] > 2: | |
| indicators.append("Periodic patterns suggesting neural vocoder") | |
| if analysis["high_band_energy"] < -2: | |
| indicators.append("Reduced high-frequency content typical of TTS") | |
| if not indicators: | |
| indicators.append("Spectrogram patterns consistent with AI synthesis") | |
| else: # Human | |
| if analysis["hnr_approx"] > 10: | |
| indicators.append("Strong harmonic structure of natural voice") | |
| if analysis["temporal_smoothness"] < 3: | |
| indicators.append("Natural variation in spectral features") | |
| if not indicators: | |
| indicators.append("Spectrogram shows natural speech characteristics") | |
| return indicators | |