voice-detection-api / app /models /spectrogram_cnn.py
bakshia's picture
Cleanup: Remove debug prints, improve logging, and refactor comments
476d044
Raw
History Blame Contribute Delete
12.6 kB
"""
Mel Spectrogram CNN Classifier for Voice Deepfake Detection
Uses CNN to analyze mel spectrograms for visual artifacts
indicative of AI-generated speech (vocoder patterns, unnatural harmonics).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchaudio
import torchaudio.transforms as T
from typing import Dict, Tuple, Optional
import logging
logger = logging.getLogger(__name__)
# Import audio utilities for MP3 support
from ..audio_utils import load_audio_torch, extract_advanced_features
class SpectrogramCNN(nn.Module):
"""
CNN architecture for analyzing mel spectrograms.
Inspired by ResNet but optimized for audio deepfake detection.
"""
def __init__(self, num_classes: int = 2):
super().__init__()
# Initial convolution
self.conv1 = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
# Residual-style blocks
self.block1 = self._make_block(32, 64)
self.block2 = self._make_block(64, 128)
self.block3 = self._make_block(128, 256)
# Attention mechanism for focusing on relevant frequency bands
self.attention = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(256, 64),
nn.ReLU(),
nn.Linear(64, 256),
nn.Sigmoid()
)
# Global pooling and classifier
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Dropout(0.3),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, num_classes)
)
self._initialize_weights()
def _make_block(self, in_channels: int, out_channels: int) -> nn.Module:
"""Create a convolutional block with skip connection"""
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def _initialize_weights(self):
"""Initialize weights with Xavier uniform"""
for module in self.modules():
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(module, nn.BatchNorm2d):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass returning logits and attention weights.
Args:
x: Input mel spectrogram [B, 1, H, W]
Returns:
logits: Classification logits [B, 2]
attention_weights: Frequency attention weights [B, 256]
"""
# Feature extraction
x = self.conv1(x)
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
# Attention
attn = self.attention(x)
x = x * attn.unsqueeze(-1).unsqueeze(-1)
# Classification
x = self.global_pool(x)
logits = self.classifier(x)
return logits, attn
class SpectrogramDetector:
"""
Mel Spectrogram-based detector for AI-generated voice detection.
Converts audio to mel spectrograms and uses CNN to detect
visual patterns indicative of neural vocoders.
"""
def __init__(self, device: str = None):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"[SpectrogramDetector] Using device: {self.device}")
# Initialize CNN model
self.model = SpectrogramCNN(num_classes=2)
self.model.to(self.device)
self.model.eval()
# Mel spectrogram parameters
self.sample_rate = 16000
self.n_mels = 128
self.n_fft = 1024
self.hop_length = 256
self.target_length = 128 # Fixed width for CNN input
# Mel transform
self.mel_transform = T.MelSpectrogram(
sample_rate=self.sample_rate,
n_fft=self.n_fft,
hop_length=self.hop_length,
n_mels=self.n_mels
)
self.resampler_cache = {}
def _load_audio(self, audio_path: str) -> torch.Tensor:
"""Load and resample audio with MP3 support"""
# Use audio_utils which supports MP3 via soundfile
waveform = load_audio_torch(audio_path, target_sr=self.sample_rate)
return waveform.unsqueeze(0) # Add channel dim
def _create_mel_spectrogram(self, waveform: torch.Tensor) -> torch.Tensor:
"""Convert waveform to normalized mel spectrogram"""
# Compute mel spectrogram
mel_spec = self.mel_transform(waveform)
# Convert to log scale (dB)
mel_spec = torch.log(mel_spec + 1e-9)
# Normalize
mel_spec = (mel_spec - mel_spec.mean()) / (mel_spec.std() + 1e-8)
# Resize to fixed width
if mel_spec.shape[-1] != self.target_length:
mel_spec = F.interpolate(
mel_spec.unsqueeze(0),
size=(self.n_mels, self.target_length),
mode='bilinear',
align_corners=False
).squeeze(0)
return mel_spec
def _analyze_spectrogram(self, mel_spec: torch.Tensor) -> Dict:
"""Analyze spectrogram for AI-typical patterns"""
spec = mel_spec.squeeze().numpy()
analysis = {}
# Check for unnaturally smooth regions
gradient = np.gradient(spec, axis=1)
analysis["temporal_smoothness"] = float(1.0 / (np.std(gradient) + 1e-8))
# Check frequency band energy distribution
low_band = spec[:32, :].mean()
mid_band = spec[32:96, :].mean()
high_band = spec[96:, :].mean()
analysis["low_band_energy"] = float(low_band)
analysis["mid_band_energy"] = float(mid_band)
analysis["high_band_energy"] = float(high_band)
# Check for vocoder grid patterns (common in neural TTS)
fft_spec = np.abs(np.fft.fft2(spec))
analysis["periodicity_score"] = float(fft_spec[1:10, 1:10].mean() / fft_spec.mean())
# Harmonic-to-noise ratio approximation
sorted_spec = np.sort(spec.flatten())[::-1]
top_10_pct = sorted_spec[:int(len(sorted_spec) * 0.1)].mean()
bottom_50_pct = sorted_spec[int(len(sorted_spec) * 0.5):].mean()
analysis["hnr_approx"] = float(top_10_pct / (bottom_50_pct + 1e-8))
return analysis
def detect(self, audio_path: str) -> Dict:
"""
Detect if audio is AI-generated using spectrogram analysis.
Args:
audio_path: Path to audio file
Returns:
Dictionary with detection results
"""
waveform = self._load_audio(audio_path)
mel_spec = self._create_mel_spectrogram(waveform)
spec_analysis = self._analyze_spectrogram(mel_spec)
spec_analysis['energy_cv'] = self._compute_energy_cv(waveform)
adv_features = extract_advanced_features(audio_path, self.sample_rate)
spec_analysis.update(adv_features)
ai_score = self._compute_ai_score_from_spectrogram(spec_analysis)
ai_score = max(0.0, min(1.0, ai_score))
is_ai = ai_score >= 0.5
confidence = abs(ai_score - 0.5) * 2
result = {
"classification": "ai_generated" if is_ai else "human",
"confidence": confidence,
"model_scores": {
"ai_probability": ai_score,
"human_probability": 1 - ai_score
},
"spectrogram_analysis": spec_analysis,
"frequency_attention": [],
"indicators": self._generate_indicators(1 if is_ai else 0, spec_analysis)
}
return result
def _compute_energy_cv(self, waveform: torch.Tensor) -> float:
"""Compute Coefficient of Variation of energy"""
if waveform.dim() > 1:
waveform = waveform.squeeze()
chunk_size = len(waveform) // 10
if chunk_size == 0: return 0.0
energies = []
for i in range(10):
chunk = waveform[i*chunk_size:(i+1)*chunk_size]
energies.append(float(torch.sqrt(torch.mean(chunk ** 2))))
energy_std = np.std(energies) if energies else 0
energy_mean = np.mean(energies) if energies else 1
return float(energy_std / (energy_mean + 1e-8))
def _compute_ai_score_from_spectrogram(self, analysis: Dict) -> float:
"""
Compute AI probability from spectrogram analysis.
AI-generated voices typically have:
- Higher temporal smoothness (consistent spectrum)
- Low energy variation (consistent volume, less natural pausing)
- Specific band energy distributions
"""
score = 0.5 # Start neutral
# 1. Energy CV - STRONGEST INDICATOR (prioritize this)
# High variation (>0.5) is very typical of human speech (pauses/breathing)
# Low variation (<0.2) is typical of AI
energy_cv = analysis.get("energy_cv", 0.25)
if energy_cv > 0.7:
score -= 0.30 # Very strong human signal
elif energy_cv > 0.5:
score -= 0.20 # Strong human signal
elif energy_cv > 0.35:
score -= 0.10 # Moderate human signal
elif energy_cv < 0.2:
score += 0.10 # Consistent energy = AI like
# 2. Advanced Features
flux = analysis.get("spectral_flux", 0)
mfcc_var = analysis.get("mfcc_variance", 0)
# Flux Heuristic - widened threshold
if flux > 2.4:
score += 0.20 # Strong AI signal
elif flux > 2.0:
score += 0.10 # Moderate AI signal
elif flux < 1.8 and flux > 0.1:
score -= 0.10 # Natural human transitions
# MFCC Var Heuristic
if mfcc_var > 1900:
score -= 0.20 # High complexity = human
# 3. Temporal smoothness (de-prioritized - often misleading)
# Only use extreme values
smoothness = analysis.get("temporal_smoothness", 1.0)
if smoothness > 5.0:
score += 0.10 # Very unnaturally smooth
# Additional features removed as unreliable for edge cases
# (periodicity, high_band, hnr can cause false positives on clean recordings)
return score
def _generate_indicators(self, pred_class: int, analysis: Dict) -> list:
"""Generate human-readable indicators"""
indicators = []
if pred_class == 1: # AI detected
if analysis["temporal_smoothness"] > 5:
indicators.append("Unnaturally smooth spectrogram transitions")
if analysis["periodicity_score"] > 2:
indicators.append("Periodic patterns suggesting neural vocoder")
if analysis["high_band_energy"] < -2:
indicators.append("Reduced high-frequency content typical of TTS")
if not indicators:
indicators.append("Spectrogram patterns consistent with AI synthesis")
else: # Human
if analysis["hnr_approx"] > 10:
indicators.append("Strong harmonic structure of natural voice")
if analysis["temporal_smoothness"] < 3:
indicators.append("Natural variation in spectral features")
if not indicators:
indicators.append("Spectrogram shows natural speech characteristics")
return indicators