| """ |
| audio_model.py |
| ============== |
| AASISTDeepFake model definition β matches the training notebook exactly. |
| Import this in both training scripts and the Gradio app (via audio_detector_inference.py). |
| |
| Architecture: |
| Raw waveform β SincConv β Downsample (32Γ) β Res2Block |
| β CNN (2 layers) β GraphAttn (Γ2) β AttentionPool β Classifier |
| |
| Label convention (from training dataset enumerate(["fake", "real"])): |
| label = 0 β Fake |
| label = 1 β Real |
| sigmoid(logit) >= threshold β Real |
| sigmoid(logit) < threshold β Fake |
| """ |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
| SAMPLE_RATE = 16_000 |
| MAX_DURATION = 5.0 |
| MAX_SAMPLES = int(SAMPLE_RATE * MAX_DURATION) |
|
|
|
|
| |
|
|
| class SincConv(nn.Module): |
| """ |
| Learnable sinc-function band-pass filter bank. |
| Only 2Γout_channels parameters (one f_low, one f_high per filter). |
| Initialised from mel-scale frequency bands. |
| """ |
|
|
| @staticmethod |
| def to_mel(hz): return 2595 * np.log10(1 + hz / 700) |
| @staticmethod |
| def to_hz(mel): return 700 * (10 ** (mel / 2595) - 1) |
|
|
| def __init__(self, out_channels: int = 64, kernel_size: int = 512, |
| sample_rate: int = 16_000): |
| super().__init__() |
| self.out_channels = out_channels |
| self.kernel_size = kernel_size + 1 if kernel_size % 2 == 0 else kernel_size |
| self.sample_rate = sample_rate |
|
|
| low_hz, high_hz = 30, sample_rate / 2 - 100 |
| mel = np.linspace(self.to_mel(low_hz), self.to_mel(high_hz), out_channels + 1) |
| hz = self.to_hz(mel) |
|
|
| self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1)) |
| self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1)) |
|
|
| half = (self.kernel_size - 1) // 2 |
| n = torch.arange(1, half + 1, dtype=torch.float32) |
| self.register_buffer('n_', (2 * np.pi * n / sample_rate).unsqueeze(0)) |
| self.register_buffer('window_', torch.hamming_window(self.kernel_size)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| low = 50 + torch.abs(self.low_hz_) |
| high = torch.clamp(low + 50 + torch.abs(self.band_hz_), |
| max=self.sample_rate / 2) |
| band = (high - low)[:, 0] |
|
|
| f1 = torch.matmul(low, self.n_) |
| f2 = torch.matmul(high, self.n_) |
| lp1 = torch.sin(f1) / (np.pi * self.n_ / (2 * np.pi)) |
| lp2 = torch.sin(f2) / (np.pi * self.n_ / (2 * np.pi)) |
| bp = (lp2 - lp1) / (2 * band[:, None]) |
|
|
| centre = torch.zeros(self.out_channels, 1, device=bp.device) |
| filters = torch.cat([bp.flip(1), centre, bp], dim=1) |
| filters = filters * self.window_ |
|
|
| x = x.unsqueeze(1) |
| return F.conv1d(x, filters.unsqueeze(1), padding=self.kernel_size // 2) |
|
|
|
|
| class Res2Block(nn.Module): |
| """ |
| Multi-scale residual block with inter-group accumulation. |
| Splits channels into `scale` groups; each group accumulates the previous. |
| """ |
|
|
| def __init__(self, channels: int, scale: int = 8, dilation: int = 1): |
| super().__init__() |
| assert channels % scale == 0, \ |
| f"channels ({channels}) must be divisible by scale ({scale})" |
| self.scale = scale |
| width = channels // scale |
| self.convs = nn.ModuleList([ |
| nn.Conv1d(width, width, 3, padding=dilation, dilation=dilation) |
| for _ in range(scale - 1) |
| ]) |
| self.bns = nn.ModuleList([nn.BatchNorm1d(width) for _ in range(scale - 1)]) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| chunks = torch.chunk(x, self.scale, dim=1) |
| out = [chunks[0]] |
| y = chunks[1] |
| for i, (conv, bn) in enumerate(zip(self.convs, self.bns)): |
| if i > 0: |
| y = y + chunks[i + 1] |
| y = F.gelu(bn(conv(y))) |
| out.append(y) |
| return torch.cat(out, dim=1) |
|
|
|
|
| class GraphAttn(nn.Module): |
| """ |
| Memory-efficient multi-head self-attention over temporal frames. |
| Sequences longer than 64 tokens are pooled before attention and |
| upsampled back for the residual addition. |
| """ |
|
|
| def __init__(self, dim: int, heads: int = 4): |
| super().__init__() |
| self.heads = heads |
| self.head_dim = dim // heads |
| self.qkv = nn.Linear(dim, dim * 3) |
| self.out = nn.Linear(dim, dim) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| B, N, C = x.shape |
| if N > 64: |
| x_pool = F.adaptive_avg_pool1d( |
| x.transpose(1, 2), 64).transpose(1, 2) |
| else: |
| x_pool = x |
|
|
| Bp, Np, Cp = x_pool.shape |
| qkv = (self.qkv(x_pool) |
| .reshape(Bp, Np, 3, self.heads, self.head_dim) |
| .permute(2, 0, 3, 1, 4)) |
| q, k, v = qkv.unbind(0) |
| attn = torch.softmax( |
| q @ k.transpose(-2, -1) / (self.head_dim ** 0.5), dim=-1) |
| out = (attn @ v).transpose(1, 2).reshape(Bp, Np, Cp) |
| out = self.out(out) |
|
|
| |
| out = F.interpolate( |
| out.transpose(1, 2), size=N, |
| mode='linear', align_corners=False).transpose(1, 2) |
| return out |
|
|
|
|
| class AttentionPool(nn.Module): |
| """Soft-attention weighted pooling over a sequence.""" |
|
|
| def __init__(self, dim: int): |
| super().__init__() |
| self.attn = nn.Linear(dim, 1) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| w = torch.softmax(self.attn(x), dim=1) |
| return (w * x).sum(dim=1) |
|
|
|
|
| |
|
|
| class AASISTDeepFake(nn.Module): |
| """ |
| AASISTDeepFake β memory-efficient raw-waveform audio spoof detector. |
| |
| Input : (B, 80 000) float32 waveform, normalised to [-1, 1] |
| Output : (B, 1) raw logit β sigmoid β P(real) |
| |
| Prediction: |
| sigmoid(logit) >= threshold β Real (label 1) |
| sigmoid(logit) < threshold β Fake (label 0) |
| """ |
|
|
| def __init__( |
| self, |
| sinc_ch: int = 64, |
| sinc_kernel: int = 512, |
| hidden: int = 128, |
| graph_heads: int = 4, |
| n_graph: int = 2, |
| ): |
| super().__init__() |
| self.sinc = SincConv(sinc_ch, sinc_kernel, SAMPLE_RATE) |
| self.bn_sinc = nn.BatchNorm1d(sinc_ch) |
|
|
| |
| self.downsample = nn.Sequential( |
| nn.Conv1d(sinc_ch, sinc_ch, kernel_size=8, stride=8), |
| nn.BatchNorm1d(sinc_ch), nn.GELU(), |
| nn.Conv1d(sinc_ch, sinc_ch, kernel_size=4, stride=4), |
| nn.BatchNorm1d(sinc_ch), nn.GELU(), |
| ) |
| self.encoder = nn.Sequential( |
| Res2Block(sinc_ch), nn.BatchNorm1d(sinc_ch), nn.GELU(), |
| ) |
| self.cnn = nn.Sequential( |
| nn.Conv1d(sinc_ch, hidden, kernel_size=3, padding=1), |
| nn.BatchNorm1d(hidden), nn.GELU(), |
| nn.Conv1d(hidden, hidden, kernel_size=3, padding=1), |
| nn.BatchNorm1d(hidden), nn.GELU(), |
| ) |
| self.graph_layers = nn.ModuleList( |
| [GraphAttn(hidden, graph_heads) for _ in range(n_graph)]) |
| self.layer_norms = nn.ModuleList( |
| [nn.LayerNorm(hidden) for _ in range(n_graph)]) |
| self.pool = AttentionPool(hidden) |
| self.classifier = nn.Sequential( |
| nn.LayerNorm(hidden), |
| nn.Linear(hidden, 64), |
| nn.GELU(), |
| nn.Dropout(0.4), |
| nn.Linear(64, 1), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = torch.abs(self.sinc(x)) |
| x = F.gelu(self.bn_sinc(x)) |
| x = self.downsample(x) |
| x = self.encoder(x) |
| x = self.cnn(x) |
| x = x.transpose(1, 2) |
| for attn, ln in zip(self.graph_layers, self.layer_norms): |
| x = ln(x + attn(x)) |
| pooled = self.pool(x) |
| return self.classifier(pooled) |
|
|
|
|
| |
|
|
| def load_audio_model( |
| checkpoint: str, |
| device: torch.device = None, |
| ) -> AASISTDeepFake: |
| """Load a trained AASISTDeepFake from a .pt state-dict checkpoint.""" |
| if device is None: |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = AASISTDeepFake() |
| model.load_state_dict(torch.load(checkpoint, map_location=device)) |
| model.eval().to(device) |
| return model |
|
|