""" audio_model.py ============== AASISTDeepFake model definition — matches the training notebook exactly. Import this in both training scripts and the Gradio app (via audio_detector_inference.py). Architecture: Raw waveform → SincConv → Downsample (32×) → Res2Block → CNN (2 layers) → GraphAttn (×2) → AttentionPool → Classifier Label convention (from training dataset enumerate(["fake", "real"])): label = 0 → Fake label = 1 → Real sigmoid(logit) >= threshold → Real sigmoid(logit) < threshold → Fake """ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F # ── Audio constants (must match training) ───────────────────────────────────── SAMPLE_RATE = 16_000 MAX_DURATION = 5.0 MAX_SAMPLES = int(SAMPLE_RATE * MAX_DURATION) # 80 000 samples # ── Sub-modules ─────────────────────────────────────────────────────────────── class SincConv(nn.Module): """ Learnable sinc-function band-pass filter bank. Only 2×out_channels parameters (one f_low, one f_high per filter). Initialised from mel-scale frequency bands. """ @staticmethod def to_mel(hz): return 2595 * np.log10(1 + hz / 700) @staticmethod def to_hz(mel): return 700 * (10 ** (mel / 2595) - 1) def __init__(self, out_channels: int = 64, kernel_size: int = 512, sample_rate: int = 16_000): super().__init__() self.out_channels = out_channels self.kernel_size = kernel_size + 1 if kernel_size % 2 == 0 else kernel_size self.sample_rate = sample_rate low_hz, high_hz = 30, sample_rate / 2 - 100 mel = np.linspace(self.to_mel(low_hz), self.to_mel(high_hz), out_channels + 1) hz = self.to_hz(mel) self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1)) self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1)) half = (self.kernel_size - 1) // 2 n = torch.arange(1, half + 1, dtype=torch.float32) self.register_buffer('n_', (2 * np.pi * n / sample_rate).unsqueeze(0)) self.register_buffer('window_', torch.hamming_window(self.kernel_size)) def forward(self, x: torch.Tensor) -> torch.Tensor: low = 50 + torch.abs(self.low_hz_) high = torch.clamp(low + 50 + torch.abs(self.band_hz_), max=self.sample_rate / 2) band = (high - low)[:, 0] f1 = torch.matmul(low, self.n_) f2 = torch.matmul(high, self.n_) lp1 = torch.sin(f1) / (np.pi * self.n_ / (2 * np.pi)) lp2 = torch.sin(f2) / (np.pi * self.n_ / (2 * np.pi)) bp = (lp2 - lp1) / (2 * band[:, None]) centre = torch.zeros(self.out_channels, 1, device=bp.device) filters = torch.cat([bp.flip(1), centre, bp], dim=1) filters = filters * self.window_ x = x.unsqueeze(1) return F.conv1d(x, filters.unsqueeze(1), padding=self.kernel_size // 2) class Res2Block(nn.Module): """ Multi-scale residual block with inter-group accumulation. Splits channels into `scale` groups; each group accumulates the previous. """ def __init__(self, channels: int, scale: int = 8, dilation: int = 1): super().__init__() assert channels % scale == 0, \ f"channels ({channels}) must be divisible by scale ({scale})" self.scale = scale width = channels // scale self.convs = nn.ModuleList([ nn.Conv1d(width, width, 3, padding=dilation, dilation=dilation) for _ in range(scale - 1) ]) self.bns = nn.ModuleList([nn.BatchNorm1d(width) for _ in range(scale - 1)]) def forward(self, x: torch.Tensor) -> torch.Tensor: chunks = torch.chunk(x, self.scale, dim=1) out = [chunks[0]] y = chunks[1] for i, (conv, bn) in enumerate(zip(self.convs, self.bns)): if i > 0: y = y + chunks[i + 1] y = F.gelu(bn(conv(y))) out.append(y) return torch.cat(out, dim=1) class GraphAttn(nn.Module): """ Memory-efficient multi-head self-attention over temporal frames. Sequences longer than 64 tokens are pooled before attention and upsampled back for the residual addition. """ def __init__(self, dim: int, heads: int = 4): super().__init__() self.heads = heads self.head_dim = dim // heads self.qkv = nn.Linear(dim, dim * 3) self.out = nn.Linear(dim, dim) def forward(self, x: torch.Tensor) -> torch.Tensor: B, N, C = x.shape if N > 64: x_pool = F.adaptive_avg_pool1d( x.transpose(1, 2), 64).transpose(1, 2) else: x_pool = x Bp, Np, Cp = x_pool.shape qkv = (self.qkv(x_pool) .reshape(Bp, Np, 3, self.heads, self.head_dim) .permute(2, 0, 3, 1, 4)) q, k, v = qkv.unbind(0) attn = torch.softmax( q @ k.transpose(-2, -1) / (self.head_dim ** 0.5), dim=-1) out = (attn @ v).transpose(1, 2).reshape(Bp, Np, Cp) out = self.out(out) # Upsample back to original length for the residual connection out = F.interpolate( out.transpose(1, 2), size=N, mode='linear', align_corners=False).transpose(1, 2) return out class AttentionPool(nn.Module): """Soft-attention weighted pooling over a sequence.""" def __init__(self, dim: int): super().__init__() self.attn = nn.Linear(dim, 1) def forward(self, x: torch.Tensor) -> torch.Tensor: w = torch.softmax(self.attn(x), dim=1) # (B, T, 1) return (w * x).sum(dim=1) # (B, dim) # ── Main model ──────────────────────────────────────────────────────────────── class AASISTDeepFake(nn.Module): """ AASISTDeepFake — memory-efficient raw-waveform audio spoof detector. Input : (B, 80 000) float32 waveform, normalised to [-1, 1] Output : (B, 1) raw logit → sigmoid → P(real) Prediction: sigmoid(logit) >= threshold → Real (label 1) sigmoid(logit) < threshold → Fake (label 0) """ def __init__( self, sinc_ch: int = 64, sinc_kernel: int = 512, hidden: int = 128, graph_heads: int = 4, n_graph: int = 2, ): super().__init__() self.sinc = SincConv(sinc_ch, sinc_kernel, SAMPLE_RATE) self.bn_sinc = nn.BatchNorm1d(sinc_ch) # Aggressive downsampling: T → T/32 (kills OOM on long sequences) self.downsample = nn.Sequential( nn.Conv1d(sinc_ch, sinc_ch, kernel_size=8, stride=8), nn.BatchNorm1d(sinc_ch), nn.GELU(), nn.Conv1d(sinc_ch, sinc_ch, kernel_size=4, stride=4), nn.BatchNorm1d(sinc_ch), nn.GELU(), ) self.encoder = nn.Sequential( Res2Block(sinc_ch), nn.BatchNorm1d(sinc_ch), nn.GELU(), ) self.cnn = nn.Sequential( nn.Conv1d(sinc_ch, hidden, kernel_size=3, padding=1), nn.BatchNorm1d(hidden), nn.GELU(), nn.Conv1d(hidden, hidden, kernel_size=3, padding=1), nn.BatchNorm1d(hidden), nn.GELU(), ) self.graph_layers = nn.ModuleList( [GraphAttn(hidden, graph_heads) for _ in range(n_graph)]) self.layer_norms = nn.ModuleList( [nn.LayerNorm(hidden) for _ in range(n_graph)]) self.pool = AttentionPool(hidden) self.classifier = nn.Sequential( nn.LayerNorm(hidden), nn.Linear(hidden, 64), nn.GELU(), nn.Dropout(0.4), nn.Linear(64, 1), ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = torch.abs(self.sinc(x)) # (B, sinc_ch, T) x = F.gelu(self.bn_sinc(x)) x = self.downsample(x) # (B, sinc_ch, T/32) x = self.encoder(x) x = self.cnn(x) # (B, hidden, T/32) x = x.transpose(1, 2) # (B, T/32, hidden) for attn, ln in zip(self.graph_layers, self.layer_norms): x = ln(x + attn(x)) pooled = self.pool(x) # (B, hidden) return self.classifier(pooled) # (B, 1) # ── Helper ──────────────────────────────────────────────────────────────────── def load_audio_model( checkpoint: str, device: torch.device = None, ) -> AASISTDeepFake: """Load a trained AASISTDeepFake from a .pt state-dict checkpoint.""" if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = AASISTDeepFake() model.load_state_dict(torch.load(checkpoint, map_location=device)) model.eval().to(device) return model