pavankumarvk's picture
Upload 2 files
e950836 verified
"""
audio_model.py
==============
AASISTDeepFake model definition β€” matches the training notebook exactly.
Import this in both training scripts and the Gradio app (via audio_detector_inference.py).
Architecture:
Raw waveform β†’ SincConv β†’ Downsample (32Γ—) β†’ Res2Block
β†’ CNN (2 layers) β†’ GraphAttn (Γ—2) β†’ AttentionPool β†’ Classifier
Label convention (from training dataset enumerate(["fake", "real"])):
label = 0 β†’ Fake
label = 1 β†’ Real
sigmoid(logit) >= threshold β†’ Real
sigmoid(logit) < threshold β†’ Fake
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# ── Audio constants (must match training) ─────────────────────────────────────
SAMPLE_RATE = 16_000
MAX_DURATION = 5.0
MAX_SAMPLES = int(SAMPLE_RATE * MAX_DURATION) # 80 000 samples
# ── Sub-modules ───────────────────────────────────────────────────────────────
class SincConv(nn.Module):
"""
Learnable sinc-function band-pass filter bank.
Only 2Γ—out_channels parameters (one f_low, one f_high per filter).
Initialised from mel-scale frequency bands.
"""
@staticmethod
def to_mel(hz): return 2595 * np.log10(1 + hz / 700)
@staticmethod
def to_hz(mel): return 700 * (10 ** (mel / 2595) - 1)
def __init__(self, out_channels: int = 64, kernel_size: int = 512,
sample_rate: int = 16_000):
super().__init__()
self.out_channels = out_channels
self.kernel_size = kernel_size + 1 if kernel_size % 2 == 0 else kernel_size
self.sample_rate = sample_rate
low_hz, high_hz = 30, sample_rate / 2 - 100
mel = np.linspace(self.to_mel(low_hz), self.to_mel(high_hz), out_channels + 1)
hz = self.to_hz(mel)
self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1))
self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1))
half = (self.kernel_size - 1) // 2
n = torch.arange(1, half + 1, dtype=torch.float32)
self.register_buffer('n_', (2 * np.pi * n / sample_rate).unsqueeze(0))
self.register_buffer('window_', torch.hamming_window(self.kernel_size))
def forward(self, x: torch.Tensor) -> torch.Tensor:
low = 50 + torch.abs(self.low_hz_)
high = torch.clamp(low + 50 + torch.abs(self.band_hz_),
max=self.sample_rate / 2)
band = (high - low)[:, 0]
f1 = torch.matmul(low, self.n_)
f2 = torch.matmul(high, self.n_)
lp1 = torch.sin(f1) / (np.pi * self.n_ / (2 * np.pi))
lp2 = torch.sin(f2) / (np.pi * self.n_ / (2 * np.pi))
bp = (lp2 - lp1) / (2 * band[:, None])
centre = torch.zeros(self.out_channels, 1, device=bp.device)
filters = torch.cat([bp.flip(1), centre, bp], dim=1)
filters = filters * self.window_
x = x.unsqueeze(1)
return F.conv1d(x, filters.unsqueeze(1), padding=self.kernel_size // 2)
class Res2Block(nn.Module):
"""
Multi-scale residual block with inter-group accumulation.
Splits channels into `scale` groups; each group accumulates the previous.
"""
def __init__(self, channels: int, scale: int = 8, dilation: int = 1):
super().__init__()
assert channels % scale == 0, \
f"channels ({channels}) must be divisible by scale ({scale})"
self.scale = scale
width = channels // scale
self.convs = nn.ModuleList([
nn.Conv1d(width, width, 3, padding=dilation, dilation=dilation)
for _ in range(scale - 1)
])
self.bns = nn.ModuleList([nn.BatchNorm1d(width) for _ in range(scale - 1)])
def forward(self, x: torch.Tensor) -> torch.Tensor:
chunks = torch.chunk(x, self.scale, dim=1)
out = [chunks[0]]
y = chunks[1]
for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
if i > 0:
y = y + chunks[i + 1]
y = F.gelu(bn(conv(y)))
out.append(y)
return torch.cat(out, dim=1)
class GraphAttn(nn.Module):
"""
Memory-efficient multi-head self-attention over temporal frames.
Sequences longer than 64 tokens are pooled before attention and
upsampled back for the residual addition.
"""
def __init__(self, dim: int, heads: int = 4):
super().__init__()
self.heads = heads
self.head_dim = dim // heads
self.qkv = nn.Linear(dim, dim * 3)
self.out = nn.Linear(dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
if N > 64:
x_pool = F.adaptive_avg_pool1d(
x.transpose(1, 2), 64).transpose(1, 2)
else:
x_pool = x
Bp, Np, Cp = x_pool.shape
qkv = (self.qkv(x_pool)
.reshape(Bp, Np, 3, self.heads, self.head_dim)
.permute(2, 0, 3, 1, 4))
q, k, v = qkv.unbind(0)
attn = torch.softmax(
q @ k.transpose(-2, -1) / (self.head_dim ** 0.5), dim=-1)
out = (attn @ v).transpose(1, 2).reshape(Bp, Np, Cp)
out = self.out(out)
# Upsample back to original length for the residual connection
out = F.interpolate(
out.transpose(1, 2), size=N,
mode='linear', align_corners=False).transpose(1, 2)
return out
class AttentionPool(nn.Module):
"""Soft-attention weighted pooling over a sequence."""
def __init__(self, dim: int):
super().__init__()
self.attn = nn.Linear(dim, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
w = torch.softmax(self.attn(x), dim=1) # (B, T, 1)
return (w * x).sum(dim=1) # (B, dim)
# ── Main model ────────────────────────────────────────────────────────────────
class AASISTDeepFake(nn.Module):
"""
AASISTDeepFake β€” memory-efficient raw-waveform audio spoof detector.
Input : (B, 80 000) float32 waveform, normalised to [-1, 1]
Output : (B, 1) raw logit β†’ sigmoid β†’ P(real)
Prediction:
sigmoid(logit) >= threshold β†’ Real (label 1)
sigmoid(logit) < threshold β†’ Fake (label 0)
"""
def __init__(
self,
sinc_ch: int = 64,
sinc_kernel: int = 512,
hidden: int = 128,
graph_heads: int = 4,
n_graph: int = 2,
):
super().__init__()
self.sinc = SincConv(sinc_ch, sinc_kernel, SAMPLE_RATE)
self.bn_sinc = nn.BatchNorm1d(sinc_ch)
# Aggressive downsampling: T β†’ T/32 (kills OOM on long sequences)
self.downsample = nn.Sequential(
nn.Conv1d(sinc_ch, sinc_ch, kernel_size=8, stride=8),
nn.BatchNorm1d(sinc_ch), nn.GELU(),
nn.Conv1d(sinc_ch, sinc_ch, kernel_size=4, stride=4),
nn.BatchNorm1d(sinc_ch), nn.GELU(),
)
self.encoder = nn.Sequential(
Res2Block(sinc_ch), nn.BatchNorm1d(sinc_ch), nn.GELU(),
)
self.cnn = nn.Sequential(
nn.Conv1d(sinc_ch, hidden, kernel_size=3, padding=1),
nn.BatchNorm1d(hidden), nn.GELU(),
nn.Conv1d(hidden, hidden, kernel_size=3, padding=1),
nn.BatchNorm1d(hidden), nn.GELU(),
)
self.graph_layers = nn.ModuleList(
[GraphAttn(hidden, graph_heads) for _ in range(n_graph)])
self.layer_norms = nn.ModuleList(
[nn.LayerNorm(hidden) for _ in range(n_graph)])
self.pool = AttentionPool(hidden)
self.classifier = nn.Sequential(
nn.LayerNorm(hidden),
nn.Linear(hidden, 64),
nn.GELU(),
nn.Dropout(0.4),
nn.Linear(64, 1),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.abs(self.sinc(x)) # (B, sinc_ch, T)
x = F.gelu(self.bn_sinc(x))
x = self.downsample(x) # (B, sinc_ch, T/32)
x = self.encoder(x)
x = self.cnn(x) # (B, hidden, T/32)
x = x.transpose(1, 2) # (B, T/32, hidden)
for attn, ln in zip(self.graph_layers, self.layer_norms):
x = ln(x + attn(x))
pooled = self.pool(x) # (B, hidden)
return self.classifier(pooled) # (B, 1)
# ── Helper ────────────────────────────────────────────────────────────────────
def load_audio_model(
checkpoint: str,
device: torch.device = None,
) -> AASISTDeepFake:
"""Load a trained AASISTDeepFake from a .pt state-dict checkpoint."""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AASISTDeepFake()
model.load_state_dict(torch.load(checkpoint, map_location=device))
model.eval().to(device)
return model